From 9d9dc969c451c87b7ad3c84f807db2c2d9109f41 Mon Sep 17 00:00:00 2001 From: larry dial Date: Sat, 23 Aug 2025 13:08:59 -0700 Subject: [PATCH 01/14] adding sparse attention gate --- records/071825_TritonMuon/record.txt | 2848 +++++++++++++++++ .../020630eb-2191-4ba2-9ee4-4cdc94316943.txt | 2802 ++++++++++++++++ .../21e732fb-4c4b-4db9-94bc-9fcd5d59b080.txt | 2802 ++++++++++++++++ .../4518e917-cec2-4c81-9c1a-53b0644c2326.txt | 2802 ++++++++++++++++ .../48b19604-5049-48c9-956c-8ddc4d0781fb.txt | 2802 ++++++++++++++++ .../50524dcb-cf95-4b75-bf89-ba8ff3c5e1af.txt | 2802 ++++++++++++++++ .../53ecb4ef-77ed-4af6-b776-47cd4006614b.txt | 2802 ++++++++++++++++ .../6701af06-6c40-4553-bb04-f501fdd56284.txt | 2802 ++++++++++++++++ .../6df384bb-9c24-46b3-826b-f7c07168c27a.txt | 2802 ++++++++++++++++ records/082325_SparseAttnGate/README.md | 45 + .../a39b1ae8-3a2a-4952-8032-13183b157053.txt | 2802 ++++++++++++++++ .../c6be54c1-12d0-45a3-83cb-41cad0868d15.txt | 2802 ++++++++++++++++ .../ca042caf-b232-4a25-b28f-88e39a2009d3.txt | 2802 ++++++++++++++++ .../d3e1ea3c-521c-4abd-a549-950c698d6cbf.txt | 2802 ++++++++++++++++ .../e8891a98-8bf2-43cc-bac5-728aa53482ce.txt | 2802 ++++++++++++++++ .../eb6d347b-fd4a-4077-a490-436c64f97ce2.txt | 2802 ++++++++++++++++ requirements.txt | 1 + train_gpt.py | 420 ++- 18 files changed, 42471 insertions(+), 71 deletions(-) create mode 100644 records/071825_TritonMuon/record.txt create mode 100644 records/082325_SparseAttnGate/020630eb-2191-4ba2-9ee4-4cdc94316943.txt create mode 100644 records/082325_SparseAttnGate/21e732fb-4c4b-4db9-94bc-9fcd5d59b080.txt create mode 100644 records/082325_SparseAttnGate/4518e917-cec2-4c81-9c1a-53b0644c2326.txt create mode 100644 records/082325_SparseAttnGate/48b19604-5049-48c9-956c-8ddc4d0781fb.txt create mode 100644 records/082325_SparseAttnGate/50524dcb-cf95-4b75-bf89-ba8ff3c5e1af.txt create mode 100644 records/082325_SparseAttnGate/53ecb4ef-77ed-4af6-b776-47cd4006614b.txt create mode 100644 records/082325_SparseAttnGate/6701af06-6c40-4553-bb04-f501fdd56284.txt create mode 100644 records/082325_SparseAttnGate/6df384bb-9c24-46b3-826b-f7c07168c27a.txt create mode 100644 records/082325_SparseAttnGate/README.md create mode 100644 records/082325_SparseAttnGate/a39b1ae8-3a2a-4952-8032-13183b157053.txt create mode 100644 records/082325_SparseAttnGate/c6be54c1-12d0-45a3-83cb-41cad0868d15.txt create mode 100644 records/082325_SparseAttnGate/ca042caf-b232-4a25-b28f-88e39a2009d3.txt create mode 100644 records/082325_SparseAttnGate/d3e1ea3c-521c-4abd-a549-950c698d6cbf.txt create mode 100644 records/082325_SparseAttnGate/e8891a98-8bf2-43cc-bac5-728aa53482ce.txt create mode 100644 records/082325_SparseAttnGate/eb6d347b-fd4a-4077-a490-436c64f97ce2.txt diff --git a/records/071825_TritonMuon/record.txt b/records/071825_TritonMuon/record.txt new file mode 100644 index 000000000..30d539ceb --- /dev/null +++ b/records/071825_TritonMuon/record.txt @@ -0,0 +1,2848 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 27.5 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5)) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1750 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +#assert world_size == 8 # this code is designed for 8xH100 +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = uuid.uuid4() + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Jul 18 15:57:50 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 30C P0 132W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 32C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 29C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 29C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 32C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 28C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 111824 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 111825 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111826 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111827 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111828 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111829 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111830 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111831 C /usr/bin/python3 614MiB | +| 1 N/A N/A 111825 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 111826 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 111827 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 111828 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 111829 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 111830 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 111831 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1750 val_loss:10.8258 train_time:0ms step_avg:0.04ms +step:1/1750 train_time:151ms step_avg:151.05ms +step:2/1750 train_time:176ms step_avg:87.79ms +step:3/1750 train_time:240ms step_avg:79.89ms +step:4/1750 train_time:329ms step_avg:82.25ms +step:5/1750 train_time:420ms step_avg:83.91ms +step:6/1750 train_time:510ms step_avg:84.94ms +step:7/1750 train_time:601ms step_avg:85.79ms +step:8/1750 train_time:691ms step_avg:86.32ms +step:9/1750 train_time:781ms step_avg:86.76ms +step:10/1750 train_time:871ms step_avg:87.08ms +step:11/1750 train_time:961ms step_avg:87.37ms +step:12/1750 train_time:1053ms step_avg:87.78ms +step:13/1750 train_time:1148ms step_avg:88.27ms +step:14/1750 train_time:1242ms step_avg:88.73ms +step:15/1750 train_time:1334ms step_avg:88.93ms +step:16/1750 train_time:1425ms step_avg:89.09ms +step:17/1750 train_time:1517ms step_avg:89.22ms +step:18/1750 train_time:1608ms step_avg:89.31ms +step:19/1750 train_time:1699ms step_avg:89.40ms +step:20/1750 train_time:1790ms step_avg:89.49ms +step:21/1750 train_time:1880ms step_avg:89.52ms +step:22/1750 train_time:1970ms step_avg:89.56ms +step:23/1750 train_time:2062ms step_avg:89.65ms +step:24/1750 train_time:2156ms step_avg:89.82ms +step:25/1750 train_time:2248ms step_avg:89.92ms +step:26/1750 train_time:2341ms step_avg:90.03ms +step:27/1750 train_time:2432ms step_avg:90.07ms +step:28/1750 train_time:2523ms step_avg:90.10ms +step:29/1750 train_time:2613ms step_avg:90.12ms +step:30/1750 train_time:2705ms step_avg:90.17ms +step:31/1750 train_time:2796ms step_avg:90.19ms +step:32/1750 train_time:2887ms step_avg:90.21ms +step:33/1750 train_time:2978ms step_avg:90.24ms +step:34/1750 train_time:3069ms step_avg:90.27ms +step:35/1750 train_time:3162ms step_avg:90.33ms +step:36/1750 train_time:3254ms step_avg:90.38ms +step:37/1750 train_time:3346ms step_avg:90.42ms +step:38/1750 train_time:3438ms step_avg:90.47ms +step:39/1750 train_time:3529ms step_avg:90.49ms +step:40/1750 train_time:3621ms step_avg:90.52ms +step:41/1750 train_time:3712ms step_avg:90.55ms +step:42/1750 train_time:3803ms step_avg:90.55ms +step:43/1750 train_time:3894ms step_avg:90.56ms +step:44/1750 train_time:3985ms step_avg:90.57ms +step:45/1750 train_time:4077ms step_avg:90.60ms +step:46/1750 train_time:4169ms step_avg:90.63ms +step:47/1750 train_time:4261ms step_avg:90.65ms +step:48/1750 train_time:4352ms step_avg:90.66ms +step:49/1750 train_time:4444ms step_avg:90.70ms +step:50/1750 train_time:4535ms step_avg:90.70ms +step:51/1750 train_time:4628ms step_avg:90.74ms +step:52/1750 train_time:4719ms step_avg:90.75ms +step:53/1750 train_time:4810ms step_avg:90.76ms +step:54/1750 train_time:4902ms step_avg:90.77ms +step:55/1750 train_time:4993ms step_avg:90.78ms +step:56/1750 train_time:5085ms step_avg:90.80ms +step:57/1750 train_time:5176ms step_avg:90.80ms +step:58/1750 train_time:5267ms step_avg:90.82ms +step:59/1750 train_time:5359ms step_avg:90.84ms +step:60/1750 train_time:5451ms step_avg:90.85ms +step:61/1750 train_time:5543ms step_avg:90.86ms +step:62/1750 train_time:5634ms step_avg:90.88ms +step:63/1750 train_time:5726ms step_avg:90.89ms +step:64/1750 train_time:5819ms step_avg:90.91ms +step:65/1750 train_time:5910ms step_avg:90.93ms +step:66/1750 train_time:6002ms step_avg:90.94ms +step:67/1750 train_time:6093ms step_avg:90.94ms +step:68/1750 train_time:6185ms step_avg:90.95ms +step:69/1750 train_time:6276ms step_avg:90.96ms +step:70/1750 train_time:6368ms step_avg:90.97ms +step:71/1750 train_time:6459ms step_avg:90.98ms +step:72/1750 train_time:6551ms step_avg:90.99ms +step:73/1750 train_time:6643ms step_avg:91.00ms +step:74/1750 train_time:6734ms step_avg:91.00ms +step:75/1750 train_time:6826ms step_avg:91.02ms +step:76/1750 train_time:6918ms step_avg:91.02ms +step:77/1750 train_time:7009ms step_avg:91.03ms +step:78/1750 train_time:7101ms step_avg:91.04ms +step:79/1750 train_time:7192ms step_avg:91.04ms +step:80/1750 train_time:7284ms step_avg:91.05ms +step:81/1750 train_time:7375ms step_avg:91.05ms +step:82/1750 train_time:7466ms step_avg:91.05ms +step:83/1750 train_time:7558ms step_avg:91.06ms +step:84/1750 train_time:7650ms step_avg:91.07ms +step:85/1750 train_time:7742ms step_avg:91.08ms +step:86/1750 train_time:7833ms step_avg:91.08ms +step:87/1750 train_time:7925ms step_avg:91.09ms +step:88/1750 train_time:8016ms step_avg:91.09ms +step:89/1750 train_time:8108ms step_avg:91.10ms +step:90/1750 train_time:8200ms step_avg:91.11ms +step:91/1750 train_time:8290ms step_avg:91.10ms +step:92/1750 train_time:8382ms step_avg:91.11ms +step:93/1750 train_time:8474ms step_avg:91.12ms +step:94/1750 train_time:8566ms step_avg:91.12ms +step:95/1750 train_time:8657ms step_avg:91.13ms +step:96/1750 train_time:8749ms step_avg:91.14ms +step:97/1750 train_time:8841ms step_avg:91.14ms +step:98/1750 train_time:8931ms step_avg:91.14ms +step:99/1750 train_time:9023ms step_avg:91.14ms +step:100/1750 train_time:9114ms step_avg:91.14ms +step:101/1750 train_time:9207ms step_avg:91.15ms +step:102/1750 train_time:9298ms step_avg:91.15ms +step:103/1750 train_time:9389ms step_avg:91.16ms +step:104/1750 train_time:9481ms step_avg:91.16ms +step:105/1750 train_time:9573ms step_avg:91.17ms +step:106/1750 train_time:9665ms step_avg:91.17ms +step:107/1750 train_time:9756ms step_avg:91.18ms +step:108/1750 train_time:9847ms step_avg:91.18ms +step:109/1750 train_time:9939ms step_avg:91.18ms +step:110/1750 train_time:10031ms step_avg:91.19ms +step:111/1750 train_time:10123ms step_avg:91.20ms +step:112/1750 train_time:10215ms step_avg:91.20ms +step:113/1750 train_time:10307ms step_avg:91.21ms +step:114/1750 train_time:10398ms step_avg:91.21ms +step:115/1750 train_time:10490ms step_avg:91.22ms +step:116/1750 train_time:10581ms step_avg:91.22ms +step:117/1750 train_time:10673ms step_avg:91.22ms +step:118/1750 train_time:10765ms step_avg:91.23ms +step:119/1750 train_time:10856ms step_avg:91.23ms +step:120/1750 train_time:10947ms step_avg:91.23ms +step:121/1750 train_time:11039ms step_avg:91.23ms +step:122/1750 train_time:11131ms step_avg:91.24ms +step:123/1750 train_time:11224ms step_avg:91.25ms +step:124/1750 train_time:11315ms step_avg:91.25ms +step:125/1750 train_time:11407ms step_avg:91.26ms +step:125/1750 val_loss:4.6317 train_time:11502ms step_avg:92.02ms +step:126/1750 train_time:11528ms step_avg:91.49ms +step:127/1750 train_time:11599ms step_avg:91.33ms +step:128/1750 train_time:11696ms step_avg:91.38ms +step:129/1750 train_time:11790ms step_avg:91.39ms +step:130/1750 train_time:11881ms step_avg:91.39ms +step:131/1750 train_time:11972ms step_avg:91.39ms +step:132/1750 train_time:12062ms step_avg:91.38ms +step:133/1750 train_time:12153ms step_avg:91.38ms +step:134/1750 train_time:12244ms step_avg:91.38ms +step:135/1750 train_time:12336ms step_avg:91.37ms +step:136/1750 train_time:12427ms step_avg:91.37ms +step:137/1750 train_time:12520ms step_avg:91.39ms +step:138/1750 train_time:12613ms step_avg:91.40ms +step:139/1750 train_time:12708ms step_avg:91.42ms +step:140/1750 train_time:12802ms step_avg:91.44ms +step:141/1750 train_time:12894ms step_avg:91.45ms +step:142/1750 train_time:12986ms step_avg:91.45ms +step:143/1750 train_time:13078ms step_avg:91.45ms +step:144/1750 train_time:13169ms step_avg:91.45ms +step:145/1750 train_time:13260ms step_avg:91.45ms +step:146/1750 train_time:13351ms step_avg:91.44ms +step:147/1750 train_time:13442ms step_avg:91.44ms +step:148/1750 train_time:13534ms step_avg:91.45ms +step:149/1750 train_time:13627ms step_avg:91.46ms +step:150/1750 train_time:13721ms step_avg:91.48ms +step:151/1750 train_time:13813ms step_avg:91.48ms +step:152/1750 train_time:13906ms step_avg:91.49ms +step:153/1750 train_time:13998ms step_avg:91.49ms +step:154/1750 train_time:14089ms step_avg:91.49ms +step:155/1750 train_time:14180ms step_avg:91.49ms +step:156/1750 train_time:14272ms step_avg:91.48ms +step:157/1750 train_time:14363ms step_avg:91.48ms +step:158/1750 train_time:14455ms step_avg:91.49ms +step:159/1750 train_time:14547ms step_avg:91.49ms +step:160/1750 train_time:14640ms step_avg:91.50ms +step:161/1750 train_time:14732ms step_avg:91.50ms +step:162/1750 train_time:14825ms step_avg:91.51ms +step:163/1750 train_time:14917ms step_avg:91.51ms +step:164/1750 train_time:15008ms step_avg:91.51ms +step:165/1750 train_time:15101ms step_avg:91.52ms +step:166/1750 train_time:15193ms step_avg:91.52ms +step:167/1750 train_time:15284ms step_avg:91.52ms +step:168/1750 train_time:15376ms step_avg:91.52ms +step:169/1750 train_time:15468ms step_avg:91.53ms +step:170/1750 train_time:15561ms step_avg:91.54ms +step:171/1750 train_time:15653ms step_avg:91.54ms +step:172/1750 train_time:15746ms step_avg:91.55ms +step:173/1750 train_time:15838ms step_avg:91.55ms +step:174/1750 train_time:15929ms step_avg:91.55ms +step:175/1750 train_time:16021ms step_avg:91.55ms +step:176/1750 train_time:16112ms step_avg:91.55ms +step:177/1750 train_time:16204ms step_avg:91.55ms +step:178/1750 train_time:16296ms step_avg:91.55ms +step:179/1750 train_time:16387ms step_avg:91.55ms +step:180/1750 train_time:16480ms step_avg:91.56ms +step:181/1750 train_time:16572ms step_avg:91.56ms +step:182/1750 train_time:16665ms step_avg:91.57ms +step:183/1750 train_time:16757ms step_avg:91.57ms +step:184/1750 train_time:16850ms step_avg:91.58ms +step:185/1750 train_time:16943ms step_avg:91.58ms +step:186/1750 train_time:17034ms step_avg:91.58ms +step:187/1750 train_time:17126ms step_avg:91.58ms +step:188/1750 train_time:17218ms step_avg:91.58ms +step:189/1750 train_time:17309ms step_avg:91.58ms +step:190/1750 train_time:17402ms step_avg:91.59ms +step:191/1750 train_time:17494ms step_avg:91.59ms +step:192/1750 train_time:17587ms step_avg:91.60ms +step:193/1750 train_time:17678ms step_avg:91.60ms +step:194/1750 train_time:17770ms step_avg:91.60ms +step:195/1750 train_time:17862ms step_avg:91.60ms +step:196/1750 train_time:17954ms step_avg:91.60ms +step:197/1750 train_time:18046ms step_avg:91.60ms +step:198/1750 train_time:18137ms step_avg:91.60ms +step:199/1750 train_time:18229ms step_avg:91.61ms +step:200/1750 train_time:18321ms step_avg:91.60ms +step:201/1750 train_time:18412ms step_avg:91.60ms +step:202/1750 train_time:18504ms step_avg:91.61ms +step:203/1750 train_time:18597ms step_avg:91.61ms +step:204/1750 train_time:18689ms step_avg:91.61ms +step:205/1750 train_time:18781ms step_avg:91.61ms +step:206/1750 train_time:18874ms step_avg:91.62ms +step:207/1750 train_time:18966ms step_avg:91.62ms +step:208/1750 train_time:19057ms step_avg:91.62ms +step:209/1750 train_time:19149ms step_avg:91.62ms +step:210/1750 train_time:19241ms step_avg:91.62ms +step:211/1750 train_time:19332ms step_avg:91.62ms +step:212/1750 train_time:19425ms step_avg:91.63ms +step:213/1750 train_time:19517ms step_avg:91.63ms +step:214/1750 train_time:19609ms step_avg:91.63ms +step:215/1750 train_time:19702ms step_avg:91.63ms +step:216/1750 train_time:19794ms step_avg:91.64ms +step:217/1750 train_time:19886ms step_avg:91.64ms +step:218/1750 train_time:19978ms step_avg:91.64ms +step:219/1750 train_time:20070ms step_avg:91.65ms +step:220/1750 train_time:20162ms step_avg:91.65ms +step:221/1750 train_time:20254ms step_avg:91.65ms +step:222/1750 train_time:20346ms step_avg:91.65ms +step:223/1750 train_time:20438ms step_avg:91.65ms +step:224/1750 train_time:20530ms step_avg:91.65ms +step:225/1750 train_time:20622ms step_avg:91.65ms +step:226/1750 train_time:20714ms step_avg:91.65ms +step:227/1750 train_time:20806ms step_avg:91.66ms +step:228/1750 train_time:20899ms step_avg:91.66ms +step:229/1750 train_time:20990ms step_avg:91.66ms +step:230/1750 train_time:21082ms step_avg:91.66ms +step:231/1750 train_time:21174ms step_avg:91.66ms +step:232/1750 train_time:21267ms step_avg:91.67ms +step:233/1750 train_time:21360ms step_avg:91.67ms +step:234/1750 train_time:21450ms step_avg:91.67ms +step:235/1750 train_time:21542ms step_avg:91.67ms +step:236/1750 train_time:21633ms step_avg:91.67ms +step:237/1750 train_time:21726ms step_avg:91.67ms +step:238/1750 train_time:21819ms step_avg:91.67ms +step:239/1750 train_time:21910ms step_avg:91.68ms +step:240/1750 train_time:22002ms step_avg:91.68ms +step:241/1750 train_time:22094ms step_avg:91.68ms +step:242/1750 train_time:22186ms step_avg:91.68ms +step:243/1750 train_time:22278ms step_avg:91.68ms +step:244/1750 train_time:22370ms step_avg:91.68ms +step:245/1750 train_time:22462ms step_avg:91.68ms +step:246/1750 train_time:22554ms step_avg:91.68ms +step:247/1750 train_time:22646ms step_avg:91.68ms +step:248/1750 train_time:22738ms step_avg:91.69ms +step:249/1750 train_time:22830ms step_avg:91.69ms +step:250/1750 train_time:22923ms step_avg:91.69ms +step:250/1750 val_loss:4.0972 train_time:23018ms step_avg:92.07ms +step:251/1750 train_time:23041ms step_avg:91.80ms +step:252/1750 train_time:23113ms step_avg:91.72ms +step:253/1750 train_time:23211ms step_avg:91.74ms +step:254/1750 train_time:23303ms step_avg:91.75ms +step:255/1750 train_time:23395ms step_avg:91.74ms +step:256/1750 train_time:23486ms step_avg:91.74ms +step:257/1750 train_time:23577ms step_avg:91.74ms +step:258/1750 train_time:23668ms step_avg:91.73ms +step:259/1750 train_time:23759ms step_avg:91.73ms +step:260/1750 train_time:23849ms step_avg:91.73ms +step:261/1750 train_time:23940ms step_avg:91.73ms +step:262/1750 train_time:24034ms step_avg:91.73ms +step:263/1750 train_time:24128ms step_avg:91.74ms +step:264/1750 train_time:24221ms step_avg:91.75ms +step:265/1750 train_time:24314ms step_avg:91.75ms +step:266/1750 train_time:24406ms step_avg:91.75ms +step:267/1750 train_time:24499ms step_avg:91.75ms +step:268/1750 train_time:24591ms step_avg:91.76ms +step:269/1750 train_time:24683ms step_avg:91.76ms +step:270/1750 train_time:24775ms step_avg:91.76ms +step:271/1750 train_time:24867ms step_avg:91.76ms +step:272/1750 train_time:24959ms step_avg:91.76ms +step:273/1750 train_time:25051ms step_avg:91.76ms +step:274/1750 train_time:25144ms step_avg:91.77ms +step:275/1750 train_time:25237ms step_avg:91.77ms +step:276/1750 train_time:25330ms step_avg:91.78ms +step:277/1750 train_time:25423ms step_avg:91.78ms +step:278/1750 train_time:25515ms step_avg:91.78ms +step:279/1750 train_time:25607ms step_avg:91.78ms +step:280/1750 train_time:25699ms step_avg:91.78ms +step:281/1750 train_time:25792ms step_avg:91.79ms +step:282/1750 train_time:25884ms step_avg:91.79ms +step:283/1750 train_time:25976ms step_avg:91.79ms +step:284/1750 train_time:26068ms step_avg:91.79ms +step:285/1750 train_time:26161ms step_avg:91.79ms +step:286/1750 train_time:26254ms step_avg:91.80ms +step:287/1750 train_time:26346ms step_avg:91.80ms +step:288/1750 train_time:26439ms step_avg:91.80ms +step:289/1750 train_time:26532ms step_avg:91.81ms +step:290/1750 train_time:26624ms step_avg:91.81ms +step:291/1750 train_time:26716ms step_avg:91.81ms +step:292/1750 train_time:26808ms step_avg:91.81ms +step:293/1750 train_time:26900ms step_avg:91.81ms +step:294/1750 train_time:26992ms step_avg:91.81ms +step:295/1750 train_time:27085ms step_avg:91.81ms +step:296/1750 train_time:27177ms step_avg:91.82ms +step:297/1750 train_time:27269ms step_avg:91.82ms +step:298/1750 train_time:27362ms step_avg:91.82ms +step:299/1750 train_time:27455ms step_avg:91.82ms +step:300/1750 train_time:27548ms step_avg:91.83ms +step:301/1750 train_time:27640ms step_avg:91.83ms +step:302/1750 train_time:27732ms step_avg:91.83ms +step:303/1750 train_time:27824ms step_avg:91.83ms +step:304/1750 train_time:27916ms step_avg:91.83ms +step:305/1750 train_time:28008ms step_avg:91.83ms +step:306/1750 train_time:28101ms step_avg:91.83ms +step:307/1750 train_time:28193ms step_avg:91.83ms +step:308/1750 train_time:28285ms step_avg:91.84ms +step:309/1750 train_time:28378ms step_avg:91.84ms +step:310/1750 train_time:28470ms step_avg:91.84ms +step:311/1750 train_time:28563ms step_avg:91.84ms +step:312/1750 train_time:28656ms step_avg:91.85ms +step:313/1750 train_time:28747ms step_avg:91.84ms +step:314/1750 train_time:28839ms step_avg:91.84ms +step:315/1750 train_time:28932ms step_avg:91.85ms +step:316/1750 train_time:29024ms step_avg:91.85ms +step:317/1750 train_time:29117ms step_avg:91.85ms +step:318/1750 train_time:29209ms step_avg:91.85ms +step:319/1750 train_time:29303ms step_avg:91.86ms +step:320/1750 train_time:29395ms step_avg:91.86ms +step:321/1750 train_time:29488ms step_avg:91.86ms +step:322/1750 train_time:29580ms step_avg:91.86ms +step:323/1750 train_time:29672ms step_avg:91.86ms +step:324/1750 train_time:29764ms step_avg:91.86ms +step:325/1750 train_time:29856ms step_avg:91.86ms +step:326/1750 train_time:29948ms step_avg:91.86ms +step:327/1750 train_time:30040ms step_avg:91.87ms +step:328/1750 train_time:30133ms step_avg:91.87ms +step:329/1750 train_time:30226ms step_avg:91.87ms +step:330/1750 train_time:30318ms step_avg:91.87ms +step:331/1750 train_time:30411ms step_avg:91.88ms +step:332/1750 train_time:30504ms step_avg:91.88ms +step:333/1750 train_time:30597ms step_avg:91.88ms +step:334/1750 train_time:30690ms step_avg:91.89ms +step:335/1750 train_time:30782ms step_avg:91.89ms +step:336/1750 train_time:30874ms step_avg:91.89ms +step:337/1750 train_time:30966ms step_avg:91.89ms +step:338/1750 train_time:31059ms step_avg:91.89ms +step:339/1750 train_time:31152ms step_avg:91.89ms +step:340/1750 train_time:31244ms step_avg:91.90ms +step:341/1750 train_time:31338ms step_avg:91.90ms +step:342/1750 train_time:31431ms step_avg:91.90ms +step:343/1750 train_time:31523ms step_avg:91.90ms +step:344/1750 train_time:31615ms step_avg:91.90ms +step:345/1750 train_time:31707ms step_avg:91.90ms +step:346/1750 train_time:31800ms step_avg:91.91ms +step:347/1750 train_time:31892ms step_avg:91.91ms +step:348/1750 train_time:31984ms step_avg:91.91ms +step:349/1750 train_time:32077ms step_avg:91.91ms +step:350/1750 train_time:32168ms step_avg:91.91ms +step:351/1750 train_time:32261ms step_avg:91.91ms +step:352/1750 train_time:32353ms step_avg:91.91ms +step:353/1750 train_time:32445ms step_avg:91.91ms +step:354/1750 train_time:32539ms step_avg:91.92ms +step:355/1750 train_time:32631ms step_avg:91.92ms +step:356/1750 train_time:32724ms step_avg:91.92ms +step:357/1750 train_time:32816ms step_avg:91.92ms +step:358/1750 train_time:32908ms step_avg:91.92ms +step:359/1750 train_time:33000ms step_avg:91.92ms +step:360/1750 train_time:33092ms step_avg:91.92ms +step:361/1750 train_time:33185ms step_avg:91.92ms +step:362/1750 train_time:33277ms step_avg:91.93ms +step:363/1750 train_time:33369ms step_avg:91.93ms +step:364/1750 train_time:33461ms step_avg:91.93ms +step:365/1750 train_time:33554ms step_avg:91.93ms +step:366/1750 train_time:33646ms step_avg:91.93ms +step:367/1750 train_time:33738ms step_avg:91.93ms +step:368/1750 train_time:33830ms step_avg:91.93ms +step:369/1750 train_time:33923ms step_avg:91.93ms +step:370/1750 train_time:34015ms step_avg:91.93ms +step:371/1750 train_time:34107ms step_avg:91.93ms +step:372/1750 train_time:34200ms step_avg:91.93ms +step:373/1750 train_time:34292ms step_avg:91.94ms +step:374/1750 train_time:34385ms step_avg:91.94ms +step:375/1750 train_time:34477ms step_avg:91.94ms +step:375/1750 val_loss:3.8911 train_time:34573ms step_avg:92.19ms +step:376/1750 train_time:34597ms step_avg:92.01ms +step:377/1750 train_time:34671ms step_avg:91.97ms +step:378/1750 train_time:34767ms step_avg:91.98ms +step:379/1750 train_time:34861ms step_avg:91.98ms +step:380/1750 train_time:34953ms step_avg:91.98ms +step:381/1750 train_time:35045ms step_avg:91.98ms +step:382/1750 train_time:35136ms step_avg:91.98ms +step:383/1750 train_time:35228ms step_avg:91.98ms +step:384/1750 train_time:35320ms step_avg:91.98ms +step:385/1750 train_time:35411ms step_avg:91.98ms +step:386/1750 train_time:35503ms step_avg:91.98ms +step:387/1750 train_time:35597ms step_avg:91.98ms +step:388/1750 train_time:35691ms step_avg:91.99ms +step:389/1750 train_time:35786ms step_avg:91.99ms +step:390/1750 train_time:35878ms step_avg:92.00ms +step:391/1750 train_time:35973ms step_avg:92.00ms +step:392/1750 train_time:36067ms step_avg:92.01ms +step:393/1750 train_time:36160ms step_avg:92.01ms +step:394/1750 train_time:36253ms step_avg:92.01ms +step:395/1750 train_time:36346ms step_avg:92.01ms +step:396/1750 train_time:36439ms step_avg:92.02ms +step:397/1750 train_time:36533ms step_avg:92.02ms +step:398/1750 train_time:36628ms step_avg:92.03ms +step:399/1750 train_time:36723ms step_avg:92.04ms +step:400/1750 train_time:36819ms step_avg:92.05ms +step:401/1750 train_time:36913ms step_avg:92.05ms +step:402/1750 train_time:37007ms step_avg:92.06ms +step:403/1750 train_time:37101ms step_avg:92.06ms +step:404/1750 train_time:37195ms step_avg:92.07ms +step:405/1750 train_time:37288ms step_avg:92.07ms +step:406/1750 train_time:37381ms step_avg:92.07ms +step:407/1750 train_time:37476ms step_avg:92.08ms +step:408/1750 train_time:37570ms step_avg:92.08ms +step:409/1750 train_time:37665ms step_avg:92.09ms +step:410/1750 train_time:37760ms step_avg:92.10ms +step:411/1750 train_time:37855ms step_avg:92.10ms +step:412/1750 train_time:37949ms step_avg:92.11ms +step:413/1750 train_time:38044ms step_avg:92.12ms +step:414/1750 train_time:38138ms step_avg:92.12ms +step:415/1750 train_time:38232ms step_avg:92.12ms +step:416/1750 train_time:38325ms step_avg:92.13ms +step:417/1750 train_time:38419ms step_avg:92.13ms +step:418/1750 train_time:38514ms step_avg:92.14ms +step:419/1750 train_time:38608ms step_avg:92.14ms +step:420/1750 train_time:38703ms step_avg:92.15ms +step:421/1750 train_time:38798ms step_avg:92.16ms +step:422/1750 train_time:38893ms step_avg:92.16ms +step:423/1750 train_time:38987ms step_avg:92.17ms +step:424/1750 train_time:39081ms step_avg:92.17ms +step:425/1750 train_time:39176ms step_avg:92.18ms +step:426/1750 train_time:39269ms step_avg:92.18ms +step:427/1750 train_time:39363ms step_avg:92.18ms +step:428/1750 train_time:39458ms step_avg:92.19ms +step:429/1750 train_time:39552ms step_avg:92.20ms +step:430/1750 train_time:39646ms step_avg:92.20ms +step:431/1750 train_time:39741ms step_avg:92.21ms +step:432/1750 train_time:39837ms step_avg:92.21ms +step:433/1750 train_time:39931ms step_avg:92.22ms +step:434/1750 train_time:40026ms step_avg:92.23ms +step:435/1750 train_time:40120ms step_avg:92.23ms +step:436/1750 train_time:40214ms step_avg:92.23ms +step:437/1750 train_time:40308ms step_avg:92.24ms +step:438/1750 train_time:40402ms step_avg:92.24ms +step:439/1750 train_time:40497ms step_avg:92.25ms +step:440/1750 train_time:40591ms step_avg:92.25ms +step:441/1750 train_time:40685ms step_avg:92.26ms +step:442/1750 train_time:40780ms step_avg:92.26ms +step:443/1750 train_time:40876ms step_avg:92.27ms +step:444/1750 train_time:40971ms step_avg:92.28ms +step:445/1750 train_time:41065ms step_avg:92.28ms +step:446/1750 train_time:41159ms step_avg:92.29ms +step:447/1750 train_time:41254ms step_avg:92.29ms +step:448/1750 train_time:41348ms step_avg:92.29ms +step:449/1750 train_time:41442ms step_avg:92.30ms +step:450/1750 train_time:41536ms step_avg:92.30ms +step:451/1750 train_time:41630ms step_avg:92.31ms +step:452/1750 train_time:41724ms step_avg:92.31ms +step:453/1750 train_time:41819ms step_avg:92.32ms +step:454/1750 train_time:41914ms step_avg:92.32ms +step:455/1750 train_time:42008ms step_avg:92.33ms +step:456/1750 train_time:42102ms step_avg:92.33ms +step:457/1750 train_time:42196ms step_avg:92.33ms +step:458/1750 train_time:42291ms step_avg:92.34ms +step:459/1750 train_time:42385ms step_avg:92.34ms +step:460/1750 train_time:42479ms step_avg:92.35ms +step:461/1750 train_time:42573ms step_avg:92.35ms +step:462/1750 train_time:42667ms step_avg:92.35ms +step:463/1750 train_time:42761ms step_avg:92.36ms +step:464/1750 train_time:42855ms step_avg:92.36ms +step:465/1750 train_time:42950ms step_avg:92.36ms +step:466/1750 train_time:43044ms step_avg:92.37ms +step:467/1750 train_time:43138ms step_avg:92.37ms +step:468/1750 train_time:43233ms step_avg:92.38ms +step:469/1750 train_time:43327ms step_avg:92.38ms +step:470/1750 train_time:43421ms step_avg:92.38ms +step:471/1750 train_time:43515ms step_avg:92.39ms +step:472/1750 train_time:43609ms step_avg:92.39ms +step:473/1750 train_time:43704ms step_avg:92.40ms +step:474/1750 train_time:43798ms step_avg:92.40ms +step:475/1750 train_time:43892ms step_avg:92.41ms +step:476/1750 train_time:43987ms step_avg:92.41ms +step:477/1750 train_time:44080ms step_avg:92.41ms +step:478/1750 train_time:44174ms step_avg:92.41ms +step:479/1750 train_time:44269ms step_avg:92.42ms +step:480/1750 train_time:44363ms step_avg:92.42ms +step:481/1750 train_time:44458ms step_avg:92.43ms +step:482/1750 train_time:44552ms step_avg:92.43ms +step:483/1750 train_time:44645ms step_avg:92.43ms +step:484/1750 train_time:44739ms step_avg:92.44ms +step:485/1750 train_time:44833ms step_avg:92.44ms +step:486/1750 train_time:44927ms step_avg:92.44ms +step:487/1750 train_time:45022ms step_avg:92.45ms +step:488/1750 train_time:45116ms step_avg:92.45ms +step:489/1750 train_time:45211ms step_avg:92.46ms +step:490/1750 train_time:45305ms step_avg:92.46ms +step:491/1750 train_time:45400ms step_avg:92.46ms +step:492/1750 train_time:45495ms step_avg:92.47ms +step:493/1750 train_time:45589ms step_avg:92.47ms +step:494/1750 train_time:45683ms step_avg:92.48ms +step:495/1750 train_time:45778ms step_avg:92.48ms +step:496/1750 train_time:45873ms step_avg:92.49ms +step:497/1750 train_time:45968ms step_avg:92.49ms +step:498/1750 train_time:46063ms step_avg:92.50ms +step:499/1750 train_time:46157ms step_avg:92.50ms +step:500/1750 train_time:46251ms step_avg:92.50ms +step:500/1750 val_loss:3.7412 train_time:46349ms step_avg:92.70ms +step:501/1750 train_time:46373ms step_avg:92.56ms +step:502/1750 train_time:46448ms step_avg:92.53ms +step:503/1750 train_time:46548ms step_avg:92.54ms +step:504/1750 train_time:46642ms step_avg:92.54ms +step:505/1750 train_time:46738ms step_avg:92.55ms +step:506/1750 train_time:46831ms step_avg:92.55ms +step:507/1750 train_time:46923ms step_avg:92.55ms +step:508/1750 train_time:47016ms step_avg:92.55ms +step:509/1750 train_time:47109ms step_avg:92.55ms +step:510/1750 train_time:47202ms step_avg:92.55ms +step:511/1750 train_time:47296ms step_avg:92.56ms +step:512/1750 train_time:47392ms step_avg:92.56ms +step:513/1750 train_time:47489ms step_avg:92.57ms +step:514/1750 train_time:47584ms step_avg:92.58ms +step:515/1750 train_time:47679ms step_avg:92.58ms +step:516/1750 train_time:47773ms step_avg:92.58ms +step:517/1750 train_time:47867ms step_avg:92.59ms +step:518/1750 train_time:47960ms step_avg:92.59ms +step:519/1750 train_time:48054ms step_avg:92.59ms +step:520/1750 train_time:48147ms step_avg:92.59ms +step:521/1750 train_time:48241ms step_avg:92.59ms +step:522/1750 train_time:48335ms step_avg:92.60ms +step:523/1750 train_time:48431ms step_avg:92.60ms +step:524/1750 train_time:48526ms step_avg:92.61ms +step:525/1750 train_time:48622ms step_avg:92.61ms +step:526/1750 train_time:48718ms step_avg:92.62ms +step:527/1750 train_time:48812ms step_avg:92.62ms +step:528/1750 train_time:48906ms step_avg:92.62ms +step:529/1750 train_time:49000ms step_avg:92.63ms +step:530/1750 train_time:49094ms step_avg:92.63ms +step:531/1750 train_time:49189ms step_avg:92.63ms +step:532/1750 train_time:49283ms step_avg:92.64ms +step:533/1750 train_time:49378ms step_avg:92.64ms +step:534/1750 train_time:49473ms step_avg:92.65ms +step:535/1750 train_time:49569ms step_avg:92.65ms +step:536/1750 train_time:49664ms step_avg:92.66ms +step:537/1750 train_time:49759ms step_avg:92.66ms +step:538/1750 train_time:49853ms step_avg:92.66ms +step:539/1750 train_time:49947ms step_avg:92.67ms +step:540/1750 train_time:50041ms step_avg:92.67ms +step:541/1750 train_time:50136ms step_avg:92.67ms +step:542/1750 train_time:50230ms step_avg:92.68ms +step:543/1750 train_time:50324ms step_avg:92.68ms +step:544/1750 train_time:50419ms step_avg:92.68ms +step:545/1750 train_time:50514ms step_avg:92.69ms +step:546/1750 train_time:50609ms step_avg:92.69ms +step:547/1750 train_time:50703ms step_avg:92.69ms +step:548/1750 train_time:50799ms step_avg:92.70ms +step:549/1750 train_time:50893ms step_avg:92.70ms +step:550/1750 train_time:50988ms step_avg:92.70ms +step:551/1750 train_time:51082ms step_avg:92.71ms +step:552/1750 train_time:51176ms step_avg:92.71ms +step:553/1750 train_time:51270ms step_avg:92.71ms +step:554/1750 train_time:51364ms step_avg:92.72ms +step:555/1750 train_time:51460ms step_avg:92.72ms +step:556/1750 train_time:51555ms step_avg:92.73ms +step:557/1750 train_time:51650ms step_avg:92.73ms +step:558/1750 train_time:51745ms step_avg:92.73ms +step:559/1750 train_time:51839ms step_avg:92.74ms +step:560/1750 train_time:51934ms step_avg:92.74ms +step:561/1750 train_time:52028ms step_avg:92.74ms +step:562/1750 train_time:52122ms step_avg:92.74ms +step:563/1750 train_time:52216ms step_avg:92.75ms +step:564/1750 train_time:52311ms step_avg:92.75ms +step:565/1750 train_time:52406ms step_avg:92.75ms +step:566/1750 train_time:52501ms step_avg:92.76ms +step:567/1750 train_time:52596ms step_avg:92.76ms +step:568/1750 train_time:52691ms step_avg:92.77ms +step:569/1750 train_time:52785ms step_avg:92.77ms +step:570/1750 train_time:52880ms step_avg:92.77ms +step:571/1750 train_time:52976ms step_avg:92.78ms +step:572/1750 train_time:53069ms step_avg:92.78ms +step:573/1750 train_time:53163ms step_avg:92.78ms +step:574/1750 train_time:53258ms step_avg:92.78ms +step:575/1750 train_time:53353ms step_avg:92.79ms +step:576/1750 train_time:53447ms step_avg:92.79ms +step:577/1750 train_time:53542ms step_avg:92.79ms +step:578/1750 train_time:53636ms step_avg:92.80ms +step:579/1750 train_time:53731ms step_avg:92.80ms +step:580/1750 train_time:53827ms step_avg:92.81ms +step:581/1750 train_time:53922ms step_avg:92.81ms +step:582/1750 train_time:54017ms step_avg:92.81ms +step:583/1750 train_time:54112ms step_avg:92.82ms +step:584/1750 train_time:54205ms step_avg:92.82ms +step:585/1750 train_time:54300ms step_avg:92.82ms +step:586/1750 train_time:54394ms step_avg:92.82ms +step:587/1750 train_time:54490ms step_avg:92.83ms +step:588/1750 train_time:54584ms step_avg:92.83ms +step:589/1750 train_time:54679ms step_avg:92.83ms +step:590/1750 train_time:54773ms step_avg:92.84ms +step:591/1750 train_time:54868ms step_avg:92.84ms +step:592/1750 train_time:54963ms step_avg:92.84ms +step:593/1750 train_time:55058ms step_avg:92.85ms +step:594/1750 train_time:55152ms step_avg:92.85ms +step:595/1750 train_time:55245ms step_avg:92.85ms +step:596/1750 train_time:55340ms step_avg:92.85ms +step:597/1750 train_time:55434ms step_avg:92.85ms +step:598/1750 train_time:55529ms step_avg:92.86ms +step:599/1750 train_time:55624ms step_avg:92.86ms +step:600/1750 train_time:55719ms step_avg:92.86ms +step:601/1750 train_time:55813ms step_avg:92.87ms +step:602/1750 train_time:55910ms step_avg:92.87ms +step:603/1750 train_time:56004ms step_avg:92.88ms +step:604/1750 train_time:56098ms step_avg:92.88ms +step:605/1750 train_time:56193ms step_avg:92.88ms +step:606/1750 train_time:56288ms step_avg:92.88ms +step:607/1750 train_time:56383ms step_avg:92.89ms +step:608/1750 train_time:56478ms step_avg:92.89ms +step:609/1750 train_time:56572ms step_avg:92.89ms +step:610/1750 train_time:56668ms step_avg:92.90ms +step:611/1750 train_time:56762ms step_avg:92.90ms +step:612/1750 train_time:56858ms step_avg:92.91ms +step:613/1750 train_time:56952ms step_avg:92.91ms +step:614/1750 train_time:57047ms step_avg:92.91ms +step:615/1750 train_time:57141ms step_avg:92.91ms +step:616/1750 train_time:57236ms step_avg:92.92ms +step:617/1750 train_time:57330ms step_avg:92.92ms +step:618/1750 train_time:57425ms step_avg:92.92ms +step:619/1750 train_time:57519ms step_avg:92.92ms +step:620/1750 train_time:57613ms step_avg:92.92ms +step:621/1750 train_time:57708ms step_avg:92.93ms +step:622/1750 train_time:57802ms step_avg:92.93ms +step:623/1750 train_time:57896ms step_avg:92.93ms +step:624/1750 train_time:57990ms step_avg:92.93ms +step:625/1750 train_time:58086ms step_avg:92.94ms +step:625/1750 val_loss:3.6537 train_time:58184ms step_avg:93.09ms +step:626/1750 train_time:58208ms step_avg:92.98ms +step:627/1750 train_time:58285ms step_avg:92.96ms +step:628/1750 train_time:58383ms step_avg:92.97ms +step:629/1750 train_time:58478ms step_avg:92.97ms +step:630/1750 train_time:58571ms step_avg:92.97ms +step:631/1750 train_time:58665ms step_avg:92.97ms +step:632/1750 train_time:58759ms step_avg:92.97ms +step:633/1750 train_time:58852ms step_avg:92.97ms +step:634/1750 train_time:58946ms step_avg:92.97ms +step:635/1750 train_time:59040ms step_avg:92.98ms +step:636/1750 train_time:59133ms step_avg:92.98ms +step:637/1750 train_time:59229ms step_avg:92.98ms +step:638/1750 train_time:59325ms step_avg:92.99ms +step:639/1750 train_time:59420ms step_avg:92.99ms +step:640/1750 train_time:59516ms step_avg:92.99ms +step:641/1750 train_time:59611ms step_avg:93.00ms +step:642/1750 train_time:59706ms step_avg:93.00ms +step:643/1750 train_time:59800ms step_avg:93.00ms +step:644/1750 train_time:59894ms step_avg:93.00ms +step:645/1750 train_time:59987ms step_avg:93.00ms +step:646/1750 train_time:60081ms step_avg:93.00ms +step:647/1750 train_time:60176ms step_avg:93.01ms +step:648/1750 train_time:60271ms step_avg:93.01ms +step:649/1750 train_time:60367ms step_avg:93.01ms +step:650/1750 train_time:60462ms step_avg:93.02ms +step:651/1750 train_time:60559ms step_avg:93.02ms +step:652/1750 train_time:60655ms step_avg:93.03ms +step:653/1750 train_time:60750ms step_avg:93.03ms +step:654/1750 train_time:60846ms step_avg:93.04ms +step:655/1750 train_time:60941ms step_avg:93.04ms +step:656/1750 train_time:61037ms step_avg:93.04ms +step:657/1750 train_time:61132ms step_avg:93.05ms +step:658/1750 train_time:61228ms step_avg:93.05ms +step:659/1750 train_time:61324ms step_avg:93.06ms +step:660/1750 train_time:61421ms step_avg:93.06ms +step:661/1750 train_time:61517ms step_avg:93.07ms +step:662/1750 train_time:61614ms step_avg:93.07ms +step:663/1750 train_time:61709ms step_avg:93.08ms +step:664/1750 train_time:61805ms step_avg:93.08ms +step:665/1750 train_time:61900ms step_avg:93.08ms +step:666/1750 train_time:61996ms step_avg:93.09ms +step:667/1750 train_time:62092ms step_avg:93.09ms +step:668/1750 train_time:62188ms step_avg:93.10ms +step:669/1750 train_time:62284ms step_avg:93.10ms +step:670/1750 train_time:62381ms step_avg:93.11ms +step:671/1750 train_time:62477ms step_avg:93.11ms +step:672/1750 train_time:62573ms step_avg:93.11ms +step:673/1750 train_time:62669ms step_avg:93.12ms +step:674/1750 train_time:62765ms step_avg:93.12ms +step:675/1750 train_time:62860ms step_avg:93.13ms +step:676/1750 train_time:62956ms step_avg:93.13ms +step:677/1750 train_time:63051ms step_avg:93.13ms +step:678/1750 train_time:63147ms step_avg:93.14ms +step:679/1750 train_time:63243ms step_avg:93.14ms +step:680/1750 train_time:63340ms step_avg:93.15ms +step:681/1750 train_time:63436ms step_avg:93.15ms +step:682/1750 train_time:63532ms step_avg:93.16ms +step:683/1750 train_time:63627ms step_avg:93.16ms +step:684/1750 train_time:63723ms step_avg:93.16ms +step:685/1750 train_time:63820ms step_avg:93.17ms +step:686/1750 train_time:63915ms step_avg:93.17ms +step:687/1750 train_time:64011ms step_avg:93.17ms +step:688/1750 train_time:64106ms step_avg:93.18ms +step:689/1750 train_time:64202ms step_avg:93.18ms +step:690/1750 train_time:64299ms step_avg:93.19ms +step:691/1750 train_time:64394ms step_avg:93.19ms +step:692/1750 train_time:64490ms step_avg:93.19ms +step:693/1750 train_time:64586ms step_avg:93.20ms +step:694/1750 train_time:64681ms step_avg:93.20ms +step:695/1750 train_time:64778ms step_avg:93.21ms +step:696/1750 train_time:64874ms step_avg:93.21ms +step:697/1750 train_time:64969ms step_avg:93.21ms +step:698/1750 train_time:65065ms step_avg:93.22ms +step:699/1750 train_time:65161ms step_avg:93.22ms +step:700/1750 train_time:65257ms step_avg:93.22ms +step:701/1750 train_time:65353ms step_avg:93.23ms +step:702/1750 train_time:65450ms step_avg:93.23ms +step:703/1750 train_time:65545ms step_avg:93.24ms +step:704/1750 train_time:65642ms step_avg:93.24ms +step:705/1750 train_time:65738ms step_avg:93.25ms +step:706/1750 train_time:65834ms step_avg:93.25ms +step:707/1750 train_time:65930ms step_avg:93.25ms +step:708/1750 train_time:66025ms step_avg:93.26ms +step:709/1750 train_time:66122ms step_avg:93.26ms +step:710/1750 train_time:66218ms step_avg:93.26ms +step:711/1750 train_time:66314ms step_avg:93.27ms +step:712/1750 train_time:66410ms step_avg:93.27ms +step:713/1750 train_time:66505ms step_avg:93.28ms +step:714/1750 train_time:66601ms step_avg:93.28ms +step:715/1750 train_time:66698ms step_avg:93.28ms +step:716/1750 train_time:66793ms step_avg:93.29ms +step:717/1750 train_time:66889ms step_avg:93.29ms +step:718/1750 train_time:66985ms step_avg:93.29ms +step:719/1750 train_time:67081ms step_avg:93.30ms +step:720/1750 train_time:67176ms step_avg:93.30ms +step:721/1750 train_time:67272ms step_avg:93.30ms +step:722/1750 train_time:67368ms step_avg:93.31ms +step:723/1750 train_time:67464ms step_avg:93.31ms +step:724/1750 train_time:67561ms step_avg:93.32ms +step:725/1750 train_time:67657ms step_avg:93.32ms +step:726/1750 train_time:67753ms step_avg:93.32ms +step:727/1750 train_time:67850ms step_avg:93.33ms +step:728/1750 train_time:67945ms step_avg:93.33ms +step:729/1750 train_time:68042ms step_avg:93.34ms +step:730/1750 train_time:68138ms step_avg:93.34ms +step:731/1750 train_time:68234ms step_avg:93.34ms +step:732/1750 train_time:68330ms step_avg:93.35ms +step:733/1750 train_time:68425ms step_avg:93.35ms +step:734/1750 train_time:68521ms step_avg:93.35ms +step:735/1750 train_time:68617ms step_avg:93.36ms +step:736/1750 train_time:68713ms step_avg:93.36ms +step:737/1750 train_time:68809ms step_avg:93.36ms +step:738/1750 train_time:68904ms step_avg:93.37ms +step:739/1750 train_time:69000ms step_avg:93.37ms +step:740/1750 train_time:69097ms step_avg:93.37ms +step:741/1750 train_time:69193ms step_avg:93.38ms +step:742/1750 train_time:69289ms step_avg:93.38ms +step:743/1750 train_time:69385ms step_avg:93.38ms +step:744/1750 train_time:69481ms step_avg:93.39ms +step:745/1750 train_time:69577ms step_avg:93.39ms +step:746/1750 train_time:69673ms step_avg:93.40ms +step:747/1750 train_time:69770ms step_avg:93.40ms +step:748/1750 train_time:69866ms step_avg:93.40ms +step:749/1750 train_time:69961ms step_avg:93.41ms +step:750/1750 train_time:70057ms step_avg:93.41ms +step:750/1750 val_loss:3.5904 train_time:70156ms step_avg:93.54ms +step:751/1750 train_time:70181ms step_avg:93.45ms +step:752/1750 train_time:70258ms step_avg:93.43ms +step:753/1750 train_time:70358ms step_avg:93.44ms +step:754/1750 train_time:70454ms step_avg:93.44ms +step:755/1750 train_time:70550ms step_avg:93.44ms +step:756/1750 train_time:70645ms step_avg:93.45ms +step:757/1750 train_time:70740ms step_avg:93.45ms +step:758/1750 train_time:70835ms step_avg:93.45ms +step:759/1750 train_time:70930ms step_avg:93.45ms +step:760/1750 train_time:71024ms step_avg:93.45ms +step:761/1750 train_time:71120ms step_avg:93.46ms +step:762/1750 train_time:71218ms step_avg:93.46ms +step:763/1750 train_time:71316ms step_avg:93.47ms +step:764/1750 train_time:71413ms step_avg:93.47ms +step:765/1750 train_time:71510ms step_avg:93.48ms +step:766/1750 train_time:71606ms step_avg:93.48ms +step:767/1750 train_time:71701ms step_avg:93.48ms +step:768/1750 train_time:71797ms step_avg:93.49ms +step:769/1750 train_time:71891ms step_avg:93.49ms +step:770/1750 train_time:71986ms step_avg:93.49ms +step:771/1750 train_time:72082ms step_avg:93.49ms +step:772/1750 train_time:72178ms step_avg:93.49ms +step:773/1750 train_time:72275ms step_avg:93.50ms +step:774/1750 train_time:72372ms step_avg:93.50ms +step:775/1750 train_time:72469ms step_avg:93.51ms +step:776/1750 train_time:72566ms step_avg:93.51ms +step:777/1750 train_time:72661ms step_avg:93.52ms +step:778/1750 train_time:72756ms step_avg:93.52ms +step:779/1750 train_time:72851ms step_avg:93.52ms +step:780/1750 train_time:72947ms step_avg:93.52ms +step:781/1750 train_time:73042ms step_avg:93.52ms +step:782/1750 train_time:73138ms step_avg:93.53ms +step:783/1750 train_time:73234ms step_avg:93.53ms +step:784/1750 train_time:73330ms step_avg:93.53ms +step:785/1750 train_time:73427ms step_avg:93.54ms +step:786/1750 train_time:73524ms step_avg:93.54ms +step:787/1750 train_time:73620ms step_avg:93.54ms +step:788/1750 train_time:73715ms step_avg:93.55ms +step:789/1750 train_time:73811ms step_avg:93.55ms +step:790/1750 train_time:73907ms step_avg:93.55ms +step:791/1750 train_time:74003ms step_avg:93.56ms +step:792/1750 train_time:74098ms step_avg:93.56ms +step:793/1750 train_time:74195ms step_avg:93.56ms +step:794/1750 train_time:74291ms step_avg:93.57ms +step:795/1750 train_time:74388ms step_avg:93.57ms +step:796/1750 train_time:74485ms step_avg:93.57ms +step:797/1750 train_time:74581ms step_avg:93.58ms +step:798/1750 train_time:74677ms step_avg:93.58ms +step:799/1750 train_time:74773ms step_avg:93.58ms +step:800/1750 train_time:74869ms step_avg:93.59ms +step:801/1750 train_time:74965ms step_avg:93.59ms +step:802/1750 train_time:75061ms step_avg:93.59ms +step:803/1750 train_time:75157ms step_avg:93.59ms +step:804/1750 train_time:75253ms step_avg:93.60ms +step:805/1750 train_time:75349ms step_avg:93.60ms +step:806/1750 train_time:75445ms step_avg:93.60ms +step:807/1750 train_time:75541ms step_avg:93.61ms +step:808/1750 train_time:75638ms step_avg:93.61ms +step:809/1750 train_time:75734ms step_avg:93.61ms +step:810/1750 train_time:75830ms step_avg:93.62ms +step:811/1750 train_time:75926ms step_avg:93.62ms +step:812/1750 train_time:76022ms step_avg:93.62ms +step:813/1750 train_time:76118ms step_avg:93.63ms +step:814/1750 train_time:76214ms step_avg:93.63ms +step:815/1750 train_time:76311ms step_avg:93.63ms +step:816/1750 train_time:76408ms step_avg:93.64ms +step:817/1750 train_time:76505ms step_avg:93.64ms +step:818/1750 train_time:76601ms step_avg:93.64ms +step:819/1750 train_time:76697ms step_avg:93.65ms +step:820/1750 train_time:76793ms step_avg:93.65ms +step:821/1750 train_time:76889ms step_avg:93.65ms +step:822/1750 train_time:76986ms step_avg:93.66ms +step:823/1750 train_time:77082ms step_avg:93.66ms +step:824/1750 train_time:77178ms step_avg:93.66ms +step:825/1750 train_time:77274ms step_avg:93.67ms +step:826/1750 train_time:77370ms step_avg:93.67ms +step:827/1750 train_time:77467ms step_avg:93.67ms +step:828/1750 train_time:77563ms step_avg:93.68ms +step:829/1750 train_time:77659ms step_avg:93.68ms +step:830/1750 train_time:77755ms step_avg:93.68ms +step:831/1750 train_time:77852ms step_avg:93.68ms +step:832/1750 train_time:77949ms step_avg:93.69ms +step:833/1750 train_time:78045ms step_avg:93.69ms +step:834/1750 train_time:78141ms step_avg:93.69ms +step:835/1750 train_time:78237ms step_avg:93.70ms +step:836/1750 train_time:78334ms step_avg:93.70ms +step:837/1750 train_time:78431ms step_avg:93.70ms +step:838/1750 train_time:78527ms step_avg:93.71ms +step:839/1750 train_time:78623ms step_avg:93.71ms +step:840/1750 train_time:78720ms step_avg:93.71ms +step:841/1750 train_time:78817ms step_avg:93.72ms +step:842/1750 train_time:78914ms step_avg:93.72ms +step:843/1750 train_time:79010ms step_avg:93.72ms +step:844/1750 train_time:79106ms step_avg:93.73ms +step:845/1750 train_time:79203ms step_avg:93.73ms +step:846/1750 train_time:79299ms step_avg:93.73ms +step:847/1750 train_time:79395ms step_avg:93.74ms +step:848/1750 train_time:79492ms step_avg:93.74ms +step:849/1750 train_time:79588ms step_avg:93.74ms +step:850/1750 train_time:79685ms step_avg:93.75ms +step:851/1750 train_time:79781ms step_avg:93.75ms +step:852/1750 train_time:79877ms step_avg:93.75ms +step:853/1750 train_time:79973ms step_avg:93.76ms +step:854/1750 train_time:80070ms step_avg:93.76ms +step:855/1750 train_time:80166ms step_avg:93.76ms +step:856/1750 train_time:80262ms step_avg:93.76ms +step:857/1750 train_time:80358ms step_avg:93.77ms +step:858/1750 train_time:80454ms step_avg:93.77ms +step:859/1750 train_time:80550ms step_avg:93.77ms +step:860/1750 train_time:80646ms step_avg:93.77ms +step:861/1750 train_time:80743ms step_avg:93.78ms +step:862/1750 train_time:80839ms step_avg:93.78ms +step:863/1750 train_time:80935ms step_avg:93.78ms +step:864/1750 train_time:81031ms step_avg:93.79ms +step:865/1750 train_time:81127ms step_avg:93.79ms +step:866/1750 train_time:81224ms step_avg:93.79ms +step:867/1750 train_time:81320ms step_avg:93.79ms +step:868/1750 train_time:81416ms step_avg:93.80ms +step:869/1750 train_time:81512ms step_avg:93.80ms +step:870/1750 train_time:81608ms step_avg:93.80ms +step:871/1750 train_time:81705ms step_avg:93.81ms +step:872/1750 train_time:81801ms step_avg:93.81ms +step:873/1750 train_time:81898ms step_avg:93.81ms +step:874/1750 train_time:81995ms step_avg:93.82ms +step:875/1750 train_time:82091ms step_avg:93.82ms +step:875/1750 val_loss:3.5431 train_time:82190ms step_avg:93.93ms +step:876/1750 train_time:82213ms step_avg:93.85ms +step:877/1750 train_time:82295ms step_avg:93.84ms +step:878/1750 train_time:82396ms step_avg:93.84ms +step:879/1750 train_time:82492ms step_avg:93.85ms +step:880/1750 train_time:82588ms step_avg:93.85ms +step:881/1750 train_time:82683ms step_avg:93.85ms +step:882/1750 train_time:82779ms step_avg:93.85ms +step:883/1750 train_time:82874ms step_avg:93.85ms +step:884/1750 train_time:82969ms step_avg:93.86ms +step:885/1750 train_time:83064ms step_avg:93.86ms +step:886/1750 train_time:83160ms step_avg:93.86ms +step:887/1750 train_time:83259ms step_avg:93.87ms +step:888/1750 train_time:83358ms step_avg:93.87ms +step:889/1750 train_time:83456ms step_avg:93.88ms +step:890/1750 train_time:83552ms step_avg:93.88ms +step:891/1750 train_time:83648ms step_avg:93.88ms +step:892/1750 train_time:83744ms step_avg:93.88ms +step:893/1750 train_time:83839ms step_avg:93.88ms +step:894/1750 train_time:83933ms step_avg:93.89ms +step:895/1750 train_time:84029ms step_avg:93.89ms +step:896/1750 train_time:84124ms step_avg:93.89ms +step:897/1750 train_time:84221ms step_avg:93.89ms +step:898/1750 train_time:84319ms step_avg:93.90ms +step:899/1750 train_time:84416ms step_avg:93.90ms +step:900/1750 train_time:84513ms step_avg:93.90ms +step:901/1750 train_time:84610ms step_avg:93.91ms +step:902/1750 train_time:84706ms step_avg:93.91ms +step:903/1750 train_time:84801ms step_avg:93.91ms +step:904/1750 train_time:84897ms step_avg:93.91ms +step:905/1750 train_time:84992ms step_avg:93.91ms +step:906/1750 train_time:85088ms step_avg:93.92ms +step:907/1750 train_time:85184ms step_avg:93.92ms +step:908/1750 train_time:85281ms step_avg:93.92ms +step:909/1750 train_time:85378ms step_avg:93.92ms +step:910/1750 train_time:85476ms step_avg:93.93ms +step:911/1750 train_time:85574ms step_avg:93.93ms +step:912/1750 train_time:85671ms step_avg:93.94ms +step:913/1750 train_time:85769ms step_avg:93.94ms +step:914/1750 train_time:85867ms step_avg:93.95ms +step:915/1750 train_time:85964ms step_avg:93.95ms +step:916/1750 train_time:86061ms step_avg:93.95ms +step:917/1750 train_time:86158ms step_avg:93.96ms +step:918/1750 train_time:86255ms step_avg:93.96ms +step:919/1750 train_time:86352ms step_avg:93.96ms +step:920/1750 train_time:86450ms step_avg:93.97ms +step:921/1750 train_time:86548ms step_avg:93.97ms +step:922/1750 train_time:86646ms step_avg:93.98ms +step:923/1750 train_time:86744ms step_avg:93.98ms +step:924/1750 train_time:86841ms step_avg:93.98ms +step:925/1750 train_time:86937ms step_avg:93.99ms +step:926/1750 train_time:87035ms step_avg:93.99ms +step:927/1750 train_time:87131ms step_avg:93.99ms +step:928/1750 train_time:87229ms step_avg:94.00ms +step:929/1750 train_time:87327ms step_avg:94.00ms +step:930/1750 train_time:87425ms step_avg:94.01ms +step:931/1750 train_time:87523ms step_avg:94.01ms +step:932/1750 train_time:87621ms step_avg:94.01ms +step:933/1750 train_time:87719ms step_avg:94.02ms +step:934/1750 train_time:87816ms step_avg:94.02ms +step:935/1750 train_time:87913ms step_avg:94.02ms +step:936/1750 train_time:88010ms step_avg:94.03ms +step:937/1750 train_time:88108ms step_avg:94.03ms +step:938/1750 train_time:88206ms step_avg:94.04ms +step:939/1750 train_time:88304ms step_avg:94.04ms +step:940/1750 train_time:88402ms step_avg:94.04ms +step:941/1750 train_time:88500ms step_avg:94.05ms +step:942/1750 train_time:88597ms step_avg:94.05ms +step:943/1750 train_time:88695ms step_avg:94.06ms +step:944/1750 train_time:88792ms step_avg:94.06ms +step:945/1750 train_time:88889ms step_avg:94.06ms +step:946/1750 train_time:88986ms step_avg:94.07ms +step:947/1750 train_time:89083ms step_avg:94.07ms +step:948/1750 train_time:89181ms step_avg:94.07ms +step:949/1750 train_time:89278ms step_avg:94.08ms +step:950/1750 train_time:89377ms step_avg:94.08ms +step:951/1750 train_time:89473ms step_avg:94.08ms +step:952/1750 train_time:89571ms step_avg:94.09ms +step:953/1750 train_time:89668ms step_avg:94.09ms +step:954/1750 train_time:89765ms step_avg:94.09ms +step:955/1750 train_time:89863ms step_avg:94.10ms +step:956/1750 train_time:89962ms step_avg:94.10ms +step:957/1750 train_time:90058ms step_avg:94.10ms +step:958/1750 train_time:90155ms step_avg:94.11ms +step:959/1750 train_time:90253ms step_avg:94.11ms +step:960/1750 train_time:90351ms step_avg:94.12ms +step:961/1750 train_time:90449ms step_avg:94.12ms +step:962/1750 train_time:90547ms step_avg:94.12ms +step:963/1750 train_time:90645ms step_avg:94.13ms +step:964/1750 train_time:90743ms step_avg:94.13ms +step:965/1750 train_time:90840ms step_avg:94.14ms +step:966/1750 train_time:90937ms step_avg:94.14ms +step:967/1750 train_time:91035ms step_avg:94.14ms +step:968/1750 train_time:91133ms step_avg:94.15ms +step:969/1750 train_time:91230ms step_avg:94.15ms +step:970/1750 train_time:91328ms step_avg:94.15ms +step:971/1750 train_time:91427ms step_avg:94.16ms +step:972/1750 train_time:91524ms step_avg:94.16ms +step:973/1750 train_time:91622ms step_avg:94.16ms +step:974/1750 train_time:91719ms step_avg:94.17ms +step:975/1750 train_time:91818ms step_avg:94.17ms +step:976/1750 train_time:91915ms step_avg:94.18ms +step:977/1750 train_time:92013ms step_avg:94.18ms +step:978/1750 train_time:92110ms step_avg:94.18ms +step:979/1750 train_time:92209ms step_avg:94.19ms +step:980/1750 train_time:92306ms step_avg:94.19ms +step:981/1750 train_time:92403ms step_avg:94.19ms +step:982/1750 train_time:92500ms step_avg:94.20ms +step:983/1750 train_time:92597ms step_avg:94.20ms +step:984/1750 train_time:92695ms step_avg:94.20ms +step:985/1750 train_time:92794ms step_avg:94.21ms +step:986/1750 train_time:92891ms step_avg:94.21ms +step:987/1750 train_time:92989ms step_avg:94.21ms +step:988/1750 train_time:93087ms step_avg:94.22ms +step:989/1750 train_time:93185ms step_avg:94.22ms +step:990/1750 train_time:93283ms step_avg:94.23ms +step:991/1750 train_time:93381ms step_avg:94.23ms +step:992/1750 train_time:93479ms step_avg:94.23ms +step:993/1750 train_time:93575ms step_avg:94.24ms +step:994/1750 train_time:93673ms step_avg:94.24ms +step:995/1750 train_time:93770ms step_avg:94.24ms +step:996/1750 train_time:93868ms step_avg:94.25ms +step:997/1750 train_time:93967ms step_avg:94.25ms +step:998/1750 train_time:94063ms step_avg:94.25ms +step:999/1750 train_time:94160ms step_avg:94.25ms +step:1000/1750 train_time:94258ms step_avg:94.26ms +step:1000/1750 val_loss:3.5036 train_time:94359ms step_avg:94.36ms +step:1001/1750 train_time:94383ms step_avg:94.29ms +step:1002/1750 train_time:94465ms step_avg:94.28ms +step:1003/1750 train_time:94566ms step_avg:94.28ms +step:1004/1750 train_time:94664ms step_avg:94.29ms +step:1005/1750 train_time:94761ms step_avg:94.29ms +step:1006/1750 train_time:94857ms step_avg:94.29ms +step:1007/1750 train_time:94953ms step_avg:94.29ms +step:1008/1750 train_time:95049ms step_avg:94.29ms +step:1009/1750 train_time:95145ms step_avg:94.30ms +step:1010/1750 train_time:95242ms step_avg:94.30ms +step:1011/1750 train_time:95340ms step_avg:94.30ms +step:1012/1750 train_time:95441ms step_avg:94.31ms +step:1013/1750 train_time:95540ms step_avg:94.31ms +step:1014/1750 train_time:95639ms step_avg:94.32ms +step:1015/1750 train_time:95736ms step_avg:94.32ms +step:1016/1750 train_time:95834ms step_avg:94.32ms +step:1017/1750 train_time:95931ms step_avg:94.33ms +step:1018/1750 train_time:96027ms step_avg:94.33ms +step:1019/1750 train_time:96124ms step_avg:94.33ms +step:1020/1750 train_time:96220ms step_avg:94.33ms +step:1021/1750 train_time:96318ms step_avg:94.34ms +step:1022/1750 train_time:96417ms step_avg:94.34ms +step:1023/1750 train_time:96516ms step_avg:94.35ms +step:1024/1750 train_time:96615ms step_avg:94.35ms +step:1025/1750 train_time:96712ms step_avg:94.35ms +step:1026/1750 train_time:96810ms step_avg:94.36ms +step:1027/1750 train_time:96906ms step_avg:94.36ms +step:1028/1750 train_time:97004ms step_avg:94.36ms +step:1029/1750 train_time:97102ms step_avg:94.37ms +step:1030/1750 train_time:97199ms step_avg:94.37ms +step:1031/1750 train_time:97296ms step_avg:94.37ms +step:1032/1750 train_time:97394ms step_avg:94.37ms +step:1033/1750 train_time:97492ms step_avg:94.38ms +step:1034/1750 train_time:97589ms step_avg:94.38ms +step:1035/1750 train_time:97688ms step_avg:94.38ms +step:1036/1750 train_time:97786ms step_avg:94.39ms +step:1037/1750 train_time:97884ms step_avg:94.39ms +step:1038/1750 train_time:97981ms step_avg:94.39ms +step:1039/1750 train_time:98078ms step_avg:94.40ms +step:1040/1750 train_time:98175ms step_avg:94.40ms +step:1041/1750 train_time:98273ms step_avg:94.40ms +step:1042/1750 train_time:98370ms step_avg:94.40ms +step:1043/1750 train_time:98468ms step_avg:94.41ms +step:1044/1750 train_time:98566ms step_avg:94.41ms +step:1045/1750 train_time:98664ms step_avg:94.42ms +step:1046/1750 train_time:98763ms step_avg:94.42ms +step:1047/1750 train_time:98861ms step_avg:94.42ms +step:1048/1750 train_time:98958ms step_avg:94.43ms +step:1049/1750 train_time:99055ms step_avg:94.43ms +step:1050/1750 train_time:99152ms step_avg:94.43ms +step:1051/1750 train_time:99250ms step_avg:94.43ms +step:1052/1750 train_time:99347ms step_avg:94.44ms +step:1053/1750 train_time:99444ms step_avg:94.44ms +step:1054/1750 train_time:99543ms step_avg:94.44ms +step:1055/1750 train_time:99641ms step_avg:94.45ms +step:1056/1750 train_time:99739ms step_avg:94.45ms +step:1057/1750 train_time:99837ms step_avg:94.45ms +step:1058/1750 train_time:99935ms step_avg:94.46ms +step:1059/1750 train_time:100032ms step_avg:94.46ms +step:1060/1750 train_time:100129ms step_avg:94.46ms +step:1061/1750 train_time:100226ms step_avg:94.46ms +step:1062/1750 train_time:100324ms step_avg:94.47ms +step:1063/1750 train_time:100421ms step_avg:94.47ms +step:1064/1750 train_time:100520ms step_avg:94.47ms +step:1065/1750 train_time:100619ms step_avg:94.48ms +step:1066/1750 train_time:100717ms step_avg:94.48ms +step:1067/1750 train_time:100815ms step_avg:94.48ms +step:1068/1750 train_time:100912ms step_avg:94.49ms +step:1069/1750 train_time:101009ms step_avg:94.49ms +step:1070/1750 train_time:101107ms step_avg:94.49ms +step:1071/1750 train_time:101205ms step_avg:94.50ms +step:1072/1750 train_time:101303ms step_avg:94.50ms +step:1073/1750 train_time:101399ms step_avg:94.50ms +step:1074/1750 train_time:101497ms step_avg:94.50ms +step:1075/1750 train_time:101595ms step_avg:94.51ms +step:1076/1750 train_time:101693ms step_avg:94.51ms +step:1077/1750 train_time:101792ms step_avg:94.51ms +step:1078/1750 train_time:101889ms step_avg:94.52ms +step:1079/1750 train_time:101988ms step_avg:94.52ms +step:1080/1750 train_time:102085ms step_avg:94.52ms +step:1081/1750 train_time:102182ms step_avg:94.53ms +step:1082/1750 train_time:102280ms step_avg:94.53ms +step:1083/1750 train_time:102378ms step_avg:94.53ms +step:1084/1750 train_time:102474ms step_avg:94.53ms +step:1085/1750 train_time:102572ms step_avg:94.54ms +step:1086/1750 train_time:102670ms step_avg:94.54ms +step:1087/1750 train_time:102767ms step_avg:94.54ms +step:1088/1750 train_time:102865ms step_avg:94.55ms +step:1089/1750 train_time:102963ms step_avg:94.55ms +step:1090/1750 train_time:103062ms step_avg:94.55ms +step:1091/1750 train_time:103160ms step_avg:94.56ms +step:1092/1750 train_time:103258ms step_avg:94.56ms +step:1093/1750 train_time:103356ms step_avg:94.56ms +step:1094/1750 train_time:103453ms step_avg:94.56ms +step:1095/1750 train_time:103550ms step_avg:94.57ms +step:1096/1750 train_time:103648ms step_avg:94.57ms +step:1097/1750 train_time:103746ms step_avg:94.57ms +step:1098/1750 train_time:103844ms step_avg:94.58ms +step:1099/1750 train_time:103942ms step_avg:94.58ms +step:1100/1750 train_time:104039ms step_avg:94.58ms +step:1101/1750 train_time:104137ms step_avg:94.58ms +step:1102/1750 train_time:104234ms step_avg:94.59ms +step:1103/1750 train_time:104332ms step_avg:94.59ms +step:1104/1750 train_time:104429ms step_avg:94.59ms +step:1105/1750 train_time:104526ms step_avg:94.59ms +step:1106/1750 train_time:104625ms step_avg:94.60ms +step:1107/1750 train_time:104723ms step_avg:94.60ms +step:1108/1750 train_time:104822ms step_avg:94.60ms +step:1109/1750 train_time:104920ms step_avg:94.61ms +step:1110/1750 train_time:105018ms step_avg:94.61ms +step:1111/1750 train_time:105115ms step_avg:94.61ms +step:1112/1750 train_time:105213ms step_avg:94.62ms +step:1113/1750 train_time:105310ms step_avg:94.62ms +step:1114/1750 train_time:105408ms step_avg:94.62ms +step:1115/1750 train_time:105505ms step_avg:94.62ms +step:1116/1750 train_time:105602ms step_avg:94.63ms +step:1117/1750 train_time:105700ms step_avg:94.63ms +step:1118/1750 train_time:105798ms step_avg:94.63ms +step:1119/1750 train_time:105895ms step_avg:94.63ms +step:1120/1750 train_time:105994ms step_avg:94.64ms +step:1121/1750 train_time:106092ms step_avg:94.64ms +step:1122/1750 train_time:106190ms step_avg:94.64ms +step:1123/1750 train_time:106287ms step_avg:94.65ms +step:1124/1750 train_time:106385ms step_avg:94.65ms +step:1125/1750 train_time:106484ms step_avg:94.65ms +step:1125/1750 val_loss:3.4519 train_time:106584ms step_avg:94.74ms +step:1126/1750 train_time:106609ms step_avg:94.68ms +step:1127/1750 train_time:106687ms step_avg:94.66ms +step:1128/1750 train_time:106787ms step_avg:94.67ms +step:1129/1750 train_time:106885ms step_avg:94.67ms +step:1130/1750 train_time:106983ms step_avg:94.67ms +step:1131/1750 train_time:107079ms step_avg:94.68ms +step:1132/1750 train_time:107176ms step_avg:94.68ms +step:1133/1750 train_time:107273ms step_avg:94.68ms +step:1134/1750 train_time:107371ms step_avg:94.68ms +step:1135/1750 train_time:107468ms step_avg:94.69ms +step:1136/1750 train_time:107567ms step_avg:94.69ms +step:1137/1750 train_time:107667ms step_avg:94.69ms +step:1138/1750 train_time:107766ms step_avg:94.70ms +step:1139/1750 train_time:107864ms step_avg:94.70ms +step:1140/1750 train_time:107961ms step_avg:94.70ms +step:1141/1750 train_time:108058ms step_avg:94.70ms +step:1142/1750 train_time:108154ms step_avg:94.71ms +step:1143/1750 train_time:108250ms step_avg:94.71ms +step:1144/1750 train_time:108347ms step_avg:94.71ms +step:1145/1750 train_time:108444ms step_avg:94.71ms +step:1146/1750 train_time:108543ms step_avg:94.71ms +step:1147/1750 train_time:108641ms step_avg:94.72ms +step:1148/1750 train_time:108740ms step_avg:94.72ms +step:1149/1750 train_time:108839ms step_avg:94.72ms +step:1150/1750 train_time:108936ms step_avg:94.73ms +step:1151/1750 train_time:109033ms step_avg:94.73ms +step:1152/1750 train_time:109130ms step_avg:94.73ms +step:1153/1750 train_time:109228ms step_avg:94.73ms +step:1154/1750 train_time:109324ms step_avg:94.73ms +step:1155/1750 train_time:109421ms step_avg:94.74ms +step:1156/1750 train_time:109519ms step_avg:94.74ms +step:1157/1750 train_time:109617ms step_avg:94.74ms +step:1158/1750 train_time:109715ms step_avg:94.75ms +step:1159/1750 train_time:109813ms step_avg:94.75ms +step:1160/1750 train_time:109911ms step_avg:94.75ms +step:1161/1750 train_time:110010ms step_avg:94.75ms +step:1162/1750 train_time:110108ms step_avg:94.76ms +step:1163/1750 train_time:110207ms step_avg:94.76ms +step:1164/1750 train_time:110304ms step_avg:94.76ms +step:1165/1750 train_time:110401ms step_avg:94.77ms +step:1166/1750 train_time:110499ms step_avg:94.77ms +step:1167/1750 train_time:110596ms step_avg:94.77ms +step:1168/1750 train_time:110695ms step_avg:94.77ms +step:1169/1750 train_time:110794ms step_avg:94.78ms +step:1170/1750 train_time:110892ms step_avg:94.78ms +step:1171/1750 train_time:110992ms step_avg:94.78ms +step:1172/1750 train_time:111090ms step_avg:94.79ms +step:1173/1750 train_time:111191ms step_avg:94.79ms +step:1174/1750 train_time:111291ms step_avg:94.80ms +step:1175/1750 train_time:111390ms step_avg:94.80ms +step:1176/1750 train_time:111489ms step_avg:94.80ms +step:1177/1750 train_time:111590ms step_avg:94.81ms +step:1178/1750 train_time:111689ms step_avg:94.81ms +step:1179/1750 train_time:111792ms step_avg:94.82ms +step:1180/1750 train_time:111891ms step_avg:94.82ms +step:1181/1750 train_time:111989ms step_avg:94.83ms +step:1182/1750 train_time:112088ms step_avg:94.83ms +step:1183/1750 train_time:112187ms step_avg:94.83ms +step:1184/1750 train_time:112286ms step_avg:94.84ms +step:1185/1750 train_time:112385ms step_avg:94.84ms +step:1186/1750 train_time:112485ms step_avg:94.84ms +step:1187/1750 train_time:112584ms step_avg:94.85ms +step:1188/1750 train_time:112683ms step_avg:94.85ms +step:1189/1750 train_time:112781ms step_avg:94.85ms +step:1190/1750 train_time:112881ms step_avg:94.86ms +step:1191/1750 train_time:112980ms step_avg:94.86ms +step:1192/1750 train_time:113080ms step_avg:94.87ms +step:1193/1750 train_time:113179ms step_avg:94.87ms +step:1194/1750 train_time:113278ms step_avg:94.87ms +step:1195/1750 train_time:113378ms step_avg:94.88ms +step:1196/1750 train_time:113476ms step_avg:94.88ms +step:1197/1750 train_time:113575ms step_avg:94.88ms +step:1198/1750 train_time:113675ms step_avg:94.89ms +step:1199/1750 train_time:113774ms step_avg:94.89ms +step:1200/1750 train_time:113873ms step_avg:94.89ms +step:1201/1750 train_time:113973ms step_avg:94.90ms +step:1202/1750 train_time:114073ms step_avg:94.90ms +step:1203/1750 train_time:114172ms step_avg:94.91ms +step:1204/1750 train_time:114271ms step_avg:94.91ms +step:1205/1750 train_time:114370ms step_avg:94.91ms +step:1206/1750 train_time:114469ms step_avg:94.92ms +step:1207/1750 train_time:114568ms step_avg:94.92ms +step:1208/1750 train_time:114668ms step_avg:94.92ms +step:1209/1750 train_time:114767ms step_avg:94.93ms +step:1210/1750 train_time:114866ms step_avg:94.93ms +step:1211/1750 train_time:114966ms step_avg:94.93ms +step:1212/1750 train_time:115064ms step_avg:94.94ms +step:1213/1750 train_time:115163ms step_avg:94.94ms +step:1214/1750 train_time:115261ms step_avg:94.94ms +step:1215/1750 train_time:115361ms step_avg:94.95ms +step:1216/1750 train_time:115460ms step_avg:94.95ms +step:1217/1750 train_time:115560ms step_avg:94.95ms +step:1218/1750 train_time:115660ms step_avg:94.96ms +step:1219/1750 train_time:115759ms step_avg:94.96ms +step:1220/1750 train_time:115859ms step_avg:94.97ms +step:1221/1750 train_time:115958ms step_avg:94.97ms +step:1222/1750 train_time:116057ms step_avg:94.97ms +step:1223/1750 train_time:116157ms step_avg:94.98ms +step:1224/1750 train_time:116256ms step_avg:94.98ms +step:1225/1750 train_time:116355ms step_avg:94.98ms +step:1226/1750 train_time:116455ms step_avg:94.99ms +step:1227/1750 train_time:116553ms step_avg:94.99ms +step:1228/1750 train_time:116652ms step_avg:94.99ms +step:1229/1750 train_time:116752ms step_avg:95.00ms +step:1230/1750 train_time:116850ms step_avg:95.00ms +step:1231/1750 train_time:116950ms step_avg:95.00ms +step:1232/1750 train_time:117049ms step_avg:95.01ms +step:1233/1750 train_time:117148ms step_avg:95.01ms +step:1234/1750 train_time:117249ms step_avg:95.02ms +step:1235/1750 train_time:117348ms step_avg:95.02ms +step:1236/1750 train_time:117447ms step_avg:95.02ms +step:1237/1750 train_time:117548ms step_avg:95.03ms +step:1238/1750 train_time:117647ms step_avg:95.03ms +step:1239/1750 train_time:117746ms step_avg:95.03ms +step:1240/1750 train_time:117845ms step_avg:95.04ms +step:1241/1750 train_time:117944ms step_avg:95.04ms +step:1242/1750 train_time:118042ms step_avg:95.04ms +step:1243/1750 train_time:118141ms step_avg:95.04ms +step:1244/1750 train_time:118241ms step_avg:95.05ms +step:1245/1750 train_time:118341ms step_avg:95.05ms +step:1246/1750 train_time:118441ms step_avg:95.06ms +step:1247/1750 train_time:118539ms step_avg:95.06ms +step:1248/1750 train_time:118638ms step_avg:95.06ms +step:1249/1750 train_time:118738ms step_avg:95.07ms +step:1250/1750 train_time:118836ms step_avg:95.07ms +step:1250/1750 val_loss:3.4067 train_time:118938ms step_avg:95.15ms +step:1251/1750 train_time:118962ms step_avg:95.09ms +step:1252/1750 train_time:119048ms step_avg:95.09ms +step:1253/1750 train_time:119149ms step_avg:95.09ms +step:1254/1750 train_time:119248ms step_avg:95.09ms +step:1255/1750 train_time:119347ms step_avg:95.10ms +step:1256/1750 train_time:119445ms step_avg:95.10ms +step:1257/1750 train_time:119542ms step_avg:95.10ms +step:1258/1750 train_time:119640ms step_avg:95.10ms +step:1259/1750 train_time:119737ms step_avg:95.10ms +step:1260/1750 train_time:119834ms step_avg:95.11ms +step:1261/1750 train_time:119932ms step_avg:95.11ms +step:1262/1750 train_time:120033ms step_avg:95.11ms +step:1263/1750 train_time:120133ms step_avg:95.12ms +step:1264/1750 train_time:120232ms step_avg:95.12ms +step:1265/1750 train_time:120332ms step_avg:95.12ms +step:1266/1750 train_time:120430ms step_avg:95.13ms +step:1267/1750 train_time:120529ms step_avg:95.13ms +step:1268/1750 train_time:120627ms step_avg:95.13ms +step:1269/1750 train_time:120724ms step_avg:95.13ms +step:1270/1750 train_time:120822ms step_avg:95.14ms +step:1271/1750 train_time:120924ms step_avg:95.14ms +step:1272/1750 train_time:121023ms step_avg:95.14ms +step:1273/1750 train_time:121123ms step_avg:95.15ms +step:1274/1750 train_time:121221ms step_avg:95.15ms +step:1275/1750 train_time:121320ms step_avg:95.15ms +step:1276/1750 train_time:121420ms step_avg:95.16ms +step:1277/1750 train_time:121519ms step_avg:95.16ms +step:1278/1750 train_time:121619ms step_avg:95.16ms +step:1279/1750 train_time:121718ms step_avg:95.17ms +step:1280/1750 train_time:121817ms step_avg:95.17ms +step:1281/1750 train_time:121917ms step_avg:95.17ms +step:1282/1750 train_time:122016ms step_avg:95.18ms +step:1283/1750 train_time:122115ms step_avg:95.18ms +step:1284/1750 train_time:122213ms step_avg:95.18ms +step:1285/1750 train_time:122311ms step_avg:95.18ms +step:1286/1750 train_time:122410ms step_avg:95.19ms +step:1287/1750 train_time:122509ms step_avg:95.19ms +step:1288/1750 train_time:122609ms step_avg:95.19ms +step:1289/1750 train_time:122709ms step_avg:95.20ms +step:1290/1750 train_time:122807ms step_avg:95.20ms +step:1291/1750 train_time:122906ms step_avg:95.20ms +step:1292/1750 train_time:123006ms step_avg:95.21ms +step:1293/1750 train_time:123105ms step_avg:95.21ms +step:1294/1750 train_time:123206ms step_avg:95.21ms +step:1295/1750 train_time:123305ms step_avg:95.22ms +step:1296/1750 train_time:123405ms step_avg:95.22ms +step:1297/1750 train_time:123504ms step_avg:95.22ms +step:1298/1750 train_time:123602ms step_avg:95.23ms +step:1299/1750 train_time:123701ms step_avg:95.23ms +step:1300/1750 train_time:123800ms step_avg:95.23ms +step:1301/1750 train_time:123899ms step_avg:95.23ms +step:1302/1750 train_time:123997ms step_avg:95.24ms +step:1303/1750 train_time:124097ms step_avg:95.24ms +step:1304/1750 train_time:124197ms step_avg:95.24ms +step:1305/1750 train_time:124296ms step_avg:95.25ms +step:1306/1750 train_time:124395ms step_avg:95.25ms +step:1307/1750 train_time:124496ms step_avg:95.25ms +step:1308/1750 train_time:124595ms step_avg:95.26ms +step:1309/1750 train_time:124694ms step_avg:95.26ms +step:1310/1750 train_time:124794ms step_avg:95.26ms +step:1311/1750 train_time:124894ms step_avg:95.27ms +step:1312/1750 train_time:124993ms step_avg:95.27ms +step:1313/1750 train_time:125093ms step_avg:95.27ms +step:1314/1750 train_time:125193ms step_avg:95.28ms +step:1315/1750 train_time:125292ms step_avg:95.28ms +step:1316/1750 train_time:125391ms step_avg:95.28ms +step:1317/1750 train_time:125490ms step_avg:95.28ms +step:1318/1750 train_time:125588ms step_avg:95.29ms +step:1319/1750 train_time:125687ms step_avg:95.29ms +step:1320/1750 train_time:125788ms step_avg:95.29ms +step:1321/1750 train_time:125887ms step_avg:95.30ms +step:1322/1750 train_time:125986ms step_avg:95.30ms +step:1323/1750 train_time:126085ms step_avg:95.30ms +step:1324/1750 train_time:126186ms step_avg:95.31ms +step:1325/1750 train_time:126285ms step_avg:95.31ms +step:1326/1750 train_time:126384ms step_avg:95.31ms +step:1327/1750 train_time:126482ms step_avg:95.31ms +step:1328/1750 train_time:126581ms step_avg:95.32ms +step:1329/1750 train_time:126680ms step_avg:95.32ms +step:1330/1750 train_time:126780ms step_avg:95.32ms +step:1331/1750 train_time:126878ms step_avg:95.33ms +step:1332/1750 train_time:126977ms step_avg:95.33ms +step:1333/1750 train_time:127078ms step_avg:95.33ms +step:1334/1750 train_time:127177ms step_avg:95.34ms +step:1335/1750 train_time:127277ms step_avg:95.34ms +step:1336/1750 train_time:127376ms step_avg:95.34ms +step:1337/1750 train_time:127475ms step_avg:95.34ms +step:1338/1750 train_time:127574ms step_avg:95.35ms +step:1339/1750 train_time:127673ms step_avg:95.35ms +step:1340/1750 train_time:127773ms step_avg:95.35ms +step:1341/1750 train_time:127872ms step_avg:95.36ms +step:1342/1750 train_time:127972ms step_avg:95.36ms +step:1343/1750 train_time:128072ms step_avg:95.36ms +step:1344/1750 train_time:128171ms step_avg:95.37ms +step:1345/1750 train_time:128271ms step_avg:95.37ms +step:1346/1750 train_time:128370ms step_avg:95.37ms +step:1347/1750 train_time:128469ms step_avg:95.37ms +step:1348/1750 train_time:128569ms step_avg:95.38ms +step:1349/1750 train_time:128666ms step_avg:95.38ms +step:1350/1750 train_time:128767ms step_avg:95.38ms +step:1351/1750 train_time:128865ms step_avg:95.38ms +step:1352/1750 train_time:128964ms step_avg:95.39ms +step:1353/1750 train_time:129064ms step_avg:95.39ms +step:1354/1750 train_time:129163ms step_avg:95.39ms +step:1355/1750 train_time:129262ms step_avg:95.40ms +step:1356/1750 train_time:129362ms step_avg:95.40ms +step:1357/1750 train_time:129460ms step_avg:95.40ms +step:1358/1750 train_time:129560ms step_avg:95.40ms +step:1359/1750 train_time:129658ms step_avg:95.41ms +step:1360/1750 train_time:129757ms step_avg:95.41ms +step:1361/1750 train_time:129856ms step_avg:95.41ms +step:1362/1750 train_time:129957ms step_avg:95.42ms +step:1363/1750 train_time:130058ms step_avg:95.42ms +step:1364/1750 train_time:130157ms step_avg:95.42ms +step:1365/1750 train_time:130257ms step_avg:95.43ms +step:1366/1750 train_time:130355ms step_avg:95.43ms +step:1367/1750 train_time:130453ms step_avg:95.43ms +step:1368/1750 train_time:130552ms step_avg:95.43ms +step:1369/1750 train_time:130650ms step_avg:95.43ms +step:1370/1750 train_time:130749ms step_avg:95.44ms +step:1371/1750 train_time:130849ms step_avg:95.44ms +step:1372/1750 train_time:130948ms step_avg:95.44ms +step:1373/1750 train_time:131048ms step_avg:95.45ms +step:1374/1750 train_time:131147ms step_avg:95.45ms +step:1375/1750 train_time:131248ms step_avg:95.45ms +step:1375/1750 val_loss:3.3669 train_time:131351ms step_avg:95.53ms +step:1376/1750 train_time:131375ms step_avg:95.48ms +step:1377/1750 train_time:131458ms step_avg:95.47ms +step:1378/1750 train_time:131559ms step_avg:95.47ms +step:1379/1750 train_time:131659ms step_avg:95.47ms +step:1380/1750 train_time:131759ms step_avg:95.48ms +step:1381/1750 train_time:131857ms step_avg:95.48ms +step:1382/1750 train_time:131955ms step_avg:95.48ms +step:1383/1750 train_time:132053ms step_avg:95.48ms +step:1384/1750 train_time:132150ms step_avg:95.48ms +step:1385/1750 train_time:132248ms step_avg:95.49ms +step:1386/1750 train_time:132348ms step_avg:95.49ms +step:1387/1750 train_time:132449ms step_avg:95.49ms +step:1388/1750 train_time:132549ms step_avg:95.50ms +step:1389/1750 train_time:132649ms step_avg:95.50ms +step:1390/1750 train_time:132748ms step_avg:95.50ms +step:1391/1750 train_time:132846ms step_avg:95.50ms +step:1392/1750 train_time:132945ms step_avg:95.51ms +step:1393/1750 train_time:133043ms step_avg:95.51ms +step:1394/1750 train_time:133142ms step_avg:95.51ms +step:1395/1750 train_time:133241ms step_avg:95.51ms +step:1396/1750 train_time:133340ms step_avg:95.52ms +step:1397/1750 train_time:133440ms step_avg:95.52ms +step:1398/1750 train_time:133540ms step_avg:95.52ms +step:1399/1750 train_time:133640ms step_avg:95.53ms +step:1400/1750 train_time:133739ms step_avg:95.53ms +step:1401/1750 train_time:133838ms step_avg:95.53ms +step:1402/1750 train_time:133938ms step_avg:95.53ms +step:1403/1750 train_time:134038ms step_avg:95.54ms +step:1404/1750 train_time:134137ms step_avg:95.54ms +step:1405/1750 train_time:134235ms step_avg:95.54ms +step:1406/1750 train_time:134335ms step_avg:95.54ms +step:1407/1750 train_time:134435ms step_avg:95.55ms +step:1408/1750 train_time:134535ms step_avg:95.55ms +step:1409/1750 train_time:134635ms step_avg:95.55ms +step:1410/1750 train_time:134734ms step_avg:95.56ms +step:1411/1750 train_time:134834ms step_avg:95.56ms +step:1412/1750 train_time:134934ms step_avg:95.56ms +step:1413/1750 train_time:135032ms step_avg:95.56ms +step:1414/1750 train_time:135132ms step_avg:95.57ms +step:1415/1750 train_time:135232ms step_avg:95.57ms +step:1416/1750 train_time:135330ms step_avg:95.57ms +step:1417/1750 train_time:135429ms step_avg:95.57ms +step:1418/1750 train_time:135528ms step_avg:95.58ms +step:1419/1750 train_time:135627ms step_avg:95.58ms +step:1420/1750 train_time:135727ms step_avg:95.58ms +step:1421/1750 train_time:135826ms step_avg:95.58ms +step:1422/1750 train_time:135925ms step_avg:95.59ms +step:1423/1750 train_time:136025ms step_avg:95.59ms +step:1424/1750 train_time:136126ms step_avg:95.59ms +step:1425/1750 train_time:136225ms step_avg:95.60ms +step:1426/1750 train_time:136324ms step_avg:95.60ms +step:1427/1750 train_time:136423ms step_avg:95.60ms +step:1428/1750 train_time:136523ms step_avg:95.60ms +step:1429/1750 train_time:136622ms step_avg:95.61ms +step:1430/1750 train_time:136722ms step_avg:95.61ms +step:1431/1750 train_time:136822ms step_avg:95.61ms +step:1432/1750 train_time:136922ms step_avg:95.62ms +step:1433/1750 train_time:137022ms step_avg:95.62ms +step:1434/1750 train_time:137122ms step_avg:95.62ms +step:1435/1750 train_time:137223ms step_avg:95.63ms +step:1436/1750 train_time:137324ms step_avg:95.63ms +step:1437/1750 train_time:137425ms step_avg:95.63ms +step:1438/1750 train_time:137524ms step_avg:95.64ms +step:1439/1750 train_time:137624ms step_avg:95.64ms +step:1440/1750 train_time:137726ms step_avg:95.64ms +step:1441/1750 train_time:137827ms step_avg:95.65ms +step:1442/1750 train_time:137925ms step_avg:95.65ms +step:1443/1750 train_time:138024ms step_avg:95.65ms +step:1444/1750 train_time:138125ms step_avg:95.65ms +step:1445/1750 train_time:138224ms step_avg:95.66ms +step:1446/1750 train_time:138324ms step_avg:95.66ms +step:1447/1750 train_time:138424ms step_avg:95.66ms +step:1448/1750 train_time:138524ms step_avg:95.67ms +step:1449/1750 train_time:138624ms step_avg:95.67ms +step:1450/1750 train_time:138723ms step_avg:95.67ms +step:1451/1750 train_time:138823ms step_avg:95.67ms +step:1452/1750 train_time:138922ms step_avg:95.68ms +step:1453/1750 train_time:139023ms step_avg:95.68ms +step:1454/1750 train_time:139125ms step_avg:95.68ms +step:1455/1750 train_time:139225ms step_avg:95.69ms +step:1456/1750 train_time:139324ms step_avg:95.69ms +step:1457/1750 train_time:139425ms step_avg:95.69ms +step:1458/1750 train_time:139525ms step_avg:95.70ms +step:1459/1750 train_time:139624ms step_avg:95.70ms +step:1460/1750 train_time:139724ms step_avg:95.70ms +step:1461/1750 train_time:139823ms step_avg:95.70ms +step:1462/1750 train_time:139922ms step_avg:95.71ms +step:1463/1750 train_time:140022ms step_avg:95.71ms +step:1464/1750 train_time:140122ms step_avg:95.71ms +step:1465/1750 train_time:140222ms step_avg:95.71ms +step:1466/1750 train_time:140321ms step_avg:95.72ms +step:1467/1750 train_time:140422ms step_avg:95.72ms +step:1468/1750 train_time:140522ms step_avg:95.72ms +step:1469/1750 train_time:140624ms step_avg:95.73ms +step:1470/1750 train_time:140724ms step_avg:95.73ms +step:1471/1750 train_time:140823ms step_avg:95.73ms +step:1472/1750 train_time:140923ms step_avg:95.74ms +step:1473/1750 train_time:141023ms step_avg:95.74ms +step:1474/1750 train_time:141123ms step_avg:95.74ms +step:1475/1750 train_time:141222ms step_avg:95.74ms +step:1476/1750 train_time:141323ms step_avg:95.75ms +step:1477/1750 train_time:141423ms step_avg:95.75ms +step:1478/1750 train_time:141524ms step_avg:95.75ms +step:1479/1750 train_time:141625ms step_avg:95.76ms +step:1480/1750 train_time:141725ms step_avg:95.76ms +step:1481/1750 train_time:141824ms step_avg:95.76ms +step:1482/1750 train_time:141925ms step_avg:95.77ms +step:1483/1750 train_time:142024ms step_avg:95.77ms +step:1484/1750 train_time:142125ms step_avg:95.77ms +step:1485/1750 train_time:142227ms step_avg:95.78ms +step:1486/1750 train_time:142327ms step_avg:95.78ms +step:1487/1750 train_time:142428ms step_avg:95.78ms +step:1488/1750 train_time:142530ms step_avg:95.79ms +step:1489/1750 train_time:142629ms step_avg:95.79ms +step:1490/1750 train_time:142729ms step_avg:95.79ms +step:1491/1750 train_time:142829ms step_avg:95.79ms +step:1492/1750 train_time:142929ms step_avg:95.80ms +step:1493/1750 train_time:143029ms step_avg:95.80ms +step:1494/1750 train_time:143129ms step_avg:95.80ms +step:1495/1750 train_time:143229ms step_avg:95.81ms +step:1496/1750 train_time:143329ms step_avg:95.81ms +step:1497/1750 train_time:143429ms step_avg:95.81ms +step:1498/1750 train_time:143529ms step_avg:95.81ms +step:1499/1750 train_time:143629ms step_avg:95.82ms +step:1500/1750 train_time:143730ms step_avg:95.82ms +step:1500/1750 val_loss:3.3314 train_time:143832ms step_avg:95.89ms +step:1501/1750 train_time:143856ms step_avg:95.84ms +step:1502/1750 train_time:143941ms step_avg:95.83ms +step:1503/1750 train_time:144043ms step_avg:95.84ms +step:1504/1750 train_time:144142ms step_avg:95.84ms +step:1505/1750 train_time:144241ms step_avg:95.84ms +step:1506/1750 train_time:144340ms step_avg:95.84ms +step:1507/1750 train_time:144438ms step_avg:95.84ms +step:1508/1750 train_time:144537ms step_avg:95.85ms +step:1509/1750 train_time:144635ms step_avg:95.85ms +step:1510/1750 train_time:144734ms step_avg:95.85ms +step:1511/1750 train_time:144835ms step_avg:95.85ms +step:1512/1750 train_time:144937ms step_avg:95.86ms +step:1513/1750 train_time:145038ms step_avg:95.86ms +step:1514/1750 train_time:145139ms step_avg:95.86ms +step:1515/1750 train_time:145243ms step_avg:95.87ms +step:1516/1750 train_time:145342ms step_avg:95.87ms +step:1517/1750 train_time:145441ms step_avg:95.87ms +step:1518/1750 train_time:145540ms step_avg:95.88ms +step:1519/1750 train_time:145640ms step_avg:95.88ms +step:1520/1750 train_time:145741ms step_avg:95.88ms +step:1521/1750 train_time:145842ms step_avg:95.89ms +step:1522/1750 train_time:145943ms step_avg:95.89ms +step:1523/1750 train_time:146042ms step_avg:95.89ms +step:1524/1750 train_time:146143ms step_avg:95.89ms +step:1525/1750 train_time:146243ms step_avg:95.90ms +step:1526/1750 train_time:146342ms step_avg:95.90ms +step:1527/1750 train_time:146442ms step_avg:95.90ms +step:1528/1750 train_time:146544ms step_avg:95.91ms +step:1529/1750 train_time:146644ms step_avg:95.91ms +step:1530/1750 train_time:146744ms step_avg:95.91ms +step:1531/1750 train_time:146844ms step_avg:95.91ms +step:1532/1750 train_time:146945ms step_avg:95.92ms +step:1533/1750 train_time:147044ms step_avg:95.92ms +step:1534/1750 train_time:147144ms step_avg:95.92ms +step:1535/1750 train_time:147244ms step_avg:95.92ms +step:1536/1750 train_time:147343ms step_avg:95.93ms +step:1537/1750 train_time:147442ms step_avg:95.93ms +step:1538/1750 train_time:147543ms step_avg:95.93ms +step:1539/1750 train_time:147642ms step_avg:95.93ms +step:1540/1750 train_time:147742ms step_avg:95.94ms +step:1541/1750 train_time:147844ms step_avg:95.94ms +step:1542/1750 train_time:147947ms step_avg:95.95ms +step:1543/1750 train_time:148049ms step_avg:95.95ms +step:1544/1750 train_time:148148ms step_avg:95.95ms +step:1545/1750 train_time:148249ms step_avg:95.95ms +step:1546/1750 train_time:148348ms step_avg:95.96ms +step:1547/1750 train_time:148452ms step_avg:95.96ms +step:1548/1750 train_time:148553ms step_avg:95.96ms +step:1549/1750 train_time:148654ms step_avg:95.97ms +step:1550/1750 train_time:148755ms step_avg:95.97ms +step:1551/1750 train_time:148855ms step_avg:95.97ms +step:1552/1750 train_time:148957ms step_avg:95.98ms +step:1553/1750 train_time:149056ms step_avg:95.98ms +step:1554/1750 train_time:149155ms step_avg:95.98ms +step:1555/1750 train_time:149256ms step_avg:95.98ms +step:1556/1750 train_time:149357ms step_avg:95.99ms +step:1557/1750 train_time:149457ms step_avg:95.99ms +step:1558/1750 train_time:149557ms step_avg:95.99ms +step:1559/1750 train_time:149657ms step_avg:96.00ms +step:1560/1750 train_time:149757ms step_avg:96.00ms +step:1561/1750 train_time:149858ms step_avg:96.00ms +step:1562/1750 train_time:149959ms step_avg:96.00ms +step:1563/1750 train_time:150062ms step_avg:96.01ms +step:1564/1750 train_time:150162ms step_avg:96.01ms +step:1565/1750 train_time:150262ms step_avg:96.01ms +step:1566/1750 train_time:150362ms step_avg:96.02ms +step:1567/1750 train_time:150461ms step_avg:96.02ms +step:1568/1750 train_time:150561ms step_avg:96.02ms +step:1569/1750 train_time:150662ms step_avg:96.02ms +step:1570/1750 train_time:150764ms step_avg:96.03ms +step:1571/1750 train_time:150864ms step_avg:96.03ms +step:1572/1750 train_time:150964ms step_avg:96.03ms +step:1573/1750 train_time:151065ms step_avg:96.04ms +step:1574/1750 train_time:151165ms step_avg:96.04ms +step:1575/1750 train_time:151266ms step_avg:96.04ms +step:1576/1750 train_time:151366ms step_avg:96.04ms +step:1577/1750 train_time:151468ms step_avg:96.05ms +step:1578/1750 train_time:151568ms step_avg:96.05ms +step:1579/1750 train_time:151668ms step_avg:96.05ms +step:1580/1750 train_time:151771ms step_avg:96.06ms +step:1581/1750 train_time:151871ms step_avg:96.06ms +step:1582/1750 train_time:151970ms step_avg:96.06ms +step:1583/1750 train_time:152072ms step_avg:96.07ms +step:1584/1750 train_time:152173ms step_avg:96.07ms +step:1585/1750 train_time:152272ms step_avg:96.07ms +step:1586/1750 train_time:152372ms step_avg:96.07ms +step:1587/1750 train_time:152472ms step_avg:96.08ms +step:1588/1750 train_time:152573ms step_avg:96.08ms +step:1589/1750 train_time:152672ms step_avg:96.08ms +step:1590/1750 train_time:152773ms step_avg:96.08ms +step:1591/1750 train_time:152873ms step_avg:96.09ms +step:1592/1750 train_time:152975ms step_avg:96.09ms +step:1593/1750 train_time:153074ms step_avg:96.09ms +step:1594/1750 train_time:153178ms step_avg:96.10ms +step:1595/1750 train_time:153277ms step_avg:96.10ms +step:1596/1750 train_time:153377ms step_avg:96.10ms +step:1597/1750 train_time:153476ms step_avg:96.10ms +step:1598/1750 train_time:153579ms step_avg:96.11ms +step:1599/1750 train_time:153679ms step_avg:96.11ms +step:1600/1750 train_time:153781ms step_avg:96.11ms +step:1601/1750 train_time:153882ms step_avg:96.12ms +step:1602/1750 train_time:153982ms step_avg:96.12ms +step:1603/1750 train_time:154082ms step_avg:96.12ms +step:1604/1750 train_time:154182ms step_avg:96.12ms +step:1605/1750 train_time:154282ms step_avg:96.13ms +step:1606/1750 train_time:154381ms step_avg:96.13ms +step:1607/1750 train_time:154482ms step_avg:96.13ms +step:1608/1750 train_time:154581ms step_avg:96.13ms +step:1609/1750 train_time:154681ms step_avg:96.13ms +step:1610/1750 train_time:154782ms step_avg:96.14ms +step:1611/1750 train_time:154883ms step_avg:96.14ms +step:1612/1750 train_time:154984ms step_avg:96.14ms +step:1613/1750 train_time:155084ms step_avg:96.15ms +step:1614/1750 train_time:155184ms step_avg:96.15ms +step:1615/1750 train_time:155285ms step_avg:96.15ms +step:1616/1750 train_time:155385ms step_avg:96.15ms +step:1617/1750 train_time:155485ms step_avg:96.16ms +step:1618/1750 train_time:155586ms step_avg:96.16ms +step:1619/1750 train_time:155686ms step_avg:96.16ms +step:1620/1750 train_time:155786ms step_avg:96.16ms +step:1621/1750 train_time:155887ms step_avg:96.17ms +step:1622/1750 train_time:155986ms step_avg:96.17ms +step:1623/1750 train_time:156086ms step_avg:96.17ms +step:1624/1750 train_time:156187ms step_avg:96.17ms +step:1625/1750 train_time:156290ms step_avg:96.18ms +step:1625/1750 val_loss:3.3014 train_time:156394ms step_avg:96.24ms +step:1626/1750 train_time:156420ms step_avg:96.20ms +step:1627/1750 train_time:156503ms step_avg:96.19ms +step:1628/1750 train_time:156603ms step_avg:96.19ms +step:1629/1750 train_time:156703ms step_avg:96.20ms +step:1630/1750 train_time:156802ms step_avg:96.20ms +step:1631/1750 train_time:156902ms step_avg:96.20ms +step:1632/1750 train_time:157002ms step_avg:96.20ms +step:1633/1750 train_time:157101ms step_avg:96.20ms +step:1634/1750 train_time:157202ms step_avg:96.21ms +step:1635/1750 train_time:157302ms step_avg:96.21ms +step:1636/1750 train_time:157404ms step_avg:96.21ms +step:1637/1750 train_time:157507ms step_avg:96.22ms +step:1638/1750 train_time:157608ms step_avg:96.22ms +step:1639/1750 train_time:157709ms step_avg:96.22ms +step:1640/1750 train_time:157808ms step_avg:96.22ms +step:1641/1750 train_time:157907ms step_avg:96.23ms +step:1642/1750 train_time:158006ms step_avg:96.23ms +step:1643/1750 train_time:158105ms step_avg:96.23ms +step:1644/1750 train_time:158205ms step_avg:96.23ms +step:1645/1750 train_time:158305ms step_avg:96.23ms +step:1646/1750 train_time:158406ms step_avg:96.24ms +step:1647/1750 train_time:158509ms step_avg:96.24ms +step:1648/1750 train_time:158611ms step_avg:96.24ms +step:1649/1750 train_time:158712ms step_avg:96.25ms +step:1650/1750 train_time:158811ms step_avg:96.25ms +step:1651/1750 train_time:158910ms step_avg:96.25ms +step:1652/1750 train_time:159010ms step_avg:96.25ms +step:1653/1750 train_time:159111ms step_avg:96.26ms +step:1654/1750 train_time:159211ms step_avg:96.26ms +step:1655/1750 train_time:159311ms step_avg:96.26ms +step:1656/1750 train_time:159412ms step_avg:96.26ms +step:1657/1750 train_time:159512ms step_avg:96.27ms +step:1658/1750 train_time:159614ms step_avg:96.27ms +step:1659/1750 train_time:159719ms step_avg:96.27ms +step:1660/1750 train_time:159820ms step_avg:96.28ms +step:1661/1750 train_time:159921ms step_avg:96.28ms +step:1662/1750 train_time:160021ms step_avg:96.28ms +step:1663/1750 train_time:160122ms step_avg:96.29ms +step:1664/1750 train_time:160222ms step_avg:96.29ms +step:1665/1750 train_time:160323ms step_avg:96.29ms +step:1666/1750 train_time:160425ms step_avg:96.29ms +step:1667/1750 train_time:160525ms step_avg:96.30ms +step:1668/1750 train_time:160628ms step_avg:96.30ms +step:1669/1750 train_time:160729ms step_avg:96.30ms +step:1670/1750 train_time:160830ms step_avg:96.31ms +step:1671/1750 train_time:160930ms step_avg:96.31ms +step:1672/1750 train_time:161031ms step_avg:96.31ms +step:1673/1750 train_time:161130ms step_avg:96.31ms +step:1674/1750 train_time:161232ms step_avg:96.32ms +step:1675/1750 train_time:161334ms step_avg:96.32ms +step:1676/1750 train_time:161434ms step_avg:96.32ms +step:1677/1750 train_time:161535ms step_avg:96.32ms +step:1678/1750 train_time:161635ms step_avg:96.33ms +step:1679/1750 train_time:161736ms step_avg:96.33ms +step:1680/1750 train_time:161836ms step_avg:96.33ms +step:1681/1750 train_time:161938ms step_avg:96.33ms +step:1682/1750 train_time:162042ms step_avg:96.34ms +step:1683/1750 train_time:162142ms step_avg:96.34ms +step:1684/1750 train_time:162244ms step_avg:96.34ms +step:1685/1750 train_time:162343ms step_avg:96.35ms +step:1686/1750 train_time:162443ms step_avg:96.35ms +step:1687/1750 train_time:162543ms step_avg:96.35ms +step:1688/1750 train_time:162644ms step_avg:96.35ms +step:1689/1750 train_time:162744ms step_avg:96.36ms +step:1690/1750 train_time:162846ms step_avg:96.36ms +step:1691/1750 train_time:162948ms step_avg:96.36ms +step:1692/1750 train_time:163048ms step_avg:96.36ms +step:1693/1750 train_time:163149ms step_avg:96.37ms +step:1694/1750 train_time:163250ms step_avg:96.37ms +step:1695/1750 train_time:163353ms step_avg:96.37ms +step:1696/1750 train_time:163453ms step_avg:96.38ms +step:1697/1750 train_time:163556ms step_avg:96.38ms +step:1698/1750 train_time:163656ms step_avg:96.38ms +step:1699/1750 train_time:163757ms step_avg:96.38ms +step:1700/1750 train_time:163858ms step_avg:96.39ms +step:1701/1750 train_time:163958ms step_avg:96.39ms +step:1702/1750 train_time:164063ms step_avg:96.39ms +step:1703/1750 train_time:164164ms step_avg:96.40ms +step:1704/1750 train_time:164266ms step_avg:96.40ms +step:1705/1750 train_time:164366ms step_avg:96.40ms +step:1706/1750 train_time:164466ms step_avg:96.40ms +step:1707/1750 train_time:164569ms step_avg:96.41ms +step:1708/1750 train_time:164672ms step_avg:96.41ms +step:1709/1750 train_time:164774ms step_avg:96.42ms +step:1710/1750 train_time:164875ms step_avg:96.42ms +step:1711/1750 train_time:164976ms step_avg:96.42ms +step:1712/1750 train_time:165077ms step_avg:96.42ms +step:1713/1750 train_time:165178ms step_avg:96.43ms +step:1714/1750 train_time:165277ms step_avg:96.43ms +step:1715/1750 train_time:165381ms step_avg:96.43ms +step:1716/1750 train_time:165484ms step_avg:96.44ms +step:1717/1750 train_time:165585ms step_avg:96.44ms +step:1718/1750 train_time:165686ms step_avg:96.44ms +step:1719/1750 train_time:165790ms step_avg:96.45ms +step:1720/1750 train_time:165891ms step_avg:96.45ms +step:1721/1750 train_time:165993ms step_avg:96.45ms +step:1722/1750 train_time:166094ms step_avg:96.45ms +step:1723/1750 train_time:166195ms step_avg:96.46ms +step:1724/1750 train_time:166297ms step_avg:96.46ms +step:1725/1750 train_time:166398ms step_avg:96.46ms +step:1726/1750 train_time:166500ms step_avg:96.47ms +step:1727/1750 train_time:166601ms step_avg:96.47ms +step:1728/1750 train_time:166705ms step_avg:96.47ms +step:1729/1750 train_time:166809ms step_avg:96.48ms +step:1730/1750 train_time:166908ms step_avg:96.48ms +step:1731/1750 train_time:167009ms step_avg:96.48ms +step:1732/1750 train_time:167110ms step_avg:96.48ms +step:1733/1750 train_time:167212ms step_avg:96.49ms +step:1734/1750 train_time:167314ms step_avg:96.49ms +step:1735/1750 train_time:167415ms step_avg:96.49ms +step:1736/1750 train_time:167516ms step_avg:96.50ms +step:1737/1750 train_time:167617ms step_avg:96.50ms +step:1738/1750 train_time:167718ms step_avg:96.50ms +step:1739/1750 train_time:167820ms step_avg:96.50ms +step:1740/1750 train_time:167921ms step_avg:96.51ms +step:1741/1750 train_time:168026ms step_avg:96.51ms +step:1742/1750 train_time:168126ms step_avg:96.51ms +step:1743/1750 train_time:168226ms step_avg:96.52ms +step:1744/1750 train_time:168327ms step_avg:96.52ms +step:1745/1750 train_time:168427ms step_avg:96.52ms +step:1746/1750 train_time:168528ms step_avg:96.52ms +step:1747/1750 train_time:168630ms step_avg:96.53ms +step:1748/1750 train_time:168733ms step_avg:96.53ms +step:1749/1750 train_time:168834ms step_avg:96.53ms +step:1750/1750 train_time:168935ms step_avg:96.53ms +step:1750/1750 val_loss:3.2780 train_time:169042ms step_avg:96.60ms +peak memory allocated: 34460 MiB reserved: 48914 MiB \ No newline at end of file diff --git a/records/082325_SparseAttnGate/020630eb-2191-4ba2-9ee4-4cdc94316943.txt b/records/082325_SparseAttnGate/020630eb-2191-4ba2-9ee4-4cdc94316943.txt new file mode 100644 index 000000000..70325646a --- /dev/null +++ b/records/082325_SparseAttnGate/020630eb-2191-4ba2-9ee4-4cdc94316943.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:51:59 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 321540 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 321541 C /usr/bin/python3 614MiB | +| 0 N/A N/A 321542 C /usr/bin/python3 614MiB | +| 0 N/A N/A 321543 C /usr/bin/python3 614MiB | +| 0 N/A N/A 321544 C /usr/bin/python3 614MiB | +| 0 N/A N/A 321545 C /usr/bin/python3 614MiB | +| 0 N/A N/A 321546 C /usr/bin/python3 614MiB | +| 0 N/A N/A 321547 C /usr/bin/python3 614MiB | +| 1 N/A N/A 321541 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 321542 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 321543 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 321544 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 321545 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 321546 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 321547 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:153ms step_avg:153.49ms +step:2/1695 train_time:179ms step_avg:89.47ms +step:3/1695 train_time:250ms step_avg:83.25ms +step:4/1695 train_time:342ms step_avg:85.39ms +step:5/1695 train_time:435ms step_avg:86.93ms +step:6/1695 train_time:527ms step_avg:87.77ms +step:7/1695 train_time:620ms step_avg:88.62ms +step:8/1695 train_time:713ms step_avg:89.16ms +step:9/1695 train_time:806ms step_avg:89.56ms +step:10/1695 train_time:899ms step_avg:89.92ms +step:11/1695 train_time:992ms step_avg:90.20ms +step:12/1695 train_time:1086ms step_avg:90.53ms +step:13/1695 train_time:1183ms step_avg:91.00ms +step:14/1695 train_time:1278ms step_avg:91.26ms +step:15/1695 train_time:1371ms step_avg:91.40ms +step:16/1695 train_time:1465ms step_avg:91.55ms +step:17/1695 train_time:1558ms step_avg:91.67ms +step:18/1695 train_time:1651ms step_avg:91.74ms +step:19/1695 train_time:1744ms step_avg:91.81ms +step:20/1695 train_time:1839ms step_avg:91.97ms +step:21/1695 train_time:1934ms step_avg:92.08ms +step:22/1695 train_time:2027ms step_avg:92.15ms +step:23/1695 train_time:2122ms step_avg:92.25ms +step:24/1695 train_time:2216ms step_avg:92.35ms +step:25/1695 train_time:2310ms step_avg:92.41ms +step:26/1695 train_time:2404ms step_avg:92.47ms +step:27/1695 train_time:2498ms step_avg:92.50ms +step:28/1695 train_time:2591ms step_avg:92.55ms +step:29/1695 train_time:2684ms step_avg:92.57ms +step:30/1695 train_time:2778ms step_avg:92.59ms +step:31/1695 train_time:2872ms step_avg:92.63ms +step:32/1695 train_time:2966ms step_avg:92.68ms +step:33/1695 train_time:3060ms step_avg:92.73ms +step:34/1695 train_time:3154ms step_avg:92.78ms +step:35/1695 train_time:3247ms step_avg:92.79ms +step:36/1695 train_time:3342ms step_avg:92.84ms +step:37/1695 train_time:3437ms step_avg:92.89ms +step:38/1695 train_time:3531ms step_avg:92.91ms +step:39/1695 train_time:3624ms step_avg:92.93ms +step:40/1695 train_time:3719ms step_avg:92.96ms +step:41/1695 train_time:3813ms step_avg:93.01ms +step:42/1695 train_time:3907ms step_avg:93.03ms +step:43/1695 train_time:4001ms step_avg:93.05ms +step:44/1695 train_time:4095ms step_avg:93.06ms +step:45/1695 train_time:4188ms step_avg:93.06ms +step:46/1695 train_time:4281ms step_avg:93.06ms +step:47/1695 train_time:4374ms step_avg:93.07ms +step:48/1695 train_time:4469ms step_avg:93.11ms +step:49/1695 train_time:4563ms step_avg:93.12ms +step:50/1695 train_time:4658ms step_avg:93.16ms +step:51/1695 train_time:4751ms step_avg:93.16ms +step:52/1695 train_time:4845ms step_avg:93.17ms +step:53/1695 train_time:4939ms step_avg:93.20ms +step:54/1695 train_time:5034ms step_avg:93.22ms +step:55/1695 train_time:5127ms step_avg:93.22ms +step:56/1695 train_time:5220ms step_avg:93.22ms +step:57/1695 train_time:5314ms step_avg:93.24ms +step:58/1695 train_time:5409ms step_avg:93.25ms +step:59/1695 train_time:5502ms step_avg:93.26ms +step:60/1695 train_time:5596ms step_avg:93.27ms +step:61/1695 train_time:5690ms step_avg:93.27ms +step:62/1695 train_time:5783ms step_avg:93.28ms +step:63/1695 train_time:5877ms step_avg:93.28ms +step:64/1695 train_time:5970ms step_avg:93.29ms +step:65/1695 train_time:6065ms step_avg:93.30ms +step:66/1695 train_time:6159ms step_avg:93.31ms +step:67/1695 train_time:6252ms step_avg:93.31ms +step:68/1695 train_time:6346ms step_avg:93.32ms +step:69/1695 train_time:6440ms step_avg:93.34ms +step:70/1695 train_time:6535ms step_avg:93.36ms +step:71/1695 train_time:6629ms step_avg:93.36ms +step:72/1695 train_time:6724ms step_avg:93.38ms +step:73/1695 train_time:6818ms step_avg:93.39ms +step:74/1695 train_time:6912ms step_avg:93.40ms +step:75/1695 train_time:7005ms step_avg:93.40ms +step:76/1695 train_time:7099ms step_avg:93.41ms +step:77/1695 train_time:7193ms step_avg:93.42ms +step:78/1695 train_time:7287ms step_avg:93.42ms +step:79/1695 train_time:7380ms step_avg:93.42ms +step:80/1695 train_time:7475ms step_avg:93.44ms +step:81/1695 train_time:7568ms step_avg:93.44ms +step:82/1695 train_time:7662ms step_avg:93.44ms +step:83/1695 train_time:7756ms step_avg:93.45ms +step:84/1695 train_time:7850ms step_avg:93.45ms +step:85/1695 train_time:7943ms step_avg:93.45ms +step:86/1695 train_time:8037ms step_avg:93.46ms +step:87/1695 train_time:8130ms step_avg:93.45ms +step:88/1695 train_time:8224ms step_avg:93.46ms +step:89/1695 train_time:8319ms step_avg:93.47ms +step:90/1695 train_time:8413ms step_avg:93.48ms +step:91/1695 train_time:8507ms step_avg:93.48ms +step:92/1695 train_time:8600ms step_avg:93.48ms +step:93/1695 train_time:8694ms step_avg:93.48ms +step:94/1695 train_time:8787ms step_avg:93.48ms +step:95/1695 train_time:8880ms step_avg:93.47ms +step:96/1695 train_time:8974ms step_avg:93.48ms +step:97/1695 train_time:9068ms step_avg:93.48ms +step:98/1695 train_time:9161ms step_avg:93.48ms +step:99/1695 train_time:9254ms step_avg:93.48ms +step:100/1695 train_time:9348ms step_avg:93.48ms +step:101/1695 train_time:9442ms step_avg:93.48ms +step:102/1695 train_time:9536ms step_avg:93.49ms +step:103/1695 train_time:9630ms step_avg:93.49ms +step:104/1695 train_time:9723ms step_avg:93.49ms +step:105/1695 train_time:9817ms step_avg:93.49ms +step:106/1695 train_time:9910ms step_avg:93.49ms +step:107/1695 train_time:10004ms step_avg:93.49ms +step:108/1695 train_time:10098ms step_avg:93.50ms +step:109/1695 train_time:10191ms step_avg:93.49ms +step:110/1695 train_time:10284ms step_avg:93.49ms +step:111/1695 train_time:10378ms step_avg:93.50ms +step:112/1695 train_time:10471ms step_avg:93.49ms +step:113/1695 train_time:10565ms step_avg:93.50ms +step:114/1695 train_time:10658ms step_avg:93.49ms +step:115/1695 train_time:10752ms step_avg:93.49ms +step:116/1695 train_time:10845ms step_avg:93.49ms +step:117/1695 train_time:10940ms step_avg:93.50ms +step:118/1695 train_time:11035ms step_avg:93.51ms +step:119/1695 train_time:11128ms step_avg:93.52ms +step:120/1695 train_time:11222ms step_avg:93.51ms +step:121/1695 train_time:11315ms step_avg:93.51ms +step:122/1695 train_time:11409ms step_avg:93.51ms +step:123/1695 train_time:11502ms step_avg:93.51ms +step:124/1695 train_time:11596ms step_avg:93.51ms +step:125/1695 train_time:11689ms step_avg:93.51ms +step:125/1695 val_loss:4.6029 train_time:11781ms step_avg:94.25ms +step:126/1695 train_time:11809ms step_avg:93.72ms +step:127/1695 train_time:11886ms step_avg:93.59ms +step:128/1695 train_time:11989ms step_avg:93.67ms +step:129/1695 train_time:12085ms step_avg:93.68ms +step:130/1695 train_time:12179ms step_avg:93.68ms +step:131/1695 train_time:12272ms step_avg:93.68ms +step:132/1695 train_time:12365ms step_avg:93.68ms +step:133/1695 train_time:12459ms step_avg:93.67ms +step:134/1695 train_time:12552ms step_avg:93.67ms +step:135/1695 train_time:12646ms step_avg:93.67ms +step:136/1695 train_time:12739ms step_avg:93.67ms +step:137/1695 train_time:12833ms step_avg:93.67ms +step:138/1695 train_time:12931ms step_avg:93.70ms +step:139/1695 train_time:13028ms step_avg:93.73ms +step:140/1695 train_time:13124ms step_avg:93.74ms +step:141/1695 train_time:13218ms step_avg:93.74ms +step:142/1695 train_time:13311ms step_avg:93.74ms +step:143/1695 train_time:13405ms step_avg:93.74ms +step:144/1695 train_time:13498ms step_avg:93.74ms +step:145/1695 train_time:13592ms step_avg:93.74ms +step:146/1695 train_time:13685ms step_avg:93.73ms +step:147/1695 train_time:13779ms step_avg:93.73ms +step:148/1695 train_time:13874ms step_avg:93.74ms +step:149/1695 train_time:13968ms step_avg:93.74ms +step:150/1695 train_time:14063ms step_avg:93.75ms +step:151/1695 train_time:14157ms step_avg:93.76ms +step:152/1695 train_time:14251ms step_avg:93.76ms +step:153/1695 train_time:14345ms step_avg:93.76ms +step:154/1695 train_time:14439ms step_avg:93.76ms +step:155/1695 train_time:14533ms step_avg:93.76ms +step:156/1695 train_time:14627ms step_avg:93.76ms +step:157/1695 train_time:14721ms step_avg:93.77ms +step:158/1695 train_time:14816ms step_avg:93.77ms +step:159/1695 train_time:14910ms step_avg:93.77ms +step:160/1695 train_time:15005ms step_avg:93.78ms +step:161/1695 train_time:15100ms step_avg:93.79ms +step:162/1695 train_time:15195ms step_avg:93.79ms +step:163/1695 train_time:15289ms step_avg:93.80ms +step:164/1695 train_time:15384ms step_avg:93.80ms +step:165/1695 train_time:15478ms step_avg:93.81ms +step:166/1695 train_time:15572ms step_avg:93.81ms +step:167/1695 train_time:15666ms step_avg:93.81ms +step:168/1695 train_time:15760ms step_avg:93.81ms +step:169/1695 train_time:15853ms step_avg:93.80ms +step:170/1695 train_time:15947ms step_avg:93.81ms +step:171/1695 train_time:16041ms step_avg:93.81ms +step:172/1695 train_time:16135ms step_avg:93.81ms +step:173/1695 train_time:16230ms step_avg:93.81ms +step:174/1695 train_time:16324ms step_avg:93.82ms +step:175/1695 train_time:16419ms step_avg:93.82ms +step:176/1695 train_time:16513ms step_avg:93.82ms +step:177/1695 train_time:16606ms step_avg:93.82ms +step:178/1695 train_time:16700ms step_avg:93.82ms +step:179/1695 train_time:16795ms step_avg:93.83ms +step:180/1695 train_time:16888ms step_avg:93.82ms +step:181/1695 train_time:16983ms step_avg:93.83ms +step:182/1695 train_time:17077ms step_avg:93.83ms +step:183/1695 train_time:17171ms step_avg:93.83ms +step:184/1695 train_time:17265ms step_avg:93.83ms +step:185/1695 train_time:17359ms step_avg:93.83ms +step:186/1695 train_time:17453ms step_avg:93.83ms +step:187/1695 train_time:17547ms step_avg:93.83ms +step:188/1695 train_time:17642ms step_avg:93.84ms +step:189/1695 train_time:17736ms step_avg:93.84ms +step:190/1695 train_time:17829ms step_avg:93.84ms +step:191/1695 train_time:17923ms step_avg:93.84ms +step:192/1695 train_time:18016ms step_avg:93.84ms +step:193/1695 train_time:18110ms step_avg:93.84ms +step:194/1695 train_time:18205ms step_avg:93.84ms +step:195/1695 train_time:18299ms step_avg:93.84ms +step:196/1695 train_time:18392ms step_avg:93.84ms +step:197/1695 train_time:18486ms step_avg:93.84ms +step:198/1695 train_time:18581ms step_avg:93.84ms +step:199/1695 train_time:18676ms step_avg:93.85ms +step:200/1695 train_time:18770ms step_avg:93.85ms +step:201/1695 train_time:18864ms step_avg:93.85ms +step:202/1695 train_time:18958ms step_avg:93.85ms +step:203/1695 train_time:19052ms step_avg:93.85ms +step:204/1695 train_time:19145ms step_avg:93.85ms +step:205/1695 train_time:19239ms step_avg:93.85ms +step:206/1695 train_time:19333ms step_avg:93.85ms +step:207/1695 train_time:19427ms step_avg:93.85ms +step:208/1695 train_time:19521ms step_avg:93.85ms +step:209/1695 train_time:19616ms step_avg:93.86ms +step:210/1695 train_time:19710ms step_avg:93.86ms +step:211/1695 train_time:19805ms step_avg:93.86ms +step:212/1695 train_time:19899ms step_avg:93.87ms +step:213/1695 train_time:19992ms step_avg:93.86ms +step:214/1695 train_time:20086ms step_avg:93.86ms +step:215/1695 train_time:20181ms step_avg:93.86ms +step:216/1695 train_time:20275ms step_avg:93.86ms +step:217/1695 train_time:20370ms step_avg:93.87ms +step:218/1695 train_time:20464ms step_avg:93.87ms +step:219/1695 train_time:20558ms step_avg:93.87ms +step:220/1695 train_time:20652ms step_avg:93.87ms +step:221/1695 train_time:20747ms step_avg:93.88ms +step:222/1695 train_time:20842ms step_avg:93.88ms +step:223/1695 train_time:20936ms step_avg:93.88ms +step:224/1695 train_time:21030ms step_avg:93.88ms +step:225/1695 train_time:21124ms step_avg:93.88ms +step:226/1695 train_time:21218ms step_avg:93.89ms +step:227/1695 train_time:21312ms step_avg:93.89ms +step:228/1695 train_time:21406ms step_avg:93.89ms +step:229/1695 train_time:21501ms step_avg:93.89ms +step:230/1695 train_time:21595ms step_avg:93.89ms +step:231/1695 train_time:21689ms step_avg:93.89ms +step:232/1695 train_time:21783ms step_avg:93.89ms +step:233/1695 train_time:21878ms step_avg:93.90ms +step:234/1695 train_time:21972ms step_avg:93.90ms +step:235/1695 train_time:22066ms step_avg:93.90ms +step:236/1695 train_time:22159ms step_avg:93.90ms +step:237/1695 train_time:22253ms step_avg:93.90ms +step:238/1695 train_time:22347ms step_avg:93.89ms +step:239/1695 train_time:22442ms step_avg:93.90ms +step:240/1695 train_time:22537ms step_avg:93.91ms +step:241/1695 train_time:22631ms step_avg:93.90ms +step:242/1695 train_time:22724ms step_avg:93.90ms +step:243/1695 train_time:22819ms step_avg:93.91ms +step:244/1695 train_time:22913ms step_avg:93.91ms +step:245/1695 train_time:23008ms step_avg:93.91ms +step:246/1695 train_time:23102ms step_avg:93.91ms +step:247/1695 train_time:23196ms step_avg:93.91ms +step:248/1695 train_time:23290ms step_avg:93.91ms +step:249/1695 train_time:23384ms step_avg:93.91ms +step:250/1695 train_time:23478ms step_avg:93.91ms +step:250/1695 val_loss:4.0788 train_time:23571ms step_avg:94.28ms +step:251/1695 train_time:23598ms step_avg:94.02ms +step:252/1695 train_time:23676ms step_avg:93.95ms +step:253/1695 train_time:23777ms step_avg:93.98ms +step:254/1695 train_time:23872ms step_avg:93.98ms +step:255/1695 train_time:23966ms step_avg:93.98ms +step:256/1695 train_time:24060ms step_avg:93.98ms +step:257/1695 train_time:24153ms step_avg:93.98ms +step:258/1695 train_time:24248ms step_avg:93.98ms +step:259/1695 train_time:24341ms step_avg:93.98ms +step:260/1695 train_time:24435ms step_avg:93.98ms +step:261/1695 train_time:24529ms step_avg:93.98ms +step:262/1695 train_time:24623ms step_avg:93.98ms +step:263/1695 train_time:24720ms step_avg:93.99ms +step:264/1695 train_time:24816ms step_avg:94.00ms +step:265/1695 train_time:24910ms step_avg:94.00ms +step:266/1695 train_time:25004ms step_avg:94.00ms +step:267/1695 train_time:25098ms step_avg:94.00ms +step:268/1695 train_time:25192ms step_avg:94.00ms +step:269/1695 train_time:25286ms step_avg:94.00ms +step:270/1695 train_time:25380ms step_avg:94.00ms +step:271/1695 train_time:25473ms step_avg:94.00ms +step:272/1695 train_time:25568ms step_avg:94.00ms +step:273/1695 train_time:25664ms step_avg:94.01ms +step:274/1695 train_time:25759ms step_avg:94.01ms +step:275/1695 train_time:25854ms step_avg:94.01ms +step:276/1695 train_time:25948ms step_avg:94.02ms +step:277/1695 train_time:26043ms step_avg:94.02ms +step:278/1695 train_time:26138ms step_avg:94.02ms +step:279/1695 train_time:26232ms step_avg:94.02ms +step:280/1695 train_time:26326ms step_avg:94.02ms +step:281/1695 train_time:26420ms step_avg:94.02ms +step:282/1695 train_time:26514ms step_avg:94.02ms +step:283/1695 train_time:26608ms step_avg:94.02ms +step:284/1695 train_time:26703ms step_avg:94.02ms +step:285/1695 train_time:26799ms step_avg:94.03ms +step:286/1695 train_time:26893ms step_avg:94.03ms +step:287/1695 train_time:26988ms step_avg:94.03ms +step:288/1695 train_time:27082ms step_avg:94.04ms +step:289/1695 train_time:27177ms step_avg:94.04ms +step:290/1695 train_time:27271ms step_avg:94.04ms +step:291/1695 train_time:27365ms step_avg:94.04ms +step:292/1695 train_time:27460ms step_avg:94.04ms +step:293/1695 train_time:27554ms step_avg:94.04ms +step:294/1695 train_time:27648ms step_avg:94.04ms +step:295/1695 train_time:27743ms step_avg:94.04ms +step:296/1695 train_time:27838ms step_avg:94.05ms +step:297/1695 train_time:27933ms step_avg:94.05ms +step:298/1695 train_time:28026ms step_avg:94.05ms +step:299/1695 train_time:28121ms step_avg:94.05ms +step:300/1695 train_time:28216ms step_avg:94.05ms +step:301/1695 train_time:28310ms step_avg:94.05ms +step:302/1695 train_time:28404ms step_avg:94.05ms +step:303/1695 train_time:28499ms step_avg:94.06ms +step:304/1695 train_time:28593ms step_avg:94.06ms +step:305/1695 train_time:28687ms step_avg:94.06ms +step:306/1695 train_time:28782ms step_avg:94.06ms +step:307/1695 train_time:28877ms step_avg:94.06ms +step:308/1695 train_time:28971ms step_avg:94.06ms +step:309/1695 train_time:29065ms step_avg:94.06ms +step:310/1695 train_time:29159ms step_avg:94.06ms +step:311/1695 train_time:29254ms step_avg:94.06ms +step:312/1695 train_time:29348ms step_avg:94.06ms +step:313/1695 train_time:29442ms step_avg:94.06ms +step:314/1695 train_time:29538ms step_avg:94.07ms +step:315/1695 train_time:29631ms step_avg:94.07ms +step:316/1695 train_time:29725ms step_avg:94.07ms +step:317/1695 train_time:29820ms step_avg:94.07ms +step:318/1695 train_time:29915ms step_avg:94.07ms +step:319/1695 train_time:30009ms step_avg:94.07ms +step:320/1695 train_time:30104ms step_avg:94.07ms +step:321/1695 train_time:30198ms step_avg:94.08ms +step:322/1695 train_time:30293ms step_avg:94.08ms +step:323/1695 train_time:30387ms step_avg:94.08ms +step:324/1695 train_time:30482ms step_avg:94.08ms +step:325/1695 train_time:30577ms step_avg:94.08ms +step:326/1695 train_time:30670ms step_avg:94.08ms +step:327/1695 train_time:30765ms step_avg:94.08ms +step:328/1695 train_time:30859ms step_avg:94.08ms +step:329/1695 train_time:30953ms step_avg:94.08ms +step:330/1695 train_time:31048ms step_avg:94.08ms +step:331/1695 train_time:31142ms step_avg:94.09ms +step:332/1695 train_time:31237ms step_avg:94.09ms +step:333/1695 train_time:31332ms step_avg:94.09ms +step:334/1695 train_time:31426ms step_avg:94.09ms +step:335/1695 train_time:31521ms step_avg:94.09ms +step:336/1695 train_time:31616ms step_avg:94.10ms +step:337/1695 train_time:31710ms step_avg:94.10ms +step:338/1695 train_time:31805ms step_avg:94.10ms +step:339/1695 train_time:31900ms step_avg:94.10ms +step:340/1695 train_time:31994ms step_avg:94.10ms +step:341/1695 train_time:32088ms step_avg:94.10ms +step:342/1695 train_time:32183ms step_avg:94.10ms +step:343/1695 train_time:32278ms step_avg:94.10ms +step:344/1695 train_time:32372ms step_avg:94.10ms +step:345/1695 train_time:32466ms step_avg:94.10ms +step:346/1695 train_time:32561ms step_avg:94.11ms +step:347/1695 train_time:32655ms step_avg:94.11ms +step:348/1695 train_time:32749ms step_avg:94.11ms +step:349/1695 train_time:32844ms step_avg:94.11ms +step:350/1695 train_time:32937ms step_avg:94.11ms +step:351/1695 train_time:33032ms step_avg:94.11ms +step:352/1695 train_time:33126ms step_avg:94.11ms +step:353/1695 train_time:33221ms step_avg:94.11ms +step:354/1695 train_time:33315ms step_avg:94.11ms +step:355/1695 train_time:33409ms step_avg:94.11ms +step:356/1695 train_time:33503ms step_avg:94.11ms +step:357/1695 train_time:33598ms step_avg:94.11ms +step:358/1695 train_time:33692ms step_avg:94.11ms +step:359/1695 train_time:33787ms step_avg:94.11ms +step:360/1695 train_time:33882ms step_avg:94.12ms +step:361/1695 train_time:33977ms step_avg:94.12ms +step:362/1695 train_time:34071ms step_avg:94.12ms +step:363/1695 train_time:34165ms step_avg:94.12ms +step:364/1695 train_time:34261ms step_avg:94.12ms +step:365/1695 train_time:34356ms step_avg:94.13ms +step:366/1695 train_time:34450ms step_avg:94.13ms +step:367/1695 train_time:34545ms step_avg:94.13ms +step:368/1695 train_time:34640ms step_avg:94.13ms +step:369/1695 train_time:34735ms step_avg:94.13ms +step:370/1695 train_time:34829ms step_avg:94.13ms +step:371/1695 train_time:34924ms step_avg:94.13ms +step:372/1695 train_time:35018ms step_avg:94.13ms +step:373/1695 train_time:35112ms step_avg:94.13ms +step:374/1695 train_time:35206ms step_avg:94.13ms +step:375/1695 train_time:35300ms step_avg:94.13ms +step:375/1695 val_loss:3.8792 train_time:35392ms step_avg:94.38ms +step:376/1695 train_time:35419ms step_avg:94.20ms +step:377/1695 train_time:35496ms step_avg:94.15ms +step:378/1695 train_time:35596ms step_avg:94.17ms +step:379/1695 train_time:35694ms step_avg:94.18ms +step:380/1695 train_time:35789ms step_avg:94.18ms +step:381/1695 train_time:35884ms step_avg:94.18ms +step:382/1695 train_time:35980ms step_avg:94.19ms +step:383/1695 train_time:36076ms step_avg:94.19ms +step:384/1695 train_time:36172ms step_avg:94.20ms +step:385/1695 train_time:36267ms step_avg:94.20ms +step:386/1695 train_time:36363ms step_avg:94.20ms +step:387/1695 train_time:36460ms step_avg:94.21ms +step:388/1695 train_time:36557ms step_avg:94.22ms +step:389/1695 train_time:36654ms step_avg:94.23ms +step:390/1695 train_time:36752ms step_avg:94.23ms +step:391/1695 train_time:36849ms step_avg:94.24ms +step:392/1695 train_time:36944ms step_avg:94.25ms +step:393/1695 train_time:37040ms step_avg:94.25ms +step:394/1695 train_time:37136ms step_avg:94.25ms +step:395/1695 train_time:37231ms step_avg:94.26ms +step:396/1695 train_time:37327ms step_avg:94.26ms +step:397/1695 train_time:37422ms step_avg:94.26ms +step:398/1695 train_time:37519ms step_avg:94.27ms +step:399/1695 train_time:37615ms step_avg:94.27ms +step:400/1695 train_time:37713ms step_avg:94.28ms +step:401/1695 train_time:37811ms step_avg:94.29ms +step:402/1695 train_time:37908ms step_avg:94.30ms +step:403/1695 train_time:38003ms step_avg:94.30ms +step:404/1695 train_time:38099ms step_avg:94.31ms +step:405/1695 train_time:38196ms step_avg:94.31ms +step:406/1695 train_time:38291ms step_avg:94.31ms +step:407/1695 train_time:38387ms step_avg:94.32ms +step:408/1695 train_time:38483ms step_avg:94.32ms +step:409/1695 train_time:38579ms step_avg:94.33ms +step:410/1695 train_time:38675ms step_avg:94.33ms +step:411/1695 train_time:38772ms step_avg:94.34ms +step:412/1695 train_time:38869ms step_avg:94.34ms +step:413/1695 train_time:38967ms step_avg:94.35ms +step:414/1695 train_time:39063ms step_avg:94.35ms +step:415/1695 train_time:39158ms step_avg:94.36ms +step:416/1695 train_time:39255ms step_avg:94.36ms +step:417/1695 train_time:39350ms step_avg:94.36ms +step:418/1695 train_time:39446ms step_avg:94.37ms +step:419/1695 train_time:39541ms step_avg:94.37ms +step:420/1695 train_time:39638ms step_avg:94.38ms +step:421/1695 train_time:39734ms step_avg:94.38ms +step:422/1695 train_time:39830ms step_avg:94.38ms +step:423/1695 train_time:39926ms step_avg:94.39ms +step:424/1695 train_time:40022ms step_avg:94.39ms +step:425/1695 train_time:40118ms step_avg:94.40ms +step:426/1695 train_time:40214ms step_avg:94.40ms +step:427/1695 train_time:40312ms step_avg:94.41ms +step:428/1695 train_time:40408ms step_avg:94.41ms +step:429/1695 train_time:40504ms step_avg:94.41ms +step:430/1695 train_time:40599ms step_avg:94.42ms +step:431/1695 train_time:40696ms step_avg:94.42ms +step:432/1695 train_time:40792ms step_avg:94.42ms +step:433/1695 train_time:40887ms step_avg:94.43ms +step:434/1695 train_time:40983ms step_avg:94.43ms +step:435/1695 train_time:41080ms step_avg:94.44ms +step:436/1695 train_time:41177ms step_avg:94.44ms +step:437/1695 train_time:41274ms step_avg:94.45ms +step:438/1695 train_time:41370ms step_avg:94.45ms +step:439/1695 train_time:41467ms step_avg:94.46ms +step:440/1695 train_time:41563ms step_avg:94.46ms +step:441/1695 train_time:41658ms step_avg:94.46ms +step:442/1695 train_time:41754ms step_avg:94.47ms +step:443/1695 train_time:41850ms step_avg:94.47ms +step:444/1695 train_time:41947ms step_avg:94.47ms +step:445/1695 train_time:42042ms step_avg:94.48ms +step:446/1695 train_time:42138ms step_avg:94.48ms +step:447/1695 train_time:42234ms step_avg:94.48ms +step:448/1695 train_time:42330ms step_avg:94.49ms +step:449/1695 train_time:42426ms step_avg:94.49ms +step:450/1695 train_time:42522ms step_avg:94.49ms +step:451/1695 train_time:42618ms step_avg:94.50ms +step:452/1695 train_time:42714ms step_avg:94.50ms +step:453/1695 train_time:42810ms step_avg:94.50ms +step:454/1695 train_time:42906ms step_avg:94.51ms +step:455/1695 train_time:43002ms step_avg:94.51ms +step:456/1695 train_time:43098ms step_avg:94.51ms +step:457/1695 train_time:43194ms step_avg:94.52ms +step:458/1695 train_time:43290ms step_avg:94.52ms +step:459/1695 train_time:43385ms step_avg:94.52ms +step:460/1695 train_time:43481ms step_avg:94.52ms +step:461/1695 train_time:43578ms step_avg:94.53ms +step:462/1695 train_time:43674ms step_avg:94.53ms +step:463/1695 train_time:43770ms step_avg:94.54ms +step:464/1695 train_time:43866ms step_avg:94.54ms +step:465/1695 train_time:43962ms step_avg:94.54ms +step:466/1695 train_time:44058ms step_avg:94.54ms +step:467/1695 train_time:44154ms step_avg:94.55ms +step:468/1695 train_time:44250ms step_avg:94.55ms +step:469/1695 train_time:44346ms step_avg:94.55ms +step:470/1695 train_time:44442ms step_avg:94.56ms +step:471/1695 train_time:44538ms step_avg:94.56ms +step:472/1695 train_time:44634ms step_avg:94.56ms +step:473/1695 train_time:44730ms step_avg:94.57ms +step:474/1695 train_time:44826ms step_avg:94.57ms +step:475/1695 train_time:44922ms step_avg:94.57ms +step:476/1695 train_time:45019ms step_avg:94.58ms +step:477/1695 train_time:45115ms step_avg:94.58ms +step:478/1695 train_time:45211ms step_avg:94.58ms +step:479/1695 train_time:45307ms step_avg:94.59ms +step:480/1695 train_time:45403ms step_avg:94.59ms +step:481/1695 train_time:45499ms step_avg:94.59ms +step:482/1695 train_time:45595ms step_avg:94.60ms +step:483/1695 train_time:45692ms step_avg:94.60ms +step:484/1695 train_time:45788ms step_avg:94.60ms +step:485/1695 train_time:45884ms step_avg:94.61ms +step:486/1695 train_time:45980ms step_avg:94.61ms +step:487/1695 train_time:46076ms step_avg:94.61ms +step:488/1695 train_time:46173ms step_avg:94.62ms +step:489/1695 train_time:46270ms step_avg:94.62ms +step:490/1695 train_time:46365ms step_avg:94.62ms +step:491/1695 train_time:46462ms step_avg:94.63ms +step:492/1695 train_time:46557ms step_avg:94.63ms +step:493/1695 train_time:46654ms step_avg:94.63ms +step:494/1695 train_time:46750ms step_avg:94.64ms +step:495/1695 train_time:46847ms step_avg:94.64ms +step:496/1695 train_time:46943ms step_avg:94.64ms +step:497/1695 train_time:47039ms step_avg:94.65ms +step:498/1695 train_time:47135ms step_avg:94.65ms +step:499/1695 train_time:47232ms step_avg:94.65ms +step:500/1695 train_time:47329ms step_avg:94.66ms +step:500/1695 val_loss:3.7326 train_time:47422ms step_avg:94.84ms +step:501/1695 train_time:47453ms step_avg:94.72ms +step:502/1695 train_time:47532ms step_avg:94.68ms +step:503/1695 train_time:47632ms step_avg:94.70ms +step:504/1695 train_time:47730ms step_avg:94.70ms +step:505/1695 train_time:47826ms step_avg:94.71ms +step:506/1695 train_time:47922ms step_avg:94.71ms +step:507/1695 train_time:48018ms step_avg:94.71ms +step:508/1695 train_time:48113ms step_avg:94.71ms +step:509/1695 train_time:48209ms step_avg:94.71ms +step:510/1695 train_time:48305ms step_avg:94.72ms +step:511/1695 train_time:48401ms step_avg:94.72ms +step:512/1695 train_time:48499ms step_avg:94.72ms +step:513/1695 train_time:48597ms step_avg:94.73ms +step:514/1695 train_time:48695ms step_avg:94.74ms +step:515/1695 train_time:48792ms step_avg:94.74ms +step:516/1695 train_time:48889ms step_avg:94.75ms +step:517/1695 train_time:48984ms step_avg:94.75ms +step:518/1695 train_time:49080ms step_avg:94.75ms +step:519/1695 train_time:49176ms step_avg:94.75ms +step:520/1695 train_time:49273ms step_avg:94.76ms +step:521/1695 train_time:49370ms step_avg:94.76ms +step:522/1695 train_time:49467ms step_avg:94.76ms +step:523/1695 train_time:49563ms step_avg:94.77ms +step:524/1695 train_time:49660ms step_avg:94.77ms +step:525/1695 train_time:49757ms step_avg:94.78ms +step:526/1695 train_time:49856ms step_avg:94.78ms +step:527/1695 train_time:49953ms step_avg:94.79ms +step:528/1695 train_time:50049ms step_avg:94.79ms +step:529/1695 train_time:50145ms step_avg:94.79ms +step:530/1695 train_time:50241ms step_avg:94.79ms +step:531/1695 train_time:50337ms step_avg:94.80ms +step:532/1695 train_time:50434ms step_avg:94.80ms +step:533/1695 train_time:50531ms step_avg:94.81ms +step:534/1695 train_time:50630ms step_avg:94.81ms +step:535/1695 train_time:50727ms step_avg:94.82ms +step:536/1695 train_time:50824ms step_avg:94.82ms +step:537/1695 train_time:50921ms step_avg:94.82ms +step:538/1695 train_time:51017ms step_avg:94.83ms +step:539/1695 train_time:51114ms step_avg:94.83ms +step:540/1695 train_time:51210ms step_avg:94.83ms +step:541/1695 train_time:51307ms step_avg:94.84ms +step:542/1695 train_time:51404ms step_avg:94.84ms +step:543/1695 train_time:51499ms step_avg:94.84ms +step:544/1695 train_time:51596ms step_avg:94.85ms +step:545/1695 train_time:51693ms step_avg:94.85ms +step:546/1695 train_time:51790ms step_avg:94.85ms +step:547/1695 train_time:51887ms step_avg:94.86ms +step:548/1695 train_time:51983ms step_avg:94.86ms +step:549/1695 train_time:52079ms step_avg:94.86ms +step:550/1695 train_time:52175ms step_avg:94.86ms +step:551/1695 train_time:52272ms step_avg:94.87ms +step:552/1695 train_time:52369ms step_avg:94.87ms +step:553/1695 train_time:52466ms step_avg:94.88ms +step:554/1695 train_time:52563ms step_avg:94.88ms +step:555/1695 train_time:52659ms step_avg:94.88ms +step:556/1695 train_time:52755ms step_avg:94.88ms +step:557/1695 train_time:52852ms step_avg:94.89ms +step:558/1695 train_time:52949ms step_avg:94.89ms +step:559/1695 train_time:53045ms step_avg:94.89ms +step:560/1695 train_time:53141ms step_avg:94.90ms +step:561/1695 train_time:53238ms step_avg:94.90ms +step:562/1695 train_time:53334ms step_avg:94.90ms +step:563/1695 train_time:53431ms step_avg:94.90ms +step:564/1695 train_time:53527ms step_avg:94.91ms +step:565/1695 train_time:53624ms step_avg:94.91ms +step:566/1695 train_time:53721ms step_avg:94.91ms +step:567/1695 train_time:53818ms step_avg:94.92ms +step:568/1695 train_time:53916ms step_avg:94.92ms +step:569/1695 train_time:54013ms step_avg:94.93ms +step:570/1695 train_time:54110ms step_avg:94.93ms +step:571/1695 train_time:54207ms step_avg:94.93ms +step:572/1695 train_time:54304ms step_avg:94.94ms +step:573/1695 train_time:54400ms step_avg:94.94ms +step:574/1695 train_time:54496ms step_avg:94.94ms +step:575/1695 train_time:54593ms step_avg:94.94ms +step:576/1695 train_time:54691ms step_avg:94.95ms +step:577/1695 train_time:54788ms step_avg:94.95ms +step:578/1695 train_time:54884ms step_avg:94.95ms +step:579/1695 train_time:54979ms step_avg:94.96ms +step:580/1695 train_time:55075ms step_avg:94.96ms +step:581/1695 train_time:55172ms step_avg:94.96ms +step:582/1695 train_time:55270ms step_avg:94.97ms +step:583/1695 train_time:55367ms step_avg:94.97ms +step:584/1695 train_time:55464ms step_avg:94.97ms +step:585/1695 train_time:55560ms step_avg:94.97ms +step:586/1695 train_time:55657ms step_avg:94.98ms +step:587/1695 train_time:55755ms step_avg:94.98ms +step:588/1695 train_time:55852ms step_avg:94.99ms +step:589/1695 train_time:55949ms step_avg:94.99ms +step:590/1695 train_time:56046ms step_avg:94.99ms +step:591/1695 train_time:56143ms step_avg:95.00ms +step:592/1695 train_time:56238ms step_avg:95.00ms +step:593/1695 train_time:56335ms step_avg:95.00ms +step:594/1695 train_time:56432ms step_avg:95.00ms +step:595/1695 train_time:56529ms step_avg:95.01ms +step:596/1695 train_time:56626ms step_avg:95.01ms +step:597/1695 train_time:56722ms step_avg:95.01ms +step:598/1695 train_time:56818ms step_avg:95.01ms +step:599/1695 train_time:56914ms step_avg:95.02ms +step:600/1695 train_time:57011ms step_avg:95.02ms +step:601/1695 train_time:57108ms step_avg:95.02ms +step:602/1695 train_time:57205ms step_avg:95.03ms +step:603/1695 train_time:57302ms step_avg:95.03ms +step:604/1695 train_time:57398ms step_avg:95.03ms +step:605/1695 train_time:57496ms step_avg:95.03ms +step:606/1695 train_time:57592ms step_avg:95.04ms +step:607/1695 train_time:57690ms step_avg:95.04ms +step:608/1695 train_time:57787ms step_avg:95.04ms +step:609/1695 train_time:57884ms step_avg:95.05ms +step:610/1695 train_time:57980ms step_avg:95.05ms +step:611/1695 train_time:58076ms step_avg:95.05ms +step:612/1695 train_time:58173ms step_avg:95.05ms +step:613/1695 train_time:58270ms step_avg:95.06ms +step:614/1695 train_time:58367ms step_avg:95.06ms +step:615/1695 train_time:58463ms step_avg:95.06ms +step:616/1695 train_time:58559ms step_avg:95.06ms +step:617/1695 train_time:58656ms step_avg:95.07ms +step:618/1695 train_time:58752ms step_avg:95.07ms +step:619/1695 train_time:58849ms step_avg:95.07ms +step:620/1695 train_time:58946ms step_avg:95.07ms +step:621/1695 train_time:59042ms step_avg:95.08ms +step:622/1695 train_time:59138ms step_avg:95.08ms +step:623/1695 train_time:59234ms step_avg:95.08ms +step:624/1695 train_time:59331ms step_avg:95.08ms +step:625/1695 train_time:59429ms step_avg:95.09ms +step:625/1695 val_loss:3.6468 train_time:59523ms step_avg:95.24ms +step:626/1695 train_time:59550ms step_avg:95.13ms +step:627/1695 train_time:59629ms step_avg:95.10ms +step:628/1695 train_time:59729ms step_avg:95.11ms +step:629/1695 train_time:60148ms step_avg:95.62ms +step:630/1695 train_time:60244ms step_avg:95.62ms +step:631/1695 train_time:60341ms step_avg:95.63ms +step:632/1695 train_time:60438ms step_avg:95.63ms +step:633/1695 train_time:60535ms step_avg:95.63ms +step:634/1695 train_time:60631ms step_avg:95.63ms +step:635/1695 train_time:60966ms step_avg:96.01ms +step:636/1695 train_time:61063ms step_avg:96.01ms +step:637/1695 train_time:61161ms step_avg:96.01ms +step:638/1695 train_time:61258ms step_avg:96.02ms +step:639/1695 train_time:61355ms step_avg:96.02ms +step:640/1695 train_time:61452ms step_avg:96.02ms +step:641/1695 train_time:61549ms step_avg:96.02ms +step:642/1695 train_time:61647ms step_avg:96.02ms +step:643/1695 train_time:61745ms step_avg:96.03ms +step:644/1695 train_time:61846ms step_avg:96.03ms +step:645/1695 train_time:61947ms step_avg:96.04ms +step:646/1695 train_time:62047ms step_avg:96.05ms +step:647/1695 train_time:62146ms step_avg:96.05ms +step:648/1695 train_time:62245ms step_avg:96.06ms +step:649/1695 train_time:62343ms step_avg:96.06ms +step:650/1695 train_time:62441ms step_avg:96.06ms +step:651/1695 train_time:62539ms step_avg:96.07ms +step:652/1695 train_time:62636ms step_avg:96.07ms +step:653/1695 train_time:62734ms step_avg:96.07ms +step:654/1695 train_time:62832ms step_avg:96.07ms +step:655/1695 train_time:62930ms step_avg:96.08ms +step:656/1695 train_time:63030ms step_avg:96.08ms +step:657/1695 train_time:63129ms step_avg:96.09ms +step:658/1695 train_time:63227ms step_avg:96.09ms +step:659/1695 train_time:63326ms step_avg:96.09ms +step:660/1695 train_time:63424ms step_avg:96.10ms +step:661/1695 train_time:63522ms step_avg:96.10ms +step:662/1695 train_time:63620ms step_avg:96.10ms +step:663/1695 train_time:63718ms step_avg:96.11ms +step:664/1695 train_time:63817ms step_avg:96.11ms +step:665/1695 train_time:63914ms step_avg:96.11ms +step:666/1695 train_time:64012ms step_avg:96.11ms +step:667/1695 train_time:64109ms step_avg:96.12ms +step:668/1695 train_time:64207ms step_avg:96.12ms +step:669/1695 train_time:64305ms step_avg:96.12ms +step:670/1695 train_time:64403ms step_avg:96.12ms +step:671/1695 train_time:64501ms step_avg:96.13ms +step:672/1695 train_time:64600ms step_avg:96.13ms +step:673/1695 train_time:64696ms step_avg:96.13ms +step:674/1695 train_time:64794ms step_avg:96.13ms +step:675/1695 train_time:64893ms step_avg:96.14ms +step:676/1695 train_time:64991ms step_avg:96.14ms +step:677/1695 train_time:65089ms step_avg:96.14ms +step:678/1695 train_time:65187ms step_avg:96.15ms +step:679/1695 train_time:65285ms step_avg:96.15ms +step:680/1695 train_time:65384ms step_avg:96.15ms +step:681/1695 train_time:65481ms step_avg:96.15ms +step:682/1695 train_time:65580ms step_avg:96.16ms +step:683/1695 train_time:65678ms step_avg:96.16ms +step:684/1695 train_time:65777ms step_avg:96.17ms +step:685/1695 train_time:65875ms step_avg:96.17ms +step:686/1695 train_time:65973ms step_avg:96.17ms +step:687/1695 train_time:66071ms step_avg:96.17ms +step:688/1695 train_time:66169ms step_avg:96.18ms +step:689/1695 train_time:66267ms step_avg:96.18ms +step:690/1695 train_time:66365ms step_avg:96.18ms +step:691/1695 train_time:66463ms step_avg:96.18ms +step:692/1695 train_time:66561ms step_avg:96.19ms +step:693/1695 train_time:66659ms step_avg:96.19ms +step:694/1695 train_time:66759ms step_avg:96.19ms +step:695/1695 train_time:66857ms step_avg:96.20ms +step:696/1695 train_time:66955ms step_avg:96.20ms +step:697/1695 train_time:67053ms step_avg:96.20ms +step:698/1695 train_time:67150ms step_avg:96.20ms +step:699/1695 train_time:67248ms step_avg:96.21ms +step:700/1695 train_time:67346ms step_avg:96.21ms +step:701/1695 train_time:67444ms step_avg:96.21ms +step:702/1695 train_time:67541ms step_avg:96.21ms +step:703/1695 train_time:67639ms step_avg:96.22ms +step:704/1695 train_time:67738ms step_avg:96.22ms +step:705/1695 train_time:67836ms step_avg:96.22ms +step:706/1695 train_time:67934ms step_avg:96.22ms +step:707/1695 train_time:68032ms step_avg:96.23ms +step:708/1695 train_time:68130ms step_avg:96.23ms +step:709/1695 train_time:68228ms step_avg:96.23ms +step:710/1695 train_time:68326ms step_avg:96.23ms +step:711/1695 train_time:68424ms step_avg:96.24ms +step:712/1695 train_time:68522ms step_avg:96.24ms +step:713/1695 train_time:68620ms step_avg:96.24ms +step:714/1695 train_time:68720ms step_avg:96.25ms +step:715/1695 train_time:68818ms step_avg:96.25ms +step:716/1695 train_time:68916ms step_avg:96.25ms +step:717/1695 train_time:69015ms step_avg:96.25ms +step:718/1695 train_time:69113ms step_avg:96.26ms +step:719/1695 train_time:69210ms step_avg:96.26ms +step:720/1695 train_time:69307ms step_avg:96.26ms +step:721/1695 train_time:69405ms step_avg:96.26ms +step:722/1695 train_time:69824ms step_avg:96.71ms +step:723/1695 train_time:69919ms step_avg:96.71ms +step:724/1695 train_time:70016ms step_avg:96.71ms +step:725/1695 train_time:70113ms step_avg:96.71ms +step:726/1695 train_time:70211ms step_avg:96.71ms +step:727/1695 train_time:70308ms step_avg:96.71ms +step:728/1695 train_time:70405ms step_avg:96.71ms +step:729/1695 train_time:70503ms step_avg:96.71ms +step:730/1695 train_time:70600ms step_avg:96.71ms +step:731/1695 train_time:70697ms step_avg:96.71ms +step:732/1695 train_time:70798ms step_avg:96.72ms +step:733/1695 train_time:70897ms step_avg:96.72ms +step:734/1695 train_time:70995ms step_avg:96.72ms +step:735/1695 train_time:71092ms step_avg:96.72ms +step:736/1695 train_time:71190ms step_avg:96.73ms +step:737/1695 train_time:71287ms step_avg:96.73ms +step:738/1695 train_time:71384ms step_avg:96.73ms +step:739/1695 train_time:71482ms step_avg:96.73ms +step:740/1695 train_time:71579ms step_avg:96.73ms +step:741/1695 train_time:71676ms step_avg:96.73ms +step:742/1695 train_time:71776ms step_avg:96.73ms +step:743/1695 train_time:71874ms step_avg:96.74ms +step:744/1695 train_time:71972ms step_avg:96.74ms +step:745/1695 train_time:72069ms step_avg:96.74ms +step:746/1695 train_time:72168ms step_avg:96.74ms +step:747/1695 train_time:72265ms step_avg:96.74ms +step:748/1695 train_time:72362ms step_avg:96.74ms +step:749/1695 train_time:72459ms step_avg:96.74ms +step:750/1695 train_time:72557ms step_avg:96.74ms +step:750/1695 val_loss:3.5854 train_time:72652ms step_avg:96.87ms +step:751/1695 train_time:72679ms step_avg:96.78ms +step:752/1695 train_time:72766ms step_avg:96.76ms +step:753/1695 train_time:72868ms step_avg:96.77ms +step:754/1695 train_time:72967ms step_avg:96.77ms +step:755/1695 train_time:73063ms step_avg:96.77ms +step:756/1695 train_time:73160ms step_avg:96.77ms +step:757/1695 train_time:73258ms step_avg:96.77ms +step:758/1695 train_time:73356ms step_avg:96.78ms +step:759/1695 train_time:73454ms step_avg:96.78ms +step:760/1695 train_time:73552ms step_avg:96.78ms +step:761/1695 train_time:73651ms step_avg:96.78ms +step:762/1695 train_time:73751ms step_avg:96.79ms +step:763/1695 train_time:73853ms step_avg:96.79ms +step:764/1695 train_time:73952ms step_avg:96.80ms +step:765/1695 train_time:74052ms step_avg:96.80ms +step:766/1695 train_time:74150ms step_avg:96.80ms +step:767/1695 train_time:74248ms step_avg:96.80ms +step:768/1695 train_time:74347ms step_avg:96.81ms +step:769/1695 train_time:74445ms step_avg:96.81ms +step:770/1695 train_time:74542ms step_avg:96.81ms +step:771/1695 train_time:74639ms step_avg:96.81ms +step:772/1695 train_time:74962ms step_avg:97.10ms +step:773/1695 train_time:75058ms step_avg:97.10ms +step:774/1695 train_time:75156ms step_avg:97.10ms +step:775/1695 train_time:75254ms step_avg:97.10ms +step:776/1695 train_time:75351ms step_avg:97.10ms +step:777/1695 train_time:75679ms step_avg:97.40ms +step:778/1695 train_time:75776ms step_avg:97.40ms +step:779/1695 train_time:75873ms step_avg:97.40ms +step:780/1695 train_time:75970ms step_avg:97.40ms +step:781/1695 train_time:76068ms step_avg:97.40ms +step:782/1695 train_time:76166ms step_avg:97.40ms +step:783/1695 train_time:76263ms step_avg:97.40ms +step:784/1695 train_time:76360ms step_avg:97.40ms +step:785/1695 train_time:76458ms step_avg:97.40ms +step:786/1695 train_time:76557ms step_avg:97.40ms +step:787/1695 train_time:76660ms step_avg:97.41ms +step:788/1695 train_time:76760ms step_avg:97.41ms +step:789/1695 train_time:76858ms step_avg:97.41ms +step:790/1695 train_time:76957ms step_avg:97.41ms +step:791/1695 train_time:77057ms step_avg:97.42ms +step:792/1695 train_time:77156ms step_avg:97.42ms +step:793/1695 train_time:77254ms step_avg:97.42ms +step:794/1695 train_time:77352ms step_avg:97.42ms +step:795/1695 train_time:77746ms step_avg:97.79ms +step:796/1695 train_time:77794ms step_avg:97.73ms +step:797/1695 train_time:77891ms step_avg:97.73ms +step:798/1695 train_time:77988ms step_avg:97.73ms +step:799/1695 train_time:78084ms step_avg:97.73ms +step:800/1695 train_time:78181ms step_avg:97.73ms +step:801/1695 train_time:78278ms step_avg:97.73ms +step:802/1695 train_time:78377ms step_avg:97.73ms +step:803/1695 train_time:78474ms step_avg:97.73ms +step:804/1695 train_time:78572ms step_avg:97.73ms +step:805/1695 train_time:78670ms step_avg:97.73ms +step:806/1695 train_time:78771ms step_avg:97.73ms +step:807/1695 train_time:78871ms step_avg:97.73ms +step:808/1695 train_time:78970ms step_avg:97.74ms +step:809/1695 train_time:79068ms step_avg:97.74ms +step:810/1695 train_time:79166ms step_avg:97.74ms +step:811/1695 train_time:79263ms step_avg:97.74ms +step:812/1695 train_time:79361ms step_avg:97.73ms +step:813/1695 train_time:79458ms step_avg:97.73ms +step:814/1695 train_time:79556ms step_avg:97.73ms +step:815/1695 train_time:79654ms step_avg:97.73ms +step:816/1695 train_time:79753ms step_avg:97.74ms +step:817/1695 train_time:79852ms step_avg:97.74ms +step:818/1695 train_time:79951ms step_avg:97.74ms +step:819/1695 train_time:80051ms step_avg:97.74ms +step:820/1695 train_time:80149ms step_avg:97.74ms +step:821/1695 train_time:80248ms step_avg:97.74ms +step:822/1695 train_time:80347ms step_avg:97.75ms +step:823/1695 train_time:80445ms step_avg:97.75ms +step:824/1695 train_time:80543ms step_avg:97.75ms +step:825/1695 train_time:80640ms step_avg:97.75ms +step:826/1695 train_time:80738ms step_avg:97.75ms +step:827/1695 train_time:80837ms step_avg:97.75ms +step:828/1695 train_time:80935ms step_avg:97.75ms +step:829/1695 train_time:81033ms step_avg:97.75ms +step:830/1695 train_time:81132ms step_avg:97.75ms +step:831/1695 train_time:81231ms step_avg:97.75ms +step:832/1695 train_time:81329ms step_avg:97.75ms +step:833/1695 train_time:81427ms step_avg:97.75ms +step:834/1695 train_time:81526ms step_avg:97.75ms +step:835/1695 train_time:81624ms step_avg:97.75ms +step:836/1695 train_time:81723ms step_avg:97.75ms +step:837/1695 train_time:81821ms step_avg:97.75ms +step:838/1695 train_time:81919ms step_avg:97.76ms +step:839/1695 train_time:82017ms step_avg:97.76ms +step:840/1695 train_time:82116ms step_avg:97.76ms +step:841/1695 train_time:82214ms step_avg:97.76ms +step:842/1695 train_time:82313ms step_avg:97.76ms +step:843/1695 train_time:82412ms step_avg:97.76ms +step:844/1695 train_time:82512ms step_avg:97.76ms +step:845/1695 train_time:82612ms step_avg:97.77ms +step:846/1695 train_time:82712ms step_avg:97.77ms +step:847/1695 train_time:82811ms step_avg:97.77ms +step:848/1695 train_time:82910ms step_avg:97.77ms +step:849/1695 train_time:83009ms step_avg:97.77ms +step:850/1695 train_time:83109ms step_avg:97.77ms +step:851/1695 train_time:83207ms step_avg:97.78ms +step:852/1695 train_time:83304ms step_avg:97.77ms +step:853/1695 train_time:83402ms step_avg:97.77ms +step:854/1695 train_time:83499ms step_avg:97.77ms +step:855/1695 train_time:83598ms step_avg:97.78ms +step:856/1695 train_time:83697ms step_avg:97.78ms +step:857/1695 train_time:83796ms step_avg:97.78ms +step:858/1695 train_time:83894ms step_avg:97.78ms +step:859/1695 train_time:83993ms step_avg:97.78ms +step:860/1695 train_time:84092ms step_avg:97.78ms +step:861/1695 train_time:84190ms step_avg:97.78ms +step:862/1695 train_time:84288ms step_avg:97.78ms +step:863/1695 train_time:84386ms step_avg:97.78ms +step:864/1695 train_time:84486ms step_avg:97.78ms +step:865/1695 train_time:84584ms step_avg:97.78ms +step:866/1695 train_time:84682ms step_avg:97.78ms +step:867/1695 train_time:84780ms step_avg:97.79ms +step:868/1695 train_time:84878ms step_avg:97.79ms +step:869/1695 train_time:84977ms step_avg:97.79ms +step:870/1695 train_time:85076ms step_avg:97.79ms +step:871/1695 train_time:85174ms step_avg:97.79ms +step:872/1695 train_time:85273ms step_avg:97.79ms +step:873/1695 train_time:85371ms step_avg:97.79ms +step:874/1695 train_time:85470ms step_avg:97.79ms +step:875/1695 train_time:85568ms step_avg:97.79ms +step:875/1695 val_loss:3.5355 train_time:85664ms step_avg:97.90ms +step:876/1695 train_time:85691ms step_avg:97.82ms +step:877/1695 train_time:85779ms step_avg:97.81ms +step:878/1695 train_time:85879ms step_avg:97.81ms +step:879/1695 train_time:86210ms step_avg:98.08ms +step:880/1695 train_time:86304ms step_avg:98.07ms +step:881/1695 train_time:86402ms step_avg:98.07ms +step:882/1695 train_time:86501ms step_avg:98.07ms +step:883/1695 train_time:86600ms step_avg:98.08ms +step:884/1695 train_time:86699ms step_avg:98.08ms +step:885/1695 train_time:86798ms step_avg:98.08ms +step:886/1695 train_time:86897ms step_avg:98.08ms +step:887/1695 train_time:86995ms step_avg:98.08ms +step:888/1695 train_time:87100ms step_avg:98.09ms +step:889/1695 train_time:87203ms step_avg:98.09ms +step:890/1695 train_time:87302ms step_avg:98.09ms +step:891/1695 train_time:87401ms step_avg:98.09ms +step:892/1695 train_time:87500ms step_avg:98.09ms +step:893/1695 train_time:87599ms step_avg:98.10ms +step:894/1695 train_time:87698ms step_avg:98.10ms +step:895/1695 train_time:87797ms step_avg:98.10ms +step:896/1695 train_time:87896ms step_avg:98.10ms +step:897/1695 train_time:87996ms step_avg:98.10ms +step:898/1695 train_time:88099ms step_avg:98.11ms +step:899/1695 train_time:88200ms step_avg:98.11ms +step:900/1695 train_time:88300ms step_avg:98.11ms +step:901/1695 train_time:88400ms step_avg:98.11ms +step:902/1695 train_time:88499ms step_avg:98.11ms +step:903/1695 train_time:88598ms step_avg:98.12ms +step:904/1695 train_time:88697ms step_avg:98.12ms +step:905/1695 train_time:88797ms step_avg:98.12ms +step:906/1695 train_time:88896ms step_avg:98.12ms +step:907/1695 train_time:88996ms step_avg:98.12ms +step:908/1695 train_time:89097ms step_avg:98.12ms +step:909/1695 train_time:89200ms step_avg:98.13ms +step:910/1695 train_time:89299ms step_avg:98.13ms +step:911/1695 train_time:89399ms step_avg:98.13ms +step:912/1695 train_time:89499ms step_avg:98.13ms +step:913/1695 train_time:89599ms step_avg:98.14ms +step:914/1695 train_time:89698ms step_avg:98.14ms +step:915/1695 train_time:89797ms step_avg:98.14ms +step:916/1695 train_time:89897ms step_avg:98.14ms +step:917/1695 train_time:89997ms step_avg:98.14ms +step:918/1695 train_time:90099ms step_avg:98.15ms +step:919/1695 train_time:90199ms step_avg:98.15ms +step:920/1695 train_time:90299ms step_avg:98.15ms +step:921/1695 train_time:90399ms step_avg:98.15ms +step:922/1695 train_time:90499ms step_avg:98.15ms +step:923/1695 train_time:90599ms step_avg:98.16ms +step:924/1695 train_time:90698ms step_avg:98.16ms +step:925/1695 train_time:90798ms step_avg:98.16ms +step:926/1695 train_time:90898ms step_avg:98.16ms +step:927/1695 train_time:90997ms step_avg:98.16ms +step:928/1695 train_time:91098ms step_avg:98.17ms +step:929/1695 train_time:91198ms step_avg:98.17ms +step:930/1695 train_time:91298ms step_avg:98.17ms +step:931/1695 train_time:91398ms step_avg:98.17ms +step:932/1695 train_time:91498ms step_avg:98.17ms +step:933/1695 train_time:91597ms step_avg:98.17ms +step:934/1695 train_time:91697ms step_avg:98.18ms +step:935/1695 train_time:91797ms step_avg:98.18ms +step:936/1695 train_time:91898ms step_avg:98.18ms +step:937/1695 train_time:91997ms step_avg:98.18ms +step:938/1695 train_time:92098ms step_avg:98.19ms +step:939/1695 train_time:92199ms step_avg:98.19ms +step:940/1695 train_time:92299ms step_avg:98.19ms +step:941/1695 train_time:92399ms step_avg:98.19ms +step:942/1695 train_time:92499ms step_avg:98.19ms +step:943/1695 train_time:92599ms step_avg:98.20ms +step:944/1695 train_time:92698ms step_avg:98.20ms +step:945/1695 train_time:92799ms step_avg:98.20ms +step:946/1695 train_time:92899ms step_avg:98.20ms +step:947/1695 train_time:92998ms step_avg:98.20ms +step:948/1695 train_time:93098ms step_avg:98.20ms +step:949/1695 train_time:93198ms step_avg:98.21ms +step:950/1695 train_time:93298ms step_avg:98.21ms +step:951/1695 train_time:93397ms step_avg:98.21ms +step:952/1695 train_time:93498ms step_avg:98.21ms +step:953/1695 train_time:93598ms step_avg:98.21ms +step:954/1695 train_time:93698ms step_avg:98.22ms +step:955/1695 train_time:93797ms step_avg:98.22ms +step:956/1695 train_time:93898ms step_avg:98.22ms +step:957/1695 train_time:93997ms step_avg:98.22ms +step:958/1695 train_time:94097ms step_avg:98.22ms +step:959/1695 train_time:94197ms step_avg:98.22ms +step:960/1695 train_time:94297ms step_avg:98.23ms +step:961/1695 train_time:94396ms step_avg:98.23ms +step:962/1695 train_time:94495ms step_avg:98.23ms +step:963/1695 train_time:94595ms step_avg:98.23ms +step:964/1695 train_time:94695ms step_avg:98.23ms +step:965/1695 train_time:94795ms step_avg:98.23ms +step:966/1695 train_time:94895ms step_avg:98.23ms +step:967/1695 train_time:94995ms step_avg:98.24ms +step:968/1695 train_time:95094ms step_avg:98.24ms +step:969/1695 train_time:95195ms step_avg:98.24ms +step:970/1695 train_time:95295ms step_avg:98.24ms +step:971/1695 train_time:95394ms step_avg:98.24ms +step:972/1695 train_time:95495ms step_avg:98.25ms +step:973/1695 train_time:95595ms step_avg:98.25ms +step:974/1695 train_time:95695ms step_avg:98.25ms +step:975/1695 train_time:95795ms step_avg:98.25ms +step:976/1695 train_time:95895ms step_avg:98.25ms +step:977/1695 train_time:95995ms step_avg:98.26ms +step:978/1695 train_time:96094ms step_avg:98.26ms +step:979/1695 train_time:96195ms step_avg:98.26ms +step:980/1695 train_time:96295ms step_avg:98.26ms +step:981/1695 train_time:96395ms step_avg:98.26ms +step:982/1695 train_time:96496ms step_avg:98.26ms +step:983/1695 train_time:96596ms step_avg:98.27ms +step:984/1695 train_time:96696ms step_avg:98.27ms +step:985/1695 train_time:96796ms step_avg:98.27ms +step:986/1695 train_time:96897ms step_avg:98.27ms +step:987/1695 train_time:96998ms step_avg:98.28ms +step:988/1695 train_time:97098ms step_avg:98.28ms +step:989/1695 train_time:97198ms step_avg:98.28ms +step:990/1695 train_time:97297ms step_avg:98.28ms +step:991/1695 train_time:97399ms step_avg:98.28ms +step:992/1695 train_time:97499ms step_avg:98.28ms +step:993/1695 train_time:97599ms step_avg:98.29ms +step:994/1695 train_time:97699ms step_avg:98.29ms +step:995/1695 train_time:97798ms step_avg:98.29ms +step:996/1695 train_time:97898ms step_avg:98.29ms +step:997/1695 train_time:97998ms step_avg:98.29ms +step:998/1695 train_time:98097ms step_avg:98.29ms +step:999/1695 train_time:98198ms step_avg:98.30ms +step:1000/1695 train_time:98297ms step_avg:98.30ms +step:1000/1695 val_loss:3.4893 train_time:98396ms step_avg:98.40ms +step:1001/1695 train_time:98423ms step_avg:98.32ms +step:1002/1695 train_time:98510ms step_avg:98.31ms +step:1003/1695 train_time:98610ms step_avg:98.31ms +step:1004/1695 train_time:98710ms step_avg:98.32ms +step:1005/1695 train_time:98810ms step_avg:98.32ms +step:1006/1695 train_time:98910ms step_avg:98.32ms +step:1007/1695 train_time:99009ms step_avg:98.32ms +step:1008/1695 train_time:99108ms step_avg:98.32ms +step:1009/1695 train_time:99208ms step_avg:98.32ms +step:1010/1695 train_time:99307ms step_avg:98.32ms +step:1011/1695 train_time:99410ms step_avg:98.33ms +step:1012/1695 train_time:99513ms step_avg:98.33ms +step:1013/1695 train_time:99614ms step_avg:98.34ms +step:1014/1695 train_time:99714ms step_avg:98.34ms +step:1015/1695 train_time:99813ms step_avg:98.34ms +step:1016/1695 train_time:99913ms step_avg:98.34ms +step:1017/1695 train_time:100012ms step_avg:98.34ms +step:1018/1695 train_time:100112ms step_avg:98.34ms +step:1019/1695 train_time:100211ms step_avg:98.34ms +step:1020/1695 train_time:100312ms step_avg:98.35ms +step:1021/1695 train_time:100414ms step_avg:98.35ms +step:1022/1695 train_time:100514ms step_avg:98.35ms +step:1023/1695 train_time:100615ms step_avg:98.35ms +step:1024/1695 train_time:100717ms step_avg:98.36ms +step:1025/1695 train_time:100816ms step_avg:98.36ms +step:1026/1695 train_time:100916ms step_avg:98.36ms +step:1027/1695 train_time:101015ms step_avg:98.36ms +step:1028/1695 train_time:101114ms step_avg:98.36ms +step:1029/1695 train_time:101215ms step_avg:98.36ms +step:1030/1695 train_time:101314ms step_avg:98.36ms +step:1031/1695 train_time:101414ms step_avg:98.37ms +step:1032/1695 train_time:101514ms step_avg:98.37ms +step:1033/1695 train_time:101614ms step_avg:98.37ms +step:1034/1695 train_time:101714ms step_avg:98.37ms +step:1035/1695 train_time:101815ms step_avg:98.37ms +step:1036/1695 train_time:101914ms step_avg:98.37ms +step:1037/1695 train_time:102015ms step_avg:98.38ms +step:1038/1695 train_time:102114ms step_avg:98.38ms +step:1039/1695 train_time:102213ms step_avg:98.38ms +step:1040/1695 train_time:102312ms step_avg:98.38ms +step:1041/1695 train_time:102413ms step_avg:98.38ms +step:1042/1695 train_time:102512ms step_avg:98.38ms +step:1043/1695 train_time:102613ms step_avg:98.38ms +step:1044/1695 train_time:102713ms step_avg:98.38ms +step:1045/1695 train_time:102813ms step_avg:98.39ms +step:1046/1695 train_time:102913ms step_avg:98.39ms +step:1047/1695 train_time:103013ms step_avg:98.39ms +step:1048/1695 train_time:103113ms step_avg:98.39ms +step:1049/1695 train_time:103212ms step_avg:98.39ms +step:1050/1695 train_time:103312ms step_avg:98.39ms +step:1051/1695 train_time:103413ms step_avg:98.39ms +step:1052/1695 train_time:103513ms step_avg:98.40ms +step:1053/1695 train_time:103613ms step_avg:98.40ms +step:1054/1695 train_time:103713ms step_avg:98.40ms +step:1055/1695 train_time:103813ms step_avg:98.40ms +step:1056/1695 train_time:103913ms step_avg:98.40ms +step:1057/1695 train_time:104013ms step_avg:98.40ms +step:1058/1695 train_time:104113ms step_avg:98.41ms +step:1059/1695 train_time:104213ms step_avg:98.41ms +step:1060/1695 train_time:104312ms step_avg:98.41ms +step:1061/1695 train_time:104411ms step_avg:98.41ms +step:1062/1695 train_time:104511ms step_avg:98.41ms +step:1063/1695 train_time:104612ms step_avg:98.41ms +step:1064/1695 train_time:104713ms step_avg:98.41ms +step:1065/1695 train_time:104813ms step_avg:98.42ms +step:1066/1695 train_time:104913ms step_avg:98.42ms +step:1067/1695 train_time:105013ms step_avg:98.42ms +step:1068/1695 train_time:105113ms step_avg:98.42ms +step:1069/1695 train_time:105212ms step_avg:98.42ms +step:1070/1695 train_time:105312ms step_avg:98.42ms +step:1071/1695 train_time:105412ms step_avg:98.42ms +step:1072/1695 train_time:105512ms step_avg:98.43ms +step:1073/1695 train_time:105613ms step_avg:98.43ms +step:1074/1695 train_time:105713ms step_avg:98.43ms +step:1075/1695 train_time:105813ms step_avg:98.43ms +step:1076/1695 train_time:105912ms step_avg:98.43ms +step:1077/1695 train_time:106014ms step_avg:98.43ms +step:1078/1695 train_time:106113ms step_avg:98.44ms +step:1079/1695 train_time:106213ms step_avg:98.44ms +step:1080/1695 train_time:106312ms step_avg:98.44ms +step:1081/1695 train_time:106412ms step_avg:98.44ms +step:1082/1695 train_time:106512ms step_avg:98.44ms +step:1083/1695 train_time:106612ms step_avg:98.44ms +step:1084/1695 train_time:106712ms step_avg:98.44ms +step:1085/1695 train_time:106813ms step_avg:98.44ms +step:1086/1695 train_time:106913ms step_avg:98.45ms +step:1087/1695 train_time:107013ms step_avg:98.45ms +step:1088/1695 train_time:107113ms step_avg:98.45ms +step:1089/1695 train_time:107213ms step_avg:98.45ms +step:1090/1695 train_time:107313ms step_avg:98.45ms +step:1091/1695 train_time:107413ms step_avg:98.45ms +step:1092/1695 train_time:107513ms step_avg:98.45ms +step:1093/1695 train_time:107613ms step_avg:98.46ms +step:1094/1695 train_time:107713ms step_avg:98.46ms +step:1095/1695 train_time:107813ms step_avg:98.46ms +step:1096/1695 train_time:107914ms step_avg:98.46ms +step:1097/1695 train_time:108013ms step_avg:98.46ms +step:1098/1695 train_time:108113ms step_avg:98.46ms +step:1099/1695 train_time:108212ms step_avg:98.46ms +step:1100/1695 train_time:108313ms step_avg:98.47ms +step:1101/1695 train_time:108412ms step_avg:98.47ms +step:1102/1695 train_time:108512ms step_avg:98.47ms +step:1103/1695 train_time:108612ms step_avg:98.47ms +step:1104/1695 train_time:108712ms step_avg:98.47ms +step:1105/1695 train_time:108812ms step_avg:98.47ms +step:1106/1695 train_time:108913ms step_avg:98.48ms +step:1107/1695 train_time:109013ms step_avg:98.48ms +step:1108/1695 train_time:109113ms step_avg:98.48ms +step:1109/1695 train_time:109213ms step_avg:98.48ms +step:1110/1695 train_time:109313ms step_avg:98.48ms +step:1111/1695 train_time:109413ms step_avg:98.48ms +step:1112/1695 train_time:109513ms step_avg:98.48ms +step:1113/1695 train_time:109613ms step_avg:98.48ms +step:1114/1695 train_time:109713ms step_avg:98.49ms +step:1115/1695 train_time:109813ms step_avg:98.49ms +step:1116/1695 train_time:109913ms step_avg:98.49ms +step:1117/1695 train_time:110012ms step_avg:98.49ms +step:1118/1695 train_time:110111ms step_avg:98.49ms +step:1119/1695 train_time:110211ms step_avg:98.49ms +step:1120/1695 train_time:110313ms step_avg:98.49ms +step:1121/1695 train_time:110413ms step_avg:98.50ms +step:1122/1695 train_time:110513ms step_avg:98.50ms +step:1123/1695 train_time:110613ms step_avg:98.50ms +step:1124/1695 train_time:110713ms step_avg:98.50ms +step:1125/1695 train_time:110813ms step_avg:98.50ms +step:1125/1695 val_loss:3.4387 train_time:110911ms step_avg:98.59ms +step:1126/1695 train_time:110938ms step_avg:98.52ms +step:1127/1695 train_time:111023ms step_avg:98.51ms +step:1128/1695 train_time:111124ms step_avg:98.51ms +step:1129/1695 train_time:111226ms step_avg:98.52ms +step:1130/1695 train_time:111327ms step_avg:98.52ms +step:1131/1695 train_time:111427ms step_avg:98.52ms +step:1132/1695 train_time:111526ms step_avg:98.52ms +step:1133/1695 train_time:111626ms step_avg:98.52ms +step:1134/1695 train_time:111726ms step_avg:98.52ms +step:1135/1695 train_time:111825ms step_avg:98.52ms +step:1136/1695 train_time:111927ms step_avg:98.53ms +step:1137/1695 train_time:112031ms step_avg:98.53ms +step:1138/1695 train_time:112132ms step_avg:98.53ms +step:1139/1695 train_time:112233ms step_avg:98.54ms +step:1140/1695 train_time:112333ms step_avg:98.54ms +step:1141/1695 train_time:112433ms step_avg:98.54ms +step:1142/1695 train_time:112533ms step_avg:98.54ms +step:1143/1695 train_time:112634ms step_avg:98.54ms +step:1144/1695 train_time:112734ms step_avg:98.54ms +step:1145/1695 train_time:112837ms step_avg:98.55ms +step:1146/1695 train_time:112939ms step_avg:98.55ms +step:1147/1695 train_time:113041ms step_avg:98.55ms +step:1148/1695 train_time:113142ms step_avg:98.56ms +step:1149/1695 train_time:113244ms step_avg:98.56ms +step:1150/1695 train_time:113346ms step_avg:98.56ms +step:1151/1695 train_time:113446ms step_avg:98.56ms +step:1152/1695 train_time:113547ms step_avg:98.57ms +step:1153/1695 train_time:113648ms step_avg:98.57ms +step:1154/1695 train_time:113749ms step_avg:98.57ms +step:1155/1695 train_time:113850ms step_avg:98.57ms +step:1156/1695 train_time:113951ms step_avg:98.57ms +step:1157/1695 train_time:114054ms step_avg:98.58ms +step:1158/1695 train_time:114154ms step_avg:98.58ms +step:1159/1695 train_time:114255ms step_avg:98.58ms +step:1160/1695 train_time:114355ms step_avg:98.58ms +step:1161/1695 train_time:114455ms step_avg:98.58ms +step:1162/1695 train_time:114554ms step_avg:98.58ms +step:1163/1695 train_time:114657ms step_avg:98.59ms +step:1164/1695 train_time:114757ms step_avg:98.59ms +step:1165/1695 train_time:114859ms step_avg:98.59ms +step:1166/1695 train_time:114961ms step_avg:98.59ms +step:1167/1695 train_time:115063ms step_avg:98.60ms +step:1168/1695 train_time:115165ms step_avg:98.60ms +step:1169/1695 train_time:115266ms step_avg:98.60ms +step:1170/1695 train_time:115367ms step_avg:98.60ms +step:1171/1695 train_time:115467ms step_avg:98.61ms +step:1172/1695 train_time:115570ms step_avg:98.61ms +step:1173/1695 train_time:115670ms step_avg:98.61ms +step:1174/1695 train_time:115771ms step_avg:98.61ms +step:1175/1695 train_time:115871ms step_avg:98.61ms +step:1176/1695 train_time:115973ms step_avg:98.62ms +step:1177/1695 train_time:116073ms step_avg:98.62ms +step:1178/1695 train_time:116173ms step_avg:98.62ms +step:1179/1695 train_time:116276ms step_avg:98.62ms +step:1180/1695 train_time:116375ms step_avg:98.62ms +step:1181/1695 train_time:116476ms step_avg:98.63ms +step:1182/1695 train_time:116577ms step_avg:98.63ms +step:1183/1695 train_time:116677ms step_avg:98.63ms +step:1184/1695 train_time:116780ms step_avg:98.63ms +step:1185/1695 train_time:116882ms step_avg:98.63ms +step:1186/1695 train_time:116984ms step_avg:98.64ms +step:1187/1695 train_time:117085ms step_avg:98.64ms +step:1188/1695 train_time:117186ms step_avg:98.64ms +step:1189/1695 train_time:117286ms step_avg:98.64ms +step:1190/1695 train_time:117387ms step_avg:98.64ms +step:1191/1695 train_time:117490ms step_avg:98.65ms +step:1192/1695 train_time:117591ms step_avg:98.65ms +step:1193/1695 train_time:117692ms step_avg:98.65ms +step:1194/1695 train_time:117794ms step_avg:98.65ms +step:1195/1695 train_time:117894ms step_avg:98.66ms +step:1196/1695 train_time:117994ms step_avg:98.66ms +step:1197/1695 train_time:118095ms step_avg:98.66ms +step:1198/1695 train_time:118195ms step_avg:98.66ms +step:1199/1695 train_time:118296ms step_avg:98.66ms +step:1200/1695 train_time:118396ms step_avg:98.66ms +step:1201/1695 train_time:118497ms step_avg:98.67ms +step:1202/1695 train_time:118600ms step_avg:98.67ms +step:1203/1695 train_time:118702ms step_avg:98.67ms +step:1204/1695 train_time:118804ms step_avg:98.67ms +step:1205/1695 train_time:118904ms step_avg:98.68ms +step:1206/1695 train_time:119005ms step_avg:98.68ms +step:1207/1695 train_time:119107ms step_avg:98.68ms +step:1208/1695 train_time:119209ms step_avg:98.68ms +step:1209/1695 train_time:119309ms step_avg:98.68ms +step:1210/1695 train_time:119410ms step_avg:98.69ms +step:1211/1695 train_time:119512ms step_avg:98.69ms +step:1212/1695 train_time:119613ms step_avg:98.69ms +step:1213/1695 train_time:119714ms step_avg:98.69ms +step:1214/1695 train_time:119814ms step_avg:98.69ms +step:1215/1695 train_time:119914ms step_avg:98.69ms +step:1216/1695 train_time:120016ms step_avg:98.70ms +step:1217/1695 train_time:120117ms step_avg:98.70ms +step:1218/1695 train_time:120218ms step_avg:98.70ms +step:1219/1695 train_time:120320ms step_avg:98.70ms +step:1220/1695 train_time:120422ms step_avg:98.71ms +step:1221/1695 train_time:120523ms step_avg:98.71ms +step:1222/1695 train_time:120623ms step_avg:98.71ms +step:1223/1695 train_time:120727ms step_avg:98.71ms +step:1224/1695 train_time:120827ms step_avg:98.72ms +step:1225/1695 train_time:120928ms step_avg:98.72ms +step:1226/1695 train_time:121029ms step_avg:98.72ms +step:1227/1695 train_time:121130ms step_avg:98.72ms +step:1228/1695 train_time:121231ms step_avg:98.72ms +step:1229/1695 train_time:121332ms step_avg:98.72ms +step:1230/1695 train_time:121432ms step_avg:98.73ms +step:1231/1695 train_time:121533ms step_avg:98.73ms +step:1232/1695 train_time:121634ms step_avg:98.73ms +step:1233/1695 train_time:121734ms step_avg:98.73ms +step:1234/1695 train_time:121836ms step_avg:98.73ms +step:1235/1695 train_time:121936ms step_avg:98.73ms +step:1236/1695 train_time:122038ms step_avg:98.74ms +step:1237/1695 train_time:122140ms step_avg:98.74ms +step:1238/1695 train_time:122242ms step_avg:98.74ms +step:1239/1695 train_time:122345ms step_avg:98.75ms +step:1240/1695 train_time:122445ms step_avg:98.75ms +step:1241/1695 train_time:122546ms step_avg:98.75ms +step:1242/1695 train_time:122647ms step_avg:98.75ms +step:1243/1695 train_time:122748ms step_avg:98.75ms +step:1244/1695 train_time:122850ms step_avg:98.75ms +step:1245/1695 train_time:122951ms step_avg:98.76ms +step:1246/1695 train_time:123053ms step_avg:98.76ms +step:1247/1695 train_time:123154ms step_avg:98.76ms +step:1248/1695 train_time:123256ms step_avg:98.76ms +step:1249/1695 train_time:123356ms step_avg:98.76ms +step:1250/1695 train_time:123455ms step_avg:98.76ms +step:1250/1695 val_loss:3.3926 train_time:123554ms step_avg:98.84ms +step:1251/1695 train_time:123581ms step_avg:98.79ms +step:1252/1695 train_time:123668ms step_avg:98.78ms +step:1253/1695 train_time:123770ms step_avg:98.78ms +step:1254/1695 train_time:123872ms step_avg:98.78ms +step:1255/1695 train_time:123973ms step_avg:98.78ms +step:1256/1695 train_time:124073ms step_avg:98.78ms +step:1257/1695 train_time:124173ms step_avg:98.79ms +step:1258/1695 train_time:124273ms step_avg:98.79ms +step:1259/1695 train_time:124373ms step_avg:98.79ms +step:1260/1695 train_time:124473ms step_avg:98.79ms +step:1261/1695 train_time:124576ms step_avg:98.79ms +step:1262/1695 train_time:124678ms step_avg:98.79ms +step:1263/1695 train_time:124779ms step_avg:98.80ms +step:1264/1695 train_time:124880ms step_avg:98.80ms +step:1265/1695 train_time:124981ms step_avg:98.80ms +step:1266/1695 train_time:125081ms step_avg:98.80ms +step:1267/1695 train_time:125182ms step_avg:98.80ms +step:1268/1695 train_time:125282ms step_avg:98.80ms +step:1269/1695 train_time:125383ms step_avg:98.80ms +step:1270/1695 train_time:125485ms step_avg:98.81ms +step:1271/1695 train_time:125588ms step_avg:98.81ms +step:1272/1695 train_time:125689ms step_avg:98.81ms +step:1273/1695 train_time:125790ms step_avg:98.81ms +step:1274/1695 train_time:125891ms step_avg:98.82ms +step:1275/1695 train_time:125992ms step_avg:98.82ms +step:1276/1695 train_time:126094ms step_avg:98.82ms +step:1277/1695 train_time:126195ms step_avg:98.82ms +step:1278/1695 train_time:126295ms step_avg:98.82ms +step:1279/1695 train_time:126396ms step_avg:98.82ms +step:1280/1695 train_time:126496ms step_avg:98.83ms +step:1281/1695 train_time:126597ms step_avg:98.83ms +step:1282/1695 train_time:126697ms step_avg:98.83ms +step:1283/1695 train_time:126797ms step_avg:98.83ms +step:1284/1695 train_time:126898ms step_avg:98.83ms +step:1285/1695 train_time:126999ms step_avg:98.83ms +step:1286/1695 train_time:127100ms step_avg:98.83ms +step:1287/1695 train_time:127200ms step_avg:98.83ms +step:1288/1695 train_time:127302ms step_avg:98.84ms +step:1289/1695 train_time:127404ms step_avg:98.84ms +step:1290/1695 train_time:127504ms step_avg:98.84ms +step:1291/1695 train_time:127605ms step_avg:98.84ms +step:1292/1695 train_time:127707ms step_avg:98.84ms +step:1293/1695 train_time:127809ms step_avg:98.85ms +step:1294/1695 train_time:127912ms step_avg:98.85ms +step:1295/1695 train_time:128014ms step_avg:98.85ms +step:1296/1695 train_time:128114ms step_avg:98.85ms +step:1297/1695 train_time:128216ms step_avg:98.86ms +step:1298/1695 train_time:128317ms step_avg:98.86ms +step:1299/1695 train_time:128418ms step_avg:98.86ms +step:1300/1695 train_time:128518ms step_avg:98.86ms +step:1301/1695 train_time:128619ms step_avg:98.86ms +step:1302/1695 train_time:128720ms step_avg:98.86ms +step:1303/1695 train_time:128821ms step_avg:98.87ms +step:1304/1695 train_time:128923ms step_avg:98.87ms +step:1305/1695 train_time:129024ms step_avg:98.87ms +step:1306/1695 train_time:129127ms step_avg:98.87ms +step:1307/1695 train_time:129229ms step_avg:98.87ms +step:1308/1695 train_time:129330ms step_avg:98.88ms +step:1309/1695 train_time:129433ms step_avg:98.88ms +step:1310/1695 train_time:129535ms step_avg:98.88ms +step:1311/1695 train_time:129636ms step_avg:98.88ms +step:1312/1695 train_time:129736ms step_avg:98.88ms +step:1313/1695 train_time:129838ms step_avg:98.89ms +step:1314/1695 train_time:129938ms step_avg:98.89ms +step:1315/1695 train_time:130038ms step_avg:98.89ms +step:1316/1695 train_time:130139ms step_avg:98.89ms +step:1317/1695 train_time:130239ms step_avg:98.89ms +step:1318/1695 train_time:130339ms step_avg:98.89ms +step:1319/1695 train_time:130441ms step_avg:98.89ms +step:1320/1695 train_time:130543ms step_avg:98.90ms +step:1321/1695 train_time:130646ms step_avg:98.90ms +step:1322/1695 train_time:130749ms step_avg:98.90ms +step:1323/1695 train_time:130850ms step_avg:98.90ms +step:1324/1695 train_time:130952ms step_avg:98.91ms +step:1325/1695 train_time:131053ms step_avg:98.91ms +step:1326/1695 train_time:131154ms step_avg:98.91ms +step:1327/1695 train_time:131255ms step_avg:98.91ms +step:1328/1695 train_time:131355ms step_avg:98.91ms +step:1329/1695 train_time:131456ms step_avg:98.91ms +step:1330/1695 train_time:131556ms step_avg:98.91ms +step:1331/1695 train_time:131657ms step_avg:98.92ms +step:1332/1695 train_time:131758ms step_avg:98.92ms +step:1333/1695 train_time:131859ms step_avg:98.92ms +step:1334/1695 train_time:131960ms step_avg:98.92ms +step:1335/1695 train_time:132062ms step_avg:98.92ms +step:1336/1695 train_time:132163ms step_avg:98.92ms +step:1337/1695 train_time:132265ms step_avg:98.93ms +step:1338/1695 train_time:132366ms step_avg:98.93ms +step:1339/1695 train_time:132467ms step_avg:98.93ms +step:1340/1695 train_time:132569ms step_avg:98.93ms +step:1341/1695 train_time:132670ms step_avg:98.93ms +step:1342/1695 train_time:132770ms step_avg:98.93ms +step:1343/1695 train_time:132873ms step_avg:98.94ms +step:1344/1695 train_time:132974ms step_avg:98.94ms +step:1345/1695 train_time:133074ms step_avg:98.94ms +step:1346/1695 train_time:133176ms step_avg:98.94ms +step:1347/1695 train_time:133277ms step_avg:98.94ms +step:1348/1695 train_time:133377ms step_avg:98.94ms +step:1349/1695 train_time:133477ms step_avg:98.95ms +step:1350/1695 train_time:133578ms step_avg:98.95ms +step:1351/1695 train_time:133678ms step_avg:98.95ms +step:1352/1695 train_time:133778ms step_avg:98.95ms +step:1353/1695 train_time:133879ms step_avg:98.95ms +step:1354/1695 train_time:133980ms step_avg:98.95ms +step:1355/1695 train_time:134081ms step_avg:98.95ms +step:1356/1695 train_time:134183ms step_avg:98.95ms +step:1357/1695 train_time:134283ms step_avg:98.96ms +step:1358/1695 train_time:134384ms step_avg:98.96ms +step:1359/1695 train_time:134485ms step_avg:98.96ms +step:1360/1695 train_time:134586ms step_avg:98.96ms +step:1361/1695 train_time:134686ms step_avg:98.96ms +step:1362/1695 train_time:134787ms step_avg:98.96ms +step:1363/1695 train_time:134889ms step_avg:98.96ms +step:1364/1695 train_time:134992ms step_avg:98.97ms +step:1365/1695 train_time:135094ms step_avg:98.97ms +step:1366/1695 train_time:135194ms step_avg:98.97ms +step:1367/1695 train_time:135295ms step_avg:98.97ms +step:1368/1695 train_time:135397ms step_avg:98.97ms +step:1369/1695 train_time:135496ms step_avg:98.97ms +step:1370/1695 train_time:135596ms step_avg:98.98ms +step:1371/1695 train_time:135697ms step_avg:98.98ms +step:1372/1695 train_time:135798ms step_avg:98.98ms +step:1373/1695 train_time:135899ms step_avg:98.98ms +step:1374/1695 train_time:136000ms step_avg:98.98ms +step:1375/1695 train_time:136101ms step_avg:98.98ms +step:1375/1695 val_loss:3.3538 train_time:136200ms step_avg:99.05ms +step:1376/1695 train_time:136228ms step_avg:99.00ms +step:1377/1695 train_time:136317ms step_avg:99.00ms +step:1378/1695 train_time:136418ms step_avg:99.00ms +step:1379/1695 train_time:136518ms step_avg:99.00ms +step:1380/1695 train_time:136620ms step_avg:99.00ms +step:1381/1695 train_time:136720ms step_avg:99.00ms +step:1382/1695 train_time:136820ms step_avg:99.00ms +step:1383/1695 train_time:136919ms step_avg:99.00ms +step:1384/1695 train_time:137019ms step_avg:99.00ms +step:1385/1695 train_time:137122ms step_avg:99.01ms +step:1386/1695 train_time:137230ms step_avg:99.01ms +step:1387/1695 train_time:137331ms step_avg:99.01ms +step:1388/1695 train_time:137434ms step_avg:99.02ms +step:1389/1695 train_time:137537ms step_avg:99.02ms +step:1390/1695 train_time:137638ms step_avg:99.02ms +step:1391/1695 train_time:137739ms step_avg:99.02ms +step:1392/1695 train_time:137840ms step_avg:99.02ms +step:1393/1695 train_time:137941ms step_avg:99.02ms +step:1394/1695 train_time:138043ms step_avg:99.03ms +step:1395/1695 train_time:138146ms step_avg:99.03ms +step:1396/1695 train_time:138249ms step_avg:99.03ms +step:1397/1695 train_time:138353ms step_avg:99.04ms +step:1398/1695 train_time:138456ms step_avg:99.04ms +step:1399/1695 train_time:138559ms step_avg:99.04ms +step:1400/1695 train_time:138660ms step_avg:99.04ms +step:1401/1695 train_time:138762ms step_avg:99.04ms +step:1402/1695 train_time:138863ms step_avg:99.05ms +step:1403/1695 train_time:138965ms step_avg:99.05ms +step:1404/1695 train_time:139067ms step_avg:99.05ms +step:1405/1695 train_time:139170ms step_avg:99.05ms +step:1406/1695 train_time:139273ms step_avg:99.06ms +step:1407/1695 train_time:139376ms step_avg:99.06ms +step:1408/1695 train_time:139477ms step_avg:99.06ms +step:1409/1695 train_time:139582ms step_avg:99.06ms +step:1410/1695 train_time:139683ms step_avg:99.07ms +step:1411/1695 train_time:139785ms step_avg:99.07ms +step:1412/1695 train_time:139889ms step_avg:99.07ms +step:1413/1695 train_time:139989ms step_avg:99.07ms +step:1414/1695 train_time:140091ms step_avg:99.07ms +step:1415/1695 train_time:140194ms step_avg:99.08ms +step:1416/1695 train_time:140295ms step_avg:99.08ms +step:1417/1695 train_time:140397ms step_avg:99.08ms +step:1418/1695 train_time:140498ms step_avg:99.08ms +step:1419/1695 train_time:140601ms step_avg:99.08ms +step:1420/1695 train_time:140702ms step_avg:99.09ms +step:1421/1695 train_time:140804ms step_avg:99.09ms +step:1422/1695 train_time:140907ms step_avg:99.09ms +step:1423/1695 train_time:141009ms step_avg:99.09ms +step:1424/1695 train_time:141112ms step_avg:99.10ms +step:1425/1695 train_time:141213ms step_avg:99.10ms +step:1426/1695 train_time:141315ms step_avg:99.10ms +step:1427/1695 train_time:141417ms step_avg:99.10ms +step:1428/1695 train_time:141519ms step_avg:99.10ms +step:1429/1695 train_time:141620ms step_avg:99.10ms +step:1430/1695 train_time:141721ms step_avg:99.11ms +step:1431/1695 train_time:141824ms step_avg:99.11ms +step:1432/1695 train_time:141925ms step_avg:99.11ms +step:1433/1695 train_time:142028ms step_avg:99.11ms +step:1434/1695 train_time:142129ms step_avg:99.11ms +step:1435/1695 train_time:142233ms step_avg:99.12ms +step:1436/1695 train_time:142337ms step_avg:99.12ms +step:1437/1695 train_time:142439ms step_avg:99.12ms +step:1438/1695 train_time:142540ms step_avg:99.12ms +step:1439/1695 train_time:142643ms step_avg:99.13ms +step:1440/1695 train_time:142745ms step_avg:99.13ms +step:1441/1695 train_time:142848ms step_avg:99.13ms +step:1442/1695 train_time:142948ms step_avg:99.13ms +step:1443/1695 train_time:143049ms step_avg:99.13ms +step:1444/1695 train_time:143152ms step_avg:99.14ms +step:1445/1695 train_time:143253ms step_avg:99.14ms +step:1446/1695 train_time:143354ms step_avg:99.14ms +step:1447/1695 train_time:143455ms step_avg:99.14ms +step:1448/1695 train_time:143558ms step_avg:99.14ms +step:1449/1695 train_time:143659ms step_avg:99.14ms +step:1450/1695 train_time:143760ms step_avg:99.14ms +step:1451/1695 train_time:143861ms step_avg:99.15ms +step:1452/1695 train_time:143964ms step_avg:99.15ms +step:1453/1695 train_time:144069ms step_avg:99.15ms +step:1454/1695 train_time:144171ms step_avg:99.15ms +step:1455/1695 train_time:144273ms step_avg:99.16ms +step:1456/1695 train_time:144375ms step_avg:99.16ms +step:1457/1695 train_time:144478ms step_avg:99.16ms +step:1458/1695 train_time:144580ms step_avg:99.16ms +step:1459/1695 train_time:144682ms step_avg:99.17ms +step:1460/1695 train_time:144784ms step_avg:99.17ms +step:1461/1695 train_time:144887ms step_avg:99.17ms +step:1462/1695 train_time:144988ms step_avg:99.17ms +step:1463/1695 train_time:145089ms step_avg:99.17ms +step:1464/1695 train_time:145192ms step_avg:99.17ms +step:1465/1695 train_time:145293ms step_avg:99.18ms +step:1466/1695 train_time:145395ms step_avg:99.18ms +step:1467/1695 train_time:145495ms step_avg:99.18ms +step:1468/1695 train_time:145598ms step_avg:99.18ms +step:1469/1695 train_time:145700ms step_avg:99.18ms +step:1470/1695 train_time:145801ms step_avg:99.18ms +step:1471/1695 train_time:145903ms step_avg:99.19ms +step:1472/1695 train_time:146005ms step_avg:99.19ms +step:1473/1695 train_time:146107ms step_avg:99.19ms +step:1474/1695 train_time:146209ms step_avg:99.19ms +step:1475/1695 train_time:146310ms step_avg:99.19ms +step:1476/1695 train_time:146413ms step_avg:99.20ms +step:1477/1695 train_time:146515ms step_avg:99.20ms +step:1478/1695 train_time:146617ms step_avg:99.20ms +step:1479/1695 train_time:146718ms step_avg:99.20ms +step:1480/1695 train_time:146819ms step_avg:99.20ms +step:1481/1695 train_time:146920ms step_avg:99.20ms +step:1482/1695 train_time:147022ms step_avg:99.21ms +step:1483/1695 train_time:147125ms step_avg:99.21ms +step:1484/1695 train_time:147227ms step_avg:99.21ms +step:1485/1695 train_time:147329ms step_avg:99.21ms +step:1486/1695 train_time:147431ms step_avg:99.21ms +step:1487/1695 train_time:147532ms step_avg:99.21ms +step:1488/1695 train_time:147635ms step_avg:99.22ms +step:1489/1695 train_time:147737ms step_avg:99.22ms +step:1490/1695 train_time:147838ms step_avg:99.22ms +step:1491/1695 train_time:147940ms step_avg:99.22ms +step:1492/1695 train_time:148041ms step_avg:99.22ms +step:1493/1695 train_time:148142ms step_avg:99.22ms +step:1494/1695 train_time:148245ms step_avg:99.23ms +step:1495/1695 train_time:148349ms step_avg:99.23ms +step:1496/1695 train_time:148451ms step_avg:99.23ms +step:1497/1695 train_time:148552ms step_avg:99.23ms +step:1498/1695 train_time:148654ms step_avg:99.23ms +step:1499/1695 train_time:148755ms step_avg:99.24ms +step:1500/1695 train_time:148857ms step_avg:99.24ms +step:1500/1695 val_loss:3.3189 train_time:148955ms step_avg:99.30ms +step:1501/1695 train_time:148982ms step_avg:99.26ms +step:1502/1695 train_time:149067ms step_avg:99.25ms +step:1503/1695 train_time:149169ms step_avg:99.25ms +step:1504/1695 train_time:149269ms step_avg:99.25ms +step:1505/1695 train_time:149370ms step_avg:99.25ms +step:1506/1695 train_time:149472ms step_avg:99.25ms +step:1507/1695 train_time:149572ms step_avg:99.25ms +step:1508/1695 train_time:149673ms step_avg:99.25ms +step:1509/1695 train_time:149776ms step_avg:99.26ms +step:1510/1695 train_time:149878ms step_avg:99.26ms +step:1511/1695 train_time:149982ms step_avg:99.26ms +step:1512/1695 train_time:150085ms step_avg:99.26ms +step:1513/1695 train_time:150187ms step_avg:99.26ms +step:1514/1695 train_time:150289ms step_avg:99.27ms +step:1515/1695 train_time:150395ms step_avg:99.27ms +step:1516/1695 train_time:150496ms step_avg:99.27ms +step:1517/1695 train_time:150597ms step_avg:99.27ms +step:1518/1695 train_time:150699ms step_avg:99.27ms +step:1519/1695 train_time:150803ms step_avg:99.28ms +step:1520/1695 train_time:150905ms step_avg:99.28ms +step:1521/1695 train_time:151006ms step_avg:99.28ms +step:1522/1695 train_time:151108ms step_avg:99.28ms +step:1523/1695 train_time:151210ms step_avg:99.28ms +step:1524/1695 train_time:151314ms step_avg:99.29ms +step:1525/1695 train_time:151417ms step_avg:99.29ms +step:1526/1695 train_time:151520ms step_avg:99.29ms +step:1527/1695 train_time:151623ms step_avg:99.29ms +step:1528/1695 train_time:151727ms step_avg:99.30ms +step:1529/1695 train_time:151828ms step_avg:99.30ms +step:1530/1695 train_time:151931ms step_avg:99.30ms +step:1531/1695 train_time:152032ms step_avg:99.30ms +step:1532/1695 train_time:152135ms step_avg:99.31ms +step:1533/1695 train_time:152238ms step_avg:99.31ms +step:1534/1695 train_time:152340ms step_avg:99.31ms +step:1535/1695 train_time:152442ms step_avg:99.31ms +step:1536/1695 train_time:152543ms step_avg:99.31ms +step:1537/1695 train_time:152644ms step_avg:99.31ms +step:1538/1695 train_time:152746ms step_avg:99.31ms +step:1539/1695 train_time:152847ms step_avg:99.32ms +step:1540/1695 train_time:152950ms step_avg:99.32ms +step:1541/1695 train_time:153052ms step_avg:99.32ms +step:1542/1695 train_time:153156ms step_avg:99.32ms +step:1543/1695 train_time:153258ms step_avg:99.32ms +step:1544/1695 train_time:153360ms step_avg:99.33ms +step:1545/1695 train_time:153462ms step_avg:99.33ms +step:1546/1695 train_time:153564ms step_avg:99.33ms +step:1547/1695 train_time:153666ms step_avg:99.33ms +step:1548/1695 train_time:153768ms step_avg:99.33ms +step:1549/1695 train_time:153870ms step_avg:99.34ms +step:1550/1695 train_time:153971ms step_avg:99.34ms +step:1551/1695 train_time:154074ms step_avg:99.34ms +step:1552/1695 train_time:154175ms step_avg:99.34ms +step:1553/1695 train_time:154280ms step_avg:99.34ms +step:1554/1695 train_time:154381ms step_avg:99.34ms +step:1555/1695 train_time:154482ms step_avg:99.35ms +step:1556/1695 train_time:154584ms step_avg:99.35ms +step:1557/1695 train_time:154688ms step_avg:99.35ms +step:1558/1695 train_time:154791ms step_avg:99.35ms +step:1559/1695 train_time:154893ms step_avg:99.35ms +step:1560/1695 train_time:154995ms step_avg:99.36ms +step:1561/1695 train_time:155097ms step_avg:99.36ms +step:1562/1695 train_time:155200ms step_avg:99.36ms +step:1563/1695 train_time:155304ms step_avg:99.36ms +step:1564/1695 train_time:155405ms step_avg:99.36ms +step:1565/1695 train_time:155506ms step_avg:99.37ms +step:1566/1695 train_time:155608ms step_avg:99.37ms +step:1567/1695 train_time:155709ms step_avg:99.37ms +step:1568/1695 train_time:155810ms step_avg:99.37ms +step:1569/1695 train_time:155911ms step_avg:99.37ms +step:1570/1695 train_time:156015ms step_avg:99.37ms +step:1571/1695 train_time:156117ms step_avg:99.37ms +step:1572/1695 train_time:156219ms step_avg:99.38ms +step:1573/1695 train_time:156321ms step_avg:99.38ms +step:1574/1695 train_time:156422ms step_avg:99.38ms +step:1575/1695 train_time:156525ms step_avg:99.38ms +step:1576/1695 train_time:156627ms step_avg:99.38ms +step:1577/1695 train_time:156730ms step_avg:99.38ms +step:1578/1695 train_time:156831ms step_avg:99.39ms +step:1579/1695 train_time:156932ms step_avg:99.39ms +step:1580/1695 train_time:157035ms step_avg:99.39ms +step:1581/1695 train_time:157137ms step_avg:99.39ms +step:1582/1695 train_time:157240ms step_avg:99.39ms +step:1583/1695 train_time:157343ms step_avg:99.40ms +step:1584/1695 train_time:157446ms step_avg:99.40ms +step:1585/1695 train_time:157547ms step_avg:99.40ms +step:1586/1695 train_time:157650ms step_avg:99.40ms +step:1587/1695 train_time:157751ms step_avg:99.40ms +step:1588/1695 train_time:157851ms step_avg:99.40ms +step:1589/1695 train_time:157952ms step_avg:99.40ms +step:1590/1695 train_time:158053ms step_avg:99.40ms +step:1591/1695 train_time:158155ms step_avg:99.41ms +step:1592/1695 train_time:158259ms step_avg:99.41ms +step:1593/1695 train_time:158361ms step_avg:99.41ms +step:1594/1695 train_time:158464ms step_avg:99.41ms +step:1595/1695 train_time:158566ms step_avg:99.41ms +step:1596/1695 train_time:158667ms step_avg:99.42ms +step:1597/1695 train_time:158769ms step_avg:99.42ms +step:1598/1695 train_time:158873ms step_avg:99.42ms +step:1599/1695 train_time:158974ms step_avg:99.42ms +step:1600/1695 train_time:159076ms step_avg:99.42ms +step:1601/1695 train_time:159179ms step_avg:99.42ms +step:1602/1695 train_time:159280ms step_avg:99.43ms +step:1603/1695 train_time:159382ms step_avg:99.43ms +step:1604/1695 train_time:159483ms step_avg:99.43ms +step:1605/1695 train_time:159587ms step_avg:99.43ms +step:1606/1695 train_time:159690ms step_avg:99.43ms +step:1607/1695 train_time:159790ms step_avg:99.43ms +step:1608/1695 train_time:159891ms step_avg:99.43ms +step:1609/1695 train_time:159992ms step_avg:99.44ms +step:1610/1695 train_time:160094ms step_avg:99.44ms +step:1611/1695 train_time:160196ms step_avg:99.44ms +step:1612/1695 train_time:160299ms step_avg:99.44ms +step:1613/1695 train_time:160401ms step_avg:99.44ms +step:1614/1695 train_time:160502ms step_avg:99.44ms +step:1615/1695 train_time:160604ms step_avg:99.45ms +step:1616/1695 train_time:160707ms step_avg:99.45ms +step:1617/1695 train_time:160809ms step_avg:99.45ms +step:1618/1695 train_time:160910ms step_avg:99.45ms +step:1619/1695 train_time:161011ms step_avg:99.45ms +step:1620/1695 train_time:161114ms step_avg:99.45ms +step:1621/1695 train_time:161215ms step_avg:99.45ms +step:1622/1695 train_time:161318ms step_avg:99.46ms +step:1623/1695 train_time:161420ms step_avg:99.46ms +step:1624/1695 train_time:161522ms step_avg:99.46ms +step:1625/1695 train_time:161626ms step_avg:99.46ms +step:1625/1695 val_loss:3.2901 train_time:161726ms step_avg:99.52ms +step:1626/1695 train_time:161754ms step_avg:99.48ms +step:1627/1695 train_time:161839ms step_avg:99.47ms +step:1628/1695 train_time:161940ms step_avg:99.47ms +step:1629/1695 train_time:162042ms step_avg:99.47ms +step:1630/1695 train_time:162143ms step_avg:99.47ms +step:1631/1695 train_time:162245ms step_avg:99.48ms +step:1632/1695 train_time:162347ms step_avg:99.48ms +step:1633/1695 train_time:162449ms step_avg:99.48ms +step:1634/1695 train_time:162551ms step_avg:99.48ms +step:1635/1695 train_time:162655ms step_avg:99.48ms +step:1636/1695 train_time:162758ms step_avg:99.49ms +step:1637/1695 train_time:162860ms step_avg:99.49ms +step:1638/1695 train_time:162963ms step_avg:99.49ms +step:1639/1695 train_time:163064ms step_avg:99.49ms +step:1640/1695 train_time:163167ms step_avg:99.49ms +step:1641/1695 train_time:163270ms step_avg:99.49ms +step:1642/1695 train_time:163371ms step_avg:99.50ms +step:1643/1695 train_time:163473ms step_avg:99.50ms +step:1644/1695 train_time:163575ms step_avg:99.50ms +step:1645/1695 train_time:163679ms step_avg:99.50ms +step:1646/1695 train_time:163781ms step_avg:99.50ms +step:1647/1695 train_time:163886ms step_avg:99.51ms +step:1648/1695 train_time:163990ms step_avg:99.51ms +step:1649/1695 train_time:164092ms step_avg:99.51ms +step:1650/1695 train_time:164195ms step_avg:99.51ms +step:1651/1695 train_time:164297ms step_avg:99.51ms +step:1652/1695 train_time:164400ms step_avg:99.52ms +step:1653/1695 train_time:164503ms step_avg:99.52ms +step:1654/1695 train_time:164605ms step_avg:99.52ms +step:1655/1695 train_time:164707ms step_avg:99.52ms +step:1656/1695 train_time:164809ms step_avg:99.52ms +step:1657/1695 train_time:164912ms step_avg:99.52ms +step:1658/1695 train_time:165016ms step_avg:99.53ms +step:1659/1695 train_time:165122ms step_avg:99.53ms +step:1660/1695 train_time:165224ms step_avg:99.53ms +step:1661/1695 train_time:165328ms step_avg:99.53ms +step:1662/1695 train_time:165432ms step_avg:99.54ms +step:1663/1695 train_time:165536ms step_avg:99.54ms +step:1664/1695 train_time:165638ms step_avg:99.54ms +step:1665/1695 train_time:165743ms step_avg:99.55ms +step:1666/1695 train_time:165846ms step_avg:99.55ms +step:1667/1695 train_time:165950ms step_avg:99.55ms +step:1668/1695 train_time:166057ms step_avg:99.55ms +step:1669/1695 train_time:166161ms step_avg:99.56ms +step:1670/1695 train_time:166264ms step_avg:99.56ms +step:1671/1695 train_time:166366ms step_avg:99.56ms +step:1672/1695 train_time:166469ms step_avg:99.56ms +step:1673/1695 train_time:166572ms step_avg:99.56ms +step:1674/1695 train_time:166674ms step_avg:99.57ms +step:1675/1695 train_time:166777ms step_avg:99.57ms +step:1676/1695 train_time:166881ms step_avg:99.57ms +step:1677/1695 train_time:166982ms step_avg:99.57ms +step:1678/1695 train_time:167088ms step_avg:99.58ms +step:1679/1695 train_time:167190ms step_avg:99.58ms +step:1680/1695 train_time:167292ms step_avg:99.58ms +step:1681/1695 train_time:167395ms step_avg:99.58ms +step:1682/1695 train_time:167502ms step_avg:99.58ms +step:1683/1695 train_time:167603ms step_avg:99.59ms +step:1684/1695 train_time:167707ms step_avg:99.59ms +step:1685/1695 train_time:167810ms step_avg:99.59ms +step:1686/1695 train_time:167913ms step_avg:99.59ms +step:1687/1695 train_time:168017ms step_avg:99.60ms +step:1688/1695 train_time:168119ms step_avg:99.60ms +step:1689/1695 train_time:168220ms step_avg:99.60ms +step:1690/1695 train_time:168322ms step_avg:99.60ms +step:1691/1695 train_time:168425ms step_avg:99.60ms +step:1692/1695 train_time:168528ms step_avg:99.60ms +step:1693/1695 train_time:168632ms step_avg:99.61ms +step:1694/1695 train_time:168736ms step_avg:99.61ms +step:1695/1695 train_time:168839ms step_avg:99.61ms +step:1695/1695 val_loss:3.2772 train_time:168938ms step_avg:99.67ms +peak memory allocated: 34004 MiB reserved: 49660 MiB diff --git a/records/082325_SparseAttnGate/21e732fb-4c4b-4db9-94bc-9fcd5d59b080.txt b/records/082325_SparseAttnGate/21e732fb-4c4b-4db9-94bc-9fcd5d59b080.txt new file mode 100644 index 000000000..6b4098881 --- /dev/null +++ b/records/082325_SparseAttnGate/21e732fb-4c4b-4db9-94bc-9fcd5d59b080.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:08:17 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 296819 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 296820 C /usr/bin/python3 614MiB | +| 0 N/A N/A 296821 C /usr/bin/python3 614MiB | +| 0 N/A N/A 296822 C /usr/bin/python3 614MiB | +| 0 N/A N/A 296823 C /usr/bin/python3 614MiB | +| 0 N/A N/A 296824 C /usr/bin/python3 614MiB | +| 0 N/A N/A 296825 C /usr/bin/python3 614MiB | +| 0 N/A N/A 296826 C /usr/bin/python3 614MiB | +| 1 N/A N/A 296820 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 296821 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 296822 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 296823 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 296824 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 296825 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 296826 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.04ms +step:1/1695 train_time:157ms step_avg:156.91ms +step:2/1695 train_time:187ms step_avg:93.26ms +step:3/1695 train_time:254ms step_avg:84.82ms +step:4/1695 train_time:346ms step_avg:86.61ms +step:5/1695 train_time:439ms step_avg:87.84ms +step:6/1695 train_time:532ms step_avg:88.67ms +step:7/1695 train_time:626ms step_avg:89.39ms +step:8/1695 train_time:719ms step_avg:89.89ms +step:9/1695 train_time:813ms step_avg:90.31ms +step:10/1695 train_time:906ms step_avg:90.55ms +step:11/1695 train_time:998ms step_avg:90.75ms +step:12/1695 train_time:1092ms step_avg:91.04ms +step:13/1695 train_time:1190ms step_avg:91.51ms +step:14/1695 train_time:1285ms step_avg:91.79ms +step:15/1695 train_time:1379ms step_avg:91.96ms +step:16/1695 train_time:1473ms step_avg:92.03ms +step:17/1695 train_time:1566ms step_avg:92.15ms +step:18/1695 train_time:1659ms step_avg:92.19ms +step:19/1695 train_time:1753ms step_avg:92.25ms +step:20/1695 train_time:1846ms step_avg:92.32ms +step:21/1695 train_time:1940ms step_avg:92.38ms +step:22/1695 train_time:2034ms step_avg:92.44ms +step:23/1695 train_time:2128ms step_avg:92.54ms +step:24/1695 train_time:2223ms step_avg:92.62ms +step:25/1695 train_time:2317ms step_avg:92.70ms +step:26/1695 train_time:2411ms step_avg:92.75ms +step:27/1695 train_time:2505ms step_avg:92.80ms +step:28/1695 train_time:2599ms step_avg:92.83ms +step:29/1695 train_time:2693ms step_avg:92.86ms +step:30/1695 train_time:2786ms step_avg:92.88ms +step:31/1695 train_time:2880ms step_avg:92.92ms +step:32/1695 train_time:2974ms step_avg:92.93ms +step:33/1695 train_time:3068ms step_avg:92.96ms +step:34/1695 train_time:3162ms step_avg:92.99ms +step:35/1695 train_time:3255ms step_avg:93.00ms +step:36/1695 train_time:3349ms step_avg:93.02ms +step:37/1695 train_time:3443ms step_avg:93.05ms +step:38/1695 train_time:3537ms step_avg:93.08ms +step:39/1695 train_time:3631ms step_avg:93.10ms +step:40/1695 train_time:3725ms step_avg:93.13ms +step:41/1695 train_time:3819ms step_avg:93.15ms +step:42/1695 train_time:3913ms step_avg:93.16ms +step:43/1695 train_time:4007ms step_avg:93.19ms +step:44/1695 train_time:4101ms step_avg:93.20ms +step:45/1695 train_time:4194ms step_avg:93.20ms +step:46/1695 train_time:4288ms step_avg:93.22ms +step:47/1695 train_time:4382ms step_avg:93.23ms +step:48/1695 train_time:4476ms step_avg:93.24ms +step:49/1695 train_time:4571ms step_avg:93.28ms +step:50/1695 train_time:4665ms step_avg:93.29ms +step:51/1695 train_time:4759ms step_avg:93.30ms +step:52/1695 train_time:4852ms step_avg:93.30ms +step:53/1695 train_time:4946ms step_avg:93.32ms +step:54/1695 train_time:5041ms step_avg:93.36ms +step:55/1695 train_time:5135ms step_avg:93.36ms +step:56/1695 train_time:5229ms step_avg:93.37ms +step:57/1695 train_time:5322ms step_avg:93.37ms +step:58/1695 train_time:5415ms step_avg:93.36ms +step:59/1695 train_time:5508ms step_avg:93.36ms +step:60/1695 train_time:5602ms step_avg:93.37ms +step:61/1695 train_time:5695ms step_avg:93.37ms +step:62/1695 train_time:5789ms step_avg:93.38ms +step:63/1695 train_time:5883ms step_avg:93.38ms +step:64/1695 train_time:5977ms step_avg:93.39ms +step:65/1695 train_time:6070ms step_avg:93.39ms +step:66/1695 train_time:6164ms step_avg:93.39ms +step:67/1695 train_time:6257ms step_avg:93.39ms +step:68/1695 train_time:6352ms step_avg:93.41ms +step:69/1695 train_time:6447ms step_avg:93.43ms +step:70/1695 train_time:6541ms step_avg:93.45ms +step:71/1695 train_time:6635ms step_avg:93.45ms +step:72/1695 train_time:6730ms step_avg:93.47ms +step:73/1695 train_time:6824ms step_avg:93.48ms +step:74/1695 train_time:6918ms step_avg:93.49ms +step:75/1695 train_time:7012ms step_avg:93.49ms +step:76/1695 train_time:7106ms step_avg:93.49ms +step:77/1695 train_time:7198ms step_avg:93.49ms +step:78/1695 train_time:7292ms step_avg:93.49ms +step:79/1695 train_time:7386ms step_avg:93.49ms +step:80/1695 train_time:7480ms step_avg:93.49ms +step:81/1695 train_time:7573ms step_avg:93.49ms +step:82/1695 train_time:7666ms step_avg:93.49ms +step:83/1695 train_time:7760ms step_avg:93.49ms +step:84/1695 train_time:7854ms step_avg:93.50ms +step:85/1695 train_time:7948ms step_avg:93.50ms +step:86/1695 train_time:8041ms step_avg:93.50ms +step:87/1695 train_time:8135ms step_avg:93.51ms +step:88/1695 train_time:8228ms step_avg:93.51ms +step:89/1695 train_time:8322ms step_avg:93.51ms +step:90/1695 train_time:8416ms step_avg:93.51ms +step:91/1695 train_time:8510ms step_avg:93.51ms +step:92/1695 train_time:8603ms step_avg:93.51ms +step:93/1695 train_time:8696ms step_avg:93.51ms +step:94/1695 train_time:8790ms step_avg:93.51ms +step:95/1695 train_time:8884ms step_avg:93.51ms +step:96/1695 train_time:8977ms step_avg:93.51ms +step:97/1695 train_time:9071ms step_avg:93.52ms +step:98/1695 train_time:9166ms step_avg:93.53ms +step:99/1695 train_time:9259ms step_avg:93.52ms +step:100/1695 train_time:9353ms step_avg:93.53ms +step:101/1695 train_time:9447ms step_avg:93.54ms +step:102/1695 train_time:9541ms step_avg:93.54ms +step:103/1695 train_time:9635ms step_avg:93.54ms +step:104/1695 train_time:9729ms step_avg:93.55ms +step:105/1695 train_time:9823ms step_avg:93.55ms +step:106/1695 train_time:9916ms step_avg:93.55ms +step:107/1695 train_time:10011ms step_avg:93.56ms +step:108/1695 train_time:10105ms step_avg:93.56ms +step:109/1695 train_time:10198ms step_avg:93.56ms +step:110/1695 train_time:10292ms step_avg:93.56ms +step:111/1695 train_time:10386ms step_avg:93.57ms +step:112/1695 train_time:10479ms step_avg:93.57ms +step:113/1695 train_time:10573ms step_avg:93.57ms +step:114/1695 train_time:10667ms step_avg:93.57ms +step:115/1695 train_time:10762ms step_avg:93.58ms +step:116/1695 train_time:10855ms step_avg:93.58ms +step:117/1695 train_time:10950ms step_avg:93.59ms +step:118/1695 train_time:11045ms step_avg:93.60ms +step:119/1695 train_time:11139ms step_avg:93.60ms +step:120/1695 train_time:11231ms step_avg:93.59ms +step:121/1695 train_time:11324ms step_avg:93.58ms +step:122/1695 train_time:11417ms step_avg:93.58ms +step:123/1695 train_time:11512ms step_avg:93.59ms +step:124/1695 train_time:11605ms step_avg:93.59ms +step:125/1695 train_time:11698ms step_avg:93.58ms +step:125/1695 val_loss:4.6063 train_time:11789ms step_avg:94.31ms +step:126/1695 train_time:11816ms step_avg:93.78ms +step:127/1695 train_time:11894ms step_avg:93.65ms +step:128/1695 train_time:11996ms step_avg:93.72ms +step:129/1695 train_time:12091ms step_avg:93.73ms +step:130/1695 train_time:12186ms step_avg:93.74ms +step:131/1695 train_time:12279ms step_avg:93.73ms +step:132/1695 train_time:12372ms step_avg:93.73ms +step:133/1695 train_time:12466ms step_avg:93.73ms +step:134/1695 train_time:12559ms step_avg:93.73ms +step:135/1695 train_time:12653ms step_avg:93.72ms +step:136/1695 train_time:12747ms step_avg:93.73ms +step:137/1695 train_time:12840ms step_avg:93.72ms +step:138/1695 train_time:12935ms step_avg:93.74ms +step:139/1695 train_time:13031ms step_avg:93.75ms +step:140/1695 train_time:13127ms step_avg:93.76ms +step:141/1695 train_time:13221ms step_avg:93.77ms +step:142/1695 train_time:13315ms step_avg:93.77ms +step:143/1695 train_time:13409ms step_avg:93.77ms +step:144/1695 train_time:13502ms step_avg:93.77ms +step:145/1695 train_time:13596ms step_avg:93.76ms +step:146/1695 train_time:13689ms step_avg:93.76ms +step:147/1695 train_time:13785ms step_avg:93.78ms +step:148/1695 train_time:13877ms step_avg:93.76ms +step:149/1695 train_time:13972ms step_avg:93.77ms +step:150/1695 train_time:14068ms step_avg:93.78ms +step:151/1695 train_time:14164ms step_avg:93.80ms +step:152/1695 train_time:14258ms step_avg:93.80ms +step:153/1695 train_time:14353ms step_avg:93.81ms +step:154/1695 train_time:14446ms step_avg:93.80ms +step:155/1695 train_time:14539ms step_avg:93.80ms +step:156/1695 train_time:14633ms step_avg:93.80ms +step:157/1695 train_time:14726ms step_avg:93.80ms +step:158/1695 train_time:14821ms step_avg:93.80ms +step:159/1695 train_time:14915ms step_avg:93.80ms +step:160/1695 train_time:15009ms step_avg:93.81ms +step:161/1695 train_time:15105ms step_avg:93.82ms +step:162/1695 train_time:15199ms step_avg:93.82ms +step:163/1695 train_time:15294ms step_avg:93.83ms +step:164/1695 train_time:15389ms step_avg:93.84ms +step:165/1695 train_time:15483ms step_avg:93.84ms +step:166/1695 train_time:15577ms step_avg:93.84ms +step:167/1695 train_time:15671ms step_avg:93.84ms +step:168/1695 train_time:15766ms step_avg:93.85ms +step:169/1695 train_time:15860ms step_avg:93.85ms +step:170/1695 train_time:15954ms step_avg:93.85ms +step:171/1695 train_time:16047ms step_avg:93.84ms +step:172/1695 train_time:16142ms step_avg:93.85ms +step:173/1695 train_time:16236ms step_avg:93.85ms +step:174/1695 train_time:16331ms step_avg:93.85ms +step:175/1695 train_time:16425ms step_avg:93.86ms +step:176/1695 train_time:16519ms step_avg:93.86ms +step:177/1695 train_time:16614ms step_avg:93.86ms +step:178/1695 train_time:16709ms step_avg:93.87ms +step:179/1695 train_time:16803ms step_avg:93.87ms +step:180/1695 train_time:16897ms step_avg:93.87ms +step:181/1695 train_time:16991ms step_avg:93.87ms +step:182/1695 train_time:17085ms step_avg:93.88ms +step:183/1695 train_time:17179ms step_avg:93.88ms +step:184/1695 train_time:17273ms step_avg:93.88ms +step:185/1695 train_time:17368ms step_avg:93.88ms +step:186/1695 train_time:17462ms step_avg:93.88ms +step:187/1695 train_time:17557ms step_avg:93.89ms +step:188/1695 train_time:17651ms step_avg:93.89ms +step:189/1695 train_time:17746ms step_avg:93.89ms +step:190/1695 train_time:17841ms step_avg:93.90ms +step:191/1695 train_time:17934ms step_avg:93.90ms +step:192/1695 train_time:18028ms step_avg:93.90ms +step:193/1695 train_time:18123ms step_avg:93.90ms +step:194/1695 train_time:18217ms step_avg:93.90ms +step:195/1695 train_time:18311ms step_avg:93.90ms +step:196/1695 train_time:18406ms step_avg:93.91ms +step:197/1695 train_time:18499ms step_avg:93.91ms +step:198/1695 train_time:18593ms step_avg:93.91ms +step:199/1695 train_time:18688ms step_avg:93.91ms +step:200/1695 train_time:18782ms step_avg:93.91ms +step:201/1695 train_time:18876ms step_avg:93.91ms +step:202/1695 train_time:18970ms step_avg:93.91ms +step:203/1695 train_time:19065ms step_avg:93.92ms +step:204/1695 train_time:19159ms step_avg:93.92ms +step:205/1695 train_time:19253ms step_avg:93.92ms +step:206/1695 train_time:19348ms step_avg:93.92ms +step:207/1695 train_time:19442ms step_avg:93.92ms +step:208/1695 train_time:19536ms step_avg:93.92ms +step:209/1695 train_time:19631ms step_avg:93.93ms +step:210/1695 train_time:19726ms step_avg:93.93ms +step:211/1695 train_time:19820ms step_avg:93.93ms +step:212/1695 train_time:19913ms step_avg:93.93ms +step:213/1695 train_time:20007ms step_avg:93.93ms +step:214/1695 train_time:20102ms step_avg:93.93ms +step:215/1695 train_time:20195ms step_avg:93.93ms +step:216/1695 train_time:20289ms step_avg:93.93ms +step:217/1695 train_time:20383ms step_avg:93.93ms +step:218/1695 train_time:20478ms step_avg:93.93ms +step:219/1695 train_time:20572ms step_avg:93.93ms +step:220/1695 train_time:20667ms step_avg:93.94ms +step:221/1695 train_time:20760ms step_avg:93.94ms +step:222/1695 train_time:20854ms step_avg:93.94ms +step:223/1695 train_time:20948ms step_avg:93.94ms +step:224/1695 train_time:21042ms step_avg:93.94ms +step:225/1695 train_time:21136ms step_avg:93.94ms +step:226/1695 train_time:21230ms step_avg:93.94ms +step:227/1695 train_time:21324ms step_avg:93.94ms +step:228/1695 train_time:21418ms step_avg:93.94ms +step:229/1695 train_time:21513ms step_avg:93.94ms +step:230/1695 train_time:21607ms step_avg:93.94ms +step:231/1695 train_time:21702ms step_avg:93.95ms +step:232/1695 train_time:21795ms step_avg:93.94ms +step:233/1695 train_time:21889ms step_avg:93.95ms +step:234/1695 train_time:21983ms step_avg:93.95ms +step:235/1695 train_time:22077ms step_avg:93.94ms +step:236/1695 train_time:22172ms step_avg:93.95ms +step:237/1695 train_time:22265ms step_avg:93.95ms +step:238/1695 train_time:22359ms step_avg:93.95ms +step:239/1695 train_time:22454ms step_avg:93.95ms +step:240/1695 train_time:22548ms step_avg:93.95ms +step:241/1695 train_time:22642ms step_avg:93.95ms +step:242/1695 train_time:22736ms step_avg:93.95ms +step:243/1695 train_time:22830ms step_avg:93.95ms +step:244/1695 train_time:22925ms step_avg:93.95ms +step:245/1695 train_time:23018ms step_avg:93.95ms +step:246/1695 train_time:23113ms step_avg:93.95ms +step:247/1695 train_time:23207ms step_avg:93.95ms +step:248/1695 train_time:23301ms step_avg:93.96ms +step:249/1695 train_time:23395ms step_avg:93.96ms +step:250/1695 train_time:23489ms step_avg:93.96ms +step:250/1695 val_loss:4.0781 train_time:23582ms step_avg:94.33ms +step:251/1695 train_time:23610ms step_avg:94.06ms +step:252/1695 train_time:23686ms step_avg:93.99ms +step:253/1695 train_time:23785ms step_avg:94.01ms +step:254/1695 train_time:23880ms step_avg:94.02ms +step:255/1695 train_time:23975ms step_avg:94.02ms +step:256/1695 train_time:24069ms step_avg:94.02ms +step:257/1695 train_time:24162ms step_avg:94.02ms +step:258/1695 train_time:24257ms step_avg:94.02ms +step:259/1695 train_time:24350ms step_avg:94.02ms +step:260/1695 train_time:24444ms step_avg:94.01ms +step:261/1695 train_time:24540ms step_avg:94.02ms +step:262/1695 train_time:24636ms step_avg:94.03ms +step:263/1695 train_time:24731ms step_avg:94.03ms +step:264/1695 train_time:24826ms step_avg:94.04ms +step:265/1695 train_time:24921ms step_avg:94.04ms +step:266/1695 train_time:25016ms step_avg:94.04ms +step:267/1695 train_time:25110ms step_avg:94.05ms +step:268/1695 train_time:25204ms step_avg:94.05ms +step:269/1695 train_time:25298ms step_avg:94.05ms +step:270/1695 train_time:25393ms step_avg:94.05ms +step:271/1695 train_time:25486ms step_avg:94.05ms +step:272/1695 train_time:25582ms step_avg:94.05ms +step:273/1695 train_time:25678ms step_avg:94.06ms +step:274/1695 train_time:25774ms step_avg:94.07ms +step:275/1695 train_time:25868ms step_avg:94.07ms +step:276/1695 train_time:25963ms step_avg:94.07ms +step:277/1695 train_time:26058ms step_avg:94.07ms +step:278/1695 train_time:26153ms step_avg:94.08ms +step:279/1695 train_time:26247ms step_avg:94.08ms +step:280/1695 train_time:26341ms step_avg:94.07ms +step:281/1695 train_time:26435ms step_avg:94.07ms +step:282/1695 train_time:26529ms step_avg:94.07ms +step:283/1695 train_time:26625ms step_avg:94.08ms +step:284/1695 train_time:26719ms step_avg:94.08ms +step:285/1695 train_time:26814ms step_avg:94.09ms +step:286/1695 train_time:26908ms step_avg:94.08ms +step:287/1695 train_time:27002ms step_avg:94.08ms +step:288/1695 train_time:27098ms step_avg:94.09ms +step:289/1695 train_time:27193ms step_avg:94.09ms +step:290/1695 train_time:27287ms step_avg:94.09ms +step:291/1695 train_time:27381ms step_avg:94.09ms +step:292/1695 train_time:27476ms step_avg:94.10ms +step:293/1695 train_time:27570ms step_avg:94.10ms +step:294/1695 train_time:27665ms step_avg:94.10ms +step:295/1695 train_time:27760ms step_avg:94.10ms +step:296/1695 train_time:27854ms step_avg:94.10ms +step:297/1695 train_time:27948ms step_avg:94.10ms +step:298/1695 train_time:28043ms step_avg:94.10ms +step:299/1695 train_time:28138ms step_avg:94.11ms +step:300/1695 train_time:28233ms step_avg:94.11ms +step:301/1695 train_time:28326ms step_avg:94.11ms +step:302/1695 train_time:28421ms step_avg:94.11ms +step:303/1695 train_time:28516ms step_avg:94.11ms +step:304/1695 train_time:28610ms step_avg:94.11ms +step:305/1695 train_time:28704ms step_avg:94.11ms +step:306/1695 train_time:28800ms step_avg:94.12ms +step:307/1695 train_time:28894ms step_avg:94.12ms +step:308/1695 train_time:28989ms step_avg:94.12ms +step:309/1695 train_time:29084ms step_avg:94.12ms +step:310/1695 train_time:29179ms step_avg:94.12ms +step:311/1695 train_time:29274ms step_avg:94.13ms +step:312/1695 train_time:29368ms step_avg:94.13ms +step:313/1695 train_time:29463ms step_avg:94.13ms +step:314/1695 train_time:29558ms step_avg:94.13ms +step:315/1695 train_time:29653ms step_avg:94.14ms +step:316/1695 train_time:29747ms step_avg:94.14ms +step:317/1695 train_time:29842ms step_avg:94.14ms +step:318/1695 train_time:29937ms step_avg:94.14ms +step:319/1695 train_time:30031ms step_avg:94.14ms +step:320/1695 train_time:30125ms step_avg:94.14ms +step:321/1695 train_time:30220ms step_avg:94.14ms +step:322/1695 train_time:30315ms step_avg:94.15ms +step:323/1695 train_time:30409ms step_avg:94.15ms +step:324/1695 train_time:30504ms step_avg:94.15ms +step:325/1695 train_time:30598ms step_avg:94.15ms +step:326/1695 train_time:30694ms step_avg:94.15ms +step:327/1695 train_time:30788ms step_avg:94.15ms +step:328/1695 train_time:30882ms step_avg:94.15ms +step:329/1695 train_time:30978ms step_avg:94.16ms +step:330/1695 train_time:31072ms step_avg:94.16ms +step:331/1695 train_time:31166ms step_avg:94.16ms +step:332/1695 train_time:31261ms step_avg:94.16ms +step:333/1695 train_time:31356ms step_avg:94.16ms +step:334/1695 train_time:31450ms step_avg:94.16ms +step:335/1695 train_time:31544ms step_avg:94.16ms +step:336/1695 train_time:31638ms step_avg:94.16ms +step:337/1695 train_time:31734ms step_avg:94.17ms +step:338/1695 train_time:31828ms step_avg:94.16ms +step:339/1695 train_time:31922ms step_avg:94.17ms +step:340/1695 train_time:32017ms step_avg:94.17ms +step:341/1695 train_time:32113ms step_avg:94.17ms +step:342/1695 train_time:32206ms step_avg:94.17ms +step:343/1695 train_time:32301ms step_avg:94.17ms +step:344/1695 train_time:32396ms step_avg:94.18ms +step:345/1695 train_time:32490ms step_avg:94.18ms +step:346/1695 train_time:32585ms step_avg:94.18ms +step:347/1695 train_time:32680ms step_avg:94.18ms +step:348/1695 train_time:32775ms step_avg:94.18ms +step:349/1695 train_time:32869ms step_avg:94.18ms +step:350/1695 train_time:32964ms step_avg:94.18ms +step:351/1695 train_time:33059ms step_avg:94.18ms +step:352/1695 train_time:33154ms step_avg:94.19ms +step:353/1695 train_time:33248ms step_avg:94.19ms +step:354/1695 train_time:33342ms step_avg:94.19ms +step:355/1695 train_time:33436ms step_avg:94.19ms +step:356/1695 train_time:33530ms step_avg:94.19ms +step:357/1695 train_time:33625ms step_avg:94.19ms +step:358/1695 train_time:33719ms step_avg:94.19ms +step:359/1695 train_time:33814ms step_avg:94.19ms +step:360/1695 train_time:33907ms step_avg:94.19ms +step:361/1695 train_time:34002ms step_avg:94.19ms +step:362/1695 train_time:34097ms step_avg:94.19ms +step:363/1695 train_time:34192ms step_avg:94.19ms +step:364/1695 train_time:34286ms step_avg:94.19ms +step:365/1695 train_time:34380ms step_avg:94.19ms +step:366/1695 train_time:34474ms step_avg:94.19ms +step:367/1695 train_time:34568ms step_avg:94.19ms +step:368/1695 train_time:34664ms step_avg:94.19ms +step:369/1695 train_time:34759ms step_avg:94.20ms +step:370/1695 train_time:34853ms step_avg:94.20ms +step:371/1695 train_time:34948ms step_avg:94.20ms +step:372/1695 train_time:35042ms step_avg:94.20ms +step:373/1695 train_time:35137ms step_avg:94.20ms +step:374/1695 train_time:35232ms step_avg:94.20ms +step:375/1695 train_time:35326ms step_avg:94.20ms +step:375/1695 val_loss:3.8822 train_time:35417ms step_avg:94.45ms +step:376/1695 train_time:35446ms step_avg:94.27ms +step:377/1695 train_time:35524ms step_avg:94.23ms +step:378/1695 train_time:35623ms step_avg:94.24ms +step:379/1695 train_time:35721ms step_avg:94.25ms +step:380/1695 train_time:35817ms step_avg:94.26ms +step:381/1695 train_time:35913ms step_avg:94.26ms +step:382/1695 train_time:36007ms step_avg:94.26ms +step:383/1695 train_time:36102ms step_avg:94.26ms +step:384/1695 train_time:36198ms step_avg:94.26ms +step:385/1695 train_time:36293ms step_avg:94.27ms +step:386/1695 train_time:36389ms step_avg:94.27ms +step:387/1695 train_time:36486ms step_avg:94.28ms +step:388/1695 train_time:36584ms step_avg:94.29ms +step:389/1695 train_time:36682ms step_avg:94.30ms +step:390/1695 train_time:36780ms step_avg:94.31ms +step:391/1695 train_time:36877ms step_avg:94.32ms +step:392/1695 train_time:36973ms step_avg:94.32ms +step:393/1695 train_time:37069ms step_avg:94.32ms +step:394/1695 train_time:37164ms step_avg:94.33ms +step:395/1695 train_time:37261ms step_avg:94.33ms +step:396/1695 train_time:37356ms step_avg:94.33ms +step:397/1695 train_time:37452ms step_avg:94.34ms +step:398/1695 train_time:37548ms step_avg:94.34ms +step:399/1695 train_time:37644ms step_avg:94.35ms +step:400/1695 train_time:37742ms step_avg:94.35ms +step:401/1695 train_time:37837ms step_avg:94.36ms +step:402/1695 train_time:37934ms step_avg:94.36ms +step:403/1695 train_time:38030ms step_avg:94.37ms +step:404/1695 train_time:38125ms step_avg:94.37ms +step:405/1695 train_time:38221ms step_avg:94.37ms +step:406/1695 train_time:38317ms step_avg:94.38ms +step:407/1695 train_time:38412ms step_avg:94.38ms +step:408/1695 train_time:38509ms step_avg:94.38ms +step:409/1695 train_time:38605ms step_avg:94.39ms +step:410/1695 train_time:38702ms step_avg:94.39ms +step:411/1695 train_time:38799ms step_avg:94.40ms +step:412/1695 train_time:38895ms step_avg:94.41ms +step:413/1695 train_time:38992ms step_avg:94.41ms +step:414/1695 train_time:39088ms step_avg:94.41ms +step:415/1695 train_time:39183ms step_avg:94.42ms +step:416/1695 train_time:39279ms step_avg:94.42ms +step:417/1695 train_time:39377ms step_avg:94.43ms +step:418/1695 train_time:39472ms step_avg:94.43ms +step:419/1695 train_time:39568ms step_avg:94.43ms +step:420/1695 train_time:39664ms step_avg:94.44ms +step:421/1695 train_time:39760ms step_avg:94.44ms +step:422/1695 train_time:39857ms step_avg:94.45ms +step:423/1695 train_time:39954ms step_avg:94.45ms +step:424/1695 train_time:40051ms step_avg:94.46ms +step:425/1695 train_time:40147ms step_avg:94.46ms +step:426/1695 train_time:40242ms step_avg:94.47ms +step:427/1695 train_time:40338ms step_avg:94.47ms +step:428/1695 train_time:40434ms step_avg:94.47ms +step:429/1695 train_time:40531ms step_avg:94.48ms +step:430/1695 train_time:40627ms step_avg:94.48ms +step:431/1695 train_time:40722ms step_avg:94.48ms +step:432/1695 train_time:40819ms step_avg:94.49ms +step:433/1695 train_time:40916ms step_avg:94.49ms +step:434/1695 train_time:41013ms step_avg:94.50ms +step:435/1695 train_time:41109ms step_avg:94.50ms +step:436/1695 train_time:41204ms step_avg:94.51ms +step:437/1695 train_time:41301ms step_avg:94.51ms +step:438/1695 train_time:41397ms step_avg:94.51ms +step:439/1695 train_time:41493ms step_avg:94.52ms +step:440/1695 train_time:41589ms step_avg:94.52ms +step:441/1695 train_time:41684ms step_avg:94.52ms +step:442/1695 train_time:41781ms step_avg:94.53ms +step:443/1695 train_time:41878ms step_avg:94.53ms +step:444/1695 train_time:41975ms step_avg:94.54ms +step:445/1695 train_time:42071ms step_avg:94.54ms +step:446/1695 train_time:42167ms step_avg:94.55ms +step:447/1695 train_time:42263ms step_avg:94.55ms +step:448/1695 train_time:42359ms step_avg:94.55ms +step:449/1695 train_time:42456ms step_avg:94.56ms +step:450/1695 train_time:42552ms step_avg:94.56ms +step:451/1695 train_time:42648ms step_avg:94.56ms +step:452/1695 train_time:42744ms step_avg:94.57ms +step:453/1695 train_time:42840ms step_avg:94.57ms +step:454/1695 train_time:42937ms step_avg:94.57ms +step:455/1695 train_time:43033ms step_avg:94.58ms +step:456/1695 train_time:43129ms step_avg:94.58ms +step:457/1695 train_time:43225ms step_avg:94.58ms +step:458/1695 train_time:43321ms step_avg:94.59ms +step:459/1695 train_time:43418ms step_avg:94.59ms +step:460/1695 train_time:43514ms step_avg:94.60ms +step:461/1695 train_time:43610ms step_avg:94.60ms +step:462/1695 train_time:43707ms step_avg:94.60ms +step:463/1695 train_time:43802ms step_avg:94.60ms +step:464/1695 train_time:43899ms step_avg:94.61ms +step:465/1695 train_time:43995ms step_avg:94.61ms +step:466/1695 train_time:44091ms step_avg:94.62ms +step:467/1695 train_time:44187ms step_avg:94.62ms +step:468/1695 train_time:44283ms step_avg:94.62ms +step:469/1695 train_time:44379ms step_avg:94.63ms +step:470/1695 train_time:44477ms step_avg:94.63ms +step:471/1695 train_time:44573ms step_avg:94.64ms +step:472/1695 train_time:44669ms step_avg:94.64ms +step:473/1695 train_time:44765ms step_avg:94.64ms +step:474/1695 train_time:44862ms step_avg:94.64ms +step:475/1695 train_time:44958ms step_avg:94.65ms +step:476/1695 train_time:45054ms step_avg:94.65ms +step:477/1695 train_time:45151ms step_avg:94.66ms +step:478/1695 train_time:45246ms step_avg:94.66ms +step:479/1695 train_time:45342ms step_avg:94.66ms +step:480/1695 train_time:45439ms step_avg:94.66ms +step:481/1695 train_time:45535ms step_avg:94.67ms +step:482/1695 train_time:45632ms step_avg:94.67ms +step:483/1695 train_time:45727ms step_avg:94.67ms +step:484/1695 train_time:45824ms step_avg:94.68ms +step:485/1695 train_time:45920ms step_avg:94.68ms +step:486/1695 train_time:46017ms step_avg:94.68ms +step:487/1695 train_time:46113ms step_avg:94.69ms +step:488/1695 train_time:46209ms step_avg:94.69ms +step:489/1695 train_time:46305ms step_avg:94.69ms +step:490/1695 train_time:46401ms step_avg:94.70ms +step:491/1695 train_time:46498ms step_avg:94.70ms +step:492/1695 train_time:46594ms step_avg:94.70ms +step:493/1695 train_time:46690ms step_avg:94.71ms +step:494/1695 train_time:46787ms step_avg:94.71ms +step:495/1695 train_time:46883ms step_avg:94.71ms +step:496/1695 train_time:46979ms step_avg:94.72ms +step:497/1695 train_time:47077ms step_avg:94.72ms +step:498/1695 train_time:47173ms step_avg:94.72ms +step:499/1695 train_time:47270ms step_avg:94.73ms +step:500/1695 train_time:47366ms step_avg:94.73ms +step:500/1695 val_loss:3.7326 train_time:47459ms step_avg:94.92ms +step:501/1695 train_time:47487ms step_avg:94.78ms +step:502/1695 train_time:47571ms step_avg:94.76ms +step:503/1695 train_time:47672ms step_avg:94.78ms +step:504/1695 train_time:47768ms step_avg:94.78ms +step:505/1695 train_time:47863ms step_avg:94.78ms +step:506/1695 train_time:47959ms step_avg:94.78ms +step:507/1695 train_time:48055ms step_avg:94.78ms +step:508/1695 train_time:48151ms step_avg:94.78ms +step:509/1695 train_time:48246ms step_avg:94.79ms +step:510/1695 train_time:48342ms step_avg:94.79ms +step:511/1695 train_time:48438ms step_avg:94.79ms +step:512/1695 train_time:48537ms step_avg:94.80ms +step:513/1695 train_time:48635ms step_avg:94.80ms +step:514/1695 train_time:48732ms step_avg:94.81ms +step:515/1695 train_time:48829ms step_avg:94.81ms +step:516/1695 train_time:48925ms step_avg:94.82ms +step:517/1695 train_time:49021ms step_avg:94.82ms +step:518/1695 train_time:49117ms step_avg:94.82ms +step:519/1695 train_time:49213ms step_avg:94.82ms +step:520/1695 train_time:49309ms step_avg:94.82ms +step:521/1695 train_time:49404ms step_avg:94.83ms +step:522/1695 train_time:49501ms step_avg:94.83ms +step:523/1695 train_time:49598ms step_avg:94.83ms +step:524/1695 train_time:49696ms step_avg:94.84ms +step:525/1695 train_time:49794ms step_avg:94.85ms +step:526/1695 train_time:49890ms step_avg:94.85ms +step:527/1695 train_time:49986ms step_avg:94.85ms +step:528/1695 train_time:50082ms step_avg:94.85ms +step:529/1695 train_time:50179ms step_avg:94.86ms +step:530/1695 train_time:50275ms step_avg:94.86ms +step:531/1695 train_time:50371ms step_avg:94.86ms +step:532/1695 train_time:50467ms step_avg:94.86ms +step:533/1695 train_time:50564ms step_avg:94.87ms +step:534/1695 train_time:50661ms step_avg:94.87ms +step:535/1695 train_time:50760ms step_avg:94.88ms +step:536/1695 train_time:50858ms step_avg:94.88ms +step:537/1695 train_time:50954ms step_avg:94.89ms +step:538/1695 train_time:51050ms step_avg:94.89ms +step:539/1695 train_time:51146ms step_avg:94.89ms +step:540/1695 train_time:51243ms step_avg:94.90ms +step:541/1695 train_time:51340ms step_avg:94.90ms +step:542/1695 train_time:51436ms step_avg:94.90ms +step:543/1695 train_time:51533ms step_avg:94.90ms +step:544/1695 train_time:51629ms step_avg:94.91ms +step:545/1695 train_time:51725ms step_avg:94.91ms +step:546/1695 train_time:51822ms step_avg:94.91ms +step:547/1695 train_time:51920ms step_avg:94.92ms +step:548/1695 train_time:52017ms step_avg:94.92ms +step:549/1695 train_time:52113ms step_avg:94.92ms +step:550/1695 train_time:52210ms step_avg:94.93ms +step:551/1695 train_time:52306ms step_avg:94.93ms +step:552/1695 train_time:52402ms step_avg:94.93ms +step:553/1695 train_time:52499ms step_avg:94.94ms +step:554/1695 train_time:52596ms step_avg:94.94ms +step:555/1695 train_time:52693ms step_avg:94.94ms +step:556/1695 train_time:52789ms step_avg:94.94ms +step:557/1695 train_time:52887ms step_avg:94.95ms +step:558/1695 train_time:52985ms step_avg:94.95ms +step:559/1695 train_time:53083ms step_avg:94.96ms +step:560/1695 train_time:53180ms step_avg:94.96ms +step:561/1695 train_time:53277ms step_avg:94.97ms +step:562/1695 train_time:53373ms step_avg:94.97ms +step:563/1695 train_time:53469ms step_avg:94.97ms +step:564/1695 train_time:53565ms step_avg:94.97ms +step:565/1695 train_time:53662ms step_avg:94.98ms +step:566/1695 train_time:53758ms step_avg:94.98ms +step:567/1695 train_time:53855ms step_avg:94.98ms +step:568/1695 train_time:53951ms step_avg:94.98ms +step:569/1695 train_time:54047ms step_avg:94.99ms +step:570/1695 train_time:54143ms step_avg:94.99ms +step:571/1695 train_time:54240ms step_avg:94.99ms +step:572/1695 train_time:54337ms step_avg:94.99ms +step:573/1695 train_time:54433ms step_avg:95.00ms +step:574/1695 train_time:54528ms step_avg:95.00ms +step:575/1695 train_time:54624ms step_avg:95.00ms +step:576/1695 train_time:54721ms step_avg:95.00ms +step:577/1695 train_time:54818ms step_avg:95.01ms +step:578/1695 train_time:54915ms step_avg:95.01ms +step:579/1695 train_time:55011ms step_avg:95.01ms +step:580/1695 train_time:55107ms step_avg:95.01ms +step:581/1695 train_time:55204ms step_avg:95.02ms +step:582/1695 train_time:55300ms step_avg:95.02ms +step:583/1695 train_time:55397ms step_avg:95.02ms +step:584/1695 train_time:55495ms step_avg:95.03ms +step:585/1695 train_time:55591ms step_avg:95.03ms +step:586/1695 train_time:55687ms step_avg:95.03ms +step:587/1695 train_time:55784ms step_avg:95.03ms +step:588/1695 train_time:55882ms step_avg:95.04ms +step:589/1695 train_time:55980ms step_avg:95.04ms +step:590/1695 train_time:56076ms step_avg:95.04ms +step:591/1695 train_time:56173ms step_avg:95.05ms +step:592/1695 train_time:56269ms step_avg:95.05ms +step:593/1695 train_time:56366ms step_avg:95.05ms +step:594/1695 train_time:56463ms step_avg:95.06ms +step:595/1695 train_time:56560ms step_avg:95.06ms +step:596/1695 train_time:56656ms step_avg:95.06ms +step:597/1695 train_time:56752ms step_avg:95.06ms +step:598/1695 train_time:56850ms step_avg:95.07ms +step:599/1695 train_time:56945ms step_avg:95.07ms +step:600/1695 train_time:57042ms step_avg:95.07ms +step:601/1695 train_time:57139ms step_avg:95.07ms +step:602/1695 train_time:57236ms step_avg:95.08ms +step:603/1695 train_time:57331ms step_avg:95.08ms +step:604/1695 train_time:57427ms step_avg:95.08ms +step:605/1695 train_time:57523ms step_avg:95.08ms +step:606/1695 train_time:57619ms step_avg:95.08ms +step:607/1695 train_time:57717ms step_avg:95.09ms +step:608/1695 train_time:57813ms step_avg:95.09ms +step:609/1695 train_time:57909ms step_avg:95.09ms +step:610/1695 train_time:58005ms step_avg:95.09ms +step:611/1695 train_time:58102ms step_avg:95.09ms +step:612/1695 train_time:58200ms step_avg:95.10ms +step:613/1695 train_time:58298ms step_avg:95.10ms +step:614/1695 train_time:58394ms step_avg:95.11ms +step:615/1695 train_time:58490ms step_avg:95.11ms +step:616/1695 train_time:58586ms step_avg:95.11ms +step:617/1695 train_time:58684ms step_avg:95.11ms +step:618/1695 train_time:58781ms step_avg:95.12ms +step:619/1695 train_time:58878ms step_avg:95.12ms +step:620/1695 train_time:58974ms step_avg:95.12ms +step:621/1695 train_time:59070ms step_avg:95.12ms +step:622/1695 train_time:59166ms step_avg:95.12ms +step:623/1695 train_time:59262ms step_avg:95.12ms +step:624/1695 train_time:59358ms step_avg:95.13ms +step:625/1695 train_time:59455ms step_avg:95.13ms +step:625/1695 val_loss:3.6470 train_time:59549ms step_avg:95.28ms +step:626/1695 train_time:59577ms step_avg:95.17ms +step:627/1695 train_time:59658ms step_avg:95.15ms +step:628/1695 train_time:59759ms step_avg:95.16ms +step:629/1695 train_time:59857ms step_avg:95.16ms +step:630/1695 train_time:59954ms step_avg:95.17ms +step:631/1695 train_time:60051ms step_avg:95.17ms +step:632/1695 train_time:60148ms step_avg:95.17ms +step:633/1695 train_time:60245ms step_avg:95.17ms +step:634/1695 train_time:60342ms step_avg:95.18ms +step:635/1695 train_time:60669ms step_avg:95.54ms +step:636/1695 train_time:60764ms step_avg:95.54ms +step:637/1695 train_time:60861ms step_avg:95.54ms +step:638/1695 train_time:60959ms step_avg:95.55ms +step:639/1695 train_time:61056ms step_avg:95.55ms +step:640/1695 train_time:61153ms step_avg:95.55ms +step:641/1695 train_time:61547ms step_avg:96.02ms +step:642/1695 train_time:61643ms step_avg:96.02ms +step:643/1695 train_time:61741ms step_avg:96.02ms +step:644/1695 train_time:61838ms step_avg:96.02ms +step:645/1695 train_time:61935ms step_avg:96.02ms +step:646/1695 train_time:62032ms step_avg:96.03ms +step:647/1695 train_time:62129ms step_avg:96.03ms +step:648/1695 train_time:62226ms step_avg:96.03ms +step:649/1695 train_time:62323ms step_avg:96.03ms +step:650/1695 train_time:62421ms step_avg:96.03ms +step:651/1695 train_time:62522ms step_avg:96.04ms +step:652/1695 train_time:62917ms step_avg:96.50ms +step:653/1695 train_time:62966ms step_avg:96.43ms +step:654/1695 train_time:63062ms step_avg:96.43ms +step:655/1695 train_time:63160ms step_avg:96.43ms +step:656/1695 train_time:63258ms step_avg:96.43ms +step:657/1695 train_time:63355ms step_avg:96.43ms +step:658/1695 train_time:63452ms step_avg:96.43ms +step:659/1695 train_time:63549ms step_avg:96.43ms +step:660/1695 train_time:63646ms step_avg:96.43ms +step:661/1695 train_time:63743ms step_avg:96.43ms +step:662/1695 train_time:63844ms step_avg:96.44ms +step:663/1695 train_time:63943ms step_avg:96.45ms +step:664/1695 train_time:64042ms step_avg:96.45ms +step:665/1695 train_time:64140ms step_avg:96.45ms +step:666/1695 train_time:64238ms step_avg:96.45ms +step:667/1695 train_time:64336ms step_avg:96.46ms +step:668/1695 train_time:64434ms step_avg:96.46ms +step:669/1695 train_time:64532ms step_avg:96.46ms +step:670/1695 train_time:64629ms step_avg:96.46ms +step:671/1695 train_time:64727ms step_avg:96.46ms +step:672/1695 train_time:64824ms step_avg:96.46ms +step:673/1695 train_time:64922ms step_avg:96.47ms +step:674/1695 train_time:65020ms step_avg:96.47ms +step:675/1695 train_time:65118ms step_avg:96.47ms +step:676/1695 train_time:65216ms step_avg:96.47ms +step:677/1695 train_time:65314ms step_avg:96.47ms +step:678/1695 train_time:65411ms step_avg:96.48ms +step:679/1695 train_time:65508ms step_avg:96.48ms +step:680/1695 train_time:65606ms step_avg:96.48ms +step:681/1695 train_time:65703ms step_avg:96.48ms +step:682/1695 train_time:65802ms step_avg:96.48ms +step:683/1695 train_time:65899ms step_avg:96.48ms +step:684/1695 train_time:65997ms step_avg:96.49ms +step:685/1695 train_time:66094ms step_avg:96.49ms +step:686/1695 train_time:66192ms step_avg:96.49ms +step:687/1695 train_time:66290ms step_avg:96.49ms +step:688/1695 train_time:66387ms step_avg:96.49ms +step:689/1695 train_time:66485ms step_avg:96.49ms +step:690/1695 train_time:66583ms step_avg:96.50ms +step:691/1695 train_time:66682ms step_avg:96.50ms +step:692/1695 train_time:66780ms step_avg:96.50ms +step:693/1695 train_time:66878ms step_avg:96.50ms +step:694/1695 train_time:66975ms step_avg:96.51ms +step:695/1695 train_time:67073ms step_avg:96.51ms +step:696/1695 train_time:67171ms step_avg:96.51ms +step:697/1695 train_time:67269ms step_avg:96.51ms +step:698/1695 train_time:67366ms step_avg:96.51ms +step:699/1695 train_time:67464ms step_avg:96.51ms +step:700/1695 train_time:67562ms step_avg:96.52ms +step:701/1695 train_time:67661ms step_avg:96.52ms +step:702/1695 train_time:67758ms step_avg:96.52ms +step:703/1695 train_time:67856ms step_avg:96.52ms +step:704/1695 train_time:67954ms step_avg:96.53ms +step:705/1695 train_time:68051ms step_avg:96.53ms +step:706/1695 train_time:68149ms step_avg:96.53ms +step:707/1695 train_time:68247ms step_avg:96.53ms +step:708/1695 train_time:68344ms step_avg:96.53ms +step:709/1695 train_time:68442ms step_avg:96.53ms +step:710/1695 train_time:68541ms step_avg:96.54ms +step:711/1695 train_time:68638ms step_avg:96.54ms +step:712/1695 train_time:68735ms step_avg:96.54ms +step:713/1695 train_time:68833ms step_avg:96.54ms +step:714/1695 train_time:68930ms step_avg:96.54ms +step:715/1695 train_time:69028ms step_avg:96.54ms +step:716/1695 train_time:69125ms step_avg:96.54ms +step:717/1695 train_time:69224ms step_avg:96.55ms +step:718/1695 train_time:69322ms step_avg:96.55ms +step:719/1695 train_time:69420ms step_avg:96.55ms +step:720/1695 train_time:69518ms step_avg:96.55ms +step:721/1695 train_time:69616ms step_avg:96.56ms +step:722/1695 train_time:69714ms step_avg:96.56ms +step:723/1695 train_time:69811ms step_avg:96.56ms +step:724/1695 train_time:69909ms step_avg:96.56ms +step:725/1695 train_time:70007ms step_avg:96.56ms +step:726/1695 train_time:70105ms step_avg:96.56ms +step:727/1695 train_time:70203ms step_avg:96.57ms +step:728/1695 train_time:70302ms step_avg:96.57ms +step:729/1695 train_time:70684ms step_avg:96.96ms +step:730/1695 train_time:70778ms step_avg:96.96ms +step:731/1695 train_time:70875ms step_avg:96.96ms +step:732/1695 train_time:70973ms step_avg:96.96ms +step:733/1695 train_time:71069ms step_avg:96.96ms +step:734/1695 train_time:71166ms step_avg:96.96ms +step:735/1695 train_time:71263ms step_avg:96.96ms +step:736/1695 train_time:71361ms step_avg:96.96ms +step:737/1695 train_time:71457ms step_avg:96.96ms +step:738/1695 train_time:71554ms step_avg:96.96ms +step:739/1695 train_time:71657ms step_avg:96.96ms +step:740/1695 train_time:71756ms step_avg:96.97ms +step:741/1695 train_time:71854ms step_avg:96.97ms +step:742/1695 train_time:71952ms step_avg:96.97ms +step:743/1695 train_time:72049ms step_avg:96.97ms +step:744/1695 train_time:72146ms step_avg:96.97ms +step:745/1695 train_time:72244ms step_avg:96.97ms +step:746/1695 train_time:72342ms step_avg:96.97ms +step:747/1695 train_time:72440ms step_avg:96.97ms +step:748/1695 train_time:72537ms step_avg:96.98ms +step:749/1695 train_time:72636ms step_avg:96.98ms +step:750/1695 train_time:72735ms step_avg:96.98ms +step:750/1695 val_loss:3.5832 train_time:72831ms step_avg:97.11ms +step:751/1695 train_time:72860ms step_avg:97.02ms +step:752/1695 train_time:72941ms step_avg:97.00ms +step:753/1695 train_time:73042ms step_avg:97.00ms +step:754/1695 train_time:73141ms step_avg:97.00ms +step:755/1695 train_time:73237ms step_avg:97.00ms +step:756/1695 train_time:73335ms step_avg:97.00ms +step:757/1695 train_time:73432ms step_avg:97.00ms +step:758/1695 train_time:73531ms step_avg:97.01ms +step:759/1695 train_time:73629ms step_avg:97.01ms +step:760/1695 train_time:73727ms step_avg:97.01ms +step:761/1695 train_time:73824ms step_avg:97.01ms +step:762/1695 train_time:73923ms step_avg:97.01ms +step:763/1695 train_time:74023ms step_avg:97.02ms +step:764/1695 train_time:74121ms step_avg:97.02ms +step:765/1695 train_time:74219ms step_avg:97.02ms +step:766/1695 train_time:74317ms step_avg:97.02ms +step:767/1695 train_time:74415ms step_avg:97.02ms +step:768/1695 train_time:74512ms step_avg:97.02ms +step:769/1695 train_time:74610ms step_avg:97.02ms +step:770/1695 train_time:74708ms step_avg:97.02ms +step:771/1695 train_time:74807ms step_avg:97.03ms +step:772/1695 train_time:75220ms step_avg:97.43ms +step:773/1695 train_time:75315ms step_avg:97.43ms +step:774/1695 train_time:75412ms step_avg:97.43ms +step:775/1695 train_time:75509ms step_avg:97.43ms +step:776/1695 train_time:75607ms step_avg:97.43ms +step:777/1695 train_time:75704ms step_avg:97.43ms +step:778/1695 train_time:76026ms step_avg:97.72ms +step:779/1695 train_time:76122ms step_avg:97.72ms +step:780/1695 train_time:76219ms step_avg:97.72ms +step:781/1695 train_time:76316ms step_avg:97.72ms +step:782/1695 train_time:76413ms step_avg:97.71ms +step:783/1695 train_time:76511ms step_avg:97.71ms +step:784/1695 train_time:76608ms step_avg:97.71ms +step:785/1695 train_time:76705ms step_avg:97.71ms +step:786/1695 train_time:76802ms step_avg:97.71ms +step:787/1695 train_time:76901ms step_avg:97.71ms +step:788/1695 train_time:77000ms step_avg:97.72ms +step:789/1695 train_time:77099ms step_avg:97.72ms +step:790/1695 train_time:77197ms step_avg:97.72ms +step:791/1695 train_time:77610ms step_avg:98.12ms +step:792/1695 train_time:77707ms step_avg:98.11ms +step:793/1695 train_time:77803ms step_avg:98.11ms +step:794/1695 train_time:77900ms step_avg:98.11ms +step:795/1695 train_time:77997ms step_avg:98.11ms +step:796/1695 train_time:78095ms step_avg:98.11ms +step:797/1695 train_time:78192ms step_avg:98.11ms +step:798/1695 train_time:78290ms step_avg:98.11ms +step:799/1695 train_time:78387ms step_avg:98.11ms +step:800/1695 train_time:78490ms step_avg:98.11ms +step:801/1695 train_time:78592ms step_avg:98.12ms +step:802/1695 train_time:78691ms step_avg:98.12ms +step:803/1695 train_time:78789ms step_avg:98.12ms +step:804/1695 train_time:78888ms step_avg:98.12ms +step:805/1695 train_time:78987ms step_avg:98.12ms +step:806/1695 train_time:79085ms step_avg:98.12ms +step:807/1695 train_time:79183ms step_avg:98.12ms +step:808/1695 train_time:79281ms step_avg:98.12ms +step:809/1695 train_time:79378ms step_avg:98.12ms +step:810/1695 train_time:79478ms step_avg:98.12ms +step:811/1695 train_time:79578ms step_avg:98.12ms +step:812/1695 train_time:79677ms step_avg:98.12ms +step:813/1695 train_time:79774ms step_avg:98.12ms +step:814/1695 train_time:79872ms step_avg:98.12ms +step:815/1695 train_time:79970ms step_avg:98.12ms +step:816/1695 train_time:80069ms step_avg:98.12ms +step:817/1695 train_time:80168ms step_avg:98.12ms +step:818/1695 train_time:80266ms step_avg:98.12ms +step:819/1695 train_time:80364ms step_avg:98.12ms +step:820/1695 train_time:80463ms step_avg:98.13ms +step:821/1695 train_time:80562ms step_avg:98.13ms +step:822/1695 train_time:80661ms step_avg:98.13ms +step:823/1695 train_time:80761ms step_avg:98.13ms +step:824/1695 train_time:80859ms step_avg:98.13ms +step:825/1695 train_time:80957ms step_avg:98.13ms +step:826/1695 train_time:81055ms step_avg:98.13ms +step:827/1695 train_time:81153ms step_avg:98.13ms +step:828/1695 train_time:81251ms step_avg:98.13ms +step:829/1695 train_time:81350ms step_avg:98.13ms +step:830/1695 train_time:81449ms step_avg:98.13ms +step:831/1695 train_time:81548ms step_avg:98.13ms +step:832/1695 train_time:81648ms step_avg:98.13ms +step:833/1695 train_time:81747ms step_avg:98.14ms +step:834/1695 train_time:81846ms step_avg:98.14ms +step:835/1695 train_time:81945ms step_avg:98.14ms +step:836/1695 train_time:82043ms step_avg:98.14ms +step:837/1695 train_time:82141ms step_avg:98.14ms +step:838/1695 train_time:82240ms step_avg:98.14ms +step:839/1695 train_time:82338ms step_avg:98.14ms +step:840/1695 train_time:82436ms step_avg:98.14ms +step:841/1695 train_time:82534ms step_avg:98.14ms +step:842/1695 train_time:82632ms step_avg:98.14ms +step:843/1695 train_time:82730ms step_avg:98.14ms +step:844/1695 train_time:82829ms step_avg:98.14ms +step:845/1695 train_time:82928ms step_avg:98.14ms +step:846/1695 train_time:83028ms step_avg:98.14ms +step:847/1695 train_time:83126ms step_avg:98.14ms +step:848/1695 train_time:83226ms step_avg:98.14ms +step:849/1695 train_time:83326ms step_avg:98.15ms +step:850/1695 train_time:83425ms step_avg:98.15ms +step:851/1695 train_time:83524ms step_avg:98.15ms +step:852/1695 train_time:83622ms step_avg:98.15ms +step:853/1695 train_time:83721ms step_avg:98.15ms +step:854/1695 train_time:83819ms step_avg:98.15ms +step:855/1695 train_time:83916ms step_avg:98.15ms +step:856/1695 train_time:84014ms step_avg:98.15ms +step:857/1695 train_time:84112ms step_avg:98.15ms +step:858/1695 train_time:84210ms step_avg:98.15ms +step:859/1695 train_time:84309ms step_avg:98.15ms +step:860/1695 train_time:84408ms step_avg:98.15ms +step:861/1695 train_time:84507ms step_avg:98.15ms +step:862/1695 train_time:84605ms step_avg:98.15ms +step:863/1695 train_time:84704ms step_avg:98.15ms +step:864/1695 train_time:84802ms step_avg:98.15ms +step:865/1695 train_time:84900ms step_avg:98.15ms +step:866/1695 train_time:84999ms step_avg:98.15ms +step:867/1695 train_time:85097ms step_avg:98.15ms +step:868/1695 train_time:85195ms step_avg:98.15ms +step:869/1695 train_time:85293ms step_avg:98.15ms +step:870/1695 train_time:85391ms step_avg:98.15ms +step:871/1695 train_time:85490ms step_avg:98.15ms +step:872/1695 train_time:85589ms step_avg:98.15ms +step:873/1695 train_time:85689ms step_avg:98.15ms +step:874/1695 train_time:85788ms step_avg:98.16ms +step:875/1695 train_time:85887ms step_avg:98.16ms +step:875/1695 val_loss:3.5360 train_time:85984ms step_avg:98.27ms +step:876/1695 train_time:86012ms step_avg:98.19ms +step:877/1695 train_time:86096ms step_avg:98.17ms +step:878/1695 train_time:86196ms step_avg:98.17ms +step:879/1695 train_time:86296ms step_avg:98.17ms +step:880/1695 train_time:86394ms step_avg:98.17ms +step:881/1695 train_time:86493ms step_avg:98.18ms +step:882/1695 train_time:86592ms step_avg:98.18ms +step:883/1695 train_time:86691ms step_avg:98.18ms +step:884/1695 train_time:86791ms step_avg:98.18ms +step:885/1695 train_time:86889ms step_avg:98.18ms +step:886/1695 train_time:86990ms step_avg:98.18ms +step:887/1695 train_time:87092ms step_avg:98.19ms +step:888/1695 train_time:87195ms step_avg:98.19ms +step:889/1695 train_time:87295ms step_avg:98.19ms +step:890/1695 train_time:87395ms step_avg:98.20ms +step:891/1695 train_time:87494ms step_avg:98.20ms +step:892/1695 train_time:87593ms step_avg:98.20ms +step:893/1695 train_time:87694ms step_avg:98.20ms +step:894/1695 train_time:87793ms step_avg:98.20ms +step:895/1695 train_time:87891ms step_avg:98.20ms +step:896/1695 train_time:87992ms step_avg:98.21ms +step:897/1695 train_time:88093ms step_avg:98.21ms +step:898/1695 train_time:88193ms step_avg:98.21ms +step:899/1695 train_time:88294ms step_avg:98.21ms +step:900/1695 train_time:88394ms step_avg:98.22ms +step:901/1695 train_time:88493ms step_avg:98.22ms +step:902/1695 train_time:88592ms step_avg:98.22ms +step:903/1695 train_time:88691ms step_avg:98.22ms +step:904/1695 train_time:88791ms step_avg:98.22ms +step:905/1695 train_time:88890ms step_avg:98.22ms +step:906/1695 train_time:88989ms step_avg:98.22ms +step:907/1695 train_time:89089ms step_avg:98.22ms +step:908/1695 train_time:89190ms step_avg:98.23ms +step:909/1695 train_time:89290ms step_avg:98.23ms +step:910/1695 train_time:89390ms step_avg:98.23ms +step:911/1695 train_time:89490ms step_avg:98.23ms +step:912/1695 train_time:89589ms step_avg:98.23ms +step:913/1695 train_time:89687ms step_avg:98.23ms +step:914/1695 train_time:89787ms step_avg:98.24ms +step:915/1695 train_time:89885ms step_avg:98.24ms +step:916/1695 train_time:89984ms step_avg:98.24ms +step:917/1695 train_time:90083ms step_avg:98.24ms +step:918/1695 train_time:90182ms step_avg:98.24ms +step:919/1695 train_time:90281ms step_avg:98.24ms +step:920/1695 train_time:90382ms step_avg:98.24ms +step:921/1695 train_time:90481ms step_avg:98.24ms +step:922/1695 train_time:90582ms step_avg:98.24ms +step:923/1695 train_time:90681ms step_avg:98.25ms +step:924/1695 train_time:90782ms step_avg:98.25ms +step:925/1695 train_time:90882ms step_avg:98.25ms +step:926/1695 train_time:90982ms step_avg:98.25ms +step:927/1695 train_time:91081ms step_avg:98.25ms +step:928/1695 train_time:91180ms step_avg:98.25ms +step:929/1695 train_time:91279ms step_avg:98.26ms +step:930/1695 train_time:91380ms step_avg:98.26ms +step:931/1695 train_time:91479ms step_avg:98.26ms +step:932/1695 train_time:91579ms step_avg:98.26ms +step:933/1695 train_time:91680ms step_avg:98.26ms +step:934/1695 train_time:91780ms step_avg:98.27ms +step:935/1695 train_time:91879ms step_avg:98.27ms +step:936/1695 train_time:91980ms step_avg:98.27ms +step:937/1695 train_time:92080ms step_avg:98.27ms +step:938/1695 train_time:92181ms step_avg:98.27ms +step:939/1695 train_time:92280ms step_avg:98.27ms +step:940/1695 train_time:92380ms step_avg:98.28ms +step:941/1695 train_time:92480ms step_avg:98.28ms +step:942/1695 train_time:92579ms step_avg:98.28ms +step:943/1695 train_time:92679ms step_avg:98.28ms +step:944/1695 train_time:92780ms step_avg:98.28ms +step:945/1695 train_time:92880ms step_avg:98.29ms +step:946/1695 train_time:92980ms step_avg:98.29ms +step:947/1695 train_time:93080ms step_avg:98.29ms +step:948/1695 train_time:93179ms step_avg:98.29ms +step:949/1695 train_time:93279ms step_avg:98.29ms +step:950/1695 train_time:93379ms step_avg:98.29ms +step:951/1695 train_time:93479ms step_avg:98.30ms +step:952/1695 train_time:93579ms step_avg:98.30ms +step:953/1695 train_time:93679ms step_avg:98.30ms +step:954/1695 train_time:93779ms step_avg:98.30ms +step:955/1695 train_time:93879ms step_avg:98.30ms +step:956/1695 train_time:93980ms step_avg:98.30ms +step:957/1695 train_time:94079ms step_avg:98.31ms +step:958/1695 train_time:94180ms step_avg:98.31ms +step:959/1695 train_time:94279ms step_avg:98.31ms +step:960/1695 train_time:94379ms step_avg:98.31ms +step:961/1695 train_time:94480ms step_avg:98.31ms +step:962/1695 train_time:94580ms step_avg:98.32ms +step:963/1695 train_time:94681ms step_avg:98.32ms +step:964/1695 train_time:94781ms step_avg:98.32ms +step:965/1695 train_time:94880ms step_avg:98.32ms +step:966/1695 train_time:94980ms step_avg:98.32ms +step:967/1695 train_time:95080ms step_avg:98.32ms +step:968/1695 train_time:95181ms step_avg:98.33ms +step:969/1695 train_time:95281ms step_avg:98.33ms +step:970/1695 train_time:95380ms step_avg:98.33ms +step:971/1695 train_time:95480ms step_avg:98.33ms +step:972/1695 train_time:95580ms step_avg:98.33ms +step:973/1695 train_time:95680ms step_avg:98.33ms +step:974/1695 train_time:95779ms step_avg:98.34ms +step:975/1695 train_time:95879ms step_avg:98.34ms +step:976/1695 train_time:95978ms step_avg:98.34ms +step:977/1695 train_time:96079ms step_avg:98.34ms +step:978/1695 train_time:96180ms step_avg:98.34ms +step:979/1695 train_time:96280ms step_avg:98.34ms +step:980/1695 train_time:96380ms step_avg:98.35ms +step:981/1695 train_time:96479ms step_avg:98.35ms +step:982/1695 train_time:96578ms step_avg:98.35ms +step:983/1695 train_time:96680ms step_avg:98.35ms +step:984/1695 train_time:96780ms step_avg:98.35ms +step:985/1695 train_time:96880ms step_avg:98.36ms +step:986/1695 train_time:96980ms step_avg:98.36ms +step:987/1695 train_time:97080ms step_avg:98.36ms +step:988/1695 train_time:97181ms step_avg:98.36ms +step:989/1695 train_time:97281ms step_avg:98.36ms +step:990/1695 train_time:97381ms step_avg:98.36ms +step:991/1695 train_time:97480ms step_avg:98.37ms +step:992/1695 train_time:97579ms step_avg:98.37ms +step:993/1695 train_time:97679ms step_avg:98.37ms +step:994/1695 train_time:97779ms step_avg:98.37ms +step:995/1695 train_time:97879ms step_avg:98.37ms +step:996/1695 train_time:97979ms step_avg:98.37ms +step:997/1695 train_time:98079ms step_avg:98.37ms +step:998/1695 train_time:98178ms step_avg:98.38ms +step:999/1695 train_time:98279ms step_avg:98.38ms +step:1000/1695 train_time:98379ms step_avg:98.38ms +step:1000/1695 val_loss:3.4899 train_time:98475ms step_avg:98.48ms +step:1001/1695 train_time:98504ms step_avg:98.41ms +step:1002/1695 train_time:98585ms step_avg:98.39ms +step:1003/1695 train_time:98686ms step_avg:98.39ms +step:1004/1695 train_time:98786ms step_avg:98.39ms +step:1005/1695 train_time:98885ms step_avg:98.39ms +step:1006/1695 train_time:98984ms step_avg:98.39ms +step:1007/1695 train_time:99083ms step_avg:98.39ms +step:1008/1695 train_time:99182ms step_avg:98.39ms +step:1009/1695 train_time:99282ms step_avg:98.40ms +step:1010/1695 train_time:99381ms step_avg:98.40ms +step:1011/1695 train_time:99483ms step_avg:98.40ms +step:1012/1695 train_time:99585ms step_avg:98.40ms +step:1013/1695 train_time:99686ms step_avg:98.41ms +step:1014/1695 train_time:99788ms step_avg:98.41ms +step:1015/1695 train_time:99888ms step_avg:98.41ms +step:1016/1695 train_time:99987ms step_avg:98.41ms +step:1017/1695 train_time:100087ms step_avg:98.41ms +step:1018/1695 train_time:100188ms step_avg:98.42ms +step:1019/1695 train_time:100287ms step_avg:98.42ms +step:1020/1695 train_time:100389ms step_avg:98.42ms +step:1021/1695 train_time:100490ms step_avg:98.42ms +step:1022/1695 train_time:100591ms step_avg:98.43ms +step:1023/1695 train_time:100692ms step_avg:98.43ms +step:1024/1695 train_time:100794ms step_avg:98.43ms +step:1025/1695 train_time:100894ms step_avg:98.43ms +step:1026/1695 train_time:100994ms step_avg:98.43ms +step:1027/1695 train_time:101093ms step_avg:98.44ms +step:1028/1695 train_time:101194ms step_avg:98.44ms +step:1029/1695 train_time:101295ms step_avg:98.44ms +step:1030/1695 train_time:101394ms step_avg:98.44ms +step:1031/1695 train_time:101494ms step_avg:98.44ms +step:1032/1695 train_time:101594ms step_avg:98.44ms +step:1033/1695 train_time:101693ms step_avg:98.44ms +step:1034/1695 train_time:101792ms step_avg:98.45ms +step:1035/1695 train_time:101892ms step_avg:98.45ms +step:1036/1695 train_time:101991ms step_avg:98.45ms +step:1037/1695 train_time:102091ms step_avg:98.45ms +step:1038/1695 train_time:102192ms step_avg:98.45ms +step:1039/1695 train_time:102292ms step_avg:98.45ms +step:1040/1695 train_time:102392ms step_avg:98.45ms +step:1041/1695 train_time:102492ms step_avg:98.46ms +step:1042/1695 train_time:102593ms step_avg:98.46ms +step:1043/1695 train_time:102693ms step_avg:98.46ms +step:1044/1695 train_time:102792ms step_avg:98.46ms +step:1045/1695 train_time:102892ms step_avg:98.46ms +step:1046/1695 train_time:102992ms step_avg:98.46ms +step:1047/1695 train_time:103091ms step_avg:98.46ms +step:1048/1695 train_time:103191ms step_avg:98.46ms +step:1049/1695 train_time:103291ms step_avg:98.47ms +step:1050/1695 train_time:103392ms step_avg:98.47ms +step:1051/1695 train_time:103493ms step_avg:98.47ms +step:1052/1695 train_time:103593ms step_avg:98.47ms +step:1053/1695 train_time:103693ms step_avg:98.47ms +step:1054/1695 train_time:103793ms step_avg:98.47ms +step:1055/1695 train_time:103892ms step_avg:98.48ms +step:1056/1695 train_time:103991ms step_avg:98.48ms +step:1057/1695 train_time:104091ms step_avg:98.48ms +step:1058/1695 train_time:104191ms step_avg:98.48ms +step:1059/1695 train_time:104290ms step_avg:98.48ms +step:1060/1695 train_time:104390ms step_avg:98.48ms +step:1061/1695 train_time:104491ms step_avg:98.48ms +step:1062/1695 train_time:104591ms step_avg:98.49ms +step:1063/1695 train_time:104691ms step_avg:98.49ms +step:1064/1695 train_time:104791ms step_avg:98.49ms +step:1065/1695 train_time:104891ms step_avg:98.49ms +step:1066/1695 train_time:104990ms step_avg:98.49ms +step:1067/1695 train_time:105091ms step_avg:98.49ms +step:1068/1695 train_time:105191ms step_avg:98.49ms +step:1069/1695 train_time:105292ms step_avg:98.50ms +step:1070/1695 train_time:105392ms step_avg:98.50ms +step:1071/1695 train_time:105492ms step_avg:98.50ms +step:1072/1695 train_time:105592ms step_avg:98.50ms +step:1073/1695 train_time:105691ms step_avg:98.50ms +step:1074/1695 train_time:105791ms step_avg:98.50ms +step:1075/1695 train_time:105890ms step_avg:98.50ms +step:1076/1695 train_time:105991ms step_avg:98.50ms +step:1077/1695 train_time:106091ms step_avg:98.51ms +step:1078/1695 train_time:106192ms step_avg:98.51ms +step:1079/1695 train_time:106292ms step_avg:98.51ms +step:1080/1695 train_time:106392ms step_avg:98.51ms +step:1081/1695 train_time:106491ms step_avg:98.51ms +step:1082/1695 train_time:106593ms step_avg:98.51ms +step:1083/1695 train_time:106692ms step_avg:98.52ms +step:1084/1695 train_time:106792ms step_avg:98.52ms +step:1085/1695 train_time:106891ms step_avg:98.52ms +step:1086/1695 train_time:106991ms step_avg:98.52ms +step:1087/1695 train_time:107092ms step_avg:98.52ms +step:1088/1695 train_time:107192ms step_avg:98.52ms +step:1089/1695 train_time:107291ms step_avg:98.52ms +step:1090/1695 train_time:107392ms step_avg:98.52ms +step:1091/1695 train_time:107492ms step_avg:98.53ms +step:1092/1695 train_time:107593ms step_avg:98.53ms +step:1093/1695 train_time:107692ms step_avg:98.53ms +step:1094/1695 train_time:107791ms step_avg:98.53ms +step:1095/1695 train_time:107891ms step_avg:98.53ms +step:1096/1695 train_time:107991ms step_avg:98.53ms +step:1097/1695 train_time:108090ms step_avg:98.53ms +step:1098/1695 train_time:108191ms step_avg:98.53ms +step:1099/1695 train_time:108289ms step_avg:98.53ms +step:1100/1695 train_time:108389ms step_avg:98.54ms +step:1101/1695 train_time:108490ms step_avg:98.54ms +step:1102/1695 train_time:108590ms step_avg:98.54ms +step:1103/1695 train_time:108690ms step_avg:98.54ms +step:1104/1695 train_time:108791ms step_avg:98.54ms +step:1105/1695 train_time:108892ms step_avg:98.54ms +step:1106/1695 train_time:108991ms step_avg:98.55ms +step:1107/1695 train_time:109092ms step_avg:98.55ms +step:1108/1695 train_time:109193ms step_avg:98.55ms +step:1109/1695 train_time:109291ms step_avg:98.55ms +step:1110/1695 train_time:109392ms step_avg:98.55ms +step:1111/1695 train_time:109493ms step_avg:98.55ms +step:1112/1695 train_time:109593ms step_avg:98.55ms +step:1113/1695 train_time:109692ms step_avg:98.56ms +step:1114/1695 train_time:109792ms step_avg:98.56ms +step:1115/1695 train_time:109892ms step_avg:98.56ms +step:1116/1695 train_time:109991ms step_avg:98.56ms +step:1117/1695 train_time:110092ms step_avg:98.56ms +step:1118/1695 train_time:110193ms step_avg:98.56ms +step:1119/1695 train_time:110293ms step_avg:98.56ms +step:1120/1695 train_time:110392ms step_avg:98.56ms +step:1121/1695 train_time:110492ms step_avg:98.57ms +step:1122/1695 train_time:110593ms step_avg:98.57ms +step:1123/1695 train_time:110693ms step_avg:98.57ms +step:1124/1695 train_time:110792ms step_avg:98.57ms +step:1125/1695 train_time:110892ms step_avg:98.57ms +step:1125/1695 val_loss:3.4392 train_time:110988ms step_avg:98.66ms +step:1126/1695 train_time:111017ms step_avg:98.59ms +step:1127/1695 train_time:111099ms step_avg:98.58ms +step:1128/1695 train_time:111201ms step_avg:98.58ms +step:1129/1695 train_time:111301ms step_avg:98.58ms +step:1130/1695 train_time:111401ms step_avg:98.59ms +step:1131/1695 train_time:111502ms step_avg:98.59ms +step:1132/1695 train_time:111602ms step_avg:98.59ms +step:1133/1695 train_time:111702ms step_avg:98.59ms +step:1134/1695 train_time:111802ms step_avg:98.59ms +step:1135/1695 train_time:111903ms step_avg:98.59ms +step:1136/1695 train_time:112004ms step_avg:98.60ms +step:1137/1695 train_time:112106ms step_avg:98.60ms +step:1138/1695 train_time:112206ms step_avg:98.60ms +step:1139/1695 train_time:112306ms step_avg:98.60ms +step:1140/1695 train_time:112406ms step_avg:98.60ms +step:1141/1695 train_time:112506ms step_avg:98.60ms +step:1142/1695 train_time:112607ms step_avg:98.60ms +step:1143/1695 train_time:112706ms step_avg:98.61ms +step:1144/1695 train_time:112807ms step_avg:98.61ms +step:1145/1695 train_time:112908ms step_avg:98.61ms +step:1146/1695 train_time:113008ms step_avg:98.61ms +step:1147/1695 train_time:113108ms step_avg:98.61ms +step:1148/1695 train_time:113207ms step_avg:98.61ms +step:1149/1695 train_time:113308ms step_avg:98.61ms +step:1150/1695 train_time:113408ms step_avg:98.62ms +step:1151/1695 train_time:113509ms step_avg:98.62ms +step:1152/1695 train_time:113609ms step_avg:98.62ms +step:1153/1695 train_time:113710ms step_avg:98.62ms +step:1154/1695 train_time:113810ms step_avg:98.62ms +step:1155/1695 train_time:113910ms step_avg:98.62ms +step:1156/1695 train_time:114010ms step_avg:98.62ms +step:1157/1695 train_time:114112ms step_avg:98.63ms +step:1158/1695 train_time:114212ms step_avg:98.63ms +step:1159/1695 train_time:114313ms step_avg:98.63ms +step:1160/1695 train_time:114417ms step_avg:98.64ms +step:1161/1695 train_time:114518ms step_avg:98.64ms +step:1162/1695 train_time:114619ms step_avg:98.64ms +step:1163/1695 train_time:114723ms step_avg:98.64ms +step:1164/1695 train_time:114825ms step_avg:98.65ms +step:1165/1695 train_time:114926ms step_avg:98.65ms +step:1166/1695 train_time:115027ms step_avg:98.65ms +step:1167/1695 train_time:115127ms step_avg:98.65ms +step:1168/1695 train_time:115228ms step_avg:98.65ms +step:1169/1695 train_time:115328ms step_avg:98.66ms +step:1170/1695 train_time:115429ms step_avg:98.66ms +step:1171/1695 train_time:115532ms step_avg:98.66ms +step:1172/1695 train_time:115635ms step_avg:98.66ms +step:1173/1695 train_time:115736ms step_avg:98.67ms +step:1174/1695 train_time:115837ms step_avg:98.67ms +step:1175/1695 train_time:115938ms step_avg:98.67ms +step:1176/1695 train_time:116039ms step_avg:98.67ms +step:1177/1695 train_time:116140ms step_avg:98.67ms +step:1178/1695 train_time:116241ms step_avg:98.68ms +step:1179/1695 train_time:116346ms step_avg:98.68ms +step:1180/1695 train_time:116447ms step_avg:98.68ms +step:1181/1695 train_time:116547ms step_avg:98.69ms +step:1182/1695 train_time:116648ms step_avg:98.69ms +step:1183/1695 train_time:116748ms step_avg:98.69ms +step:1184/1695 train_time:116850ms step_avg:98.69ms +step:1185/1695 train_time:116951ms step_avg:98.69ms +step:1186/1695 train_time:117053ms step_avg:98.70ms +step:1187/1695 train_time:117154ms step_avg:98.70ms +step:1188/1695 train_time:117256ms step_avg:98.70ms +step:1189/1695 train_time:117357ms step_avg:98.70ms +step:1190/1695 train_time:117460ms step_avg:98.71ms +step:1191/1695 train_time:117560ms step_avg:98.71ms +step:1192/1695 train_time:117662ms step_avg:98.71ms +step:1193/1695 train_time:117762ms step_avg:98.71ms +step:1194/1695 train_time:117863ms step_avg:98.71ms +step:1195/1695 train_time:117964ms step_avg:98.71ms +step:1196/1695 train_time:118065ms step_avg:98.72ms +step:1197/1695 train_time:118166ms step_avg:98.72ms +step:1198/1695 train_time:118267ms step_avg:98.72ms +step:1199/1695 train_time:118367ms step_avg:98.72ms +step:1200/1695 train_time:118467ms step_avg:98.72ms +step:1201/1695 train_time:118569ms step_avg:98.72ms +step:1202/1695 train_time:118670ms step_avg:98.73ms +step:1203/1695 train_time:118778ms step_avg:98.74ms +step:1204/1695 train_time:118873ms step_avg:98.73ms +step:1205/1695 train_time:118975ms step_avg:98.73ms +step:1206/1695 train_time:119077ms step_avg:98.74ms +step:1207/1695 train_time:119179ms step_avg:98.74ms +step:1208/1695 train_time:119280ms step_avg:98.74ms +step:1209/1695 train_time:119381ms step_avg:98.74ms +step:1210/1695 train_time:119482ms step_avg:98.75ms +step:1211/1695 train_time:119583ms step_avg:98.75ms +step:1212/1695 train_time:119685ms step_avg:98.75ms +step:1213/1695 train_time:119785ms step_avg:98.75ms +step:1214/1695 train_time:119885ms step_avg:98.75ms +step:1215/1695 train_time:119986ms step_avg:98.75ms +step:1216/1695 train_time:120089ms step_avg:98.76ms +step:1217/1695 train_time:120189ms step_avg:98.76ms +step:1218/1695 train_time:120290ms step_avg:98.76ms +step:1219/1695 train_time:120390ms step_avg:98.76ms +step:1220/1695 train_time:120491ms step_avg:98.76ms +step:1221/1695 train_time:120592ms step_avg:98.77ms +step:1222/1695 train_time:120693ms step_avg:98.77ms +step:1223/1695 train_time:120795ms step_avg:98.77ms +step:1224/1695 train_time:120895ms step_avg:98.77ms +step:1225/1695 train_time:120998ms step_avg:98.77ms +step:1226/1695 train_time:121101ms step_avg:98.78ms +step:1227/1695 train_time:121201ms step_avg:98.78ms +step:1228/1695 train_time:121302ms step_avg:98.78ms +step:1229/1695 train_time:121404ms step_avg:98.78ms +step:1230/1695 train_time:121505ms step_avg:98.78ms +step:1231/1695 train_time:121606ms step_avg:98.79ms +step:1232/1695 train_time:121706ms step_avg:98.79ms +step:1233/1695 train_time:121806ms step_avg:98.79ms +step:1234/1695 train_time:121907ms step_avg:98.79ms +step:1235/1695 train_time:122008ms step_avg:98.79ms +step:1236/1695 train_time:122109ms step_avg:98.79ms +step:1237/1695 train_time:122209ms step_avg:98.79ms +step:1238/1695 train_time:122310ms step_avg:98.80ms +step:1239/1695 train_time:122410ms step_avg:98.80ms +step:1240/1695 train_time:122512ms step_avg:98.80ms +step:1241/1695 train_time:122613ms step_avg:98.80ms +step:1242/1695 train_time:122714ms step_avg:98.80ms +step:1243/1695 train_time:122815ms step_avg:98.81ms +step:1244/1695 train_time:122916ms step_avg:98.81ms +step:1245/1695 train_time:123018ms step_avg:98.81ms +step:1246/1695 train_time:123121ms step_avg:98.81ms +step:1247/1695 train_time:123222ms step_avg:98.81ms +step:1248/1695 train_time:123323ms step_avg:98.82ms +step:1249/1695 train_time:123424ms step_avg:98.82ms +step:1250/1695 train_time:123525ms step_avg:98.82ms +step:1250/1695 val_loss:3.3925 train_time:123622ms step_avg:98.90ms +step:1251/1695 train_time:123650ms step_avg:98.84ms +step:1252/1695 train_time:123734ms step_avg:98.83ms +step:1253/1695 train_time:123835ms step_avg:98.83ms +step:1254/1695 train_time:123936ms step_avg:98.83ms +step:1255/1695 train_time:124037ms step_avg:98.83ms +step:1256/1695 train_time:124138ms step_avg:98.84ms +step:1257/1695 train_time:124238ms step_avg:98.84ms +step:1258/1695 train_time:124339ms step_avg:98.84ms +step:1259/1695 train_time:124439ms step_avg:98.84ms +step:1260/1695 train_time:124540ms step_avg:98.84ms +step:1261/1695 train_time:124642ms step_avg:98.84ms +step:1262/1695 train_time:124744ms step_avg:98.85ms +step:1263/1695 train_time:124844ms step_avg:98.85ms +step:1264/1695 train_time:124944ms step_avg:98.85ms +step:1265/1695 train_time:125043ms step_avg:98.85ms +step:1266/1695 train_time:125143ms step_avg:98.85ms +step:1267/1695 train_time:125243ms step_avg:98.85ms +step:1268/1695 train_time:125344ms step_avg:98.85ms +step:1269/1695 train_time:125445ms step_avg:98.85ms +step:1270/1695 train_time:125546ms step_avg:98.86ms +step:1271/1695 train_time:125648ms step_avg:98.86ms +step:1272/1695 train_time:125748ms step_avg:98.86ms +step:1273/1695 train_time:125849ms step_avg:98.86ms +step:1274/1695 train_time:125949ms step_avg:98.86ms +step:1275/1695 train_time:126051ms step_avg:98.86ms +step:1276/1695 train_time:126154ms step_avg:98.87ms +step:1277/1695 train_time:126255ms step_avg:98.87ms +step:1278/1695 train_time:126357ms step_avg:98.87ms +step:1279/1695 train_time:126460ms step_avg:98.87ms +step:1280/1695 train_time:126560ms step_avg:98.88ms +step:1281/1695 train_time:126662ms step_avg:98.88ms +step:1282/1695 train_time:126762ms step_avg:98.88ms +step:1283/1695 train_time:126861ms step_avg:98.88ms +step:1284/1695 train_time:126961ms step_avg:98.88ms +step:1285/1695 train_time:127062ms step_avg:98.88ms +step:1286/1695 train_time:127161ms step_avg:98.88ms +step:1287/1695 train_time:127262ms step_avg:98.88ms +step:1288/1695 train_time:127363ms step_avg:98.88ms +step:1289/1695 train_time:127465ms step_avg:98.89ms +step:1290/1695 train_time:127565ms step_avg:98.89ms +step:1291/1695 train_time:127666ms step_avg:98.89ms +step:1292/1695 train_time:127766ms step_avg:98.89ms +step:1293/1695 train_time:127867ms step_avg:98.89ms +step:1294/1695 train_time:127969ms step_avg:98.89ms +step:1295/1695 train_time:128070ms step_avg:98.90ms +step:1296/1695 train_time:128171ms step_avg:98.90ms +step:1297/1695 train_time:128273ms step_avg:98.90ms +step:1298/1695 train_time:128374ms step_avg:98.90ms +step:1299/1695 train_time:128476ms step_avg:98.90ms +step:1300/1695 train_time:128578ms step_avg:98.91ms +step:1301/1695 train_time:128679ms step_avg:98.91ms +step:1302/1695 train_time:128780ms step_avg:98.91ms +step:1303/1695 train_time:128881ms step_avg:98.91ms +step:1304/1695 train_time:128982ms step_avg:98.91ms +step:1305/1695 train_time:129083ms step_avg:98.91ms +step:1306/1695 train_time:129182ms step_avg:98.91ms +step:1307/1695 train_time:129283ms step_avg:98.92ms +step:1308/1695 train_time:129385ms step_avg:98.92ms +step:1309/1695 train_time:129486ms step_avg:98.92ms +step:1310/1695 train_time:129587ms step_avg:98.92ms +step:1311/1695 train_time:129688ms step_avg:98.92ms +step:1312/1695 train_time:129788ms step_avg:98.92ms +step:1313/1695 train_time:129890ms step_avg:98.93ms +step:1314/1695 train_time:129992ms step_avg:98.93ms +step:1315/1695 train_time:130093ms step_avg:98.93ms +step:1316/1695 train_time:130194ms step_avg:98.93ms +step:1317/1695 train_time:130294ms step_avg:98.93ms +step:1318/1695 train_time:130396ms step_avg:98.93ms +step:1319/1695 train_time:130498ms step_avg:98.94ms +step:1320/1695 train_time:130600ms step_avg:98.94ms +step:1321/1695 train_time:130701ms step_avg:98.94ms +step:1322/1695 train_time:130802ms step_avg:98.94ms +step:1323/1695 train_time:130903ms step_avg:98.94ms +step:1324/1695 train_time:131003ms step_avg:98.95ms +step:1325/1695 train_time:131104ms step_avg:98.95ms +step:1326/1695 train_time:131205ms step_avg:98.95ms +step:1327/1695 train_time:131307ms step_avg:98.95ms +step:1328/1695 train_time:131406ms step_avg:98.95ms +step:1329/1695 train_time:131506ms step_avg:98.95ms +step:1330/1695 train_time:131606ms step_avg:98.95ms +step:1331/1695 train_time:131709ms step_avg:98.95ms +step:1332/1695 train_time:131811ms step_avg:98.96ms +step:1333/1695 train_time:131912ms step_avg:98.96ms +step:1334/1695 train_time:132013ms step_avg:98.96ms +step:1335/1695 train_time:132115ms step_avg:98.96ms +step:1336/1695 train_time:132216ms step_avg:98.96ms +step:1337/1695 train_time:132318ms step_avg:98.97ms +step:1338/1695 train_time:132419ms step_avg:98.97ms +step:1339/1695 train_time:132521ms step_avg:98.97ms +step:1340/1695 train_time:132621ms step_avg:98.97ms +step:1341/1695 train_time:132722ms step_avg:98.97ms +step:1342/1695 train_time:132822ms step_avg:98.97ms +step:1343/1695 train_time:132922ms step_avg:98.97ms +step:1344/1695 train_time:133022ms step_avg:98.97ms +step:1345/1695 train_time:133122ms step_avg:98.98ms +step:1346/1695 train_time:133223ms step_avg:98.98ms +step:1347/1695 train_time:133324ms step_avg:98.98ms +step:1348/1695 train_time:133425ms step_avg:98.98ms +step:1349/1695 train_time:133527ms step_avg:98.98ms +step:1350/1695 train_time:133628ms step_avg:98.98ms +step:1351/1695 train_time:133728ms step_avg:98.98ms +step:1352/1695 train_time:133829ms step_avg:98.99ms +step:1353/1695 train_time:133930ms step_avg:98.99ms +step:1354/1695 train_time:134031ms step_avg:98.99ms +step:1355/1695 train_time:134132ms step_avg:98.99ms +step:1356/1695 train_time:134234ms step_avg:98.99ms +step:1357/1695 train_time:134335ms step_avg:98.99ms +step:1358/1695 train_time:134436ms step_avg:99.00ms +step:1359/1695 train_time:134538ms step_avg:99.00ms +step:1360/1695 train_time:134638ms step_avg:99.00ms +step:1361/1695 train_time:134739ms step_avg:99.00ms +step:1362/1695 train_time:134839ms step_avg:99.00ms +step:1363/1695 train_time:134942ms step_avg:99.00ms +step:1364/1695 train_time:135042ms step_avg:99.00ms +step:1365/1695 train_time:135143ms step_avg:99.01ms +step:1366/1695 train_time:135243ms step_avg:99.01ms +step:1367/1695 train_time:135343ms step_avg:99.01ms +step:1368/1695 train_time:135444ms step_avg:99.01ms +step:1369/1695 train_time:135544ms step_avg:99.01ms +step:1370/1695 train_time:135646ms step_avg:99.01ms +step:1371/1695 train_time:135747ms step_avg:99.01ms +step:1372/1695 train_time:135848ms step_avg:99.01ms +step:1373/1695 train_time:135951ms step_avg:99.02ms +step:1374/1695 train_time:136051ms step_avg:99.02ms +step:1375/1695 train_time:136154ms step_avg:99.02ms +step:1375/1695 val_loss:3.3531 train_time:136252ms step_avg:99.09ms +step:1376/1695 train_time:136281ms step_avg:99.04ms +step:1377/1695 train_time:136367ms step_avg:99.03ms +step:1378/1695 train_time:136471ms step_avg:99.04ms +step:1379/1695 train_time:136572ms step_avg:99.04ms +step:1380/1695 train_time:136675ms step_avg:99.04ms +step:1381/1695 train_time:136775ms step_avg:99.04ms +step:1382/1695 train_time:136874ms step_avg:99.04ms +step:1383/1695 train_time:136973ms step_avg:99.04ms +step:1384/1695 train_time:137074ms step_avg:99.04ms +step:1385/1695 train_time:137174ms step_avg:99.04ms +step:1386/1695 train_time:137276ms step_avg:99.04ms +step:1387/1695 train_time:137378ms step_avg:99.05ms +step:1388/1695 train_time:137481ms step_avg:99.05ms +step:1389/1695 train_time:137585ms step_avg:99.05ms +step:1390/1695 train_time:137687ms step_avg:99.06ms +step:1391/1695 train_time:137788ms step_avg:99.06ms +step:1392/1695 train_time:137891ms step_avg:99.06ms +step:1393/1695 train_time:137992ms step_avg:99.06ms +step:1394/1695 train_time:138093ms step_avg:99.06ms +step:1395/1695 train_time:138194ms step_avg:99.06ms +step:1396/1695 train_time:138296ms step_avg:99.07ms +step:1397/1695 train_time:138398ms step_avg:99.07ms +step:1398/1695 train_time:138500ms step_avg:99.07ms +step:1399/1695 train_time:138601ms step_avg:99.07ms +step:1400/1695 train_time:138704ms step_avg:99.07ms +step:1401/1695 train_time:138806ms step_avg:99.08ms +step:1402/1695 train_time:138909ms step_avg:99.08ms +step:1403/1695 train_time:139012ms step_avg:99.08ms +step:1404/1695 train_time:139114ms step_avg:99.08ms +step:1405/1695 train_time:139215ms step_avg:99.09ms +step:1406/1695 train_time:139317ms step_avg:99.09ms +step:1407/1695 train_time:139418ms step_avg:99.09ms +step:1408/1695 train_time:139519ms step_avg:99.09ms +step:1409/1695 train_time:139625ms step_avg:99.09ms +step:1410/1695 train_time:139726ms step_avg:99.10ms +step:1411/1695 train_time:139828ms step_avg:99.10ms +step:1412/1695 train_time:139931ms step_avg:99.10ms +step:1413/1695 train_time:140032ms step_avg:99.10ms +step:1414/1695 train_time:140135ms step_avg:99.11ms +step:1415/1695 train_time:140237ms step_avg:99.11ms +step:1416/1695 train_time:140337ms step_avg:99.11ms +step:1417/1695 train_time:140438ms step_avg:99.11ms +step:1418/1695 train_time:140539ms step_avg:99.11ms +step:1419/1695 train_time:140641ms step_avg:99.11ms +step:1420/1695 train_time:140743ms step_avg:99.12ms +step:1421/1695 train_time:140846ms step_avg:99.12ms +step:1422/1695 train_time:140947ms step_avg:99.12ms +step:1423/1695 train_time:141050ms step_avg:99.12ms +step:1424/1695 train_time:141152ms step_avg:99.12ms +step:1425/1695 train_time:141253ms step_avg:99.12ms +step:1426/1695 train_time:141356ms step_avg:99.13ms +step:1427/1695 train_time:141457ms step_avg:99.13ms +step:1428/1695 train_time:141559ms step_avg:99.13ms +step:1429/1695 train_time:141661ms step_avg:99.13ms +step:1430/1695 train_time:141762ms step_avg:99.13ms +step:1431/1695 train_time:141863ms step_avg:99.14ms +step:1432/1695 train_time:141966ms step_avg:99.14ms +step:1433/1695 train_time:142068ms step_avg:99.14ms +step:1434/1695 train_time:142169ms step_avg:99.14ms +step:1435/1695 train_time:142272ms step_avg:99.14ms +step:1436/1695 train_time:142375ms step_avg:99.15ms +step:1437/1695 train_time:142477ms step_avg:99.15ms +step:1438/1695 train_time:142577ms step_avg:99.15ms +step:1439/1695 train_time:142680ms step_avg:99.15ms +step:1440/1695 train_time:142782ms step_avg:99.15ms +step:1441/1695 train_time:142885ms step_avg:99.16ms +step:1442/1695 train_time:142986ms step_avg:99.16ms +step:1443/1695 train_time:143088ms step_avg:99.16ms +step:1444/1695 train_time:143190ms step_avg:99.16ms +step:1445/1695 train_time:143292ms step_avg:99.16ms +step:1446/1695 train_time:143395ms step_avg:99.17ms +step:1447/1695 train_time:143496ms step_avg:99.17ms +step:1448/1695 train_time:143600ms step_avg:99.17ms +step:1449/1695 train_time:143701ms step_avg:99.17ms +step:1450/1695 train_time:143802ms step_avg:99.17ms +step:1451/1695 train_time:143903ms step_avg:99.18ms +step:1452/1695 train_time:144005ms step_avg:99.18ms +step:1453/1695 train_time:144109ms step_avg:99.18ms +step:1454/1695 train_time:144212ms step_avg:99.18ms +step:1455/1695 train_time:144314ms step_avg:99.19ms +step:1456/1695 train_time:144416ms step_avg:99.19ms +step:1457/1695 train_time:144519ms step_avg:99.19ms +step:1458/1695 train_time:144621ms step_avg:99.19ms +step:1459/1695 train_time:144723ms step_avg:99.19ms +step:1460/1695 train_time:144825ms step_avg:99.20ms +step:1461/1695 train_time:144928ms step_avg:99.20ms +step:1462/1695 train_time:145029ms step_avg:99.20ms +step:1463/1695 train_time:145130ms step_avg:99.20ms +step:1464/1695 train_time:145232ms step_avg:99.20ms +step:1465/1695 train_time:145333ms step_avg:99.20ms +step:1466/1695 train_time:145435ms step_avg:99.21ms +step:1467/1695 train_time:145535ms step_avg:99.21ms +step:1468/1695 train_time:145638ms step_avg:99.21ms +step:1469/1695 train_time:145740ms step_avg:99.21ms +step:1470/1695 train_time:145842ms step_avg:99.21ms +step:1471/1695 train_time:145944ms step_avg:99.21ms +step:1472/1695 train_time:146046ms step_avg:99.22ms +step:1473/1695 train_time:146148ms step_avg:99.22ms +step:1474/1695 train_time:146249ms step_avg:99.22ms +step:1475/1695 train_time:146351ms step_avg:99.22ms +step:1476/1695 train_time:146454ms step_avg:99.22ms +step:1477/1695 train_time:146556ms step_avg:99.23ms +step:1478/1695 train_time:146657ms step_avg:99.23ms +step:1479/1695 train_time:146759ms step_avg:99.23ms +step:1480/1695 train_time:146860ms step_avg:99.23ms +step:1481/1695 train_time:146962ms step_avg:99.23ms +step:1482/1695 train_time:147064ms step_avg:99.23ms +step:1483/1695 train_time:147166ms step_avg:99.24ms +step:1484/1695 train_time:147269ms step_avg:99.24ms +step:1485/1695 train_time:147371ms step_avg:99.24ms +step:1486/1695 train_time:147473ms step_avg:99.24ms +step:1487/1695 train_time:147574ms step_avg:99.24ms +step:1488/1695 train_time:147676ms step_avg:99.24ms +step:1489/1695 train_time:147778ms step_avg:99.25ms +step:1490/1695 train_time:147880ms step_avg:99.25ms +step:1491/1695 train_time:147982ms step_avg:99.25ms +step:1492/1695 train_time:148083ms step_avg:99.25ms +step:1493/1695 train_time:148185ms step_avg:99.25ms +step:1494/1695 train_time:148288ms step_avg:99.26ms +step:1495/1695 train_time:148391ms step_avg:99.26ms +step:1496/1695 train_time:148492ms step_avg:99.26ms +step:1497/1695 train_time:148593ms step_avg:99.26ms +step:1498/1695 train_time:148695ms step_avg:99.26ms +step:1499/1695 train_time:148796ms step_avg:99.26ms +step:1500/1695 train_time:148898ms step_avg:99.27ms +step:1500/1695 val_loss:3.3182 train_time:148996ms step_avg:99.33ms +step:1501/1695 train_time:149025ms step_avg:99.28ms +step:1502/1695 train_time:149112ms step_avg:99.28ms +step:1503/1695 train_time:149215ms step_avg:99.28ms +step:1504/1695 train_time:149315ms step_avg:99.28ms +step:1505/1695 train_time:149418ms step_avg:99.28ms +step:1506/1695 train_time:149519ms step_avg:99.28ms +step:1507/1695 train_time:149621ms step_avg:99.28ms +step:1508/1695 train_time:149721ms step_avg:99.28ms +step:1509/1695 train_time:149823ms step_avg:99.29ms +step:1510/1695 train_time:149925ms step_avg:99.29ms +step:1511/1695 train_time:150029ms step_avg:99.29ms +step:1512/1695 train_time:150131ms step_avg:99.29ms +step:1513/1695 train_time:150233ms step_avg:99.29ms +step:1514/1695 train_time:150335ms step_avg:99.30ms +step:1515/1695 train_time:150441ms step_avg:99.30ms +step:1516/1695 train_time:150542ms step_avg:99.30ms +step:1517/1695 train_time:150642ms step_avg:99.30ms +step:1518/1695 train_time:150743ms step_avg:99.30ms +step:1519/1695 train_time:150847ms step_avg:99.31ms +step:1520/1695 train_time:150949ms step_avg:99.31ms +step:1521/1695 train_time:151050ms step_avg:99.31ms +step:1522/1695 train_time:151151ms step_avg:99.31ms +step:1523/1695 train_time:151253ms step_avg:99.31ms +step:1524/1695 train_time:151357ms step_avg:99.32ms +step:1525/1695 train_time:151461ms step_avg:99.32ms +step:1526/1695 train_time:151563ms step_avg:99.32ms +step:1527/1695 train_time:151665ms step_avg:99.32ms +step:1528/1695 train_time:151770ms step_avg:99.33ms +step:1529/1695 train_time:151872ms step_avg:99.33ms +step:1530/1695 train_time:151976ms step_avg:99.33ms +step:1531/1695 train_time:152078ms step_avg:99.33ms +step:1532/1695 train_time:152180ms step_avg:99.33ms +step:1533/1695 train_time:152282ms step_avg:99.34ms +step:1534/1695 train_time:152383ms step_avg:99.34ms +step:1535/1695 train_time:152486ms step_avg:99.34ms +step:1536/1695 train_time:152587ms step_avg:99.34ms +step:1537/1695 train_time:152688ms step_avg:99.34ms +step:1538/1695 train_time:152790ms step_avg:99.34ms +step:1539/1695 train_time:152892ms step_avg:99.34ms +step:1540/1695 train_time:152996ms step_avg:99.35ms +step:1541/1695 train_time:153099ms step_avg:99.35ms +step:1542/1695 train_time:153203ms step_avg:99.35ms +step:1543/1695 train_time:153305ms step_avg:99.36ms +step:1544/1695 train_time:153407ms step_avg:99.36ms +step:1545/1695 train_time:153509ms step_avg:99.36ms +step:1546/1695 train_time:153610ms step_avg:99.36ms +step:1547/1695 train_time:153712ms step_avg:99.36ms +step:1548/1695 train_time:153814ms step_avg:99.36ms +step:1549/1695 train_time:153916ms step_avg:99.36ms +step:1550/1695 train_time:154019ms step_avg:99.37ms +step:1551/1695 train_time:154121ms step_avg:99.37ms +step:1552/1695 train_time:154223ms step_avg:99.37ms +step:1553/1695 train_time:154325ms step_avg:99.37ms +step:1554/1695 train_time:154426ms step_avg:99.37ms +step:1555/1695 train_time:154528ms step_avg:99.37ms +step:1556/1695 train_time:154630ms step_avg:99.38ms +step:1557/1695 train_time:154734ms step_avg:99.38ms +step:1558/1695 train_time:154837ms step_avg:99.38ms +step:1559/1695 train_time:154940ms step_avg:99.38ms +step:1560/1695 train_time:155042ms step_avg:99.39ms +step:1561/1695 train_time:155144ms step_avg:99.39ms +step:1562/1695 train_time:155246ms step_avg:99.39ms +step:1563/1695 train_time:155350ms step_avg:99.39ms +step:1564/1695 train_time:155451ms step_avg:99.39ms +step:1565/1695 train_time:155553ms step_avg:99.39ms +step:1566/1695 train_time:155655ms step_avg:99.40ms +step:1567/1695 train_time:155755ms step_avg:99.40ms +step:1568/1695 train_time:155857ms step_avg:99.40ms +step:1569/1695 train_time:155958ms step_avg:99.40ms +step:1570/1695 train_time:156062ms step_avg:99.40ms +step:1571/1695 train_time:156163ms step_avg:99.40ms +step:1572/1695 train_time:156265ms step_avg:99.41ms +step:1573/1695 train_time:156367ms step_avg:99.41ms +step:1574/1695 train_time:156468ms step_avg:99.41ms +step:1575/1695 train_time:156568ms step_avg:99.41ms +step:1576/1695 train_time:156670ms step_avg:99.41ms +step:1577/1695 train_time:156773ms step_avg:99.41ms +step:1578/1695 train_time:156875ms step_avg:99.41ms +step:1579/1695 train_time:156978ms step_avg:99.42ms +step:1580/1695 train_time:157082ms step_avg:99.42ms +step:1581/1695 train_time:157184ms step_avg:99.42ms +step:1582/1695 train_time:157285ms step_avg:99.42ms +step:1583/1695 train_time:157389ms step_avg:99.42ms +step:1584/1695 train_time:157492ms step_avg:99.43ms +step:1585/1695 train_time:157593ms step_avg:99.43ms +step:1586/1695 train_time:157696ms step_avg:99.43ms +step:1587/1695 train_time:157798ms step_avg:99.43ms +step:1588/1695 train_time:157899ms step_avg:99.43ms +step:1589/1695 train_time:158000ms step_avg:99.43ms +step:1590/1695 train_time:158102ms step_avg:99.44ms +step:1591/1695 train_time:158205ms step_avg:99.44ms +step:1592/1695 train_time:158307ms step_avg:99.44ms +step:1593/1695 train_time:158408ms step_avg:99.44ms +step:1594/1695 train_time:158512ms step_avg:99.44ms +step:1595/1695 train_time:158614ms step_avg:99.44ms +step:1596/1695 train_time:158715ms step_avg:99.45ms +step:1597/1695 train_time:158818ms step_avg:99.45ms +step:1598/1695 train_time:158922ms step_avg:99.45ms +step:1599/1695 train_time:159023ms step_avg:99.45ms +step:1600/1695 train_time:159124ms step_avg:99.45ms +step:1601/1695 train_time:159227ms step_avg:99.45ms +step:1602/1695 train_time:159329ms step_avg:99.46ms +step:1603/1695 train_time:159429ms step_avg:99.46ms +step:1604/1695 train_time:159531ms step_avg:99.46ms +step:1605/1695 train_time:159633ms step_avg:99.46ms +step:1606/1695 train_time:159735ms step_avg:99.46ms +step:1607/1695 train_time:159836ms step_avg:99.46ms +step:1608/1695 train_time:159939ms step_avg:99.46ms +step:1609/1695 train_time:160040ms step_avg:99.47ms +step:1610/1695 train_time:160143ms step_avg:99.47ms +step:1611/1695 train_time:160245ms step_avg:99.47ms +step:1612/1695 train_time:160347ms step_avg:99.47ms +step:1613/1695 train_time:160449ms step_avg:99.47ms +step:1614/1695 train_time:160549ms step_avg:99.47ms +step:1615/1695 train_time:160651ms step_avg:99.47ms +step:1616/1695 train_time:160752ms step_avg:99.48ms +step:1617/1695 train_time:160854ms step_avg:99.48ms +step:1618/1695 train_time:160956ms step_avg:99.48ms +step:1619/1695 train_time:161058ms step_avg:99.48ms +step:1620/1695 train_time:161162ms step_avg:99.48ms +step:1621/1695 train_time:161263ms step_avg:99.48ms +step:1622/1695 train_time:161365ms step_avg:99.49ms +step:1623/1695 train_time:161467ms step_avg:99.49ms +step:1624/1695 train_time:161569ms step_avg:99.49ms +step:1625/1695 train_time:161673ms step_avg:99.49ms +step:1625/1695 val_loss:3.2898 train_time:161771ms step_avg:99.55ms +step:1626/1695 train_time:161800ms step_avg:99.51ms +step:1627/1695 train_time:161886ms step_avg:99.50ms +step:1628/1695 train_time:161988ms step_avg:99.50ms +step:1629/1695 train_time:162092ms step_avg:99.50ms +step:1630/1695 train_time:162193ms step_avg:99.51ms +step:1631/1695 train_time:162295ms step_avg:99.51ms +step:1632/1695 train_time:162397ms step_avg:99.51ms +step:1633/1695 train_time:162498ms step_avg:99.51ms +step:1634/1695 train_time:162600ms step_avg:99.51ms +step:1635/1695 train_time:162702ms step_avg:99.51ms +step:1636/1695 train_time:162805ms step_avg:99.51ms +step:1637/1695 train_time:162908ms step_avg:99.52ms +step:1638/1695 train_time:163010ms step_avg:99.52ms +step:1639/1695 train_time:163113ms step_avg:99.52ms +step:1640/1695 train_time:163216ms step_avg:99.52ms +step:1641/1695 train_time:163318ms step_avg:99.52ms +step:1642/1695 train_time:163421ms step_avg:99.53ms +step:1643/1695 train_time:163522ms step_avg:99.53ms +step:1644/1695 train_time:163624ms step_avg:99.53ms +step:1645/1695 train_time:163727ms step_avg:99.53ms +step:1646/1695 train_time:163830ms step_avg:99.53ms +step:1647/1695 train_time:163934ms step_avg:99.54ms +step:1648/1695 train_time:164040ms step_avg:99.54ms +step:1649/1695 train_time:164142ms step_avg:99.54ms +step:1650/1695 train_time:164244ms step_avg:99.54ms +step:1651/1695 train_time:164345ms step_avg:99.54ms +step:1652/1695 train_time:164447ms step_avg:99.54ms +step:1653/1695 train_time:164550ms step_avg:99.55ms +step:1654/1695 train_time:164654ms step_avg:99.55ms +step:1655/1695 train_time:164756ms step_avg:99.55ms +step:1656/1695 train_time:164859ms step_avg:99.55ms +step:1657/1695 train_time:164961ms step_avg:99.55ms +step:1658/1695 train_time:165064ms step_avg:99.56ms +step:1659/1695 train_time:165170ms step_avg:99.56ms +step:1660/1695 train_time:165272ms step_avg:99.56ms +step:1661/1695 train_time:165377ms step_avg:99.56ms +step:1662/1695 train_time:165482ms step_avg:99.57ms +step:1663/1695 train_time:165584ms step_avg:99.57ms +step:1664/1695 train_time:165686ms step_avg:99.57ms +step:1665/1695 train_time:165795ms step_avg:99.58ms +step:1666/1695 train_time:165898ms step_avg:99.58ms +step:1667/1695 train_time:166001ms step_avg:99.58ms +step:1668/1695 train_time:166107ms step_avg:99.58ms +step:1669/1695 train_time:166211ms step_avg:99.59ms +step:1670/1695 train_time:166313ms step_avg:99.59ms +step:1671/1695 train_time:166416ms step_avg:99.59ms +step:1672/1695 train_time:166520ms step_avg:99.59ms +step:1673/1695 train_time:166622ms step_avg:99.59ms +step:1674/1695 train_time:166724ms step_avg:99.60ms +step:1675/1695 train_time:166826ms step_avg:99.60ms +step:1676/1695 train_time:166931ms step_avg:99.60ms +step:1677/1695 train_time:167033ms step_avg:99.60ms +step:1678/1695 train_time:167137ms step_avg:99.60ms +step:1679/1695 train_time:167240ms step_avg:99.61ms +step:1680/1695 train_time:167342ms step_avg:99.61ms +step:1681/1695 train_time:167444ms step_avg:99.61ms +step:1682/1695 train_time:167549ms step_avg:99.61ms +step:1683/1695 train_time:167651ms step_avg:99.61ms +step:1684/1695 train_time:167754ms step_avg:99.62ms +step:1685/1695 train_time:167857ms step_avg:99.62ms +step:1686/1695 train_time:167960ms step_avg:99.62ms +step:1687/1695 train_time:168062ms step_avg:99.62ms +step:1688/1695 train_time:168165ms step_avg:99.62ms +step:1689/1695 train_time:168266ms step_avg:99.62ms +step:1690/1695 train_time:168369ms step_avg:99.63ms +step:1691/1695 train_time:168472ms step_avg:99.63ms +step:1692/1695 train_time:168573ms step_avg:99.63ms +step:1693/1695 train_time:168676ms step_avg:99.63ms +step:1694/1695 train_time:168779ms step_avg:99.63ms +step:1695/1695 train_time:168882ms step_avg:99.64ms +step:1695/1695 val_loss:3.2767 train_time:168981ms step_avg:99.69ms +peak memory allocated: 34077 MiB reserved: 49660 MiB diff --git a/records/082325_SparseAttnGate/4518e917-cec2-4c81-9c1a-53b0644c2326.txt b/records/082325_SparseAttnGate/4518e917-cec2-4c81-9c1a-53b0644c2326.txt new file mode 100644 index 000000000..f79b4ae11 --- /dev/null +++ b/records/082325_SparseAttnGate/4518e917-cec2-4c81-9c1a-53b0644c2326.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:44:03 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 38C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 317064 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 317065 C /usr/bin/python3 614MiB | +| 0 N/A N/A 317066 C /usr/bin/python3 614MiB | +| 0 N/A N/A 317067 C /usr/bin/python3 614MiB | +| 0 N/A N/A 317068 C /usr/bin/python3 614MiB | +| 0 N/A N/A 317069 C /usr/bin/python3 614MiB | +| 0 N/A N/A 317070 C /usr/bin/python3 614MiB | +| 0 N/A N/A 317071 C /usr/bin/python3 614MiB | +| 1 N/A N/A 317065 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 317066 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 317067 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 317068 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 317069 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 317070 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 317071 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:155ms step_avg:155.10ms +step:2/1695 train_time:182ms step_avg:91.07ms +step:3/1695 train_time:253ms step_avg:84.30ms +step:4/1695 train_time:345ms step_avg:86.24ms +step:5/1695 train_time:437ms step_avg:87.38ms +step:6/1695 train_time:530ms step_avg:88.29ms +step:7/1695 train_time:623ms step_avg:88.94ms +step:8/1695 train_time:715ms step_avg:89.43ms +step:9/1695 train_time:808ms step_avg:89.82ms +step:10/1695 train_time:901ms step_avg:90.13ms +step:11/1695 train_time:994ms step_avg:90.37ms +step:12/1695 train_time:1088ms step_avg:90.68ms +step:13/1695 train_time:1184ms step_avg:91.05ms +step:14/1695 train_time:1278ms step_avg:91.30ms +step:15/1695 train_time:1372ms step_avg:91.47ms +step:16/1695 train_time:1466ms step_avg:91.61ms +step:17/1695 train_time:1559ms step_avg:91.71ms +step:18/1695 train_time:1652ms step_avg:91.77ms +step:19/1695 train_time:1746ms step_avg:91.88ms +step:20/1695 train_time:1839ms step_avg:91.93ms +step:21/1695 train_time:1932ms step_avg:91.98ms +step:22/1695 train_time:2025ms step_avg:92.03ms +step:23/1695 train_time:2120ms step_avg:92.16ms +step:24/1695 train_time:2213ms step_avg:92.20ms +step:25/1695 train_time:2307ms step_avg:92.29ms +step:26/1695 train_time:2401ms step_avg:92.35ms +step:27/1695 train_time:2495ms step_avg:92.40ms +step:28/1695 train_time:2589ms step_avg:92.45ms +step:29/1695 train_time:2682ms step_avg:92.49ms +step:30/1695 train_time:2776ms step_avg:92.53ms +step:31/1695 train_time:2870ms step_avg:92.59ms +step:32/1695 train_time:2964ms step_avg:92.62ms +step:33/1695 train_time:3057ms step_avg:92.65ms +step:34/1695 train_time:3152ms step_avg:92.70ms +step:35/1695 train_time:3247ms step_avg:92.76ms +step:36/1695 train_time:3341ms step_avg:92.81ms +step:37/1695 train_time:3436ms step_avg:92.85ms +step:38/1695 train_time:3530ms step_avg:92.89ms +step:39/1695 train_time:3624ms step_avg:92.91ms +step:40/1695 train_time:3718ms step_avg:92.95ms +step:41/1695 train_time:3812ms step_avg:92.98ms +step:42/1695 train_time:3906ms step_avg:93.01ms +step:43/1695 train_time:4001ms step_avg:93.04ms +step:44/1695 train_time:4094ms step_avg:93.04ms +step:45/1695 train_time:4187ms step_avg:93.05ms +step:46/1695 train_time:4282ms step_avg:93.08ms +step:47/1695 train_time:4376ms step_avg:93.10ms +step:48/1695 train_time:4471ms step_avg:93.14ms +step:49/1695 train_time:4565ms step_avg:93.16ms +step:50/1695 train_time:4658ms step_avg:93.17ms +step:51/1695 train_time:4752ms step_avg:93.18ms +step:52/1695 train_time:4846ms step_avg:93.18ms +step:53/1695 train_time:4940ms step_avg:93.20ms +step:54/1695 train_time:5033ms step_avg:93.20ms +step:55/1695 train_time:5126ms step_avg:93.21ms +step:56/1695 train_time:5220ms step_avg:93.21ms +step:57/1695 train_time:5313ms step_avg:93.21ms +step:58/1695 train_time:5407ms step_avg:93.22ms +step:59/1695 train_time:5501ms step_avg:93.23ms +step:60/1695 train_time:5594ms step_avg:93.24ms +step:61/1695 train_time:5689ms step_avg:93.26ms +step:62/1695 train_time:5782ms step_avg:93.27ms +step:63/1695 train_time:5876ms step_avg:93.27ms +step:64/1695 train_time:5971ms step_avg:93.29ms +step:65/1695 train_time:6064ms step_avg:93.29ms +step:66/1695 train_time:6157ms step_avg:93.29ms +step:67/1695 train_time:6251ms step_avg:93.30ms +step:68/1695 train_time:6345ms step_avg:93.31ms +step:69/1695 train_time:6438ms step_avg:93.31ms +step:70/1695 train_time:6531ms step_avg:93.31ms +step:71/1695 train_time:6625ms step_avg:93.31ms +step:72/1695 train_time:6719ms step_avg:93.32ms +step:73/1695 train_time:6812ms step_avg:93.31ms +step:74/1695 train_time:6906ms step_avg:93.32ms +step:75/1695 train_time:7000ms step_avg:93.33ms +step:76/1695 train_time:7093ms step_avg:93.33ms +step:77/1695 train_time:7188ms step_avg:93.35ms +step:78/1695 train_time:7280ms step_avg:93.33ms +step:79/1695 train_time:7373ms step_avg:93.33ms +step:80/1695 train_time:7468ms step_avg:93.35ms +step:81/1695 train_time:7562ms step_avg:93.36ms +step:82/1695 train_time:7656ms step_avg:93.36ms +step:83/1695 train_time:7750ms step_avg:93.37ms +step:84/1695 train_time:7844ms step_avg:93.38ms +step:85/1695 train_time:7937ms step_avg:93.38ms +step:86/1695 train_time:8031ms step_avg:93.38ms +step:87/1695 train_time:8124ms step_avg:93.38ms +step:88/1695 train_time:8217ms step_avg:93.38ms +step:89/1695 train_time:8311ms step_avg:93.38ms +step:90/1695 train_time:8405ms step_avg:93.39ms +step:91/1695 train_time:8498ms step_avg:93.39ms +step:92/1695 train_time:8591ms step_avg:93.39ms +step:93/1695 train_time:8685ms step_avg:93.39ms +step:94/1695 train_time:8780ms step_avg:93.40ms +step:95/1695 train_time:8873ms step_avg:93.40ms +step:96/1695 train_time:8967ms step_avg:93.41ms +step:97/1695 train_time:9062ms step_avg:93.42ms +step:98/1695 train_time:9155ms step_avg:93.41ms +step:99/1695 train_time:9248ms step_avg:93.41ms +step:100/1695 train_time:9342ms step_avg:93.42ms +step:101/1695 train_time:9435ms step_avg:93.42ms +step:102/1695 train_time:9530ms step_avg:93.43ms +step:103/1695 train_time:9624ms step_avg:93.43ms +step:104/1695 train_time:9717ms step_avg:93.43ms +step:105/1695 train_time:9810ms step_avg:93.43ms +step:106/1695 train_time:9903ms step_avg:93.43ms +step:107/1695 train_time:9996ms step_avg:93.42ms +step:108/1695 train_time:10090ms step_avg:93.42ms +step:109/1695 train_time:10184ms step_avg:93.43ms +step:110/1695 train_time:10278ms step_avg:93.44ms +step:111/1695 train_time:10371ms step_avg:93.43ms +step:112/1695 train_time:10465ms step_avg:93.44ms +step:113/1695 train_time:10558ms step_avg:93.43ms +step:114/1695 train_time:10652ms step_avg:93.43ms +step:115/1695 train_time:10746ms step_avg:93.45ms +step:116/1695 train_time:10840ms step_avg:93.45ms +step:117/1695 train_time:10934ms step_avg:93.45ms +step:118/1695 train_time:11028ms step_avg:93.46ms +step:119/1695 train_time:11122ms step_avg:93.46ms +step:120/1695 train_time:11215ms step_avg:93.46ms +step:121/1695 train_time:11309ms step_avg:93.46ms +step:122/1695 train_time:11402ms step_avg:93.46ms +step:123/1695 train_time:11495ms step_avg:93.46ms +step:124/1695 train_time:11589ms step_avg:93.46ms +step:125/1695 train_time:11683ms step_avg:93.46ms +step:125/1695 val_loss:4.6053 train_time:11775ms step_avg:94.20ms +step:126/1695 train_time:11803ms step_avg:93.67ms +step:127/1695 train_time:11879ms step_avg:93.54ms +step:128/1695 train_time:11981ms step_avg:93.60ms +step:129/1695 train_time:12078ms step_avg:93.63ms +step:130/1695 train_time:12172ms step_avg:93.63ms +step:131/1695 train_time:12265ms step_avg:93.63ms +step:132/1695 train_time:12358ms step_avg:93.62ms +step:133/1695 train_time:12452ms step_avg:93.62ms +step:134/1695 train_time:12545ms step_avg:93.62ms +step:135/1695 train_time:12638ms step_avg:93.62ms +step:136/1695 train_time:12732ms step_avg:93.61ms +step:137/1695 train_time:12826ms step_avg:93.62ms +step:138/1695 train_time:12921ms step_avg:93.63ms +step:139/1695 train_time:13017ms step_avg:93.64ms +step:140/1695 train_time:13112ms step_avg:93.66ms +step:141/1695 train_time:13207ms step_avg:93.66ms +step:142/1695 train_time:13300ms step_avg:93.66ms +step:143/1695 train_time:13394ms step_avg:93.66ms +step:144/1695 train_time:13488ms step_avg:93.67ms +step:145/1695 train_time:13581ms step_avg:93.67ms +step:146/1695 train_time:13675ms step_avg:93.66ms +step:147/1695 train_time:13769ms step_avg:93.66ms +step:148/1695 train_time:13862ms step_avg:93.66ms +step:149/1695 train_time:13957ms step_avg:93.67ms +step:150/1695 train_time:14053ms step_avg:93.68ms +step:151/1695 train_time:14148ms step_avg:93.69ms +step:152/1695 train_time:14242ms step_avg:93.70ms +step:153/1695 train_time:14335ms step_avg:93.69ms +step:154/1695 train_time:14430ms step_avg:93.70ms +step:155/1695 train_time:14523ms step_avg:93.70ms +step:156/1695 train_time:14617ms step_avg:93.70ms +step:157/1695 train_time:14712ms step_avg:93.70ms +step:158/1695 train_time:14806ms step_avg:93.71ms +step:159/1695 train_time:14900ms step_avg:93.71ms +step:160/1695 train_time:14993ms step_avg:93.71ms +step:161/1695 train_time:15087ms step_avg:93.71ms +step:162/1695 train_time:15182ms step_avg:93.72ms +step:163/1695 train_time:15276ms step_avg:93.72ms +step:164/1695 train_time:15370ms step_avg:93.72ms +step:165/1695 train_time:15464ms step_avg:93.72ms +step:166/1695 train_time:15558ms step_avg:93.72ms +step:167/1695 train_time:15653ms step_avg:93.73ms +step:168/1695 train_time:15747ms step_avg:93.73ms +step:169/1695 train_time:15840ms step_avg:93.73ms +step:170/1695 train_time:15935ms step_avg:93.73ms +step:171/1695 train_time:16029ms step_avg:93.74ms +step:172/1695 train_time:16124ms step_avg:93.74ms +step:173/1695 train_time:16218ms step_avg:93.74ms +step:174/1695 train_time:16312ms step_avg:93.75ms +step:175/1695 train_time:16405ms step_avg:93.74ms +step:176/1695 train_time:16499ms step_avg:93.75ms +step:177/1695 train_time:16593ms step_avg:93.75ms +step:178/1695 train_time:16687ms step_avg:93.75ms +step:179/1695 train_time:16781ms step_avg:93.75ms +step:180/1695 train_time:16875ms step_avg:93.75ms +step:181/1695 train_time:16969ms step_avg:93.75ms +step:182/1695 train_time:17064ms step_avg:93.76ms +step:183/1695 train_time:17157ms step_avg:93.76ms +step:184/1695 train_time:17252ms step_avg:93.76ms +step:185/1695 train_time:17348ms step_avg:93.77ms +step:186/1695 train_time:17441ms step_avg:93.77ms +step:187/1695 train_time:17536ms step_avg:93.77ms +step:188/1695 train_time:17629ms step_avg:93.77ms +step:189/1695 train_time:17723ms step_avg:93.77ms +step:190/1695 train_time:17817ms step_avg:93.78ms +step:191/1695 train_time:17911ms step_avg:93.78ms +step:192/1695 train_time:18005ms step_avg:93.78ms +step:193/1695 train_time:18099ms step_avg:93.78ms +step:194/1695 train_time:18193ms step_avg:93.78ms +step:195/1695 train_time:18288ms step_avg:93.78ms +step:196/1695 train_time:18381ms step_avg:93.78ms +step:197/1695 train_time:18476ms step_avg:93.78ms +step:198/1695 train_time:18570ms step_avg:93.79ms +step:199/1695 train_time:18664ms step_avg:93.79ms +step:200/1695 train_time:18758ms step_avg:93.79ms +step:201/1695 train_time:18853ms step_avg:93.79ms +step:202/1695 train_time:18948ms step_avg:93.80ms +step:203/1695 train_time:19043ms step_avg:93.81ms +step:204/1695 train_time:19136ms step_avg:93.80ms +step:205/1695 train_time:19230ms step_avg:93.81ms +step:206/1695 train_time:19325ms step_avg:93.81ms +step:207/1695 train_time:19418ms step_avg:93.81ms +step:208/1695 train_time:19512ms step_avg:93.81ms +step:209/1695 train_time:19606ms step_avg:93.81ms +step:210/1695 train_time:19700ms step_avg:93.81ms +step:211/1695 train_time:19795ms step_avg:93.81ms +step:212/1695 train_time:19888ms step_avg:93.81ms +step:213/1695 train_time:19982ms step_avg:93.81ms +step:214/1695 train_time:20077ms step_avg:93.82ms +step:215/1695 train_time:20171ms step_avg:93.82ms +step:216/1695 train_time:20266ms step_avg:93.82ms +step:217/1695 train_time:20359ms step_avg:93.82ms +step:218/1695 train_time:20454ms step_avg:93.82ms +step:219/1695 train_time:20548ms step_avg:93.83ms +step:220/1695 train_time:20642ms step_avg:93.83ms +step:221/1695 train_time:20736ms step_avg:93.83ms +step:222/1695 train_time:20831ms step_avg:93.83ms +step:223/1695 train_time:20926ms step_avg:93.84ms +step:224/1695 train_time:21019ms step_avg:93.84ms +step:225/1695 train_time:21113ms step_avg:93.84ms +step:226/1695 train_time:21208ms step_avg:93.84ms +step:227/1695 train_time:21301ms step_avg:93.84ms +step:228/1695 train_time:21396ms step_avg:93.84ms +step:229/1695 train_time:21489ms step_avg:93.84ms +step:230/1695 train_time:21583ms step_avg:93.84ms +step:231/1695 train_time:21677ms step_avg:93.84ms +step:232/1695 train_time:21771ms step_avg:93.84ms +step:233/1695 train_time:21864ms step_avg:93.84ms +step:234/1695 train_time:21959ms step_avg:93.84ms +step:235/1695 train_time:22053ms step_avg:93.84ms +step:236/1695 train_time:22148ms step_avg:93.85ms +step:237/1695 train_time:22242ms step_avg:93.85ms +step:238/1695 train_time:22336ms step_avg:93.85ms +step:239/1695 train_time:22431ms step_avg:93.85ms +step:240/1695 train_time:22525ms step_avg:93.85ms +step:241/1695 train_time:22619ms step_avg:93.85ms +step:242/1695 train_time:22712ms step_avg:93.85ms +step:243/1695 train_time:22806ms step_avg:93.85ms +step:244/1695 train_time:22901ms step_avg:93.86ms +step:245/1695 train_time:22995ms step_avg:93.86ms +step:246/1695 train_time:23090ms step_avg:93.86ms +step:247/1695 train_time:23184ms step_avg:93.86ms +step:248/1695 train_time:23277ms step_avg:93.86ms +step:249/1695 train_time:23373ms step_avg:93.87ms +step:250/1695 train_time:23467ms step_avg:93.87ms +step:250/1695 val_loss:4.0744 train_time:23559ms step_avg:94.24ms +step:251/1695 train_time:23587ms step_avg:93.97ms +step:252/1695 train_time:23664ms step_avg:93.90ms +step:253/1695 train_time:23762ms step_avg:93.92ms +step:254/1695 train_time:23857ms step_avg:93.93ms +step:255/1695 train_time:23951ms step_avg:93.93ms +step:256/1695 train_time:24045ms step_avg:93.92ms +step:257/1695 train_time:24138ms step_avg:93.92ms +step:258/1695 train_time:24232ms step_avg:93.92ms +step:259/1695 train_time:24327ms step_avg:93.93ms +step:260/1695 train_time:24419ms step_avg:93.92ms +step:261/1695 train_time:24514ms step_avg:93.92ms +step:262/1695 train_time:24609ms step_avg:93.93ms +step:263/1695 train_time:24706ms step_avg:93.94ms +step:264/1695 train_time:24801ms step_avg:93.94ms +step:265/1695 train_time:24897ms step_avg:93.95ms +step:266/1695 train_time:24992ms step_avg:93.96ms +step:267/1695 train_time:25087ms step_avg:93.96ms +step:268/1695 train_time:25180ms step_avg:93.96ms +step:269/1695 train_time:25275ms step_avg:93.96ms +step:270/1695 train_time:25369ms step_avg:93.96ms +step:271/1695 train_time:25462ms step_avg:93.96ms +step:272/1695 train_time:25557ms step_avg:93.96ms +step:273/1695 train_time:25653ms step_avg:93.97ms +step:274/1695 train_time:25748ms step_avg:93.97ms +step:275/1695 train_time:25843ms step_avg:93.97ms +step:276/1695 train_time:25937ms step_avg:93.98ms +step:277/1695 train_time:26032ms step_avg:93.98ms +step:278/1695 train_time:26127ms step_avg:93.98ms +step:279/1695 train_time:26221ms step_avg:93.98ms +step:280/1695 train_time:26315ms step_avg:93.98ms +step:281/1695 train_time:26408ms step_avg:93.98ms +step:282/1695 train_time:26502ms step_avg:93.98ms +step:283/1695 train_time:26597ms step_avg:93.98ms +step:284/1695 train_time:26692ms step_avg:93.99ms +step:285/1695 train_time:26787ms step_avg:93.99ms +step:286/1695 train_time:26882ms step_avg:93.99ms +step:287/1695 train_time:26976ms step_avg:93.99ms +step:288/1695 train_time:27071ms step_avg:94.00ms +step:289/1695 train_time:27166ms step_avg:94.00ms +step:290/1695 train_time:27260ms step_avg:94.00ms +step:291/1695 train_time:27355ms step_avg:94.00ms +step:292/1695 train_time:27449ms step_avg:94.00ms +step:293/1695 train_time:27544ms step_avg:94.01ms +step:294/1695 train_time:27638ms step_avg:94.01ms +step:295/1695 train_time:27734ms step_avg:94.01ms +step:296/1695 train_time:27828ms step_avg:94.01ms +step:297/1695 train_time:27922ms step_avg:94.01ms +step:298/1695 train_time:28016ms step_avg:94.01ms +step:299/1695 train_time:28111ms step_avg:94.02ms +step:300/1695 train_time:28206ms step_avg:94.02ms +step:301/1695 train_time:28300ms step_avg:94.02ms +step:302/1695 train_time:28394ms step_avg:94.02ms +step:303/1695 train_time:28488ms step_avg:94.02ms +step:304/1695 train_time:28583ms step_avg:94.02ms +step:305/1695 train_time:28677ms step_avg:94.02ms +step:306/1695 train_time:28772ms step_avg:94.03ms +step:307/1695 train_time:28867ms step_avg:94.03ms +step:308/1695 train_time:28961ms step_avg:94.03ms +step:309/1695 train_time:29056ms step_avg:94.03ms +step:310/1695 train_time:29151ms step_avg:94.04ms +step:311/1695 train_time:29246ms step_avg:94.04ms +step:312/1695 train_time:29340ms step_avg:94.04ms +step:313/1695 train_time:29434ms step_avg:94.04ms +step:314/1695 train_time:29529ms step_avg:94.04ms +step:315/1695 train_time:29623ms step_avg:94.04ms +step:316/1695 train_time:29717ms step_avg:94.04ms +step:317/1695 train_time:29812ms step_avg:94.04ms +step:318/1695 train_time:29907ms step_avg:94.05ms +step:319/1695 train_time:30001ms step_avg:94.05ms +step:320/1695 train_time:30096ms step_avg:94.05ms +step:321/1695 train_time:30191ms step_avg:94.05ms +step:322/1695 train_time:30285ms step_avg:94.05ms +step:323/1695 train_time:30379ms step_avg:94.05ms +step:324/1695 train_time:30474ms step_avg:94.05ms +step:325/1695 train_time:30569ms step_avg:94.06ms +step:326/1695 train_time:30663ms step_avg:94.06ms +step:327/1695 train_time:30757ms step_avg:94.06ms +step:328/1695 train_time:30854ms step_avg:94.07ms +step:329/1695 train_time:30948ms step_avg:94.07ms +step:330/1695 train_time:31043ms step_avg:94.07ms +step:331/1695 train_time:31136ms step_avg:94.07ms +step:332/1695 train_time:31232ms step_avg:94.07ms +step:333/1695 train_time:31327ms step_avg:94.07ms +step:334/1695 train_time:31420ms step_avg:94.07ms +step:335/1695 train_time:31515ms step_avg:94.08ms +step:336/1695 train_time:31610ms step_avg:94.08ms +step:337/1695 train_time:31704ms step_avg:94.08ms +step:338/1695 train_time:31798ms step_avg:94.08ms +step:339/1695 train_time:31893ms step_avg:94.08ms +step:340/1695 train_time:31988ms step_avg:94.08ms +step:341/1695 train_time:32082ms step_avg:94.08ms +step:342/1695 train_time:32176ms step_avg:94.08ms +step:343/1695 train_time:32271ms step_avg:94.08ms +step:344/1695 train_time:32366ms step_avg:94.09ms +step:345/1695 train_time:32460ms step_avg:94.09ms +step:346/1695 train_time:32555ms step_avg:94.09ms +step:347/1695 train_time:32650ms step_avg:94.09ms +step:348/1695 train_time:32743ms step_avg:94.09ms +step:349/1695 train_time:32838ms step_avg:94.09ms +step:350/1695 train_time:32932ms step_avg:94.09ms +step:351/1695 train_time:33027ms step_avg:94.09ms +step:352/1695 train_time:33121ms step_avg:94.09ms +step:353/1695 train_time:33215ms step_avg:94.09ms +step:354/1695 train_time:33310ms step_avg:94.10ms +step:355/1695 train_time:33404ms step_avg:94.09ms +step:356/1695 train_time:33498ms step_avg:94.09ms +step:357/1695 train_time:33593ms step_avg:94.10ms +step:358/1695 train_time:33687ms step_avg:94.10ms +step:359/1695 train_time:33781ms step_avg:94.10ms +step:360/1695 train_time:33876ms step_avg:94.10ms +step:361/1695 train_time:33971ms step_avg:94.10ms +step:362/1695 train_time:34066ms step_avg:94.10ms +step:363/1695 train_time:34160ms step_avg:94.10ms +step:364/1695 train_time:34255ms step_avg:94.11ms +step:365/1695 train_time:34350ms step_avg:94.11ms +step:366/1695 train_time:34444ms step_avg:94.11ms +step:367/1695 train_time:34538ms step_avg:94.11ms +step:368/1695 train_time:34633ms step_avg:94.11ms +step:369/1695 train_time:34728ms step_avg:94.11ms +step:370/1695 train_time:34822ms step_avg:94.11ms +step:371/1695 train_time:34917ms step_avg:94.12ms +step:372/1695 train_time:35012ms step_avg:94.12ms +step:373/1695 train_time:35107ms step_avg:94.12ms +step:374/1695 train_time:35202ms step_avg:94.12ms +step:375/1695 train_time:35296ms step_avg:94.12ms +step:375/1695 val_loss:3.8794 train_time:35390ms step_avg:94.37ms +step:376/1695 train_time:35418ms step_avg:94.20ms +step:377/1695 train_time:35494ms step_avg:94.15ms +step:378/1695 train_time:35595ms step_avg:94.17ms +step:379/1695 train_time:35692ms step_avg:94.17ms +step:380/1695 train_time:35789ms step_avg:94.18ms +step:381/1695 train_time:35885ms step_avg:94.19ms +step:382/1695 train_time:35980ms step_avg:94.19ms +step:383/1695 train_time:36075ms step_avg:94.19ms +step:384/1695 train_time:36172ms step_avg:94.20ms +step:385/1695 train_time:36267ms step_avg:94.20ms +step:386/1695 train_time:36363ms step_avg:94.20ms +step:387/1695 train_time:36459ms step_avg:94.21ms +step:388/1695 train_time:36556ms step_avg:94.22ms +step:389/1695 train_time:36653ms step_avg:94.22ms +step:390/1695 train_time:36750ms step_avg:94.23ms +step:391/1695 train_time:36846ms step_avg:94.24ms +step:392/1695 train_time:36943ms step_avg:94.24ms +step:393/1695 train_time:37039ms step_avg:94.25ms +step:394/1695 train_time:37135ms step_avg:94.25ms +step:395/1695 train_time:37231ms step_avg:94.26ms +step:396/1695 train_time:37327ms step_avg:94.26ms +step:397/1695 train_time:37424ms step_avg:94.27ms +step:398/1695 train_time:37521ms step_avg:94.27ms +step:399/1695 train_time:37618ms step_avg:94.28ms +step:400/1695 train_time:37714ms step_avg:94.29ms +step:401/1695 train_time:37811ms step_avg:94.29ms +step:402/1695 train_time:37909ms step_avg:94.30ms +step:403/1695 train_time:38005ms step_avg:94.31ms +step:404/1695 train_time:38101ms step_avg:94.31ms +step:405/1695 train_time:38196ms step_avg:94.31ms +step:406/1695 train_time:38292ms step_avg:94.32ms +step:407/1695 train_time:38388ms step_avg:94.32ms +step:408/1695 train_time:38485ms step_avg:94.33ms +step:409/1695 train_time:38582ms step_avg:94.33ms +step:410/1695 train_time:38679ms step_avg:94.34ms +step:411/1695 train_time:38776ms step_avg:94.34ms +step:412/1695 train_time:38872ms step_avg:94.35ms +step:413/1695 train_time:38969ms step_avg:94.36ms +step:414/1695 train_time:39065ms step_avg:94.36ms +step:415/1695 train_time:39161ms step_avg:94.36ms +step:416/1695 train_time:39256ms step_avg:94.37ms +step:417/1695 train_time:39351ms step_avg:94.37ms +step:418/1695 train_time:39448ms step_avg:94.37ms +step:419/1695 train_time:39545ms step_avg:94.38ms +step:420/1695 train_time:39643ms step_avg:94.39ms +step:421/1695 train_time:39739ms step_avg:94.39ms +step:422/1695 train_time:39834ms step_avg:94.39ms +step:423/1695 train_time:39931ms step_avg:94.40ms +step:424/1695 train_time:40027ms step_avg:94.40ms +step:425/1695 train_time:40124ms step_avg:94.41ms +step:426/1695 train_time:40221ms step_avg:94.42ms +step:427/1695 train_time:40317ms step_avg:94.42ms +step:428/1695 train_time:40413ms step_avg:94.42ms +step:429/1695 train_time:40509ms step_avg:94.43ms +step:430/1695 train_time:40607ms step_avg:94.44ms +step:431/1695 train_time:40703ms step_avg:94.44ms +step:432/1695 train_time:40800ms step_avg:94.44ms +step:433/1695 train_time:40896ms step_avg:94.45ms +step:434/1695 train_time:40991ms step_avg:94.45ms +step:435/1695 train_time:41087ms step_avg:94.45ms +step:436/1695 train_time:41184ms step_avg:94.46ms +step:437/1695 train_time:41280ms step_avg:94.46ms +step:438/1695 train_time:41376ms step_avg:94.47ms +step:439/1695 train_time:41472ms step_avg:94.47ms +step:440/1695 train_time:41568ms step_avg:94.47ms +step:441/1695 train_time:41665ms step_avg:94.48ms +step:442/1695 train_time:41760ms step_avg:94.48ms +step:443/1695 train_time:41856ms step_avg:94.48ms +step:444/1695 train_time:41952ms step_avg:94.49ms +step:445/1695 train_time:42048ms step_avg:94.49ms +step:446/1695 train_time:42144ms step_avg:94.49ms +step:447/1695 train_time:42241ms step_avg:94.50ms +step:448/1695 train_time:42337ms step_avg:94.50ms +step:449/1695 train_time:42433ms step_avg:94.50ms +step:450/1695 train_time:42530ms step_avg:94.51ms +step:451/1695 train_time:42626ms step_avg:94.51ms +step:452/1695 train_time:42721ms step_avg:94.52ms +step:453/1695 train_time:42817ms step_avg:94.52ms +step:454/1695 train_time:42913ms step_avg:94.52ms +step:455/1695 train_time:43010ms step_avg:94.53ms +step:456/1695 train_time:43107ms step_avg:94.53ms +step:457/1695 train_time:43203ms step_avg:94.54ms +step:458/1695 train_time:43300ms step_avg:94.54ms +step:459/1695 train_time:43396ms step_avg:94.54ms +step:460/1695 train_time:43492ms step_avg:94.55ms +step:461/1695 train_time:43589ms step_avg:94.55ms +step:462/1695 train_time:43685ms step_avg:94.56ms +step:463/1695 train_time:43782ms step_avg:94.56ms +step:464/1695 train_time:43878ms step_avg:94.56ms +step:465/1695 train_time:43973ms step_avg:94.57ms +step:466/1695 train_time:44069ms step_avg:94.57ms +step:467/1695 train_time:44165ms step_avg:94.57ms +step:468/1695 train_time:44262ms step_avg:94.58ms +step:469/1695 train_time:44358ms step_avg:94.58ms +step:470/1695 train_time:44453ms step_avg:94.58ms +step:471/1695 train_time:44549ms step_avg:94.58ms +step:472/1695 train_time:44646ms step_avg:94.59ms +step:473/1695 train_time:44742ms step_avg:94.59ms +step:474/1695 train_time:44838ms step_avg:94.60ms +step:475/1695 train_time:44934ms step_avg:94.60ms +step:476/1695 train_time:45031ms step_avg:94.60ms +step:477/1695 train_time:45128ms step_avg:94.61ms +step:478/1695 train_time:45224ms step_avg:94.61ms +step:479/1695 train_time:45321ms step_avg:94.62ms +step:480/1695 train_time:45417ms step_avg:94.62ms +step:481/1695 train_time:45513ms step_avg:94.62ms +step:482/1695 train_time:45609ms step_avg:94.63ms +step:483/1695 train_time:45707ms step_avg:94.63ms +step:484/1695 train_time:45804ms step_avg:94.64ms +step:485/1695 train_time:45900ms step_avg:94.64ms +step:486/1695 train_time:45996ms step_avg:94.64ms +step:487/1695 train_time:46092ms step_avg:94.64ms +step:488/1695 train_time:46188ms step_avg:94.65ms +step:489/1695 train_time:46284ms step_avg:94.65ms +step:490/1695 train_time:46381ms step_avg:94.65ms +step:491/1695 train_time:46477ms step_avg:94.66ms +step:492/1695 train_time:46572ms step_avg:94.66ms +step:493/1695 train_time:46669ms step_avg:94.66ms +step:494/1695 train_time:46766ms step_avg:94.67ms +step:495/1695 train_time:46863ms step_avg:94.67ms +step:496/1695 train_time:46959ms step_avg:94.68ms +step:497/1695 train_time:47055ms step_avg:94.68ms +step:498/1695 train_time:47150ms step_avg:94.68ms +step:499/1695 train_time:47247ms step_avg:94.68ms +step:500/1695 train_time:47342ms step_avg:94.68ms +step:500/1695 val_loss:3.7347 train_time:47437ms step_avg:94.87ms +step:501/1695 train_time:47465ms step_avg:94.74ms +step:502/1695 train_time:47545ms step_avg:94.71ms +step:503/1695 train_time:47647ms step_avg:94.73ms +step:504/1695 train_time:47744ms step_avg:94.73ms +step:505/1695 train_time:47840ms step_avg:94.73ms +step:506/1695 train_time:47935ms step_avg:94.73ms +step:507/1695 train_time:48032ms step_avg:94.74ms +step:508/1695 train_time:48127ms step_avg:94.74ms +step:509/1695 train_time:48223ms step_avg:94.74ms +step:510/1695 train_time:48318ms step_avg:94.74ms +step:511/1695 train_time:48414ms step_avg:94.74ms +step:512/1695 train_time:48514ms step_avg:94.75ms +step:513/1695 train_time:48613ms step_avg:94.76ms +step:514/1695 train_time:48712ms step_avg:94.77ms +step:515/1695 train_time:48810ms step_avg:94.78ms +step:516/1695 train_time:48907ms step_avg:94.78ms +step:517/1695 train_time:49004ms step_avg:94.78ms +step:518/1695 train_time:49099ms step_avg:94.78ms +step:519/1695 train_time:49195ms step_avg:94.79ms +step:520/1695 train_time:49291ms step_avg:94.79ms +step:521/1695 train_time:49388ms step_avg:94.79ms +step:522/1695 train_time:49485ms step_avg:94.80ms +step:523/1695 train_time:49581ms step_avg:94.80ms +step:524/1695 train_time:49679ms step_avg:94.81ms +step:525/1695 train_time:49776ms step_avg:94.81ms +step:526/1695 train_time:49873ms step_avg:94.82ms +step:527/1695 train_time:49969ms step_avg:94.82ms +step:528/1695 train_time:50066ms step_avg:94.82ms +step:529/1695 train_time:50163ms step_avg:94.83ms +step:530/1695 train_time:50259ms step_avg:94.83ms +step:531/1695 train_time:50354ms step_avg:94.83ms +step:532/1695 train_time:50451ms step_avg:94.83ms +step:533/1695 train_time:50548ms step_avg:94.84ms +step:534/1695 train_time:50646ms step_avg:94.84ms +step:535/1695 train_time:50743ms step_avg:94.85ms +step:536/1695 train_time:50840ms step_avg:94.85ms +step:537/1695 train_time:50936ms step_avg:94.85ms +step:538/1695 train_time:51033ms step_avg:94.86ms +step:539/1695 train_time:51130ms step_avg:94.86ms +step:540/1695 train_time:51227ms step_avg:94.86ms +step:541/1695 train_time:51323ms step_avg:94.87ms +step:542/1695 train_time:51418ms step_avg:94.87ms +step:543/1695 train_time:51515ms step_avg:94.87ms +step:544/1695 train_time:51614ms step_avg:94.88ms +step:545/1695 train_time:51711ms step_avg:94.88ms +step:546/1695 train_time:51808ms step_avg:94.89ms +step:547/1695 train_time:51905ms step_avg:94.89ms +step:548/1695 train_time:52001ms step_avg:94.89ms +step:549/1695 train_time:52097ms step_avg:94.89ms +step:550/1695 train_time:52193ms step_avg:94.90ms +step:551/1695 train_time:52290ms step_avg:94.90ms +step:552/1695 train_time:52387ms step_avg:94.90ms +step:553/1695 train_time:52484ms step_avg:94.91ms +step:554/1695 train_time:52580ms step_avg:94.91ms +step:555/1695 train_time:52676ms step_avg:94.91ms +step:556/1695 train_time:52773ms step_avg:94.92ms +step:557/1695 train_time:52870ms step_avg:94.92ms +step:558/1695 train_time:52966ms step_avg:94.92ms +step:559/1695 train_time:53062ms step_avg:94.92ms +step:560/1695 train_time:53158ms step_avg:94.93ms +step:561/1695 train_time:53254ms step_avg:94.93ms +step:562/1695 train_time:53351ms step_avg:94.93ms +step:563/1695 train_time:53448ms step_avg:94.93ms +step:564/1695 train_time:53545ms step_avg:94.94ms +step:565/1695 train_time:53641ms step_avg:94.94ms +step:566/1695 train_time:53738ms step_avg:94.94ms +step:567/1695 train_time:53834ms step_avg:94.95ms +step:568/1695 train_time:53931ms step_avg:94.95ms +step:569/1695 train_time:54028ms step_avg:94.95ms +step:570/1695 train_time:54125ms step_avg:94.96ms +step:571/1695 train_time:54469ms step_avg:95.39ms +step:572/1695 train_time:54563ms step_avg:95.39ms +step:573/1695 train_time:54659ms step_avg:95.39ms +step:574/1695 train_time:54754ms step_avg:95.39ms +step:575/1695 train_time:54850ms step_avg:95.39ms +step:576/1695 train_time:54946ms step_avg:95.39ms +step:577/1695 train_time:55041ms step_avg:95.39ms +step:578/1695 train_time:55137ms step_avg:95.39ms +step:579/1695 train_time:55233ms step_avg:95.39ms +step:580/1695 train_time:55329ms step_avg:95.39ms +step:581/1695 train_time:55428ms step_avg:95.40ms +step:582/1695 train_time:55526ms step_avg:95.41ms +step:583/1695 train_time:55624ms step_avg:95.41ms +step:584/1695 train_time:55721ms step_avg:95.41ms +step:585/1695 train_time:55817ms step_avg:95.41ms +step:586/1695 train_time:55913ms step_avg:95.41ms +step:587/1695 train_time:56009ms step_avg:95.42ms +step:588/1695 train_time:56106ms step_avg:95.42ms +step:589/1695 train_time:56202ms step_avg:95.42ms +step:590/1695 train_time:56298ms step_avg:95.42ms +step:591/1695 train_time:56395ms step_avg:95.42ms +step:592/1695 train_time:56493ms step_avg:95.43ms +step:593/1695 train_time:56591ms step_avg:95.43ms +step:594/1695 train_time:56689ms step_avg:95.44ms +step:595/1695 train_time:56787ms step_avg:95.44ms +step:596/1695 train_time:56883ms step_avg:95.44ms +step:597/1695 train_time:56979ms step_avg:95.44ms +step:598/1695 train_time:57075ms step_avg:95.44ms +step:599/1695 train_time:57172ms step_avg:95.45ms +step:600/1695 train_time:57268ms step_avg:95.45ms +step:601/1695 train_time:57365ms step_avg:95.45ms +step:602/1695 train_time:57461ms step_avg:95.45ms +step:603/1695 train_time:57559ms step_avg:95.45ms +step:604/1695 train_time:57656ms step_avg:95.46ms +step:605/1695 train_time:57755ms step_avg:95.46ms +step:606/1695 train_time:57852ms step_avg:95.46ms +step:607/1695 train_time:57949ms step_avg:95.47ms +step:608/1695 train_time:58045ms step_avg:95.47ms +step:609/1695 train_time:58141ms step_avg:95.47ms +step:610/1695 train_time:58236ms step_avg:95.47ms +step:611/1695 train_time:58333ms step_avg:95.47ms +step:612/1695 train_time:58429ms step_avg:95.47ms +step:613/1695 train_time:58527ms step_avg:95.48ms +step:614/1695 train_time:58623ms step_avg:95.48ms +step:615/1695 train_time:58720ms step_avg:95.48ms +step:616/1695 train_time:58816ms step_avg:95.48ms +step:617/1695 train_time:58914ms step_avg:95.48ms +step:618/1695 train_time:59012ms step_avg:95.49ms +step:619/1695 train_time:59109ms step_avg:95.49ms +step:620/1695 train_time:59206ms step_avg:95.49ms +step:621/1695 train_time:59301ms step_avg:95.49ms +step:622/1695 train_time:59396ms step_avg:95.49ms +step:623/1695 train_time:59493ms step_avg:95.49ms +step:624/1695 train_time:59590ms step_avg:95.50ms +step:625/1695 train_time:59688ms step_avg:95.50ms +step:625/1695 val_loss:3.6510 train_time:59783ms step_avg:95.65ms +step:626/1695 train_time:59811ms step_avg:95.54ms +step:627/1695 train_time:59894ms step_avg:95.52ms +step:628/1695 train_time:59994ms step_avg:95.53ms +step:629/1695 train_time:60306ms step_avg:95.88ms +step:630/1695 train_time:60402ms step_avg:95.88ms +step:631/1695 train_time:60498ms step_avg:95.88ms +step:632/1695 train_time:60595ms step_avg:95.88ms +step:633/1695 train_time:60693ms step_avg:95.88ms +step:634/1695 train_time:60790ms step_avg:95.88ms +step:635/1695 train_time:60887ms step_avg:95.89ms +step:636/1695 train_time:60984ms step_avg:95.89ms +step:637/1695 train_time:61082ms step_avg:95.89ms +step:638/1695 train_time:61179ms step_avg:95.89ms +step:639/1695 train_time:61280ms step_avg:95.90ms +step:640/1695 train_time:61379ms step_avg:95.90ms +step:641/1695 train_time:61477ms step_avg:95.91ms +step:642/1695 train_time:61574ms step_avg:95.91ms +step:643/1695 train_time:61672ms step_avg:95.91ms +step:644/1695 train_time:61769ms step_avg:95.92ms +step:645/1695 train_time:61867ms step_avg:95.92ms +step:646/1695 train_time:61964ms step_avg:95.92ms +step:647/1695 train_time:62061ms step_avg:95.92ms +step:648/1695 train_time:62159ms step_avg:95.92ms +step:649/1695 train_time:62257ms step_avg:95.93ms +step:650/1695 train_time:62356ms step_avg:95.93ms +step:651/1695 train_time:62453ms step_avg:95.93ms +step:652/1695 train_time:62552ms step_avg:95.94ms +step:653/1695 train_time:62885ms step_avg:96.30ms +step:654/1695 train_time:62980ms step_avg:96.30ms +step:655/1695 train_time:63077ms step_avg:96.30ms +step:656/1695 train_time:63175ms step_avg:96.30ms +step:657/1695 train_time:63273ms step_avg:96.31ms +step:658/1695 train_time:63370ms step_avg:96.31ms +step:659/1695 train_time:63468ms step_avg:96.31ms +step:660/1695 train_time:63564ms step_avg:96.31ms +step:661/1695 train_time:63661ms step_avg:96.31ms +step:662/1695 train_time:63759ms step_avg:96.31ms +step:663/1695 train_time:63860ms step_avg:96.32ms +step:664/1695 train_time:63959ms step_avg:96.32ms +step:665/1695 train_time:64057ms step_avg:96.33ms +step:666/1695 train_time:64155ms step_avg:96.33ms +step:667/1695 train_time:64253ms step_avg:96.33ms +step:668/1695 train_time:64351ms step_avg:96.33ms +step:669/1695 train_time:64448ms step_avg:96.33ms +step:670/1695 train_time:64545ms step_avg:96.34ms +step:671/1695 train_time:64642ms step_avg:96.34ms +step:672/1695 train_time:64740ms step_avg:96.34ms +step:673/1695 train_time:64838ms step_avg:96.34ms +step:674/1695 train_time:64938ms step_avg:96.35ms +step:675/1695 train_time:65036ms step_avg:96.35ms +step:676/1695 train_time:65134ms step_avg:96.35ms +step:677/1695 train_time:65232ms step_avg:96.35ms +step:678/1695 train_time:65330ms step_avg:96.36ms +step:679/1695 train_time:65427ms step_avg:96.36ms +step:680/1695 train_time:65524ms step_avg:96.36ms +step:681/1695 train_time:65621ms step_avg:96.36ms +step:682/1695 train_time:65718ms step_avg:96.36ms +step:683/1695 train_time:65817ms step_avg:96.36ms +step:684/1695 train_time:65916ms step_avg:96.37ms +step:685/1695 train_time:66015ms step_avg:96.37ms +step:686/1695 train_time:66114ms step_avg:96.38ms +step:687/1695 train_time:66212ms step_avg:96.38ms +step:688/1695 train_time:66311ms step_avg:96.38ms +step:689/1695 train_time:66409ms step_avg:96.39ms +step:690/1695 train_time:66507ms step_avg:96.39ms +step:691/1695 train_time:66606ms step_avg:96.39ms +step:692/1695 train_time:66703ms step_avg:96.39ms +step:693/1695 train_time:66801ms step_avg:96.39ms +step:694/1695 train_time:66899ms step_avg:96.40ms +step:695/1695 train_time:66998ms step_avg:96.40ms +step:696/1695 train_time:67096ms step_avg:96.40ms +step:697/1695 train_time:67195ms step_avg:96.41ms +step:698/1695 train_time:67294ms step_avg:96.41ms +step:699/1695 train_time:67392ms step_avg:96.41ms +step:700/1695 train_time:67490ms step_avg:96.41ms +step:701/1695 train_time:67589ms step_avg:96.42ms +step:702/1695 train_time:67686ms step_avg:96.42ms +step:703/1695 train_time:67786ms step_avg:96.42ms +step:704/1695 train_time:67885ms step_avg:96.43ms +step:705/1695 train_time:67984ms step_avg:96.43ms +step:706/1695 train_time:68083ms step_avg:96.43ms +step:707/1695 train_time:68180ms step_avg:96.44ms +step:708/1695 train_time:68278ms step_avg:96.44ms +step:709/1695 train_time:68375ms step_avg:96.44ms +step:710/1695 train_time:68473ms step_avg:96.44ms +step:711/1695 train_time:68570ms step_avg:96.44ms +step:712/1695 train_time:68668ms step_avg:96.44ms +step:713/1695 train_time:68767ms step_avg:96.45ms +step:714/1695 train_time:69177ms step_avg:96.89ms +step:715/1695 train_time:69271ms step_avg:96.88ms +step:716/1695 train_time:69368ms step_avg:96.88ms +step:717/1695 train_time:69465ms step_avg:96.88ms +step:718/1695 train_time:69561ms step_avg:96.88ms +step:719/1695 train_time:69658ms step_avg:96.88ms +step:720/1695 train_time:70021ms step_avg:97.25ms +step:721/1695 train_time:70117ms step_avg:97.25ms +step:722/1695 train_time:70214ms step_avg:97.25ms +step:723/1695 train_time:70311ms step_avg:97.25ms +step:724/1695 train_time:70408ms step_avg:97.25ms +step:725/1695 train_time:70505ms step_avg:97.25ms +step:726/1695 train_time:70601ms step_avg:97.25ms +step:727/1695 train_time:70698ms step_avg:97.25ms +step:728/1695 train_time:70796ms step_avg:97.25ms +step:729/1695 train_time:70898ms step_avg:97.25ms +step:730/1695 train_time:70998ms step_avg:97.26ms +step:731/1695 train_time:71097ms step_avg:97.26ms +step:732/1695 train_time:71196ms step_avg:97.26ms +step:733/1695 train_time:71293ms step_avg:97.26ms +step:734/1695 train_time:71391ms step_avg:97.26ms +step:735/1695 train_time:71490ms step_avg:97.27ms +step:736/1695 train_time:71588ms step_avg:97.27ms +step:737/1695 train_time:71685ms step_avg:97.27ms +step:738/1695 train_time:71782ms step_avg:97.27ms +step:739/1695 train_time:71879ms step_avg:97.27ms +step:740/1695 train_time:71978ms step_avg:97.27ms +step:741/1695 train_time:72076ms step_avg:97.27ms +step:742/1695 train_time:72175ms step_avg:97.27ms +step:743/1695 train_time:72274ms step_avg:97.27ms +step:744/1695 train_time:72372ms step_avg:97.27ms +step:745/1695 train_time:72469ms step_avg:97.27ms +step:746/1695 train_time:72566ms step_avg:97.27ms +step:747/1695 train_time:72664ms step_avg:97.27ms +step:748/1695 train_time:72761ms step_avg:97.27ms +step:749/1695 train_time:72860ms step_avg:97.28ms +step:750/1695 train_time:72958ms step_avg:97.28ms +step:750/1695 val_loss:3.5898 train_time:73054ms step_avg:97.40ms +step:751/1695 train_time:73082ms step_avg:97.31ms +step:752/1695 train_time:73165ms step_avg:97.29ms +step:753/1695 train_time:73269ms step_avg:97.30ms +step:754/1695 train_time:73367ms step_avg:97.30ms +step:755/1695 train_time:73465ms step_avg:97.30ms +step:756/1695 train_time:73564ms step_avg:97.31ms +step:757/1695 train_time:73661ms step_avg:97.31ms +step:758/1695 train_time:73759ms step_avg:97.31ms +step:759/1695 train_time:73856ms step_avg:97.31ms +step:760/1695 train_time:73953ms step_avg:97.31ms +step:761/1695 train_time:74050ms step_avg:97.31ms +step:762/1695 train_time:74149ms step_avg:97.31ms +step:763/1695 train_time:74248ms step_avg:97.31ms +step:764/1695 train_time:74347ms step_avg:97.31ms +step:765/1695 train_time:74445ms step_avg:97.31ms +step:766/1695 train_time:74544ms step_avg:97.32ms +step:767/1695 train_time:74643ms step_avg:97.32ms +step:768/1695 train_time:74741ms step_avg:97.32ms +step:769/1695 train_time:74839ms step_avg:97.32ms +step:770/1695 train_time:74936ms step_avg:97.32ms +step:771/1695 train_time:75034ms step_avg:97.32ms +step:772/1695 train_time:75370ms step_avg:97.63ms +step:773/1695 train_time:75465ms step_avg:97.63ms +step:774/1695 train_time:75562ms step_avg:97.63ms +step:775/1695 train_time:75660ms step_avg:97.63ms +step:776/1695 train_time:75757ms step_avg:97.62ms +step:777/1695 train_time:75854ms step_avg:97.62ms +step:778/1695 train_time:75950ms step_avg:97.62ms +step:779/1695 train_time:76047ms step_avg:97.62ms +step:780/1695 train_time:76144ms step_avg:97.62ms +step:781/1695 train_time:76243ms step_avg:97.62ms +step:782/1695 train_time:76346ms step_avg:97.63ms +step:783/1695 train_time:76444ms step_avg:97.63ms +step:784/1695 train_time:76543ms step_avg:97.63ms +step:785/1695 train_time:76641ms step_avg:97.63ms +step:786/1695 train_time:76739ms step_avg:97.63ms +step:787/1695 train_time:76837ms step_avg:97.63ms +step:788/1695 train_time:76935ms step_avg:97.63ms +step:789/1695 train_time:77032ms step_avg:97.63ms +step:790/1695 train_time:77358ms step_avg:97.92ms +step:791/1695 train_time:77455ms step_avg:97.92ms +step:792/1695 train_time:77552ms step_avg:97.92ms +step:793/1695 train_time:77649ms step_avg:97.92ms +step:794/1695 train_time:77746ms step_avg:97.92ms +step:795/1695 train_time:78185ms step_avg:98.35ms +step:796/1695 train_time:78234ms step_avg:98.28ms +step:797/1695 train_time:78331ms step_avg:98.28ms +step:798/1695 train_time:78428ms step_avg:98.28ms +step:799/1695 train_time:78525ms step_avg:98.28ms +step:800/1695 train_time:78623ms step_avg:98.28ms +step:801/1695 train_time:78720ms step_avg:98.28ms +step:802/1695 train_time:78818ms step_avg:98.28ms +step:803/1695 train_time:78915ms step_avg:98.27ms +step:804/1695 train_time:79011ms step_avg:98.27ms +step:805/1695 train_time:79111ms step_avg:98.27ms +step:806/1695 train_time:79211ms step_avg:98.28ms +step:807/1695 train_time:79309ms step_avg:98.28ms +step:808/1695 train_time:79407ms step_avg:98.28ms +step:809/1695 train_time:79506ms step_avg:98.28ms +step:810/1695 train_time:79605ms step_avg:98.28ms +step:811/1695 train_time:79703ms step_avg:98.28ms +step:812/1695 train_time:79801ms step_avg:98.28ms +step:813/1695 train_time:79898ms step_avg:98.28ms +step:814/1695 train_time:79996ms step_avg:98.27ms +step:815/1695 train_time:80093ms step_avg:98.27ms +step:816/1695 train_time:80192ms step_avg:98.27ms +step:817/1695 train_time:80290ms step_avg:98.27ms +step:818/1695 train_time:80388ms step_avg:98.27ms +step:819/1695 train_time:80487ms step_avg:98.27ms +step:820/1695 train_time:80585ms step_avg:98.27ms +step:821/1695 train_time:80684ms step_avg:98.27ms +step:822/1695 train_time:80782ms step_avg:98.27ms +step:823/1695 train_time:80880ms step_avg:98.27ms +step:824/1695 train_time:80978ms step_avg:98.27ms +step:825/1695 train_time:81076ms step_avg:98.27ms +step:826/1695 train_time:81174ms step_avg:98.27ms +step:827/1695 train_time:81273ms step_avg:98.27ms +step:828/1695 train_time:81372ms step_avg:98.28ms +step:829/1695 train_time:81470ms step_avg:98.28ms +step:830/1695 train_time:81568ms step_avg:98.27ms +step:831/1695 train_time:81667ms step_avg:98.27ms +step:832/1695 train_time:81765ms step_avg:98.28ms +step:833/1695 train_time:81862ms step_avg:98.27ms +step:834/1695 train_time:81961ms step_avg:98.27ms +step:835/1695 train_time:82058ms step_avg:98.27ms +step:836/1695 train_time:82157ms step_avg:98.27ms +step:837/1695 train_time:82256ms step_avg:98.27ms +step:838/1695 train_time:82355ms step_avg:98.28ms +step:839/1695 train_time:82454ms step_avg:98.28ms +step:840/1695 train_time:82554ms step_avg:98.28ms +step:841/1695 train_time:82651ms step_avg:98.28ms +step:842/1695 train_time:82750ms step_avg:98.28ms +step:843/1695 train_time:82848ms step_avg:98.28ms +step:844/1695 train_time:82945ms step_avg:98.28ms +step:845/1695 train_time:83043ms step_avg:98.28ms +step:846/1695 train_time:83141ms step_avg:98.28ms +step:847/1695 train_time:83240ms step_avg:98.28ms +step:848/1695 train_time:83339ms step_avg:98.28ms +step:849/1695 train_time:83439ms step_avg:98.28ms +step:850/1695 train_time:83538ms step_avg:98.28ms +step:851/1695 train_time:83637ms step_avg:98.28ms +step:852/1695 train_time:83736ms step_avg:98.28ms +step:853/1695 train_time:83835ms step_avg:98.28ms +step:854/1695 train_time:83934ms step_avg:98.28ms +step:855/1695 train_time:84032ms step_avg:98.28ms +step:856/1695 train_time:84129ms step_avg:98.28ms +step:857/1695 train_time:84226ms step_avg:98.28ms +step:858/1695 train_time:84324ms step_avg:98.28ms +step:859/1695 train_time:84423ms step_avg:98.28ms +step:860/1695 train_time:84522ms step_avg:98.28ms +step:861/1695 train_time:84621ms step_avg:98.28ms +step:862/1695 train_time:84720ms step_avg:98.28ms +step:863/1695 train_time:84818ms step_avg:98.28ms +step:864/1695 train_time:84917ms step_avg:98.28ms +step:865/1695 train_time:85016ms step_avg:98.28ms +step:866/1695 train_time:85115ms step_avg:98.29ms +step:867/1695 train_time:85213ms step_avg:98.29ms +step:868/1695 train_time:85311ms step_avg:98.28ms +step:869/1695 train_time:85409ms step_avg:98.28ms +step:870/1695 train_time:85506ms step_avg:98.28ms +step:871/1695 train_time:85605ms step_avg:98.28ms +step:872/1695 train_time:85703ms step_avg:98.28ms +step:873/1695 train_time:85802ms step_avg:98.28ms +step:874/1695 train_time:85902ms step_avg:98.29ms +step:875/1695 train_time:86003ms step_avg:98.29ms +step:875/1695 val_loss:3.5387 train_time:86100ms step_avg:98.40ms +step:876/1695 train_time:86128ms step_avg:98.32ms +step:877/1695 train_time:86208ms step_avg:98.30ms +step:878/1695 train_time:86308ms step_avg:98.30ms +step:879/1695 train_time:86406ms step_avg:98.30ms +step:880/1695 train_time:86503ms step_avg:98.30ms +step:881/1695 train_time:86602ms step_avg:98.30ms +step:882/1695 train_time:86701ms step_avg:98.30ms +step:883/1695 train_time:86800ms step_avg:98.30ms +step:884/1695 train_time:86899ms step_avg:98.30ms +step:885/1695 train_time:86997ms step_avg:98.30ms +step:886/1695 train_time:87099ms step_avg:98.31ms +step:887/1695 train_time:87201ms step_avg:98.31ms +step:888/1695 train_time:87303ms step_avg:98.31ms +step:889/1695 train_time:87404ms step_avg:98.32ms +step:890/1695 train_time:87503ms step_avg:98.32ms +step:891/1695 train_time:87603ms step_avg:98.32ms +step:892/1695 train_time:87703ms step_avg:98.32ms +step:893/1695 train_time:87802ms step_avg:98.32ms +step:894/1695 train_time:87901ms step_avg:98.32ms +step:895/1695 train_time:88000ms step_avg:98.32ms +step:896/1695 train_time:88099ms step_avg:98.33ms +step:897/1695 train_time:88200ms step_avg:98.33ms +step:898/1695 train_time:88301ms step_avg:98.33ms +step:899/1695 train_time:88402ms step_avg:98.33ms +step:900/1695 train_time:88503ms step_avg:98.34ms +step:901/1695 train_time:88602ms step_avg:98.34ms +step:902/1695 train_time:88702ms step_avg:98.34ms +step:903/1695 train_time:88801ms step_avg:98.34ms +step:904/1695 train_time:88901ms step_avg:98.34ms +step:905/1695 train_time:89001ms step_avg:98.34ms +step:906/1695 train_time:89100ms step_avg:98.34ms +step:907/1695 train_time:89200ms step_avg:98.35ms +step:908/1695 train_time:89301ms step_avg:98.35ms +step:909/1695 train_time:89402ms step_avg:98.35ms +step:910/1695 train_time:89502ms step_avg:98.35ms +step:911/1695 train_time:89602ms step_avg:98.36ms +step:912/1695 train_time:89702ms step_avg:98.36ms +step:913/1695 train_time:89801ms step_avg:98.36ms +step:914/1695 train_time:89901ms step_avg:98.36ms +step:915/1695 train_time:90001ms step_avg:98.36ms +step:916/1695 train_time:90101ms step_avg:98.36ms +step:917/1695 train_time:90201ms step_avg:98.37ms +step:918/1695 train_time:90302ms step_avg:98.37ms +step:919/1695 train_time:90403ms step_avg:98.37ms +step:920/1695 train_time:90503ms step_avg:98.37ms +step:921/1695 train_time:90604ms step_avg:98.38ms +step:922/1695 train_time:90703ms step_avg:98.38ms +step:923/1695 train_time:90803ms step_avg:98.38ms +step:924/1695 train_time:90902ms step_avg:98.38ms +step:925/1695 train_time:91001ms step_avg:98.38ms +step:926/1695 train_time:91100ms step_avg:98.38ms +step:927/1695 train_time:91200ms step_avg:98.38ms +step:928/1695 train_time:91300ms step_avg:98.38ms +step:929/1695 train_time:91400ms step_avg:98.39ms +step:930/1695 train_time:91500ms step_avg:98.39ms +step:931/1695 train_time:91601ms step_avg:98.39ms +step:932/1695 train_time:91701ms step_avg:98.39ms +step:933/1695 train_time:91801ms step_avg:98.39ms +step:934/1695 train_time:91900ms step_avg:98.39ms +step:935/1695 train_time:91999ms step_avg:98.40ms +step:936/1695 train_time:92098ms step_avg:98.40ms +step:937/1695 train_time:92198ms step_avg:98.40ms +step:938/1695 train_time:92298ms step_avg:98.40ms +step:939/1695 train_time:92399ms step_avg:98.40ms +step:940/1695 train_time:92499ms step_avg:98.40ms +step:941/1695 train_time:92600ms step_avg:98.41ms +step:942/1695 train_time:92701ms step_avg:98.41ms +step:943/1695 train_time:92801ms step_avg:98.41ms +step:944/1695 train_time:92900ms step_avg:98.41ms +step:945/1695 train_time:93001ms step_avg:98.41ms +step:946/1695 train_time:93101ms step_avg:98.42ms +step:947/1695 train_time:93201ms step_avg:98.42ms +step:948/1695 train_time:93300ms step_avg:98.42ms +step:949/1695 train_time:93400ms step_avg:98.42ms +step:950/1695 train_time:93500ms step_avg:98.42ms +step:951/1695 train_time:93601ms step_avg:98.42ms +step:952/1695 train_time:93701ms step_avg:98.43ms +step:953/1695 train_time:93801ms step_avg:98.43ms +step:954/1695 train_time:93901ms step_avg:98.43ms +step:955/1695 train_time:94001ms step_avg:98.43ms +step:956/1695 train_time:94100ms step_avg:98.43ms +step:957/1695 train_time:94199ms step_avg:98.43ms +step:958/1695 train_time:94299ms step_avg:98.43ms +step:959/1695 train_time:94398ms step_avg:98.43ms +step:960/1695 train_time:94498ms step_avg:98.44ms +step:961/1695 train_time:94600ms step_avg:98.44ms +step:962/1695 train_time:94700ms step_avg:98.44ms +step:963/1695 train_time:94800ms step_avg:98.44ms +step:964/1695 train_time:94900ms step_avg:98.44ms +step:965/1695 train_time:95000ms step_avg:98.45ms +step:966/1695 train_time:95099ms step_avg:98.45ms +step:967/1695 train_time:95199ms step_avg:98.45ms +step:968/1695 train_time:95299ms step_avg:98.45ms +step:969/1695 train_time:95400ms step_avg:98.45ms +step:970/1695 train_time:95500ms step_avg:98.45ms +step:971/1695 train_time:95599ms step_avg:98.45ms +step:972/1695 train_time:95701ms step_avg:98.46ms +step:973/1695 train_time:95800ms step_avg:98.46ms +step:974/1695 train_time:95900ms step_avg:98.46ms +step:975/1695 train_time:96001ms step_avg:98.46ms +step:976/1695 train_time:96102ms step_avg:98.46ms +step:977/1695 train_time:96201ms step_avg:98.47ms +step:978/1695 train_time:96301ms step_avg:98.47ms +step:979/1695 train_time:96401ms step_avg:98.47ms +step:980/1695 train_time:96501ms step_avg:98.47ms +step:981/1695 train_time:96601ms step_avg:98.47ms +step:982/1695 train_time:96701ms step_avg:98.47ms +step:983/1695 train_time:96801ms step_avg:98.48ms +step:984/1695 train_time:96901ms step_avg:98.48ms +step:985/1695 train_time:97002ms step_avg:98.48ms +step:986/1695 train_time:97103ms step_avg:98.48ms +step:987/1695 train_time:97205ms step_avg:98.48ms +step:988/1695 train_time:97304ms step_avg:98.49ms +step:989/1695 train_time:97403ms step_avg:98.49ms +step:990/1695 train_time:97503ms step_avg:98.49ms +step:991/1695 train_time:97604ms step_avg:98.49ms +step:992/1695 train_time:97704ms step_avg:98.49ms +step:993/1695 train_time:97803ms step_avg:98.49ms +step:994/1695 train_time:97903ms step_avg:98.49ms +step:995/1695 train_time:98003ms step_avg:98.50ms +step:996/1695 train_time:98102ms step_avg:98.50ms +step:997/1695 train_time:98202ms step_avg:98.50ms +step:998/1695 train_time:98301ms step_avg:98.50ms +step:999/1695 train_time:98402ms step_avg:98.50ms +step:1000/1695 train_time:98501ms step_avg:98.50ms +step:1000/1695 val_loss:3.4933 train_time:98599ms step_avg:98.60ms +step:1001/1695 train_time:98627ms step_avg:98.53ms +step:1002/1695 train_time:98709ms step_avg:98.51ms +step:1003/1695 train_time:98815ms step_avg:98.52ms +step:1004/1695 train_time:98916ms step_avg:98.52ms +step:1005/1695 train_time:99015ms step_avg:98.52ms +step:1006/1695 train_time:99115ms step_avg:98.52ms +step:1007/1695 train_time:99214ms step_avg:98.52ms +step:1008/1695 train_time:99313ms step_avg:98.52ms +step:1009/1695 train_time:99412ms step_avg:98.53ms +step:1010/1695 train_time:99510ms step_avg:98.52ms +step:1011/1695 train_time:99611ms step_avg:98.53ms +step:1012/1695 train_time:99712ms step_avg:98.53ms +step:1013/1695 train_time:99813ms step_avg:98.53ms +step:1014/1695 train_time:99915ms step_avg:98.54ms +step:1015/1695 train_time:100016ms step_avg:98.54ms +step:1016/1695 train_time:100116ms step_avg:98.54ms +step:1017/1695 train_time:100216ms step_avg:98.54ms +step:1018/1695 train_time:100316ms step_avg:98.54ms +step:1019/1695 train_time:100415ms step_avg:98.54ms +step:1020/1695 train_time:100515ms step_avg:98.54ms +step:1021/1695 train_time:100616ms step_avg:98.55ms +step:1022/1695 train_time:100716ms step_avg:98.55ms +step:1023/1695 train_time:100818ms step_avg:98.55ms +step:1024/1695 train_time:100920ms step_avg:98.55ms +step:1025/1695 train_time:101021ms step_avg:98.56ms +step:1026/1695 train_time:101121ms step_avg:98.56ms +step:1027/1695 train_time:101220ms step_avg:98.56ms +step:1028/1695 train_time:101319ms step_avg:98.56ms +step:1029/1695 train_time:101419ms step_avg:98.56ms +step:1030/1695 train_time:101519ms step_avg:98.56ms +step:1031/1695 train_time:101620ms step_avg:98.56ms +step:1032/1695 train_time:101720ms step_avg:98.57ms +step:1033/1695 train_time:101821ms step_avg:98.57ms +step:1034/1695 train_time:101921ms step_avg:98.57ms +step:1035/1695 train_time:102021ms step_avg:98.57ms +step:1036/1695 train_time:102121ms step_avg:98.57ms +step:1037/1695 train_time:102221ms step_avg:98.57ms +step:1038/1695 train_time:102320ms step_avg:98.57ms +step:1039/1695 train_time:102419ms step_avg:98.57ms +step:1040/1695 train_time:102520ms step_avg:98.58ms +step:1041/1695 train_time:102620ms step_avg:98.58ms +step:1042/1695 train_time:102720ms step_avg:98.58ms +step:1043/1695 train_time:102820ms step_avg:98.58ms +step:1044/1695 train_time:102920ms step_avg:98.58ms +step:1045/1695 train_time:103020ms step_avg:98.58ms +step:1046/1695 train_time:103120ms step_avg:98.59ms +step:1047/1695 train_time:103220ms step_avg:98.59ms +step:1048/1695 train_time:103320ms step_avg:98.59ms +step:1049/1695 train_time:103419ms step_avg:98.59ms +step:1050/1695 train_time:103519ms step_avg:98.59ms +step:1051/1695 train_time:103620ms step_avg:98.59ms +step:1052/1695 train_time:103719ms step_avg:98.59ms +step:1053/1695 train_time:103819ms step_avg:98.59ms +step:1054/1695 train_time:103918ms step_avg:98.59ms +step:1055/1695 train_time:104018ms step_avg:98.60ms +step:1056/1695 train_time:104117ms step_avg:98.60ms +step:1057/1695 train_time:104217ms step_avg:98.60ms +step:1058/1695 train_time:104316ms step_avg:98.60ms +step:1059/1695 train_time:104416ms step_avg:98.60ms +step:1060/1695 train_time:104516ms step_avg:98.60ms +step:1061/1695 train_time:104616ms step_avg:98.60ms +step:1062/1695 train_time:104717ms step_avg:98.60ms +step:1063/1695 train_time:104817ms step_avg:98.60ms +step:1064/1695 train_time:104917ms step_avg:98.61ms +step:1065/1695 train_time:105017ms step_avg:98.61ms +step:1066/1695 train_time:105117ms step_avg:98.61ms +step:1067/1695 train_time:105218ms step_avg:98.61ms +step:1068/1695 train_time:105317ms step_avg:98.61ms +step:1069/1695 train_time:105417ms step_avg:98.61ms +step:1070/1695 train_time:105517ms step_avg:98.61ms +step:1071/1695 train_time:105617ms step_avg:98.61ms +step:1072/1695 train_time:105716ms step_avg:98.62ms +step:1073/1695 train_time:105816ms step_avg:98.62ms +step:1074/1695 train_time:105916ms step_avg:98.62ms +step:1075/1695 train_time:106016ms step_avg:98.62ms +step:1076/1695 train_time:106116ms step_avg:98.62ms +step:1077/1695 train_time:106217ms step_avg:98.62ms +step:1078/1695 train_time:106317ms step_avg:98.62ms +step:1079/1695 train_time:106418ms step_avg:98.63ms +step:1080/1695 train_time:106517ms step_avg:98.63ms +step:1081/1695 train_time:106616ms step_avg:98.63ms +step:1082/1695 train_time:106717ms step_avg:98.63ms +step:1083/1695 train_time:106817ms step_avg:98.63ms +step:1084/1695 train_time:106917ms step_avg:98.63ms +step:1085/1695 train_time:107017ms step_avg:98.63ms +step:1086/1695 train_time:107117ms step_avg:98.63ms +step:1087/1695 train_time:107217ms step_avg:98.64ms +step:1088/1695 train_time:107317ms step_avg:98.64ms +step:1089/1695 train_time:107417ms step_avg:98.64ms +step:1090/1695 train_time:107518ms step_avg:98.64ms +step:1091/1695 train_time:107619ms step_avg:98.64ms +step:1092/1695 train_time:107719ms step_avg:98.64ms +step:1093/1695 train_time:107818ms step_avg:98.64ms +step:1094/1695 train_time:107918ms step_avg:98.64ms +step:1095/1695 train_time:108017ms step_avg:98.65ms +step:1096/1695 train_time:108117ms step_avg:98.65ms +step:1097/1695 train_time:108217ms step_avg:98.65ms +step:1098/1695 train_time:108317ms step_avg:98.65ms +step:1099/1695 train_time:108416ms step_avg:98.65ms +step:1100/1695 train_time:108516ms step_avg:98.65ms +step:1101/1695 train_time:108616ms step_avg:98.65ms +step:1102/1695 train_time:108716ms step_avg:98.65ms +step:1103/1695 train_time:108816ms step_avg:98.65ms +step:1104/1695 train_time:108916ms step_avg:98.66ms +step:1105/1695 train_time:109016ms step_avg:98.66ms +step:1106/1695 train_time:109116ms step_avg:98.66ms +step:1107/1695 train_time:109217ms step_avg:98.66ms +step:1108/1695 train_time:109317ms step_avg:98.66ms +step:1109/1695 train_time:109417ms step_avg:98.66ms +step:1110/1695 train_time:109517ms step_avg:98.66ms +step:1111/1695 train_time:109617ms step_avg:98.66ms +step:1112/1695 train_time:109717ms step_avg:98.67ms +step:1113/1695 train_time:109817ms step_avg:98.67ms +step:1114/1695 train_time:109917ms step_avg:98.67ms +step:1115/1695 train_time:110017ms step_avg:98.67ms +step:1116/1695 train_time:110117ms step_avg:98.67ms +step:1117/1695 train_time:110217ms step_avg:98.67ms +step:1118/1695 train_time:110316ms step_avg:98.67ms +step:1119/1695 train_time:110416ms step_avg:98.67ms +step:1120/1695 train_time:110516ms step_avg:98.68ms +step:1121/1695 train_time:110616ms step_avg:98.68ms +step:1122/1695 train_time:110716ms step_avg:98.68ms +step:1123/1695 train_time:110816ms step_avg:98.68ms +step:1124/1695 train_time:110916ms step_avg:98.68ms +step:1125/1695 train_time:111016ms step_avg:98.68ms +step:1125/1695 val_loss:3.4421 train_time:111114ms step_avg:98.77ms +step:1126/1695 train_time:111142ms step_avg:98.70ms +step:1127/1695 train_time:111225ms step_avg:98.69ms +step:1128/1695 train_time:111328ms step_avg:98.70ms +step:1129/1695 train_time:111428ms step_avg:98.70ms +step:1130/1695 train_time:111526ms step_avg:98.70ms +step:1131/1695 train_time:111625ms step_avg:98.70ms +step:1132/1695 train_time:111724ms step_avg:98.70ms +step:1133/1695 train_time:111824ms step_avg:98.70ms +step:1134/1695 train_time:111923ms step_avg:98.70ms +step:1135/1695 train_time:112022ms step_avg:98.70ms +step:1136/1695 train_time:112124ms step_avg:98.70ms +step:1137/1695 train_time:112225ms step_avg:98.70ms +step:1138/1695 train_time:112328ms step_avg:98.71ms +step:1139/1695 train_time:112429ms step_avg:98.71ms +step:1140/1695 train_time:112530ms step_avg:98.71ms +step:1141/1695 train_time:112630ms step_avg:98.71ms +step:1142/1695 train_time:112731ms step_avg:98.71ms +step:1143/1695 train_time:112831ms step_avg:98.72ms +step:1144/1695 train_time:112933ms step_avg:98.72ms +step:1145/1695 train_time:113034ms step_avg:98.72ms +step:1146/1695 train_time:113135ms step_avg:98.72ms +step:1147/1695 train_time:113236ms step_avg:98.72ms +step:1148/1695 train_time:113337ms step_avg:98.73ms +step:1149/1695 train_time:113438ms step_avg:98.73ms +step:1150/1695 train_time:113539ms step_avg:98.73ms +step:1151/1695 train_time:113640ms step_avg:98.73ms +step:1152/1695 train_time:113740ms step_avg:98.73ms +step:1153/1695 train_time:113842ms step_avg:98.74ms +step:1154/1695 train_time:113942ms step_avg:98.74ms +step:1155/1695 train_time:114042ms step_avg:98.74ms +step:1156/1695 train_time:114142ms step_avg:98.74ms +step:1157/1695 train_time:114243ms step_avg:98.74ms +step:1158/1695 train_time:114342ms step_avg:98.74ms +step:1159/1695 train_time:114442ms step_avg:98.74ms +step:1160/1695 train_time:114544ms step_avg:98.74ms +step:1161/1695 train_time:114644ms step_avg:98.75ms +step:1162/1695 train_time:114743ms step_avg:98.75ms +step:1163/1695 train_time:114846ms step_avg:98.75ms +step:1164/1695 train_time:114947ms step_avg:98.75ms +step:1165/1695 train_time:115048ms step_avg:98.75ms +step:1166/1695 train_time:115149ms step_avg:98.76ms +step:1167/1695 train_time:115250ms step_avg:98.76ms +step:1168/1695 train_time:115353ms step_avg:98.76ms +step:1169/1695 train_time:115455ms step_avg:98.76ms +step:1170/1695 train_time:115555ms step_avg:98.77ms +step:1171/1695 train_time:115655ms step_avg:98.77ms +step:1172/1695 train_time:115756ms step_avg:98.77ms +step:1173/1695 train_time:115857ms step_avg:98.77ms +step:1174/1695 train_time:115959ms step_avg:98.77ms +step:1175/1695 train_time:116060ms step_avg:98.77ms +step:1176/1695 train_time:116161ms step_avg:98.78ms +step:1177/1695 train_time:116262ms step_avg:98.78ms +step:1178/1695 train_time:116363ms step_avg:98.78ms +step:1179/1695 train_time:116465ms step_avg:98.78ms +step:1180/1695 train_time:116565ms step_avg:98.78ms +step:1181/1695 train_time:116666ms step_avg:98.79ms +step:1182/1695 train_time:116765ms step_avg:98.79ms +step:1183/1695 train_time:116866ms step_avg:98.79ms +step:1184/1695 train_time:116971ms step_avg:98.79ms +step:1185/1695 train_time:117074ms step_avg:98.80ms +step:1186/1695 train_time:117175ms step_avg:98.80ms +step:1187/1695 train_time:117275ms step_avg:98.80ms +step:1188/1695 train_time:117376ms step_avg:98.80ms +step:1189/1695 train_time:117476ms step_avg:98.80ms +step:1190/1695 train_time:117576ms step_avg:98.80ms +step:1191/1695 train_time:117676ms step_avg:98.80ms +step:1192/1695 train_time:117777ms step_avg:98.81ms +step:1193/1695 train_time:117878ms step_avg:98.81ms +step:1194/1695 train_time:117980ms step_avg:98.81ms +step:1195/1695 train_time:118081ms step_avg:98.81ms +step:1196/1695 train_time:118181ms step_avg:98.81ms +step:1197/1695 train_time:118282ms step_avg:98.82ms +step:1198/1695 train_time:118382ms step_avg:98.82ms +step:1199/1695 train_time:118482ms step_avg:98.82ms +step:1200/1695 train_time:118582ms step_avg:98.82ms +step:1201/1695 train_time:118682ms step_avg:98.82ms +step:1202/1695 train_time:118783ms step_avg:98.82ms +step:1203/1695 train_time:118883ms step_avg:98.82ms +step:1204/1695 train_time:118984ms step_avg:98.82ms +step:1205/1695 train_time:119084ms step_avg:98.83ms +step:1206/1695 train_time:119185ms step_avg:98.83ms +step:1207/1695 train_time:119285ms step_avg:98.83ms +step:1208/1695 train_time:119385ms step_avg:98.83ms +step:1209/1695 train_time:119486ms step_avg:98.83ms +step:1210/1695 train_time:119586ms step_avg:98.83ms +step:1211/1695 train_time:119689ms step_avg:98.84ms +step:1212/1695 train_time:119790ms step_avg:98.84ms +step:1213/1695 train_time:119892ms step_avg:98.84ms +step:1214/1695 train_time:119993ms step_avg:98.84ms +step:1215/1695 train_time:120095ms step_avg:98.84ms +step:1216/1695 train_time:120198ms step_avg:98.85ms +step:1217/1695 train_time:120298ms step_avg:98.85ms +step:1218/1695 train_time:120400ms step_avg:98.85ms +step:1219/1695 train_time:120500ms step_avg:98.85ms +step:1220/1695 train_time:120601ms step_avg:98.85ms +step:1221/1695 train_time:120702ms step_avg:98.85ms +step:1222/1695 train_time:120803ms step_avg:98.86ms +step:1223/1695 train_time:120903ms step_avg:98.86ms +step:1224/1695 train_time:121003ms step_avg:98.86ms +step:1225/1695 train_time:121104ms step_avg:98.86ms +step:1226/1695 train_time:121204ms step_avg:98.86ms +step:1227/1695 train_time:121305ms step_avg:98.86ms +step:1228/1695 train_time:121405ms step_avg:98.86ms +step:1229/1695 train_time:121506ms step_avg:98.87ms +step:1230/1695 train_time:121606ms step_avg:98.87ms +step:1231/1695 train_time:121708ms step_avg:98.87ms +step:1232/1695 train_time:121809ms step_avg:98.87ms +step:1233/1695 train_time:121911ms step_avg:98.87ms +step:1234/1695 train_time:122013ms step_avg:98.88ms +step:1235/1695 train_time:122114ms step_avg:98.88ms +step:1236/1695 train_time:122215ms step_avg:98.88ms +step:1237/1695 train_time:122317ms step_avg:98.88ms +step:1238/1695 train_time:122418ms step_avg:98.88ms +step:1239/1695 train_time:122519ms step_avg:98.89ms +step:1240/1695 train_time:122620ms step_avg:98.89ms +step:1241/1695 train_time:122721ms step_avg:98.89ms +step:1242/1695 train_time:122823ms step_avg:98.89ms +step:1243/1695 train_time:122922ms step_avg:98.89ms +step:1244/1695 train_time:123022ms step_avg:98.89ms +step:1245/1695 train_time:123122ms step_avg:98.89ms +step:1246/1695 train_time:123223ms step_avg:98.89ms +step:1247/1695 train_time:123323ms step_avg:98.90ms +step:1248/1695 train_time:123423ms step_avg:98.90ms +step:1249/1695 train_time:123523ms step_avg:98.90ms +step:1250/1695 train_time:123623ms step_avg:98.90ms +step:1250/1695 val_loss:3.3966 train_time:123721ms step_avg:98.98ms +step:1251/1695 train_time:123749ms step_avg:98.92ms +step:1252/1695 train_time:123832ms step_avg:98.91ms +step:1253/1695 train_time:123934ms step_avg:98.91ms +step:1254/1695 train_time:124035ms step_avg:98.91ms +step:1255/1695 train_time:124136ms step_avg:98.91ms +step:1256/1695 train_time:124235ms step_avg:98.91ms +step:1257/1695 train_time:124335ms step_avg:98.91ms +step:1258/1695 train_time:124436ms step_avg:98.92ms +step:1259/1695 train_time:124537ms step_avg:98.92ms +step:1260/1695 train_time:124637ms step_avg:98.92ms +step:1261/1695 train_time:124739ms step_avg:98.92ms +step:1262/1695 train_time:124844ms step_avg:98.93ms +step:1263/1695 train_time:124945ms step_avg:98.93ms +step:1264/1695 train_time:125045ms step_avg:98.93ms +step:1265/1695 train_time:125145ms step_avg:98.93ms +step:1266/1695 train_time:125245ms step_avg:98.93ms +step:1267/1695 train_time:125346ms step_avg:98.93ms +step:1268/1695 train_time:125447ms step_avg:98.93ms +step:1269/1695 train_time:125547ms step_avg:98.93ms +step:1270/1695 train_time:125647ms step_avg:98.93ms +step:1271/1695 train_time:125749ms step_avg:98.94ms +step:1272/1695 train_time:125849ms step_avg:98.94ms +step:1273/1695 train_time:125951ms step_avg:98.94ms +step:1274/1695 train_time:126051ms step_avg:98.94ms +step:1275/1695 train_time:126152ms step_avg:98.94ms +step:1276/1695 train_time:126255ms step_avg:98.95ms +step:1277/1695 train_time:126357ms step_avg:98.95ms +step:1278/1695 train_time:126459ms step_avg:98.95ms +step:1279/1695 train_time:126559ms step_avg:98.95ms +step:1280/1695 train_time:126660ms step_avg:98.95ms +step:1281/1695 train_time:126762ms step_avg:98.96ms +step:1282/1695 train_time:126863ms step_avg:98.96ms +step:1283/1695 train_time:126964ms step_avg:98.96ms +step:1284/1695 train_time:127064ms step_avg:98.96ms +step:1285/1695 train_time:127165ms step_avg:98.96ms +step:1286/1695 train_time:127266ms step_avg:98.96ms +step:1287/1695 train_time:127367ms step_avg:98.96ms +step:1288/1695 train_time:127467ms step_avg:98.97ms +step:1289/1695 train_time:127568ms step_avg:98.97ms +step:1290/1695 train_time:127668ms step_avg:98.97ms +step:1291/1695 train_time:127769ms step_avg:98.97ms +step:1292/1695 train_time:127869ms step_avg:98.97ms +step:1293/1695 train_time:127970ms step_avg:98.97ms +step:1294/1695 train_time:128071ms step_avg:98.97ms +step:1295/1695 train_time:128172ms step_avg:98.97ms +step:1296/1695 train_time:128274ms step_avg:98.98ms +step:1297/1695 train_time:128375ms step_avg:98.98ms +step:1298/1695 train_time:128476ms step_avg:98.98ms +step:1299/1695 train_time:128577ms step_avg:98.98ms +step:1300/1695 train_time:128678ms step_avg:98.98ms +step:1301/1695 train_time:128780ms step_avg:98.99ms +step:1302/1695 train_time:128882ms step_avg:98.99ms +step:1303/1695 train_time:128983ms step_avg:98.99ms +step:1304/1695 train_time:129084ms step_avg:98.99ms +step:1305/1695 train_time:129186ms step_avg:98.99ms +step:1306/1695 train_time:129286ms step_avg:98.99ms +step:1307/1695 train_time:129386ms step_avg:98.99ms +step:1308/1695 train_time:129486ms step_avg:99.00ms +step:1309/1695 train_time:129587ms step_avg:99.00ms +step:1310/1695 train_time:129688ms step_avg:99.00ms +step:1311/1695 train_time:129789ms step_avg:99.00ms +step:1312/1695 train_time:129890ms step_avg:99.00ms +step:1313/1695 train_time:129992ms step_avg:99.00ms +step:1314/1695 train_time:130093ms step_avg:99.01ms +step:1315/1695 train_time:130194ms step_avg:99.01ms +step:1316/1695 train_time:130297ms step_avg:99.01ms +step:1317/1695 train_time:130397ms step_avg:99.01ms +step:1318/1695 train_time:130498ms step_avg:99.01ms +step:1319/1695 train_time:130600ms step_avg:99.01ms +step:1320/1695 train_time:130703ms step_avg:99.02ms +step:1321/1695 train_time:130804ms step_avg:99.02ms +step:1322/1695 train_time:130905ms step_avg:99.02ms +step:1323/1695 train_time:131005ms step_avg:99.02ms +step:1324/1695 train_time:131106ms step_avg:99.02ms +step:1325/1695 train_time:131207ms step_avg:99.02ms +step:1326/1695 train_time:131307ms step_avg:99.03ms +step:1327/1695 train_time:131408ms step_avg:99.03ms +step:1328/1695 train_time:131507ms step_avg:99.03ms +step:1329/1695 train_time:131607ms step_avg:99.03ms +step:1330/1695 train_time:131708ms step_avg:99.03ms +step:1331/1695 train_time:131808ms step_avg:99.03ms +step:1332/1695 train_time:131909ms step_avg:99.03ms +step:1333/1695 train_time:132010ms step_avg:99.03ms +step:1334/1695 train_time:132112ms step_avg:99.03ms +step:1335/1695 train_time:132212ms step_avg:99.04ms +step:1336/1695 train_time:132314ms step_avg:99.04ms +step:1337/1695 train_time:132415ms step_avg:99.04ms +step:1338/1695 train_time:132516ms step_avg:99.04ms +step:1339/1695 train_time:132618ms step_avg:99.04ms +step:1340/1695 train_time:132718ms step_avg:99.04ms +step:1341/1695 train_time:132820ms step_avg:99.05ms +step:1342/1695 train_time:132921ms step_avg:99.05ms +step:1343/1695 train_time:133022ms step_avg:99.05ms +step:1344/1695 train_time:133122ms step_avg:99.05ms +step:1345/1695 train_time:133224ms step_avg:99.05ms +step:1346/1695 train_time:133326ms step_avg:99.05ms +step:1347/1695 train_time:133427ms step_avg:99.06ms +step:1348/1695 train_time:133527ms step_avg:99.06ms +step:1349/1695 train_time:133627ms step_avg:99.06ms +step:1350/1695 train_time:133728ms step_avg:99.06ms +step:1351/1695 train_time:133828ms step_avg:99.06ms +step:1352/1695 train_time:133928ms step_avg:99.06ms +step:1353/1695 train_time:134028ms step_avg:99.06ms +step:1354/1695 train_time:134129ms step_avg:99.06ms +step:1355/1695 train_time:134230ms step_avg:99.06ms +step:1356/1695 train_time:134332ms step_avg:99.06ms +step:1357/1695 train_time:134434ms step_avg:99.07ms +step:1358/1695 train_time:134536ms step_avg:99.07ms +step:1359/1695 train_time:134637ms step_avg:99.07ms +step:1360/1695 train_time:134737ms step_avg:99.07ms +step:1361/1695 train_time:134837ms step_avg:99.07ms +step:1362/1695 train_time:134938ms step_avg:99.07ms +step:1363/1695 train_time:135040ms step_avg:99.08ms +step:1364/1695 train_time:135142ms step_avg:99.08ms +step:1365/1695 train_time:135244ms step_avg:99.08ms +step:1366/1695 train_time:135344ms step_avg:99.08ms +step:1367/1695 train_time:135445ms step_avg:99.08ms +step:1368/1695 train_time:135546ms step_avg:99.08ms +step:1369/1695 train_time:135646ms step_avg:99.08ms +step:1370/1695 train_time:135745ms step_avg:99.08ms +step:1371/1695 train_time:135847ms step_avg:99.09ms +step:1372/1695 train_time:135948ms step_avg:99.09ms +step:1373/1695 train_time:136048ms step_avg:99.09ms +step:1374/1695 train_time:136148ms step_avg:99.09ms +step:1375/1695 train_time:136248ms step_avg:99.09ms +step:1375/1695 val_loss:3.3574 train_time:136346ms step_avg:99.16ms +step:1376/1695 train_time:136374ms step_avg:99.11ms +step:1377/1695 train_time:136458ms step_avg:99.10ms +step:1378/1695 train_time:136564ms step_avg:99.10ms +step:1379/1695 train_time:136664ms step_avg:99.10ms +step:1380/1695 train_time:136767ms step_avg:99.11ms +step:1381/1695 train_time:136867ms step_avg:99.11ms +step:1382/1695 train_time:136966ms step_avg:99.11ms +step:1383/1695 train_time:137066ms step_avg:99.11ms +step:1384/1695 train_time:137167ms step_avg:99.11ms +step:1385/1695 train_time:137268ms step_avg:99.11ms +step:1386/1695 train_time:137372ms step_avg:99.11ms +step:1387/1695 train_time:137475ms step_avg:99.12ms +step:1388/1695 train_time:137577ms step_avg:99.12ms +step:1389/1695 train_time:137679ms step_avg:99.12ms +step:1390/1695 train_time:137780ms step_avg:99.12ms +step:1391/1695 train_time:137881ms step_avg:99.12ms +step:1392/1695 train_time:137983ms step_avg:99.13ms +step:1393/1695 train_time:138085ms step_avg:99.13ms +step:1394/1695 train_time:138187ms step_avg:99.13ms +step:1395/1695 train_time:138290ms step_avg:99.13ms +step:1396/1695 train_time:138391ms step_avg:99.13ms +step:1397/1695 train_time:138494ms step_avg:99.14ms +step:1398/1695 train_time:138595ms step_avg:99.14ms +step:1399/1695 train_time:138697ms step_avg:99.14ms +step:1400/1695 train_time:138799ms step_avg:99.14ms +step:1401/1695 train_time:138900ms step_avg:99.14ms +step:1402/1695 train_time:139001ms step_avg:99.14ms +step:1403/1695 train_time:139105ms step_avg:99.15ms +step:1404/1695 train_time:139207ms step_avg:99.15ms +step:1405/1695 train_time:139309ms step_avg:99.15ms +step:1406/1695 train_time:139411ms step_avg:99.15ms +step:1407/1695 train_time:139513ms step_avg:99.16ms +step:1408/1695 train_time:139614ms step_avg:99.16ms +step:1409/1695 train_time:139718ms step_avg:99.16ms +step:1410/1695 train_time:139819ms step_avg:99.16ms +step:1411/1695 train_time:139920ms step_avg:99.16ms +step:1412/1695 train_time:140024ms step_avg:99.17ms +step:1413/1695 train_time:140125ms step_avg:99.17ms +step:1414/1695 train_time:140227ms step_avg:99.17ms +step:1415/1695 train_time:140329ms step_avg:99.17ms +step:1416/1695 train_time:140431ms step_avg:99.17ms +step:1417/1695 train_time:140532ms step_avg:99.18ms +step:1418/1695 train_time:140633ms step_avg:99.18ms +step:1419/1695 train_time:140736ms step_avg:99.18ms +step:1420/1695 train_time:140837ms step_avg:99.18ms +step:1421/1695 train_time:140939ms step_avg:99.18ms +step:1422/1695 train_time:141041ms step_avg:99.18ms +step:1423/1695 train_time:141142ms step_avg:99.19ms +step:1424/1695 train_time:141245ms step_avg:99.19ms +step:1425/1695 train_time:141347ms step_avg:99.19ms +step:1426/1695 train_time:141450ms step_avg:99.19ms +step:1427/1695 train_time:141552ms step_avg:99.20ms +step:1428/1695 train_time:141653ms step_avg:99.20ms +step:1429/1695 train_time:141754ms step_avg:99.20ms +step:1430/1695 train_time:141856ms step_avg:99.20ms +step:1431/1695 train_time:141957ms step_avg:99.20ms +step:1432/1695 train_time:142059ms step_avg:99.20ms +step:1433/1695 train_time:142161ms step_avg:99.21ms +step:1434/1695 train_time:142263ms step_avg:99.21ms +step:1435/1695 train_time:142366ms step_avg:99.21ms +step:1436/1695 train_time:142469ms step_avg:99.21ms +step:1437/1695 train_time:142571ms step_avg:99.21ms +step:1438/1695 train_time:142672ms step_avg:99.22ms +step:1439/1695 train_time:142775ms step_avg:99.22ms +step:1440/1695 train_time:142877ms step_avg:99.22ms +step:1441/1695 train_time:142980ms step_avg:99.22ms +step:1442/1695 train_time:143081ms step_avg:99.22ms +step:1443/1695 train_time:143182ms step_avg:99.23ms +step:1444/1695 train_time:143284ms step_avg:99.23ms +step:1445/1695 train_time:143384ms step_avg:99.23ms +step:1446/1695 train_time:143487ms step_avg:99.23ms +step:1447/1695 train_time:143588ms step_avg:99.23ms +step:1448/1695 train_time:143691ms step_avg:99.23ms +step:1449/1695 train_time:143792ms step_avg:99.24ms +step:1450/1695 train_time:143893ms step_avg:99.24ms +step:1451/1695 train_time:143994ms step_avg:99.24ms +step:1452/1695 train_time:144095ms step_avg:99.24ms +step:1453/1695 train_time:144196ms step_avg:99.24ms +step:1454/1695 train_time:144299ms step_avg:99.24ms +step:1455/1695 train_time:144403ms step_avg:99.25ms +step:1456/1695 train_time:144505ms step_avg:99.25ms +step:1457/1695 train_time:144607ms step_avg:99.25ms +step:1458/1695 train_time:144709ms step_avg:99.25ms +step:1459/1695 train_time:144811ms step_avg:99.25ms +step:1460/1695 train_time:144912ms step_avg:99.25ms +step:1461/1695 train_time:145014ms step_avg:99.26ms +step:1462/1695 train_time:145116ms step_avg:99.26ms +step:1463/1695 train_time:145217ms step_avg:99.26ms +step:1464/1695 train_time:145319ms step_avg:99.26ms +step:1465/1695 train_time:145419ms step_avg:99.26ms +step:1466/1695 train_time:145522ms step_avg:99.26ms +step:1467/1695 train_time:145624ms step_avg:99.27ms +step:1468/1695 train_time:145727ms step_avg:99.27ms +step:1469/1695 train_time:145831ms step_avg:99.27ms +step:1470/1695 train_time:145932ms step_avg:99.27ms +step:1471/1695 train_time:146032ms step_avg:99.27ms +step:1472/1695 train_time:146133ms step_avg:99.27ms +step:1473/1695 train_time:146234ms step_avg:99.28ms +step:1474/1695 train_time:146335ms step_avg:99.28ms +step:1475/1695 train_time:146438ms step_avg:99.28ms +step:1476/1695 train_time:146540ms step_avg:99.28ms +step:1477/1695 train_time:146644ms step_avg:99.28ms +step:1478/1695 train_time:146747ms step_avg:99.29ms +step:1479/1695 train_time:146848ms step_avg:99.29ms +step:1480/1695 train_time:146950ms step_avg:99.29ms +step:1481/1695 train_time:147057ms step_avg:99.30ms +step:1482/1695 train_time:147152ms step_avg:99.29ms +step:1483/1695 train_time:147253ms step_avg:99.29ms +step:1484/1695 train_time:147355ms step_avg:99.30ms +step:1485/1695 train_time:147456ms step_avg:99.30ms +step:1486/1695 train_time:147558ms step_avg:99.30ms +step:1487/1695 train_time:147661ms step_avg:99.30ms +step:1488/1695 train_time:147765ms step_avg:99.30ms +step:1489/1695 train_time:147867ms step_avg:99.31ms +step:1490/1695 train_time:147969ms step_avg:99.31ms +step:1491/1695 train_time:148070ms step_avg:99.31ms +step:1492/1695 train_time:148171ms step_avg:99.31ms +step:1493/1695 train_time:148272ms step_avg:99.31ms +step:1494/1695 train_time:148374ms step_avg:99.31ms +step:1495/1695 train_time:148475ms step_avg:99.31ms +step:1496/1695 train_time:148576ms step_avg:99.32ms +step:1497/1695 train_time:148676ms step_avg:99.32ms +step:1498/1695 train_time:148779ms step_avg:99.32ms +step:1499/1695 train_time:148881ms step_avg:99.32ms +step:1500/1695 train_time:148984ms step_avg:99.32ms +step:1500/1695 val_loss:3.3221 train_time:149083ms step_avg:99.39ms +step:1501/1695 train_time:149111ms step_avg:99.34ms +step:1502/1695 train_time:149199ms step_avg:99.33ms +step:1503/1695 train_time:149302ms step_avg:99.34ms +step:1504/1695 train_time:149404ms step_avg:99.34ms +step:1505/1695 train_time:149505ms step_avg:99.34ms +step:1506/1695 train_time:149606ms step_avg:99.34ms +step:1507/1695 train_time:149707ms step_avg:99.34ms +step:1508/1695 train_time:149808ms step_avg:99.34ms +step:1509/1695 train_time:149910ms step_avg:99.34ms +step:1510/1695 train_time:150012ms step_avg:99.35ms +step:1511/1695 train_time:150115ms step_avg:99.35ms +step:1512/1695 train_time:150217ms step_avg:99.35ms +step:1513/1695 train_time:150319ms step_avg:99.35ms +step:1514/1695 train_time:150421ms step_avg:99.35ms +step:1515/1695 train_time:150527ms step_avg:99.36ms +step:1516/1695 train_time:150629ms step_avg:99.36ms +step:1517/1695 train_time:150729ms step_avg:99.36ms +step:1518/1695 train_time:150830ms step_avg:99.36ms +step:1519/1695 train_time:150934ms step_avg:99.36ms +step:1520/1695 train_time:151035ms step_avg:99.37ms +step:1521/1695 train_time:151137ms step_avg:99.37ms +step:1522/1695 train_time:151238ms step_avg:99.37ms +step:1523/1695 train_time:151340ms step_avg:99.37ms +step:1524/1695 train_time:151446ms step_avg:99.37ms +step:1525/1695 train_time:151550ms step_avg:99.38ms +step:1526/1695 train_time:151652ms step_avg:99.38ms +step:1527/1695 train_time:151753ms step_avg:99.38ms +step:1528/1695 train_time:151858ms step_avg:99.38ms +step:1529/1695 train_time:151960ms step_avg:99.39ms +step:1530/1695 train_time:152063ms step_avg:99.39ms +step:1531/1695 train_time:152164ms step_avg:99.39ms +step:1532/1695 train_time:152267ms step_avg:99.39ms +step:1533/1695 train_time:152369ms step_avg:99.39ms +step:1534/1695 train_time:152470ms step_avg:99.39ms +step:1535/1695 train_time:152572ms step_avg:99.40ms +step:1536/1695 train_time:152673ms step_avg:99.40ms +step:1537/1695 train_time:152774ms step_avg:99.40ms +step:1538/1695 train_time:152875ms step_avg:99.40ms +step:1539/1695 train_time:152977ms step_avg:99.40ms +step:1540/1695 train_time:153081ms step_avg:99.40ms +step:1541/1695 train_time:153186ms step_avg:99.41ms +step:1542/1695 train_time:153290ms step_avg:99.41ms +step:1543/1695 train_time:153393ms step_avg:99.41ms +step:1544/1695 train_time:153495ms step_avg:99.41ms +step:1545/1695 train_time:153597ms step_avg:99.42ms +step:1546/1695 train_time:153698ms step_avg:99.42ms +step:1547/1695 train_time:153800ms step_avg:99.42ms +step:1548/1695 train_time:153902ms step_avg:99.42ms +step:1549/1695 train_time:154005ms step_avg:99.42ms +step:1550/1695 train_time:154106ms step_avg:99.42ms +step:1551/1695 train_time:154208ms step_avg:99.42ms +step:1552/1695 train_time:154311ms step_avg:99.43ms +step:1553/1695 train_time:154413ms step_avg:99.43ms +step:1554/1695 train_time:154515ms step_avg:99.43ms +step:1555/1695 train_time:154616ms step_avg:99.43ms +step:1556/1695 train_time:154718ms step_avg:99.43ms +step:1557/1695 train_time:154823ms step_avg:99.44ms +step:1558/1695 train_time:154926ms step_avg:99.44ms +step:1559/1695 train_time:155027ms step_avg:99.44ms +step:1560/1695 train_time:155129ms step_avg:99.44ms +step:1561/1695 train_time:155230ms step_avg:99.44ms +step:1562/1695 train_time:155333ms step_avg:99.44ms +step:1563/1695 train_time:155437ms step_avg:99.45ms +step:1564/1695 train_time:155538ms step_avg:99.45ms +step:1565/1695 train_time:155639ms step_avg:99.45ms +step:1566/1695 train_time:155740ms step_avg:99.45ms +step:1567/1695 train_time:155841ms step_avg:99.45ms +step:1568/1695 train_time:155942ms step_avg:99.45ms +step:1569/1695 train_time:156044ms step_avg:99.45ms +step:1570/1695 train_time:156148ms step_avg:99.46ms +step:1571/1695 train_time:156250ms step_avg:99.46ms +step:1572/1695 train_time:156351ms step_avg:99.46ms +step:1573/1695 train_time:156453ms step_avg:99.46ms +step:1574/1695 train_time:156556ms step_avg:99.46ms +step:1575/1695 train_time:156657ms step_avg:99.46ms +step:1576/1695 train_time:156758ms step_avg:99.47ms +step:1577/1695 train_time:156861ms step_avg:99.47ms +step:1578/1695 train_time:156963ms step_avg:99.47ms +step:1579/1695 train_time:157065ms step_avg:99.47ms +step:1580/1695 train_time:157167ms step_avg:99.47ms +step:1581/1695 train_time:157270ms step_avg:99.47ms +step:1582/1695 train_time:157371ms step_avg:99.48ms +step:1583/1695 train_time:157475ms step_avg:99.48ms +step:1584/1695 train_time:157579ms step_avg:99.48ms +step:1585/1695 train_time:157680ms step_avg:99.48ms +step:1586/1695 train_time:157782ms step_avg:99.48ms +step:1587/1695 train_time:157884ms step_avg:99.49ms +step:1588/1695 train_time:157985ms step_avg:99.49ms +step:1589/1695 train_time:158087ms step_avg:99.49ms +step:1590/1695 train_time:158189ms step_avg:99.49ms +step:1591/1695 train_time:158290ms step_avg:99.49ms +step:1592/1695 train_time:158393ms step_avg:99.49ms +step:1593/1695 train_time:158494ms step_avg:99.49ms +step:1594/1695 train_time:158598ms step_avg:99.50ms +step:1595/1695 train_time:158699ms step_avg:99.50ms +step:1596/1695 train_time:158801ms step_avg:99.50ms +step:1597/1695 train_time:158904ms step_avg:99.50ms +step:1598/1695 train_time:159008ms step_avg:99.50ms +step:1599/1695 train_time:159109ms step_avg:99.51ms +step:1600/1695 train_time:159211ms step_avg:99.51ms +step:1601/1695 train_time:159313ms step_avg:99.51ms +step:1602/1695 train_time:159416ms step_avg:99.51ms +step:1603/1695 train_time:159517ms step_avg:99.51ms +step:1604/1695 train_time:159618ms step_avg:99.51ms +step:1605/1695 train_time:159721ms step_avg:99.51ms +step:1606/1695 train_time:159823ms step_avg:99.52ms +step:1607/1695 train_time:159925ms step_avg:99.52ms +step:1608/1695 train_time:160026ms step_avg:99.52ms +step:1609/1695 train_time:160127ms step_avg:99.52ms +step:1610/1695 train_time:160230ms step_avg:99.52ms +step:1611/1695 train_time:160332ms step_avg:99.52ms +step:1612/1695 train_time:160435ms step_avg:99.53ms +step:1613/1695 train_time:160535ms step_avg:99.53ms +step:1614/1695 train_time:160636ms step_avg:99.53ms +step:1615/1695 train_time:160738ms step_avg:99.53ms +step:1616/1695 train_time:160838ms step_avg:99.53ms +step:1617/1695 train_time:160942ms step_avg:99.53ms +step:1618/1695 train_time:161044ms step_avg:99.53ms +step:1619/1695 train_time:161147ms step_avg:99.53ms +step:1620/1695 train_time:161249ms step_avg:99.54ms +step:1621/1695 train_time:161351ms step_avg:99.54ms +step:1622/1695 train_time:161452ms step_avg:99.54ms +step:1623/1695 train_time:161552ms step_avg:99.54ms +step:1624/1695 train_time:161654ms step_avg:99.54ms +step:1625/1695 train_time:161756ms step_avg:99.54ms +step:1625/1695 val_loss:3.2931 train_time:161855ms step_avg:99.60ms +step:1626/1695 train_time:161883ms step_avg:99.56ms +step:1627/1695 train_time:161972ms step_avg:99.55ms +step:1628/1695 train_time:162075ms step_avg:99.55ms +step:1629/1695 train_time:162178ms step_avg:99.56ms +step:1630/1695 train_time:162279ms step_avg:99.56ms +step:1631/1695 train_time:162381ms step_avg:99.56ms +step:1632/1695 train_time:162482ms step_avg:99.56ms +step:1633/1695 train_time:162582ms step_avg:99.56ms +step:1634/1695 train_time:162686ms step_avg:99.56ms +step:1635/1695 train_time:162787ms step_avg:99.56ms +step:1636/1695 train_time:162892ms step_avg:99.57ms +step:1637/1695 train_time:162994ms step_avg:99.57ms +step:1638/1695 train_time:163097ms step_avg:99.57ms +step:1639/1695 train_time:163199ms step_avg:99.57ms +step:1640/1695 train_time:163301ms step_avg:99.57ms +step:1641/1695 train_time:163405ms step_avg:99.58ms +step:1642/1695 train_time:163506ms step_avg:99.58ms +step:1643/1695 train_time:163608ms step_avg:99.58ms +step:1644/1695 train_time:163709ms step_avg:99.58ms +step:1645/1695 train_time:163812ms step_avg:99.58ms +step:1646/1695 train_time:163914ms step_avg:99.58ms +step:1647/1695 train_time:164019ms step_avg:99.59ms +step:1648/1695 train_time:164123ms step_avg:99.59ms +step:1649/1695 train_time:164226ms step_avg:99.59ms +step:1650/1695 train_time:164328ms step_avg:99.59ms +step:1651/1695 train_time:164430ms step_avg:99.59ms +step:1652/1695 train_time:164533ms step_avg:99.60ms +step:1653/1695 train_time:164636ms step_avg:99.60ms +step:1654/1695 train_time:164739ms step_avg:99.60ms +step:1655/1695 train_time:164841ms step_avg:99.60ms +step:1656/1695 train_time:164944ms step_avg:99.60ms +step:1657/1695 train_time:165046ms step_avg:99.61ms +step:1658/1695 train_time:165148ms step_avg:99.61ms +step:1659/1695 train_time:165254ms step_avg:99.61ms +step:1660/1695 train_time:165357ms step_avg:99.61ms +step:1661/1695 train_time:165460ms step_avg:99.61ms +step:1662/1695 train_time:165565ms step_avg:99.62ms +step:1663/1695 train_time:165668ms step_avg:99.62ms +step:1664/1695 train_time:165771ms step_avg:99.62ms +step:1665/1695 train_time:165878ms step_avg:99.63ms +step:1666/1695 train_time:165981ms step_avg:99.63ms +step:1667/1695 train_time:166083ms step_avg:99.63ms +step:1668/1695 train_time:166187ms step_avg:99.63ms +step:1669/1695 train_time:166291ms step_avg:99.64ms +step:1670/1695 train_time:166393ms step_avg:99.64ms +step:1671/1695 train_time:166495ms step_avg:99.64ms +step:1672/1695 train_time:166600ms step_avg:99.64ms +step:1673/1695 train_time:166703ms step_avg:99.64ms +step:1674/1695 train_time:166806ms step_avg:99.64ms +step:1675/1695 train_time:166908ms step_avg:99.65ms +step:1676/1695 train_time:167013ms step_avg:99.65ms +step:1677/1695 train_time:167115ms step_avg:99.65ms +step:1678/1695 train_time:167220ms step_avg:99.65ms +step:1679/1695 train_time:167325ms step_avg:99.66ms +step:1680/1695 train_time:167427ms step_avg:99.66ms +step:1681/1695 train_time:167530ms step_avg:99.66ms +step:1682/1695 train_time:167636ms step_avg:99.66ms +step:1683/1695 train_time:167739ms step_avg:99.67ms +step:1684/1695 train_time:167842ms step_avg:99.67ms +step:1685/1695 train_time:167945ms step_avg:99.67ms +step:1686/1695 train_time:168047ms step_avg:99.67ms +step:1687/1695 train_time:168149ms step_avg:99.67ms +step:1688/1695 train_time:168251ms step_avg:99.67ms +step:1689/1695 train_time:168352ms step_avg:99.68ms +step:1690/1695 train_time:168455ms step_avg:99.68ms +step:1691/1695 train_time:168558ms step_avg:99.68ms +step:1692/1695 train_time:168661ms step_avg:99.68ms +step:1693/1695 train_time:168764ms step_avg:99.68ms +step:1694/1695 train_time:168868ms step_avg:99.69ms +step:1695/1695 train_time:168971ms step_avg:99.69ms +step:1695/1695 val_loss:3.2802 train_time:169070ms step_avg:99.75ms +peak memory allocated: 34004 MiB reserved: 49600 MiB diff --git a/records/082325_SparseAttnGate/48b19604-5049-48c9-956c-8ddc4d0781fb.txt b/records/082325_SparseAttnGate/48b19604-5049-48c9-956c-8ddc4d0781fb.txt new file mode 100644 index 000000000..c9fef39e2 --- /dev/null +++ b/records/082325_SparseAttnGate/48b19604-5049-48c9-956c-8ddc4d0781fb.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:48:02 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 319303 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 319304 C /usr/bin/python3 614MiB | +| 0 N/A N/A 319305 C /usr/bin/python3 614MiB | +| 0 N/A N/A 319306 C /usr/bin/python3 614MiB | +| 0 N/A N/A 319307 C /usr/bin/python3 614MiB | +| 0 N/A N/A 319308 C /usr/bin/python3 614MiB | +| 0 N/A N/A 319309 C /usr/bin/python3 614MiB | +| 0 N/A N/A 319310 C /usr/bin/python3 614MiB | +| 1 N/A N/A 319304 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 319305 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 319306 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 319307 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 319308 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 319309 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 319310 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.06ms +step:1/1695 train_time:155ms step_avg:155.16ms +step:2/1695 train_time:182ms step_avg:90.82ms +step:3/1695 train_time:250ms step_avg:83.49ms +step:4/1695 train_time:342ms step_avg:85.55ms +step:5/1695 train_time:435ms step_avg:86.91ms +step:6/1695 train_time:527ms step_avg:87.78ms +step:7/1695 train_time:619ms step_avg:88.50ms +step:8/1695 train_time:712ms step_avg:89.01ms +step:9/1695 train_time:805ms step_avg:89.42ms +step:10/1695 train_time:898ms step_avg:89.80ms +step:11/1695 train_time:990ms step_avg:90.03ms +step:12/1695 train_time:1086ms step_avg:90.46ms +step:13/1695 train_time:1182ms step_avg:90.92ms +step:14/1695 train_time:1277ms step_avg:91.23ms +step:15/1695 train_time:1371ms step_avg:91.38ms +step:16/1695 train_time:1465ms step_avg:91.54ms +step:17/1695 train_time:1558ms step_avg:91.66ms +step:18/1695 train_time:1651ms step_avg:91.75ms +step:19/1695 train_time:1744ms step_avg:91.79ms +step:20/1695 train_time:1838ms step_avg:91.88ms +step:21/1695 train_time:1931ms step_avg:91.93ms +step:22/1695 train_time:2025ms step_avg:92.02ms +step:23/1695 train_time:2119ms step_avg:92.11ms +step:24/1695 train_time:2213ms step_avg:92.19ms +step:25/1695 train_time:2306ms step_avg:92.25ms +step:26/1695 train_time:2401ms step_avg:92.34ms +step:27/1695 train_time:2494ms step_avg:92.38ms +step:28/1695 train_time:2588ms step_avg:92.43ms +step:29/1695 train_time:2682ms step_avg:92.48ms +step:30/1695 train_time:2776ms step_avg:92.52ms +step:31/1695 train_time:2869ms step_avg:92.54ms +step:32/1695 train_time:2962ms step_avg:92.57ms +step:33/1695 train_time:3056ms step_avg:92.61ms +step:34/1695 train_time:3149ms step_avg:92.62ms +step:35/1695 train_time:3243ms step_avg:92.66ms +step:36/1695 train_time:3338ms step_avg:92.73ms +step:37/1695 train_time:3431ms step_avg:92.73ms +step:38/1695 train_time:3525ms step_avg:92.76ms +step:39/1695 train_time:3619ms step_avg:92.79ms +step:40/1695 train_time:3713ms step_avg:92.82ms +step:41/1695 train_time:3806ms step_avg:92.83ms +step:42/1695 train_time:3900ms step_avg:92.85ms +step:43/1695 train_time:3993ms step_avg:92.86ms +step:44/1695 train_time:4086ms step_avg:92.87ms +step:45/1695 train_time:4181ms step_avg:92.90ms +step:46/1695 train_time:4274ms step_avg:92.92ms +step:47/1695 train_time:4368ms step_avg:92.93ms +step:48/1695 train_time:4461ms step_avg:92.95ms +step:49/1695 train_time:4555ms step_avg:92.96ms +step:50/1695 train_time:4648ms step_avg:92.96ms +step:51/1695 train_time:4743ms step_avg:92.99ms +step:52/1695 train_time:4835ms step_avg:92.98ms +step:53/1695 train_time:4929ms step_avg:93.00ms +step:54/1695 train_time:5023ms step_avg:93.01ms +step:55/1695 train_time:5116ms step_avg:93.02ms +step:56/1695 train_time:5210ms step_avg:93.03ms +step:57/1695 train_time:5304ms step_avg:93.06ms +step:58/1695 train_time:5399ms step_avg:93.08ms +step:59/1695 train_time:5494ms step_avg:93.12ms +step:60/1695 train_time:5588ms step_avg:93.13ms +step:61/1695 train_time:5681ms step_avg:93.13ms +step:62/1695 train_time:5775ms step_avg:93.14ms +step:63/1695 train_time:5868ms step_avg:93.14ms +step:64/1695 train_time:5961ms step_avg:93.15ms +step:65/1695 train_time:6054ms step_avg:93.14ms +step:66/1695 train_time:6147ms step_avg:93.14ms +step:67/1695 train_time:6240ms step_avg:93.13ms +step:68/1695 train_time:6333ms step_avg:93.14ms +step:69/1695 train_time:6427ms step_avg:93.14ms +step:70/1695 train_time:6522ms step_avg:93.17ms +step:71/1695 train_time:6616ms step_avg:93.18ms +step:72/1695 train_time:6710ms step_avg:93.19ms +step:73/1695 train_time:6804ms step_avg:93.21ms +step:74/1695 train_time:6899ms step_avg:93.23ms +step:75/1695 train_time:6993ms step_avg:93.24ms +step:76/1695 train_time:7086ms step_avg:93.24ms +step:77/1695 train_time:7181ms step_avg:93.25ms +step:78/1695 train_time:7275ms step_avg:93.27ms +step:79/1695 train_time:7369ms step_avg:93.28ms +step:80/1695 train_time:7462ms step_avg:93.28ms +step:81/1695 train_time:7556ms step_avg:93.29ms +step:82/1695 train_time:7650ms step_avg:93.29ms +step:83/1695 train_time:7743ms step_avg:93.29ms +step:84/1695 train_time:7837ms step_avg:93.30ms +step:85/1695 train_time:7931ms step_avg:93.30ms +step:86/1695 train_time:8025ms step_avg:93.31ms +step:87/1695 train_time:8121ms step_avg:93.34ms +step:88/1695 train_time:8213ms step_avg:93.33ms +step:89/1695 train_time:8306ms step_avg:93.32ms +step:90/1695 train_time:8400ms step_avg:93.33ms +step:91/1695 train_time:8494ms step_avg:93.34ms +step:92/1695 train_time:8588ms step_avg:93.35ms +step:93/1695 train_time:8682ms step_avg:93.35ms +step:94/1695 train_time:8775ms step_avg:93.35ms +step:95/1695 train_time:8868ms step_avg:93.35ms +step:96/1695 train_time:8962ms step_avg:93.35ms +step:97/1695 train_time:9055ms step_avg:93.36ms +step:98/1695 train_time:9149ms step_avg:93.35ms +step:99/1695 train_time:9242ms step_avg:93.35ms +step:100/1695 train_time:9336ms step_avg:93.36ms +step:101/1695 train_time:9429ms step_avg:93.36ms +step:102/1695 train_time:9524ms step_avg:93.37ms +step:103/1695 train_time:9619ms step_avg:93.38ms +step:104/1695 train_time:9712ms step_avg:93.38ms +step:105/1695 train_time:9806ms step_avg:93.39ms +step:106/1695 train_time:9900ms step_avg:93.40ms +step:107/1695 train_time:9994ms step_avg:93.40ms +step:108/1695 train_time:10087ms step_avg:93.40ms +step:109/1695 train_time:10180ms step_avg:93.40ms +step:110/1695 train_time:10274ms step_avg:93.40ms +step:111/1695 train_time:10368ms step_avg:93.41ms +step:112/1695 train_time:10462ms step_avg:93.41ms +step:113/1695 train_time:10556ms step_avg:93.42ms +step:114/1695 train_time:10649ms step_avg:93.41ms +step:115/1695 train_time:10742ms step_avg:93.41ms +step:116/1695 train_time:10835ms step_avg:93.41ms +step:117/1695 train_time:10929ms step_avg:93.41ms +step:118/1695 train_time:11023ms step_avg:93.41ms +step:119/1695 train_time:11116ms step_avg:93.41ms +step:120/1695 train_time:11210ms step_avg:93.41ms +step:121/1695 train_time:11305ms step_avg:93.43ms +step:122/1695 train_time:11397ms step_avg:93.42ms +step:123/1695 train_time:11490ms step_avg:93.42ms +step:124/1695 train_time:11584ms step_avg:93.42ms +step:125/1695 train_time:11678ms step_avg:93.43ms +step:125/1695 val_loss:4.5872 train_time:11770ms step_avg:94.16ms +step:126/1695 train_time:11798ms step_avg:93.64ms +step:127/1695 train_time:11875ms step_avg:93.50ms +step:128/1695 train_time:11974ms step_avg:93.55ms +step:129/1695 train_time:12070ms step_avg:93.56ms +step:130/1695 train_time:12165ms step_avg:93.58ms +step:131/1695 train_time:12257ms step_avg:93.57ms +step:132/1695 train_time:12351ms step_avg:93.57ms +step:133/1695 train_time:12444ms step_avg:93.56ms +step:134/1695 train_time:12537ms step_avg:93.56ms +step:135/1695 train_time:12630ms step_avg:93.56ms +step:136/1695 train_time:12724ms step_avg:93.56ms +step:137/1695 train_time:12818ms step_avg:93.56ms +step:138/1695 train_time:12913ms step_avg:93.57ms +step:139/1695 train_time:13008ms step_avg:93.58ms +step:140/1695 train_time:13104ms step_avg:93.60ms +step:141/1695 train_time:13199ms step_avg:93.61ms +step:142/1695 train_time:13292ms step_avg:93.61ms +step:143/1695 train_time:13386ms step_avg:93.61ms +step:144/1695 train_time:13481ms step_avg:93.62ms +step:145/1695 train_time:13574ms step_avg:93.61ms +step:146/1695 train_time:13668ms step_avg:93.62ms +step:147/1695 train_time:13762ms step_avg:93.62ms +step:148/1695 train_time:13857ms step_avg:93.63ms +step:149/1695 train_time:13951ms step_avg:93.63ms +step:150/1695 train_time:14047ms step_avg:93.64ms +step:151/1695 train_time:14142ms step_avg:93.65ms +step:152/1695 train_time:14236ms step_avg:93.66ms +step:153/1695 train_time:14330ms step_avg:93.66ms +step:154/1695 train_time:14424ms step_avg:93.66ms +step:155/1695 train_time:14518ms step_avg:93.66ms +step:156/1695 train_time:14613ms step_avg:93.67ms +step:157/1695 train_time:14706ms step_avg:93.67ms +step:158/1695 train_time:14801ms step_avg:93.67ms +step:159/1695 train_time:14895ms step_avg:93.68ms +step:160/1695 train_time:14989ms step_avg:93.68ms +step:161/1695 train_time:15084ms step_avg:93.69ms +step:162/1695 train_time:15179ms step_avg:93.69ms +step:163/1695 train_time:15273ms step_avg:93.70ms +step:164/1695 train_time:15367ms step_avg:93.70ms +step:165/1695 train_time:15461ms step_avg:93.70ms +step:166/1695 train_time:15554ms step_avg:93.70ms +step:167/1695 train_time:15648ms step_avg:93.70ms +step:168/1695 train_time:15742ms step_avg:93.71ms +step:169/1695 train_time:15838ms step_avg:93.71ms +step:170/1695 train_time:15931ms step_avg:93.71ms +step:171/1695 train_time:16026ms step_avg:93.72ms +step:172/1695 train_time:16121ms step_avg:93.73ms +step:173/1695 train_time:16214ms step_avg:93.73ms +step:174/1695 train_time:16308ms step_avg:93.73ms +step:175/1695 train_time:16403ms step_avg:93.73ms +step:176/1695 train_time:16496ms step_avg:93.73ms +step:177/1695 train_time:16590ms step_avg:93.73ms +step:178/1695 train_time:16684ms step_avg:93.73ms +step:179/1695 train_time:16778ms step_avg:93.73ms +step:180/1695 train_time:16873ms step_avg:93.74ms +step:181/1695 train_time:16967ms step_avg:93.74ms +step:182/1695 train_time:17061ms step_avg:93.74ms +step:183/1695 train_time:17155ms step_avg:93.74ms +step:184/1695 train_time:17249ms step_avg:93.74ms +step:185/1695 train_time:17343ms step_avg:93.75ms +step:186/1695 train_time:17437ms step_avg:93.75ms +step:187/1695 train_time:17530ms step_avg:93.74ms +step:188/1695 train_time:17624ms step_avg:93.75ms +step:189/1695 train_time:17718ms step_avg:93.74ms +step:190/1695 train_time:17811ms step_avg:93.74ms +step:191/1695 train_time:17906ms step_avg:93.75ms +step:192/1695 train_time:18000ms step_avg:93.75ms +step:193/1695 train_time:18094ms step_avg:93.75ms +step:194/1695 train_time:18188ms step_avg:93.75ms +step:195/1695 train_time:18282ms step_avg:93.75ms +step:196/1695 train_time:18377ms step_avg:93.76ms +step:197/1695 train_time:18470ms step_avg:93.76ms +step:198/1695 train_time:18564ms step_avg:93.76ms +step:199/1695 train_time:18658ms step_avg:93.76ms +step:200/1695 train_time:18752ms step_avg:93.76ms +step:201/1695 train_time:18846ms step_avg:93.76ms +step:202/1695 train_time:18941ms step_avg:93.77ms +step:203/1695 train_time:19035ms step_avg:93.77ms +step:204/1695 train_time:19130ms step_avg:93.77ms +step:205/1695 train_time:19224ms step_avg:93.77ms +step:206/1695 train_time:19318ms step_avg:93.78ms +step:207/1695 train_time:19412ms step_avg:93.78ms +step:208/1695 train_time:19506ms step_avg:93.78ms +step:209/1695 train_time:19600ms step_avg:93.78ms +step:210/1695 train_time:19694ms step_avg:93.78ms +step:211/1695 train_time:19788ms step_avg:93.78ms +step:212/1695 train_time:19883ms step_avg:93.79ms +step:213/1695 train_time:19977ms step_avg:93.79ms +step:214/1695 train_time:20070ms step_avg:93.79ms +step:215/1695 train_time:20164ms step_avg:93.79ms +step:216/1695 train_time:20258ms step_avg:93.79ms +step:217/1695 train_time:20352ms step_avg:93.79ms +step:218/1695 train_time:20446ms step_avg:93.79ms +step:219/1695 train_time:20540ms step_avg:93.79ms +step:220/1695 train_time:20634ms step_avg:93.79ms +step:221/1695 train_time:20728ms step_avg:93.79ms +step:222/1695 train_time:20822ms step_avg:93.79ms +step:223/1695 train_time:20916ms step_avg:93.79ms +step:224/1695 train_time:21009ms step_avg:93.79ms +step:225/1695 train_time:21104ms step_avg:93.80ms +step:226/1695 train_time:21200ms step_avg:93.80ms +step:227/1695 train_time:21294ms step_avg:93.80ms +step:228/1695 train_time:21388ms step_avg:93.81ms +step:229/1695 train_time:21481ms step_avg:93.80ms +step:230/1695 train_time:21576ms step_avg:93.81ms +step:231/1695 train_time:21669ms step_avg:93.81ms +step:232/1695 train_time:21763ms step_avg:93.81ms +step:233/1695 train_time:21857ms step_avg:93.81ms +step:234/1695 train_time:21950ms step_avg:93.80ms +step:235/1695 train_time:22044ms step_avg:93.80ms +step:236/1695 train_time:22138ms step_avg:93.81ms +step:237/1695 train_time:22232ms step_avg:93.81ms +step:238/1695 train_time:22327ms step_avg:93.81ms +step:239/1695 train_time:22421ms step_avg:93.81ms +step:240/1695 train_time:22514ms step_avg:93.81ms +step:241/1695 train_time:22608ms step_avg:93.81ms +step:242/1695 train_time:22704ms step_avg:93.82ms +step:243/1695 train_time:22798ms step_avg:93.82ms +step:244/1695 train_time:22893ms step_avg:93.82ms +step:245/1695 train_time:22986ms step_avg:93.82ms +step:246/1695 train_time:23080ms step_avg:93.82ms +step:247/1695 train_time:23174ms step_avg:93.82ms +step:248/1695 train_time:23268ms step_avg:93.82ms +step:249/1695 train_time:23362ms step_avg:93.82ms +step:250/1695 train_time:23457ms step_avg:93.83ms +step:250/1695 val_loss:4.0644 train_time:23548ms step_avg:94.19ms +step:251/1695 train_time:23576ms step_avg:93.93ms +step:252/1695 train_time:23651ms step_avg:93.85ms +step:253/1695 train_time:23751ms step_avg:93.88ms +step:254/1695 train_time:23846ms step_avg:93.88ms +step:255/1695 train_time:23941ms step_avg:93.89ms +step:256/1695 train_time:24035ms step_avg:93.89ms +step:257/1695 train_time:24128ms step_avg:93.88ms +step:258/1695 train_time:24222ms step_avg:93.88ms +step:259/1695 train_time:24316ms step_avg:93.88ms +step:260/1695 train_time:24411ms step_avg:93.89ms +step:261/1695 train_time:24504ms step_avg:93.89ms +step:262/1695 train_time:24600ms step_avg:93.89ms +step:263/1695 train_time:24697ms step_avg:93.90ms +step:264/1695 train_time:24794ms step_avg:93.92ms +step:265/1695 train_time:24888ms step_avg:93.92ms +step:266/1695 train_time:24983ms step_avg:93.92ms +step:267/1695 train_time:25077ms step_avg:93.92ms +step:268/1695 train_time:25171ms step_avg:93.92ms +step:269/1695 train_time:25264ms step_avg:93.92ms +step:270/1695 train_time:25358ms step_avg:93.92ms +step:271/1695 train_time:25451ms step_avg:93.92ms +step:272/1695 train_time:25546ms step_avg:93.92ms +step:273/1695 train_time:25641ms step_avg:93.92ms +step:274/1695 train_time:25736ms step_avg:93.93ms +step:275/1695 train_time:25831ms step_avg:93.93ms +step:276/1695 train_time:25927ms step_avg:93.94ms +step:277/1695 train_time:26022ms step_avg:93.94ms +step:278/1695 train_time:26117ms step_avg:93.95ms +step:279/1695 train_time:26210ms step_avg:93.94ms +step:280/1695 train_time:26304ms step_avg:93.94ms +step:281/1695 train_time:26399ms step_avg:93.95ms +step:282/1695 train_time:26492ms step_avg:93.94ms +step:283/1695 train_time:26586ms step_avg:93.94ms +step:284/1695 train_time:26682ms step_avg:93.95ms +step:285/1695 train_time:26777ms step_avg:93.95ms +step:286/1695 train_time:26871ms step_avg:93.95ms +step:287/1695 train_time:26967ms step_avg:93.96ms +step:288/1695 train_time:27062ms step_avg:93.97ms +step:289/1695 train_time:27157ms step_avg:93.97ms +step:290/1695 train_time:27251ms step_avg:93.97ms +step:291/1695 train_time:27345ms step_avg:93.97ms +step:292/1695 train_time:27440ms step_avg:93.97ms +step:293/1695 train_time:27534ms step_avg:93.97ms +step:294/1695 train_time:27629ms step_avg:93.97ms +step:295/1695 train_time:27724ms step_avg:93.98ms +step:296/1695 train_time:27821ms step_avg:93.99ms +step:297/1695 train_time:27915ms step_avg:93.99ms +step:298/1695 train_time:28009ms step_avg:93.99ms +step:299/1695 train_time:28104ms step_avg:93.99ms +step:300/1695 train_time:28199ms step_avg:94.00ms +step:301/1695 train_time:28293ms step_avg:94.00ms +step:302/1695 train_time:28387ms step_avg:94.00ms +step:303/1695 train_time:28482ms step_avg:94.00ms +step:304/1695 train_time:28576ms step_avg:94.00ms +step:305/1695 train_time:28670ms step_avg:94.00ms +step:306/1695 train_time:28765ms step_avg:94.00ms +step:307/1695 train_time:28860ms step_avg:94.01ms +step:308/1695 train_time:28955ms step_avg:94.01ms +step:309/1695 train_time:29049ms step_avg:94.01ms +step:310/1695 train_time:29143ms step_avg:94.01ms +step:311/1695 train_time:29238ms step_avg:94.01ms +step:312/1695 train_time:29332ms step_avg:94.01ms +step:313/1695 train_time:29427ms step_avg:94.02ms +step:314/1695 train_time:29522ms step_avg:94.02ms +step:315/1695 train_time:29617ms step_avg:94.02ms +step:316/1695 train_time:29710ms step_avg:94.02ms +step:317/1695 train_time:29805ms step_avg:94.02ms +step:318/1695 train_time:29901ms step_avg:94.03ms +step:319/1695 train_time:29995ms step_avg:94.03ms +step:320/1695 train_time:30089ms step_avg:94.03ms +step:321/1695 train_time:30184ms step_avg:94.03ms +step:322/1695 train_time:30279ms step_avg:94.04ms +step:323/1695 train_time:30374ms step_avg:94.04ms +step:324/1695 train_time:30468ms step_avg:94.04ms +step:325/1695 train_time:30562ms step_avg:94.04ms +step:326/1695 train_time:30656ms step_avg:94.04ms +step:327/1695 train_time:30751ms step_avg:94.04ms +step:328/1695 train_time:30846ms step_avg:94.04ms +step:329/1695 train_time:30942ms step_avg:94.05ms +step:330/1695 train_time:31036ms step_avg:94.05ms +step:331/1695 train_time:31130ms step_avg:94.05ms +step:332/1695 train_time:31224ms step_avg:94.05ms +step:333/1695 train_time:31319ms step_avg:94.05ms +step:334/1695 train_time:31414ms step_avg:94.05ms +step:335/1695 train_time:31508ms step_avg:94.05ms +step:336/1695 train_time:31602ms step_avg:94.05ms +step:337/1695 train_time:31697ms step_avg:94.06ms +step:338/1695 train_time:31791ms step_avg:94.06ms +step:339/1695 train_time:31886ms step_avg:94.06ms +step:340/1695 train_time:31981ms step_avg:94.06ms +step:341/1695 train_time:32074ms step_avg:94.06ms +step:342/1695 train_time:32168ms step_avg:94.06ms +step:343/1695 train_time:32263ms step_avg:94.06ms +step:344/1695 train_time:32358ms step_avg:94.06ms +step:345/1695 train_time:32452ms step_avg:94.06ms +step:346/1695 train_time:32547ms step_avg:94.07ms +step:347/1695 train_time:32643ms step_avg:94.07ms +step:348/1695 train_time:32738ms step_avg:94.07ms +step:349/1695 train_time:32833ms step_avg:94.08ms +step:350/1695 train_time:32926ms step_avg:94.08ms +step:351/1695 train_time:33023ms step_avg:94.08ms +step:352/1695 train_time:33117ms step_avg:94.08ms +step:353/1695 train_time:33211ms step_avg:94.08ms +step:354/1695 train_time:33307ms step_avg:94.09ms +step:355/1695 train_time:33401ms step_avg:94.09ms +step:356/1695 train_time:33496ms step_avg:94.09ms +step:357/1695 train_time:33590ms step_avg:94.09ms +step:358/1695 train_time:33686ms step_avg:94.09ms +step:359/1695 train_time:33780ms step_avg:94.10ms +step:360/1695 train_time:33874ms step_avg:94.09ms +step:361/1695 train_time:33969ms step_avg:94.10ms +step:362/1695 train_time:34064ms step_avg:94.10ms +step:363/1695 train_time:34159ms step_avg:94.10ms +step:364/1695 train_time:34253ms step_avg:94.10ms +step:365/1695 train_time:34347ms step_avg:94.10ms +step:366/1695 train_time:34441ms step_avg:94.10ms +step:367/1695 train_time:34535ms step_avg:94.10ms +step:368/1695 train_time:34628ms step_avg:94.10ms +step:369/1695 train_time:34723ms step_avg:94.10ms +step:370/1695 train_time:34818ms step_avg:94.10ms +step:371/1695 train_time:34912ms step_avg:94.10ms +step:372/1695 train_time:35007ms step_avg:94.10ms +step:373/1695 train_time:35101ms step_avg:94.10ms +step:374/1695 train_time:35195ms step_avg:94.10ms +step:375/1695 train_time:35289ms step_avg:94.10ms +step:375/1695 val_loss:3.8654 train_time:35382ms step_avg:94.35ms +step:376/1695 train_time:35410ms step_avg:94.17ms +step:377/1695 train_time:35486ms step_avg:94.13ms +step:378/1695 train_time:35588ms step_avg:94.15ms +step:379/1695 train_time:35685ms step_avg:94.15ms +step:380/1695 train_time:35781ms step_avg:94.16ms +step:381/1695 train_time:35877ms step_avg:94.17ms +step:382/1695 train_time:35972ms step_avg:94.17ms +step:383/1695 train_time:36068ms step_avg:94.17ms +step:384/1695 train_time:36163ms step_avg:94.18ms +step:385/1695 train_time:36260ms step_avg:94.18ms +step:386/1695 train_time:36354ms step_avg:94.18ms +step:387/1695 train_time:36450ms step_avg:94.19ms +step:388/1695 train_time:36547ms step_avg:94.19ms +step:389/1695 train_time:36645ms step_avg:94.20ms +step:390/1695 train_time:36741ms step_avg:94.21ms +step:391/1695 train_time:36838ms step_avg:94.21ms +step:392/1695 train_time:36934ms step_avg:94.22ms +step:393/1695 train_time:37030ms step_avg:94.22ms +step:394/1695 train_time:37125ms step_avg:94.23ms +step:395/1695 train_time:37221ms step_avg:94.23ms +step:396/1695 train_time:37317ms step_avg:94.24ms +step:397/1695 train_time:37413ms step_avg:94.24ms +step:398/1695 train_time:37509ms step_avg:94.24ms +step:399/1695 train_time:37606ms step_avg:94.25ms +step:400/1695 train_time:37702ms step_avg:94.25ms +step:401/1695 train_time:37799ms step_avg:94.26ms +step:402/1695 train_time:37895ms step_avg:94.27ms +step:403/1695 train_time:37991ms step_avg:94.27ms +step:404/1695 train_time:38087ms step_avg:94.27ms +step:405/1695 train_time:38182ms step_avg:94.28ms +step:406/1695 train_time:38278ms step_avg:94.28ms +step:407/1695 train_time:38375ms step_avg:94.29ms +step:408/1695 train_time:38472ms step_avg:94.29ms +step:409/1695 train_time:38568ms step_avg:94.30ms +step:410/1695 train_time:38664ms step_avg:94.30ms +step:411/1695 train_time:38760ms step_avg:94.31ms +step:412/1695 train_time:38857ms step_avg:94.31ms +step:413/1695 train_time:38953ms step_avg:94.32ms +step:414/1695 train_time:39049ms step_avg:94.32ms +step:415/1695 train_time:39145ms step_avg:94.33ms +step:416/1695 train_time:39241ms step_avg:94.33ms +step:417/1695 train_time:39337ms step_avg:94.33ms +step:418/1695 train_time:39434ms step_avg:94.34ms +step:419/1695 train_time:39529ms step_avg:94.34ms +step:420/1695 train_time:39625ms step_avg:94.34ms +step:421/1695 train_time:39721ms step_avg:94.35ms +step:422/1695 train_time:39817ms step_avg:94.35ms +step:423/1695 train_time:39915ms step_avg:94.36ms +step:424/1695 train_time:40011ms step_avg:94.36ms +step:425/1695 train_time:40106ms step_avg:94.37ms +step:426/1695 train_time:40202ms step_avg:94.37ms +step:427/1695 train_time:40299ms step_avg:94.38ms +step:428/1695 train_time:40394ms step_avg:94.38ms +step:429/1695 train_time:40490ms step_avg:94.38ms +step:430/1695 train_time:40586ms step_avg:94.39ms +step:431/1695 train_time:40682ms step_avg:94.39ms +step:432/1695 train_time:40778ms step_avg:94.39ms +step:433/1695 train_time:40874ms step_avg:94.40ms +step:434/1695 train_time:40970ms step_avg:94.40ms +step:435/1695 train_time:41065ms step_avg:94.40ms +step:436/1695 train_time:41161ms step_avg:94.41ms +step:437/1695 train_time:41257ms step_avg:94.41ms +step:438/1695 train_time:41354ms step_avg:94.41ms +step:439/1695 train_time:41449ms step_avg:94.42ms +step:440/1695 train_time:41545ms step_avg:94.42ms +step:441/1695 train_time:41641ms step_avg:94.42ms +step:442/1695 train_time:41737ms step_avg:94.43ms +step:443/1695 train_time:41834ms step_avg:94.43ms +step:444/1695 train_time:41929ms step_avg:94.44ms +step:445/1695 train_time:42025ms step_avg:94.44ms +step:446/1695 train_time:42121ms step_avg:94.44ms +step:447/1695 train_time:42218ms step_avg:94.45ms +step:448/1695 train_time:42314ms step_avg:94.45ms +step:449/1695 train_time:42410ms step_avg:94.45ms +step:450/1695 train_time:42505ms step_avg:94.46ms +step:451/1695 train_time:42602ms step_avg:94.46ms +step:452/1695 train_time:42698ms step_avg:94.47ms +step:453/1695 train_time:42794ms step_avg:94.47ms +step:454/1695 train_time:42890ms step_avg:94.47ms +step:455/1695 train_time:42986ms step_avg:94.47ms +step:456/1695 train_time:43082ms step_avg:94.48ms +step:457/1695 train_time:43179ms step_avg:94.48ms +step:458/1695 train_time:43275ms step_avg:94.49ms +step:459/1695 train_time:43370ms step_avg:94.49ms +step:460/1695 train_time:43466ms step_avg:94.49ms +step:461/1695 train_time:43562ms step_avg:94.50ms +step:462/1695 train_time:43658ms step_avg:94.50ms +step:463/1695 train_time:43755ms step_avg:94.50ms +step:464/1695 train_time:43852ms step_avg:94.51ms +step:465/1695 train_time:43947ms step_avg:94.51ms +step:466/1695 train_time:44044ms step_avg:94.51ms +step:467/1695 train_time:44140ms step_avg:94.52ms +step:468/1695 train_time:44236ms step_avg:94.52ms +step:469/1695 train_time:44333ms step_avg:94.53ms +step:470/1695 train_time:44430ms step_avg:94.53ms +step:471/1695 train_time:44526ms step_avg:94.53ms +step:472/1695 train_time:44622ms step_avg:94.54ms +step:473/1695 train_time:44718ms step_avg:94.54ms +step:474/1695 train_time:44814ms step_avg:94.54ms +step:475/1695 train_time:44911ms step_avg:94.55ms +step:476/1695 train_time:45006ms step_avg:94.55ms +step:477/1695 train_time:45103ms step_avg:94.56ms +step:478/1695 train_time:45199ms step_avg:94.56ms +step:479/1695 train_time:45296ms step_avg:94.56ms +step:480/1695 train_time:45393ms step_avg:94.57ms +step:481/1695 train_time:45489ms step_avg:94.57ms +step:482/1695 train_time:45585ms step_avg:94.58ms +step:483/1695 train_time:45682ms step_avg:94.58ms +step:484/1695 train_time:45777ms step_avg:94.58ms +step:485/1695 train_time:45874ms step_avg:94.59ms +step:486/1695 train_time:45970ms step_avg:94.59ms +step:487/1695 train_time:46066ms step_avg:94.59ms +step:488/1695 train_time:46162ms step_avg:94.59ms +step:489/1695 train_time:46259ms step_avg:94.60ms +step:490/1695 train_time:46356ms step_avg:94.60ms +step:491/1695 train_time:46452ms step_avg:94.61ms +step:492/1695 train_time:46547ms step_avg:94.61ms +step:493/1695 train_time:46643ms step_avg:94.61ms +step:494/1695 train_time:46739ms step_avg:94.61ms +step:495/1695 train_time:46835ms step_avg:94.62ms +step:496/1695 train_time:46931ms step_avg:94.62ms +step:497/1695 train_time:47027ms step_avg:94.62ms +step:498/1695 train_time:47123ms step_avg:94.63ms +step:499/1695 train_time:47219ms step_avg:94.63ms +step:500/1695 train_time:47316ms step_avg:94.63ms +step:500/1695 val_loss:3.7236 train_time:47410ms step_avg:94.82ms +step:501/1695 train_time:47439ms step_avg:94.69ms +step:502/1695 train_time:47522ms step_avg:94.66ms +step:503/1695 train_time:47620ms step_avg:94.67ms +step:504/1695 train_time:47717ms step_avg:94.68ms +step:505/1695 train_time:47813ms step_avg:94.68ms +step:506/1695 train_time:47909ms step_avg:94.68ms +step:507/1695 train_time:48005ms step_avg:94.68ms +step:508/1695 train_time:48100ms step_avg:94.69ms +step:509/1695 train_time:48197ms step_avg:94.69ms +step:510/1695 train_time:48293ms step_avg:94.69ms +step:511/1695 train_time:48389ms step_avg:94.69ms +step:512/1695 train_time:48485ms step_avg:94.70ms +step:513/1695 train_time:48583ms step_avg:94.70ms +step:514/1695 train_time:48681ms step_avg:94.71ms +step:515/1695 train_time:48778ms step_avg:94.72ms +step:516/1695 train_time:48875ms step_avg:94.72ms +step:517/1695 train_time:48971ms step_avg:94.72ms +step:518/1695 train_time:49067ms step_avg:94.72ms +step:519/1695 train_time:49163ms step_avg:94.73ms +step:520/1695 train_time:49259ms step_avg:94.73ms +step:521/1695 train_time:49356ms step_avg:94.73ms +step:522/1695 train_time:49453ms step_avg:94.74ms +step:523/1695 train_time:49550ms step_avg:94.74ms +step:524/1695 train_time:49647ms step_avg:94.75ms +step:525/1695 train_time:49743ms step_avg:94.75ms +step:526/1695 train_time:49841ms step_avg:94.75ms +step:527/1695 train_time:49937ms step_avg:94.76ms +step:528/1695 train_time:50034ms step_avg:94.76ms +step:529/1695 train_time:50129ms step_avg:94.76ms +step:530/1695 train_time:50225ms step_avg:94.76ms +step:531/1695 train_time:50321ms step_avg:94.77ms +step:532/1695 train_time:50418ms step_avg:94.77ms +step:533/1695 train_time:50514ms step_avg:94.77ms +step:534/1695 train_time:50611ms step_avg:94.78ms +step:535/1695 train_time:50708ms step_avg:94.78ms +step:536/1695 train_time:50805ms step_avg:94.79ms +step:537/1695 train_time:50901ms step_avg:94.79ms +step:538/1695 train_time:50998ms step_avg:94.79ms +step:539/1695 train_time:51095ms step_avg:94.80ms +step:540/1695 train_time:51191ms step_avg:94.80ms +step:541/1695 train_time:51287ms step_avg:94.80ms +step:542/1695 train_time:51383ms step_avg:94.80ms +step:543/1695 train_time:51480ms step_avg:94.81ms +step:544/1695 train_time:51576ms step_avg:94.81ms +step:545/1695 train_time:51673ms step_avg:94.81ms +step:546/1695 train_time:51770ms step_avg:94.82ms +step:547/1695 train_time:51866ms step_avg:94.82ms +step:548/1695 train_time:51962ms step_avg:94.82ms +step:549/1695 train_time:52059ms step_avg:94.83ms +step:550/1695 train_time:52156ms step_avg:94.83ms +step:551/1695 train_time:52253ms step_avg:94.83ms +step:552/1695 train_time:52349ms step_avg:94.83ms +step:553/1695 train_time:52444ms step_avg:94.84ms +step:554/1695 train_time:52542ms step_avg:94.84ms +step:555/1695 train_time:52639ms step_avg:94.85ms +step:556/1695 train_time:52736ms step_avg:94.85ms +step:557/1695 train_time:52833ms step_avg:94.85ms +step:558/1695 train_time:52930ms step_avg:94.86ms +step:559/1695 train_time:53026ms step_avg:94.86ms +step:560/1695 train_time:53123ms step_avg:94.86ms +step:561/1695 train_time:53220ms step_avg:94.87ms +step:562/1695 train_time:53317ms step_avg:94.87ms +step:563/1695 train_time:53413ms step_avg:94.87ms +step:564/1695 train_time:53510ms step_avg:94.88ms +step:565/1695 train_time:53606ms step_avg:94.88ms +step:566/1695 train_time:53704ms step_avg:94.88ms +step:567/1695 train_time:53801ms step_avg:94.89ms +step:568/1695 train_time:53898ms step_avg:94.89ms +step:569/1695 train_time:53993ms step_avg:94.89ms +step:570/1695 train_time:54089ms step_avg:94.89ms +step:571/1695 train_time:54185ms step_avg:94.89ms +step:572/1695 train_time:54281ms step_avg:94.90ms +step:573/1695 train_time:54378ms step_avg:94.90ms +step:574/1695 train_time:54474ms step_avg:94.90ms +step:575/1695 train_time:54571ms step_avg:94.91ms +step:576/1695 train_time:54668ms step_avg:94.91ms +step:577/1695 train_time:54764ms step_avg:94.91ms +step:578/1695 train_time:54860ms step_avg:94.91ms +step:579/1695 train_time:54957ms step_avg:94.92ms +step:580/1695 train_time:55053ms step_avg:94.92ms +step:581/1695 train_time:55149ms step_avg:94.92ms +step:582/1695 train_time:55245ms step_avg:94.92ms +step:583/1695 train_time:55341ms step_avg:94.93ms +step:584/1695 train_time:55439ms step_avg:94.93ms +step:585/1695 train_time:55537ms step_avg:94.93ms +step:586/1695 train_time:55636ms step_avg:94.94ms +step:587/1695 train_time:55733ms step_avg:94.95ms +step:588/1695 train_time:55830ms step_avg:94.95ms +step:589/1695 train_time:55925ms step_avg:94.95ms +step:590/1695 train_time:56022ms step_avg:94.95ms +step:591/1695 train_time:56119ms step_avg:94.96ms +step:592/1695 train_time:56215ms step_avg:94.96ms +step:593/1695 train_time:56312ms step_avg:94.96ms +step:594/1695 train_time:56409ms step_avg:94.96ms +step:595/1695 train_time:56505ms step_avg:94.97ms +step:596/1695 train_time:56602ms step_avg:94.97ms +step:597/1695 train_time:56699ms step_avg:94.97ms +step:598/1695 train_time:56796ms step_avg:94.98ms +step:599/1695 train_time:56892ms step_avg:94.98ms +step:600/1695 train_time:56988ms step_avg:94.98ms +step:601/1695 train_time:57084ms step_avg:94.98ms +step:602/1695 train_time:57180ms step_avg:94.98ms +step:603/1695 train_time:57276ms step_avg:94.99ms +step:604/1695 train_time:57374ms step_avg:94.99ms +step:605/1695 train_time:57471ms step_avg:94.99ms +step:606/1695 train_time:57567ms step_avg:95.00ms +step:607/1695 train_time:57663ms step_avg:95.00ms +step:608/1695 train_time:57760ms step_avg:95.00ms +step:609/1695 train_time:57857ms step_avg:95.00ms +step:610/1695 train_time:57953ms step_avg:95.01ms +step:611/1695 train_time:58049ms step_avg:95.01ms +step:612/1695 train_time:58144ms step_avg:95.01ms +step:613/1695 train_time:58240ms step_avg:95.01ms +step:614/1695 train_time:58338ms step_avg:95.01ms +step:615/1695 train_time:58436ms step_avg:95.02ms +step:616/1695 train_time:58533ms step_avg:95.02ms +step:617/1695 train_time:58629ms step_avg:95.02ms +step:618/1695 train_time:58725ms step_avg:95.02ms +step:619/1695 train_time:58821ms step_avg:95.03ms +step:620/1695 train_time:58918ms step_avg:95.03ms +step:621/1695 train_time:59015ms step_avg:95.03ms +step:622/1695 train_time:59112ms step_avg:95.03ms +step:623/1695 train_time:59207ms step_avg:95.04ms +step:624/1695 train_time:59303ms step_avg:95.04ms +step:625/1695 train_time:59401ms step_avg:95.04ms +step:625/1695 val_loss:3.6385 train_time:59496ms step_avg:95.19ms +step:626/1695 train_time:59524ms step_avg:95.09ms +step:627/1695 train_time:59604ms step_avg:95.06ms +step:628/1695 train_time:59709ms step_avg:95.08ms +step:629/1695 train_time:60139ms step_avg:95.61ms +step:630/1695 train_time:60235ms step_avg:95.61ms +step:631/1695 train_time:60332ms step_avg:95.61ms +step:632/1695 train_time:60429ms step_avg:95.62ms +step:633/1695 train_time:60525ms step_avg:95.62ms +step:634/1695 train_time:60623ms step_avg:95.62ms +step:635/1695 train_time:60720ms step_avg:95.62ms +step:636/1695 train_time:60817ms step_avg:95.62ms +step:637/1695 train_time:60914ms step_avg:95.63ms +step:638/1695 train_time:61015ms step_avg:95.63ms +step:639/1695 train_time:61115ms step_avg:95.64ms +step:640/1695 train_time:61215ms step_avg:95.65ms +step:641/1695 train_time:61313ms step_avg:95.65ms +step:642/1695 train_time:61771ms step_avg:96.22ms +step:643/1695 train_time:61820ms step_avg:96.14ms +step:644/1695 train_time:61916ms step_avg:96.14ms +step:645/1695 train_time:62013ms step_avg:96.14ms +step:646/1695 train_time:62110ms step_avg:96.15ms +step:647/1695 train_time:62207ms step_avg:96.15ms +step:648/1695 train_time:62305ms step_avg:96.15ms +step:649/1695 train_time:62402ms step_avg:96.15ms +step:650/1695 train_time:62499ms step_avg:96.15ms +step:651/1695 train_time:62596ms step_avg:96.15ms +step:652/1695 train_time:62696ms step_avg:96.16ms +step:653/1695 train_time:62796ms step_avg:96.17ms +step:654/1695 train_time:62896ms step_avg:96.17ms +step:655/1695 train_time:62994ms step_avg:96.17ms +step:656/1695 train_time:63092ms step_avg:96.18ms +step:657/1695 train_time:63189ms step_avg:96.18ms +step:658/1695 train_time:63286ms step_avg:96.18ms +step:659/1695 train_time:63383ms step_avg:96.18ms +step:660/1695 train_time:63480ms step_avg:96.18ms +step:661/1695 train_time:63577ms step_avg:96.18ms +step:662/1695 train_time:63676ms step_avg:96.19ms +step:663/1695 train_time:63775ms step_avg:96.19ms +step:664/1695 train_time:63873ms step_avg:96.19ms +step:665/1695 train_time:63971ms step_avg:96.20ms +step:666/1695 train_time:64070ms step_avg:96.20ms +step:667/1695 train_time:64168ms step_avg:96.20ms +step:668/1695 train_time:64265ms step_avg:96.21ms +step:669/1695 train_time:64363ms step_avg:96.21ms +step:670/1695 train_time:64460ms step_avg:96.21ms +step:671/1695 train_time:64558ms step_avg:96.21ms +step:672/1695 train_time:64655ms step_avg:96.21ms +step:673/1695 train_time:64753ms step_avg:96.22ms +step:674/1695 train_time:64851ms step_avg:96.22ms +step:675/1695 train_time:64949ms step_avg:96.22ms +step:676/1695 train_time:65046ms step_avg:96.22ms +step:677/1695 train_time:65144ms step_avg:96.22ms +step:678/1695 train_time:65241ms step_avg:96.23ms +step:679/1695 train_time:65339ms step_avg:96.23ms +step:680/1695 train_time:65438ms step_avg:96.23ms +step:681/1695 train_time:65535ms step_avg:96.23ms +step:682/1695 train_time:65633ms step_avg:96.24ms +step:683/1695 train_time:65731ms step_avg:96.24ms +step:684/1695 train_time:65829ms step_avg:96.24ms +step:685/1695 train_time:65926ms step_avg:96.24ms +step:686/1695 train_time:66024ms step_avg:96.24ms +step:687/1695 train_time:66122ms step_avg:96.25ms +step:688/1695 train_time:66220ms step_avg:96.25ms +step:689/1695 train_time:66319ms step_avg:96.25ms +step:690/1695 train_time:66417ms step_avg:96.26ms +step:691/1695 train_time:66515ms step_avg:96.26ms +step:692/1695 train_time:66613ms step_avg:96.26ms +step:693/1695 train_time:66712ms step_avg:96.27ms +step:694/1695 train_time:66811ms step_avg:96.27ms +step:695/1695 train_time:66910ms step_avg:96.27ms +step:696/1695 train_time:67008ms step_avg:96.28ms +step:697/1695 train_time:67105ms step_avg:96.28ms +step:698/1695 train_time:67202ms step_avg:96.28ms +step:699/1695 train_time:67299ms step_avg:96.28ms +step:700/1695 train_time:67397ms step_avg:96.28ms +step:701/1695 train_time:67496ms step_avg:96.28ms +step:702/1695 train_time:67593ms step_avg:96.29ms +step:703/1695 train_time:67692ms step_avg:96.29ms +step:704/1695 train_time:67789ms step_avg:96.29ms +step:705/1695 train_time:67887ms step_avg:96.29ms +step:706/1695 train_time:67985ms step_avg:96.30ms +step:707/1695 train_time:68082ms step_avg:96.30ms +step:708/1695 train_time:68179ms step_avg:96.30ms +step:709/1695 train_time:68277ms step_avg:96.30ms +step:710/1695 train_time:68375ms step_avg:96.30ms +step:711/1695 train_time:68474ms step_avg:96.31ms +step:712/1695 train_time:68572ms step_avg:96.31ms +step:713/1695 train_time:68669ms step_avg:96.31ms +step:714/1695 train_time:68767ms step_avg:96.31ms +step:715/1695 train_time:69087ms step_avg:96.63ms +step:716/1695 train_time:69182ms step_avg:96.62ms +step:717/1695 train_time:69280ms step_avg:96.62ms +step:718/1695 train_time:69376ms step_avg:96.62ms +step:719/1695 train_time:69767ms step_avg:97.03ms +step:720/1695 train_time:69863ms step_avg:97.03ms +step:721/1695 train_time:69960ms step_avg:97.03ms +step:722/1695 train_time:70058ms step_avg:97.03ms +step:723/1695 train_time:70154ms step_avg:97.03ms +step:724/1695 train_time:70251ms step_avg:97.03ms +step:725/1695 train_time:70348ms step_avg:97.03ms +step:726/1695 train_time:70444ms step_avg:97.03ms +step:727/1695 train_time:70542ms step_avg:97.03ms +step:728/1695 train_time:70640ms step_avg:97.03ms +step:729/1695 train_time:70742ms step_avg:97.04ms +step:730/1695 train_time:70844ms step_avg:97.05ms +step:731/1695 train_time:70943ms step_avg:97.05ms +step:732/1695 train_time:71043ms step_avg:97.05ms +step:733/1695 train_time:71142ms step_avg:97.06ms +step:734/1695 train_time:71240ms step_avg:97.06ms +step:735/1695 train_time:71336ms step_avg:97.06ms +step:736/1695 train_time:71434ms step_avg:97.06ms +step:737/1695 train_time:71531ms step_avg:97.06ms +step:738/1695 train_time:71628ms step_avg:97.06ms +step:739/1695 train_time:71726ms step_avg:97.06ms +step:740/1695 train_time:71824ms step_avg:97.06ms +step:741/1695 train_time:71923ms step_avg:97.06ms +step:742/1695 train_time:72021ms step_avg:97.06ms +step:743/1695 train_time:72119ms step_avg:97.07ms +step:744/1695 train_time:72217ms step_avg:97.07ms +step:745/1695 train_time:72314ms step_avg:97.07ms +step:746/1695 train_time:72412ms step_avg:97.07ms +step:747/1695 train_time:72509ms step_avg:97.07ms +step:748/1695 train_time:72606ms step_avg:97.07ms +step:749/1695 train_time:72704ms step_avg:97.07ms +step:750/1695 train_time:72802ms step_avg:97.07ms +step:750/1695 val_loss:3.5787 train_time:72898ms step_avg:97.20ms +step:751/1695 train_time:72926ms step_avg:97.11ms +step:752/1695 train_time:73011ms step_avg:97.09ms +step:753/1695 train_time:73110ms step_avg:97.09ms +step:754/1695 train_time:73208ms step_avg:97.09ms +step:755/1695 train_time:73306ms step_avg:97.09ms +step:756/1695 train_time:73405ms step_avg:97.10ms +step:757/1695 train_time:73503ms step_avg:97.10ms +step:758/1695 train_time:73600ms step_avg:97.10ms +step:759/1695 train_time:73697ms step_avg:97.10ms +step:760/1695 train_time:73795ms step_avg:97.10ms +step:761/1695 train_time:73893ms step_avg:97.10ms +step:762/1695 train_time:73993ms step_avg:97.10ms +step:763/1695 train_time:74092ms step_avg:97.11ms +step:764/1695 train_time:74190ms step_avg:97.11ms +step:765/1695 train_time:74288ms step_avg:97.11ms +step:766/1695 train_time:74387ms step_avg:97.11ms +step:767/1695 train_time:74486ms step_avg:97.11ms +step:768/1695 train_time:74584ms step_avg:97.12ms +step:769/1695 train_time:74683ms step_avg:97.12ms +step:770/1695 train_time:74781ms step_avg:97.12ms +step:771/1695 train_time:74880ms step_avg:97.12ms +step:772/1695 train_time:74978ms step_avg:97.12ms +step:773/1695 train_time:75078ms step_avg:97.13ms +step:774/1695 train_time:75177ms step_avg:97.13ms +step:775/1695 train_time:75276ms step_avg:97.13ms +step:776/1695 train_time:75374ms step_avg:97.13ms +step:777/1695 train_time:75473ms step_avg:97.13ms +step:778/1695 train_time:75570ms step_avg:97.13ms +step:779/1695 train_time:75667ms step_avg:97.13ms +step:780/1695 train_time:75766ms step_avg:97.14ms +step:781/1695 train_time:75864ms step_avg:97.14ms +step:782/1695 train_time:75964ms step_avg:97.14ms +step:783/1695 train_time:76062ms step_avg:97.14ms +step:784/1695 train_time:76160ms step_avg:97.14ms +step:785/1695 train_time:76259ms step_avg:97.14ms +step:786/1695 train_time:76358ms step_avg:97.15ms +step:787/1695 train_time:76457ms step_avg:97.15ms +step:788/1695 train_time:76555ms step_avg:97.15ms +step:789/1695 train_time:76653ms step_avg:97.15ms +step:790/1695 train_time:76751ms step_avg:97.15ms +step:791/1695 train_time:76848ms step_avg:97.15ms +step:792/1695 train_time:76946ms step_avg:97.15ms +step:793/1695 train_time:77045ms step_avg:97.16ms +step:794/1695 train_time:77144ms step_avg:97.16ms +step:795/1695 train_time:77244ms step_avg:97.16ms +step:796/1695 train_time:77585ms step_avg:97.47ms +step:797/1695 train_time:77682ms step_avg:97.47ms +step:798/1695 train_time:77779ms step_avg:97.47ms +step:799/1695 train_time:77877ms step_avg:97.47ms +step:800/1695 train_time:77974ms step_avg:97.47ms +step:801/1695 train_time:78071ms step_avg:97.47ms +step:802/1695 train_time:78168ms step_avg:97.47ms +step:803/1695 train_time:78265ms step_avg:97.47ms +step:804/1695 train_time:78603ms step_avg:97.76ms +step:805/1695 train_time:78699ms step_avg:97.76ms +step:806/1695 train_time:78796ms step_avg:97.76ms +step:807/1695 train_time:78894ms step_avg:97.76ms +step:808/1695 train_time:78991ms step_avg:97.76ms +step:809/1695 train_time:79089ms step_avg:97.76ms +step:810/1695 train_time:79187ms step_avg:97.76ms +step:811/1695 train_time:79285ms step_avg:97.76ms +step:812/1695 train_time:79382ms step_avg:97.76ms +step:813/1695 train_time:79483ms step_avg:97.77ms +step:814/1695 train_time:79584ms step_avg:97.77ms +step:815/1695 train_time:79683ms step_avg:97.77ms +step:816/1695 train_time:79782ms step_avg:97.77ms +step:817/1695 train_time:79881ms step_avg:97.77ms +step:818/1695 train_time:79979ms step_avg:97.77ms +step:819/1695 train_time:80077ms step_avg:97.77ms +step:820/1695 train_time:80176ms step_avg:97.78ms +step:821/1695 train_time:80274ms step_avg:97.78ms +step:822/1695 train_time:80371ms step_avg:97.78ms +step:823/1695 train_time:80469ms step_avg:97.77ms +step:824/1695 train_time:80566ms step_avg:97.77ms +step:825/1695 train_time:80664ms step_avg:97.78ms +step:826/1695 train_time:80762ms step_avg:97.78ms +step:827/1695 train_time:80861ms step_avg:97.78ms +step:828/1695 train_time:80959ms step_avg:97.78ms +step:829/1695 train_time:81057ms step_avg:97.78ms +step:830/1695 train_time:81155ms step_avg:97.78ms +step:831/1695 train_time:81253ms step_avg:97.78ms +step:832/1695 train_time:81352ms step_avg:97.78ms +step:833/1695 train_time:81450ms step_avg:97.78ms +step:834/1695 train_time:81547ms step_avg:97.78ms +step:835/1695 train_time:81645ms step_avg:97.78ms +step:836/1695 train_time:81745ms step_avg:97.78ms +step:837/1695 train_time:81843ms step_avg:97.78ms +step:838/1695 train_time:81941ms step_avg:97.78ms +step:839/1695 train_time:82039ms step_avg:97.78ms +step:840/1695 train_time:82138ms step_avg:97.78ms +step:841/1695 train_time:82237ms step_avg:97.79ms +step:842/1695 train_time:82336ms step_avg:97.79ms +step:843/1695 train_time:82435ms step_avg:97.79ms +step:844/1695 train_time:82534ms step_avg:97.79ms +step:845/1695 train_time:82631ms step_avg:97.79ms +step:846/1695 train_time:82729ms step_avg:97.79ms +step:847/1695 train_time:82827ms step_avg:97.79ms +step:848/1695 train_time:82925ms step_avg:97.79ms +step:849/1695 train_time:83023ms step_avg:97.79ms +step:850/1695 train_time:83120ms step_avg:97.79ms +step:851/1695 train_time:83218ms step_avg:97.79ms +step:852/1695 train_time:83316ms step_avg:97.79ms +step:853/1695 train_time:83414ms step_avg:97.79ms +step:854/1695 train_time:83513ms step_avg:97.79ms +step:855/1695 train_time:83610ms step_avg:97.79ms +step:856/1695 train_time:83708ms step_avg:97.79ms +step:857/1695 train_time:83806ms step_avg:97.79ms +step:858/1695 train_time:83904ms step_avg:97.79ms +step:859/1695 train_time:84002ms step_avg:97.79ms +step:860/1695 train_time:84101ms step_avg:97.79ms +step:861/1695 train_time:84199ms step_avg:97.79ms +step:862/1695 train_time:84298ms step_avg:97.79ms +step:863/1695 train_time:84396ms step_avg:97.79ms +step:864/1695 train_time:84495ms step_avg:97.80ms +step:865/1695 train_time:84594ms step_avg:97.80ms +step:866/1695 train_time:84693ms step_avg:97.80ms +step:867/1695 train_time:84791ms step_avg:97.80ms +step:868/1695 train_time:84889ms step_avg:97.80ms +step:869/1695 train_time:84986ms step_avg:97.80ms +step:870/1695 train_time:85084ms step_avg:97.80ms +step:871/1695 train_time:85182ms step_avg:97.80ms +step:872/1695 train_time:85282ms step_avg:97.80ms +step:873/1695 train_time:85382ms step_avg:97.80ms +step:874/1695 train_time:85481ms step_avg:97.80ms +step:875/1695 train_time:85581ms step_avg:97.81ms +step:875/1695 val_loss:3.5322 train_time:85677ms step_avg:97.92ms +step:876/1695 train_time:85705ms step_avg:97.84ms +step:877/1695 train_time:85791ms step_avg:97.82ms +step:878/1695 train_time:85890ms step_avg:97.82ms +step:879/1695 train_time:85989ms step_avg:97.83ms +step:880/1695 train_time:86086ms step_avg:97.82ms +step:881/1695 train_time:86185ms step_avg:97.83ms +step:882/1695 train_time:86285ms step_avg:97.83ms +step:883/1695 train_time:86384ms step_avg:97.83ms +step:884/1695 train_time:86483ms step_avg:97.83ms +step:885/1695 train_time:86581ms step_avg:97.83ms +step:886/1695 train_time:86682ms step_avg:97.84ms +step:887/1695 train_time:86785ms step_avg:97.84ms +step:888/1695 train_time:86887ms step_avg:97.85ms +step:889/1695 train_time:86988ms step_avg:97.85ms +step:890/1695 train_time:87087ms step_avg:97.85ms +step:891/1695 train_time:87186ms step_avg:97.85ms +step:892/1695 train_time:87285ms step_avg:97.85ms +step:893/1695 train_time:87384ms step_avg:97.85ms +step:894/1695 train_time:87483ms step_avg:97.86ms +step:895/1695 train_time:87583ms step_avg:97.86ms +step:896/1695 train_time:87682ms step_avg:97.86ms +step:897/1695 train_time:87782ms step_avg:97.86ms +step:898/1695 train_time:87883ms step_avg:97.87ms +step:899/1695 train_time:87984ms step_avg:97.87ms +step:900/1695 train_time:88085ms step_avg:97.87ms +step:901/1695 train_time:88185ms step_avg:97.87ms +step:902/1695 train_time:88284ms step_avg:97.88ms +step:903/1695 train_time:88384ms step_avg:97.88ms +step:904/1695 train_time:88483ms step_avg:97.88ms +step:905/1695 train_time:88582ms step_avg:97.88ms +step:906/1695 train_time:88682ms step_avg:97.88ms +step:907/1695 train_time:88782ms step_avg:97.89ms +step:908/1695 train_time:88883ms step_avg:97.89ms +step:909/1695 train_time:88983ms step_avg:97.89ms +step:910/1695 train_time:89083ms step_avg:97.89ms +step:911/1695 train_time:89183ms step_avg:97.90ms +step:912/1695 train_time:89283ms step_avg:97.90ms +step:913/1695 train_time:89382ms step_avg:97.90ms +step:914/1695 train_time:89481ms step_avg:97.90ms +step:915/1695 train_time:89580ms step_avg:97.90ms +step:916/1695 train_time:89680ms step_avg:97.90ms +step:917/1695 train_time:89778ms step_avg:97.90ms +step:918/1695 train_time:89878ms step_avg:97.91ms +step:919/1695 train_time:89977ms step_avg:97.91ms +step:920/1695 train_time:90077ms step_avg:97.91ms +step:921/1695 train_time:90176ms step_avg:97.91ms +step:922/1695 train_time:90275ms step_avg:97.91ms +step:923/1695 train_time:90374ms step_avg:97.91ms +step:924/1695 train_time:90473ms step_avg:97.91ms +step:925/1695 train_time:90574ms step_avg:97.92ms +step:926/1695 train_time:90674ms step_avg:97.92ms +step:927/1695 train_time:90773ms step_avg:97.92ms +step:928/1695 train_time:90871ms step_avg:97.92ms +step:929/1695 train_time:90971ms step_avg:97.92ms +step:930/1695 train_time:91071ms step_avg:97.93ms +step:931/1695 train_time:91171ms step_avg:97.93ms +step:932/1695 train_time:91271ms step_avg:97.93ms +step:933/1695 train_time:91370ms step_avg:97.93ms +step:934/1695 train_time:91469ms step_avg:97.93ms +step:935/1695 train_time:91568ms step_avg:97.93ms +step:936/1695 train_time:91667ms step_avg:97.93ms +step:937/1695 train_time:91767ms step_avg:97.94ms +step:938/1695 train_time:91866ms step_avg:97.94ms +step:939/1695 train_time:91967ms step_avg:97.94ms +step:940/1695 train_time:92067ms step_avg:97.94ms +step:941/1695 train_time:92168ms step_avg:97.95ms +step:942/1695 train_time:92267ms step_avg:97.95ms +step:943/1695 train_time:92367ms step_avg:97.95ms +step:944/1695 train_time:92466ms step_avg:97.95ms +step:945/1695 train_time:92567ms step_avg:97.95ms +step:946/1695 train_time:92666ms step_avg:97.96ms +step:947/1695 train_time:92766ms step_avg:97.96ms +step:948/1695 train_time:92865ms step_avg:97.96ms +step:949/1695 train_time:92965ms step_avg:97.96ms +step:950/1695 train_time:93065ms step_avg:97.96ms +step:951/1695 train_time:93166ms step_avg:97.97ms +step:952/1695 train_time:93266ms step_avg:97.97ms +step:953/1695 train_time:93366ms step_avg:97.97ms +step:954/1695 train_time:93466ms step_avg:97.97ms +step:955/1695 train_time:93567ms step_avg:97.98ms +step:956/1695 train_time:93666ms step_avg:97.98ms +step:957/1695 train_time:93766ms step_avg:97.98ms +step:958/1695 train_time:93865ms step_avg:97.98ms +step:959/1695 train_time:93965ms step_avg:97.98ms +step:960/1695 train_time:94065ms step_avg:97.98ms +step:961/1695 train_time:94165ms step_avg:97.99ms +step:962/1695 train_time:94265ms step_avg:97.99ms +step:963/1695 train_time:94365ms step_avg:97.99ms +step:964/1695 train_time:94466ms step_avg:97.99ms +step:965/1695 train_time:94566ms step_avg:98.00ms +step:966/1695 train_time:94669ms step_avg:98.00ms +step:967/1695 train_time:94769ms step_avg:98.00ms +step:968/1695 train_time:94867ms step_avg:98.00ms +step:969/1695 train_time:94967ms step_avg:98.01ms +step:970/1695 train_time:95067ms step_avg:98.01ms +step:971/1695 train_time:95167ms step_avg:98.01ms +step:972/1695 train_time:95266ms step_avg:98.01ms +step:973/1695 train_time:95366ms step_avg:98.01ms +step:974/1695 train_time:95465ms step_avg:98.01ms +step:975/1695 train_time:95565ms step_avg:98.02ms +step:976/1695 train_time:95665ms step_avg:98.02ms +step:977/1695 train_time:95765ms step_avg:98.02ms +step:978/1695 train_time:95864ms step_avg:98.02ms +step:979/1695 train_time:95965ms step_avg:98.02ms +step:980/1695 train_time:96065ms step_avg:98.03ms +step:981/1695 train_time:96164ms step_avg:98.03ms +step:982/1695 train_time:96265ms step_avg:98.03ms +step:983/1695 train_time:96366ms step_avg:98.03ms +step:984/1695 train_time:96465ms step_avg:98.03ms +step:985/1695 train_time:96565ms step_avg:98.04ms +step:986/1695 train_time:96666ms step_avg:98.04ms +step:987/1695 train_time:96766ms step_avg:98.04ms +step:988/1695 train_time:96866ms step_avg:98.04ms +step:989/1695 train_time:96966ms step_avg:98.04ms +step:990/1695 train_time:97065ms step_avg:98.05ms +step:991/1695 train_time:97166ms step_avg:98.05ms +step:992/1695 train_time:97265ms step_avg:98.05ms +step:993/1695 train_time:97365ms step_avg:98.05ms +step:994/1695 train_time:97465ms step_avg:98.05ms +step:995/1695 train_time:97565ms step_avg:98.06ms +step:996/1695 train_time:97664ms step_avg:98.06ms +step:997/1695 train_time:97764ms step_avg:98.06ms +step:998/1695 train_time:97864ms step_avg:98.06ms +step:999/1695 train_time:97965ms step_avg:98.06ms +step:1000/1695 train_time:98064ms step_avg:98.06ms +step:1000/1695 val_loss:3.4877 train_time:98160ms step_avg:98.16ms +step:1001/1695 train_time:98189ms step_avg:98.09ms +step:1002/1695 train_time:98270ms step_avg:98.07ms +step:1003/1695 train_time:98372ms step_avg:98.08ms +step:1004/1695 train_time:98471ms step_avg:98.08ms +step:1005/1695 train_time:98570ms step_avg:98.08ms +step:1006/1695 train_time:98668ms step_avg:98.08ms +step:1007/1695 train_time:98767ms step_avg:98.08ms +step:1008/1695 train_time:98865ms step_avg:98.08ms +step:1009/1695 train_time:98965ms step_avg:98.08ms +step:1010/1695 train_time:99063ms step_avg:98.08ms +step:1011/1695 train_time:99164ms step_avg:98.08ms +step:1012/1695 train_time:99265ms step_avg:98.09ms +step:1013/1695 train_time:99366ms step_avg:98.09ms +step:1014/1695 train_time:99466ms step_avg:98.09ms +step:1015/1695 train_time:99567ms step_avg:98.10ms +step:1016/1695 train_time:99666ms step_avg:98.10ms +step:1017/1695 train_time:99766ms step_avg:98.10ms +step:1018/1695 train_time:99865ms step_avg:98.10ms +step:1019/1695 train_time:99963ms step_avg:98.10ms +step:1020/1695 train_time:100063ms step_avg:98.10ms +step:1021/1695 train_time:100164ms step_avg:98.10ms +step:1022/1695 train_time:100264ms step_avg:98.11ms +step:1023/1695 train_time:100365ms step_avg:98.11ms +step:1024/1695 train_time:100467ms step_avg:98.11ms +step:1025/1695 train_time:100567ms step_avg:98.11ms +step:1026/1695 train_time:100667ms step_avg:98.12ms +step:1027/1695 train_time:100766ms step_avg:98.12ms +step:1028/1695 train_time:100865ms step_avg:98.12ms +step:1029/1695 train_time:100965ms step_avg:98.12ms +step:1030/1695 train_time:101064ms step_avg:98.12ms +step:1031/1695 train_time:101165ms step_avg:98.12ms +step:1032/1695 train_time:101266ms step_avg:98.13ms +step:1033/1695 train_time:101366ms step_avg:98.13ms +step:1034/1695 train_time:101466ms step_avg:98.13ms +step:1035/1695 train_time:101566ms step_avg:98.13ms +step:1036/1695 train_time:101665ms step_avg:98.13ms +step:1037/1695 train_time:101766ms step_avg:98.13ms +step:1038/1695 train_time:101865ms step_avg:98.14ms +step:1039/1695 train_time:101964ms step_avg:98.14ms +step:1040/1695 train_time:102063ms step_avg:98.14ms +step:1041/1695 train_time:102163ms step_avg:98.14ms +step:1042/1695 train_time:102264ms step_avg:98.14ms +step:1043/1695 train_time:102365ms step_avg:98.14ms +step:1044/1695 train_time:102465ms step_avg:98.15ms +step:1045/1695 train_time:102565ms step_avg:98.15ms +step:1046/1695 train_time:102665ms step_avg:98.15ms +step:1047/1695 train_time:102765ms step_avg:98.15ms +step:1048/1695 train_time:102865ms step_avg:98.15ms +step:1049/1695 train_time:102964ms step_avg:98.15ms +step:1050/1695 train_time:103064ms step_avg:98.16ms +step:1051/1695 train_time:103164ms step_avg:98.16ms +step:1052/1695 train_time:103264ms step_avg:98.16ms +step:1053/1695 train_time:103366ms step_avg:98.16ms +step:1054/1695 train_time:103466ms step_avg:98.17ms +step:1055/1695 train_time:103566ms step_avg:98.17ms +step:1056/1695 train_time:103666ms step_avg:98.17ms +step:1057/1695 train_time:103765ms step_avg:98.17ms +step:1058/1695 train_time:103864ms step_avg:98.17ms +step:1059/1695 train_time:103963ms step_avg:98.17ms +step:1060/1695 train_time:104063ms step_avg:98.17ms +step:1061/1695 train_time:104162ms step_avg:98.17ms +step:1062/1695 train_time:104262ms step_avg:98.18ms +step:1063/1695 train_time:104362ms step_avg:98.18ms +step:1064/1695 train_time:104462ms step_avg:98.18ms +step:1065/1695 train_time:104563ms step_avg:98.18ms +step:1066/1695 train_time:104663ms step_avg:98.18ms +step:1067/1695 train_time:104763ms step_avg:98.18ms +step:1068/1695 train_time:104863ms step_avg:98.19ms +step:1069/1695 train_time:104962ms step_avg:98.19ms +step:1070/1695 train_time:105062ms step_avg:98.19ms +step:1071/1695 train_time:105161ms step_avg:98.19ms +step:1072/1695 train_time:105261ms step_avg:98.19ms +step:1073/1695 train_time:105360ms step_avg:98.19ms +step:1074/1695 train_time:105459ms step_avg:98.19ms +step:1075/1695 train_time:105559ms step_avg:98.19ms +step:1076/1695 train_time:105659ms step_avg:98.20ms +step:1077/1695 train_time:105760ms step_avg:98.20ms +step:1078/1695 train_time:105860ms step_avg:98.20ms +step:1079/1695 train_time:105961ms step_avg:98.20ms +step:1080/1695 train_time:106060ms step_avg:98.20ms +step:1081/1695 train_time:106159ms step_avg:98.20ms +step:1082/1695 train_time:106260ms step_avg:98.21ms +step:1083/1695 train_time:106359ms step_avg:98.21ms +step:1084/1695 train_time:106460ms step_avg:98.21ms +step:1085/1695 train_time:106559ms step_avg:98.21ms +step:1086/1695 train_time:106660ms step_avg:98.21ms +step:1087/1695 train_time:106760ms step_avg:98.22ms +step:1088/1695 train_time:106861ms step_avg:98.22ms +step:1089/1695 train_time:106961ms step_avg:98.22ms +step:1090/1695 train_time:107062ms step_avg:98.22ms +step:1091/1695 train_time:107161ms step_avg:98.22ms +step:1092/1695 train_time:107261ms step_avg:98.22ms +step:1093/1695 train_time:107360ms step_avg:98.23ms +step:1094/1695 train_time:107460ms step_avg:98.23ms +step:1095/1695 train_time:107560ms step_avg:98.23ms +step:1096/1695 train_time:107661ms step_avg:98.23ms +step:1097/1695 train_time:107761ms step_avg:98.23ms +step:1098/1695 train_time:107860ms step_avg:98.23ms +step:1099/1695 train_time:107959ms step_avg:98.23ms +step:1100/1695 train_time:108059ms step_avg:98.24ms +step:1101/1695 train_time:108159ms step_avg:98.24ms +step:1102/1695 train_time:108259ms step_avg:98.24ms +step:1103/1695 train_time:108359ms step_avg:98.24ms +step:1104/1695 train_time:108460ms step_avg:98.24ms +step:1105/1695 train_time:108559ms step_avg:98.24ms +step:1106/1695 train_time:108660ms step_avg:98.25ms +step:1107/1695 train_time:108760ms step_avg:98.25ms +step:1108/1695 train_time:108860ms step_avg:98.25ms +step:1109/1695 train_time:108959ms step_avg:98.25ms +step:1110/1695 train_time:109060ms step_avg:98.25ms +step:1111/1695 train_time:109160ms step_avg:98.25ms +step:1112/1695 train_time:109260ms step_avg:98.26ms +step:1113/1695 train_time:109361ms step_avg:98.26ms +step:1114/1695 train_time:109461ms step_avg:98.26ms +step:1115/1695 train_time:109561ms step_avg:98.26ms +step:1116/1695 train_time:109661ms step_avg:98.26ms +step:1117/1695 train_time:109763ms step_avg:98.27ms +step:1118/1695 train_time:109863ms step_avg:98.27ms +step:1119/1695 train_time:109964ms step_avg:98.27ms +step:1120/1695 train_time:110064ms step_avg:98.27ms +step:1121/1695 train_time:110164ms step_avg:98.27ms +step:1122/1695 train_time:110264ms step_avg:98.27ms +step:1123/1695 train_time:110365ms step_avg:98.28ms +step:1124/1695 train_time:110465ms step_avg:98.28ms +step:1125/1695 train_time:110565ms step_avg:98.28ms +step:1125/1695 val_loss:3.4364 train_time:110663ms step_avg:98.37ms +step:1126/1695 train_time:110690ms step_avg:98.30ms +step:1127/1695 train_time:110776ms step_avg:98.29ms +step:1128/1695 train_time:110878ms step_avg:98.30ms +step:1129/1695 train_time:110978ms step_avg:98.30ms +step:1130/1695 train_time:111078ms step_avg:98.30ms +step:1131/1695 train_time:111177ms step_avg:98.30ms +step:1132/1695 train_time:111277ms step_avg:98.30ms +step:1133/1695 train_time:111376ms step_avg:98.30ms +step:1134/1695 train_time:111477ms step_avg:98.30ms +step:1135/1695 train_time:111576ms step_avg:98.30ms +step:1136/1695 train_time:111680ms step_avg:98.31ms +step:1137/1695 train_time:111782ms step_avg:98.31ms +step:1138/1695 train_time:111883ms step_avg:98.32ms +step:1139/1695 train_time:111982ms step_avg:98.32ms +step:1140/1695 train_time:112083ms step_avg:98.32ms +step:1141/1695 train_time:112182ms step_avg:98.32ms +step:1142/1695 train_time:112283ms step_avg:98.32ms +step:1143/1695 train_time:112382ms step_avg:98.32ms +step:1144/1695 train_time:112483ms step_avg:98.32ms +step:1145/1695 train_time:112586ms step_avg:98.33ms +step:1146/1695 train_time:112687ms step_avg:98.33ms +step:1147/1695 train_time:112789ms step_avg:98.33ms +step:1148/1695 train_time:112889ms step_avg:98.34ms +step:1149/1695 train_time:112991ms step_avg:98.34ms +step:1150/1695 train_time:113093ms step_avg:98.34ms +step:1151/1695 train_time:113193ms step_avg:98.34ms +step:1152/1695 train_time:113295ms step_avg:98.35ms +step:1153/1695 train_time:113395ms step_avg:98.35ms +step:1154/1695 train_time:113496ms step_avg:98.35ms +step:1155/1695 train_time:113597ms step_avg:98.35ms +step:1156/1695 train_time:113698ms step_avg:98.35ms +step:1157/1695 train_time:113800ms step_avg:98.36ms +step:1158/1695 train_time:113900ms step_avg:98.36ms +step:1159/1695 train_time:114001ms step_avg:98.36ms +step:1160/1695 train_time:114102ms step_avg:98.36ms +step:1161/1695 train_time:114202ms step_avg:98.36ms +step:1162/1695 train_time:114301ms step_avg:98.37ms +step:1163/1695 train_time:114404ms step_avg:98.37ms +step:1164/1695 train_time:114506ms step_avg:98.37ms +step:1165/1695 train_time:114609ms step_avg:98.38ms +step:1166/1695 train_time:114710ms step_avg:98.38ms +step:1167/1695 train_time:114813ms step_avg:98.38ms +step:1168/1695 train_time:114914ms step_avg:98.39ms +step:1169/1695 train_time:115015ms step_avg:98.39ms +step:1170/1695 train_time:115117ms step_avg:98.39ms +step:1171/1695 train_time:115218ms step_avg:98.39ms +step:1172/1695 train_time:115319ms step_avg:98.40ms +step:1173/1695 train_time:115420ms step_avg:98.40ms +step:1174/1695 train_time:115522ms step_avg:98.40ms +step:1175/1695 train_time:115622ms step_avg:98.40ms +step:1176/1695 train_time:115722ms step_avg:98.40ms +step:1177/1695 train_time:115823ms step_avg:98.41ms +step:1178/1695 train_time:115923ms step_avg:98.41ms +step:1179/1695 train_time:116026ms step_avg:98.41ms +step:1180/1695 train_time:116126ms step_avg:98.41ms +step:1181/1695 train_time:116227ms step_avg:98.41ms +step:1182/1695 train_time:116329ms step_avg:98.42ms +step:1183/1695 train_time:116431ms step_avg:98.42ms +step:1184/1695 train_time:116533ms step_avg:98.42ms +step:1185/1695 train_time:116635ms step_avg:98.43ms +step:1186/1695 train_time:116736ms step_avg:98.43ms +step:1187/1695 train_time:116837ms step_avg:98.43ms +step:1188/1695 train_time:116938ms step_avg:98.43ms +step:1189/1695 train_time:117038ms step_avg:98.43ms +step:1190/1695 train_time:117139ms step_avg:98.44ms +step:1191/1695 train_time:117241ms step_avg:98.44ms +step:1192/1695 train_time:117340ms step_avg:98.44ms +step:1193/1695 train_time:117440ms step_avg:98.44ms +step:1194/1695 train_time:117541ms step_avg:98.44ms +step:1195/1695 train_time:117641ms step_avg:98.44ms +step:1196/1695 train_time:117741ms step_avg:98.45ms +step:1197/1695 train_time:117842ms step_avg:98.45ms +step:1198/1695 train_time:117941ms step_avg:98.45ms +step:1199/1695 train_time:118041ms step_avg:98.45ms +step:1200/1695 train_time:118142ms step_avg:98.45ms +step:1201/1695 train_time:118242ms step_avg:98.45ms +step:1202/1695 train_time:118343ms step_avg:98.46ms +step:1203/1695 train_time:118443ms step_avg:98.46ms +step:1204/1695 train_time:118544ms step_avg:98.46ms +step:1205/1695 train_time:118645ms step_avg:98.46ms +step:1206/1695 train_time:118746ms step_avg:98.46ms +step:1207/1695 train_time:118846ms step_avg:98.46ms +step:1208/1695 train_time:118946ms step_avg:98.47ms +step:1209/1695 train_time:119047ms step_avg:98.47ms +step:1210/1695 train_time:119148ms step_avg:98.47ms +step:1211/1695 train_time:119250ms step_avg:98.47ms +step:1212/1695 train_time:119352ms step_avg:98.48ms +step:1213/1695 train_time:119453ms step_avg:98.48ms +step:1214/1695 train_time:119554ms step_avg:98.48ms +step:1215/1695 train_time:119656ms step_avg:98.48ms +step:1216/1695 train_time:119758ms step_avg:98.49ms +step:1217/1695 train_time:119859ms step_avg:98.49ms +step:1218/1695 train_time:119961ms step_avg:98.49ms +step:1219/1695 train_time:120062ms step_avg:98.49ms +step:1220/1695 train_time:120162ms step_avg:98.49ms +step:1221/1695 train_time:120261ms step_avg:98.49ms +step:1222/1695 train_time:120362ms step_avg:98.50ms +step:1223/1695 train_time:120462ms step_avg:98.50ms +step:1224/1695 train_time:120562ms step_avg:98.50ms +step:1225/1695 train_time:120664ms step_avg:98.50ms +step:1226/1695 train_time:120764ms step_avg:98.50ms +step:1227/1695 train_time:120866ms step_avg:98.50ms +step:1228/1695 train_time:120967ms step_avg:98.51ms +step:1229/1695 train_time:121068ms step_avg:98.51ms +step:1230/1695 train_time:121169ms step_avg:98.51ms +step:1231/1695 train_time:121270ms step_avg:98.51ms +step:1232/1695 train_time:121371ms step_avg:98.52ms +step:1233/1695 train_time:121472ms step_avg:98.52ms +step:1234/1695 train_time:121575ms step_avg:98.52ms +step:1235/1695 train_time:121675ms step_avg:98.52ms +step:1236/1695 train_time:121779ms step_avg:98.53ms +step:1237/1695 train_time:121880ms step_avg:98.53ms +step:1238/1695 train_time:121980ms step_avg:98.53ms +step:1239/1695 train_time:122081ms step_avg:98.53ms +step:1240/1695 train_time:122181ms step_avg:98.53ms +step:1241/1695 train_time:122283ms step_avg:98.54ms +step:1242/1695 train_time:122383ms step_avg:98.54ms +step:1243/1695 train_time:122484ms step_avg:98.54ms +step:1244/1695 train_time:122584ms step_avg:98.54ms +step:1245/1695 train_time:122686ms step_avg:98.54ms +step:1246/1695 train_time:122787ms step_avg:98.54ms +step:1247/1695 train_time:122888ms step_avg:98.55ms +step:1248/1695 train_time:122990ms step_avg:98.55ms +step:1249/1695 train_time:123091ms step_avg:98.55ms +step:1250/1695 train_time:123192ms step_avg:98.55ms +step:1250/1695 val_loss:3.3910 train_time:123290ms step_avg:98.63ms +step:1251/1695 train_time:123319ms step_avg:98.58ms +step:1252/1695 train_time:123407ms step_avg:98.57ms +step:1253/1695 train_time:123508ms step_avg:98.57ms +step:1254/1695 train_time:123608ms step_avg:98.57ms +step:1255/1695 train_time:123708ms step_avg:98.57ms +step:1256/1695 train_time:123808ms step_avg:98.57ms +step:1257/1695 train_time:123907ms step_avg:98.57ms +step:1258/1695 train_time:124007ms step_avg:98.58ms +step:1259/1695 train_time:124107ms step_avg:98.58ms +step:1260/1695 train_time:124206ms step_avg:98.58ms +step:1261/1695 train_time:124308ms step_avg:98.58ms +step:1262/1695 train_time:124411ms step_avg:98.58ms +step:1263/1695 train_time:124513ms step_avg:98.59ms +step:1264/1695 train_time:124614ms step_avg:98.59ms +step:1265/1695 train_time:124715ms step_avg:98.59ms +step:1266/1695 train_time:124816ms step_avg:98.59ms +step:1267/1695 train_time:124917ms step_avg:98.59ms +step:1268/1695 train_time:125019ms step_avg:98.60ms +step:1269/1695 train_time:125120ms step_avg:98.60ms +step:1270/1695 train_time:125221ms step_avg:98.60ms +step:1271/1695 train_time:125324ms step_avg:98.60ms +step:1272/1695 train_time:125424ms step_avg:98.60ms +step:1273/1695 train_time:125525ms step_avg:98.61ms +step:1274/1695 train_time:125625ms step_avg:98.61ms +step:1275/1695 train_time:125726ms step_avg:98.61ms +step:1276/1695 train_time:125827ms step_avg:98.61ms +step:1277/1695 train_time:125927ms step_avg:98.61ms +step:1278/1695 train_time:126027ms step_avg:98.61ms +step:1279/1695 train_time:126127ms step_avg:98.61ms +step:1280/1695 train_time:126228ms step_avg:98.62ms +step:1281/1695 train_time:126330ms step_avg:98.62ms +step:1282/1695 train_time:126432ms step_avg:98.62ms +step:1283/1695 train_time:126535ms step_avg:98.62ms +step:1284/1695 train_time:126637ms step_avg:98.63ms +step:1285/1695 train_time:126737ms step_avg:98.63ms +step:1286/1695 train_time:126838ms step_avg:98.63ms +step:1287/1695 train_time:126939ms step_avg:98.63ms +step:1288/1695 train_time:127040ms step_avg:98.63ms +step:1289/1695 train_time:127141ms step_avg:98.64ms +step:1290/1695 train_time:127243ms step_avg:98.64ms +step:1291/1695 train_time:127344ms step_avg:98.64ms +step:1292/1695 train_time:127446ms step_avg:98.64ms +step:1293/1695 train_time:127546ms step_avg:98.64ms +step:1294/1695 train_time:127648ms step_avg:98.65ms +step:1295/1695 train_time:127749ms step_avg:98.65ms +step:1296/1695 train_time:127848ms step_avg:98.65ms +step:1297/1695 train_time:127948ms step_avg:98.65ms +step:1298/1695 train_time:128049ms step_avg:98.65ms +step:1299/1695 train_time:128151ms step_avg:98.65ms +step:1300/1695 train_time:128253ms step_avg:98.66ms +step:1301/1695 train_time:128355ms step_avg:98.66ms +step:1302/1695 train_time:128457ms step_avg:98.66ms +step:1303/1695 train_time:128559ms step_avg:98.66ms +step:1304/1695 train_time:128660ms step_avg:98.67ms +step:1305/1695 train_time:128761ms step_avg:98.67ms +step:1306/1695 train_time:128862ms step_avg:98.67ms +step:1307/1695 train_time:128963ms step_avg:98.67ms +step:1308/1695 train_time:129064ms step_avg:98.67ms +step:1309/1695 train_time:129164ms step_avg:98.67ms +step:1310/1695 train_time:129265ms step_avg:98.68ms +step:1311/1695 train_time:129366ms step_avg:98.68ms +step:1312/1695 train_time:129466ms step_avg:98.68ms +step:1313/1695 train_time:129567ms step_avg:98.68ms +step:1314/1695 train_time:129667ms step_avg:98.68ms +step:1315/1695 train_time:129768ms step_avg:98.68ms +step:1316/1695 train_time:129869ms step_avg:98.68ms +step:1317/1695 train_time:129969ms step_avg:98.69ms +step:1318/1695 train_time:130069ms step_avg:98.69ms +step:1319/1695 train_time:130170ms step_avg:98.69ms +step:1320/1695 train_time:130271ms step_avg:98.69ms +step:1321/1695 train_time:130373ms step_avg:98.69ms +step:1322/1695 train_time:130475ms step_avg:98.70ms +step:1323/1695 train_time:130576ms step_avg:98.70ms +step:1324/1695 train_time:130677ms step_avg:98.70ms +step:1325/1695 train_time:130778ms step_avg:98.70ms +step:1326/1695 train_time:130880ms step_avg:98.70ms +step:1327/1695 train_time:130981ms step_avg:98.70ms +step:1328/1695 train_time:131082ms step_avg:98.71ms +step:1329/1695 train_time:131183ms step_avg:98.71ms +step:1330/1695 train_time:131284ms step_avg:98.71ms +step:1331/1695 train_time:131385ms step_avg:98.71ms +step:1332/1695 train_time:131486ms step_avg:98.71ms +step:1333/1695 train_time:131588ms step_avg:98.72ms +step:1334/1695 train_time:131688ms step_avg:98.72ms +step:1335/1695 train_time:131787ms step_avg:98.72ms +step:1336/1695 train_time:131889ms step_avg:98.72ms +step:1337/1695 train_time:131989ms step_avg:98.72ms +step:1338/1695 train_time:132090ms step_avg:98.72ms +step:1339/1695 train_time:132192ms step_avg:98.72ms +step:1340/1695 train_time:132293ms step_avg:98.73ms +step:1341/1695 train_time:132394ms step_avg:98.73ms +step:1342/1695 train_time:132496ms step_avg:98.73ms +step:1343/1695 train_time:132597ms step_avg:98.73ms +step:1344/1695 train_time:132697ms step_avg:98.73ms +step:1345/1695 train_time:132799ms step_avg:98.74ms +step:1346/1695 train_time:132901ms step_avg:98.74ms +step:1347/1695 train_time:133001ms step_avg:98.74ms +step:1348/1695 train_time:133103ms step_avg:98.74ms +step:1349/1695 train_time:133204ms step_avg:98.74ms +step:1350/1695 train_time:133305ms step_avg:98.74ms +step:1351/1695 train_time:133405ms step_avg:98.75ms +step:1352/1695 train_time:133506ms step_avg:98.75ms +step:1353/1695 train_time:133606ms step_avg:98.75ms +step:1354/1695 train_time:133705ms step_avg:98.75ms +step:1355/1695 train_time:133805ms step_avg:98.75ms +step:1356/1695 train_time:133905ms step_avg:98.75ms +step:1357/1695 train_time:134006ms step_avg:98.75ms +step:1358/1695 train_time:134106ms step_avg:98.75ms +step:1359/1695 train_time:134206ms step_avg:98.75ms +step:1360/1695 train_time:134307ms step_avg:98.76ms +step:1361/1695 train_time:134409ms step_avg:98.76ms +step:1362/1695 train_time:134509ms step_avg:98.76ms +step:1363/1695 train_time:134610ms step_avg:98.76ms +step:1364/1695 train_time:134712ms step_avg:98.76ms +step:1365/1695 train_time:134813ms step_avg:98.76ms +step:1366/1695 train_time:134915ms step_avg:98.77ms +step:1367/1695 train_time:135015ms step_avg:98.77ms +step:1368/1695 train_time:135116ms step_avg:98.77ms +step:1369/1695 train_time:135217ms step_avg:98.77ms +step:1370/1695 train_time:135317ms step_avg:98.77ms +step:1371/1695 train_time:135418ms step_avg:98.77ms +step:1372/1695 train_time:135520ms step_avg:98.78ms +step:1373/1695 train_time:135622ms step_avg:98.78ms +step:1374/1695 train_time:135723ms step_avg:98.78ms +step:1375/1695 train_time:135825ms step_avg:98.78ms +step:1375/1695 val_loss:3.3517 train_time:135923ms step_avg:98.85ms +step:1376/1695 train_time:135951ms step_avg:98.80ms +step:1377/1695 train_time:136040ms step_avg:98.79ms +step:1378/1695 train_time:136141ms step_avg:98.80ms +step:1379/1695 train_time:136242ms step_avg:98.80ms +step:1380/1695 train_time:136344ms step_avg:98.80ms +step:1381/1695 train_time:136445ms step_avg:98.80ms +step:1382/1695 train_time:136544ms step_avg:98.80ms +step:1383/1695 train_time:136644ms step_avg:98.80ms +step:1384/1695 train_time:136744ms step_avg:98.80ms +step:1385/1695 train_time:136845ms step_avg:98.81ms +step:1386/1695 train_time:136949ms step_avg:98.81ms +step:1387/1695 train_time:137052ms step_avg:98.81ms +step:1388/1695 train_time:137155ms step_avg:98.81ms +step:1389/1695 train_time:137257ms step_avg:98.82ms +step:1390/1695 train_time:137358ms step_avg:98.82ms +step:1391/1695 train_time:137459ms step_avg:98.82ms +step:1392/1695 train_time:137561ms step_avg:98.82ms +step:1393/1695 train_time:137663ms step_avg:98.82ms +step:1394/1695 train_time:137765ms step_avg:98.83ms +step:1395/1695 train_time:137867ms step_avg:98.83ms +step:1396/1695 train_time:137969ms step_avg:98.83ms +step:1397/1695 train_time:138071ms step_avg:98.83ms +step:1398/1695 train_time:138172ms step_avg:98.84ms +step:1399/1695 train_time:138273ms step_avg:98.84ms +step:1400/1695 train_time:138375ms step_avg:98.84ms +step:1401/1695 train_time:138475ms step_avg:98.84ms +step:1402/1695 train_time:138578ms step_avg:98.84ms +step:1403/1695 train_time:138682ms step_avg:98.85ms +step:1404/1695 train_time:138784ms step_avg:98.85ms +step:1405/1695 train_time:138886ms step_avg:98.85ms +step:1406/1695 train_time:138988ms step_avg:98.85ms +step:1407/1695 train_time:139090ms step_avg:98.86ms +step:1408/1695 train_time:139191ms step_avg:98.86ms +step:1409/1695 train_time:139296ms step_avg:98.86ms +step:1410/1695 train_time:139397ms step_avg:98.86ms +step:1411/1695 train_time:139498ms step_avg:98.86ms +step:1412/1695 train_time:139602ms step_avg:98.87ms +step:1413/1695 train_time:139702ms step_avg:98.87ms +step:1414/1695 train_time:139805ms step_avg:98.87ms +step:1415/1695 train_time:139908ms step_avg:98.88ms +step:1416/1695 train_time:140010ms step_avg:98.88ms +step:1417/1695 train_time:140110ms step_avg:98.88ms +step:1418/1695 train_time:140211ms step_avg:98.88ms +step:1419/1695 train_time:140314ms step_avg:98.88ms +step:1420/1695 train_time:140414ms step_avg:98.88ms +step:1421/1695 train_time:140516ms step_avg:98.89ms +step:1422/1695 train_time:140617ms step_avg:98.89ms +step:1423/1695 train_time:140718ms step_avg:98.89ms +step:1424/1695 train_time:140820ms step_avg:98.89ms +step:1425/1695 train_time:140924ms step_avg:98.89ms +step:1426/1695 train_time:141027ms step_avg:98.90ms +step:1427/1695 train_time:141129ms step_avg:98.90ms +step:1428/1695 train_time:141231ms step_avg:98.90ms +step:1429/1695 train_time:141332ms step_avg:98.90ms +step:1430/1695 train_time:141433ms step_avg:98.90ms +step:1431/1695 train_time:141534ms step_avg:98.91ms +step:1432/1695 train_time:141635ms step_avg:98.91ms +step:1433/1695 train_time:141736ms step_avg:98.91ms +step:1434/1695 train_time:141838ms step_avg:98.91ms +step:1435/1695 train_time:141941ms step_avg:98.91ms +step:1436/1695 train_time:142044ms step_avg:98.92ms +step:1437/1695 train_time:142146ms step_avg:98.92ms +step:1438/1695 train_time:142247ms step_avg:98.92ms +step:1439/1695 train_time:142350ms step_avg:98.92ms +step:1440/1695 train_time:142453ms step_avg:98.93ms +step:1441/1695 train_time:142555ms step_avg:98.93ms +step:1442/1695 train_time:142655ms step_avg:98.93ms +step:1443/1695 train_time:142755ms step_avg:98.93ms +step:1444/1695 train_time:142856ms step_avg:98.93ms +step:1445/1695 train_time:142958ms step_avg:98.93ms +step:1446/1695 train_time:143061ms step_avg:98.94ms +step:1447/1695 train_time:143164ms step_avg:98.94ms +step:1448/1695 train_time:143268ms step_avg:98.94ms +step:1449/1695 train_time:143368ms step_avg:98.94ms +step:1450/1695 train_time:143469ms step_avg:98.94ms +step:1451/1695 train_time:143571ms step_avg:98.95ms +step:1452/1695 train_time:143672ms step_avg:98.95ms +step:1453/1695 train_time:143773ms step_avg:98.95ms +step:1454/1695 train_time:143877ms step_avg:98.95ms +step:1455/1695 train_time:143979ms step_avg:98.95ms +step:1456/1695 train_time:144082ms step_avg:98.96ms +step:1457/1695 train_time:144184ms step_avg:98.96ms +step:1458/1695 train_time:144287ms step_avg:98.96ms +step:1459/1695 train_time:144389ms step_avg:98.96ms +step:1460/1695 train_time:144490ms step_avg:98.97ms +step:1461/1695 train_time:144593ms step_avg:98.97ms +step:1462/1695 train_time:144694ms step_avg:98.97ms +step:1463/1695 train_time:144794ms step_avg:98.97ms +step:1464/1695 train_time:144896ms step_avg:98.97ms +step:1465/1695 train_time:144996ms step_avg:98.97ms +step:1466/1695 train_time:145100ms step_avg:98.98ms +step:1467/1695 train_time:145203ms step_avg:98.98ms +step:1468/1695 train_time:145305ms step_avg:98.98ms +step:1469/1695 train_time:145408ms step_avg:98.98ms +step:1470/1695 train_time:145509ms step_avg:98.99ms +step:1471/1695 train_time:145610ms step_avg:98.99ms +step:1472/1695 train_time:145711ms step_avg:98.99ms +step:1473/1695 train_time:145812ms step_avg:98.99ms +step:1474/1695 train_time:145913ms step_avg:98.99ms +step:1475/1695 train_time:146014ms step_avg:98.99ms +step:1476/1695 train_time:146116ms step_avg:98.99ms +step:1477/1695 train_time:146221ms step_avg:99.00ms +step:1478/1695 train_time:146323ms step_avg:99.00ms +step:1479/1695 train_time:146424ms step_avg:99.00ms +step:1480/1695 train_time:146527ms step_avg:99.00ms +step:1481/1695 train_time:146629ms step_avg:99.01ms +step:1482/1695 train_time:146731ms step_avg:99.01ms +step:1483/1695 train_time:146832ms step_avg:99.01ms +step:1484/1695 train_time:146934ms step_avg:99.01ms +step:1485/1695 train_time:147036ms step_avg:99.01ms +step:1486/1695 train_time:147138ms step_avg:99.02ms +step:1487/1695 train_time:147239ms step_avg:99.02ms +step:1488/1695 train_time:147342ms step_avg:99.02ms +step:1489/1695 train_time:147445ms step_avg:99.02ms +step:1490/1695 train_time:147548ms step_avg:99.03ms +step:1491/1695 train_time:147650ms step_avg:99.03ms +step:1492/1695 train_time:147751ms step_avg:99.03ms +step:1493/1695 train_time:147852ms step_avg:99.03ms +step:1494/1695 train_time:147954ms step_avg:99.03ms +step:1495/1695 train_time:148055ms step_avg:99.03ms +step:1496/1695 train_time:148156ms step_avg:99.03ms +step:1497/1695 train_time:148257ms step_avg:99.04ms +step:1498/1695 train_time:148360ms step_avg:99.04ms +step:1499/1695 train_time:148462ms step_avg:99.04ms +step:1500/1695 train_time:148564ms step_avg:99.04ms +step:1500/1695 val_loss:3.3169 train_time:148663ms step_avg:99.11ms +step:1501/1695 train_time:148691ms step_avg:99.06ms +step:1502/1695 train_time:148780ms step_avg:99.05ms +step:1503/1695 train_time:148882ms step_avg:99.06ms +step:1504/1695 train_time:148983ms step_avg:99.06ms +step:1505/1695 train_time:149083ms step_avg:99.06ms +step:1506/1695 train_time:149184ms step_avg:99.06ms +step:1507/1695 train_time:149285ms step_avg:99.06ms +step:1508/1695 train_time:149385ms step_avg:99.06ms +step:1509/1695 train_time:149488ms step_avg:99.06ms +step:1510/1695 train_time:149589ms step_avg:99.07ms +step:1511/1695 train_time:149694ms step_avg:99.07ms +step:1512/1695 train_time:149796ms step_avg:99.07ms +step:1513/1695 train_time:149899ms step_avg:99.07ms +step:1514/1695 train_time:150001ms step_avg:99.08ms +step:1515/1695 train_time:150106ms step_avg:99.08ms +step:1516/1695 train_time:150207ms step_avg:99.08ms +step:1517/1695 train_time:150307ms step_avg:99.08ms +step:1518/1695 train_time:150408ms step_avg:99.08ms +step:1519/1695 train_time:150512ms step_avg:99.09ms +step:1520/1695 train_time:150614ms step_avg:99.09ms +step:1521/1695 train_time:150716ms step_avg:99.09ms +step:1522/1695 train_time:150818ms step_avg:99.09ms +step:1523/1695 train_time:150920ms step_avg:99.09ms +step:1524/1695 train_time:151025ms step_avg:99.10ms +step:1525/1695 train_time:151129ms step_avg:99.10ms +step:1526/1695 train_time:151231ms step_avg:99.10ms +step:1527/1695 train_time:151332ms step_avg:99.10ms +step:1528/1695 train_time:151438ms step_avg:99.11ms +step:1529/1695 train_time:151539ms step_avg:99.11ms +step:1530/1695 train_time:151642ms step_avg:99.11ms +step:1531/1695 train_time:151742ms step_avg:99.11ms +step:1532/1695 train_time:151845ms step_avg:99.12ms +step:1533/1695 train_time:151945ms step_avg:99.12ms +step:1534/1695 train_time:152046ms step_avg:99.12ms +step:1535/1695 train_time:152147ms step_avg:99.12ms +step:1536/1695 train_time:152248ms step_avg:99.12ms +step:1537/1695 train_time:152350ms step_avg:99.12ms +step:1538/1695 train_time:152452ms step_avg:99.12ms +step:1539/1695 train_time:152555ms step_avg:99.13ms +step:1540/1695 train_time:152659ms step_avg:99.13ms +step:1541/1695 train_time:152762ms step_avg:99.13ms +step:1542/1695 train_time:152868ms step_avg:99.14ms +step:1543/1695 train_time:152969ms step_avg:99.14ms +step:1544/1695 train_time:153071ms step_avg:99.14ms +step:1545/1695 train_time:153173ms step_avg:99.14ms +step:1546/1695 train_time:153274ms step_avg:99.14ms +step:1547/1695 train_time:153377ms step_avg:99.14ms +step:1548/1695 train_time:153479ms step_avg:99.15ms +step:1549/1695 train_time:153581ms step_avg:99.15ms +step:1550/1695 train_time:153683ms step_avg:99.15ms +step:1551/1695 train_time:153786ms step_avg:99.15ms +step:1552/1695 train_time:153887ms step_avg:99.15ms +step:1553/1695 train_time:153990ms step_avg:99.16ms +step:1554/1695 train_time:154091ms step_avg:99.16ms +step:1555/1695 train_time:154193ms step_avg:99.16ms +step:1556/1695 train_time:154294ms step_avg:99.16ms +step:1557/1695 train_time:154398ms step_avg:99.16ms +step:1558/1695 train_time:154501ms step_avg:99.17ms +step:1559/1695 train_time:154603ms step_avg:99.17ms +step:1560/1695 train_time:154704ms step_avg:99.17ms +step:1561/1695 train_time:154806ms step_avg:99.17ms +step:1562/1695 train_time:154909ms step_avg:99.17ms +step:1563/1695 train_time:155013ms step_avg:99.18ms +step:1564/1695 train_time:155114ms step_avg:99.18ms +step:1565/1695 train_time:155215ms step_avg:99.18ms +step:1566/1695 train_time:155317ms step_avg:99.18ms +step:1567/1695 train_time:155418ms step_avg:99.18ms +step:1568/1695 train_time:155519ms step_avg:99.18ms +step:1569/1695 train_time:155621ms step_avg:99.19ms +step:1570/1695 train_time:155726ms step_avg:99.19ms +step:1571/1695 train_time:155827ms step_avg:99.19ms +step:1572/1695 train_time:155929ms step_avg:99.19ms +step:1573/1695 train_time:156031ms step_avg:99.19ms +step:1574/1695 train_time:156132ms step_avg:99.19ms +step:1575/1695 train_time:156233ms step_avg:99.20ms +step:1576/1695 train_time:156336ms step_avg:99.20ms +step:1577/1695 train_time:156439ms step_avg:99.20ms +step:1578/1695 train_time:156541ms step_avg:99.20ms +step:1579/1695 train_time:156643ms step_avg:99.20ms +step:1580/1695 train_time:156746ms step_avg:99.21ms +step:1581/1695 train_time:156849ms step_avg:99.21ms +step:1582/1695 train_time:156950ms step_avg:99.21ms +step:1583/1695 train_time:157054ms step_avg:99.21ms +step:1584/1695 train_time:157156ms step_avg:99.21ms +step:1585/1695 train_time:157257ms step_avg:99.22ms +step:1586/1695 train_time:157360ms step_avg:99.22ms +step:1587/1695 train_time:157462ms step_avg:99.22ms +step:1588/1695 train_time:157563ms step_avg:99.22ms +step:1589/1695 train_time:157664ms step_avg:99.22ms +step:1590/1695 train_time:157766ms step_avg:99.22ms +step:1591/1695 train_time:157867ms step_avg:99.23ms +step:1592/1695 train_time:157969ms step_avg:99.23ms +step:1593/1695 train_time:158070ms step_avg:99.23ms +step:1594/1695 train_time:158174ms step_avg:99.23ms +step:1595/1695 train_time:158277ms step_avg:99.23ms +step:1596/1695 train_time:158378ms step_avg:99.23ms +step:1597/1695 train_time:158481ms step_avg:99.24ms +step:1598/1695 train_time:158584ms step_avg:99.24ms +step:1599/1695 train_time:158691ms step_avg:99.24ms +step:1600/1695 train_time:158787ms step_avg:99.24ms +step:1601/1695 train_time:158890ms step_avg:99.24ms +step:1602/1695 train_time:158991ms step_avg:99.25ms +step:1603/1695 train_time:159093ms step_avg:99.25ms +step:1604/1695 train_time:159194ms step_avg:99.25ms +step:1605/1695 train_time:159296ms step_avg:99.25ms +step:1606/1695 train_time:159399ms step_avg:99.25ms +step:1607/1695 train_time:159500ms step_avg:99.25ms +step:1608/1695 train_time:159601ms step_avg:99.25ms +step:1609/1695 train_time:159703ms step_avg:99.26ms +step:1610/1695 train_time:159806ms step_avg:99.26ms +step:1611/1695 train_time:159908ms step_avg:99.26ms +step:1612/1695 train_time:160010ms step_avg:99.26ms +step:1613/1695 train_time:160111ms step_avg:99.26ms +step:1614/1695 train_time:160212ms step_avg:99.26ms +step:1615/1695 train_time:160314ms step_avg:99.27ms +step:1616/1695 train_time:160415ms step_avg:99.27ms +step:1617/1695 train_time:160518ms step_avg:99.27ms +step:1618/1695 train_time:160620ms step_avg:99.27ms +step:1619/1695 train_time:160723ms step_avg:99.27ms +step:1620/1695 train_time:160826ms step_avg:99.28ms +step:1621/1695 train_time:160927ms step_avg:99.28ms +step:1622/1695 train_time:161028ms step_avg:99.28ms +step:1623/1695 train_time:161130ms step_avg:99.28ms +step:1624/1695 train_time:161231ms step_avg:99.28ms +step:1625/1695 train_time:161335ms step_avg:99.28ms +step:1625/1695 val_loss:3.2888 train_time:161435ms step_avg:99.34ms +step:1626/1695 train_time:161464ms step_avg:99.30ms +step:1627/1695 train_time:161550ms step_avg:99.29ms +step:1628/1695 train_time:161655ms step_avg:99.30ms +step:1629/1695 train_time:161757ms step_avg:99.30ms +step:1630/1695 train_time:161859ms step_avg:99.30ms +step:1631/1695 train_time:161961ms step_avg:99.30ms +step:1632/1695 train_time:162062ms step_avg:99.30ms +step:1633/1695 train_time:162162ms step_avg:99.30ms +step:1634/1695 train_time:162265ms step_avg:99.31ms +step:1635/1695 train_time:162367ms step_avg:99.31ms +step:1636/1695 train_time:162470ms step_avg:99.31ms +step:1637/1695 train_time:162574ms step_avg:99.31ms +step:1638/1695 train_time:162677ms step_avg:99.31ms +step:1639/1695 train_time:162779ms step_avg:99.32ms +step:1640/1695 train_time:162881ms step_avg:99.32ms +step:1641/1695 train_time:162984ms step_avg:99.32ms +step:1642/1695 train_time:163085ms step_avg:99.32ms +step:1643/1695 train_time:163187ms step_avg:99.32ms +step:1644/1695 train_time:163293ms step_avg:99.33ms +step:1645/1695 train_time:163393ms step_avg:99.33ms +step:1646/1695 train_time:163496ms step_avg:99.33ms +step:1647/1695 train_time:163601ms step_avg:99.33ms +step:1648/1695 train_time:163704ms step_avg:99.33ms +step:1649/1695 train_time:163807ms step_avg:99.34ms +step:1650/1695 train_time:163910ms step_avg:99.34ms +step:1651/1695 train_time:164012ms step_avg:99.34ms +step:1652/1695 train_time:164115ms step_avg:99.34ms +step:1653/1695 train_time:164219ms step_avg:99.35ms +step:1654/1695 train_time:164321ms step_avg:99.35ms +step:1655/1695 train_time:164423ms step_avg:99.35ms +step:1656/1695 train_time:164525ms step_avg:99.35ms +step:1657/1695 train_time:164627ms step_avg:99.35ms +step:1658/1695 train_time:164730ms step_avg:99.35ms +step:1659/1695 train_time:164836ms step_avg:99.36ms +step:1660/1695 train_time:164938ms step_avg:99.36ms +step:1661/1695 train_time:165042ms step_avg:99.36ms +step:1662/1695 train_time:165147ms step_avg:99.37ms +step:1663/1695 train_time:165249ms step_avg:99.37ms +step:1664/1695 train_time:165351ms step_avg:99.37ms +step:1665/1695 train_time:165456ms step_avg:99.37ms +step:1666/1695 train_time:165560ms step_avg:99.38ms +step:1667/1695 train_time:165661ms step_avg:99.38ms +step:1668/1695 train_time:165766ms step_avg:99.38ms +step:1669/1695 train_time:165871ms step_avg:99.38ms +step:1670/1695 train_time:165973ms step_avg:99.39ms +step:1671/1695 train_time:166077ms step_avg:99.39ms +step:1672/1695 train_time:166181ms step_avg:99.39ms +step:1673/1695 train_time:166282ms step_avg:99.39ms +step:1674/1695 train_time:166383ms step_avg:99.39ms +step:1675/1695 train_time:166486ms step_avg:99.39ms +step:1676/1695 train_time:166594ms step_avg:99.40ms +step:1677/1695 train_time:166695ms step_avg:99.40ms +step:1678/1695 train_time:166798ms step_avg:99.40ms +step:1679/1695 train_time:166901ms step_avg:99.40ms +step:1680/1695 train_time:167002ms step_avg:99.41ms +step:1681/1695 train_time:167105ms step_avg:99.41ms +step:1682/1695 train_time:167213ms step_avg:99.41ms +step:1683/1695 train_time:167315ms step_avg:99.41ms +step:1684/1695 train_time:167419ms step_avg:99.42ms +step:1685/1695 train_time:167522ms step_avg:99.42ms +step:1686/1695 train_time:167624ms step_avg:99.42ms +step:1687/1695 train_time:167726ms step_avg:99.42ms +step:1688/1695 train_time:167829ms step_avg:99.42ms +step:1689/1695 train_time:167931ms step_avg:99.43ms +step:1690/1695 train_time:168033ms step_avg:99.43ms +step:1691/1695 train_time:168136ms step_avg:99.43ms +step:1692/1695 train_time:168238ms step_avg:99.43ms +step:1693/1695 train_time:168341ms step_avg:99.43ms +step:1694/1695 train_time:168444ms step_avg:99.44ms +step:1695/1695 train_time:168548ms step_avg:99.44ms +step:1695/1695 val_loss:3.2760 train_time:168647ms step_avg:99.50ms +peak memory allocated: 34761 MiB reserved: 49140 MiB diff --git a/records/082325_SparseAttnGate/50524dcb-cf95-4b75-bf89-ba8ff3c5e1af.txt b/records/082325_SparseAttnGate/50524dcb-cf95-4b75-bf89-ba8ff3c5e1af.txt new file mode 100644 index 000000000..d12ded416 --- /dev/null +++ b/records/082325_SparseAttnGate/50524dcb-cf95-4b75-bf89-ba8ff3c5e1af.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:12:16 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 299071 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 299072 C /usr/bin/python3 614MiB | +| 0 N/A N/A 299073 C /usr/bin/python3 614MiB | +| 0 N/A N/A 299074 C /usr/bin/python3 614MiB | +| 0 N/A N/A 299075 C /usr/bin/python3 614MiB | +| 0 N/A N/A 299076 C /usr/bin/python3 614MiB | +| 0 N/A N/A 299077 C /usr/bin/python3 614MiB | +| 0 N/A N/A 299078 C /usr/bin/python3 614MiB | +| 1 N/A N/A 299072 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 299073 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 299074 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 299075 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 299076 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 299077 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 299078 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1695 train_time:157ms step_avg:157.11ms +step:2/1695 train_time:184ms step_avg:92.16ms +step:3/1695 train_time:253ms step_avg:84.19ms +step:4/1695 train_time:345ms step_avg:86.24ms +step:5/1695 train_time:438ms step_avg:87.58ms +step:6/1695 train_time:531ms step_avg:88.50ms +step:7/1695 train_time:624ms step_avg:89.17ms +step:8/1695 train_time:718ms step_avg:89.74ms +step:9/1695 train_time:811ms step_avg:90.06ms +step:10/1695 train_time:903ms step_avg:90.34ms +step:11/1695 train_time:997ms step_avg:90.61ms +step:12/1695 train_time:1090ms step_avg:90.82ms +step:13/1695 train_time:1184ms step_avg:91.09ms +step:14/1695 train_time:1279ms step_avg:91.33ms +step:15/1695 train_time:1373ms step_avg:91.54ms +step:16/1695 train_time:1467ms step_avg:91.66ms +step:17/1695 train_time:1561ms step_avg:91.80ms +step:18/1695 train_time:1655ms step_avg:91.92ms +step:19/1695 train_time:1748ms step_avg:91.98ms +step:20/1695 train_time:1842ms step_avg:92.10ms +step:21/1695 train_time:1936ms step_avg:92.18ms +step:22/1695 train_time:2030ms step_avg:92.27ms +step:23/1695 train_time:2124ms step_avg:92.34ms +step:24/1695 train_time:2218ms step_avg:92.41ms +step:25/1695 train_time:2311ms step_avg:92.46ms +step:26/1695 train_time:2405ms step_avg:92.50ms +step:27/1695 train_time:2498ms step_avg:92.53ms +step:28/1695 train_time:2593ms step_avg:92.60ms +step:29/1695 train_time:2686ms step_avg:92.64ms +step:30/1695 train_time:2780ms step_avg:92.68ms +step:31/1695 train_time:2874ms step_avg:92.71ms +step:32/1695 train_time:2967ms step_avg:92.72ms +step:33/1695 train_time:3061ms step_avg:92.74ms +step:34/1695 train_time:3154ms step_avg:92.77ms +step:35/1695 train_time:3248ms step_avg:92.80ms +step:36/1695 train_time:3342ms step_avg:92.84ms +step:37/1695 train_time:3437ms step_avg:92.88ms +step:38/1695 train_time:3532ms step_avg:92.96ms +step:39/1695 train_time:3624ms step_avg:92.93ms +step:40/1695 train_time:3719ms step_avg:92.97ms +step:41/1695 train_time:3813ms step_avg:93.01ms +step:42/1695 train_time:3908ms step_avg:93.04ms +step:43/1695 train_time:4002ms step_avg:93.07ms +step:44/1695 train_time:4096ms step_avg:93.08ms +step:45/1695 train_time:4188ms step_avg:93.08ms +step:46/1695 train_time:4282ms step_avg:93.10ms +step:47/1695 train_time:4376ms step_avg:93.11ms +step:48/1695 train_time:4471ms step_avg:93.14ms +step:49/1695 train_time:4564ms step_avg:93.14ms +step:50/1695 train_time:4658ms step_avg:93.17ms +step:51/1695 train_time:4753ms step_avg:93.20ms +step:52/1695 train_time:4847ms step_avg:93.21ms +step:53/1695 train_time:4941ms step_avg:93.23ms +step:54/1695 train_time:5035ms step_avg:93.24ms +step:55/1695 train_time:5128ms step_avg:93.24ms +step:56/1695 train_time:5221ms step_avg:93.24ms +step:57/1695 train_time:5315ms step_avg:93.25ms +step:58/1695 train_time:5410ms step_avg:93.28ms +step:59/1695 train_time:5503ms step_avg:93.27ms +step:60/1695 train_time:5597ms step_avg:93.28ms +step:61/1695 train_time:5690ms step_avg:93.28ms +step:62/1695 train_time:5784ms step_avg:93.29ms +step:63/1695 train_time:5879ms step_avg:93.32ms +step:64/1695 train_time:5973ms step_avg:93.33ms +step:65/1695 train_time:6067ms step_avg:93.34ms +step:66/1695 train_time:6161ms step_avg:93.35ms +step:67/1695 train_time:6254ms step_avg:93.35ms +step:68/1695 train_time:6348ms step_avg:93.35ms +step:69/1695 train_time:6442ms step_avg:93.36ms +step:70/1695 train_time:6536ms step_avg:93.37ms +step:71/1695 train_time:6629ms step_avg:93.36ms +step:72/1695 train_time:6723ms step_avg:93.37ms +step:73/1695 train_time:6817ms step_avg:93.38ms +step:74/1695 train_time:6911ms step_avg:93.39ms +step:75/1695 train_time:7004ms step_avg:93.39ms +step:76/1695 train_time:7099ms step_avg:93.41ms +step:77/1695 train_time:7193ms step_avg:93.41ms +step:78/1695 train_time:7286ms step_avg:93.41ms +step:79/1695 train_time:7380ms step_avg:93.41ms +step:80/1695 train_time:7474ms step_avg:93.42ms +step:81/1695 train_time:7567ms step_avg:93.42ms +step:82/1695 train_time:7662ms step_avg:93.44ms +step:83/1695 train_time:7756ms step_avg:93.45ms +step:84/1695 train_time:7851ms step_avg:93.46ms +step:85/1695 train_time:7944ms step_avg:93.46ms +step:86/1695 train_time:8039ms step_avg:93.47ms +step:87/1695 train_time:8132ms step_avg:93.48ms +step:88/1695 train_time:8227ms step_avg:93.48ms +step:89/1695 train_time:8320ms step_avg:93.49ms +step:90/1695 train_time:8414ms step_avg:93.49ms +step:91/1695 train_time:8507ms step_avg:93.48ms +step:92/1695 train_time:8601ms step_avg:93.49ms +step:93/1695 train_time:8694ms step_avg:93.49ms +step:94/1695 train_time:8788ms step_avg:93.49ms +step:95/1695 train_time:8882ms step_avg:93.50ms +step:96/1695 train_time:8976ms step_avg:93.50ms +step:97/1695 train_time:9070ms step_avg:93.51ms +step:98/1695 train_time:9164ms step_avg:93.51ms +step:99/1695 train_time:9258ms step_avg:93.52ms +step:100/1695 train_time:9352ms step_avg:93.52ms +step:101/1695 train_time:9446ms step_avg:93.53ms +step:102/1695 train_time:9541ms step_avg:93.54ms +step:103/1695 train_time:9635ms step_avg:93.54ms +step:104/1695 train_time:9729ms step_avg:93.54ms +step:105/1695 train_time:9822ms step_avg:93.55ms +step:106/1695 train_time:9917ms step_avg:93.55ms +step:107/1695 train_time:10010ms step_avg:93.55ms +step:108/1695 train_time:10104ms step_avg:93.56ms +step:109/1695 train_time:10199ms step_avg:93.57ms +step:110/1695 train_time:10292ms step_avg:93.57ms +step:111/1695 train_time:10385ms step_avg:93.56ms +step:112/1695 train_time:10479ms step_avg:93.56ms +step:113/1695 train_time:10573ms step_avg:93.57ms +step:114/1695 train_time:10667ms step_avg:93.57ms +step:115/1695 train_time:10761ms step_avg:93.57ms +step:116/1695 train_time:10855ms step_avg:93.58ms +step:117/1695 train_time:10949ms step_avg:93.58ms +step:118/1695 train_time:11043ms step_avg:93.59ms +step:119/1695 train_time:11136ms step_avg:93.58ms +step:120/1695 train_time:11231ms step_avg:93.59ms +step:121/1695 train_time:11324ms step_avg:93.59ms +step:122/1695 train_time:11419ms step_avg:93.60ms +step:123/1695 train_time:11512ms step_avg:93.60ms +step:124/1695 train_time:11605ms step_avg:93.59ms +step:125/1695 train_time:11698ms step_avg:93.59ms +step:125/1695 val_loss:4.6089 train_time:11790ms step_avg:94.32ms +step:126/1695 train_time:11820ms step_avg:93.81ms +step:127/1695 train_time:11894ms step_avg:93.65ms +step:128/1695 train_time:11995ms step_avg:93.71ms +step:129/1695 train_time:12092ms step_avg:93.73ms +step:130/1695 train_time:12186ms step_avg:93.74ms +step:131/1695 train_time:12279ms step_avg:93.74ms +step:132/1695 train_time:12373ms step_avg:93.73ms +step:133/1695 train_time:12467ms step_avg:93.73ms +step:134/1695 train_time:12560ms step_avg:93.73ms +step:135/1695 train_time:12653ms step_avg:93.73ms +step:136/1695 train_time:12746ms step_avg:93.72ms +step:137/1695 train_time:12840ms step_avg:93.73ms +step:138/1695 train_time:12935ms step_avg:93.74ms +step:139/1695 train_time:13032ms step_avg:93.76ms +step:140/1695 train_time:13128ms step_avg:93.77ms +step:141/1695 train_time:13222ms step_avg:93.77ms +step:142/1695 train_time:13315ms step_avg:93.77ms +step:143/1695 train_time:13409ms step_avg:93.77ms +step:144/1695 train_time:13503ms step_avg:93.77ms +step:145/1695 train_time:13596ms step_avg:93.77ms +step:146/1695 train_time:13690ms step_avg:93.77ms +step:147/1695 train_time:13784ms step_avg:93.77ms +step:148/1695 train_time:13878ms step_avg:93.77ms +step:149/1695 train_time:13972ms step_avg:93.77ms +step:150/1695 train_time:14067ms step_avg:93.78ms +step:151/1695 train_time:14162ms step_avg:93.79ms +step:152/1695 train_time:14257ms step_avg:93.79ms +step:153/1695 train_time:14351ms step_avg:93.80ms +step:154/1695 train_time:14446ms step_avg:93.80ms +step:155/1695 train_time:14539ms step_avg:93.80ms +step:156/1695 train_time:14633ms step_avg:93.80ms +step:157/1695 train_time:14727ms step_avg:93.80ms +step:158/1695 train_time:14820ms step_avg:93.80ms +step:159/1695 train_time:14914ms step_avg:93.80ms +step:160/1695 train_time:15008ms step_avg:93.80ms +step:161/1695 train_time:15102ms step_avg:93.80ms +step:162/1695 train_time:15196ms step_avg:93.80ms +step:163/1695 train_time:15290ms step_avg:93.80ms +step:164/1695 train_time:15384ms step_avg:93.81ms +step:165/1695 train_time:15479ms step_avg:93.81ms +step:166/1695 train_time:15572ms step_avg:93.81ms +step:167/1695 train_time:15667ms step_avg:93.82ms +step:168/1695 train_time:15762ms step_avg:93.82ms +step:169/1695 train_time:15855ms step_avg:93.82ms +step:170/1695 train_time:15949ms step_avg:93.82ms +step:171/1695 train_time:16044ms step_avg:93.82ms +step:172/1695 train_time:16141ms step_avg:93.84ms +step:173/1695 train_time:16232ms step_avg:93.83ms +step:174/1695 train_time:16326ms step_avg:93.83ms +step:175/1695 train_time:16421ms step_avg:93.83ms +step:176/1695 train_time:16515ms step_avg:93.84ms +step:177/1695 train_time:16610ms step_avg:93.84ms +step:178/1695 train_time:16704ms step_avg:93.85ms +step:179/1695 train_time:16798ms step_avg:93.84ms +step:180/1695 train_time:16892ms step_avg:93.84ms +step:181/1695 train_time:16986ms step_avg:93.85ms +step:182/1695 train_time:17080ms step_avg:93.85ms +step:183/1695 train_time:17174ms step_avg:93.85ms +step:184/1695 train_time:17270ms step_avg:93.86ms +step:185/1695 train_time:17365ms step_avg:93.86ms +step:186/1695 train_time:17458ms step_avg:93.86ms +step:187/1695 train_time:17553ms step_avg:93.87ms +step:188/1695 train_time:17648ms step_avg:93.87ms +step:189/1695 train_time:17742ms step_avg:93.87ms +step:190/1695 train_time:17835ms step_avg:93.87ms +step:191/1695 train_time:17930ms step_avg:93.87ms +step:192/1695 train_time:18024ms step_avg:93.88ms +step:193/1695 train_time:18119ms step_avg:93.88ms +step:194/1695 train_time:18213ms step_avg:93.88ms +step:195/1695 train_time:18308ms step_avg:93.89ms +step:196/1695 train_time:18402ms step_avg:93.89ms +step:197/1695 train_time:18496ms step_avg:93.89ms +step:198/1695 train_time:18590ms step_avg:93.89ms +step:199/1695 train_time:18684ms step_avg:93.89ms +step:200/1695 train_time:18777ms step_avg:93.89ms +step:201/1695 train_time:18871ms step_avg:93.89ms +step:202/1695 train_time:18967ms step_avg:93.90ms +step:203/1695 train_time:19062ms step_avg:93.90ms +step:204/1695 train_time:19156ms step_avg:93.90ms +step:205/1695 train_time:19250ms step_avg:93.90ms +step:206/1695 train_time:19344ms step_avg:93.90ms +step:207/1695 train_time:19438ms step_avg:93.90ms +step:208/1695 train_time:19532ms step_avg:93.90ms +step:209/1695 train_time:19626ms step_avg:93.91ms +step:210/1695 train_time:19721ms step_avg:93.91ms +step:211/1695 train_time:19815ms step_avg:93.91ms +step:212/1695 train_time:19908ms step_avg:93.91ms +step:213/1695 train_time:20003ms step_avg:93.91ms +step:214/1695 train_time:20098ms step_avg:93.92ms +step:215/1695 train_time:20192ms step_avg:93.92ms +step:216/1695 train_time:20286ms step_avg:93.92ms +step:217/1695 train_time:20380ms step_avg:93.92ms +step:218/1695 train_time:20475ms step_avg:93.92ms +step:219/1695 train_time:20569ms step_avg:93.92ms +step:220/1695 train_time:20662ms step_avg:93.92ms +step:221/1695 train_time:20756ms step_avg:93.92ms +step:222/1695 train_time:20850ms step_avg:93.92ms +step:223/1695 train_time:20943ms step_avg:93.92ms +step:224/1695 train_time:21037ms step_avg:93.92ms +step:225/1695 train_time:21131ms step_avg:93.92ms +step:226/1695 train_time:21226ms step_avg:93.92ms +step:227/1695 train_time:21320ms step_avg:93.92ms +step:228/1695 train_time:21414ms step_avg:93.92ms +step:229/1695 train_time:21508ms step_avg:93.92ms +step:230/1695 train_time:21602ms step_avg:93.92ms +step:231/1695 train_time:21695ms step_avg:93.92ms +step:232/1695 train_time:21789ms step_avg:93.92ms +step:233/1695 train_time:21883ms step_avg:93.92ms +step:234/1695 train_time:21977ms step_avg:93.92ms +step:235/1695 train_time:22071ms step_avg:93.92ms +step:236/1695 train_time:22165ms step_avg:93.92ms +step:237/1695 train_time:22258ms step_avg:93.92ms +step:238/1695 train_time:22352ms step_avg:93.92ms +step:239/1695 train_time:22447ms step_avg:93.92ms +step:240/1695 train_time:22541ms step_avg:93.92ms +step:241/1695 train_time:22635ms step_avg:93.92ms +step:242/1695 train_time:22729ms step_avg:93.92ms +step:243/1695 train_time:22824ms step_avg:93.93ms +step:244/1695 train_time:22918ms step_avg:93.93ms +step:245/1695 train_time:23012ms step_avg:93.93ms +step:246/1695 train_time:23106ms step_avg:93.93ms +step:247/1695 train_time:23199ms step_avg:93.92ms +step:248/1695 train_time:23293ms step_avg:93.92ms +step:249/1695 train_time:23387ms step_avg:93.92ms +step:250/1695 train_time:23482ms step_avg:93.93ms +step:250/1695 val_loss:4.0734 train_time:23574ms step_avg:94.30ms +step:251/1695 train_time:23603ms step_avg:94.04ms +step:252/1695 train_time:23679ms step_avg:93.97ms +step:253/1695 train_time:23779ms step_avg:93.99ms +step:254/1695 train_time:23874ms step_avg:93.99ms +step:255/1695 train_time:23968ms step_avg:93.99ms +step:256/1695 train_time:24063ms step_avg:94.00ms +step:257/1695 train_time:24157ms step_avg:93.99ms +step:258/1695 train_time:24250ms step_avg:93.99ms +step:259/1695 train_time:24345ms step_avg:93.99ms +step:260/1695 train_time:24439ms step_avg:93.99ms +step:261/1695 train_time:24532ms step_avg:93.99ms +step:262/1695 train_time:24628ms step_avg:94.00ms +step:263/1695 train_time:24724ms step_avg:94.01ms +step:264/1695 train_time:24819ms step_avg:94.01ms +step:265/1695 train_time:24914ms step_avg:94.02ms +step:266/1695 train_time:25009ms step_avg:94.02ms +step:267/1695 train_time:25104ms step_avg:94.02ms +step:268/1695 train_time:25199ms step_avg:94.03ms +step:269/1695 train_time:25292ms step_avg:94.02ms +step:270/1695 train_time:25386ms step_avg:94.02ms +step:271/1695 train_time:25480ms step_avg:94.02ms +step:272/1695 train_time:25573ms step_avg:94.02ms +step:273/1695 train_time:25669ms step_avg:94.03ms +step:274/1695 train_time:25765ms step_avg:94.03ms +step:275/1695 train_time:25860ms step_avg:94.04ms +step:276/1695 train_time:25955ms step_avg:94.04ms +step:277/1695 train_time:26049ms step_avg:94.04ms +step:278/1695 train_time:26145ms step_avg:94.05ms +step:279/1695 train_time:26239ms step_avg:94.05ms +step:280/1695 train_time:26333ms step_avg:94.05ms +step:281/1695 train_time:26427ms step_avg:94.05ms +step:282/1695 train_time:26521ms step_avg:94.05ms +step:283/1695 train_time:26614ms step_avg:94.04ms +step:284/1695 train_time:26709ms step_avg:94.05ms +step:285/1695 train_time:26805ms step_avg:94.05ms +step:286/1695 train_time:26899ms step_avg:94.05ms +step:287/1695 train_time:26994ms step_avg:94.05ms +step:288/1695 train_time:27089ms step_avg:94.06ms +step:289/1695 train_time:27185ms step_avg:94.07ms +step:290/1695 train_time:27280ms step_avg:94.07ms +step:291/1695 train_time:27374ms step_avg:94.07ms +step:292/1695 train_time:27469ms step_avg:94.07ms +step:293/1695 train_time:27564ms step_avg:94.07ms +step:294/1695 train_time:27659ms step_avg:94.08ms +step:295/1695 train_time:27753ms step_avg:94.08ms +step:296/1695 train_time:27847ms step_avg:94.08ms +step:297/1695 train_time:27942ms step_avg:94.08ms +step:298/1695 train_time:28037ms step_avg:94.08ms +step:299/1695 train_time:28131ms step_avg:94.08ms +step:300/1695 train_time:28226ms step_avg:94.09ms +step:301/1695 train_time:28321ms step_avg:94.09ms +step:302/1695 train_time:28415ms step_avg:94.09ms +step:303/1695 train_time:28509ms step_avg:94.09ms +step:304/1695 train_time:28605ms step_avg:94.09ms +step:305/1695 train_time:28698ms step_avg:94.09ms +step:306/1695 train_time:28791ms step_avg:94.09ms +step:307/1695 train_time:28886ms step_avg:94.09ms +step:308/1695 train_time:28981ms step_avg:94.09ms +step:309/1695 train_time:29076ms step_avg:94.10ms +step:310/1695 train_time:29170ms step_avg:94.10ms +step:311/1695 train_time:29265ms step_avg:94.10ms +step:312/1695 train_time:29360ms step_avg:94.10ms +step:313/1695 train_time:29453ms step_avg:94.10ms +step:314/1695 train_time:29548ms step_avg:94.10ms +step:315/1695 train_time:29643ms step_avg:94.10ms +step:316/1695 train_time:29737ms step_avg:94.11ms +step:317/1695 train_time:29831ms step_avg:94.10ms +step:318/1695 train_time:29925ms step_avg:94.11ms +step:319/1695 train_time:30020ms step_avg:94.11ms +step:320/1695 train_time:30113ms step_avg:94.10ms +step:321/1695 train_time:30209ms step_avg:94.11ms +step:322/1695 train_time:30304ms step_avg:94.11ms +step:323/1695 train_time:30398ms step_avg:94.11ms +step:324/1695 train_time:30492ms step_avg:94.11ms +step:325/1695 train_time:30586ms step_avg:94.11ms +step:326/1695 train_time:30681ms step_avg:94.11ms +step:327/1695 train_time:30775ms step_avg:94.11ms +step:328/1695 train_time:30870ms step_avg:94.12ms +step:329/1695 train_time:30965ms step_avg:94.12ms +step:330/1695 train_time:31060ms step_avg:94.12ms +step:331/1695 train_time:31154ms step_avg:94.12ms +step:332/1695 train_time:31248ms step_avg:94.12ms +step:333/1695 train_time:31342ms step_avg:94.12ms +step:334/1695 train_time:31437ms step_avg:94.12ms +step:335/1695 train_time:31531ms step_avg:94.12ms +step:336/1695 train_time:31625ms step_avg:94.12ms +step:337/1695 train_time:31720ms step_avg:94.12ms +step:338/1695 train_time:31814ms step_avg:94.12ms +step:339/1695 train_time:31908ms step_avg:94.12ms +step:340/1695 train_time:32003ms step_avg:94.13ms +step:341/1695 train_time:32097ms step_avg:94.13ms +step:342/1695 train_time:32192ms step_avg:94.13ms +step:343/1695 train_time:32287ms step_avg:94.13ms +step:344/1695 train_time:32382ms step_avg:94.13ms +step:345/1695 train_time:32477ms step_avg:94.13ms +step:346/1695 train_time:32571ms step_avg:94.13ms +step:347/1695 train_time:32666ms step_avg:94.14ms +step:348/1695 train_time:32760ms step_avg:94.14ms +step:349/1695 train_time:32854ms step_avg:94.14ms +step:350/1695 train_time:32949ms step_avg:94.14ms +step:351/1695 train_time:33045ms step_avg:94.14ms +step:352/1695 train_time:33139ms step_avg:94.15ms +step:353/1695 train_time:33233ms step_avg:94.15ms +step:354/1695 train_time:33329ms step_avg:94.15ms +step:355/1695 train_time:33423ms step_avg:94.15ms +step:356/1695 train_time:33517ms step_avg:94.15ms +step:357/1695 train_time:33611ms step_avg:94.15ms +step:358/1695 train_time:33706ms step_avg:94.15ms +step:359/1695 train_time:33802ms step_avg:94.16ms +step:360/1695 train_time:33897ms step_avg:94.16ms +step:361/1695 train_time:33991ms step_avg:94.16ms +step:362/1695 train_time:34085ms step_avg:94.16ms +step:363/1695 train_time:34181ms step_avg:94.16ms +step:364/1695 train_time:34275ms step_avg:94.16ms +step:365/1695 train_time:34370ms step_avg:94.16ms +step:366/1695 train_time:34465ms step_avg:94.17ms +step:367/1695 train_time:34559ms step_avg:94.17ms +step:368/1695 train_time:34653ms step_avg:94.16ms +step:369/1695 train_time:34747ms step_avg:94.17ms +step:370/1695 train_time:34842ms step_avg:94.17ms +step:371/1695 train_time:34936ms step_avg:94.17ms +step:372/1695 train_time:35030ms step_avg:94.17ms +step:373/1695 train_time:35126ms step_avg:94.17ms +step:374/1695 train_time:35222ms step_avg:94.18ms +step:375/1695 train_time:35315ms step_avg:94.17ms +step:375/1695 val_loss:3.8753 train_time:35407ms step_avg:94.42ms +step:376/1695 train_time:35436ms step_avg:94.24ms +step:377/1695 train_time:35517ms step_avg:94.21ms +step:378/1695 train_time:35616ms step_avg:94.22ms +step:379/1695 train_time:35712ms step_avg:94.23ms +step:380/1695 train_time:35808ms step_avg:94.23ms +step:381/1695 train_time:35905ms step_avg:94.24ms +step:382/1695 train_time:36000ms step_avg:94.24ms +step:383/1695 train_time:36095ms step_avg:94.24ms +step:384/1695 train_time:36191ms step_avg:94.25ms +step:385/1695 train_time:36286ms step_avg:94.25ms +step:386/1695 train_time:36382ms step_avg:94.25ms +step:387/1695 train_time:36479ms step_avg:94.26ms +step:388/1695 train_time:36577ms step_avg:94.27ms +step:389/1695 train_time:36674ms step_avg:94.28ms +step:390/1695 train_time:36770ms step_avg:94.28ms +step:391/1695 train_time:36866ms step_avg:94.29ms +step:392/1695 train_time:36961ms step_avg:94.29ms +step:393/1695 train_time:37057ms step_avg:94.29ms +step:394/1695 train_time:37153ms step_avg:94.30ms +step:395/1695 train_time:37248ms step_avg:94.30ms +step:396/1695 train_time:37344ms step_avg:94.30ms +step:397/1695 train_time:37441ms step_avg:94.31ms +step:398/1695 train_time:37537ms step_avg:94.32ms +step:399/1695 train_time:37634ms step_avg:94.32ms +step:400/1695 train_time:37731ms step_avg:94.33ms +step:401/1695 train_time:37827ms step_avg:94.33ms +step:402/1695 train_time:37924ms step_avg:94.34ms +step:403/1695 train_time:38021ms step_avg:94.34ms +step:404/1695 train_time:38117ms step_avg:94.35ms +step:405/1695 train_time:38214ms step_avg:94.35ms +step:406/1695 train_time:38309ms step_avg:94.36ms +step:407/1695 train_time:38405ms step_avg:94.36ms +step:408/1695 train_time:38501ms step_avg:94.37ms +step:409/1695 train_time:38597ms step_avg:94.37ms +step:410/1695 train_time:38693ms step_avg:94.37ms +step:411/1695 train_time:38789ms step_avg:94.38ms +step:412/1695 train_time:38886ms step_avg:94.38ms +step:413/1695 train_time:38984ms step_avg:94.39ms +step:414/1695 train_time:39081ms step_avg:94.40ms +step:415/1695 train_time:39177ms step_avg:94.40ms +step:416/1695 train_time:39273ms step_avg:94.41ms +step:417/1695 train_time:39369ms step_avg:94.41ms +step:418/1695 train_time:39465ms step_avg:94.41ms +step:419/1695 train_time:39562ms step_avg:94.42ms +step:420/1695 train_time:39659ms step_avg:94.43ms +step:421/1695 train_time:39755ms step_avg:94.43ms +step:422/1695 train_time:39851ms step_avg:94.43ms +step:423/1695 train_time:39948ms step_avg:94.44ms +step:424/1695 train_time:40045ms step_avg:94.44ms +step:425/1695 train_time:40141ms step_avg:94.45ms +step:426/1695 train_time:40237ms step_avg:94.45ms +step:427/1695 train_time:40333ms step_avg:94.46ms +step:428/1695 train_time:40428ms step_avg:94.46ms +step:429/1695 train_time:40525ms step_avg:94.46ms +step:430/1695 train_time:40622ms step_avg:94.47ms +step:431/1695 train_time:40718ms step_avg:94.47ms +step:432/1695 train_time:40814ms step_avg:94.48ms +step:433/1695 train_time:40910ms step_avg:94.48ms +step:434/1695 train_time:41006ms step_avg:94.48ms +step:435/1695 train_time:41103ms step_avg:94.49ms +step:436/1695 train_time:41200ms step_avg:94.50ms +step:437/1695 train_time:41297ms step_avg:94.50ms +step:438/1695 train_time:41392ms step_avg:94.50ms +step:439/1695 train_time:41488ms step_avg:94.51ms +step:440/1695 train_time:41585ms step_avg:94.51ms +step:441/1695 train_time:41682ms step_avg:94.52ms +step:442/1695 train_time:41778ms step_avg:94.52ms +step:443/1695 train_time:41874ms step_avg:94.52ms +step:444/1695 train_time:41970ms step_avg:94.53ms +step:445/1695 train_time:42066ms step_avg:94.53ms +step:446/1695 train_time:42163ms step_avg:94.54ms +step:447/1695 train_time:42260ms step_avg:94.54ms +step:448/1695 train_time:42356ms step_avg:94.55ms +step:449/1695 train_time:42454ms step_avg:94.55ms +step:450/1695 train_time:42549ms step_avg:94.55ms +step:451/1695 train_time:42646ms step_avg:94.56ms +step:452/1695 train_time:42742ms step_avg:94.56ms +step:453/1695 train_time:42838ms step_avg:94.57ms +step:454/1695 train_time:42935ms step_avg:94.57ms +step:455/1695 train_time:43031ms step_avg:94.57ms +step:456/1695 train_time:43126ms step_avg:94.58ms +step:457/1695 train_time:43223ms step_avg:94.58ms +step:458/1695 train_time:43320ms step_avg:94.59ms +step:459/1695 train_time:43417ms step_avg:94.59ms +step:460/1695 train_time:43513ms step_avg:94.59ms +step:461/1695 train_time:43609ms step_avg:94.60ms +step:462/1695 train_time:43705ms step_avg:94.60ms +step:463/1695 train_time:43802ms step_avg:94.60ms +step:464/1695 train_time:43898ms step_avg:94.61ms +step:465/1695 train_time:43994ms step_avg:94.61ms +step:466/1695 train_time:44090ms step_avg:94.61ms +step:467/1695 train_time:44186ms step_avg:94.62ms +step:468/1695 train_time:44283ms step_avg:94.62ms +step:469/1695 train_time:44381ms step_avg:94.63ms +step:470/1695 train_time:44477ms step_avg:94.63ms +step:471/1695 train_time:44573ms step_avg:94.64ms +step:472/1695 train_time:44669ms step_avg:94.64ms +step:473/1695 train_time:44765ms step_avg:94.64ms +step:474/1695 train_time:44861ms step_avg:94.64ms +step:475/1695 train_time:44958ms step_avg:94.65ms +step:476/1695 train_time:45055ms step_avg:94.65ms +step:477/1695 train_time:45150ms step_avg:94.65ms +step:478/1695 train_time:45246ms step_avg:94.66ms +step:479/1695 train_time:45343ms step_avg:94.66ms +step:480/1695 train_time:45439ms step_avg:94.67ms +step:481/1695 train_time:45536ms step_avg:94.67ms +step:482/1695 train_time:45632ms step_avg:94.67ms +step:483/1695 train_time:45728ms step_avg:94.67ms +step:484/1695 train_time:45823ms step_avg:94.68ms +step:485/1695 train_time:45920ms step_avg:94.68ms +step:486/1695 train_time:46016ms step_avg:94.68ms +step:487/1695 train_time:46112ms step_avg:94.69ms +step:488/1695 train_time:46208ms step_avg:94.69ms +step:489/1695 train_time:46304ms step_avg:94.69ms +step:490/1695 train_time:46401ms step_avg:94.70ms +step:491/1695 train_time:46498ms step_avg:94.70ms +step:492/1695 train_time:46593ms step_avg:94.70ms +step:493/1695 train_time:46689ms step_avg:94.70ms +step:494/1695 train_time:46785ms step_avg:94.71ms +step:495/1695 train_time:46882ms step_avg:94.71ms +step:496/1695 train_time:46979ms step_avg:94.72ms +step:497/1695 train_time:47075ms step_avg:94.72ms +step:498/1695 train_time:47170ms step_avg:94.72ms +step:499/1695 train_time:47267ms step_avg:94.72ms +step:500/1695 train_time:47364ms step_avg:94.73ms +step:500/1695 val_loss:3.7308 train_time:47458ms step_avg:94.92ms +step:501/1695 train_time:47487ms step_avg:94.78ms +step:502/1695 train_time:47568ms step_avg:94.76ms +step:503/1695 train_time:47669ms step_avg:94.77ms +step:504/1695 train_time:47766ms step_avg:94.77ms +step:505/1695 train_time:47862ms step_avg:94.78ms +step:506/1695 train_time:47958ms step_avg:94.78ms +step:507/1695 train_time:48053ms step_avg:94.78ms +step:508/1695 train_time:48149ms step_avg:94.78ms +step:509/1695 train_time:48245ms step_avg:94.78ms +step:510/1695 train_time:48341ms step_avg:94.79ms +step:511/1695 train_time:48436ms step_avg:94.79ms +step:512/1695 train_time:48533ms step_avg:94.79ms +step:513/1695 train_time:48631ms step_avg:94.80ms +step:514/1695 train_time:48729ms step_avg:94.80ms +step:515/1695 train_time:48826ms step_avg:94.81ms +step:516/1695 train_time:48923ms step_avg:94.81ms +step:517/1695 train_time:49019ms step_avg:94.82ms +step:518/1695 train_time:49115ms step_avg:94.82ms +step:519/1695 train_time:49211ms step_avg:94.82ms +step:520/1695 train_time:49308ms step_avg:94.82ms +step:521/1695 train_time:49404ms step_avg:94.83ms +step:522/1695 train_time:49501ms step_avg:94.83ms +step:523/1695 train_time:49598ms step_avg:94.83ms +step:524/1695 train_time:49694ms step_avg:94.84ms +step:525/1695 train_time:49792ms step_avg:94.84ms +step:526/1695 train_time:49890ms step_avg:94.85ms +step:527/1695 train_time:49989ms step_avg:94.86ms +step:528/1695 train_time:50085ms step_avg:94.86ms +step:529/1695 train_time:50181ms step_avg:94.86ms +step:530/1695 train_time:50277ms step_avg:94.86ms +step:531/1695 train_time:50373ms step_avg:94.87ms +step:532/1695 train_time:50470ms step_avg:94.87ms +step:533/1695 train_time:50567ms step_avg:94.87ms +step:534/1695 train_time:50665ms step_avg:94.88ms +step:535/1695 train_time:50763ms step_avg:94.88ms +step:536/1695 train_time:50861ms step_avg:94.89ms +step:537/1695 train_time:50956ms step_avg:94.89ms +step:538/1695 train_time:51053ms step_avg:94.89ms +step:539/1695 train_time:51149ms step_avg:94.90ms +step:540/1695 train_time:51246ms step_avg:94.90ms +step:541/1695 train_time:51342ms step_avg:94.90ms +step:542/1695 train_time:51438ms step_avg:94.90ms +step:543/1695 train_time:51533ms step_avg:94.90ms +step:544/1695 train_time:51630ms step_avg:94.91ms +step:545/1695 train_time:51726ms step_avg:94.91ms +step:546/1695 train_time:51824ms step_avg:94.91ms +step:547/1695 train_time:51921ms step_avg:94.92ms +step:548/1695 train_time:52018ms step_avg:94.92ms +step:549/1695 train_time:52114ms step_avg:94.93ms +step:550/1695 train_time:52210ms step_avg:94.93ms +step:551/1695 train_time:52308ms step_avg:94.93ms +step:552/1695 train_time:52404ms step_avg:94.94ms +step:553/1695 train_time:52501ms step_avg:94.94ms +step:554/1695 train_time:52597ms step_avg:94.94ms +step:555/1695 train_time:52694ms step_avg:94.94ms +step:556/1695 train_time:52790ms step_avg:94.95ms +step:557/1695 train_time:52888ms step_avg:94.95ms +step:558/1695 train_time:52986ms step_avg:94.96ms +step:559/1695 train_time:53082ms step_avg:94.96ms +step:560/1695 train_time:53179ms step_avg:94.96ms +step:561/1695 train_time:53275ms step_avg:94.96ms +step:562/1695 train_time:53371ms step_avg:94.97ms +step:563/1695 train_time:53467ms step_avg:94.97ms +step:564/1695 train_time:53564ms step_avg:94.97ms +step:565/1695 train_time:53661ms step_avg:94.98ms +step:566/1695 train_time:53758ms step_avg:94.98ms +step:567/1695 train_time:53854ms step_avg:94.98ms +step:568/1695 train_time:53951ms step_avg:94.98ms +step:569/1695 train_time:54048ms step_avg:94.99ms +step:570/1695 train_time:54144ms step_avg:94.99ms +step:571/1695 train_time:54241ms step_avg:94.99ms +step:572/1695 train_time:54337ms step_avg:95.00ms +step:573/1695 train_time:54433ms step_avg:95.00ms +step:574/1695 train_time:54530ms step_avg:95.00ms +step:575/1695 train_time:54626ms step_avg:95.00ms +step:576/1695 train_time:54724ms step_avg:95.01ms +step:577/1695 train_time:54821ms step_avg:95.01ms +step:578/1695 train_time:54918ms step_avg:95.01ms +step:579/1695 train_time:55015ms step_avg:95.02ms +step:580/1695 train_time:55111ms step_avg:95.02ms +step:581/1695 train_time:55208ms step_avg:95.02ms +step:582/1695 train_time:55304ms step_avg:95.02ms +step:583/1695 train_time:55400ms step_avg:95.03ms +step:584/1695 train_time:55497ms step_avg:95.03ms +step:585/1695 train_time:55593ms step_avg:95.03ms +step:586/1695 train_time:55690ms step_avg:95.03ms +step:587/1695 train_time:55787ms step_avg:95.04ms +step:588/1695 train_time:55885ms step_avg:95.04ms +step:589/1695 train_time:55981ms step_avg:95.04ms +step:590/1695 train_time:56077ms step_avg:95.05ms +step:591/1695 train_time:56173ms step_avg:95.05ms +step:592/1695 train_time:56270ms step_avg:95.05ms +step:593/1695 train_time:56367ms step_avg:95.05ms +step:594/1695 train_time:56465ms step_avg:95.06ms +step:595/1695 train_time:56561ms step_avg:95.06ms +step:596/1695 train_time:56657ms step_avg:95.06ms +step:597/1695 train_time:56754ms step_avg:95.06ms +step:598/1695 train_time:56850ms step_avg:95.07ms +step:599/1695 train_time:56946ms step_avg:95.07ms +step:600/1695 train_time:57043ms step_avg:95.07ms +step:601/1695 train_time:57139ms step_avg:95.07ms +step:602/1695 train_time:57235ms step_avg:95.08ms +step:603/1695 train_time:57331ms step_avg:95.08ms +step:604/1695 train_time:57428ms step_avg:95.08ms +step:605/1695 train_time:57526ms step_avg:95.08ms +step:606/1695 train_time:57623ms step_avg:95.09ms +step:607/1695 train_time:57720ms step_avg:95.09ms +step:608/1695 train_time:57815ms step_avg:95.09ms +step:609/1695 train_time:57911ms step_avg:95.09ms +step:610/1695 train_time:58009ms step_avg:95.10ms +step:611/1695 train_time:58107ms step_avg:95.10ms +step:612/1695 train_time:58203ms step_avg:95.10ms +step:613/1695 train_time:58300ms step_avg:95.11ms +step:614/1695 train_time:58396ms step_avg:95.11ms +step:615/1695 train_time:58492ms step_avg:95.11ms +step:616/1695 train_time:58591ms step_avg:95.11ms +step:617/1695 train_time:58689ms step_avg:95.12ms +step:618/1695 train_time:58784ms step_avg:95.12ms +step:619/1695 train_time:58881ms step_avg:95.12ms +step:620/1695 train_time:58977ms step_avg:95.12ms +step:621/1695 train_time:59073ms step_avg:95.13ms +step:622/1695 train_time:59170ms step_avg:95.13ms +step:623/1695 train_time:59267ms step_avg:95.13ms +step:624/1695 train_time:59363ms step_avg:95.13ms +step:625/1695 train_time:59459ms step_avg:95.13ms +step:625/1695 val_loss:3.6477 train_time:59553ms step_avg:95.28ms +step:626/1695 train_time:59581ms step_avg:95.18ms +step:627/1695 train_time:59663ms step_avg:95.16ms +step:628/1695 train_time:59764ms step_avg:95.17ms +step:629/1695 train_time:60094ms step_avg:95.54ms +step:630/1695 train_time:60191ms step_avg:95.54ms +step:631/1695 train_time:60287ms step_avg:95.54ms +step:632/1695 train_time:60671ms step_avg:96.00ms +step:633/1695 train_time:60767ms step_avg:96.00ms +step:634/1695 train_time:60865ms step_avg:96.00ms +step:635/1695 train_time:61199ms step_avg:96.38ms +step:636/1695 train_time:61295ms step_avg:96.38ms +step:637/1695 train_time:61392ms step_avg:96.38ms +step:638/1695 train_time:61489ms step_avg:96.38ms +step:639/1695 train_time:61586ms step_avg:96.38ms +step:640/1695 train_time:61684ms step_avg:96.38ms +step:641/1695 train_time:61780ms step_avg:96.38ms +step:642/1695 train_time:61877ms step_avg:96.38ms +step:643/1695 train_time:61974ms step_avg:96.38ms +step:644/1695 train_time:62073ms step_avg:96.39ms +step:645/1695 train_time:62174ms step_avg:96.39ms +step:646/1695 train_time:62274ms step_avg:96.40ms +step:647/1695 train_time:62606ms step_avg:96.76ms +step:648/1695 train_time:62702ms step_avg:96.76ms +step:649/1695 train_time:62799ms step_avg:96.76ms +step:650/1695 train_time:62896ms step_avg:96.76ms +step:651/1695 train_time:62993ms step_avg:96.76ms +step:652/1695 train_time:63091ms step_avg:96.76ms +step:653/1695 train_time:63577ms step_avg:97.36ms +step:654/1695 train_time:63626ms step_avg:97.29ms +step:655/1695 train_time:63722ms step_avg:97.29ms +step:656/1695 train_time:63819ms step_avg:97.28ms +step:657/1695 train_time:63916ms step_avg:97.28ms +step:658/1695 train_time:64012ms step_avg:97.28ms +step:659/1695 train_time:64110ms step_avg:97.28ms +step:660/1695 train_time:64207ms step_avg:97.28ms +step:661/1695 train_time:64304ms step_avg:97.28ms +step:662/1695 train_time:64401ms step_avg:97.28ms +step:663/1695 train_time:64501ms step_avg:97.29ms +step:664/1695 train_time:64602ms step_avg:97.29ms +step:665/1695 train_time:64701ms step_avg:97.29ms +step:666/1695 train_time:64799ms step_avg:97.30ms +step:667/1695 train_time:64895ms step_avg:97.29ms +step:668/1695 train_time:64992ms step_avg:97.29ms +step:669/1695 train_time:65090ms step_avg:97.29ms +step:670/1695 train_time:65187ms step_avg:97.29ms +step:671/1695 train_time:65284ms step_avg:97.29ms +step:672/1695 train_time:65381ms step_avg:97.29ms +step:673/1695 train_time:65479ms step_avg:97.29ms +step:674/1695 train_time:65577ms step_avg:97.30ms +step:675/1695 train_time:65676ms step_avg:97.30ms +step:676/1695 train_time:65775ms step_avg:97.30ms +step:677/1695 train_time:65873ms step_avg:97.30ms +step:678/1695 train_time:65971ms step_avg:97.30ms +step:679/1695 train_time:66070ms step_avg:97.30ms +step:680/1695 train_time:66167ms step_avg:97.31ms +step:681/1695 train_time:66265ms step_avg:97.30ms +step:682/1695 train_time:66362ms step_avg:97.31ms +step:683/1695 train_time:66460ms step_avg:97.31ms +step:684/1695 train_time:66558ms step_avg:97.31ms +step:685/1695 train_time:66655ms step_avg:97.31ms +step:686/1695 train_time:66753ms step_avg:97.31ms +step:687/1695 train_time:66852ms step_avg:97.31ms +step:688/1695 train_time:66950ms step_avg:97.31ms +step:689/1695 train_time:67049ms step_avg:97.31ms +step:690/1695 train_time:67146ms step_avg:97.31ms +step:691/1695 train_time:67244ms step_avg:97.31ms +step:692/1695 train_time:67342ms step_avg:97.31ms +step:693/1695 train_time:67440ms step_avg:97.32ms +step:694/1695 train_time:67537ms step_avg:97.32ms +step:695/1695 train_time:67634ms step_avg:97.32ms +step:696/1695 train_time:67733ms step_avg:97.32ms +step:697/1695 train_time:67831ms step_avg:97.32ms +step:698/1695 train_time:67929ms step_avg:97.32ms +step:699/1695 train_time:68027ms step_avg:97.32ms +step:700/1695 train_time:68124ms step_avg:97.32ms +step:701/1695 train_time:68222ms step_avg:97.32ms +step:702/1695 train_time:68319ms step_avg:97.32ms +step:703/1695 train_time:68417ms step_avg:97.32ms +step:704/1695 train_time:68514ms step_avg:97.32ms +step:705/1695 train_time:68613ms step_avg:97.32ms +step:706/1695 train_time:68712ms step_avg:97.33ms +step:707/1695 train_time:68810ms step_avg:97.33ms +step:708/1695 train_time:68908ms step_avg:97.33ms +step:709/1695 train_time:69006ms step_avg:97.33ms +step:710/1695 train_time:69103ms step_avg:97.33ms +step:711/1695 train_time:69200ms step_avg:97.33ms +step:712/1695 train_time:69298ms step_avg:97.33ms +step:713/1695 train_time:69395ms step_avg:97.33ms +step:714/1695 train_time:69494ms step_avg:97.33ms +step:715/1695 train_time:69591ms step_avg:97.33ms +step:716/1695 train_time:69690ms step_avg:97.33ms +step:717/1695 train_time:69789ms step_avg:97.33ms +step:718/1695 train_time:69886ms step_avg:97.33ms +step:719/1695 train_time:69984ms step_avg:97.33ms +step:720/1695 train_time:70081ms step_avg:97.33ms +step:721/1695 train_time:70179ms step_avg:97.34ms +step:722/1695 train_time:70277ms step_avg:97.34ms +step:723/1695 train_time:70374ms step_avg:97.34ms +step:724/1695 train_time:70473ms step_avg:97.34ms +step:725/1695 train_time:70571ms step_avg:97.34ms +step:726/1695 train_time:70669ms step_avg:97.34ms +step:727/1695 train_time:70767ms step_avg:97.34ms +step:728/1695 train_time:70864ms step_avg:97.34ms +step:729/1695 train_time:70962ms step_avg:97.34ms +step:730/1695 train_time:71060ms step_avg:97.34ms +step:731/1695 train_time:71158ms step_avg:97.34ms +step:732/1695 train_time:71256ms step_avg:97.34ms +step:733/1695 train_time:71353ms step_avg:97.34ms +step:734/1695 train_time:71451ms step_avg:97.34ms +step:735/1695 train_time:71549ms step_avg:97.35ms +step:736/1695 train_time:71647ms step_avg:97.35ms +step:737/1695 train_time:71744ms step_avg:97.35ms +step:738/1695 train_time:71842ms step_avg:97.35ms +step:739/1695 train_time:71939ms step_avg:97.35ms +step:740/1695 train_time:72037ms step_avg:97.35ms +step:741/1695 train_time:72135ms step_avg:97.35ms +step:742/1695 train_time:72233ms step_avg:97.35ms +step:743/1695 train_time:72332ms step_avg:97.35ms +step:744/1695 train_time:72430ms step_avg:97.35ms +step:745/1695 train_time:72528ms step_avg:97.35ms +step:746/1695 train_time:72626ms step_avg:97.35ms +step:747/1695 train_time:72724ms step_avg:97.36ms +step:748/1695 train_time:72822ms step_avg:97.36ms +step:749/1695 train_time:72920ms step_avg:97.36ms +step:750/1695 train_time:73018ms step_avg:97.36ms +step:750/1695 val_loss:3.5859 train_time:73112ms step_avg:97.48ms +step:751/1695 train_time:73142ms step_avg:97.39ms +step:752/1695 train_time:73224ms step_avg:97.37ms +step:753/1695 train_time:73324ms step_avg:97.38ms +step:754/1695 train_time:73422ms step_avg:97.38ms +step:755/1695 train_time:73519ms step_avg:97.38ms +step:756/1695 train_time:73617ms step_avg:97.38ms +step:757/1695 train_time:73714ms step_avg:97.38ms +step:758/1695 train_time:73812ms step_avg:97.38ms +step:759/1695 train_time:73909ms step_avg:97.38ms +step:760/1695 train_time:74228ms step_avg:97.67ms +step:761/1695 train_time:74324ms step_avg:97.67ms +step:762/1695 train_time:74422ms step_avg:97.67ms +step:763/1695 train_time:74520ms step_avg:97.67ms +step:764/1695 train_time:74617ms step_avg:97.67ms +step:765/1695 train_time:74715ms step_avg:97.67ms +step:766/1695 train_time:74812ms step_avg:97.67ms +step:767/1695 train_time:74910ms step_avg:97.67ms +step:768/1695 train_time:75007ms step_avg:97.67ms +step:769/1695 train_time:75104ms step_avg:97.66ms +step:770/1695 train_time:75203ms step_avg:97.67ms +step:771/1695 train_time:75302ms step_avg:97.67ms +step:772/1695 train_time:75401ms step_avg:97.67ms +step:773/1695 train_time:75499ms step_avg:97.67ms +step:774/1695 train_time:75597ms step_avg:97.67ms +step:775/1695 train_time:75695ms step_avg:97.67ms +step:776/1695 train_time:75792ms step_avg:97.67ms +step:777/1695 train_time:75889ms step_avg:97.67ms +step:778/1695 train_time:75988ms step_avg:97.67ms +step:779/1695 train_time:76086ms step_avg:97.67ms +step:780/1695 train_time:76184ms step_avg:97.67ms +step:781/1695 train_time:76282ms step_avg:97.67ms +step:782/1695 train_time:76381ms step_avg:97.67ms +step:783/1695 train_time:76480ms step_avg:97.68ms +step:784/1695 train_time:76578ms step_avg:97.68ms +step:785/1695 train_time:76677ms step_avg:97.68ms +step:786/1695 train_time:76775ms step_avg:97.68ms +step:787/1695 train_time:76872ms step_avg:97.68ms +step:788/1695 train_time:76969ms step_avg:97.68ms +step:789/1695 train_time:77068ms step_avg:97.68ms +step:790/1695 train_time:77167ms step_avg:97.68ms +step:791/1695 train_time:77265ms step_avg:97.68ms +step:792/1695 train_time:77363ms step_avg:97.68ms +step:793/1695 train_time:77461ms step_avg:97.68ms +step:794/1695 train_time:77559ms step_avg:97.68ms +step:795/1695 train_time:77656ms step_avg:97.68ms +step:796/1695 train_time:77755ms step_avg:97.68ms +step:797/1695 train_time:78197ms step_avg:98.11ms +step:798/1695 train_time:78293ms step_avg:98.11ms +step:799/1695 train_time:78390ms step_avg:98.11ms +step:800/1695 train_time:78488ms step_avg:98.11ms +step:801/1695 train_time:78584ms step_avg:98.11ms +step:802/1695 train_time:78681ms step_avg:98.11ms +step:803/1695 train_time:78778ms step_avg:98.11ms +step:804/1695 train_time:78877ms step_avg:98.11ms +step:805/1695 train_time:78974ms step_avg:98.10ms +step:806/1695 train_time:79073ms step_avg:98.11ms +step:807/1695 train_time:79177ms step_avg:98.11ms +step:808/1695 train_time:79277ms step_avg:98.12ms +step:809/1695 train_time:79376ms step_avg:98.12ms +step:810/1695 train_time:79475ms step_avg:98.12ms +step:811/1695 train_time:79575ms step_avg:98.12ms +step:812/1695 train_time:79673ms step_avg:98.12ms +step:813/1695 train_time:79771ms step_avg:98.12ms +step:814/1695 train_time:79869ms step_avg:98.12ms +step:815/1695 train_time:79966ms step_avg:98.12ms +step:816/1695 train_time:80064ms step_avg:98.12ms +step:817/1695 train_time:80162ms step_avg:98.12ms +step:818/1695 train_time:80260ms step_avg:98.12ms +step:819/1695 train_time:80359ms step_avg:98.12ms +step:820/1695 train_time:80457ms step_avg:98.12ms +step:821/1695 train_time:80556ms step_avg:98.12ms +step:822/1695 train_time:80654ms step_avg:98.12ms +step:823/1695 train_time:80753ms step_avg:98.12ms +step:824/1695 train_time:80852ms step_avg:98.12ms +step:825/1695 train_time:80951ms step_avg:98.12ms +step:826/1695 train_time:81048ms step_avg:98.12ms +step:827/1695 train_time:81148ms step_avg:98.12ms +step:828/1695 train_time:81246ms step_avg:98.12ms +step:829/1695 train_time:81344ms step_avg:98.12ms +step:830/1695 train_time:81442ms step_avg:98.12ms +step:831/1695 train_time:81540ms step_avg:98.12ms +step:832/1695 train_time:81639ms step_avg:98.12ms +step:833/1695 train_time:81736ms step_avg:98.12ms +step:834/1695 train_time:81835ms step_avg:98.12ms +step:835/1695 train_time:81933ms step_avg:98.12ms +step:836/1695 train_time:82032ms step_avg:98.12ms +step:837/1695 train_time:82130ms step_avg:98.12ms +step:838/1695 train_time:82229ms step_avg:98.13ms +step:839/1695 train_time:82327ms step_avg:98.13ms +step:840/1695 train_time:82425ms step_avg:98.13ms +step:841/1695 train_time:82522ms step_avg:98.12ms +step:842/1695 train_time:82620ms step_avg:98.12ms +step:843/1695 train_time:82718ms step_avg:98.12ms +step:844/1695 train_time:82816ms step_avg:98.12ms +step:845/1695 train_time:82915ms step_avg:98.12ms +step:846/1695 train_time:83013ms step_avg:98.12ms +step:847/1695 train_time:83112ms step_avg:98.13ms +step:848/1695 train_time:83211ms step_avg:98.13ms +step:849/1695 train_time:83309ms step_avg:98.13ms +step:850/1695 train_time:83408ms step_avg:98.13ms +step:851/1695 train_time:83507ms step_avg:98.13ms +step:852/1695 train_time:83605ms step_avg:98.13ms +step:853/1695 train_time:83703ms step_avg:98.13ms +step:854/1695 train_time:83801ms step_avg:98.13ms +step:855/1695 train_time:83898ms step_avg:98.13ms +step:856/1695 train_time:83997ms step_avg:98.13ms +step:857/1695 train_time:84095ms step_avg:98.13ms +step:858/1695 train_time:84194ms step_avg:98.13ms +step:859/1695 train_time:84293ms step_avg:98.13ms +step:860/1695 train_time:84393ms step_avg:98.13ms +step:861/1695 train_time:84492ms step_avg:98.13ms +step:862/1695 train_time:84591ms step_avg:98.13ms +step:863/1695 train_time:84690ms step_avg:98.13ms +step:864/1695 train_time:84789ms step_avg:98.14ms +step:865/1695 train_time:84888ms step_avg:98.14ms +step:866/1695 train_time:84987ms step_avg:98.14ms +step:867/1695 train_time:85085ms step_avg:98.14ms +step:868/1695 train_time:85182ms step_avg:98.14ms +step:869/1695 train_time:85280ms step_avg:98.14ms +step:870/1695 train_time:85379ms step_avg:98.14ms +step:871/1695 train_time:85479ms step_avg:98.14ms +step:872/1695 train_time:85578ms step_avg:98.14ms +step:873/1695 train_time:85678ms step_avg:98.14ms +step:874/1695 train_time:85777ms step_avg:98.14ms +step:875/1695 train_time:85876ms step_avg:98.14ms +step:875/1695 val_loss:3.5367 train_time:85973ms step_avg:98.25ms +step:876/1695 train_time:86005ms step_avg:98.18ms +step:877/1695 train_time:86083ms step_avg:98.16ms +step:878/1695 train_time:86183ms step_avg:98.16ms +step:879/1695 train_time:86281ms step_avg:98.16ms +step:880/1695 train_time:86380ms step_avg:98.16ms +step:881/1695 train_time:86479ms step_avg:98.16ms +step:882/1695 train_time:86578ms step_avg:98.16ms +step:883/1695 train_time:86677ms step_avg:98.16ms +step:884/1695 train_time:86776ms step_avg:98.16ms +step:885/1695 train_time:86875ms step_avg:98.16ms +step:886/1695 train_time:86975ms step_avg:98.17ms +step:887/1695 train_time:87077ms step_avg:98.17ms +step:888/1695 train_time:87178ms step_avg:98.17ms +step:889/1695 train_time:87278ms step_avg:98.18ms +step:890/1695 train_time:87378ms step_avg:98.18ms +step:891/1695 train_time:87478ms step_avg:98.18ms +step:892/1695 train_time:87577ms step_avg:98.18ms +step:893/1695 train_time:87677ms step_avg:98.18ms +step:894/1695 train_time:87775ms step_avg:98.18ms +step:895/1695 train_time:87874ms step_avg:98.18ms +step:896/1695 train_time:87974ms step_avg:98.19ms +step:897/1695 train_time:88074ms step_avg:98.19ms +step:898/1695 train_time:88174ms step_avg:98.19ms +step:899/1695 train_time:88275ms step_avg:98.19ms +step:900/1695 train_time:88376ms step_avg:98.20ms +step:901/1695 train_time:88476ms step_avg:98.20ms +step:902/1695 train_time:88576ms step_avg:98.20ms +step:903/1695 train_time:88676ms step_avg:98.20ms +step:904/1695 train_time:88776ms step_avg:98.20ms +step:905/1695 train_time:88875ms step_avg:98.20ms +step:906/1695 train_time:88975ms step_avg:98.21ms +step:907/1695 train_time:89076ms step_avg:98.21ms +step:908/1695 train_time:89177ms step_avg:98.21ms +step:909/1695 train_time:89276ms step_avg:98.21ms +step:910/1695 train_time:89376ms step_avg:98.22ms +step:911/1695 train_time:89476ms step_avg:98.22ms +step:912/1695 train_time:89575ms step_avg:98.22ms +step:913/1695 train_time:89675ms step_avg:98.22ms +step:914/1695 train_time:89774ms step_avg:98.22ms +step:915/1695 train_time:89873ms step_avg:98.22ms +step:916/1695 train_time:89973ms step_avg:98.22ms +step:917/1695 train_time:90074ms step_avg:98.23ms +step:918/1695 train_time:90174ms step_avg:98.23ms +step:919/1695 train_time:90275ms step_avg:98.23ms +step:920/1695 train_time:90377ms step_avg:98.24ms +step:921/1695 train_time:90477ms step_avg:98.24ms +step:922/1695 train_time:90577ms step_avg:98.24ms +step:923/1695 train_time:90677ms step_avg:98.24ms +step:924/1695 train_time:90776ms step_avg:98.24ms +step:925/1695 train_time:90876ms step_avg:98.24ms +step:926/1695 train_time:90976ms step_avg:98.25ms +step:927/1695 train_time:91077ms step_avg:98.25ms +step:928/1695 train_time:91177ms step_avg:98.25ms +step:929/1695 train_time:91278ms step_avg:98.25ms +step:930/1695 train_time:91378ms step_avg:98.26ms +step:931/1695 train_time:91478ms step_avg:98.26ms +step:932/1695 train_time:91578ms step_avg:98.26ms +step:933/1695 train_time:91677ms step_avg:98.26ms +step:934/1695 train_time:91776ms step_avg:98.26ms +step:935/1695 train_time:91876ms step_avg:98.26ms +step:936/1695 train_time:91975ms step_avg:98.26ms +step:937/1695 train_time:92075ms step_avg:98.27ms +step:938/1695 train_time:92175ms step_avg:98.27ms +step:939/1695 train_time:92276ms step_avg:98.27ms +step:940/1695 train_time:92376ms step_avg:98.27ms +step:941/1695 train_time:92477ms step_avg:98.28ms +step:942/1695 train_time:92576ms step_avg:98.28ms +step:943/1695 train_time:92677ms step_avg:98.28ms +step:944/1695 train_time:92777ms step_avg:98.28ms +step:945/1695 train_time:92877ms step_avg:98.28ms +step:946/1695 train_time:92976ms step_avg:98.28ms +step:947/1695 train_time:93076ms step_avg:98.29ms +step:948/1695 train_time:93176ms step_avg:98.29ms +step:949/1695 train_time:93276ms step_avg:98.29ms +step:950/1695 train_time:93376ms step_avg:98.29ms +step:951/1695 train_time:93476ms step_avg:98.29ms +step:952/1695 train_time:93576ms step_avg:98.29ms +step:953/1695 train_time:93676ms step_avg:98.30ms +step:954/1695 train_time:93776ms step_avg:98.30ms +step:955/1695 train_time:93876ms step_avg:98.30ms +step:956/1695 train_time:93975ms step_avg:98.30ms +step:957/1695 train_time:94076ms step_avg:98.30ms +step:958/1695 train_time:94176ms step_avg:98.30ms +step:959/1695 train_time:94275ms step_avg:98.31ms +step:960/1695 train_time:94374ms step_avg:98.31ms +step:961/1695 train_time:94475ms step_avg:98.31ms +step:962/1695 train_time:94576ms step_avg:98.31ms +step:963/1695 train_time:94677ms step_avg:98.31ms +step:964/1695 train_time:94777ms step_avg:98.32ms +step:965/1695 train_time:94878ms step_avg:98.32ms +step:966/1695 train_time:94978ms step_avg:98.32ms +step:967/1695 train_time:95078ms step_avg:98.32ms +step:968/1695 train_time:95178ms step_avg:98.32ms +step:969/1695 train_time:95278ms step_avg:98.33ms +step:970/1695 train_time:95377ms step_avg:98.33ms +step:971/1695 train_time:95477ms step_avg:98.33ms +step:972/1695 train_time:95577ms step_avg:98.33ms +step:973/1695 train_time:95677ms step_avg:98.33ms +step:974/1695 train_time:95777ms step_avg:98.33ms +step:975/1695 train_time:95877ms step_avg:98.34ms +step:976/1695 train_time:95977ms step_avg:98.34ms +step:977/1695 train_time:96077ms step_avg:98.34ms +step:978/1695 train_time:96176ms step_avg:98.34ms +step:979/1695 train_time:96277ms step_avg:98.34ms +step:980/1695 train_time:96377ms step_avg:98.34ms +step:981/1695 train_time:96477ms step_avg:98.35ms +step:982/1695 train_time:96577ms step_avg:98.35ms +step:983/1695 train_time:96678ms step_avg:98.35ms +step:984/1695 train_time:96778ms step_avg:98.35ms +step:985/1695 train_time:96877ms step_avg:98.35ms +step:986/1695 train_time:96978ms step_avg:98.35ms +step:987/1695 train_time:97078ms step_avg:98.36ms +step:988/1695 train_time:97178ms step_avg:98.36ms +step:989/1695 train_time:97277ms step_avg:98.36ms +step:990/1695 train_time:97377ms step_avg:98.36ms +step:991/1695 train_time:97477ms step_avg:98.36ms +step:992/1695 train_time:97576ms step_avg:98.36ms +step:993/1695 train_time:97676ms step_avg:98.37ms +step:994/1695 train_time:97776ms step_avg:98.37ms +step:995/1695 train_time:97877ms step_avg:98.37ms +step:996/1695 train_time:97976ms step_avg:98.37ms +step:997/1695 train_time:98077ms step_avg:98.37ms +step:998/1695 train_time:98176ms step_avg:98.37ms +step:999/1695 train_time:98277ms step_avg:98.37ms +step:1000/1695 train_time:98376ms step_avg:98.38ms +step:1000/1695 val_loss:3.4912 train_time:98474ms step_avg:98.47ms +step:1001/1695 train_time:98503ms step_avg:98.40ms +step:1002/1695 train_time:98585ms step_avg:98.39ms +step:1003/1695 train_time:98687ms step_avg:98.39ms +step:1004/1695 train_time:98788ms step_avg:98.39ms +step:1005/1695 train_time:98887ms step_avg:98.40ms +step:1006/1695 train_time:98986ms step_avg:98.40ms +step:1007/1695 train_time:99085ms step_avg:98.40ms +step:1008/1695 train_time:99184ms step_avg:98.40ms +step:1009/1695 train_time:99283ms step_avg:98.40ms +step:1010/1695 train_time:99382ms step_avg:98.40ms +step:1011/1695 train_time:99483ms step_avg:98.40ms +step:1012/1695 train_time:99587ms step_avg:98.41ms +step:1013/1695 train_time:99689ms step_avg:98.41ms +step:1014/1695 train_time:99789ms step_avg:98.41ms +step:1015/1695 train_time:99888ms step_avg:98.41ms +step:1016/1695 train_time:99987ms step_avg:98.41ms +step:1017/1695 train_time:100087ms step_avg:98.41ms +step:1018/1695 train_time:100185ms step_avg:98.41ms +step:1019/1695 train_time:100284ms step_avg:98.41ms +step:1020/1695 train_time:100384ms step_avg:98.42ms +step:1021/1695 train_time:100485ms step_avg:98.42ms +step:1022/1695 train_time:100586ms step_avg:98.42ms +step:1023/1695 train_time:100687ms step_avg:98.42ms +step:1024/1695 train_time:100790ms step_avg:98.43ms +step:1025/1695 train_time:100891ms step_avg:98.43ms +step:1026/1695 train_time:100991ms step_avg:98.43ms +step:1027/1695 train_time:101090ms step_avg:98.43ms +step:1028/1695 train_time:101189ms step_avg:98.43ms +step:1029/1695 train_time:101290ms step_avg:98.44ms +step:1030/1695 train_time:101388ms step_avg:98.44ms +step:1031/1695 train_time:101488ms step_avg:98.44ms +step:1032/1695 train_time:101588ms step_avg:98.44ms +step:1033/1695 train_time:101688ms step_avg:98.44ms +step:1034/1695 train_time:101788ms step_avg:98.44ms +step:1035/1695 train_time:101888ms step_avg:98.44ms +step:1036/1695 train_time:101987ms step_avg:98.44ms +step:1037/1695 train_time:102088ms step_avg:98.45ms +step:1038/1695 train_time:102187ms step_avg:98.45ms +step:1039/1695 train_time:102286ms step_avg:98.45ms +step:1040/1695 train_time:102386ms step_avg:98.45ms +step:1041/1695 train_time:102487ms step_avg:98.45ms +step:1042/1695 train_time:102587ms step_avg:98.45ms +step:1043/1695 train_time:102689ms step_avg:98.46ms +step:1044/1695 train_time:102789ms step_avg:98.46ms +step:1045/1695 train_time:102888ms step_avg:98.46ms +step:1046/1695 train_time:102989ms step_avg:98.46ms +step:1047/1695 train_time:103088ms step_avg:98.46ms +step:1048/1695 train_time:103188ms step_avg:98.46ms +step:1049/1695 train_time:103287ms step_avg:98.46ms +step:1050/1695 train_time:103386ms step_avg:98.46ms +step:1051/1695 train_time:103486ms step_avg:98.46ms +step:1052/1695 train_time:103587ms step_avg:98.47ms +step:1053/1695 train_time:103688ms step_avg:98.47ms +step:1054/1695 train_time:103788ms step_avg:98.47ms +step:1055/1695 train_time:103889ms step_avg:98.47ms +step:1056/1695 train_time:103989ms step_avg:98.47ms +step:1057/1695 train_time:104088ms step_avg:98.48ms +step:1058/1695 train_time:104188ms step_avg:98.48ms +step:1059/1695 train_time:104286ms step_avg:98.48ms +step:1060/1695 train_time:104386ms step_avg:98.48ms +step:1061/1695 train_time:104485ms step_avg:98.48ms +step:1062/1695 train_time:104585ms step_avg:98.48ms +step:1063/1695 train_time:104686ms step_avg:98.48ms +step:1064/1695 train_time:104787ms step_avg:98.48ms +step:1065/1695 train_time:104888ms step_avg:98.49ms +step:1066/1695 train_time:104988ms step_avg:98.49ms +step:1067/1695 train_time:105088ms step_avg:98.49ms +step:1068/1695 train_time:105188ms step_avg:98.49ms +step:1069/1695 train_time:105287ms step_avg:98.49ms +step:1070/1695 train_time:105388ms step_avg:98.49ms +step:1071/1695 train_time:105488ms step_avg:98.49ms +step:1072/1695 train_time:105588ms step_avg:98.50ms +step:1073/1695 train_time:105687ms step_avg:98.50ms +step:1074/1695 train_time:105787ms step_avg:98.50ms +step:1075/1695 train_time:105887ms step_avg:98.50ms +step:1076/1695 train_time:105987ms step_avg:98.50ms +step:1077/1695 train_time:106087ms step_avg:98.50ms +step:1078/1695 train_time:106188ms step_avg:98.50ms +step:1079/1695 train_time:106288ms step_avg:98.51ms +step:1080/1695 train_time:106387ms step_avg:98.51ms +step:1081/1695 train_time:106487ms step_avg:98.51ms +step:1082/1695 train_time:106587ms step_avg:98.51ms +step:1083/1695 train_time:106687ms step_avg:98.51ms +step:1084/1695 train_time:106787ms step_avg:98.51ms +step:1085/1695 train_time:106887ms step_avg:98.51ms +step:1086/1695 train_time:106988ms step_avg:98.52ms +step:1087/1695 train_time:107088ms step_avg:98.52ms +step:1088/1695 train_time:107188ms step_avg:98.52ms +step:1089/1695 train_time:107288ms step_avg:98.52ms +step:1090/1695 train_time:107388ms step_avg:98.52ms +step:1091/1695 train_time:107488ms step_avg:98.52ms +step:1092/1695 train_time:107588ms step_avg:98.52ms +step:1093/1695 train_time:107687ms step_avg:98.52ms +step:1094/1695 train_time:107787ms step_avg:98.53ms +step:1095/1695 train_time:107887ms step_avg:98.53ms +step:1096/1695 train_time:107987ms step_avg:98.53ms +step:1097/1695 train_time:108087ms step_avg:98.53ms +step:1098/1695 train_time:108187ms step_avg:98.53ms +step:1099/1695 train_time:108287ms step_avg:98.53ms +step:1100/1695 train_time:108387ms step_avg:98.53ms +step:1101/1695 train_time:108486ms step_avg:98.53ms +step:1102/1695 train_time:108587ms step_avg:98.54ms +step:1103/1695 train_time:108686ms step_avg:98.54ms +step:1104/1695 train_time:108786ms step_avg:98.54ms +step:1105/1695 train_time:108886ms step_avg:98.54ms +step:1106/1695 train_time:108986ms step_avg:98.54ms +step:1107/1695 train_time:109086ms step_avg:98.54ms +step:1108/1695 train_time:109187ms step_avg:98.54ms +step:1109/1695 train_time:109286ms step_avg:98.54ms +step:1110/1695 train_time:109387ms step_avg:98.55ms +step:1111/1695 train_time:109487ms step_avg:98.55ms +step:1112/1695 train_time:109587ms step_avg:98.55ms +step:1113/1695 train_time:109687ms step_avg:98.55ms +step:1114/1695 train_time:109787ms step_avg:98.55ms +step:1115/1695 train_time:109887ms step_avg:98.55ms +step:1116/1695 train_time:109988ms step_avg:98.56ms +step:1117/1695 train_time:110089ms step_avg:98.56ms +step:1118/1695 train_time:110189ms step_avg:98.56ms +step:1119/1695 train_time:110288ms step_avg:98.56ms +step:1120/1695 train_time:110388ms step_avg:98.56ms +step:1121/1695 train_time:110488ms step_avg:98.56ms +step:1122/1695 train_time:110588ms step_avg:98.56ms +step:1123/1695 train_time:110688ms step_avg:98.56ms +step:1124/1695 train_time:110787ms step_avg:98.56ms +step:1125/1695 train_time:110886ms step_avg:98.57ms +step:1125/1695 val_loss:3.4413 train_time:110983ms step_avg:98.65ms +step:1126/1695 train_time:111013ms step_avg:98.59ms +step:1127/1695 train_time:111097ms step_avg:98.58ms +step:1128/1695 train_time:111197ms step_avg:98.58ms +step:1129/1695 train_time:111297ms step_avg:98.58ms +step:1130/1695 train_time:111397ms step_avg:98.58ms +step:1131/1695 train_time:111496ms step_avg:98.58ms +step:1132/1695 train_time:111596ms step_avg:98.58ms +step:1133/1695 train_time:111696ms step_avg:98.58ms +step:1134/1695 train_time:111795ms step_avg:98.58ms +step:1135/1695 train_time:111896ms step_avg:98.59ms +step:1136/1695 train_time:111998ms step_avg:98.59ms +step:1137/1695 train_time:112100ms step_avg:98.59ms +step:1138/1695 train_time:112202ms step_avg:98.60ms +step:1139/1695 train_time:112302ms step_avg:98.60ms +step:1140/1695 train_time:112401ms step_avg:98.60ms +step:1141/1695 train_time:112501ms step_avg:98.60ms +step:1142/1695 train_time:112602ms step_avg:98.60ms +step:1143/1695 train_time:112702ms step_avg:98.60ms +step:1144/1695 train_time:112803ms step_avg:98.60ms +step:1145/1695 train_time:112905ms step_avg:98.61ms +step:1146/1695 train_time:113007ms step_avg:98.61ms +step:1147/1695 train_time:113109ms step_avg:98.61ms +step:1148/1695 train_time:113210ms step_avg:98.62ms +step:1149/1695 train_time:113311ms step_avg:98.62ms +step:1150/1695 train_time:113413ms step_avg:98.62ms +step:1151/1695 train_time:113515ms step_avg:98.62ms +step:1152/1695 train_time:113616ms step_avg:98.62ms +step:1153/1695 train_time:113717ms step_avg:98.63ms +step:1154/1695 train_time:113819ms step_avg:98.63ms +step:1155/1695 train_time:113920ms step_avg:98.63ms +step:1156/1695 train_time:114020ms step_avg:98.63ms +step:1157/1695 train_time:114121ms step_avg:98.64ms +step:1158/1695 train_time:114220ms step_avg:98.64ms +step:1159/1695 train_time:114320ms step_avg:98.64ms +step:1160/1695 train_time:114420ms step_avg:98.64ms +step:1161/1695 train_time:114520ms step_avg:98.64ms +step:1162/1695 train_time:114620ms step_avg:98.64ms +step:1163/1695 train_time:114722ms step_avg:98.64ms +step:1164/1695 train_time:114822ms step_avg:98.64ms +step:1165/1695 train_time:114923ms step_avg:98.65ms +step:1166/1695 train_time:115023ms step_avg:98.65ms +step:1167/1695 train_time:115123ms step_avg:98.65ms +step:1168/1695 train_time:115224ms step_avg:98.65ms +step:1169/1695 train_time:115323ms step_avg:98.65ms +step:1170/1695 train_time:115425ms step_avg:98.65ms +step:1171/1695 train_time:115526ms step_avg:98.66ms +step:1172/1695 train_time:115629ms step_avg:98.66ms +step:1173/1695 train_time:115730ms step_avg:98.66ms +step:1174/1695 train_time:115831ms step_avg:98.66ms +step:1175/1695 train_time:115932ms step_avg:98.67ms +step:1176/1695 train_time:116034ms step_avg:98.67ms +step:1177/1695 train_time:116135ms step_avg:98.67ms +step:1178/1695 train_time:116236ms step_avg:98.67ms +step:1179/1695 train_time:116340ms step_avg:98.68ms +step:1180/1695 train_time:116439ms step_avg:98.68ms +step:1181/1695 train_time:116541ms step_avg:98.68ms +step:1182/1695 train_time:116641ms step_avg:98.68ms +step:1183/1695 train_time:116740ms step_avg:98.68ms +step:1184/1695 train_time:116843ms step_avg:98.69ms +step:1185/1695 train_time:116946ms step_avg:98.69ms +step:1186/1695 train_time:117047ms step_avg:98.69ms +step:1187/1695 train_time:117150ms step_avg:98.69ms +step:1188/1695 train_time:117252ms step_avg:98.70ms +step:1189/1695 train_time:117352ms step_avg:98.70ms +step:1190/1695 train_time:117453ms step_avg:98.70ms +step:1191/1695 train_time:117553ms step_avg:98.70ms +step:1192/1695 train_time:117655ms step_avg:98.70ms +step:1193/1695 train_time:117757ms step_avg:98.71ms +step:1194/1695 train_time:117858ms step_avg:98.71ms +step:1195/1695 train_time:117958ms step_avg:98.71ms +step:1196/1695 train_time:118059ms step_avg:98.71ms +step:1197/1695 train_time:118160ms step_avg:98.71ms +step:1198/1695 train_time:118260ms step_avg:98.71ms +step:1199/1695 train_time:118360ms step_avg:98.72ms +step:1200/1695 train_time:118460ms step_avg:98.72ms +step:1201/1695 train_time:118561ms step_avg:98.72ms +step:1202/1695 train_time:118662ms step_avg:98.72ms +step:1203/1695 train_time:118764ms step_avg:98.72ms +step:1204/1695 train_time:118865ms step_avg:98.72ms +step:1205/1695 train_time:118965ms step_avg:98.73ms +step:1206/1695 train_time:119066ms step_avg:98.73ms +step:1207/1695 train_time:119168ms step_avg:98.73ms +step:1208/1695 train_time:119269ms step_avg:98.73ms +step:1209/1695 train_time:119371ms step_avg:98.74ms +step:1210/1695 train_time:119473ms step_avg:98.74ms +step:1211/1695 train_time:119574ms step_avg:98.74ms +step:1212/1695 train_time:119675ms step_avg:98.74ms +step:1213/1695 train_time:119776ms step_avg:98.74ms +step:1214/1695 train_time:119877ms step_avg:98.75ms +step:1215/1695 train_time:119978ms step_avg:98.75ms +step:1216/1695 train_time:120079ms step_avg:98.75ms +step:1217/1695 train_time:120181ms step_avg:98.75ms +step:1218/1695 train_time:120281ms step_avg:98.75ms +step:1219/1695 train_time:120381ms step_avg:98.75ms +step:1220/1695 train_time:120482ms step_avg:98.76ms +step:1221/1695 train_time:120583ms step_avg:98.76ms +step:1222/1695 train_time:120685ms step_avg:98.76ms +step:1223/1695 train_time:120786ms step_avg:98.76ms +step:1224/1695 train_time:120887ms step_avg:98.76ms +step:1225/1695 train_time:120989ms step_avg:98.77ms +step:1226/1695 train_time:121090ms step_avg:98.77ms +step:1227/1695 train_time:121192ms step_avg:98.77ms +step:1228/1695 train_time:121293ms step_avg:98.77ms +step:1229/1695 train_time:121394ms step_avg:98.77ms +step:1230/1695 train_time:121496ms step_avg:98.78ms +step:1231/1695 train_time:121597ms step_avg:98.78ms +step:1232/1695 train_time:121698ms step_avg:98.78ms +step:1233/1695 train_time:121798ms step_avg:98.78ms +step:1234/1695 train_time:121901ms step_avg:98.79ms +step:1235/1695 train_time:122000ms step_avg:98.79ms +step:1236/1695 train_time:122101ms step_avg:98.79ms +step:1237/1695 train_time:122201ms step_avg:98.79ms +step:1238/1695 train_time:122302ms step_avg:98.79ms +step:1239/1695 train_time:122404ms step_avg:98.79ms +step:1240/1695 train_time:122506ms step_avg:98.80ms +step:1241/1695 train_time:122606ms step_avg:98.80ms +step:1242/1695 train_time:122708ms step_avg:98.80ms +step:1243/1695 train_time:122810ms step_avg:98.80ms +step:1244/1695 train_time:122912ms step_avg:98.80ms +step:1245/1695 train_time:123013ms step_avg:98.81ms +step:1246/1695 train_time:123115ms step_avg:98.81ms +step:1247/1695 train_time:123217ms step_avg:98.81ms +step:1248/1695 train_time:123319ms step_avg:98.81ms +step:1249/1695 train_time:123419ms step_avg:98.81ms +step:1250/1695 train_time:123519ms step_avg:98.82ms +step:1250/1695 val_loss:3.3953 train_time:123617ms step_avg:98.89ms +step:1251/1695 train_time:123646ms step_avg:98.84ms +step:1252/1695 train_time:123729ms step_avg:98.82ms +step:1253/1695 train_time:123830ms step_avg:98.83ms +step:1254/1695 train_time:123932ms step_avg:98.83ms +step:1255/1695 train_time:124033ms step_avg:98.83ms +step:1256/1695 train_time:124134ms step_avg:98.83ms +step:1257/1695 train_time:124234ms step_avg:98.83ms +step:1258/1695 train_time:124335ms step_avg:98.84ms +step:1259/1695 train_time:124435ms step_avg:98.84ms +step:1260/1695 train_time:124536ms step_avg:98.84ms +step:1261/1695 train_time:124638ms step_avg:98.84ms +step:1262/1695 train_time:124743ms step_avg:98.85ms +step:1263/1695 train_time:124844ms step_avg:98.85ms +step:1264/1695 train_time:124944ms step_avg:98.85ms +step:1265/1695 train_time:125045ms step_avg:98.85ms +step:1266/1695 train_time:125146ms step_avg:98.85ms +step:1267/1695 train_time:125246ms step_avg:98.85ms +step:1268/1695 train_time:125347ms step_avg:98.85ms +step:1269/1695 train_time:125447ms step_avg:98.86ms +step:1270/1695 train_time:125548ms step_avg:98.86ms +step:1271/1695 train_time:125649ms step_avg:98.86ms +step:1272/1695 train_time:125751ms step_avg:98.86ms +step:1273/1695 train_time:125852ms step_avg:98.86ms +step:1274/1695 train_time:125953ms step_avg:98.86ms +step:1275/1695 train_time:126055ms step_avg:98.87ms +step:1276/1695 train_time:126157ms step_avg:98.87ms +step:1277/1695 train_time:126259ms step_avg:98.87ms +step:1278/1695 train_time:126360ms step_avg:98.87ms +step:1279/1695 train_time:126461ms step_avg:98.88ms +step:1280/1695 train_time:126563ms step_avg:98.88ms +step:1281/1695 train_time:126665ms step_avg:98.88ms +step:1282/1695 train_time:126766ms step_avg:98.88ms +step:1283/1695 train_time:126867ms step_avg:98.88ms +step:1284/1695 train_time:126967ms step_avg:98.88ms +step:1285/1695 train_time:127067ms step_avg:98.88ms +step:1286/1695 train_time:127167ms step_avg:98.89ms +step:1287/1695 train_time:127267ms step_avg:98.89ms +step:1288/1695 train_time:127367ms step_avg:98.89ms +step:1289/1695 train_time:127468ms step_avg:98.89ms +step:1290/1695 train_time:127569ms step_avg:98.89ms +step:1291/1695 train_time:127670ms step_avg:98.89ms +step:1292/1695 train_time:127770ms step_avg:98.89ms +step:1293/1695 train_time:127871ms step_avg:98.89ms +step:1294/1695 train_time:127973ms step_avg:98.90ms +step:1295/1695 train_time:128074ms step_avg:98.90ms +step:1296/1695 train_time:128174ms step_avg:98.90ms +step:1297/1695 train_time:128276ms step_avg:98.90ms +step:1298/1695 train_time:128376ms step_avg:98.90ms +step:1299/1695 train_time:128478ms step_avg:98.90ms +step:1300/1695 train_time:128581ms step_avg:98.91ms +step:1301/1695 train_time:128682ms step_avg:98.91ms +step:1302/1695 train_time:128784ms step_avg:98.91ms +step:1303/1695 train_time:128885ms step_avg:98.91ms +step:1304/1695 train_time:128985ms step_avg:98.92ms +step:1305/1695 train_time:129086ms step_avg:98.92ms +step:1306/1695 train_time:129187ms step_avg:98.92ms +step:1307/1695 train_time:129286ms step_avg:98.92ms +step:1308/1695 train_time:129386ms step_avg:98.92ms +step:1309/1695 train_time:129486ms step_avg:98.92ms +step:1310/1695 train_time:129586ms step_avg:98.92ms +step:1311/1695 train_time:129688ms step_avg:98.92ms +step:1312/1695 train_time:129789ms step_avg:98.92ms +step:1313/1695 train_time:129889ms step_avg:98.93ms +step:1314/1695 train_time:129990ms step_avg:98.93ms +step:1315/1695 train_time:130091ms step_avg:98.93ms +step:1316/1695 train_time:130192ms step_avg:98.93ms +step:1317/1695 train_time:130293ms step_avg:98.93ms +step:1318/1695 train_time:130393ms step_avg:98.93ms +step:1319/1695 train_time:130494ms step_avg:98.93ms +step:1320/1695 train_time:130597ms step_avg:98.94ms +step:1321/1695 train_time:130698ms step_avg:98.94ms +step:1322/1695 train_time:130800ms step_avg:98.94ms +step:1323/1695 train_time:130902ms step_avg:98.94ms +step:1324/1695 train_time:131003ms step_avg:98.95ms +step:1325/1695 train_time:131104ms step_avg:98.95ms +step:1326/1695 train_time:131206ms step_avg:98.95ms +step:1327/1695 train_time:131307ms step_avg:98.95ms +step:1328/1695 train_time:131406ms step_avg:98.95ms +step:1329/1695 train_time:131507ms step_avg:98.95ms +step:1330/1695 train_time:131607ms step_avg:98.95ms +step:1331/1695 train_time:131708ms step_avg:98.95ms +step:1332/1695 train_time:131808ms step_avg:98.95ms +step:1333/1695 train_time:131910ms step_avg:98.96ms +step:1334/1695 train_time:132013ms step_avg:98.96ms +step:1335/1695 train_time:132114ms step_avg:98.96ms +step:1336/1695 train_time:132215ms step_avg:98.96ms +step:1337/1695 train_time:132318ms step_avg:98.97ms +step:1338/1695 train_time:132419ms step_avg:98.97ms +step:1339/1695 train_time:132520ms step_avg:98.97ms +step:1340/1695 train_time:132621ms step_avg:98.97ms +step:1341/1695 train_time:132723ms step_avg:98.97ms +step:1342/1695 train_time:132825ms step_avg:98.98ms +step:1343/1695 train_time:132926ms step_avg:98.98ms +step:1344/1695 train_time:133026ms step_avg:98.98ms +step:1345/1695 train_time:133127ms step_avg:98.98ms +step:1346/1695 train_time:133228ms step_avg:98.98ms +step:1347/1695 train_time:133328ms step_avg:98.98ms +step:1348/1695 train_time:133429ms step_avg:98.98ms +step:1349/1695 train_time:133529ms step_avg:98.98ms +step:1350/1695 train_time:133631ms step_avg:98.99ms +step:1351/1695 train_time:133732ms step_avg:98.99ms +step:1352/1695 train_time:133833ms step_avg:98.99ms +step:1353/1695 train_time:133935ms step_avg:98.99ms +step:1354/1695 train_time:134036ms step_avg:98.99ms +step:1355/1695 train_time:134139ms step_avg:99.00ms +step:1356/1695 train_time:134242ms step_avg:99.00ms +step:1357/1695 train_time:134343ms step_avg:99.00ms +step:1358/1695 train_time:134444ms step_avg:99.00ms +step:1359/1695 train_time:134545ms step_avg:99.00ms +step:1360/1695 train_time:134645ms step_avg:99.00ms +step:1361/1695 train_time:134747ms step_avg:99.01ms +step:1362/1695 train_time:134847ms step_avg:99.01ms +step:1363/1695 train_time:134948ms step_avg:99.01ms +step:1364/1695 train_time:135049ms step_avg:99.01ms +step:1365/1695 train_time:135151ms step_avg:99.01ms +step:1366/1695 train_time:135253ms step_avg:99.01ms +step:1367/1695 train_time:135354ms step_avg:99.02ms +step:1368/1695 train_time:135455ms step_avg:99.02ms +step:1369/1695 train_time:135556ms step_avg:99.02ms +step:1370/1695 train_time:135658ms step_avg:99.02ms +step:1371/1695 train_time:135760ms step_avg:99.02ms +step:1372/1695 train_time:135861ms step_avg:99.02ms +step:1373/1695 train_time:135963ms step_avg:99.03ms +step:1374/1695 train_time:136064ms step_avg:99.03ms +step:1375/1695 train_time:136166ms step_avg:99.03ms +step:1375/1695 val_loss:3.3558 train_time:136264ms step_avg:99.10ms +step:1376/1695 train_time:136293ms step_avg:99.05ms +step:1377/1695 train_time:136379ms step_avg:99.04ms +step:1378/1695 train_time:136482ms step_avg:99.04ms +step:1379/1695 train_time:136584ms step_avg:99.05ms +step:1380/1695 train_time:136687ms step_avg:99.05ms +step:1381/1695 train_time:136787ms step_avg:99.05ms +step:1382/1695 train_time:136886ms step_avg:99.05ms +step:1383/1695 train_time:136986ms step_avg:99.05ms +step:1384/1695 train_time:137086ms step_avg:99.05ms +step:1385/1695 train_time:137188ms step_avg:99.05ms +step:1386/1695 train_time:137294ms step_avg:99.06ms +step:1387/1695 train_time:137395ms step_avg:99.06ms +step:1388/1695 train_time:137496ms step_avg:99.06ms +step:1389/1695 train_time:137600ms step_avg:99.06ms +step:1390/1695 train_time:137703ms step_avg:99.07ms +step:1391/1695 train_time:137806ms step_avg:99.07ms +step:1392/1695 train_time:137907ms step_avg:99.07ms +step:1393/1695 train_time:138009ms step_avg:99.07ms +step:1394/1695 train_time:138110ms step_avg:99.07ms +step:1395/1695 train_time:138211ms step_avg:99.08ms +step:1396/1695 train_time:138313ms step_avg:99.08ms +step:1397/1695 train_time:138416ms step_avg:99.08ms +step:1398/1695 train_time:138519ms step_avg:99.08ms +step:1399/1695 train_time:138621ms step_avg:99.09ms +step:1400/1695 train_time:138724ms step_avg:99.09ms +step:1401/1695 train_time:138826ms step_avg:99.09ms +step:1402/1695 train_time:138928ms step_avg:99.09ms +step:1403/1695 train_time:139030ms step_avg:99.09ms +step:1404/1695 train_time:139131ms step_avg:99.10ms +step:1405/1695 train_time:139233ms step_avg:99.10ms +step:1406/1695 train_time:139334ms step_avg:99.10ms +step:1407/1695 train_time:139436ms step_avg:99.10ms +step:1408/1695 train_time:139536ms step_avg:99.10ms +step:1409/1695 train_time:139640ms step_avg:99.11ms +step:1410/1695 train_time:139743ms step_avg:99.11ms +step:1411/1695 train_time:139846ms step_avg:99.11ms +step:1412/1695 train_time:139949ms step_avg:99.11ms +step:1413/1695 train_time:140050ms step_avg:99.12ms +step:1414/1695 train_time:140152ms step_avg:99.12ms +step:1415/1695 train_time:140253ms step_avg:99.12ms +step:1416/1695 train_time:140354ms step_avg:99.12ms +step:1417/1695 train_time:140454ms step_avg:99.12ms +step:1418/1695 train_time:140555ms step_avg:99.12ms +step:1419/1695 train_time:140656ms step_avg:99.12ms +step:1420/1695 train_time:140759ms step_avg:99.13ms +step:1421/1695 train_time:140862ms step_avg:99.13ms +step:1422/1695 train_time:140964ms step_avg:99.13ms +step:1423/1695 train_time:141067ms step_avg:99.13ms +step:1424/1695 train_time:141169ms step_avg:99.14ms +step:1425/1695 train_time:141271ms step_avg:99.14ms +step:1426/1695 train_time:141374ms step_avg:99.14ms +step:1427/1695 train_time:141476ms step_avg:99.14ms +step:1428/1695 train_time:141578ms step_avg:99.14ms +step:1429/1695 train_time:141679ms step_avg:99.15ms +step:1430/1695 train_time:141781ms step_avg:99.15ms +step:1431/1695 train_time:141883ms step_avg:99.15ms +step:1432/1695 train_time:141985ms step_avg:99.15ms +step:1433/1695 train_time:142089ms step_avg:99.15ms +step:1434/1695 train_time:142190ms step_avg:99.16ms +step:1435/1695 train_time:142293ms step_avg:99.16ms +step:1436/1695 train_time:142395ms step_avg:99.16ms +step:1437/1695 train_time:142497ms step_avg:99.16ms +step:1438/1695 train_time:142599ms step_avg:99.16ms +step:1439/1695 train_time:142701ms step_avg:99.17ms +step:1440/1695 train_time:142804ms step_avg:99.17ms +step:1441/1695 train_time:142907ms step_avg:99.17ms +step:1442/1695 train_time:143009ms step_avg:99.17ms +step:1443/1695 train_time:143110ms step_avg:99.18ms +step:1444/1695 train_time:143212ms step_avg:99.18ms +step:1445/1695 train_time:143313ms step_avg:99.18ms +step:1446/1695 train_time:143414ms step_avg:99.18ms +step:1447/1695 train_time:143515ms step_avg:99.18ms +step:1448/1695 train_time:143618ms step_avg:99.18ms +step:1449/1695 train_time:143719ms step_avg:99.18ms +step:1450/1695 train_time:143820ms step_avg:99.19ms +step:1451/1695 train_time:143923ms step_avg:99.19ms +step:1452/1695 train_time:144026ms step_avg:99.19ms +step:1453/1695 train_time:144130ms step_avg:99.19ms +step:1454/1695 train_time:144232ms step_avg:99.20ms +step:1455/1695 train_time:144334ms step_avg:99.20ms +step:1456/1695 train_time:144435ms step_avg:99.20ms +step:1457/1695 train_time:144536ms step_avg:99.20ms +step:1458/1695 train_time:144638ms step_avg:99.20ms +step:1459/1695 train_time:144740ms step_avg:99.20ms +step:1460/1695 train_time:144841ms step_avg:99.21ms +step:1461/1695 train_time:144943ms step_avg:99.21ms +step:1462/1695 train_time:145046ms step_avg:99.21ms +step:1463/1695 train_time:145148ms step_avg:99.21ms +step:1464/1695 train_time:145250ms step_avg:99.21ms +step:1465/1695 train_time:145351ms step_avg:99.22ms +step:1466/1695 train_time:145452ms step_avg:99.22ms +step:1467/1695 train_time:145553ms step_avg:99.22ms +step:1468/1695 train_time:145656ms step_avg:99.22ms +step:1469/1695 train_time:145758ms step_avg:99.22ms +step:1470/1695 train_time:145859ms step_avg:99.22ms +step:1471/1695 train_time:145962ms step_avg:99.23ms +step:1472/1695 train_time:146064ms step_avg:99.23ms +step:1473/1695 train_time:146167ms step_avg:99.23ms +step:1474/1695 train_time:146268ms step_avg:99.23ms +step:1475/1695 train_time:146369ms step_avg:99.23ms +step:1476/1695 train_time:146471ms step_avg:99.24ms +step:1477/1695 train_time:146573ms step_avg:99.24ms +step:1478/1695 train_time:146674ms step_avg:99.24ms +step:1479/1695 train_time:146775ms step_avg:99.24ms +step:1480/1695 train_time:146877ms step_avg:99.24ms +step:1481/1695 train_time:146980ms step_avg:99.24ms +step:1482/1695 train_time:147082ms step_avg:99.25ms +step:1483/1695 train_time:147184ms step_avg:99.25ms +step:1484/1695 train_time:147288ms step_avg:99.25ms +step:1485/1695 train_time:147389ms step_avg:99.25ms +step:1486/1695 train_time:147491ms step_avg:99.25ms +step:1487/1695 train_time:147593ms step_avg:99.26ms +step:1488/1695 train_time:147695ms step_avg:99.26ms +step:1489/1695 train_time:147797ms step_avg:99.26ms +step:1490/1695 train_time:147899ms step_avg:99.26ms +step:1491/1695 train_time:148000ms step_avg:99.26ms +step:1492/1695 train_time:148102ms step_avg:99.26ms +step:1493/1695 train_time:148205ms step_avg:99.27ms +step:1494/1695 train_time:148307ms step_avg:99.27ms +step:1495/1695 train_time:148410ms step_avg:99.27ms +step:1496/1695 train_time:148512ms step_avg:99.27ms +step:1497/1695 train_time:148612ms step_avg:99.27ms +step:1498/1695 train_time:148714ms step_avg:99.27ms +step:1499/1695 train_time:148815ms step_avg:99.28ms +step:1500/1695 train_time:148916ms step_avg:99.28ms +step:1500/1695 val_loss:3.3212 train_time:149015ms step_avg:99.34ms +step:1501/1695 train_time:149044ms step_avg:99.30ms +step:1502/1695 train_time:149127ms step_avg:99.29ms +step:1503/1695 train_time:149230ms step_avg:99.29ms +step:1504/1695 train_time:149333ms step_avg:99.29ms +step:1505/1695 train_time:149434ms step_avg:99.29ms +step:1506/1695 train_time:149535ms step_avg:99.29ms +step:1507/1695 train_time:149637ms step_avg:99.29ms +step:1508/1695 train_time:149738ms step_avg:99.30ms +step:1509/1695 train_time:149841ms step_avg:99.30ms +step:1510/1695 train_time:149943ms step_avg:99.30ms +step:1511/1695 train_time:150046ms step_avg:99.30ms +step:1512/1695 train_time:150149ms step_avg:99.30ms +step:1513/1695 train_time:150252ms step_avg:99.31ms +step:1514/1695 train_time:150354ms step_avg:99.31ms +step:1515/1695 train_time:150459ms step_avg:99.31ms +step:1516/1695 train_time:150560ms step_avg:99.31ms +step:1517/1695 train_time:150661ms step_avg:99.31ms +step:1518/1695 train_time:150762ms step_avg:99.32ms +step:1519/1695 train_time:150866ms step_avg:99.32ms +step:1520/1695 train_time:150968ms step_avg:99.32ms +step:1521/1695 train_time:151069ms step_avg:99.32ms +step:1522/1695 train_time:151170ms step_avg:99.32ms +step:1523/1695 train_time:151272ms step_avg:99.33ms +step:1524/1695 train_time:151376ms step_avg:99.33ms +step:1525/1695 train_time:151479ms step_avg:99.33ms +step:1526/1695 train_time:151582ms step_avg:99.33ms +step:1527/1695 train_time:151684ms step_avg:99.33ms +step:1528/1695 train_time:151788ms step_avg:99.34ms +step:1529/1695 train_time:151890ms step_avg:99.34ms +step:1530/1695 train_time:151994ms step_avg:99.34ms +step:1531/1695 train_time:152095ms step_avg:99.34ms +step:1532/1695 train_time:152198ms step_avg:99.35ms +step:1533/1695 train_time:152301ms step_avg:99.35ms +step:1534/1695 train_time:152403ms step_avg:99.35ms +step:1535/1695 train_time:152504ms step_avg:99.35ms +step:1536/1695 train_time:152605ms step_avg:99.35ms +step:1537/1695 train_time:152709ms step_avg:99.35ms +step:1538/1695 train_time:152810ms step_avg:99.36ms +step:1539/1695 train_time:152911ms step_avg:99.36ms +step:1540/1695 train_time:153013ms step_avg:99.36ms +step:1541/1695 train_time:153118ms step_avg:99.36ms +step:1542/1695 train_time:153222ms step_avg:99.37ms +step:1543/1695 train_time:153324ms step_avg:99.37ms +step:1544/1695 train_time:153425ms step_avg:99.37ms +step:1545/1695 train_time:153526ms step_avg:99.37ms +step:1546/1695 train_time:153627ms step_avg:99.37ms +step:1547/1695 train_time:153729ms step_avg:99.37ms +step:1548/1695 train_time:153831ms step_avg:99.37ms +step:1549/1695 train_time:153934ms step_avg:99.38ms +step:1550/1695 train_time:154035ms step_avg:99.38ms +step:1551/1695 train_time:154137ms step_avg:99.38ms +step:1552/1695 train_time:154240ms step_avg:99.38ms +step:1553/1695 train_time:154343ms step_avg:99.38ms +step:1554/1695 train_time:154444ms step_avg:99.38ms +step:1555/1695 train_time:154545ms step_avg:99.39ms +step:1556/1695 train_time:154647ms step_avg:99.39ms +step:1557/1695 train_time:154750ms step_avg:99.39ms +step:1558/1695 train_time:154852ms step_avg:99.39ms +step:1559/1695 train_time:154954ms step_avg:99.39ms +step:1560/1695 train_time:155056ms step_avg:99.40ms +step:1561/1695 train_time:155159ms step_avg:99.40ms +step:1562/1695 train_time:155262ms step_avg:99.40ms +step:1563/1695 train_time:155366ms step_avg:99.40ms +step:1564/1695 train_time:155467ms step_avg:99.40ms +step:1565/1695 train_time:155568ms step_avg:99.40ms +step:1566/1695 train_time:155669ms step_avg:99.41ms +step:1567/1695 train_time:155771ms step_avg:99.41ms +step:1568/1695 train_time:155872ms step_avg:99.41ms +step:1569/1695 train_time:155973ms step_avg:99.41ms +step:1570/1695 train_time:156078ms step_avg:99.41ms +step:1571/1695 train_time:156179ms step_avg:99.41ms +step:1572/1695 train_time:156281ms step_avg:99.42ms +step:1573/1695 train_time:156383ms step_avg:99.42ms +step:1574/1695 train_time:156485ms step_avg:99.42ms +step:1575/1695 train_time:156585ms step_avg:99.42ms +step:1576/1695 train_time:156687ms step_avg:99.42ms +step:1577/1695 train_time:156790ms step_avg:99.42ms +step:1578/1695 train_time:156891ms step_avg:99.42ms +step:1579/1695 train_time:156993ms step_avg:99.43ms +step:1580/1695 train_time:157095ms step_avg:99.43ms +step:1581/1695 train_time:157198ms step_avg:99.43ms +step:1582/1695 train_time:157299ms step_avg:99.43ms +step:1583/1695 train_time:157402ms step_avg:99.43ms +step:1584/1695 train_time:157505ms step_avg:99.43ms +step:1585/1695 train_time:157606ms step_avg:99.44ms +step:1586/1695 train_time:157708ms step_avg:99.44ms +step:1587/1695 train_time:157809ms step_avg:99.44ms +step:1588/1695 train_time:157910ms step_avg:99.44ms +step:1589/1695 train_time:158012ms step_avg:99.44ms +step:1590/1695 train_time:158114ms step_avg:99.44ms +step:1591/1695 train_time:158216ms step_avg:99.44ms +step:1592/1695 train_time:158319ms step_avg:99.45ms +step:1593/1695 train_time:158421ms step_avg:99.45ms +step:1594/1695 train_time:158525ms step_avg:99.45ms +step:1595/1695 train_time:158626ms step_avg:99.45ms +step:1596/1695 train_time:158727ms step_avg:99.45ms +step:1597/1695 train_time:158829ms step_avg:99.45ms +step:1598/1695 train_time:158931ms step_avg:99.46ms +step:1599/1695 train_time:159032ms step_avg:99.46ms +step:1600/1695 train_time:159134ms step_avg:99.46ms +step:1601/1695 train_time:159238ms step_avg:99.46ms +step:1602/1695 train_time:159340ms step_avg:99.46ms +step:1603/1695 train_time:159441ms step_avg:99.46ms +step:1604/1695 train_time:159544ms step_avg:99.47ms +step:1605/1695 train_time:159647ms step_avg:99.47ms +step:1606/1695 train_time:159749ms step_avg:99.47ms +step:1607/1695 train_time:159850ms step_avg:99.47ms +step:1608/1695 train_time:159951ms step_avg:99.47ms +step:1609/1695 train_time:160052ms step_avg:99.47ms +step:1610/1695 train_time:160153ms step_avg:99.47ms +step:1611/1695 train_time:160256ms step_avg:99.48ms +step:1612/1695 train_time:160359ms step_avg:99.48ms +step:1613/1695 train_time:160460ms step_avg:99.48ms +step:1614/1695 train_time:160561ms step_avg:99.48ms +step:1615/1695 train_time:160664ms step_avg:99.48ms +step:1616/1695 train_time:160765ms step_avg:99.48ms +step:1617/1695 train_time:160867ms step_avg:99.48ms +step:1618/1695 train_time:160968ms step_avg:99.49ms +step:1619/1695 train_time:161070ms step_avg:99.49ms +step:1620/1695 train_time:161172ms step_avg:99.49ms +step:1621/1695 train_time:161274ms step_avg:99.49ms +step:1622/1695 train_time:161375ms step_avg:99.49ms +step:1623/1695 train_time:161478ms step_avg:99.49ms +step:1624/1695 train_time:161582ms step_avg:99.50ms +step:1625/1695 train_time:161686ms step_avg:99.50ms +step:1625/1695 val_loss:3.2922 train_time:161787ms step_avg:99.56ms +step:1626/1695 train_time:161816ms step_avg:99.52ms +step:1627/1695 train_time:161901ms step_avg:99.51ms +step:1628/1695 train_time:162003ms step_avg:99.51ms +step:1629/1695 train_time:162105ms step_avg:99.51ms +step:1630/1695 train_time:162207ms step_avg:99.51ms +step:1631/1695 train_time:162308ms step_avg:99.51ms +step:1632/1695 train_time:162410ms step_avg:99.52ms +step:1633/1695 train_time:162511ms step_avg:99.52ms +step:1634/1695 train_time:162614ms step_avg:99.52ms +step:1635/1695 train_time:162716ms step_avg:99.52ms +step:1636/1695 train_time:162819ms step_avg:99.52ms +step:1637/1695 train_time:162922ms step_avg:99.52ms +step:1638/1695 train_time:163025ms step_avg:99.53ms +step:1639/1695 train_time:163129ms step_avg:99.53ms +step:1640/1695 train_time:163231ms step_avg:99.53ms +step:1641/1695 train_time:163333ms step_avg:99.53ms +step:1642/1695 train_time:163435ms step_avg:99.53ms +step:1643/1695 train_time:163536ms step_avg:99.54ms +step:1644/1695 train_time:163638ms step_avg:99.54ms +step:1645/1695 train_time:163741ms step_avg:99.54ms +step:1646/1695 train_time:163844ms step_avg:99.54ms +step:1647/1695 train_time:163950ms step_avg:99.54ms +step:1648/1695 train_time:164054ms step_avg:99.55ms +step:1649/1695 train_time:164156ms step_avg:99.55ms +step:1650/1695 train_time:164258ms step_avg:99.55ms +step:1651/1695 train_time:164360ms step_avg:99.55ms +step:1652/1695 train_time:164462ms step_avg:99.55ms +step:1653/1695 train_time:164565ms step_avg:99.56ms +step:1654/1695 train_time:164668ms step_avg:99.56ms +step:1655/1695 train_time:164771ms step_avg:99.56ms +step:1656/1695 train_time:164874ms step_avg:99.56ms +step:1657/1695 train_time:164976ms step_avg:99.56ms +step:1658/1695 train_time:165078ms step_avg:99.56ms +step:1659/1695 train_time:165184ms step_avg:99.57ms +step:1660/1695 train_time:165287ms step_avg:99.57ms +step:1661/1695 train_time:165392ms step_avg:99.57ms +step:1662/1695 train_time:165496ms step_avg:99.58ms +step:1663/1695 train_time:165599ms step_avg:99.58ms +step:1664/1695 train_time:165700ms step_avg:99.58ms +step:1665/1695 train_time:165808ms step_avg:99.58ms +step:1666/1695 train_time:165911ms step_avg:99.59ms +step:1667/1695 train_time:166014ms step_avg:99.59ms +step:1668/1695 train_time:166119ms step_avg:99.59ms +step:1669/1695 train_time:166223ms step_avg:99.59ms +step:1670/1695 train_time:166326ms step_avg:99.60ms +step:1671/1695 train_time:166429ms step_avg:99.60ms +step:1672/1695 train_time:166533ms step_avg:99.60ms +step:1673/1695 train_time:166635ms step_avg:99.60ms +step:1674/1695 train_time:166738ms step_avg:99.60ms +step:1675/1695 train_time:166840ms step_avg:99.61ms +step:1676/1695 train_time:166945ms step_avg:99.61ms +step:1677/1695 train_time:167046ms step_avg:99.61ms +step:1678/1695 train_time:167152ms step_avg:99.61ms +step:1679/1695 train_time:167254ms step_avg:99.62ms +step:1680/1695 train_time:167356ms step_avg:99.62ms +step:1681/1695 train_time:167459ms step_avg:99.62ms +step:1682/1695 train_time:167564ms step_avg:99.62ms +step:1683/1695 train_time:167667ms step_avg:99.62ms +step:1684/1695 train_time:167770ms step_avg:99.63ms +step:1685/1695 train_time:167873ms step_avg:99.63ms +step:1686/1695 train_time:167975ms step_avg:99.63ms +step:1687/1695 train_time:168077ms step_avg:99.63ms +step:1688/1695 train_time:168179ms step_avg:99.63ms +step:1689/1695 train_time:168280ms step_avg:99.63ms +step:1690/1695 train_time:168382ms step_avg:99.63ms +step:1691/1695 train_time:168485ms step_avg:99.64ms +step:1692/1695 train_time:168588ms step_avg:99.64ms +step:1693/1695 train_time:168692ms step_avg:99.64ms +step:1694/1695 train_time:168796ms step_avg:99.64ms +step:1695/1695 train_time:168900ms step_avg:99.65ms +step:1695/1695 val_loss:3.2795 train_time:168999ms step_avg:99.70ms +peak memory allocated: 33860 MiB reserved: 49040 MiB diff --git a/records/082325_SparseAttnGate/53ecb4ef-77ed-4af6-b776-47cd4006614b.txt b/records/082325_SparseAttnGate/53ecb4ef-77ed-4af6-b776-47cd4006614b.txt new file mode 100644 index 000000000..01dbfe241 --- /dev/null +++ b/records/082325_SparseAttnGate/53ecb4ef-77ed-4af6-b776-47cd4006614b.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:28:09 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 308062 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 308063 C /usr/bin/python3 614MiB | +| 0 N/A N/A 308064 C /usr/bin/python3 614MiB | +| 0 N/A N/A 308065 C /usr/bin/python3 614MiB | +| 0 N/A N/A 308066 C /usr/bin/python3 614MiB | +| 0 N/A N/A 308067 C /usr/bin/python3 614MiB | +| 0 N/A N/A 308068 C /usr/bin/python3 614MiB | +| 0 N/A N/A 308069 C /usr/bin/python3 614MiB | +| 1 N/A N/A 308063 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 308064 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 308065 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 308066 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 308067 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 308068 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 308069 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:154ms step_avg:153.72ms +step:2/1695 train_time:179ms step_avg:89.42ms +step:3/1695 train_time:250ms step_avg:83.25ms +step:4/1695 train_time:342ms step_avg:85.38ms +step:5/1695 train_time:434ms step_avg:86.79ms +step:6/1695 train_time:526ms step_avg:87.73ms +step:7/1695 train_time:620ms step_avg:88.51ms +step:8/1695 train_time:712ms step_avg:89.00ms +step:9/1695 train_time:805ms step_avg:89.40ms +step:10/1695 train_time:898ms step_avg:89.77ms +step:11/1695 train_time:990ms step_avg:90.05ms +step:12/1695 train_time:1084ms step_avg:90.36ms +step:13/1695 train_time:1180ms step_avg:90.78ms +step:14/1695 train_time:1274ms step_avg:91.03ms +step:15/1695 train_time:1368ms step_avg:91.19ms +step:16/1695 train_time:1462ms step_avg:91.35ms +step:17/1695 train_time:1555ms step_avg:91.48ms +step:18/1695 train_time:1648ms step_avg:91.58ms +step:19/1695 train_time:1742ms step_avg:91.67ms +step:20/1695 train_time:1835ms step_avg:91.74ms +step:21/1695 train_time:1927ms step_avg:91.77ms +step:22/1695 train_time:2021ms step_avg:91.85ms +step:23/1695 train_time:2115ms step_avg:91.97ms +step:24/1695 train_time:2211ms step_avg:92.11ms +step:25/1695 train_time:2305ms step_avg:92.20ms +step:26/1695 train_time:2399ms step_avg:92.26ms +step:27/1695 train_time:2493ms step_avg:92.32ms +step:28/1695 train_time:2586ms step_avg:92.36ms +step:29/1695 train_time:2680ms step_avg:92.40ms +step:30/1695 train_time:2773ms step_avg:92.44ms +step:31/1695 train_time:2866ms step_avg:92.46ms +step:32/1695 train_time:2961ms step_avg:92.52ms +step:33/1695 train_time:3054ms step_avg:92.55ms +step:34/1695 train_time:3149ms step_avg:92.61ms +step:35/1695 train_time:3242ms step_avg:92.63ms +step:36/1695 train_time:3336ms step_avg:92.67ms +step:37/1695 train_time:3430ms step_avg:92.71ms +step:38/1695 train_time:3523ms step_avg:92.71ms +step:39/1695 train_time:3617ms step_avg:92.73ms +step:40/1695 train_time:3710ms step_avg:92.74ms +step:41/1695 train_time:3802ms step_avg:92.74ms +step:42/1695 train_time:3896ms step_avg:92.76ms +step:43/1695 train_time:3989ms step_avg:92.77ms +step:44/1695 train_time:4083ms step_avg:92.80ms +step:45/1695 train_time:4177ms step_avg:92.81ms +step:46/1695 train_time:4270ms step_avg:92.83ms +step:47/1695 train_time:4364ms step_avg:92.85ms +step:48/1695 train_time:4459ms step_avg:92.89ms +step:49/1695 train_time:4553ms step_avg:92.92ms +step:50/1695 train_time:4646ms step_avg:92.93ms +step:51/1695 train_time:4740ms step_avg:92.95ms +step:52/1695 train_time:4834ms step_avg:92.96ms +step:53/1695 train_time:4927ms step_avg:92.97ms +step:54/1695 train_time:5021ms step_avg:92.99ms +step:55/1695 train_time:5116ms step_avg:93.01ms +step:56/1695 train_time:5209ms step_avg:93.02ms +step:57/1695 train_time:5303ms step_avg:93.03ms +step:58/1695 train_time:5397ms step_avg:93.05ms +step:59/1695 train_time:5491ms step_avg:93.07ms +step:60/1695 train_time:5585ms step_avg:93.08ms +step:61/1695 train_time:5678ms step_avg:93.08ms +step:62/1695 train_time:5771ms step_avg:93.09ms +step:63/1695 train_time:5865ms step_avg:93.09ms +step:64/1695 train_time:5959ms step_avg:93.10ms +step:65/1695 train_time:6052ms step_avg:93.10ms +step:66/1695 train_time:6145ms step_avg:93.11ms +step:67/1695 train_time:6239ms step_avg:93.11ms +step:68/1695 train_time:6333ms step_avg:93.13ms +step:69/1695 train_time:6426ms step_avg:93.13ms +step:70/1695 train_time:6520ms step_avg:93.14ms +step:71/1695 train_time:6614ms step_avg:93.16ms +step:72/1695 train_time:6709ms step_avg:93.18ms +step:73/1695 train_time:6803ms step_avg:93.19ms +step:74/1695 train_time:6896ms step_avg:93.19ms +step:75/1695 train_time:6989ms step_avg:93.18ms +step:76/1695 train_time:7083ms step_avg:93.20ms +step:77/1695 train_time:7177ms step_avg:93.21ms +step:78/1695 train_time:7271ms step_avg:93.22ms +step:79/1695 train_time:7364ms step_avg:93.22ms +step:80/1695 train_time:7458ms step_avg:93.22ms +step:81/1695 train_time:7551ms step_avg:93.22ms +step:82/1695 train_time:7644ms step_avg:93.22ms +step:83/1695 train_time:7739ms step_avg:93.24ms +step:84/1695 train_time:7832ms step_avg:93.24ms +step:85/1695 train_time:7925ms step_avg:93.23ms +step:86/1695 train_time:8019ms step_avg:93.24ms +step:87/1695 train_time:8112ms step_avg:93.24ms +step:88/1695 train_time:8205ms step_avg:93.24ms +step:89/1695 train_time:8299ms step_avg:93.25ms +step:90/1695 train_time:8393ms step_avg:93.26ms +step:91/1695 train_time:8486ms step_avg:93.25ms +step:92/1695 train_time:8580ms step_avg:93.26ms +step:93/1695 train_time:8674ms step_avg:93.26ms +step:94/1695 train_time:8767ms step_avg:93.27ms +step:95/1695 train_time:8861ms step_avg:93.27ms +step:96/1695 train_time:8954ms step_avg:93.27ms +step:97/1695 train_time:9047ms step_avg:93.27ms +step:98/1695 train_time:9141ms step_avg:93.28ms +step:99/1695 train_time:9235ms step_avg:93.28ms +step:100/1695 train_time:9329ms step_avg:93.29ms +step:101/1695 train_time:9423ms step_avg:93.29ms +step:102/1695 train_time:9516ms step_avg:93.29ms +step:103/1695 train_time:9609ms step_avg:93.29ms +step:104/1695 train_time:9703ms step_avg:93.29ms +step:105/1695 train_time:9797ms step_avg:93.30ms +step:106/1695 train_time:9890ms step_avg:93.30ms +step:107/1695 train_time:9983ms step_avg:93.30ms +step:108/1695 train_time:10077ms step_avg:93.31ms +step:109/1695 train_time:10171ms step_avg:93.31ms +step:110/1695 train_time:10264ms step_avg:93.31ms +step:111/1695 train_time:10357ms step_avg:93.31ms +step:112/1695 train_time:10451ms step_avg:93.31ms +step:113/1695 train_time:10545ms step_avg:93.31ms +step:114/1695 train_time:10639ms step_avg:93.32ms +step:115/1695 train_time:10732ms step_avg:93.32ms +step:116/1695 train_time:10826ms step_avg:93.32ms +step:117/1695 train_time:10919ms step_avg:93.32ms +step:118/1695 train_time:11012ms step_avg:93.32ms +step:119/1695 train_time:11105ms step_avg:93.32ms +step:120/1695 train_time:11199ms step_avg:93.33ms +step:121/1695 train_time:11293ms step_avg:93.33ms +step:122/1695 train_time:11387ms step_avg:93.33ms +step:123/1695 train_time:11480ms step_avg:93.33ms +step:124/1695 train_time:11573ms step_avg:93.33ms +step:125/1695 train_time:11667ms step_avg:93.33ms +step:125/1695 val_loss:4.6063 train_time:11759ms step_avg:94.07ms +step:126/1695 train_time:11785ms step_avg:93.53ms +step:127/1695 train_time:11864ms step_avg:93.41ms +step:128/1695 train_time:11963ms step_avg:93.46ms +step:129/1695 train_time:12057ms step_avg:93.46ms +step:130/1695 train_time:12150ms step_avg:93.46ms +step:131/1695 train_time:12243ms step_avg:93.46ms +step:132/1695 train_time:12337ms step_avg:93.46ms +step:133/1695 train_time:12430ms step_avg:93.46ms +step:134/1695 train_time:12524ms step_avg:93.46ms +step:135/1695 train_time:12617ms step_avg:93.46ms +step:136/1695 train_time:12711ms step_avg:93.46ms +step:137/1695 train_time:12806ms step_avg:93.48ms +step:138/1695 train_time:12903ms step_avg:93.50ms +step:139/1695 train_time:12999ms step_avg:93.52ms +step:140/1695 train_time:13094ms step_avg:93.53ms +step:141/1695 train_time:13187ms step_avg:93.53ms +step:142/1695 train_time:13282ms step_avg:93.53ms +step:143/1695 train_time:13376ms step_avg:93.54ms +step:144/1695 train_time:13469ms step_avg:93.53ms +step:145/1695 train_time:13563ms step_avg:93.54ms +step:146/1695 train_time:13656ms step_avg:93.53ms +step:147/1695 train_time:13750ms step_avg:93.54ms +step:148/1695 train_time:13845ms step_avg:93.55ms +step:149/1695 train_time:13940ms step_avg:93.56ms +step:150/1695 train_time:14035ms step_avg:93.57ms +step:151/1695 train_time:14129ms step_avg:93.57ms +step:152/1695 train_time:14223ms step_avg:93.57ms +step:153/1695 train_time:14318ms step_avg:93.58ms +step:154/1695 train_time:14411ms step_avg:93.58ms +step:155/1695 train_time:14504ms step_avg:93.58ms +step:156/1695 train_time:14598ms step_avg:93.58ms +step:157/1695 train_time:14692ms step_avg:93.58ms +step:158/1695 train_time:14785ms step_avg:93.58ms +step:159/1695 train_time:14879ms step_avg:93.58ms +step:160/1695 train_time:14974ms step_avg:93.59ms +step:161/1695 train_time:15068ms step_avg:93.59ms +step:162/1695 train_time:15162ms step_avg:93.59ms +step:163/1695 train_time:15256ms step_avg:93.59ms +step:164/1695 train_time:15349ms step_avg:93.59ms +step:165/1695 train_time:15443ms step_avg:93.60ms +step:166/1695 train_time:15537ms step_avg:93.59ms +step:167/1695 train_time:15630ms step_avg:93.59ms +step:168/1695 train_time:15724ms step_avg:93.59ms +step:169/1695 train_time:15818ms step_avg:93.60ms +step:170/1695 train_time:15912ms step_avg:93.60ms +step:171/1695 train_time:16006ms step_avg:93.60ms +step:172/1695 train_time:16101ms step_avg:93.61ms +step:173/1695 train_time:16197ms step_avg:93.62ms +step:174/1695 train_time:16290ms step_avg:93.62ms +step:175/1695 train_time:16385ms step_avg:93.63ms +step:176/1695 train_time:16479ms step_avg:93.63ms +step:177/1695 train_time:16572ms step_avg:93.63ms +step:178/1695 train_time:16666ms step_avg:93.63ms +step:179/1695 train_time:16759ms step_avg:93.63ms +step:180/1695 train_time:16854ms step_avg:93.63ms +step:181/1695 train_time:16948ms step_avg:93.63ms +step:182/1695 train_time:17041ms step_avg:93.63ms +step:183/1695 train_time:17137ms step_avg:93.64ms +step:184/1695 train_time:17230ms step_avg:93.64ms +step:185/1695 train_time:17325ms step_avg:93.65ms +step:186/1695 train_time:17420ms step_avg:93.65ms +step:187/1695 train_time:17514ms step_avg:93.66ms +step:188/1695 train_time:17607ms step_avg:93.66ms +step:189/1695 train_time:17701ms step_avg:93.66ms +step:190/1695 train_time:17795ms step_avg:93.66ms +step:191/1695 train_time:17889ms step_avg:93.66ms +step:192/1695 train_time:17983ms step_avg:93.66ms +step:193/1695 train_time:18077ms step_avg:93.66ms +step:194/1695 train_time:18171ms step_avg:93.66ms +step:195/1695 train_time:18266ms step_avg:93.67ms +step:196/1695 train_time:18361ms step_avg:93.68ms +step:197/1695 train_time:18455ms step_avg:93.68ms +step:198/1695 train_time:18549ms step_avg:93.68ms +step:199/1695 train_time:18642ms step_avg:93.68ms +step:200/1695 train_time:18737ms step_avg:93.68ms +step:201/1695 train_time:18830ms step_avg:93.68ms +step:202/1695 train_time:18924ms step_avg:93.68ms +step:203/1695 train_time:19019ms step_avg:93.69ms +step:204/1695 train_time:19114ms step_avg:93.70ms +step:205/1695 train_time:19207ms step_avg:93.69ms +step:206/1695 train_time:19302ms step_avg:93.70ms +step:207/1695 train_time:19397ms step_avg:93.70ms +step:208/1695 train_time:19490ms step_avg:93.70ms +step:209/1695 train_time:19585ms step_avg:93.71ms +step:210/1695 train_time:19680ms step_avg:93.71ms +step:211/1695 train_time:19773ms step_avg:93.71ms +step:212/1695 train_time:19867ms step_avg:93.71ms +step:213/1695 train_time:19961ms step_avg:93.71ms +step:214/1695 train_time:20055ms step_avg:93.72ms +step:215/1695 train_time:20148ms step_avg:93.71ms +step:216/1695 train_time:20242ms step_avg:93.71ms +step:217/1695 train_time:20335ms step_avg:93.71ms +step:218/1695 train_time:20429ms step_avg:93.71ms +step:219/1695 train_time:20523ms step_avg:93.71ms +step:220/1695 train_time:20618ms step_avg:93.72ms +step:221/1695 train_time:20711ms step_avg:93.72ms +step:222/1695 train_time:20805ms step_avg:93.72ms +step:223/1695 train_time:20900ms step_avg:93.72ms +step:224/1695 train_time:20993ms step_avg:93.72ms +step:225/1695 train_time:21087ms step_avg:93.72ms +step:226/1695 train_time:21181ms step_avg:93.72ms +step:227/1695 train_time:21275ms step_avg:93.72ms +step:228/1695 train_time:21368ms step_avg:93.72ms +step:229/1695 train_time:21463ms step_avg:93.72ms +step:230/1695 train_time:21558ms step_avg:93.73ms +step:231/1695 train_time:21651ms step_avg:93.73ms +step:232/1695 train_time:21745ms step_avg:93.73ms +step:233/1695 train_time:21839ms step_avg:93.73ms +step:234/1695 train_time:21932ms step_avg:93.73ms +step:235/1695 train_time:22026ms step_avg:93.73ms +step:236/1695 train_time:22121ms step_avg:93.73ms +step:237/1695 train_time:22216ms step_avg:93.74ms +step:238/1695 train_time:22309ms step_avg:93.74ms +step:239/1695 train_time:22404ms step_avg:93.74ms +step:240/1695 train_time:22498ms step_avg:93.74ms +step:241/1695 train_time:22592ms step_avg:93.74ms +step:242/1695 train_time:22685ms step_avg:93.74ms +step:243/1695 train_time:22780ms step_avg:93.74ms +step:244/1695 train_time:22873ms step_avg:93.74ms +step:245/1695 train_time:22967ms step_avg:93.74ms +step:246/1695 train_time:23061ms step_avg:93.74ms +step:247/1695 train_time:23154ms step_avg:93.74ms +step:248/1695 train_time:23248ms step_avg:93.74ms +step:249/1695 train_time:23342ms step_avg:93.74ms +step:250/1695 train_time:23436ms step_avg:93.74ms +step:250/1695 val_loss:4.0653 train_time:23527ms step_avg:94.11ms +step:251/1695 train_time:23554ms step_avg:93.84ms +step:252/1695 train_time:23632ms step_avg:93.78ms +step:253/1695 train_time:23730ms step_avg:93.79ms +step:254/1695 train_time:23824ms step_avg:93.80ms +step:255/1695 train_time:23918ms step_avg:93.80ms +step:256/1695 train_time:24012ms step_avg:93.80ms +step:257/1695 train_time:24106ms step_avg:93.80ms +step:258/1695 train_time:24200ms step_avg:93.80ms +step:259/1695 train_time:24293ms step_avg:93.80ms +step:260/1695 train_time:24387ms step_avg:93.79ms +step:261/1695 train_time:24482ms step_avg:93.80ms +step:262/1695 train_time:24579ms step_avg:93.81ms +step:263/1695 train_time:24673ms step_avg:93.81ms +step:264/1695 train_time:24768ms step_avg:93.82ms +step:265/1695 train_time:24863ms step_avg:93.82ms +step:266/1695 train_time:24957ms step_avg:93.82ms +step:267/1695 train_time:25051ms step_avg:93.82ms +step:268/1695 train_time:25145ms step_avg:93.82ms +step:269/1695 train_time:25239ms step_avg:93.83ms +step:270/1695 train_time:25334ms step_avg:93.83ms +step:271/1695 train_time:25427ms step_avg:93.83ms +step:272/1695 train_time:25523ms step_avg:93.83ms +step:273/1695 train_time:25619ms step_avg:93.84ms +step:274/1695 train_time:25714ms step_avg:93.85ms +step:275/1695 train_time:25809ms step_avg:93.85ms +step:276/1695 train_time:25903ms step_avg:93.85ms +step:277/1695 train_time:25997ms step_avg:93.85ms +step:278/1695 train_time:26091ms step_avg:93.85ms +step:279/1695 train_time:26186ms step_avg:93.86ms +step:280/1695 train_time:26281ms step_avg:93.86ms +step:281/1695 train_time:26375ms step_avg:93.86ms +step:282/1695 train_time:26469ms step_avg:93.86ms +step:283/1695 train_time:26563ms step_avg:93.86ms +step:284/1695 train_time:26658ms step_avg:93.87ms +step:285/1695 train_time:26753ms step_avg:93.87ms +step:286/1695 train_time:26847ms step_avg:93.87ms +step:287/1695 train_time:26941ms step_avg:93.87ms +step:288/1695 train_time:27037ms step_avg:93.88ms +step:289/1695 train_time:27131ms step_avg:93.88ms +step:290/1695 train_time:27225ms step_avg:93.88ms +step:291/1695 train_time:27319ms step_avg:93.88ms +step:292/1695 train_time:27414ms step_avg:93.88ms +step:293/1695 train_time:27507ms step_avg:93.88ms +step:294/1695 train_time:27602ms step_avg:93.88ms +step:295/1695 train_time:27697ms step_avg:93.89ms +step:296/1695 train_time:27792ms step_avg:93.89ms +step:297/1695 train_time:27887ms step_avg:93.90ms +step:298/1695 train_time:27982ms step_avg:93.90ms +step:299/1695 train_time:28077ms step_avg:93.90ms +step:300/1695 train_time:28171ms step_avg:93.90ms +step:301/1695 train_time:28265ms step_avg:93.90ms +step:302/1695 train_time:28360ms step_avg:93.91ms +step:303/1695 train_time:28454ms step_avg:93.91ms +step:304/1695 train_time:28548ms step_avg:93.91ms +step:305/1695 train_time:28642ms step_avg:93.91ms +step:306/1695 train_time:28736ms step_avg:93.91ms +step:307/1695 train_time:28830ms step_avg:93.91ms +step:308/1695 train_time:28924ms step_avg:93.91ms +step:309/1695 train_time:29018ms step_avg:93.91ms +step:310/1695 train_time:29113ms step_avg:93.91ms +step:311/1695 train_time:29206ms step_avg:93.91ms +step:312/1695 train_time:29302ms step_avg:93.92ms +step:313/1695 train_time:29398ms step_avg:93.92ms +step:314/1695 train_time:29492ms step_avg:93.92ms +step:315/1695 train_time:29585ms step_avg:93.92ms +step:316/1695 train_time:29681ms step_avg:93.93ms +step:317/1695 train_time:29775ms step_avg:93.93ms +step:318/1695 train_time:29869ms step_avg:93.93ms +step:319/1695 train_time:29963ms step_avg:93.93ms +step:320/1695 train_time:30058ms step_avg:93.93ms +step:321/1695 train_time:30153ms step_avg:93.93ms +step:322/1695 train_time:30247ms step_avg:93.93ms +step:323/1695 train_time:30341ms step_avg:93.94ms +step:324/1695 train_time:30436ms step_avg:93.94ms +step:325/1695 train_time:30531ms step_avg:93.94ms +step:326/1695 train_time:30625ms step_avg:93.94ms +step:327/1695 train_time:30721ms step_avg:93.95ms +step:328/1695 train_time:30815ms step_avg:93.95ms +step:329/1695 train_time:30909ms step_avg:93.95ms +step:330/1695 train_time:31003ms step_avg:93.95ms +step:331/1695 train_time:31098ms step_avg:93.95ms +step:332/1695 train_time:31193ms step_avg:93.95ms +step:333/1695 train_time:31286ms step_avg:93.95ms +step:334/1695 train_time:31381ms step_avg:93.95ms +step:335/1695 train_time:31476ms step_avg:93.96ms +step:336/1695 train_time:31569ms step_avg:93.96ms +step:337/1695 train_time:31664ms step_avg:93.96ms +step:338/1695 train_time:31758ms step_avg:93.96ms +step:339/1695 train_time:31853ms step_avg:93.96ms +step:340/1695 train_time:31946ms step_avg:93.96ms +step:341/1695 train_time:32041ms step_avg:93.96ms +step:342/1695 train_time:32136ms step_avg:93.96ms +step:343/1695 train_time:32229ms step_avg:93.96ms +step:344/1695 train_time:32323ms step_avg:93.96ms +step:345/1695 train_time:32418ms step_avg:93.97ms +step:346/1695 train_time:32512ms step_avg:93.97ms +step:347/1695 train_time:32606ms step_avg:93.97ms +step:348/1695 train_time:32701ms step_avg:93.97ms +step:349/1695 train_time:32796ms step_avg:93.97ms +step:350/1695 train_time:32889ms step_avg:93.97ms +step:351/1695 train_time:32985ms step_avg:93.97ms +step:352/1695 train_time:33079ms step_avg:93.97ms +step:353/1695 train_time:33174ms step_avg:93.98ms +step:354/1695 train_time:33268ms step_avg:93.98ms +step:355/1695 train_time:33362ms step_avg:93.98ms +step:356/1695 train_time:33457ms step_avg:93.98ms +step:357/1695 train_time:33551ms step_avg:93.98ms +step:358/1695 train_time:33645ms step_avg:93.98ms +step:359/1695 train_time:33740ms step_avg:93.98ms +step:360/1695 train_time:33836ms step_avg:93.99ms +step:361/1695 train_time:33930ms step_avg:93.99ms +step:362/1695 train_time:34024ms step_avg:93.99ms +step:363/1695 train_time:34119ms step_avg:93.99ms +step:364/1695 train_time:34213ms step_avg:93.99ms +step:365/1695 train_time:34306ms step_avg:93.99ms +step:366/1695 train_time:34402ms step_avg:93.99ms +step:367/1695 train_time:34497ms step_avg:94.00ms +step:368/1695 train_time:34591ms step_avg:94.00ms +step:369/1695 train_time:34685ms step_avg:94.00ms +step:370/1695 train_time:34780ms step_avg:94.00ms +step:371/1695 train_time:34875ms step_avg:94.00ms +step:372/1695 train_time:34969ms step_avg:94.00ms +step:373/1695 train_time:35063ms step_avg:94.00ms +step:374/1695 train_time:35157ms step_avg:94.00ms +step:375/1695 train_time:35251ms step_avg:94.00ms +step:375/1695 val_loss:3.8671 train_time:35344ms step_avg:94.25ms +step:376/1695 train_time:35370ms step_avg:94.07ms +step:377/1695 train_time:35449ms step_avg:94.03ms +step:378/1695 train_time:35549ms step_avg:94.05ms +step:379/1695 train_time:35645ms step_avg:94.05ms +step:380/1695 train_time:35740ms step_avg:94.05ms +step:381/1695 train_time:35836ms step_avg:94.06ms +step:382/1695 train_time:35931ms step_avg:94.06ms +step:383/1695 train_time:36026ms step_avg:94.06ms +step:384/1695 train_time:36122ms step_avg:94.07ms +step:385/1695 train_time:36218ms step_avg:94.07ms +step:386/1695 train_time:36313ms step_avg:94.07ms +step:387/1695 train_time:36409ms step_avg:94.08ms +step:388/1695 train_time:36506ms step_avg:94.09ms +step:389/1695 train_time:36604ms step_avg:94.10ms +step:390/1695 train_time:36700ms step_avg:94.10ms +step:391/1695 train_time:36796ms step_avg:94.11ms +step:392/1695 train_time:36893ms step_avg:94.11ms +step:393/1695 train_time:36989ms step_avg:94.12ms +step:394/1695 train_time:37084ms step_avg:94.12ms +step:395/1695 train_time:37179ms step_avg:94.12ms +step:396/1695 train_time:37275ms step_avg:94.13ms +step:397/1695 train_time:37371ms step_avg:94.13ms +step:398/1695 train_time:37468ms step_avg:94.14ms +step:399/1695 train_time:37564ms step_avg:94.15ms +step:400/1695 train_time:37660ms step_avg:94.15ms +step:401/1695 train_time:37757ms step_avg:94.16ms +step:402/1695 train_time:37853ms step_avg:94.16ms +step:403/1695 train_time:37950ms step_avg:94.17ms +step:404/1695 train_time:38045ms step_avg:94.17ms +step:405/1695 train_time:38141ms step_avg:94.18ms +step:406/1695 train_time:38237ms step_avg:94.18ms +step:407/1695 train_time:38333ms step_avg:94.18ms +step:408/1695 train_time:38429ms step_avg:94.19ms +step:409/1695 train_time:38525ms step_avg:94.19ms +step:410/1695 train_time:38621ms step_avg:94.20ms +step:411/1695 train_time:38717ms step_avg:94.20ms +step:412/1695 train_time:38813ms step_avg:94.21ms +step:413/1695 train_time:38909ms step_avg:94.21ms +step:414/1695 train_time:39005ms step_avg:94.21ms +step:415/1695 train_time:39101ms step_avg:94.22ms +step:416/1695 train_time:39197ms step_avg:94.22ms +step:417/1695 train_time:39292ms step_avg:94.23ms +step:418/1695 train_time:39388ms step_avg:94.23ms +step:419/1695 train_time:39484ms step_avg:94.23ms +step:420/1695 train_time:39580ms step_avg:94.24ms +step:421/1695 train_time:39676ms step_avg:94.24ms +step:422/1695 train_time:39772ms step_avg:94.25ms +step:423/1695 train_time:39869ms step_avg:94.25ms +step:424/1695 train_time:39965ms step_avg:94.26ms +step:425/1695 train_time:40061ms step_avg:94.26ms +step:426/1695 train_time:40157ms step_avg:94.27ms +step:427/1695 train_time:40253ms step_avg:94.27ms +step:428/1695 train_time:40349ms step_avg:94.27ms +step:429/1695 train_time:40446ms step_avg:94.28ms +step:430/1695 train_time:40541ms step_avg:94.28ms +step:431/1695 train_time:40637ms step_avg:94.29ms +step:432/1695 train_time:40733ms step_avg:94.29ms +step:433/1695 train_time:40829ms step_avg:94.29ms +step:434/1695 train_time:40926ms step_avg:94.30ms +step:435/1695 train_time:41022ms step_avg:94.30ms +step:436/1695 train_time:41118ms step_avg:94.31ms +step:437/1695 train_time:41214ms step_avg:94.31ms +step:438/1695 train_time:41312ms step_avg:94.32ms +step:439/1695 train_time:41408ms step_avg:94.32ms +step:440/1695 train_time:41504ms step_avg:94.33ms +step:441/1695 train_time:41599ms step_avg:94.33ms +step:442/1695 train_time:41695ms step_avg:94.33ms +step:443/1695 train_time:41792ms step_avg:94.34ms +step:444/1695 train_time:41888ms step_avg:94.34ms +step:445/1695 train_time:41986ms step_avg:94.35ms +step:446/1695 train_time:42082ms step_avg:94.35ms +step:447/1695 train_time:42178ms step_avg:94.36ms +step:448/1695 train_time:42274ms step_avg:94.36ms +step:449/1695 train_time:42371ms step_avg:94.37ms +step:450/1695 train_time:42467ms step_avg:94.37ms +step:451/1695 train_time:42562ms step_avg:94.37ms +step:452/1695 train_time:42658ms step_avg:94.38ms +step:453/1695 train_time:42755ms step_avg:94.38ms +step:454/1695 train_time:42851ms step_avg:94.39ms +step:455/1695 train_time:42948ms step_avg:94.39ms +step:456/1695 train_time:43044ms step_avg:94.39ms +step:457/1695 train_time:43139ms step_avg:94.40ms +step:458/1695 train_time:43236ms step_avg:94.40ms +step:459/1695 train_time:43332ms step_avg:94.41ms +step:460/1695 train_time:43429ms step_avg:94.41ms +step:461/1695 train_time:43526ms step_avg:94.42ms +step:462/1695 train_time:43621ms step_avg:94.42ms +step:463/1695 train_time:43717ms step_avg:94.42ms +step:464/1695 train_time:43812ms step_avg:94.42ms +step:465/1695 train_time:43909ms step_avg:94.43ms +step:466/1695 train_time:44005ms step_avg:94.43ms +step:467/1695 train_time:44100ms step_avg:94.43ms +step:468/1695 train_time:44197ms step_avg:94.44ms +step:469/1695 train_time:44293ms step_avg:94.44ms +step:470/1695 train_time:44389ms step_avg:94.45ms +step:471/1695 train_time:44486ms step_avg:94.45ms +step:472/1695 train_time:44582ms step_avg:94.45ms +step:473/1695 train_time:44677ms step_avg:94.46ms +step:474/1695 train_time:44773ms step_avg:94.46ms +step:475/1695 train_time:44870ms step_avg:94.46ms +step:476/1695 train_time:44966ms step_avg:94.47ms +step:477/1695 train_time:45061ms step_avg:94.47ms +step:478/1695 train_time:45157ms step_avg:94.47ms +step:479/1695 train_time:45253ms step_avg:94.47ms +step:480/1695 train_time:45350ms step_avg:94.48ms +step:481/1695 train_time:45446ms step_avg:94.48ms +step:482/1695 train_time:45542ms step_avg:94.49ms +step:483/1695 train_time:45638ms step_avg:94.49ms +step:484/1695 train_time:45735ms step_avg:94.49ms +step:485/1695 train_time:45830ms step_avg:94.49ms +step:486/1695 train_time:45927ms step_avg:94.50ms +step:487/1695 train_time:46023ms step_avg:94.50ms +step:488/1695 train_time:46119ms step_avg:94.51ms +step:489/1695 train_time:46216ms step_avg:94.51ms +step:490/1695 train_time:46312ms step_avg:94.51ms +step:491/1695 train_time:46408ms step_avg:94.52ms +step:492/1695 train_time:46503ms step_avg:94.52ms +step:493/1695 train_time:46598ms step_avg:94.52ms +step:494/1695 train_time:46694ms step_avg:94.52ms +step:495/1695 train_time:46790ms step_avg:94.53ms +step:496/1695 train_time:46886ms step_avg:94.53ms +step:497/1695 train_time:46982ms step_avg:94.53ms +step:498/1695 train_time:47078ms step_avg:94.53ms +step:499/1695 train_time:47175ms step_avg:94.54ms +step:500/1695 train_time:47271ms step_avg:94.54ms +step:500/1695 val_loss:3.7271 train_time:47366ms step_avg:94.73ms +step:501/1695 train_time:47393ms step_avg:94.60ms +step:502/1695 train_time:47474ms step_avg:94.57ms +step:503/1695 train_time:47572ms step_avg:94.58ms +step:504/1695 train_time:47669ms step_avg:94.58ms +step:505/1695 train_time:47764ms step_avg:94.58ms +step:506/1695 train_time:47860ms step_avg:94.58ms +step:507/1695 train_time:47955ms step_avg:94.59ms +step:508/1695 train_time:48051ms step_avg:94.59ms +step:509/1695 train_time:48146ms step_avg:94.59ms +step:510/1695 train_time:48242ms step_avg:94.59ms +step:511/1695 train_time:48337ms step_avg:94.59ms +step:512/1695 train_time:48434ms step_avg:94.60ms +step:513/1695 train_time:48532ms step_avg:94.60ms +step:514/1695 train_time:48629ms step_avg:94.61ms +step:515/1695 train_time:48725ms step_avg:94.61ms +step:516/1695 train_time:48821ms step_avg:94.61ms +step:517/1695 train_time:48917ms step_avg:94.62ms +step:518/1695 train_time:49014ms step_avg:94.62ms +step:519/1695 train_time:49110ms step_avg:94.62ms +step:520/1695 train_time:49205ms step_avg:94.63ms +step:521/1695 train_time:49301ms step_avg:94.63ms +step:522/1695 train_time:49398ms step_avg:94.63ms +step:523/1695 train_time:49495ms step_avg:94.64ms +step:524/1695 train_time:49593ms step_avg:94.64ms +step:525/1695 train_time:49690ms step_avg:94.65ms +step:526/1695 train_time:49788ms step_avg:94.65ms +step:527/1695 train_time:49883ms step_avg:94.66ms +step:528/1695 train_time:49979ms step_avg:94.66ms +step:529/1695 train_time:50075ms step_avg:94.66ms +step:530/1695 train_time:50171ms step_avg:94.66ms +step:531/1695 train_time:50268ms step_avg:94.67ms +step:532/1695 train_time:50363ms step_avg:94.67ms +step:533/1695 train_time:50459ms step_avg:94.67ms +step:534/1695 train_time:50557ms step_avg:94.68ms +step:535/1695 train_time:50654ms step_avg:94.68ms +step:536/1695 train_time:50752ms step_avg:94.69ms +step:537/1695 train_time:50848ms step_avg:94.69ms +step:538/1695 train_time:50944ms step_avg:94.69ms +step:539/1695 train_time:51040ms step_avg:94.69ms +step:540/1695 train_time:51136ms step_avg:94.70ms +step:541/1695 train_time:51233ms step_avg:94.70ms +step:542/1695 train_time:51329ms step_avg:94.70ms +step:543/1695 train_time:51425ms step_avg:94.71ms +step:544/1695 train_time:51521ms step_avg:94.71ms +step:545/1695 train_time:51619ms step_avg:94.71ms +step:546/1695 train_time:51717ms step_avg:94.72ms +step:547/1695 train_time:51815ms step_avg:94.73ms +step:548/1695 train_time:51912ms step_avg:94.73ms +step:549/1695 train_time:52009ms step_avg:94.73ms +step:550/1695 train_time:52105ms step_avg:94.74ms +step:551/1695 train_time:52201ms step_avg:94.74ms +step:552/1695 train_time:52298ms step_avg:94.74ms +step:553/1695 train_time:52394ms step_avg:94.74ms +step:554/1695 train_time:52491ms step_avg:94.75ms +step:555/1695 train_time:52587ms step_avg:94.75ms +step:556/1695 train_time:52683ms step_avg:94.75ms +step:557/1695 train_time:52780ms step_avg:94.76ms +step:558/1695 train_time:52877ms step_avg:94.76ms +step:559/1695 train_time:52974ms step_avg:94.77ms +step:560/1695 train_time:53070ms step_avg:94.77ms +step:561/1695 train_time:53166ms step_avg:94.77ms +step:562/1695 train_time:53262ms step_avg:94.77ms +step:563/1695 train_time:53358ms step_avg:94.78ms +step:564/1695 train_time:53455ms step_avg:94.78ms +step:565/1695 train_time:53551ms step_avg:94.78ms +step:566/1695 train_time:53647ms step_avg:94.78ms +step:567/1695 train_time:53743ms step_avg:94.79ms +step:568/1695 train_time:54068ms step_avg:95.19ms +step:569/1695 train_time:54162ms step_avg:95.19ms +step:570/1695 train_time:54258ms step_avg:95.19ms +step:571/1695 train_time:54354ms step_avg:95.19ms +step:572/1695 train_time:54448ms step_avg:95.19ms +step:573/1695 train_time:54544ms step_avg:95.19ms +step:574/1695 train_time:54640ms step_avg:95.19ms +step:575/1695 train_time:54735ms step_avg:95.19ms +step:576/1695 train_time:54831ms step_avg:95.19ms +step:577/1695 train_time:54927ms step_avg:95.19ms +step:578/1695 train_time:55025ms step_avg:95.20ms +step:579/1695 train_time:55123ms step_avg:95.20ms +step:580/1695 train_time:55220ms step_avg:95.21ms +step:581/1695 train_time:55317ms step_avg:95.21ms +step:582/1695 train_time:55413ms step_avg:95.21ms +step:583/1695 train_time:55508ms step_avg:95.21ms +step:584/1695 train_time:55604ms step_avg:95.21ms +step:585/1695 train_time:55699ms step_avg:95.21ms +step:586/1695 train_time:55796ms step_avg:95.21ms +step:587/1695 train_time:55891ms step_avg:95.22ms +step:588/1695 train_time:55989ms step_avg:95.22ms +step:589/1695 train_time:56085ms step_avg:95.22ms +step:590/1695 train_time:56181ms step_avg:95.22ms +step:591/1695 train_time:56279ms step_avg:95.23ms +step:592/1695 train_time:56376ms step_avg:95.23ms +step:593/1695 train_time:56473ms step_avg:95.23ms +step:594/1695 train_time:56569ms step_avg:95.23ms +step:595/1695 train_time:56664ms step_avg:95.23ms +step:596/1695 train_time:56760ms step_avg:95.23ms +step:597/1695 train_time:56856ms step_avg:95.24ms +step:598/1695 train_time:56953ms step_avg:95.24ms +step:599/1695 train_time:57050ms step_avg:95.24ms +step:600/1695 train_time:57147ms step_avg:95.24ms +step:601/1695 train_time:57244ms step_avg:95.25ms +step:602/1695 train_time:57340ms step_avg:95.25ms +step:603/1695 train_time:57437ms step_avg:95.25ms +step:604/1695 train_time:57533ms step_avg:95.25ms +step:605/1695 train_time:57630ms step_avg:95.26ms +step:606/1695 train_time:57725ms step_avg:95.26ms +step:607/1695 train_time:57821ms step_avg:95.26ms +step:608/1695 train_time:57918ms step_avg:95.26ms +step:609/1695 train_time:58015ms step_avg:95.26ms +step:610/1695 train_time:58112ms step_avg:95.27ms +step:611/1695 train_time:58209ms step_avg:95.27ms +step:612/1695 train_time:58305ms step_avg:95.27ms +step:613/1695 train_time:58401ms step_avg:95.27ms +step:614/1695 train_time:58498ms step_avg:95.27ms +step:615/1695 train_time:58594ms step_avg:95.28ms +step:616/1695 train_time:58690ms step_avg:95.28ms +step:617/1695 train_time:58786ms step_avg:95.28ms +step:618/1695 train_time:58882ms step_avg:95.28ms +step:619/1695 train_time:58979ms step_avg:95.28ms +step:620/1695 train_time:59076ms step_avg:95.28ms +step:621/1695 train_time:59173ms step_avg:95.29ms +step:622/1695 train_time:59270ms step_avg:95.29ms +step:623/1695 train_time:59366ms step_avg:95.29ms +step:624/1695 train_time:59462ms step_avg:95.29ms +step:625/1695 train_time:59558ms step_avg:95.29ms +step:625/1695 val_loss:3.6433 train_time:59653ms step_avg:95.44ms +step:626/1695 train_time:59679ms step_avg:95.33ms +step:627/1695 train_time:59760ms step_avg:95.31ms +step:628/1695 train_time:59861ms step_avg:95.32ms +step:629/1695 train_time:60187ms step_avg:95.69ms +step:630/1695 train_time:60284ms step_avg:95.69ms +step:631/1695 train_time:60380ms step_avg:95.69ms +step:632/1695 train_time:60477ms step_avg:95.69ms +step:633/1695 train_time:60574ms step_avg:95.69ms +step:634/1695 train_time:60672ms step_avg:95.70ms +step:635/1695 train_time:61019ms step_avg:96.09ms +step:636/1695 train_time:61114ms step_avg:96.09ms +step:637/1695 train_time:61211ms step_avg:96.09ms +step:638/1695 train_time:61308ms step_avg:96.09ms +step:639/1695 train_time:61405ms step_avg:96.10ms +step:640/1695 train_time:61502ms step_avg:96.10ms +step:641/1695 train_time:61599ms step_avg:96.10ms +step:642/1695 train_time:61696ms step_avg:96.10ms +step:643/1695 train_time:61794ms step_avg:96.10ms +step:644/1695 train_time:61891ms step_avg:96.10ms +step:645/1695 train_time:61990ms step_avg:96.11ms +step:646/1695 train_time:62090ms step_avg:96.11ms +step:647/1695 train_time:62188ms step_avg:96.12ms +step:648/1695 train_time:62285ms step_avg:96.12ms +step:649/1695 train_time:62383ms step_avg:96.12ms +step:650/1695 train_time:62479ms step_avg:96.12ms +step:651/1695 train_time:62577ms step_avg:96.12ms +step:652/1695 train_time:62674ms step_avg:96.13ms +step:653/1695 train_time:62772ms step_avg:96.13ms +step:654/1695 train_time:62870ms step_avg:96.13ms +step:655/1695 train_time:62969ms step_avg:96.14ms +step:656/1695 train_time:63068ms step_avg:96.14ms +step:657/1695 train_time:63166ms step_avg:96.14ms +step:658/1695 train_time:63264ms step_avg:96.15ms +step:659/1695 train_time:63361ms step_avg:96.15ms +step:660/1695 train_time:63458ms step_avg:96.15ms +step:661/1695 train_time:63556ms step_avg:96.15ms +step:662/1695 train_time:63654ms step_avg:96.15ms +step:663/1695 train_time:63752ms step_avg:96.16ms +step:664/1695 train_time:63850ms step_avg:96.16ms +step:665/1695 train_time:63948ms step_avg:96.16ms +step:666/1695 train_time:64046ms step_avg:96.17ms +step:667/1695 train_time:64143ms step_avg:96.17ms +step:668/1695 train_time:64241ms step_avg:96.17ms +step:669/1695 train_time:64339ms step_avg:96.17ms +step:670/1695 train_time:64438ms step_avg:96.18ms +step:671/1695 train_time:64536ms step_avg:96.18ms +step:672/1695 train_time:64634ms step_avg:96.18ms +step:673/1695 train_time:64732ms step_avg:96.18ms +step:674/1695 train_time:64830ms step_avg:96.19ms +step:675/1695 train_time:64928ms step_avg:96.19ms +step:676/1695 train_time:65025ms step_avg:96.19ms +step:677/1695 train_time:65123ms step_avg:96.19ms +step:678/1695 train_time:65222ms step_avg:96.20ms +step:679/1695 train_time:65320ms step_avg:96.20ms +step:680/1695 train_time:65418ms step_avg:96.20ms +step:681/1695 train_time:65516ms step_avg:96.21ms +step:682/1695 train_time:65614ms step_avg:96.21ms +step:683/1695 train_time:65712ms step_avg:96.21ms +step:684/1695 train_time:65810ms step_avg:96.21ms +step:685/1695 train_time:65908ms step_avg:96.22ms +step:686/1695 train_time:66007ms step_avg:96.22ms +step:687/1695 train_time:66105ms step_avg:96.22ms +step:688/1695 train_time:66203ms step_avg:96.23ms +step:689/1695 train_time:66301ms step_avg:96.23ms +step:690/1695 train_time:66398ms step_avg:96.23ms +step:691/1695 train_time:66496ms step_avg:96.23ms +step:692/1695 train_time:66594ms step_avg:96.23ms +step:693/1695 train_time:66692ms step_avg:96.24ms +step:694/1695 train_time:66789ms step_avg:96.24ms +step:695/1695 train_time:66887ms step_avg:96.24ms +step:696/1695 train_time:66984ms step_avg:96.24ms +step:697/1695 train_time:67082ms step_avg:96.24ms +step:698/1695 train_time:67180ms step_avg:96.25ms +step:699/1695 train_time:67279ms step_avg:96.25ms +step:700/1695 train_time:67377ms step_avg:96.25ms +step:701/1695 train_time:67476ms step_avg:96.26ms +step:702/1695 train_time:67574ms step_avg:96.26ms +step:703/1695 train_time:67672ms step_avg:96.26ms +step:704/1695 train_time:67769ms step_avg:96.26ms +step:705/1695 train_time:67867ms step_avg:96.27ms +step:706/1695 train_time:67965ms step_avg:96.27ms +step:707/1695 train_time:68063ms step_avg:96.27ms +step:708/1695 train_time:68160ms step_avg:96.27ms +step:709/1695 train_time:68258ms step_avg:96.27ms +step:710/1695 train_time:68356ms step_avg:96.28ms +step:711/1695 train_time:68453ms step_avg:96.28ms +step:712/1695 train_time:68552ms step_avg:96.28ms +step:713/1695 train_time:68650ms step_avg:96.28ms +step:714/1695 train_time:68747ms step_avg:96.28ms +step:715/1695 train_time:68844ms step_avg:96.29ms +step:716/1695 train_time:68942ms step_avg:96.29ms +step:717/1695 train_time:69041ms step_avg:96.29ms +step:718/1695 train_time:69138ms step_avg:96.29ms +step:719/1695 train_time:69236ms step_avg:96.30ms +step:720/1695 train_time:69335ms step_avg:96.30ms +step:721/1695 train_time:69434ms step_avg:96.30ms +step:722/1695 train_time:69533ms step_avg:96.31ms +step:723/1695 train_time:69630ms step_avg:96.31ms +step:724/1695 train_time:69728ms step_avg:96.31ms +step:725/1695 train_time:69826ms step_avg:96.31ms +step:726/1695 train_time:70167ms step_avg:96.65ms +step:727/1695 train_time:70262ms step_avg:96.65ms +step:728/1695 train_time:70360ms step_avg:96.65ms +step:729/1695 train_time:70457ms step_avg:96.65ms +step:730/1695 train_time:70555ms step_avg:96.65ms +step:731/1695 train_time:70653ms step_avg:96.65ms +step:732/1695 train_time:70750ms step_avg:96.65ms +step:733/1695 train_time:70847ms step_avg:96.65ms +step:734/1695 train_time:70943ms step_avg:96.65ms +step:735/1695 train_time:71040ms step_avg:96.65ms +step:736/1695 train_time:71142ms step_avg:96.66ms +step:737/1695 train_time:71241ms step_avg:96.66ms +step:738/1695 train_time:71338ms step_avg:96.66ms +step:739/1695 train_time:71437ms step_avg:96.67ms +step:740/1695 train_time:71535ms step_avg:96.67ms +step:741/1695 train_time:71633ms step_avg:96.67ms +step:742/1695 train_time:71731ms step_avg:96.67ms +step:743/1695 train_time:71829ms step_avg:96.67ms +step:744/1695 train_time:71926ms step_avg:96.67ms +step:745/1695 train_time:72023ms step_avg:96.68ms +step:746/1695 train_time:72122ms step_avg:96.68ms +step:747/1695 train_time:72221ms step_avg:96.68ms +step:748/1695 train_time:72319ms step_avg:96.68ms +step:749/1695 train_time:72417ms step_avg:96.69ms +step:750/1695 train_time:72516ms step_avg:96.69ms +step:750/1695 val_loss:3.5813 train_time:72612ms step_avg:96.82ms +step:751/1695 train_time:72638ms step_avg:96.72ms +step:752/1695 train_time:72722ms step_avg:96.71ms +step:753/1695 train_time:72825ms step_avg:96.71ms +step:754/1695 train_time:72924ms step_avg:96.72ms +step:755/1695 train_time:73022ms step_avg:96.72ms +step:756/1695 train_time:73119ms step_avg:96.72ms +step:757/1695 train_time:73216ms step_avg:96.72ms +step:758/1695 train_time:73313ms step_avg:96.72ms +step:759/1695 train_time:73410ms step_avg:96.72ms +step:760/1695 train_time:73508ms step_avg:96.72ms +step:761/1695 train_time:73605ms step_avg:96.72ms +step:762/1695 train_time:73705ms step_avg:96.73ms +step:763/1695 train_time:73806ms step_avg:96.73ms +step:764/1695 train_time:73907ms step_avg:96.74ms +step:765/1695 train_time:74008ms step_avg:96.74ms +step:766/1695 train_time:74106ms step_avg:96.74ms +step:767/1695 train_time:74205ms step_avg:96.75ms +step:768/1695 train_time:74304ms step_avg:96.75ms +step:769/1695 train_time:74402ms step_avg:96.75ms +step:770/1695 train_time:74499ms step_avg:96.75ms +step:771/1695 train_time:74597ms step_avg:96.75ms +step:772/1695 train_time:74695ms step_avg:96.76ms +step:773/1695 train_time:74794ms step_avg:96.76ms +step:774/1695 train_time:74892ms step_avg:96.76ms +step:775/1695 train_time:74990ms step_avg:96.76ms +step:776/1695 train_time:75089ms step_avg:96.76ms +step:777/1695 train_time:75187ms step_avg:96.77ms +step:778/1695 train_time:75285ms step_avg:96.77ms +step:779/1695 train_time:75675ms step_avg:97.14ms +step:780/1695 train_time:75771ms step_avg:97.14ms +step:781/1695 train_time:75868ms step_avg:97.14ms +step:782/1695 train_time:75966ms step_avg:97.14ms +step:783/1695 train_time:76064ms step_avg:97.14ms +step:784/1695 train_time:76161ms step_avg:97.14ms +step:785/1695 train_time:76258ms step_avg:97.14ms +step:786/1695 train_time:76355ms step_avg:97.14ms +step:787/1695 train_time:76451ms step_avg:97.14ms +step:788/1695 train_time:76548ms step_avg:97.14ms +step:789/1695 train_time:76951ms step_avg:97.53ms +step:790/1695 train_time:77001ms step_avg:97.47ms +step:791/1695 train_time:77124ms step_avg:97.50ms +step:792/1695 train_time:77221ms step_avg:97.50ms +step:793/1695 train_time:77318ms step_avg:97.50ms +step:794/1695 train_time:77415ms step_avg:97.50ms +step:795/1695 train_time:77512ms step_avg:97.50ms +step:796/1695 train_time:77609ms step_avg:97.50ms +step:797/1695 train_time:77707ms step_avg:97.50ms +step:798/1695 train_time:77804ms step_avg:97.50ms +step:799/1695 train_time:77902ms step_avg:97.50ms +step:800/1695 train_time:78002ms step_avg:97.50ms +step:801/1695 train_time:78102ms step_avg:97.51ms +step:802/1695 train_time:78201ms step_avg:97.51ms +step:803/1695 train_time:78299ms step_avg:97.51ms +step:804/1695 train_time:78397ms step_avg:97.51ms +step:805/1695 train_time:78495ms step_avg:97.51ms +step:806/1695 train_time:78592ms step_avg:97.51ms +step:807/1695 train_time:78690ms step_avg:97.51ms +step:808/1695 train_time:78788ms step_avg:97.51ms +step:809/1695 train_time:78887ms step_avg:97.51ms +step:810/1695 train_time:78986ms step_avg:97.51ms +step:811/1695 train_time:79086ms step_avg:97.52ms +step:812/1695 train_time:79185ms step_avg:97.52ms +step:813/1695 train_time:79285ms step_avg:97.52ms +step:814/1695 train_time:79383ms step_avg:97.52ms +step:815/1695 train_time:79483ms step_avg:97.52ms +step:816/1695 train_time:79581ms step_avg:97.53ms +step:817/1695 train_time:79680ms step_avg:97.53ms +step:818/1695 train_time:79778ms step_avg:97.53ms +step:819/1695 train_time:79876ms step_avg:97.53ms +step:820/1695 train_time:79974ms step_avg:97.53ms +step:821/1695 train_time:80071ms step_avg:97.53ms +step:822/1695 train_time:80169ms step_avg:97.53ms +step:823/1695 train_time:80267ms step_avg:97.53ms +step:824/1695 train_time:80365ms step_avg:97.53ms +step:825/1695 train_time:80463ms step_avg:97.53ms +step:826/1695 train_time:80561ms step_avg:97.53ms +step:827/1695 train_time:80659ms step_avg:97.53ms +step:828/1695 train_time:80757ms step_avg:97.53ms +step:829/1695 train_time:80855ms step_avg:97.53ms +step:830/1695 train_time:80953ms step_avg:97.53ms +step:831/1695 train_time:81051ms step_avg:97.53ms +step:832/1695 train_time:81150ms step_avg:97.54ms +step:833/1695 train_time:81248ms step_avg:97.54ms +step:834/1695 train_time:81346ms step_avg:97.54ms +step:835/1695 train_time:81445ms step_avg:97.54ms +step:836/1695 train_time:81543ms step_avg:97.54ms +step:837/1695 train_time:81642ms step_avg:97.54ms +step:838/1695 train_time:81741ms step_avg:97.54ms +step:839/1695 train_time:81839ms step_avg:97.54ms +step:840/1695 train_time:81937ms step_avg:97.54ms +step:841/1695 train_time:82036ms step_avg:97.55ms +step:842/1695 train_time:82134ms step_avg:97.55ms +step:843/1695 train_time:82233ms step_avg:97.55ms +step:844/1695 train_time:82330ms step_avg:97.55ms +step:845/1695 train_time:82428ms step_avg:97.55ms +step:846/1695 train_time:82526ms step_avg:97.55ms +step:847/1695 train_time:82624ms step_avg:97.55ms +step:848/1695 train_time:82723ms step_avg:97.55ms +step:849/1695 train_time:82821ms step_avg:97.55ms +step:850/1695 train_time:82920ms step_avg:97.55ms +step:851/1695 train_time:83018ms step_avg:97.55ms +step:852/1695 train_time:83117ms step_avg:97.55ms +step:853/1695 train_time:83216ms step_avg:97.56ms +step:854/1695 train_time:83315ms step_avg:97.56ms +step:855/1695 train_time:83412ms step_avg:97.56ms +step:856/1695 train_time:83509ms step_avg:97.56ms +step:857/1695 train_time:83607ms step_avg:97.56ms +step:858/1695 train_time:83705ms step_avg:97.56ms +step:859/1695 train_time:83803ms step_avg:97.56ms +step:860/1695 train_time:83901ms step_avg:97.56ms +step:861/1695 train_time:84000ms step_avg:97.56ms +step:862/1695 train_time:84099ms step_avg:97.56ms +step:863/1695 train_time:84197ms step_avg:97.56ms +step:864/1695 train_time:84571ms step_avg:97.88ms +step:865/1695 train_time:84667ms step_avg:97.88ms +step:866/1695 train_time:84764ms step_avg:97.88ms +step:867/1695 train_time:84862ms step_avg:97.88ms +step:868/1695 train_time:84959ms step_avg:97.88ms +step:869/1695 train_time:85056ms step_avg:97.88ms +step:870/1695 train_time:85152ms step_avg:97.88ms +step:871/1695 train_time:85250ms step_avg:97.88ms +step:872/1695 train_time:85346ms step_avg:97.87ms +step:873/1695 train_time:85448ms step_avg:97.88ms +step:874/1695 train_time:85549ms step_avg:97.88ms +step:875/1695 train_time:85648ms step_avg:97.88ms +step:875/1695 val_loss:3.5344 train_time:85745ms step_avg:97.99ms +step:876/1695 train_time:85771ms step_avg:97.91ms +step:877/1695 train_time:85852ms step_avg:97.89ms +step:878/1695 train_time:85957ms step_avg:97.90ms +step:879/1695 train_time:86055ms step_avg:97.90ms +step:880/1695 train_time:86153ms step_avg:97.90ms +step:881/1695 train_time:86252ms step_avg:97.90ms +step:882/1695 train_time:86352ms step_avg:97.90ms +step:883/1695 train_time:86452ms step_avg:97.91ms +step:884/1695 train_time:86551ms step_avg:97.91ms +step:885/1695 train_time:86650ms step_avg:97.91ms +step:886/1695 train_time:86750ms step_avg:97.91ms +step:887/1695 train_time:86851ms step_avg:97.92ms +step:888/1695 train_time:86954ms step_avg:97.92ms +step:889/1695 train_time:87054ms step_avg:97.92ms +step:890/1695 train_time:87154ms step_avg:97.93ms +step:891/1695 train_time:87253ms step_avg:97.93ms +step:892/1695 train_time:87351ms step_avg:97.93ms +step:893/1695 train_time:87451ms step_avg:97.93ms +step:894/1695 train_time:87550ms step_avg:97.93ms +step:895/1695 train_time:87649ms step_avg:97.93ms +step:896/1695 train_time:87748ms step_avg:97.93ms +step:897/1695 train_time:87849ms step_avg:97.94ms +step:898/1695 train_time:87949ms step_avg:97.94ms +step:899/1695 train_time:88050ms step_avg:97.94ms +step:900/1695 train_time:88150ms step_avg:97.94ms +step:901/1695 train_time:88250ms step_avg:97.95ms +step:902/1695 train_time:88349ms step_avg:97.95ms +step:903/1695 train_time:88449ms step_avg:97.95ms +step:904/1695 train_time:88548ms step_avg:97.95ms +step:905/1695 train_time:88647ms step_avg:97.95ms +step:906/1695 train_time:88746ms step_avg:97.95ms +step:907/1695 train_time:88845ms step_avg:97.96ms +step:908/1695 train_time:88945ms step_avg:97.96ms +step:909/1695 train_time:89045ms step_avg:97.96ms +step:910/1695 train_time:89146ms step_avg:97.96ms +step:911/1695 train_time:89245ms step_avg:97.96ms +step:912/1695 train_time:89345ms step_avg:97.97ms +step:913/1695 train_time:89444ms step_avg:97.97ms +step:914/1695 train_time:89543ms step_avg:97.97ms +step:915/1695 train_time:89642ms step_avg:97.97ms +step:916/1695 train_time:89740ms step_avg:97.97ms +step:917/1695 train_time:89839ms step_avg:97.97ms +step:918/1695 train_time:89937ms step_avg:97.97ms +step:919/1695 train_time:90037ms step_avg:97.97ms +step:920/1695 train_time:90136ms step_avg:97.97ms +step:921/1695 train_time:90235ms step_avg:97.98ms +step:922/1695 train_time:90335ms step_avg:97.98ms +step:923/1695 train_time:90434ms step_avg:97.98ms +step:924/1695 train_time:90533ms step_avg:97.98ms +step:925/1695 train_time:90633ms step_avg:97.98ms +step:926/1695 train_time:90733ms step_avg:97.98ms +step:927/1695 train_time:90833ms step_avg:97.99ms +step:928/1695 train_time:90933ms step_avg:97.99ms +step:929/1695 train_time:91033ms step_avg:97.99ms +step:930/1695 train_time:91133ms step_avg:97.99ms +step:931/1695 train_time:91233ms step_avg:97.99ms +step:932/1695 train_time:91332ms step_avg:98.00ms +step:933/1695 train_time:91432ms step_avg:98.00ms +step:934/1695 train_time:91532ms step_avg:98.00ms +step:935/1695 train_time:91631ms step_avg:98.00ms +step:936/1695 train_time:91731ms step_avg:98.00ms +step:937/1695 train_time:91831ms step_avg:98.01ms +step:938/1695 train_time:91932ms step_avg:98.01ms +step:939/1695 train_time:92032ms step_avg:98.01ms +step:940/1695 train_time:92131ms step_avg:98.01ms +step:941/1695 train_time:92231ms step_avg:98.01ms +step:942/1695 train_time:92331ms step_avg:98.02ms +step:943/1695 train_time:92431ms step_avg:98.02ms +step:944/1695 train_time:92531ms step_avg:98.02ms +step:945/1695 train_time:92632ms step_avg:98.02ms +step:946/1695 train_time:92731ms step_avg:98.02ms +step:947/1695 train_time:92830ms step_avg:98.03ms +step:948/1695 train_time:92929ms step_avg:98.03ms +step:949/1695 train_time:93029ms step_avg:98.03ms +step:950/1695 train_time:93129ms step_avg:98.03ms +step:951/1695 train_time:93230ms step_avg:98.03ms +step:952/1695 train_time:93329ms step_avg:98.03ms +step:953/1695 train_time:93429ms step_avg:98.04ms +step:954/1695 train_time:93529ms step_avg:98.04ms +step:955/1695 train_time:93628ms step_avg:98.04ms +step:956/1695 train_time:93728ms step_avg:98.04ms +step:957/1695 train_time:93827ms step_avg:98.04ms +step:958/1695 train_time:93928ms step_avg:98.05ms +step:959/1695 train_time:94028ms step_avg:98.05ms +step:960/1695 train_time:94128ms step_avg:98.05ms +step:961/1695 train_time:94227ms step_avg:98.05ms +step:962/1695 train_time:94326ms step_avg:98.05ms +step:963/1695 train_time:94426ms step_avg:98.05ms +step:964/1695 train_time:94526ms step_avg:98.06ms +step:965/1695 train_time:94626ms step_avg:98.06ms +step:966/1695 train_time:94726ms step_avg:98.06ms +step:967/1695 train_time:94825ms step_avg:98.06ms +step:968/1695 train_time:94925ms step_avg:98.06ms +step:969/1695 train_time:95026ms step_avg:98.07ms +step:970/1695 train_time:95126ms step_avg:98.07ms +step:971/1695 train_time:95225ms step_avg:98.07ms +step:972/1695 train_time:95324ms step_avg:98.07ms +step:973/1695 train_time:95425ms step_avg:98.07ms +step:974/1695 train_time:95524ms step_avg:98.07ms +step:975/1695 train_time:95623ms step_avg:98.07ms +step:976/1695 train_time:95722ms step_avg:98.08ms +step:977/1695 train_time:95820ms step_avg:98.08ms +step:978/1695 train_time:95919ms step_avg:98.08ms +step:979/1695 train_time:96018ms step_avg:98.08ms +step:980/1695 train_time:96117ms step_avg:98.08ms +step:981/1695 train_time:96217ms step_avg:98.08ms +step:982/1695 train_time:96316ms step_avg:98.08ms +step:983/1695 train_time:96416ms step_avg:98.08ms +step:984/1695 train_time:96516ms step_avg:98.09ms +step:985/1695 train_time:96616ms step_avg:98.09ms +step:986/1695 train_time:96716ms step_avg:98.09ms +step:987/1695 train_time:96815ms step_avg:98.09ms +step:988/1695 train_time:96914ms step_avg:98.09ms +step:989/1695 train_time:97014ms step_avg:98.09ms +step:990/1695 train_time:97114ms step_avg:98.09ms +step:991/1695 train_time:97213ms step_avg:98.10ms +step:992/1695 train_time:97313ms step_avg:98.10ms +step:993/1695 train_time:97413ms step_avg:98.10ms +step:994/1695 train_time:97512ms step_avg:98.10ms +step:995/1695 train_time:97612ms step_avg:98.10ms +step:996/1695 train_time:97712ms step_avg:98.10ms +step:997/1695 train_time:97812ms step_avg:98.11ms +step:998/1695 train_time:97912ms step_avg:98.11ms +step:999/1695 train_time:98012ms step_avg:98.11ms +step:1000/1695 train_time:98111ms step_avg:98.11ms +step:1000/1695 val_loss:3.4900 train_time:98209ms step_avg:98.21ms +step:1001/1695 train_time:98235ms step_avg:98.14ms +step:1002/1695 train_time:98321ms step_avg:98.12ms +step:1003/1695 train_time:98422ms step_avg:98.13ms +step:1004/1695 train_time:98523ms step_avg:98.13ms +step:1005/1695 train_time:98622ms step_avg:98.13ms +step:1006/1695 train_time:98721ms step_avg:98.13ms +step:1007/1695 train_time:98820ms step_avg:98.13ms +step:1008/1695 train_time:98919ms step_avg:98.13ms +step:1009/1695 train_time:99018ms step_avg:98.13ms +step:1010/1695 train_time:99116ms step_avg:98.13ms +step:1011/1695 train_time:99219ms step_avg:98.14ms +step:1012/1695 train_time:99323ms step_avg:98.15ms +step:1013/1695 train_time:99424ms step_avg:98.15ms +step:1014/1695 train_time:99524ms step_avg:98.15ms +step:1015/1695 train_time:99624ms step_avg:98.15ms +step:1016/1695 train_time:99722ms step_avg:98.15ms +step:1017/1695 train_time:99822ms step_avg:98.15ms +step:1018/1695 train_time:99920ms step_avg:98.15ms +step:1019/1695 train_time:100020ms step_avg:98.15ms +step:1020/1695 train_time:100119ms step_avg:98.16ms +step:1021/1695 train_time:100220ms step_avg:98.16ms +step:1022/1695 train_time:100321ms step_avg:98.16ms +step:1023/1695 train_time:100423ms step_avg:98.16ms +step:1024/1695 train_time:100524ms step_avg:98.17ms +step:1025/1695 train_time:100624ms step_avg:98.17ms +step:1026/1695 train_time:100724ms step_avg:98.17ms +step:1027/1695 train_time:100823ms step_avg:98.17ms +step:1028/1695 train_time:100922ms step_avg:98.17ms +step:1029/1695 train_time:101023ms step_avg:98.18ms +step:1030/1695 train_time:101122ms step_avg:98.18ms +step:1031/1695 train_time:101223ms step_avg:98.18ms +step:1032/1695 train_time:101322ms step_avg:98.18ms +step:1033/1695 train_time:101423ms step_avg:98.18ms +step:1034/1695 train_time:101522ms step_avg:98.18ms +step:1035/1695 train_time:101622ms step_avg:98.19ms +step:1036/1695 train_time:101722ms step_avg:98.19ms +step:1037/1695 train_time:101823ms step_avg:98.19ms +step:1038/1695 train_time:101922ms step_avg:98.19ms +step:1039/1695 train_time:102021ms step_avg:98.19ms +step:1040/1695 train_time:102121ms step_avg:98.19ms +step:1041/1695 train_time:102221ms step_avg:98.19ms +step:1042/1695 train_time:102321ms step_avg:98.20ms +step:1043/1695 train_time:102421ms step_avg:98.20ms +step:1044/1695 train_time:102521ms step_avg:98.20ms +step:1045/1695 train_time:102621ms step_avg:98.20ms +step:1046/1695 train_time:102722ms step_avg:98.20ms +step:1047/1695 train_time:102821ms step_avg:98.21ms +step:1048/1695 train_time:102920ms step_avg:98.21ms +step:1049/1695 train_time:103019ms step_avg:98.21ms +step:1050/1695 train_time:103119ms step_avg:98.21ms +step:1051/1695 train_time:103219ms step_avg:98.21ms +step:1052/1695 train_time:103320ms step_avg:98.21ms +step:1053/1695 train_time:103420ms step_avg:98.22ms +step:1054/1695 train_time:103521ms step_avg:98.22ms +step:1055/1695 train_time:103621ms step_avg:98.22ms +step:1056/1695 train_time:103721ms step_avg:98.22ms +step:1057/1695 train_time:103821ms step_avg:98.22ms +step:1058/1695 train_time:103921ms step_avg:98.22ms +step:1059/1695 train_time:104021ms step_avg:98.23ms +step:1060/1695 train_time:104120ms step_avg:98.23ms +step:1061/1695 train_time:104220ms step_avg:98.23ms +step:1062/1695 train_time:104320ms step_avg:98.23ms +step:1063/1695 train_time:104420ms step_avg:98.23ms +step:1064/1695 train_time:104521ms step_avg:98.23ms +step:1065/1695 train_time:104622ms step_avg:98.24ms +step:1066/1695 train_time:104722ms step_avg:98.24ms +step:1067/1695 train_time:104822ms step_avg:98.24ms +step:1068/1695 train_time:104921ms step_avg:98.24ms +step:1069/1695 train_time:105021ms step_avg:98.24ms +step:1070/1695 train_time:105121ms step_avg:98.24ms +step:1071/1695 train_time:105221ms step_avg:98.25ms +step:1072/1695 train_time:105322ms step_avg:98.25ms +step:1073/1695 train_time:105421ms step_avg:98.25ms +step:1074/1695 train_time:105521ms step_avg:98.25ms +step:1075/1695 train_time:105621ms step_avg:98.25ms +step:1076/1695 train_time:105721ms step_avg:98.25ms +step:1077/1695 train_time:105822ms step_avg:98.26ms +step:1078/1695 train_time:105922ms step_avg:98.26ms +step:1079/1695 train_time:106021ms step_avg:98.26ms +step:1080/1695 train_time:106121ms step_avg:98.26ms +step:1081/1695 train_time:106221ms step_avg:98.26ms +step:1082/1695 train_time:106322ms step_avg:98.26ms +step:1083/1695 train_time:106422ms step_avg:98.27ms +step:1084/1695 train_time:106522ms step_avg:98.27ms +step:1085/1695 train_time:106621ms step_avg:98.27ms +step:1086/1695 train_time:106722ms step_avg:98.27ms +step:1087/1695 train_time:106821ms step_avg:98.27ms +step:1088/1695 train_time:106921ms step_avg:98.27ms +step:1089/1695 train_time:107021ms step_avg:98.27ms +step:1090/1695 train_time:107121ms step_avg:98.28ms +step:1091/1695 train_time:107221ms step_avg:98.28ms +step:1092/1695 train_time:107321ms step_avg:98.28ms +step:1093/1695 train_time:107421ms step_avg:98.28ms +step:1094/1695 train_time:107522ms step_avg:98.28ms +step:1095/1695 train_time:107621ms step_avg:98.28ms +step:1096/1695 train_time:107722ms step_avg:98.29ms +step:1097/1695 train_time:107822ms step_avg:98.29ms +step:1098/1695 train_time:107921ms step_avg:98.29ms +step:1099/1695 train_time:108021ms step_avg:98.29ms +step:1100/1695 train_time:108121ms step_avg:98.29ms +step:1101/1695 train_time:108221ms step_avg:98.29ms +step:1102/1695 train_time:108321ms step_avg:98.29ms +step:1103/1695 train_time:108420ms step_avg:98.30ms +step:1104/1695 train_time:108520ms step_avg:98.30ms +step:1105/1695 train_time:108621ms step_avg:98.30ms +step:1106/1695 train_time:108720ms step_avg:98.30ms +step:1107/1695 train_time:108821ms step_avg:98.30ms +step:1108/1695 train_time:108921ms step_avg:98.30ms +step:1109/1695 train_time:109021ms step_avg:98.31ms +step:1110/1695 train_time:109121ms step_avg:98.31ms +step:1111/1695 train_time:109221ms step_avg:98.31ms +step:1112/1695 train_time:109321ms step_avg:98.31ms +step:1113/1695 train_time:109421ms step_avg:98.31ms +step:1114/1695 train_time:109520ms step_avg:98.31ms +step:1115/1695 train_time:109620ms step_avg:98.31ms +step:1116/1695 train_time:109721ms step_avg:98.32ms +step:1117/1695 train_time:109821ms step_avg:98.32ms +step:1118/1695 train_time:109921ms step_avg:98.32ms +step:1119/1695 train_time:110021ms step_avg:98.32ms +step:1120/1695 train_time:110121ms step_avg:98.32ms +step:1121/1695 train_time:110222ms step_avg:98.32ms +step:1122/1695 train_time:110323ms step_avg:98.33ms +step:1123/1695 train_time:110423ms step_avg:98.33ms +step:1124/1695 train_time:110522ms step_avg:98.33ms +step:1125/1695 train_time:110622ms step_avg:98.33ms +step:1125/1695 val_loss:3.4397 train_time:110720ms step_avg:98.42ms +step:1126/1695 train_time:110745ms step_avg:98.35ms +step:1127/1695 train_time:110830ms step_avg:98.34ms +step:1128/1695 train_time:110933ms step_avg:98.34ms +step:1129/1695 train_time:111032ms step_avg:98.35ms +step:1130/1695 train_time:111131ms step_avg:98.35ms +step:1131/1695 train_time:111229ms step_avg:98.35ms +step:1132/1695 train_time:111328ms step_avg:98.35ms +step:1133/1695 train_time:111428ms step_avg:98.35ms +step:1134/1695 train_time:111527ms step_avg:98.35ms +step:1135/1695 train_time:111626ms step_avg:98.35ms +step:1136/1695 train_time:111727ms step_avg:98.35ms +step:1137/1695 train_time:111830ms step_avg:98.36ms +step:1138/1695 train_time:111930ms step_avg:98.36ms +step:1139/1695 train_time:112030ms step_avg:98.36ms +step:1140/1695 train_time:112130ms step_avg:98.36ms +step:1141/1695 train_time:112230ms step_avg:98.36ms +step:1142/1695 train_time:112329ms step_avg:98.36ms +step:1143/1695 train_time:112429ms step_avg:98.36ms +step:1144/1695 train_time:112529ms step_avg:98.36ms +step:1145/1695 train_time:112629ms step_avg:98.37ms +step:1146/1695 train_time:112730ms step_avg:98.37ms +step:1147/1695 train_time:112831ms step_avg:98.37ms +step:1148/1695 train_time:112931ms step_avg:98.37ms +step:1149/1695 train_time:113032ms step_avg:98.37ms +step:1150/1695 train_time:113132ms step_avg:98.38ms +step:1151/1695 train_time:113231ms step_avg:98.38ms +step:1152/1695 train_time:113332ms step_avg:98.38ms +step:1153/1695 train_time:113433ms step_avg:98.38ms +step:1154/1695 train_time:113533ms step_avg:98.38ms +step:1155/1695 train_time:113634ms step_avg:98.38ms +step:1156/1695 train_time:113736ms step_avg:98.39ms +step:1157/1695 train_time:113838ms step_avg:98.39ms +step:1158/1695 train_time:113938ms step_avg:98.39ms +step:1159/1695 train_time:114039ms step_avg:98.39ms +step:1160/1695 train_time:114141ms step_avg:98.40ms +step:1161/1695 train_time:114242ms step_avg:98.40ms +step:1162/1695 train_time:114343ms step_avg:98.40ms +step:1163/1695 train_time:114446ms step_avg:98.41ms +step:1164/1695 train_time:114547ms step_avg:98.41ms +step:1165/1695 train_time:114647ms step_avg:98.41ms +step:1166/1695 train_time:114747ms step_avg:98.41ms +step:1167/1695 train_time:114848ms step_avg:98.41ms +step:1168/1695 train_time:114948ms step_avg:98.41ms +step:1169/1695 train_time:115048ms step_avg:98.42ms +step:1170/1695 train_time:115149ms step_avg:98.42ms +step:1171/1695 train_time:115248ms step_avg:98.42ms +step:1172/1695 train_time:115350ms step_avg:98.42ms +step:1173/1695 train_time:115451ms step_avg:98.42ms +step:1174/1695 train_time:115552ms step_avg:98.43ms +step:1175/1695 train_time:115653ms step_avg:98.43ms +step:1176/1695 train_time:115754ms step_avg:98.43ms +step:1177/1695 train_time:115855ms step_avg:98.43ms +step:1178/1695 train_time:115957ms step_avg:98.44ms +step:1179/1695 train_time:116059ms step_avg:98.44ms +step:1180/1695 train_time:116160ms step_avg:98.44ms +step:1181/1695 train_time:116261ms step_avg:98.44ms +step:1182/1695 train_time:116364ms step_avg:98.45ms +step:1183/1695 train_time:116465ms step_avg:98.45ms +step:1184/1695 train_time:116566ms step_avg:98.45ms +step:1185/1695 train_time:116667ms step_avg:98.45ms +step:1186/1695 train_time:116768ms step_avg:98.46ms +step:1187/1695 train_time:116869ms step_avg:98.46ms +step:1188/1695 train_time:116969ms step_avg:98.46ms +step:1189/1695 train_time:117068ms step_avg:98.46ms +step:1190/1695 train_time:117168ms step_avg:98.46ms +step:1191/1695 train_time:117268ms step_avg:98.46ms +step:1192/1695 train_time:117368ms step_avg:98.46ms +step:1193/1695 train_time:117468ms step_avg:98.46ms +step:1194/1695 train_time:117569ms step_avg:98.47ms +step:1195/1695 train_time:117668ms step_avg:98.47ms +step:1196/1695 train_time:117768ms step_avg:98.47ms +step:1197/1695 train_time:117869ms step_avg:98.47ms +step:1198/1695 train_time:117968ms step_avg:98.47ms +step:1199/1695 train_time:118069ms step_avg:98.47ms +step:1200/1695 train_time:118168ms step_avg:98.47ms +step:1201/1695 train_time:118268ms step_avg:98.47ms +step:1202/1695 train_time:118370ms step_avg:98.48ms +step:1203/1695 train_time:118470ms step_avg:98.48ms +step:1204/1695 train_time:118571ms step_avg:98.48ms +step:1205/1695 train_time:118671ms step_avg:98.48ms +step:1206/1695 train_time:118771ms step_avg:98.48ms +step:1207/1695 train_time:118872ms step_avg:98.49ms +step:1208/1695 train_time:118972ms step_avg:98.49ms +step:1209/1695 train_time:119071ms step_avg:98.49ms +step:1210/1695 train_time:119171ms step_avg:98.49ms +step:1211/1695 train_time:119273ms step_avg:98.49ms +step:1212/1695 train_time:119373ms step_avg:98.49ms +step:1213/1695 train_time:119476ms step_avg:98.50ms +step:1214/1695 train_time:119576ms step_avg:98.50ms +step:1215/1695 train_time:119676ms step_avg:98.50ms +step:1216/1695 train_time:119778ms step_avg:98.50ms +step:1217/1695 train_time:119880ms step_avg:98.50ms +step:1218/1695 train_time:119982ms step_avg:98.51ms +step:1219/1695 train_time:120084ms step_avg:98.51ms +step:1220/1695 train_time:120185ms step_avg:98.51ms +step:1221/1695 train_time:120285ms step_avg:98.51ms +step:1222/1695 train_time:120386ms step_avg:98.52ms +step:1223/1695 train_time:120487ms step_avg:98.52ms +step:1224/1695 train_time:120587ms step_avg:98.52ms +step:1225/1695 train_time:120688ms step_avg:98.52ms +step:1226/1695 train_time:120787ms step_avg:98.52ms +step:1227/1695 train_time:120888ms step_avg:98.52ms +step:1228/1695 train_time:120989ms step_avg:98.52ms +step:1229/1695 train_time:121088ms step_avg:98.53ms +step:1230/1695 train_time:121187ms step_avg:98.53ms +step:1231/1695 train_time:121287ms step_avg:98.53ms +step:1232/1695 train_time:121388ms step_avg:98.53ms +step:1233/1695 train_time:121488ms step_avg:98.53ms +step:1234/1695 train_time:121589ms step_avg:98.53ms +step:1235/1695 train_time:121689ms step_avg:98.53ms +step:1236/1695 train_time:121789ms step_avg:98.53ms +step:1237/1695 train_time:121889ms step_avg:98.54ms +step:1238/1695 train_time:121989ms step_avg:98.54ms +step:1239/1695 train_time:122089ms step_avg:98.54ms +step:1240/1695 train_time:122189ms step_avg:98.54ms +step:1241/1695 train_time:122290ms step_avg:98.54ms +step:1242/1695 train_time:122390ms step_avg:98.54ms +step:1243/1695 train_time:122490ms step_avg:98.54ms +step:1244/1695 train_time:122589ms step_avg:98.54ms +step:1245/1695 train_time:122689ms step_avg:98.55ms +step:1246/1695 train_time:122790ms step_avg:98.55ms +step:1247/1695 train_time:122889ms step_avg:98.55ms +step:1248/1695 train_time:122990ms step_avg:98.55ms +step:1249/1695 train_time:123090ms step_avg:98.55ms +step:1250/1695 train_time:123190ms step_avg:98.55ms +step:1250/1695 val_loss:3.3940 train_time:123288ms step_avg:98.63ms +step:1251/1695 train_time:123313ms step_avg:98.57ms +step:1252/1695 train_time:123398ms step_avg:98.56ms +step:1253/1695 train_time:123502ms step_avg:98.56ms +step:1254/1695 train_time:123603ms step_avg:98.57ms +step:1255/1695 train_time:123704ms step_avg:98.57ms +step:1256/1695 train_time:123804ms step_avg:98.57ms +step:1257/1695 train_time:123903ms step_avg:98.57ms +step:1258/1695 train_time:124004ms step_avg:98.57ms +step:1259/1695 train_time:124103ms step_avg:98.57ms +step:1260/1695 train_time:124204ms step_avg:98.57ms +step:1261/1695 train_time:124307ms step_avg:98.58ms +step:1262/1695 train_time:124411ms step_avg:98.58ms +step:1263/1695 train_time:124513ms step_avg:98.59ms +step:1264/1695 train_time:124613ms step_avg:98.59ms +step:1265/1695 train_time:124713ms step_avg:98.59ms +step:1266/1695 train_time:124812ms step_avg:98.59ms +step:1267/1695 train_time:124912ms step_avg:98.59ms +step:1268/1695 train_time:125012ms step_avg:98.59ms +step:1269/1695 train_time:125113ms step_avg:98.59ms +step:1270/1695 train_time:125213ms step_avg:98.59ms +step:1271/1695 train_time:125315ms step_avg:98.60ms +step:1272/1695 train_time:125416ms step_avg:98.60ms +step:1273/1695 train_time:125517ms step_avg:98.60ms +step:1274/1695 train_time:125617ms step_avg:98.60ms +step:1275/1695 train_time:125719ms step_avg:98.60ms +step:1276/1695 train_time:125821ms step_avg:98.61ms +step:1277/1695 train_time:125922ms step_avg:98.61ms +step:1278/1695 train_time:126024ms step_avg:98.61ms +step:1279/1695 train_time:126126ms step_avg:98.61ms +step:1280/1695 train_time:126227ms step_avg:98.61ms +step:1281/1695 train_time:126328ms step_avg:98.62ms +step:1282/1695 train_time:126429ms step_avg:98.62ms +step:1283/1695 train_time:126529ms step_avg:98.62ms +step:1284/1695 train_time:126630ms step_avg:98.62ms +step:1285/1695 train_time:126731ms step_avg:98.62ms +step:1286/1695 train_time:126832ms step_avg:98.62ms +step:1287/1695 train_time:126932ms step_avg:98.63ms +step:1288/1695 train_time:127031ms step_avg:98.63ms +step:1289/1695 train_time:127132ms step_avg:98.63ms +step:1290/1695 train_time:127232ms step_avg:98.63ms +step:1291/1695 train_time:127333ms step_avg:98.63ms +step:1292/1695 train_time:127433ms step_avg:98.63ms +step:1293/1695 train_time:127534ms step_avg:98.63ms +step:1294/1695 train_time:127635ms step_avg:98.64ms +step:1295/1695 train_time:127736ms step_avg:98.64ms +step:1296/1695 train_time:127837ms step_avg:98.64ms +step:1297/1695 train_time:127938ms step_avg:98.64ms +step:1298/1695 train_time:128039ms step_avg:98.64ms +step:1299/1695 train_time:128141ms step_avg:98.65ms +step:1300/1695 train_time:128242ms step_avg:98.65ms +step:1301/1695 train_time:128344ms step_avg:98.65ms +step:1302/1695 train_time:128446ms step_avg:98.65ms +step:1303/1695 train_time:128548ms step_avg:98.66ms +step:1304/1695 train_time:128649ms step_avg:98.66ms +step:1305/1695 train_time:128750ms step_avg:98.66ms +step:1306/1695 train_time:128849ms step_avg:98.66ms +step:1307/1695 train_time:128950ms step_avg:98.66ms +step:1308/1695 train_time:129051ms step_avg:98.66ms +step:1309/1695 train_time:129152ms step_avg:98.66ms +step:1310/1695 train_time:129253ms step_avg:98.67ms +step:1311/1695 train_time:129354ms step_avg:98.67ms +step:1312/1695 train_time:129455ms step_avg:98.67ms +step:1313/1695 train_time:129556ms step_avg:98.67ms +step:1314/1695 train_time:129657ms step_avg:98.67ms +step:1315/1695 train_time:129758ms step_avg:98.68ms +step:1316/1695 train_time:129859ms step_avg:98.68ms +step:1317/1695 train_time:129959ms step_avg:98.68ms +step:1318/1695 train_time:130060ms step_avg:98.68ms +step:1319/1695 train_time:130162ms step_avg:98.68ms +step:1320/1695 train_time:130265ms step_avg:98.69ms +step:1321/1695 train_time:130366ms step_avg:98.69ms +step:1322/1695 train_time:130468ms step_avg:98.69ms +step:1323/1695 train_time:130568ms step_avg:98.69ms +step:1324/1695 train_time:130669ms step_avg:98.69ms +step:1325/1695 train_time:130770ms step_avg:98.69ms +step:1326/1695 train_time:130870ms step_avg:98.70ms +step:1327/1695 train_time:130971ms step_avg:98.70ms +step:1328/1695 train_time:131071ms step_avg:98.70ms +step:1329/1695 train_time:131172ms step_avg:98.70ms +step:1330/1695 train_time:131272ms step_avg:98.70ms +step:1331/1695 train_time:131372ms step_avg:98.70ms +step:1332/1695 train_time:131472ms step_avg:98.70ms +step:1333/1695 train_time:131572ms step_avg:98.70ms +step:1334/1695 train_time:131673ms step_avg:98.71ms +step:1335/1695 train_time:131773ms step_avg:98.71ms +step:1336/1695 train_time:131874ms step_avg:98.71ms +step:1337/1695 train_time:131975ms step_avg:98.71ms +step:1338/1695 train_time:132075ms step_avg:98.71ms +step:1339/1695 train_time:132177ms step_avg:98.71ms +step:1340/1695 train_time:132277ms step_avg:98.71ms +step:1341/1695 train_time:132378ms step_avg:98.72ms +step:1342/1695 train_time:132478ms step_avg:98.72ms +step:1343/1695 train_time:132580ms step_avg:98.72ms +step:1344/1695 train_time:132680ms step_avg:98.72ms +step:1345/1695 train_time:132782ms step_avg:98.72ms +step:1346/1695 train_time:132883ms step_avg:98.72ms +step:1347/1695 train_time:132985ms step_avg:98.73ms +step:1348/1695 train_time:133087ms step_avg:98.73ms +step:1349/1695 train_time:133187ms step_avg:98.73ms +step:1350/1695 train_time:133289ms step_avg:98.73ms +step:1351/1695 train_time:133389ms step_avg:98.73ms +step:1352/1695 train_time:133489ms step_avg:98.73ms +step:1353/1695 train_time:133590ms step_avg:98.74ms +step:1354/1695 train_time:133691ms step_avg:98.74ms +step:1355/1695 train_time:133791ms step_avg:98.74ms +step:1356/1695 train_time:133891ms step_avg:98.74ms +step:1357/1695 train_time:133991ms step_avg:98.74ms +step:1358/1695 train_time:134091ms step_avg:98.74ms +step:1359/1695 train_time:134191ms step_avg:98.74ms +step:1360/1695 train_time:134292ms step_avg:98.74ms +step:1361/1695 train_time:134392ms step_avg:98.75ms +step:1362/1695 train_time:134492ms step_avg:98.75ms +step:1363/1695 train_time:134593ms step_avg:98.75ms +step:1364/1695 train_time:134693ms step_avg:98.75ms +step:1365/1695 train_time:134794ms step_avg:98.75ms +step:1366/1695 train_time:134895ms step_avg:98.75ms +step:1367/1695 train_time:134995ms step_avg:98.75ms +step:1368/1695 train_time:135096ms step_avg:98.75ms +step:1369/1695 train_time:135197ms step_avg:98.76ms +step:1370/1695 train_time:135299ms step_avg:98.76ms +step:1371/1695 train_time:135399ms step_avg:98.76ms +step:1372/1695 train_time:135499ms step_avg:98.76ms +step:1373/1695 train_time:135600ms step_avg:98.76ms +step:1374/1695 train_time:135701ms step_avg:98.76ms +step:1375/1695 train_time:135804ms step_avg:98.77ms +step:1375/1695 val_loss:3.3538 train_time:135904ms step_avg:98.84ms +step:1376/1695 train_time:135930ms step_avg:98.79ms +step:1377/1695 train_time:136016ms step_avg:98.78ms +step:1378/1695 train_time:136117ms step_avg:98.78ms +step:1379/1695 train_time:136217ms step_avg:98.78ms +step:1380/1695 train_time:136318ms step_avg:98.78ms +step:1381/1695 train_time:136418ms step_avg:98.78ms +step:1382/1695 train_time:136517ms step_avg:98.78ms +step:1383/1695 train_time:136617ms step_avg:98.78ms +step:1384/1695 train_time:136718ms step_avg:98.78ms +step:1385/1695 train_time:136821ms step_avg:98.79ms +step:1386/1695 train_time:136926ms step_avg:98.79ms +step:1387/1695 train_time:137029ms step_avg:98.79ms +step:1388/1695 train_time:137131ms step_avg:98.80ms +step:1389/1695 train_time:137233ms step_avg:98.80ms +step:1390/1695 train_time:137334ms step_avg:98.80ms +step:1391/1695 train_time:137435ms step_avg:98.80ms +step:1392/1695 train_time:137536ms step_avg:98.80ms +step:1393/1695 train_time:137638ms step_avg:98.81ms +step:1394/1695 train_time:137739ms step_avg:98.81ms +step:1395/1695 train_time:137840ms step_avg:98.81ms +step:1396/1695 train_time:137944ms step_avg:98.81ms +step:1397/1695 train_time:138048ms step_avg:98.82ms +step:1398/1695 train_time:138150ms step_avg:98.82ms +step:1399/1695 train_time:138252ms step_avg:98.82ms +step:1400/1695 train_time:138354ms step_avg:98.82ms +step:1401/1695 train_time:138454ms step_avg:98.83ms +step:1402/1695 train_time:138555ms step_avg:98.83ms +step:1403/1695 train_time:138656ms step_avg:98.83ms +step:1404/1695 train_time:138758ms step_avg:98.83ms +step:1405/1695 train_time:138858ms step_avg:98.83ms +step:1406/1695 train_time:138962ms step_avg:98.84ms +step:1407/1695 train_time:139064ms step_avg:98.84ms +step:1408/1695 train_time:139166ms step_avg:98.84ms +step:1409/1695 train_time:139272ms step_avg:98.84ms +step:1410/1695 train_time:139374ms step_avg:98.85ms +step:1411/1695 train_time:139474ms step_avg:98.85ms +step:1412/1695 train_time:139577ms step_avg:98.85ms +step:1413/1695 train_time:139678ms step_avg:98.85ms +step:1414/1695 train_time:139779ms step_avg:98.85ms +step:1415/1695 train_time:139881ms step_avg:98.86ms +step:1416/1695 train_time:139982ms step_avg:98.86ms +step:1417/1695 train_time:140083ms step_avg:98.86ms +step:1418/1695 train_time:140185ms step_avg:98.86ms +step:1419/1695 train_time:140290ms step_avg:98.87ms +step:1420/1695 train_time:140392ms step_avg:98.87ms +step:1421/1695 train_time:140494ms step_avg:98.87ms +step:1422/1695 train_time:140595ms step_avg:98.87ms +step:1423/1695 train_time:140696ms step_avg:98.87ms +step:1424/1695 train_time:140797ms step_avg:98.87ms +step:1425/1695 train_time:140901ms step_avg:98.88ms +step:1426/1695 train_time:141002ms step_avg:98.88ms +step:1427/1695 train_time:141103ms step_avg:98.88ms +step:1428/1695 train_time:141205ms step_avg:98.88ms +step:1429/1695 train_time:141309ms step_avg:98.89ms +step:1430/1695 train_time:141411ms step_avg:98.89ms +step:1431/1695 train_time:141512ms step_avg:98.89ms +step:1432/1695 train_time:141614ms step_avg:98.89ms +step:1433/1695 train_time:141716ms step_avg:98.89ms +step:1434/1695 train_time:141817ms step_avg:98.90ms +step:1435/1695 train_time:141919ms step_avg:98.90ms +step:1436/1695 train_time:142021ms step_avg:98.90ms +step:1437/1695 train_time:142123ms step_avg:98.90ms +step:1438/1695 train_time:142225ms step_avg:98.90ms +step:1439/1695 train_time:142329ms step_avg:98.91ms +step:1440/1695 train_time:142432ms step_avg:98.91ms +step:1441/1695 train_time:142535ms step_avg:98.91ms +step:1442/1695 train_time:142636ms step_avg:98.92ms +step:1443/1695 train_time:142737ms step_avg:98.92ms +step:1444/1695 train_time:142838ms step_avg:98.92ms +step:1445/1695 train_time:142938ms step_avg:98.92ms +step:1446/1695 train_time:143040ms step_avg:98.92ms +step:1447/1695 train_time:143142ms step_avg:98.92ms +step:1448/1695 train_time:143245ms step_avg:98.93ms +step:1449/1695 train_time:143346ms step_avg:98.93ms +step:1450/1695 train_time:143449ms step_avg:98.93ms +step:1451/1695 train_time:143552ms step_avg:98.93ms +step:1452/1695 train_time:143653ms step_avg:98.93ms +step:1453/1695 train_time:143756ms step_avg:98.94ms +step:1454/1695 train_time:143859ms step_avg:98.94ms +step:1455/1695 train_time:143960ms step_avg:98.94ms +step:1456/1695 train_time:144061ms step_avg:98.94ms +step:1457/1695 train_time:144163ms step_avg:98.95ms +step:1458/1695 train_time:144266ms step_avg:98.95ms +step:1459/1695 train_time:144368ms step_avg:98.95ms +step:1460/1695 train_time:144470ms step_avg:98.95ms +step:1461/1695 train_time:144572ms step_avg:98.95ms +step:1462/1695 train_time:144673ms step_avg:98.96ms +step:1463/1695 train_time:144774ms step_avg:98.96ms +step:1464/1695 train_time:144876ms step_avg:98.96ms +step:1465/1695 train_time:144977ms step_avg:98.96ms +step:1466/1695 train_time:145077ms step_avg:98.96ms +step:1467/1695 train_time:145179ms step_avg:98.96ms +step:1468/1695 train_time:145282ms step_avg:98.97ms +step:1469/1695 train_time:145386ms step_avg:98.97ms +step:1470/1695 train_time:145487ms step_avg:98.97ms +step:1471/1695 train_time:145590ms step_avg:98.97ms +step:1472/1695 train_time:145691ms step_avg:98.97ms +step:1473/1695 train_time:145792ms step_avg:98.98ms +step:1474/1695 train_time:145895ms step_avg:98.98ms +step:1475/1695 train_time:145995ms step_avg:98.98ms +step:1476/1695 train_time:146097ms step_avg:98.98ms +step:1477/1695 train_time:146198ms step_avg:98.98ms +step:1478/1695 train_time:146300ms step_avg:98.99ms +step:1479/1695 train_time:146402ms step_avg:98.99ms +step:1480/1695 train_time:146504ms step_avg:98.99ms +step:1481/1695 train_time:146606ms step_avg:98.99ms +step:1482/1695 train_time:146710ms step_avg:98.99ms +step:1483/1695 train_time:146812ms step_avg:99.00ms +step:1484/1695 train_time:146916ms step_avg:99.00ms +step:1485/1695 train_time:147017ms step_avg:99.00ms +step:1486/1695 train_time:147118ms step_avg:99.00ms +step:1487/1695 train_time:147219ms step_avg:99.00ms +step:1488/1695 train_time:147321ms step_avg:99.01ms +step:1489/1695 train_time:147424ms step_avg:99.01ms +step:1490/1695 train_time:147526ms step_avg:99.01ms +step:1491/1695 train_time:147628ms step_avg:99.01ms +step:1492/1695 train_time:147730ms step_avg:99.01ms +step:1493/1695 train_time:147832ms step_avg:99.02ms +step:1494/1695 train_time:147933ms step_avg:99.02ms +step:1495/1695 train_time:148034ms step_avg:99.02ms +step:1496/1695 train_time:148135ms step_avg:99.02ms +step:1497/1695 train_time:148235ms step_avg:99.02ms +step:1498/1695 train_time:148336ms step_avg:99.02ms +step:1499/1695 train_time:148440ms step_avg:99.03ms +step:1500/1695 train_time:148542ms step_avg:99.03ms +step:1500/1695 val_loss:3.3196 train_time:148642ms step_avg:99.09ms +step:1501/1695 train_time:148668ms step_avg:99.05ms +step:1502/1695 train_time:148756ms step_avg:99.04ms +step:1503/1695 train_time:148858ms step_avg:99.04ms +step:1504/1695 train_time:148960ms step_avg:99.04ms +step:1505/1695 train_time:149061ms step_avg:99.04ms +step:1506/1695 train_time:149163ms step_avg:99.05ms +step:1507/1695 train_time:149264ms step_avg:99.05ms +step:1508/1695 train_time:149364ms step_avg:99.05ms +step:1509/1695 train_time:149467ms step_avg:99.05ms +step:1510/1695 train_time:149569ms step_avg:99.05ms +step:1511/1695 train_time:149672ms step_avg:99.05ms +step:1512/1695 train_time:149775ms step_avg:99.06ms +step:1513/1695 train_time:149878ms step_avg:99.06ms +step:1514/1695 train_time:149980ms step_avg:99.06ms +step:1515/1695 train_time:150086ms step_avg:99.07ms +step:1516/1695 train_time:150188ms step_avg:99.07ms +step:1517/1695 train_time:150289ms step_avg:99.07ms +step:1518/1695 train_time:150391ms step_avg:99.07ms +step:1519/1695 train_time:150494ms step_avg:99.07ms +step:1520/1695 train_time:150595ms step_avg:99.08ms +step:1521/1695 train_time:150696ms step_avg:99.08ms +step:1522/1695 train_time:150798ms step_avg:99.08ms +step:1523/1695 train_time:150899ms step_avg:99.08ms +step:1524/1695 train_time:151003ms step_avg:99.08ms +step:1525/1695 train_time:151106ms step_avg:99.09ms +step:1526/1695 train_time:151209ms step_avg:99.09ms +step:1527/1695 train_time:151311ms step_avg:99.09ms +step:1528/1695 train_time:151416ms step_avg:99.09ms +step:1529/1695 train_time:151518ms step_avg:99.10ms +step:1530/1695 train_time:151621ms step_avg:99.10ms +step:1531/1695 train_time:151723ms step_avg:99.10ms +step:1532/1695 train_time:151825ms step_avg:99.10ms +step:1533/1695 train_time:151925ms step_avg:99.10ms +step:1534/1695 train_time:152027ms step_avg:99.11ms +step:1535/1695 train_time:152129ms step_avg:99.11ms +step:1536/1695 train_time:152232ms step_avg:99.11ms +step:1537/1695 train_time:152333ms step_avg:99.11ms +step:1538/1695 train_time:152435ms step_avg:99.11ms +step:1539/1695 train_time:152536ms step_avg:99.11ms +step:1540/1695 train_time:152638ms step_avg:99.12ms +step:1541/1695 train_time:152741ms step_avg:99.12ms +step:1542/1695 train_time:152844ms step_avg:99.12ms +step:1543/1695 train_time:152947ms step_avg:99.12ms +step:1544/1695 train_time:153049ms step_avg:99.12ms +step:1545/1695 train_time:153151ms step_avg:99.13ms +step:1546/1695 train_time:153253ms step_avg:99.13ms +step:1547/1695 train_time:153356ms step_avg:99.13ms +step:1548/1695 train_time:153458ms step_avg:99.13ms +step:1549/1695 train_time:153560ms step_avg:99.14ms +step:1550/1695 train_time:153661ms step_avg:99.14ms +step:1551/1695 train_time:153763ms step_avg:99.14ms +step:1552/1695 train_time:153865ms step_avg:99.14ms +step:1553/1695 train_time:153968ms step_avg:99.14ms +step:1554/1695 train_time:154069ms step_avg:99.14ms +step:1555/1695 train_time:154172ms step_avg:99.15ms +step:1556/1695 train_time:154275ms step_avg:99.15ms +step:1557/1695 train_time:154378ms step_avg:99.15ms +step:1558/1695 train_time:154482ms step_avg:99.15ms +step:1559/1695 train_time:154584ms step_avg:99.16ms +step:1560/1695 train_time:154686ms step_avg:99.16ms +step:1561/1695 train_time:154787ms step_avg:99.16ms +step:1562/1695 train_time:154890ms step_avg:99.16ms +step:1563/1695 train_time:154994ms step_avg:99.16ms +step:1564/1695 train_time:155095ms step_avg:99.17ms +step:1565/1695 train_time:155196ms step_avg:99.17ms +step:1566/1695 train_time:155298ms step_avg:99.17ms +step:1567/1695 train_time:155398ms step_avg:99.17ms +step:1568/1695 train_time:155499ms step_avg:99.17ms +step:1569/1695 train_time:155600ms step_avg:99.17ms +step:1570/1695 train_time:155703ms step_avg:99.17ms +step:1571/1695 train_time:155805ms step_avg:99.18ms +step:1572/1695 train_time:155905ms step_avg:99.18ms +step:1573/1695 train_time:156008ms step_avg:99.18ms +step:1574/1695 train_time:156109ms step_avg:99.18ms +step:1575/1695 train_time:156211ms step_avg:99.18ms +step:1576/1695 train_time:156315ms step_avg:99.18ms +step:1577/1695 train_time:156418ms step_avg:99.19ms +step:1578/1695 train_time:156520ms step_avg:99.19ms +step:1579/1695 train_time:156621ms step_avg:99.19ms +step:1580/1695 train_time:156723ms step_avg:99.19ms +step:1581/1695 train_time:156826ms step_avg:99.19ms +step:1582/1695 train_time:156927ms step_avg:99.20ms +step:1583/1695 train_time:157030ms step_avg:99.20ms +step:1584/1695 train_time:157133ms step_avg:99.20ms +step:1585/1695 train_time:157235ms step_avg:99.20ms +step:1586/1695 train_time:157338ms step_avg:99.20ms +step:1587/1695 train_time:157439ms step_avg:99.21ms +step:1588/1695 train_time:157539ms step_avg:99.21ms +step:1589/1695 train_time:157640ms step_avg:99.21ms +step:1590/1695 train_time:157742ms step_avg:99.21ms +step:1591/1695 train_time:157844ms step_avg:99.21ms +step:1592/1695 train_time:157946ms step_avg:99.21ms +step:1593/1695 train_time:158047ms step_avg:99.21ms +step:1594/1695 train_time:158151ms step_avg:99.22ms +step:1595/1695 train_time:158254ms step_avg:99.22ms +step:1596/1695 train_time:158355ms step_avg:99.22ms +step:1597/1695 train_time:158457ms step_avg:99.22ms +step:1598/1695 train_time:158560ms step_avg:99.22ms +step:1599/1695 train_time:158660ms step_avg:99.22ms +step:1600/1695 train_time:158762ms step_avg:99.23ms +step:1601/1695 train_time:158864ms step_avg:99.23ms +step:1602/1695 train_time:158965ms step_avg:99.23ms +step:1603/1695 train_time:159067ms step_avg:99.23ms +step:1604/1695 train_time:159170ms step_avg:99.23ms +step:1605/1695 train_time:159273ms step_avg:99.24ms +step:1606/1695 train_time:159376ms step_avg:99.24ms +step:1607/1695 train_time:159477ms step_avg:99.24ms +step:1608/1695 train_time:159578ms step_avg:99.24ms +step:1609/1695 train_time:159679ms step_avg:99.24ms +step:1610/1695 train_time:159781ms step_avg:99.24ms +step:1611/1695 train_time:159884ms step_avg:99.24ms +step:1612/1695 train_time:159985ms step_avg:99.25ms +step:1613/1695 train_time:160087ms step_avg:99.25ms +step:1614/1695 train_time:160188ms step_avg:99.25ms +step:1615/1695 train_time:160291ms step_avg:99.25ms +step:1616/1695 train_time:160392ms step_avg:99.25ms +step:1617/1695 train_time:160495ms step_avg:99.25ms +step:1618/1695 train_time:160597ms step_avg:99.26ms +step:1619/1695 train_time:160698ms step_avg:99.26ms +step:1620/1695 train_time:160801ms step_avg:99.26ms +step:1621/1695 train_time:160902ms step_avg:99.26ms +step:1622/1695 train_time:161003ms step_avg:99.26ms +step:1623/1695 train_time:161105ms step_avg:99.26ms +step:1624/1695 train_time:161208ms step_avg:99.27ms +step:1625/1695 train_time:161311ms step_avg:99.27ms +step:1625/1695 val_loss:3.2905 train_time:161412ms step_avg:99.33ms +step:1626/1695 train_time:161438ms step_avg:99.29ms +step:1627/1695 train_time:161526ms step_avg:99.28ms +step:1628/1695 train_time:161629ms step_avg:99.28ms +step:1629/1695 train_time:161733ms step_avg:99.28ms +step:1630/1695 train_time:161835ms step_avg:99.29ms +step:1631/1695 train_time:161936ms step_avg:99.29ms +step:1632/1695 train_time:162038ms step_avg:99.29ms +step:1633/1695 train_time:162138ms step_avg:99.29ms +step:1634/1695 train_time:162241ms step_avg:99.29ms +step:1635/1695 train_time:162342ms step_avg:99.29ms +step:1636/1695 train_time:162445ms step_avg:99.29ms +step:1637/1695 train_time:162548ms step_avg:99.30ms +step:1638/1695 train_time:162652ms step_avg:99.30ms +step:1639/1695 train_time:162756ms step_avg:99.30ms +step:1640/1695 train_time:162859ms step_avg:99.30ms +step:1641/1695 train_time:162963ms step_avg:99.31ms +step:1642/1695 train_time:163064ms step_avg:99.31ms +step:1643/1695 train_time:163168ms step_avg:99.31ms +step:1644/1695 train_time:163269ms step_avg:99.31ms +step:1645/1695 train_time:163373ms step_avg:99.31ms +step:1646/1695 train_time:163475ms step_avg:99.32ms +step:1647/1695 train_time:163579ms step_avg:99.32ms +step:1648/1695 train_time:163682ms step_avg:99.32ms +step:1649/1695 train_time:163785ms step_avg:99.32ms +step:1650/1695 train_time:163889ms step_avg:99.33ms +step:1651/1695 train_time:163992ms step_avg:99.33ms +step:1652/1695 train_time:164096ms step_avg:99.33ms +step:1653/1695 train_time:164199ms step_avg:99.33ms +step:1654/1695 train_time:164300ms step_avg:99.33ms +step:1655/1695 train_time:164402ms step_avg:99.34ms +step:1656/1695 train_time:164505ms step_avg:99.34ms +step:1657/1695 train_time:164607ms step_avg:99.34ms +step:1658/1695 train_time:164710ms step_avg:99.34ms +step:1659/1695 train_time:164816ms step_avg:99.35ms +step:1660/1695 train_time:164918ms step_avg:99.35ms +step:1661/1695 train_time:165022ms step_avg:99.35ms +step:1662/1695 train_time:165126ms step_avg:99.35ms +step:1663/1695 train_time:165230ms step_avg:99.36ms +step:1664/1695 train_time:165334ms step_avg:99.36ms +step:1665/1695 train_time:165439ms step_avg:99.36ms +step:1666/1695 train_time:165542ms step_avg:99.36ms +step:1667/1695 train_time:165643ms step_avg:99.37ms +step:1668/1695 train_time:165747ms step_avg:99.37ms +step:1669/1695 train_time:165852ms step_avg:99.37ms +step:1670/1695 train_time:165954ms step_avg:99.37ms +step:1671/1695 train_time:166057ms step_avg:99.38ms +step:1672/1695 train_time:166160ms step_avg:99.38ms +step:1673/1695 train_time:166262ms step_avg:99.38ms +step:1674/1695 train_time:166365ms step_avg:99.38ms +step:1675/1695 train_time:166467ms step_avg:99.38ms +step:1676/1695 train_time:166572ms step_avg:99.39ms +step:1677/1695 train_time:166675ms step_avg:99.39ms +step:1678/1695 train_time:166778ms step_avg:99.39ms +step:1679/1695 train_time:166881ms step_avg:99.39ms +step:1680/1695 train_time:166983ms step_avg:99.39ms +step:1681/1695 train_time:167086ms step_avg:99.40ms +step:1682/1695 train_time:167192ms step_avg:99.40ms +step:1683/1695 train_time:167295ms step_avg:99.40ms +step:1684/1695 train_time:167397ms step_avg:99.40ms +step:1685/1695 train_time:167500ms step_avg:99.41ms +step:1686/1695 train_time:167602ms step_avg:99.41ms +step:1687/1695 train_time:167703ms step_avg:99.41ms +step:1688/1695 train_time:167805ms step_avg:99.41ms +step:1689/1695 train_time:167907ms step_avg:99.41ms +step:1690/1695 train_time:168010ms step_avg:99.41ms +step:1691/1695 train_time:168114ms step_avg:99.42ms +step:1692/1695 train_time:168216ms step_avg:99.42ms +step:1693/1695 train_time:168320ms step_avg:99.42ms +step:1694/1695 train_time:168425ms step_avg:99.42ms +step:1695/1695 train_time:168528ms step_avg:99.43ms +step:1695/1695 val_loss:3.2774 train_time:168627ms step_avg:99.48ms +peak memory allocated: 34004 MiB reserved: 49680 MiB diff --git a/records/082325_SparseAttnGate/6701af06-6c40-4553-bb04-f501fdd56284.txt b/records/082325_SparseAttnGate/6701af06-6c40-4553-bb04-f501fdd56284.txt new file mode 100644 index 000000000..28ee46440 --- /dev/null +++ b/records/082325_SparseAttnGate/6701af06-6c40-4553-bb04-f501fdd56284.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:24:11 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 305810 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 305811 C /usr/bin/python3 614MiB | +| 0 N/A N/A 305812 C /usr/bin/python3 614MiB | +| 0 N/A N/A 305813 C /usr/bin/python3 614MiB | +| 0 N/A N/A 305814 C /usr/bin/python3 614MiB | +| 0 N/A N/A 305815 C /usr/bin/python3 614MiB | +| 0 N/A N/A 305816 C /usr/bin/python3 614MiB | +| 0 N/A N/A 305817 C /usr/bin/python3 614MiB | +| 1 N/A N/A 305811 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 305812 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 305813 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 305814 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 305815 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 305816 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 305817 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:154ms step_avg:153.57ms +step:2/1695 train_time:178ms step_avg:89.05ms +step:3/1695 train_time:249ms step_avg:83.13ms +step:4/1695 train_time:342ms step_avg:85.39ms +step:5/1695 train_time:434ms step_avg:86.82ms +step:6/1695 train_time:527ms step_avg:87.75ms +step:7/1695 train_time:620ms step_avg:88.51ms +step:8/1695 train_time:713ms step_avg:89.12ms +step:9/1695 train_time:806ms step_avg:89.55ms +step:10/1695 train_time:899ms step_avg:89.86ms +step:11/1695 train_time:991ms step_avg:90.12ms +step:12/1695 train_time:1085ms step_avg:90.45ms +step:13/1695 train_time:1181ms step_avg:90.81ms +step:14/1695 train_time:1276ms step_avg:91.15ms +step:15/1695 train_time:1370ms step_avg:91.32ms +step:16/1695 train_time:1463ms step_avg:91.47ms +step:17/1695 train_time:1557ms step_avg:91.58ms +step:18/1695 train_time:1650ms step_avg:91.68ms +step:19/1695 train_time:1744ms step_avg:91.82ms +step:20/1695 train_time:1838ms step_avg:91.90ms +step:21/1695 train_time:1931ms step_avg:91.96ms +step:22/1695 train_time:2025ms step_avg:92.05ms +step:23/1695 train_time:2119ms step_avg:92.12ms +step:24/1695 train_time:2213ms step_avg:92.19ms +step:25/1695 train_time:2307ms step_avg:92.28ms +step:26/1695 train_time:2401ms step_avg:92.36ms +step:27/1695 train_time:2495ms step_avg:92.42ms +step:28/1695 train_time:2590ms step_avg:92.51ms +step:29/1695 train_time:2684ms step_avg:92.55ms +step:30/1695 train_time:2778ms step_avg:92.59ms +step:31/1695 train_time:2871ms step_avg:92.62ms +step:32/1695 train_time:2965ms step_avg:92.66ms +step:33/1695 train_time:3059ms step_avg:92.70ms +step:34/1695 train_time:3153ms step_avg:92.73ms +step:35/1695 train_time:3247ms step_avg:92.77ms +step:36/1695 train_time:3341ms step_avg:92.82ms +step:37/1695 train_time:3436ms step_avg:92.86ms +step:38/1695 train_time:3530ms step_avg:92.89ms +step:39/1695 train_time:3624ms step_avg:92.92ms +step:40/1695 train_time:3718ms step_avg:92.96ms +step:41/1695 train_time:3813ms step_avg:93.00ms +step:42/1695 train_time:3906ms step_avg:93.01ms +step:43/1695 train_time:4000ms step_avg:93.01ms +step:44/1695 train_time:4093ms step_avg:93.03ms +step:45/1695 train_time:4187ms step_avg:93.04ms +step:46/1695 train_time:4281ms step_avg:93.07ms +step:47/1695 train_time:4375ms step_avg:93.08ms +step:48/1695 train_time:4468ms step_avg:93.09ms +step:49/1695 train_time:4562ms step_avg:93.11ms +step:50/1695 train_time:4656ms step_avg:93.12ms +step:51/1695 train_time:4750ms step_avg:93.13ms +step:52/1695 train_time:4844ms step_avg:93.15ms +step:53/1695 train_time:4938ms step_avg:93.16ms +step:54/1695 train_time:5031ms step_avg:93.16ms +step:55/1695 train_time:5124ms step_avg:93.17ms +step:56/1695 train_time:5218ms step_avg:93.19ms +step:57/1695 train_time:5311ms step_avg:93.18ms +step:58/1695 train_time:5405ms step_avg:93.19ms +step:59/1695 train_time:5499ms step_avg:93.20ms +step:60/1695 train_time:5592ms step_avg:93.20ms +step:61/1695 train_time:5686ms step_avg:93.22ms +step:62/1695 train_time:5780ms step_avg:93.23ms +step:63/1695 train_time:5874ms step_avg:93.24ms +step:64/1695 train_time:5967ms step_avg:93.24ms +step:65/1695 train_time:6061ms step_avg:93.25ms +step:66/1695 train_time:6155ms step_avg:93.26ms +step:67/1695 train_time:6248ms step_avg:93.26ms +step:68/1695 train_time:6342ms step_avg:93.27ms +step:69/1695 train_time:6436ms step_avg:93.27ms +step:70/1695 train_time:6529ms step_avg:93.27ms +step:71/1695 train_time:6623ms step_avg:93.28ms +step:72/1695 train_time:6717ms step_avg:93.30ms +step:73/1695 train_time:6811ms step_avg:93.30ms +step:74/1695 train_time:6905ms step_avg:93.31ms +step:75/1695 train_time:7000ms step_avg:93.33ms +step:76/1695 train_time:7093ms step_avg:93.32ms +step:77/1695 train_time:7186ms step_avg:93.32ms +step:78/1695 train_time:7279ms step_avg:93.32ms +step:79/1695 train_time:7373ms step_avg:93.32ms +step:80/1695 train_time:7466ms step_avg:93.32ms +step:81/1695 train_time:7560ms step_avg:93.33ms +step:82/1695 train_time:7654ms step_avg:93.35ms +step:83/1695 train_time:7748ms step_avg:93.35ms +step:84/1695 train_time:7842ms step_avg:93.36ms +step:85/1695 train_time:7936ms step_avg:93.36ms +step:86/1695 train_time:8029ms step_avg:93.36ms +step:87/1695 train_time:8123ms step_avg:93.37ms +step:88/1695 train_time:8218ms step_avg:93.39ms +step:89/1695 train_time:8312ms step_avg:93.39ms +step:90/1695 train_time:8405ms step_avg:93.39ms +step:91/1695 train_time:8498ms step_avg:93.39ms +step:92/1695 train_time:8591ms step_avg:93.39ms +step:93/1695 train_time:8685ms step_avg:93.39ms +step:94/1695 train_time:8779ms step_avg:93.39ms +step:95/1695 train_time:8872ms step_avg:93.39ms +step:96/1695 train_time:8966ms step_avg:93.39ms +step:97/1695 train_time:9060ms step_avg:93.40ms +step:98/1695 train_time:9155ms step_avg:93.42ms +step:99/1695 train_time:9247ms step_avg:93.40ms +step:100/1695 train_time:9341ms step_avg:93.41ms +step:101/1695 train_time:9434ms step_avg:93.41ms +step:102/1695 train_time:9528ms step_avg:93.41ms +step:103/1695 train_time:9622ms step_avg:93.42ms +step:104/1695 train_time:9716ms step_avg:93.42ms +step:105/1695 train_time:9809ms step_avg:93.42ms +step:106/1695 train_time:9903ms step_avg:93.42ms +step:107/1695 train_time:9997ms step_avg:93.43ms +step:108/1695 train_time:10091ms step_avg:93.43ms +step:109/1695 train_time:10184ms step_avg:93.43ms +step:110/1695 train_time:10277ms step_avg:93.43ms +step:111/1695 train_time:10370ms step_avg:93.43ms +step:112/1695 train_time:10464ms step_avg:93.43ms +step:113/1695 train_time:10558ms step_avg:93.43ms +step:114/1695 train_time:10651ms step_avg:93.43ms +step:115/1695 train_time:10745ms step_avg:93.44ms +step:116/1695 train_time:10840ms step_avg:93.45ms +step:117/1695 train_time:10933ms step_avg:93.44ms +step:118/1695 train_time:11027ms step_avg:93.45ms +step:119/1695 train_time:11121ms step_avg:93.45ms +step:120/1695 train_time:11214ms step_avg:93.45ms +step:121/1695 train_time:11307ms step_avg:93.45ms +step:122/1695 train_time:11401ms step_avg:93.45ms +step:123/1695 train_time:11495ms step_avg:93.45ms +step:124/1695 train_time:11588ms step_avg:93.45ms +step:125/1695 train_time:11682ms step_avg:93.45ms +step:125/1695 val_loss:4.5907 train_time:11774ms step_avg:94.19ms +step:126/1695 train_time:11800ms step_avg:93.65ms +step:127/1695 train_time:11878ms step_avg:93.52ms +step:128/1695 train_time:11980ms step_avg:93.60ms +step:129/1695 train_time:12076ms step_avg:93.61ms +step:130/1695 train_time:12171ms step_avg:93.62ms +step:131/1695 train_time:12264ms step_avg:93.62ms +step:132/1695 train_time:12357ms step_avg:93.61ms +step:133/1695 train_time:12451ms step_avg:93.62ms +step:134/1695 train_time:12544ms step_avg:93.61ms +step:135/1695 train_time:12638ms step_avg:93.61ms +step:136/1695 train_time:12732ms step_avg:93.62ms +step:137/1695 train_time:12828ms step_avg:93.63ms +step:138/1695 train_time:12924ms step_avg:93.65ms +step:139/1695 train_time:13018ms step_avg:93.65ms +step:140/1695 train_time:13114ms step_avg:93.67ms +step:141/1695 train_time:13209ms step_avg:93.68ms +step:142/1695 train_time:13302ms step_avg:93.68ms +step:143/1695 train_time:13396ms step_avg:93.68ms +step:144/1695 train_time:13490ms step_avg:93.68ms +step:145/1695 train_time:13583ms step_avg:93.68ms +step:146/1695 train_time:13676ms step_avg:93.67ms +step:147/1695 train_time:13771ms step_avg:93.68ms +step:148/1695 train_time:13866ms step_avg:93.69ms +step:149/1695 train_time:13960ms step_avg:93.69ms +step:150/1695 train_time:14055ms step_avg:93.70ms +step:151/1695 train_time:14151ms step_avg:93.71ms +step:152/1695 train_time:14245ms step_avg:93.72ms +step:153/1695 train_time:14339ms step_avg:93.72ms +step:154/1695 train_time:14433ms step_avg:93.72ms +step:155/1695 train_time:14527ms step_avg:93.72ms +step:156/1695 train_time:14620ms step_avg:93.72ms +step:157/1695 train_time:14714ms step_avg:93.72ms +step:158/1695 train_time:14809ms step_avg:93.73ms +step:159/1695 train_time:14903ms step_avg:93.73ms +step:160/1695 train_time:14997ms step_avg:93.73ms +step:161/1695 train_time:15092ms step_avg:93.74ms +step:162/1695 train_time:15187ms step_avg:93.74ms +step:163/1695 train_time:15280ms step_avg:93.74ms +step:164/1695 train_time:15374ms step_avg:93.74ms +step:165/1695 train_time:15468ms step_avg:93.75ms +step:166/1695 train_time:15563ms step_avg:93.75ms +step:167/1695 train_time:15656ms step_avg:93.75ms +step:168/1695 train_time:15750ms step_avg:93.75ms +step:169/1695 train_time:15845ms step_avg:93.75ms +step:170/1695 train_time:15939ms step_avg:93.76ms +step:171/1695 train_time:16033ms step_avg:93.76ms +step:172/1695 train_time:16128ms step_avg:93.77ms +step:173/1695 train_time:16222ms step_avg:93.77ms +step:174/1695 train_time:16316ms step_avg:93.77ms +step:175/1695 train_time:16410ms step_avg:93.77ms +step:176/1695 train_time:16505ms step_avg:93.78ms +step:177/1695 train_time:16599ms step_avg:93.78ms +step:178/1695 train_time:16693ms step_avg:93.78ms +step:179/1695 train_time:16787ms step_avg:93.78ms +step:180/1695 train_time:16882ms step_avg:93.79ms +step:181/1695 train_time:16976ms step_avg:93.79ms +step:182/1695 train_time:17070ms step_avg:93.79ms +step:183/1695 train_time:17164ms step_avg:93.79ms +step:184/1695 train_time:17258ms step_avg:93.79ms +step:185/1695 train_time:17353ms step_avg:93.80ms +step:186/1695 train_time:17448ms step_avg:93.81ms +step:187/1695 train_time:17542ms step_avg:93.81ms +step:188/1695 train_time:17636ms step_avg:93.81ms +step:189/1695 train_time:17730ms step_avg:93.81ms +step:190/1695 train_time:17824ms step_avg:93.81ms +step:191/1695 train_time:17918ms step_avg:93.81ms +step:192/1695 train_time:18012ms step_avg:93.81ms +step:193/1695 train_time:18107ms step_avg:93.82ms +step:194/1695 train_time:18200ms step_avg:93.82ms +step:195/1695 train_time:18294ms step_avg:93.82ms +step:196/1695 train_time:18388ms step_avg:93.82ms +step:197/1695 train_time:18482ms step_avg:93.82ms +step:198/1695 train_time:18576ms step_avg:93.82ms +step:199/1695 train_time:18670ms step_avg:93.82ms +step:200/1695 train_time:18764ms step_avg:93.82ms +step:201/1695 train_time:18857ms step_avg:93.82ms +step:202/1695 train_time:18951ms step_avg:93.82ms +step:203/1695 train_time:19046ms step_avg:93.82ms +step:204/1695 train_time:19140ms step_avg:93.82ms +step:205/1695 train_time:19234ms step_avg:93.82ms +step:206/1695 train_time:19328ms step_avg:93.83ms +step:207/1695 train_time:19422ms step_avg:93.82ms +step:208/1695 train_time:19516ms step_avg:93.83ms +step:209/1695 train_time:19610ms step_avg:93.83ms +step:210/1695 train_time:19704ms step_avg:93.83ms +step:211/1695 train_time:19798ms step_avg:93.83ms +step:212/1695 train_time:19893ms step_avg:93.83ms +step:213/1695 train_time:19987ms step_avg:93.83ms +step:214/1695 train_time:20080ms step_avg:93.83ms +step:215/1695 train_time:20174ms step_avg:93.83ms +step:216/1695 train_time:20269ms step_avg:93.84ms +step:217/1695 train_time:20365ms step_avg:93.85ms +step:218/1695 train_time:20459ms step_avg:93.85ms +step:219/1695 train_time:20553ms step_avg:93.85ms +step:220/1695 train_time:20649ms step_avg:93.86ms +step:221/1695 train_time:20741ms step_avg:93.85ms +step:222/1695 train_time:20836ms step_avg:93.85ms +step:223/1695 train_time:20930ms step_avg:93.86ms +step:224/1695 train_time:21024ms step_avg:93.86ms +step:225/1695 train_time:21118ms step_avg:93.86ms +step:226/1695 train_time:21213ms step_avg:93.86ms +step:227/1695 train_time:21308ms step_avg:93.87ms +step:228/1695 train_time:21402ms step_avg:93.87ms +step:229/1695 train_time:21496ms step_avg:93.87ms +step:230/1695 train_time:21590ms step_avg:93.87ms +step:231/1695 train_time:21684ms step_avg:93.87ms +step:232/1695 train_time:21778ms step_avg:93.87ms +step:233/1695 train_time:21872ms step_avg:93.87ms +step:234/1695 train_time:21966ms step_avg:93.87ms +step:235/1695 train_time:22060ms step_avg:93.87ms +step:236/1695 train_time:22154ms step_avg:93.87ms +step:237/1695 train_time:22250ms step_avg:93.88ms +step:238/1695 train_time:22343ms step_avg:93.88ms +step:239/1695 train_time:22437ms step_avg:93.88ms +step:240/1695 train_time:22533ms step_avg:93.89ms +step:241/1695 train_time:22628ms step_avg:93.89ms +step:242/1695 train_time:22722ms step_avg:93.89ms +step:243/1695 train_time:22816ms step_avg:93.89ms +step:244/1695 train_time:22910ms step_avg:93.89ms +step:245/1695 train_time:23003ms step_avg:93.89ms +step:246/1695 train_time:23098ms step_avg:93.89ms +step:247/1695 train_time:23192ms step_avg:93.89ms +step:248/1695 train_time:23286ms step_avg:93.89ms +step:249/1695 train_time:23380ms step_avg:93.89ms +step:250/1695 train_time:23474ms step_avg:93.90ms +step:250/1695 val_loss:4.0689 train_time:23568ms step_avg:94.27ms +step:251/1695 train_time:23593ms step_avg:94.00ms +step:252/1695 train_time:23674ms step_avg:93.94ms +step:253/1695 train_time:23772ms step_avg:93.96ms +step:254/1695 train_time:23867ms step_avg:93.97ms +step:255/1695 train_time:23962ms step_avg:93.97ms +step:256/1695 train_time:24056ms step_avg:93.97ms +step:257/1695 train_time:24150ms step_avg:93.97ms +step:258/1695 train_time:24243ms step_avg:93.97ms +step:259/1695 train_time:24337ms step_avg:93.96ms +step:260/1695 train_time:24431ms step_avg:93.97ms +step:261/1695 train_time:24525ms step_avg:93.97ms +step:262/1695 train_time:24620ms step_avg:93.97ms +step:263/1695 train_time:24716ms step_avg:93.98ms +step:264/1695 train_time:24812ms step_avg:93.99ms +step:265/1695 train_time:24907ms step_avg:93.99ms +step:266/1695 train_time:25002ms step_avg:93.99ms +step:267/1695 train_time:25095ms step_avg:93.99ms +step:268/1695 train_time:25190ms step_avg:93.99ms +step:269/1695 train_time:25284ms step_avg:93.99ms +step:270/1695 train_time:25378ms step_avg:93.99ms +step:271/1695 train_time:25471ms step_avg:93.99ms +step:272/1695 train_time:25566ms step_avg:93.99ms +step:273/1695 train_time:25660ms step_avg:93.99ms +step:274/1695 train_time:25756ms step_avg:94.00ms +step:275/1695 train_time:25851ms step_avg:94.01ms +step:276/1695 train_time:25947ms step_avg:94.01ms +step:277/1695 train_time:26041ms step_avg:94.01ms +step:278/1695 train_time:26135ms step_avg:94.01ms +step:279/1695 train_time:26230ms step_avg:94.02ms +step:280/1695 train_time:26325ms step_avg:94.02ms +step:281/1695 train_time:26419ms step_avg:94.02ms +step:282/1695 train_time:26513ms step_avg:94.02ms +step:283/1695 train_time:26609ms step_avg:94.02ms +step:284/1695 train_time:26703ms step_avg:94.03ms +step:285/1695 train_time:26798ms step_avg:94.03ms +step:286/1695 train_time:26893ms step_avg:94.03ms +step:287/1695 train_time:26989ms step_avg:94.04ms +step:288/1695 train_time:27083ms step_avg:94.04ms +step:289/1695 train_time:27177ms step_avg:94.04ms +step:290/1695 train_time:27272ms step_avg:94.04ms +step:291/1695 train_time:27368ms step_avg:94.05ms +step:292/1695 train_time:27462ms step_avg:94.05ms +step:293/1695 train_time:27556ms step_avg:94.05ms +step:294/1695 train_time:27650ms step_avg:94.05ms +step:295/1695 train_time:27745ms step_avg:94.05ms +step:296/1695 train_time:27839ms step_avg:94.05ms +step:297/1695 train_time:27934ms step_avg:94.05ms +step:298/1695 train_time:28028ms step_avg:94.05ms +step:299/1695 train_time:28123ms step_avg:94.06ms +step:300/1695 train_time:28217ms step_avg:94.06ms +step:301/1695 train_time:28312ms step_avg:94.06ms +step:302/1695 train_time:28407ms step_avg:94.06ms +step:303/1695 train_time:28500ms step_avg:94.06ms +step:304/1695 train_time:28595ms step_avg:94.06ms +step:305/1695 train_time:28691ms step_avg:94.07ms +step:306/1695 train_time:28785ms step_avg:94.07ms +step:307/1695 train_time:28879ms step_avg:94.07ms +step:308/1695 train_time:28973ms step_avg:94.07ms +step:309/1695 train_time:29068ms step_avg:94.07ms +step:310/1695 train_time:29163ms step_avg:94.07ms +step:311/1695 train_time:29257ms step_avg:94.07ms +step:312/1695 train_time:29352ms step_avg:94.08ms +step:313/1695 train_time:29447ms step_avg:94.08ms +step:314/1695 train_time:29540ms step_avg:94.08ms +step:315/1695 train_time:29635ms step_avg:94.08ms +step:316/1695 train_time:29729ms step_avg:94.08ms +step:317/1695 train_time:29824ms step_avg:94.08ms +step:318/1695 train_time:29918ms step_avg:94.08ms +step:319/1695 train_time:30012ms step_avg:94.08ms +step:320/1695 train_time:30107ms step_avg:94.09ms +step:321/1695 train_time:30201ms step_avg:94.09ms +step:322/1695 train_time:30296ms step_avg:94.09ms +step:323/1695 train_time:30390ms step_avg:94.09ms +step:324/1695 train_time:30485ms step_avg:94.09ms +step:325/1695 train_time:30578ms step_avg:94.09ms +step:326/1695 train_time:30674ms step_avg:94.09ms +step:327/1695 train_time:30769ms step_avg:94.09ms +step:328/1695 train_time:30863ms step_avg:94.09ms +step:329/1695 train_time:30957ms step_avg:94.09ms +step:330/1695 train_time:31052ms step_avg:94.10ms +step:331/1695 train_time:31146ms step_avg:94.10ms +step:332/1695 train_time:31240ms step_avg:94.10ms +step:333/1695 train_time:31335ms step_avg:94.10ms +step:334/1695 train_time:31430ms step_avg:94.10ms +step:335/1695 train_time:31524ms step_avg:94.10ms +step:336/1695 train_time:31618ms step_avg:94.10ms +step:337/1695 train_time:31713ms step_avg:94.11ms +step:338/1695 train_time:31808ms step_avg:94.11ms +step:339/1695 train_time:31902ms step_avg:94.11ms +step:340/1695 train_time:31996ms step_avg:94.11ms +step:341/1695 train_time:32092ms step_avg:94.11ms +step:342/1695 train_time:32187ms step_avg:94.12ms +step:343/1695 train_time:32281ms step_avg:94.11ms +step:344/1695 train_time:32375ms step_avg:94.11ms +step:345/1695 train_time:32470ms step_avg:94.12ms +step:346/1695 train_time:32564ms step_avg:94.12ms +step:347/1695 train_time:32658ms step_avg:94.11ms +step:348/1695 train_time:32753ms step_avg:94.12ms +step:349/1695 train_time:32848ms step_avg:94.12ms +step:350/1695 train_time:32942ms step_avg:94.12ms +step:351/1695 train_time:33036ms step_avg:94.12ms +step:352/1695 train_time:33132ms step_avg:94.12ms +step:353/1695 train_time:33226ms step_avg:94.13ms +step:354/1695 train_time:33320ms step_avg:94.12ms +step:355/1695 train_time:33415ms step_avg:94.13ms +step:356/1695 train_time:33509ms step_avg:94.13ms +step:357/1695 train_time:33604ms step_avg:94.13ms +step:358/1695 train_time:33697ms step_avg:94.13ms +step:359/1695 train_time:33792ms step_avg:94.13ms +step:360/1695 train_time:33887ms step_avg:94.13ms +step:361/1695 train_time:33981ms step_avg:94.13ms +step:362/1695 train_time:34075ms step_avg:94.13ms +step:363/1695 train_time:34171ms step_avg:94.14ms +step:364/1695 train_time:34265ms step_avg:94.14ms +step:365/1695 train_time:34359ms step_avg:94.13ms +step:366/1695 train_time:34454ms step_avg:94.14ms +step:367/1695 train_time:34548ms step_avg:94.14ms +step:368/1695 train_time:34642ms step_avg:94.14ms +step:369/1695 train_time:34736ms step_avg:94.14ms +step:370/1695 train_time:34831ms step_avg:94.14ms +step:371/1695 train_time:34925ms step_avg:94.14ms +step:372/1695 train_time:35019ms step_avg:94.14ms +step:373/1695 train_time:35113ms step_avg:94.14ms +step:374/1695 train_time:35209ms step_avg:94.14ms +step:375/1695 train_time:35302ms step_avg:94.14ms +step:375/1695 val_loss:3.8750 train_time:35394ms step_avg:94.39ms +step:376/1695 train_time:35420ms step_avg:94.20ms +step:377/1695 train_time:35500ms step_avg:94.16ms +step:378/1695 train_time:35598ms step_avg:94.18ms +step:379/1695 train_time:35696ms step_avg:94.18ms +step:380/1695 train_time:35791ms step_avg:94.19ms +step:381/1695 train_time:35888ms step_avg:94.19ms +step:382/1695 train_time:35983ms step_avg:94.20ms +step:383/1695 train_time:36079ms step_avg:94.20ms +step:384/1695 train_time:36174ms step_avg:94.20ms +step:385/1695 train_time:36269ms step_avg:94.21ms +step:386/1695 train_time:36365ms step_avg:94.21ms +step:387/1695 train_time:36462ms step_avg:94.22ms +step:388/1695 train_time:36559ms step_avg:94.22ms +step:389/1695 train_time:36656ms step_avg:94.23ms +step:390/1695 train_time:36753ms step_avg:94.24ms +step:391/1695 train_time:36850ms step_avg:94.25ms +step:392/1695 train_time:36946ms step_avg:94.25ms +step:393/1695 train_time:37041ms step_avg:94.25ms +step:394/1695 train_time:37138ms step_avg:94.26ms +step:395/1695 train_time:37233ms step_avg:94.26ms +step:396/1695 train_time:37329ms step_avg:94.27ms +step:397/1695 train_time:37426ms step_avg:94.27ms +step:398/1695 train_time:37523ms step_avg:94.28ms +step:399/1695 train_time:37619ms step_avg:94.28ms +step:400/1695 train_time:37714ms step_avg:94.29ms +step:401/1695 train_time:37812ms step_avg:94.29ms +step:402/1695 train_time:37908ms step_avg:94.30ms +step:403/1695 train_time:38005ms step_avg:94.31ms +step:404/1695 train_time:38100ms step_avg:94.31ms +step:405/1695 train_time:38196ms step_avg:94.31ms +step:406/1695 train_time:38292ms step_avg:94.32ms +step:407/1695 train_time:38389ms step_avg:94.32ms +step:408/1695 train_time:38486ms step_avg:94.33ms +step:409/1695 train_time:38582ms step_avg:94.33ms +step:410/1695 train_time:38677ms step_avg:94.33ms +step:411/1695 train_time:38774ms step_avg:94.34ms +step:412/1695 train_time:38871ms step_avg:94.35ms +step:413/1695 train_time:38967ms step_avg:94.35ms +step:414/1695 train_time:39063ms step_avg:94.36ms +step:415/1695 train_time:39159ms step_avg:94.36ms +step:416/1695 train_time:39255ms step_avg:94.36ms +step:417/1695 train_time:39351ms step_avg:94.37ms +step:418/1695 train_time:39448ms step_avg:94.37ms +step:419/1695 train_time:39545ms step_avg:94.38ms +step:420/1695 train_time:39641ms step_avg:94.38ms +step:421/1695 train_time:39737ms step_avg:94.39ms +step:422/1695 train_time:39833ms step_avg:94.39ms +step:423/1695 train_time:39930ms step_avg:94.40ms +step:424/1695 train_time:40027ms step_avg:94.40ms +step:425/1695 train_time:40123ms step_avg:94.41ms +step:426/1695 train_time:40219ms step_avg:94.41ms +step:427/1695 train_time:40314ms step_avg:94.41ms +step:428/1695 train_time:40410ms step_avg:94.42ms +step:429/1695 train_time:40506ms step_avg:94.42ms +step:430/1695 train_time:40602ms step_avg:94.42ms +step:431/1695 train_time:40698ms step_avg:94.43ms +step:432/1695 train_time:40794ms step_avg:94.43ms +step:433/1695 train_time:40891ms step_avg:94.44ms +step:434/1695 train_time:40988ms step_avg:94.44ms +step:435/1695 train_time:41085ms step_avg:94.45ms +step:436/1695 train_time:41181ms step_avg:94.45ms +step:437/1695 train_time:41277ms step_avg:94.46ms +step:438/1695 train_time:41373ms step_avg:94.46ms +step:439/1695 train_time:41469ms step_avg:94.46ms +step:440/1695 train_time:41566ms step_avg:94.47ms +step:441/1695 train_time:41662ms step_avg:94.47ms +step:442/1695 train_time:41757ms step_avg:94.47ms +step:443/1695 train_time:41854ms step_avg:94.48ms +step:444/1695 train_time:41951ms step_avg:94.48ms +step:445/1695 train_time:42047ms step_avg:94.49ms +step:446/1695 train_time:42144ms step_avg:94.49ms +step:447/1695 train_time:42240ms step_avg:94.50ms +step:448/1695 train_time:42336ms step_avg:94.50ms +step:449/1695 train_time:42432ms step_avg:94.50ms +step:450/1695 train_time:42529ms step_avg:94.51ms +step:451/1695 train_time:42626ms step_avg:94.51ms +step:452/1695 train_time:42722ms step_avg:94.52ms +step:453/1695 train_time:42818ms step_avg:94.52ms +step:454/1695 train_time:42914ms step_avg:94.52ms +step:455/1695 train_time:43010ms step_avg:94.53ms +step:456/1695 train_time:43107ms step_avg:94.53ms +step:457/1695 train_time:43203ms step_avg:94.54ms +step:458/1695 train_time:43299ms step_avg:94.54ms +step:459/1695 train_time:43395ms step_avg:94.54ms +step:460/1695 train_time:43492ms step_avg:94.55ms +step:461/1695 train_time:43589ms step_avg:94.55ms +step:462/1695 train_time:43685ms step_avg:94.56ms +step:463/1695 train_time:43782ms step_avg:94.56ms +step:464/1695 train_time:43878ms step_avg:94.56ms +step:465/1695 train_time:43974ms step_avg:94.57ms +step:466/1695 train_time:44070ms step_avg:94.57ms +step:467/1695 train_time:44167ms step_avg:94.58ms +step:468/1695 train_time:44263ms step_avg:94.58ms +step:469/1695 train_time:44359ms step_avg:94.58ms +step:470/1695 train_time:44455ms step_avg:94.59ms +step:471/1695 train_time:44552ms step_avg:94.59ms +step:472/1695 train_time:44650ms step_avg:94.60ms +step:473/1695 train_time:44746ms step_avg:94.60ms +step:474/1695 train_time:44842ms step_avg:94.60ms +step:475/1695 train_time:44938ms step_avg:94.61ms +step:476/1695 train_time:45035ms step_avg:94.61ms +step:477/1695 train_time:45131ms step_avg:94.61ms +step:478/1695 train_time:45228ms step_avg:94.62ms +step:479/1695 train_time:45324ms step_avg:94.62ms +step:480/1695 train_time:45420ms step_avg:94.63ms +step:481/1695 train_time:45516ms step_avg:94.63ms +step:482/1695 train_time:45613ms step_avg:94.63ms +step:483/1695 train_time:45709ms step_avg:94.64ms +step:484/1695 train_time:45805ms step_avg:94.64ms +step:485/1695 train_time:45901ms step_avg:94.64ms +step:486/1695 train_time:45997ms step_avg:94.64ms +step:487/1695 train_time:46093ms step_avg:94.65ms +step:488/1695 train_time:46190ms step_avg:94.65ms +step:489/1695 train_time:46286ms step_avg:94.66ms +step:490/1695 train_time:46383ms step_avg:94.66ms +step:491/1695 train_time:46478ms step_avg:94.66ms +step:492/1695 train_time:46574ms step_avg:94.66ms +step:493/1695 train_time:46671ms step_avg:94.67ms +step:494/1695 train_time:46767ms step_avg:94.67ms +step:495/1695 train_time:46863ms step_avg:94.67ms +step:496/1695 train_time:46959ms step_avg:94.68ms +step:497/1695 train_time:47055ms step_avg:94.68ms +step:498/1695 train_time:47151ms step_avg:94.68ms +step:499/1695 train_time:47248ms step_avg:94.68ms +step:500/1695 train_time:47344ms step_avg:94.69ms +step:500/1695 val_loss:3.7291 train_time:47440ms step_avg:94.88ms +step:501/1695 train_time:47465ms step_avg:94.74ms +step:502/1695 train_time:47549ms step_avg:94.72ms +step:503/1695 train_time:47648ms step_avg:94.73ms +step:504/1695 train_time:47745ms step_avg:94.73ms +step:505/1695 train_time:47840ms step_avg:94.73ms +step:506/1695 train_time:47937ms step_avg:94.74ms +step:507/1695 train_time:48033ms step_avg:94.74ms +step:508/1695 train_time:48128ms step_avg:94.74ms +step:509/1695 train_time:48224ms step_avg:94.74ms +step:510/1695 train_time:48319ms step_avg:94.74ms +step:511/1695 train_time:48416ms step_avg:94.75ms +step:512/1695 train_time:48515ms step_avg:94.76ms +step:513/1695 train_time:48614ms step_avg:94.76ms +step:514/1695 train_time:48712ms step_avg:94.77ms +step:515/1695 train_time:48811ms step_avg:94.78ms +step:516/1695 train_time:48906ms step_avg:94.78ms +step:517/1695 train_time:49001ms step_avg:94.78ms +step:518/1695 train_time:49097ms step_avg:94.78ms +step:519/1695 train_time:49194ms step_avg:94.79ms +step:520/1695 train_time:49290ms step_avg:94.79ms +step:521/1695 train_time:49386ms step_avg:94.79ms +step:522/1695 train_time:49483ms step_avg:94.79ms +step:523/1695 train_time:49580ms step_avg:94.80ms +step:524/1695 train_time:49678ms step_avg:94.81ms +step:525/1695 train_time:49776ms step_avg:94.81ms +step:526/1695 train_time:49874ms step_avg:94.82ms +step:527/1695 train_time:49971ms step_avg:94.82ms +step:528/1695 train_time:50068ms step_avg:94.83ms +step:529/1695 train_time:50163ms step_avg:94.83ms +step:530/1695 train_time:50259ms step_avg:94.83ms +step:531/1695 train_time:50356ms step_avg:94.83ms +step:532/1695 train_time:50454ms step_avg:94.84ms +step:533/1695 train_time:50551ms step_avg:94.84ms +step:534/1695 train_time:50650ms step_avg:94.85ms +step:535/1695 train_time:50748ms step_avg:94.86ms +step:536/1695 train_time:50844ms step_avg:94.86ms +step:537/1695 train_time:50940ms step_avg:94.86ms +step:538/1695 train_time:51037ms step_avg:94.86ms +step:539/1695 train_time:51134ms step_avg:94.87ms +step:540/1695 train_time:51230ms step_avg:94.87ms +step:541/1695 train_time:51326ms step_avg:94.87ms +step:542/1695 train_time:51422ms step_avg:94.87ms +step:543/1695 train_time:51518ms step_avg:94.88ms +step:544/1695 train_time:51616ms step_avg:94.88ms +step:545/1695 train_time:51714ms step_avg:94.89ms +step:546/1695 train_time:51811ms step_avg:94.89ms +step:547/1695 train_time:51909ms step_avg:94.90ms +step:548/1695 train_time:52006ms step_avg:94.90ms +step:549/1695 train_time:52102ms step_avg:94.90ms +step:550/1695 train_time:52198ms step_avg:94.91ms +step:551/1695 train_time:52296ms step_avg:94.91ms +step:552/1695 train_time:52393ms step_avg:94.92ms +step:553/1695 train_time:52489ms step_avg:94.92ms +step:554/1695 train_time:52585ms step_avg:94.92ms +step:555/1695 train_time:52682ms step_avg:94.92ms +step:556/1695 train_time:52778ms step_avg:94.93ms +step:557/1695 train_time:52876ms step_avg:94.93ms +step:558/1695 train_time:52974ms step_avg:94.94ms +step:559/1695 train_time:53072ms step_avg:94.94ms +step:560/1695 train_time:53169ms step_avg:94.94ms +step:561/1695 train_time:53265ms step_avg:94.95ms +step:562/1695 train_time:53361ms step_avg:94.95ms +step:563/1695 train_time:53458ms step_avg:94.95ms +step:564/1695 train_time:53555ms step_avg:94.96ms +step:565/1695 train_time:53652ms step_avg:94.96ms +step:566/1695 train_time:53749ms step_avg:94.96ms +step:567/1695 train_time:53845ms step_avg:94.96ms +step:568/1695 train_time:53941ms step_avg:94.97ms +step:569/1695 train_time:54038ms step_avg:94.97ms +step:570/1695 train_time:54135ms step_avg:94.97ms +step:571/1695 train_time:54232ms step_avg:94.98ms +step:572/1695 train_time:54329ms step_avg:94.98ms +step:573/1695 train_time:54426ms step_avg:94.98ms +step:574/1695 train_time:54522ms step_avg:94.99ms +step:575/1695 train_time:54618ms step_avg:94.99ms +step:576/1695 train_time:54715ms step_avg:94.99ms +step:577/1695 train_time:54813ms step_avg:95.00ms +step:578/1695 train_time:54910ms step_avg:95.00ms +step:579/1695 train_time:55006ms step_avg:95.00ms +step:580/1695 train_time:55103ms step_avg:95.00ms +step:581/1695 train_time:55199ms step_avg:95.01ms +step:582/1695 train_time:55296ms step_avg:95.01ms +step:583/1695 train_time:55394ms step_avg:95.01ms +step:584/1695 train_time:55491ms step_avg:95.02ms +step:585/1695 train_time:55588ms step_avg:95.02ms +step:586/1695 train_time:55684ms step_avg:95.02ms +step:587/1695 train_time:55780ms step_avg:95.02ms +step:588/1695 train_time:55877ms step_avg:95.03ms +step:589/1695 train_time:55973ms step_avg:95.03ms +step:590/1695 train_time:56071ms step_avg:95.04ms +step:591/1695 train_time:56168ms step_avg:95.04ms +step:592/1695 train_time:56264ms step_avg:95.04ms +step:593/1695 train_time:56361ms step_avg:95.04ms +step:594/1695 train_time:56457ms step_avg:95.05ms +step:595/1695 train_time:56556ms step_avg:95.05ms +step:596/1695 train_time:56653ms step_avg:95.06ms +step:597/1695 train_time:56750ms step_avg:95.06ms +step:598/1695 train_time:56847ms step_avg:95.06ms +step:599/1695 train_time:56943ms step_avg:95.06ms +step:600/1695 train_time:57039ms step_avg:95.07ms +step:601/1695 train_time:57136ms step_avg:95.07ms +step:602/1695 train_time:57234ms step_avg:95.07ms +step:603/1695 train_time:57331ms step_avg:95.08ms +step:604/1695 train_time:57427ms step_avg:95.08ms +step:605/1695 train_time:57524ms step_avg:95.08ms +step:606/1695 train_time:57620ms step_avg:95.08ms +step:607/1695 train_time:57717ms step_avg:95.09ms +step:608/1695 train_time:57814ms step_avg:95.09ms +step:609/1695 train_time:57911ms step_avg:95.09ms +step:610/1695 train_time:58006ms step_avg:95.09ms +step:611/1695 train_time:58102ms step_avg:95.09ms +step:612/1695 train_time:58198ms step_avg:95.10ms +step:613/1695 train_time:58296ms step_avg:95.10ms +step:614/1695 train_time:58393ms step_avg:95.10ms +step:615/1695 train_time:58490ms step_avg:95.11ms +step:616/1695 train_time:58587ms step_avg:95.11ms +step:617/1695 train_time:58682ms step_avg:95.11ms +step:618/1695 train_time:58779ms step_avg:95.11ms +step:619/1695 train_time:58876ms step_avg:95.11ms +step:620/1695 train_time:58973ms step_avg:95.12ms +step:621/1695 train_time:59070ms step_avg:95.12ms +step:622/1695 train_time:59167ms step_avg:95.12ms +step:623/1695 train_time:59263ms step_avg:95.13ms +step:624/1695 train_time:59360ms step_avg:95.13ms +step:625/1695 train_time:59457ms step_avg:95.13ms +step:625/1695 val_loss:3.6467 train_time:59552ms step_avg:95.28ms +step:626/1695 train_time:59578ms step_avg:95.17ms +step:627/1695 train_time:59664ms step_avg:95.16ms +step:628/1695 train_time:59765ms step_avg:95.17ms +step:629/1695 train_time:59862ms step_avg:95.17ms +step:630/1695 train_time:59960ms step_avg:95.17ms +step:631/1695 train_time:60058ms step_avg:95.18ms +step:632/1695 train_time:60155ms step_avg:95.18ms +step:633/1695 train_time:60252ms step_avg:95.19ms +step:634/1695 train_time:60349ms step_avg:95.19ms +step:635/1695 train_time:60685ms step_avg:95.57ms +step:636/1695 train_time:60780ms step_avg:95.57ms +step:637/1695 train_time:60877ms step_avg:95.57ms +step:638/1695 train_time:60975ms step_avg:95.57ms +step:639/1695 train_time:61072ms step_avg:95.57ms +step:640/1695 train_time:61169ms step_avg:95.58ms +step:641/1695 train_time:61528ms step_avg:95.99ms +step:642/1695 train_time:61622ms step_avg:95.98ms +step:643/1695 train_time:61720ms step_avg:95.99ms +step:644/1695 train_time:61817ms step_avg:95.99ms +step:645/1695 train_time:61914ms step_avg:95.99ms +step:646/1695 train_time:62011ms step_avg:95.99ms +step:647/1695 train_time:62109ms step_avg:95.99ms +step:648/1695 train_time:62206ms step_avg:96.00ms +step:649/1695 train_time:62303ms step_avg:96.00ms +step:650/1695 train_time:62400ms step_avg:96.00ms +step:651/1695 train_time:62503ms step_avg:96.01ms +step:652/1695 train_time:62882ms step_avg:96.44ms +step:653/1695 train_time:62932ms step_avg:96.37ms +step:654/1695 train_time:63029ms step_avg:96.37ms +step:655/1695 train_time:63126ms step_avg:96.38ms +step:656/1695 train_time:63224ms step_avg:96.38ms +step:657/1695 train_time:63320ms step_avg:96.38ms +step:658/1695 train_time:63418ms step_avg:96.38ms +step:659/1695 train_time:63515ms step_avg:96.38ms +step:660/1695 train_time:63612ms step_avg:96.38ms +step:661/1695 train_time:63709ms step_avg:96.38ms +step:662/1695 train_time:63808ms step_avg:96.39ms +step:663/1695 train_time:63908ms step_avg:96.39ms +step:664/1695 train_time:64006ms step_avg:96.39ms +step:665/1695 train_time:64104ms step_avg:96.40ms +step:666/1695 train_time:64202ms step_avg:96.40ms +step:667/1695 train_time:64299ms step_avg:96.40ms +step:668/1695 train_time:64398ms step_avg:96.40ms +step:669/1695 train_time:64495ms step_avg:96.41ms +step:670/1695 train_time:64593ms step_avg:96.41ms +step:671/1695 train_time:64690ms step_avg:96.41ms +step:672/1695 train_time:64788ms step_avg:96.41ms +step:673/1695 train_time:64886ms step_avg:96.41ms +step:674/1695 train_time:64985ms step_avg:96.42ms +step:675/1695 train_time:65083ms step_avg:96.42ms +step:676/1695 train_time:65181ms step_avg:96.42ms +step:677/1695 train_time:65280ms step_avg:96.42ms +step:678/1695 train_time:65377ms step_avg:96.43ms +step:679/1695 train_time:65475ms step_avg:96.43ms +step:680/1695 train_time:65572ms step_avg:96.43ms +step:681/1695 train_time:65670ms step_avg:96.43ms +step:682/1695 train_time:65768ms step_avg:96.43ms +step:683/1695 train_time:65866ms step_avg:96.44ms +step:684/1695 train_time:65964ms step_avg:96.44ms +step:685/1695 train_time:66063ms step_avg:96.44ms +step:686/1695 train_time:66161ms step_avg:96.45ms +step:687/1695 train_time:66260ms step_avg:96.45ms +step:688/1695 train_time:66359ms step_avg:96.45ms +step:689/1695 train_time:66456ms step_avg:96.45ms +step:690/1695 train_time:66554ms step_avg:96.46ms +step:691/1695 train_time:66652ms step_avg:96.46ms +step:692/1695 train_time:66751ms step_avg:96.46ms +step:693/1695 train_time:66849ms step_avg:96.46ms +step:694/1695 train_time:66947ms step_avg:96.47ms +step:695/1695 train_time:67045ms step_avg:96.47ms +step:696/1695 train_time:67142ms step_avg:96.47ms +step:697/1695 train_time:67240ms step_avg:96.47ms +step:698/1695 train_time:67338ms step_avg:96.47ms +step:699/1695 train_time:67436ms step_avg:96.47ms +step:700/1695 train_time:67534ms step_avg:96.48ms +step:701/1695 train_time:67632ms step_avg:96.48ms +step:702/1695 train_time:67729ms step_avg:96.48ms +step:703/1695 train_time:67827ms step_avg:96.48ms +step:704/1695 train_time:67925ms step_avg:96.48ms +step:705/1695 train_time:68023ms step_avg:96.49ms +step:706/1695 train_time:68121ms step_avg:96.49ms +step:707/1695 train_time:68219ms step_avg:96.49ms +step:708/1695 train_time:68318ms step_avg:96.49ms +step:709/1695 train_time:68415ms step_avg:96.50ms +step:710/1695 train_time:68513ms step_avg:96.50ms +step:711/1695 train_time:68611ms step_avg:96.50ms +step:712/1695 train_time:68940ms step_avg:96.83ms +step:713/1695 train_time:69036ms step_avg:96.82ms +step:714/1695 train_time:69133ms step_avg:96.83ms +step:715/1695 train_time:69230ms step_avg:96.82ms +step:716/1695 train_time:69327ms step_avg:96.82ms +step:717/1695 train_time:69424ms step_avg:96.83ms +step:718/1695 train_time:69521ms step_avg:96.83ms +step:719/1695 train_time:69921ms step_avg:97.25ms +step:720/1695 train_time:70016ms step_avg:97.25ms +step:721/1695 train_time:70113ms step_avg:97.24ms +step:722/1695 train_time:70210ms step_avg:97.24ms +step:723/1695 train_time:70307ms step_avg:97.24ms +step:724/1695 train_time:70404ms step_avg:97.24ms +step:725/1695 train_time:70501ms step_avg:97.24ms +step:726/1695 train_time:70599ms step_avg:97.24ms +step:727/1695 train_time:70696ms step_avg:97.24ms +step:728/1695 train_time:70793ms step_avg:97.24ms +step:729/1695 train_time:70895ms step_avg:97.25ms +step:730/1695 train_time:70994ms step_avg:97.25ms +step:731/1695 train_time:71092ms step_avg:97.25ms +step:732/1695 train_time:71189ms step_avg:97.25ms +step:733/1695 train_time:71286ms step_avg:97.25ms +step:734/1695 train_time:71384ms step_avg:97.25ms +step:735/1695 train_time:71481ms step_avg:97.25ms +step:736/1695 train_time:71579ms step_avg:97.25ms +step:737/1695 train_time:71676ms step_avg:97.25ms +step:738/1695 train_time:71774ms step_avg:97.25ms +step:739/1695 train_time:71872ms step_avg:97.26ms +step:740/1695 train_time:71971ms step_avg:97.26ms +step:741/1695 train_time:72069ms step_avg:97.26ms +step:742/1695 train_time:72167ms step_avg:97.26ms +step:743/1695 train_time:72264ms step_avg:97.26ms +step:744/1695 train_time:72362ms step_avg:97.26ms +step:745/1695 train_time:72459ms step_avg:97.26ms +step:746/1695 train_time:72557ms step_avg:97.26ms +step:747/1695 train_time:72654ms step_avg:97.26ms +step:748/1695 train_time:72752ms step_avg:97.26ms +step:749/1695 train_time:72849ms step_avg:97.26ms +step:750/1695 train_time:72947ms step_avg:97.26ms +step:750/1695 val_loss:3.5846 train_time:73043ms step_avg:97.39ms +step:751/1695 train_time:73069ms step_avg:97.30ms +step:752/1695 train_time:73158ms step_avg:97.28ms +step:753/1695 train_time:73258ms step_avg:97.29ms +step:754/1695 train_time:73356ms step_avg:97.29ms +step:755/1695 train_time:73453ms step_avg:97.29ms +step:756/1695 train_time:73550ms step_avg:97.29ms +step:757/1695 train_time:73648ms step_avg:97.29ms +step:758/1695 train_time:73746ms step_avg:97.29ms +step:759/1695 train_time:73843ms step_avg:97.29ms +step:760/1695 train_time:73940ms step_avg:97.29ms +step:761/1695 train_time:74038ms step_avg:97.29ms +step:762/1695 train_time:74136ms step_avg:97.29ms +step:763/1695 train_time:74236ms step_avg:97.29ms +step:764/1695 train_time:74334ms step_avg:97.30ms +step:765/1695 train_time:74432ms step_avg:97.30ms +step:766/1695 train_time:74530ms step_avg:97.30ms +step:767/1695 train_time:74628ms step_avg:97.30ms +step:768/1695 train_time:74725ms step_avg:97.30ms +step:769/1695 train_time:74823ms step_avg:97.30ms +step:770/1695 train_time:74921ms step_avg:97.30ms +step:771/1695 train_time:75019ms step_avg:97.30ms +step:772/1695 train_time:75117ms step_avg:97.30ms +step:773/1695 train_time:75215ms step_avg:97.30ms +step:774/1695 train_time:75313ms step_avg:97.30ms +step:775/1695 train_time:75412ms step_avg:97.31ms +step:776/1695 train_time:75509ms step_avg:97.31ms +step:777/1695 train_time:75607ms step_avg:97.31ms +step:778/1695 train_time:75931ms step_avg:97.60ms +step:779/1695 train_time:76028ms step_avg:97.60ms +step:780/1695 train_time:76126ms step_avg:97.60ms +step:781/1695 train_time:76223ms step_avg:97.60ms +step:782/1695 train_time:76321ms step_avg:97.60ms +step:783/1695 train_time:76418ms step_avg:97.60ms +step:784/1695 train_time:76515ms step_avg:97.60ms +step:785/1695 train_time:76841ms step_avg:97.89ms +step:786/1695 train_time:76937ms step_avg:97.88ms +step:787/1695 train_time:77034ms step_avg:97.88ms +step:788/1695 train_time:77131ms step_avg:97.88ms +step:789/1695 train_time:77228ms step_avg:97.88ms +step:790/1695 train_time:77326ms step_avg:97.88ms +step:791/1695 train_time:77423ms step_avg:97.88ms +step:792/1695 train_time:77521ms step_avg:97.88ms +step:793/1695 train_time:77617ms step_avg:97.88ms +step:794/1695 train_time:77716ms step_avg:97.88ms +step:795/1695 train_time:77814ms step_avg:97.88ms +step:796/1695 train_time:77913ms step_avg:97.88ms +step:797/1695 train_time:78012ms step_avg:97.88ms +step:798/1695 train_time:78110ms step_avg:97.88ms +step:799/1695 train_time:78208ms step_avg:97.88ms +step:800/1695 train_time:78306ms step_avg:97.88ms +step:801/1695 train_time:78404ms step_avg:97.88ms +step:802/1695 train_time:78501ms step_avg:97.88ms +step:803/1695 train_time:78599ms step_avg:97.88ms +step:804/1695 train_time:78697ms step_avg:97.88ms +step:805/1695 train_time:78795ms step_avg:97.88ms +step:806/1695 train_time:78893ms step_avg:97.88ms +step:807/1695 train_time:78991ms step_avg:97.88ms +step:808/1695 train_time:79090ms step_avg:97.88ms +step:809/1695 train_time:79189ms step_avg:97.88ms +step:810/1695 train_time:79287ms step_avg:97.88ms +step:811/1695 train_time:79386ms step_avg:97.89ms +step:812/1695 train_time:79483ms step_avg:97.89ms +step:813/1695 train_time:79581ms step_avg:97.89ms +step:814/1695 train_time:79679ms step_avg:97.89ms +step:815/1695 train_time:79778ms step_avg:97.89ms +step:816/1695 train_time:79876ms step_avg:97.89ms +step:817/1695 train_time:79974ms step_avg:97.89ms +step:818/1695 train_time:80071ms step_avg:97.89ms +step:819/1695 train_time:80170ms step_avg:97.89ms +step:820/1695 train_time:80268ms step_avg:97.89ms +step:821/1695 train_time:80367ms step_avg:97.89ms +step:822/1695 train_time:80465ms step_avg:97.89ms +step:823/1695 train_time:80563ms step_avg:97.89ms +step:824/1695 train_time:80663ms step_avg:97.89ms +step:825/1695 train_time:80762ms step_avg:97.89ms +step:826/1695 train_time:80861ms step_avg:97.89ms +step:827/1695 train_time:80959ms step_avg:97.89ms +step:828/1695 train_time:81058ms step_avg:97.90ms +step:829/1695 train_time:81156ms step_avg:97.90ms +step:830/1695 train_time:81253ms step_avg:97.90ms +step:831/1695 train_time:81351ms step_avg:97.90ms +step:832/1695 train_time:81449ms step_avg:97.90ms +step:833/1695 train_time:81547ms step_avg:97.90ms +step:834/1695 train_time:81647ms step_avg:97.90ms +step:835/1695 train_time:81746ms step_avg:97.90ms +step:836/1695 train_time:81846ms step_avg:97.90ms +step:837/1695 train_time:81945ms step_avg:97.90ms +step:838/1695 train_time:82044ms step_avg:97.90ms +step:839/1695 train_time:82143ms step_avg:97.91ms +step:840/1695 train_time:82242ms step_avg:97.91ms +step:841/1695 train_time:82340ms step_avg:97.91ms +step:842/1695 train_time:82439ms step_avg:97.91ms +step:843/1695 train_time:82536ms step_avg:97.91ms +step:844/1695 train_time:82634ms step_avg:97.91ms +step:845/1695 train_time:82732ms step_avg:97.91ms +step:846/1695 train_time:82830ms step_avg:97.91ms +step:847/1695 train_time:82928ms step_avg:97.91ms +step:848/1695 train_time:83028ms step_avg:97.91ms +step:849/1695 train_time:83127ms step_avg:97.91ms +step:850/1695 train_time:83225ms step_avg:97.91ms +step:851/1695 train_time:83324ms step_avg:97.91ms +step:852/1695 train_time:83424ms step_avg:97.92ms +step:853/1695 train_time:83523ms step_avg:97.92ms +step:854/1695 train_time:83620ms step_avg:97.92ms +step:855/1695 train_time:83719ms step_avg:97.92ms +step:856/1695 train_time:83817ms step_avg:97.92ms +step:857/1695 train_time:83916ms step_avg:97.92ms +step:858/1695 train_time:84015ms step_avg:97.92ms +step:859/1695 train_time:84113ms step_avg:97.92ms +step:860/1695 train_time:84211ms step_avg:97.92ms +step:861/1695 train_time:84309ms step_avg:97.92ms +step:862/1695 train_time:84661ms step_avg:98.21ms +step:863/1695 train_time:84758ms step_avg:98.21ms +step:864/1695 train_time:84855ms step_avg:98.21ms +step:865/1695 train_time:84952ms step_avg:98.21ms +step:866/1695 train_time:85050ms step_avg:98.21ms +step:867/1695 train_time:85147ms step_avg:98.21ms +step:868/1695 train_time:85245ms step_avg:98.21ms +step:869/1695 train_time:85343ms step_avg:98.21ms +step:870/1695 train_time:85441ms step_avg:98.21ms +step:871/1695 train_time:85540ms step_avg:98.21ms +step:872/1695 train_time:85641ms step_avg:98.21ms +step:873/1695 train_time:85740ms step_avg:98.21ms +step:874/1695 train_time:85839ms step_avg:98.21ms +step:875/1695 train_time:85937ms step_avg:98.21ms +step:875/1695 val_loss:3.5373 train_time:86033ms step_avg:98.32ms +step:876/1695 train_time:86058ms step_avg:98.24ms +step:877/1695 train_time:86144ms step_avg:98.23ms +step:878/1695 train_time:86248ms step_avg:98.23ms +step:879/1695 train_time:86346ms step_avg:98.23ms +step:880/1695 train_time:86443ms step_avg:98.23ms +step:881/1695 train_time:86542ms step_avg:98.23ms +step:882/1695 train_time:86641ms step_avg:98.23ms +step:883/1695 train_time:86740ms step_avg:98.23ms +step:884/1695 train_time:86840ms step_avg:98.24ms +step:885/1695 train_time:86939ms step_avg:98.24ms +step:886/1695 train_time:87038ms step_avg:98.24ms +step:887/1695 train_time:87139ms step_avg:98.24ms +step:888/1695 train_time:87241ms step_avg:98.24ms +step:889/1695 train_time:87342ms step_avg:98.25ms +step:890/1695 train_time:87442ms step_avg:98.25ms +step:891/1695 train_time:87541ms step_avg:98.25ms +step:892/1695 train_time:87641ms step_avg:98.25ms +step:893/1695 train_time:87740ms step_avg:98.25ms +step:894/1695 train_time:87839ms step_avg:98.25ms +step:895/1695 train_time:87938ms step_avg:98.25ms +step:896/1695 train_time:88038ms step_avg:98.26ms +step:897/1695 train_time:88137ms step_avg:98.26ms +step:898/1695 train_time:88238ms step_avg:98.26ms +step:899/1695 train_time:88340ms step_avg:98.26ms +step:900/1695 train_time:88441ms step_avg:98.27ms +step:901/1695 train_time:88540ms step_avg:98.27ms +step:902/1695 train_time:88640ms step_avg:98.27ms +step:903/1695 train_time:88740ms step_avg:98.27ms +step:904/1695 train_time:88839ms step_avg:98.27ms +step:905/1695 train_time:88939ms step_avg:98.27ms +step:906/1695 train_time:89039ms step_avg:98.28ms +step:907/1695 train_time:89139ms step_avg:98.28ms +step:908/1695 train_time:89240ms step_avg:98.28ms +step:909/1695 train_time:89340ms step_avg:98.28ms +step:910/1695 train_time:89440ms step_avg:98.29ms +step:911/1695 train_time:89540ms step_avg:98.29ms +step:912/1695 train_time:89640ms step_avg:98.29ms +step:913/1695 train_time:89739ms step_avg:98.29ms +step:914/1695 train_time:89839ms step_avg:98.29ms +step:915/1695 train_time:89938ms step_avg:98.29ms +step:916/1695 train_time:90038ms step_avg:98.29ms +step:917/1695 train_time:90138ms step_avg:98.30ms +step:918/1695 train_time:90239ms step_avg:98.30ms +step:919/1695 train_time:90340ms step_avg:98.30ms +step:920/1695 train_time:90441ms step_avg:98.31ms +step:921/1695 train_time:90541ms step_avg:98.31ms +step:922/1695 train_time:90641ms step_avg:98.31ms +step:923/1695 train_time:90741ms step_avg:98.31ms +step:924/1695 train_time:90841ms step_avg:98.31ms +step:925/1695 train_time:90940ms step_avg:98.31ms +step:926/1695 train_time:91039ms step_avg:98.31ms +step:927/1695 train_time:91139ms step_avg:98.32ms +step:928/1695 train_time:91238ms step_avg:98.32ms +step:929/1695 train_time:91339ms step_avg:98.32ms +step:930/1695 train_time:91441ms step_avg:98.32ms +step:931/1695 train_time:91541ms step_avg:98.33ms +step:932/1695 train_time:91641ms step_avg:98.33ms +step:933/1695 train_time:91741ms step_avg:98.33ms +step:934/1695 train_time:91840ms step_avg:98.33ms +step:935/1695 train_time:91939ms step_avg:98.33ms +step:936/1695 train_time:92039ms step_avg:98.33ms +step:937/1695 train_time:92140ms step_avg:98.33ms +step:938/1695 train_time:92240ms step_avg:98.34ms +step:939/1695 train_time:92340ms step_avg:98.34ms +step:940/1695 train_time:92440ms step_avg:98.34ms +step:941/1695 train_time:92541ms step_avg:98.34ms +step:942/1695 train_time:92640ms step_avg:98.34ms +step:943/1695 train_time:92741ms step_avg:98.35ms +step:944/1695 train_time:92840ms step_avg:98.35ms +step:945/1695 train_time:92941ms step_avg:98.35ms +step:946/1695 train_time:93040ms step_avg:98.35ms +step:947/1695 train_time:93139ms step_avg:98.35ms +step:948/1695 train_time:93239ms step_avg:98.35ms +step:949/1695 train_time:93339ms step_avg:98.36ms +step:950/1695 train_time:93440ms step_avg:98.36ms +step:951/1695 train_time:93540ms step_avg:98.36ms +step:952/1695 train_time:93641ms step_avg:98.36ms +step:953/1695 train_time:93740ms step_avg:98.36ms +step:954/1695 train_time:93840ms step_avg:98.36ms +step:955/1695 train_time:93940ms step_avg:98.37ms +step:956/1695 train_time:94040ms step_avg:98.37ms +step:957/1695 train_time:94140ms step_avg:98.37ms +step:958/1695 train_time:94239ms step_avg:98.37ms +step:959/1695 train_time:94339ms step_avg:98.37ms +step:960/1695 train_time:94440ms step_avg:98.37ms +step:961/1695 train_time:94540ms step_avg:98.38ms +step:962/1695 train_time:94640ms step_avg:98.38ms +step:963/1695 train_time:94740ms step_avg:98.38ms +step:964/1695 train_time:94841ms step_avg:98.38ms +step:965/1695 train_time:94941ms step_avg:98.38ms +step:966/1695 train_time:95040ms step_avg:98.39ms +step:967/1695 train_time:95140ms step_avg:98.39ms +step:968/1695 train_time:95240ms step_avg:98.39ms +step:969/1695 train_time:95340ms step_avg:98.39ms +step:970/1695 train_time:95439ms step_avg:98.39ms +step:971/1695 train_time:95540ms step_avg:98.39ms +step:972/1695 train_time:95640ms step_avg:98.40ms +step:973/1695 train_time:95741ms step_avg:98.40ms +step:974/1695 train_time:95840ms step_avg:98.40ms +step:975/1695 train_time:95940ms step_avg:98.40ms +step:976/1695 train_time:96040ms step_avg:98.40ms +step:977/1695 train_time:96140ms step_avg:98.40ms +step:978/1695 train_time:96239ms step_avg:98.40ms +step:979/1695 train_time:96339ms step_avg:98.41ms +step:980/1695 train_time:96440ms step_avg:98.41ms +step:981/1695 train_time:96539ms step_avg:98.41ms +step:982/1695 train_time:96640ms step_avg:98.41ms +step:983/1695 train_time:96741ms step_avg:98.41ms +step:984/1695 train_time:96841ms step_avg:98.42ms +step:985/1695 train_time:96941ms step_avg:98.42ms +step:986/1695 train_time:97041ms step_avg:98.42ms +step:987/1695 train_time:97142ms step_avg:98.42ms +step:988/1695 train_time:97241ms step_avg:98.42ms +step:989/1695 train_time:97340ms step_avg:98.42ms +step:990/1695 train_time:97440ms step_avg:98.42ms +step:991/1695 train_time:97540ms step_avg:98.43ms +step:992/1695 train_time:97640ms step_avg:98.43ms +step:993/1695 train_time:97740ms step_avg:98.43ms +step:994/1695 train_time:97840ms step_avg:98.43ms +step:995/1695 train_time:97940ms step_avg:98.43ms +step:996/1695 train_time:98040ms step_avg:98.43ms +step:997/1695 train_time:98140ms step_avg:98.44ms +step:998/1695 train_time:98239ms step_avg:98.44ms +step:999/1695 train_time:98339ms step_avg:98.44ms +step:1000/1695 train_time:98438ms step_avg:98.44ms +step:1000/1695 val_loss:3.4932 train_time:98537ms step_avg:98.54ms +step:1001/1695 train_time:98563ms step_avg:98.46ms +step:1002/1695 train_time:98647ms step_avg:98.45ms +step:1003/1695 train_time:98748ms step_avg:98.45ms +step:1004/1695 train_time:98847ms step_avg:98.45ms +step:1005/1695 train_time:98946ms step_avg:98.45ms +step:1006/1695 train_time:99045ms step_avg:98.45ms +step:1007/1695 train_time:99143ms step_avg:98.45ms +step:1008/1695 train_time:99242ms step_avg:98.45ms +step:1009/1695 train_time:99340ms step_avg:98.45ms +step:1010/1695 train_time:99439ms step_avg:98.45ms +step:1011/1695 train_time:99539ms step_avg:98.46ms +step:1012/1695 train_time:99640ms step_avg:98.46ms +step:1013/1695 train_time:99741ms step_avg:98.46ms +step:1014/1695 train_time:99842ms step_avg:98.46ms +step:1015/1695 train_time:99942ms step_avg:98.46ms +step:1016/1695 train_time:100041ms step_avg:98.47ms +step:1017/1695 train_time:100142ms step_avg:98.47ms +step:1018/1695 train_time:100241ms step_avg:98.47ms +step:1019/1695 train_time:100339ms step_avg:98.47ms +step:1020/1695 train_time:100439ms step_avg:98.47ms +step:1021/1695 train_time:100539ms step_avg:98.47ms +step:1022/1695 train_time:100639ms step_avg:98.47ms +step:1023/1695 train_time:100740ms step_avg:98.48ms +step:1024/1695 train_time:100843ms step_avg:98.48ms +step:1025/1695 train_time:100943ms step_avg:98.48ms +step:1026/1695 train_time:101043ms step_avg:98.48ms +step:1027/1695 train_time:101142ms step_avg:98.48ms +step:1028/1695 train_time:101241ms step_avg:98.48ms +step:1029/1695 train_time:101342ms step_avg:98.49ms +step:1030/1695 train_time:101440ms step_avg:98.49ms +step:1031/1695 train_time:101540ms step_avg:98.49ms +step:1032/1695 train_time:101640ms step_avg:98.49ms +step:1033/1695 train_time:101740ms step_avg:98.49ms +step:1034/1695 train_time:101840ms step_avg:98.49ms +step:1035/1695 train_time:101941ms step_avg:98.49ms +step:1036/1695 train_time:102041ms step_avg:98.50ms +step:1037/1695 train_time:102142ms step_avg:98.50ms +step:1038/1695 train_time:102241ms step_avg:98.50ms +step:1039/1695 train_time:102340ms step_avg:98.50ms +step:1040/1695 train_time:102439ms step_avg:98.50ms +step:1041/1695 train_time:102539ms step_avg:98.50ms +step:1042/1695 train_time:102639ms step_avg:98.50ms +step:1043/1695 train_time:102739ms step_avg:98.50ms +step:1044/1695 train_time:102839ms step_avg:98.51ms +step:1045/1695 train_time:102940ms step_avg:98.51ms +step:1046/1695 train_time:103041ms step_avg:98.51ms +step:1047/1695 train_time:103141ms step_avg:98.51ms +step:1048/1695 train_time:103241ms step_avg:98.51ms +step:1049/1695 train_time:103340ms step_avg:98.51ms +step:1050/1695 train_time:103440ms step_avg:98.51ms +step:1051/1695 train_time:103541ms step_avg:98.52ms +step:1052/1695 train_time:103640ms step_avg:98.52ms +step:1053/1695 train_time:103740ms step_avg:98.52ms +step:1054/1695 train_time:103840ms step_avg:98.52ms +step:1055/1695 train_time:103941ms step_avg:98.52ms +step:1056/1695 train_time:104041ms step_avg:98.52ms +step:1057/1695 train_time:104141ms step_avg:98.53ms +step:1058/1695 train_time:104241ms step_avg:98.53ms +step:1059/1695 train_time:104339ms step_avg:98.53ms +step:1060/1695 train_time:104439ms step_avg:98.53ms +step:1061/1695 train_time:104538ms step_avg:98.53ms +step:1062/1695 train_time:104639ms step_avg:98.53ms +step:1063/1695 train_time:104739ms step_avg:98.53ms +step:1064/1695 train_time:104839ms step_avg:98.53ms +step:1065/1695 train_time:104940ms step_avg:98.54ms +step:1066/1695 train_time:105040ms step_avg:98.54ms +step:1067/1695 train_time:105140ms step_avg:98.54ms +step:1068/1695 train_time:105240ms step_avg:98.54ms +step:1069/1695 train_time:105340ms step_avg:98.54ms +step:1070/1695 train_time:105440ms step_avg:98.54ms +step:1071/1695 train_time:105539ms step_avg:98.54ms +step:1072/1695 train_time:105639ms step_avg:98.54ms +step:1073/1695 train_time:105739ms step_avg:98.54ms +step:1074/1695 train_time:105838ms step_avg:98.55ms +step:1075/1695 train_time:105939ms step_avg:98.55ms +step:1076/1695 train_time:106038ms step_avg:98.55ms +step:1077/1695 train_time:106140ms step_avg:98.55ms +step:1078/1695 train_time:106239ms step_avg:98.55ms +step:1079/1695 train_time:106339ms step_avg:98.55ms +step:1080/1695 train_time:106439ms step_avg:98.55ms +step:1081/1695 train_time:106538ms step_avg:98.56ms +step:1082/1695 train_time:106639ms step_avg:98.56ms +step:1083/1695 train_time:106739ms step_avg:98.56ms +step:1084/1695 train_time:106839ms step_avg:98.56ms +step:1085/1695 train_time:106939ms step_avg:98.56ms +step:1086/1695 train_time:107040ms step_avg:98.56ms +step:1087/1695 train_time:107140ms step_avg:98.56ms +step:1088/1695 train_time:107240ms step_avg:98.57ms +step:1089/1695 train_time:107339ms step_avg:98.57ms +step:1090/1695 train_time:107439ms step_avg:98.57ms +step:1091/1695 train_time:107539ms step_avg:98.57ms +step:1092/1695 train_time:107639ms step_avg:98.57ms +step:1093/1695 train_time:107739ms step_avg:98.57ms +step:1094/1695 train_time:107839ms step_avg:98.57ms +step:1095/1695 train_time:107939ms step_avg:98.57ms +step:1096/1695 train_time:108040ms step_avg:98.58ms +step:1097/1695 train_time:108140ms step_avg:98.58ms +step:1098/1695 train_time:108240ms step_avg:98.58ms +step:1099/1695 train_time:108339ms step_avg:98.58ms +step:1100/1695 train_time:108439ms step_avg:98.58ms +step:1101/1695 train_time:108538ms step_avg:98.58ms +step:1102/1695 train_time:108639ms step_avg:98.58ms +step:1103/1695 train_time:108739ms step_avg:98.58ms +step:1104/1695 train_time:108839ms step_avg:98.59ms +step:1105/1695 train_time:108938ms step_avg:98.59ms +step:1106/1695 train_time:109040ms step_avg:98.59ms +step:1107/1695 train_time:109140ms step_avg:98.59ms +step:1108/1695 train_time:109240ms step_avg:98.59ms +step:1109/1695 train_time:109340ms step_avg:98.59ms +step:1110/1695 train_time:109439ms step_avg:98.59ms +step:1111/1695 train_time:109539ms step_avg:98.60ms +step:1112/1695 train_time:109640ms step_avg:98.60ms +step:1113/1695 train_time:109740ms step_avg:98.60ms +step:1114/1695 train_time:109840ms step_avg:98.60ms +step:1115/1695 train_time:109940ms step_avg:98.60ms +step:1116/1695 train_time:110040ms step_avg:98.60ms +step:1117/1695 train_time:110141ms step_avg:98.60ms +step:1118/1695 train_time:110240ms step_avg:98.60ms +step:1119/1695 train_time:110340ms step_avg:98.61ms +step:1120/1695 train_time:110440ms step_avg:98.61ms +step:1121/1695 train_time:110541ms step_avg:98.61ms +step:1122/1695 train_time:110640ms step_avg:98.61ms +step:1123/1695 train_time:110740ms step_avg:98.61ms +step:1124/1695 train_time:110839ms step_avg:98.61ms +step:1125/1695 train_time:110940ms step_avg:98.61ms +step:1125/1695 val_loss:3.4399 train_time:111037ms step_avg:98.70ms +step:1126/1695 train_time:111063ms step_avg:98.64ms +step:1127/1695 train_time:111148ms step_avg:98.62ms +step:1128/1695 train_time:111250ms step_avg:98.63ms +step:1129/1695 train_time:111351ms step_avg:98.63ms +step:1130/1695 train_time:111451ms step_avg:98.63ms +step:1131/1695 train_time:111550ms step_avg:98.63ms +step:1132/1695 train_time:111650ms step_avg:98.63ms +step:1133/1695 train_time:111750ms step_avg:98.63ms +step:1134/1695 train_time:111850ms step_avg:98.63ms +step:1135/1695 train_time:111949ms step_avg:98.63ms +step:1136/1695 train_time:112052ms step_avg:98.64ms +step:1137/1695 train_time:112157ms step_avg:98.64ms +step:1138/1695 train_time:112261ms step_avg:98.65ms +step:1139/1695 train_time:112360ms step_avg:98.65ms +step:1140/1695 train_time:112461ms step_avg:98.65ms +step:1141/1695 train_time:112560ms step_avg:98.65ms +step:1142/1695 train_time:112660ms step_avg:98.65ms +step:1143/1695 train_time:112759ms step_avg:98.65ms +step:1144/1695 train_time:112859ms step_avg:98.65ms +step:1145/1695 train_time:112961ms step_avg:98.66ms +step:1146/1695 train_time:113061ms step_avg:98.66ms +step:1147/1695 train_time:113162ms step_avg:98.66ms +step:1148/1695 train_time:113263ms step_avg:98.66ms +step:1149/1695 train_time:113363ms step_avg:98.66ms +step:1150/1695 train_time:113463ms step_avg:98.66ms +step:1151/1695 train_time:113563ms step_avg:98.67ms +step:1152/1695 train_time:113664ms step_avg:98.67ms +step:1153/1695 train_time:113765ms step_avg:98.67ms +step:1154/1695 train_time:113866ms step_avg:98.67ms +step:1155/1695 train_time:113966ms step_avg:98.67ms +step:1156/1695 train_time:114067ms step_avg:98.67ms +step:1157/1695 train_time:114169ms step_avg:98.68ms +step:1158/1695 train_time:114269ms step_avg:98.68ms +step:1159/1695 train_time:114370ms step_avg:98.68ms +step:1160/1695 train_time:114473ms step_avg:98.68ms +step:1161/1695 train_time:114576ms step_avg:98.69ms +step:1162/1695 train_time:114677ms step_avg:98.69ms +step:1163/1695 train_time:114780ms step_avg:98.69ms +step:1164/1695 train_time:114881ms step_avg:98.69ms +step:1165/1695 train_time:114981ms step_avg:98.70ms +step:1166/1695 train_time:115082ms step_avg:98.70ms +step:1167/1695 train_time:115182ms step_avg:98.70ms +step:1168/1695 train_time:115282ms step_avg:98.70ms +step:1169/1695 train_time:115383ms step_avg:98.70ms +step:1170/1695 train_time:115483ms step_avg:98.70ms +step:1171/1695 train_time:115583ms step_avg:98.70ms +step:1172/1695 train_time:115688ms step_avg:98.71ms +step:1173/1695 train_time:115789ms step_avg:98.71ms +step:1174/1695 train_time:115890ms step_avg:98.71ms +step:1175/1695 train_time:115991ms step_avg:98.72ms +step:1176/1695 train_time:116092ms step_avg:98.72ms +step:1177/1695 train_time:116193ms step_avg:98.72ms +step:1178/1695 train_time:116295ms step_avg:98.72ms +step:1179/1695 train_time:116399ms step_avg:98.73ms +step:1180/1695 train_time:116500ms step_avg:98.73ms +step:1181/1695 train_time:116600ms step_avg:98.73ms +step:1182/1695 train_time:116701ms step_avg:98.73ms +step:1183/1695 train_time:116801ms step_avg:98.73ms +step:1184/1695 train_time:116902ms step_avg:98.74ms +step:1185/1695 train_time:117003ms step_avg:98.74ms +step:1186/1695 train_time:117103ms step_avg:98.74ms +step:1187/1695 train_time:117203ms step_avg:98.74ms +step:1188/1695 train_time:117304ms step_avg:98.74ms +step:1189/1695 train_time:117405ms step_avg:98.74ms +step:1190/1695 train_time:117506ms step_avg:98.74ms +step:1191/1695 train_time:117608ms step_avg:98.75ms +step:1192/1695 train_time:117709ms step_avg:98.75ms +step:1193/1695 train_time:117809ms step_avg:98.75ms +step:1194/1695 train_time:117910ms step_avg:98.75ms +step:1195/1695 train_time:118011ms step_avg:98.75ms +step:1196/1695 train_time:118112ms step_avg:98.76ms +step:1197/1695 train_time:118215ms step_avg:98.76ms +step:1198/1695 train_time:118317ms step_avg:98.76ms +step:1199/1695 train_time:118418ms step_avg:98.76ms +step:1200/1695 train_time:118519ms step_avg:98.77ms +step:1201/1695 train_time:118619ms step_avg:98.77ms +step:1202/1695 train_time:118721ms step_avg:98.77ms +step:1203/1695 train_time:118822ms step_avg:98.77ms +step:1204/1695 train_time:118923ms step_avg:98.77ms +step:1205/1695 train_time:119022ms step_avg:98.77ms +step:1206/1695 train_time:119122ms step_avg:98.77ms +step:1207/1695 train_time:119222ms step_avg:98.78ms +step:1208/1695 train_time:119322ms step_avg:98.78ms +step:1209/1695 train_time:119422ms step_avg:98.78ms +step:1210/1695 train_time:119522ms step_avg:98.78ms +step:1211/1695 train_time:119623ms step_avg:98.78ms +step:1212/1695 train_time:119723ms step_avg:98.78ms +step:1213/1695 train_time:119825ms step_avg:98.78ms +step:1214/1695 train_time:119925ms step_avg:98.79ms +step:1215/1695 train_time:120026ms step_avg:98.79ms +step:1216/1695 train_time:120128ms step_avg:98.79ms +step:1217/1695 train_time:120229ms step_avg:98.79ms +step:1218/1695 train_time:120330ms step_avg:98.79ms +step:1219/1695 train_time:120431ms step_avg:98.79ms +step:1220/1695 train_time:120533ms step_avg:98.80ms +step:1221/1695 train_time:120636ms step_avg:98.80ms +step:1222/1695 train_time:120738ms step_avg:98.80ms +step:1223/1695 train_time:120839ms step_avg:98.81ms +step:1224/1695 train_time:120939ms step_avg:98.81ms +step:1225/1695 train_time:121040ms step_avg:98.81ms +step:1226/1695 train_time:121140ms step_avg:98.81ms +step:1227/1695 train_time:121241ms step_avg:98.81ms +step:1228/1695 train_time:121341ms step_avg:98.81ms +step:1229/1695 train_time:121440ms step_avg:98.81ms +step:1230/1695 train_time:121540ms step_avg:98.81ms +step:1231/1695 train_time:121641ms step_avg:98.81ms +step:1232/1695 train_time:121741ms step_avg:98.82ms +step:1233/1695 train_time:121843ms step_avg:98.82ms +step:1234/1695 train_time:121944ms step_avg:98.82ms +step:1235/1695 train_time:122043ms step_avg:98.82ms +step:1236/1695 train_time:122143ms step_avg:98.82ms +step:1237/1695 train_time:122243ms step_avg:98.82ms +step:1238/1695 train_time:122343ms step_avg:98.82ms +step:1239/1695 train_time:122443ms step_avg:98.82ms +step:1240/1695 train_time:122544ms step_avg:98.83ms +step:1241/1695 train_time:122645ms step_avg:98.83ms +step:1242/1695 train_time:122746ms step_avg:98.83ms +step:1243/1695 train_time:122849ms step_avg:98.83ms +step:1244/1695 train_time:122949ms step_avg:98.83ms +step:1245/1695 train_time:123050ms step_avg:98.84ms +step:1246/1695 train_time:123151ms step_avg:98.84ms +step:1247/1695 train_time:123252ms step_avg:98.84ms +step:1248/1695 train_time:123354ms step_avg:98.84ms +step:1249/1695 train_time:123455ms step_avg:98.84ms +step:1250/1695 train_time:123556ms step_avg:98.84ms +step:1250/1695 val_loss:3.3958 train_time:123656ms step_avg:98.92ms +step:1251/1695 train_time:123681ms step_avg:98.87ms +step:1252/1695 train_time:123768ms step_avg:98.86ms +step:1253/1695 train_time:123872ms step_avg:98.86ms +step:1254/1695 train_time:123972ms step_avg:98.86ms +step:1255/1695 train_time:124073ms step_avg:98.86ms +step:1256/1695 train_time:124173ms step_avg:98.86ms +step:1257/1695 train_time:124272ms step_avg:98.86ms +step:1258/1695 train_time:124372ms step_avg:98.86ms +step:1259/1695 train_time:124471ms step_avg:98.87ms +step:1260/1695 train_time:124571ms step_avg:98.87ms +step:1261/1695 train_time:124671ms step_avg:98.87ms +step:1262/1695 train_time:124773ms step_avg:98.87ms +step:1263/1695 train_time:124874ms step_avg:98.87ms +step:1264/1695 train_time:124974ms step_avg:98.87ms +step:1265/1695 train_time:125075ms step_avg:98.87ms +step:1266/1695 train_time:125175ms step_avg:98.87ms +step:1267/1695 train_time:125275ms step_avg:98.88ms +step:1268/1695 train_time:125375ms step_avg:98.88ms +step:1269/1695 train_time:125476ms step_avg:98.88ms +step:1270/1695 train_time:125577ms step_avg:98.88ms +step:1271/1695 train_time:125679ms step_avg:98.88ms +step:1272/1695 train_time:125780ms step_avg:98.88ms +step:1273/1695 train_time:125882ms step_avg:98.89ms +step:1274/1695 train_time:125983ms step_avg:98.89ms +step:1275/1695 train_time:126084ms step_avg:98.89ms +step:1276/1695 train_time:126188ms step_avg:98.89ms +step:1277/1695 train_time:126289ms step_avg:98.90ms +step:1278/1695 train_time:126390ms step_avg:98.90ms +step:1279/1695 train_time:126491ms step_avg:98.90ms +step:1280/1695 train_time:126591ms step_avg:98.90ms +step:1281/1695 train_time:126691ms step_avg:98.90ms +step:1282/1695 train_time:126791ms step_avg:98.90ms +step:1283/1695 train_time:126892ms step_avg:98.90ms +step:1284/1695 train_time:126992ms step_avg:98.90ms +step:1285/1695 train_time:127092ms step_avg:98.90ms +step:1286/1695 train_time:127192ms step_avg:98.91ms +step:1287/1695 train_time:127293ms step_avg:98.91ms +step:1288/1695 train_time:127393ms step_avg:98.91ms +step:1289/1695 train_time:127493ms step_avg:98.91ms +step:1290/1695 train_time:127593ms step_avg:98.91ms +step:1291/1695 train_time:127693ms step_avg:98.91ms +step:1292/1695 train_time:127794ms step_avg:98.91ms +step:1293/1695 train_time:127894ms step_avg:98.91ms +step:1294/1695 train_time:127995ms step_avg:98.91ms +step:1295/1695 train_time:128096ms step_avg:98.92ms +step:1296/1695 train_time:128196ms step_avg:98.92ms +step:1297/1695 train_time:128297ms step_avg:98.92ms +step:1298/1695 train_time:128397ms step_avg:98.92ms +step:1299/1695 train_time:128498ms step_avg:98.92ms +step:1300/1695 train_time:128598ms step_avg:98.92ms +step:1301/1695 train_time:128699ms step_avg:98.92ms +step:1302/1695 train_time:128802ms step_avg:98.93ms +step:1303/1695 train_time:128903ms step_avg:98.93ms +step:1304/1695 train_time:129005ms step_avg:98.93ms +step:1305/1695 train_time:129107ms step_avg:98.93ms +step:1306/1695 train_time:129208ms step_avg:98.93ms +step:1307/1695 train_time:129309ms step_avg:98.94ms +step:1308/1695 train_time:129410ms step_avg:98.94ms +step:1309/1695 train_time:129511ms step_avg:98.94ms +step:1310/1695 train_time:129613ms step_avg:98.94ms +step:1311/1695 train_time:129714ms step_avg:98.94ms +step:1312/1695 train_time:129814ms step_avg:98.94ms +step:1313/1695 train_time:129915ms step_avg:98.94ms +step:1314/1695 train_time:130015ms step_avg:98.95ms +step:1315/1695 train_time:130115ms step_avg:98.95ms +step:1316/1695 train_time:130217ms step_avg:98.95ms +step:1317/1695 train_time:130318ms step_avg:98.95ms +step:1318/1695 train_time:130418ms step_avg:98.95ms +step:1319/1695 train_time:130519ms step_avg:98.95ms +step:1320/1695 train_time:130621ms step_avg:98.96ms +step:1321/1695 train_time:130723ms step_avg:98.96ms +step:1322/1695 train_time:130824ms step_avg:98.96ms +step:1323/1695 train_time:130925ms step_avg:98.96ms +step:1324/1695 train_time:131027ms step_avg:98.96ms +step:1325/1695 train_time:131129ms step_avg:98.97ms +step:1326/1695 train_time:131232ms step_avg:98.97ms +step:1327/1695 train_time:131333ms step_avg:98.97ms +step:1328/1695 train_time:131433ms step_avg:98.97ms +step:1329/1695 train_time:131533ms step_avg:98.97ms +step:1330/1695 train_time:131634ms step_avg:98.97ms +step:1331/1695 train_time:131734ms step_avg:98.97ms +step:1332/1695 train_time:131834ms step_avg:98.97ms +step:1333/1695 train_time:131934ms step_avg:98.98ms +step:1334/1695 train_time:132037ms step_avg:98.98ms +step:1335/1695 train_time:132139ms step_avg:98.98ms +step:1336/1695 train_time:132241ms step_avg:98.98ms +step:1337/1695 train_time:132342ms step_avg:98.98ms +step:1338/1695 train_time:132443ms step_avg:98.99ms +step:1339/1695 train_time:132544ms step_avg:98.99ms +step:1340/1695 train_time:132645ms step_avg:98.99ms +step:1341/1695 train_time:132746ms step_avg:98.99ms +step:1342/1695 train_time:132848ms step_avg:98.99ms +step:1343/1695 train_time:132950ms step_avg:98.99ms +step:1344/1695 train_time:133050ms step_avg:99.00ms +step:1345/1695 train_time:133151ms step_avg:99.00ms +step:1346/1695 train_time:133252ms step_avg:99.00ms +step:1347/1695 train_time:133353ms step_avg:99.00ms +step:1348/1695 train_time:133452ms step_avg:99.00ms +step:1349/1695 train_time:133552ms step_avg:99.00ms +step:1350/1695 train_time:133653ms step_avg:99.00ms +step:1351/1695 train_time:133753ms step_avg:99.00ms +step:1352/1695 train_time:133853ms step_avg:99.00ms +step:1353/1695 train_time:133953ms step_avg:99.00ms +step:1354/1695 train_time:134054ms step_avg:99.01ms +step:1355/1695 train_time:134155ms step_avg:99.01ms +step:1356/1695 train_time:134255ms step_avg:99.01ms +step:1357/1695 train_time:134356ms step_avg:99.01ms +step:1358/1695 train_time:134457ms step_avg:99.01ms +step:1359/1695 train_time:134556ms step_avg:99.01ms +step:1360/1695 train_time:134657ms step_avg:99.01ms +step:1361/1695 train_time:134758ms step_avg:99.01ms +step:1362/1695 train_time:134860ms step_avg:99.02ms +step:1363/1695 train_time:134960ms step_avg:99.02ms +step:1364/1695 train_time:135062ms step_avg:99.02ms +step:1365/1695 train_time:135163ms step_avg:99.02ms +step:1366/1695 train_time:135265ms step_avg:99.02ms +step:1367/1695 train_time:135368ms step_avg:99.03ms +step:1368/1695 train_time:135470ms step_avg:99.03ms +step:1369/1695 train_time:135570ms step_avg:99.03ms +step:1370/1695 train_time:135670ms step_avg:99.03ms +step:1371/1695 train_time:135771ms step_avg:99.03ms +step:1372/1695 train_time:135872ms step_avg:99.03ms +step:1373/1695 train_time:135974ms step_avg:99.03ms +step:1374/1695 train_time:136074ms step_avg:99.03ms +step:1375/1695 train_time:136175ms step_avg:99.04ms +step:1375/1695 val_loss:3.3561 train_time:136274ms step_avg:99.11ms +step:1376/1695 train_time:136299ms step_avg:99.05ms +step:1377/1695 train_time:136390ms step_avg:99.05ms +step:1378/1695 train_time:136492ms step_avg:99.05ms +step:1379/1695 train_time:136594ms step_avg:99.05ms +step:1380/1695 train_time:136696ms step_avg:99.06ms +step:1381/1695 train_time:136796ms step_avg:99.06ms +step:1382/1695 train_time:136897ms step_avg:99.06ms +step:1383/1695 train_time:136996ms step_avg:99.06ms +step:1384/1695 train_time:137097ms step_avg:99.06ms +step:1385/1695 train_time:137198ms step_avg:99.06ms +step:1386/1695 train_time:137302ms step_avg:99.06ms +step:1387/1695 train_time:137405ms step_avg:99.07ms +step:1388/1695 train_time:137506ms step_avg:99.07ms +step:1389/1695 train_time:137608ms step_avg:99.07ms +step:1390/1695 train_time:137709ms step_avg:99.07ms +step:1391/1695 train_time:137810ms step_avg:99.07ms +step:1392/1695 train_time:137911ms step_avg:99.07ms +step:1393/1695 train_time:138013ms step_avg:99.08ms +step:1394/1695 train_time:138114ms step_avg:99.08ms +step:1395/1695 train_time:138218ms step_avg:99.08ms +step:1396/1695 train_time:138320ms step_avg:99.08ms +step:1397/1695 train_time:138424ms step_avg:99.09ms +step:1398/1695 train_time:138526ms step_avg:99.09ms +step:1399/1695 train_time:138629ms step_avg:99.09ms +step:1400/1695 train_time:138731ms step_avg:99.09ms +step:1401/1695 train_time:138831ms step_avg:99.09ms +step:1402/1695 train_time:138933ms step_avg:99.10ms +step:1403/1695 train_time:139036ms step_avg:99.10ms +step:1404/1695 train_time:139138ms step_avg:99.10ms +step:1405/1695 train_time:139240ms step_avg:99.10ms +step:1406/1695 train_time:139342ms step_avg:99.11ms +step:1407/1695 train_time:139444ms step_avg:99.11ms +step:1408/1695 train_time:139545ms step_avg:99.11ms +step:1409/1695 train_time:139649ms step_avg:99.11ms +step:1410/1695 train_time:139749ms step_avg:99.11ms +step:1411/1695 train_time:139850ms step_avg:99.11ms +step:1412/1695 train_time:139953ms step_avg:99.12ms +step:1413/1695 train_time:140054ms step_avg:99.12ms +step:1414/1695 train_time:140157ms step_avg:99.12ms +step:1415/1695 train_time:140260ms step_avg:99.12ms +step:1416/1695 train_time:140360ms step_avg:99.12ms +step:1417/1695 train_time:140461ms step_avg:99.13ms +step:1418/1695 train_time:140563ms step_avg:99.13ms +step:1419/1695 train_time:140665ms step_avg:99.13ms +step:1420/1695 train_time:140766ms step_avg:99.13ms +step:1421/1695 train_time:140867ms step_avg:99.13ms +step:1422/1695 train_time:140968ms step_avg:99.13ms +step:1423/1695 train_time:141069ms step_avg:99.14ms +step:1424/1695 train_time:141172ms step_avg:99.14ms +step:1425/1695 train_time:141275ms step_avg:99.14ms +step:1426/1695 train_time:141379ms step_avg:99.14ms +step:1427/1695 train_time:141481ms step_avg:99.15ms +step:1428/1695 train_time:141582ms step_avg:99.15ms +step:1429/1695 train_time:141684ms step_avg:99.15ms +step:1430/1695 train_time:141785ms step_avg:99.15ms +step:1431/1695 train_time:141886ms step_avg:99.15ms +step:1432/1695 train_time:141986ms step_avg:99.15ms +step:1433/1695 train_time:142089ms step_avg:99.16ms +step:1434/1695 train_time:142192ms step_avg:99.16ms +step:1435/1695 train_time:142294ms step_avg:99.16ms +step:1436/1695 train_time:142399ms step_avg:99.16ms +step:1437/1695 train_time:142501ms step_avg:99.17ms +step:1438/1695 train_time:142602ms step_avg:99.17ms +step:1439/1695 train_time:142705ms step_avg:99.17ms +step:1440/1695 train_time:142807ms step_avg:99.17ms +step:1441/1695 train_time:142910ms step_avg:99.17ms +step:1442/1695 train_time:143011ms step_avg:99.18ms +step:1443/1695 train_time:143112ms step_avg:99.18ms +step:1444/1695 train_time:143214ms step_avg:99.18ms +step:1445/1695 train_time:143315ms step_avg:99.18ms +step:1446/1695 train_time:143418ms step_avg:99.18ms +step:1447/1695 train_time:143520ms step_avg:99.18ms +step:1448/1695 train_time:143624ms step_avg:99.19ms +step:1449/1695 train_time:143725ms step_avg:99.19ms +step:1450/1695 train_time:143827ms step_avg:99.19ms +step:1451/1695 train_time:143928ms step_avg:99.19ms +step:1452/1695 train_time:144030ms step_avg:99.19ms +step:1453/1695 train_time:144132ms step_avg:99.20ms +step:1454/1695 train_time:144234ms step_avg:99.20ms +step:1455/1695 train_time:144337ms step_avg:99.20ms +step:1456/1695 train_time:144439ms step_avg:99.20ms +step:1457/1695 train_time:144541ms step_avg:99.20ms +step:1458/1695 train_time:144644ms step_avg:99.21ms +step:1459/1695 train_time:144746ms step_avg:99.21ms +step:1460/1695 train_time:144846ms step_avg:99.21ms +step:1461/1695 train_time:144949ms step_avg:99.21ms +step:1462/1695 train_time:145050ms step_avg:99.21ms +step:1463/1695 train_time:145152ms step_avg:99.21ms +step:1464/1695 train_time:145253ms step_avg:99.22ms +step:1465/1695 train_time:145355ms step_avg:99.22ms +step:1466/1695 train_time:145458ms step_avg:99.22ms +step:1467/1695 train_time:145560ms step_avg:99.22ms +step:1468/1695 train_time:145663ms step_avg:99.23ms +step:1469/1695 train_time:145765ms step_avg:99.23ms +step:1470/1695 train_time:145865ms step_avg:99.23ms +step:1471/1695 train_time:145967ms step_avg:99.23ms +step:1472/1695 train_time:146069ms step_avg:99.23ms +step:1473/1695 train_time:146170ms step_avg:99.23ms +step:1474/1695 train_time:146273ms step_avg:99.24ms +step:1475/1695 train_time:146376ms step_avg:99.24ms +step:1476/1695 train_time:146478ms step_avg:99.24ms +step:1477/1695 train_time:146580ms step_avg:99.24ms +step:1478/1695 train_time:146684ms step_avg:99.24ms +step:1479/1695 train_time:146785ms step_avg:99.25ms +step:1480/1695 train_time:146887ms step_avg:99.25ms +step:1481/1695 train_time:146989ms step_avg:99.25ms +step:1482/1695 train_time:147091ms step_avg:99.25ms +step:1483/1695 train_time:147193ms step_avg:99.25ms +step:1484/1695 train_time:147296ms step_avg:99.26ms +step:1485/1695 train_time:147398ms step_avg:99.26ms +step:1486/1695 train_time:147500ms step_avg:99.26ms +step:1487/1695 train_time:147601ms step_avg:99.26ms +step:1488/1695 train_time:147704ms step_avg:99.26ms +step:1489/1695 train_time:147806ms step_avg:99.27ms +step:1490/1695 train_time:147908ms step_avg:99.27ms +step:1491/1695 train_time:148009ms step_avg:99.27ms +step:1492/1695 train_time:148110ms step_avg:99.27ms +step:1493/1695 train_time:148213ms step_avg:99.27ms +step:1494/1695 train_time:148314ms step_avg:99.27ms +step:1495/1695 train_time:148417ms step_avg:99.28ms +step:1496/1695 train_time:148519ms step_avg:99.28ms +step:1497/1695 train_time:148621ms step_avg:99.28ms +step:1498/1695 train_time:148723ms step_avg:99.28ms +step:1499/1695 train_time:148824ms step_avg:99.28ms +step:1500/1695 train_time:148926ms step_avg:99.28ms +step:1500/1695 val_loss:3.3213 train_time:149024ms step_avg:99.35ms +step:1501/1695 train_time:149049ms step_avg:99.30ms +step:1502/1695 train_time:149137ms step_avg:99.29ms +step:1503/1695 train_time:149237ms step_avg:99.29ms +step:1504/1695 train_time:149338ms step_avg:99.29ms +step:1505/1695 train_time:149438ms step_avg:99.29ms +step:1506/1695 train_time:149539ms step_avg:99.30ms +step:1507/1695 train_time:149640ms step_avg:99.30ms +step:1508/1695 train_time:149740ms step_avg:99.30ms +step:1509/1695 train_time:149842ms step_avg:99.30ms +step:1510/1695 train_time:149943ms step_avg:99.30ms +step:1511/1695 train_time:150048ms step_avg:99.30ms +step:1512/1695 train_time:150152ms step_avg:99.31ms +step:1513/1695 train_time:150253ms step_avg:99.31ms +step:1514/1695 train_time:150355ms step_avg:99.31ms +step:1515/1695 train_time:150460ms step_avg:99.31ms +step:1516/1695 train_time:150562ms step_avg:99.32ms +step:1517/1695 train_time:150662ms step_avg:99.32ms +step:1518/1695 train_time:150764ms step_avg:99.32ms +step:1519/1695 train_time:150868ms step_avg:99.32ms +step:1520/1695 train_time:150969ms step_avg:99.32ms +step:1521/1695 train_time:151071ms step_avg:99.32ms +step:1522/1695 train_time:151173ms step_avg:99.33ms +step:1523/1695 train_time:151275ms step_avg:99.33ms +step:1524/1695 train_time:151379ms step_avg:99.33ms +step:1525/1695 train_time:151483ms step_avg:99.33ms +step:1526/1695 train_time:151586ms step_avg:99.34ms +step:1527/1695 train_time:151687ms step_avg:99.34ms +step:1528/1695 train_time:151794ms step_avg:99.34ms +step:1529/1695 train_time:151895ms step_avg:99.34ms +step:1530/1695 train_time:151997ms step_avg:99.34ms +step:1531/1695 train_time:152098ms step_avg:99.35ms +step:1532/1695 train_time:152200ms step_avg:99.35ms +step:1533/1695 train_time:152303ms step_avg:99.35ms +step:1534/1695 train_time:152406ms step_avg:99.35ms +step:1535/1695 train_time:152508ms step_avg:99.35ms +step:1536/1695 train_time:152610ms step_avg:99.36ms +step:1537/1695 train_time:152712ms step_avg:99.36ms +step:1538/1695 train_time:152814ms step_avg:99.36ms +step:1539/1695 train_time:152915ms step_avg:99.36ms +step:1540/1695 train_time:153017ms step_avg:99.36ms +step:1541/1695 train_time:153120ms step_avg:99.36ms +step:1542/1695 train_time:153224ms step_avg:99.37ms +step:1543/1695 train_time:153326ms step_avg:99.37ms +step:1544/1695 train_time:153428ms step_avg:99.37ms +step:1545/1695 train_time:153530ms step_avg:99.37ms +step:1546/1695 train_time:153633ms step_avg:99.37ms +step:1547/1695 train_time:153736ms step_avg:99.38ms +step:1548/1695 train_time:153837ms step_avg:99.38ms +step:1549/1695 train_time:153939ms step_avg:99.38ms +step:1550/1695 train_time:154040ms step_avg:99.38ms +step:1551/1695 train_time:154143ms step_avg:99.38ms +step:1552/1695 train_time:154244ms step_avg:99.38ms +step:1553/1695 train_time:154347ms step_avg:99.39ms +step:1554/1695 train_time:154449ms step_avg:99.39ms +step:1555/1695 train_time:154551ms step_avg:99.39ms +step:1556/1695 train_time:154654ms step_avg:99.39ms +step:1557/1695 train_time:154757ms step_avg:99.39ms +step:1558/1695 train_time:154859ms step_avg:99.40ms +step:1559/1695 train_time:154962ms step_avg:99.40ms +step:1560/1695 train_time:155063ms step_avg:99.40ms +step:1561/1695 train_time:155165ms step_avg:99.40ms +step:1562/1695 train_time:155268ms step_avg:99.40ms +step:1563/1695 train_time:155372ms step_avg:99.41ms +step:1564/1695 train_time:155473ms step_avg:99.41ms +step:1565/1695 train_time:155574ms step_avg:99.41ms +step:1566/1695 train_time:155675ms step_avg:99.41ms +step:1567/1695 train_time:155776ms step_avg:99.41ms +step:1568/1695 train_time:155876ms step_avg:99.41ms +step:1569/1695 train_time:155977ms step_avg:99.41ms +step:1570/1695 train_time:156080ms step_avg:99.41ms +step:1571/1695 train_time:156181ms step_avg:99.42ms +step:1572/1695 train_time:156282ms step_avg:99.42ms +step:1573/1695 train_time:156384ms step_avg:99.42ms +step:1574/1695 train_time:156486ms step_avg:99.42ms +step:1575/1695 train_time:156588ms step_avg:99.42ms +step:1576/1695 train_time:156691ms step_avg:99.42ms +step:1577/1695 train_time:156794ms step_avg:99.43ms +step:1578/1695 train_time:156895ms step_avg:99.43ms +step:1579/1695 train_time:156997ms step_avg:99.43ms +step:1580/1695 train_time:157100ms step_avg:99.43ms +step:1581/1695 train_time:157202ms step_avg:99.43ms +step:1582/1695 train_time:157303ms step_avg:99.43ms +step:1583/1695 train_time:157406ms step_avg:99.44ms +step:1584/1695 train_time:157509ms step_avg:99.44ms +step:1585/1695 train_time:157610ms step_avg:99.44ms +step:1586/1695 train_time:157713ms step_avg:99.44ms +step:1587/1695 train_time:157814ms step_avg:99.44ms +step:1588/1695 train_time:157916ms step_avg:99.44ms +step:1589/1695 train_time:158017ms step_avg:99.44ms +step:1590/1695 train_time:158118ms step_avg:99.45ms +step:1591/1695 train_time:158220ms step_avg:99.45ms +step:1592/1695 train_time:158322ms step_avg:99.45ms +step:1593/1695 train_time:158423ms step_avg:99.45ms +step:1594/1695 train_time:158527ms step_avg:99.45ms +step:1595/1695 train_time:158629ms step_avg:99.45ms +step:1596/1695 train_time:158731ms step_avg:99.46ms +step:1597/1695 train_time:158833ms step_avg:99.46ms +step:1598/1695 train_time:158936ms step_avg:99.46ms +step:1599/1695 train_time:159037ms step_avg:99.46ms +step:1600/1695 train_time:159139ms step_avg:99.46ms +step:1601/1695 train_time:159241ms step_avg:99.46ms +step:1602/1695 train_time:159343ms step_avg:99.46ms +step:1603/1695 train_time:159444ms step_avg:99.47ms +step:1604/1695 train_time:159545ms step_avg:99.47ms +step:1605/1695 train_time:159647ms step_avg:99.47ms +step:1606/1695 train_time:159750ms step_avg:99.47ms +step:1607/1695 train_time:159852ms step_avg:99.47ms +step:1608/1695 train_time:159955ms step_avg:99.47ms +step:1609/1695 train_time:160056ms step_avg:99.48ms +step:1610/1695 train_time:160158ms step_avg:99.48ms +step:1611/1695 train_time:160260ms step_avg:99.48ms +step:1612/1695 train_time:160362ms step_avg:99.48ms +step:1613/1695 train_time:160463ms step_avg:99.48ms +step:1614/1695 train_time:160564ms step_avg:99.48ms +step:1615/1695 train_time:160667ms step_avg:99.48ms +step:1616/1695 train_time:160767ms step_avg:99.48ms +step:1617/1695 train_time:160871ms step_avg:99.49ms +step:1618/1695 train_time:160974ms step_avg:99.49ms +step:1619/1695 train_time:161076ms step_avg:99.49ms +step:1620/1695 train_time:161178ms step_avg:99.49ms +step:1621/1695 train_time:161280ms step_avg:99.49ms +step:1622/1695 train_time:161382ms step_avg:99.50ms +step:1623/1695 train_time:161483ms step_avg:99.50ms +step:1624/1695 train_time:161584ms step_avg:99.50ms +step:1625/1695 train_time:161687ms step_avg:99.50ms +step:1625/1695 val_loss:3.2926 train_time:161788ms step_avg:99.56ms +step:1626/1695 train_time:161814ms step_avg:99.52ms +step:1627/1695 train_time:161900ms step_avg:99.51ms +step:1628/1695 train_time:162004ms step_avg:99.51ms +step:1629/1695 train_time:162107ms step_avg:99.51ms +step:1630/1695 train_time:162209ms step_avg:99.51ms +step:1631/1695 train_time:162311ms step_avg:99.52ms +step:1632/1695 train_time:162412ms step_avg:99.52ms +step:1633/1695 train_time:162512ms step_avg:99.52ms +step:1634/1695 train_time:162615ms step_avg:99.52ms +step:1635/1695 train_time:162717ms step_avg:99.52ms +step:1636/1695 train_time:162820ms step_avg:99.52ms +step:1637/1695 train_time:162924ms step_avg:99.53ms +step:1638/1695 train_time:163027ms step_avg:99.53ms +step:1639/1695 train_time:163130ms step_avg:99.53ms +step:1640/1695 train_time:163232ms step_avg:99.53ms +step:1641/1695 train_time:163335ms step_avg:99.53ms +step:1642/1695 train_time:163437ms step_avg:99.54ms +step:1643/1695 train_time:163538ms step_avg:99.54ms +step:1644/1695 train_time:163640ms step_avg:99.54ms +step:1645/1695 train_time:163744ms step_avg:99.54ms +step:1646/1695 train_time:163847ms step_avg:99.54ms +step:1647/1695 train_time:163952ms step_avg:99.55ms +step:1648/1695 train_time:164055ms step_avg:99.55ms +step:1649/1695 train_time:164158ms step_avg:99.55ms +step:1650/1695 train_time:164262ms step_avg:99.55ms +step:1651/1695 train_time:164364ms step_avg:99.55ms +step:1652/1695 train_time:164468ms step_avg:99.56ms +step:1653/1695 train_time:164572ms step_avg:99.56ms +step:1654/1695 train_time:164673ms step_avg:99.56ms +step:1655/1695 train_time:164776ms step_avg:99.56ms +step:1656/1695 train_time:164879ms step_avg:99.56ms +step:1657/1695 train_time:164981ms step_avg:99.57ms +step:1658/1695 train_time:165084ms step_avg:99.57ms +step:1659/1695 train_time:165190ms step_avg:99.57ms +step:1660/1695 train_time:165293ms step_avg:99.57ms +step:1661/1695 train_time:165396ms step_avg:99.58ms +step:1662/1695 train_time:165501ms step_avg:99.58ms +step:1663/1695 train_time:165603ms step_avg:99.58ms +step:1664/1695 train_time:165707ms step_avg:99.58ms +step:1665/1695 train_time:165812ms step_avg:99.59ms +step:1666/1695 train_time:165915ms step_avg:99.59ms +step:1667/1695 train_time:166016ms step_avg:99.59ms +step:1668/1695 train_time:166121ms step_avg:99.59ms +step:1669/1695 train_time:166227ms step_avg:99.60ms +step:1670/1695 train_time:166329ms step_avg:99.60ms +step:1671/1695 train_time:166432ms step_avg:99.60ms +step:1672/1695 train_time:166536ms step_avg:99.60ms +step:1673/1695 train_time:166637ms step_avg:99.60ms +step:1674/1695 train_time:166739ms step_avg:99.60ms +step:1675/1695 train_time:166841ms step_avg:99.61ms +step:1676/1695 train_time:166946ms step_avg:99.61ms +step:1677/1695 train_time:167049ms step_avg:99.61ms +step:1678/1695 train_time:167153ms step_avg:99.61ms +step:1679/1695 train_time:167256ms step_avg:99.62ms +step:1680/1695 train_time:167357ms step_avg:99.62ms +step:1681/1695 train_time:167460ms step_avg:99.62ms +step:1682/1695 train_time:167567ms step_avg:99.62ms +step:1683/1695 train_time:167670ms step_avg:99.63ms +step:1684/1695 train_time:167773ms step_avg:99.63ms +step:1685/1695 train_time:167876ms step_avg:99.63ms +step:1686/1695 train_time:167978ms step_avg:99.63ms +step:1687/1695 train_time:168080ms step_avg:99.63ms +step:1688/1695 train_time:168183ms step_avg:99.63ms +step:1689/1695 train_time:168284ms step_avg:99.64ms +step:1690/1695 train_time:168388ms step_avg:99.64ms +step:1691/1695 train_time:168491ms step_avg:99.64ms +step:1692/1695 train_time:168593ms step_avg:99.64ms +step:1693/1695 train_time:168696ms step_avg:99.64ms +step:1694/1695 train_time:168800ms step_avg:99.65ms +step:1695/1695 train_time:168903ms step_avg:99.65ms +step:1695/1695 val_loss:3.2796 train_time:169003ms step_avg:99.71ms +peak memory allocated: 34004 MiB reserved: 49720 MiB diff --git a/records/082325_SparseAttnGate/6df384bb-9c24-46b3-826b-f7c07168c27a.txt b/records/082325_SparseAttnGate/6df384bb-9c24-46b3-826b-f7c07168c27a.txt new file mode 100644 index 000000000..a2ba6ed9d --- /dev/null +++ b/records/082325_SparseAttnGate/6df384bb-9c24-46b3-826b-f7c07168c27a.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:04:05 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 27C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 26C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 293658 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 293659 C /usr/bin/python3 614MiB | +| 0 N/A N/A 293660 C /usr/bin/python3 614MiB | +| 0 N/A N/A 293661 C /usr/bin/python3 614MiB | +| 0 N/A N/A 293662 C /usr/bin/python3 614MiB | +| 0 N/A N/A 293663 C /usr/bin/python3 614MiB | +| 0 N/A N/A 293664 C /usr/bin/python3 614MiB | +| 0 N/A N/A 293665 C /usr/bin/python3 614MiB | +| 1 N/A N/A 293659 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 293660 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 293661 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 293662 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 293663 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 293664 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 293665 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1695 train_time:148ms step_avg:148.13ms +step:2/1695 train_time:174ms step_avg:86.81ms +step:3/1695 train_time:245ms step_avg:81.60ms +step:4/1695 train_time:336ms step_avg:84.10ms +step:5/1695 train_time:429ms step_avg:85.81ms +step:6/1695 train_time:522ms step_avg:86.96ms +step:7/1695 train_time:615ms step_avg:87.92ms +step:8/1695 train_time:708ms step_avg:88.55ms +step:9/1695 train_time:801ms step_avg:89.03ms +step:10/1695 train_time:894ms step_avg:89.41ms +step:11/1695 train_time:988ms step_avg:89.80ms +step:12/1695 train_time:1081ms step_avg:90.12ms +step:13/1695 train_time:1177ms step_avg:90.54ms +step:14/1695 train_time:1271ms step_avg:90.81ms +step:15/1695 train_time:1365ms step_avg:91.01ms +step:16/1695 train_time:1459ms step_avg:91.18ms +step:17/1695 train_time:1552ms step_avg:91.28ms +step:18/1695 train_time:1645ms step_avg:91.37ms +step:19/1695 train_time:1738ms step_avg:91.46ms +step:20/1695 train_time:1831ms step_avg:91.54ms +step:21/1695 train_time:1924ms step_avg:91.63ms +step:22/1695 train_time:2017ms step_avg:91.70ms +step:23/1695 train_time:2111ms step_avg:91.80ms +step:24/1695 train_time:2207ms step_avg:91.94ms +step:25/1695 train_time:2301ms step_avg:92.06ms +step:26/1695 train_time:2396ms step_avg:92.15ms +step:27/1695 train_time:2489ms step_avg:92.20ms +step:28/1695 train_time:2583ms step_avg:92.26ms +step:29/1695 train_time:2677ms step_avg:92.32ms +step:30/1695 train_time:2770ms step_avg:92.33ms +step:31/1695 train_time:2863ms step_avg:92.36ms +step:32/1695 train_time:2956ms step_avg:92.38ms +step:33/1695 train_time:3050ms step_avg:92.42ms +step:34/1695 train_time:3144ms step_avg:92.47ms +step:35/1695 train_time:3238ms step_avg:92.53ms +step:36/1695 train_time:3332ms step_avg:92.56ms +step:37/1695 train_time:3427ms step_avg:92.61ms +step:38/1695 train_time:3520ms step_avg:92.64ms +step:39/1695 train_time:3613ms step_avg:92.65ms +step:40/1695 train_time:3706ms step_avg:92.66ms +step:41/1695 train_time:3800ms step_avg:92.67ms +step:42/1695 train_time:3893ms step_avg:92.69ms +step:43/1695 train_time:3986ms step_avg:92.69ms +step:44/1695 train_time:4080ms step_avg:92.72ms +step:45/1695 train_time:4173ms step_avg:92.72ms +step:46/1695 train_time:4268ms step_avg:92.77ms +step:47/1695 train_time:4361ms step_avg:92.79ms +step:48/1695 train_time:4455ms step_avg:92.81ms +step:49/1695 train_time:4549ms step_avg:92.83ms +step:50/1695 train_time:4643ms step_avg:92.86ms +step:51/1695 train_time:4737ms step_avg:92.89ms +step:52/1695 train_time:4830ms step_avg:92.89ms +step:53/1695 train_time:4924ms step_avg:92.91ms +step:54/1695 train_time:5017ms step_avg:92.91ms +step:55/1695 train_time:5110ms step_avg:92.91ms +step:56/1695 train_time:5204ms step_avg:92.93ms +step:57/1695 train_time:5297ms step_avg:92.93ms +step:58/1695 train_time:5390ms step_avg:92.94ms +step:59/1695 train_time:5485ms step_avg:92.96ms +step:60/1695 train_time:5578ms step_avg:92.97ms +step:61/1695 train_time:5672ms step_avg:92.98ms +step:62/1695 train_time:5766ms step_avg:93.00ms +step:63/1695 train_time:5859ms step_avg:93.00ms +step:64/1695 train_time:5952ms step_avg:93.01ms +step:65/1695 train_time:6046ms step_avg:93.02ms +step:66/1695 train_time:6140ms step_avg:93.04ms +step:67/1695 train_time:6234ms step_avg:93.04ms +step:68/1695 train_time:6328ms step_avg:93.05ms +step:69/1695 train_time:6421ms step_avg:93.06ms +step:70/1695 train_time:6514ms step_avg:93.06ms +step:71/1695 train_time:6608ms step_avg:93.07ms +step:72/1695 train_time:6702ms step_avg:93.09ms +step:73/1695 train_time:6797ms step_avg:93.11ms +step:74/1695 train_time:6889ms step_avg:93.10ms +step:75/1695 train_time:6982ms step_avg:93.10ms +step:76/1695 train_time:7077ms step_avg:93.11ms +step:77/1695 train_time:7170ms step_avg:93.11ms +step:78/1695 train_time:7264ms step_avg:93.12ms +step:79/1695 train_time:7356ms step_avg:93.12ms +step:80/1695 train_time:7450ms step_avg:93.12ms +step:81/1695 train_time:7543ms step_avg:93.13ms +step:82/1695 train_time:7637ms step_avg:93.13ms +step:83/1695 train_time:7730ms step_avg:93.14ms +step:84/1695 train_time:7824ms step_avg:93.15ms +step:85/1695 train_time:7918ms step_avg:93.15ms +step:86/1695 train_time:8011ms step_avg:93.15ms +step:87/1695 train_time:8104ms step_avg:93.15ms +step:88/1695 train_time:8198ms step_avg:93.16ms +step:89/1695 train_time:8291ms step_avg:93.16ms +step:90/1695 train_time:8385ms step_avg:93.17ms +step:91/1695 train_time:8479ms step_avg:93.17ms +step:92/1695 train_time:8571ms step_avg:93.17ms +step:93/1695 train_time:8665ms step_avg:93.17ms +step:94/1695 train_time:8758ms step_avg:93.17ms +step:95/1695 train_time:8852ms step_avg:93.17ms +step:96/1695 train_time:8946ms step_avg:93.18ms +step:97/1695 train_time:9039ms step_avg:93.19ms +step:98/1695 train_time:9133ms step_avg:93.19ms +step:99/1695 train_time:9227ms step_avg:93.20ms +step:100/1695 train_time:9321ms step_avg:93.21ms +step:101/1695 train_time:9414ms step_avg:93.21ms +step:102/1695 train_time:9507ms step_avg:93.21ms +step:103/1695 train_time:9602ms step_avg:93.22ms +step:104/1695 train_time:9694ms step_avg:93.21ms +step:105/1695 train_time:9788ms step_avg:93.22ms +step:106/1695 train_time:9882ms step_avg:93.23ms +step:107/1695 train_time:9976ms step_avg:93.23ms +step:108/1695 train_time:10070ms step_avg:93.24ms +step:109/1695 train_time:10164ms step_avg:93.25ms +step:110/1695 train_time:10258ms step_avg:93.26ms +step:111/1695 train_time:10351ms step_avg:93.25ms +step:112/1695 train_time:10445ms step_avg:93.26ms +step:113/1695 train_time:10538ms step_avg:93.26ms +step:114/1695 train_time:10631ms step_avg:93.26ms +step:115/1695 train_time:10724ms step_avg:93.26ms +step:116/1695 train_time:10818ms step_avg:93.26ms +step:117/1695 train_time:10912ms step_avg:93.26ms +step:118/1695 train_time:11005ms step_avg:93.27ms +step:119/1695 train_time:11098ms step_avg:93.26ms +step:120/1695 train_time:11191ms step_avg:93.26ms +step:121/1695 train_time:11285ms step_avg:93.27ms +step:122/1695 train_time:11379ms step_avg:93.27ms +step:123/1695 train_time:11472ms step_avg:93.27ms +step:124/1695 train_time:11566ms step_avg:93.27ms +step:125/1695 train_time:11659ms step_avg:93.27ms +step:125/1695 val_loss:4.6123 train_time:11750ms step_avg:94.00ms +step:126/1695 train_time:11777ms step_avg:93.47ms +step:127/1695 train_time:11853ms step_avg:93.33ms +step:128/1695 train_time:11958ms step_avg:93.42ms +step:129/1695 train_time:12053ms step_avg:93.43ms +step:130/1695 train_time:12147ms step_avg:93.44ms +step:131/1695 train_time:12240ms step_avg:93.43ms +step:132/1695 train_time:12333ms step_avg:93.43ms +step:133/1695 train_time:12426ms step_avg:93.43ms +step:134/1695 train_time:12519ms step_avg:93.42ms +step:135/1695 train_time:12612ms step_avg:93.42ms +step:136/1695 train_time:12705ms step_avg:93.42ms +step:137/1695 train_time:12799ms step_avg:93.42ms +step:138/1695 train_time:12894ms step_avg:93.44ms +step:139/1695 train_time:12990ms step_avg:93.45ms +step:140/1695 train_time:13085ms step_avg:93.46ms +step:141/1695 train_time:13179ms step_avg:93.47ms +step:142/1695 train_time:13273ms step_avg:93.47ms +step:143/1695 train_time:13367ms step_avg:93.47ms +step:144/1695 train_time:13460ms step_avg:93.47ms +step:145/1695 train_time:13554ms step_avg:93.48ms +step:146/1695 train_time:13647ms step_avg:93.47ms +step:147/1695 train_time:13741ms step_avg:93.48ms +step:148/1695 train_time:13835ms step_avg:93.48ms +step:149/1695 train_time:13929ms step_avg:93.48ms +step:150/1695 train_time:14023ms step_avg:93.49ms +step:151/1695 train_time:14118ms step_avg:93.50ms +step:152/1695 train_time:14212ms step_avg:93.50ms +step:153/1695 train_time:14306ms step_avg:93.50ms +step:154/1695 train_time:14400ms step_avg:93.51ms +step:155/1695 train_time:14494ms step_avg:93.51ms +step:156/1695 train_time:14587ms step_avg:93.51ms +step:157/1695 train_time:14681ms step_avg:93.51ms +step:158/1695 train_time:14775ms step_avg:93.51ms +step:159/1695 train_time:14869ms step_avg:93.52ms +step:160/1695 train_time:14963ms step_avg:93.52ms +step:161/1695 train_time:15057ms step_avg:93.52ms +step:162/1695 train_time:15151ms step_avg:93.53ms +step:163/1695 train_time:15246ms step_avg:93.53ms +step:164/1695 train_time:15340ms step_avg:93.53ms +step:165/1695 train_time:15434ms step_avg:93.54ms +step:166/1695 train_time:15528ms step_avg:93.54ms +step:167/1695 train_time:15622ms step_avg:93.54ms +step:168/1695 train_time:15716ms step_avg:93.55ms +step:169/1695 train_time:15809ms step_avg:93.55ms +step:170/1695 train_time:15904ms step_avg:93.55ms +step:171/1695 train_time:15998ms step_avg:93.55ms +step:172/1695 train_time:16093ms step_avg:93.56ms +step:173/1695 train_time:16186ms step_avg:93.56ms +step:174/1695 train_time:16279ms step_avg:93.56ms +step:175/1695 train_time:16373ms step_avg:93.56ms +step:176/1695 train_time:16467ms step_avg:93.56ms +step:177/1695 train_time:16561ms step_avg:93.56ms +step:178/1695 train_time:16655ms step_avg:93.57ms +step:179/1695 train_time:16749ms step_avg:93.57ms +step:180/1695 train_time:16842ms step_avg:93.57ms +step:181/1695 train_time:16937ms step_avg:93.57ms +step:182/1695 train_time:17030ms step_avg:93.57ms +step:183/1695 train_time:17124ms step_avg:93.58ms +step:184/1695 train_time:17219ms step_avg:93.58ms +step:185/1695 train_time:17313ms step_avg:93.58ms +step:186/1695 train_time:17407ms step_avg:93.58ms +step:187/1695 train_time:17500ms step_avg:93.58ms +step:188/1695 train_time:17594ms step_avg:93.58ms +step:189/1695 train_time:17688ms step_avg:93.59ms +step:190/1695 train_time:17781ms step_avg:93.59ms +step:191/1695 train_time:17876ms step_avg:93.59ms +step:192/1695 train_time:17971ms step_avg:93.60ms +step:193/1695 train_time:18065ms step_avg:93.60ms +step:194/1695 train_time:18159ms step_avg:93.60ms +step:195/1695 train_time:18252ms step_avg:93.60ms +step:196/1695 train_time:18347ms step_avg:93.61ms +step:197/1695 train_time:18442ms step_avg:93.61ms +step:198/1695 train_time:18536ms step_avg:93.62ms +step:199/1695 train_time:18630ms step_avg:93.62ms +step:200/1695 train_time:18723ms step_avg:93.62ms +step:201/1695 train_time:18818ms step_avg:93.62ms +step:202/1695 train_time:18912ms step_avg:93.62ms +step:203/1695 train_time:19005ms step_avg:93.62ms +step:204/1695 train_time:19099ms step_avg:93.62ms +step:205/1695 train_time:19193ms step_avg:93.63ms +step:206/1695 train_time:19288ms step_avg:93.63ms +step:207/1695 train_time:19382ms step_avg:93.63ms +step:208/1695 train_time:19476ms step_avg:93.63ms +step:209/1695 train_time:19570ms step_avg:93.63ms +step:210/1695 train_time:19664ms step_avg:93.64ms +step:211/1695 train_time:19758ms step_avg:93.64ms +step:212/1695 train_time:19851ms step_avg:93.64ms +step:213/1695 train_time:19945ms step_avg:93.64ms +step:214/1695 train_time:20040ms step_avg:93.64ms +step:215/1695 train_time:20134ms step_avg:93.65ms +step:216/1695 train_time:20227ms step_avg:93.65ms +step:217/1695 train_time:20321ms step_avg:93.65ms +step:218/1695 train_time:20415ms step_avg:93.65ms +step:219/1695 train_time:20509ms step_avg:93.65ms +step:220/1695 train_time:20603ms step_avg:93.65ms +step:221/1695 train_time:20697ms step_avg:93.65ms +step:222/1695 train_time:20791ms step_avg:93.65ms +step:223/1695 train_time:20884ms step_avg:93.65ms +step:224/1695 train_time:20979ms step_avg:93.66ms +step:225/1695 train_time:21073ms step_avg:93.66ms +step:226/1695 train_time:21167ms step_avg:93.66ms +step:227/1695 train_time:21261ms step_avg:93.66ms +step:228/1695 train_time:21355ms step_avg:93.66ms +step:229/1695 train_time:21449ms step_avg:93.67ms +step:230/1695 train_time:21543ms step_avg:93.67ms +step:231/1695 train_time:21637ms step_avg:93.67ms +step:232/1695 train_time:21731ms step_avg:93.67ms +step:233/1695 train_time:21824ms step_avg:93.67ms +step:234/1695 train_time:21918ms step_avg:93.67ms +step:235/1695 train_time:22012ms step_avg:93.67ms +step:236/1695 train_time:22105ms step_avg:93.67ms +step:237/1695 train_time:22199ms step_avg:93.67ms +step:238/1695 train_time:22293ms step_avg:93.67ms +step:239/1695 train_time:22387ms step_avg:93.67ms +step:240/1695 train_time:22482ms step_avg:93.67ms +step:241/1695 train_time:22575ms step_avg:93.67ms +step:242/1695 train_time:22669ms step_avg:93.67ms +step:243/1695 train_time:22763ms step_avg:93.68ms +step:244/1695 train_time:22857ms step_avg:93.68ms +step:245/1695 train_time:22951ms step_avg:93.68ms +step:246/1695 train_time:23046ms step_avg:93.68ms +step:247/1695 train_time:23138ms step_avg:93.68ms +step:248/1695 train_time:23232ms step_avg:93.68ms +step:249/1695 train_time:23326ms step_avg:93.68ms +step:250/1695 train_time:23419ms step_avg:93.68ms +step:250/1695 val_loss:4.0789 train_time:23511ms step_avg:94.04ms +step:251/1695 train_time:23539ms step_avg:93.78ms +step:252/1695 train_time:23617ms step_avg:93.72ms +step:253/1695 train_time:23717ms step_avg:93.74ms +step:254/1695 train_time:23813ms step_avg:93.75ms +step:255/1695 train_time:23907ms step_avg:93.75ms +step:256/1695 train_time:24000ms step_avg:93.75ms +step:257/1695 train_time:24094ms step_avg:93.75ms +step:258/1695 train_time:24188ms step_avg:93.75ms +step:259/1695 train_time:24282ms step_avg:93.75ms +step:260/1695 train_time:24376ms step_avg:93.75ms +step:261/1695 train_time:24469ms step_avg:93.75ms +step:262/1695 train_time:24564ms step_avg:93.75ms +step:263/1695 train_time:24660ms step_avg:93.76ms +step:264/1695 train_time:24756ms step_avg:93.77ms +step:265/1695 train_time:24852ms step_avg:93.78ms +step:266/1695 train_time:24946ms step_avg:93.78ms +step:267/1695 train_time:25040ms step_avg:93.78ms +step:268/1695 train_time:25135ms step_avg:93.79ms +step:269/1695 train_time:25229ms step_avg:93.79ms +step:270/1695 train_time:25323ms step_avg:93.79ms +step:271/1695 train_time:25417ms step_avg:93.79ms +step:272/1695 train_time:25510ms step_avg:93.79ms +step:273/1695 train_time:25605ms step_avg:93.79ms +step:274/1695 train_time:25700ms step_avg:93.80ms +step:275/1695 train_time:25796ms step_avg:93.80ms +step:276/1695 train_time:25890ms step_avg:93.80ms +step:277/1695 train_time:25984ms step_avg:93.81ms +step:278/1695 train_time:26079ms step_avg:93.81ms +step:279/1695 train_time:26173ms step_avg:93.81ms +step:280/1695 train_time:26267ms step_avg:93.81ms +step:281/1695 train_time:26361ms step_avg:93.81ms +step:282/1695 train_time:26454ms step_avg:93.81ms +step:283/1695 train_time:26548ms step_avg:93.81ms +step:284/1695 train_time:26643ms step_avg:93.81ms +step:285/1695 train_time:26740ms step_avg:93.82ms +step:286/1695 train_time:26835ms step_avg:93.83ms +step:287/1695 train_time:26930ms step_avg:93.83ms +step:288/1695 train_time:27024ms step_avg:93.83ms +step:289/1695 train_time:27119ms step_avg:93.84ms +step:290/1695 train_time:27214ms step_avg:93.84ms +step:291/1695 train_time:27307ms step_avg:93.84ms +step:292/1695 train_time:27401ms step_avg:93.84ms +step:293/1695 train_time:27496ms step_avg:93.84ms +step:294/1695 train_time:27590ms step_avg:93.84ms +step:295/1695 train_time:27685ms step_avg:93.85ms +step:296/1695 train_time:27779ms step_avg:93.85ms +step:297/1695 train_time:27875ms step_avg:93.85ms +step:298/1695 train_time:27969ms step_avg:93.86ms +step:299/1695 train_time:28064ms step_avg:93.86ms +step:300/1695 train_time:28158ms step_avg:93.86ms +step:301/1695 train_time:28253ms step_avg:93.86ms +step:302/1695 train_time:28347ms step_avg:93.86ms +step:303/1695 train_time:28441ms step_avg:93.86ms +step:304/1695 train_time:28535ms step_avg:93.87ms +step:305/1695 train_time:28630ms step_avg:93.87ms +step:306/1695 train_time:28724ms step_avg:93.87ms +step:307/1695 train_time:28819ms step_avg:93.87ms +step:308/1695 train_time:28915ms step_avg:93.88ms +step:309/1695 train_time:29010ms step_avg:93.88ms +step:310/1695 train_time:29104ms step_avg:93.88ms +step:311/1695 train_time:29199ms step_avg:93.89ms +step:312/1695 train_time:29294ms step_avg:93.89ms +step:313/1695 train_time:29388ms step_avg:93.89ms +step:314/1695 train_time:29482ms step_avg:93.89ms +step:315/1695 train_time:29576ms step_avg:93.89ms +step:316/1695 train_time:29670ms step_avg:93.89ms +step:317/1695 train_time:29765ms step_avg:93.89ms +step:318/1695 train_time:29860ms step_avg:93.90ms +step:319/1695 train_time:29955ms step_avg:93.90ms +step:320/1695 train_time:30049ms step_avg:93.90ms +step:321/1695 train_time:30144ms step_avg:93.91ms +step:322/1695 train_time:30239ms step_avg:93.91ms +step:323/1695 train_time:30333ms step_avg:93.91ms +step:324/1695 train_time:30427ms step_avg:93.91ms +step:325/1695 train_time:30522ms step_avg:93.91ms +step:326/1695 train_time:30617ms step_avg:93.92ms +step:327/1695 train_time:30712ms step_avg:93.92ms +step:328/1695 train_time:30806ms step_avg:93.92ms +step:329/1695 train_time:30901ms step_avg:93.92ms +step:330/1695 train_time:30996ms step_avg:93.93ms +step:331/1695 train_time:31091ms step_avg:93.93ms +step:332/1695 train_time:31186ms step_avg:93.93ms +step:333/1695 train_time:31280ms step_avg:93.93ms +step:334/1695 train_time:31374ms step_avg:93.93ms +step:335/1695 train_time:31468ms step_avg:93.93ms +step:336/1695 train_time:31562ms step_avg:93.94ms +step:337/1695 train_time:31657ms step_avg:93.94ms +step:338/1695 train_time:31752ms step_avg:93.94ms +step:339/1695 train_time:31846ms step_avg:93.94ms +step:340/1695 train_time:31940ms step_avg:93.94ms +step:341/1695 train_time:32035ms step_avg:93.94ms +step:342/1695 train_time:32129ms step_avg:93.94ms +step:343/1695 train_time:32224ms step_avg:93.95ms +step:344/1695 train_time:32318ms step_avg:93.95ms +step:345/1695 train_time:32413ms step_avg:93.95ms +step:346/1695 train_time:32506ms step_avg:93.95ms +step:347/1695 train_time:32602ms step_avg:93.95ms +step:348/1695 train_time:32697ms step_avg:93.96ms +step:349/1695 train_time:32791ms step_avg:93.96ms +step:350/1695 train_time:32886ms step_avg:93.96ms +step:351/1695 train_time:32980ms step_avg:93.96ms +step:352/1695 train_time:33075ms step_avg:93.96ms +step:353/1695 train_time:33169ms step_avg:93.96ms +step:354/1695 train_time:33264ms step_avg:93.97ms +step:355/1695 train_time:33359ms step_avg:93.97ms +step:356/1695 train_time:33453ms step_avg:93.97ms +step:357/1695 train_time:33547ms step_avg:93.97ms +step:358/1695 train_time:33642ms step_avg:93.97ms +step:359/1695 train_time:33737ms step_avg:93.97ms +step:360/1695 train_time:33831ms step_avg:93.98ms +step:361/1695 train_time:33925ms step_avg:93.98ms +step:362/1695 train_time:34020ms step_avg:93.98ms +step:363/1695 train_time:34114ms step_avg:93.98ms +step:364/1695 train_time:34208ms step_avg:93.98ms +step:365/1695 train_time:34303ms step_avg:93.98ms +step:366/1695 train_time:34398ms step_avg:93.98ms +step:367/1695 train_time:34492ms step_avg:93.98ms +step:368/1695 train_time:34586ms step_avg:93.98ms +step:369/1695 train_time:34680ms step_avg:93.98ms +step:370/1695 train_time:34774ms step_avg:93.98ms +step:371/1695 train_time:34868ms step_avg:93.98ms +step:372/1695 train_time:34963ms step_avg:93.99ms +step:373/1695 train_time:35058ms step_avg:93.99ms +step:374/1695 train_time:35153ms step_avg:93.99ms +step:375/1695 train_time:35248ms step_avg:93.99ms +step:375/1695 val_loss:3.8806 train_time:35339ms step_avg:94.24ms +step:376/1695 train_time:35365ms step_avg:94.06ms +step:377/1695 train_time:35445ms step_avg:94.02ms +step:378/1695 train_time:35545ms step_avg:94.04ms +step:379/1695 train_time:35644ms step_avg:94.05ms +step:380/1695 train_time:35740ms step_avg:94.05ms +step:381/1695 train_time:35836ms step_avg:94.06ms +step:382/1695 train_time:35931ms step_avg:94.06ms +step:383/1695 train_time:36026ms step_avg:94.06ms +step:384/1695 train_time:36122ms step_avg:94.07ms +step:385/1695 train_time:36219ms step_avg:94.07ms +step:386/1695 train_time:36314ms step_avg:94.08ms +step:387/1695 train_time:36410ms step_avg:94.08ms +step:388/1695 train_time:36508ms step_avg:94.09ms +step:389/1695 train_time:36605ms step_avg:94.10ms +step:390/1695 train_time:36702ms step_avg:94.11ms +step:391/1695 train_time:36798ms step_avg:94.11ms +step:392/1695 train_time:36893ms step_avg:94.12ms +step:393/1695 train_time:36989ms step_avg:94.12ms +step:394/1695 train_time:37085ms step_avg:94.12ms +step:395/1695 train_time:37181ms step_avg:94.13ms +step:396/1695 train_time:37276ms step_avg:94.13ms +step:397/1695 train_time:37373ms step_avg:94.14ms +step:398/1695 train_time:37469ms step_avg:94.14ms +step:399/1695 train_time:37565ms step_avg:94.15ms +step:400/1695 train_time:37662ms step_avg:94.15ms +step:401/1695 train_time:37758ms step_avg:94.16ms +step:402/1695 train_time:37855ms step_avg:94.17ms +step:403/1695 train_time:37952ms step_avg:94.17ms +step:404/1695 train_time:38047ms step_avg:94.18ms +step:405/1695 train_time:38142ms step_avg:94.18ms +step:406/1695 train_time:38238ms step_avg:94.18ms +step:407/1695 train_time:38334ms step_avg:94.19ms +step:408/1695 train_time:38430ms step_avg:94.19ms +step:409/1695 train_time:38525ms step_avg:94.19ms +step:410/1695 train_time:38622ms step_avg:94.20ms +step:411/1695 train_time:38719ms step_avg:94.21ms +step:412/1695 train_time:38815ms step_avg:94.21ms +step:413/1695 train_time:38911ms step_avg:94.22ms +step:414/1695 train_time:39006ms step_avg:94.22ms +step:415/1695 train_time:39103ms step_avg:94.22ms +step:416/1695 train_time:39198ms step_avg:94.23ms +step:417/1695 train_time:39294ms step_avg:94.23ms +step:418/1695 train_time:39389ms step_avg:94.23ms +step:419/1695 train_time:39485ms step_avg:94.24ms +step:420/1695 train_time:39582ms step_avg:94.24ms +step:421/1695 train_time:39677ms step_avg:94.25ms +step:422/1695 train_time:39774ms step_avg:94.25ms +step:423/1695 train_time:39870ms step_avg:94.26ms +step:424/1695 train_time:39966ms step_avg:94.26ms +step:425/1695 train_time:40062ms step_avg:94.26ms +step:426/1695 train_time:40158ms step_avg:94.27ms +step:427/1695 train_time:40255ms step_avg:94.27ms +step:428/1695 train_time:40351ms step_avg:94.28ms +step:429/1695 train_time:40446ms step_avg:94.28ms +step:430/1695 train_time:40543ms step_avg:94.29ms +step:431/1695 train_time:40638ms step_avg:94.29ms +step:432/1695 train_time:40734ms step_avg:94.29ms +step:433/1695 train_time:40830ms step_avg:94.30ms +step:434/1695 train_time:40926ms step_avg:94.30ms +step:435/1695 train_time:41022ms step_avg:94.30ms +step:436/1695 train_time:41118ms step_avg:94.31ms +step:437/1695 train_time:41214ms step_avg:94.31ms +step:438/1695 train_time:41310ms step_avg:94.32ms +step:439/1695 train_time:41406ms step_avg:94.32ms +step:440/1695 train_time:41502ms step_avg:94.32ms +step:441/1695 train_time:41598ms step_avg:94.33ms +step:442/1695 train_time:41695ms step_avg:94.33ms +step:443/1695 train_time:41791ms step_avg:94.34ms +step:444/1695 train_time:41887ms step_avg:94.34ms +step:445/1695 train_time:41983ms step_avg:94.34ms +step:446/1695 train_time:42079ms step_avg:94.35ms +step:447/1695 train_time:42175ms step_avg:94.35ms +step:448/1695 train_time:42271ms step_avg:94.35ms +step:449/1695 train_time:42367ms step_avg:94.36ms +step:450/1695 train_time:42463ms step_avg:94.36ms +step:451/1695 train_time:42560ms step_avg:94.37ms +step:452/1695 train_time:42657ms step_avg:94.37ms +step:453/1695 train_time:42754ms step_avg:94.38ms +step:454/1695 train_time:42850ms step_avg:94.38ms +step:455/1695 train_time:42946ms step_avg:94.39ms +step:456/1695 train_time:43042ms step_avg:94.39ms +step:457/1695 train_time:43138ms step_avg:94.39ms +step:458/1695 train_time:43234ms step_avg:94.40ms +step:459/1695 train_time:43332ms step_avg:94.40ms +step:460/1695 train_time:43428ms step_avg:94.41ms +step:461/1695 train_time:43523ms step_avg:94.41ms +step:462/1695 train_time:43620ms step_avg:94.42ms +step:463/1695 train_time:43717ms step_avg:94.42ms +step:464/1695 train_time:43813ms step_avg:94.42ms +step:465/1695 train_time:43909ms step_avg:94.43ms +step:466/1695 train_time:44004ms step_avg:94.43ms +step:467/1695 train_time:44101ms step_avg:94.43ms +step:468/1695 train_time:44197ms step_avg:94.44ms +step:469/1695 train_time:44294ms step_avg:94.44ms +step:470/1695 train_time:44390ms step_avg:94.45ms +step:471/1695 train_time:44486ms step_avg:94.45ms +step:472/1695 train_time:44582ms step_avg:94.45ms +step:473/1695 train_time:44678ms step_avg:94.46ms +step:474/1695 train_time:44774ms step_avg:94.46ms +step:475/1695 train_time:44870ms step_avg:94.46ms +step:476/1695 train_time:44965ms step_avg:94.47ms +step:477/1695 train_time:45062ms step_avg:94.47ms +step:478/1695 train_time:45158ms step_avg:94.47ms +step:479/1695 train_time:45255ms step_avg:94.48ms +step:480/1695 train_time:45351ms step_avg:94.48ms +step:481/1695 train_time:45446ms step_avg:94.48ms +step:482/1695 train_time:45542ms step_avg:94.49ms +step:483/1695 train_time:45639ms step_avg:94.49ms +step:484/1695 train_time:45735ms step_avg:94.49ms +step:485/1695 train_time:45831ms step_avg:94.50ms +step:486/1695 train_time:45927ms step_avg:94.50ms +step:487/1695 train_time:46023ms step_avg:94.50ms +step:488/1695 train_time:46119ms step_avg:94.51ms +step:489/1695 train_time:46216ms step_avg:94.51ms +step:490/1695 train_time:46312ms step_avg:94.51ms +step:491/1695 train_time:46408ms step_avg:94.52ms +step:492/1695 train_time:46504ms step_avg:94.52ms +step:493/1695 train_time:46600ms step_avg:94.52ms +step:494/1695 train_time:46696ms step_avg:94.53ms +step:495/1695 train_time:46793ms step_avg:94.53ms +step:496/1695 train_time:46889ms step_avg:94.53ms +step:497/1695 train_time:46985ms step_avg:94.54ms +step:498/1695 train_time:47081ms step_avg:94.54ms +step:499/1695 train_time:47178ms step_avg:94.54ms +step:500/1695 train_time:47273ms step_avg:94.55ms +step:500/1695 val_loss:3.7364 train_time:47368ms step_avg:94.74ms +step:501/1695 train_time:47395ms step_avg:94.60ms +step:502/1695 train_time:47475ms step_avg:94.57ms +step:503/1695 train_time:47577ms step_avg:94.59ms +step:504/1695 train_time:47674ms step_avg:94.59ms +step:505/1695 train_time:47770ms step_avg:94.59ms +step:506/1695 train_time:47866ms step_avg:94.60ms +step:507/1695 train_time:47962ms step_avg:94.60ms +step:508/1695 train_time:48057ms step_avg:94.60ms +step:509/1695 train_time:48153ms step_avg:94.60ms +step:510/1695 train_time:48248ms step_avg:94.60ms +step:511/1695 train_time:48345ms step_avg:94.61ms +step:512/1695 train_time:48442ms step_avg:94.61ms +step:513/1695 train_time:48541ms step_avg:94.62ms +step:514/1695 train_time:48639ms step_avg:94.63ms +step:515/1695 train_time:48736ms step_avg:94.63ms +step:516/1695 train_time:48832ms step_avg:94.64ms +step:517/1695 train_time:48930ms step_avg:94.64ms +step:518/1695 train_time:49025ms step_avg:94.64ms +step:519/1695 train_time:49121ms step_avg:94.64ms +step:520/1695 train_time:49217ms step_avg:94.65ms +step:521/1695 train_time:49313ms step_avg:94.65ms +step:522/1695 train_time:49409ms step_avg:94.65ms +step:523/1695 train_time:49506ms step_avg:94.66ms +step:524/1695 train_time:49604ms step_avg:94.66ms +step:525/1695 train_time:49702ms step_avg:94.67ms +step:526/1695 train_time:49799ms step_avg:94.68ms +step:527/1695 train_time:49897ms step_avg:94.68ms +step:528/1695 train_time:49994ms step_avg:94.69ms +step:529/1695 train_time:50091ms step_avg:94.69ms +step:530/1695 train_time:50186ms step_avg:94.69ms +step:531/1695 train_time:50282ms step_avg:94.69ms +step:532/1695 train_time:50379ms step_avg:94.70ms +step:533/1695 train_time:50475ms step_avg:94.70ms +step:534/1695 train_time:50573ms step_avg:94.71ms +step:535/1695 train_time:50671ms step_avg:94.71ms +step:536/1695 train_time:50767ms step_avg:94.72ms +step:537/1695 train_time:50863ms step_avg:94.72ms +step:538/1695 train_time:50960ms step_avg:94.72ms +step:539/1695 train_time:51058ms step_avg:94.73ms +step:540/1695 train_time:51155ms step_avg:94.73ms +step:541/1695 train_time:51253ms step_avg:94.74ms +step:542/1695 train_time:51350ms step_avg:94.74ms +step:543/1695 train_time:51446ms step_avg:94.74ms +step:544/1695 train_time:51542ms step_avg:94.75ms +step:545/1695 train_time:51638ms step_avg:94.75ms +step:546/1695 train_time:51736ms step_avg:94.75ms +step:547/1695 train_time:51833ms step_avg:94.76ms +step:548/1695 train_time:51929ms step_avg:94.76ms +step:549/1695 train_time:52025ms step_avg:94.76ms +step:550/1695 train_time:52122ms step_avg:94.77ms +step:551/1695 train_time:52219ms step_avg:94.77ms +step:552/1695 train_time:52317ms step_avg:94.78ms +step:553/1695 train_time:52415ms step_avg:94.78ms +step:554/1695 train_time:52512ms step_avg:94.79ms +step:555/1695 train_time:52609ms step_avg:94.79ms +step:556/1695 train_time:52706ms step_avg:94.79ms +step:557/1695 train_time:52802ms step_avg:94.80ms +step:558/1695 train_time:52899ms step_avg:94.80ms +step:559/1695 train_time:52995ms step_avg:94.80ms +step:560/1695 train_time:53092ms step_avg:94.81ms +step:561/1695 train_time:53189ms step_avg:94.81ms +step:562/1695 train_time:53286ms step_avg:94.81ms +step:563/1695 train_time:53382ms step_avg:94.82ms +step:564/1695 train_time:53480ms step_avg:94.82ms +step:565/1695 train_time:53577ms step_avg:94.83ms +step:566/1695 train_time:53675ms step_avg:94.83ms +step:567/1695 train_time:53773ms step_avg:94.84ms +step:568/1695 train_time:53870ms step_avg:94.84ms +step:569/1695 train_time:53966ms step_avg:94.84ms +step:570/1695 train_time:54062ms step_avg:94.85ms +step:571/1695 train_time:54159ms step_avg:94.85ms +step:572/1695 train_time:54256ms step_avg:94.85ms +step:573/1695 train_time:54353ms step_avg:94.86ms +step:574/1695 train_time:54450ms step_avg:94.86ms +step:575/1695 train_time:54547ms step_avg:94.86ms +step:576/1695 train_time:54642ms step_avg:94.87ms +step:577/1695 train_time:54739ms step_avg:94.87ms +step:578/1695 train_time:54837ms step_avg:94.87ms +step:579/1695 train_time:54935ms step_avg:94.88ms +step:580/1695 train_time:55033ms step_avg:94.88ms +step:581/1695 train_time:55130ms step_avg:94.89ms +step:582/1695 train_time:55227ms step_avg:94.89ms +step:583/1695 train_time:55323ms step_avg:94.89ms +step:584/1695 train_time:55420ms step_avg:94.90ms +step:585/1695 train_time:55518ms step_avg:94.90ms +step:586/1695 train_time:55616ms step_avg:94.91ms +step:587/1695 train_time:55713ms step_avg:94.91ms +step:588/1695 train_time:55810ms step_avg:94.91ms +step:589/1695 train_time:55906ms step_avg:94.92ms +step:590/1695 train_time:56002ms step_avg:94.92ms +step:591/1695 train_time:56099ms step_avg:94.92ms +step:592/1695 train_time:56196ms step_avg:94.93ms +step:593/1695 train_time:56294ms step_avg:94.93ms +step:594/1695 train_time:56391ms step_avg:94.93ms +step:595/1695 train_time:56487ms step_avg:94.94ms +step:596/1695 train_time:56583ms step_avg:94.94ms +step:597/1695 train_time:56680ms step_avg:94.94ms +step:598/1695 train_time:56776ms step_avg:94.94ms +step:599/1695 train_time:56873ms step_avg:94.95ms +step:600/1695 train_time:56970ms step_avg:94.95ms +step:601/1695 train_time:57066ms step_avg:94.95ms +step:602/1695 train_time:57162ms step_avg:94.95ms +step:603/1695 train_time:57258ms step_avg:94.96ms +step:604/1695 train_time:57355ms step_avg:94.96ms +step:605/1695 train_time:57452ms step_avg:94.96ms +step:606/1695 train_time:57548ms step_avg:94.96ms +step:607/1695 train_time:57644ms step_avg:94.97ms +step:608/1695 train_time:57740ms step_avg:94.97ms +step:609/1695 train_time:57835ms step_avg:94.97ms +step:610/1695 train_time:57932ms step_avg:94.97ms +step:611/1695 train_time:58029ms step_avg:94.97ms +step:612/1695 train_time:58125ms step_avg:94.97ms +step:613/1695 train_time:58222ms step_avg:94.98ms +step:614/1695 train_time:58319ms step_avg:94.98ms +step:615/1695 train_time:58416ms step_avg:94.99ms +step:616/1695 train_time:58512ms step_avg:94.99ms +step:617/1695 train_time:58609ms step_avg:94.99ms +step:618/1695 train_time:58705ms step_avg:94.99ms +step:619/1695 train_time:58801ms step_avg:94.99ms +step:620/1695 train_time:58898ms step_avg:95.00ms +step:621/1695 train_time:58995ms step_avg:95.00ms +step:622/1695 train_time:59092ms step_avg:95.00ms +step:623/1695 train_time:59189ms step_avg:95.01ms +step:624/1695 train_time:59287ms step_avg:95.01ms +step:625/1695 train_time:59383ms step_avg:95.01ms +step:625/1695 val_loss:3.6465 train_time:59477ms step_avg:95.16ms +step:626/1695 train_time:59504ms step_avg:95.05ms +step:627/1695 train_time:59583ms step_avg:95.03ms +step:628/1695 train_time:59681ms step_avg:95.03ms +step:629/1695 train_time:59778ms step_avg:95.04ms +step:630/1695 train_time:59876ms step_avg:95.04ms +step:631/1695 train_time:59974ms step_avg:95.05ms +step:632/1695 train_time:60071ms step_avg:95.05ms +step:633/1695 train_time:60167ms step_avg:95.05ms +step:634/1695 train_time:60264ms step_avg:95.05ms +step:635/1695 train_time:60361ms step_avg:95.06ms +step:636/1695 train_time:60459ms step_avg:95.06ms +step:637/1695 train_time:60558ms step_avg:95.07ms +step:638/1695 train_time:60657ms step_avg:95.07ms +step:639/1695 train_time:60757ms step_avg:95.08ms +step:640/1695 train_time:60856ms step_avg:95.09ms +step:641/1695 train_time:60955ms step_avg:95.09ms +step:642/1695 train_time:61052ms step_avg:95.10ms +step:643/1695 train_time:61150ms step_avg:95.10ms +step:644/1695 train_time:61247ms step_avg:95.10ms +step:645/1695 train_time:61345ms step_avg:95.11ms +step:646/1695 train_time:61443ms step_avg:95.11ms +step:647/1695 train_time:61541ms step_avg:95.12ms +step:648/1695 train_time:61639ms step_avg:95.12ms +step:649/1695 train_time:61737ms step_avg:95.13ms +step:650/1695 train_time:61835ms step_avg:95.13ms +step:651/1695 train_time:61934ms step_avg:95.14ms +step:652/1695 train_time:62033ms step_avg:95.14ms +step:653/1695 train_time:62131ms step_avg:95.15ms +step:654/1695 train_time:62229ms step_avg:95.15ms +step:655/1695 train_time:62327ms step_avg:95.16ms +step:656/1695 train_time:62425ms step_avg:95.16ms +step:657/1695 train_time:62522ms step_avg:95.16ms +step:658/1695 train_time:62620ms step_avg:95.17ms +step:659/1695 train_time:62719ms step_avg:95.17ms +step:660/1695 train_time:62817ms step_avg:95.18ms +step:661/1695 train_time:62915ms step_avg:95.18ms +step:662/1695 train_time:63014ms step_avg:95.19ms +step:663/1695 train_time:63112ms step_avg:95.19ms +step:664/1695 train_time:63209ms step_avg:95.19ms +step:665/1695 train_time:63309ms step_avg:95.20ms +step:666/1695 train_time:63407ms step_avg:95.21ms +step:667/1695 train_time:63505ms step_avg:95.21ms +step:668/1695 train_time:63605ms step_avg:95.22ms +step:669/1695 train_time:63705ms step_avg:95.22ms +step:670/1695 train_time:63803ms step_avg:95.23ms +step:671/1695 train_time:63901ms step_avg:95.23ms +step:672/1695 train_time:64000ms step_avg:95.24ms +step:673/1695 train_time:64098ms step_avg:95.24ms +step:674/1695 train_time:64195ms step_avg:95.25ms +step:675/1695 train_time:64294ms step_avg:95.25ms +step:676/1695 train_time:64391ms step_avg:95.25ms +step:677/1695 train_time:64490ms step_avg:95.26ms +step:678/1695 train_time:64590ms step_avg:95.26ms +step:679/1695 train_time:64690ms step_avg:95.27ms +step:680/1695 train_time:64789ms step_avg:95.28ms +step:681/1695 train_time:64888ms step_avg:95.28ms +step:682/1695 train_time:64987ms step_avg:95.29ms +step:683/1695 train_time:65086ms step_avg:95.29ms +step:684/1695 train_time:65184ms step_avg:95.30ms +step:685/1695 train_time:65282ms step_avg:95.30ms +step:686/1695 train_time:65380ms step_avg:95.31ms +step:687/1695 train_time:65477ms step_avg:95.31ms +step:688/1695 train_time:65575ms step_avg:95.31ms +step:689/1695 train_time:65672ms step_avg:95.32ms +step:690/1695 train_time:65770ms step_avg:95.32ms +step:691/1695 train_time:65869ms step_avg:95.32ms +step:692/1695 train_time:65967ms step_avg:95.33ms +step:693/1695 train_time:66067ms step_avg:95.33ms +step:694/1695 train_time:66166ms step_avg:95.34ms +step:695/1695 train_time:66264ms step_avg:95.34ms +step:696/1695 train_time:66362ms step_avg:95.35ms +step:697/1695 train_time:66460ms step_avg:95.35ms +step:698/1695 train_time:66558ms step_avg:95.35ms +step:699/1695 train_time:66655ms step_avg:95.36ms +step:700/1695 train_time:66753ms step_avg:95.36ms +step:701/1695 train_time:66851ms step_avg:95.36ms +step:702/1695 train_time:66949ms step_avg:95.37ms +step:703/1695 train_time:67047ms step_avg:95.37ms +step:704/1695 train_time:67146ms step_avg:95.38ms +step:705/1695 train_time:67245ms step_avg:95.38ms +step:706/1695 train_time:67343ms step_avg:95.39ms +step:707/1695 train_time:67441ms step_avg:95.39ms +step:708/1695 train_time:67539ms step_avg:95.39ms +step:709/1695 train_time:67636ms step_avg:95.40ms +step:710/1695 train_time:67734ms step_avg:95.40ms +step:711/1695 train_time:67831ms step_avg:95.40ms +step:712/1695 train_time:67930ms step_avg:95.41ms +step:713/1695 train_time:68027ms step_avg:95.41ms +step:714/1695 train_time:68126ms step_avg:95.41ms +step:715/1695 train_time:68225ms step_avg:95.42ms +step:716/1695 train_time:68322ms step_avg:95.42ms +step:717/1695 train_time:68420ms step_avg:95.43ms +step:718/1695 train_time:68518ms step_avg:95.43ms +step:719/1695 train_time:68615ms step_avg:95.43ms +step:720/1695 train_time:68713ms step_avg:95.43ms +step:721/1695 train_time:68810ms step_avg:95.44ms +step:722/1695 train_time:68909ms step_avg:95.44ms +step:723/1695 train_time:69006ms step_avg:95.44ms +step:724/1695 train_time:69104ms step_avg:95.45ms +step:725/1695 train_time:69201ms step_avg:95.45ms +step:726/1695 train_time:69299ms step_avg:95.45ms +step:727/1695 train_time:69398ms step_avg:95.46ms +step:728/1695 train_time:69496ms step_avg:95.46ms +step:729/1695 train_time:69594ms step_avg:95.46ms +step:730/1695 train_time:69692ms step_avg:95.47ms +step:731/1695 train_time:69790ms step_avg:95.47ms +step:732/1695 train_time:69888ms step_avg:95.48ms +step:733/1695 train_time:69986ms step_avg:95.48ms +step:734/1695 train_time:70084ms step_avg:95.48ms +step:735/1695 train_time:70181ms step_avg:95.49ms +step:736/1695 train_time:70279ms step_avg:95.49ms +step:737/1695 train_time:70377ms step_avg:95.49ms +step:738/1695 train_time:70475ms step_avg:95.49ms +step:739/1695 train_time:70572ms step_avg:95.50ms +step:740/1695 train_time:70670ms step_avg:95.50ms +step:741/1695 train_time:70768ms step_avg:95.50ms +step:742/1695 train_time:70867ms step_avg:95.51ms +step:743/1695 train_time:70965ms step_avg:95.51ms +step:744/1695 train_time:71062ms step_avg:95.51ms +step:745/1695 train_time:71160ms step_avg:95.52ms +step:746/1695 train_time:71258ms step_avg:95.52ms +step:747/1695 train_time:71356ms step_avg:95.52ms +step:748/1695 train_time:71454ms step_avg:95.53ms +step:749/1695 train_time:71552ms step_avg:95.53ms +step:750/1695 train_time:71650ms step_avg:95.53ms +step:750/1695 val_loss:3.5852 train_time:71745ms step_avg:95.66ms +step:751/1695 train_time:71772ms step_avg:95.57ms +step:752/1695 train_time:71856ms step_avg:95.55ms +step:753/1695 train_time:71958ms step_avg:95.56ms +step:754/1695 train_time:72057ms step_avg:95.57ms +step:755/1695 train_time:72156ms step_avg:95.57ms +step:756/1695 train_time:72254ms step_avg:95.57ms +step:757/1695 train_time:72351ms step_avg:95.58ms +step:758/1695 train_time:72449ms step_avg:95.58ms +step:759/1695 train_time:72546ms step_avg:95.58ms +step:760/1695 train_time:72643ms step_avg:95.58ms +step:761/1695 train_time:72741ms step_avg:95.59ms +step:762/1695 train_time:72840ms step_avg:95.59ms +step:763/1695 train_time:72939ms step_avg:95.59ms +step:764/1695 train_time:73037ms step_avg:95.60ms +step:765/1695 train_time:73136ms step_avg:95.60ms +step:766/1695 train_time:73234ms step_avg:95.61ms +step:767/1695 train_time:73332ms step_avg:95.61ms +step:768/1695 train_time:73430ms step_avg:95.61ms +step:769/1695 train_time:73528ms step_avg:95.62ms +step:770/1695 train_time:73626ms step_avg:95.62ms +step:771/1695 train_time:73724ms step_avg:95.62ms +step:772/1695 train_time:73823ms step_avg:95.63ms +step:773/1695 train_time:73921ms step_avg:95.63ms +step:774/1695 train_time:74020ms step_avg:95.63ms +step:775/1695 train_time:74117ms step_avg:95.63ms +step:776/1695 train_time:74215ms step_avg:95.64ms +step:777/1695 train_time:74313ms step_avg:95.64ms +step:778/1695 train_time:74411ms step_avg:95.64ms +step:779/1695 train_time:74510ms step_avg:95.65ms +step:780/1695 train_time:74607ms step_avg:95.65ms +step:781/1695 train_time:74705ms step_avg:95.65ms +step:782/1695 train_time:74804ms step_avg:95.66ms +step:783/1695 train_time:74904ms step_avg:95.66ms +step:784/1695 train_time:75002ms step_avg:95.67ms +step:785/1695 train_time:75099ms step_avg:95.67ms +step:786/1695 train_time:75197ms step_avg:95.67ms +step:787/1695 train_time:75294ms step_avg:95.67ms +step:788/1695 train_time:75392ms step_avg:95.68ms +step:789/1695 train_time:75491ms step_avg:95.68ms +step:790/1695 train_time:75590ms step_avg:95.68ms +step:791/1695 train_time:75688ms step_avg:95.69ms +step:792/1695 train_time:75787ms step_avg:95.69ms +step:793/1695 train_time:75885ms step_avg:95.69ms +step:794/1695 train_time:75984ms step_avg:95.70ms +step:795/1695 train_time:76083ms step_avg:95.70ms +step:796/1695 train_time:76181ms step_avg:95.71ms +step:797/1695 train_time:76279ms step_avg:95.71ms +step:798/1695 train_time:76376ms step_avg:95.71ms +step:799/1695 train_time:76474ms step_avg:95.71ms +step:800/1695 train_time:76572ms step_avg:95.72ms +step:801/1695 train_time:76671ms step_avg:95.72ms +step:802/1695 train_time:76771ms step_avg:95.72ms +step:803/1695 train_time:76870ms step_avg:95.73ms +step:804/1695 train_time:76970ms step_avg:95.73ms +step:805/1695 train_time:77069ms step_avg:95.74ms +step:806/1695 train_time:77169ms step_avg:95.74ms +step:807/1695 train_time:77268ms step_avg:95.75ms +step:808/1695 train_time:77367ms step_avg:95.75ms +step:809/1695 train_time:77465ms step_avg:95.75ms +step:810/1695 train_time:77562ms step_avg:95.76ms +step:811/1695 train_time:77659ms step_avg:95.76ms +step:812/1695 train_time:77757ms step_avg:95.76ms +step:813/1695 train_time:77855ms step_avg:95.76ms +step:814/1695 train_time:77952ms step_avg:95.76ms +step:815/1695 train_time:78051ms step_avg:95.77ms +step:816/1695 train_time:78150ms step_avg:95.77ms +step:817/1695 train_time:78249ms step_avg:95.78ms +step:818/1695 train_time:78348ms step_avg:95.78ms +step:819/1695 train_time:78447ms step_avg:95.78ms +step:820/1695 train_time:78545ms step_avg:95.79ms +step:821/1695 train_time:78643ms step_avg:95.79ms +step:822/1695 train_time:78742ms step_avg:95.79ms +step:823/1695 train_time:78840ms step_avg:95.80ms +step:824/1695 train_time:78937ms step_avg:95.80ms +step:825/1695 train_time:79036ms step_avg:95.80ms +step:826/1695 train_time:79134ms step_avg:95.80ms +step:827/1695 train_time:79233ms step_avg:95.81ms +step:828/1695 train_time:79332ms step_avg:95.81ms +step:829/1695 train_time:79432ms step_avg:95.82ms +step:830/1695 train_time:79531ms step_avg:95.82ms +step:831/1695 train_time:79629ms step_avg:95.82ms +step:832/1695 train_time:79728ms step_avg:95.83ms +step:833/1695 train_time:79826ms step_avg:95.83ms +step:834/1695 train_time:79925ms step_avg:95.83ms +step:835/1695 train_time:80023ms step_avg:95.84ms +step:836/1695 train_time:80122ms step_avg:95.84ms +step:837/1695 train_time:80220ms step_avg:95.84ms +step:838/1695 train_time:80319ms step_avg:95.85ms +step:839/1695 train_time:80417ms step_avg:95.85ms +step:840/1695 train_time:80517ms step_avg:95.85ms +step:841/1695 train_time:80618ms step_avg:95.86ms +step:842/1695 train_time:80717ms step_avg:95.86ms +step:843/1695 train_time:80815ms step_avg:95.87ms +step:844/1695 train_time:80914ms step_avg:95.87ms +step:845/1695 train_time:81013ms step_avg:95.87ms +step:846/1695 train_time:81113ms step_avg:95.88ms +step:847/1695 train_time:81212ms step_avg:95.88ms +step:848/1695 train_time:81312ms step_avg:95.89ms +step:849/1695 train_time:81411ms step_avg:95.89ms +step:850/1695 train_time:81510ms step_avg:95.89ms +step:851/1695 train_time:81608ms step_avg:95.90ms +step:852/1695 train_time:81707ms step_avg:95.90ms +step:853/1695 train_time:81805ms step_avg:95.90ms +step:854/1695 train_time:81903ms step_avg:95.91ms +step:855/1695 train_time:82001ms step_avg:95.91ms +step:856/1695 train_time:82099ms step_avg:95.91ms +step:857/1695 train_time:82199ms step_avg:95.91ms +step:858/1695 train_time:82297ms step_avg:95.92ms +step:859/1695 train_time:82395ms step_avg:95.92ms +step:860/1695 train_time:82495ms step_avg:95.92ms +step:861/1695 train_time:82593ms step_avg:95.93ms +step:862/1695 train_time:82692ms step_avg:95.93ms +step:863/1695 train_time:82791ms step_avg:95.93ms +step:864/1695 train_time:82889ms step_avg:95.94ms +step:865/1695 train_time:82989ms step_avg:95.94ms +step:866/1695 train_time:83088ms step_avg:95.94ms +step:867/1695 train_time:83187ms step_avg:95.95ms +step:868/1695 train_time:83285ms step_avg:95.95ms +step:869/1695 train_time:83384ms step_avg:95.95ms +step:870/1695 train_time:83481ms step_avg:95.96ms +step:871/1695 train_time:83580ms step_avg:95.96ms +step:872/1695 train_time:83677ms step_avg:95.96ms +step:873/1695 train_time:83775ms step_avg:95.96ms +step:874/1695 train_time:83873ms step_avg:95.96ms +step:875/1695 train_time:83971ms step_avg:95.97ms +step:875/1695 val_loss:3.5356 train_time:84068ms step_avg:96.08ms +step:876/1695 train_time:84095ms step_avg:96.00ms +step:877/1695 train_time:84181ms step_avg:95.99ms +step:878/1695 train_time:84283ms step_avg:95.99ms +step:879/1695 train_time:84381ms step_avg:96.00ms +step:880/1695 train_time:84479ms step_avg:96.00ms +step:881/1695 train_time:84578ms step_avg:96.00ms +step:882/1695 train_time:84678ms step_avg:96.01ms +step:883/1695 train_time:84777ms step_avg:96.01ms +step:884/1695 train_time:84878ms step_avg:96.02ms +step:885/1695 train_time:84976ms step_avg:96.02ms +step:886/1695 train_time:85077ms step_avg:96.02ms +step:887/1695 train_time:85179ms step_avg:96.03ms +step:888/1695 train_time:85280ms step_avg:96.04ms +step:889/1695 train_time:85381ms step_avg:96.04ms +step:890/1695 train_time:85481ms step_avg:96.05ms +step:891/1695 train_time:85580ms step_avg:96.05ms +step:892/1695 train_time:85679ms step_avg:96.05ms +step:893/1695 train_time:85778ms step_avg:96.06ms +step:894/1695 train_time:85878ms step_avg:96.06ms +step:895/1695 train_time:85978ms step_avg:96.06ms +step:896/1695 train_time:86078ms step_avg:96.07ms +step:897/1695 train_time:86179ms step_avg:96.07ms +step:898/1695 train_time:86280ms step_avg:96.08ms +step:899/1695 train_time:86380ms step_avg:96.08ms +step:900/1695 train_time:86481ms step_avg:96.09ms +step:901/1695 train_time:86581ms step_avg:96.09ms +step:902/1695 train_time:86680ms step_avg:96.10ms +step:903/1695 train_time:86780ms step_avg:96.10ms +step:904/1695 train_time:86879ms step_avg:96.11ms +step:905/1695 train_time:86978ms step_avg:96.11ms +step:906/1695 train_time:87079ms step_avg:96.11ms +step:907/1695 train_time:87179ms step_avg:96.12ms +step:908/1695 train_time:87280ms step_avg:96.12ms +step:909/1695 train_time:87380ms step_avg:96.13ms +step:910/1695 train_time:87480ms step_avg:96.13ms +step:911/1695 train_time:87580ms step_avg:96.14ms +step:912/1695 train_time:87679ms step_avg:96.14ms +step:913/1695 train_time:87780ms step_avg:96.14ms +step:914/1695 train_time:87879ms step_avg:96.15ms +step:915/1695 train_time:87978ms step_avg:96.15ms +step:916/1695 train_time:88077ms step_avg:96.15ms +step:917/1695 train_time:88178ms step_avg:96.16ms +step:918/1695 train_time:88279ms step_avg:96.16ms +step:919/1695 train_time:88380ms step_avg:96.17ms +step:920/1695 train_time:88481ms step_avg:96.18ms +step:921/1695 train_time:88581ms step_avg:96.18ms +step:922/1695 train_time:88680ms step_avg:96.18ms +step:923/1695 train_time:88779ms step_avg:96.19ms +step:924/1695 train_time:88879ms step_avg:96.19ms +step:925/1695 train_time:88980ms step_avg:96.19ms +step:926/1695 train_time:89080ms step_avg:96.20ms +step:927/1695 train_time:89179ms step_avg:96.20ms +step:928/1695 train_time:89279ms step_avg:96.21ms +step:929/1695 train_time:89379ms step_avg:96.21ms +step:930/1695 train_time:89480ms step_avg:96.22ms +step:931/1695 train_time:89580ms step_avg:96.22ms +step:932/1695 train_time:89680ms step_avg:96.22ms +step:933/1695 train_time:89781ms step_avg:96.23ms +step:934/1695 train_time:89881ms step_avg:96.23ms +step:935/1695 train_time:89980ms step_avg:96.24ms +step:936/1695 train_time:90079ms step_avg:96.24ms +step:937/1695 train_time:90179ms step_avg:96.24ms +step:938/1695 train_time:90279ms step_avg:96.25ms +step:939/1695 train_time:90381ms step_avg:96.25ms +step:940/1695 train_time:90482ms step_avg:96.26ms +step:941/1695 train_time:90583ms step_avg:96.26ms +step:942/1695 train_time:90682ms step_avg:96.27ms +step:943/1695 train_time:90784ms step_avg:96.27ms +step:944/1695 train_time:90884ms step_avg:96.28ms +step:945/1695 train_time:90984ms step_avg:96.28ms +step:946/1695 train_time:91083ms step_avg:96.28ms +step:947/1695 train_time:91183ms step_avg:96.29ms +step:948/1695 train_time:91283ms step_avg:96.29ms +step:949/1695 train_time:91382ms step_avg:96.29ms +step:950/1695 train_time:91482ms step_avg:96.30ms +step:951/1695 train_time:91582ms step_avg:96.30ms +step:952/1695 train_time:91682ms step_avg:96.30ms +step:953/1695 train_time:91783ms step_avg:96.31ms +step:954/1695 train_time:91884ms step_avg:96.31ms +step:955/1695 train_time:91984ms step_avg:96.32ms +step:956/1695 train_time:92083ms step_avg:96.32ms +step:957/1695 train_time:92182ms step_avg:96.32ms +step:958/1695 train_time:92282ms step_avg:96.33ms +step:959/1695 train_time:92381ms step_avg:96.33ms +step:960/1695 train_time:92482ms step_avg:96.34ms +step:961/1695 train_time:92581ms step_avg:96.34ms +step:962/1695 train_time:92681ms step_avg:96.34ms +step:963/1695 train_time:92781ms step_avg:96.35ms +step:964/1695 train_time:92881ms step_avg:96.35ms +step:965/1695 train_time:92982ms step_avg:96.35ms +step:966/1695 train_time:93083ms step_avg:96.36ms +step:967/1695 train_time:93184ms step_avg:96.36ms +step:968/1695 train_time:93284ms step_avg:96.37ms +step:969/1695 train_time:93384ms step_avg:96.37ms +step:970/1695 train_time:93483ms step_avg:96.37ms +step:971/1695 train_time:93584ms step_avg:96.38ms +step:972/1695 train_time:93683ms step_avg:96.38ms +step:973/1695 train_time:93782ms step_avg:96.38ms +step:974/1695 train_time:93882ms step_avg:96.39ms +step:975/1695 train_time:93981ms step_avg:96.39ms +step:976/1695 train_time:94082ms step_avg:96.40ms +step:977/1695 train_time:94182ms step_avg:96.40ms +step:978/1695 train_time:94283ms step_avg:96.40ms +step:979/1695 train_time:94383ms step_avg:96.41ms +step:980/1695 train_time:94484ms step_avg:96.41ms +step:981/1695 train_time:94583ms step_avg:96.41ms +step:982/1695 train_time:94685ms step_avg:96.42ms +step:983/1695 train_time:94785ms step_avg:96.42ms +step:984/1695 train_time:94884ms step_avg:96.43ms +step:985/1695 train_time:94983ms step_avg:96.43ms +step:986/1695 train_time:95083ms step_avg:96.43ms +step:987/1695 train_time:95183ms step_avg:96.44ms +step:988/1695 train_time:95283ms step_avg:96.44ms +step:989/1695 train_time:95383ms step_avg:96.44ms +step:990/1695 train_time:95484ms step_avg:96.45ms +step:991/1695 train_time:95584ms step_avg:96.45ms +step:992/1695 train_time:95684ms step_avg:96.46ms +step:993/1695 train_time:95784ms step_avg:96.46ms +step:994/1695 train_time:95884ms step_avg:96.46ms +step:995/1695 train_time:95983ms step_avg:96.47ms +step:996/1695 train_time:96084ms step_avg:96.47ms +step:997/1695 train_time:96184ms step_avg:96.47ms +step:998/1695 train_time:96283ms step_avg:96.48ms +step:999/1695 train_time:96383ms step_avg:96.48ms +step:1000/1695 train_time:96484ms step_avg:96.48ms +step:1000/1695 val_loss:3.4915 train_time:96581ms step_avg:96.58ms +step:1001/1695 train_time:96608ms step_avg:96.51ms +step:1002/1695 train_time:96695ms step_avg:96.50ms +step:1003/1695 train_time:96797ms step_avg:96.51ms +step:1004/1695 train_time:96897ms step_avg:96.51ms +step:1005/1695 train_time:96996ms step_avg:96.51ms +step:1006/1695 train_time:97095ms step_avg:96.52ms +step:1007/1695 train_time:97194ms step_avg:96.52ms +step:1008/1695 train_time:97293ms step_avg:96.52ms +step:1009/1695 train_time:97393ms step_avg:96.52ms +step:1010/1695 train_time:97492ms step_avg:96.53ms +step:1011/1695 train_time:97594ms step_avg:96.53ms +step:1012/1695 train_time:97696ms step_avg:96.54ms +step:1013/1695 train_time:97798ms step_avg:96.54ms +step:1014/1695 train_time:97899ms step_avg:96.55ms +step:1015/1695 train_time:97998ms step_avg:96.55ms +step:1016/1695 train_time:98097ms step_avg:96.55ms +step:1017/1695 train_time:98197ms step_avg:96.56ms +step:1018/1695 train_time:98296ms step_avg:96.56ms +step:1019/1695 train_time:98396ms step_avg:96.56ms +step:1020/1695 train_time:98496ms step_avg:96.56ms +step:1021/1695 train_time:98597ms step_avg:96.57ms +step:1022/1695 train_time:98698ms step_avg:96.57ms +step:1023/1695 train_time:98799ms step_avg:96.58ms +step:1024/1695 train_time:98902ms step_avg:96.58ms +step:1025/1695 train_time:99001ms step_avg:96.59ms +step:1026/1695 train_time:99101ms step_avg:96.59ms +step:1027/1695 train_time:99200ms step_avg:96.59ms +step:1028/1695 train_time:99299ms step_avg:96.59ms +step:1029/1695 train_time:99399ms step_avg:96.60ms +step:1030/1695 train_time:99499ms step_avg:96.60ms +step:1031/1695 train_time:99600ms step_avg:96.60ms +step:1032/1695 train_time:99699ms step_avg:96.61ms +step:1033/1695 train_time:99800ms step_avg:96.61ms +step:1034/1695 train_time:99899ms step_avg:96.61ms +step:1035/1695 train_time:99999ms step_avg:96.62ms +step:1036/1695 train_time:100100ms step_avg:96.62ms +step:1037/1695 train_time:100199ms step_avg:96.62ms +step:1038/1695 train_time:100299ms step_avg:96.63ms +step:1039/1695 train_time:100399ms step_avg:96.63ms +step:1040/1695 train_time:100499ms step_avg:96.63ms +step:1041/1695 train_time:100599ms step_avg:96.64ms +step:1042/1695 train_time:100699ms step_avg:96.64ms +step:1043/1695 train_time:100799ms step_avg:96.64ms +step:1044/1695 train_time:100899ms step_avg:96.65ms +step:1045/1695 train_time:100998ms step_avg:96.65ms +step:1046/1695 train_time:101099ms step_avg:96.65ms +step:1047/1695 train_time:101198ms step_avg:96.66ms +step:1048/1695 train_time:101298ms step_avg:96.66ms +step:1049/1695 train_time:101399ms step_avg:96.66ms +step:1050/1695 train_time:101499ms step_avg:96.67ms +step:1051/1695 train_time:101599ms step_avg:96.67ms +step:1052/1695 train_time:101700ms step_avg:96.67ms +step:1053/1695 train_time:101801ms step_avg:96.68ms +step:1054/1695 train_time:101901ms step_avg:96.68ms +step:1055/1695 train_time:102000ms step_avg:96.68ms +step:1056/1695 train_time:102101ms step_avg:96.69ms +step:1057/1695 train_time:102200ms step_avg:96.69ms +step:1058/1695 train_time:102299ms step_avg:96.69ms +step:1059/1695 train_time:102399ms step_avg:96.69ms +step:1060/1695 train_time:102498ms step_avg:96.70ms +step:1061/1695 train_time:102598ms step_avg:96.70ms +step:1062/1695 train_time:102699ms step_avg:96.70ms +step:1063/1695 train_time:102799ms step_avg:96.71ms +step:1064/1695 train_time:102899ms step_avg:96.71ms +step:1065/1695 train_time:102998ms step_avg:96.71ms +step:1066/1695 train_time:103098ms step_avg:96.71ms +step:1067/1695 train_time:103199ms step_avg:96.72ms +step:1068/1695 train_time:103299ms step_avg:96.72ms +step:1069/1695 train_time:103399ms step_avg:96.73ms +step:1070/1695 train_time:103500ms step_avg:96.73ms +step:1071/1695 train_time:103601ms step_avg:96.73ms +step:1072/1695 train_time:103701ms step_avg:96.74ms +step:1073/1695 train_time:103800ms step_avg:96.74ms +step:1074/1695 train_time:103899ms step_avg:96.74ms +step:1075/1695 train_time:103999ms step_avg:96.74ms +step:1076/1695 train_time:104098ms step_avg:96.75ms +step:1077/1695 train_time:104199ms step_avg:96.75ms +step:1078/1695 train_time:104299ms step_avg:96.75ms +step:1079/1695 train_time:104400ms step_avg:96.76ms +step:1080/1695 train_time:104500ms step_avg:96.76ms +step:1081/1695 train_time:104599ms step_avg:96.76ms +step:1082/1695 train_time:104699ms step_avg:96.76ms +step:1083/1695 train_time:104800ms step_avg:96.77ms +step:1084/1695 train_time:104899ms step_avg:96.77ms +step:1085/1695 train_time:104999ms step_avg:96.77ms +step:1086/1695 train_time:105098ms step_avg:96.78ms +step:1087/1695 train_time:105198ms step_avg:96.78ms +step:1088/1695 train_time:105299ms step_avg:96.78ms +step:1089/1695 train_time:105399ms step_avg:96.78ms +step:1090/1695 train_time:105499ms step_avg:96.79ms +step:1091/1695 train_time:105600ms step_avg:96.79ms +step:1092/1695 train_time:105701ms step_avg:96.80ms +step:1093/1695 train_time:105800ms step_avg:96.80ms +step:1094/1695 train_time:105900ms step_avg:96.80ms +step:1095/1695 train_time:106000ms step_avg:96.80ms +step:1096/1695 train_time:106099ms step_avg:96.81ms +step:1097/1695 train_time:106199ms step_avg:96.81ms +step:1098/1695 train_time:106299ms step_avg:96.81ms +step:1099/1695 train_time:106398ms step_avg:96.81ms +step:1100/1695 train_time:106498ms step_avg:96.82ms +step:1101/1695 train_time:106598ms step_avg:96.82ms +step:1102/1695 train_time:106699ms step_avg:96.82ms +step:1103/1695 train_time:106799ms step_avg:96.83ms +step:1104/1695 train_time:106899ms step_avg:96.83ms +step:1105/1695 train_time:107000ms step_avg:96.83ms +step:1106/1695 train_time:107100ms step_avg:96.84ms +step:1107/1695 train_time:107199ms step_avg:96.84ms +step:1108/1695 train_time:107299ms step_avg:96.84ms +step:1109/1695 train_time:107399ms step_avg:96.84ms +step:1110/1695 train_time:107499ms step_avg:96.85ms +step:1111/1695 train_time:107598ms step_avg:96.85ms +step:1112/1695 train_time:107698ms step_avg:96.85ms +step:1113/1695 train_time:107799ms step_avg:96.85ms +step:1114/1695 train_time:107899ms step_avg:96.86ms +step:1115/1695 train_time:108000ms step_avg:96.86ms +step:1116/1695 train_time:108099ms step_avg:96.86ms +step:1117/1695 train_time:108199ms step_avg:96.87ms +step:1118/1695 train_time:108299ms step_avg:96.87ms +step:1119/1695 train_time:108399ms step_avg:96.87ms +step:1120/1695 train_time:108499ms step_avg:96.87ms +step:1121/1695 train_time:108599ms step_avg:96.88ms +step:1122/1695 train_time:108700ms step_avg:96.88ms +step:1123/1695 train_time:108800ms step_avg:96.88ms +step:1124/1695 train_time:108900ms step_avg:96.89ms +step:1125/1695 train_time:108999ms step_avg:96.89ms +step:1125/1695 val_loss:3.4400 train_time:109096ms step_avg:96.97ms +step:1126/1695 train_time:109124ms step_avg:96.91ms +step:1127/1695 train_time:109210ms step_avg:96.90ms +step:1128/1695 train_time:109311ms step_avg:96.91ms +step:1129/1695 train_time:109411ms step_avg:96.91ms +step:1130/1695 train_time:109510ms step_avg:96.91ms +step:1131/1695 train_time:109609ms step_avg:96.91ms +step:1132/1695 train_time:109709ms step_avg:96.92ms +step:1133/1695 train_time:109808ms step_avg:96.92ms +step:1134/1695 train_time:109908ms step_avg:96.92ms +step:1135/1695 train_time:110008ms step_avg:96.92ms +step:1136/1695 train_time:110109ms step_avg:96.93ms +step:1137/1695 train_time:110211ms step_avg:96.93ms +step:1138/1695 train_time:110313ms step_avg:96.94ms +step:1139/1695 train_time:110413ms step_avg:96.94ms +step:1140/1695 train_time:110513ms step_avg:96.94ms +step:1141/1695 train_time:110613ms step_avg:96.94ms +step:1142/1695 train_time:110713ms step_avg:96.95ms +step:1143/1695 train_time:110813ms step_avg:96.95ms +step:1144/1695 train_time:110913ms step_avg:96.95ms +step:1145/1695 train_time:111014ms step_avg:96.96ms +step:1146/1695 train_time:111115ms step_avg:96.96ms +step:1147/1695 train_time:111216ms step_avg:96.96ms +step:1148/1695 train_time:111317ms step_avg:96.97ms +step:1149/1695 train_time:111418ms step_avg:96.97ms +step:1150/1695 train_time:111519ms step_avg:96.97ms +step:1151/1695 train_time:111621ms step_avg:96.98ms +step:1152/1695 train_time:111722ms step_avg:96.98ms +step:1153/1695 train_time:111824ms step_avg:96.99ms +step:1154/1695 train_time:111924ms step_avg:96.99ms +step:1155/1695 train_time:112025ms step_avg:96.99ms +step:1156/1695 train_time:112126ms step_avg:96.99ms +step:1157/1695 train_time:112228ms step_avg:97.00ms +step:1158/1695 train_time:112330ms step_avg:97.00ms +step:1159/1695 train_time:112431ms step_avg:97.01ms +step:1160/1695 train_time:112530ms step_avg:97.01ms +step:1161/1695 train_time:112631ms step_avg:97.01ms +step:1162/1695 train_time:112731ms step_avg:97.01ms +step:1163/1695 train_time:112833ms step_avg:97.02ms +step:1164/1695 train_time:112933ms step_avg:97.02ms +step:1165/1695 train_time:113033ms step_avg:97.02ms +step:1166/1695 train_time:113134ms step_avg:97.03ms +step:1167/1695 train_time:113234ms step_avg:97.03ms +step:1168/1695 train_time:113335ms step_avg:97.03ms +step:1169/1695 train_time:113436ms step_avg:97.04ms +step:1170/1695 train_time:113538ms step_avg:97.04ms +step:1171/1695 train_time:113637ms step_avg:97.04ms +step:1172/1695 train_time:113740ms step_avg:97.05ms +step:1173/1695 train_time:113840ms step_avg:97.05ms +step:1174/1695 train_time:113942ms step_avg:97.05ms +step:1175/1695 train_time:114043ms step_avg:97.06ms +step:1176/1695 train_time:114145ms step_avg:97.06ms +step:1177/1695 train_time:114247ms step_avg:97.07ms +step:1178/1695 train_time:114349ms step_avg:97.07ms +step:1179/1695 train_time:114453ms step_avg:97.08ms +step:1180/1695 train_time:114553ms step_avg:97.08ms +step:1181/1695 train_time:114653ms step_avg:97.08ms +step:1182/1695 train_time:114754ms step_avg:97.08ms +step:1183/1695 train_time:114853ms step_avg:97.09ms +step:1184/1695 train_time:114955ms step_avg:97.09ms +step:1185/1695 train_time:115057ms step_avg:97.09ms +step:1186/1695 train_time:115159ms step_avg:97.10ms +step:1187/1695 train_time:115260ms step_avg:97.10ms +step:1188/1695 train_time:115362ms step_avg:97.11ms +step:1189/1695 train_time:115463ms step_avg:97.11ms +step:1190/1695 train_time:115564ms step_avg:97.11ms +step:1191/1695 train_time:115664ms step_avg:97.12ms +step:1192/1695 train_time:115765ms step_avg:97.12ms +step:1193/1695 train_time:115866ms step_avg:97.12ms +step:1194/1695 train_time:115969ms step_avg:97.13ms +step:1195/1695 train_time:116070ms step_avg:97.13ms +step:1196/1695 train_time:116171ms step_avg:97.13ms +step:1197/1695 train_time:116273ms step_avg:97.14ms +step:1198/1695 train_time:116373ms step_avg:97.14ms +step:1199/1695 train_time:116473ms step_avg:97.14ms +step:1200/1695 train_time:116573ms step_avg:97.14ms +step:1201/1695 train_time:116673ms step_avg:97.15ms +step:1202/1695 train_time:116775ms step_avg:97.15ms +step:1203/1695 train_time:116877ms step_avg:97.15ms +step:1204/1695 train_time:116979ms step_avg:97.16ms +step:1205/1695 train_time:117080ms step_avg:97.16ms +step:1206/1695 train_time:117181ms step_avg:97.16ms +step:1207/1695 train_time:117282ms step_avg:97.17ms +step:1208/1695 train_time:117384ms step_avg:97.17ms +step:1209/1695 train_time:117485ms step_avg:97.18ms +step:1210/1695 train_time:117586ms step_avg:97.18ms +step:1211/1695 train_time:117687ms step_avg:97.18ms +step:1212/1695 train_time:117789ms step_avg:97.19ms +step:1213/1695 train_time:117890ms step_avg:97.19ms +step:1214/1695 train_time:117991ms step_avg:97.19ms +step:1215/1695 train_time:118092ms step_avg:97.20ms +step:1216/1695 train_time:118194ms step_avg:97.20ms +step:1217/1695 train_time:118293ms step_avg:97.20ms +step:1218/1695 train_time:118395ms step_avg:97.20ms +step:1219/1695 train_time:118495ms step_avg:97.21ms +step:1220/1695 train_time:118597ms step_avg:97.21ms +step:1221/1695 train_time:118698ms step_avg:97.21ms +step:1222/1695 train_time:118800ms step_avg:97.22ms +step:1223/1695 train_time:118903ms step_avg:97.22ms +step:1224/1695 train_time:119006ms step_avg:97.23ms +step:1225/1695 train_time:119107ms step_avg:97.23ms +step:1226/1695 train_time:119207ms step_avg:97.23ms +step:1227/1695 train_time:119308ms step_avg:97.24ms +step:1228/1695 train_time:119409ms step_avg:97.24ms +step:1229/1695 train_time:119510ms step_avg:97.24ms +step:1230/1695 train_time:119612ms step_avg:97.25ms +step:1231/1695 train_time:119713ms step_avg:97.25ms +step:1232/1695 train_time:119813ms step_avg:97.25ms +step:1233/1695 train_time:119914ms step_avg:97.25ms +step:1234/1695 train_time:120015ms step_avg:97.26ms +step:1235/1695 train_time:120115ms step_avg:97.26ms +step:1236/1695 train_time:120216ms step_avg:97.26ms +step:1237/1695 train_time:120318ms step_avg:97.27ms +step:1238/1695 train_time:120421ms step_avg:97.27ms +step:1239/1695 train_time:120522ms step_avg:97.27ms +step:1240/1695 train_time:120624ms step_avg:97.28ms +step:1241/1695 train_time:120725ms step_avg:97.28ms +step:1242/1695 train_time:120826ms step_avg:97.28ms +step:1243/1695 train_time:120927ms step_avg:97.29ms +step:1244/1695 train_time:121028ms step_avg:97.29ms +step:1245/1695 train_time:121130ms step_avg:97.29ms +step:1246/1695 train_time:121232ms step_avg:97.30ms +step:1247/1695 train_time:121332ms step_avg:97.30ms +step:1248/1695 train_time:121433ms step_avg:97.30ms +step:1249/1695 train_time:121533ms step_avg:97.30ms +step:1250/1695 train_time:121633ms step_avg:97.31ms +step:1250/1695 val_loss:3.3953 train_time:121731ms step_avg:97.38ms +step:1251/1695 train_time:121758ms step_avg:97.33ms +step:1252/1695 train_time:121845ms step_avg:97.32ms +step:1253/1695 train_time:121947ms step_avg:97.32ms +step:1254/1695 train_time:122048ms step_avg:97.33ms +step:1255/1695 train_time:122149ms step_avg:97.33ms +step:1256/1695 train_time:122249ms step_avg:97.33ms +step:1257/1695 train_time:122349ms step_avg:97.33ms +step:1258/1695 train_time:122450ms step_avg:97.34ms +step:1259/1695 train_time:122551ms step_avg:97.34ms +step:1260/1695 train_time:122651ms step_avg:97.34ms +step:1261/1695 train_time:122754ms step_avg:97.35ms +step:1262/1695 train_time:122855ms step_avg:97.35ms +step:1263/1695 train_time:122957ms step_avg:97.35ms +step:1264/1695 train_time:123057ms step_avg:97.36ms +step:1265/1695 train_time:123157ms step_avg:97.36ms +step:1266/1695 train_time:123258ms step_avg:97.36ms +step:1267/1695 train_time:123359ms step_avg:97.36ms +step:1268/1695 train_time:123460ms step_avg:97.37ms +step:1269/1695 train_time:123561ms step_avg:97.37ms +step:1270/1695 train_time:123662ms step_avg:97.37ms +step:1271/1695 train_time:123764ms step_avg:97.38ms +step:1272/1695 train_time:123864ms step_avg:97.38ms +step:1273/1695 train_time:123965ms step_avg:97.38ms +step:1274/1695 train_time:124066ms step_avg:97.38ms +step:1275/1695 train_time:124168ms step_avg:97.39ms +step:1276/1695 train_time:124272ms step_avg:97.39ms +step:1277/1695 train_time:124372ms step_avg:97.39ms +step:1278/1695 train_time:124473ms step_avg:97.40ms +step:1279/1695 train_time:124574ms step_avg:97.40ms +step:1280/1695 train_time:124674ms step_avg:97.40ms +step:1281/1695 train_time:124775ms step_avg:97.40ms +step:1282/1695 train_time:124875ms step_avg:97.41ms +step:1283/1695 train_time:124975ms step_avg:97.41ms +step:1284/1695 train_time:125075ms step_avg:97.41ms +step:1285/1695 train_time:125176ms step_avg:97.41ms +step:1286/1695 train_time:125277ms step_avg:97.42ms +step:1287/1695 train_time:125379ms step_avg:97.42ms +step:1288/1695 train_time:125480ms step_avg:97.42ms +step:1289/1695 train_time:125582ms step_avg:97.43ms +step:1290/1695 train_time:125683ms step_avg:97.43ms +step:1291/1695 train_time:125786ms step_avg:97.43ms +step:1292/1695 train_time:125886ms step_avg:97.43ms +step:1293/1695 train_time:125986ms step_avg:97.44ms +step:1294/1695 train_time:126089ms step_avg:97.44ms +step:1295/1695 train_time:126191ms step_avg:97.44ms +step:1296/1695 train_time:126292ms step_avg:97.45ms +step:1297/1695 train_time:126393ms step_avg:97.45ms +step:1298/1695 train_time:126494ms step_avg:97.45ms +step:1299/1695 train_time:126594ms step_avg:97.45ms +step:1300/1695 train_time:126694ms step_avg:97.46ms +step:1301/1695 train_time:126795ms step_avg:97.46ms +step:1302/1695 train_time:126896ms step_avg:97.46ms +step:1303/1695 train_time:126997ms step_avg:97.47ms +step:1304/1695 train_time:127098ms step_avg:97.47ms +step:1305/1695 train_time:127200ms step_avg:97.47ms +step:1306/1695 train_time:127303ms step_avg:97.48ms +step:1307/1695 train_time:127405ms step_avg:97.48ms +step:1308/1695 train_time:127506ms step_avg:97.48ms +step:1309/1695 train_time:127607ms step_avg:97.48ms +step:1310/1695 train_time:127709ms step_avg:97.49ms +step:1311/1695 train_time:127812ms step_avg:97.49ms +step:1312/1695 train_time:127913ms step_avg:97.49ms +step:1313/1695 train_time:128014ms step_avg:97.50ms +step:1314/1695 train_time:128114ms step_avg:97.50ms +step:1315/1695 train_time:128215ms step_avg:97.50ms +step:1316/1695 train_time:128315ms step_avg:97.50ms +step:1317/1695 train_time:128415ms step_avg:97.51ms +step:1318/1695 train_time:128517ms step_avg:97.51ms +step:1319/1695 train_time:128618ms step_avg:97.51ms +step:1320/1695 train_time:128720ms step_avg:97.51ms +step:1321/1695 train_time:128823ms step_avg:97.52ms +step:1322/1695 train_time:128924ms step_avg:97.52ms +step:1323/1695 train_time:129025ms step_avg:97.52ms +step:1324/1695 train_time:129127ms step_avg:97.53ms +step:1325/1695 train_time:129227ms step_avg:97.53ms +step:1326/1695 train_time:129329ms step_avg:97.53ms +step:1327/1695 train_time:129432ms step_avg:97.54ms +step:1328/1695 train_time:129533ms step_avg:97.54ms +step:1329/1695 train_time:129633ms step_avg:97.54ms +step:1330/1695 train_time:129733ms step_avg:97.54ms +step:1331/1695 train_time:129834ms step_avg:97.55ms +step:1332/1695 train_time:129936ms step_avg:97.55ms +step:1333/1695 train_time:130037ms step_avg:97.55ms +step:1334/1695 train_time:130138ms step_avg:97.55ms +step:1335/1695 train_time:130239ms step_avg:97.56ms +step:1336/1695 train_time:130342ms step_avg:97.56ms +step:1337/1695 train_time:130444ms step_avg:97.57ms +step:1338/1695 train_time:130545ms step_avg:97.57ms +step:1339/1695 train_time:130647ms step_avg:97.57ms +step:1340/1695 train_time:130748ms step_avg:97.57ms +step:1341/1695 train_time:130848ms step_avg:97.58ms +step:1342/1695 train_time:130949ms step_avg:97.58ms +step:1343/1695 train_time:131050ms step_avg:97.58ms +step:1344/1695 train_time:131152ms step_avg:97.58ms +step:1345/1695 train_time:131253ms step_avg:97.59ms +step:1346/1695 train_time:131355ms step_avg:97.59ms +step:1347/1695 train_time:131455ms step_avg:97.59ms +step:1348/1695 train_time:131556ms step_avg:97.59ms +step:1349/1695 train_time:131658ms step_avg:97.60ms +step:1350/1695 train_time:131759ms step_avg:97.60ms +step:1351/1695 train_time:131859ms step_avg:97.60ms +step:1352/1695 train_time:131961ms step_avg:97.60ms +step:1353/1695 train_time:132064ms step_avg:97.61ms +step:1354/1695 train_time:132165ms step_avg:97.61ms +step:1355/1695 train_time:132266ms step_avg:97.61ms +step:1356/1695 train_time:132369ms step_avg:97.62ms +step:1357/1695 train_time:132470ms step_avg:97.62ms +step:1358/1695 train_time:132571ms step_avg:97.62ms +step:1359/1695 train_time:132673ms step_avg:97.63ms +step:1360/1695 train_time:132773ms step_avg:97.63ms +step:1361/1695 train_time:132873ms step_avg:97.63ms +step:1362/1695 train_time:132973ms step_avg:97.63ms +step:1363/1695 train_time:133075ms step_avg:97.63ms +step:1364/1695 train_time:133176ms step_avg:97.64ms +step:1365/1695 train_time:133278ms step_avg:97.64ms +step:1366/1695 train_time:133379ms step_avg:97.64ms +step:1367/1695 train_time:133481ms step_avg:97.65ms +step:1368/1695 train_time:133584ms step_avg:97.65ms +step:1369/1695 train_time:133685ms step_avg:97.65ms +step:1370/1695 train_time:133785ms step_avg:97.65ms +step:1371/1695 train_time:133886ms step_avg:97.66ms +step:1372/1695 train_time:133987ms step_avg:97.66ms +step:1373/1695 train_time:134090ms step_avg:97.66ms +step:1374/1695 train_time:134192ms step_avg:97.66ms +step:1375/1695 train_time:134293ms step_avg:97.67ms +step:1375/1695 val_loss:3.3553 train_time:134391ms step_avg:97.74ms +step:1376/1695 train_time:134418ms step_avg:97.69ms +step:1377/1695 train_time:134504ms step_avg:97.68ms +step:1378/1695 train_time:134605ms step_avg:97.68ms +step:1379/1695 train_time:134706ms step_avg:97.68ms +step:1380/1695 train_time:134809ms step_avg:97.69ms +step:1381/1695 train_time:134910ms step_avg:97.69ms +step:1382/1695 train_time:135010ms step_avg:97.69ms +step:1383/1695 train_time:135110ms step_avg:97.69ms +step:1384/1695 train_time:135211ms step_avg:97.70ms +step:1385/1695 train_time:135313ms step_avg:97.70ms +step:1386/1695 train_time:135418ms step_avg:97.70ms +step:1387/1695 train_time:135520ms step_avg:97.71ms +step:1388/1695 train_time:135622ms step_avg:97.71ms +step:1389/1695 train_time:135725ms step_avg:97.71ms +step:1390/1695 train_time:135827ms step_avg:97.72ms +step:1391/1695 train_time:135929ms step_avg:97.72ms +step:1392/1695 train_time:136031ms step_avg:97.72ms +step:1393/1695 train_time:136133ms step_avg:97.73ms +step:1394/1695 train_time:136234ms step_avg:97.73ms +step:1395/1695 train_time:136335ms step_avg:97.73ms +step:1396/1695 train_time:136437ms step_avg:97.73ms +step:1397/1695 train_time:136541ms step_avg:97.74ms +step:1398/1695 train_time:136643ms step_avg:97.74ms +step:1399/1695 train_time:136744ms step_avg:97.74ms +step:1400/1695 train_time:136847ms step_avg:97.75ms +step:1401/1695 train_time:136948ms step_avg:97.75ms +step:1402/1695 train_time:137050ms step_avg:97.75ms +step:1403/1695 train_time:137152ms step_avg:97.76ms +step:1404/1695 train_time:137256ms step_avg:97.76ms +step:1405/1695 train_time:137357ms step_avg:97.76ms +step:1406/1695 train_time:137459ms step_avg:97.77ms +step:1407/1695 train_time:137561ms step_avg:97.77ms +step:1408/1695 train_time:137662ms step_avg:97.77ms +step:1409/1695 train_time:137766ms step_avg:97.78ms +step:1410/1695 train_time:137869ms step_avg:97.78ms +step:1411/1695 train_time:137970ms step_avg:97.78ms +step:1412/1695 train_time:138073ms step_avg:97.79ms +step:1413/1695 train_time:138174ms step_avg:97.79ms +step:1414/1695 train_time:138276ms step_avg:97.79ms +step:1415/1695 train_time:138378ms step_avg:97.79ms +step:1416/1695 train_time:138479ms step_avg:97.80ms +step:1417/1695 train_time:138579ms step_avg:97.80ms +step:1418/1695 train_time:138680ms step_avg:97.80ms +step:1419/1695 train_time:138782ms step_avg:97.80ms +step:1420/1695 train_time:138882ms step_avg:97.80ms +step:1421/1695 train_time:138984ms step_avg:97.81ms +step:1422/1695 train_time:139087ms step_avg:97.81ms +step:1423/1695 train_time:139190ms step_avg:97.81ms +step:1424/1695 train_time:139293ms step_avg:97.82ms +step:1425/1695 train_time:139395ms step_avg:97.82ms +step:1426/1695 train_time:139498ms step_avg:97.82ms +step:1427/1695 train_time:139599ms step_avg:97.83ms +step:1428/1695 train_time:139700ms step_avg:97.83ms +step:1429/1695 train_time:139801ms step_avg:97.83ms +step:1430/1695 train_time:139902ms step_avg:97.83ms +step:1431/1695 train_time:140005ms step_avg:97.84ms +step:1432/1695 train_time:140106ms step_avg:97.84ms +step:1433/1695 train_time:140209ms step_avg:97.84ms +step:1434/1695 train_time:140312ms step_avg:97.85ms +step:1435/1695 train_time:140417ms step_avg:97.85ms +step:1436/1695 train_time:140520ms step_avg:97.86ms +step:1437/1695 train_time:140621ms step_avg:97.86ms +step:1438/1695 train_time:140723ms step_avg:97.86ms +step:1439/1695 train_time:140826ms step_avg:97.86ms +step:1440/1695 train_time:140928ms step_avg:97.87ms +step:1441/1695 train_time:141030ms step_avg:97.87ms +step:1442/1695 train_time:141131ms step_avg:97.87ms +step:1443/1695 train_time:141233ms step_avg:97.87ms +step:1444/1695 train_time:141335ms step_avg:97.88ms +step:1445/1695 train_time:141436ms step_avg:97.88ms +step:1446/1695 train_time:141538ms step_avg:97.88ms +step:1447/1695 train_time:141639ms step_avg:97.88ms +step:1448/1695 train_time:141743ms step_avg:97.89ms +step:1449/1695 train_time:141844ms step_avg:97.89ms +step:1450/1695 train_time:141946ms step_avg:97.89ms +step:1451/1695 train_time:142048ms step_avg:97.90ms +step:1452/1695 train_time:142150ms step_avg:97.90ms +step:1453/1695 train_time:142253ms step_avg:97.90ms +step:1454/1695 train_time:142356ms step_avg:97.91ms +step:1455/1695 train_time:142458ms step_avg:97.91ms +step:1456/1695 train_time:142560ms step_avg:97.91ms +step:1457/1695 train_time:142662ms step_avg:97.92ms +step:1458/1695 train_time:142765ms step_avg:97.92ms +step:1459/1695 train_time:142867ms step_avg:97.92ms +step:1460/1695 train_time:142968ms step_avg:97.92ms +step:1461/1695 train_time:143071ms step_avg:97.93ms +step:1462/1695 train_time:143173ms step_avg:97.93ms +step:1463/1695 train_time:143275ms step_avg:97.93ms +step:1464/1695 train_time:143377ms step_avg:97.94ms +step:1465/1695 train_time:143477ms step_avg:97.94ms +step:1466/1695 train_time:143579ms step_avg:97.94ms +step:1467/1695 train_time:143679ms step_avg:97.94ms +step:1468/1695 train_time:143781ms step_avg:97.94ms +step:1469/1695 train_time:143884ms step_avg:97.95ms +step:1470/1695 train_time:143986ms step_avg:97.95ms +step:1471/1695 train_time:144088ms step_avg:97.95ms +step:1472/1695 train_time:144190ms step_avg:97.96ms +step:1473/1695 train_time:144292ms step_avg:97.96ms +step:1474/1695 train_time:144393ms step_avg:97.96ms +step:1475/1695 train_time:144495ms step_avg:97.96ms +step:1476/1695 train_time:144597ms step_avg:97.97ms +step:1477/1695 train_time:144699ms step_avg:97.97ms +step:1478/1695 train_time:144800ms step_avg:97.97ms +step:1479/1695 train_time:144901ms step_avg:97.97ms +step:1480/1695 train_time:145003ms step_avg:97.98ms +step:1481/1695 train_time:145107ms step_avg:97.98ms +step:1482/1695 train_time:145208ms step_avg:97.98ms +step:1483/1695 train_time:145310ms step_avg:97.98ms +step:1484/1695 train_time:145413ms step_avg:97.99ms +step:1485/1695 train_time:145515ms step_avg:97.99ms +step:1486/1695 train_time:145616ms step_avg:97.99ms +step:1487/1695 train_time:145717ms step_avg:97.99ms +step:1488/1695 train_time:145820ms step_avg:98.00ms +step:1489/1695 train_time:145923ms step_avg:98.00ms +step:1490/1695 train_time:146026ms step_avg:98.00ms +step:1491/1695 train_time:146127ms step_avg:98.01ms +step:1492/1695 train_time:146229ms step_avg:98.01ms +step:1493/1695 train_time:146330ms step_avg:98.01ms +step:1494/1695 train_time:146432ms step_avg:98.01ms +step:1495/1695 train_time:146535ms step_avg:98.02ms +step:1496/1695 train_time:146637ms step_avg:98.02ms +step:1497/1695 train_time:146738ms step_avg:98.02ms +step:1498/1695 train_time:146840ms step_avg:98.02ms +step:1499/1695 train_time:146942ms step_avg:98.03ms +step:1500/1695 train_time:147044ms step_avg:98.03ms +step:1500/1695 val_loss:3.3201 train_time:147143ms step_avg:98.10ms +step:1501/1695 train_time:147170ms step_avg:98.05ms +step:1502/1695 train_time:147256ms step_avg:98.04ms +step:1503/1695 train_time:147358ms step_avg:98.04ms +step:1504/1695 train_time:147460ms step_avg:98.05ms +step:1505/1695 train_time:147562ms step_avg:98.05ms +step:1506/1695 train_time:147663ms step_avg:98.05ms +step:1507/1695 train_time:147764ms step_avg:98.05ms +step:1508/1695 train_time:147865ms step_avg:98.05ms +step:1509/1695 train_time:147968ms step_avg:98.06ms +step:1510/1695 train_time:148070ms step_avg:98.06ms +step:1511/1695 train_time:148173ms step_avg:98.06ms +step:1512/1695 train_time:148275ms step_avg:98.07ms +step:1513/1695 train_time:148376ms step_avg:98.07ms +step:1514/1695 train_time:148479ms step_avg:98.07ms +step:1515/1695 train_time:148583ms step_avg:98.07ms +step:1516/1695 train_time:148685ms step_avg:98.08ms +step:1517/1695 train_time:148785ms step_avg:98.08ms +step:1518/1695 train_time:148887ms step_avg:98.08ms +step:1519/1695 train_time:148990ms step_avg:98.08ms +step:1520/1695 train_time:149091ms step_avg:98.09ms +step:1521/1695 train_time:149192ms step_avg:98.09ms +step:1522/1695 train_time:149294ms step_avg:98.09ms +step:1523/1695 train_time:149398ms step_avg:98.09ms +step:1524/1695 train_time:149503ms step_avg:98.10ms +step:1525/1695 train_time:149606ms step_avg:98.10ms +step:1526/1695 train_time:149708ms step_avg:98.10ms +step:1527/1695 train_time:149809ms step_avg:98.11ms +step:1528/1695 train_time:149915ms step_avg:98.11ms +step:1529/1695 train_time:150016ms step_avg:98.11ms +step:1530/1695 train_time:150118ms step_avg:98.12ms +step:1531/1695 train_time:150220ms step_avg:98.12ms +step:1532/1695 train_time:150322ms step_avg:98.12ms +step:1533/1695 train_time:150424ms step_avg:98.12ms +step:1534/1695 train_time:150526ms step_avg:98.13ms +step:1535/1695 train_time:150628ms step_avg:98.13ms +step:1536/1695 train_time:150731ms step_avg:98.13ms +step:1537/1695 train_time:150833ms step_avg:98.13ms +step:1538/1695 train_time:150934ms step_avg:98.14ms +step:1539/1695 train_time:151035ms step_avg:98.14ms +step:1540/1695 train_time:151136ms step_avg:98.14ms +step:1541/1695 train_time:151239ms step_avg:98.14ms +step:1542/1695 train_time:151344ms step_avg:98.15ms +step:1543/1695 train_time:151446ms step_avg:98.15ms +step:1544/1695 train_time:151548ms step_avg:98.15ms +step:1545/1695 train_time:151649ms step_avg:98.15ms +step:1546/1695 train_time:151751ms step_avg:98.16ms +step:1547/1695 train_time:151855ms step_avg:98.16ms +step:1548/1695 train_time:151956ms step_avg:98.16ms +step:1549/1695 train_time:152058ms step_avg:98.17ms +step:1550/1695 train_time:152160ms step_avg:98.17ms +step:1551/1695 train_time:152262ms step_avg:98.17ms +step:1552/1695 train_time:152364ms step_avg:98.17ms +step:1553/1695 train_time:152467ms step_avg:98.18ms +step:1554/1695 train_time:152569ms step_avg:98.18ms +step:1555/1695 train_time:152670ms step_avg:98.18ms +step:1556/1695 train_time:152773ms step_avg:98.18ms +step:1557/1695 train_time:152876ms step_avg:98.19ms +step:1558/1695 train_time:152978ms step_avg:98.19ms +step:1559/1695 train_time:153081ms step_avg:98.19ms +step:1560/1695 train_time:153181ms step_avg:98.19ms +step:1561/1695 train_time:153283ms step_avg:98.20ms +step:1562/1695 train_time:153385ms step_avg:98.20ms +step:1563/1695 train_time:153488ms step_avg:98.20ms +step:1564/1695 train_time:153590ms step_avg:98.20ms +step:1565/1695 train_time:153691ms step_avg:98.21ms +step:1566/1695 train_time:153793ms step_avg:98.21ms +step:1567/1695 train_time:153894ms step_avg:98.21ms +step:1568/1695 train_time:153995ms step_avg:98.21ms +step:1569/1695 train_time:154096ms step_avg:98.21ms +step:1570/1695 train_time:154198ms step_avg:98.22ms +step:1571/1695 train_time:154299ms step_avg:98.22ms +step:1572/1695 train_time:154401ms step_avg:98.22ms +step:1573/1695 train_time:154504ms step_avg:98.22ms +step:1574/1695 train_time:154606ms step_avg:98.22ms +step:1575/1695 train_time:154708ms step_avg:98.23ms +step:1576/1695 train_time:154810ms step_avg:98.23ms +step:1577/1695 train_time:154914ms step_avg:98.23ms +step:1578/1695 train_time:155015ms step_avg:98.23ms +step:1579/1695 train_time:155116ms step_avg:98.24ms +step:1580/1695 train_time:155218ms step_avg:98.24ms +step:1581/1695 train_time:155319ms step_avg:98.24ms +step:1582/1695 train_time:155421ms step_avg:98.24ms +step:1583/1695 train_time:155524ms step_avg:98.25ms +step:1584/1695 train_time:155626ms step_avg:98.25ms +step:1585/1695 train_time:155728ms step_avg:98.25ms +step:1586/1695 train_time:155831ms step_avg:98.25ms +step:1587/1695 train_time:155933ms step_avg:98.26ms +step:1588/1695 train_time:156034ms step_avg:98.26ms +step:1589/1695 train_time:156135ms step_avg:98.26ms +step:1590/1695 train_time:156237ms step_avg:98.26ms +step:1591/1695 train_time:156340ms step_avg:98.26ms +step:1592/1695 train_time:156442ms step_avg:98.27ms +step:1593/1695 train_time:156543ms step_avg:98.27ms +step:1594/1695 train_time:156647ms step_avg:98.27ms +step:1595/1695 train_time:156748ms step_avg:98.27ms +step:1596/1695 train_time:156850ms step_avg:98.28ms +step:1597/1695 train_time:156952ms step_avg:98.28ms +step:1598/1695 train_time:157055ms step_avg:98.28ms +step:1599/1695 train_time:157156ms step_avg:98.28ms +step:1600/1695 train_time:157257ms step_avg:98.29ms +step:1601/1695 train_time:157360ms step_avg:98.29ms +step:1602/1695 train_time:157461ms step_avg:98.29ms +step:1603/1695 train_time:157563ms step_avg:98.29ms +step:1604/1695 train_time:157664ms step_avg:98.29ms +step:1605/1695 train_time:157767ms step_avg:98.30ms +step:1606/1695 train_time:157870ms step_avg:98.30ms +step:1607/1695 train_time:157972ms step_avg:98.30ms +step:1608/1695 train_time:158074ms step_avg:98.30ms +step:1609/1695 train_time:158175ms step_avg:98.31ms +step:1610/1695 train_time:158277ms step_avg:98.31ms +step:1611/1695 train_time:158380ms step_avg:98.31ms +step:1612/1695 train_time:158481ms step_avg:98.31ms +step:1613/1695 train_time:158582ms step_avg:98.32ms +step:1614/1695 train_time:158683ms step_avg:98.32ms +step:1615/1695 train_time:158784ms step_avg:98.32ms +step:1616/1695 train_time:158887ms step_avg:98.32ms +step:1617/1695 train_time:158990ms step_avg:98.32ms +step:1618/1695 train_time:159093ms step_avg:98.33ms +step:1619/1695 train_time:159195ms step_avg:98.33ms +step:1620/1695 train_time:159297ms step_avg:98.33ms +step:1621/1695 train_time:159399ms step_avg:98.33ms +step:1622/1695 train_time:159499ms step_avg:98.34ms +step:1623/1695 train_time:159600ms step_avg:98.34ms +step:1624/1695 train_time:159702ms step_avg:98.34ms +step:1625/1695 train_time:159806ms step_avg:98.34ms +step:1625/1695 val_loss:3.2915 train_time:159906ms step_avg:98.40ms +step:1626/1695 train_time:159932ms step_avg:98.36ms +step:1627/1695 train_time:160019ms step_avg:98.35ms +step:1628/1695 train_time:160122ms step_avg:98.36ms +step:1629/1695 train_time:160225ms step_avg:98.36ms +step:1630/1695 train_time:160326ms step_avg:98.36ms +step:1631/1695 train_time:160428ms step_avg:98.36ms +step:1632/1695 train_time:160529ms step_avg:98.36ms +step:1633/1695 train_time:160629ms step_avg:98.36ms +step:1634/1695 train_time:160731ms step_avg:98.37ms +step:1635/1695 train_time:160834ms step_avg:98.37ms +step:1636/1695 train_time:160938ms step_avg:98.37ms +step:1637/1695 train_time:161041ms step_avg:98.38ms +step:1638/1695 train_time:161143ms step_avg:98.38ms +step:1639/1695 train_time:161246ms step_avg:98.38ms +step:1640/1695 train_time:161349ms step_avg:98.38ms +step:1641/1695 train_time:161452ms step_avg:98.39ms +step:1642/1695 train_time:161554ms step_avg:98.39ms +step:1643/1695 train_time:161656ms step_avg:98.39ms +step:1644/1695 train_time:161759ms step_avg:98.39ms +step:1645/1695 train_time:161861ms step_avg:98.40ms +step:1646/1695 train_time:161965ms step_avg:98.40ms +step:1647/1695 train_time:162069ms step_avg:98.40ms +step:1648/1695 train_time:162173ms step_avg:98.41ms +step:1649/1695 train_time:162276ms step_avg:98.41ms +step:1650/1695 train_time:162379ms step_avg:98.41ms +step:1651/1695 train_time:162482ms step_avg:98.41ms +step:1652/1695 train_time:162585ms step_avg:98.42ms +step:1653/1695 train_time:162688ms step_avg:98.42ms +step:1654/1695 train_time:162790ms step_avg:98.42ms +step:1655/1695 train_time:162893ms step_avg:98.42ms +step:1656/1695 train_time:162998ms step_avg:98.43ms +step:1657/1695 train_time:163100ms step_avg:98.43ms +step:1658/1695 train_time:163203ms step_avg:98.43ms +step:1659/1695 train_time:163308ms step_avg:98.44ms +step:1660/1695 train_time:163411ms step_avg:98.44ms +step:1661/1695 train_time:163516ms step_avg:98.44ms +step:1662/1695 train_time:163620ms step_avg:98.45ms +step:1663/1695 train_time:163723ms step_avg:98.45ms +step:1664/1695 train_time:163825ms step_avg:98.45ms +step:1665/1695 train_time:163931ms step_avg:98.46ms +step:1666/1695 train_time:164034ms step_avg:98.46ms +step:1667/1695 train_time:164136ms step_avg:98.46ms +step:1668/1695 train_time:164242ms step_avg:98.47ms +step:1669/1695 train_time:164346ms step_avg:98.47ms +step:1670/1695 train_time:164448ms step_avg:98.47ms +step:1671/1695 train_time:164550ms step_avg:98.47ms +step:1672/1695 train_time:164654ms step_avg:98.48ms +step:1673/1695 train_time:164757ms step_avg:98.48ms +step:1674/1695 train_time:164860ms step_avg:98.48ms +step:1675/1695 train_time:164963ms step_avg:98.49ms +step:1676/1695 train_time:165067ms step_avg:98.49ms +step:1677/1695 train_time:165168ms step_avg:98.49ms +step:1678/1695 train_time:165272ms step_avg:98.49ms +step:1679/1695 train_time:165377ms step_avg:98.50ms +step:1680/1695 train_time:165480ms step_avg:98.50ms +step:1681/1695 train_time:165584ms step_avg:98.50ms +step:1682/1695 train_time:165690ms step_avg:98.51ms +step:1683/1695 train_time:165791ms step_avg:98.51ms +step:1684/1695 train_time:165894ms step_avg:98.51ms +step:1685/1695 train_time:165997ms step_avg:98.51ms +step:1686/1695 train_time:166100ms step_avg:98.52ms +step:1687/1695 train_time:166203ms step_avg:98.52ms +step:1688/1695 train_time:166305ms step_avg:98.52ms +step:1689/1695 train_time:166406ms step_avg:98.52ms +step:1690/1695 train_time:166508ms step_avg:98.53ms +step:1691/1695 train_time:166611ms step_avg:98.53ms +step:1692/1695 train_time:166713ms step_avg:98.53ms +step:1693/1695 train_time:166816ms step_avg:98.53ms +step:1694/1695 train_time:166920ms step_avg:98.54ms +step:1695/1695 train_time:167023ms step_avg:98.54ms +step:1695/1695 val_loss:3.2786 train_time:167122ms step_avg:98.60ms +peak memory allocated: 34004 MiB reserved: 49180 MiB diff --git a/records/082325_SparseAttnGate/README.md b/records/082325_SparseAttnGate/README.md new file mode 100644 index 000000000..584f392d1 --- /dev/null +++ b/records/082325_SparseAttnGate/README.md @@ -0,0 +1,45 @@ +## New record 08/23/25 + +1. Included WR improvements on Triton and grad batching from https://github.com/KellerJordan/modded-nanogpt/pull/109 by @byronxu99 +2. Added a sparse attention gate on the attention output to enable a context based no-op. Found the mechanism was performant with only 12 active dimensions from the residual stream. If curious, here is a related blog post from an earlier investigation into non-sparse attention gate with detailed plots: https://medium.com/@larry36d/modulating-attention-scores-cc0bcd853f06. The blog demonstrates how the attention gate reduces the need for the bos_token to function as an attention sink. This is particularly relevant in a sliding window attention context because the bos_token is not always in the context window. ROPE embeddings cause the bos_token attention sink to change based on relative distance, whereas a sparse attention gate is indifferent to distance from start of sample. Estimate of impact: 50 steps fewer, with slight increase in time per step. +3. As a follow-on from 2: Reduced number of iterations from 1750 to 1695. +4. Reverted the lm head scaling changes made on Feb 10th: https://github.com/KellerJordan/modded-nanogpt/commit/85a0a5201f08c4d6bb288ef348bb252d9c33e132. When tested on a single A100, reverting this change drops the L2 norm of the LM head weights from 250 down to 10. The logits need to express values roughly from -10 to 10 in order to capture the range of token probabilities. Dividing by 27.5 (x.size(-1)**0.5) was causing the weights to grow substantially to accomplish this, since the residual stream was being normed prior to the lm_head. The second moment estimate of Adam depends on the parameter scale, and the Adam learning rates were likely heavily tuned prior to the Feb 10th update. If curious, more details near end of this blog post: https://medium.com/@larry36d/exploration-log-exploring-initializing-transformers-with-bigram-distribution-70f9c8800b21. Estimate of impact: 5-10 steps. (in this case just a cleaner cut below 3.28) +5. Chose to keep the minimum lr at 0.1. The bos_align record decreased the minimum lr to 0.05, and a later refactor, perhaps unintentionally, moved it back to 0.1. On further testing, the impact of this value on mean loss is marginal, but lower minimum lr appear to increase the variance of the final loss, making testing more challenging. Lower minimum lr may have higher variance because its committing to diving deep in the local space earlier, and is somewhat rolling the dice on if its a promising region or not. On reflection, I likely originally picked 0.05 because taking the min loss over a grid search will naturally bias to higher variance configurations, which is the opposite of what we want. + + +Validated results (p=0.0059) with 14 runs: +``` +import scipy.stats +import torch + +accs = [3.2774, 3.2782, 3.2796, 3.2815, 3.276 , 3.2777, 3.2784, 3.2795, + 3.281 , 3.2802, 3.2767, 3.2772, 3.28 , 3.2786 + ] +times = [ + 168.627, 169.037, 169.003, 168.727, 168.647, 169.024, 168.917, + 168.999, 168.728, 169.07 , 168.981, 168.938, 168.718, 167.122] + +print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) +# p=0.0059 + +print('acc:',torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0016), tensor(3.2787)) + +print('time:',torch.std_mean(torch.tensor(times))) +# time: (tensor(0.4946), tensor(168.7527)) +# Running on fresh cluster gave 167.695. actively working in jupyter notebooks on same machines during these runs may be adding variance to timing +``` + +###Negative and neutral test results during this process: + +1. Initialize embedding tokens using bigram distribution. Bigram statistics can be calculated for 100 million tokens in ~1 second or less. I tested initializing the embedding layer using `z = relu(log(p(y|x)/p(y))); embed = norm(rand_linear(z))`, where p(y|x) is the bigram prob of token y given x. This initialization makes it so that tokens with similar bigram statistics will have similar embeddings. If I froze the embedding layer, this initialization performed better than random initialization. However, for non frozen embeddings, the impact was not statistically significant. +2. Weight freezing during training. Since the majority of the time on each step is spent computing the gradient, freezing a subset of weights can substantially decrease time per step. Unfortunately, all combinations tested of this failed to yield an improvement. Typical matmul ops require N FLOPS on the forward pass and 2N FLOPs on the backwards pass. The 2N is to compute the gradient with respect to the weights to update the weights, and the gradient with respect to the data to pass the gradient onwards. The torch compiler is smart enough to compute only N FLOPS on the backwards pass for leaf operations. To leverage this, I tested updating the first 3 layers to run in parallel, and then froze the embedding after a portion of training, such that 3 layers became leaf operations. The change was not kept as the performance drop outweighed the speedup. +3. Logit shift parameter. The residual space activations for all positions are heavily aligned away (>120 degrees) from the lm_head vectors of tokens that never appear in the training set. In other words, the ~400 tokens that never appear in the 50348 vocab size (including the 91 padding vocab entries) may be skewing the topology of the activations in the residual stream. Adding a simple logits += logit_shift enables the model to learn the unigram distribution directly (or even just a static variable that is -inf on padding tokens), without disrupting the residual space. Unfortunately, my implementation of this change was giving memory issues on an A100. On the H100 setup, the change dropped the loss by 0.01 but was slightly edged out by the increase in time per step. I don't have the budget to fiddle substantially with params I can't test on an A100. If a more compute optimized version can be found, this is an easy improvement to the loss, likely equivalent to 50+ steps. +4. Removing torch.compile on zeropower_via_newtonschulz5(). Surprisingly, the torch compiler makes the output of newtonschulz() vary based on the batch dimension size, with a 2% change depending on the batch size. This is relevant when we are batching kqv in one op. This appears to occur because of rounding issues with bfloat16 and some internal accumulations the compiler is altering, as the percent diff drops to less then 0.1% for float32. On an A100 removing the compile gave an improvement when I was testing different batch sizes, but the change was not statistically significant on H100 w/ fp8 lm_head. Unclear exactly what is going on here, but noting that bfloat16 can lead to very unintuitive consequences. +5. Megabatch NetwonSchulz. Inspired by @byronxu99, I tested further impacts of batching for zeropower_via_newtonschulz5(). The results were quite surprising on an A100. The run time was heavily dependent on the batch size, with larger batch sizes running up to twice as fast, based on initial testing (honestly need to sanity check this, seemed too crazy). As a result, I experimented with setting all MLP params in 1 contiguous variable and doing a single iteration of zeropower_via_newtonschulz5(), with [3,4*768,768] input to each GPU as a single pass, and [6,768,768] for Attn to each GPU as a second pass. This gave a total of only 2 iterations of zeropower_via_newtonschulz5() on each GPU per step. I was running into memory errors on the 8H100 setup, and need to get a cheaper distributed setup before I test further. +6. 0.5 init weighting for x0 stream instead of 0. At the end of training on a A100, the x0 weight for many layers is 50x higher than the x weight. Updating the weighting to 0.5 gave a statistically significant improvement on A100, but this was not replicated on the 8H100 setup with fp8 lmhead. +7. Normalize value embedding inputs during forward pass. Seemed like a natural thing to do given norms on the input embedding and the existing lambda to scale value weights. However, this yielded worse performance, perhaps because the value embeddings need to have much high weight than the values and the lambda scaling parameter was not tuned to handle this itself. +8. Renormalize embedding in place between each forward pass. The L2 norm of the embedding layer is climbing from 27 to 500 over the course of training, leading to a different effective learning rate depending on the stage of training. Normalizing this parameter may enable the lr to be tuned more precisely. However, I found norm() still needed to be included in the forward pass for an accurate grad calc, at which point the compute penalty for a second norm outside the forward pass became not worthwhile. +9. Removing value entirely (only use value embedding) for first and last 3 layers. The trained weights indicate that the value embedding is dominating the calculated attention value, and I can save some matmul ops if I can drop 6 layers of value calcs. The change cost roughly 0.015 loss, which unfortunately was worth more than the speedup achieved based on the parameters used. +10. Bigram full initialization. Similar to 1, I tested initializing the lm_head and embedding layer to approximate the bigram distribution. (Bigram could in theory cause learning to start around 5.7 loss, with potentially better generalization during training). Unfortunately, it is not analytically simple to set embed and lm_head to achieve a known bigram distribution, because of the nonlinearity of the softmax. Attempting to approximate this yielded worse performance than random initialization. +11. Dual loss on bigram distribution. I tested having the first X iterations minimize a combination of the next token prediction loss, along with the bigram distribution for that token. Intuition was that since I can compute the bigram distribution of 100 million tokens in 1s, the bigram distribution encodes a higher density of information than a single high variance loss signal of a 500,000 token batch. However, the 50,000x50,000 bigram matrix proved too bulky for compute efficient steps. diff --git a/records/082325_SparseAttnGate/a39b1ae8-3a2a-4952-8032-13183b157053.txt b/records/082325_SparseAttnGate/a39b1ae8-3a2a-4952-8032-13183b157053.txt new file mode 100644 index 000000000..080453143 --- /dev/null +++ b/records/082325_SparseAttnGate/a39b1ae8-3a2a-4952-8032-13183b157053.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:40:05 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 314814 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 314815 C /usr/bin/python3 614MiB | +| 0 N/A N/A 314816 C /usr/bin/python3 614MiB | +| 0 N/A N/A 314817 C /usr/bin/python3 614MiB | +| 0 N/A N/A 314818 C /usr/bin/python3 614MiB | +| 0 N/A N/A 314819 C /usr/bin/python3 614MiB | +| 0 N/A N/A 314820 C /usr/bin/python3 614MiB | +| 0 N/A N/A 314821 C /usr/bin/python3 614MiB | +| 1 N/A N/A 314815 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 314816 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 314817 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 314818 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 314819 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 314820 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 314821 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:153ms step_avg:152.82ms +step:2/1695 train_time:177ms step_avg:88.47ms +step:3/1695 train_time:250ms step_avg:83.26ms +step:4/1695 train_time:341ms step_avg:85.31ms +step:5/1695 train_time:435ms step_avg:86.95ms +step:6/1695 train_time:526ms step_avg:87.74ms +step:7/1695 train_time:619ms step_avg:88.47ms +step:8/1695 train_time:713ms step_avg:89.11ms +step:9/1695 train_time:806ms step_avg:89.59ms +step:10/1695 train_time:900ms step_avg:89.95ms +step:11/1695 train_time:993ms step_avg:90.28ms +step:12/1695 train_time:1088ms step_avg:90.64ms +step:13/1695 train_time:1182ms step_avg:90.95ms +step:14/1695 train_time:1278ms step_avg:91.28ms +step:15/1695 train_time:1372ms step_avg:91.44ms +step:16/1695 train_time:1465ms step_avg:91.59ms +step:17/1695 train_time:1559ms step_avg:91.70ms +step:18/1695 train_time:1653ms step_avg:91.82ms +step:19/1695 train_time:1746ms step_avg:91.89ms +step:20/1695 train_time:1839ms step_avg:91.96ms +step:21/1695 train_time:1932ms step_avg:92.01ms +step:22/1695 train_time:2026ms step_avg:92.09ms +step:23/1695 train_time:2119ms step_avg:92.15ms +step:24/1695 train_time:2214ms step_avg:92.26ms +step:25/1695 train_time:2309ms step_avg:92.34ms +step:26/1695 train_time:2403ms step_avg:92.41ms +step:27/1695 train_time:2496ms step_avg:92.45ms +step:28/1695 train_time:2590ms step_avg:92.51ms +step:29/1695 train_time:2684ms step_avg:92.56ms +step:30/1695 train_time:2777ms step_avg:92.58ms +step:31/1695 train_time:2871ms step_avg:92.62ms +step:32/1695 train_time:2964ms step_avg:92.64ms +step:33/1695 train_time:3058ms step_avg:92.65ms +step:34/1695 train_time:3151ms step_avg:92.69ms +step:35/1695 train_time:3245ms step_avg:92.72ms +step:36/1695 train_time:3339ms step_avg:92.75ms +step:37/1695 train_time:3434ms step_avg:92.80ms +step:38/1695 train_time:3528ms step_avg:92.84ms +step:39/1695 train_time:3622ms step_avg:92.87ms +step:40/1695 train_time:3716ms step_avg:92.89ms +step:41/1695 train_time:3810ms step_avg:92.93ms +step:42/1695 train_time:3905ms step_avg:92.99ms +step:43/1695 train_time:3999ms step_avg:93.00ms +step:44/1695 train_time:4093ms step_avg:93.03ms +step:45/1695 train_time:4187ms step_avg:93.03ms +step:46/1695 train_time:4279ms step_avg:93.03ms +step:47/1695 train_time:4373ms step_avg:93.04ms +step:48/1695 train_time:4467ms step_avg:93.05ms +step:49/1695 train_time:4561ms step_avg:93.08ms +step:50/1695 train_time:4655ms step_avg:93.10ms +step:51/1695 train_time:4750ms step_avg:93.13ms +step:52/1695 train_time:4844ms step_avg:93.15ms +step:53/1695 train_time:4938ms step_avg:93.16ms +step:54/1695 train_time:5032ms step_avg:93.19ms +step:55/1695 train_time:5127ms step_avg:93.22ms +step:56/1695 train_time:5221ms step_avg:93.22ms +step:57/1695 train_time:5314ms step_avg:93.23ms +step:58/1695 train_time:5409ms step_avg:93.25ms +step:59/1695 train_time:5502ms step_avg:93.26ms +step:60/1695 train_time:5596ms step_avg:93.27ms +step:61/1695 train_time:5690ms step_avg:93.28ms +step:62/1695 train_time:5783ms step_avg:93.28ms +step:63/1695 train_time:5877ms step_avg:93.29ms +step:64/1695 train_time:5971ms step_avg:93.30ms +step:65/1695 train_time:6066ms step_avg:93.32ms +step:66/1695 train_time:6159ms step_avg:93.32ms +step:67/1695 train_time:6254ms step_avg:93.35ms +step:68/1695 train_time:6349ms step_avg:93.37ms +step:69/1695 train_time:6443ms step_avg:93.37ms +step:70/1695 train_time:6536ms step_avg:93.37ms +step:71/1695 train_time:6630ms step_avg:93.39ms +step:72/1695 train_time:6724ms step_avg:93.40ms +step:73/1695 train_time:6818ms step_avg:93.39ms +step:74/1695 train_time:6911ms step_avg:93.39ms +step:75/1695 train_time:7005ms step_avg:93.40ms +step:76/1695 train_time:7098ms step_avg:93.40ms +step:77/1695 train_time:7192ms step_avg:93.41ms +step:78/1695 train_time:7287ms step_avg:93.42ms +step:79/1695 train_time:7381ms step_avg:93.42ms +step:80/1695 train_time:7474ms step_avg:93.42ms +step:81/1695 train_time:7568ms step_avg:93.43ms +step:82/1695 train_time:7662ms step_avg:93.43ms +step:83/1695 train_time:7755ms step_avg:93.44ms +step:84/1695 train_time:7851ms step_avg:93.46ms +step:85/1695 train_time:7944ms step_avg:93.46ms +step:86/1695 train_time:8037ms step_avg:93.46ms +step:87/1695 train_time:8131ms step_avg:93.46ms +step:88/1695 train_time:8225ms step_avg:93.47ms +step:89/1695 train_time:8318ms step_avg:93.46ms +step:90/1695 train_time:8412ms step_avg:93.47ms +step:91/1695 train_time:8506ms step_avg:93.47ms +step:92/1695 train_time:8600ms step_avg:93.48ms +step:93/1695 train_time:8694ms step_avg:93.48ms +step:94/1695 train_time:8788ms step_avg:93.49ms +step:95/1695 train_time:8881ms step_avg:93.49ms +step:96/1695 train_time:8975ms step_avg:93.49ms +step:97/1695 train_time:9069ms step_avg:93.50ms +step:98/1695 train_time:9163ms step_avg:93.50ms +step:99/1695 train_time:9257ms step_avg:93.51ms +step:100/1695 train_time:9351ms step_avg:93.51ms +step:101/1695 train_time:9444ms step_avg:93.51ms +step:102/1695 train_time:9538ms step_avg:93.51ms +step:103/1695 train_time:9632ms step_avg:93.51ms +step:104/1695 train_time:9726ms step_avg:93.52ms +step:105/1695 train_time:9820ms step_avg:93.52ms +step:106/1695 train_time:9914ms step_avg:93.53ms +step:107/1695 train_time:10008ms step_avg:93.53ms +step:108/1695 train_time:10101ms step_avg:93.53ms +step:109/1695 train_time:10195ms step_avg:93.53ms +step:110/1695 train_time:10289ms step_avg:93.53ms +step:111/1695 train_time:10382ms step_avg:93.53ms +step:112/1695 train_time:10475ms step_avg:93.53ms +step:113/1695 train_time:10569ms step_avg:93.53ms +step:114/1695 train_time:10662ms step_avg:93.53ms +step:115/1695 train_time:10757ms step_avg:93.54ms +step:116/1695 train_time:10850ms step_avg:93.54ms +step:117/1695 train_time:10943ms step_avg:93.53ms +step:118/1695 train_time:11037ms step_avg:93.54ms +step:119/1695 train_time:11131ms step_avg:93.54ms +step:120/1695 train_time:11225ms step_avg:93.54ms +step:121/1695 train_time:11318ms step_avg:93.54ms +step:122/1695 train_time:11413ms step_avg:93.55ms +step:123/1695 train_time:11507ms step_avg:93.55ms +step:124/1695 train_time:11601ms step_avg:93.56ms +step:125/1695 train_time:11695ms step_avg:93.56ms +step:125/1695 val_loss:4.5970 train_time:11787ms step_avg:94.30ms +step:126/1695 train_time:11813ms step_avg:93.75ms +step:127/1695 train_time:11890ms step_avg:93.62ms +step:128/1695 train_time:11993ms step_avg:93.70ms +step:129/1695 train_time:12090ms step_avg:93.72ms +step:130/1695 train_time:12184ms step_avg:93.72ms +step:131/1695 train_time:12277ms step_avg:93.72ms +step:132/1695 train_time:12371ms step_avg:93.72ms +step:133/1695 train_time:12465ms step_avg:93.72ms +step:134/1695 train_time:12558ms step_avg:93.72ms +step:135/1695 train_time:12652ms step_avg:93.72ms +step:136/1695 train_time:12745ms step_avg:93.71ms +step:137/1695 train_time:12839ms step_avg:93.72ms +step:138/1695 train_time:12936ms step_avg:93.74ms +step:139/1695 train_time:13034ms step_avg:93.77ms +step:140/1695 train_time:13130ms step_avg:93.78ms +step:141/1695 train_time:13224ms step_avg:93.79ms +step:142/1695 train_time:13318ms step_avg:93.79ms +step:143/1695 train_time:13412ms step_avg:93.79ms +step:144/1695 train_time:13505ms step_avg:93.79ms +step:145/1695 train_time:13599ms step_avg:93.79ms +step:146/1695 train_time:13693ms step_avg:93.79ms +step:147/1695 train_time:13788ms step_avg:93.79ms +step:148/1695 train_time:13882ms step_avg:93.80ms +step:149/1695 train_time:13977ms step_avg:93.80ms +step:150/1695 train_time:14072ms step_avg:93.82ms +step:151/1695 train_time:14168ms step_avg:93.83ms +step:152/1695 train_time:14262ms step_avg:93.83ms +step:153/1695 train_time:14356ms step_avg:93.83ms +step:154/1695 train_time:14451ms step_avg:93.84ms +step:155/1695 train_time:14545ms step_avg:93.84ms +step:156/1695 train_time:14638ms step_avg:93.83ms +step:157/1695 train_time:14732ms step_avg:93.83ms +step:158/1695 train_time:14826ms step_avg:93.84ms +step:159/1695 train_time:14921ms step_avg:93.84ms +step:160/1695 train_time:15016ms step_avg:93.85ms +step:161/1695 train_time:15112ms step_avg:93.86ms +step:162/1695 train_time:15207ms step_avg:93.87ms +step:163/1695 train_time:15301ms step_avg:93.87ms +step:164/1695 train_time:15395ms step_avg:93.87ms +step:165/1695 train_time:15490ms step_avg:93.88ms +step:166/1695 train_time:15584ms step_avg:93.88ms +step:167/1695 train_time:15677ms step_avg:93.88ms +step:168/1695 train_time:15772ms step_avg:93.88ms +step:169/1695 train_time:15867ms step_avg:93.89ms +step:170/1695 train_time:15961ms step_avg:93.89ms +step:171/1695 train_time:16055ms step_avg:93.89ms +step:172/1695 train_time:16149ms step_avg:93.89ms +step:173/1695 train_time:16244ms step_avg:93.90ms +step:174/1695 train_time:16338ms step_avg:93.89ms +step:175/1695 train_time:16433ms step_avg:93.90ms +step:176/1695 train_time:16528ms step_avg:93.91ms +step:177/1695 train_time:16623ms step_avg:93.91ms +step:178/1695 train_time:16717ms step_avg:93.91ms +step:179/1695 train_time:16811ms step_avg:93.92ms +step:180/1695 train_time:16905ms step_avg:93.92ms +step:181/1695 train_time:16999ms step_avg:93.92ms +step:182/1695 train_time:17094ms step_avg:93.93ms +step:183/1695 train_time:17189ms step_avg:93.93ms +step:184/1695 train_time:17283ms step_avg:93.93ms +step:185/1695 train_time:17376ms step_avg:93.93ms +step:186/1695 train_time:17471ms step_avg:93.93ms +step:187/1695 train_time:17567ms step_avg:93.94ms +step:188/1695 train_time:17660ms step_avg:93.94ms +step:189/1695 train_time:17754ms step_avg:93.94ms +step:190/1695 train_time:17849ms step_avg:93.94ms +step:191/1695 train_time:17943ms step_avg:93.94ms +step:192/1695 train_time:18037ms step_avg:93.94ms +step:193/1695 train_time:18134ms step_avg:93.96ms +step:194/1695 train_time:18229ms step_avg:93.96ms +step:195/1695 train_time:18322ms step_avg:93.96ms +step:196/1695 train_time:18416ms step_avg:93.96ms +step:197/1695 train_time:18511ms step_avg:93.96ms +step:198/1695 train_time:18605ms step_avg:93.96ms +step:199/1695 train_time:18698ms step_avg:93.96ms +step:200/1695 train_time:18793ms step_avg:93.96ms +step:201/1695 train_time:18888ms step_avg:93.97ms +step:202/1695 train_time:18982ms step_avg:93.97ms +step:203/1695 train_time:19076ms step_avg:93.97ms +step:204/1695 train_time:19171ms step_avg:93.98ms +step:205/1695 train_time:19265ms step_avg:93.98ms +step:206/1695 train_time:19359ms step_avg:93.97ms +step:207/1695 train_time:19453ms step_avg:93.98ms +step:208/1695 train_time:19548ms step_avg:93.98ms +step:209/1695 train_time:19641ms step_avg:93.98ms +step:210/1695 train_time:19735ms step_avg:93.98ms +step:211/1695 train_time:19829ms step_avg:93.98ms +step:212/1695 train_time:19923ms step_avg:93.98ms +step:213/1695 train_time:20018ms step_avg:93.98ms +step:214/1695 train_time:20113ms step_avg:93.98ms +step:215/1695 train_time:20208ms step_avg:93.99ms +step:216/1695 train_time:20302ms step_avg:93.99ms +step:217/1695 train_time:20395ms step_avg:93.99ms +step:218/1695 train_time:20489ms step_avg:93.99ms +step:219/1695 train_time:20582ms step_avg:93.98ms +step:220/1695 train_time:20676ms step_avg:93.98ms +step:221/1695 train_time:20770ms step_avg:93.98ms +step:222/1695 train_time:20866ms step_avg:93.99ms +step:223/1695 train_time:20961ms step_avg:93.99ms +step:224/1695 train_time:21054ms step_avg:93.99ms +step:225/1695 train_time:21149ms step_avg:93.99ms +step:226/1695 train_time:21243ms step_avg:93.99ms +step:227/1695 train_time:21337ms step_avg:93.99ms +step:228/1695 train_time:21432ms step_avg:94.00ms +step:229/1695 train_time:21527ms step_avg:94.00ms +step:230/1695 train_time:21620ms step_avg:94.00ms +step:231/1695 train_time:21714ms step_avg:94.00ms +step:232/1695 train_time:21808ms step_avg:94.00ms +step:233/1695 train_time:21903ms step_avg:94.00ms +step:234/1695 train_time:21996ms step_avg:94.00ms +step:235/1695 train_time:22091ms step_avg:94.00ms +step:236/1695 train_time:22184ms step_avg:94.00ms +step:237/1695 train_time:22278ms step_avg:94.00ms +step:238/1695 train_time:22372ms step_avg:94.00ms +step:239/1695 train_time:22467ms step_avg:94.00ms +step:240/1695 train_time:22562ms step_avg:94.01ms +step:241/1695 train_time:22657ms step_avg:94.01ms +step:242/1695 train_time:22752ms step_avg:94.02ms +step:243/1695 train_time:22847ms step_avg:94.02ms +step:244/1695 train_time:22941ms step_avg:94.02ms +step:245/1695 train_time:23035ms step_avg:94.02ms +step:246/1695 train_time:23130ms step_avg:94.03ms +step:247/1695 train_time:23225ms step_avg:94.03ms +step:248/1695 train_time:23318ms step_avg:94.02ms +step:249/1695 train_time:23413ms step_avg:94.03ms +step:250/1695 train_time:23508ms step_avg:94.03ms +step:250/1695 val_loss:4.0722 train_time:23600ms step_avg:94.40ms +step:251/1695 train_time:23626ms step_avg:94.13ms +step:252/1695 train_time:23704ms step_avg:94.07ms +step:253/1695 train_time:23806ms step_avg:94.09ms +step:254/1695 train_time:23901ms step_avg:94.10ms +step:255/1695 train_time:23995ms step_avg:94.10ms +step:256/1695 train_time:24088ms step_avg:94.10ms +step:257/1695 train_time:24182ms step_avg:94.09ms +step:258/1695 train_time:24276ms step_avg:94.09ms +step:259/1695 train_time:24369ms step_avg:94.09ms +step:260/1695 train_time:24464ms step_avg:94.09ms +step:261/1695 train_time:24558ms step_avg:94.09ms +step:262/1695 train_time:24653ms step_avg:94.09ms +step:263/1695 train_time:24749ms step_avg:94.10ms +step:264/1695 train_time:24846ms step_avg:94.11ms +step:265/1695 train_time:24942ms step_avg:94.12ms +step:266/1695 train_time:25037ms step_avg:94.12ms +step:267/1695 train_time:25130ms step_avg:94.12ms +step:268/1695 train_time:25225ms step_avg:94.12ms +step:269/1695 train_time:25319ms step_avg:94.12ms +step:270/1695 train_time:25413ms step_avg:94.12ms +step:271/1695 train_time:25507ms step_avg:94.12ms +step:272/1695 train_time:25602ms step_avg:94.13ms +step:273/1695 train_time:25698ms step_avg:94.13ms +step:274/1695 train_time:25792ms step_avg:94.13ms +step:275/1695 train_time:25888ms step_avg:94.14ms +step:276/1695 train_time:25983ms step_avg:94.14ms +step:277/1695 train_time:26079ms step_avg:94.15ms +step:278/1695 train_time:26173ms step_avg:94.15ms +step:279/1695 train_time:26267ms step_avg:94.15ms +step:280/1695 train_time:26362ms step_avg:94.15ms +step:281/1695 train_time:26457ms step_avg:94.15ms +step:282/1695 train_time:26550ms step_avg:94.15ms +step:283/1695 train_time:26645ms step_avg:94.15ms +step:284/1695 train_time:26740ms step_avg:94.15ms +step:285/1695 train_time:26835ms step_avg:94.16ms +step:286/1695 train_time:26930ms step_avg:94.16ms +step:287/1695 train_time:27026ms step_avg:94.17ms +step:288/1695 train_time:27121ms step_avg:94.17ms +step:289/1695 train_time:27215ms step_avg:94.17ms +step:290/1695 train_time:27309ms step_avg:94.17ms +step:291/1695 train_time:27404ms step_avg:94.17ms +step:292/1695 train_time:27499ms step_avg:94.17ms +step:293/1695 train_time:27592ms step_avg:94.17ms +step:294/1695 train_time:27686ms step_avg:94.17ms +step:295/1695 train_time:27781ms step_avg:94.17ms +step:296/1695 train_time:27875ms step_avg:94.17ms +step:297/1695 train_time:27969ms step_avg:94.17ms +step:298/1695 train_time:28066ms step_avg:94.18ms +step:299/1695 train_time:28162ms step_avg:94.19ms +step:300/1695 train_time:28256ms step_avg:94.19ms +step:301/1695 train_time:28350ms step_avg:94.19ms +step:302/1695 train_time:28445ms step_avg:94.19ms +step:303/1695 train_time:28539ms step_avg:94.19ms +step:304/1695 train_time:28633ms step_avg:94.19ms +step:305/1695 train_time:28728ms step_avg:94.19ms +step:306/1695 train_time:28823ms step_avg:94.19ms +step:307/1695 train_time:28917ms step_avg:94.19ms +step:308/1695 train_time:29012ms step_avg:94.20ms +step:309/1695 train_time:29107ms step_avg:94.20ms +step:310/1695 train_time:29201ms step_avg:94.20ms +step:311/1695 train_time:29296ms step_avg:94.20ms +step:312/1695 train_time:29390ms step_avg:94.20ms +step:313/1695 train_time:29485ms step_avg:94.20ms +step:314/1695 train_time:29580ms step_avg:94.20ms +step:315/1695 train_time:29674ms step_avg:94.20ms +step:316/1695 train_time:29768ms step_avg:94.20ms +step:317/1695 train_time:29865ms step_avg:94.21ms +step:318/1695 train_time:29960ms step_avg:94.21ms +step:319/1695 train_time:30054ms step_avg:94.21ms +step:320/1695 train_time:30149ms step_avg:94.21ms +step:321/1695 train_time:30243ms step_avg:94.21ms +step:322/1695 train_time:30338ms step_avg:94.22ms +step:323/1695 train_time:30432ms step_avg:94.22ms +step:324/1695 train_time:30527ms step_avg:94.22ms +step:325/1695 train_time:30622ms step_avg:94.22ms +step:326/1695 train_time:30717ms step_avg:94.22ms +step:327/1695 train_time:30812ms step_avg:94.23ms +step:328/1695 train_time:30907ms step_avg:94.23ms +step:329/1695 train_time:31001ms step_avg:94.23ms +step:330/1695 train_time:31096ms step_avg:94.23ms +step:331/1695 train_time:31190ms step_avg:94.23ms +step:332/1695 train_time:31285ms step_avg:94.23ms +step:333/1695 train_time:31379ms step_avg:94.23ms +step:334/1695 train_time:31475ms step_avg:94.24ms +step:335/1695 train_time:31569ms step_avg:94.23ms +step:336/1695 train_time:31664ms step_avg:94.24ms +step:337/1695 train_time:31759ms step_avg:94.24ms +step:338/1695 train_time:31853ms step_avg:94.24ms +step:339/1695 train_time:31947ms step_avg:94.24ms +step:340/1695 train_time:32043ms step_avg:94.24ms +step:341/1695 train_time:32138ms step_avg:94.25ms +step:342/1695 train_time:32233ms step_avg:94.25ms +step:343/1695 train_time:32327ms step_avg:94.25ms +step:344/1695 train_time:32422ms step_avg:94.25ms +step:345/1695 train_time:32517ms step_avg:94.25ms +step:346/1695 train_time:32611ms step_avg:94.25ms +step:347/1695 train_time:32705ms step_avg:94.25ms +step:348/1695 train_time:32800ms step_avg:94.25ms +step:349/1695 train_time:32895ms step_avg:94.25ms +step:350/1695 train_time:32989ms step_avg:94.25ms +step:351/1695 train_time:33083ms step_avg:94.25ms +step:352/1695 train_time:33178ms step_avg:94.26ms +step:353/1695 train_time:33273ms step_avg:94.26ms +step:354/1695 train_time:33367ms step_avg:94.26ms +step:355/1695 train_time:33463ms step_avg:94.26ms +step:356/1695 train_time:33557ms step_avg:94.26ms +step:357/1695 train_time:33652ms step_avg:94.26ms +step:358/1695 train_time:33746ms step_avg:94.26ms +step:359/1695 train_time:33841ms step_avg:94.27ms +step:360/1695 train_time:33936ms step_avg:94.27ms +step:361/1695 train_time:34031ms step_avg:94.27ms +step:362/1695 train_time:34126ms step_avg:94.27ms +step:363/1695 train_time:34221ms step_avg:94.27ms +step:364/1695 train_time:34315ms step_avg:94.27ms +step:365/1695 train_time:34409ms step_avg:94.27ms +step:366/1695 train_time:34504ms step_avg:94.27ms +step:367/1695 train_time:34599ms step_avg:94.27ms +step:368/1695 train_time:34693ms step_avg:94.27ms +step:369/1695 train_time:34787ms step_avg:94.27ms +step:370/1695 train_time:34882ms step_avg:94.28ms +step:371/1695 train_time:34978ms step_avg:94.28ms +step:372/1695 train_time:35072ms step_avg:94.28ms +step:373/1695 train_time:35167ms step_avg:94.28ms +step:374/1695 train_time:35261ms step_avg:94.28ms +step:375/1695 train_time:35355ms step_avg:94.28ms +step:375/1695 val_loss:3.8750 train_time:35447ms step_avg:94.53ms +step:376/1695 train_time:35473ms step_avg:94.34ms +step:377/1695 train_time:35551ms step_avg:94.30ms +step:378/1695 train_time:35650ms step_avg:94.31ms +step:379/1695 train_time:35747ms step_avg:94.32ms +step:380/1695 train_time:35843ms step_avg:94.32ms +step:381/1695 train_time:35940ms step_avg:94.33ms +step:382/1695 train_time:36035ms step_avg:94.33ms +step:383/1695 train_time:36130ms step_avg:94.33ms +step:384/1695 train_time:36226ms step_avg:94.34ms +step:385/1695 train_time:36322ms step_avg:94.34ms +step:386/1695 train_time:36418ms step_avg:94.35ms +step:387/1695 train_time:36516ms step_avg:94.36ms +step:388/1695 train_time:36613ms step_avg:94.36ms +step:389/1695 train_time:36710ms step_avg:94.37ms +step:390/1695 train_time:36807ms step_avg:94.38ms +step:391/1695 train_time:36902ms step_avg:94.38ms +step:392/1695 train_time:36998ms step_avg:94.38ms +step:393/1695 train_time:37095ms step_avg:94.39ms +step:394/1695 train_time:37191ms step_avg:94.39ms +step:395/1695 train_time:37286ms step_avg:94.40ms +step:396/1695 train_time:37382ms step_avg:94.40ms +step:397/1695 train_time:37479ms step_avg:94.41ms +step:398/1695 train_time:37577ms step_avg:94.42ms +step:399/1695 train_time:37675ms step_avg:94.42ms +step:400/1695 train_time:37771ms step_avg:94.43ms +step:401/1695 train_time:37867ms step_avg:94.43ms +step:402/1695 train_time:37963ms step_avg:94.44ms +step:403/1695 train_time:38059ms step_avg:94.44ms +step:404/1695 train_time:38156ms step_avg:94.45ms +step:405/1695 train_time:38252ms step_avg:94.45ms +step:406/1695 train_time:38347ms step_avg:94.45ms +step:407/1695 train_time:38443ms step_avg:94.46ms +step:408/1695 train_time:38541ms step_avg:94.46ms +step:409/1695 train_time:38637ms step_avg:94.47ms +step:410/1695 train_time:38734ms step_avg:94.47ms +step:411/1695 train_time:38830ms step_avg:94.48ms +step:412/1695 train_time:38926ms step_avg:94.48ms +step:413/1695 train_time:39022ms step_avg:94.48ms +step:414/1695 train_time:39118ms step_avg:94.49ms +step:415/1695 train_time:39215ms step_avg:94.49ms +step:416/1695 train_time:39311ms step_avg:94.50ms +step:417/1695 train_time:39407ms step_avg:94.50ms +step:418/1695 train_time:39504ms step_avg:94.51ms +step:419/1695 train_time:39600ms step_avg:94.51ms +step:420/1695 train_time:39697ms step_avg:94.52ms +step:421/1695 train_time:39793ms step_avg:94.52ms +step:422/1695 train_time:39890ms step_avg:94.53ms +step:423/1695 train_time:39986ms step_avg:94.53ms +step:424/1695 train_time:40082ms step_avg:94.53ms +step:425/1695 train_time:40179ms step_avg:94.54ms +step:426/1695 train_time:40276ms step_avg:94.55ms +step:427/1695 train_time:40372ms step_avg:94.55ms +step:428/1695 train_time:40468ms step_avg:94.55ms +step:429/1695 train_time:40564ms step_avg:94.55ms +step:430/1695 train_time:40660ms step_avg:94.56ms +step:431/1695 train_time:40757ms step_avg:94.56ms +step:432/1695 train_time:40854ms step_avg:94.57ms +step:433/1695 train_time:40951ms step_avg:94.57ms +step:434/1695 train_time:41047ms step_avg:94.58ms +step:435/1695 train_time:41143ms step_avg:94.58ms +step:436/1695 train_time:41239ms step_avg:94.59ms +step:437/1695 train_time:41336ms step_avg:94.59ms +step:438/1695 train_time:41432ms step_avg:94.59ms +step:439/1695 train_time:41528ms step_avg:94.60ms +step:440/1695 train_time:41624ms step_avg:94.60ms +step:441/1695 train_time:41721ms step_avg:94.60ms +step:442/1695 train_time:41818ms step_avg:94.61ms +step:443/1695 train_time:41915ms step_avg:94.62ms +step:444/1695 train_time:42011ms step_avg:94.62ms +step:445/1695 train_time:42107ms step_avg:94.62ms +step:446/1695 train_time:42203ms step_avg:94.62ms +step:447/1695 train_time:42299ms step_avg:94.63ms +step:448/1695 train_time:42395ms step_avg:94.63ms +step:449/1695 train_time:42491ms step_avg:94.63ms +step:450/1695 train_time:42586ms step_avg:94.64ms +step:451/1695 train_time:42683ms step_avg:94.64ms +step:452/1695 train_time:42780ms step_avg:94.65ms +step:453/1695 train_time:42877ms step_avg:94.65ms +step:454/1695 train_time:42974ms step_avg:94.66ms +step:455/1695 train_time:43071ms step_avg:94.66ms +step:456/1695 train_time:43167ms step_avg:94.66ms +step:457/1695 train_time:43263ms step_avg:94.67ms +step:458/1695 train_time:43360ms step_avg:94.67ms +step:459/1695 train_time:43457ms step_avg:94.68ms +step:460/1695 train_time:43553ms step_avg:94.68ms +step:461/1695 train_time:43649ms step_avg:94.68ms +step:462/1695 train_time:43745ms step_avg:94.69ms +step:463/1695 train_time:43841ms step_avg:94.69ms +step:464/1695 train_time:43938ms step_avg:94.69ms +step:465/1695 train_time:44035ms step_avg:94.70ms +step:466/1695 train_time:44131ms step_avg:94.70ms +step:467/1695 train_time:44227ms step_avg:94.70ms +step:468/1695 train_time:44323ms step_avg:94.71ms +step:469/1695 train_time:44420ms step_avg:94.71ms +step:470/1695 train_time:44517ms step_avg:94.72ms +step:471/1695 train_time:44613ms step_avg:94.72ms +step:472/1695 train_time:44709ms step_avg:94.72ms +step:473/1695 train_time:44804ms step_avg:94.72ms +step:474/1695 train_time:44901ms step_avg:94.73ms +step:475/1695 train_time:44998ms step_avg:94.73ms +step:476/1695 train_time:45095ms step_avg:94.74ms +step:477/1695 train_time:45191ms step_avg:94.74ms +step:478/1695 train_time:45287ms step_avg:94.74ms +step:479/1695 train_time:45382ms step_avg:94.74ms +step:480/1695 train_time:45479ms step_avg:94.75ms +step:481/1695 train_time:45575ms step_avg:94.75ms +step:482/1695 train_time:45671ms step_avg:94.75ms +step:483/1695 train_time:45767ms step_avg:94.76ms +step:484/1695 train_time:45863ms step_avg:94.76ms +step:485/1695 train_time:45960ms step_avg:94.76ms +step:486/1695 train_time:46057ms step_avg:94.77ms +step:487/1695 train_time:46154ms step_avg:94.77ms +step:488/1695 train_time:46250ms step_avg:94.78ms +step:489/1695 train_time:46346ms step_avg:94.78ms +step:490/1695 train_time:46442ms step_avg:94.78ms +step:491/1695 train_time:46539ms step_avg:94.78ms +step:492/1695 train_time:46635ms step_avg:94.79ms +step:493/1695 train_time:46732ms step_avg:94.79ms +step:494/1695 train_time:46827ms step_avg:94.79ms +step:495/1695 train_time:46923ms step_avg:94.79ms +step:496/1695 train_time:47020ms step_avg:94.80ms +step:497/1695 train_time:47117ms step_avg:94.80ms +step:498/1695 train_time:47214ms step_avg:94.81ms +step:499/1695 train_time:47310ms step_avg:94.81ms +step:500/1695 train_time:47406ms step_avg:94.81ms +step:500/1695 val_loss:3.7286 train_time:47501ms step_avg:95.00ms +step:501/1695 train_time:47526ms step_avg:94.86ms +step:502/1695 train_time:47607ms step_avg:94.83ms +step:503/1695 train_time:47709ms step_avg:94.85ms +step:504/1695 train_time:47805ms step_avg:94.85ms +step:505/1695 train_time:47900ms step_avg:94.85ms +step:506/1695 train_time:47997ms step_avg:94.86ms +step:507/1695 train_time:48093ms step_avg:94.86ms +step:508/1695 train_time:48188ms step_avg:94.86ms +step:509/1695 train_time:48284ms step_avg:94.86ms +step:510/1695 train_time:48380ms step_avg:94.86ms +step:511/1695 train_time:48478ms step_avg:94.87ms +step:512/1695 train_time:48577ms step_avg:94.88ms +step:513/1695 train_time:48678ms step_avg:94.89ms +step:514/1695 train_time:48775ms step_avg:94.89ms +step:515/1695 train_time:48871ms step_avg:94.90ms +step:516/1695 train_time:48967ms step_avg:94.90ms +step:517/1695 train_time:49063ms step_avg:94.90ms +step:518/1695 train_time:49159ms step_avg:94.90ms +step:519/1695 train_time:49256ms step_avg:94.91ms +step:520/1695 train_time:49351ms step_avg:94.91ms +step:521/1695 train_time:49447ms step_avg:94.91ms +step:522/1695 train_time:49544ms step_avg:94.91ms +step:523/1695 train_time:49642ms step_avg:94.92ms +step:524/1695 train_time:49740ms step_avg:94.92ms +step:525/1695 train_time:49839ms step_avg:94.93ms +step:526/1695 train_time:49936ms step_avg:94.94ms +step:527/1695 train_time:50033ms step_avg:94.94ms +step:528/1695 train_time:50129ms step_avg:94.94ms +step:529/1695 train_time:50225ms step_avg:94.94ms +step:530/1695 train_time:50321ms step_avg:94.94ms +step:531/1695 train_time:50418ms step_avg:94.95ms +step:532/1695 train_time:50516ms step_avg:94.95ms +step:533/1695 train_time:50613ms step_avg:94.96ms +step:534/1695 train_time:50710ms step_avg:94.96ms +step:535/1695 train_time:50808ms step_avg:94.97ms +step:536/1695 train_time:50904ms step_avg:94.97ms +step:537/1695 train_time:51001ms step_avg:94.97ms +step:538/1695 train_time:51098ms step_avg:94.98ms +step:539/1695 train_time:51194ms step_avg:94.98ms +step:540/1695 train_time:51290ms step_avg:94.98ms +step:541/1695 train_time:51386ms step_avg:94.98ms +step:542/1695 train_time:51483ms step_avg:94.99ms +step:543/1695 train_time:51580ms step_avg:94.99ms +step:544/1695 train_time:51678ms step_avg:95.00ms +step:545/1695 train_time:51776ms step_avg:95.00ms +step:546/1695 train_time:51873ms step_avg:95.00ms +step:547/1695 train_time:51969ms step_avg:95.01ms +step:548/1695 train_time:52065ms step_avg:95.01ms +step:549/1695 train_time:52161ms step_avg:95.01ms +step:550/1695 train_time:52258ms step_avg:95.01ms +step:551/1695 train_time:52354ms step_avg:95.02ms +step:552/1695 train_time:52451ms step_avg:95.02ms +step:553/1695 train_time:52547ms step_avg:95.02ms +step:554/1695 train_time:52644ms step_avg:95.03ms +step:555/1695 train_time:52741ms step_avg:95.03ms +step:556/1695 train_time:52839ms step_avg:95.03ms +step:557/1695 train_time:52937ms step_avg:95.04ms +step:558/1695 train_time:53034ms step_avg:95.04ms +step:559/1695 train_time:53132ms step_avg:95.05ms +step:560/1695 train_time:53227ms step_avg:95.05ms +step:561/1695 train_time:53324ms step_avg:95.05ms +step:562/1695 train_time:53420ms step_avg:95.05ms +step:563/1695 train_time:53518ms step_avg:95.06ms +step:564/1695 train_time:53614ms step_avg:95.06ms +step:565/1695 train_time:53711ms step_avg:95.06ms +step:566/1695 train_time:53807ms step_avg:95.07ms +step:567/1695 train_time:53904ms step_avg:95.07ms +step:568/1695 train_time:54002ms step_avg:95.07ms +step:569/1695 train_time:54099ms step_avg:95.08ms +step:570/1695 train_time:54196ms step_avg:95.08ms +step:571/1695 train_time:54292ms step_avg:95.08ms +step:572/1695 train_time:54388ms step_avg:95.08ms +step:573/1695 train_time:54484ms step_avg:95.09ms +step:574/1695 train_time:54580ms step_avg:95.09ms +step:575/1695 train_time:54677ms step_avg:95.09ms +step:576/1695 train_time:54774ms step_avg:95.09ms +step:577/1695 train_time:54871ms step_avg:95.10ms +step:578/1695 train_time:54967ms step_avg:95.10ms +step:579/1695 train_time:55063ms step_avg:95.10ms +step:580/1695 train_time:55160ms step_avg:95.10ms +step:581/1695 train_time:55257ms step_avg:95.11ms +step:582/1695 train_time:55355ms step_avg:95.11ms +step:583/1695 train_time:55452ms step_avg:95.11ms +step:584/1695 train_time:55548ms step_avg:95.12ms +step:585/1695 train_time:55644ms step_avg:95.12ms +step:586/1695 train_time:55741ms step_avg:95.12ms +step:587/1695 train_time:55838ms step_avg:95.12ms +step:588/1695 train_time:55935ms step_avg:95.13ms +step:589/1695 train_time:56032ms step_avg:95.13ms +step:590/1695 train_time:56128ms step_avg:95.13ms +step:591/1695 train_time:56225ms step_avg:95.13ms +step:592/1695 train_time:56322ms step_avg:95.14ms +step:593/1695 train_time:56419ms step_avg:95.14ms +step:594/1695 train_time:56516ms step_avg:95.15ms +step:595/1695 train_time:56614ms step_avg:95.15ms +step:596/1695 train_time:56710ms step_avg:95.15ms +step:597/1695 train_time:56806ms step_avg:95.15ms +step:598/1695 train_time:56902ms step_avg:95.15ms +step:599/1695 train_time:56999ms step_avg:95.16ms +step:600/1695 train_time:57096ms step_avg:95.16ms +step:601/1695 train_time:57192ms step_avg:95.16ms +step:602/1695 train_time:57289ms step_avg:95.16ms +step:603/1695 train_time:57386ms step_avg:95.17ms +step:604/1695 train_time:57482ms step_avg:95.17ms +step:605/1695 train_time:57580ms step_avg:95.17ms +step:606/1695 train_time:57677ms step_avg:95.18ms +step:607/1695 train_time:57774ms step_avg:95.18ms +step:608/1695 train_time:57870ms step_avg:95.18ms +step:609/1695 train_time:57966ms step_avg:95.18ms +step:610/1695 train_time:58062ms step_avg:95.18ms +step:611/1695 train_time:58159ms step_avg:95.19ms +step:612/1695 train_time:58257ms step_avg:95.19ms +step:613/1695 train_time:58354ms step_avg:95.19ms +step:614/1695 train_time:58451ms step_avg:95.20ms +step:615/1695 train_time:58548ms step_avg:95.20ms +step:616/1695 train_time:58645ms step_avg:95.20ms +step:617/1695 train_time:58741ms step_avg:95.20ms +step:618/1695 train_time:58839ms step_avg:95.21ms +step:619/1695 train_time:58936ms step_avg:95.21ms +step:620/1695 train_time:59032ms step_avg:95.21ms +step:621/1695 train_time:59128ms step_avg:95.21ms +step:622/1695 train_time:59225ms step_avg:95.22ms +step:623/1695 train_time:59321ms step_avg:95.22ms +step:624/1695 train_time:59419ms step_avg:95.22ms +step:625/1695 train_time:59516ms step_avg:95.23ms +step:625/1695 val_loss:3.6445 train_time:59612ms step_avg:95.38ms +step:626/1695 train_time:59637ms step_avg:95.27ms +step:627/1695 train_time:59719ms step_avg:95.25ms +step:628/1695 train_time:59822ms step_avg:95.26ms +step:629/1695 train_time:60137ms step_avg:95.61ms +step:630/1695 train_time:60233ms step_avg:95.61ms +step:631/1695 train_time:60330ms step_avg:95.61ms +step:632/1695 train_time:60427ms step_avg:95.61ms +step:633/1695 train_time:60523ms step_avg:95.61ms +step:634/1695 train_time:60620ms step_avg:95.62ms +step:635/1695 train_time:60718ms step_avg:95.62ms +step:636/1695 train_time:60815ms step_avg:95.62ms +step:637/1695 train_time:60912ms step_avg:95.62ms +step:638/1695 train_time:61010ms step_avg:95.63ms +step:639/1695 train_time:61113ms step_avg:95.64ms +step:640/1695 train_time:61212ms step_avg:95.64ms +step:641/1695 train_time:61310ms step_avg:95.65ms +step:642/1695 train_time:61407ms step_avg:95.65ms +step:643/1695 train_time:61831ms step_avg:96.16ms +step:644/1695 train_time:61927ms step_avg:96.16ms +step:645/1695 train_time:62024ms step_avg:96.16ms +step:646/1695 train_time:62121ms step_avg:96.16ms +step:647/1695 train_time:62218ms step_avg:96.16ms +step:648/1695 train_time:62316ms step_avg:96.17ms +step:649/1695 train_time:62413ms step_avg:96.17ms +step:650/1695 train_time:62511ms step_avg:96.17ms +step:651/1695 train_time:62607ms step_avg:96.17ms +step:652/1695 train_time:62705ms step_avg:96.17ms +step:653/1695 train_time:63092ms step_avg:96.62ms +step:654/1695 train_time:63188ms step_avg:96.62ms +step:655/1695 train_time:63285ms step_avg:96.62ms +step:656/1695 train_time:63382ms step_avg:96.62ms +step:657/1695 train_time:63480ms step_avg:96.62ms +step:658/1695 train_time:63577ms step_avg:96.62ms +step:659/1695 train_time:63920ms step_avg:96.99ms +step:660/1695 train_time:64016ms step_avg:96.99ms +step:661/1695 train_time:64113ms step_avg:96.99ms +step:662/1695 train_time:64210ms step_avg:96.99ms +step:663/1695 train_time:64307ms step_avg:96.99ms +step:664/1695 train_time:64404ms step_avg:96.99ms +step:665/1695 train_time:64501ms step_avg:96.99ms +step:666/1695 train_time:64598ms step_avg:96.99ms +step:667/1695 train_time:64696ms step_avg:96.99ms +step:668/1695 train_time:64793ms step_avg:97.00ms +step:669/1695 train_time:64893ms step_avg:97.00ms +step:670/1695 train_time:64993ms step_avg:97.00ms +step:671/1695 train_time:65091ms step_avg:97.01ms +step:672/1695 train_time:65189ms step_avg:97.01ms +step:673/1695 train_time:65286ms step_avg:97.01ms +step:674/1695 train_time:65383ms step_avg:97.01ms +step:675/1695 train_time:65480ms step_avg:97.01ms +step:676/1695 train_time:65578ms step_avg:97.01ms +step:677/1695 train_time:65675ms step_avg:97.01ms +step:678/1695 train_time:65773ms step_avg:97.01ms +step:679/1695 train_time:65871ms step_avg:97.01ms +step:680/1695 train_time:65970ms step_avg:97.01ms +step:681/1695 train_time:66068ms step_avg:97.02ms +step:682/1695 train_time:66166ms step_avg:97.02ms +step:683/1695 train_time:66264ms step_avg:97.02ms +step:684/1695 train_time:66362ms step_avg:97.02ms +step:685/1695 train_time:66459ms step_avg:97.02ms +step:686/1695 train_time:66557ms step_avg:97.02ms +step:687/1695 train_time:66655ms step_avg:97.02ms +step:688/1695 train_time:66753ms step_avg:97.02ms +step:689/1695 train_time:66851ms step_avg:97.03ms +step:690/1695 train_time:66949ms step_avg:97.03ms +step:691/1695 train_time:67047ms step_avg:97.03ms +step:692/1695 train_time:67145ms step_avg:97.03ms +step:693/1695 train_time:67243ms step_avg:97.03ms +step:694/1695 train_time:67341ms step_avg:97.03ms +step:695/1695 train_time:67439ms step_avg:97.03ms +step:696/1695 train_time:67538ms step_avg:97.04ms +step:697/1695 train_time:67636ms step_avg:97.04ms +step:698/1695 train_time:67734ms step_avg:97.04ms +step:699/1695 train_time:67832ms step_avg:97.04ms +step:700/1695 train_time:67930ms step_avg:97.04ms +step:701/1695 train_time:68028ms step_avg:97.04ms +step:702/1695 train_time:68126ms step_avg:97.05ms +step:703/1695 train_time:68224ms step_avg:97.05ms +step:704/1695 train_time:68322ms step_avg:97.05ms +step:705/1695 train_time:68420ms step_avg:97.05ms +step:706/1695 train_time:68518ms step_avg:97.05ms +step:707/1695 train_time:68616ms step_avg:97.05ms +step:708/1695 train_time:68714ms step_avg:97.05ms +step:709/1695 train_time:68812ms step_avg:97.05ms +step:710/1695 train_time:68910ms step_avg:97.06ms +step:711/1695 train_time:69008ms step_avg:97.06ms +step:712/1695 train_time:69106ms step_avg:97.06ms +step:713/1695 train_time:69424ms step_avg:97.37ms +step:714/1695 train_time:69520ms step_avg:97.37ms +step:715/1695 train_time:69617ms step_avg:97.37ms +step:716/1695 train_time:69715ms step_avg:97.37ms +step:717/1695 train_time:69812ms step_avg:97.37ms +step:718/1695 train_time:69909ms step_avg:97.37ms +step:719/1695 train_time:70006ms step_avg:97.37ms +step:720/1695 train_time:70103ms step_avg:97.36ms +step:721/1695 train_time:70199ms step_avg:97.36ms +step:722/1695 train_time:70299ms step_avg:97.37ms +step:723/1695 train_time:70402ms step_avg:97.37ms +step:724/1695 train_time:70500ms step_avg:97.38ms +step:725/1695 train_time:70598ms step_avg:97.38ms +step:726/1695 train_time:70696ms step_avg:97.38ms +step:727/1695 train_time:70794ms step_avg:97.38ms +step:728/1695 train_time:70892ms step_avg:97.38ms +step:729/1695 train_time:70990ms step_avg:97.38ms +step:730/1695 train_time:71087ms step_avg:97.38ms +step:731/1695 train_time:71184ms step_avg:97.38ms +step:732/1695 train_time:71282ms step_avg:97.38ms +step:733/1695 train_time:71382ms step_avg:97.38ms +step:734/1695 train_time:71481ms step_avg:97.39ms +step:735/1695 train_time:71579ms step_avg:97.39ms +step:736/1695 train_time:71677ms step_avg:97.39ms +step:737/1695 train_time:71775ms step_avg:97.39ms +step:738/1695 train_time:71874ms step_avg:97.39ms +step:739/1695 train_time:71972ms step_avg:97.39ms +step:740/1695 train_time:72070ms step_avg:97.39ms +step:741/1695 train_time:72168ms step_avg:97.39ms +step:742/1695 train_time:72267ms step_avg:97.39ms +step:743/1695 train_time:72365ms step_avg:97.40ms +step:744/1695 train_time:72463ms step_avg:97.40ms +step:745/1695 train_time:72561ms step_avg:97.40ms +step:746/1695 train_time:72659ms step_avg:97.40ms +step:747/1695 train_time:72757ms step_avg:97.40ms +step:748/1695 train_time:72855ms step_avg:97.40ms +step:749/1695 train_time:72952ms step_avg:97.40ms +step:750/1695 train_time:73050ms step_avg:97.40ms +step:750/1695 val_loss:3.5832 train_time:73147ms step_avg:97.53ms +step:751/1695 train_time:73173ms step_avg:97.43ms +step:752/1695 train_time:73256ms step_avg:97.41ms +step:753/1695 train_time:73356ms step_avg:97.42ms +step:754/1695 train_time:73454ms step_avg:97.42ms +step:755/1695 train_time:73552ms step_avg:97.42ms +step:756/1695 train_time:73650ms step_avg:97.42ms +step:757/1695 train_time:73747ms step_avg:97.42ms +step:758/1695 train_time:73845ms step_avg:97.42ms +step:759/1695 train_time:73943ms step_avg:97.42ms +step:760/1695 train_time:74041ms step_avg:97.42ms +step:761/1695 train_time:74138ms step_avg:97.42ms +step:762/1695 train_time:74237ms step_avg:97.42ms +step:763/1695 train_time:74336ms step_avg:97.43ms +step:764/1695 train_time:74434ms step_avg:97.43ms +step:765/1695 train_time:74533ms step_avg:97.43ms +step:766/1695 train_time:74630ms step_avg:97.43ms +step:767/1695 train_time:74728ms step_avg:97.43ms +step:768/1695 train_time:74827ms step_avg:97.43ms +step:769/1695 train_time:74925ms step_avg:97.43ms +step:770/1695 train_time:75023ms step_avg:97.43ms +step:771/1695 train_time:75122ms step_avg:97.43ms +step:772/1695 train_time:75221ms step_avg:97.44ms +step:773/1695 train_time:75321ms step_avg:97.44ms +step:774/1695 train_time:75420ms step_avg:97.44ms +step:775/1695 train_time:75521ms step_avg:97.45ms +step:776/1695 train_time:75620ms step_avg:97.45ms +step:777/1695 train_time:75719ms step_avg:97.45ms +step:778/1695 train_time:75817ms step_avg:97.45ms +step:779/1695 train_time:75915ms step_avg:97.45ms +step:780/1695 train_time:76012ms step_avg:97.45ms +step:781/1695 train_time:76110ms step_avg:97.45ms +step:782/1695 train_time:76207ms step_avg:97.45ms +step:783/1695 train_time:76305ms step_avg:97.45ms +step:784/1695 train_time:76663ms step_avg:97.78ms +step:785/1695 train_time:76832ms step_avg:97.87ms +step:786/1695 train_time:76928ms step_avg:97.87ms +step:787/1695 train_time:77025ms step_avg:97.87ms +step:788/1695 train_time:77122ms step_avg:97.87ms +step:789/1695 train_time:77511ms step_avg:98.24ms +step:790/1695 train_time:77562ms step_avg:98.18ms +step:791/1695 train_time:77659ms step_avg:98.18ms +step:792/1695 train_time:77756ms step_avg:98.18ms +step:793/1695 train_time:77854ms step_avg:98.18ms +step:794/1695 train_time:77951ms step_avg:98.17ms +step:795/1695 train_time:78048ms step_avg:98.17ms +step:796/1695 train_time:78146ms step_avg:98.17ms +step:797/1695 train_time:78243ms step_avg:98.17ms +step:798/1695 train_time:78340ms step_avg:98.17ms +step:799/1695 train_time:78440ms step_avg:98.17ms +step:800/1695 train_time:78541ms step_avg:98.18ms +step:801/1695 train_time:78641ms step_avg:98.18ms +step:802/1695 train_time:78739ms step_avg:98.18ms +step:803/1695 train_time:78837ms step_avg:98.18ms +step:804/1695 train_time:78935ms step_avg:98.18ms +step:805/1695 train_time:79033ms step_avg:98.18ms +step:806/1695 train_time:79130ms step_avg:98.18ms +step:807/1695 train_time:79228ms step_avg:98.18ms +step:808/1695 train_time:79326ms step_avg:98.18ms +step:809/1695 train_time:79426ms step_avg:98.18ms +step:810/1695 train_time:79525ms step_avg:98.18ms +step:811/1695 train_time:79625ms step_avg:98.18ms +step:812/1695 train_time:79725ms step_avg:98.18ms +step:813/1695 train_time:79825ms step_avg:98.19ms +step:814/1695 train_time:79924ms step_avg:98.19ms +step:815/1695 train_time:80024ms step_avg:98.19ms +step:816/1695 train_time:80123ms step_avg:98.19ms +step:817/1695 train_time:80221ms step_avg:98.19ms +step:818/1695 train_time:80319ms step_avg:98.19ms +step:819/1695 train_time:80417ms step_avg:98.19ms +step:820/1695 train_time:80515ms step_avg:98.19ms +step:821/1695 train_time:80613ms step_avg:98.19ms +step:822/1695 train_time:80711ms step_avg:98.19ms +step:823/1695 train_time:80809ms step_avg:98.19ms +step:824/1695 train_time:80908ms step_avg:98.19ms +step:825/1695 train_time:81007ms step_avg:98.19ms +step:826/1695 train_time:81105ms step_avg:98.19ms +step:827/1695 train_time:81204ms step_avg:98.19ms +step:828/1695 train_time:81303ms step_avg:98.19ms +step:829/1695 train_time:81402ms step_avg:98.19ms +step:830/1695 train_time:81500ms step_avg:98.19ms +step:831/1695 train_time:81599ms step_avg:98.19ms +step:832/1695 train_time:81697ms step_avg:98.19ms +step:833/1695 train_time:81796ms step_avg:98.19ms +step:834/1695 train_time:81893ms step_avg:98.19ms +step:835/1695 train_time:81990ms step_avg:98.19ms +step:836/1695 train_time:82088ms step_avg:98.19ms +step:837/1695 train_time:82186ms step_avg:98.19ms +step:838/1695 train_time:82284ms step_avg:98.19ms +step:839/1695 train_time:82383ms step_avg:98.19ms +step:840/1695 train_time:82483ms step_avg:98.19ms +step:841/1695 train_time:82581ms step_avg:98.19ms +step:842/1695 train_time:82680ms step_avg:98.20ms +step:843/1695 train_time:82779ms step_avg:98.20ms +step:844/1695 train_time:82879ms step_avg:98.20ms +step:845/1695 train_time:82977ms step_avg:98.20ms +step:846/1695 train_time:83075ms step_avg:98.20ms +step:847/1695 train_time:83173ms step_avg:98.20ms +step:848/1695 train_time:83271ms step_avg:98.20ms +step:849/1695 train_time:83369ms step_avg:98.20ms +step:850/1695 train_time:83467ms step_avg:98.20ms +step:851/1695 train_time:83566ms step_avg:98.20ms +step:852/1695 train_time:83665ms step_avg:98.20ms +step:853/1695 train_time:83765ms step_avg:98.20ms +step:854/1695 train_time:83865ms step_avg:98.20ms +step:855/1695 train_time:83964ms step_avg:98.20ms +step:856/1695 train_time:84064ms step_avg:98.21ms +step:857/1695 train_time:84164ms step_avg:98.21ms +step:858/1695 train_time:84264ms step_avg:98.21ms +step:859/1695 train_time:84364ms step_avg:98.21ms +step:860/1695 train_time:84462ms step_avg:98.21ms +step:861/1695 train_time:84560ms step_avg:98.21ms +step:862/1695 train_time:84664ms step_avg:98.22ms +step:863/1695 train_time:84757ms step_avg:98.21ms +step:864/1695 train_time:84855ms step_avg:98.21ms +step:865/1695 train_time:84953ms step_avg:98.21ms +step:866/1695 train_time:85051ms step_avg:98.21ms +step:867/1695 train_time:85149ms step_avg:98.21ms +step:868/1695 train_time:85247ms step_avg:98.21ms +step:869/1695 train_time:85346ms step_avg:98.21ms +step:870/1695 train_time:85445ms step_avg:98.21ms +step:871/1695 train_time:85543ms step_avg:98.21ms +step:872/1695 train_time:85642ms step_avg:98.21ms +step:873/1695 train_time:85740ms step_avg:98.21ms +step:874/1695 train_time:85839ms step_avg:98.21ms +step:875/1695 train_time:85938ms step_avg:98.22ms +step:875/1695 val_loss:3.5364 train_time:86035ms step_avg:98.33ms +step:876/1695 train_time:86061ms step_avg:98.24ms +step:877/1695 train_time:86146ms step_avg:98.23ms +step:878/1695 train_time:86245ms step_avg:98.23ms +step:879/1695 train_time:86344ms step_avg:98.23ms +step:880/1695 train_time:86443ms step_avg:98.23ms +step:881/1695 train_time:86541ms step_avg:98.23ms +step:882/1695 train_time:86639ms step_avg:98.23ms +step:883/1695 train_time:86738ms step_avg:98.23ms +step:884/1695 train_time:86838ms step_avg:98.23ms +step:885/1695 train_time:86936ms step_avg:98.23ms +step:886/1695 train_time:87037ms step_avg:98.24ms +step:887/1695 train_time:87137ms step_avg:98.24ms +step:888/1695 train_time:87239ms step_avg:98.24ms +step:889/1695 train_time:87339ms step_avg:98.24ms +step:890/1695 train_time:87439ms step_avg:98.25ms +step:891/1695 train_time:87538ms step_avg:98.25ms +step:892/1695 train_time:87637ms step_avg:98.25ms +step:893/1695 train_time:87736ms step_avg:98.25ms +step:894/1695 train_time:87834ms step_avg:98.25ms +step:895/1695 train_time:87934ms step_avg:98.25ms +step:896/1695 train_time:88034ms step_avg:98.25ms +step:897/1695 train_time:88133ms step_avg:98.25ms +step:898/1695 train_time:88233ms step_avg:98.26ms +step:899/1695 train_time:88334ms step_avg:98.26ms +step:900/1695 train_time:88436ms step_avg:98.26ms +step:901/1695 train_time:88536ms step_avg:98.26ms +step:902/1695 train_time:88637ms step_avg:98.27ms +step:903/1695 train_time:88736ms step_avg:98.27ms +step:904/1695 train_time:88835ms step_avg:98.27ms +step:905/1695 train_time:88933ms step_avg:98.27ms +step:906/1695 train_time:89033ms step_avg:98.27ms +step:907/1695 train_time:89133ms step_avg:98.27ms +step:908/1695 train_time:89233ms step_avg:98.27ms +step:909/1695 train_time:89333ms step_avg:98.28ms +step:910/1695 train_time:89433ms step_avg:98.28ms +step:911/1695 train_time:89535ms step_avg:98.28ms +step:912/1695 train_time:89635ms step_avg:98.28ms +step:913/1695 train_time:89734ms step_avg:98.29ms +step:914/1695 train_time:89834ms step_avg:98.29ms +step:915/1695 train_time:89933ms step_avg:98.29ms +step:916/1695 train_time:90032ms step_avg:98.29ms +step:917/1695 train_time:90133ms step_avg:98.29ms +step:918/1695 train_time:90234ms step_avg:98.29ms +step:919/1695 train_time:90334ms step_avg:98.30ms +step:920/1695 train_time:90434ms step_avg:98.30ms +step:921/1695 train_time:90535ms step_avg:98.30ms +step:922/1695 train_time:90636ms step_avg:98.30ms +step:923/1695 train_time:90736ms step_avg:98.31ms +step:924/1695 train_time:90835ms step_avg:98.31ms +step:925/1695 train_time:90935ms step_avg:98.31ms +step:926/1695 train_time:91035ms step_avg:98.31ms +step:927/1695 train_time:91135ms step_avg:98.31ms +step:928/1695 train_time:91234ms step_avg:98.31ms +step:929/1695 train_time:91334ms step_avg:98.31ms +step:930/1695 train_time:91435ms step_avg:98.32ms +step:931/1695 train_time:91537ms step_avg:98.32ms +step:932/1695 train_time:91637ms step_avg:98.32ms +step:933/1695 train_time:91736ms step_avg:98.32ms +step:934/1695 train_time:91836ms step_avg:98.33ms +step:935/1695 train_time:91935ms step_avg:98.33ms +step:936/1695 train_time:92035ms step_avg:98.33ms +step:937/1695 train_time:92134ms step_avg:98.33ms +step:938/1695 train_time:92234ms step_avg:98.33ms +step:939/1695 train_time:92334ms step_avg:98.33ms +step:940/1695 train_time:92434ms step_avg:98.33ms +step:941/1695 train_time:92535ms step_avg:98.34ms +step:942/1695 train_time:92636ms step_avg:98.34ms +step:943/1695 train_time:92736ms step_avg:98.34ms +step:944/1695 train_time:92835ms step_avg:98.34ms +step:945/1695 train_time:92936ms step_avg:98.34ms +step:946/1695 train_time:93036ms step_avg:98.35ms +step:947/1695 train_time:93135ms step_avg:98.35ms +step:948/1695 train_time:93235ms step_avg:98.35ms +step:949/1695 train_time:93334ms step_avg:98.35ms +step:950/1695 train_time:93435ms step_avg:98.35ms +step:951/1695 train_time:93535ms step_avg:98.35ms +step:952/1695 train_time:93635ms step_avg:98.36ms +step:953/1695 train_time:93736ms step_avg:98.36ms +step:954/1695 train_time:93835ms step_avg:98.36ms +step:955/1695 train_time:93935ms step_avg:98.36ms +step:956/1695 train_time:94034ms step_avg:98.36ms +step:957/1695 train_time:94135ms step_avg:98.36ms +step:958/1695 train_time:94234ms step_avg:98.37ms +step:959/1695 train_time:94334ms step_avg:98.37ms +step:960/1695 train_time:94434ms step_avg:98.37ms +step:961/1695 train_time:94535ms step_avg:98.37ms +step:962/1695 train_time:94635ms step_avg:98.37ms +step:963/1695 train_time:94735ms step_avg:98.37ms +step:964/1695 train_time:94834ms step_avg:98.38ms +step:965/1695 train_time:94934ms step_avg:98.38ms +step:966/1695 train_time:95035ms step_avg:98.38ms +step:967/1695 train_time:95136ms step_avg:98.38ms +step:968/1695 train_time:95235ms step_avg:98.38ms +step:969/1695 train_time:95335ms step_avg:98.39ms +step:970/1695 train_time:95435ms step_avg:98.39ms +step:971/1695 train_time:95535ms step_avg:98.39ms +step:972/1695 train_time:95635ms step_avg:98.39ms +step:973/1695 train_time:95736ms step_avg:98.39ms +step:974/1695 train_time:95835ms step_avg:98.39ms +step:975/1695 train_time:95935ms step_avg:98.39ms +step:976/1695 train_time:96034ms step_avg:98.40ms +step:977/1695 train_time:96134ms step_avg:98.40ms +step:978/1695 train_time:96234ms step_avg:98.40ms +step:979/1695 train_time:96334ms step_avg:98.40ms +step:980/1695 train_time:96435ms step_avg:98.40ms +step:981/1695 train_time:96536ms step_avg:98.41ms +step:982/1695 train_time:96638ms step_avg:98.41ms +step:983/1695 train_time:96737ms step_avg:98.41ms +step:984/1695 train_time:96836ms step_avg:98.41ms +step:985/1695 train_time:96937ms step_avg:98.41ms +step:986/1695 train_time:97037ms step_avg:98.41ms +step:987/1695 train_time:97137ms step_avg:98.42ms +step:988/1695 train_time:97236ms step_avg:98.42ms +step:989/1695 train_time:97336ms step_avg:98.42ms +step:990/1695 train_time:97436ms step_avg:98.42ms +step:991/1695 train_time:97536ms step_avg:98.42ms +step:992/1695 train_time:97636ms step_avg:98.42ms +step:993/1695 train_time:97735ms step_avg:98.42ms +step:994/1695 train_time:97835ms step_avg:98.43ms +step:995/1695 train_time:97935ms step_avg:98.43ms +step:996/1695 train_time:98034ms step_avg:98.43ms +step:997/1695 train_time:98135ms step_avg:98.43ms +step:998/1695 train_time:98235ms step_avg:98.43ms +step:999/1695 train_time:98335ms step_avg:98.43ms +step:1000/1695 train_time:98434ms step_avg:98.43ms +step:1000/1695 val_loss:3.4915 train_time:98532ms step_avg:98.53ms +step:1001/1695 train_time:98558ms step_avg:98.46ms +step:1002/1695 train_time:98645ms step_avg:98.45ms +step:1003/1695 train_time:98749ms step_avg:98.45ms +step:1004/1695 train_time:98849ms step_avg:98.46ms +step:1005/1695 train_time:98949ms step_avg:98.46ms +step:1006/1695 train_time:99049ms step_avg:98.46ms +step:1007/1695 train_time:99149ms step_avg:98.46ms +step:1008/1695 train_time:99248ms step_avg:98.46ms +step:1009/1695 train_time:99347ms step_avg:98.46ms +step:1010/1695 train_time:99445ms step_avg:98.46ms +step:1011/1695 train_time:99547ms step_avg:98.46ms +step:1012/1695 train_time:99650ms step_avg:98.47ms +step:1013/1695 train_time:99752ms step_avg:98.47ms +step:1014/1695 train_time:99856ms step_avg:98.48ms +step:1015/1695 train_time:99955ms step_avg:98.48ms +step:1016/1695 train_time:100054ms step_avg:98.48ms +step:1017/1695 train_time:100154ms step_avg:98.48ms +step:1018/1695 train_time:100253ms step_avg:98.48ms +step:1019/1695 train_time:100352ms step_avg:98.48ms +step:1020/1695 train_time:100453ms step_avg:98.48ms +step:1021/1695 train_time:100555ms step_avg:98.49ms +step:1022/1695 train_time:100656ms step_avg:98.49ms +step:1023/1695 train_time:100756ms step_avg:98.49ms +step:1024/1695 train_time:100858ms step_avg:98.49ms +step:1025/1695 train_time:100958ms step_avg:98.50ms +step:1026/1695 train_time:101058ms step_avg:98.50ms +step:1027/1695 train_time:101156ms step_avg:98.50ms +step:1028/1695 train_time:101256ms step_avg:98.50ms +step:1029/1695 train_time:101355ms step_avg:98.50ms +step:1030/1695 train_time:101456ms step_avg:98.50ms +step:1031/1695 train_time:101556ms step_avg:98.50ms +step:1032/1695 train_time:101656ms step_avg:98.50ms +step:1033/1695 train_time:101757ms step_avg:98.51ms +step:1034/1695 train_time:101858ms step_avg:98.51ms +step:1035/1695 train_time:101957ms step_avg:98.51ms +step:1036/1695 train_time:102057ms step_avg:98.51ms +step:1037/1695 train_time:102157ms step_avg:98.51ms +step:1038/1695 train_time:102256ms step_avg:98.51ms +step:1039/1695 train_time:102355ms step_avg:98.51ms +step:1040/1695 train_time:102456ms step_avg:98.52ms +step:1041/1695 train_time:102556ms step_avg:98.52ms +step:1042/1695 train_time:102657ms step_avg:98.52ms +step:1043/1695 train_time:102757ms step_avg:98.52ms +step:1044/1695 train_time:102856ms step_avg:98.52ms +step:1045/1695 train_time:102956ms step_avg:98.52ms +step:1046/1695 train_time:103056ms step_avg:98.52ms +step:1047/1695 train_time:103156ms step_avg:98.52ms +step:1048/1695 train_time:103255ms step_avg:98.53ms +step:1049/1695 train_time:103355ms step_avg:98.53ms +step:1050/1695 train_time:103454ms step_avg:98.53ms +step:1051/1695 train_time:103556ms step_avg:98.53ms +step:1052/1695 train_time:103655ms step_avg:98.53ms +step:1053/1695 train_time:103754ms step_avg:98.53ms +step:1054/1695 train_time:103855ms step_avg:98.53ms +step:1055/1695 train_time:103956ms step_avg:98.54ms +step:1056/1695 train_time:104055ms step_avg:98.54ms +step:1057/1695 train_time:104155ms step_avg:98.54ms +step:1058/1695 train_time:104254ms step_avg:98.54ms +step:1059/1695 train_time:104353ms step_avg:98.54ms +step:1060/1695 train_time:104452ms step_avg:98.54ms +step:1061/1695 train_time:104553ms step_avg:98.54ms +step:1062/1695 train_time:104654ms step_avg:98.54ms +step:1063/1695 train_time:104756ms step_avg:98.55ms +step:1064/1695 train_time:104856ms step_avg:98.55ms +step:1065/1695 train_time:104956ms step_avg:98.55ms +step:1066/1695 train_time:105056ms step_avg:98.55ms +step:1067/1695 train_time:105155ms step_avg:98.55ms +step:1068/1695 train_time:105255ms step_avg:98.55ms +step:1069/1695 train_time:105355ms step_avg:98.55ms +step:1070/1695 train_time:105454ms step_avg:98.56ms +step:1071/1695 train_time:105554ms step_avg:98.56ms +step:1072/1695 train_time:105655ms step_avg:98.56ms +step:1073/1695 train_time:105755ms step_avg:98.56ms +step:1074/1695 train_time:105855ms step_avg:98.56ms +step:1075/1695 train_time:105955ms step_avg:98.56ms +step:1076/1695 train_time:106055ms step_avg:98.56ms +step:1077/1695 train_time:106156ms step_avg:98.57ms +step:1078/1695 train_time:106255ms step_avg:98.57ms +step:1079/1695 train_time:106354ms step_avg:98.57ms +step:1080/1695 train_time:106453ms step_avg:98.57ms +step:1081/1695 train_time:106552ms step_avg:98.57ms +step:1082/1695 train_time:106652ms step_avg:98.57ms +step:1083/1695 train_time:106752ms step_avg:98.57ms +step:1084/1695 train_time:106852ms step_avg:98.57ms +step:1085/1695 train_time:106953ms step_avg:98.57ms +step:1086/1695 train_time:107053ms step_avg:98.58ms +step:1087/1695 train_time:107154ms step_avg:98.58ms +step:1088/1695 train_time:107254ms step_avg:98.58ms +step:1089/1695 train_time:107355ms step_avg:98.58ms +step:1090/1695 train_time:107455ms step_avg:98.58ms +step:1091/1695 train_time:107556ms step_avg:98.59ms +step:1092/1695 train_time:107655ms step_avg:98.59ms +step:1093/1695 train_time:107755ms step_avg:98.59ms +step:1094/1695 train_time:107855ms step_avg:98.59ms +step:1095/1695 train_time:107955ms step_avg:98.59ms +step:1096/1695 train_time:108054ms step_avg:98.59ms +step:1097/1695 train_time:108153ms step_avg:98.59ms +step:1098/1695 train_time:108254ms step_avg:98.59ms +step:1099/1695 train_time:108353ms step_avg:98.59ms +step:1100/1695 train_time:108454ms step_avg:98.59ms +step:1101/1695 train_time:108554ms step_avg:98.60ms +step:1102/1695 train_time:108654ms step_avg:98.60ms +step:1103/1695 train_time:108755ms step_avg:98.60ms +step:1104/1695 train_time:108855ms step_avg:98.60ms +step:1105/1695 train_time:108955ms step_avg:98.60ms +step:1106/1695 train_time:109056ms step_avg:98.60ms +step:1107/1695 train_time:109155ms step_avg:98.60ms +step:1108/1695 train_time:109255ms step_avg:98.61ms +step:1109/1695 train_time:109354ms step_avg:98.61ms +step:1110/1695 train_time:109454ms step_avg:98.61ms +step:1111/1695 train_time:109555ms step_avg:98.61ms +step:1112/1695 train_time:109656ms step_avg:98.61ms +step:1113/1695 train_time:109756ms step_avg:98.61ms +step:1114/1695 train_time:109856ms step_avg:98.61ms +step:1115/1695 train_time:109956ms step_avg:98.62ms +step:1116/1695 train_time:110056ms step_avg:98.62ms +step:1117/1695 train_time:110155ms step_avg:98.62ms +step:1118/1695 train_time:110255ms step_avg:98.62ms +step:1119/1695 train_time:110354ms step_avg:98.62ms +step:1120/1695 train_time:110454ms step_avg:98.62ms +step:1121/1695 train_time:110554ms step_avg:98.62ms +step:1122/1695 train_time:110654ms step_avg:98.62ms +step:1123/1695 train_time:110754ms step_avg:98.62ms +step:1124/1695 train_time:110854ms step_avg:98.62ms +step:1125/1695 train_time:110955ms step_avg:98.63ms +step:1125/1695 val_loss:3.4410 train_time:111053ms step_avg:98.71ms +step:1126/1695 train_time:111079ms step_avg:98.65ms +step:1127/1695 train_time:111166ms step_avg:98.64ms +step:1128/1695 train_time:111269ms step_avg:98.64ms +step:1129/1695 train_time:111369ms step_avg:98.64ms +step:1130/1695 train_time:111468ms step_avg:98.64ms +step:1131/1695 train_time:111567ms step_avg:98.64ms +step:1132/1695 train_time:111666ms step_avg:98.64ms +step:1133/1695 train_time:111766ms step_avg:98.65ms +step:1134/1695 train_time:111866ms step_avg:98.65ms +step:1135/1695 train_time:111966ms step_avg:98.65ms +step:1136/1695 train_time:112068ms step_avg:98.65ms +step:1137/1695 train_time:112171ms step_avg:98.66ms +step:1138/1695 train_time:112272ms step_avg:98.66ms +step:1139/1695 train_time:112373ms step_avg:98.66ms +step:1140/1695 train_time:112473ms step_avg:98.66ms +step:1141/1695 train_time:112574ms step_avg:98.66ms +step:1142/1695 train_time:112675ms step_avg:98.66ms +step:1143/1695 train_time:112775ms step_avg:98.67ms +step:1144/1695 train_time:112877ms step_avg:98.67ms +step:1145/1695 train_time:112979ms step_avg:98.67ms +step:1146/1695 train_time:113080ms step_avg:98.67ms +step:1147/1695 train_time:113181ms step_avg:98.68ms +step:1148/1695 train_time:113283ms step_avg:98.68ms +step:1149/1695 train_time:113384ms step_avg:98.68ms +step:1150/1695 train_time:113486ms step_avg:98.68ms +step:1151/1695 train_time:113587ms step_avg:98.69ms +step:1152/1695 train_time:113688ms step_avg:98.69ms +step:1153/1695 train_time:113789ms step_avg:98.69ms +step:1154/1695 train_time:113889ms step_avg:98.69ms +step:1155/1695 train_time:113989ms step_avg:98.69ms +step:1156/1695 train_time:114090ms step_avg:98.69ms +step:1157/1695 train_time:114191ms step_avg:98.70ms +step:1158/1695 train_time:114291ms step_avg:98.70ms +step:1159/1695 train_time:114391ms step_avg:98.70ms +step:1160/1695 train_time:114491ms step_avg:98.70ms +step:1161/1695 train_time:114591ms step_avg:98.70ms +step:1162/1695 train_time:114692ms step_avg:98.70ms +step:1163/1695 train_time:114793ms step_avg:98.70ms +step:1164/1695 train_time:114893ms step_avg:98.71ms +step:1165/1695 train_time:114996ms step_avg:98.71ms +step:1166/1695 train_time:115098ms step_avg:98.71ms +step:1167/1695 train_time:115197ms step_avg:98.71ms +step:1168/1695 train_time:115299ms step_avg:98.72ms +step:1169/1695 train_time:115401ms step_avg:98.72ms +step:1170/1695 train_time:115502ms step_avg:98.72ms +step:1171/1695 train_time:115603ms step_avg:98.72ms +step:1172/1695 train_time:115705ms step_avg:98.72ms +step:1173/1695 train_time:115805ms step_avg:98.73ms +step:1174/1695 train_time:115907ms step_avg:98.73ms +step:1175/1695 train_time:116008ms step_avg:98.73ms +step:1176/1695 train_time:116109ms step_avg:98.73ms +step:1177/1695 train_time:116209ms step_avg:98.73ms +step:1178/1695 train_time:116310ms step_avg:98.74ms +step:1179/1695 train_time:116414ms step_avg:98.74ms +step:1180/1695 train_time:116513ms step_avg:98.74ms +step:1181/1695 train_time:116613ms step_avg:98.74ms +step:1182/1695 train_time:116714ms step_avg:98.74ms +step:1183/1695 train_time:116815ms step_avg:98.74ms +step:1184/1695 train_time:116918ms step_avg:98.75ms +step:1185/1695 train_time:117020ms step_avg:98.75ms +step:1186/1695 train_time:117121ms step_avg:98.75ms +step:1187/1695 train_time:117222ms step_avg:98.75ms +step:1188/1695 train_time:117323ms step_avg:98.76ms +step:1189/1695 train_time:117424ms step_avg:98.76ms +step:1190/1695 train_time:117525ms step_avg:98.76ms +step:1191/1695 train_time:117625ms step_avg:98.76ms +step:1192/1695 train_time:117726ms step_avg:98.76ms +step:1193/1695 train_time:117827ms step_avg:98.77ms +step:1194/1695 train_time:117928ms step_avg:98.77ms +step:1195/1695 train_time:118029ms step_avg:98.77ms +step:1196/1695 train_time:118131ms step_avg:98.77ms +step:1197/1695 train_time:118231ms step_avg:98.77ms +step:1198/1695 train_time:118331ms step_avg:98.77ms +step:1199/1695 train_time:118432ms step_avg:98.78ms +step:1200/1695 train_time:118531ms step_avg:98.78ms +step:1201/1695 train_time:118630ms step_avg:98.78ms +step:1202/1695 train_time:118731ms step_avg:98.78ms +step:1203/1695 train_time:118832ms step_avg:98.78ms +step:1204/1695 train_time:118933ms step_avg:98.78ms +step:1205/1695 train_time:119033ms step_avg:98.78ms +step:1206/1695 train_time:119135ms step_avg:98.79ms +step:1207/1695 train_time:119236ms step_avg:98.79ms +step:1208/1695 train_time:119338ms step_avg:98.79ms +step:1209/1695 train_time:119439ms step_avg:98.79ms +step:1210/1695 train_time:119540ms step_avg:98.79ms +step:1211/1695 train_time:119641ms step_avg:98.80ms +step:1212/1695 train_time:119743ms step_avg:98.80ms +step:1213/1695 train_time:119844ms step_avg:98.80ms +step:1214/1695 train_time:119945ms step_avg:98.80ms +step:1215/1695 train_time:120047ms step_avg:98.80ms +step:1216/1695 train_time:120149ms step_avg:98.81ms +step:1217/1695 train_time:120249ms step_avg:98.81ms +step:1218/1695 train_time:120350ms step_avg:98.81ms +step:1219/1695 train_time:120450ms step_avg:98.81ms +step:1220/1695 train_time:120551ms step_avg:98.81ms +step:1221/1695 train_time:120652ms step_avg:98.81ms +step:1222/1695 train_time:120752ms step_avg:98.82ms +step:1223/1695 train_time:120852ms step_avg:98.82ms +step:1224/1695 train_time:120952ms step_avg:98.82ms +step:1225/1695 train_time:121053ms step_avg:98.82ms +step:1226/1695 train_time:121153ms step_avg:98.82ms +step:1227/1695 train_time:121254ms step_avg:98.82ms +step:1228/1695 train_time:121355ms step_avg:98.82ms +step:1229/1695 train_time:121456ms step_avg:98.82ms +step:1230/1695 train_time:121556ms step_avg:98.83ms +step:1231/1695 train_time:121657ms step_avg:98.83ms +step:1232/1695 train_time:121758ms step_avg:98.83ms +step:1233/1695 train_time:121859ms step_avg:98.83ms +step:1234/1695 train_time:121962ms step_avg:98.83ms +step:1235/1695 train_time:122063ms step_avg:98.84ms +step:1236/1695 train_time:122166ms step_avg:98.84ms +step:1237/1695 train_time:122267ms step_avg:98.84ms +step:1238/1695 train_time:122369ms step_avg:98.84ms +step:1239/1695 train_time:122469ms step_avg:98.85ms +step:1240/1695 train_time:122570ms step_avg:98.85ms +step:1241/1695 train_time:122671ms step_avg:98.85ms +step:1242/1695 train_time:122771ms step_avg:98.85ms +step:1243/1695 train_time:122871ms step_avg:98.85ms +step:1244/1695 train_time:122971ms step_avg:98.85ms +step:1245/1695 train_time:123071ms step_avg:98.85ms +step:1246/1695 train_time:123172ms step_avg:98.85ms +step:1247/1695 train_time:123273ms step_avg:98.86ms +step:1248/1695 train_time:123373ms step_avg:98.86ms +step:1249/1695 train_time:123473ms step_avg:98.86ms +step:1250/1695 train_time:123573ms step_avg:98.86ms +step:1250/1695 val_loss:3.3958 train_time:123671ms step_avg:98.94ms +step:1251/1695 train_time:123697ms step_avg:98.88ms +step:1252/1695 train_time:123786ms step_avg:98.87ms +step:1253/1695 train_time:123887ms step_avg:98.87ms +step:1254/1695 train_time:123989ms step_avg:98.87ms +step:1255/1695 train_time:124090ms step_avg:98.88ms +step:1256/1695 train_time:124190ms step_avg:98.88ms +step:1257/1695 train_time:124289ms step_avg:98.88ms +step:1258/1695 train_time:124390ms step_avg:98.88ms +step:1259/1695 train_time:124490ms step_avg:98.88ms +step:1260/1695 train_time:124591ms step_avg:98.88ms +step:1261/1695 train_time:124693ms step_avg:98.88ms +step:1262/1695 train_time:124797ms step_avg:98.89ms +step:1263/1695 train_time:124897ms step_avg:98.89ms +step:1264/1695 train_time:124997ms step_avg:98.89ms +step:1265/1695 train_time:125097ms step_avg:98.89ms +step:1266/1695 train_time:125196ms step_avg:98.89ms +step:1267/1695 train_time:125296ms step_avg:98.89ms +step:1268/1695 train_time:125396ms step_avg:98.89ms +step:1269/1695 train_time:125497ms step_avg:98.89ms +step:1270/1695 train_time:125598ms step_avg:98.90ms +step:1271/1695 train_time:125699ms step_avg:98.90ms +step:1272/1695 train_time:125800ms step_avg:98.90ms +step:1273/1695 train_time:125902ms step_avg:98.90ms +step:1274/1695 train_time:126002ms step_avg:98.90ms +step:1275/1695 train_time:126104ms step_avg:98.91ms +step:1276/1695 train_time:126209ms step_avg:98.91ms +step:1277/1695 train_time:126310ms step_avg:98.91ms +step:1278/1695 train_time:126411ms step_avg:98.91ms +step:1279/1695 train_time:126512ms step_avg:98.91ms +step:1280/1695 train_time:126613ms step_avg:98.92ms +step:1281/1695 train_time:126714ms step_avg:98.92ms +step:1282/1695 train_time:126814ms step_avg:98.92ms +step:1283/1695 train_time:126915ms step_avg:98.92ms +step:1284/1695 train_time:127015ms step_avg:98.92ms +step:1285/1695 train_time:127115ms step_avg:98.92ms +step:1286/1695 train_time:127216ms step_avg:98.92ms +step:1287/1695 train_time:127318ms step_avg:98.93ms +step:1288/1695 train_time:127418ms step_avg:98.93ms +step:1289/1695 train_time:127518ms step_avg:98.93ms +step:1290/1695 train_time:127619ms step_avg:98.93ms +step:1291/1695 train_time:127720ms step_avg:98.93ms +step:1292/1695 train_time:127821ms step_avg:98.93ms +step:1293/1695 train_time:127924ms step_avg:98.94ms +step:1294/1695 train_time:128026ms step_avg:98.94ms +step:1295/1695 train_time:128128ms step_avg:98.94ms +step:1296/1695 train_time:128229ms step_avg:98.94ms +step:1297/1695 train_time:128330ms step_avg:98.94ms +step:1298/1695 train_time:128431ms step_avg:98.95ms +step:1299/1695 train_time:128532ms step_avg:98.95ms +step:1300/1695 train_time:128633ms step_avg:98.95ms +step:1301/1695 train_time:128734ms step_avg:98.95ms +step:1302/1695 train_time:128836ms step_avg:98.95ms +step:1303/1695 train_time:128936ms step_avg:98.95ms +step:1304/1695 train_time:129037ms step_avg:98.95ms +step:1305/1695 train_time:129137ms step_avg:98.96ms +step:1306/1695 train_time:129237ms step_avg:98.96ms +step:1307/1695 train_time:129338ms step_avg:98.96ms +step:1308/1695 train_time:129439ms step_avg:98.96ms +step:1309/1695 train_time:129539ms step_avg:98.96ms +step:1310/1695 train_time:129640ms step_avg:98.96ms +step:1311/1695 train_time:129742ms step_avg:98.96ms +step:1312/1695 train_time:129843ms step_avg:98.97ms +step:1313/1695 train_time:129946ms step_avg:98.97ms +step:1314/1695 train_time:130047ms step_avg:98.97ms +step:1315/1695 train_time:130147ms step_avg:98.97ms +step:1316/1695 train_time:130249ms step_avg:98.97ms +step:1317/1695 train_time:130351ms step_avg:98.98ms +step:1318/1695 train_time:130451ms step_avg:98.98ms +step:1319/1695 train_time:130553ms step_avg:98.98ms +step:1320/1695 train_time:130655ms step_avg:98.98ms +step:1321/1695 train_time:130756ms step_avg:98.98ms +step:1322/1695 train_time:130857ms step_avg:98.98ms +step:1323/1695 train_time:130956ms step_avg:98.98ms +step:1324/1695 train_time:131057ms step_avg:98.99ms +step:1325/1695 train_time:131158ms step_avg:98.99ms +step:1326/1695 train_time:131259ms step_avg:98.99ms +step:1327/1695 train_time:131361ms step_avg:98.99ms +step:1328/1695 train_time:131463ms step_avg:98.99ms +step:1329/1695 train_time:131564ms step_avg:98.99ms +step:1330/1695 train_time:131664ms step_avg:99.00ms +step:1331/1695 train_time:131766ms step_avg:99.00ms +step:1332/1695 train_time:131867ms step_avg:99.00ms +step:1333/1695 train_time:131969ms step_avg:99.00ms +step:1334/1695 train_time:132070ms step_avg:99.00ms +step:1335/1695 train_time:132172ms step_avg:99.01ms +step:1336/1695 train_time:132272ms step_avg:99.01ms +step:1337/1695 train_time:132374ms step_avg:99.01ms +step:1338/1695 train_time:132474ms step_avg:99.01ms +step:1339/1695 train_time:132575ms step_avg:99.01ms +step:1340/1695 train_time:132675ms step_avg:99.01ms +step:1341/1695 train_time:132777ms step_avg:99.01ms +step:1342/1695 train_time:132877ms step_avg:99.01ms +step:1343/1695 train_time:132978ms step_avg:99.02ms +step:1344/1695 train_time:133078ms step_avg:99.02ms +step:1345/1695 train_time:133179ms step_avg:99.02ms +step:1346/1695 train_time:133280ms step_avg:99.02ms +step:1347/1695 train_time:133382ms step_avg:99.02ms +step:1348/1695 train_time:133484ms step_avg:99.02ms +step:1349/1695 train_time:133586ms step_avg:99.03ms +step:1350/1695 train_time:133688ms step_avg:99.03ms +step:1351/1695 train_time:133789ms step_avg:99.03ms +step:1352/1695 train_time:133889ms step_avg:99.03ms +step:1353/1695 train_time:133990ms step_avg:99.03ms +step:1354/1695 train_time:134090ms step_avg:99.03ms +step:1355/1695 train_time:134190ms step_avg:99.03ms +step:1356/1695 train_time:134292ms step_avg:99.04ms +step:1357/1695 train_time:134393ms step_avg:99.04ms +step:1358/1695 train_time:134494ms step_avg:99.04ms +step:1359/1695 train_time:134595ms step_avg:99.04ms +step:1360/1695 train_time:134696ms step_avg:99.04ms +step:1361/1695 train_time:134796ms step_avg:99.04ms +step:1362/1695 train_time:134896ms step_avg:99.04ms +step:1363/1695 train_time:134997ms step_avg:99.04ms +step:1364/1695 train_time:135098ms step_avg:99.05ms +step:1365/1695 train_time:135199ms step_avg:99.05ms +step:1366/1695 train_time:135301ms step_avg:99.05ms +step:1367/1695 train_time:135403ms step_avg:99.05ms +step:1368/1695 train_time:135505ms step_avg:99.05ms +step:1369/1695 train_time:135607ms step_avg:99.06ms +step:1370/1695 train_time:135709ms step_avg:99.06ms +step:1371/1695 train_time:135809ms step_avg:99.06ms +step:1372/1695 train_time:135910ms step_avg:99.06ms +step:1373/1695 train_time:136011ms step_avg:99.06ms +step:1374/1695 train_time:136113ms step_avg:99.06ms +step:1375/1695 train_time:136215ms step_avg:99.07ms +step:1375/1695 val_loss:3.3554 train_time:136313ms step_avg:99.14ms +step:1376/1695 train_time:136338ms step_avg:99.08ms +step:1377/1695 train_time:136429ms step_avg:99.08ms +step:1378/1695 train_time:136531ms step_avg:99.08ms +step:1379/1695 train_time:136632ms step_avg:99.08ms +step:1380/1695 train_time:136735ms step_avg:99.08ms +step:1381/1695 train_time:136835ms step_avg:99.08ms +step:1382/1695 train_time:136935ms step_avg:99.08ms +step:1383/1695 train_time:137035ms step_avg:99.09ms +step:1384/1695 train_time:137136ms step_avg:99.09ms +step:1385/1695 train_time:137237ms step_avg:99.09ms +step:1386/1695 train_time:137341ms step_avg:99.09ms +step:1387/1695 train_time:137443ms step_avg:99.09ms +step:1388/1695 train_time:137545ms step_avg:99.10ms +step:1389/1695 train_time:137647ms step_avg:99.10ms +step:1390/1695 train_time:137748ms step_avg:99.10ms +step:1391/1695 train_time:137849ms step_avg:99.10ms +step:1392/1695 train_time:137951ms step_avg:99.10ms +step:1393/1695 train_time:138053ms step_avg:99.10ms +step:1394/1695 train_time:138154ms step_avg:99.11ms +step:1395/1695 train_time:138255ms step_avg:99.11ms +step:1396/1695 train_time:138359ms step_avg:99.11ms +step:1397/1695 train_time:138462ms step_avg:99.11ms +step:1398/1695 train_time:138565ms step_avg:99.12ms +step:1399/1695 train_time:138667ms step_avg:99.12ms +step:1400/1695 train_time:138769ms step_avg:99.12ms +step:1401/1695 train_time:138870ms step_avg:99.12ms +step:1402/1695 train_time:138972ms step_avg:99.12ms +step:1403/1695 train_time:139075ms step_avg:99.13ms +step:1404/1695 train_time:139178ms step_avg:99.13ms +step:1405/1695 train_time:139279ms step_avg:99.13ms +step:1406/1695 train_time:139382ms step_avg:99.13ms +step:1407/1695 train_time:139483ms step_avg:99.14ms +step:1408/1695 train_time:139584ms step_avg:99.14ms +step:1409/1695 train_time:139688ms step_avg:99.14ms +step:1410/1695 train_time:139790ms step_avg:99.14ms +step:1411/1695 train_time:139891ms step_avg:99.14ms +step:1412/1695 train_time:139995ms step_avg:99.15ms +step:1413/1695 train_time:140096ms step_avg:99.15ms +step:1414/1695 train_time:140199ms step_avg:99.15ms +step:1415/1695 train_time:140302ms step_avg:99.15ms +step:1416/1695 train_time:140402ms step_avg:99.15ms +step:1417/1695 train_time:140503ms step_avg:99.15ms +step:1418/1695 train_time:140604ms step_avg:99.16ms +step:1419/1695 train_time:140705ms step_avg:99.16ms +step:1420/1695 train_time:140806ms step_avg:99.16ms +step:1421/1695 train_time:140908ms step_avg:99.16ms +step:1422/1695 train_time:141009ms step_avg:99.16ms +step:1423/1695 train_time:141111ms step_avg:99.16ms +step:1424/1695 train_time:141213ms step_avg:99.17ms +step:1425/1695 train_time:141316ms step_avg:99.17ms +step:1426/1695 train_time:141420ms step_avg:99.17ms +step:1427/1695 train_time:141521ms step_avg:99.17ms +step:1428/1695 train_time:141623ms step_avg:99.18ms +step:1429/1695 train_time:141724ms step_avg:99.18ms +step:1430/1695 train_time:141825ms step_avg:99.18ms +step:1431/1695 train_time:141927ms step_avg:99.18ms +step:1432/1695 train_time:142028ms step_avg:99.18ms +step:1433/1695 train_time:142130ms step_avg:99.18ms +step:1434/1695 train_time:142232ms step_avg:99.19ms +step:1435/1695 train_time:142334ms step_avg:99.19ms +step:1436/1695 train_time:142437ms step_avg:99.19ms +step:1437/1695 train_time:142538ms step_avg:99.19ms +step:1438/1695 train_time:142639ms step_avg:99.19ms +step:1439/1695 train_time:142742ms step_avg:99.20ms +step:1440/1695 train_time:142844ms step_avg:99.20ms +step:1441/1695 train_time:142947ms step_avg:99.20ms +step:1442/1695 train_time:143048ms step_avg:99.20ms +step:1443/1695 train_time:143148ms step_avg:99.20ms +step:1444/1695 train_time:143251ms step_avg:99.20ms +step:1445/1695 train_time:143352ms step_avg:99.21ms +step:1446/1695 train_time:143454ms step_avg:99.21ms +step:1447/1695 train_time:143557ms step_avg:99.21ms +step:1448/1695 train_time:143662ms step_avg:99.21ms +step:1449/1695 train_time:143762ms step_avg:99.21ms +step:1450/1695 train_time:143864ms step_avg:99.22ms +step:1451/1695 train_time:143965ms step_avg:99.22ms +step:1452/1695 train_time:144067ms step_avg:99.22ms +step:1453/1695 train_time:144169ms step_avg:99.22ms +step:1454/1695 train_time:144272ms step_avg:99.22ms +step:1455/1695 train_time:144374ms step_avg:99.23ms +step:1456/1695 train_time:144476ms step_avg:99.23ms +step:1457/1695 train_time:144578ms step_avg:99.23ms +step:1458/1695 train_time:144680ms step_avg:99.23ms +step:1459/1695 train_time:144782ms step_avg:99.23ms +step:1460/1695 train_time:144882ms step_avg:99.23ms +step:1461/1695 train_time:144985ms step_avg:99.24ms +step:1462/1695 train_time:145086ms step_avg:99.24ms +step:1463/1695 train_time:145187ms step_avg:99.24ms +step:1464/1695 train_time:145288ms step_avg:99.24ms +step:1465/1695 train_time:145390ms step_avg:99.24ms +step:1466/1695 train_time:145494ms step_avg:99.25ms +step:1467/1695 train_time:145596ms step_avg:99.25ms +step:1468/1695 train_time:145699ms step_avg:99.25ms +step:1469/1695 train_time:145802ms step_avg:99.25ms +step:1470/1695 train_time:145902ms step_avg:99.25ms +step:1471/1695 train_time:146003ms step_avg:99.25ms +step:1472/1695 train_time:146104ms step_avg:99.26ms +step:1473/1695 train_time:146204ms step_avg:99.26ms +step:1474/1695 train_time:146306ms step_avg:99.26ms +step:1475/1695 train_time:146408ms step_avg:99.26ms +step:1476/1695 train_time:146510ms step_avg:99.26ms +step:1477/1695 train_time:146613ms step_avg:99.26ms +step:1478/1695 train_time:146716ms step_avg:99.27ms +step:1479/1695 train_time:146818ms step_avg:99.27ms +step:1480/1695 train_time:146919ms step_avg:99.27ms +step:1481/1695 train_time:147021ms step_avg:99.27ms +step:1482/1695 train_time:147122ms step_avg:99.27ms +step:1483/1695 train_time:147224ms step_avg:99.27ms +step:1484/1695 train_time:147326ms step_avg:99.28ms +step:1485/1695 train_time:147429ms step_avg:99.28ms +step:1486/1695 train_time:147530ms step_avg:99.28ms +step:1487/1695 train_time:147631ms step_avg:99.28ms +step:1488/1695 train_time:147734ms step_avg:99.28ms +step:1489/1695 train_time:147837ms step_avg:99.29ms +step:1490/1695 train_time:147941ms step_avg:99.29ms +step:1491/1695 train_time:148042ms step_avg:99.29ms +step:1492/1695 train_time:148143ms step_avg:99.29ms +step:1493/1695 train_time:148244ms step_avg:99.29ms +step:1494/1695 train_time:148345ms step_avg:99.29ms +step:1495/1695 train_time:148447ms step_avg:99.30ms +step:1496/1695 train_time:148548ms step_avg:99.30ms +step:1497/1695 train_time:148649ms step_avg:99.30ms +step:1498/1695 train_time:148752ms step_avg:99.30ms +step:1499/1695 train_time:148855ms step_avg:99.30ms +step:1500/1695 train_time:148957ms step_avg:99.30ms +step:1500/1695 val_loss:3.3207 train_time:149058ms step_avg:99.37ms +step:1501/1695 train_time:149083ms step_avg:99.32ms +step:1502/1695 train_time:149173ms step_avg:99.32ms +step:1503/1695 train_time:149276ms step_avg:99.32ms +step:1504/1695 train_time:149377ms step_avg:99.32ms +step:1505/1695 train_time:149478ms step_avg:99.32ms +step:1506/1695 train_time:149579ms step_avg:99.32ms +step:1507/1695 train_time:149679ms step_avg:99.32ms +step:1508/1695 train_time:149779ms step_avg:99.32ms +step:1509/1695 train_time:149881ms step_avg:99.32ms +step:1510/1695 train_time:149982ms step_avg:99.33ms +step:1511/1695 train_time:150086ms step_avg:99.33ms +step:1512/1695 train_time:150190ms step_avg:99.33ms +step:1513/1695 train_time:150292ms step_avg:99.33ms +step:1514/1695 train_time:150394ms step_avg:99.34ms +step:1515/1695 train_time:150499ms step_avg:99.34ms +step:1516/1695 train_time:150600ms step_avg:99.34ms +step:1517/1695 train_time:150700ms step_avg:99.34ms +step:1518/1695 train_time:150801ms step_avg:99.34ms +step:1519/1695 train_time:150904ms step_avg:99.34ms +step:1520/1695 train_time:151006ms step_avg:99.35ms +step:1521/1695 train_time:151108ms step_avg:99.35ms +step:1522/1695 train_time:151210ms step_avg:99.35ms +step:1523/1695 train_time:151313ms step_avg:99.35ms +step:1524/1695 train_time:151417ms step_avg:99.35ms +step:1525/1695 train_time:151520ms step_avg:99.36ms +step:1526/1695 train_time:151622ms step_avg:99.36ms +step:1527/1695 train_time:151724ms step_avg:99.36ms +step:1528/1695 train_time:151830ms step_avg:99.37ms +step:1529/1695 train_time:151931ms step_avg:99.37ms +step:1530/1695 train_time:152034ms step_avg:99.37ms +step:1531/1695 train_time:152135ms step_avg:99.37ms +step:1532/1695 train_time:152237ms step_avg:99.37ms +step:1533/1695 train_time:152338ms step_avg:99.37ms +step:1534/1695 train_time:152440ms step_avg:99.37ms +step:1535/1695 train_time:152543ms step_avg:99.38ms +step:1536/1695 train_time:152644ms step_avg:99.38ms +step:1537/1695 train_time:152746ms step_avg:99.38ms +step:1538/1695 train_time:152848ms step_avg:99.38ms +step:1539/1695 train_time:152950ms step_avg:99.38ms +step:1540/1695 train_time:153052ms step_avg:99.38ms +step:1541/1695 train_time:153156ms step_avg:99.39ms +step:1542/1695 train_time:153259ms step_avg:99.39ms +step:1543/1695 train_time:153362ms step_avg:99.39ms +step:1544/1695 train_time:153464ms step_avg:99.39ms +step:1545/1695 train_time:153566ms step_avg:99.40ms +step:1546/1695 train_time:153667ms step_avg:99.40ms +step:1547/1695 train_time:153769ms step_avg:99.40ms +step:1548/1695 train_time:153871ms step_avg:99.40ms +step:1549/1695 train_time:153974ms step_avg:99.40ms +step:1550/1695 train_time:154075ms step_avg:99.40ms +step:1551/1695 train_time:154178ms step_avg:99.41ms +step:1552/1695 train_time:154280ms step_avg:99.41ms +step:1553/1695 train_time:154382ms step_avg:99.41ms +step:1554/1695 train_time:154483ms step_avg:99.41ms +step:1555/1695 train_time:154585ms step_avg:99.41ms +step:1556/1695 train_time:154686ms step_avg:99.41ms +step:1557/1695 train_time:154790ms step_avg:99.42ms +step:1558/1695 train_time:154894ms step_avg:99.42ms +step:1559/1695 train_time:154997ms step_avg:99.42ms +step:1560/1695 train_time:155098ms step_avg:99.42ms +step:1561/1695 train_time:155199ms step_avg:99.42ms +step:1562/1695 train_time:155301ms step_avg:99.42ms +step:1563/1695 train_time:155406ms step_avg:99.43ms +step:1564/1695 train_time:155508ms step_avg:99.43ms +step:1565/1695 train_time:155609ms step_avg:99.43ms +step:1566/1695 train_time:155710ms step_avg:99.43ms +step:1567/1695 train_time:155812ms step_avg:99.43ms +step:1568/1695 train_time:155913ms step_avg:99.43ms +step:1569/1695 train_time:156014ms step_avg:99.44ms +step:1570/1695 train_time:156118ms step_avg:99.44ms +step:1571/1695 train_time:156219ms step_avg:99.44ms +step:1572/1695 train_time:156320ms step_avg:99.44ms +step:1573/1695 train_time:156422ms step_avg:99.44ms +step:1574/1695 train_time:156523ms step_avg:99.44ms +step:1575/1695 train_time:156625ms step_avg:99.44ms +step:1576/1695 train_time:156728ms step_avg:99.45ms +step:1577/1695 train_time:156832ms step_avg:99.45ms +step:1578/1695 train_time:156933ms step_avg:99.45ms +step:1579/1695 train_time:157035ms step_avg:99.45ms +step:1580/1695 train_time:157137ms step_avg:99.45ms +step:1581/1695 train_time:157238ms step_avg:99.45ms +step:1582/1695 train_time:157339ms step_avg:99.46ms +step:1583/1695 train_time:157442ms step_avg:99.46ms +step:1584/1695 train_time:157545ms step_avg:99.46ms +step:1585/1695 train_time:157647ms step_avg:99.46ms +step:1586/1695 train_time:157750ms step_avg:99.46ms +step:1587/1695 train_time:157852ms step_avg:99.47ms +step:1588/1695 train_time:157953ms step_avg:99.47ms +step:1589/1695 train_time:158055ms step_avg:99.47ms +step:1590/1695 train_time:158157ms step_avg:99.47ms +step:1591/1695 train_time:158258ms step_avg:99.47ms +step:1592/1695 train_time:158360ms step_avg:99.47ms +step:1593/1695 train_time:158461ms step_avg:99.47ms +step:1594/1695 train_time:158565ms step_avg:99.48ms +step:1595/1695 train_time:158668ms step_avg:99.48ms +step:1596/1695 train_time:158771ms step_avg:99.48ms +step:1597/1695 train_time:158873ms step_avg:99.48ms +step:1598/1695 train_time:158975ms step_avg:99.48ms +step:1599/1695 train_time:159076ms step_avg:99.48ms +step:1600/1695 train_time:159177ms step_avg:99.49ms +step:1601/1695 train_time:159279ms step_avg:99.49ms +step:1602/1695 train_time:159381ms step_avg:99.49ms +step:1603/1695 train_time:159482ms step_avg:99.49ms +step:1604/1695 train_time:159583ms step_avg:99.49ms +step:1605/1695 train_time:159685ms step_avg:99.49ms +step:1606/1695 train_time:159788ms step_avg:99.49ms +step:1607/1695 train_time:159890ms step_avg:99.50ms +step:1608/1695 train_time:159992ms step_avg:99.50ms +step:1609/1695 train_time:160093ms step_avg:99.50ms +step:1610/1695 train_time:160197ms step_avg:99.50ms +step:1611/1695 train_time:160298ms step_avg:99.50ms +step:1612/1695 train_time:160399ms step_avg:99.50ms +step:1613/1695 train_time:160500ms step_avg:99.50ms +step:1614/1695 train_time:160600ms step_avg:99.50ms +step:1615/1695 train_time:160701ms step_avg:99.51ms +step:1616/1695 train_time:160802ms step_avg:99.51ms +step:1617/1695 train_time:160905ms step_avg:99.51ms +step:1618/1695 train_time:161008ms step_avg:99.51ms +step:1619/1695 train_time:161111ms step_avg:99.51ms +step:1620/1695 train_time:161215ms step_avg:99.52ms +step:1621/1695 train_time:161316ms step_avg:99.52ms +step:1622/1695 train_time:161420ms step_avg:99.52ms +step:1623/1695 train_time:161522ms step_avg:99.52ms +step:1624/1695 train_time:161623ms step_avg:99.52ms +step:1625/1695 train_time:161727ms step_avg:99.52ms +step:1625/1695 val_loss:3.2913 train_time:161827ms step_avg:99.59ms +step:1626/1695 train_time:161853ms step_avg:99.54ms +step:1627/1695 train_time:161939ms step_avg:99.53ms +step:1628/1695 train_time:162041ms step_avg:99.53ms +step:1629/1695 train_time:162144ms step_avg:99.54ms +step:1630/1695 train_time:162245ms step_avg:99.54ms +step:1631/1695 train_time:162346ms step_avg:99.54ms +step:1632/1695 train_time:162448ms step_avg:99.54ms +step:1633/1695 train_time:162548ms step_avg:99.54ms +step:1634/1695 train_time:162651ms step_avg:99.54ms +step:1635/1695 train_time:162753ms step_avg:99.54ms +step:1636/1695 train_time:162856ms step_avg:99.55ms +step:1637/1695 train_time:162960ms step_avg:99.55ms +step:1638/1695 train_time:163064ms step_avg:99.55ms +step:1639/1695 train_time:163167ms step_avg:99.55ms +step:1640/1695 train_time:163268ms step_avg:99.55ms +step:1641/1695 train_time:163371ms step_avg:99.56ms +step:1642/1695 train_time:163473ms step_avg:99.56ms +step:1643/1695 train_time:163574ms step_avg:99.56ms +step:1644/1695 train_time:163676ms step_avg:99.56ms +step:1645/1695 train_time:163781ms step_avg:99.56ms +step:1646/1695 train_time:163884ms step_avg:99.57ms +step:1647/1695 train_time:163990ms step_avg:99.57ms +step:1648/1695 train_time:164093ms step_avg:99.57ms +step:1649/1695 train_time:164196ms step_avg:99.57ms +step:1650/1695 train_time:164298ms step_avg:99.57ms +step:1651/1695 train_time:164400ms step_avg:99.58ms +step:1652/1695 train_time:164504ms step_avg:99.58ms +step:1653/1695 train_time:164608ms step_avg:99.58ms +step:1654/1695 train_time:164710ms step_avg:99.58ms +step:1655/1695 train_time:164814ms step_avg:99.59ms +step:1656/1695 train_time:164917ms step_avg:99.59ms +step:1657/1695 train_time:165019ms step_avg:99.59ms +step:1658/1695 train_time:165122ms step_avg:99.59ms +step:1659/1695 train_time:165228ms step_avg:99.60ms +step:1660/1695 train_time:165330ms step_avg:99.60ms +step:1661/1695 train_time:165434ms step_avg:99.60ms +step:1662/1695 train_time:165539ms step_avg:99.60ms +step:1663/1695 train_time:165642ms step_avg:99.60ms +step:1664/1695 train_time:165745ms step_avg:99.61ms +step:1665/1695 train_time:165850ms step_avg:99.61ms +step:1666/1695 train_time:165953ms step_avg:99.61ms +step:1667/1695 train_time:166055ms step_avg:99.61ms +step:1668/1695 train_time:166159ms step_avg:99.62ms +step:1669/1695 train_time:166265ms step_avg:99.62ms +step:1670/1695 train_time:166367ms step_avg:99.62ms +step:1671/1695 train_time:166469ms step_avg:99.62ms +step:1672/1695 train_time:166572ms step_avg:99.62ms +step:1673/1695 train_time:166674ms step_avg:99.63ms +step:1674/1695 train_time:166775ms step_avg:99.63ms +step:1675/1695 train_time:166878ms step_avg:99.63ms +step:1676/1695 train_time:166984ms step_avg:99.63ms +step:1677/1695 train_time:167086ms step_avg:99.63ms +step:1678/1695 train_time:167189ms step_avg:99.64ms +step:1679/1695 train_time:167293ms step_avg:99.64ms +step:1680/1695 train_time:167394ms step_avg:99.64ms +step:1681/1695 train_time:167496ms step_avg:99.64ms +step:1682/1695 train_time:167602ms step_avg:99.64ms +step:1683/1695 train_time:167705ms step_avg:99.65ms +step:1684/1695 train_time:167809ms step_avg:99.65ms +step:1685/1695 train_time:167911ms step_avg:99.65ms +step:1686/1695 train_time:168013ms step_avg:99.65ms +step:1687/1695 train_time:168115ms step_avg:99.65ms +step:1688/1695 train_time:168218ms step_avg:99.66ms +step:1689/1695 train_time:168319ms step_avg:99.66ms +step:1690/1695 train_time:168422ms step_avg:99.66ms +step:1691/1695 train_time:168525ms step_avg:99.66ms +step:1692/1695 train_time:168628ms step_avg:99.66ms +step:1693/1695 train_time:168731ms step_avg:99.66ms +step:1694/1695 train_time:168834ms step_avg:99.67ms +step:1695/1695 train_time:168938ms step_avg:99.67ms +step:1695/1695 val_loss:3.2782 train_time:169037ms step_avg:99.73ms +peak memory allocated: 34005 MiB reserved: 49660 MiB diff --git a/records/082325_SparseAttnGate/c6be54c1-12d0-45a3-83cb-41cad0868d15.txt b/records/082325_SparseAttnGate/c6be54c1-12d0-45a3-83cb-41cad0868d15.txt new file mode 100644 index 000000000..d1010d1e1 --- /dev/null +++ b/records/082325_SparseAttnGate/c6be54c1-12d0-45a3-83cb-41cad0868d15.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:56:00 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 323795 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 323796 C /usr/bin/python3 614MiB | +| 0 N/A N/A 323797 C /usr/bin/python3 614MiB | +| 0 N/A N/A 323798 C /usr/bin/python3 614MiB | +| 0 N/A N/A 323799 C /usr/bin/python3 614MiB | +| 0 N/A N/A 323800 C /usr/bin/python3 614MiB | +| 0 N/A N/A 323801 C /usr/bin/python3 614MiB | +| 0 N/A N/A 323802 C /usr/bin/python3 614MiB | +| 1 N/A N/A 323796 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 323797 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 323798 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 323799 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 323800 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 323801 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 323802 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:154ms step_avg:153.64ms +step:2/1695 train_time:180ms step_avg:90.19ms +step:3/1695 train_time:250ms step_avg:83.26ms +step:4/1695 train_time:342ms step_avg:85.41ms +step:5/1695 train_time:434ms step_avg:86.81ms +step:6/1695 train_time:526ms step_avg:87.73ms +step:7/1695 train_time:620ms step_avg:88.53ms +step:8/1695 train_time:712ms step_avg:89.01ms +step:9/1695 train_time:805ms step_avg:89.45ms +step:10/1695 train_time:898ms step_avg:89.83ms +step:11/1695 train_time:991ms step_avg:90.09ms +step:12/1695 train_time:1087ms step_avg:90.60ms +step:13/1695 train_time:1185ms step_avg:91.12ms +step:14/1695 train_time:1280ms step_avg:91.40ms +step:15/1695 train_time:1373ms step_avg:91.53ms +step:16/1695 train_time:1466ms step_avg:91.65ms +step:17/1695 train_time:1561ms step_avg:91.82ms +step:18/1695 train_time:1653ms step_avg:91.83ms +step:19/1695 train_time:1746ms step_avg:91.87ms +step:20/1695 train_time:1839ms step_avg:91.94ms +step:21/1695 train_time:1932ms step_avg:92.00ms +step:22/1695 train_time:2025ms step_avg:92.06ms +step:23/1695 train_time:2120ms step_avg:92.17ms +step:24/1695 train_time:2215ms step_avg:92.27ms +step:25/1695 train_time:2310ms step_avg:92.40ms +step:26/1695 train_time:2404ms step_avg:92.46ms +step:27/1695 train_time:2498ms step_avg:92.53ms +step:28/1695 train_time:2592ms step_avg:92.57ms +step:29/1695 train_time:2686ms step_avg:92.63ms +step:30/1695 train_time:2779ms step_avg:92.64ms +step:31/1695 train_time:2872ms step_avg:92.65ms +step:32/1695 train_time:2966ms step_avg:92.68ms +step:33/1695 train_time:3061ms step_avg:92.75ms +step:34/1695 train_time:3154ms step_avg:92.77ms +step:35/1695 train_time:3248ms step_avg:92.80ms +step:36/1695 train_time:3342ms step_avg:92.83ms +step:37/1695 train_time:3435ms step_avg:92.85ms +step:38/1695 train_time:3529ms step_avg:92.88ms +step:39/1695 train_time:3624ms step_avg:92.92ms +step:40/1695 train_time:3717ms step_avg:92.93ms +step:41/1695 train_time:3811ms step_avg:92.94ms +step:42/1695 train_time:3905ms step_avg:92.97ms +step:43/1695 train_time:3998ms step_avg:92.98ms +step:44/1695 train_time:4093ms step_avg:93.02ms +step:45/1695 train_time:4186ms step_avg:93.02ms +step:46/1695 train_time:4280ms step_avg:93.04ms +step:47/1695 train_time:4374ms step_avg:93.06ms +step:48/1695 train_time:4468ms step_avg:93.09ms +step:49/1695 train_time:4562ms step_avg:93.10ms +step:50/1695 train_time:4655ms step_avg:93.10ms +step:51/1695 train_time:4749ms step_avg:93.12ms +step:52/1695 train_time:4843ms step_avg:93.14ms +step:53/1695 train_time:4937ms step_avg:93.16ms +step:54/1695 train_time:5031ms step_avg:93.17ms +step:55/1695 train_time:5125ms step_avg:93.19ms +step:56/1695 train_time:5219ms step_avg:93.19ms +step:57/1695 train_time:5312ms step_avg:93.20ms +step:58/1695 train_time:5405ms step_avg:93.20ms +step:59/1695 train_time:5500ms step_avg:93.21ms +step:60/1695 train_time:5593ms step_avg:93.22ms +step:61/1695 train_time:5687ms step_avg:93.24ms +step:62/1695 train_time:5782ms step_avg:93.26ms +step:63/1695 train_time:5876ms step_avg:93.27ms +step:64/1695 train_time:5970ms step_avg:93.29ms +step:65/1695 train_time:6063ms step_avg:93.28ms +step:66/1695 train_time:6157ms step_avg:93.28ms +step:67/1695 train_time:6250ms step_avg:93.29ms +step:68/1695 train_time:6344ms step_avg:93.30ms +step:69/1695 train_time:6438ms step_avg:93.30ms +step:70/1695 train_time:6532ms step_avg:93.31ms +step:71/1695 train_time:6625ms step_avg:93.30ms +step:72/1695 train_time:6718ms step_avg:93.31ms +step:73/1695 train_time:6812ms step_avg:93.32ms +step:74/1695 train_time:6906ms step_avg:93.33ms +step:75/1695 train_time:7000ms step_avg:93.33ms +step:76/1695 train_time:7093ms step_avg:93.33ms +step:77/1695 train_time:7186ms step_avg:93.33ms +step:78/1695 train_time:7280ms step_avg:93.33ms +step:79/1695 train_time:7374ms step_avg:93.34ms +step:80/1695 train_time:7467ms step_avg:93.33ms +step:81/1695 train_time:7562ms step_avg:93.35ms +step:82/1695 train_time:7655ms step_avg:93.35ms +step:83/1695 train_time:7749ms step_avg:93.36ms +step:84/1695 train_time:7843ms step_avg:93.36ms +step:85/1695 train_time:7936ms step_avg:93.36ms +step:86/1695 train_time:8029ms step_avg:93.36ms +step:87/1695 train_time:8123ms step_avg:93.37ms +step:88/1695 train_time:8216ms step_avg:93.37ms +step:89/1695 train_time:8309ms step_avg:93.37ms +step:90/1695 train_time:8403ms step_avg:93.37ms +step:91/1695 train_time:8497ms step_avg:93.37ms +step:92/1695 train_time:8590ms step_avg:93.37ms +step:93/1695 train_time:8685ms step_avg:93.38ms +step:94/1695 train_time:8779ms step_avg:93.39ms +step:95/1695 train_time:8872ms step_avg:93.39ms +step:96/1695 train_time:8966ms step_avg:93.40ms +step:97/1695 train_time:9060ms step_avg:93.40ms +step:98/1695 train_time:9152ms step_avg:93.39ms +step:99/1695 train_time:9246ms step_avg:93.39ms +step:100/1695 train_time:9340ms step_avg:93.40ms +step:101/1695 train_time:9433ms step_avg:93.40ms +step:102/1695 train_time:9527ms step_avg:93.40ms +step:103/1695 train_time:9621ms step_avg:93.40ms +step:104/1695 train_time:9714ms step_avg:93.40ms +step:105/1695 train_time:9808ms step_avg:93.41ms +step:106/1695 train_time:9902ms step_avg:93.41ms +step:107/1695 train_time:9996ms step_avg:93.42ms +step:108/1695 train_time:10090ms step_avg:93.43ms +step:109/1695 train_time:10185ms step_avg:93.44ms +step:110/1695 train_time:10278ms step_avg:93.43ms +step:111/1695 train_time:10371ms step_avg:93.43ms +step:112/1695 train_time:10464ms step_avg:93.43ms +step:113/1695 train_time:10557ms step_avg:93.43ms +step:114/1695 train_time:10651ms step_avg:93.43ms +step:115/1695 train_time:10744ms step_avg:93.42ms +step:116/1695 train_time:10837ms step_avg:93.42ms +step:117/1695 train_time:10931ms step_avg:93.43ms +step:118/1695 train_time:11025ms step_avg:93.44ms +step:119/1695 train_time:11119ms step_avg:93.43ms +step:120/1695 train_time:11212ms step_avg:93.44ms +step:121/1695 train_time:11306ms step_avg:93.44ms +step:122/1695 train_time:11400ms step_avg:93.44ms +step:123/1695 train_time:11493ms step_avg:93.44ms +step:124/1695 train_time:11587ms step_avg:93.44ms +step:125/1695 train_time:11680ms step_avg:93.44ms +step:125/1695 val_loss:4.6007 train_time:11772ms step_avg:94.18ms +step:126/1695 train_time:11800ms step_avg:93.65ms +step:127/1695 train_time:11876ms step_avg:93.52ms +step:128/1695 train_time:11979ms step_avg:93.59ms +step:129/1695 train_time:12076ms step_avg:93.61ms +step:130/1695 train_time:12170ms step_avg:93.62ms +step:131/1695 train_time:12263ms step_avg:93.61ms +step:132/1695 train_time:12357ms step_avg:93.61ms +step:133/1695 train_time:12450ms step_avg:93.61ms +step:134/1695 train_time:12543ms step_avg:93.61ms +step:135/1695 train_time:12637ms step_avg:93.60ms +step:136/1695 train_time:12731ms step_avg:93.61ms +step:137/1695 train_time:12827ms step_avg:93.63ms +step:138/1695 train_time:12924ms step_avg:93.65ms +step:139/1695 train_time:13021ms step_avg:93.67ms +step:140/1695 train_time:13116ms step_avg:93.69ms +step:141/1695 train_time:13211ms step_avg:93.70ms +step:142/1695 train_time:13305ms step_avg:93.70ms +step:143/1695 train_time:13399ms step_avg:93.70ms +step:144/1695 train_time:13492ms step_avg:93.70ms +step:145/1695 train_time:13586ms step_avg:93.70ms +step:146/1695 train_time:13679ms step_avg:93.69ms +step:147/1695 train_time:13773ms step_avg:93.70ms +step:148/1695 train_time:13868ms step_avg:93.70ms +step:149/1695 train_time:13962ms step_avg:93.71ms +step:150/1695 train_time:14058ms step_avg:93.72ms +step:151/1695 train_time:14153ms step_avg:93.73ms +step:152/1695 train_time:14247ms step_avg:93.73ms +step:153/1695 train_time:14341ms step_avg:93.73ms +step:154/1695 train_time:14435ms step_avg:93.74ms +step:155/1695 train_time:14529ms step_avg:93.74ms +step:156/1695 train_time:14623ms step_avg:93.73ms +step:157/1695 train_time:14717ms step_avg:93.74ms +step:158/1695 train_time:14811ms step_avg:93.74ms +step:159/1695 train_time:14907ms step_avg:93.75ms +step:160/1695 train_time:15000ms step_avg:93.75ms +step:161/1695 train_time:15094ms step_avg:93.75ms +step:162/1695 train_time:15188ms step_avg:93.76ms +step:163/1695 train_time:15283ms step_avg:93.76ms +step:164/1695 train_time:15377ms step_avg:93.76ms +step:165/1695 train_time:15471ms step_avg:93.76ms +step:166/1695 train_time:15565ms step_avg:93.76ms +step:167/1695 train_time:15658ms step_avg:93.76ms +step:168/1695 train_time:15752ms step_avg:93.76ms +step:169/1695 train_time:15847ms step_avg:93.77ms +step:170/1695 train_time:15941ms step_avg:93.77ms +step:171/1695 train_time:16034ms step_avg:93.77ms +step:172/1695 train_time:16129ms step_avg:93.77ms +step:173/1695 train_time:16222ms step_avg:93.77ms +step:174/1695 train_time:16316ms step_avg:93.77ms +step:175/1695 train_time:16411ms step_avg:93.78ms +step:176/1695 train_time:16504ms step_avg:93.77ms +step:177/1695 train_time:16599ms step_avg:93.78ms +step:178/1695 train_time:16693ms step_avg:93.78ms +step:179/1695 train_time:16787ms step_avg:93.78ms +step:180/1695 train_time:16881ms step_avg:93.78ms +step:181/1695 train_time:16974ms step_avg:93.78ms +step:182/1695 train_time:17068ms step_avg:93.78ms +step:183/1695 train_time:17161ms step_avg:93.78ms +step:184/1695 train_time:17255ms step_avg:93.78ms +step:185/1695 train_time:17349ms step_avg:93.78ms +step:186/1695 train_time:17443ms step_avg:93.78ms +step:187/1695 train_time:17538ms step_avg:93.78ms +step:188/1695 train_time:17632ms step_avg:93.78ms +step:189/1695 train_time:17726ms step_avg:93.79ms +step:190/1695 train_time:17820ms step_avg:93.79ms +step:191/1695 train_time:17915ms step_avg:93.79ms +step:192/1695 train_time:18009ms step_avg:93.80ms +step:193/1695 train_time:18102ms step_avg:93.79ms +step:194/1695 train_time:18197ms step_avg:93.80ms +step:195/1695 train_time:18290ms step_avg:93.80ms +step:196/1695 train_time:18384ms step_avg:93.80ms +step:197/1695 train_time:18478ms step_avg:93.80ms +step:198/1695 train_time:18572ms step_avg:93.80ms +step:199/1695 train_time:18667ms step_avg:93.80ms +step:200/1695 train_time:18761ms step_avg:93.80ms +step:201/1695 train_time:18854ms step_avg:93.80ms +step:202/1695 train_time:18949ms step_avg:93.81ms +step:203/1695 train_time:19043ms step_avg:93.81ms +step:204/1695 train_time:19137ms step_avg:93.81ms +step:205/1695 train_time:19231ms step_avg:93.81ms +step:206/1695 train_time:19325ms step_avg:93.81ms +step:207/1695 train_time:19419ms step_avg:93.81ms +step:208/1695 train_time:19514ms step_avg:93.82ms +step:209/1695 train_time:19607ms step_avg:93.81ms +step:210/1695 train_time:19701ms step_avg:93.81ms +step:211/1695 train_time:19796ms step_avg:93.82ms +step:212/1695 train_time:19889ms step_avg:93.82ms +step:213/1695 train_time:19984ms step_avg:93.82ms +step:214/1695 train_time:20078ms step_avg:93.82ms +step:215/1695 train_time:20173ms step_avg:93.83ms +step:216/1695 train_time:20267ms step_avg:93.83ms +step:217/1695 train_time:20361ms step_avg:93.83ms +step:218/1695 train_time:20455ms step_avg:93.83ms +step:219/1695 train_time:20548ms step_avg:93.83ms +step:220/1695 train_time:20642ms step_avg:93.83ms +step:221/1695 train_time:20737ms step_avg:93.83ms +step:222/1695 train_time:20830ms step_avg:93.83ms +step:223/1695 train_time:20924ms step_avg:93.83ms +step:224/1695 train_time:21019ms step_avg:93.83ms +step:225/1695 train_time:21113ms step_avg:93.83ms +step:226/1695 train_time:21207ms step_avg:93.83ms +step:227/1695 train_time:21301ms step_avg:93.84ms +step:228/1695 train_time:21394ms step_avg:93.83ms +step:229/1695 train_time:21489ms step_avg:93.84ms +step:230/1695 train_time:21583ms step_avg:93.84ms +step:231/1695 train_time:21676ms step_avg:93.84ms +step:232/1695 train_time:21770ms step_avg:93.84ms +step:233/1695 train_time:21864ms step_avg:93.84ms +step:234/1695 train_time:21958ms step_avg:93.84ms +step:235/1695 train_time:22052ms step_avg:93.84ms +step:236/1695 train_time:22147ms step_avg:93.84ms +step:237/1695 train_time:22241ms step_avg:93.84ms +step:238/1695 train_time:22336ms step_avg:93.85ms +step:239/1695 train_time:22430ms step_avg:93.85ms +step:240/1695 train_time:22525ms step_avg:93.85ms +step:241/1695 train_time:22619ms step_avg:93.86ms +step:242/1695 train_time:22714ms step_avg:93.86ms +step:243/1695 train_time:22808ms step_avg:93.86ms +step:244/1695 train_time:22901ms step_avg:93.86ms +step:245/1695 train_time:22995ms step_avg:93.86ms +step:246/1695 train_time:23090ms step_avg:93.86ms +step:247/1695 train_time:23183ms step_avg:93.86ms +step:248/1695 train_time:23278ms step_avg:93.86ms +step:249/1695 train_time:23374ms step_avg:93.87ms +step:250/1695 train_time:23468ms step_avg:93.87ms +step:250/1695 val_loss:4.0786 train_time:23559ms step_avg:94.24ms +step:251/1695 train_time:23587ms step_avg:93.97ms +step:252/1695 train_time:23665ms step_avg:93.91ms +step:253/1695 train_time:23761ms step_avg:93.92ms +step:254/1695 train_time:23855ms step_avg:93.92ms +step:255/1695 train_time:23949ms step_avg:93.92ms +step:256/1695 train_time:24044ms step_avg:93.92ms +step:257/1695 train_time:24137ms step_avg:93.92ms +step:258/1695 train_time:24231ms step_avg:93.92ms +step:259/1695 train_time:24325ms step_avg:93.92ms +step:260/1695 train_time:24418ms step_avg:93.92ms +step:261/1695 train_time:24514ms step_avg:93.92ms +step:262/1695 train_time:24612ms step_avg:93.94ms +step:263/1695 train_time:24709ms step_avg:93.95ms +step:264/1695 train_time:24805ms step_avg:93.96ms +step:265/1695 train_time:24899ms step_avg:93.96ms +step:266/1695 train_time:24994ms step_avg:93.96ms +step:267/1695 train_time:25088ms step_avg:93.96ms +step:268/1695 train_time:25182ms step_avg:93.96ms +step:269/1695 train_time:25275ms step_avg:93.96ms +step:270/1695 train_time:25370ms step_avg:93.96ms +step:271/1695 train_time:25463ms step_avg:93.96ms +step:272/1695 train_time:25558ms step_avg:93.96ms +step:273/1695 train_time:25655ms step_avg:93.98ms +step:274/1695 train_time:25752ms step_avg:93.98ms +step:275/1695 train_time:25847ms step_avg:93.99ms +step:276/1695 train_time:25940ms step_avg:93.99ms +step:277/1695 train_time:26035ms step_avg:93.99ms +step:278/1695 train_time:26131ms step_avg:94.00ms +step:279/1695 train_time:26225ms step_avg:94.00ms +step:280/1695 train_time:26319ms step_avg:94.00ms +step:281/1695 train_time:26412ms step_avg:93.99ms +step:282/1695 train_time:26506ms step_avg:93.99ms +step:283/1695 train_time:26601ms step_avg:94.00ms +step:284/1695 train_time:26697ms step_avg:94.00ms +step:285/1695 train_time:26792ms step_avg:94.01ms +step:286/1695 train_time:26887ms step_avg:94.01ms +step:287/1695 train_time:26980ms step_avg:94.01ms +step:288/1695 train_time:27075ms step_avg:94.01ms +step:289/1695 train_time:27170ms step_avg:94.01ms +step:290/1695 train_time:27265ms step_avg:94.02ms +step:291/1695 train_time:27359ms step_avg:94.02ms +step:292/1695 train_time:27453ms step_avg:94.02ms +step:293/1695 train_time:27546ms step_avg:94.01ms +step:294/1695 train_time:27641ms step_avg:94.02ms +step:295/1695 train_time:27736ms step_avg:94.02ms +step:296/1695 train_time:27830ms step_avg:94.02ms +step:297/1695 train_time:27925ms step_avg:94.02ms +step:298/1695 train_time:28019ms step_avg:94.02ms +step:299/1695 train_time:28113ms step_avg:94.02ms +step:300/1695 train_time:28208ms step_avg:94.03ms +step:301/1695 train_time:28303ms step_avg:94.03ms +step:302/1695 train_time:28397ms step_avg:94.03ms +step:303/1695 train_time:28491ms step_avg:94.03ms +step:304/1695 train_time:28585ms step_avg:94.03ms +step:305/1695 train_time:28679ms step_avg:94.03ms +step:306/1695 train_time:28774ms step_avg:94.03ms +step:307/1695 train_time:28869ms step_avg:94.04ms +step:308/1695 train_time:28963ms step_avg:94.04ms +step:309/1695 train_time:29058ms step_avg:94.04ms +step:310/1695 train_time:29152ms step_avg:94.04ms +step:311/1695 train_time:29247ms step_avg:94.04ms +step:312/1695 train_time:29341ms step_avg:94.04ms +step:313/1695 train_time:29435ms step_avg:94.04ms +step:314/1695 train_time:29530ms step_avg:94.04ms +step:315/1695 train_time:29624ms step_avg:94.04ms +step:316/1695 train_time:29718ms step_avg:94.04ms +step:317/1695 train_time:29813ms step_avg:94.05ms +step:318/1695 train_time:29907ms step_avg:94.05ms +step:319/1695 train_time:30001ms step_avg:94.05ms +step:320/1695 train_time:30096ms step_avg:94.05ms +step:321/1695 train_time:30192ms step_avg:94.06ms +step:322/1695 train_time:30286ms step_avg:94.06ms +step:323/1695 train_time:30380ms step_avg:94.06ms +step:324/1695 train_time:30474ms step_avg:94.06ms +step:325/1695 train_time:30569ms step_avg:94.06ms +step:326/1695 train_time:30664ms step_avg:94.06ms +step:327/1695 train_time:30757ms step_avg:94.06ms +step:328/1695 train_time:30853ms step_avg:94.06ms +step:329/1695 train_time:30948ms step_avg:94.07ms +step:330/1695 train_time:31042ms step_avg:94.07ms +step:331/1695 train_time:31136ms step_avg:94.07ms +step:332/1695 train_time:31232ms step_avg:94.07ms +step:333/1695 train_time:31327ms step_avg:94.07ms +step:334/1695 train_time:31421ms step_avg:94.08ms +step:335/1695 train_time:31516ms step_avg:94.08ms +step:336/1695 train_time:31611ms step_avg:94.08ms +step:337/1695 train_time:31706ms step_avg:94.08ms +step:338/1695 train_time:31800ms step_avg:94.08ms +step:339/1695 train_time:31894ms step_avg:94.08ms +step:340/1695 train_time:31989ms step_avg:94.08ms +step:341/1695 train_time:32083ms step_avg:94.08ms +step:342/1695 train_time:32177ms step_avg:94.09ms +step:343/1695 train_time:32271ms step_avg:94.09ms +step:344/1695 train_time:32367ms step_avg:94.09ms +step:345/1695 train_time:32461ms step_avg:94.09ms +step:346/1695 train_time:32555ms step_avg:94.09ms +step:347/1695 train_time:32649ms step_avg:94.09ms +step:348/1695 train_time:32744ms step_avg:94.09ms +step:349/1695 train_time:32838ms step_avg:94.09ms +step:350/1695 train_time:32932ms step_avg:94.09ms +step:351/1695 train_time:33027ms step_avg:94.09ms +step:352/1695 train_time:33121ms step_avg:94.09ms +step:353/1695 train_time:33215ms step_avg:94.09ms +step:354/1695 train_time:33310ms step_avg:94.10ms +step:355/1695 train_time:33406ms step_avg:94.10ms +step:356/1695 train_time:33499ms step_avg:94.10ms +step:357/1695 train_time:33593ms step_avg:94.10ms +step:358/1695 train_time:33688ms step_avg:94.10ms +step:359/1695 train_time:33781ms step_avg:94.10ms +step:360/1695 train_time:33876ms step_avg:94.10ms +step:361/1695 train_time:33972ms step_avg:94.10ms +step:362/1695 train_time:34065ms step_avg:94.10ms +step:363/1695 train_time:34159ms step_avg:94.10ms +step:364/1695 train_time:34255ms step_avg:94.11ms +step:365/1695 train_time:34349ms step_avg:94.11ms +step:366/1695 train_time:34444ms step_avg:94.11ms +step:367/1695 train_time:34538ms step_avg:94.11ms +step:368/1695 train_time:34632ms step_avg:94.11ms +step:369/1695 train_time:34727ms step_avg:94.11ms +step:370/1695 train_time:34822ms step_avg:94.11ms +step:371/1695 train_time:34916ms step_avg:94.11ms +step:372/1695 train_time:35010ms step_avg:94.11ms +step:373/1695 train_time:35104ms step_avg:94.11ms +step:374/1695 train_time:35199ms step_avg:94.11ms +step:375/1695 train_time:35293ms step_avg:94.12ms +step:375/1695 val_loss:3.8792 train_time:35386ms step_avg:94.36ms +step:376/1695 train_time:35413ms step_avg:94.18ms +step:377/1695 train_time:35491ms step_avg:94.14ms +step:378/1695 train_time:35593ms step_avg:94.16ms +step:379/1695 train_time:35690ms step_avg:94.17ms +step:380/1695 train_time:35785ms step_avg:94.17ms +step:381/1695 train_time:35881ms step_avg:94.17ms +step:382/1695 train_time:35976ms step_avg:94.18ms +step:383/1695 train_time:36071ms step_avg:94.18ms +step:384/1695 train_time:36167ms step_avg:94.18ms +step:385/1695 train_time:36262ms step_avg:94.19ms +step:386/1695 train_time:36357ms step_avg:94.19ms +step:387/1695 train_time:36453ms step_avg:94.20ms +step:388/1695 train_time:36552ms step_avg:94.21ms +step:389/1695 train_time:36650ms step_avg:94.22ms +step:390/1695 train_time:36748ms step_avg:94.22ms +step:391/1695 train_time:36844ms step_avg:94.23ms +step:392/1695 train_time:36940ms step_avg:94.24ms +step:393/1695 train_time:37036ms step_avg:94.24ms +step:394/1695 train_time:37132ms step_avg:94.24ms +step:395/1695 train_time:37228ms step_avg:94.25ms +step:396/1695 train_time:37323ms step_avg:94.25ms +step:397/1695 train_time:37419ms step_avg:94.25ms +step:398/1695 train_time:37515ms step_avg:94.26ms +step:399/1695 train_time:37612ms step_avg:94.27ms +step:400/1695 train_time:37710ms step_avg:94.27ms +step:401/1695 train_time:37807ms step_avg:94.28ms +step:402/1695 train_time:37903ms step_avg:94.29ms +step:403/1695 train_time:37999ms step_avg:94.29ms +step:404/1695 train_time:38095ms step_avg:94.30ms +step:405/1695 train_time:38190ms step_avg:94.30ms +step:406/1695 train_time:38286ms step_avg:94.30ms +step:407/1695 train_time:38382ms step_avg:94.30ms +step:408/1695 train_time:38477ms step_avg:94.31ms +step:409/1695 train_time:38573ms step_avg:94.31ms +step:410/1695 train_time:38670ms step_avg:94.32ms +step:411/1695 train_time:38766ms step_avg:94.32ms +step:412/1695 train_time:38862ms step_avg:94.33ms +step:413/1695 train_time:38958ms step_avg:94.33ms +step:414/1695 train_time:39054ms step_avg:94.33ms +step:415/1695 train_time:39151ms step_avg:94.34ms +step:416/1695 train_time:39246ms step_avg:94.34ms +step:417/1695 train_time:39342ms step_avg:94.34ms +step:418/1695 train_time:39438ms step_avg:94.35ms +step:419/1695 train_time:39534ms step_avg:94.35ms +step:420/1695 train_time:39630ms step_avg:94.36ms +step:421/1695 train_time:39726ms step_avg:94.36ms +step:422/1695 train_time:39823ms step_avg:94.37ms +step:423/1695 train_time:39920ms step_avg:94.37ms +step:424/1695 train_time:40016ms step_avg:94.38ms +step:425/1695 train_time:40111ms step_avg:94.38ms +step:426/1695 train_time:40208ms step_avg:94.38ms +step:427/1695 train_time:40304ms step_avg:94.39ms +step:428/1695 train_time:40401ms step_avg:94.39ms +step:429/1695 train_time:40495ms step_avg:94.39ms +step:430/1695 train_time:40591ms step_avg:94.40ms +step:431/1695 train_time:40687ms step_avg:94.40ms +step:432/1695 train_time:40783ms step_avg:94.41ms +step:433/1695 train_time:40878ms step_avg:94.41ms +step:434/1695 train_time:40974ms step_avg:94.41ms +step:435/1695 train_time:41070ms step_avg:94.41ms +step:436/1695 train_time:41166ms step_avg:94.42ms +step:437/1695 train_time:41263ms step_avg:94.42ms +step:438/1695 train_time:41359ms step_avg:94.43ms +step:439/1695 train_time:41455ms step_avg:94.43ms +step:440/1695 train_time:41550ms step_avg:94.43ms +step:441/1695 train_time:41647ms step_avg:94.44ms +step:442/1695 train_time:41743ms step_avg:94.44ms +step:443/1695 train_time:41840ms step_avg:94.45ms +step:444/1695 train_time:41936ms step_avg:94.45ms +step:445/1695 train_time:42033ms step_avg:94.46ms +step:446/1695 train_time:42129ms step_avg:94.46ms +step:447/1695 train_time:42225ms step_avg:94.46ms +step:448/1695 train_time:42321ms step_avg:94.47ms +step:449/1695 train_time:42417ms step_avg:94.47ms +step:450/1695 train_time:42513ms step_avg:94.47ms +step:451/1695 train_time:42610ms step_avg:94.48ms +step:452/1695 train_time:42706ms step_avg:94.48ms +step:453/1695 train_time:42802ms step_avg:94.49ms +step:454/1695 train_time:42898ms step_avg:94.49ms +step:455/1695 train_time:42994ms step_avg:94.49ms +step:456/1695 train_time:43090ms step_avg:94.50ms +step:457/1695 train_time:43187ms step_avg:94.50ms +step:458/1695 train_time:43283ms step_avg:94.50ms +step:459/1695 train_time:43379ms step_avg:94.51ms +step:460/1695 train_time:43475ms step_avg:94.51ms +step:461/1695 train_time:43571ms step_avg:94.51ms +step:462/1695 train_time:43667ms step_avg:94.52ms +step:463/1695 train_time:43764ms step_avg:94.52ms +step:464/1695 train_time:43860ms step_avg:94.53ms +step:465/1695 train_time:43956ms step_avg:94.53ms +step:466/1695 train_time:44052ms step_avg:94.53ms +step:467/1695 train_time:44149ms step_avg:94.54ms +step:468/1695 train_time:44246ms step_avg:94.54ms +step:469/1695 train_time:44343ms step_avg:94.55ms +step:470/1695 train_time:44439ms step_avg:94.55ms +step:471/1695 train_time:44535ms step_avg:94.55ms +step:472/1695 train_time:44631ms step_avg:94.56ms +step:473/1695 train_time:44727ms step_avg:94.56ms +step:474/1695 train_time:44823ms step_avg:94.56ms +step:475/1695 train_time:44920ms step_avg:94.57ms +step:476/1695 train_time:45016ms step_avg:94.57ms +step:477/1695 train_time:45112ms step_avg:94.57ms +step:478/1695 train_time:45208ms step_avg:94.58ms +step:479/1695 train_time:45305ms step_avg:94.58ms +step:480/1695 train_time:45401ms step_avg:94.59ms +step:481/1695 train_time:45497ms step_avg:94.59ms +step:482/1695 train_time:45593ms step_avg:94.59ms +step:483/1695 train_time:45688ms step_avg:94.59ms +step:484/1695 train_time:45785ms step_avg:94.60ms +step:485/1695 train_time:45881ms step_avg:94.60ms +step:486/1695 train_time:45977ms step_avg:94.60ms +step:487/1695 train_time:46073ms step_avg:94.61ms +step:488/1695 train_time:46170ms step_avg:94.61ms +step:489/1695 train_time:46267ms step_avg:94.61ms +step:490/1695 train_time:46363ms step_avg:94.62ms +step:491/1695 train_time:46460ms step_avg:94.62ms +step:492/1695 train_time:46555ms step_avg:94.62ms +step:493/1695 train_time:46650ms step_avg:94.63ms +step:494/1695 train_time:46747ms step_avg:94.63ms +step:495/1695 train_time:46844ms step_avg:94.63ms +step:496/1695 train_time:46940ms step_avg:94.64ms +step:497/1695 train_time:47036ms step_avg:94.64ms +step:498/1695 train_time:47131ms step_avg:94.64ms +step:499/1695 train_time:47228ms step_avg:94.65ms +step:500/1695 train_time:47325ms step_avg:94.65ms +step:500/1695 val_loss:3.7347 train_time:47420ms step_avg:94.84ms +step:501/1695 train_time:47447ms step_avg:94.70ms +step:502/1695 train_time:47528ms step_avg:94.68ms +step:503/1695 train_time:47629ms step_avg:94.69ms +step:504/1695 train_time:47725ms step_avg:94.69ms +step:505/1695 train_time:47820ms step_avg:94.69ms +step:506/1695 train_time:47916ms step_avg:94.69ms +step:507/1695 train_time:48012ms step_avg:94.70ms +step:508/1695 train_time:48107ms step_avg:94.70ms +step:509/1695 train_time:48203ms step_avg:94.70ms +step:510/1695 train_time:48299ms step_avg:94.70ms +step:511/1695 train_time:48395ms step_avg:94.71ms +step:512/1695 train_time:48499ms step_avg:94.72ms +step:513/1695 train_time:48590ms step_avg:94.72ms +step:514/1695 train_time:48688ms step_avg:94.72ms +step:515/1695 train_time:48785ms step_avg:94.73ms +step:516/1695 train_time:48881ms step_avg:94.73ms +step:517/1695 train_time:48977ms step_avg:94.73ms +step:518/1695 train_time:49074ms step_avg:94.74ms +step:519/1695 train_time:49170ms step_avg:94.74ms +step:520/1695 train_time:49267ms step_avg:94.74ms +step:521/1695 train_time:49362ms step_avg:94.74ms +step:522/1695 train_time:49458ms step_avg:94.75ms +step:523/1695 train_time:49556ms step_avg:94.75ms +step:524/1695 train_time:49653ms step_avg:94.76ms +step:525/1695 train_time:49751ms step_avg:94.76ms +step:526/1695 train_time:49848ms step_avg:94.77ms +step:527/1695 train_time:49944ms step_avg:94.77ms +step:528/1695 train_time:50040ms step_avg:94.77ms +step:529/1695 train_time:50136ms step_avg:94.78ms +step:530/1695 train_time:50233ms step_avg:94.78ms +step:531/1695 train_time:50329ms step_avg:94.78ms +step:532/1695 train_time:50426ms step_avg:94.79ms +step:533/1695 train_time:50522ms step_avg:94.79ms +step:534/1695 train_time:50619ms step_avg:94.79ms +step:535/1695 train_time:50716ms step_avg:94.80ms +step:536/1695 train_time:50813ms step_avg:94.80ms +step:537/1695 train_time:50909ms step_avg:94.80ms +step:538/1695 train_time:51006ms step_avg:94.81ms +step:539/1695 train_time:51102ms step_avg:94.81ms +step:540/1695 train_time:51198ms step_avg:94.81ms +step:541/1695 train_time:51295ms step_avg:94.81ms +step:542/1695 train_time:51392ms step_avg:94.82ms +step:543/1695 train_time:51489ms step_avg:94.82ms +step:544/1695 train_time:51586ms step_avg:94.83ms +step:545/1695 train_time:51682ms step_avg:94.83ms +step:546/1695 train_time:51779ms step_avg:94.83ms +step:547/1695 train_time:51875ms step_avg:94.84ms +step:548/1695 train_time:51972ms step_avg:94.84ms +step:549/1695 train_time:52068ms step_avg:94.84ms +step:550/1695 train_time:52164ms step_avg:94.84ms +step:551/1695 train_time:52260ms step_avg:94.85ms +step:552/1695 train_time:52356ms step_avg:94.85ms +step:553/1695 train_time:52453ms step_avg:94.85ms +step:554/1695 train_time:52549ms step_avg:94.85ms +step:555/1695 train_time:52646ms step_avg:94.86ms +step:556/1695 train_time:52742ms step_avg:94.86ms +step:557/1695 train_time:52839ms step_avg:94.86ms +step:558/1695 train_time:52935ms step_avg:94.87ms +step:559/1695 train_time:53032ms step_avg:94.87ms +step:560/1695 train_time:53128ms step_avg:94.87ms +step:561/1695 train_time:53225ms step_avg:94.87ms +step:562/1695 train_time:53321ms step_avg:94.88ms +step:563/1695 train_time:53417ms step_avg:94.88ms +step:564/1695 train_time:53513ms step_avg:94.88ms +step:565/1695 train_time:53609ms step_avg:94.88ms +step:566/1695 train_time:53707ms step_avg:94.89ms +step:567/1695 train_time:53803ms step_avg:94.89ms +step:568/1695 train_time:53899ms step_avg:94.89ms +step:569/1695 train_time:53996ms step_avg:94.90ms +step:570/1695 train_time:54093ms step_avg:94.90ms +step:571/1695 train_time:54189ms step_avg:94.90ms +step:572/1695 train_time:54286ms step_avg:94.91ms +step:573/1695 train_time:54382ms step_avg:94.91ms +step:574/1695 train_time:54479ms step_avg:94.91ms +step:575/1695 train_time:54576ms step_avg:94.91ms +step:576/1695 train_time:54673ms step_avg:94.92ms +step:577/1695 train_time:54769ms step_avg:94.92ms +step:578/1695 train_time:54866ms step_avg:94.92ms +step:579/1695 train_time:54962ms step_avg:94.93ms +step:580/1695 train_time:55059ms step_avg:94.93ms +step:581/1695 train_time:55156ms step_avg:94.93ms +step:582/1695 train_time:55253ms step_avg:94.94ms +step:583/1695 train_time:55351ms step_avg:94.94ms +step:584/1695 train_time:55447ms step_avg:94.94ms +step:585/1695 train_time:55544ms step_avg:94.95ms +step:586/1695 train_time:55642ms step_avg:94.95ms +step:587/1695 train_time:55739ms step_avg:94.96ms +step:588/1695 train_time:55835ms step_avg:94.96ms +step:589/1695 train_time:55931ms step_avg:94.96ms +step:590/1695 train_time:56027ms step_avg:94.96ms +step:591/1695 train_time:56123ms step_avg:94.96ms +step:592/1695 train_time:56219ms step_avg:94.96ms +step:593/1695 train_time:56316ms step_avg:94.97ms +step:594/1695 train_time:56412ms step_avg:94.97ms +step:595/1695 train_time:56509ms step_avg:94.97ms +step:596/1695 train_time:56605ms step_avg:94.97ms +step:597/1695 train_time:56701ms step_avg:94.98ms +step:598/1695 train_time:56797ms step_avg:94.98ms +step:599/1695 train_time:56894ms step_avg:94.98ms +step:600/1695 train_time:56990ms step_avg:94.98ms +step:601/1695 train_time:57087ms step_avg:94.99ms +step:602/1695 train_time:57184ms step_avg:94.99ms +step:603/1695 train_time:57281ms step_avg:94.99ms +step:604/1695 train_time:57377ms step_avg:94.99ms +step:605/1695 train_time:57474ms step_avg:95.00ms +step:606/1695 train_time:57571ms step_avg:95.00ms +step:607/1695 train_time:57667ms step_avg:95.00ms +step:608/1695 train_time:57762ms step_avg:95.00ms +step:609/1695 train_time:57859ms step_avg:95.01ms +step:610/1695 train_time:57955ms step_avg:95.01ms +step:611/1695 train_time:58053ms step_avg:95.01ms +step:612/1695 train_time:58150ms step_avg:95.02ms +step:613/1695 train_time:58245ms step_avg:95.02ms +step:614/1695 train_time:58341ms step_avg:95.02ms +step:615/1695 train_time:58437ms step_avg:95.02ms +step:616/1695 train_time:58533ms step_avg:95.02ms +step:617/1695 train_time:58630ms step_avg:95.02ms +step:618/1695 train_time:58726ms step_avg:95.03ms +step:619/1695 train_time:58823ms step_avg:95.03ms +step:620/1695 train_time:58919ms step_avg:95.03ms +step:621/1695 train_time:59016ms step_avg:95.03ms +step:622/1695 train_time:59113ms step_avg:95.04ms +step:623/1695 train_time:59209ms step_avg:95.04ms +step:624/1695 train_time:59306ms step_avg:95.04ms +step:625/1695 train_time:59403ms step_avg:95.04ms +step:625/1695 val_loss:3.6497 train_time:59497ms step_avg:95.20ms +step:626/1695 train_time:59525ms step_avg:95.09ms +step:627/1695 train_time:59606ms step_avg:95.07ms +step:628/1695 train_time:59706ms step_avg:95.07ms +step:629/1695 train_time:60035ms step_avg:95.44ms +step:630/1695 train_time:60131ms step_avg:95.45ms +step:631/1695 train_time:60227ms step_avg:95.45ms +step:632/1695 train_time:60325ms step_avg:95.45ms +step:633/1695 train_time:60422ms step_avg:95.45ms +step:634/1695 train_time:60519ms step_avg:95.46ms +step:635/1695 train_time:60616ms step_avg:95.46ms +step:636/1695 train_time:60713ms step_avg:95.46ms +step:637/1695 train_time:60810ms step_avg:95.46ms +step:638/1695 train_time:60907ms step_avg:95.47ms +step:639/1695 train_time:61008ms step_avg:95.47ms +step:640/1695 train_time:61109ms step_avg:95.48ms +step:641/1695 train_time:61433ms step_avg:95.84ms +step:642/1695 train_time:61528ms step_avg:95.84ms +step:643/1695 train_time:61625ms step_avg:95.84ms +step:644/1695 train_time:61722ms step_avg:95.84ms +step:645/1695 train_time:61818ms step_avg:95.84ms +step:646/1695 train_time:61916ms step_avg:95.85ms +step:647/1695 train_time:62253ms step_avg:96.22ms +step:648/1695 train_time:62348ms step_avg:96.22ms +step:649/1695 train_time:62446ms step_avg:96.22ms +step:650/1695 train_time:62543ms step_avg:96.22ms +step:651/1695 train_time:62640ms step_avg:96.22ms +step:652/1695 train_time:62738ms step_avg:96.22ms +step:653/1695 train_time:63076ms step_avg:96.59ms +step:654/1695 train_time:63173ms step_avg:96.59ms +step:655/1695 train_time:63270ms step_avg:96.60ms +step:656/1695 train_time:63367ms step_avg:96.60ms +step:657/1695 train_time:63464ms step_avg:96.60ms +step:658/1695 train_time:63561ms step_avg:96.60ms +step:659/1695 train_time:63658ms step_avg:96.60ms +step:660/1695 train_time:63755ms step_avg:96.60ms +step:661/1695 train_time:63852ms step_avg:96.60ms +step:662/1695 train_time:63950ms step_avg:96.60ms +step:663/1695 train_time:64050ms step_avg:96.61ms +step:664/1695 train_time:64150ms step_avg:96.61ms +step:665/1695 train_time:64248ms step_avg:96.61ms +step:666/1695 train_time:64345ms step_avg:96.61ms +step:667/1695 train_time:64443ms step_avg:96.62ms +step:668/1695 train_time:64540ms step_avg:96.62ms +step:669/1695 train_time:64638ms step_avg:96.62ms +step:670/1695 train_time:64735ms step_avg:96.62ms +step:671/1695 train_time:64833ms step_avg:96.62ms +step:672/1695 train_time:64931ms step_avg:96.62ms +step:673/1695 train_time:65029ms step_avg:96.62ms +step:674/1695 train_time:65127ms step_avg:96.63ms +step:675/1695 train_time:65225ms step_avg:96.63ms +step:676/1695 train_time:65323ms step_avg:96.63ms +step:677/1695 train_time:65422ms step_avg:96.63ms +step:678/1695 train_time:65519ms step_avg:96.64ms +step:679/1695 train_time:65617ms step_avg:96.64ms +step:680/1695 train_time:65715ms step_avg:96.64ms +step:681/1695 train_time:65812ms step_avg:96.64ms +step:682/1695 train_time:65909ms step_avg:96.64ms +step:683/1695 train_time:66007ms step_avg:96.64ms +step:684/1695 train_time:66106ms step_avg:96.65ms +step:685/1695 train_time:66204ms step_avg:96.65ms +step:686/1695 train_time:66302ms step_avg:96.65ms +step:687/1695 train_time:66399ms step_avg:96.65ms +step:688/1695 train_time:66497ms step_avg:96.65ms +step:689/1695 train_time:66595ms step_avg:96.65ms +step:690/1695 train_time:66693ms step_avg:96.66ms +step:691/1695 train_time:66791ms step_avg:96.66ms +step:692/1695 train_time:66888ms step_avg:96.66ms +step:693/1695 train_time:66986ms step_avg:96.66ms +step:694/1695 train_time:67085ms step_avg:96.66ms +step:695/1695 train_time:67184ms step_avg:96.67ms +step:696/1695 train_time:67282ms step_avg:96.67ms +step:697/1695 train_time:67380ms step_avg:96.67ms +step:698/1695 train_time:67478ms step_avg:96.67ms +step:699/1695 train_time:67575ms step_avg:96.67ms +step:700/1695 train_time:67673ms step_avg:96.68ms +step:701/1695 train_time:67770ms step_avg:96.68ms +step:702/1695 train_time:67867ms step_avg:96.68ms +step:703/1695 train_time:67964ms step_avg:96.68ms +step:704/1695 train_time:68063ms step_avg:96.68ms +step:705/1695 train_time:68161ms step_avg:96.68ms +step:706/1695 train_time:68260ms step_avg:96.68ms +step:707/1695 train_time:68358ms step_avg:96.69ms +step:708/1695 train_time:68456ms step_avg:96.69ms +step:709/1695 train_time:68553ms step_avg:96.69ms +step:710/1695 train_time:68651ms step_avg:96.69ms +step:711/1695 train_time:68748ms step_avg:96.69ms +step:712/1695 train_time:68846ms step_avg:96.69ms +step:713/1695 train_time:68943ms step_avg:96.69ms +step:714/1695 train_time:69388ms step_avg:97.18ms +step:715/1695 train_time:69437ms step_avg:97.12ms +step:716/1695 train_time:69533ms step_avg:97.11ms +step:717/1695 train_time:69630ms step_avg:97.11ms +step:718/1695 train_time:69727ms step_avg:97.11ms +step:719/1695 train_time:69824ms step_avg:97.11ms +step:720/1695 train_time:69921ms step_avg:97.11ms +step:721/1695 train_time:70018ms step_avg:97.11ms +step:722/1695 train_time:70116ms step_avg:97.11ms +step:723/1695 train_time:70213ms step_avg:97.11ms +step:724/1695 train_time:70312ms step_avg:97.12ms +step:725/1695 train_time:70412ms step_avg:97.12ms +step:726/1695 train_time:70511ms step_avg:97.12ms +step:727/1695 train_time:70609ms step_avg:97.12ms +step:728/1695 train_time:70707ms step_avg:97.12ms +step:729/1695 train_time:70804ms step_avg:97.12ms +step:730/1695 train_time:70901ms step_avg:97.12ms +step:731/1695 train_time:70998ms step_avg:97.12ms +step:732/1695 train_time:71095ms step_avg:97.12ms +step:733/1695 train_time:71192ms step_avg:97.12ms +step:734/1695 train_time:71290ms step_avg:97.13ms +step:735/1695 train_time:71389ms step_avg:97.13ms +step:736/1695 train_time:71488ms step_avg:97.13ms +step:737/1695 train_time:71586ms step_avg:97.13ms +step:738/1695 train_time:71684ms step_avg:97.13ms +step:739/1695 train_time:71783ms step_avg:97.14ms +step:740/1695 train_time:71880ms step_avg:97.14ms +step:741/1695 train_time:71977ms step_avg:97.14ms +step:742/1695 train_time:72075ms step_avg:97.14ms +step:743/1695 train_time:72172ms step_avg:97.14ms +step:744/1695 train_time:72269ms step_avg:97.14ms +step:745/1695 train_time:72367ms step_avg:97.14ms +step:746/1695 train_time:72464ms step_avg:97.14ms +step:747/1695 train_time:72562ms step_avg:97.14ms +step:748/1695 train_time:72661ms step_avg:97.14ms +step:749/1695 train_time:72760ms step_avg:97.14ms +step:750/1695 train_time:72857ms step_avg:97.14ms +step:750/1695 val_loss:3.5863 train_time:72953ms step_avg:97.27ms +step:751/1695 train_time:72981ms step_avg:97.18ms +step:752/1695 train_time:73064ms step_avg:97.16ms +step:753/1695 train_time:73165ms step_avg:97.16ms +step:754/1695 train_time:73265ms step_avg:97.17ms +step:755/1695 train_time:73363ms step_avg:97.17ms +step:756/1695 train_time:73461ms step_avg:97.17ms +step:757/1695 train_time:73560ms step_avg:97.17ms +step:758/1695 train_time:73656ms step_avg:97.17ms +step:759/1695 train_time:73753ms step_avg:97.17ms +step:760/1695 train_time:73851ms step_avg:97.17ms +step:761/1695 train_time:73949ms step_avg:97.17ms +step:762/1695 train_time:74047ms step_avg:97.17ms +step:763/1695 train_time:74146ms step_avg:97.18ms +step:764/1695 train_time:74246ms step_avg:97.18ms +step:765/1695 train_time:74345ms step_avg:97.18ms +step:766/1695 train_time:74444ms step_avg:97.18ms +step:767/1695 train_time:74542ms step_avg:97.19ms +step:768/1695 train_time:74639ms step_avg:97.19ms +step:769/1695 train_time:74737ms step_avg:97.19ms +step:770/1695 train_time:74834ms step_avg:97.19ms +step:771/1695 train_time:74933ms step_avg:97.19ms +step:772/1695 train_time:75263ms step_avg:97.49ms +step:773/1695 train_time:75359ms step_avg:97.49ms +step:774/1695 train_time:75456ms step_avg:97.49ms +step:775/1695 train_time:75553ms step_avg:97.49ms +step:776/1695 train_time:75651ms step_avg:97.49ms +step:777/1695 train_time:75749ms step_avg:97.49ms +step:778/1695 train_time:75846ms step_avg:97.49ms +step:779/1695 train_time:75944ms step_avg:97.49ms +step:780/1695 train_time:76042ms step_avg:97.49ms +step:781/1695 train_time:76139ms step_avg:97.49ms +step:782/1695 train_time:76240ms step_avg:97.49ms +step:783/1695 train_time:76338ms step_avg:97.49ms +step:784/1695 train_time:76438ms step_avg:97.50ms +step:785/1695 train_time:76537ms step_avg:97.50ms +step:786/1695 train_time:76634ms step_avg:97.50ms +step:787/1695 train_time:76732ms step_avg:97.50ms +step:788/1695 train_time:76829ms step_avg:97.50ms +step:789/1695 train_time:76927ms step_avg:97.50ms +step:790/1695 train_time:77249ms step_avg:97.78ms +step:791/1695 train_time:77345ms step_avg:97.78ms +step:792/1695 train_time:77443ms step_avg:97.78ms +step:793/1695 train_time:77540ms step_avg:97.78ms +step:794/1695 train_time:77637ms step_avg:97.78ms +step:795/1695 train_time:77734ms step_avg:97.78ms +step:796/1695 train_time:78123ms step_avg:98.14ms +step:797/1695 train_time:78219ms step_avg:98.14ms +step:798/1695 train_time:78315ms step_avg:98.14ms +step:799/1695 train_time:78412ms step_avg:98.14ms +step:800/1695 train_time:78510ms step_avg:98.14ms +step:801/1695 train_time:78607ms step_avg:98.14ms +step:802/1695 train_time:78705ms step_avg:98.14ms +step:803/1695 train_time:78802ms step_avg:98.13ms +step:804/1695 train_time:78899ms step_avg:98.13ms +step:805/1695 train_time:78998ms step_avg:98.13ms +step:806/1695 train_time:79097ms step_avg:98.14ms +step:807/1695 train_time:79196ms step_avg:98.14ms +step:808/1695 train_time:79295ms step_avg:98.14ms +step:809/1695 train_time:79394ms step_avg:98.14ms +step:810/1695 train_time:79492ms step_avg:98.14ms +step:811/1695 train_time:79589ms step_avg:98.14ms +step:812/1695 train_time:79687ms step_avg:98.14ms +step:813/1695 train_time:79785ms step_avg:98.14ms +step:814/1695 train_time:79882ms step_avg:98.13ms +step:815/1695 train_time:79980ms step_avg:98.13ms +step:816/1695 train_time:80078ms step_avg:98.13ms +step:817/1695 train_time:80176ms step_avg:98.13ms +step:818/1695 train_time:80274ms step_avg:98.13ms +step:819/1695 train_time:80372ms step_avg:98.13ms +step:820/1695 train_time:80470ms step_avg:98.13ms +step:821/1695 train_time:80568ms step_avg:98.13ms +step:822/1695 train_time:80667ms step_avg:98.13ms +step:823/1695 train_time:80764ms step_avg:98.13ms +step:824/1695 train_time:80862ms step_avg:98.13ms +step:825/1695 train_time:80960ms step_avg:98.13ms +step:826/1695 train_time:81058ms step_avg:98.13ms +step:827/1695 train_time:81157ms step_avg:98.13ms +step:828/1695 train_time:81255ms step_avg:98.13ms +step:829/1695 train_time:81353ms step_avg:98.13ms +step:830/1695 train_time:81451ms step_avg:98.13ms +step:831/1695 train_time:81549ms step_avg:98.13ms +step:832/1695 train_time:81647ms step_avg:98.13ms +step:833/1695 train_time:81745ms step_avg:98.13ms +step:834/1695 train_time:81842ms step_avg:98.13ms +step:835/1695 train_time:81940ms step_avg:98.13ms +step:836/1695 train_time:82038ms step_avg:98.13ms +step:837/1695 train_time:82137ms step_avg:98.13ms +step:838/1695 train_time:82235ms step_avg:98.13ms +step:839/1695 train_time:82334ms step_avg:98.13ms +step:840/1695 train_time:82432ms step_avg:98.13ms +step:841/1695 train_time:82530ms step_avg:98.13ms +step:842/1695 train_time:82628ms step_avg:98.13ms +step:843/1695 train_time:82726ms step_avg:98.13ms +step:844/1695 train_time:82825ms step_avg:98.13ms +step:845/1695 train_time:82923ms step_avg:98.13ms +step:846/1695 train_time:83021ms step_avg:98.13ms +step:847/1695 train_time:83119ms step_avg:98.13ms +step:848/1695 train_time:83217ms step_avg:98.13ms +step:849/1695 train_time:83315ms step_avg:98.13ms +step:850/1695 train_time:83414ms step_avg:98.13ms +step:851/1695 train_time:83512ms step_avg:98.13ms +step:852/1695 train_time:83609ms step_avg:98.13ms +step:853/1695 train_time:83707ms step_avg:98.13ms +step:854/1695 train_time:83805ms step_avg:98.13ms +step:855/1695 train_time:83903ms step_avg:98.13ms +step:856/1695 train_time:84001ms step_avg:98.13ms +step:857/1695 train_time:84099ms step_avg:98.13ms +step:858/1695 train_time:84196ms step_avg:98.13ms +step:859/1695 train_time:84295ms step_avg:98.13ms +step:860/1695 train_time:84393ms step_avg:98.13ms +step:861/1695 train_time:84492ms step_avg:98.13ms +step:862/1695 train_time:84590ms step_avg:98.13ms +step:863/1695 train_time:84688ms step_avg:98.13ms +step:864/1695 train_time:84785ms step_avg:98.13ms +step:865/1695 train_time:84883ms step_avg:98.13ms +step:866/1695 train_time:84981ms step_avg:98.13ms +step:867/1695 train_time:85079ms step_avg:98.13ms +step:868/1695 train_time:85177ms step_avg:98.13ms +step:869/1695 train_time:85275ms step_avg:98.13ms +step:870/1695 train_time:85373ms step_avg:98.13ms +step:871/1695 train_time:85471ms step_avg:98.13ms +step:872/1695 train_time:85570ms step_avg:98.13ms +step:873/1695 train_time:85667ms step_avg:98.13ms +step:874/1695 train_time:85765ms step_avg:98.13ms +step:875/1695 train_time:85863ms step_avg:98.13ms +step:875/1695 val_loss:3.5358 train_time:85959ms step_avg:98.24ms +step:876/1695 train_time:85987ms step_avg:98.16ms +step:877/1695 train_time:86069ms step_avg:98.14ms +step:878/1695 train_time:86171ms step_avg:98.14ms +step:879/1695 train_time:86269ms step_avg:98.14ms +step:880/1695 train_time:86366ms step_avg:98.14ms +step:881/1695 train_time:86464ms step_avg:98.14ms +step:882/1695 train_time:86563ms step_avg:98.14ms +step:883/1695 train_time:86662ms step_avg:98.15ms +step:884/1695 train_time:86761ms step_avg:98.15ms +step:885/1695 train_time:86859ms step_avg:98.15ms +step:886/1695 train_time:86958ms step_avg:98.15ms +step:887/1695 train_time:87058ms step_avg:98.15ms +step:888/1695 train_time:87161ms step_avg:98.15ms +step:889/1695 train_time:87262ms step_avg:98.16ms +step:890/1695 train_time:87363ms step_avg:98.16ms +step:891/1695 train_time:87462ms step_avg:98.16ms +step:892/1695 train_time:87562ms step_avg:98.16ms +step:893/1695 train_time:87660ms step_avg:98.16ms +step:894/1695 train_time:87759ms step_avg:98.16ms +step:895/1695 train_time:87857ms step_avg:98.16ms +step:896/1695 train_time:87956ms step_avg:98.16ms +step:897/1695 train_time:88055ms step_avg:98.17ms +step:898/1695 train_time:88155ms step_avg:98.17ms +step:899/1695 train_time:88255ms step_avg:98.17ms +step:900/1695 train_time:88357ms step_avg:98.17ms +step:901/1695 train_time:88457ms step_avg:98.18ms +step:902/1695 train_time:88558ms step_avg:98.18ms +step:903/1695 train_time:88657ms step_avg:98.18ms +step:904/1695 train_time:88757ms step_avg:98.18ms +step:905/1695 train_time:88856ms step_avg:98.18ms +step:906/1695 train_time:88954ms step_avg:98.18ms +step:907/1695 train_time:89053ms step_avg:98.18ms +step:908/1695 train_time:89153ms step_avg:98.19ms +step:909/1695 train_time:89252ms step_avg:98.19ms +step:910/1695 train_time:89353ms step_avg:98.19ms +step:911/1695 train_time:89454ms step_avg:98.19ms +step:912/1695 train_time:89555ms step_avg:98.20ms +step:913/1695 train_time:89654ms step_avg:98.20ms +step:914/1695 train_time:89754ms step_avg:98.20ms +step:915/1695 train_time:89853ms step_avg:98.20ms +step:916/1695 train_time:89953ms step_avg:98.20ms +step:917/1695 train_time:90052ms step_avg:98.20ms +step:918/1695 train_time:90151ms step_avg:98.20ms +step:919/1695 train_time:90250ms step_avg:98.20ms +step:920/1695 train_time:90349ms step_avg:98.21ms +step:921/1695 train_time:90450ms step_avg:98.21ms +step:922/1695 train_time:90550ms step_avg:98.21ms +step:923/1695 train_time:90649ms step_avg:98.21ms +step:924/1695 train_time:90748ms step_avg:98.21ms +step:925/1695 train_time:90846ms step_avg:98.21ms +step:926/1695 train_time:90946ms step_avg:98.21ms +step:927/1695 train_time:91045ms step_avg:98.21ms +step:928/1695 train_time:91144ms step_avg:98.22ms +step:929/1695 train_time:91243ms step_avg:98.22ms +step:930/1695 train_time:91343ms step_avg:98.22ms +step:931/1695 train_time:91444ms step_avg:98.22ms +step:932/1695 train_time:91543ms step_avg:98.22ms +step:933/1695 train_time:91643ms step_avg:98.22ms +step:934/1695 train_time:91742ms step_avg:98.23ms +step:935/1695 train_time:91842ms step_avg:98.23ms +step:936/1695 train_time:91941ms step_avg:98.23ms +step:937/1695 train_time:92040ms step_avg:98.23ms +step:938/1695 train_time:92140ms step_avg:98.23ms +step:939/1695 train_time:92241ms step_avg:98.23ms +step:940/1695 train_time:92340ms step_avg:98.23ms +step:941/1695 train_time:92440ms step_avg:98.24ms +step:942/1695 train_time:92540ms step_avg:98.24ms +step:943/1695 train_time:92640ms step_avg:98.24ms +step:944/1695 train_time:92740ms step_avg:98.24ms +step:945/1695 train_time:92841ms step_avg:98.24ms +step:946/1695 train_time:92941ms step_avg:98.25ms +step:947/1695 train_time:93040ms step_avg:98.25ms +step:948/1695 train_time:93139ms step_avg:98.25ms +step:949/1695 train_time:93239ms step_avg:98.25ms +step:950/1695 train_time:93338ms step_avg:98.25ms +step:951/1695 train_time:93438ms step_avg:98.25ms +step:952/1695 train_time:93538ms step_avg:98.25ms +step:953/1695 train_time:93639ms step_avg:98.26ms +step:954/1695 train_time:93738ms step_avg:98.26ms +step:955/1695 train_time:93838ms step_avg:98.26ms +step:956/1695 train_time:93937ms step_avg:98.26ms +step:957/1695 train_time:94036ms step_avg:98.26ms +step:958/1695 train_time:94136ms step_avg:98.26ms +step:959/1695 train_time:94235ms step_avg:98.26ms +step:960/1695 train_time:94335ms step_avg:98.27ms +step:961/1695 train_time:94435ms step_avg:98.27ms +step:962/1695 train_time:94535ms step_avg:98.27ms +step:963/1695 train_time:94635ms step_avg:98.27ms +step:964/1695 train_time:94736ms step_avg:98.27ms +step:965/1695 train_time:94837ms step_avg:98.28ms +step:966/1695 train_time:94937ms step_avg:98.28ms +step:967/1695 train_time:95037ms step_avg:98.28ms +step:968/1695 train_time:95135ms step_avg:98.28ms +step:969/1695 train_time:95236ms step_avg:98.28ms +step:970/1695 train_time:95336ms step_avg:98.28ms +step:971/1695 train_time:95435ms step_avg:98.29ms +step:972/1695 train_time:95535ms step_avg:98.29ms +step:973/1695 train_time:95635ms step_avg:98.29ms +step:974/1695 train_time:95736ms step_avg:98.29ms +step:975/1695 train_time:95836ms step_avg:98.29ms +step:976/1695 train_time:95936ms step_avg:98.30ms +step:977/1695 train_time:96036ms step_avg:98.30ms +step:978/1695 train_time:96136ms step_avg:98.30ms +step:979/1695 train_time:96236ms step_avg:98.30ms +step:980/1695 train_time:96336ms step_avg:98.30ms +step:981/1695 train_time:96436ms step_avg:98.30ms +step:982/1695 train_time:96536ms step_avg:98.31ms +step:983/1695 train_time:96637ms step_avg:98.31ms +step:984/1695 train_time:96736ms step_avg:98.31ms +step:985/1695 train_time:96836ms step_avg:98.31ms +step:986/1695 train_time:96936ms step_avg:98.31ms +step:987/1695 train_time:97037ms step_avg:98.31ms +step:988/1695 train_time:97136ms step_avg:98.32ms +step:989/1695 train_time:97236ms step_avg:98.32ms +step:990/1695 train_time:97337ms step_avg:98.32ms +step:991/1695 train_time:97437ms step_avg:98.32ms +step:992/1695 train_time:97537ms step_avg:98.32ms +step:993/1695 train_time:97636ms step_avg:98.32ms +step:994/1695 train_time:97736ms step_avg:98.33ms +step:995/1695 train_time:97837ms step_avg:98.33ms +step:996/1695 train_time:97936ms step_avg:98.33ms +step:997/1695 train_time:98036ms step_avg:98.33ms +step:998/1695 train_time:98135ms step_avg:98.33ms +step:999/1695 train_time:98236ms step_avg:98.33ms +step:1000/1695 train_time:98335ms step_avg:98.33ms +step:1000/1695 val_loss:3.4922 train_time:98433ms step_avg:98.43ms +step:1001/1695 train_time:98469ms step_avg:98.37ms +step:1002/1695 train_time:98546ms step_avg:98.35ms +step:1003/1695 train_time:98643ms step_avg:98.35ms +step:1004/1695 train_time:98742ms step_avg:98.35ms +step:1005/1695 train_time:98841ms step_avg:98.35ms +step:1006/1695 train_time:98940ms step_avg:98.35ms +step:1007/1695 train_time:99039ms step_avg:98.35ms +step:1008/1695 train_time:99138ms step_avg:98.35ms +step:1009/1695 train_time:99237ms step_avg:98.35ms +step:1010/1695 train_time:99336ms step_avg:98.35ms +step:1011/1695 train_time:99439ms step_avg:98.36ms +step:1012/1695 train_time:99540ms step_avg:98.36ms +step:1013/1695 train_time:99640ms step_avg:98.36ms +step:1014/1695 train_time:99739ms step_avg:98.36ms +step:1015/1695 train_time:99838ms step_avg:98.36ms +step:1016/1695 train_time:99938ms step_avg:98.36ms +step:1017/1695 train_time:100038ms step_avg:98.37ms +step:1018/1695 train_time:100137ms step_avg:98.37ms +step:1019/1695 train_time:100236ms step_avg:98.37ms +step:1020/1695 train_time:100336ms step_avg:98.37ms +step:1021/1695 train_time:100437ms step_avg:98.37ms +step:1022/1695 train_time:100537ms step_avg:98.37ms +step:1023/1695 train_time:100638ms step_avg:98.38ms +step:1024/1695 train_time:100739ms step_avg:98.38ms +step:1025/1695 train_time:100839ms step_avg:98.38ms +step:1026/1695 train_time:100939ms step_avg:98.38ms +step:1027/1695 train_time:101038ms step_avg:98.38ms +step:1028/1695 train_time:101137ms step_avg:98.38ms +step:1029/1695 train_time:101237ms step_avg:98.38ms +step:1030/1695 train_time:101337ms step_avg:98.39ms +step:1031/1695 train_time:101438ms step_avg:98.39ms +step:1032/1695 train_time:101539ms step_avg:98.39ms +step:1033/1695 train_time:101639ms step_avg:98.39ms +step:1034/1695 train_time:101738ms step_avg:98.39ms +step:1035/1695 train_time:101838ms step_avg:98.39ms +step:1036/1695 train_time:101937ms step_avg:98.39ms +step:1037/1695 train_time:102036ms step_avg:98.40ms +step:1038/1695 train_time:102136ms step_avg:98.40ms +step:1039/1695 train_time:102234ms step_avg:98.40ms +step:1040/1695 train_time:102333ms step_avg:98.40ms +step:1041/1695 train_time:102434ms step_avg:98.40ms +step:1042/1695 train_time:102535ms step_avg:98.40ms +step:1043/1695 train_time:102636ms step_avg:98.40ms +step:1044/1695 train_time:102736ms step_avg:98.41ms +step:1045/1695 train_time:102837ms step_avg:98.41ms +step:1046/1695 train_time:102937ms step_avg:98.41ms +step:1047/1695 train_time:103036ms step_avg:98.41ms +step:1048/1695 train_time:103135ms step_avg:98.41ms +step:1049/1695 train_time:103234ms step_avg:98.41ms +step:1050/1695 train_time:103333ms step_avg:98.41ms +step:1051/1695 train_time:103434ms step_avg:98.41ms +step:1052/1695 train_time:103534ms step_avg:98.42ms +step:1053/1695 train_time:103633ms step_avg:98.42ms +step:1054/1695 train_time:103733ms step_avg:98.42ms +step:1055/1695 train_time:103832ms step_avg:98.42ms +step:1056/1695 train_time:103932ms step_avg:98.42ms +step:1057/1695 train_time:104032ms step_avg:98.42ms +step:1058/1695 train_time:104132ms step_avg:98.42ms +step:1059/1695 train_time:104231ms step_avg:98.42ms +step:1060/1695 train_time:104330ms step_avg:98.42ms +step:1061/1695 train_time:104430ms step_avg:98.43ms +step:1062/1695 train_time:104530ms step_avg:98.43ms +step:1063/1695 train_time:104630ms step_avg:98.43ms +step:1064/1695 train_time:104730ms step_avg:98.43ms +step:1065/1695 train_time:104829ms step_avg:98.43ms +step:1066/1695 train_time:104929ms step_avg:98.43ms +step:1067/1695 train_time:105030ms step_avg:98.44ms +step:1068/1695 train_time:105131ms step_avg:98.44ms +step:1069/1695 train_time:105231ms step_avg:98.44ms +step:1070/1695 train_time:105333ms step_avg:98.44ms +step:1071/1695 train_time:105433ms step_avg:98.44ms +step:1072/1695 train_time:105532ms step_avg:98.44ms +step:1073/1695 train_time:105632ms step_avg:98.45ms +step:1074/1695 train_time:105732ms step_avg:98.45ms +step:1075/1695 train_time:105832ms step_avg:98.45ms +step:1076/1695 train_time:105932ms step_avg:98.45ms +step:1077/1695 train_time:106033ms step_avg:98.45ms +step:1078/1695 train_time:106133ms step_avg:98.45ms +step:1079/1695 train_time:106233ms step_avg:98.45ms +step:1080/1695 train_time:106332ms step_avg:98.46ms +step:1081/1695 train_time:106433ms step_avg:98.46ms +step:1082/1695 train_time:106533ms step_avg:98.46ms +step:1083/1695 train_time:106633ms step_avg:98.46ms +step:1084/1695 train_time:106732ms step_avg:98.46ms +step:1085/1695 train_time:106832ms step_avg:98.46ms +step:1086/1695 train_time:106932ms step_avg:98.46ms +step:1087/1695 train_time:107032ms step_avg:98.47ms +step:1088/1695 train_time:107132ms step_avg:98.47ms +step:1089/1695 train_time:107231ms step_avg:98.47ms +step:1090/1695 train_time:107332ms step_avg:98.47ms +step:1091/1695 train_time:107432ms step_avg:98.47ms +step:1092/1695 train_time:107532ms step_avg:98.47ms +step:1093/1695 train_time:107633ms step_avg:98.47ms +step:1094/1695 train_time:107733ms step_avg:98.48ms +step:1095/1695 train_time:107833ms step_avg:98.48ms +step:1096/1695 train_time:107932ms step_avg:98.48ms +step:1097/1695 train_time:108032ms step_avg:98.48ms +step:1098/1695 train_time:108132ms step_avg:98.48ms +step:1099/1695 train_time:108232ms step_avg:98.48ms +step:1100/1695 train_time:108332ms step_avg:98.48ms +step:1101/1695 train_time:108432ms step_avg:98.49ms +step:1102/1695 train_time:108533ms step_avg:98.49ms +step:1103/1695 train_time:108633ms step_avg:98.49ms +step:1104/1695 train_time:108732ms step_avg:98.49ms +step:1105/1695 train_time:108832ms step_avg:98.49ms +step:1106/1695 train_time:108932ms step_avg:98.49ms +step:1107/1695 train_time:109033ms step_avg:98.49ms +step:1108/1695 train_time:109133ms step_avg:98.50ms +step:1109/1695 train_time:109233ms step_avg:98.50ms +step:1110/1695 train_time:109333ms step_avg:98.50ms +step:1111/1695 train_time:109432ms step_avg:98.50ms +step:1112/1695 train_time:109533ms step_avg:98.50ms +step:1113/1695 train_time:109632ms step_avg:98.50ms +step:1114/1695 train_time:109733ms step_avg:98.50ms +step:1115/1695 train_time:109832ms step_avg:98.50ms +step:1116/1695 train_time:109933ms step_avg:98.51ms +step:1117/1695 train_time:110033ms step_avg:98.51ms +step:1118/1695 train_time:110133ms step_avg:98.51ms +step:1119/1695 train_time:110233ms step_avg:98.51ms +step:1120/1695 train_time:110333ms step_avg:98.51ms +step:1121/1695 train_time:110434ms step_avg:98.51ms +step:1122/1695 train_time:110533ms step_avg:98.51ms +step:1123/1695 train_time:110634ms step_avg:98.52ms +step:1124/1695 train_time:110733ms step_avg:98.52ms +step:1125/1695 train_time:110834ms step_avg:98.52ms +step:1125/1695 val_loss:3.4399 train_time:110931ms step_avg:98.61ms +step:1126/1695 train_time:110957ms step_avg:98.54ms +step:1127/1695 train_time:111046ms step_avg:98.53ms +step:1128/1695 train_time:111150ms step_avg:98.54ms +step:1129/1695 train_time:111251ms step_avg:98.54ms +step:1130/1695 train_time:111349ms step_avg:98.54ms +step:1131/1695 train_time:111448ms step_avg:98.54ms +step:1132/1695 train_time:111548ms step_avg:98.54ms +step:1133/1695 train_time:111647ms step_avg:98.54ms +step:1134/1695 train_time:111747ms step_avg:98.54ms +step:1135/1695 train_time:111847ms step_avg:98.54ms +step:1136/1695 train_time:111949ms step_avg:98.55ms +step:1137/1695 train_time:112053ms step_avg:98.55ms +step:1138/1695 train_time:112156ms step_avg:98.56ms +step:1139/1695 train_time:112256ms step_avg:98.56ms +step:1140/1695 train_time:112356ms step_avg:98.56ms +step:1141/1695 train_time:112455ms step_avg:98.56ms +step:1142/1695 train_time:112556ms step_avg:98.56ms +step:1143/1695 train_time:112656ms step_avg:98.56ms +step:1144/1695 train_time:112756ms step_avg:98.56ms +step:1145/1695 train_time:112858ms step_avg:98.57ms +step:1146/1695 train_time:112959ms step_avg:98.57ms +step:1147/1695 train_time:113060ms step_avg:98.57ms +step:1148/1695 train_time:113161ms step_avg:98.57ms +step:1149/1695 train_time:113261ms step_avg:98.57ms +step:1150/1695 train_time:113362ms step_avg:98.58ms +step:1151/1695 train_time:113462ms step_avg:98.58ms +step:1152/1695 train_time:113563ms step_avg:98.58ms +step:1153/1695 train_time:113663ms step_avg:98.58ms +step:1154/1695 train_time:113763ms step_avg:98.58ms +step:1155/1695 train_time:113863ms step_avg:98.58ms +step:1156/1695 train_time:113963ms step_avg:98.58ms +step:1157/1695 train_time:114065ms step_avg:98.59ms +step:1158/1695 train_time:114165ms step_avg:98.59ms +step:1159/1695 train_time:114266ms step_avg:98.59ms +step:1160/1695 train_time:114367ms step_avg:98.59ms +step:1161/1695 train_time:114468ms step_avg:98.59ms +step:1162/1695 train_time:114569ms step_avg:98.60ms +step:1163/1695 train_time:114673ms step_avg:98.60ms +step:1164/1695 train_time:114775ms step_avg:98.60ms +step:1165/1695 train_time:114875ms step_avg:98.60ms +step:1166/1695 train_time:114976ms step_avg:98.61ms +step:1167/1695 train_time:115076ms step_avg:98.61ms +step:1168/1695 train_time:115177ms step_avg:98.61ms +step:1169/1695 train_time:115278ms step_avg:98.61ms +step:1170/1695 train_time:115378ms step_avg:98.61ms +step:1171/1695 train_time:115480ms step_avg:98.62ms +step:1172/1695 train_time:115581ms step_avg:98.62ms +step:1173/1695 train_time:115682ms step_avg:98.62ms +step:1174/1695 train_time:115782ms step_avg:98.62ms +step:1175/1695 train_time:115882ms step_avg:98.62ms +step:1176/1695 train_time:115983ms step_avg:98.63ms +step:1177/1695 train_time:116083ms step_avg:98.63ms +step:1178/1695 train_time:116183ms step_avg:98.63ms +step:1179/1695 train_time:116287ms step_avg:98.63ms +step:1180/1695 train_time:116390ms step_avg:98.64ms +step:1181/1695 train_time:116491ms step_avg:98.64ms +step:1182/1695 train_time:116592ms step_avg:98.64ms +step:1183/1695 train_time:116691ms step_avg:98.64ms +step:1184/1695 train_time:116794ms step_avg:98.64ms +step:1185/1695 train_time:116895ms step_avg:98.65ms +step:1186/1695 train_time:116996ms step_avg:98.65ms +step:1187/1695 train_time:117098ms step_avg:98.65ms +step:1188/1695 train_time:117200ms step_avg:98.65ms +step:1189/1695 train_time:117299ms step_avg:98.65ms +step:1190/1695 train_time:117400ms step_avg:98.66ms +step:1191/1695 train_time:117501ms step_avg:98.66ms +step:1192/1695 train_time:117601ms step_avg:98.66ms +step:1193/1695 train_time:117702ms step_avg:98.66ms +step:1194/1695 train_time:117804ms step_avg:98.66ms +step:1195/1695 train_time:117903ms step_avg:98.66ms +step:1196/1695 train_time:118004ms step_avg:98.67ms +step:1197/1695 train_time:118106ms step_avg:98.67ms +step:1198/1695 train_time:118207ms step_avg:98.67ms +step:1199/1695 train_time:118309ms step_avg:98.67ms +step:1200/1695 train_time:118409ms step_avg:98.67ms +step:1201/1695 train_time:118511ms step_avg:98.68ms +step:1202/1695 train_time:118611ms step_avg:98.68ms +step:1203/1695 train_time:118712ms step_avg:98.68ms +step:1204/1695 train_time:118813ms step_avg:98.68ms +step:1205/1695 train_time:118914ms step_avg:98.68ms +step:1206/1695 train_time:119015ms step_avg:98.69ms +step:1207/1695 train_time:119117ms step_avg:98.69ms +step:1208/1695 train_time:119218ms step_avg:98.69ms +step:1209/1695 train_time:119318ms step_avg:98.69ms +step:1210/1695 train_time:119420ms step_avg:98.69ms +step:1211/1695 train_time:119520ms step_avg:98.70ms +step:1212/1695 train_time:119620ms step_avg:98.70ms +step:1213/1695 train_time:119720ms step_avg:98.70ms +step:1214/1695 train_time:119820ms step_avg:98.70ms +step:1215/1695 train_time:119921ms step_avg:98.70ms +step:1216/1695 train_time:120023ms step_avg:98.70ms +step:1217/1695 train_time:120124ms step_avg:98.70ms +step:1218/1695 train_time:120224ms step_avg:98.71ms +step:1219/1695 train_time:120326ms step_avg:98.71ms +step:1220/1695 train_time:120429ms step_avg:98.71ms +step:1221/1695 train_time:120530ms step_avg:98.71ms +step:1222/1695 train_time:120630ms step_avg:98.72ms +step:1223/1695 train_time:120732ms step_avg:98.72ms +step:1224/1695 train_time:120832ms step_avg:98.72ms +step:1225/1695 train_time:120934ms step_avg:98.72ms +step:1226/1695 train_time:121034ms step_avg:98.72ms +step:1227/1695 train_time:121137ms step_avg:98.73ms +step:1228/1695 train_time:121238ms step_avg:98.73ms +step:1229/1695 train_time:121339ms step_avg:98.73ms +step:1230/1695 train_time:121440ms step_avg:98.73ms +step:1231/1695 train_time:121540ms step_avg:98.73ms +step:1232/1695 train_time:121641ms step_avg:98.73ms +step:1233/1695 train_time:121741ms step_avg:98.74ms +step:1234/1695 train_time:121842ms step_avg:98.74ms +step:1235/1695 train_time:121942ms step_avg:98.74ms +step:1236/1695 train_time:122043ms step_avg:98.74ms +step:1237/1695 train_time:122144ms step_avg:98.74ms +step:1238/1695 train_time:122245ms step_avg:98.74ms +step:1239/1695 train_time:122346ms step_avg:98.75ms +step:1240/1695 train_time:122448ms step_avg:98.75ms +step:1241/1695 train_time:122549ms step_avg:98.75ms +step:1242/1695 train_time:122650ms step_avg:98.75ms +step:1243/1695 train_time:122751ms step_avg:98.75ms +step:1244/1695 train_time:122851ms step_avg:98.76ms +step:1245/1695 train_time:122952ms step_avg:98.76ms +step:1246/1695 train_time:123054ms step_avg:98.76ms +step:1247/1695 train_time:123154ms step_avg:98.76ms +step:1248/1695 train_time:123256ms step_avg:98.76ms +step:1249/1695 train_time:123357ms step_avg:98.77ms +step:1250/1695 train_time:123458ms step_avg:98.77ms +step:1250/1695 val_loss:3.3944 train_time:123558ms step_avg:98.85ms +step:1251/1695 train_time:123586ms step_avg:98.79ms +step:1252/1695 train_time:123668ms step_avg:98.78ms +step:1253/1695 train_time:123769ms step_avg:98.78ms +step:1254/1695 train_time:123871ms step_avg:98.78ms +step:1255/1695 train_time:123971ms step_avg:98.78ms +step:1256/1695 train_time:124071ms step_avg:98.78ms +step:1257/1695 train_time:124170ms step_avg:98.78ms +step:1258/1695 train_time:124270ms step_avg:98.78ms +step:1259/1695 train_time:124370ms step_avg:98.78ms +step:1260/1695 train_time:124470ms step_avg:98.79ms +step:1261/1695 train_time:124574ms step_avg:98.79ms +step:1262/1695 train_time:124677ms step_avg:98.79ms +step:1263/1695 train_time:124778ms step_avg:98.80ms +step:1264/1695 train_time:124879ms step_avg:98.80ms +step:1265/1695 train_time:124979ms step_avg:98.80ms +step:1266/1695 train_time:125080ms step_avg:98.80ms +step:1267/1695 train_time:125181ms step_avg:98.80ms +step:1268/1695 train_time:125283ms step_avg:98.80ms +step:1269/1695 train_time:125384ms step_avg:98.81ms +step:1270/1695 train_time:125485ms step_avg:98.81ms +step:1271/1695 train_time:125589ms step_avg:98.81ms +step:1272/1695 train_time:125688ms step_avg:98.81ms +step:1273/1695 train_time:125788ms step_avg:98.81ms +step:1274/1695 train_time:125888ms step_avg:98.81ms +step:1275/1695 train_time:125988ms step_avg:98.81ms +step:1276/1695 train_time:126090ms step_avg:98.82ms +step:1277/1695 train_time:126190ms step_avg:98.82ms +step:1278/1695 train_time:126292ms step_avg:98.82ms +step:1279/1695 train_time:126393ms step_avg:98.82ms +step:1280/1695 train_time:126496ms step_avg:98.82ms +step:1281/1695 train_time:126598ms step_avg:98.83ms +step:1282/1695 train_time:126700ms step_avg:98.83ms +step:1283/1695 train_time:126800ms step_avg:98.83ms +step:1284/1695 train_time:126900ms step_avg:98.83ms +step:1285/1695 train_time:127001ms step_avg:98.83ms +step:1286/1695 train_time:127102ms step_avg:98.83ms +step:1287/1695 train_time:127203ms step_avg:98.84ms +step:1288/1695 train_time:127304ms step_avg:98.84ms +step:1289/1695 train_time:127406ms step_avg:98.84ms +step:1290/1695 train_time:127506ms step_avg:98.84ms +step:1291/1695 train_time:127608ms step_avg:98.84ms +step:1292/1695 train_time:127709ms step_avg:98.85ms +step:1293/1695 train_time:127810ms step_avg:98.85ms +step:1294/1695 train_time:127911ms step_avg:98.85ms +step:1295/1695 train_time:128012ms step_avg:98.85ms +step:1296/1695 train_time:128114ms step_avg:98.85ms +step:1297/1695 train_time:128216ms step_avg:98.86ms +step:1298/1695 train_time:128318ms step_avg:98.86ms +step:1299/1695 train_time:128420ms step_avg:98.86ms +step:1300/1695 train_time:128520ms step_avg:98.86ms +step:1301/1695 train_time:128621ms step_avg:98.86ms +step:1302/1695 train_time:128723ms step_avg:98.87ms +step:1303/1695 train_time:128826ms step_avg:98.87ms +step:1304/1695 train_time:128926ms step_avg:98.87ms +step:1305/1695 train_time:129028ms step_avg:98.87ms +step:1306/1695 train_time:129128ms step_avg:98.87ms +step:1307/1695 train_time:129229ms step_avg:98.87ms +step:1308/1695 train_time:129329ms step_avg:98.88ms +step:1309/1695 train_time:129430ms step_avg:98.88ms +step:1310/1695 train_time:129530ms step_avg:98.88ms +step:1311/1695 train_time:129631ms step_avg:98.88ms +step:1312/1695 train_time:129733ms step_avg:98.88ms +step:1313/1695 train_time:129835ms step_avg:98.88ms +step:1314/1695 train_time:129938ms step_avg:98.89ms +step:1315/1695 train_time:130039ms step_avg:98.89ms +step:1316/1695 train_time:130140ms step_avg:98.89ms +step:1317/1695 train_time:130241ms step_avg:98.89ms +step:1318/1695 train_time:130342ms step_avg:98.89ms +step:1319/1695 train_time:130443ms step_avg:98.90ms +step:1320/1695 train_time:130546ms step_avg:98.90ms +step:1321/1695 train_time:130646ms step_avg:98.90ms +step:1322/1695 train_time:130747ms step_avg:98.90ms +step:1323/1695 train_time:130847ms step_avg:98.90ms +step:1324/1695 train_time:130948ms step_avg:98.90ms +step:1325/1695 train_time:131049ms step_avg:98.91ms +step:1326/1695 train_time:131150ms step_avg:98.91ms +step:1327/1695 train_time:131252ms step_avg:98.91ms +step:1328/1695 train_time:131353ms step_avg:98.91ms +step:1329/1695 train_time:131455ms step_avg:98.91ms +step:1330/1695 train_time:131555ms step_avg:98.91ms +step:1331/1695 train_time:131658ms step_avg:98.92ms +step:1332/1695 train_time:131758ms step_avg:98.92ms +step:1333/1695 train_time:131859ms step_avg:98.92ms +step:1334/1695 train_time:131960ms step_avg:98.92ms +step:1335/1695 train_time:132061ms step_avg:98.92ms +step:1336/1695 train_time:132162ms step_avg:98.92ms +step:1337/1695 train_time:132266ms step_avg:98.93ms +step:1338/1695 train_time:132366ms step_avg:98.93ms +step:1339/1695 train_time:132467ms step_avg:98.93ms +step:1340/1695 train_time:132567ms step_avg:98.93ms +step:1341/1695 train_time:132667ms step_avg:98.93ms +step:1342/1695 train_time:132768ms step_avg:98.93ms +step:1343/1695 train_time:132868ms step_avg:98.93ms +step:1344/1695 train_time:132968ms step_avg:98.93ms +step:1345/1695 train_time:133069ms step_avg:98.94ms +step:1346/1695 train_time:133171ms step_avg:98.94ms +step:1347/1695 train_time:133272ms step_avg:98.94ms +step:1348/1695 train_time:133372ms step_avg:98.94ms +step:1349/1695 train_time:133473ms step_avg:98.94ms +step:1350/1695 train_time:133575ms step_avg:98.94ms +step:1351/1695 train_time:133675ms step_avg:98.95ms +step:1352/1695 train_time:133776ms step_avg:98.95ms +step:1353/1695 train_time:133877ms step_avg:98.95ms +step:1354/1695 train_time:133979ms step_avg:98.95ms +step:1355/1695 train_time:134079ms step_avg:98.95ms +step:1356/1695 train_time:134182ms step_avg:98.95ms +step:1357/1695 train_time:134283ms step_avg:98.96ms +step:1358/1695 train_time:134385ms step_avg:98.96ms +step:1359/1695 train_time:134486ms step_avg:98.96ms +step:1360/1695 train_time:134586ms step_avg:98.96ms +step:1361/1695 train_time:134687ms step_avg:98.96ms +step:1362/1695 train_time:134788ms step_avg:98.96ms +step:1363/1695 train_time:134889ms step_avg:98.96ms +step:1364/1695 train_time:134990ms step_avg:98.97ms +step:1365/1695 train_time:135091ms step_avg:98.97ms +step:1366/1695 train_time:135192ms step_avg:98.97ms +step:1367/1695 train_time:135295ms step_avg:98.97ms +step:1368/1695 train_time:135397ms step_avg:98.97ms +step:1369/1695 train_time:135497ms step_avg:98.98ms +step:1370/1695 train_time:135599ms step_avg:98.98ms +step:1371/1695 train_time:135700ms step_avg:98.98ms +step:1372/1695 train_time:135800ms step_avg:98.98ms +step:1373/1695 train_time:135901ms step_avg:98.98ms +step:1374/1695 train_time:136003ms step_avg:98.98ms +step:1375/1695 train_time:136106ms step_avg:98.99ms +step:1375/1695 val_loss:3.3545 train_time:136205ms step_avg:99.06ms +step:1376/1695 train_time:136232ms step_avg:99.01ms +step:1377/1695 train_time:136316ms step_avg:98.99ms +step:1378/1695 train_time:136420ms step_avg:99.00ms +step:1379/1695 train_time:136521ms step_avg:99.00ms +step:1380/1695 train_time:136623ms step_avg:99.00ms +step:1381/1695 train_time:136723ms step_avg:99.00ms +step:1382/1695 train_time:136822ms step_avg:99.00ms +step:1383/1695 train_time:136922ms step_avg:99.00ms +step:1384/1695 train_time:137023ms step_avg:99.00ms +step:1385/1695 train_time:137124ms step_avg:99.01ms +step:1386/1695 train_time:137227ms step_avg:99.01ms +step:1387/1695 train_time:137332ms step_avg:99.01ms +step:1388/1695 train_time:137433ms step_avg:99.02ms +step:1389/1695 train_time:137536ms step_avg:99.02ms +step:1390/1695 train_time:137637ms step_avg:99.02ms +step:1391/1695 train_time:137739ms step_avg:99.02ms +step:1392/1695 train_time:137841ms step_avg:99.02ms +step:1393/1695 train_time:137943ms step_avg:99.03ms +step:1394/1695 train_time:138043ms step_avg:99.03ms +step:1395/1695 train_time:138145ms step_avg:99.03ms +step:1396/1695 train_time:138248ms step_avg:99.03ms +step:1397/1695 train_time:138350ms step_avg:99.03ms +step:1398/1695 train_time:138452ms step_avg:99.04ms +step:1399/1695 train_time:138554ms step_avg:99.04ms +step:1400/1695 train_time:138656ms step_avg:99.04ms +step:1401/1695 train_time:138756ms step_avg:99.04ms +step:1402/1695 train_time:138859ms step_avg:99.04ms +step:1403/1695 train_time:138961ms step_avg:99.05ms +step:1404/1695 train_time:139063ms step_avg:99.05ms +step:1405/1695 train_time:139165ms step_avg:99.05ms +step:1406/1695 train_time:139267ms step_avg:99.05ms +step:1407/1695 train_time:139370ms step_avg:99.05ms +step:1408/1695 train_time:139471ms step_avg:99.06ms +step:1409/1695 train_time:139576ms step_avg:99.06ms +step:1410/1695 train_time:139678ms step_avg:99.06ms +step:1411/1695 train_time:139779ms step_avg:99.06ms +step:1412/1695 train_time:139882ms step_avg:99.07ms +step:1413/1695 train_time:139983ms step_avg:99.07ms +step:1414/1695 train_time:140085ms step_avg:99.07ms +step:1415/1695 train_time:140187ms step_avg:99.07ms +step:1416/1695 train_time:140287ms step_avg:99.07ms +step:1417/1695 train_time:140389ms step_avg:99.07ms +step:1418/1695 train_time:140490ms step_avg:99.08ms +step:1419/1695 train_time:140592ms step_avg:99.08ms +step:1420/1695 train_time:140694ms step_avg:99.08ms +step:1421/1695 train_time:140795ms step_avg:99.08ms +step:1422/1695 train_time:140897ms step_avg:99.08ms +step:1423/1695 train_time:140999ms step_avg:99.09ms +step:1424/1695 train_time:141100ms step_avg:99.09ms +step:1425/1695 train_time:141203ms step_avg:99.09ms +step:1426/1695 train_time:141306ms step_avg:99.09ms +step:1427/1695 train_time:141409ms step_avg:99.10ms +step:1428/1695 train_time:141511ms step_avg:99.10ms +step:1429/1695 train_time:141612ms step_avg:99.10ms +step:1430/1695 train_time:141713ms step_avg:99.10ms +step:1431/1695 train_time:141815ms step_avg:99.10ms +step:1432/1695 train_time:141915ms step_avg:99.10ms +step:1433/1695 train_time:142017ms step_avg:99.11ms +step:1434/1695 train_time:142120ms step_avg:99.11ms +step:1435/1695 train_time:142223ms step_avg:99.11ms +step:1436/1695 train_time:142326ms step_avg:99.11ms +step:1437/1695 train_time:142427ms step_avg:99.11ms +step:1438/1695 train_time:142529ms step_avg:99.12ms +step:1439/1695 train_time:142631ms step_avg:99.12ms +step:1440/1695 train_time:142734ms step_avg:99.12ms +step:1441/1695 train_time:142836ms step_avg:99.12ms +step:1442/1695 train_time:142936ms step_avg:99.12ms +step:1443/1695 train_time:143037ms step_avg:99.12ms +step:1444/1695 train_time:143139ms step_avg:99.13ms +step:1445/1695 train_time:143241ms step_avg:99.13ms +step:1446/1695 train_time:143345ms step_avg:99.13ms +step:1447/1695 train_time:143446ms step_avg:99.13ms +step:1448/1695 train_time:143550ms step_avg:99.14ms +step:1449/1695 train_time:143650ms step_avg:99.14ms +step:1450/1695 train_time:143751ms step_avg:99.14ms +step:1451/1695 train_time:143852ms step_avg:99.14ms +step:1452/1695 train_time:143954ms step_avg:99.14ms +step:1453/1695 train_time:144055ms step_avg:99.14ms +step:1454/1695 train_time:144158ms step_avg:99.15ms +step:1455/1695 train_time:144260ms step_avg:99.15ms +step:1456/1695 train_time:144363ms step_avg:99.15ms +step:1457/1695 train_time:144466ms step_avg:99.15ms +step:1458/1695 train_time:144567ms step_avg:99.15ms +step:1459/1695 train_time:144669ms step_avg:99.16ms +step:1460/1695 train_time:144770ms step_avg:99.16ms +step:1461/1695 train_time:144872ms step_avg:99.16ms +step:1462/1695 train_time:144973ms step_avg:99.16ms +step:1463/1695 train_time:145074ms step_avg:99.16ms +step:1464/1695 train_time:145176ms step_avg:99.16ms +step:1465/1695 train_time:145277ms step_avg:99.17ms +step:1466/1695 train_time:145380ms step_avg:99.17ms +step:1467/1695 train_time:145483ms step_avg:99.17ms +step:1468/1695 train_time:145585ms step_avg:99.17ms +step:1469/1695 train_time:145687ms step_avg:99.17ms +step:1470/1695 train_time:145788ms step_avg:99.18ms +step:1471/1695 train_time:145890ms step_avg:99.18ms +step:1472/1695 train_time:145991ms step_avg:99.18ms +step:1473/1695 train_time:146092ms step_avg:99.18ms +step:1474/1695 train_time:146193ms step_avg:99.18ms +step:1475/1695 train_time:146294ms step_avg:99.18ms +step:1476/1695 train_time:146397ms step_avg:99.18ms +step:1477/1695 train_time:146500ms step_avg:99.19ms +step:1478/1695 train_time:146601ms step_avg:99.19ms +step:1479/1695 train_time:146703ms step_avg:99.19ms +step:1480/1695 train_time:146805ms step_avg:99.19ms +step:1481/1695 train_time:146906ms step_avg:99.19ms +step:1482/1695 train_time:147008ms step_avg:99.20ms +step:1483/1695 train_time:147110ms step_avg:99.20ms +step:1484/1695 train_time:147213ms step_avg:99.20ms +step:1485/1695 train_time:147315ms step_avg:99.20ms +step:1486/1695 train_time:147416ms step_avg:99.20ms +step:1487/1695 train_time:147517ms step_avg:99.20ms +step:1488/1695 train_time:147619ms step_avg:99.21ms +step:1489/1695 train_time:147722ms step_avg:99.21ms +step:1490/1695 train_time:147824ms step_avg:99.21ms +step:1491/1695 train_time:147926ms step_avg:99.21ms +step:1492/1695 train_time:148027ms step_avg:99.21ms +step:1493/1695 train_time:148128ms step_avg:99.22ms +step:1494/1695 train_time:148230ms step_avg:99.22ms +step:1495/1695 train_time:148331ms step_avg:99.22ms +step:1496/1695 train_time:148432ms step_avg:99.22ms +step:1497/1695 train_time:148533ms step_avg:99.22ms +step:1498/1695 train_time:148635ms step_avg:99.22ms +step:1499/1695 train_time:148736ms step_avg:99.22ms +step:1500/1695 train_time:148838ms step_avg:99.23ms +step:1500/1695 val_loss:3.3204 train_time:148940ms step_avg:99.29ms +step:1501/1695 train_time:148971ms step_avg:99.25ms +step:1502/1695 train_time:149057ms step_avg:99.24ms +step:1503/1695 train_time:149159ms step_avg:99.24ms +step:1504/1695 train_time:149260ms step_avg:99.24ms +step:1505/1695 train_time:149362ms step_avg:99.24ms +step:1506/1695 train_time:149462ms step_avg:99.24ms +step:1507/1695 train_time:149564ms step_avg:99.25ms +step:1508/1695 train_time:149665ms step_avg:99.25ms +step:1509/1695 train_time:149767ms step_avg:99.25ms +step:1510/1695 train_time:149869ms step_avg:99.25ms +step:1511/1695 train_time:149972ms step_avg:99.25ms +step:1512/1695 train_time:150074ms step_avg:99.26ms +step:1513/1695 train_time:150175ms step_avg:99.26ms +step:1514/1695 train_time:150277ms step_avg:99.26ms +step:1515/1695 train_time:150383ms step_avg:99.26ms +step:1516/1695 train_time:150484ms step_avg:99.26ms +step:1517/1695 train_time:150584ms step_avg:99.26ms +step:1518/1695 train_time:150685ms step_avg:99.27ms +step:1519/1695 train_time:150789ms step_avg:99.27ms +step:1520/1695 train_time:150890ms step_avg:99.27ms +step:1521/1695 train_time:150992ms step_avg:99.27ms +step:1522/1695 train_time:151093ms step_avg:99.27ms +step:1523/1695 train_time:151194ms step_avg:99.27ms +step:1524/1695 train_time:151297ms step_avg:99.28ms +step:1525/1695 train_time:151401ms step_avg:99.28ms +step:1526/1695 train_time:151504ms step_avg:99.28ms +step:1527/1695 train_time:151606ms step_avg:99.28ms +step:1528/1695 train_time:151711ms step_avg:99.29ms +step:1529/1695 train_time:151812ms step_avg:99.29ms +step:1530/1695 train_time:151916ms step_avg:99.29ms +step:1531/1695 train_time:152019ms step_avg:99.29ms +step:1532/1695 train_time:152121ms step_avg:99.30ms +step:1533/1695 train_time:152223ms step_avg:99.30ms +step:1534/1695 train_time:152325ms step_avg:99.30ms +step:1535/1695 train_time:152427ms step_avg:99.30ms +step:1536/1695 train_time:152528ms step_avg:99.30ms +step:1537/1695 train_time:152629ms step_avg:99.30ms +step:1538/1695 train_time:152730ms step_avg:99.30ms +step:1539/1695 train_time:152831ms step_avg:99.31ms +step:1540/1695 train_time:152933ms step_avg:99.31ms +step:1541/1695 train_time:153037ms step_avg:99.31ms +step:1542/1695 train_time:153143ms step_avg:99.31ms +step:1543/1695 train_time:153245ms step_avg:99.32ms +step:1544/1695 train_time:153347ms step_avg:99.32ms +step:1545/1695 train_time:153449ms step_avg:99.32ms +step:1546/1695 train_time:153550ms step_avg:99.32ms +step:1547/1695 train_time:153653ms step_avg:99.32ms +step:1548/1695 train_time:153755ms step_avg:99.32ms +step:1549/1695 train_time:153856ms step_avg:99.33ms +step:1550/1695 train_time:153958ms step_avg:99.33ms +step:1551/1695 train_time:154060ms step_avg:99.33ms +step:1552/1695 train_time:154163ms step_avg:99.33ms +step:1553/1695 train_time:154267ms step_avg:99.33ms +step:1554/1695 train_time:154368ms step_avg:99.34ms +step:1555/1695 train_time:154469ms step_avg:99.34ms +step:1556/1695 train_time:154571ms step_avg:99.34ms +step:1557/1695 train_time:154675ms step_avg:99.34ms +step:1558/1695 train_time:154779ms step_avg:99.34ms +step:1559/1695 train_time:154881ms step_avg:99.35ms +step:1560/1695 train_time:154983ms step_avg:99.35ms +step:1561/1695 train_time:155084ms step_avg:99.35ms +step:1562/1695 train_time:155186ms step_avg:99.35ms +step:1563/1695 train_time:155289ms step_avg:99.35ms +step:1564/1695 train_time:155391ms step_avg:99.35ms +step:1565/1695 train_time:155491ms step_avg:99.36ms +step:1566/1695 train_time:155593ms step_avg:99.36ms +step:1567/1695 train_time:155693ms step_avg:99.36ms +step:1568/1695 train_time:155794ms step_avg:99.36ms +step:1569/1695 train_time:155895ms step_avg:99.36ms +step:1570/1695 train_time:155997ms step_avg:99.36ms +step:1571/1695 train_time:156099ms step_avg:99.36ms +step:1572/1695 train_time:156202ms step_avg:99.36ms +step:1573/1695 train_time:156304ms step_avg:99.37ms +step:1574/1695 train_time:156405ms step_avg:99.37ms +step:1575/1695 train_time:156506ms step_avg:99.37ms +step:1576/1695 train_time:156609ms step_avg:99.37ms +step:1577/1695 train_time:156712ms step_avg:99.37ms +step:1578/1695 train_time:156812ms step_avg:99.37ms +step:1579/1695 train_time:156914ms step_avg:99.38ms +step:1580/1695 train_time:157015ms step_avg:99.38ms +step:1581/1695 train_time:157118ms step_avg:99.38ms +step:1582/1695 train_time:157221ms step_avg:99.38ms +step:1583/1695 train_time:157324ms step_avg:99.38ms +step:1584/1695 train_time:157426ms step_avg:99.39ms +step:1585/1695 train_time:157527ms step_avg:99.39ms +step:1586/1695 train_time:157629ms step_avg:99.39ms +step:1587/1695 train_time:157732ms step_avg:99.39ms +step:1588/1695 train_time:157833ms step_avg:99.39ms +step:1589/1695 train_time:157934ms step_avg:99.39ms +step:1590/1695 train_time:158035ms step_avg:99.39ms +step:1591/1695 train_time:158137ms step_avg:99.39ms +step:1592/1695 train_time:158240ms step_avg:99.40ms +step:1593/1695 train_time:158342ms step_avg:99.40ms +step:1594/1695 train_time:158446ms step_avg:99.40ms +step:1595/1695 train_time:158548ms step_avg:99.40ms +step:1596/1695 train_time:158650ms step_avg:99.40ms +step:1597/1695 train_time:158751ms step_avg:99.41ms +step:1598/1695 train_time:158854ms step_avg:99.41ms +step:1599/1695 train_time:158956ms step_avg:99.41ms +step:1600/1695 train_time:159058ms step_avg:99.41ms +step:1601/1695 train_time:159161ms step_avg:99.41ms +step:1602/1695 train_time:159263ms step_avg:99.41ms +step:1603/1695 train_time:159364ms step_avg:99.42ms +step:1604/1695 train_time:159465ms step_avg:99.42ms +step:1605/1695 train_time:159568ms step_avg:99.42ms +step:1606/1695 train_time:159670ms step_avg:99.42ms +step:1607/1695 train_time:159771ms step_avg:99.42ms +step:1608/1695 train_time:159872ms step_avg:99.42ms +step:1609/1695 train_time:159974ms step_avg:99.42ms +step:1610/1695 train_time:160076ms step_avg:99.43ms +step:1611/1695 train_time:160179ms step_avg:99.43ms +step:1612/1695 train_time:160281ms step_avg:99.43ms +step:1613/1695 train_time:160383ms step_avg:99.43ms +step:1614/1695 train_time:160484ms step_avg:99.43ms +step:1615/1695 train_time:160586ms step_avg:99.43ms +step:1616/1695 train_time:160687ms step_avg:99.44ms +step:1617/1695 train_time:160789ms step_avg:99.44ms +step:1618/1695 train_time:160891ms step_avg:99.44ms +step:1619/1695 train_time:160992ms step_avg:99.44ms +step:1620/1695 train_time:161094ms step_avg:99.44ms +step:1621/1695 train_time:161195ms step_avg:99.44ms +step:1622/1695 train_time:161297ms step_avg:99.44ms +step:1623/1695 train_time:161402ms step_avg:99.45ms +step:1624/1695 train_time:161504ms step_avg:99.45ms +step:1625/1695 train_time:161607ms step_avg:99.45ms +step:1625/1695 val_loss:3.2911 train_time:161708ms step_avg:99.51ms +step:1626/1695 train_time:161735ms step_avg:99.47ms +step:1627/1695 train_time:161823ms step_avg:99.46ms +step:1628/1695 train_time:161924ms step_avg:99.46ms +step:1629/1695 train_time:162026ms step_avg:99.46ms +step:1630/1695 train_time:162128ms step_avg:99.46ms +step:1631/1695 train_time:162230ms step_avg:99.47ms +step:1632/1695 train_time:162332ms step_avg:99.47ms +step:1633/1695 train_time:162432ms step_avg:99.47ms +step:1634/1695 train_time:162534ms step_avg:99.47ms +step:1635/1695 train_time:162637ms step_avg:99.47ms +step:1636/1695 train_time:162740ms step_avg:99.47ms +step:1637/1695 train_time:162843ms step_avg:99.48ms +step:1638/1695 train_time:162945ms step_avg:99.48ms +step:1639/1695 train_time:163047ms step_avg:99.48ms +step:1640/1695 train_time:163150ms step_avg:99.48ms +step:1641/1695 train_time:163253ms step_avg:99.48ms +step:1642/1695 train_time:163356ms step_avg:99.49ms +step:1643/1695 train_time:163459ms step_avg:99.49ms +step:1644/1695 train_time:163561ms step_avg:99.49ms +step:1645/1695 train_time:163664ms step_avg:99.49ms +step:1646/1695 train_time:163768ms step_avg:99.49ms +step:1647/1695 train_time:163871ms step_avg:99.50ms +step:1648/1695 train_time:163976ms step_avg:99.50ms +step:1649/1695 train_time:164078ms step_avg:99.50ms +step:1650/1695 train_time:164182ms step_avg:99.50ms +step:1651/1695 train_time:164283ms step_avg:99.51ms +step:1652/1695 train_time:164385ms step_avg:99.51ms +step:1653/1695 train_time:164488ms step_avg:99.51ms +step:1654/1695 train_time:164590ms step_avg:99.51ms +step:1655/1695 train_time:164692ms step_avg:99.51ms +step:1656/1695 train_time:164794ms step_avg:99.51ms +step:1657/1695 train_time:164896ms step_avg:99.51ms +step:1658/1695 train_time:165000ms step_avg:99.52ms +step:1659/1695 train_time:165107ms step_avg:99.52ms +step:1660/1695 train_time:165209ms step_avg:99.52ms +step:1661/1695 train_time:165313ms step_avg:99.53ms +step:1662/1695 train_time:165418ms step_avg:99.53ms +step:1663/1695 train_time:165522ms step_avg:99.53ms +step:1664/1695 train_time:165624ms step_avg:99.53ms +step:1665/1695 train_time:165731ms step_avg:99.54ms +step:1666/1695 train_time:165833ms step_avg:99.54ms +step:1667/1695 train_time:165935ms step_avg:99.54ms +step:1668/1695 train_time:166039ms step_avg:99.54ms +step:1669/1695 train_time:166143ms step_avg:99.55ms +step:1670/1695 train_time:166245ms step_avg:99.55ms +step:1671/1695 train_time:166347ms step_avg:99.55ms +step:1672/1695 train_time:166450ms step_avg:99.55ms +step:1673/1695 train_time:166552ms step_avg:99.55ms +step:1674/1695 train_time:166654ms step_avg:99.55ms +step:1675/1695 train_time:166757ms step_avg:99.56ms +step:1676/1695 train_time:166861ms step_avg:99.56ms +step:1677/1695 train_time:166963ms step_avg:99.56ms +step:1678/1695 train_time:167066ms step_avg:99.56ms +step:1679/1695 train_time:167169ms step_avg:99.56ms +step:1680/1695 train_time:167272ms step_avg:99.57ms +step:1681/1695 train_time:167375ms step_avg:99.57ms +step:1682/1695 train_time:167483ms step_avg:99.57ms +step:1683/1695 train_time:167585ms step_avg:99.58ms +step:1684/1695 train_time:167687ms step_avg:99.58ms +step:1685/1695 train_time:167790ms step_avg:99.58ms +step:1686/1695 train_time:167893ms step_avg:99.58ms +step:1687/1695 train_time:167995ms step_avg:99.58ms +step:1688/1695 train_time:168097ms step_avg:99.58ms +step:1689/1695 train_time:168199ms step_avg:99.59ms +step:1690/1695 train_time:168301ms step_avg:99.59ms +step:1691/1695 train_time:168403ms step_avg:99.59ms +step:1692/1695 train_time:168505ms step_avg:99.59ms +step:1693/1695 train_time:168610ms step_avg:99.59ms +step:1694/1695 train_time:168714ms step_avg:99.60ms +step:1695/1695 train_time:168817ms step_avg:99.60ms +step:1695/1695 val_loss:3.2784 train_time:168917ms step_avg:99.66ms +peak memory allocated: 34004 MiB reserved: 49660 MiB diff --git a/records/082325_SparseAttnGate/ca042caf-b232-4a25-b28f-88e39a2009d3.txt b/records/082325_SparseAttnGate/ca042caf-b232-4a25-b28f-88e39a2009d3.txt new file mode 100644 index 000000000..2b5eb6f73 --- /dev/null +++ b/records/082325_SparseAttnGate/ca042caf-b232-4a25-b28f-88e39a2009d3.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:20:14 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 303577 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 303578 C /usr/bin/python3 614MiB | +| 0 N/A N/A 303579 C /usr/bin/python3 614MiB | +| 0 N/A N/A 303580 C /usr/bin/python3 614MiB | +| 0 N/A N/A 303581 C /usr/bin/python3 614MiB | +| 0 N/A N/A 303582 C /usr/bin/python3 614MiB | +| 0 N/A N/A 303583 C /usr/bin/python3 614MiB | +| 0 N/A N/A 303584 C /usr/bin/python3 614MiB | +| 1 N/A N/A 303578 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 303579 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 303580 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 303581 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 303582 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 303583 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 303584 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:154ms step_avg:154.29ms +step:2/1695 train_time:181ms step_avg:90.43ms +step:3/1695 train_time:250ms step_avg:83.32ms +step:4/1695 train_time:342ms step_avg:85.40ms +step:5/1695 train_time:434ms step_avg:86.90ms +step:6/1695 train_time:527ms step_avg:87.85ms +step:7/1695 train_time:620ms step_avg:88.57ms +step:8/1695 train_time:713ms step_avg:89.11ms +step:9/1695 train_time:806ms step_avg:89.52ms +step:10/1695 train_time:899ms step_avg:89.85ms +step:11/1695 train_time:993ms step_avg:90.24ms +step:12/1695 train_time:1089ms step_avg:90.71ms +step:13/1695 train_time:1183ms step_avg:91.00ms +step:14/1695 train_time:1278ms step_avg:91.25ms +step:15/1695 train_time:1372ms step_avg:91.45ms +step:16/1695 train_time:1465ms step_avg:91.59ms +step:17/1695 train_time:1559ms step_avg:91.68ms +step:18/1695 train_time:1652ms step_avg:91.77ms +step:19/1695 train_time:1745ms step_avg:91.83ms +step:20/1695 train_time:1838ms step_avg:91.91ms +step:21/1695 train_time:1933ms step_avg:92.04ms +step:22/1695 train_time:2027ms step_avg:92.14ms +step:23/1695 train_time:2121ms step_avg:92.21ms +step:24/1695 train_time:2216ms step_avg:92.32ms +step:25/1695 train_time:2310ms step_avg:92.40ms +step:26/1695 train_time:2405ms step_avg:92.49ms +step:27/1695 train_time:2499ms step_avg:92.54ms +step:28/1695 train_time:2593ms step_avg:92.60ms +step:29/1695 train_time:2687ms step_avg:92.65ms +step:30/1695 train_time:2780ms step_avg:92.67ms +step:31/1695 train_time:2873ms step_avg:92.69ms +step:32/1695 train_time:2967ms step_avg:92.72ms +step:33/1695 train_time:3060ms step_avg:92.73ms +step:34/1695 train_time:3155ms step_avg:92.79ms +step:35/1695 train_time:3249ms step_avg:92.84ms +step:36/1695 train_time:3343ms step_avg:92.87ms +step:37/1695 train_time:3437ms step_avg:92.89ms +step:38/1695 train_time:3531ms step_avg:92.93ms +step:39/1695 train_time:3625ms step_avg:92.95ms +step:40/1695 train_time:3718ms step_avg:92.96ms +step:41/1695 train_time:3812ms step_avg:92.97ms +step:42/1695 train_time:3906ms step_avg:92.99ms +step:43/1695 train_time:3999ms step_avg:92.99ms +step:44/1695 train_time:4092ms step_avg:93.00ms +step:45/1695 train_time:4186ms step_avg:93.03ms +step:46/1695 train_time:4280ms step_avg:93.04ms +step:47/1695 train_time:4374ms step_avg:93.06ms +step:48/1695 train_time:4468ms step_avg:93.08ms +step:49/1695 train_time:4561ms step_avg:93.09ms +step:50/1695 train_time:4655ms step_avg:93.10ms +step:51/1695 train_time:4750ms step_avg:93.13ms +step:52/1695 train_time:4843ms step_avg:93.13ms +step:53/1695 train_time:4936ms step_avg:93.14ms +step:54/1695 train_time:5031ms step_avg:93.17ms +step:55/1695 train_time:5126ms step_avg:93.20ms +step:56/1695 train_time:5219ms step_avg:93.20ms +step:57/1695 train_time:5315ms step_avg:93.24ms +step:58/1695 train_time:5409ms step_avg:93.26ms +step:59/1695 train_time:5502ms step_avg:93.26ms +step:60/1695 train_time:5596ms step_avg:93.26ms +step:61/1695 train_time:5690ms step_avg:93.28ms +step:62/1695 train_time:5784ms step_avg:93.29ms +step:63/1695 train_time:5877ms step_avg:93.28ms +step:64/1695 train_time:5971ms step_avg:93.30ms +step:65/1695 train_time:6065ms step_avg:93.31ms +step:66/1695 train_time:6159ms step_avg:93.31ms +step:67/1695 train_time:6253ms step_avg:93.33ms +step:68/1695 train_time:6347ms step_avg:93.34ms +step:69/1695 train_time:6440ms step_avg:93.34ms +step:70/1695 train_time:6535ms step_avg:93.35ms +step:71/1695 train_time:6628ms step_avg:93.35ms +step:72/1695 train_time:6721ms step_avg:93.35ms +step:73/1695 train_time:6815ms step_avg:93.35ms +step:74/1695 train_time:6909ms step_avg:93.36ms +step:75/1695 train_time:7003ms step_avg:93.37ms +step:76/1695 train_time:7096ms step_avg:93.37ms +step:77/1695 train_time:7191ms step_avg:93.39ms +step:78/1695 train_time:7285ms step_avg:93.40ms +step:79/1695 train_time:7379ms step_avg:93.40ms +step:80/1695 train_time:7473ms step_avg:93.41ms +step:81/1695 train_time:7566ms step_avg:93.41ms +step:82/1695 train_time:7659ms step_avg:93.41ms +step:83/1695 train_time:7753ms step_avg:93.41ms +step:84/1695 train_time:7846ms step_avg:93.41ms +step:85/1695 train_time:7940ms step_avg:93.42ms +step:86/1695 train_time:8034ms step_avg:93.42ms +step:87/1695 train_time:8127ms step_avg:93.42ms +step:88/1695 train_time:8220ms step_avg:93.41ms +step:89/1695 train_time:8315ms step_avg:93.43ms +step:90/1695 train_time:8410ms step_avg:93.44ms +step:91/1695 train_time:8504ms step_avg:93.45ms +step:92/1695 train_time:8597ms step_avg:93.45ms +step:93/1695 train_time:8691ms step_avg:93.45ms +step:94/1695 train_time:8784ms step_avg:93.44ms +step:95/1695 train_time:8877ms step_avg:93.44ms +step:96/1695 train_time:8971ms step_avg:93.45ms +step:97/1695 train_time:9064ms step_avg:93.44ms +step:98/1695 train_time:9157ms step_avg:93.44ms +step:99/1695 train_time:9251ms step_avg:93.45ms +step:100/1695 train_time:9345ms step_avg:93.45ms +step:101/1695 train_time:9439ms step_avg:93.46ms +step:102/1695 train_time:9534ms step_avg:93.47ms +step:103/1695 train_time:9627ms step_avg:93.47ms +step:104/1695 train_time:9721ms step_avg:93.47ms +step:105/1695 train_time:9814ms step_avg:93.47ms +step:106/1695 train_time:9907ms step_avg:93.47ms +step:107/1695 train_time:10000ms step_avg:93.46ms +step:108/1695 train_time:10094ms step_avg:93.46ms +step:109/1695 train_time:10188ms step_avg:93.47ms +step:110/1695 train_time:10281ms step_avg:93.47ms +step:111/1695 train_time:10375ms step_avg:93.47ms +step:112/1695 train_time:10469ms step_avg:93.47ms +step:113/1695 train_time:10562ms step_avg:93.47ms +step:114/1695 train_time:10656ms step_avg:93.47ms +step:115/1695 train_time:10750ms step_avg:93.48ms +step:116/1695 train_time:10844ms step_avg:93.48ms +step:117/1695 train_time:10937ms step_avg:93.47ms +step:118/1695 train_time:11031ms step_avg:93.48ms +step:119/1695 train_time:11124ms step_avg:93.48ms +step:120/1695 train_time:11218ms step_avg:93.48ms +step:121/1695 train_time:11312ms step_avg:93.48ms +step:122/1695 train_time:11406ms step_avg:93.49ms +step:123/1695 train_time:11499ms step_avg:93.49ms +step:124/1695 train_time:11593ms step_avg:93.49ms +step:125/1695 train_time:11687ms step_avg:93.50ms +step:125/1695 val_loss:4.6033 train_time:11779ms step_avg:94.23ms +step:126/1695 train_time:11806ms step_avg:93.70ms +step:127/1695 train_time:11881ms step_avg:93.55ms +step:128/1695 train_time:11983ms step_avg:93.62ms +step:129/1695 train_time:12081ms step_avg:93.65ms +step:130/1695 train_time:12176ms step_avg:93.66ms +step:131/1695 train_time:12269ms step_avg:93.66ms +step:132/1695 train_time:12362ms step_avg:93.65ms +step:133/1695 train_time:12455ms step_avg:93.65ms +step:134/1695 train_time:12549ms step_avg:93.65ms +step:135/1695 train_time:12642ms step_avg:93.64ms +step:136/1695 train_time:12736ms step_avg:93.64ms +step:137/1695 train_time:12830ms step_avg:93.65ms +step:138/1695 train_time:12925ms step_avg:93.66ms +step:139/1695 train_time:13021ms step_avg:93.68ms +step:140/1695 train_time:13117ms step_avg:93.70ms +step:141/1695 train_time:13212ms step_avg:93.71ms +step:142/1695 train_time:13307ms step_avg:93.71ms +step:143/1695 train_time:13400ms step_avg:93.71ms +step:144/1695 train_time:13494ms step_avg:93.71ms +step:145/1695 train_time:13588ms step_avg:93.71ms +step:146/1695 train_time:13682ms step_avg:93.71ms +step:147/1695 train_time:13775ms step_avg:93.71ms +step:148/1695 train_time:13869ms step_avg:93.71ms +step:149/1695 train_time:13964ms step_avg:93.72ms +step:150/1695 train_time:14058ms step_avg:93.72ms +step:151/1695 train_time:14154ms step_avg:93.74ms +step:152/1695 train_time:14248ms step_avg:93.74ms +step:153/1695 train_time:14341ms step_avg:93.73ms +step:154/1695 train_time:14437ms step_avg:93.74ms +step:155/1695 train_time:14532ms step_avg:93.75ms +step:156/1695 train_time:14626ms step_avg:93.75ms +step:157/1695 train_time:14719ms step_avg:93.75ms +step:158/1695 train_time:14814ms step_avg:93.76ms +step:159/1695 train_time:14908ms step_avg:93.76ms +step:160/1695 train_time:15002ms step_avg:93.77ms +step:161/1695 train_time:15097ms step_avg:93.77ms +step:162/1695 train_time:15192ms step_avg:93.78ms +step:163/1695 train_time:15286ms step_avg:93.78ms +step:164/1695 train_time:15380ms step_avg:93.78ms +step:165/1695 train_time:15474ms step_avg:93.78ms +step:166/1695 train_time:15569ms step_avg:93.79ms +step:167/1695 train_time:15663ms step_avg:93.79ms +step:168/1695 train_time:15757ms step_avg:93.79ms +step:169/1695 train_time:15851ms step_avg:93.79ms +step:170/1695 train_time:15945ms step_avg:93.80ms +step:171/1695 train_time:16039ms step_avg:93.79ms +step:172/1695 train_time:16133ms step_avg:93.80ms +step:173/1695 train_time:16227ms step_avg:93.80ms +step:174/1695 train_time:16320ms step_avg:93.79ms +step:175/1695 train_time:16415ms step_avg:93.80ms +step:176/1695 train_time:16509ms step_avg:93.80ms +step:177/1695 train_time:16603ms step_avg:93.80ms +step:178/1695 train_time:16697ms step_avg:93.80ms +step:179/1695 train_time:16791ms step_avg:93.81ms +step:180/1695 train_time:16885ms step_avg:93.81ms +step:181/1695 train_time:16979ms step_avg:93.81ms +step:182/1695 train_time:17074ms step_avg:93.81ms +step:183/1695 train_time:17168ms step_avg:93.82ms +step:184/1695 train_time:17262ms step_avg:93.82ms +step:185/1695 train_time:17356ms step_avg:93.82ms +step:186/1695 train_time:17451ms step_avg:93.82ms +step:187/1695 train_time:17545ms step_avg:93.83ms +step:188/1695 train_time:17639ms step_avg:93.83ms +step:189/1695 train_time:17733ms step_avg:93.83ms +step:190/1695 train_time:17827ms step_avg:93.82ms +step:191/1695 train_time:17920ms step_avg:93.82ms +step:192/1695 train_time:18015ms step_avg:93.83ms +step:193/1695 train_time:18110ms step_avg:93.83ms +step:194/1695 train_time:18203ms step_avg:93.83ms +step:195/1695 train_time:18297ms step_avg:93.83ms +step:196/1695 train_time:18392ms step_avg:93.84ms +step:197/1695 train_time:18486ms step_avg:93.84ms +step:198/1695 train_time:18580ms step_avg:93.84ms +step:199/1695 train_time:18675ms step_avg:93.84ms +step:200/1695 train_time:18769ms step_avg:93.85ms +step:201/1695 train_time:18862ms step_avg:93.84ms +step:202/1695 train_time:18957ms step_avg:93.84ms +step:203/1695 train_time:19052ms step_avg:93.85ms +step:204/1695 train_time:19145ms step_avg:93.85ms +step:205/1695 train_time:19239ms step_avg:93.85ms +step:206/1695 train_time:19333ms step_avg:93.85ms +step:207/1695 train_time:19427ms step_avg:93.85ms +step:208/1695 train_time:19520ms step_avg:93.85ms +step:209/1695 train_time:19615ms step_avg:93.85ms +step:210/1695 train_time:19709ms step_avg:93.85ms +step:211/1695 train_time:19803ms step_avg:93.85ms +step:212/1695 train_time:19897ms step_avg:93.86ms +step:213/1695 train_time:19991ms step_avg:93.86ms +step:214/1695 train_time:20085ms step_avg:93.86ms +step:215/1695 train_time:20179ms step_avg:93.86ms +step:216/1695 train_time:20274ms step_avg:93.86ms +step:217/1695 train_time:20368ms step_avg:93.86ms +step:218/1695 train_time:20461ms step_avg:93.86ms +step:219/1695 train_time:20555ms step_avg:93.86ms +step:220/1695 train_time:20650ms step_avg:93.86ms +step:221/1695 train_time:20745ms step_avg:93.87ms +step:222/1695 train_time:20839ms step_avg:93.87ms +step:223/1695 train_time:20933ms step_avg:93.87ms +step:224/1695 train_time:21027ms step_avg:93.87ms +step:225/1695 train_time:21121ms step_avg:93.87ms +step:226/1695 train_time:21216ms step_avg:93.88ms +step:227/1695 train_time:21311ms step_avg:93.88ms +step:228/1695 train_time:21405ms step_avg:93.88ms +step:229/1695 train_time:21499ms step_avg:93.88ms +step:230/1695 train_time:21594ms step_avg:93.89ms +step:231/1695 train_time:21688ms step_avg:93.89ms +step:232/1695 train_time:21781ms step_avg:93.88ms +step:233/1695 train_time:21875ms step_avg:93.88ms +step:234/1695 train_time:21969ms step_avg:93.88ms +step:235/1695 train_time:22062ms step_avg:93.88ms +step:236/1695 train_time:22156ms step_avg:93.88ms +step:237/1695 train_time:22250ms step_avg:93.88ms +step:238/1695 train_time:22344ms step_avg:93.88ms +step:239/1695 train_time:22438ms step_avg:93.88ms +step:240/1695 train_time:22532ms step_avg:93.88ms +step:241/1695 train_time:22626ms step_avg:93.88ms +step:242/1695 train_time:22720ms step_avg:93.88ms +step:243/1695 train_time:22814ms step_avg:93.88ms +step:244/1695 train_time:22907ms step_avg:93.88ms +step:245/1695 train_time:23001ms step_avg:93.88ms +step:246/1695 train_time:23096ms step_avg:93.88ms +step:247/1695 train_time:23189ms step_avg:93.88ms +step:248/1695 train_time:23283ms step_avg:93.88ms +step:249/1695 train_time:23377ms step_avg:93.88ms +step:250/1695 train_time:23470ms step_avg:93.88ms +step:250/1695 val_loss:4.0846 train_time:23562ms step_avg:94.25ms +step:251/1695 train_time:23590ms step_avg:93.98ms +step:252/1695 train_time:23665ms step_avg:93.91ms +step:253/1695 train_time:23763ms step_avg:93.92ms +step:254/1695 train_time:23858ms step_avg:93.93ms +step:255/1695 train_time:23952ms step_avg:93.93ms +step:256/1695 train_time:24047ms step_avg:93.93ms +step:257/1695 train_time:24140ms step_avg:93.93ms +step:258/1695 train_time:24234ms step_avg:93.93ms +step:259/1695 train_time:24327ms step_avg:93.93ms +step:260/1695 train_time:24421ms step_avg:93.93ms +step:261/1695 train_time:24515ms step_avg:93.93ms +step:262/1695 train_time:24612ms step_avg:93.94ms +step:263/1695 train_time:24708ms step_avg:93.95ms +step:264/1695 train_time:24803ms step_avg:93.95ms +step:265/1695 train_time:24898ms step_avg:93.95ms +step:266/1695 train_time:24993ms step_avg:93.96ms +step:267/1695 train_time:25087ms step_avg:93.96ms +step:268/1695 train_time:25181ms step_avg:93.96ms +step:269/1695 train_time:25274ms step_avg:93.96ms +step:270/1695 train_time:25369ms step_avg:93.96ms +step:271/1695 train_time:25463ms step_avg:93.96ms +step:272/1695 train_time:25557ms step_avg:93.96ms +step:273/1695 train_time:25653ms step_avg:93.97ms +step:274/1695 train_time:25750ms step_avg:93.98ms +step:275/1695 train_time:25845ms step_avg:93.98ms +step:276/1695 train_time:25939ms step_avg:93.98ms +step:277/1695 train_time:26033ms step_avg:93.98ms +step:278/1695 train_time:26128ms step_avg:93.99ms +step:279/1695 train_time:26222ms step_avg:93.99ms +step:280/1695 train_time:26316ms step_avg:93.99ms +step:281/1695 train_time:26411ms step_avg:93.99ms +step:282/1695 train_time:26506ms step_avg:93.99ms +step:283/1695 train_time:26600ms step_avg:93.99ms +step:284/1695 train_time:26695ms step_avg:94.00ms +step:285/1695 train_time:26791ms step_avg:94.00ms +step:286/1695 train_time:26887ms step_avg:94.01ms +step:287/1695 train_time:26981ms step_avg:94.01ms +step:288/1695 train_time:27075ms step_avg:94.01ms +step:289/1695 train_time:27170ms step_avg:94.01ms +step:290/1695 train_time:27264ms step_avg:94.02ms +step:291/1695 train_time:27358ms step_avg:94.01ms +step:292/1695 train_time:27453ms step_avg:94.02ms +step:293/1695 train_time:27548ms step_avg:94.02ms +step:294/1695 train_time:27642ms step_avg:94.02ms +step:295/1695 train_time:27736ms step_avg:94.02ms +step:296/1695 train_time:27832ms step_avg:94.03ms +step:297/1695 train_time:27927ms step_avg:94.03ms +step:298/1695 train_time:28022ms step_avg:94.03ms +step:299/1695 train_time:28116ms step_avg:94.03ms +step:300/1695 train_time:28210ms step_avg:94.03ms +step:301/1695 train_time:28304ms step_avg:94.03ms +step:302/1695 train_time:28398ms step_avg:94.03ms +step:303/1695 train_time:28492ms step_avg:94.03ms +step:304/1695 train_time:28587ms step_avg:94.04ms +step:305/1695 train_time:28681ms step_avg:94.04ms +step:306/1695 train_time:28775ms step_avg:94.03ms +step:307/1695 train_time:28870ms step_avg:94.04ms +step:308/1695 train_time:28965ms step_avg:94.04ms +step:309/1695 train_time:29059ms step_avg:94.04ms +step:310/1695 train_time:29154ms step_avg:94.04ms +step:311/1695 train_time:29248ms step_avg:94.05ms +step:312/1695 train_time:29343ms step_avg:94.05ms +step:313/1695 train_time:29436ms step_avg:94.05ms +step:314/1695 train_time:29532ms step_avg:94.05ms +step:315/1695 train_time:29627ms step_avg:94.05ms +step:316/1695 train_time:29722ms step_avg:94.06ms +step:317/1695 train_time:29816ms step_avg:94.06ms +step:318/1695 train_time:29911ms step_avg:94.06ms +step:319/1695 train_time:30007ms step_avg:94.07ms +step:320/1695 train_time:30101ms step_avg:94.06ms +step:321/1695 train_time:30195ms step_avg:94.07ms +step:322/1695 train_time:30291ms step_avg:94.07ms +step:323/1695 train_time:30385ms step_avg:94.07ms +step:324/1695 train_time:30479ms step_avg:94.07ms +step:325/1695 train_time:30574ms step_avg:94.07ms +step:326/1695 train_time:30668ms step_avg:94.07ms +step:327/1695 train_time:30763ms step_avg:94.08ms +step:328/1695 train_time:30857ms step_avg:94.08ms +step:329/1695 train_time:30952ms step_avg:94.08ms +step:330/1695 train_time:31047ms step_avg:94.08ms +step:331/1695 train_time:31140ms step_avg:94.08ms +step:332/1695 train_time:31235ms step_avg:94.08ms +step:333/1695 train_time:31330ms step_avg:94.08ms +step:334/1695 train_time:31425ms step_avg:94.09ms +step:335/1695 train_time:31519ms step_avg:94.09ms +step:336/1695 train_time:31614ms step_avg:94.09ms +step:337/1695 train_time:31710ms step_avg:94.09ms +step:338/1695 train_time:31803ms step_avg:94.09ms +step:339/1695 train_time:31898ms step_avg:94.09ms +step:340/1695 train_time:31992ms step_avg:94.10ms +step:341/1695 train_time:32087ms step_avg:94.10ms +step:342/1695 train_time:32181ms step_avg:94.10ms +step:343/1695 train_time:32275ms step_avg:94.10ms +step:344/1695 train_time:32370ms step_avg:94.10ms +step:345/1695 train_time:32464ms step_avg:94.10ms +step:346/1695 train_time:32559ms step_avg:94.10ms +step:347/1695 train_time:32654ms step_avg:94.10ms +step:348/1695 train_time:32747ms step_avg:94.10ms +step:349/1695 train_time:32842ms step_avg:94.10ms +step:350/1695 train_time:32937ms step_avg:94.10ms +step:351/1695 train_time:33032ms step_avg:94.11ms +step:352/1695 train_time:33126ms step_avg:94.11ms +step:353/1695 train_time:33220ms step_avg:94.11ms +step:354/1695 train_time:33315ms step_avg:94.11ms +step:355/1695 train_time:33410ms step_avg:94.11ms +step:356/1695 train_time:33504ms step_avg:94.11ms +step:357/1695 train_time:33598ms step_avg:94.11ms +step:358/1695 train_time:33692ms step_avg:94.11ms +step:359/1695 train_time:33786ms step_avg:94.11ms +step:360/1695 train_time:33880ms step_avg:94.11ms +step:361/1695 train_time:33974ms step_avg:94.11ms +step:362/1695 train_time:34070ms step_avg:94.12ms +step:363/1695 train_time:34163ms step_avg:94.11ms +step:364/1695 train_time:34258ms step_avg:94.11ms +step:365/1695 train_time:34352ms step_avg:94.12ms +step:366/1695 train_time:34447ms step_avg:94.12ms +step:367/1695 train_time:34541ms step_avg:94.12ms +step:368/1695 train_time:34635ms step_avg:94.12ms +step:369/1695 train_time:34730ms step_avg:94.12ms +step:370/1695 train_time:34824ms step_avg:94.12ms +step:371/1695 train_time:34918ms step_avg:94.12ms +step:372/1695 train_time:35013ms step_avg:94.12ms +step:373/1695 train_time:35107ms step_avg:94.12ms +step:374/1695 train_time:35202ms step_avg:94.12ms +step:375/1695 train_time:35296ms step_avg:94.12ms +step:375/1695 val_loss:3.8779 train_time:35389ms step_avg:94.37ms +step:376/1695 train_time:35417ms step_avg:94.19ms +step:377/1695 train_time:35495ms step_avg:94.15ms +step:378/1695 train_time:35595ms step_avg:94.17ms +step:379/1695 train_time:35693ms step_avg:94.18ms +step:380/1695 train_time:35789ms step_avg:94.18ms +step:381/1695 train_time:35884ms step_avg:94.18ms +step:382/1695 train_time:35980ms step_avg:94.19ms +step:383/1695 train_time:36075ms step_avg:94.19ms +step:384/1695 train_time:36170ms step_avg:94.19ms +step:385/1695 train_time:36266ms step_avg:94.20ms +step:386/1695 train_time:36361ms step_avg:94.20ms +step:387/1695 train_time:36459ms step_avg:94.21ms +step:388/1695 train_time:36556ms step_avg:94.22ms +step:389/1695 train_time:36653ms step_avg:94.22ms +step:390/1695 train_time:36750ms step_avg:94.23ms +step:391/1695 train_time:36846ms step_avg:94.24ms +step:392/1695 train_time:36943ms step_avg:94.24ms +step:393/1695 train_time:37038ms step_avg:94.24ms +step:394/1695 train_time:37134ms step_avg:94.25ms +step:395/1695 train_time:37229ms step_avg:94.25ms +step:396/1695 train_time:37325ms step_avg:94.26ms +step:397/1695 train_time:37421ms step_avg:94.26ms +step:398/1695 train_time:37517ms step_avg:94.26ms +step:399/1695 train_time:37613ms step_avg:94.27ms +step:400/1695 train_time:37710ms step_avg:94.27ms +step:401/1695 train_time:37806ms step_avg:94.28ms +step:402/1695 train_time:37903ms step_avg:94.29ms +step:403/1695 train_time:37999ms step_avg:94.29ms +step:404/1695 train_time:38095ms step_avg:94.29ms +step:405/1695 train_time:38190ms step_avg:94.30ms +step:406/1695 train_time:38286ms step_avg:94.30ms +step:407/1695 train_time:38382ms step_avg:94.31ms +step:408/1695 train_time:38478ms step_avg:94.31ms +step:409/1695 train_time:38574ms step_avg:94.31ms +step:410/1695 train_time:38670ms step_avg:94.32ms +step:411/1695 train_time:38766ms step_avg:94.32ms +step:412/1695 train_time:38863ms step_avg:94.33ms +step:413/1695 train_time:38959ms step_avg:94.33ms +step:414/1695 train_time:39054ms step_avg:94.33ms +step:415/1695 train_time:39150ms step_avg:94.34ms +step:416/1695 train_time:39246ms step_avg:94.34ms +step:417/1695 train_time:39343ms step_avg:94.35ms +step:418/1695 train_time:39439ms step_avg:94.35ms +step:419/1695 train_time:39535ms step_avg:94.35ms +step:420/1695 train_time:39630ms step_avg:94.36ms +step:421/1695 train_time:39728ms step_avg:94.36ms +step:422/1695 train_time:39824ms step_avg:94.37ms +step:423/1695 train_time:39920ms step_avg:94.37ms +step:424/1695 train_time:40016ms step_avg:94.38ms +step:425/1695 train_time:40111ms step_avg:94.38ms +step:426/1695 train_time:40208ms step_avg:94.38ms +step:427/1695 train_time:40304ms step_avg:94.39ms +step:428/1695 train_time:40400ms step_avg:94.39ms +step:429/1695 train_time:40496ms step_avg:94.40ms +step:430/1695 train_time:40592ms step_avg:94.40ms +step:431/1695 train_time:40689ms step_avg:94.41ms +step:432/1695 train_time:40785ms step_avg:94.41ms +step:433/1695 train_time:40882ms step_avg:94.42ms +step:434/1695 train_time:40978ms step_avg:94.42ms +step:435/1695 train_time:41074ms step_avg:94.42ms +step:436/1695 train_time:41171ms step_avg:94.43ms +step:437/1695 train_time:41267ms step_avg:94.43ms +step:438/1695 train_time:41363ms step_avg:94.44ms +step:439/1695 train_time:41459ms step_avg:94.44ms +step:440/1695 train_time:41555ms step_avg:94.44ms +step:441/1695 train_time:41651ms step_avg:94.45ms +step:442/1695 train_time:41747ms step_avg:94.45ms +step:443/1695 train_time:41844ms step_avg:94.46ms +step:444/1695 train_time:41940ms step_avg:94.46ms +step:445/1695 train_time:42035ms step_avg:94.46ms +step:446/1695 train_time:42131ms step_avg:94.46ms +step:447/1695 train_time:42227ms step_avg:94.47ms +step:448/1695 train_time:42323ms step_avg:94.47ms +step:449/1695 train_time:42419ms step_avg:94.47ms +step:450/1695 train_time:42515ms step_avg:94.48ms +step:451/1695 train_time:42611ms step_avg:94.48ms +step:452/1695 train_time:42707ms step_avg:94.48ms +step:453/1695 train_time:42803ms step_avg:94.49ms +step:454/1695 train_time:42900ms step_avg:94.49ms +step:455/1695 train_time:42995ms step_avg:94.50ms +step:456/1695 train_time:43091ms step_avg:94.50ms +step:457/1695 train_time:43187ms step_avg:94.50ms +step:458/1695 train_time:43283ms step_avg:94.50ms +step:459/1695 train_time:43380ms step_avg:94.51ms +step:460/1695 train_time:43475ms step_avg:94.51ms +step:461/1695 train_time:43571ms step_avg:94.51ms +step:462/1695 train_time:43668ms step_avg:94.52ms +step:463/1695 train_time:43764ms step_avg:94.52ms +step:464/1695 train_time:43861ms step_avg:94.53ms +step:465/1695 train_time:43957ms step_avg:94.53ms +step:466/1695 train_time:44053ms step_avg:94.53ms +step:467/1695 train_time:44149ms step_avg:94.54ms +step:468/1695 train_time:44245ms step_avg:94.54ms +step:469/1695 train_time:44341ms step_avg:94.54ms +step:470/1695 train_time:44437ms step_avg:94.55ms +step:471/1695 train_time:44533ms step_avg:94.55ms +step:472/1695 train_time:44629ms step_avg:94.55ms +step:473/1695 train_time:44725ms step_avg:94.56ms +step:474/1695 train_time:44821ms step_avg:94.56ms +step:475/1695 train_time:44917ms step_avg:94.56ms +step:476/1695 train_time:45014ms step_avg:94.57ms +step:477/1695 train_time:45109ms step_avg:94.57ms +step:478/1695 train_time:45206ms step_avg:94.57ms +step:479/1695 train_time:45303ms step_avg:94.58ms +step:480/1695 train_time:45399ms step_avg:94.58ms +step:481/1695 train_time:45495ms step_avg:94.58ms +step:482/1695 train_time:45591ms step_avg:94.59ms +step:483/1695 train_time:45687ms step_avg:94.59ms +step:484/1695 train_time:45783ms step_avg:94.59ms +step:485/1695 train_time:45880ms step_avg:94.60ms +step:486/1695 train_time:45975ms step_avg:94.60ms +step:487/1695 train_time:46071ms step_avg:94.60ms +step:488/1695 train_time:46167ms step_avg:94.60ms +step:489/1695 train_time:46263ms step_avg:94.61ms +step:490/1695 train_time:46360ms step_avg:94.61ms +step:491/1695 train_time:46457ms step_avg:94.62ms +step:492/1695 train_time:46552ms step_avg:94.62ms +step:493/1695 train_time:46649ms step_avg:94.62ms +step:494/1695 train_time:46746ms step_avg:94.63ms +step:495/1695 train_time:46843ms step_avg:94.63ms +step:496/1695 train_time:46939ms step_avg:94.63ms +step:497/1695 train_time:47035ms step_avg:94.64ms +step:498/1695 train_time:47130ms step_avg:94.64ms +step:499/1695 train_time:47227ms step_avg:94.64ms +step:500/1695 train_time:47324ms step_avg:94.65ms +step:500/1695 val_loss:3.7325 train_time:47418ms step_avg:94.84ms +step:501/1695 train_time:47445ms step_avg:94.70ms +step:502/1695 train_time:47525ms step_avg:94.67ms +step:503/1695 train_time:47625ms step_avg:94.68ms +step:504/1695 train_time:47722ms step_avg:94.69ms +step:505/1695 train_time:47818ms step_avg:94.69ms +step:506/1695 train_time:47914ms step_avg:94.69ms +step:507/1695 train_time:48010ms step_avg:94.69ms +step:508/1695 train_time:48105ms step_avg:94.70ms +step:509/1695 train_time:48201ms step_avg:94.70ms +step:510/1695 train_time:48297ms step_avg:94.70ms +step:511/1695 train_time:48393ms step_avg:94.70ms +step:512/1695 train_time:48491ms step_avg:94.71ms +step:513/1695 train_time:48590ms step_avg:94.72ms +step:514/1695 train_time:48688ms step_avg:94.72ms +step:515/1695 train_time:48784ms step_avg:94.73ms +step:516/1695 train_time:48881ms step_avg:94.73ms +step:517/1695 train_time:48976ms step_avg:94.73ms +step:518/1695 train_time:49072ms step_avg:94.73ms +step:519/1695 train_time:49169ms step_avg:94.74ms +step:520/1695 train_time:49265ms step_avg:94.74ms +step:521/1695 train_time:49360ms step_avg:94.74ms +step:522/1695 train_time:49456ms step_avg:94.74ms +step:523/1695 train_time:49554ms step_avg:94.75ms +step:524/1695 train_time:49652ms step_avg:94.76ms +step:525/1695 train_time:49750ms step_avg:94.76ms +step:526/1695 train_time:49847ms step_avg:94.77ms +step:527/1695 train_time:49944ms step_avg:94.77ms +step:528/1695 train_time:50040ms step_avg:94.77ms +step:529/1695 train_time:50136ms step_avg:94.77ms +step:530/1695 train_time:50232ms step_avg:94.78ms +step:531/1695 train_time:50328ms step_avg:94.78ms +step:532/1695 train_time:50425ms step_avg:94.78ms +step:533/1695 train_time:50522ms step_avg:94.79ms +step:534/1695 train_time:50619ms step_avg:94.79ms +step:535/1695 train_time:50715ms step_avg:94.80ms +step:536/1695 train_time:50813ms step_avg:94.80ms +step:537/1695 train_time:50910ms step_avg:94.81ms +step:538/1695 train_time:51007ms step_avg:94.81ms +step:539/1695 train_time:51104ms step_avg:94.81ms +step:540/1695 train_time:51200ms step_avg:94.81ms +step:541/1695 train_time:51296ms step_avg:94.82ms +step:542/1695 train_time:51392ms step_avg:94.82ms +step:543/1695 train_time:51490ms step_avg:94.82ms +step:544/1695 train_time:51586ms step_avg:94.83ms +step:545/1695 train_time:51684ms step_avg:94.83ms +step:546/1695 train_time:51780ms step_avg:94.84ms +step:547/1695 train_time:51876ms step_avg:94.84ms +step:548/1695 train_time:51973ms step_avg:94.84ms +step:549/1695 train_time:52070ms step_avg:94.85ms +step:550/1695 train_time:52167ms step_avg:94.85ms +step:551/1695 train_time:52263ms step_avg:94.85ms +step:552/1695 train_time:52359ms step_avg:94.85ms +step:553/1695 train_time:52455ms step_avg:94.86ms +step:554/1695 train_time:52552ms step_avg:94.86ms +step:555/1695 train_time:52649ms step_avg:94.86ms +step:556/1695 train_time:52746ms step_avg:94.87ms +step:557/1695 train_time:52842ms step_avg:94.87ms +step:558/1695 train_time:52938ms step_avg:94.87ms +step:559/1695 train_time:53035ms step_avg:94.88ms +step:560/1695 train_time:53133ms step_avg:94.88ms +step:561/1695 train_time:53230ms step_avg:94.88ms +step:562/1695 train_time:53327ms step_avg:94.89ms +step:563/1695 train_time:53424ms step_avg:94.89ms +step:564/1695 train_time:53520ms step_avg:94.89ms +step:565/1695 train_time:53617ms step_avg:94.90ms +step:566/1695 train_time:53715ms step_avg:94.90ms +step:567/1695 train_time:53813ms step_avg:94.91ms +step:568/1695 train_time:53910ms step_avg:94.91ms +step:569/1695 train_time:54006ms step_avg:94.91ms +step:570/1695 train_time:54101ms step_avg:94.91ms +step:571/1695 train_time:54197ms step_avg:94.92ms +step:572/1695 train_time:54293ms step_avg:94.92ms +step:573/1695 train_time:54391ms step_avg:94.92ms +step:574/1695 train_time:54487ms step_avg:94.93ms +step:575/1695 train_time:54584ms step_avg:94.93ms +step:576/1695 train_time:54681ms step_avg:94.93ms +step:577/1695 train_time:54777ms step_avg:94.93ms +step:578/1695 train_time:54873ms step_avg:94.94ms +step:579/1695 train_time:54970ms step_avg:94.94ms +step:580/1695 train_time:55067ms step_avg:94.94ms +step:581/1695 train_time:55163ms step_avg:94.95ms +step:582/1695 train_time:55259ms step_avg:94.95ms +step:583/1695 train_time:55355ms step_avg:94.95ms +step:584/1695 train_time:55452ms step_avg:94.95ms +step:585/1695 train_time:55550ms step_avg:94.96ms +step:586/1695 train_time:55646ms step_avg:94.96ms +step:587/1695 train_time:55743ms step_avg:94.96ms +step:588/1695 train_time:55839ms step_avg:94.96ms +step:589/1695 train_time:55934ms step_avg:94.96ms +step:590/1695 train_time:56031ms step_avg:94.97ms +step:591/1695 train_time:56127ms step_avg:94.97ms +step:592/1695 train_time:56224ms step_avg:94.97ms +step:593/1695 train_time:56320ms step_avg:94.98ms +step:594/1695 train_time:56416ms step_avg:94.98ms +step:595/1695 train_time:56513ms step_avg:94.98ms +step:596/1695 train_time:56609ms step_avg:94.98ms +step:597/1695 train_time:56706ms step_avg:94.98ms +step:598/1695 train_time:56803ms step_avg:94.99ms +step:599/1695 train_time:56899ms step_avg:94.99ms +step:600/1695 train_time:56995ms step_avg:94.99ms +step:601/1695 train_time:57092ms step_avg:95.00ms +step:602/1695 train_time:57189ms step_avg:95.00ms +step:603/1695 train_time:57286ms step_avg:95.00ms +step:604/1695 train_time:57383ms step_avg:95.01ms +step:605/1695 train_time:57479ms step_avg:95.01ms +step:606/1695 train_time:57575ms step_avg:95.01ms +step:607/1695 train_time:57672ms step_avg:95.01ms +step:608/1695 train_time:57769ms step_avg:95.01ms +step:609/1695 train_time:57866ms step_avg:95.02ms +step:610/1695 train_time:57962ms step_avg:95.02ms +step:611/1695 train_time:58059ms step_avg:95.02ms +step:612/1695 train_time:58155ms step_avg:95.02ms +step:613/1695 train_time:58253ms step_avg:95.03ms +step:614/1695 train_time:58350ms step_avg:95.03ms +step:615/1695 train_time:58447ms step_avg:95.04ms +step:616/1695 train_time:58544ms step_avg:95.04ms +step:617/1695 train_time:58640ms step_avg:95.04ms +step:618/1695 train_time:58736ms step_avg:95.04ms +step:619/1695 train_time:58832ms step_avg:95.04ms +step:620/1695 train_time:58929ms step_avg:95.05ms +step:621/1695 train_time:59027ms step_avg:95.05ms +step:622/1695 train_time:59123ms step_avg:95.05ms +step:623/1695 train_time:59219ms step_avg:95.05ms +step:624/1695 train_time:59315ms step_avg:95.06ms +step:625/1695 train_time:59412ms step_avg:95.06ms +step:625/1695 val_loss:3.6469 train_time:59507ms step_avg:95.21ms +step:626/1695 train_time:59535ms step_avg:95.10ms +step:627/1695 train_time:59614ms step_avg:95.08ms +step:628/1695 train_time:59714ms step_avg:95.09ms +step:629/1695 train_time:59812ms step_avg:95.09ms +step:630/1695 train_time:59911ms step_avg:95.10ms +step:631/1695 train_time:60009ms step_avg:95.10ms +step:632/1695 train_time:60107ms step_avg:95.11ms +step:633/1695 train_time:60205ms step_avg:95.11ms +step:634/1695 train_time:60302ms step_avg:95.11ms +step:635/1695 train_time:60399ms step_avg:95.12ms +step:636/1695 train_time:60496ms step_avg:95.12ms +step:637/1695 train_time:60594ms step_avg:95.12ms +step:638/1695 train_time:60694ms step_avg:95.13ms +step:639/1695 train_time:60792ms step_avg:95.14ms +step:640/1695 train_time:60890ms step_avg:95.14ms +step:641/1695 train_time:60989ms step_avg:95.15ms +step:642/1695 train_time:61087ms step_avg:95.15ms +step:643/1695 train_time:61185ms step_avg:95.16ms +step:644/1695 train_time:61282ms step_avg:95.16ms +step:645/1695 train_time:61379ms step_avg:95.16ms +step:646/1695 train_time:61477ms step_avg:95.17ms +step:647/1695 train_time:61575ms step_avg:95.17ms +step:648/1695 train_time:61673ms step_avg:95.17ms +step:649/1695 train_time:61772ms step_avg:95.18ms +step:650/1695 train_time:61870ms step_avg:95.18ms +step:651/1695 train_time:61969ms step_avg:95.19ms +step:652/1695 train_time:62067ms step_avg:95.20ms +step:653/1695 train_time:62164ms step_avg:95.20ms +step:654/1695 train_time:62262ms step_avg:95.20ms +step:655/1695 train_time:62359ms step_avg:95.20ms +step:656/1695 train_time:62456ms step_avg:95.21ms +step:657/1695 train_time:62554ms step_avg:95.21ms +step:658/1695 train_time:62652ms step_avg:95.22ms +step:659/1695 train_time:62751ms step_avg:95.22ms +step:660/1695 train_time:62849ms step_avg:95.23ms +step:661/1695 train_time:62947ms step_avg:95.23ms +step:662/1695 train_time:63046ms step_avg:95.23ms +step:663/1695 train_time:63143ms step_avg:95.24ms +step:664/1695 train_time:63241ms step_avg:95.24ms +step:665/1695 train_time:63339ms step_avg:95.25ms +step:666/1695 train_time:63437ms step_avg:95.25ms +step:667/1695 train_time:63534ms step_avg:95.25ms +step:668/1695 train_time:63632ms step_avg:95.26ms +step:669/1695 train_time:63730ms step_avg:95.26ms +step:670/1695 train_time:63828ms step_avg:95.27ms +step:671/1695 train_time:63927ms step_avg:95.27ms +step:672/1695 train_time:64025ms step_avg:95.27ms +step:673/1695 train_time:64122ms step_avg:95.28ms +step:674/1695 train_time:64220ms step_avg:95.28ms +step:675/1695 train_time:64318ms step_avg:95.29ms +step:676/1695 train_time:64416ms step_avg:95.29ms +step:677/1695 train_time:64513ms step_avg:95.29ms +step:678/1695 train_time:64611ms step_avg:95.30ms +step:679/1695 train_time:64709ms step_avg:95.30ms +step:680/1695 train_time:64808ms step_avg:95.31ms +step:681/1695 train_time:64906ms step_avg:95.31ms +step:682/1695 train_time:65004ms step_avg:95.31ms +step:683/1695 train_time:65103ms step_avg:95.32ms +step:684/1695 train_time:65202ms step_avg:95.32ms +step:685/1695 train_time:65300ms step_avg:95.33ms +step:686/1695 train_time:65398ms step_avg:95.33ms +step:687/1695 train_time:65495ms step_avg:95.34ms +step:688/1695 train_time:65592ms step_avg:95.34ms +step:689/1695 train_time:65690ms step_avg:95.34ms +step:690/1695 train_time:65788ms step_avg:95.34ms +step:691/1695 train_time:65886ms step_avg:95.35ms +step:692/1695 train_time:65983ms step_avg:95.35ms +step:693/1695 train_time:66081ms step_avg:95.35ms +step:694/1695 train_time:66179ms step_avg:95.36ms +step:695/1695 train_time:66276ms step_avg:95.36ms +step:696/1695 train_time:66374ms step_avg:95.36ms +step:697/1695 train_time:66473ms step_avg:95.37ms +step:698/1695 train_time:66571ms step_avg:95.37ms +step:699/1695 train_time:66669ms step_avg:95.38ms +step:700/1695 train_time:66767ms step_avg:95.38ms +step:701/1695 train_time:67101ms step_avg:95.72ms +step:702/1695 train_time:67198ms step_avg:95.72ms +step:703/1695 train_time:67294ms step_avg:95.72ms +step:704/1695 train_time:67392ms step_avg:95.73ms +step:705/1695 train_time:67489ms step_avg:95.73ms +step:706/1695 train_time:67587ms step_avg:95.73ms +step:707/1695 train_time:67683ms step_avg:95.73ms +step:708/1695 train_time:67781ms step_avg:95.74ms +step:709/1695 train_time:67878ms step_avg:95.74ms +step:710/1695 train_time:67977ms step_avg:95.74ms +step:711/1695 train_time:68076ms step_avg:95.75ms +step:712/1695 train_time:68175ms step_avg:95.75ms +step:713/1695 train_time:68526ms step_avg:96.11ms +step:714/1695 train_time:68622ms step_avg:96.11ms +step:715/1695 train_time:68718ms step_avg:96.11ms +step:716/1695 train_time:68815ms step_avg:96.11ms +step:717/1695 train_time:68912ms step_avg:96.11ms +step:718/1695 train_time:69009ms step_avg:96.11ms +step:719/1695 train_time:69106ms step_avg:96.11ms +step:720/1695 train_time:69204ms step_avg:96.12ms +step:721/1695 train_time:69301ms step_avg:96.12ms +step:722/1695 train_time:69398ms step_avg:96.12ms +step:723/1695 train_time:69498ms step_avg:96.12ms +step:724/1695 train_time:69595ms step_avg:96.13ms +step:725/1695 train_time:69693ms step_avg:96.13ms +step:726/1695 train_time:69791ms step_avg:96.13ms +step:727/1695 train_time:69888ms step_avg:96.13ms +step:728/1695 train_time:69986ms step_avg:96.13ms +step:729/1695 train_time:70083ms step_avg:96.14ms +step:730/1695 train_time:70180ms step_avg:96.14ms +step:731/1695 train_time:70277ms step_avg:96.14ms +step:732/1695 train_time:70375ms step_avg:96.14ms +step:733/1695 train_time:70474ms step_avg:96.14ms +step:734/1695 train_time:70572ms step_avg:96.15ms +step:735/1695 train_time:70670ms step_avg:96.15ms +step:736/1695 train_time:70769ms step_avg:96.15ms +step:737/1695 train_time:70866ms step_avg:96.16ms +step:738/1695 train_time:70964ms step_avg:96.16ms +step:739/1695 train_time:71061ms step_avg:96.16ms +step:740/1695 train_time:71158ms step_avg:96.16ms +step:741/1695 train_time:71255ms step_avg:96.16ms +step:742/1695 train_time:71352ms step_avg:96.16ms +step:743/1695 train_time:71452ms step_avg:96.17ms +step:744/1695 train_time:71551ms step_avg:96.17ms +step:745/1695 train_time:71650ms step_avg:96.17ms +step:746/1695 train_time:71748ms step_avg:96.18ms +step:747/1695 train_time:71846ms step_avg:96.18ms +step:748/1695 train_time:71944ms step_avg:96.18ms +step:749/1695 train_time:72043ms step_avg:96.19ms +step:750/1695 train_time:72141ms step_avg:96.19ms +step:750/1695 val_loss:3.5857 train_time:72236ms step_avg:96.32ms +step:751/1695 train_time:72264ms step_avg:96.22ms +step:752/1695 train_time:72348ms step_avg:96.21ms +step:753/1695 train_time:72450ms step_avg:96.22ms +step:754/1695 train_time:72547ms step_avg:96.22ms +step:755/1695 train_time:72645ms step_avg:96.22ms +step:756/1695 train_time:72743ms step_avg:96.22ms +step:757/1695 train_time:72840ms step_avg:96.22ms +step:758/1695 train_time:72938ms step_avg:96.22ms +step:759/1695 train_time:73035ms step_avg:96.23ms +step:760/1695 train_time:73133ms step_avg:96.23ms +step:761/1695 train_time:73230ms step_avg:96.23ms +step:762/1695 train_time:73330ms step_avg:96.23ms +step:763/1695 train_time:73430ms step_avg:96.24ms +step:764/1695 train_time:73528ms step_avg:96.24ms +step:765/1695 train_time:73626ms step_avg:96.24ms +step:766/1695 train_time:73724ms step_avg:96.24ms +step:767/1695 train_time:73822ms step_avg:96.25ms +step:768/1695 train_time:73920ms step_avg:96.25ms +step:769/1695 train_time:74018ms step_avg:96.25ms +step:770/1695 train_time:74115ms step_avg:96.25ms +step:771/1695 train_time:74214ms step_avg:96.26ms +step:772/1695 train_time:74608ms step_avg:96.64ms +step:773/1695 train_time:74705ms step_avg:96.64ms +step:774/1695 train_time:74802ms step_avg:96.64ms +step:775/1695 train_time:74900ms step_avg:96.64ms +step:776/1695 train_time:74997ms step_avg:96.65ms +step:777/1695 train_time:75095ms step_avg:96.65ms +step:778/1695 train_time:75421ms step_avg:96.94ms +step:779/1695 train_time:75517ms step_avg:96.94ms +step:780/1695 train_time:75614ms step_avg:96.94ms +step:781/1695 train_time:75712ms step_avg:96.94ms +step:782/1695 train_time:75811ms step_avg:96.94ms +step:783/1695 train_time:75907ms step_avg:96.94ms +step:784/1695 train_time:76004ms step_avg:96.94ms +step:785/1695 train_time:76101ms step_avg:96.94ms +step:786/1695 train_time:76199ms step_avg:96.95ms +step:787/1695 train_time:76298ms step_avg:96.95ms +step:788/1695 train_time:76402ms step_avg:96.96ms +step:789/1695 train_time:76501ms step_avg:96.96ms +step:790/1695 train_time:76842ms step_avg:97.27ms +step:791/1695 train_time:76938ms step_avg:97.27ms +step:792/1695 train_time:77036ms step_avg:97.27ms +step:793/1695 train_time:77133ms step_avg:97.27ms +step:794/1695 train_time:77229ms step_avg:97.27ms +step:795/1695 train_time:77327ms step_avg:97.27ms +step:796/1695 train_time:77664ms step_avg:97.57ms +step:797/1695 train_time:78096ms step_avg:97.99ms +step:798/1695 train_time:78145ms step_avg:97.93ms +step:799/1695 train_time:78241ms step_avg:97.92ms +step:800/1695 train_time:78338ms step_avg:97.92ms +step:801/1695 train_time:78435ms step_avg:97.92ms +step:802/1695 train_time:78533ms step_avg:97.92ms +step:803/1695 train_time:78630ms step_avg:97.92ms +step:804/1695 train_time:78727ms step_avg:97.92ms +step:805/1695 train_time:78825ms step_avg:97.92ms +step:806/1695 train_time:78922ms step_avg:97.92ms +step:807/1695 train_time:79020ms step_avg:97.92ms +step:808/1695 train_time:79122ms step_avg:97.92ms +step:809/1695 train_time:79222ms step_avg:97.93ms +step:810/1695 train_time:79321ms step_avg:97.93ms +step:811/1695 train_time:79420ms step_avg:97.93ms +step:812/1695 train_time:79519ms step_avg:97.93ms +step:813/1695 train_time:79618ms step_avg:97.93ms +step:814/1695 train_time:79716ms step_avg:97.93ms +step:815/1695 train_time:79813ms step_avg:97.93ms +step:816/1695 train_time:79910ms step_avg:97.93ms +step:817/1695 train_time:80008ms step_avg:97.93ms +step:818/1695 train_time:80105ms step_avg:97.93ms +step:819/1695 train_time:80204ms step_avg:97.93ms +step:820/1695 train_time:80302ms step_avg:97.93ms +step:821/1695 train_time:80400ms step_avg:97.93ms +step:822/1695 train_time:80499ms step_avg:97.93ms +step:823/1695 train_time:80598ms step_avg:97.93ms +step:824/1695 train_time:80696ms step_avg:97.93ms +step:825/1695 train_time:80795ms step_avg:97.93ms +step:826/1695 train_time:80893ms step_avg:97.93ms +step:827/1695 train_time:80991ms step_avg:97.93ms +step:828/1695 train_time:81089ms step_avg:97.93ms +step:829/1695 train_time:81187ms step_avg:97.93ms +step:830/1695 train_time:81285ms step_avg:97.93ms +step:831/1695 train_time:81382ms step_avg:97.93ms +step:832/1695 train_time:81481ms step_avg:97.93ms +step:833/1695 train_time:81579ms step_avg:97.93ms +step:834/1695 train_time:81678ms step_avg:97.94ms +step:835/1695 train_time:81775ms step_avg:97.93ms +step:836/1695 train_time:81874ms step_avg:97.94ms +step:837/1695 train_time:81974ms step_avg:97.94ms +step:838/1695 train_time:82072ms step_avg:97.94ms +step:839/1695 train_time:82170ms step_avg:97.94ms +step:840/1695 train_time:82268ms step_avg:97.94ms +step:841/1695 train_time:82366ms step_avg:97.94ms +step:842/1695 train_time:82464ms step_avg:97.94ms +step:843/1695 train_time:82562ms step_avg:97.94ms +step:844/1695 train_time:82661ms step_avg:97.94ms +step:845/1695 train_time:82759ms step_avg:97.94ms +step:846/1695 train_time:82858ms step_avg:97.94ms +step:847/1695 train_time:82957ms step_avg:97.94ms +step:848/1695 train_time:83056ms step_avg:97.94ms +step:849/1695 train_time:83154ms step_avg:97.94ms +step:850/1695 train_time:83252ms step_avg:97.94ms +step:851/1695 train_time:83350ms step_avg:97.94ms +step:852/1695 train_time:83448ms step_avg:97.94ms +step:853/1695 train_time:83547ms step_avg:97.94ms +step:854/1695 train_time:83645ms step_avg:97.94ms +step:855/1695 train_time:83742ms step_avg:97.94ms +step:856/1695 train_time:83840ms step_avg:97.94ms +step:857/1695 train_time:83938ms step_avg:97.94ms +step:858/1695 train_time:84036ms step_avg:97.94ms +step:859/1695 train_time:84135ms step_avg:97.95ms +step:860/1695 train_time:84233ms step_avg:97.95ms +step:861/1695 train_time:84332ms step_avg:97.95ms +step:862/1695 train_time:84430ms step_avg:97.95ms +step:863/1695 train_time:84529ms step_avg:97.95ms +step:864/1695 train_time:84627ms step_avg:97.95ms +step:865/1695 train_time:84725ms step_avg:97.95ms +step:866/1695 train_time:84823ms step_avg:97.95ms +step:867/1695 train_time:84920ms step_avg:97.95ms +step:868/1695 train_time:85018ms step_avg:97.95ms +step:869/1695 train_time:85117ms step_avg:97.95ms +step:870/1695 train_time:85216ms step_avg:97.95ms +step:871/1695 train_time:85316ms step_avg:97.95ms +step:872/1695 train_time:85416ms step_avg:97.95ms +step:873/1695 train_time:85515ms step_avg:97.96ms +step:874/1695 train_time:85614ms step_avg:97.96ms +step:875/1695 train_time:85712ms step_avg:97.96ms +step:875/1695 val_loss:3.5376 train_time:85809ms step_avg:98.07ms +step:876/1695 train_time:85837ms step_avg:97.99ms +step:877/1695 train_time:85918ms step_avg:97.97ms +step:878/1695 train_time:86019ms step_avg:97.97ms +step:879/1695 train_time:86118ms step_avg:97.97ms +step:880/1695 train_time:86216ms step_avg:97.97ms +step:881/1695 train_time:86313ms step_avg:97.97ms +step:882/1695 train_time:86412ms step_avg:97.97ms +step:883/1695 train_time:86511ms step_avg:97.97ms +step:884/1695 train_time:86609ms step_avg:97.97ms +step:885/1695 train_time:86708ms step_avg:97.98ms +step:886/1695 train_time:86808ms step_avg:97.98ms +step:887/1695 train_time:86910ms step_avg:97.98ms +step:888/1695 train_time:87011ms step_avg:97.99ms +step:889/1695 train_time:87112ms step_avg:97.99ms +step:890/1695 train_time:87212ms step_avg:97.99ms +step:891/1695 train_time:87311ms step_avg:97.99ms +step:892/1695 train_time:87410ms step_avg:97.99ms +step:893/1695 train_time:87510ms step_avg:98.00ms +step:894/1695 train_time:87608ms step_avg:98.00ms +step:895/1695 train_time:87707ms step_avg:98.00ms +step:896/1695 train_time:87807ms step_avg:98.00ms +step:897/1695 train_time:87907ms step_avg:98.00ms +step:898/1695 train_time:88008ms step_avg:98.00ms +step:899/1695 train_time:88109ms step_avg:98.01ms +step:900/1695 train_time:88210ms step_avg:98.01ms +step:901/1695 train_time:88310ms step_avg:98.01ms +step:902/1695 train_time:88410ms step_avg:98.02ms +step:903/1695 train_time:88510ms step_avg:98.02ms +step:904/1695 train_time:88609ms step_avg:98.02ms +step:905/1695 train_time:88708ms step_avg:98.02ms +step:906/1695 train_time:88807ms step_avg:98.02ms +step:907/1695 train_time:88906ms step_avg:98.02ms +step:908/1695 train_time:89007ms step_avg:98.03ms +step:909/1695 train_time:89108ms step_avg:98.03ms +step:910/1695 train_time:89209ms step_avg:98.03ms +step:911/1695 train_time:89309ms step_avg:98.03ms +step:912/1695 train_time:89409ms step_avg:98.04ms +step:913/1695 train_time:89509ms step_avg:98.04ms +step:914/1695 train_time:89608ms step_avg:98.04ms +step:915/1695 train_time:89707ms step_avg:98.04ms +step:916/1695 train_time:89806ms step_avg:98.04ms +step:917/1695 train_time:89906ms step_avg:98.04ms +step:918/1695 train_time:90007ms step_avg:98.05ms +step:919/1695 train_time:90107ms step_avg:98.05ms +step:920/1695 train_time:90208ms step_avg:98.05ms +step:921/1695 train_time:90309ms step_avg:98.06ms +step:922/1695 train_time:90409ms step_avg:98.06ms +step:923/1695 train_time:90509ms step_avg:98.06ms +step:924/1695 train_time:90609ms step_avg:98.06ms +step:925/1695 train_time:90708ms step_avg:98.06ms +step:926/1695 train_time:90809ms step_avg:98.07ms +step:927/1695 train_time:90908ms step_avg:98.07ms +step:928/1695 train_time:91007ms step_avg:98.07ms +step:929/1695 train_time:91106ms step_avg:98.07ms +step:930/1695 train_time:91207ms step_avg:98.07ms +step:931/1695 train_time:91308ms step_avg:98.07ms +step:932/1695 train_time:91408ms step_avg:98.08ms +step:933/1695 train_time:91509ms step_avg:98.08ms +step:934/1695 train_time:91608ms step_avg:98.08ms +step:935/1695 train_time:91708ms step_avg:98.08ms +step:936/1695 train_time:91807ms step_avg:98.08ms +step:937/1695 train_time:91907ms step_avg:98.09ms +step:938/1695 train_time:92007ms step_avg:98.09ms +step:939/1695 train_time:92108ms step_avg:98.09ms +step:940/1695 train_time:92209ms step_avg:98.09ms +step:941/1695 train_time:92309ms step_avg:98.10ms +step:942/1695 train_time:92409ms step_avg:98.10ms +step:943/1695 train_time:92509ms step_avg:98.10ms +step:944/1695 train_time:92609ms step_avg:98.10ms +step:945/1695 train_time:92709ms step_avg:98.10ms +step:946/1695 train_time:92808ms step_avg:98.11ms +step:947/1695 train_time:92908ms step_avg:98.11ms +step:948/1695 train_time:93008ms step_avg:98.11ms +step:949/1695 train_time:93107ms step_avg:98.11ms +step:950/1695 train_time:93207ms step_avg:98.11ms +step:951/1695 train_time:93307ms step_avg:98.11ms +step:952/1695 train_time:93408ms step_avg:98.12ms +step:953/1695 train_time:93509ms step_avg:98.12ms +step:954/1695 train_time:93609ms step_avg:98.12ms +step:955/1695 train_time:93709ms step_avg:98.12ms +step:956/1695 train_time:93808ms step_avg:98.13ms +step:957/1695 train_time:93908ms step_avg:98.13ms +step:958/1695 train_time:94008ms step_avg:98.13ms +step:959/1695 train_time:94107ms step_avg:98.13ms +step:960/1695 train_time:94208ms step_avg:98.13ms +step:961/1695 train_time:94309ms step_avg:98.14ms +step:962/1695 train_time:94409ms step_avg:98.14ms +step:963/1695 train_time:94509ms step_avg:98.14ms +step:964/1695 train_time:94608ms step_avg:98.14ms +step:965/1695 train_time:94708ms step_avg:98.14ms +step:966/1695 train_time:94808ms step_avg:98.14ms +step:967/1695 train_time:94907ms step_avg:98.15ms +step:968/1695 train_time:95007ms step_avg:98.15ms +step:969/1695 train_time:95108ms step_avg:98.15ms +step:970/1695 train_time:95208ms step_avg:98.15ms +step:971/1695 train_time:95308ms step_avg:98.15ms +step:972/1695 train_time:95409ms step_avg:98.16ms +step:973/1695 train_time:95509ms step_avg:98.16ms +step:974/1695 train_time:95608ms step_avg:98.16ms +step:975/1695 train_time:95708ms step_avg:98.16ms +step:976/1695 train_time:95808ms step_avg:98.16ms +step:977/1695 train_time:95908ms step_avg:98.17ms +step:978/1695 train_time:96007ms step_avg:98.17ms +step:979/1695 train_time:96107ms step_avg:98.17ms +step:980/1695 train_time:96207ms step_avg:98.17ms +step:981/1695 train_time:96306ms step_avg:98.17ms +step:982/1695 train_time:96408ms step_avg:98.17ms +step:983/1695 train_time:96508ms step_avg:98.18ms +step:984/1695 train_time:96608ms step_avg:98.18ms +step:985/1695 train_time:96708ms step_avg:98.18ms +step:986/1695 train_time:96808ms step_avg:98.18ms +step:987/1695 train_time:96908ms step_avg:98.18ms +step:988/1695 train_time:97008ms step_avg:98.19ms +step:989/1695 train_time:97108ms step_avg:98.19ms +step:990/1695 train_time:97208ms step_avg:98.19ms +step:991/1695 train_time:97308ms step_avg:98.19ms +step:992/1695 train_time:97407ms step_avg:98.19ms +step:993/1695 train_time:97507ms step_avg:98.19ms +step:994/1695 train_time:97607ms step_avg:98.20ms +step:995/1695 train_time:97707ms step_avg:98.20ms +step:996/1695 train_time:97807ms step_avg:98.20ms +step:997/1695 train_time:97907ms step_avg:98.20ms +step:998/1695 train_time:98006ms step_avg:98.20ms +step:999/1695 train_time:98106ms step_avg:98.20ms +step:1000/1695 train_time:98206ms step_avg:98.21ms +step:1000/1695 val_loss:3.4945 train_time:98304ms step_avg:98.30ms +step:1001/1695 train_time:98332ms step_avg:98.23ms +step:1002/1695 train_time:98415ms step_avg:98.22ms +step:1003/1695 train_time:98515ms step_avg:98.22ms +step:1004/1695 train_time:98614ms step_avg:98.22ms +step:1005/1695 train_time:98713ms step_avg:98.22ms +step:1006/1695 train_time:98811ms step_avg:98.22ms +step:1007/1695 train_time:98910ms step_avg:98.22ms +step:1008/1695 train_time:99008ms step_avg:98.22ms +step:1009/1695 train_time:99107ms step_avg:98.22ms +step:1010/1695 train_time:99205ms step_avg:98.22ms +step:1011/1695 train_time:99305ms step_avg:98.22ms +step:1012/1695 train_time:99406ms step_avg:98.23ms +step:1013/1695 train_time:99507ms step_avg:98.23ms +step:1014/1695 train_time:99607ms step_avg:98.23ms +step:1015/1695 train_time:99708ms step_avg:98.23ms +step:1016/1695 train_time:99807ms step_avg:98.23ms +step:1017/1695 train_time:99907ms step_avg:98.24ms +step:1018/1695 train_time:100006ms step_avg:98.24ms +step:1019/1695 train_time:100104ms step_avg:98.24ms +step:1020/1695 train_time:100204ms step_avg:98.24ms +step:1021/1695 train_time:100305ms step_avg:98.24ms +step:1022/1695 train_time:100405ms step_avg:98.24ms +step:1023/1695 train_time:100505ms step_avg:98.25ms +step:1024/1695 train_time:100608ms step_avg:98.25ms +step:1025/1695 train_time:100707ms step_avg:98.25ms +step:1026/1695 train_time:100807ms step_avg:98.25ms +step:1027/1695 train_time:100907ms step_avg:98.25ms +step:1028/1695 train_time:101006ms step_avg:98.26ms +step:1029/1695 train_time:101107ms step_avg:98.26ms +step:1030/1695 train_time:101205ms step_avg:98.26ms +step:1031/1695 train_time:101305ms step_avg:98.26ms +step:1032/1695 train_time:101405ms step_avg:98.26ms +step:1033/1695 train_time:101505ms step_avg:98.26ms +step:1034/1695 train_time:101605ms step_avg:98.26ms +step:1035/1695 train_time:101705ms step_avg:98.27ms +step:1036/1695 train_time:101804ms step_avg:98.27ms +step:1037/1695 train_time:101906ms step_avg:98.27ms +step:1038/1695 train_time:102006ms step_avg:98.27ms +step:1039/1695 train_time:102105ms step_avg:98.27ms +step:1040/1695 train_time:102204ms step_avg:98.27ms +step:1041/1695 train_time:102304ms step_avg:98.27ms +step:1042/1695 train_time:102404ms step_avg:98.28ms +step:1043/1695 train_time:102504ms step_avg:98.28ms +step:1044/1695 train_time:102603ms step_avg:98.28ms +step:1045/1695 train_time:102704ms step_avg:98.28ms +step:1046/1695 train_time:102805ms step_avg:98.28ms +step:1047/1695 train_time:102905ms step_avg:98.29ms +step:1048/1695 train_time:103005ms step_avg:98.29ms +step:1049/1695 train_time:103104ms step_avg:98.29ms +step:1050/1695 train_time:103204ms step_avg:98.29ms +step:1051/1695 train_time:103305ms step_avg:98.29ms +step:1052/1695 train_time:103404ms step_avg:98.29ms +step:1053/1695 train_time:103504ms step_avg:98.29ms +step:1054/1695 train_time:103604ms step_avg:98.30ms +step:1055/1695 train_time:103704ms step_avg:98.30ms +step:1056/1695 train_time:103804ms step_avg:98.30ms +step:1057/1695 train_time:103904ms step_avg:98.30ms +step:1058/1695 train_time:104003ms step_avg:98.30ms +step:1059/1695 train_time:104103ms step_avg:98.30ms +step:1060/1695 train_time:104203ms step_avg:98.30ms +step:1061/1695 train_time:104304ms step_avg:98.31ms +step:1062/1695 train_time:104405ms step_avg:98.31ms +step:1063/1695 train_time:104505ms step_avg:98.31ms +step:1064/1695 train_time:104604ms step_avg:98.31ms +step:1065/1695 train_time:104705ms step_avg:98.31ms +step:1066/1695 train_time:104805ms step_avg:98.32ms +step:1067/1695 train_time:104906ms step_avg:98.32ms +step:1068/1695 train_time:105005ms step_avg:98.32ms +step:1069/1695 train_time:105105ms step_avg:98.32ms +step:1070/1695 train_time:105206ms step_avg:98.32ms +step:1071/1695 train_time:105305ms step_avg:98.32ms +step:1072/1695 train_time:105405ms step_avg:98.33ms +step:1073/1695 train_time:105504ms step_avg:98.33ms +step:1074/1695 train_time:105605ms step_avg:98.33ms +step:1075/1695 train_time:105705ms step_avg:98.33ms +step:1076/1695 train_time:105805ms step_avg:98.33ms +step:1077/1695 train_time:105906ms step_avg:98.33ms +step:1078/1695 train_time:106006ms step_avg:98.34ms +step:1079/1695 train_time:106106ms step_avg:98.34ms +step:1080/1695 train_time:106205ms step_avg:98.34ms +step:1081/1695 train_time:106305ms step_avg:98.34ms +step:1082/1695 train_time:106405ms step_avg:98.34ms +step:1083/1695 train_time:106504ms step_avg:98.34ms +step:1084/1695 train_time:106604ms step_avg:98.34ms +step:1085/1695 train_time:106704ms step_avg:98.35ms +step:1086/1695 train_time:106804ms step_avg:98.35ms +step:1087/1695 train_time:106905ms step_avg:98.35ms +step:1088/1695 train_time:107005ms step_avg:98.35ms +step:1089/1695 train_time:107105ms step_avg:98.35ms +step:1090/1695 train_time:107205ms step_avg:98.35ms +step:1091/1695 train_time:107305ms step_avg:98.36ms +step:1092/1695 train_time:107406ms step_avg:98.36ms +step:1093/1695 train_time:107506ms step_avg:98.36ms +step:1094/1695 train_time:107606ms step_avg:98.36ms +step:1095/1695 train_time:107706ms step_avg:98.36ms +step:1096/1695 train_time:107806ms step_avg:98.36ms +step:1097/1695 train_time:107906ms step_avg:98.36ms +step:1098/1695 train_time:108005ms step_avg:98.37ms +step:1099/1695 train_time:108104ms step_avg:98.37ms +step:1100/1695 train_time:108205ms step_avg:98.37ms +step:1101/1695 train_time:108304ms step_avg:98.37ms +step:1102/1695 train_time:108404ms step_avg:98.37ms +step:1103/1695 train_time:108504ms step_avg:98.37ms +step:1104/1695 train_time:108603ms step_avg:98.37ms +step:1105/1695 train_time:108703ms step_avg:98.37ms +step:1106/1695 train_time:108803ms step_avg:98.38ms +step:1107/1695 train_time:108903ms step_avg:98.38ms +step:1108/1695 train_time:109004ms step_avg:98.38ms +step:1109/1695 train_time:109103ms step_avg:98.38ms +step:1110/1695 train_time:109204ms step_avg:98.38ms +step:1111/1695 train_time:109304ms step_avg:98.38ms +step:1112/1695 train_time:109404ms step_avg:98.39ms +step:1113/1695 train_time:109504ms step_avg:98.39ms +step:1114/1695 train_time:109604ms step_avg:98.39ms +step:1115/1695 train_time:109704ms step_avg:98.39ms +step:1116/1695 train_time:109803ms step_avg:98.39ms +step:1117/1695 train_time:109904ms step_avg:98.39ms +step:1118/1695 train_time:110004ms step_avg:98.39ms +step:1119/1695 train_time:110104ms step_avg:98.40ms +step:1120/1695 train_time:110205ms step_avg:98.40ms +step:1121/1695 train_time:110304ms step_avg:98.40ms +step:1122/1695 train_time:110404ms step_avg:98.40ms +step:1123/1695 train_time:110504ms step_avg:98.40ms +step:1124/1695 train_time:110603ms step_avg:98.40ms +step:1125/1695 train_time:110705ms step_avg:98.40ms +step:1125/1695 val_loss:3.4411 train_time:110802ms step_avg:98.49ms +step:1126/1695 train_time:110830ms step_avg:98.43ms +step:1127/1695 train_time:110915ms step_avg:98.42ms +step:1128/1695 train_time:111018ms step_avg:98.42ms +step:1129/1695 train_time:111117ms step_avg:98.42ms +step:1130/1695 train_time:111216ms step_avg:98.42ms +step:1131/1695 train_time:111316ms step_avg:98.42ms +step:1132/1695 train_time:111415ms step_avg:98.42ms +step:1133/1695 train_time:111515ms step_avg:98.42ms +step:1134/1695 train_time:111616ms step_avg:98.43ms +step:1135/1695 train_time:111715ms step_avg:98.43ms +step:1136/1695 train_time:111817ms step_avg:98.43ms +step:1137/1695 train_time:111921ms step_avg:98.44ms +step:1138/1695 train_time:112022ms step_avg:98.44ms +step:1139/1695 train_time:112123ms step_avg:98.44ms +step:1140/1695 train_time:112224ms step_avg:98.44ms +step:1141/1695 train_time:112323ms step_avg:98.44ms +step:1142/1695 train_time:112423ms step_avg:98.44ms +step:1143/1695 train_time:112523ms step_avg:98.45ms +step:1144/1695 train_time:112623ms step_avg:98.45ms +step:1145/1695 train_time:112724ms step_avg:98.45ms +step:1146/1695 train_time:112824ms step_avg:98.45ms +step:1147/1695 train_time:112924ms step_avg:98.45ms +step:1148/1695 train_time:113026ms step_avg:98.45ms +step:1149/1695 train_time:113126ms step_avg:98.46ms +step:1150/1695 train_time:113227ms step_avg:98.46ms +step:1151/1695 train_time:113327ms step_avg:98.46ms +step:1152/1695 train_time:113427ms step_avg:98.46ms +step:1153/1695 train_time:113527ms step_avg:98.46ms +step:1154/1695 train_time:113629ms step_avg:98.47ms +step:1155/1695 train_time:113729ms step_avg:98.47ms +step:1156/1695 train_time:113830ms step_avg:98.47ms +step:1157/1695 train_time:113931ms step_avg:98.47ms +step:1158/1695 train_time:114033ms step_avg:98.47ms +step:1159/1695 train_time:114134ms step_avg:98.48ms +step:1160/1695 train_time:114236ms step_avg:98.48ms +step:1161/1695 train_time:114338ms step_avg:98.48ms +step:1162/1695 train_time:114438ms step_avg:98.48ms +step:1163/1695 train_time:114542ms step_avg:98.49ms +step:1164/1695 train_time:114642ms step_avg:98.49ms +step:1165/1695 train_time:114742ms step_avg:98.49ms +step:1166/1695 train_time:114843ms step_avg:98.49ms +step:1167/1695 train_time:114942ms step_avg:98.49ms +step:1168/1695 train_time:115043ms step_avg:98.50ms +step:1169/1695 train_time:115142ms step_avg:98.50ms +step:1170/1695 train_time:115243ms step_avg:98.50ms +step:1171/1695 train_time:115343ms step_avg:98.50ms +step:1172/1695 train_time:115446ms step_avg:98.50ms +step:1173/1695 train_time:115547ms step_avg:98.51ms +step:1174/1695 train_time:115648ms step_avg:98.51ms +step:1175/1695 train_time:115748ms step_avg:98.51ms +step:1176/1695 train_time:115849ms step_avg:98.51ms +step:1177/1695 train_time:115949ms step_avg:98.51ms +step:1178/1695 train_time:116050ms step_avg:98.51ms +step:1179/1695 train_time:116152ms step_avg:98.52ms +step:1180/1695 train_time:116253ms step_avg:98.52ms +step:1181/1695 train_time:116355ms step_avg:98.52ms +step:1182/1695 train_time:116457ms step_avg:98.53ms +step:1183/1695 train_time:116558ms step_avg:98.53ms +step:1184/1695 train_time:116660ms step_avg:98.53ms +step:1185/1695 train_time:116761ms step_avg:98.53ms +step:1186/1695 train_time:116861ms step_avg:98.53ms +step:1187/1695 train_time:116961ms step_avg:98.53ms +step:1188/1695 train_time:117061ms step_avg:98.54ms +step:1189/1695 train_time:117162ms step_avg:98.54ms +step:1190/1695 train_time:117264ms step_avg:98.54ms +step:1191/1695 train_time:117364ms step_avg:98.54ms +step:1192/1695 train_time:117463ms step_avg:98.54ms +step:1193/1695 train_time:117563ms step_avg:98.54ms +step:1194/1695 train_time:117663ms step_avg:98.55ms +step:1195/1695 train_time:117763ms step_avg:98.55ms +step:1196/1695 train_time:117864ms step_avg:98.55ms +step:1197/1695 train_time:117964ms step_avg:98.55ms +step:1198/1695 train_time:118065ms step_avg:98.55ms +step:1199/1695 train_time:118165ms step_avg:98.55ms +step:1200/1695 train_time:118265ms step_avg:98.55ms +step:1201/1695 train_time:118366ms step_avg:98.56ms +step:1202/1695 train_time:118466ms step_avg:98.56ms +step:1203/1695 train_time:118568ms step_avg:98.56ms +step:1204/1695 train_time:118669ms step_avg:98.56ms +step:1205/1695 train_time:118771ms step_avg:98.56ms +step:1206/1695 train_time:118872ms step_avg:98.57ms +step:1207/1695 train_time:118973ms step_avg:98.57ms +step:1208/1695 train_time:119074ms step_avg:98.57ms +step:1209/1695 train_time:119175ms step_avg:98.57ms +step:1210/1695 train_time:119277ms step_avg:98.58ms +step:1211/1695 train_time:119378ms step_avg:98.58ms +step:1212/1695 train_time:119479ms step_avg:98.58ms +step:1213/1695 train_time:119579ms step_avg:98.58ms +step:1214/1695 train_time:119679ms step_avg:98.58ms +step:1215/1695 train_time:119780ms step_avg:98.58ms +step:1216/1695 train_time:119881ms step_avg:98.59ms +step:1217/1695 train_time:119982ms step_avg:98.59ms +step:1218/1695 train_time:120082ms step_avg:98.59ms +step:1219/1695 train_time:120182ms step_avg:98.59ms +step:1220/1695 train_time:120283ms step_avg:98.59ms +step:1221/1695 train_time:120383ms step_avg:98.59ms +step:1222/1695 train_time:120482ms step_avg:98.59ms +step:1223/1695 train_time:120582ms step_avg:98.60ms +step:1224/1695 train_time:120682ms step_avg:98.60ms +step:1225/1695 train_time:120782ms step_avg:98.60ms +step:1226/1695 train_time:120882ms step_avg:98.60ms +step:1227/1695 train_time:120982ms step_avg:98.60ms +step:1228/1695 train_time:121081ms step_avg:98.60ms +step:1229/1695 train_time:121181ms step_avg:98.60ms +step:1230/1695 train_time:121282ms step_avg:98.60ms +step:1231/1695 train_time:121381ms step_avg:98.60ms +step:1232/1695 train_time:121482ms step_avg:98.61ms +step:1233/1695 train_time:121582ms step_avg:98.61ms +step:1234/1695 train_time:121683ms step_avg:98.61ms +step:1235/1695 train_time:121782ms step_avg:98.61ms +step:1236/1695 train_time:121883ms step_avg:98.61ms +step:1237/1695 train_time:121983ms step_avg:98.61ms +step:1238/1695 train_time:122082ms step_avg:98.61ms +step:1239/1695 train_time:122183ms step_avg:98.61ms +step:1240/1695 train_time:122283ms step_avg:98.62ms +step:1241/1695 train_time:122385ms step_avg:98.62ms +step:1242/1695 train_time:122485ms step_avg:98.62ms +step:1243/1695 train_time:122586ms step_avg:98.62ms +step:1244/1695 train_time:122686ms step_avg:98.62ms +step:1245/1695 train_time:122786ms step_avg:98.62ms +step:1246/1695 train_time:122886ms step_avg:98.62ms +step:1247/1695 train_time:122987ms step_avg:98.63ms +step:1248/1695 train_time:123087ms step_avg:98.63ms +step:1249/1695 train_time:123188ms step_avg:98.63ms +step:1250/1695 train_time:123288ms step_avg:98.63ms +step:1250/1695 val_loss:3.3962 train_time:123387ms step_avg:98.71ms +step:1251/1695 train_time:123415ms step_avg:98.65ms +step:1252/1695 train_time:123497ms step_avg:98.64ms +step:1253/1695 train_time:123601ms step_avg:98.64ms +step:1254/1695 train_time:123703ms step_avg:98.65ms +step:1255/1695 train_time:123805ms step_avg:98.65ms +step:1256/1695 train_time:123904ms step_avg:98.65ms +step:1257/1695 train_time:124004ms step_avg:98.65ms +step:1258/1695 train_time:124104ms step_avg:98.65ms +step:1259/1695 train_time:124203ms step_avg:98.65ms +step:1260/1695 train_time:124304ms step_avg:98.65ms +step:1261/1695 train_time:124406ms step_avg:98.66ms +step:1262/1695 train_time:124508ms step_avg:98.66ms +step:1263/1695 train_time:124609ms step_avg:98.66ms +step:1264/1695 train_time:124709ms step_avg:98.66ms +step:1265/1695 train_time:124809ms step_avg:98.66ms +step:1266/1695 train_time:124909ms step_avg:98.66ms +step:1267/1695 train_time:125009ms step_avg:98.67ms +step:1268/1695 train_time:125109ms step_avg:98.67ms +step:1269/1695 train_time:125210ms step_avg:98.67ms +step:1270/1695 train_time:125310ms step_avg:98.67ms +step:1271/1695 train_time:125411ms step_avg:98.67ms +step:1272/1695 train_time:125512ms step_avg:98.67ms +step:1273/1695 train_time:125613ms step_avg:98.68ms +step:1274/1695 train_time:125714ms step_avg:98.68ms +step:1275/1695 train_time:125815ms step_avg:98.68ms +step:1276/1695 train_time:125917ms step_avg:98.68ms +step:1277/1695 train_time:126018ms step_avg:98.68ms +step:1278/1695 train_time:126120ms step_avg:98.69ms +step:1279/1695 train_time:126222ms step_avg:98.69ms +step:1280/1695 train_time:126322ms step_avg:98.69ms +step:1281/1695 train_time:126424ms step_avg:98.69ms +step:1282/1695 train_time:126525ms step_avg:98.69ms +step:1283/1695 train_time:126625ms step_avg:98.69ms +step:1284/1695 train_time:126725ms step_avg:98.70ms +step:1285/1695 train_time:126825ms step_avg:98.70ms +step:1286/1695 train_time:126926ms step_avg:98.70ms +step:1287/1695 train_time:127026ms step_avg:98.70ms +step:1288/1695 train_time:127126ms step_avg:98.70ms +step:1289/1695 train_time:127226ms step_avg:98.70ms +step:1290/1695 train_time:127327ms step_avg:98.70ms +step:1291/1695 train_time:127427ms step_avg:98.70ms +step:1292/1695 train_time:127527ms step_avg:98.71ms +step:1293/1695 train_time:127628ms step_avg:98.71ms +step:1294/1695 train_time:127729ms step_avg:98.71ms +step:1295/1695 train_time:127830ms step_avg:98.71ms +step:1296/1695 train_time:127930ms step_avg:98.71ms +step:1297/1695 train_time:128030ms step_avg:98.71ms +step:1298/1695 train_time:128130ms step_avg:98.71ms +step:1299/1695 train_time:128231ms step_avg:98.72ms +step:1300/1695 train_time:128332ms step_avg:98.72ms +step:1301/1695 train_time:128433ms step_avg:98.72ms +step:1302/1695 train_time:128535ms step_avg:98.72ms +step:1303/1695 train_time:128636ms step_avg:98.72ms +step:1304/1695 train_time:128738ms step_avg:98.73ms +step:1305/1695 train_time:128841ms step_avg:98.73ms +step:1306/1695 train_time:128942ms step_avg:98.73ms +step:1307/1695 train_time:129043ms step_avg:98.73ms +step:1308/1695 train_time:129144ms step_avg:98.73ms +step:1309/1695 train_time:129244ms step_avg:98.73ms +step:1310/1695 train_time:129346ms step_avg:98.74ms +step:1311/1695 train_time:129446ms step_avg:98.74ms +step:1312/1695 train_time:129546ms step_avg:98.74ms +step:1313/1695 train_time:129647ms step_avg:98.74ms +step:1314/1695 train_time:129747ms step_avg:98.74ms +step:1315/1695 train_time:129848ms step_avg:98.74ms +step:1316/1695 train_time:129948ms step_avg:98.74ms +step:1317/1695 train_time:130049ms step_avg:98.75ms +step:1318/1695 train_time:130149ms step_avg:98.75ms +step:1319/1695 train_time:130250ms step_avg:98.75ms +step:1320/1695 train_time:130351ms step_avg:98.75ms +step:1321/1695 train_time:130454ms step_avg:98.75ms +step:1322/1695 train_time:130554ms step_avg:98.76ms +step:1323/1695 train_time:130656ms step_avg:98.76ms +step:1324/1695 train_time:130757ms step_avg:98.76ms +step:1325/1695 train_time:130859ms step_avg:98.76ms +step:1326/1695 train_time:130961ms step_avg:98.76ms +step:1327/1695 train_time:131064ms step_avg:98.77ms +step:1328/1695 train_time:131164ms step_avg:98.77ms +step:1329/1695 train_time:131264ms step_avg:98.77ms +step:1330/1695 train_time:131365ms step_avg:98.77ms +step:1331/1695 train_time:131465ms step_avg:98.77ms +step:1332/1695 train_time:131567ms step_avg:98.77ms +step:1333/1695 train_time:131668ms step_avg:98.78ms +step:1334/1695 train_time:131768ms step_avg:98.78ms +step:1335/1695 train_time:131869ms step_avg:98.78ms +step:1336/1695 train_time:131970ms step_avg:98.78ms +step:1337/1695 train_time:132070ms step_avg:98.78ms +step:1338/1695 train_time:132170ms step_avg:98.78ms +step:1339/1695 train_time:132272ms step_avg:98.78ms +step:1340/1695 train_time:132374ms step_avg:98.79ms +step:1341/1695 train_time:132476ms step_avg:98.79ms +step:1342/1695 train_time:132576ms step_avg:98.79ms +step:1343/1695 train_time:132676ms step_avg:98.79ms +step:1344/1695 train_time:132777ms step_avg:98.79ms +step:1345/1695 train_time:132879ms step_avg:98.79ms +step:1346/1695 train_time:132982ms step_avg:98.80ms +step:1347/1695 train_time:133084ms step_avg:98.80ms +step:1348/1695 train_time:133184ms step_avg:98.80ms +step:1349/1695 train_time:133285ms step_avg:98.80ms +step:1350/1695 train_time:133386ms step_avg:98.80ms +step:1351/1695 train_time:133487ms step_avg:98.81ms +step:1352/1695 train_time:133587ms step_avg:98.81ms +step:1353/1695 train_time:133687ms step_avg:98.81ms +step:1354/1695 train_time:133786ms step_avg:98.81ms +step:1355/1695 train_time:133886ms step_avg:98.81ms +step:1356/1695 train_time:133986ms step_avg:98.81ms +step:1357/1695 train_time:134086ms step_avg:98.81ms +step:1358/1695 train_time:134187ms step_avg:98.81ms +step:1359/1695 train_time:134286ms step_avg:98.81ms +step:1360/1695 train_time:134386ms step_avg:98.81ms +step:1361/1695 train_time:134487ms step_avg:98.81ms +step:1362/1695 train_time:134587ms step_avg:98.82ms +step:1363/1695 train_time:134687ms step_avg:98.82ms +step:1364/1695 train_time:134788ms step_avg:98.82ms +step:1365/1695 train_time:134888ms step_avg:98.82ms +step:1366/1695 train_time:134990ms step_avg:98.82ms +step:1367/1695 train_time:135090ms step_avg:98.82ms +step:1368/1695 train_time:135192ms step_avg:98.82ms +step:1369/1695 train_time:135293ms step_avg:98.83ms +step:1370/1695 train_time:135394ms step_avg:98.83ms +step:1371/1695 train_time:135495ms step_avg:98.83ms +step:1372/1695 train_time:135595ms step_avg:98.83ms +step:1373/1695 train_time:135696ms step_avg:98.83ms +step:1374/1695 train_time:135796ms step_avg:98.83ms +step:1375/1695 train_time:135901ms step_avg:98.84ms +step:1375/1695 val_loss:3.3563 train_time:135999ms step_avg:98.91ms +step:1376/1695 train_time:136030ms step_avg:98.86ms +step:1377/1695 train_time:136115ms step_avg:98.85ms +step:1378/1695 train_time:136216ms step_avg:98.85ms +step:1379/1695 train_time:136316ms step_avg:98.85ms +step:1380/1695 train_time:136417ms step_avg:98.85ms +step:1381/1695 train_time:136516ms step_avg:98.85ms +step:1382/1695 train_time:136616ms step_avg:98.85ms +step:1383/1695 train_time:136716ms step_avg:98.85ms +step:1384/1695 train_time:136816ms step_avg:98.86ms +step:1385/1695 train_time:136919ms step_avg:98.86ms +step:1386/1695 train_time:137023ms step_avg:98.86ms +step:1387/1695 train_time:137125ms step_avg:98.86ms +step:1388/1695 train_time:137228ms step_avg:98.87ms +step:1389/1695 train_time:137330ms step_avg:98.87ms +step:1390/1695 train_time:137431ms step_avg:98.87ms +step:1391/1695 train_time:137533ms step_avg:98.87ms +step:1392/1695 train_time:137634ms step_avg:98.88ms +step:1393/1695 train_time:137736ms step_avg:98.88ms +step:1394/1695 train_time:137837ms step_avg:98.88ms +step:1395/1695 train_time:137938ms step_avg:98.88ms +step:1396/1695 train_time:138041ms step_avg:98.88ms +step:1397/1695 train_time:138143ms step_avg:98.89ms +step:1398/1695 train_time:138246ms step_avg:98.89ms +step:1399/1695 train_time:138349ms step_avg:98.89ms +step:1400/1695 train_time:138451ms step_avg:98.89ms +step:1401/1695 train_time:138552ms step_avg:98.90ms +step:1402/1695 train_time:138654ms step_avg:98.90ms +step:1403/1695 train_time:138756ms step_avg:98.90ms +step:1404/1695 train_time:138858ms step_avg:98.90ms +step:1405/1695 train_time:138960ms step_avg:98.90ms +step:1406/1695 train_time:139062ms step_avg:98.91ms +step:1407/1695 train_time:139163ms step_avg:98.91ms +step:1408/1695 train_time:139265ms step_avg:98.91ms +step:1409/1695 train_time:139370ms step_avg:98.91ms +step:1410/1695 train_time:139472ms step_avg:98.92ms +step:1411/1695 train_time:139574ms step_avg:98.92ms +step:1412/1695 train_time:139678ms step_avg:98.92ms +step:1413/1695 train_time:139778ms step_avg:98.92ms +step:1414/1695 train_time:139879ms step_avg:98.92ms +step:1415/1695 train_time:139981ms step_avg:98.93ms +step:1416/1695 train_time:140082ms step_avg:98.93ms +step:1417/1695 train_time:140183ms step_avg:98.93ms +step:1418/1695 train_time:140284ms step_avg:98.93ms +step:1419/1695 train_time:140387ms step_avg:98.93ms +step:1420/1695 train_time:140491ms step_avg:98.94ms +step:1421/1695 train_time:140593ms step_avg:98.94ms +step:1422/1695 train_time:140695ms step_avg:98.94ms +step:1423/1695 train_time:140796ms step_avg:98.94ms +step:1424/1695 train_time:140898ms step_avg:98.95ms +step:1425/1695 train_time:141000ms step_avg:98.95ms +step:1426/1695 train_time:141102ms step_avg:98.95ms +step:1427/1695 train_time:141203ms step_avg:98.95ms +step:1428/1695 train_time:141305ms step_avg:98.95ms +step:1429/1695 train_time:141409ms step_avg:98.96ms +step:1430/1695 train_time:141510ms step_avg:98.96ms +step:1431/1695 train_time:141612ms step_avg:98.96ms +step:1432/1695 train_time:141713ms step_avg:98.96ms +step:1433/1695 train_time:141814ms step_avg:98.96ms +step:1434/1695 train_time:141914ms step_avg:98.96ms +step:1435/1695 train_time:142016ms step_avg:98.97ms +step:1436/1695 train_time:142118ms step_avg:98.97ms +step:1437/1695 train_time:142220ms step_avg:98.97ms +step:1438/1695 train_time:142323ms step_avg:98.97ms +step:1439/1695 train_time:142425ms step_avg:98.98ms +step:1440/1695 train_time:142528ms step_avg:98.98ms +step:1441/1695 train_time:142631ms step_avg:98.98ms +step:1442/1695 train_time:142731ms step_avg:98.98ms +step:1443/1695 train_time:142832ms step_avg:98.98ms +step:1444/1695 train_time:142933ms step_avg:98.98ms +step:1445/1695 train_time:143035ms step_avg:98.99ms +step:1446/1695 train_time:143135ms step_avg:98.99ms +step:1447/1695 train_time:143236ms step_avg:98.99ms +step:1448/1695 train_time:143339ms step_avg:98.99ms +step:1449/1695 train_time:143440ms step_avg:98.99ms +step:1450/1695 train_time:143542ms step_avg:98.99ms +step:1451/1695 train_time:143644ms step_avg:99.00ms +step:1452/1695 train_time:143746ms step_avg:99.00ms +step:1453/1695 train_time:143850ms step_avg:99.00ms +step:1454/1695 train_time:143953ms step_avg:99.00ms +step:1455/1695 train_time:144055ms step_avg:99.01ms +step:1456/1695 train_time:144156ms step_avg:99.01ms +step:1457/1695 train_time:144258ms step_avg:99.01ms +step:1458/1695 train_time:144360ms step_avg:99.01ms +step:1459/1695 train_time:144461ms step_avg:99.01ms +step:1460/1695 train_time:144562ms step_avg:99.01ms +step:1461/1695 train_time:144664ms step_avg:99.02ms +step:1462/1695 train_time:144766ms step_avg:99.02ms +step:1463/1695 train_time:144868ms step_avg:99.02ms +step:1464/1695 train_time:144971ms step_avg:99.02ms +step:1465/1695 train_time:145072ms step_avg:99.03ms +step:1466/1695 train_time:145173ms step_avg:99.03ms +step:1467/1695 train_time:145273ms step_avg:99.03ms +step:1468/1695 train_time:145376ms step_avg:99.03ms +step:1469/1695 train_time:145478ms step_avg:99.03ms +step:1470/1695 train_time:145579ms step_avg:99.03ms +step:1471/1695 train_time:145683ms step_avg:99.04ms +step:1472/1695 train_time:145785ms step_avg:99.04ms +step:1473/1695 train_time:145887ms step_avg:99.04ms +step:1474/1695 train_time:145989ms step_avg:99.04ms +step:1475/1695 train_time:146091ms step_avg:99.04ms +step:1476/1695 train_time:146193ms step_avg:99.05ms +step:1477/1695 train_time:146294ms step_avg:99.05ms +step:1478/1695 train_time:146395ms step_avg:99.05ms +step:1479/1695 train_time:146496ms step_avg:99.05ms +step:1480/1695 train_time:146598ms step_avg:99.05ms +step:1481/1695 train_time:146702ms step_avg:99.06ms +step:1482/1695 train_time:146803ms step_avg:99.06ms +step:1483/1695 train_time:146906ms step_avg:99.06ms +step:1484/1695 train_time:147009ms step_avg:99.06ms +step:1485/1695 train_time:147111ms step_avg:99.06ms +step:1486/1695 train_time:147211ms step_avg:99.07ms +step:1487/1695 train_time:147312ms step_avg:99.07ms +step:1488/1695 train_time:147415ms step_avg:99.07ms +step:1489/1695 train_time:147516ms step_avg:99.07ms +step:1490/1695 train_time:147618ms step_avg:99.07ms +step:1491/1695 train_time:147720ms step_avg:99.07ms +step:1492/1695 train_time:147822ms step_avg:99.08ms +step:1493/1695 train_time:147924ms step_avg:99.08ms +step:1494/1695 train_time:148027ms step_avg:99.08ms +step:1495/1695 train_time:148129ms step_avg:99.08ms +step:1496/1695 train_time:148231ms step_avg:99.08ms +step:1497/1695 train_time:148331ms step_avg:99.09ms +step:1498/1695 train_time:148433ms step_avg:99.09ms +step:1499/1695 train_time:148534ms step_avg:99.09ms +step:1500/1695 train_time:148635ms step_avg:99.09ms +step:1500/1695 val_loss:3.3218 train_time:148735ms step_avg:99.16ms +step:1501/1695 train_time:148763ms step_avg:99.11ms +step:1502/1695 train_time:148848ms step_avg:99.10ms +step:1503/1695 train_time:148949ms step_avg:99.10ms +step:1504/1695 train_time:149049ms step_avg:99.10ms +step:1505/1695 train_time:149150ms step_avg:99.10ms +step:1506/1695 train_time:149251ms step_avg:99.10ms +step:1507/1695 train_time:149352ms step_avg:99.11ms +step:1508/1695 train_time:149452ms step_avg:99.11ms +step:1509/1695 train_time:149554ms step_avg:99.11ms +step:1510/1695 train_time:149656ms step_avg:99.11ms +step:1511/1695 train_time:149759ms step_avg:99.11ms +step:1512/1695 train_time:149863ms step_avg:99.12ms +step:1513/1695 train_time:149965ms step_avg:99.12ms +step:1514/1695 train_time:150067ms step_avg:99.12ms +step:1515/1695 train_time:150173ms step_avg:99.12ms +step:1516/1695 train_time:150274ms step_avg:99.13ms +step:1517/1695 train_time:150374ms step_avg:99.13ms +step:1518/1695 train_time:150475ms step_avg:99.13ms +step:1519/1695 train_time:150579ms step_avg:99.13ms +step:1520/1695 train_time:150680ms step_avg:99.13ms +step:1521/1695 train_time:150782ms step_avg:99.13ms +step:1522/1695 train_time:150884ms step_avg:99.14ms +step:1523/1695 train_time:150986ms step_avg:99.14ms +step:1524/1695 train_time:151090ms step_avg:99.14ms +step:1525/1695 train_time:151194ms step_avg:99.14ms +step:1526/1695 train_time:151295ms step_avg:99.15ms +step:1527/1695 train_time:151397ms step_avg:99.15ms +step:1528/1695 train_time:151503ms step_avg:99.15ms +step:1529/1695 train_time:151603ms step_avg:99.15ms +step:1530/1695 train_time:151706ms step_avg:99.15ms +step:1531/1695 train_time:151808ms step_avg:99.16ms +step:1532/1695 train_time:151910ms step_avg:99.16ms +step:1533/1695 train_time:152013ms step_avg:99.16ms +step:1534/1695 train_time:152115ms step_avg:99.16ms +step:1535/1695 train_time:152217ms step_avg:99.16ms +step:1536/1695 train_time:152319ms step_avg:99.17ms +step:1537/1695 train_time:152420ms step_avg:99.17ms +step:1538/1695 train_time:152522ms step_avg:99.17ms +step:1539/1695 train_time:152623ms step_avg:99.17ms +step:1540/1695 train_time:152725ms step_avg:99.17ms +step:1541/1695 train_time:152828ms step_avg:99.17ms +step:1542/1695 train_time:152932ms step_avg:99.18ms +step:1543/1695 train_time:153035ms step_avg:99.18ms +step:1544/1695 train_time:153136ms step_avg:99.18ms +step:1545/1695 train_time:153238ms step_avg:99.18ms +step:1546/1695 train_time:153340ms step_avg:99.18ms +step:1547/1695 train_time:153443ms step_avg:99.19ms +step:1548/1695 train_time:153545ms step_avg:99.19ms +step:1549/1695 train_time:153647ms step_avg:99.19ms +step:1550/1695 train_time:153748ms step_avg:99.19ms +step:1551/1695 train_time:153850ms step_avg:99.19ms +step:1552/1695 train_time:153953ms step_avg:99.20ms +step:1553/1695 train_time:154056ms step_avg:99.20ms +step:1554/1695 train_time:154158ms step_avg:99.20ms +step:1555/1695 train_time:154259ms step_avg:99.20ms +step:1556/1695 train_time:154361ms step_avg:99.20ms +step:1557/1695 train_time:154464ms step_avg:99.21ms +step:1558/1695 train_time:154567ms step_avg:99.21ms +step:1559/1695 train_time:154669ms step_avg:99.21ms +step:1560/1695 train_time:154771ms step_avg:99.21ms +step:1561/1695 train_time:154873ms step_avg:99.21ms +step:1562/1695 train_time:154977ms step_avg:99.22ms +step:1563/1695 train_time:155081ms step_avg:99.22ms +step:1564/1695 train_time:155183ms step_avg:99.22ms +step:1565/1695 train_time:155283ms step_avg:99.22ms +step:1566/1695 train_time:155384ms step_avg:99.22ms +step:1567/1695 train_time:155485ms step_avg:99.22ms +step:1568/1695 train_time:155587ms step_avg:99.23ms +step:1569/1695 train_time:155688ms step_avg:99.23ms +step:1570/1695 train_time:155792ms step_avg:99.23ms +step:1571/1695 train_time:155893ms step_avg:99.23ms +step:1572/1695 train_time:155994ms step_avg:99.23ms +step:1573/1695 train_time:156096ms step_avg:99.23ms +step:1574/1695 train_time:156198ms step_avg:99.24ms +step:1575/1695 train_time:156299ms step_avg:99.24ms +step:1576/1695 train_time:156401ms step_avg:99.24ms +step:1577/1695 train_time:156505ms step_avg:99.24ms +step:1578/1695 train_time:156606ms step_avg:99.24ms +step:1579/1695 train_time:156707ms step_avg:99.24ms +step:1580/1695 train_time:156810ms step_avg:99.25ms +step:1581/1695 train_time:156911ms step_avg:99.25ms +step:1582/1695 train_time:157012ms step_avg:99.25ms +step:1583/1695 train_time:157115ms step_avg:99.25ms +step:1584/1695 train_time:157217ms step_avg:99.25ms +step:1585/1695 train_time:157319ms step_avg:99.25ms +step:1586/1695 train_time:157423ms step_avg:99.26ms +step:1587/1695 train_time:157525ms step_avg:99.26ms +step:1588/1695 train_time:157626ms step_avg:99.26ms +step:1589/1695 train_time:157728ms step_avg:99.26ms +step:1590/1695 train_time:157830ms step_avg:99.26ms +step:1591/1695 train_time:157931ms step_avg:99.27ms +step:1592/1695 train_time:158033ms step_avg:99.27ms +step:1593/1695 train_time:158134ms step_avg:99.27ms +step:1594/1695 train_time:158238ms step_avg:99.27ms +step:1595/1695 train_time:158340ms step_avg:99.27ms +step:1596/1695 train_time:158442ms step_avg:99.27ms +step:1597/1695 train_time:158544ms step_avg:99.28ms +step:1598/1695 train_time:158648ms step_avg:99.28ms +step:1599/1695 train_time:158750ms step_avg:99.28ms +step:1600/1695 train_time:158852ms step_avg:99.28ms +step:1601/1695 train_time:158954ms step_avg:99.28ms +step:1602/1695 train_time:159056ms step_avg:99.29ms +step:1603/1695 train_time:159157ms step_avg:99.29ms +step:1604/1695 train_time:159259ms step_avg:99.29ms +step:1605/1695 train_time:159362ms step_avg:99.29ms +step:1606/1695 train_time:159465ms step_avg:99.29ms +step:1607/1695 train_time:159566ms step_avg:99.29ms +step:1608/1695 train_time:159666ms step_avg:99.29ms +step:1609/1695 train_time:159767ms step_avg:99.30ms +step:1610/1695 train_time:159870ms step_avg:99.30ms +step:1611/1695 train_time:159972ms step_avg:99.30ms +step:1612/1695 train_time:160073ms step_avg:99.30ms +step:1613/1695 train_time:160174ms step_avg:99.30ms +step:1614/1695 train_time:160276ms step_avg:99.30ms +step:1615/1695 train_time:160378ms step_avg:99.31ms +step:1616/1695 train_time:160480ms step_avg:99.31ms +step:1617/1695 train_time:160583ms step_avg:99.31ms +step:1618/1695 train_time:160686ms step_avg:99.31ms +step:1619/1695 train_time:160788ms step_avg:99.31ms +step:1620/1695 train_time:160890ms step_avg:99.32ms +step:1621/1695 train_time:160991ms step_avg:99.32ms +step:1622/1695 train_time:161092ms step_avg:99.32ms +step:1623/1695 train_time:161195ms step_avg:99.32ms +step:1624/1695 train_time:161297ms step_avg:99.32ms +step:1625/1695 train_time:161401ms step_avg:99.32ms +step:1625/1695 val_loss:3.2929 train_time:161502ms step_avg:99.39ms +step:1626/1695 train_time:161530ms step_avg:99.34ms +step:1627/1695 train_time:161615ms step_avg:99.33ms +step:1628/1695 train_time:161716ms step_avg:99.33ms +step:1629/1695 train_time:161818ms step_avg:99.34ms +step:1630/1695 train_time:161920ms step_avg:99.34ms +step:1631/1695 train_time:162021ms step_avg:99.34ms +step:1632/1695 train_time:162121ms step_avg:99.34ms +step:1633/1695 train_time:162222ms step_avg:99.34ms +step:1634/1695 train_time:162325ms step_avg:99.34ms +step:1635/1695 train_time:162427ms step_avg:99.34ms +step:1636/1695 train_time:162529ms step_avg:99.35ms +step:1637/1695 train_time:162634ms step_avg:99.35ms +step:1638/1695 train_time:162736ms step_avg:99.35ms +step:1639/1695 train_time:162838ms step_avg:99.35ms +step:1640/1695 train_time:162941ms step_avg:99.35ms +step:1641/1695 train_time:163043ms step_avg:99.36ms +step:1642/1695 train_time:163145ms step_avg:99.36ms +step:1643/1695 train_time:163247ms step_avg:99.36ms +step:1644/1695 train_time:163349ms step_avg:99.36ms +step:1645/1695 train_time:163452ms step_avg:99.36ms +step:1646/1695 train_time:163556ms step_avg:99.37ms +step:1647/1695 train_time:163661ms step_avg:99.37ms +step:1648/1695 train_time:163764ms step_avg:99.37ms +step:1649/1695 train_time:163867ms step_avg:99.37ms +step:1650/1695 train_time:163971ms step_avg:99.38ms +step:1651/1695 train_time:164074ms step_avg:99.38ms +step:1652/1695 train_time:164178ms step_avg:99.38ms +step:1653/1695 train_time:164281ms step_avg:99.38ms +step:1654/1695 train_time:164383ms step_avg:99.38ms +step:1655/1695 train_time:164484ms step_avg:99.39ms +step:1656/1695 train_time:164587ms step_avg:99.39ms +step:1657/1695 train_time:164690ms step_avg:99.39ms +step:1658/1695 train_time:164794ms step_avg:99.39ms +step:1659/1695 train_time:164899ms step_avg:99.40ms +step:1660/1695 train_time:165001ms step_avg:99.40ms +step:1661/1695 train_time:165105ms step_avg:99.40ms +step:1662/1695 train_time:165213ms step_avg:99.41ms +step:1663/1695 train_time:165316ms step_avg:99.41ms +step:1664/1695 train_time:165417ms step_avg:99.41ms +step:1665/1695 train_time:165523ms step_avg:99.41ms +step:1666/1695 train_time:165626ms step_avg:99.42ms +step:1667/1695 train_time:165727ms step_avg:99.42ms +step:1668/1695 train_time:165833ms step_avg:99.42ms +step:1669/1695 train_time:165938ms step_avg:99.42ms +step:1670/1695 train_time:166041ms step_avg:99.43ms +step:1671/1695 train_time:166143ms step_avg:99.43ms +step:1672/1695 train_time:166247ms step_avg:99.43ms +step:1673/1695 train_time:166349ms step_avg:99.43ms +step:1674/1695 train_time:166451ms step_avg:99.43ms +step:1675/1695 train_time:166555ms step_avg:99.44ms +step:1676/1695 train_time:166659ms step_avg:99.44ms +step:1677/1695 train_time:166760ms step_avg:99.44ms +step:1678/1695 train_time:166864ms step_avg:99.44ms +step:1679/1695 train_time:166968ms step_avg:99.45ms +step:1680/1695 train_time:167071ms step_avg:99.45ms +step:1681/1695 train_time:167175ms step_avg:99.45ms +step:1682/1695 train_time:167282ms step_avg:99.45ms +step:1683/1695 train_time:167383ms step_avg:99.46ms +step:1684/1695 train_time:167486ms step_avg:99.46ms +step:1685/1695 train_time:167590ms step_avg:99.46ms +step:1686/1695 train_time:167693ms step_avg:99.46ms +step:1687/1695 train_time:167796ms step_avg:99.46ms +step:1688/1695 train_time:167899ms step_avg:99.47ms +step:1689/1695 train_time:168000ms step_avg:99.47ms +step:1690/1695 train_time:168103ms step_avg:99.47ms +step:1691/1695 train_time:168206ms step_avg:99.47ms +step:1692/1695 train_time:168309ms step_avg:99.47ms +step:1693/1695 train_time:168412ms step_avg:99.48ms +step:1694/1695 train_time:168516ms step_avg:99.48ms +step:1695/1695 train_time:168619ms step_avg:99.48ms +step:1695/1695 val_loss:3.2800 train_time:168718ms step_avg:99.54ms +peak memory allocated: 34221 MiB reserved: 49440 MiB diff --git a/records/082325_SparseAttnGate/d3e1ea3c-521c-4abd-a549-950c698d6cbf.txt b/records/082325_SparseAttnGate/d3e1ea3c-521c-4abd-a549-950c698d6cbf.txt new file mode 100644 index 000000000..050c5f3ef --- /dev/null +++ b/records/082325_SparseAttnGate/d3e1ea3c-521c-4abd-a549-950c698d6cbf.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:16:15 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 301329 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 301330 C /usr/bin/python3 614MiB | +| 0 N/A N/A 301331 C /usr/bin/python3 614MiB | +| 0 N/A N/A 301332 C /usr/bin/python3 614MiB | +| 0 N/A N/A 301333 C /usr/bin/python3 614MiB | +| 0 N/A N/A 301334 C /usr/bin/python3 614MiB | +| 0 N/A N/A 301335 C /usr/bin/python3 614MiB | +| 0 N/A N/A 301336 C /usr/bin/python3 614MiB | +| 1 N/A N/A 301330 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 301331 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 301332 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 301333 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 301334 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 301335 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 301336 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:159ms step_avg:159.07ms +step:2/1695 train_time:185ms step_avg:92.35ms +step:3/1695 train_time:255ms step_avg:84.99ms +step:4/1695 train_time:347ms step_avg:86.81ms +step:5/1695 train_time:439ms step_avg:87.81ms +step:6/1695 train_time:532ms step_avg:88.59ms +step:7/1695 train_time:624ms step_avg:89.20ms +step:8/1695 train_time:717ms step_avg:89.59ms +step:9/1695 train_time:810ms step_avg:89.98ms +step:10/1695 train_time:903ms step_avg:90.31ms +step:11/1695 train_time:996ms step_avg:90.52ms +step:12/1695 train_time:1091ms step_avg:90.88ms +step:13/1695 train_time:1189ms step_avg:91.46ms +step:14/1695 train_time:1285ms step_avg:91.78ms +step:15/1695 train_time:1379ms step_avg:91.96ms +step:16/1695 train_time:1473ms step_avg:92.04ms +step:17/1695 train_time:1567ms step_avg:92.17ms +step:18/1695 train_time:1660ms step_avg:92.24ms +step:19/1695 train_time:1753ms step_avg:92.28ms +step:20/1695 train_time:1847ms step_avg:92.36ms +step:21/1695 train_time:1940ms step_avg:92.36ms +step:22/1695 train_time:2032ms step_avg:92.38ms +step:23/1695 train_time:2127ms step_avg:92.46ms +step:24/1695 train_time:2222ms step_avg:92.58ms +step:25/1695 train_time:2316ms step_avg:92.62ms +step:26/1695 train_time:2410ms step_avg:92.69ms +step:27/1695 train_time:2504ms step_avg:92.74ms +step:28/1695 train_time:2599ms step_avg:92.81ms +step:29/1695 train_time:2692ms step_avg:92.83ms +step:30/1695 train_time:2786ms step_avg:92.87ms +step:31/1695 train_time:2880ms step_avg:92.89ms +step:32/1695 train_time:2973ms step_avg:92.90ms +step:33/1695 train_time:3067ms step_avg:92.93ms +step:34/1695 train_time:3161ms step_avg:92.97ms +step:35/1695 train_time:3254ms step_avg:92.98ms +step:36/1695 train_time:3349ms step_avg:93.03ms +step:37/1695 train_time:3443ms step_avg:93.06ms +step:38/1695 train_time:3538ms step_avg:93.10ms +step:39/1695 train_time:3632ms step_avg:93.12ms +step:40/1695 train_time:3725ms step_avg:93.14ms +step:41/1695 train_time:3819ms step_avg:93.14ms +step:42/1695 train_time:3912ms step_avg:93.14ms +step:43/1695 train_time:4006ms step_avg:93.17ms +step:44/1695 train_time:4100ms step_avg:93.18ms +step:45/1695 train_time:4194ms step_avg:93.20ms +step:46/1695 train_time:4288ms step_avg:93.22ms +step:47/1695 train_time:4385ms step_avg:93.29ms +step:48/1695 train_time:4476ms step_avg:93.26ms +step:49/1695 train_time:4570ms step_avg:93.28ms +step:50/1695 train_time:4664ms step_avg:93.29ms +step:51/1695 train_time:4758ms step_avg:93.29ms +step:52/1695 train_time:4851ms step_avg:93.30ms +step:53/1695 train_time:4944ms step_avg:93.29ms +step:54/1695 train_time:5038ms step_avg:93.30ms +step:55/1695 train_time:5131ms step_avg:93.29ms +step:56/1695 train_time:5225ms step_avg:93.30ms +step:57/1695 train_time:5319ms step_avg:93.31ms +step:58/1695 train_time:5412ms step_avg:93.31ms +step:59/1695 train_time:5506ms step_avg:93.32ms +step:60/1695 train_time:5600ms step_avg:93.33ms +step:61/1695 train_time:5693ms step_avg:93.33ms +step:62/1695 train_time:5787ms step_avg:93.34ms +step:63/1695 train_time:5881ms step_avg:93.35ms +step:64/1695 train_time:5974ms step_avg:93.34ms +step:65/1695 train_time:6067ms step_avg:93.34ms +step:66/1695 train_time:6160ms step_avg:93.34ms +step:67/1695 train_time:6255ms step_avg:93.36ms +step:68/1695 train_time:6349ms step_avg:93.37ms +step:69/1695 train_time:6443ms step_avg:93.38ms +step:70/1695 train_time:6537ms step_avg:93.39ms +step:71/1695 train_time:6630ms step_avg:93.38ms +step:72/1695 train_time:6724ms step_avg:93.39ms +step:73/1695 train_time:6817ms step_avg:93.39ms +step:74/1695 train_time:6910ms step_avg:93.38ms +step:75/1695 train_time:7004ms step_avg:93.39ms +step:76/1695 train_time:7098ms step_avg:93.40ms +step:77/1695 train_time:7191ms step_avg:93.39ms +step:78/1695 train_time:7285ms step_avg:93.40ms +step:79/1695 train_time:7379ms step_avg:93.41ms +step:80/1695 train_time:7472ms step_avg:93.41ms +step:81/1695 train_time:7567ms step_avg:93.42ms +step:82/1695 train_time:7660ms step_avg:93.42ms +step:83/1695 train_time:7754ms step_avg:93.42ms +step:84/1695 train_time:7848ms step_avg:93.43ms +step:85/1695 train_time:7942ms step_avg:93.43ms +step:86/1695 train_time:8035ms step_avg:93.43ms +step:87/1695 train_time:8129ms step_avg:93.44ms +step:88/1695 train_time:8223ms step_avg:93.45ms +step:89/1695 train_time:8317ms step_avg:93.45ms +step:90/1695 train_time:8411ms step_avg:93.46ms +step:91/1695 train_time:8506ms step_avg:93.47ms +step:92/1695 train_time:8600ms step_avg:93.48ms +step:93/1695 train_time:8694ms step_avg:93.48ms +step:94/1695 train_time:8788ms step_avg:93.49ms +step:95/1695 train_time:8881ms step_avg:93.49ms +step:96/1695 train_time:8975ms step_avg:93.49ms +step:97/1695 train_time:9068ms step_avg:93.49ms +step:98/1695 train_time:9162ms step_avg:93.49ms +step:99/1695 train_time:9255ms step_avg:93.49ms +step:100/1695 train_time:9349ms step_avg:93.49ms +step:101/1695 train_time:9442ms step_avg:93.49ms +step:102/1695 train_time:9536ms step_avg:93.49ms +step:103/1695 train_time:9630ms step_avg:93.50ms +step:104/1695 train_time:9724ms step_avg:93.50ms +step:105/1695 train_time:9818ms step_avg:93.50ms +step:106/1695 train_time:9911ms step_avg:93.50ms +step:107/1695 train_time:10005ms step_avg:93.50ms +step:108/1695 train_time:10099ms step_avg:93.51ms +step:109/1695 train_time:10192ms step_avg:93.51ms +step:110/1695 train_time:10286ms step_avg:93.51ms +step:111/1695 train_time:10380ms step_avg:93.51ms +step:112/1695 train_time:10473ms step_avg:93.51ms +step:113/1695 train_time:10567ms step_avg:93.52ms +step:114/1695 train_time:10661ms step_avg:93.52ms +step:115/1695 train_time:10755ms step_avg:93.52ms +step:116/1695 train_time:10849ms step_avg:93.53ms +step:117/1695 train_time:10943ms step_avg:93.53ms +step:118/1695 train_time:11037ms step_avg:93.53ms +step:119/1695 train_time:11130ms step_avg:93.53ms +step:120/1695 train_time:11224ms step_avg:93.53ms +step:121/1695 train_time:11317ms step_avg:93.53ms +step:122/1695 train_time:11411ms step_avg:93.53ms +step:123/1695 train_time:11505ms step_avg:93.54ms +step:124/1695 train_time:11599ms step_avg:93.54ms +step:125/1695 train_time:11692ms step_avg:93.54ms +step:125/1695 val_loss:4.6040 train_time:11785ms step_avg:94.28ms +step:126/1695 train_time:11811ms step_avg:93.74ms +step:127/1695 train_time:11889ms step_avg:93.61ms +step:128/1695 train_time:11992ms step_avg:93.69ms +step:129/1695 train_time:12088ms step_avg:93.71ms +step:130/1695 train_time:12183ms step_avg:93.71ms +step:131/1695 train_time:12277ms step_avg:93.71ms +step:132/1695 train_time:12370ms step_avg:93.71ms +step:133/1695 train_time:12463ms step_avg:93.71ms +step:134/1695 train_time:12557ms step_avg:93.71ms +step:135/1695 train_time:12651ms step_avg:93.71ms +step:136/1695 train_time:12744ms step_avg:93.71ms +step:137/1695 train_time:12839ms step_avg:93.71ms +step:138/1695 train_time:12935ms step_avg:93.73ms +step:139/1695 train_time:13030ms step_avg:93.74ms +step:140/1695 train_time:13126ms step_avg:93.76ms +step:141/1695 train_time:13221ms step_avg:93.76ms +step:142/1695 train_time:13315ms step_avg:93.77ms +step:143/1695 train_time:13408ms step_avg:93.76ms +step:144/1695 train_time:13501ms step_avg:93.76ms +step:145/1695 train_time:13595ms step_avg:93.76ms +step:146/1695 train_time:13689ms step_avg:93.76ms +step:147/1695 train_time:13782ms step_avg:93.75ms +step:148/1695 train_time:13878ms step_avg:93.77ms +step:149/1695 train_time:13972ms step_avg:93.77ms +step:150/1695 train_time:14067ms step_avg:93.78ms +step:151/1695 train_time:14161ms step_avg:93.78ms +step:152/1695 train_time:14256ms step_avg:93.79ms +step:153/1695 train_time:14351ms step_avg:93.79ms +step:154/1695 train_time:14445ms step_avg:93.80ms +step:155/1695 train_time:14540ms step_avg:93.81ms +step:156/1695 train_time:14634ms step_avg:93.80ms +step:157/1695 train_time:14727ms step_avg:93.80ms +step:158/1695 train_time:14822ms step_avg:93.81ms +step:159/1695 train_time:14917ms step_avg:93.82ms +step:160/1695 train_time:15011ms step_avg:93.82ms +step:161/1695 train_time:15105ms step_avg:93.82ms +step:162/1695 train_time:15200ms step_avg:93.83ms +step:163/1695 train_time:15294ms step_avg:93.83ms +step:164/1695 train_time:15389ms step_avg:93.84ms +step:165/1695 train_time:15483ms step_avg:93.84ms +step:166/1695 train_time:15577ms step_avg:93.84ms +step:167/1695 train_time:15671ms step_avg:93.84ms +step:168/1695 train_time:15764ms step_avg:93.83ms +step:169/1695 train_time:15859ms step_avg:93.84ms +step:170/1695 train_time:15952ms step_avg:93.84ms +step:171/1695 train_time:16046ms step_avg:93.84ms +step:172/1695 train_time:16143ms step_avg:93.85ms +step:173/1695 train_time:16237ms step_avg:93.86ms +step:174/1695 train_time:16331ms step_avg:93.86ms +step:175/1695 train_time:16424ms step_avg:93.85ms +step:176/1695 train_time:16519ms step_avg:93.86ms +step:177/1695 train_time:16612ms step_avg:93.86ms +step:178/1695 train_time:16707ms step_avg:93.86ms +step:179/1695 train_time:16801ms step_avg:93.86ms +step:180/1695 train_time:16896ms step_avg:93.87ms +step:181/1695 train_time:16990ms step_avg:93.87ms +step:182/1695 train_time:17084ms step_avg:93.87ms +step:183/1695 train_time:17179ms step_avg:93.88ms +step:184/1695 train_time:17273ms step_avg:93.87ms +step:185/1695 train_time:17367ms step_avg:93.87ms +step:186/1695 train_time:17461ms step_avg:93.87ms +step:187/1695 train_time:17555ms step_avg:93.87ms +step:188/1695 train_time:17648ms step_avg:93.87ms +step:189/1695 train_time:17743ms step_avg:93.88ms +step:190/1695 train_time:17838ms step_avg:93.88ms +step:191/1695 train_time:17932ms step_avg:93.88ms +step:192/1695 train_time:18025ms step_avg:93.88ms +step:193/1695 train_time:18120ms step_avg:93.88ms +step:194/1695 train_time:18214ms step_avg:93.89ms +step:195/1695 train_time:18307ms step_avg:93.88ms +step:196/1695 train_time:18403ms step_avg:93.89ms +step:197/1695 train_time:18498ms step_avg:93.90ms +step:198/1695 train_time:18592ms step_avg:93.90ms +step:199/1695 train_time:18686ms step_avg:93.90ms +step:200/1695 train_time:18780ms step_avg:93.90ms +step:201/1695 train_time:18874ms step_avg:93.90ms +step:202/1695 train_time:18968ms step_avg:93.90ms +step:203/1695 train_time:19062ms step_avg:93.90ms +step:204/1695 train_time:19156ms step_avg:93.90ms +step:205/1695 train_time:19251ms step_avg:93.91ms +step:206/1695 train_time:19345ms step_avg:93.91ms +step:207/1695 train_time:19439ms step_avg:93.91ms +step:208/1695 train_time:19533ms step_avg:93.91ms +step:209/1695 train_time:19626ms step_avg:93.90ms +step:210/1695 train_time:19720ms step_avg:93.90ms +step:211/1695 train_time:19814ms step_avg:93.90ms +step:212/1695 train_time:19908ms step_avg:93.90ms +step:213/1695 train_time:20003ms step_avg:93.91ms +step:214/1695 train_time:20098ms step_avg:93.91ms +step:215/1695 train_time:20191ms step_avg:93.91ms +step:216/1695 train_time:20286ms step_avg:93.91ms +step:217/1695 train_time:20380ms step_avg:93.92ms +step:218/1695 train_time:20475ms step_avg:93.92ms +step:219/1695 train_time:20569ms step_avg:93.92ms +step:220/1695 train_time:20663ms step_avg:93.92ms +step:221/1695 train_time:20757ms step_avg:93.92ms +step:222/1695 train_time:20851ms step_avg:93.92ms +step:223/1695 train_time:20945ms step_avg:93.92ms +step:224/1695 train_time:21039ms step_avg:93.92ms +step:225/1695 train_time:21133ms step_avg:93.92ms +step:226/1695 train_time:21227ms step_avg:93.92ms +step:227/1695 train_time:21322ms step_avg:93.93ms +step:228/1695 train_time:21416ms step_avg:93.93ms +step:229/1695 train_time:21510ms step_avg:93.93ms +step:230/1695 train_time:21605ms step_avg:93.94ms +step:231/1695 train_time:21699ms step_avg:93.94ms +step:232/1695 train_time:21794ms step_avg:93.94ms +step:233/1695 train_time:21887ms step_avg:93.94ms +step:234/1695 train_time:21981ms step_avg:93.94ms +step:235/1695 train_time:22076ms step_avg:93.94ms +step:236/1695 train_time:22170ms step_avg:93.94ms +step:237/1695 train_time:22263ms step_avg:93.94ms +step:238/1695 train_time:22358ms step_avg:93.94ms +step:239/1695 train_time:22452ms step_avg:93.94ms +step:240/1695 train_time:22546ms step_avg:93.94ms +step:241/1695 train_time:22640ms step_avg:93.94ms +step:242/1695 train_time:22734ms step_avg:93.94ms +step:243/1695 train_time:22828ms step_avg:93.94ms +step:244/1695 train_time:22921ms step_avg:93.94ms +step:245/1695 train_time:23016ms step_avg:93.94ms +step:246/1695 train_time:23110ms step_avg:93.94ms +step:247/1695 train_time:23204ms step_avg:93.94ms +step:248/1695 train_time:23298ms step_avg:93.94ms +step:249/1695 train_time:23392ms step_avg:93.94ms +step:250/1695 train_time:23486ms step_avg:93.94ms +step:250/1695 val_loss:4.0758 train_time:23579ms step_avg:94.31ms +step:251/1695 train_time:23605ms step_avg:94.04ms +step:252/1695 train_time:23686ms step_avg:93.99ms +step:253/1695 train_time:23789ms step_avg:94.03ms +step:254/1695 train_time:23885ms step_avg:94.04ms +step:255/1695 train_time:23980ms step_avg:94.04ms +step:256/1695 train_time:24074ms step_avg:94.04ms +step:257/1695 train_time:24168ms step_avg:94.04ms +step:258/1695 train_time:24261ms step_avg:94.04ms +step:259/1695 train_time:24355ms step_avg:94.03ms +step:260/1695 train_time:24449ms step_avg:94.03ms +step:261/1695 train_time:24542ms step_avg:94.03ms +step:262/1695 train_time:24638ms step_avg:94.04ms +step:263/1695 train_time:24734ms step_avg:94.05ms +step:264/1695 train_time:24830ms step_avg:94.05ms +step:265/1695 train_time:24927ms step_avg:94.06ms +step:266/1695 train_time:25022ms step_avg:94.07ms +step:267/1695 train_time:25117ms step_avg:94.07ms +step:268/1695 train_time:25211ms step_avg:94.07ms +step:269/1695 train_time:25304ms step_avg:94.07ms +step:270/1695 train_time:25398ms step_avg:94.07ms +step:271/1695 train_time:25492ms step_avg:94.06ms +step:272/1695 train_time:25586ms step_avg:94.07ms +step:273/1695 train_time:25682ms step_avg:94.07ms +step:274/1695 train_time:25777ms step_avg:94.08ms +step:275/1695 train_time:25872ms step_avg:94.08ms +step:276/1695 train_time:25967ms step_avg:94.08ms +step:277/1695 train_time:26062ms step_avg:94.09ms +step:278/1695 train_time:26156ms step_avg:94.09ms +step:279/1695 train_time:26249ms step_avg:94.08ms +step:280/1695 train_time:26344ms step_avg:94.09ms +step:281/1695 train_time:26438ms step_avg:94.09ms +step:282/1695 train_time:26532ms step_avg:94.09ms +step:283/1695 train_time:26626ms step_avg:94.09ms +step:284/1695 train_time:26722ms step_avg:94.09ms +step:285/1695 train_time:26817ms step_avg:94.09ms +step:286/1695 train_time:26911ms step_avg:94.10ms +step:287/1695 train_time:27006ms step_avg:94.10ms +step:288/1695 train_time:27101ms step_avg:94.10ms +step:289/1695 train_time:27195ms step_avg:94.10ms +step:290/1695 train_time:27289ms step_avg:94.10ms +step:291/1695 train_time:27383ms step_avg:94.10ms +step:292/1695 train_time:27477ms step_avg:94.10ms +step:293/1695 train_time:27571ms step_avg:94.10ms +step:294/1695 train_time:27666ms step_avg:94.10ms +step:295/1695 train_time:27760ms step_avg:94.10ms +step:296/1695 train_time:27855ms step_avg:94.10ms +step:297/1695 train_time:27949ms step_avg:94.10ms +step:298/1695 train_time:28044ms step_avg:94.11ms +step:299/1695 train_time:28139ms step_avg:94.11ms +step:300/1695 train_time:28232ms step_avg:94.11ms +step:301/1695 train_time:28327ms step_avg:94.11ms +step:302/1695 train_time:28421ms step_avg:94.11ms +step:303/1695 train_time:28515ms step_avg:94.11ms +step:304/1695 train_time:28609ms step_avg:94.11ms +step:305/1695 train_time:28705ms step_avg:94.11ms +step:306/1695 train_time:28799ms step_avg:94.11ms +step:307/1695 train_time:28893ms step_avg:94.11ms +step:308/1695 train_time:28987ms step_avg:94.11ms +step:309/1695 train_time:29081ms step_avg:94.11ms +step:310/1695 train_time:29175ms step_avg:94.11ms +step:311/1695 train_time:29269ms step_avg:94.11ms +step:312/1695 train_time:29364ms step_avg:94.12ms +step:313/1695 train_time:29459ms step_avg:94.12ms +step:314/1695 train_time:29553ms step_avg:94.12ms +step:315/1695 train_time:29647ms step_avg:94.12ms +step:316/1695 train_time:29742ms step_avg:94.12ms +step:317/1695 train_time:29836ms step_avg:94.12ms +step:318/1695 train_time:29930ms step_avg:94.12ms +step:319/1695 train_time:30025ms step_avg:94.12ms +step:320/1695 train_time:30120ms step_avg:94.13ms +step:321/1695 train_time:30214ms step_avg:94.12ms +step:322/1695 train_time:30308ms step_avg:94.12ms +step:323/1695 train_time:30402ms step_avg:94.12ms +step:324/1695 train_time:30495ms step_avg:94.12ms +step:325/1695 train_time:30589ms step_avg:94.12ms +step:326/1695 train_time:30684ms step_avg:94.12ms +step:327/1695 train_time:30779ms step_avg:94.13ms +step:328/1695 train_time:30874ms step_avg:94.13ms +step:329/1695 train_time:30968ms step_avg:94.13ms +step:330/1695 train_time:31062ms step_avg:94.13ms +step:331/1695 train_time:31156ms step_avg:94.13ms +step:332/1695 train_time:31251ms step_avg:94.13ms +step:333/1695 train_time:31346ms step_avg:94.13ms +step:334/1695 train_time:31441ms step_avg:94.13ms +step:335/1695 train_time:31535ms step_avg:94.13ms +step:336/1695 train_time:31629ms step_avg:94.13ms +step:337/1695 train_time:31724ms step_avg:94.14ms +step:338/1695 train_time:31819ms step_avg:94.14ms +step:339/1695 train_time:31913ms step_avg:94.14ms +step:340/1695 train_time:32007ms step_avg:94.14ms +step:341/1695 train_time:32102ms step_avg:94.14ms +step:342/1695 train_time:32196ms step_avg:94.14ms +step:343/1695 train_time:32290ms step_avg:94.14ms +step:344/1695 train_time:32385ms step_avg:94.14ms +step:345/1695 train_time:32480ms step_avg:94.15ms +step:346/1695 train_time:32574ms step_avg:94.14ms +step:347/1695 train_time:32668ms step_avg:94.14ms +step:348/1695 train_time:32763ms step_avg:94.15ms +step:349/1695 train_time:32857ms step_avg:94.15ms +step:350/1695 train_time:32952ms step_avg:94.15ms +step:351/1695 train_time:33047ms step_avg:94.15ms +step:352/1695 train_time:33140ms step_avg:94.15ms +step:353/1695 train_time:33234ms step_avg:94.15ms +step:354/1695 train_time:33328ms step_avg:94.15ms +step:355/1695 train_time:33423ms step_avg:94.15ms +step:356/1695 train_time:33517ms step_avg:94.15ms +step:357/1695 train_time:33611ms step_avg:94.15ms +step:358/1695 train_time:33706ms step_avg:94.15ms +step:359/1695 train_time:33800ms step_avg:94.15ms +step:360/1695 train_time:33894ms step_avg:94.15ms +step:361/1695 train_time:33988ms step_avg:94.15ms +step:362/1695 train_time:34083ms step_avg:94.15ms +step:363/1695 train_time:34177ms step_avg:94.15ms +step:364/1695 train_time:34272ms step_avg:94.15ms +step:365/1695 train_time:34366ms step_avg:94.15ms +step:366/1695 train_time:34460ms step_avg:94.15ms +step:367/1695 train_time:34554ms step_avg:94.15ms +step:368/1695 train_time:34650ms step_avg:94.16ms +step:369/1695 train_time:34745ms step_avg:94.16ms +step:370/1695 train_time:34840ms step_avg:94.16ms +step:371/1695 train_time:34934ms step_avg:94.16ms +step:372/1695 train_time:35028ms step_avg:94.16ms +step:373/1695 train_time:35123ms step_avg:94.16ms +step:374/1695 train_time:35217ms step_avg:94.16ms +step:375/1695 train_time:35311ms step_avg:94.16ms +step:375/1695 val_loss:3.8846 train_time:35404ms step_avg:94.41ms +step:376/1695 train_time:35430ms step_avg:94.23ms +step:377/1695 train_time:35510ms step_avg:94.19ms +step:378/1695 train_time:35609ms step_avg:94.20ms +step:379/1695 train_time:35707ms step_avg:94.21ms +step:380/1695 train_time:35803ms step_avg:94.22ms +step:381/1695 train_time:35899ms step_avg:94.22ms +step:382/1695 train_time:35995ms step_avg:94.23ms +step:383/1695 train_time:36091ms step_avg:94.23ms +step:384/1695 train_time:36186ms step_avg:94.23ms +step:385/1695 train_time:36281ms step_avg:94.24ms +step:386/1695 train_time:36377ms step_avg:94.24ms +step:387/1695 train_time:36475ms step_avg:94.25ms +step:388/1695 train_time:36574ms step_avg:94.26ms +step:389/1695 train_time:36672ms step_avg:94.27ms +step:390/1695 train_time:36768ms step_avg:94.28ms +step:391/1695 train_time:36863ms step_avg:94.28ms +step:392/1695 train_time:36959ms step_avg:94.28ms +step:393/1695 train_time:37055ms step_avg:94.29ms +step:394/1695 train_time:37151ms step_avg:94.29ms +step:395/1695 train_time:37246ms step_avg:94.29ms +step:396/1695 train_time:37342ms step_avg:94.30ms +step:397/1695 train_time:37438ms step_avg:94.30ms +step:398/1695 train_time:37536ms step_avg:94.31ms +step:399/1695 train_time:37634ms step_avg:94.32ms +step:400/1695 train_time:37731ms step_avg:94.33ms +step:401/1695 train_time:37828ms step_avg:94.33ms +step:402/1695 train_time:37923ms step_avg:94.34ms +step:403/1695 train_time:38020ms step_avg:94.34ms +step:404/1695 train_time:38116ms step_avg:94.35ms +step:405/1695 train_time:38212ms step_avg:94.35ms +step:406/1695 train_time:38307ms step_avg:94.35ms +step:407/1695 train_time:38403ms step_avg:94.36ms +step:408/1695 train_time:38499ms step_avg:94.36ms +step:409/1695 train_time:38597ms step_avg:94.37ms +step:410/1695 train_time:38693ms step_avg:94.37ms +step:411/1695 train_time:38790ms step_avg:94.38ms +step:412/1695 train_time:38886ms step_avg:94.38ms +step:413/1695 train_time:38982ms step_avg:94.39ms +step:414/1695 train_time:39078ms step_avg:94.39ms +step:415/1695 train_time:39175ms step_avg:94.40ms +step:416/1695 train_time:39271ms step_avg:94.40ms +step:417/1695 train_time:39367ms step_avg:94.41ms +step:418/1695 train_time:39463ms step_avg:94.41ms +step:419/1695 train_time:39560ms step_avg:94.41ms +step:420/1695 train_time:39656ms step_avg:94.42ms +step:421/1695 train_time:39753ms step_avg:94.42ms +step:422/1695 train_time:39849ms step_avg:94.43ms +step:423/1695 train_time:39945ms step_avg:94.43ms +step:424/1695 train_time:40041ms step_avg:94.44ms +step:425/1695 train_time:40138ms step_avg:94.44ms +step:426/1695 train_time:40235ms step_avg:94.45ms +step:427/1695 train_time:40331ms step_avg:94.45ms +step:428/1695 train_time:40427ms step_avg:94.46ms +step:429/1695 train_time:40523ms step_avg:94.46ms +step:430/1695 train_time:40620ms step_avg:94.47ms +step:431/1695 train_time:40717ms step_avg:94.47ms +step:432/1695 train_time:40815ms step_avg:94.48ms +step:433/1695 train_time:40912ms step_avg:94.49ms +step:434/1695 train_time:41009ms step_avg:94.49ms +step:435/1695 train_time:41104ms step_avg:94.49ms +step:436/1695 train_time:41200ms step_avg:94.49ms +step:437/1695 train_time:41296ms step_avg:94.50ms +step:438/1695 train_time:41392ms step_avg:94.50ms +step:439/1695 train_time:41488ms step_avg:94.51ms +step:440/1695 train_time:41583ms step_avg:94.51ms +step:441/1695 train_time:41680ms step_avg:94.51ms +step:442/1695 train_time:41778ms step_avg:94.52ms +step:443/1695 train_time:41874ms step_avg:94.52ms +step:444/1695 train_time:41971ms step_avg:94.53ms +step:445/1695 train_time:42067ms step_avg:94.53ms +step:446/1695 train_time:42163ms step_avg:94.53ms +step:447/1695 train_time:42259ms step_avg:94.54ms +step:448/1695 train_time:42356ms step_avg:94.54ms +step:449/1695 train_time:42452ms step_avg:94.55ms +step:450/1695 train_time:42548ms step_avg:94.55ms +step:451/1695 train_time:42643ms step_avg:94.55ms +step:452/1695 train_time:42739ms step_avg:94.56ms +step:453/1695 train_time:42836ms step_avg:94.56ms +step:454/1695 train_time:42933ms step_avg:94.57ms +step:455/1695 train_time:43029ms step_avg:94.57ms +step:456/1695 train_time:43125ms step_avg:94.57ms +step:457/1695 train_time:43220ms step_avg:94.57ms +step:458/1695 train_time:43316ms step_avg:94.58ms +step:459/1695 train_time:43413ms step_avg:94.58ms +step:460/1695 train_time:43509ms step_avg:94.59ms +step:461/1695 train_time:43605ms step_avg:94.59ms +step:462/1695 train_time:43700ms step_avg:94.59ms +step:463/1695 train_time:43796ms step_avg:94.59ms +step:464/1695 train_time:43893ms step_avg:94.60ms +step:465/1695 train_time:43989ms step_avg:94.60ms +step:466/1695 train_time:44085ms step_avg:94.60ms +step:467/1695 train_time:44180ms step_avg:94.60ms +step:468/1695 train_time:44276ms step_avg:94.61ms +step:469/1695 train_time:44372ms step_avg:94.61ms +step:470/1695 train_time:44469ms step_avg:94.61ms +step:471/1695 train_time:44564ms step_avg:94.62ms +step:472/1695 train_time:44661ms step_avg:94.62ms +step:473/1695 train_time:44757ms step_avg:94.62ms +step:474/1695 train_time:44854ms step_avg:94.63ms +step:475/1695 train_time:44951ms step_avg:94.63ms +step:476/1695 train_time:45047ms step_avg:94.64ms +step:477/1695 train_time:45142ms step_avg:94.64ms +step:478/1695 train_time:45238ms step_avg:94.64ms +step:479/1695 train_time:45334ms step_avg:94.64ms +step:480/1695 train_time:45430ms step_avg:94.65ms +step:481/1695 train_time:45527ms step_avg:94.65ms +step:482/1695 train_time:45622ms step_avg:94.65ms +step:483/1695 train_time:45718ms step_avg:94.66ms +step:484/1695 train_time:45816ms step_avg:94.66ms +step:485/1695 train_time:45913ms step_avg:94.67ms +step:486/1695 train_time:46009ms step_avg:94.67ms +step:487/1695 train_time:46105ms step_avg:94.67ms +step:488/1695 train_time:46200ms step_avg:94.67ms +step:489/1695 train_time:46296ms step_avg:94.68ms +step:490/1695 train_time:46393ms step_avg:94.68ms +step:491/1695 train_time:46490ms step_avg:94.68ms +step:492/1695 train_time:46586ms step_avg:94.69ms +step:493/1695 train_time:46681ms step_avg:94.69ms +step:494/1695 train_time:46777ms step_avg:94.69ms +step:495/1695 train_time:46874ms step_avg:94.69ms +step:496/1695 train_time:46970ms step_avg:94.70ms +step:497/1695 train_time:47066ms step_avg:94.70ms +step:498/1695 train_time:47161ms step_avg:94.70ms +step:499/1695 train_time:47258ms step_avg:94.71ms +step:500/1695 train_time:47356ms step_avg:94.71ms +step:500/1695 val_loss:3.7382 train_time:47450ms step_avg:94.90ms +step:501/1695 train_time:47475ms step_avg:94.76ms +step:502/1695 train_time:47558ms step_avg:94.74ms +step:503/1695 train_time:47660ms step_avg:94.75ms +step:504/1695 train_time:47758ms step_avg:94.76ms +step:505/1695 train_time:47854ms step_avg:94.76ms +step:506/1695 train_time:47950ms step_avg:94.76ms +step:507/1695 train_time:48046ms step_avg:94.76ms +step:508/1695 train_time:48142ms step_avg:94.77ms +step:509/1695 train_time:48238ms step_avg:94.77ms +step:510/1695 train_time:48334ms step_avg:94.77ms +step:511/1695 train_time:48429ms step_avg:94.77ms +step:512/1695 train_time:48526ms step_avg:94.78ms +step:513/1695 train_time:48624ms step_avg:94.78ms +step:514/1695 train_time:48723ms step_avg:94.79ms +step:515/1695 train_time:48821ms step_avg:94.80ms +step:516/1695 train_time:48919ms step_avg:94.80ms +step:517/1695 train_time:49016ms step_avg:94.81ms +step:518/1695 train_time:49112ms step_avg:94.81ms +step:519/1695 train_time:49208ms step_avg:94.81ms +step:520/1695 train_time:49303ms step_avg:94.81ms +step:521/1695 train_time:49399ms step_avg:94.82ms +step:522/1695 train_time:49497ms step_avg:94.82ms +step:523/1695 train_time:49595ms step_avg:94.83ms +step:524/1695 train_time:49692ms step_avg:94.83ms +step:525/1695 train_time:49789ms step_avg:94.84ms +step:526/1695 train_time:49884ms step_avg:94.84ms +step:527/1695 train_time:49981ms step_avg:94.84ms +step:528/1695 train_time:50078ms step_avg:94.85ms +step:529/1695 train_time:50177ms step_avg:94.85ms +step:530/1695 train_time:50272ms step_avg:94.85ms +step:531/1695 train_time:50367ms step_avg:94.85ms +step:532/1695 train_time:50463ms step_avg:94.86ms +step:533/1695 train_time:50560ms step_avg:94.86ms +step:534/1695 train_time:50657ms step_avg:94.86ms +step:535/1695 train_time:50755ms step_avg:94.87ms +step:536/1695 train_time:50851ms step_avg:94.87ms +step:537/1695 train_time:50948ms step_avg:94.87ms +step:538/1695 train_time:51044ms step_avg:94.88ms +step:539/1695 train_time:51140ms step_avg:94.88ms +step:540/1695 train_time:51237ms step_avg:94.88ms +step:541/1695 train_time:51333ms step_avg:94.89ms +step:542/1695 train_time:51429ms step_avg:94.89ms +step:543/1695 train_time:51525ms step_avg:94.89ms +step:544/1695 train_time:51621ms step_avg:94.89ms +step:545/1695 train_time:51718ms step_avg:94.90ms +step:546/1695 train_time:51815ms step_avg:94.90ms +step:547/1695 train_time:51911ms step_avg:94.90ms +step:548/1695 train_time:52007ms step_avg:94.90ms +step:549/1695 train_time:52103ms step_avg:94.91ms +step:550/1695 train_time:52201ms step_avg:94.91ms +step:551/1695 train_time:52297ms step_avg:94.91ms +step:552/1695 train_time:52394ms step_avg:94.92ms +step:553/1695 train_time:52490ms step_avg:94.92ms +step:554/1695 train_time:52586ms step_avg:94.92ms +step:555/1695 train_time:52683ms step_avg:94.92ms +step:556/1695 train_time:52780ms step_avg:94.93ms +step:557/1695 train_time:52878ms step_avg:94.93ms +step:558/1695 train_time:52975ms step_avg:94.94ms +step:559/1695 train_time:53071ms step_avg:94.94ms +step:560/1695 train_time:53167ms step_avg:94.94ms +step:561/1695 train_time:53263ms step_avg:94.94ms +step:562/1695 train_time:53361ms step_avg:94.95ms +step:563/1695 train_time:53458ms step_avg:94.95ms +step:564/1695 train_time:53555ms step_avg:94.96ms +step:565/1695 train_time:53651ms step_avg:94.96ms +step:566/1695 train_time:53747ms step_avg:94.96ms +step:567/1695 train_time:53844ms step_avg:94.96ms +step:568/1695 train_time:53942ms step_avg:94.97ms +step:569/1695 train_time:54039ms step_avg:94.97ms +step:570/1695 train_time:54136ms step_avg:94.98ms +step:571/1695 train_time:54233ms step_avg:94.98ms +step:572/1695 train_time:54330ms step_avg:94.98ms +step:573/1695 train_time:54425ms step_avg:94.98ms +step:574/1695 train_time:54522ms step_avg:94.99ms +step:575/1695 train_time:54619ms step_avg:94.99ms +step:576/1695 train_time:54716ms step_avg:94.99ms +step:577/1695 train_time:54812ms step_avg:95.00ms +step:578/1695 train_time:54909ms step_avg:95.00ms +step:579/1695 train_time:55005ms step_avg:95.00ms +step:580/1695 train_time:55101ms step_avg:95.00ms +step:581/1695 train_time:55198ms step_avg:95.01ms +step:582/1695 train_time:55295ms step_avg:95.01ms +step:583/1695 train_time:55391ms step_avg:95.01ms +step:584/1695 train_time:55487ms step_avg:95.01ms +step:585/1695 train_time:55584ms step_avg:95.01ms +step:586/1695 train_time:55681ms step_avg:95.02ms +step:587/1695 train_time:55777ms step_avg:95.02ms +step:588/1695 train_time:55875ms step_avg:95.03ms +step:589/1695 train_time:55972ms step_avg:95.03ms +step:590/1695 train_time:56068ms step_avg:95.03ms +step:591/1695 train_time:56165ms step_avg:95.03ms +step:592/1695 train_time:56262ms step_avg:95.04ms +step:593/1695 train_time:56360ms step_avg:95.04ms +step:594/1695 train_time:56457ms step_avg:95.05ms +step:595/1695 train_time:56554ms step_avg:95.05ms +step:596/1695 train_time:56650ms step_avg:95.05ms +step:597/1695 train_time:56746ms step_avg:95.05ms +step:598/1695 train_time:56842ms step_avg:95.05ms +step:599/1695 train_time:56939ms step_avg:95.06ms +step:600/1695 train_time:57036ms step_avg:95.06ms +step:601/1695 train_time:57133ms step_avg:95.06ms +step:602/1695 train_time:57229ms step_avg:95.06ms +step:603/1695 train_time:57325ms step_avg:95.07ms +step:604/1695 train_time:57421ms step_avg:95.07ms +step:605/1695 train_time:57518ms step_avg:95.07ms +step:606/1695 train_time:57615ms step_avg:95.07ms +step:607/1695 train_time:57712ms step_avg:95.08ms +step:608/1695 train_time:57808ms step_avg:95.08ms +step:609/1695 train_time:57904ms step_avg:95.08ms +step:610/1695 train_time:58000ms step_avg:95.08ms +step:611/1695 train_time:58096ms step_avg:95.08ms +step:612/1695 train_time:58193ms step_avg:95.09ms +step:613/1695 train_time:58289ms step_avg:95.09ms +step:614/1695 train_time:58385ms step_avg:95.09ms +step:615/1695 train_time:58481ms step_avg:95.09ms +step:616/1695 train_time:58578ms step_avg:95.09ms +step:617/1695 train_time:58675ms step_avg:95.10ms +step:618/1695 train_time:58772ms step_avg:95.10ms +step:619/1695 train_time:58868ms step_avg:95.10ms +step:620/1695 train_time:58964ms step_avg:95.10ms +step:621/1695 train_time:59062ms step_avg:95.11ms +step:622/1695 train_time:59158ms step_avg:95.11ms +step:623/1695 train_time:59255ms step_avg:95.11ms +step:624/1695 train_time:59351ms step_avg:95.11ms +step:625/1695 train_time:59448ms step_avg:95.12ms +step:625/1695 val_loss:3.6536 train_time:59542ms step_avg:95.27ms +step:626/1695 train_time:59568ms step_avg:95.16ms +step:627/1695 train_time:59649ms step_avg:95.13ms +step:628/1695 train_time:59750ms step_avg:95.14ms +step:629/1695 train_time:59847ms step_avg:95.15ms +step:630/1695 train_time:59945ms step_avg:95.15ms +step:631/1695 train_time:60041ms step_avg:95.15ms +step:632/1695 train_time:60139ms step_avg:95.16ms +step:633/1695 train_time:60236ms step_avg:95.16ms +step:634/1695 train_time:60333ms step_avg:95.16ms +step:635/1695 train_time:60429ms step_avg:95.16ms +step:636/1695 train_time:60527ms step_avg:95.17ms +step:637/1695 train_time:60625ms step_avg:95.17ms +step:638/1695 train_time:60723ms step_avg:95.18ms +step:639/1695 train_time:60823ms step_avg:95.18ms +step:640/1695 train_time:60922ms step_avg:95.19ms +step:641/1695 train_time:61020ms step_avg:95.19ms +step:642/1695 train_time:61118ms step_avg:95.20ms +step:643/1695 train_time:61216ms step_avg:95.20ms +step:644/1695 train_time:61313ms step_avg:95.21ms +step:645/1695 train_time:61410ms step_avg:95.21ms +step:646/1695 train_time:61507ms step_avg:95.21ms +step:647/1695 train_time:61825ms step_avg:95.56ms +step:648/1695 train_time:61921ms step_avg:95.56ms +step:649/1695 train_time:62019ms step_avg:95.56ms +step:650/1695 train_time:62117ms step_avg:95.56ms +step:651/1695 train_time:62213ms step_avg:95.57ms +step:652/1695 train_time:62310ms step_avg:95.57ms +step:653/1695 train_time:62407ms step_avg:95.57ms +step:654/1695 train_time:62503ms step_avg:95.57ms +step:655/1695 train_time:62600ms step_avg:95.57ms +step:656/1695 train_time:62700ms step_avg:95.58ms +step:657/1695 train_time:62802ms step_avg:95.59ms +step:658/1695 train_time:62902ms step_avg:95.60ms +step:659/1695 train_time:63001ms step_avg:95.60ms +step:660/1695 train_time:63099ms step_avg:95.60ms +step:661/1695 train_time:63197ms step_avg:95.61ms +step:662/1695 train_time:63295ms step_avg:95.61ms +step:663/1695 train_time:63392ms step_avg:95.61ms +step:664/1695 train_time:63490ms step_avg:95.62ms +step:665/1695 train_time:63586ms step_avg:95.62ms +step:666/1695 train_time:63684ms step_avg:95.62ms +step:667/1695 train_time:63782ms step_avg:95.63ms +step:668/1695 train_time:63881ms step_avg:95.63ms +step:669/1695 train_time:63979ms step_avg:95.63ms +step:670/1695 train_time:64077ms step_avg:95.64ms +step:671/1695 train_time:64175ms step_avg:95.64ms +step:672/1695 train_time:64273ms step_avg:95.64ms +step:673/1695 train_time:64370ms step_avg:95.65ms +step:674/1695 train_time:64467ms step_avg:95.65ms +step:675/1695 train_time:64563ms step_avg:95.65ms +step:676/1695 train_time:64661ms step_avg:95.65ms +step:677/1695 train_time:64759ms step_avg:95.66ms +step:678/1695 train_time:64858ms step_avg:95.66ms +step:679/1695 train_time:64957ms step_avg:95.67ms +step:680/1695 train_time:65055ms step_avg:95.67ms +step:681/1695 train_time:65153ms step_avg:95.67ms +step:682/1695 train_time:65251ms step_avg:95.68ms +step:683/1695 train_time:65349ms step_avg:95.68ms +step:684/1695 train_time:65447ms step_avg:95.68ms +step:685/1695 train_time:65544ms step_avg:95.68ms +step:686/1695 train_time:65641ms step_avg:95.69ms +step:687/1695 train_time:65739ms step_avg:95.69ms +step:688/1695 train_time:65837ms step_avg:95.69ms +step:689/1695 train_time:65936ms step_avg:95.70ms +step:690/1695 train_time:66035ms step_avg:95.70ms +step:691/1695 train_time:66133ms step_avg:95.71ms +step:692/1695 train_time:66231ms step_avg:95.71ms +step:693/1695 train_time:66329ms step_avg:95.71ms +step:694/1695 train_time:66427ms step_avg:95.72ms +step:695/1695 train_time:66525ms step_avg:95.72ms +step:696/1695 train_time:66622ms step_avg:95.72ms +step:697/1695 train_time:66720ms step_avg:95.72ms +step:698/1695 train_time:66818ms step_avg:95.73ms +step:699/1695 train_time:66917ms step_avg:95.73ms +step:700/1695 train_time:67015ms step_avg:95.74ms +step:701/1695 train_time:67358ms step_avg:96.09ms +step:702/1695 train_time:67454ms step_avg:96.09ms +step:703/1695 train_time:67784ms step_avg:96.42ms +step:704/1695 train_time:67879ms step_avg:96.42ms +step:705/1695 train_time:67976ms step_avg:96.42ms +step:706/1695 train_time:68073ms step_avg:96.42ms +step:707/1695 train_time:68170ms step_avg:96.42ms +step:708/1695 train_time:68267ms step_avg:96.42ms +step:709/1695 train_time:68364ms step_avg:96.42ms +step:710/1695 train_time:68461ms step_avg:96.42ms +step:711/1695 train_time:68558ms step_avg:96.42ms +step:712/1695 train_time:68655ms step_avg:96.43ms +step:713/1695 train_time:68757ms step_avg:96.43ms +step:714/1695 train_time:68860ms step_avg:96.44ms +step:715/1695 train_time:68958ms step_avg:96.45ms +step:716/1695 train_time:69056ms step_avg:96.45ms +step:717/1695 train_time:69153ms step_avg:96.45ms +step:718/1695 train_time:69251ms step_avg:96.45ms +step:719/1695 train_time:69348ms step_avg:96.45ms +step:720/1695 train_time:69445ms step_avg:96.45ms +step:721/1695 train_time:69541ms step_avg:96.45ms +step:722/1695 train_time:69639ms step_avg:96.45ms +step:723/1695 train_time:69738ms step_avg:96.46ms +step:724/1695 train_time:69837ms step_avg:96.46ms +step:725/1695 train_time:69936ms step_avg:96.46ms +step:726/1695 train_time:70035ms step_avg:96.47ms +step:727/1695 train_time:70133ms step_avg:96.47ms +step:728/1695 train_time:70231ms step_avg:96.47ms +step:729/1695 train_time:70329ms step_avg:96.47ms +step:730/1695 train_time:70426ms step_avg:96.47ms +step:731/1695 train_time:70523ms step_avg:96.48ms +step:732/1695 train_time:70620ms step_avg:96.48ms +step:733/1695 train_time:70718ms step_avg:96.48ms +step:734/1695 train_time:70816ms step_avg:96.48ms +step:735/1695 train_time:70914ms step_avg:96.48ms +step:736/1695 train_time:71012ms step_avg:96.48ms +step:737/1695 train_time:71110ms step_avg:96.49ms +step:738/1695 train_time:71208ms step_avg:96.49ms +step:739/1695 train_time:71306ms step_avg:96.49ms +step:740/1695 train_time:71402ms step_avg:96.49ms +step:741/1695 train_time:71500ms step_avg:96.49ms +step:742/1695 train_time:71598ms step_avg:96.49ms +step:743/1695 train_time:71696ms step_avg:96.50ms +step:744/1695 train_time:71794ms step_avg:96.50ms +step:745/1695 train_time:71893ms step_avg:96.50ms +step:746/1695 train_time:71990ms step_avg:96.50ms +step:747/1695 train_time:72088ms step_avg:96.50ms +step:748/1695 train_time:72187ms step_avg:96.51ms +step:749/1695 train_time:72285ms step_avg:96.51ms +step:750/1695 train_time:72382ms step_avg:96.51ms +step:750/1695 val_loss:3.5872 train_time:72478ms step_avg:96.64ms +step:751/1695 train_time:72504ms step_avg:96.54ms +step:752/1695 train_time:72586ms step_avg:96.52ms +step:753/1695 train_time:72687ms step_avg:96.53ms +step:754/1695 train_time:72784ms step_avg:96.53ms +step:755/1695 train_time:72882ms step_avg:96.53ms +step:756/1695 train_time:72981ms step_avg:96.54ms +step:757/1695 train_time:73079ms step_avg:96.54ms +step:758/1695 train_time:73177ms step_avg:96.54ms +step:759/1695 train_time:73275ms step_avg:96.54ms +step:760/1695 train_time:73372ms step_avg:96.54ms +step:761/1695 train_time:73471ms step_avg:96.54ms +step:762/1695 train_time:73569ms step_avg:96.55ms +step:763/1695 train_time:73668ms step_avg:96.55ms +step:764/1695 train_time:73766ms step_avg:96.55ms +step:765/1695 train_time:73864ms step_avg:96.55ms +step:766/1695 train_time:73961ms step_avg:96.56ms +step:767/1695 train_time:74060ms step_avg:96.56ms +step:768/1695 train_time:74158ms step_avg:96.56ms +step:769/1695 train_time:74255ms step_avg:96.56ms +step:770/1695 train_time:74354ms step_avg:96.56ms +step:771/1695 train_time:74452ms step_avg:96.57ms +step:772/1695 train_time:74551ms step_avg:96.57ms +step:773/1695 train_time:74650ms step_avg:96.57ms +step:774/1695 train_time:74748ms step_avg:96.57ms +step:775/1695 train_time:74846ms step_avg:96.58ms +step:776/1695 train_time:74944ms step_avg:96.58ms +step:777/1695 train_time:75042ms step_avg:96.58ms +step:778/1695 train_time:75360ms step_avg:96.86ms +step:779/1695 train_time:75456ms step_avg:96.86ms +step:780/1695 train_time:75554ms step_avg:96.86ms +step:781/1695 train_time:75650ms step_avg:96.86ms +step:782/1695 train_time:75747ms step_avg:96.86ms +step:783/1695 train_time:75844ms step_avg:96.86ms +step:784/1695 train_time:75942ms step_avg:96.86ms +step:785/1695 train_time:76040ms step_avg:96.87ms +step:786/1695 train_time:76138ms step_avg:96.87ms +step:787/1695 train_time:76236ms step_avg:96.87ms +step:788/1695 train_time:76339ms step_avg:96.88ms +step:789/1695 train_time:76728ms step_avg:97.25ms +step:790/1695 train_time:76779ms step_avg:97.19ms +step:791/1695 train_time:76876ms step_avg:97.19ms +step:792/1695 train_time:76973ms step_avg:97.19ms +step:793/1695 train_time:77070ms step_avg:97.19ms +step:794/1695 train_time:77167ms step_avg:97.19ms +step:795/1695 train_time:77264ms step_avg:97.19ms +step:796/1695 train_time:77362ms step_avg:97.19ms +step:797/1695 train_time:77783ms step_avg:97.59ms +step:798/1695 train_time:77879ms step_avg:97.59ms +step:799/1695 train_time:77976ms step_avg:97.59ms +step:800/1695 train_time:78074ms step_avg:97.59ms +step:801/1695 train_time:78170ms step_avg:97.59ms +step:802/1695 train_time:78513ms step_avg:97.90ms +step:803/1695 train_time:78608ms step_avg:97.89ms +step:804/1695 train_time:78706ms step_avg:97.89ms +step:805/1695 train_time:78803ms step_avg:97.89ms +step:806/1695 train_time:78901ms step_avg:97.89ms +step:807/1695 train_time:78999ms step_avg:97.89ms +step:808/1695 train_time:79096ms step_avg:97.89ms +step:809/1695 train_time:79194ms step_avg:97.89ms +step:810/1695 train_time:79291ms step_avg:97.89ms +step:811/1695 train_time:79388ms step_avg:97.89ms +step:812/1695 train_time:79489ms step_avg:97.89ms +step:813/1695 train_time:79589ms step_avg:97.90ms +step:814/1695 train_time:79686ms step_avg:97.89ms +step:815/1695 train_time:79784ms step_avg:97.89ms +step:816/1695 train_time:79881ms step_avg:97.89ms +step:817/1695 train_time:79979ms step_avg:97.89ms +step:818/1695 train_time:80077ms step_avg:97.89ms +step:819/1695 train_time:80173ms step_avg:97.89ms +step:820/1695 train_time:80271ms step_avg:97.89ms +step:821/1695 train_time:80368ms step_avg:97.89ms +step:822/1695 train_time:80467ms step_avg:97.89ms +step:823/1695 train_time:80565ms step_avg:97.89ms +step:824/1695 train_time:80663ms step_avg:97.89ms +step:825/1695 train_time:80762ms step_avg:97.89ms +step:826/1695 train_time:80860ms step_avg:97.89ms +step:827/1695 train_time:80959ms step_avg:97.89ms +step:828/1695 train_time:81057ms step_avg:97.89ms +step:829/1695 train_time:81154ms step_avg:97.89ms +step:830/1695 train_time:81251ms step_avg:97.89ms +step:831/1695 train_time:81350ms step_avg:97.89ms +step:832/1695 train_time:81448ms step_avg:97.89ms +step:833/1695 train_time:81546ms step_avg:97.89ms +step:834/1695 train_time:81644ms step_avg:97.89ms +step:835/1695 train_time:81742ms step_avg:97.89ms +step:836/1695 train_time:81841ms step_avg:97.90ms +step:837/1695 train_time:81939ms step_avg:97.90ms +step:838/1695 train_time:82038ms step_avg:97.90ms +step:839/1695 train_time:82135ms step_avg:97.90ms +step:840/1695 train_time:82233ms step_avg:97.90ms +step:841/1695 train_time:82332ms step_avg:97.90ms +step:842/1695 train_time:82431ms step_avg:97.90ms +step:843/1695 train_time:82529ms step_avg:97.90ms +step:844/1695 train_time:82627ms step_avg:97.90ms +step:845/1695 train_time:82724ms step_avg:97.90ms +step:846/1695 train_time:82822ms step_avg:97.90ms +step:847/1695 train_time:82920ms step_avg:97.90ms +step:848/1695 train_time:83018ms step_avg:97.90ms +step:849/1695 train_time:83116ms step_avg:97.90ms +step:850/1695 train_time:83214ms step_avg:97.90ms +step:851/1695 train_time:83312ms step_avg:97.90ms +step:852/1695 train_time:83410ms step_avg:97.90ms +step:853/1695 train_time:83509ms step_avg:97.90ms +step:854/1695 train_time:83606ms step_avg:97.90ms +step:855/1695 train_time:83704ms step_avg:97.90ms +step:856/1695 train_time:83802ms step_avg:97.90ms +step:857/1695 train_time:83899ms step_avg:97.90ms +step:858/1695 train_time:83998ms step_avg:97.90ms +step:859/1695 train_time:84097ms step_avg:97.90ms +step:860/1695 train_time:84196ms step_avg:97.90ms +step:861/1695 train_time:84294ms step_avg:97.90ms +step:862/1695 train_time:84392ms step_avg:97.90ms +step:863/1695 train_time:84490ms step_avg:97.90ms +step:864/1695 train_time:84589ms step_avg:97.90ms +step:865/1695 train_time:84688ms step_avg:97.90ms +step:866/1695 train_time:84786ms step_avg:97.91ms +step:867/1695 train_time:84884ms step_avg:97.90ms +step:868/1695 train_time:84981ms step_avg:97.90ms +step:869/1695 train_time:85078ms step_avg:97.90ms +step:870/1695 train_time:85177ms step_avg:97.90ms +step:871/1695 train_time:85276ms step_avg:97.91ms +step:872/1695 train_time:85375ms step_avg:97.91ms +step:873/1695 train_time:85473ms step_avg:97.91ms +step:874/1695 train_time:85572ms step_avg:97.91ms +step:875/1695 train_time:85671ms step_avg:97.91ms +step:875/1695 val_loss:3.5420 train_time:85767ms step_avg:98.02ms +step:876/1695 train_time:85793ms step_avg:97.94ms +step:877/1695 train_time:85875ms step_avg:97.92ms +step:878/1695 train_time:85975ms step_avg:97.92ms +step:879/1695 train_time:86074ms step_avg:97.92ms +step:880/1695 train_time:86172ms step_avg:97.92ms +step:881/1695 train_time:86270ms step_avg:97.92ms +step:882/1695 train_time:86369ms step_avg:97.92ms +step:883/1695 train_time:86467ms step_avg:97.92ms +step:884/1695 train_time:86566ms step_avg:97.93ms +step:885/1695 train_time:86664ms step_avg:97.93ms +step:886/1695 train_time:86763ms step_avg:97.93ms +step:887/1695 train_time:86863ms step_avg:97.93ms +step:888/1695 train_time:86964ms step_avg:97.93ms +step:889/1695 train_time:87064ms step_avg:97.93ms +step:890/1695 train_time:87163ms step_avg:97.94ms +step:891/1695 train_time:87262ms step_avg:97.94ms +step:892/1695 train_time:87362ms step_avg:97.94ms +step:893/1695 train_time:87462ms step_avg:97.94ms +step:894/1695 train_time:87562ms step_avg:97.94ms +step:895/1695 train_time:87661ms step_avg:97.94ms +step:896/1695 train_time:87761ms step_avg:97.95ms +step:897/1695 train_time:87861ms step_avg:97.95ms +step:898/1695 train_time:87962ms step_avg:97.95ms +step:899/1695 train_time:88062ms step_avg:97.96ms +step:900/1695 train_time:88162ms step_avg:97.96ms +step:901/1695 train_time:88262ms step_avg:97.96ms +step:902/1695 train_time:88361ms step_avg:97.96ms +step:903/1695 train_time:88461ms step_avg:97.96ms +step:904/1695 train_time:88561ms step_avg:97.97ms +step:905/1695 train_time:88660ms step_avg:97.97ms +step:906/1695 train_time:88760ms step_avg:97.97ms +step:907/1695 train_time:88860ms step_avg:97.97ms +step:908/1695 train_time:88960ms step_avg:97.97ms +step:909/1695 train_time:89061ms step_avg:97.98ms +step:910/1695 train_time:89161ms step_avg:97.98ms +step:911/1695 train_time:89262ms step_avg:97.98ms +step:912/1695 train_time:89362ms step_avg:97.98ms +step:913/1695 train_time:89461ms step_avg:97.99ms +step:914/1695 train_time:89561ms step_avg:97.99ms +step:915/1695 train_time:89661ms step_avg:97.99ms +step:916/1695 train_time:89760ms step_avg:97.99ms +step:917/1695 train_time:89860ms step_avg:97.99ms +step:918/1695 train_time:89960ms step_avg:98.00ms +step:919/1695 train_time:90059ms step_avg:98.00ms +step:920/1695 train_time:90159ms step_avg:98.00ms +step:921/1695 train_time:90259ms step_avg:98.00ms +step:922/1695 train_time:90359ms step_avg:98.00ms +step:923/1695 train_time:90459ms step_avg:98.01ms +step:924/1695 train_time:90560ms step_avg:98.01ms +step:925/1695 train_time:90661ms step_avg:98.01ms +step:926/1695 train_time:90760ms step_avg:98.01ms +step:927/1695 train_time:90859ms step_avg:98.01ms +step:928/1695 train_time:90959ms step_avg:98.02ms +step:929/1695 train_time:91059ms step_avg:98.02ms +step:930/1695 train_time:91160ms step_avg:98.02ms +step:931/1695 train_time:91260ms step_avg:98.02ms +step:932/1695 train_time:91362ms step_avg:98.03ms +step:933/1695 train_time:91462ms step_avg:98.03ms +step:934/1695 train_time:91561ms step_avg:98.03ms +step:935/1695 train_time:91661ms step_avg:98.03ms +step:936/1695 train_time:91762ms step_avg:98.04ms +step:937/1695 train_time:91862ms step_avg:98.04ms +step:938/1695 train_time:91961ms step_avg:98.04ms +step:939/1695 train_time:92062ms step_avg:98.04ms +step:940/1695 train_time:92162ms step_avg:98.04ms +step:941/1695 train_time:92262ms step_avg:98.05ms +step:942/1695 train_time:92362ms step_avg:98.05ms +step:943/1695 train_time:92462ms step_avg:98.05ms +step:944/1695 train_time:92561ms step_avg:98.05ms +step:945/1695 train_time:92662ms step_avg:98.05ms +step:946/1695 train_time:92761ms step_avg:98.06ms +step:947/1695 train_time:92861ms step_avg:98.06ms +step:948/1695 train_time:92961ms step_avg:98.06ms +step:949/1695 train_time:93060ms step_avg:98.06ms +step:950/1695 train_time:93160ms step_avg:98.06ms +step:951/1695 train_time:93260ms step_avg:98.07ms +step:952/1695 train_time:93360ms step_avg:98.07ms +step:953/1695 train_time:93460ms step_avg:98.07ms +step:954/1695 train_time:93561ms step_avg:98.07ms +step:955/1695 train_time:93662ms step_avg:98.08ms +step:956/1695 train_time:93762ms step_avg:98.08ms +step:957/1695 train_time:93861ms step_avg:98.08ms +step:958/1695 train_time:93961ms step_avg:98.08ms +step:959/1695 train_time:94060ms step_avg:98.08ms +step:960/1695 train_time:94160ms step_avg:98.08ms +step:961/1695 train_time:94261ms step_avg:98.09ms +step:962/1695 train_time:94362ms step_avg:98.09ms +step:963/1695 train_time:94462ms step_avg:98.09ms +step:964/1695 train_time:94563ms step_avg:98.09ms +step:965/1695 train_time:94663ms step_avg:98.10ms +step:966/1695 train_time:94762ms step_avg:98.10ms +step:967/1695 train_time:94862ms step_avg:98.10ms +step:968/1695 train_time:94962ms step_avg:98.10ms +step:969/1695 train_time:95061ms step_avg:98.10ms +step:970/1695 train_time:95160ms step_avg:98.10ms +step:971/1695 train_time:95260ms step_avg:98.10ms +step:972/1695 train_time:95359ms step_avg:98.11ms +step:973/1695 train_time:95459ms step_avg:98.11ms +step:974/1695 train_time:95559ms step_avg:98.11ms +step:975/1695 train_time:95659ms step_avg:98.11ms +step:976/1695 train_time:95760ms step_avg:98.11ms +step:977/1695 train_time:95860ms step_avg:98.12ms +step:978/1695 train_time:95960ms step_avg:98.12ms +step:979/1695 train_time:96060ms step_avg:98.12ms +step:980/1695 train_time:96160ms step_avg:98.12ms +step:981/1695 train_time:96259ms step_avg:98.12ms +step:982/1695 train_time:96360ms step_avg:98.13ms +step:983/1695 train_time:96460ms step_avg:98.13ms +step:984/1695 train_time:96561ms step_avg:98.13ms +step:985/1695 train_time:96661ms step_avg:98.13ms +step:986/1695 train_time:96762ms step_avg:98.14ms +step:987/1695 train_time:96862ms step_avg:98.14ms +step:988/1695 train_time:96961ms step_avg:98.14ms +step:989/1695 train_time:97061ms step_avg:98.14ms +step:990/1695 train_time:97161ms step_avg:98.14ms +step:991/1695 train_time:97262ms step_avg:98.14ms +step:992/1695 train_time:97361ms step_avg:98.15ms +step:993/1695 train_time:97460ms step_avg:98.15ms +step:994/1695 train_time:97561ms step_avg:98.15ms +step:995/1695 train_time:97662ms step_avg:98.15ms +step:996/1695 train_time:97762ms step_avg:98.15ms +step:997/1695 train_time:97862ms step_avg:98.16ms +step:998/1695 train_time:97961ms step_avg:98.16ms +step:999/1695 train_time:98061ms step_avg:98.16ms +step:1000/1695 train_time:98160ms step_avg:98.16ms +step:1000/1695 val_loss:3.4970 train_time:98257ms step_avg:98.26ms +step:1001/1695 train_time:98283ms step_avg:98.19ms +step:1002/1695 train_time:98371ms step_avg:98.17ms +step:1003/1695 train_time:98472ms step_avg:98.18ms +step:1004/1695 train_time:98572ms step_avg:98.18ms +step:1005/1695 train_time:98672ms step_avg:98.18ms +step:1006/1695 train_time:98772ms step_avg:98.18ms +step:1007/1695 train_time:98871ms step_avg:98.18ms +step:1008/1695 train_time:98970ms step_avg:98.18ms +step:1009/1695 train_time:99069ms step_avg:98.19ms +step:1010/1695 train_time:99168ms step_avg:98.19ms +step:1011/1695 train_time:99271ms step_avg:98.19ms +step:1012/1695 train_time:99373ms step_avg:98.20ms +step:1013/1695 train_time:99475ms step_avg:98.20ms +step:1014/1695 train_time:99575ms step_avg:98.20ms +step:1015/1695 train_time:99675ms step_avg:98.20ms +step:1016/1695 train_time:99774ms step_avg:98.20ms +step:1017/1695 train_time:99874ms step_avg:98.20ms +step:1018/1695 train_time:99972ms step_avg:98.20ms +step:1019/1695 train_time:100072ms step_avg:98.21ms +step:1020/1695 train_time:100172ms step_avg:98.21ms +step:1021/1695 train_time:100274ms step_avg:98.21ms +step:1022/1695 train_time:100375ms step_avg:98.21ms +step:1023/1695 train_time:100476ms step_avg:98.22ms +step:1024/1695 train_time:100578ms step_avg:98.22ms +step:1025/1695 train_time:100677ms step_avg:98.22ms +step:1026/1695 train_time:100778ms step_avg:98.22ms +step:1027/1695 train_time:100877ms step_avg:98.23ms +step:1028/1695 train_time:100978ms step_avg:98.23ms +step:1029/1695 train_time:101079ms step_avg:98.23ms +step:1030/1695 train_time:101179ms step_avg:98.23ms +step:1031/1695 train_time:101279ms step_avg:98.23ms +step:1032/1695 train_time:101378ms step_avg:98.23ms +step:1033/1695 train_time:101478ms step_avg:98.24ms +step:1034/1695 train_time:101577ms step_avg:98.24ms +step:1035/1695 train_time:101677ms step_avg:98.24ms +step:1036/1695 train_time:101777ms step_avg:98.24ms +step:1037/1695 train_time:101877ms step_avg:98.24ms +step:1038/1695 train_time:101977ms step_avg:98.24ms +step:1039/1695 train_time:102077ms step_avg:98.25ms +step:1040/1695 train_time:102178ms step_avg:98.25ms +step:1041/1695 train_time:102278ms step_avg:98.25ms +step:1042/1695 train_time:102377ms step_avg:98.25ms +step:1043/1695 train_time:102477ms step_avg:98.25ms +step:1044/1695 train_time:102577ms step_avg:98.25ms +step:1045/1695 train_time:102677ms step_avg:98.26ms +step:1046/1695 train_time:102778ms step_avg:98.26ms +step:1047/1695 train_time:102878ms step_avg:98.26ms +step:1048/1695 train_time:102978ms step_avg:98.26ms +step:1049/1695 train_time:103078ms step_avg:98.26ms +step:1050/1695 train_time:103178ms step_avg:98.26ms +step:1051/1695 train_time:103277ms step_avg:98.27ms +step:1052/1695 train_time:103377ms step_avg:98.27ms +step:1053/1695 train_time:103477ms step_avg:98.27ms +step:1054/1695 train_time:103577ms step_avg:98.27ms +step:1055/1695 train_time:103677ms step_avg:98.27ms +step:1056/1695 train_time:103776ms step_avg:98.27ms +step:1057/1695 train_time:103876ms step_avg:98.27ms +step:1058/1695 train_time:103977ms step_avg:98.28ms +step:1059/1695 train_time:104077ms step_avg:98.28ms +step:1060/1695 train_time:104177ms step_avg:98.28ms +step:1061/1695 train_time:104277ms step_avg:98.28ms +step:1062/1695 train_time:104377ms step_avg:98.28ms +step:1063/1695 train_time:104477ms step_avg:98.28ms +step:1064/1695 train_time:104577ms step_avg:98.29ms +step:1065/1695 train_time:104677ms step_avg:98.29ms +step:1066/1695 train_time:104776ms step_avg:98.29ms +step:1067/1695 train_time:104877ms step_avg:98.29ms +step:1068/1695 train_time:104976ms step_avg:98.29ms +step:1069/1695 train_time:105077ms step_avg:98.29ms +step:1070/1695 train_time:105178ms step_avg:98.30ms +step:1071/1695 train_time:105278ms step_avg:98.30ms +step:1072/1695 train_time:105378ms step_avg:98.30ms +step:1073/1695 train_time:105478ms step_avg:98.30ms +step:1074/1695 train_time:105578ms step_avg:98.30ms +step:1075/1695 train_time:105677ms step_avg:98.30ms +step:1076/1695 train_time:105777ms step_avg:98.31ms +step:1077/1695 train_time:105877ms step_avg:98.31ms +step:1078/1695 train_time:105977ms step_avg:98.31ms +step:1079/1695 train_time:106079ms step_avg:98.31ms +step:1080/1695 train_time:106178ms step_avg:98.31ms +step:1081/1695 train_time:106278ms step_avg:98.31ms +step:1082/1695 train_time:106378ms step_avg:98.32ms +step:1083/1695 train_time:106478ms step_avg:98.32ms +step:1084/1695 train_time:106577ms step_avg:98.32ms +step:1085/1695 train_time:106677ms step_avg:98.32ms +step:1086/1695 train_time:106777ms step_avg:98.32ms +step:1087/1695 train_time:106877ms step_avg:98.32ms +step:1088/1695 train_time:106976ms step_avg:98.32ms +step:1089/1695 train_time:107077ms step_avg:98.33ms +step:1090/1695 train_time:107177ms step_avg:98.33ms +step:1091/1695 train_time:107277ms step_avg:98.33ms +step:1092/1695 train_time:107378ms step_avg:98.33ms +step:1093/1695 train_time:107477ms step_avg:98.33ms +step:1094/1695 train_time:107578ms step_avg:98.33ms +step:1095/1695 train_time:107678ms step_avg:98.34ms +step:1096/1695 train_time:107778ms step_avg:98.34ms +step:1097/1695 train_time:107877ms step_avg:98.34ms +step:1098/1695 train_time:107978ms step_avg:98.34ms +step:1099/1695 train_time:108077ms step_avg:98.34ms +step:1100/1695 train_time:108177ms step_avg:98.34ms +step:1101/1695 train_time:108277ms step_avg:98.34ms +step:1102/1695 train_time:108377ms step_avg:98.35ms +step:1103/1695 train_time:108477ms step_avg:98.35ms +step:1104/1695 train_time:108576ms step_avg:98.35ms +step:1105/1695 train_time:108677ms step_avg:98.35ms +step:1106/1695 train_time:108777ms step_avg:98.35ms +step:1107/1695 train_time:108877ms step_avg:98.35ms +step:1108/1695 train_time:108977ms step_avg:98.36ms +step:1109/1695 train_time:109077ms step_avg:98.36ms +step:1110/1695 train_time:109178ms step_avg:98.36ms +step:1111/1695 train_time:109279ms step_avg:98.36ms +step:1112/1695 train_time:109379ms step_avg:98.36ms +step:1113/1695 train_time:109479ms step_avg:98.36ms +step:1114/1695 train_time:109579ms step_avg:98.37ms +step:1115/1695 train_time:109678ms step_avg:98.37ms +step:1116/1695 train_time:109778ms step_avg:98.37ms +step:1117/1695 train_time:109878ms step_avg:98.37ms +step:1118/1695 train_time:109977ms step_avg:98.37ms +step:1119/1695 train_time:110077ms step_avg:98.37ms +step:1120/1695 train_time:110177ms step_avg:98.37ms +step:1121/1695 train_time:110277ms step_avg:98.37ms +step:1122/1695 train_time:110378ms step_avg:98.38ms +step:1123/1695 train_time:110479ms step_avg:98.38ms +step:1124/1695 train_time:110578ms step_avg:98.38ms +step:1125/1695 train_time:110679ms step_avg:98.38ms +step:1125/1695 val_loss:3.4440 train_time:110776ms step_avg:98.47ms +step:1126/1695 train_time:110802ms step_avg:98.40ms +step:1127/1695 train_time:110888ms step_avg:98.39ms +step:1128/1695 train_time:110993ms step_avg:98.40ms +step:1129/1695 train_time:111094ms step_avg:98.40ms +step:1130/1695 train_time:111193ms step_avg:98.40ms +step:1131/1695 train_time:111293ms step_avg:98.40ms +step:1132/1695 train_time:111392ms step_avg:98.40ms +step:1133/1695 train_time:111493ms step_avg:98.40ms +step:1134/1695 train_time:111593ms step_avg:98.41ms +step:1135/1695 train_time:111693ms step_avg:98.41ms +step:1136/1695 train_time:111793ms step_avg:98.41ms +step:1137/1695 train_time:111896ms step_avg:98.41ms +step:1138/1695 train_time:111998ms step_avg:98.42ms +step:1139/1695 train_time:112099ms step_avg:98.42ms +step:1140/1695 train_time:112199ms step_avg:98.42ms +step:1141/1695 train_time:112299ms step_avg:98.42ms +step:1142/1695 train_time:112399ms step_avg:98.42ms +step:1143/1695 train_time:112499ms step_avg:98.42ms +step:1144/1695 train_time:112600ms step_avg:98.43ms +step:1145/1695 train_time:112700ms step_avg:98.43ms +step:1146/1695 train_time:112801ms step_avg:98.43ms +step:1147/1695 train_time:112902ms step_avg:98.43ms +step:1148/1695 train_time:113004ms step_avg:98.44ms +step:1149/1695 train_time:113105ms step_avg:98.44ms +step:1150/1695 train_time:113205ms step_avg:98.44ms +step:1151/1695 train_time:113306ms step_avg:98.44ms +step:1152/1695 train_time:113407ms step_avg:98.44ms +step:1153/1695 train_time:113508ms step_avg:98.45ms +step:1154/1695 train_time:113609ms step_avg:98.45ms +step:1155/1695 train_time:113710ms step_avg:98.45ms +step:1156/1695 train_time:113811ms step_avg:98.45ms +step:1157/1695 train_time:113914ms step_avg:98.46ms +step:1158/1695 train_time:114015ms step_avg:98.46ms +step:1159/1695 train_time:114116ms step_avg:98.46ms +step:1160/1695 train_time:114216ms step_avg:98.46ms +step:1161/1695 train_time:114317ms step_avg:98.46ms +step:1162/1695 train_time:114416ms step_avg:98.47ms +step:1163/1695 train_time:114519ms step_avg:98.47ms +step:1164/1695 train_time:114619ms step_avg:98.47ms +step:1165/1695 train_time:114719ms step_avg:98.47ms +step:1166/1695 train_time:114819ms step_avg:98.47ms +step:1167/1695 train_time:114919ms step_avg:98.47ms +step:1168/1695 train_time:115020ms step_avg:98.48ms +step:1169/1695 train_time:115120ms step_avg:98.48ms +step:1170/1695 train_time:115220ms step_avg:98.48ms +step:1171/1695 train_time:115321ms step_avg:98.48ms +step:1172/1695 train_time:115424ms step_avg:98.48ms +step:1173/1695 train_time:115525ms step_avg:98.49ms +step:1174/1695 train_time:115626ms step_avg:98.49ms +step:1175/1695 train_time:115727ms step_avg:98.49ms +step:1176/1695 train_time:115828ms step_avg:98.49ms +step:1177/1695 train_time:115930ms step_avg:98.50ms +step:1178/1695 train_time:116031ms step_avg:98.50ms +step:1179/1695 train_time:116135ms step_avg:98.50ms +step:1180/1695 train_time:116235ms step_avg:98.50ms +step:1181/1695 train_time:116335ms step_avg:98.51ms +step:1182/1695 train_time:116437ms step_avg:98.51ms +step:1183/1695 train_time:116536ms step_avg:98.51ms +step:1184/1695 train_time:116637ms step_avg:98.51ms +step:1185/1695 train_time:116738ms step_avg:98.51ms +step:1186/1695 train_time:116838ms step_avg:98.51ms +step:1187/1695 train_time:116938ms step_avg:98.52ms +step:1188/1695 train_time:117039ms step_avg:98.52ms +step:1189/1695 train_time:117140ms step_avg:98.52ms +step:1190/1695 train_time:117240ms step_avg:98.52ms +step:1191/1695 train_time:117342ms step_avg:98.52ms +step:1192/1695 train_time:117442ms step_avg:98.53ms +step:1193/1695 train_time:117543ms step_avg:98.53ms +step:1194/1695 train_time:117643ms step_avg:98.53ms +step:1195/1695 train_time:117744ms step_avg:98.53ms +step:1196/1695 train_time:117845ms step_avg:98.53ms +step:1197/1695 train_time:117946ms step_avg:98.53ms +step:1198/1695 train_time:118048ms step_avg:98.54ms +step:1199/1695 train_time:118150ms step_avg:98.54ms +step:1200/1695 train_time:118252ms step_avg:98.54ms +step:1201/1695 train_time:118354ms step_avg:98.55ms +step:1202/1695 train_time:118456ms step_avg:98.55ms +step:1203/1695 train_time:118556ms step_avg:98.55ms +step:1204/1695 train_time:118656ms step_avg:98.55ms +step:1205/1695 train_time:118756ms step_avg:98.55ms +step:1206/1695 train_time:118856ms step_avg:98.55ms +step:1207/1695 train_time:118957ms step_avg:98.56ms +step:1208/1695 train_time:119057ms step_avg:98.56ms +step:1209/1695 train_time:119158ms step_avg:98.56ms +step:1210/1695 train_time:119258ms step_avg:98.56ms +step:1211/1695 train_time:119359ms step_avg:98.56ms +step:1212/1695 train_time:119459ms step_avg:98.56ms +step:1213/1695 train_time:119560ms step_avg:98.57ms +step:1214/1695 train_time:119659ms step_avg:98.57ms +step:1215/1695 train_time:119760ms step_avg:98.57ms +step:1216/1695 train_time:119862ms step_avg:98.57ms +step:1217/1695 train_time:119962ms step_avg:98.57ms +step:1218/1695 train_time:120063ms step_avg:98.57ms +step:1219/1695 train_time:120163ms step_avg:98.58ms +step:1220/1695 train_time:120265ms step_avg:98.58ms +step:1221/1695 train_time:120366ms step_avg:98.58ms +step:1222/1695 train_time:120467ms step_avg:98.58ms +step:1223/1695 train_time:120570ms step_avg:98.59ms +step:1224/1695 train_time:120671ms step_avg:98.59ms +step:1225/1695 train_time:120773ms step_avg:98.59ms +step:1226/1695 train_time:120874ms step_avg:98.59ms +step:1227/1695 train_time:120975ms step_avg:98.59ms +step:1228/1695 train_time:121076ms step_avg:98.60ms +step:1229/1695 train_time:121176ms step_avg:98.60ms +step:1230/1695 train_time:121277ms step_avg:98.60ms +step:1231/1695 train_time:121377ms step_avg:98.60ms +step:1232/1695 train_time:121478ms step_avg:98.60ms +step:1233/1695 train_time:121578ms step_avg:98.60ms +step:1234/1695 train_time:121680ms step_avg:98.61ms +step:1235/1695 train_time:121779ms step_avg:98.61ms +step:1236/1695 train_time:121880ms step_avg:98.61ms +step:1237/1695 train_time:121980ms step_avg:98.61ms +step:1238/1695 train_time:122080ms step_avg:98.61ms +step:1239/1695 train_time:122180ms step_avg:98.61ms +step:1240/1695 train_time:122282ms step_avg:98.61ms +step:1241/1695 train_time:122382ms step_avg:98.62ms +step:1242/1695 train_time:122483ms step_avg:98.62ms +step:1243/1695 train_time:122584ms step_avg:98.62ms +step:1244/1695 train_time:122685ms step_avg:98.62ms +step:1245/1695 train_time:122786ms step_avg:98.62ms +step:1246/1695 train_time:122888ms step_avg:98.63ms +step:1247/1695 train_time:122988ms step_avg:98.63ms +step:1248/1695 train_time:123090ms step_avg:98.63ms +step:1249/1695 train_time:123192ms step_avg:98.63ms +step:1250/1695 train_time:123293ms step_avg:98.63ms +step:1250/1695 val_loss:3.3986 train_time:123392ms step_avg:98.71ms +step:1251/1695 train_time:123417ms step_avg:98.65ms +step:1252/1695 train_time:123507ms step_avg:98.65ms +step:1253/1695 train_time:123608ms step_avg:98.65ms +step:1254/1695 train_time:123710ms step_avg:98.65ms +step:1255/1695 train_time:123811ms step_avg:98.65ms +step:1256/1695 train_time:123911ms step_avg:98.66ms +step:1257/1695 train_time:124011ms step_avg:98.66ms +step:1258/1695 train_time:124112ms step_avg:98.66ms +step:1259/1695 train_time:124212ms step_avg:98.66ms +step:1260/1695 train_time:124313ms step_avg:98.66ms +step:1261/1695 train_time:124417ms step_avg:98.67ms +step:1262/1695 train_time:124520ms step_avg:98.67ms +step:1263/1695 train_time:124621ms step_avg:98.67ms +step:1264/1695 train_time:124720ms step_avg:98.67ms +step:1265/1695 train_time:124821ms step_avg:98.67ms +step:1266/1695 train_time:124921ms step_avg:98.67ms +step:1267/1695 train_time:125021ms step_avg:98.67ms +step:1268/1695 train_time:125121ms step_avg:98.68ms +step:1269/1695 train_time:125221ms step_avg:98.68ms +step:1270/1695 train_time:125321ms step_avg:98.68ms +step:1271/1695 train_time:125422ms step_avg:98.68ms +step:1272/1695 train_time:125523ms step_avg:98.68ms +step:1273/1695 train_time:125625ms step_avg:98.68ms +step:1274/1695 train_time:125727ms step_avg:98.69ms +step:1275/1695 train_time:125828ms step_avg:98.69ms +step:1276/1695 train_time:125929ms step_avg:98.69ms +step:1277/1695 train_time:126030ms step_avg:98.69ms +step:1278/1695 train_time:126130ms step_avg:98.69ms +step:1279/1695 train_time:126232ms step_avg:98.70ms +step:1280/1695 train_time:126334ms step_avg:98.70ms +step:1281/1695 train_time:126438ms step_avg:98.70ms +step:1282/1695 train_time:126539ms step_avg:98.70ms +step:1283/1695 train_time:126640ms step_avg:98.71ms +step:1284/1695 train_time:126740ms step_avg:98.71ms +step:1285/1695 train_time:126840ms step_avg:98.71ms +step:1286/1695 train_time:126939ms step_avg:98.71ms +step:1287/1695 train_time:127040ms step_avg:98.71ms +step:1288/1695 train_time:127140ms step_avg:98.71ms +step:1289/1695 train_time:127242ms step_avg:98.71ms +step:1290/1695 train_time:127342ms step_avg:98.71ms +step:1291/1695 train_time:127443ms step_avg:98.72ms +step:1292/1695 train_time:127543ms step_avg:98.72ms +step:1293/1695 train_time:127644ms step_avg:98.72ms +step:1294/1695 train_time:127745ms step_avg:98.72ms +step:1295/1695 train_time:127846ms step_avg:98.72ms +step:1296/1695 train_time:127947ms step_avg:98.72ms +step:1297/1695 train_time:128047ms step_avg:98.73ms +step:1298/1695 train_time:128148ms step_avg:98.73ms +step:1299/1695 train_time:128249ms step_avg:98.73ms +step:1300/1695 train_time:128349ms step_avg:98.73ms +step:1301/1695 train_time:128450ms step_avg:98.73ms +step:1302/1695 train_time:128552ms step_avg:98.73ms +step:1303/1695 train_time:128654ms step_avg:98.74ms +step:1304/1695 train_time:128756ms step_avg:98.74ms +step:1305/1695 train_time:128857ms step_avg:98.74ms +step:1306/1695 train_time:128958ms step_avg:98.74ms +step:1307/1695 train_time:129058ms step_avg:98.74ms +step:1308/1695 train_time:129158ms step_avg:98.74ms +step:1309/1695 train_time:129260ms step_avg:98.75ms +step:1310/1695 train_time:129360ms step_avg:98.75ms +step:1311/1695 train_time:129461ms step_avg:98.75ms +step:1312/1695 train_time:129561ms step_avg:98.75ms +step:1313/1695 train_time:129662ms step_avg:98.75ms +step:1314/1695 train_time:129762ms step_avg:98.75ms +step:1315/1695 train_time:129863ms step_avg:98.76ms +step:1316/1695 train_time:129963ms step_avg:98.76ms +step:1317/1695 train_time:130064ms step_avg:98.76ms +step:1318/1695 train_time:130164ms step_avg:98.76ms +step:1319/1695 train_time:130265ms step_avg:98.76ms +step:1320/1695 train_time:130366ms step_avg:98.76ms +step:1321/1695 train_time:130468ms step_avg:98.76ms +step:1322/1695 train_time:130568ms step_avg:98.77ms +step:1323/1695 train_time:130669ms step_avg:98.77ms +step:1324/1695 train_time:130771ms step_avg:98.77ms +step:1325/1695 train_time:130872ms step_avg:98.77ms +step:1326/1695 train_time:130973ms step_avg:98.77ms +step:1327/1695 train_time:131076ms step_avg:98.78ms +step:1328/1695 train_time:131177ms step_avg:98.78ms +step:1329/1695 train_time:131278ms step_avg:98.78ms +step:1330/1695 train_time:131378ms step_avg:98.78ms +step:1331/1695 train_time:131479ms step_avg:98.78ms +step:1332/1695 train_time:131579ms step_avg:98.78ms +step:1333/1695 train_time:131681ms step_avg:98.79ms +step:1334/1695 train_time:131780ms step_avg:98.79ms +step:1335/1695 train_time:131880ms step_avg:98.79ms +step:1336/1695 train_time:131982ms step_avg:98.79ms +step:1337/1695 train_time:132082ms step_avg:98.79ms +step:1338/1695 train_time:132182ms step_avg:98.79ms +step:1339/1695 train_time:132282ms step_avg:98.79ms +step:1340/1695 train_time:132383ms step_avg:98.79ms +step:1341/1695 train_time:132485ms step_avg:98.80ms +step:1342/1695 train_time:132585ms step_avg:98.80ms +step:1343/1695 train_time:132684ms step_avg:98.80ms +step:1344/1695 train_time:132784ms step_avg:98.80ms +step:1345/1695 train_time:132884ms step_avg:98.80ms +step:1346/1695 train_time:132986ms step_avg:98.80ms +step:1347/1695 train_time:133087ms step_avg:98.80ms +step:1348/1695 train_time:133188ms step_avg:98.80ms +step:1349/1695 train_time:133288ms step_avg:98.81ms +step:1350/1695 train_time:133390ms step_avg:98.81ms +step:1351/1695 train_time:133491ms step_avg:98.81ms +step:1352/1695 train_time:133593ms step_avg:98.81ms +step:1353/1695 train_time:133695ms step_avg:98.81ms +step:1354/1695 train_time:133796ms step_avg:98.82ms +step:1355/1695 train_time:133897ms step_avg:98.82ms +step:1356/1695 train_time:133998ms step_avg:98.82ms +step:1357/1695 train_time:134098ms step_avg:98.82ms +step:1358/1695 train_time:134200ms step_avg:98.82ms +step:1359/1695 train_time:134301ms step_avg:98.82ms +step:1360/1695 train_time:134402ms step_avg:98.83ms +step:1361/1695 train_time:134502ms step_avg:98.83ms +step:1362/1695 train_time:134601ms step_avg:98.83ms +step:1363/1695 train_time:134702ms step_avg:98.83ms +step:1364/1695 train_time:134803ms step_avg:98.83ms +step:1365/1695 train_time:134905ms step_avg:98.83ms +step:1366/1695 train_time:135006ms step_avg:98.83ms +step:1367/1695 train_time:135108ms step_avg:98.84ms +step:1368/1695 train_time:135209ms step_avg:98.84ms +step:1369/1695 train_time:135310ms step_avg:98.84ms +step:1370/1695 train_time:135411ms step_avg:98.84ms +step:1371/1695 train_time:135513ms step_avg:98.84ms +step:1372/1695 train_time:135616ms step_avg:98.85ms +step:1373/1695 train_time:135717ms step_avg:98.85ms +step:1374/1695 train_time:135819ms step_avg:98.85ms +step:1375/1695 train_time:135919ms step_avg:98.85ms +step:1375/1695 val_loss:3.3581 train_time:136018ms step_avg:98.92ms +step:1376/1695 train_time:136044ms step_avg:98.87ms +step:1377/1695 train_time:136129ms step_avg:98.86ms +step:1378/1695 train_time:136230ms step_avg:98.86ms +step:1379/1695 train_time:136332ms step_avg:98.86ms +step:1380/1695 train_time:136435ms step_avg:98.87ms +step:1381/1695 train_time:136535ms step_avg:98.87ms +step:1382/1695 train_time:136634ms step_avg:98.87ms +step:1383/1695 train_time:136734ms step_avg:98.87ms +step:1384/1695 train_time:136834ms step_avg:98.87ms +step:1385/1695 train_time:136936ms step_avg:98.87ms +step:1386/1695 train_time:137039ms step_avg:98.87ms +step:1387/1695 train_time:137141ms step_avg:98.88ms +step:1388/1695 train_time:137242ms step_avg:98.88ms +step:1389/1695 train_time:137344ms step_avg:98.88ms +step:1390/1695 train_time:137446ms step_avg:98.88ms +step:1391/1695 train_time:137548ms step_avg:98.88ms +step:1392/1695 train_time:137650ms step_avg:98.89ms +step:1393/1695 train_time:137753ms step_avg:98.89ms +step:1394/1695 train_time:137855ms step_avg:98.89ms +step:1395/1695 train_time:137956ms step_avg:98.89ms +step:1396/1695 train_time:138059ms step_avg:98.90ms +step:1397/1695 train_time:138162ms step_avg:98.90ms +step:1398/1695 train_time:138264ms step_avg:98.90ms +step:1399/1695 train_time:138366ms step_avg:98.90ms +step:1400/1695 train_time:138469ms step_avg:98.91ms +step:1401/1695 train_time:138570ms step_avg:98.91ms +step:1402/1695 train_time:138673ms step_avg:98.91ms +step:1403/1695 train_time:138775ms step_avg:98.91ms +step:1404/1695 train_time:138877ms step_avg:98.92ms +step:1405/1695 train_time:138979ms step_avg:98.92ms +step:1406/1695 train_time:139082ms step_avg:98.92ms +step:1407/1695 train_time:139183ms step_avg:98.92ms +step:1408/1695 train_time:139283ms step_avg:98.92ms +step:1409/1695 train_time:139387ms step_avg:98.93ms +step:1410/1695 train_time:139488ms step_avg:98.93ms +step:1411/1695 train_time:139589ms step_avg:98.93ms +step:1412/1695 train_time:139694ms step_avg:98.93ms +step:1413/1695 train_time:139796ms step_avg:98.94ms +step:1414/1695 train_time:139897ms step_avg:98.94ms +step:1415/1695 train_time:139999ms step_avg:98.94ms +step:1416/1695 train_time:140100ms step_avg:98.94ms +step:1417/1695 train_time:140200ms step_avg:98.94ms +step:1418/1695 train_time:140301ms step_avg:98.94ms +step:1419/1695 train_time:140403ms step_avg:98.94ms +step:1420/1695 train_time:140504ms step_avg:98.95ms +step:1421/1695 train_time:140606ms step_avg:98.95ms +step:1422/1695 train_time:140708ms step_avg:98.95ms +step:1423/1695 train_time:140812ms step_avg:98.95ms +step:1424/1695 train_time:140914ms step_avg:98.96ms +step:1425/1695 train_time:141015ms step_avg:98.96ms +step:1426/1695 train_time:141119ms step_avg:98.96ms +step:1427/1695 train_time:141220ms step_avg:98.96ms +step:1428/1695 train_time:141322ms step_avg:98.97ms +step:1429/1695 train_time:141424ms step_avg:98.97ms +step:1430/1695 train_time:141525ms step_avg:98.97ms +step:1431/1695 train_time:141627ms step_avg:98.97ms +step:1432/1695 train_time:141728ms step_avg:98.97ms +step:1433/1695 train_time:141831ms step_avg:98.98ms +step:1434/1695 train_time:141932ms step_avg:98.98ms +step:1435/1695 train_time:142035ms step_avg:98.98ms +step:1436/1695 train_time:142137ms step_avg:98.98ms +step:1437/1695 train_time:142239ms step_avg:98.98ms +step:1438/1695 train_time:142339ms step_avg:98.98ms +step:1439/1695 train_time:142442ms step_avg:98.99ms +step:1440/1695 train_time:142546ms step_avg:98.99ms +step:1441/1695 train_time:142648ms step_avg:98.99ms +step:1442/1695 train_time:142749ms step_avg:98.99ms +step:1443/1695 train_time:142850ms step_avg:99.00ms +step:1444/1695 train_time:142952ms step_avg:99.00ms +step:1445/1695 train_time:143054ms step_avg:99.00ms +step:1446/1695 train_time:143155ms step_avg:99.00ms +step:1447/1695 train_time:143256ms step_avg:99.00ms +step:1448/1695 train_time:143361ms step_avg:99.01ms +step:1449/1695 train_time:143461ms step_avg:99.01ms +step:1450/1695 train_time:143563ms step_avg:99.01ms +step:1451/1695 train_time:143665ms step_avg:99.01ms +step:1452/1695 train_time:143768ms step_avg:99.01ms +step:1453/1695 train_time:143869ms step_avg:99.02ms +step:1454/1695 train_time:143972ms step_avg:99.02ms +step:1455/1695 train_time:144075ms step_avg:99.02ms +step:1456/1695 train_time:144177ms step_avg:99.02ms +step:1457/1695 train_time:144279ms step_avg:99.02ms +step:1458/1695 train_time:144382ms step_avg:99.03ms +step:1459/1695 train_time:144484ms step_avg:99.03ms +step:1460/1695 train_time:144585ms step_avg:99.03ms +step:1461/1695 train_time:144688ms step_avg:99.03ms +step:1462/1695 train_time:144788ms step_avg:99.03ms +step:1463/1695 train_time:144890ms step_avg:99.04ms +step:1464/1695 train_time:144992ms step_avg:99.04ms +step:1465/1695 train_time:145093ms step_avg:99.04ms +step:1466/1695 train_time:145195ms step_avg:99.04ms +step:1467/1695 train_time:145296ms step_avg:99.04ms +step:1468/1695 train_time:145399ms step_avg:99.05ms +step:1469/1695 train_time:145501ms step_avg:99.05ms +step:1470/1695 train_time:145601ms step_avg:99.05ms +step:1471/1695 train_time:145704ms step_avg:99.05ms +step:1472/1695 train_time:145807ms step_avg:99.05ms +step:1473/1695 train_time:145908ms step_avg:99.05ms +step:1474/1695 train_time:146009ms step_avg:99.06ms +step:1475/1695 train_time:146111ms step_avg:99.06ms +step:1476/1695 train_time:146214ms step_avg:99.06ms +step:1477/1695 train_time:146316ms step_avg:99.06ms +step:1478/1695 train_time:146418ms step_avg:99.06ms +step:1479/1695 train_time:146518ms step_avg:99.07ms +step:1480/1695 train_time:146620ms step_avg:99.07ms +step:1481/1695 train_time:146721ms step_avg:99.07ms +step:1482/1695 train_time:146823ms step_avg:99.07ms +step:1483/1695 train_time:146926ms step_avg:99.07ms +step:1484/1695 train_time:147028ms step_avg:99.08ms +step:1485/1695 train_time:147129ms step_avg:99.08ms +step:1486/1695 train_time:147231ms step_avg:99.08ms +step:1487/1695 train_time:147332ms step_avg:99.08ms +step:1488/1695 train_time:147435ms step_avg:99.08ms +step:1489/1695 train_time:147538ms step_avg:99.09ms +step:1490/1695 train_time:147640ms step_avg:99.09ms +step:1491/1695 train_time:147741ms step_avg:99.09ms +step:1492/1695 train_time:147842ms step_avg:99.09ms +step:1493/1695 train_time:147944ms step_avg:99.09ms +step:1494/1695 train_time:148046ms step_avg:99.09ms +step:1495/1695 train_time:148148ms step_avg:99.10ms +step:1496/1695 train_time:148250ms step_avg:99.10ms +step:1497/1695 train_time:148351ms step_avg:99.10ms +step:1498/1695 train_time:148454ms step_avg:99.10ms +step:1499/1695 train_time:148556ms step_avg:99.10ms +step:1500/1695 train_time:148658ms step_avg:99.11ms +step:1500/1695 val_loss:3.3231 train_time:148757ms step_avg:99.17ms +step:1501/1695 train_time:148784ms step_avg:99.12ms +step:1502/1695 train_time:148870ms step_avg:99.11ms +step:1503/1695 train_time:148971ms step_avg:99.12ms +step:1504/1695 train_time:149071ms step_avg:99.12ms +step:1505/1695 train_time:149172ms step_avg:99.12ms +step:1506/1695 train_time:149272ms step_avg:99.12ms +step:1507/1695 train_time:149372ms step_avg:99.12ms +step:1508/1695 train_time:149473ms step_avg:99.12ms +step:1509/1695 train_time:149575ms step_avg:99.12ms +step:1510/1695 train_time:149677ms step_avg:99.12ms +step:1511/1695 train_time:149781ms step_avg:99.13ms +step:1512/1695 train_time:149884ms step_avg:99.13ms +step:1513/1695 train_time:149986ms step_avg:99.13ms +step:1514/1695 train_time:150088ms step_avg:99.13ms +step:1515/1695 train_time:150193ms step_avg:99.14ms +step:1516/1695 train_time:150294ms step_avg:99.14ms +step:1517/1695 train_time:150394ms step_avg:99.14ms +step:1518/1695 train_time:150495ms step_avg:99.14ms +step:1519/1695 train_time:150598ms step_avg:99.14ms +step:1520/1695 train_time:150700ms step_avg:99.14ms +step:1521/1695 train_time:150802ms step_avg:99.15ms +step:1522/1695 train_time:150903ms step_avg:99.15ms +step:1523/1695 train_time:151006ms step_avg:99.15ms +step:1524/1695 train_time:151111ms step_avg:99.15ms +step:1525/1695 train_time:151213ms step_avg:99.16ms +step:1526/1695 train_time:151315ms step_avg:99.16ms +step:1527/1695 train_time:151416ms step_avg:99.16ms +step:1528/1695 train_time:151522ms step_avg:99.16ms +step:1529/1695 train_time:151624ms step_avg:99.17ms +step:1530/1695 train_time:151728ms step_avg:99.17ms +step:1531/1695 train_time:151829ms step_avg:99.17ms +step:1532/1695 train_time:151930ms step_avg:99.17ms +step:1533/1695 train_time:152032ms step_avg:99.17ms +step:1534/1695 train_time:152133ms step_avg:99.17ms +step:1535/1695 train_time:152236ms step_avg:99.18ms +step:1536/1695 train_time:152336ms step_avg:99.18ms +step:1537/1695 train_time:152438ms step_avg:99.18ms +step:1538/1695 train_time:152540ms step_avg:99.18ms +step:1539/1695 train_time:152641ms step_avg:99.18ms +step:1540/1695 train_time:152745ms step_avg:99.18ms +step:1541/1695 train_time:152848ms step_avg:99.19ms +step:1542/1695 train_time:152953ms step_avg:99.19ms +step:1543/1695 train_time:153056ms step_avg:99.19ms +step:1544/1695 train_time:153159ms step_avg:99.20ms +step:1545/1695 train_time:153260ms step_avg:99.20ms +step:1546/1695 train_time:153361ms step_avg:99.20ms +step:1547/1695 train_time:153464ms step_avg:99.20ms +step:1548/1695 train_time:153566ms step_avg:99.20ms +step:1549/1695 train_time:153668ms step_avg:99.20ms +step:1550/1695 train_time:153769ms step_avg:99.21ms +step:1551/1695 train_time:153872ms step_avg:99.21ms +step:1552/1695 train_time:153974ms step_avg:99.21ms +step:1553/1695 train_time:154077ms step_avg:99.21ms +step:1554/1695 train_time:154177ms step_avg:99.21ms +step:1555/1695 train_time:154280ms step_avg:99.22ms +step:1556/1695 train_time:154381ms step_avg:99.22ms +step:1557/1695 train_time:154484ms step_avg:99.22ms +step:1558/1695 train_time:154586ms step_avg:99.22ms +step:1559/1695 train_time:154689ms step_avg:99.22ms +step:1560/1695 train_time:154791ms step_avg:99.22ms +step:1561/1695 train_time:154892ms step_avg:99.23ms +step:1562/1695 train_time:154995ms step_avg:99.23ms +step:1563/1695 train_time:155099ms step_avg:99.23ms +step:1564/1695 train_time:155201ms step_avg:99.23ms +step:1565/1695 train_time:155302ms step_avg:99.23ms +step:1566/1695 train_time:155403ms step_avg:99.24ms +step:1567/1695 train_time:155505ms step_avg:99.24ms +step:1568/1695 train_time:155607ms step_avg:99.24ms +step:1569/1695 train_time:155708ms step_avg:99.24ms +step:1570/1695 train_time:155811ms step_avg:99.24ms +step:1571/1695 train_time:155913ms step_avg:99.24ms +step:1572/1695 train_time:156014ms step_avg:99.25ms +step:1573/1695 train_time:156115ms step_avg:99.25ms +step:1574/1695 train_time:156216ms step_avg:99.25ms +step:1575/1695 train_time:156317ms step_avg:99.25ms +step:1576/1695 train_time:156420ms step_avg:99.25ms +step:1577/1695 train_time:156524ms step_avg:99.25ms +step:1578/1695 train_time:156625ms step_avg:99.26ms +step:1579/1695 train_time:156727ms step_avg:99.26ms +step:1580/1695 train_time:156829ms step_avg:99.26ms +step:1581/1695 train_time:156931ms step_avg:99.26ms +step:1582/1695 train_time:157032ms step_avg:99.26ms +step:1583/1695 train_time:157135ms step_avg:99.26ms +step:1584/1695 train_time:157238ms step_avg:99.27ms +step:1585/1695 train_time:157340ms step_avg:99.27ms +step:1586/1695 train_time:157443ms step_avg:99.27ms +step:1587/1695 train_time:157545ms step_avg:99.27ms +step:1588/1695 train_time:157646ms step_avg:99.27ms +step:1589/1695 train_time:157748ms step_avg:99.27ms +step:1590/1695 train_time:157850ms step_avg:99.28ms +step:1591/1695 train_time:157951ms step_avg:99.28ms +step:1592/1695 train_time:158053ms step_avg:99.28ms +step:1593/1695 train_time:158155ms step_avg:99.28ms +step:1594/1695 train_time:158259ms step_avg:99.28ms +step:1595/1695 train_time:158361ms step_avg:99.29ms +step:1596/1695 train_time:158462ms step_avg:99.29ms +step:1597/1695 train_time:158565ms step_avg:99.29ms +step:1598/1695 train_time:158668ms step_avg:99.29ms +step:1599/1695 train_time:158769ms step_avg:99.29ms +step:1600/1695 train_time:158872ms step_avg:99.29ms +step:1601/1695 train_time:158975ms step_avg:99.30ms +step:1602/1695 train_time:159077ms step_avg:99.30ms +step:1603/1695 train_time:159178ms step_avg:99.30ms +step:1604/1695 train_time:159279ms step_avg:99.30ms +step:1605/1695 train_time:159381ms step_avg:99.30ms +step:1606/1695 train_time:159483ms step_avg:99.30ms +step:1607/1695 train_time:159584ms step_avg:99.31ms +step:1608/1695 train_time:159685ms step_avg:99.31ms +step:1609/1695 train_time:159787ms step_avg:99.31ms +step:1610/1695 train_time:159890ms step_avg:99.31ms +step:1611/1695 train_time:159993ms step_avg:99.31ms +step:1612/1695 train_time:160095ms step_avg:99.31ms +step:1613/1695 train_time:160196ms step_avg:99.32ms +step:1614/1695 train_time:160297ms step_avg:99.32ms +step:1615/1695 train_time:160399ms step_avg:99.32ms +step:1616/1695 train_time:160500ms step_avg:99.32ms +step:1617/1695 train_time:160601ms step_avg:99.32ms +step:1618/1695 train_time:160703ms step_avg:99.32ms +step:1619/1695 train_time:160806ms step_avg:99.32ms +step:1620/1695 train_time:160909ms step_avg:99.33ms +step:1621/1695 train_time:161011ms step_avg:99.33ms +step:1622/1695 train_time:161112ms step_avg:99.33ms +step:1623/1695 train_time:161214ms step_avg:99.33ms +step:1624/1695 train_time:161316ms step_avg:99.33ms +step:1625/1695 train_time:161419ms step_avg:99.33ms +step:1625/1695 val_loss:3.2939 train_time:161518ms step_avg:99.40ms +step:1626/1695 train_time:161544ms step_avg:99.35ms +step:1627/1695 train_time:161636ms step_avg:99.35ms +step:1628/1695 train_time:161739ms step_avg:99.35ms +step:1629/1695 train_time:161842ms step_avg:99.35ms +step:1630/1695 train_time:161943ms step_avg:99.35ms +step:1631/1695 train_time:162045ms step_avg:99.35ms +step:1632/1695 train_time:162147ms step_avg:99.35ms +step:1633/1695 train_time:162248ms step_avg:99.36ms +step:1634/1695 train_time:162350ms step_avg:99.36ms +step:1635/1695 train_time:162452ms step_avg:99.36ms +step:1636/1695 train_time:162555ms step_avg:99.36ms +step:1637/1695 train_time:162658ms step_avg:99.36ms +step:1638/1695 train_time:162761ms step_avg:99.37ms +step:1639/1695 train_time:162864ms step_avg:99.37ms +step:1640/1695 train_time:162966ms step_avg:99.37ms +step:1641/1695 train_time:163069ms step_avg:99.37ms +step:1642/1695 train_time:163171ms step_avg:99.37ms +step:1643/1695 train_time:163273ms step_avg:99.38ms +step:1644/1695 train_time:163375ms step_avg:99.38ms +step:1645/1695 train_time:163478ms step_avg:99.38ms +step:1646/1695 train_time:163580ms step_avg:99.38ms +step:1647/1695 train_time:163684ms step_avg:99.38ms +step:1648/1695 train_time:163788ms step_avg:99.39ms +step:1649/1695 train_time:163890ms step_avg:99.39ms +step:1650/1695 train_time:163993ms step_avg:99.39ms +step:1651/1695 train_time:164094ms step_avg:99.39ms +step:1652/1695 train_time:164197ms step_avg:99.39ms +step:1653/1695 train_time:164299ms step_avg:99.39ms +step:1654/1695 train_time:164402ms step_avg:99.40ms +step:1655/1695 train_time:164504ms step_avg:99.40ms +step:1656/1695 train_time:164607ms step_avg:99.40ms +step:1657/1695 train_time:164710ms step_avg:99.40ms +step:1658/1695 train_time:164812ms step_avg:99.40ms +step:1659/1695 train_time:164918ms step_avg:99.41ms +step:1660/1695 train_time:165021ms step_avg:99.41ms +step:1661/1695 train_time:165125ms step_avg:99.41ms +step:1662/1695 train_time:165229ms step_avg:99.42ms +step:1663/1695 train_time:165332ms step_avg:99.42ms +step:1664/1695 train_time:165434ms step_avg:99.42ms +step:1665/1695 train_time:165539ms step_avg:99.42ms +step:1666/1695 train_time:165642ms step_avg:99.43ms +step:1667/1695 train_time:165744ms step_avg:99.43ms +step:1668/1695 train_time:165850ms step_avg:99.43ms +step:1669/1695 train_time:165954ms step_avg:99.43ms +step:1670/1695 train_time:166055ms step_avg:99.43ms +step:1671/1695 train_time:166158ms step_avg:99.44ms +step:1672/1695 train_time:166260ms step_avg:99.44ms +step:1673/1695 train_time:166363ms step_avg:99.44ms +step:1674/1695 train_time:166465ms step_avg:99.44ms +step:1675/1695 train_time:166570ms step_avg:99.45ms +step:1676/1695 train_time:166675ms step_avg:99.45ms +step:1677/1695 train_time:166776ms step_avg:99.45ms +step:1678/1695 train_time:166879ms step_avg:99.45ms +step:1679/1695 train_time:166984ms step_avg:99.45ms +step:1680/1695 train_time:167086ms step_avg:99.46ms +step:1681/1695 train_time:167189ms step_avg:99.46ms +step:1682/1695 train_time:167295ms step_avg:99.46ms +step:1683/1695 train_time:167396ms step_avg:99.46ms +step:1684/1695 train_time:167499ms step_avg:99.46ms +step:1685/1695 train_time:167603ms step_avg:99.47ms +step:1686/1695 train_time:167705ms step_avg:99.47ms +step:1687/1695 train_time:167808ms step_avg:99.47ms +step:1688/1695 train_time:167911ms step_avg:99.47ms +step:1689/1695 train_time:168013ms step_avg:99.47ms +step:1690/1695 train_time:168114ms step_avg:99.48ms +step:1691/1695 train_time:168216ms step_avg:99.48ms +step:1692/1695 train_time:168318ms step_avg:99.48ms +step:1693/1695 train_time:168420ms step_avg:99.48ms +step:1694/1695 train_time:168524ms step_avg:99.48ms +step:1695/1695 train_time:168628ms step_avg:99.49ms +step:1695/1695 val_loss:3.2810 train_time:168728ms step_avg:99.54ms +peak memory allocated: 34004 MiB reserved: 49660 MiB diff --git a/records/082325_SparseAttnGate/e8891a98-8bf2-43cc-bac5-728aa53482ce.txt b/records/082325_SparseAttnGate/e8891a98-8bf2-43cc-bac5-728aa53482ce.txt new file mode 100644 index 000000000..1eed97d0a --- /dev/null +++ b/records/082325_SparseAttnGate/e8891a98-8bf2-43cc-bac5-728aa53482ce.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:36:07 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 312571 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 312572 C /usr/bin/python3 614MiB | +| 0 N/A N/A 312573 C /usr/bin/python3 614MiB | +| 0 N/A N/A 312574 C /usr/bin/python3 614MiB | +| 0 N/A N/A 312575 C /usr/bin/python3 614MiB | +| 0 N/A N/A 312576 C /usr/bin/python3 614MiB | +| 0 N/A N/A 312577 C /usr/bin/python3 614MiB | +| 0 N/A N/A 312578 C /usr/bin/python3 614MiB | +| 1 N/A N/A 312572 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 312573 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 312574 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 312575 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 312576 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 312577 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 312578 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:157ms step_avg:157.06ms +step:2/1695 train_time:183ms step_avg:91.36ms +step:3/1695 train_time:253ms step_avg:84.35ms +step:4/1695 train_time:344ms step_avg:86.09ms +step:5/1695 train_time:438ms step_avg:87.63ms +step:6/1695 train_time:531ms step_avg:88.54ms +step:7/1695 train_time:624ms step_avg:89.16ms +step:8/1695 train_time:718ms step_avg:89.72ms +step:9/1695 train_time:810ms step_avg:90.03ms +step:10/1695 train_time:903ms step_avg:90.28ms +step:11/1695 train_time:996ms step_avg:90.53ms +step:12/1695 train_time:1090ms step_avg:90.82ms +step:13/1695 train_time:1185ms step_avg:91.13ms +step:14/1695 train_time:1280ms step_avg:91.40ms +step:15/1695 train_time:1375ms step_avg:91.65ms +step:16/1695 train_time:1469ms step_avg:91.81ms +step:17/1695 train_time:1562ms step_avg:91.88ms +step:18/1695 train_time:1655ms step_avg:91.96ms +step:19/1695 train_time:1748ms step_avg:92.00ms +step:20/1695 train_time:1841ms step_avg:92.03ms +step:21/1695 train_time:1934ms step_avg:92.12ms +step:22/1695 train_time:2029ms step_avg:92.25ms +step:23/1695 train_time:2123ms step_avg:92.29ms +step:24/1695 train_time:2217ms step_avg:92.36ms +step:25/1695 train_time:2311ms step_avg:92.45ms +step:26/1695 train_time:2406ms step_avg:92.52ms +step:27/1695 train_time:2500ms step_avg:92.59ms +step:28/1695 train_time:2595ms step_avg:92.69ms +step:29/1695 train_time:2689ms step_avg:92.72ms +step:30/1695 train_time:2782ms step_avg:92.72ms +step:31/1695 train_time:2875ms step_avg:92.74ms +step:32/1695 train_time:2969ms step_avg:92.77ms +step:33/1695 train_time:3062ms step_avg:92.80ms +step:34/1695 train_time:3156ms step_avg:92.81ms +step:35/1695 train_time:3250ms step_avg:92.86ms +step:36/1695 train_time:3343ms step_avg:92.87ms +step:37/1695 train_time:3437ms step_avg:92.89ms +step:38/1695 train_time:3531ms step_avg:92.93ms +step:39/1695 train_time:3625ms step_avg:92.96ms +step:40/1695 train_time:3719ms step_avg:92.97ms +step:41/1695 train_time:3812ms step_avg:92.98ms +step:42/1695 train_time:3906ms step_avg:92.99ms +step:43/1695 train_time:3999ms step_avg:93.00ms +step:44/1695 train_time:4093ms step_avg:93.03ms +step:45/1695 train_time:4187ms step_avg:93.04ms +step:46/1695 train_time:4280ms step_avg:93.05ms +step:47/1695 train_time:4374ms step_avg:93.06ms +step:48/1695 train_time:4468ms step_avg:93.08ms +step:49/1695 train_time:4562ms step_avg:93.10ms +step:50/1695 train_time:4656ms step_avg:93.12ms +step:51/1695 train_time:4750ms step_avg:93.13ms +step:52/1695 train_time:4844ms step_avg:93.15ms +step:53/1695 train_time:4937ms step_avg:93.15ms +step:54/1695 train_time:5030ms step_avg:93.15ms +step:55/1695 train_time:5123ms step_avg:93.15ms +step:56/1695 train_time:5217ms step_avg:93.16ms +step:57/1695 train_time:5311ms step_avg:93.17ms +step:58/1695 train_time:5404ms step_avg:93.18ms +step:59/1695 train_time:5498ms step_avg:93.19ms +step:60/1695 train_time:5592ms step_avg:93.20ms +step:61/1695 train_time:5686ms step_avg:93.22ms +step:62/1695 train_time:5779ms step_avg:93.22ms +step:63/1695 train_time:5874ms step_avg:93.23ms +step:64/1695 train_time:5969ms step_avg:93.26ms +step:65/1695 train_time:6062ms step_avg:93.26ms +step:66/1695 train_time:6155ms step_avg:93.26ms +step:67/1695 train_time:6248ms step_avg:93.25ms +step:68/1695 train_time:6342ms step_avg:93.26ms +step:69/1695 train_time:6435ms step_avg:93.27ms +step:70/1695 train_time:6529ms step_avg:93.27ms +step:71/1695 train_time:6623ms step_avg:93.28ms +step:72/1695 train_time:6717ms step_avg:93.30ms +step:73/1695 train_time:6811ms step_avg:93.31ms +step:74/1695 train_time:6906ms step_avg:93.32ms +step:75/1695 train_time:6999ms step_avg:93.32ms +step:76/1695 train_time:7093ms step_avg:93.33ms +step:77/1695 train_time:7186ms step_avg:93.33ms +step:78/1695 train_time:7280ms step_avg:93.33ms +step:79/1695 train_time:7374ms step_avg:93.34ms +step:80/1695 train_time:7468ms step_avg:93.35ms +step:81/1695 train_time:7561ms step_avg:93.35ms +step:82/1695 train_time:7656ms step_avg:93.36ms +step:83/1695 train_time:7749ms step_avg:93.36ms +step:84/1695 train_time:7843ms step_avg:93.37ms +step:85/1695 train_time:7937ms step_avg:93.38ms +step:86/1695 train_time:8031ms step_avg:93.38ms +step:87/1695 train_time:8123ms step_avg:93.37ms +step:88/1695 train_time:8217ms step_avg:93.37ms +step:89/1695 train_time:8310ms step_avg:93.37ms +step:90/1695 train_time:8403ms step_avg:93.37ms +step:91/1695 train_time:8497ms step_avg:93.37ms +step:92/1695 train_time:8591ms step_avg:93.38ms +step:93/1695 train_time:8685ms step_avg:93.38ms +step:94/1695 train_time:8778ms step_avg:93.39ms +step:95/1695 train_time:8872ms step_avg:93.39ms +step:96/1695 train_time:8968ms step_avg:93.41ms +step:97/1695 train_time:9060ms step_avg:93.40ms +step:98/1695 train_time:9153ms step_avg:93.40ms +step:99/1695 train_time:9247ms step_avg:93.40ms +step:100/1695 train_time:9340ms step_avg:93.40ms +step:101/1695 train_time:9433ms step_avg:93.40ms +step:102/1695 train_time:9526ms step_avg:93.40ms +step:103/1695 train_time:9619ms step_avg:93.39ms +step:104/1695 train_time:9713ms step_avg:93.40ms +step:105/1695 train_time:9808ms step_avg:93.41ms +step:106/1695 train_time:9901ms step_avg:93.41ms +step:107/1695 train_time:9994ms step_avg:93.41ms +step:108/1695 train_time:10088ms step_avg:93.41ms +step:109/1695 train_time:10181ms step_avg:93.40ms +step:110/1695 train_time:10275ms step_avg:93.41ms +step:111/1695 train_time:10369ms step_avg:93.41ms +step:112/1695 train_time:10462ms step_avg:93.41ms +step:113/1695 train_time:10556ms step_avg:93.41ms +step:114/1695 train_time:10649ms step_avg:93.41ms +step:115/1695 train_time:10742ms step_avg:93.41ms +step:116/1695 train_time:10836ms step_avg:93.41ms +step:117/1695 train_time:10930ms step_avg:93.42ms +step:118/1695 train_time:11024ms step_avg:93.42ms +step:119/1695 train_time:11117ms step_avg:93.42ms +step:120/1695 train_time:11211ms step_avg:93.43ms +step:121/1695 train_time:11305ms step_avg:93.43ms +step:122/1695 train_time:11398ms step_avg:93.43ms +step:123/1695 train_time:11492ms step_avg:93.43ms +step:124/1695 train_time:11586ms step_avg:93.43ms +step:125/1695 train_time:11679ms step_avg:93.43ms +step:125/1695 val_loss:4.6043 train_time:11771ms step_avg:94.17ms +step:126/1695 train_time:11798ms step_avg:93.64ms +step:127/1695 train_time:11873ms step_avg:93.49ms +step:128/1695 train_time:11977ms step_avg:93.57ms +step:129/1695 train_time:12072ms step_avg:93.58ms +step:130/1695 train_time:12166ms step_avg:93.58ms +step:131/1695 train_time:12260ms step_avg:93.59ms +step:132/1695 train_time:12353ms step_avg:93.58ms +step:133/1695 train_time:12447ms step_avg:93.58ms +step:134/1695 train_time:12541ms step_avg:93.59ms +step:135/1695 train_time:12634ms step_avg:93.59ms +step:136/1695 train_time:12727ms step_avg:93.58ms +step:137/1695 train_time:12821ms step_avg:93.59ms +step:138/1695 train_time:12917ms step_avg:93.60ms +step:139/1695 train_time:13011ms step_avg:93.61ms +step:140/1695 train_time:13105ms step_avg:93.61ms +step:141/1695 train_time:13200ms step_avg:93.62ms +step:142/1695 train_time:13294ms step_avg:93.62ms +step:143/1695 train_time:13387ms step_avg:93.62ms +step:144/1695 train_time:13481ms step_avg:93.62ms +step:145/1695 train_time:13575ms step_avg:93.62ms +step:146/1695 train_time:13668ms step_avg:93.62ms +step:147/1695 train_time:13762ms step_avg:93.62ms +step:148/1695 train_time:13856ms step_avg:93.62ms +step:149/1695 train_time:13950ms step_avg:93.63ms +step:150/1695 train_time:14045ms step_avg:93.63ms +step:151/1695 train_time:14140ms step_avg:93.64ms +step:152/1695 train_time:14234ms step_avg:93.65ms +step:153/1695 train_time:14328ms step_avg:93.65ms +step:154/1695 train_time:14422ms step_avg:93.65ms +step:155/1695 train_time:14516ms step_avg:93.65ms +step:156/1695 train_time:14611ms step_avg:93.66ms +step:157/1695 train_time:14705ms step_avg:93.66ms +step:158/1695 train_time:14798ms step_avg:93.66ms +step:159/1695 train_time:14892ms step_avg:93.66ms +step:160/1695 train_time:14986ms step_avg:93.66ms +step:161/1695 train_time:15081ms step_avg:93.67ms +step:162/1695 train_time:15175ms step_avg:93.67ms +step:163/1695 train_time:15270ms step_avg:93.68ms +step:164/1695 train_time:15363ms step_avg:93.68ms +step:165/1695 train_time:15457ms step_avg:93.68ms +step:166/1695 train_time:15552ms step_avg:93.69ms +step:167/1695 train_time:15645ms step_avg:93.69ms +step:168/1695 train_time:15739ms step_avg:93.69ms +step:169/1695 train_time:15834ms step_avg:93.69ms +step:170/1695 train_time:15928ms step_avg:93.69ms +step:171/1695 train_time:16021ms step_avg:93.69ms +step:172/1695 train_time:16116ms step_avg:93.70ms +step:173/1695 train_time:16209ms step_avg:93.69ms +step:174/1695 train_time:16303ms step_avg:93.70ms +step:175/1695 train_time:16398ms step_avg:93.70ms +step:176/1695 train_time:16492ms step_avg:93.70ms +step:177/1695 train_time:16586ms step_avg:93.70ms +step:178/1695 train_time:16680ms step_avg:93.71ms +step:179/1695 train_time:16774ms step_avg:93.71ms +step:180/1695 train_time:16868ms step_avg:93.71ms +step:181/1695 train_time:16962ms step_avg:93.71ms +step:182/1695 train_time:17057ms step_avg:93.72ms +step:183/1695 train_time:17151ms step_avg:93.72ms +step:184/1695 train_time:17244ms step_avg:93.72ms +step:185/1695 train_time:17338ms step_avg:93.72ms +step:186/1695 train_time:17433ms step_avg:93.73ms +step:187/1695 train_time:17526ms step_avg:93.72ms +step:188/1695 train_time:17621ms step_avg:93.73ms +step:189/1695 train_time:17715ms step_avg:93.73ms +step:190/1695 train_time:17809ms step_avg:93.73ms +step:191/1695 train_time:17903ms step_avg:93.73ms +step:192/1695 train_time:17996ms step_avg:93.73ms +step:193/1695 train_time:18090ms step_avg:93.73ms +step:194/1695 train_time:18183ms step_avg:93.73ms +step:195/1695 train_time:18278ms step_avg:93.73ms +step:196/1695 train_time:18372ms step_avg:93.74ms +step:197/1695 train_time:18466ms step_avg:93.74ms +step:198/1695 train_time:18560ms step_avg:93.74ms +step:199/1695 train_time:18654ms step_avg:93.74ms +step:200/1695 train_time:18747ms step_avg:93.73ms +step:201/1695 train_time:18840ms step_avg:93.73ms +step:202/1695 train_time:18934ms step_avg:93.73ms +step:203/1695 train_time:19028ms step_avg:93.73ms +step:204/1695 train_time:19121ms step_avg:93.73ms +step:205/1695 train_time:19216ms step_avg:93.74ms +step:206/1695 train_time:19311ms step_avg:93.74ms +step:207/1695 train_time:19404ms step_avg:93.74ms +step:208/1695 train_time:19499ms step_avg:93.74ms +step:209/1695 train_time:19593ms step_avg:93.75ms +step:210/1695 train_time:19687ms step_avg:93.75ms +step:211/1695 train_time:19781ms step_avg:93.75ms +step:212/1695 train_time:19877ms step_avg:93.76ms +step:213/1695 train_time:19971ms step_avg:93.76ms +step:214/1695 train_time:20065ms step_avg:93.76ms +step:215/1695 train_time:20159ms step_avg:93.76ms +step:216/1695 train_time:20253ms step_avg:93.77ms +step:217/1695 train_time:20347ms step_avg:93.77ms +step:218/1695 train_time:20441ms step_avg:93.77ms +step:219/1695 train_time:20536ms step_avg:93.77ms +step:220/1695 train_time:20630ms step_avg:93.77ms +step:221/1695 train_time:20723ms step_avg:93.77ms +step:222/1695 train_time:20817ms step_avg:93.77ms +step:223/1695 train_time:20912ms step_avg:93.78ms +step:224/1695 train_time:21005ms step_avg:93.77ms +step:225/1695 train_time:21099ms step_avg:93.77ms +step:226/1695 train_time:21192ms step_avg:93.77ms +step:227/1695 train_time:21285ms step_avg:93.77ms +step:228/1695 train_time:21379ms step_avg:93.77ms +step:229/1695 train_time:21473ms step_avg:93.77ms +step:230/1695 train_time:21567ms step_avg:93.77ms +step:231/1695 train_time:21661ms step_avg:93.77ms +step:232/1695 train_time:21755ms step_avg:93.77ms +step:233/1695 train_time:21849ms step_avg:93.77ms +step:234/1695 train_time:21943ms step_avg:93.77ms +step:235/1695 train_time:22037ms step_avg:93.78ms +step:236/1695 train_time:22131ms step_avg:93.78ms +step:237/1695 train_time:22225ms step_avg:93.78ms +step:238/1695 train_time:22320ms step_avg:93.78ms +step:239/1695 train_time:22414ms step_avg:93.78ms +step:240/1695 train_time:22507ms step_avg:93.78ms +step:241/1695 train_time:22602ms step_avg:93.79ms +step:242/1695 train_time:22697ms step_avg:93.79ms +step:243/1695 train_time:22791ms step_avg:93.79ms +step:244/1695 train_time:22884ms step_avg:93.79ms +step:245/1695 train_time:22980ms step_avg:93.80ms +step:246/1695 train_time:23076ms step_avg:93.81ms +step:247/1695 train_time:23169ms step_avg:93.80ms +step:248/1695 train_time:23263ms step_avg:93.80ms +step:249/1695 train_time:23356ms step_avg:93.80ms +step:250/1695 train_time:23450ms step_avg:93.80ms +step:250/1695 val_loss:4.0858 train_time:23541ms step_avg:94.16ms +step:251/1695 train_time:23567ms step_avg:93.89ms +step:252/1695 train_time:23645ms step_avg:93.83ms +step:253/1695 train_time:23744ms step_avg:93.85ms +step:254/1695 train_time:23840ms step_avg:93.86ms +step:255/1695 train_time:23933ms step_avg:93.86ms +step:256/1695 train_time:24027ms step_avg:93.86ms +step:257/1695 train_time:24121ms step_avg:93.85ms +step:258/1695 train_time:24214ms step_avg:93.85ms +step:259/1695 train_time:24308ms step_avg:93.85ms +step:260/1695 train_time:24402ms step_avg:93.85ms +step:261/1695 train_time:24496ms step_avg:93.85ms +step:262/1695 train_time:24593ms step_avg:93.87ms +step:263/1695 train_time:24691ms step_avg:93.88ms +step:264/1695 train_time:24788ms step_avg:93.89ms +step:265/1695 train_time:24885ms step_avg:93.90ms +step:266/1695 train_time:24978ms step_avg:93.90ms +step:267/1695 train_time:25072ms step_avg:93.90ms +step:268/1695 train_time:25166ms step_avg:93.90ms +step:269/1695 train_time:25260ms step_avg:93.90ms +step:270/1695 train_time:25353ms step_avg:93.90ms +step:271/1695 train_time:25448ms step_avg:93.90ms +step:272/1695 train_time:25542ms step_avg:93.90ms +step:273/1695 train_time:25636ms step_avg:93.91ms +step:274/1695 train_time:25733ms step_avg:93.92ms +step:275/1695 train_time:25829ms step_avg:93.92ms +step:276/1695 train_time:25923ms step_avg:93.92ms +step:277/1695 train_time:26017ms step_avg:93.92ms +step:278/1695 train_time:26112ms step_avg:93.93ms +step:279/1695 train_time:26207ms step_avg:93.93ms +step:280/1695 train_time:26300ms step_avg:93.93ms +step:281/1695 train_time:26394ms step_avg:93.93ms +step:282/1695 train_time:26488ms step_avg:93.93ms +step:283/1695 train_time:26583ms step_avg:93.93ms +step:284/1695 train_time:26678ms step_avg:93.93ms +step:285/1695 train_time:26772ms step_avg:93.94ms +step:286/1695 train_time:26867ms step_avg:93.94ms +step:287/1695 train_time:26962ms step_avg:93.95ms +step:288/1695 train_time:27057ms step_avg:93.95ms +step:289/1695 train_time:27151ms step_avg:93.95ms +step:290/1695 train_time:27246ms step_avg:93.95ms +step:291/1695 train_time:27340ms step_avg:93.95ms +step:292/1695 train_time:27433ms step_avg:93.95ms +step:293/1695 train_time:27527ms step_avg:93.95ms +step:294/1695 train_time:27622ms step_avg:93.95ms +step:295/1695 train_time:27717ms step_avg:93.96ms +step:296/1695 train_time:27813ms step_avg:93.96ms +step:297/1695 train_time:27908ms step_avg:93.97ms +step:298/1695 train_time:28003ms step_avg:93.97ms +step:299/1695 train_time:28096ms step_avg:93.97ms +step:300/1695 train_time:28193ms step_avg:93.98ms +step:301/1695 train_time:28288ms step_avg:93.98ms +step:302/1695 train_time:28382ms step_avg:93.98ms +step:303/1695 train_time:28476ms step_avg:93.98ms +step:304/1695 train_time:28570ms step_avg:93.98ms +step:305/1695 train_time:28664ms step_avg:93.98ms +step:306/1695 train_time:28758ms step_avg:93.98ms +step:307/1695 train_time:28853ms step_avg:93.98ms +step:308/1695 train_time:28948ms step_avg:93.99ms +step:309/1695 train_time:29043ms step_avg:93.99ms +step:310/1695 train_time:29137ms step_avg:93.99ms +step:311/1695 train_time:29233ms step_avg:94.00ms +step:312/1695 train_time:29328ms step_avg:94.00ms +step:313/1695 train_time:29422ms step_avg:94.00ms +step:314/1695 train_time:29516ms step_avg:94.00ms +step:315/1695 train_time:29611ms step_avg:94.00ms +step:316/1695 train_time:29706ms step_avg:94.00ms +step:317/1695 train_time:29800ms step_avg:94.01ms +step:318/1695 train_time:29894ms step_avg:94.01ms +step:319/1695 train_time:29989ms step_avg:94.01ms +step:320/1695 train_time:30084ms step_avg:94.01ms +step:321/1695 train_time:30178ms step_avg:94.01ms +step:322/1695 train_time:30273ms step_avg:94.01ms +step:323/1695 train_time:30368ms step_avg:94.02ms +step:324/1695 train_time:30463ms step_avg:94.02ms +step:325/1695 train_time:30556ms step_avg:94.02ms +step:326/1695 train_time:30651ms step_avg:94.02ms +step:327/1695 train_time:30746ms step_avg:94.02ms +step:328/1695 train_time:30839ms step_avg:94.02ms +step:329/1695 train_time:30934ms step_avg:94.02ms +step:330/1695 train_time:31029ms step_avg:94.03ms +step:331/1695 train_time:31124ms step_avg:94.03ms +step:332/1695 train_time:31218ms step_avg:94.03ms +step:333/1695 train_time:31313ms step_avg:94.03ms +step:334/1695 train_time:31407ms step_avg:94.03ms +step:335/1695 train_time:31502ms step_avg:94.03ms +step:336/1695 train_time:31595ms step_avg:94.03ms +step:337/1695 train_time:31691ms step_avg:94.04ms +step:338/1695 train_time:31785ms step_avg:94.04ms +step:339/1695 train_time:31879ms step_avg:94.04ms +step:340/1695 train_time:31973ms step_avg:94.04ms +step:341/1695 train_time:32069ms step_avg:94.04ms +step:342/1695 train_time:32163ms step_avg:94.04ms +step:343/1695 train_time:32257ms step_avg:94.04ms +step:344/1695 train_time:32352ms step_avg:94.05ms +step:345/1695 train_time:32446ms step_avg:94.05ms +step:346/1695 train_time:32541ms step_avg:94.05ms +step:347/1695 train_time:32634ms step_avg:94.05ms +step:348/1695 train_time:32728ms step_avg:94.05ms +step:349/1695 train_time:32822ms step_avg:94.05ms +step:350/1695 train_time:32916ms step_avg:94.05ms +step:351/1695 train_time:33012ms step_avg:94.05ms +step:352/1695 train_time:33107ms step_avg:94.05ms +step:353/1695 train_time:33201ms step_avg:94.05ms +step:354/1695 train_time:33295ms step_avg:94.05ms +step:355/1695 train_time:33389ms step_avg:94.05ms +step:356/1695 train_time:33484ms step_avg:94.06ms +step:357/1695 train_time:33578ms step_avg:94.05ms +step:358/1695 train_time:33672ms step_avg:94.05ms +step:359/1695 train_time:33767ms step_avg:94.06ms +step:360/1695 train_time:33860ms step_avg:94.06ms +step:361/1695 train_time:33954ms step_avg:94.06ms +step:362/1695 train_time:34050ms step_avg:94.06ms +step:363/1695 train_time:34145ms step_avg:94.06ms +step:364/1695 train_time:34239ms step_avg:94.06ms +step:365/1695 train_time:34332ms step_avg:94.06ms +step:366/1695 train_time:34427ms step_avg:94.06ms +step:367/1695 train_time:34521ms step_avg:94.06ms +step:368/1695 train_time:34615ms step_avg:94.06ms +step:369/1695 train_time:34709ms step_avg:94.06ms +step:370/1695 train_time:34805ms step_avg:94.07ms +step:371/1695 train_time:34899ms step_avg:94.07ms +step:372/1695 train_time:34993ms step_avg:94.07ms +step:373/1695 train_time:35088ms step_avg:94.07ms +step:374/1695 train_time:35183ms step_avg:94.07ms +step:375/1695 train_time:35277ms step_avg:94.07ms +step:375/1695 val_loss:3.8837 train_time:35370ms step_avg:94.32ms +step:376/1695 train_time:35397ms step_avg:94.14ms +step:377/1695 train_time:35476ms step_avg:94.10ms +step:378/1695 train_time:35573ms step_avg:94.11ms +step:379/1695 train_time:35670ms step_avg:94.12ms +step:380/1695 train_time:35765ms step_avg:94.12ms +step:381/1695 train_time:35861ms step_avg:94.12ms +step:382/1695 train_time:35956ms step_avg:94.13ms +step:383/1695 train_time:36051ms step_avg:94.13ms +step:384/1695 train_time:36146ms step_avg:94.13ms +step:385/1695 train_time:36242ms step_avg:94.13ms +step:386/1695 train_time:36337ms step_avg:94.14ms +step:387/1695 train_time:36434ms step_avg:94.14ms +step:388/1695 train_time:36531ms step_avg:94.15ms +step:389/1695 train_time:36628ms step_avg:94.16ms +step:390/1695 train_time:36726ms step_avg:94.17ms +step:391/1695 train_time:36822ms step_avg:94.17ms +step:392/1695 train_time:36918ms step_avg:94.18ms +step:393/1695 train_time:37014ms step_avg:94.18ms +step:394/1695 train_time:37109ms step_avg:94.18ms +step:395/1695 train_time:37205ms step_avg:94.19ms +step:396/1695 train_time:37302ms step_avg:94.20ms +step:397/1695 train_time:37398ms step_avg:94.20ms +step:398/1695 train_time:37494ms step_avg:94.21ms +step:399/1695 train_time:37590ms step_avg:94.21ms +step:400/1695 train_time:37687ms step_avg:94.22ms +step:401/1695 train_time:37785ms step_avg:94.23ms +step:402/1695 train_time:37881ms step_avg:94.23ms +step:403/1695 train_time:37978ms step_avg:94.24ms +step:404/1695 train_time:38073ms step_avg:94.24ms +step:405/1695 train_time:38169ms step_avg:94.24ms +step:406/1695 train_time:38265ms step_avg:94.25ms +step:407/1695 train_time:38362ms step_avg:94.25ms +step:408/1695 train_time:38459ms step_avg:94.26ms +step:409/1695 train_time:38555ms step_avg:94.27ms +step:410/1695 train_time:38650ms step_avg:94.27ms +step:411/1695 train_time:38746ms step_avg:94.27ms +step:412/1695 train_time:38842ms step_avg:94.28ms +step:413/1695 train_time:38939ms step_avg:94.28ms +step:414/1695 train_time:39035ms step_avg:94.29ms +step:415/1695 train_time:39131ms step_avg:94.29ms +step:416/1695 train_time:39227ms step_avg:94.30ms +step:417/1695 train_time:39323ms step_avg:94.30ms +step:418/1695 train_time:39420ms step_avg:94.31ms +step:419/1695 train_time:39516ms step_avg:94.31ms +step:420/1695 train_time:39612ms step_avg:94.31ms +step:421/1695 train_time:39708ms step_avg:94.32ms +step:422/1695 train_time:39805ms step_avg:94.32ms +step:423/1695 train_time:39902ms step_avg:94.33ms +step:424/1695 train_time:39998ms step_avg:94.34ms +step:425/1695 train_time:40094ms step_avg:94.34ms +step:426/1695 train_time:40190ms step_avg:94.34ms +step:427/1695 train_time:40286ms step_avg:94.35ms +step:428/1695 train_time:40383ms step_avg:94.35ms +step:429/1695 train_time:40479ms step_avg:94.36ms +step:430/1695 train_time:40575ms step_avg:94.36ms +step:431/1695 train_time:40672ms step_avg:94.37ms +step:432/1695 train_time:40768ms step_avg:94.37ms +step:433/1695 train_time:40864ms step_avg:94.37ms +step:434/1695 train_time:40962ms step_avg:94.38ms +step:435/1695 train_time:41059ms step_avg:94.39ms +step:436/1695 train_time:41155ms step_avg:94.39ms +step:437/1695 train_time:41251ms step_avg:94.40ms +step:438/1695 train_time:41347ms step_avg:94.40ms +step:439/1695 train_time:41444ms step_avg:94.40ms +step:440/1695 train_time:41541ms step_avg:94.41ms +step:441/1695 train_time:41637ms step_avg:94.42ms +step:442/1695 train_time:41733ms step_avg:94.42ms +step:443/1695 train_time:41829ms step_avg:94.42ms +step:444/1695 train_time:41925ms step_avg:94.43ms +step:445/1695 train_time:42022ms step_avg:94.43ms +step:446/1695 train_time:42120ms step_avg:94.44ms +step:447/1695 train_time:42216ms step_avg:94.44ms +step:448/1695 train_time:42313ms step_avg:94.45ms +step:449/1695 train_time:42408ms step_avg:94.45ms +step:450/1695 train_time:42505ms step_avg:94.46ms +step:451/1695 train_time:42602ms step_avg:94.46ms +step:452/1695 train_time:42699ms step_avg:94.47ms +step:453/1695 train_time:42795ms step_avg:94.47ms +step:454/1695 train_time:42891ms step_avg:94.47ms +step:455/1695 train_time:42987ms step_avg:94.48ms +step:456/1695 train_time:43084ms step_avg:94.48ms +step:457/1695 train_time:43181ms step_avg:94.49ms +step:458/1695 train_time:43278ms step_avg:94.49ms +step:459/1695 train_time:43374ms step_avg:94.50ms +step:460/1695 train_time:43470ms step_avg:94.50ms +step:461/1695 train_time:43566ms step_avg:94.50ms +step:462/1695 train_time:43663ms step_avg:94.51ms +step:463/1695 train_time:43760ms step_avg:94.51ms +step:464/1695 train_time:43856ms step_avg:94.52ms +step:465/1695 train_time:43952ms step_avg:94.52ms +step:466/1695 train_time:44048ms step_avg:94.52ms +step:467/1695 train_time:44144ms step_avg:94.53ms +step:468/1695 train_time:44241ms step_avg:94.53ms +step:469/1695 train_time:44338ms step_avg:94.54ms +step:470/1695 train_time:44434ms step_avg:94.54ms +step:471/1695 train_time:44529ms step_avg:94.54ms +step:472/1695 train_time:44625ms step_avg:94.54ms +step:473/1695 train_time:44722ms step_avg:94.55ms +step:474/1695 train_time:44819ms step_avg:94.55ms +step:475/1695 train_time:44915ms step_avg:94.56ms +step:476/1695 train_time:45011ms step_avg:94.56ms +step:477/1695 train_time:45107ms step_avg:94.56ms +step:478/1695 train_time:45204ms step_avg:94.57ms +step:479/1695 train_time:45300ms step_avg:94.57ms +step:480/1695 train_time:45397ms step_avg:94.58ms +step:481/1695 train_time:45493ms step_avg:94.58ms +step:482/1695 train_time:45588ms step_avg:94.58ms +step:483/1695 train_time:45685ms step_avg:94.59ms +step:484/1695 train_time:45782ms step_avg:94.59ms +step:485/1695 train_time:45877ms step_avg:94.59ms +step:486/1695 train_time:45973ms step_avg:94.60ms +step:487/1695 train_time:46069ms step_avg:94.60ms +step:488/1695 train_time:46166ms step_avg:94.60ms +step:489/1695 train_time:46262ms step_avg:94.61ms +step:490/1695 train_time:46359ms step_avg:94.61ms +step:491/1695 train_time:46456ms step_avg:94.61ms +step:492/1695 train_time:46550ms step_avg:94.61ms +step:493/1695 train_time:46646ms step_avg:94.62ms +step:494/1695 train_time:46743ms step_avg:94.62ms +step:495/1695 train_time:46839ms step_avg:94.62ms +step:496/1695 train_time:46936ms step_avg:94.63ms +step:497/1695 train_time:47031ms step_avg:94.63ms +step:498/1695 train_time:47127ms step_avg:94.63ms +step:499/1695 train_time:47223ms step_avg:94.64ms +step:500/1695 train_time:47320ms step_avg:94.64ms +step:500/1695 val_loss:3.7375 train_time:47413ms step_avg:94.83ms +step:501/1695 train_time:47441ms step_avg:94.69ms +step:502/1695 train_time:47522ms step_avg:94.66ms +step:503/1695 train_time:47624ms step_avg:94.68ms +step:504/1695 train_time:47721ms step_avg:94.68ms +step:505/1695 train_time:47817ms step_avg:94.69ms +step:506/1695 train_time:47912ms step_avg:94.69ms +step:507/1695 train_time:48008ms step_avg:94.69ms +step:508/1695 train_time:48104ms step_avg:94.69ms +step:509/1695 train_time:48200ms step_avg:94.69ms +step:510/1695 train_time:48295ms step_avg:94.70ms +step:511/1695 train_time:48391ms step_avg:94.70ms +step:512/1695 train_time:48489ms step_avg:94.70ms +step:513/1695 train_time:48588ms step_avg:94.71ms +step:514/1695 train_time:48686ms step_avg:94.72ms +step:515/1695 train_time:48783ms step_avg:94.72ms +step:516/1695 train_time:48880ms step_avg:94.73ms +step:517/1695 train_time:48976ms step_avg:94.73ms +step:518/1695 train_time:49072ms step_avg:94.73ms +step:519/1695 train_time:49168ms step_avg:94.74ms +step:520/1695 train_time:49264ms step_avg:94.74ms +step:521/1695 train_time:49360ms step_avg:94.74ms +step:522/1695 train_time:49456ms step_avg:94.74ms +step:523/1695 train_time:49553ms step_avg:94.75ms +step:524/1695 train_time:49650ms step_avg:94.75ms +step:525/1695 train_time:49747ms step_avg:94.76ms +step:526/1695 train_time:49845ms step_avg:94.76ms +step:527/1695 train_time:49942ms step_avg:94.77ms +step:528/1695 train_time:50039ms step_avg:94.77ms +step:529/1695 train_time:50136ms step_avg:94.77ms +step:530/1695 train_time:50231ms step_avg:94.78ms +step:531/1695 train_time:50327ms step_avg:94.78ms +step:532/1695 train_time:50424ms step_avg:94.78ms +step:533/1695 train_time:50521ms step_avg:94.79ms +step:534/1695 train_time:50618ms step_avg:94.79ms +step:535/1695 train_time:50713ms step_avg:94.79ms +step:536/1695 train_time:50810ms step_avg:94.80ms +step:537/1695 train_time:50908ms step_avg:94.80ms +step:538/1695 train_time:51006ms step_avg:94.81ms +step:539/1695 train_time:51105ms step_avg:94.81ms +step:540/1695 train_time:51200ms step_avg:94.82ms +step:541/1695 train_time:51296ms step_avg:94.82ms +step:542/1695 train_time:51392ms step_avg:94.82ms +step:543/1695 train_time:51488ms step_avg:94.82ms +step:544/1695 train_time:51585ms step_avg:94.83ms +step:545/1695 train_time:51682ms step_avg:94.83ms +step:546/1695 train_time:51779ms step_avg:94.83ms +step:547/1695 train_time:51875ms step_avg:94.84ms +step:548/1695 train_time:51972ms step_avg:94.84ms +step:549/1695 train_time:52069ms step_avg:94.84ms +step:550/1695 train_time:52166ms step_avg:94.85ms +step:551/1695 train_time:52263ms step_avg:94.85ms +step:552/1695 train_time:52360ms step_avg:94.86ms +step:553/1695 train_time:52456ms step_avg:94.86ms +step:554/1695 train_time:52552ms step_avg:94.86ms +step:555/1695 train_time:52648ms step_avg:94.86ms +step:556/1695 train_time:52746ms step_avg:94.87ms +step:557/1695 train_time:52842ms step_avg:94.87ms +step:558/1695 train_time:52939ms step_avg:94.87ms +step:559/1695 train_time:53036ms step_avg:94.88ms +step:560/1695 train_time:53132ms step_avg:94.88ms +step:561/1695 train_time:53228ms step_avg:94.88ms +step:562/1695 train_time:53326ms step_avg:94.89ms +step:563/1695 train_time:53423ms step_avg:94.89ms +step:564/1695 train_time:53520ms step_avg:94.89ms +step:565/1695 train_time:53616ms step_avg:94.90ms +step:566/1695 train_time:53712ms step_avg:94.90ms +step:567/1695 train_time:53809ms step_avg:94.90ms +step:568/1695 train_time:53906ms step_avg:94.91ms +step:569/1695 train_time:54003ms step_avg:94.91ms +step:570/1695 train_time:54100ms step_avg:94.91ms +step:571/1695 train_time:54196ms step_avg:94.91ms +step:572/1695 train_time:54292ms step_avg:94.92ms +step:573/1695 train_time:54389ms step_avg:94.92ms +step:574/1695 train_time:54486ms step_avg:94.92ms +step:575/1695 train_time:54584ms step_avg:94.93ms +step:576/1695 train_time:54681ms step_avg:94.93ms +step:577/1695 train_time:54778ms step_avg:94.94ms +step:578/1695 train_time:54874ms step_avg:94.94ms +step:579/1695 train_time:54972ms step_avg:94.94ms +step:580/1695 train_time:55069ms step_avg:94.95ms +step:581/1695 train_time:55166ms step_avg:94.95ms +step:582/1695 train_time:55263ms step_avg:94.95ms +step:583/1695 train_time:55361ms step_avg:94.96ms +step:584/1695 train_time:55457ms step_avg:94.96ms +step:585/1695 train_time:55553ms step_avg:94.96ms +step:586/1695 train_time:55649ms step_avg:94.96ms +step:587/1695 train_time:55747ms step_avg:94.97ms +step:588/1695 train_time:55844ms step_avg:94.97ms +step:589/1695 train_time:55941ms step_avg:94.98ms +step:590/1695 train_time:56037ms step_avg:94.98ms +step:591/1695 train_time:56133ms step_avg:94.98ms +step:592/1695 train_time:56230ms step_avg:94.98ms +step:593/1695 train_time:56328ms step_avg:94.99ms +step:594/1695 train_time:56426ms step_avg:94.99ms +step:595/1695 train_time:56523ms step_avg:95.00ms +step:596/1695 train_time:56618ms step_avg:95.00ms +step:597/1695 train_time:56714ms step_avg:95.00ms +step:598/1695 train_time:56811ms step_avg:95.00ms +step:599/1695 train_time:56908ms step_avg:95.00ms +step:600/1695 train_time:57006ms step_avg:95.01ms +step:601/1695 train_time:57103ms step_avg:95.01ms +step:602/1695 train_time:57199ms step_avg:95.02ms +step:603/1695 train_time:57295ms step_avg:95.02ms +step:604/1695 train_time:57392ms step_avg:95.02ms +step:605/1695 train_time:57489ms step_avg:95.02ms +step:606/1695 train_time:57587ms step_avg:95.03ms +step:607/1695 train_time:57684ms step_avg:95.03ms +step:608/1695 train_time:57780ms step_avg:95.03ms +step:609/1695 train_time:57877ms step_avg:95.04ms +step:610/1695 train_time:57973ms step_avg:95.04ms +step:611/1695 train_time:58070ms step_avg:95.04ms +step:612/1695 train_time:58167ms step_avg:95.04ms +step:613/1695 train_time:58264ms step_avg:95.05ms +step:614/1695 train_time:58361ms step_avg:95.05ms +step:615/1695 train_time:58457ms step_avg:95.05ms +step:616/1695 train_time:58553ms step_avg:95.05ms +step:617/1695 train_time:58649ms step_avg:95.05ms +step:618/1695 train_time:58745ms step_avg:95.06ms +step:619/1695 train_time:58843ms step_avg:95.06ms +step:620/1695 train_time:58939ms step_avg:95.06ms +step:621/1695 train_time:59035ms step_avg:95.07ms +step:622/1695 train_time:59131ms step_avg:95.07ms +step:623/1695 train_time:59228ms step_avg:95.07ms +step:624/1695 train_time:59325ms step_avg:95.07ms +step:625/1695 train_time:59423ms step_avg:95.08ms +step:625/1695 val_loss:3.6553 train_time:59518ms step_avg:95.23ms +step:626/1695 train_time:59544ms step_avg:95.12ms +step:627/1695 train_time:59627ms step_avg:95.10ms +step:628/1695 train_time:59729ms step_avg:95.11ms +step:629/1695 train_time:59826ms step_avg:95.11ms +step:630/1695 train_time:59924ms step_avg:95.12ms +step:631/1695 train_time:60021ms step_avg:95.12ms +step:632/1695 train_time:60119ms step_avg:95.13ms +step:633/1695 train_time:60216ms step_avg:95.13ms +step:634/1695 train_time:60313ms step_avg:95.13ms +step:635/1695 train_time:60642ms step_avg:95.50ms +step:636/1695 train_time:60738ms step_avg:95.50ms +step:637/1695 train_time:60834ms step_avg:95.50ms +step:638/1695 train_time:60931ms step_avg:95.50ms +step:639/1695 train_time:61028ms step_avg:95.51ms +step:640/1695 train_time:61125ms step_avg:95.51ms +step:641/1695 train_time:61222ms step_avg:95.51ms +step:642/1695 train_time:61318ms step_avg:95.51ms +step:643/1695 train_time:61415ms step_avg:95.51ms +step:644/1695 train_time:61512ms step_avg:95.51ms +step:645/1695 train_time:61612ms step_avg:95.52ms +step:646/1695 train_time:61711ms step_avg:95.53ms +step:647/1695 train_time:61810ms step_avg:95.53ms +step:648/1695 train_time:61907ms step_avg:95.54ms +step:649/1695 train_time:62005ms step_avg:95.54ms +step:650/1695 train_time:62103ms step_avg:95.54ms +step:651/1695 train_time:62201ms step_avg:95.55ms +step:652/1695 train_time:62298ms step_avg:95.55ms +step:653/1695 train_time:62395ms step_avg:95.55ms +step:654/1695 train_time:62493ms step_avg:95.55ms +step:655/1695 train_time:62590ms step_avg:95.56ms +step:656/1695 train_time:62689ms step_avg:95.56ms +step:657/1695 train_time:62788ms step_avg:95.57ms +step:658/1695 train_time:62886ms step_avg:95.57ms +step:659/1695 train_time:62984ms step_avg:95.57ms +step:660/1695 train_time:63081ms step_avg:95.58ms +step:661/1695 train_time:63177ms step_avg:95.58ms +step:662/1695 train_time:63274ms step_avg:95.58ms +step:663/1695 train_time:63371ms step_avg:95.58ms +step:664/1695 train_time:63469ms step_avg:95.59ms +step:665/1695 train_time:63567ms step_avg:95.59ms +step:666/1695 train_time:63666ms step_avg:95.59ms +step:667/1695 train_time:63765ms step_avg:95.60ms +step:668/1695 train_time:63864ms step_avg:95.60ms +step:669/1695 train_time:63963ms step_avg:95.61ms +step:670/1695 train_time:64061ms step_avg:95.61ms +step:671/1695 train_time:64158ms step_avg:95.62ms +step:672/1695 train_time:64255ms step_avg:95.62ms +step:673/1695 train_time:64352ms step_avg:95.62ms +step:674/1695 train_time:64449ms step_avg:95.62ms +step:675/1695 train_time:64546ms step_avg:95.62ms +step:676/1695 train_time:64644ms step_avg:95.63ms +step:677/1695 train_time:64742ms step_avg:95.63ms +step:678/1695 train_time:64841ms step_avg:95.64ms +step:679/1695 train_time:64940ms step_avg:95.64ms +step:680/1695 train_time:65039ms step_avg:95.65ms +step:681/1695 train_time:65136ms step_avg:95.65ms +step:682/1695 train_time:65233ms step_avg:95.65ms +step:683/1695 train_time:65330ms step_avg:95.65ms +step:684/1695 train_time:65429ms step_avg:95.66ms +step:685/1695 train_time:65526ms step_avg:95.66ms +step:686/1695 train_time:65625ms step_avg:95.66ms +step:687/1695 train_time:65724ms step_avg:95.67ms +step:688/1695 train_time:65822ms step_avg:95.67ms +step:689/1695 train_time:65920ms step_avg:95.68ms +step:690/1695 train_time:66018ms step_avg:95.68ms +step:691/1695 train_time:66115ms step_avg:95.68ms +step:692/1695 train_time:66213ms step_avg:95.68ms +step:693/1695 train_time:66310ms step_avg:95.69ms +step:694/1695 train_time:66408ms step_avg:95.69ms +step:695/1695 train_time:66506ms step_avg:95.69ms +step:696/1695 train_time:66892ms step_avg:96.11ms +step:697/1695 train_time:66988ms step_avg:96.11ms +step:698/1695 train_time:67085ms step_avg:96.11ms +step:699/1695 train_time:67182ms step_avg:96.11ms +step:700/1695 train_time:67280ms step_avg:96.11ms +step:701/1695 train_time:67377ms step_avg:96.12ms +step:702/1695 train_time:67473ms step_avg:96.12ms +step:703/1695 train_time:67570ms step_avg:96.12ms +step:704/1695 train_time:67668ms step_avg:96.12ms +step:705/1695 train_time:67765ms step_avg:96.12ms +step:706/1695 train_time:67869ms step_avg:96.13ms +step:707/1695 train_time:67968ms step_avg:96.14ms +step:708/1695 train_time:68067ms step_avg:96.14ms +step:709/1695 train_time:68165ms step_avg:96.14ms +step:710/1695 train_time:68263ms step_avg:96.14ms +step:711/1695 train_time:68361ms step_avg:96.15ms +step:712/1695 train_time:68459ms step_avg:96.15ms +step:713/1695 train_time:68557ms step_avg:96.15ms +step:714/1695 train_time:68656ms step_avg:96.16ms +step:715/1695 train_time:68753ms step_avg:96.16ms +step:716/1695 train_time:68851ms step_avg:96.16ms +step:717/1695 train_time:68950ms step_avg:96.16ms +step:718/1695 train_time:69047ms step_avg:96.17ms +step:719/1695 train_time:69389ms step_avg:96.51ms +step:720/1695 train_time:69486ms step_avg:96.51ms +step:721/1695 train_time:69583ms step_avg:96.51ms +step:722/1695 train_time:69681ms step_avg:96.51ms +step:723/1695 train_time:69777ms step_avg:96.51ms +step:724/1695 train_time:69874ms step_avg:96.51ms +step:725/1695 train_time:69971ms step_avg:96.51ms +step:726/1695 train_time:70068ms step_avg:96.51ms +step:727/1695 train_time:70164ms step_avg:96.51ms +step:728/1695 train_time:70265ms step_avg:96.52ms +step:729/1695 train_time:70367ms step_avg:96.53ms +step:730/1695 train_time:70466ms step_avg:96.53ms +step:731/1695 train_time:70564ms step_avg:96.53ms +step:732/1695 train_time:70662ms step_avg:96.53ms +step:733/1695 train_time:70760ms step_avg:96.53ms +step:734/1695 train_time:70858ms step_avg:96.54ms +step:735/1695 train_time:70956ms step_avg:96.54ms +step:736/1695 train_time:71053ms step_avg:96.54ms +step:737/1695 train_time:71149ms step_avg:96.54ms +step:738/1695 train_time:71246ms step_avg:96.54ms +step:739/1695 train_time:71344ms step_avg:96.54ms +step:740/1695 train_time:71442ms step_avg:96.54ms +step:741/1695 train_time:71540ms step_avg:96.55ms +step:742/1695 train_time:71638ms step_avg:96.55ms +step:743/1695 train_time:71736ms step_avg:96.55ms +step:744/1695 train_time:71833ms step_avg:96.55ms +step:745/1695 train_time:71930ms step_avg:96.55ms +step:746/1695 train_time:72029ms step_avg:96.55ms +step:747/1695 train_time:72127ms step_avg:96.55ms +step:748/1695 train_time:72224ms step_avg:96.56ms +step:749/1695 train_time:72322ms step_avg:96.56ms +step:750/1695 train_time:72420ms step_avg:96.56ms +step:750/1695 val_loss:3.5918 train_time:72516ms step_avg:96.69ms +step:751/1695 train_time:72543ms step_avg:96.59ms +step:752/1695 train_time:72623ms step_avg:96.57ms +step:753/1695 train_time:72722ms step_avg:96.58ms +step:754/1695 train_time:72821ms step_avg:96.58ms +step:755/1695 train_time:72918ms step_avg:96.58ms +step:756/1695 train_time:73015ms step_avg:96.58ms +step:757/1695 train_time:73113ms step_avg:96.58ms +step:758/1695 train_time:73210ms step_avg:96.58ms +step:759/1695 train_time:73308ms step_avg:96.58ms +step:760/1695 train_time:73405ms step_avg:96.59ms +step:761/1695 train_time:73503ms step_avg:96.59ms +step:762/1695 train_time:73602ms step_avg:96.59ms +step:763/1695 train_time:73701ms step_avg:96.59ms +step:764/1695 train_time:73800ms step_avg:96.60ms +step:765/1695 train_time:73898ms step_avg:96.60ms +step:766/1695 train_time:73995ms step_avg:96.60ms +step:767/1695 train_time:74093ms step_avg:96.60ms +step:768/1695 train_time:74190ms step_avg:96.60ms +step:769/1695 train_time:74288ms step_avg:96.60ms +step:770/1695 train_time:74386ms step_avg:96.60ms +step:771/1695 train_time:74484ms step_avg:96.61ms +step:772/1695 train_time:74867ms step_avg:96.98ms +step:773/1695 train_time:74963ms step_avg:96.98ms +step:774/1695 train_time:75060ms step_avg:96.98ms +step:775/1695 train_time:75158ms step_avg:96.98ms +step:776/1695 train_time:75255ms step_avg:96.98ms +step:777/1695 train_time:75353ms step_avg:96.98ms +step:778/1695 train_time:75683ms step_avg:97.28ms +step:779/1695 train_time:75780ms step_avg:97.28ms +step:780/1695 train_time:75877ms step_avg:97.28ms +step:781/1695 train_time:75974ms step_avg:97.28ms +step:782/1695 train_time:76072ms step_avg:97.28ms +step:783/1695 train_time:76170ms step_avg:97.28ms +step:784/1695 train_time:76268ms step_avg:97.28ms +step:785/1695 train_time:76365ms step_avg:97.28ms +step:786/1695 train_time:76462ms step_avg:97.28ms +step:787/1695 train_time:76562ms step_avg:97.28ms +step:788/1695 train_time:76663ms step_avg:97.29ms +step:789/1695 train_time:77083ms step_avg:97.70ms +step:790/1695 train_time:77133ms step_avg:97.64ms +step:791/1695 train_time:77239ms step_avg:97.65ms +step:792/1695 train_time:77336ms step_avg:97.65ms +step:793/1695 train_time:77433ms step_avg:97.65ms +step:794/1695 train_time:77530ms step_avg:97.65ms +step:795/1695 train_time:77628ms step_avg:97.65ms +step:796/1695 train_time:77726ms step_avg:97.65ms +step:797/1695 train_time:77823ms step_avg:97.65ms +step:798/1695 train_time:78155ms step_avg:97.94ms +step:799/1695 train_time:78249ms step_avg:97.93ms +step:800/1695 train_time:78346ms step_avg:97.93ms +step:801/1695 train_time:78444ms step_avg:97.93ms +step:802/1695 train_time:78541ms step_avg:97.93ms +step:803/1695 train_time:78638ms step_avg:97.93ms +step:804/1695 train_time:78735ms step_avg:97.93ms +step:805/1695 train_time:78832ms step_avg:97.93ms +step:806/1695 train_time:78930ms step_avg:97.93ms +step:807/1695 train_time:79030ms step_avg:97.93ms +step:808/1695 train_time:79132ms step_avg:97.94ms +step:809/1695 train_time:79232ms step_avg:97.94ms +step:810/1695 train_time:79331ms step_avg:97.94ms +step:811/1695 train_time:79430ms step_avg:97.94ms +step:812/1695 train_time:79529ms step_avg:97.94ms +step:813/1695 train_time:79627ms step_avg:97.94ms +step:814/1695 train_time:79724ms step_avg:97.94ms +step:815/1695 train_time:79822ms step_avg:97.94ms +step:816/1695 train_time:79918ms step_avg:97.94ms +step:817/1695 train_time:80016ms step_avg:97.94ms +step:818/1695 train_time:80114ms step_avg:97.94ms +step:819/1695 train_time:80213ms step_avg:97.94ms +step:820/1695 train_time:80311ms step_avg:97.94ms +step:821/1695 train_time:80410ms step_avg:97.94ms +step:822/1695 train_time:80509ms step_avg:97.94ms +step:823/1695 train_time:80607ms step_avg:97.94ms +step:824/1695 train_time:80705ms step_avg:97.94ms +step:825/1695 train_time:80802ms step_avg:97.94ms +step:826/1695 train_time:80900ms step_avg:97.94ms +step:827/1695 train_time:80998ms step_avg:97.94ms +step:828/1695 train_time:81096ms step_avg:97.94ms +step:829/1695 train_time:81194ms step_avg:97.94ms +step:830/1695 train_time:81292ms step_avg:97.94ms +step:831/1695 train_time:81391ms step_avg:97.94ms +step:832/1695 train_time:81490ms step_avg:97.94ms +step:833/1695 train_time:81588ms step_avg:97.94ms +step:834/1695 train_time:81686ms step_avg:97.94ms +step:835/1695 train_time:81783ms step_avg:97.94ms +step:836/1695 train_time:81881ms step_avg:97.94ms +step:837/1695 train_time:81978ms step_avg:97.94ms +step:838/1695 train_time:82076ms step_avg:97.94ms +step:839/1695 train_time:82174ms step_avg:97.94ms +step:840/1695 train_time:82272ms step_avg:97.94ms +step:841/1695 train_time:82370ms step_avg:97.94ms +step:842/1695 train_time:82469ms step_avg:97.94ms +step:843/1695 train_time:82567ms step_avg:97.94ms +step:844/1695 train_time:82666ms step_avg:97.95ms +step:845/1695 train_time:82763ms step_avg:97.94ms +step:846/1695 train_time:82861ms step_avg:97.94ms +step:847/1695 train_time:82959ms step_avg:97.94ms +step:848/1695 train_time:83058ms step_avg:97.95ms +step:849/1695 train_time:83156ms step_avg:97.95ms +step:850/1695 train_time:83254ms step_avg:97.95ms +step:851/1695 train_time:83352ms step_avg:97.95ms +step:852/1695 train_time:83451ms step_avg:97.95ms +step:853/1695 train_time:83549ms step_avg:97.95ms +step:854/1695 train_time:83647ms step_avg:97.95ms +step:855/1695 train_time:83745ms step_avg:97.95ms +step:856/1695 train_time:83843ms step_avg:97.95ms +step:857/1695 train_time:83941ms step_avg:97.95ms +step:858/1695 train_time:84039ms step_avg:97.95ms +step:859/1695 train_time:84138ms step_avg:97.95ms +step:860/1695 train_time:84235ms step_avg:97.95ms +step:861/1695 train_time:84334ms step_avg:97.95ms +step:862/1695 train_time:84431ms step_avg:97.95ms +step:863/1695 train_time:84529ms step_avg:97.95ms +step:864/1695 train_time:84626ms step_avg:97.95ms +step:865/1695 train_time:84725ms step_avg:97.95ms +step:866/1695 train_time:84823ms step_avg:97.95ms +step:867/1695 train_time:84922ms step_avg:97.95ms +step:868/1695 train_time:85020ms step_avg:97.95ms +step:869/1695 train_time:85119ms step_avg:97.95ms +step:870/1695 train_time:85216ms step_avg:97.95ms +step:871/1695 train_time:85315ms step_avg:97.95ms +step:872/1695 train_time:85413ms step_avg:97.95ms +step:873/1695 train_time:85511ms step_avg:97.95ms +step:874/1695 train_time:85610ms step_avg:97.95ms +step:875/1695 train_time:85709ms step_avg:97.95ms +step:875/1695 val_loss:3.5440 train_time:85805ms step_avg:98.06ms +step:876/1695 train_time:85831ms step_avg:97.98ms +step:877/1695 train_time:85916ms step_avg:97.97ms +step:878/1695 train_time:86017ms step_avg:97.97ms +step:879/1695 train_time:86116ms step_avg:97.97ms +step:880/1695 train_time:86214ms step_avg:97.97ms +step:881/1695 train_time:86313ms step_avg:97.97ms +step:882/1695 train_time:86413ms step_avg:97.97ms +step:883/1695 train_time:86512ms step_avg:97.98ms +step:884/1695 train_time:86612ms step_avg:97.98ms +step:885/1695 train_time:86711ms step_avg:97.98ms +step:886/1695 train_time:86813ms step_avg:97.98ms +step:887/1695 train_time:86915ms step_avg:97.99ms +step:888/1695 train_time:87016ms step_avg:97.99ms +step:889/1695 train_time:87116ms step_avg:97.99ms +step:890/1695 train_time:87215ms step_avg:97.99ms +step:891/1695 train_time:87314ms step_avg:98.00ms +step:892/1695 train_time:87413ms step_avg:98.00ms +step:893/1695 train_time:87512ms step_avg:98.00ms +step:894/1695 train_time:87612ms step_avg:98.00ms +step:895/1695 train_time:87711ms step_avg:98.00ms +step:896/1695 train_time:87813ms step_avg:98.01ms +step:897/1695 train_time:87915ms step_avg:98.01ms +step:898/1695 train_time:88015ms step_avg:98.01ms +step:899/1695 train_time:88115ms step_avg:98.01ms +step:900/1695 train_time:88215ms step_avg:98.02ms +step:901/1695 train_time:88314ms step_avg:98.02ms +step:902/1695 train_time:88413ms step_avg:98.02ms +step:903/1695 train_time:88512ms step_avg:98.02ms +step:904/1695 train_time:88612ms step_avg:98.02ms +step:905/1695 train_time:88712ms step_avg:98.02ms +step:906/1695 train_time:88812ms step_avg:98.03ms +step:907/1695 train_time:88913ms step_avg:98.03ms +step:908/1695 train_time:89016ms step_avg:98.03ms +step:909/1695 train_time:89116ms step_avg:98.04ms +step:910/1695 train_time:89215ms step_avg:98.04ms +step:911/1695 train_time:89315ms step_avg:98.04ms +step:912/1695 train_time:89414ms step_avg:98.04ms +step:913/1695 train_time:89513ms step_avg:98.04ms +step:914/1695 train_time:89613ms step_avg:98.04ms +step:915/1695 train_time:89712ms step_avg:98.05ms +step:916/1695 train_time:89812ms step_avg:98.05ms +step:917/1695 train_time:89913ms step_avg:98.05ms +step:918/1695 train_time:90014ms step_avg:98.05ms +step:919/1695 train_time:90114ms step_avg:98.06ms +step:920/1695 train_time:90215ms step_avg:98.06ms +step:921/1695 train_time:90314ms step_avg:98.06ms +step:922/1695 train_time:90414ms step_avg:98.06ms +step:923/1695 train_time:90514ms step_avg:98.06ms +step:924/1695 train_time:90613ms step_avg:98.07ms +step:925/1695 train_time:90713ms step_avg:98.07ms +step:926/1695 train_time:90813ms step_avg:98.07ms +step:927/1695 train_time:90914ms step_avg:98.07ms +step:928/1695 train_time:91015ms step_avg:98.08ms +step:929/1695 train_time:91115ms step_avg:98.08ms +step:930/1695 train_time:91214ms step_avg:98.08ms +step:931/1695 train_time:91315ms step_avg:98.08ms +step:932/1695 train_time:91415ms step_avg:98.08ms +step:933/1695 train_time:91515ms step_avg:98.09ms +step:934/1695 train_time:91614ms step_avg:98.09ms +step:935/1695 train_time:91713ms step_avg:98.09ms +step:936/1695 train_time:91813ms step_avg:98.09ms +step:937/1695 train_time:91914ms step_avg:98.09ms +step:938/1695 train_time:92014ms step_avg:98.10ms +step:939/1695 train_time:92116ms step_avg:98.10ms +step:940/1695 train_time:92215ms step_avg:98.10ms +step:941/1695 train_time:92314ms step_avg:98.10ms +step:942/1695 train_time:92414ms step_avg:98.10ms +step:943/1695 train_time:92514ms step_avg:98.11ms +step:944/1695 train_time:92613ms step_avg:98.11ms +step:945/1695 train_time:92715ms step_avg:98.11ms +step:946/1695 train_time:92814ms step_avg:98.11ms +step:947/1695 train_time:92913ms step_avg:98.11ms +step:948/1695 train_time:93013ms step_avg:98.11ms +step:949/1695 train_time:93113ms step_avg:98.12ms +step:950/1695 train_time:93213ms step_avg:98.12ms +step:951/1695 train_time:93312ms step_avg:98.12ms +step:952/1695 train_time:93413ms step_avg:98.12ms +step:953/1695 train_time:93513ms step_avg:98.12ms +step:954/1695 train_time:93613ms step_avg:98.13ms +step:955/1695 train_time:93712ms step_avg:98.13ms +step:956/1695 train_time:93812ms step_avg:98.13ms +step:957/1695 train_time:93912ms step_avg:98.13ms +step:958/1695 train_time:94012ms step_avg:98.13ms +step:959/1695 train_time:94112ms step_avg:98.14ms +step:960/1695 train_time:94213ms step_avg:98.14ms +step:961/1695 train_time:94313ms step_avg:98.14ms +step:962/1695 train_time:94413ms step_avg:98.14ms +step:963/1695 train_time:94513ms step_avg:98.14ms +step:964/1695 train_time:94613ms step_avg:98.15ms +step:965/1695 train_time:94713ms step_avg:98.15ms +step:966/1695 train_time:94814ms step_avg:98.15ms +step:967/1695 train_time:94914ms step_avg:98.15ms +step:968/1695 train_time:95014ms step_avg:98.15ms +step:969/1695 train_time:95114ms step_avg:98.16ms +step:970/1695 train_time:95214ms step_avg:98.16ms +step:971/1695 train_time:95313ms step_avg:98.16ms +step:972/1695 train_time:95413ms step_avg:98.16ms +step:973/1695 train_time:95513ms step_avg:98.16ms +step:974/1695 train_time:95613ms step_avg:98.17ms +step:975/1695 train_time:95713ms step_avg:98.17ms +step:976/1695 train_time:95813ms step_avg:98.17ms +step:977/1695 train_time:95913ms step_avg:98.17ms +step:978/1695 train_time:96014ms step_avg:98.17ms +step:979/1695 train_time:96115ms step_avg:98.18ms +step:980/1695 train_time:96214ms step_avg:98.18ms +step:981/1695 train_time:96314ms step_avg:98.18ms +step:982/1695 train_time:96415ms step_avg:98.18ms +step:983/1695 train_time:96514ms step_avg:98.18ms +step:984/1695 train_time:96614ms step_avg:98.19ms +step:985/1695 train_time:96713ms step_avg:98.19ms +step:986/1695 train_time:96813ms step_avg:98.19ms +step:987/1695 train_time:96914ms step_avg:98.19ms +step:988/1695 train_time:97014ms step_avg:98.19ms +step:989/1695 train_time:97114ms step_avg:98.19ms +step:990/1695 train_time:97214ms step_avg:98.20ms +step:991/1695 train_time:97314ms step_avg:98.20ms +step:992/1695 train_time:97414ms step_avg:98.20ms +step:993/1695 train_time:97515ms step_avg:98.20ms +step:994/1695 train_time:97613ms step_avg:98.20ms +step:995/1695 train_time:97713ms step_avg:98.20ms +step:996/1695 train_time:97814ms step_avg:98.21ms +step:997/1695 train_time:97914ms step_avg:98.21ms +step:998/1695 train_time:98014ms step_avg:98.21ms +step:999/1695 train_time:98114ms step_avg:98.21ms +step:1000/1695 train_time:98215ms step_avg:98.21ms +step:1000/1695 val_loss:3.4977 train_time:98313ms step_avg:98.31ms +step:1001/1695 train_time:98340ms step_avg:98.24ms +step:1002/1695 train_time:98426ms step_avg:98.23ms +step:1003/1695 train_time:98528ms step_avg:98.23ms +step:1004/1695 train_time:98628ms step_avg:98.24ms +step:1005/1695 train_time:98728ms step_avg:98.24ms +step:1006/1695 train_time:98827ms step_avg:98.24ms +step:1007/1695 train_time:98928ms step_avg:98.24ms +step:1008/1695 train_time:99027ms step_avg:98.24ms +step:1009/1695 train_time:99126ms step_avg:98.24ms +step:1010/1695 train_time:99225ms step_avg:98.24ms +step:1011/1695 train_time:99328ms step_avg:98.25ms +step:1012/1695 train_time:99430ms step_avg:98.25ms +step:1013/1695 train_time:99531ms step_avg:98.25ms +step:1014/1695 train_time:99631ms step_avg:98.26ms +step:1015/1695 train_time:99731ms step_avg:98.26ms +step:1016/1695 train_time:99830ms step_avg:98.26ms +step:1017/1695 train_time:99929ms step_avg:98.26ms +step:1018/1695 train_time:100028ms step_avg:98.26ms +step:1019/1695 train_time:100126ms step_avg:98.26ms +step:1020/1695 train_time:100226ms step_avg:98.26ms +step:1021/1695 train_time:100329ms step_avg:98.27ms +step:1022/1695 train_time:100430ms step_avg:98.27ms +step:1023/1695 train_time:100530ms step_avg:98.27ms +step:1024/1695 train_time:100631ms step_avg:98.27ms +step:1025/1695 train_time:100731ms step_avg:98.27ms +step:1026/1695 train_time:100830ms step_avg:98.28ms +step:1027/1695 train_time:100929ms step_avg:98.28ms +step:1028/1695 train_time:101029ms step_avg:98.28ms +step:1029/1695 train_time:101129ms step_avg:98.28ms +step:1030/1695 train_time:101228ms step_avg:98.28ms +step:1031/1695 train_time:101329ms step_avg:98.28ms +step:1032/1695 train_time:101429ms step_avg:98.28ms +step:1033/1695 train_time:101529ms step_avg:98.29ms +step:1034/1695 train_time:101630ms step_avg:98.29ms +step:1035/1695 train_time:101730ms step_avg:98.29ms +step:1036/1695 train_time:101830ms step_avg:98.29ms +step:1037/1695 train_time:101930ms step_avg:98.29ms +step:1038/1695 train_time:102029ms step_avg:98.29ms +step:1039/1695 train_time:102128ms step_avg:98.29ms +step:1040/1695 train_time:102228ms step_avg:98.30ms +step:1041/1695 train_time:102328ms step_avg:98.30ms +step:1042/1695 train_time:102427ms step_avg:98.30ms +step:1043/1695 train_time:102528ms step_avg:98.30ms +step:1044/1695 train_time:102628ms step_avg:98.30ms +step:1045/1695 train_time:102728ms step_avg:98.30ms +step:1046/1695 train_time:102829ms step_avg:98.31ms +step:1047/1695 train_time:102929ms step_avg:98.31ms +step:1048/1695 train_time:103028ms step_avg:98.31ms +step:1049/1695 train_time:103127ms step_avg:98.31ms +step:1050/1695 train_time:103227ms step_avg:98.31ms +step:1051/1695 train_time:103328ms step_avg:98.31ms +step:1052/1695 train_time:103428ms step_avg:98.32ms +step:1053/1695 train_time:103528ms step_avg:98.32ms +step:1054/1695 train_time:103628ms step_avg:98.32ms +step:1055/1695 train_time:103728ms step_avg:98.32ms +step:1056/1695 train_time:103828ms step_avg:98.32ms +step:1057/1695 train_time:103928ms step_avg:98.32ms +step:1058/1695 train_time:104028ms step_avg:98.33ms +step:1059/1695 train_time:104127ms step_avg:98.33ms +step:1060/1695 train_time:104227ms step_avg:98.33ms +step:1061/1695 train_time:104327ms step_avg:98.33ms +step:1062/1695 train_time:104427ms step_avg:98.33ms +step:1063/1695 train_time:104528ms step_avg:98.33ms +step:1064/1695 train_time:104628ms step_avg:98.33ms +step:1065/1695 train_time:104728ms step_avg:98.34ms +step:1066/1695 train_time:104828ms step_avg:98.34ms +step:1067/1695 train_time:104928ms step_avg:98.34ms +step:1068/1695 train_time:105028ms step_avg:98.34ms +step:1069/1695 train_time:105128ms step_avg:98.34ms +step:1070/1695 train_time:105228ms step_avg:98.34ms +step:1071/1695 train_time:105328ms step_avg:98.35ms +step:1072/1695 train_time:105428ms step_avg:98.35ms +step:1073/1695 train_time:105527ms step_avg:98.35ms +step:1074/1695 train_time:105627ms step_avg:98.35ms +step:1075/1695 train_time:105728ms step_avg:98.35ms +step:1076/1695 train_time:105827ms step_avg:98.35ms +step:1077/1695 train_time:105928ms step_avg:98.36ms +step:1078/1695 train_time:106028ms step_avg:98.36ms +step:1079/1695 train_time:106128ms step_avg:98.36ms +step:1080/1695 train_time:106228ms step_avg:98.36ms +step:1081/1695 train_time:106328ms step_avg:98.36ms +step:1082/1695 train_time:106428ms step_avg:98.36ms +step:1083/1695 train_time:106527ms step_avg:98.36ms +step:1084/1695 train_time:106628ms step_avg:98.36ms +step:1085/1695 train_time:106728ms step_avg:98.37ms +step:1086/1695 train_time:106828ms step_avg:98.37ms +step:1087/1695 train_time:106928ms step_avg:98.37ms +step:1088/1695 train_time:107028ms step_avg:98.37ms +step:1089/1695 train_time:107128ms step_avg:98.37ms +step:1090/1695 train_time:107228ms step_avg:98.37ms +step:1091/1695 train_time:107328ms step_avg:98.38ms +step:1092/1695 train_time:107428ms step_avg:98.38ms +step:1093/1695 train_time:107528ms step_avg:98.38ms +step:1094/1695 train_time:107629ms step_avg:98.38ms +step:1095/1695 train_time:107727ms step_avg:98.38ms +step:1096/1695 train_time:107827ms step_avg:98.38ms +step:1097/1695 train_time:107927ms step_avg:98.38ms +step:1098/1695 train_time:108027ms step_avg:98.39ms +step:1099/1695 train_time:108127ms step_avg:98.39ms +step:1100/1695 train_time:108227ms step_avg:98.39ms +step:1101/1695 train_time:108326ms step_avg:98.39ms +step:1102/1695 train_time:108427ms step_avg:98.39ms +step:1103/1695 train_time:108526ms step_avg:98.39ms +step:1104/1695 train_time:108627ms step_avg:98.39ms +step:1105/1695 train_time:108726ms step_avg:98.39ms +step:1106/1695 train_time:108827ms step_avg:98.40ms +step:1107/1695 train_time:108927ms step_avg:98.40ms +step:1108/1695 train_time:109027ms step_avg:98.40ms +step:1109/1695 train_time:109127ms step_avg:98.40ms +step:1110/1695 train_time:109227ms step_avg:98.40ms +step:1111/1695 train_time:109328ms step_avg:98.40ms +step:1112/1695 train_time:109428ms step_avg:98.41ms +step:1113/1695 train_time:109528ms step_avg:98.41ms +step:1114/1695 train_time:109628ms step_avg:98.41ms +step:1115/1695 train_time:109728ms step_avg:98.41ms +step:1116/1695 train_time:109827ms step_avg:98.41ms +step:1117/1695 train_time:109928ms step_avg:98.41ms +step:1118/1695 train_time:110028ms step_avg:98.41ms +step:1119/1695 train_time:110127ms step_avg:98.42ms +step:1120/1695 train_time:110228ms step_avg:98.42ms +step:1121/1695 train_time:110327ms step_avg:98.42ms +step:1122/1695 train_time:110428ms step_avg:98.42ms +step:1123/1695 train_time:110528ms step_avg:98.42ms +step:1124/1695 train_time:110628ms step_avg:98.42ms +step:1125/1695 train_time:110729ms step_avg:98.43ms +step:1125/1695 val_loss:3.4459 train_time:110827ms step_avg:98.51ms +step:1126/1695 train_time:110853ms step_avg:98.45ms +step:1127/1695 train_time:110940ms step_avg:98.44ms +step:1128/1695 train_time:111042ms step_avg:98.44ms +step:1129/1695 train_time:111143ms step_avg:98.44ms +step:1130/1695 train_time:111243ms step_avg:98.44ms +step:1131/1695 train_time:111342ms step_avg:98.45ms +step:1132/1695 train_time:111442ms step_avg:98.45ms +step:1133/1695 train_time:111543ms step_avg:98.45ms +step:1134/1695 train_time:111642ms step_avg:98.45ms +step:1135/1695 train_time:111741ms step_avg:98.45ms +step:1136/1695 train_time:111844ms step_avg:98.45ms +step:1137/1695 train_time:111947ms step_avg:98.46ms +step:1138/1695 train_time:112048ms step_avg:98.46ms +step:1139/1695 train_time:112148ms step_avg:98.46ms +step:1140/1695 train_time:112248ms step_avg:98.46ms +step:1141/1695 train_time:112347ms step_avg:98.46ms +step:1142/1695 train_time:112447ms step_avg:98.46ms +step:1143/1695 train_time:112547ms step_avg:98.47ms +step:1144/1695 train_time:112648ms step_avg:98.47ms +step:1145/1695 train_time:112749ms step_avg:98.47ms +step:1146/1695 train_time:112849ms step_avg:98.47ms +step:1147/1695 train_time:112949ms step_avg:98.47ms +step:1148/1695 train_time:113050ms step_avg:98.48ms +step:1149/1695 train_time:113151ms step_avg:98.48ms +step:1150/1695 train_time:113252ms step_avg:98.48ms +step:1151/1695 train_time:113352ms step_avg:98.48ms +step:1152/1695 train_time:113453ms step_avg:98.48ms +step:1153/1695 train_time:113553ms step_avg:98.48ms +step:1154/1695 train_time:113655ms step_avg:98.49ms +step:1155/1695 train_time:113756ms step_avg:98.49ms +step:1156/1695 train_time:113857ms step_avg:98.49ms +step:1157/1695 train_time:113958ms step_avg:98.49ms +step:1158/1695 train_time:114059ms step_avg:98.50ms +step:1159/1695 train_time:114161ms step_avg:98.50ms +step:1160/1695 train_time:114263ms step_avg:98.50ms +step:1161/1695 train_time:114364ms step_avg:98.50ms +step:1162/1695 train_time:114464ms step_avg:98.51ms +step:1163/1695 train_time:114567ms step_avg:98.51ms +step:1164/1695 train_time:114668ms step_avg:98.51ms +step:1165/1695 train_time:114768ms step_avg:98.51ms +step:1166/1695 train_time:114868ms step_avg:98.51ms +step:1167/1695 train_time:114967ms step_avg:98.52ms +step:1168/1695 train_time:115068ms step_avg:98.52ms +step:1169/1695 train_time:115167ms step_avg:98.52ms +step:1170/1695 train_time:115268ms step_avg:98.52ms +step:1171/1695 train_time:115367ms step_avg:98.52ms +step:1172/1695 train_time:115469ms step_avg:98.52ms +step:1173/1695 train_time:115569ms step_avg:98.52ms +step:1174/1695 train_time:115670ms step_avg:98.53ms +step:1175/1695 train_time:115769ms step_avg:98.53ms +step:1176/1695 train_time:115869ms step_avg:98.53ms +step:1177/1695 train_time:115970ms step_avg:98.53ms +step:1178/1695 train_time:116069ms step_avg:98.53ms +step:1179/1695 train_time:116172ms step_avg:98.53ms +step:1180/1695 train_time:116273ms step_avg:98.54ms +step:1181/1695 train_time:116375ms step_avg:98.54ms +step:1182/1695 train_time:116475ms step_avg:98.54ms +step:1183/1695 train_time:116577ms step_avg:98.54ms +step:1184/1695 train_time:116679ms step_avg:98.55ms +step:1185/1695 train_time:116780ms step_avg:98.55ms +step:1186/1695 train_time:116881ms step_avg:98.55ms +step:1187/1695 train_time:116982ms step_avg:98.55ms +step:1188/1695 train_time:117084ms step_avg:98.56ms +step:1189/1695 train_time:117184ms step_avg:98.56ms +step:1190/1695 train_time:117284ms step_avg:98.56ms +step:1191/1695 train_time:117385ms step_avg:98.56ms +step:1192/1695 train_time:117486ms step_avg:98.56ms +step:1193/1695 train_time:117587ms step_avg:98.56ms +step:1194/1695 train_time:117688ms step_avg:98.57ms +step:1195/1695 train_time:117788ms step_avg:98.57ms +step:1196/1695 train_time:117889ms step_avg:98.57ms +step:1197/1695 train_time:117989ms step_avg:98.57ms +step:1198/1695 train_time:118089ms step_avg:98.57ms +step:1199/1695 train_time:118190ms step_avg:98.57ms +step:1200/1695 train_time:118289ms step_avg:98.57ms +step:1201/1695 train_time:118389ms step_avg:98.58ms +step:1202/1695 train_time:118491ms step_avg:98.58ms +step:1203/1695 train_time:118593ms step_avg:98.58ms +step:1204/1695 train_time:118694ms step_avg:98.58ms +step:1205/1695 train_time:118795ms step_avg:98.59ms +step:1206/1695 train_time:118897ms step_avg:98.59ms +step:1207/1695 train_time:118999ms step_avg:98.59ms +step:1208/1695 train_time:119100ms step_avg:98.59ms +step:1209/1695 train_time:119201ms step_avg:98.59ms +step:1210/1695 train_time:119302ms step_avg:98.60ms +step:1211/1695 train_time:119403ms step_avg:98.60ms +step:1212/1695 train_time:119503ms step_avg:98.60ms +step:1213/1695 train_time:119604ms step_avg:98.60ms +step:1214/1695 train_time:119705ms step_avg:98.60ms +step:1215/1695 train_time:119806ms step_avg:98.61ms +step:1216/1695 train_time:119907ms step_avg:98.61ms +step:1217/1695 train_time:120007ms step_avg:98.61ms +step:1218/1695 train_time:120108ms step_avg:98.61ms +step:1219/1695 train_time:120208ms step_avg:98.61ms +step:1220/1695 train_time:120309ms step_avg:98.61ms +step:1221/1695 train_time:120409ms step_avg:98.61ms +step:1222/1695 train_time:120508ms step_avg:98.62ms +step:1223/1695 train_time:120609ms step_avg:98.62ms +step:1224/1695 train_time:120709ms step_avg:98.62ms +step:1225/1695 train_time:120811ms step_avg:98.62ms +step:1226/1695 train_time:120912ms step_avg:98.62ms +step:1227/1695 train_time:121013ms step_avg:98.63ms +step:1228/1695 train_time:121113ms step_avg:98.63ms +step:1229/1695 train_time:121214ms step_avg:98.63ms +step:1230/1695 train_time:121315ms step_avg:98.63ms +step:1231/1695 train_time:121416ms step_avg:98.63ms +step:1232/1695 train_time:121518ms step_avg:98.63ms +step:1233/1695 train_time:121619ms step_avg:98.64ms +step:1234/1695 train_time:121723ms step_avg:98.64ms +step:1235/1695 train_time:121823ms step_avg:98.64ms +step:1236/1695 train_time:121925ms step_avg:98.64ms +step:1237/1695 train_time:122026ms step_avg:98.65ms +step:1238/1695 train_time:122127ms step_avg:98.65ms +step:1239/1695 train_time:122227ms step_avg:98.65ms +step:1240/1695 train_time:122327ms step_avg:98.65ms +step:1241/1695 train_time:122430ms step_avg:98.65ms +step:1242/1695 train_time:122530ms step_avg:98.66ms +step:1243/1695 train_time:122629ms step_avg:98.66ms +step:1244/1695 train_time:122729ms step_avg:98.66ms +step:1245/1695 train_time:122829ms step_avg:98.66ms +step:1246/1695 train_time:122931ms step_avg:98.66ms +step:1247/1695 train_time:123031ms step_avg:98.66ms +step:1248/1695 train_time:123133ms step_avg:98.66ms +step:1249/1695 train_time:123235ms step_avg:98.67ms +step:1250/1695 train_time:123336ms step_avg:98.67ms +step:1250/1695 val_loss:3.3987 train_time:123435ms step_avg:98.75ms +step:1251/1695 train_time:123462ms step_avg:98.69ms +step:1252/1695 train_time:123549ms step_avg:98.68ms +step:1253/1695 train_time:123649ms step_avg:98.68ms +step:1254/1695 train_time:123749ms step_avg:98.68ms +step:1255/1695 train_time:123849ms step_avg:98.68ms +step:1256/1695 train_time:123948ms step_avg:98.68ms +step:1257/1695 train_time:124048ms step_avg:98.69ms +step:1258/1695 train_time:124148ms step_avg:98.69ms +step:1259/1695 train_time:124247ms step_avg:98.69ms +step:1260/1695 train_time:124347ms step_avg:98.69ms +step:1261/1695 train_time:124449ms step_avg:98.69ms +step:1262/1695 train_time:124552ms step_avg:98.69ms +step:1263/1695 train_time:124652ms step_avg:98.70ms +step:1264/1695 train_time:124753ms step_avg:98.70ms +step:1265/1695 train_time:124853ms step_avg:98.70ms +step:1266/1695 train_time:124953ms step_avg:98.70ms +step:1267/1695 train_time:125054ms step_avg:98.70ms +step:1268/1695 train_time:125153ms step_avg:98.70ms +step:1269/1695 train_time:125254ms step_avg:98.70ms +step:1270/1695 train_time:125356ms step_avg:98.71ms +step:1271/1695 train_time:125458ms step_avg:98.71ms +step:1272/1695 train_time:125559ms step_avg:98.71ms +step:1273/1695 train_time:125661ms step_avg:98.71ms +step:1274/1695 train_time:125761ms step_avg:98.71ms +step:1275/1695 train_time:125861ms step_avg:98.71ms +step:1276/1695 train_time:125964ms step_avg:98.72ms +step:1277/1695 train_time:126065ms step_avg:98.72ms +step:1278/1695 train_time:126165ms step_avg:98.72ms +step:1279/1695 train_time:126267ms step_avg:98.72ms +step:1280/1695 train_time:126368ms step_avg:98.72ms +step:1281/1695 train_time:126468ms step_avg:98.73ms +step:1282/1695 train_time:126568ms step_avg:98.73ms +step:1283/1695 train_time:126668ms step_avg:98.73ms +step:1284/1695 train_time:126768ms step_avg:98.73ms +step:1285/1695 train_time:126868ms step_avg:98.73ms +step:1286/1695 train_time:126968ms step_avg:98.73ms +step:1287/1695 train_time:127068ms step_avg:98.73ms +step:1288/1695 train_time:127168ms step_avg:98.73ms +step:1289/1695 train_time:127268ms step_avg:98.73ms +step:1290/1695 train_time:127368ms step_avg:98.73ms +step:1291/1695 train_time:127468ms step_avg:98.74ms +step:1292/1695 train_time:127568ms step_avg:98.74ms +step:1293/1695 train_time:127669ms step_avg:98.74ms +step:1294/1695 train_time:127770ms step_avg:98.74ms +step:1295/1695 train_time:127871ms step_avg:98.74ms +step:1296/1695 train_time:127971ms step_avg:98.74ms +step:1297/1695 train_time:128073ms step_avg:98.75ms +step:1298/1695 train_time:128174ms step_avg:98.75ms +step:1299/1695 train_time:128274ms step_avg:98.75ms +step:1300/1695 train_time:128375ms step_avg:98.75ms +step:1301/1695 train_time:128477ms step_avg:98.75ms +step:1302/1695 train_time:128578ms step_avg:98.75ms +step:1303/1695 train_time:128681ms step_avg:98.76ms +step:1304/1695 train_time:128783ms step_avg:98.76ms +step:1305/1695 train_time:128885ms step_avg:98.76ms +step:1306/1695 train_time:128986ms step_avg:98.76ms +step:1307/1695 train_time:129087ms step_avg:98.77ms +step:1308/1695 train_time:129187ms step_avg:98.77ms +step:1309/1695 train_time:129287ms step_avg:98.77ms +step:1310/1695 train_time:129387ms step_avg:98.77ms +step:1311/1695 train_time:129488ms step_avg:98.77ms +step:1312/1695 train_time:129587ms step_avg:98.77ms +step:1313/1695 train_time:129688ms step_avg:98.77ms +step:1314/1695 train_time:129789ms step_avg:98.77ms +step:1315/1695 train_time:129890ms step_avg:98.78ms +step:1316/1695 train_time:129991ms step_avg:98.78ms +step:1317/1695 train_time:130092ms step_avg:98.78ms +step:1318/1695 train_time:130193ms step_avg:98.78ms +step:1319/1695 train_time:130295ms step_avg:98.78ms +step:1320/1695 train_time:130394ms step_avg:98.78ms +step:1321/1695 train_time:130496ms step_avg:98.79ms +step:1322/1695 train_time:130599ms step_avg:98.79ms +step:1323/1695 train_time:130699ms step_avg:98.79ms +step:1324/1695 train_time:130801ms step_avg:98.79ms +step:1325/1695 train_time:130902ms step_avg:98.79ms +step:1326/1695 train_time:131004ms step_avg:98.80ms +step:1327/1695 train_time:131105ms step_avg:98.80ms +step:1328/1695 train_time:131205ms step_avg:98.80ms +step:1329/1695 train_time:131306ms step_avg:98.80ms +step:1330/1695 train_time:131407ms step_avg:98.80ms +step:1331/1695 train_time:131508ms step_avg:98.80ms +step:1332/1695 train_time:131609ms step_avg:98.81ms +step:1333/1695 train_time:131710ms step_avg:98.81ms +step:1334/1695 train_time:131810ms step_avg:98.81ms +step:1335/1695 train_time:131910ms step_avg:98.81ms +step:1336/1695 train_time:132011ms step_avg:98.81ms +step:1337/1695 train_time:132113ms step_avg:98.81ms +step:1338/1695 train_time:132214ms step_avg:98.81ms +step:1339/1695 train_time:132315ms step_avg:98.82ms +step:1340/1695 train_time:132417ms step_avg:98.82ms +step:1341/1695 train_time:132518ms step_avg:98.82ms +step:1342/1695 train_time:132621ms step_avg:98.82ms +step:1343/1695 train_time:132723ms step_avg:98.83ms +step:1344/1695 train_time:132823ms step_avg:98.83ms +step:1345/1695 train_time:132923ms step_avg:98.83ms +step:1346/1695 train_time:133024ms step_avg:98.83ms +step:1347/1695 train_time:133125ms step_avg:98.83ms +step:1348/1695 train_time:133225ms step_avg:98.83ms +step:1349/1695 train_time:133326ms step_avg:98.83ms +step:1350/1695 train_time:133427ms step_avg:98.84ms +step:1351/1695 train_time:133529ms step_avg:98.84ms +step:1352/1695 train_time:133629ms step_avg:98.84ms +step:1353/1695 train_time:133729ms step_avg:98.84ms +step:1354/1695 train_time:133828ms step_avg:98.84ms +step:1355/1695 train_time:133928ms step_avg:98.84ms +step:1356/1695 train_time:134028ms step_avg:98.84ms +step:1357/1695 train_time:134128ms step_avg:98.84ms +step:1358/1695 train_time:134228ms step_avg:98.84ms +step:1359/1695 train_time:134328ms step_avg:98.84ms +step:1360/1695 train_time:134429ms step_avg:98.84ms +step:1361/1695 train_time:134529ms step_avg:98.85ms +step:1362/1695 train_time:134629ms step_avg:98.85ms +step:1363/1695 train_time:134731ms step_avg:98.85ms +step:1364/1695 train_time:134831ms step_avg:98.85ms +step:1365/1695 train_time:134933ms step_avg:98.85ms +step:1366/1695 train_time:135035ms step_avg:98.85ms +step:1367/1695 train_time:135135ms step_avg:98.86ms +step:1368/1695 train_time:135237ms step_avg:98.86ms +step:1369/1695 train_time:135338ms step_avg:98.86ms +step:1370/1695 train_time:135439ms step_avg:98.86ms +step:1371/1695 train_time:135540ms step_avg:98.86ms +step:1372/1695 train_time:135641ms step_avg:98.86ms +step:1373/1695 train_time:135743ms step_avg:98.87ms +step:1374/1695 train_time:135844ms step_avg:98.87ms +step:1375/1695 train_time:135945ms step_avg:98.87ms +step:1375/1695 val_loss:3.3591 train_time:136044ms step_avg:98.94ms +step:1376/1695 train_time:136071ms step_avg:98.89ms +step:1377/1695 train_time:136154ms step_avg:98.88ms +step:1378/1695 train_time:136255ms step_avg:98.88ms +step:1379/1695 train_time:136357ms step_avg:98.88ms +step:1380/1695 train_time:136459ms step_avg:98.88ms +step:1381/1695 train_time:136559ms step_avg:98.88ms +step:1382/1695 train_time:136660ms step_avg:98.89ms +step:1383/1695 train_time:136760ms step_avg:98.89ms +step:1384/1695 train_time:136861ms step_avg:98.89ms +step:1385/1695 train_time:136962ms step_avg:98.89ms +step:1386/1695 train_time:137067ms step_avg:98.89ms +step:1387/1695 train_time:137169ms step_avg:98.90ms +step:1388/1695 train_time:137271ms step_avg:98.90ms +step:1389/1695 train_time:137372ms step_avg:98.90ms +step:1390/1695 train_time:137473ms step_avg:98.90ms +step:1391/1695 train_time:137574ms step_avg:98.90ms +step:1392/1695 train_time:137675ms step_avg:98.90ms +step:1393/1695 train_time:137775ms step_avg:98.91ms +step:1394/1695 train_time:137877ms step_avg:98.91ms +step:1395/1695 train_time:137979ms step_avg:98.91ms +step:1396/1695 train_time:138082ms step_avg:98.91ms +step:1397/1695 train_time:138186ms step_avg:98.92ms +step:1398/1695 train_time:138289ms step_avg:98.92ms +step:1399/1695 train_time:138391ms step_avg:98.92ms +step:1400/1695 train_time:138493ms step_avg:98.92ms +step:1401/1695 train_time:138593ms step_avg:98.92ms +step:1402/1695 train_time:138695ms step_avg:98.93ms +step:1403/1695 train_time:138797ms step_avg:98.93ms +step:1404/1695 train_time:138898ms step_avg:98.93ms +step:1405/1695 train_time:138999ms step_avg:98.93ms +step:1406/1695 train_time:139102ms step_avg:98.93ms +step:1407/1695 train_time:139204ms step_avg:98.94ms +step:1408/1695 train_time:139306ms step_avg:98.94ms +step:1409/1695 train_time:139411ms step_avg:98.94ms +step:1410/1695 train_time:139512ms step_avg:98.94ms +step:1411/1695 train_time:139613ms step_avg:98.95ms +step:1412/1695 train_time:139716ms step_avg:98.95ms +step:1413/1695 train_time:139816ms step_avg:98.95ms +step:1414/1695 train_time:139918ms step_avg:98.95ms +step:1415/1695 train_time:140020ms step_avg:98.95ms +step:1416/1695 train_time:140121ms step_avg:98.96ms +step:1417/1695 train_time:140223ms step_avg:98.96ms +step:1418/1695 train_time:140325ms step_avg:98.96ms +step:1419/1695 train_time:140428ms step_avg:98.96ms +step:1420/1695 train_time:140530ms step_avg:98.96ms +step:1421/1695 train_time:140632ms step_avg:98.97ms +step:1422/1695 train_time:140733ms step_avg:98.97ms +step:1423/1695 train_time:140833ms step_avg:98.97ms +step:1424/1695 train_time:140935ms step_avg:98.97ms +step:1425/1695 train_time:141036ms step_avg:98.97ms +step:1426/1695 train_time:141138ms step_avg:98.97ms +step:1427/1695 train_time:141241ms step_avg:98.98ms +step:1428/1695 train_time:141343ms step_avg:98.98ms +step:1429/1695 train_time:141446ms step_avg:98.98ms +step:1430/1695 train_time:141548ms step_avg:98.98ms +step:1431/1695 train_time:141649ms step_avg:98.99ms +step:1432/1695 train_time:141750ms step_avg:98.99ms +step:1433/1695 train_time:141851ms step_avg:98.99ms +step:1434/1695 train_time:141951ms step_avg:98.99ms +step:1435/1695 train_time:142053ms step_avg:98.99ms +step:1436/1695 train_time:142156ms step_avg:98.99ms +step:1437/1695 train_time:142257ms step_avg:99.00ms +step:1438/1695 train_time:142360ms step_avg:99.00ms +step:1439/1695 train_time:142463ms step_avg:99.00ms +step:1440/1695 train_time:142565ms step_avg:99.00ms +step:1441/1695 train_time:142667ms step_avg:99.01ms +step:1442/1695 train_time:142768ms step_avg:99.01ms +step:1443/1695 train_time:142869ms step_avg:99.01ms +step:1444/1695 train_time:142970ms step_avg:99.01ms +step:1445/1695 train_time:143072ms step_avg:99.01ms +step:1446/1695 train_time:143172ms step_avg:99.01ms +step:1447/1695 train_time:143275ms step_avg:99.02ms +step:1448/1695 train_time:143377ms step_avg:99.02ms +step:1449/1695 train_time:143479ms step_avg:99.02ms +step:1450/1695 train_time:143582ms step_avg:99.02ms +step:1451/1695 train_time:143682ms step_avg:99.02ms +step:1452/1695 train_time:143785ms step_avg:99.03ms +step:1453/1695 train_time:143888ms step_avg:99.03ms +step:1454/1695 train_time:143990ms step_avg:99.03ms +step:1455/1695 train_time:144091ms step_avg:99.03ms +step:1456/1695 train_time:144192ms step_avg:99.03ms +step:1457/1695 train_time:144293ms step_avg:99.03ms +step:1458/1695 train_time:144395ms step_avg:99.04ms +step:1459/1695 train_time:144496ms step_avg:99.04ms +step:1460/1695 train_time:144597ms step_avg:99.04ms +step:1461/1695 train_time:144699ms step_avg:99.04ms +step:1462/1695 train_time:144801ms step_avg:99.04ms +step:1463/1695 train_time:144904ms step_avg:99.05ms +step:1464/1695 train_time:145007ms step_avg:99.05ms +step:1465/1695 train_time:145109ms step_avg:99.05ms +step:1466/1695 train_time:145210ms step_avg:99.05ms +step:1467/1695 train_time:145311ms step_avg:99.05ms +step:1468/1695 train_time:145412ms step_avg:99.05ms +step:1469/1695 train_time:145515ms step_avg:99.06ms +step:1470/1695 train_time:145615ms step_avg:99.06ms +step:1471/1695 train_time:145717ms step_avg:99.06ms +step:1472/1695 train_time:145820ms step_avg:99.06ms +step:1473/1695 train_time:145922ms step_avg:99.06ms +step:1474/1695 train_time:146024ms step_avg:99.07ms +step:1475/1695 train_time:146125ms step_avg:99.07ms +step:1476/1695 train_time:146227ms step_avg:99.07ms +step:1477/1695 train_time:146329ms step_avg:99.07ms +step:1478/1695 train_time:146431ms step_avg:99.07ms +step:1479/1695 train_time:146532ms step_avg:99.08ms +step:1480/1695 train_time:146633ms step_avg:99.08ms +step:1481/1695 train_time:146735ms step_avg:99.08ms +step:1482/1695 train_time:146837ms step_avg:99.08ms +step:1483/1695 train_time:146940ms step_avg:99.08ms +step:1484/1695 train_time:147042ms step_avg:99.08ms +step:1485/1695 train_time:147143ms step_avg:99.09ms +step:1486/1695 train_time:147244ms step_avg:99.09ms +step:1487/1695 train_time:147346ms step_avg:99.09ms +step:1488/1695 train_time:147450ms step_avg:99.09ms +step:1489/1695 train_time:147551ms step_avg:99.09ms +step:1490/1695 train_time:147652ms step_avg:99.10ms +step:1491/1695 train_time:147753ms step_avg:99.10ms +step:1492/1695 train_time:147853ms step_avg:99.10ms +step:1493/1695 train_time:147955ms step_avg:99.10ms +step:1494/1695 train_time:148056ms step_avg:99.10ms +step:1495/1695 train_time:148160ms step_avg:99.10ms +step:1496/1695 train_time:148262ms step_avg:99.11ms +step:1497/1695 train_time:148365ms step_avg:99.11ms +step:1498/1695 train_time:148468ms step_avg:99.11ms +step:1499/1695 train_time:148570ms step_avg:99.11ms +step:1500/1695 train_time:148672ms step_avg:99.11ms +step:1500/1695 val_loss:3.3236 train_time:148770ms step_avg:99.18ms +step:1501/1695 train_time:148796ms step_avg:99.13ms +step:1502/1695 train_time:148882ms step_avg:99.12ms +step:1503/1695 train_time:148986ms step_avg:99.13ms +step:1504/1695 train_time:149088ms step_avg:99.13ms +step:1505/1695 train_time:149189ms step_avg:99.13ms +step:1506/1695 train_time:149290ms step_avg:99.13ms +step:1507/1695 train_time:149391ms step_avg:99.13ms +step:1508/1695 train_time:149491ms step_avg:99.13ms +step:1509/1695 train_time:149593ms step_avg:99.13ms +step:1510/1695 train_time:149694ms step_avg:99.14ms +step:1511/1695 train_time:149798ms step_avg:99.14ms +step:1512/1695 train_time:149900ms step_avg:99.14ms +step:1513/1695 train_time:150002ms step_avg:99.14ms +step:1514/1695 train_time:150105ms step_avg:99.14ms +step:1515/1695 train_time:150210ms step_avg:99.15ms +step:1516/1695 train_time:150312ms step_avg:99.15ms +step:1517/1695 train_time:150412ms step_avg:99.15ms +step:1518/1695 train_time:150513ms step_avg:99.15ms +step:1519/1695 train_time:150616ms step_avg:99.15ms +step:1520/1695 train_time:150718ms step_avg:99.16ms +step:1521/1695 train_time:150819ms step_avg:99.16ms +step:1522/1695 train_time:150920ms step_avg:99.16ms +step:1523/1695 train_time:151022ms step_avg:99.16ms +step:1524/1695 train_time:151126ms step_avg:99.16ms +step:1525/1695 train_time:151230ms step_avg:99.17ms +step:1526/1695 train_time:151332ms step_avg:99.17ms +step:1527/1695 train_time:151434ms step_avg:99.17ms +step:1528/1695 train_time:151539ms step_avg:99.17ms +step:1529/1695 train_time:151641ms step_avg:99.18ms +step:1530/1695 train_time:151744ms step_avg:99.18ms +step:1531/1695 train_time:151845ms step_avg:99.18ms +step:1532/1695 train_time:151947ms step_avg:99.18ms +step:1533/1695 train_time:152049ms step_avg:99.18ms +step:1534/1695 train_time:152150ms step_avg:99.19ms +step:1535/1695 train_time:152252ms step_avg:99.19ms +step:1536/1695 train_time:152354ms step_avg:99.19ms +step:1537/1695 train_time:152455ms step_avg:99.19ms +step:1538/1695 train_time:152557ms step_avg:99.19ms +step:1539/1695 train_time:152658ms step_avg:99.19ms +step:1540/1695 train_time:152760ms step_avg:99.20ms +step:1541/1695 train_time:152866ms step_avg:99.20ms +step:1542/1695 train_time:152969ms step_avg:99.20ms +step:1543/1695 train_time:153071ms step_avg:99.20ms +step:1544/1695 train_time:153172ms step_avg:99.20ms +step:1545/1695 train_time:153273ms step_avg:99.21ms +step:1546/1695 train_time:153374ms step_avg:99.21ms +step:1547/1695 train_time:153476ms step_avg:99.21ms +step:1548/1695 train_time:153579ms step_avg:99.21ms +step:1549/1695 train_time:153681ms step_avg:99.21ms +step:1550/1695 train_time:153782ms step_avg:99.21ms +step:1551/1695 train_time:153885ms step_avg:99.22ms +step:1552/1695 train_time:153989ms step_avg:99.22ms +step:1553/1695 train_time:154092ms step_avg:99.22ms +step:1554/1695 train_time:154193ms step_avg:99.22ms +step:1555/1695 train_time:154293ms step_avg:99.22ms +step:1556/1695 train_time:154395ms step_avg:99.23ms +step:1557/1695 train_time:154498ms step_avg:99.23ms +step:1558/1695 train_time:154601ms step_avg:99.23ms +step:1559/1695 train_time:154703ms step_avg:99.23ms +step:1560/1695 train_time:154805ms step_avg:99.23ms +step:1561/1695 train_time:154906ms step_avg:99.24ms +step:1562/1695 train_time:155010ms step_avg:99.24ms +step:1563/1695 train_time:155114ms step_avg:99.24ms +step:1564/1695 train_time:155215ms step_avg:99.24ms +step:1565/1695 train_time:155316ms step_avg:99.24ms +step:1566/1695 train_time:155417ms step_avg:99.24ms +step:1567/1695 train_time:155517ms step_avg:99.25ms +step:1568/1695 train_time:155618ms step_avg:99.25ms +step:1569/1695 train_time:155718ms step_avg:99.25ms +step:1570/1695 train_time:155821ms step_avg:99.25ms +step:1571/1695 train_time:155923ms step_avg:99.25ms +step:1572/1695 train_time:156026ms step_avg:99.25ms +step:1573/1695 train_time:156129ms step_avg:99.26ms +step:1574/1695 train_time:156230ms step_avg:99.26ms +step:1575/1695 train_time:156332ms step_avg:99.26ms +step:1576/1695 train_time:156434ms step_avg:99.26ms +step:1577/1695 train_time:156538ms step_avg:99.26ms +step:1578/1695 train_time:156638ms step_avg:99.26ms +step:1579/1695 train_time:156739ms step_avg:99.26ms +step:1580/1695 train_time:156840ms step_avg:99.27ms +step:1581/1695 train_time:156944ms step_avg:99.27ms +step:1582/1695 train_time:157046ms step_avg:99.27ms +step:1583/1695 train_time:157148ms step_avg:99.27ms +step:1584/1695 train_time:157251ms step_avg:99.27ms +step:1585/1695 train_time:157352ms step_avg:99.28ms +step:1586/1695 train_time:157456ms step_avg:99.28ms +step:1587/1695 train_time:157557ms step_avg:99.28ms +step:1588/1695 train_time:157657ms step_avg:99.28ms +step:1589/1695 train_time:157758ms step_avg:99.28ms +step:1590/1695 train_time:157859ms step_avg:99.28ms +step:1591/1695 train_time:157961ms step_avg:99.28ms +step:1592/1695 train_time:158064ms step_avg:99.29ms +step:1593/1695 train_time:158165ms step_avg:99.29ms +step:1594/1695 train_time:158269ms step_avg:99.29ms +step:1595/1695 train_time:158371ms step_avg:99.29ms +step:1596/1695 train_time:158473ms step_avg:99.29ms +step:1597/1695 train_time:158575ms step_avg:99.30ms +step:1598/1695 train_time:158678ms step_avg:99.30ms +step:1599/1695 train_time:158779ms step_avg:99.30ms +step:1600/1695 train_time:158881ms step_avg:99.30ms +step:1601/1695 train_time:158983ms step_avg:99.30ms +step:1602/1695 train_time:159086ms step_avg:99.30ms +step:1603/1695 train_time:159187ms step_avg:99.31ms +step:1604/1695 train_time:159288ms step_avg:99.31ms +step:1605/1695 train_time:159391ms step_avg:99.31ms +step:1606/1695 train_time:159493ms step_avg:99.31ms +step:1607/1695 train_time:159595ms step_avg:99.31ms +step:1608/1695 train_time:159695ms step_avg:99.31ms +step:1609/1695 train_time:159796ms step_avg:99.31ms +step:1610/1695 train_time:159900ms step_avg:99.32ms +step:1611/1695 train_time:160002ms step_avg:99.32ms +step:1612/1695 train_time:160104ms step_avg:99.32ms +step:1613/1695 train_time:160205ms step_avg:99.32ms +step:1614/1695 train_time:160306ms step_avg:99.32ms +step:1615/1695 train_time:160409ms step_avg:99.32ms +step:1616/1695 train_time:160511ms step_avg:99.33ms +step:1617/1695 train_time:160614ms step_avg:99.33ms +step:1618/1695 train_time:160715ms step_avg:99.33ms +step:1619/1695 train_time:160816ms step_avg:99.33ms +step:1620/1695 train_time:160917ms step_avg:99.33ms +step:1621/1695 train_time:161018ms step_avg:99.33ms +step:1622/1695 train_time:161119ms step_avg:99.33ms +step:1623/1695 train_time:161221ms step_avg:99.34ms +step:1624/1695 train_time:161324ms step_avg:99.34ms +step:1625/1695 train_time:161428ms step_avg:99.34ms +step:1625/1695 val_loss:3.2946 train_time:161529ms step_avg:99.40ms +step:1626/1695 train_time:161558ms step_avg:99.36ms +step:1627/1695 train_time:161640ms step_avg:99.35ms +step:1628/1695 train_time:161743ms step_avg:99.35ms +step:1629/1695 train_time:161845ms step_avg:99.35ms +step:1630/1695 train_time:161946ms step_avg:99.35ms +step:1631/1695 train_time:162047ms step_avg:99.35ms +step:1632/1695 train_time:162148ms step_avg:99.36ms +step:1633/1695 train_time:162249ms step_avg:99.36ms +step:1634/1695 train_time:162351ms step_avg:99.36ms +step:1635/1695 train_time:162453ms step_avg:99.36ms +step:1636/1695 train_time:162557ms step_avg:99.36ms +step:1637/1695 train_time:162661ms step_avg:99.37ms +step:1638/1695 train_time:162763ms step_avg:99.37ms +step:1639/1695 train_time:162865ms step_avg:99.37ms +step:1640/1695 train_time:162968ms step_avg:99.37ms +step:1641/1695 train_time:163070ms step_avg:99.37ms +step:1642/1695 train_time:163172ms step_avg:99.37ms +step:1643/1695 train_time:163273ms step_avg:99.37ms +step:1644/1695 train_time:163375ms step_avg:99.38ms +step:1645/1695 train_time:163478ms step_avg:99.38ms +step:1646/1695 train_time:163581ms step_avg:99.38ms +step:1647/1695 train_time:163685ms step_avg:99.38ms +step:1648/1695 train_time:163789ms step_avg:99.39ms +step:1649/1695 train_time:163891ms step_avg:99.39ms +step:1650/1695 train_time:163994ms step_avg:99.39ms +step:1651/1695 train_time:164096ms step_avg:99.39ms +step:1652/1695 train_time:164199ms step_avg:99.39ms +step:1653/1695 train_time:164302ms step_avg:99.40ms +step:1654/1695 train_time:164403ms step_avg:99.40ms +step:1655/1695 train_time:164507ms step_avg:99.40ms +step:1656/1695 train_time:164608ms step_avg:99.40ms +step:1657/1695 train_time:164712ms step_avg:99.40ms +step:1658/1695 train_time:164814ms step_avg:99.41ms +step:1659/1695 train_time:164921ms step_avg:99.41ms +step:1660/1695 train_time:165023ms step_avg:99.41ms +step:1661/1695 train_time:165126ms step_avg:99.41ms +step:1662/1695 train_time:165231ms step_avg:99.42ms +step:1663/1695 train_time:165334ms step_avg:99.42ms +step:1664/1695 train_time:165436ms step_avg:99.42ms +step:1665/1695 train_time:165541ms step_avg:99.42ms +step:1666/1695 train_time:165644ms step_avg:99.43ms +step:1667/1695 train_time:165746ms step_avg:99.43ms +step:1668/1695 train_time:165851ms step_avg:99.43ms +step:1669/1695 train_time:165956ms step_avg:99.43ms +step:1670/1695 train_time:166059ms step_avg:99.44ms +step:1671/1695 train_time:166161ms step_avg:99.44ms +step:1672/1695 train_time:166264ms step_avg:99.44ms +step:1673/1695 train_time:166365ms step_avg:99.44ms +step:1674/1695 train_time:166467ms step_avg:99.44ms +step:1675/1695 train_time:166569ms step_avg:99.44ms +step:1676/1695 train_time:166673ms step_avg:99.45ms +step:1677/1695 train_time:166776ms step_avg:99.45ms +step:1678/1695 train_time:166880ms step_avg:99.45ms +step:1679/1695 train_time:166983ms step_avg:99.45ms +step:1680/1695 train_time:167085ms step_avg:99.46ms +step:1681/1695 train_time:167188ms step_avg:99.46ms +step:1682/1695 train_time:167293ms step_avg:99.46ms +step:1683/1695 train_time:167396ms step_avg:99.46ms +step:1684/1695 train_time:167499ms step_avg:99.47ms +step:1685/1695 train_time:167602ms step_avg:99.47ms +step:1686/1695 train_time:167704ms step_avg:99.47ms +step:1687/1695 train_time:167806ms step_avg:99.47ms +step:1688/1695 train_time:167908ms step_avg:99.47ms +step:1689/1695 train_time:168010ms step_avg:99.47ms +step:1690/1695 train_time:168112ms step_avg:99.47ms +step:1691/1695 train_time:168215ms step_avg:99.48ms +step:1692/1695 train_time:168318ms step_avg:99.48ms +step:1693/1695 train_time:168421ms step_avg:99.48ms +step:1694/1695 train_time:168525ms step_avg:99.48ms +step:1695/1695 train_time:168628ms step_avg:99.49ms +step:1695/1695 val_loss:3.2815 train_time:168727ms step_avg:99.54ms +peak memory allocated: 34004 MiB reserved: 49660 MiB diff --git a/records/082325_SparseAttnGate/eb6d347b-fd4a-4077-a490-436c64f97ce2.txt b/records/082325_SparseAttnGate/eb6d347b-fd4a-4077-a490-436c64f97ce2.txt new file mode 100644 index 000000000..33fd2ea91 --- /dev/null +++ b/records/082325_SparseAttnGate/eb6d347b-fd4a-4077-a490-436c64f97ce2.txt @@ -0,0 +1,2802 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +@dataclass +class Hyperparameters: + # data + dampen_factor = 64 + run_id = f'final/{uuid.uuid4()}' + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1695 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0, bias=False): + super().__init__(in_features, out_features, bias=bias) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + self.dampen = CastedLinear(dim//args.dampen_factor, num_heads) + self.dampen.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T, d_model = x.size(0), x.size(1), x.size(-1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + dampen_factor = torch.sigmoid(self.dampen(x[..., :d_model//args.dampen_factor])).view(B, T, self.num_heads, 1) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * dampen_factor + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + # increments = (input_seq == 50256) | torch.cat([torch.tensor([False], device="cuda"), input_seq[:-1] == 50256]) + # docs = increments.cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + + + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250713+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sat Aug 23 13:32:07 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 310315 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 310316 C /usr/bin/python3 614MiB | +| 0 N/A N/A 310317 C /usr/bin/python3 614MiB | +| 0 N/A N/A 310318 C /usr/bin/python3 614MiB | +| 0 N/A N/A 310319 C /usr/bin/python3 614MiB | +| 0 N/A N/A 310320 C /usr/bin/python3 614MiB | +| 0 N/A N/A 310321 C /usr/bin/python3 614MiB | +| 0 N/A N/A 310322 C /usr/bin/python3 614MiB | +| 1 N/A N/A 310316 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 310317 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 310318 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 310319 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 310320 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 310321 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 310322 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:152ms step_avg:152.47ms +step:2/1695 train_time:178ms step_avg:88.89ms +step:3/1695 train_time:255ms step_avg:84.86ms +step:4/1695 train_time:346ms step_avg:86.58ms +step:5/1695 train_time:439ms step_avg:87.77ms +step:6/1695 train_time:532ms step_avg:88.62ms +step:7/1695 train_time:625ms step_avg:89.24ms +step:8/1695 train_time:719ms step_avg:89.87ms +step:9/1695 train_time:812ms step_avg:90.20ms +step:10/1695 train_time:905ms step_avg:90.49ms +step:11/1695 train_time:998ms step_avg:90.74ms +step:12/1695 train_time:1092ms step_avg:90.98ms +step:13/1695 train_time:1187ms step_avg:91.28ms +step:14/1695 train_time:1281ms step_avg:91.50ms +step:15/1695 train_time:1375ms step_avg:91.69ms +step:16/1695 train_time:1470ms step_avg:91.86ms +step:17/1695 train_time:1564ms step_avg:91.99ms +step:18/1695 train_time:1658ms step_avg:92.08ms +step:19/1695 train_time:1751ms step_avg:92.16ms +step:20/1695 train_time:1844ms step_avg:92.19ms +step:21/1695 train_time:1938ms step_avg:92.27ms +step:22/1695 train_time:2031ms step_avg:92.34ms +step:23/1695 train_time:2126ms step_avg:92.42ms +step:24/1695 train_time:2220ms step_avg:92.49ms +step:25/1695 train_time:2314ms step_avg:92.56ms +step:26/1695 train_time:2407ms step_avg:92.60ms +step:27/1695 train_time:2502ms step_avg:92.66ms +step:28/1695 train_time:2597ms step_avg:92.74ms +step:29/1695 train_time:2691ms step_avg:92.80ms +step:30/1695 train_time:2784ms step_avg:92.80ms +step:31/1695 train_time:2878ms step_avg:92.82ms +step:32/1695 train_time:2971ms step_avg:92.86ms +step:33/1695 train_time:3065ms step_avg:92.88ms +step:34/1695 train_time:3159ms step_avg:92.92ms +step:35/1695 train_time:3253ms step_avg:92.94ms +step:36/1695 train_time:3346ms step_avg:92.95ms +step:37/1695 train_time:3441ms step_avg:92.99ms +step:38/1695 train_time:3535ms step_avg:93.04ms +step:39/1695 train_time:3630ms step_avg:93.07ms +step:40/1695 train_time:3724ms step_avg:93.09ms +step:41/1695 train_time:3818ms step_avg:93.11ms +step:42/1695 train_time:3911ms step_avg:93.13ms +step:43/1695 train_time:4005ms step_avg:93.13ms +step:44/1695 train_time:4098ms step_avg:93.14ms +step:45/1695 train_time:4192ms step_avg:93.15ms +step:46/1695 train_time:4285ms step_avg:93.16ms +step:47/1695 train_time:4379ms step_avg:93.17ms +step:48/1695 train_time:4473ms step_avg:93.19ms +step:49/1695 train_time:4567ms step_avg:93.20ms +step:50/1695 train_time:4660ms step_avg:93.21ms +step:51/1695 train_time:4754ms step_avg:93.21ms +step:52/1695 train_time:4848ms step_avg:93.23ms +step:53/1695 train_time:4942ms step_avg:93.24ms +step:54/1695 train_time:5036ms step_avg:93.27ms +step:55/1695 train_time:5131ms step_avg:93.28ms +step:56/1695 train_time:5224ms step_avg:93.28ms +step:57/1695 train_time:5318ms step_avg:93.29ms +step:58/1695 train_time:5412ms step_avg:93.31ms +step:59/1695 train_time:5505ms step_avg:93.31ms +step:60/1695 train_time:5599ms step_avg:93.32ms +step:61/1695 train_time:5693ms step_avg:93.32ms +step:62/1695 train_time:5786ms step_avg:93.32ms +step:63/1695 train_time:5879ms step_avg:93.32ms +step:64/1695 train_time:5973ms step_avg:93.33ms +step:65/1695 train_time:6067ms step_avg:93.34ms +step:66/1695 train_time:6160ms step_avg:93.34ms +step:67/1695 train_time:6254ms step_avg:93.35ms +step:68/1695 train_time:6349ms step_avg:93.36ms +step:69/1695 train_time:6442ms step_avg:93.37ms +step:70/1695 train_time:6536ms step_avg:93.37ms +step:71/1695 train_time:6630ms step_avg:93.39ms +step:72/1695 train_time:6724ms step_avg:93.39ms +step:73/1695 train_time:6818ms step_avg:93.39ms +step:74/1695 train_time:6912ms step_avg:93.41ms +step:75/1695 train_time:7005ms step_avg:93.41ms +step:76/1695 train_time:7099ms step_avg:93.41ms +step:77/1695 train_time:7193ms step_avg:93.41ms +step:78/1695 train_time:7286ms step_avg:93.41ms +step:79/1695 train_time:7379ms step_avg:93.41ms +step:80/1695 train_time:7473ms step_avg:93.41ms +step:81/1695 train_time:7566ms step_avg:93.41ms +step:82/1695 train_time:7660ms step_avg:93.41ms +step:83/1695 train_time:7753ms step_avg:93.42ms +step:84/1695 train_time:7847ms step_avg:93.41ms +step:85/1695 train_time:7940ms step_avg:93.42ms +step:86/1695 train_time:8035ms step_avg:93.42ms +step:87/1695 train_time:8128ms step_avg:93.42ms +step:88/1695 train_time:8221ms step_avg:93.43ms +step:89/1695 train_time:8316ms step_avg:93.44ms +step:90/1695 train_time:8410ms step_avg:93.45ms +step:91/1695 train_time:8504ms step_avg:93.45ms +step:92/1695 train_time:8598ms step_avg:93.46ms +step:93/1695 train_time:8692ms step_avg:93.46ms +step:94/1695 train_time:8785ms step_avg:93.45ms +step:95/1695 train_time:8879ms step_avg:93.47ms +step:96/1695 train_time:8973ms step_avg:93.47ms +step:97/1695 train_time:9067ms step_avg:93.47ms +step:98/1695 train_time:9160ms step_avg:93.47ms +step:99/1695 train_time:9254ms step_avg:93.47ms +step:100/1695 train_time:9347ms step_avg:93.47ms +step:101/1695 train_time:9441ms step_avg:93.48ms +step:102/1695 train_time:9534ms step_avg:93.47ms +step:103/1695 train_time:9627ms step_avg:93.47ms +step:104/1695 train_time:9721ms step_avg:93.47ms +step:105/1695 train_time:9815ms step_avg:93.48ms +step:106/1695 train_time:9909ms step_avg:93.48ms +step:107/1695 train_time:10002ms step_avg:93.48ms +step:108/1695 train_time:10097ms step_avg:93.49ms +step:109/1695 train_time:10191ms step_avg:93.50ms +step:110/1695 train_time:10284ms step_avg:93.49ms +step:111/1695 train_time:10378ms step_avg:93.50ms +step:112/1695 train_time:10472ms step_avg:93.50ms +step:113/1695 train_time:10566ms step_avg:93.51ms +step:114/1695 train_time:10659ms step_avg:93.50ms +step:115/1695 train_time:10753ms step_avg:93.51ms +step:116/1695 train_time:10847ms step_avg:93.51ms +step:117/1695 train_time:10941ms step_avg:93.51ms +step:118/1695 train_time:11035ms step_avg:93.52ms +step:119/1695 train_time:11130ms step_avg:93.53ms +step:120/1695 train_time:11223ms step_avg:93.53ms +step:121/1695 train_time:11316ms step_avg:93.52ms +step:122/1695 train_time:11411ms step_avg:93.53ms +step:123/1695 train_time:11503ms step_avg:93.52ms +step:124/1695 train_time:11598ms step_avg:93.53ms +step:125/1695 train_time:11691ms step_avg:93.53ms +step:125/1695 val_loss:4.6000 train_time:11783ms step_avg:94.26ms +step:126/1695 train_time:11809ms step_avg:93.72ms +step:127/1695 train_time:11885ms step_avg:93.59ms +step:128/1695 train_time:11987ms step_avg:93.65ms +step:129/1695 train_time:12084ms step_avg:93.67ms +step:130/1695 train_time:12177ms step_avg:93.67ms +step:131/1695 train_time:12270ms step_avg:93.67ms +step:132/1695 train_time:12364ms step_avg:93.67ms +step:133/1695 train_time:12458ms step_avg:93.67ms +step:134/1695 train_time:12551ms step_avg:93.67ms +step:135/1695 train_time:12645ms step_avg:93.67ms +step:136/1695 train_time:12738ms step_avg:93.66ms +step:137/1695 train_time:12834ms step_avg:93.68ms +step:138/1695 train_time:12931ms step_avg:93.70ms +step:139/1695 train_time:13027ms step_avg:93.72ms +step:140/1695 train_time:13122ms step_avg:93.73ms +step:141/1695 train_time:13215ms step_avg:93.72ms +step:142/1695 train_time:13309ms step_avg:93.72ms +step:143/1695 train_time:13402ms step_avg:93.72ms +step:144/1695 train_time:13495ms step_avg:93.72ms +step:145/1695 train_time:13589ms step_avg:93.72ms +step:146/1695 train_time:13682ms step_avg:93.71ms +step:147/1695 train_time:13775ms step_avg:93.71ms +step:148/1695 train_time:13870ms step_avg:93.72ms +step:149/1695 train_time:13966ms step_avg:93.73ms +step:150/1695 train_time:14060ms step_avg:93.73ms +step:151/1695 train_time:14155ms step_avg:93.74ms +step:152/1695 train_time:14250ms step_avg:93.75ms +step:153/1695 train_time:14344ms step_avg:93.75ms +step:154/1695 train_time:14437ms step_avg:93.75ms +step:155/1695 train_time:14531ms step_avg:93.75ms +step:156/1695 train_time:14626ms step_avg:93.75ms +step:157/1695 train_time:14719ms step_avg:93.75ms +step:158/1695 train_time:14813ms step_avg:93.75ms +step:159/1695 train_time:14907ms step_avg:93.75ms +step:160/1695 train_time:15001ms step_avg:93.76ms +step:161/1695 train_time:15096ms step_avg:93.76ms +step:162/1695 train_time:15191ms step_avg:93.77ms +step:163/1695 train_time:15285ms step_avg:93.77ms +step:164/1695 train_time:15379ms step_avg:93.77ms +step:165/1695 train_time:15472ms step_avg:93.77ms +step:166/1695 train_time:15567ms step_avg:93.78ms +step:167/1695 train_time:15661ms step_avg:93.78ms +step:168/1695 train_time:15754ms step_avg:93.78ms +step:169/1695 train_time:15850ms step_avg:93.78ms +step:170/1695 train_time:15943ms step_avg:93.78ms +step:171/1695 train_time:16037ms step_avg:93.78ms +step:172/1695 train_time:16133ms step_avg:93.80ms +step:173/1695 train_time:16228ms step_avg:93.80ms +step:174/1695 train_time:16322ms step_avg:93.80ms +step:175/1695 train_time:16415ms step_avg:93.80ms +step:176/1695 train_time:16510ms step_avg:93.80ms +step:177/1695 train_time:16604ms step_avg:93.81ms +step:178/1695 train_time:16697ms step_avg:93.81ms +step:179/1695 train_time:16792ms step_avg:93.81ms +step:180/1695 train_time:16886ms step_avg:93.81ms +step:181/1695 train_time:16979ms step_avg:93.81ms +step:182/1695 train_time:17074ms step_avg:93.81ms +step:183/1695 train_time:17168ms step_avg:93.81ms +step:184/1695 train_time:17262ms step_avg:93.81ms +step:185/1695 train_time:17357ms step_avg:93.82ms +step:186/1695 train_time:17451ms step_avg:93.83ms +step:187/1695 train_time:17545ms step_avg:93.83ms +step:188/1695 train_time:17639ms step_avg:93.83ms +step:189/1695 train_time:17733ms step_avg:93.83ms +step:190/1695 train_time:17828ms step_avg:93.83ms +step:191/1695 train_time:17922ms step_avg:93.83ms +step:192/1695 train_time:18015ms step_avg:93.83ms +step:193/1695 train_time:18109ms step_avg:93.83ms +step:194/1695 train_time:18203ms step_avg:93.83ms +step:195/1695 train_time:18297ms step_avg:93.83ms +step:196/1695 train_time:18392ms step_avg:93.84ms +step:197/1695 train_time:18486ms step_avg:93.84ms +step:198/1695 train_time:18580ms step_avg:93.84ms +step:199/1695 train_time:18674ms step_avg:93.84ms +step:200/1695 train_time:18768ms step_avg:93.84ms +step:201/1695 train_time:18862ms step_avg:93.84ms +step:202/1695 train_time:18957ms step_avg:93.85ms +step:203/1695 train_time:19051ms step_avg:93.85ms +step:204/1695 train_time:19146ms step_avg:93.85ms +step:205/1695 train_time:19239ms step_avg:93.85ms +step:206/1695 train_time:19334ms step_avg:93.85ms +step:207/1695 train_time:19427ms step_avg:93.85ms +step:208/1695 train_time:19521ms step_avg:93.85ms +step:209/1695 train_time:19615ms step_avg:93.85ms +step:210/1695 train_time:19710ms step_avg:93.86ms +step:211/1695 train_time:19805ms step_avg:93.86ms +step:212/1695 train_time:19899ms step_avg:93.86ms +step:213/1695 train_time:19993ms step_avg:93.86ms +step:214/1695 train_time:20087ms step_avg:93.86ms +step:215/1695 train_time:20181ms step_avg:93.86ms +step:216/1695 train_time:20274ms step_avg:93.86ms +step:217/1695 train_time:20368ms step_avg:93.86ms +step:218/1695 train_time:20462ms step_avg:93.86ms +step:219/1695 train_time:20555ms step_avg:93.86ms +step:220/1695 train_time:20650ms step_avg:93.86ms +step:221/1695 train_time:20744ms step_avg:93.86ms +step:222/1695 train_time:20838ms step_avg:93.87ms +step:223/1695 train_time:20933ms step_avg:93.87ms +step:224/1695 train_time:21027ms step_avg:93.87ms +step:225/1695 train_time:21122ms step_avg:93.87ms +step:226/1695 train_time:21216ms step_avg:93.88ms +step:227/1695 train_time:21310ms step_avg:93.88ms +step:228/1695 train_time:21405ms step_avg:93.88ms +step:229/1695 train_time:21498ms step_avg:93.88ms +step:230/1695 train_time:21592ms step_avg:93.88ms +step:231/1695 train_time:21686ms step_avg:93.88ms +step:232/1695 train_time:21780ms step_avg:93.88ms +step:233/1695 train_time:21874ms step_avg:93.88ms +step:234/1695 train_time:21968ms step_avg:93.88ms +step:235/1695 train_time:22062ms step_avg:93.88ms +step:236/1695 train_time:22156ms step_avg:93.88ms +step:237/1695 train_time:22250ms step_avg:93.88ms +step:238/1695 train_time:22345ms step_avg:93.89ms +step:239/1695 train_time:22438ms step_avg:93.88ms +step:240/1695 train_time:22533ms step_avg:93.89ms +step:241/1695 train_time:22627ms step_avg:93.89ms +step:242/1695 train_time:22720ms step_avg:93.89ms +step:243/1695 train_time:22814ms step_avg:93.88ms +step:244/1695 train_time:22908ms step_avg:93.88ms +step:245/1695 train_time:23002ms step_avg:93.88ms +step:246/1695 train_time:23096ms step_avg:93.89ms +step:247/1695 train_time:23191ms step_avg:93.89ms +step:248/1695 train_time:23286ms step_avg:93.89ms +step:249/1695 train_time:23380ms step_avg:93.90ms +step:250/1695 train_time:23473ms step_avg:93.89ms +step:250/1695 val_loss:4.0715 train_time:23566ms step_avg:94.26ms +step:251/1695 train_time:23592ms step_avg:93.99ms +step:252/1695 train_time:23670ms step_avg:93.93ms +step:253/1695 train_time:23772ms step_avg:93.96ms +step:254/1695 train_time:23867ms step_avg:93.97ms +step:255/1695 train_time:23961ms step_avg:93.97ms +step:256/1695 train_time:24055ms step_avg:93.96ms +step:257/1695 train_time:24149ms step_avg:93.96ms +step:258/1695 train_time:24242ms step_avg:93.96ms +step:259/1695 train_time:24335ms step_avg:93.96ms +step:260/1695 train_time:24429ms step_avg:93.96ms +step:261/1695 train_time:24522ms step_avg:93.96ms +step:262/1695 train_time:24617ms step_avg:93.96ms +step:263/1695 train_time:24714ms step_avg:93.97ms +step:264/1695 train_time:24811ms step_avg:93.98ms +step:265/1695 train_time:24907ms step_avg:93.99ms +step:266/1695 train_time:25001ms step_avg:93.99ms +step:267/1695 train_time:25096ms step_avg:93.99ms +step:268/1695 train_time:25190ms step_avg:93.99ms +step:269/1695 train_time:25283ms step_avg:93.99ms +step:270/1695 train_time:25377ms step_avg:93.99ms +step:271/1695 train_time:25471ms step_avg:93.99ms +step:272/1695 train_time:25565ms step_avg:93.99ms +step:273/1695 train_time:25659ms step_avg:93.99ms +step:274/1695 train_time:25754ms step_avg:93.99ms +step:275/1695 train_time:25851ms step_avg:94.00ms +step:276/1695 train_time:25947ms step_avg:94.01ms +step:277/1695 train_time:26041ms step_avg:94.01ms +step:278/1695 train_time:26135ms step_avg:94.01ms +step:279/1695 train_time:26230ms step_avg:94.01ms +step:280/1695 train_time:26324ms step_avg:94.01ms +step:281/1695 train_time:26417ms step_avg:94.01ms +step:282/1695 train_time:26511ms step_avg:94.01ms +step:283/1695 train_time:26606ms step_avg:94.01ms +step:284/1695 train_time:26701ms step_avg:94.02ms +step:285/1695 train_time:26796ms step_avg:94.02ms +step:286/1695 train_time:26891ms step_avg:94.02ms +step:287/1695 train_time:26987ms step_avg:94.03ms +step:288/1695 train_time:27081ms step_avg:94.03ms +step:289/1695 train_time:27175ms step_avg:94.03ms +step:290/1695 train_time:27270ms step_avg:94.04ms +step:291/1695 train_time:27364ms step_avg:94.03ms +step:292/1695 train_time:27458ms step_avg:94.03ms +step:293/1695 train_time:27552ms step_avg:94.03ms +step:294/1695 train_time:27648ms step_avg:94.04ms +step:295/1695 train_time:27741ms step_avg:94.04ms +step:296/1695 train_time:27836ms step_avg:94.04ms +step:297/1695 train_time:27932ms step_avg:94.05ms +step:298/1695 train_time:28026ms step_avg:94.05ms +step:299/1695 train_time:28121ms step_avg:94.05ms +step:300/1695 train_time:28215ms step_avg:94.05ms +step:301/1695 train_time:28310ms step_avg:94.05ms +step:302/1695 train_time:28405ms step_avg:94.05ms +step:303/1695 train_time:28499ms step_avg:94.05ms +step:304/1695 train_time:28593ms step_avg:94.06ms +step:305/1695 train_time:28688ms step_avg:94.06ms +step:306/1695 train_time:28782ms step_avg:94.06ms +step:307/1695 train_time:28876ms step_avg:94.06ms +step:308/1695 train_time:28972ms step_avg:94.06ms +step:309/1695 train_time:29067ms step_avg:94.07ms +step:310/1695 train_time:29162ms step_avg:94.07ms +step:311/1695 train_time:29255ms step_avg:94.07ms +step:312/1695 train_time:29350ms step_avg:94.07ms +step:313/1695 train_time:29445ms step_avg:94.07ms +step:314/1695 train_time:29539ms step_avg:94.07ms +step:315/1695 train_time:29634ms step_avg:94.08ms +step:316/1695 train_time:29729ms step_avg:94.08ms +step:317/1695 train_time:29823ms step_avg:94.08ms +step:318/1695 train_time:29917ms step_avg:94.08ms +step:319/1695 train_time:30013ms step_avg:94.08ms +step:320/1695 train_time:30109ms step_avg:94.09ms +step:321/1695 train_time:30203ms step_avg:94.09ms +step:322/1695 train_time:30297ms step_avg:94.09ms +step:323/1695 train_time:30391ms step_avg:94.09ms +step:324/1695 train_time:30486ms step_avg:94.09ms +step:325/1695 train_time:30581ms step_avg:94.10ms +step:326/1695 train_time:30675ms step_avg:94.10ms +step:327/1695 train_time:30770ms step_avg:94.10ms +step:328/1695 train_time:30865ms step_avg:94.10ms +step:329/1695 train_time:30960ms step_avg:94.10ms +step:330/1695 train_time:31055ms step_avg:94.11ms +step:331/1695 train_time:31150ms step_avg:94.11ms +step:332/1695 train_time:31244ms step_avg:94.11ms +step:333/1695 train_time:31338ms step_avg:94.11ms +step:334/1695 train_time:31433ms step_avg:94.11ms +step:335/1695 train_time:31528ms step_avg:94.11ms +step:336/1695 train_time:31622ms step_avg:94.11ms +step:337/1695 train_time:31716ms step_avg:94.11ms +step:338/1695 train_time:31811ms step_avg:94.12ms +step:339/1695 train_time:31906ms step_avg:94.12ms +step:340/1695 train_time:31999ms step_avg:94.12ms +step:341/1695 train_time:32094ms step_avg:94.12ms +step:342/1695 train_time:32189ms step_avg:94.12ms +step:343/1695 train_time:32283ms step_avg:94.12ms +step:344/1695 train_time:32378ms step_avg:94.12ms +step:345/1695 train_time:32472ms step_avg:94.12ms +step:346/1695 train_time:32567ms step_avg:94.13ms +step:347/1695 train_time:32661ms step_avg:94.13ms +step:348/1695 train_time:32755ms step_avg:94.12ms +step:349/1695 train_time:32850ms step_avg:94.13ms +step:350/1695 train_time:32945ms step_avg:94.13ms +step:351/1695 train_time:33039ms step_avg:94.13ms +step:352/1695 train_time:33134ms step_avg:94.13ms +step:353/1695 train_time:33229ms step_avg:94.13ms +step:354/1695 train_time:33323ms step_avg:94.13ms +step:355/1695 train_time:33417ms step_avg:94.13ms +step:356/1695 train_time:33513ms step_avg:94.14ms +step:357/1695 train_time:33608ms step_avg:94.14ms +step:358/1695 train_time:33701ms step_avg:94.14ms +step:359/1695 train_time:33795ms step_avg:94.14ms +step:360/1695 train_time:33890ms step_avg:94.14ms +step:361/1695 train_time:33985ms step_avg:94.14ms +step:362/1695 train_time:34079ms step_avg:94.14ms +step:363/1695 train_time:34173ms step_avg:94.14ms +step:364/1695 train_time:34268ms step_avg:94.14ms +step:365/1695 train_time:34362ms step_avg:94.14ms +step:366/1695 train_time:34457ms step_avg:94.14ms +step:367/1695 train_time:34552ms step_avg:94.15ms +step:368/1695 train_time:34648ms step_avg:94.15ms +step:369/1695 train_time:34741ms step_avg:94.15ms +step:370/1695 train_time:34835ms step_avg:94.15ms +step:371/1695 train_time:34931ms step_avg:94.15ms +step:372/1695 train_time:35025ms step_avg:94.15ms +step:373/1695 train_time:35119ms step_avg:94.15ms +step:374/1695 train_time:35213ms step_avg:94.15ms +step:375/1695 train_time:35308ms step_avg:94.15ms +step:375/1695 val_loss:3.8701 train_time:35400ms step_avg:94.40ms +step:376/1695 train_time:35426ms step_avg:94.22ms +step:377/1695 train_time:35504ms step_avg:94.18ms +step:378/1695 train_time:35605ms step_avg:94.19ms +step:379/1695 train_time:35702ms step_avg:94.20ms +step:380/1695 train_time:35798ms step_avg:94.20ms +step:381/1695 train_time:35893ms step_avg:94.21ms +step:382/1695 train_time:35988ms step_avg:94.21ms +step:383/1695 train_time:36084ms step_avg:94.21ms +step:384/1695 train_time:36179ms step_avg:94.22ms +step:385/1695 train_time:36274ms step_avg:94.22ms +step:386/1695 train_time:36370ms step_avg:94.22ms +step:387/1695 train_time:36467ms step_avg:94.23ms +step:388/1695 train_time:36566ms step_avg:94.24ms +step:389/1695 train_time:36665ms step_avg:94.25ms +step:390/1695 train_time:36762ms step_avg:94.26ms +step:391/1695 train_time:36859ms step_avg:94.27ms +step:392/1695 train_time:36955ms step_avg:94.27ms +step:393/1695 train_time:37051ms step_avg:94.28ms +step:394/1695 train_time:37146ms step_avg:94.28ms +step:395/1695 train_time:37242ms step_avg:94.28ms +step:396/1695 train_time:37338ms step_avg:94.29ms +step:397/1695 train_time:37434ms step_avg:94.29ms +step:398/1695 train_time:37530ms step_avg:94.30ms +step:399/1695 train_time:37626ms step_avg:94.30ms +step:400/1695 train_time:37724ms step_avg:94.31ms +step:401/1695 train_time:37821ms step_avg:94.32ms +step:402/1695 train_time:37918ms step_avg:94.32ms +step:403/1695 train_time:38013ms step_avg:94.33ms +step:404/1695 train_time:38109ms step_avg:94.33ms +step:405/1695 train_time:38205ms step_avg:94.33ms +step:406/1695 train_time:38301ms step_avg:94.34ms +step:407/1695 train_time:38397ms step_avg:94.34ms +step:408/1695 train_time:38493ms step_avg:94.35ms +step:409/1695 train_time:38589ms step_avg:94.35ms +step:410/1695 train_time:38685ms step_avg:94.35ms +step:411/1695 train_time:38782ms step_avg:94.36ms +step:412/1695 train_time:38880ms step_avg:94.37ms +step:413/1695 train_time:38976ms step_avg:94.37ms +step:414/1695 train_time:39072ms step_avg:94.38ms +step:415/1695 train_time:39168ms step_avg:94.38ms +step:416/1695 train_time:39264ms step_avg:94.38ms +step:417/1695 train_time:39361ms step_avg:94.39ms +step:418/1695 train_time:39457ms step_avg:94.39ms +step:419/1695 train_time:39553ms step_avg:94.40ms +step:420/1695 train_time:39649ms step_avg:94.40ms +step:421/1695 train_time:39744ms step_avg:94.40ms +step:422/1695 train_time:39842ms step_avg:94.41ms +step:423/1695 train_time:39939ms step_avg:94.42ms +step:424/1695 train_time:40034ms step_avg:94.42ms +step:425/1695 train_time:40130ms step_avg:94.42ms +step:426/1695 train_time:40226ms step_avg:94.43ms +step:427/1695 train_time:40322ms step_avg:94.43ms +step:428/1695 train_time:40419ms step_avg:94.44ms +step:429/1695 train_time:40516ms step_avg:94.44ms +step:430/1695 train_time:40611ms step_avg:94.44ms +step:431/1695 train_time:40707ms step_avg:94.45ms +step:432/1695 train_time:40803ms step_avg:94.45ms +step:433/1695 train_time:40900ms step_avg:94.46ms +step:434/1695 train_time:40997ms step_avg:94.46ms +step:435/1695 train_time:41093ms step_avg:94.47ms +step:436/1695 train_time:41188ms step_avg:94.47ms +step:437/1695 train_time:41285ms step_avg:94.47ms +step:438/1695 train_time:41382ms step_avg:94.48ms +step:439/1695 train_time:41478ms step_avg:94.48ms +step:440/1695 train_time:41574ms step_avg:94.49ms +step:441/1695 train_time:41669ms step_avg:94.49ms +step:442/1695 train_time:41766ms step_avg:94.49ms +step:443/1695 train_time:41863ms step_avg:94.50ms +step:444/1695 train_time:41960ms step_avg:94.50ms +step:445/1695 train_time:42057ms step_avg:94.51ms +step:446/1695 train_time:42153ms step_avg:94.51ms +step:447/1695 train_time:42249ms step_avg:94.52ms +step:448/1695 train_time:42345ms step_avg:94.52ms +step:449/1695 train_time:42441ms step_avg:94.52ms +step:450/1695 train_time:42538ms step_avg:94.53ms +step:451/1695 train_time:42634ms step_avg:94.53ms +step:452/1695 train_time:42729ms step_avg:94.53ms +step:453/1695 train_time:42826ms step_avg:94.54ms +step:454/1695 train_time:42922ms step_avg:94.54ms +step:455/1695 train_time:43019ms step_avg:94.55ms +step:456/1695 train_time:43115ms step_avg:94.55ms +step:457/1695 train_time:43211ms step_avg:94.55ms +step:458/1695 train_time:43307ms step_avg:94.56ms +step:459/1695 train_time:43404ms step_avg:94.56ms +step:460/1695 train_time:43500ms step_avg:94.57ms +step:461/1695 train_time:43596ms step_avg:94.57ms +step:462/1695 train_time:43692ms step_avg:94.57ms +step:463/1695 train_time:43788ms step_avg:94.57ms +step:464/1695 train_time:43885ms step_avg:94.58ms +step:465/1695 train_time:43981ms step_avg:94.58ms +step:466/1695 train_time:44078ms step_avg:94.59ms +step:467/1695 train_time:44174ms step_avg:94.59ms +step:468/1695 train_time:44270ms step_avg:94.59ms +step:469/1695 train_time:44366ms step_avg:94.60ms +step:470/1695 train_time:44463ms step_avg:94.60ms +step:471/1695 train_time:44560ms step_avg:94.61ms +step:472/1695 train_time:44656ms step_avg:94.61ms +step:473/1695 train_time:44752ms step_avg:94.61ms +step:474/1695 train_time:44847ms step_avg:94.61ms +step:475/1695 train_time:44943ms step_avg:94.62ms +step:476/1695 train_time:45039ms step_avg:94.62ms +step:477/1695 train_time:45135ms step_avg:94.62ms +step:478/1695 train_time:45230ms step_avg:94.62ms +step:479/1695 train_time:45326ms step_avg:94.63ms +step:480/1695 train_time:45423ms step_avg:94.63ms +step:481/1695 train_time:45520ms step_avg:94.64ms +step:482/1695 train_time:45617ms step_avg:94.64ms +step:483/1695 train_time:45713ms step_avg:94.64ms +step:484/1695 train_time:45809ms step_avg:94.65ms +step:485/1695 train_time:45906ms step_avg:94.65ms +step:486/1695 train_time:46002ms step_avg:94.65ms +step:487/1695 train_time:46099ms step_avg:94.66ms +step:488/1695 train_time:46195ms step_avg:94.66ms +step:489/1695 train_time:46291ms step_avg:94.66ms +step:490/1695 train_time:46387ms step_avg:94.67ms +step:491/1695 train_time:46484ms step_avg:94.67ms +step:492/1695 train_time:46581ms step_avg:94.68ms +step:493/1695 train_time:46678ms step_avg:94.68ms +step:494/1695 train_time:46774ms step_avg:94.68ms +step:495/1695 train_time:46869ms step_avg:94.69ms +step:496/1695 train_time:46965ms step_avg:94.69ms +step:497/1695 train_time:47063ms step_avg:94.69ms +step:498/1695 train_time:47160ms step_avg:94.70ms +step:499/1695 train_time:47257ms step_avg:94.70ms +step:500/1695 train_time:47353ms step_avg:94.71ms +step:500/1695 val_loss:3.7285 train_time:47446ms step_avg:94.89ms +step:501/1695 train_time:47472ms step_avg:94.76ms +step:502/1695 train_time:47553ms step_avg:94.73ms +step:503/1695 train_time:47653ms step_avg:94.74ms +step:504/1695 train_time:47748ms step_avg:94.74ms +step:505/1695 train_time:47844ms step_avg:94.74ms +step:506/1695 train_time:47941ms step_avg:94.74ms +step:507/1695 train_time:48037ms step_avg:94.75ms +step:508/1695 train_time:48132ms step_avg:94.75ms +step:509/1695 train_time:48227ms step_avg:94.75ms +step:510/1695 train_time:48324ms step_avg:94.75ms +step:511/1695 train_time:48420ms step_avg:94.76ms +step:512/1695 train_time:48518ms step_avg:94.76ms +step:513/1695 train_time:48616ms step_avg:94.77ms +step:514/1695 train_time:48714ms step_avg:94.78ms +step:515/1695 train_time:48810ms step_avg:94.78ms +step:516/1695 train_time:48906ms step_avg:94.78ms +step:517/1695 train_time:49003ms step_avg:94.78ms +step:518/1695 train_time:49099ms step_avg:94.79ms +step:519/1695 train_time:49195ms step_avg:94.79ms +step:520/1695 train_time:49290ms step_avg:94.79ms +step:521/1695 train_time:49386ms step_avg:94.79ms +step:522/1695 train_time:49484ms step_avg:94.80ms +step:523/1695 train_time:49581ms step_avg:94.80ms +step:524/1695 train_time:49678ms step_avg:94.81ms +step:525/1695 train_time:49776ms step_avg:94.81ms +step:526/1695 train_time:49871ms step_avg:94.81ms +step:527/1695 train_time:49968ms step_avg:94.82ms +step:528/1695 train_time:50064ms step_avg:94.82ms +step:529/1695 train_time:50160ms step_avg:94.82ms +step:530/1695 train_time:50256ms step_avg:94.82ms +step:531/1695 train_time:50352ms step_avg:94.82ms +step:532/1695 train_time:50447ms step_avg:94.83ms +step:533/1695 train_time:50544ms step_avg:94.83ms +step:534/1695 train_time:50641ms step_avg:94.83ms +step:535/1695 train_time:50737ms step_avg:94.84ms +step:536/1695 train_time:50834ms step_avg:94.84ms +step:537/1695 train_time:50930ms step_avg:94.84ms +step:538/1695 train_time:51026ms step_avg:94.84ms +step:539/1695 train_time:51123ms step_avg:94.85ms +step:540/1695 train_time:51220ms step_avg:94.85ms +step:541/1695 train_time:51316ms step_avg:94.85ms +step:542/1695 train_time:51412ms step_avg:94.86ms +step:543/1695 train_time:51507ms step_avg:94.86ms +step:544/1695 train_time:51604ms step_avg:94.86ms +step:545/1695 train_time:51701ms step_avg:94.86ms +step:546/1695 train_time:51798ms step_avg:94.87ms +step:547/1695 train_time:51894ms step_avg:94.87ms +step:548/1695 train_time:51990ms step_avg:94.87ms +step:549/1695 train_time:52086ms step_avg:94.87ms +step:550/1695 train_time:52183ms step_avg:94.88ms +step:551/1695 train_time:52281ms step_avg:94.88ms +step:552/1695 train_time:52378ms step_avg:94.89ms +step:553/1695 train_time:52473ms step_avg:94.89ms +step:554/1695 train_time:52569ms step_avg:94.89ms +step:555/1695 train_time:52665ms step_avg:94.89ms +step:556/1695 train_time:52763ms step_avg:94.90ms +step:557/1695 train_time:52860ms step_avg:94.90ms +step:558/1695 train_time:52956ms step_avg:94.90ms +step:559/1695 train_time:53052ms step_avg:94.91ms +step:560/1695 train_time:53148ms step_avg:94.91ms +step:561/1695 train_time:53245ms step_avg:94.91ms +step:562/1695 train_time:53342ms step_avg:94.91ms +step:563/1695 train_time:53439ms step_avg:94.92ms +step:564/1695 train_time:53535ms step_avg:94.92ms +step:565/1695 train_time:53632ms step_avg:94.92ms +step:566/1695 train_time:53728ms step_avg:94.93ms +step:567/1695 train_time:53824ms step_avg:94.93ms +step:568/1695 train_time:53921ms step_avg:94.93ms +step:569/1695 train_time:54018ms step_avg:94.93ms +step:570/1695 train_time:54114ms step_avg:94.94ms +step:571/1695 train_time:54210ms step_avg:94.94ms +step:572/1695 train_time:54306ms step_avg:94.94ms +step:573/1695 train_time:54403ms step_avg:94.94ms +step:574/1695 train_time:54500ms step_avg:94.95ms +step:575/1695 train_time:54598ms step_avg:94.95ms +step:576/1695 train_time:54696ms step_avg:94.96ms +step:577/1695 train_time:54792ms step_avg:94.96ms +step:578/1695 train_time:54888ms step_avg:94.96ms +step:579/1695 train_time:54985ms step_avg:94.97ms +step:580/1695 train_time:55082ms step_avg:94.97ms +step:581/1695 train_time:55179ms step_avg:94.97ms +step:582/1695 train_time:55276ms step_avg:94.98ms +step:583/1695 train_time:55372ms step_avg:94.98ms +step:584/1695 train_time:55468ms step_avg:94.98ms +step:585/1695 train_time:55566ms step_avg:94.98ms +step:586/1695 train_time:55663ms step_avg:94.99ms +step:587/1695 train_time:55760ms step_avg:94.99ms +step:588/1695 train_time:55858ms step_avg:95.00ms +step:589/1695 train_time:55953ms step_avg:95.00ms +step:590/1695 train_time:56048ms step_avg:95.00ms +step:591/1695 train_time:56145ms step_avg:95.00ms +step:592/1695 train_time:56242ms step_avg:95.00ms +step:593/1695 train_time:56338ms step_avg:95.01ms +step:594/1695 train_time:56435ms step_avg:95.01ms +step:595/1695 train_time:56530ms step_avg:95.01ms +step:596/1695 train_time:56627ms step_avg:95.01ms +step:597/1695 train_time:56723ms step_avg:95.01ms +step:598/1695 train_time:56821ms step_avg:95.02ms +step:599/1695 train_time:56918ms step_avg:95.02ms +step:600/1695 train_time:57014ms step_avg:95.02ms +step:601/1695 train_time:57110ms step_avg:95.02ms +step:602/1695 train_time:57205ms step_avg:95.03ms +step:603/1695 train_time:57302ms step_avg:95.03ms +step:604/1695 train_time:57399ms step_avg:95.03ms +step:605/1695 train_time:57496ms step_avg:95.03ms +step:606/1695 train_time:57592ms step_avg:95.04ms +step:607/1695 train_time:57688ms step_avg:95.04ms +step:608/1695 train_time:57785ms step_avg:95.04ms +step:609/1695 train_time:57882ms step_avg:95.04ms +step:610/1695 train_time:57978ms step_avg:95.05ms +step:611/1695 train_time:58074ms step_avg:95.05ms +step:612/1695 train_time:58169ms step_avg:95.05ms +step:613/1695 train_time:58266ms step_avg:95.05ms +step:614/1695 train_time:58364ms step_avg:95.06ms +step:615/1695 train_time:58461ms step_avg:95.06ms +step:616/1695 train_time:58558ms step_avg:95.06ms +step:617/1695 train_time:58655ms step_avg:95.06ms +step:618/1695 train_time:58751ms step_avg:95.07ms +step:619/1695 train_time:58847ms step_avg:95.07ms +step:620/1695 train_time:58945ms step_avg:95.07ms +step:621/1695 train_time:59042ms step_avg:95.08ms +step:622/1695 train_time:59139ms step_avg:95.08ms +step:623/1695 train_time:59236ms step_avg:95.08ms +step:624/1695 train_time:59332ms step_avg:95.08ms +step:625/1695 train_time:59429ms step_avg:95.09ms +step:625/1695 val_loss:3.6465 train_time:59524ms step_avg:95.24ms +step:626/1695 train_time:59550ms step_avg:95.13ms +step:627/1695 train_time:59630ms step_avg:95.10ms +step:628/1695 train_time:59729ms step_avg:95.11ms +step:629/1695 train_time:59825ms step_avg:95.11ms +step:630/1695 train_time:59922ms step_avg:95.11ms +step:631/1695 train_time:60019ms step_avg:95.12ms +step:632/1695 train_time:60117ms step_avg:95.12ms +step:633/1695 train_time:60214ms step_avg:95.13ms +step:634/1695 train_time:60311ms step_avg:95.13ms +step:635/1695 train_time:60408ms step_avg:95.13ms +step:636/1695 train_time:60505ms step_avg:95.13ms +step:637/1695 train_time:60603ms step_avg:95.14ms +step:638/1695 train_time:60702ms step_avg:95.14ms +step:639/1695 train_time:60800ms step_avg:95.15ms +step:640/1695 train_time:60899ms step_avg:95.15ms +step:641/1695 train_time:60996ms step_avg:95.16ms +step:642/1695 train_time:61094ms step_avg:95.16ms +step:643/1695 train_time:61191ms step_avg:95.16ms +step:644/1695 train_time:61289ms step_avg:95.17ms +step:645/1695 train_time:61386ms step_avg:95.17ms +step:646/1695 train_time:61484ms step_avg:95.18ms +step:647/1695 train_time:61888ms step_avg:95.65ms +step:648/1695 train_time:61984ms step_avg:95.65ms +step:649/1695 train_time:62080ms step_avg:95.65ms +step:650/1695 train_time:62177ms step_avg:95.66ms +step:651/1695 train_time:62274ms step_avg:95.66ms +step:652/1695 train_time:62371ms step_avg:95.66ms +step:653/1695 train_time:62468ms step_avg:95.66ms +step:654/1695 train_time:62565ms step_avg:95.66ms +step:655/1695 train_time:62661ms step_avg:95.67ms +step:656/1695 train_time:62759ms step_avg:95.67ms +step:657/1695 train_time:62862ms step_avg:95.68ms +step:658/1695 train_time:62961ms step_avg:95.69ms +step:659/1695 train_time:63059ms step_avg:95.69ms +step:660/1695 train_time:63157ms step_avg:95.69ms +step:661/1695 train_time:63254ms step_avg:95.70ms +step:662/1695 train_time:63352ms step_avg:95.70ms +step:663/1695 train_time:63449ms step_avg:95.70ms +step:664/1695 train_time:63547ms step_avg:95.70ms +step:665/1695 train_time:63643ms step_avg:95.70ms +step:666/1695 train_time:63740ms step_avg:95.71ms +step:667/1695 train_time:63839ms step_avg:95.71ms +step:668/1695 train_time:63938ms step_avg:95.72ms +step:669/1695 train_time:64037ms step_avg:95.72ms +step:670/1695 train_time:64135ms step_avg:95.72ms +step:671/1695 train_time:64233ms step_avg:95.73ms +step:672/1695 train_time:64330ms step_avg:95.73ms +step:673/1695 train_time:64429ms step_avg:95.73ms +step:674/1695 train_time:64527ms step_avg:95.74ms +step:675/1695 train_time:64624ms step_avg:95.74ms +step:676/1695 train_time:64722ms step_avg:95.74ms +step:677/1695 train_time:64819ms step_avg:95.74ms +step:678/1695 train_time:64918ms step_avg:95.75ms +step:679/1695 train_time:65016ms step_avg:95.75ms +step:680/1695 train_time:65115ms step_avg:95.76ms +step:681/1695 train_time:65213ms step_avg:95.76ms +step:682/1695 train_time:65310ms step_avg:95.76ms +step:683/1695 train_time:65408ms step_avg:95.77ms +step:684/1695 train_time:65505ms step_avg:95.77ms +step:685/1695 train_time:65602ms step_avg:95.77ms +step:686/1695 train_time:65699ms step_avg:95.77ms +step:687/1695 train_time:65797ms step_avg:95.77ms +step:688/1695 train_time:65896ms step_avg:95.78ms +step:689/1695 train_time:65994ms step_avg:95.78ms +step:690/1695 train_time:66092ms step_avg:95.79ms +step:691/1695 train_time:66190ms step_avg:95.79ms +step:692/1695 train_time:66288ms step_avg:95.79ms +step:693/1695 train_time:66385ms step_avg:95.79ms +step:694/1695 train_time:66483ms step_avg:95.80ms +step:695/1695 train_time:66815ms step_avg:96.14ms +step:696/1695 train_time:66911ms step_avg:96.14ms +step:697/1695 train_time:67008ms step_avg:96.14ms +step:698/1695 train_time:67105ms step_avg:96.14ms +step:699/1695 train_time:67202ms step_avg:96.14ms +step:700/1695 train_time:67299ms step_avg:96.14ms +step:701/1695 train_time:67626ms step_avg:96.47ms +step:702/1695 train_time:67721ms step_avg:96.47ms +step:703/1695 train_time:67818ms step_avg:96.47ms +step:704/1695 train_time:67916ms step_avg:96.47ms +step:705/1695 train_time:68013ms step_avg:96.47ms +step:706/1695 train_time:68110ms step_avg:96.47ms +step:707/1695 train_time:68207ms step_avg:96.47ms +step:708/1695 train_time:68304ms step_avg:96.47ms +step:709/1695 train_time:68401ms step_avg:96.48ms +step:710/1695 train_time:68500ms step_avg:96.48ms +step:711/1695 train_time:68601ms step_avg:96.49ms +step:712/1695 train_time:68700ms step_avg:96.49ms +step:713/1695 train_time:68797ms step_avg:96.49ms +step:714/1695 train_time:68895ms step_avg:96.49ms +step:715/1695 train_time:68993ms step_avg:96.49ms +step:716/1695 train_time:69090ms step_avg:96.50ms +step:717/1695 train_time:69188ms step_avg:96.50ms +step:718/1695 train_time:69284ms step_avg:96.50ms +step:719/1695 train_time:69381ms step_avg:96.50ms +step:720/1695 train_time:69478ms step_avg:96.50ms +step:721/1695 train_time:69578ms step_avg:96.50ms +step:722/1695 train_time:69677ms step_avg:96.51ms +step:723/1695 train_time:69776ms step_avg:96.51ms +step:724/1695 train_time:69873ms step_avg:96.51ms +step:725/1695 train_time:69971ms step_avg:96.51ms +step:726/1695 train_time:70069ms step_avg:96.51ms +step:727/1695 train_time:70167ms step_avg:96.52ms +step:728/1695 train_time:70265ms step_avg:96.52ms +step:729/1695 train_time:70362ms step_avg:96.52ms +step:730/1695 train_time:70459ms step_avg:96.52ms +step:731/1695 train_time:70557ms step_avg:96.52ms +step:732/1695 train_time:70656ms step_avg:96.52ms +step:733/1695 train_time:70754ms step_avg:96.53ms +step:734/1695 train_time:70852ms step_avg:96.53ms +step:735/1695 train_time:70950ms step_avg:96.53ms +step:736/1695 train_time:71049ms step_avg:96.53ms +step:737/1695 train_time:71145ms step_avg:96.53ms +step:738/1695 train_time:71243ms step_avg:96.53ms +step:739/1695 train_time:71340ms step_avg:96.54ms +step:740/1695 train_time:71438ms step_avg:96.54ms +step:741/1695 train_time:71537ms step_avg:96.54ms +step:742/1695 train_time:71635ms step_avg:96.54ms +step:743/1695 train_time:71733ms step_avg:96.54ms +step:744/1695 train_time:71831ms step_avg:96.55ms +step:745/1695 train_time:71929ms step_avg:96.55ms +step:746/1695 train_time:72027ms step_avg:96.55ms +step:747/1695 train_time:72125ms step_avg:96.55ms +step:748/1695 train_time:72222ms step_avg:96.55ms +step:749/1695 train_time:72319ms step_avg:96.55ms +step:750/1695 train_time:72417ms step_avg:96.56ms +step:750/1695 val_loss:3.5809 train_time:72513ms step_avg:96.68ms +step:751/1695 train_time:72539ms step_avg:96.59ms +step:752/1695 train_time:72622ms step_avg:96.57ms +step:753/1695 train_time:72723ms step_avg:96.58ms +step:754/1695 train_time:72822ms step_avg:96.58ms +step:755/1695 train_time:72920ms step_avg:96.58ms +step:756/1695 train_time:73018ms step_avg:96.58ms +step:757/1695 train_time:73116ms step_avg:96.59ms +step:758/1695 train_time:73213ms step_avg:96.59ms +step:759/1695 train_time:73311ms step_avg:96.59ms +step:760/1695 train_time:73408ms step_avg:96.59ms +step:761/1695 train_time:73506ms step_avg:96.59ms +step:762/1695 train_time:73605ms step_avg:96.59ms +step:763/1695 train_time:73705ms step_avg:96.60ms +step:764/1695 train_time:73803ms step_avg:96.60ms +step:765/1695 train_time:73902ms step_avg:96.60ms +step:766/1695 train_time:74000ms step_avg:96.61ms +step:767/1695 train_time:74099ms step_avg:96.61ms +step:768/1695 train_time:74196ms step_avg:96.61ms +step:769/1695 train_time:74294ms step_avg:96.61ms +step:770/1695 train_time:74392ms step_avg:96.61ms +step:771/1695 train_time:74490ms step_avg:96.62ms +step:772/1695 train_time:74589ms step_avg:96.62ms +step:773/1695 train_time:74687ms step_avg:96.62ms +step:774/1695 train_time:74785ms step_avg:96.62ms +step:775/1695 train_time:74883ms step_avg:96.62ms +step:776/1695 train_time:74982ms step_avg:96.63ms +step:777/1695 train_time:75080ms step_avg:96.63ms +step:778/1695 train_time:75477ms step_avg:97.01ms +step:779/1695 train_time:75573ms step_avg:97.01ms +step:780/1695 train_time:75670ms step_avg:97.01ms +step:781/1695 train_time:75767ms step_avg:97.01ms +step:782/1695 train_time:75865ms step_avg:97.01ms +step:783/1695 train_time:75962ms step_avg:97.01ms +step:784/1695 train_time:76287ms step_avg:97.31ms +step:785/1695 train_time:76384ms step_avg:97.30ms +step:786/1695 train_time:76481ms step_avg:97.30ms +step:787/1695 train_time:76578ms step_avg:97.30ms +step:788/1695 train_time:76675ms step_avg:97.30ms +step:789/1695 train_time:76772ms step_avg:97.30ms +step:790/1695 train_time:77113ms step_avg:97.61ms +step:791/1695 train_time:77209ms step_avg:97.61ms +step:792/1695 train_time:77306ms step_avg:97.61ms +step:793/1695 train_time:77404ms step_avg:97.61ms +step:794/1695 train_time:77501ms step_avg:97.61ms +step:795/1695 train_time:77599ms step_avg:97.61ms +step:796/1695 train_time:77697ms step_avg:97.61ms +step:797/1695 train_time:77794ms step_avg:97.61ms +step:798/1695 train_time:77891ms step_avg:97.61ms +step:799/1695 train_time:77988ms step_avg:97.61ms +step:800/1695 train_time:78090ms step_avg:97.61ms +step:801/1695 train_time:78471ms step_avg:97.97ms +step:802/1695 train_time:78522ms step_avg:97.91ms +step:803/1695 train_time:78618ms step_avg:97.91ms +step:804/1695 train_time:78715ms step_avg:97.90ms +step:805/1695 train_time:78812ms step_avg:97.90ms +step:806/1695 train_time:78909ms step_avg:97.90ms +step:807/1695 train_time:79006ms step_avg:97.90ms +step:808/1695 train_time:79104ms step_avg:97.90ms +step:809/1695 train_time:79201ms step_avg:97.90ms +step:810/1695 train_time:79299ms step_avg:97.90ms +step:811/1695 train_time:79398ms step_avg:97.90ms +step:812/1695 train_time:79498ms step_avg:97.90ms +step:813/1695 train_time:79598ms step_avg:97.91ms +step:814/1695 train_time:79696ms step_avg:97.91ms +step:815/1695 train_time:80037ms step_avg:98.20ms +step:816/1695 train_time:80133ms step_avg:98.20ms +step:817/1695 train_time:80229ms step_avg:98.20ms +step:818/1695 train_time:80327ms step_avg:98.20ms +step:819/1695 train_time:80424ms step_avg:98.20ms +step:820/1695 train_time:80522ms step_avg:98.20ms +step:821/1695 train_time:80620ms step_avg:98.20ms +step:822/1695 train_time:80717ms step_avg:98.20ms +step:823/1695 train_time:80815ms step_avg:98.20ms +step:824/1695 train_time:80915ms step_avg:98.20ms +step:825/1695 train_time:81016ms step_avg:98.20ms +step:826/1695 train_time:81116ms step_avg:98.20ms +step:827/1695 train_time:81215ms step_avg:98.20ms +step:828/1695 train_time:81313ms step_avg:98.20ms +step:829/1695 train_time:81410ms step_avg:98.20ms +step:830/1695 train_time:81508ms step_avg:98.20ms +step:831/1695 train_time:81606ms step_avg:98.20ms +step:832/1695 train_time:81704ms step_avg:98.20ms +step:833/1695 train_time:81802ms step_avg:98.20ms +step:834/1695 train_time:81900ms step_avg:98.20ms +step:835/1695 train_time:81999ms step_avg:98.20ms +step:836/1695 train_time:82098ms step_avg:98.20ms +step:837/1695 train_time:82196ms step_avg:98.20ms +step:838/1695 train_time:82295ms step_avg:98.20ms +step:839/1695 train_time:82394ms step_avg:98.21ms +step:840/1695 train_time:82492ms step_avg:98.20ms +step:841/1695 train_time:82590ms step_avg:98.20ms +step:842/1695 train_time:82687ms step_avg:98.20ms +step:843/1695 train_time:82785ms step_avg:98.20ms +step:844/1695 train_time:82882ms step_avg:98.20ms +step:845/1695 train_time:82980ms step_avg:98.20ms +step:846/1695 train_time:83079ms step_avg:98.20ms +step:847/1695 train_time:83178ms step_avg:98.20ms +step:848/1695 train_time:83277ms step_avg:98.20ms +step:849/1695 train_time:83376ms step_avg:98.20ms +step:850/1695 train_time:83474ms step_avg:98.20ms +step:851/1695 train_time:83572ms step_avg:98.21ms +step:852/1695 train_time:83671ms step_avg:98.21ms +step:853/1695 train_time:83769ms step_avg:98.21ms +step:854/1695 train_time:83868ms step_avg:98.21ms +step:855/1695 train_time:83966ms step_avg:98.21ms +step:856/1695 train_time:84063ms step_avg:98.20ms +step:857/1695 train_time:84161ms step_avg:98.20ms +step:858/1695 train_time:84260ms step_avg:98.21ms +step:859/1695 train_time:84360ms step_avg:98.21ms +step:860/1695 train_time:84459ms step_avg:98.21ms +step:861/1695 train_time:84558ms step_avg:98.21ms +step:862/1695 train_time:84657ms step_avg:98.21ms +step:863/1695 train_time:84756ms step_avg:98.21ms +step:864/1695 train_time:84854ms step_avg:98.21ms +step:865/1695 train_time:84953ms step_avg:98.21ms +step:866/1695 train_time:85052ms step_avg:98.21ms +step:867/1695 train_time:85148ms step_avg:98.21ms +step:868/1695 train_time:85246ms step_avg:98.21ms +step:869/1695 train_time:85345ms step_avg:98.21ms +step:870/1695 train_time:85442ms step_avg:98.21ms +step:871/1695 train_time:85542ms step_avg:98.21ms +step:872/1695 train_time:85641ms step_avg:98.21ms +step:873/1695 train_time:85740ms step_avg:98.21ms +step:874/1695 train_time:85840ms step_avg:98.22ms +step:875/1695 train_time:85939ms step_avg:98.22ms +step:875/1695 val_loss:3.5360 train_time:86035ms step_avg:98.33ms +step:876/1695 train_time:86061ms step_avg:98.24ms +step:877/1695 train_time:86146ms step_avg:98.23ms +step:878/1695 train_time:86249ms step_avg:98.23ms +step:879/1695 train_time:86349ms step_avg:98.24ms +step:880/1695 train_time:86447ms step_avg:98.24ms +step:881/1695 train_time:86547ms step_avg:98.24ms +step:882/1695 train_time:86647ms step_avg:98.24ms +step:883/1695 train_time:86746ms step_avg:98.24ms +step:884/1695 train_time:86845ms step_avg:98.24ms +step:885/1695 train_time:86944ms step_avg:98.24ms +step:886/1695 train_time:87045ms step_avg:98.24ms +step:887/1695 train_time:87147ms step_avg:98.25ms +step:888/1695 train_time:87249ms step_avg:98.25ms +step:889/1695 train_time:87350ms step_avg:98.26ms +step:890/1695 train_time:87450ms step_avg:98.26ms +step:891/1695 train_time:87549ms step_avg:98.26ms +step:892/1695 train_time:87648ms step_avg:98.26ms +step:893/1695 train_time:87748ms step_avg:98.26ms +step:894/1695 train_time:87847ms step_avg:98.26ms +step:895/1695 train_time:87946ms step_avg:98.26ms +step:896/1695 train_time:88046ms step_avg:98.27ms +step:897/1695 train_time:88147ms step_avg:98.27ms +step:898/1695 train_time:88249ms step_avg:98.27ms +step:899/1695 train_time:88349ms step_avg:98.28ms +step:900/1695 train_time:88450ms step_avg:98.28ms +step:901/1695 train_time:88549ms step_avg:98.28ms +step:902/1695 train_time:88648ms step_avg:98.28ms +step:903/1695 train_time:88748ms step_avg:98.28ms +step:904/1695 train_time:88847ms step_avg:98.28ms +step:905/1695 train_time:88946ms step_avg:98.28ms +step:906/1695 train_time:89046ms step_avg:98.28ms +step:907/1695 train_time:89147ms step_avg:98.29ms +step:908/1695 train_time:89248ms step_avg:98.29ms +step:909/1695 train_time:89349ms step_avg:98.29ms +step:910/1695 train_time:89450ms step_avg:98.30ms +step:911/1695 train_time:89549ms step_avg:98.30ms +step:912/1695 train_time:89649ms step_avg:98.30ms +step:913/1695 train_time:89749ms step_avg:98.30ms +step:914/1695 train_time:89848ms step_avg:98.30ms +step:915/1695 train_time:89948ms step_avg:98.30ms +step:916/1695 train_time:90047ms step_avg:98.30ms +step:917/1695 train_time:90148ms step_avg:98.31ms +step:918/1695 train_time:90249ms step_avg:98.31ms +step:919/1695 train_time:90349ms step_avg:98.31ms +step:920/1695 train_time:90450ms step_avg:98.32ms +step:921/1695 train_time:90550ms step_avg:98.32ms +step:922/1695 train_time:90649ms step_avg:98.32ms +step:923/1695 train_time:90749ms step_avg:98.32ms +step:924/1695 train_time:90849ms step_avg:98.32ms +step:925/1695 train_time:90949ms step_avg:98.32ms +step:926/1695 train_time:91050ms step_avg:98.33ms +step:927/1695 train_time:91149ms step_avg:98.33ms +step:928/1695 train_time:91249ms step_avg:98.33ms +step:929/1695 train_time:91350ms step_avg:98.33ms +step:930/1695 train_time:91449ms step_avg:98.33ms +step:931/1695 train_time:91549ms step_avg:98.33ms +step:932/1695 train_time:91649ms step_avg:98.34ms +step:933/1695 train_time:91748ms step_avg:98.34ms +step:934/1695 train_time:91848ms step_avg:98.34ms +step:935/1695 train_time:91947ms step_avg:98.34ms +step:936/1695 train_time:92047ms step_avg:98.34ms +step:937/1695 train_time:92148ms step_avg:98.34ms +step:938/1695 train_time:92248ms step_avg:98.35ms +step:939/1695 train_time:92348ms step_avg:98.35ms +step:940/1695 train_time:92449ms step_avg:98.35ms +step:941/1695 train_time:92550ms step_avg:98.35ms +step:942/1695 train_time:92650ms step_avg:98.35ms +step:943/1695 train_time:92750ms step_avg:98.36ms +step:944/1695 train_time:92849ms step_avg:98.36ms +step:945/1695 train_time:92950ms step_avg:98.36ms +step:946/1695 train_time:93049ms step_avg:98.36ms +step:947/1695 train_time:93149ms step_avg:98.36ms +step:948/1695 train_time:93249ms step_avg:98.36ms +step:949/1695 train_time:93348ms step_avg:98.36ms +step:950/1695 train_time:93448ms step_avg:98.37ms +step:951/1695 train_time:93548ms step_avg:98.37ms +step:952/1695 train_time:93649ms step_avg:98.37ms +step:953/1695 train_time:93749ms step_avg:98.37ms +step:954/1695 train_time:93848ms step_avg:98.37ms +step:955/1695 train_time:93948ms step_avg:98.37ms +step:956/1695 train_time:94048ms step_avg:98.38ms +step:957/1695 train_time:94148ms step_avg:98.38ms +step:958/1695 train_time:94248ms step_avg:98.38ms +step:959/1695 train_time:94348ms step_avg:98.38ms +step:960/1695 train_time:94448ms step_avg:98.38ms +step:961/1695 train_time:94549ms step_avg:98.39ms +step:962/1695 train_time:94649ms step_avg:98.39ms +step:963/1695 train_time:94748ms step_avg:98.39ms +step:964/1695 train_time:94848ms step_avg:98.39ms +step:965/1695 train_time:94948ms step_avg:98.39ms +step:966/1695 train_time:95048ms step_avg:98.39ms +step:967/1695 train_time:95148ms step_avg:98.39ms +step:968/1695 train_time:95248ms step_avg:98.40ms +step:969/1695 train_time:95349ms step_avg:98.40ms +step:970/1695 train_time:95448ms step_avg:98.40ms +step:971/1695 train_time:95548ms step_avg:98.40ms +step:972/1695 train_time:95648ms step_avg:98.40ms +step:973/1695 train_time:95748ms step_avg:98.40ms +step:974/1695 train_time:95847ms step_avg:98.41ms +step:975/1695 train_time:95947ms step_avg:98.41ms +step:976/1695 train_time:96047ms step_avg:98.41ms +step:977/1695 train_time:96147ms step_avg:98.41ms +step:978/1695 train_time:96248ms step_avg:98.41ms +step:979/1695 train_time:96349ms step_avg:98.42ms +step:980/1695 train_time:96449ms step_avg:98.42ms +step:981/1695 train_time:96548ms step_avg:98.42ms +step:982/1695 train_time:96649ms step_avg:98.42ms +step:983/1695 train_time:96749ms step_avg:98.42ms +step:984/1695 train_time:96848ms step_avg:98.42ms +step:985/1695 train_time:96948ms step_avg:98.42ms +step:986/1695 train_time:97049ms step_avg:98.43ms +step:987/1695 train_time:97150ms step_avg:98.43ms +step:988/1695 train_time:97250ms step_avg:98.43ms +step:989/1695 train_time:97350ms step_avg:98.43ms +step:990/1695 train_time:97449ms step_avg:98.43ms +step:991/1695 train_time:97550ms step_avg:98.44ms +step:992/1695 train_time:97650ms step_avg:98.44ms +step:993/1695 train_time:97750ms step_avg:98.44ms +step:994/1695 train_time:97850ms step_avg:98.44ms +step:995/1695 train_time:97949ms step_avg:98.44ms +step:996/1695 train_time:98049ms step_avg:98.44ms +step:997/1695 train_time:98149ms step_avg:98.44ms +step:998/1695 train_time:98248ms step_avg:98.44ms +step:999/1695 train_time:98349ms step_avg:98.45ms +step:1000/1695 train_time:98448ms step_avg:98.45ms +step:1000/1695 val_loss:3.4900 train_time:98546ms step_avg:98.55ms +step:1001/1695 train_time:98572ms step_avg:98.47ms +step:1002/1695 train_time:98657ms step_avg:98.46ms +step:1003/1695 train_time:98758ms step_avg:98.46ms +step:1004/1695 train_time:98858ms step_avg:98.46ms +step:1005/1695 train_time:98957ms step_avg:98.46ms +step:1006/1695 train_time:99056ms step_avg:98.47ms +step:1007/1695 train_time:99155ms step_avg:98.47ms +step:1008/1695 train_time:99255ms step_avg:98.47ms +step:1009/1695 train_time:99354ms step_avg:98.47ms +step:1010/1695 train_time:99452ms step_avg:98.47ms +step:1011/1695 train_time:99555ms step_avg:98.47ms +step:1012/1695 train_time:99658ms step_avg:98.48ms +step:1013/1695 train_time:99760ms step_avg:98.48ms +step:1014/1695 train_time:99860ms step_avg:98.48ms +step:1015/1695 train_time:99959ms step_avg:98.48ms +step:1016/1695 train_time:100058ms step_avg:98.48ms +step:1017/1695 train_time:100157ms step_avg:98.48ms +step:1018/1695 train_time:100256ms step_avg:98.48ms +step:1019/1695 train_time:100355ms step_avg:98.48ms +step:1020/1695 train_time:100455ms step_avg:98.49ms +step:1021/1695 train_time:100556ms step_avg:98.49ms +step:1022/1695 train_time:100658ms step_avg:98.49ms +step:1023/1695 train_time:100758ms step_avg:98.49ms +step:1024/1695 train_time:100860ms step_avg:98.50ms +step:1025/1695 train_time:100960ms step_avg:98.50ms +step:1026/1695 train_time:101060ms step_avg:98.50ms +step:1027/1695 train_time:101159ms step_avg:98.50ms +step:1028/1695 train_time:101258ms step_avg:98.50ms +step:1029/1695 train_time:101359ms step_avg:98.50ms +step:1030/1695 train_time:101459ms step_avg:98.50ms +step:1031/1695 train_time:101559ms step_avg:98.51ms +step:1032/1695 train_time:101658ms step_avg:98.51ms +step:1033/1695 train_time:101759ms step_avg:98.51ms +step:1034/1695 train_time:101859ms step_avg:98.51ms +step:1035/1695 train_time:101959ms step_avg:98.51ms +step:1036/1695 train_time:102059ms step_avg:98.51ms +step:1037/1695 train_time:102159ms step_avg:98.51ms +step:1038/1695 train_time:102259ms step_avg:98.52ms +step:1039/1695 train_time:102358ms step_avg:98.52ms +step:1040/1695 train_time:102457ms step_avg:98.52ms +step:1041/1695 train_time:102558ms step_avg:98.52ms +step:1042/1695 train_time:102657ms step_avg:98.52ms +step:1043/1695 train_time:102758ms step_avg:98.52ms +step:1044/1695 train_time:102858ms step_avg:98.52ms +step:1045/1695 train_time:102959ms step_avg:98.52ms +step:1046/1695 train_time:103059ms step_avg:98.53ms +step:1047/1695 train_time:103158ms step_avg:98.53ms +step:1048/1695 train_time:103258ms step_avg:98.53ms +step:1049/1695 train_time:103357ms step_avg:98.53ms +step:1050/1695 train_time:103457ms step_avg:98.53ms +step:1051/1695 train_time:103557ms step_avg:98.53ms +step:1052/1695 train_time:103657ms step_avg:98.53ms +step:1053/1695 train_time:103757ms step_avg:98.53ms +step:1054/1695 train_time:103857ms step_avg:98.54ms +step:1055/1695 train_time:103958ms step_avg:98.54ms +step:1056/1695 train_time:104058ms step_avg:98.54ms +step:1057/1695 train_time:104157ms step_avg:98.54ms +step:1058/1695 train_time:104258ms step_avg:98.54ms +step:1059/1695 train_time:104358ms step_avg:98.54ms +step:1060/1695 train_time:104458ms step_avg:98.54ms +step:1061/1695 train_time:104557ms step_avg:98.55ms +step:1062/1695 train_time:104657ms step_avg:98.55ms +step:1063/1695 train_time:104759ms step_avg:98.55ms +step:1064/1695 train_time:104859ms step_avg:98.55ms +step:1065/1695 train_time:104959ms step_avg:98.55ms +step:1066/1695 train_time:105058ms step_avg:98.55ms +step:1067/1695 train_time:105158ms step_avg:98.55ms +step:1068/1695 train_time:105258ms step_avg:98.56ms +step:1069/1695 train_time:105358ms step_avg:98.56ms +step:1070/1695 train_time:105459ms step_avg:98.56ms +step:1071/1695 train_time:105558ms step_avg:98.56ms +step:1072/1695 train_time:105658ms step_avg:98.56ms +step:1073/1695 train_time:105758ms step_avg:98.56ms +step:1074/1695 train_time:105858ms step_avg:98.56ms +step:1075/1695 train_time:105958ms step_avg:98.57ms +step:1076/1695 train_time:106057ms step_avg:98.57ms +step:1077/1695 train_time:106159ms step_avg:98.57ms +step:1078/1695 train_time:106258ms step_avg:98.57ms +step:1079/1695 train_time:106358ms step_avg:98.57ms +step:1080/1695 train_time:106458ms step_avg:98.57ms +step:1081/1695 train_time:106558ms step_avg:98.57ms +step:1082/1695 train_time:106658ms step_avg:98.58ms +step:1083/1695 train_time:106758ms step_avg:98.58ms +step:1084/1695 train_time:106858ms step_avg:98.58ms +step:1085/1695 train_time:106959ms step_avg:98.58ms +step:1086/1695 train_time:107059ms step_avg:98.58ms +step:1087/1695 train_time:107158ms step_avg:98.58ms +step:1088/1695 train_time:107259ms step_avg:98.58ms +step:1089/1695 train_time:107358ms step_avg:98.58ms +step:1090/1695 train_time:107459ms step_avg:98.59ms +step:1091/1695 train_time:107559ms step_avg:98.59ms +step:1092/1695 train_time:107659ms step_avg:98.59ms +step:1093/1695 train_time:107759ms step_avg:98.59ms +step:1094/1695 train_time:107860ms step_avg:98.59ms +step:1095/1695 train_time:107959ms step_avg:98.59ms +step:1096/1695 train_time:108059ms step_avg:98.59ms +step:1097/1695 train_time:108158ms step_avg:98.59ms +step:1098/1695 train_time:108258ms step_avg:98.60ms +step:1099/1695 train_time:108358ms step_avg:98.60ms +step:1100/1695 train_time:108457ms step_avg:98.60ms +step:1101/1695 train_time:108557ms step_avg:98.60ms +step:1102/1695 train_time:108658ms step_avg:98.60ms +step:1103/1695 train_time:108757ms step_avg:98.60ms +step:1104/1695 train_time:108857ms step_avg:98.60ms +step:1105/1695 train_time:108958ms step_avg:98.60ms +step:1106/1695 train_time:109058ms step_avg:98.61ms +step:1107/1695 train_time:109158ms step_avg:98.61ms +step:1108/1695 train_time:109258ms step_avg:98.61ms +step:1109/1695 train_time:109357ms step_avg:98.61ms +step:1110/1695 train_time:109458ms step_avg:98.61ms +step:1111/1695 train_time:109558ms step_avg:98.61ms +step:1112/1695 train_time:109658ms step_avg:98.61ms +step:1113/1695 train_time:109757ms step_avg:98.61ms +step:1114/1695 train_time:109858ms step_avg:98.62ms +step:1115/1695 train_time:109958ms step_avg:98.62ms +step:1116/1695 train_time:110057ms step_avg:98.62ms +step:1117/1695 train_time:110158ms step_avg:98.62ms +step:1118/1695 train_time:110258ms step_avg:98.62ms +step:1119/1695 train_time:110358ms step_avg:98.62ms +step:1120/1695 train_time:110457ms step_avg:98.62ms +step:1121/1695 train_time:110557ms step_avg:98.62ms +step:1122/1695 train_time:110657ms step_avg:98.63ms +step:1123/1695 train_time:110758ms step_avg:98.63ms +step:1124/1695 train_time:110858ms step_avg:98.63ms +step:1125/1695 train_time:110959ms step_avg:98.63ms +step:1125/1695 val_loss:3.4391 train_time:111057ms step_avg:98.72ms +step:1126/1695 train_time:111083ms step_avg:98.65ms +step:1127/1695 train_time:111170ms step_avg:98.64ms +step:1128/1695 train_time:111271ms step_avg:98.64ms +step:1129/1695 train_time:111371ms step_avg:98.65ms +step:1130/1695 train_time:111471ms step_avg:98.65ms +step:1131/1695 train_time:111570ms step_avg:98.65ms +step:1132/1695 train_time:111669ms step_avg:98.65ms +step:1133/1695 train_time:111769ms step_avg:98.65ms +step:1134/1695 train_time:111868ms step_avg:98.65ms +step:1135/1695 train_time:111969ms step_avg:98.65ms +step:1136/1695 train_time:112072ms step_avg:98.65ms +step:1137/1695 train_time:112174ms step_avg:98.66ms +step:1138/1695 train_time:112275ms step_avg:98.66ms +step:1139/1695 train_time:112375ms step_avg:98.66ms +step:1140/1695 train_time:112475ms step_avg:98.66ms +step:1141/1695 train_time:112575ms step_avg:98.66ms +step:1142/1695 train_time:112675ms step_avg:98.66ms +step:1143/1695 train_time:112774ms step_avg:98.67ms +step:1144/1695 train_time:112875ms step_avg:98.67ms +step:1145/1695 train_time:112976ms step_avg:98.67ms +step:1146/1695 train_time:113075ms step_avg:98.67ms +step:1147/1695 train_time:113176ms step_avg:98.67ms +step:1148/1695 train_time:113276ms step_avg:98.67ms +step:1149/1695 train_time:113376ms step_avg:98.67ms +step:1150/1695 train_time:113476ms step_avg:98.68ms +step:1151/1695 train_time:113576ms step_avg:98.68ms +step:1152/1695 train_time:113676ms step_avg:98.68ms +step:1153/1695 train_time:113777ms step_avg:98.68ms +step:1154/1695 train_time:113877ms step_avg:98.68ms +step:1155/1695 train_time:113977ms step_avg:98.68ms +step:1156/1695 train_time:114077ms step_avg:98.68ms +step:1157/1695 train_time:114178ms step_avg:98.68ms +step:1158/1695 train_time:114279ms step_avg:98.69ms +step:1159/1695 train_time:114379ms step_avg:98.69ms +step:1160/1695 train_time:114480ms step_avg:98.69ms +step:1161/1695 train_time:114580ms step_avg:98.69ms +step:1162/1695 train_time:114680ms step_avg:98.69ms +step:1163/1695 train_time:114782ms step_avg:98.69ms +step:1164/1695 train_time:114884ms step_avg:98.70ms +step:1165/1695 train_time:114986ms step_avg:98.70ms +step:1166/1695 train_time:115089ms step_avg:98.70ms +step:1167/1695 train_time:115190ms step_avg:98.71ms +step:1168/1695 train_time:115291ms step_avg:98.71ms +step:1169/1695 train_time:115392ms step_avg:98.71ms +step:1170/1695 train_time:115492ms step_avg:98.71ms +step:1171/1695 train_time:115593ms step_avg:98.71ms +step:1172/1695 train_time:115695ms step_avg:98.72ms +step:1173/1695 train_time:115795ms step_avg:98.72ms +step:1174/1695 train_time:115896ms step_avg:98.72ms +step:1175/1695 train_time:115997ms step_avg:98.72ms +step:1176/1695 train_time:116097ms step_avg:98.72ms +step:1177/1695 train_time:116197ms step_avg:98.72ms +step:1178/1695 train_time:116300ms step_avg:98.73ms +step:1179/1695 train_time:116402ms step_avg:98.73ms +step:1180/1695 train_time:116504ms step_avg:98.73ms +step:1181/1695 train_time:116605ms step_avg:98.73ms +step:1182/1695 train_time:116706ms step_avg:98.74ms +step:1183/1695 train_time:116806ms step_avg:98.74ms +step:1184/1695 train_time:116909ms step_avg:98.74ms +step:1185/1695 train_time:117011ms step_avg:98.74ms +step:1186/1695 train_time:117112ms step_avg:98.74ms +step:1187/1695 train_time:117212ms step_avg:98.75ms +step:1188/1695 train_time:117313ms step_avg:98.75ms +step:1189/1695 train_time:117414ms step_avg:98.75ms +step:1190/1695 train_time:117514ms step_avg:98.75ms +step:1191/1695 train_time:117616ms step_avg:98.75ms +step:1192/1695 train_time:117715ms step_avg:98.75ms +step:1193/1695 train_time:117816ms step_avg:98.76ms +step:1194/1695 train_time:117917ms step_avg:98.76ms +step:1195/1695 train_time:118016ms step_avg:98.76ms +step:1196/1695 train_time:118117ms step_avg:98.76ms +step:1197/1695 train_time:118217ms step_avg:98.76ms +step:1198/1695 train_time:118318ms step_avg:98.76ms +step:1199/1695 train_time:118419ms step_avg:98.76ms +step:1200/1695 train_time:118519ms step_avg:98.77ms +step:1201/1695 train_time:118620ms step_avg:98.77ms +step:1202/1695 train_time:118721ms step_avg:98.77ms +step:1203/1695 train_time:118823ms step_avg:98.77ms +step:1204/1695 train_time:118923ms step_avg:98.77ms +step:1205/1695 train_time:119024ms step_avg:98.78ms +step:1206/1695 train_time:119125ms step_avg:98.78ms +step:1207/1695 train_time:119228ms step_avg:98.78ms +step:1208/1695 train_time:119329ms step_avg:98.78ms +step:1209/1695 train_time:119430ms step_avg:98.78ms +step:1210/1695 train_time:119531ms step_avg:98.79ms +step:1211/1695 train_time:119632ms step_avg:98.79ms +step:1212/1695 train_time:119732ms step_avg:98.79ms +step:1213/1695 train_time:119833ms step_avg:98.79ms +step:1214/1695 train_time:119933ms step_avg:98.79ms +step:1215/1695 train_time:120034ms step_avg:98.79ms +step:1216/1695 train_time:120135ms step_avg:98.80ms +step:1217/1695 train_time:120235ms step_avg:98.80ms +step:1218/1695 train_time:120335ms step_avg:98.80ms +step:1219/1695 train_time:120435ms step_avg:98.80ms +step:1220/1695 train_time:120536ms step_avg:98.80ms +step:1221/1695 train_time:120636ms step_avg:98.80ms +step:1222/1695 train_time:120737ms step_avg:98.80ms +step:1223/1695 train_time:120837ms step_avg:98.80ms +step:1224/1695 train_time:120937ms step_avg:98.80ms +step:1225/1695 train_time:121038ms step_avg:98.81ms +step:1226/1695 train_time:121139ms step_avg:98.81ms +step:1227/1695 train_time:121241ms step_avg:98.81ms +step:1228/1695 train_time:121342ms step_avg:98.81ms +step:1229/1695 train_time:121443ms step_avg:98.81ms +step:1230/1695 train_time:121544ms step_avg:98.82ms +step:1231/1695 train_time:121644ms step_avg:98.82ms +step:1232/1695 train_time:121745ms step_avg:98.82ms +step:1233/1695 train_time:121845ms step_avg:98.82ms +step:1234/1695 train_time:121949ms step_avg:98.82ms +step:1235/1695 train_time:122050ms step_avg:98.83ms +step:1236/1695 train_time:122152ms step_avg:98.83ms +step:1237/1695 train_time:122253ms step_avg:98.83ms +step:1238/1695 train_time:122353ms step_avg:98.83ms +step:1239/1695 train_time:122454ms step_avg:98.83ms +step:1240/1695 train_time:122554ms step_avg:98.83ms +step:1241/1695 train_time:122655ms step_avg:98.84ms +step:1242/1695 train_time:122756ms step_avg:98.84ms +step:1243/1695 train_time:122857ms step_avg:98.84ms +step:1244/1695 train_time:122957ms step_avg:98.84ms +step:1245/1695 train_time:123057ms step_avg:98.84ms +step:1246/1695 train_time:123158ms step_avg:98.84ms +step:1247/1695 train_time:123259ms step_avg:98.84ms +step:1248/1695 train_time:123359ms step_avg:98.85ms +step:1249/1695 train_time:123459ms step_avg:98.85ms +step:1250/1695 train_time:123559ms step_avg:98.85ms +step:1250/1695 val_loss:3.3934 train_time:123659ms step_avg:98.93ms +step:1251/1695 train_time:123684ms step_avg:98.87ms +step:1252/1695 train_time:123775ms step_avg:98.86ms +step:1253/1695 train_time:123879ms step_avg:98.87ms +step:1254/1695 train_time:123981ms step_avg:98.87ms +step:1255/1695 train_time:124081ms step_avg:98.87ms +step:1256/1695 train_time:124181ms step_avg:98.87ms +step:1257/1695 train_time:124280ms step_avg:98.87ms +step:1258/1695 train_time:124381ms step_avg:98.87ms +step:1259/1695 train_time:124480ms step_avg:98.87ms +step:1260/1695 train_time:124581ms step_avg:98.87ms +step:1261/1695 train_time:124683ms step_avg:98.88ms +step:1262/1695 train_time:124786ms step_avg:98.88ms +step:1263/1695 train_time:124887ms step_avg:98.88ms +step:1264/1695 train_time:124987ms step_avg:98.88ms +step:1265/1695 train_time:125087ms step_avg:98.88ms +step:1266/1695 train_time:125186ms step_avg:98.88ms +step:1267/1695 train_time:125287ms step_avg:98.88ms +step:1268/1695 train_time:125387ms step_avg:98.89ms +step:1269/1695 train_time:125487ms step_avg:98.89ms +step:1270/1695 train_time:125588ms step_avg:98.89ms +step:1271/1695 train_time:125689ms step_avg:98.89ms +step:1272/1695 train_time:125788ms step_avg:98.89ms +step:1273/1695 train_time:125889ms step_avg:98.89ms +step:1274/1695 train_time:125990ms step_avg:98.89ms +step:1275/1695 train_time:126091ms step_avg:98.90ms +step:1276/1695 train_time:126194ms step_avg:98.90ms +step:1277/1695 train_time:126295ms step_avg:98.90ms +step:1278/1695 train_time:126396ms step_avg:98.90ms +step:1279/1695 train_time:126498ms step_avg:98.90ms +step:1280/1695 train_time:126599ms step_avg:98.91ms +step:1281/1695 train_time:126701ms step_avg:98.91ms +step:1282/1695 train_time:126801ms step_avg:98.91ms +step:1283/1695 train_time:126903ms step_avg:98.91ms +step:1284/1695 train_time:127004ms step_avg:98.91ms +step:1285/1695 train_time:127104ms step_avg:98.91ms +step:1286/1695 train_time:127205ms step_avg:98.92ms +step:1287/1695 train_time:127306ms step_avg:98.92ms +step:1288/1695 train_time:127406ms step_avg:98.92ms +step:1289/1695 train_time:127507ms step_avg:98.92ms +step:1290/1695 train_time:127607ms step_avg:98.92ms +step:1291/1695 train_time:127708ms step_avg:98.92ms +step:1292/1695 train_time:127808ms step_avg:98.92ms +step:1293/1695 train_time:127909ms step_avg:98.92ms +step:1294/1695 train_time:128010ms step_avg:98.93ms +step:1295/1695 train_time:128112ms step_avg:98.93ms +step:1296/1695 train_time:128213ms step_avg:98.93ms +step:1297/1695 train_time:128313ms step_avg:98.93ms +step:1298/1695 train_time:128416ms step_avg:98.93ms +step:1299/1695 train_time:128518ms step_avg:98.94ms +step:1300/1695 train_time:128619ms step_avg:98.94ms +step:1301/1695 train_time:128721ms step_avg:98.94ms +step:1302/1695 train_time:128822ms step_avg:98.94ms +step:1303/1695 train_time:128924ms step_avg:98.94ms +step:1304/1695 train_time:129025ms step_avg:98.95ms +step:1305/1695 train_time:129125ms step_avg:98.95ms +step:1306/1695 train_time:129225ms step_avg:98.95ms +step:1307/1695 train_time:129325ms step_avg:98.95ms +step:1308/1695 train_time:129426ms step_avg:98.95ms +step:1309/1695 train_time:129527ms step_avg:98.95ms +step:1310/1695 train_time:129628ms step_avg:98.95ms +step:1311/1695 train_time:129729ms step_avg:98.95ms +step:1312/1695 train_time:129829ms step_avg:98.95ms +step:1313/1695 train_time:129930ms step_avg:98.96ms +step:1314/1695 train_time:130031ms step_avg:98.96ms +step:1315/1695 train_time:130133ms step_avg:98.96ms +step:1316/1695 train_time:130234ms step_avg:98.96ms +step:1317/1695 train_time:130336ms step_avg:98.96ms +step:1318/1695 train_time:130436ms step_avg:98.97ms +step:1319/1695 train_time:130538ms step_avg:98.97ms +step:1320/1695 train_time:130641ms step_avg:98.97ms +step:1321/1695 train_time:130740ms step_avg:98.97ms +step:1322/1695 train_time:130842ms step_avg:98.97ms +step:1323/1695 train_time:130942ms step_avg:98.97ms +step:1324/1695 train_time:131043ms step_avg:98.98ms +step:1325/1695 train_time:131144ms step_avg:98.98ms +step:1326/1695 train_time:131246ms step_avg:98.98ms +step:1327/1695 train_time:131346ms step_avg:98.98ms +step:1328/1695 train_time:131447ms step_avg:98.98ms +step:1329/1695 train_time:131546ms step_avg:98.98ms +step:1330/1695 train_time:131646ms step_avg:98.98ms +step:1331/1695 train_time:131747ms step_avg:98.98ms +step:1332/1695 train_time:131848ms step_avg:98.98ms +step:1333/1695 train_time:131948ms step_avg:98.99ms +step:1334/1695 train_time:132050ms step_avg:98.99ms +step:1335/1695 train_time:132151ms step_avg:98.99ms +step:1336/1695 train_time:132254ms step_avg:98.99ms +step:1337/1695 train_time:132355ms step_avg:98.99ms +step:1338/1695 train_time:132456ms step_avg:99.00ms +step:1339/1695 train_time:132558ms step_avg:99.00ms +step:1340/1695 train_time:132659ms step_avg:99.00ms +step:1341/1695 train_time:132760ms step_avg:99.00ms +step:1342/1695 train_time:132861ms step_avg:99.00ms +step:1343/1695 train_time:132962ms step_avg:99.00ms +step:1344/1695 train_time:133062ms step_avg:99.00ms +step:1345/1695 train_time:133166ms step_avg:99.01ms +step:1346/1695 train_time:133266ms step_avg:99.01ms +step:1347/1695 train_time:133368ms step_avg:99.01ms +step:1348/1695 train_time:133467ms step_avg:99.01ms +step:1349/1695 train_time:133567ms step_avg:99.01ms +step:1350/1695 train_time:133668ms step_avg:99.01ms +step:1351/1695 train_time:133767ms step_avg:99.01ms +step:1352/1695 train_time:133867ms step_avg:99.01ms +step:1353/1695 train_time:133968ms step_avg:99.02ms +step:1354/1695 train_time:134069ms step_avg:99.02ms +step:1355/1695 train_time:134170ms step_avg:99.02ms +step:1356/1695 train_time:134271ms step_avg:99.02ms +step:1357/1695 train_time:134371ms step_avg:99.02ms +step:1358/1695 train_time:134472ms step_avg:99.02ms +step:1359/1695 train_time:134572ms step_avg:99.02ms +step:1360/1695 train_time:134674ms step_avg:99.03ms +step:1361/1695 train_time:134775ms step_avg:99.03ms +step:1362/1695 train_time:134876ms step_avg:99.03ms +step:1363/1695 train_time:134978ms step_avg:99.03ms +step:1364/1695 train_time:135080ms step_avg:99.03ms +step:1365/1695 train_time:135182ms step_avg:99.03ms +step:1366/1695 train_time:135282ms step_avg:99.04ms +step:1367/1695 train_time:135383ms step_avg:99.04ms +step:1368/1695 train_time:135485ms step_avg:99.04ms +step:1369/1695 train_time:135585ms step_avg:99.04ms +step:1370/1695 train_time:135686ms step_avg:99.04ms +step:1371/1695 train_time:135786ms step_avg:99.04ms +step:1372/1695 train_time:135887ms step_avg:99.04ms +step:1373/1695 train_time:135988ms step_avg:99.04ms +step:1374/1695 train_time:136089ms step_avg:99.05ms +step:1375/1695 train_time:136190ms step_avg:99.05ms +step:1375/1695 val_loss:3.3541 train_time:136290ms step_avg:99.12ms +step:1376/1695 train_time:136316ms step_avg:99.07ms +step:1377/1695 train_time:136402ms step_avg:99.06ms +step:1378/1695 train_time:136505ms step_avg:99.06ms +step:1379/1695 train_time:136606ms step_avg:99.06ms +step:1380/1695 train_time:136709ms step_avg:99.06ms +step:1381/1695 train_time:136809ms step_avg:99.07ms +step:1382/1695 train_time:136909ms step_avg:99.07ms +step:1383/1695 train_time:137008ms step_avg:99.07ms +step:1384/1695 train_time:137109ms step_avg:99.07ms +step:1385/1695 train_time:137210ms step_avg:99.07ms +step:1386/1695 train_time:137313ms step_avg:99.07ms +step:1387/1695 train_time:137416ms step_avg:99.07ms +step:1388/1695 train_time:137516ms step_avg:99.07ms +step:1389/1695 train_time:137618ms step_avg:99.08ms +step:1390/1695 train_time:137720ms step_avg:99.08ms +step:1391/1695 train_time:137823ms step_avg:99.08ms +step:1392/1695 train_time:137924ms step_avg:99.08ms +step:1393/1695 train_time:138026ms step_avg:99.09ms +step:1394/1695 train_time:138127ms step_avg:99.09ms +step:1395/1695 train_time:138229ms step_avg:99.09ms +step:1396/1695 train_time:138330ms step_avg:99.09ms +step:1397/1695 train_time:138433ms step_avg:99.09ms +step:1398/1695 train_time:138534ms step_avg:99.09ms +step:1399/1695 train_time:138636ms step_avg:99.10ms +step:1400/1695 train_time:138739ms step_avg:99.10ms +step:1401/1695 train_time:138842ms step_avg:99.10ms +step:1402/1695 train_time:138944ms step_avg:99.10ms +step:1403/1695 train_time:139047ms step_avg:99.11ms +step:1404/1695 train_time:139150ms step_avg:99.11ms +step:1405/1695 train_time:139251ms step_avg:99.11ms +step:1406/1695 train_time:139352ms step_avg:99.11ms +step:1407/1695 train_time:139454ms step_avg:99.11ms +step:1408/1695 train_time:139556ms step_avg:99.12ms +step:1409/1695 train_time:139660ms step_avg:99.12ms +step:1410/1695 train_time:139761ms step_avg:99.12ms +step:1411/1695 train_time:139862ms step_avg:99.12ms +step:1412/1695 train_time:139967ms step_avg:99.13ms +step:1413/1695 train_time:140068ms step_avg:99.13ms +step:1414/1695 train_time:140170ms step_avg:99.13ms +step:1415/1695 train_time:140272ms step_avg:99.13ms +step:1416/1695 train_time:140372ms step_avg:99.13ms +step:1417/1695 train_time:140473ms step_avg:99.13ms +step:1418/1695 train_time:140574ms step_avg:99.14ms +step:1419/1695 train_time:140676ms step_avg:99.14ms +step:1420/1695 train_time:140778ms step_avg:99.14ms +step:1421/1695 train_time:140880ms step_avg:99.14ms +step:1422/1695 train_time:140981ms step_avg:99.14ms +step:1423/1695 train_time:141083ms step_avg:99.14ms +step:1424/1695 train_time:141187ms step_avg:99.15ms +step:1425/1695 train_time:141288ms step_avg:99.15ms +step:1426/1695 train_time:141391ms step_avg:99.15ms +step:1427/1695 train_time:141492ms step_avg:99.15ms +step:1428/1695 train_time:141594ms step_avg:99.16ms +step:1429/1695 train_time:141696ms step_avg:99.16ms +step:1430/1695 train_time:141797ms step_avg:99.16ms +step:1431/1695 train_time:141900ms step_avg:99.16ms +step:1432/1695 train_time:142001ms step_avg:99.16ms +step:1433/1695 train_time:142104ms step_avg:99.17ms +step:1434/1695 train_time:142206ms step_avg:99.17ms +step:1435/1695 train_time:142309ms step_avg:99.17ms +step:1436/1695 train_time:142411ms step_avg:99.17ms +step:1437/1695 train_time:142513ms step_avg:99.17ms +step:1438/1695 train_time:142614ms step_avg:99.17ms +step:1439/1695 train_time:142716ms step_avg:99.18ms +step:1440/1695 train_time:142819ms step_avg:99.18ms +step:1441/1695 train_time:142922ms step_avg:99.18ms +step:1442/1695 train_time:143022ms step_avg:99.18ms +step:1443/1695 train_time:143124ms step_avg:99.18ms +step:1444/1695 train_time:143227ms step_avg:99.19ms +step:1445/1695 train_time:143328ms step_avg:99.19ms +step:1446/1695 train_time:143430ms step_avg:99.19ms +step:1447/1695 train_time:143530ms step_avg:99.19ms +step:1448/1695 train_time:143633ms step_avg:99.19ms +step:1449/1695 train_time:143734ms step_avg:99.20ms +step:1450/1695 train_time:143835ms step_avg:99.20ms +step:1451/1695 train_time:143937ms step_avg:99.20ms +step:1452/1695 train_time:144040ms step_avg:99.20ms +step:1453/1695 train_time:144142ms step_avg:99.20ms +step:1454/1695 train_time:144245ms step_avg:99.21ms +step:1455/1695 train_time:144348ms step_avg:99.21ms +step:1456/1695 train_time:144449ms step_avg:99.21ms +step:1457/1695 train_time:144551ms step_avg:99.21ms +step:1458/1695 train_time:144653ms step_avg:99.21ms +step:1459/1695 train_time:144755ms step_avg:99.22ms +step:1460/1695 train_time:144856ms step_avg:99.22ms +step:1461/1695 train_time:144958ms step_avg:99.22ms +step:1462/1695 train_time:145060ms step_avg:99.22ms +step:1463/1695 train_time:145162ms step_avg:99.22ms +step:1464/1695 train_time:145264ms step_avg:99.22ms +step:1465/1695 train_time:145366ms step_avg:99.23ms +step:1466/1695 train_time:145468ms step_avg:99.23ms +step:1467/1695 train_time:145570ms step_avg:99.23ms +step:1468/1695 train_time:145671ms step_avg:99.23ms +step:1469/1695 train_time:145774ms step_avg:99.23ms +step:1470/1695 train_time:145875ms step_avg:99.23ms +step:1471/1695 train_time:145976ms step_avg:99.24ms +step:1472/1695 train_time:146078ms step_avg:99.24ms +step:1473/1695 train_time:146179ms step_avg:99.24ms +step:1474/1695 train_time:146282ms step_avg:99.24ms +step:1475/1695 train_time:146384ms step_avg:99.24ms +step:1476/1695 train_time:146487ms step_avg:99.25ms +step:1477/1695 train_time:146590ms step_avg:99.25ms +step:1478/1695 train_time:146691ms step_avg:99.25ms +step:1479/1695 train_time:146792ms step_avg:99.25ms +step:1480/1695 train_time:146894ms step_avg:99.25ms +step:1481/1695 train_time:146996ms step_avg:99.25ms +step:1482/1695 train_time:147097ms step_avg:99.26ms +step:1483/1695 train_time:147201ms step_avg:99.26ms +step:1484/1695 train_time:147303ms step_avg:99.26ms +step:1485/1695 train_time:147404ms step_avg:99.26ms +step:1486/1695 train_time:147505ms step_avg:99.26ms +step:1487/1695 train_time:147607ms step_avg:99.27ms +step:1488/1695 train_time:147711ms step_avg:99.27ms +step:1489/1695 train_time:147812ms step_avg:99.27ms +step:1490/1695 train_time:147915ms step_avg:99.27ms +step:1491/1695 train_time:148017ms step_avg:99.27ms +step:1492/1695 train_time:148117ms step_avg:99.27ms +step:1493/1695 train_time:148220ms step_avg:99.28ms +step:1494/1695 train_time:148322ms step_avg:99.28ms +step:1495/1695 train_time:148424ms step_avg:99.28ms +step:1496/1695 train_time:148527ms step_avg:99.28ms +step:1497/1695 train_time:148629ms step_avg:99.28ms +step:1498/1695 train_time:148731ms step_avg:99.29ms +step:1499/1695 train_time:148831ms step_avg:99.29ms +step:1500/1695 train_time:148932ms step_avg:99.29ms +step:1500/1695 val_loss:3.3188 train_time:149031ms step_avg:99.35ms +step:1501/1695 train_time:149057ms step_avg:99.31ms +step:1502/1695 train_time:149145ms step_avg:99.30ms +step:1503/1695 train_time:149246ms step_avg:99.30ms +step:1504/1695 train_time:149347ms step_avg:99.30ms +step:1505/1695 train_time:149447ms step_avg:99.30ms +step:1506/1695 train_time:149549ms step_avg:99.30ms +step:1507/1695 train_time:149650ms step_avg:99.30ms +step:1508/1695 train_time:149750ms step_avg:99.30ms +step:1509/1695 train_time:149853ms step_avg:99.31ms +step:1510/1695 train_time:149955ms step_avg:99.31ms +step:1511/1695 train_time:150060ms step_avg:99.31ms +step:1512/1695 train_time:150163ms step_avg:99.31ms +step:1513/1695 train_time:150264ms step_avg:99.32ms +step:1514/1695 train_time:150366ms step_avg:99.32ms +step:1515/1695 train_time:150471ms step_avg:99.32ms +step:1516/1695 train_time:150573ms step_avg:99.32ms +step:1517/1695 train_time:150674ms step_avg:99.32ms +step:1518/1695 train_time:150775ms step_avg:99.32ms +step:1519/1695 train_time:150878ms step_avg:99.33ms +step:1520/1695 train_time:150979ms step_avg:99.33ms +step:1521/1695 train_time:151082ms step_avg:99.33ms +step:1522/1695 train_time:151184ms step_avg:99.33ms +step:1523/1695 train_time:151285ms step_avg:99.33ms +step:1524/1695 train_time:151389ms step_avg:99.34ms +step:1525/1695 train_time:151492ms step_avg:99.34ms +step:1526/1695 train_time:151594ms step_avg:99.34ms +step:1527/1695 train_time:151696ms step_avg:99.34ms +step:1528/1695 train_time:151802ms step_avg:99.35ms +step:1529/1695 train_time:151903ms step_avg:99.35ms +step:1530/1695 train_time:152007ms step_avg:99.35ms +step:1531/1695 train_time:152109ms step_avg:99.35ms +step:1532/1695 train_time:152212ms step_avg:99.36ms +step:1533/1695 train_time:152315ms step_avg:99.36ms +step:1534/1695 train_time:152418ms step_avg:99.36ms +step:1535/1695 train_time:152520ms step_avg:99.36ms +step:1536/1695 train_time:152621ms step_avg:99.36ms +step:1537/1695 train_time:152722ms step_avg:99.36ms +step:1538/1695 train_time:152824ms step_avg:99.37ms +step:1539/1695 train_time:152925ms step_avg:99.37ms +step:1540/1695 train_time:153026ms step_avg:99.37ms +step:1541/1695 train_time:153131ms step_avg:99.37ms +step:1542/1695 train_time:153236ms step_avg:99.37ms +step:1543/1695 train_time:153338ms step_avg:99.38ms +step:1544/1695 train_time:153439ms step_avg:99.38ms +step:1545/1695 train_time:153541ms step_avg:99.38ms +step:1546/1695 train_time:153643ms step_avg:99.38ms +step:1547/1695 train_time:153746ms step_avg:99.38ms +step:1548/1695 train_time:153847ms step_avg:99.38ms +step:1549/1695 train_time:153948ms step_avg:99.39ms +step:1550/1695 train_time:154050ms step_avg:99.39ms +step:1551/1695 train_time:154152ms step_avg:99.39ms +step:1552/1695 train_time:154254ms step_avg:99.39ms +step:1553/1695 train_time:154358ms step_avg:99.39ms +step:1554/1695 train_time:154459ms step_avg:99.39ms +step:1555/1695 train_time:154560ms step_avg:99.40ms +step:1556/1695 train_time:154662ms step_avg:99.40ms +step:1557/1695 train_time:154765ms step_avg:99.40ms +step:1558/1695 train_time:154867ms step_avg:99.40ms +step:1559/1695 train_time:154970ms step_avg:99.40ms +step:1560/1695 train_time:155071ms step_avg:99.40ms +step:1561/1695 train_time:155173ms step_avg:99.41ms +step:1562/1695 train_time:155277ms step_avg:99.41ms +step:1563/1695 train_time:155381ms step_avg:99.41ms +step:1564/1695 train_time:155482ms step_avg:99.41ms +step:1565/1695 train_time:155584ms step_avg:99.41ms +step:1566/1695 train_time:155684ms step_avg:99.42ms +step:1567/1695 train_time:155786ms step_avg:99.42ms +step:1568/1695 train_time:155886ms step_avg:99.42ms +step:1569/1695 train_time:155986ms step_avg:99.42ms +step:1570/1695 train_time:156089ms step_avg:99.42ms +step:1571/1695 train_time:156190ms step_avg:99.42ms +step:1572/1695 train_time:156292ms step_avg:99.42ms +step:1573/1695 train_time:156395ms step_avg:99.42ms +step:1574/1695 train_time:156496ms step_avg:99.43ms +step:1575/1695 train_time:156599ms step_avg:99.43ms +step:1576/1695 train_time:156703ms step_avg:99.43ms +step:1577/1695 train_time:156807ms step_avg:99.43ms +step:1578/1695 train_time:156907ms step_avg:99.43ms +step:1579/1695 train_time:157009ms step_avg:99.44ms +step:1580/1695 train_time:157110ms step_avg:99.44ms +step:1581/1695 train_time:157212ms step_avg:99.44ms +step:1582/1695 train_time:157314ms step_avg:99.44ms +step:1583/1695 train_time:157417ms step_avg:99.44ms +step:1584/1695 train_time:157521ms step_avg:99.44ms +step:1585/1695 train_time:157622ms step_avg:99.45ms +step:1586/1695 train_time:157725ms step_avg:99.45ms +step:1587/1695 train_time:157827ms step_avg:99.45ms +step:1588/1695 train_time:157928ms step_avg:99.45ms +step:1589/1695 train_time:158029ms step_avg:99.45ms +step:1590/1695 train_time:158130ms step_avg:99.45ms +step:1591/1695 train_time:158231ms step_avg:99.45ms +step:1592/1695 train_time:158334ms step_avg:99.46ms +step:1593/1695 train_time:158434ms step_avg:99.46ms +step:1594/1695 train_time:158538ms step_avg:99.46ms +step:1595/1695 train_time:158641ms step_avg:99.46ms +step:1596/1695 train_time:158743ms step_avg:99.46ms +step:1597/1695 train_time:158845ms step_avg:99.46ms +step:1598/1695 train_time:158947ms step_avg:99.47ms +step:1599/1695 train_time:159048ms step_avg:99.47ms +step:1600/1695 train_time:159150ms step_avg:99.47ms +step:1601/1695 train_time:159252ms step_avg:99.47ms +step:1602/1695 train_time:159355ms step_avg:99.47ms +step:1603/1695 train_time:159456ms step_avg:99.47ms +step:1604/1695 train_time:159558ms step_avg:99.48ms +step:1605/1695 train_time:159661ms step_avg:99.48ms +step:1606/1695 train_time:159764ms step_avg:99.48ms +step:1607/1695 train_time:159864ms step_avg:99.48ms +step:1608/1695 train_time:159965ms step_avg:99.48ms +step:1609/1695 train_time:160067ms step_avg:99.48ms +step:1610/1695 train_time:160169ms step_avg:99.48ms +step:1611/1695 train_time:160272ms step_avg:99.49ms +step:1612/1695 train_time:160374ms step_avg:99.49ms +step:1613/1695 train_time:160475ms step_avg:99.49ms +step:1614/1695 train_time:160576ms step_avg:99.49ms +step:1615/1695 train_time:160678ms step_avg:99.49ms +step:1616/1695 train_time:160780ms step_avg:99.49ms +step:1617/1695 train_time:160883ms step_avg:99.49ms +step:1618/1695 train_time:160985ms step_avg:99.50ms +step:1619/1695 train_time:161086ms step_avg:99.50ms +step:1620/1695 train_time:161188ms step_avg:99.50ms +step:1621/1695 train_time:161289ms step_avg:99.50ms +step:1622/1695 train_time:161391ms step_avg:99.50ms +step:1623/1695 train_time:161493ms step_avg:99.50ms +step:1624/1695 train_time:161595ms step_avg:99.50ms +step:1625/1695 train_time:161700ms step_avg:99.51ms +step:1625/1695 val_loss:3.2904 train_time:161801ms step_avg:99.57ms +step:1626/1695 train_time:161827ms step_avg:99.52ms +step:1627/1695 train_time:161915ms step_avg:99.52ms +step:1628/1695 train_time:162017ms step_avg:99.52ms +step:1629/1695 train_time:162119ms step_avg:99.52ms +step:1630/1695 train_time:162220ms step_avg:99.52ms +step:1631/1695 train_time:162321ms step_avg:99.52ms +step:1632/1695 train_time:162422ms step_avg:99.52ms +step:1633/1695 train_time:162522ms step_avg:99.52ms +step:1634/1695 train_time:162625ms step_avg:99.53ms +step:1635/1695 train_time:162728ms step_avg:99.53ms +step:1636/1695 train_time:162831ms step_avg:99.53ms +step:1637/1695 train_time:162935ms step_avg:99.53ms +step:1638/1695 train_time:163038ms step_avg:99.53ms +step:1639/1695 train_time:163140ms step_avg:99.54ms +step:1640/1695 train_time:163243ms step_avg:99.54ms +step:1641/1695 train_time:163347ms step_avg:99.54ms +step:1642/1695 train_time:163449ms step_avg:99.54ms +step:1643/1695 train_time:163552ms step_avg:99.54ms +step:1644/1695 train_time:163655ms step_avg:99.55ms +step:1645/1695 train_time:163758ms step_avg:99.55ms +step:1646/1695 train_time:163861ms step_avg:99.55ms +step:1647/1695 train_time:163965ms step_avg:99.55ms +step:1648/1695 train_time:164070ms step_avg:99.56ms +step:1649/1695 train_time:164173ms step_avg:99.56ms +step:1650/1695 train_time:164275ms step_avg:99.56ms +step:1651/1695 train_time:164379ms step_avg:99.56ms +step:1652/1695 train_time:164481ms step_avg:99.56ms +step:1653/1695 train_time:164584ms step_avg:99.57ms +step:1654/1695 train_time:164686ms step_avg:99.57ms +step:1655/1695 train_time:164789ms step_avg:99.57ms +step:1656/1695 train_time:164894ms step_avg:99.57ms +step:1657/1695 train_time:164996ms step_avg:99.57ms +step:1658/1695 train_time:165098ms step_avg:99.58ms +step:1659/1695 train_time:165204ms step_avg:99.58ms +step:1660/1695 train_time:165306ms step_avg:99.58ms +step:1661/1695 train_time:165412ms step_avg:99.59ms +step:1662/1695 train_time:165517ms step_avg:99.59ms +step:1663/1695 train_time:165620ms step_avg:99.59ms +step:1664/1695 train_time:165722ms step_avg:99.59ms +step:1665/1695 train_time:165828ms step_avg:99.60ms +step:1666/1695 train_time:165931ms step_avg:99.60ms +step:1667/1695 train_time:166033ms step_avg:99.60ms +step:1668/1695 train_time:166139ms step_avg:99.60ms +step:1669/1695 train_time:166242ms step_avg:99.61ms +step:1670/1695 train_time:166344ms step_avg:99.61ms +step:1671/1695 train_time:166447ms step_avg:99.61ms +step:1672/1695 train_time:166551ms step_avg:99.61ms +step:1673/1695 train_time:166654ms step_avg:99.61ms +step:1674/1695 train_time:166757ms step_avg:99.62ms +step:1675/1695 train_time:166859ms step_avg:99.62ms +step:1676/1695 train_time:166963ms step_avg:99.62ms +step:1677/1695 train_time:167065ms step_avg:99.62ms +step:1678/1695 train_time:167169ms step_avg:99.62ms +step:1679/1695 train_time:167273ms step_avg:99.63ms +step:1680/1695 train_time:167376ms step_avg:99.63ms +step:1681/1695 train_time:167479ms step_avg:99.63ms +step:1682/1695 train_time:167585ms step_avg:99.63ms +step:1683/1695 train_time:167688ms step_avg:99.64ms +step:1684/1695 train_time:167792ms step_avg:99.64ms +step:1685/1695 train_time:167895ms step_avg:99.64ms +step:1686/1695 train_time:167998ms step_avg:99.64ms +step:1687/1695 train_time:168100ms step_avg:99.64ms +step:1688/1695 train_time:168202ms step_avg:99.65ms +step:1689/1695 train_time:168304ms step_avg:99.65ms +step:1690/1695 train_time:168407ms step_avg:99.65ms +step:1691/1695 train_time:168510ms step_avg:99.65ms +step:1692/1695 train_time:168613ms step_avg:99.65ms +step:1693/1695 train_time:168716ms step_avg:99.66ms +step:1694/1695 train_time:168820ms step_avg:99.66ms +step:1695/1695 train_time:168923ms step_avg:99.66ms +step:1695/1695 val_loss:3.2777 train_time:169024ms step_avg:99.72ms +peak memory allocated: 34761 MiB reserved: 49580 MiB diff --git a/requirements.txt b/requirements.txt index fe83bb138..80dc92a80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy tqdm torch huggingface-hub +triton \ No newline at end of file diff --git a/train_gpt.py b/train_gpt.py index 57ccce211..74556953b 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -7,7 +7,7 @@ import copy import glob from dataclasses import dataclass -from functools import lru_cache, partial # Added partial for hook registration +from functools import lru_cache from pathlib import Path os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -19,6 +19,8 @@ # use of FlexAttention contributed by @KoszarskyB from torch.nn.attention.flex_attention import BlockMask, flex_attention #torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl # ----------------------------------------------------------------------------- # Custom operators: FP8 matmul by @YouJiacheng @@ -102,37 +104,287 @@ def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): mm_op.register_autograd(backward, setup_context=setup_context) # ----------------------------------------------------------------------------- -# Muon optimizer +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) -@torch.compile -def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) - X = G + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + # Perform the NS iterations - for _ in range(steps): - A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies if G.size(-2) > G.size(-1): X = X.mT return X +# ----------------------------------------------------------------------------- +# Muon optimizer + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -166,7 +418,7 @@ def step(self): rank = dist.get_rank() world_size = dist.get_world_size() reduce_scatter_futures: list[torch.Future] = [] - all_reduce_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] for group in self.param_groups: params: list[Tensor] = group["params"] grad = torch.empty_like(params[-1]) @@ -196,11 +448,11 @@ def step(self): p.mul_(1 - eff_weight_decay) momentum_buffer.lerp_(grad, 1 - momentum) grad = grad.lerp_(momentum_buffer, momentum) - v = zeropower_via_newtonschulz5(grad.bfloat16(), 5) + v = newton_schulz_triton(grad) p.add_(other=v, alpha=-eff_lr) idx += 1 - all_reduce_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_reduce_futures).wait() + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() class DistAdam(torch.optim.Optimizer): def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): @@ -221,11 +473,10 @@ def step(self): rank = dist.get_rank() world_size = dist.get_world_size() reduce_scatter_futures: list[torch.Future] = [] - all_reduce_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] grad_slices = [] for group in self.param_groups: params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) for base_i in range(len(params)): grad = params[base_i].grad rank_size = grad.shape[0] // world_size @@ -272,8 +523,8 @@ def step(self): update = exp_avg.div(denom).mul_(step_size) p_slice.add_(other=update, alpha=-1.0) idx += 1 - all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_reduce_futures).wait() + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the model @@ -328,45 +579,61 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): self.num_heads = num_heads self.head_dim = head_dim hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" std = 0.5 * (dim ** -0.5) bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng # https://x.com/hi_tysam/status/1879699187107033311 - self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound)) + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero self.rotary = Rotary(head_dim, max_seq_len) - self.c_proj = CastedLinear(hdim, dim) - self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 self.attn_scale = 0.12 + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): B, T = x.size(0), x.size(1) # batch size, sequence length assert B == 1, "Must use batch size = 1 for FlexAttention" - q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 q, k = self.rotary(q), self.rotary(k) if ve is not None: v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 else: # skip mid-layers token value embeddings by @YouJiacheng v = lambdas[0] * v - y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2) + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = self.c_proj(y) + y = F.linear(y, self.qkvo_w[3].type_as(y)) return y class MLP(nn.Module): def __init__(self, dim: int): super().__init__() hdim = 4 * dim - self.c_fc = CastedLinear(dim, hdim) - self.c_proj = CastedLinear(hdim, dim) - self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 def forward(self, x: Tensor): - x = self.c_fc(x) + x = F.linear(x, self.c_fc.T.type_as(x)) x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = self.c_proj(x) + x = F.linear(x, self.c_proj.type_as(x)) return x class Block(nn.Module): @@ -400,7 +667,8 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) self.lm_head.weight.detach().zero_() # @Grad62304977 # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 @@ -416,7 +684,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: param.lr_mul = 75. for param in self.value_embeds.parameters(): param.lr_mul = 75. - self.lm_head.weight.lr_mul = 27.5 + self.lm_head.weight.lr_mul = 1.0 self.scalars.lr_mul = 5.0 def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): @@ -491,7 +759,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_bloc x = norm(x) logits = self.lm_head(x).float() # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5)) + logits = 30 * torch.sigmoid(logits / 7.5) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") return loss @@ -510,44 +778,46 @@ def _load_data_shard(file: Path): assert nbytes == 2 * num_tokens, "number of tokens read does not match header" return tokens -# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap -def find_batch_starts(tokens: Tensor, pos: int, local_batch_size: int, max_batch_span: int): - boundary_mask = tokens[pos : pos + max_batch_span] == 50256 +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap by @classiclarryd +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos start = boundary_positions[0].item() starts = [] for i in range(1, len(boundary_positions)): end = boundary_positions[i].item() - if end - start >= local_batch_size: + if end - start >= seq_len: starts.append(start) # append start once end pos is confirmed if len(starts) == dist.get_world_size(): return starts, end - pos start = end - assert False # increase max_batch_span if necessary + assert False # increase token_window if necessary -def distributed_data_generator(filename_pattern: str, batch_size: int, align_to_bos: bool): +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): rank = dist.get_rank() world_size = dist.get_world_size() + batch_size = seq_len * world_size files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - assert batch_size % world_size == 0 - local_batch_size = batch_size // world_size file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training tokens, pos = _load_data_shard(next(file_iter)), 0 - max_batch_span = 2 * batch_size if align_to_bos else batch_size # provide buffer to handle samples up to length local_batch_size while True: - if pos + max_batch_span + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - if align_to_bos: - batch_starts, batch_span = find_batch_starts(tokens, pos, local_batch_size, max_batch_span) - start_idx = batch_starts[rank] - else: - batch_span = batch_size - start_idx = pos + rank * local_batch_size - buf = tokens[start_idx:][:local_batch_size + 1] - inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; - targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. - pos += batch_span - yield inputs, targets + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets # ----------------------------------------------------------------------------- # int main @@ -561,17 +831,23 @@ class Hyperparameters: train_seq_len = 48*1024 # FlexAttention sequence length val_seq_len = 4*64*1024 # FlexAttention sequence length for validation # optimization - num_iterations = 1750 # number of iterations to run + num_iterations = 1695 # number of iterations to run cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate # evaluation and logging + run_id = uuid.uuid4() val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end save_checkpoint = False args = Hyperparameters() +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + # torchrun sets these env variables rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) -assert world_size == 8 # this code is designed for 8xH100 +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size assert torch.cuda.is_available() device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) torch.cuda.set_device(device) @@ -582,7 +858,7 @@ class Hyperparameters: # begin logging logfile = None if master_process: - run_id = uuid.uuid4() + run_id = args.run_id os.makedirs("logs", exist_ok=True) logfile = f"logs/{run_id}.txt" print(logfile) @@ -599,6 +875,7 @@ def print0(s, console=False): # log information about the hardware/software environment this is running on print0(f"Running Python {sys.version}") print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") def nvidia_smi(): import subprocess # avoid top level import return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout @@ -650,7 +927,7 @@ def get_window_size_blocks(step: int): window_size = next_multiple_of_n(1728 * x, n=128) return get_window_size_blocks_helper(window_size) -model: nn.Module = torch.compile(model, dynamic=False) +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) ######################################## # Warmup kernels # @@ -660,7 +937,7 @@ def get_window_size_blocks(step: int): warmup_steps = 10 initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, align_to_bos=True) +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) for _ in range(warmup_steps): inputs, targets = next(train_loader) model(inputs, targets, get_window_size_blocks(1)).backward() @@ -676,7 +953,7 @@ def get_window_size_blocks(step: int): # Training and validation # ######################################## -train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, align_to_bos=True) +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) training_time_ms = 0 # start the clock torch.cuda.synchronize() @@ -695,7 +972,7 @@ def get_window_size_blocks(step: int): val_batch_size = world_size * args.val_seq_len assert args.val_tokens % val_batch_size == 0 val_steps = args.val_tokens // val_batch_size - val_loader = distributed_data_generator(args.val_files, val_batch_size, align_to_bos=False) + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) val_loss = 0 with torch.no_grad(): for _ in range(val_steps): @@ -719,8 +996,9 @@ def get_window_size_blocks(step: int): break # --------------- TRAINING SECTION ----------------- - inputs, targets = next(train_loader) - model(inputs, targets, get_window_size_blocks(step)).backward() + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() # set optimization hyperparameters for opt in optimizers: for group in opt.param_groups: From 12dfed77678c6ef53ec0540def0156ca852aeb5b Mon Sep 17 00:00:00 2001 From: ClassicLarry <42926649+ClassicLarry@users.noreply.github.com> Date: Sat, 23 Aug 2025 13:13:29 -0700 Subject: [PATCH 02/14] Update README.md --- records/082325_SparseAttnGate/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/records/082325_SparseAttnGate/README.md b/records/082325_SparseAttnGate/README.md index 584f392d1..89f8d1182 100644 --- a/records/082325_SparseAttnGate/README.md +++ b/records/082325_SparseAttnGate/README.md @@ -1,7 +1,7 @@ ## New record 08/23/25 1. Included WR improvements on Triton and grad batching from https://github.com/KellerJordan/modded-nanogpt/pull/109 by @byronxu99 -2. Added a sparse attention gate on the attention output to enable a context based no-op. Found the mechanism was performant with only 12 active dimensions from the residual stream. If curious, here is a related blog post from an earlier investigation into non-sparse attention gate with detailed plots: https://medium.com/@larry36d/modulating-attention-scores-cc0bcd853f06. The blog demonstrates how the attention gate reduces the need for the bos_token to function as an attention sink. This is particularly relevant in a sliding window attention context because the bos_token is not always in the context window. ROPE embeddings cause the bos_token attention sink to change based on relative distance, whereas a sparse attention gate is indifferent to distance from start of sample. Estimate of impact: 50 steps fewer, with slight increase in time per step. +2. Added a sparse attention gate on the attention output to enable a context based no-op. Found the mechanism was performant with 12 active dimensions from the residual stream. If curious, here is a related blog post from an earlier investigation into non-sparse attention gate with detailed plots: https://medium.com/@larry36d/modulating-attention-scores-cc0bcd853f06. The blog demonstrates how the attention gate reduces the need for the bos_token to function as an attention sink. This is particularly relevant in a sliding window attention context because the bos_token is not always in the context window. ROPE embeddings cause the bos_token attention sink to change based on relative distance, whereas a sparse attention gate is indifferent to distance from start of sample. Estimate of impact: 50 steps fewer, with slight increase in time per step. 3. As a follow-on from 2: Reduced number of iterations from 1750 to 1695. 4. Reverted the lm head scaling changes made on Feb 10th: https://github.com/KellerJordan/modded-nanogpt/commit/85a0a5201f08c4d6bb288ef348bb252d9c33e132. When tested on a single A100, reverting this change drops the L2 norm of the LM head weights from 250 down to 10. The logits need to express values roughly from -10 to 10 in order to capture the range of token probabilities. Dividing by 27.5 (x.size(-1)**0.5) was causing the weights to grow substantially to accomplish this, since the residual stream was being normed prior to the lm_head. The second moment estimate of Adam depends on the parameter scale, and the Adam learning rates were likely heavily tuned prior to the Feb 10th update. If curious, more details near end of this blog post: https://medium.com/@larry36d/exploration-log-exploring-initializing-transformers-with-bigram-distribution-70f9c8800b21. Estimate of impact: 5-10 steps. (in this case just a cleaner cut below 3.28) 5. Chose to keep the minimum lr at 0.1. The bos_align record decreased the minimum lr to 0.05, and a later refactor, perhaps unintentionally, moved it back to 0.1. On further testing, the impact of this value on mean loss is marginal, but lower minimum lr appear to increase the variance of the final loss, making testing more challenging. Lower minimum lr may have higher variance because its committing to diving deep in the local space earlier, and is somewhat rolling the dice on if its a promising region or not. On reflection, I likely originally picked 0.05 because taking the min loss over a grid search will naturally bias to higher variance configurations, which is the opposite of what we want. From 84cd472e8c7dfc6b531b71d7fcdf30834755e665 Mon Sep 17 00:00:00 2001 From: Varun Neal Srivastava Date: Wed, 27 Aug 2025 01:17:42 -0400 Subject: [PATCH 03/14] Added FA3 record --- .../17e712ee-7cf8-44c9-a784-3762e61b174c.txt | 2808 +++++++++++++++++ .../1d46fee6-b32c-48de-bd61-0a326442ec4e.txt | 2808 +++++++++++++++++ .../27d1e0d2-df15-41a9-9496-492a21943fb1.txt | 2808 +++++++++++++++++ .../7a492532-c19b-40dd-958d-fec55aa4d3fd.txt | 2808 +++++++++++++++++ records/082725_FA3/README.md | 147 + .../ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt | 2808 +++++++++++++++++ .../bb331245-5e49-4366-b902-6caff64ed8d6.txt | 2808 +++++++++++++++++ .../be1069a9-64f4-4316-bd26-4a7f5b697509.txt | 2808 +++++++++++++++++ requirements.txt | 3 +- train_gpt.py | 292 +- 10 files changed, 19963 insertions(+), 135 deletions(-) create mode 100644 records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt create mode 100644 records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt create mode 100644 records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt create mode 100644 records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt create mode 100644 records/082725_FA3/README.md create mode 100644 records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt create mode 100644 records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt create mode 100644 records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt diff --git a/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt b/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt new file mode 100644 index 000000000..b5371a4da --- /dev/null +++ b/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 04:15:50 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 30C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1695 train_time:524ms step_avg:524.12ms +step:2/1695 train_time:549ms step_avg:274.51ms +step:3/1695 train_time:617ms step_avg:205.57ms +step:4/1695 train_time:709ms step_avg:177.26ms +step:5/1695 train_time:803ms step_avg:160.52ms +step:6/1695 train_time:897ms step_avg:149.44ms +step:7/1695 train_time:989ms step_avg:141.33ms +step:8/1695 train_time:1082ms step_avg:135.31ms +step:9/1695 train_time:1177ms step_avg:130.74ms +step:10/1695 train_time:1270ms step_avg:127.00ms +step:11/1695 train_time:1364ms step_avg:123.98ms +step:12/1695 train_time:1460ms step_avg:121.70ms +step:13/1695 train_time:1558ms step_avg:119.84ms +step:14/1695 train_time:1653ms step_avg:118.08ms +step:15/1695 train_time:1748ms step_avg:116.53ms +step:16/1695 train_time:1843ms step_avg:115.16ms +step:17/1695 train_time:1937ms step_avg:113.93ms +step:18/1695 train_time:2030ms step_avg:112.80ms +step:19/1695 train_time:2124ms step_avg:111.79ms +step:20/1695 train_time:2218ms step_avg:110.90ms +step:21/1695 train_time:2312ms step_avg:110.11ms +step:22/1695 train_time:2407ms step_avg:109.43ms +step:23/1695 train_time:2504ms step_avg:108.87ms +step:24/1695 train_time:2600ms step_avg:108.35ms +step:25/1695 train_time:2696ms step_avg:107.83ms +step:26/1695 train_time:2790ms step_avg:107.30ms +step:27/1695 train_time:2884ms step_avg:106.81ms +step:28/1695 train_time:2979ms step_avg:106.39ms +step:29/1695 train_time:3073ms step_avg:105.96ms +step:30/1695 train_time:3167ms step_avg:105.56ms +step:31/1695 train_time:3261ms step_avg:105.19ms +step:32/1695 train_time:3356ms step_avg:104.89ms +step:33/1695 train_time:3451ms step_avg:104.57ms +step:34/1695 train_time:3546ms step_avg:104.31ms +step:35/1695 train_time:3644ms step_avg:104.12ms +step:36/1695 train_time:3740ms step_avg:103.89ms +step:37/1695 train_time:3835ms step_avg:103.64ms +step:38/1695 train_time:3929ms step_avg:103.38ms +step:39/1695 train_time:4023ms step_avg:103.16ms +step:40/1695 train_time:4117ms step_avg:102.94ms +step:41/1695 train_time:4211ms step_avg:102.71ms +step:42/1695 train_time:4305ms step_avg:102.51ms +step:43/1695 train_time:4401ms step_avg:102.34ms +step:44/1695 train_time:4497ms step_avg:102.20ms +step:45/1695 train_time:4592ms step_avg:102.04ms +step:46/1695 train_time:4686ms step_avg:101.87ms +step:47/1695 train_time:4782ms step_avg:101.74ms +step:48/1695 train_time:4877ms step_avg:101.60ms +step:49/1695 train_time:4971ms step_avg:101.45ms +step:50/1695 train_time:5065ms step_avg:101.30ms +step:51/1695 train_time:5160ms step_avg:101.17ms +step:52/1695 train_time:5254ms step_avg:101.03ms +step:53/1695 train_time:5347ms step_avg:100.89ms +step:54/1695 train_time:5442ms step_avg:100.79ms +step:55/1695 train_time:5538ms step_avg:100.69ms +step:56/1695 train_time:5632ms step_avg:100.57ms +step:57/1695 train_time:5726ms step_avg:100.46ms +step:58/1695 train_time:5822ms step_avg:100.38ms +step:59/1695 train_time:5918ms step_avg:100.30ms +step:60/1695 train_time:6012ms step_avg:100.20ms +step:61/1695 train_time:6106ms step_avg:100.10ms +step:62/1695 train_time:6201ms step_avg:100.02ms +step:63/1695 train_time:6296ms step_avg:99.94ms +step:64/1695 train_time:6390ms step_avg:99.84ms +step:65/1695 train_time:6485ms step_avg:99.77ms +step:66/1695 train_time:6579ms step_avg:99.68ms +step:67/1695 train_time:6673ms step_avg:99.59ms +step:68/1695 train_time:6767ms step_avg:99.52ms +step:69/1695 train_time:6863ms step_avg:99.47ms +step:70/1695 train_time:6958ms step_avg:99.40ms +step:71/1695 train_time:7052ms step_avg:99.32ms +step:72/1695 train_time:7146ms step_avg:99.25ms +step:73/1695 train_time:7241ms step_avg:99.19ms +step:74/1695 train_time:7337ms step_avg:99.15ms +step:75/1695 train_time:7431ms step_avg:99.08ms +step:76/1695 train_time:7526ms step_avg:99.02ms +step:77/1695 train_time:7621ms step_avg:98.97ms +step:78/1695 train_time:7716ms step_avg:98.92ms +step:79/1695 train_time:7809ms step_avg:98.85ms +step:80/1695 train_time:7905ms step_avg:98.81ms +step:81/1695 train_time:8000ms step_avg:98.76ms +step:82/1695 train_time:8094ms step_avg:98.70ms +step:83/1695 train_time:8189ms step_avg:98.66ms +step:84/1695 train_time:8283ms step_avg:98.60ms +step:85/1695 train_time:8378ms step_avg:98.56ms +step:86/1695 train_time:8471ms step_avg:98.50ms +step:87/1695 train_time:8566ms step_avg:98.45ms +step:88/1695 train_time:8661ms step_avg:98.42ms +step:89/1695 train_time:8755ms step_avg:98.37ms +step:90/1695 train_time:8849ms step_avg:98.32ms +step:91/1695 train_time:8944ms step_avg:98.29ms +step:92/1695 train_time:9039ms step_avg:98.25ms +step:93/1695 train_time:9133ms step_avg:98.21ms +step:94/1695 train_time:9227ms step_avg:98.16ms +step:95/1695 train_time:9322ms step_avg:98.13ms +step:96/1695 train_time:9417ms step_avg:98.10ms +step:97/1695 train_time:9511ms step_avg:98.05ms +step:98/1695 train_time:9606ms step_avg:98.02ms +step:99/1695 train_time:9702ms step_avg:98.00ms +step:100/1695 train_time:9797ms step_avg:97.97ms +step:101/1695 train_time:9891ms step_avg:97.93ms +step:102/1695 train_time:9985ms step_avg:97.89ms +step:103/1695 train_time:10079ms step_avg:97.85ms +step:104/1695 train_time:10174ms step_avg:97.82ms +step:105/1695 train_time:10268ms step_avg:97.79ms +step:106/1695 train_time:10362ms step_avg:97.76ms +step:107/1695 train_time:10456ms step_avg:97.72ms +step:108/1695 train_time:10550ms step_avg:97.69ms +step:109/1695 train_time:10645ms step_avg:97.66ms +step:110/1695 train_time:10740ms step_avg:97.63ms +step:111/1695 train_time:10835ms step_avg:97.61ms +step:112/1695 train_time:10929ms step_avg:97.58ms +step:113/1695 train_time:11023ms step_avg:97.55ms +step:114/1695 train_time:11118ms step_avg:97.53ms +step:115/1695 train_time:11213ms step_avg:97.50ms +step:116/1695 train_time:11307ms step_avg:97.47ms +step:117/1695 train_time:11402ms step_avg:97.45ms +step:118/1695 train_time:11496ms step_avg:97.42ms +step:119/1695 train_time:11589ms step_avg:97.39ms +step:120/1695 train_time:11684ms step_avg:97.37ms +step:121/1695 train_time:11779ms step_avg:97.34ms +step:122/1695 train_time:11873ms step_avg:97.32ms +step:123/1695 train_time:11967ms step_avg:97.29ms +step:124/1695 train_time:12062ms step_avg:97.28ms +step:125/1695 train_time:12157ms step_avg:97.25ms +step:125/1695 val_loss:4.3113 train_time:12248ms step_avg:97.99ms +step:126/1695 train_time:12274ms step_avg:97.41ms +step:127/1695 train_time:12351ms step_avg:97.25ms +step:128/1695 train_time:12451ms step_avg:97.28ms +step:129/1695 train_time:12547ms step_avg:97.26ms +step:130/1695 train_time:12640ms step_avg:97.23ms +step:131/1695 train_time:12734ms step_avg:97.21ms +step:132/1695 train_time:12828ms step_avg:97.18ms +step:133/1695 train_time:12921ms step_avg:97.15ms +step:134/1695 train_time:13014ms step_avg:97.12ms +step:135/1695 train_time:13108ms step_avg:97.09ms +step:136/1695 train_time:13201ms step_avg:97.07ms +step:137/1695 train_time:13297ms step_avg:97.05ms +step:138/1695 train_time:13394ms step_avg:97.06ms +step:139/1695 train_time:13490ms step_avg:97.05ms +step:140/1695 train_time:13584ms step_avg:97.03ms +step:141/1695 train_time:13678ms step_avg:97.00ms +step:142/1695 train_time:13772ms step_avg:96.98ms +step:143/1695 train_time:13865ms step_avg:96.96ms +step:144/1695 train_time:13958ms step_avg:96.93ms +step:145/1695 train_time:14052ms step_avg:96.91ms +step:146/1695 train_time:14144ms step_avg:96.88ms +step:147/1695 train_time:14238ms step_avg:96.86ms +step:148/1695 train_time:14333ms step_avg:96.84ms +step:149/1695 train_time:14429ms step_avg:96.84ms +step:150/1695 train_time:14523ms step_avg:96.82ms +step:151/1695 train_time:14618ms step_avg:96.81ms +step:152/1695 train_time:14713ms step_avg:96.79ms +step:153/1695 train_time:14806ms step_avg:96.77ms +step:154/1695 train_time:14900ms step_avg:96.75ms +step:155/1695 train_time:14995ms step_avg:96.74ms +step:156/1695 train_time:15088ms step_avg:96.72ms +step:157/1695 train_time:15181ms step_avg:96.69ms +step:158/1695 train_time:15274ms step_avg:96.67ms +step:159/1695 train_time:15369ms step_avg:96.66ms +step:160/1695 train_time:15464ms step_avg:96.65ms +step:161/1695 train_time:15558ms step_avg:96.63ms +step:162/1695 train_time:15653ms step_avg:96.62ms +step:163/1695 train_time:15748ms step_avg:96.61ms +step:164/1695 train_time:15841ms step_avg:96.59ms +step:165/1695 train_time:15936ms step_avg:96.58ms +step:166/1695 train_time:16030ms step_avg:96.57ms +step:167/1695 train_time:16124ms step_avg:96.55ms +step:168/1695 train_time:16217ms step_avg:96.53ms +step:169/1695 train_time:16311ms step_avg:96.52ms +step:170/1695 train_time:16406ms step_avg:96.50ms +step:171/1695 train_time:16500ms step_avg:96.49ms +step:172/1695 train_time:16595ms step_avg:96.48ms +step:173/1695 train_time:16964ms step_avg:98.06ms +step:174/1695 train_time:17033ms step_avg:97.89ms +step:175/1695 train_time:17126ms step_avg:97.86ms +step:176/1695 train_time:17219ms step_avg:97.83ms +step:177/1695 train_time:17312ms step_avg:97.81ms +step:178/1695 train_time:17405ms step_avg:97.78ms +step:179/1695 train_time:17499ms step_avg:97.76ms +step:180/1695 train_time:17591ms step_avg:97.73ms +step:181/1695 train_time:17684ms step_avg:97.70ms +step:182/1695 train_time:17777ms step_avg:97.68ms +step:183/1695 train_time:17876ms step_avg:97.68ms +step:184/1695 train_time:17973ms step_avg:97.68ms +step:185/1695 train_time:18068ms step_avg:97.67ms +step:186/1695 train_time:18162ms step_avg:97.64ms +step:187/1695 train_time:18256ms step_avg:97.63ms +step:188/1695 train_time:18350ms step_avg:97.61ms +step:189/1695 train_time:18443ms step_avg:97.58ms +step:190/1695 train_time:18537ms step_avg:97.56ms +step:191/1695 train_time:18630ms step_avg:97.54ms +step:192/1695 train_time:18724ms step_avg:97.52ms +step:193/1695 train_time:18818ms step_avg:97.50ms +step:194/1695 train_time:18914ms step_avg:97.49ms +step:195/1695 train_time:19010ms step_avg:97.48ms +step:196/1695 train_time:19105ms step_avg:97.47ms +step:197/1695 train_time:19198ms step_avg:97.45ms +step:198/1695 train_time:19292ms step_avg:97.44ms +step:199/1695 train_time:19387ms step_avg:97.42ms +step:200/1695 train_time:19480ms step_avg:97.40ms +step:201/1695 train_time:19574ms step_avg:97.39ms +step:202/1695 train_time:19669ms step_avg:97.37ms +step:203/1695 train_time:19763ms step_avg:97.35ms +step:204/1695 train_time:19857ms step_avg:97.34ms +step:205/1695 train_time:19952ms step_avg:97.32ms +step:206/1695 train_time:20047ms step_avg:97.31ms +step:207/1695 train_time:20140ms step_avg:97.30ms +step:208/1695 train_time:20235ms step_avg:97.28ms +step:209/1695 train_time:20329ms step_avg:97.27ms +step:210/1695 train_time:20423ms step_avg:97.25ms +step:211/1695 train_time:20516ms step_avg:97.23ms +step:212/1695 train_time:20611ms step_avg:97.22ms +step:213/1695 train_time:20706ms step_avg:97.21ms +step:214/1695 train_time:20799ms step_avg:97.19ms +step:215/1695 train_time:20894ms step_avg:97.18ms +step:216/1695 train_time:20988ms step_avg:97.17ms +step:217/1695 train_time:21082ms step_avg:97.15ms +step:218/1695 train_time:21176ms step_avg:97.14ms +step:219/1695 train_time:21271ms step_avg:97.13ms +step:220/1695 train_time:21365ms step_avg:97.12ms +step:221/1695 train_time:21459ms step_avg:97.10ms +step:222/1695 train_time:21554ms step_avg:97.09ms +step:223/1695 train_time:21648ms step_avg:97.07ms +step:224/1695 train_time:21741ms step_avg:97.06ms +step:225/1695 train_time:21835ms step_avg:97.05ms +step:226/1695 train_time:21931ms step_avg:97.04ms +step:227/1695 train_time:22024ms step_avg:97.02ms +step:228/1695 train_time:22118ms step_avg:97.01ms +step:229/1695 train_time:22213ms step_avg:97.00ms +step:230/1695 train_time:22308ms step_avg:96.99ms +step:231/1695 train_time:22401ms step_avg:96.98ms +step:232/1695 train_time:22496ms step_avg:96.96ms +step:233/1695 train_time:22589ms step_avg:96.95ms +step:234/1695 train_time:22683ms step_avg:96.93ms +step:235/1695 train_time:22776ms step_avg:96.92ms +step:236/1695 train_time:22872ms step_avg:96.91ms +step:237/1695 train_time:22967ms step_avg:96.91ms +step:238/1695 train_time:23061ms step_avg:96.89ms +step:239/1695 train_time:23155ms step_avg:96.88ms +step:240/1695 train_time:23249ms step_avg:96.87ms +step:241/1695 train_time:23343ms step_avg:96.86ms +step:242/1695 train_time:23437ms step_avg:96.85ms +step:243/1695 train_time:23532ms step_avg:96.84ms +step:244/1695 train_time:23626ms step_avg:96.83ms +step:245/1695 train_time:23719ms step_avg:96.81ms +step:246/1695 train_time:23814ms step_avg:96.81ms +step:247/1695 train_time:23909ms step_avg:96.80ms +step:248/1695 train_time:24003ms step_avg:96.79ms +step:249/1695 train_time:24097ms step_avg:96.78ms +step:250/1695 train_time:24191ms step_avg:96.77ms +step:250/1695 val_loss:3.9807 train_time:24284ms step_avg:97.14ms +step:251/1695 train_time:24310ms step_avg:96.85ms +step:252/1695 train_time:24384ms step_avg:96.76ms +step:253/1695 train_time:24484ms step_avg:96.78ms +step:254/1695 train_time:24580ms step_avg:96.77ms +step:255/1695 train_time:24673ms step_avg:96.76ms +step:256/1695 train_time:24766ms step_avg:96.74ms +step:257/1695 train_time:24859ms step_avg:96.73ms +step:258/1695 train_time:24953ms step_avg:96.72ms +step:259/1695 train_time:25046ms step_avg:96.70ms +step:260/1695 train_time:25139ms step_avg:96.69ms +step:261/1695 train_time:25232ms step_avg:96.68ms +step:262/1695 train_time:25328ms step_avg:96.67ms +step:263/1695 train_time:25425ms step_avg:96.67ms +step:264/1695 train_time:25521ms step_avg:96.67ms +step:265/1695 train_time:25616ms step_avg:96.67ms +step:266/1695 train_time:25710ms step_avg:96.65ms +step:267/1695 train_time:25804ms step_avg:96.64ms +step:268/1695 train_time:25898ms step_avg:96.64ms +step:269/1695 train_time:25992ms step_avg:96.62ms +step:270/1695 train_time:26085ms step_avg:96.61ms +step:271/1695 train_time:26179ms step_avg:96.60ms +step:272/1695 train_time:26273ms step_avg:96.59ms +step:273/1695 train_time:26368ms step_avg:96.59ms +step:274/1695 train_time:26465ms step_avg:96.59ms +step:275/1695 train_time:26560ms step_avg:96.58ms +step:276/1695 train_time:26655ms step_avg:96.58ms +step:277/1695 train_time:26749ms step_avg:96.57ms +step:278/1695 train_time:26843ms step_avg:96.56ms +step:279/1695 train_time:26936ms step_avg:96.55ms +step:280/1695 train_time:27030ms step_avg:96.53ms +step:281/1695 train_time:27124ms step_avg:96.53ms +step:282/1695 train_time:27218ms step_avg:96.52ms +step:283/1695 train_time:27312ms step_avg:96.51ms +step:284/1695 train_time:27407ms step_avg:96.50ms +step:285/1695 train_time:27502ms step_avg:96.50ms +step:286/1695 train_time:27596ms step_avg:96.49ms +step:287/1695 train_time:27690ms step_avg:96.48ms +step:288/1695 train_time:27785ms step_avg:96.47ms +step:289/1695 train_time:27879ms step_avg:96.47ms +step:290/1695 train_time:27971ms step_avg:96.45ms +step:291/1695 train_time:28065ms step_avg:96.44ms +step:292/1695 train_time:28160ms step_avg:96.44ms +step:293/1695 train_time:28254ms step_avg:96.43ms +step:294/1695 train_time:28348ms step_avg:96.42ms +step:295/1695 train_time:28443ms step_avg:96.42ms +step:296/1695 train_time:28538ms step_avg:96.41ms +step:297/1695 train_time:28632ms step_avg:96.40ms +step:298/1695 train_time:28726ms step_avg:96.40ms +step:299/1695 train_time:28820ms step_avg:96.39ms +step:300/1695 train_time:28913ms step_avg:96.38ms +step:301/1695 train_time:29006ms step_avg:96.37ms +step:302/1695 train_time:29100ms step_avg:96.36ms +step:303/1695 train_time:29194ms step_avg:96.35ms +step:304/1695 train_time:29288ms step_avg:96.34ms +step:305/1695 train_time:29382ms step_avg:96.33ms +step:306/1695 train_time:29477ms step_avg:96.33ms +step:307/1695 train_time:29571ms step_avg:96.32ms +step:308/1695 train_time:29666ms step_avg:96.32ms +step:309/1695 train_time:29761ms step_avg:96.31ms +step:310/1695 train_time:29855ms step_avg:96.31ms +step:311/1695 train_time:29948ms step_avg:96.30ms +step:312/1695 train_time:30042ms step_avg:96.29ms +step:313/1695 train_time:30135ms step_avg:96.28ms +step:314/1695 train_time:30229ms step_avg:96.27ms +step:315/1695 train_time:30322ms step_avg:96.26ms +step:316/1695 train_time:30416ms step_avg:96.25ms +step:317/1695 train_time:30510ms step_avg:96.24ms +step:318/1695 train_time:30605ms step_avg:96.24ms +step:319/1695 train_time:30700ms step_avg:96.24ms +step:320/1695 train_time:30795ms step_avg:96.23ms +step:321/1695 train_time:30888ms step_avg:96.23ms +step:322/1695 train_time:30983ms step_avg:96.22ms +step:323/1695 train_time:31077ms step_avg:96.21ms +step:324/1695 train_time:31170ms step_avg:96.20ms +step:325/1695 train_time:31264ms step_avg:96.20ms +step:326/1695 train_time:31360ms step_avg:96.20ms +step:327/1695 train_time:31453ms step_avg:96.19ms +step:328/1695 train_time:31547ms step_avg:96.18ms +step:329/1695 train_time:31642ms step_avg:96.18ms +step:330/1695 train_time:31737ms step_avg:96.17ms +step:331/1695 train_time:31832ms step_avg:96.17ms +step:332/1695 train_time:31926ms step_avg:96.16ms +step:333/1695 train_time:32020ms step_avg:96.16ms +step:334/1695 train_time:32114ms step_avg:96.15ms +step:335/1695 train_time:32207ms step_avg:96.14ms +step:336/1695 train_time:32302ms step_avg:96.14ms +step:337/1695 train_time:32395ms step_avg:96.13ms +step:338/1695 train_time:32488ms step_avg:96.12ms +step:339/1695 train_time:32582ms step_avg:96.11ms +step:340/1695 train_time:32677ms step_avg:96.11ms +step:341/1695 train_time:32771ms step_avg:96.10ms +step:342/1695 train_time:32866ms step_avg:96.10ms +step:343/1695 train_time:32961ms step_avg:96.10ms +step:344/1695 train_time:33055ms step_avg:96.09ms +step:345/1695 train_time:33378ms step_avg:96.75ms +step:346/1695 train_time:33470ms step_avg:96.73ms +step:347/1695 train_time:33563ms step_avg:96.72ms +step:348/1695 train_time:33655ms step_avg:96.71ms +step:349/1695 train_time:33748ms step_avg:96.70ms +step:350/1695 train_time:33841ms step_avg:96.69ms +step:351/1695 train_time:33934ms step_avg:96.68ms +step:352/1695 train_time:34027ms step_avg:96.67ms +step:353/1695 train_time:34120ms step_avg:96.66ms +step:354/1695 train_time:34213ms step_avg:96.65ms +step:355/1695 train_time:34312ms step_avg:96.65ms +step:356/1695 train_time:34409ms step_avg:96.65ms +step:357/1695 train_time:34506ms step_avg:96.66ms +step:358/1695 train_time:34602ms step_avg:96.65ms +step:359/1695 train_time:34695ms step_avg:96.64ms +step:360/1695 train_time:34788ms step_avg:96.63ms +step:361/1695 train_time:34881ms step_avg:96.62ms +step:362/1695 train_time:34974ms step_avg:96.61ms +step:363/1695 train_time:35068ms step_avg:96.60ms +step:364/1695 train_time:35161ms step_avg:96.60ms +step:365/1695 train_time:35256ms step_avg:96.59ms +step:366/1695 train_time:35351ms step_avg:96.59ms +step:367/1695 train_time:35448ms step_avg:96.59ms +step:368/1695 train_time:35544ms step_avg:96.59ms +step:369/1695 train_time:35639ms step_avg:96.58ms +step:370/1695 train_time:35733ms step_avg:96.58ms +step:371/1695 train_time:35826ms step_avg:96.57ms +step:372/1695 train_time:35920ms step_avg:96.56ms +step:373/1695 train_time:36013ms step_avg:96.55ms +step:374/1695 train_time:36106ms step_avg:96.54ms +step:375/1695 train_time:36199ms step_avg:96.53ms +step:375/1695 val_loss:3.8148 train_time:36291ms step_avg:96.78ms +step:376/1695 train_time:36317ms step_avg:96.59ms +step:377/1695 train_time:36395ms step_avg:96.54ms +step:378/1695 train_time:36491ms step_avg:96.54ms +step:379/1695 train_time:36586ms step_avg:96.53ms +step:380/1695 train_time:36680ms step_avg:96.53ms +step:381/1695 train_time:36773ms step_avg:96.52ms +step:382/1695 train_time:36866ms step_avg:96.51ms +step:383/1695 train_time:36960ms step_avg:96.50ms +step:384/1695 train_time:37054ms step_avg:96.49ms +step:385/1695 train_time:37147ms step_avg:96.49ms +step:386/1695 train_time:37241ms step_avg:96.48ms +step:387/1695 train_time:37337ms step_avg:96.48ms +step:388/1695 train_time:37433ms step_avg:96.48ms +step:389/1695 train_time:37529ms step_avg:96.47ms +step:390/1695 train_time:37623ms step_avg:96.47ms +step:391/1695 train_time:37716ms step_avg:96.46ms +step:392/1695 train_time:37809ms step_avg:96.45ms +step:393/1695 train_time:37902ms step_avg:96.44ms +step:394/1695 train_time:37996ms step_avg:96.44ms +step:395/1695 train_time:38089ms step_avg:96.43ms +step:396/1695 train_time:38183ms step_avg:96.42ms +step:397/1695 train_time:38277ms step_avg:96.42ms +step:398/1695 train_time:38371ms step_avg:96.41ms +step:399/1695 train_time:38466ms step_avg:96.41ms +step:400/1695 train_time:38561ms step_avg:96.40ms +step:401/1695 train_time:38655ms step_avg:96.40ms +step:402/1695 train_time:38749ms step_avg:96.39ms +step:403/1695 train_time:38843ms step_avg:96.38ms +step:404/1695 train_time:38936ms step_avg:96.38ms +step:405/1695 train_time:39030ms step_avg:96.37ms +step:406/1695 train_time:39124ms step_avg:96.36ms +step:407/1695 train_time:39217ms step_avg:96.36ms +step:408/1695 train_time:39311ms step_avg:96.35ms +step:409/1695 train_time:39406ms step_avg:96.35ms +step:410/1695 train_time:39501ms step_avg:96.34ms +step:411/1695 train_time:39594ms step_avg:96.34ms +step:412/1695 train_time:39689ms step_avg:96.33ms +step:413/1695 train_time:39783ms step_avg:96.33ms +step:414/1695 train_time:39876ms step_avg:96.32ms +step:415/1695 train_time:39970ms step_avg:96.31ms +step:416/1695 train_time:40063ms step_avg:96.31ms +step:417/1695 train_time:40157ms step_avg:96.30ms +step:418/1695 train_time:40251ms step_avg:96.29ms +step:419/1695 train_time:40346ms step_avg:96.29ms +step:420/1695 train_time:40440ms step_avg:96.29ms +step:421/1695 train_time:40534ms step_avg:96.28ms +step:422/1695 train_time:40629ms step_avg:96.28ms +step:423/1695 train_time:40723ms step_avg:96.27ms +step:424/1695 train_time:40816ms step_avg:96.26ms +step:425/1695 train_time:40910ms step_avg:96.26ms +step:426/1695 train_time:41005ms step_avg:96.25ms +step:427/1695 train_time:41099ms step_avg:96.25ms +step:428/1695 train_time:41192ms step_avg:96.24ms +step:429/1695 train_time:41288ms step_avg:96.24ms +step:430/1695 train_time:41383ms step_avg:96.24ms +step:431/1695 train_time:41476ms step_avg:96.23ms +step:432/1695 train_time:41570ms step_avg:96.23ms +step:433/1695 train_time:41664ms step_avg:96.22ms +step:434/1695 train_time:41758ms step_avg:96.22ms +step:435/1695 train_time:41851ms step_avg:96.21ms +step:436/1695 train_time:41946ms step_avg:96.21ms +step:437/1695 train_time:42039ms step_avg:96.20ms +step:438/1695 train_time:42132ms step_avg:96.19ms +step:439/1695 train_time:42227ms step_avg:96.19ms +step:440/1695 train_time:42322ms step_avg:96.19ms +step:441/1695 train_time:42416ms step_avg:96.18ms +step:442/1695 train_time:42510ms step_avg:96.18ms +step:443/1695 train_time:42605ms step_avg:96.17ms +step:444/1695 train_time:42698ms step_avg:96.17ms +step:445/1695 train_time:42792ms step_avg:96.16ms +step:446/1695 train_time:42887ms step_avg:96.16ms +step:447/1695 train_time:42982ms step_avg:96.16ms +step:448/1695 train_time:43075ms step_avg:96.15ms +step:449/1695 train_time:43169ms step_avg:96.14ms +step:450/1695 train_time:43264ms step_avg:96.14ms +step:451/1695 train_time:43358ms step_avg:96.14ms +step:452/1695 train_time:43452ms step_avg:96.13ms +step:453/1695 train_time:43548ms step_avg:96.13ms +step:454/1695 train_time:43642ms step_avg:96.13ms +step:455/1695 train_time:43735ms step_avg:96.12ms +step:456/1695 train_time:43828ms step_avg:96.11ms +step:457/1695 train_time:43923ms step_avg:96.11ms +step:458/1695 train_time:44018ms step_avg:96.11ms +step:459/1695 train_time:44112ms step_avg:96.10ms +step:460/1695 train_time:44205ms step_avg:96.10ms +step:461/1695 train_time:44300ms step_avg:96.09ms +step:462/1695 train_time:44393ms step_avg:96.09ms +step:463/1695 train_time:44487ms step_avg:96.08ms +step:464/1695 train_time:44582ms step_avg:96.08ms +step:465/1695 train_time:44676ms step_avg:96.08ms +step:466/1695 train_time:44770ms step_avg:96.07ms +step:467/1695 train_time:44864ms step_avg:96.07ms +step:468/1695 train_time:44959ms step_avg:96.07ms +step:469/1695 train_time:45053ms step_avg:96.06ms +step:470/1695 train_time:45147ms step_avg:96.06ms +step:471/1695 train_time:45242ms step_avg:96.05ms +step:472/1695 train_time:45335ms step_avg:96.05ms +step:473/1695 train_time:45429ms step_avg:96.04ms +step:474/1695 train_time:45523ms step_avg:96.04ms +step:475/1695 train_time:45616ms step_avg:96.03ms +step:476/1695 train_time:45710ms step_avg:96.03ms +step:477/1695 train_time:45805ms step_avg:96.03ms +step:478/1695 train_time:45899ms step_avg:96.02ms +step:479/1695 train_time:45992ms step_avg:96.02ms +step:480/1695 train_time:46087ms step_avg:96.01ms +step:481/1695 train_time:46181ms step_avg:96.01ms +step:482/1695 train_time:46275ms step_avg:96.01ms +step:483/1695 train_time:46369ms step_avg:96.00ms +step:484/1695 train_time:46464ms step_avg:96.00ms +step:485/1695 train_time:46559ms step_avg:96.00ms +step:486/1695 train_time:46653ms step_avg:95.99ms +step:487/1695 train_time:46747ms step_avg:95.99ms +step:488/1695 train_time:46842ms step_avg:95.99ms +step:489/1695 train_time:46935ms step_avg:95.98ms +step:490/1695 train_time:47029ms step_avg:95.98ms +step:491/1695 train_time:47122ms step_avg:95.97ms +step:492/1695 train_time:47216ms step_avg:95.97ms +step:493/1695 train_time:47309ms step_avg:95.96ms +step:494/1695 train_time:47403ms step_avg:95.96ms +step:495/1695 train_time:47496ms step_avg:95.95ms +step:496/1695 train_time:47591ms step_avg:95.95ms +step:497/1695 train_time:47686ms step_avg:95.95ms +step:498/1695 train_time:47780ms step_avg:95.94ms +step:499/1695 train_time:47874ms step_avg:95.94ms +step:500/1695 train_time:47968ms step_avg:95.94ms +step:500/1695 val_loss:3.7151 train_time:48060ms step_avg:96.12ms +step:501/1695 train_time:48087ms step_avg:95.98ms +step:502/1695 train_time:48163ms step_avg:95.94ms +step:503/1695 train_time:48261ms step_avg:95.95ms +step:504/1695 train_time:48355ms step_avg:95.94ms +step:505/1695 train_time:48448ms step_avg:95.94ms +step:506/1695 train_time:48542ms step_avg:95.93ms +step:507/1695 train_time:48634ms step_avg:95.93ms +step:508/1695 train_time:48728ms step_avg:95.92ms +step:509/1695 train_time:48820ms step_avg:95.91ms +step:510/1695 train_time:48913ms step_avg:95.91ms +step:511/1695 train_time:49007ms step_avg:95.90ms +step:512/1695 train_time:49103ms step_avg:95.90ms +step:513/1695 train_time:49198ms step_avg:95.90ms +step:514/1695 train_time:49293ms step_avg:95.90ms +step:515/1695 train_time:49388ms step_avg:95.90ms +step:516/1695 train_time:49482ms step_avg:95.90ms +step:517/1695 train_time:49575ms step_avg:95.89ms +step:518/1695 train_time:49669ms step_avg:95.89ms +step:519/1695 train_time:50009ms step_avg:96.36ms +step:520/1695 train_time:50195ms step_avg:96.53ms +step:521/1695 train_time:50287ms step_avg:96.52ms +step:522/1695 train_time:50380ms step_avg:96.51ms +step:523/1695 train_time:50473ms step_avg:96.51ms +step:524/1695 train_time:50566ms step_avg:96.50ms +step:525/1695 train_time:50658ms step_avg:96.49ms +step:526/1695 train_time:50751ms step_avg:96.49ms +step:527/1695 train_time:50845ms step_avg:96.48ms +step:528/1695 train_time:50937ms step_avg:96.47ms +step:529/1695 train_time:51035ms step_avg:96.47ms +step:530/1695 train_time:51134ms step_avg:96.48ms +step:531/1695 train_time:51232ms step_avg:96.48ms +step:532/1695 train_time:51328ms step_avg:96.48ms +step:533/1695 train_time:51422ms step_avg:96.48ms +step:534/1695 train_time:51515ms step_avg:96.47ms +step:535/1695 train_time:51609ms step_avg:96.47ms +step:536/1695 train_time:51703ms step_avg:96.46ms +step:537/1695 train_time:51795ms step_avg:96.45ms +step:538/1695 train_time:51889ms step_avg:96.45ms +step:539/1695 train_time:51982ms step_avg:96.44ms +step:540/1695 train_time:52077ms step_avg:96.44ms +step:541/1695 train_time:52173ms step_avg:96.44ms +step:542/1695 train_time:52270ms step_avg:96.44ms +step:543/1695 train_time:52365ms step_avg:96.44ms +step:544/1695 train_time:52458ms step_avg:96.43ms +step:545/1695 train_time:52552ms step_avg:96.43ms +step:546/1695 train_time:52646ms step_avg:96.42ms +step:547/1695 train_time:52739ms step_avg:96.42ms +step:548/1695 train_time:52833ms step_avg:96.41ms +step:549/1695 train_time:52926ms step_avg:96.40ms +step:550/1695 train_time:53020ms step_avg:96.40ms +step:551/1695 train_time:53114ms step_avg:96.40ms +step:552/1695 train_time:53209ms step_avg:96.39ms +step:553/1695 train_time:53303ms step_avg:96.39ms +step:554/1695 train_time:53397ms step_avg:96.38ms +step:555/1695 train_time:53491ms step_avg:96.38ms +step:556/1695 train_time:53584ms step_avg:96.37ms +step:557/1695 train_time:53678ms step_avg:96.37ms +step:558/1695 train_time:53771ms step_avg:96.36ms +step:559/1695 train_time:53865ms step_avg:96.36ms +step:560/1695 train_time:53959ms step_avg:96.35ms +step:561/1695 train_time:54053ms step_avg:96.35ms +step:562/1695 train_time:54147ms step_avg:96.35ms +step:563/1695 train_time:54242ms step_avg:96.34ms +step:564/1695 train_time:54336ms step_avg:96.34ms +step:565/1695 train_time:54430ms step_avg:96.34ms +step:566/1695 train_time:54524ms step_avg:96.33ms +step:567/1695 train_time:54618ms step_avg:96.33ms +step:568/1695 train_time:54713ms step_avg:96.33ms +step:569/1695 train_time:54809ms step_avg:96.32ms +step:570/1695 train_time:54905ms step_avg:96.33ms +step:571/1695 train_time:55001ms step_avg:96.32ms +step:572/1695 train_time:55096ms step_avg:96.32ms +step:573/1695 train_time:55193ms step_avg:96.32ms +step:574/1695 train_time:55289ms step_avg:96.32ms +step:575/1695 train_time:55385ms step_avg:96.32ms +step:576/1695 train_time:55481ms step_avg:96.32ms +step:577/1695 train_time:55576ms step_avg:96.32ms +step:578/1695 train_time:55672ms step_avg:96.32ms +step:579/1695 train_time:55769ms step_avg:96.32ms +step:580/1695 train_time:55865ms step_avg:96.32ms +step:581/1695 train_time:55961ms step_avg:96.32ms +step:582/1695 train_time:56056ms step_avg:96.32ms +step:583/1695 train_time:56152ms step_avg:96.32ms +step:584/1695 train_time:56248ms step_avg:96.32ms +step:585/1695 train_time:56345ms step_avg:96.32ms +step:586/1695 train_time:56442ms step_avg:96.32ms +step:587/1695 train_time:56537ms step_avg:96.32ms +step:588/1695 train_time:56633ms step_avg:96.31ms +step:589/1695 train_time:56729ms step_avg:96.31ms +step:590/1695 train_time:56824ms step_avg:96.31ms +step:591/1695 train_time:56919ms step_avg:96.31ms +step:592/1695 train_time:57015ms step_avg:96.31ms +step:593/1695 train_time:57111ms step_avg:96.31ms +step:594/1695 train_time:57208ms step_avg:96.31ms +step:595/1695 train_time:57304ms step_avg:96.31ms +step:596/1695 train_time:57401ms step_avg:96.31ms +step:597/1695 train_time:57496ms step_avg:96.31ms +step:598/1695 train_time:57592ms step_avg:96.31ms +step:599/1695 train_time:57689ms step_avg:96.31ms +step:600/1695 train_time:57784ms step_avg:96.31ms +step:601/1695 train_time:57880ms step_avg:96.31ms +step:602/1695 train_time:57976ms step_avg:96.31ms +step:603/1695 train_time:58071ms step_avg:96.30ms +step:604/1695 train_time:58168ms step_avg:96.30ms +step:605/1695 train_time:58264ms step_avg:96.30ms +step:606/1695 train_time:58360ms step_avg:96.30ms +step:607/1695 train_time:58455ms step_avg:96.30ms +step:608/1695 train_time:58550ms step_avg:96.30ms +step:609/1695 train_time:58647ms step_avg:96.30ms +step:610/1695 train_time:58744ms step_avg:96.30ms +step:611/1695 train_time:58841ms step_avg:96.30ms +step:612/1695 train_time:58936ms step_avg:96.30ms +step:613/1695 train_time:59032ms step_avg:96.30ms +step:614/1695 train_time:59128ms step_avg:96.30ms +step:615/1695 train_time:59224ms step_avg:96.30ms +step:616/1695 train_time:59320ms step_avg:96.30ms +step:617/1695 train_time:59415ms step_avg:96.30ms +step:618/1695 train_time:59511ms step_avg:96.30ms +step:619/1695 train_time:59607ms step_avg:96.30ms +step:620/1695 train_time:59703ms step_avg:96.29ms +step:621/1695 train_time:59798ms step_avg:96.29ms +step:622/1695 train_time:59894ms step_avg:96.29ms +step:623/1695 train_time:59991ms step_avg:96.29ms +step:624/1695 train_time:60087ms step_avg:96.29ms +step:625/1695 train_time:60183ms step_avg:96.29ms +step:625/1695 val_loss:3.6179 train_time:60276ms step_avg:96.44ms +step:626/1695 train_time:60301ms step_avg:96.33ms +step:627/1695 train_time:60381ms step_avg:96.30ms +step:628/1695 train_time:60478ms step_avg:96.30ms +step:629/1695 train_time:60574ms step_avg:96.30ms +step:630/1695 train_time:60669ms step_avg:96.30ms +step:631/1695 train_time:60764ms step_avg:96.30ms +step:632/1695 train_time:60858ms step_avg:96.29ms +step:633/1695 train_time:60953ms step_avg:96.29ms +step:634/1695 train_time:61048ms step_avg:96.29ms +step:635/1695 train_time:61143ms step_avg:96.29ms +step:636/1695 train_time:61240ms step_avg:96.29ms +step:637/1695 train_time:61337ms step_avg:96.29ms +step:638/1695 train_time:61435ms step_avg:96.29ms +step:639/1695 train_time:61532ms step_avg:96.29ms +step:640/1695 train_time:61628ms step_avg:96.29ms +step:641/1695 train_time:61725ms step_avg:96.30ms +step:642/1695 train_time:61821ms step_avg:96.29ms +step:643/1695 train_time:61915ms step_avg:96.29ms +step:644/1695 train_time:62011ms step_avg:96.29ms +step:645/1695 train_time:62106ms step_avg:96.29ms +step:646/1695 train_time:62201ms step_avg:96.29ms +step:647/1695 train_time:62297ms step_avg:96.29ms +step:648/1695 train_time:62394ms step_avg:96.29ms +step:649/1695 train_time:62492ms step_avg:96.29ms +step:650/1695 train_time:62589ms step_avg:96.29ms +step:651/1695 train_time:62685ms step_avg:96.29ms +step:652/1695 train_time:62781ms step_avg:96.29ms +step:653/1695 train_time:62876ms step_avg:96.29ms +step:654/1695 train_time:62971ms step_avg:96.29ms +step:655/1695 train_time:63067ms step_avg:96.29ms +step:656/1695 train_time:63164ms step_avg:96.29ms +step:657/1695 train_time:63261ms step_avg:96.29ms +step:658/1695 train_time:63357ms step_avg:96.29ms +step:659/1695 train_time:63453ms step_avg:96.29ms +step:660/1695 train_time:63550ms step_avg:96.29ms +step:661/1695 train_time:63647ms step_avg:96.29ms +step:662/1695 train_time:63743ms step_avg:96.29ms +step:663/1695 train_time:63838ms step_avg:96.29ms +step:664/1695 train_time:63933ms step_avg:96.28ms +step:665/1695 train_time:64029ms step_avg:96.28ms +step:666/1695 train_time:64126ms step_avg:96.28ms +step:667/1695 train_time:64222ms step_avg:96.29ms +step:668/1695 train_time:64318ms step_avg:96.28ms +step:669/1695 train_time:64413ms step_avg:96.28ms +step:670/1695 train_time:64511ms step_avg:96.28ms +step:671/1695 train_time:64608ms step_avg:96.29ms +step:672/1695 train_time:64705ms step_avg:96.29ms +step:673/1695 train_time:64801ms step_avg:96.29ms +step:674/1695 train_time:64896ms step_avg:96.28ms +step:675/1695 train_time:64993ms step_avg:96.29ms +step:676/1695 train_time:65089ms step_avg:96.29ms +step:677/1695 train_time:65186ms step_avg:96.29ms +step:678/1695 train_time:65283ms step_avg:96.29ms +step:679/1695 train_time:65378ms step_avg:96.29ms +step:680/1695 train_time:65474ms step_avg:96.29ms +step:681/1695 train_time:65571ms step_avg:96.29ms +step:682/1695 train_time:65667ms step_avg:96.29ms +step:683/1695 train_time:65763ms step_avg:96.29ms +step:684/1695 train_time:65859ms step_avg:96.28ms +step:685/1695 train_time:65954ms step_avg:96.28ms +step:686/1695 train_time:66050ms step_avg:96.28ms +step:687/1695 train_time:66146ms step_avg:96.28ms +step:688/1695 train_time:66242ms step_avg:96.28ms +step:689/1695 train_time:66338ms step_avg:96.28ms +step:690/1695 train_time:66433ms step_avg:96.28ms +step:691/1695 train_time:66794ms step_avg:96.66ms +step:692/1695 train_time:66958ms step_avg:96.76ms +step:693/1695 train_time:67053ms step_avg:96.76ms +step:694/1695 train_time:67148ms step_avg:96.76ms +step:695/1695 train_time:67244ms step_avg:96.75ms +step:696/1695 train_time:67338ms step_avg:96.75ms +step:697/1695 train_time:67433ms step_avg:96.75ms +step:698/1695 train_time:67528ms step_avg:96.75ms +step:699/1695 train_time:67624ms step_avg:96.74ms +step:700/1695 train_time:67718ms step_avg:96.74ms +step:701/1695 train_time:67820ms step_avg:96.75ms +step:702/1695 train_time:67919ms step_avg:96.75ms +step:703/1695 train_time:68017ms step_avg:96.75ms +step:704/1695 train_time:68115ms step_avg:96.75ms +step:705/1695 train_time:68211ms step_avg:96.75ms +step:706/1695 train_time:68306ms step_avg:96.75ms +step:707/1695 train_time:68400ms step_avg:96.75ms +step:708/1695 train_time:68495ms step_avg:96.74ms +step:709/1695 train_time:68590ms step_avg:96.74ms +step:710/1695 train_time:68686ms step_avg:96.74ms +step:711/1695 train_time:68784ms step_avg:96.74ms +step:712/1695 train_time:68882ms step_avg:96.74ms +step:713/1695 train_time:68978ms step_avg:96.74ms +step:714/1695 train_time:69074ms step_avg:96.74ms +step:715/1695 train_time:69170ms step_avg:96.74ms +step:716/1695 train_time:69266ms step_avg:96.74ms +step:717/1695 train_time:69361ms step_avg:96.74ms +step:718/1695 train_time:69455ms step_avg:96.73ms +step:719/1695 train_time:69551ms step_avg:96.73ms +step:720/1695 train_time:69646ms step_avg:96.73ms +step:721/1695 train_time:69742ms step_avg:96.73ms +step:722/1695 train_time:69838ms step_avg:96.73ms +step:723/1695 train_time:69935ms step_avg:96.73ms +step:724/1695 train_time:70032ms step_avg:96.73ms +step:725/1695 train_time:70129ms step_avg:96.73ms +step:726/1695 train_time:70225ms step_avg:96.73ms +step:727/1695 train_time:70322ms step_avg:96.73ms +step:728/1695 train_time:70417ms step_avg:96.73ms +step:729/1695 train_time:70512ms step_avg:96.72ms +step:730/1695 train_time:70608ms step_avg:96.72ms +step:731/1695 train_time:70704ms step_avg:96.72ms +step:732/1695 train_time:70801ms step_avg:96.72ms +step:733/1695 train_time:70896ms step_avg:96.72ms +step:734/1695 train_time:70993ms step_avg:96.72ms +step:735/1695 train_time:71090ms step_avg:96.72ms +step:736/1695 train_time:71187ms step_avg:96.72ms +step:737/1695 train_time:71284ms step_avg:96.72ms +step:738/1695 train_time:71380ms step_avg:96.72ms +step:739/1695 train_time:71475ms step_avg:96.72ms +step:740/1695 train_time:71570ms step_avg:96.72ms +step:741/1695 train_time:71667ms step_avg:96.72ms +step:742/1695 train_time:71765ms step_avg:96.72ms +step:743/1695 train_time:71861ms step_avg:96.72ms +step:744/1695 train_time:71957ms step_avg:96.72ms +step:745/1695 train_time:72053ms step_avg:96.71ms +step:746/1695 train_time:72150ms step_avg:96.72ms +step:747/1695 train_time:72247ms step_avg:96.72ms +step:748/1695 train_time:72343ms step_avg:96.72ms +step:749/1695 train_time:72439ms step_avg:96.71ms +step:750/1695 train_time:72534ms step_avg:96.71ms +step:750/1695 val_loss:3.5645 train_time:72628ms step_avg:96.84ms +step:751/1695 train_time:72654ms step_avg:96.74ms +step:752/1695 train_time:72734ms step_avg:96.72ms +step:753/1695 train_time:72834ms step_avg:96.72ms +step:754/1695 train_time:72930ms step_avg:96.72ms +step:755/1695 train_time:73026ms step_avg:96.72ms +step:756/1695 train_time:73121ms step_avg:96.72ms +step:757/1695 train_time:73215ms step_avg:96.72ms +step:758/1695 train_time:73311ms step_avg:96.72ms +step:759/1695 train_time:73407ms step_avg:96.72ms +step:760/1695 train_time:73501ms step_avg:96.71ms +step:761/1695 train_time:73598ms step_avg:96.71ms +step:762/1695 train_time:73696ms step_avg:96.71ms +step:763/1695 train_time:73795ms step_avg:96.72ms +step:764/1695 train_time:73893ms step_avg:96.72ms +step:765/1695 train_time:73990ms step_avg:96.72ms +step:766/1695 train_time:74086ms step_avg:96.72ms +step:767/1695 train_time:74182ms step_avg:96.72ms +step:768/1695 train_time:74276ms step_avg:96.71ms +step:769/1695 train_time:74372ms step_avg:96.71ms +step:770/1695 train_time:74468ms step_avg:96.71ms +step:771/1695 train_time:74564ms step_avg:96.71ms +step:772/1695 train_time:74661ms step_avg:96.71ms +step:773/1695 train_time:74758ms step_avg:96.71ms +step:774/1695 train_time:74854ms step_avg:96.71ms +step:775/1695 train_time:74951ms step_avg:96.71ms +step:776/1695 train_time:75048ms step_avg:96.71ms +step:777/1695 train_time:75143ms step_avg:96.71ms +step:778/1695 train_time:75238ms step_avg:96.71ms +step:779/1695 train_time:75333ms step_avg:96.71ms +step:780/1695 train_time:75429ms step_avg:96.70ms +step:781/1695 train_time:75526ms step_avg:96.70ms +step:782/1695 train_time:75622ms step_avg:96.70ms +step:783/1695 train_time:75718ms step_avg:96.70ms +step:784/1695 train_time:75815ms step_avg:96.70ms +step:785/1695 train_time:75911ms step_avg:96.70ms +step:786/1695 train_time:76008ms step_avg:96.70ms +step:787/1695 train_time:76104ms step_avg:96.70ms +step:788/1695 train_time:76199ms step_avg:96.70ms +step:789/1695 train_time:76295ms step_avg:96.70ms +step:790/1695 train_time:76390ms step_avg:96.70ms +step:791/1695 train_time:76486ms step_avg:96.70ms +step:792/1695 train_time:76582ms step_avg:96.69ms +step:793/1695 train_time:76678ms step_avg:96.69ms +step:794/1695 train_time:76774ms step_avg:96.69ms +step:795/1695 train_time:76871ms step_avg:96.69ms +step:796/1695 train_time:76969ms step_avg:96.69ms +step:797/1695 train_time:77067ms step_avg:96.70ms +step:798/1695 train_time:77163ms step_avg:96.70ms +step:799/1695 train_time:77259ms step_avg:96.69ms +step:800/1695 train_time:77355ms step_avg:96.69ms +step:801/1695 train_time:77450ms step_avg:96.69ms +step:802/1695 train_time:77546ms step_avg:96.69ms +step:803/1695 train_time:77642ms step_avg:96.69ms +step:804/1695 train_time:77738ms step_avg:96.69ms +step:805/1695 train_time:77834ms step_avg:96.69ms +step:806/1695 train_time:77931ms step_avg:96.69ms +step:807/1695 train_time:78029ms step_avg:96.69ms +step:808/1695 train_time:78126ms step_avg:96.69ms +step:809/1695 train_time:78222ms step_avg:96.69ms +step:810/1695 train_time:78317ms step_avg:96.69ms +step:811/1695 train_time:78413ms step_avg:96.69ms +step:812/1695 train_time:78509ms step_avg:96.69ms +step:813/1695 train_time:78604ms step_avg:96.68ms +step:814/1695 train_time:78699ms step_avg:96.68ms +step:815/1695 train_time:78795ms step_avg:96.68ms +step:816/1695 train_time:78891ms step_avg:96.68ms +step:817/1695 train_time:78988ms step_avg:96.68ms +step:818/1695 train_time:79085ms step_avg:96.68ms +step:819/1695 train_time:79182ms step_avg:96.68ms +step:820/1695 train_time:79278ms step_avg:96.68ms +step:821/1695 train_time:79373ms step_avg:96.68ms +step:822/1695 train_time:79469ms step_avg:96.68ms +step:823/1695 train_time:79565ms step_avg:96.68ms +step:824/1695 train_time:79660ms step_avg:96.67ms +step:825/1695 train_time:79755ms step_avg:96.67ms +step:826/1695 train_time:79852ms step_avg:96.67ms +step:827/1695 train_time:79949ms step_avg:96.67ms +step:828/1695 train_time:80045ms step_avg:96.67ms +step:829/1695 train_time:80142ms step_avg:96.67ms +step:830/1695 train_time:80237ms step_avg:96.67ms +step:831/1695 train_time:80333ms step_avg:96.67ms +step:832/1695 train_time:80430ms step_avg:96.67ms +step:833/1695 train_time:80527ms step_avg:96.67ms +step:834/1695 train_time:80624ms step_avg:96.67ms +step:835/1695 train_time:80719ms step_avg:96.67ms +step:836/1695 train_time:80815ms step_avg:96.67ms +step:837/1695 train_time:80911ms step_avg:96.67ms +step:838/1695 train_time:81007ms step_avg:96.67ms +step:839/1695 train_time:81103ms step_avg:96.67ms +step:840/1695 train_time:81199ms step_avg:96.67ms +step:841/1695 train_time:81295ms step_avg:96.66ms +step:842/1695 train_time:81392ms step_avg:96.66ms +step:843/1695 train_time:81488ms step_avg:96.66ms +step:844/1695 train_time:81583ms step_avg:96.66ms +step:845/1695 train_time:81678ms step_avg:96.66ms +step:846/1695 train_time:81773ms step_avg:96.66ms +step:847/1695 train_time:81869ms step_avg:96.66ms +step:848/1695 train_time:81965ms step_avg:96.66ms +step:849/1695 train_time:82062ms step_avg:96.66ms +step:850/1695 train_time:82158ms step_avg:96.66ms +step:851/1695 train_time:82254ms step_avg:96.66ms +step:852/1695 train_time:82350ms step_avg:96.66ms +step:853/1695 train_time:82447ms step_avg:96.66ms +step:854/1695 train_time:82542ms step_avg:96.65ms +step:855/1695 train_time:82637ms step_avg:96.65ms +step:856/1695 train_time:82733ms step_avg:96.65ms +step:857/1695 train_time:82829ms step_avg:96.65ms +step:858/1695 train_time:82925ms step_avg:96.65ms +step:859/1695 train_time:83021ms step_avg:96.65ms +step:860/1695 train_time:83116ms step_avg:96.65ms +step:861/1695 train_time:83212ms step_avg:96.65ms +step:862/1695 train_time:83309ms step_avg:96.65ms +step:863/1695 train_time:83635ms step_avg:96.91ms +step:864/1695 train_time:83834ms step_avg:97.03ms +step:865/1695 train_time:83928ms step_avg:97.03ms +step:866/1695 train_time:84024ms step_avg:97.03ms +step:867/1695 train_time:84118ms step_avg:97.02ms +step:868/1695 train_time:84213ms step_avg:97.02ms +step:869/1695 train_time:84308ms step_avg:97.02ms +step:870/1695 train_time:84404ms step_avg:97.02ms +step:871/1695 train_time:84498ms step_avg:97.01ms +step:872/1695 train_time:84593ms step_avg:97.01ms +step:873/1695 train_time:84695ms step_avg:97.02ms +step:874/1695 train_time:84795ms step_avg:97.02ms +step:875/1695 train_time:84893ms step_avg:97.02ms +step:875/1695 val_loss:3.5224 train_time:84986ms step_avg:97.13ms +step:876/1695 train_time:85012ms step_avg:97.05ms +step:877/1695 train_time:85093ms step_avg:97.03ms +step:878/1695 train_time:85195ms step_avg:97.03ms +step:879/1695 train_time:85293ms step_avg:97.03ms +step:880/1695 train_time:85390ms step_avg:97.03ms +step:881/1695 train_time:85485ms step_avg:97.03ms +step:882/1695 train_time:85580ms step_avg:97.03ms +step:883/1695 train_time:85675ms step_avg:97.03ms +step:884/1695 train_time:85769ms step_avg:97.02ms +step:885/1695 train_time:85864ms step_avg:97.02ms +step:886/1695 train_time:85959ms step_avg:97.02ms +step:887/1695 train_time:86057ms step_avg:97.02ms +step:888/1695 train_time:86155ms step_avg:97.02ms +step:889/1695 train_time:86253ms step_avg:97.02ms +step:890/1695 train_time:86350ms step_avg:97.02ms +step:891/1695 train_time:86447ms step_avg:97.02ms +step:892/1695 train_time:86542ms step_avg:97.02ms +step:893/1695 train_time:86637ms step_avg:97.02ms +step:894/1695 train_time:86732ms step_avg:97.02ms +step:895/1695 train_time:86827ms step_avg:97.01ms +step:896/1695 train_time:86922ms step_avg:97.01ms +step:897/1695 train_time:87018ms step_avg:97.01ms +step:898/1695 train_time:87115ms step_avg:97.01ms +step:899/1695 train_time:87213ms step_avg:97.01ms +step:900/1695 train_time:87311ms step_avg:97.01ms +step:901/1695 train_time:87409ms step_avg:97.01ms +step:902/1695 train_time:87506ms step_avg:97.01ms +step:903/1695 train_time:87601ms step_avg:97.01ms +step:904/1695 train_time:87696ms step_avg:97.01ms +step:905/1695 train_time:87792ms step_avg:97.01ms +step:906/1695 train_time:87887ms step_avg:97.01ms +step:907/1695 train_time:87983ms step_avg:97.00ms +step:908/1695 train_time:88079ms step_avg:97.00ms +step:909/1695 train_time:88175ms step_avg:97.00ms +step:910/1695 train_time:88271ms step_avg:97.00ms +step:911/1695 train_time:88367ms step_avg:97.00ms +step:912/1695 train_time:88463ms step_avg:97.00ms +step:913/1695 train_time:88558ms step_avg:97.00ms +step:914/1695 train_time:88654ms step_avg:97.00ms +step:915/1695 train_time:88750ms step_avg:96.99ms +step:916/1695 train_time:88846ms step_avg:96.99ms +step:917/1695 train_time:88942ms step_avg:96.99ms +step:918/1695 train_time:89038ms step_avg:96.99ms +step:919/1695 train_time:89134ms step_avg:96.99ms +step:920/1695 train_time:89231ms step_avg:96.99ms +step:921/1695 train_time:89327ms step_avg:96.99ms +step:922/1695 train_time:89422ms step_avg:96.99ms +step:923/1695 train_time:89518ms step_avg:96.99ms +step:924/1695 train_time:89614ms step_avg:96.98ms +step:925/1695 train_time:89709ms step_avg:96.98ms +step:926/1695 train_time:89806ms step_avg:96.98ms +step:927/1695 train_time:89901ms step_avg:96.98ms +step:928/1695 train_time:89997ms step_avg:96.98ms +step:929/1695 train_time:90093ms step_avg:96.98ms +step:930/1695 train_time:90190ms step_avg:96.98ms +step:931/1695 train_time:90287ms step_avg:96.98ms +step:932/1695 train_time:90384ms step_avg:96.98ms +step:933/1695 train_time:90479ms step_avg:96.98ms +step:934/1695 train_time:90575ms step_avg:96.98ms +step:935/1695 train_time:90671ms step_avg:96.97ms +step:936/1695 train_time:90767ms step_avg:96.97ms +step:937/1695 train_time:90863ms step_avg:96.97ms +step:938/1695 train_time:90958ms step_avg:96.97ms +step:939/1695 train_time:91055ms step_avg:96.97ms +step:940/1695 train_time:91150ms step_avg:96.97ms +step:941/1695 train_time:91246ms step_avg:96.97ms +step:942/1695 train_time:91342ms step_avg:96.97ms +step:943/1695 train_time:91438ms step_avg:96.96ms +step:944/1695 train_time:91533ms step_avg:96.96ms +step:945/1695 train_time:91630ms step_avg:96.96ms +step:946/1695 train_time:91727ms step_avg:96.96ms +step:947/1695 train_time:91824ms step_avg:96.96ms +step:948/1695 train_time:91920ms step_avg:96.96ms +step:949/1695 train_time:92015ms step_avg:96.96ms +step:950/1695 train_time:92112ms step_avg:96.96ms +step:951/1695 train_time:92208ms step_avg:96.96ms +step:952/1695 train_time:92305ms step_avg:96.96ms +step:953/1695 train_time:92401ms step_avg:96.96ms +step:954/1695 train_time:92496ms step_avg:96.96ms +step:955/1695 train_time:92592ms step_avg:96.96ms +step:956/1695 train_time:92689ms step_avg:96.95ms +step:957/1695 train_time:92785ms step_avg:96.95ms +step:958/1695 train_time:92881ms step_avg:96.95ms +step:959/1695 train_time:92976ms step_avg:96.95ms +step:960/1695 train_time:93072ms step_avg:96.95ms +step:961/1695 train_time:93168ms step_avg:96.95ms +step:962/1695 train_time:93265ms step_avg:96.95ms +step:963/1695 train_time:93362ms step_avg:96.95ms +step:964/1695 train_time:93458ms step_avg:96.95ms +step:965/1695 train_time:93555ms step_avg:96.95ms +step:966/1695 train_time:93652ms step_avg:96.95ms +step:967/1695 train_time:93748ms step_avg:96.95ms +step:968/1695 train_time:93844ms step_avg:96.95ms +step:969/1695 train_time:93939ms step_avg:96.94ms +step:970/1695 train_time:94035ms step_avg:96.94ms +step:971/1695 train_time:94132ms step_avg:96.94ms +step:972/1695 train_time:94229ms step_avg:96.94ms +step:973/1695 train_time:94326ms step_avg:96.94ms +step:974/1695 train_time:94422ms step_avg:96.94ms +step:975/1695 train_time:94517ms step_avg:96.94ms +step:976/1695 train_time:94614ms step_avg:96.94ms +step:977/1695 train_time:94712ms step_avg:96.94ms +step:978/1695 train_time:94808ms step_avg:96.94ms +step:979/1695 train_time:94904ms step_avg:96.94ms +step:980/1695 train_time:95000ms step_avg:96.94ms +step:981/1695 train_time:95095ms step_avg:96.94ms +step:982/1695 train_time:95192ms step_avg:96.94ms +step:983/1695 train_time:95289ms step_avg:96.94ms +step:984/1695 train_time:95386ms step_avg:96.94ms +step:985/1695 train_time:95482ms step_avg:96.94ms +step:986/1695 train_time:95577ms step_avg:96.93ms +step:987/1695 train_time:95673ms step_avg:96.93ms +step:988/1695 train_time:95770ms step_avg:96.93ms +step:989/1695 train_time:95866ms step_avg:96.93ms +step:990/1695 train_time:95963ms step_avg:96.93ms +step:991/1695 train_time:96059ms step_avg:96.93ms +step:992/1695 train_time:96154ms step_avg:96.93ms +step:993/1695 train_time:96251ms step_avg:96.93ms +step:994/1695 train_time:96348ms step_avg:96.93ms +step:995/1695 train_time:96445ms step_avg:96.93ms +step:996/1695 train_time:96539ms step_avg:96.93ms +step:997/1695 train_time:96634ms step_avg:96.93ms +step:998/1695 train_time:96732ms step_avg:96.93ms +step:999/1695 train_time:96829ms step_avg:96.93ms +step:1000/1695 train_time:96926ms step_avg:96.93ms +step:1000/1695 val_loss:3.4843 train_time:97020ms step_avg:97.02ms +step:1001/1695 train_time:97046ms step_avg:96.95ms +step:1002/1695 train_time:97126ms step_avg:96.93ms +step:1003/1695 train_time:97224ms step_avg:96.93ms +step:1004/1695 train_time:97319ms step_avg:96.93ms +step:1005/1695 train_time:97415ms step_avg:96.93ms +step:1006/1695 train_time:97511ms step_avg:96.93ms +step:1007/1695 train_time:97607ms step_avg:96.93ms +step:1008/1695 train_time:97701ms step_avg:96.93ms +step:1009/1695 train_time:97796ms step_avg:96.92ms +step:1010/1695 train_time:97892ms step_avg:96.92ms +step:1011/1695 train_time:97989ms step_avg:96.92ms +step:1012/1695 train_time:98087ms step_avg:96.92ms +step:1013/1695 train_time:98185ms step_avg:96.93ms +step:1014/1695 train_time:98281ms step_avg:96.92ms +step:1015/1695 train_time:98377ms step_avg:96.92ms +step:1016/1695 train_time:98474ms step_avg:96.92ms +step:1017/1695 train_time:98570ms step_avg:96.92ms +step:1018/1695 train_time:98665ms step_avg:96.92ms +step:1019/1695 train_time:98760ms step_avg:96.92ms +step:1020/1695 train_time:98855ms step_avg:96.92ms +step:1021/1695 train_time:98951ms step_avg:96.92ms +step:1022/1695 train_time:99049ms step_avg:96.92ms +step:1023/1695 train_time:99145ms step_avg:96.92ms +step:1024/1695 train_time:99241ms step_avg:96.92ms +step:1025/1695 train_time:99337ms step_avg:96.91ms +step:1026/1695 train_time:99434ms step_avg:96.91ms +step:1027/1695 train_time:99531ms step_avg:96.91ms +step:1028/1695 train_time:99627ms step_avg:96.91ms +step:1029/1695 train_time:99724ms step_avg:96.91ms +step:1030/1695 train_time:99817ms step_avg:96.91ms +step:1031/1695 train_time:99913ms step_avg:96.91ms +step:1032/1695 train_time:100009ms step_avg:96.91ms +step:1033/1695 train_time:100105ms step_avg:96.91ms +step:1034/1695 train_time:100202ms step_avg:96.91ms +step:1035/1695 train_time:100298ms step_avg:96.91ms +step:1036/1695 train_time:100628ms step_avg:97.13ms +step:1037/1695 train_time:100810ms step_avg:97.21ms +step:1038/1695 train_time:100904ms step_avg:97.21ms +step:1039/1695 train_time:100998ms step_avg:97.21ms +step:1040/1695 train_time:101093ms step_avg:97.20ms +step:1041/1695 train_time:101188ms step_avg:97.20ms +step:1042/1695 train_time:101283ms step_avg:97.20ms +step:1043/1695 train_time:101377ms step_avg:97.20ms +step:1044/1695 train_time:101472ms step_avg:97.20ms +step:1045/1695 train_time:101567ms step_avg:97.19ms +step:1046/1695 train_time:101666ms step_avg:97.20ms +step:1047/1695 train_time:101765ms step_avg:97.20ms +step:1048/1695 train_time:101862ms step_avg:97.20ms +step:1049/1695 train_time:101959ms step_avg:97.20ms +step:1050/1695 train_time:102054ms step_avg:97.19ms +step:1051/1695 train_time:102149ms step_avg:97.19ms +step:1052/1695 train_time:102244ms step_avg:97.19ms +step:1053/1695 train_time:102339ms step_avg:97.19ms +step:1054/1695 train_time:102434ms step_avg:97.19ms +step:1055/1695 train_time:102529ms step_avg:97.18ms +step:1056/1695 train_time:102626ms step_avg:97.18ms +step:1057/1695 train_time:102724ms step_avg:97.18ms +step:1058/1695 train_time:102821ms step_avg:97.18ms +step:1059/1695 train_time:102917ms step_avg:97.18ms +step:1060/1695 train_time:103015ms step_avg:97.18ms +step:1061/1695 train_time:103111ms step_avg:97.18ms +step:1062/1695 train_time:103207ms step_avg:97.18ms +step:1063/1695 train_time:103302ms step_avg:97.18ms +step:1064/1695 train_time:103397ms step_avg:97.18ms +step:1065/1695 train_time:103493ms step_avg:97.18ms +step:1066/1695 train_time:103589ms step_avg:97.18ms +step:1067/1695 train_time:103684ms step_avg:97.17ms +step:1068/1695 train_time:103781ms step_avg:97.17ms +step:1069/1695 train_time:103877ms step_avg:97.17ms +step:1070/1695 train_time:103974ms step_avg:97.17ms +step:1071/1695 train_time:104070ms step_avg:97.17ms +step:1072/1695 train_time:104166ms step_avg:97.17ms +step:1073/1695 train_time:104262ms step_avg:97.17ms +step:1074/1695 train_time:104357ms step_avg:97.17ms +step:1075/1695 train_time:104453ms step_avg:97.17ms +step:1076/1695 train_time:104549ms step_avg:97.16ms +step:1077/1695 train_time:104644ms step_avg:97.16ms +step:1078/1695 train_time:104739ms step_avg:97.16ms +step:1079/1695 train_time:104836ms step_avg:97.16ms +step:1080/1695 train_time:104933ms step_avg:97.16ms +step:1081/1695 train_time:105029ms step_avg:97.16ms +step:1082/1695 train_time:105125ms step_avg:97.16ms +step:1083/1695 train_time:105221ms step_avg:97.16ms +step:1084/1695 train_time:105316ms step_avg:97.15ms +step:1085/1695 train_time:105412ms step_avg:97.15ms +step:1086/1695 train_time:105508ms step_avg:97.15ms +step:1087/1695 train_time:105604ms step_avg:97.15ms +step:1088/1695 train_time:105699ms step_avg:97.15ms +step:1089/1695 train_time:105795ms step_avg:97.15ms +step:1090/1695 train_time:105893ms step_avg:97.15ms +step:1091/1695 train_time:105990ms step_avg:97.15ms +step:1092/1695 train_time:106086ms step_avg:97.15ms +step:1093/1695 train_time:106181ms step_avg:97.15ms +step:1094/1695 train_time:106277ms step_avg:97.15ms +step:1095/1695 train_time:106373ms step_avg:97.14ms +step:1096/1695 train_time:106469ms step_avg:97.14ms +step:1097/1695 train_time:106565ms step_avg:97.14ms +step:1098/1695 train_time:106661ms step_avg:97.14ms +step:1099/1695 train_time:106756ms step_avg:97.14ms +step:1100/1695 train_time:106854ms step_avg:97.14ms +step:1101/1695 train_time:106950ms step_avg:97.14ms +step:1102/1695 train_time:107046ms step_avg:97.14ms +step:1103/1695 train_time:107142ms step_avg:97.14ms +step:1104/1695 train_time:107237ms step_avg:97.13ms +step:1105/1695 train_time:107333ms step_avg:97.13ms +step:1106/1695 train_time:107429ms step_avg:97.13ms +step:1107/1695 train_time:107526ms step_avg:97.13ms +step:1108/1695 train_time:107622ms step_avg:97.13ms +step:1109/1695 train_time:107718ms step_avg:97.13ms +step:1110/1695 train_time:107814ms step_avg:97.13ms +step:1111/1695 train_time:107912ms step_avg:97.13ms +step:1112/1695 train_time:108009ms step_avg:97.13ms +step:1113/1695 train_time:108105ms step_avg:97.13ms +step:1114/1695 train_time:108200ms step_avg:97.13ms +step:1115/1695 train_time:108296ms step_avg:97.13ms +step:1116/1695 train_time:108393ms step_avg:97.13ms +step:1117/1695 train_time:108490ms step_avg:97.13ms +step:1118/1695 train_time:108587ms step_avg:97.13ms +step:1119/1695 train_time:108683ms step_avg:97.13ms +step:1120/1695 train_time:108778ms step_avg:97.12ms +step:1121/1695 train_time:108875ms step_avg:97.12ms +step:1122/1695 train_time:108970ms step_avg:97.12ms +step:1123/1695 train_time:109068ms step_avg:97.12ms +step:1124/1695 train_time:109164ms step_avg:97.12ms +step:1125/1695 train_time:109260ms step_avg:97.12ms +step:1125/1695 val_loss:3.4352 train_time:109353ms step_avg:97.20ms +step:1126/1695 train_time:109379ms step_avg:97.14ms +step:1127/1695 train_time:109456ms step_avg:97.12ms +step:1128/1695 train_time:109554ms step_avg:97.12ms +step:1129/1695 train_time:109650ms step_avg:97.12ms +step:1130/1695 train_time:109745ms step_avg:97.12ms +step:1131/1695 train_time:109840ms step_avg:97.12ms +step:1132/1695 train_time:109934ms step_avg:97.12ms +step:1133/1695 train_time:110031ms step_avg:97.11ms +step:1134/1695 train_time:110129ms step_avg:97.12ms +step:1135/1695 train_time:110228ms step_avg:97.12ms +step:1136/1695 train_time:110328ms step_avg:97.12ms +step:1137/1695 train_time:110431ms step_avg:97.12ms +step:1138/1695 train_time:110532ms step_avg:97.13ms +step:1139/1695 train_time:110631ms step_avg:97.13ms +step:1140/1695 train_time:110729ms step_avg:97.13ms +step:1141/1695 train_time:110826ms step_avg:97.13ms +step:1142/1695 train_time:110923ms step_avg:97.13ms +step:1143/1695 train_time:111020ms step_avg:97.13ms +step:1144/1695 train_time:111116ms step_avg:97.13ms +step:1145/1695 train_time:111214ms step_avg:97.13ms +step:1146/1695 train_time:111312ms step_avg:97.13ms +step:1147/1695 train_time:111412ms step_avg:97.13ms +step:1148/1695 train_time:111511ms step_avg:97.14ms +step:1149/1695 train_time:111611ms step_avg:97.14ms +step:1150/1695 train_time:111710ms step_avg:97.14ms +step:1151/1695 train_time:111808ms step_avg:97.14ms +step:1152/1695 train_time:111907ms step_avg:97.14ms +step:1153/1695 train_time:112006ms step_avg:97.14ms +step:1154/1695 train_time:112104ms step_avg:97.14ms +step:1155/1695 train_time:112201ms step_avg:97.14ms +step:1156/1695 train_time:112299ms step_avg:97.14ms +step:1157/1695 train_time:112396ms step_avg:97.14ms +step:1158/1695 train_time:112495ms step_avg:97.15ms +step:1159/1695 train_time:112593ms step_avg:97.15ms +step:1160/1695 train_time:112691ms step_avg:97.15ms +step:1161/1695 train_time:112790ms step_avg:97.15ms +step:1162/1695 train_time:112888ms step_avg:97.15ms +step:1163/1695 train_time:112986ms step_avg:97.15ms +step:1164/1695 train_time:113084ms step_avg:97.15ms +step:1165/1695 train_time:113181ms step_avg:97.15ms +step:1166/1695 train_time:113279ms step_avg:97.15ms +step:1167/1695 train_time:113376ms step_avg:97.15ms +step:1168/1695 train_time:113473ms step_avg:97.15ms +step:1169/1695 train_time:113571ms step_avg:97.15ms +step:1170/1695 train_time:113669ms step_avg:97.15ms +step:1171/1695 train_time:113767ms step_avg:97.15ms +step:1172/1695 train_time:113864ms step_avg:97.15ms +step:1173/1695 train_time:113962ms step_avg:97.15ms +step:1174/1695 train_time:114060ms step_avg:97.15ms +step:1175/1695 train_time:114157ms step_avg:97.16ms +step:1176/1695 train_time:114255ms step_avg:97.16ms +step:1177/1695 train_time:114352ms step_avg:97.16ms +step:1178/1695 train_time:114451ms step_avg:97.16ms +step:1179/1695 train_time:114550ms step_avg:97.16ms +step:1180/1695 train_time:114648ms step_avg:97.16ms +step:1181/1695 train_time:114747ms step_avg:97.16ms +step:1182/1695 train_time:114844ms step_avg:97.16ms +step:1183/1695 train_time:114942ms step_avg:97.16ms +step:1184/1695 train_time:115039ms step_avg:97.16ms +step:1185/1695 train_time:115136ms step_avg:97.16ms +step:1186/1695 train_time:115234ms step_avg:97.16ms +step:1187/1695 train_time:115331ms step_avg:97.16ms +step:1188/1695 train_time:115430ms step_avg:97.16ms +step:1189/1695 train_time:115529ms step_avg:97.16ms +step:1190/1695 train_time:115627ms step_avg:97.17ms +step:1191/1695 train_time:115725ms step_avg:97.17ms +step:1192/1695 train_time:115823ms step_avg:97.17ms +step:1193/1695 train_time:115922ms step_avg:97.17ms +step:1194/1695 train_time:116019ms step_avg:97.17ms +step:1195/1695 train_time:116117ms step_avg:97.17ms +step:1196/1695 train_time:116214ms step_avg:97.17ms +step:1197/1695 train_time:116311ms step_avg:97.17ms +step:1198/1695 train_time:116409ms step_avg:97.17ms +step:1199/1695 train_time:116507ms step_avg:97.17ms +step:1200/1695 train_time:116604ms step_avg:97.17ms +step:1201/1695 train_time:116702ms step_avg:97.17ms +step:1202/1695 train_time:116799ms step_avg:97.17ms +step:1203/1695 train_time:116897ms step_avg:97.17ms +step:1204/1695 train_time:116995ms step_avg:97.17ms +step:1205/1695 train_time:117093ms step_avg:97.17ms +step:1206/1695 train_time:117191ms step_avg:97.17ms +step:1207/1695 train_time:117289ms step_avg:97.17ms +step:1208/1695 train_time:117624ms step_avg:97.37ms +step:1209/1695 train_time:117814ms step_avg:97.45ms +step:1210/1695 train_time:117909ms step_avg:97.45ms +step:1211/1695 train_time:118006ms step_avg:97.44ms +step:1212/1695 train_time:118103ms step_avg:97.44ms +step:1213/1695 train_time:118199ms step_avg:97.44ms +step:1214/1695 train_time:118295ms step_avg:97.44ms +step:1215/1695 train_time:118393ms step_avg:97.44ms +step:1216/1695 train_time:118490ms step_avg:97.44ms +step:1217/1695 train_time:118587ms step_avg:97.44ms +step:1218/1695 train_time:118689ms step_avg:97.45ms +step:1219/1695 train_time:118792ms step_avg:97.45ms +step:1220/1695 train_time:118892ms step_avg:97.45ms +step:1221/1695 train_time:118991ms step_avg:97.45ms +step:1222/1695 train_time:119090ms step_avg:97.46ms +step:1223/1695 train_time:119190ms step_avg:97.46ms +step:1224/1695 train_time:119287ms step_avg:97.46ms +step:1225/1695 train_time:119384ms step_avg:97.46ms +step:1226/1695 train_time:119481ms step_avg:97.46ms +step:1227/1695 train_time:119577ms step_avg:97.45ms +step:1228/1695 train_time:119673ms step_avg:97.45ms +step:1229/1695 train_time:119773ms step_avg:97.46ms +step:1230/1695 train_time:119873ms step_avg:97.46ms +step:1231/1695 train_time:119972ms step_avg:97.46ms +step:1232/1695 train_time:120070ms step_avg:97.46ms +step:1233/1695 train_time:120168ms step_avg:97.46ms +step:1234/1695 train_time:120268ms step_avg:97.46ms +step:1235/1695 train_time:120366ms step_avg:97.46ms +step:1236/1695 train_time:120464ms step_avg:97.46ms +step:1237/1695 train_time:120561ms step_avg:97.46ms +step:1238/1695 train_time:120658ms step_avg:97.46ms +step:1239/1695 train_time:120756ms step_avg:97.46ms +step:1240/1695 train_time:120853ms step_avg:97.46ms +step:1241/1695 train_time:120952ms step_avg:97.46ms +step:1242/1695 train_time:121051ms step_avg:97.46ms +step:1243/1695 train_time:121150ms step_avg:97.47ms +step:1244/1695 train_time:121249ms step_avg:97.47ms +step:1245/1695 train_time:121347ms step_avg:97.47ms +step:1246/1695 train_time:121444ms step_avg:97.47ms +step:1247/1695 train_time:121542ms step_avg:97.47ms +step:1248/1695 train_time:121640ms step_avg:97.47ms +step:1249/1695 train_time:121738ms step_avg:97.47ms +step:1250/1695 train_time:121835ms step_avg:97.47ms +step:1250/1695 val_loss:3.3886 train_time:121930ms step_avg:97.54ms +step:1251/1695 train_time:121956ms step_avg:97.49ms +step:1252/1695 train_time:122037ms step_avg:97.47ms +step:1253/1695 train_time:122135ms step_avg:97.47ms +step:1254/1695 train_time:122231ms step_avg:97.47ms +step:1255/1695 train_time:122327ms step_avg:97.47ms +step:1256/1695 train_time:122424ms step_avg:97.47ms +step:1257/1695 train_time:122520ms step_avg:97.47ms +step:1258/1695 train_time:122616ms step_avg:97.47ms +step:1259/1695 train_time:122713ms step_avg:97.47ms +step:1260/1695 train_time:122809ms step_avg:97.47ms +step:1261/1695 train_time:122913ms step_avg:97.47ms +step:1262/1695 train_time:123013ms step_avg:97.47ms +step:1263/1695 train_time:123111ms step_avg:97.48ms +step:1264/1695 train_time:123209ms step_avg:97.48ms +step:1265/1695 train_time:123306ms step_avg:97.48ms +step:1266/1695 train_time:123403ms step_avg:97.47ms +step:1267/1695 train_time:123499ms step_avg:97.47ms +step:1268/1695 train_time:123596ms step_avg:97.47ms +step:1269/1695 train_time:123693ms step_avg:97.47ms +step:1270/1695 train_time:123789ms step_avg:97.47ms +step:1271/1695 train_time:123889ms step_avg:97.47ms +step:1272/1695 train_time:123988ms step_avg:97.47ms +step:1273/1695 train_time:124086ms step_avg:97.48ms +step:1274/1695 train_time:124185ms step_avg:97.48ms +step:1275/1695 train_time:124284ms step_avg:97.48ms +step:1276/1695 train_time:124381ms step_avg:97.48ms +step:1277/1695 train_time:124479ms step_avg:97.48ms +step:1278/1695 train_time:124576ms step_avg:97.48ms +step:1279/1695 train_time:124673ms step_avg:97.48ms +step:1280/1695 train_time:124770ms step_avg:97.48ms +step:1281/1695 train_time:124868ms step_avg:97.48ms +step:1282/1695 train_time:124966ms step_avg:97.48ms +step:1283/1695 train_time:125065ms step_avg:97.48ms +step:1284/1695 train_time:125165ms step_avg:97.48ms +step:1285/1695 train_time:125264ms step_avg:97.48ms +step:1286/1695 train_time:125363ms step_avg:97.48ms +step:1287/1695 train_time:125460ms step_avg:97.48ms +step:1288/1695 train_time:125558ms step_avg:97.48ms +step:1289/1695 train_time:125656ms step_avg:97.48ms +step:1290/1695 train_time:125755ms step_avg:97.48ms +step:1291/1695 train_time:125853ms step_avg:97.48ms +step:1292/1695 train_time:125950ms step_avg:97.48ms +step:1293/1695 train_time:126048ms step_avg:97.49ms +step:1294/1695 train_time:126147ms step_avg:97.49ms +step:1295/1695 train_time:126245ms step_avg:97.49ms +step:1296/1695 train_time:126344ms step_avg:97.49ms +step:1297/1695 train_time:126440ms step_avg:97.49ms +step:1298/1695 train_time:126538ms step_avg:97.49ms +step:1299/1695 train_time:126635ms step_avg:97.49ms +step:1300/1695 train_time:126731ms step_avg:97.49ms +step:1301/1695 train_time:126829ms step_avg:97.49ms +step:1302/1695 train_time:126927ms step_avg:97.49ms +step:1303/1695 train_time:127025ms step_avg:97.49ms +step:1304/1695 train_time:127124ms step_avg:97.49ms +step:1305/1695 train_time:127222ms step_avg:97.49ms +step:1306/1695 train_time:127320ms step_avg:97.49ms +step:1307/1695 train_time:127418ms step_avg:97.49ms +step:1308/1695 train_time:127515ms step_avg:97.49ms +step:1309/1695 train_time:127612ms step_avg:97.49ms +step:1310/1695 train_time:127710ms step_avg:97.49ms +step:1311/1695 train_time:127807ms step_avg:97.49ms +step:1312/1695 train_time:127905ms step_avg:97.49ms +step:1313/1695 train_time:128004ms step_avg:97.49ms +step:1314/1695 train_time:128104ms step_avg:97.49ms +step:1315/1695 train_time:128203ms step_avg:97.49ms +step:1316/1695 train_time:128301ms step_avg:97.49ms +step:1317/1695 train_time:128399ms step_avg:97.49ms +step:1318/1695 train_time:128498ms step_avg:97.49ms +step:1319/1695 train_time:128596ms step_avg:97.50ms +step:1320/1695 train_time:128695ms step_avg:97.50ms +step:1321/1695 train_time:128792ms step_avg:97.50ms +step:1322/1695 train_time:128889ms step_avg:97.50ms +step:1323/1695 train_time:128986ms step_avg:97.50ms +step:1324/1695 train_time:129085ms step_avg:97.50ms +step:1325/1695 train_time:129184ms step_avg:97.50ms +step:1326/1695 train_time:129282ms step_avg:97.50ms +step:1327/1695 train_time:129379ms step_avg:97.50ms +step:1328/1695 train_time:129477ms step_avg:97.50ms +step:1329/1695 train_time:129574ms step_avg:97.50ms +step:1330/1695 train_time:129672ms step_avg:97.50ms +step:1331/1695 train_time:129769ms step_avg:97.50ms +step:1332/1695 train_time:129866ms step_avg:97.50ms +step:1333/1695 train_time:129964ms step_avg:97.50ms +step:1334/1695 train_time:130063ms step_avg:97.50ms +step:1335/1695 train_time:130162ms step_avg:97.50ms +step:1336/1695 train_time:130260ms step_avg:97.50ms +step:1337/1695 train_time:130358ms step_avg:97.50ms +step:1338/1695 train_time:130455ms step_avg:97.50ms +step:1339/1695 train_time:130553ms step_avg:97.50ms +step:1340/1695 train_time:130652ms step_avg:97.50ms +step:1341/1695 train_time:130749ms step_avg:97.50ms +step:1342/1695 train_time:130846ms step_avg:97.50ms +step:1343/1695 train_time:130944ms step_avg:97.50ms +step:1344/1695 train_time:131043ms step_avg:97.50ms +step:1345/1695 train_time:131142ms step_avg:97.50ms +step:1346/1695 train_time:131241ms step_avg:97.50ms +step:1347/1695 train_time:131340ms step_avg:97.51ms +step:1348/1695 train_time:131438ms step_avg:97.51ms +step:1349/1695 train_time:131537ms step_avg:97.51ms +step:1350/1695 train_time:131636ms step_avg:97.51ms +step:1351/1695 train_time:131734ms step_avg:97.51ms +step:1352/1695 train_time:131832ms step_avg:97.51ms +step:1353/1695 train_time:131930ms step_avg:97.51ms +step:1354/1695 train_time:132028ms step_avg:97.51ms +step:1355/1695 train_time:132126ms step_avg:97.51ms +step:1356/1695 train_time:132223ms step_avg:97.51ms +step:1357/1695 train_time:132321ms step_avg:97.51ms +step:1358/1695 train_time:132419ms step_avg:97.51ms +step:1359/1695 train_time:132517ms step_avg:97.51ms +step:1360/1695 train_time:132614ms step_avg:97.51ms +step:1361/1695 train_time:132711ms step_avg:97.51ms +step:1362/1695 train_time:132808ms step_avg:97.51ms +step:1363/1695 train_time:132905ms step_avg:97.51ms +step:1364/1695 train_time:133004ms step_avg:97.51ms +step:1365/1695 train_time:133102ms step_avg:97.51ms +step:1366/1695 train_time:133200ms step_avg:97.51ms +step:1367/1695 train_time:133297ms step_avg:97.51ms +step:1368/1695 train_time:133393ms step_avg:97.51ms +step:1369/1695 train_time:133491ms step_avg:97.51ms +step:1370/1695 train_time:133589ms step_avg:97.51ms +step:1371/1695 train_time:133687ms step_avg:97.51ms +step:1372/1695 train_time:133785ms step_avg:97.51ms +step:1373/1695 train_time:133884ms step_avg:97.51ms +step:1374/1695 train_time:133982ms step_avg:97.51ms +step:1375/1695 train_time:134080ms step_avg:97.51ms +step:1375/1695 val_loss:3.3495 train_time:134174ms step_avg:97.58ms +step:1376/1695 train_time:134203ms step_avg:97.53ms +step:1377/1695 train_time:134283ms step_avg:97.52ms +step:1378/1695 train_time:134384ms step_avg:97.52ms +step:1379/1695 train_time:134483ms step_avg:97.52ms +step:1380/1695 train_time:134581ms step_avg:97.52ms +step:1381/1695 train_time:134941ms step_avg:97.71ms +step:1382/1695 train_time:135109ms step_avg:97.76ms +step:1383/1695 train_time:135205ms step_avg:97.76ms +step:1384/1695 train_time:135302ms step_avg:97.76ms +step:1385/1695 train_time:135398ms step_avg:97.76ms +step:1386/1695 train_time:135494ms step_avg:97.76ms +step:1387/1695 train_time:135591ms step_avg:97.76ms +step:1388/1695 train_time:135686ms step_avg:97.76ms +step:1389/1695 train_time:135783ms step_avg:97.76ms +step:1390/1695 train_time:135880ms step_avg:97.76ms +step:1391/1695 train_time:135981ms step_avg:97.76ms +step:1392/1695 train_time:136087ms step_avg:97.76ms +step:1393/1695 train_time:136186ms step_avg:97.76ms +step:1394/1695 train_time:136284ms step_avg:97.76ms +step:1395/1695 train_time:136381ms step_avg:97.76ms +step:1396/1695 train_time:136480ms step_avg:97.76ms +step:1397/1695 train_time:136577ms step_avg:97.76ms +step:1398/1695 train_time:136674ms step_avg:97.76ms +step:1399/1695 train_time:136770ms step_avg:97.76ms +step:1400/1695 train_time:136866ms step_avg:97.76ms +step:1401/1695 train_time:136964ms step_avg:97.76ms +step:1402/1695 train_time:137065ms step_avg:97.76ms +step:1403/1695 train_time:137165ms step_avg:97.77ms +step:1404/1695 train_time:137264ms step_avg:97.77ms +step:1405/1695 train_time:137362ms step_avg:97.77ms +step:1406/1695 train_time:137460ms step_avg:97.77ms +step:1407/1695 train_time:137558ms step_avg:97.77ms +step:1408/1695 train_time:137656ms step_avg:97.77ms +step:1409/1695 train_time:137752ms step_avg:97.77ms +step:1410/1695 train_time:137849ms step_avg:97.77ms +step:1411/1695 train_time:137946ms step_avg:97.76ms +step:1412/1695 train_time:138046ms step_avg:97.77ms +step:1413/1695 train_time:138144ms step_avg:97.77ms +step:1414/1695 train_time:138244ms step_avg:97.77ms +step:1415/1695 train_time:138342ms step_avg:97.77ms +step:1416/1695 train_time:138440ms step_avg:97.77ms +step:1417/1695 train_time:138538ms step_avg:97.77ms +step:1418/1695 train_time:138637ms step_avg:97.77ms +step:1419/1695 train_time:138735ms step_avg:97.77ms +step:1420/1695 train_time:138832ms step_avg:97.77ms +step:1421/1695 train_time:138928ms step_avg:97.77ms +step:1422/1695 train_time:139025ms step_avg:97.77ms +step:1423/1695 train_time:139123ms step_avg:97.77ms +step:1424/1695 train_time:139221ms step_avg:97.77ms +step:1425/1695 train_time:139320ms step_avg:97.77ms +step:1426/1695 train_time:139418ms step_avg:97.77ms +step:1427/1695 train_time:139515ms step_avg:97.77ms +step:1428/1695 train_time:139612ms step_avg:97.77ms +step:1429/1695 train_time:139709ms step_avg:97.77ms +step:1430/1695 train_time:139807ms step_avg:97.77ms +step:1431/1695 train_time:139905ms step_avg:97.77ms +step:1432/1695 train_time:140004ms step_avg:97.77ms +step:1433/1695 train_time:140104ms step_avg:97.77ms +step:1434/1695 train_time:140202ms step_avg:97.77ms +step:1435/1695 train_time:140300ms step_avg:97.77ms +step:1436/1695 train_time:140397ms step_avg:97.77ms +step:1437/1695 train_time:140496ms step_avg:97.77ms +step:1438/1695 train_time:140594ms step_avg:97.77ms +step:1439/1695 train_time:140691ms step_avg:97.77ms +step:1440/1695 train_time:140788ms step_avg:97.77ms +step:1441/1695 train_time:140885ms step_avg:97.77ms +step:1442/1695 train_time:140982ms step_avg:97.77ms +step:1443/1695 train_time:141080ms step_avg:97.77ms +step:1444/1695 train_time:141178ms step_avg:97.77ms +step:1445/1695 train_time:141275ms step_avg:97.77ms +step:1446/1695 train_time:141373ms step_avg:97.77ms +step:1447/1695 train_time:141471ms step_avg:97.77ms +step:1448/1695 train_time:141569ms step_avg:97.77ms +step:1449/1695 train_time:141667ms step_avg:97.77ms +step:1450/1695 train_time:141765ms step_avg:97.77ms +step:1451/1695 train_time:141864ms step_avg:97.77ms +step:1452/1695 train_time:141962ms step_avg:97.77ms +step:1453/1695 train_time:142060ms step_avg:97.77ms +step:1454/1695 train_time:142159ms step_avg:97.77ms +step:1455/1695 train_time:142257ms step_avg:97.77ms +step:1456/1695 train_time:142355ms step_avg:97.77ms +step:1457/1695 train_time:142452ms step_avg:97.77ms +step:1458/1695 train_time:142550ms step_avg:97.77ms +step:1459/1695 train_time:142648ms step_avg:97.77ms +step:1460/1695 train_time:142746ms step_avg:97.77ms +step:1461/1695 train_time:142844ms step_avg:97.77ms +step:1462/1695 train_time:142943ms step_avg:97.77ms +step:1463/1695 train_time:143041ms step_avg:97.77ms +step:1464/1695 train_time:143141ms step_avg:97.77ms +step:1465/1695 train_time:143241ms step_avg:97.78ms +step:1466/1695 train_time:143339ms step_avg:97.78ms +step:1467/1695 train_time:143438ms step_avg:97.78ms +step:1468/1695 train_time:143537ms step_avg:97.78ms +step:1469/1695 train_time:143635ms step_avg:97.78ms +step:1470/1695 train_time:143732ms step_avg:97.78ms +step:1471/1695 train_time:143829ms step_avg:97.78ms +step:1472/1695 train_time:143927ms step_avg:97.78ms +step:1473/1695 train_time:144025ms step_avg:97.78ms +step:1474/1695 train_time:144123ms step_avg:97.78ms +step:1475/1695 train_time:144221ms step_avg:97.78ms +step:1476/1695 train_time:144319ms step_avg:97.78ms +step:1477/1695 train_time:144418ms step_avg:97.78ms +step:1478/1695 train_time:144515ms step_avg:97.78ms +step:1479/1695 train_time:144613ms step_avg:97.78ms +step:1480/1695 train_time:144711ms step_avg:97.78ms +step:1481/1695 train_time:144809ms step_avg:97.78ms +step:1482/1695 train_time:144906ms step_avg:97.78ms +step:1483/1695 train_time:145004ms step_avg:97.78ms +step:1484/1695 train_time:145101ms step_avg:97.78ms +step:1485/1695 train_time:145200ms step_avg:97.78ms +step:1486/1695 train_time:145298ms step_avg:97.78ms +step:1487/1695 train_time:145397ms step_avg:97.78ms +step:1488/1695 train_time:145495ms step_avg:97.78ms +step:1489/1695 train_time:145593ms step_avg:97.78ms +step:1490/1695 train_time:145690ms step_avg:97.78ms +step:1491/1695 train_time:145787ms step_avg:97.78ms +step:1492/1695 train_time:145884ms step_avg:97.78ms +step:1493/1695 train_time:145982ms step_avg:97.78ms +step:1494/1695 train_time:146079ms step_avg:97.78ms +step:1495/1695 train_time:146177ms step_avg:97.78ms +step:1496/1695 train_time:146274ms step_avg:97.78ms +step:1497/1695 train_time:146372ms step_avg:97.78ms +step:1498/1695 train_time:146469ms step_avg:97.78ms +step:1499/1695 train_time:146567ms step_avg:97.78ms +step:1500/1695 train_time:146665ms step_avg:97.78ms +step:1500/1695 val_loss:3.3162 train_time:146761ms step_avg:97.84ms +step:1501/1695 train_time:146787ms step_avg:97.79ms +step:1502/1695 train_time:146870ms step_avg:97.78ms +step:1503/1695 train_time:146968ms step_avg:97.78ms +step:1504/1695 train_time:147065ms step_avg:97.78ms +step:1505/1695 train_time:147162ms step_avg:97.78ms +step:1506/1695 train_time:147259ms step_avg:97.78ms +step:1507/1695 train_time:147355ms step_avg:97.78ms +step:1508/1695 train_time:147452ms step_avg:97.78ms +step:1509/1695 train_time:147548ms step_avg:97.78ms +step:1510/1695 train_time:147645ms step_avg:97.78ms +step:1511/1695 train_time:147745ms step_avg:97.78ms +step:1512/1695 train_time:147846ms step_avg:97.78ms +step:1513/1695 train_time:147945ms step_avg:97.78ms +step:1514/1695 train_time:148043ms step_avg:97.78ms +step:1515/1695 train_time:148141ms step_avg:97.78ms +step:1516/1695 train_time:148240ms step_avg:97.78ms +step:1517/1695 train_time:148337ms step_avg:97.78ms +step:1518/1695 train_time:148433ms step_avg:97.78ms +step:1519/1695 train_time:148529ms step_avg:97.78ms +step:1520/1695 train_time:148627ms step_avg:97.78ms +step:1521/1695 train_time:148726ms step_avg:97.78ms +step:1522/1695 train_time:148825ms step_avg:97.78ms +step:1523/1695 train_time:148924ms step_avg:97.78ms +step:1524/1695 train_time:149022ms step_avg:97.78ms +step:1525/1695 train_time:149121ms step_avg:97.78ms +step:1526/1695 train_time:149219ms step_avg:97.78ms +step:1527/1695 train_time:149317ms step_avg:97.78ms +step:1528/1695 train_time:149413ms step_avg:97.78ms +step:1529/1695 train_time:149511ms step_avg:97.78ms +step:1530/1695 train_time:149608ms step_avg:97.78ms +step:1531/1695 train_time:149705ms step_avg:97.78ms +step:1532/1695 train_time:149804ms step_avg:97.78ms +step:1533/1695 train_time:149903ms step_avg:97.78ms +step:1534/1695 train_time:150001ms step_avg:97.78ms +step:1535/1695 train_time:150100ms step_avg:97.78ms +step:1536/1695 train_time:150198ms step_avg:97.79ms +step:1537/1695 train_time:150296ms step_avg:97.79ms +step:1538/1695 train_time:150393ms step_avg:97.78ms +step:1539/1695 train_time:150491ms step_avg:97.78ms +step:1540/1695 train_time:150588ms step_avg:97.78ms +step:1541/1695 train_time:150685ms step_avg:97.78ms +step:1542/1695 train_time:150783ms step_avg:97.78ms +step:1543/1695 train_time:150881ms step_avg:97.78ms +step:1544/1695 train_time:150981ms step_avg:97.79ms +step:1545/1695 train_time:151081ms step_avg:97.79ms +step:1546/1695 train_time:151181ms step_avg:97.79ms +step:1547/1695 train_time:151279ms step_avg:97.79ms +step:1548/1695 train_time:151378ms step_avg:97.79ms +step:1549/1695 train_time:151478ms step_avg:97.79ms +step:1550/1695 train_time:151576ms step_avg:97.79ms +step:1551/1695 train_time:151674ms step_avg:97.79ms +step:1552/1695 train_time:152071ms step_avg:97.98ms +step:1553/1695 train_time:152147ms step_avg:97.97ms +step:1554/1695 train_time:152242ms step_avg:97.97ms +step:1555/1695 train_time:152339ms step_avg:97.97ms +step:1556/1695 train_time:152435ms step_avg:97.97ms +step:1557/1695 train_time:152532ms step_avg:97.97ms +step:1558/1695 train_time:152628ms step_avg:97.96ms +step:1559/1695 train_time:152724ms step_avg:97.96ms +step:1560/1695 train_time:152821ms step_avg:97.96ms +step:1561/1695 train_time:152919ms step_avg:97.96ms +step:1562/1695 train_time:153018ms step_avg:97.96ms +step:1563/1695 train_time:153125ms step_avg:97.97ms +step:1564/1695 train_time:153224ms step_avg:97.97ms +step:1565/1695 train_time:153322ms step_avg:97.97ms +step:1566/1695 train_time:153420ms step_avg:97.97ms +step:1567/1695 train_time:153517ms step_avg:97.97ms +step:1568/1695 train_time:153616ms step_avg:97.97ms +step:1569/1695 train_time:153713ms step_avg:97.97ms +step:1570/1695 train_time:153810ms step_avg:97.97ms +step:1571/1695 train_time:153906ms step_avg:97.97ms +step:1572/1695 train_time:154004ms step_avg:97.97ms +step:1573/1695 train_time:154103ms step_avg:97.97ms +step:1574/1695 train_time:154203ms step_avg:97.97ms +step:1575/1695 train_time:154302ms step_avg:97.97ms +step:1576/1695 train_time:154400ms step_avg:97.97ms +step:1577/1695 train_time:154498ms step_avg:97.97ms +step:1578/1695 train_time:154596ms step_avg:97.97ms +step:1579/1695 train_time:154693ms step_avg:97.97ms +step:1580/1695 train_time:154791ms step_avg:97.97ms +step:1581/1695 train_time:154888ms step_avg:97.97ms +step:1582/1695 train_time:154985ms step_avg:97.97ms +step:1583/1695 train_time:155083ms step_avg:97.97ms +step:1584/1695 train_time:155182ms step_avg:97.97ms +step:1585/1695 train_time:155280ms step_avg:97.97ms +step:1586/1695 train_time:155380ms step_avg:97.97ms +step:1587/1695 train_time:155478ms step_avg:97.97ms +step:1588/1695 train_time:155575ms step_avg:97.97ms +step:1589/1695 train_time:155673ms step_avg:97.97ms +step:1590/1695 train_time:155771ms step_avg:97.97ms +step:1591/1695 train_time:155868ms step_avg:97.97ms +step:1592/1695 train_time:155965ms step_avg:97.97ms +step:1593/1695 train_time:156063ms step_avg:97.97ms +step:1594/1695 train_time:156160ms step_avg:97.97ms +step:1595/1695 train_time:156258ms step_avg:97.97ms +step:1596/1695 train_time:156357ms step_avg:97.97ms +step:1597/1695 train_time:156456ms step_avg:97.97ms +step:1598/1695 train_time:156554ms step_avg:97.97ms +step:1599/1695 train_time:156651ms step_avg:97.97ms +step:1600/1695 train_time:156748ms step_avg:97.97ms +step:1601/1695 train_time:156845ms step_avg:97.97ms +step:1602/1695 train_time:156943ms step_avg:97.97ms +step:1603/1695 train_time:157041ms step_avg:97.97ms +step:1604/1695 train_time:157140ms step_avg:97.97ms +step:1605/1695 train_time:157239ms step_avg:97.97ms +step:1606/1695 train_time:157337ms step_avg:97.97ms +step:1607/1695 train_time:157436ms step_avg:97.97ms +step:1608/1695 train_time:157535ms step_avg:97.97ms +step:1609/1695 train_time:157632ms step_avg:97.97ms +step:1610/1695 train_time:157730ms step_avg:97.97ms +step:1611/1695 train_time:157827ms step_avg:97.97ms +step:1612/1695 train_time:157925ms step_avg:97.97ms +step:1613/1695 train_time:158022ms step_avg:97.97ms +step:1614/1695 train_time:158119ms step_avg:97.97ms +step:1615/1695 train_time:158218ms step_avg:97.97ms +step:1616/1695 train_time:158317ms step_avg:97.97ms +step:1617/1695 train_time:158415ms step_avg:97.97ms +step:1618/1695 train_time:158513ms step_avg:97.97ms +step:1619/1695 train_time:158611ms step_avg:97.97ms +step:1620/1695 train_time:158708ms step_avg:97.97ms +step:1621/1695 train_time:158805ms step_avg:97.97ms +step:1622/1695 train_time:158903ms step_avg:97.97ms +step:1623/1695 train_time:159001ms step_avg:97.97ms +step:1624/1695 train_time:159099ms step_avg:97.97ms +step:1625/1695 train_time:159197ms step_avg:97.97ms +step:1625/1695 val_loss:3.2895 train_time:159292ms step_avg:98.03ms +step:1626/1695 train_time:159319ms step_avg:97.98ms +step:1627/1695 train_time:159403ms step_avg:97.97ms +step:1628/1695 train_time:159501ms step_avg:97.97ms +step:1629/1695 train_time:159598ms step_avg:97.97ms +step:1630/1695 train_time:159696ms step_avg:97.97ms +step:1631/1695 train_time:159793ms step_avg:97.97ms +step:1632/1695 train_time:159890ms step_avg:97.97ms +step:1633/1695 train_time:159986ms step_avg:97.97ms +step:1634/1695 train_time:160083ms step_avg:97.97ms +step:1635/1695 train_time:160179ms step_avg:97.97ms +step:1636/1695 train_time:160280ms step_avg:97.97ms +step:1637/1695 train_time:160382ms step_avg:97.97ms +step:1638/1695 train_time:160482ms step_avg:97.97ms +step:1639/1695 train_time:160580ms step_avg:97.97ms +step:1640/1695 train_time:160678ms step_avg:97.97ms +step:1641/1695 train_time:160774ms step_avg:97.97ms +step:1642/1695 train_time:160871ms step_avg:97.97ms +step:1643/1695 train_time:160969ms step_avg:97.97ms +step:1644/1695 train_time:161066ms step_avg:97.97ms +step:1645/1695 train_time:161162ms step_avg:97.97ms +step:1646/1695 train_time:161261ms step_avg:97.97ms +step:1647/1695 train_time:161362ms step_avg:97.97ms +step:1648/1695 train_time:161461ms step_avg:97.97ms +step:1649/1695 train_time:161559ms step_avg:97.97ms +step:1650/1695 train_time:161657ms step_avg:97.97ms +step:1651/1695 train_time:161755ms step_avg:97.97ms +step:1652/1695 train_time:161852ms step_avg:97.97ms +step:1653/1695 train_time:161951ms step_avg:97.97ms +step:1654/1695 train_time:162049ms step_avg:97.97ms +step:1655/1695 train_time:162146ms step_avg:97.97ms +step:1656/1695 train_time:162244ms step_avg:97.97ms +step:1657/1695 train_time:162342ms step_avg:97.97ms +step:1658/1695 train_time:162440ms step_avg:97.97ms +step:1659/1695 train_time:162538ms step_avg:97.97ms +step:1660/1695 train_time:162636ms step_avg:97.97ms +step:1661/1695 train_time:162734ms step_avg:97.97ms +step:1662/1695 train_time:162831ms step_avg:97.97ms +step:1663/1695 train_time:162928ms step_avg:97.97ms +step:1664/1695 train_time:163026ms step_avg:97.97ms +step:1665/1695 train_time:163122ms step_avg:97.97ms +step:1666/1695 train_time:163221ms step_avg:97.97ms +step:1667/1695 train_time:163320ms step_avg:97.97ms +step:1668/1695 train_time:163418ms step_avg:97.97ms +step:1669/1695 train_time:163518ms step_avg:97.97ms +step:1670/1695 train_time:163617ms step_avg:97.97ms +step:1671/1695 train_time:163715ms step_avg:97.97ms +step:1672/1695 train_time:163812ms step_avg:97.97ms +step:1673/1695 train_time:163911ms step_avg:97.97ms +step:1674/1695 train_time:164008ms step_avg:97.97ms +step:1675/1695 train_time:164105ms step_avg:97.97ms +step:1676/1695 train_time:164202ms step_avg:97.97ms +step:1677/1695 train_time:164300ms step_avg:97.97ms +step:1678/1695 train_time:164397ms step_avg:97.97ms +step:1679/1695 train_time:164495ms step_avg:97.97ms +step:1680/1695 train_time:164593ms step_avg:97.97ms +step:1681/1695 train_time:164691ms step_avg:97.97ms +step:1682/1695 train_time:164789ms step_avg:97.97ms +step:1683/1695 train_time:164886ms step_avg:97.97ms +step:1684/1695 train_time:164984ms step_avg:97.97ms +step:1685/1695 train_time:165081ms step_avg:97.97ms +step:1686/1695 train_time:165179ms step_avg:97.97ms +step:1687/1695 train_time:165278ms step_avg:97.97ms +step:1688/1695 train_time:165377ms step_avg:97.97ms +step:1689/1695 train_time:165475ms step_avg:97.97ms +step:1690/1695 train_time:165573ms step_avg:97.97ms +step:1691/1695 train_time:165672ms step_avg:97.97ms +step:1692/1695 train_time:165770ms step_avg:97.97ms +step:1693/1695 train_time:165868ms step_avg:97.97ms +step:1694/1695 train_time:165966ms step_avg:97.97ms +step:1695/1695 train_time:166062ms step_avg:97.97ms +step:1695/1695 val_loss:3.2782 train_time:166157ms step_avg:98.03ms +peak memory allocated: 34361 MiB reserved: 49576 MiB diff --git a/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt b/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt new file mode 100644 index 000000000..32ec95b7e --- /dev/null +++ b/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 04:04:43 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 29C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 29C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 29C P0 109W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 31C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:517ms step_avg:517.24ms +step:2/1695 train_time:541ms step_avg:270.47ms +step:3/1695 train_time:610ms step_avg:203.20ms +step:4/1695 train_time:702ms step_avg:175.48ms +step:5/1695 train_time:795ms step_avg:158.96ms +step:6/1695 train_time:888ms step_avg:148.00ms +step:7/1695 train_time:982ms step_avg:140.22ms +step:8/1695 train_time:1075ms step_avg:134.38ms +step:9/1695 train_time:1169ms step_avg:129.83ms +step:10/1695 train_time:1262ms step_avg:126.22ms +step:11/1695 train_time:1356ms step_avg:123.26ms +step:12/1695 train_time:1454ms step_avg:121.13ms +step:13/1695 train_time:1553ms step_avg:119.44ms +step:14/1695 train_time:1650ms step_avg:117.83ms +step:15/1695 train_time:1744ms step_avg:116.30ms +step:16/1695 train_time:1838ms step_avg:114.86ms +step:17/1695 train_time:1931ms step_avg:113.60ms +step:18/1695 train_time:2025ms step_avg:112.49ms +step:19/1695 train_time:2118ms step_avg:111.45ms +step:20/1695 train_time:2211ms step_avg:110.56ms +step:21/1695 train_time:2306ms step_avg:109.79ms +step:22/1695 train_time:2401ms step_avg:109.13ms +step:23/1695 train_time:2497ms step_avg:108.55ms +step:24/1695 train_time:2592ms step_avg:108.02ms +step:25/1695 train_time:2689ms step_avg:107.56ms +step:26/1695 train_time:2785ms step_avg:107.10ms +step:27/1695 train_time:2879ms step_avg:106.62ms +step:28/1695 train_time:2973ms step_avg:106.16ms +step:29/1695 train_time:3067ms step_avg:105.75ms +step:30/1695 train_time:3160ms step_avg:105.34ms +step:31/1695 train_time:3253ms step_avg:104.95ms +step:32/1695 train_time:3349ms step_avg:104.66ms +step:33/1695 train_time:3445ms step_avg:104.39ms +step:34/1695 train_time:3540ms step_avg:104.10ms +step:35/1695 train_time:3634ms step_avg:103.84ms +step:36/1695 train_time:3731ms step_avg:103.64ms +step:37/1695 train_time:3827ms step_avg:103.44ms +step:38/1695 train_time:3922ms step_avg:103.22ms +step:39/1695 train_time:4016ms step_avg:102.98ms +step:40/1695 train_time:4111ms step_avg:102.77ms +step:41/1695 train_time:4205ms step_avg:102.56ms +step:42/1695 train_time:4299ms step_avg:102.35ms +step:43/1695 train_time:4393ms step_avg:102.16ms +step:44/1695 train_time:4488ms step_avg:102.01ms +step:45/1695 train_time:4584ms step_avg:101.86ms +step:46/1695 train_time:4678ms step_avg:101.68ms +step:47/1695 train_time:4773ms step_avg:101.54ms +step:48/1695 train_time:4867ms step_avg:101.40ms +step:49/1695 train_time:4962ms step_avg:101.27ms +step:50/1695 train_time:5056ms step_avg:101.12ms +step:51/1695 train_time:5150ms step_avg:100.98ms +step:52/1695 train_time:5244ms step_avg:100.85ms +step:53/1695 train_time:5338ms step_avg:100.72ms +step:54/1695 train_time:5432ms step_avg:100.60ms +step:55/1695 train_time:5528ms step_avg:100.51ms +step:56/1695 train_time:5623ms step_avg:100.42ms +step:57/1695 train_time:5717ms step_avg:100.31ms +step:58/1695 train_time:5812ms step_avg:100.20ms +step:59/1695 train_time:5908ms step_avg:100.13ms +step:60/1695 train_time:6003ms step_avg:100.04ms +step:61/1695 train_time:6096ms step_avg:99.94ms +step:62/1695 train_time:6190ms step_avg:99.84ms +step:63/1695 train_time:6284ms step_avg:99.75ms +step:64/1695 train_time:6378ms step_avg:99.66ms +step:65/1695 train_time:6472ms step_avg:99.57ms +step:66/1695 train_time:6568ms step_avg:99.51ms +step:67/1695 train_time:6663ms step_avg:99.45ms +step:68/1695 train_time:6756ms step_avg:99.36ms +step:69/1695 train_time:6851ms step_avg:99.29ms +step:70/1695 train_time:6946ms step_avg:99.23ms +step:71/1695 train_time:7040ms step_avg:99.16ms +step:72/1695 train_time:7133ms step_avg:99.07ms +step:73/1695 train_time:7228ms step_avg:99.02ms +step:74/1695 train_time:7324ms step_avg:98.97ms +step:75/1695 train_time:7418ms step_avg:98.90ms +step:76/1695 train_time:7512ms step_avg:98.84ms +step:77/1695 train_time:7607ms step_avg:98.80ms +step:78/1695 train_time:7702ms step_avg:98.75ms +step:79/1695 train_time:7796ms step_avg:98.69ms +step:80/1695 train_time:7891ms step_avg:98.63ms +step:81/1695 train_time:7985ms step_avg:98.58ms +step:82/1695 train_time:8079ms step_avg:98.52ms +step:83/1695 train_time:8172ms step_avg:98.46ms +step:84/1695 train_time:8268ms step_avg:98.42ms +step:85/1695 train_time:8363ms step_avg:98.39ms +step:86/1695 train_time:8457ms step_avg:98.33ms +step:87/1695 train_time:8551ms step_avg:98.29ms +step:88/1695 train_time:8647ms step_avg:98.26ms +step:89/1695 train_time:8742ms step_avg:98.22ms +step:90/1695 train_time:8835ms step_avg:98.17ms +step:91/1695 train_time:8931ms step_avg:98.14ms +step:92/1695 train_time:9024ms step_avg:98.09ms +step:93/1695 train_time:9118ms step_avg:98.04ms +step:94/1695 train_time:9211ms step_avg:97.99ms +step:95/1695 train_time:9305ms step_avg:97.95ms +step:96/1695 train_time:9399ms step_avg:97.91ms +step:97/1695 train_time:9494ms step_avg:97.87ms +step:98/1695 train_time:9588ms step_avg:97.84ms +step:99/1695 train_time:9684ms step_avg:97.81ms +step:100/1695 train_time:9777ms step_avg:97.77ms +step:101/1695 train_time:9872ms step_avg:97.75ms +step:102/1695 train_time:9968ms step_avg:97.73ms +step:103/1695 train_time:10063ms step_avg:97.70ms +step:104/1695 train_time:10157ms step_avg:97.66ms +step:105/1695 train_time:10251ms step_avg:97.63ms +step:106/1695 train_time:10345ms step_avg:97.59ms +step:107/1695 train_time:10439ms step_avg:97.56ms +step:108/1695 train_time:10533ms step_avg:97.52ms +step:109/1695 train_time:10627ms step_avg:97.50ms +step:110/1695 train_time:10721ms step_avg:97.46ms +step:111/1695 train_time:10814ms step_avg:97.42ms +step:112/1695 train_time:10909ms step_avg:97.40ms +step:113/1695 train_time:11004ms step_avg:97.38ms +step:114/1695 train_time:11098ms step_avg:97.35ms +step:115/1695 train_time:11192ms step_avg:97.33ms +step:116/1695 train_time:11288ms step_avg:97.31ms +step:117/1695 train_time:11381ms step_avg:97.28ms +step:118/1695 train_time:11476ms step_avg:97.25ms +step:119/1695 train_time:11570ms step_avg:97.23ms +step:120/1695 train_time:11664ms step_avg:97.20ms +step:121/1695 train_time:11758ms step_avg:97.17ms +step:122/1695 train_time:11852ms step_avg:97.15ms +step:123/1695 train_time:11947ms step_avg:97.13ms +step:124/1695 train_time:12042ms step_avg:97.11ms +step:125/1695 train_time:12135ms step_avg:97.08ms +step:125/1695 val_loss:4.3142 train_time:12227ms step_avg:97.82ms +step:126/1695 train_time:12252ms step_avg:97.24ms +step:127/1695 train_time:12329ms step_avg:97.08ms +step:128/1695 train_time:12428ms step_avg:97.09ms +step:129/1695 train_time:12522ms step_avg:97.07ms +step:130/1695 train_time:12616ms step_avg:97.05ms +step:131/1695 train_time:12710ms step_avg:97.02ms +step:132/1695 train_time:12803ms step_avg:96.99ms +step:133/1695 train_time:12896ms step_avg:96.96ms +step:134/1695 train_time:12990ms step_avg:96.94ms +step:135/1695 train_time:13083ms step_avg:96.91ms +step:136/1695 train_time:13177ms step_avg:96.89ms +step:137/1695 train_time:13273ms step_avg:96.88ms +step:138/1695 train_time:13370ms step_avg:96.88ms +step:139/1695 train_time:13466ms step_avg:96.88ms +step:140/1695 train_time:13560ms step_avg:96.86ms +step:141/1695 train_time:13654ms step_avg:96.84ms +step:142/1695 train_time:13748ms step_avg:96.82ms +step:143/1695 train_time:13841ms step_avg:96.79ms +step:144/1695 train_time:13935ms step_avg:96.77ms +step:145/1695 train_time:14029ms step_avg:96.75ms +step:146/1695 train_time:14122ms step_avg:96.72ms +step:147/1695 train_time:14215ms step_avg:96.70ms +step:148/1695 train_time:14312ms step_avg:96.71ms +step:149/1695 train_time:14409ms step_avg:96.70ms +step:150/1695 train_time:14503ms step_avg:96.69ms +step:151/1695 train_time:14597ms step_avg:96.67ms +step:152/1695 train_time:14692ms step_avg:96.66ms +step:153/1695 train_time:14787ms step_avg:96.65ms +step:154/1695 train_time:14881ms step_avg:96.63ms +step:155/1695 train_time:14974ms step_avg:96.61ms +step:156/1695 train_time:15068ms step_avg:96.59ms +step:157/1695 train_time:15161ms step_avg:96.56ms +step:158/1695 train_time:15255ms step_avg:96.55ms +step:159/1695 train_time:15350ms step_avg:96.54ms +step:160/1695 train_time:15446ms step_avg:96.53ms +step:161/1695 train_time:15540ms step_avg:96.52ms +step:162/1695 train_time:15635ms step_avg:96.51ms +step:163/1695 train_time:15729ms step_avg:96.50ms +step:164/1695 train_time:15824ms step_avg:96.49ms +step:165/1695 train_time:15917ms step_avg:96.47ms +step:166/1695 train_time:16012ms step_avg:96.46ms +step:167/1695 train_time:16106ms step_avg:96.44ms +step:168/1695 train_time:16199ms step_avg:96.42ms +step:169/1695 train_time:16293ms step_avg:96.41ms +step:170/1695 train_time:16388ms step_avg:96.40ms +step:171/1695 train_time:16482ms step_avg:96.39ms +step:172/1695 train_time:16576ms step_avg:96.37ms +step:173/1695 train_time:16958ms step_avg:98.02ms +step:174/1695 train_time:17044ms step_avg:97.95ms +step:175/1695 train_time:17136ms step_avg:97.92ms +step:176/1695 train_time:17228ms step_avg:97.89ms +step:177/1695 train_time:17321ms step_avg:97.86ms +step:178/1695 train_time:17414ms step_avg:97.83ms +step:179/1695 train_time:17507ms step_avg:97.80ms +step:180/1695 train_time:17599ms step_avg:97.77ms +step:181/1695 train_time:17693ms step_avg:97.75ms +step:182/1695 train_time:17787ms step_avg:97.73ms +step:183/1695 train_time:17880ms step_avg:97.71ms +step:184/1695 train_time:17978ms step_avg:97.71ms +step:185/1695 train_time:18075ms step_avg:97.71ms +step:186/1695 train_time:18170ms step_avg:97.69ms +step:187/1695 train_time:18264ms step_avg:97.67ms +step:188/1695 train_time:18357ms step_avg:97.64ms +step:189/1695 train_time:18451ms step_avg:97.62ms +step:190/1695 train_time:18545ms step_avg:97.60ms +step:191/1695 train_time:18637ms step_avg:97.58ms +step:192/1695 train_time:18731ms step_avg:97.56ms +step:193/1695 train_time:18825ms step_avg:97.54ms +step:194/1695 train_time:18919ms step_avg:97.52ms +step:195/1695 train_time:19016ms step_avg:97.52ms +step:196/1695 train_time:19111ms step_avg:97.51ms +step:197/1695 train_time:19207ms step_avg:97.50ms +step:198/1695 train_time:19300ms step_avg:97.48ms +step:199/1695 train_time:19395ms step_avg:97.46ms +step:200/1695 train_time:19489ms step_avg:97.44ms +step:201/1695 train_time:19582ms step_avg:97.42ms +step:202/1695 train_time:19676ms step_avg:97.40ms +step:203/1695 train_time:19769ms step_avg:97.39ms +step:204/1695 train_time:19863ms step_avg:97.37ms +step:205/1695 train_time:19957ms step_avg:97.35ms +step:206/1695 train_time:20052ms step_avg:97.34ms +step:207/1695 train_time:20148ms step_avg:97.33ms +step:208/1695 train_time:20242ms step_avg:97.32ms +step:209/1695 train_time:20336ms step_avg:97.30ms +step:210/1695 train_time:20430ms step_avg:97.29ms +step:211/1695 train_time:20523ms step_avg:97.27ms +step:212/1695 train_time:20617ms step_avg:97.25ms +step:213/1695 train_time:20710ms step_avg:97.23ms +step:214/1695 train_time:20804ms step_avg:97.21ms +step:215/1695 train_time:20897ms step_avg:97.20ms +step:216/1695 train_time:20991ms step_avg:97.18ms +step:217/1695 train_time:21086ms step_avg:97.17ms +step:218/1695 train_time:21180ms step_avg:97.15ms +step:219/1695 train_time:21274ms step_avg:97.14ms +step:220/1695 train_time:21369ms step_avg:97.13ms +step:221/1695 train_time:21463ms step_avg:97.12ms +step:222/1695 train_time:21557ms step_avg:97.10ms +step:223/1695 train_time:21651ms step_avg:97.09ms +step:224/1695 train_time:21745ms step_avg:97.07ms +step:225/1695 train_time:21838ms step_avg:97.06ms +step:226/1695 train_time:21933ms step_avg:97.05ms +step:227/1695 train_time:22028ms step_avg:97.04ms +step:228/1695 train_time:22122ms step_avg:97.03ms +step:229/1695 train_time:22216ms step_avg:97.01ms +step:230/1695 train_time:22311ms step_avg:97.01ms +step:231/1695 train_time:22406ms step_avg:97.00ms +step:232/1695 train_time:22500ms step_avg:96.98ms +step:233/1695 train_time:22593ms step_avg:96.97ms +step:234/1695 train_time:22688ms step_avg:96.96ms +step:235/1695 train_time:22781ms step_avg:96.94ms +step:236/1695 train_time:22874ms step_avg:96.93ms +step:237/1695 train_time:22969ms step_avg:96.91ms +step:238/1695 train_time:23062ms step_avg:96.90ms +step:239/1695 train_time:23155ms step_avg:96.88ms +step:240/1695 train_time:23249ms step_avg:96.87ms +step:241/1695 train_time:23343ms step_avg:96.86ms +step:242/1695 train_time:23437ms step_avg:96.85ms +step:243/1695 train_time:23531ms step_avg:96.84ms +step:244/1695 train_time:23625ms step_avg:96.82ms +step:245/1695 train_time:23719ms step_avg:96.81ms +step:246/1695 train_time:23813ms step_avg:96.80ms +step:247/1695 train_time:23908ms step_avg:96.79ms +step:248/1695 train_time:24002ms step_avg:96.78ms +step:249/1695 train_time:24096ms step_avg:96.77ms +step:250/1695 train_time:24190ms step_avg:96.76ms +step:250/1695 val_loss:3.9738 train_time:24282ms step_avg:97.13ms +step:251/1695 train_time:24306ms step_avg:96.84ms +step:252/1695 train_time:24385ms step_avg:96.77ms +step:253/1695 train_time:24484ms step_avg:96.78ms +step:254/1695 train_time:24579ms step_avg:96.77ms +step:255/1695 train_time:24672ms step_avg:96.75ms +step:256/1695 train_time:24766ms step_avg:96.74ms +step:257/1695 train_time:24858ms step_avg:96.73ms +step:258/1695 train_time:24951ms step_avg:96.71ms +step:259/1695 train_time:25044ms step_avg:96.69ms +step:260/1695 train_time:25137ms step_avg:96.68ms +step:261/1695 train_time:25231ms step_avg:96.67ms +step:262/1695 train_time:25325ms step_avg:96.66ms +step:263/1695 train_time:25423ms step_avg:96.66ms +step:264/1695 train_time:25519ms step_avg:96.66ms +step:265/1695 train_time:25614ms step_avg:96.66ms +step:266/1695 train_time:25708ms step_avg:96.65ms +step:267/1695 train_time:25801ms step_avg:96.63ms +step:268/1695 train_time:25895ms step_avg:96.62ms +step:269/1695 train_time:25988ms step_avg:96.61ms +step:270/1695 train_time:26081ms step_avg:96.60ms +step:271/1695 train_time:26174ms step_avg:96.58ms +step:272/1695 train_time:26267ms step_avg:96.57ms +step:273/1695 train_time:26362ms step_avg:96.57ms +step:274/1695 train_time:26458ms step_avg:96.56ms +step:275/1695 train_time:26553ms step_avg:96.56ms +step:276/1695 train_time:26648ms step_avg:96.55ms +step:277/1695 train_time:26741ms step_avg:96.54ms +step:278/1695 train_time:26836ms step_avg:96.53ms +step:279/1695 train_time:26931ms step_avg:96.53ms +step:280/1695 train_time:27023ms step_avg:96.51ms +step:281/1695 train_time:27117ms step_avg:96.50ms +step:282/1695 train_time:27210ms step_avg:96.49ms +step:283/1695 train_time:27303ms step_avg:96.48ms +step:284/1695 train_time:27398ms step_avg:96.47ms +step:285/1695 train_time:27492ms step_avg:96.46ms +step:286/1695 train_time:27587ms step_avg:96.46ms +step:287/1695 train_time:27681ms step_avg:96.45ms +step:288/1695 train_time:27775ms step_avg:96.44ms +step:289/1695 train_time:27870ms step_avg:96.44ms +step:290/1695 train_time:27963ms step_avg:96.43ms +step:291/1695 train_time:28057ms step_avg:96.42ms +step:292/1695 train_time:28151ms step_avg:96.41ms +step:293/1695 train_time:28244ms step_avg:96.39ms +step:294/1695 train_time:28338ms step_avg:96.39ms +step:295/1695 train_time:28432ms step_avg:96.38ms +step:296/1695 train_time:28526ms step_avg:96.37ms +step:297/1695 train_time:28620ms step_avg:96.36ms +step:298/1695 train_time:28714ms step_avg:96.35ms +step:299/1695 train_time:28808ms step_avg:96.35ms +step:300/1695 train_time:28901ms step_avg:96.34ms +step:301/1695 train_time:28995ms step_avg:96.33ms +step:302/1695 train_time:29089ms step_avg:96.32ms +step:303/1695 train_time:29182ms step_avg:96.31ms +step:304/1695 train_time:29276ms step_avg:96.30ms +step:305/1695 train_time:29370ms step_avg:96.30ms +step:306/1695 train_time:29464ms step_avg:96.29ms +step:307/1695 train_time:29558ms step_avg:96.28ms +step:308/1695 train_time:29653ms step_avg:96.28ms +step:309/1695 train_time:29748ms step_avg:96.27ms +step:310/1695 train_time:29841ms step_avg:96.26ms +step:311/1695 train_time:29936ms step_avg:96.26ms +step:312/1695 train_time:30031ms step_avg:96.25ms +step:313/1695 train_time:30124ms step_avg:96.24ms +step:314/1695 train_time:30218ms step_avg:96.24ms +step:315/1695 train_time:30312ms step_avg:96.23ms +step:316/1695 train_time:30405ms step_avg:96.22ms +step:317/1695 train_time:30500ms step_avg:96.21ms +step:318/1695 train_time:30595ms step_avg:96.21ms +step:319/1695 train_time:30689ms step_avg:96.20ms +step:320/1695 train_time:30783ms step_avg:96.20ms +step:321/1695 train_time:30877ms step_avg:96.19ms +step:322/1695 train_time:30972ms step_avg:96.19ms +step:323/1695 train_time:31066ms step_avg:96.18ms +step:324/1695 train_time:31159ms step_avg:96.17ms +step:325/1695 train_time:31252ms step_avg:96.16ms +step:326/1695 train_time:31346ms step_avg:96.15ms +step:327/1695 train_time:31440ms step_avg:96.15ms +step:328/1695 train_time:31535ms step_avg:96.14ms +step:329/1695 train_time:31630ms step_avg:96.14ms +step:330/1695 train_time:31723ms step_avg:96.13ms +step:331/1695 train_time:31817ms step_avg:96.12ms +step:332/1695 train_time:31912ms step_avg:96.12ms +step:333/1695 train_time:32006ms step_avg:96.11ms +step:334/1695 train_time:32099ms step_avg:96.11ms +step:335/1695 train_time:32194ms step_avg:96.10ms +step:336/1695 train_time:32288ms step_avg:96.09ms +step:337/1695 train_time:32381ms step_avg:96.09ms +step:338/1695 train_time:32475ms step_avg:96.08ms +step:339/1695 train_time:32570ms step_avg:96.08ms +step:340/1695 train_time:32664ms step_avg:96.07ms +step:341/1695 train_time:32758ms step_avg:96.06ms +step:342/1695 train_time:32852ms step_avg:96.06ms +step:343/1695 train_time:32947ms step_avg:96.05ms +step:344/1695 train_time:33040ms step_avg:96.05ms +step:345/1695 train_time:33366ms step_avg:96.71ms +step:346/1695 train_time:33470ms step_avg:96.73ms +step:347/1695 train_time:33562ms step_avg:96.72ms +step:348/1695 train_time:33655ms step_avg:96.71ms +step:349/1695 train_time:33748ms step_avg:96.70ms +step:350/1695 train_time:33840ms step_avg:96.69ms +step:351/1695 train_time:33933ms step_avg:96.68ms +step:352/1695 train_time:34026ms step_avg:96.66ms +step:353/1695 train_time:34119ms step_avg:96.65ms +step:354/1695 train_time:34212ms step_avg:96.65ms +step:355/1695 train_time:34310ms step_avg:96.65ms +step:356/1695 train_time:34408ms step_avg:96.65ms +step:357/1695 train_time:34502ms step_avg:96.65ms +step:358/1695 train_time:34597ms step_avg:96.64ms +step:359/1695 train_time:34690ms step_avg:96.63ms +step:360/1695 train_time:34783ms step_avg:96.62ms +step:361/1695 train_time:34876ms step_avg:96.61ms +step:362/1695 train_time:34970ms step_avg:96.60ms +step:363/1695 train_time:35062ms step_avg:96.59ms +step:364/1695 train_time:35155ms step_avg:96.58ms +step:365/1695 train_time:35250ms step_avg:96.58ms +step:366/1695 train_time:35345ms step_avg:96.57ms +step:367/1695 train_time:35440ms step_avg:96.57ms +step:368/1695 train_time:35535ms step_avg:96.56ms +step:369/1695 train_time:35630ms step_avg:96.56ms +step:370/1695 train_time:35723ms step_avg:96.55ms +step:371/1695 train_time:35817ms step_avg:96.54ms +step:372/1695 train_time:35911ms step_avg:96.53ms +step:373/1695 train_time:36004ms step_avg:96.52ms +step:374/1695 train_time:36097ms step_avg:96.52ms +step:375/1695 train_time:36191ms step_avg:96.51ms +step:375/1695 val_loss:3.8151 train_time:36283ms step_avg:96.76ms +step:376/1695 train_time:36310ms step_avg:96.57ms +step:377/1695 train_time:36385ms step_avg:96.51ms +step:378/1695 train_time:36485ms step_avg:96.52ms +step:379/1695 train_time:36582ms step_avg:96.52ms +step:380/1695 train_time:36675ms step_avg:96.51ms +step:381/1695 train_time:36768ms step_avg:96.50ms +step:382/1695 train_time:36861ms step_avg:96.50ms +step:383/1695 train_time:36955ms step_avg:96.49ms +step:384/1695 train_time:37047ms step_avg:96.48ms +step:385/1695 train_time:37140ms step_avg:96.47ms +step:386/1695 train_time:37233ms step_avg:96.46ms +step:387/1695 train_time:37328ms step_avg:96.46ms +step:388/1695 train_time:37424ms step_avg:96.45ms +step:389/1695 train_time:37521ms step_avg:96.46ms +step:390/1695 train_time:37617ms step_avg:96.45ms +step:391/1695 train_time:37711ms step_avg:96.45ms +step:392/1695 train_time:37804ms step_avg:96.44ms +step:393/1695 train_time:37897ms step_avg:96.43ms +step:394/1695 train_time:37990ms step_avg:96.42ms +step:395/1695 train_time:38083ms step_avg:96.41ms +step:396/1695 train_time:38176ms step_avg:96.40ms +step:397/1695 train_time:38270ms step_avg:96.40ms +step:398/1695 train_time:38364ms step_avg:96.39ms +step:399/1695 train_time:38460ms step_avg:96.39ms +step:400/1695 train_time:38556ms step_avg:96.39ms +step:401/1695 train_time:38650ms step_avg:96.38ms +step:402/1695 train_time:38744ms step_avg:96.38ms +step:403/1695 train_time:38837ms step_avg:96.37ms +step:404/1695 train_time:38930ms step_avg:96.36ms +step:405/1695 train_time:39024ms step_avg:96.35ms +step:406/1695 train_time:39118ms step_avg:96.35ms +step:407/1695 train_time:39211ms step_avg:96.34ms +step:408/1695 train_time:39304ms step_avg:96.33ms +step:409/1695 train_time:39398ms step_avg:96.33ms +step:410/1695 train_time:39494ms step_avg:96.33ms +step:411/1695 train_time:39588ms step_avg:96.32ms +step:412/1695 train_time:39683ms step_avg:96.32ms +step:413/1695 train_time:39776ms step_avg:96.31ms +step:414/1695 train_time:39870ms step_avg:96.30ms +step:415/1695 train_time:39963ms step_avg:96.30ms +step:416/1695 train_time:40058ms step_avg:96.29ms +step:417/1695 train_time:40152ms step_avg:96.29ms +step:418/1695 train_time:40246ms step_avg:96.28ms +step:419/1695 train_time:40339ms step_avg:96.27ms +step:420/1695 train_time:40433ms step_avg:96.27ms +step:421/1695 train_time:40528ms step_avg:96.27ms +step:422/1695 train_time:40622ms step_avg:96.26ms +step:423/1695 train_time:40716ms step_avg:96.26ms +step:424/1695 train_time:40810ms step_avg:96.25ms +step:425/1695 train_time:40903ms step_avg:96.24ms +step:426/1695 train_time:40997ms step_avg:96.24ms +step:427/1695 train_time:41091ms step_avg:96.23ms +step:428/1695 train_time:41184ms step_avg:96.22ms +step:429/1695 train_time:41278ms step_avg:96.22ms +step:430/1695 train_time:41372ms step_avg:96.21ms +step:431/1695 train_time:41466ms step_avg:96.21ms +step:432/1695 train_time:41561ms step_avg:96.21ms +step:433/1695 train_time:41656ms step_avg:96.20ms +step:434/1695 train_time:41750ms step_avg:96.20ms +step:435/1695 train_time:41843ms step_avg:96.19ms +step:436/1695 train_time:41938ms step_avg:96.19ms +step:437/1695 train_time:42032ms step_avg:96.18ms +step:438/1695 train_time:42126ms step_avg:96.18ms +step:439/1695 train_time:42220ms step_avg:96.17ms +step:440/1695 train_time:42314ms step_avg:96.17ms +step:441/1695 train_time:42408ms step_avg:96.16ms +step:442/1695 train_time:42501ms step_avg:96.16ms +step:443/1695 train_time:42596ms step_avg:96.15ms +step:444/1695 train_time:42691ms step_avg:96.15ms +step:445/1695 train_time:42784ms step_avg:96.14ms +step:446/1695 train_time:42878ms step_avg:96.14ms +step:447/1695 train_time:42972ms step_avg:96.13ms +step:448/1695 train_time:43066ms step_avg:96.13ms +step:449/1695 train_time:43160ms step_avg:96.12ms +step:450/1695 train_time:43255ms step_avg:96.12ms +step:451/1695 train_time:43348ms step_avg:96.12ms +step:452/1695 train_time:43443ms step_avg:96.11ms +step:453/1695 train_time:43537ms step_avg:96.11ms +step:454/1695 train_time:43632ms step_avg:96.11ms +step:455/1695 train_time:43725ms step_avg:96.10ms +step:456/1695 train_time:43819ms step_avg:96.09ms +step:457/1695 train_time:43913ms step_avg:96.09ms +step:458/1695 train_time:44006ms step_avg:96.08ms +step:459/1695 train_time:44100ms step_avg:96.08ms +step:460/1695 train_time:44194ms step_avg:96.07ms +step:461/1695 train_time:44287ms step_avg:96.07ms +step:462/1695 train_time:44381ms step_avg:96.06ms +step:463/1695 train_time:44476ms step_avg:96.06ms +step:464/1695 train_time:44569ms step_avg:96.05ms +step:465/1695 train_time:44663ms step_avg:96.05ms +step:466/1695 train_time:44758ms step_avg:96.05ms +step:467/1695 train_time:44853ms step_avg:96.04ms +step:468/1695 train_time:44947ms step_avg:96.04ms +step:469/1695 train_time:45040ms step_avg:96.03ms +step:470/1695 train_time:45134ms step_avg:96.03ms +step:471/1695 train_time:45229ms step_avg:96.03ms +step:472/1695 train_time:45323ms step_avg:96.02ms +step:473/1695 train_time:45418ms step_avg:96.02ms +step:474/1695 train_time:45512ms step_avg:96.02ms +step:475/1695 train_time:45605ms step_avg:96.01ms +step:476/1695 train_time:45699ms step_avg:96.01ms +step:477/1695 train_time:45795ms step_avg:96.01ms +step:478/1695 train_time:45888ms step_avg:96.00ms +step:479/1695 train_time:45982ms step_avg:96.00ms +step:480/1695 train_time:46076ms step_avg:95.99ms +step:481/1695 train_time:46170ms step_avg:95.99ms +step:482/1695 train_time:46264ms step_avg:95.98ms +step:483/1695 train_time:46358ms step_avg:95.98ms +step:484/1695 train_time:46454ms step_avg:95.98ms +step:485/1695 train_time:46547ms step_avg:95.97ms +step:486/1695 train_time:46641ms step_avg:95.97ms +step:487/1695 train_time:46735ms step_avg:95.97ms +step:488/1695 train_time:46830ms step_avg:95.96ms +step:489/1695 train_time:46924ms step_avg:95.96ms +step:490/1695 train_time:47018ms step_avg:95.96ms +step:491/1695 train_time:47112ms step_avg:95.95ms +step:492/1695 train_time:47205ms step_avg:95.95ms +step:493/1695 train_time:47299ms step_avg:95.94ms +step:494/1695 train_time:47393ms step_avg:95.94ms +step:495/1695 train_time:47487ms step_avg:95.93ms +step:496/1695 train_time:47581ms step_avg:95.93ms +step:497/1695 train_time:47675ms step_avg:95.93ms +step:498/1695 train_time:47768ms step_avg:95.92ms +step:499/1695 train_time:47862ms step_avg:95.92ms +step:500/1695 train_time:47957ms step_avg:95.91ms +step:500/1695 val_loss:3.7158 train_time:48050ms step_avg:96.10ms +step:501/1695 train_time:48074ms step_avg:95.96ms +step:502/1695 train_time:48155ms step_avg:95.93ms +step:503/1695 train_time:48257ms step_avg:95.94ms +step:504/1695 train_time:48350ms step_avg:95.93ms +step:505/1695 train_time:48444ms step_avg:95.93ms +step:506/1695 train_time:48537ms step_avg:95.92ms +step:507/1695 train_time:48630ms step_avg:95.92ms +step:508/1695 train_time:48723ms step_avg:95.91ms +step:509/1695 train_time:48816ms step_avg:95.91ms +step:510/1695 train_time:48909ms step_avg:95.90ms +step:511/1695 train_time:49002ms step_avg:95.89ms +step:512/1695 train_time:49098ms step_avg:95.89ms +step:513/1695 train_time:49195ms step_avg:95.90ms +step:514/1695 train_time:49290ms step_avg:95.90ms +step:515/1695 train_time:49386ms step_avg:95.89ms +step:516/1695 train_time:49480ms step_avg:95.89ms +step:517/1695 train_time:49573ms step_avg:95.89ms +step:518/1695 train_time:49666ms step_avg:95.88ms +step:519/1695 train_time:49996ms step_avg:96.33ms +step:520/1695 train_time:50189ms step_avg:96.52ms +step:521/1695 train_time:50281ms step_avg:96.51ms +step:522/1695 train_time:50374ms step_avg:96.50ms +step:523/1695 train_time:50467ms step_avg:96.50ms +step:524/1695 train_time:50560ms step_avg:96.49ms +step:525/1695 train_time:50653ms step_avg:96.48ms +step:526/1695 train_time:50746ms step_avg:96.48ms +step:527/1695 train_time:50840ms step_avg:96.47ms +step:528/1695 train_time:50933ms step_avg:96.46ms +step:529/1695 train_time:51027ms step_avg:96.46ms +step:530/1695 train_time:51125ms step_avg:96.46ms +step:531/1695 train_time:51222ms step_avg:96.46ms +step:532/1695 train_time:51316ms step_avg:96.46ms +step:533/1695 train_time:51409ms step_avg:96.45ms +step:534/1695 train_time:51504ms step_avg:96.45ms +step:535/1695 train_time:51596ms step_avg:96.44ms +step:536/1695 train_time:51689ms step_avg:96.43ms +step:537/1695 train_time:51782ms step_avg:96.43ms +step:538/1695 train_time:51875ms step_avg:96.42ms +step:539/1695 train_time:51969ms step_avg:96.42ms +step:540/1695 train_time:52063ms step_avg:96.41ms +step:541/1695 train_time:52158ms step_avg:96.41ms +step:542/1695 train_time:52252ms step_avg:96.41ms +step:543/1695 train_time:52348ms step_avg:96.40ms +step:544/1695 train_time:52442ms step_avg:96.40ms +step:545/1695 train_time:52536ms step_avg:96.40ms +step:546/1695 train_time:52629ms step_avg:96.39ms +step:547/1695 train_time:52723ms step_avg:96.39ms +step:548/1695 train_time:52816ms step_avg:96.38ms +step:549/1695 train_time:52909ms step_avg:96.37ms +step:550/1695 train_time:53003ms step_avg:96.37ms +step:551/1695 train_time:53097ms step_avg:96.36ms +step:552/1695 train_time:53192ms step_avg:96.36ms +step:553/1695 train_time:53287ms step_avg:96.36ms +step:554/1695 train_time:53382ms step_avg:96.36ms +step:555/1695 train_time:53476ms step_avg:96.35ms +step:556/1695 train_time:53569ms step_avg:96.35ms +step:557/1695 train_time:53663ms step_avg:96.34ms +step:558/1695 train_time:53756ms step_avg:96.34ms +step:559/1695 train_time:53849ms step_avg:96.33ms +step:560/1695 train_time:53943ms step_avg:96.33ms +step:561/1695 train_time:54037ms step_avg:96.32ms +step:562/1695 train_time:54131ms step_avg:96.32ms +step:563/1695 train_time:54226ms step_avg:96.32ms +step:564/1695 train_time:54321ms step_avg:96.31ms +step:565/1695 train_time:54414ms step_avg:96.31ms +step:566/1695 train_time:54508ms step_avg:96.30ms +step:567/1695 train_time:54602ms step_avg:96.30ms +step:568/1695 train_time:54697ms step_avg:96.30ms +step:569/1695 train_time:54792ms step_avg:96.30ms +step:570/1695 train_time:54889ms step_avg:96.30ms +step:571/1695 train_time:54986ms step_avg:96.30ms +step:572/1695 train_time:55083ms step_avg:96.30ms +step:573/1695 train_time:55180ms step_avg:96.30ms +step:574/1695 train_time:55276ms step_avg:96.30ms +step:575/1695 train_time:55373ms step_avg:96.30ms +step:576/1695 train_time:55469ms step_avg:96.30ms +step:577/1695 train_time:55565ms step_avg:96.30ms +step:578/1695 train_time:55662ms step_avg:96.30ms +step:579/1695 train_time:55757ms step_avg:96.30ms +step:580/1695 train_time:55852ms step_avg:96.30ms +step:581/1695 train_time:55948ms step_avg:96.30ms +step:582/1695 train_time:56044ms step_avg:96.30ms +step:583/1695 train_time:56140ms step_avg:96.30ms +step:584/1695 train_time:56236ms step_avg:96.29ms +step:585/1695 train_time:56331ms step_avg:96.29ms +step:586/1695 train_time:56427ms step_avg:96.29ms +step:587/1695 train_time:56523ms step_avg:96.29ms +step:588/1695 train_time:56621ms step_avg:96.29ms +step:589/1695 train_time:56716ms step_avg:96.29ms +step:590/1695 train_time:56812ms step_avg:96.29ms +step:591/1695 train_time:56908ms step_avg:96.29ms +step:592/1695 train_time:57005ms step_avg:96.29ms +step:593/1695 train_time:57101ms step_avg:96.29ms +step:594/1695 train_time:57196ms step_avg:96.29ms +step:595/1695 train_time:57291ms step_avg:96.29ms +step:596/1695 train_time:57388ms step_avg:96.29ms +step:597/1695 train_time:57484ms step_avg:96.29ms +step:598/1695 train_time:57581ms step_avg:96.29ms +step:599/1695 train_time:57678ms step_avg:96.29ms +step:600/1695 train_time:57773ms step_avg:96.29ms +step:601/1695 train_time:57869ms step_avg:96.29ms +step:602/1695 train_time:57964ms step_avg:96.29ms +step:603/1695 train_time:58060ms step_avg:96.28ms +step:604/1695 train_time:58155ms step_avg:96.28ms +step:605/1695 train_time:58251ms step_avg:96.28ms +step:606/1695 train_time:58348ms step_avg:96.28ms +step:607/1695 train_time:58444ms step_avg:96.28ms +step:608/1695 train_time:58541ms step_avg:96.28ms +step:609/1695 train_time:58636ms step_avg:96.28ms +step:610/1695 train_time:58732ms step_avg:96.28ms +step:611/1695 train_time:58827ms step_avg:96.28ms +step:612/1695 train_time:58923ms step_avg:96.28ms +step:613/1695 train_time:59020ms step_avg:96.28ms +step:614/1695 train_time:59116ms step_avg:96.28ms +step:615/1695 train_time:59211ms step_avg:96.28ms +step:616/1695 train_time:59307ms step_avg:96.28ms +step:617/1695 train_time:59404ms step_avg:96.28ms +step:618/1695 train_time:59499ms step_avg:96.28ms +step:619/1695 train_time:59595ms step_avg:96.28ms +step:620/1695 train_time:59691ms step_avg:96.28ms +step:621/1695 train_time:59788ms step_avg:96.28ms +step:622/1695 train_time:59884ms step_avg:96.28ms +step:623/1695 train_time:59982ms step_avg:96.28ms +step:624/1695 train_time:60078ms step_avg:96.28ms +step:625/1695 train_time:60173ms step_avg:96.28ms +step:625/1695 val_loss:3.6195 train_time:60266ms step_avg:96.43ms +step:626/1695 train_time:60290ms step_avg:96.31ms +step:627/1695 train_time:60370ms step_avg:96.28ms +step:628/1695 train_time:60467ms step_avg:96.29ms +step:629/1695 train_time:60563ms step_avg:96.28ms +step:630/1695 train_time:60658ms step_avg:96.28ms +step:631/1695 train_time:60753ms step_avg:96.28ms +step:632/1695 train_time:60847ms step_avg:96.28ms +step:633/1695 train_time:60943ms step_avg:96.28ms +step:634/1695 train_time:61038ms step_avg:96.27ms +step:635/1695 train_time:61133ms step_avg:96.27ms +step:636/1695 train_time:61232ms step_avg:96.28ms +step:637/1695 train_time:61331ms step_avg:96.28ms +step:638/1695 train_time:61426ms step_avg:96.28ms +step:639/1695 train_time:61522ms step_avg:96.28ms +step:640/1695 train_time:61617ms step_avg:96.28ms +step:641/1695 train_time:61712ms step_avg:96.28ms +step:642/1695 train_time:61808ms step_avg:96.27ms +step:643/1695 train_time:61902ms step_avg:96.27ms +step:644/1695 train_time:61997ms step_avg:96.27ms +step:645/1695 train_time:62092ms step_avg:96.27ms +step:646/1695 train_time:62188ms step_avg:96.27ms +step:647/1695 train_time:62284ms step_avg:96.27ms +step:648/1695 train_time:62382ms step_avg:96.27ms +step:649/1695 train_time:62478ms step_avg:96.27ms +step:650/1695 train_time:62575ms step_avg:96.27ms +step:651/1695 train_time:62671ms step_avg:96.27ms +step:652/1695 train_time:62767ms step_avg:96.27ms +step:653/1695 train_time:62862ms step_avg:96.27ms +step:654/1695 train_time:62957ms step_avg:96.27ms +step:655/1695 train_time:63053ms step_avg:96.26ms +step:656/1695 train_time:63148ms step_avg:96.26ms +step:657/1695 train_time:63245ms step_avg:96.26ms +step:658/1695 train_time:63341ms step_avg:96.26ms +step:659/1695 train_time:63438ms step_avg:96.26ms +step:660/1695 train_time:63535ms step_avg:96.27ms +step:661/1695 train_time:63632ms step_avg:96.27ms +step:662/1695 train_time:63727ms step_avg:96.26ms +step:663/1695 train_time:63822ms step_avg:96.26ms +step:664/1695 train_time:63918ms step_avg:96.26ms +step:665/1695 train_time:64014ms step_avg:96.26ms +step:666/1695 train_time:64111ms step_avg:96.26ms +step:667/1695 train_time:64207ms step_avg:96.26ms +step:668/1695 train_time:64303ms step_avg:96.26ms +step:669/1695 train_time:64400ms step_avg:96.26ms +step:670/1695 train_time:64497ms step_avg:96.26ms +step:671/1695 train_time:64594ms step_avg:96.27ms +step:672/1695 train_time:64691ms step_avg:96.27ms +step:673/1695 train_time:64786ms step_avg:96.26ms +step:674/1695 train_time:64881ms step_avg:96.26ms +step:675/1695 train_time:64977ms step_avg:96.26ms +step:676/1695 train_time:65073ms step_avg:96.26ms +step:677/1695 train_time:65170ms step_avg:96.26ms +step:678/1695 train_time:65266ms step_avg:96.26ms +step:679/1695 train_time:65361ms step_avg:96.26ms +step:680/1695 train_time:65458ms step_avg:96.26ms +step:681/1695 train_time:65554ms step_avg:96.26ms +step:682/1695 train_time:65651ms step_avg:96.26ms +step:683/1695 train_time:65747ms step_avg:96.26ms +step:684/1695 train_time:65843ms step_avg:96.26ms +step:685/1695 train_time:65939ms step_avg:96.26ms +step:686/1695 train_time:66034ms step_avg:96.26ms +step:687/1695 train_time:66129ms step_avg:96.26ms +step:688/1695 train_time:66225ms step_avg:96.26ms +step:689/1695 train_time:66320ms step_avg:96.26ms +step:690/1695 train_time:66417ms step_avg:96.26ms +step:691/1695 train_time:66874ms step_avg:96.78ms +step:692/1695 train_time:66944ms step_avg:96.74ms +step:693/1695 train_time:67039ms step_avg:96.74ms +step:694/1695 train_time:67134ms step_avg:96.74ms +step:695/1695 train_time:67229ms step_avg:96.73ms +step:696/1695 train_time:67323ms step_avg:96.73ms +step:697/1695 train_time:67418ms step_avg:96.73ms +step:698/1695 train_time:67513ms step_avg:96.72ms +step:699/1695 train_time:67607ms step_avg:96.72ms +step:700/1695 train_time:67702ms step_avg:96.72ms +step:701/1695 train_time:67802ms step_avg:96.72ms +step:702/1695 train_time:67902ms step_avg:96.73ms +step:703/1695 train_time:68000ms step_avg:96.73ms +step:704/1695 train_time:68096ms step_avg:96.73ms +step:705/1695 train_time:68192ms step_avg:96.73ms +step:706/1695 train_time:68287ms step_avg:96.72ms +step:707/1695 train_time:68382ms step_avg:96.72ms +step:708/1695 train_time:68478ms step_avg:96.72ms +step:709/1695 train_time:68574ms step_avg:96.72ms +step:710/1695 train_time:68670ms step_avg:96.72ms +step:711/1695 train_time:68765ms step_avg:96.72ms +step:712/1695 train_time:68862ms step_avg:96.72ms +step:713/1695 train_time:68959ms step_avg:96.72ms +step:714/1695 train_time:69056ms step_avg:96.72ms +step:715/1695 train_time:69153ms step_avg:96.72ms +step:716/1695 train_time:69250ms step_avg:96.72ms +step:717/1695 train_time:69344ms step_avg:96.71ms +step:718/1695 train_time:69439ms step_avg:96.71ms +step:719/1695 train_time:69535ms step_avg:96.71ms +step:720/1695 train_time:69631ms step_avg:96.71ms +step:721/1695 train_time:69726ms step_avg:96.71ms +step:722/1695 train_time:69822ms step_avg:96.71ms +step:723/1695 train_time:69918ms step_avg:96.71ms +step:724/1695 train_time:70015ms step_avg:96.71ms +step:725/1695 train_time:70113ms step_avg:96.71ms +step:726/1695 train_time:70210ms step_avg:96.71ms +step:727/1695 train_time:70305ms step_avg:96.71ms +step:728/1695 train_time:70400ms step_avg:96.70ms +step:729/1695 train_time:70496ms step_avg:96.70ms +step:730/1695 train_time:70591ms step_avg:96.70ms +step:731/1695 train_time:70686ms step_avg:96.70ms +step:732/1695 train_time:70782ms step_avg:96.70ms +step:733/1695 train_time:70878ms step_avg:96.70ms +step:734/1695 train_time:70974ms step_avg:96.69ms +step:735/1695 train_time:71069ms step_avg:96.69ms +step:736/1695 train_time:71166ms step_avg:96.69ms +step:737/1695 train_time:71262ms step_avg:96.69ms +step:738/1695 train_time:71358ms step_avg:96.69ms +step:739/1695 train_time:71454ms step_avg:96.69ms +step:740/1695 train_time:71549ms step_avg:96.69ms +step:741/1695 train_time:71644ms step_avg:96.69ms +step:742/1695 train_time:71739ms step_avg:96.68ms +step:743/1695 train_time:71836ms step_avg:96.68ms +step:744/1695 train_time:71931ms step_avg:96.68ms +step:745/1695 train_time:72027ms step_avg:96.68ms +step:746/1695 train_time:72123ms step_avg:96.68ms +step:747/1695 train_time:72220ms step_avg:96.68ms +step:748/1695 train_time:72317ms step_avg:96.68ms +step:749/1695 train_time:72413ms step_avg:96.68ms +step:750/1695 train_time:72510ms step_avg:96.68ms +step:750/1695 val_loss:3.5686 train_time:72604ms step_avg:96.81ms +step:751/1695 train_time:72630ms step_avg:96.71ms +step:752/1695 train_time:72710ms step_avg:96.69ms +step:753/1695 train_time:72807ms step_avg:96.69ms +step:754/1695 train_time:72902ms step_avg:96.69ms +step:755/1695 train_time:72998ms step_avg:96.69ms +step:756/1695 train_time:73092ms step_avg:96.68ms +step:757/1695 train_time:73186ms step_avg:96.68ms +step:758/1695 train_time:73281ms step_avg:96.68ms +step:759/1695 train_time:73376ms step_avg:96.67ms +step:760/1695 train_time:73470ms step_avg:96.67ms +step:761/1695 train_time:73566ms step_avg:96.67ms +step:762/1695 train_time:73665ms step_avg:96.67ms +step:763/1695 train_time:73763ms step_avg:96.67ms +step:764/1695 train_time:73859ms step_avg:96.67ms +step:765/1695 train_time:73955ms step_avg:96.67ms +step:766/1695 train_time:74051ms step_avg:96.67ms +step:767/1695 train_time:74146ms step_avg:96.67ms +step:768/1695 train_time:74241ms step_avg:96.67ms +step:769/1695 train_time:74336ms step_avg:96.67ms +step:770/1695 train_time:74430ms step_avg:96.66ms +step:771/1695 train_time:74525ms step_avg:96.66ms +step:772/1695 train_time:74622ms step_avg:96.66ms +step:773/1695 train_time:74719ms step_avg:96.66ms +step:774/1695 train_time:74817ms step_avg:96.66ms +step:775/1695 train_time:74913ms step_avg:96.66ms +step:776/1695 train_time:75008ms step_avg:96.66ms +step:777/1695 train_time:75104ms step_avg:96.66ms +step:778/1695 train_time:75199ms step_avg:96.66ms +step:779/1695 train_time:75294ms step_avg:96.65ms +step:780/1695 train_time:75389ms step_avg:96.65ms +step:781/1695 train_time:75484ms step_avg:96.65ms +step:782/1695 train_time:75580ms step_avg:96.65ms +step:783/1695 train_time:75677ms step_avg:96.65ms +step:784/1695 train_time:75773ms step_avg:96.65ms +step:785/1695 train_time:75870ms step_avg:96.65ms +step:786/1695 train_time:75965ms step_avg:96.65ms +step:787/1695 train_time:76061ms step_avg:96.65ms +step:788/1695 train_time:76156ms step_avg:96.64ms +step:789/1695 train_time:76251ms step_avg:96.64ms +step:790/1695 train_time:76346ms step_avg:96.64ms +step:791/1695 train_time:76441ms step_avg:96.64ms +step:792/1695 train_time:76538ms step_avg:96.64ms +step:793/1695 train_time:76634ms step_avg:96.64ms +step:794/1695 train_time:76730ms step_avg:96.64ms +step:795/1695 train_time:76825ms step_avg:96.64ms +step:796/1695 train_time:76921ms step_avg:96.63ms +step:797/1695 train_time:77017ms step_avg:96.63ms +step:798/1695 train_time:77113ms step_avg:96.63ms +step:799/1695 train_time:77207ms step_avg:96.63ms +step:800/1695 train_time:77303ms step_avg:96.63ms +step:801/1695 train_time:77398ms step_avg:96.63ms +step:802/1695 train_time:77493ms step_avg:96.62ms +step:803/1695 train_time:77588ms step_avg:96.62ms +step:804/1695 train_time:77684ms step_avg:96.62ms +step:805/1695 train_time:77782ms step_avg:96.62ms +step:806/1695 train_time:77879ms step_avg:96.62ms +step:807/1695 train_time:77976ms step_avg:96.62ms +step:808/1695 train_time:78073ms step_avg:96.62ms +step:809/1695 train_time:78168ms step_avg:96.62ms +step:810/1695 train_time:78262ms step_avg:96.62ms +step:811/1695 train_time:78357ms step_avg:96.62ms +step:812/1695 train_time:78453ms step_avg:96.62ms +step:813/1695 train_time:78547ms step_avg:96.61ms +step:814/1695 train_time:78643ms step_avg:96.61ms +step:815/1695 train_time:78739ms step_avg:96.61ms +step:816/1695 train_time:78836ms step_avg:96.61ms +step:817/1695 train_time:78932ms step_avg:96.61ms +step:818/1695 train_time:79028ms step_avg:96.61ms +step:819/1695 train_time:79124ms step_avg:96.61ms +step:820/1695 train_time:79219ms step_avg:96.61ms +step:821/1695 train_time:79315ms step_avg:96.61ms +step:822/1695 train_time:79410ms step_avg:96.61ms +step:823/1695 train_time:79505ms step_avg:96.60ms +step:824/1695 train_time:79600ms step_avg:96.60ms +step:825/1695 train_time:79696ms step_avg:96.60ms +step:826/1695 train_time:79791ms step_avg:96.60ms +step:827/1695 train_time:79887ms step_avg:96.60ms +step:828/1695 train_time:79984ms step_avg:96.60ms +step:829/1695 train_time:80081ms step_avg:96.60ms +step:830/1695 train_time:80178ms step_avg:96.60ms +step:831/1695 train_time:80273ms step_avg:96.60ms +step:832/1695 train_time:80368ms step_avg:96.60ms +step:833/1695 train_time:80463ms step_avg:96.59ms +step:834/1695 train_time:80559ms step_avg:96.59ms +step:835/1695 train_time:80656ms step_avg:96.59ms +step:836/1695 train_time:80752ms step_avg:96.59ms +step:837/1695 train_time:80847ms step_avg:96.59ms +step:838/1695 train_time:80942ms step_avg:96.59ms +step:839/1695 train_time:81038ms step_avg:96.59ms +step:840/1695 train_time:81134ms step_avg:96.59ms +step:841/1695 train_time:81229ms step_avg:96.59ms +step:842/1695 train_time:81324ms step_avg:96.58ms +step:843/1695 train_time:81420ms step_avg:96.58ms +step:844/1695 train_time:81516ms step_avg:96.58ms +step:845/1695 train_time:81612ms step_avg:96.58ms +step:846/1695 train_time:81707ms step_avg:96.58ms +step:847/1695 train_time:81802ms step_avg:96.58ms +step:848/1695 train_time:81899ms step_avg:96.58ms +step:849/1695 train_time:81994ms step_avg:96.58ms +step:850/1695 train_time:82090ms step_avg:96.58ms +step:851/1695 train_time:82186ms step_avg:96.58ms +step:852/1695 train_time:82281ms step_avg:96.57ms +step:853/1695 train_time:82377ms step_avg:96.57ms +step:854/1695 train_time:82473ms step_avg:96.57ms +step:855/1695 train_time:82569ms step_avg:96.57ms +step:856/1695 train_time:82664ms step_avg:96.57ms +step:857/1695 train_time:82760ms step_avg:96.57ms +step:858/1695 train_time:82855ms step_avg:96.57ms +step:859/1695 train_time:82952ms step_avg:96.57ms +step:860/1695 train_time:83047ms step_avg:96.57ms +step:861/1695 train_time:83143ms step_avg:96.57ms +step:862/1695 train_time:83239ms step_avg:96.56ms +step:863/1695 train_time:83566ms step_avg:96.83ms +step:864/1695 train_time:83759ms step_avg:96.94ms +step:865/1695 train_time:83853ms step_avg:96.94ms +step:866/1695 train_time:83948ms step_avg:96.94ms +step:867/1695 train_time:84042ms step_avg:96.93ms +step:868/1695 train_time:84138ms step_avg:96.93ms +step:869/1695 train_time:84233ms step_avg:96.93ms +step:870/1695 train_time:84327ms step_avg:96.93ms +step:871/1695 train_time:84421ms step_avg:96.92ms +step:872/1695 train_time:84516ms step_avg:96.92ms +step:873/1695 train_time:84616ms step_avg:96.93ms +step:874/1695 train_time:84714ms step_avg:96.93ms +step:875/1695 train_time:84811ms step_avg:96.93ms +step:875/1695 val_loss:3.5270 train_time:84905ms step_avg:97.03ms +step:876/1695 train_time:84930ms step_avg:96.95ms +step:877/1695 train_time:85012ms step_avg:96.94ms +step:878/1695 train_time:85111ms step_avg:96.94ms +step:879/1695 train_time:85209ms step_avg:96.94ms +step:880/1695 train_time:85304ms step_avg:96.94ms +step:881/1695 train_time:85400ms step_avg:96.94ms +step:882/1695 train_time:85494ms step_avg:96.93ms +step:883/1695 train_time:85589ms step_avg:96.93ms +step:884/1695 train_time:85685ms step_avg:96.93ms +step:885/1695 train_time:85780ms step_avg:96.93ms +step:886/1695 train_time:85876ms step_avg:96.93ms +step:887/1695 train_time:85973ms step_avg:96.93ms +step:888/1695 train_time:86071ms step_avg:96.93ms +step:889/1695 train_time:86170ms step_avg:96.93ms +step:890/1695 train_time:86267ms step_avg:96.93ms +step:891/1695 train_time:86364ms step_avg:96.93ms +step:892/1695 train_time:86459ms step_avg:96.93ms +step:893/1695 train_time:86554ms step_avg:96.92ms +step:894/1695 train_time:86649ms step_avg:96.92ms +step:895/1695 train_time:86745ms step_avg:96.92ms +step:896/1695 train_time:86842ms step_avg:96.92ms +step:897/1695 train_time:86938ms step_avg:96.92ms +step:898/1695 train_time:87034ms step_avg:96.92ms +step:899/1695 train_time:87131ms step_avg:96.92ms +step:900/1695 train_time:87229ms step_avg:96.92ms +step:901/1695 train_time:87325ms step_avg:96.92ms +step:902/1695 train_time:87422ms step_avg:96.92ms +step:903/1695 train_time:87517ms step_avg:96.92ms +step:904/1695 train_time:87612ms step_avg:96.92ms +step:905/1695 train_time:87708ms step_avg:96.91ms +step:906/1695 train_time:87804ms step_avg:96.91ms +step:907/1695 train_time:87901ms step_avg:96.91ms +step:908/1695 train_time:87997ms step_avg:96.91ms +step:909/1695 train_time:88092ms step_avg:96.91ms +step:910/1695 train_time:88188ms step_avg:96.91ms +step:911/1695 train_time:88285ms step_avg:96.91ms +step:912/1695 train_time:88381ms step_avg:96.91ms +step:913/1695 train_time:88477ms step_avg:96.91ms +step:914/1695 train_time:88572ms step_avg:96.91ms +step:915/1695 train_time:88667ms step_avg:96.90ms +step:916/1695 train_time:88763ms step_avg:96.90ms +step:917/1695 train_time:88860ms step_avg:96.90ms +step:918/1695 train_time:88955ms step_avg:96.90ms +step:919/1695 train_time:89050ms step_avg:96.90ms +step:920/1695 train_time:89147ms step_avg:96.90ms +step:921/1695 train_time:89245ms step_avg:96.90ms +step:922/1695 train_time:89341ms step_avg:96.90ms +step:923/1695 train_time:89436ms step_avg:96.90ms +step:924/1695 train_time:89531ms step_avg:96.90ms +step:925/1695 train_time:89627ms step_avg:96.89ms +step:926/1695 train_time:89724ms step_avg:96.89ms +step:927/1695 train_time:89821ms step_avg:96.89ms +step:928/1695 train_time:89916ms step_avg:96.89ms +step:929/1695 train_time:90011ms step_avg:96.89ms +step:930/1695 train_time:90108ms step_avg:96.89ms +step:931/1695 train_time:90205ms step_avg:96.89ms +step:932/1695 train_time:90302ms step_avg:96.89ms +step:933/1695 train_time:90398ms step_avg:96.89ms +step:934/1695 train_time:90493ms step_avg:96.89ms +step:935/1695 train_time:90589ms step_avg:96.89ms +step:936/1695 train_time:90685ms step_avg:96.89ms +step:937/1695 train_time:90782ms step_avg:96.89ms +step:938/1695 train_time:90877ms step_avg:96.88ms +step:939/1695 train_time:90973ms step_avg:96.88ms +step:940/1695 train_time:91069ms step_avg:96.88ms +step:941/1695 train_time:91167ms step_avg:96.88ms +step:942/1695 train_time:91264ms step_avg:96.88ms +step:943/1695 train_time:91361ms step_avg:96.88ms +step:944/1695 train_time:91456ms step_avg:96.88ms +step:945/1695 train_time:91552ms step_avg:96.88ms +step:946/1695 train_time:91648ms step_avg:96.88ms +step:947/1695 train_time:91744ms step_avg:96.88ms +step:948/1695 train_time:91840ms step_avg:96.88ms +step:949/1695 train_time:91935ms step_avg:96.88ms +step:950/1695 train_time:92031ms step_avg:96.87ms +step:951/1695 train_time:92127ms step_avg:96.87ms +step:952/1695 train_time:92223ms step_avg:96.87ms +step:953/1695 train_time:92319ms step_avg:96.87ms +step:954/1695 train_time:92415ms step_avg:96.87ms +step:955/1695 train_time:92510ms step_avg:96.87ms +step:956/1695 train_time:92606ms step_avg:96.87ms +step:957/1695 train_time:92702ms step_avg:96.87ms +step:958/1695 train_time:92798ms step_avg:96.87ms +step:959/1695 train_time:92894ms step_avg:96.87ms +step:960/1695 train_time:92990ms step_avg:96.86ms +step:961/1695 train_time:93087ms step_avg:96.86ms +step:962/1695 train_time:93183ms step_avg:96.86ms +step:963/1695 train_time:93279ms step_avg:96.86ms +step:964/1695 train_time:93374ms step_avg:96.86ms +step:965/1695 train_time:93471ms step_avg:96.86ms +step:966/1695 train_time:93568ms step_avg:96.86ms +step:967/1695 train_time:93664ms step_avg:96.86ms +step:968/1695 train_time:93760ms step_avg:96.86ms +step:969/1695 train_time:93856ms step_avg:96.86ms +step:970/1695 train_time:93951ms step_avg:96.86ms +step:971/1695 train_time:94048ms step_avg:96.86ms +step:972/1695 train_time:94144ms step_avg:96.86ms +step:973/1695 train_time:94241ms step_avg:96.86ms +step:974/1695 train_time:94336ms step_avg:96.85ms +step:975/1695 train_time:94431ms step_avg:96.85ms +step:976/1695 train_time:94528ms step_avg:96.85ms +step:977/1695 train_time:94625ms step_avg:96.85ms +step:978/1695 train_time:94722ms step_avg:96.85ms +step:979/1695 train_time:94818ms step_avg:96.85ms +step:980/1695 train_time:94913ms step_avg:96.85ms +step:981/1695 train_time:95010ms step_avg:96.85ms +step:982/1695 train_time:95106ms step_avg:96.85ms +step:983/1695 train_time:95202ms step_avg:96.85ms +step:984/1695 train_time:95298ms step_avg:96.85ms +step:985/1695 train_time:95393ms step_avg:96.85ms +step:986/1695 train_time:95489ms step_avg:96.84ms +step:987/1695 train_time:95585ms step_avg:96.84ms +step:988/1695 train_time:95682ms step_avg:96.84ms +step:989/1695 train_time:95777ms step_avg:96.84ms +step:990/1695 train_time:95872ms step_avg:96.84ms +step:991/1695 train_time:95967ms step_avg:96.84ms +step:992/1695 train_time:96064ms step_avg:96.84ms +step:993/1695 train_time:96160ms step_avg:96.84ms +step:994/1695 train_time:96255ms step_avg:96.84ms +step:995/1695 train_time:96351ms step_avg:96.83ms +step:996/1695 train_time:96446ms step_avg:96.83ms +step:997/1695 train_time:96543ms step_avg:96.83ms +step:998/1695 train_time:96638ms step_avg:96.83ms +step:999/1695 train_time:96734ms step_avg:96.83ms +step:1000/1695 train_time:96830ms step_avg:96.83ms +step:1000/1695 val_loss:3.4844 train_time:96924ms step_avg:96.92ms +step:1001/1695 train_time:96949ms step_avg:96.85ms +step:1002/1695 train_time:97032ms step_avg:96.84ms +step:1003/1695 train_time:97130ms step_avg:96.84ms +step:1004/1695 train_time:97226ms step_avg:96.84ms +step:1005/1695 train_time:97322ms step_avg:96.84ms +step:1006/1695 train_time:97417ms step_avg:96.84ms +step:1007/1695 train_time:97512ms step_avg:96.83ms +step:1008/1695 train_time:97606ms step_avg:96.83ms +step:1009/1695 train_time:97702ms step_avg:96.83ms +step:1010/1695 train_time:97797ms step_avg:96.83ms +step:1011/1695 train_time:97893ms step_avg:96.83ms +step:1012/1695 train_time:97991ms step_avg:96.83ms +step:1013/1695 train_time:98089ms step_avg:96.83ms +step:1014/1695 train_time:98187ms step_avg:96.83ms +step:1015/1695 train_time:98284ms step_avg:96.83ms +step:1016/1695 train_time:98379ms step_avg:96.83ms +step:1017/1695 train_time:98474ms step_avg:96.83ms +step:1018/1695 train_time:98569ms step_avg:96.83ms +step:1019/1695 train_time:98665ms step_avg:96.83ms +step:1020/1695 train_time:98761ms step_avg:96.82ms +step:1021/1695 train_time:98857ms step_avg:96.82ms +step:1022/1695 train_time:98954ms step_avg:96.82ms +step:1023/1695 train_time:99050ms step_avg:96.82ms +step:1024/1695 train_time:99147ms step_avg:96.82ms +step:1025/1695 train_time:99243ms step_avg:96.82ms +step:1026/1695 train_time:99340ms step_avg:96.82ms +step:1027/1695 train_time:99436ms step_avg:96.82ms +step:1028/1695 train_time:99530ms step_avg:96.82ms +step:1029/1695 train_time:99625ms step_avg:96.82ms +step:1030/1695 train_time:99721ms step_avg:96.82ms +step:1031/1695 train_time:99818ms step_avg:96.82ms +step:1032/1695 train_time:99913ms step_avg:96.82ms +step:1033/1695 train_time:100009ms step_avg:96.81ms +step:1034/1695 train_time:100107ms step_avg:96.82ms +step:1035/1695 train_time:100204ms step_avg:96.82ms +step:1036/1695 train_time:100552ms step_avg:97.06ms +step:1037/1695 train_time:100724ms step_avg:97.13ms +step:1038/1695 train_time:100817ms step_avg:97.13ms +step:1039/1695 train_time:100912ms step_avg:97.12ms +step:1040/1695 train_time:101007ms step_avg:97.12ms +step:1041/1695 train_time:101101ms step_avg:97.12ms +step:1042/1695 train_time:101196ms step_avg:97.12ms +step:1043/1695 train_time:101290ms step_avg:97.11ms +step:1044/1695 train_time:101385ms step_avg:97.11ms +step:1045/1695 train_time:101480ms step_avg:97.11ms +step:1046/1695 train_time:101579ms step_avg:97.11ms +step:1047/1695 train_time:101680ms step_avg:97.12ms +step:1048/1695 train_time:101777ms step_avg:97.12ms +step:1049/1695 train_time:101873ms step_avg:97.11ms +step:1050/1695 train_time:101968ms step_avg:97.11ms +step:1051/1695 train_time:102063ms step_avg:97.11ms +step:1052/1695 train_time:102159ms step_avg:97.11ms +step:1053/1695 train_time:102253ms step_avg:97.11ms +step:1054/1695 train_time:102348ms step_avg:97.10ms +step:1055/1695 train_time:102443ms step_avg:97.10ms +step:1056/1695 train_time:102540ms step_avg:97.10ms +step:1057/1695 train_time:102637ms step_avg:97.10ms +step:1058/1695 train_time:102733ms step_avg:97.10ms +step:1059/1695 train_time:102829ms step_avg:97.10ms +step:1060/1695 train_time:102925ms step_avg:97.10ms +step:1061/1695 train_time:103021ms step_avg:97.10ms +step:1062/1695 train_time:103117ms step_avg:97.10ms +step:1063/1695 train_time:103212ms step_avg:97.09ms +step:1064/1695 train_time:103307ms step_avg:97.09ms +step:1065/1695 train_time:103402ms step_avg:97.09ms +step:1066/1695 train_time:103499ms step_avg:97.09ms +step:1067/1695 train_time:103595ms step_avg:97.09ms +step:1068/1695 train_time:103691ms step_avg:97.09ms +step:1069/1695 train_time:103789ms step_avg:97.09ms +step:1070/1695 train_time:103886ms step_avg:97.09ms +step:1071/1695 train_time:103983ms step_avg:97.09ms +step:1072/1695 train_time:104080ms step_avg:97.09ms +step:1073/1695 train_time:104176ms step_avg:97.09ms +step:1074/1695 train_time:104270ms step_avg:97.09ms +step:1075/1695 train_time:104366ms step_avg:97.08ms +step:1076/1695 train_time:104461ms step_avg:97.08ms +step:1077/1695 train_time:104558ms step_avg:97.08ms +step:1078/1695 train_time:104654ms step_avg:97.08ms +step:1079/1695 train_time:104750ms step_avg:97.08ms +step:1080/1695 train_time:104847ms step_avg:97.08ms +step:1081/1695 train_time:104946ms step_avg:97.08ms +step:1082/1695 train_time:105043ms step_avg:97.08ms +step:1083/1695 train_time:105140ms step_avg:97.08ms +step:1084/1695 train_time:105235ms step_avg:97.08ms +step:1085/1695 train_time:105330ms step_avg:97.08ms +step:1086/1695 train_time:105426ms step_avg:97.08ms +step:1087/1695 train_time:105523ms step_avg:97.08ms +step:1088/1695 train_time:105619ms step_avg:97.08ms +step:1089/1695 train_time:105715ms step_avg:97.08ms +step:1090/1695 train_time:105811ms step_avg:97.07ms +step:1091/1695 train_time:105907ms step_avg:97.07ms +step:1092/1695 train_time:106003ms step_avg:97.07ms +step:1093/1695 train_time:106099ms step_avg:97.07ms +step:1094/1695 train_time:106194ms step_avg:97.07ms +step:1095/1695 train_time:106289ms step_avg:97.07ms +step:1096/1695 train_time:106386ms step_avg:97.07ms +step:1097/1695 train_time:106482ms step_avg:97.07ms +step:1098/1695 train_time:106578ms step_avg:97.07ms +step:1099/1695 train_time:106673ms step_avg:97.06ms +step:1100/1695 train_time:106769ms step_avg:97.06ms +step:1101/1695 train_time:106866ms step_avg:97.06ms +step:1102/1695 train_time:106963ms step_avg:97.06ms +step:1103/1695 train_time:107060ms step_avg:97.06ms +step:1104/1695 train_time:107156ms step_avg:97.06ms +step:1105/1695 train_time:107251ms step_avg:97.06ms +step:1106/1695 train_time:107347ms step_avg:97.06ms +step:1107/1695 train_time:107443ms step_avg:97.06ms +step:1108/1695 train_time:107540ms step_avg:97.06ms +step:1109/1695 train_time:107636ms step_avg:97.06ms +step:1110/1695 train_time:107730ms step_avg:97.05ms +step:1111/1695 train_time:107826ms step_avg:97.05ms +step:1112/1695 train_time:107922ms step_avg:97.05ms +step:1113/1695 train_time:108019ms step_avg:97.05ms +step:1114/1695 train_time:108115ms step_avg:97.05ms +step:1115/1695 train_time:108211ms step_avg:97.05ms +step:1116/1695 train_time:108306ms step_avg:97.05ms +step:1117/1695 train_time:108403ms step_avg:97.05ms +step:1118/1695 train_time:108499ms step_avg:97.05ms +step:1119/1695 train_time:108594ms step_avg:97.05ms +step:1120/1695 train_time:108690ms step_avg:97.04ms +step:1121/1695 train_time:108786ms step_avg:97.04ms +step:1122/1695 train_time:108883ms step_avg:97.04ms +step:1123/1695 train_time:108979ms step_avg:97.04ms +step:1124/1695 train_time:109075ms step_avg:97.04ms +step:1125/1695 train_time:109170ms step_avg:97.04ms +step:1125/1695 val_loss:3.4368 train_time:109264ms step_avg:97.12ms +step:1126/1695 train_time:109288ms step_avg:97.06ms +step:1127/1695 train_time:109371ms step_avg:97.05ms +step:1128/1695 train_time:109469ms step_avg:97.05ms +step:1129/1695 train_time:109566ms step_avg:97.05ms +step:1130/1695 train_time:109662ms step_avg:97.05ms +step:1131/1695 train_time:109757ms step_avg:97.04ms +step:1132/1695 train_time:109852ms step_avg:97.04ms +step:1133/1695 train_time:109950ms step_avg:97.04ms +step:1134/1695 train_time:110047ms step_avg:97.04ms +step:1135/1695 train_time:110144ms step_avg:97.04ms +step:1136/1695 train_time:110243ms step_avg:97.04ms +step:1137/1695 train_time:110343ms step_avg:97.05ms +step:1138/1695 train_time:110442ms step_avg:97.05ms +step:1139/1695 train_time:110539ms step_avg:97.05ms +step:1140/1695 train_time:110636ms step_avg:97.05ms +step:1141/1695 train_time:110733ms step_avg:97.05ms +step:1142/1695 train_time:110829ms step_avg:97.05ms +step:1143/1695 train_time:110927ms step_avg:97.05ms +step:1144/1695 train_time:111024ms step_avg:97.05ms +step:1145/1695 train_time:111121ms step_avg:97.05ms +step:1146/1695 train_time:111220ms step_avg:97.05ms +step:1147/1695 train_time:111319ms step_avg:97.05ms +step:1148/1695 train_time:111417ms step_avg:97.05ms +step:1149/1695 train_time:111515ms step_avg:97.05ms +step:1150/1695 train_time:111613ms step_avg:97.05ms +step:1151/1695 train_time:111710ms step_avg:97.05ms +step:1152/1695 train_time:111807ms step_avg:97.05ms +step:1153/1695 train_time:111903ms step_avg:97.05ms +step:1154/1695 train_time:112000ms step_avg:97.05ms +step:1155/1695 train_time:112096ms step_avg:97.05ms +step:1156/1695 train_time:112194ms step_avg:97.05ms +step:1157/1695 train_time:112292ms step_avg:97.05ms +step:1158/1695 train_time:112392ms step_avg:97.06ms +step:1159/1695 train_time:112491ms step_avg:97.06ms +step:1160/1695 train_time:112592ms step_avg:97.06ms +step:1161/1695 train_time:112691ms step_avg:97.06ms +step:1162/1695 train_time:112789ms step_avg:97.06ms +step:1163/1695 train_time:112885ms step_avg:97.06ms +step:1164/1695 train_time:112983ms step_avg:97.06ms +step:1165/1695 train_time:113080ms step_avg:97.06ms +step:1166/1695 train_time:113177ms step_avg:97.06ms +step:1167/1695 train_time:113274ms step_avg:97.06ms +step:1168/1695 train_time:113372ms step_avg:97.06ms +step:1169/1695 train_time:113471ms step_avg:97.07ms +step:1170/1695 train_time:113571ms step_avg:97.07ms +step:1171/1695 train_time:113670ms step_avg:97.07ms +step:1172/1695 train_time:113769ms step_avg:97.07ms +step:1173/1695 train_time:113866ms step_avg:97.07ms +step:1174/1695 train_time:113964ms step_avg:97.07ms +step:1175/1695 train_time:114063ms step_avg:97.08ms +step:1176/1695 train_time:114161ms step_avg:97.08ms +step:1177/1695 train_time:114259ms step_avg:97.08ms +step:1178/1695 train_time:114356ms step_avg:97.08ms +step:1179/1695 train_time:114453ms step_avg:97.08ms +step:1180/1695 train_time:114551ms step_avg:97.08ms +step:1181/1695 train_time:114649ms step_avg:97.08ms +step:1182/1695 train_time:114746ms step_avg:97.08ms +step:1183/1695 train_time:114844ms step_avg:97.08ms +step:1184/1695 train_time:114942ms step_avg:97.08ms +step:1185/1695 train_time:115039ms step_avg:97.08ms +step:1186/1695 train_time:115136ms step_avg:97.08ms +step:1187/1695 train_time:115233ms step_avg:97.08ms +step:1188/1695 train_time:115331ms step_avg:97.08ms +step:1189/1695 train_time:115429ms step_avg:97.08ms +step:1190/1695 train_time:115527ms step_avg:97.08ms +step:1191/1695 train_time:115625ms step_avg:97.08ms +step:1192/1695 train_time:115723ms step_avg:97.08ms +step:1193/1695 train_time:115821ms step_avg:97.08ms +step:1194/1695 train_time:115918ms step_avg:97.08ms +step:1195/1695 train_time:116016ms step_avg:97.08ms +step:1196/1695 train_time:116114ms step_avg:97.09ms +step:1197/1695 train_time:116212ms step_avg:97.09ms +step:1198/1695 train_time:116310ms step_avg:97.09ms +step:1199/1695 train_time:116409ms step_avg:97.09ms +step:1200/1695 train_time:116509ms step_avg:97.09ms +step:1201/1695 train_time:116609ms step_avg:97.09ms +step:1202/1695 train_time:116708ms step_avg:97.09ms +step:1203/1695 train_time:116808ms step_avg:97.10ms +step:1204/1695 train_time:116906ms step_avg:97.10ms +step:1205/1695 train_time:117005ms step_avg:97.10ms +step:1206/1695 train_time:117103ms step_avg:97.10ms +step:1207/1695 train_time:117201ms step_avg:97.10ms +step:1208/1695 train_time:117548ms step_avg:97.31ms +step:1209/1695 train_time:117728ms step_avg:97.38ms +step:1210/1695 train_time:117823ms step_avg:97.37ms +step:1211/1695 train_time:117920ms step_avg:97.37ms +step:1212/1695 train_time:118016ms step_avg:97.37ms +step:1213/1695 train_time:118112ms step_avg:97.37ms +step:1214/1695 train_time:118209ms step_avg:97.37ms +step:1215/1695 train_time:118306ms step_avg:97.37ms +step:1216/1695 train_time:118402ms step_avg:97.37ms +step:1217/1695 train_time:118500ms step_avg:97.37ms +step:1218/1695 train_time:118604ms step_avg:97.38ms +step:1219/1695 train_time:118704ms step_avg:97.38ms +step:1220/1695 train_time:118801ms step_avg:97.38ms +step:1221/1695 train_time:118897ms step_avg:97.38ms +step:1222/1695 train_time:118994ms step_avg:97.38ms +step:1223/1695 train_time:119090ms step_avg:97.38ms +step:1224/1695 train_time:119187ms step_avg:97.38ms +step:1225/1695 train_time:119285ms step_avg:97.38ms +step:1226/1695 train_time:119382ms step_avg:97.38ms +step:1227/1695 train_time:119480ms step_avg:97.38ms +step:1228/1695 train_time:119579ms step_avg:97.38ms +step:1229/1695 train_time:119678ms step_avg:97.38ms +step:1230/1695 train_time:119776ms step_avg:97.38ms +step:1231/1695 train_time:119874ms step_avg:97.38ms +step:1232/1695 train_time:119971ms step_avg:97.38ms +step:1233/1695 train_time:120068ms step_avg:97.38ms +step:1234/1695 train_time:120166ms step_avg:97.38ms +step:1235/1695 train_time:120263ms step_avg:97.38ms +step:1236/1695 train_time:120360ms step_avg:97.38ms +step:1237/1695 train_time:120457ms step_avg:97.38ms +step:1238/1695 train_time:120555ms step_avg:97.38ms +step:1239/1695 train_time:120654ms step_avg:97.38ms +step:1240/1695 train_time:120752ms step_avg:97.38ms +step:1241/1695 train_time:120851ms step_avg:97.38ms +step:1242/1695 train_time:120950ms step_avg:97.38ms +step:1243/1695 train_time:121048ms step_avg:97.38ms +step:1244/1695 train_time:121145ms step_avg:97.38ms +step:1245/1695 train_time:121243ms step_avg:97.38ms +step:1246/1695 train_time:121340ms step_avg:97.38ms +step:1247/1695 train_time:121437ms step_avg:97.38ms +step:1248/1695 train_time:121534ms step_avg:97.38ms +step:1249/1695 train_time:121632ms step_avg:97.38ms +step:1250/1695 train_time:121731ms step_avg:97.38ms +step:1250/1695 val_loss:3.3897 train_time:121827ms step_avg:97.46ms +step:1251/1695 train_time:121854ms step_avg:97.40ms +step:1252/1695 train_time:121931ms step_avg:97.39ms +step:1253/1695 train_time:122027ms step_avg:97.39ms +step:1254/1695 train_time:122123ms step_avg:97.39ms +step:1255/1695 train_time:122220ms step_avg:97.39ms +step:1256/1695 train_time:122317ms step_avg:97.39ms +step:1257/1695 train_time:122414ms step_avg:97.39ms +step:1258/1695 train_time:122510ms step_avg:97.38ms +step:1259/1695 train_time:122606ms step_avg:97.38ms +step:1260/1695 train_time:122702ms step_avg:97.38ms +step:1261/1695 train_time:122805ms step_avg:97.39ms +step:1262/1695 train_time:122904ms step_avg:97.39ms +step:1263/1695 train_time:123002ms step_avg:97.39ms +step:1264/1695 train_time:123099ms step_avg:97.39ms +step:1265/1695 train_time:123196ms step_avg:97.39ms +step:1266/1695 train_time:123292ms step_avg:97.39ms +step:1267/1695 train_time:123390ms step_avg:97.39ms +step:1268/1695 train_time:123486ms step_avg:97.39ms +step:1269/1695 train_time:123583ms step_avg:97.39ms +step:1270/1695 train_time:123681ms step_avg:97.39ms +step:1271/1695 train_time:123780ms step_avg:97.39ms +step:1272/1695 train_time:123879ms step_avg:97.39ms +step:1273/1695 train_time:123978ms step_avg:97.39ms +step:1274/1695 train_time:124078ms step_avg:97.39ms +step:1275/1695 train_time:124176ms step_avg:97.39ms +step:1276/1695 train_time:124275ms step_avg:97.39ms +step:1277/1695 train_time:124372ms step_avg:97.39ms +step:1278/1695 train_time:124470ms step_avg:97.39ms +step:1279/1695 train_time:124567ms step_avg:97.39ms +step:1280/1695 train_time:124664ms step_avg:97.39ms +step:1281/1695 train_time:124761ms step_avg:97.39ms +step:1282/1695 train_time:124859ms step_avg:97.39ms +step:1283/1695 train_time:124959ms step_avg:97.40ms +step:1284/1695 train_time:125057ms step_avg:97.40ms +step:1285/1695 train_time:125156ms step_avg:97.40ms +step:1286/1695 train_time:125254ms step_avg:97.40ms +step:1287/1695 train_time:125351ms step_avg:97.40ms +step:1288/1695 train_time:125449ms step_avg:97.40ms +step:1289/1695 train_time:125546ms step_avg:97.40ms +step:1290/1695 train_time:125644ms step_avg:97.40ms +step:1291/1695 train_time:125741ms step_avg:97.40ms +step:1292/1695 train_time:125839ms step_avg:97.40ms +step:1293/1695 train_time:125938ms step_avg:97.40ms +step:1294/1695 train_time:126037ms step_avg:97.40ms +step:1295/1695 train_time:126136ms step_avg:97.40ms +step:1296/1695 train_time:126234ms step_avg:97.40ms +step:1297/1695 train_time:126333ms step_avg:97.40ms +step:1298/1695 train_time:126431ms step_avg:97.40ms +step:1299/1695 train_time:126529ms step_avg:97.41ms +step:1300/1695 train_time:126628ms step_avg:97.41ms +step:1301/1695 train_time:126726ms step_avg:97.41ms +step:1302/1695 train_time:126823ms step_avg:97.41ms +step:1303/1695 train_time:126921ms step_avg:97.41ms +step:1304/1695 train_time:127018ms step_avg:97.41ms +step:1305/1695 train_time:127117ms step_avg:97.41ms +step:1306/1695 train_time:127216ms step_avg:97.41ms +step:1307/1695 train_time:127314ms step_avg:97.41ms +step:1308/1695 train_time:127412ms step_avg:97.41ms +step:1309/1695 train_time:127509ms step_avg:97.41ms +step:1310/1695 train_time:127608ms step_avg:97.41ms +step:1311/1695 train_time:127705ms step_avg:97.41ms +step:1312/1695 train_time:127802ms step_avg:97.41ms +step:1313/1695 train_time:127899ms step_avg:97.41ms +step:1314/1695 train_time:127996ms step_avg:97.41ms +step:1315/1695 train_time:128095ms step_avg:97.41ms +step:1316/1695 train_time:128193ms step_avg:97.41ms +step:1317/1695 train_time:128291ms step_avg:97.41ms +step:1318/1695 train_time:128389ms step_avg:97.41ms +step:1319/1695 train_time:128485ms step_avg:97.41ms +step:1320/1695 train_time:128582ms step_avg:97.41ms +step:1321/1695 train_time:128680ms step_avg:97.41ms +step:1322/1695 train_time:128778ms step_avg:97.41ms +step:1323/1695 train_time:128876ms step_avg:97.41ms +step:1324/1695 train_time:128974ms step_avg:97.41ms +step:1325/1695 train_time:129072ms step_avg:97.41ms +step:1326/1695 train_time:129170ms step_avg:97.41ms +step:1327/1695 train_time:129268ms step_avg:97.41ms +step:1328/1695 train_time:129366ms step_avg:97.41ms +step:1329/1695 train_time:129463ms step_avg:97.41ms +step:1330/1695 train_time:129561ms step_avg:97.41ms +step:1331/1695 train_time:129659ms step_avg:97.41ms +step:1332/1695 train_time:129758ms step_avg:97.42ms +step:1333/1695 train_time:129857ms step_avg:97.42ms +step:1334/1695 train_time:129955ms step_avg:97.42ms +step:1335/1695 train_time:130053ms step_avg:97.42ms +step:1336/1695 train_time:130151ms step_avg:97.42ms +step:1337/1695 train_time:130248ms step_avg:97.42ms +step:1338/1695 train_time:130347ms step_avg:97.42ms +step:1339/1695 train_time:130444ms step_avg:97.42ms +step:1340/1695 train_time:130541ms step_avg:97.42ms +step:1341/1695 train_time:130639ms step_avg:97.42ms +step:1342/1695 train_time:130736ms step_avg:97.42ms +step:1343/1695 train_time:130835ms step_avg:97.42ms +step:1344/1695 train_time:130933ms step_avg:97.42ms +step:1345/1695 train_time:131030ms step_avg:97.42ms +step:1346/1695 train_time:131127ms step_avg:97.42ms +step:1347/1695 train_time:131224ms step_avg:97.42ms +step:1348/1695 train_time:131321ms step_avg:97.42ms +step:1349/1695 train_time:131419ms step_avg:97.42ms +step:1350/1695 train_time:131518ms step_avg:97.42ms +step:1351/1695 train_time:131615ms step_avg:97.42ms +step:1352/1695 train_time:131714ms step_avg:97.42ms +step:1353/1695 train_time:131813ms step_avg:97.42ms +step:1354/1695 train_time:131911ms step_avg:97.42ms +step:1355/1695 train_time:132009ms step_avg:97.42ms +step:1356/1695 train_time:132106ms step_avg:97.42ms +step:1357/1695 train_time:132203ms step_avg:97.42ms +step:1358/1695 train_time:132300ms step_avg:97.42ms +step:1359/1695 train_time:132398ms step_avg:97.42ms +step:1360/1695 train_time:132497ms step_avg:97.42ms +step:1361/1695 train_time:132595ms step_avg:97.42ms +step:1362/1695 train_time:132693ms step_avg:97.43ms +step:1363/1695 train_time:132792ms step_avg:97.43ms +step:1364/1695 train_time:132890ms step_avg:97.43ms +step:1365/1695 train_time:132988ms step_avg:97.43ms +step:1366/1695 train_time:133085ms step_avg:97.43ms +step:1367/1695 train_time:133182ms step_avg:97.43ms +step:1368/1695 train_time:133279ms step_avg:97.43ms +step:1369/1695 train_time:133377ms step_avg:97.43ms +step:1370/1695 train_time:133476ms step_avg:97.43ms +step:1371/1695 train_time:133574ms step_avg:97.43ms +step:1372/1695 train_time:133671ms step_avg:97.43ms +step:1373/1695 train_time:133769ms step_avg:97.43ms +step:1374/1695 train_time:133867ms step_avg:97.43ms +step:1375/1695 train_time:133964ms step_avg:97.43ms +step:1375/1695 val_loss:3.3507 train_time:134060ms step_avg:97.50ms +step:1376/1695 train_time:134085ms step_avg:97.45ms +step:1377/1695 train_time:134167ms step_avg:97.43ms +step:1378/1695 train_time:134266ms step_avg:97.44ms +step:1379/1695 train_time:134364ms step_avg:97.44ms +step:1380/1695 train_time:134461ms step_avg:97.44ms +step:1381/1695 train_time:134815ms step_avg:97.62ms +step:1382/1695 train_time:134984ms step_avg:97.67ms +step:1383/1695 train_time:135080ms step_avg:97.67ms +step:1384/1695 train_time:135176ms step_avg:97.67ms +step:1385/1695 train_time:135272ms step_avg:97.67ms +step:1386/1695 train_time:135369ms step_avg:97.67ms +step:1387/1695 train_time:135465ms step_avg:97.67ms +step:1388/1695 train_time:135562ms step_avg:97.67ms +step:1389/1695 train_time:135658ms step_avg:97.67ms +step:1390/1695 train_time:135756ms step_avg:97.67ms +step:1391/1695 train_time:135859ms step_avg:97.67ms +step:1392/1695 train_time:135961ms step_avg:97.67ms +step:1393/1695 train_time:136060ms step_avg:97.67ms +step:1394/1695 train_time:136156ms step_avg:97.67ms +step:1395/1695 train_time:136253ms step_avg:97.67ms +step:1396/1695 train_time:136350ms step_avg:97.67ms +step:1397/1695 train_time:136446ms step_avg:97.67ms +step:1398/1695 train_time:136542ms step_avg:97.67ms +step:1399/1695 train_time:136639ms step_avg:97.67ms +step:1400/1695 train_time:136736ms step_avg:97.67ms +step:1401/1695 train_time:136834ms step_avg:97.67ms +step:1402/1695 train_time:136933ms step_avg:97.67ms +step:1403/1695 train_time:137032ms step_avg:97.67ms +step:1404/1695 train_time:137131ms step_avg:97.67ms +step:1405/1695 train_time:137230ms step_avg:97.67ms +step:1406/1695 train_time:137328ms step_avg:97.67ms +step:1407/1695 train_time:137425ms step_avg:97.67ms +step:1408/1695 train_time:137522ms step_avg:97.67ms +step:1409/1695 train_time:137619ms step_avg:97.67ms +step:1410/1695 train_time:137716ms step_avg:97.67ms +step:1411/1695 train_time:137813ms step_avg:97.67ms +step:1412/1695 train_time:137912ms step_avg:97.67ms +step:1413/1695 train_time:138011ms step_avg:97.67ms +step:1414/1695 train_time:138111ms step_avg:97.67ms +step:1415/1695 train_time:138210ms step_avg:97.67ms +step:1416/1695 train_time:138308ms step_avg:97.67ms +step:1417/1695 train_time:138405ms step_avg:97.67ms +step:1418/1695 train_time:138502ms step_avg:97.67ms +step:1419/1695 train_time:138600ms step_avg:97.67ms +step:1420/1695 train_time:138698ms step_avg:97.67ms +step:1421/1695 train_time:138795ms step_avg:97.67ms +step:1422/1695 train_time:138892ms step_avg:97.67ms +step:1423/1695 train_time:138991ms step_avg:97.67ms +step:1424/1695 train_time:139089ms step_avg:97.67ms +step:1425/1695 train_time:139188ms step_avg:97.68ms +step:1426/1695 train_time:139285ms step_avg:97.68ms +step:1427/1695 train_time:139382ms step_avg:97.68ms +step:1428/1695 train_time:139480ms step_avg:97.68ms +step:1429/1695 train_time:139577ms step_avg:97.67ms +step:1430/1695 train_time:139673ms step_avg:97.67ms +step:1431/1695 train_time:139771ms step_avg:97.67ms +step:1432/1695 train_time:139869ms step_avg:97.67ms +step:1433/1695 train_time:139968ms step_avg:97.67ms +step:1434/1695 train_time:140067ms step_avg:97.68ms +step:1435/1695 train_time:140167ms step_avg:97.68ms +step:1436/1695 train_time:140266ms step_avg:97.68ms +step:1437/1695 train_time:140363ms step_avg:97.68ms +step:1438/1695 train_time:140461ms step_avg:97.68ms +step:1439/1695 train_time:140558ms step_avg:97.68ms +step:1440/1695 train_time:140655ms step_avg:97.68ms +step:1441/1695 train_time:140752ms step_avg:97.68ms +step:1442/1695 train_time:140849ms step_avg:97.68ms +step:1443/1695 train_time:140948ms step_avg:97.68ms +step:1444/1695 train_time:141045ms step_avg:97.68ms +step:1445/1695 train_time:141144ms step_avg:97.68ms +step:1446/1695 train_time:141242ms step_avg:97.68ms +step:1447/1695 train_time:141340ms step_avg:97.68ms +step:1448/1695 train_time:141438ms step_avg:97.68ms +step:1449/1695 train_time:141534ms step_avg:97.68ms +step:1450/1695 train_time:141631ms step_avg:97.68ms +step:1451/1695 train_time:141728ms step_avg:97.68ms +step:1452/1695 train_time:141826ms step_avg:97.68ms +step:1453/1695 train_time:141924ms step_avg:97.68ms +step:1454/1695 train_time:142021ms step_avg:97.68ms +step:1455/1695 train_time:142117ms step_avg:97.68ms +step:1456/1695 train_time:142216ms step_avg:97.68ms +step:1457/1695 train_time:142314ms step_avg:97.68ms +step:1458/1695 train_time:142413ms step_avg:97.68ms +step:1459/1695 train_time:142510ms step_avg:97.68ms +step:1460/1695 train_time:142608ms step_avg:97.68ms +step:1461/1695 train_time:142706ms step_avg:97.68ms +step:1462/1695 train_time:142803ms step_avg:97.68ms +step:1463/1695 train_time:142901ms step_avg:97.68ms +step:1464/1695 train_time:142999ms step_avg:97.68ms +step:1465/1695 train_time:143096ms step_avg:97.68ms +step:1466/1695 train_time:143194ms step_avg:97.68ms +step:1467/1695 train_time:143291ms step_avg:97.68ms +step:1468/1695 train_time:143389ms step_avg:97.68ms +step:1469/1695 train_time:143487ms step_avg:97.68ms +step:1470/1695 train_time:143585ms step_avg:97.68ms +step:1471/1695 train_time:143682ms step_avg:97.68ms +step:1472/1695 train_time:143779ms step_avg:97.68ms +step:1473/1695 train_time:143877ms step_avg:97.68ms +step:1474/1695 train_time:143974ms step_avg:97.68ms +step:1475/1695 train_time:144072ms step_avg:97.68ms +step:1476/1695 train_time:144169ms step_avg:97.68ms +step:1477/1695 train_time:144267ms step_avg:97.68ms +step:1478/1695 train_time:144365ms step_avg:97.68ms +step:1479/1695 train_time:144462ms step_avg:97.68ms +step:1480/1695 train_time:144559ms step_avg:97.68ms +step:1481/1695 train_time:144657ms step_avg:97.67ms +step:1482/1695 train_time:144754ms step_avg:97.67ms +step:1483/1695 train_time:144852ms step_avg:97.67ms +step:1484/1695 train_time:144949ms step_avg:97.67ms +step:1485/1695 train_time:145048ms step_avg:97.68ms +step:1486/1695 train_time:145146ms step_avg:97.68ms +step:1487/1695 train_time:145244ms step_avg:97.68ms +step:1488/1695 train_time:145341ms step_avg:97.68ms +step:1489/1695 train_time:145438ms step_avg:97.67ms +step:1490/1695 train_time:145535ms step_avg:97.67ms +step:1491/1695 train_time:145632ms step_avg:97.67ms +step:1492/1695 train_time:145730ms step_avg:97.67ms +step:1493/1695 train_time:145829ms step_avg:97.68ms +step:1494/1695 train_time:145927ms step_avg:97.68ms +step:1495/1695 train_time:146024ms step_avg:97.68ms +step:1496/1695 train_time:146123ms step_avg:97.68ms +step:1497/1695 train_time:146220ms step_avg:97.68ms +step:1498/1695 train_time:146317ms step_avg:97.67ms +step:1499/1695 train_time:146414ms step_avg:97.67ms +step:1500/1695 train_time:146512ms step_avg:97.67ms +step:1500/1695 val_loss:3.3178 train_time:146608ms step_avg:97.74ms +step:1501/1695 train_time:146633ms step_avg:97.69ms +step:1502/1695 train_time:146718ms step_avg:97.68ms +step:1503/1695 train_time:146818ms step_avg:97.68ms +step:1504/1695 train_time:146916ms step_avg:97.68ms +step:1505/1695 train_time:147013ms step_avg:97.68ms +step:1506/1695 train_time:147110ms step_avg:97.68ms +step:1507/1695 train_time:147206ms step_avg:97.68ms +step:1508/1695 train_time:147302ms step_avg:97.68ms +step:1509/1695 train_time:147399ms step_avg:97.68ms +step:1510/1695 train_time:147495ms step_avg:97.68ms +step:1511/1695 train_time:147595ms step_avg:97.68ms +step:1512/1695 train_time:147697ms step_avg:97.68ms +step:1513/1695 train_time:147797ms step_avg:97.68ms +step:1514/1695 train_time:147896ms step_avg:97.69ms +step:1515/1695 train_time:147994ms step_avg:97.69ms +step:1516/1695 train_time:148092ms step_avg:97.69ms +step:1517/1695 train_time:148190ms step_avg:97.69ms +step:1518/1695 train_time:148287ms step_avg:97.69ms +step:1519/1695 train_time:148384ms step_avg:97.69ms +step:1520/1695 train_time:148481ms step_avg:97.68ms +step:1521/1695 train_time:148578ms step_avg:97.68ms +step:1522/1695 train_time:148676ms step_avg:97.68ms +step:1523/1695 train_time:148776ms step_avg:97.69ms +step:1524/1695 train_time:148874ms step_avg:97.69ms +step:1525/1695 train_time:148973ms step_avg:97.69ms +step:1526/1695 train_time:149072ms step_avg:97.69ms +step:1527/1695 train_time:149169ms step_avg:97.69ms +step:1528/1695 train_time:149266ms step_avg:97.69ms +step:1529/1695 train_time:149362ms step_avg:97.69ms +step:1530/1695 train_time:149460ms step_avg:97.69ms +step:1531/1695 train_time:149557ms step_avg:97.69ms +step:1532/1695 train_time:149655ms step_avg:97.69ms +step:1533/1695 train_time:149753ms step_avg:97.69ms +step:1534/1695 train_time:149852ms step_avg:97.69ms +step:1535/1695 train_time:149950ms step_avg:97.69ms +step:1536/1695 train_time:150048ms step_avg:97.69ms +step:1537/1695 train_time:150146ms step_avg:97.69ms +step:1538/1695 train_time:150244ms step_avg:97.69ms +step:1539/1695 train_time:150341ms step_avg:97.69ms +step:1540/1695 train_time:150438ms step_avg:97.69ms +step:1541/1695 train_time:150535ms step_avg:97.69ms +step:1542/1695 train_time:150633ms step_avg:97.69ms +step:1543/1695 train_time:150731ms step_avg:97.69ms +step:1544/1695 train_time:150830ms step_avg:97.69ms +step:1545/1695 train_time:150928ms step_avg:97.69ms +step:1546/1695 train_time:151027ms step_avg:97.69ms +step:1547/1695 train_time:151123ms step_avg:97.69ms +step:1548/1695 train_time:151220ms step_avg:97.69ms +step:1549/1695 train_time:151317ms step_avg:97.69ms +step:1550/1695 train_time:151415ms step_avg:97.69ms +step:1551/1695 train_time:151513ms step_avg:97.69ms +step:1552/1695 train_time:151866ms step_avg:97.85ms +step:1553/1695 train_time:152044ms step_avg:97.90ms +step:1554/1695 train_time:152139ms step_avg:97.90ms +step:1555/1695 train_time:152235ms step_avg:97.90ms +step:1556/1695 train_time:152332ms step_avg:97.90ms +step:1557/1695 train_time:152428ms step_avg:97.90ms +step:1558/1695 train_time:152525ms step_avg:97.90ms +step:1559/1695 train_time:152621ms step_avg:97.90ms +step:1560/1695 train_time:152717ms step_avg:97.90ms +step:1561/1695 train_time:152815ms step_avg:97.90ms +step:1562/1695 train_time:152920ms step_avg:97.90ms +step:1563/1695 train_time:153021ms step_avg:97.90ms +step:1564/1695 train_time:153120ms step_avg:97.90ms +step:1565/1695 train_time:153218ms step_avg:97.90ms +step:1566/1695 train_time:153315ms step_avg:97.90ms +step:1567/1695 train_time:153412ms step_avg:97.90ms +step:1568/1695 train_time:153509ms step_avg:97.90ms +step:1569/1695 train_time:153606ms step_avg:97.90ms +step:1570/1695 train_time:153702ms step_avg:97.90ms +step:1571/1695 train_time:153798ms step_avg:97.90ms +step:1572/1695 train_time:153898ms step_avg:97.90ms +step:1573/1695 train_time:153999ms step_avg:97.90ms +step:1574/1695 train_time:154099ms step_avg:97.90ms +step:1575/1695 train_time:154197ms step_avg:97.90ms +step:1576/1695 train_time:154294ms step_avg:97.90ms +step:1577/1695 train_time:154393ms step_avg:97.90ms +step:1578/1695 train_time:154490ms step_avg:97.90ms +step:1579/1695 train_time:154587ms step_avg:97.90ms +step:1580/1695 train_time:154684ms step_avg:97.90ms +step:1581/1695 train_time:154781ms step_avg:97.90ms +step:1582/1695 train_time:154879ms step_avg:97.90ms +step:1583/1695 train_time:154977ms step_avg:97.90ms +step:1584/1695 train_time:155075ms step_avg:97.90ms +step:1585/1695 train_time:155174ms step_avg:97.90ms +step:1586/1695 train_time:155273ms step_avg:97.90ms +step:1587/1695 train_time:155371ms step_avg:97.90ms +step:1588/1695 train_time:155469ms step_avg:97.90ms +step:1589/1695 train_time:155566ms step_avg:97.90ms +step:1590/1695 train_time:155663ms step_avg:97.90ms +step:1591/1695 train_time:155760ms step_avg:97.90ms +step:1592/1695 train_time:155858ms step_avg:97.90ms +step:1593/1695 train_time:155956ms step_avg:97.90ms +step:1594/1695 train_time:156054ms step_avg:97.90ms +step:1595/1695 train_time:156152ms step_avg:97.90ms +step:1596/1695 train_time:156250ms step_avg:97.90ms +step:1597/1695 train_time:156350ms step_avg:97.90ms +step:1598/1695 train_time:156448ms step_avg:97.90ms +step:1599/1695 train_time:156546ms step_avg:97.90ms +step:1600/1695 train_time:156644ms step_avg:97.90ms +step:1601/1695 train_time:156741ms step_avg:97.90ms +step:1602/1695 train_time:156838ms step_avg:97.90ms +step:1603/1695 train_time:156936ms step_avg:97.90ms +step:1604/1695 train_time:157034ms step_avg:97.90ms +step:1605/1695 train_time:157133ms step_avg:97.90ms +step:1606/1695 train_time:157234ms step_avg:97.90ms +step:1607/1695 train_time:157333ms step_avg:97.90ms +step:1608/1695 train_time:157431ms step_avg:97.91ms +step:1609/1695 train_time:157529ms step_avg:97.91ms +step:1610/1695 train_time:157627ms step_avg:97.91ms +step:1611/1695 train_time:157726ms step_avg:97.91ms +step:1612/1695 train_time:157824ms step_avg:97.91ms +step:1613/1695 train_time:157921ms step_avg:97.91ms +step:1614/1695 train_time:158017ms step_avg:97.90ms +step:1615/1695 train_time:158114ms step_avg:97.90ms +step:1616/1695 train_time:158212ms step_avg:97.90ms +step:1617/1695 train_time:158312ms step_avg:97.90ms +step:1618/1695 train_time:158412ms step_avg:97.91ms +step:1619/1695 train_time:158510ms step_avg:97.91ms +step:1620/1695 train_time:158609ms step_avg:97.91ms +step:1621/1695 train_time:158708ms step_avg:97.91ms +step:1622/1695 train_time:158806ms step_avg:97.91ms +step:1623/1695 train_time:158905ms step_avg:97.91ms +step:1624/1695 train_time:159001ms step_avg:97.91ms +step:1625/1695 train_time:159097ms step_avg:97.91ms +step:1625/1695 val_loss:3.2907 train_time:159193ms step_avg:97.96ms +step:1626/1695 train_time:159217ms step_avg:97.92ms +step:1627/1695 train_time:159299ms step_avg:97.91ms +step:1628/1695 train_time:159398ms step_avg:97.91ms +step:1629/1695 train_time:159495ms step_avg:97.91ms +step:1630/1695 train_time:159593ms step_avg:97.91ms +step:1631/1695 train_time:159690ms step_avg:97.91ms +step:1632/1695 train_time:159787ms step_avg:97.91ms +step:1633/1695 train_time:159884ms step_avg:97.91ms +step:1634/1695 train_time:159981ms step_avg:97.91ms +step:1635/1695 train_time:160077ms step_avg:97.91ms +step:1636/1695 train_time:160176ms step_avg:97.91ms +step:1637/1695 train_time:160276ms step_avg:97.91ms +step:1638/1695 train_time:160375ms step_avg:97.91ms +step:1639/1695 train_time:160474ms step_avg:97.91ms +step:1640/1695 train_time:160571ms step_avg:97.91ms +step:1641/1695 train_time:160669ms step_avg:97.91ms +step:1642/1695 train_time:160766ms step_avg:97.91ms +step:1643/1695 train_time:160864ms step_avg:97.91ms +step:1644/1695 train_time:160961ms step_avg:97.91ms +step:1645/1695 train_time:161058ms step_avg:97.91ms +step:1646/1695 train_time:161157ms step_avg:97.91ms +step:1647/1695 train_time:161255ms step_avg:97.91ms +step:1648/1695 train_time:161353ms step_avg:97.91ms +step:1649/1695 train_time:161452ms step_avg:97.91ms +step:1650/1695 train_time:161551ms step_avg:97.91ms +step:1651/1695 train_time:161649ms step_avg:97.91ms +step:1652/1695 train_time:161746ms step_avg:97.91ms +step:1653/1695 train_time:161843ms step_avg:97.91ms +step:1654/1695 train_time:161940ms step_avg:97.91ms +step:1655/1695 train_time:162037ms step_avg:97.91ms +step:1656/1695 train_time:162135ms step_avg:97.91ms +step:1657/1695 train_time:162233ms step_avg:97.91ms +step:1658/1695 train_time:162332ms step_avg:97.91ms +step:1659/1695 train_time:162430ms step_avg:97.91ms +step:1660/1695 train_time:162528ms step_avg:97.91ms +step:1661/1695 train_time:162627ms step_avg:97.91ms +step:1662/1695 train_time:162724ms step_avg:97.91ms +step:1663/1695 train_time:162820ms step_avg:97.91ms +step:1664/1695 train_time:162918ms step_avg:97.91ms +step:1665/1695 train_time:163015ms step_avg:97.91ms +step:1666/1695 train_time:163114ms step_avg:97.91ms +step:1667/1695 train_time:163211ms step_avg:97.91ms +step:1668/1695 train_time:163309ms step_avg:97.91ms +step:1669/1695 train_time:163407ms step_avg:97.91ms +step:1670/1695 train_time:163505ms step_avg:97.91ms +step:1671/1695 train_time:163603ms step_avg:97.91ms +step:1672/1695 train_time:163700ms step_avg:97.91ms +step:1673/1695 train_time:163797ms step_avg:97.91ms +step:1674/1695 train_time:163894ms step_avg:97.91ms +step:1675/1695 train_time:163992ms step_avg:97.91ms +step:1676/1695 train_time:164091ms step_avg:97.91ms +step:1677/1695 train_time:164189ms step_avg:97.91ms +step:1678/1695 train_time:164287ms step_avg:97.91ms +step:1679/1695 train_time:164385ms step_avg:97.91ms +step:1680/1695 train_time:164482ms step_avg:97.91ms +step:1681/1695 train_time:164580ms step_avg:97.91ms +step:1682/1695 train_time:164677ms step_avg:97.91ms +step:1683/1695 train_time:164775ms step_avg:97.91ms +step:1684/1695 train_time:164873ms step_avg:97.91ms +step:1685/1695 train_time:164971ms step_avg:97.91ms +step:1686/1695 train_time:165069ms step_avg:97.91ms +step:1687/1695 train_time:165167ms step_avg:97.91ms +step:1688/1695 train_time:165265ms step_avg:97.91ms +step:1689/1695 train_time:165363ms step_avg:97.91ms +step:1690/1695 train_time:165461ms step_avg:97.91ms +step:1691/1695 train_time:165559ms step_avg:97.91ms +step:1692/1695 train_time:165656ms step_avg:97.91ms +step:1693/1695 train_time:165754ms step_avg:97.91ms +step:1694/1695 train_time:165851ms step_avg:97.91ms +step:1695/1695 train_time:165950ms step_avg:97.91ms +step:1695/1695 val_loss:3.2791 train_time:166045ms step_avg:97.96ms +peak memory allocated: 34073 MiB reserved: 49476 MiB diff --git a/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt b/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt new file mode 100644 index 000000000..9652d6c2d --- /dev/null +++ b/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 03:43:24 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 32C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 30C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:511ms step_avg:510.59ms +step:2/1695 train_time:534ms step_avg:266.84ms +step:3/1695 train_time:604ms step_avg:201.32ms +step:4/1695 train_time:696ms step_avg:174.01ms +step:5/1695 train_time:789ms step_avg:157.81ms +step:6/1695 train_time:882ms step_avg:147.03ms +step:7/1695 train_time:975ms step_avg:139.35ms +step:8/1695 train_time:1069ms step_avg:133.63ms +step:9/1695 train_time:1163ms step_avg:129.21ms +step:10/1695 train_time:1256ms step_avg:125.59ms +step:11/1695 train_time:1349ms step_avg:122.66ms +step:12/1695 train_time:1447ms step_avg:120.59ms +step:13/1695 train_time:1547ms step_avg:119.02ms +step:14/1695 train_time:1644ms step_avg:117.41ms +step:15/1695 train_time:1738ms step_avg:115.86ms +step:16/1695 train_time:1831ms step_avg:114.46ms +step:17/1695 train_time:1925ms step_avg:113.25ms +step:18/1695 train_time:2019ms step_avg:112.18ms +step:19/1695 train_time:2113ms step_avg:111.19ms +step:20/1695 train_time:2207ms step_avg:110.34ms +step:21/1695 train_time:2300ms step_avg:109.54ms +step:22/1695 train_time:2395ms step_avg:108.87ms +step:23/1695 train_time:2491ms step_avg:108.31ms +step:24/1695 train_time:2587ms step_avg:107.80ms +step:25/1695 train_time:2683ms step_avg:107.33ms +step:26/1695 train_time:2778ms step_avg:106.85ms +step:27/1695 train_time:2872ms step_avg:106.37ms +step:28/1695 train_time:2967ms step_avg:105.96ms +step:29/1695 train_time:3061ms step_avg:105.56ms +step:30/1695 train_time:3155ms step_avg:105.16ms +step:31/1695 train_time:3249ms step_avg:104.81ms +step:32/1695 train_time:3344ms step_avg:104.50ms +step:33/1695 train_time:3438ms step_avg:104.19ms +step:34/1695 train_time:3533ms step_avg:103.92ms +step:35/1695 train_time:3629ms step_avg:103.69ms +step:36/1695 train_time:3726ms step_avg:103.50ms +step:37/1695 train_time:3822ms step_avg:103.29ms +step:38/1695 train_time:3916ms step_avg:103.06ms +step:39/1695 train_time:4010ms step_avg:102.83ms +step:40/1695 train_time:4105ms step_avg:102.62ms +step:41/1695 train_time:4199ms step_avg:102.42ms +step:42/1695 train_time:4293ms step_avg:102.21ms +step:43/1695 train_time:4388ms step_avg:102.06ms +step:44/1695 train_time:4484ms step_avg:101.90ms +step:45/1695 train_time:4579ms step_avg:101.75ms +step:46/1695 train_time:4674ms step_avg:101.61ms +step:47/1695 train_time:4770ms step_avg:101.48ms +step:48/1695 train_time:4865ms step_avg:101.36ms +step:49/1695 train_time:4960ms step_avg:101.22ms +step:50/1695 train_time:5054ms step_avg:101.07ms +step:51/1695 train_time:5149ms step_avg:100.96ms +step:52/1695 train_time:5243ms step_avg:100.83ms +step:53/1695 train_time:5338ms step_avg:100.71ms +step:54/1695 train_time:5432ms step_avg:100.59ms +step:55/1695 train_time:5528ms step_avg:100.51ms +step:56/1695 train_time:5623ms step_avg:100.41ms +step:57/1695 train_time:5718ms step_avg:100.31ms +step:58/1695 train_time:5813ms step_avg:100.22ms +step:59/1695 train_time:5908ms step_avg:100.14ms +step:60/1695 train_time:6003ms step_avg:100.05ms +step:61/1695 train_time:6097ms step_avg:99.95ms +step:62/1695 train_time:6191ms step_avg:99.86ms +step:63/1695 train_time:6286ms step_avg:99.78ms +step:64/1695 train_time:6381ms step_avg:99.70ms +step:65/1695 train_time:6475ms step_avg:99.62ms +step:66/1695 train_time:6570ms step_avg:99.55ms +step:67/1695 train_time:6665ms step_avg:99.48ms +step:68/1695 train_time:6760ms step_avg:99.40ms +step:69/1695 train_time:6854ms step_avg:99.33ms +step:70/1695 train_time:6949ms step_avg:99.26ms +step:71/1695 train_time:7043ms step_avg:99.20ms +step:72/1695 train_time:7138ms step_avg:99.14ms +step:73/1695 train_time:7232ms step_avg:99.07ms +step:74/1695 train_time:7327ms step_avg:99.01ms +step:75/1695 train_time:7422ms step_avg:98.96ms +step:76/1695 train_time:7516ms step_avg:98.89ms +step:77/1695 train_time:7610ms step_avg:98.84ms +step:78/1695 train_time:7706ms step_avg:98.79ms +step:79/1695 train_time:7801ms step_avg:98.75ms +step:80/1695 train_time:7895ms step_avg:98.69ms +step:81/1695 train_time:7990ms step_avg:98.64ms +step:82/1695 train_time:8085ms step_avg:98.59ms +step:83/1695 train_time:8179ms step_avg:98.54ms +step:84/1695 train_time:8273ms step_avg:98.49ms +step:85/1695 train_time:8369ms step_avg:98.46ms +step:86/1695 train_time:8463ms step_avg:98.41ms +step:87/1695 train_time:8557ms step_avg:98.36ms +step:88/1695 train_time:8652ms step_avg:98.31ms +step:89/1695 train_time:8747ms step_avg:98.28ms +step:90/1695 train_time:8842ms step_avg:98.24ms +step:91/1695 train_time:8936ms step_avg:98.20ms +step:92/1695 train_time:9030ms step_avg:98.16ms +step:93/1695 train_time:9126ms step_avg:98.13ms +step:94/1695 train_time:9221ms step_avg:98.10ms +step:95/1695 train_time:9315ms step_avg:98.05ms +step:96/1695 train_time:9409ms step_avg:98.01ms +step:97/1695 train_time:9504ms step_avg:97.98ms +step:98/1695 train_time:9597ms step_avg:97.93ms +step:99/1695 train_time:9691ms step_avg:97.89ms +step:100/1695 train_time:9786ms step_avg:97.86ms +step:101/1695 train_time:9881ms step_avg:97.84ms +step:102/1695 train_time:9975ms step_avg:97.79ms +step:103/1695 train_time:10069ms step_avg:97.76ms +step:104/1695 train_time:10164ms step_avg:97.73ms +step:105/1695 train_time:10259ms step_avg:97.70ms +step:106/1695 train_time:10352ms step_avg:97.66ms +step:107/1695 train_time:10448ms step_avg:97.64ms +step:108/1695 train_time:10543ms step_avg:97.62ms +step:109/1695 train_time:10636ms step_avg:97.58ms +step:110/1695 train_time:10730ms step_avg:97.54ms +step:111/1695 train_time:10826ms step_avg:97.53ms +step:112/1695 train_time:10922ms step_avg:97.51ms +step:113/1695 train_time:11016ms step_avg:97.48ms +step:114/1695 train_time:11110ms step_avg:97.45ms +step:115/1695 train_time:11205ms step_avg:97.44ms +step:116/1695 train_time:11300ms step_avg:97.41ms +step:117/1695 train_time:11393ms step_avg:97.37ms +step:118/1695 train_time:11487ms step_avg:97.35ms +step:119/1695 train_time:11582ms step_avg:97.33ms +step:120/1695 train_time:11676ms step_avg:97.30ms +step:121/1695 train_time:11770ms step_avg:97.27ms +step:122/1695 train_time:11864ms step_avg:97.25ms +step:123/1695 train_time:11959ms step_avg:97.22ms +step:124/1695 train_time:12052ms step_avg:97.20ms +step:125/1695 train_time:12148ms step_avg:97.18ms +step:125/1695 val_loss:4.3128 train_time:12241ms step_avg:97.92ms +step:126/1695 train_time:12267ms step_avg:97.36ms +step:127/1695 train_time:12345ms step_avg:97.20ms +step:128/1695 train_time:12446ms step_avg:97.23ms +step:129/1695 train_time:12540ms step_avg:97.21ms +step:130/1695 train_time:12635ms step_avg:97.19ms +step:131/1695 train_time:12728ms step_avg:97.16ms +step:132/1695 train_time:12821ms step_avg:97.13ms +step:133/1695 train_time:12915ms step_avg:97.11ms +step:134/1695 train_time:13008ms step_avg:97.08ms +step:135/1695 train_time:13102ms step_avg:97.05ms +step:136/1695 train_time:13196ms step_avg:97.03ms +step:137/1695 train_time:13291ms step_avg:97.01ms +step:138/1695 train_time:13388ms step_avg:97.01ms +step:139/1695 train_time:13482ms step_avg:96.99ms +step:140/1695 train_time:13577ms step_avg:96.98ms +step:141/1695 train_time:13672ms step_avg:96.96ms +step:142/1695 train_time:13766ms step_avg:96.94ms +step:143/1695 train_time:13859ms step_avg:96.92ms +step:144/1695 train_time:13954ms step_avg:96.90ms +step:145/1695 train_time:14047ms step_avg:96.88ms +step:146/1695 train_time:14140ms step_avg:96.85ms +step:147/1695 train_time:14235ms step_avg:96.84ms +step:148/1695 train_time:14331ms step_avg:96.83ms +step:149/1695 train_time:14426ms step_avg:96.82ms +step:150/1695 train_time:14520ms step_avg:96.80ms +step:151/1695 train_time:14614ms step_avg:96.78ms +step:152/1695 train_time:14709ms step_avg:96.77ms +step:153/1695 train_time:14802ms step_avg:96.75ms +step:154/1695 train_time:14896ms step_avg:96.73ms +step:155/1695 train_time:14991ms step_avg:96.72ms +step:156/1695 train_time:15084ms step_avg:96.69ms +step:157/1695 train_time:15178ms step_avg:96.68ms +step:158/1695 train_time:15272ms step_avg:96.66ms +step:159/1695 train_time:15366ms step_avg:96.64ms +step:160/1695 train_time:15461ms step_avg:96.63ms +step:161/1695 train_time:15556ms step_avg:96.62ms +step:162/1695 train_time:15651ms step_avg:96.61ms +step:163/1695 train_time:15745ms step_avg:96.59ms +step:164/1695 train_time:15839ms step_avg:96.58ms +step:165/1695 train_time:15934ms step_avg:96.57ms +step:166/1695 train_time:16029ms step_avg:96.56ms +step:167/1695 train_time:16123ms step_avg:96.55ms +step:168/1695 train_time:16217ms step_avg:96.53ms +step:169/1695 train_time:16312ms step_avg:96.52ms +step:170/1695 train_time:16406ms step_avg:96.51ms +step:171/1695 train_time:16501ms step_avg:96.50ms +step:172/1695 train_time:16596ms step_avg:96.49ms +step:173/1695 train_time:16939ms step_avg:97.91ms +step:174/1695 train_time:17041ms step_avg:97.94ms +step:175/1695 train_time:17135ms step_avg:97.91ms +step:176/1695 train_time:17228ms step_avg:97.88ms +step:177/1695 train_time:17321ms step_avg:97.86ms +step:178/1695 train_time:17414ms step_avg:97.83ms +step:179/1695 train_time:17509ms step_avg:97.82ms +step:180/1695 train_time:17602ms step_avg:97.79ms +step:181/1695 train_time:17695ms step_avg:97.77ms +step:182/1695 train_time:17788ms step_avg:97.74ms +step:183/1695 train_time:17884ms step_avg:97.73ms +step:184/1695 train_time:17983ms step_avg:97.73ms +step:185/1695 train_time:18078ms step_avg:97.72ms +step:186/1695 train_time:18174ms step_avg:97.71ms +step:187/1695 train_time:18268ms step_avg:97.69ms +step:188/1695 train_time:18361ms step_avg:97.67ms +step:189/1695 train_time:18455ms step_avg:97.65ms +step:190/1695 train_time:18548ms step_avg:97.62ms +step:191/1695 train_time:18641ms step_avg:97.60ms +step:192/1695 train_time:18735ms step_avg:97.58ms +step:193/1695 train_time:18830ms step_avg:97.56ms +step:194/1695 train_time:18924ms step_avg:97.55ms +step:195/1695 train_time:19018ms step_avg:97.53ms +step:196/1695 train_time:19114ms step_avg:97.52ms +step:197/1695 train_time:19209ms step_avg:97.51ms +step:198/1695 train_time:19302ms step_avg:97.49ms +step:199/1695 train_time:19397ms step_avg:97.47ms +step:200/1695 train_time:19490ms step_avg:97.45ms +step:201/1695 train_time:19583ms step_avg:97.43ms +step:202/1695 train_time:19677ms step_avg:97.41ms +step:203/1695 train_time:19771ms step_avg:97.40ms +step:204/1695 train_time:19866ms step_avg:97.38ms +step:205/1695 train_time:19960ms step_avg:97.37ms +step:206/1695 train_time:20056ms step_avg:97.36ms +step:207/1695 train_time:20151ms step_avg:97.35ms +step:208/1695 train_time:20244ms step_avg:97.33ms +step:209/1695 train_time:20338ms step_avg:97.31ms +step:210/1695 train_time:20433ms step_avg:97.30ms +step:211/1695 train_time:20527ms step_avg:97.29ms +step:212/1695 train_time:20621ms step_avg:97.27ms +step:213/1695 train_time:20715ms step_avg:97.26ms +step:214/1695 train_time:20810ms step_avg:97.24ms +step:215/1695 train_time:20904ms step_avg:97.23ms +step:216/1695 train_time:20998ms step_avg:97.21ms +step:217/1695 train_time:21094ms step_avg:97.21ms +step:218/1695 train_time:21188ms step_avg:97.19ms +step:219/1695 train_time:21281ms step_avg:97.17ms +step:220/1695 train_time:21376ms step_avg:97.16ms +step:221/1695 train_time:21470ms step_avg:97.15ms +step:222/1695 train_time:21563ms step_avg:97.13ms +step:223/1695 train_time:21657ms step_avg:97.12ms +step:224/1695 train_time:21751ms step_avg:97.10ms +step:225/1695 train_time:21845ms step_avg:97.09ms +step:226/1695 train_time:21939ms step_avg:97.07ms +step:227/1695 train_time:22034ms step_avg:97.06ms +step:228/1695 train_time:22128ms step_avg:97.05ms +step:229/1695 train_time:22221ms step_avg:97.03ms +step:230/1695 train_time:22315ms step_avg:97.02ms +step:231/1695 train_time:22410ms step_avg:97.01ms +step:232/1695 train_time:22505ms step_avg:97.00ms +step:233/1695 train_time:22599ms step_avg:96.99ms +step:234/1695 train_time:22693ms step_avg:96.98ms +step:235/1695 train_time:22787ms step_avg:96.97ms +step:236/1695 train_time:22881ms step_avg:96.95ms +step:237/1695 train_time:22975ms step_avg:96.94ms +step:238/1695 train_time:23071ms step_avg:96.94ms +step:239/1695 train_time:23166ms step_avg:96.93ms +step:240/1695 train_time:23259ms step_avg:96.91ms +step:241/1695 train_time:23354ms step_avg:96.90ms +step:242/1695 train_time:23448ms step_avg:96.89ms +step:243/1695 train_time:23541ms step_avg:96.88ms +step:244/1695 train_time:23637ms step_avg:96.87ms +step:245/1695 train_time:23731ms step_avg:96.86ms +step:246/1695 train_time:23825ms step_avg:96.85ms +step:247/1695 train_time:23919ms step_avg:96.84ms +step:248/1695 train_time:24014ms step_avg:96.83ms +step:249/1695 train_time:24109ms step_avg:96.83ms +step:250/1695 train_time:24204ms step_avg:96.81ms +step:250/1695 val_loss:3.9758 train_time:24295ms step_avg:97.18ms +step:251/1695 train_time:24320ms step_avg:96.89ms +step:252/1695 train_time:24399ms step_avg:96.82ms +step:253/1695 train_time:24500ms step_avg:96.84ms +step:254/1695 train_time:24595ms step_avg:96.83ms +step:255/1695 train_time:24689ms step_avg:96.82ms +step:256/1695 train_time:24782ms step_avg:96.81ms +step:257/1695 train_time:24876ms step_avg:96.79ms +step:258/1695 train_time:24969ms step_avg:96.78ms +step:259/1695 train_time:25062ms step_avg:96.76ms +step:260/1695 train_time:25155ms step_avg:96.75ms +step:261/1695 train_time:25248ms step_avg:96.74ms +step:262/1695 train_time:25344ms step_avg:96.73ms +step:263/1695 train_time:25441ms step_avg:96.73ms +step:264/1695 train_time:25538ms step_avg:96.73ms +step:265/1695 train_time:25632ms step_avg:96.73ms +step:266/1695 train_time:25726ms step_avg:96.71ms +step:267/1695 train_time:25820ms step_avg:96.70ms +step:268/1695 train_time:25914ms step_avg:96.69ms +step:269/1695 train_time:26007ms step_avg:96.68ms +step:270/1695 train_time:26100ms step_avg:96.67ms +step:271/1695 train_time:26193ms step_avg:96.65ms +step:272/1695 train_time:26287ms step_avg:96.64ms +step:273/1695 train_time:26382ms step_avg:96.64ms +step:274/1695 train_time:26478ms step_avg:96.63ms +step:275/1695 train_time:26574ms step_avg:96.63ms +step:276/1695 train_time:26668ms step_avg:96.62ms +step:277/1695 train_time:26762ms step_avg:96.61ms +step:278/1695 train_time:26856ms step_avg:96.61ms +step:279/1695 train_time:26950ms step_avg:96.60ms +step:280/1695 train_time:27044ms step_avg:96.59ms +step:281/1695 train_time:27137ms step_avg:96.57ms +step:282/1695 train_time:27231ms step_avg:96.56ms +step:283/1695 train_time:27324ms step_avg:96.55ms +step:284/1695 train_time:27419ms step_avg:96.54ms +step:285/1695 train_time:27513ms step_avg:96.54ms +step:286/1695 train_time:27608ms step_avg:96.53ms +step:287/1695 train_time:27702ms step_avg:96.52ms +step:288/1695 train_time:27796ms step_avg:96.51ms +step:289/1695 train_time:27891ms step_avg:96.51ms +step:290/1695 train_time:27984ms step_avg:96.50ms +step:291/1695 train_time:28078ms step_avg:96.49ms +step:292/1695 train_time:28171ms step_avg:96.48ms +step:293/1695 train_time:28265ms step_avg:96.47ms +step:294/1695 train_time:28359ms step_avg:96.46ms +step:295/1695 train_time:28453ms step_avg:96.45ms +step:296/1695 train_time:28547ms step_avg:96.44ms +step:297/1695 train_time:28641ms step_avg:96.43ms +step:298/1695 train_time:28735ms step_avg:96.43ms +step:299/1695 train_time:28831ms step_avg:96.42ms +step:300/1695 train_time:28925ms step_avg:96.42ms +step:301/1695 train_time:29019ms step_avg:96.41ms +step:302/1695 train_time:29114ms step_avg:96.40ms +step:303/1695 train_time:29208ms step_avg:96.39ms +step:304/1695 train_time:29301ms step_avg:96.39ms +step:305/1695 train_time:29396ms step_avg:96.38ms +step:306/1695 train_time:29490ms step_avg:96.37ms +step:307/1695 train_time:29584ms step_avg:96.37ms +step:308/1695 train_time:29678ms step_avg:96.36ms +step:309/1695 train_time:29772ms step_avg:96.35ms +step:310/1695 train_time:29866ms step_avg:96.34ms +step:311/1695 train_time:29960ms step_avg:96.33ms +step:312/1695 train_time:30055ms step_avg:96.33ms +step:313/1695 train_time:30149ms step_avg:96.32ms +step:314/1695 train_time:30242ms step_avg:96.31ms +step:315/1695 train_time:30337ms step_avg:96.31ms +step:316/1695 train_time:30431ms step_avg:96.30ms +step:317/1695 train_time:30525ms step_avg:96.29ms +step:318/1695 train_time:30619ms step_avg:96.29ms +step:319/1695 train_time:30713ms step_avg:96.28ms +step:320/1695 train_time:30807ms step_avg:96.27ms +step:321/1695 train_time:30900ms step_avg:96.26ms +step:322/1695 train_time:30995ms step_avg:96.26ms +step:323/1695 train_time:31089ms step_avg:96.25ms +step:324/1695 train_time:31182ms step_avg:96.24ms +step:325/1695 train_time:31277ms step_avg:96.24ms +step:326/1695 train_time:31371ms step_avg:96.23ms +step:327/1695 train_time:31465ms step_avg:96.22ms +step:328/1695 train_time:31559ms step_avg:96.22ms +step:329/1695 train_time:31654ms step_avg:96.21ms +step:330/1695 train_time:31750ms step_avg:96.21ms +step:331/1695 train_time:31843ms step_avg:96.20ms +step:332/1695 train_time:31938ms step_avg:96.20ms +step:333/1695 train_time:32032ms step_avg:96.19ms +step:334/1695 train_time:32126ms step_avg:96.19ms +step:335/1695 train_time:32220ms step_avg:96.18ms +step:336/1695 train_time:32315ms step_avg:96.18ms +step:337/1695 train_time:32410ms step_avg:96.17ms +step:338/1695 train_time:32503ms step_avg:96.16ms +step:339/1695 train_time:32597ms step_avg:96.16ms +step:340/1695 train_time:32692ms step_avg:96.15ms +step:341/1695 train_time:32785ms step_avg:96.14ms +step:342/1695 train_time:32879ms step_avg:96.14ms +step:343/1695 train_time:32974ms step_avg:96.13ms +step:344/1695 train_time:33069ms step_avg:96.13ms +step:345/1695 train_time:33399ms step_avg:96.81ms +step:346/1695 train_time:33523ms step_avg:96.89ms +step:347/1695 train_time:33615ms step_avg:96.87ms +step:348/1695 train_time:33709ms step_avg:96.87ms +step:349/1695 train_time:33802ms step_avg:96.85ms +step:350/1695 train_time:33895ms step_avg:96.84ms +step:351/1695 train_time:33988ms step_avg:96.83ms +step:352/1695 train_time:34081ms step_avg:96.82ms +step:353/1695 train_time:34174ms step_avg:96.81ms +step:354/1695 train_time:34267ms step_avg:96.80ms +step:355/1695 train_time:34364ms step_avg:96.80ms +step:356/1695 train_time:34462ms step_avg:96.80ms +step:357/1695 train_time:34558ms step_avg:96.80ms +step:358/1695 train_time:34653ms step_avg:96.80ms +step:359/1695 train_time:34747ms step_avg:96.79ms +step:360/1695 train_time:34840ms step_avg:96.78ms +step:361/1695 train_time:34934ms step_avg:96.77ms +step:362/1695 train_time:35026ms step_avg:96.76ms +step:363/1695 train_time:35119ms step_avg:96.75ms +step:364/1695 train_time:35213ms step_avg:96.74ms +step:365/1695 train_time:35306ms step_avg:96.73ms +step:366/1695 train_time:35402ms step_avg:96.73ms +step:367/1695 train_time:35497ms step_avg:96.72ms +step:368/1695 train_time:35591ms step_avg:96.72ms +step:369/1695 train_time:35686ms step_avg:96.71ms +step:370/1695 train_time:35780ms step_avg:96.70ms +step:371/1695 train_time:35874ms step_avg:96.69ms +step:372/1695 train_time:35967ms step_avg:96.68ms +step:373/1695 train_time:36060ms step_avg:96.68ms +step:374/1695 train_time:36154ms step_avg:96.67ms +step:375/1695 train_time:36248ms step_avg:96.66ms +step:375/1695 val_loss:3.8203 train_time:36339ms step_avg:96.90ms +step:376/1695 train_time:36364ms step_avg:96.71ms +step:377/1695 train_time:36442ms step_avg:96.66ms +step:378/1695 train_time:36539ms step_avg:96.66ms +step:379/1695 train_time:36633ms step_avg:96.66ms +step:380/1695 train_time:36726ms step_avg:96.65ms +step:381/1695 train_time:36820ms step_avg:96.64ms +step:382/1695 train_time:36912ms step_avg:96.63ms +step:383/1695 train_time:37005ms step_avg:96.62ms +step:384/1695 train_time:37098ms step_avg:96.61ms +step:385/1695 train_time:37190ms step_avg:96.60ms +step:386/1695 train_time:37284ms step_avg:96.59ms +step:387/1695 train_time:37379ms step_avg:96.59ms +step:388/1695 train_time:37475ms step_avg:96.59ms +step:389/1695 train_time:37570ms step_avg:96.58ms +step:390/1695 train_time:37665ms step_avg:96.58ms +step:391/1695 train_time:37759ms step_avg:96.57ms +step:392/1695 train_time:37852ms step_avg:96.56ms +step:393/1695 train_time:37946ms step_avg:96.55ms +step:394/1695 train_time:38039ms step_avg:96.55ms +step:395/1695 train_time:38131ms step_avg:96.53ms +step:396/1695 train_time:38225ms step_avg:96.53ms +step:397/1695 train_time:38319ms step_avg:96.52ms +step:398/1695 train_time:38413ms step_avg:96.52ms +step:399/1695 train_time:38508ms step_avg:96.51ms +step:400/1695 train_time:38604ms step_avg:96.51ms +step:401/1695 train_time:38699ms step_avg:96.51ms +step:402/1695 train_time:38792ms step_avg:96.50ms +step:403/1695 train_time:38886ms step_avg:96.49ms +step:404/1695 train_time:38980ms step_avg:96.49ms +step:405/1695 train_time:39073ms step_avg:96.48ms +step:406/1695 train_time:39166ms step_avg:96.47ms +step:407/1695 train_time:39260ms step_avg:96.46ms +step:408/1695 train_time:39354ms step_avg:96.46ms +step:409/1695 train_time:39448ms step_avg:96.45ms +step:410/1695 train_time:39543ms step_avg:96.45ms +step:411/1695 train_time:39638ms step_avg:96.44ms +step:412/1695 train_time:39732ms step_avg:96.44ms +step:413/1695 train_time:39825ms step_avg:96.43ms +step:414/1695 train_time:39919ms step_avg:96.42ms +step:415/1695 train_time:40012ms step_avg:96.41ms +step:416/1695 train_time:40105ms step_avg:96.41ms +step:417/1695 train_time:40199ms step_avg:96.40ms +step:418/1695 train_time:40292ms step_avg:96.39ms +step:419/1695 train_time:40386ms step_avg:96.39ms +step:420/1695 train_time:40481ms step_avg:96.38ms +step:421/1695 train_time:40575ms step_avg:96.38ms +step:422/1695 train_time:40669ms step_avg:96.37ms +step:423/1695 train_time:40764ms step_avg:96.37ms +step:424/1695 train_time:40858ms step_avg:96.36ms +step:425/1695 train_time:40952ms step_avg:96.36ms +step:426/1695 train_time:41046ms step_avg:96.35ms +step:427/1695 train_time:41140ms step_avg:96.35ms +step:428/1695 train_time:41233ms step_avg:96.34ms +step:429/1695 train_time:41327ms step_avg:96.33ms +step:430/1695 train_time:41420ms step_avg:96.33ms +step:431/1695 train_time:41514ms step_avg:96.32ms +step:432/1695 train_time:41608ms step_avg:96.32ms +step:433/1695 train_time:41702ms step_avg:96.31ms +step:434/1695 train_time:41797ms step_avg:96.31ms +step:435/1695 train_time:41890ms step_avg:96.30ms +step:436/1695 train_time:41985ms step_avg:96.30ms +step:437/1695 train_time:42079ms step_avg:96.29ms +step:438/1695 train_time:42173ms step_avg:96.29ms +step:439/1695 train_time:42267ms step_avg:96.28ms +step:440/1695 train_time:42361ms step_avg:96.28ms +step:441/1695 train_time:42455ms step_avg:96.27ms +step:442/1695 train_time:42549ms step_avg:96.26ms +step:443/1695 train_time:42643ms step_avg:96.26ms +step:444/1695 train_time:42737ms step_avg:96.25ms +step:445/1695 train_time:42830ms step_avg:96.25ms +step:446/1695 train_time:42924ms step_avg:96.24ms +step:447/1695 train_time:43019ms step_avg:96.24ms +step:448/1695 train_time:43114ms step_avg:96.24ms +step:449/1695 train_time:43208ms step_avg:96.23ms +step:450/1695 train_time:43302ms step_avg:96.23ms +step:451/1695 train_time:43397ms step_avg:96.22ms +step:452/1695 train_time:43490ms step_avg:96.22ms +step:453/1695 train_time:43584ms step_avg:96.21ms +step:454/1695 train_time:43679ms step_avg:96.21ms +step:455/1695 train_time:43773ms step_avg:96.20ms +step:456/1695 train_time:43866ms step_avg:96.20ms +step:457/1695 train_time:43960ms step_avg:96.19ms +step:458/1695 train_time:44055ms step_avg:96.19ms +step:459/1695 train_time:44149ms step_avg:96.18ms +step:460/1695 train_time:44243ms step_avg:96.18ms +step:461/1695 train_time:44338ms step_avg:96.18ms +step:462/1695 train_time:44432ms step_avg:96.17ms +step:463/1695 train_time:44526ms step_avg:96.17ms +step:464/1695 train_time:44621ms step_avg:96.16ms +step:465/1695 train_time:44714ms step_avg:96.16ms +step:466/1695 train_time:44808ms step_avg:96.15ms +step:467/1695 train_time:44902ms step_avg:96.15ms +step:468/1695 train_time:44996ms step_avg:96.14ms +step:469/1695 train_time:45089ms step_avg:96.14ms +step:470/1695 train_time:45184ms step_avg:96.14ms +step:471/1695 train_time:45278ms step_avg:96.13ms +step:472/1695 train_time:45372ms step_avg:96.13ms +step:473/1695 train_time:45466ms step_avg:96.12ms +step:474/1695 train_time:45560ms step_avg:96.12ms +step:475/1695 train_time:45653ms step_avg:96.11ms +step:476/1695 train_time:45748ms step_avg:96.11ms +step:477/1695 train_time:45842ms step_avg:96.10ms +step:478/1695 train_time:45935ms step_avg:96.10ms +step:479/1695 train_time:46029ms step_avg:96.09ms +step:480/1695 train_time:46123ms step_avg:96.09ms +step:481/1695 train_time:46217ms step_avg:96.08ms +step:482/1695 train_time:46311ms step_avg:96.08ms +step:483/1695 train_time:46405ms step_avg:96.08ms +step:484/1695 train_time:46499ms step_avg:96.07ms +step:485/1695 train_time:46593ms step_avg:96.07ms +step:486/1695 train_time:46688ms step_avg:96.07ms +step:487/1695 train_time:46782ms step_avg:96.06ms +step:488/1695 train_time:46876ms step_avg:96.06ms +step:489/1695 train_time:46969ms step_avg:96.05ms +step:490/1695 train_time:47064ms step_avg:96.05ms +step:491/1695 train_time:47158ms step_avg:96.05ms +step:492/1695 train_time:47252ms step_avg:96.04ms +step:493/1695 train_time:47346ms step_avg:96.04ms +step:494/1695 train_time:47442ms step_avg:96.04ms +step:495/1695 train_time:47536ms step_avg:96.03ms +step:496/1695 train_time:47629ms step_avg:96.03ms +step:497/1695 train_time:47724ms step_avg:96.02ms +step:498/1695 train_time:47820ms step_avg:96.02ms +step:499/1695 train_time:47914ms step_avg:96.02ms +step:500/1695 train_time:48008ms step_avg:96.02ms +step:500/1695 val_loss:3.7161 train_time:48100ms step_avg:96.20ms +step:501/1695 train_time:48124ms step_avg:96.06ms +step:502/1695 train_time:48204ms step_avg:96.02ms +step:503/1695 train_time:48302ms step_avg:96.03ms +step:504/1695 train_time:48397ms step_avg:96.03ms +step:505/1695 train_time:48491ms step_avg:96.02ms +step:506/1695 train_time:48584ms step_avg:96.02ms +step:507/1695 train_time:48678ms step_avg:96.01ms +step:508/1695 train_time:48771ms step_avg:96.01ms +step:509/1695 train_time:48864ms step_avg:96.00ms +step:510/1695 train_time:48957ms step_avg:95.99ms +step:511/1695 train_time:49050ms step_avg:95.99ms +step:512/1695 train_time:49146ms step_avg:95.99ms +step:513/1695 train_time:49242ms step_avg:95.99ms +step:514/1695 train_time:49337ms step_avg:95.99ms +step:515/1695 train_time:49432ms step_avg:95.98ms +step:516/1695 train_time:49525ms step_avg:95.98ms +step:517/1695 train_time:49619ms step_avg:95.97ms +step:518/1695 train_time:49713ms step_avg:95.97ms +step:519/1695 train_time:50082ms step_avg:96.50ms +step:520/1695 train_time:50228ms step_avg:96.59ms +step:521/1695 train_time:50320ms step_avg:96.58ms +step:522/1695 train_time:50412ms step_avg:96.58ms +step:523/1695 train_time:50505ms step_avg:96.57ms +step:524/1695 train_time:50598ms step_avg:96.56ms +step:525/1695 train_time:50691ms step_avg:96.55ms +step:526/1695 train_time:50784ms step_avg:96.55ms +step:527/1695 train_time:50878ms step_avg:96.54ms +step:528/1695 train_time:50971ms step_avg:96.54ms +step:529/1695 train_time:51069ms step_avg:96.54ms +step:530/1695 train_time:51167ms step_avg:96.54ms +step:531/1695 train_time:51264ms step_avg:96.54ms +step:532/1695 train_time:51358ms step_avg:96.54ms +step:533/1695 train_time:51452ms step_avg:96.53ms +step:534/1695 train_time:51545ms step_avg:96.53ms +step:535/1695 train_time:51638ms step_avg:96.52ms +step:536/1695 train_time:51732ms step_avg:96.51ms +step:537/1695 train_time:51824ms step_avg:96.51ms +step:538/1695 train_time:51918ms step_avg:96.50ms +step:539/1695 train_time:52014ms step_avg:96.50ms +step:540/1695 train_time:52110ms step_avg:96.50ms +step:541/1695 train_time:52204ms step_avg:96.50ms +step:542/1695 train_time:52299ms step_avg:96.49ms +step:543/1695 train_time:52393ms step_avg:96.49ms +step:544/1695 train_time:52486ms step_avg:96.48ms +step:545/1695 train_time:52580ms step_avg:96.48ms +step:546/1695 train_time:52674ms step_avg:96.47ms +step:547/1695 train_time:52767ms step_avg:96.47ms +step:548/1695 train_time:52860ms step_avg:96.46ms +step:549/1695 train_time:52954ms step_avg:96.46ms +step:550/1695 train_time:53049ms step_avg:96.45ms +step:551/1695 train_time:53143ms step_avg:96.45ms +step:552/1695 train_time:53238ms step_avg:96.45ms +step:553/1695 train_time:53331ms step_avg:96.44ms +step:554/1695 train_time:53425ms step_avg:96.43ms +step:555/1695 train_time:53519ms step_avg:96.43ms +step:556/1695 train_time:53614ms step_avg:96.43ms +step:557/1695 train_time:53708ms step_avg:96.42ms +step:558/1695 train_time:53801ms step_avg:96.42ms +step:559/1695 train_time:53895ms step_avg:96.41ms +step:560/1695 train_time:53989ms step_avg:96.41ms +step:561/1695 train_time:54083ms step_avg:96.40ms +step:562/1695 train_time:54178ms step_avg:96.40ms +step:563/1695 train_time:54273ms step_avg:96.40ms +step:564/1695 train_time:54367ms step_avg:96.40ms +step:565/1695 train_time:54461ms step_avg:96.39ms +step:566/1695 train_time:54555ms step_avg:96.39ms +step:567/1695 train_time:54650ms step_avg:96.38ms +step:568/1695 train_time:54746ms step_avg:96.38ms +step:569/1695 train_time:54841ms step_avg:96.38ms +step:570/1695 train_time:54938ms step_avg:96.38ms +step:571/1695 train_time:55035ms step_avg:96.38ms +step:572/1695 train_time:55131ms step_avg:96.38ms +step:573/1695 train_time:55228ms step_avg:96.38ms +step:574/1695 train_time:55323ms step_avg:96.38ms +step:575/1695 train_time:55420ms step_avg:96.38ms +step:576/1695 train_time:55517ms step_avg:96.38ms +step:577/1695 train_time:55614ms step_avg:96.38ms +step:578/1695 train_time:55711ms step_avg:96.39ms +step:579/1695 train_time:55806ms step_avg:96.38ms +step:580/1695 train_time:55903ms step_avg:96.38ms +step:581/1695 train_time:55999ms step_avg:96.38ms +step:582/1695 train_time:56097ms step_avg:96.39ms +step:583/1695 train_time:56194ms step_avg:96.39ms +step:584/1695 train_time:56290ms step_avg:96.39ms +step:585/1695 train_time:56386ms step_avg:96.39ms +step:586/1695 train_time:56482ms step_avg:96.39ms +step:587/1695 train_time:56578ms step_avg:96.39ms +step:588/1695 train_time:56675ms step_avg:96.39ms +step:589/1695 train_time:56772ms step_avg:96.39ms +step:590/1695 train_time:56868ms step_avg:96.39ms +step:591/1695 train_time:56964ms step_avg:96.39ms +step:592/1695 train_time:57061ms step_avg:96.39ms +step:593/1695 train_time:57157ms step_avg:96.39ms +step:594/1695 train_time:57254ms step_avg:96.39ms +step:595/1695 train_time:57350ms step_avg:96.39ms +step:596/1695 train_time:57446ms step_avg:96.39ms +step:597/1695 train_time:57542ms step_avg:96.39ms +step:598/1695 train_time:57638ms step_avg:96.38ms +step:599/1695 train_time:57735ms step_avg:96.39ms +step:600/1695 train_time:57831ms step_avg:96.39ms +step:601/1695 train_time:57926ms step_avg:96.38ms +step:602/1695 train_time:58022ms step_avg:96.38ms +step:603/1695 train_time:58118ms step_avg:96.38ms +step:604/1695 train_time:58215ms step_avg:96.38ms +step:605/1695 train_time:58312ms step_avg:96.38ms +step:606/1695 train_time:58408ms step_avg:96.38ms +step:607/1695 train_time:58504ms step_avg:96.38ms +step:608/1695 train_time:58600ms step_avg:96.38ms +step:609/1695 train_time:58696ms step_avg:96.38ms +step:610/1695 train_time:58794ms step_avg:96.38ms +step:611/1695 train_time:58891ms step_avg:96.38ms +step:612/1695 train_time:58987ms step_avg:96.38ms +step:613/1695 train_time:59083ms step_avg:96.38ms +step:614/1695 train_time:59180ms step_avg:96.38ms +step:615/1695 train_time:59276ms step_avg:96.38ms +step:616/1695 train_time:59373ms step_avg:96.39ms +step:617/1695 train_time:59470ms step_avg:96.39ms +step:618/1695 train_time:59565ms step_avg:96.38ms +step:619/1695 train_time:59661ms step_avg:96.38ms +step:620/1695 train_time:59758ms step_avg:96.38ms +step:621/1695 train_time:59855ms step_avg:96.39ms +step:622/1695 train_time:59952ms step_avg:96.39ms +step:623/1695 train_time:60048ms step_avg:96.39ms +step:624/1695 train_time:60143ms step_avg:96.38ms +step:625/1695 train_time:60239ms step_avg:96.38ms +step:625/1695 val_loss:3.6203 train_time:60334ms step_avg:96.53ms +step:626/1695 train_time:60358ms step_avg:96.42ms +step:627/1695 train_time:60442ms step_avg:96.40ms +step:628/1695 train_time:60540ms step_avg:96.40ms +step:629/1695 train_time:60637ms step_avg:96.40ms +step:630/1695 train_time:60732ms step_avg:96.40ms +step:631/1695 train_time:60827ms step_avg:96.40ms +step:632/1695 train_time:60921ms step_avg:96.39ms +step:633/1695 train_time:61017ms step_avg:96.39ms +step:634/1695 train_time:61112ms step_avg:96.39ms +step:635/1695 train_time:61208ms step_avg:96.39ms +step:636/1695 train_time:61305ms step_avg:96.39ms +step:637/1695 train_time:61405ms step_avg:96.40ms +step:638/1695 train_time:61502ms step_avg:96.40ms +step:639/1695 train_time:61599ms step_avg:96.40ms +step:640/1695 train_time:61696ms step_avg:96.40ms +step:641/1695 train_time:61793ms step_avg:96.40ms +step:642/1695 train_time:61887ms step_avg:96.40ms +step:643/1695 train_time:61983ms step_avg:96.40ms +step:644/1695 train_time:62078ms step_avg:96.39ms +step:645/1695 train_time:62175ms step_avg:96.40ms +step:646/1695 train_time:62272ms step_avg:96.40ms +step:647/1695 train_time:62368ms step_avg:96.40ms +step:648/1695 train_time:62465ms step_avg:96.40ms +step:649/1695 train_time:62562ms step_avg:96.40ms +step:650/1695 train_time:62660ms step_avg:96.40ms +step:651/1695 train_time:62756ms step_avg:96.40ms +step:652/1695 train_time:62852ms step_avg:96.40ms +step:653/1695 train_time:62947ms step_avg:96.40ms +step:654/1695 train_time:63041ms step_avg:96.39ms +step:655/1695 train_time:63139ms step_avg:96.39ms +step:656/1695 train_time:63237ms step_avg:96.40ms +step:657/1695 train_time:63335ms step_avg:96.40ms +step:658/1695 train_time:63431ms step_avg:96.40ms +step:659/1695 train_time:63527ms step_avg:96.40ms +step:660/1695 train_time:63623ms step_avg:96.40ms +step:661/1695 train_time:63719ms step_avg:96.40ms +step:662/1695 train_time:63815ms step_avg:96.40ms +step:663/1695 train_time:63911ms step_avg:96.40ms +step:664/1695 train_time:64007ms step_avg:96.40ms +step:665/1695 train_time:64103ms step_avg:96.40ms +step:666/1695 train_time:64199ms step_avg:96.39ms +step:667/1695 train_time:64295ms step_avg:96.39ms +step:668/1695 train_time:64391ms step_avg:96.39ms +step:669/1695 train_time:64487ms step_avg:96.39ms +step:670/1695 train_time:64583ms step_avg:96.39ms +step:671/1695 train_time:64679ms step_avg:96.39ms +step:672/1695 train_time:64775ms step_avg:96.39ms +step:673/1695 train_time:64871ms step_avg:96.39ms +step:674/1695 train_time:64967ms step_avg:96.39ms +step:675/1695 train_time:65062ms step_avg:96.39ms +step:676/1695 train_time:65158ms step_avg:96.39ms +step:677/1695 train_time:65255ms step_avg:96.39ms +step:678/1695 train_time:65351ms step_avg:96.39ms +step:679/1695 train_time:65447ms step_avg:96.39ms +step:680/1695 train_time:65542ms step_avg:96.39ms +step:681/1695 train_time:65639ms step_avg:96.39ms +step:682/1695 train_time:65735ms step_avg:96.39ms +step:683/1695 train_time:65832ms step_avg:96.39ms +step:684/1695 train_time:65929ms step_avg:96.39ms +step:685/1695 train_time:66025ms step_avg:96.39ms +step:686/1695 train_time:66120ms step_avg:96.38ms +step:687/1695 train_time:66216ms step_avg:96.38ms +step:688/1695 train_time:66313ms step_avg:96.39ms +step:689/1695 train_time:66410ms step_avg:96.39ms +step:690/1695 train_time:66505ms step_avg:96.38ms +step:691/1695 train_time:66945ms step_avg:96.88ms +step:692/1695 train_time:67026ms step_avg:96.86ms +step:693/1695 train_time:67121ms step_avg:96.86ms +step:694/1695 train_time:67216ms step_avg:96.85ms +step:695/1695 train_time:67311ms step_avg:96.85ms +step:696/1695 train_time:67406ms step_avg:96.85ms +step:697/1695 train_time:67501ms step_avg:96.84ms +step:698/1695 train_time:67597ms step_avg:96.84ms +step:699/1695 train_time:67692ms step_avg:96.84ms +step:700/1695 train_time:67787ms step_avg:96.84ms +step:701/1695 train_time:67886ms step_avg:96.84ms +step:702/1695 train_time:67985ms step_avg:96.84ms +step:703/1695 train_time:68082ms step_avg:96.84ms +step:704/1695 train_time:68178ms step_avg:96.84ms +step:705/1695 train_time:68274ms step_avg:96.84ms +step:706/1695 train_time:68370ms step_avg:96.84ms +step:707/1695 train_time:68465ms step_avg:96.84ms +step:708/1695 train_time:68561ms step_avg:96.84ms +step:709/1695 train_time:68656ms step_avg:96.84ms +step:710/1695 train_time:68752ms step_avg:96.83ms +step:711/1695 train_time:68850ms step_avg:96.84ms +step:712/1695 train_time:68946ms step_avg:96.83ms +step:713/1695 train_time:69043ms step_avg:96.83ms +step:714/1695 train_time:69139ms step_avg:96.83ms +step:715/1695 train_time:69235ms step_avg:96.83ms +step:716/1695 train_time:69331ms step_avg:96.83ms +step:717/1695 train_time:69426ms step_avg:96.83ms +step:718/1695 train_time:69522ms step_avg:96.83ms +step:719/1695 train_time:69618ms step_avg:96.83ms +step:720/1695 train_time:69714ms step_avg:96.82ms +step:721/1695 train_time:69811ms step_avg:96.83ms +step:722/1695 train_time:69908ms step_avg:96.83ms +step:723/1695 train_time:70004ms step_avg:96.82ms +step:724/1695 train_time:70101ms step_avg:96.82ms +step:725/1695 train_time:70198ms step_avg:96.83ms +step:726/1695 train_time:70296ms step_avg:96.83ms +step:727/1695 train_time:70392ms step_avg:96.83ms +step:728/1695 train_time:70487ms step_avg:96.82ms +step:729/1695 train_time:70583ms step_avg:96.82ms +step:730/1695 train_time:70679ms step_avg:96.82ms +step:731/1695 train_time:70776ms step_avg:96.82ms +step:732/1695 train_time:70874ms step_avg:96.82ms +step:733/1695 train_time:70970ms step_avg:96.82ms +step:734/1695 train_time:71066ms step_avg:96.82ms +step:735/1695 train_time:71162ms step_avg:96.82ms +step:736/1695 train_time:71258ms step_avg:96.82ms +step:737/1695 train_time:71355ms step_avg:96.82ms +step:738/1695 train_time:71451ms step_avg:96.82ms +step:739/1695 train_time:71546ms step_avg:96.81ms +step:740/1695 train_time:71641ms step_avg:96.81ms +step:741/1695 train_time:71739ms step_avg:96.81ms +step:742/1695 train_time:71837ms step_avg:96.81ms +step:743/1695 train_time:71934ms step_avg:96.82ms +step:744/1695 train_time:72030ms step_avg:96.81ms +step:745/1695 train_time:72125ms step_avg:96.81ms +step:746/1695 train_time:72221ms step_avg:96.81ms +step:747/1695 train_time:72318ms step_avg:96.81ms +step:748/1695 train_time:72414ms step_avg:96.81ms +step:749/1695 train_time:72510ms step_avg:96.81ms +step:750/1695 train_time:72605ms step_avg:96.81ms +step:750/1695 val_loss:3.5663 train_time:72700ms step_avg:96.93ms +step:751/1695 train_time:72724ms step_avg:96.84ms +step:752/1695 train_time:72807ms step_avg:96.82ms +step:753/1695 train_time:72904ms step_avg:96.82ms +step:754/1695 train_time:73002ms step_avg:96.82ms +step:755/1695 train_time:73098ms step_avg:96.82ms +step:756/1695 train_time:73192ms step_avg:96.82ms +step:757/1695 train_time:73287ms step_avg:96.81ms +step:758/1695 train_time:73381ms step_avg:96.81ms +step:759/1695 train_time:73476ms step_avg:96.81ms +step:760/1695 train_time:73571ms step_avg:96.80ms +step:761/1695 train_time:73668ms step_avg:96.80ms +step:762/1695 train_time:73766ms step_avg:96.81ms +step:763/1695 train_time:73864ms step_avg:96.81ms +step:764/1695 train_time:73962ms step_avg:96.81ms +step:765/1695 train_time:74059ms step_avg:96.81ms +step:766/1695 train_time:74154ms step_avg:96.81ms +step:767/1695 train_time:74249ms step_avg:96.80ms +step:768/1695 train_time:74344ms step_avg:96.80ms +step:769/1695 train_time:74439ms step_avg:96.80ms +step:770/1695 train_time:74535ms step_avg:96.80ms +step:771/1695 train_time:74630ms step_avg:96.80ms +step:772/1695 train_time:74726ms step_avg:96.80ms +step:773/1695 train_time:74824ms step_avg:96.80ms +step:774/1695 train_time:74921ms step_avg:96.80ms +step:775/1695 train_time:75018ms step_avg:96.80ms +step:776/1695 train_time:75114ms step_avg:96.80ms +step:777/1695 train_time:75209ms step_avg:96.79ms +step:778/1695 train_time:75304ms step_avg:96.79ms +step:779/1695 train_time:75400ms step_avg:96.79ms +step:780/1695 train_time:75496ms step_avg:96.79ms +step:781/1695 train_time:75592ms step_avg:96.79ms +step:782/1695 train_time:75687ms step_avg:96.79ms +step:783/1695 train_time:75784ms step_avg:96.79ms +step:784/1695 train_time:75880ms step_avg:96.79ms +step:785/1695 train_time:75977ms step_avg:96.79ms +step:786/1695 train_time:76073ms step_avg:96.79ms +step:787/1695 train_time:76168ms step_avg:96.78ms +step:788/1695 train_time:76264ms step_avg:96.78ms +step:789/1695 train_time:76359ms step_avg:96.78ms +step:790/1695 train_time:76455ms step_avg:96.78ms +step:791/1695 train_time:76549ms step_avg:96.78ms +step:792/1695 train_time:76645ms step_avg:96.77ms +step:793/1695 train_time:76742ms step_avg:96.77ms +step:794/1695 train_time:76839ms step_avg:96.77ms +step:795/1695 train_time:76935ms step_avg:96.77ms +step:796/1695 train_time:77031ms step_avg:96.77ms +step:797/1695 train_time:77126ms step_avg:96.77ms +step:798/1695 train_time:77224ms step_avg:96.77ms +step:799/1695 train_time:77320ms step_avg:96.77ms +step:800/1695 train_time:77416ms step_avg:96.77ms +step:801/1695 train_time:77511ms step_avg:96.77ms +step:802/1695 train_time:77607ms step_avg:96.77ms +step:803/1695 train_time:77702ms step_avg:96.76ms +step:804/1695 train_time:77797ms step_avg:96.76ms +step:805/1695 train_time:77893ms step_avg:96.76ms +step:806/1695 train_time:77989ms step_avg:96.76ms +step:807/1695 train_time:78085ms step_avg:96.76ms +step:808/1695 train_time:78181ms step_avg:96.76ms +step:809/1695 train_time:78276ms step_avg:96.76ms +step:810/1695 train_time:78371ms step_avg:96.75ms +step:811/1695 train_time:78467ms step_avg:96.75ms +step:812/1695 train_time:78562ms step_avg:96.75ms +step:813/1695 train_time:78657ms step_avg:96.75ms +step:814/1695 train_time:78752ms step_avg:96.75ms +step:815/1695 train_time:78848ms step_avg:96.75ms +step:816/1695 train_time:78945ms step_avg:96.75ms +step:817/1695 train_time:79042ms step_avg:96.75ms +step:818/1695 train_time:79138ms step_avg:96.75ms +step:819/1695 train_time:79234ms step_avg:96.74ms +step:820/1695 train_time:79329ms step_avg:96.74ms +step:821/1695 train_time:79425ms step_avg:96.74ms +step:822/1695 train_time:79520ms step_avg:96.74ms +step:823/1695 train_time:79616ms step_avg:96.74ms +step:824/1695 train_time:79711ms step_avg:96.74ms +step:825/1695 train_time:79807ms step_avg:96.74ms +step:826/1695 train_time:79903ms step_avg:96.73ms +step:827/1695 train_time:79998ms step_avg:96.73ms +step:828/1695 train_time:80094ms step_avg:96.73ms +step:829/1695 train_time:80190ms step_avg:96.73ms +step:830/1695 train_time:80286ms step_avg:96.73ms +step:831/1695 train_time:80382ms step_avg:96.73ms +step:832/1695 train_time:80477ms step_avg:96.73ms +step:833/1695 train_time:80572ms step_avg:96.73ms +step:834/1695 train_time:80668ms step_avg:96.72ms +step:835/1695 train_time:80763ms step_avg:96.72ms +step:836/1695 train_time:80859ms step_avg:96.72ms +step:837/1695 train_time:80955ms step_avg:96.72ms +step:838/1695 train_time:81050ms step_avg:96.72ms +step:839/1695 train_time:81146ms step_avg:96.72ms +step:840/1695 train_time:81243ms step_avg:96.72ms +step:841/1695 train_time:81339ms step_avg:96.72ms +step:842/1695 train_time:81435ms step_avg:96.72ms +step:843/1695 train_time:81530ms step_avg:96.71ms +step:844/1695 train_time:81625ms step_avg:96.71ms +step:845/1695 train_time:81722ms step_avg:96.71ms +step:846/1695 train_time:81818ms step_avg:96.71ms +step:847/1695 train_time:81915ms step_avg:96.71ms +step:848/1695 train_time:82011ms step_avg:96.71ms +step:849/1695 train_time:82106ms step_avg:96.71ms +step:850/1695 train_time:82202ms step_avg:96.71ms +step:851/1695 train_time:82298ms step_avg:96.71ms +step:852/1695 train_time:82393ms step_avg:96.70ms +step:853/1695 train_time:82488ms step_avg:96.70ms +step:854/1695 train_time:82583ms step_avg:96.70ms +step:855/1695 train_time:82679ms step_avg:96.70ms +step:856/1695 train_time:82775ms step_avg:96.70ms +step:857/1695 train_time:82870ms step_avg:96.70ms +step:858/1695 train_time:82966ms step_avg:96.70ms +step:859/1695 train_time:83063ms step_avg:96.70ms +step:860/1695 train_time:83159ms step_avg:96.70ms +step:861/1695 train_time:83255ms step_avg:96.70ms +step:862/1695 train_time:83350ms step_avg:96.69ms +step:863/1695 train_time:83679ms step_avg:96.96ms +step:864/1695 train_time:83862ms step_avg:97.06ms +step:865/1695 train_time:83955ms step_avg:97.06ms +step:866/1695 train_time:84050ms step_avg:97.06ms +step:867/1695 train_time:84145ms step_avg:97.05ms +step:868/1695 train_time:84240ms step_avg:97.05ms +step:869/1695 train_time:84336ms step_avg:97.05ms +step:870/1695 train_time:84431ms step_avg:97.05ms +step:871/1695 train_time:84525ms step_avg:97.04ms +step:872/1695 train_time:84620ms step_avg:97.04ms +step:873/1695 train_time:84718ms step_avg:97.04ms +step:874/1695 train_time:84818ms step_avg:97.05ms +step:875/1695 train_time:84917ms step_avg:97.05ms +step:875/1695 val_loss:3.5235 train_time:85011ms step_avg:97.16ms +step:876/1695 train_time:85037ms step_avg:97.07ms +step:877/1695 train_time:85116ms step_avg:97.05ms +step:878/1695 train_time:85213ms step_avg:97.05ms +step:879/1695 train_time:85309ms step_avg:97.05ms +step:880/1695 train_time:85404ms step_avg:97.05ms +step:881/1695 train_time:85499ms step_avg:97.05ms +step:882/1695 train_time:85594ms step_avg:97.05ms +step:883/1695 train_time:85690ms step_avg:97.04ms +step:884/1695 train_time:85785ms step_avg:97.04ms +step:885/1695 train_time:85879ms step_avg:97.04ms +step:886/1695 train_time:85976ms step_avg:97.04ms +step:887/1695 train_time:86075ms step_avg:97.04ms +step:888/1695 train_time:86174ms step_avg:97.04ms +step:889/1695 train_time:86271ms step_avg:97.04ms +step:890/1695 train_time:86367ms step_avg:97.04ms +step:891/1695 train_time:86462ms step_avg:97.04ms +step:892/1695 train_time:86557ms step_avg:97.04ms +step:893/1695 train_time:86653ms step_avg:97.04ms +step:894/1695 train_time:86749ms step_avg:97.04ms +step:895/1695 train_time:86845ms step_avg:97.03ms +step:896/1695 train_time:86940ms step_avg:97.03ms +step:897/1695 train_time:87038ms step_avg:97.03ms +step:898/1695 train_time:87136ms step_avg:97.03ms +step:899/1695 train_time:87234ms step_avg:97.03ms +step:900/1695 train_time:87331ms step_avg:97.03ms +step:901/1695 train_time:87427ms step_avg:97.03ms +step:902/1695 train_time:87522ms step_avg:97.03ms +step:903/1695 train_time:87617ms step_avg:97.03ms +step:904/1695 train_time:87714ms step_avg:97.03ms +step:905/1695 train_time:87810ms step_avg:97.03ms +step:906/1695 train_time:87905ms step_avg:97.03ms +step:907/1695 train_time:88002ms step_avg:97.03ms +step:908/1695 train_time:88099ms step_avg:97.03ms +step:909/1695 train_time:88197ms step_avg:97.03ms +step:910/1695 train_time:88295ms step_avg:97.03ms +step:911/1695 train_time:88392ms step_avg:97.03ms +step:912/1695 train_time:88487ms step_avg:97.03ms +step:913/1695 train_time:88582ms step_avg:97.02ms +step:914/1695 train_time:88678ms step_avg:97.02ms +step:915/1695 train_time:88774ms step_avg:97.02ms +step:916/1695 train_time:88870ms step_avg:97.02ms +step:917/1695 train_time:88965ms step_avg:97.02ms +step:918/1695 train_time:89060ms step_avg:97.02ms +step:919/1695 train_time:89157ms step_avg:97.02ms +step:920/1695 train_time:89255ms step_avg:97.02ms +step:921/1695 train_time:89353ms step_avg:97.02ms +step:922/1695 train_time:89451ms step_avg:97.02ms +step:923/1695 train_time:89547ms step_avg:97.02ms +step:924/1695 train_time:89642ms step_avg:97.01ms +step:925/1695 train_time:89738ms step_avg:97.01ms +step:926/1695 train_time:89833ms step_avg:97.01ms +step:927/1695 train_time:89929ms step_avg:97.01ms +step:928/1695 train_time:90025ms step_avg:97.01ms +step:929/1695 train_time:90120ms step_avg:97.01ms +step:930/1695 train_time:90216ms step_avg:97.01ms +step:931/1695 train_time:90313ms step_avg:97.01ms +step:932/1695 train_time:90410ms step_avg:97.01ms +step:933/1695 train_time:90506ms step_avg:97.01ms +step:934/1695 train_time:90602ms step_avg:97.00ms +step:935/1695 train_time:90697ms step_avg:97.00ms +step:936/1695 train_time:90794ms step_avg:97.00ms +step:937/1695 train_time:90890ms step_avg:97.00ms +step:938/1695 train_time:90985ms step_avg:97.00ms +step:939/1695 train_time:91081ms step_avg:97.00ms +step:940/1695 train_time:91177ms step_avg:97.00ms +step:941/1695 train_time:91274ms step_avg:97.00ms +step:942/1695 train_time:91371ms step_avg:97.00ms +step:943/1695 train_time:91468ms step_avg:97.00ms +step:944/1695 train_time:91563ms step_avg:97.00ms +step:945/1695 train_time:91660ms step_avg:96.99ms +step:946/1695 train_time:91756ms step_avg:96.99ms +step:947/1695 train_time:91855ms step_avg:97.00ms +step:948/1695 train_time:91952ms step_avg:97.00ms +step:949/1695 train_time:92048ms step_avg:96.99ms +step:950/1695 train_time:92144ms step_avg:96.99ms +step:951/1695 train_time:92239ms step_avg:96.99ms +step:952/1695 train_time:92336ms step_avg:96.99ms +step:953/1695 train_time:92433ms step_avg:96.99ms +step:954/1695 train_time:92529ms step_avg:96.99ms +step:955/1695 train_time:92625ms step_avg:96.99ms +step:956/1695 train_time:92721ms step_avg:96.99ms +step:957/1695 train_time:92817ms step_avg:96.99ms +step:958/1695 train_time:92914ms step_avg:96.99ms +step:959/1695 train_time:93011ms step_avg:96.99ms +step:960/1695 train_time:93107ms step_avg:96.99ms +step:961/1695 train_time:93203ms step_avg:96.99ms +step:962/1695 train_time:93299ms step_avg:96.98ms +step:963/1695 train_time:93395ms step_avg:96.98ms +step:964/1695 train_time:93491ms step_avg:96.98ms +step:965/1695 train_time:93587ms step_avg:96.98ms +step:966/1695 train_time:93682ms step_avg:96.98ms +step:967/1695 train_time:93779ms step_avg:96.98ms +step:968/1695 train_time:93874ms step_avg:96.98ms +step:969/1695 train_time:93970ms step_avg:96.98ms +step:970/1695 train_time:94066ms step_avg:96.98ms +step:971/1695 train_time:94162ms step_avg:96.97ms +step:972/1695 train_time:94259ms step_avg:96.97ms +step:973/1695 train_time:94356ms step_avg:96.97ms +step:974/1695 train_time:94452ms step_avg:96.97ms +step:975/1695 train_time:94549ms step_avg:96.97ms +step:976/1695 train_time:94643ms step_avg:96.97ms +step:977/1695 train_time:94739ms step_avg:96.97ms +step:978/1695 train_time:94835ms step_avg:96.97ms +step:979/1695 train_time:94932ms step_avg:96.97ms +step:980/1695 train_time:95028ms step_avg:96.97ms +step:981/1695 train_time:95124ms step_avg:96.97ms +step:982/1695 train_time:95219ms step_avg:96.96ms +step:983/1695 train_time:95316ms step_avg:96.96ms +step:984/1695 train_time:95412ms step_avg:96.96ms +step:985/1695 train_time:95508ms step_avg:96.96ms +step:986/1695 train_time:95603ms step_avg:96.96ms +step:987/1695 train_time:95699ms step_avg:96.96ms +step:988/1695 train_time:95795ms step_avg:96.96ms +step:989/1695 train_time:95891ms step_avg:96.96ms +step:990/1695 train_time:95987ms step_avg:96.96ms +step:991/1695 train_time:96083ms step_avg:96.96ms +step:992/1695 train_time:96178ms step_avg:96.95ms +step:993/1695 train_time:96274ms step_avg:96.95ms +step:994/1695 train_time:96370ms step_avg:96.95ms +step:995/1695 train_time:96465ms step_avg:96.95ms +step:996/1695 train_time:96560ms step_avg:96.95ms +step:997/1695 train_time:96657ms step_avg:96.95ms +step:998/1695 train_time:96753ms step_avg:96.95ms +step:999/1695 train_time:96850ms step_avg:96.95ms +step:1000/1695 train_time:96947ms step_avg:96.95ms +step:1000/1695 val_loss:3.4841 train_time:97040ms step_avg:97.04ms +step:1001/1695 train_time:97064ms step_avg:96.97ms +step:1002/1695 train_time:97146ms step_avg:96.95ms +step:1003/1695 train_time:97243ms step_avg:96.95ms +step:1004/1695 train_time:97339ms step_avg:96.95ms +step:1005/1695 train_time:97434ms step_avg:96.95ms +step:1006/1695 train_time:97530ms step_avg:96.95ms +step:1007/1695 train_time:97624ms step_avg:96.95ms +step:1008/1695 train_time:97719ms step_avg:96.94ms +step:1009/1695 train_time:97815ms step_avg:96.94ms +step:1010/1695 train_time:97910ms step_avg:96.94ms +step:1011/1695 train_time:98007ms step_avg:96.94ms +step:1012/1695 train_time:98104ms step_avg:96.94ms +step:1013/1695 train_time:98201ms step_avg:96.94ms +step:1014/1695 train_time:98298ms step_avg:96.94ms +step:1015/1695 train_time:98395ms step_avg:96.94ms +step:1016/1695 train_time:98490ms step_avg:96.94ms +step:1017/1695 train_time:98586ms step_avg:96.94ms +step:1018/1695 train_time:98682ms step_avg:96.94ms +step:1019/1695 train_time:98777ms step_avg:96.93ms +step:1020/1695 train_time:98872ms step_avg:96.93ms +step:1021/1695 train_time:98967ms step_avg:96.93ms +step:1022/1695 train_time:99063ms step_avg:96.93ms +step:1023/1695 train_time:99160ms step_avg:96.93ms +step:1024/1695 train_time:99258ms step_avg:96.93ms +step:1025/1695 train_time:99355ms step_avg:96.93ms +step:1026/1695 train_time:99451ms step_avg:96.93ms +step:1027/1695 train_time:99547ms step_avg:96.93ms +step:1028/1695 train_time:99643ms step_avg:96.93ms +step:1029/1695 train_time:99738ms step_avg:96.93ms +step:1030/1695 train_time:99833ms step_avg:96.93ms +step:1031/1695 train_time:99930ms step_avg:96.93ms +step:1032/1695 train_time:100026ms step_avg:96.92ms +step:1033/1695 train_time:100122ms step_avg:96.92ms +step:1034/1695 train_time:100218ms step_avg:96.92ms +step:1035/1695 train_time:100315ms step_avg:96.92ms +step:1036/1695 train_time:100647ms step_avg:97.15ms +step:1037/1695 train_time:100826ms step_avg:97.23ms +step:1038/1695 train_time:100920ms step_avg:97.22ms +step:1039/1695 train_time:101015ms step_avg:97.22ms +step:1040/1695 train_time:101110ms step_avg:97.22ms +step:1041/1695 train_time:101204ms step_avg:97.22ms +step:1042/1695 train_time:101299ms step_avg:97.22ms +step:1043/1695 train_time:101394ms step_avg:97.21ms +step:1044/1695 train_time:101489ms step_avg:97.21ms +step:1045/1695 train_time:101584ms step_avg:97.21ms +step:1046/1695 train_time:101681ms step_avg:97.21ms +step:1047/1695 train_time:101783ms step_avg:97.21ms +step:1048/1695 train_time:101883ms step_avg:97.22ms +step:1049/1695 train_time:101980ms step_avg:97.22ms +step:1050/1695 train_time:102077ms step_avg:97.22ms +step:1051/1695 train_time:102173ms step_avg:97.21ms +step:1052/1695 train_time:102268ms step_avg:97.21ms +step:1053/1695 train_time:102362ms step_avg:97.21ms +step:1054/1695 train_time:102457ms step_avg:97.21ms +step:1055/1695 train_time:102553ms step_avg:97.21ms +step:1056/1695 train_time:102650ms step_avg:97.21ms +step:1057/1695 train_time:102747ms step_avg:97.21ms +step:1058/1695 train_time:102844ms step_avg:97.21ms +step:1059/1695 train_time:102941ms step_avg:97.21ms +step:1060/1695 train_time:103037ms step_avg:97.20ms +step:1061/1695 train_time:103134ms step_avg:97.20ms +step:1062/1695 train_time:103230ms step_avg:97.20ms +step:1063/1695 train_time:103325ms step_avg:97.20ms +step:1064/1695 train_time:103421ms step_avg:97.20ms +step:1065/1695 train_time:103516ms step_avg:97.20ms +step:1066/1695 train_time:103612ms step_avg:97.20ms +step:1067/1695 train_time:103709ms step_avg:97.20ms +step:1068/1695 train_time:103806ms step_avg:97.20ms +step:1069/1695 train_time:103901ms step_avg:97.19ms +step:1070/1695 train_time:103997ms step_avg:97.19ms +step:1071/1695 train_time:104094ms step_avg:97.19ms +step:1072/1695 train_time:104190ms step_avg:97.19ms +step:1073/1695 train_time:104285ms step_avg:97.19ms +step:1074/1695 train_time:104381ms step_avg:97.19ms +step:1075/1695 train_time:104476ms step_avg:97.19ms +step:1076/1695 train_time:104572ms step_avg:97.19ms +step:1077/1695 train_time:104668ms step_avg:97.18ms +step:1078/1695 train_time:104764ms step_avg:97.18ms +step:1079/1695 train_time:104860ms step_avg:97.18ms +step:1080/1695 train_time:104956ms step_avg:97.18ms +step:1081/1695 train_time:105052ms step_avg:97.18ms +step:1082/1695 train_time:105149ms step_avg:97.18ms +step:1083/1695 train_time:105246ms step_avg:97.18ms +step:1084/1695 train_time:105341ms step_avg:97.18ms +step:1085/1695 train_time:105437ms step_avg:97.18ms +step:1086/1695 train_time:105533ms step_avg:97.18ms +step:1087/1695 train_time:105630ms step_avg:97.18ms +step:1088/1695 train_time:105726ms step_avg:97.18ms +step:1089/1695 train_time:105822ms step_avg:97.17ms +step:1090/1695 train_time:105918ms step_avg:97.17ms +step:1091/1695 train_time:106014ms step_avg:97.17ms +step:1092/1695 train_time:106110ms step_avg:97.17ms +step:1093/1695 train_time:106206ms step_avg:97.17ms +step:1094/1695 train_time:106301ms step_avg:97.17ms +step:1095/1695 train_time:106397ms step_avg:97.17ms +step:1096/1695 train_time:106493ms step_avg:97.16ms +step:1097/1695 train_time:106589ms step_avg:97.16ms +step:1098/1695 train_time:106685ms step_avg:97.16ms +step:1099/1695 train_time:106781ms step_avg:97.16ms +step:1100/1695 train_time:106878ms step_avg:97.16ms +step:1101/1695 train_time:106974ms step_avg:97.16ms +step:1102/1695 train_time:107070ms step_avg:97.16ms +step:1103/1695 train_time:107166ms step_avg:97.16ms +step:1104/1695 train_time:107261ms step_avg:97.16ms +step:1105/1695 train_time:107358ms step_avg:97.16ms +step:1106/1695 train_time:107455ms step_avg:97.16ms +step:1107/1695 train_time:107551ms step_avg:97.16ms +step:1108/1695 train_time:107648ms step_avg:97.16ms +step:1109/1695 train_time:107743ms step_avg:97.15ms +step:1110/1695 train_time:107839ms step_avg:97.15ms +step:1111/1695 train_time:107935ms step_avg:97.15ms +step:1112/1695 train_time:108031ms step_avg:97.15ms +step:1113/1695 train_time:108127ms step_avg:97.15ms +step:1114/1695 train_time:108223ms step_avg:97.15ms +step:1115/1695 train_time:108319ms step_avg:97.15ms +step:1116/1695 train_time:108416ms step_avg:97.15ms +step:1117/1695 train_time:108512ms step_avg:97.15ms +step:1118/1695 train_time:108608ms step_avg:97.15ms +step:1119/1695 train_time:108704ms step_avg:97.14ms +step:1120/1695 train_time:108799ms step_avg:97.14ms +step:1121/1695 train_time:108896ms step_avg:97.14ms +step:1122/1695 train_time:108993ms step_avg:97.14ms +step:1123/1695 train_time:109090ms step_avg:97.14ms +step:1124/1695 train_time:109186ms step_avg:97.14ms +step:1125/1695 train_time:109281ms step_avg:97.14ms +step:1125/1695 val_loss:3.4352 train_time:109375ms step_avg:97.22ms +step:1126/1695 train_time:109400ms step_avg:97.16ms +step:1127/1695 train_time:109483ms step_avg:97.15ms +step:1128/1695 train_time:109580ms step_avg:97.15ms +step:1129/1695 train_time:109676ms step_avg:97.14ms +step:1130/1695 train_time:109771ms step_avg:97.14ms +step:1131/1695 train_time:109866ms step_avg:97.14ms +step:1132/1695 train_time:109960ms step_avg:97.14ms +step:1133/1695 train_time:110056ms step_avg:97.14ms +step:1134/1695 train_time:110153ms step_avg:97.14ms +step:1135/1695 train_time:110251ms step_avg:97.14ms +step:1136/1695 train_time:110349ms step_avg:97.14ms +step:1137/1695 train_time:110450ms step_avg:97.14ms +step:1138/1695 train_time:110549ms step_avg:97.14ms +step:1139/1695 train_time:110648ms step_avg:97.14ms +step:1140/1695 train_time:110745ms step_avg:97.14ms +step:1141/1695 train_time:110841ms step_avg:97.14ms +step:1142/1695 train_time:110938ms step_avg:97.14ms +step:1143/1695 train_time:111036ms step_avg:97.14ms +step:1144/1695 train_time:111133ms step_avg:97.14ms +step:1145/1695 train_time:111230ms step_avg:97.14ms +step:1146/1695 train_time:111328ms step_avg:97.15ms +step:1147/1695 train_time:111427ms step_avg:97.15ms +step:1148/1695 train_time:111527ms step_avg:97.15ms +step:1149/1695 train_time:111626ms step_avg:97.15ms +step:1150/1695 train_time:111724ms step_avg:97.15ms +step:1151/1695 train_time:111820ms step_avg:97.15ms +step:1152/1695 train_time:111918ms step_avg:97.15ms +step:1153/1695 train_time:112015ms step_avg:97.15ms +step:1154/1695 train_time:112112ms step_avg:97.15ms +step:1155/1695 train_time:112210ms step_avg:97.15ms +step:1156/1695 train_time:112307ms step_avg:97.15ms +step:1157/1695 train_time:112405ms step_avg:97.15ms +step:1158/1695 train_time:112503ms step_avg:97.15ms +step:1159/1695 train_time:112602ms step_avg:97.15ms +step:1160/1695 train_time:112701ms step_avg:97.16ms +step:1161/1695 train_time:112798ms step_avg:97.16ms +step:1162/1695 train_time:112895ms step_avg:97.16ms +step:1163/1695 train_time:112993ms step_avg:97.16ms +step:1164/1695 train_time:113089ms step_avg:97.16ms +step:1165/1695 train_time:113186ms step_avg:97.16ms +step:1166/1695 train_time:113283ms step_avg:97.16ms +step:1167/1695 train_time:113381ms step_avg:97.16ms +step:1168/1695 train_time:113479ms step_avg:97.16ms +step:1169/1695 train_time:113578ms step_avg:97.16ms +step:1170/1695 train_time:113678ms step_avg:97.16ms +step:1171/1695 train_time:113777ms step_avg:97.16ms +step:1172/1695 train_time:113875ms step_avg:97.16ms +step:1173/1695 train_time:113973ms step_avg:97.16ms +step:1174/1695 train_time:114071ms step_avg:97.16ms +step:1175/1695 train_time:114168ms step_avg:97.16ms +step:1176/1695 train_time:114265ms step_avg:97.16ms +step:1177/1695 train_time:114362ms step_avg:97.16ms +step:1178/1695 train_time:114460ms step_avg:97.16ms +step:1179/1695 train_time:114558ms step_avg:97.17ms +step:1180/1695 train_time:114659ms step_avg:97.17ms +step:1181/1695 train_time:114757ms step_avg:97.17ms +step:1182/1695 train_time:114856ms step_avg:97.17ms +step:1183/1695 train_time:114955ms step_avg:97.17ms +step:1184/1695 train_time:115054ms step_avg:97.17ms +step:1185/1695 train_time:115152ms step_avg:97.17ms +step:1186/1695 train_time:115250ms step_avg:97.18ms +step:1187/1695 train_time:115348ms step_avg:97.18ms +step:1188/1695 train_time:115445ms step_avg:97.18ms +step:1189/1695 train_time:115543ms step_avg:97.18ms +step:1190/1695 train_time:115640ms step_avg:97.18ms +step:1191/1695 train_time:115739ms step_avg:97.18ms +step:1192/1695 train_time:115838ms step_avg:97.18ms +step:1193/1695 train_time:115938ms step_avg:97.18ms +step:1194/1695 train_time:116038ms step_avg:97.18ms +step:1195/1695 train_time:116138ms step_avg:97.19ms +step:1196/1695 train_time:116236ms step_avg:97.19ms +step:1197/1695 train_time:116336ms step_avg:97.19ms +step:1198/1695 train_time:116435ms step_avg:97.19ms +step:1199/1695 train_time:116533ms step_avg:97.19ms +step:1200/1695 train_time:116632ms step_avg:97.19ms +step:1201/1695 train_time:116730ms step_avg:97.19ms +step:1202/1695 train_time:116828ms step_avg:97.19ms +step:1203/1695 train_time:116927ms step_avg:97.20ms +step:1204/1695 train_time:117024ms step_avg:97.20ms +step:1205/1695 train_time:117122ms step_avg:97.20ms +step:1206/1695 train_time:117220ms step_avg:97.20ms +step:1207/1695 train_time:117318ms step_avg:97.20ms +step:1208/1695 train_time:117661ms step_avg:97.40ms +step:1209/1695 train_time:117847ms step_avg:97.47ms +step:1210/1695 train_time:117942ms step_avg:97.47ms +step:1211/1695 train_time:118039ms step_avg:97.47ms +step:1212/1695 train_time:118136ms step_avg:97.47ms +step:1213/1695 train_time:118233ms step_avg:97.47ms +step:1214/1695 train_time:118330ms step_avg:97.47ms +step:1215/1695 train_time:118426ms step_avg:97.47ms +step:1216/1695 train_time:118522ms step_avg:97.47ms +step:1217/1695 train_time:118618ms step_avg:97.47ms +step:1218/1695 train_time:118722ms step_avg:97.47ms +step:1219/1695 train_time:118826ms step_avg:97.48ms +step:1220/1695 train_time:118924ms step_avg:97.48ms +step:1221/1695 train_time:119021ms step_avg:97.48ms +step:1222/1695 train_time:119118ms step_avg:97.48ms +step:1223/1695 train_time:119215ms step_avg:97.48ms +step:1224/1695 train_time:119312ms step_avg:97.48ms +step:1225/1695 train_time:119409ms step_avg:97.48ms +step:1226/1695 train_time:119505ms step_avg:97.48ms +step:1227/1695 train_time:119602ms step_avg:97.48ms +step:1228/1695 train_time:119701ms step_avg:97.48ms +step:1229/1695 train_time:119802ms step_avg:97.48ms +step:1230/1695 train_time:119902ms step_avg:97.48ms +step:1231/1695 train_time:120000ms step_avg:97.48ms +step:1232/1695 train_time:120098ms step_avg:97.48ms +step:1233/1695 train_time:120196ms step_avg:97.48ms +step:1234/1695 train_time:120294ms step_avg:97.48ms +step:1235/1695 train_time:120391ms step_avg:97.48ms +step:1236/1695 train_time:120488ms step_avg:97.48ms +step:1237/1695 train_time:120585ms step_avg:97.48ms +step:1238/1695 train_time:120683ms step_avg:97.48ms +step:1239/1695 train_time:120781ms step_avg:97.48ms +step:1240/1695 train_time:120880ms step_avg:97.48ms +step:1241/1695 train_time:120978ms step_avg:97.48ms +step:1242/1695 train_time:121077ms step_avg:97.49ms +step:1243/1695 train_time:121174ms step_avg:97.48ms +step:1244/1695 train_time:121272ms step_avg:97.49ms +step:1245/1695 train_time:121369ms step_avg:97.48ms +step:1246/1695 train_time:121466ms step_avg:97.48ms +step:1247/1695 train_time:121562ms step_avg:97.48ms +step:1248/1695 train_time:121661ms step_avg:97.48ms +step:1249/1695 train_time:121759ms step_avg:97.49ms +step:1250/1695 train_time:121859ms step_avg:97.49ms +step:1250/1695 val_loss:3.3872 train_time:121956ms step_avg:97.56ms +step:1251/1695 train_time:121980ms step_avg:97.51ms +step:1252/1695 train_time:122061ms step_avg:97.49ms +step:1253/1695 train_time:122158ms step_avg:97.49ms +step:1254/1695 train_time:122255ms step_avg:97.49ms +step:1255/1695 train_time:122351ms step_avg:97.49ms +step:1256/1695 train_time:122448ms step_avg:97.49ms +step:1257/1695 train_time:122545ms step_avg:97.49ms +step:1258/1695 train_time:122641ms step_avg:97.49ms +step:1259/1695 train_time:122737ms step_avg:97.49ms +step:1260/1695 train_time:122835ms step_avg:97.49ms +step:1261/1695 train_time:122937ms step_avg:97.49ms +step:1262/1695 train_time:123037ms step_avg:97.49ms +step:1263/1695 train_time:123135ms step_avg:97.49ms +step:1264/1695 train_time:123233ms step_avg:97.49ms +step:1265/1695 train_time:123330ms step_avg:97.49ms +step:1266/1695 train_time:123427ms step_avg:97.49ms +step:1267/1695 train_time:123524ms step_avg:97.49ms +step:1268/1695 train_time:123621ms step_avg:97.49ms +step:1269/1695 train_time:123718ms step_avg:97.49ms +step:1270/1695 train_time:123816ms step_avg:97.49ms +step:1271/1695 train_time:123915ms step_avg:97.49ms +step:1272/1695 train_time:124014ms step_avg:97.50ms +step:1273/1695 train_time:124113ms step_avg:97.50ms +step:1274/1695 train_time:124210ms step_avg:97.50ms +step:1275/1695 train_time:124308ms step_avg:97.50ms +step:1276/1695 train_time:124406ms step_avg:97.50ms +step:1277/1695 train_time:124503ms step_avg:97.50ms +step:1278/1695 train_time:124600ms step_avg:97.50ms +step:1279/1695 train_time:124697ms step_avg:97.50ms +step:1280/1695 train_time:124795ms step_avg:97.50ms +step:1281/1695 train_time:124893ms step_avg:97.50ms +step:1282/1695 train_time:124991ms step_avg:97.50ms +step:1283/1695 train_time:125089ms step_avg:97.50ms +step:1284/1695 train_time:125187ms step_avg:97.50ms +step:1285/1695 train_time:125285ms step_avg:97.50ms +step:1286/1695 train_time:125383ms step_avg:97.50ms +step:1287/1695 train_time:125480ms step_avg:97.50ms +step:1288/1695 train_time:125578ms step_avg:97.50ms +step:1289/1695 train_time:125675ms step_avg:97.50ms +step:1290/1695 train_time:125772ms step_avg:97.50ms +step:1291/1695 train_time:125870ms step_avg:97.50ms +step:1292/1695 train_time:125968ms step_avg:97.50ms +step:1293/1695 train_time:126066ms step_avg:97.50ms +step:1294/1695 train_time:126165ms step_avg:97.50ms +step:1295/1695 train_time:126264ms step_avg:97.50ms +step:1296/1695 train_time:126364ms step_avg:97.50ms +step:1297/1695 train_time:126462ms step_avg:97.50ms +step:1298/1695 train_time:126560ms step_avg:97.50ms +step:1299/1695 train_time:126658ms step_avg:97.50ms +step:1300/1695 train_time:126756ms step_avg:97.50ms +step:1301/1695 train_time:126855ms step_avg:97.51ms +step:1302/1695 train_time:126952ms step_avg:97.51ms +step:1303/1695 train_time:127049ms step_avg:97.50ms +step:1304/1695 train_time:127147ms step_avg:97.51ms +step:1305/1695 train_time:127244ms step_avg:97.51ms +step:1306/1695 train_time:127343ms step_avg:97.51ms +step:1307/1695 train_time:127440ms step_avg:97.51ms +step:1308/1695 train_time:127539ms step_avg:97.51ms +step:1309/1695 train_time:127637ms step_avg:97.51ms +step:1310/1695 train_time:127735ms step_avg:97.51ms +step:1311/1695 train_time:127834ms step_avg:97.51ms +step:1312/1695 train_time:127931ms step_avg:97.51ms +step:1313/1695 train_time:128028ms step_avg:97.51ms +step:1314/1695 train_time:128126ms step_avg:97.51ms +step:1315/1695 train_time:128224ms step_avg:97.51ms +step:1316/1695 train_time:128322ms step_avg:97.51ms +step:1317/1695 train_time:128420ms step_avg:97.51ms +step:1318/1695 train_time:128518ms step_avg:97.51ms +step:1319/1695 train_time:128616ms step_avg:97.51ms +step:1320/1695 train_time:128714ms step_avg:97.51ms +step:1321/1695 train_time:128813ms step_avg:97.51ms +step:1322/1695 train_time:128910ms step_avg:97.51ms +step:1323/1695 train_time:129008ms step_avg:97.51ms +step:1324/1695 train_time:129106ms step_avg:97.51ms +step:1325/1695 train_time:129204ms step_avg:97.51ms +step:1326/1695 train_time:129303ms step_avg:97.51ms +step:1327/1695 train_time:129401ms step_avg:97.51ms +step:1328/1695 train_time:129499ms step_avg:97.51ms +step:1329/1695 train_time:129597ms step_avg:97.51ms +step:1330/1695 train_time:129695ms step_avg:97.51ms +step:1331/1695 train_time:129792ms step_avg:97.51ms +step:1332/1695 train_time:129890ms step_avg:97.52ms +step:1333/1695 train_time:129988ms step_avg:97.52ms +step:1334/1695 train_time:130085ms step_avg:97.51ms +step:1335/1695 train_time:130182ms step_avg:97.51ms +step:1336/1695 train_time:130281ms step_avg:97.52ms +step:1337/1695 train_time:130381ms step_avg:97.52ms +step:1338/1695 train_time:130478ms step_avg:97.52ms +step:1339/1695 train_time:130577ms step_avg:97.52ms +step:1340/1695 train_time:130674ms step_avg:97.52ms +step:1341/1695 train_time:130773ms step_avg:97.52ms +step:1342/1695 train_time:130870ms step_avg:97.52ms +step:1343/1695 train_time:130967ms step_avg:97.52ms +step:1344/1695 train_time:131063ms step_avg:97.52ms +step:1345/1695 train_time:131161ms step_avg:97.52ms +step:1346/1695 train_time:131259ms step_avg:97.52ms +step:1347/1695 train_time:131357ms step_avg:97.52ms +step:1348/1695 train_time:131455ms step_avg:97.52ms +step:1349/1695 train_time:131553ms step_avg:97.52ms +step:1350/1695 train_time:131651ms step_avg:97.52ms +step:1351/1695 train_time:131749ms step_avg:97.52ms +step:1352/1695 train_time:131847ms step_avg:97.52ms +step:1353/1695 train_time:131946ms step_avg:97.52ms +step:1354/1695 train_time:132044ms step_avg:97.52ms +step:1355/1695 train_time:132142ms step_avg:97.52ms +step:1356/1695 train_time:132239ms step_avg:97.52ms +step:1357/1695 train_time:132336ms step_avg:97.52ms +step:1358/1695 train_time:132434ms step_avg:97.52ms +step:1359/1695 train_time:132532ms step_avg:97.52ms +step:1360/1695 train_time:132629ms step_avg:97.52ms +step:1361/1695 train_time:132727ms step_avg:97.52ms +step:1362/1695 train_time:132825ms step_avg:97.52ms +step:1363/1695 train_time:132925ms step_avg:97.52ms +step:1364/1695 train_time:133023ms step_avg:97.52ms +step:1365/1695 train_time:133120ms step_avg:97.52ms +step:1366/1695 train_time:133218ms step_avg:97.52ms +step:1367/1695 train_time:133316ms step_avg:97.52ms +step:1368/1695 train_time:133413ms step_avg:97.52ms +step:1369/1695 train_time:133511ms step_avg:97.52ms +step:1370/1695 train_time:133609ms step_avg:97.52ms +step:1371/1695 train_time:133707ms step_avg:97.53ms +step:1372/1695 train_time:133805ms step_avg:97.53ms +step:1373/1695 train_time:133904ms step_avg:97.53ms +step:1374/1695 train_time:134003ms step_avg:97.53ms +step:1375/1695 train_time:134101ms step_avg:97.53ms +step:1375/1695 val_loss:3.3494 train_time:134197ms step_avg:97.60ms +step:1376/1695 train_time:134222ms step_avg:97.55ms +step:1377/1695 train_time:134308ms step_avg:97.54ms +step:1378/1695 train_time:134406ms step_avg:97.54ms +step:1379/1695 train_time:134504ms step_avg:97.54ms +step:1380/1695 train_time:134602ms step_avg:97.54ms +step:1381/1695 train_time:135056ms step_avg:97.80ms +step:1382/1695 train_time:135131ms step_avg:97.78ms +step:1383/1695 train_time:135227ms step_avg:97.78ms +step:1384/1695 train_time:135324ms step_avg:97.78ms +step:1385/1695 train_time:135420ms step_avg:97.78ms +step:1386/1695 train_time:135517ms step_avg:97.78ms +step:1387/1695 train_time:135613ms step_avg:97.77ms +step:1388/1695 train_time:135709ms step_avg:97.77ms +step:1389/1695 train_time:135806ms step_avg:97.77ms +step:1390/1695 train_time:135905ms step_avg:97.77ms +step:1391/1695 train_time:136009ms step_avg:97.78ms +step:1392/1695 train_time:136109ms step_avg:97.78ms +step:1393/1695 train_time:136207ms step_avg:97.78ms +step:1394/1695 train_time:136304ms step_avg:97.78ms +step:1395/1695 train_time:136402ms step_avg:97.78ms +step:1396/1695 train_time:136499ms step_avg:97.78ms +step:1397/1695 train_time:136596ms step_avg:97.78ms +step:1398/1695 train_time:136692ms step_avg:97.78ms +step:1399/1695 train_time:136788ms step_avg:97.78ms +step:1400/1695 train_time:136887ms step_avg:97.78ms +step:1401/1695 train_time:136985ms step_avg:97.78ms +step:1402/1695 train_time:137085ms step_avg:97.78ms +step:1403/1695 train_time:137185ms step_avg:97.78ms +step:1404/1695 train_time:137283ms step_avg:97.78ms +step:1405/1695 train_time:137381ms step_avg:97.78ms +step:1406/1695 train_time:137479ms step_avg:97.78ms +step:1407/1695 train_time:137577ms step_avg:97.78ms +step:1408/1695 train_time:137674ms step_avg:97.78ms +step:1409/1695 train_time:137771ms step_avg:97.78ms +step:1410/1695 train_time:137869ms step_avg:97.78ms +step:1411/1695 train_time:137967ms step_avg:97.78ms +step:1412/1695 train_time:138066ms step_avg:97.78ms +step:1413/1695 train_time:138164ms step_avg:97.78ms +step:1414/1695 train_time:138263ms step_avg:97.78ms +step:1415/1695 train_time:138361ms step_avg:97.78ms +step:1416/1695 train_time:138459ms step_avg:97.78ms +step:1417/1695 train_time:138557ms step_avg:97.78ms +step:1418/1695 train_time:138655ms step_avg:97.78ms +step:1419/1695 train_time:138753ms step_avg:97.78ms +step:1420/1695 train_time:138850ms step_avg:97.78ms +step:1421/1695 train_time:138947ms step_avg:97.78ms +step:1422/1695 train_time:139045ms step_avg:97.78ms +step:1423/1695 train_time:139143ms step_avg:97.78ms +step:1424/1695 train_time:139242ms step_avg:97.78ms +step:1425/1695 train_time:139340ms step_avg:97.78ms +step:1426/1695 train_time:139438ms step_avg:97.78ms +step:1427/1695 train_time:139535ms step_avg:97.78ms +step:1428/1695 train_time:139633ms step_avg:97.78ms +step:1429/1695 train_time:139730ms step_avg:97.78ms +step:1430/1695 train_time:139828ms step_avg:97.78ms +step:1431/1695 train_time:139926ms step_avg:97.78ms +step:1432/1695 train_time:140024ms step_avg:97.78ms +step:1433/1695 train_time:140122ms step_avg:97.78ms +step:1434/1695 train_time:140222ms step_avg:97.78ms +step:1435/1695 train_time:140319ms step_avg:97.78ms +step:1436/1695 train_time:140418ms step_avg:97.78ms +step:1437/1695 train_time:140516ms step_avg:97.78ms +step:1438/1695 train_time:140613ms step_avg:97.78ms +step:1439/1695 train_time:140711ms step_avg:97.78ms +step:1440/1695 train_time:140808ms step_avg:97.78ms +step:1441/1695 train_time:140905ms step_avg:97.78ms +step:1442/1695 train_time:141003ms step_avg:97.78ms +step:1443/1695 train_time:141101ms step_avg:97.78ms +step:1444/1695 train_time:141199ms step_avg:97.78ms +step:1445/1695 train_time:141297ms step_avg:97.78ms +step:1446/1695 train_time:141395ms step_avg:97.78ms +step:1447/1695 train_time:141493ms step_avg:97.78ms +step:1448/1695 train_time:141590ms step_avg:97.78ms +step:1449/1695 train_time:141687ms step_avg:97.78ms +step:1450/1695 train_time:141785ms step_avg:97.78ms +step:1451/1695 train_time:141882ms step_avg:97.78ms +step:1452/1695 train_time:141980ms step_avg:97.78ms +step:1453/1695 train_time:142078ms step_avg:97.78ms +step:1454/1695 train_time:142176ms step_avg:97.78ms +step:1455/1695 train_time:142274ms step_avg:97.78ms +step:1456/1695 train_time:142372ms step_avg:97.78ms +step:1457/1695 train_time:142470ms step_avg:97.78ms +step:1458/1695 train_time:142568ms step_avg:97.78ms +step:1459/1695 train_time:142665ms step_avg:97.78ms +step:1460/1695 train_time:142764ms step_avg:97.78ms +step:1461/1695 train_time:142863ms step_avg:97.78ms +step:1462/1695 train_time:142961ms step_avg:97.78ms +step:1463/1695 train_time:143059ms step_avg:97.78ms +step:1464/1695 train_time:143156ms step_avg:97.78ms +step:1465/1695 train_time:143253ms step_avg:97.78ms +step:1466/1695 train_time:143351ms step_avg:97.78ms +step:1467/1695 train_time:143448ms step_avg:97.78ms +step:1468/1695 train_time:143546ms step_avg:97.78ms +step:1469/1695 train_time:143644ms step_avg:97.78ms +step:1470/1695 train_time:143742ms step_avg:97.78ms +step:1471/1695 train_time:143841ms step_avg:97.78ms +step:1472/1695 train_time:143938ms step_avg:97.78ms +step:1473/1695 train_time:144036ms step_avg:97.78ms +step:1474/1695 train_time:144132ms step_avg:97.78ms +step:1475/1695 train_time:144231ms step_avg:97.78ms +step:1476/1695 train_time:144328ms step_avg:97.78ms +step:1477/1695 train_time:144426ms step_avg:97.78ms +step:1478/1695 train_time:144524ms step_avg:97.78ms +step:1479/1695 train_time:144623ms step_avg:97.78ms +step:1480/1695 train_time:144722ms step_avg:97.79ms +step:1481/1695 train_time:144821ms step_avg:97.79ms +step:1482/1695 train_time:144919ms step_avg:97.79ms +step:1483/1695 train_time:145017ms step_avg:97.79ms +step:1484/1695 train_time:145115ms step_avg:97.79ms +step:1485/1695 train_time:145212ms step_avg:97.79ms +step:1486/1695 train_time:145310ms step_avg:97.79ms +step:1487/1695 train_time:145407ms step_avg:97.79ms +step:1488/1695 train_time:145505ms step_avg:97.79ms +step:1489/1695 train_time:145603ms step_avg:97.79ms +step:1490/1695 train_time:145702ms step_avg:97.79ms +step:1491/1695 train_time:145799ms step_avg:97.79ms +step:1492/1695 train_time:145896ms step_avg:97.79ms +step:1493/1695 train_time:145994ms step_avg:97.79ms +step:1494/1695 train_time:146092ms step_avg:97.79ms +step:1495/1695 train_time:146190ms step_avg:97.79ms +step:1496/1695 train_time:146288ms step_avg:97.79ms +step:1497/1695 train_time:146385ms step_avg:97.79ms +step:1498/1695 train_time:146482ms step_avg:97.79ms +step:1499/1695 train_time:146580ms step_avg:97.79ms +step:1500/1695 train_time:146679ms step_avg:97.79ms +step:1500/1695 val_loss:3.3158 train_time:146775ms step_avg:97.85ms +step:1501/1695 train_time:146802ms step_avg:97.80ms +step:1502/1695 train_time:146885ms step_avg:97.79ms +step:1503/1695 train_time:146985ms step_avg:97.79ms +step:1504/1695 train_time:147082ms step_avg:97.79ms +step:1505/1695 train_time:147180ms step_avg:97.79ms +step:1506/1695 train_time:147276ms step_avg:97.79ms +step:1507/1695 train_time:147372ms step_avg:97.79ms +step:1508/1695 train_time:147469ms step_avg:97.79ms +step:1509/1695 train_time:147566ms step_avg:97.79ms +step:1510/1695 train_time:147663ms step_avg:97.79ms +step:1511/1695 train_time:147762ms step_avg:97.79ms +step:1512/1695 train_time:147865ms step_avg:97.79ms +step:1513/1695 train_time:147965ms step_avg:97.80ms +step:1514/1695 train_time:148064ms step_avg:97.80ms +step:1515/1695 train_time:148162ms step_avg:97.80ms +step:1516/1695 train_time:148260ms step_avg:97.80ms +step:1517/1695 train_time:148356ms step_avg:97.80ms +step:1518/1695 train_time:148454ms step_avg:97.80ms +step:1519/1695 train_time:148551ms step_avg:97.80ms +step:1520/1695 train_time:148647ms step_avg:97.79ms +step:1521/1695 train_time:148745ms step_avg:97.79ms +step:1522/1695 train_time:148844ms step_avg:97.79ms +step:1523/1695 train_time:148943ms step_avg:97.80ms +step:1524/1695 train_time:149041ms step_avg:97.80ms +step:1525/1695 train_time:149140ms step_avg:97.80ms +step:1526/1695 train_time:149238ms step_avg:97.80ms +step:1527/1695 train_time:149336ms step_avg:97.80ms +step:1528/1695 train_time:149434ms step_avg:97.80ms +step:1529/1695 train_time:149531ms step_avg:97.80ms +step:1530/1695 train_time:149628ms step_avg:97.80ms +step:1531/1695 train_time:149726ms step_avg:97.80ms +step:1532/1695 train_time:149824ms step_avg:97.80ms +step:1533/1695 train_time:149922ms step_avg:97.80ms +step:1534/1695 train_time:150020ms step_avg:97.80ms +step:1535/1695 train_time:150119ms step_avg:97.80ms +step:1536/1695 train_time:150217ms step_avg:97.80ms +step:1537/1695 train_time:150316ms step_avg:97.80ms +step:1538/1695 train_time:150413ms step_avg:97.80ms +step:1539/1695 train_time:150510ms step_avg:97.80ms +step:1540/1695 train_time:150607ms step_avg:97.80ms +step:1541/1695 train_time:150704ms step_avg:97.80ms +step:1542/1695 train_time:150803ms step_avg:97.80ms +step:1543/1695 train_time:150902ms step_avg:97.80ms +step:1544/1695 train_time:151000ms step_avg:97.80ms +step:1545/1695 train_time:151099ms step_avg:97.80ms +step:1546/1695 train_time:151198ms step_avg:97.80ms +step:1547/1695 train_time:151296ms step_avg:97.80ms +step:1548/1695 train_time:151394ms step_avg:97.80ms +step:1549/1695 train_time:151492ms step_avg:97.80ms +step:1550/1695 train_time:151590ms step_avg:97.80ms +step:1551/1695 train_time:151687ms step_avg:97.80ms +step:1552/1695 train_time:152039ms step_avg:97.96ms +step:1553/1695 train_time:152209ms step_avg:98.01ms +step:1554/1695 train_time:152305ms step_avg:98.01ms +step:1555/1695 train_time:152401ms step_avg:98.01ms +step:1556/1695 train_time:152498ms step_avg:98.01ms +step:1557/1695 train_time:152595ms step_avg:98.01ms +step:1558/1695 train_time:152692ms step_avg:98.00ms +step:1559/1695 train_time:152787ms step_avg:98.00ms +step:1560/1695 train_time:152884ms step_avg:98.00ms +step:1561/1695 train_time:152980ms step_avg:98.00ms +step:1562/1695 train_time:153085ms step_avg:98.01ms +step:1563/1695 train_time:153187ms step_avg:98.01ms +step:1564/1695 train_time:153287ms step_avg:98.01ms +step:1565/1695 train_time:153384ms step_avg:98.01ms +step:1566/1695 train_time:153482ms step_avg:98.01ms +step:1567/1695 train_time:153579ms step_avg:98.01ms +step:1568/1695 train_time:153677ms step_avg:98.01ms +step:1569/1695 train_time:153774ms step_avg:98.01ms +step:1570/1695 train_time:153871ms step_avg:98.01ms +step:1571/1695 train_time:153968ms step_avg:98.01ms +step:1572/1695 train_time:154066ms step_avg:98.01ms +step:1573/1695 train_time:154166ms step_avg:98.01ms +step:1574/1695 train_time:154265ms step_avg:98.01ms +step:1575/1695 train_time:154364ms step_avg:98.01ms +step:1576/1695 train_time:154461ms step_avg:98.01ms +step:1577/1695 train_time:154559ms step_avg:98.01ms +step:1578/1695 train_time:154656ms step_avg:98.01ms +step:1579/1695 train_time:154753ms step_avg:98.01ms +step:1580/1695 train_time:154850ms step_avg:98.01ms +step:1581/1695 train_time:154946ms step_avg:98.01ms +step:1582/1695 train_time:155044ms step_avg:98.01ms +step:1583/1695 train_time:155144ms step_avg:98.01ms +step:1584/1695 train_time:155243ms step_avg:98.01ms +step:1585/1695 train_time:155341ms step_avg:98.01ms +step:1586/1695 train_time:155439ms step_avg:98.01ms +step:1587/1695 train_time:155537ms step_avg:98.01ms +step:1588/1695 train_time:155634ms step_avg:98.01ms +step:1589/1695 train_time:155731ms step_avg:98.01ms +step:1590/1695 train_time:155829ms step_avg:98.01ms +step:1591/1695 train_time:155926ms step_avg:98.00ms +step:1592/1695 train_time:156023ms step_avg:98.00ms +step:1593/1695 train_time:156121ms step_avg:98.00ms +step:1594/1695 train_time:156220ms step_avg:98.01ms +step:1595/1695 train_time:156320ms step_avg:98.01ms +step:1596/1695 train_time:156419ms step_avg:98.01ms +step:1597/1695 train_time:156517ms step_avg:98.01ms +step:1598/1695 train_time:156615ms step_avg:98.01ms +step:1599/1695 train_time:156713ms step_avg:98.01ms +step:1600/1695 train_time:156810ms step_avg:98.01ms +step:1601/1695 train_time:156908ms step_avg:98.01ms +step:1602/1695 train_time:157005ms step_avg:98.01ms +step:1603/1695 train_time:157103ms step_avg:98.01ms +step:1604/1695 train_time:157201ms step_avg:98.01ms +step:1605/1695 train_time:157300ms step_avg:98.01ms +step:1606/1695 train_time:157399ms step_avg:98.01ms +step:1607/1695 train_time:157497ms step_avg:98.01ms +step:1608/1695 train_time:157595ms step_avg:98.01ms +step:1609/1695 train_time:157693ms step_avg:98.01ms +step:1610/1695 train_time:157791ms step_avg:98.01ms +step:1611/1695 train_time:157889ms step_avg:98.01ms +step:1612/1695 train_time:157986ms step_avg:98.01ms +step:1613/1695 train_time:158083ms step_avg:98.01ms +step:1614/1695 train_time:158181ms step_avg:98.01ms +step:1615/1695 train_time:158279ms step_avg:98.01ms +step:1616/1695 train_time:158378ms step_avg:98.01ms +step:1617/1695 train_time:158477ms step_avg:98.01ms +step:1618/1695 train_time:158575ms step_avg:98.01ms +step:1619/1695 train_time:158672ms step_avg:98.01ms +step:1620/1695 train_time:158771ms step_avg:98.01ms +step:1621/1695 train_time:158869ms step_avg:98.01ms +step:1622/1695 train_time:158967ms step_avg:98.01ms +step:1623/1695 train_time:159064ms step_avg:98.01ms +step:1624/1695 train_time:159161ms step_avg:98.01ms +step:1625/1695 train_time:159259ms step_avg:98.01ms +step:1625/1695 val_loss:3.2885 train_time:159356ms step_avg:98.07ms +step:1626/1695 train_time:159382ms step_avg:98.02ms +step:1627/1695 train_time:159464ms step_avg:98.01ms +step:1628/1695 train_time:159563ms step_avg:98.01ms +step:1629/1695 train_time:159661ms step_avg:98.01ms +step:1630/1695 train_time:159759ms step_avg:98.01ms +step:1631/1695 train_time:159856ms step_avg:98.01ms +step:1632/1695 train_time:159953ms step_avg:98.01ms +step:1633/1695 train_time:160051ms step_avg:98.01ms +step:1634/1695 train_time:160147ms step_avg:98.01ms +step:1635/1695 train_time:160244ms step_avg:98.01ms +step:1636/1695 train_time:160346ms step_avg:98.01ms +step:1637/1695 train_time:160445ms step_avg:98.01ms +step:1638/1695 train_time:160544ms step_avg:98.01ms +step:1639/1695 train_time:160643ms step_avg:98.01ms +step:1640/1695 train_time:160740ms step_avg:98.01ms +step:1641/1695 train_time:160839ms step_avg:98.01ms +step:1642/1695 train_time:160936ms step_avg:98.01ms +step:1643/1695 train_time:161034ms step_avg:98.01ms +step:1644/1695 train_time:161132ms step_avg:98.01ms +step:1645/1695 train_time:161230ms step_avg:98.01ms +step:1646/1695 train_time:161329ms step_avg:98.01ms +step:1647/1695 train_time:161427ms step_avg:98.01ms +step:1648/1695 train_time:161526ms step_avg:98.01ms +step:1649/1695 train_time:161624ms step_avg:98.01ms +step:1650/1695 train_time:161721ms step_avg:98.01ms +step:1651/1695 train_time:161819ms step_avg:98.01ms +step:1652/1695 train_time:161917ms step_avg:98.01ms +step:1653/1695 train_time:162016ms step_avg:98.01ms +step:1654/1695 train_time:162114ms step_avg:98.01ms +step:1655/1695 train_time:162213ms step_avg:98.01ms +step:1656/1695 train_time:162313ms step_avg:98.01ms +step:1657/1695 train_time:162411ms step_avg:98.02ms +step:1658/1695 train_time:162510ms step_avg:98.02ms +step:1659/1695 train_time:162608ms step_avg:98.02ms +step:1660/1695 train_time:162706ms step_avg:98.02ms +step:1661/1695 train_time:162804ms step_avg:98.02ms +step:1662/1695 train_time:162901ms step_avg:98.02ms +step:1663/1695 train_time:163000ms step_avg:98.02ms +step:1664/1695 train_time:163099ms step_avg:98.02ms +step:1665/1695 train_time:163198ms step_avg:98.02ms +step:1666/1695 train_time:163297ms step_avg:98.02ms +step:1667/1695 train_time:163396ms step_avg:98.02ms +step:1668/1695 train_time:163496ms step_avg:98.02ms +step:1669/1695 train_time:163595ms step_avg:98.02ms +step:1670/1695 train_time:163695ms step_avg:98.02ms +step:1671/1695 train_time:163793ms step_avg:98.02ms +step:1672/1695 train_time:163891ms step_avg:98.02ms +step:1673/1695 train_time:163988ms step_avg:98.02ms +step:1674/1695 train_time:164085ms step_avg:98.02ms +step:1675/1695 train_time:164182ms step_avg:98.02ms +step:1676/1695 train_time:164280ms step_avg:98.02ms +step:1677/1695 train_time:164379ms step_avg:98.02ms +step:1678/1695 train_time:164478ms step_avg:98.02ms +step:1679/1695 train_time:164578ms step_avg:98.02ms +step:1680/1695 train_time:164677ms step_avg:98.02ms +step:1681/1695 train_time:164774ms step_avg:98.02ms +step:1682/1695 train_time:164872ms step_avg:98.02ms +step:1683/1695 train_time:164969ms step_avg:98.02ms +step:1684/1695 train_time:165067ms step_avg:98.02ms +step:1685/1695 train_time:165164ms step_avg:98.02ms +step:1686/1695 train_time:165262ms step_avg:98.02ms +step:1687/1695 train_time:165359ms step_avg:98.02ms +step:1688/1695 train_time:165458ms step_avg:98.02ms +step:1689/1695 train_time:165558ms step_avg:98.02ms +step:1690/1695 train_time:165659ms step_avg:98.02ms +step:1691/1695 train_time:165757ms step_avg:98.02ms +step:1692/1695 train_time:165855ms step_avg:98.02ms +step:1693/1695 train_time:165953ms step_avg:98.02ms +step:1694/1695 train_time:166052ms step_avg:98.02ms +step:1695/1695 train_time:166151ms step_avg:98.02ms +step:1695/1695 val_loss:3.2769 train_time:166247ms step_avg:98.08ms +peak memory allocated: 34505 MiB reserved: 49576 MiB diff --git a/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt b/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt new file mode 100644 index 000000000..7e21a501e --- /dev/null +++ b/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 03:53:12 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 30C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 32C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 30C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1695 train_time:508ms step_avg:507.99ms +step:2/1695 train_time:531ms step_avg:265.69ms +step:3/1695 train_time:603ms step_avg:200.91ms +step:4/1695 train_time:695ms step_avg:173.68ms +step:5/1695 train_time:787ms step_avg:157.47ms +step:6/1695 train_time:881ms step_avg:146.76ms +step:7/1695 train_time:974ms step_avg:139.13ms +step:8/1695 train_time:1067ms step_avg:133.43ms +step:9/1695 train_time:1161ms step_avg:128.95ms +step:10/1695 train_time:1253ms step_avg:125.34ms +step:11/1695 train_time:1347ms step_avg:122.48ms +step:12/1695 train_time:1445ms step_avg:120.43ms +step:13/1695 train_time:1543ms step_avg:118.70ms +step:14/1695 train_time:1638ms step_avg:117.04ms +step:15/1695 train_time:1732ms step_avg:115.49ms +step:16/1695 train_time:1827ms step_avg:114.19ms +step:17/1695 train_time:1921ms step_avg:113.01ms +step:18/1695 train_time:2015ms step_avg:111.93ms +step:19/1695 train_time:2108ms step_avg:110.95ms +step:20/1695 train_time:2202ms step_avg:110.08ms +step:21/1695 train_time:2295ms step_avg:109.29ms +step:22/1695 train_time:2390ms step_avg:108.63ms +step:23/1695 train_time:2486ms step_avg:108.07ms +step:24/1695 train_time:2583ms step_avg:107.61ms +step:25/1695 train_time:2678ms step_avg:107.10ms +step:26/1695 train_time:2771ms step_avg:106.59ms +step:27/1695 train_time:2867ms step_avg:106.18ms +step:28/1695 train_time:2962ms step_avg:105.78ms +step:29/1695 train_time:3055ms step_avg:105.34ms +step:30/1695 train_time:3149ms step_avg:104.96ms +step:31/1695 train_time:3243ms step_avg:104.60ms +step:32/1695 train_time:3336ms step_avg:104.26ms +step:33/1695 train_time:3430ms step_avg:103.94ms +step:34/1695 train_time:3526ms step_avg:103.70ms +step:35/1695 train_time:3622ms step_avg:103.48ms +step:36/1695 train_time:3717ms step_avg:103.24ms +step:37/1695 train_time:3811ms step_avg:102.99ms +step:38/1695 train_time:3906ms step_avg:102.79ms +step:39/1695 train_time:4001ms step_avg:102.58ms +step:40/1695 train_time:4094ms step_avg:102.36ms +step:41/1695 train_time:4188ms step_avg:102.14ms +step:42/1695 train_time:4282ms step_avg:101.96ms +step:43/1695 train_time:4376ms step_avg:101.76ms +step:44/1695 train_time:4470ms step_avg:101.58ms +step:45/1695 train_time:4565ms step_avg:101.45ms +step:46/1695 train_time:4661ms step_avg:101.32ms +step:47/1695 train_time:4754ms step_avg:101.15ms +step:48/1695 train_time:4849ms step_avg:101.01ms +step:49/1695 train_time:4944ms step_avg:100.89ms +step:50/1695 train_time:5039ms step_avg:100.77ms +step:51/1695 train_time:5132ms step_avg:100.63ms +step:52/1695 train_time:5227ms step_avg:100.51ms +step:53/1695 train_time:5322ms step_avg:100.41ms +step:54/1695 train_time:5416ms step_avg:100.29ms +step:55/1695 train_time:5510ms step_avg:100.18ms +step:56/1695 train_time:5606ms step_avg:100.10ms +step:57/1695 train_time:5701ms step_avg:100.02ms +step:58/1695 train_time:5795ms step_avg:99.91ms +step:59/1695 train_time:5889ms step_avg:99.81ms +step:60/1695 train_time:5984ms step_avg:99.73ms +step:61/1695 train_time:6077ms step_avg:99.63ms +step:62/1695 train_time:6171ms step_avg:99.54ms +step:63/1695 train_time:6267ms step_avg:99.47ms +step:64/1695 train_time:6362ms step_avg:99.40ms +step:65/1695 train_time:6457ms step_avg:99.33ms +step:66/1695 train_time:6551ms step_avg:99.26ms +step:67/1695 train_time:6647ms step_avg:99.21ms +step:68/1695 train_time:6743ms step_avg:99.15ms +step:69/1695 train_time:6837ms step_avg:99.08ms +step:70/1695 train_time:6930ms step_avg:99.00ms +step:71/1695 train_time:7025ms step_avg:98.95ms +step:72/1695 train_time:7119ms step_avg:98.88ms +step:73/1695 train_time:7214ms step_avg:98.82ms +step:74/1695 train_time:7308ms step_avg:98.75ms +step:75/1695 train_time:7403ms step_avg:98.70ms +step:76/1695 train_time:7497ms step_avg:98.64ms +step:77/1695 train_time:7591ms step_avg:98.59ms +step:78/1695 train_time:7685ms step_avg:98.53ms +step:79/1695 train_time:7781ms step_avg:98.50ms +step:80/1695 train_time:7876ms step_avg:98.45ms +step:81/1695 train_time:7969ms step_avg:98.39ms +step:82/1695 train_time:8065ms step_avg:98.35ms +step:83/1695 train_time:8160ms step_avg:98.31ms +step:84/1695 train_time:8253ms step_avg:98.25ms +step:85/1695 train_time:8347ms step_avg:98.21ms +step:86/1695 train_time:8442ms step_avg:98.16ms +step:87/1695 train_time:8536ms step_avg:98.11ms +step:88/1695 train_time:8630ms step_avg:98.07ms +step:89/1695 train_time:8725ms step_avg:98.03ms +step:90/1695 train_time:8819ms step_avg:97.99ms +step:91/1695 train_time:8913ms step_avg:97.95ms +step:92/1695 train_time:9008ms step_avg:97.91ms +step:93/1695 train_time:9102ms step_avg:97.87ms +step:94/1695 train_time:9196ms step_avg:97.83ms +step:95/1695 train_time:9290ms step_avg:97.79ms +step:96/1695 train_time:9385ms step_avg:97.76ms +step:97/1695 train_time:9480ms step_avg:97.73ms +step:98/1695 train_time:9574ms step_avg:97.69ms +step:99/1695 train_time:9669ms step_avg:97.66ms +step:100/1695 train_time:9764ms step_avg:97.64ms +step:101/1695 train_time:9858ms step_avg:97.61ms +step:102/1695 train_time:9952ms step_avg:97.57ms +step:103/1695 train_time:10047ms step_avg:97.54ms +step:104/1695 train_time:10141ms step_avg:97.51ms +step:105/1695 train_time:10235ms step_avg:97.47ms +step:106/1695 train_time:10329ms step_avg:97.44ms +step:107/1695 train_time:10424ms step_avg:97.42ms +step:108/1695 train_time:10519ms step_avg:97.40ms +step:109/1695 train_time:10613ms step_avg:97.37ms +step:110/1695 train_time:10708ms step_avg:97.34ms +step:111/1695 train_time:10802ms step_avg:97.32ms +step:112/1695 train_time:10896ms step_avg:97.29ms +step:113/1695 train_time:10990ms step_avg:97.25ms +step:114/1695 train_time:11084ms step_avg:97.22ms +step:115/1695 train_time:11178ms step_avg:97.20ms +step:116/1695 train_time:11271ms step_avg:97.17ms +step:117/1695 train_time:11366ms step_avg:97.14ms +step:118/1695 train_time:11460ms step_avg:97.12ms +step:119/1695 train_time:11555ms step_avg:97.10ms +step:120/1695 train_time:11649ms step_avg:97.07ms +step:121/1695 train_time:11744ms step_avg:97.05ms +step:122/1695 train_time:11839ms step_avg:97.04ms +step:123/1695 train_time:11933ms step_avg:97.01ms +step:124/1695 train_time:12028ms step_avg:97.00ms +step:125/1695 train_time:12122ms step_avg:96.98ms +step:125/1695 val_loss:4.3129 train_time:12214ms step_avg:97.71ms +step:126/1695 train_time:12238ms step_avg:97.13ms +step:127/1695 train_time:12320ms step_avg:97.01ms +step:128/1695 train_time:12421ms step_avg:97.04ms +step:129/1695 train_time:12516ms step_avg:97.02ms +step:130/1695 train_time:12609ms step_avg:97.00ms +step:131/1695 train_time:12702ms step_avg:96.96ms +step:132/1695 train_time:12795ms step_avg:96.94ms +step:133/1695 train_time:12889ms step_avg:96.91ms +step:134/1695 train_time:12982ms step_avg:96.88ms +step:135/1695 train_time:13075ms step_avg:96.85ms +step:136/1695 train_time:13168ms step_avg:96.83ms +step:137/1695 train_time:13264ms step_avg:96.82ms +step:138/1695 train_time:13359ms step_avg:96.81ms +step:139/1695 train_time:13455ms step_avg:96.80ms +step:140/1695 train_time:13550ms step_avg:96.79ms +step:141/1695 train_time:13645ms step_avg:96.77ms +step:142/1695 train_time:13739ms step_avg:96.75ms +step:143/1695 train_time:13832ms step_avg:96.73ms +step:144/1695 train_time:13926ms step_avg:96.71ms +step:145/1695 train_time:14019ms step_avg:96.68ms +step:146/1695 train_time:14112ms step_avg:96.66ms +step:147/1695 train_time:14206ms step_avg:96.64ms +step:148/1695 train_time:14302ms step_avg:96.63ms +step:149/1695 train_time:14396ms step_avg:96.62ms +step:150/1695 train_time:14492ms step_avg:96.61ms +step:151/1695 train_time:14586ms step_avg:96.60ms +step:152/1695 train_time:14681ms step_avg:96.59ms +step:153/1695 train_time:14775ms step_avg:96.57ms +step:154/1695 train_time:14869ms step_avg:96.55ms +step:155/1695 train_time:14963ms step_avg:96.53ms +step:156/1695 train_time:15056ms step_avg:96.51ms +step:157/1695 train_time:15149ms step_avg:96.49ms +step:158/1695 train_time:15245ms step_avg:96.49ms +step:159/1695 train_time:15340ms step_avg:96.48ms +step:160/1695 train_time:15434ms step_avg:96.46ms +step:161/1695 train_time:15529ms step_avg:96.45ms +step:162/1695 train_time:15624ms step_avg:96.45ms +step:163/1695 train_time:15719ms step_avg:96.44ms +step:164/1695 train_time:15813ms step_avg:96.42ms +step:165/1695 train_time:15907ms step_avg:96.40ms +step:166/1695 train_time:16001ms step_avg:96.39ms +step:167/1695 train_time:16095ms step_avg:96.38ms +step:168/1695 train_time:16189ms step_avg:96.36ms +step:169/1695 train_time:16283ms step_avg:96.35ms +step:170/1695 train_time:16378ms step_avg:96.34ms +step:171/1695 train_time:16471ms step_avg:96.32ms +step:172/1695 train_time:16567ms step_avg:96.32ms +step:173/1695 train_time:16951ms step_avg:97.98ms +step:174/1695 train_time:17020ms step_avg:97.82ms +step:175/1695 train_time:17112ms step_avg:97.78ms +step:176/1695 train_time:17206ms step_avg:97.76ms +step:177/1695 train_time:17299ms step_avg:97.73ms +step:178/1695 train_time:17391ms step_avg:97.70ms +step:179/1695 train_time:17484ms step_avg:97.68ms +step:180/1695 train_time:17578ms step_avg:97.65ms +step:181/1695 train_time:17671ms step_avg:97.63ms +step:182/1695 train_time:17764ms step_avg:97.61ms +step:183/1695 train_time:17859ms step_avg:97.59ms +step:184/1695 train_time:17956ms step_avg:97.59ms +step:185/1695 train_time:18052ms step_avg:97.58ms +step:186/1695 train_time:18148ms step_avg:97.57ms +step:187/1695 train_time:18243ms step_avg:97.56ms +step:188/1695 train_time:18336ms step_avg:97.53ms +step:189/1695 train_time:18429ms step_avg:97.51ms +step:190/1695 train_time:18522ms step_avg:97.49ms +step:191/1695 train_time:18615ms step_avg:97.46ms +step:192/1695 train_time:18709ms step_avg:97.44ms +step:193/1695 train_time:18803ms step_avg:97.42ms +step:194/1695 train_time:18898ms step_avg:97.41ms +step:195/1695 train_time:18993ms step_avg:97.40ms +step:196/1695 train_time:19088ms step_avg:97.39ms +step:197/1695 train_time:19184ms step_avg:97.38ms +step:198/1695 train_time:19278ms step_avg:97.37ms +step:199/1695 train_time:19372ms step_avg:97.35ms +step:200/1695 train_time:19466ms step_avg:97.33ms +step:201/1695 train_time:19560ms step_avg:97.31ms +step:202/1695 train_time:19652ms step_avg:97.29ms +step:203/1695 train_time:19747ms step_avg:97.28ms +step:204/1695 train_time:19841ms step_avg:97.26ms +step:205/1695 train_time:19935ms step_avg:97.24ms +step:206/1695 train_time:20030ms step_avg:97.23ms +step:207/1695 train_time:20124ms step_avg:97.22ms +step:208/1695 train_time:20219ms step_avg:97.21ms +step:209/1695 train_time:20312ms step_avg:97.19ms +step:210/1695 train_time:20406ms step_avg:97.17ms +step:211/1695 train_time:20500ms step_avg:97.16ms +step:212/1695 train_time:20593ms step_avg:97.14ms +step:213/1695 train_time:20686ms step_avg:97.12ms +step:214/1695 train_time:20781ms step_avg:97.11ms +step:215/1695 train_time:20874ms step_avg:97.09ms +step:216/1695 train_time:20969ms step_avg:97.08ms +step:217/1695 train_time:21063ms step_avg:97.07ms +step:218/1695 train_time:21158ms step_avg:97.05ms +step:219/1695 train_time:21252ms step_avg:97.04ms +step:220/1695 train_time:21348ms step_avg:97.04ms +step:221/1695 train_time:21442ms step_avg:97.02ms +step:222/1695 train_time:21535ms step_avg:97.01ms +step:223/1695 train_time:21629ms step_avg:96.99ms +step:224/1695 train_time:21723ms step_avg:96.98ms +step:225/1695 train_time:21817ms step_avg:96.97ms +step:226/1695 train_time:21911ms step_avg:96.95ms +step:227/1695 train_time:22005ms step_avg:96.94ms +step:228/1695 train_time:22099ms step_avg:96.93ms +step:229/1695 train_time:22193ms step_avg:96.91ms +step:230/1695 train_time:22287ms step_avg:96.90ms +step:231/1695 train_time:22382ms step_avg:96.89ms +step:232/1695 train_time:22476ms step_avg:96.88ms +step:233/1695 train_time:22570ms step_avg:96.87ms +step:234/1695 train_time:22663ms step_avg:96.85ms +step:235/1695 train_time:22756ms step_avg:96.83ms +step:236/1695 train_time:22850ms step_avg:96.82ms +step:237/1695 train_time:22945ms step_avg:96.81ms +step:238/1695 train_time:23040ms step_avg:96.81ms +step:239/1695 train_time:23134ms step_avg:96.79ms +step:240/1695 train_time:23228ms step_avg:96.78ms +step:241/1695 train_time:23323ms step_avg:96.78ms +step:242/1695 train_time:23418ms step_avg:96.77ms +step:243/1695 train_time:23512ms step_avg:96.76ms +step:244/1695 train_time:23606ms step_avg:96.75ms +step:245/1695 train_time:23699ms step_avg:96.73ms +step:246/1695 train_time:23792ms step_avg:96.72ms +step:247/1695 train_time:23886ms step_avg:96.71ms +step:248/1695 train_time:23981ms step_avg:96.70ms +step:249/1695 train_time:24075ms step_avg:96.69ms +step:250/1695 train_time:24169ms step_avg:96.68ms +step:250/1695 val_loss:3.9787 train_time:24262ms step_avg:97.05ms +step:251/1695 train_time:24286ms step_avg:96.76ms +step:252/1695 train_time:24364ms step_avg:96.68ms +step:253/1695 train_time:24461ms step_avg:96.68ms +step:254/1695 train_time:24555ms step_avg:96.67ms +step:255/1695 train_time:24648ms step_avg:96.66ms +step:256/1695 train_time:24742ms step_avg:96.65ms +step:257/1695 train_time:24835ms step_avg:96.63ms +step:258/1695 train_time:24928ms step_avg:96.62ms +step:259/1695 train_time:25021ms step_avg:96.61ms +step:260/1695 train_time:25114ms step_avg:96.59ms +step:261/1695 train_time:25208ms step_avg:96.58ms +step:262/1695 train_time:25304ms step_avg:96.58ms +step:263/1695 train_time:25401ms step_avg:96.58ms +step:264/1695 train_time:25495ms step_avg:96.57ms +step:265/1695 train_time:25590ms step_avg:96.57ms +step:266/1695 train_time:25684ms step_avg:96.56ms +step:267/1695 train_time:25777ms step_avg:96.54ms +step:268/1695 train_time:25870ms step_avg:96.53ms +step:269/1695 train_time:25964ms step_avg:96.52ms +step:270/1695 train_time:26058ms step_avg:96.51ms +step:271/1695 train_time:26151ms step_avg:96.50ms +step:272/1695 train_time:26245ms step_avg:96.49ms +step:273/1695 train_time:26341ms step_avg:96.49ms +step:274/1695 train_time:26436ms step_avg:96.48ms +step:275/1695 train_time:26531ms step_avg:96.48ms +step:276/1695 train_time:26626ms step_avg:96.47ms +step:277/1695 train_time:26720ms step_avg:96.46ms +step:278/1695 train_time:26813ms step_avg:96.45ms +step:279/1695 train_time:26906ms step_avg:96.44ms +step:280/1695 train_time:26999ms step_avg:96.43ms +step:281/1695 train_time:27092ms step_avg:96.41ms +step:282/1695 train_time:27187ms step_avg:96.41ms +step:283/1695 train_time:27282ms step_avg:96.40ms +step:284/1695 train_time:27376ms step_avg:96.40ms +step:285/1695 train_time:27471ms step_avg:96.39ms +step:286/1695 train_time:27566ms step_avg:96.39ms +step:287/1695 train_time:27661ms step_avg:96.38ms +step:288/1695 train_time:27754ms step_avg:96.37ms +step:289/1695 train_time:27849ms step_avg:96.36ms +step:290/1695 train_time:27943ms step_avg:96.35ms +step:291/1695 train_time:28037ms step_avg:96.35ms +step:292/1695 train_time:28130ms step_avg:96.33ms +step:293/1695 train_time:28224ms step_avg:96.33ms +step:294/1695 train_time:28318ms step_avg:96.32ms +step:295/1695 train_time:28412ms step_avg:96.31ms +step:296/1695 train_time:28507ms step_avg:96.31ms +step:297/1695 train_time:28603ms step_avg:96.30ms +step:298/1695 train_time:28697ms step_avg:96.30ms +step:299/1695 train_time:28791ms step_avg:96.29ms +step:300/1695 train_time:28885ms step_avg:96.28ms +step:301/1695 train_time:28979ms step_avg:96.27ms +step:302/1695 train_time:29072ms step_avg:96.27ms +step:303/1695 train_time:29166ms step_avg:96.26ms +step:304/1695 train_time:29261ms step_avg:96.25ms +step:305/1695 train_time:29355ms step_avg:96.25ms +step:306/1695 train_time:29449ms step_avg:96.24ms +step:307/1695 train_time:29544ms step_avg:96.23ms +step:308/1695 train_time:29638ms step_avg:96.23ms +step:309/1695 train_time:29732ms step_avg:96.22ms +step:310/1695 train_time:29827ms step_avg:96.21ms +step:311/1695 train_time:29921ms step_avg:96.21ms +step:312/1695 train_time:30014ms step_avg:96.20ms +step:313/1695 train_time:30108ms step_avg:96.19ms +step:314/1695 train_time:30203ms step_avg:96.19ms +step:315/1695 train_time:30297ms step_avg:96.18ms +step:316/1695 train_time:30391ms step_avg:96.17ms +step:317/1695 train_time:30486ms step_avg:96.17ms +step:318/1695 train_time:30580ms step_avg:96.16ms +step:319/1695 train_time:30674ms step_avg:96.16ms +step:320/1695 train_time:30769ms step_avg:96.15ms +step:321/1695 train_time:30864ms step_avg:96.15ms +step:322/1695 train_time:30957ms step_avg:96.14ms +step:323/1695 train_time:31050ms step_avg:96.13ms +step:324/1695 train_time:31145ms step_avg:96.13ms +step:325/1695 train_time:31239ms step_avg:96.12ms +step:326/1695 train_time:31333ms step_avg:96.11ms +step:327/1695 train_time:31428ms step_avg:96.11ms +step:328/1695 train_time:31523ms step_avg:96.11ms +step:329/1695 train_time:31616ms step_avg:96.10ms +step:330/1695 train_time:31710ms step_avg:96.09ms +step:331/1695 train_time:31805ms step_avg:96.09ms +step:332/1695 train_time:31899ms step_avg:96.08ms +step:333/1695 train_time:31993ms step_avg:96.07ms +step:334/1695 train_time:32087ms step_avg:96.07ms +step:335/1695 train_time:32180ms step_avg:96.06ms +step:336/1695 train_time:32274ms step_avg:96.05ms +step:337/1695 train_time:32368ms step_avg:96.05ms +step:338/1695 train_time:32468ms step_avg:96.06ms +step:339/1695 train_time:32561ms step_avg:96.05ms +step:340/1695 train_time:32655ms step_avg:96.04ms +step:341/1695 train_time:32748ms step_avg:96.04ms +step:342/1695 train_time:32839ms step_avg:96.02ms +step:343/1695 train_time:32933ms step_avg:96.01ms +step:344/1695 train_time:33027ms step_avg:96.01ms +step:345/1695 train_time:33355ms step_avg:96.68ms +step:346/1695 train_time:33456ms step_avg:96.69ms +step:347/1695 train_time:33548ms step_avg:96.68ms +step:348/1695 train_time:33642ms step_avg:96.67ms +step:349/1695 train_time:33735ms step_avg:96.66ms +step:350/1695 train_time:33828ms step_avg:96.65ms +step:351/1695 train_time:33921ms step_avg:96.64ms +step:352/1695 train_time:34014ms step_avg:96.63ms +step:353/1695 train_time:34106ms step_avg:96.62ms +step:354/1695 train_time:34200ms step_avg:96.61ms +step:355/1695 train_time:34296ms step_avg:96.61ms +step:356/1695 train_time:34393ms step_avg:96.61ms +step:357/1695 train_time:34490ms step_avg:96.61ms +step:358/1695 train_time:34584ms step_avg:96.60ms +step:359/1695 train_time:34678ms step_avg:96.60ms +step:360/1695 train_time:34771ms step_avg:96.59ms +step:361/1695 train_time:34865ms step_avg:96.58ms +step:362/1695 train_time:34959ms step_avg:96.57ms +step:363/1695 train_time:35051ms step_avg:96.56ms +step:364/1695 train_time:35145ms step_avg:96.55ms +step:365/1695 train_time:35238ms step_avg:96.54ms +step:366/1695 train_time:35332ms step_avg:96.54ms +step:367/1695 train_time:35428ms step_avg:96.53ms +step:368/1695 train_time:35523ms step_avg:96.53ms +step:369/1695 train_time:35616ms step_avg:96.52ms +step:370/1695 train_time:35710ms step_avg:96.51ms +step:371/1695 train_time:35803ms step_avg:96.51ms +step:372/1695 train_time:35897ms step_avg:96.50ms +step:373/1695 train_time:35990ms step_avg:96.49ms +step:374/1695 train_time:36083ms step_avg:96.48ms +step:375/1695 train_time:36176ms step_avg:96.47ms +step:375/1695 val_loss:3.8232 train_time:36268ms step_avg:96.71ms +step:376/1695 train_time:36292ms step_avg:96.52ms +step:377/1695 train_time:36372ms step_avg:96.48ms +step:378/1695 train_time:36470ms step_avg:96.48ms +step:379/1695 train_time:36565ms step_avg:96.48ms +step:380/1695 train_time:36659ms step_avg:96.47ms +step:381/1695 train_time:36752ms step_avg:96.46ms +step:382/1695 train_time:36845ms step_avg:96.45ms +step:383/1695 train_time:36938ms step_avg:96.44ms +step:384/1695 train_time:37031ms step_avg:96.43ms +step:385/1695 train_time:37124ms step_avg:96.43ms +step:386/1695 train_time:37218ms step_avg:96.42ms +step:387/1695 train_time:37313ms step_avg:96.42ms +step:388/1695 train_time:37410ms step_avg:96.42ms +step:389/1695 train_time:37505ms step_avg:96.41ms +step:390/1695 train_time:37599ms step_avg:96.41ms +step:391/1695 train_time:37693ms step_avg:96.40ms +step:392/1695 train_time:37787ms step_avg:96.40ms +step:393/1695 train_time:37880ms step_avg:96.39ms +step:394/1695 train_time:37973ms step_avg:96.38ms +step:395/1695 train_time:38067ms step_avg:96.37ms +step:396/1695 train_time:38160ms step_avg:96.36ms +step:397/1695 train_time:38254ms step_avg:96.36ms +step:398/1695 train_time:38350ms step_avg:96.36ms +step:399/1695 train_time:38445ms step_avg:96.35ms +step:400/1695 train_time:38540ms step_avg:96.35ms +step:401/1695 train_time:38633ms step_avg:96.34ms +step:402/1695 train_time:38727ms step_avg:96.34ms +step:403/1695 train_time:38822ms step_avg:96.33ms +step:404/1695 train_time:38915ms step_avg:96.32ms +step:405/1695 train_time:39008ms step_avg:96.32ms +step:406/1695 train_time:39101ms step_avg:96.31ms +step:407/1695 train_time:39194ms step_avg:96.30ms +step:408/1695 train_time:39289ms step_avg:96.30ms +step:409/1695 train_time:39384ms step_avg:96.29ms +step:410/1695 train_time:39478ms step_avg:96.29ms +step:411/1695 train_time:39572ms step_avg:96.28ms +step:412/1695 train_time:39667ms step_avg:96.28ms +step:413/1695 train_time:39762ms step_avg:96.28ms +step:414/1695 train_time:39856ms step_avg:96.27ms +step:415/1695 train_time:39949ms step_avg:96.26ms +step:416/1695 train_time:40043ms step_avg:96.26ms +step:417/1695 train_time:40137ms step_avg:96.25ms +step:418/1695 train_time:40230ms step_avg:96.24ms +step:419/1695 train_time:40324ms step_avg:96.24ms +step:420/1695 train_time:40419ms step_avg:96.24ms +step:421/1695 train_time:40513ms step_avg:96.23ms +step:422/1695 train_time:40607ms step_avg:96.22ms +step:423/1695 train_time:40702ms step_avg:96.22ms +step:424/1695 train_time:40795ms step_avg:96.21ms +step:425/1695 train_time:40889ms step_avg:96.21ms +step:426/1695 train_time:40984ms step_avg:96.21ms +step:427/1695 train_time:41077ms step_avg:96.20ms +step:428/1695 train_time:41170ms step_avg:96.19ms +step:429/1695 train_time:41264ms step_avg:96.19ms +step:430/1695 train_time:41359ms step_avg:96.18ms +step:431/1695 train_time:41453ms step_avg:96.18ms +step:432/1695 train_time:41547ms step_avg:96.17ms +step:433/1695 train_time:41641ms step_avg:96.17ms +step:434/1695 train_time:41734ms step_avg:96.16ms +step:435/1695 train_time:41828ms step_avg:96.16ms +step:436/1695 train_time:41922ms step_avg:96.15ms +step:437/1695 train_time:42016ms step_avg:96.15ms +step:438/1695 train_time:42110ms step_avg:96.14ms +step:439/1695 train_time:42203ms step_avg:96.13ms +step:440/1695 train_time:42297ms step_avg:96.13ms +step:441/1695 train_time:42392ms step_avg:96.13ms +step:442/1695 train_time:42487ms step_avg:96.12ms +step:443/1695 train_time:42583ms step_avg:96.12ms +step:444/1695 train_time:42676ms step_avg:96.12ms +step:445/1695 train_time:42770ms step_avg:96.11ms +step:446/1695 train_time:42864ms step_avg:96.11ms +step:447/1695 train_time:42957ms step_avg:96.10ms +step:448/1695 train_time:43050ms step_avg:96.09ms +step:449/1695 train_time:43144ms step_avg:96.09ms +step:450/1695 train_time:43237ms step_avg:96.08ms +step:451/1695 train_time:43331ms step_avg:96.08ms +step:452/1695 train_time:43426ms step_avg:96.07ms +step:453/1695 train_time:43519ms step_avg:96.07ms +step:454/1695 train_time:43613ms step_avg:96.06ms +step:455/1695 train_time:43707ms step_avg:96.06ms +step:456/1695 train_time:43801ms step_avg:96.06ms +step:457/1695 train_time:43895ms step_avg:96.05ms +step:458/1695 train_time:43990ms step_avg:96.05ms +step:459/1695 train_time:44084ms step_avg:96.04ms +step:460/1695 train_time:44178ms step_avg:96.04ms +step:461/1695 train_time:44271ms step_avg:96.03ms +step:462/1695 train_time:44365ms step_avg:96.03ms +step:463/1695 train_time:44459ms step_avg:96.02ms +step:464/1695 train_time:44553ms step_avg:96.02ms +step:465/1695 train_time:44647ms step_avg:96.02ms +step:466/1695 train_time:44742ms step_avg:96.01ms +step:467/1695 train_time:44836ms step_avg:96.01ms +step:468/1695 train_time:44930ms step_avg:96.00ms +step:469/1695 train_time:45025ms step_avg:96.00ms +step:470/1695 train_time:45119ms step_avg:96.00ms +step:471/1695 train_time:45212ms step_avg:95.99ms +step:472/1695 train_time:45306ms step_avg:95.99ms +step:473/1695 train_time:45401ms step_avg:95.98ms +step:474/1695 train_time:45494ms step_avg:95.98ms +step:475/1695 train_time:45589ms step_avg:95.98ms +step:476/1695 train_time:45683ms step_avg:95.97ms +step:477/1695 train_time:45777ms step_avg:95.97ms +step:478/1695 train_time:45870ms step_avg:95.96ms +step:479/1695 train_time:45964ms step_avg:95.96ms +step:480/1695 train_time:46059ms step_avg:95.96ms +step:481/1695 train_time:46153ms step_avg:95.95ms +step:482/1695 train_time:46247ms step_avg:95.95ms +step:483/1695 train_time:46341ms step_avg:95.94ms +step:484/1695 train_time:46435ms step_avg:95.94ms +step:485/1695 train_time:46529ms step_avg:95.94ms +step:486/1695 train_time:46623ms step_avg:95.93ms +step:487/1695 train_time:46718ms step_avg:95.93ms +step:488/1695 train_time:46811ms step_avg:95.92ms +step:489/1695 train_time:46905ms step_avg:95.92ms +step:490/1695 train_time:46998ms step_avg:95.91ms +step:491/1695 train_time:47092ms step_avg:95.91ms +step:492/1695 train_time:47186ms step_avg:95.91ms +step:493/1695 train_time:47280ms step_avg:95.90ms +step:494/1695 train_time:47374ms step_avg:95.90ms +step:495/1695 train_time:47468ms step_avg:95.89ms +step:496/1695 train_time:47562ms step_avg:95.89ms +step:497/1695 train_time:47657ms step_avg:95.89ms +step:498/1695 train_time:47751ms step_avg:95.88ms +step:499/1695 train_time:47845ms step_avg:95.88ms +step:500/1695 train_time:47938ms step_avg:95.88ms +step:500/1695 val_loss:3.7202 train_time:48030ms step_avg:96.06ms +step:501/1695 train_time:48054ms step_avg:95.92ms +step:502/1695 train_time:48133ms step_avg:95.88ms +step:503/1695 train_time:48232ms step_avg:95.89ms +step:504/1695 train_time:48327ms step_avg:95.89ms +step:505/1695 train_time:48419ms step_avg:95.88ms +step:506/1695 train_time:48513ms step_avg:95.88ms +step:507/1695 train_time:48607ms step_avg:95.87ms +step:508/1695 train_time:48699ms step_avg:95.86ms +step:509/1695 train_time:48792ms step_avg:95.86ms +step:510/1695 train_time:48885ms step_avg:95.85ms +step:511/1695 train_time:48979ms step_avg:95.85ms +step:512/1695 train_time:49076ms step_avg:95.85ms +step:513/1695 train_time:49173ms step_avg:95.85ms +step:514/1695 train_time:49268ms step_avg:95.85ms +step:515/1695 train_time:49363ms step_avg:95.85ms +step:516/1695 train_time:49456ms step_avg:95.84ms +step:517/1695 train_time:49549ms step_avg:95.84ms +step:518/1695 train_time:49643ms step_avg:95.84ms +step:519/1695 train_time:49968ms step_avg:96.28ms +step:520/1695 train_time:50168ms step_avg:96.48ms +step:521/1695 train_time:50261ms step_avg:96.47ms +step:522/1695 train_time:50353ms step_avg:96.46ms +step:523/1695 train_time:50446ms step_avg:96.46ms +step:524/1695 train_time:50539ms step_avg:96.45ms +step:525/1695 train_time:50632ms step_avg:96.44ms +step:526/1695 train_time:50725ms step_avg:96.43ms +step:527/1695 train_time:50817ms step_avg:96.43ms +step:528/1695 train_time:50910ms step_avg:96.42ms +step:529/1695 train_time:51008ms step_avg:96.42ms +step:530/1695 train_time:51106ms step_avg:96.43ms +step:531/1695 train_time:51202ms step_avg:96.43ms +step:532/1695 train_time:51296ms step_avg:96.42ms +step:533/1695 train_time:51389ms step_avg:96.41ms +step:534/1695 train_time:51482ms step_avg:96.41ms +step:535/1695 train_time:51575ms step_avg:96.40ms +step:536/1695 train_time:51668ms step_avg:96.40ms +step:537/1695 train_time:51761ms step_avg:96.39ms +step:538/1695 train_time:51854ms step_avg:96.38ms +step:539/1695 train_time:51949ms step_avg:96.38ms +step:540/1695 train_time:52044ms step_avg:96.38ms +step:541/1695 train_time:52139ms step_avg:96.38ms +step:542/1695 train_time:52234ms step_avg:96.37ms +step:543/1695 train_time:52328ms step_avg:96.37ms +step:544/1695 train_time:52421ms step_avg:96.36ms +step:545/1695 train_time:52515ms step_avg:96.36ms +step:546/1695 train_time:52609ms step_avg:96.35ms +step:547/1695 train_time:52702ms step_avg:96.35ms +step:548/1695 train_time:52795ms step_avg:96.34ms +step:549/1695 train_time:52889ms step_avg:96.34ms +step:550/1695 train_time:52983ms step_avg:96.33ms +step:551/1695 train_time:53077ms step_avg:96.33ms +step:552/1695 train_time:53172ms step_avg:96.33ms +step:553/1695 train_time:53267ms step_avg:96.32ms +step:554/1695 train_time:53361ms step_avg:96.32ms +step:555/1695 train_time:53454ms step_avg:96.31ms +step:556/1695 train_time:53548ms step_avg:96.31ms +step:557/1695 train_time:53640ms step_avg:96.30ms +step:558/1695 train_time:53734ms step_avg:96.30ms +step:559/1695 train_time:53827ms step_avg:96.29ms +step:560/1695 train_time:53920ms step_avg:96.29ms +step:561/1695 train_time:54015ms step_avg:96.28ms +step:562/1695 train_time:54111ms step_avg:96.28ms +step:563/1695 train_time:54207ms step_avg:96.28ms +step:564/1695 train_time:54301ms step_avg:96.28ms +step:565/1695 train_time:54395ms step_avg:96.27ms +step:566/1695 train_time:54488ms step_avg:96.27ms +step:567/1695 train_time:54583ms step_avg:96.27ms +step:568/1695 train_time:54678ms step_avg:96.26ms +step:569/1695 train_time:54774ms step_avg:96.26ms +step:570/1695 train_time:54869ms step_avg:96.26ms +step:571/1695 train_time:54965ms step_avg:96.26ms +step:572/1695 train_time:55061ms step_avg:96.26ms +step:573/1695 train_time:55157ms step_avg:96.26ms +step:574/1695 train_time:55253ms step_avg:96.26ms +step:575/1695 train_time:55349ms step_avg:96.26ms +step:576/1695 train_time:55445ms step_avg:96.26ms +step:577/1695 train_time:55540ms step_avg:96.26ms +step:578/1695 train_time:55636ms step_avg:96.26ms +step:579/1695 train_time:55732ms step_avg:96.26ms +step:580/1695 train_time:55828ms step_avg:96.25ms +step:581/1695 train_time:55923ms step_avg:96.25ms +step:582/1695 train_time:56018ms step_avg:96.25ms +step:583/1695 train_time:56116ms step_avg:96.25ms +step:584/1695 train_time:56213ms step_avg:96.25ms +step:585/1695 train_time:56310ms step_avg:96.26ms +step:586/1695 train_time:56407ms step_avg:96.26ms +step:587/1695 train_time:56502ms step_avg:96.26ms +step:588/1695 train_time:56597ms step_avg:96.25ms +step:589/1695 train_time:56693ms step_avg:96.25ms +step:590/1695 train_time:56789ms step_avg:96.25ms +step:591/1695 train_time:56885ms step_avg:96.25ms +step:592/1695 train_time:56981ms step_avg:96.25ms +step:593/1695 train_time:57077ms step_avg:96.25ms +step:594/1695 train_time:57173ms step_avg:96.25ms +step:595/1695 train_time:57270ms step_avg:96.25ms +step:596/1695 train_time:57365ms step_avg:96.25ms +step:597/1695 train_time:57461ms step_avg:96.25ms +step:598/1695 train_time:57556ms step_avg:96.25ms +step:599/1695 train_time:57653ms step_avg:96.25ms +step:600/1695 train_time:57749ms step_avg:96.25ms +step:601/1695 train_time:57845ms step_avg:96.25ms +step:602/1695 train_time:57941ms step_avg:96.25ms +step:603/1695 train_time:58037ms step_avg:96.25ms +step:604/1695 train_time:58134ms step_avg:96.25ms +step:605/1695 train_time:58230ms step_avg:96.25ms +step:606/1695 train_time:58326ms step_avg:96.25ms +step:607/1695 train_time:58421ms step_avg:96.25ms +step:608/1695 train_time:58518ms step_avg:96.25ms +step:609/1695 train_time:58614ms step_avg:96.25ms +step:610/1695 train_time:58711ms step_avg:96.25ms +step:611/1695 train_time:58808ms step_avg:96.25ms +step:612/1695 train_time:58904ms step_avg:96.25ms +step:613/1695 train_time:59000ms step_avg:96.25ms +step:614/1695 train_time:59095ms step_avg:96.25ms +step:615/1695 train_time:59192ms step_avg:96.25ms +step:616/1695 train_time:59289ms step_avg:96.25ms +step:617/1695 train_time:59385ms step_avg:96.25ms +step:618/1695 train_time:59481ms step_avg:96.25ms +step:619/1695 train_time:59576ms step_avg:96.25ms +step:620/1695 train_time:59673ms step_avg:96.25ms +step:621/1695 train_time:59769ms step_avg:96.25ms +step:622/1695 train_time:59865ms step_avg:96.25ms +step:623/1695 train_time:59960ms step_avg:96.24ms +step:624/1695 train_time:60056ms step_avg:96.24ms +step:625/1695 train_time:60152ms step_avg:96.24ms +step:625/1695 val_loss:3.6216 train_time:60246ms step_avg:96.39ms +step:626/1695 train_time:60272ms step_avg:96.28ms +step:627/1695 train_time:60351ms step_avg:96.25ms +step:628/1695 train_time:60448ms step_avg:96.25ms +step:629/1695 train_time:60544ms step_avg:96.25ms +step:630/1695 train_time:60638ms step_avg:96.25ms +step:631/1695 train_time:60733ms step_avg:96.25ms +step:632/1695 train_time:60828ms step_avg:96.25ms +step:633/1695 train_time:60923ms step_avg:96.25ms +step:634/1695 train_time:61018ms step_avg:96.24ms +step:635/1695 train_time:61112ms step_avg:96.24ms +step:636/1695 train_time:61210ms step_avg:96.24ms +step:637/1695 train_time:61308ms step_avg:96.25ms +step:638/1695 train_time:61406ms step_avg:96.25ms +step:639/1695 train_time:61503ms step_avg:96.25ms +step:640/1695 train_time:61598ms step_avg:96.25ms +step:641/1695 train_time:61693ms step_avg:96.24ms +step:642/1695 train_time:61788ms step_avg:96.24ms +step:643/1695 train_time:61883ms step_avg:96.24ms +step:644/1695 train_time:61977ms step_avg:96.24ms +step:645/1695 train_time:62072ms step_avg:96.24ms +step:646/1695 train_time:62168ms step_avg:96.23ms +step:647/1695 train_time:62264ms step_avg:96.24ms +step:648/1695 train_time:62361ms step_avg:96.24ms +step:649/1695 train_time:62458ms step_avg:96.24ms +step:650/1695 train_time:62554ms step_avg:96.24ms +step:651/1695 train_time:62651ms step_avg:96.24ms +step:652/1695 train_time:62747ms step_avg:96.24ms +step:653/1695 train_time:62843ms step_avg:96.24ms +step:654/1695 train_time:62937ms step_avg:96.23ms +step:655/1695 train_time:63034ms step_avg:96.23ms +step:656/1695 train_time:63130ms step_avg:96.24ms +step:657/1695 train_time:63228ms step_avg:96.24ms +step:658/1695 train_time:63324ms step_avg:96.24ms +step:659/1695 train_time:63420ms step_avg:96.24ms +step:660/1695 train_time:63515ms step_avg:96.24ms +step:661/1695 train_time:63612ms step_avg:96.24ms +step:662/1695 train_time:63707ms step_avg:96.23ms +step:663/1695 train_time:63803ms step_avg:96.23ms +step:664/1695 train_time:63897ms step_avg:96.23ms +step:665/1695 train_time:63993ms step_avg:96.23ms +step:666/1695 train_time:64089ms step_avg:96.23ms +step:667/1695 train_time:64186ms step_avg:96.23ms +step:668/1695 train_time:64282ms step_avg:96.23ms +step:669/1695 train_time:64377ms step_avg:96.23ms +step:670/1695 train_time:64473ms step_avg:96.23ms +step:671/1695 train_time:64570ms step_avg:96.23ms +step:672/1695 train_time:64666ms step_avg:96.23ms +step:673/1695 train_time:64762ms step_avg:96.23ms +step:674/1695 train_time:64857ms step_avg:96.23ms +step:675/1695 train_time:64954ms step_avg:96.23ms +step:676/1695 train_time:65048ms step_avg:96.23ms +step:677/1695 train_time:65144ms step_avg:96.22ms +step:678/1695 train_time:65238ms step_avg:96.22ms +step:679/1695 train_time:65335ms step_avg:96.22ms +step:680/1695 train_time:65431ms step_avg:96.22ms +step:681/1695 train_time:65528ms step_avg:96.22ms +step:682/1695 train_time:65624ms step_avg:96.22ms +step:683/1695 train_time:65720ms step_avg:96.22ms +step:684/1695 train_time:65815ms step_avg:96.22ms +step:685/1695 train_time:65911ms step_avg:96.22ms +step:686/1695 train_time:66007ms step_avg:96.22ms +step:687/1695 train_time:66103ms step_avg:96.22ms +step:688/1695 train_time:66199ms step_avg:96.22ms +step:689/1695 train_time:66294ms step_avg:96.22ms +step:690/1695 train_time:66390ms step_avg:96.22ms +step:691/1695 train_time:66847ms step_avg:96.74ms +step:692/1695 train_time:66917ms step_avg:96.70ms +step:693/1695 train_time:67011ms step_avg:96.70ms +step:694/1695 train_time:67106ms step_avg:96.69ms +step:695/1695 train_time:67201ms step_avg:96.69ms +step:696/1695 train_time:67296ms step_avg:96.69ms +step:697/1695 train_time:67392ms step_avg:96.69ms +step:698/1695 train_time:67487ms step_avg:96.69ms +step:699/1695 train_time:67581ms step_avg:96.68ms +step:700/1695 train_time:67676ms step_avg:96.68ms +step:701/1695 train_time:67776ms step_avg:96.69ms +step:702/1695 train_time:67879ms step_avg:96.69ms +step:703/1695 train_time:67976ms step_avg:96.69ms +step:704/1695 train_time:68073ms step_avg:96.69ms +step:705/1695 train_time:68169ms step_avg:96.69ms +step:706/1695 train_time:68266ms step_avg:96.69ms +step:707/1695 train_time:68361ms step_avg:96.69ms +step:708/1695 train_time:68455ms step_avg:96.69ms +step:709/1695 train_time:68550ms step_avg:96.69ms +step:710/1695 train_time:68645ms step_avg:96.68ms +step:711/1695 train_time:68741ms step_avg:96.68ms +step:712/1695 train_time:68838ms step_avg:96.68ms +step:713/1695 train_time:68936ms step_avg:96.68ms +step:714/1695 train_time:69034ms step_avg:96.69ms +step:715/1695 train_time:69130ms step_avg:96.69ms +step:716/1695 train_time:69226ms step_avg:96.68ms +step:717/1695 train_time:69321ms step_avg:96.68ms +step:718/1695 train_time:69416ms step_avg:96.68ms +step:719/1695 train_time:69512ms step_avg:96.68ms +step:720/1695 train_time:69607ms step_avg:96.68ms +step:721/1695 train_time:69702ms step_avg:96.67ms +step:722/1695 train_time:69799ms step_avg:96.67ms +step:723/1695 train_time:69895ms step_avg:96.67ms +step:724/1695 train_time:69992ms step_avg:96.67ms +step:725/1695 train_time:70088ms step_avg:96.67ms +step:726/1695 train_time:70185ms step_avg:96.67ms +step:727/1695 train_time:70280ms step_avg:96.67ms +step:728/1695 train_time:70375ms step_avg:96.67ms +step:729/1695 train_time:70471ms step_avg:96.67ms +step:730/1695 train_time:70566ms step_avg:96.67ms +step:731/1695 train_time:70661ms step_avg:96.66ms +step:732/1695 train_time:70757ms step_avg:96.66ms +step:733/1695 train_time:70853ms step_avg:96.66ms +step:734/1695 train_time:70951ms step_avg:96.66ms +step:735/1695 train_time:71049ms step_avg:96.66ms +step:736/1695 train_time:71145ms step_avg:96.66ms +step:737/1695 train_time:71240ms step_avg:96.66ms +step:738/1695 train_time:71336ms step_avg:96.66ms +step:739/1695 train_time:71432ms step_avg:96.66ms +step:740/1695 train_time:71528ms step_avg:96.66ms +step:741/1695 train_time:71623ms step_avg:96.66ms +step:742/1695 train_time:71719ms step_avg:96.66ms +step:743/1695 train_time:71814ms step_avg:96.65ms +step:744/1695 train_time:71910ms step_avg:96.65ms +step:745/1695 train_time:72007ms step_avg:96.65ms +step:746/1695 train_time:72103ms step_avg:96.65ms +step:747/1695 train_time:72199ms step_avg:96.65ms +step:748/1695 train_time:72295ms step_avg:96.65ms +step:749/1695 train_time:72391ms step_avg:96.65ms +step:750/1695 train_time:72488ms step_avg:96.65ms +step:750/1695 val_loss:3.5671 train_time:72581ms step_avg:96.77ms +step:751/1695 train_time:72608ms step_avg:96.68ms +step:752/1695 train_time:72687ms step_avg:96.66ms +step:753/1695 train_time:72784ms step_avg:96.66ms +step:754/1695 train_time:72880ms step_avg:96.66ms +step:755/1695 train_time:72976ms step_avg:96.66ms +step:756/1695 train_time:73072ms step_avg:96.66ms +step:757/1695 train_time:73167ms step_avg:96.65ms +step:758/1695 train_time:73262ms step_avg:96.65ms +step:759/1695 train_time:73357ms step_avg:96.65ms +step:760/1695 train_time:73452ms step_avg:96.65ms +step:761/1695 train_time:73549ms step_avg:96.65ms +step:762/1695 train_time:73646ms step_avg:96.65ms +step:763/1695 train_time:73744ms step_avg:96.65ms +step:764/1695 train_time:73841ms step_avg:96.65ms +step:765/1695 train_time:73938ms step_avg:96.65ms +step:766/1695 train_time:74034ms step_avg:96.65ms +step:767/1695 train_time:74129ms step_avg:96.65ms +step:768/1695 train_time:74224ms step_avg:96.65ms +step:769/1695 train_time:74319ms step_avg:96.64ms +step:770/1695 train_time:74414ms step_avg:96.64ms +step:771/1695 train_time:74509ms step_avg:96.64ms +step:772/1695 train_time:74606ms step_avg:96.64ms +step:773/1695 train_time:74703ms step_avg:96.64ms +step:774/1695 train_time:74800ms step_avg:96.64ms +step:775/1695 train_time:74897ms step_avg:96.64ms +step:776/1695 train_time:74994ms step_avg:96.64ms +step:777/1695 train_time:75090ms step_avg:96.64ms +step:778/1695 train_time:75185ms step_avg:96.64ms +step:779/1695 train_time:75280ms step_avg:96.64ms +step:780/1695 train_time:75375ms step_avg:96.63ms +step:781/1695 train_time:75472ms step_avg:96.63ms +step:782/1695 train_time:75567ms step_avg:96.63ms +step:783/1695 train_time:75663ms step_avg:96.63ms +step:784/1695 train_time:75760ms step_avg:96.63ms +step:785/1695 train_time:75857ms step_avg:96.63ms +step:786/1695 train_time:75953ms step_avg:96.63ms +step:787/1695 train_time:76049ms step_avg:96.63ms +step:788/1695 train_time:76144ms step_avg:96.63ms +step:789/1695 train_time:76239ms step_avg:96.63ms +step:790/1695 train_time:76335ms step_avg:96.63ms +step:791/1695 train_time:76433ms step_avg:96.63ms +step:792/1695 train_time:76529ms step_avg:96.63ms +step:793/1695 train_time:76624ms step_avg:96.63ms +step:794/1695 train_time:76721ms step_avg:96.63ms +step:795/1695 train_time:76818ms step_avg:96.63ms +step:796/1695 train_time:76916ms step_avg:96.63ms +step:797/1695 train_time:77012ms step_avg:96.63ms +step:798/1695 train_time:77108ms step_avg:96.63ms +step:799/1695 train_time:77202ms step_avg:96.62ms +step:800/1695 train_time:77298ms step_avg:96.62ms +step:801/1695 train_time:77394ms step_avg:96.62ms +step:802/1695 train_time:77490ms step_avg:96.62ms +step:803/1695 train_time:77585ms step_avg:96.62ms +step:804/1695 train_time:77681ms step_avg:96.62ms +step:805/1695 train_time:77778ms step_avg:96.62ms +step:806/1695 train_time:77874ms step_avg:96.62ms +step:807/1695 train_time:77971ms step_avg:96.62ms +step:808/1695 train_time:78067ms step_avg:96.62ms +step:809/1695 train_time:78163ms step_avg:96.62ms +step:810/1695 train_time:78258ms step_avg:96.62ms +step:811/1695 train_time:78355ms step_avg:96.61ms +step:812/1695 train_time:78450ms step_avg:96.61ms +step:813/1695 train_time:78545ms step_avg:96.61ms +step:814/1695 train_time:78641ms step_avg:96.61ms +step:815/1695 train_time:78738ms step_avg:96.61ms +step:816/1695 train_time:78835ms step_avg:96.61ms +step:817/1695 train_time:78932ms step_avg:96.61ms +step:818/1695 train_time:79028ms step_avg:96.61ms +step:819/1695 train_time:79123ms step_avg:96.61ms +step:820/1695 train_time:79219ms step_avg:96.61ms +step:821/1695 train_time:79316ms step_avg:96.61ms +step:822/1695 train_time:79412ms step_avg:96.61ms +step:823/1695 train_time:79508ms step_avg:96.61ms +step:824/1695 train_time:79604ms step_avg:96.61ms +step:825/1695 train_time:79700ms step_avg:96.61ms +step:826/1695 train_time:79796ms step_avg:96.60ms +step:827/1695 train_time:79892ms step_avg:96.61ms +step:828/1695 train_time:79989ms step_avg:96.61ms +step:829/1695 train_time:80084ms step_avg:96.60ms +step:830/1695 train_time:80180ms step_avg:96.60ms +step:831/1695 train_time:80276ms step_avg:96.60ms +step:832/1695 train_time:80373ms step_avg:96.60ms +step:833/1695 train_time:80469ms step_avg:96.60ms +step:834/1695 train_time:80565ms step_avg:96.60ms +step:835/1695 train_time:80660ms step_avg:96.60ms +step:836/1695 train_time:80756ms step_avg:96.60ms +step:837/1695 train_time:80853ms step_avg:96.60ms +step:838/1695 train_time:80949ms step_avg:96.60ms +step:839/1695 train_time:81045ms step_avg:96.60ms +step:840/1695 train_time:81141ms step_avg:96.60ms +step:841/1695 train_time:81237ms step_avg:96.60ms +step:842/1695 train_time:81333ms step_avg:96.59ms +step:843/1695 train_time:81429ms step_avg:96.59ms +step:844/1695 train_time:81524ms step_avg:96.59ms +step:845/1695 train_time:81619ms step_avg:96.59ms +step:846/1695 train_time:81716ms step_avg:96.59ms +step:847/1695 train_time:81812ms step_avg:96.59ms +step:848/1695 train_time:81908ms step_avg:96.59ms +step:849/1695 train_time:82003ms step_avg:96.59ms +step:850/1695 train_time:82099ms step_avg:96.59ms +step:851/1695 train_time:82196ms step_avg:96.59ms +step:852/1695 train_time:82292ms step_avg:96.59ms +step:853/1695 train_time:82388ms step_avg:96.59ms +step:854/1695 train_time:82484ms step_avg:96.59ms +step:855/1695 train_time:82579ms step_avg:96.58ms +step:856/1695 train_time:82676ms step_avg:96.58ms +step:857/1695 train_time:82773ms step_avg:96.58ms +step:858/1695 train_time:82869ms step_avg:96.58ms +step:859/1695 train_time:82964ms step_avg:96.58ms +step:860/1695 train_time:83060ms step_avg:96.58ms +step:861/1695 train_time:83156ms step_avg:96.58ms +step:862/1695 train_time:83252ms step_avg:96.58ms +step:863/1695 train_time:83584ms step_avg:96.85ms +step:864/1695 train_time:83778ms step_avg:96.96ms +step:865/1695 train_time:83872ms step_avg:96.96ms +step:866/1695 train_time:83966ms step_avg:96.96ms +step:867/1695 train_time:84061ms step_avg:96.96ms +step:868/1695 train_time:84156ms step_avg:96.95ms +step:869/1695 train_time:84250ms step_avg:96.95ms +step:870/1695 train_time:84345ms step_avg:96.95ms +step:871/1695 train_time:84440ms step_avg:96.95ms +step:872/1695 train_time:84535ms step_avg:96.94ms +step:873/1695 train_time:84637ms step_avg:96.95ms +step:874/1695 train_time:84737ms step_avg:96.95ms +step:875/1695 train_time:84837ms step_avg:96.96ms +step:875/1695 val_loss:3.5244 train_time:84933ms step_avg:97.07ms +step:876/1695 train_time:84957ms step_avg:96.98ms +step:877/1695 train_time:85038ms step_avg:96.96ms +step:878/1695 train_time:85138ms step_avg:96.97ms +step:879/1695 train_time:85235ms step_avg:96.97ms +step:880/1695 train_time:85330ms step_avg:96.97ms +step:881/1695 train_time:85425ms step_avg:96.96ms +step:882/1695 train_time:85520ms step_avg:96.96ms +step:883/1695 train_time:85614ms step_avg:96.96ms +step:884/1695 train_time:85709ms step_avg:96.96ms +step:885/1695 train_time:85804ms step_avg:96.95ms +step:886/1695 train_time:85901ms step_avg:96.95ms +step:887/1695 train_time:86000ms step_avg:96.96ms +step:888/1695 train_time:86100ms step_avg:96.96ms +step:889/1695 train_time:86198ms step_avg:96.96ms +step:890/1695 train_time:86295ms step_avg:96.96ms +step:891/1695 train_time:86391ms step_avg:96.96ms +step:892/1695 train_time:86486ms step_avg:96.96ms +step:893/1695 train_time:86581ms step_avg:96.96ms +step:894/1695 train_time:86677ms step_avg:96.95ms +step:895/1695 train_time:86772ms step_avg:96.95ms +step:896/1695 train_time:86868ms step_avg:96.95ms +step:897/1695 train_time:86964ms step_avg:96.95ms +step:898/1695 train_time:87062ms step_avg:96.95ms +step:899/1695 train_time:87159ms step_avg:96.95ms +step:900/1695 train_time:87256ms step_avg:96.95ms +step:901/1695 train_time:87351ms step_avg:96.95ms +step:902/1695 train_time:87447ms step_avg:96.95ms +step:903/1695 train_time:87542ms step_avg:96.95ms +step:904/1695 train_time:87638ms step_avg:96.95ms +step:905/1695 train_time:87734ms step_avg:96.94ms +step:906/1695 train_time:87829ms step_avg:96.94ms +step:907/1695 train_time:87925ms step_avg:96.94ms +step:908/1695 train_time:88021ms step_avg:96.94ms +step:909/1695 train_time:88118ms step_avg:96.94ms +step:910/1695 train_time:88216ms step_avg:96.94ms +step:911/1695 train_time:88312ms step_avg:96.94ms +step:912/1695 train_time:88407ms step_avg:96.94ms +step:913/1695 train_time:88503ms step_avg:96.94ms +step:914/1695 train_time:88599ms step_avg:96.94ms +step:915/1695 train_time:88695ms step_avg:96.93ms +step:916/1695 train_time:88791ms step_avg:96.93ms +step:917/1695 train_time:88887ms step_avg:96.93ms +step:918/1695 train_time:88983ms step_avg:96.93ms +step:919/1695 train_time:89080ms step_avg:96.93ms +step:920/1695 train_time:89177ms step_avg:96.93ms +step:921/1695 train_time:89274ms step_avg:96.93ms +step:922/1695 train_time:89370ms step_avg:96.93ms +step:923/1695 train_time:89467ms step_avg:96.93ms +step:924/1695 train_time:89562ms step_avg:96.93ms +step:925/1695 train_time:89658ms step_avg:96.93ms +step:926/1695 train_time:89752ms step_avg:96.92ms +step:927/1695 train_time:89848ms step_avg:96.92ms +step:928/1695 train_time:89944ms step_avg:96.92ms +step:929/1695 train_time:90041ms step_avg:96.92ms +step:930/1695 train_time:90138ms step_avg:96.92ms +step:931/1695 train_time:90235ms step_avg:96.92ms +step:932/1695 train_time:90332ms step_avg:96.92ms +step:933/1695 train_time:90428ms step_avg:96.92ms +step:934/1695 train_time:90523ms step_avg:96.92ms +step:935/1695 train_time:90620ms step_avg:96.92ms +step:936/1695 train_time:90717ms step_avg:96.92ms +step:937/1695 train_time:90813ms step_avg:96.92ms +step:938/1695 train_time:90908ms step_avg:96.92ms +step:939/1695 train_time:91003ms step_avg:96.92ms +step:940/1695 train_time:91099ms step_avg:96.91ms +step:941/1695 train_time:91195ms step_avg:96.91ms +step:942/1695 train_time:91291ms step_avg:96.91ms +step:943/1695 train_time:91386ms step_avg:96.91ms +step:944/1695 train_time:91482ms step_avg:96.91ms +step:945/1695 train_time:91578ms step_avg:96.91ms +step:946/1695 train_time:91675ms step_avg:96.91ms +step:947/1695 train_time:91771ms step_avg:96.91ms +step:948/1695 train_time:91866ms step_avg:96.91ms +step:949/1695 train_time:91962ms step_avg:96.90ms +step:950/1695 train_time:92057ms step_avg:96.90ms +step:951/1695 train_time:92153ms step_avg:96.90ms +step:952/1695 train_time:92249ms step_avg:96.90ms +step:953/1695 train_time:92345ms step_avg:96.90ms +step:954/1695 train_time:92441ms step_avg:96.90ms +step:955/1695 train_time:92538ms step_avg:96.90ms +step:956/1695 train_time:92636ms step_avg:96.90ms +step:957/1695 train_time:92732ms step_avg:96.90ms +step:958/1695 train_time:92827ms step_avg:96.90ms +step:959/1695 train_time:92923ms step_avg:96.90ms +step:960/1695 train_time:93018ms step_avg:96.89ms +step:961/1695 train_time:93114ms step_avg:96.89ms +step:962/1695 train_time:93210ms step_avg:96.89ms +step:963/1695 train_time:93305ms step_avg:96.89ms +step:964/1695 train_time:93401ms step_avg:96.89ms +step:965/1695 train_time:93496ms step_avg:96.89ms +step:966/1695 train_time:93592ms step_avg:96.89ms +step:967/1695 train_time:93688ms step_avg:96.88ms +step:968/1695 train_time:93783ms step_avg:96.88ms +step:969/1695 train_time:93880ms step_avg:96.88ms +step:970/1695 train_time:93976ms step_avg:96.88ms +step:971/1695 train_time:94073ms step_avg:96.88ms +step:972/1695 train_time:94168ms step_avg:96.88ms +step:973/1695 train_time:94263ms step_avg:96.88ms +step:974/1695 train_time:94360ms step_avg:96.88ms +step:975/1695 train_time:94456ms step_avg:96.88ms +step:976/1695 train_time:94553ms step_avg:96.88ms +step:977/1695 train_time:94648ms step_avg:96.88ms +step:978/1695 train_time:94743ms step_avg:96.87ms +step:979/1695 train_time:94840ms step_avg:96.87ms +step:980/1695 train_time:94937ms step_avg:96.87ms +step:981/1695 train_time:95033ms step_avg:96.87ms +step:982/1695 train_time:95129ms step_avg:96.87ms +step:983/1695 train_time:95224ms step_avg:96.87ms +step:984/1695 train_time:95320ms step_avg:96.87ms +step:985/1695 train_time:95416ms step_avg:96.87ms +step:986/1695 train_time:95512ms step_avg:96.87ms +step:987/1695 train_time:95608ms step_avg:96.87ms +step:988/1695 train_time:95703ms step_avg:96.87ms +step:989/1695 train_time:95799ms step_avg:96.86ms +step:990/1695 train_time:95895ms step_avg:96.86ms +step:991/1695 train_time:95991ms step_avg:96.86ms +step:992/1695 train_time:96087ms step_avg:96.86ms +step:993/1695 train_time:96183ms step_avg:96.86ms +step:994/1695 train_time:96280ms step_avg:96.86ms +step:995/1695 train_time:96375ms step_avg:96.86ms +step:996/1695 train_time:96471ms step_avg:96.86ms +step:997/1695 train_time:96566ms step_avg:96.86ms +step:998/1695 train_time:96663ms step_avg:96.86ms +step:999/1695 train_time:96758ms step_avg:96.85ms +step:1000/1695 train_time:96854ms step_avg:96.85ms +step:1000/1695 val_loss:3.4839 train_time:96948ms step_avg:96.95ms +step:1001/1695 train_time:96972ms step_avg:96.88ms +step:1002/1695 train_time:97055ms step_avg:96.86ms +step:1003/1695 train_time:97153ms step_avg:96.86ms +step:1004/1695 train_time:97250ms step_avg:96.86ms +step:1005/1695 train_time:97345ms step_avg:96.86ms +step:1006/1695 train_time:97440ms step_avg:96.86ms +step:1007/1695 train_time:97534ms step_avg:96.86ms +step:1008/1695 train_time:97629ms step_avg:96.85ms +step:1009/1695 train_time:97724ms step_avg:96.85ms +step:1010/1695 train_time:97819ms step_avg:96.85ms +step:1011/1695 train_time:97914ms step_avg:96.85ms +step:1012/1695 train_time:98013ms step_avg:96.85ms +step:1013/1695 train_time:98112ms step_avg:96.85ms +step:1014/1695 train_time:98211ms step_avg:96.86ms +step:1015/1695 train_time:98309ms step_avg:96.86ms +step:1016/1695 train_time:98406ms step_avg:96.86ms +step:1017/1695 train_time:98501ms step_avg:96.85ms +step:1018/1695 train_time:98596ms step_avg:96.85ms +step:1019/1695 train_time:98690ms step_avg:96.85ms +step:1020/1695 train_time:98786ms step_avg:96.85ms +step:1021/1695 train_time:98883ms step_avg:96.85ms +step:1022/1695 train_time:98979ms step_avg:96.85ms +step:1023/1695 train_time:99077ms step_avg:96.85ms +step:1024/1695 train_time:99173ms step_avg:96.85ms +step:1025/1695 train_time:99270ms step_avg:96.85ms +step:1026/1695 train_time:99366ms step_avg:96.85ms +step:1027/1695 train_time:99462ms step_avg:96.85ms +step:1028/1695 train_time:99556ms step_avg:96.84ms +step:1029/1695 train_time:99651ms step_avg:96.84ms +step:1030/1695 train_time:99746ms step_avg:96.84ms +step:1031/1695 train_time:99842ms step_avg:96.84ms +step:1032/1695 train_time:99938ms step_avg:96.84ms +step:1033/1695 train_time:100035ms step_avg:96.84ms +step:1034/1695 train_time:100131ms step_avg:96.84ms +step:1035/1695 train_time:100228ms step_avg:96.84ms +step:1036/1695 train_time:100552ms step_avg:97.06ms +step:1037/1695 train_time:100740ms step_avg:97.15ms +step:1038/1695 train_time:100833ms step_avg:97.14ms +step:1039/1695 train_time:100928ms step_avg:97.14ms +step:1040/1695 train_time:101024ms step_avg:97.14ms +step:1041/1695 train_time:101118ms step_avg:97.14ms +step:1042/1695 train_time:101212ms step_avg:97.13ms +step:1043/1695 train_time:101307ms step_avg:97.13ms +step:1044/1695 train_time:101402ms step_avg:97.13ms +step:1045/1695 train_time:101497ms step_avg:97.13ms +step:1046/1695 train_time:101598ms step_avg:97.13ms +step:1047/1695 train_time:101696ms step_avg:97.13ms +step:1048/1695 train_time:101794ms step_avg:97.13ms +step:1049/1695 train_time:101890ms step_avg:97.13ms +step:1050/1695 train_time:101986ms step_avg:97.13ms +step:1051/1695 train_time:102083ms step_avg:97.13ms +step:1052/1695 train_time:102178ms step_avg:97.13ms +step:1053/1695 train_time:102272ms step_avg:97.12ms +step:1054/1695 train_time:102367ms step_avg:97.12ms +step:1055/1695 train_time:102462ms step_avg:97.12ms +step:1056/1695 train_time:102560ms step_avg:97.12ms +step:1057/1695 train_time:102656ms step_avg:97.12ms +step:1058/1695 train_time:102753ms step_avg:97.12ms +step:1059/1695 train_time:102850ms step_avg:97.12ms +step:1060/1695 train_time:102947ms step_avg:97.12ms +step:1061/1695 train_time:103044ms step_avg:97.12ms +step:1062/1695 train_time:103140ms step_avg:97.12ms +step:1063/1695 train_time:103235ms step_avg:97.12ms +step:1064/1695 train_time:103330ms step_avg:97.11ms +step:1065/1695 train_time:103425ms step_avg:97.11ms +step:1066/1695 train_time:103521ms step_avg:97.11ms +step:1067/1695 train_time:103617ms step_avg:97.11ms +step:1068/1695 train_time:103714ms step_avg:97.11ms +step:1069/1695 train_time:103810ms step_avg:97.11ms +step:1070/1695 train_time:103907ms step_avg:97.11ms +step:1071/1695 train_time:104004ms step_avg:97.11ms +step:1072/1695 train_time:104100ms step_avg:97.11ms +step:1073/1695 train_time:104195ms step_avg:97.11ms +step:1074/1695 train_time:104289ms step_avg:97.10ms +step:1075/1695 train_time:104385ms step_avg:97.10ms +step:1076/1695 train_time:104481ms step_avg:97.10ms +step:1077/1695 train_time:104578ms step_avg:97.10ms +step:1078/1695 train_time:104673ms step_avg:97.10ms +step:1079/1695 train_time:104769ms step_avg:97.10ms +step:1080/1695 train_time:104865ms step_avg:97.10ms +step:1081/1695 train_time:104961ms step_avg:97.10ms +step:1082/1695 train_time:105056ms step_avg:97.09ms +step:1083/1695 train_time:105152ms step_avg:97.09ms +step:1084/1695 train_time:105248ms step_avg:97.09ms +step:1085/1695 train_time:105344ms step_avg:97.09ms +step:1086/1695 train_time:105439ms step_avg:97.09ms +step:1087/1695 train_time:105535ms step_avg:97.09ms +step:1088/1695 train_time:105631ms step_avg:97.09ms +step:1089/1695 train_time:105729ms step_avg:97.09ms +step:1090/1695 train_time:105825ms step_avg:97.09ms +step:1091/1695 train_time:105922ms step_avg:97.09ms +step:1092/1695 train_time:106018ms step_avg:97.09ms +step:1093/1695 train_time:106114ms step_avg:97.08ms +step:1094/1695 train_time:106209ms step_avg:97.08ms +step:1095/1695 train_time:106306ms step_avg:97.08ms +step:1096/1695 train_time:106402ms step_avg:97.08ms +step:1097/1695 train_time:106497ms step_avg:97.08ms +step:1098/1695 train_time:106592ms step_avg:97.08ms +step:1099/1695 train_time:106688ms step_avg:97.08ms +step:1100/1695 train_time:106785ms step_avg:97.08ms +step:1101/1695 train_time:106882ms step_avg:97.08ms +step:1102/1695 train_time:106978ms step_avg:97.08ms +step:1103/1695 train_time:107073ms step_avg:97.07ms +step:1104/1695 train_time:107169ms step_avg:97.07ms +step:1105/1695 train_time:107265ms step_avg:97.07ms +step:1106/1695 train_time:107362ms step_avg:97.07ms +step:1107/1695 train_time:107457ms step_avg:97.07ms +step:1108/1695 train_time:107552ms step_avg:97.07ms +step:1109/1695 train_time:107648ms step_avg:97.07ms +step:1110/1695 train_time:107745ms step_avg:97.07ms +step:1111/1695 train_time:107841ms step_avg:97.07ms +step:1112/1695 train_time:107938ms step_avg:97.07ms +step:1113/1695 train_time:108034ms step_avg:97.07ms +step:1114/1695 train_time:108130ms step_avg:97.06ms +step:1115/1695 train_time:108227ms step_avg:97.06ms +step:1116/1695 train_time:108323ms step_avg:97.06ms +step:1117/1695 train_time:108419ms step_avg:97.06ms +step:1118/1695 train_time:108515ms step_avg:97.06ms +step:1119/1695 train_time:108611ms step_avg:97.06ms +step:1120/1695 train_time:108708ms step_avg:97.06ms +step:1121/1695 train_time:108804ms step_avg:97.06ms +step:1122/1695 train_time:108901ms step_avg:97.06ms +step:1123/1695 train_time:108995ms step_avg:97.06ms +step:1124/1695 train_time:109090ms step_avg:97.06ms +step:1125/1695 train_time:109187ms step_avg:97.06ms +step:1125/1695 val_loss:3.4370 train_time:109281ms step_avg:97.14ms +step:1126/1695 train_time:109306ms step_avg:97.08ms +step:1127/1695 train_time:109389ms step_avg:97.06ms +step:1128/1695 train_time:109486ms step_avg:97.06ms +step:1129/1695 train_time:109581ms step_avg:97.06ms +step:1130/1695 train_time:109676ms step_avg:97.06ms +step:1131/1695 train_time:109771ms step_avg:97.06ms +step:1132/1695 train_time:109865ms step_avg:97.05ms +step:1133/1695 train_time:109961ms step_avg:97.05ms +step:1134/1695 train_time:110058ms step_avg:97.05ms +step:1135/1695 train_time:110154ms step_avg:97.05ms +step:1136/1695 train_time:110253ms step_avg:97.05ms +step:1137/1695 train_time:110354ms step_avg:97.06ms +step:1138/1695 train_time:110454ms step_avg:97.06ms +step:1139/1695 train_time:110552ms step_avg:97.06ms +step:1140/1695 train_time:110650ms step_avg:97.06ms +step:1141/1695 train_time:110746ms step_avg:97.06ms +step:1142/1695 train_time:110842ms step_avg:97.06ms +step:1143/1695 train_time:110939ms step_avg:97.06ms +step:1144/1695 train_time:111035ms step_avg:97.06ms +step:1145/1695 train_time:111133ms step_avg:97.06ms +step:1146/1695 train_time:111230ms step_avg:97.06ms +step:1147/1695 train_time:111329ms step_avg:97.06ms +step:1148/1695 train_time:111428ms step_avg:97.06ms +step:1149/1695 train_time:111526ms step_avg:97.06ms +step:1150/1695 train_time:111624ms step_avg:97.06ms +step:1151/1695 train_time:111721ms step_avg:97.06ms +step:1152/1695 train_time:111819ms step_avg:97.06ms +step:1153/1695 train_time:111916ms step_avg:97.06ms +step:1154/1695 train_time:112013ms step_avg:97.06ms +step:1155/1695 train_time:112110ms step_avg:97.06ms +step:1156/1695 train_time:112207ms step_avg:97.07ms +step:1157/1695 train_time:112305ms step_avg:97.07ms +step:1158/1695 train_time:112403ms step_avg:97.07ms +step:1159/1695 train_time:112502ms step_avg:97.07ms +step:1160/1695 train_time:112601ms step_avg:97.07ms +step:1161/1695 train_time:112699ms step_avg:97.07ms +step:1162/1695 train_time:112796ms step_avg:97.07ms +step:1163/1695 train_time:112894ms step_avg:97.07ms +step:1164/1695 train_time:112991ms step_avg:97.07ms +step:1165/1695 train_time:113088ms step_avg:97.07ms +step:1166/1695 train_time:113185ms step_avg:97.07ms +step:1167/1695 train_time:113282ms step_avg:97.07ms +step:1168/1695 train_time:113380ms step_avg:97.07ms +step:1169/1695 train_time:113479ms step_avg:97.07ms +step:1170/1695 train_time:113579ms step_avg:97.08ms +step:1171/1695 train_time:113676ms step_avg:97.08ms +step:1172/1695 train_time:113774ms step_avg:97.08ms +step:1173/1695 train_time:113870ms step_avg:97.08ms +step:1174/1695 train_time:113966ms step_avg:97.08ms +step:1175/1695 train_time:114063ms step_avg:97.07ms +step:1176/1695 train_time:114161ms step_avg:97.08ms +step:1177/1695 train_time:114259ms step_avg:97.08ms +step:1178/1695 train_time:114357ms step_avg:97.08ms +step:1179/1695 train_time:114455ms step_avg:97.08ms +step:1180/1695 train_time:114553ms step_avg:97.08ms +step:1181/1695 train_time:114651ms step_avg:97.08ms +step:1182/1695 train_time:114748ms step_avg:97.08ms +step:1183/1695 train_time:114845ms step_avg:97.08ms +step:1184/1695 train_time:114942ms step_avg:97.08ms +step:1185/1695 train_time:115041ms step_avg:97.08ms +step:1186/1695 train_time:115140ms step_avg:97.08ms +step:1187/1695 train_time:115238ms step_avg:97.08ms +step:1188/1695 train_time:115337ms step_avg:97.09ms +step:1189/1695 train_time:115435ms step_avg:97.09ms +step:1190/1695 train_time:115533ms step_avg:97.09ms +step:1191/1695 train_time:115632ms step_avg:97.09ms +step:1192/1695 train_time:115730ms step_avg:97.09ms +step:1193/1695 train_time:115826ms step_avg:97.09ms +step:1194/1695 train_time:115923ms step_avg:97.09ms +step:1195/1695 train_time:116021ms step_avg:97.09ms +step:1196/1695 train_time:116119ms step_avg:97.09ms +step:1197/1695 train_time:116216ms step_avg:97.09ms +step:1198/1695 train_time:116313ms step_avg:97.09ms +step:1199/1695 train_time:116411ms step_avg:97.09ms +step:1200/1695 train_time:116508ms step_avg:97.09ms +step:1201/1695 train_time:116605ms step_avg:97.09ms +step:1202/1695 train_time:116704ms step_avg:97.09ms +step:1203/1695 train_time:116801ms step_avg:97.09ms +step:1204/1695 train_time:116899ms step_avg:97.09ms +step:1205/1695 train_time:116997ms step_avg:97.09ms +step:1206/1695 train_time:117095ms step_avg:97.09ms +step:1207/1695 train_time:117193ms step_avg:97.09ms +step:1208/1695 train_time:117515ms step_avg:97.28ms +step:1209/1695 train_time:117719ms step_avg:97.37ms +step:1210/1695 train_time:117814ms step_avg:97.37ms +step:1211/1695 train_time:117911ms step_avg:97.37ms +step:1212/1695 train_time:118008ms step_avg:97.37ms +step:1213/1695 train_time:118103ms step_avg:97.36ms +step:1214/1695 train_time:118200ms step_avg:97.36ms +step:1215/1695 train_time:118297ms step_avg:97.36ms +step:1216/1695 train_time:118393ms step_avg:97.36ms +step:1217/1695 train_time:118491ms step_avg:97.36ms +step:1218/1695 train_time:118592ms step_avg:97.37ms +step:1219/1695 train_time:118695ms step_avg:97.37ms +step:1220/1695 train_time:118794ms step_avg:97.37ms +step:1221/1695 train_time:118893ms step_avg:97.37ms +step:1222/1695 train_time:118990ms step_avg:97.37ms +step:1223/1695 train_time:119087ms step_avg:97.37ms +step:1224/1695 train_time:119183ms step_avg:97.37ms +step:1225/1695 train_time:119281ms step_avg:97.37ms +step:1226/1695 train_time:119378ms step_avg:97.37ms +step:1227/1695 train_time:119475ms step_avg:97.37ms +step:1228/1695 train_time:119573ms step_avg:97.37ms +step:1229/1695 train_time:119673ms step_avg:97.37ms +step:1230/1695 train_time:119772ms step_avg:97.38ms +step:1231/1695 train_time:119870ms step_avg:97.38ms +step:1232/1695 train_time:119966ms step_avg:97.38ms +step:1233/1695 train_time:120064ms step_avg:97.38ms +step:1234/1695 train_time:120162ms step_avg:97.38ms +step:1235/1695 train_time:120259ms step_avg:97.38ms +step:1236/1695 train_time:120357ms step_avg:97.38ms +step:1237/1695 train_time:120454ms step_avg:97.38ms +step:1238/1695 train_time:120552ms step_avg:97.38ms +step:1239/1695 train_time:120650ms step_avg:97.38ms +step:1240/1695 train_time:120748ms step_avg:97.38ms +step:1241/1695 train_time:120845ms step_avg:97.38ms +step:1242/1695 train_time:120944ms step_avg:97.38ms +step:1243/1695 train_time:121041ms step_avg:97.38ms +step:1244/1695 train_time:121139ms step_avg:97.38ms +step:1245/1695 train_time:121235ms step_avg:97.38ms +step:1246/1695 train_time:121332ms step_avg:97.38ms +step:1247/1695 train_time:121429ms step_avg:97.38ms +step:1248/1695 train_time:121527ms step_avg:97.38ms +step:1249/1695 train_time:121625ms step_avg:97.38ms +step:1250/1695 train_time:121724ms step_avg:97.38ms +step:1250/1695 val_loss:3.3885 train_time:121819ms step_avg:97.46ms +step:1251/1695 train_time:121843ms step_avg:97.40ms +step:1252/1695 train_time:121929ms step_avg:97.39ms +step:1253/1695 train_time:122027ms step_avg:97.39ms +step:1254/1695 train_time:122123ms step_avg:97.39ms +step:1255/1695 train_time:122220ms step_avg:97.39ms +step:1256/1695 train_time:122317ms step_avg:97.39ms +step:1257/1695 train_time:122414ms step_avg:97.39ms +step:1258/1695 train_time:122510ms step_avg:97.38ms +step:1259/1695 train_time:122606ms step_avg:97.38ms +step:1260/1695 train_time:122703ms step_avg:97.38ms +step:1261/1695 train_time:122803ms step_avg:97.39ms +step:1262/1695 train_time:122905ms step_avg:97.39ms +step:1263/1695 train_time:123004ms step_avg:97.39ms +step:1264/1695 train_time:123101ms step_avg:97.39ms +step:1265/1695 train_time:123199ms step_avg:97.39ms +step:1266/1695 train_time:123296ms step_avg:97.39ms +step:1267/1695 train_time:123393ms step_avg:97.39ms +step:1268/1695 train_time:123490ms step_avg:97.39ms +step:1269/1695 train_time:123587ms step_avg:97.39ms +step:1270/1695 train_time:123686ms step_avg:97.39ms +step:1271/1695 train_time:123781ms step_avg:97.39ms +step:1272/1695 train_time:123881ms step_avg:97.39ms +step:1273/1695 train_time:123982ms step_avg:97.39ms +step:1274/1695 train_time:124082ms step_avg:97.40ms +step:1275/1695 train_time:124180ms step_avg:97.40ms +step:1276/1695 train_time:124278ms step_avg:97.40ms +step:1277/1695 train_time:124376ms step_avg:97.40ms +step:1278/1695 train_time:124474ms step_avg:97.40ms +step:1279/1695 train_time:124570ms step_avg:97.40ms +step:1280/1695 train_time:124666ms step_avg:97.40ms +step:1281/1695 train_time:124763ms step_avg:97.40ms +step:1282/1695 train_time:124862ms step_avg:97.40ms +step:1283/1695 train_time:124961ms step_avg:97.40ms +step:1284/1695 train_time:125061ms step_avg:97.40ms +step:1285/1695 train_time:125160ms step_avg:97.40ms +step:1286/1695 train_time:125258ms step_avg:97.40ms +step:1287/1695 train_time:125356ms step_avg:97.40ms +step:1288/1695 train_time:125454ms step_avg:97.40ms +step:1289/1695 train_time:125552ms step_avg:97.40ms +step:1290/1695 train_time:125649ms step_avg:97.40ms +step:1291/1695 train_time:125746ms step_avg:97.40ms +step:1292/1695 train_time:125843ms step_avg:97.40ms +step:1293/1695 train_time:125941ms step_avg:97.40ms +step:1294/1695 train_time:126039ms step_avg:97.40ms +step:1295/1695 train_time:126138ms step_avg:97.40ms +step:1296/1695 train_time:126237ms step_avg:97.40ms +step:1297/1695 train_time:126335ms step_avg:97.41ms +step:1298/1695 train_time:126432ms step_avg:97.41ms +step:1299/1695 train_time:126529ms step_avg:97.41ms +step:1300/1695 train_time:126626ms step_avg:97.40ms +step:1301/1695 train_time:126723ms step_avg:97.40ms +step:1302/1695 train_time:126821ms step_avg:97.40ms +step:1303/1695 train_time:126919ms step_avg:97.41ms +step:1304/1695 train_time:127018ms step_avg:97.41ms +step:1305/1695 train_time:127116ms step_avg:97.41ms +step:1306/1695 train_time:127214ms step_avg:97.41ms +step:1307/1695 train_time:127312ms step_avg:97.41ms +step:1308/1695 train_time:127410ms step_avg:97.41ms +step:1309/1695 train_time:127508ms step_avg:97.41ms +step:1310/1695 train_time:127605ms step_avg:97.41ms +step:1311/1695 train_time:127703ms step_avg:97.41ms +step:1312/1695 train_time:127800ms step_avg:97.41ms +step:1313/1695 train_time:127898ms step_avg:97.41ms +step:1314/1695 train_time:127996ms step_avg:97.41ms +step:1315/1695 train_time:128093ms step_avg:97.41ms +step:1316/1695 train_time:128191ms step_avg:97.41ms +step:1317/1695 train_time:128289ms step_avg:97.41ms +step:1318/1695 train_time:128386ms step_avg:97.41ms +step:1319/1695 train_time:128483ms step_avg:97.41ms +step:1320/1695 train_time:128582ms step_avg:97.41ms +step:1321/1695 train_time:128679ms step_avg:97.41ms +step:1322/1695 train_time:128777ms step_avg:97.41ms +step:1323/1695 train_time:128875ms step_avg:97.41ms +step:1324/1695 train_time:128973ms step_avg:97.41ms +step:1325/1695 train_time:129070ms step_avg:97.41ms +step:1326/1695 train_time:129168ms step_avg:97.41ms +step:1327/1695 train_time:129265ms step_avg:97.41ms +step:1328/1695 train_time:129363ms step_avg:97.41ms +step:1329/1695 train_time:129461ms step_avg:97.41ms +step:1330/1695 train_time:129559ms step_avg:97.41ms +step:1331/1695 train_time:129658ms step_avg:97.41ms +step:1332/1695 train_time:129756ms step_avg:97.41ms +step:1333/1695 train_time:129854ms step_avg:97.42ms +step:1334/1695 train_time:129952ms step_avg:97.42ms +step:1335/1695 train_time:130049ms step_avg:97.41ms +step:1336/1695 train_time:130146ms step_avg:97.41ms +step:1337/1695 train_time:130244ms step_avg:97.42ms +step:1338/1695 train_time:130342ms step_avg:97.42ms +step:1339/1695 train_time:130440ms step_avg:97.42ms +step:1340/1695 train_time:130539ms step_avg:97.42ms +step:1341/1695 train_time:130637ms step_avg:97.42ms +step:1342/1695 train_time:130735ms step_avg:97.42ms +step:1343/1695 train_time:130832ms step_avg:97.42ms +step:1344/1695 train_time:130929ms step_avg:97.42ms +step:1345/1695 train_time:131026ms step_avg:97.42ms +step:1346/1695 train_time:131124ms step_avg:97.42ms +step:1347/1695 train_time:131220ms step_avg:97.42ms +step:1348/1695 train_time:131318ms step_avg:97.42ms +step:1349/1695 train_time:131417ms step_avg:97.42ms +step:1350/1695 train_time:131515ms step_avg:97.42ms +step:1351/1695 train_time:131613ms step_avg:97.42ms +step:1352/1695 train_time:131710ms step_avg:97.42ms +step:1353/1695 train_time:131808ms step_avg:97.42ms +step:1354/1695 train_time:131905ms step_avg:97.42ms +step:1355/1695 train_time:132003ms step_avg:97.42ms +step:1356/1695 train_time:132102ms step_avg:97.42ms +step:1357/1695 train_time:132200ms step_avg:97.42ms +step:1358/1695 train_time:132297ms step_avg:97.42ms +step:1359/1695 train_time:132395ms step_avg:97.42ms +step:1360/1695 train_time:132492ms step_avg:97.42ms +step:1361/1695 train_time:132590ms step_avg:97.42ms +step:1362/1695 train_time:132688ms step_avg:97.42ms +step:1363/1695 train_time:132786ms step_avg:97.42ms +step:1364/1695 train_time:132883ms step_avg:97.42ms +step:1365/1695 train_time:132981ms step_avg:97.42ms +step:1366/1695 train_time:133080ms step_avg:97.42ms +step:1367/1695 train_time:133179ms step_avg:97.42ms +step:1368/1695 train_time:133277ms step_avg:97.42ms +step:1369/1695 train_time:133375ms step_avg:97.43ms +step:1370/1695 train_time:133473ms step_avg:97.43ms +step:1371/1695 train_time:133571ms step_avg:97.43ms +step:1372/1695 train_time:133669ms step_avg:97.43ms +step:1373/1695 train_time:133766ms step_avg:97.43ms +step:1374/1695 train_time:133863ms step_avg:97.43ms +step:1375/1695 train_time:133961ms step_avg:97.43ms +step:1375/1695 val_loss:3.3505 train_time:134057ms step_avg:97.50ms +step:1376/1695 train_time:134084ms step_avg:97.45ms +step:1377/1695 train_time:134163ms step_avg:97.43ms +step:1378/1695 train_time:134261ms step_avg:97.43ms +step:1379/1695 train_time:134359ms step_avg:97.43ms +step:1380/1695 train_time:134456ms step_avg:97.43ms +step:1381/1695 train_time:134781ms step_avg:97.60ms +step:1382/1695 train_time:134987ms step_avg:97.68ms +step:1383/1695 train_time:135083ms step_avg:97.67ms +step:1384/1695 train_time:135179ms step_avg:97.67ms +step:1385/1695 train_time:135276ms step_avg:97.67ms +step:1386/1695 train_time:135374ms step_avg:97.67ms +step:1387/1695 train_time:135471ms step_avg:97.67ms +step:1388/1695 train_time:135568ms step_avg:97.67ms +step:1389/1695 train_time:135665ms step_avg:97.67ms +step:1390/1695 train_time:135763ms step_avg:97.67ms +step:1391/1695 train_time:135867ms step_avg:97.68ms +step:1392/1695 train_time:135968ms step_avg:97.68ms +step:1393/1695 train_time:136065ms step_avg:97.68ms +step:1394/1695 train_time:136163ms step_avg:97.68ms +step:1395/1695 train_time:136260ms step_avg:97.68ms +step:1396/1695 train_time:136356ms step_avg:97.68ms +step:1397/1695 train_time:136453ms step_avg:97.68ms +step:1398/1695 train_time:136549ms step_avg:97.67ms +step:1399/1695 train_time:136646ms step_avg:97.67ms +step:1400/1695 train_time:136743ms step_avg:97.67ms +step:1401/1695 train_time:136841ms step_avg:97.67ms +step:1402/1695 train_time:136940ms step_avg:97.67ms +step:1403/1695 train_time:137039ms step_avg:97.68ms +step:1404/1695 train_time:137137ms step_avg:97.68ms +step:1405/1695 train_time:137234ms step_avg:97.68ms +step:1406/1695 train_time:137332ms step_avg:97.68ms +step:1407/1695 train_time:137429ms step_avg:97.67ms +step:1408/1695 train_time:137525ms step_avg:97.67ms +step:1409/1695 train_time:137622ms step_avg:97.67ms +step:1410/1695 train_time:137719ms step_avg:97.67ms +step:1411/1695 train_time:137817ms step_avg:97.67ms +step:1412/1695 train_time:137916ms step_avg:97.67ms +step:1413/1695 train_time:138014ms step_avg:97.67ms +step:1414/1695 train_time:138113ms step_avg:97.68ms +step:1415/1695 train_time:138212ms step_avg:97.68ms +step:1416/1695 train_time:138309ms step_avg:97.68ms +step:1417/1695 train_time:138405ms step_avg:97.67ms +step:1418/1695 train_time:138501ms step_avg:97.67ms +step:1419/1695 train_time:138598ms step_avg:97.67ms +step:1420/1695 train_time:138696ms step_avg:97.67ms +step:1421/1695 train_time:138794ms step_avg:97.67ms +step:1422/1695 train_time:138893ms step_avg:97.67ms +step:1423/1695 train_time:138990ms step_avg:97.67ms +step:1424/1695 train_time:139089ms step_avg:97.67ms +step:1425/1695 train_time:139188ms step_avg:97.68ms +step:1426/1695 train_time:139286ms step_avg:97.68ms +step:1427/1695 train_time:139382ms step_avg:97.68ms +step:1428/1695 train_time:139479ms step_avg:97.67ms +step:1429/1695 train_time:139577ms step_avg:97.67ms +step:1430/1695 train_time:139675ms step_avg:97.67ms +step:1431/1695 train_time:139772ms step_avg:97.67ms +step:1432/1695 train_time:139871ms step_avg:97.68ms +step:1433/1695 train_time:139969ms step_avg:97.68ms +step:1434/1695 train_time:140067ms step_avg:97.68ms +step:1435/1695 train_time:140164ms step_avg:97.68ms +step:1436/1695 train_time:140261ms step_avg:97.68ms +step:1437/1695 train_time:140358ms step_avg:97.67ms +step:1438/1695 train_time:140455ms step_avg:97.67ms +step:1439/1695 train_time:140553ms step_avg:97.67ms +step:1440/1695 train_time:140651ms step_avg:97.67ms +step:1441/1695 train_time:140749ms step_avg:97.67ms +step:1442/1695 train_time:140847ms step_avg:97.67ms +step:1443/1695 train_time:140945ms step_avg:97.68ms +step:1444/1695 train_time:141043ms step_avg:97.68ms +step:1445/1695 train_time:141140ms step_avg:97.67ms +step:1446/1695 train_time:141238ms step_avg:97.67ms +step:1447/1695 train_time:141335ms step_avg:97.67ms +step:1448/1695 train_time:141433ms step_avg:97.67ms +step:1449/1695 train_time:141530ms step_avg:97.67ms +step:1450/1695 train_time:141628ms step_avg:97.67ms +step:1451/1695 train_time:141725ms step_avg:97.67ms +step:1452/1695 train_time:141822ms step_avg:97.67ms +step:1453/1695 train_time:141920ms step_avg:97.67ms +step:1454/1695 train_time:142019ms step_avg:97.67ms +step:1455/1695 train_time:142118ms step_avg:97.68ms +step:1456/1695 train_time:142215ms step_avg:97.68ms +step:1457/1695 train_time:142313ms step_avg:97.68ms +step:1458/1695 train_time:142410ms step_avg:97.68ms +step:1459/1695 train_time:142507ms step_avg:97.67ms +step:1460/1695 train_time:142604ms step_avg:97.67ms +step:1461/1695 train_time:142701ms step_avg:97.67ms +step:1462/1695 train_time:142798ms step_avg:97.67ms +step:1463/1695 train_time:142896ms step_avg:97.67ms +step:1464/1695 train_time:142994ms step_avg:97.67ms +step:1465/1695 train_time:143092ms step_avg:97.67ms +step:1466/1695 train_time:143190ms step_avg:97.67ms +step:1467/1695 train_time:143287ms step_avg:97.67ms +step:1468/1695 train_time:143385ms step_avg:97.67ms +step:1469/1695 train_time:143482ms step_avg:97.67ms +step:1470/1695 train_time:143579ms step_avg:97.67ms +step:1471/1695 train_time:143678ms step_avg:97.67ms +step:1472/1695 train_time:143776ms step_avg:97.67ms +step:1473/1695 train_time:143873ms step_avg:97.67ms +step:1474/1695 train_time:143972ms step_avg:97.67ms +step:1475/1695 train_time:144070ms step_avg:97.67ms +step:1476/1695 train_time:144167ms step_avg:97.67ms +step:1477/1695 train_time:144265ms step_avg:97.67ms +step:1478/1695 train_time:144361ms step_avg:97.67ms +step:1479/1695 train_time:144459ms step_avg:97.67ms +step:1480/1695 train_time:144555ms step_avg:97.67ms +step:1481/1695 train_time:144653ms step_avg:97.67ms +step:1482/1695 train_time:144751ms step_avg:97.67ms +step:1483/1695 train_time:144849ms step_avg:97.67ms +step:1484/1695 train_time:144947ms step_avg:97.67ms +step:1485/1695 train_time:145044ms step_avg:97.67ms +step:1486/1695 train_time:145141ms step_avg:97.67ms +step:1487/1695 train_time:145239ms step_avg:97.67ms +step:1488/1695 train_time:145337ms step_avg:97.67ms +step:1489/1695 train_time:145435ms step_avg:97.67ms +step:1490/1695 train_time:145533ms step_avg:97.67ms +step:1491/1695 train_time:145630ms step_avg:97.67ms +step:1492/1695 train_time:145728ms step_avg:97.67ms +step:1493/1695 train_time:145825ms step_avg:97.67ms +step:1494/1695 train_time:145922ms step_avg:97.67ms +step:1495/1695 train_time:146020ms step_avg:97.67ms +step:1496/1695 train_time:146118ms step_avg:97.67ms +step:1497/1695 train_time:146216ms step_avg:97.67ms +step:1498/1695 train_time:146315ms step_avg:97.67ms +step:1499/1695 train_time:146413ms step_avg:97.67ms +step:1500/1695 train_time:146511ms step_avg:97.67ms +step:1500/1695 val_loss:3.3179 train_time:146606ms step_avg:97.74ms +step:1501/1695 train_time:146632ms step_avg:97.69ms +step:1502/1695 train_time:146715ms step_avg:97.68ms +step:1503/1695 train_time:146816ms step_avg:97.68ms +step:1504/1695 train_time:146914ms step_avg:97.68ms +step:1505/1695 train_time:147011ms step_avg:97.68ms +step:1506/1695 train_time:147108ms step_avg:97.68ms +step:1507/1695 train_time:147205ms step_avg:97.68ms +step:1508/1695 train_time:147301ms step_avg:97.68ms +step:1509/1695 train_time:147397ms step_avg:97.68ms +step:1510/1695 train_time:147494ms step_avg:97.68ms +step:1511/1695 train_time:147596ms step_avg:97.68ms +step:1512/1695 train_time:147698ms step_avg:97.68ms +step:1513/1695 train_time:147797ms step_avg:97.68ms +step:1514/1695 train_time:147896ms step_avg:97.69ms +step:1515/1695 train_time:147994ms step_avg:97.69ms +step:1516/1695 train_time:148092ms step_avg:97.69ms +step:1517/1695 train_time:148189ms step_avg:97.69ms +step:1518/1695 train_time:148286ms step_avg:97.68ms +step:1519/1695 train_time:148382ms step_avg:97.68ms +step:1520/1695 train_time:148478ms step_avg:97.68ms +step:1521/1695 train_time:148577ms step_avg:97.68ms +step:1522/1695 train_time:148678ms step_avg:97.69ms +step:1523/1695 train_time:148777ms step_avg:97.69ms +step:1524/1695 train_time:148876ms step_avg:97.69ms +step:1525/1695 train_time:148975ms step_avg:97.69ms +step:1526/1695 train_time:149074ms step_avg:97.69ms +step:1527/1695 train_time:149172ms step_avg:97.69ms +step:1528/1695 train_time:149269ms step_avg:97.69ms +step:1529/1695 train_time:149367ms step_avg:97.69ms +step:1530/1695 train_time:149465ms step_avg:97.69ms +step:1531/1695 train_time:149562ms step_avg:97.69ms +step:1532/1695 train_time:149660ms step_avg:97.69ms +step:1533/1695 train_time:149758ms step_avg:97.69ms +step:1534/1695 train_time:149856ms step_avg:97.69ms +step:1535/1695 train_time:149955ms step_avg:97.69ms +step:1536/1695 train_time:150054ms step_avg:97.69ms +step:1537/1695 train_time:150152ms step_avg:97.69ms +step:1538/1695 train_time:150249ms step_avg:97.69ms +step:1539/1695 train_time:150346ms step_avg:97.69ms +step:1540/1695 train_time:150443ms step_avg:97.69ms +step:1541/1695 train_time:150540ms step_avg:97.69ms +step:1542/1695 train_time:150638ms step_avg:97.69ms +step:1543/1695 train_time:150736ms step_avg:97.69ms +step:1544/1695 train_time:150835ms step_avg:97.69ms +step:1545/1695 train_time:150932ms step_avg:97.69ms +step:1546/1695 train_time:151030ms step_avg:97.69ms +step:1547/1695 train_time:151128ms step_avg:97.69ms +step:1548/1695 train_time:151225ms step_avg:97.69ms +step:1549/1695 train_time:151322ms step_avg:97.69ms +step:1550/1695 train_time:151419ms step_avg:97.69ms +step:1551/1695 train_time:151517ms step_avg:97.69ms +step:1552/1695 train_time:151888ms step_avg:97.87ms +step:1553/1695 train_time:151963ms step_avg:97.85ms +step:1554/1695 train_time:152058ms step_avg:97.85ms +step:1555/1695 train_time:152155ms step_avg:97.85ms +step:1556/1695 train_time:152252ms step_avg:97.85ms +step:1557/1695 train_time:152349ms step_avg:97.85ms +step:1558/1695 train_time:152445ms step_avg:97.85ms +step:1559/1695 train_time:152541ms step_avg:97.85ms +step:1560/1695 train_time:152638ms step_avg:97.84ms +step:1561/1695 train_time:152735ms step_avg:97.84ms +step:1562/1695 train_time:152839ms step_avg:97.85ms +step:1563/1695 train_time:152939ms step_avg:97.85ms +step:1564/1695 train_time:153039ms step_avg:97.85ms +step:1565/1695 train_time:153136ms step_avg:97.85ms +step:1566/1695 train_time:153235ms step_avg:97.85ms +step:1567/1695 train_time:153331ms step_avg:97.85ms +step:1568/1695 train_time:153428ms step_avg:97.85ms +step:1569/1695 train_time:153525ms step_avg:97.85ms +step:1570/1695 train_time:153622ms step_avg:97.85ms +step:1571/1695 train_time:153720ms step_avg:97.85ms +step:1572/1695 train_time:153819ms step_avg:97.85ms +step:1573/1695 train_time:153917ms step_avg:97.85ms +step:1574/1695 train_time:154017ms step_avg:97.85ms +step:1575/1695 train_time:154115ms step_avg:97.85ms +step:1576/1695 train_time:154215ms step_avg:97.85ms +step:1577/1695 train_time:154312ms step_avg:97.85ms +step:1578/1695 train_time:154410ms step_avg:97.85ms +step:1579/1695 train_time:154507ms step_avg:97.85ms +step:1580/1695 train_time:154604ms step_avg:97.85ms +step:1581/1695 train_time:154701ms step_avg:97.85ms +step:1582/1695 train_time:154798ms step_avg:97.85ms +step:1583/1695 train_time:154897ms step_avg:97.85ms +step:1584/1695 train_time:154996ms step_avg:97.85ms +step:1585/1695 train_time:155096ms step_avg:97.85ms +step:1586/1695 train_time:155195ms step_avg:97.85ms +step:1587/1695 train_time:155293ms step_avg:97.85ms +step:1588/1695 train_time:155391ms step_avg:97.85ms +step:1589/1695 train_time:155489ms step_avg:97.85ms +step:1590/1695 train_time:155586ms step_avg:97.85ms +step:1591/1695 train_time:155682ms step_avg:97.85ms +step:1592/1695 train_time:155779ms step_avg:97.85ms +step:1593/1695 train_time:155876ms step_avg:97.85ms +step:1594/1695 train_time:155976ms step_avg:97.85ms +step:1595/1695 train_time:156075ms step_avg:97.85ms +step:1596/1695 train_time:156175ms step_avg:97.85ms +step:1597/1695 train_time:156275ms step_avg:97.86ms +step:1598/1695 train_time:156375ms step_avg:97.86ms +step:1599/1695 train_time:156473ms step_avg:97.86ms +step:1600/1695 train_time:156571ms step_avg:97.86ms +step:1601/1695 train_time:156669ms step_avg:97.86ms +step:1602/1695 train_time:156767ms step_avg:97.86ms +step:1603/1695 train_time:156864ms step_avg:97.86ms +step:1604/1695 train_time:156962ms step_avg:97.86ms +step:1605/1695 train_time:157060ms step_avg:97.86ms +step:1606/1695 train_time:157159ms step_avg:97.86ms +step:1607/1695 train_time:157257ms step_avg:97.86ms +step:1608/1695 train_time:157354ms step_avg:97.86ms +step:1609/1695 train_time:157453ms step_avg:97.86ms +step:1610/1695 train_time:157551ms step_avg:97.86ms +step:1611/1695 train_time:157649ms step_avg:97.86ms +step:1612/1695 train_time:157747ms step_avg:97.86ms +step:1613/1695 train_time:157844ms step_avg:97.86ms +step:1614/1695 train_time:157942ms step_avg:97.86ms +step:1615/1695 train_time:158039ms step_avg:97.86ms +step:1616/1695 train_time:158137ms step_avg:97.86ms +step:1617/1695 train_time:158235ms step_avg:97.86ms +step:1618/1695 train_time:158333ms step_avg:97.86ms +step:1619/1695 train_time:158432ms step_avg:97.86ms +step:1620/1695 train_time:158530ms step_avg:97.86ms +step:1621/1695 train_time:158628ms step_avg:97.86ms +step:1622/1695 train_time:158726ms step_avg:97.86ms +step:1623/1695 train_time:158824ms step_avg:97.86ms +step:1624/1695 train_time:158922ms step_avg:97.86ms +step:1625/1695 train_time:159019ms step_avg:97.86ms +step:1625/1695 val_loss:3.2905 train_time:159114ms step_avg:97.92ms +step:1626/1695 train_time:159139ms step_avg:97.87ms +step:1627/1695 train_time:159222ms step_avg:97.86ms +step:1628/1695 train_time:159320ms step_avg:97.86ms +step:1629/1695 train_time:159418ms step_avg:97.86ms +step:1630/1695 train_time:159515ms step_avg:97.86ms +step:1631/1695 train_time:159612ms step_avg:97.86ms +step:1632/1695 train_time:159709ms step_avg:97.86ms +step:1633/1695 train_time:159806ms step_avg:97.86ms +step:1634/1695 train_time:159902ms step_avg:97.86ms +step:1635/1695 train_time:159999ms step_avg:97.86ms +step:1636/1695 train_time:160099ms step_avg:97.86ms +step:1637/1695 train_time:160200ms step_avg:97.86ms +step:1638/1695 train_time:160300ms step_avg:97.86ms +step:1639/1695 train_time:160398ms step_avg:97.86ms +step:1640/1695 train_time:160495ms step_avg:97.86ms +step:1641/1695 train_time:160592ms step_avg:97.86ms +step:1642/1695 train_time:160689ms step_avg:97.86ms +step:1643/1695 train_time:160786ms step_avg:97.86ms +step:1644/1695 train_time:160882ms step_avg:97.86ms +step:1645/1695 train_time:160980ms step_avg:97.86ms +step:1646/1695 train_time:161078ms step_avg:97.86ms +step:1647/1695 train_time:161177ms step_avg:97.86ms +step:1648/1695 train_time:161277ms step_avg:97.86ms +step:1649/1695 train_time:161378ms step_avg:97.86ms +step:1650/1695 train_time:161476ms step_avg:97.86ms +step:1651/1695 train_time:161573ms step_avg:97.86ms +step:1652/1695 train_time:161670ms step_avg:97.86ms +step:1653/1695 train_time:161767ms step_avg:97.86ms +step:1654/1695 train_time:161864ms step_avg:97.86ms +step:1655/1695 train_time:161961ms step_avg:97.86ms +step:1656/1695 train_time:162059ms step_avg:97.86ms +step:1657/1695 train_time:162157ms step_avg:97.86ms +step:1658/1695 train_time:162256ms step_avg:97.86ms +step:1659/1695 train_time:162356ms step_avg:97.86ms +step:1660/1695 train_time:162456ms step_avg:97.86ms +step:1661/1695 train_time:162555ms step_avg:97.87ms +step:1662/1695 train_time:162654ms step_avg:97.87ms +step:1663/1695 train_time:162751ms step_avg:97.87ms +step:1664/1695 train_time:162849ms step_avg:97.87ms +step:1665/1695 train_time:162946ms step_avg:97.87ms +step:1666/1695 train_time:163044ms step_avg:97.87ms +step:1667/1695 train_time:163142ms step_avg:97.87ms +step:1668/1695 train_time:163239ms step_avg:97.87ms +step:1669/1695 train_time:163337ms step_avg:97.87ms +step:1670/1695 train_time:163435ms step_avg:97.87ms +step:1671/1695 train_time:163534ms step_avg:97.87ms +step:1672/1695 train_time:163633ms step_avg:97.87ms +step:1673/1695 train_time:163731ms step_avg:97.87ms +step:1674/1695 train_time:163829ms step_avg:97.87ms +step:1675/1695 train_time:163927ms step_avg:97.87ms +step:1676/1695 train_time:164024ms step_avg:97.87ms +step:1677/1695 train_time:164122ms step_avg:97.87ms +step:1678/1695 train_time:164219ms step_avg:97.87ms +step:1679/1695 train_time:164317ms step_avg:97.87ms +step:1680/1695 train_time:164415ms step_avg:97.87ms +step:1681/1695 train_time:164513ms step_avg:97.87ms +step:1682/1695 train_time:164612ms step_avg:97.87ms +step:1683/1695 train_time:164710ms step_avg:97.87ms +step:1684/1695 train_time:164809ms step_avg:97.87ms +step:1685/1695 train_time:164906ms step_avg:97.87ms +step:1686/1695 train_time:165004ms step_avg:97.87ms +step:1687/1695 train_time:165101ms step_avg:97.87ms +step:1688/1695 train_time:165199ms step_avg:97.87ms +step:1689/1695 train_time:165296ms step_avg:97.87ms +step:1690/1695 train_time:165393ms step_avg:97.87ms +step:1691/1695 train_time:165491ms step_avg:97.87ms +step:1692/1695 train_time:165589ms step_avg:97.87ms +step:1693/1695 train_time:165686ms step_avg:97.87ms +step:1694/1695 train_time:165783ms step_avg:97.86ms +step:1695/1695 train_time:165881ms step_avg:97.86ms +step:1695/1695 val_loss:3.2790 train_time:165977ms step_avg:97.92ms +peak memory allocated: 34000 MiB reserved: 49756 MiB diff --git a/records/082725_FA3/README.md b/records/082725_FA3/README.md new file mode 100644 index 000000000..a4079630d --- /dev/null +++ b/records/082725_FA3/README.md @@ -0,0 +1,147 @@ +# New record 08/27/25 + +This submission includes recent WR changes by +@ClassicLarry [(08/23/25)](https://github.com/ClassicLarry/modded-nanogpt/tree/master/records/082325_SparseAttnGate) +and @byronxu99 [(07/18/25)](https://github.com/KellerJordan/modded-nanogpt/pull/109). + +The main idea of this record is to use input tensors with `batch_size > 1` throughout our training run. +Increasing `batch_size` increases GPU utilization and allows us to use shorter input sequences for training. +However, since Flex Attention's is inefficient for `batch_size > 1`, we use [Flash Attention v3](https://github.com/Dao-AILab/flash-attention). +The official version of this module is incompatible with `torch.compile` and causes graph breaks. +However, a [recent PR](https://github.com/Dao-AILab/flash-attention/pull/1769) by +[@guilhermeleobas](https://github.com/guilhermeleobas) addresses this issue. + + +## Timing and Validation + +Validated over 7 runs: +- In 1695 training steps, this run achieves a loss <3.28 (`p=0.0031`) +- In 166.10 seconds on average, or <166.25 seconds (`p=0.0024`), + +``` +import scipy.stats +import torch +import numpy as np + +accs = [ + 3.2769, 3.2782, 3.2790, 3.2791, 3.2791, 3.2780, 3.2782 +] + +times = [ + 166.247, 166.117, 165.977, 166.135, 166.045, 166.044, 166.157 +] + +print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) +# p=0.0008 + +print('p=%.4f' % scipy.stats.ttest_1samp(times, 166.25, alternative='less').pvalue) +# p=0.0024 + +print(f"{np.mean(times):.4f}") +# 166.1031 +``` + +In my timing, this is a 2.1 second mean improvement over [PR#117])(https://github.com/KellerJordan/modded-nanogpt/pull/117). +The number of steps can also probably be brought down by 5-15 while achieving loss <3.28. + +I used SXM5 8 x H100 via Prime Intellect for validation compute. + +## Further Details + +### Motivation + +PyTorch's Flex Attention experiences a slowdown >10% wallclock for inputs with `batch_size > 1`. +As such, previous records would train on very long sequence lengths (`48 * 1024`) with no batch dimension. +Attention is approximately `O(|seq_len|^2 x |batch_size|)`, so this is theoretically bad, +but it was mitigated by using aggressive blocking masking. +Attention used a `block_mask` which only grew at most to `1664` tokens (and was often shorter due to document masking). +However, GPU utilization for attention is higher when tokens are distributed along the batch dimension. + + +Additionally, increasing the batch size allows us to decrease sequence length while maintaining the total +number of tokens processed per step. +WR#26 by @ClassicLarry found that validation loss decreases when we train only +on sequences beginning with the Beginning of Sequence token (``). +Decreasing the sequence length ensures makes it more likely that `` is present in the attention window. +In order generate batches where each sequence begins with ``, I have created the helper class +`EOSBatchFinder`. This class pre-indexes shards with the location of `` for slight speedups. + +### Flash Attention 3 + +Most of the Hopper-specific benefits in Flash Attention 3 are incorporated into +PyTorch's Flex Attention already. However, the latter implementation is fastest with `batch_size == 1`, +Flash Attention 3 is as fast as Flex Attention for 1 dimensional input sequences, and increases +in speed as we distribute tokens along the batch dimension. +I measured a 9% wallclock decrease for FA3 when using an optimal ratio of batch dimension to sequence length +(`24: 2048`) over a single batch dimension (`1: 49152`) (on a single Hopper H100). + +As mentioned above, we need to use an unmerged PR in order to use FA3 with `torch.compile`. +You can build the wheel like so: + +``` +pip install -U pip wheel setuptools ninja numpy packaging psutil + +git clone https://github.com/guilhermeleobas/flash-attention.git +cd flash-attention/hopper +git switch guilhermeleobas/fa3-compile + +export MAX_JOBS=32 # Can increase based on machine +export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch +export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only +export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8 +export FLASH_ATTENTION_DISABLE_HDIM64=TRUE # NanoGPT only uses HDIM = 128 +export FLASH_ATTENTION_DISABLE_HDIM96=TRUE +export FLASH_ATTENTION_DISABLE_HDIM192=TRUE +export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + +python setup.py bdist_wheel +``` + +Additionally, I have uploaded a prebuilt wheel +[here](https://github.com/varunneal/flash-attention/releases/tag/v3.0.0b1-alpha), +though it will likely be faster to build it yourself than download this wheel. + +For exact reproduction, I recommend that you install Torch Nightly 2.9.0.dev20250718 and +install the FA3 wheel afterward: + +``` +pip install --pre "torch==2.9.0.dev20250718+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126 + +# typical path to FA3 Wheel +pip install flash-attention/hopper/dist/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl +``` + +For me, Torch Nightly 2.9.0.dev20250713 was incompatible with PR#109. + +### Attention Masks + +Unfortunately, Flash Attention does not support complex Block Masks like Flex Attention. +Therefore, `create_blockmasks` was removed. Instead, we only are given the parameter `window_size` +where we can specify the number of left tokens to attend to. + +I kept the existing long-short sliding window block mask pattern, as well as the idea +that the window sizes should linearly increase over the length of the training run. +To aid with this, I modified `get_lr(step)` to instead be `get_lr_and_ws(step)`. +Additionally, I added a hyperparameter `ws_schedule` which specifies what the +longer window size should be during each portion of the run. I additionally added the +size of blocks in a window as a hyperparameter `bandwidth=128`. + +I have picked a linear schedule with three steps: `ws_schedule=(3, 7, 11)`. +Currently, `torch.compile` creates a new compilation graph per each step in `ws_schedule`. +Therefore, each graph needs to be warmed up separately. I have increased the number +of warmup steps from `10` to `60`. The compile time is dominated by the first iteration +so this will take approximately `len(ws_schedule)` times longer than before. + +Removing document masking had a noticeably negative impact on validation loss, +however the benefits of a short sequence length counteract this. + +### Potential Improvements + +- Batch size scheduling: Previously, the block mask acted as a proxy for batch size. +Now block size can be controlled explicitly and sequenced according to critical batch +size theory. I have added code in `distributed_data_generator` that allows for changing the +batch size and sequence length yielded after the generator is created. +- The current block mask window schedule `(3, 7, 11)` can almost certainly be improved upon. +- Hyperparameter tuning might change with smaller sequence length. Rotary base, validation sequence length, learning rates +etc. should be re-tuned. I haven't done that for this run. +- FA3 has additional features over Flex Attention that may be useful. \ No newline at end of file diff --git a/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt b/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt new file mode 100644 index 000000000..7a5ed0b1c --- /dev/null +++ b/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 03:58:09 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 34C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 30C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 35C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 33C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1695 train_time:516ms step_avg:515.52ms +step:2/1695 train_time:539ms step_avg:269.65ms +step:3/1695 train_time:612ms step_avg:203.90ms +step:4/1695 train_time:704ms step_avg:175.97ms +step:5/1695 train_time:797ms step_avg:159.42ms +step:6/1695 train_time:891ms step_avg:148.48ms +step:7/1695 train_time:984ms step_avg:140.60ms +step:8/1695 train_time:1078ms step_avg:134.78ms +step:9/1695 train_time:1172ms step_avg:130.23ms +step:10/1695 train_time:1265ms step_avg:126.49ms +step:11/1695 train_time:1359ms step_avg:123.52ms +step:12/1695 train_time:1457ms step_avg:121.44ms +step:13/1695 train_time:1555ms step_avg:119.64ms +step:14/1695 train_time:1650ms step_avg:117.89ms +step:15/1695 train_time:1745ms step_avg:116.30ms +step:16/1695 train_time:1839ms step_avg:114.95ms +step:17/1695 train_time:1933ms step_avg:113.72ms +step:18/1695 train_time:2027ms step_avg:112.62ms +step:19/1695 train_time:2122ms step_avg:111.67ms +step:20/1695 train_time:2216ms step_avg:110.82ms +step:21/1695 train_time:2311ms step_avg:110.03ms +step:22/1695 train_time:2405ms step_avg:109.33ms +step:23/1695 train_time:2501ms step_avg:108.73ms +step:24/1695 train_time:2597ms step_avg:108.21ms +step:25/1695 train_time:2693ms step_avg:107.73ms +step:26/1695 train_time:2788ms step_avg:107.23ms +step:27/1695 train_time:2882ms step_avg:106.74ms +step:28/1695 train_time:2977ms step_avg:106.33ms +step:29/1695 train_time:3071ms step_avg:105.90ms +step:30/1695 train_time:3165ms step_avg:105.50ms +step:31/1695 train_time:3259ms step_avg:105.14ms +step:32/1695 train_time:3355ms step_avg:104.84ms +step:33/1695 train_time:3449ms step_avg:104.53ms +step:34/1695 train_time:3545ms step_avg:104.26ms +step:35/1695 train_time:3640ms step_avg:104.01ms +step:36/1695 train_time:3736ms step_avg:103.78ms +step:37/1695 train_time:3831ms step_avg:103.53ms +step:38/1695 train_time:3925ms step_avg:103.28ms +step:39/1695 train_time:4019ms step_avg:103.06ms +step:40/1695 train_time:4113ms step_avg:102.84ms +step:41/1695 train_time:4207ms step_avg:102.60ms +step:42/1695 train_time:4301ms step_avg:102.40ms +step:43/1695 train_time:4396ms step_avg:102.23ms +step:44/1695 train_time:4491ms step_avg:102.07ms +step:45/1695 train_time:4587ms step_avg:101.92ms +step:46/1695 train_time:4681ms step_avg:101.76ms +step:47/1695 train_time:4777ms step_avg:101.63ms +step:48/1695 train_time:4873ms step_avg:101.53ms +step:49/1695 train_time:4966ms step_avg:101.35ms +step:50/1695 train_time:5060ms step_avg:101.21ms +step:51/1695 train_time:5155ms step_avg:101.07ms +step:52/1695 train_time:5249ms step_avg:100.94ms +step:53/1695 train_time:5343ms step_avg:100.82ms +step:54/1695 train_time:5439ms step_avg:100.72ms +step:55/1695 train_time:5534ms step_avg:100.62ms +step:56/1695 train_time:5629ms step_avg:100.51ms +step:57/1695 train_time:5723ms step_avg:100.41ms +step:58/1695 train_time:5818ms step_avg:100.32ms +step:59/1695 train_time:5913ms step_avg:100.23ms +step:60/1695 train_time:6007ms step_avg:100.11ms +step:61/1695 train_time:6100ms step_avg:100.00ms +step:62/1695 train_time:6196ms step_avg:99.93ms +step:63/1695 train_time:6290ms step_avg:99.84ms +step:64/1695 train_time:6384ms step_avg:99.75ms +step:65/1695 train_time:6479ms step_avg:99.68ms +step:66/1695 train_time:6573ms step_avg:99.59ms +step:67/1695 train_time:6667ms step_avg:99.51ms +step:68/1695 train_time:6762ms step_avg:99.44ms +step:69/1695 train_time:6856ms step_avg:99.37ms +step:70/1695 train_time:6950ms step_avg:99.29ms +step:71/1695 train_time:7044ms step_avg:99.21ms +step:72/1695 train_time:7139ms step_avg:99.16ms +step:73/1695 train_time:7234ms step_avg:99.10ms +step:74/1695 train_time:7329ms step_avg:99.04ms +step:75/1695 train_time:7423ms step_avg:98.98ms +step:76/1695 train_time:7519ms step_avg:98.94ms +step:77/1695 train_time:7614ms step_avg:98.88ms +step:78/1695 train_time:7709ms step_avg:98.83ms +step:79/1695 train_time:7803ms step_avg:98.77ms +step:80/1695 train_time:7897ms step_avg:98.71ms +step:81/1695 train_time:7991ms step_avg:98.66ms +step:82/1695 train_time:8085ms step_avg:98.60ms +step:83/1695 train_time:8179ms step_avg:98.55ms +step:84/1695 train_time:8274ms step_avg:98.50ms +step:85/1695 train_time:8368ms step_avg:98.45ms +step:86/1695 train_time:8462ms step_avg:98.39ms +step:87/1695 train_time:8558ms step_avg:98.36ms +step:88/1695 train_time:8653ms step_avg:98.33ms +step:89/1695 train_time:8747ms step_avg:98.28ms +step:90/1695 train_time:8841ms step_avg:98.24ms +step:91/1695 train_time:8936ms step_avg:98.20ms +step:92/1695 train_time:9031ms step_avg:98.16ms +step:93/1695 train_time:9125ms step_avg:98.12ms +step:94/1695 train_time:9219ms step_avg:98.08ms +step:95/1695 train_time:9313ms step_avg:98.03ms +step:96/1695 train_time:9406ms step_avg:97.98ms +step:97/1695 train_time:9500ms step_avg:97.94ms +step:98/1695 train_time:9596ms step_avg:97.91ms +step:99/1695 train_time:9690ms step_avg:97.88ms +step:100/1695 train_time:9784ms step_avg:97.84ms +step:101/1695 train_time:9878ms step_avg:97.80ms +step:102/1695 train_time:9973ms step_avg:97.78ms +step:103/1695 train_time:10067ms step_avg:97.74ms +step:104/1695 train_time:10162ms step_avg:97.71ms +step:105/1695 train_time:10256ms step_avg:97.67ms +step:106/1695 train_time:10350ms step_avg:97.64ms +step:107/1695 train_time:10444ms step_avg:97.60ms +step:108/1695 train_time:10540ms step_avg:97.59ms +step:109/1695 train_time:10635ms step_avg:97.57ms +step:110/1695 train_time:10731ms step_avg:97.55ms +step:111/1695 train_time:10825ms step_avg:97.52ms +step:112/1695 train_time:10920ms step_avg:97.50ms +step:113/1695 train_time:11014ms step_avg:97.47ms +step:114/1695 train_time:11108ms step_avg:97.44ms +step:115/1695 train_time:11201ms step_avg:97.40ms +step:116/1695 train_time:11297ms step_avg:97.38ms +step:117/1695 train_time:11391ms step_avg:97.36ms +step:118/1695 train_time:11485ms step_avg:97.33ms +step:119/1695 train_time:11580ms step_avg:97.31ms +step:120/1695 train_time:11675ms step_avg:97.29ms +step:121/1695 train_time:11769ms step_avg:97.27ms +step:122/1695 train_time:11863ms step_avg:97.24ms +step:123/1695 train_time:11959ms step_avg:97.22ms +step:124/1695 train_time:12054ms step_avg:97.21ms +step:125/1695 train_time:12148ms step_avg:97.18ms +step:125/1695 val_loss:4.3195 train_time:12239ms step_avg:97.92ms +step:126/1695 train_time:12266ms step_avg:97.35ms +step:127/1695 train_time:12343ms step_avg:97.19ms +step:128/1695 train_time:12442ms step_avg:97.21ms +step:129/1695 train_time:12537ms step_avg:97.19ms +step:130/1695 train_time:12631ms step_avg:97.16ms +step:131/1695 train_time:12725ms step_avg:97.14ms +step:132/1695 train_time:12818ms step_avg:97.11ms +step:133/1695 train_time:12911ms step_avg:97.07ms +step:134/1695 train_time:13005ms step_avg:97.05ms +step:135/1695 train_time:13098ms step_avg:97.02ms +step:136/1695 train_time:13191ms step_avg:97.00ms +step:137/1695 train_time:13287ms step_avg:96.99ms +step:138/1695 train_time:13385ms step_avg:97.00ms +step:139/1695 train_time:13481ms step_avg:96.98ms +step:140/1695 train_time:13575ms step_avg:96.97ms +step:141/1695 train_time:13669ms step_avg:96.94ms +step:142/1695 train_time:13763ms step_avg:96.92ms +step:143/1695 train_time:13856ms step_avg:96.90ms +step:144/1695 train_time:13949ms step_avg:96.87ms +step:145/1695 train_time:14043ms step_avg:96.85ms +step:146/1695 train_time:14136ms step_avg:96.82ms +step:147/1695 train_time:14230ms step_avg:96.80ms +step:148/1695 train_time:14326ms step_avg:96.79ms +step:149/1695 train_time:14422ms step_avg:96.79ms +step:150/1695 train_time:14517ms step_avg:96.78ms +step:151/1695 train_time:14612ms step_avg:96.77ms +step:152/1695 train_time:14707ms step_avg:96.75ms +step:153/1695 train_time:14801ms step_avg:96.74ms +step:154/1695 train_time:14895ms step_avg:96.72ms +step:155/1695 train_time:14989ms step_avg:96.70ms +step:156/1695 train_time:15083ms step_avg:96.69ms +step:157/1695 train_time:15176ms step_avg:96.66ms +step:158/1695 train_time:15270ms step_avg:96.65ms +step:159/1695 train_time:15365ms step_avg:96.63ms +step:160/1695 train_time:15461ms step_avg:96.63ms +step:161/1695 train_time:15556ms step_avg:96.62ms +step:162/1695 train_time:15650ms step_avg:96.60ms +step:163/1695 train_time:15744ms step_avg:96.59ms +step:164/1695 train_time:15839ms step_avg:96.58ms +step:165/1695 train_time:15933ms step_avg:96.56ms +step:166/1695 train_time:16027ms step_avg:96.55ms +step:167/1695 train_time:16121ms step_avg:96.53ms +step:168/1695 train_time:16215ms step_avg:96.52ms +step:169/1695 train_time:16309ms step_avg:96.50ms +step:170/1695 train_time:16404ms step_avg:96.50ms +step:171/1695 train_time:16499ms step_avg:96.49ms +step:172/1695 train_time:16594ms step_avg:96.47ms +step:173/1695 train_time:16931ms step_avg:97.87ms +step:174/1695 train_time:17020ms step_avg:97.82ms +step:175/1695 train_time:17114ms step_avg:97.79ms +step:176/1695 train_time:17207ms step_avg:97.77ms +step:177/1695 train_time:17301ms step_avg:97.74ms +step:178/1695 train_time:17394ms step_avg:97.72ms +step:179/1695 train_time:17487ms step_avg:97.69ms +step:180/1695 train_time:17581ms step_avg:97.67ms +step:181/1695 train_time:17673ms step_avg:97.64ms +step:182/1695 train_time:17767ms step_avg:97.62ms +step:183/1695 train_time:17864ms step_avg:97.62ms +step:184/1695 train_time:17961ms step_avg:97.61ms +step:185/1695 train_time:18055ms step_avg:97.60ms +step:186/1695 train_time:18149ms step_avg:97.58ms +step:187/1695 train_time:18244ms step_avg:97.56ms +step:188/1695 train_time:18339ms step_avg:97.55ms +step:189/1695 train_time:18432ms step_avg:97.52ms +step:190/1695 train_time:18526ms step_avg:97.50ms +step:191/1695 train_time:18620ms step_avg:97.49ms +step:192/1695 train_time:18713ms step_avg:97.46ms +step:193/1695 train_time:18808ms step_avg:97.45ms +step:194/1695 train_time:18904ms step_avg:97.44ms +step:195/1695 train_time:19000ms step_avg:97.43ms +step:196/1695 train_time:19093ms step_avg:97.42ms +step:197/1695 train_time:19188ms step_avg:97.40ms +step:198/1695 train_time:19283ms step_avg:97.39ms +step:199/1695 train_time:19377ms step_avg:97.37ms +step:200/1695 train_time:19470ms step_avg:97.35ms +step:201/1695 train_time:19564ms step_avg:97.33ms +step:202/1695 train_time:19657ms step_avg:97.31ms +step:203/1695 train_time:19751ms step_avg:97.30ms +step:204/1695 train_time:19847ms step_avg:97.29ms +step:205/1695 train_time:19942ms step_avg:97.28ms +step:206/1695 train_time:20036ms step_avg:97.26ms +step:207/1695 train_time:20129ms step_avg:97.24ms +step:208/1695 train_time:20224ms step_avg:97.23ms +step:209/1695 train_time:20318ms step_avg:97.21ms +step:210/1695 train_time:20411ms step_avg:97.20ms +step:211/1695 train_time:20506ms step_avg:97.18ms +step:212/1695 train_time:20601ms step_avg:97.17ms +step:213/1695 train_time:20693ms step_avg:97.15ms +step:214/1695 train_time:20787ms step_avg:97.14ms +step:215/1695 train_time:20883ms step_avg:97.13ms +step:216/1695 train_time:20977ms step_avg:97.12ms +step:217/1695 train_time:21071ms step_avg:97.10ms +step:218/1695 train_time:21166ms step_avg:97.09ms +step:219/1695 train_time:21261ms step_avg:97.08ms +step:220/1695 train_time:21355ms step_avg:97.07ms +step:221/1695 train_time:21449ms step_avg:97.05ms +step:222/1695 train_time:21544ms step_avg:97.04ms +step:223/1695 train_time:21637ms step_avg:97.03ms +step:224/1695 train_time:21730ms step_avg:97.01ms +step:225/1695 train_time:21825ms step_avg:97.00ms +step:226/1695 train_time:21919ms step_avg:96.99ms +step:227/1695 train_time:22013ms step_avg:96.97ms +step:228/1695 train_time:22108ms step_avg:96.97ms +step:229/1695 train_time:22203ms step_avg:96.96ms +step:230/1695 train_time:22298ms step_avg:96.95ms +step:231/1695 train_time:22391ms step_avg:96.93ms +step:232/1695 train_time:22486ms step_avg:96.92ms +step:233/1695 train_time:22580ms step_avg:96.91ms +step:234/1695 train_time:22673ms step_avg:96.89ms +step:235/1695 train_time:22767ms step_avg:96.88ms +step:236/1695 train_time:22862ms step_avg:96.87ms +step:237/1695 train_time:22957ms step_avg:96.86ms +step:238/1695 train_time:23051ms step_avg:96.85ms +step:239/1695 train_time:23146ms step_avg:96.85ms +step:240/1695 train_time:23242ms step_avg:96.84ms +step:241/1695 train_time:23336ms step_avg:96.83ms +step:242/1695 train_time:23430ms step_avg:96.82ms +step:243/1695 train_time:23526ms step_avg:96.81ms +step:244/1695 train_time:23621ms step_avg:96.81ms +step:245/1695 train_time:23715ms step_avg:96.79ms +step:246/1695 train_time:23809ms step_avg:96.78ms +step:247/1695 train_time:23902ms step_avg:96.77ms +step:248/1695 train_time:23996ms step_avg:96.76ms +step:249/1695 train_time:24089ms step_avg:96.74ms +step:250/1695 train_time:24184ms step_avg:96.73ms +step:250/1695 val_loss:3.9759 train_time:24277ms step_avg:97.11ms +step:251/1695 train_time:24301ms step_avg:96.82ms +step:252/1695 train_time:24381ms step_avg:96.75ms +step:253/1695 train_time:24478ms step_avg:96.75ms +step:254/1695 train_time:24574ms step_avg:96.75ms +step:255/1695 train_time:24668ms step_avg:96.74ms +step:256/1695 train_time:24760ms step_avg:96.72ms +step:257/1695 train_time:24854ms step_avg:96.71ms +step:258/1695 train_time:24948ms step_avg:96.70ms +step:259/1695 train_time:25041ms step_avg:96.68ms +step:260/1695 train_time:25134ms step_avg:96.67ms +step:261/1695 train_time:25228ms step_avg:96.66ms +step:262/1695 train_time:25322ms step_avg:96.65ms +step:263/1695 train_time:25418ms step_avg:96.65ms +step:264/1695 train_time:25514ms step_avg:96.64ms +step:265/1695 train_time:25609ms step_avg:96.64ms +step:266/1695 train_time:25703ms step_avg:96.63ms +step:267/1695 train_time:25796ms step_avg:96.62ms +step:268/1695 train_time:25890ms step_avg:96.60ms +step:269/1695 train_time:25983ms step_avg:96.59ms +step:270/1695 train_time:26076ms step_avg:96.58ms +step:271/1695 train_time:26169ms step_avg:96.57ms +step:272/1695 train_time:26264ms step_avg:96.56ms +step:273/1695 train_time:26357ms step_avg:96.55ms +step:274/1695 train_time:26453ms step_avg:96.54ms +step:275/1695 train_time:26549ms step_avg:96.54ms +step:276/1695 train_time:26643ms step_avg:96.53ms +step:277/1695 train_time:26737ms step_avg:96.52ms +step:278/1695 train_time:26831ms step_avg:96.51ms +step:279/1695 train_time:26924ms step_avg:96.50ms +step:280/1695 train_time:27017ms step_avg:96.49ms +step:281/1695 train_time:27111ms step_avg:96.48ms +step:282/1695 train_time:27205ms step_avg:96.47ms +step:283/1695 train_time:27298ms step_avg:96.46ms +step:284/1695 train_time:27393ms step_avg:96.45ms +step:285/1695 train_time:27488ms step_avg:96.45ms +step:286/1695 train_time:27582ms step_avg:96.44ms +step:287/1695 train_time:27677ms step_avg:96.43ms +step:288/1695 train_time:27772ms step_avg:96.43ms +step:289/1695 train_time:27868ms step_avg:96.43ms +step:290/1695 train_time:27961ms step_avg:96.42ms +step:291/1695 train_time:28055ms step_avg:96.41ms +step:292/1695 train_time:28149ms step_avg:96.40ms +step:293/1695 train_time:28242ms step_avg:96.39ms +step:294/1695 train_time:28336ms step_avg:96.38ms +step:295/1695 train_time:28431ms step_avg:96.38ms +step:296/1695 train_time:28526ms step_avg:96.37ms +step:297/1695 train_time:28619ms step_avg:96.36ms +step:298/1695 train_time:28713ms step_avg:96.35ms +step:299/1695 train_time:28809ms step_avg:96.35ms +step:300/1695 train_time:28903ms step_avg:96.34ms +step:301/1695 train_time:28997ms step_avg:96.33ms +step:302/1695 train_time:29091ms step_avg:96.33ms +step:303/1695 train_time:29186ms step_avg:96.32ms +step:304/1695 train_time:29279ms step_avg:96.31ms +step:305/1695 train_time:29374ms step_avg:96.31ms +step:306/1695 train_time:29469ms step_avg:96.31ms +step:307/1695 train_time:29565ms step_avg:96.30ms +step:308/1695 train_time:29659ms step_avg:96.30ms +step:309/1695 train_time:29753ms step_avg:96.29ms +step:310/1695 train_time:29848ms step_avg:96.29ms +step:311/1695 train_time:29942ms step_avg:96.28ms +step:312/1695 train_time:30036ms step_avg:96.27ms +step:313/1695 train_time:30131ms step_avg:96.26ms +step:314/1695 train_time:30224ms step_avg:96.25ms +step:315/1695 train_time:30317ms step_avg:96.25ms +step:316/1695 train_time:30412ms step_avg:96.24ms +step:317/1695 train_time:30508ms step_avg:96.24ms +step:318/1695 train_time:30602ms step_avg:96.23ms +step:319/1695 train_time:30696ms step_avg:96.23ms +step:320/1695 train_time:30790ms step_avg:96.22ms +step:321/1695 train_time:30884ms step_avg:96.21ms +step:322/1695 train_time:30978ms step_avg:96.21ms +step:323/1695 train_time:31073ms step_avg:96.20ms +step:324/1695 train_time:31168ms step_avg:96.20ms +step:325/1695 train_time:31261ms step_avg:96.19ms +step:326/1695 train_time:31354ms step_avg:96.18ms +step:327/1695 train_time:31448ms step_avg:96.17ms +step:328/1695 train_time:31543ms step_avg:96.17ms +step:329/1695 train_time:31637ms step_avg:96.16ms +step:330/1695 train_time:31733ms step_avg:96.16ms +step:331/1695 train_time:31828ms step_avg:96.16ms +step:332/1695 train_time:31921ms step_avg:96.15ms +step:333/1695 train_time:32015ms step_avg:96.14ms +step:334/1695 train_time:32109ms step_avg:96.14ms +step:335/1695 train_time:32203ms step_avg:96.13ms +step:336/1695 train_time:32296ms step_avg:96.12ms +step:337/1695 train_time:32390ms step_avg:96.11ms +step:338/1695 train_time:32483ms step_avg:96.10ms +step:339/1695 train_time:32577ms step_avg:96.10ms +step:340/1695 train_time:32672ms step_avg:96.09ms +step:341/1695 train_time:32767ms step_avg:96.09ms +step:342/1695 train_time:32861ms step_avg:96.08ms +step:343/1695 train_time:32956ms step_avg:96.08ms +step:344/1695 train_time:33050ms step_avg:96.08ms +step:345/1695 train_time:33388ms step_avg:96.78ms +step:346/1695 train_time:33462ms step_avg:96.71ms +step:347/1695 train_time:33554ms step_avg:96.70ms +step:348/1695 train_time:33647ms step_avg:96.69ms +step:349/1695 train_time:33740ms step_avg:96.68ms +step:350/1695 train_time:33833ms step_avg:96.67ms +step:351/1695 train_time:33927ms step_avg:96.66ms +step:352/1695 train_time:34019ms step_avg:96.65ms +step:353/1695 train_time:34112ms step_avg:96.64ms +step:354/1695 train_time:34205ms step_avg:96.62ms +step:355/1695 train_time:34300ms step_avg:96.62ms +step:356/1695 train_time:34397ms step_avg:96.62ms +step:357/1695 train_time:34494ms step_avg:96.62ms +step:358/1695 train_time:34589ms step_avg:96.62ms +step:359/1695 train_time:34682ms step_avg:96.61ms +step:360/1695 train_time:34775ms step_avg:96.60ms +step:361/1695 train_time:34868ms step_avg:96.59ms +step:362/1695 train_time:34961ms step_avg:96.58ms +step:363/1695 train_time:35055ms step_avg:96.57ms +step:364/1695 train_time:35148ms step_avg:96.56ms +step:365/1695 train_time:35243ms step_avg:96.56ms +step:366/1695 train_time:35338ms step_avg:96.55ms +step:367/1695 train_time:35434ms step_avg:96.55ms +step:368/1695 train_time:35530ms step_avg:96.55ms +step:369/1695 train_time:35624ms step_avg:96.54ms +step:370/1695 train_time:35717ms step_avg:96.53ms +step:371/1695 train_time:35811ms step_avg:96.53ms +step:372/1695 train_time:35905ms step_avg:96.52ms +step:373/1695 train_time:35998ms step_avg:96.51ms +step:374/1695 train_time:36092ms step_avg:96.50ms +step:375/1695 train_time:36187ms step_avg:96.50ms +step:375/1695 val_loss:3.8237 train_time:36278ms step_avg:96.74ms +step:376/1695 train_time:36303ms step_avg:96.55ms +step:377/1695 train_time:36381ms step_avg:96.50ms +step:378/1695 train_time:36478ms step_avg:96.50ms +step:379/1695 train_time:36573ms step_avg:96.50ms +step:380/1695 train_time:36667ms step_avg:96.49ms +step:381/1695 train_time:36761ms step_avg:96.48ms +step:382/1695 train_time:36854ms step_avg:96.48ms +step:383/1695 train_time:36948ms step_avg:96.47ms +step:384/1695 train_time:37040ms step_avg:96.46ms +step:385/1695 train_time:37133ms step_avg:96.45ms +step:386/1695 train_time:37227ms step_avg:96.44ms +step:387/1695 train_time:37323ms step_avg:96.44ms +step:388/1695 train_time:37419ms step_avg:96.44ms +step:389/1695 train_time:37514ms step_avg:96.44ms +step:390/1695 train_time:37609ms step_avg:96.43ms +step:391/1695 train_time:37703ms step_avg:96.43ms +step:392/1695 train_time:37795ms step_avg:96.42ms +step:393/1695 train_time:37890ms step_avg:96.41ms +step:394/1695 train_time:37982ms step_avg:96.40ms +step:395/1695 train_time:38075ms step_avg:96.39ms +step:396/1695 train_time:38169ms step_avg:96.39ms +step:397/1695 train_time:38263ms step_avg:96.38ms +step:398/1695 train_time:38358ms step_avg:96.38ms +step:399/1695 train_time:38453ms step_avg:96.37ms +step:400/1695 train_time:38548ms step_avg:96.37ms +step:401/1695 train_time:38642ms step_avg:96.36ms +step:402/1695 train_time:38736ms step_avg:96.36ms +step:403/1695 train_time:38830ms step_avg:96.35ms +step:404/1695 train_time:38924ms step_avg:96.35ms +step:405/1695 train_time:39017ms step_avg:96.34ms +step:406/1695 train_time:39111ms step_avg:96.33ms +step:407/1695 train_time:39205ms step_avg:96.33ms +step:408/1695 train_time:39298ms step_avg:96.32ms +step:409/1695 train_time:39394ms step_avg:96.32ms +step:410/1695 train_time:39489ms step_avg:96.31ms +step:411/1695 train_time:39583ms step_avg:96.31ms +step:412/1695 train_time:39676ms step_avg:96.30ms +step:413/1695 train_time:39771ms step_avg:96.30ms +step:414/1695 train_time:39865ms step_avg:96.29ms +step:415/1695 train_time:39959ms step_avg:96.29ms +step:416/1695 train_time:40053ms step_avg:96.28ms +step:417/1695 train_time:40148ms step_avg:96.28ms +step:418/1695 train_time:40241ms step_avg:96.27ms +step:419/1695 train_time:40335ms step_avg:96.27ms +step:420/1695 train_time:40430ms step_avg:96.26ms +step:421/1695 train_time:40525ms step_avg:96.26ms +step:422/1695 train_time:40618ms step_avg:96.25ms +step:423/1695 train_time:40712ms step_avg:96.25ms +step:424/1695 train_time:40806ms step_avg:96.24ms +step:425/1695 train_time:40900ms step_avg:96.23ms +step:426/1695 train_time:40994ms step_avg:96.23ms +step:427/1695 train_time:41088ms step_avg:96.23ms +step:428/1695 train_time:41181ms step_avg:96.22ms +step:429/1695 train_time:41275ms step_avg:96.21ms +step:430/1695 train_time:41370ms step_avg:96.21ms +step:431/1695 train_time:41465ms step_avg:96.21ms +step:432/1695 train_time:41559ms step_avg:96.20ms +step:433/1695 train_time:41654ms step_avg:96.20ms +step:434/1695 train_time:41748ms step_avg:96.19ms +step:435/1695 train_time:41841ms step_avg:96.19ms +step:436/1695 train_time:41934ms step_avg:96.18ms +step:437/1695 train_time:42028ms step_avg:96.17ms +step:438/1695 train_time:42121ms step_avg:96.17ms +step:439/1695 train_time:42215ms step_avg:96.16ms +step:440/1695 train_time:42309ms step_avg:96.16ms +step:441/1695 train_time:42403ms step_avg:96.15ms +step:442/1695 train_time:42497ms step_avg:96.15ms +step:443/1695 train_time:42592ms step_avg:96.15ms +step:444/1695 train_time:42687ms step_avg:96.14ms +step:445/1695 train_time:42781ms step_avg:96.14ms +step:446/1695 train_time:42874ms step_avg:96.13ms +step:447/1695 train_time:42968ms step_avg:96.12ms +step:448/1695 train_time:43061ms step_avg:96.12ms +step:449/1695 train_time:43155ms step_avg:96.11ms +step:450/1695 train_time:43249ms step_avg:96.11ms +step:451/1695 train_time:43343ms step_avg:96.10ms +step:452/1695 train_time:43437ms step_avg:96.10ms +step:453/1695 train_time:43531ms step_avg:96.09ms +step:454/1695 train_time:43625ms step_avg:96.09ms +step:455/1695 train_time:43719ms step_avg:96.08ms +step:456/1695 train_time:43813ms step_avg:96.08ms +step:457/1695 train_time:43907ms step_avg:96.08ms +step:458/1695 train_time:44000ms step_avg:96.07ms +step:459/1695 train_time:44095ms step_avg:96.07ms +step:460/1695 train_time:44189ms step_avg:96.06ms +step:461/1695 train_time:44283ms step_avg:96.06ms +step:462/1695 train_time:44377ms step_avg:96.05ms +step:463/1695 train_time:44471ms step_avg:96.05ms +step:464/1695 train_time:44566ms step_avg:96.05ms +step:465/1695 train_time:44660ms step_avg:96.04ms +step:466/1695 train_time:44754ms step_avg:96.04ms +step:467/1695 train_time:44849ms step_avg:96.04ms +step:468/1695 train_time:44942ms step_avg:96.03ms +step:469/1695 train_time:45036ms step_avg:96.03ms +step:470/1695 train_time:45131ms step_avg:96.02ms +step:471/1695 train_time:45224ms step_avg:96.02ms +step:472/1695 train_time:45318ms step_avg:96.01ms +step:473/1695 train_time:45413ms step_avg:96.01ms +step:474/1695 train_time:45507ms step_avg:96.01ms +step:475/1695 train_time:45601ms step_avg:96.00ms +step:476/1695 train_time:45695ms step_avg:96.00ms +step:477/1695 train_time:45789ms step_avg:95.99ms +step:478/1695 train_time:45882ms step_avg:95.99ms +step:479/1695 train_time:45976ms step_avg:95.98ms +step:480/1695 train_time:46071ms step_avg:95.98ms +step:481/1695 train_time:46165ms step_avg:95.98ms +step:482/1695 train_time:46259ms step_avg:95.97ms +step:483/1695 train_time:46353ms step_avg:95.97ms +step:484/1695 train_time:46449ms step_avg:95.97ms +step:485/1695 train_time:46542ms step_avg:95.96ms +step:486/1695 train_time:46636ms step_avg:95.96ms +step:487/1695 train_time:46731ms step_avg:95.96ms +step:488/1695 train_time:46825ms step_avg:95.95ms +step:489/1695 train_time:46918ms step_avg:95.95ms +step:490/1695 train_time:47013ms step_avg:95.94ms +step:491/1695 train_time:47107ms step_avg:95.94ms +step:492/1695 train_time:47200ms step_avg:95.94ms +step:493/1695 train_time:47295ms step_avg:95.93ms +step:494/1695 train_time:47390ms step_avg:95.93ms +step:495/1695 train_time:47484ms step_avg:95.93ms +step:496/1695 train_time:47577ms step_avg:95.92ms +step:497/1695 train_time:47672ms step_avg:95.92ms +step:498/1695 train_time:47766ms step_avg:95.92ms +step:499/1695 train_time:47860ms step_avg:95.91ms +step:500/1695 train_time:47954ms step_avg:95.91ms +step:500/1695 val_loss:3.7206 train_time:48046ms step_avg:96.09ms +step:501/1695 train_time:48071ms step_avg:95.95ms +step:502/1695 train_time:48149ms step_avg:95.91ms +step:503/1695 train_time:48247ms step_avg:95.92ms +step:504/1695 train_time:48342ms step_avg:95.92ms +step:505/1695 train_time:48436ms step_avg:95.91ms +step:506/1695 train_time:48529ms step_avg:95.91ms +step:507/1695 train_time:48622ms step_avg:95.90ms +step:508/1695 train_time:48715ms step_avg:95.90ms +step:509/1695 train_time:48808ms step_avg:95.89ms +step:510/1695 train_time:48901ms step_avg:95.88ms +step:511/1695 train_time:48994ms step_avg:95.88ms +step:512/1695 train_time:49090ms step_avg:95.88ms +step:513/1695 train_time:49185ms step_avg:95.88ms +step:514/1695 train_time:49281ms step_avg:95.88ms +step:515/1695 train_time:49377ms step_avg:95.88ms +step:516/1695 train_time:49471ms step_avg:95.87ms +step:517/1695 train_time:49564ms step_avg:95.87ms +step:518/1695 train_time:49657ms step_avg:95.86ms +step:519/1695 train_time:49991ms step_avg:96.32ms +step:520/1695 train_time:50182ms step_avg:96.50ms +step:521/1695 train_time:50274ms step_avg:96.50ms +step:522/1695 train_time:50366ms step_avg:96.49ms +step:523/1695 train_time:50458ms step_avg:96.48ms +step:524/1695 train_time:50551ms step_avg:96.47ms +step:525/1695 train_time:50644ms step_avg:96.46ms +step:526/1695 train_time:50737ms step_avg:96.46ms +step:527/1695 train_time:50829ms step_avg:96.45ms +step:528/1695 train_time:50922ms step_avg:96.44ms +step:529/1695 train_time:51018ms step_avg:96.44ms +step:530/1695 train_time:51114ms step_avg:96.44ms +step:531/1695 train_time:51210ms step_avg:96.44ms +step:532/1695 train_time:51305ms step_avg:96.44ms +step:533/1695 train_time:51399ms step_avg:96.43ms +step:534/1695 train_time:51493ms step_avg:96.43ms +step:535/1695 train_time:51586ms step_avg:96.42ms +step:536/1695 train_time:51679ms step_avg:96.42ms +step:537/1695 train_time:51772ms step_avg:96.41ms +step:538/1695 train_time:51865ms step_avg:96.40ms +step:539/1695 train_time:51958ms step_avg:96.40ms +step:540/1695 train_time:52052ms step_avg:96.39ms +step:541/1695 train_time:52147ms step_avg:96.39ms +step:542/1695 train_time:52241ms step_avg:96.39ms +step:543/1695 train_time:52336ms step_avg:96.38ms +step:544/1695 train_time:52429ms step_avg:96.38ms +step:545/1695 train_time:52523ms step_avg:96.37ms +step:546/1695 train_time:52617ms step_avg:96.37ms +step:547/1695 train_time:52710ms step_avg:96.36ms +step:548/1695 train_time:52803ms step_avg:96.36ms +step:549/1695 train_time:52898ms step_avg:96.35ms +step:550/1695 train_time:52991ms step_avg:96.35ms +step:551/1695 train_time:53086ms step_avg:96.34ms +step:552/1695 train_time:53182ms step_avg:96.34ms +step:553/1695 train_time:53276ms step_avg:96.34ms +step:554/1695 train_time:53370ms step_avg:96.34ms +step:555/1695 train_time:53464ms step_avg:96.33ms +step:556/1695 train_time:53558ms step_avg:96.33ms +step:557/1695 train_time:53652ms step_avg:96.32ms +step:558/1695 train_time:53745ms step_avg:96.32ms +step:559/1695 train_time:53839ms step_avg:96.31ms +step:560/1695 train_time:53933ms step_avg:96.31ms +step:561/1695 train_time:54027ms step_avg:96.30ms +step:562/1695 train_time:54121ms step_avg:96.30ms +step:563/1695 train_time:54216ms step_avg:96.30ms +step:564/1695 train_time:54311ms step_avg:96.30ms +step:565/1695 train_time:54405ms step_avg:96.29ms +step:566/1695 train_time:54499ms step_avg:96.29ms +step:567/1695 train_time:54594ms step_avg:96.29ms +step:568/1695 train_time:54690ms step_avg:96.28ms +step:569/1695 train_time:54785ms step_avg:96.28ms +step:570/1695 train_time:54881ms step_avg:96.28ms +step:571/1695 train_time:54978ms step_avg:96.28ms +step:572/1695 train_time:55074ms step_avg:96.28ms +step:573/1695 train_time:55169ms step_avg:96.28ms +step:574/1695 train_time:55265ms step_avg:96.28ms +step:575/1695 train_time:55363ms step_avg:96.28ms +step:576/1695 train_time:55460ms step_avg:96.28ms +step:577/1695 train_time:55557ms step_avg:96.29ms +step:578/1695 train_time:55653ms step_avg:96.29ms +step:579/1695 train_time:55748ms step_avg:96.28ms +step:580/1695 train_time:55845ms step_avg:96.28ms +step:581/1695 train_time:55940ms step_avg:96.28ms +step:582/1695 train_time:56037ms step_avg:96.28ms +step:583/1695 train_time:56132ms step_avg:96.28ms +step:584/1695 train_time:56228ms step_avg:96.28ms +step:585/1695 train_time:56324ms step_avg:96.28ms +step:586/1695 train_time:56421ms step_avg:96.28ms +step:587/1695 train_time:56519ms step_avg:96.28ms +step:588/1695 train_time:56617ms step_avg:96.29ms +step:589/1695 train_time:56712ms step_avg:96.29ms +step:590/1695 train_time:56807ms step_avg:96.28ms +step:591/1695 train_time:56903ms step_avg:96.28ms +step:592/1695 train_time:57000ms step_avg:96.28ms +step:593/1695 train_time:57096ms step_avg:96.28ms +step:594/1695 train_time:57192ms step_avg:96.28ms +step:595/1695 train_time:57288ms step_avg:96.28ms +step:596/1695 train_time:57384ms step_avg:96.28ms +step:597/1695 train_time:57481ms step_avg:96.28ms +step:598/1695 train_time:57577ms step_avg:96.28ms +step:599/1695 train_time:57673ms step_avg:96.28ms +step:600/1695 train_time:57769ms step_avg:96.28ms +step:601/1695 train_time:57864ms step_avg:96.28ms +step:602/1695 train_time:57960ms step_avg:96.28ms +step:603/1695 train_time:58057ms step_avg:96.28ms +step:604/1695 train_time:58153ms step_avg:96.28ms +step:605/1695 train_time:58248ms step_avg:96.28ms +step:606/1695 train_time:58344ms step_avg:96.28ms +step:607/1695 train_time:58440ms step_avg:96.28ms +step:608/1695 train_time:58537ms step_avg:96.28ms +step:609/1695 train_time:58633ms step_avg:96.28ms +step:610/1695 train_time:58728ms step_avg:96.28ms +step:611/1695 train_time:58824ms step_avg:96.27ms +step:612/1695 train_time:58920ms step_avg:96.27ms +step:613/1695 train_time:59016ms step_avg:96.27ms +step:614/1695 train_time:59111ms step_avg:96.27ms +step:615/1695 train_time:59208ms step_avg:96.27ms +step:616/1695 train_time:59305ms step_avg:96.27ms +step:617/1695 train_time:59401ms step_avg:96.27ms +step:618/1695 train_time:59499ms step_avg:96.28ms +step:619/1695 train_time:59596ms step_avg:96.28ms +step:620/1695 train_time:59692ms step_avg:96.28ms +step:621/1695 train_time:59789ms step_avg:96.28ms +step:622/1695 train_time:59885ms step_avg:96.28ms +step:623/1695 train_time:59982ms step_avg:96.28ms +step:624/1695 train_time:60079ms step_avg:96.28ms +step:625/1695 train_time:60176ms step_avg:96.28ms +step:625/1695 val_loss:3.6228 train_time:60270ms step_avg:96.43ms +step:626/1695 train_time:60294ms step_avg:96.32ms +step:627/1695 train_time:60378ms step_avg:96.30ms +step:628/1695 train_time:60473ms step_avg:96.30ms +step:629/1695 train_time:60570ms step_avg:96.30ms +step:630/1695 train_time:60665ms step_avg:96.29ms +step:631/1695 train_time:60761ms step_avg:96.29ms +step:632/1695 train_time:60855ms step_avg:96.29ms +step:633/1695 train_time:60949ms step_avg:96.29ms +step:634/1695 train_time:61044ms step_avg:96.28ms +step:635/1695 train_time:61139ms step_avg:96.28ms +step:636/1695 train_time:61237ms step_avg:96.28ms +step:637/1695 train_time:61335ms step_avg:96.29ms +step:638/1695 train_time:61432ms step_avg:96.29ms +step:639/1695 train_time:61530ms step_avg:96.29ms +step:640/1695 train_time:61627ms step_avg:96.29ms +step:641/1695 train_time:61723ms step_avg:96.29ms +step:642/1695 train_time:61818ms step_avg:96.29ms +step:643/1695 train_time:61912ms step_avg:96.29ms +step:644/1695 train_time:62007ms step_avg:96.28ms +step:645/1695 train_time:62103ms step_avg:96.28ms +step:646/1695 train_time:62199ms step_avg:96.28ms +step:647/1695 train_time:62297ms step_avg:96.29ms +step:648/1695 train_time:62394ms step_avg:96.29ms +step:649/1695 train_time:62491ms step_avg:96.29ms +step:650/1695 train_time:62588ms step_avg:96.29ms +step:651/1695 train_time:62686ms step_avg:96.29ms +step:652/1695 train_time:62783ms step_avg:96.29ms +step:653/1695 train_time:62879ms step_avg:96.29ms +step:654/1695 train_time:62973ms step_avg:96.29ms +step:655/1695 train_time:63068ms step_avg:96.29ms +step:656/1695 train_time:63166ms step_avg:96.29ms +step:657/1695 train_time:63264ms step_avg:96.29ms +step:658/1695 train_time:63362ms step_avg:96.29ms +step:659/1695 train_time:63459ms step_avg:96.30ms +step:660/1695 train_time:63555ms step_avg:96.30ms +step:661/1695 train_time:63652ms step_avg:96.30ms +step:662/1695 train_time:63747ms step_avg:96.30ms +step:663/1695 train_time:63844ms step_avg:96.30ms +step:664/1695 train_time:63939ms step_avg:96.29ms +step:665/1695 train_time:64034ms step_avg:96.29ms +step:666/1695 train_time:64130ms step_avg:96.29ms +step:667/1695 train_time:64227ms step_avg:96.29ms +step:668/1695 train_time:64325ms step_avg:96.29ms +step:669/1695 train_time:64422ms step_avg:96.30ms +step:670/1695 train_time:64518ms step_avg:96.30ms +step:671/1695 train_time:64614ms step_avg:96.30ms +step:672/1695 train_time:64710ms step_avg:96.29ms +step:673/1695 train_time:64806ms step_avg:96.29ms +step:674/1695 train_time:64902ms step_avg:96.29ms +step:675/1695 train_time:64998ms step_avg:96.29ms +step:676/1695 train_time:65093ms step_avg:96.29ms +step:677/1695 train_time:65190ms step_avg:96.29ms +step:678/1695 train_time:65287ms step_avg:96.29ms +step:679/1695 train_time:65384ms step_avg:96.30ms +step:680/1695 train_time:65482ms step_avg:96.30ms +step:681/1695 train_time:65578ms step_avg:96.30ms +step:682/1695 train_time:65674ms step_avg:96.30ms +step:683/1695 train_time:65769ms step_avg:96.29ms +step:684/1695 train_time:65866ms step_avg:96.30ms +step:685/1695 train_time:65962ms step_avg:96.29ms +step:686/1695 train_time:66058ms step_avg:96.29ms +step:687/1695 train_time:66153ms step_avg:96.29ms +step:688/1695 train_time:66249ms step_avg:96.29ms +step:689/1695 train_time:66345ms step_avg:96.29ms +step:690/1695 train_time:66441ms step_avg:96.29ms +step:691/1695 train_time:66884ms step_avg:96.79ms +step:692/1695 train_time:66970ms step_avg:96.78ms +step:693/1695 train_time:67064ms step_avg:96.77ms +step:694/1695 train_time:67159ms step_avg:96.77ms +step:695/1695 train_time:67254ms step_avg:96.77ms +step:696/1695 train_time:67348ms step_avg:96.77ms +step:697/1695 train_time:67443ms step_avg:96.76ms +step:698/1695 train_time:67538ms step_avg:96.76ms +step:699/1695 train_time:67633ms step_avg:96.76ms +step:700/1695 train_time:67727ms step_avg:96.75ms +step:701/1695 train_time:67826ms step_avg:96.76ms +step:702/1695 train_time:67927ms step_avg:96.76ms +step:703/1695 train_time:68025ms step_avg:96.76ms +step:704/1695 train_time:68122ms step_avg:96.76ms +step:705/1695 train_time:68218ms step_avg:96.76ms +step:706/1695 train_time:68313ms step_avg:96.76ms +step:707/1695 train_time:68408ms step_avg:96.76ms +step:708/1695 train_time:68503ms step_avg:96.76ms +step:709/1695 train_time:68598ms step_avg:96.75ms +step:710/1695 train_time:68693ms step_avg:96.75ms +step:711/1695 train_time:68789ms step_avg:96.75ms +step:712/1695 train_time:68885ms step_avg:96.75ms +step:713/1695 train_time:68982ms step_avg:96.75ms +step:714/1695 train_time:69079ms step_avg:96.75ms +step:715/1695 train_time:69175ms step_avg:96.75ms +step:716/1695 train_time:69271ms step_avg:96.75ms +step:717/1695 train_time:69366ms step_avg:96.74ms +step:718/1695 train_time:69461ms step_avg:96.74ms +step:719/1695 train_time:69556ms step_avg:96.74ms +step:720/1695 train_time:69652ms step_avg:96.74ms +step:721/1695 train_time:69748ms step_avg:96.74ms +step:722/1695 train_time:69845ms step_avg:96.74ms +step:723/1695 train_time:69942ms step_avg:96.74ms +step:724/1695 train_time:70038ms step_avg:96.74ms +step:725/1695 train_time:70134ms step_avg:96.74ms +step:726/1695 train_time:70230ms step_avg:96.74ms +step:727/1695 train_time:70328ms step_avg:96.74ms +step:728/1695 train_time:70425ms step_avg:96.74ms +step:729/1695 train_time:70520ms step_avg:96.74ms +step:730/1695 train_time:70615ms step_avg:96.73ms +step:731/1695 train_time:70710ms step_avg:96.73ms +step:732/1695 train_time:70806ms step_avg:96.73ms +step:733/1695 train_time:70904ms step_avg:96.73ms +step:734/1695 train_time:71000ms step_avg:96.73ms +step:735/1695 train_time:71097ms step_avg:96.73ms +step:736/1695 train_time:71193ms step_avg:96.73ms +step:737/1695 train_time:71289ms step_avg:96.73ms +step:738/1695 train_time:71386ms step_avg:96.73ms +step:739/1695 train_time:71483ms step_avg:96.73ms +step:740/1695 train_time:71579ms step_avg:96.73ms +step:741/1695 train_time:71674ms step_avg:96.73ms +step:742/1695 train_time:71769ms step_avg:96.72ms +step:743/1695 train_time:71866ms step_avg:96.72ms +step:744/1695 train_time:71963ms step_avg:96.72ms +step:745/1695 train_time:72059ms step_avg:96.72ms +step:746/1695 train_time:72155ms step_avg:96.72ms +step:747/1695 train_time:72251ms step_avg:96.72ms +step:748/1695 train_time:72348ms step_avg:96.72ms +step:749/1695 train_time:72444ms step_avg:96.72ms +step:750/1695 train_time:72540ms step_avg:96.72ms +step:750/1695 val_loss:3.5691 train_time:72633ms step_avg:96.84ms +step:751/1695 train_time:72658ms step_avg:96.75ms +step:752/1695 train_time:72740ms step_avg:96.73ms +step:753/1695 train_time:72841ms step_avg:96.73ms +step:754/1695 train_time:72937ms step_avg:96.73ms +step:755/1695 train_time:73032ms step_avg:96.73ms +step:756/1695 train_time:73127ms step_avg:96.73ms +step:757/1695 train_time:73221ms step_avg:96.73ms +step:758/1695 train_time:73316ms step_avg:96.72ms +step:759/1695 train_time:73411ms step_avg:96.72ms +step:760/1695 train_time:73506ms step_avg:96.72ms +step:761/1695 train_time:73604ms step_avg:96.72ms +step:762/1695 train_time:73701ms step_avg:96.72ms +step:763/1695 train_time:73800ms step_avg:96.72ms +step:764/1695 train_time:73896ms step_avg:96.72ms +step:765/1695 train_time:73992ms step_avg:96.72ms +step:766/1695 train_time:74088ms step_avg:96.72ms +step:767/1695 train_time:74183ms step_avg:96.72ms +step:768/1695 train_time:74278ms step_avg:96.72ms +step:769/1695 train_time:74373ms step_avg:96.71ms +step:770/1695 train_time:74470ms step_avg:96.71ms +step:771/1695 train_time:74566ms step_avg:96.71ms +step:772/1695 train_time:74663ms step_avg:96.71ms +step:773/1695 train_time:74761ms step_avg:96.71ms +step:774/1695 train_time:74857ms step_avg:96.72ms +step:775/1695 train_time:74954ms step_avg:96.71ms +step:776/1695 train_time:75051ms step_avg:96.72ms +step:777/1695 train_time:75147ms step_avg:96.71ms +step:778/1695 train_time:75242ms step_avg:96.71ms +step:779/1695 train_time:75338ms step_avg:96.71ms +step:780/1695 train_time:75434ms step_avg:96.71ms +step:781/1695 train_time:75531ms step_avg:96.71ms +step:782/1695 train_time:75627ms step_avg:96.71ms +step:783/1695 train_time:75724ms step_avg:96.71ms +step:784/1695 train_time:75820ms step_avg:96.71ms +step:785/1695 train_time:75916ms step_avg:96.71ms +step:786/1695 train_time:76013ms step_avg:96.71ms +step:787/1695 train_time:76109ms step_avg:96.71ms +step:788/1695 train_time:76206ms step_avg:96.71ms +step:789/1695 train_time:76301ms step_avg:96.71ms +step:790/1695 train_time:76397ms step_avg:96.71ms +step:791/1695 train_time:76493ms step_avg:96.70ms +step:792/1695 train_time:76591ms step_avg:96.71ms +step:793/1695 train_time:76689ms step_avg:96.71ms +step:794/1695 train_time:76785ms step_avg:96.71ms +step:795/1695 train_time:76881ms step_avg:96.71ms +step:796/1695 train_time:76977ms step_avg:96.71ms +step:797/1695 train_time:77073ms step_avg:96.70ms +step:798/1695 train_time:77169ms step_avg:96.70ms +step:799/1695 train_time:77264ms step_avg:96.70ms +step:800/1695 train_time:77359ms step_avg:96.70ms +step:801/1695 train_time:77455ms step_avg:96.70ms +step:802/1695 train_time:77551ms step_avg:96.70ms +step:803/1695 train_time:77647ms step_avg:96.70ms +step:804/1695 train_time:77743ms step_avg:96.69ms +step:805/1695 train_time:77839ms step_avg:96.69ms +step:806/1695 train_time:77936ms step_avg:96.69ms +step:807/1695 train_time:78033ms step_avg:96.69ms +step:808/1695 train_time:78130ms step_avg:96.69ms +step:809/1695 train_time:78226ms step_avg:96.69ms +step:810/1695 train_time:78321ms step_avg:96.69ms +step:811/1695 train_time:78417ms step_avg:96.69ms +step:812/1695 train_time:78513ms step_avg:96.69ms +step:813/1695 train_time:78609ms step_avg:96.69ms +step:814/1695 train_time:78706ms step_avg:96.69ms +step:815/1695 train_time:78802ms step_avg:96.69ms +step:816/1695 train_time:78898ms step_avg:96.69ms +step:817/1695 train_time:78996ms step_avg:96.69ms +step:818/1695 train_time:79093ms step_avg:96.69ms +step:819/1695 train_time:79190ms step_avg:96.69ms +step:820/1695 train_time:79286ms step_avg:96.69ms +step:821/1695 train_time:79381ms step_avg:96.69ms +step:822/1695 train_time:79477ms step_avg:96.69ms +step:823/1695 train_time:79574ms step_avg:96.69ms +step:824/1695 train_time:79671ms step_avg:96.69ms +step:825/1695 train_time:79767ms step_avg:96.69ms +step:826/1695 train_time:79863ms step_avg:96.69ms +step:827/1695 train_time:79959ms step_avg:96.69ms +step:828/1695 train_time:80056ms step_avg:96.69ms +step:829/1695 train_time:80153ms step_avg:96.69ms +step:830/1695 train_time:80249ms step_avg:96.69ms +step:831/1695 train_time:80345ms step_avg:96.68ms +step:832/1695 train_time:80440ms step_avg:96.68ms +step:833/1695 train_time:80536ms step_avg:96.68ms +step:834/1695 train_time:80633ms step_avg:96.68ms +step:835/1695 train_time:80730ms step_avg:96.68ms +step:836/1695 train_time:80826ms step_avg:96.68ms +step:837/1695 train_time:80922ms step_avg:96.68ms +step:838/1695 train_time:81019ms step_avg:96.68ms +step:839/1695 train_time:81116ms step_avg:96.68ms +step:840/1695 train_time:81213ms step_avg:96.68ms +step:841/1695 train_time:81309ms step_avg:96.68ms +step:842/1695 train_time:81405ms step_avg:96.68ms +step:843/1695 train_time:81502ms step_avg:96.68ms +step:844/1695 train_time:81597ms step_avg:96.68ms +step:845/1695 train_time:81694ms step_avg:96.68ms +step:846/1695 train_time:81791ms step_avg:96.68ms +step:847/1695 train_time:81888ms step_avg:96.68ms +step:848/1695 train_time:81983ms step_avg:96.68ms +step:849/1695 train_time:82078ms step_avg:96.68ms +step:850/1695 train_time:82174ms step_avg:96.68ms +step:851/1695 train_time:82271ms step_avg:96.68ms +step:852/1695 train_time:82368ms step_avg:96.68ms +step:853/1695 train_time:82464ms step_avg:96.68ms +step:854/1695 train_time:82559ms step_avg:96.67ms +step:855/1695 train_time:82655ms step_avg:96.67ms +step:856/1695 train_time:82752ms step_avg:96.67ms +step:857/1695 train_time:82848ms step_avg:96.67ms +step:858/1695 train_time:82944ms step_avg:96.67ms +step:859/1695 train_time:83039ms step_avg:96.67ms +step:860/1695 train_time:83136ms step_avg:96.67ms +step:861/1695 train_time:83233ms step_avg:96.67ms +step:862/1695 train_time:83331ms step_avg:96.67ms +step:863/1695 train_time:83651ms step_avg:96.93ms +step:864/1695 train_time:83850ms step_avg:97.05ms +step:865/1695 train_time:83944ms step_avg:97.05ms +step:866/1695 train_time:84039ms step_avg:97.04ms +step:867/1695 train_time:84134ms step_avg:97.04ms +step:868/1695 train_time:84229ms step_avg:97.04ms +step:869/1695 train_time:84323ms step_avg:97.03ms +step:870/1695 train_time:84418ms step_avg:97.03ms +step:871/1695 train_time:84513ms step_avg:97.03ms +step:872/1695 train_time:84609ms step_avg:97.03ms +step:873/1695 train_time:84705ms step_avg:97.03ms +step:874/1695 train_time:84806ms step_avg:97.03ms +step:875/1695 train_time:84904ms step_avg:97.03ms +step:875/1695 val_loss:3.5252 train_time:84998ms step_avg:97.14ms +step:876/1695 train_time:85024ms step_avg:97.06ms +step:877/1695 train_time:85101ms step_avg:97.04ms +step:878/1695 train_time:85198ms step_avg:97.04ms +step:879/1695 train_time:85293ms step_avg:97.03ms +step:880/1695 train_time:85388ms step_avg:97.03ms +step:881/1695 train_time:85484ms step_avg:97.03ms +step:882/1695 train_time:85579ms step_avg:97.03ms +step:883/1695 train_time:85673ms step_avg:97.03ms +step:884/1695 train_time:85768ms step_avg:97.02ms +step:885/1695 train_time:85863ms step_avg:97.02ms +step:886/1695 train_time:85962ms step_avg:97.02ms +step:887/1695 train_time:86061ms step_avg:97.02ms +step:888/1695 train_time:86157ms step_avg:97.02ms +step:889/1695 train_time:86253ms step_avg:97.02ms +step:890/1695 train_time:86350ms step_avg:97.02ms +step:891/1695 train_time:86446ms step_avg:97.02ms +step:892/1695 train_time:86542ms step_avg:97.02ms +step:893/1695 train_time:86637ms step_avg:97.02ms +step:894/1695 train_time:86732ms step_avg:97.02ms +step:895/1695 train_time:86828ms step_avg:97.01ms +step:896/1695 train_time:86924ms step_avg:97.01ms +step:897/1695 train_time:87021ms step_avg:97.01ms +step:898/1695 train_time:87117ms step_avg:97.01ms +step:899/1695 train_time:87213ms step_avg:97.01ms +step:900/1695 train_time:87309ms step_avg:97.01ms +step:901/1695 train_time:87406ms step_avg:97.01ms +step:902/1695 train_time:87502ms step_avg:97.01ms +step:903/1695 train_time:87598ms step_avg:97.01ms +step:904/1695 train_time:87693ms step_avg:97.01ms +step:905/1695 train_time:87789ms step_avg:97.00ms +step:906/1695 train_time:87885ms step_avg:97.00ms +step:907/1695 train_time:87982ms step_avg:97.00ms +step:908/1695 train_time:88078ms step_avg:97.00ms +step:909/1695 train_time:88174ms step_avg:97.00ms +step:910/1695 train_time:88270ms step_avg:97.00ms +step:911/1695 train_time:88366ms step_avg:97.00ms +step:912/1695 train_time:88462ms step_avg:97.00ms +step:913/1695 train_time:88558ms step_avg:97.00ms +step:914/1695 train_time:88653ms step_avg:96.99ms +step:915/1695 train_time:88749ms step_avg:96.99ms +step:916/1695 train_time:88845ms step_avg:96.99ms +step:917/1695 train_time:88942ms step_avg:96.99ms +step:918/1695 train_time:89038ms step_avg:96.99ms +step:919/1695 train_time:89133ms step_avg:96.99ms +step:920/1695 train_time:89230ms step_avg:96.99ms +step:921/1695 train_time:89327ms step_avg:96.99ms +step:922/1695 train_time:89423ms step_avg:96.99ms +step:923/1695 train_time:89519ms step_avg:96.99ms +step:924/1695 train_time:89615ms step_avg:96.99ms +step:925/1695 train_time:89711ms step_avg:96.98ms +step:926/1695 train_time:89806ms step_avg:96.98ms +step:927/1695 train_time:89902ms step_avg:96.98ms +step:928/1695 train_time:89998ms step_avg:96.98ms +step:929/1695 train_time:90094ms step_avg:96.98ms +step:930/1695 train_time:90190ms step_avg:96.98ms +step:931/1695 train_time:90287ms step_avg:96.98ms +step:932/1695 train_time:90383ms step_avg:96.98ms +step:933/1695 train_time:90479ms step_avg:96.98ms +step:934/1695 train_time:90575ms step_avg:96.98ms +step:935/1695 train_time:90670ms step_avg:96.97ms +step:936/1695 train_time:90767ms step_avg:96.97ms +step:937/1695 train_time:90863ms step_avg:96.97ms +step:938/1695 train_time:90960ms step_avg:96.97ms +step:939/1695 train_time:91055ms step_avg:96.97ms +step:940/1695 train_time:91151ms step_avg:96.97ms +step:941/1695 train_time:91247ms step_avg:96.97ms +step:942/1695 train_time:91343ms step_avg:96.97ms +step:943/1695 train_time:91440ms step_avg:96.97ms +step:944/1695 train_time:91535ms step_avg:96.97ms +step:945/1695 train_time:91631ms step_avg:96.96ms +step:946/1695 train_time:91726ms step_avg:96.96ms +step:947/1695 train_time:91822ms step_avg:96.96ms +step:948/1695 train_time:91918ms step_avg:96.96ms +step:949/1695 train_time:92013ms step_avg:96.96ms +step:950/1695 train_time:92109ms step_avg:96.96ms +step:951/1695 train_time:92205ms step_avg:96.96ms +step:952/1695 train_time:92301ms step_avg:96.95ms +step:953/1695 train_time:92396ms step_avg:96.95ms +step:954/1695 train_time:92491ms step_avg:96.95ms +step:955/1695 train_time:92587ms step_avg:96.95ms +step:956/1695 train_time:92683ms step_avg:96.95ms +step:957/1695 train_time:92779ms step_avg:96.95ms +step:958/1695 train_time:92875ms step_avg:96.95ms +step:959/1695 train_time:92971ms step_avg:96.95ms +step:960/1695 train_time:93068ms step_avg:96.95ms +step:961/1695 train_time:93164ms step_avg:96.95ms +step:962/1695 train_time:93260ms step_avg:96.94ms +step:963/1695 train_time:93356ms step_avg:96.94ms +step:964/1695 train_time:93453ms step_avg:96.94ms +step:965/1695 train_time:93549ms step_avg:96.94ms +step:966/1695 train_time:93646ms step_avg:96.94ms +step:967/1695 train_time:93743ms step_avg:96.94ms +step:968/1695 train_time:93839ms step_avg:96.94ms +step:969/1695 train_time:93934ms step_avg:96.94ms +step:970/1695 train_time:94030ms step_avg:96.94ms +step:971/1695 train_time:94126ms step_avg:96.94ms +step:972/1695 train_time:94222ms step_avg:96.94ms +step:973/1695 train_time:94318ms step_avg:96.94ms +step:974/1695 train_time:94414ms step_avg:96.93ms +step:975/1695 train_time:94510ms step_avg:96.93ms +step:976/1695 train_time:94607ms step_avg:96.93ms +step:977/1695 train_time:94702ms step_avg:96.93ms +step:978/1695 train_time:94799ms step_avg:96.93ms +step:979/1695 train_time:94894ms step_avg:96.93ms +step:980/1695 train_time:94990ms step_avg:96.93ms +step:981/1695 train_time:95087ms step_avg:96.93ms +step:982/1695 train_time:95184ms step_avg:96.93ms +step:983/1695 train_time:95280ms step_avg:96.93ms +step:984/1695 train_time:95376ms step_avg:96.93ms +step:985/1695 train_time:95472ms step_avg:96.93ms +step:986/1695 train_time:95568ms step_avg:96.92ms +step:987/1695 train_time:95664ms step_avg:96.92ms +step:988/1695 train_time:95760ms step_avg:96.92ms +step:989/1695 train_time:95855ms step_avg:96.92ms +step:990/1695 train_time:95950ms step_avg:96.92ms +step:991/1695 train_time:96047ms step_avg:96.92ms +step:992/1695 train_time:96144ms step_avg:96.92ms +step:993/1695 train_time:96240ms step_avg:96.92ms +step:994/1695 train_time:96335ms step_avg:96.92ms +step:995/1695 train_time:96431ms step_avg:96.92ms +step:996/1695 train_time:96528ms step_avg:96.92ms +step:997/1695 train_time:96625ms step_avg:96.92ms +step:998/1695 train_time:96722ms step_avg:96.92ms +step:999/1695 train_time:96817ms step_avg:96.91ms +step:1000/1695 train_time:96912ms step_avg:96.91ms +step:1000/1695 val_loss:3.4845 train_time:97007ms step_avg:97.01ms +step:1001/1695 train_time:97031ms step_avg:96.93ms +step:1002/1695 train_time:97110ms step_avg:96.92ms +step:1003/1695 train_time:97207ms step_avg:96.92ms +step:1004/1695 train_time:97304ms step_avg:96.92ms +step:1005/1695 train_time:97399ms step_avg:96.91ms +step:1006/1695 train_time:97495ms step_avg:96.91ms +step:1007/1695 train_time:97589ms step_avg:96.91ms +step:1008/1695 train_time:97684ms step_avg:96.91ms +step:1009/1695 train_time:97779ms step_avg:96.91ms +step:1010/1695 train_time:97874ms step_avg:96.90ms +step:1011/1695 train_time:97971ms step_avg:96.90ms +step:1012/1695 train_time:98069ms step_avg:96.91ms +step:1013/1695 train_time:98165ms step_avg:96.91ms +step:1014/1695 train_time:98263ms step_avg:96.91ms +step:1015/1695 train_time:98359ms step_avg:96.91ms +step:1016/1695 train_time:98454ms step_avg:96.90ms +step:1017/1695 train_time:98549ms step_avg:96.90ms +step:1018/1695 train_time:98644ms step_avg:96.90ms +step:1019/1695 train_time:98740ms step_avg:96.90ms +step:1020/1695 train_time:98835ms step_avg:96.90ms +step:1021/1695 train_time:98930ms step_avg:96.89ms +step:1022/1695 train_time:99026ms step_avg:96.89ms +step:1023/1695 train_time:99123ms step_avg:96.89ms +step:1024/1695 train_time:99220ms step_avg:96.89ms +step:1025/1695 train_time:99317ms step_avg:96.89ms +step:1026/1695 train_time:99414ms step_avg:96.90ms +step:1027/1695 train_time:99510ms step_avg:96.89ms +step:1028/1695 train_time:99605ms step_avg:96.89ms +step:1029/1695 train_time:99700ms step_avg:96.89ms +step:1030/1695 train_time:99795ms step_avg:96.89ms +step:1031/1695 train_time:99891ms step_avg:96.89ms +step:1032/1695 train_time:99986ms step_avg:96.89ms +step:1033/1695 train_time:100082ms step_avg:96.88ms +step:1034/1695 train_time:100179ms step_avg:96.88ms +step:1035/1695 train_time:100274ms step_avg:96.88ms +step:1036/1695 train_time:100600ms step_avg:97.10ms +step:1037/1695 train_time:100773ms step_avg:97.18ms +step:1038/1695 train_time:100866ms step_avg:97.17ms +step:1039/1695 train_time:100961ms step_avg:97.17ms +step:1040/1695 train_time:101056ms step_avg:97.17ms +step:1041/1695 train_time:101151ms step_avg:97.17ms +step:1042/1695 train_time:101245ms step_avg:97.16ms +step:1043/1695 train_time:101341ms step_avg:97.16ms +step:1044/1695 train_time:101436ms step_avg:97.16ms +step:1045/1695 train_time:101530ms step_avg:97.16ms +step:1046/1695 train_time:101626ms step_avg:97.16ms +step:1047/1695 train_time:101728ms step_avg:97.16ms +step:1048/1695 train_time:101826ms step_avg:97.16ms +step:1049/1695 train_time:101923ms step_avg:97.16ms +step:1050/1695 train_time:102019ms step_avg:97.16ms +step:1051/1695 train_time:102115ms step_avg:97.16ms +step:1052/1695 train_time:102210ms step_avg:97.16ms +step:1053/1695 train_time:102304ms step_avg:97.16ms +step:1054/1695 train_time:102399ms step_avg:97.15ms +step:1055/1695 train_time:102494ms step_avg:97.15ms +step:1056/1695 train_time:102590ms step_avg:97.15ms +step:1057/1695 train_time:102688ms step_avg:97.15ms +step:1058/1695 train_time:102784ms step_avg:97.15ms +step:1059/1695 train_time:102881ms step_avg:97.15ms +step:1060/1695 train_time:102978ms step_avg:97.15ms +step:1061/1695 train_time:103075ms step_avg:97.15ms +step:1062/1695 train_time:103171ms step_avg:97.15ms +step:1063/1695 train_time:103265ms step_avg:97.15ms +step:1064/1695 train_time:103360ms step_avg:97.14ms +step:1065/1695 train_time:103456ms step_avg:97.14ms +step:1066/1695 train_time:103553ms step_avg:97.14ms +step:1067/1695 train_time:103650ms step_avg:97.14ms +step:1068/1695 train_time:103746ms step_avg:97.14ms +step:1069/1695 train_time:103842ms step_avg:97.14ms +step:1070/1695 train_time:103939ms step_avg:97.14ms +step:1071/1695 train_time:104036ms step_avg:97.14ms +step:1072/1695 train_time:104132ms step_avg:97.14ms +step:1073/1695 train_time:104227ms step_avg:97.14ms +step:1074/1695 train_time:104322ms step_avg:97.13ms +step:1075/1695 train_time:104418ms step_avg:97.13ms +step:1076/1695 train_time:104514ms step_avg:97.13ms +step:1077/1695 train_time:104609ms step_avg:97.13ms +step:1078/1695 train_time:104705ms step_avg:97.13ms +step:1079/1695 train_time:104802ms step_avg:97.13ms +step:1080/1695 train_time:104898ms step_avg:97.13ms +step:1081/1695 train_time:104995ms step_avg:97.13ms +step:1082/1695 train_time:105091ms step_avg:97.13ms +step:1083/1695 train_time:105187ms step_avg:97.13ms +step:1084/1695 train_time:105282ms step_avg:97.12ms +step:1085/1695 train_time:105378ms step_avg:97.12ms +step:1086/1695 train_time:105474ms step_avg:97.12ms +step:1087/1695 train_time:105569ms step_avg:97.12ms +step:1088/1695 train_time:105664ms step_avg:97.12ms +step:1089/1695 train_time:105761ms step_avg:97.12ms +step:1090/1695 train_time:105858ms step_avg:97.12ms +step:1091/1695 train_time:105955ms step_avg:97.12ms +step:1092/1695 train_time:106050ms step_avg:97.12ms +step:1093/1695 train_time:106146ms step_avg:97.11ms +step:1094/1695 train_time:106241ms step_avg:97.11ms +step:1095/1695 train_time:106339ms step_avg:97.11ms +step:1096/1695 train_time:106435ms step_avg:97.11ms +step:1097/1695 train_time:106531ms step_avg:97.11ms +step:1098/1695 train_time:106627ms step_avg:97.11ms +step:1099/1695 train_time:106723ms step_avg:97.11ms +step:1100/1695 train_time:106819ms step_avg:97.11ms +step:1101/1695 train_time:106915ms step_avg:97.11ms +step:1102/1695 train_time:107012ms step_avg:97.11ms +step:1103/1695 train_time:107107ms step_avg:97.11ms +step:1104/1695 train_time:107203ms step_avg:97.10ms +step:1105/1695 train_time:107300ms step_avg:97.10ms +step:1106/1695 train_time:107396ms step_avg:97.10ms +step:1107/1695 train_time:107492ms step_avg:97.10ms +step:1108/1695 train_time:107587ms step_avg:97.10ms +step:1109/1695 train_time:107683ms step_avg:97.10ms +step:1110/1695 train_time:107779ms step_avg:97.10ms +step:1111/1695 train_time:107875ms step_avg:97.10ms +step:1112/1695 train_time:107971ms step_avg:97.10ms +step:1113/1695 train_time:108066ms step_avg:97.09ms +step:1114/1695 train_time:108163ms step_avg:97.09ms +step:1115/1695 train_time:108259ms step_avg:97.09ms +step:1116/1695 train_time:108356ms step_avg:97.09ms +step:1117/1695 train_time:108453ms step_avg:97.09ms +step:1118/1695 train_time:108549ms step_avg:97.09ms +step:1119/1695 train_time:108645ms step_avg:97.09ms +step:1120/1695 train_time:108741ms step_avg:97.09ms +step:1121/1695 train_time:108838ms step_avg:97.09ms +step:1122/1695 train_time:108934ms step_avg:97.09ms +step:1123/1695 train_time:109030ms step_avg:97.09ms +step:1124/1695 train_time:109126ms step_avg:97.09ms +step:1125/1695 train_time:109222ms step_avg:97.09ms +step:1125/1695 val_loss:3.4375 train_time:109315ms step_avg:97.17ms +step:1126/1695 train_time:109340ms step_avg:97.10ms +step:1127/1695 train_time:109421ms step_avg:97.09ms +step:1128/1695 train_time:109520ms step_avg:97.09ms +step:1129/1695 train_time:109618ms step_avg:97.09ms +step:1130/1695 train_time:109713ms step_avg:97.09ms +step:1131/1695 train_time:109808ms step_avg:97.09ms +step:1132/1695 train_time:109903ms step_avg:97.09ms +step:1133/1695 train_time:110000ms step_avg:97.09ms +step:1134/1695 train_time:110096ms step_avg:97.09ms +step:1135/1695 train_time:110193ms step_avg:97.09ms +step:1136/1695 train_time:110294ms step_avg:97.09ms +step:1137/1695 train_time:110395ms step_avg:97.09ms +step:1138/1695 train_time:110495ms step_avg:97.10ms +step:1139/1695 train_time:110595ms step_avg:97.10ms +step:1140/1695 train_time:110691ms step_avg:97.10ms +step:1141/1695 train_time:110788ms step_avg:97.10ms +step:1142/1695 train_time:110886ms step_avg:97.10ms +step:1143/1695 train_time:110983ms step_avg:97.10ms +step:1144/1695 train_time:111080ms step_avg:97.10ms +step:1145/1695 train_time:111178ms step_avg:97.10ms +step:1146/1695 train_time:111276ms step_avg:97.10ms +step:1147/1695 train_time:111375ms step_avg:97.10ms +step:1148/1695 train_time:111474ms step_avg:97.10ms +step:1149/1695 train_time:111572ms step_avg:97.10ms +step:1150/1695 train_time:111669ms step_avg:97.10ms +step:1151/1695 train_time:111766ms step_avg:97.10ms +step:1152/1695 train_time:111863ms step_avg:97.10ms +step:1153/1695 train_time:111959ms step_avg:97.10ms +step:1154/1695 train_time:112057ms step_avg:97.10ms +step:1155/1695 train_time:112154ms step_avg:97.10ms +step:1156/1695 train_time:112251ms step_avg:97.10ms +step:1157/1695 train_time:112349ms step_avg:97.10ms +step:1158/1695 train_time:112448ms step_avg:97.11ms +step:1159/1695 train_time:112547ms step_avg:97.11ms +step:1160/1695 train_time:112645ms step_avg:97.11ms +step:1161/1695 train_time:112744ms step_avg:97.11ms +step:1162/1695 train_time:112843ms step_avg:97.11ms +step:1163/1695 train_time:112940ms step_avg:97.11ms +step:1164/1695 train_time:113037ms step_avg:97.11ms +step:1165/1695 train_time:113135ms step_avg:97.11ms +step:1166/1695 train_time:113232ms step_avg:97.11ms +step:1167/1695 train_time:113330ms step_avg:97.11ms +step:1168/1695 train_time:113428ms step_avg:97.11ms +step:1169/1695 train_time:113527ms step_avg:97.11ms +step:1170/1695 train_time:113626ms step_avg:97.12ms +step:1171/1695 train_time:113723ms step_avg:97.12ms +step:1172/1695 train_time:113822ms step_avg:97.12ms +step:1173/1695 train_time:113919ms step_avg:97.12ms +step:1174/1695 train_time:114018ms step_avg:97.12ms +step:1175/1695 train_time:114115ms step_avg:97.12ms +step:1176/1695 train_time:114212ms step_avg:97.12ms +step:1177/1695 train_time:114310ms step_avg:97.12ms +step:1178/1695 train_time:114406ms step_avg:97.12ms +step:1179/1695 train_time:114504ms step_avg:97.12ms +step:1180/1695 train_time:114604ms step_avg:97.12ms +step:1181/1695 train_time:114703ms step_avg:97.12ms +step:1182/1695 train_time:114801ms step_avg:97.12ms +step:1183/1695 train_time:114898ms step_avg:97.12ms +step:1184/1695 train_time:114996ms step_avg:97.13ms +step:1185/1695 train_time:115095ms step_avg:97.13ms +step:1186/1695 train_time:115193ms step_avg:97.13ms +step:1187/1695 train_time:115290ms step_avg:97.13ms +step:1188/1695 train_time:115388ms step_avg:97.13ms +step:1189/1695 train_time:115486ms step_avg:97.13ms +step:1190/1695 train_time:115584ms step_avg:97.13ms +step:1191/1695 train_time:115682ms step_avg:97.13ms +step:1192/1695 train_time:115779ms step_avg:97.13ms +step:1193/1695 train_time:115876ms step_avg:97.13ms +step:1194/1695 train_time:115973ms step_avg:97.13ms +step:1195/1695 train_time:116071ms step_avg:97.13ms +step:1196/1695 train_time:116169ms step_avg:97.13ms +step:1197/1695 train_time:116266ms step_avg:97.13ms +step:1198/1695 train_time:116365ms step_avg:97.13ms +step:1199/1695 train_time:116464ms step_avg:97.13ms +step:1200/1695 train_time:116562ms step_avg:97.14ms +step:1201/1695 train_time:116661ms step_avg:97.14ms +step:1202/1695 train_time:116758ms step_avg:97.14ms +step:1203/1695 train_time:116856ms step_avg:97.14ms +step:1204/1695 train_time:116955ms step_avg:97.14ms +step:1205/1695 train_time:117053ms step_avg:97.14ms +step:1206/1695 train_time:117150ms step_avg:97.14ms +step:1207/1695 train_time:117247ms step_avg:97.14ms +step:1208/1695 train_time:117580ms step_avg:97.33ms +step:1209/1695 train_time:117762ms step_avg:97.40ms +step:1210/1695 train_time:117858ms step_avg:97.40ms +step:1211/1695 train_time:117955ms step_avg:97.40ms +step:1212/1695 train_time:118051ms step_avg:97.40ms +step:1213/1695 train_time:118147ms step_avg:97.40ms +step:1214/1695 train_time:118243ms step_avg:97.40ms +step:1215/1695 train_time:118340ms step_avg:97.40ms +step:1216/1695 train_time:118436ms step_avg:97.40ms +step:1217/1695 train_time:118533ms step_avg:97.40ms +step:1218/1695 train_time:118636ms step_avg:97.40ms +step:1219/1695 train_time:118739ms step_avg:97.41ms +step:1220/1695 train_time:118838ms step_avg:97.41ms +step:1221/1695 train_time:118935ms step_avg:97.41ms +step:1222/1695 train_time:119031ms step_avg:97.41ms +step:1223/1695 train_time:119128ms step_avg:97.41ms +step:1224/1695 train_time:119225ms step_avg:97.41ms +step:1225/1695 train_time:119323ms step_avg:97.41ms +step:1226/1695 train_time:119420ms step_avg:97.41ms +step:1227/1695 train_time:119519ms step_avg:97.41ms +step:1228/1695 train_time:119618ms step_avg:97.41ms +step:1229/1695 train_time:119719ms step_avg:97.41ms +step:1230/1695 train_time:119818ms step_avg:97.41ms +step:1231/1695 train_time:119916ms step_avg:97.41ms +step:1232/1695 train_time:120013ms step_avg:97.41ms +step:1233/1695 train_time:120109ms step_avg:97.41ms +step:1234/1695 train_time:120206ms step_avg:97.41ms +step:1235/1695 train_time:120302ms step_avg:97.41ms +step:1236/1695 train_time:120399ms step_avg:97.41ms +step:1237/1695 train_time:120498ms step_avg:97.41ms +step:1238/1695 train_time:120595ms step_avg:97.41ms +step:1239/1695 train_time:120694ms step_avg:97.41ms +step:1240/1695 train_time:120792ms step_avg:97.41ms +step:1241/1695 train_time:120890ms step_avg:97.41ms +step:1242/1695 train_time:120987ms step_avg:97.41ms +step:1243/1695 train_time:121085ms step_avg:97.41ms +step:1244/1695 train_time:121183ms step_avg:97.41ms +step:1245/1695 train_time:121280ms step_avg:97.41ms +step:1246/1695 train_time:121377ms step_avg:97.41ms +step:1247/1695 train_time:121475ms step_avg:97.41ms +step:1248/1695 train_time:121573ms step_avg:97.41ms +step:1249/1695 train_time:121671ms step_avg:97.41ms +step:1250/1695 train_time:121769ms step_avg:97.41ms +step:1250/1695 val_loss:3.3901 train_time:121865ms step_avg:97.49ms +step:1251/1695 train_time:121890ms step_avg:97.43ms +step:1252/1695 train_time:121970ms step_avg:97.42ms +step:1253/1695 train_time:122068ms step_avg:97.42ms +step:1254/1695 train_time:122165ms step_avg:97.42ms +step:1255/1695 train_time:122262ms step_avg:97.42ms +step:1256/1695 train_time:122358ms step_avg:97.42ms +step:1257/1695 train_time:122455ms step_avg:97.42ms +step:1258/1695 train_time:122552ms step_avg:97.42ms +step:1259/1695 train_time:122648ms step_avg:97.42ms +step:1260/1695 train_time:122746ms step_avg:97.42ms +step:1261/1695 train_time:122850ms step_avg:97.42ms +step:1262/1695 train_time:122952ms step_avg:97.43ms +step:1263/1695 train_time:123051ms step_avg:97.43ms +step:1264/1695 train_time:123148ms step_avg:97.43ms +step:1265/1695 train_time:123246ms step_avg:97.43ms +step:1266/1695 train_time:123344ms step_avg:97.43ms +step:1267/1695 train_time:123440ms step_avg:97.43ms +step:1268/1695 train_time:123537ms step_avg:97.43ms +step:1269/1695 train_time:123633ms step_avg:97.43ms +step:1270/1695 train_time:123732ms step_avg:97.43ms +step:1271/1695 train_time:123833ms step_avg:97.43ms +step:1272/1695 train_time:123933ms step_avg:97.43ms +step:1273/1695 train_time:124031ms step_avg:97.43ms +step:1274/1695 train_time:124131ms step_avg:97.43ms +step:1275/1695 train_time:124230ms step_avg:97.44ms +step:1276/1695 train_time:124329ms step_avg:97.44ms +step:1277/1695 train_time:124427ms step_avg:97.44ms +step:1278/1695 train_time:124525ms step_avg:97.44ms +step:1279/1695 train_time:124622ms step_avg:97.44ms +step:1280/1695 train_time:124719ms step_avg:97.44ms +step:1281/1695 train_time:124817ms step_avg:97.44ms +step:1282/1695 train_time:124915ms step_avg:97.44ms +step:1283/1695 train_time:125014ms step_avg:97.44ms +step:1284/1695 train_time:125113ms step_avg:97.44ms +step:1285/1695 train_time:125211ms step_avg:97.44ms +step:1286/1695 train_time:125310ms step_avg:97.44ms +step:1287/1695 train_time:125409ms step_avg:97.44ms +step:1288/1695 train_time:125508ms step_avg:97.44ms +step:1289/1695 train_time:125606ms step_avg:97.44ms +step:1290/1695 train_time:125703ms step_avg:97.44ms +step:1291/1695 train_time:125801ms step_avg:97.44ms +step:1292/1695 train_time:125898ms step_avg:97.44ms +step:1293/1695 train_time:125996ms step_avg:97.44ms +step:1294/1695 train_time:126094ms step_avg:97.44ms +step:1295/1695 train_time:126193ms step_avg:97.45ms +step:1296/1695 train_time:126291ms step_avg:97.45ms +step:1297/1695 train_time:126390ms step_avg:97.45ms +step:1298/1695 train_time:126489ms step_avg:97.45ms +step:1299/1695 train_time:126586ms step_avg:97.45ms +step:1300/1695 train_time:126684ms step_avg:97.45ms +step:1301/1695 train_time:126781ms step_avg:97.45ms +step:1302/1695 train_time:126879ms step_avg:97.45ms +step:1303/1695 train_time:126976ms step_avg:97.45ms +step:1304/1695 train_time:127074ms step_avg:97.45ms +step:1305/1695 train_time:127172ms step_avg:97.45ms +step:1306/1695 train_time:127271ms step_avg:97.45ms +step:1307/1695 train_time:127369ms step_avg:97.45ms +step:1308/1695 train_time:127467ms step_avg:97.45ms +step:1309/1695 train_time:127566ms step_avg:97.45ms +step:1310/1695 train_time:127664ms step_avg:97.45ms +step:1311/1695 train_time:127762ms step_avg:97.45ms +step:1312/1695 train_time:127860ms step_avg:97.45ms +step:1313/1695 train_time:127957ms step_avg:97.45ms +step:1314/1695 train_time:128054ms step_avg:97.45ms +step:1315/1695 train_time:128151ms step_avg:97.45ms +step:1316/1695 train_time:128250ms step_avg:97.45ms +step:1317/1695 train_time:128349ms step_avg:97.46ms +step:1318/1695 train_time:128448ms step_avg:97.46ms +step:1319/1695 train_time:128546ms step_avg:97.46ms +step:1320/1695 train_time:128645ms step_avg:97.46ms +step:1321/1695 train_time:128743ms step_avg:97.46ms +step:1322/1695 train_time:128840ms step_avg:97.46ms +step:1323/1695 train_time:128938ms step_avg:97.46ms +step:1324/1695 train_time:129035ms step_avg:97.46ms +step:1325/1695 train_time:129132ms step_avg:97.46ms +step:1326/1695 train_time:129230ms step_avg:97.46ms +step:1327/1695 train_time:129328ms step_avg:97.46ms +step:1328/1695 train_time:129427ms step_avg:97.46ms +step:1329/1695 train_time:129525ms step_avg:97.46ms +step:1330/1695 train_time:129623ms step_avg:97.46ms +step:1331/1695 train_time:129721ms step_avg:97.46ms +step:1332/1695 train_time:129820ms step_avg:97.46ms +step:1333/1695 train_time:129917ms step_avg:97.46ms +step:1334/1695 train_time:130014ms step_avg:97.46ms +step:1335/1695 train_time:130112ms step_avg:97.46ms +step:1336/1695 train_time:130209ms step_avg:97.46ms +step:1337/1695 train_time:130306ms step_avg:97.46ms +step:1338/1695 train_time:130405ms step_avg:97.46ms +step:1339/1695 train_time:130503ms step_avg:97.46ms +step:1340/1695 train_time:130600ms step_avg:97.46ms +step:1341/1695 train_time:130698ms step_avg:97.46ms +step:1342/1695 train_time:130795ms step_avg:97.46ms +step:1343/1695 train_time:130894ms step_avg:97.46ms +step:1344/1695 train_time:130992ms step_avg:97.46ms +step:1345/1695 train_time:131090ms step_avg:97.46ms +step:1346/1695 train_time:131188ms step_avg:97.46ms +step:1347/1695 train_time:131285ms step_avg:97.46ms +step:1348/1695 train_time:131382ms step_avg:97.46ms +step:1349/1695 train_time:131480ms step_avg:97.46ms +step:1350/1695 train_time:131577ms step_avg:97.46ms +step:1351/1695 train_time:131675ms step_avg:97.46ms +step:1352/1695 train_time:131773ms step_avg:97.47ms +step:1353/1695 train_time:131872ms step_avg:97.47ms +step:1354/1695 train_time:131970ms step_avg:97.47ms +step:1355/1695 train_time:132069ms step_avg:97.47ms +step:1356/1695 train_time:132167ms step_avg:97.47ms +step:1357/1695 train_time:132264ms step_avg:97.47ms +step:1358/1695 train_time:132362ms step_avg:97.47ms +step:1359/1695 train_time:132460ms step_avg:97.47ms +step:1360/1695 train_time:132556ms step_avg:97.47ms +step:1361/1695 train_time:132654ms step_avg:97.47ms +step:1362/1695 train_time:132752ms step_avg:97.47ms +step:1363/1695 train_time:132851ms step_avg:97.47ms +step:1364/1695 train_time:132950ms step_avg:97.47ms +step:1365/1695 train_time:133049ms step_avg:97.47ms +step:1366/1695 train_time:133147ms step_avg:97.47ms +step:1367/1695 train_time:133246ms step_avg:97.47ms +step:1368/1695 train_time:133344ms step_avg:97.47ms +step:1369/1695 train_time:133443ms step_avg:97.47ms +step:1370/1695 train_time:133540ms step_avg:97.47ms +step:1371/1695 train_time:133637ms step_avg:97.47ms +step:1372/1695 train_time:133734ms step_avg:97.47ms +step:1373/1695 train_time:133830ms step_avg:97.47ms +step:1374/1695 train_time:133929ms step_avg:97.47ms +step:1375/1695 train_time:134027ms step_avg:97.47ms +step:1375/1695 val_loss:3.3508 train_time:134124ms step_avg:97.54ms +step:1376/1695 train_time:134149ms step_avg:97.49ms +step:1377/1695 train_time:134233ms step_avg:97.48ms +step:1378/1695 train_time:134332ms step_avg:97.48ms +step:1379/1695 train_time:134430ms step_avg:97.48ms +step:1380/1695 train_time:134528ms step_avg:97.48ms +step:1381/1695 train_time:134984ms step_avg:97.74ms +step:1382/1695 train_time:135059ms step_avg:97.73ms +step:1383/1695 train_time:135155ms step_avg:97.73ms +step:1384/1695 train_time:135251ms step_avg:97.72ms +step:1385/1695 train_time:135348ms step_avg:97.72ms +step:1386/1695 train_time:135445ms step_avg:97.72ms +step:1387/1695 train_time:135542ms step_avg:97.72ms +step:1388/1695 train_time:135638ms step_avg:97.72ms +step:1389/1695 train_time:135734ms step_avg:97.72ms +step:1390/1695 train_time:135832ms step_avg:97.72ms +step:1391/1695 train_time:135934ms step_avg:97.72ms +step:1392/1695 train_time:136034ms step_avg:97.73ms +step:1393/1695 train_time:136133ms step_avg:97.73ms +step:1394/1695 train_time:136230ms step_avg:97.73ms +step:1395/1695 train_time:136328ms step_avg:97.73ms +step:1396/1695 train_time:136425ms step_avg:97.73ms +step:1397/1695 train_time:136523ms step_avg:97.73ms +step:1398/1695 train_time:136620ms step_avg:97.73ms +step:1399/1695 train_time:136717ms step_avg:97.72ms +step:1400/1695 train_time:136814ms step_avg:97.72ms +step:1401/1695 train_time:136913ms step_avg:97.73ms +step:1402/1695 train_time:137012ms step_avg:97.73ms +step:1403/1695 train_time:137111ms step_avg:97.73ms +step:1404/1695 train_time:137211ms step_avg:97.73ms +step:1405/1695 train_time:137309ms step_avg:97.73ms +step:1406/1695 train_time:137407ms step_avg:97.73ms +step:1407/1695 train_time:137504ms step_avg:97.73ms +step:1408/1695 train_time:137601ms step_avg:97.73ms +step:1409/1695 train_time:137698ms step_avg:97.73ms +step:1410/1695 train_time:137795ms step_avg:97.73ms +step:1411/1695 train_time:137894ms step_avg:97.73ms +step:1412/1695 train_time:137993ms step_avg:97.73ms +step:1413/1695 train_time:138092ms step_avg:97.73ms +step:1414/1695 train_time:138191ms step_avg:97.73ms +step:1415/1695 train_time:138289ms step_avg:97.73ms +step:1416/1695 train_time:138386ms step_avg:97.73ms +step:1417/1695 train_time:138483ms step_avg:97.73ms +step:1418/1695 train_time:138580ms step_avg:97.73ms +step:1419/1695 train_time:138677ms step_avg:97.73ms +step:1420/1695 train_time:138774ms step_avg:97.73ms +step:1421/1695 train_time:138873ms step_avg:97.73ms +step:1422/1695 train_time:138971ms step_avg:97.73ms +step:1423/1695 train_time:139071ms step_avg:97.73ms +step:1424/1695 train_time:139170ms step_avg:97.73ms +step:1425/1695 train_time:139269ms step_avg:97.73ms +step:1426/1695 train_time:139367ms step_avg:97.73ms +step:1427/1695 train_time:139467ms step_avg:97.73ms +step:1428/1695 train_time:139565ms step_avg:97.73ms +step:1429/1695 train_time:139662ms step_avg:97.73ms +step:1430/1695 train_time:139760ms step_avg:97.73ms +step:1431/1695 train_time:139857ms step_avg:97.73ms +step:1432/1695 train_time:139954ms step_avg:97.73ms +step:1433/1695 train_time:140052ms step_avg:97.73ms +step:1434/1695 train_time:140149ms step_avg:97.73ms +step:1435/1695 train_time:140246ms step_avg:97.73ms +step:1436/1695 train_time:140344ms step_avg:97.73ms +step:1437/1695 train_time:140441ms step_avg:97.73ms +step:1438/1695 train_time:140538ms step_avg:97.73ms +step:1439/1695 train_time:140635ms step_avg:97.73ms +step:1440/1695 train_time:140733ms step_avg:97.73ms +step:1441/1695 train_time:140831ms step_avg:97.73ms +step:1442/1695 train_time:140929ms step_avg:97.73ms +step:1443/1695 train_time:141027ms step_avg:97.73ms +step:1444/1695 train_time:141125ms step_avg:97.73ms +step:1445/1695 train_time:141223ms step_avg:97.73ms +step:1446/1695 train_time:141320ms step_avg:97.73ms +step:1447/1695 train_time:141417ms step_avg:97.73ms +step:1448/1695 train_time:141514ms step_avg:97.73ms +step:1449/1695 train_time:141612ms step_avg:97.73ms +step:1450/1695 train_time:141710ms step_avg:97.73ms +step:1451/1695 train_time:141808ms step_avg:97.73ms +step:1452/1695 train_time:141905ms step_avg:97.73ms +step:1453/1695 train_time:142003ms step_avg:97.73ms +step:1454/1695 train_time:142100ms step_avg:97.73ms +step:1455/1695 train_time:142198ms step_avg:97.73ms +step:1456/1695 train_time:142295ms step_avg:97.73ms +step:1457/1695 train_time:142393ms step_avg:97.73ms +step:1458/1695 train_time:142491ms step_avg:97.73ms +step:1459/1695 train_time:142590ms step_avg:97.73ms +step:1460/1695 train_time:142687ms step_avg:97.73ms +step:1461/1695 train_time:142784ms step_avg:97.73ms +step:1462/1695 train_time:142881ms step_avg:97.73ms +step:1463/1695 train_time:142979ms step_avg:97.73ms +step:1464/1695 train_time:143076ms step_avg:97.73ms +step:1465/1695 train_time:143174ms step_avg:97.73ms +step:1466/1695 train_time:143271ms step_avg:97.73ms +step:1467/1695 train_time:143369ms step_avg:97.73ms +step:1468/1695 train_time:143468ms step_avg:97.73ms +step:1469/1695 train_time:143568ms step_avg:97.73ms +step:1470/1695 train_time:143665ms step_avg:97.73ms +step:1471/1695 train_time:143763ms step_avg:97.73ms +step:1472/1695 train_time:143860ms step_avg:97.73ms +step:1473/1695 train_time:143957ms step_avg:97.73ms +step:1474/1695 train_time:144054ms step_avg:97.73ms +step:1475/1695 train_time:144151ms step_avg:97.73ms +step:1476/1695 train_time:144248ms step_avg:97.73ms +step:1477/1695 train_time:144346ms step_avg:97.73ms +step:1478/1695 train_time:144444ms step_avg:97.73ms +step:1479/1695 train_time:144542ms step_avg:97.73ms +step:1480/1695 train_time:144640ms step_avg:97.73ms +step:1481/1695 train_time:144737ms step_avg:97.73ms +step:1482/1695 train_time:144834ms step_avg:97.73ms +step:1483/1695 train_time:144931ms step_avg:97.73ms +step:1484/1695 train_time:145029ms step_avg:97.73ms +step:1485/1695 train_time:145127ms step_avg:97.73ms +step:1486/1695 train_time:145225ms step_avg:97.73ms +step:1487/1695 train_time:145322ms step_avg:97.73ms +step:1488/1695 train_time:145420ms step_avg:97.73ms +step:1489/1695 train_time:145518ms step_avg:97.73ms +step:1490/1695 train_time:145616ms step_avg:97.73ms +step:1491/1695 train_time:145714ms step_avg:97.73ms +step:1492/1695 train_time:145811ms step_avg:97.73ms +step:1493/1695 train_time:145909ms step_avg:97.73ms +step:1494/1695 train_time:146007ms step_avg:97.73ms +step:1495/1695 train_time:146105ms step_avg:97.73ms +step:1496/1695 train_time:146203ms step_avg:97.73ms +step:1497/1695 train_time:146300ms step_avg:97.73ms +step:1498/1695 train_time:146397ms step_avg:97.73ms +step:1499/1695 train_time:146494ms step_avg:97.73ms +step:1500/1695 train_time:146592ms step_avg:97.73ms +step:1500/1695 val_loss:3.3185 train_time:146688ms step_avg:97.79ms +step:1501/1695 train_time:146713ms step_avg:97.74ms +step:1502/1695 train_time:146798ms step_avg:97.74ms +step:1503/1695 train_time:146898ms step_avg:97.74ms +step:1504/1695 train_time:146996ms step_avg:97.74ms +step:1505/1695 train_time:147093ms step_avg:97.74ms +step:1506/1695 train_time:147190ms step_avg:97.74ms +step:1507/1695 train_time:147287ms step_avg:97.73ms +step:1508/1695 train_time:147383ms step_avg:97.73ms +step:1509/1695 train_time:147479ms step_avg:97.73ms +step:1510/1695 train_time:147576ms step_avg:97.73ms +step:1511/1695 train_time:147675ms step_avg:97.73ms +step:1512/1695 train_time:147778ms step_avg:97.74ms +step:1513/1695 train_time:147878ms step_avg:97.74ms +step:1514/1695 train_time:147977ms step_avg:97.74ms +step:1515/1695 train_time:148075ms step_avg:97.74ms +step:1516/1695 train_time:148173ms step_avg:97.74ms +step:1517/1695 train_time:148270ms step_avg:97.74ms +step:1518/1695 train_time:148368ms step_avg:97.74ms +step:1519/1695 train_time:148465ms step_avg:97.74ms +step:1520/1695 train_time:148562ms step_avg:97.74ms +step:1521/1695 train_time:148659ms step_avg:97.74ms +step:1522/1695 train_time:148758ms step_avg:97.74ms +step:1523/1695 train_time:148858ms step_avg:97.74ms +step:1524/1695 train_time:148957ms step_avg:97.74ms +step:1525/1695 train_time:149056ms step_avg:97.74ms +step:1526/1695 train_time:149154ms step_avg:97.74ms +step:1527/1695 train_time:149252ms step_avg:97.74ms +step:1528/1695 train_time:149349ms step_avg:97.74ms +step:1529/1695 train_time:149446ms step_avg:97.74ms +step:1530/1695 train_time:149544ms step_avg:97.74ms +step:1531/1695 train_time:149641ms step_avg:97.74ms +step:1532/1695 train_time:149739ms step_avg:97.74ms +step:1533/1695 train_time:149836ms step_avg:97.74ms +step:1534/1695 train_time:149935ms step_avg:97.74ms +step:1535/1695 train_time:150033ms step_avg:97.74ms +step:1536/1695 train_time:150131ms step_avg:97.74ms +step:1537/1695 train_time:150228ms step_avg:97.74ms +step:1538/1695 train_time:150325ms step_avg:97.74ms +step:1539/1695 train_time:150423ms step_avg:97.74ms +step:1540/1695 train_time:150520ms step_avg:97.74ms +step:1541/1695 train_time:150617ms step_avg:97.74ms +step:1542/1695 train_time:150715ms step_avg:97.74ms +step:1543/1695 train_time:150815ms step_avg:97.74ms +step:1544/1695 train_time:150914ms step_avg:97.74ms +step:1545/1695 train_time:151013ms step_avg:97.74ms +step:1546/1695 train_time:151111ms step_avg:97.74ms +step:1547/1695 train_time:151209ms step_avg:97.74ms +step:1548/1695 train_time:151307ms step_avg:97.74ms +step:1549/1695 train_time:151404ms step_avg:97.74ms +step:1550/1695 train_time:151502ms step_avg:97.74ms +step:1551/1695 train_time:151599ms step_avg:97.74ms +step:1552/1695 train_time:152044ms step_avg:97.97ms +step:1553/1695 train_time:152117ms step_avg:97.95ms +step:1554/1695 train_time:152213ms step_avg:97.95ms +step:1555/1695 train_time:152309ms step_avg:97.95ms +step:1556/1695 train_time:152406ms step_avg:97.95ms +step:1557/1695 train_time:152502ms step_avg:97.95ms +step:1558/1695 train_time:152598ms step_avg:97.94ms +step:1559/1695 train_time:152694ms step_avg:97.94ms +step:1560/1695 train_time:152791ms step_avg:97.94ms +step:1561/1695 train_time:152888ms step_avg:97.94ms +step:1562/1695 train_time:152990ms step_avg:97.95ms +step:1563/1695 train_time:153092ms step_avg:97.95ms +step:1564/1695 train_time:153194ms step_avg:97.95ms +step:1565/1695 train_time:153294ms step_avg:97.95ms +step:1566/1695 train_time:153391ms step_avg:97.95ms +step:1567/1695 train_time:153490ms step_avg:97.95ms +step:1568/1695 train_time:153587ms step_avg:97.95ms +step:1569/1695 train_time:153684ms step_avg:97.95ms +step:1570/1695 train_time:153780ms step_avg:97.95ms +step:1571/1695 train_time:153877ms step_avg:97.95ms +step:1572/1695 train_time:153976ms step_avg:97.95ms +step:1573/1695 train_time:154077ms step_avg:97.95ms +step:1574/1695 train_time:154176ms step_avg:97.95ms +step:1575/1695 train_time:154276ms step_avg:97.95ms +step:1576/1695 train_time:154375ms step_avg:97.95ms +step:1577/1695 train_time:154473ms step_avg:97.95ms +step:1578/1695 train_time:154571ms step_avg:97.95ms +step:1579/1695 train_time:154669ms step_avg:97.95ms +step:1580/1695 train_time:154766ms step_avg:97.95ms +step:1581/1695 train_time:154863ms step_avg:97.95ms +step:1582/1695 train_time:154961ms step_avg:97.95ms +step:1583/1695 train_time:155059ms step_avg:97.95ms +step:1584/1695 train_time:155156ms step_avg:97.95ms +step:1585/1695 train_time:155255ms step_avg:97.95ms +step:1586/1695 train_time:155354ms step_avg:97.95ms +step:1587/1695 train_time:155454ms step_avg:97.95ms +step:1588/1695 train_time:155553ms step_avg:97.96ms +step:1589/1695 train_time:155651ms step_avg:97.96ms +step:1590/1695 train_time:155749ms step_avg:97.96ms +step:1591/1695 train_time:155847ms step_avg:97.96ms +step:1592/1695 train_time:155946ms step_avg:97.96ms +step:1593/1695 train_time:156045ms step_avg:97.96ms +step:1594/1695 train_time:156143ms step_avg:97.96ms +step:1595/1695 train_time:156241ms step_avg:97.96ms +step:1596/1695 train_time:156339ms step_avg:97.96ms +step:1597/1695 train_time:156436ms step_avg:97.96ms +step:1598/1695 train_time:156534ms step_avg:97.96ms +step:1599/1695 train_time:156632ms step_avg:97.96ms +step:1600/1695 train_time:156731ms step_avg:97.96ms +step:1601/1695 train_time:156829ms step_avg:97.96ms +step:1602/1695 train_time:156928ms step_avg:97.96ms +step:1603/1695 train_time:157027ms step_avg:97.96ms +step:1604/1695 train_time:157126ms step_avg:97.96ms +step:1605/1695 train_time:157224ms step_avg:97.96ms +step:1606/1695 train_time:157323ms step_avg:97.96ms +step:1607/1695 train_time:157420ms step_avg:97.96ms +step:1608/1695 train_time:157518ms step_avg:97.96ms +step:1609/1695 train_time:157615ms step_avg:97.96ms +step:1610/1695 train_time:157712ms step_avg:97.96ms +step:1611/1695 train_time:157810ms step_avg:97.96ms +step:1612/1695 train_time:157909ms step_avg:97.96ms +step:1613/1695 train_time:158007ms step_avg:97.96ms +step:1614/1695 train_time:158106ms step_avg:97.96ms +step:1615/1695 train_time:158204ms step_avg:97.96ms +step:1616/1695 train_time:158302ms step_avg:97.96ms +step:1617/1695 train_time:158400ms step_avg:97.96ms +step:1618/1695 train_time:158498ms step_avg:97.96ms +step:1619/1695 train_time:158594ms step_avg:97.96ms +step:1620/1695 train_time:158691ms step_avg:97.96ms +step:1621/1695 train_time:158789ms step_avg:97.96ms +step:1622/1695 train_time:158887ms step_avg:97.96ms +step:1623/1695 train_time:158984ms step_avg:97.96ms +step:1624/1695 train_time:159083ms step_avg:97.96ms +step:1625/1695 train_time:159181ms step_avg:97.96ms +step:1625/1695 val_loss:3.2909 train_time:159277ms step_avg:98.02ms +step:1626/1695 train_time:159302ms step_avg:97.97ms +step:1627/1695 train_time:159385ms step_avg:97.96ms +step:1628/1695 train_time:159484ms step_avg:97.96ms +step:1629/1695 train_time:159582ms step_avg:97.96ms +step:1630/1695 train_time:159678ms step_avg:97.96ms +step:1631/1695 train_time:159775ms step_avg:97.96ms +step:1632/1695 train_time:159872ms step_avg:97.96ms +step:1633/1695 train_time:159970ms step_avg:97.96ms +step:1634/1695 train_time:160066ms step_avg:97.96ms +step:1635/1695 train_time:160163ms step_avg:97.96ms +step:1636/1695 train_time:160263ms step_avg:97.96ms +step:1637/1695 train_time:160363ms step_avg:97.96ms +step:1638/1695 train_time:160463ms step_avg:97.96ms +step:1639/1695 train_time:160561ms step_avg:97.96ms +step:1640/1695 train_time:160658ms step_avg:97.96ms +step:1641/1695 train_time:160756ms step_avg:97.96ms +step:1642/1695 train_time:160853ms step_avg:97.96ms +step:1643/1695 train_time:160949ms step_avg:97.96ms +step:1644/1695 train_time:161046ms step_avg:97.96ms +step:1645/1695 train_time:161142ms step_avg:97.96ms +step:1646/1695 train_time:161240ms step_avg:97.96ms +step:1647/1695 train_time:161339ms step_avg:97.96ms +step:1648/1695 train_time:161438ms step_avg:97.96ms +step:1649/1695 train_time:161536ms step_avg:97.96ms +step:1650/1695 train_time:161634ms step_avg:97.96ms +step:1651/1695 train_time:161732ms step_avg:97.96ms +step:1652/1695 train_time:161831ms step_avg:97.96ms +step:1653/1695 train_time:161930ms step_avg:97.96ms +step:1654/1695 train_time:162027ms step_avg:97.96ms +step:1655/1695 train_time:162124ms step_avg:97.96ms +step:1656/1695 train_time:162223ms step_avg:97.96ms +step:1657/1695 train_time:162320ms step_avg:97.96ms +step:1658/1695 train_time:162418ms step_avg:97.96ms +step:1659/1695 train_time:162516ms step_avg:97.96ms +step:1660/1695 train_time:162614ms step_avg:97.96ms +step:1661/1695 train_time:162712ms step_avg:97.96ms +step:1662/1695 train_time:162811ms step_avg:97.96ms +step:1663/1695 train_time:162909ms step_avg:97.96ms +step:1664/1695 train_time:163007ms step_avg:97.96ms +step:1665/1695 train_time:163105ms step_avg:97.96ms +step:1666/1695 train_time:163204ms step_avg:97.96ms +step:1667/1695 train_time:163303ms step_avg:97.96ms +step:1668/1695 train_time:163402ms step_avg:97.96ms +step:1669/1695 train_time:163499ms step_avg:97.96ms +step:1670/1695 train_time:163597ms step_avg:97.96ms +step:1671/1695 train_time:163694ms step_avg:97.96ms +step:1672/1695 train_time:163791ms step_avg:97.96ms +step:1673/1695 train_time:163889ms step_avg:97.96ms +step:1674/1695 train_time:163987ms step_avg:97.96ms +step:1675/1695 train_time:164085ms step_avg:97.96ms +step:1676/1695 train_time:164183ms step_avg:97.96ms +step:1677/1695 train_time:164281ms step_avg:97.96ms +step:1678/1695 train_time:164379ms step_avg:97.96ms +step:1679/1695 train_time:164476ms step_avg:97.96ms +step:1680/1695 train_time:164573ms step_avg:97.96ms +step:1681/1695 train_time:164671ms step_avg:97.96ms +step:1682/1695 train_time:164768ms step_avg:97.96ms +step:1683/1695 train_time:164866ms step_avg:97.96ms +step:1684/1695 train_time:164963ms step_avg:97.96ms +step:1685/1695 train_time:165060ms step_avg:97.96ms +step:1686/1695 train_time:165157ms step_avg:97.96ms +step:1687/1695 train_time:165255ms step_avg:97.96ms +step:1688/1695 train_time:165354ms step_avg:97.96ms +step:1689/1695 train_time:165453ms step_avg:97.96ms +step:1690/1695 train_time:165551ms step_avg:97.96ms +step:1691/1695 train_time:165650ms step_avg:97.96ms +step:1692/1695 train_time:165748ms step_avg:97.96ms +step:1693/1695 train_time:165845ms step_avg:97.96ms +step:1694/1695 train_time:165943ms step_avg:97.96ms +step:1695/1695 train_time:166040ms step_avg:97.96ms +step:1695/1695 val_loss:3.2791 train_time:166135ms step_avg:98.01ms +peak memory allocated: 34000 MiB reserved: 49416 MiB diff --git a/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt b/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt new file mode 100644 index 000000000..f68fe219a --- /dev/null +++ b/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 04:10:32 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 29C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 31C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 29C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 32C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1695 train_time:507ms step_avg:507.01ms +step:2/1695 train_time:531ms step_avg:265.61ms +step:3/1695 train_time:603ms step_avg:201.15ms +step:4/1695 train_time:695ms step_avg:173.81ms +step:5/1695 train_time:789ms step_avg:157.80ms +step:6/1695 train_time:883ms step_avg:147.20ms +step:7/1695 train_time:976ms step_avg:139.42ms +step:8/1695 train_time:1069ms step_avg:133.63ms +step:9/1695 train_time:1162ms step_avg:129.16ms +step:10/1695 train_time:1256ms step_avg:125.61ms +step:11/1695 train_time:1350ms step_avg:122.74ms +step:12/1695 train_time:1446ms step_avg:120.48ms +step:13/1695 train_time:1542ms step_avg:118.60ms +step:14/1695 train_time:1637ms step_avg:116.89ms +step:15/1695 train_time:1731ms step_avg:115.41ms +step:16/1695 train_time:1827ms step_avg:114.16ms +step:17/1695 train_time:1921ms step_avg:112.98ms +step:18/1695 train_time:2014ms step_avg:111.90ms +step:19/1695 train_time:2108ms step_avg:110.97ms +step:20/1695 train_time:2204ms step_avg:110.18ms +step:21/1695 train_time:2297ms step_avg:109.37ms +step:22/1695 train_time:2392ms step_avg:108.71ms +step:23/1695 train_time:2487ms step_avg:108.14ms +step:24/1695 train_time:2583ms step_avg:107.62ms +step:25/1695 train_time:2677ms step_avg:107.07ms +step:26/1695 train_time:2771ms step_avg:106.58ms +step:27/1695 train_time:2866ms step_avg:106.14ms +step:28/1695 train_time:2960ms step_avg:105.72ms +step:29/1695 train_time:3054ms step_avg:105.31ms +step:30/1695 train_time:3149ms step_avg:104.96ms +step:31/1695 train_time:3243ms step_avg:104.62ms +step:32/1695 train_time:3337ms step_avg:104.27ms +step:33/1695 train_time:3432ms step_avg:104.01ms +step:34/1695 train_time:3529ms step_avg:103.78ms +step:35/1695 train_time:3624ms step_avg:103.56ms +step:36/1695 train_time:3719ms step_avg:103.30ms +step:37/1695 train_time:3814ms step_avg:103.07ms +step:38/1695 train_time:3909ms step_avg:102.86ms +step:39/1695 train_time:4004ms step_avg:102.66ms +step:40/1695 train_time:4098ms step_avg:102.45ms +step:41/1695 train_time:4192ms step_avg:102.24ms +step:42/1695 train_time:4286ms step_avg:102.04ms +step:43/1695 train_time:4380ms step_avg:101.86ms +step:44/1695 train_time:4474ms step_avg:101.68ms +step:45/1695 train_time:4569ms step_avg:101.53ms +step:46/1695 train_time:4664ms step_avg:101.39ms +step:47/1695 train_time:4757ms step_avg:101.22ms +step:48/1695 train_time:4852ms step_avg:101.07ms +step:49/1695 train_time:4947ms step_avg:100.95ms +step:50/1695 train_time:5041ms step_avg:100.82ms +step:51/1695 train_time:5135ms step_avg:100.68ms +step:52/1695 train_time:5229ms step_avg:100.56ms +step:53/1695 train_time:5324ms step_avg:100.45ms +step:54/1695 train_time:5418ms step_avg:100.33ms +step:55/1695 train_time:5513ms step_avg:100.23ms +step:56/1695 train_time:5608ms step_avg:100.14ms +step:57/1695 train_time:5703ms step_avg:100.05ms +step:58/1695 train_time:5796ms step_avg:99.94ms +step:59/1695 train_time:5891ms step_avg:99.85ms +step:60/1695 train_time:5986ms step_avg:99.76ms +step:61/1695 train_time:6079ms step_avg:99.66ms +step:62/1695 train_time:6173ms step_avg:99.56ms +step:63/1695 train_time:6268ms step_avg:99.50ms +step:64/1695 train_time:6364ms step_avg:99.44ms +step:65/1695 train_time:6458ms step_avg:99.36ms +step:66/1695 train_time:6553ms step_avg:99.29ms +step:67/1695 train_time:6649ms step_avg:99.24ms +step:68/1695 train_time:6745ms step_avg:99.19ms +step:69/1695 train_time:6837ms step_avg:99.08ms +step:70/1695 train_time:6931ms step_avg:99.02ms +step:71/1695 train_time:7026ms step_avg:98.95ms +step:72/1695 train_time:7119ms step_avg:98.87ms +step:73/1695 train_time:7213ms step_avg:98.81ms +step:74/1695 train_time:7307ms step_avg:98.75ms +step:75/1695 train_time:7402ms step_avg:98.69ms +step:76/1695 train_time:7495ms step_avg:98.62ms +step:77/1695 train_time:7590ms step_avg:98.58ms +step:78/1695 train_time:7686ms step_avg:98.53ms +step:79/1695 train_time:7779ms step_avg:98.47ms +step:80/1695 train_time:7873ms step_avg:98.41ms +step:81/1695 train_time:7969ms step_avg:98.38ms +step:82/1695 train_time:8064ms step_avg:98.34ms +step:83/1695 train_time:8158ms step_avg:98.28ms +step:84/1695 train_time:8252ms step_avg:98.23ms +step:85/1695 train_time:8346ms step_avg:98.18ms +step:86/1695 train_time:8439ms step_avg:98.13ms +step:87/1695 train_time:8533ms step_avg:98.08ms +step:88/1695 train_time:8629ms step_avg:98.06ms +step:89/1695 train_time:8724ms step_avg:98.02ms +step:90/1695 train_time:8817ms step_avg:97.97ms +step:91/1695 train_time:8911ms step_avg:97.93ms +step:92/1695 train_time:9006ms step_avg:97.89ms +step:93/1695 train_time:9100ms step_avg:97.85ms +step:94/1695 train_time:9194ms step_avg:97.80ms +step:95/1695 train_time:9289ms step_avg:97.77ms +step:96/1695 train_time:9383ms step_avg:97.74ms +step:97/1695 train_time:9476ms step_avg:97.69ms +step:98/1695 train_time:9570ms step_avg:97.66ms +step:99/1695 train_time:9665ms step_avg:97.63ms +step:100/1695 train_time:9760ms step_avg:97.60ms +step:101/1695 train_time:9854ms step_avg:97.56ms +step:102/1695 train_time:9949ms step_avg:97.54ms +step:103/1695 train_time:10043ms step_avg:97.51ms +step:104/1695 train_time:10137ms step_avg:97.47ms +step:105/1695 train_time:10231ms step_avg:97.44ms +step:106/1695 train_time:10326ms step_avg:97.41ms +step:107/1695 train_time:10420ms step_avg:97.38ms +step:108/1695 train_time:10513ms step_avg:97.34ms +step:109/1695 train_time:10608ms step_avg:97.32ms +step:110/1695 train_time:10703ms step_avg:97.30ms +step:111/1695 train_time:10797ms step_avg:97.27ms +step:112/1695 train_time:10891ms step_avg:97.24ms +step:113/1695 train_time:10985ms step_avg:97.21ms +step:114/1695 train_time:11078ms step_avg:97.18ms +step:115/1695 train_time:11172ms step_avg:97.15ms +step:116/1695 train_time:11267ms step_avg:97.13ms +step:117/1695 train_time:11362ms step_avg:97.11ms +step:118/1695 train_time:11456ms step_avg:97.08ms +step:119/1695 train_time:11550ms step_avg:97.06ms +step:120/1695 train_time:11644ms step_avg:97.03ms +step:121/1695 train_time:11737ms step_avg:97.00ms +step:122/1695 train_time:11832ms step_avg:96.98ms +step:123/1695 train_time:11927ms step_avg:96.97ms +step:124/1695 train_time:12022ms step_avg:96.95ms +step:125/1695 train_time:12115ms step_avg:96.92ms +step:125/1695 val_loss:4.3104 train_time:12207ms step_avg:97.66ms +step:126/1695 train_time:12232ms step_avg:97.08ms +step:127/1695 train_time:12310ms step_avg:96.93ms +step:128/1695 train_time:12410ms step_avg:96.96ms +step:129/1695 train_time:12506ms step_avg:96.94ms +step:130/1695 train_time:12599ms step_avg:96.92ms +step:131/1695 train_time:12692ms step_avg:96.89ms +step:132/1695 train_time:12785ms step_avg:96.86ms +step:133/1695 train_time:12878ms step_avg:96.83ms +step:134/1695 train_time:12972ms step_avg:96.80ms +step:135/1695 train_time:13064ms step_avg:96.77ms +step:136/1695 train_time:13157ms step_avg:96.75ms +step:137/1695 train_time:13252ms step_avg:96.73ms +step:138/1695 train_time:13348ms step_avg:96.73ms +step:139/1695 train_time:13444ms step_avg:96.72ms +step:140/1695 train_time:13539ms step_avg:96.70ms +step:141/1695 train_time:13633ms step_avg:96.69ms +step:142/1695 train_time:13727ms step_avg:96.67ms +step:143/1695 train_time:13821ms step_avg:96.65ms +step:144/1695 train_time:13916ms step_avg:96.64ms +step:145/1695 train_time:14010ms step_avg:96.62ms +step:146/1695 train_time:14103ms step_avg:96.59ms +step:147/1695 train_time:14197ms step_avg:96.58ms +step:148/1695 train_time:14293ms step_avg:96.57ms +step:149/1695 train_time:14387ms step_avg:96.56ms +step:150/1695 train_time:14482ms step_avg:96.55ms +step:151/1695 train_time:14578ms step_avg:96.55ms +step:152/1695 train_time:14673ms step_avg:96.53ms +step:153/1695 train_time:14766ms step_avg:96.51ms +step:154/1695 train_time:14860ms step_avg:96.49ms +step:155/1695 train_time:14953ms step_avg:96.47ms +step:156/1695 train_time:15046ms step_avg:96.45ms +step:157/1695 train_time:15140ms step_avg:96.43ms +step:158/1695 train_time:15233ms step_avg:96.41ms +step:159/1695 train_time:15327ms step_avg:96.40ms +step:160/1695 train_time:15422ms step_avg:96.39ms +step:161/1695 train_time:15517ms step_avg:96.38ms +step:162/1695 train_time:15613ms step_avg:96.38ms +step:163/1695 train_time:15706ms step_avg:96.36ms +step:164/1695 train_time:15800ms step_avg:96.34ms +step:165/1695 train_time:15893ms step_avg:96.32ms +step:166/1695 train_time:15987ms step_avg:96.30ms +step:167/1695 train_time:16080ms step_avg:96.29ms +step:168/1695 train_time:16174ms step_avg:96.27ms +step:169/1695 train_time:16267ms step_avg:96.26ms +step:170/1695 train_time:16361ms step_avg:96.24ms +step:171/1695 train_time:16455ms step_avg:96.23ms +step:172/1695 train_time:16551ms step_avg:96.22ms +step:173/1695 train_time:16932ms step_avg:97.87ms +step:174/1695 train_time:17007ms step_avg:97.74ms +step:175/1695 train_time:17099ms step_avg:97.71ms +step:176/1695 train_time:17192ms step_avg:97.68ms +step:177/1695 train_time:17285ms step_avg:97.65ms +step:178/1695 train_time:17378ms step_avg:97.63ms +step:179/1695 train_time:17472ms step_avg:97.61ms +step:180/1695 train_time:17565ms step_avg:97.58ms +step:181/1695 train_time:17658ms step_avg:97.56ms +step:182/1695 train_time:17752ms step_avg:97.54ms +step:183/1695 train_time:17847ms step_avg:97.53ms +step:184/1695 train_time:17944ms step_avg:97.52ms +step:185/1695 train_time:18041ms step_avg:97.52ms +step:186/1695 train_time:18135ms step_avg:97.50ms +step:187/1695 train_time:18229ms step_avg:97.48ms +step:188/1695 train_time:18322ms step_avg:97.46ms +step:189/1695 train_time:18416ms step_avg:97.44ms +step:190/1695 train_time:18510ms step_avg:97.42ms +step:191/1695 train_time:18602ms step_avg:97.40ms +step:192/1695 train_time:18696ms step_avg:97.38ms +step:193/1695 train_time:18790ms step_avg:97.36ms +step:194/1695 train_time:18885ms step_avg:97.34ms +step:195/1695 train_time:18980ms step_avg:97.33ms +step:196/1695 train_time:19075ms step_avg:97.32ms +step:197/1695 train_time:19169ms step_avg:97.31ms +step:198/1695 train_time:19263ms step_avg:97.29ms +step:199/1695 train_time:19357ms step_avg:97.27ms +step:200/1695 train_time:19451ms step_avg:97.26ms +step:201/1695 train_time:19544ms step_avg:97.23ms +step:202/1695 train_time:19638ms step_avg:97.22ms +step:203/1695 train_time:19733ms step_avg:97.21ms +step:204/1695 train_time:19826ms step_avg:97.19ms +step:205/1695 train_time:19921ms step_avg:97.18ms +step:206/1695 train_time:20016ms step_avg:97.17ms +step:207/1695 train_time:20111ms step_avg:97.15ms +step:208/1695 train_time:20204ms step_avg:97.14ms +step:209/1695 train_time:20299ms step_avg:97.12ms +step:210/1695 train_time:20393ms step_avg:97.11ms +step:211/1695 train_time:20487ms step_avg:97.10ms +step:212/1695 train_time:20581ms step_avg:97.08ms +step:213/1695 train_time:20674ms step_avg:97.06ms +step:214/1695 train_time:20768ms step_avg:97.05ms +step:215/1695 train_time:20862ms step_avg:97.03ms +step:216/1695 train_time:20957ms step_avg:97.02ms +step:217/1695 train_time:21051ms step_avg:97.01ms +step:218/1695 train_time:21145ms step_avg:97.00ms +step:219/1695 train_time:21240ms step_avg:96.98ms +step:220/1695 train_time:21334ms step_avg:96.97ms +step:221/1695 train_time:21428ms step_avg:96.96ms +step:222/1695 train_time:21521ms step_avg:96.94ms +step:223/1695 train_time:21616ms step_avg:96.93ms +step:224/1695 train_time:21710ms step_avg:96.92ms +step:225/1695 train_time:21804ms step_avg:96.91ms +step:226/1695 train_time:21899ms step_avg:96.90ms +step:227/1695 train_time:21994ms step_avg:96.89ms +step:228/1695 train_time:22087ms step_avg:96.87ms +step:229/1695 train_time:22181ms step_avg:96.86ms +step:230/1695 train_time:22276ms step_avg:96.85ms +step:231/1695 train_time:22371ms step_avg:96.85ms +step:232/1695 train_time:22465ms step_avg:96.83ms +step:233/1695 train_time:22559ms step_avg:96.82ms +step:234/1695 train_time:22653ms step_avg:96.81ms +step:235/1695 train_time:22747ms step_avg:96.79ms +step:236/1695 train_time:22840ms step_avg:96.78ms +step:237/1695 train_time:22935ms step_avg:96.77ms +step:238/1695 train_time:23029ms step_avg:96.76ms +step:239/1695 train_time:23122ms step_avg:96.75ms +step:240/1695 train_time:23217ms step_avg:96.74ms +step:241/1695 train_time:23312ms step_avg:96.73ms +step:242/1695 train_time:23405ms step_avg:96.72ms +step:243/1695 train_time:23499ms step_avg:96.71ms +step:244/1695 train_time:23593ms step_avg:96.69ms +step:245/1695 train_time:23686ms step_avg:96.68ms +step:246/1695 train_time:23780ms step_avg:96.67ms +step:247/1695 train_time:23875ms step_avg:96.66ms +step:248/1695 train_time:23969ms step_avg:96.65ms +step:249/1695 train_time:24063ms step_avg:96.64ms +step:250/1695 train_time:24157ms step_avg:96.63ms +step:250/1695 val_loss:3.9654 train_time:24251ms step_avg:97.00ms +step:251/1695 train_time:24276ms step_avg:96.72ms +step:252/1695 train_time:24355ms step_avg:96.65ms +step:253/1695 train_time:24453ms step_avg:96.65ms +step:254/1695 train_time:24548ms step_avg:96.65ms +step:255/1695 train_time:24642ms step_avg:96.64ms +step:256/1695 train_time:24735ms step_avg:96.62ms +step:257/1695 train_time:24828ms step_avg:96.61ms +step:258/1695 train_time:24921ms step_avg:96.59ms +step:259/1695 train_time:25014ms step_avg:96.58ms +step:260/1695 train_time:25108ms step_avg:96.57ms +step:261/1695 train_time:25201ms step_avg:96.55ms +step:262/1695 train_time:25295ms step_avg:96.55ms +step:263/1695 train_time:25391ms step_avg:96.54ms +step:264/1695 train_time:25487ms step_avg:96.54ms +step:265/1695 train_time:25582ms step_avg:96.54ms +step:266/1695 train_time:25676ms step_avg:96.52ms +step:267/1695 train_time:25769ms step_avg:96.51ms +step:268/1695 train_time:25862ms step_avg:96.50ms +step:269/1695 train_time:25955ms step_avg:96.49ms +step:270/1695 train_time:26048ms step_avg:96.47ms +step:271/1695 train_time:26141ms step_avg:96.46ms +step:272/1695 train_time:26235ms step_avg:96.45ms +step:273/1695 train_time:26330ms step_avg:96.45ms +step:274/1695 train_time:26426ms step_avg:96.45ms +step:275/1695 train_time:26522ms step_avg:96.44ms +step:276/1695 train_time:26616ms step_avg:96.44ms +step:277/1695 train_time:26710ms step_avg:96.43ms +step:278/1695 train_time:26803ms step_avg:96.42ms +step:279/1695 train_time:26897ms step_avg:96.40ms +step:280/1695 train_time:26990ms step_avg:96.39ms +step:281/1695 train_time:27083ms step_avg:96.38ms +step:282/1695 train_time:27176ms step_avg:96.37ms +step:283/1695 train_time:27270ms step_avg:96.36ms +step:284/1695 train_time:27365ms step_avg:96.36ms +step:285/1695 train_time:27459ms step_avg:96.35ms +step:286/1695 train_time:27553ms step_avg:96.34ms +step:287/1695 train_time:27648ms step_avg:96.34ms +step:288/1695 train_time:27742ms step_avg:96.33ms +step:289/1695 train_time:27836ms step_avg:96.32ms +step:290/1695 train_time:27930ms step_avg:96.31ms +step:291/1695 train_time:28024ms step_avg:96.30ms +step:292/1695 train_time:28118ms step_avg:96.29ms +step:293/1695 train_time:28211ms step_avg:96.28ms +step:294/1695 train_time:28305ms step_avg:96.28ms +step:295/1695 train_time:28399ms step_avg:96.27ms +step:296/1695 train_time:28493ms step_avg:96.26ms +step:297/1695 train_time:28587ms step_avg:96.25ms +step:298/1695 train_time:28682ms step_avg:96.25ms +step:299/1695 train_time:28776ms step_avg:96.24ms +step:300/1695 train_time:28870ms step_avg:96.23ms +step:301/1695 train_time:28964ms step_avg:96.23ms +step:302/1695 train_time:29057ms step_avg:96.22ms +step:303/1695 train_time:29151ms step_avg:96.21ms +step:304/1695 train_time:29245ms step_avg:96.20ms +step:305/1695 train_time:29340ms step_avg:96.20ms +step:306/1695 train_time:29434ms step_avg:96.19ms +step:307/1695 train_time:29528ms step_avg:96.18ms +step:308/1695 train_time:29623ms step_avg:96.18ms +step:309/1695 train_time:29717ms step_avg:96.17ms +step:310/1695 train_time:29812ms step_avg:96.17ms +step:311/1695 train_time:29907ms step_avg:96.16ms +step:312/1695 train_time:30001ms step_avg:96.16ms +step:313/1695 train_time:30094ms step_avg:96.15ms +step:314/1695 train_time:30188ms step_avg:96.14ms +step:315/1695 train_time:30282ms step_avg:96.13ms +step:316/1695 train_time:30375ms step_avg:96.12ms +step:317/1695 train_time:30469ms step_avg:96.12ms +step:318/1695 train_time:30563ms step_avg:96.11ms +step:319/1695 train_time:30658ms step_avg:96.11ms +step:320/1695 train_time:30751ms step_avg:96.10ms +step:321/1695 train_time:30846ms step_avg:96.09ms +step:322/1695 train_time:30941ms step_avg:96.09ms +step:323/1695 train_time:31034ms step_avg:96.08ms +step:324/1695 train_time:31128ms step_avg:96.07ms +step:325/1695 train_time:31223ms step_avg:96.07ms +step:326/1695 train_time:31317ms step_avg:96.06ms +step:327/1695 train_time:31410ms step_avg:96.06ms +step:328/1695 train_time:31504ms step_avg:96.05ms +step:329/1695 train_time:31597ms step_avg:96.04ms +step:330/1695 train_time:31691ms step_avg:96.03ms +step:331/1695 train_time:31785ms step_avg:96.03ms +step:332/1695 train_time:31881ms step_avg:96.03ms +step:333/1695 train_time:31974ms step_avg:96.02ms +step:334/1695 train_time:32069ms step_avg:96.01ms +step:335/1695 train_time:32163ms step_avg:96.01ms +step:336/1695 train_time:32257ms step_avg:96.00ms +step:337/1695 train_time:32350ms step_avg:95.99ms +step:338/1695 train_time:32446ms step_avg:95.99ms +step:339/1695 train_time:32541ms step_avg:95.99ms +step:340/1695 train_time:32634ms step_avg:95.98ms +step:341/1695 train_time:32728ms step_avg:95.98ms +step:342/1695 train_time:32823ms step_avg:95.97ms +step:343/1695 train_time:32916ms step_avg:95.97ms +step:344/1695 train_time:33010ms step_avg:95.96ms +step:345/1695 train_time:33345ms step_avg:96.65ms +step:346/1695 train_time:33427ms step_avg:96.61ms +step:347/1695 train_time:33519ms step_avg:96.60ms +step:348/1695 train_time:33612ms step_avg:96.59ms +step:349/1695 train_time:33705ms step_avg:96.58ms +step:350/1695 train_time:33798ms step_avg:96.57ms +step:351/1695 train_time:33891ms step_avg:96.55ms +step:352/1695 train_time:33984ms step_avg:96.54ms +step:353/1695 train_time:34077ms step_avg:96.53ms +step:354/1695 train_time:34169ms step_avg:96.52ms +step:355/1695 train_time:34266ms step_avg:96.52ms +step:356/1695 train_time:34363ms step_avg:96.52ms +step:357/1695 train_time:34458ms step_avg:96.52ms +step:358/1695 train_time:34552ms step_avg:96.51ms +step:359/1695 train_time:34646ms step_avg:96.51ms +step:360/1695 train_time:34741ms step_avg:96.50ms +step:361/1695 train_time:34834ms step_avg:96.49ms +step:362/1695 train_time:34927ms step_avg:96.48ms +step:363/1695 train_time:35020ms step_avg:96.47ms +step:364/1695 train_time:35113ms step_avg:96.46ms +step:365/1695 train_time:35207ms step_avg:96.46ms +step:366/1695 train_time:35303ms step_avg:96.45ms +step:367/1695 train_time:35398ms step_avg:96.45ms +step:368/1695 train_time:35492ms step_avg:96.45ms +step:369/1695 train_time:35587ms step_avg:96.44ms +step:370/1695 train_time:35681ms step_avg:96.44ms +step:371/1695 train_time:35775ms step_avg:96.43ms +step:372/1695 train_time:35868ms step_avg:96.42ms +step:373/1695 train_time:35961ms step_avg:96.41ms +step:374/1695 train_time:36054ms step_avg:96.40ms +step:375/1695 train_time:36147ms step_avg:96.39ms +step:375/1695 val_loss:3.8151 train_time:36240ms step_avg:96.64ms +step:376/1695 train_time:36265ms step_avg:96.45ms +step:377/1695 train_time:36343ms step_avg:96.40ms +step:378/1695 train_time:36441ms step_avg:96.41ms +step:379/1695 train_time:36535ms step_avg:96.40ms +step:380/1695 train_time:36628ms step_avg:96.39ms +step:381/1695 train_time:36721ms step_avg:96.38ms +step:382/1695 train_time:36814ms step_avg:96.37ms +step:383/1695 train_time:36907ms step_avg:96.36ms +step:384/1695 train_time:36999ms step_avg:96.35ms +step:385/1695 train_time:37092ms step_avg:96.34ms +step:386/1695 train_time:37185ms step_avg:96.33ms +step:387/1695 train_time:37280ms step_avg:96.33ms +step:388/1695 train_time:37376ms step_avg:96.33ms +step:389/1695 train_time:37474ms step_avg:96.33ms +step:390/1695 train_time:37570ms step_avg:96.33ms +step:391/1695 train_time:37665ms step_avg:96.33ms +step:392/1695 train_time:37758ms step_avg:96.32ms +step:393/1695 train_time:37851ms step_avg:96.31ms +step:394/1695 train_time:37945ms step_avg:96.31ms +step:395/1695 train_time:38038ms step_avg:96.30ms +step:396/1695 train_time:38132ms step_avg:96.29ms +step:397/1695 train_time:38226ms step_avg:96.29ms +step:398/1695 train_time:38320ms step_avg:96.28ms +step:399/1695 train_time:38414ms step_avg:96.28ms +step:400/1695 train_time:38510ms step_avg:96.27ms +step:401/1695 train_time:38606ms step_avg:96.27ms +step:402/1695 train_time:38700ms step_avg:96.27ms +step:403/1695 train_time:38793ms step_avg:96.26ms +step:404/1695 train_time:38888ms step_avg:96.26ms +step:405/1695 train_time:38981ms step_avg:96.25ms +step:406/1695 train_time:39074ms step_avg:96.24ms +step:407/1695 train_time:39168ms step_avg:96.24ms +step:408/1695 train_time:39262ms step_avg:96.23ms +step:409/1695 train_time:39356ms step_avg:96.22ms +step:410/1695 train_time:39450ms step_avg:96.22ms +step:411/1695 train_time:39545ms step_avg:96.22ms +step:412/1695 train_time:39639ms step_avg:96.21ms +step:413/1695 train_time:39733ms step_avg:96.21ms +step:414/1695 train_time:39827ms step_avg:96.20ms +step:415/1695 train_time:39920ms step_avg:96.19ms +step:416/1695 train_time:40013ms step_avg:96.19ms +step:417/1695 train_time:40107ms step_avg:96.18ms +step:418/1695 train_time:40200ms step_avg:96.17ms +step:419/1695 train_time:40294ms step_avg:96.17ms +step:420/1695 train_time:40388ms step_avg:96.16ms +step:421/1695 train_time:40482ms step_avg:96.16ms +step:422/1695 train_time:40577ms step_avg:96.15ms +step:423/1695 train_time:40672ms step_avg:96.15ms +step:424/1695 train_time:40767ms step_avg:96.15ms +step:425/1695 train_time:40862ms step_avg:96.15ms +step:426/1695 train_time:40955ms step_avg:96.14ms +step:427/1695 train_time:41049ms step_avg:96.13ms +step:428/1695 train_time:41143ms step_avg:96.13ms +step:429/1695 train_time:41236ms step_avg:96.12ms +step:430/1695 train_time:41330ms step_avg:96.12ms +step:431/1695 train_time:41424ms step_avg:96.11ms +step:432/1695 train_time:41518ms step_avg:96.11ms +step:433/1695 train_time:41612ms step_avg:96.10ms +step:434/1695 train_time:41707ms step_avg:96.10ms +step:435/1695 train_time:41800ms step_avg:96.09ms +step:436/1695 train_time:41894ms step_avg:96.09ms +step:437/1695 train_time:41988ms step_avg:96.08ms +step:438/1695 train_time:42081ms step_avg:96.08ms +step:439/1695 train_time:42175ms step_avg:96.07ms +step:440/1695 train_time:42269ms step_avg:96.07ms +step:441/1695 train_time:42365ms step_avg:96.06ms +step:442/1695 train_time:42458ms step_avg:96.06ms +step:443/1695 train_time:42553ms step_avg:96.06ms +step:444/1695 train_time:42647ms step_avg:96.05ms +step:445/1695 train_time:42741ms step_avg:96.05ms +step:446/1695 train_time:42835ms step_avg:96.04ms +step:447/1695 train_time:42929ms step_avg:96.04ms +step:448/1695 train_time:43023ms step_avg:96.03ms +step:449/1695 train_time:43116ms step_avg:96.03ms +step:450/1695 train_time:43210ms step_avg:96.02ms +step:451/1695 train_time:43305ms step_avg:96.02ms +step:452/1695 train_time:43398ms step_avg:96.01ms +step:453/1695 train_time:43492ms step_avg:96.01ms +step:454/1695 train_time:43587ms step_avg:96.01ms +step:455/1695 train_time:43680ms step_avg:96.00ms +step:456/1695 train_time:43774ms step_avg:96.00ms +step:457/1695 train_time:43869ms step_avg:95.99ms +step:458/1695 train_time:43964ms step_avg:95.99ms +step:459/1695 train_time:44057ms step_avg:95.99ms +step:460/1695 train_time:44151ms step_avg:95.98ms +step:461/1695 train_time:44245ms step_avg:95.98ms +step:462/1695 train_time:44338ms step_avg:95.97ms +step:463/1695 train_time:44432ms step_avg:95.97ms +step:464/1695 train_time:44526ms step_avg:95.96ms +step:465/1695 train_time:44620ms step_avg:95.96ms +step:466/1695 train_time:44714ms step_avg:95.95ms +step:467/1695 train_time:44808ms step_avg:95.95ms +step:468/1695 train_time:44902ms step_avg:95.95ms +step:469/1695 train_time:44995ms step_avg:95.94ms +step:470/1695 train_time:45090ms step_avg:95.94ms +step:471/1695 train_time:45184ms step_avg:95.93ms +step:472/1695 train_time:45277ms step_avg:95.93ms +step:473/1695 train_time:45371ms step_avg:95.92ms +step:474/1695 train_time:45466ms step_avg:95.92ms +step:475/1695 train_time:45559ms step_avg:95.91ms +step:476/1695 train_time:45653ms step_avg:95.91ms +step:477/1695 train_time:45748ms step_avg:95.91ms +step:478/1695 train_time:45843ms step_avg:95.91ms +step:479/1695 train_time:45936ms step_avg:95.90ms +step:480/1695 train_time:46031ms step_avg:95.90ms +step:481/1695 train_time:46125ms step_avg:95.89ms +step:482/1695 train_time:46218ms step_avg:95.89ms +step:483/1695 train_time:46312ms step_avg:95.88ms +step:484/1695 train_time:46406ms step_avg:95.88ms +step:485/1695 train_time:46499ms step_avg:95.87ms +step:486/1695 train_time:46593ms step_avg:95.87ms +step:487/1695 train_time:46687ms step_avg:95.87ms +step:488/1695 train_time:46780ms step_avg:95.86ms +step:489/1695 train_time:46874ms step_avg:95.86ms +step:490/1695 train_time:46969ms step_avg:95.86ms +step:491/1695 train_time:47064ms step_avg:95.85ms +step:492/1695 train_time:47157ms step_avg:95.85ms +step:493/1695 train_time:47251ms step_avg:95.84ms +step:494/1695 train_time:47345ms step_avg:95.84ms +step:495/1695 train_time:47439ms step_avg:95.84ms +step:496/1695 train_time:47533ms step_avg:95.83ms +step:497/1695 train_time:47628ms step_avg:95.83ms +step:498/1695 train_time:47722ms step_avg:95.83ms +step:499/1695 train_time:47814ms step_avg:95.82ms +step:500/1695 train_time:47908ms step_avg:95.82ms +step:500/1695 val_loss:3.7156 train_time:48001ms step_avg:96.00ms +step:501/1695 train_time:48026ms step_avg:95.86ms +step:502/1695 train_time:48105ms step_avg:95.83ms +step:503/1695 train_time:48205ms step_avg:95.83ms +step:504/1695 train_time:48300ms step_avg:95.83ms +step:505/1695 train_time:48393ms step_avg:95.83ms +step:506/1695 train_time:48486ms step_avg:95.82ms +step:507/1695 train_time:48579ms step_avg:95.82ms +step:508/1695 train_time:48673ms step_avg:95.81ms +step:509/1695 train_time:48765ms step_avg:95.81ms +step:510/1695 train_time:48859ms step_avg:95.80ms +step:511/1695 train_time:48952ms step_avg:95.80ms +step:512/1695 train_time:49046ms step_avg:95.79ms +step:513/1695 train_time:49142ms step_avg:95.79ms +step:514/1695 train_time:49239ms step_avg:95.80ms +step:515/1695 train_time:49334ms step_avg:95.79ms +step:516/1695 train_time:49427ms step_avg:95.79ms +step:517/1695 train_time:49520ms step_avg:95.78ms +step:518/1695 train_time:49614ms step_avg:95.78ms +step:519/1695 train_time:50068ms step_avg:96.47ms +step:520/1695 train_time:50139ms step_avg:96.42ms +step:521/1695 train_time:50231ms step_avg:96.41ms +step:522/1695 train_time:50324ms step_avg:96.41ms +step:523/1695 train_time:50417ms step_avg:96.40ms +step:524/1695 train_time:50510ms step_avg:96.39ms +step:525/1695 train_time:50602ms step_avg:96.39ms +step:526/1695 train_time:50695ms step_avg:96.38ms +step:527/1695 train_time:50789ms step_avg:96.37ms +step:528/1695 train_time:50882ms step_avg:96.37ms +step:529/1695 train_time:50977ms step_avg:96.37ms +step:530/1695 train_time:51074ms step_avg:96.37ms +step:531/1695 train_time:51171ms step_avg:96.37ms +step:532/1695 train_time:51264ms step_avg:96.36ms +step:533/1695 train_time:51358ms step_avg:96.36ms +step:534/1695 train_time:51451ms step_avg:96.35ms +step:535/1695 train_time:51543ms step_avg:96.34ms +step:536/1695 train_time:51636ms step_avg:96.34ms +step:537/1695 train_time:51730ms step_avg:96.33ms +step:538/1695 train_time:51823ms step_avg:96.32ms +step:539/1695 train_time:51916ms step_avg:96.32ms +step:540/1695 train_time:52011ms step_avg:96.32ms +step:541/1695 train_time:52105ms step_avg:96.31ms +step:542/1695 train_time:52200ms step_avg:96.31ms +step:543/1695 train_time:52295ms step_avg:96.31ms +step:544/1695 train_time:52389ms step_avg:96.30ms +step:545/1695 train_time:52482ms step_avg:96.30ms +step:546/1695 train_time:52577ms step_avg:96.29ms +step:547/1695 train_time:52671ms step_avg:96.29ms +step:548/1695 train_time:52764ms step_avg:96.28ms +step:549/1695 train_time:52858ms step_avg:96.28ms +step:550/1695 train_time:52951ms step_avg:96.27ms +step:551/1695 train_time:53045ms step_avg:96.27ms +step:552/1695 train_time:53139ms step_avg:96.27ms +step:553/1695 train_time:53233ms step_avg:96.26ms +step:554/1695 train_time:53326ms step_avg:96.26ms +step:555/1695 train_time:53420ms step_avg:96.25ms +step:556/1695 train_time:53514ms step_avg:96.25ms +step:557/1695 train_time:53608ms step_avg:96.24ms +step:558/1695 train_time:53701ms step_avg:96.24ms +step:559/1695 train_time:53795ms step_avg:96.24ms +step:560/1695 train_time:53890ms step_avg:96.23ms +step:561/1695 train_time:53983ms step_avg:96.23ms +step:562/1695 train_time:54077ms step_avg:96.22ms +step:563/1695 train_time:54172ms step_avg:96.22ms +step:564/1695 train_time:54265ms step_avg:96.21ms +step:565/1695 train_time:54359ms step_avg:96.21ms +step:566/1695 train_time:54453ms step_avg:96.21ms +step:567/1695 train_time:54546ms step_avg:96.20ms +step:568/1695 train_time:54641ms step_avg:96.20ms +step:569/1695 train_time:54737ms step_avg:96.20ms +step:570/1695 train_time:54833ms step_avg:96.20ms +step:571/1695 train_time:54929ms step_avg:96.20ms +step:572/1695 train_time:55024ms step_avg:96.20ms +step:573/1695 train_time:55120ms step_avg:96.20ms +step:574/1695 train_time:55216ms step_avg:96.20ms +step:575/1695 train_time:55312ms step_avg:96.19ms +step:576/1695 train_time:55407ms step_avg:96.19ms +step:577/1695 train_time:55502ms step_avg:96.19ms +step:578/1695 train_time:55599ms step_avg:96.19ms +step:579/1695 train_time:55695ms step_avg:96.19ms +step:580/1695 train_time:55792ms step_avg:96.19ms +step:581/1695 train_time:55887ms step_avg:96.19ms +step:582/1695 train_time:55983ms step_avg:96.19ms +step:583/1695 train_time:56080ms step_avg:96.19ms +step:584/1695 train_time:56177ms step_avg:96.19ms +step:585/1695 train_time:56274ms step_avg:96.19ms +step:586/1695 train_time:56371ms step_avg:96.20ms +step:587/1695 train_time:56467ms step_avg:96.20ms +step:588/1695 train_time:56562ms step_avg:96.19ms +step:589/1695 train_time:56659ms step_avg:96.19ms +step:590/1695 train_time:56755ms step_avg:96.20ms +step:591/1695 train_time:56851ms step_avg:96.20ms +step:592/1695 train_time:56946ms step_avg:96.19ms +step:593/1695 train_time:57041ms step_avg:96.19ms +step:594/1695 train_time:57138ms step_avg:96.19ms +step:595/1695 train_time:57234ms step_avg:96.19ms +step:596/1695 train_time:57332ms step_avg:96.19ms +step:597/1695 train_time:57428ms step_avg:96.19ms +step:598/1695 train_time:57523ms step_avg:96.19ms +step:599/1695 train_time:57619ms step_avg:96.19ms +step:600/1695 train_time:57716ms step_avg:96.19ms +step:601/1695 train_time:57812ms step_avg:96.19ms +step:602/1695 train_time:57906ms step_avg:96.19ms +step:603/1695 train_time:58002ms step_avg:96.19ms +step:604/1695 train_time:58098ms step_avg:96.19ms +step:605/1695 train_time:58194ms step_avg:96.19ms +step:606/1695 train_time:58290ms step_avg:96.19ms +step:607/1695 train_time:58386ms step_avg:96.19ms +step:608/1695 train_time:58482ms step_avg:96.19ms +step:609/1695 train_time:58579ms step_avg:96.19ms +step:610/1695 train_time:58676ms step_avg:96.19ms +step:611/1695 train_time:58773ms step_avg:96.19ms +step:612/1695 train_time:58868ms step_avg:96.19ms +step:613/1695 train_time:58963ms step_avg:96.19ms +step:614/1695 train_time:59059ms step_avg:96.19ms +step:615/1695 train_time:59156ms step_avg:96.19ms +step:616/1695 train_time:59252ms step_avg:96.19ms +step:617/1695 train_time:59348ms step_avg:96.19ms +step:618/1695 train_time:59444ms step_avg:96.19ms +step:619/1695 train_time:59540ms step_avg:96.19ms +step:620/1695 train_time:59636ms step_avg:96.19ms +step:621/1695 train_time:59732ms step_avg:96.19ms +step:622/1695 train_time:59828ms step_avg:96.19ms +step:623/1695 train_time:59923ms step_avg:96.18ms +step:624/1695 train_time:60019ms step_avg:96.18ms +step:625/1695 train_time:60116ms step_avg:96.19ms +step:625/1695 val_loss:3.6216 train_time:60211ms step_avg:96.34ms +step:626/1695 train_time:60235ms step_avg:96.22ms +step:627/1695 train_time:60317ms step_avg:96.20ms +step:628/1695 train_time:60413ms step_avg:96.20ms +step:629/1695 train_time:60508ms step_avg:96.20ms +step:630/1695 train_time:60603ms step_avg:96.19ms +step:631/1695 train_time:60697ms step_avg:96.19ms +step:632/1695 train_time:60792ms step_avg:96.19ms +step:633/1695 train_time:60888ms step_avg:96.19ms +step:634/1695 train_time:60982ms step_avg:96.19ms +step:635/1695 train_time:61078ms step_avg:96.19ms +step:636/1695 train_time:61178ms step_avg:96.19ms +step:637/1695 train_time:61278ms step_avg:96.20ms +step:638/1695 train_time:61377ms step_avg:96.20ms +step:639/1695 train_time:61475ms step_avg:96.20ms +step:640/1695 train_time:61572ms step_avg:96.21ms +step:641/1695 train_time:61668ms step_avg:96.21ms +step:642/1695 train_time:61763ms step_avg:96.20ms +step:643/1695 train_time:61858ms step_avg:96.20ms +step:644/1695 train_time:61954ms step_avg:96.20ms +step:645/1695 train_time:62049ms step_avg:96.20ms +step:646/1695 train_time:62144ms step_avg:96.20ms +step:647/1695 train_time:62241ms step_avg:96.20ms +step:648/1695 train_time:62339ms step_avg:96.20ms +step:649/1695 train_time:62436ms step_avg:96.20ms +step:650/1695 train_time:62533ms step_avg:96.21ms +step:651/1695 train_time:62630ms step_avg:96.21ms +step:652/1695 train_time:62725ms step_avg:96.20ms +step:653/1695 train_time:62820ms step_avg:96.20ms +step:654/1695 train_time:62916ms step_avg:96.20ms +step:655/1695 train_time:63011ms step_avg:96.20ms +step:656/1695 train_time:63107ms step_avg:96.20ms +step:657/1695 train_time:63202ms step_avg:96.20ms +step:658/1695 train_time:63298ms step_avg:96.20ms +step:659/1695 train_time:63396ms step_avg:96.20ms +step:660/1695 train_time:63494ms step_avg:96.20ms +step:661/1695 train_time:63591ms step_avg:96.20ms +step:662/1695 train_time:63688ms step_avg:96.21ms +step:663/1695 train_time:63784ms step_avg:96.20ms +step:664/1695 train_time:63879ms step_avg:96.20ms +step:665/1695 train_time:63974ms step_avg:96.20ms +step:666/1695 train_time:64070ms step_avg:96.20ms +step:667/1695 train_time:64165ms step_avg:96.20ms +step:668/1695 train_time:64260ms step_avg:96.20ms +step:669/1695 train_time:64356ms step_avg:96.20ms +step:670/1695 train_time:64453ms step_avg:96.20ms +step:671/1695 train_time:64550ms step_avg:96.20ms +step:672/1695 train_time:64647ms step_avg:96.20ms +step:673/1695 train_time:64742ms step_avg:96.20ms +step:674/1695 train_time:64838ms step_avg:96.20ms +step:675/1695 train_time:64934ms step_avg:96.20ms +step:676/1695 train_time:65030ms step_avg:96.20ms +step:677/1695 train_time:65126ms step_avg:96.20ms +step:678/1695 train_time:65222ms step_avg:96.20ms +step:679/1695 train_time:65318ms step_avg:96.20ms +step:680/1695 train_time:65414ms step_avg:96.20ms +step:681/1695 train_time:65510ms step_avg:96.20ms +step:682/1695 train_time:65606ms step_avg:96.20ms +step:683/1695 train_time:65701ms step_avg:96.20ms +step:684/1695 train_time:65798ms step_avg:96.20ms +step:685/1695 train_time:65894ms step_avg:96.20ms +step:686/1695 train_time:65990ms step_avg:96.20ms +step:687/1695 train_time:66086ms step_avg:96.20ms +step:688/1695 train_time:66181ms step_avg:96.19ms +step:689/1695 train_time:66277ms step_avg:96.19ms +step:690/1695 train_time:66373ms step_avg:96.19ms +step:691/1695 train_time:66817ms step_avg:96.70ms +step:692/1695 train_time:66898ms step_avg:96.67ms +step:693/1695 train_time:66992ms step_avg:96.67ms +step:694/1695 train_time:67087ms step_avg:96.67ms +step:695/1695 train_time:67181ms step_avg:96.66ms +step:696/1695 train_time:67277ms step_avg:96.66ms +step:697/1695 train_time:67371ms step_avg:96.66ms +step:698/1695 train_time:67466ms step_avg:96.66ms +step:699/1695 train_time:67560ms step_avg:96.65ms +step:700/1695 train_time:67656ms step_avg:96.65ms +step:701/1695 train_time:67756ms step_avg:96.66ms +step:702/1695 train_time:67856ms step_avg:96.66ms +step:703/1695 train_time:67953ms step_avg:96.66ms +step:704/1695 train_time:68049ms step_avg:96.66ms +step:705/1695 train_time:68144ms step_avg:96.66ms +step:706/1695 train_time:68238ms step_avg:96.65ms +step:707/1695 train_time:68334ms step_avg:96.65ms +step:708/1695 train_time:68430ms step_avg:96.65ms +step:709/1695 train_time:68524ms step_avg:96.65ms +step:710/1695 train_time:68618ms step_avg:96.65ms +step:711/1695 train_time:68715ms step_avg:96.65ms +step:712/1695 train_time:68813ms step_avg:96.65ms +step:713/1695 train_time:68909ms step_avg:96.65ms +step:714/1695 train_time:69005ms step_avg:96.65ms +step:715/1695 train_time:69101ms step_avg:96.64ms +step:716/1695 train_time:69196ms step_avg:96.64ms +step:717/1695 train_time:69293ms step_avg:96.64ms +step:718/1695 train_time:69390ms step_avg:96.64ms +step:719/1695 train_time:69485ms step_avg:96.64ms +step:720/1695 train_time:69580ms step_avg:96.64ms +step:721/1695 train_time:69676ms step_avg:96.64ms +step:722/1695 train_time:69774ms step_avg:96.64ms +step:723/1695 train_time:69871ms step_avg:96.64ms +step:724/1695 train_time:69969ms step_avg:96.64ms +step:725/1695 train_time:70065ms step_avg:96.64ms +step:726/1695 train_time:70160ms step_avg:96.64ms +step:727/1695 train_time:70256ms step_avg:96.64ms +step:728/1695 train_time:70353ms step_avg:96.64ms +step:729/1695 train_time:70450ms step_avg:96.64ms +step:730/1695 train_time:70547ms step_avg:96.64ms +step:731/1695 train_time:70642ms step_avg:96.64ms +step:732/1695 train_time:70737ms step_avg:96.64ms +step:733/1695 train_time:70834ms step_avg:96.64ms +step:734/1695 train_time:70933ms step_avg:96.64ms +step:735/1695 train_time:71029ms step_avg:96.64ms +step:736/1695 train_time:71125ms step_avg:96.64ms +step:737/1695 train_time:71220ms step_avg:96.63ms +step:738/1695 train_time:71316ms step_avg:96.63ms +step:739/1695 train_time:71413ms step_avg:96.63ms +step:740/1695 train_time:71508ms step_avg:96.63ms +step:741/1695 train_time:71603ms step_avg:96.63ms +step:742/1695 train_time:71698ms step_avg:96.63ms +step:743/1695 train_time:71794ms step_avg:96.63ms +step:744/1695 train_time:71891ms step_avg:96.63ms +step:745/1695 train_time:71988ms step_avg:96.63ms +step:746/1695 train_time:72083ms step_avg:96.63ms +step:747/1695 train_time:72179ms step_avg:96.63ms +step:748/1695 train_time:72275ms step_avg:96.62ms +step:749/1695 train_time:72371ms step_avg:96.62ms +step:750/1695 train_time:72467ms step_avg:96.62ms +step:750/1695 val_loss:3.5657 train_time:72560ms step_avg:96.75ms +step:751/1695 train_time:72585ms step_avg:96.65ms +step:752/1695 train_time:72667ms step_avg:96.63ms +step:753/1695 train_time:72765ms step_avg:96.63ms +step:754/1695 train_time:72860ms step_avg:96.63ms +step:755/1695 train_time:72956ms step_avg:96.63ms +step:756/1695 train_time:73050ms step_avg:96.63ms +step:757/1695 train_time:73145ms step_avg:96.62ms +step:758/1695 train_time:73239ms step_avg:96.62ms +step:759/1695 train_time:73334ms step_avg:96.62ms +step:760/1695 train_time:73429ms step_avg:96.62ms +step:761/1695 train_time:73526ms step_avg:96.62ms +step:762/1695 train_time:73625ms step_avg:96.62ms +step:763/1695 train_time:73723ms step_avg:96.62ms +step:764/1695 train_time:73819ms step_avg:96.62ms +step:765/1695 train_time:73916ms step_avg:96.62ms +step:766/1695 train_time:74011ms step_avg:96.62ms +step:767/1695 train_time:74106ms step_avg:96.62ms +step:768/1695 train_time:74201ms step_avg:96.62ms +step:769/1695 train_time:74296ms step_avg:96.61ms +step:770/1695 train_time:74391ms step_avg:96.61ms +step:771/1695 train_time:74487ms step_avg:96.61ms +step:772/1695 train_time:74584ms step_avg:96.61ms +step:773/1695 train_time:74682ms step_avg:96.61ms +step:774/1695 train_time:74779ms step_avg:96.61ms +step:775/1695 train_time:74875ms step_avg:96.61ms +step:776/1695 train_time:74971ms step_avg:96.61ms +step:777/1695 train_time:75066ms step_avg:96.61ms +step:778/1695 train_time:75161ms step_avg:96.61ms +step:779/1695 train_time:75257ms step_avg:96.61ms +step:780/1695 train_time:75352ms step_avg:96.61ms +step:781/1695 train_time:75446ms step_avg:96.60ms +step:782/1695 train_time:75542ms step_avg:96.60ms +step:783/1695 train_time:75639ms step_avg:96.60ms +step:784/1695 train_time:75736ms step_avg:96.60ms +step:785/1695 train_time:75833ms step_avg:96.60ms +step:786/1695 train_time:75930ms step_avg:96.60ms +step:787/1695 train_time:76024ms step_avg:96.60ms +step:788/1695 train_time:76119ms step_avg:96.60ms +step:789/1695 train_time:76215ms step_avg:96.60ms +step:790/1695 train_time:76309ms step_avg:96.59ms +step:791/1695 train_time:76404ms step_avg:96.59ms +step:792/1695 train_time:76500ms step_avg:96.59ms +step:793/1695 train_time:76597ms step_avg:96.59ms +step:794/1695 train_time:76694ms step_avg:96.59ms +step:795/1695 train_time:76790ms step_avg:96.59ms +step:796/1695 train_time:76886ms step_avg:96.59ms +step:797/1695 train_time:76982ms step_avg:96.59ms +step:798/1695 train_time:77077ms step_avg:96.59ms +step:799/1695 train_time:77173ms step_avg:96.59ms +step:800/1695 train_time:77268ms step_avg:96.58ms +step:801/1695 train_time:77362ms step_avg:96.58ms +step:802/1695 train_time:77458ms step_avg:96.58ms +step:803/1695 train_time:77555ms step_avg:96.58ms +step:804/1695 train_time:77651ms step_avg:96.58ms +step:805/1695 train_time:77747ms step_avg:96.58ms +step:806/1695 train_time:77843ms step_avg:96.58ms +step:807/1695 train_time:77939ms step_avg:96.58ms +step:808/1695 train_time:78036ms step_avg:96.58ms +step:809/1695 train_time:78132ms step_avg:96.58ms +step:810/1695 train_time:78228ms step_avg:96.58ms +step:811/1695 train_time:78322ms step_avg:96.58ms +step:812/1695 train_time:78418ms step_avg:96.57ms +step:813/1695 train_time:78514ms step_avg:96.57ms +step:814/1695 train_time:78611ms step_avg:96.57ms +step:815/1695 train_time:78706ms step_avg:96.57ms +step:816/1695 train_time:78802ms step_avg:96.57ms +step:817/1695 train_time:78899ms step_avg:96.57ms +step:818/1695 train_time:78995ms step_avg:96.57ms +step:819/1695 train_time:79091ms step_avg:96.57ms +step:820/1695 train_time:79186ms step_avg:96.57ms +step:821/1695 train_time:79281ms step_avg:96.57ms +step:822/1695 train_time:79378ms step_avg:96.57ms +step:823/1695 train_time:79474ms step_avg:96.57ms +step:824/1695 train_time:79570ms step_avg:96.57ms +step:825/1695 train_time:79665ms step_avg:96.56ms +step:826/1695 train_time:79761ms step_avg:96.56ms +step:827/1695 train_time:79857ms step_avg:96.56ms +step:828/1695 train_time:79953ms step_avg:96.56ms +step:829/1695 train_time:80049ms step_avg:96.56ms +step:830/1695 train_time:80145ms step_avg:96.56ms +step:831/1695 train_time:80241ms step_avg:96.56ms +step:832/1695 train_time:80337ms step_avg:96.56ms +step:833/1695 train_time:80434ms step_avg:96.56ms +step:834/1695 train_time:80529ms step_avg:96.56ms +step:835/1695 train_time:80624ms step_avg:96.56ms +step:836/1695 train_time:80720ms step_avg:96.55ms +step:837/1695 train_time:80816ms step_avg:96.55ms +step:838/1695 train_time:80912ms step_avg:96.55ms +step:839/1695 train_time:81007ms step_avg:96.55ms +step:840/1695 train_time:81103ms step_avg:96.55ms +step:841/1695 train_time:81199ms step_avg:96.55ms +step:842/1695 train_time:81294ms step_avg:96.55ms +step:843/1695 train_time:81390ms step_avg:96.55ms +step:844/1695 train_time:81485ms step_avg:96.55ms +step:845/1695 train_time:81581ms step_avg:96.55ms +step:846/1695 train_time:81676ms step_avg:96.54ms +step:847/1695 train_time:81772ms step_avg:96.54ms +step:848/1695 train_time:81867ms step_avg:96.54ms +step:849/1695 train_time:81963ms step_avg:96.54ms +step:850/1695 train_time:82059ms step_avg:96.54ms +step:851/1695 train_time:82155ms step_avg:96.54ms +step:852/1695 train_time:82251ms step_avg:96.54ms +step:853/1695 train_time:82347ms step_avg:96.54ms +step:854/1695 train_time:82443ms step_avg:96.54ms +step:855/1695 train_time:82539ms step_avg:96.54ms +step:856/1695 train_time:82635ms step_avg:96.54ms +step:857/1695 train_time:82730ms step_avg:96.53ms +step:858/1695 train_time:82826ms step_avg:96.53ms +step:859/1695 train_time:82921ms step_avg:96.53ms +step:860/1695 train_time:83017ms step_avg:96.53ms +step:861/1695 train_time:83113ms step_avg:96.53ms +step:862/1695 train_time:83208ms step_avg:96.53ms +step:863/1695 train_time:83631ms step_avg:96.91ms +step:864/1695 train_time:83735ms step_avg:96.92ms +step:865/1695 train_time:83829ms step_avg:96.91ms +step:866/1695 train_time:83923ms step_avg:96.91ms +step:867/1695 train_time:84018ms step_avg:96.91ms +step:868/1695 train_time:84113ms step_avg:96.90ms +step:869/1695 train_time:84207ms step_avg:96.90ms +step:870/1695 train_time:84301ms step_avg:96.90ms +step:871/1695 train_time:84397ms step_avg:96.90ms +step:872/1695 train_time:84492ms step_avg:96.89ms +step:873/1695 train_time:84593ms step_avg:96.90ms +step:874/1695 train_time:84691ms step_avg:96.90ms +step:875/1695 train_time:84788ms step_avg:96.90ms +step:875/1695 val_loss:3.5240 train_time:84881ms step_avg:97.01ms +step:876/1695 train_time:84907ms step_avg:96.93ms +step:877/1695 train_time:84991ms step_avg:96.91ms +step:878/1695 train_time:85089ms step_avg:96.91ms +step:879/1695 train_time:85186ms step_avg:96.91ms +step:880/1695 train_time:85282ms step_avg:96.91ms +step:881/1695 train_time:85377ms step_avg:96.91ms +step:882/1695 train_time:85472ms step_avg:96.91ms +step:883/1695 train_time:85567ms step_avg:96.90ms +step:884/1695 train_time:85663ms step_avg:96.90ms +step:885/1695 train_time:85757ms step_avg:96.90ms +step:886/1695 train_time:85854ms step_avg:96.90ms +step:887/1695 train_time:85952ms step_avg:96.90ms +step:888/1695 train_time:86049ms step_avg:96.90ms +step:889/1695 train_time:86145ms step_avg:96.90ms +step:890/1695 train_time:86242ms step_avg:96.90ms +step:891/1695 train_time:86339ms step_avg:96.90ms +step:892/1695 train_time:86435ms step_avg:96.90ms +step:893/1695 train_time:86530ms step_avg:96.90ms +step:894/1695 train_time:86625ms step_avg:96.90ms +step:895/1695 train_time:86720ms step_avg:96.89ms +step:896/1695 train_time:86816ms step_avg:96.89ms +step:897/1695 train_time:86913ms step_avg:96.89ms +step:898/1695 train_time:87009ms step_avg:96.89ms +step:899/1695 train_time:87105ms step_avg:96.89ms +step:900/1695 train_time:87201ms step_avg:96.89ms +step:901/1695 train_time:87298ms step_avg:96.89ms +step:902/1695 train_time:87394ms step_avg:96.89ms +step:903/1695 train_time:87489ms step_avg:96.89ms +step:904/1695 train_time:87584ms step_avg:96.88ms +step:905/1695 train_time:87679ms step_avg:96.88ms +step:906/1695 train_time:87776ms step_avg:96.88ms +step:907/1695 train_time:87872ms step_avg:96.88ms +step:908/1695 train_time:87968ms step_avg:96.88ms +step:909/1695 train_time:88065ms step_avg:96.88ms +step:910/1695 train_time:88162ms step_avg:96.88ms +step:911/1695 train_time:88259ms step_avg:96.88ms +step:912/1695 train_time:88354ms step_avg:96.88ms +step:913/1695 train_time:88449ms step_avg:96.88ms +step:914/1695 train_time:88544ms step_avg:96.88ms +step:915/1695 train_time:88640ms step_avg:96.87ms +step:916/1695 train_time:88736ms step_avg:96.87ms +step:917/1695 train_time:88831ms step_avg:96.87ms +step:918/1695 train_time:88927ms step_avg:96.87ms +step:919/1695 train_time:89023ms step_avg:96.87ms +step:920/1695 train_time:89119ms step_avg:96.87ms +step:921/1695 train_time:89216ms step_avg:96.87ms +step:922/1695 train_time:89312ms step_avg:96.87ms +step:923/1695 train_time:89407ms step_avg:96.87ms +step:924/1695 train_time:89504ms step_avg:96.87ms +step:925/1695 train_time:89600ms step_avg:96.86ms +step:926/1695 train_time:89696ms step_avg:96.86ms +step:927/1695 train_time:89792ms step_avg:96.86ms +step:928/1695 train_time:89887ms step_avg:96.86ms +step:929/1695 train_time:89983ms step_avg:96.86ms +step:930/1695 train_time:90080ms step_avg:96.86ms +step:931/1695 train_time:90177ms step_avg:96.86ms +step:932/1695 train_time:90273ms step_avg:96.86ms +step:933/1695 train_time:90369ms step_avg:96.86ms +step:934/1695 train_time:90465ms step_avg:96.86ms +step:935/1695 train_time:90561ms step_avg:96.86ms +step:936/1695 train_time:90657ms step_avg:96.86ms +step:937/1695 train_time:90753ms step_avg:96.86ms +step:938/1695 train_time:90848ms step_avg:96.85ms +step:939/1695 train_time:90944ms step_avg:96.85ms +step:940/1695 train_time:91040ms step_avg:96.85ms +step:941/1695 train_time:91136ms step_avg:96.85ms +step:942/1695 train_time:91232ms step_avg:96.85ms +step:943/1695 train_time:91329ms step_avg:96.85ms +step:944/1695 train_time:91424ms step_avg:96.85ms +step:945/1695 train_time:91520ms step_avg:96.85ms +step:946/1695 train_time:91616ms step_avg:96.85ms +step:947/1695 train_time:91711ms step_avg:96.84ms +step:948/1695 train_time:91807ms step_avg:96.84ms +step:949/1695 train_time:91904ms step_avg:96.84ms +step:950/1695 train_time:92001ms step_avg:96.84ms +step:951/1695 train_time:92098ms step_avg:96.84ms +step:952/1695 train_time:92194ms step_avg:96.84ms +step:953/1695 train_time:92290ms step_avg:96.84ms +step:954/1695 train_time:92385ms step_avg:96.84ms +step:955/1695 train_time:92481ms step_avg:96.84ms +step:956/1695 train_time:92578ms step_avg:96.84ms +step:957/1695 train_time:92675ms step_avg:96.84ms +step:958/1695 train_time:92770ms step_avg:96.84ms +step:959/1695 train_time:92866ms step_avg:96.84ms +step:960/1695 train_time:92963ms step_avg:96.84ms +step:961/1695 train_time:93060ms step_avg:96.84ms +step:962/1695 train_time:93158ms step_avg:96.84ms +step:963/1695 train_time:93254ms step_avg:96.84ms +step:964/1695 train_time:93349ms step_avg:96.83ms +step:965/1695 train_time:93444ms step_avg:96.83ms +step:966/1695 train_time:93540ms step_avg:96.83ms +step:967/1695 train_time:93636ms step_avg:96.83ms +step:968/1695 train_time:93733ms step_avg:96.83ms +step:969/1695 train_time:93828ms step_avg:96.83ms +step:970/1695 train_time:93924ms step_avg:96.83ms +step:971/1695 train_time:94021ms step_avg:96.83ms +step:972/1695 train_time:94119ms step_avg:96.83ms +step:973/1695 train_time:94215ms step_avg:96.83ms +step:974/1695 train_time:94310ms step_avg:96.83ms +step:975/1695 train_time:94406ms step_avg:96.83ms +step:976/1695 train_time:94501ms step_avg:96.82ms +step:977/1695 train_time:94597ms step_avg:96.82ms +step:978/1695 train_time:94693ms step_avg:96.82ms +step:979/1695 train_time:94788ms step_avg:96.82ms +step:980/1695 train_time:94884ms step_avg:96.82ms +step:981/1695 train_time:94980ms step_avg:96.82ms +step:982/1695 train_time:95077ms step_avg:96.82ms +step:983/1695 train_time:95173ms step_avg:96.82ms +step:984/1695 train_time:95268ms step_avg:96.82ms +step:985/1695 train_time:95364ms step_avg:96.82ms +step:986/1695 train_time:95460ms step_avg:96.82ms +step:987/1695 train_time:95556ms step_avg:96.81ms +step:988/1695 train_time:95652ms step_avg:96.81ms +step:989/1695 train_time:95747ms step_avg:96.81ms +step:990/1695 train_time:95844ms step_avg:96.81ms +step:991/1695 train_time:95940ms step_avg:96.81ms +step:992/1695 train_time:96037ms step_avg:96.81ms +step:993/1695 train_time:96133ms step_avg:96.81ms +step:994/1695 train_time:96229ms step_avg:96.81ms +step:995/1695 train_time:96325ms step_avg:96.81ms +step:996/1695 train_time:96422ms step_avg:96.81ms +step:997/1695 train_time:96518ms step_avg:96.81ms +step:998/1695 train_time:96613ms step_avg:96.81ms +step:999/1695 train_time:96708ms step_avg:96.80ms +step:1000/1695 train_time:96804ms step_avg:96.80ms +step:1000/1695 val_loss:3.4845 train_time:96898ms step_avg:96.90ms +step:1001/1695 train_time:96924ms step_avg:96.83ms +step:1002/1695 train_time:97001ms step_avg:96.81ms +step:1003/1695 train_time:97099ms step_avg:96.81ms +step:1004/1695 train_time:97195ms step_avg:96.81ms +step:1005/1695 train_time:97291ms step_avg:96.81ms +step:1006/1695 train_time:97386ms step_avg:96.81ms +step:1007/1695 train_time:97481ms step_avg:96.80ms +step:1008/1695 train_time:97576ms step_avg:96.80ms +step:1009/1695 train_time:97672ms step_avg:96.80ms +step:1010/1695 train_time:97766ms step_avg:96.80ms +step:1011/1695 train_time:97863ms step_avg:96.80ms +step:1012/1695 train_time:97961ms step_avg:96.80ms +step:1013/1695 train_time:98057ms step_avg:96.80ms +step:1014/1695 train_time:98153ms step_avg:96.80ms +step:1015/1695 train_time:98249ms step_avg:96.80ms +step:1016/1695 train_time:98346ms step_avg:96.80ms +step:1017/1695 train_time:98440ms step_avg:96.79ms +step:1018/1695 train_time:98535ms step_avg:96.79ms +step:1019/1695 train_time:98632ms step_avg:96.79ms +step:1020/1695 train_time:98728ms step_avg:96.79ms +step:1021/1695 train_time:98823ms step_avg:96.79ms +step:1022/1695 train_time:98919ms step_avg:96.79ms +step:1023/1695 train_time:99016ms step_avg:96.79ms +step:1024/1695 train_time:99114ms step_avg:96.79ms +step:1025/1695 train_time:99211ms step_avg:96.79ms +step:1026/1695 train_time:99307ms step_avg:96.79ms +step:1027/1695 train_time:99402ms step_avg:96.79ms +step:1028/1695 train_time:99497ms step_avg:96.79ms +step:1029/1695 train_time:99593ms step_avg:96.79ms +step:1030/1695 train_time:99690ms step_avg:96.79ms +step:1031/1695 train_time:99786ms step_avg:96.79ms +step:1032/1695 train_time:99882ms step_avg:96.78ms +step:1033/1695 train_time:99978ms step_avg:96.78ms +step:1034/1695 train_time:100075ms step_avg:96.78ms +step:1035/1695 train_time:100172ms step_avg:96.78ms +step:1036/1695 train_time:100506ms step_avg:97.01ms +step:1037/1695 train_time:100695ms step_avg:97.10ms +step:1038/1695 train_time:100788ms step_avg:97.10ms +step:1039/1695 train_time:100882ms step_avg:97.10ms +step:1040/1695 train_time:100977ms step_avg:97.09ms +step:1041/1695 train_time:101073ms step_avg:97.09ms +step:1042/1695 train_time:101168ms step_avg:97.09ms +step:1043/1695 train_time:101262ms step_avg:97.09ms +step:1044/1695 train_time:101357ms step_avg:97.09ms +step:1045/1695 train_time:101452ms step_avg:97.08ms +step:1046/1695 train_time:101548ms step_avg:97.08ms +step:1047/1695 train_time:101649ms step_avg:97.09ms +step:1048/1695 train_time:101747ms step_avg:97.09ms +step:1049/1695 train_time:101842ms step_avg:97.09ms +step:1050/1695 train_time:101937ms step_avg:97.08ms +step:1051/1695 train_time:102034ms step_avg:97.08ms +step:1052/1695 train_time:102130ms step_avg:97.08ms +step:1053/1695 train_time:102225ms step_avg:97.08ms +step:1054/1695 train_time:102320ms step_avg:97.08ms +step:1055/1695 train_time:102415ms step_avg:97.08ms +step:1056/1695 train_time:102511ms step_avg:97.07ms +step:1057/1695 train_time:102610ms step_avg:97.08ms +step:1058/1695 train_time:102707ms step_avg:97.08ms +step:1059/1695 train_time:102804ms step_avg:97.08ms +step:1060/1695 train_time:102900ms step_avg:97.08ms +step:1061/1695 train_time:102996ms step_avg:97.07ms +step:1062/1695 train_time:103092ms step_avg:97.07ms +step:1063/1695 train_time:103187ms step_avg:97.07ms +step:1064/1695 train_time:103282ms step_avg:97.07ms +step:1065/1695 train_time:103377ms step_avg:97.07ms +step:1066/1695 train_time:103473ms step_avg:97.07ms +step:1067/1695 train_time:103569ms step_avg:97.07ms +step:1068/1695 train_time:103666ms step_avg:97.07ms +step:1069/1695 train_time:103762ms step_avg:97.06ms +step:1070/1695 train_time:103858ms step_avg:97.06ms +step:1071/1695 train_time:103954ms step_avg:97.06ms +step:1072/1695 train_time:104050ms step_avg:97.06ms +step:1073/1695 train_time:104145ms step_avg:97.06ms +step:1074/1695 train_time:104241ms step_avg:97.06ms +step:1075/1695 train_time:104336ms step_avg:97.06ms +step:1076/1695 train_time:104432ms step_avg:97.06ms +step:1077/1695 train_time:104527ms step_avg:97.05ms +step:1078/1695 train_time:104624ms step_avg:97.05ms +step:1079/1695 train_time:104719ms step_avg:97.05ms +step:1080/1695 train_time:104817ms step_avg:97.05ms +step:1081/1695 train_time:104915ms step_avg:97.05ms +step:1082/1695 train_time:105011ms step_avg:97.05ms +step:1083/1695 train_time:105107ms step_avg:97.05ms +step:1084/1695 train_time:105203ms step_avg:97.05ms +step:1085/1695 train_time:105298ms step_avg:97.05ms +step:1086/1695 train_time:105394ms step_avg:97.05ms +step:1087/1695 train_time:105489ms step_avg:97.05ms +step:1088/1695 train_time:105585ms step_avg:97.05ms +step:1089/1695 train_time:105680ms step_avg:97.04ms +step:1090/1695 train_time:105776ms step_avg:97.04ms +step:1091/1695 train_time:105873ms step_avg:97.04ms +step:1092/1695 train_time:105970ms step_avg:97.04ms +step:1093/1695 train_time:106065ms step_avg:97.04ms +step:1094/1695 train_time:106160ms step_avg:97.04ms +step:1095/1695 train_time:106255ms step_avg:97.04ms +step:1096/1695 train_time:106351ms step_avg:97.04ms +step:1097/1695 train_time:106448ms step_avg:97.04ms +step:1098/1695 train_time:106544ms step_avg:97.03ms +step:1099/1695 train_time:106640ms step_avg:97.03ms +step:1100/1695 train_time:106735ms step_avg:97.03ms +step:1101/1695 train_time:106832ms step_avg:97.03ms +step:1102/1695 train_time:106929ms step_avg:97.03ms +step:1103/1695 train_time:107026ms step_avg:97.03ms +step:1104/1695 train_time:107121ms step_avg:97.03ms +step:1105/1695 train_time:107216ms step_avg:97.03ms +step:1106/1695 train_time:107312ms step_avg:97.03ms +step:1107/1695 train_time:107408ms step_avg:97.03ms +step:1108/1695 train_time:107503ms step_avg:97.02ms +step:1109/1695 train_time:107599ms step_avg:97.02ms +step:1110/1695 train_time:107695ms step_avg:97.02ms +step:1111/1695 train_time:107792ms step_avg:97.02ms +step:1112/1695 train_time:107889ms step_avg:97.02ms +step:1113/1695 train_time:107985ms step_avg:97.02ms +step:1114/1695 train_time:108080ms step_avg:97.02ms +step:1115/1695 train_time:108176ms step_avg:97.02ms +step:1116/1695 train_time:108272ms step_avg:97.02ms +step:1117/1695 train_time:108369ms step_avg:97.02ms +step:1118/1695 train_time:108464ms step_avg:97.02ms +step:1119/1695 train_time:108560ms step_avg:97.02ms +step:1120/1695 train_time:108655ms step_avg:97.01ms +step:1121/1695 train_time:108752ms step_avg:97.01ms +step:1122/1695 train_time:108848ms step_avg:97.01ms +step:1123/1695 train_time:108944ms step_avg:97.01ms +step:1124/1695 train_time:109040ms step_avg:97.01ms +step:1125/1695 train_time:109136ms step_avg:97.01ms +step:1125/1695 val_loss:3.4374 train_time:109230ms step_avg:97.09ms +step:1126/1695 train_time:109255ms step_avg:97.03ms +step:1127/1695 train_time:109339ms step_avg:97.02ms +step:1128/1695 train_time:109436ms step_avg:97.02ms +step:1129/1695 train_time:109533ms step_avg:97.02ms +step:1130/1695 train_time:109628ms step_avg:97.02ms +step:1131/1695 train_time:109724ms step_avg:97.01ms +step:1132/1695 train_time:109818ms step_avg:97.01ms +step:1133/1695 train_time:109915ms step_avg:97.01ms +step:1134/1695 train_time:110012ms step_avg:97.01ms +step:1135/1695 train_time:110108ms step_avg:97.01ms +step:1136/1695 train_time:110207ms step_avg:97.01ms +step:1137/1695 train_time:110306ms step_avg:97.02ms +step:1138/1695 train_time:110406ms step_avg:97.02ms +step:1139/1695 train_time:110505ms step_avg:97.02ms +step:1140/1695 train_time:110602ms step_avg:97.02ms +step:1141/1695 train_time:110699ms step_avg:97.02ms +step:1142/1695 train_time:110796ms step_avg:97.02ms +step:1143/1695 train_time:110894ms step_avg:97.02ms +step:1144/1695 train_time:110992ms step_avg:97.02ms +step:1145/1695 train_time:111089ms step_avg:97.02ms +step:1146/1695 train_time:111188ms step_avg:97.02ms +step:1147/1695 train_time:111286ms step_avg:97.02ms +step:1148/1695 train_time:111384ms step_avg:97.02ms +step:1149/1695 train_time:111483ms step_avg:97.03ms +step:1150/1695 train_time:111581ms step_avg:97.03ms +step:1151/1695 train_time:111678ms step_avg:97.03ms +step:1152/1695 train_time:111776ms step_avg:97.03ms +step:1153/1695 train_time:111873ms step_avg:97.03ms +step:1154/1695 train_time:111970ms step_avg:97.03ms +step:1155/1695 train_time:112067ms step_avg:97.03ms +step:1156/1695 train_time:112165ms step_avg:97.03ms +step:1157/1695 train_time:112264ms step_avg:97.03ms +step:1158/1695 train_time:112361ms step_avg:97.03ms +step:1159/1695 train_time:112460ms step_avg:97.03ms +step:1160/1695 train_time:112558ms step_avg:97.03ms +step:1161/1695 train_time:112656ms step_avg:97.03ms +step:1162/1695 train_time:112754ms step_avg:97.03ms +step:1163/1695 train_time:112852ms step_avg:97.04ms +step:1164/1695 train_time:112949ms step_avg:97.04ms +step:1165/1695 train_time:113046ms step_avg:97.04ms +step:1166/1695 train_time:113143ms step_avg:97.04ms +step:1167/1695 train_time:113241ms step_avg:97.04ms +step:1168/1695 train_time:113339ms step_avg:97.04ms +step:1169/1695 train_time:113437ms step_avg:97.04ms +step:1170/1695 train_time:113535ms step_avg:97.04ms +step:1171/1695 train_time:113634ms step_avg:97.04ms +step:1172/1695 train_time:113733ms step_avg:97.04ms +step:1173/1695 train_time:113830ms step_avg:97.04ms +step:1174/1695 train_time:113927ms step_avg:97.04ms +step:1175/1695 train_time:114025ms step_avg:97.04ms +step:1176/1695 train_time:114121ms step_avg:97.04ms +step:1177/1695 train_time:114219ms step_avg:97.04ms +step:1178/1695 train_time:114317ms step_avg:97.04ms +step:1179/1695 train_time:114416ms step_avg:97.05ms +step:1180/1695 train_time:114515ms step_avg:97.05ms +step:1181/1695 train_time:114613ms step_avg:97.05ms +step:1182/1695 train_time:114712ms step_avg:97.05ms +step:1183/1695 train_time:114809ms step_avg:97.05ms +step:1184/1695 train_time:114908ms step_avg:97.05ms +step:1185/1695 train_time:115005ms step_avg:97.05ms +step:1186/1695 train_time:115102ms step_avg:97.05ms +step:1187/1695 train_time:115199ms step_avg:97.05ms +step:1188/1695 train_time:115296ms step_avg:97.05ms +step:1189/1695 train_time:115395ms step_avg:97.05ms +step:1190/1695 train_time:115494ms step_avg:97.05ms +step:1191/1695 train_time:115593ms step_avg:97.06ms +step:1192/1695 train_time:115692ms step_avg:97.06ms +step:1193/1695 train_time:115790ms step_avg:97.06ms +step:1194/1695 train_time:115888ms step_avg:97.06ms +step:1195/1695 train_time:115986ms step_avg:97.06ms +step:1196/1695 train_time:116084ms step_avg:97.06ms +step:1197/1695 train_time:116183ms step_avg:97.06ms +step:1198/1695 train_time:116280ms step_avg:97.06ms +step:1199/1695 train_time:116378ms step_avg:97.06ms +step:1200/1695 train_time:116476ms step_avg:97.06ms +step:1201/1695 train_time:116574ms step_avg:97.06ms +step:1202/1695 train_time:116672ms step_avg:97.07ms +step:1203/1695 train_time:116770ms step_avg:97.07ms +step:1204/1695 train_time:116868ms step_avg:97.07ms +step:1205/1695 train_time:116966ms step_avg:97.07ms +step:1206/1695 train_time:117064ms step_avg:97.07ms +step:1207/1695 train_time:117162ms step_avg:97.07ms +step:1208/1695 train_time:117508ms step_avg:97.27ms +step:1209/1695 train_time:117691ms step_avg:97.35ms +step:1210/1695 train_time:117787ms step_avg:97.34ms +step:1211/1695 train_time:117883ms step_avg:97.34ms +step:1212/1695 train_time:117979ms step_avg:97.34ms +step:1213/1695 train_time:118076ms step_avg:97.34ms +step:1214/1695 train_time:118174ms step_avg:97.34ms +step:1215/1695 train_time:118270ms step_avg:97.34ms +step:1216/1695 train_time:118367ms step_avg:97.34ms +step:1217/1695 train_time:118463ms step_avg:97.34ms +step:1218/1695 train_time:118566ms step_avg:97.34ms +step:1219/1695 train_time:118667ms step_avg:97.35ms +step:1220/1695 train_time:118767ms step_avg:97.35ms +step:1221/1695 train_time:118863ms step_avg:97.35ms +step:1222/1695 train_time:118959ms step_avg:97.35ms +step:1223/1695 train_time:119056ms step_avg:97.35ms +step:1224/1695 train_time:119152ms step_avg:97.35ms +step:1225/1695 train_time:119249ms step_avg:97.35ms +step:1226/1695 train_time:119346ms step_avg:97.35ms +step:1227/1695 train_time:119443ms step_avg:97.35ms +step:1228/1695 train_time:119541ms step_avg:97.35ms +step:1229/1695 train_time:119640ms step_avg:97.35ms +step:1230/1695 train_time:119740ms step_avg:97.35ms +step:1231/1695 train_time:119838ms step_avg:97.35ms +step:1232/1695 train_time:119936ms step_avg:97.35ms +step:1233/1695 train_time:120033ms step_avg:97.35ms +step:1234/1695 train_time:120130ms step_avg:97.35ms +step:1235/1695 train_time:120227ms step_avg:97.35ms +step:1236/1695 train_time:120323ms step_avg:97.35ms +step:1237/1695 train_time:120420ms step_avg:97.35ms +step:1238/1695 train_time:120517ms step_avg:97.35ms +step:1239/1695 train_time:120616ms step_avg:97.35ms +step:1240/1695 train_time:120716ms step_avg:97.35ms +step:1241/1695 train_time:120815ms step_avg:97.35ms +step:1242/1695 train_time:120914ms step_avg:97.35ms +step:1243/1695 train_time:121012ms step_avg:97.36ms +step:1244/1695 train_time:121110ms step_avg:97.35ms +step:1245/1695 train_time:121206ms step_avg:97.35ms +step:1246/1695 train_time:121304ms step_avg:97.35ms +step:1247/1695 train_time:121400ms step_avg:97.35ms +step:1248/1695 train_time:121498ms step_avg:97.35ms +step:1249/1695 train_time:121596ms step_avg:97.35ms +step:1250/1695 train_time:121696ms step_avg:97.36ms +step:1250/1695 val_loss:3.3886 train_time:121792ms step_avg:97.43ms +step:1251/1695 train_time:121818ms step_avg:97.38ms +step:1252/1695 train_time:121899ms step_avg:97.36ms +step:1253/1695 train_time:121997ms step_avg:97.36ms +step:1254/1695 train_time:122094ms step_avg:97.36ms +step:1255/1695 train_time:122190ms step_avg:97.36ms +step:1256/1695 train_time:122287ms step_avg:97.36ms +step:1257/1695 train_time:122383ms step_avg:97.36ms +step:1258/1695 train_time:122480ms step_avg:97.36ms +step:1259/1695 train_time:122576ms step_avg:97.36ms +step:1260/1695 train_time:122673ms step_avg:97.36ms +step:1261/1695 train_time:122774ms step_avg:97.36ms +step:1262/1695 train_time:122874ms step_avg:97.36ms +step:1263/1695 train_time:122972ms step_avg:97.36ms +step:1264/1695 train_time:123070ms step_avg:97.37ms +step:1265/1695 train_time:123167ms step_avg:97.37ms +step:1266/1695 train_time:123265ms step_avg:97.37ms +step:1267/1695 train_time:123361ms step_avg:97.36ms +step:1268/1695 train_time:123458ms step_avg:97.36ms +step:1269/1695 train_time:123554ms step_avg:97.36ms +step:1270/1695 train_time:123650ms step_avg:97.36ms +step:1271/1695 train_time:123749ms step_avg:97.36ms +step:1272/1695 train_time:123849ms step_avg:97.37ms +step:1273/1695 train_time:123947ms step_avg:97.37ms +step:1274/1695 train_time:124045ms step_avg:97.37ms +step:1275/1695 train_time:124143ms step_avg:97.37ms +step:1276/1695 train_time:124242ms step_avg:97.37ms +step:1277/1695 train_time:124340ms step_avg:97.37ms +step:1278/1695 train_time:124437ms step_avg:97.37ms +step:1279/1695 train_time:124534ms step_avg:97.37ms +step:1280/1695 train_time:124632ms step_avg:97.37ms +step:1281/1695 train_time:124730ms step_avg:97.37ms +step:1282/1695 train_time:124828ms step_avg:97.37ms +step:1283/1695 train_time:124926ms step_avg:97.37ms +step:1284/1695 train_time:125024ms step_avg:97.37ms +step:1285/1695 train_time:125122ms step_avg:97.37ms +step:1286/1695 train_time:125220ms step_avg:97.37ms +step:1287/1695 train_time:125318ms step_avg:97.37ms +step:1288/1695 train_time:125415ms step_avg:97.37ms +step:1289/1695 train_time:125512ms step_avg:97.37ms +step:1290/1695 train_time:125609ms step_avg:97.37ms +step:1291/1695 train_time:125707ms step_avg:97.37ms +step:1292/1695 train_time:125806ms step_avg:97.37ms +step:1293/1695 train_time:125904ms step_avg:97.37ms +step:1294/1695 train_time:126003ms step_avg:97.37ms +step:1295/1695 train_time:126102ms step_avg:97.38ms +step:1296/1695 train_time:126200ms step_avg:97.38ms +step:1297/1695 train_time:126297ms step_avg:97.38ms +step:1298/1695 train_time:126395ms step_avg:97.38ms +step:1299/1695 train_time:126492ms step_avg:97.38ms +step:1300/1695 train_time:126589ms step_avg:97.38ms +step:1301/1695 train_time:126686ms step_avg:97.38ms +step:1302/1695 train_time:126785ms step_avg:97.38ms +step:1303/1695 train_time:126883ms step_avg:97.38ms +step:1304/1695 train_time:126981ms step_avg:97.38ms +step:1305/1695 train_time:127079ms step_avg:97.38ms +step:1306/1695 train_time:127178ms step_avg:97.38ms +step:1307/1695 train_time:127276ms step_avg:97.38ms +step:1308/1695 train_time:127373ms step_avg:97.38ms +step:1309/1695 train_time:127470ms step_avg:97.38ms +step:1310/1695 train_time:127568ms step_avg:97.38ms +step:1311/1695 train_time:127665ms step_avg:97.38ms +step:1312/1695 train_time:127764ms step_avg:97.38ms +step:1313/1695 train_time:127862ms step_avg:97.38ms +step:1314/1695 train_time:127961ms step_avg:97.38ms +step:1315/1695 train_time:128059ms step_avg:97.38ms +step:1316/1695 train_time:128158ms step_avg:97.38ms +step:1317/1695 train_time:128255ms step_avg:97.38ms +step:1318/1695 train_time:128352ms step_avg:97.38ms +step:1319/1695 train_time:128449ms step_avg:97.38ms +step:1320/1695 train_time:128546ms step_avg:97.38ms +step:1321/1695 train_time:128644ms step_avg:97.38ms +step:1322/1695 train_time:128743ms step_avg:97.38ms +step:1323/1695 train_time:128841ms step_avg:97.39ms +step:1324/1695 train_time:128939ms step_avg:97.39ms +step:1325/1695 train_time:129038ms step_avg:97.39ms +step:1326/1695 train_time:129136ms step_avg:97.39ms +step:1327/1695 train_time:129233ms step_avg:97.39ms +step:1328/1695 train_time:129332ms step_avg:97.39ms +step:1329/1695 train_time:129429ms step_avg:97.39ms +step:1330/1695 train_time:129525ms step_avg:97.39ms +step:1331/1695 train_time:129623ms step_avg:97.39ms +step:1332/1695 train_time:129721ms step_avg:97.39ms +step:1333/1695 train_time:129820ms step_avg:97.39ms +step:1334/1695 train_time:129917ms step_avg:97.39ms +step:1335/1695 train_time:130015ms step_avg:97.39ms +step:1336/1695 train_time:130112ms step_avg:97.39ms +step:1337/1695 train_time:130209ms step_avg:97.39ms +step:1338/1695 train_time:130307ms step_avg:97.39ms +step:1339/1695 train_time:130405ms step_avg:97.39ms +step:1340/1695 train_time:130503ms step_avg:97.39ms +step:1341/1695 train_time:130601ms step_avg:97.39ms +step:1342/1695 train_time:130699ms step_avg:97.39ms +step:1343/1695 train_time:130797ms step_avg:97.39ms +step:1344/1695 train_time:130894ms step_avg:97.39ms +step:1345/1695 train_time:130991ms step_avg:97.39ms +step:1346/1695 train_time:131088ms step_avg:97.39ms +step:1347/1695 train_time:131185ms step_avg:97.39ms +step:1348/1695 train_time:131283ms step_avg:97.39ms +step:1349/1695 train_time:131381ms step_avg:97.39ms +step:1350/1695 train_time:131481ms step_avg:97.39ms +step:1351/1695 train_time:131579ms step_avg:97.39ms +step:1352/1695 train_time:131676ms step_avg:97.39ms +step:1353/1695 train_time:131774ms step_avg:97.39ms +step:1354/1695 train_time:131871ms step_avg:97.39ms +step:1355/1695 train_time:131969ms step_avg:97.39ms +step:1356/1695 train_time:132066ms step_avg:97.39ms +step:1357/1695 train_time:132164ms step_avg:97.39ms +step:1358/1695 train_time:132262ms step_avg:97.39ms +step:1359/1695 train_time:132360ms step_avg:97.40ms +step:1360/1695 train_time:132459ms step_avg:97.40ms +step:1361/1695 train_time:132558ms step_avg:97.40ms +step:1362/1695 train_time:132655ms step_avg:97.40ms +step:1363/1695 train_time:132753ms step_avg:97.40ms +step:1364/1695 train_time:132850ms step_avg:97.40ms +step:1365/1695 train_time:132948ms step_avg:97.40ms +step:1366/1695 train_time:133046ms step_avg:97.40ms +step:1367/1695 train_time:133144ms step_avg:97.40ms +step:1368/1695 train_time:133242ms step_avg:97.40ms +step:1369/1695 train_time:133340ms step_avg:97.40ms +step:1370/1695 train_time:133439ms step_avg:97.40ms +step:1371/1695 train_time:133537ms step_avg:97.40ms +step:1372/1695 train_time:133634ms step_avg:97.40ms +step:1373/1695 train_time:133732ms step_avg:97.40ms +step:1374/1695 train_time:133829ms step_avg:97.40ms +step:1375/1695 train_time:133926ms step_avg:97.40ms +step:1375/1695 val_loss:3.3508 train_time:134022ms step_avg:97.47ms +step:1376/1695 train_time:134049ms step_avg:97.42ms +step:1377/1695 train_time:134131ms step_avg:97.41ms +step:1378/1695 train_time:134229ms step_avg:97.41ms +step:1379/1695 train_time:134327ms step_avg:97.41ms +step:1380/1695 train_time:134424ms step_avg:97.41ms +step:1381/1695 train_time:134877ms step_avg:97.67ms +step:1382/1695 train_time:134952ms step_avg:97.65ms +step:1383/1695 train_time:135047ms step_avg:97.65ms +step:1384/1695 train_time:135144ms step_avg:97.65ms +step:1385/1695 train_time:135241ms step_avg:97.65ms +step:1386/1695 train_time:135338ms step_avg:97.65ms +step:1387/1695 train_time:135434ms step_avg:97.65ms +step:1388/1695 train_time:135530ms step_avg:97.64ms +step:1389/1695 train_time:135626ms step_avg:97.64ms +step:1390/1695 train_time:135723ms step_avg:97.64ms +step:1391/1695 train_time:135831ms step_avg:97.65ms +step:1392/1695 train_time:135931ms step_avg:97.65ms +step:1393/1695 train_time:136030ms step_avg:97.65ms +step:1394/1695 train_time:136128ms step_avg:97.65ms +step:1395/1695 train_time:136225ms step_avg:97.65ms +step:1396/1695 train_time:136322ms step_avg:97.65ms +step:1397/1695 train_time:136419ms step_avg:97.65ms +step:1398/1695 train_time:136516ms step_avg:97.65ms +step:1399/1695 train_time:136613ms step_avg:97.65ms +step:1400/1695 train_time:136711ms step_avg:97.65ms +step:1401/1695 train_time:136810ms step_avg:97.65ms +step:1402/1695 train_time:136910ms step_avg:97.65ms +step:1403/1695 train_time:137008ms step_avg:97.65ms +step:1404/1695 train_time:137106ms step_avg:97.65ms +step:1405/1695 train_time:137204ms step_avg:97.65ms +step:1406/1695 train_time:137301ms step_avg:97.65ms +step:1407/1695 train_time:137398ms step_avg:97.65ms +step:1408/1695 train_time:137494ms step_avg:97.65ms +step:1409/1695 train_time:137591ms step_avg:97.65ms +step:1410/1695 train_time:137688ms step_avg:97.65ms +step:1411/1695 train_time:137787ms step_avg:97.65ms +step:1412/1695 train_time:137887ms step_avg:97.65ms +step:1413/1695 train_time:137987ms step_avg:97.66ms +step:1414/1695 train_time:138085ms step_avg:97.66ms +step:1415/1695 train_time:138183ms step_avg:97.66ms +step:1416/1695 train_time:138281ms step_avg:97.66ms +step:1417/1695 train_time:138378ms step_avg:97.66ms +step:1418/1695 train_time:138476ms step_avg:97.66ms +step:1419/1695 train_time:138574ms step_avg:97.66ms +step:1420/1695 train_time:138669ms step_avg:97.65ms +step:1421/1695 train_time:138766ms step_avg:97.65ms +step:1422/1695 train_time:138866ms step_avg:97.66ms +step:1423/1695 train_time:138965ms step_avg:97.66ms +step:1424/1695 train_time:139065ms step_avg:97.66ms +step:1425/1695 train_time:139163ms step_avg:97.66ms +step:1426/1695 train_time:139260ms step_avg:97.66ms +step:1427/1695 train_time:139359ms step_avg:97.66ms +step:1428/1695 train_time:139456ms step_avg:97.66ms +step:1429/1695 train_time:139553ms step_avg:97.66ms +step:1430/1695 train_time:139650ms step_avg:97.66ms +step:1431/1695 train_time:139748ms step_avg:97.66ms +step:1432/1695 train_time:139845ms step_avg:97.66ms +step:1433/1695 train_time:139945ms step_avg:97.66ms +step:1434/1695 train_time:140045ms step_avg:97.66ms +step:1435/1695 train_time:140144ms step_avg:97.66ms +step:1436/1695 train_time:140243ms step_avg:97.66ms +step:1437/1695 train_time:140341ms step_avg:97.66ms +step:1438/1695 train_time:140440ms step_avg:97.66ms +step:1439/1695 train_time:140539ms step_avg:97.66ms +step:1440/1695 train_time:140637ms step_avg:97.66ms +step:1441/1695 train_time:140733ms step_avg:97.66ms +step:1442/1695 train_time:140830ms step_avg:97.66ms +step:1443/1695 train_time:140927ms step_avg:97.66ms +step:1444/1695 train_time:141024ms step_avg:97.66ms +step:1445/1695 train_time:141124ms step_avg:97.66ms +step:1446/1695 train_time:141222ms step_avg:97.66ms +step:1447/1695 train_time:141321ms step_avg:97.66ms +step:1448/1695 train_time:141421ms step_avg:97.67ms +step:1449/1695 train_time:141519ms step_avg:97.67ms +step:1450/1695 train_time:141618ms step_avg:97.67ms +step:1451/1695 train_time:141716ms step_avg:97.67ms +step:1452/1695 train_time:141813ms step_avg:97.67ms +step:1453/1695 train_time:141909ms step_avg:97.67ms +step:1454/1695 train_time:142006ms step_avg:97.67ms +step:1455/1695 train_time:142104ms step_avg:97.67ms +step:1456/1695 train_time:142202ms step_avg:97.67ms +step:1457/1695 train_time:142302ms step_avg:97.67ms +step:1458/1695 train_time:142401ms step_avg:97.67ms +step:1459/1695 train_time:142500ms step_avg:97.67ms +step:1460/1695 train_time:142599ms step_avg:97.67ms +step:1461/1695 train_time:142697ms step_avg:97.67ms +step:1462/1695 train_time:142795ms step_avg:97.67ms +step:1463/1695 train_time:142892ms step_avg:97.67ms +step:1464/1695 train_time:142989ms step_avg:97.67ms +step:1465/1695 train_time:143087ms step_avg:97.67ms +step:1466/1695 train_time:143185ms step_avg:97.67ms +step:1467/1695 train_time:143283ms step_avg:97.67ms +step:1468/1695 train_time:143382ms step_avg:97.67ms +step:1469/1695 train_time:143481ms step_avg:97.67ms +step:1470/1695 train_time:143579ms step_avg:97.67ms +step:1471/1695 train_time:143678ms step_avg:97.67ms +step:1472/1695 train_time:143775ms step_avg:97.67ms +step:1473/1695 train_time:143872ms step_avg:97.67ms +step:1474/1695 train_time:143968ms step_avg:97.67ms +step:1475/1695 train_time:144065ms step_avg:97.67ms +step:1476/1695 train_time:144163ms step_avg:97.67ms +step:1477/1695 train_time:144261ms step_avg:97.67ms +step:1478/1695 train_time:144359ms step_avg:97.67ms +step:1479/1695 train_time:144458ms step_avg:97.67ms +step:1480/1695 train_time:144557ms step_avg:97.67ms +step:1481/1695 train_time:144654ms step_avg:97.67ms +step:1482/1695 train_time:144751ms step_avg:97.67ms +step:1483/1695 train_time:144849ms step_avg:97.67ms +step:1484/1695 train_time:144946ms step_avg:97.67ms +step:1485/1695 train_time:145044ms step_avg:97.67ms +step:1486/1695 train_time:145141ms step_avg:97.67ms +step:1487/1695 train_time:145238ms step_avg:97.67ms +step:1488/1695 train_time:145335ms step_avg:97.67ms +step:1489/1695 train_time:145433ms step_avg:97.67ms +step:1490/1695 train_time:145532ms step_avg:97.67ms +step:1491/1695 train_time:145630ms step_avg:97.67ms +step:1492/1695 train_time:145728ms step_avg:97.67ms +step:1493/1695 train_time:145826ms step_avg:97.67ms +step:1494/1695 train_time:145924ms step_avg:97.67ms +step:1495/1695 train_time:146023ms step_avg:97.67ms +step:1496/1695 train_time:146121ms step_avg:97.67ms +step:1497/1695 train_time:146218ms step_avg:97.67ms +step:1498/1695 train_time:146316ms step_avg:97.67ms +step:1499/1695 train_time:146414ms step_avg:97.67ms +step:1500/1695 train_time:146512ms step_avg:97.67ms +step:1500/1695 val_loss:3.3173 train_time:146608ms step_avg:97.74ms +step:1501/1695 train_time:146635ms step_avg:97.69ms +step:1502/1695 train_time:146715ms step_avg:97.68ms +step:1503/1695 train_time:146815ms step_avg:97.68ms +step:1504/1695 train_time:146912ms step_avg:97.68ms +step:1505/1695 train_time:147009ms step_avg:97.68ms +step:1506/1695 train_time:147105ms step_avg:97.68ms +step:1507/1695 train_time:147202ms step_avg:97.68ms +step:1508/1695 train_time:147299ms step_avg:97.68ms +step:1509/1695 train_time:147395ms step_avg:97.68ms +step:1510/1695 train_time:147491ms step_avg:97.68ms +step:1511/1695 train_time:147591ms step_avg:97.68ms +step:1512/1695 train_time:147692ms step_avg:97.68ms +step:1513/1695 train_time:147792ms step_avg:97.68ms +step:1514/1695 train_time:147890ms step_avg:97.68ms +step:1515/1695 train_time:147987ms step_avg:97.68ms +step:1516/1695 train_time:148084ms step_avg:97.68ms +step:1517/1695 train_time:148181ms step_avg:97.68ms +step:1518/1695 train_time:148278ms step_avg:97.68ms +step:1519/1695 train_time:148375ms step_avg:97.68ms +step:1520/1695 train_time:148472ms step_avg:97.68ms +step:1521/1695 train_time:148570ms step_avg:97.68ms +step:1522/1695 train_time:148669ms step_avg:97.68ms +step:1523/1695 train_time:148769ms step_avg:97.68ms +step:1524/1695 train_time:148867ms step_avg:97.68ms +step:1525/1695 train_time:148965ms step_avg:97.68ms +step:1526/1695 train_time:149062ms step_avg:97.68ms +step:1527/1695 train_time:149159ms step_avg:97.68ms +step:1528/1695 train_time:149256ms step_avg:97.68ms +step:1529/1695 train_time:149352ms step_avg:97.68ms +step:1530/1695 train_time:149448ms step_avg:97.68ms +step:1531/1695 train_time:149546ms step_avg:97.68ms +step:1532/1695 train_time:149645ms step_avg:97.68ms +step:1533/1695 train_time:149745ms step_avg:97.68ms +step:1534/1695 train_time:149844ms step_avg:97.68ms +step:1535/1695 train_time:149943ms step_avg:97.68ms +step:1536/1695 train_time:150041ms step_avg:97.68ms +step:1537/1695 train_time:150138ms step_avg:97.68ms +step:1538/1695 train_time:150235ms step_avg:97.68ms +step:1539/1695 train_time:150333ms step_avg:97.68ms +step:1540/1695 train_time:150429ms step_avg:97.68ms +step:1541/1695 train_time:150527ms step_avg:97.68ms +step:1542/1695 train_time:150625ms step_avg:97.68ms +step:1543/1695 train_time:150724ms step_avg:97.68ms +step:1544/1695 train_time:150823ms step_avg:97.68ms +step:1545/1695 train_time:150922ms step_avg:97.68ms +step:1546/1695 train_time:151020ms step_avg:97.68ms +step:1547/1695 train_time:151118ms step_avg:97.68ms +step:1548/1695 train_time:151214ms step_avg:97.68ms +step:1549/1695 train_time:151311ms step_avg:97.68ms +step:1550/1695 train_time:151407ms step_avg:97.68ms +step:1551/1695 train_time:151505ms step_avg:97.68ms +step:1552/1695 train_time:151855ms step_avg:97.84ms +step:1553/1695 train_time:152033ms step_avg:97.90ms +step:1554/1695 train_time:152130ms step_avg:97.90ms +step:1555/1695 train_time:152226ms step_avg:97.89ms +step:1556/1695 train_time:152323ms step_avg:97.89ms +step:1557/1695 train_time:152420ms step_avg:97.89ms +step:1558/1695 train_time:152516ms step_avg:97.89ms +step:1559/1695 train_time:152612ms step_avg:97.89ms +step:1560/1695 train_time:152708ms step_avg:97.89ms +step:1561/1695 train_time:152804ms step_avg:97.89ms +step:1562/1695 train_time:152908ms step_avg:97.89ms +step:1563/1695 train_time:153011ms step_avg:97.90ms +step:1564/1695 train_time:153110ms step_avg:97.90ms +step:1565/1695 train_time:153207ms step_avg:97.90ms +step:1566/1695 train_time:153305ms step_avg:97.90ms +step:1567/1695 train_time:153401ms step_avg:97.89ms +step:1568/1695 train_time:153499ms step_avg:97.89ms +step:1569/1695 train_time:153596ms step_avg:97.89ms +step:1570/1695 train_time:153694ms step_avg:97.89ms +step:1571/1695 train_time:153791ms step_avg:97.89ms +step:1572/1695 train_time:153889ms step_avg:97.89ms +step:1573/1695 train_time:153989ms step_avg:97.89ms +step:1574/1695 train_time:154087ms step_avg:97.90ms +step:1575/1695 train_time:154185ms step_avg:97.90ms +step:1576/1695 train_time:154283ms step_avg:97.90ms +step:1577/1695 train_time:154380ms step_avg:97.89ms +step:1578/1695 train_time:154477ms step_avg:97.89ms +step:1579/1695 train_time:154574ms step_avg:97.89ms +step:1580/1695 train_time:154672ms step_avg:97.89ms +step:1581/1695 train_time:154769ms step_avg:97.89ms +step:1582/1695 train_time:154867ms step_avg:97.89ms +step:1583/1695 train_time:154967ms step_avg:97.89ms +step:1584/1695 train_time:155066ms step_avg:97.90ms +step:1585/1695 train_time:155164ms step_avg:97.90ms +step:1586/1695 train_time:155262ms step_avg:97.90ms +step:1587/1695 train_time:155360ms step_avg:97.90ms +step:1588/1695 train_time:155457ms step_avg:97.89ms +step:1589/1695 train_time:155554ms step_avg:97.89ms +step:1590/1695 train_time:155651ms step_avg:97.89ms +step:1591/1695 train_time:155749ms step_avg:97.89ms +step:1592/1695 train_time:155846ms step_avg:97.89ms +step:1593/1695 train_time:155945ms step_avg:97.89ms +step:1594/1695 train_time:156044ms step_avg:97.89ms +step:1595/1695 train_time:156143ms step_avg:97.90ms +step:1596/1695 train_time:156240ms step_avg:97.89ms +step:1597/1695 train_time:156338ms step_avg:97.89ms +step:1598/1695 train_time:156435ms step_avg:97.89ms +step:1599/1695 train_time:156532ms step_avg:97.89ms +step:1600/1695 train_time:156629ms step_avg:97.89ms +step:1601/1695 train_time:156727ms step_avg:97.89ms +step:1602/1695 train_time:156825ms step_avg:97.89ms +step:1603/1695 train_time:156923ms step_avg:97.89ms +step:1604/1695 train_time:157022ms step_avg:97.89ms +step:1605/1695 train_time:157121ms step_avg:97.89ms +step:1606/1695 train_time:157220ms step_avg:97.90ms +step:1607/1695 train_time:157318ms step_avg:97.90ms +step:1608/1695 train_time:157416ms step_avg:97.90ms +step:1609/1695 train_time:157515ms step_avg:97.90ms +step:1610/1695 train_time:157611ms step_avg:97.90ms +step:1611/1695 train_time:157709ms step_avg:97.90ms +step:1612/1695 train_time:157806ms step_avg:97.89ms +step:1613/1695 train_time:157904ms step_avg:97.89ms +step:1614/1695 train_time:158003ms step_avg:97.90ms +step:1615/1695 train_time:158102ms step_avg:97.90ms +step:1616/1695 train_time:158200ms step_avg:97.90ms +step:1617/1695 train_time:158299ms step_avg:97.90ms +step:1618/1695 train_time:158398ms step_avg:97.90ms +step:1619/1695 train_time:158495ms step_avg:97.90ms +step:1620/1695 train_time:158593ms step_avg:97.90ms +step:1621/1695 train_time:158691ms step_avg:97.90ms +step:1622/1695 train_time:158787ms step_avg:97.90ms +step:1623/1695 train_time:158885ms step_avg:97.90ms +step:1624/1695 train_time:158983ms step_avg:97.90ms +step:1625/1695 train_time:159082ms step_avg:97.90ms +step:1625/1695 val_loss:3.2899 train_time:159178ms step_avg:97.96ms +step:1626/1695 train_time:159206ms step_avg:97.91ms +step:1627/1695 train_time:159288ms step_avg:97.90ms +step:1628/1695 train_time:159387ms step_avg:97.90ms +step:1629/1695 train_time:159485ms step_avg:97.90ms +step:1630/1695 train_time:159582ms step_avg:97.90ms +step:1631/1695 train_time:159679ms step_avg:97.90ms +step:1632/1695 train_time:159776ms step_avg:97.90ms +step:1633/1695 train_time:159872ms step_avg:97.90ms +step:1634/1695 train_time:159968ms step_avg:97.90ms +step:1635/1695 train_time:160065ms step_avg:97.90ms +step:1636/1695 train_time:160166ms step_avg:97.90ms +step:1637/1695 train_time:160267ms step_avg:97.90ms +step:1638/1695 train_time:160367ms step_avg:97.90ms +step:1639/1695 train_time:160466ms step_avg:97.90ms +step:1640/1695 train_time:160563ms step_avg:97.90ms +step:1641/1695 train_time:160661ms step_avg:97.90ms +step:1642/1695 train_time:160759ms step_avg:97.90ms +step:1643/1695 train_time:160856ms step_avg:97.90ms +step:1644/1695 train_time:160953ms step_avg:97.90ms +step:1645/1695 train_time:161050ms step_avg:97.90ms +step:1646/1695 train_time:161148ms step_avg:97.90ms +step:1647/1695 train_time:161247ms step_avg:97.90ms +step:1648/1695 train_time:161346ms step_avg:97.90ms +step:1649/1695 train_time:161445ms step_avg:97.90ms +step:1650/1695 train_time:161544ms step_avg:97.91ms +step:1651/1695 train_time:161641ms step_avg:97.90ms +step:1652/1695 train_time:161739ms step_avg:97.91ms +step:1653/1695 train_time:161838ms step_avg:97.91ms +step:1654/1695 train_time:161935ms step_avg:97.91ms +step:1655/1695 train_time:162032ms step_avg:97.90ms +step:1656/1695 train_time:162128ms step_avg:97.90ms +step:1657/1695 train_time:162226ms step_avg:97.90ms +step:1658/1695 train_time:162326ms step_avg:97.90ms +step:1659/1695 train_time:162424ms step_avg:97.90ms +step:1660/1695 train_time:162522ms step_avg:97.90ms +step:1661/1695 train_time:162620ms step_avg:97.90ms +step:1662/1695 train_time:162717ms step_avg:97.90ms +step:1663/1695 train_time:162815ms step_avg:97.90ms +step:1664/1695 train_time:162913ms step_avg:97.90ms +step:1665/1695 train_time:163010ms step_avg:97.90ms +step:1666/1695 train_time:163107ms step_avg:97.90ms +step:1667/1695 train_time:163205ms step_avg:97.90ms +step:1668/1695 train_time:163305ms step_avg:97.90ms +step:1669/1695 train_time:163404ms step_avg:97.91ms +step:1670/1695 train_time:163503ms step_avg:97.91ms +step:1671/1695 train_time:163601ms step_avg:97.91ms +step:1672/1695 train_time:163699ms step_avg:97.91ms +step:1673/1695 train_time:163797ms step_avg:97.91ms +step:1674/1695 train_time:163896ms step_avg:97.91ms +step:1675/1695 train_time:163994ms step_avg:97.91ms +step:1676/1695 train_time:164090ms step_avg:97.91ms +step:1677/1695 train_time:164187ms step_avg:97.91ms +step:1678/1695 train_time:164285ms step_avg:97.91ms +step:1679/1695 train_time:164384ms step_avg:97.91ms +step:1680/1695 train_time:164482ms step_avg:97.91ms +step:1681/1695 train_time:164581ms step_avg:97.91ms +step:1682/1695 train_time:164679ms step_avg:97.91ms +step:1683/1695 train_time:164777ms step_avg:97.91ms +step:1684/1695 train_time:164875ms step_avg:97.91ms +step:1685/1695 train_time:164972ms step_avg:97.91ms +step:1686/1695 train_time:165068ms step_avg:97.91ms +step:1687/1695 train_time:165167ms step_avg:97.91ms +step:1688/1695 train_time:165265ms step_avg:97.91ms +step:1689/1695 train_time:165363ms step_avg:97.91ms +step:1690/1695 train_time:165460ms step_avg:97.91ms +step:1691/1695 train_time:165558ms step_avg:97.91ms +step:1692/1695 train_time:165656ms step_avg:97.91ms +step:1693/1695 train_time:165754ms step_avg:97.91ms +step:1694/1695 train_time:165851ms step_avg:97.91ms +step:1695/1695 train_time:165949ms step_avg:97.90ms +step:1695/1695 val_loss:3.2780 train_time:166044ms step_avg:97.96ms +peak memory allocated: 34000 MiB reserved: 49636 MiB diff --git a/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt b/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt new file mode 100644 index 000000000..789ccc0d1 --- /dev/null +++ b/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt @@ -0,0 +1,2808 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + + while True: + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + # optimization + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 60 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): + inputs, targets = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, ws, ws // 2) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, ws, ws // 2).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * lr + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Aug 27 03:47:47 2025 ++---------------------------------------------------------------------------------------+ +| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | +|-----------------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+======================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | +| N/A 32C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | +| N/A 36C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | +| N/A 37C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | +| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | +| N/A 32C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | +| N/A 38C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | +| N/A 36C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | +| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+----------------------+----------------------+ + ++---------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=======================================================================================| ++---------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1695 train_time:520ms step_avg:519.82ms +step:2/1695 train_time:546ms step_avg:272.89ms +step:3/1695 train_time:615ms step_avg:204.90ms +step:4/1695 train_time:707ms step_avg:176.66ms +step:5/1695 train_time:800ms step_avg:159.97ms +step:6/1695 train_time:892ms step_avg:148.74ms +step:7/1695 train_time:986ms step_avg:140.81ms +step:8/1695 train_time:1079ms step_avg:134.91ms +step:9/1695 train_time:1173ms step_avg:130.29ms +step:10/1695 train_time:1266ms step_avg:126.65ms +step:11/1695 train_time:1360ms step_avg:123.62ms +step:12/1695 train_time:1455ms step_avg:121.29ms +step:13/1695 train_time:1553ms step_avg:119.48ms +step:14/1695 train_time:1649ms step_avg:117.78ms +step:15/1695 train_time:1745ms step_avg:116.36ms +step:16/1695 train_time:1840ms step_avg:115.01ms +step:17/1695 train_time:1933ms step_avg:113.72ms +step:18/1695 train_time:2027ms step_avg:112.62ms +step:19/1695 train_time:2121ms step_avg:111.64ms +step:20/1695 train_time:2215ms step_avg:110.74ms +step:21/1695 train_time:2309ms step_avg:109.95ms +step:22/1695 train_time:2405ms step_avg:109.30ms +step:23/1695 train_time:2501ms step_avg:108.76ms +step:24/1695 train_time:2596ms step_avg:108.18ms +step:25/1695 train_time:2692ms step_avg:107.66ms +step:26/1695 train_time:2787ms step_avg:107.19ms +step:27/1695 train_time:2882ms step_avg:106.75ms +step:28/1695 train_time:2976ms step_avg:106.30ms +step:29/1695 train_time:3070ms step_avg:105.87ms +step:30/1695 train_time:3164ms step_avg:105.48ms +step:31/1695 train_time:3258ms step_avg:105.10ms +step:32/1695 train_time:3352ms step_avg:104.76ms +step:33/1695 train_time:3448ms step_avg:104.49ms +step:34/1695 train_time:3546ms step_avg:104.28ms +step:35/1695 train_time:3641ms step_avg:104.04ms +step:36/1695 train_time:3735ms step_avg:103.76ms +step:37/1695 train_time:3830ms step_avg:103.51ms +step:38/1695 train_time:3925ms step_avg:103.28ms +step:39/1695 train_time:4019ms step_avg:103.05ms +step:40/1695 train_time:4113ms step_avg:102.83ms +step:41/1695 train_time:4208ms step_avg:102.63ms +step:42/1695 train_time:4302ms step_avg:102.43ms +step:43/1695 train_time:4396ms step_avg:102.22ms +step:44/1695 train_time:4490ms step_avg:102.06ms +step:45/1695 train_time:4587ms step_avg:101.92ms +step:46/1695 train_time:4682ms step_avg:101.78ms +step:47/1695 train_time:4776ms step_avg:101.62ms +step:48/1695 train_time:4871ms step_avg:101.47ms +step:49/1695 train_time:4965ms step_avg:101.33ms +step:50/1695 train_time:5061ms step_avg:101.22ms +step:51/1695 train_time:5154ms step_avg:101.05ms +step:52/1695 train_time:5248ms step_avg:100.93ms +step:53/1695 train_time:5343ms step_avg:100.82ms +step:54/1695 train_time:5439ms step_avg:100.73ms +step:55/1695 train_time:5533ms step_avg:100.59ms +step:56/1695 train_time:5628ms step_avg:100.50ms +step:57/1695 train_time:5723ms step_avg:100.41ms +step:58/1695 train_time:5819ms step_avg:100.33ms +step:59/1695 train_time:5913ms step_avg:100.22ms +step:60/1695 train_time:6008ms step_avg:100.13ms +step:61/1695 train_time:6101ms step_avg:100.02ms +step:62/1695 train_time:6195ms step_avg:99.91ms +step:63/1695 train_time:6289ms step_avg:99.82ms +step:64/1695 train_time:6384ms step_avg:99.75ms +step:65/1695 train_time:6479ms step_avg:99.68ms +step:66/1695 train_time:6573ms step_avg:99.59ms +step:67/1695 train_time:6668ms step_avg:99.52ms +step:68/1695 train_time:6762ms step_avg:99.45ms +step:69/1695 train_time:6856ms step_avg:99.37ms +step:70/1695 train_time:6950ms step_avg:99.29ms +step:71/1695 train_time:7044ms step_avg:99.21ms +step:72/1695 train_time:7138ms step_avg:99.14ms +step:73/1695 train_time:7232ms step_avg:99.07ms +step:74/1695 train_time:7327ms step_avg:99.02ms +step:75/1695 train_time:7423ms step_avg:98.97ms +step:76/1695 train_time:7518ms step_avg:98.92ms +step:77/1695 train_time:7613ms step_avg:98.87ms +step:78/1695 train_time:7709ms step_avg:98.83ms +step:79/1695 train_time:7803ms step_avg:98.78ms +step:80/1695 train_time:7897ms step_avg:98.72ms +step:81/1695 train_time:7991ms step_avg:98.65ms +step:82/1695 train_time:8085ms step_avg:98.60ms +step:83/1695 train_time:8180ms step_avg:98.56ms +step:84/1695 train_time:8274ms step_avg:98.50ms +step:85/1695 train_time:8368ms step_avg:98.45ms +step:86/1695 train_time:8463ms step_avg:98.40ms +step:87/1695 train_time:8556ms step_avg:98.35ms +step:88/1695 train_time:8651ms step_avg:98.31ms +step:89/1695 train_time:8747ms step_avg:98.29ms +step:90/1695 train_time:8843ms step_avg:98.26ms +step:91/1695 train_time:8938ms step_avg:98.22ms +step:92/1695 train_time:9031ms step_avg:98.16ms +step:93/1695 train_time:9125ms step_avg:98.12ms +step:94/1695 train_time:9220ms step_avg:98.08ms +step:95/1695 train_time:9314ms step_avg:98.04ms +step:96/1695 train_time:9408ms step_avg:98.00ms +step:97/1695 train_time:9503ms step_avg:97.97ms +step:98/1695 train_time:9597ms step_avg:97.93ms +step:99/1695 train_time:9691ms step_avg:97.89ms +step:100/1695 train_time:9787ms step_avg:97.87ms +step:101/1695 train_time:9881ms step_avg:97.84ms +step:102/1695 train_time:9975ms step_avg:97.80ms +step:103/1695 train_time:10069ms step_avg:97.76ms +step:104/1695 train_time:10164ms step_avg:97.74ms +step:105/1695 train_time:10258ms step_avg:97.70ms +step:106/1695 train_time:10352ms step_avg:97.66ms +step:107/1695 train_time:10447ms step_avg:97.64ms +step:108/1695 train_time:10543ms step_avg:97.62ms +step:109/1695 train_time:10637ms step_avg:97.59ms +step:110/1695 train_time:10732ms step_avg:97.56ms +step:111/1695 train_time:10827ms step_avg:97.54ms +step:112/1695 train_time:10922ms step_avg:97.52ms +step:113/1695 train_time:11016ms step_avg:97.49ms +step:114/1695 train_time:11110ms step_avg:97.46ms +step:115/1695 train_time:11205ms step_avg:97.44ms +step:116/1695 train_time:11299ms step_avg:97.41ms +step:117/1695 train_time:11393ms step_avg:97.38ms +step:118/1695 train_time:11488ms step_avg:97.35ms +step:119/1695 train_time:11583ms step_avg:97.33ms +step:120/1695 train_time:11677ms step_avg:97.31ms +step:121/1695 train_time:11771ms step_avg:97.28ms +step:122/1695 train_time:11867ms step_avg:97.27ms +step:123/1695 train_time:11961ms step_avg:97.24ms +step:124/1695 train_time:12054ms step_avg:97.21ms +step:125/1695 train_time:12149ms step_avg:97.19ms +step:125/1695 val_loss:4.3107 train_time:12241ms step_avg:97.93ms +step:126/1695 train_time:12268ms step_avg:97.36ms +step:127/1695 train_time:12344ms step_avg:97.20ms +step:128/1695 train_time:12448ms step_avg:97.25ms +step:129/1695 train_time:12544ms step_avg:97.24ms +step:130/1695 train_time:12638ms step_avg:97.21ms +step:131/1695 train_time:12730ms step_avg:97.18ms +step:132/1695 train_time:12824ms step_avg:97.15ms +step:133/1695 train_time:12918ms step_avg:97.13ms +step:134/1695 train_time:13011ms step_avg:97.10ms +step:135/1695 train_time:13105ms step_avg:97.07ms +step:136/1695 train_time:13198ms step_avg:97.04ms +step:137/1695 train_time:13292ms step_avg:97.02ms +step:138/1695 train_time:13389ms step_avg:97.02ms +step:139/1695 train_time:13485ms step_avg:97.01ms +step:140/1695 train_time:13580ms step_avg:97.00ms +step:141/1695 train_time:13675ms step_avg:96.98ms +step:142/1695 train_time:13770ms step_avg:96.97ms +step:143/1695 train_time:13862ms step_avg:96.94ms +step:144/1695 train_time:13955ms step_avg:96.91ms +step:145/1695 train_time:14049ms step_avg:96.89ms +step:146/1695 train_time:14144ms step_avg:96.87ms +step:147/1695 train_time:14238ms step_avg:96.86ms +step:148/1695 train_time:14333ms step_avg:96.84ms +step:149/1695 train_time:14427ms step_avg:96.83ms +step:150/1695 train_time:14524ms step_avg:96.82ms +step:151/1695 train_time:14617ms step_avg:96.80ms +step:152/1695 train_time:14711ms step_avg:96.78ms +step:153/1695 train_time:14805ms step_avg:96.76ms +step:154/1695 train_time:14899ms step_avg:96.75ms +step:155/1695 train_time:14992ms step_avg:96.72ms +step:156/1695 train_time:15085ms step_avg:96.70ms +step:157/1695 train_time:15180ms step_avg:96.69ms +step:158/1695 train_time:15274ms step_avg:96.67ms +step:159/1695 train_time:15368ms step_avg:96.66ms +step:160/1695 train_time:15463ms step_avg:96.65ms +step:161/1695 train_time:15558ms step_avg:96.63ms +step:162/1695 train_time:15652ms step_avg:96.62ms +step:163/1695 train_time:15746ms step_avg:96.60ms +step:164/1695 train_time:15841ms step_avg:96.59ms +step:165/1695 train_time:15935ms step_avg:96.57ms +step:166/1695 train_time:16029ms step_avg:96.56ms +step:167/1695 train_time:16124ms step_avg:96.55ms +step:168/1695 train_time:16217ms step_avg:96.53ms +step:169/1695 train_time:16313ms step_avg:96.52ms +step:170/1695 train_time:16406ms step_avg:96.51ms +step:171/1695 train_time:16501ms step_avg:96.50ms +step:172/1695 train_time:16595ms step_avg:96.48ms +step:173/1695 train_time:16938ms step_avg:97.91ms +step:174/1695 train_time:17053ms step_avg:98.00ms +step:175/1695 train_time:17146ms step_avg:97.98ms +step:176/1695 train_time:17239ms step_avg:97.95ms +step:177/1695 train_time:17332ms step_avg:97.92ms +step:178/1695 train_time:17426ms step_avg:97.90ms +step:179/1695 train_time:17518ms step_avg:97.87ms +step:180/1695 train_time:17611ms step_avg:97.84ms +step:181/1695 train_time:17705ms step_avg:97.82ms +step:182/1695 train_time:17798ms step_avg:97.79ms +step:183/1695 train_time:17897ms step_avg:97.80ms +step:184/1695 train_time:17993ms step_avg:97.79ms +step:185/1695 train_time:18087ms step_avg:97.77ms +step:186/1695 train_time:18181ms step_avg:97.75ms +step:187/1695 train_time:18275ms step_avg:97.73ms +step:188/1695 train_time:18369ms step_avg:97.71ms +step:189/1695 train_time:18462ms step_avg:97.68ms +step:190/1695 train_time:18555ms step_avg:97.66ms +step:191/1695 train_time:18648ms step_avg:97.63ms +step:192/1695 train_time:18743ms step_avg:97.62ms +step:193/1695 train_time:18839ms step_avg:97.61ms +step:194/1695 train_time:18934ms step_avg:97.60ms +step:195/1695 train_time:19028ms step_avg:97.58ms +step:196/1695 train_time:19122ms step_avg:97.56ms +step:197/1695 train_time:19216ms step_avg:97.55ms +step:198/1695 train_time:19310ms step_avg:97.53ms +step:199/1695 train_time:19404ms step_avg:97.51ms +step:200/1695 train_time:19498ms step_avg:97.49ms +step:201/1695 train_time:19591ms step_avg:97.47ms +step:202/1695 train_time:19684ms step_avg:97.45ms +step:203/1695 train_time:19779ms step_avg:97.43ms +step:204/1695 train_time:19873ms step_avg:97.42ms +step:205/1695 train_time:19967ms step_avg:97.40ms +step:206/1695 train_time:20062ms step_avg:97.39ms +step:207/1695 train_time:20157ms step_avg:97.38ms +step:208/1695 train_time:20251ms step_avg:97.36ms +step:209/1695 train_time:20345ms step_avg:97.34ms +step:210/1695 train_time:20439ms step_avg:97.33ms +step:211/1695 train_time:20532ms step_avg:97.31ms +step:212/1695 train_time:20625ms step_avg:97.29ms +step:213/1695 train_time:20719ms step_avg:97.27ms +step:214/1695 train_time:20814ms step_avg:97.26ms +step:215/1695 train_time:20908ms step_avg:97.25ms +step:216/1695 train_time:21002ms step_avg:97.23ms +step:217/1695 train_time:21097ms step_avg:97.22ms +step:218/1695 train_time:21191ms step_avg:97.21ms +step:219/1695 train_time:21285ms step_avg:97.19ms +step:220/1695 train_time:21380ms step_avg:97.18ms +step:221/1695 train_time:21474ms step_avg:97.17ms +step:222/1695 train_time:21568ms step_avg:97.15ms +step:223/1695 train_time:21664ms step_avg:97.15ms +step:224/1695 train_time:21756ms step_avg:97.13ms +step:225/1695 train_time:21850ms step_avg:97.11ms +step:226/1695 train_time:21945ms step_avg:97.10ms +step:227/1695 train_time:22040ms step_avg:97.09ms +step:228/1695 train_time:22134ms step_avg:97.08ms +step:229/1695 train_time:22228ms step_avg:97.06ms +step:230/1695 train_time:22322ms step_avg:97.05ms +step:231/1695 train_time:22417ms step_avg:97.04ms +step:232/1695 train_time:22510ms step_avg:97.02ms +step:233/1695 train_time:22604ms step_avg:97.01ms +step:234/1695 train_time:22699ms step_avg:97.00ms +step:235/1695 train_time:22792ms step_avg:96.99ms +step:236/1695 train_time:22886ms step_avg:96.97ms +step:237/1695 train_time:22981ms step_avg:96.96ms +step:238/1695 train_time:23075ms step_avg:96.95ms +step:239/1695 train_time:23170ms step_avg:96.94ms +step:240/1695 train_time:23264ms step_avg:96.93ms +step:241/1695 train_time:23358ms step_avg:96.92ms +step:242/1695 train_time:23451ms step_avg:96.91ms +step:243/1695 train_time:23546ms step_avg:96.90ms +step:244/1695 train_time:23641ms step_avg:96.89ms +step:245/1695 train_time:23735ms step_avg:96.88ms +step:246/1695 train_time:23829ms step_avg:96.87ms +step:247/1695 train_time:23923ms step_avg:96.85ms +step:248/1695 train_time:24018ms step_avg:96.85ms +step:249/1695 train_time:24112ms step_avg:96.83ms +step:250/1695 train_time:24206ms step_avg:96.82ms +step:250/1695 val_loss:3.9776 train_time:24298ms step_avg:97.19ms +step:251/1695 train_time:24328ms step_avg:96.92ms +step:252/1695 train_time:24399ms step_avg:96.82ms +step:253/1695 train_time:24498ms step_avg:96.83ms +step:254/1695 train_time:24592ms step_avg:96.82ms +step:255/1695 train_time:24685ms step_avg:96.81ms +step:256/1695 train_time:24778ms step_avg:96.79ms +step:257/1695 train_time:24870ms step_avg:96.77ms +step:258/1695 train_time:24964ms step_avg:96.76ms +step:259/1695 train_time:25057ms step_avg:96.74ms +step:260/1695 train_time:25150ms step_avg:96.73ms +step:261/1695 train_time:25243ms step_avg:96.72ms +step:262/1695 train_time:25338ms step_avg:96.71ms +step:263/1695 train_time:25434ms step_avg:96.71ms +step:264/1695 train_time:25529ms step_avg:96.70ms +step:265/1695 train_time:25625ms step_avg:96.70ms +step:266/1695 train_time:25719ms step_avg:96.69ms +step:267/1695 train_time:25812ms step_avg:96.68ms +step:268/1695 train_time:25905ms step_avg:96.66ms +step:269/1695 train_time:25998ms step_avg:96.65ms +step:270/1695 train_time:26091ms step_avg:96.63ms +step:271/1695 train_time:26185ms step_avg:96.62ms +step:272/1695 train_time:26279ms step_avg:96.61ms +step:273/1695 train_time:26373ms step_avg:96.60ms +step:274/1695 train_time:26468ms step_avg:96.60ms +step:275/1695 train_time:26563ms step_avg:96.59ms +step:276/1695 train_time:26658ms step_avg:96.59ms +step:277/1695 train_time:26751ms step_avg:96.57ms +step:278/1695 train_time:26845ms step_avg:96.56ms +step:279/1695 train_time:26939ms step_avg:96.56ms +step:280/1695 train_time:27032ms step_avg:96.54ms +step:281/1695 train_time:27125ms step_avg:96.53ms +step:282/1695 train_time:27219ms step_avg:96.52ms +step:283/1695 train_time:27313ms step_avg:96.51ms +step:284/1695 train_time:27407ms step_avg:96.50ms +step:285/1695 train_time:27502ms step_avg:96.50ms +step:286/1695 train_time:27597ms step_avg:96.49ms +step:287/1695 train_time:27691ms step_avg:96.49ms +step:288/1695 train_time:27785ms step_avg:96.48ms +step:289/1695 train_time:27879ms step_avg:96.47ms +step:290/1695 train_time:27973ms step_avg:96.46ms +step:291/1695 train_time:28067ms step_avg:96.45ms +step:292/1695 train_time:28160ms step_avg:96.44ms +step:293/1695 train_time:28254ms step_avg:96.43ms +step:294/1695 train_time:28347ms step_avg:96.42ms +step:295/1695 train_time:28442ms step_avg:96.41ms +step:296/1695 train_time:28537ms step_avg:96.41ms +step:297/1695 train_time:28632ms step_avg:96.40ms +step:298/1695 train_time:28726ms step_avg:96.40ms +step:299/1695 train_time:28821ms step_avg:96.39ms +step:300/1695 train_time:28915ms step_avg:96.38ms +step:301/1695 train_time:29008ms step_avg:96.37ms +step:302/1695 train_time:29101ms step_avg:96.36ms +step:303/1695 train_time:29195ms step_avg:96.35ms +step:304/1695 train_time:29289ms step_avg:96.34ms +step:305/1695 train_time:29383ms step_avg:96.34ms +step:306/1695 train_time:29478ms step_avg:96.33ms +step:307/1695 train_time:29572ms step_avg:96.32ms +step:308/1695 train_time:29667ms step_avg:96.32ms +step:309/1695 train_time:29761ms step_avg:96.31ms +step:310/1695 train_time:29855ms step_avg:96.31ms +step:311/1695 train_time:29949ms step_avg:96.30ms +step:312/1695 train_time:30044ms step_avg:96.29ms +step:313/1695 train_time:30138ms step_avg:96.29ms +step:314/1695 train_time:30231ms step_avg:96.28ms +step:315/1695 train_time:30325ms step_avg:96.27ms +step:316/1695 train_time:30419ms step_avg:96.26ms +step:317/1695 train_time:30513ms step_avg:96.26ms +step:318/1695 train_time:30607ms step_avg:96.25ms +step:319/1695 train_time:30702ms step_avg:96.24ms +step:320/1695 train_time:30796ms step_avg:96.24ms +step:321/1695 train_time:30890ms step_avg:96.23ms +step:322/1695 train_time:30984ms step_avg:96.22ms +step:323/1695 train_time:31078ms step_avg:96.22ms +step:324/1695 train_time:31171ms step_avg:96.21ms +step:325/1695 train_time:31265ms step_avg:96.20ms +step:326/1695 train_time:31359ms step_avg:96.19ms +step:327/1695 train_time:31452ms step_avg:96.18ms +step:328/1695 train_time:31545ms step_avg:96.17ms +step:329/1695 train_time:31640ms step_avg:96.17ms +step:330/1695 train_time:31735ms step_avg:96.17ms +step:331/1695 train_time:31829ms step_avg:96.16ms +step:332/1695 train_time:31923ms step_avg:96.15ms +step:333/1695 train_time:32016ms step_avg:96.14ms +step:334/1695 train_time:32109ms step_avg:96.14ms +step:335/1695 train_time:32203ms step_avg:96.13ms +step:336/1695 train_time:32298ms step_avg:96.12ms +step:337/1695 train_time:32392ms step_avg:96.12ms +step:338/1695 train_time:32486ms step_avg:96.11ms +step:339/1695 train_time:32580ms step_avg:96.11ms +step:340/1695 train_time:32674ms step_avg:96.10ms +step:341/1695 train_time:32767ms step_avg:96.09ms +step:342/1695 train_time:32862ms step_avg:96.09ms +step:343/1695 train_time:32957ms step_avg:96.08ms +step:344/1695 train_time:33050ms step_avg:96.08ms +step:345/1695 train_time:33394ms step_avg:96.79ms +step:346/1695 train_time:33465ms step_avg:96.72ms +step:347/1695 train_time:33557ms step_avg:96.71ms +step:348/1695 train_time:33650ms step_avg:96.70ms +step:349/1695 train_time:33744ms step_avg:96.69ms +step:350/1695 train_time:33837ms step_avg:96.68ms +step:351/1695 train_time:33930ms step_avg:96.67ms +step:352/1695 train_time:34023ms step_avg:96.66ms +step:353/1695 train_time:34116ms step_avg:96.65ms +step:354/1695 train_time:34209ms step_avg:96.64ms +step:355/1695 train_time:34305ms step_avg:96.63ms +step:356/1695 train_time:34401ms step_avg:96.63ms +step:357/1695 train_time:34497ms step_avg:96.63ms +step:358/1695 train_time:34591ms step_avg:96.62ms +step:359/1695 train_time:34685ms step_avg:96.61ms +step:360/1695 train_time:34779ms step_avg:96.61ms +step:361/1695 train_time:34871ms step_avg:96.60ms +step:362/1695 train_time:34965ms step_avg:96.59ms +step:363/1695 train_time:35060ms step_avg:96.58ms +step:364/1695 train_time:35153ms step_avg:96.57ms +step:365/1695 train_time:35246ms step_avg:96.56ms +step:366/1695 train_time:35341ms step_avg:96.56ms +step:367/1695 train_time:35436ms step_avg:96.56ms +step:368/1695 train_time:35530ms step_avg:96.55ms +step:369/1695 train_time:35625ms step_avg:96.54ms +step:370/1695 train_time:35719ms step_avg:96.54ms +step:371/1695 train_time:35812ms step_avg:96.53ms +step:372/1695 train_time:35905ms step_avg:96.52ms +step:373/1695 train_time:35999ms step_avg:96.51ms +step:374/1695 train_time:36092ms step_avg:96.50ms +step:375/1695 train_time:36185ms step_avg:96.49ms +step:375/1695 val_loss:3.8206 train_time:36277ms step_avg:96.74ms +step:376/1695 train_time:36303ms step_avg:96.55ms +step:377/1695 train_time:36376ms step_avg:96.49ms +step:378/1695 train_time:36477ms step_avg:96.50ms +step:379/1695 train_time:36572ms step_avg:96.50ms +step:380/1695 train_time:36667ms step_avg:96.49ms +step:381/1695 train_time:36760ms step_avg:96.48ms +step:382/1695 train_time:36852ms step_avg:96.47ms +step:383/1695 train_time:36946ms step_avg:96.46ms +step:384/1695 train_time:37039ms step_avg:96.45ms +step:385/1695 train_time:37131ms step_avg:96.44ms +step:386/1695 train_time:37224ms step_avg:96.44ms +step:387/1695 train_time:37319ms step_avg:96.43ms +step:388/1695 train_time:37415ms step_avg:96.43ms +step:389/1695 train_time:37511ms step_avg:96.43ms +step:390/1695 train_time:37607ms step_avg:96.43ms +step:391/1695 train_time:37701ms step_avg:96.42ms +step:392/1695 train_time:37794ms step_avg:96.41ms +step:393/1695 train_time:37888ms step_avg:96.41ms +step:394/1695 train_time:37982ms step_avg:96.40ms +step:395/1695 train_time:38075ms step_avg:96.39ms +step:396/1695 train_time:38168ms step_avg:96.39ms +step:397/1695 train_time:38263ms step_avg:96.38ms +step:398/1695 train_time:38358ms step_avg:96.38ms +step:399/1695 train_time:38452ms step_avg:96.37ms +step:400/1695 train_time:38548ms step_avg:96.37ms +step:401/1695 train_time:38643ms step_avg:96.37ms +step:402/1695 train_time:38737ms step_avg:96.36ms +step:403/1695 train_time:38830ms step_avg:96.35ms +step:404/1695 train_time:38924ms step_avg:96.35ms +step:405/1695 train_time:39017ms step_avg:96.34ms +step:406/1695 train_time:39110ms step_avg:96.33ms +step:407/1695 train_time:39203ms step_avg:96.32ms +step:408/1695 train_time:39296ms step_avg:96.31ms +step:409/1695 train_time:39391ms step_avg:96.31ms +step:410/1695 train_time:39486ms step_avg:96.31ms +step:411/1695 train_time:39580ms step_avg:96.30ms +step:412/1695 train_time:39673ms step_avg:96.29ms +step:413/1695 train_time:39768ms step_avg:96.29ms +step:414/1695 train_time:39863ms step_avg:96.29ms +step:415/1695 train_time:39957ms step_avg:96.28ms +step:416/1695 train_time:40050ms step_avg:96.27ms +step:417/1695 train_time:40144ms step_avg:96.27ms +step:418/1695 train_time:40237ms step_avg:96.26ms +step:419/1695 train_time:40331ms step_avg:96.25ms +step:420/1695 train_time:40425ms step_avg:96.25ms +step:421/1695 train_time:40519ms step_avg:96.25ms +step:422/1695 train_time:40613ms step_avg:96.24ms +step:423/1695 train_time:40707ms step_avg:96.23ms +step:424/1695 train_time:40802ms step_avg:96.23ms +step:425/1695 train_time:40895ms step_avg:96.22ms +step:426/1695 train_time:40990ms step_avg:96.22ms +step:427/1695 train_time:41085ms step_avg:96.22ms +step:428/1695 train_time:41178ms step_avg:96.21ms +step:429/1695 train_time:41272ms step_avg:96.20ms +step:430/1695 train_time:41365ms step_avg:96.20ms +step:431/1695 train_time:41459ms step_avg:96.19ms +step:432/1695 train_time:41553ms step_avg:96.19ms +step:433/1695 train_time:41648ms step_avg:96.18ms +step:434/1695 train_time:41743ms step_avg:96.18ms +step:435/1695 train_time:41837ms step_avg:96.18ms +step:436/1695 train_time:41930ms step_avg:96.17ms +step:437/1695 train_time:42025ms step_avg:96.17ms +step:438/1695 train_time:42119ms step_avg:96.16ms +step:439/1695 train_time:42212ms step_avg:96.15ms +step:440/1695 train_time:42305ms step_avg:96.15ms +step:441/1695 train_time:42399ms step_avg:96.14ms +step:442/1695 train_time:42492ms step_avg:96.14ms +step:443/1695 train_time:42587ms step_avg:96.13ms +step:444/1695 train_time:42680ms step_avg:96.13ms +step:445/1695 train_time:42774ms step_avg:96.12ms +step:446/1695 train_time:42868ms step_avg:96.12ms +step:447/1695 train_time:42963ms step_avg:96.11ms +step:448/1695 train_time:43057ms step_avg:96.11ms +step:449/1695 train_time:43150ms step_avg:96.10ms +step:450/1695 train_time:43245ms step_avg:96.10ms +step:451/1695 train_time:43338ms step_avg:96.09ms +step:452/1695 train_time:43432ms step_avg:96.09ms +step:453/1695 train_time:43526ms step_avg:96.08ms +step:454/1695 train_time:43621ms step_avg:96.08ms +step:455/1695 train_time:43715ms step_avg:96.08ms +step:456/1695 train_time:43809ms step_avg:96.07ms +step:457/1695 train_time:43903ms step_avg:96.07ms +step:458/1695 train_time:43997ms step_avg:96.06ms +step:459/1695 train_time:44091ms step_avg:96.06ms +step:460/1695 train_time:44185ms step_avg:96.05ms +step:461/1695 train_time:44278ms step_avg:96.05ms +step:462/1695 train_time:44372ms step_avg:96.04ms +step:463/1695 train_time:44466ms step_avg:96.04ms +step:464/1695 train_time:44560ms step_avg:96.03ms +step:465/1695 train_time:44654ms step_avg:96.03ms +step:466/1695 train_time:44748ms step_avg:96.03ms +step:467/1695 train_time:44842ms step_avg:96.02ms +step:468/1695 train_time:44935ms step_avg:96.02ms +step:469/1695 train_time:45031ms step_avg:96.01ms +step:470/1695 train_time:45125ms step_avg:96.01ms +step:471/1695 train_time:45218ms step_avg:96.00ms +step:472/1695 train_time:45312ms step_avg:96.00ms +step:473/1695 train_time:45406ms step_avg:96.00ms +step:474/1695 train_time:45499ms step_avg:95.99ms +step:475/1695 train_time:45594ms step_avg:95.99ms +step:476/1695 train_time:45689ms step_avg:95.99ms +step:477/1695 train_time:45784ms step_avg:95.98ms +step:478/1695 train_time:45877ms step_avg:95.98ms +step:479/1695 train_time:45971ms step_avg:95.97ms +step:480/1695 train_time:46065ms step_avg:95.97ms +step:481/1695 train_time:46159ms step_avg:95.96ms +step:482/1695 train_time:46252ms step_avg:95.96ms +step:483/1695 train_time:46347ms step_avg:95.96ms +step:484/1695 train_time:46442ms step_avg:95.95ms +step:485/1695 train_time:46535ms step_avg:95.95ms +step:486/1695 train_time:46629ms step_avg:95.95ms +step:487/1695 train_time:46723ms step_avg:95.94ms +step:488/1695 train_time:46818ms step_avg:95.94ms +step:489/1695 train_time:46911ms step_avg:95.93ms +step:490/1695 train_time:47004ms step_avg:95.93ms +step:491/1695 train_time:47098ms step_avg:95.92ms +step:492/1695 train_time:47191ms step_avg:95.92ms +step:493/1695 train_time:47286ms step_avg:95.91ms +step:494/1695 train_time:47381ms step_avg:95.91ms +step:495/1695 train_time:47474ms step_avg:95.91ms +step:496/1695 train_time:47568ms step_avg:95.90ms +step:497/1695 train_time:47663ms step_avg:95.90ms +step:498/1695 train_time:47756ms step_avg:95.90ms +step:499/1695 train_time:47851ms step_avg:95.89ms +step:500/1695 train_time:47944ms step_avg:95.89ms +step:500/1695 val_loss:3.7169 train_time:48035ms step_avg:96.07ms +step:501/1695 train_time:48064ms step_avg:95.94ms +step:502/1695 train_time:48136ms step_avg:95.89ms +step:503/1695 train_time:48238ms step_avg:95.90ms +step:504/1695 train_time:48335ms step_avg:95.90ms +step:505/1695 train_time:48428ms step_avg:95.90ms +step:506/1695 train_time:48521ms step_avg:95.89ms +step:507/1695 train_time:48614ms step_avg:95.89ms +step:508/1695 train_time:48707ms step_avg:95.88ms +step:509/1695 train_time:48800ms step_avg:95.87ms +step:510/1695 train_time:48893ms step_avg:95.87ms +step:511/1695 train_time:48986ms step_avg:95.86ms +step:512/1695 train_time:49080ms step_avg:95.86ms +step:513/1695 train_time:49177ms step_avg:95.86ms +step:514/1695 train_time:49275ms step_avg:95.87ms +step:515/1695 train_time:49371ms step_avg:95.87ms +step:516/1695 train_time:49465ms step_avg:95.86ms +step:517/1695 train_time:49558ms step_avg:95.86ms +step:518/1695 train_time:49652ms step_avg:95.85ms +step:519/1695 train_time:49984ms step_avg:96.31ms +step:520/1695 train_time:50159ms step_avg:96.46ms +step:521/1695 train_time:50250ms step_avg:96.45ms +step:522/1695 train_time:50343ms step_avg:96.44ms +step:523/1695 train_time:50436ms step_avg:96.44ms +step:524/1695 train_time:50529ms step_avg:96.43ms +step:525/1695 train_time:50621ms step_avg:96.42ms +step:526/1695 train_time:50714ms step_avg:96.41ms +step:527/1695 train_time:50806ms step_avg:96.41ms +step:528/1695 train_time:50898ms step_avg:96.40ms +step:529/1695 train_time:50992ms step_avg:96.39ms +step:530/1695 train_time:51091ms step_avg:96.40ms +step:531/1695 train_time:51188ms step_avg:96.40ms +step:532/1695 train_time:51282ms step_avg:96.40ms +step:533/1695 train_time:51376ms step_avg:96.39ms +step:534/1695 train_time:51470ms step_avg:96.39ms +step:535/1695 train_time:51564ms step_avg:96.38ms +step:536/1695 train_time:51656ms step_avg:96.37ms +step:537/1695 train_time:51750ms step_avg:96.37ms +step:538/1695 train_time:51842ms step_avg:96.36ms +step:539/1695 train_time:51935ms step_avg:96.35ms +step:540/1695 train_time:52032ms step_avg:96.36ms +step:541/1695 train_time:52126ms step_avg:96.35ms +step:542/1695 train_time:52220ms step_avg:96.35ms +step:543/1695 train_time:52316ms step_avg:96.35ms +step:544/1695 train_time:52411ms step_avg:96.34ms +step:545/1695 train_time:52505ms step_avg:96.34ms +step:546/1695 train_time:52598ms step_avg:96.33ms +step:547/1695 train_time:52691ms step_avg:96.33ms +step:548/1695 train_time:52784ms step_avg:96.32ms +step:549/1695 train_time:52877ms step_avg:96.32ms +step:550/1695 train_time:52971ms step_avg:96.31ms +step:551/1695 train_time:53066ms step_avg:96.31ms +step:552/1695 train_time:53161ms step_avg:96.31ms +step:553/1695 train_time:53255ms step_avg:96.30ms +step:554/1695 train_time:53352ms step_avg:96.30ms +step:555/1695 train_time:53446ms step_avg:96.30ms +step:556/1695 train_time:53539ms step_avg:96.29ms +step:557/1695 train_time:53633ms step_avg:96.29ms +step:558/1695 train_time:53727ms step_avg:96.29ms +step:559/1695 train_time:53820ms step_avg:96.28ms +step:560/1695 train_time:53914ms step_avg:96.27ms +step:561/1695 train_time:54009ms step_avg:96.27ms +step:562/1695 train_time:54102ms step_avg:96.27ms +step:563/1695 train_time:54196ms step_avg:96.26ms +step:564/1695 train_time:54291ms step_avg:96.26ms +step:565/1695 train_time:54385ms step_avg:96.26ms +step:566/1695 train_time:54479ms step_avg:96.25ms +step:567/1695 train_time:54573ms step_avg:96.25ms +step:568/1695 train_time:54670ms step_avg:96.25ms +step:569/1695 train_time:54765ms step_avg:96.25ms +step:570/1695 train_time:54859ms step_avg:96.24ms +step:571/1695 train_time:54955ms step_avg:96.24ms +step:572/1695 train_time:55052ms step_avg:96.24ms +step:573/1695 train_time:55149ms step_avg:96.25ms +step:574/1695 train_time:55245ms step_avg:96.25ms +step:575/1695 train_time:55344ms step_avg:96.25ms +step:576/1695 train_time:55437ms step_avg:96.25ms +step:577/1695 train_time:55534ms step_avg:96.25ms +step:578/1695 train_time:55631ms step_avg:96.25ms +step:579/1695 train_time:55727ms step_avg:96.25ms +step:580/1695 train_time:55823ms step_avg:96.25ms +step:581/1695 train_time:55918ms step_avg:96.24ms +step:582/1695 train_time:56015ms step_avg:96.25ms +step:583/1695 train_time:56111ms step_avg:96.25ms +step:584/1695 train_time:56209ms step_avg:96.25ms +step:585/1695 train_time:56306ms step_avg:96.25ms +step:586/1695 train_time:56402ms step_avg:96.25ms +step:587/1695 train_time:56497ms step_avg:96.25ms +step:588/1695 train_time:56593ms step_avg:96.25ms +step:589/1695 train_time:56689ms step_avg:96.25ms +step:590/1695 train_time:56786ms step_avg:96.25ms +step:591/1695 train_time:56882ms step_avg:96.25ms +step:592/1695 train_time:56978ms step_avg:96.25ms +step:593/1695 train_time:57074ms step_avg:96.25ms +step:594/1695 train_time:57172ms step_avg:96.25ms +step:595/1695 train_time:57269ms step_avg:96.25ms +step:596/1695 train_time:57365ms step_avg:96.25ms +step:597/1695 train_time:57461ms step_avg:96.25ms +step:598/1695 train_time:57555ms step_avg:96.25ms +step:599/1695 train_time:57652ms step_avg:96.25ms +step:600/1695 train_time:57749ms step_avg:96.25ms +step:601/1695 train_time:57846ms step_avg:96.25ms +step:602/1695 train_time:57942ms step_avg:96.25ms +step:603/1695 train_time:58037ms step_avg:96.25ms +step:604/1695 train_time:58133ms step_avg:96.25ms +step:605/1695 train_time:58229ms step_avg:96.25ms +step:606/1695 train_time:58326ms step_avg:96.25ms +step:607/1695 train_time:58421ms step_avg:96.25ms +step:608/1695 train_time:58516ms step_avg:96.24ms +step:609/1695 train_time:58613ms step_avg:96.25ms +step:610/1695 train_time:58710ms step_avg:96.25ms +step:611/1695 train_time:58807ms step_avg:96.25ms +step:612/1695 train_time:58903ms step_avg:96.25ms +step:613/1695 train_time:58998ms step_avg:96.24ms +step:614/1695 train_time:59095ms step_avg:96.25ms +step:615/1695 train_time:59191ms step_avg:96.25ms +step:616/1695 train_time:59288ms step_avg:96.25ms +step:617/1695 train_time:59384ms step_avg:96.25ms +step:618/1695 train_time:59479ms step_avg:96.24ms +step:619/1695 train_time:59577ms step_avg:96.25ms +step:620/1695 train_time:59674ms step_avg:96.25ms +step:621/1695 train_time:59772ms step_avg:96.25ms +step:622/1695 train_time:59868ms step_avg:96.25ms +step:623/1695 train_time:59964ms step_avg:96.25ms +step:624/1695 train_time:60059ms step_avg:96.25ms +step:625/1695 train_time:60155ms step_avg:96.25ms +step:625/1695 val_loss:3.6208 train_time:60249ms step_avg:96.40ms +step:626/1695 train_time:60275ms step_avg:96.29ms +step:627/1695 train_time:60358ms step_avg:96.26ms +step:628/1695 train_time:60456ms step_avg:96.27ms +step:629/1695 train_time:60551ms step_avg:96.27ms +step:630/1695 train_time:60646ms step_avg:96.26ms +step:631/1695 train_time:60741ms step_avg:96.26ms +step:632/1695 train_time:60836ms step_avg:96.26ms +step:633/1695 train_time:60930ms step_avg:96.26ms +step:634/1695 train_time:61025ms step_avg:96.25ms +step:635/1695 train_time:61122ms step_avg:96.25ms +step:636/1695 train_time:61217ms step_avg:96.25ms +step:637/1695 train_time:61315ms step_avg:96.26ms +step:638/1695 train_time:61413ms step_avg:96.26ms +step:639/1695 train_time:61509ms step_avg:96.26ms +step:640/1695 train_time:61606ms step_avg:96.26ms +step:641/1695 train_time:61702ms step_avg:96.26ms +step:642/1695 train_time:61797ms step_avg:96.26ms +step:643/1695 train_time:61892ms step_avg:96.26ms +step:644/1695 train_time:61988ms step_avg:96.25ms +step:645/1695 train_time:62084ms step_avg:96.25ms +step:646/1695 train_time:62181ms step_avg:96.26ms +step:647/1695 train_time:62278ms step_avg:96.26ms +step:648/1695 train_time:62376ms step_avg:96.26ms +step:649/1695 train_time:62472ms step_avg:96.26ms +step:650/1695 train_time:62568ms step_avg:96.26ms +step:651/1695 train_time:62664ms step_avg:96.26ms +step:652/1695 train_time:62760ms step_avg:96.26ms +step:653/1695 train_time:62855ms step_avg:96.26ms +step:654/1695 train_time:62950ms step_avg:96.25ms +step:655/1695 train_time:63045ms step_avg:96.25ms +step:656/1695 train_time:63141ms step_avg:96.25ms +step:657/1695 train_time:63238ms step_avg:96.25ms +step:658/1695 train_time:63335ms step_avg:96.25ms +step:659/1695 train_time:63431ms step_avg:96.25ms +step:660/1695 train_time:63529ms step_avg:96.26ms +step:661/1695 train_time:63626ms step_avg:96.26ms +step:662/1695 train_time:63721ms step_avg:96.26ms +step:663/1695 train_time:63816ms step_avg:96.25ms +step:664/1695 train_time:63911ms step_avg:96.25ms +step:665/1695 train_time:64007ms step_avg:96.25ms +step:666/1695 train_time:64103ms step_avg:96.25ms +step:667/1695 train_time:64199ms step_avg:96.25ms +step:668/1695 train_time:64295ms step_avg:96.25ms +step:669/1695 train_time:64390ms step_avg:96.25ms +step:670/1695 train_time:64487ms step_avg:96.25ms +step:671/1695 train_time:64585ms step_avg:96.25ms +step:672/1695 train_time:64682ms step_avg:96.25ms +step:673/1695 train_time:64778ms step_avg:96.25ms +step:674/1695 train_time:64873ms step_avg:96.25ms +step:675/1695 train_time:64969ms step_avg:96.25ms +step:676/1695 train_time:65065ms step_avg:96.25ms +step:677/1695 train_time:65161ms step_avg:96.25ms +step:678/1695 train_time:65257ms step_avg:96.25ms +step:679/1695 train_time:65353ms step_avg:96.25ms +step:680/1695 train_time:65449ms step_avg:96.25ms +step:681/1695 train_time:65545ms step_avg:96.25ms +step:682/1695 train_time:65642ms step_avg:96.25ms +step:683/1695 train_time:65738ms step_avg:96.25ms +step:684/1695 train_time:65833ms step_avg:96.25ms +step:685/1695 train_time:65929ms step_avg:96.25ms +step:686/1695 train_time:66024ms step_avg:96.25ms +step:687/1695 train_time:66120ms step_avg:96.25ms +step:688/1695 train_time:66216ms step_avg:96.24ms +step:689/1695 train_time:66311ms step_avg:96.24ms +step:690/1695 train_time:66408ms step_avg:96.24ms +step:691/1695 train_time:66766ms step_avg:96.62ms +step:692/1695 train_time:66929ms step_avg:96.72ms +step:693/1695 train_time:67024ms step_avg:96.72ms +step:694/1695 train_time:67118ms step_avg:96.71ms +step:695/1695 train_time:67212ms step_avg:96.71ms +step:696/1695 train_time:67307ms step_avg:96.71ms +step:697/1695 train_time:67402ms step_avg:96.70ms +step:698/1695 train_time:67496ms step_avg:96.70ms +step:699/1695 train_time:67590ms step_avg:96.70ms +step:700/1695 train_time:67686ms step_avg:96.69ms +step:701/1695 train_time:67789ms step_avg:96.70ms +step:702/1695 train_time:67889ms step_avg:96.71ms +step:703/1695 train_time:67986ms step_avg:96.71ms +step:704/1695 train_time:68082ms step_avg:96.71ms +step:705/1695 train_time:68176ms step_avg:96.70ms +step:706/1695 train_time:68272ms step_avg:96.70ms +step:707/1695 train_time:68368ms step_avg:96.70ms +step:708/1695 train_time:68463ms step_avg:96.70ms +step:709/1695 train_time:68558ms step_avg:96.70ms +step:710/1695 train_time:68653ms step_avg:96.69ms +step:711/1695 train_time:68750ms step_avg:96.70ms +step:712/1695 train_time:68848ms step_avg:96.70ms +step:713/1695 train_time:68946ms step_avg:96.70ms +step:714/1695 train_time:69043ms step_avg:96.70ms +step:715/1695 train_time:69140ms step_avg:96.70ms +step:716/1695 train_time:69235ms step_avg:96.70ms +step:717/1695 train_time:69330ms step_avg:96.69ms +step:718/1695 train_time:69426ms step_avg:96.69ms +step:719/1695 train_time:69522ms step_avg:96.69ms +step:720/1695 train_time:69617ms step_avg:96.69ms +step:721/1695 train_time:69713ms step_avg:96.69ms +step:722/1695 train_time:69810ms step_avg:96.69ms +step:723/1695 train_time:69907ms step_avg:96.69ms +step:724/1695 train_time:70003ms step_avg:96.69ms +step:725/1695 train_time:70100ms step_avg:96.69ms +step:726/1695 train_time:70195ms step_avg:96.69ms +step:727/1695 train_time:70291ms step_avg:96.69ms +step:728/1695 train_time:70387ms step_avg:96.69ms +step:729/1695 train_time:70482ms step_avg:96.68ms +step:730/1695 train_time:70578ms step_avg:96.68ms +step:731/1695 train_time:70673ms step_avg:96.68ms +step:732/1695 train_time:70769ms step_avg:96.68ms +step:733/1695 train_time:70865ms step_avg:96.68ms +step:734/1695 train_time:70963ms step_avg:96.68ms +step:735/1695 train_time:71059ms step_avg:96.68ms +step:736/1695 train_time:71155ms step_avg:96.68ms +step:737/1695 train_time:71250ms step_avg:96.68ms +step:738/1695 train_time:71345ms step_avg:96.67ms +step:739/1695 train_time:71441ms step_avg:96.67ms +step:740/1695 train_time:71537ms step_avg:96.67ms +step:741/1695 train_time:71632ms step_avg:96.67ms +step:742/1695 train_time:71728ms step_avg:96.67ms +step:743/1695 train_time:71824ms step_avg:96.67ms +step:744/1695 train_time:71920ms step_avg:96.67ms +step:745/1695 train_time:72016ms step_avg:96.67ms +step:746/1695 train_time:72111ms step_avg:96.66ms +step:747/1695 train_time:72208ms step_avg:96.66ms +step:748/1695 train_time:72303ms step_avg:96.66ms +step:749/1695 train_time:72399ms step_avg:96.66ms +step:750/1695 train_time:72495ms step_avg:96.66ms +step:750/1695 val_loss:3.5658 train_time:72587ms step_avg:96.78ms +step:751/1695 train_time:72614ms step_avg:96.69ms +step:752/1695 train_time:72692ms step_avg:96.67ms +step:753/1695 train_time:72794ms step_avg:96.67ms +step:754/1695 train_time:72890ms step_avg:96.67ms +step:755/1695 train_time:72986ms step_avg:96.67ms +step:756/1695 train_time:73081ms step_avg:96.67ms +step:757/1695 train_time:73176ms step_avg:96.67ms +step:758/1695 train_time:73271ms step_avg:96.66ms +step:759/1695 train_time:73365ms step_avg:96.66ms +step:760/1695 train_time:73460ms step_avg:96.66ms +step:761/1695 train_time:73556ms step_avg:96.66ms +step:762/1695 train_time:73653ms step_avg:96.66ms +step:763/1695 train_time:73751ms step_avg:96.66ms +step:764/1695 train_time:73849ms step_avg:96.66ms +step:765/1695 train_time:73945ms step_avg:96.66ms +step:766/1695 train_time:74040ms step_avg:96.66ms +step:767/1695 train_time:74135ms step_avg:96.66ms +step:768/1695 train_time:74231ms step_avg:96.66ms +step:769/1695 train_time:74327ms step_avg:96.65ms +step:770/1695 train_time:74421ms step_avg:96.65ms +step:771/1695 train_time:74516ms step_avg:96.65ms +step:772/1695 train_time:74614ms step_avg:96.65ms +step:773/1695 train_time:74712ms step_avg:96.65ms +step:774/1695 train_time:74809ms step_avg:96.65ms +step:775/1695 train_time:74908ms step_avg:96.65ms +step:776/1695 train_time:75004ms step_avg:96.65ms +step:777/1695 train_time:75099ms step_avg:96.65ms +step:778/1695 train_time:75194ms step_avg:96.65ms +step:779/1695 train_time:75290ms step_avg:96.65ms +step:780/1695 train_time:75386ms step_avg:96.65ms +step:781/1695 train_time:75482ms step_avg:96.65ms +step:782/1695 train_time:75578ms step_avg:96.65ms +step:783/1695 train_time:75675ms step_avg:96.65ms +step:784/1695 train_time:75772ms step_avg:96.65ms +step:785/1695 train_time:75870ms step_avg:96.65ms +step:786/1695 train_time:75967ms step_avg:96.65ms +step:787/1695 train_time:76062ms step_avg:96.65ms +step:788/1695 train_time:76158ms step_avg:96.65ms +step:789/1695 train_time:76253ms step_avg:96.65ms +step:790/1695 train_time:76350ms step_avg:96.65ms +step:791/1695 train_time:76445ms step_avg:96.64ms +step:792/1695 train_time:76541ms step_avg:96.64ms +step:793/1695 train_time:76638ms step_avg:96.64ms +step:794/1695 train_time:76735ms step_avg:96.64ms +step:795/1695 train_time:76832ms step_avg:96.64ms +step:796/1695 train_time:76930ms step_avg:96.65ms +step:797/1695 train_time:77026ms step_avg:96.65ms +step:798/1695 train_time:77121ms step_avg:96.64ms +step:799/1695 train_time:77216ms step_avg:96.64ms +step:800/1695 train_time:77312ms step_avg:96.64ms +step:801/1695 train_time:77411ms step_avg:96.64ms +step:802/1695 train_time:77506ms step_avg:96.64ms +step:803/1695 train_time:77602ms step_avg:96.64ms +step:804/1695 train_time:77697ms step_avg:96.64ms +step:805/1695 train_time:77795ms step_avg:96.64ms +step:806/1695 train_time:77893ms step_avg:96.64ms +step:807/1695 train_time:77990ms step_avg:96.64ms +step:808/1695 train_time:78087ms step_avg:96.64ms +step:809/1695 train_time:78183ms step_avg:96.64ms +step:810/1695 train_time:78281ms step_avg:96.64ms +step:811/1695 train_time:78374ms step_avg:96.64ms +step:812/1695 train_time:78471ms step_avg:96.64ms +step:813/1695 train_time:78568ms step_avg:96.64ms +step:814/1695 train_time:78663ms step_avg:96.64ms +step:815/1695 train_time:78759ms step_avg:96.64ms +step:816/1695 train_time:78855ms step_avg:96.64ms +step:817/1695 train_time:78952ms step_avg:96.64ms +step:818/1695 train_time:79048ms step_avg:96.64ms +step:819/1695 train_time:79144ms step_avg:96.64ms +step:820/1695 train_time:79239ms step_avg:96.63ms +step:821/1695 train_time:79336ms step_avg:96.63ms +step:822/1695 train_time:79431ms step_avg:96.63ms +step:823/1695 train_time:79527ms step_avg:96.63ms +step:824/1695 train_time:79623ms step_avg:96.63ms +step:825/1695 train_time:79719ms step_avg:96.63ms +step:826/1695 train_time:79816ms step_avg:96.63ms +step:827/1695 train_time:79912ms step_avg:96.63ms +step:828/1695 train_time:80008ms step_avg:96.63ms +step:829/1695 train_time:80103ms step_avg:96.63ms +step:830/1695 train_time:80199ms step_avg:96.62ms +step:831/1695 train_time:80295ms step_avg:96.62ms +step:832/1695 train_time:80391ms step_avg:96.62ms +step:833/1695 train_time:80487ms step_avg:96.62ms +step:834/1695 train_time:80584ms step_avg:96.62ms +step:835/1695 train_time:80680ms step_avg:96.62ms +step:836/1695 train_time:80775ms step_avg:96.62ms +step:837/1695 train_time:80872ms step_avg:96.62ms +step:838/1695 train_time:80968ms step_avg:96.62ms +step:839/1695 train_time:81064ms step_avg:96.62ms +step:840/1695 train_time:81160ms step_avg:96.62ms +step:841/1695 train_time:81257ms step_avg:96.62ms +step:842/1695 train_time:81352ms step_avg:96.62ms +step:843/1695 train_time:81450ms step_avg:96.62ms +step:844/1695 train_time:81546ms step_avg:96.62ms +step:845/1695 train_time:81643ms step_avg:96.62ms +step:846/1695 train_time:81738ms step_avg:96.62ms +step:847/1695 train_time:81834ms step_avg:96.62ms +step:848/1695 train_time:81931ms step_avg:96.62ms +step:849/1695 train_time:82026ms step_avg:96.62ms +step:850/1695 train_time:82122ms step_avg:96.61ms +step:851/1695 train_time:82218ms step_avg:96.61ms +step:852/1695 train_time:82313ms step_avg:96.61ms +step:853/1695 train_time:82409ms step_avg:96.61ms +step:854/1695 train_time:82506ms step_avg:96.61ms +step:855/1695 train_time:82602ms step_avg:96.61ms +step:856/1695 train_time:82698ms step_avg:96.61ms +step:857/1695 train_time:82795ms step_avg:96.61ms +step:858/1695 train_time:82891ms step_avg:96.61ms +step:859/1695 train_time:82987ms step_avg:96.61ms +step:860/1695 train_time:83083ms step_avg:96.61ms +step:861/1695 train_time:83180ms step_avg:96.61ms +step:862/1695 train_time:83275ms step_avg:96.61ms +step:863/1695 train_time:83625ms step_avg:96.90ms +step:864/1695 train_time:83802ms step_avg:96.99ms +step:865/1695 train_time:83896ms step_avg:96.99ms +step:866/1695 train_time:83991ms step_avg:96.99ms +step:867/1695 train_time:84086ms step_avg:96.99ms +step:868/1695 train_time:84180ms step_avg:96.98ms +step:869/1695 train_time:84275ms step_avg:96.98ms +step:870/1695 train_time:84370ms step_avg:96.98ms +step:871/1695 train_time:84465ms step_avg:96.97ms +step:872/1695 train_time:84559ms step_avg:96.97ms +step:873/1695 train_time:84660ms step_avg:96.98ms +step:874/1695 train_time:84758ms step_avg:96.98ms +step:875/1695 train_time:84857ms step_avg:96.98ms +step:875/1695 val_loss:3.5251 train_time:84952ms step_avg:97.09ms +step:876/1695 train_time:84978ms step_avg:97.01ms +step:877/1695 train_time:85058ms step_avg:96.99ms +step:878/1695 train_time:85158ms step_avg:96.99ms +step:879/1695 train_time:85256ms step_avg:96.99ms +step:880/1695 train_time:85352ms step_avg:96.99ms +step:881/1695 train_time:85447ms step_avg:96.99ms +step:882/1695 train_time:85541ms step_avg:96.99ms +step:883/1695 train_time:85636ms step_avg:96.98ms +step:884/1695 train_time:85731ms step_avg:96.98ms +step:885/1695 train_time:85825ms step_avg:96.98ms +step:886/1695 train_time:85920ms step_avg:96.98ms +step:887/1695 train_time:86020ms step_avg:96.98ms +step:888/1695 train_time:86118ms step_avg:96.98ms +step:889/1695 train_time:86217ms step_avg:96.98ms +step:890/1695 train_time:86314ms step_avg:96.98ms +step:891/1695 train_time:86411ms step_avg:96.98ms +step:892/1695 train_time:86505ms step_avg:96.98ms +step:893/1695 train_time:86600ms step_avg:96.98ms +step:894/1695 train_time:86695ms step_avg:96.97ms +step:895/1695 train_time:86790ms step_avg:96.97ms +step:896/1695 train_time:86886ms step_avg:96.97ms +step:897/1695 train_time:86982ms step_avg:96.97ms +step:898/1695 train_time:87080ms step_avg:96.97ms +step:899/1695 train_time:87177ms step_avg:96.97ms +step:900/1695 train_time:87276ms step_avg:96.97ms +step:901/1695 train_time:87374ms step_avg:96.97ms +step:902/1695 train_time:87471ms step_avg:96.97ms +step:903/1695 train_time:87567ms step_avg:96.97ms +step:904/1695 train_time:87661ms step_avg:96.97ms +step:905/1695 train_time:87756ms step_avg:96.97ms +step:906/1695 train_time:87852ms step_avg:96.97ms +step:907/1695 train_time:87948ms step_avg:96.97ms +step:908/1695 train_time:88044ms step_avg:96.96ms +step:909/1695 train_time:88140ms step_avg:96.96ms +step:910/1695 train_time:88237ms step_avg:96.96ms +step:911/1695 train_time:88336ms step_avg:96.97ms +step:912/1695 train_time:88433ms step_avg:96.97ms +step:913/1695 train_time:88528ms step_avg:96.96ms +step:914/1695 train_time:88624ms step_avg:96.96ms +step:915/1695 train_time:88718ms step_avg:96.96ms +step:916/1695 train_time:88816ms step_avg:96.96ms +step:917/1695 train_time:88914ms step_avg:96.96ms +step:918/1695 train_time:89010ms step_avg:96.96ms +step:919/1695 train_time:89107ms step_avg:96.96ms +step:920/1695 train_time:89203ms step_avg:96.96ms +step:921/1695 train_time:89299ms step_avg:96.96ms +step:922/1695 train_time:89396ms step_avg:96.96ms +step:923/1695 train_time:89493ms step_avg:96.96ms +step:924/1695 train_time:89590ms step_avg:96.96ms +step:925/1695 train_time:89686ms step_avg:96.96ms +step:926/1695 train_time:89782ms step_avg:96.96ms +step:927/1695 train_time:89878ms step_avg:96.96ms +step:928/1695 train_time:89974ms step_avg:96.95ms +step:929/1695 train_time:90071ms step_avg:96.95ms +step:930/1695 train_time:90167ms step_avg:96.95ms +step:931/1695 train_time:90262ms step_avg:96.95ms +step:932/1695 train_time:90358ms step_avg:96.95ms +step:933/1695 train_time:90455ms step_avg:96.95ms +step:934/1695 train_time:90551ms step_avg:96.95ms +step:935/1695 train_time:90648ms step_avg:96.95ms +step:936/1695 train_time:90743ms step_avg:96.95ms +step:937/1695 train_time:90839ms step_avg:96.95ms +step:938/1695 train_time:90935ms step_avg:96.95ms +step:939/1695 train_time:91030ms step_avg:96.94ms +step:940/1695 train_time:91126ms step_avg:96.94ms +step:941/1695 train_time:91221ms step_avg:96.94ms +step:942/1695 train_time:91317ms step_avg:96.94ms +step:943/1695 train_time:91413ms step_avg:96.94ms +step:944/1695 train_time:91510ms step_avg:96.94ms +step:945/1695 train_time:91607ms step_avg:96.94ms +step:946/1695 train_time:91703ms step_avg:96.94ms +step:947/1695 train_time:91798ms step_avg:96.94ms +step:948/1695 train_time:91894ms step_avg:96.93ms +step:949/1695 train_time:91989ms step_avg:96.93ms +step:950/1695 train_time:92085ms step_avg:96.93ms +step:951/1695 train_time:92181ms step_avg:96.93ms +step:952/1695 train_time:92277ms step_avg:96.93ms +step:953/1695 train_time:92373ms step_avg:96.93ms +step:954/1695 train_time:92469ms step_avg:96.93ms +step:955/1695 train_time:92566ms step_avg:96.93ms +step:956/1695 train_time:92662ms step_avg:96.93ms +step:957/1695 train_time:92758ms step_avg:96.93ms +step:958/1695 train_time:92854ms step_avg:96.92ms +step:959/1695 train_time:92950ms step_avg:96.92ms +step:960/1695 train_time:93046ms step_avg:96.92ms +step:961/1695 train_time:93142ms step_avg:96.92ms +step:962/1695 train_time:93238ms step_avg:96.92ms +step:963/1695 train_time:93334ms step_avg:96.92ms +step:964/1695 train_time:93431ms step_avg:96.92ms +step:965/1695 train_time:93526ms step_avg:96.92ms +step:966/1695 train_time:93621ms step_avg:96.92ms +step:967/1695 train_time:93718ms step_avg:96.92ms +step:968/1695 train_time:93815ms step_avg:96.92ms +step:969/1695 train_time:93911ms step_avg:96.92ms +step:970/1695 train_time:94008ms step_avg:96.92ms +step:971/1695 train_time:94104ms step_avg:96.91ms +step:972/1695 train_time:94200ms step_avg:96.91ms +step:973/1695 train_time:94295ms step_avg:96.91ms +step:974/1695 train_time:94392ms step_avg:96.91ms +step:975/1695 train_time:94489ms step_avg:96.91ms +step:976/1695 train_time:94585ms step_avg:96.91ms +step:977/1695 train_time:94681ms step_avg:96.91ms +step:978/1695 train_time:94778ms step_avg:96.91ms +step:979/1695 train_time:94874ms step_avg:96.91ms +step:980/1695 train_time:94970ms step_avg:96.91ms +step:981/1695 train_time:95066ms step_avg:96.91ms +step:982/1695 train_time:95161ms step_avg:96.91ms +step:983/1695 train_time:95257ms step_avg:96.90ms +step:984/1695 train_time:95353ms step_avg:96.90ms +step:985/1695 train_time:95450ms step_avg:96.90ms +step:986/1695 train_time:95547ms step_avg:96.90ms +step:987/1695 train_time:95644ms step_avg:96.90ms +step:988/1695 train_time:95739ms step_avg:96.90ms +step:989/1695 train_time:95835ms step_avg:96.90ms +step:990/1695 train_time:95932ms step_avg:96.90ms +step:991/1695 train_time:96028ms step_avg:96.90ms +step:992/1695 train_time:96123ms step_avg:96.90ms +step:993/1695 train_time:96218ms step_avg:96.90ms +step:994/1695 train_time:96314ms step_avg:96.90ms +step:995/1695 train_time:96411ms step_avg:96.90ms +step:996/1695 train_time:96508ms step_avg:96.90ms +step:997/1695 train_time:96605ms step_avg:96.90ms +step:998/1695 train_time:96701ms step_avg:96.89ms +step:999/1695 train_time:96796ms step_avg:96.89ms +step:1000/1695 train_time:96893ms step_avg:96.89ms +step:1000/1695 val_loss:3.4830 train_time:96988ms step_avg:96.99ms +step:1001/1695 train_time:97014ms step_avg:96.92ms +step:1002/1695 train_time:97096ms step_avg:96.90ms +step:1003/1695 train_time:97195ms step_avg:96.90ms +step:1004/1695 train_time:97291ms step_avg:96.90ms +step:1005/1695 train_time:97387ms step_avg:96.90ms +step:1006/1695 train_time:97482ms step_avg:96.90ms +step:1007/1695 train_time:97577ms step_avg:96.90ms +step:1008/1695 train_time:97673ms step_avg:96.90ms +step:1009/1695 train_time:97768ms step_avg:96.90ms +step:1010/1695 train_time:97862ms step_avg:96.89ms +step:1011/1695 train_time:97959ms step_avg:96.89ms +step:1012/1695 train_time:98058ms step_avg:96.90ms +step:1013/1695 train_time:98157ms step_avg:96.90ms +step:1014/1695 train_time:98256ms step_avg:96.90ms +step:1015/1695 train_time:98354ms step_avg:96.90ms +step:1016/1695 train_time:98450ms step_avg:96.90ms +step:1017/1695 train_time:98545ms step_avg:96.90ms +step:1018/1695 train_time:98640ms step_avg:96.90ms +step:1019/1695 train_time:98736ms step_avg:96.89ms +step:1020/1695 train_time:98831ms step_avg:96.89ms +step:1021/1695 train_time:98927ms step_avg:96.89ms +step:1022/1695 train_time:99025ms step_avg:96.89ms +step:1023/1695 train_time:99121ms step_avg:96.89ms +step:1024/1695 train_time:99218ms step_avg:96.89ms +step:1025/1695 train_time:99316ms step_avg:96.89ms +step:1026/1695 train_time:99413ms step_avg:96.89ms +step:1027/1695 train_time:99510ms step_avg:96.89ms +step:1028/1695 train_time:99605ms step_avg:96.89ms +step:1029/1695 train_time:99700ms step_avg:96.89ms +step:1030/1695 train_time:99795ms step_avg:96.89ms +step:1031/1695 train_time:99891ms step_avg:96.89ms +step:1032/1695 train_time:99987ms step_avg:96.89ms +step:1033/1695 train_time:100085ms step_avg:96.89ms +step:1034/1695 train_time:100181ms step_avg:96.89ms +step:1035/1695 train_time:100277ms step_avg:96.89ms +step:1036/1695 train_time:100616ms step_avg:97.12ms +step:1037/1695 train_time:100781ms step_avg:97.19ms +step:1038/1695 train_time:100876ms step_avg:97.18ms +step:1039/1695 train_time:100971ms step_avg:97.18ms +step:1040/1695 train_time:101066ms step_avg:97.18ms +step:1041/1695 train_time:101161ms step_avg:97.18ms +step:1042/1695 train_time:101256ms step_avg:97.17ms +step:1043/1695 train_time:101350ms step_avg:97.17ms +step:1044/1695 train_time:101444ms step_avg:97.17ms +step:1045/1695 train_time:101539ms step_avg:97.17ms +step:1046/1695 train_time:101641ms step_avg:97.17ms +step:1047/1695 train_time:101740ms step_avg:97.17ms +step:1048/1695 train_time:101838ms step_avg:97.17ms +step:1049/1695 train_time:101934ms step_avg:97.17ms +step:1050/1695 train_time:102031ms step_avg:97.17ms +step:1051/1695 train_time:102126ms step_avg:97.17ms +step:1052/1695 train_time:102220ms step_avg:97.17ms +step:1053/1695 train_time:102315ms step_avg:97.17ms +step:1054/1695 train_time:102411ms step_avg:97.16ms +step:1055/1695 train_time:102506ms step_avg:97.16ms +step:1056/1695 train_time:102603ms step_avg:97.16ms +step:1057/1695 train_time:102700ms step_avg:97.16ms +step:1058/1695 train_time:102797ms step_avg:97.16ms +step:1059/1695 train_time:102894ms step_avg:97.16ms +step:1060/1695 train_time:102991ms step_avg:97.16ms +step:1061/1695 train_time:103087ms step_avg:97.16ms +step:1062/1695 train_time:103182ms step_avg:97.16ms +step:1063/1695 train_time:103277ms step_avg:97.16ms +step:1064/1695 train_time:103373ms step_avg:97.16ms +step:1065/1695 train_time:103469ms step_avg:97.15ms +step:1066/1695 train_time:103565ms step_avg:97.15ms +step:1067/1695 train_time:103661ms step_avg:97.15ms +step:1068/1695 train_time:103756ms step_avg:97.15ms +step:1069/1695 train_time:103854ms step_avg:97.15ms +step:1070/1695 train_time:103951ms step_avg:97.15ms +step:1071/1695 train_time:104048ms step_avg:97.15ms +step:1072/1695 train_time:104144ms step_avg:97.15ms +step:1073/1695 train_time:104239ms step_avg:97.15ms +step:1074/1695 train_time:104335ms step_avg:97.15ms +step:1075/1695 train_time:104431ms step_avg:97.14ms +step:1076/1695 train_time:104527ms step_avg:97.14ms +step:1077/1695 train_time:104623ms step_avg:97.14ms +step:1078/1695 train_time:104719ms step_avg:97.14ms +step:1079/1695 train_time:104816ms step_avg:97.14ms +step:1080/1695 train_time:104912ms step_avg:97.14ms +step:1081/1695 train_time:105009ms step_avg:97.14ms +step:1082/1695 train_time:105105ms step_avg:97.14ms +step:1083/1695 train_time:105200ms step_avg:97.14ms +step:1084/1695 train_time:105296ms step_avg:97.14ms +step:1085/1695 train_time:105391ms step_avg:97.13ms +step:1086/1695 train_time:105487ms step_avg:97.13ms +step:1087/1695 train_time:105583ms step_avg:97.13ms +step:1088/1695 train_time:105678ms step_avg:97.13ms +step:1089/1695 train_time:105775ms step_avg:97.13ms +step:1090/1695 train_time:105871ms step_avg:97.13ms +step:1091/1695 train_time:105967ms step_avg:97.13ms +step:1092/1695 train_time:106063ms step_avg:97.13ms +step:1093/1695 train_time:106158ms step_avg:97.13ms +step:1094/1695 train_time:106254ms step_avg:97.12ms +step:1095/1695 train_time:106350ms step_avg:97.12ms +step:1096/1695 train_time:106446ms step_avg:97.12ms +step:1097/1695 train_time:106541ms step_avg:97.12ms +step:1098/1695 train_time:106636ms step_avg:97.12ms +step:1099/1695 train_time:106732ms step_avg:97.12ms +step:1100/1695 train_time:106829ms step_avg:97.12ms +step:1101/1695 train_time:106925ms step_avg:97.12ms +step:1102/1695 train_time:107021ms step_avg:97.12ms +step:1103/1695 train_time:107117ms step_avg:97.11ms +step:1104/1695 train_time:107213ms step_avg:97.11ms +step:1105/1695 train_time:107309ms step_avg:97.11ms +step:1106/1695 train_time:107405ms step_avg:97.11ms +step:1107/1695 train_time:107501ms step_avg:97.11ms +step:1108/1695 train_time:107597ms step_avg:97.11ms +step:1109/1695 train_time:107693ms step_avg:97.11ms +step:1110/1695 train_time:107790ms step_avg:97.11ms +step:1111/1695 train_time:107886ms step_avg:97.11ms +step:1112/1695 train_time:107982ms step_avg:97.11ms +step:1113/1695 train_time:108078ms step_avg:97.11ms +step:1114/1695 train_time:108174ms step_avg:97.10ms +step:1115/1695 train_time:108272ms step_avg:97.10ms +step:1116/1695 train_time:108369ms step_avg:97.10ms +step:1117/1695 train_time:108463ms step_avg:97.10ms +step:1118/1695 train_time:108558ms step_avg:97.10ms +step:1119/1695 train_time:108653ms step_avg:97.10ms +step:1120/1695 train_time:108749ms step_avg:97.10ms +step:1121/1695 train_time:108846ms step_avg:97.10ms +step:1122/1695 train_time:108942ms step_avg:97.10ms +step:1123/1695 train_time:109037ms step_avg:97.09ms +step:1124/1695 train_time:109133ms step_avg:97.09ms +step:1125/1695 train_time:109229ms step_avg:97.09ms +step:1125/1695 val_loss:3.4364 train_time:109322ms step_avg:97.18ms +step:1126/1695 train_time:109349ms step_avg:97.11ms +step:1127/1695 train_time:109426ms step_avg:97.10ms +step:1128/1695 train_time:109523ms step_avg:97.09ms +step:1129/1695 train_time:109619ms step_avg:97.09ms +step:1130/1695 train_time:109715ms step_avg:97.09ms +step:1131/1695 train_time:109810ms step_avg:97.09ms +step:1132/1695 train_time:109905ms step_avg:97.09ms +step:1133/1695 train_time:110001ms step_avg:97.09ms +step:1134/1695 train_time:110098ms step_avg:97.09ms +step:1135/1695 train_time:110195ms step_avg:97.09ms +step:1136/1695 train_time:110294ms step_avg:97.09ms +step:1137/1695 train_time:110397ms step_avg:97.09ms +step:1138/1695 train_time:110496ms step_avg:97.10ms +step:1139/1695 train_time:110596ms step_avg:97.10ms +step:1140/1695 train_time:110695ms step_avg:97.10ms +step:1141/1695 train_time:110793ms step_avg:97.10ms +step:1142/1695 train_time:110891ms step_avg:97.10ms +step:1143/1695 train_time:110988ms step_avg:97.10ms +step:1144/1695 train_time:111085ms step_avg:97.10ms +step:1145/1695 train_time:111182ms step_avg:97.10ms +step:1146/1695 train_time:111280ms step_avg:97.10ms +step:1147/1695 train_time:111379ms step_avg:97.10ms +step:1148/1695 train_time:111478ms step_avg:97.11ms +step:1149/1695 train_time:111576ms step_avg:97.11ms +step:1150/1695 train_time:111674ms step_avg:97.11ms +step:1151/1695 train_time:111773ms step_avg:97.11ms +step:1152/1695 train_time:111871ms step_avg:97.11ms +step:1153/1695 train_time:111969ms step_avg:97.11ms +step:1154/1695 train_time:112065ms step_avg:97.11ms +step:1155/1695 train_time:112162ms step_avg:97.11ms +step:1156/1695 train_time:112260ms step_avg:97.11ms +step:1157/1695 train_time:112358ms step_avg:97.11ms +step:1158/1695 train_time:112458ms step_avg:97.11ms +step:1159/1695 train_time:112557ms step_avg:97.12ms +step:1160/1695 train_time:112656ms step_avg:97.12ms +step:1161/1695 train_time:112755ms step_avg:97.12ms +step:1162/1695 train_time:112852ms step_avg:97.12ms +step:1163/1695 train_time:112951ms step_avg:97.12ms +step:1164/1695 train_time:113048ms step_avg:97.12ms +step:1165/1695 train_time:113144ms step_avg:97.12ms +step:1166/1695 train_time:113242ms step_avg:97.12ms +step:1167/1695 train_time:113340ms step_avg:97.12ms +step:1168/1695 train_time:113437ms step_avg:97.12ms +step:1169/1695 train_time:113536ms step_avg:97.12ms +step:1170/1695 train_time:113634ms step_avg:97.12ms +step:1171/1695 train_time:113733ms step_avg:97.12ms +step:1172/1695 train_time:113832ms step_avg:97.13ms +step:1173/1695 train_time:113931ms step_avg:97.13ms +step:1174/1695 train_time:114029ms step_avg:97.13ms +step:1175/1695 train_time:114127ms step_avg:97.13ms +step:1176/1695 train_time:114224ms step_avg:97.13ms +step:1177/1695 train_time:114322ms step_avg:97.13ms +step:1178/1695 train_time:114419ms step_avg:97.13ms +step:1179/1695 train_time:114517ms step_avg:97.13ms +step:1180/1695 train_time:114615ms step_avg:97.13ms +step:1181/1695 train_time:114713ms step_avg:97.13ms +step:1182/1695 train_time:114810ms step_avg:97.13ms +step:1183/1695 train_time:114908ms step_avg:97.13ms +step:1184/1695 train_time:115005ms step_avg:97.13ms +step:1185/1695 train_time:115103ms step_avg:97.13ms +step:1186/1695 train_time:115200ms step_avg:97.13ms +step:1187/1695 train_time:115298ms step_avg:97.13ms +step:1188/1695 train_time:115397ms step_avg:97.14ms +step:1189/1695 train_time:115495ms step_avg:97.14ms +step:1190/1695 train_time:115593ms step_avg:97.14ms +step:1191/1695 train_time:115691ms step_avg:97.14ms +step:1192/1695 train_time:115789ms step_avg:97.14ms +step:1193/1695 train_time:115887ms step_avg:97.14ms +step:1194/1695 train_time:115984ms step_avg:97.14ms +step:1195/1695 train_time:116082ms step_avg:97.14ms +step:1196/1695 train_time:116180ms step_avg:97.14ms +step:1197/1695 train_time:116279ms step_avg:97.14ms +step:1198/1695 train_time:116377ms step_avg:97.14ms +step:1199/1695 train_time:116476ms step_avg:97.14ms +step:1200/1695 train_time:116574ms step_avg:97.14ms +step:1201/1695 train_time:116672ms step_avg:97.15ms +step:1202/1695 train_time:116770ms step_avg:97.15ms +step:1203/1695 train_time:116868ms step_avg:97.15ms +step:1204/1695 train_time:116965ms step_avg:97.15ms +step:1205/1695 train_time:117062ms step_avg:97.15ms +step:1206/1695 train_time:117160ms step_avg:97.15ms +step:1207/1695 train_time:117259ms step_avg:97.15ms +step:1208/1695 train_time:117635ms step_avg:97.38ms +step:1209/1695 train_time:117773ms step_avg:97.41ms +step:1210/1695 train_time:117868ms step_avg:97.41ms +step:1211/1695 train_time:117964ms step_avg:97.41ms +step:1212/1695 train_time:118061ms step_avg:97.41ms +step:1213/1695 train_time:118157ms step_avg:97.41ms +step:1214/1695 train_time:118253ms step_avg:97.41ms +step:1215/1695 train_time:118350ms step_avg:97.41ms +step:1216/1695 train_time:118446ms step_avg:97.41ms +step:1217/1695 train_time:118542ms step_avg:97.41ms +step:1218/1695 train_time:118646ms step_avg:97.41ms +step:1219/1695 train_time:118750ms step_avg:97.42ms +step:1220/1695 train_time:118849ms step_avg:97.42ms +step:1221/1695 train_time:118947ms step_avg:97.42ms +step:1222/1695 train_time:119044ms step_avg:97.42ms +step:1223/1695 train_time:119140ms step_avg:97.42ms +step:1224/1695 train_time:119238ms step_avg:97.42ms +step:1225/1695 train_time:119335ms step_avg:97.42ms +step:1226/1695 train_time:119432ms step_avg:97.42ms +step:1227/1695 train_time:119529ms step_avg:97.42ms +step:1228/1695 train_time:119626ms step_avg:97.42ms +step:1229/1695 train_time:119726ms step_avg:97.42ms +step:1230/1695 train_time:119824ms step_avg:97.42ms +step:1231/1695 train_time:119923ms step_avg:97.42ms +step:1232/1695 train_time:120021ms step_avg:97.42ms +step:1233/1695 train_time:120119ms step_avg:97.42ms +step:1234/1695 train_time:120216ms step_avg:97.42ms +step:1235/1695 train_time:120313ms step_avg:97.42ms +step:1236/1695 train_time:120410ms step_avg:97.42ms +step:1237/1695 train_time:120506ms step_avg:97.42ms +step:1238/1695 train_time:120604ms step_avg:97.42ms +step:1239/1695 train_time:120703ms step_avg:97.42ms +step:1240/1695 train_time:120801ms step_avg:97.42ms +step:1241/1695 train_time:120900ms step_avg:97.42ms +step:1242/1695 train_time:120998ms step_avg:97.42ms +step:1243/1695 train_time:121096ms step_avg:97.42ms +step:1244/1695 train_time:121194ms step_avg:97.42ms +step:1245/1695 train_time:121292ms step_avg:97.42ms +step:1246/1695 train_time:121388ms step_avg:97.42ms +step:1247/1695 train_time:121485ms step_avg:97.42ms +step:1248/1695 train_time:121583ms step_avg:97.42ms +step:1249/1695 train_time:121682ms step_avg:97.42ms +step:1250/1695 train_time:121780ms step_avg:97.42ms +step:1250/1695 val_loss:3.3889 train_time:121876ms step_avg:97.50ms +step:1251/1695 train_time:121915ms step_avg:97.45ms +step:1252/1695 train_time:121984ms step_avg:97.43ms +step:1253/1695 train_time:122082ms step_avg:97.43ms +step:1254/1695 train_time:122178ms step_avg:97.43ms +step:1255/1695 train_time:122274ms step_avg:97.43ms +step:1256/1695 train_time:122371ms step_avg:97.43ms +step:1257/1695 train_time:122467ms step_avg:97.43ms +step:1258/1695 train_time:122564ms step_avg:97.43ms +step:1259/1695 train_time:122660ms step_avg:97.43ms +step:1260/1695 train_time:122759ms step_avg:97.43ms +step:1261/1695 train_time:122864ms step_avg:97.43ms +step:1262/1695 train_time:122963ms step_avg:97.43ms +step:1263/1695 train_time:123061ms step_avg:97.44ms +step:1264/1695 train_time:123158ms step_avg:97.43ms +step:1265/1695 train_time:123255ms step_avg:97.43ms +step:1266/1695 train_time:123352ms step_avg:97.43ms +step:1267/1695 train_time:123449ms step_avg:97.43ms +step:1268/1695 train_time:123546ms step_avg:97.43ms +step:1269/1695 train_time:123642ms step_avg:97.43ms +step:1270/1695 train_time:123739ms step_avg:97.43ms +step:1271/1695 train_time:123838ms step_avg:97.43ms +step:1272/1695 train_time:123937ms step_avg:97.43ms +step:1273/1695 train_time:124036ms step_avg:97.44ms +step:1274/1695 train_time:124135ms step_avg:97.44ms +step:1275/1695 train_time:124233ms step_avg:97.44ms +step:1276/1695 train_time:124329ms step_avg:97.44ms +step:1277/1695 train_time:124426ms step_avg:97.44ms +step:1278/1695 train_time:124524ms step_avg:97.44ms +step:1279/1695 train_time:124619ms step_avg:97.44ms +step:1280/1695 train_time:124717ms step_avg:97.43ms +step:1281/1695 train_time:124815ms step_avg:97.44ms +step:1282/1695 train_time:124914ms step_avg:97.44ms +step:1283/1695 train_time:125012ms step_avg:97.44ms +step:1284/1695 train_time:125111ms step_avg:97.44ms +step:1285/1695 train_time:125208ms step_avg:97.44ms +step:1286/1695 train_time:125306ms step_avg:97.44ms +step:1287/1695 train_time:125403ms step_avg:97.44ms +step:1288/1695 train_time:125499ms step_avg:97.44ms +step:1289/1695 train_time:125597ms step_avg:97.44ms +step:1290/1695 train_time:125693ms step_avg:97.44ms +step:1291/1695 train_time:125793ms step_avg:97.44ms +step:1292/1695 train_time:125892ms step_avg:97.44ms +step:1293/1695 train_time:125993ms step_avg:97.44ms +step:1294/1695 train_time:126092ms step_avg:97.44ms +step:1295/1695 train_time:126191ms step_avg:97.44ms +step:1296/1695 train_time:126289ms step_avg:97.45ms +step:1297/1695 train_time:126387ms step_avg:97.45ms +step:1298/1695 train_time:126485ms step_avg:97.45ms +step:1299/1695 train_time:126581ms step_avg:97.45ms +step:1300/1695 train_time:126679ms step_avg:97.45ms +step:1301/1695 train_time:126776ms step_avg:97.44ms +step:1302/1695 train_time:126874ms step_avg:97.45ms +step:1303/1695 train_time:126975ms step_avg:97.45ms +step:1304/1695 train_time:127073ms step_avg:97.45ms +step:1305/1695 train_time:127173ms step_avg:97.45ms +step:1306/1695 train_time:127270ms step_avg:97.45ms +step:1307/1695 train_time:127369ms step_avg:97.45ms +step:1308/1695 train_time:127467ms step_avg:97.45ms +step:1309/1695 train_time:127564ms step_avg:97.45ms +step:1310/1695 train_time:127661ms step_avg:97.45ms +step:1311/1695 train_time:127759ms step_avg:97.45ms +step:1312/1695 train_time:127855ms step_avg:97.45ms +step:1313/1695 train_time:127953ms step_avg:97.45ms +step:1314/1695 train_time:128051ms step_avg:97.45ms +step:1315/1695 train_time:128149ms step_avg:97.45ms +step:1316/1695 train_time:128246ms step_avg:97.45ms +step:1317/1695 train_time:128343ms step_avg:97.45ms +step:1318/1695 train_time:128440ms step_avg:97.45ms +step:1319/1695 train_time:128538ms step_avg:97.45ms +step:1320/1695 train_time:128636ms step_avg:97.45ms +step:1321/1695 train_time:128734ms step_avg:97.45ms +step:1322/1695 train_time:128833ms step_avg:97.45ms +step:1323/1695 train_time:128932ms step_avg:97.45ms +step:1324/1695 train_time:129029ms step_avg:97.45ms +step:1325/1695 train_time:129127ms step_avg:97.45ms +step:1326/1695 train_time:129225ms step_avg:97.45ms +step:1327/1695 train_time:129322ms step_avg:97.45ms +step:1328/1695 train_time:129419ms step_avg:97.45ms +step:1329/1695 train_time:129517ms step_avg:97.45ms +step:1330/1695 train_time:129615ms step_avg:97.45ms +step:1331/1695 train_time:129713ms step_avg:97.46ms +step:1332/1695 train_time:129811ms step_avg:97.46ms +step:1333/1695 train_time:129910ms step_avg:97.46ms +step:1334/1695 train_time:130007ms step_avg:97.46ms +step:1335/1695 train_time:130105ms step_avg:97.46ms +step:1336/1695 train_time:130204ms step_avg:97.46ms +step:1337/1695 train_time:130300ms step_avg:97.46ms +step:1338/1695 train_time:130398ms step_avg:97.46ms +step:1339/1695 train_time:130495ms step_avg:97.46ms +step:1340/1695 train_time:130593ms step_avg:97.46ms +step:1341/1695 train_time:130692ms step_avg:97.46ms +step:1342/1695 train_time:130790ms step_avg:97.46ms +step:1343/1695 train_time:130888ms step_avg:97.46ms +step:1344/1695 train_time:130986ms step_avg:97.46ms +step:1345/1695 train_time:131084ms step_avg:97.46ms +step:1346/1695 train_time:131182ms step_avg:97.46ms +step:1347/1695 train_time:131280ms step_avg:97.46ms +step:1348/1695 train_time:131377ms step_avg:97.46ms +step:1349/1695 train_time:131475ms step_avg:97.46ms +step:1350/1695 train_time:131572ms step_avg:97.46ms +step:1351/1695 train_time:131672ms step_avg:97.46ms +step:1352/1695 train_time:131770ms step_avg:97.46ms +step:1353/1695 train_time:131869ms step_avg:97.46ms +step:1354/1695 train_time:131966ms step_avg:97.46ms +step:1355/1695 train_time:132064ms step_avg:97.46ms +step:1356/1695 train_time:132161ms step_avg:97.46ms +step:1357/1695 train_time:132258ms step_avg:97.46ms +step:1358/1695 train_time:132356ms step_avg:97.46ms +step:1359/1695 train_time:132454ms step_avg:97.46ms +step:1360/1695 train_time:132552ms step_avg:97.46ms +step:1361/1695 train_time:132650ms step_avg:97.46ms +step:1362/1695 train_time:132748ms step_avg:97.47ms +step:1363/1695 train_time:132845ms step_avg:97.47ms +step:1364/1695 train_time:132942ms step_avg:97.46ms +step:1365/1695 train_time:133038ms step_avg:97.46ms +step:1366/1695 train_time:133136ms step_avg:97.46ms +step:1367/1695 train_time:133235ms step_avg:97.46ms +step:1368/1695 train_time:133333ms step_avg:97.47ms +step:1369/1695 train_time:133430ms step_avg:97.47ms +step:1370/1695 train_time:133529ms step_avg:97.47ms +step:1371/1695 train_time:133630ms step_avg:97.47ms +step:1372/1695 train_time:133725ms step_avg:97.47ms +step:1373/1695 train_time:133823ms step_avg:97.47ms +step:1374/1695 train_time:133921ms step_avg:97.47ms +step:1375/1695 train_time:134023ms step_avg:97.47ms +step:1375/1695 val_loss:3.3505 train_time:134113ms step_avg:97.54ms +step:1376/1695 train_time:134162ms step_avg:97.50ms +step:1377/1695 train_time:134219ms step_avg:97.47ms +step:1378/1695 train_time:134319ms step_avg:97.47ms +step:1379/1695 train_time:134416ms step_avg:97.47ms +step:1380/1695 train_time:134512ms step_avg:97.47ms +step:1381/1695 train_time:134884ms step_avg:97.67ms +step:1382/1695 train_time:135043ms step_avg:97.72ms +step:1383/1695 train_time:135139ms step_avg:97.71ms +step:1384/1695 train_time:135235ms step_avg:97.71ms +step:1385/1695 train_time:135332ms step_avg:97.71ms +step:1386/1695 train_time:135428ms step_avg:97.71ms +step:1387/1695 train_time:135526ms step_avg:97.71ms +step:1388/1695 train_time:135621ms step_avg:97.71ms +step:1389/1695 train_time:135716ms step_avg:97.71ms +step:1390/1695 train_time:135813ms step_avg:97.71ms +step:1391/1695 train_time:135916ms step_avg:97.71ms +step:1392/1695 train_time:136020ms step_avg:97.72ms +step:1393/1695 train_time:136121ms step_avg:97.72ms +step:1394/1695 train_time:136218ms step_avg:97.72ms +step:1395/1695 train_time:136315ms step_avg:97.72ms +step:1396/1695 train_time:136412ms step_avg:97.72ms +step:1397/1695 train_time:136509ms step_avg:97.72ms +step:1398/1695 train_time:136607ms step_avg:97.72ms +step:1399/1695 train_time:136704ms step_avg:97.72ms +step:1400/1695 train_time:136801ms step_avg:97.71ms +step:1401/1695 train_time:136899ms step_avg:97.72ms +step:1402/1695 train_time:136997ms step_avg:97.72ms +step:1403/1695 train_time:137096ms step_avg:97.72ms +step:1404/1695 train_time:137194ms step_avg:97.72ms +step:1405/1695 train_time:137293ms step_avg:97.72ms +step:1406/1695 train_time:137390ms step_avg:97.72ms +step:1407/1695 train_time:137487ms step_avg:97.72ms +step:1408/1695 train_time:137584ms step_avg:97.72ms +step:1409/1695 train_time:137681ms step_avg:97.72ms +step:1410/1695 train_time:137778ms step_avg:97.72ms +step:1411/1695 train_time:137876ms step_avg:97.72ms +step:1412/1695 train_time:137975ms step_avg:97.72ms +step:1413/1695 train_time:138075ms step_avg:97.72ms +step:1414/1695 train_time:138173ms step_avg:97.72ms +step:1415/1695 train_time:138272ms step_avg:97.72ms +step:1416/1695 train_time:138368ms step_avg:97.72ms +step:1417/1695 train_time:138466ms step_avg:97.72ms +step:1418/1695 train_time:138563ms step_avg:97.72ms +step:1419/1695 train_time:138661ms step_avg:97.72ms +step:1420/1695 train_time:138757ms step_avg:97.72ms +step:1421/1695 train_time:138854ms step_avg:97.72ms +step:1422/1695 train_time:138953ms step_avg:97.72ms +step:1423/1695 train_time:139053ms step_avg:97.72ms +step:1424/1695 train_time:139153ms step_avg:97.72ms +step:1425/1695 train_time:139252ms step_avg:97.72ms +step:1426/1695 train_time:139350ms step_avg:97.72ms +step:1427/1695 train_time:139448ms step_avg:97.72ms +step:1428/1695 train_time:139546ms step_avg:97.72ms +step:1429/1695 train_time:139644ms step_avg:97.72ms +step:1430/1695 train_time:139742ms step_avg:97.72ms +step:1431/1695 train_time:139839ms step_avg:97.72ms +step:1432/1695 train_time:139938ms step_avg:97.72ms +step:1433/1695 train_time:140036ms step_avg:97.72ms +step:1434/1695 train_time:140134ms step_avg:97.72ms +step:1435/1695 train_time:140232ms step_avg:97.72ms +step:1436/1695 train_time:140330ms step_avg:97.72ms +step:1437/1695 train_time:140428ms step_avg:97.72ms +step:1438/1695 train_time:140526ms step_avg:97.72ms +step:1439/1695 train_time:140623ms step_avg:97.72ms +step:1440/1695 train_time:140720ms step_avg:97.72ms +step:1441/1695 train_time:140818ms step_avg:97.72ms +step:1442/1695 train_time:140917ms step_avg:97.72ms +step:1443/1695 train_time:141014ms step_avg:97.72ms +step:1444/1695 train_time:141113ms step_avg:97.72ms +step:1445/1695 train_time:141213ms step_avg:97.73ms +step:1446/1695 train_time:141312ms step_avg:97.73ms +step:1447/1695 train_time:141411ms step_avg:97.73ms +step:1448/1695 train_time:141508ms step_avg:97.73ms +step:1449/1695 train_time:141607ms step_avg:97.73ms +step:1450/1695 train_time:141706ms step_avg:97.73ms +step:1451/1695 train_time:141804ms step_avg:97.73ms +step:1452/1695 train_time:141902ms step_avg:97.73ms +step:1453/1695 train_time:141999ms step_avg:97.73ms +step:1454/1695 train_time:142097ms step_avg:97.73ms +step:1455/1695 train_time:142194ms step_avg:97.73ms +step:1456/1695 train_time:142292ms step_avg:97.73ms +step:1457/1695 train_time:142390ms step_avg:97.73ms +step:1458/1695 train_time:142488ms step_avg:97.73ms +step:1459/1695 train_time:142588ms step_avg:97.73ms +step:1460/1695 train_time:142685ms step_avg:97.73ms +step:1461/1695 train_time:142784ms step_avg:97.73ms +step:1462/1695 train_time:142881ms step_avg:97.73ms +step:1463/1695 train_time:142978ms step_avg:97.73ms +step:1464/1695 train_time:143076ms step_avg:97.73ms +step:1465/1695 train_time:143173ms step_avg:97.73ms +step:1466/1695 train_time:143272ms step_avg:97.73ms +step:1467/1695 train_time:143369ms step_avg:97.73ms +step:1468/1695 train_time:143467ms step_avg:97.73ms +step:1469/1695 train_time:143564ms step_avg:97.73ms +step:1470/1695 train_time:143661ms step_avg:97.73ms +step:1471/1695 train_time:143759ms step_avg:97.73ms +step:1472/1695 train_time:143856ms step_avg:97.73ms +step:1473/1695 train_time:143956ms step_avg:97.73ms +step:1474/1695 train_time:144053ms step_avg:97.73ms +step:1475/1695 train_time:144149ms step_avg:97.73ms +step:1476/1695 train_time:144248ms step_avg:97.73ms +step:1477/1695 train_time:144346ms step_avg:97.73ms +step:1478/1695 train_time:144445ms step_avg:97.73ms +step:1479/1695 train_time:144541ms step_avg:97.73ms +step:1480/1695 train_time:144639ms step_avg:97.73ms +step:1481/1695 train_time:144737ms step_avg:97.73ms +step:1482/1695 train_time:144835ms step_avg:97.73ms +step:1483/1695 train_time:144933ms step_avg:97.73ms +step:1484/1695 train_time:145031ms step_avg:97.73ms +step:1485/1695 train_time:145128ms step_avg:97.73ms +step:1486/1695 train_time:145225ms step_avg:97.73ms +step:1487/1695 train_time:145323ms step_avg:97.73ms +step:1488/1695 train_time:145421ms step_avg:97.73ms +step:1489/1695 train_time:145519ms step_avg:97.73ms +step:1490/1695 train_time:145616ms step_avg:97.73ms +step:1491/1695 train_time:145714ms step_avg:97.73ms +step:1492/1695 train_time:145813ms step_avg:97.73ms +step:1493/1695 train_time:145910ms step_avg:97.73ms +step:1494/1695 train_time:146008ms step_avg:97.73ms +step:1495/1695 train_time:146105ms step_avg:97.73ms +step:1496/1695 train_time:146203ms step_avg:97.73ms +step:1497/1695 train_time:146299ms step_avg:97.73ms +step:1498/1695 train_time:146396ms step_avg:97.73ms +step:1499/1695 train_time:146495ms step_avg:97.73ms +step:1500/1695 train_time:146593ms step_avg:97.73ms +step:1500/1695 val_loss:3.3176 train_time:146690ms step_avg:97.79ms +step:1501/1695 train_time:146740ms step_avg:97.76ms +step:1502/1695 train_time:146800ms step_avg:97.74ms +step:1503/1695 train_time:146898ms step_avg:97.74ms +step:1504/1695 train_time:146995ms step_avg:97.74ms +step:1505/1695 train_time:147094ms step_avg:97.74ms +step:1506/1695 train_time:147190ms step_avg:97.74ms +step:1507/1695 train_time:147287ms step_avg:97.74ms +step:1508/1695 train_time:147383ms step_avg:97.73ms +step:1509/1695 train_time:147481ms step_avg:97.73ms +step:1510/1695 train_time:147577ms step_avg:97.73ms +step:1511/1695 train_time:147677ms step_avg:97.73ms +step:1512/1695 train_time:147778ms step_avg:97.74ms +step:1513/1695 train_time:147878ms step_avg:97.74ms +step:1514/1695 train_time:147975ms step_avg:97.74ms +step:1515/1695 train_time:148073ms step_avg:97.74ms +step:1516/1695 train_time:148170ms step_avg:97.74ms +step:1517/1695 train_time:148269ms step_avg:97.74ms +step:1518/1695 train_time:148367ms step_avg:97.74ms +step:1519/1695 train_time:148463ms step_avg:97.74ms +step:1520/1695 train_time:148560ms step_avg:97.74ms +step:1521/1695 train_time:148658ms step_avg:97.74ms +step:1522/1695 train_time:148758ms step_avg:97.74ms +step:1523/1695 train_time:148857ms step_avg:97.74ms +step:1524/1695 train_time:148956ms step_avg:97.74ms +step:1525/1695 train_time:149055ms step_avg:97.74ms +step:1526/1695 train_time:149152ms step_avg:97.74ms +step:1527/1695 train_time:149250ms step_avg:97.74ms +step:1528/1695 train_time:149347ms step_avg:97.74ms +step:1529/1695 train_time:149444ms step_avg:97.74ms +step:1530/1695 train_time:149541ms step_avg:97.74ms +step:1531/1695 train_time:149638ms step_avg:97.74ms +step:1532/1695 train_time:149735ms step_avg:97.74ms +step:1533/1695 train_time:149835ms step_avg:97.74ms +step:1534/1695 train_time:149933ms step_avg:97.74ms +step:1535/1695 train_time:150032ms step_avg:97.74ms +step:1536/1695 train_time:150129ms step_avg:97.74ms +step:1537/1695 train_time:150227ms step_avg:97.74ms +step:1538/1695 train_time:150324ms step_avg:97.74ms +step:1539/1695 train_time:150421ms step_avg:97.74ms +step:1540/1695 train_time:150519ms step_avg:97.74ms +step:1541/1695 train_time:150616ms step_avg:97.74ms +step:1542/1695 train_time:150715ms step_avg:97.74ms +step:1543/1695 train_time:150813ms step_avg:97.74ms +step:1544/1695 train_time:150912ms step_avg:97.74ms +step:1545/1695 train_time:151010ms step_avg:97.74ms +step:1546/1695 train_time:151108ms step_avg:97.74ms +step:1547/1695 train_time:151206ms step_avg:97.74ms +step:1548/1695 train_time:151304ms step_avg:97.74ms +step:1549/1695 train_time:151401ms step_avg:97.74ms +step:1550/1695 train_time:151498ms step_avg:97.74ms +step:1551/1695 train_time:151596ms step_avg:97.74ms +step:1552/1695 train_time:151941ms step_avg:97.90ms +step:1553/1695 train_time:152117ms step_avg:97.95ms +step:1554/1695 train_time:152212ms step_avg:97.95ms +step:1555/1695 train_time:152308ms step_avg:97.95ms +step:1556/1695 train_time:152403ms step_avg:97.95ms +step:1557/1695 train_time:152499ms step_avg:97.94ms +step:1558/1695 train_time:152596ms step_avg:97.94ms +step:1559/1695 train_time:152693ms step_avg:97.94ms +step:1560/1695 train_time:152790ms step_avg:97.94ms +step:1561/1695 train_time:152886ms step_avg:97.94ms +step:1562/1695 train_time:152987ms step_avg:97.94ms +step:1563/1695 train_time:153089ms step_avg:97.95ms +step:1564/1695 train_time:153189ms step_avg:97.95ms +step:1565/1695 train_time:153285ms step_avg:97.95ms +step:1566/1695 train_time:153382ms step_avg:97.95ms +step:1567/1695 train_time:153479ms step_avg:97.94ms +step:1568/1695 train_time:153576ms step_avg:97.94ms +step:1569/1695 train_time:153673ms step_avg:97.94ms +step:1570/1695 train_time:153769ms step_avg:97.94ms +step:1571/1695 train_time:153866ms step_avg:97.94ms +step:1572/1695 train_time:153964ms step_avg:97.94ms +step:1573/1695 train_time:154064ms step_avg:97.94ms +step:1574/1695 train_time:154164ms step_avg:97.94ms +step:1575/1695 train_time:154262ms step_avg:97.94ms +step:1576/1695 train_time:154360ms step_avg:97.94ms +step:1577/1695 train_time:154457ms step_avg:97.94ms +step:1578/1695 train_time:154554ms step_avg:97.94ms +step:1579/1695 train_time:154651ms step_avg:97.94ms +step:1580/1695 train_time:154748ms step_avg:97.94ms +step:1581/1695 train_time:154845ms step_avg:97.94ms +step:1582/1695 train_time:154943ms step_avg:97.94ms +step:1583/1695 train_time:155042ms step_avg:97.94ms +step:1584/1695 train_time:155142ms step_avg:97.94ms +step:1585/1695 train_time:155239ms step_avg:97.94ms +step:1586/1695 train_time:155337ms step_avg:97.94ms +step:1587/1695 train_time:155435ms step_avg:97.94ms +step:1588/1695 train_time:155532ms step_avg:97.94ms +step:1589/1695 train_time:155630ms step_avg:97.94ms +step:1590/1695 train_time:155726ms step_avg:97.94ms +step:1591/1695 train_time:155823ms step_avg:97.94ms +step:1592/1695 train_time:155921ms step_avg:97.94ms +step:1593/1695 train_time:156019ms step_avg:97.94ms +step:1594/1695 train_time:156119ms step_avg:97.94ms +step:1595/1695 train_time:156218ms step_avg:97.94ms +step:1596/1695 train_time:156317ms step_avg:97.94ms +step:1597/1695 train_time:156416ms step_avg:97.94ms +step:1598/1695 train_time:156513ms step_avg:97.94ms +step:1599/1695 train_time:156610ms step_avg:97.94ms +step:1600/1695 train_time:156707ms step_avg:97.94ms +step:1601/1695 train_time:156804ms step_avg:97.94ms +step:1602/1695 train_time:156901ms step_avg:97.94ms +step:1603/1695 train_time:156999ms step_avg:97.94ms +step:1604/1695 train_time:157097ms step_avg:97.94ms +step:1605/1695 train_time:157196ms step_avg:97.94ms +step:1606/1695 train_time:157295ms step_avg:97.94ms +step:1607/1695 train_time:157393ms step_avg:97.94ms +step:1608/1695 train_time:157491ms step_avg:97.94ms +step:1609/1695 train_time:157590ms step_avg:97.94ms +step:1610/1695 train_time:157687ms step_avg:97.94ms +step:1611/1695 train_time:157784ms step_avg:97.94ms +step:1612/1695 train_time:157881ms step_avg:97.94ms +step:1613/1695 train_time:157978ms step_avg:97.94ms +step:1614/1695 train_time:158077ms step_avg:97.94ms +step:1615/1695 train_time:158176ms step_avg:97.94ms +step:1616/1695 train_time:158274ms step_avg:97.94ms +step:1617/1695 train_time:158373ms step_avg:97.94ms +step:1618/1695 train_time:158471ms step_avg:97.94ms +step:1619/1695 train_time:158569ms step_avg:97.94ms +step:1620/1695 train_time:158666ms step_avg:97.94ms +step:1621/1695 train_time:158765ms step_avg:97.94ms +step:1622/1695 train_time:158861ms step_avg:97.94ms +step:1623/1695 train_time:158958ms step_avg:97.94ms +step:1624/1695 train_time:159057ms step_avg:97.94ms +step:1625/1695 train_time:159155ms step_avg:97.94ms +step:1625/1695 val_loss:3.2898 train_time:159251ms step_avg:98.00ms +step:1626/1695 train_time:159278ms step_avg:97.96ms +step:1627/1695 train_time:159359ms step_avg:97.95ms +step:1628/1695 train_time:159459ms step_avg:97.95ms +step:1629/1695 train_time:159560ms step_avg:97.95ms +step:1630/1695 train_time:159658ms step_avg:97.95ms +step:1631/1695 train_time:159755ms step_avg:97.95ms +step:1632/1695 train_time:159852ms step_avg:97.95ms +step:1633/1695 train_time:159950ms step_avg:97.95ms +step:1634/1695 train_time:160046ms step_avg:97.95ms +step:1635/1695 train_time:160142ms step_avg:97.95ms +step:1636/1695 train_time:160242ms step_avg:97.95ms +step:1637/1695 train_time:160342ms step_avg:97.95ms +step:1638/1695 train_time:160441ms step_avg:97.95ms +step:1639/1695 train_time:160539ms step_avg:97.95ms +step:1640/1695 train_time:160639ms step_avg:97.95ms +step:1641/1695 train_time:160737ms step_avg:97.95ms +step:1642/1695 train_time:160835ms step_avg:97.95ms +step:1643/1695 train_time:160933ms step_avg:97.95ms +step:1644/1695 train_time:161031ms step_avg:97.95ms +step:1645/1695 train_time:161128ms step_avg:97.95ms +step:1646/1695 train_time:161225ms step_avg:97.95ms +step:1647/1695 train_time:161323ms step_avg:97.95ms +step:1648/1695 train_time:161421ms step_avg:97.95ms +step:1649/1695 train_time:161519ms step_avg:97.95ms +step:1650/1695 train_time:161618ms step_avg:97.95ms +step:1651/1695 train_time:161717ms step_avg:97.95ms +step:1652/1695 train_time:161816ms step_avg:97.95ms +step:1653/1695 train_time:161914ms step_avg:97.95ms +step:1654/1695 train_time:162010ms step_avg:97.95ms +step:1655/1695 train_time:162107ms step_avg:97.95ms +step:1656/1695 train_time:162205ms step_avg:97.95ms +step:1657/1695 train_time:162302ms step_avg:97.95ms +step:1658/1695 train_time:162399ms step_avg:97.95ms +step:1659/1695 train_time:162496ms step_avg:97.95ms +step:1660/1695 train_time:162595ms step_avg:97.95ms +step:1661/1695 train_time:162693ms step_avg:97.95ms +step:1662/1695 train_time:162791ms step_avg:97.95ms +step:1663/1695 train_time:162888ms step_avg:97.95ms +step:1664/1695 train_time:162985ms step_avg:97.95ms +step:1665/1695 train_time:163083ms step_avg:97.95ms +step:1666/1695 train_time:163181ms step_avg:97.95ms +step:1667/1695 train_time:163280ms step_avg:97.95ms +step:1668/1695 train_time:163379ms step_avg:97.95ms +step:1669/1695 train_time:163476ms step_avg:97.95ms +step:1670/1695 train_time:163574ms step_avg:97.95ms +step:1671/1695 train_time:163672ms step_avg:97.95ms +step:1672/1695 train_time:163769ms step_avg:97.95ms +step:1673/1695 train_time:163867ms step_avg:97.95ms +step:1674/1695 train_time:163964ms step_avg:97.95ms +step:1675/1695 train_time:164062ms step_avg:97.95ms +step:1676/1695 train_time:164160ms step_avg:97.95ms +step:1677/1695 train_time:164258ms step_avg:97.95ms +step:1678/1695 train_time:164359ms step_avg:97.95ms +step:1679/1695 train_time:164455ms step_avg:97.95ms +step:1680/1695 train_time:164553ms step_avg:97.95ms +step:1681/1695 train_time:164651ms step_avg:97.95ms +step:1682/1695 train_time:164750ms step_avg:97.95ms +step:1683/1695 train_time:164849ms step_avg:97.95ms +step:1684/1695 train_time:164947ms step_avg:97.95ms +step:1685/1695 train_time:165044ms step_avg:97.95ms +step:1686/1695 train_time:165141ms step_avg:97.95ms +step:1687/1695 train_time:165239ms step_avg:97.95ms +step:1688/1695 train_time:165336ms step_avg:97.95ms +step:1689/1695 train_time:165435ms step_avg:97.95ms +step:1690/1695 train_time:165533ms step_avg:97.95ms +step:1691/1695 train_time:165631ms step_avg:97.95ms +step:1692/1695 train_time:165728ms step_avg:97.95ms +step:1693/1695 train_time:165825ms step_avg:97.95ms +step:1694/1695 train_time:165922ms step_avg:97.95ms +step:1695/1695 train_time:166021ms step_avg:97.95ms +step:1695/1695 val_loss:3.2782 train_time:166117ms step_avg:98.00ms +peak memory allocated: 34001 MiB reserved: 49716 MiB diff --git a/requirements.txt b/requirements.txt index 80dc92a80..f97ac89a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ numpy tqdm torch huggingface-hub -triton \ No newline at end of file +triton +flash-attn \ No newline at end of file diff --git a/train_gpt.py b/train_gpt.py index 74556953b..bbb431bed 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -16,11 +16,13 @@ from torch import Tensor, nn import torch.nn.functional as F import torch.distributed as dist -# use of FlexAttention contributed by @KoszarskyB -from torch.nn.attention.flex_attention import BlockMask, flex_attention #torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np import triton import triton.language as tl +from flash_attn_interface import flash_attn_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 # ----------------------------------------------------------------------------- # Custom operators: FP8 matmul by @YouJiacheng @@ -598,10 +600,9 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) self.attn_gate.weight.detach().zero_() - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): B, T = x.size(0), x.size(1) # batch size, sequence length - assert B == 1, "Must use batch size = 1 for FlexAttention" - + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 q, k = self.rotary(q), self.rotary(k) @@ -609,7 +610,8 @@ def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: Blo v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 else: # skip mid-layers token value embeddings by @YouJiacheng v = lambdas[0] * v - y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) + + y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal y = y.view(B, T, self.num_heads, self.head_dim) y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side @@ -643,10 +645,10 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None self.mlp = MLP(dim) - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): x = lambdas[0] * x + lambdas[1] * x0 if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) x = x + self.mlp(norm(x)) return x @@ -687,59 +689,20 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.lm_head.weight.lr_mul = 1.0 self.scalars.lr_mul = 5.0 - def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): - BLOCK_SIZE = 128 - docs = (input_seq == 50256).cumsum(0) - - def document_causal(b, h, q_idx, kv_idx): - causal_mask = q_idx >= kv_idx - document_mask = docs[q_idx] == docs[kv_idx] - return causal_mask & document_mask - - def dense_to_ordered(dense_blockmask: Tensor): - num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) - indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) - return num_blocks[None, None].contiguous(), indices[None, None].contiguous() - - # manual block mask creation by @YouJiacheng - assert len(input_seq) % BLOCK_SIZE == 0 - NUM_BLOCKS = len(input_seq) // BLOCK_SIZE - block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") - causal_blockmask_any = block_idx[:, None] >= block_idx - causal_blockmask_all = block_idx[:, None] > block_idx - docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() - docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() - document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) - document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) - blockmask_any = causal_blockmask_any & document_blockmask_any - blockmask_all = causal_blockmask_all & document_blockmask_all - partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) - full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) - def build_bm(window_size_blocks: Tensor) -> BlockMask: - return BlockMask.from_kv_blocks( - torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), - partial_kv_indices, - torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), - full_kv_indices, - BLOCK_SIZE=BLOCK_SIZE, - mask_mod=document_causal, - ) - # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper - return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) - - def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): - assert input_seq.ndim == 1 + + def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): + assert input_seq.ndim == 2 ve = [value_embed(input_seq) for value_embed in self.value_embeds] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] assert len(ve) == len(self.blocks) - long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) - block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(block_masks) == len(self.blocks) + long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) - x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 # U-net design by @brendanh0gan skip_connections = [] @@ -752,7 +715,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_bloc for i in range(len(self.blocks)): if i >= n: x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) if i < n: skip_connections.append(x) @@ -760,7 +723,8 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_bloc logits = self.lm_head(x).float() # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), + reduction="sum" if self.training else "mean") return loss # ----------------------------------------------------------------------------- @@ -778,46 +742,104 @@ def _load_data_shard(file: Path): assert nbytes == 2 * num_tokens, "number of tokens read does not match header" return tokens -# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap by @classiclarryd -def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): - boundary_mask = tokens[pos : pos + token_window] == 50256 - boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos - start = boundary_positions[0].item() - starts = [] - for i in range(1, len(boundary_positions)): - end = boundary_positions[i].item() - if end - start >= seq_len: - starts.append(start) # append start once end pos is confirmed - if len(starts) == dist.get_world_size(): - return starts, end - pos - start = end - assert False # increase token_window if necessary - -def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): - rank = dist.get_rank() - world_size = dist.get_world_size() - batch_size = seq_len * world_size +class EOSBatchFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): + # Precompute EOS positions once per shard + self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 # pointer into eos_idx (start EOS for next step) + self.pos = 0 # logical stream position within this shard + self.world_size = world_size + def seek(self, pos: int): + # Set pointer to the first EOS >= pos + self.i = np.searchsorted(self.eos_idx, pos) + if self.i >= len(self.eos_idx): + raise StopIteration("Seek past last EOS.") + self.pos = pos + def next_batch(self, batch_size_local: int, seq_len: int): + n = len(self.eos_idx) + if self.i >= n: + raise StopIteration("No more EOS in this shard.") + starts = [[] for _ in range(self.world_size)] + idx = self.i + cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 + for r in range(self.world_size): + for _ in range(batch_size_local): + start = cur + 1 + target = start + seq_len # need seq_len tokens before next EOS + j = np.searchsorted(self.eos_idx, target) + if j >= n: + raise StopIteration("Insufficient EOS ahead; hit tail of shard.") + starts[r].append(start) + idx = j + cur = self.eos_idx[idx] # next seq must also start at a new doc + advance = self.eos_idx[idx] - self.pos # move stream to the last end + self.pos += advance + self.i = idx + return starts, advance + + +def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert batch_size % world_size == 0, "Batch size must be divisible by world size" + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training tokens, pos = _load_data_shard(next(file_iter)), 0 + + finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None + if align_to_bos: finder.seek(pos) + while True: - token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len - if pos + token_window + 1 >= len(tokens): - tokens = _load_data_shard(next(file_iter)) - pos = 0 - for _ in range(grad_accum_steps): - if align_to_bos: - batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) - start_idx = batch_starts[rank] - else: - tokens_consumed = batch_size - start_idx = pos + rank * seq_len - buf = tokens[start_idx:][:seq_len + 1] - inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; - targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. - pos += tokens_consumed - token_window -= tokens_consumed - yield inputs, targets + batch_size_local = batch_size // world_size + num_tokens_global = batch_size * seq_len + + if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): + tokens, pos = _load_data_shard(next(file_iter)), 0 + + if align_to_bos: + try: + batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) + start_idxs = batch_starts[rank] + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, pos = _load_data_shard(next(file_iter)), 0 + finder = EOSBatchFinder(tokens, world_size=world_size) + continue + + bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] + buf = torch.stack(bufs, dim=0) + _inputs = buf[:, :-1] + _targets = buf[:, 1:] + else: + batch_span = num_tokens_global + start_pos_local = pos + rank * (batch_size_local * seq_len) + end_pos_local = start_pos_local + (batch_size_local * seq_len) + + buf = tokens[start_pos_local: end_pos_local + 1] + + _inputs = buf[:-1].view(batch_size_local, seq_len) + _targets = buf[1:].view(batch_size_local, seq_len) + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + ) + + pos += batch_span + + if new_params is not None: + # makes it possible for generator to recieve new (batch_size, seq_len) via .send() + new_batch_size, new_seq_len = new_params + assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" + batch_size = new_batch_size + seq_len = new_seq_len + # ----------------------------------------------------------------------------- # int main @@ -825,18 +847,23 @@ def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_s @dataclass class Hyperparameters: # data - train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len = 48*1024 # FlexAttention sequence length - val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len: int = 1024 * 2 + train_batch_size: int = 24 * 8 + val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. # optimization - num_iterations = 1695 # number of iterations to run - cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + num_iterations: int = 1695 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate # evaluation and logging - run_id = uuid.uuid4() - val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint = False + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + bandwidth: int = 128 + ws_schedule: tuple = (3, 7, 11) + args = Hyperparameters() data_path = os.environ.get("DATA_PATH", ".") @@ -876,13 +903,20 @@ def print0(s, console=False): print0(f"Running Python {sys.version}") print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") print0(f"Running Triton version {triton.__version__}") + def nvidia_smi(): import subprocess # avoid top level import return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout print0(nvidia_smi()) print0("="*100) -model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_seq_len, args.val_seq_len) +).cuda() for m in model.modules(): if isinstance(m, nn.Embedding): m.bfloat16() @@ -906,26 +940,15 @@ def nvidia_smi(): group["initial_lr"] = group["lr"] # learning rate schedule: stable then decay -def get_lr(step: int): - x = step / args.num_iterations # progress in training +def get_lr_and_ws(step: int): + x = step / (1 + args.num_iterations) # progress in training assert 0 <= x < 1 - if x < 1 - args.cooldown_frac: - return 1.0 - else: + lr = 1.0 + if x >= 1 - args.cooldown_frac: w = (1 - x) / args.cooldown_frac - return w * 1.0 + (1 - w) * 0.1 - -# attention window size schedule: linearly increase -@lru_cache(1) -def get_window_size_blocks_helper(window_size: int): - return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) -def get_window_size_blocks(step: int): - x = step / args.num_iterations # progress in training - assert 0 <= x <= 1 - # Linearly increase the block-wise sliding window size over training 128 -> 1792 - # increase by @fernbear.bsky.social; block-wise by @YouJiacheng - window_size = next_multiple_of_n(1728 * x, n=128) - return get_window_size_blocks_helper(window_size) + lr = w * 1.0 + (1 - w) * 0.1 + ws_idx = int(len(args.ws_schedule) * x) + return lr, args.ws_schedule[ws_idx] model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) @@ -934,13 +957,14 @@ def get_window_size_blocks(step: int): ######################################## # Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 10 +warmup_steps = 60 initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) -for _ in range(warmup_steps): +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +for step in range(warmup_steps): inputs, targets = next(train_loader) - model(inputs, targets, get_window_size_blocks(1)).backward() + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ + model(inputs, targets, ws, ws // 2).backward() for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) @@ -953,7 +977,7 @@ def get_window_size_blocks(step: int): # Training and validation # ######################################## -train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) training_time_ms = 0 # start the clock torch.cuda.synchronize() @@ -962,6 +986,7 @@ def get_window_size_blocks(step: int): train_steps = args.num_iterations for step in range(train_steps + 1): last_step = (step == train_steps) + lr, ws = get_lr_and_ws(step) # --------------- VALIDATION SECTION ----------------- if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): @@ -969,15 +994,14 @@ def get_window_size_blocks(step: int): torch.cuda.synchronize() training_time_ms += 1000 * (time.perf_counter() - t0) model.eval() - val_batch_size = world_size * args.val_seq_len - assert args.val_tokens % val_batch_size == 0 - val_steps = args.val_tokens // val_batch_size - val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + assert args.val_tokens % (world_size * args.val_seq_len) == 0 + val_steps = args.val_tokens // (world_size * args.val_seq_len) + val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) val_loss = 0 with torch.no_grad(): for _ in range(val_steps): inputs, targets = next(val_loader) - val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss += model(inputs, targets, ws, ws // 2) val_loss /= val_steps del val_loader dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) @@ -998,11 +1022,11 @@ def get_window_size_blocks(step: int): # --------------- TRAINING SECTION ----------------- for _ in range(grad_accum_steps): inputs, targets = next(train_loader) - model(inputs, targets, get_window_size_blocks(step)).backward() + model(inputs, targets, ws, ws // 2).backward() # set optimization hyperparameters for opt in optimizers: for group in opt.param_groups: - group["lr"] = group["initial_lr"] * get_lr(step) + group["lr"] = group["initial_lr"] * lr for group in optimizer2.param_groups: frac = min(step / 300, 1) # momentum warmup for muon group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 From ad8cdc402c20b27f9a8837e5c9fea4b8e6974061 Mon Sep 17 00:00:00 2001 From: Varun Srivastava <32204417+varunneal@users.noreply.github.com> Date: Wed, 3 Sep 2025 21:39:36 -0400 Subject: [PATCH 04/14] FA3 with flash_attn_varlen_func (#1) See README --- .gitignore | 1 + .../17e712ee-7cf8-44c9-a784-3762e61b174c.txt | 2808 ---------------- .../1d46fee6-b32c-48de-bd61-0a326442ec4e.txt | 2808 ---------------- .../27d1e0d2-df15-41a9-9496-492a21943fb1.txt | 2808 ---------------- .../7a492532-c19b-40dd-958d-fec55aa4d3fd.txt | 2808 ---------------- records/082725_FA3/README.md | 147 - .../ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt | 2808 ---------------- .../bb331245-5e49-4366-b902-6caff64ed8d6.txt | 2808 ---------------- .../be1069a9-64f4-4316-bd26-4a7f5b697509.txt | 2808 ---------------- .../44fc1276-0510-4961-92c0-730c65e5feba.txt | 2814 +++++++++++++++++ .../4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt | 2814 +++++++++++++++++ .../65b0d9c0-3089-40eb-a1bc-45b15f897462.txt | 2814 +++++++++++++++++ .../831dade9-9b29-43ff-9106-80fc680b3e57.txt | 2814 +++++++++++++++++ records/090325_FA3/README.md | 133 + .../ce3400f2-2ca1-4e0e-a784-089451df1913.txt | 2814 +++++++++++++++++ .../d5d05889-69c7-4887-ac9b-baaae1a5f499.txt | 2814 +++++++++++++++++ .../f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt | 2814 +++++++++++++++++ .../media/attn_speed_vs_batch_s1024_ws384.png | Bin 0 -> 105002 bytes train_gpt.py | 214 +- 19 files changed, 19947 insertions(+), 19902 deletions(-) delete mode 100644 records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt delete mode 100644 records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt delete mode 100644 records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt delete mode 100644 records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt delete mode 100644 records/082725_FA3/README.md delete mode 100644 records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt delete mode 100644 records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt delete mode 100644 records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt create mode 100644 records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt create mode 100644 records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt create mode 100644 records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt create mode 100644 records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt create mode 100644 records/090325_FA3/README.md create mode 100644 records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt create mode 100644 records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt create mode 100644 records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt create mode 100644 records/090325_FA3/media/attn_speed_vs_batch_s1024_ws384.png diff --git a/.gitignore b/.gitignore index c0d3296a0..19c517d6a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ fineweb10B/ pylog124M/ __pycache__/ logs/ +.DS_Store diff --git a/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt b/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt deleted file mode 100644 index b5371a4da..000000000 --- a/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 04:15:50 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:524ms step_avg:524.12ms -step:2/1695 train_time:549ms step_avg:274.51ms -step:3/1695 train_time:617ms step_avg:205.57ms -step:4/1695 train_time:709ms step_avg:177.26ms -step:5/1695 train_time:803ms step_avg:160.52ms -step:6/1695 train_time:897ms step_avg:149.44ms -step:7/1695 train_time:989ms step_avg:141.33ms -step:8/1695 train_time:1082ms step_avg:135.31ms -step:9/1695 train_time:1177ms step_avg:130.74ms -step:10/1695 train_time:1270ms step_avg:127.00ms -step:11/1695 train_time:1364ms step_avg:123.98ms -step:12/1695 train_time:1460ms step_avg:121.70ms -step:13/1695 train_time:1558ms step_avg:119.84ms -step:14/1695 train_time:1653ms step_avg:118.08ms -step:15/1695 train_time:1748ms step_avg:116.53ms -step:16/1695 train_time:1843ms step_avg:115.16ms -step:17/1695 train_time:1937ms step_avg:113.93ms -step:18/1695 train_time:2030ms step_avg:112.80ms -step:19/1695 train_time:2124ms step_avg:111.79ms -step:20/1695 train_time:2218ms step_avg:110.90ms -step:21/1695 train_time:2312ms step_avg:110.11ms -step:22/1695 train_time:2407ms step_avg:109.43ms -step:23/1695 train_time:2504ms step_avg:108.87ms -step:24/1695 train_time:2600ms step_avg:108.35ms -step:25/1695 train_time:2696ms step_avg:107.83ms -step:26/1695 train_time:2790ms step_avg:107.30ms -step:27/1695 train_time:2884ms step_avg:106.81ms -step:28/1695 train_time:2979ms step_avg:106.39ms -step:29/1695 train_time:3073ms step_avg:105.96ms -step:30/1695 train_time:3167ms step_avg:105.56ms -step:31/1695 train_time:3261ms step_avg:105.19ms -step:32/1695 train_time:3356ms step_avg:104.89ms -step:33/1695 train_time:3451ms step_avg:104.57ms -step:34/1695 train_time:3546ms step_avg:104.31ms -step:35/1695 train_time:3644ms step_avg:104.12ms -step:36/1695 train_time:3740ms step_avg:103.89ms -step:37/1695 train_time:3835ms step_avg:103.64ms -step:38/1695 train_time:3929ms step_avg:103.38ms -step:39/1695 train_time:4023ms step_avg:103.16ms -step:40/1695 train_time:4117ms step_avg:102.94ms -step:41/1695 train_time:4211ms step_avg:102.71ms -step:42/1695 train_time:4305ms step_avg:102.51ms -step:43/1695 train_time:4401ms step_avg:102.34ms -step:44/1695 train_time:4497ms step_avg:102.20ms -step:45/1695 train_time:4592ms step_avg:102.04ms -step:46/1695 train_time:4686ms step_avg:101.87ms -step:47/1695 train_time:4782ms step_avg:101.74ms -step:48/1695 train_time:4877ms step_avg:101.60ms -step:49/1695 train_time:4971ms step_avg:101.45ms -step:50/1695 train_time:5065ms step_avg:101.30ms -step:51/1695 train_time:5160ms step_avg:101.17ms -step:52/1695 train_time:5254ms step_avg:101.03ms -step:53/1695 train_time:5347ms step_avg:100.89ms -step:54/1695 train_time:5442ms step_avg:100.79ms -step:55/1695 train_time:5538ms step_avg:100.69ms -step:56/1695 train_time:5632ms step_avg:100.57ms -step:57/1695 train_time:5726ms step_avg:100.46ms -step:58/1695 train_time:5822ms step_avg:100.38ms -step:59/1695 train_time:5918ms step_avg:100.30ms -step:60/1695 train_time:6012ms step_avg:100.20ms -step:61/1695 train_time:6106ms step_avg:100.10ms -step:62/1695 train_time:6201ms step_avg:100.02ms -step:63/1695 train_time:6296ms step_avg:99.94ms -step:64/1695 train_time:6390ms step_avg:99.84ms -step:65/1695 train_time:6485ms step_avg:99.77ms -step:66/1695 train_time:6579ms step_avg:99.68ms -step:67/1695 train_time:6673ms step_avg:99.59ms -step:68/1695 train_time:6767ms step_avg:99.52ms -step:69/1695 train_time:6863ms step_avg:99.47ms -step:70/1695 train_time:6958ms step_avg:99.40ms -step:71/1695 train_time:7052ms step_avg:99.32ms -step:72/1695 train_time:7146ms step_avg:99.25ms -step:73/1695 train_time:7241ms step_avg:99.19ms -step:74/1695 train_time:7337ms step_avg:99.15ms -step:75/1695 train_time:7431ms step_avg:99.08ms -step:76/1695 train_time:7526ms step_avg:99.02ms -step:77/1695 train_time:7621ms step_avg:98.97ms -step:78/1695 train_time:7716ms step_avg:98.92ms -step:79/1695 train_time:7809ms step_avg:98.85ms -step:80/1695 train_time:7905ms step_avg:98.81ms -step:81/1695 train_time:8000ms step_avg:98.76ms -step:82/1695 train_time:8094ms step_avg:98.70ms -step:83/1695 train_time:8189ms step_avg:98.66ms -step:84/1695 train_time:8283ms step_avg:98.60ms -step:85/1695 train_time:8378ms step_avg:98.56ms -step:86/1695 train_time:8471ms step_avg:98.50ms -step:87/1695 train_time:8566ms step_avg:98.45ms -step:88/1695 train_time:8661ms step_avg:98.42ms -step:89/1695 train_time:8755ms step_avg:98.37ms -step:90/1695 train_time:8849ms step_avg:98.32ms -step:91/1695 train_time:8944ms step_avg:98.29ms -step:92/1695 train_time:9039ms step_avg:98.25ms -step:93/1695 train_time:9133ms step_avg:98.21ms -step:94/1695 train_time:9227ms step_avg:98.16ms -step:95/1695 train_time:9322ms step_avg:98.13ms -step:96/1695 train_time:9417ms step_avg:98.10ms -step:97/1695 train_time:9511ms step_avg:98.05ms -step:98/1695 train_time:9606ms step_avg:98.02ms -step:99/1695 train_time:9702ms step_avg:98.00ms -step:100/1695 train_time:9797ms step_avg:97.97ms -step:101/1695 train_time:9891ms step_avg:97.93ms -step:102/1695 train_time:9985ms step_avg:97.89ms -step:103/1695 train_time:10079ms step_avg:97.85ms -step:104/1695 train_time:10174ms step_avg:97.82ms -step:105/1695 train_time:10268ms step_avg:97.79ms -step:106/1695 train_time:10362ms step_avg:97.76ms -step:107/1695 train_time:10456ms step_avg:97.72ms -step:108/1695 train_time:10550ms step_avg:97.69ms -step:109/1695 train_time:10645ms step_avg:97.66ms -step:110/1695 train_time:10740ms step_avg:97.63ms -step:111/1695 train_time:10835ms step_avg:97.61ms -step:112/1695 train_time:10929ms step_avg:97.58ms -step:113/1695 train_time:11023ms step_avg:97.55ms -step:114/1695 train_time:11118ms step_avg:97.53ms -step:115/1695 train_time:11213ms step_avg:97.50ms -step:116/1695 train_time:11307ms step_avg:97.47ms -step:117/1695 train_time:11402ms step_avg:97.45ms -step:118/1695 train_time:11496ms step_avg:97.42ms -step:119/1695 train_time:11589ms step_avg:97.39ms -step:120/1695 train_time:11684ms step_avg:97.37ms -step:121/1695 train_time:11779ms step_avg:97.34ms -step:122/1695 train_time:11873ms step_avg:97.32ms -step:123/1695 train_time:11967ms step_avg:97.29ms -step:124/1695 train_time:12062ms step_avg:97.28ms -step:125/1695 train_time:12157ms step_avg:97.25ms -step:125/1695 val_loss:4.3113 train_time:12248ms step_avg:97.99ms -step:126/1695 train_time:12274ms step_avg:97.41ms -step:127/1695 train_time:12351ms step_avg:97.25ms -step:128/1695 train_time:12451ms step_avg:97.28ms -step:129/1695 train_time:12547ms step_avg:97.26ms -step:130/1695 train_time:12640ms step_avg:97.23ms -step:131/1695 train_time:12734ms step_avg:97.21ms -step:132/1695 train_time:12828ms step_avg:97.18ms -step:133/1695 train_time:12921ms step_avg:97.15ms -step:134/1695 train_time:13014ms step_avg:97.12ms -step:135/1695 train_time:13108ms step_avg:97.09ms -step:136/1695 train_time:13201ms step_avg:97.07ms -step:137/1695 train_time:13297ms step_avg:97.05ms -step:138/1695 train_time:13394ms step_avg:97.06ms -step:139/1695 train_time:13490ms step_avg:97.05ms -step:140/1695 train_time:13584ms step_avg:97.03ms -step:141/1695 train_time:13678ms step_avg:97.00ms -step:142/1695 train_time:13772ms step_avg:96.98ms -step:143/1695 train_time:13865ms step_avg:96.96ms -step:144/1695 train_time:13958ms step_avg:96.93ms -step:145/1695 train_time:14052ms step_avg:96.91ms -step:146/1695 train_time:14144ms step_avg:96.88ms -step:147/1695 train_time:14238ms step_avg:96.86ms -step:148/1695 train_time:14333ms step_avg:96.84ms -step:149/1695 train_time:14429ms step_avg:96.84ms -step:150/1695 train_time:14523ms step_avg:96.82ms -step:151/1695 train_time:14618ms step_avg:96.81ms -step:152/1695 train_time:14713ms step_avg:96.79ms -step:153/1695 train_time:14806ms step_avg:96.77ms -step:154/1695 train_time:14900ms step_avg:96.75ms -step:155/1695 train_time:14995ms step_avg:96.74ms -step:156/1695 train_time:15088ms step_avg:96.72ms -step:157/1695 train_time:15181ms step_avg:96.69ms -step:158/1695 train_time:15274ms step_avg:96.67ms -step:159/1695 train_time:15369ms step_avg:96.66ms -step:160/1695 train_time:15464ms step_avg:96.65ms -step:161/1695 train_time:15558ms step_avg:96.63ms -step:162/1695 train_time:15653ms step_avg:96.62ms -step:163/1695 train_time:15748ms step_avg:96.61ms -step:164/1695 train_time:15841ms step_avg:96.59ms -step:165/1695 train_time:15936ms step_avg:96.58ms -step:166/1695 train_time:16030ms step_avg:96.57ms -step:167/1695 train_time:16124ms step_avg:96.55ms -step:168/1695 train_time:16217ms step_avg:96.53ms -step:169/1695 train_time:16311ms step_avg:96.52ms -step:170/1695 train_time:16406ms step_avg:96.50ms -step:171/1695 train_time:16500ms step_avg:96.49ms -step:172/1695 train_time:16595ms step_avg:96.48ms -step:173/1695 train_time:16964ms step_avg:98.06ms -step:174/1695 train_time:17033ms step_avg:97.89ms -step:175/1695 train_time:17126ms step_avg:97.86ms -step:176/1695 train_time:17219ms step_avg:97.83ms -step:177/1695 train_time:17312ms step_avg:97.81ms -step:178/1695 train_time:17405ms step_avg:97.78ms -step:179/1695 train_time:17499ms step_avg:97.76ms -step:180/1695 train_time:17591ms step_avg:97.73ms -step:181/1695 train_time:17684ms step_avg:97.70ms -step:182/1695 train_time:17777ms step_avg:97.68ms -step:183/1695 train_time:17876ms step_avg:97.68ms -step:184/1695 train_time:17973ms step_avg:97.68ms -step:185/1695 train_time:18068ms step_avg:97.67ms -step:186/1695 train_time:18162ms step_avg:97.64ms -step:187/1695 train_time:18256ms step_avg:97.63ms -step:188/1695 train_time:18350ms step_avg:97.61ms -step:189/1695 train_time:18443ms step_avg:97.58ms -step:190/1695 train_time:18537ms step_avg:97.56ms -step:191/1695 train_time:18630ms step_avg:97.54ms -step:192/1695 train_time:18724ms step_avg:97.52ms -step:193/1695 train_time:18818ms step_avg:97.50ms -step:194/1695 train_time:18914ms step_avg:97.49ms -step:195/1695 train_time:19010ms step_avg:97.48ms -step:196/1695 train_time:19105ms step_avg:97.47ms -step:197/1695 train_time:19198ms step_avg:97.45ms -step:198/1695 train_time:19292ms step_avg:97.44ms -step:199/1695 train_time:19387ms step_avg:97.42ms -step:200/1695 train_time:19480ms step_avg:97.40ms -step:201/1695 train_time:19574ms step_avg:97.39ms -step:202/1695 train_time:19669ms step_avg:97.37ms -step:203/1695 train_time:19763ms step_avg:97.35ms -step:204/1695 train_time:19857ms step_avg:97.34ms -step:205/1695 train_time:19952ms step_avg:97.32ms -step:206/1695 train_time:20047ms step_avg:97.31ms -step:207/1695 train_time:20140ms step_avg:97.30ms -step:208/1695 train_time:20235ms step_avg:97.28ms -step:209/1695 train_time:20329ms step_avg:97.27ms -step:210/1695 train_time:20423ms step_avg:97.25ms -step:211/1695 train_time:20516ms step_avg:97.23ms -step:212/1695 train_time:20611ms step_avg:97.22ms -step:213/1695 train_time:20706ms step_avg:97.21ms -step:214/1695 train_time:20799ms step_avg:97.19ms -step:215/1695 train_time:20894ms step_avg:97.18ms -step:216/1695 train_time:20988ms step_avg:97.17ms -step:217/1695 train_time:21082ms step_avg:97.15ms -step:218/1695 train_time:21176ms step_avg:97.14ms -step:219/1695 train_time:21271ms step_avg:97.13ms -step:220/1695 train_time:21365ms step_avg:97.12ms -step:221/1695 train_time:21459ms step_avg:97.10ms -step:222/1695 train_time:21554ms step_avg:97.09ms -step:223/1695 train_time:21648ms step_avg:97.07ms -step:224/1695 train_time:21741ms step_avg:97.06ms -step:225/1695 train_time:21835ms step_avg:97.05ms -step:226/1695 train_time:21931ms step_avg:97.04ms -step:227/1695 train_time:22024ms step_avg:97.02ms -step:228/1695 train_time:22118ms step_avg:97.01ms -step:229/1695 train_time:22213ms step_avg:97.00ms -step:230/1695 train_time:22308ms step_avg:96.99ms -step:231/1695 train_time:22401ms step_avg:96.98ms -step:232/1695 train_time:22496ms step_avg:96.96ms -step:233/1695 train_time:22589ms step_avg:96.95ms -step:234/1695 train_time:22683ms step_avg:96.93ms -step:235/1695 train_time:22776ms step_avg:96.92ms -step:236/1695 train_time:22872ms step_avg:96.91ms -step:237/1695 train_time:22967ms step_avg:96.91ms -step:238/1695 train_time:23061ms step_avg:96.89ms -step:239/1695 train_time:23155ms step_avg:96.88ms -step:240/1695 train_time:23249ms step_avg:96.87ms -step:241/1695 train_time:23343ms step_avg:96.86ms -step:242/1695 train_time:23437ms step_avg:96.85ms -step:243/1695 train_time:23532ms step_avg:96.84ms -step:244/1695 train_time:23626ms step_avg:96.83ms -step:245/1695 train_time:23719ms step_avg:96.81ms -step:246/1695 train_time:23814ms step_avg:96.81ms -step:247/1695 train_time:23909ms step_avg:96.80ms -step:248/1695 train_time:24003ms step_avg:96.79ms -step:249/1695 train_time:24097ms step_avg:96.78ms -step:250/1695 train_time:24191ms step_avg:96.77ms -step:250/1695 val_loss:3.9807 train_time:24284ms step_avg:97.14ms -step:251/1695 train_time:24310ms step_avg:96.85ms -step:252/1695 train_time:24384ms step_avg:96.76ms -step:253/1695 train_time:24484ms step_avg:96.78ms -step:254/1695 train_time:24580ms step_avg:96.77ms -step:255/1695 train_time:24673ms step_avg:96.76ms -step:256/1695 train_time:24766ms step_avg:96.74ms -step:257/1695 train_time:24859ms step_avg:96.73ms -step:258/1695 train_time:24953ms step_avg:96.72ms -step:259/1695 train_time:25046ms step_avg:96.70ms -step:260/1695 train_time:25139ms step_avg:96.69ms -step:261/1695 train_time:25232ms step_avg:96.68ms -step:262/1695 train_time:25328ms step_avg:96.67ms -step:263/1695 train_time:25425ms step_avg:96.67ms -step:264/1695 train_time:25521ms step_avg:96.67ms -step:265/1695 train_time:25616ms step_avg:96.67ms -step:266/1695 train_time:25710ms step_avg:96.65ms -step:267/1695 train_time:25804ms step_avg:96.64ms -step:268/1695 train_time:25898ms step_avg:96.64ms -step:269/1695 train_time:25992ms step_avg:96.62ms -step:270/1695 train_time:26085ms step_avg:96.61ms -step:271/1695 train_time:26179ms step_avg:96.60ms -step:272/1695 train_time:26273ms step_avg:96.59ms -step:273/1695 train_time:26368ms step_avg:96.59ms -step:274/1695 train_time:26465ms step_avg:96.59ms -step:275/1695 train_time:26560ms step_avg:96.58ms -step:276/1695 train_time:26655ms step_avg:96.58ms -step:277/1695 train_time:26749ms step_avg:96.57ms -step:278/1695 train_time:26843ms step_avg:96.56ms -step:279/1695 train_time:26936ms step_avg:96.55ms -step:280/1695 train_time:27030ms step_avg:96.53ms -step:281/1695 train_time:27124ms step_avg:96.53ms -step:282/1695 train_time:27218ms step_avg:96.52ms -step:283/1695 train_time:27312ms step_avg:96.51ms -step:284/1695 train_time:27407ms step_avg:96.50ms -step:285/1695 train_time:27502ms step_avg:96.50ms -step:286/1695 train_time:27596ms step_avg:96.49ms -step:287/1695 train_time:27690ms step_avg:96.48ms -step:288/1695 train_time:27785ms step_avg:96.47ms -step:289/1695 train_time:27879ms step_avg:96.47ms -step:290/1695 train_time:27971ms step_avg:96.45ms -step:291/1695 train_time:28065ms step_avg:96.44ms -step:292/1695 train_time:28160ms step_avg:96.44ms -step:293/1695 train_time:28254ms step_avg:96.43ms -step:294/1695 train_time:28348ms step_avg:96.42ms -step:295/1695 train_time:28443ms step_avg:96.42ms -step:296/1695 train_time:28538ms step_avg:96.41ms -step:297/1695 train_time:28632ms step_avg:96.40ms -step:298/1695 train_time:28726ms step_avg:96.40ms -step:299/1695 train_time:28820ms step_avg:96.39ms -step:300/1695 train_time:28913ms step_avg:96.38ms -step:301/1695 train_time:29006ms step_avg:96.37ms -step:302/1695 train_time:29100ms step_avg:96.36ms -step:303/1695 train_time:29194ms step_avg:96.35ms -step:304/1695 train_time:29288ms step_avg:96.34ms -step:305/1695 train_time:29382ms step_avg:96.33ms -step:306/1695 train_time:29477ms step_avg:96.33ms -step:307/1695 train_time:29571ms step_avg:96.32ms -step:308/1695 train_time:29666ms step_avg:96.32ms -step:309/1695 train_time:29761ms step_avg:96.31ms -step:310/1695 train_time:29855ms step_avg:96.31ms -step:311/1695 train_time:29948ms step_avg:96.30ms -step:312/1695 train_time:30042ms step_avg:96.29ms -step:313/1695 train_time:30135ms step_avg:96.28ms -step:314/1695 train_time:30229ms step_avg:96.27ms -step:315/1695 train_time:30322ms step_avg:96.26ms -step:316/1695 train_time:30416ms step_avg:96.25ms -step:317/1695 train_time:30510ms step_avg:96.24ms -step:318/1695 train_time:30605ms step_avg:96.24ms -step:319/1695 train_time:30700ms step_avg:96.24ms -step:320/1695 train_time:30795ms step_avg:96.23ms -step:321/1695 train_time:30888ms step_avg:96.23ms -step:322/1695 train_time:30983ms step_avg:96.22ms -step:323/1695 train_time:31077ms step_avg:96.21ms -step:324/1695 train_time:31170ms step_avg:96.20ms -step:325/1695 train_time:31264ms step_avg:96.20ms -step:326/1695 train_time:31360ms step_avg:96.20ms -step:327/1695 train_time:31453ms step_avg:96.19ms -step:328/1695 train_time:31547ms step_avg:96.18ms -step:329/1695 train_time:31642ms step_avg:96.18ms -step:330/1695 train_time:31737ms step_avg:96.17ms -step:331/1695 train_time:31832ms step_avg:96.17ms -step:332/1695 train_time:31926ms step_avg:96.16ms -step:333/1695 train_time:32020ms step_avg:96.16ms -step:334/1695 train_time:32114ms step_avg:96.15ms -step:335/1695 train_time:32207ms step_avg:96.14ms -step:336/1695 train_time:32302ms step_avg:96.14ms -step:337/1695 train_time:32395ms step_avg:96.13ms -step:338/1695 train_time:32488ms step_avg:96.12ms -step:339/1695 train_time:32582ms step_avg:96.11ms -step:340/1695 train_time:32677ms step_avg:96.11ms -step:341/1695 train_time:32771ms step_avg:96.10ms -step:342/1695 train_time:32866ms step_avg:96.10ms -step:343/1695 train_time:32961ms step_avg:96.10ms -step:344/1695 train_time:33055ms step_avg:96.09ms -step:345/1695 train_time:33378ms step_avg:96.75ms -step:346/1695 train_time:33470ms step_avg:96.73ms -step:347/1695 train_time:33563ms step_avg:96.72ms -step:348/1695 train_time:33655ms step_avg:96.71ms -step:349/1695 train_time:33748ms step_avg:96.70ms -step:350/1695 train_time:33841ms step_avg:96.69ms -step:351/1695 train_time:33934ms step_avg:96.68ms -step:352/1695 train_time:34027ms step_avg:96.67ms -step:353/1695 train_time:34120ms step_avg:96.66ms -step:354/1695 train_time:34213ms step_avg:96.65ms -step:355/1695 train_time:34312ms step_avg:96.65ms -step:356/1695 train_time:34409ms step_avg:96.65ms -step:357/1695 train_time:34506ms step_avg:96.66ms -step:358/1695 train_time:34602ms step_avg:96.65ms -step:359/1695 train_time:34695ms step_avg:96.64ms -step:360/1695 train_time:34788ms step_avg:96.63ms -step:361/1695 train_time:34881ms step_avg:96.62ms -step:362/1695 train_time:34974ms step_avg:96.61ms -step:363/1695 train_time:35068ms step_avg:96.60ms -step:364/1695 train_time:35161ms step_avg:96.60ms -step:365/1695 train_time:35256ms step_avg:96.59ms -step:366/1695 train_time:35351ms step_avg:96.59ms -step:367/1695 train_time:35448ms step_avg:96.59ms -step:368/1695 train_time:35544ms step_avg:96.59ms -step:369/1695 train_time:35639ms step_avg:96.58ms -step:370/1695 train_time:35733ms step_avg:96.58ms -step:371/1695 train_time:35826ms step_avg:96.57ms -step:372/1695 train_time:35920ms step_avg:96.56ms -step:373/1695 train_time:36013ms step_avg:96.55ms -step:374/1695 train_time:36106ms step_avg:96.54ms -step:375/1695 train_time:36199ms step_avg:96.53ms -step:375/1695 val_loss:3.8148 train_time:36291ms step_avg:96.78ms -step:376/1695 train_time:36317ms step_avg:96.59ms -step:377/1695 train_time:36395ms step_avg:96.54ms -step:378/1695 train_time:36491ms step_avg:96.54ms -step:379/1695 train_time:36586ms step_avg:96.53ms -step:380/1695 train_time:36680ms step_avg:96.53ms -step:381/1695 train_time:36773ms step_avg:96.52ms -step:382/1695 train_time:36866ms step_avg:96.51ms -step:383/1695 train_time:36960ms step_avg:96.50ms -step:384/1695 train_time:37054ms step_avg:96.49ms -step:385/1695 train_time:37147ms step_avg:96.49ms -step:386/1695 train_time:37241ms step_avg:96.48ms -step:387/1695 train_time:37337ms step_avg:96.48ms -step:388/1695 train_time:37433ms step_avg:96.48ms -step:389/1695 train_time:37529ms step_avg:96.47ms -step:390/1695 train_time:37623ms step_avg:96.47ms -step:391/1695 train_time:37716ms step_avg:96.46ms -step:392/1695 train_time:37809ms step_avg:96.45ms -step:393/1695 train_time:37902ms step_avg:96.44ms -step:394/1695 train_time:37996ms step_avg:96.44ms -step:395/1695 train_time:38089ms step_avg:96.43ms -step:396/1695 train_time:38183ms step_avg:96.42ms -step:397/1695 train_time:38277ms step_avg:96.42ms -step:398/1695 train_time:38371ms step_avg:96.41ms -step:399/1695 train_time:38466ms step_avg:96.41ms -step:400/1695 train_time:38561ms step_avg:96.40ms -step:401/1695 train_time:38655ms step_avg:96.40ms -step:402/1695 train_time:38749ms step_avg:96.39ms -step:403/1695 train_time:38843ms step_avg:96.38ms -step:404/1695 train_time:38936ms step_avg:96.38ms -step:405/1695 train_time:39030ms step_avg:96.37ms -step:406/1695 train_time:39124ms step_avg:96.36ms -step:407/1695 train_time:39217ms step_avg:96.36ms -step:408/1695 train_time:39311ms step_avg:96.35ms -step:409/1695 train_time:39406ms step_avg:96.35ms -step:410/1695 train_time:39501ms step_avg:96.34ms -step:411/1695 train_time:39594ms step_avg:96.34ms -step:412/1695 train_time:39689ms step_avg:96.33ms -step:413/1695 train_time:39783ms step_avg:96.33ms -step:414/1695 train_time:39876ms step_avg:96.32ms -step:415/1695 train_time:39970ms step_avg:96.31ms -step:416/1695 train_time:40063ms step_avg:96.31ms -step:417/1695 train_time:40157ms step_avg:96.30ms -step:418/1695 train_time:40251ms step_avg:96.29ms -step:419/1695 train_time:40346ms step_avg:96.29ms -step:420/1695 train_time:40440ms step_avg:96.29ms -step:421/1695 train_time:40534ms step_avg:96.28ms -step:422/1695 train_time:40629ms step_avg:96.28ms -step:423/1695 train_time:40723ms step_avg:96.27ms -step:424/1695 train_time:40816ms step_avg:96.26ms -step:425/1695 train_time:40910ms step_avg:96.26ms -step:426/1695 train_time:41005ms step_avg:96.25ms -step:427/1695 train_time:41099ms step_avg:96.25ms -step:428/1695 train_time:41192ms step_avg:96.24ms -step:429/1695 train_time:41288ms step_avg:96.24ms -step:430/1695 train_time:41383ms step_avg:96.24ms -step:431/1695 train_time:41476ms step_avg:96.23ms -step:432/1695 train_time:41570ms step_avg:96.23ms -step:433/1695 train_time:41664ms step_avg:96.22ms -step:434/1695 train_time:41758ms step_avg:96.22ms -step:435/1695 train_time:41851ms step_avg:96.21ms -step:436/1695 train_time:41946ms step_avg:96.21ms -step:437/1695 train_time:42039ms step_avg:96.20ms -step:438/1695 train_time:42132ms step_avg:96.19ms -step:439/1695 train_time:42227ms step_avg:96.19ms -step:440/1695 train_time:42322ms step_avg:96.19ms -step:441/1695 train_time:42416ms step_avg:96.18ms -step:442/1695 train_time:42510ms step_avg:96.18ms -step:443/1695 train_time:42605ms step_avg:96.17ms -step:444/1695 train_time:42698ms step_avg:96.17ms -step:445/1695 train_time:42792ms step_avg:96.16ms -step:446/1695 train_time:42887ms step_avg:96.16ms -step:447/1695 train_time:42982ms step_avg:96.16ms -step:448/1695 train_time:43075ms step_avg:96.15ms -step:449/1695 train_time:43169ms step_avg:96.14ms -step:450/1695 train_time:43264ms step_avg:96.14ms -step:451/1695 train_time:43358ms step_avg:96.14ms -step:452/1695 train_time:43452ms step_avg:96.13ms -step:453/1695 train_time:43548ms step_avg:96.13ms -step:454/1695 train_time:43642ms step_avg:96.13ms -step:455/1695 train_time:43735ms step_avg:96.12ms -step:456/1695 train_time:43828ms step_avg:96.11ms -step:457/1695 train_time:43923ms step_avg:96.11ms -step:458/1695 train_time:44018ms step_avg:96.11ms -step:459/1695 train_time:44112ms step_avg:96.10ms -step:460/1695 train_time:44205ms step_avg:96.10ms -step:461/1695 train_time:44300ms step_avg:96.09ms -step:462/1695 train_time:44393ms step_avg:96.09ms -step:463/1695 train_time:44487ms step_avg:96.08ms -step:464/1695 train_time:44582ms step_avg:96.08ms -step:465/1695 train_time:44676ms step_avg:96.08ms -step:466/1695 train_time:44770ms step_avg:96.07ms -step:467/1695 train_time:44864ms step_avg:96.07ms -step:468/1695 train_time:44959ms step_avg:96.07ms -step:469/1695 train_time:45053ms step_avg:96.06ms -step:470/1695 train_time:45147ms step_avg:96.06ms -step:471/1695 train_time:45242ms step_avg:96.05ms -step:472/1695 train_time:45335ms step_avg:96.05ms -step:473/1695 train_time:45429ms step_avg:96.04ms -step:474/1695 train_time:45523ms step_avg:96.04ms -step:475/1695 train_time:45616ms step_avg:96.03ms -step:476/1695 train_time:45710ms step_avg:96.03ms -step:477/1695 train_time:45805ms step_avg:96.03ms -step:478/1695 train_time:45899ms step_avg:96.02ms -step:479/1695 train_time:45992ms step_avg:96.02ms -step:480/1695 train_time:46087ms step_avg:96.01ms -step:481/1695 train_time:46181ms step_avg:96.01ms -step:482/1695 train_time:46275ms step_avg:96.01ms -step:483/1695 train_time:46369ms step_avg:96.00ms -step:484/1695 train_time:46464ms step_avg:96.00ms -step:485/1695 train_time:46559ms step_avg:96.00ms -step:486/1695 train_time:46653ms step_avg:95.99ms -step:487/1695 train_time:46747ms step_avg:95.99ms -step:488/1695 train_time:46842ms step_avg:95.99ms -step:489/1695 train_time:46935ms step_avg:95.98ms -step:490/1695 train_time:47029ms step_avg:95.98ms -step:491/1695 train_time:47122ms step_avg:95.97ms -step:492/1695 train_time:47216ms step_avg:95.97ms -step:493/1695 train_time:47309ms step_avg:95.96ms -step:494/1695 train_time:47403ms step_avg:95.96ms -step:495/1695 train_time:47496ms step_avg:95.95ms -step:496/1695 train_time:47591ms step_avg:95.95ms -step:497/1695 train_time:47686ms step_avg:95.95ms -step:498/1695 train_time:47780ms step_avg:95.94ms -step:499/1695 train_time:47874ms step_avg:95.94ms -step:500/1695 train_time:47968ms step_avg:95.94ms -step:500/1695 val_loss:3.7151 train_time:48060ms step_avg:96.12ms -step:501/1695 train_time:48087ms step_avg:95.98ms -step:502/1695 train_time:48163ms step_avg:95.94ms -step:503/1695 train_time:48261ms step_avg:95.95ms -step:504/1695 train_time:48355ms step_avg:95.94ms -step:505/1695 train_time:48448ms step_avg:95.94ms -step:506/1695 train_time:48542ms step_avg:95.93ms -step:507/1695 train_time:48634ms step_avg:95.93ms -step:508/1695 train_time:48728ms step_avg:95.92ms -step:509/1695 train_time:48820ms step_avg:95.91ms -step:510/1695 train_time:48913ms step_avg:95.91ms -step:511/1695 train_time:49007ms step_avg:95.90ms -step:512/1695 train_time:49103ms step_avg:95.90ms -step:513/1695 train_time:49198ms step_avg:95.90ms -step:514/1695 train_time:49293ms step_avg:95.90ms -step:515/1695 train_time:49388ms step_avg:95.90ms -step:516/1695 train_time:49482ms step_avg:95.90ms -step:517/1695 train_time:49575ms step_avg:95.89ms -step:518/1695 train_time:49669ms step_avg:95.89ms -step:519/1695 train_time:50009ms step_avg:96.36ms -step:520/1695 train_time:50195ms step_avg:96.53ms -step:521/1695 train_time:50287ms step_avg:96.52ms -step:522/1695 train_time:50380ms step_avg:96.51ms -step:523/1695 train_time:50473ms step_avg:96.51ms -step:524/1695 train_time:50566ms step_avg:96.50ms -step:525/1695 train_time:50658ms step_avg:96.49ms -step:526/1695 train_time:50751ms step_avg:96.49ms -step:527/1695 train_time:50845ms step_avg:96.48ms -step:528/1695 train_time:50937ms step_avg:96.47ms -step:529/1695 train_time:51035ms step_avg:96.47ms -step:530/1695 train_time:51134ms step_avg:96.48ms -step:531/1695 train_time:51232ms step_avg:96.48ms -step:532/1695 train_time:51328ms step_avg:96.48ms -step:533/1695 train_time:51422ms step_avg:96.48ms -step:534/1695 train_time:51515ms step_avg:96.47ms -step:535/1695 train_time:51609ms step_avg:96.47ms -step:536/1695 train_time:51703ms step_avg:96.46ms -step:537/1695 train_time:51795ms step_avg:96.45ms -step:538/1695 train_time:51889ms step_avg:96.45ms -step:539/1695 train_time:51982ms step_avg:96.44ms -step:540/1695 train_time:52077ms step_avg:96.44ms -step:541/1695 train_time:52173ms step_avg:96.44ms -step:542/1695 train_time:52270ms step_avg:96.44ms -step:543/1695 train_time:52365ms step_avg:96.44ms -step:544/1695 train_time:52458ms step_avg:96.43ms -step:545/1695 train_time:52552ms step_avg:96.43ms -step:546/1695 train_time:52646ms step_avg:96.42ms -step:547/1695 train_time:52739ms step_avg:96.42ms -step:548/1695 train_time:52833ms step_avg:96.41ms -step:549/1695 train_time:52926ms step_avg:96.40ms -step:550/1695 train_time:53020ms step_avg:96.40ms -step:551/1695 train_time:53114ms step_avg:96.40ms -step:552/1695 train_time:53209ms step_avg:96.39ms -step:553/1695 train_time:53303ms step_avg:96.39ms -step:554/1695 train_time:53397ms step_avg:96.38ms -step:555/1695 train_time:53491ms step_avg:96.38ms -step:556/1695 train_time:53584ms step_avg:96.37ms -step:557/1695 train_time:53678ms step_avg:96.37ms -step:558/1695 train_time:53771ms step_avg:96.36ms -step:559/1695 train_time:53865ms step_avg:96.36ms -step:560/1695 train_time:53959ms step_avg:96.35ms -step:561/1695 train_time:54053ms step_avg:96.35ms -step:562/1695 train_time:54147ms step_avg:96.35ms -step:563/1695 train_time:54242ms step_avg:96.34ms -step:564/1695 train_time:54336ms step_avg:96.34ms -step:565/1695 train_time:54430ms step_avg:96.34ms -step:566/1695 train_time:54524ms step_avg:96.33ms -step:567/1695 train_time:54618ms step_avg:96.33ms -step:568/1695 train_time:54713ms step_avg:96.33ms -step:569/1695 train_time:54809ms step_avg:96.32ms -step:570/1695 train_time:54905ms step_avg:96.33ms -step:571/1695 train_time:55001ms step_avg:96.32ms -step:572/1695 train_time:55096ms step_avg:96.32ms -step:573/1695 train_time:55193ms step_avg:96.32ms -step:574/1695 train_time:55289ms step_avg:96.32ms -step:575/1695 train_time:55385ms step_avg:96.32ms -step:576/1695 train_time:55481ms step_avg:96.32ms -step:577/1695 train_time:55576ms step_avg:96.32ms -step:578/1695 train_time:55672ms step_avg:96.32ms -step:579/1695 train_time:55769ms step_avg:96.32ms -step:580/1695 train_time:55865ms step_avg:96.32ms -step:581/1695 train_time:55961ms step_avg:96.32ms -step:582/1695 train_time:56056ms step_avg:96.32ms -step:583/1695 train_time:56152ms step_avg:96.32ms -step:584/1695 train_time:56248ms step_avg:96.32ms -step:585/1695 train_time:56345ms step_avg:96.32ms -step:586/1695 train_time:56442ms step_avg:96.32ms -step:587/1695 train_time:56537ms step_avg:96.32ms -step:588/1695 train_time:56633ms step_avg:96.31ms -step:589/1695 train_time:56729ms step_avg:96.31ms -step:590/1695 train_time:56824ms step_avg:96.31ms -step:591/1695 train_time:56919ms step_avg:96.31ms -step:592/1695 train_time:57015ms step_avg:96.31ms -step:593/1695 train_time:57111ms step_avg:96.31ms -step:594/1695 train_time:57208ms step_avg:96.31ms -step:595/1695 train_time:57304ms step_avg:96.31ms -step:596/1695 train_time:57401ms step_avg:96.31ms -step:597/1695 train_time:57496ms step_avg:96.31ms -step:598/1695 train_time:57592ms step_avg:96.31ms -step:599/1695 train_time:57689ms step_avg:96.31ms -step:600/1695 train_time:57784ms step_avg:96.31ms -step:601/1695 train_time:57880ms step_avg:96.31ms -step:602/1695 train_time:57976ms step_avg:96.31ms -step:603/1695 train_time:58071ms step_avg:96.30ms -step:604/1695 train_time:58168ms step_avg:96.30ms -step:605/1695 train_time:58264ms step_avg:96.30ms -step:606/1695 train_time:58360ms step_avg:96.30ms -step:607/1695 train_time:58455ms step_avg:96.30ms -step:608/1695 train_time:58550ms step_avg:96.30ms -step:609/1695 train_time:58647ms step_avg:96.30ms -step:610/1695 train_time:58744ms step_avg:96.30ms -step:611/1695 train_time:58841ms step_avg:96.30ms -step:612/1695 train_time:58936ms step_avg:96.30ms -step:613/1695 train_time:59032ms step_avg:96.30ms -step:614/1695 train_time:59128ms step_avg:96.30ms -step:615/1695 train_time:59224ms step_avg:96.30ms -step:616/1695 train_time:59320ms step_avg:96.30ms -step:617/1695 train_time:59415ms step_avg:96.30ms -step:618/1695 train_time:59511ms step_avg:96.30ms -step:619/1695 train_time:59607ms step_avg:96.30ms -step:620/1695 train_time:59703ms step_avg:96.29ms -step:621/1695 train_time:59798ms step_avg:96.29ms -step:622/1695 train_time:59894ms step_avg:96.29ms -step:623/1695 train_time:59991ms step_avg:96.29ms -step:624/1695 train_time:60087ms step_avg:96.29ms -step:625/1695 train_time:60183ms step_avg:96.29ms -step:625/1695 val_loss:3.6179 train_time:60276ms step_avg:96.44ms -step:626/1695 train_time:60301ms step_avg:96.33ms -step:627/1695 train_time:60381ms step_avg:96.30ms -step:628/1695 train_time:60478ms step_avg:96.30ms -step:629/1695 train_time:60574ms step_avg:96.30ms -step:630/1695 train_time:60669ms step_avg:96.30ms -step:631/1695 train_time:60764ms step_avg:96.30ms -step:632/1695 train_time:60858ms step_avg:96.29ms -step:633/1695 train_time:60953ms step_avg:96.29ms -step:634/1695 train_time:61048ms step_avg:96.29ms -step:635/1695 train_time:61143ms step_avg:96.29ms -step:636/1695 train_time:61240ms step_avg:96.29ms -step:637/1695 train_time:61337ms step_avg:96.29ms -step:638/1695 train_time:61435ms step_avg:96.29ms -step:639/1695 train_time:61532ms step_avg:96.29ms -step:640/1695 train_time:61628ms step_avg:96.29ms -step:641/1695 train_time:61725ms step_avg:96.30ms -step:642/1695 train_time:61821ms step_avg:96.29ms -step:643/1695 train_time:61915ms step_avg:96.29ms -step:644/1695 train_time:62011ms step_avg:96.29ms -step:645/1695 train_time:62106ms step_avg:96.29ms -step:646/1695 train_time:62201ms step_avg:96.29ms -step:647/1695 train_time:62297ms step_avg:96.29ms -step:648/1695 train_time:62394ms step_avg:96.29ms -step:649/1695 train_time:62492ms step_avg:96.29ms -step:650/1695 train_time:62589ms step_avg:96.29ms -step:651/1695 train_time:62685ms step_avg:96.29ms -step:652/1695 train_time:62781ms step_avg:96.29ms -step:653/1695 train_time:62876ms step_avg:96.29ms -step:654/1695 train_time:62971ms step_avg:96.29ms -step:655/1695 train_time:63067ms step_avg:96.29ms -step:656/1695 train_time:63164ms step_avg:96.29ms -step:657/1695 train_time:63261ms step_avg:96.29ms -step:658/1695 train_time:63357ms step_avg:96.29ms -step:659/1695 train_time:63453ms step_avg:96.29ms -step:660/1695 train_time:63550ms step_avg:96.29ms -step:661/1695 train_time:63647ms step_avg:96.29ms -step:662/1695 train_time:63743ms step_avg:96.29ms -step:663/1695 train_time:63838ms step_avg:96.29ms -step:664/1695 train_time:63933ms step_avg:96.28ms -step:665/1695 train_time:64029ms step_avg:96.28ms -step:666/1695 train_time:64126ms step_avg:96.28ms -step:667/1695 train_time:64222ms step_avg:96.29ms -step:668/1695 train_time:64318ms step_avg:96.28ms -step:669/1695 train_time:64413ms step_avg:96.28ms -step:670/1695 train_time:64511ms step_avg:96.28ms -step:671/1695 train_time:64608ms step_avg:96.29ms -step:672/1695 train_time:64705ms step_avg:96.29ms -step:673/1695 train_time:64801ms step_avg:96.29ms -step:674/1695 train_time:64896ms step_avg:96.28ms -step:675/1695 train_time:64993ms step_avg:96.29ms -step:676/1695 train_time:65089ms step_avg:96.29ms -step:677/1695 train_time:65186ms step_avg:96.29ms -step:678/1695 train_time:65283ms step_avg:96.29ms -step:679/1695 train_time:65378ms step_avg:96.29ms -step:680/1695 train_time:65474ms step_avg:96.29ms -step:681/1695 train_time:65571ms step_avg:96.29ms -step:682/1695 train_time:65667ms step_avg:96.29ms -step:683/1695 train_time:65763ms step_avg:96.29ms -step:684/1695 train_time:65859ms step_avg:96.28ms -step:685/1695 train_time:65954ms step_avg:96.28ms -step:686/1695 train_time:66050ms step_avg:96.28ms -step:687/1695 train_time:66146ms step_avg:96.28ms -step:688/1695 train_time:66242ms step_avg:96.28ms -step:689/1695 train_time:66338ms step_avg:96.28ms -step:690/1695 train_time:66433ms step_avg:96.28ms -step:691/1695 train_time:66794ms step_avg:96.66ms -step:692/1695 train_time:66958ms step_avg:96.76ms -step:693/1695 train_time:67053ms step_avg:96.76ms -step:694/1695 train_time:67148ms step_avg:96.76ms -step:695/1695 train_time:67244ms step_avg:96.75ms -step:696/1695 train_time:67338ms step_avg:96.75ms -step:697/1695 train_time:67433ms step_avg:96.75ms -step:698/1695 train_time:67528ms step_avg:96.75ms -step:699/1695 train_time:67624ms step_avg:96.74ms -step:700/1695 train_time:67718ms step_avg:96.74ms -step:701/1695 train_time:67820ms step_avg:96.75ms -step:702/1695 train_time:67919ms step_avg:96.75ms -step:703/1695 train_time:68017ms step_avg:96.75ms -step:704/1695 train_time:68115ms step_avg:96.75ms -step:705/1695 train_time:68211ms step_avg:96.75ms -step:706/1695 train_time:68306ms step_avg:96.75ms -step:707/1695 train_time:68400ms step_avg:96.75ms -step:708/1695 train_time:68495ms step_avg:96.74ms -step:709/1695 train_time:68590ms step_avg:96.74ms -step:710/1695 train_time:68686ms step_avg:96.74ms -step:711/1695 train_time:68784ms step_avg:96.74ms -step:712/1695 train_time:68882ms step_avg:96.74ms -step:713/1695 train_time:68978ms step_avg:96.74ms -step:714/1695 train_time:69074ms step_avg:96.74ms -step:715/1695 train_time:69170ms step_avg:96.74ms -step:716/1695 train_time:69266ms step_avg:96.74ms -step:717/1695 train_time:69361ms step_avg:96.74ms -step:718/1695 train_time:69455ms step_avg:96.73ms -step:719/1695 train_time:69551ms step_avg:96.73ms -step:720/1695 train_time:69646ms step_avg:96.73ms -step:721/1695 train_time:69742ms step_avg:96.73ms -step:722/1695 train_time:69838ms step_avg:96.73ms -step:723/1695 train_time:69935ms step_avg:96.73ms -step:724/1695 train_time:70032ms step_avg:96.73ms -step:725/1695 train_time:70129ms step_avg:96.73ms -step:726/1695 train_time:70225ms step_avg:96.73ms -step:727/1695 train_time:70322ms step_avg:96.73ms -step:728/1695 train_time:70417ms step_avg:96.73ms -step:729/1695 train_time:70512ms step_avg:96.72ms -step:730/1695 train_time:70608ms step_avg:96.72ms -step:731/1695 train_time:70704ms step_avg:96.72ms -step:732/1695 train_time:70801ms step_avg:96.72ms -step:733/1695 train_time:70896ms step_avg:96.72ms -step:734/1695 train_time:70993ms step_avg:96.72ms -step:735/1695 train_time:71090ms step_avg:96.72ms -step:736/1695 train_time:71187ms step_avg:96.72ms -step:737/1695 train_time:71284ms step_avg:96.72ms -step:738/1695 train_time:71380ms step_avg:96.72ms -step:739/1695 train_time:71475ms step_avg:96.72ms -step:740/1695 train_time:71570ms step_avg:96.72ms -step:741/1695 train_time:71667ms step_avg:96.72ms -step:742/1695 train_time:71765ms step_avg:96.72ms -step:743/1695 train_time:71861ms step_avg:96.72ms -step:744/1695 train_time:71957ms step_avg:96.72ms -step:745/1695 train_time:72053ms step_avg:96.71ms -step:746/1695 train_time:72150ms step_avg:96.72ms -step:747/1695 train_time:72247ms step_avg:96.72ms -step:748/1695 train_time:72343ms step_avg:96.72ms -step:749/1695 train_time:72439ms step_avg:96.71ms -step:750/1695 train_time:72534ms step_avg:96.71ms -step:750/1695 val_loss:3.5645 train_time:72628ms step_avg:96.84ms -step:751/1695 train_time:72654ms step_avg:96.74ms -step:752/1695 train_time:72734ms step_avg:96.72ms -step:753/1695 train_time:72834ms step_avg:96.72ms -step:754/1695 train_time:72930ms step_avg:96.72ms -step:755/1695 train_time:73026ms step_avg:96.72ms -step:756/1695 train_time:73121ms step_avg:96.72ms -step:757/1695 train_time:73215ms step_avg:96.72ms -step:758/1695 train_time:73311ms step_avg:96.72ms -step:759/1695 train_time:73407ms step_avg:96.72ms -step:760/1695 train_time:73501ms step_avg:96.71ms -step:761/1695 train_time:73598ms step_avg:96.71ms -step:762/1695 train_time:73696ms step_avg:96.71ms -step:763/1695 train_time:73795ms step_avg:96.72ms -step:764/1695 train_time:73893ms step_avg:96.72ms -step:765/1695 train_time:73990ms step_avg:96.72ms -step:766/1695 train_time:74086ms step_avg:96.72ms -step:767/1695 train_time:74182ms step_avg:96.72ms -step:768/1695 train_time:74276ms step_avg:96.71ms -step:769/1695 train_time:74372ms step_avg:96.71ms -step:770/1695 train_time:74468ms step_avg:96.71ms -step:771/1695 train_time:74564ms step_avg:96.71ms -step:772/1695 train_time:74661ms step_avg:96.71ms -step:773/1695 train_time:74758ms step_avg:96.71ms -step:774/1695 train_time:74854ms step_avg:96.71ms -step:775/1695 train_time:74951ms step_avg:96.71ms -step:776/1695 train_time:75048ms step_avg:96.71ms -step:777/1695 train_time:75143ms step_avg:96.71ms -step:778/1695 train_time:75238ms step_avg:96.71ms -step:779/1695 train_time:75333ms step_avg:96.71ms -step:780/1695 train_time:75429ms step_avg:96.70ms -step:781/1695 train_time:75526ms step_avg:96.70ms -step:782/1695 train_time:75622ms step_avg:96.70ms -step:783/1695 train_time:75718ms step_avg:96.70ms -step:784/1695 train_time:75815ms step_avg:96.70ms -step:785/1695 train_time:75911ms step_avg:96.70ms -step:786/1695 train_time:76008ms step_avg:96.70ms -step:787/1695 train_time:76104ms step_avg:96.70ms -step:788/1695 train_time:76199ms step_avg:96.70ms -step:789/1695 train_time:76295ms step_avg:96.70ms -step:790/1695 train_time:76390ms step_avg:96.70ms -step:791/1695 train_time:76486ms step_avg:96.70ms -step:792/1695 train_time:76582ms step_avg:96.69ms -step:793/1695 train_time:76678ms step_avg:96.69ms -step:794/1695 train_time:76774ms step_avg:96.69ms -step:795/1695 train_time:76871ms step_avg:96.69ms -step:796/1695 train_time:76969ms step_avg:96.69ms -step:797/1695 train_time:77067ms step_avg:96.70ms -step:798/1695 train_time:77163ms step_avg:96.70ms -step:799/1695 train_time:77259ms step_avg:96.69ms -step:800/1695 train_time:77355ms step_avg:96.69ms -step:801/1695 train_time:77450ms step_avg:96.69ms -step:802/1695 train_time:77546ms step_avg:96.69ms -step:803/1695 train_time:77642ms step_avg:96.69ms -step:804/1695 train_time:77738ms step_avg:96.69ms -step:805/1695 train_time:77834ms step_avg:96.69ms -step:806/1695 train_time:77931ms step_avg:96.69ms -step:807/1695 train_time:78029ms step_avg:96.69ms -step:808/1695 train_time:78126ms step_avg:96.69ms -step:809/1695 train_time:78222ms step_avg:96.69ms -step:810/1695 train_time:78317ms step_avg:96.69ms -step:811/1695 train_time:78413ms step_avg:96.69ms -step:812/1695 train_time:78509ms step_avg:96.69ms -step:813/1695 train_time:78604ms step_avg:96.68ms -step:814/1695 train_time:78699ms step_avg:96.68ms -step:815/1695 train_time:78795ms step_avg:96.68ms -step:816/1695 train_time:78891ms step_avg:96.68ms -step:817/1695 train_time:78988ms step_avg:96.68ms -step:818/1695 train_time:79085ms step_avg:96.68ms -step:819/1695 train_time:79182ms step_avg:96.68ms -step:820/1695 train_time:79278ms step_avg:96.68ms -step:821/1695 train_time:79373ms step_avg:96.68ms -step:822/1695 train_time:79469ms step_avg:96.68ms -step:823/1695 train_time:79565ms step_avg:96.68ms -step:824/1695 train_time:79660ms step_avg:96.67ms -step:825/1695 train_time:79755ms step_avg:96.67ms -step:826/1695 train_time:79852ms step_avg:96.67ms -step:827/1695 train_time:79949ms step_avg:96.67ms -step:828/1695 train_time:80045ms step_avg:96.67ms -step:829/1695 train_time:80142ms step_avg:96.67ms -step:830/1695 train_time:80237ms step_avg:96.67ms -step:831/1695 train_time:80333ms step_avg:96.67ms -step:832/1695 train_time:80430ms step_avg:96.67ms -step:833/1695 train_time:80527ms step_avg:96.67ms -step:834/1695 train_time:80624ms step_avg:96.67ms -step:835/1695 train_time:80719ms step_avg:96.67ms -step:836/1695 train_time:80815ms step_avg:96.67ms -step:837/1695 train_time:80911ms step_avg:96.67ms -step:838/1695 train_time:81007ms step_avg:96.67ms -step:839/1695 train_time:81103ms step_avg:96.67ms -step:840/1695 train_time:81199ms step_avg:96.67ms -step:841/1695 train_time:81295ms step_avg:96.66ms -step:842/1695 train_time:81392ms step_avg:96.66ms -step:843/1695 train_time:81488ms step_avg:96.66ms -step:844/1695 train_time:81583ms step_avg:96.66ms -step:845/1695 train_time:81678ms step_avg:96.66ms -step:846/1695 train_time:81773ms step_avg:96.66ms -step:847/1695 train_time:81869ms step_avg:96.66ms -step:848/1695 train_time:81965ms step_avg:96.66ms -step:849/1695 train_time:82062ms step_avg:96.66ms -step:850/1695 train_time:82158ms step_avg:96.66ms -step:851/1695 train_time:82254ms step_avg:96.66ms -step:852/1695 train_time:82350ms step_avg:96.66ms -step:853/1695 train_time:82447ms step_avg:96.66ms -step:854/1695 train_time:82542ms step_avg:96.65ms -step:855/1695 train_time:82637ms step_avg:96.65ms -step:856/1695 train_time:82733ms step_avg:96.65ms -step:857/1695 train_time:82829ms step_avg:96.65ms -step:858/1695 train_time:82925ms step_avg:96.65ms -step:859/1695 train_time:83021ms step_avg:96.65ms -step:860/1695 train_time:83116ms step_avg:96.65ms -step:861/1695 train_time:83212ms step_avg:96.65ms -step:862/1695 train_time:83309ms step_avg:96.65ms -step:863/1695 train_time:83635ms step_avg:96.91ms -step:864/1695 train_time:83834ms step_avg:97.03ms -step:865/1695 train_time:83928ms step_avg:97.03ms -step:866/1695 train_time:84024ms step_avg:97.03ms -step:867/1695 train_time:84118ms step_avg:97.02ms -step:868/1695 train_time:84213ms step_avg:97.02ms -step:869/1695 train_time:84308ms step_avg:97.02ms -step:870/1695 train_time:84404ms step_avg:97.02ms -step:871/1695 train_time:84498ms step_avg:97.01ms -step:872/1695 train_time:84593ms step_avg:97.01ms -step:873/1695 train_time:84695ms step_avg:97.02ms -step:874/1695 train_time:84795ms step_avg:97.02ms -step:875/1695 train_time:84893ms step_avg:97.02ms -step:875/1695 val_loss:3.5224 train_time:84986ms step_avg:97.13ms -step:876/1695 train_time:85012ms step_avg:97.05ms -step:877/1695 train_time:85093ms step_avg:97.03ms -step:878/1695 train_time:85195ms step_avg:97.03ms -step:879/1695 train_time:85293ms step_avg:97.03ms -step:880/1695 train_time:85390ms step_avg:97.03ms -step:881/1695 train_time:85485ms step_avg:97.03ms -step:882/1695 train_time:85580ms step_avg:97.03ms -step:883/1695 train_time:85675ms step_avg:97.03ms -step:884/1695 train_time:85769ms step_avg:97.02ms -step:885/1695 train_time:85864ms step_avg:97.02ms -step:886/1695 train_time:85959ms step_avg:97.02ms -step:887/1695 train_time:86057ms step_avg:97.02ms -step:888/1695 train_time:86155ms step_avg:97.02ms -step:889/1695 train_time:86253ms step_avg:97.02ms -step:890/1695 train_time:86350ms step_avg:97.02ms -step:891/1695 train_time:86447ms step_avg:97.02ms -step:892/1695 train_time:86542ms step_avg:97.02ms -step:893/1695 train_time:86637ms step_avg:97.02ms -step:894/1695 train_time:86732ms step_avg:97.02ms -step:895/1695 train_time:86827ms step_avg:97.01ms -step:896/1695 train_time:86922ms step_avg:97.01ms -step:897/1695 train_time:87018ms step_avg:97.01ms -step:898/1695 train_time:87115ms step_avg:97.01ms -step:899/1695 train_time:87213ms step_avg:97.01ms -step:900/1695 train_time:87311ms step_avg:97.01ms -step:901/1695 train_time:87409ms step_avg:97.01ms -step:902/1695 train_time:87506ms step_avg:97.01ms -step:903/1695 train_time:87601ms step_avg:97.01ms -step:904/1695 train_time:87696ms step_avg:97.01ms -step:905/1695 train_time:87792ms step_avg:97.01ms -step:906/1695 train_time:87887ms step_avg:97.01ms -step:907/1695 train_time:87983ms step_avg:97.00ms -step:908/1695 train_time:88079ms step_avg:97.00ms -step:909/1695 train_time:88175ms step_avg:97.00ms -step:910/1695 train_time:88271ms step_avg:97.00ms -step:911/1695 train_time:88367ms step_avg:97.00ms -step:912/1695 train_time:88463ms step_avg:97.00ms -step:913/1695 train_time:88558ms step_avg:97.00ms -step:914/1695 train_time:88654ms step_avg:97.00ms -step:915/1695 train_time:88750ms step_avg:96.99ms -step:916/1695 train_time:88846ms step_avg:96.99ms -step:917/1695 train_time:88942ms step_avg:96.99ms -step:918/1695 train_time:89038ms step_avg:96.99ms -step:919/1695 train_time:89134ms step_avg:96.99ms -step:920/1695 train_time:89231ms step_avg:96.99ms -step:921/1695 train_time:89327ms step_avg:96.99ms -step:922/1695 train_time:89422ms step_avg:96.99ms -step:923/1695 train_time:89518ms step_avg:96.99ms -step:924/1695 train_time:89614ms step_avg:96.98ms -step:925/1695 train_time:89709ms step_avg:96.98ms -step:926/1695 train_time:89806ms step_avg:96.98ms -step:927/1695 train_time:89901ms step_avg:96.98ms -step:928/1695 train_time:89997ms step_avg:96.98ms -step:929/1695 train_time:90093ms step_avg:96.98ms -step:930/1695 train_time:90190ms step_avg:96.98ms -step:931/1695 train_time:90287ms step_avg:96.98ms -step:932/1695 train_time:90384ms step_avg:96.98ms -step:933/1695 train_time:90479ms step_avg:96.98ms -step:934/1695 train_time:90575ms step_avg:96.98ms -step:935/1695 train_time:90671ms step_avg:96.97ms -step:936/1695 train_time:90767ms step_avg:96.97ms -step:937/1695 train_time:90863ms step_avg:96.97ms -step:938/1695 train_time:90958ms step_avg:96.97ms -step:939/1695 train_time:91055ms step_avg:96.97ms -step:940/1695 train_time:91150ms step_avg:96.97ms -step:941/1695 train_time:91246ms step_avg:96.97ms -step:942/1695 train_time:91342ms step_avg:96.97ms -step:943/1695 train_time:91438ms step_avg:96.96ms -step:944/1695 train_time:91533ms step_avg:96.96ms -step:945/1695 train_time:91630ms step_avg:96.96ms -step:946/1695 train_time:91727ms step_avg:96.96ms -step:947/1695 train_time:91824ms step_avg:96.96ms -step:948/1695 train_time:91920ms step_avg:96.96ms -step:949/1695 train_time:92015ms step_avg:96.96ms -step:950/1695 train_time:92112ms step_avg:96.96ms -step:951/1695 train_time:92208ms step_avg:96.96ms -step:952/1695 train_time:92305ms step_avg:96.96ms -step:953/1695 train_time:92401ms step_avg:96.96ms -step:954/1695 train_time:92496ms step_avg:96.96ms -step:955/1695 train_time:92592ms step_avg:96.96ms -step:956/1695 train_time:92689ms step_avg:96.95ms -step:957/1695 train_time:92785ms step_avg:96.95ms -step:958/1695 train_time:92881ms step_avg:96.95ms -step:959/1695 train_time:92976ms step_avg:96.95ms -step:960/1695 train_time:93072ms step_avg:96.95ms -step:961/1695 train_time:93168ms step_avg:96.95ms -step:962/1695 train_time:93265ms step_avg:96.95ms -step:963/1695 train_time:93362ms step_avg:96.95ms -step:964/1695 train_time:93458ms step_avg:96.95ms -step:965/1695 train_time:93555ms step_avg:96.95ms -step:966/1695 train_time:93652ms step_avg:96.95ms -step:967/1695 train_time:93748ms step_avg:96.95ms -step:968/1695 train_time:93844ms step_avg:96.95ms -step:969/1695 train_time:93939ms step_avg:96.94ms -step:970/1695 train_time:94035ms step_avg:96.94ms -step:971/1695 train_time:94132ms step_avg:96.94ms -step:972/1695 train_time:94229ms step_avg:96.94ms -step:973/1695 train_time:94326ms step_avg:96.94ms -step:974/1695 train_time:94422ms step_avg:96.94ms -step:975/1695 train_time:94517ms step_avg:96.94ms -step:976/1695 train_time:94614ms step_avg:96.94ms -step:977/1695 train_time:94712ms step_avg:96.94ms -step:978/1695 train_time:94808ms step_avg:96.94ms -step:979/1695 train_time:94904ms step_avg:96.94ms -step:980/1695 train_time:95000ms step_avg:96.94ms -step:981/1695 train_time:95095ms step_avg:96.94ms -step:982/1695 train_time:95192ms step_avg:96.94ms -step:983/1695 train_time:95289ms step_avg:96.94ms -step:984/1695 train_time:95386ms step_avg:96.94ms -step:985/1695 train_time:95482ms step_avg:96.94ms -step:986/1695 train_time:95577ms step_avg:96.93ms -step:987/1695 train_time:95673ms step_avg:96.93ms -step:988/1695 train_time:95770ms step_avg:96.93ms -step:989/1695 train_time:95866ms step_avg:96.93ms -step:990/1695 train_time:95963ms step_avg:96.93ms -step:991/1695 train_time:96059ms step_avg:96.93ms -step:992/1695 train_time:96154ms step_avg:96.93ms -step:993/1695 train_time:96251ms step_avg:96.93ms -step:994/1695 train_time:96348ms step_avg:96.93ms -step:995/1695 train_time:96445ms step_avg:96.93ms -step:996/1695 train_time:96539ms step_avg:96.93ms -step:997/1695 train_time:96634ms step_avg:96.93ms -step:998/1695 train_time:96732ms step_avg:96.93ms -step:999/1695 train_time:96829ms step_avg:96.93ms -step:1000/1695 train_time:96926ms step_avg:96.93ms -step:1000/1695 val_loss:3.4843 train_time:97020ms step_avg:97.02ms -step:1001/1695 train_time:97046ms step_avg:96.95ms -step:1002/1695 train_time:97126ms step_avg:96.93ms -step:1003/1695 train_time:97224ms step_avg:96.93ms -step:1004/1695 train_time:97319ms step_avg:96.93ms -step:1005/1695 train_time:97415ms step_avg:96.93ms -step:1006/1695 train_time:97511ms step_avg:96.93ms -step:1007/1695 train_time:97607ms step_avg:96.93ms -step:1008/1695 train_time:97701ms step_avg:96.93ms -step:1009/1695 train_time:97796ms step_avg:96.92ms -step:1010/1695 train_time:97892ms step_avg:96.92ms -step:1011/1695 train_time:97989ms step_avg:96.92ms -step:1012/1695 train_time:98087ms step_avg:96.92ms -step:1013/1695 train_time:98185ms step_avg:96.93ms -step:1014/1695 train_time:98281ms step_avg:96.92ms -step:1015/1695 train_time:98377ms step_avg:96.92ms -step:1016/1695 train_time:98474ms step_avg:96.92ms -step:1017/1695 train_time:98570ms step_avg:96.92ms -step:1018/1695 train_time:98665ms step_avg:96.92ms -step:1019/1695 train_time:98760ms step_avg:96.92ms -step:1020/1695 train_time:98855ms step_avg:96.92ms -step:1021/1695 train_time:98951ms step_avg:96.92ms -step:1022/1695 train_time:99049ms step_avg:96.92ms -step:1023/1695 train_time:99145ms step_avg:96.92ms -step:1024/1695 train_time:99241ms step_avg:96.92ms -step:1025/1695 train_time:99337ms step_avg:96.91ms -step:1026/1695 train_time:99434ms step_avg:96.91ms -step:1027/1695 train_time:99531ms step_avg:96.91ms -step:1028/1695 train_time:99627ms step_avg:96.91ms -step:1029/1695 train_time:99724ms step_avg:96.91ms -step:1030/1695 train_time:99817ms step_avg:96.91ms -step:1031/1695 train_time:99913ms step_avg:96.91ms -step:1032/1695 train_time:100009ms step_avg:96.91ms -step:1033/1695 train_time:100105ms step_avg:96.91ms -step:1034/1695 train_time:100202ms step_avg:96.91ms -step:1035/1695 train_time:100298ms step_avg:96.91ms -step:1036/1695 train_time:100628ms step_avg:97.13ms -step:1037/1695 train_time:100810ms step_avg:97.21ms -step:1038/1695 train_time:100904ms step_avg:97.21ms -step:1039/1695 train_time:100998ms step_avg:97.21ms -step:1040/1695 train_time:101093ms step_avg:97.20ms -step:1041/1695 train_time:101188ms step_avg:97.20ms -step:1042/1695 train_time:101283ms step_avg:97.20ms -step:1043/1695 train_time:101377ms step_avg:97.20ms -step:1044/1695 train_time:101472ms step_avg:97.20ms -step:1045/1695 train_time:101567ms step_avg:97.19ms -step:1046/1695 train_time:101666ms step_avg:97.20ms -step:1047/1695 train_time:101765ms step_avg:97.20ms -step:1048/1695 train_time:101862ms step_avg:97.20ms -step:1049/1695 train_time:101959ms step_avg:97.20ms -step:1050/1695 train_time:102054ms step_avg:97.19ms -step:1051/1695 train_time:102149ms step_avg:97.19ms -step:1052/1695 train_time:102244ms step_avg:97.19ms -step:1053/1695 train_time:102339ms step_avg:97.19ms -step:1054/1695 train_time:102434ms step_avg:97.19ms -step:1055/1695 train_time:102529ms step_avg:97.18ms -step:1056/1695 train_time:102626ms step_avg:97.18ms -step:1057/1695 train_time:102724ms step_avg:97.18ms -step:1058/1695 train_time:102821ms step_avg:97.18ms -step:1059/1695 train_time:102917ms step_avg:97.18ms -step:1060/1695 train_time:103015ms step_avg:97.18ms -step:1061/1695 train_time:103111ms step_avg:97.18ms -step:1062/1695 train_time:103207ms step_avg:97.18ms -step:1063/1695 train_time:103302ms step_avg:97.18ms -step:1064/1695 train_time:103397ms step_avg:97.18ms -step:1065/1695 train_time:103493ms step_avg:97.18ms -step:1066/1695 train_time:103589ms step_avg:97.18ms -step:1067/1695 train_time:103684ms step_avg:97.17ms -step:1068/1695 train_time:103781ms step_avg:97.17ms -step:1069/1695 train_time:103877ms step_avg:97.17ms -step:1070/1695 train_time:103974ms step_avg:97.17ms -step:1071/1695 train_time:104070ms step_avg:97.17ms -step:1072/1695 train_time:104166ms step_avg:97.17ms -step:1073/1695 train_time:104262ms step_avg:97.17ms -step:1074/1695 train_time:104357ms step_avg:97.17ms -step:1075/1695 train_time:104453ms step_avg:97.17ms -step:1076/1695 train_time:104549ms step_avg:97.16ms -step:1077/1695 train_time:104644ms step_avg:97.16ms -step:1078/1695 train_time:104739ms step_avg:97.16ms -step:1079/1695 train_time:104836ms step_avg:97.16ms -step:1080/1695 train_time:104933ms step_avg:97.16ms -step:1081/1695 train_time:105029ms step_avg:97.16ms -step:1082/1695 train_time:105125ms step_avg:97.16ms -step:1083/1695 train_time:105221ms step_avg:97.16ms -step:1084/1695 train_time:105316ms step_avg:97.15ms -step:1085/1695 train_time:105412ms step_avg:97.15ms -step:1086/1695 train_time:105508ms step_avg:97.15ms -step:1087/1695 train_time:105604ms step_avg:97.15ms -step:1088/1695 train_time:105699ms step_avg:97.15ms -step:1089/1695 train_time:105795ms step_avg:97.15ms -step:1090/1695 train_time:105893ms step_avg:97.15ms -step:1091/1695 train_time:105990ms step_avg:97.15ms -step:1092/1695 train_time:106086ms step_avg:97.15ms -step:1093/1695 train_time:106181ms step_avg:97.15ms -step:1094/1695 train_time:106277ms step_avg:97.15ms -step:1095/1695 train_time:106373ms step_avg:97.14ms -step:1096/1695 train_time:106469ms step_avg:97.14ms -step:1097/1695 train_time:106565ms step_avg:97.14ms -step:1098/1695 train_time:106661ms step_avg:97.14ms -step:1099/1695 train_time:106756ms step_avg:97.14ms -step:1100/1695 train_time:106854ms step_avg:97.14ms -step:1101/1695 train_time:106950ms step_avg:97.14ms -step:1102/1695 train_time:107046ms step_avg:97.14ms -step:1103/1695 train_time:107142ms step_avg:97.14ms -step:1104/1695 train_time:107237ms step_avg:97.13ms -step:1105/1695 train_time:107333ms step_avg:97.13ms -step:1106/1695 train_time:107429ms step_avg:97.13ms -step:1107/1695 train_time:107526ms step_avg:97.13ms -step:1108/1695 train_time:107622ms step_avg:97.13ms -step:1109/1695 train_time:107718ms step_avg:97.13ms -step:1110/1695 train_time:107814ms step_avg:97.13ms -step:1111/1695 train_time:107912ms step_avg:97.13ms -step:1112/1695 train_time:108009ms step_avg:97.13ms -step:1113/1695 train_time:108105ms step_avg:97.13ms -step:1114/1695 train_time:108200ms step_avg:97.13ms -step:1115/1695 train_time:108296ms step_avg:97.13ms -step:1116/1695 train_time:108393ms step_avg:97.13ms -step:1117/1695 train_time:108490ms step_avg:97.13ms -step:1118/1695 train_time:108587ms step_avg:97.13ms -step:1119/1695 train_time:108683ms step_avg:97.13ms -step:1120/1695 train_time:108778ms step_avg:97.12ms -step:1121/1695 train_time:108875ms step_avg:97.12ms -step:1122/1695 train_time:108970ms step_avg:97.12ms -step:1123/1695 train_time:109068ms step_avg:97.12ms -step:1124/1695 train_time:109164ms step_avg:97.12ms -step:1125/1695 train_time:109260ms step_avg:97.12ms -step:1125/1695 val_loss:3.4352 train_time:109353ms step_avg:97.20ms -step:1126/1695 train_time:109379ms step_avg:97.14ms -step:1127/1695 train_time:109456ms step_avg:97.12ms -step:1128/1695 train_time:109554ms step_avg:97.12ms -step:1129/1695 train_time:109650ms step_avg:97.12ms -step:1130/1695 train_time:109745ms step_avg:97.12ms -step:1131/1695 train_time:109840ms step_avg:97.12ms -step:1132/1695 train_time:109934ms step_avg:97.12ms -step:1133/1695 train_time:110031ms step_avg:97.11ms -step:1134/1695 train_time:110129ms step_avg:97.12ms -step:1135/1695 train_time:110228ms step_avg:97.12ms -step:1136/1695 train_time:110328ms step_avg:97.12ms -step:1137/1695 train_time:110431ms step_avg:97.12ms -step:1138/1695 train_time:110532ms step_avg:97.13ms -step:1139/1695 train_time:110631ms step_avg:97.13ms -step:1140/1695 train_time:110729ms step_avg:97.13ms -step:1141/1695 train_time:110826ms step_avg:97.13ms -step:1142/1695 train_time:110923ms step_avg:97.13ms -step:1143/1695 train_time:111020ms step_avg:97.13ms -step:1144/1695 train_time:111116ms step_avg:97.13ms -step:1145/1695 train_time:111214ms step_avg:97.13ms -step:1146/1695 train_time:111312ms step_avg:97.13ms -step:1147/1695 train_time:111412ms step_avg:97.13ms -step:1148/1695 train_time:111511ms step_avg:97.14ms -step:1149/1695 train_time:111611ms step_avg:97.14ms -step:1150/1695 train_time:111710ms step_avg:97.14ms -step:1151/1695 train_time:111808ms step_avg:97.14ms -step:1152/1695 train_time:111907ms step_avg:97.14ms -step:1153/1695 train_time:112006ms step_avg:97.14ms -step:1154/1695 train_time:112104ms step_avg:97.14ms -step:1155/1695 train_time:112201ms step_avg:97.14ms -step:1156/1695 train_time:112299ms step_avg:97.14ms -step:1157/1695 train_time:112396ms step_avg:97.14ms -step:1158/1695 train_time:112495ms step_avg:97.15ms -step:1159/1695 train_time:112593ms step_avg:97.15ms -step:1160/1695 train_time:112691ms step_avg:97.15ms -step:1161/1695 train_time:112790ms step_avg:97.15ms -step:1162/1695 train_time:112888ms step_avg:97.15ms -step:1163/1695 train_time:112986ms step_avg:97.15ms -step:1164/1695 train_time:113084ms step_avg:97.15ms -step:1165/1695 train_time:113181ms step_avg:97.15ms -step:1166/1695 train_time:113279ms step_avg:97.15ms -step:1167/1695 train_time:113376ms step_avg:97.15ms -step:1168/1695 train_time:113473ms step_avg:97.15ms -step:1169/1695 train_time:113571ms step_avg:97.15ms -step:1170/1695 train_time:113669ms step_avg:97.15ms -step:1171/1695 train_time:113767ms step_avg:97.15ms -step:1172/1695 train_time:113864ms step_avg:97.15ms -step:1173/1695 train_time:113962ms step_avg:97.15ms -step:1174/1695 train_time:114060ms step_avg:97.15ms -step:1175/1695 train_time:114157ms step_avg:97.16ms -step:1176/1695 train_time:114255ms step_avg:97.16ms -step:1177/1695 train_time:114352ms step_avg:97.16ms -step:1178/1695 train_time:114451ms step_avg:97.16ms -step:1179/1695 train_time:114550ms step_avg:97.16ms -step:1180/1695 train_time:114648ms step_avg:97.16ms -step:1181/1695 train_time:114747ms step_avg:97.16ms -step:1182/1695 train_time:114844ms step_avg:97.16ms -step:1183/1695 train_time:114942ms step_avg:97.16ms -step:1184/1695 train_time:115039ms step_avg:97.16ms -step:1185/1695 train_time:115136ms step_avg:97.16ms -step:1186/1695 train_time:115234ms step_avg:97.16ms -step:1187/1695 train_time:115331ms step_avg:97.16ms -step:1188/1695 train_time:115430ms step_avg:97.16ms -step:1189/1695 train_time:115529ms step_avg:97.16ms -step:1190/1695 train_time:115627ms step_avg:97.17ms -step:1191/1695 train_time:115725ms step_avg:97.17ms -step:1192/1695 train_time:115823ms step_avg:97.17ms -step:1193/1695 train_time:115922ms step_avg:97.17ms -step:1194/1695 train_time:116019ms step_avg:97.17ms -step:1195/1695 train_time:116117ms step_avg:97.17ms -step:1196/1695 train_time:116214ms step_avg:97.17ms -step:1197/1695 train_time:116311ms step_avg:97.17ms -step:1198/1695 train_time:116409ms step_avg:97.17ms -step:1199/1695 train_time:116507ms step_avg:97.17ms -step:1200/1695 train_time:116604ms step_avg:97.17ms -step:1201/1695 train_time:116702ms step_avg:97.17ms -step:1202/1695 train_time:116799ms step_avg:97.17ms -step:1203/1695 train_time:116897ms step_avg:97.17ms -step:1204/1695 train_time:116995ms step_avg:97.17ms -step:1205/1695 train_time:117093ms step_avg:97.17ms -step:1206/1695 train_time:117191ms step_avg:97.17ms -step:1207/1695 train_time:117289ms step_avg:97.17ms -step:1208/1695 train_time:117624ms step_avg:97.37ms -step:1209/1695 train_time:117814ms step_avg:97.45ms -step:1210/1695 train_time:117909ms step_avg:97.45ms -step:1211/1695 train_time:118006ms step_avg:97.44ms -step:1212/1695 train_time:118103ms step_avg:97.44ms -step:1213/1695 train_time:118199ms step_avg:97.44ms -step:1214/1695 train_time:118295ms step_avg:97.44ms -step:1215/1695 train_time:118393ms step_avg:97.44ms -step:1216/1695 train_time:118490ms step_avg:97.44ms -step:1217/1695 train_time:118587ms step_avg:97.44ms -step:1218/1695 train_time:118689ms step_avg:97.45ms -step:1219/1695 train_time:118792ms step_avg:97.45ms -step:1220/1695 train_time:118892ms step_avg:97.45ms -step:1221/1695 train_time:118991ms step_avg:97.45ms -step:1222/1695 train_time:119090ms step_avg:97.46ms -step:1223/1695 train_time:119190ms step_avg:97.46ms -step:1224/1695 train_time:119287ms step_avg:97.46ms -step:1225/1695 train_time:119384ms step_avg:97.46ms -step:1226/1695 train_time:119481ms step_avg:97.46ms -step:1227/1695 train_time:119577ms step_avg:97.45ms -step:1228/1695 train_time:119673ms step_avg:97.45ms -step:1229/1695 train_time:119773ms step_avg:97.46ms -step:1230/1695 train_time:119873ms step_avg:97.46ms -step:1231/1695 train_time:119972ms step_avg:97.46ms -step:1232/1695 train_time:120070ms step_avg:97.46ms -step:1233/1695 train_time:120168ms step_avg:97.46ms -step:1234/1695 train_time:120268ms step_avg:97.46ms -step:1235/1695 train_time:120366ms step_avg:97.46ms -step:1236/1695 train_time:120464ms step_avg:97.46ms -step:1237/1695 train_time:120561ms step_avg:97.46ms -step:1238/1695 train_time:120658ms step_avg:97.46ms -step:1239/1695 train_time:120756ms step_avg:97.46ms -step:1240/1695 train_time:120853ms step_avg:97.46ms -step:1241/1695 train_time:120952ms step_avg:97.46ms -step:1242/1695 train_time:121051ms step_avg:97.46ms -step:1243/1695 train_time:121150ms step_avg:97.47ms -step:1244/1695 train_time:121249ms step_avg:97.47ms -step:1245/1695 train_time:121347ms step_avg:97.47ms -step:1246/1695 train_time:121444ms step_avg:97.47ms -step:1247/1695 train_time:121542ms step_avg:97.47ms -step:1248/1695 train_time:121640ms step_avg:97.47ms -step:1249/1695 train_time:121738ms step_avg:97.47ms -step:1250/1695 train_time:121835ms step_avg:97.47ms -step:1250/1695 val_loss:3.3886 train_time:121930ms step_avg:97.54ms -step:1251/1695 train_time:121956ms step_avg:97.49ms -step:1252/1695 train_time:122037ms step_avg:97.47ms -step:1253/1695 train_time:122135ms step_avg:97.47ms -step:1254/1695 train_time:122231ms step_avg:97.47ms -step:1255/1695 train_time:122327ms step_avg:97.47ms -step:1256/1695 train_time:122424ms step_avg:97.47ms -step:1257/1695 train_time:122520ms step_avg:97.47ms -step:1258/1695 train_time:122616ms step_avg:97.47ms -step:1259/1695 train_time:122713ms step_avg:97.47ms -step:1260/1695 train_time:122809ms step_avg:97.47ms -step:1261/1695 train_time:122913ms step_avg:97.47ms -step:1262/1695 train_time:123013ms step_avg:97.47ms -step:1263/1695 train_time:123111ms step_avg:97.48ms -step:1264/1695 train_time:123209ms step_avg:97.48ms -step:1265/1695 train_time:123306ms step_avg:97.48ms -step:1266/1695 train_time:123403ms step_avg:97.47ms -step:1267/1695 train_time:123499ms step_avg:97.47ms -step:1268/1695 train_time:123596ms step_avg:97.47ms -step:1269/1695 train_time:123693ms step_avg:97.47ms -step:1270/1695 train_time:123789ms step_avg:97.47ms -step:1271/1695 train_time:123889ms step_avg:97.47ms -step:1272/1695 train_time:123988ms step_avg:97.47ms -step:1273/1695 train_time:124086ms step_avg:97.48ms -step:1274/1695 train_time:124185ms step_avg:97.48ms -step:1275/1695 train_time:124284ms step_avg:97.48ms -step:1276/1695 train_time:124381ms step_avg:97.48ms -step:1277/1695 train_time:124479ms step_avg:97.48ms -step:1278/1695 train_time:124576ms step_avg:97.48ms -step:1279/1695 train_time:124673ms step_avg:97.48ms -step:1280/1695 train_time:124770ms step_avg:97.48ms -step:1281/1695 train_time:124868ms step_avg:97.48ms -step:1282/1695 train_time:124966ms step_avg:97.48ms -step:1283/1695 train_time:125065ms step_avg:97.48ms -step:1284/1695 train_time:125165ms step_avg:97.48ms -step:1285/1695 train_time:125264ms step_avg:97.48ms -step:1286/1695 train_time:125363ms step_avg:97.48ms -step:1287/1695 train_time:125460ms step_avg:97.48ms -step:1288/1695 train_time:125558ms step_avg:97.48ms -step:1289/1695 train_time:125656ms step_avg:97.48ms -step:1290/1695 train_time:125755ms step_avg:97.48ms -step:1291/1695 train_time:125853ms step_avg:97.48ms -step:1292/1695 train_time:125950ms step_avg:97.48ms -step:1293/1695 train_time:126048ms step_avg:97.49ms -step:1294/1695 train_time:126147ms step_avg:97.49ms -step:1295/1695 train_time:126245ms step_avg:97.49ms -step:1296/1695 train_time:126344ms step_avg:97.49ms -step:1297/1695 train_time:126440ms step_avg:97.49ms -step:1298/1695 train_time:126538ms step_avg:97.49ms -step:1299/1695 train_time:126635ms step_avg:97.49ms -step:1300/1695 train_time:126731ms step_avg:97.49ms -step:1301/1695 train_time:126829ms step_avg:97.49ms -step:1302/1695 train_time:126927ms step_avg:97.49ms -step:1303/1695 train_time:127025ms step_avg:97.49ms -step:1304/1695 train_time:127124ms step_avg:97.49ms -step:1305/1695 train_time:127222ms step_avg:97.49ms -step:1306/1695 train_time:127320ms step_avg:97.49ms -step:1307/1695 train_time:127418ms step_avg:97.49ms -step:1308/1695 train_time:127515ms step_avg:97.49ms -step:1309/1695 train_time:127612ms step_avg:97.49ms -step:1310/1695 train_time:127710ms step_avg:97.49ms -step:1311/1695 train_time:127807ms step_avg:97.49ms -step:1312/1695 train_time:127905ms step_avg:97.49ms -step:1313/1695 train_time:128004ms step_avg:97.49ms -step:1314/1695 train_time:128104ms step_avg:97.49ms -step:1315/1695 train_time:128203ms step_avg:97.49ms -step:1316/1695 train_time:128301ms step_avg:97.49ms -step:1317/1695 train_time:128399ms step_avg:97.49ms -step:1318/1695 train_time:128498ms step_avg:97.49ms -step:1319/1695 train_time:128596ms step_avg:97.50ms -step:1320/1695 train_time:128695ms step_avg:97.50ms -step:1321/1695 train_time:128792ms step_avg:97.50ms -step:1322/1695 train_time:128889ms step_avg:97.50ms -step:1323/1695 train_time:128986ms step_avg:97.50ms -step:1324/1695 train_time:129085ms step_avg:97.50ms -step:1325/1695 train_time:129184ms step_avg:97.50ms -step:1326/1695 train_time:129282ms step_avg:97.50ms -step:1327/1695 train_time:129379ms step_avg:97.50ms -step:1328/1695 train_time:129477ms step_avg:97.50ms -step:1329/1695 train_time:129574ms step_avg:97.50ms -step:1330/1695 train_time:129672ms step_avg:97.50ms -step:1331/1695 train_time:129769ms step_avg:97.50ms -step:1332/1695 train_time:129866ms step_avg:97.50ms -step:1333/1695 train_time:129964ms step_avg:97.50ms -step:1334/1695 train_time:130063ms step_avg:97.50ms -step:1335/1695 train_time:130162ms step_avg:97.50ms -step:1336/1695 train_time:130260ms step_avg:97.50ms -step:1337/1695 train_time:130358ms step_avg:97.50ms -step:1338/1695 train_time:130455ms step_avg:97.50ms -step:1339/1695 train_time:130553ms step_avg:97.50ms -step:1340/1695 train_time:130652ms step_avg:97.50ms -step:1341/1695 train_time:130749ms step_avg:97.50ms -step:1342/1695 train_time:130846ms step_avg:97.50ms -step:1343/1695 train_time:130944ms step_avg:97.50ms -step:1344/1695 train_time:131043ms step_avg:97.50ms -step:1345/1695 train_time:131142ms step_avg:97.50ms -step:1346/1695 train_time:131241ms step_avg:97.50ms -step:1347/1695 train_time:131340ms step_avg:97.51ms -step:1348/1695 train_time:131438ms step_avg:97.51ms -step:1349/1695 train_time:131537ms step_avg:97.51ms -step:1350/1695 train_time:131636ms step_avg:97.51ms -step:1351/1695 train_time:131734ms step_avg:97.51ms -step:1352/1695 train_time:131832ms step_avg:97.51ms -step:1353/1695 train_time:131930ms step_avg:97.51ms -step:1354/1695 train_time:132028ms step_avg:97.51ms -step:1355/1695 train_time:132126ms step_avg:97.51ms -step:1356/1695 train_time:132223ms step_avg:97.51ms -step:1357/1695 train_time:132321ms step_avg:97.51ms -step:1358/1695 train_time:132419ms step_avg:97.51ms -step:1359/1695 train_time:132517ms step_avg:97.51ms -step:1360/1695 train_time:132614ms step_avg:97.51ms -step:1361/1695 train_time:132711ms step_avg:97.51ms -step:1362/1695 train_time:132808ms step_avg:97.51ms -step:1363/1695 train_time:132905ms step_avg:97.51ms -step:1364/1695 train_time:133004ms step_avg:97.51ms -step:1365/1695 train_time:133102ms step_avg:97.51ms -step:1366/1695 train_time:133200ms step_avg:97.51ms -step:1367/1695 train_time:133297ms step_avg:97.51ms -step:1368/1695 train_time:133393ms step_avg:97.51ms -step:1369/1695 train_time:133491ms step_avg:97.51ms -step:1370/1695 train_time:133589ms step_avg:97.51ms -step:1371/1695 train_time:133687ms step_avg:97.51ms -step:1372/1695 train_time:133785ms step_avg:97.51ms -step:1373/1695 train_time:133884ms step_avg:97.51ms -step:1374/1695 train_time:133982ms step_avg:97.51ms -step:1375/1695 train_time:134080ms step_avg:97.51ms -step:1375/1695 val_loss:3.3495 train_time:134174ms step_avg:97.58ms -step:1376/1695 train_time:134203ms step_avg:97.53ms -step:1377/1695 train_time:134283ms step_avg:97.52ms -step:1378/1695 train_time:134384ms step_avg:97.52ms -step:1379/1695 train_time:134483ms step_avg:97.52ms -step:1380/1695 train_time:134581ms step_avg:97.52ms -step:1381/1695 train_time:134941ms step_avg:97.71ms -step:1382/1695 train_time:135109ms step_avg:97.76ms -step:1383/1695 train_time:135205ms step_avg:97.76ms -step:1384/1695 train_time:135302ms step_avg:97.76ms -step:1385/1695 train_time:135398ms step_avg:97.76ms -step:1386/1695 train_time:135494ms step_avg:97.76ms -step:1387/1695 train_time:135591ms step_avg:97.76ms -step:1388/1695 train_time:135686ms step_avg:97.76ms -step:1389/1695 train_time:135783ms step_avg:97.76ms -step:1390/1695 train_time:135880ms step_avg:97.76ms -step:1391/1695 train_time:135981ms step_avg:97.76ms -step:1392/1695 train_time:136087ms step_avg:97.76ms -step:1393/1695 train_time:136186ms step_avg:97.76ms -step:1394/1695 train_time:136284ms step_avg:97.76ms -step:1395/1695 train_time:136381ms step_avg:97.76ms -step:1396/1695 train_time:136480ms step_avg:97.76ms -step:1397/1695 train_time:136577ms step_avg:97.76ms -step:1398/1695 train_time:136674ms step_avg:97.76ms -step:1399/1695 train_time:136770ms step_avg:97.76ms -step:1400/1695 train_time:136866ms step_avg:97.76ms -step:1401/1695 train_time:136964ms step_avg:97.76ms -step:1402/1695 train_time:137065ms step_avg:97.76ms -step:1403/1695 train_time:137165ms step_avg:97.77ms -step:1404/1695 train_time:137264ms step_avg:97.77ms -step:1405/1695 train_time:137362ms step_avg:97.77ms -step:1406/1695 train_time:137460ms step_avg:97.77ms -step:1407/1695 train_time:137558ms step_avg:97.77ms -step:1408/1695 train_time:137656ms step_avg:97.77ms -step:1409/1695 train_time:137752ms step_avg:97.77ms -step:1410/1695 train_time:137849ms step_avg:97.77ms -step:1411/1695 train_time:137946ms step_avg:97.76ms -step:1412/1695 train_time:138046ms step_avg:97.77ms -step:1413/1695 train_time:138144ms step_avg:97.77ms -step:1414/1695 train_time:138244ms step_avg:97.77ms -step:1415/1695 train_time:138342ms step_avg:97.77ms -step:1416/1695 train_time:138440ms step_avg:97.77ms -step:1417/1695 train_time:138538ms step_avg:97.77ms -step:1418/1695 train_time:138637ms step_avg:97.77ms -step:1419/1695 train_time:138735ms step_avg:97.77ms -step:1420/1695 train_time:138832ms step_avg:97.77ms -step:1421/1695 train_time:138928ms step_avg:97.77ms -step:1422/1695 train_time:139025ms step_avg:97.77ms -step:1423/1695 train_time:139123ms step_avg:97.77ms -step:1424/1695 train_time:139221ms step_avg:97.77ms -step:1425/1695 train_time:139320ms step_avg:97.77ms -step:1426/1695 train_time:139418ms step_avg:97.77ms -step:1427/1695 train_time:139515ms step_avg:97.77ms -step:1428/1695 train_time:139612ms step_avg:97.77ms -step:1429/1695 train_time:139709ms step_avg:97.77ms -step:1430/1695 train_time:139807ms step_avg:97.77ms -step:1431/1695 train_time:139905ms step_avg:97.77ms -step:1432/1695 train_time:140004ms step_avg:97.77ms -step:1433/1695 train_time:140104ms step_avg:97.77ms -step:1434/1695 train_time:140202ms step_avg:97.77ms -step:1435/1695 train_time:140300ms step_avg:97.77ms -step:1436/1695 train_time:140397ms step_avg:97.77ms -step:1437/1695 train_time:140496ms step_avg:97.77ms -step:1438/1695 train_time:140594ms step_avg:97.77ms -step:1439/1695 train_time:140691ms step_avg:97.77ms -step:1440/1695 train_time:140788ms step_avg:97.77ms -step:1441/1695 train_time:140885ms step_avg:97.77ms -step:1442/1695 train_time:140982ms step_avg:97.77ms -step:1443/1695 train_time:141080ms step_avg:97.77ms -step:1444/1695 train_time:141178ms step_avg:97.77ms -step:1445/1695 train_time:141275ms step_avg:97.77ms -step:1446/1695 train_time:141373ms step_avg:97.77ms -step:1447/1695 train_time:141471ms step_avg:97.77ms -step:1448/1695 train_time:141569ms step_avg:97.77ms -step:1449/1695 train_time:141667ms step_avg:97.77ms -step:1450/1695 train_time:141765ms step_avg:97.77ms -step:1451/1695 train_time:141864ms step_avg:97.77ms -step:1452/1695 train_time:141962ms step_avg:97.77ms -step:1453/1695 train_time:142060ms step_avg:97.77ms -step:1454/1695 train_time:142159ms step_avg:97.77ms -step:1455/1695 train_time:142257ms step_avg:97.77ms -step:1456/1695 train_time:142355ms step_avg:97.77ms -step:1457/1695 train_time:142452ms step_avg:97.77ms -step:1458/1695 train_time:142550ms step_avg:97.77ms -step:1459/1695 train_time:142648ms step_avg:97.77ms -step:1460/1695 train_time:142746ms step_avg:97.77ms -step:1461/1695 train_time:142844ms step_avg:97.77ms -step:1462/1695 train_time:142943ms step_avg:97.77ms -step:1463/1695 train_time:143041ms step_avg:97.77ms -step:1464/1695 train_time:143141ms step_avg:97.77ms -step:1465/1695 train_time:143241ms step_avg:97.78ms -step:1466/1695 train_time:143339ms step_avg:97.78ms -step:1467/1695 train_time:143438ms step_avg:97.78ms -step:1468/1695 train_time:143537ms step_avg:97.78ms -step:1469/1695 train_time:143635ms step_avg:97.78ms -step:1470/1695 train_time:143732ms step_avg:97.78ms -step:1471/1695 train_time:143829ms step_avg:97.78ms -step:1472/1695 train_time:143927ms step_avg:97.78ms -step:1473/1695 train_time:144025ms step_avg:97.78ms -step:1474/1695 train_time:144123ms step_avg:97.78ms -step:1475/1695 train_time:144221ms step_avg:97.78ms -step:1476/1695 train_time:144319ms step_avg:97.78ms -step:1477/1695 train_time:144418ms step_avg:97.78ms -step:1478/1695 train_time:144515ms step_avg:97.78ms -step:1479/1695 train_time:144613ms step_avg:97.78ms -step:1480/1695 train_time:144711ms step_avg:97.78ms -step:1481/1695 train_time:144809ms step_avg:97.78ms -step:1482/1695 train_time:144906ms step_avg:97.78ms -step:1483/1695 train_time:145004ms step_avg:97.78ms -step:1484/1695 train_time:145101ms step_avg:97.78ms -step:1485/1695 train_time:145200ms step_avg:97.78ms -step:1486/1695 train_time:145298ms step_avg:97.78ms -step:1487/1695 train_time:145397ms step_avg:97.78ms -step:1488/1695 train_time:145495ms step_avg:97.78ms -step:1489/1695 train_time:145593ms step_avg:97.78ms -step:1490/1695 train_time:145690ms step_avg:97.78ms -step:1491/1695 train_time:145787ms step_avg:97.78ms -step:1492/1695 train_time:145884ms step_avg:97.78ms -step:1493/1695 train_time:145982ms step_avg:97.78ms -step:1494/1695 train_time:146079ms step_avg:97.78ms -step:1495/1695 train_time:146177ms step_avg:97.78ms -step:1496/1695 train_time:146274ms step_avg:97.78ms -step:1497/1695 train_time:146372ms step_avg:97.78ms -step:1498/1695 train_time:146469ms step_avg:97.78ms -step:1499/1695 train_time:146567ms step_avg:97.78ms -step:1500/1695 train_time:146665ms step_avg:97.78ms -step:1500/1695 val_loss:3.3162 train_time:146761ms step_avg:97.84ms -step:1501/1695 train_time:146787ms step_avg:97.79ms -step:1502/1695 train_time:146870ms step_avg:97.78ms -step:1503/1695 train_time:146968ms step_avg:97.78ms -step:1504/1695 train_time:147065ms step_avg:97.78ms -step:1505/1695 train_time:147162ms step_avg:97.78ms -step:1506/1695 train_time:147259ms step_avg:97.78ms -step:1507/1695 train_time:147355ms step_avg:97.78ms -step:1508/1695 train_time:147452ms step_avg:97.78ms -step:1509/1695 train_time:147548ms step_avg:97.78ms -step:1510/1695 train_time:147645ms step_avg:97.78ms -step:1511/1695 train_time:147745ms step_avg:97.78ms -step:1512/1695 train_time:147846ms step_avg:97.78ms -step:1513/1695 train_time:147945ms step_avg:97.78ms -step:1514/1695 train_time:148043ms step_avg:97.78ms -step:1515/1695 train_time:148141ms step_avg:97.78ms -step:1516/1695 train_time:148240ms step_avg:97.78ms -step:1517/1695 train_time:148337ms step_avg:97.78ms -step:1518/1695 train_time:148433ms step_avg:97.78ms -step:1519/1695 train_time:148529ms step_avg:97.78ms -step:1520/1695 train_time:148627ms step_avg:97.78ms -step:1521/1695 train_time:148726ms step_avg:97.78ms -step:1522/1695 train_time:148825ms step_avg:97.78ms -step:1523/1695 train_time:148924ms step_avg:97.78ms -step:1524/1695 train_time:149022ms step_avg:97.78ms -step:1525/1695 train_time:149121ms step_avg:97.78ms -step:1526/1695 train_time:149219ms step_avg:97.78ms -step:1527/1695 train_time:149317ms step_avg:97.78ms -step:1528/1695 train_time:149413ms step_avg:97.78ms -step:1529/1695 train_time:149511ms step_avg:97.78ms -step:1530/1695 train_time:149608ms step_avg:97.78ms -step:1531/1695 train_time:149705ms step_avg:97.78ms -step:1532/1695 train_time:149804ms step_avg:97.78ms -step:1533/1695 train_time:149903ms step_avg:97.78ms -step:1534/1695 train_time:150001ms step_avg:97.78ms -step:1535/1695 train_time:150100ms step_avg:97.78ms -step:1536/1695 train_time:150198ms step_avg:97.79ms -step:1537/1695 train_time:150296ms step_avg:97.79ms -step:1538/1695 train_time:150393ms step_avg:97.78ms -step:1539/1695 train_time:150491ms step_avg:97.78ms -step:1540/1695 train_time:150588ms step_avg:97.78ms -step:1541/1695 train_time:150685ms step_avg:97.78ms -step:1542/1695 train_time:150783ms step_avg:97.78ms -step:1543/1695 train_time:150881ms step_avg:97.78ms -step:1544/1695 train_time:150981ms step_avg:97.79ms -step:1545/1695 train_time:151081ms step_avg:97.79ms -step:1546/1695 train_time:151181ms step_avg:97.79ms -step:1547/1695 train_time:151279ms step_avg:97.79ms -step:1548/1695 train_time:151378ms step_avg:97.79ms -step:1549/1695 train_time:151478ms step_avg:97.79ms -step:1550/1695 train_time:151576ms step_avg:97.79ms -step:1551/1695 train_time:151674ms step_avg:97.79ms -step:1552/1695 train_time:152071ms step_avg:97.98ms -step:1553/1695 train_time:152147ms step_avg:97.97ms -step:1554/1695 train_time:152242ms step_avg:97.97ms -step:1555/1695 train_time:152339ms step_avg:97.97ms -step:1556/1695 train_time:152435ms step_avg:97.97ms -step:1557/1695 train_time:152532ms step_avg:97.97ms -step:1558/1695 train_time:152628ms step_avg:97.96ms -step:1559/1695 train_time:152724ms step_avg:97.96ms -step:1560/1695 train_time:152821ms step_avg:97.96ms -step:1561/1695 train_time:152919ms step_avg:97.96ms -step:1562/1695 train_time:153018ms step_avg:97.96ms -step:1563/1695 train_time:153125ms step_avg:97.97ms -step:1564/1695 train_time:153224ms step_avg:97.97ms -step:1565/1695 train_time:153322ms step_avg:97.97ms -step:1566/1695 train_time:153420ms step_avg:97.97ms -step:1567/1695 train_time:153517ms step_avg:97.97ms -step:1568/1695 train_time:153616ms step_avg:97.97ms -step:1569/1695 train_time:153713ms step_avg:97.97ms -step:1570/1695 train_time:153810ms step_avg:97.97ms -step:1571/1695 train_time:153906ms step_avg:97.97ms -step:1572/1695 train_time:154004ms step_avg:97.97ms -step:1573/1695 train_time:154103ms step_avg:97.97ms -step:1574/1695 train_time:154203ms step_avg:97.97ms -step:1575/1695 train_time:154302ms step_avg:97.97ms -step:1576/1695 train_time:154400ms step_avg:97.97ms -step:1577/1695 train_time:154498ms step_avg:97.97ms -step:1578/1695 train_time:154596ms step_avg:97.97ms -step:1579/1695 train_time:154693ms step_avg:97.97ms -step:1580/1695 train_time:154791ms step_avg:97.97ms -step:1581/1695 train_time:154888ms step_avg:97.97ms -step:1582/1695 train_time:154985ms step_avg:97.97ms -step:1583/1695 train_time:155083ms step_avg:97.97ms -step:1584/1695 train_time:155182ms step_avg:97.97ms -step:1585/1695 train_time:155280ms step_avg:97.97ms -step:1586/1695 train_time:155380ms step_avg:97.97ms -step:1587/1695 train_time:155478ms step_avg:97.97ms -step:1588/1695 train_time:155575ms step_avg:97.97ms -step:1589/1695 train_time:155673ms step_avg:97.97ms -step:1590/1695 train_time:155771ms step_avg:97.97ms -step:1591/1695 train_time:155868ms step_avg:97.97ms -step:1592/1695 train_time:155965ms step_avg:97.97ms -step:1593/1695 train_time:156063ms step_avg:97.97ms -step:1594/1695 train_time:156160ms step_avg:97.97ms -step:1595/1695 train_time:156258ms step_avg:97.97ms -step:1596/1695 train_time:156357ms step_avg:97.97ms -step:1597/1695 train_time:156456ms step_avg:97.97ms -step:1598/1695 train_time:156554ms step_avg:97.97ms -step:1599/1695 train_time:156651ms step_avg:97.97ms -step:1600/1695 train_time:156748ms step_avg:97.97ms -step:1601/1695 train_time:156845ms step_avg:97.97ms -step:1602/1695 train_time:156943ms step_avg:97.97ms -step:1603/1695 train_time:157041ms step_avg:97.97ms -step:1604/1695 train_time:157140ms step_avg:97.97ms -step:1605/1695 train_time:157239ms step_avg:97.97ms -step:1606/1695 train_time:157337ms step_avg:97.97ms -step:1607/1695 train_time:157436ms step_avg:97.97ms -step:1608/1695 train_time:157535ms step_avg:97.97ms -step:1609/1695 train_time:157632ms step_avg:97.97ms -step:1610/1695 train_time:157730ms step_avg:97.97ms -step:1611/1695 train_time:157827ms step_avg:97.97ms -step:1612/1695 train_time:157925ms step_avg:97.97ms -step:1613/1695 train_time:158022ms step_avg:97.97ms -step:1614/1695 train_time:158119ms step_avg:97.97ms -step:1615/1695 train_time:158218ms step_avg:97.97ms -step:1616/1695 train_time:158317ms step_avg:97.97ms -step:1617/1695 train_time:158415ms step_avg:97.97ms -step:1618/1695 train_time:158513ms step_avg:97.97ms -step:1619/1695 train_time:158611ms step_avg:97.97ms -step:1620/1695 train_time:158708ms step_avg:97.97ms -step:1621/1695 train_time:158805ms step_avg:97.97ms -step:1622/1695 train_time:158903ms step_avg:97.97ms -step:1623/1695 train_time:159001ms step_avg:97.97ms -step:1624/1695 train_time:159099ms step_avg:97.97ms -step:1625/1695 train_time:159197ms step_avg:97.97ms -step:1625/1695 val_loss:3.2895 train_time:159292ms step_avg:98.03ms -step:1626/1695 train_time:159319ms step_avg:97.98ms -step:1627/1695 train_time:159403ms step_avg:97.97ms -step:1628/1695 train_time:159501ms step_avg:97.97ms -step:1629/1695 train_time:159598ms step_avg:97.97ms -step:1630/1695 train_time:159696ms step_avg:97.97ms -step:1631/1695 train_time:159793ms step_avg:97.97ms -step:1632/1695 train_time:159890ms step_avg:97.97ms -step:1633/1695 train_time:159986ms step_avg:97.97ms -step:1634/1695 train_time:160083ms step_avg:97.97ms -step:1635/1695 train_time:160179ms step_avg:97.97ms -step:1636/1695 train_time:160280ms step_avg:97.97ms -step:1637/1695 train_time:160382ms step_avg:97.97ms -step:1638/1695 train_time:160482ms step_avg:97.97ms -step:1639/1695 train_time:160580ms step_avg:97.97ms -step:1640/1695 train_time:160678ms step_avg:97.97ms -step:1641/1695 train_time:160774ms step_avg:97.97ms -step:1642/1695 train_time:160871ms step_avg:97.97ms -step:1643/1695 train_time:160969ms step_avg:97.97ms -step:1644/1695 train_time:161066ms step_avg:97.97ms -step:1645/1695 train_time:161162ms step_avg:97.97ms -step:1646/1695 train_time:161261ms step_avg:97.97ms -step:1647/1695 train_time:161362ms step_avg:97.97ms -step:1648/1695 train_time:161461ms step_avg:97.97ms -step:1649/1695 train_time:161559ms step_avg:97.97ms -step:1650/1695 train_time:161657ms step_avg:97.97ms -step:1651/1695 train_time:161755ms step_avg:97.97ms -step:1652/1695 train_time:161852ms step_avg:97.97ms -step:1653/1695 train_time:161951ms step_avg:97.97ms -step:1654/1695 train_time:162049ms step_avg:97.97ms -step:1655/1695 train_time:162146ms step_avg:97.97ms -step:1656/1695 train_time:162244ms step_avg:97.97ms -step:1657/1695 train_time:162342ms step_avg:97.97ms -step:1658/1695 train_time:162440ms step_avg:97.97ms -step:1659/1695 train_time:162538ms step_avg:97.97ms -step:1660/1695 train_time:162636ms step_avg:97.97ms -step:1661/1695 train_time:162734ms step_avg:97.97ms -step:1662/1695 train_time:162831ms step_avg:97.97ms -step:1663/1695 train_time:162928ms step_avg:97.97ms -step:1664/1695 train_time:163026ms step_avg:97.97ms -step:1665/1695 train_time:163122ms step_avg:97.97ms -step:1666/1695 train_time:163221ms step_avg:97.97ms -step:1667/1695 train_time:163320ms step_avg:97.97ms -step:1668/1695 train_time:163418ms step_avg:97.97ms -step:1669/1695 train_time:163518ms step_avg:97.97ms -step:1670/1695 train_time:163617ms step_avg:97.97ms -step:1671/1695 train_time:163715ms step_avg:97.97ms -step:1672/1695 train_time:163812ms step_avg:97.97ms -step:1673/1695 train_time:163911ms step_avg:97.97ms -step:1674/1695 train_time:164008ms step_avg:97.97ms -step:1675/1695 train_time:164105ms step_avg:97.97ms -step:1676/1695 train_time:164202ms step_avg:97.97ms -step:1677/1695 train_time:164300ms step_avg:97.97ms -step:1678/1695 train_time:164397ms step_avg:97.97ms -step:1679/1695 train_time:164495ms step_avg:97.97ms -step:1680/1695 train_time:164593ms step_avg:97.97ms -step:1681/1695 train_time:164691ms step_avg:97.97ms -step:1682/1695 train_time:164789ms step_avg:97.97ms -step:1683/1695 train_time:164886ms step_avg:97.97ms -step:1684/1695 train_time:164984ms step_avg:97.97ms -step:1685/1695 train_time:165081ms step_avg:97.97ms -step:1686/1695 train_time:165179ms step_avg:97.97ms -step:1687/1695 train_time:165278ms step_avg:97.97ms -step:1688/1695 train_time:165377ms step_avg:97.97ms -step:1689/1695 train_time:165475ms step_avg:97.97ms -step:1690/1695 train_time:165573ms step_avg:97.97ms -step:1691/1695 train_time:165672ms step_avg:97.97ms -step:1692/1695 train_time:165770ms step_avg:97.97ms -step:1693/1695 train_time:165868ms step_avg:97.97ms -step:1694/1695 train_time:165966ms step_avg:97.97ms -step:1695/1695 train_time:166062ms step_avg:97.97ms -step:1695/1695 val_loss:3.2782 train_time:166157ms step_avg:98.03ms -peak memory allocated: 34361 MiB reserved: 49576 MiB diff --git a/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt b/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt deleted file mode 100644 index 32ec95b7e..000000000 --- a/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 04:04:43 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 29C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 29C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 29C P0 109W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 31C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms -step:1/1695 train_time:517ms step_avg:517.24ms -step:2/1695 train_time:541ms step_avg:270.47ms -step:3/1695 train_time:610ms step_avg:203.20ms -step:4/1695 train_time:702ms step_avg:175.48ms -step:5/1695 train_time:795ms step_avg:158.96ms -step:6/1695 train_time:888ms step_avg:148.00ms -step:7/1695 train_time:982ms step_avg:140.22ms -step:8/1695 train_time:1075ms step_avg:134.38ms -step:9/1695 train_time:1169ms step_avg:129.83ms -step:10/1695 train_time:1262ms step_avg:126.22ms -step:11/1695 train_time:1356ms step_avg:123.26ms -step:12/1695 train_time:1454ms step_avg:121.13ms -step:13/1695 train_time:1553ms step_avg:119.44ms -step:14/1695 train_time:1650ms step_avg:117.83ms -step:15/1695 train_time:1744ms step_avg:116.30ms -step:16/1695 train_time:1838ms step_avg:114.86ms -step:17/1695 train_time:1931ms step_avg:113.60ms -step:18/1695 train_time:2025ms step_avg:112.49ms -step:19/1695 train_time:2118ms step_avg:111.45ms -step:20/1695 train_time:2211ms step_avg:110.56ms -step:21/1695 train_time:2306ms step_avg:109.79ms -step:22/1695 train_time:2401ms step_avg:109.13ms -step:23/1695 train_time:2497ms step_avg:108.55ms -step:24/1695 train_time:2592ms step_avg:108.02ms -step:25/1695 train_time:2689ms step_avg:107.56ms -step:26/1695 train_time:2785ms step_avg:107.10ms -step:27/1695 train_time:2879ms step_avg:106.62ms -step:28/1695 train_time:2973ms step_avg:106.16ms -step:29/1695 train_time:3067ms step_avg:105.75ms -step:30/1695 train_time:3160ms step_avg:105.34ms -step:31/1695 train_time:3253ms step_avg:104.95ms -step:32/1695 train_time:3349ms step_avg:104.66ms -step:33/1695 train_time:3445ms step_avg:104.39ms -step:34/1695 train_time:3540ms step_avg:104.10ms -step:35/1695 train_time:3634ms step_avg:103.84ms -step:36/1695 train_time:3731ms step_avg:103.64ms -step:37/1695 train_time:3827ms step_avg:103.44ms -step:38/1695 train_time:3922ms step_avg:103.22ms -step:39/1695 train_time:4016ms step_avg:102.98ms -step:40/1695 train_time:4111ms step_avg:102.77ms -step:41/1695 train_time:4205ms step_avg:102.56ms -step:42/1695 train_time:4299ms step_avg:102.35ms -step:43/1695 train_time:4393ms step_avg:102.16ms -step:44/1695 train_time:4488ms step_avg:102.01ms -step:45/1695 train_time:4584ms step_avg:101.86ms -step:46/1695 train_time:4678ms step_avg:101.68ms -step:47/1695 train_time:4773ms step_avg:101.54ms -step:48/1695 train_time:4867ms step_avg:101.40ms -step:49/1695 train_time:4962ms step_avg:101.27ms -step:50/1695 train_time:5056ms step_avg:101.12ms -step:51/1695 train_time:5150ms step_avg:100.98ms -step:52/1695 train_time:5244ms step_avg:100.85ms -step:53/1695 train_time:5338ms step_avg:100.72ms -step:54/1695 train_time:5432ms step_avg:100.60ms -step:55/1695 train_time:5528ms step_avg:100.51ms -step:56/1695 train_time:5623ms step_avg:100.42ms -step:57/1695 train_time:5717ms step_avg:100.31ms -step:58/1695 train_time:5812ms step_avg:100.20ms -step:59/1695 train_time:5908ms step_avg:100.13ms -step:60/1695 train_time:6003ms step_avg:100.04ms -step:61/1695 train_time:6096ms step_avg:99.94ms -step:62/1695 train_time:6190ms step_avg:99.84ms -step:63/1695 train_time:6284ms step_avg:99.75ms -step:64/1695 train_time:6378ms step_avg:99.66ms -step:65/1695 train_time:6472ms step_avg:99.57ms -step:66/1695 train_time:6568ms step_avg:99.51ms -step:67/1695 train_time:6663ms step_avg:99.45ms -step:68/1695 train_time:6756ms step_avg:99.36ms -step:69/1695 train_time:6851ms step_avg:99.29ms -step:70/1695 train_time:6946ms step_avg:99.23ms -step:71/1695 train_time:7040ms step_avg:99.16ms -step:72/1695 train_time:7133ms step_avg:99.07ms -step:73/1695 train_time:7228ms step_avg:99.02ms -step:74/1695 train_time:7324ms step_avg:98.97ms -step:75/1695 train_time:7418ms step_avg:98.90ms -step:76/1695 train_time:7512ms step_avg:98.84ms -step:77/1695 train_time:7607ms step_avg:98.80ms -step:78/1695 train_time:7702ms step_avg:98.75ms -step:79/1695 train_time:7796ms step_avg:98.69ms -step:80/1695 train_time:7891ms step_avg:98.63ms -step:81/1695 train_time:7985ms step_avg:98.58ms -step:82/1695 train_time:8079ms step_avg:98.52ms -step:83/1695 train_time:8172ms step_avg:98.46ms -step:84/1695 train_time:8268ms step_avg:98.42ms -step:85/1695 train_time:8363ms step_avg:98.39ms -step:86/1695 train_time:8457ms step_avg:98.33ms -step:87/1695 train_time:8551ms step_avg:98.29ms -step:88/1695 train_time:8647ms step_avg:98.26ms -step:89/1695 train_time:8742ms step_avg:98.22ms -step:90/1695 train_time:8835ms step_avg:98.17ms -step:91/1695 train_time:8931ms step_avg:98.14ms -step:92/1695 train_time:9024ms step_avg:98.09ms -step:93/1695 train_time:9118ms step_avg:98.04ms -step:94/1695 train_time:9211ms step_avg:97.99ms -step:95/1695 train_time:9305ms step_avg:97.95ms -step:96/1695 train_time:9399ms step_avg:97.91ms -step:97/1695 train_time:9494ms step_avg:97.87ms -step:98/1695 train_time:9588ms step_avg:97.84ms -step:99/1695 train_time:9684ms step_avg:97.81ms -step:100/1695 train_time:9777ms step_avg:97.77ms -step:101/1695 train_time:9872ms step_avg:97.75ms -step:102/1695 train_time:9968ms step_avg:97.73ms -step:103/1695 train_time:10063ms step_avg:97.70ms -step:104/1695 train_time:10157ms step_avg:97.66ms -step:105/1695 train_time:10251ms step_avg:97.63ms -step:106/1695 train_time:10345ms step_avg:97.59ms -step:107/1695 train_time:10439ms step_avg:97.56ms -step:108/1695 train_time:10533ms step_avg:97.52ms -step:109/1695 train_time:10627ms step_avg:97.50ms -step:110/1695 train_time:10721ms step_avg:97.46ms -step:111/1695 train_time:10814ms step_avg:97.42ms -step:112/1695 train_time:10909ms step_avg:97.40ms -step:113/1695 train_time:11004ms step_avg:97.38ms -step:114/1695 train_time:11098ms step_avg:97.35ms -step:115/1695 train_time:11192ms step_avg:97.33ms -step:116/1695 train_time:11288ms step_avg:97.31ms -step:117/1695 train_time:11381ms step_avg:97.28ms -step:118/1695 train_time:11476ms step_avg:97.25ms -step:119/1695 train_time:11570ms step_avg:97.23ms -step:120/1695 train_time:11664ms step_avg:97.20ms -step:121/1695 train_time:11758ms step_avg:97.17ms -step:122/1695 train_time:11852ms step_avg:97.15ms -step:123/1695 train_time:11947ms step_avg:97.13ms -step:124/1695 train_time:12042ms step_avg:97.11ms -step:125/1695 train_time:12135ms step_avg:97.08ms -step:125/1695 val_loss:4.3142 train_time:12227ms step_avg:97.82ms -step:126/1695 train_time:12252ms step_avg:97.24ms -step:127/1695 train_time:12329ms step_avg:97.08ms -step:128/1695 train_time:12428ms step_avg:97.09ms -step:129/1695 train_time:12522ms step_avg:97.07ms -step:130/1695 train_time:12616ms step_avg:97.05ms -step:131/1695 train_time:12710ms step_avg:97.02ms -step:132/1695 train_time:12803ms step_avg:96.99ms -step:133/1695 train_time:12896ms step_avg:96.96ms -step:134/1695 train_time:12990ms step_avg:96.94ms -step:135/1695 train_time:13083ms step_avg:96.91ms -step:136/1695 train_time:13177ms step_avg:96.89ms -step:137/1695 train_time:13273ms step_avg:96.88ms -step:138/1695 train_time:13370ms step_avg:96.88ms -step:139/1695 train_time:13466ms step_avg:96.88ms -step:140/1695 train_time:13560ms step_avg:96.86ms -step:141/1695 train_time:13654ms step_avg:96.84ms -step:142/1695 train_time:13748ms step_avg:96.82ms -step:143/1695 train_time:13841ms step_avg:96.79ms -step:144/1695 train_time:13935ms step_avg:96.77ms -step:145/1695 train_time:14029ms step_avg:96.75ms -step:146/1695 train_time:14122ms step_avg:96.72ms -step:147/1695 train_time:14215ms step_avg:96.70ms -step:148/1695 train_time:14312ms step_avg:96.71ms -step:149/1695 train_time:14409ms step_avg:96.70ms -step:150/1695 train_time:14503ms step_avg:96.69ms -step:151/1695 train_time:14597ms step_avg:96.67ms -step:152/1695 train_time:14692ms step_avg:96.66ms -step:153/1695 train_time:14787ms step_avg:96.65ms -step:154/1695 train_time:14881ms step_avg:96.63ms -step:155/1695 train_time:14974ms step_avg:96.61ms -step:156/1695 train_time:15068ms step_avg:96.59ms -step:157/1695 train_time:15161ms step_avg:96.56ms -step:158/1695 train_time:15255ms step_avg:96.55ms -step:159/1695 train_time:15350ms step_avg:96.54ms -step:160/1695 train_time:15446ms step_avg:96.53ms -step:161/1695 train_time:15540ms step_avg:96.52ms -step:162/1695 train_time:15635ms step_avg:96.51ms -step:163/1695 train_time:15729ms step_avg:96.50ms -step:164/1695 train_time:15824ms step_avg:96.49ms -step:165/1695 train_time:15917ms step_avg:96.47ms -step:166/1695 train_time:16012ms step_avg:96.46ms -step:167/1695 train_time:16106ms step_avg:96.44ms -step:168/1695 train_time:16199ms step_avg:96.42ms -step:169/1695 train_time:16293ms step_avg:96.41ms -step:170/1695 train_time:16388ms step_avg:96.40ms -step:171/1695 train_time:16482ms step_avg:96.39ms -step:172/1695 train_time:16576ms step_avg:96.37ms -step:173/1695 train_time:16958ms step_avg:98.02ms -step:174/1695 train_time:17044ms step_avg:97.95ms -step:175/1695 train_time:17136ms step_avg:97.92ms -step:176/1695 train_time:17228ms step_avg:97.89ms -step:177/1695 train_time:17321ms step_avg:97.86ms -step:178/1695 train_time:17414ms step_avg:97.83ms -step:179/1695 train_time:17507ms step_avg:97.80ms -step:180/1695 train_time:17599ms step_avg:97.77ms -step:181/1695 train_time:17693ms step_avg:97.75ms -step:182/1695 train_time:17787ms step_avg:97.73ms -step:183/1695 train_time:17880ms step_avg:97.71ms -step:184/1695 train_time:17978ms step_avg:97.71ms -step:185/1695 train_time:18075ms step_avg:97.71ms -step:186/1695 train_time:18170ms step_avg:97.69ms -step:187/1695 train_time:18264ms step_avg:97.67ms -step:188/1695 train_time:18357ms step_avg:97.64ms -step:189/1695 train_time:18451ms step_avg:97.62ms -step:190/1695 train_time:18545ms step_avg:97.60ms -step:191/1695 train_time:18637ms step_avg:97.58ms -step:192/1695 train_time:18731ms step_avg:97.56ms -step:193/1695 train_time:18825ms step_avg:97.54ms -step:194/1695 train_time:18919ms step_avg:97.52ms -step:195/1695 train_time:19016ms step_avg:97.52ms -step:196/1695 train_time:19111ms step_avg:97.51ms -step:197/1695 train_time:19207ms step_avg:97.50ms -step:198/1695 train_time:19300ms step_avg:97.48ms -step:199/1695 train_time:19395ms step_avg:97.46ms -step:200/1695 train_time:19489ms step_avg:97.44ms -step:201/1695 train_time:19582ms step_avg:97.42ms -step:202/1695 train_time:19676ms step_avg:97.40ms -step:203/1695 train_time:19769ms step_avg:97.39ms -step:204/1695 train_time:19863ms step_avg:97.37ms -step:205/1695 train_time:19957ms step_avg:97.35ms -step:206/1695 train_time:20052ms step_avg:97.34ms -step:207/1695 train_time:20148ms step_avg:97.33ms -step:208/1695 train_time:20242ms step_avg:97.32ms -step:209/1695 train_time:20336ms step_avg:97.30ms -step:210/1695 train_time:20430ms step_avg:97.29ms -step:211/1695 train_time:20523ms step_avg:97.27ms -step:212/1695 train_time:20617ms step_avg:97.25ms -step:213/1695 train_time:20710ms step_avg:97.23ms -step:214/1695 train_time:20804ms step_avg:97.21ms -step:215/1695 train_time:20897ms step_avg:97.20ms -step:216/1695 train_time:20991ms step_avg:97.18ms -step:217/1695 train_time:21086ms step_avg:97.17ms -step:218/1695 train_time:21180ms step_avg:97.15ms -step:219/1695 train_time:21274ms step_avg:97.14ms -step:220/1695 train_time:21369ms step_avg:97.13ms -step:221/1695 train_time:21463ms step_avg:97.12ms -step:222/1695 train_time:21557ms step_avg:97.10ms -step:223/1695 train_time:21651ms step_avg:97.09ms -step:224/1695 train_time:21745ms step_avg:97.07ms -step:225/1695 train_time:21838ms step_avg:97.06ms -step:226/1695 train_time:21933ms step_avg:97.05ms -step:227/1695 train_time:22028ms step_avg:97.04ms -step:228/1695 train_time:22122ms step_avg:97.03ms -step:229/1695 train_time:22216ms step_avg:97.01ms -step:230/1695 train_time:22311ms step_avg:97.01ms -step:231/1695 train_time:22406ms step_avg:97.00ms -step:232/1695 train_time:22500ms step_avg:96.98ms -step:233/1695 train_time:22593ms step_avg:96.97ms -step:234/1695 train_time:22688ms step_avg:96.96ms -step:235/1695 train_time:22781ms step_avg:96.94ms -step:236/1695 train_time:22874ms step_avg:96.93ms -step:237/1695 train_time:22969ms step_avg:96.91ms -step:238/1695 train_time:23062ms step_avg:96.90ms -step:239/1695 train_time:23155ms step_avg:96.88ms -step:240/1695 train_time:23249ms step_avg:96.87ms -step:241/1695 train_time:23343ms step_avg:96.86ms -step:242/1695 train_time:23437ms step_avg:96.85ms -step:243/1695 train_time:23531ms step_avg:96.84ms -step:244/1695 train_time:23625ms step_avg:96.82ms -step:245/1695 train_time:23719ms step_avg:96.81ms -step:246/1695 train_time:23813ms step_avg:96.80ms -step:247/1695 train_time:23908ms step_avg:96.79ms -step:248/1695 train_time:24002ms step_avg:96.78ms -step:249/1695 train_time:24096ms step_avg:96.77ms -step:250/1695 train_time:24190ms step_avg:96.76ms -step:250/1695 val_loss:3.9738 train_time:24282ms step_avg:97.13ms -step:251/1695 train_time:24306ms step_avg:96.84ms -step:252/1695 train_time:24385ms step_avg:96.77ms -step:253/1695 train_time:24484ms step_avg:96.78ms -step:254/1695 train_time:24579ms step_avg:96.77ms -step:255/1695 train_time:24672ms step_avg:96.75ms -step:256/1695 train_time:24766ms step_avg:96.74ms -step:257/1695 train_time:24858ms step_avg:96.73ms -step:258/1695 train_time:24951ms step_avg:96.71ms -step:259/1695 train_time:25044ms step_avg:96.69ms -step:260/1695 train_time:25137ms step_avg:96.68ms -step:261/1695 train_time:25231ms step_avg:96.67ms -step:262/1695 train_time:25325ms step_avg:96.66ms -step:263/1695 train_time:25423ms step_avg:96.66ms -step:264/1695 train_time:25519ms step_avg:96.66ms -step:265/1695 train_time:25614ms step_avg:96.66ms -step:266/1695 train_time:25708ms step_avg:96.65ms -step:267/1695 train_time:25801ms step_avg:96.63ms -step:268/1695 train_time:25895ms step_avg:96.62ms -step:269/1695 train_time:25988ms step_avg:96.61ms -step:270/1695 train_time:26081ms step_avg:96.60ms -step:271/1695 train_time:26174ms step_avg:96.58ms -step:272/1695 train_time:26267ms step_avg:96.57ms -step:273/1695 train_time:26362ms step_avg:96.57ms -step:274/1695 train_time:26458ms step_avg:96.56ms -step:275/1695 train_time:26553ms step_avg:96.56ms -step:276/1695 train_time:26648ms step_avg:96.55ms -step:277/1695 train_time:26741ms step_avg:96.54ms -step:278/1695 train_time:26836ms step_avg:96.53ms -step:279/1695 train_time:26931ms step_avg:96.53ms -step:280/1695 train_time:27023ms step_avg:96.51ms -step:281/1695 train_time:27117ms step_avg:96.50ms -step:282/1695 train_time:27210ms step_avg:96.49ms -step:283/1695 train_time:27303ms step_avg:96.48ms -step:284/1695 train_time:27398ms step_avg:96.47ms -step:285/1695 train_time:27492ms step_avg:96.46ms -step:286/1695 train_time:27587ms step_avg:96.46ms -step:287/1695 train_time:27681ms step_avg:96.45ms -step:288/1695 train_time:27775ms step_avg:96.44ms -step:289/1695 train_time:27870ms step_avg:96.44ms -step:290/1695 train_time:27963ms step_avg:96.43ms -step:291/1695 train_time:28057ms step_avg:96.42ms -step:292/1695 train_time:28151ms step_avg:96.41ms -step:293/1695 train_time:28244ms step_avg:96.39ms -step:294/1695 train_time:28338ms step_avg:96.39ms -step:295/1695 train_time:28432ms step_avg:96.38ms -step:296/1695 train_time:28526ms step_avg:96.37ms -step:297/1695 train_time:28620ms step_avg:96.36ms -step:298/1695 train_time:28714ms step_avg:96.35ms -step:299/1695 train_time:28808ms step_avg:96.35ms -step:300/1695 train_time:28901ms step_avg:96.34ms -step:301/1695 train_time:28995ms step_avg:96.33ms -step:302/1695 train_time:29089ms step_avg:96.32ms -step:303/1695 train_time:29182ms step_avg:96.31ms -step:304/1695 train_time:29276ms step_avg:96.30ms -step:305/1695 train_time:29370ms step_avg:96.30ms -step:306/1695 train_time:29464ms step_avg:96.29ms -step:307/1695 train_time:29558ms step_avg:96.28ms -step:308/1695 train_time:29653ms step_avg:96.28ms -step:309/1695 train_time:29748ms step_avg:96.27ms -step:310/1695 train_time:29841ms step_avg:96.26ms -step:311/1695 train_time:29936ms step_avg:96.26ms -step:312/1695 train_time:30031ms step_avg:96.25ms -step:313/1695 train_time:30124ms step_avg:96.24ms -step:314/1695 train_time:30218ms step_avg:96.24ms -step:315/1695 train_time:30312ms step_avg:96.23ms -step:316/1695 train_time:30405ms step_avg:96.22ms -step:317/1695 train_time:30500ms step_avg:96.21ms -step:318/1695 train_time:30595ms step_avg:96.21ms -step:319/1695 train_time:30689ms step_avg:96.20ms -step:320/1695 train_time:30783ms step_avg:96.20ms -step:321/1695 train_time:30877ms step_avg:96.19ms -step:322/1695 train_time:30972ms step_avg:96.19ms -step:323/1695 train_time:31066ms step_avg:96.18ms -step:324/1695 train_time:31159ms step_avg:96.17ms -step:325/1695 train_time:31252ms step_avg:96.16ms -step:326/1695 train_time:31346ms step_avg:96.15ms -step:327/1695 train_time:31440ms step_avg:96.15ms -step:328/1695 train_time:31535ms step_avg:96.14ms -step:329/1695 train_time:31630ms step_avg:96.14ms -step:330/1695 train_time:31723ms step_avg:96.13ms -step:331/1695 train_time:31817ms step_avg:96.12ms -step:332/1695 train_time:31912ms step_avg:96.12ms -step:333/1695 train_time:32006ms step_avg:96.11ms -step:334/1695 train_time:32099ms step_avg:96.11ms -step:335/1695 train_time:32194ms step_avg:96.10ms -step:336/1695 train_time:32288ms step_avg:96.09ms -step:337/1695 train_time:32381ms step_avg:96.09ms -step:338/1695 train_time:32475ms step_avg:96.08ms -step:339/1695 train_time:32570ms step_avg:96.08ms -step:340/1695 train_time:32664ms step_avg:96.07ms -step:341/1695 train_time:32758ms step_avg:96.06ms -step:342/1695 train_time:32852ms step_avg:96.06ms -step:343/1695 train_time:32947ms step_avg:96.05ms -step:344/1695 train_time:33040ms step_avg:96.05ms -step:345/1695 train_time:33366ms step_avg:96.71ms -step:346/1695 train_time:33470ms step_avg:96.73ms -step:347/1695 train_time:33562ms step_avg:96.72ms -step:348/1695 train_time:33655ms step_avg:96.71ms -step:349/1695 train_time:33748ms step_avg:96.70ms -step:350/1695 train_time:33840ms step_avg:96.69ms -step:351/1695 train_time:33933ms step_avg:96.68ms -step:352/1695 train_time:34026ms step_avg:96.66ms -step:353/1695 train_time:34119ms step_avg:96.65ms -step:354/1695 train_time:34212ms step_avg:96.65ms -step:355/1695 train_time:34310ms step_avg:96.65ms -step:356/1695 train_time:34408ms step_avg:96.65ms -step:357/1695 train_time:34502ms step_avg:96.65ms -step:358/1695 train_time:34597ms step_avg:96.64ms -step:359/1695 train_time:34690ms step_avg:96.63ms -step:360/1695 train_time:34783ms step_avg:96.62ms -step:361/1695 train_time:34876ms step_avg:96.61ms -step:362/1695 train_time:34970ms step_avg:96.60ms -step:363/1695 train_time:35062ms step_avg:96.59ms -step:364/1695 train_time:35155ms step_avg:96.58ms -step:365/1695 train_time:35250ms step_avg:96.58ms -step:366/1695 train_time:35345ms step_avg:96.57ms -step:367/1695 train_time:35440ms step_avg:96.57ms -step:368/1695 train_time:35535ms step_avg:96.56ms -step:369/1695 train_time:35630ms step_avg:96.56ms -step:370/1695 train_time:35723ms step_avg:96.55ms -step:371/1695 train_time:35817ms step_avg:96.54ms -step:372/1695 train_time:35911ms step_avg:96.53ms -step:373/1695 train_time:36004ms step_avg:96.52ms -step:374/1695 train_time:36097ms step_avg:96.52ms -step:375/1695 train_time:36191ms step_avg:96.51ms -step:375/1695 val_loss:3.8151 train_time:36283ms step_avg:96.76ms -step:376/1695 train_time:36310ms step_avg:96.57ms -step:377/1695 train_time:36385ms step_avg:96.51ms -step:378/1695 train_time:36485ms step_avg:96.52ms -step:379/1695 train_time:36582ms step_avg:96.52ms -step:380/1695 train_time:36675ms step_avg:96.51ms -step:381/1695 train_time:36768ms step_avg:96.50ms -step:382/1695 train_time:36861ms step_avg:96.50ms -step:383/1695 train_time:36955ms step_avg:96.49ms -step:384/1695 train_time:37047ms step_avg:96.48ms -step:385/1695 train_time:37140ms step_avg:96.47ms -step:386/1695 train_time:37233ms step_avg:96.46ms -step:387/1695 train_time:37328ms step_avg:96.46ms -step:388/1695 train_time:37424ms step_avg:96.45ms -step:389/1695 train_time:37521ms step_avg:96.46ms -step:390/1695 train_time:37617ms step_avg:96.45ms -step:391/1695 train_time:37711ms step_avg:96.45ms -step:392/1695 train_time:37804ms step_avg:96.44ms -step:393/1695 train_time:37897ms step_avg:96.43ms -step:394/1695 train_time:37990ms step_avg:96.42ms -step:395/1695 train_time:38083ms step_avg:96.41ms -step:396/1695 train_time:38176ms step_avg:96.40ms -step:397/1695 train_time:38270ms step_avg:96.40ms -step:398/1695 train_time:38364ms step_avg:96.39ms -step:399/1695 train_time:38460ms step_avg:96.39ms -step:400/1695 train_time:38556ms step_avg:96.39ms -step:401/1695 train_time:38650ms step_avg:96.38ms -step:402/1695 train_time:38744ms step_avg:96.38ms -step:403/1695 train_time:38837ms step_avg:96.37ms -step:404/1695 train_time:38930ms step_avg:96.36ms -step:405/1695 train_time:39024ms step_avg:96.35ms -step:406/1695 train_time:39118ms step_avg:96.35ms -step:407/1695 train_time:39211ms step_avg:96.34ms -step:408/1695 train_time:39304ms step_avg:96.33ms -step:409/1695 train_time:39398ms step_avg:96.33ms -step:410/1695 train_time:39494ms step_avg:96.33ms -step:411/1695 train_time:39588ms step_avg:96.32ms -step:412/1695 train_time:39683ms step_avg:96.32ms -step:413/1695 train_time:39776ms step_avg:96.31ms -step:414/1695 train_time:39870ms step_avg:96.30ms -step:415/1695 train_time:39963ms step_avg:96.30ms -step:416/1695 train_time:40058ms step_avg:96.29ms -step:417/1695 train_time:40152ms step_avg:96.29ms -step:418/1695 train_time:40246ms step_avg:96.28ms -step:419/1695 train_time:40339ms step_avg:96.27ms -step:420/1695 train_time:40433ms step_avg:96.27ms -step:421/1695 train_time:40528ms step_avg:96.27ms -step:422/1695 train_time:40622ms step_avg:96.26ms -step:423/1695 train_time:40716ms step_avg:96.26ms -step:424/1695 train_time:40810ms step_avg:96.25ms -step:425/1695 train_time:40903ms step_avg:96.24ms -step:426/1695 train_time:40997ms step_avg:96.24ms -step:427/1695 train_time:41091ms step_avg:96.23ms -step:428/1695 train_time:41184ms step_avg:96.22ms -step:429/1695 train_time:41278ms step_avg:96.22ms -step:430/1695 train_time:41372ms step_avg:96.21ms -step:431/1695 train_time:41466ms step_avg:96.21ms -step:432/1695 train_time:41561ms step_avg:96.21ms -step:433/1695 train_time:41656ms step_avg:96.20ms -step:434/1695 train_time:41750ms step_avg:96.20ms -step:435/1695 train_time:41843ms step_avg:96.19ms -step:436/1695 train_time:41938ms step_avg:96.19ms -step:437/1695 train_time:42032ms step_avg:96.18ms -step:438/1695 train_time:42126ms step_avg:96.18ms -step:439/1695 train_time:42220ms step_avg:96.17ms -step:440/1695 train_time:42314ms step_avg:96.17ms -step:441/1695 train_time:42408ms step_avg:96.16ms -step:442/1695 train_time:42501ms step_avg:96.16ms -step:443/1695 train_time:42596ms step_avg:96.15ms -step:444/1695 train_time:42691ms step_avg:96.15ms -step:445/1695 train_time:42784ms step_avg:96.14ms -step:446/1695 train_time:42878ms step_avg:96.14ms -step:447/1695 train_time:42972ms step_avg:96.13ms -step:448/1695 train_time:43066ms step_avg:96.13ms -step:449/1695 train_time:43160ms step_avg:96.12ms -step:450/1695 train_time:43255ms step_avg:96.12ms -step:451/1695 train_time:43348ms step_avg:96.12ms -step:452/1695 train_time:43443ms step_avg:96.11ms -step:453/1695 train_time:43537ms step_avg:96.11ms -step:454/1695 train_time:43632ms step_avg:96.11ms -step:455/1695 train_time:43725ms step_avg:96.10ms -step:456/1695 train_time:43819ms step_avg:96.09ms -step:457/1695 train_time:43913ms step_avg:96.09ms -step:458/1695 train_time:44006ms step_avg:96.08ms -step:459/1695 train_time:44100ms step_avg:96.08ms -step:460/1695 train_time:44194ms step_avg:96.07ms -step:461/1695 train_time:44287ms step_avg:96.07ms -step:462/1695 train_time:44381ms step_avg:96.06ms -step:463/1695 train_time:44476ms step_avg:96.06ms -step:464/1695 train_time:44569ms step_avg:96.05ms -step:465/1695 train_time:44663ms step_avg:96.05ms -step:466/1695 train_time:44758ms step_avg:96.05ms -step:467/1695 train_time:44853ms step_avg:96.04ms -step:468/1695 train_time:44947ms step_avg:96.04ms -step:469/1695 train_time:45040ms step_avg:96.03ms -step:470/1695 train_time:45134ms step_avg:96.03ms -step:471/1695 train_time:45229ms step_avg:96.03ms -step:472/1695 train_time:45323ms step_avg:96.02ms -step:473/1695 train_time:45418ms step_avg:96.02ms -step:474/1695 train_time:45512ms step_avg:96.02ms -step:475/1695 train_time:45605ms step_avg:96.01ms -step:476/1695 train_time:45699ms step_avg:96.01ms -step:477/1695 train_time:45795ms step_avg:96.01ms -step:478/1695 train_time:45888ms step_avg:96.00ms -step:479/1695 train_time:45982ms step_avg:96.00ms -step:480/1695 train_time:46076ms step_avg:95.99ms -step:481/1695 train_time:46170ms step_avg:95.99ms -step:482/1695 train_time:46264ms step_avg:95.98ms -step:483/1695 train_time:46358ms step_avg:95.98ms -step:484/1695 train_time:46454ms step_avg:95.98ms -step:485/1695 train_time:46547ms step_avg:95.97ms -step:486/1695 train_time:46641ms step_avg:95.97ms -step:487/1695 train_time:46735ms step_avg:95.97ms -step:488/1695 train_time:46830ms step_avg:95.96ms -step:489/1695 train_time:46924ms step_avg:95.96ms -step:490/1695 train_time:47018ms step_avg:95.96ms -step:491/1695 train_time:47112ms step_avg:95.95ms -step:492/1695 train_time:47205ms step_avg:95.95ms -step:493/1695 train_time:47299ms step_avg:95.94ms -step:494/1695 train_time:47393ms step_avg:95.94ms -step:495/1695 train_time:47487ms step_avg:95.93ms -step:496/1695 train_time:47581ms step_avg:95.93ms -step:497/1695 train_time:47675ms step_avg:95.93ms -step:498/1695 train_time:47768ms step_avg:95.92ms -step:499/1695 train_time:47862ms step_avg:95.92ms -step:500/1695 train_time:47957ms step_avg:95.91ms -step:500/1695 val_loss:3.7158 train_time:48050ms step_avg:96.10ms -step:501/1695 train_time:48074ms step_avg:95.96ms -step:502/1695 train_time:48155ms step_avg:95.93ms -step:503/1695 train_time:48257ms step_avg:95.94ms -step:504/1695 train_time:48350ms step_avg:95.93ms -step:505/1695 train_time:48444ms step_avg:95.93ms -step:506/1695 train_time:48537ms step_avg:95.92ms -step:507/1695 train_time:48630ms step_avg:95.92ms -step:508/1695 train_time:48723ms step_avg:95.91ms -step:509/1695 train_time:48816ms step_avg:95.91ms -step:510/1695 train_time:48909ms step_avg:95.90ms -step:511/1695 train_time:49002ms step_avg:95.89ms -step:512/1695 train_time:49098ms step_avg:95.89ms -step:513/1695 train_time:49195ms step_avg:95.90ms -step:514/1695 train_time:49290ms step_avg:95.90ms -step:515/1695 train_time:49386ms step_avg:95.89ms -step:516/1695 train_time:49480ms step_avg:95.89ms -step:517/1695 train_time:49573ms step_avg:95.89ms -step:518/1695 train_time:49666ms step_avg:95.88ms -step:519/1695 train_time:49996ms step_avg:96.33ms -step:520/1695 train_time:50189ms step_avg:96.52ms -step:521/1695 train_time:50281ms step_avg:96.51ms -step:522/1695 train_time:50374ms step_avg:96.50ms -step:523/1695 train_time:50467ms step_avg:96.50ms -step:524/1695 train_time:50560ms step_avg:96.49ms -step:525/1695 train_time:50653ms step_avg:96.48ms -step:526/1695 train_time:50746ms step_avg:96.48ms -step:527/1695 train_time:50840ms step_avg:96.47ms -step:528/1695 train_time:50933ms step_avg:96.46ms -step:529/1695 train_time:51027ms step_avg:96.46ms -step:530/1695 train_time:51125ms step_avg:96.46ms -step:531/1695 train_time:51222ms step_avg:96.46ms -step:532/1695 train_time:51316ms step_avg:96.46ms -step:533/1695 train_time:51409ms step_avg:96.45ms -step:534/1695 train_time:51504ms step_avg:96.45ms -step:535/1695 train_time:51596ms step_avg:96.44ms -step:536/1695 train_time:51689ms step_avg:96.43ms -step:537/1695 train_time:51782ms step_avg:96.43ms -step:538/1695 train_time:51875ms step_avg:96.42ms -step:539/1695 train_time:51969ms step_avg:96.42ms -step:540/1695 train_time:52063ms step_avg:96.41ms -step:541/1695 train_time:52158ms step_avg:96.41ms -step:542/1695 train_time:52252ms step_avg:96.41ms -step:543/1695 train_time:52348ms step_avg:96.40ms -step:544/1695 train_time:52442ms step_avg:96.40ms -step:545/1695 train_time:52536ms step_avg:96.40ms -step:546/1695 train_time:52629ms step_avg:96.39ms -step:547/1695 train_time:52723ms step_avg:96.39ms -step:548/1695 train_time:52816ms step_avg:96.38ms -step:549/1695 train_time:52909ms step_avg:96.37ms -step:550/1695 train_time:53003ms step_avg:96.37ms -step:551/1695 train_time:53097ms step_avg:96.36ms -step:552/1695 train_time:53192ms step_avg:96.36ms -step:553/1695 train_time:53287ms step_avg:96.36ms -step:554/1695 train_time:53382ms step_avg:96.36ms -step:555/1695 train_time:53476ms step_avg:96.35ms -step:556/1695 train_time:53569ms step_avg:96.35ms -step:557/1695 train_time:53663ms step_avg:96.34ms -step:558/1695 train_time:53756ms step_avg:96.34ms -step:559/1695 train_time:53849ms step_avg:96.33ms -step:560/1695 train_time:53943ms step_avg:96.33ms -step:561/1695 train_time:54037ms step_avg:96.32ms -step:562/1695 train_time:54131ms step_avg:96.32ms -step:563/1695 train_time:54226ms step_avg:96.32ms -step:564/1695 train_time:54321ms step_avg:96.31ms -step:565/1695 train_time:54414ms step_avg:96.31ms -step:566/1695 train_time:54508ms step_avg:96.30ms -step:567/1695 train_time:54602ms step_avg:96.30ms -step:568/1695 train_time:54697ms step_avg:96.30ms -step:569/1695 train_time:54792ms step_avg:96.30ms -step:570/1695 train_time:54889ms step_avg:96.30ms -step:571/1695 train_time:54986ms step_avg:96.30ms -step:572/1695 train_time:55083ms step_avg:96.30ms -step:573/1695 train_time:55180ms step_avg:96.30ms -step:574/1695 train_time:55276ms step_avg:96.30ms -step:575/1695 train_time:55373ms step_avg:96.30ms -step:576/1695 train_time:55469ms step_avg:96.30ms -step:577/1695 train_time:55565ms step_avg:96.30ms -step:578/1695 train_time:55662ms step_avg:96.30ms -step:579/1695 train_time:55757ms step_avg:96.30ms -step:580/1695 train_time:55852ms step_avg:96.30ms -step:581/1695 train_time:55948ms step_avg:96.30ms -step:582/1695 train_time:56044ms step_avg:96.30ms -step:583/1695 train_time:56140ms step_avg:96.30ms -step:584/1695 train_time:56236ms step_avg:96.29ms -step:585/1695 train_time:56331ms step_avg:96.29ms -step:586/1695 train_time:56427ms step_avg:96.29ms -step:587/1695 train_time:56523ms step_avg:96.29ms -step:588/1695 train_time:56621ms step_avg:96.29ms -step:589/1695 train_time:56716ms step_avg:96.29ms -step:590/1695 train_time:56812ms step_avg:96.29ms -step:591/1695 train_time:56908ms step_avg:96.29ms -step:592/1695 train_time:57005ms step_avg:96.29ms -step:593/1695 train_time:57101ms step_avg:96.29ms -step:594/1695 train_time:57196ms step_avg:96.29ms -step:595/1695 train_time:57291ms step_avg:96.29ms -step:596/1695 train_time:57388ms step_avg:96.29ms -step:597/1695 train_time:57484ms step_avg:96.29ms -step:598/1695 train_time:57581ms step_avg:96.29ms -step:599/1695 train_time:57678ms step_avg:96.29ms -step:600/1695 train_time:57773ms step_avg:96.29ms -step:601/1695 train_time:57869ms step_avg:96.29ms -step:602/1695 train_time:57964ms step_avg:96.29ms -step:603/1695 train_time:58060ms step_avg:96.28ms -step:604/1695 train_time:58155ms step_avg:96.28ms -step:605/1695 train_time:58251ms step_avg:96.28ms -step:606/1695 train_time:58348ms step_avg:96.28ms -step:607/1695 train_time:58444ms step_avg:96.28ms -step:608/1695 train_time:58541ms step_avg:96.28ms -step:609/1695 train_time:58636ms step_avg:96.28ms -step:610/1695 train_time:58732ms step_avg:96.28ms -step:611/1695 train_time:58827ms step_avg:96.28ms -step:612/1695 train_time:58923ms step_avg:96.28ms -step:613/1695 train_time:59020ms step_avg:96.28ms -step:614/1695 train_time:59116ms step_avg:96.28ms -step:615/1695 train_time:59211ms step_avg:96.28ms -step:616/1695 train_time:59307ms step_avg:96.28ms -step:617/1695 train_time:59404ms step_avg:96.28ms -step:618/1695 train_time:59499ms step_avg:96.28ms -step:619/1695 train_time:59595ms step_avg:96.28ms -step:620/1695 train_time:59691ms step_avg:96.28ms -step:621/1695 train_time:59788ms step_avg:96.28ms -step:622/1695 train_time:59884ms step_avg:96.28ms -step:623/1695 train_time:59982ms step_avg:96.28ms -step:624/1695 train_time:60078ms step_avg:96.28ms -step:625/1695 train_time:60173ms step_avg:96.28ms -step:625/1695 val_loss:3.6195 train_time:60266ms step_avg:96.43ms -step:626/1695 train_time:60290ms step_avg:96.31ms -step:627/1695 train_time:60370ms step_avg:96.28ms -step:628/1695 train_time:60467ms step_avg:96.29ms -step:629/1695 train_time:60563ms step_avg:96.28ms -step:630/1695 train_time:60658ms step_avg:96.28ms -step:631/1695 train_time:60753ms step_avg:96.28ms -step:632/1695 train_time:60847ms step_avg:96.28ms -step:633/1695 train_time:60943ms step_avg:96.28ms -step:634/1695 train_time:61038ms step_avg:96.27ms -step:635/1695 train_time:61133ms step_avg:96.27ms -step:636/1695 train_time:61232ms step_avg:96.28ms -step:637/1695 train_time:61331ms step_avg:96.28ms -step:638/1695 train_time:61426ms step_avg:96.28ms -step:639/1695 train_time:61522ms step_avg:96.28ms -step:640/1695 train_time:61617ms step_avg:96.28ms -step:641/1695 train_time:61712ms step_avg:96.28ms -step:642/1695 train_time:61808ms step_avg:96.27ms -step:643/1695 train_time:61902ms step_avg:96.27ms -step:644/1695 train_time:61997ms step_avg:96.27ms -step:645/1695 train_time:62092ms step_avg:96.27ms -step:646/1695 train_time:62188ms step_avg:96.27ms -step:647/1695 train_time:62284ms step_avg:96.27ms -step:648/1695 train_time:62382ms step_avg:96.27ms -step:649/1695 train_time:62478ms step_avg:96.27ms -step:650/1695 train_time:62575ms step_avg:96.27ms -step:651/1695 train_time:62671ms step_avg:96.27ms -step:652/1695 train_time:62767ms step_avg:96.27ms -step:653/1695 train_time:62862ms step_avg:96.27ms -step:654/1695 train_time:62957ms step_avg:96.27ms -step:655/1695 train_time:63053ms step_avg:96.26ms -step:656/1695 train_time:63148ms step_avg:96.26ms -step:657/1695 train_time:63245ms step_avg:96.26ms -step:658/1695 train_time:63341ms step_avg:96.26ms -step:659/1695 train_time:63438ms step_avg:96.26ms -step:660/1695 train_time:63535ms step_avg:96.27ms -step:661/1695 train_time:63632ms step_avg:96.27ms -step:662/1695 train_time:63727ms step_avg:96.26ms -step:663/1695 train_time:63822ms step_avg:96.26ms -step:664/1695 train_time:63918ms step_avg:96.26ms -step:665/1695 train_time:64014ms step_avg:96.26ms -step:666/1695 train_time:64111ms step_avg:96.26ms -step:667/1695 train_time:64207ms step_avg:96.26ms -step:668/1695 train_time:64303ms step_avg:96.26ms -step:669/1695 train_time:64400ms step_avg:96.26ms -step:670/1695 train_time:64497ms step_avg:96.26ms -step:671/1695 train_time:64594ms step_avg:96.27ms -step:672/1695 train_time:64691ms step_avg:96.27ms -step:673/1695 train_time:64786ms step_avg:96.26ms -step:674/1695 train_time:64881ms step_avg:96.26ms -step:675/1695 train_time:64977ms step_avg:96.26ms -step:676/1695 train_time:65073ms step_avg:96.26ms -step:677/1695 train_time:65170ms step_avg:96.26ms -step:678/1695 train_time:65266ms step_avg:96.26ms -step:679/1695 train_time:65361ms step_avg:96.26ms -step:680/1695 train_time:65458ms step_avg:96.26ms -step:681/1695 train_time:65554ms step_avg:96.26ms -step:682/1695 train_time:65651ms step_avg:96.26ms -step:683/1695 train_time:65747ms step_avg:96.26ms -step:684/1695 train_time:65843ms step_avg:96.26ms -step:685/1695 train_time:65939ms step_avg:96.26ms -step:686/1695 train_time:66034ms step_avg:96.26ms -step:687/1695 train_time:66129ms step_avg:96.26ms -step:688/1695 train_time:66225ms step_avg:96.26ms -step:689/1695 train_time:66320ms step_avg:96.26ms -step:690/1695 train_time:66417ms step_avg:96.26ms -step:691/1695 train_time:66874ms step_avg:96.78ms -step:692/1695 train_time:66944ms step_avg:96.74ms -step:693/1695 train_time:67039ms step_avg:96.74ms -step:694/1695 train_time:67134ms step_avg:96.74ms -step:695/1695 train_time:67229ms step_avg:96.73ms -step:696/1695 train_time:67323ms step_avg:96.73ms -step:697/1695 train_time:67418ms step_avg:96.73ms -step:698/1695 train_time:67513ms step_avg:96.72ms -step:699/1695 train_time:67607ms step_avg:96.72ms -step:700/1695 train_time:67702ms step_avg:96.72ms -step:701/1695 train_time:67802ms step_avg:96.72ms -step:702/1695 train_time:67902ms step_avg:96.73ms -step:703/1695 train_time:68000ms step_avg:96.73ms -step:704/1695 train_time:68096ms step_avg:96.73ms -step:705/1695 train_time:68192ms step_avg:96.73ms -step:706/1695 train_time:68287ms step_avg:96.72ms -step:707/1695 train_time:68382ms step_avg:96.72ms -step:708/1695 train_time:68478ms step_avg:96.72ms -step:709/1695 train_time:68574ms step_avg:96.72ms -step:710/1695 train_time:68670ms step_avg:96.72ms -step:711/1695 train_time:68765ms step_avg:96.72ms -step:712/1695 train_time:68862ms step_avg:96.72ms -step:713/1695 train_time:68959ms step_avg:96.72ms -step:714/1695 train_time:69056ms step_avg:96.72ms -step:715/1695 train_time:69153ms step_avg:96.72ms -step:716/1695 train_time:69250ms step_avg:96.72ms -step:717/1695 train_time:69344ms step_avg:96.71ms -step:718/1695 train_time:69439ms step_avg:96.71ms -step:719/1695 train_time:69535ms step_avg:96.71ms -step:720/1695 train_time:69631ms step_avg:96.71ms -step:721/1695 train_time:69726ms step_avg:96.71ms -step:722/1695 train_time:69822ms step_avg:96.71ms -step:723/1695 train_time:69918ms step_avg:96.71ms -step:724/1695 train_time:70015ms step_avg:96.71ms -step:725/1695 train_time:70113ms step_avg:96.71ms -step:726/1695 train_time:70210ms step_avg:96.71ms -step:727/1695 train_time:70305ms step_avg:96.71ms -step:728/1695 train_time:70400ms step_avg:96.70ms -step:729/1695 train_time:70496ms step_avg:96.70ms -step:730/1695 train_time:70591ms step_avg:96.70ms -step:731/1695 train_time:70686ms step_avg:96.70ms -step:732/1695 train_time:70782ms step_avg:96.70ms -step:733/1695 train_time:70878ms step_avg:96.70ms -step:734/1695 train_time:70974ms step_avg:96.69ms -step:735/1695 train_time:71069ms step_avg:96.69ms -step:736/1695 train_time:71166ms step_avg:96.69ms -step:737/1695 train_time:71262ms step_avg:96.69ms -step:738/1695 train_time:71358ms step_avg:96.69ms -step:739/1695 train_time:71454ms step_avg:96.69ms -step:740/1695 train_time:71549ms step_avg:96.69ms -step:741/1695 train_time:71644ms step_avg:96.69ms -step:742/1695 train_time:71739ms step_avg:96.68ms -step:743/1695 train_time:71836ms step_avg:96.68ms -step:744/1695 train_time:71931ms step_avg:96.68ms -step:745/1695 train_time:72027ms step_avg:96.68ms -step:746/1695 train_time:72123ms step_avg:96.68ms -step:747/1695 train_time:72220ms step_avg:96.68ms -step:748/1695 train_time:72317ms step_avg:96.68ms -step:749/1695 train_time:72413ms step_avg:96.68ms -step:750/1695 train_time:72510ms step_avg:96.68ms -step:750/1695 val_loss:3.5686 train_time:72604ms step_avg:96.81ms -step:751/1695 train_time:72630ms step_avg:96.71ms -step:752/1695 train_time:72710ms step_avg:96.69ms -step:753/1695 train_time:72807ms step_avg:96.69ms -step:754/1695 train_time:72902ms step_avg:96.69ms -step:755/1695 train_time:72998ms step_avg:96.69ms -step:756/1695 train_time:73092ms step_avg:96.68ms -step:757/1695 train_time:73186ms step_avg:96.68ms -step:758/1695 train_time:73281ms step_avg:96.68ms -step:759/1695 train_time:73376ms step_avg:96.67ms -step:760/1695 train_time:73470ms step_avg:96.67ms -step:761/1695 train_time:73566ms step_avg:96.67ms -step:762/1695 train_time:73665ms step_avg:96.67ms -step:763/1695 train_time:73763ms step_avg:96.67ms -step:764/1695 train_time:73859ms step_avg:96.67ms -step:765/1695 train_time:73955ms step_avg:96.67ms -step:766/1695 train_time:74051ms step_avg:96.67ms -step:767/1695 train_time:74146ms step_avg:96.67ms -step:768/1695 train_time:74241ms step_avg:96.67ms -step:769/1695 train_time:74336ms step_avg:96.67ms -step:770/1695 train_time:74430ms step_avg:96.66ms -step:771/1695 train_time:74525ms step_avg:96.66ms -step:772/1695 train_time:74622ms step_avg:96.66ms -step:773/1695 train_time:74719ms step_avg:96.66ms -step:774/1695 train_time:74817ms step_avg:96.66ms -step:775/1695 train_time:74913ms step_avg:96.66ms -step:776/1695 train_time:75008ms step_avg:96.66ms -step:777/1695 train_time:75104ms step_avg:96.66ms -step:778/1695 train_time:75199ms step_avg:96.66ms -step:779/1695 train_time:75294ms step_avg:96.65ms -step:780/1695 train_time:75389ms step_avg:96.65ms -step:781/1695 train_time:75484ms step_avg:96.65ms -step:782/1695 train_time:75580ms step_avg:96.65ms -step:783/1695 train_time:75677ms step_avg:96.65ms -step:784/1695 train_time:75773ms step_avg:96.65ms -step:785/1695 train_time:75870ms step_avg:96.65ms -step:786/1695 train_time:75965ms step_avg:96.65ms -step:787/1695 train_time:76061ms step_avg:96.65ms -step:788/1695 train_time:76156ms step_avg:96.64ms -step:789/1695 train_time:76251ms step_avg:96.64ms -step:790/1695 train_time:76346ms step_avg:96.64ms -step:791/1695 train_time:76441ms step_avg:96.64ms -step:792/1695 train_time:76538ms step_avg:96.64ms -step:793/1695 train_time:76634ms step_avg:96.64ms -step:794/1695 train_time:76730ms step_avg:96.64ms -step:795/1695 train_time:76825ms step_avg:96.64ms -step:796/1695 train_time:76921ms step_avg:96.63ms -step:797/1695 train_time:77017ms step_avg:96.63ms -step:798/1695 train_time:77113ms step_avg:96.63ms -step:799/1695 train_time:77207ms step_avg:96.63ms -step:800/1695 train_time:77303ms step_avg:96.63ms -step:801/1695 train_time:77398ms step_avg:96.63ms -step:802/1695 train_time:77493ms step_avg:96.62ms -step:803/1695 train_time:77588ms step_avg:96.62ms -step:804/1695 train_time:77684ms step_avg:96.62ms -step:805/1695 train_time:77782ms step_avg:96.62ms -step:806/1695 train_time:77879ms step_avg:96.62ms -step:807/1695 train_time:77976ms step_avg:96.62ms -step:808/1695 train_time:78073ms step_avg:96.62ms -step:809/1695 train_time:78168ms step_avg:96.62ms -step:810/1695 train_time:78262ms step_avg:96.62ms -step:811/1695 train_time:78357ms step_avg:96.62ms -step:812/1695 train_time:78453ms step_avg:96.62ms -step:813/1695 train_time:78547ms step_avg:96.61ms -step:814/1695 train_time:78643ms step_avg:96.61ms -step:815/1695 train_time:78739ms step_avg:96.61ms -step:816/1695 train_time:78836ms step_avg:96.61ms -step:817/1695 train_time:78932ms step_avg:96.61ms -step:818/1695 train_time:79028ms step_avg:96.61ms -step:819/1695 train_time:79124ms step_avg:96.61ms -step:820/1695 train_time:79219ms step_avg:96.61ms -step:821/1695 train_time:79315ms step_avg:96.61ms -step:822/1695 train_time:79410ms step_avg:96.61ms -step:823/1695 train_time:79505ms step_avg:96.60ms -step:824/1695 train_time:79600ms step_avg:96.60ms -step:825/1695 train_time:79696ms step_avg:96.60ms -step:826/1695 train_time:79791ms step_avg:96.60ms -step:827/1695 train_time:79887ms step_avg:96.60ms -step:828/1695 train_time:79984ms step_avg:96.60ms -step:829/1695 train_time:80081ms step_avg:96.60ms -step:830/1695 train_time:80178ms step_avg:96.60ms -step:831/1695 train_time:80273ms step_avg:96.60ms -step:832/1695 train_time:80368ms step_avg:96.60ms -step:833/1695 train_time:80463ms step_avg:96.59ms -step:834/1695 train_time:80559ms step_avg:96.59ms -step:835/1695 train_time:80656ms step_avg:96.59ms -step:836/1695 train_time:80752ms step_avg:96.59ms -step:837/1695 train_time:80847ms step_avg:96.59ms -step:838/1695 train_time:80942ms step_avg:96.59ms -step:839/1695 train_time:81038ms step_avg:96.59ms -step:840/1695 train_time:81134ms step_avg:96.59ms -step:841/1695 train_time:81229ms step_avg:96.59ms -step:842/1695 train_time:81324ms step_avg:96.58ms -step:843/1695 train_time:81420ms step_avg:96.58ms -step:844/1695 train_time:81516ms step_avg:96.58ms -step:845/1695 train_time:81612ms step_avg:96.58ms -step:846/1695 train_time:81707ms step_avg:96.58ms -step:847/1695 train_time:81802ms step_avg:96.58ms -step:848/1695 train_time:81899ms step_avg:96.58ms -step:849/1695 train_time:81994ms step_avg:96.58ms -step:850/1695 train_time:82090ms step_avg:96.58ms -step:851/1695 train_time:82186ms step_avg:96.58ms -step:852/1695 train_time:82281ms step_avg:96.57ms -step:853/1695 train_time:82377ms step_avg:96.57ms -step:854/1695 train_time:82473ms step_avg:96.57ms -step:855/1695 train_time:82569ms step_avg:96.57ms -step:856/1695 train_time:82664ms step_avg:96.57ms -step:857/1695 train_time:82760ms step_avg:96.57ms -step:858/1695 train_time:82855ms step_avg:96.57ms -step:859/1695 train_time:82952ms step_avg:96.57ms -step:860/1695 train_time:83047ms step_avg:96.57ms -step:861/1695 train_time:83143ms step_avg:96.57ms -step:862/1695 train_time:83239ms step_avg:96.56ms -step:863/1695 train_time:83566ms step_avg:96.83ms -step:864/1695 train_time:83759ms step_avg:96.94ms -step:865/1695 train_time:83853ms step_avg:96.94ms -step:866/1695 train_time:83948ms step_avg:96.94ms -step:867/1695 train_time:84042ms step_avg:96.93ms -step:868/1695 train_time:84138ms step_avg:96.93ms -step:869/1695 train_time:84233ms step_avg:96.93ms -step:870/1695 train_time:84327ms step_avg:96.93ms -step:871/1695 train_time:84421ms step_avg:96.92ms -step:872/1695 train_time:84516ms step_avg:96.92ms -step:873/1695 train_time:84616ms step_avg:96.93ms -step:874/1695 train_time:84714ms step_avg:96.93ms -step:875/1695 train_time:84811ms step_avg:96.93ms -step:875/1695 val_loss:3.5270 train_time:84905ms step_avg:97.03ms -step:876/1695 train_time:84930ms step_avg:96.95ms -step:877/1695 train_time:85012ms step_avg:96.94ms -step:878/1695 train_time:85111ms step_avg:96.94ms -step:879/1695 train_time:85209ms step_avg:96.94ms -step:880/1695 train_time:85304ms step_avg:96.94ms -step:881/1695 train_time:85400ms step_avg:96.94ms -step:882/1695 train_time:85494ms step_avg:96.93ms -step:883/1695 train_time:85589ms step_avg:96.93ms -step:884/1695 train_time:85685ms step_avg:96.93ms -step:885/1695 train_time:85780ms step_avg:96.93ms -step:886/1695 train_time:85876ms step_avg:96.93ms -step:887/1695 train_time:85973ms step_avg:96.93ms -step:888/1695 train_time:86071ms step_avg:96.93ms -step:889/1695 train_time:86170ms step_avg:96.93ms -step:890/1695 train_time:86267ms step_avg:96.93ms -step:891/1695 train_time:86364ms step_avg:96.93ms -step:892/1695 train_time:86459ms step_avg:96.93ms -step:893/1695 train_time:86554ms step_avg:96.92ms -step:894/1695 train_time:86649ms step_avg:96.92ms -step:895/1695 train_time:86745ms step_avg:96.92ms -step:896/1695 train_time:86842ms step_avg:96.92ms -step:897/1695 train_time:86938ms step_avg:96.92ms -step:898/1695 train_time:87034ms step_avg:96.92ms -step:899/1695 train_time:87131ms step_avg:96.92ms -step:900/1695 train_time:87229ms step_avg:96.92ms -step:901/1695 train_time:87325ms step_avg:96.92ms -step:902/1695 train_time:87422ms step_avg:96.92ms -step:903/1695 train_time:87517ms step_avg:96.92ms -step:904/1695 train_time:87612ms step_avg:96.92ms -step:905/1695 train_time:87708ms step_avg:96.91ms -step:906/1695 train_time:87804ms step_avg:96.91ms -step:907/1695 train_time:87901ms step_avg:96.91ms -step:908/1695 train_time:87997ms step_avg:96.91ms -step:909/1695 train_time:88092ms step_avg:96.91ms -step:910/1695 train_time:88188ms step_avg:96.91ms -step:911/1695 train_time:88285ms step_avg:96.91ms -step:912/1695 train_time:88381ms step_avg:96.91ms -step:913/1695 train_time:88477ms step_avg:96.91ms -step:914/1695 train_time:88572ms step_avg:96.91ms -step:915/1695 train_time:88667ms step_avg:96.90ms -step:916/1695 train_time:88763ms step_avg:96.90ms -step:917/1695 train_time:88860ms step_avg:96.90ms -step:918/1695 train_time:88955ms step_avg:96.90ms -step:919/1695 train_time:89050ms step_avg:96.90ms -step:920/1695 train_time:89147ms step_avg:96.90ms -step:921/1695 train_time:89245ms step_avg:96.90ms -step:922/1695 train_time:89341ms step_avg:96.90ms -step:923/1695 train_time:89436ms step_avg:96.90ms -step:924/1695 train_time:89531ms step_avg:96.90ms -step:925/1695 train_time:89627ms step_avg:96.89ms -step:926/1695 train_time:89724ms step_avg:96.89ms -step:927/1695 train_time:89821ms step_avg:96.89ms -step:928/1695 train_time:89916ms step_avg:96.89ms -step:929/1695 train_time:90011ms step_avg:96.89ms -step:930/1695 train_time:90108ms step_avg:96.89ms -step:931/1695 train_time:90205ms step_avg:96.89ms -step:932/1695 train_time:90302ms step_avg:96.89ms -step:933/1695 train_time:90398ms step_avg:96.89ms -step:934/1695 train_time:90493ms step_avg:96.89ms -step:935/1695 train_time:90589ms step_avg:96.89ms -step:936/1695 train_time:90685ms step_avg:96.89ms -step:937/1695 train_time:90782ms step_avg:96.89ms -step:938/1695 train_time:90877ms step_avg:96.88ms -step:939/1695 train_time:90973ms step_avg:96.88ms -step:940/1695 train_time:91069ms step_avg:96.88ms -step:941/1695 train_time:91167ms step_avg:96.88ms -step:942/1695 train_time:91264ms step_avg:96.88ms -step:943/1695 train_time:91361ms step_avg:96.88ms -step:944/1695 train_time:91456ms step_avg:96.88ms -step:945/1695 train_time:91552ms step_avg:96.88ms -step:946/1695 train_time:91648ms step_avg:96.88ms -step:947/1695 train_time:91744ms step_avg:96.88ms -step:948/1695 train_time:91840ms step_avg:96.88ms -step:949/1695 train_time:91935ms step_avg:96.88ms -step:950/1695 train_time:92031ms step_avg:96.87ms -step:951/1695 train_time:92127ms step_avg:96.87ms -step:952/1695 train_time:92223ms step_avg:96.87ms -step:953/1695 train_time:92319ms step_avg:96.87ms -step:954/1695 train_time:92415ms step_avg:96.87ms -step:955/1695 train_time:92510ms step_avg:96.87ms -step:956/1695 train_time:92606ms step_avg:96.87ms -step:957/1695 train_time:92702ms step_avg:96.87ms -step:958/1695 train_time:92798ms step_avg:96.87ms -step:959/1695 train_time:92894ms step_avg:96.87ms -step:960/1695 train_time:92990ms step_avg:96.86ms -step:961/1695 train_time:93087ms step_avg:96.86ms -step:962/1695 train_time:93183ms step_avg:96.86ms -step:963/1695 train_time:93279ms step_avg:96.86ms -step:964/1695 train_time:93374ms step_avg:96.86ms -step:965/1695 train_time:93471ms step_avg:96.86ms -step:966/1695 train_time:93568ms step_avg:96.86ms -step:967/1695 train_time:93664ms step_avg:96.86ms -step:968/1695 train_time:93760ms step_avg:96.86ms -step:969/1695 train_time:93856ms step_avg:96.86ms -step:970/1695 train_time:93951ms step_avg:96.86ms -step:971/1695 train_time:94048ms step_avg:96.86ms -step:972/1695 train_time:94144ms step_avg:96.86ms -step:973/1695 train_time:94241ms step_avg:96.86ms -step:974/1695 train_time:94336ms step_avg:96.85ms -step:975/1695 train_time:94431ms step_avg:96.85ms -step:976/1695 train_time:94528ms step_avg:96.85ms -step:977/1695 train_time:94625ms step_avg:96.85ms -step:978/1695 train_time:94722ms step_avg:96.85ms -step:979/1695 train_time:94818ms step_avg:96.85ms -step:980/1695 train_time:94913ms step_avg:96.85ms -step:981/1695 train_time:95010ms step_avg:96.85ms -step:982/1695 train_time:95106ms step_avg:96.85ms -step:983/1695 train_time:95202ms step_avg:96.85ms -step:984/1695 train_time:95298ms step_avg:96.85ms -step:985/1695 train_time:95393ms step_avg:96.85ms -step:986/1695 train_time:95489ms step_avg:96.84ms -step:987/1695 train_time:95585ms step_avg:96.84ms -step:988/1695 train_time:95682ms step_avg:96.84ms -step:989/1695 train_time:95777ms step_avg:96.84ms -step:990/1695 train_time:95872ms step_avg:96.84ms -step:991/1695 train_time:95967ms step_avg:96.84ms -step:992/1695 train_time:96064ms step_avg:96.84ms -step:993/1695 train_time:96160ms step_avg:96.84ms -step:994/1695 train_time:96255ms step_avg:96.84ms -step:995/1695 train_time:96351ms step_avg:96.83ms -step:996/1695 train_time:96446ms step_avg:96.83ms -step:997/1695 train_time:96543ms step_avg:96.83ms -step:998/1695 train_time:96638ms step_avg:96.83ms -step:999/1695 train_time:96734ms step_avg:96.83ms -step:1000/1695 train_time:96830ms step_avg:96.83ms -step:1000/1695 val_loss:3.4844 train_time:96924ms step_avg:96.92ms -step:1001/1695 train_time:96949ms step_avg:96.85ms -step:1002/1695 train_time:97032ms step_avg:96.84ms -step:1003/1695 train_time:97130ms step_avg:96.84ms -step:1004/1695 train_time:97226ms step_avg:96.84ms -step:1005/1695 train_time:97322ms step_avg:96.84ms -step:1006/1695 train_time:97417ms step_avg:96.84ms -step:1007/1695 train_time:97512ms step_avg:96.83ms -step:1008/1695 train_time:97606ms step_avg:96.83ms -step:1009/1695 train_time:97702ms step_avg:96.83ms -step:1010/1695 train_time:97797ms step_avg:96.83ms -step:1011/1695 train_time:97893ms step_avg:96.83ms -step:1012/1695 train_time:97991ms step_avg:96.83ms -step:1013/1695 train_time:98089ms step_avg:96.83ms -step:1014/1695 train_time:98187ms step_avg:96.83ms -step:1015/1695 train_time:98284ms step_avg:96.83ms -step:1016/1695 train_time:98379ms step_avg:96.83ms -step:1017/1695 train_time:98474ms step_avg:96.83ms -step:1018/1695 train_time:98569ms step_avg:96.83ms -step:1019/1695 train_time:98665ms step_avg:96.83ms -step:1020/1695 train_time:98761ms step_avg:96.82ms -step:1021/1695 train_time:98857ms step_avg:96.82ms -step:1022/1695 train_time:98954ms step_avg:96.82ms -step:1023/1695 train_time:99050ms step_avg:96.82ms -step:1024/1695 train_time:99147ms step_avg:96.82ms -step:1025/1695 train_time:99243ms step_avg:96.82ms -step:1026/1695 train_time:99340ms step_avg:96.82ms -step:1027/1695 train_time:99436ms step_avg:96.82ms -step:1028/1695 train_time:99530ms step_avg:96.82ms -step:1029/1695 train_time:99625ms step_avg:96.82ms -step:1030/1695 train_time:99721ms step_avg:96.82ms -step:1031/1695 train_time:99818ms step_avg:96.82ms -step:1032/1695 train_time:99913ms step_avg:96.82ms -step:1033/1695 train_time:100009ms step_avg:96.81ms -step:1034/1695 train_time:100107ms step_avg:96.82ms -step:1035/1695 train_time:100204ms step_avg:96.82ms -step:1036/1695 train_time:100552ms step_avg:97.06ms -step:1037/1695 train_time:100724ms step_avg:97.13ms -step:1038/1695 train_time:100817ms step_avg:97.13ms -step:1039/1695 train_time:100912ms step_avg:97.12ms -step:1040/1695 train_time:101007ms step_avg:97.12ms -step:1041/1695 train_time:101101ms step_avg:97.12ms -step:1042/1695 train_time:101196ms step_avg:97.12ms -step:1043/1695 train_time:101290ms step_avg:97.11ms -step:1044/1695 train_time:101385ms step_avg:97.11ms -step:1045/1695 train_time:101480ms step_avg:97.11ms -step:1046/1695 train_time:101579ms step_avg:97.11ms -step:1047/1695 train_time:101680ms step_avg:97.12ms -step:1048/1695 train_time:101777ms step_avg:97.12ms -step:1049/1695 train_time:101873ms step_avg:97.11ms -step:1050/1695 train_time:101968ms step_avg:97.11ms -step:1051/1695 train_time:102063ms step_avg:97.11ms -step:1052/1695 train_time:102159ms step_avg:97.11ms -step:1053/1695 train_time:102253ms step_avg:97.11ms -step:1054/1695 train_time:102348ms step_avg:97.10ms -step:1055/1695 train_time:102443ms step_avg:97.10ms -step:1056/1695 train_time:102540ms step_avg:97.10ms -step:1057/1695 train_time:102637ms step_avg:97.10ms -step:1058/1695 train_time:102733ms step_avg:97.10ms -step:1059/1695 train_time:102829ms step_avg:97.10ms -step:1060/1695 train_time:102925ms step_avg:97.10ms -step:1061/1695 train_time:103021ms step_avg:97.10ms -step:1062/1695 train_time:103117ms step_avg:97.10ms -step:1063/1695 train_time:103212ms step_avg:97.09ms -step:1064/1695 train_time:103307ms step_avg:97.09ms -step:1065/1695 train_time:103402ms step_avg:97.09ms -step:1066/1695 train_time:103499ms step_avg:97.09ms -step:1067/1695 train_time:103595ms step_avg:97.09ms -step:1068/1695 train_time:103691ms step_avg:97.09ms -step:1069/1695 train_time:103789ms step_avg:97.09ms -step:1070/1695 train_time:103886ms step_avg:97.09ms -step:1071/1695 train_time:103983ms step_avg:97.09ms -step:1072/1695 train_time:104080ms step_avg:97.09ms -step:1073/1695 train_time:104176ms step_avg:97.09ms -step:1074/1695 train_time:104270ms step_avg:97.09ms -step:1075/1695 train_time:104366ms step_avg:97.08ms -step:1076/1695 train_time:104461ms step_avg:97.08ms -step:1077/1695 train_time:104558ms step_avg:97.08ms -step:1078/1695 train_time:104654ms step_avg:97.08ms -step:1079/1695 train_time:104750ms step_avg:97.08ms -step:1080/1695 train_time:104847ms step_avg:97.08ms -step:1081/1695 train_time:104946ms step_avg:97.08ms -step:1082/1695 train_time:105043ms step_avg:97.08ms -step:1083/1695 train_time:105140ms step_avg:97.08ms -step:1084/1695 train_time:105235ms step_avg:97.08ms -step:1085/1695 train_time:105330ms step_avg:97.08ms -step:1086/1695 train_time:105426ms step_avg:97.08ms -step:1087/1695 train_time:105523ms step_avg:97.08ms -step:1088/1695 train_time:105619ms step_avg:97.08ms -step:1089/1695 train_time:105715ms step_avg:97.08ms -step:1090/1695 train_time:105811ms step_avg:97.07ms -step:1091/1695 train_time:105907ms step_avg:97.07ms -step:1092/1695 train_time:106003ms step_avg:97.07ms -step:1093/1695 train_time:106099ms step_avg:97.07ms -step:1094/1695 train_time:106194ms step_avg:97.07ms -step:1095/1695 train_time:106289ms step_avg:97.07ms -step:1096/1695 train_time:106386ms step_avg:97.07ms -step:1097/1695 train_time:106482ms step_avg:97.07ms -step:1098/1695 train_time:106578ms step_avg:97.07ms -step:1099/1695 train_time:106673ms step_avg:97.06ms -step:1100/1695 train_time:106769ms step_avg:97.06ms -step:1101/1695 train_time:106866ms step_avg:97.06ms -step:1102/1695 train_time:106963ms step_avg:97.06ms -step:1103/1695 train_time:107060ms step_avg:97.06ms -step:1104/1695 train_time:107156ms step_avg:97.06ms -step:1105/1695 train_time:107251ms step_avg:97.06ms -step:1106/1695 train_time:107347ms step_avg:97.06ms -step:1107/1695 train_time:107443ms step_avg:97.06ms -step:1108/1695 train_time:107540ms step_avg:97.06ms -step:1109/1695 train_time:107636ms step_avg:97.06ms -step:1110/1695 train_time:107730ms step_avg:97.05ms -step:1111/1695 train_time:107826ms step_avg:97.05ms -step:1112/1695 train_time:107922ms step_avg:97.05ms -step:1113/1695 train_time:108019ms step_avg:97.05ms -step:1114/1695 train_time:108115ms step_avg:97.05ms -step:1115/1695 train_time:108211ms step_avg:97.05ms -step:1116/1695 train_time:108306ms step_avg:97.05ms -step:1117/1695 train_time:108403ms step_avg:97.05ms -step:1118/1695 train_time:108499ms step_avg:97.05ms -step:1119/1695 train_time:108594ms step_avg:97.05ms -step:1120/1695 train_time:108690ms step_avg:97.04ms -step:1121/1695 train_time:108786ms step_avg:97.04ms -step:1122/1695 train_time:108883ms step_avg:97.04ms -step:1123/1695 train_time:108979ms step_avg:97.04ms -step:1124/1695 train_time:109075ms step_avg:97.04ms -step:1125/1695 train_time:109170ms step_avg:97.04ms -step:1125/1695 val_loss:3.4368 train_time:109264ms step_avg:97.12ms -step:1126/1695 train_time:109288ms step_avg:97.06ms -step:1127/1695 train_time:109371ms step_avg:97.05ms -step:1128/1695 train_time:109469ms step_avg:97.05ms -step:1129/1695 train_time:109566ms step_avg:97.05ms -step:1130/1695 train_time:109662ms step_avg:97.05ms -step:1131/1695 train_time:109757ms step_avg:97.04ms -step:1132/1695 train_time:109852ms step_avg:97.04ms -step:1133/1695 train_time:109950ms step_avg:97.04ms -step:1134/1695 train_time:110047ms step_avg:97.04ms -step:1135/1695 train_time:110144ms step_avg:97.04ms -step:1136/1695 train_time:110243ms step_avg:97.04ms -step:1137/1695 train_time:110343ms step_avg:97.05ms -step:1138/1695 train_time:110442ms step_avg:97.05ms -step:1139/1695 train_time:110539ms step_avg:97.05ms -step:1140/1695 train_time:110636ms step_avg:97.05ms -step:1141/1695 train_time:110733ms step_avg:97.05ms -step:1142/1695 train_time:110829ms step_avg:97.05ms -step:1143/1695 train_time:110927ms step_avg:97.05ms -step:1144/1695 train_time:111024ms step_avg:97.05ms -step:1145/1695 train_time:111121ms step_avg:97.05ms -step:1146/1695 train_time:111220ms step_avg:97.05ms -step:1147/1695 train_time:111319ms step_avg:97.05ms -step:1148/1695 train_time:111417ms step_avg:97.05ms -step:1149/1695 train_time:111515ms step_avg:97.05ms -step:1150/1695 train_time:111613ms step_avg:97.05ms -step:1151/1695 train_time:111710ms step_avg:97.05ms -step:1152/1695 train_time:111807ms step_avg:97.05ms -step:1153/1695 train_time:111903ms step_avg:97.05ms -step:1154/1695 train_time:112000ms step_avg:97.05ms -step:1155/1695 train_time:112096ms step_avg:97.05ms -step:1156/1695 train_time:112194ms step_avg:97.05ms -step:1157/1695 train_time:112292ms step_avg:97.05ms -step:1158/1695 train_time:112392ms step_avg:97.06ms -step:1159/1695 train_time:112491ms step_avg:97.06ms -step:1160/1695 train_time:112592ms step_avg:97.06ms -step:1161/1695 train_time:112691ms step_avg:97.06ms -step:1162/1695 train_time:112789ms step_avg:97.06ms -step:1163/1695 train_time:112885ms step_avg:97.06ms -step:1164/1695 train_time:112983ms step_avg:97.06ms -step:1165/1695 train_time:113080ms step_avg:97.06ms -step:1166/1695 train_time:113177ms step_avg:97.06ms -step:1167/1695 train_time:113274ms step_avg:97.06ms -step:1168/1695 train_time:113372ms step_avg:97.06ms -step:1169/1695 train_time:113471ms step_avg:97.07ms -step:1170/1695 train_time:113571ms step_avg:97.07ms -step:1171/1695 train_time:113670ms step_avg:97.07ms -step:1172/1695 train_time:113769ms step_avg:97.07ms -step:1173/1695 train_time:113866ms step_avg:97.07ms -step:1174/1695 train_time:113964ms step_avg:97.07ms -step:1175/1695 train_time:114063ms step_avg:97.08ms -step:1176/1695 train_time:114161ms step_avg:97.08ms -step:1177/1695 train_time:114259ms step_avg:97.08ms -step:1178/1695 train_time:114356ms step_avg:97.08ms -step:1179/1695 train_time:114453ms step_avg:97.08ms -step:1180/1695 train_time:114551ms step_avg:97.08ms -step:1181/1695 train_time:114649ms step_avg:97.08ms -step:1182/1695 train_time:114746ms step_avg:97.08ms -step:1183/1695 train_time:114844ms step_avg:97.08ms -step:1184/1695 train_time:114942ms step_avg:97.08ms -step:1185/1695 train_time:115039ms step_avg:97.08ms -step:1186/1695 train_time:115136ms step_avg:97.08ms -step:1187/1695 train_time:115233ms step_avg:97.08ms -step:1188/1695 train_time:115331ms step_avg:97.08ms -step:1189/1695 train_time:115429ms step_avg:97.08ms -step:1190/1695 train_time:115527ms step_avg:97.08ms -step:1191/1695 train_time:115625ms step_avg:97.08ms -step:1192/1695 train_time:115723ms step_avg:97.08ms -step:1193/1695 train_time:115821ms step_avg:97.08ms -step:1194/1695 train_time:115918ms step_avg:97.08ms -step:1195/1695 train_time:116016ms step_avg:97.08ms -step:1196/1695 train_time:116114ms step_avg:97.09ms -step:1197/1695 train_time:116212ms step_avg:97.09ms -step:1198/1695 train_time:116310ms step_avg:97.09ms -step:1199/1695 train_time:116409ms step_avg:97.09ms -step:1200/1695 train_time:116509ms step_avg:97.09ms -step:1201/1695 train_time:116609ms step_avg:97.09ms -step:1202/1695 train_time:116708ms step_avg:97.09ms -step:1203/1695 train_time:116808ms step_avg:97.10ms -step:1204/1695 train_time:116906ms step_avg:97.10ms -step:1205/1695 train_time:117005ms step_avg:97.10ms -step:1206/1695 train_time:117103ms step_avg:97.10ms -step:1207/1695 train_time:117201ms step_avg:97.10ms -step:1208/1695 train_time:117548ms step_avg:97.31ms -step:1209/1695 train_time:117728ms step_avg:97.38ms -step:1210/1695 train_time:117823ms step_avg:97.37ms -step:1211/1695 train_time:117920ms step_avg:97.37ms -step:1212/1695 train_time:118016ms step_avg:97.37ms -step:1213/1695 train_time:118112ms step_avg:97.37ms -step:1214/1695 train_time:118209ms step_avg:97.37ms -step:1215/1695 train_time:118306ms step_avg:97.37ms -step:1216/1695 train_time:118402ms step_avg:97.37ms -step:1217/1695 train_time:118500ms step_avg:97.37ms -step:1218/1695 train_time:118604ms step_avg:97.38ms -step:1219/1695 train_time:118704ms step_avg:97.38ms -step:1220/1695 train_time:118801ms step_avg:97.38ms -step:1221/1695 train_time:118897ms step_avg:97.38ms -step:1222/1695 train_time:118994ms step_avg:97.38ms -step:1223/1695 train_time:119090ms step_avg:97.38ms -step:1224/1695 train_time:119187ms step_avg:97.38ms -step:1225/1695 train_time:119285ms step_avg:97.38ms -step:1226/1695 train_time:119382ms step_avg:97.38ms -step:1227/1695 train_time:119480ms step_avg:97.38ms -step:1228/1695 train_time:119579ms step_avg:97.38ms -step:1229/1695 train_time:119678ms step_avg:97.38ms -step:1230/1695 train_time:119776ms step_avg:97.38ms -step:1231/1695 train_time:119874ms step_avg:97.38ms -step:1232/1695 train_time:119971ms step_avg:97.38ms -step:1233/1695 train_time:120068ms step_avg:97.38ms -step:1234/1695 train_time:120166ms step_avg:97.38ms -step:1235/1695 train_time:120263ms step_avg:97.38ms -step:1236/1695 train_time:120360ms step_avg:97.38ms -step:1237/1695 train_time:120457ms step_avg:97.38ms -step:1238/1695 train_time:120555ms step_avg:97.38ms -step:1239/1695 train_time:120654ms step_avg:97.38ms -step:1240/1695 train_time:120752ms step_avg:97.38ms -step:1241/1695 train_time:120851ms step_avg:97.38ms -step:1242/1695 train_time:120950ms step_avg:97.38ms -step:1243/1695 train_time:121048ms step_avg:97.38ms -step:1244/1695 train_time:121145ms step_avg:97.38ms -step:1245/1695 train_time:121243ms step_avg:97.38ms -step:1246/1695 train_time:121340ms step_avg:97.38ms -step:1247/1695 train_time:121437ms step_avg:97.38ms -step:1248/1695 train_time:121534ms step_avg:97.38ms -step:1249/1695 train_time:121632ms step_avg:97.38ms -step:1250/1695 train_time:121731ms step_avg:97.38ms -step:1250/1695 val_loss:3.3897 train_time:121827ms step_avg:97.46ms -step:1251/1695 train_time:121854ms step_avg:97.40ms -step:1252/1695 train_time:121931ms step_avg:97.39ms -step:1253/1695 train_time:122027ms step_avg:97.39ms -step:1254/1695 train_time:122123ms step_avg:97.39ms -step:1255/1695 train_time:122220ms step_avg:97.39ms -step:1256/1695 train_time:122317ms step_avg:97.39ms -step:1257/1695 train_time:122414ms step_avg:97.39ms -step:1258/1695 train_time:122510ms step_avg:97.38ms -step:1259/1695 train_time:122606ms step_avg:97.38ms -step:1260/1695 train_time:122702ms step_avg:97.38ms -step:1261/1695 train_time:122805ms step_avg:97.39ms -step:1262/1695 train_time:122904ms step_avg:97.39ms -step:1263/1695 train_time:123002ms step_avg:97.39ms -step:1264/1695 train_time:123099ms step_avg:97.39ms -step:1265/1695 train_time:123196ms step_avg:97.39ms -step:1266/1695 train_time:123292ms step_avg:97.39ms -step:1267/1695 train_time:123390ms step_avg:97.39ms -step:1268/1695 train_time:123486ms step_avg:97.39ms -step:1269/1695 train_time:123583ms step_avg:97.39ms -step:1270/1695 train_time:123681ms step_avg:97.39ms -step:1271/1695 train_time:123780ms step_avg:97.39ms -step:1272/1695 train_time:123879ms step_avg:97.39ms -step:1273/1695 train_time:123978ms step_avg:97.39ms -step:1274/1695 train_time:124078ms step_avg:97.39ms -step:1275/1695 train_time:124176ms step_avg:97.39ms -step:1276/1695 train_time:124275ms step_avg:97.39ms -step:1277/1695 train_time:124372ms step_avg:97.39ms -step:1278/1695 train_time:124470ms step_avg:97.39ms -step:1279/1695 train_time:124567ms step_avg:97.39ms -step:1280/1695 train_time:124664ms step_avg:97.39ms -step:1281/1695 train_time:124761ms step_avg:97.39ms -step:1282/1695 train_time:124859ms step_avg:97.39ms -step:1283/1695 train_time:124959ms step_avg:97.40ms -step:1284/1695 train_time:125057ms step_avg:97.40ms -step:1285/1695 train_time:125156ms step_avg:97.40ms -step:1286/1695 train_time:125254ms step_avg:97.40ms -step:1287/1695 train_time:125351ms step_avg:97.40ms -step:1288/1695 train_time:125449ms step_avg:97.40ms -step:1289/1695 train_time:125546ms step_avg:97.40ms -step:1290/1695 train_time:125644ms step_avg:97.40ms -step:1291/1695 train_time:125741ms step_avg:97.40ms -step:1292/1695 train_time:125839ms step_avg:97.40ms -step:1293/1695 train_time:125938ms step_avg:97.40ms -step:1294/1695 train_time:126037ms step_avg:97.40ms -step:1295/1695 train_time:126136ms step_avg:97.40ms -step:1296/1695 train_time:126234ms step_avg:97.40ms -step:1297/1695 train_time:126333ms step_avg:97.40ms -step:1298/1695 train_time:126431ms step_avg:97.40ms -step:1299/1695 train_time:126529ms step_avg:97.41ms -step:1300/1695 train_time:126628ms step_avg:97.41ms -step:1301/1695 train_time:126726ms step_avg:97.41ms -step:1302/1695 train_time:126823ms step_avg:97.41ms -step:1303/1695 train_time:126921ms step_avg:97.41ms -step:1304/1695 train_time:127018ms step_avg:97.41ms -step:1305/1695 train_time:127117ms step_avg:97.41ms -step:1306/1695 train_time:127216ms step_avg:97.41ms -step:1307/1695 train_time:127314ms step_avg:97.41ms -step:1308/1695 train_time:127412ms step_avg:97.41ms -step:1309/1695 train_time:127509ms step_avg:97.41ms -step:1310/1695 train_time:127608ms step_avg:97.41ms -step:1311/1695 train_time:127705ms step_avg:97.41ms -step:1312/1695 train_time:127802ms step_avg:97.41ms -step:1313/1695 train_time:127899ms step_avg:97.41ms -step:1314/1695 train_time:127996ms step_avg:97.41ms -step:1315/1695 train_time:128095ms step_avg:97.41ms -step:1316/1695 train_time:128193ms step_avg:97.41ms -step:1317/1695 train_time:128291ms step_avg:97.41ms -step:1318/1695 train_time:128389ms step_avg:97.41ms -step:1319/1695 train_time:128485ms step_avg:97.41ms -step:1320/1695 train_time:128582ms step_avg:97.41ms -step:1321/1695 train_time:128680ms step_avg:97.41ms -step:1322/1695 train_time:128778ms step_avg:97.41ms -step:1323/1695 train_time:128876ms step_avg:97.41ms -step:1324/1695 train_time:128974ms step_avg:97.41ms -step:1325/1695 train_time:129072ms step_avg:97.41ms -step:1326/1695 train_time:129170ms step_avg:97.41ms -step:1327/1695 train_time:129268ms step_avg:97.41ms -step:1328/1695 train_time:129366ms step_avg:97.41ms -step:1329/1695 train_time:129463ms step_avg:97.41ms -step:1330/1695 train_time:129561ms step_avg:97.41ms -step:1331/1695 train_time:129659ms step_avg:97.41ms -step:1332/1695 train_time:129758ms step_avg:97.42ms -step:1333/1695 train_time:129857ms step_avg:97.42ms -step:1334/1695 train_time:129955ms step_avg:97.42ms -step:1335/1695 train_time:130053ms step_avg:97.42ms -step:1336/1695 train_time:130151ms step_avg:97.42ms -step:1337/1695 train_time:130248ms step_avg:97.42ms -step:1338/1695 train_time:130347ms step_avg:97.42ms -step:1339/1695 train_time:130444ms step_avg:97.42ms -step:1340/1695 train_time:130541ms step_avg:97.42ms -step:1341/1695 train_time:130639ms step_avg:97.42ms -step:1342/1695 train_time:130736ms step_avg:97.42ms -step:1343/1695 train_time:130835ms step_avg:97.42ms -step:1344/1695 train_time:130933ms step_avg:97.42ms -step:1345/1695 train_time:131030ms step_avg:97.42ms -step:1346/1695 train_time:131127ms step_avg:97.42ms -step:1347/1695 train_time:131224ms step_avg:97.42ms -step:1348/1695 train_time:131321ms step_avg:97.42ms -step:1349/1695 train_time:131419ms step_avg:97.42ms -step:1350/1695 train_time:131518ms step_avg:97.42ms -step:1351/1695 train_time:131615ms step_avg:97.42ms -step:1352/1695 train_time:131714ms step_avg:97.42ms -step:1353/1695 train_time:131813ms step_avg:97.42ms -step:1354/1695 train_time:131911ms step_avg:97.42ms -step:1355/1695 train_time:132009ms step_avg:97.42ms -step:1356/1695 train_time:132106ms step_avg:97.42ms -step:1357/1695 train_time:132203ms step_avg:97.42ms -step:1358/1695 train_time:132300ms step_avg:97.42ms -step:1359/1695 train_time:132398ms step_avg:97.42ms -step:1360/1695 train_time:132497ms step_avg:97.42ms -step:1361/1695 train_time:132595ms step_avg:97.42ms -step:1362/1695 train_time:132693ms step_avg:97.43ms -step:1363/1695 train_time:132792ms step_avg:97.43ms -step:1364/1695 train_time:132890ms step_avg:97.43ms -step:1365/1695 train_time:132988ms step_avg:97.43ms -step:1366/1695 train_time:133085ms step_avg:97.43ms -step:1367/1695 train_time:133182ms step_avg:97.43ms -step:1368/1695 train_time:133279ms step_avg:97.43ms -step:1369/1695 train_time:133377ms step_avg:97.43ms -step:1370/1695 train_time:133476ms step_avg:97.43ms -step:1371/1695 train_time:133574ms step_avg:97.43ms -step:1372/1695 train_time:133671ms step_avg:97.43ms -step:1373/1695 train_time:133769ms step_avg:97.43ms -step:1374/1695 train_time:133867ms step_avg:97.43ms -step:1375/1695 train_time:133964ms step_avg:97.43ms -step:1375/1695 val_loss:3.3507 train_time:134060ms step_avg:97.50ms -step:1376/1695 train_time:134085ms step_avg:97.45ms -step:1377/1695 train_time:134167ms step_avg:97.43ms -step:1378/1695 train_time:134266ms step_avg:97.44ms -step:1379/1695 train_time:134364ms step_avg:97.44ms -step:1380/1695 train_time:134461ms step_avg:97.44ms -step:1381/1695 train_time:134815ms step_avg:97.62ms -step:1382/1695 train_time:134984ms step_avg:97.67ms -step:1383/1695 train_time:135080ms step_avg:97.67ms -step:1384/1695 train_time:135176ms step_avg:97.67ms -step:1385/1695 train_time:135272ms step_avg:97.67ms -step:1386/1695 train_time:135369ms step_avg:97.67ms -step:1387/1695 train_time:135465ms step_avg:97.67ms -step:1388/1695 train_time:135562ms step_avg:97.67ms -step:1389/1695 train_time:135658ms step_avg:97.67ms -step:1390/1695 train_time:135756ms step_avg:97.67ms -step:1391/1695 train_time:135859ms step_avg:97.67ms -step:1392/1695 train_time:135961ms step_avg:97.67ms -step:1393/1695 train_time:136060ms step_avg:97.67ms -step:1394/1695 train_time:136156ms step_avg:97.67ms -step:1395/1695 train_time:136253ms step_avg:97.67ms -step:1396/1695 train_time:136350ms step_avg:97.67ms -step:1397/1695 train_time:136446ms step_avg:97.67ms -step:1398/1695 train_time:136542ms step_avg:97.67ms -step:1399/1695 train_time:136639ms step_avg:97.67ms -step:1400/1695 train_time:136736ms step_avg:97.67ms -step:1401/1695 train_time:136834ms step_avg:97.67ms -step:1402/1695 train_time:136933ms step_avg:97.67ms -step:1403/1695 train_time:137032ms step_avg:97.67ms -step:1404/1695 train_time:137131ms step_avg:97.67ms -step:1405/1695 train_time:137230ms step_avg:97.67ms -step:1406/1695 train_time:137328ms step_avg:97.67ms -step:1407/1695 train_time:137425ms step_avg:97.67ms -step:1408/1695 train_time:137522ms step_avg:97.67ms -step:1409/1695 train_time:137619ms step_avg:97.67ms -step:1410/1695 train_time:137716ms step_avg:97.67ms -step:1411/1695 train_time:137813ms step_avg:97.67ms -step:1412/1695 train_time:137912ms step_avg:97.67ms -step:1413/1695 train_time:138011ms step_avg:97.67ms -step:1414/1695 train_time:138111ms step_avg:97.67ms -step:1415/1695 train_time:138210ms step_avg:97.67ms -step:1416/1695 train_time:138308ms step_avg:97.67ms -step:1417/1695 train_time:138405ms step_avg:97.67ms -step:1418/1695 train_time:138502ms step_avg:97.67ms -step:1419/1695 train_time:138600ms step_avg:97.67ms -step:1420/1695 train_time:138698ms step_avg:97.67ms -step:1421/1695 train_time:138795ms step_avg:97.67ms -step:1422/1695 train_time:138892ms step_avg:97.67ms -step:1423/1695 train_time:138991ms step_avg:97.67ms -step:1424/1695 train_time:139089ms step_avg:97.67ms -step:1425/1695 train_time:139188ms step_avg:97.68ms -step:1426/1695 train_time:139285ms step_avg:97.68ms -step:1427/1695 train_time:139382ms step_avg:97.68ms -step:1428/1695 train_time:139480ms step_avg:97.68ms -step:1429/1695 train_time:139577ms step_avg:97.67ms -step:1430/1695 train_time:139673ms step_avg:97.67ms -step:1431/1695 train_time:139771ms step_avg:97.67ms -step:1432/1695 train_time:139869ms step_avg:97.67ms -step:1433/1695 train_time:139968ms step_avg:97.67ms -step:1434/1695 train_time:140067ms step_avg:97.68ms -step:1435/1695 train_time:140167ms step_avg:97.68ms -step:1436/1695 train_time:140266ms step_avg:97.68ms -step:1437/1695 train_time:140363ms step_avg:97.68ms -step:1438/1695 train_time:140461ms step_avg:97.68ms -step:1439/1695 train_time:140558ms step_avg:97.68ms -step:1440/1695 train_time:140655ms step_avg:97.68ms -step:1441/1695 train_time:140752ms step_avg:97.68ms -step:1442/1695 train_time:140849ms step_avg:97.68ms -step:1443/1695 train_time:140948ms step_avg:97.68ms -step:1444/1695 train_time:141045ms step_avg:97.68ms -step:1445/1695 train_time:141144ms step_avg:97.68ms -step:1446/1695 train_time:141242ms step_avg:97.68ms -step:1447/1695 train_time:141340ms step_avg:97.68ms -step:1448/1695 train_time:141438ms step_avg:97.68ms -step:1449/1695 train_time:141534ms step_avg:97.68ms -step:1450/1695 train_time:141631ms step_avg:97.68ms -step:1451/1695 train_time:141728ms step_avg:97.68ms -step:1452/1695 train_time:141826ms step_avg:97.68ms -step:1453/1695 train_time:141924ms step_avg:97.68ms -step:1454/1695 train_time:142021ms step_avg:97.68ms -step:1455/1695 train_time:142117ms step_avg:97.68ms -step:1456/1695 train_time:142216ms step_avg:97.68ms -step:1457/1695 train_time:142314ms step_avg:97.68ms -step:1458/1695 train_time:142413ms step_avg:97.68ms -step:1459/1695 train_time:142510ms step_avg:97.68ms -step:1460/1695 train_time:142608ms step_avg:97.68ms -step:1461/1695 train_time:142706ms step_avg:97.68ms -step:1462/1695 train_time:142803ms step_avg:97.68ms -step:1463/1695 train_time:142901ms step_avg:97.68ms -step:1464/1695 train_time:142999ms step_avg:97.68ms -step:1465/1695 train_time:143096ms step_avg:97.68ms -step:1466/1695 train_time:143194ms step_avg:97.68ms -step:1467/1695 train_time:143291ms step_avg:97.68ms -step:1468/1695 train_time:143389ms step_avg:97.68ms -step:1469/1695 train_time:143487ms step_avg:97.68ms -step:1470/1695 train_time:143585ms step_avg:97.68ms -step:1471/1695 train_time:143682ms step_avg:97.68ms -step:1472/1695 train_time:143779ms step_avg:97.68ms -step:1473/1695 train_time:143877ms step_avg:97.68ms -step:1474/1695 train_time:143974ms step_avg:97.68ms -step:1475/1695 train_time:144072ms step_avg:97.68ms -step:1476/1695 train_time:144169ms step_avg:97.68ms -step:1477/1695 train_time:144267ms step_avg:97.68ms -step:1478/1695 train_time:144365ms step_avg:97.68ms -step:1479/1695 train_time:144462ms step_avg:97.68ms -step:1480/1695 train_time:144559ms step_avg:97.68ms -step:1481/1695 train_time:144657ms step_avg:97.67ms -step:1482/1695 train_time:144754ms step_avg:97.67ms -step:1483/1695 train_time:144852ms step_avg:97.67ms -step:1484/1695 train_time:144949ms step_avg:97.67ms -step:1485/1695 train_time:145048ms step_avg:97.68ms -step:1486/1695 train_time:145146ms step_avg:97.68ms -step:1487/1695 train_time:145244ms step_avg:97.68ms -step:1488/1695 train_time:145341ms step_avg:97.68ms -step:1489/1695 train_time:145438ms step_avg:97.67ms -step:1490/1695 train_time:145535ms step_avg:97.67ms -step:1491/1695 train_time:145632ms step_avg:97.67ms -step:1492/1695 train_time:145730ms step_avg:97.67ms -step:1493/1695 train_time:145829ms step_avg:97.68ms -step:1494/1695 train_time:145927ms step_avg:97.68ms -step:1495/1695 train_time:146024ms step_avg:97.68ms -step:1496/1695 train_time:146123ms step_avg:97.68ms -step:1497/1695 train_time:146220ms step_avg:97.68ms -step:1498/1695 train_time:146317ms step_avg:97.67ms -step:1499/1695 train_time:146414ms step_avg:97.67ms -step:1500/1695 train_time:146512ms step_avg:97.67ms -step:1500/1695 val_loss:3.3178 train_time:146608ms step_avg:97.74ms -step:1501/1695 train_time:146633ms step_avg:97.69ms -step:1502/1695 train_time:146718ms step_avg:97.68ms -step:1503/1695 train_time:146818ms step_avg:97.68ms -step:1504/1695 train_time:146916ms step_avg:97.68ms -step:1505/1695 train_time:147013ms step_avg:97.68ms -step:1506/1695 train_time:147110ms step_avg:97.68ms -step:1507/1695 train_time:147206ms step_avg:97.68ms -step:1508/1695 train_time:147302ms step_avg:97.68ms -step:1509/1695 train_time:147399ms step_avg:97.68ms -step:1510/1695 train_time:147495ms step_avg:97.68ms -step:1511/1695 train_time:147595ms step_avg:97.68ms -step:1512/1695 train_time:147697ms step_avg:97.68ms -step:1513/1695 train_time:147797ms step_avg:97.68ms -step:1514/1695 train_time:147896ms step_avg:97.69ms -step:1515/1695 train_time:147994ms step_avg:97.69ms -step:1516/1695 train_time:148092ms step_avg:97.69ms -step:1517/1695 train_time:148190ms step_avg:97.69ms -step:1518/1695 train_time:148287ms step_avg:97.69ms -step:1519/1695 train_time:148384ms step_avg:97.69ms -step:1520/1695 train_time:148481ms step_avg:97.68ms -step:1521/1695 train_time:148578ms step_avg:97.68ms -step:1522/1695 train_time:148676ms step_avg:97.68ms -step:1523/1695 train_time:148776ms step_avg:97.69ms -step:1524/1695 train_time:148874ms step_avg:97.69ms -step:1525/1695 train_time:148973ms step_avg:97.69ms -step:1526/1695 train_time:149072ms step_avg:97.69ms -step:1527/1695 train_time:149169ms step_avg:97.69ms -step:1528/1695 train_time:149266ms step_avg:97.69ms -step:1529/1695 train_time:149362ms step_avg:97.69ms -step:1530/1695 train_time:149460ms step_avg:97.69ms -step:1531/1695 train_time:149557ms step_avg:97.69ms -step:1532/1695 train_time:149655ms step_avg:97.69ms -step:1533/1695 train_time:149753ms step_avg:97.69ms -step:1534/1695 train_time:149852ms step_avg:97.69ms -step:1535/1695 train_time:149950ms step_avg:97.69ms -step:1536/1695 train_time:150048ms step_avg:97.69ms -step:1537/1695 train_time:150146ms step_avg:97.69ms -step:1538/1695 train_time:150244ms step_avg:97.69ms -step:1539/1695 train_time:150341ms step_avg:97.69ms -step:1540/1695 train_time:150438ms step_avg:97.69ms -step:1541/1695 train_time:150535ms step_avg:97.69ms -step:1542/1695 train_time:150633ms step_avg:97.69ms -step:1543/1695 train_time:150731ms step_avg:97.69ms -step:1544/1695 train_time:150830ms step_avg:97.69ms -step:1545/1695 train_time:150928ms step_avg:97.69ms -step:1546/1695 train_time:151027ms step_avg:97.69ms -step:1547/1695 train_time:151123ms step_avg:97.69ms -step:1548/1695 train_time:151220ms step_avg:97.69ms -step:1549/1695 train_time:151317ms step_avg:97.69ms -step:1550/1695 train_time:151415ms step_avg:97.69ms -step:1551/1695 train_time:151513ms step_avg:97.69ms -step:1552/1695 train_time:151866ms step_avg:97.85ms -step:1553/1695 train_time:152044ms step_avg:97.90ms -step:1554/1695 train_time:152139ms step_avg:97.90ms -step:1555/1695 train_time:152235ms step_avg:97.90ms -step:1556/1695 train_time:152332ms step_avg:97.90ms -step:1557/1695 train_time:152428ms step_avg:97.90ms -step:1558/1695 train_time:152525ms step_avg:97.90ms -step:1559/1695 train_time:152621ms step_avg:97.90ms -step:1560/1695 train_time:152717ms step_avg:97.90ms -step:1561/1695 train_time:152815ms step_avg:97.90ms -step:1562/1695 train_time:152920ms step_avg:97.90ms -step:1563/1695 train_time:153021ms step_avg:97.90ms -step:1564/1695 train_time:153120ms step_avg:97.90ms -step:1565/1695 train_time:153218ms step_avg:97.90ms -step:1566/1695 train_time:153315ms step_avg:97.90ms -step:1567/1695 train_time:153412ms step_avg:97.90ms -step:1568/1695 train_time:153509ms step_avg:97.90ms -step:1569/1695 train_time:153606ms step_avg:97.90ms -step:1570/1695 train_time:153702ms step_avg:97.90ms -step:1571/1695 train_time:153798ms step_avg:97.90ms -step:1572/1695 train_time:153898ms step_avg:97.90ms -step:1573/1695 train_time:153999ms step_avg:97.90ms -step:1574/1695 train_time:154099ms step_avg:97.90ms -step:1575/1695 train_time:154197ms step_avg:97.90ms -step:1576/1695 train_time:154294ms step_avg:97.90ms -step:1577/1695 train_time:154393ms step_avg:97.90ms -step:1578/1695 train_time:154490ms step_avg:97.90ms -step:1579/1695 train_time:154587ms step_avg:97.90ms -step:1580/1695 train_time:154684ms step_avg:97.90ms -step:1581/1695 train_time:154781ms step_avg:97.90ms -step:1582/1695 train_time:154879ms step_avg:97.90ms -step:1583/1695 train_time:154977ms step_avg:97.90ms -step:1584/1695 train_time:155075ms step_avg:97.90ms -step:1585/1695 train_time:155174ms step_avg:97.90ms -step:1586/1695 train_time:155273ms step_avg:97.90ms -step:1587/1695 train_time:155371ms step_avg:97.90ms -step:1588/1695 train_time:155469ms step_avg:97.90ms -step:1589/1695 train_time:155566ms step_avg:97.90ms -step:1590/1695 train_time:155663ms step_avg:97.90ms -step:1591/1695 train_time:155760ms step_avg:97.90ms -step:1592/1695 train_time:155858ms step_avg:97.90ms -step:1593/1695 train_time:155956ms step_avg:97.90ms -step:1594/1695 train_time:156054ms step_avg:97.90ms -step:1595/1695 train_time:156152ms step_avg:97.90ms -step:1596/1695 train_time:156250ms step_avg:97.90ms -step:1597/1695 train_time:156350ms step_avg:97.90ms -step:1598/1695 train_time:156448ms step_avg:97.90ms -step:1599/1695 train_time:156546ms step_avg:97.90ms -step:1600/1695 train_time:156644ms step_avg:97.90ms -step:1601/1695 train_time:156741ms step_avg:97.90ms -step:1602/1695 train_time:156838ms step_avg:97.90ms -step:1603/1695 train_time:156936ms step_avg:97.90ms -step:1604/1695 train_time:157034ms step_avg:97.90ms -step:1605/1695 train_time:157133ms step_avg:97.90ms -step:1606/1695 train_time:157234ms step_avg:97.90ms -step:1607/1695 train_time:157333ms step_avg:97.90ms -step:1608/1695 train_time:157431ms step_avg:97.91ms -step:1609/1695 train_time:157529ms step_avg:97.91ms -step:1610/1695 train_time:157627ms step_avg:97.91ms -step:1611/1695 train_time:157726ms step_avg:97.91ms -step:1612/1695 train_time:157824ms step_avg:97.91ms -step:1613/1695 train_time:157921ms step_avg:97.91ms -step:1614/1695 train_time:158017ms step_avg:97.90ms -step:1615/1695 train_time:158114ms step_avg:97.90ms -step:1616/1695 train_time:158212ms step_avg:97.90ms -step:1617/1695 train_time:158312ms step_avg:97.90ms -step:1618/1695 train_time:158412ms step_avg:97.91ms -step:1619/1695 train_time:158510ms step_avg:97.91ms -step:1620/1695 train_time:158609ms step_avg:97.91ms -step:1621/1695 train_time:158708ms step_avg:97.91ms -step:1622/1695 train_time:158806ms step_avg:97.91ms -step:1623/1695 train_time:158905ms step_avg:97.91ms -step:1624/1695 train_time:159001ms step_avg:97.91ms -step:1625/1695 train_time:159097ms step_avg:97.91ms -step:1625/1695 val_loss:3.2907 train_time:159193ms step_avg:97.96ms -step:1626/1695 train_time:159217ms step_avg:97.92ms -step:1627/1695 train_time:159299ms step_avg:97.91ms -step:1628/1695 train_time:159398ms step_avg:97.91ms -step:1629/1695 train_time:159495ms step_avg:97.91ms -step:1630/1695 train_time:159593ms step_avg:97.91ms -step:1631/1695 train_time:159690ms step_avg:97.91ms -step:1632/1695 train_time:159787ms step_avg:97.91ms -step:1633/1695 train_time:159884ms step_avg:97.91ms -step:1634/1695 train_time:159981ms step_avg:97.91ms -step:1635/1695 train_time:160077ms step_avg:97.91ms -step:1636/1695 train_time:160176ms step_avg:97.91ms -step:1637/1695 train_time:160276ms step_avg:97.91ms -step:1638/1695 train_time:160375ms step_avg:97.91ms -step:1639/1695 train_time:160474ms step_avg:97.91ms -step:1640/1695 train_time:160571ms step_avg:97.91ms -step:1641/1695 train_time:160669ms step_avg:97.91ms -step:1642/1695 train_time:160766ms step_avg:97.91ms -step:1643/1695 train_time:160864ms step_avg:97.91ms -step:1644/1695 train_time:160961ms step_avg:97.91ms -step:1645/1695 train_time:161058ms step_avg:97.91ms -step:1646/1695 train_time:161157ms step_avg:97.91ms -step:1647/1695 train_time:161255ms step_avg:97.91ms -step:1648/1695 train_time:161353ms step_avg:97.91ms -step:1649/1695 train_time:161452ms step_avg:97.91ms -step:1650/1695 train_time:161551ms step_avg:97.91ms -step:1651/1695 train_time:161649ms step_avg:97.91ms -step:1652/1695 train_time:161746ms step_avg:97.91ms -step:1653/1695 train_time:161843ms step_avg:97.91ms -step:1654/1695 train_time:161940ms step_avg:97.91ms -step:1655/1695 train_time:162037ms step_avg:97.91ms -step:1656/1695 train_time:162135ms step_avg:97.91ms -step:1657/1695 train_time:162233ms step_avg:97.91ms -step:1658/1695 train_time:162332ms step_avg:97.91ms -step:1659/1695 train_time:162430ms step_avg:97.91ms -step:1660/1695 train_time:162528ms step_avg:97.91ms -step:1661/1695 train_time:162627ms step_avg:97.91ms -step:1662/1695 train_time:162724ms step_avg:97.91ms -step:1663/1695 train_time:162820ms step_avg:97.91ms -step:1664/1695 train_time:162918ms step_avg:97.91ms -step:1665/1695 train_time:163015ms step_avg:97.91ms -step:1666/1695 train_time:163114ms step_avg:97.91ms -step:1667/1695 train_time:163211ms step_avg:97.91ms -step:1668/1695 train_time:163309ms step_avg:97.91ms -step:1669/1695 train_time:163407ms step_avg:97.91ms -step:1670/1695 train_time:163505ms step_avg:97.91ms -step:1671/1695 train_time:163603ms step_avg:97.91ms -step:1672/1695 train_time:163700ms step_avg:97.91ms -step:1673/1695 train_time:163797ms step_avg:97.91ms -step:1674/1695 train_time:163894ms step_avg:97.91ms -step:1675/1695 train_time:163992ms step_avg:97.91ms -step:1676/1695 train_time:164091ms step_avg:97.91ms -step:1677/1695 train_time:164189ms step_avg:97.91ms -step:1678/1695 train_time:164287ms step_avg:97.91ms -step:1679/1695 train_time:164385ms step_avg:97.91ms -step:1680/1695 train_time:164482ms step_avg:97.91ms -step:1681/1695 train_time:164580ms step_avg:97.91ms -step:1682/1695 train_time:164677ms step_avg:97.91ms -step:1683/1695 train_time:164775ms step_avg:97.91ms -step:1684/1695 train_time:164873ms step_avg:97.91ms -step:1685/1695 train_time:164971ms step_avg:97.91ms -step:1686/1695 train_time:165069ms step_avg:97.91ms -step:1687/1695 train_time:165167ms step_avg:97.91ms -step:1688/1695 train_time:165265ms step_avg:97.91ms -step:1689/1695 train_time:165363ms step_avg:97.91ms -step:1690/1695 train_time:165461ms step_avg:97.91ms -step:1691/1695 train_time:165559ms step_avg:97.91ms -step:1692/1695 train_time:165656ms step_avg:97.91ms -step:1693/1695 train_time:165754ms step_avg:97.91ms -step:1694/1695 train_time:165851ms step_avg:97.91ms -step:1695/1695 train_time:165950ms step_avg:97.91ms -step:1695/1695 val_loss:3.2791 train_time:166045ms step_avg:97.96ms -peak memory allocated: 34073 MiB reserved: 49476 MiB diff --git a/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt b/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt deleted file mode 100644 index 9652d6c2d..000000000 --- a/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:43:24 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 32C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms -step:1/1695 train_time:511ms step_avg:510.59ms -step:2/1695 train_time:534ms step_avg:266.84ms -step:3/1695 train_time:604ms step_avg:201.32ms -step:4/1695 train_time:696ms step_avg:174.01ms -step:5/1695 train_time:789ms step_avg:157.81ms -step:6/1695 train_time:882ms step_avg:147.03ms -step:7/1695 train_time:975ms step_avg:139.35ms -step:8/1695 train_time:1069ms step_avg:133.63ms -step:9/1695 train_time:1163ms step_avg:129.21ms -step:10/1695 train_time:1256ms step_avg:125.59ms -step:11/1695 train_time:1349ms step_avg:122.66ms -step:12/1695 train_time:1447ms step_avg:120.59ms -step:13/1695 train_time:1547ms step_avg:119.02ms -step:14/1695 train_time:1644ms step_avg:117.41ms -step:15/1695 train_time:1738ms step_avg:115.86ms -step:16/1695 train_time:1831ms step_avg:114.46ms -step:17/1695 train_time:1925ms step_avg:113.25ms -step:18/1695 train_time:2019ms step_avg:112.18ms -step:19/1695 train_time:2113ms step_avg:111.19ms -step:20/1695 train_time:2207ms step_avg:110.34ms -step:21/1695 train_time:2300ms step_avg:109.54ms -step:22/1695 train_time:2395ms step_avg:108.87ms -step:23/1695 train_time:2491ms step_avg:108.31ms -step:24/1695 train_time:2587ms step_avg:107.80ms -step:25/1695 train_time:2683ms step_avg:107.33ms -step:26/1695 train_time:2778ms step_avg:106.85ms -step:27/1695 train_time:2872ms step_avg:106.37ms -step:28/1695 train_time:2967ms step_avg:105.96ms -step:29/1695 train_time:3061ms step_avg:105.56ms -step:30/1695 train_time:3155ms step_avg:105.16ms -step:31/1695 train_time:3249ms step_avg:104.81ms -step:32/1695 train_time:3344ms step_avg:104.50ms -step:33/1695 train_time:3438ms step_avg:104.19ms -step:34/1695 train_time:3533ms step_avg:103.92ms -step:35/1695 train_time:3629ms step_avg:103.69ms -step:36/1695 train_time:3726ms step_avg:103.50ms -step:37/1695 train_time:3822ms step_avg:103.29ms -step:38/1695 train_time:3916ms step_avg:103.06ms -step:39/1695 train_time:4010ms step_avg:102.83ms -step:40/1695 train_time:4105ms step_avg:102.62ms -step:41/1695 train_time:4199ms step_avg:102.42ms -step:42/1695 train_time:4293ms step_avg:102.21ms -step:43/1695 train_time:4388ms step_avg:102.06ms -step:44/1695 train_time:4484ms step_avg:101.90ms -step:45/1695 train_time:4579ms step_avg:101.75ms -step:46/1695 train_time:4674ms step_avg:101.61ms -step:47/1695 train_time:4770ms step_avg:101.48ms -step:48/1695 train_time:4865ms step_avg:101.36ms -step:49/1695 train_time:4960ms step_avg:101.22ms -step:50/1695 train_time:5054ms step_avg:101.07ms -step:51/1695 train_time:5149ms step_avg:100.96ms -step:52/1695 train_time:5243ms step_avg:100.83ms -step:53/1695 train_time:5338ms step_avg:100.71ms -step:54/1695 train_time:5432ms step_avg:100.59ms -step:55/1695 train_time:5528ms step_avg:100.51ms -step:56/1695 train_time:5623ms step_avg:100.41ms -step:57/1695 train_time:5718ms step_avg:100.31ms -step:58/1695 train_time:5813ms step_avg:100.22ms -step:59/1695 train_time:5908ms step_avg:100.14ms -step:60/1695 train_time:6003ms step_avg:100.05ms -step:61/1695 train_time:6097ms step_avg:99.95ms -step:62/1695 train_time:6191ms step_avg:99.86ms -step:63/1695 train_time:6286ms step_avg:99.78ms -step:64/1695 train_time:6381ms step_avg:99.70ms -step:65/1695 train_time:6475ms step_avg:99.62ms -step:66/1695 train_time:6570ms step_avg:99.55ms -step:67/1695 train_time:6665ms step_avg:99.48ms -step:68/1695 train_time:6760ms step_avg:99.40ms -step:69/1695 train_time:6854ms step_avg:99.33ms -step:70/1695 train_time:6949ms step_avg:99.26ms -step:71/1695 train_time:7043ms step_avg:99.20ms -step:72/1695 train_time:7138ms step_avg:99.14ms -step:73/1695 train_time:7232ms step_avg:99.07ms -step:74/1695 train_time:7327ms step_avg:99.01ms -step:75/1695 train_time:7422ms step_avg:98.96ms -step:76/1695 train_time:7516ms step_avg:98.89ms -step:77/1695 train_time:7610ms step_avg:98.84ms -step:78/1695 train_time:7706ms step_avg:98.79ms -step:79/1695 train_time:7801ms step_avg:98.75ms -step:80/1695 train_time:7895ms step_avg:98.69ms -step:81/1695 train_time:7990ms step_avg:98.64ms -step:82/1695 train_time:8085ms step_avg:98.59ms -step:83/1695 train_time:8179ms step_avg:98.54ms -step:84/1695 train_time:8273ms step_avg:98.49ms -step:85/1695 train_time:8369ms step_avg:98.46ms -step:86/1695 train_time:8463ms step_avg:98.41ms -step:87/1695 train_time:8557ms step_avg:98.36ms -step:88/1695 train_time:8652ms step_avg:98.31ms -step:89/1695 train_time:8747ms step_avg:98.28ms -step:90/1695 train_time:8842ms step_avg:98.24ms -step:91/1695 train_time:8936ms step_avg:98.20ms -step:92/1695 train_time:9030ms step_avg:98.16ms -step:93/1695 train_time:9126ms step_avg:98.13ms -step:94/1695 train_time:9221ms step_avg:98.10ms -step:95/1695 train_time:9315ms step_avg:98.05ms -step:96/1695 train_time:9409ms step_avg:98.01ms -step:97/1695 train_time:9504ms step_avg:97.98ms -step:98/1695 train_time:9597ms step_avg:97.93ms -step:99/1695 train_time:9691ms step_avg:97.89ms -step:100/1695 train_time:9786ms step_avg:97.86ms -step:101/1695 train_time:9881ms step_avg:97.84ms -step:102/1695 train_time:9975ms step_avg:97.79ms -step:103/1695 train_time:10069ms step_avg:97.76ms -step:104/1695 train_time:10164ms step_avg:97.73ms -step:105/1695 train_time:10259ms step_avg:97.70ms -step:106/1695 train_time:10352ms step_avg:97.66ms -step:107/1695 train_time:10448ms step_avg:97.64ms -step:108/1695 train_time:10543ms step_avg:97.62ms -step:109/1695 train_time:10636ms step_avg:97.58ms -step:110/1695 train_time:10730ms step_avg:97.54ms -step:111/1695 train_time:10826ms step_avg:97.53ms -step:112/1695 train_time:10922ms step_avg:97.51ms -step:113/1695 train_time:11016ms step_avg:97.48ms -step:114/1695 train_time:11110ms step_avg:97.45ms -step:115/1695 train_time:11205ms step_avg:97.44ms -step:116/1695 train_time:11300ms step_avg:97.41ms -step:117/1695 train_time:11393ms step_avg:97.37ms -step:118/1695 train_time:11487ms step_avg:97.35ms -step:119/1695 train_time:11582ms step_avg:97.33ms -step:120/1695 train_time:11676ms step_avg:97.30ms -step:121/1695 train_time:11770ms step_avg:97.27ms -step:122/1695 train_time:11864ms step_avg:97.25ms -step:123/1695 train_time:11959ms step_avg:97.22ms -step:124/1695 train_time:12052ms step_avg:97.20ms -step:125/1695 train_time:12148ms step_avg:97.18ms -step:125/1695 val_loss:4.3128 train_time:12241ms step_avg:97.92ms -step:126/1695 train_time:12267ms step_avg:97.36ms -step:127/1695 train_time:12345ms step_avg:97.20ms -step:128/1695 train_time:12446ms step_avg:97.23ms -step:129/1695 train_time:12540ms step_avg:97.21ms -step:130/1695 train_time:12635ms step_avg:97.19ms -step:131/1695 train_time:12728ms step_avg:97.16ms -step:132/1695 train_time:12821ms step_avg:97.13ms -step:133/1695 train_time:12915ms step_avg:97.11ms -step:134/1695 train_time:13008ms step_avg:97.08ms -step:135/1695 train_time:13102ms step_avg:97.05ms -step:136/1695 train_time:13196ms step_avg:97.03ms -step:137/1695 train_time:13291ms step_avg:97.01ms -step:138/1695 train_time:13388ms step_avg:97.01ms -step:139/1695 train_time:13482ms step_avg:96.99ms -step:140/1695 train_time:13577ms step_avg:96.98ms -step:141/1695 train_time:13672ms step_avg:96.96ms -step:142/1695 train_time:13766ms step_avg:96.94ms -step:143/1695 train_time:13859ms step_avg:96.92ms -step:144/1695 train_time:13954ms step_avg:96.90ms -step:145/1695 train_time:14047ms step_avg:96.88ms -step:146/1695 train_time:14140ms step_avg:96.85ms -step:147/1695 train_time:14235ms step_avg:96.84ms -step:148/1695 train_time:14331ms step_avg:96.83ms -step:149/1695 train_time:14426ms step_avg:96.82ms -step:150/1695 train_time:14520ms step_avg:96.80ms -step:151/1695 train_time:14614ms step_avg:96.78ms -step:152/1695 train_time:14709ms step_avg:96.77ms -step:153/1695 train_time:14802ms step_avg:96.75ms -step:154/1695 train_time:14896ms step_avg:96.73ms -step:155/1695 train_time:14991ms step_avg:96.72ms -step:156/1695 train_time:15084ms step_avg:96.69ms -step:157/1695 train_time:15178ms step_avg:96.68ms -step:158/1695 train_time:15272ms step_avg:96.66ms -step:159/1695 train_time:15366ms step_avg:96.64ms -step:160/1695 train_time:15461ms step_avg:96.63ms -step:161/1695 train_time:15556ms step_avg:96.62ms -step:162/1695 train_time:15651ms step_avg:96.61ms -step:163/1695 train_time:15745ms step_avg:96.59ms -step:164/1695 train_time:15839ms step_avg:96.58ms -step:165/1695 train_time:15934ms step_avg:96.57ms -step:166/1695 train_time:16029ms step_avg:96.56ms -step:167/1695 train_time:16123ms step_avg:96.55ms -step:168/1695 train_time:16217ms step_avg:96.53ms -step:169/1695 train_time:16312ms step_avg:96.52ms -step:170/1695 train_time:16406ms step_avg:96.51ms -step:171/1695 train_time:16501ms step_avg:96.50ms -step:172/1695 train_time:16596ms step_avg:96.49ms -step:173/1695 train_time:16939ms step_avg:97.91ms -step:174/1695 train_time:17041ms step_avg:97.94ms -step:175/1695 train_time:17135ms step_avg:97.91ms -step:176/1695 train_time:17228ms step_avg:97.88ms -step:177/1695 train_time:17321ms step_avg:97.86ms -step:178/1695 train_time:17414ms step_avg:97.83ms -step:179/1695 train_time:17509ms step_avg:97.82ms -step:180/1695 train_time:17602ms step_avg:97.79ms -step:181/1695 train_time:17695ms step_avg:97.77ms -step:182/1695 train_time:17788ms step_avg:97.74ms -step:183/1695 train_time:17884ms step_avg:97.73ms -step:184/1695 train_time:17983ms step_avg:97.73ms -step:185/1695 train_time:18078ms step_avg:97.72ms -step:186/1695 train_time:18174ms step_avg:97.71ms -step:187/1695 train_time:18268ms step_avg:97.69ms -step:188/1695 train_time:18361ms step_avg:97.67ms -step:189/1695 train_time:18455ms step_avg:97.65ms -step:190/1695 train_time:18548ms step_avg:97.62ms -step:191/1695 train_time:18641ms step_avg:97.60ms -step:192/1695 train_time:18735ms step_avg:97.58ms -step:193/1695 train_time:18830ms step_avg:97.56ms -step:194/1695 train_time:18924ms step_avg:97.55ms -step:195/1695 train_time:19018ms step_avg:97.53ms -step:196/1695 train_time:19114ms step_avg:97.52ms -step:197/1695 train_time:19209ms step_avg:97.51ms -step:198/1695 train_time:19302ms step_avg:97.49ms -step:199/1695 train_time:19397ms step_avg:97.47ms -step:200/1695 train_time:19490ms step_avg:97.45ms -step:201/1695 train_time:19583ms step_avg:97.43ms -step:202/1695 train_time:19677ms step_avg:97.41ms -step:203/1695 train_time:19771ms step_avg:97.40ms -step:204/1695 train_time:19866ms step_avg:97.38ms -step:205/1695 train_time:19960ms step_avg:97.37ms -step:206/1695 train_time:20056ms step_avg:97.36ms -step:207/1695 train_time:20151ms step_avg:97.35ms -step:208/1695 train_time:20244ms step_avg:97.33ms -step:209/1695 train_time:20338ms step_avg:97.31ms -step:210/1695 train_time:20433ms step_avg:97.30ms -step:211/1695 train_time:20527ms step_avg:97.29ms -step:212/1695 train_time:20621ms step_avg:97.27ms -step:213/1695 train_time:20715ms step_avg:97.26ms -step:214/1695 train_time:20810ms step_avg:97.24ms -step:215/1695 train_time:20904ms step_avg:97.23ms -step:216/1695 train_time:20998ms step_avg:97.21ms -step:217/1695 train_time:21094ms step_avg:97.21ms -step:218/1695 train_time:21188ms step_avg:97.19ms -step:219/1695 train_time:21281ms step_avg:97.17ms -step:220/1695 train_time:21376ms step_avg:97.16ms -step:221/1695 train_time:21470ms step_avg:97.15ms -step:222/1695 train_time:21563ms step_avg:97.13ms -step:223/1695 train_time:21657ms step_avg:97.12ms -step:224/1695 train_time:21751ms step_avg:97.10ms -step:225/1695 train_time:21845ms step_avg:97.09ms -step:226/1695 train_time:21939ms step_avg:97.07ms -step:227/1695 train_time:22034ms step_avg:97.06ms -step:228/1695 train_time:22128ms step_avg:97.05ms -step:229/1695 train_time:22221ms step_avg:97.03ms -step:230/1695 train_time:22315ms step_avg:97.02ms -step:231/1695 train_time:22410ms step_avg:97.01ms -step:232/1695 train_time:22505ms step_avg:97.00ms -step:233/1695 train_time:22599ms step_avg:96.99ms -step:234/1695 train_time:22693ms step_avg:96.98ms -step:235/1695 train_time:22787ms step_avg:96.97ms -step:236/1695 train_time:22881ms step_avg:96.95ms -step:237/1695 train_time:22975ms step_avg:96.94ms -step:238/1695 train_time:23071ms step_avg:96.94ms -step:239/1695 train_time:23166ms step_avg:96.93ms -step:240/1695 train_time:23259ms step_avg:96.91ms -step:241/1695 train_time:23354ms step_avg:96.90ms -step:242/1695 train_time:23448ms step_avg:96.89ms -step:243/1695 train_time:23541ms step_avg:96.88ms -step:244/1695 train_time:23637ms step_avg:96.87ms -step:245/1695 train_time:23731ms step_avg:96.86ms -step:246/1695 train_time:23825ms step_avg:96.85ms -step:247/1695 train_time:23919ms step_avg:96.84ms -step:248/1695 train_time:24014ms step_avg:96.83ms -step:249/1695 train_time:24109ms step_avg:96.83ms -step:250/1695 train_time:24204ms step_avg:96.81ms -step:250/1695 val_loss:3.9758 train_time:24295ms step_avg:97.18ms -step:251/1695 train_time:24320ms step_avg:96.89ms -step:252/1695 train_time:24399ms step_avg:96.82ms -step:253/1695 train_time:24500ms step_avg:96.84ms -step:254/1695 train_time:24595ms step_avg:96.83ms -step:255/1695 train_time:24689ms step_avg:96.82ms -step:256/1695 train_time:24782ms step_avg:96.81ms -step:257/1695 train_time:24876ms step_avg:96.79ms -step:258/1695 train_time:24969ms step_avg:96.78ms -step:259/1695 train_time:25062ms step_avg:96.76ms -step:260/1695 train_time:25155ms step_avg:96.75ms -step:261/1695 train_time:25248ms step_avg:96.74ms -step:262/1695 train_time:25344ms step_avg:96.73ms -step:263/1695 train_time:25441ms step_avg:96.73ms -step:264/1695 train_time:25538ms step_avg:96.73ms -step:265/1695 train_time:25632ms step_avg:96.73ms -step:266/1695 train_time:25726ms step_avg:96.71ms -step:267/1695 train_time:25820ms step_avg:96.70ms -step:268/1695 train_time:25914ms step_avg:96.69ms -step:269/1695 train_time:26007ms step_avg:96.68ms -step:270/1695 train_time:26100ms step_avg:96.67ms -step:271/1695 train_time:26193ms step_avg:96.65ms -step:272/1695 train_time:26287ms step_avg:96.64ms -step:273/1695 train_time:26382ms step_avg:96.64ms -step:274/1695 train_time:26478ms step_avg:96.63ms -step:275/1695 train_time:26574ms step_avg:96.63ms -step:276/1695 train_time:26668ms step_avg:96.62ms -step:277/1695 train_time:26762ms step_avg:96.61ms -step:278/1695 train_time:26856ms step_avg:96.61ms -step:279/1695 train_time:26950ms step_avg:96.60ms -step:280/1695 train_time:27044ms step_avg:96.59ms -step:281/1695 train_time:27137ms step_avg:96.57ms -step:282/1695 train_time:27231ms step_avg:96.56ms -step:283/1695 train_time:27324ms step_avg:96.55ms -step:284/1695 train_time:27419ms step_avg:96.54ms -step:285/1695 train_time:27513ms step_avg:96.54ms -step:286/1695 train_time:27608ms step_avg:96.53ms -step:287/1695 train_time:27702ms step_avg:96.52ms -step:288/1695 train_time:27796ms step_avg:96.51ms -step:289/1695 train_time:27891ms step_avg:96.51ms -step:290/1695 train_time:27984ms step_avg:96.50ms -step:291/1695 train_time:28078ms step_avg:96.49ms -step:292/1695 train_time:28171ms step_avg:96.48ms -step:293/1695 train_time:28265ms step_avg:96.47ms -step:294/1695 train_time:28359ms step_avg:96.46ms -step:295/1695 train_time:28453ms step_avg:96.45ms -step:296/1695 train_time:28547ms step_avg:96.44ms -step:297/1695 train_time:28641ms step_avg:96.43ms -step:298/1695 train_time:28735ms step_avg:96.43ms -step:299/1695 train_time:28831ms step_avg:96.42ms -step:300/1695 train_time:28925ms step_avg:96.42ms -step:301/1695 train_time:29019ms step_avg:96.41ms -step:302/1695 train_time:29114ms step_avg:96.40ms -step:303/1695 train_time:29208ms step_avg:96.39ms -step:304/1695 train_time:29301ms step_avg:96.39ms -step:305/1695 train_time:29396ms step_avg:96.38ms -step:306/1695 train_time:29490ms step_avg:96.37ms -step:307/1695 train_time:29584ms step_avg:96.37ms -step:308/1695 train_time:29678ms step_avg:96.36ms -step:309/1695 train_time:29772ms step_avg:96.35ms -step:310/1695 train_time:29866ms step_avg:96.34ms -step:311/1695 train_time:29960ms step_avg:96.33ms -step:312/1695 train_time:30055ms step_avg:96.33ms -step:313/1695 train_time:30149ms step_avg:96.32ms -step:314/1695 train_time:30242ms step_avg:96.31ms -step:315/1695 train_time:30337ms step_avg:96.31ms -step:316/1695 train_time:30431ms step_avg:96.30ms -step:317/1695 train_time:30525ms step_avg:96.29ms -step:318/1695 train_time:30619ms step_avg:96.29ms -step:319/1695 train_time:30713ms step_avg:96.28ms -step:320/1695 train_time:30807ms step_avg:96.27ms -step:321/1695 train_time:30900ms step_avg:96.26ms -step:322/1695 train_time:30995ms step_avg:96.26ms -step:323/1695 train_time:31089ms step_avg:96.25ms -step:324/1695 train_time:31182ms step_avg:96.24ms -step:325/1695 train_time:31277ms step_avg:96.24ms -step:326/1695 train_time:31371ms step_avg:96.23ms -step:327/1695 train_time:31465ms step_avg:96.22ms -step:328/1695 train_time:31559ms step_avg:96.22ms -step:329/1695 train_time:31654ms step_avg:96.21ms -step:330/1695 train_time:31750ms step_avg:96.21ms -step:331/1695 train_time:31843ms step_avg:96.20ms -step:332/1695 train_time:31938ms step_avg:96.20ms -step:333/1695 train_time:32032ms step_avg:96.19ms -step:334/1695 train_time:32126ms step_avg:96.19ms -step:335/1695 train_time:32220ms step_avg:96.18ms -step:336/1695 train_time:32315ms step_avg:96.18ms -step:337/1695 train_time:32410ms step_avg:96.17ms -step:338/1695 train_time:32503ms step_avg:96.16ms -step:339/1695 train_time:32597ms step_avg:96.16ms -step:340/1695 train_time:32692ms step_avg:96.15ms -step:341/1695 train_time:32785ms step_avg:96.14ms -step:342/1695 train_time:32879ms step_avg:96.14ms -step:343/1695 train_time:32974ms step_avg:96.13ms -step:344/1695 train_time:33069ms step_avg:96.13ms -step:345/1695 train_time:33399ms step_avg:96.81ms -step:346/1695 train_time:33523ms step_avg:96.89ms -step:347/1695 train_time:33615ms step_avg:96.87ms -step:348/1695 train_time:33709ms step_avg:96.87ms -step:349/1695 train_time:33802ms step_avg:96.85ms -step:350/1695 train_time:33895ms step_avg:96.84ms -step:351/1695 train_time:33988ms step_avg:96.83ms -step:352/1695 train_time:34081ms step_avg:96.82ms -step:353/1695 train_time:34174ms step_avg:96.81ms -step:354/1695 train_time:34267ms step_avg:96.80ms -step:355/1695 train_time:34364ms step_avg:96.80ms -step:356/1695 train_time:34462ms step_avg:96.80ms -step:357/1695 train_time:34558ms step_avg:96.80ms -step:358/1695 train_time:34653ms step_avg:96.80ms -step:359/1695 train_time:34747ms step_avg:96.79ms -step:360/1695 train_time:34840ms step_avg:96.78ms -step:361/1695 train_time:34934ms step_avg:96.77ms -step:362/1695 train_time:35026ms step_avg:96.76ms -step:363/1695 train_time:35119ms step_avg:96.75ms -step:364/1695 train_time:35213ms step_avg:96.74ms -step:365/1695 train_time:35306ms step_avg:96.73ms -step:366/1695 train_time:35402ms step_avg:96.73ms -step:367/1695 train_time:35497ms step_avg:96.72ms -step:368/1695 train_time:35591ms step_avg:96.72ms -step:369/1695 train_time:35686ms step_avg:96.71ms -step:370/1695 train_time:35780ms step_avg:96.70ms -step:371/1695 train_time:35874ms step_avg:96.69ms -step:372/1695 train_time:35967ms step_avg:96.68ms -step:373/1695 train_time:36060ms step_avg:96.68ms -step:374/1695 train_time:36154ms step_avg:96.67ms -step:375/1695 train_time:36248ms step_avg:96.66ms -step:375/1695 val_loss:3.8203 train_time:36339ms step_avg:96.90ms -step:376/1695 train_time:36364ms step_avg:96.71ms -step:377/1695 train_time:36442ms step_avg:96.66ms -step:378/1695 train_time:36539ms step_avg:96.66ms -step:379/1695 train_time:36633ms step_avg:96.66ms -step:380/1695 train_time:36726ms step_avg:96.65ms -step:381/1695 train_time:36820ms step_avg:96.64ms -step:382/1695 train_time:36912ms step_avg:96.63ms -step:383/1695 train_time:37005ms step_avg:96.62ms -step:384/1695 train_time:37098ms step_avg:96.61ms -step:385/1695 train_time:37190ms step_avg:96.60ms -step:386/1695 train_time:37284ms step_avg:96.59ms -step:387/1695 train_time:37379ms step_avg:96.59ms -step:388/1695 train_time:37475ms step_avg:96.59ms -step:389/1695 train_time:37570ms step_avg:96.58ms -step:390/1695 train_time:37665ms step_avg:96.58ms -step:391/1695 train_time:37759ms step_avg:96.57ms -step:392/1695 train_time:37852ms step_avg:96.56ms -step:393/1695 train_time:37946ms step_avg:96.55ms -step:394/1695 train_time:38039ms step_avg:96.55ms -step:395/1695 train_time:38131ms step_avg:96.53ms -step:396/1695 train_time:38225ms step_avg:96.53ms -step:397/1695 train_time:38319ms step_avg:96.52ms -step:398/1695 train_time:38413ms step_avg:96.52ms -step:399/1695 train_time:38508ms step_avg:96.51ms -step:400/1695 train_time:38604ms step_avg:96.51ms -step:401/1695 train_time:38699ms step_avg:96.51ms -step:402/1695 train_time:38792ms step_avg:96.50ms -step:403/1695 train_time:38886ms step_avg:96.49ms -step:404/1695 train_time:38980ms step_avg:96.49ms -step:405/1695 train_time:39073ms step_avg:96.48ms -step:406/1695 train_time:39166ms step_avg:96.47ms -step:407/1695 train_time:39260ms step_avg:96.46ms -step:408/1695 train_time:39354ms step_avg:96.46ms -step:409/1695 train_time:39448ms step_avg:96.45ms -step:410/1695 train_time:39543ms step_avg:96.45ms -step:411/1695 train_time:39638ms step_avg:96.44ms -step:412/1695 train_time:39732ms step_avg:96.44ms -step:413/1695 train_time:39825ms step_avg:96.43ms -step:414/1695 train_time:39919ms step_avg:96.42ms -step:415/1695 train_time:40012ms step_avg:96.41ms -step:416/1695 train_time:40105ms step_avg:96.41ms -step:417/1695 train_time:40199ms step_avg:96.40ms -step:418/1695 train_time:40292ms step_avg:96.39ms -step:419/1695 train_time:40386ms step_avg:96.39ms -step:420/1695 train_time:40481ms step_avg:96.38ms -step:421/1695 train_time:40575ms step_avg:96.38ms -step:422/1695 train_time:40669ms step_avg:96.37ms -step:423/1695 train_time:40764ms step_avg:96.37ms -step:424/1695 train_time:40858ms step_avg:96.36ms -step:425/1695 train_time:40952ms step_avg:96.36ms -step:426/1695 train_time:41046ms step_avg:96.35ms -step:427/1695 train_time:41140ms step_avg:96.35ms -step:428/1695 train_time:41233ms step_avg:96.34ms -step:429/1695 train_time:41327ms step_avg:96.33ms -step:430/1695 train_time:41420ms step_avg:96.33ms -step:431/1695 train_time:41514ms step_avg:96.32ms -step:432/1695 train_time:41608ms step_avg:96.32ms -step:433/1695 train_time:41702ms step_avg:96.31ms -step:434/1695 train_time:41797ms step_avg:96.31ms -step:435/1695 train_time:41890ms step_avg:96.30ms -step:436/1695 train_time:41985ms step_avg:96.30ms -step:437/1695 train_time:42079ms step_avg:96.29ms -step:438/1695 train_time:42173ms step_avg:96.29ms -step:439/1695 train_time:42267ms step_avg:96.28ms -step:440/1695 train_time:42361ms step_avg:96.28ms -step:441/1695 train_time:42455ms step_avg:96.27ms -step:442/1695 train_time:42549ms step_avg:96.26ms -step:443/1695 train_time:42643ms step_avg:96.26ms -step:444/1695 train_time:42737ms step_avg:96.25ms -step:445/1695 train_time:42830ms step_avg:96.25ms -step:446/1695 train_time:42924ms step_avg:96.24ms -step:447/1695 train_time:43019ms step_avg:96.24ms -step:448/1695 train_time:43114ms step_avg:96.24ms -step:449/1695 train_time:43208ms step_avg:96.23ms -step:450/1695 train_time:43302ms step_avg:96.23ms -step:451/1695 train_time:43397ms step_avg:96.22ms -step:452/1695 train_time:43490ms step_avg:96.22ms -step:453/1695 train_time:43584ms step_avg:96.21ms -step:454/1695 train_time:43679ms step_avg:96.21ms -step:455/1695 train_time:43773ms step_avg:96.20ms -step:456/1695 train_time:43866ms step_avg:96.20ms -step:457/1695 train_time:43960ms step_avg:96.19ms -step:458/1695 train_time:44055ms step_avg:96.19ms -step:459/1695 train_time:44149ms step_avg:96.18ms -step:460/1695 train_time:44243ms step_avg:96.18ms -step:461/1695 train_time:44338ms step_avg:96.18ms -step:462/1695 train_time:44432ms step_avg:96.17ms -step:463/1695 train_time:44526ms step_avg:96.17ms -step:464/1695 train_time:44621ms step_avg:96.16ms -step:465/1695 train_time:44714ms step_avg:96.16ms -step:466/1695 train_time:44808ms step_avg:96.15ms -step:467/1695 train_time:44902ms step_avg:96.15ms -step:468/1695 train_time:44996ms step_avg:96.14ms -step:469/1695 train_time:45089ms step_avg:96.14ms -step:470/1695 train_time:45184ms step_avg:96.14ms -step:471/1695 train_time:45278ms step_avg:96.13ms -step:472/1695 train_time:45372ms step_avg:96.13ms -step:473/1695 train_time:45466ms step_avg:96.12ms -step:474/1695 train_time:45560ms step_avg:96.12ms -step:475/1695 train_time:45653ms step_avg:96.11ms -step:476/1695 train_time:45748ms step_avg:96.11ms -step:477/1695 train_time:45842ms step_avg:96.10ms -step:478/1695 train_time:45935ms step_avg:96.10ms -step:479/1695 train_time:46029ms step_avg:96.09ms -step:480/1695 train_time:46123ms step_avg:96.09ms -step:481/1695 train_time:46217ms step_avg:96.08ms -step:482/1695 train_time:46311ms step_avg:96.08ms -step:483/1695 train_time:46405ms step_avg:96.08ms -step:484/1695 train_time:46499ms step_avg:96.07ms -step:485/1695 train_time:46593ms step_avg:96.07ms -step:486/1695 train_time:46688ms step_avg:96.07ms -step:487/1695 train_time:46782ms step_avg:96.06ms -step:488/1695 train_time:46876ms step_avg:96.06ms -step:489/1695 train_time:46969ms step_avg:96.05ms -step:490/1695 train_time:47064ms step_avg:96.05ms -step:491/1695 train_time:47158ms step_avg:96.05ms -step:492/1695 train_time:47252ms step_avg:96.04ms -step:493/1695 train_time:47346ms step_avg:96.04ms -step:494/1695 train_time:47442ms step_avg:96.04ms -step:495/1695 train_time:47536ms step_avg:96.03ms -step:496/1695 train_time:47629ms step_avg:96.03ms -step:497/1695 train_time:47724ms step_avg:96.02ms -step:498/1695 train_time:47820ms step_avg:96.02ms -step:499/1695 train_time:47914ms step_avg:96.02ms -step:500/1695 train_time:48008ms step_avg:96.02ms -step:500/1695 val_loss:3.7161 train_time:48100ms step_avg:96.20ms -step:501/1695 train_time:48124ms step_avg:96.06ms -step:502/1695 train_time:48204ms step_avg:96.02ms -step:503/1695 train_time:48302ms step_avg:96.03ms -step:504/1695 train_time:48397ms step_avg:96.03ms -step:505/1695 train_time:48491ms step_avg:96.02ms -step:506/1695 train_time:48584ms step_avg:96.02ms -step:507/1695 train_time:48678ms step_avg:96.01ms -step:508/1695 train_time:48771ms step_avg:96.01ms -step:509/1695 train_time:48864ms step_avg:96.00ms -step:510/1695 train_time:48957ms step_avg:95.99ms -step:511/1695 train_time:49050ms step_avg:95.99ms -step:512/1695 train_time:49146ms step_avg:95.99ms -step:513/1695 train_time:49242ms step_avg:95.99ms -step:514/1695 train_time:49337ms step_avg:95.99ms -step:515/1695 train_time:49432ms step_avg:95.98ms -step:516/1695 train_time:49525ms step_avg:95.98ms -step:517/1695 train_time:49619ms step_avg:95.97ms -step:518/1695 train_time:49713ms step_avg:95.97ms -step:519/1695 train_time:50082ms step_avg:96.50ms -step:520/1695 train_time:50228ms step_avg:96.59ms -step:521/1695 train_time:50320ms step_avg:96.58ms -step:522/1695 train_time:50412ms step_avg:96.58ms -step:523/1695 train_time:50505ms step_avg:96.57ms -step:524/1695 train_time:50598ms step_avg:96.56ms -step:525/1695 train_time:50691ms step_avg:96.55ms -step:526/1695 train_time:50784ms step_avg:96.55ms -step:527/1695 train_time:50878ms step_avg:96.54ms -step:528/1695 train_time:50971ms step_avg:96.54ms -step:529/1695 train_time:51069ms step_avg:96.54ms -step:530/1695 train_time:51167ms step_avg:96.54ms -step:531/1695 train_time:51264ms step_avg:96.54ms -step:532/1695 train_time:51358ms step_avg:96.54ms -step:533/1695 train_time:51452ms step_avg:96.53ms -step:534/1695 train_time:51545ms step_avg:96.53ms -step:535/1695 train_time:51638ms step_avg:96.52ms -step:536/1695 train_time:51732ms step_avg:96.51ms -step:537/1695 train_time:51824ms step_avg:96.51ms -step:538/1695 train_time:51918ms step_avg:96.50ms -step:539/1695 train_time:52014ms step_avg:96.50ms -step:540/1695 train_time:52110ms step_avg:96.50ms -step:541/1695 train_time:52204ms step_avg:96.50ms -step:542/1695 train_time:52299ms step_avg:96.49ms -step:543/1695 train_time:52393ms step_avg:96.49ms -step:544/1695 train_time:52486ms step_avg:96.48ms -step:545/1695 train_time:52580ms step_avg:96.48ms -step:546/1695 train_time:52674ms step_avg:96.47ms -step:547/1695 train_time:52767ms step_avg:96.47ms -step:548/1695 train_time:52860ms step_avg:96.46ms -step:549/1695 train_time:52954ms step_avg:96.46ms -step:550/1695 train_time:53049ms step_avg:96.45ms -step:551/1695 train_time:53143ms step_avg:96.45ms -step:552/1695 train_time:53238ms step_avg:96.45ms -step:553/1695 train_time:53331ms step_avg:96.44ms -step:554/1695 train_time:53425ms step_avg:96.43ms -step:555/1695 train_time:53519ms step_avg:96.43ms -step:556/1695 train_time:53614ms step_avg:96.43ms -step:557/1695 train_time:53708ms step_avg:96.42ms -step:558/1695 train_time:53801ms step_avg:96.42ms -step:559/1695 train_time:53895ms step_avg:96.41ms -step:560/1695 train_time:53989ms step_avg:96.41ms -step:561/1695 train_time:54083ms step_avg:96.40ms -step:562/1695 train_time:54178ms step_avg:96.40ms -step:563/1695 train_time:54273ms step_avg:96.40ms -step:564/1695 train_time:54367ms step_avg:96.40ms -step:565/1695 train_time:54461ms step_avg:96.39ms -step:566/1695 train_time:54555ms step_avg:96.39ms -step:567/1695 train_time:54650ms step_avg:96.38ms -step:568/1695 train_time:54746ms step_avg:96.38ms -step:569/1695 train_time:54841ms step_avg:96.38ms -step:570/1695 train_time:54938ms step_avg:96.38ms -step:571/1695 train_time:55035ms step_avg:96.38ms -step:572/1695 train_time:55131ms step_avg:96.38ms -step:573/1695 train_time:55228ms step_avg:96.38ms -step:574/1695 train_time:55323ms step_avg:96.38ms -step:575/1695 train_time:55420ms step_avg:96.38ms -step:576/1695 train_time:55517ms step_avg:96.38ms -step:577/1695 train_time:55614ms step_avg:96.38ms -step:578/1695 train_time:55711ms step_avg:96.39ms -step:579/1695 train_time:55806ms step_avg:96.38ms -step:580/1695 train_time:55903ms step_avg:96.38ms -step:581/1695 train_time:55999ms step_avg:96.38ms -step:582/1695 train_time:56097ms step_avg:96.39ms -step:583/1695 train_time:56194ms step_avg:96.39ms -step:584/1695 train_time:56290ms step_avg:96.39ms -step:585/1695 train_time:56386ms step_avg:96.39ms -step:586/1695 train_time:56482ms step_avg:96.39ms -step:587/1695 train_time:56578ms step_avg:96.39ms -step:588/1695 train_time:56675ms step_avg:96.39ms -step:589/1695 train_time:56772ms step_avg:96.39ms -step:590/1695 train_time:56868ms step_avg:96.39ms -step:591/1695 train_time:56964ms step_avg:96.39ms -step:592/1695 train_time:57061ms step_avg:96.39ms -step:593/1695 train_time:57157ms step_avg:96.39ms -step:594/1695 train_time:57254ms step_avg:96.39ms -step:595/1695 train_time:57350ms step_avg:96.39ms -step:596/1695 train_time:57446ms step_avg:96.39ms -step:597/1695 train_time:57542ms step_avg:96.39ms -step:598/1695 train_time:57638ms step_avg:96.38ms -step:599/1695 train_time:57735ms step_avg:96.39ms -step:600/1695 train_time:57831ms step_avg:96.39ms -step:601/1695 train_time:57926ms step_avg:96.38ms -step:602/1695 train_time:58022ms step_avg:96.38ms -step:603/1695 train_time:58118ms step_avg:96.38ms -step:604/1695 train_time:58215ms step_avg:96.38ms -step:605/1695 train_time:58312ms step_avg:96.38ms -step:606/1695 train_time:58408ms step_avg:96.38ms -step:607/1695 train_time:58504ms step_avg:96.38ms -step:608/1695 train_time:58600ms step_avg:96.38ms -step:609/1695 train_time:58696ms step_avg:96.38ms -step:610/1695 train_time:58794ms step_avg:96.38ms -step:611/1695 train_time:58891ms step_avg:96.38ms -step:612/1695 train_time:58987ms step_avg:96.38ms -step:613/1695 train_time:59083ms step_avg:96.38ms -step:614/1695 train_time:59180ms step_avg:96.38ms -step:615/1695 train_time:59276ms step_avg:96.38ms -step:616/1695 train_time:59373ms step_avg:96.39ms -step:617/1695 train_time:59470ms step_avg:96.39ms -step:618/1695 train_time:59565ms step_avg:96.38ms -step:619/1695 train_time:59661ms step_avg:96.38ms -step:620/1695 train_time:59758ms step_avg:96.38ms -step:621/1695 train_time:59855ms step_avg:96.39ms -step:622/1695 train_time:59952ms step_avg:96.39ms -step:623/1695 train_time:60048ms step_avg:96.39ms -step:624/1695 train_time:60143ms step_avg:96.38ms -step:625/1695 train_time:60239ms step_avg:96.38ms -step:625/1695 val_loss:3.6203 train_time:60334ms step_avg:96.53ms -step:626/1695 train_time:60358ms step_avg:96.42ms -step:627/1695 train_time:60442ms step_avg:96.40ms -step:628/1695 train_time:60540ms step_avg:96.40ms -step:629/1695 train_time:60637ms step_avg:96.40ms -step:630/1695 train_time:60732ms step_avg:96.40ms -step:631/1695 train_time:60827ms step_avg:96.40ms -step:632/1695 train_time:60921ms step_avg:96.39ms -step:633/1695 train_time:61017ms step_avg:96.39ms -step:634/1695 train_time:61112ms step_avg:96.39ms -step:635/1695 train_time:61208ms step_avg:96.39ms -step:636/1695 train_time:61305ms step_avg:96.39ms -step:637/1695 train_time:61405ms step_avg:96.40ms -step:638/1695 train_time:61502ms step_avg:96.40ms -step:639/1695 train_time:61599ms step_avg:96.40ms -step:640/1695 train_time:61696ms step_avg:96.40ms -step:641/1695 train_time:61793ms step_avg:96.40ms -step:642/1695 train_time:61887ms step_avg:96.40ms -step:643/1695 train_time:61983ms step_avg:96.40ms -step:644/1695 train_time:62078ms step_avg:96.39ms -step:645/1695 train_time:62175ms step_avg:96.40ms -step:646/1695 train_time:62272ms step_avg:96.40ms -step:647/1695 train_time:62368ms step_avg:96.40ms -step:648/1695 train_time:62465ms step_avg:96.40ms -step:649/1695 train_time:62562ms step_avg:96.40ms -step:650/1695 train_time:62660ms step_avg:96.40ms -step:651/1695 train_time:62756ms step_avg:96.40ms -step:652/1695 train_time:62852ms step_avg:96.40ms -step:653/1695 train_time:62947ms step_avg:96.40ms -step:654/1695 train_time:63041ms step_avg:96.39ms -step:655/1695 train_time:63139ms step_avg:96.39ms -step:656/1695 train_time:63237ms step_avg:96.40ms -step:657/1695 train_time:63335ms step_avg:96.40ms -step:658/1695 train_time:63431ms step_avg:96.40ms -step:659/1695 train_time:63527ms step_avg:96.40ms -step:660/1695 train_time:63623ms step_avg:96.40ms -step:661/1695 train_time:63719ms step_avg:96.40ms -step:662/1695 train_time:63815ms step_avg:96.40ms -step:663/1695 train_time:63911ms step_avg:96.40ms -step:664/1695 train_time:64007ms step_avg:96.40ms -step:665/1695 train_time:64103ms step_avg:96.40ms -step:666/1695 train_time:64199ms step_avg:96.39ms -step:667/1695 train_time:64295ms step_avg:96.39ms -step:668/1695 train_time:64391ms step_avg:96.39ms -step:669/1695 train_time:64487ms step_avg:96.39ms -step:670/1695 train_time:64583ms step_avg:96.39ms -step:671/1695 train_time:64679ms step_avg:96.39ms -step:672/1695 train_time:64775ms step_avg:96.39ms -step:673/1695 train_time:64871ms step_avg:96.39ms -step:674/1695 train_time:64967ms step_avg:96.39ms -step:675/1695 train_time:65062ms step_avg:96.39ms -step:676/1695 train_time:65158ms step_avg:96.39ms -step:677/1695 train_time:65255ms step_avg:96.39ms -step:678/1695 train_time:65351ms step_avg:96.39ms -step:679/1695 train_time:65447ms step_avg:96.39ms -step:680/1695 train_time:65542ms step_avg:96.39ms -step:681/1695 train_time:65639ms step_avg:96.39ms -step:682/1695 train_time:65735ms step_avg:96.39ms -step:683/1695 train_time:65832ms step_avg:96.39ms -step:684/1695 train_time:65929ms step_avg:96.39ms -step:685/1695 train_time:66025ms step_avg:96.39ms -step:686/1695 train_time:66120ms step_avg:96.38ms -step:687/1695 train_time:66216ms step_avg:96.38ms -step:688/1695 train_time:66313ms step_avg:96.39ms -step:689/1695 train_time:66410ms step_avg:96.39ms -step:690/1695 train_time:66505ms step_avg:96.38ms -step:691/1695 train_time:66945ms step_avg:96.88ms -step:692/1695 train_time:67026ms step_avg:96.86ms -step:693/1695 train_time:67121ms step_avg:96.86ms -step:694/1695 train_time:67216ms step_avg:96.85ms -step:695/1695 train_time:67311ms step_avg:96.85ms -step:696/1695 train_time:67406ms step_avg:96.85ms -step:697/1695 train_time:67501ms step_avg:96.84ms -step:698/1695 train_time:67597ms step_avg:96.84ms -step:699/1695 train_time:67692ms step_avg:96.84ms -step:700/1695 train_time:67787ms step_avg:96.84ms -step:701/1695 train_time:67886ms step_avg:96.84ms -step:702/1695 train_time:67985ms step_avg:96.84ms -step:703/1695 train_time:68082ms step_avg:96.84ms -step:704/1695 train_time:68178ms step_avg:96.84ms -step:705/1695 train_time:68274ms step_avg:96.84ms -step:706/1695 train_time:68370ms step_avg:96.84ms -step:707/1695 train_time:68465ms step_avg:96.84ms -step:708/1695 train_time:68561ms step_avg:96.84ms -step:709/1695 train_time:68656ms step_avg:96.84ms -step:710/1695 train_time:68752ms step_avg:96.83ms -step:711/1695 train_time:68850ms step_avg:96.84ms -step:712/1695 train_time:68946ms step_avg:96.83ms -step:713/1695 train_time:69043ms step_avg:96.83ms -step:714/1695 train_time:69139ms step_avg:96.83ms -step:715/1695 train_time:69235ms step_avg:96.83ms -step:716/1695 train_time:69331ms step_avg:96.83ms -step:717/1695 train_time:69426ms step_avg:96.83ms -step:718/1695 train_time:69522ms step_avg:96.83ms -step:719/1695 train_time:69618ms step_avg:96.83ms -step:720/1695 train_time:69714ms step_avg:96.82ms -step:721/1695 train_time:69811ms step_avg:96.83ms -step:722/1695 train_time:69908ms step_avg:96.83ms -step:723/1695 train_time:70004ms step_avg:96.82ms -step:724/1695 train_time:70101ms step_avg:96.82ms -step:725/1695 train_time:70198ms step_avg:96.83ms -step:726/1695 train_time:70296ms step_avg:96.83ms -step:727/1695 train_time:70392ms step_avg:96.83ms -step:728/1695 train_time:70487ms step_avg:96.82ms -step:729/1695 train_time:70583ms step_avg:96.82ms -step:730/1695 train_time:70679ms step_avg:96.82ms -step:731/1695 train_time:70776ms step_avg:96.82ms -step:732/1695 train_time:70874ms step_avg:96.82ms -step:733/1695 train_time:70970ms step_avg:96.82ms -step:734/1695 train_time:71066ms step_avg:96.82ms -step:735/1695 train_time:71162ms step_avg:96.82ms -step:736/1695 train_time:71258ms step_avg:96.82ms -step:737/1695 train_time:71355ms step_avg:96.82ms -step:738/1695 train_time:71451ms step_avg:96.82ms -step:739/1695 train_time:71546ms step_avg:96.81ms -step:740/1695 train_time:71641ms step_avg:96.81ms -step:741/1695 train_time:71739ms step_avg:96.81ms -step:742/1695 train_time:71837ms step_avg:96.81ms -step:743/1695 train_time:71934ms step_avg:96.82ms -step:744/1695 train_time:72030ms step_avg:96.81ms -step:745/1695 train_time:72125ms step_avg:96.81ms -step:746/1695 train_time:72221ms step_avg:96.81ms -step:747/1695 train_time:72318ms step_avg:96.81ms -step:748/1695 train_time:72414ms step_avg:96.81ms -step:749/1695 train_time:72510ms step_avg:96.81ms -step:750/1695 train_time:72605ms step_avg:96.81ms -step:750/1695 val_loss:3.5663 train_time:72700ms step_avg:96.93ms -step:751/1695 train_time:72724ms step_avg:96.84ms -step:752/1695 train_time:72807ms step_avg:96.82ms -step:753/1695 train_time:72904ms step_avg:96.82ms -step:754/1695 train_time:73002ms step_avg:96.82ms -step:755/1695 train_time:73098ms step_avg:96.82ms -step:756/1695 train_time:73192ms step_avg:96.82ms -step:757/1695 train_time:73287ms step_avg:96.81ms -step:758/1695 train_time:73381ms step_avg:96.81ms -step:759/1695 train_time:73476ms step_avg:96.81ms -step:760/1695 train_time:73571ms step_avg:96.80ms -step:761/1695 train_time:73668ms step_avg:96.80ms -step:762/1695 train_time:73766ms step_avg:96.81ms -step:763/1695 train_time:73864ms step_avg:96.81ms -step:764/1695 train_time:73962ms step_avg:96.81ms -step:765/1695 train_time:74059ms step_avg:96.81ms -step:766/1695 train_time:74154ms step_avg:96.81ms -step:767/1695 train_time:74249ms step_avg:96.80ms -step:768/1695 train_time:74344ms step_avg:96.80ms -step:769/1695 train_time:74439ms step_avg:96.80ms -step:770/1695 train_time:74535ms step_avg:96.80ms -step:771/1695 train_time:74630ms step_avg:96.80ms -step:772/1695 train_time:74726ms step_avg:96.80ms -step:773/1695 train_time:74824ms step_avg:96.80ms -step:774/1695 train_time:74921ms step_avg:96.80ms -step:775/1695 train_time:75018ms step_avg:96.80ms -step:776/1695 train_time:75114ms step_avg:96.80ms -step:777/1695 train_time:75209ms step_avg:96.79ms -step:778/1695 train_time:75304ms step_avg:96.79ms -step:779/1695 train_time:75400ms step_avg:96.79ms -step:780/1695 train_time:75496ms step_avg:96.79ms -step:781/1695 train_time:75592ms step_avg:96.79ms -step:782/1695 train_time:75687ms step_avg:96.79ms -step:783/1695 train_time:75784ms step_avg:96.79ms -step:784/1695 train_time:75880ms step_avg:96.79ms -step:785/1695 train_time:75977ms step_avg:96.79ms -step:786/1695 train_time:76073ms step_avg:96.79ms -step:787/1695 train_time:76168ms step_avg:96.78ms -step:788/1695 train_time:76264ms step_avg:96.78ms -step:789/1695 train_time:76359ms step_avg:96.78ms -step:790/1695 train_time:76455ms step_avg:96.78ms -step:791/1695 train_time:76549ms step_avg:96.78ms -step:792/1695 train_time:76645ms step_avg:96.77ms -step:793/1695 train_time:76742ms step_avg:96.77ms -step:794/1695 train_time:76839ms step_avg:96.77ms -step:795/1695 train_time:76935ms step_avg:96.77ms -step:796/1695 train_time:77031ms step_avg:96.77ms -step:797/1695 train_time:77126ms step_avg:96.77ms -step:798/1695 train_time:77224ms step_avg:96.77ms -step:799/1695 train_time:77320ms step_avg:96.77ms -step:800/1695 train_time:77416ms step_avg:96.77ms -step:801/1695 train_time:77511ms step_avg:96.77ms -step:802/1695 train_time:77607ms step_avg:96.77ms -step:803/1695 train_time:77702ms step_avg:96.76ms -step:804/1695 train_time:77797ms step_avg:96.76ms -step:805/1695 train_time:77893ms step_avg:96.76ms -step:806/1695 train_time:77989ms step_avg:96.76ms -step:807/1695 train_time:78085ms step_avg:96.76ms -step:808/1695 train_time:78181ms step_avg:96.76ms -step:809/1695 train_time:78276ms step_avg:96.76ms -step:810/1695 train_time:78371ms step_avg:96.75ms -step:811/1695 train_time:78467ms step_avg:96.75ms -step:812/1695 train_time:78562ms step_avg:96.75ms -step:813/1695 train_time:78657ms step_avg:96.75ms -step:814/1695 train_time:78752ms step_avg:96.75ms -step:815/1695 train_time:78848ms step_avg:96.75ms -step:816/1695 train_time:78945ms step_avg:96.75ms -step:817/1695 train_time:79042ms step_avg:96.75ms -step:818/1695 train_time:79138ms step_avg:96.75ms -step:819/1695 train_time:79234ms step_avg:96.74ms -step:820/1695 train_time:79329ms step_avg:96.74ms -step:821/1695 train_time:79425ms step_avg:96.74ms -step:822/1695 train_time:79520ms step_avg:96.74ms -step:823/1695 train_time:79616ms step_avg:96.74ms -step:824/1695 train_time:79711ms step_avg:96.74ms -step:825/1695 train_time:79807ms step_avg:96.74ms -step:826/1695 train_time:79903ms step_avg:96.73ms -step:827/1695 train_time:79998ms step_avg:96.73ms -step:828/1695 train_time:80094ms step_avg:96.73ms -step:829/1695 train_time:80190ms step_avg:96.73ms -step:830/1695 train_time:80286ms step_avg:96.73ms -step:831/1695 train_time:80382ms step_avg:96.73ms -step:832/1695 train_time:80477ms step_avg:96.73ms -step:833/1695 train_time:80572ms step_avg:96.73ms -step:834/1695 train_time:80668ms step_avg:96.72ms -step:835/1695 train_time:80763ms step_avg:96.72ms -step:836/1695 train_time:80859ms step_avg:96.72ms -step:837/1695 train_time:80955ms step_avg:96.72ms -step:838/1695 train_time:81050ms step_avg:96.72ms -step:839/1695 train_time:81146ms step_avg:96.72ms -step:840/1695 train_time:81243ms step_avg:96.72ms -step:841/1695 train_time:81339ms step_avg:96.72ms -step:842/1695 train_time:81435ms step_avg:96.72ms -step:843/1695 train_time:81530ms step_avg:96.71ms -step:844/1695 train_time:81625ms step_avg:96.71ms -step:845/1695 train_time:81722ms step_avg:96.71ms -step:846/1695 train_time:81818ms step_avg:96.71ms -step:847/1695 train_time:81915ms step_avg:96.71ms -step:848/1695 train_time:82011ms step_avg:96.71ms -step:849/1695 train_time:82106ms step_avg:96.71ms -step:850/1695 train_time:82202ms step_avg:96.71ms -step:851/1695 train_time:82298ms step_avg:96.71ms -step:852/1695 train_time:82393ms step_avg:96.70ms -step:853/1695 train_time:82488ms step_avg:96.70ms -step:854/1695 train_time:82583ms step_avg:96.70ms -step:855/1695 train_time:82679ms step_avg:96.70ms -step:856/1695 train_time:82775ms step_avg:96.70ms -step:857/1695 train_time:82870ms step_avg:96.70ms -step:858/1695 train_time:82966ms step_avg:96.70ms -step:859/1695 train_time:83063ms step_avg:96.70ms -step:860/1695 train_time:83159ms step_avg:96.70ms -step:861/1695 train_time:83255ms step_avg:96.70ms -step:862/1695 train_time:83350ms step_avg:96.69ms -step:863/1695 train_time:83679ms step_avg:96.96ms -step:864/1695 train_time:83862ms step_avg:97.06ms -step:865/1695 train_time:83955ms step_avg:97.06ms -step:866/1695 train_time:84050ms step_avg:97.06ms -step:867/1695 train_time:84145ms step_avg:97.05ms -step:868/1695 train_time:84240ms step_avg:97.05ms -step:869/1695 train_time:84336ms step_avg:97.05ms -step:870/1695 train_time:84431ms step_avg:97.05ms -step:871/1695 train_time:84525ms step_avg:97.04ms -step:872/1695 train_time:84620ms step_avg:97.04ms -step:873/1695 train_time:84718ms step_avg:97.04ms -step:874/1695 train_time:84818ms step_avg:97.05ms -step:875/1695 train_time:84917ms step_avg:97.05ms -step:875/1695 val_loss:3.5235 train_time:85011ms step_avg:97.16ms -step:876/1695 train_time:85037ms step_avg:97.07ms -step:877/1695 train_time:85116ms step_avg:97.05ms -step:878/1695 train_time:85213ms step_avg:97.05ms -step:879/1695 train_time:85309ms step_avg:97.05ms -step:880/1695 train_time:85404ms step_avg:97.05ms -step:881/1695 train_time:85499ms step_avg:97.05ms -step:882/1695 train_time:85594ms step_avg:97.05ms -step:883/1695 train_time:85690ms step_avg:97.04ms -step:884/1695 train_time:85785ms step_avg:97.04ms -step:885/1695 train_time:85879ms step_avg:97.04ms -step:886/1695 train_time:85976ms step_avg:97.04ms -step:887/1695 train_time:86075ms step_avg:97.04ms -step:888/1695 train_time:86174ms step_avg:97.04ms -step:889/1695 train_time:86271ms step_avg:97.04ms -step:890/1695 train_time:86367ms step_avg:97.04ms -step:891/1695 train_time:86462ms step_avg:97.04ms -step:892/1695 train_time:86557ms step_avg:97.04ms -step:893/1695 train_time:86653ms step_avg:97.04ms -step:894/1695 train_time:86749ms step_avg:97.04ms -step:895/1695 train_time:86845ms step_avg:97.03ms -step:896/1695 train_time:86940ms step_avg:97.03ms -step:897/1695 train_time:87038ms step_avg:97.03ms -step:898/1695 train_time:87136ms step_avg:97.03ms -step:899/1695 train_time:87234ms step_avg:97.03ms -step:900/1695 train_time:87331ms step_avg:97.03ms -step:901/1695 train_time:87427ms step_avg:97.03ms -step:902/1695 train_time:87522ms step_avg:97.03ms -step:903/1695 train_time:87617ms step_avg:97.03ms -step:904/1695 train_time:87714ms step_avg:97.03ms -step:905/1695 train_time:87810ms step_avg:97.03ms -step:906/1695 train_time:87905ms step_avg:97.03ms -step:907/1695 train_time:88002ms step_avg:97.03ms -step:908/1695 train_time:88099ms step_avg:97.03ms -step:909/1695 train_time:88197ms step_avg:97.03ms -step:910/1695 train_time:88295ms step_avg:97.03ms -step:911/1695 train_time:88392ms step_avg:97.03ms -step:912/1695 train_time:88487ms step_avg:97.03ms -step:913/1695 train_time:88582ms step_avg:97.02ms -step:914/1695 train_time:88678ms step_avg:97.02ms -step:915/1695 train_time:88774ms step_avg:97.02ms -step:916/1695 train_time:88870ms step_avg:97.02ms -step:917/1695 train_time:88965ms step_avg:97.02ms -step:918/1695 train_time:89060ms step_avg:97.02ms -step:919/1695 train_time:89157ms step_avg:97.02ms -step:920/1695 train_time:89255ms step_avg:97.02ms -step:921/1695 train_time:89353ms step_avg:97.02ms -step:922/1695 train_time:89451ms step_avg:97.02ms -step:923/1695 train_time:89547ms step_avg:97.02ms -step:924/1695 train_time:89642ms step_avg:97.01ms -step:925/1695 train_time:89738ms step_avg:97.01ms -step:926/1695 train_time:89833ms step_avg:97.01ms -step:927/1695 train_time:89929ms step_avg:97.01ms -step:928/1695 train_time:90025ms step_avg:97.01ms -step:929/1695 train_time:90120ms step_avg:97.01ms -step:930/1695 train_time:90216ms step_avg:97.01ms -step:931/1695 train_time:90313ms step_avg:97.01ms -step:932/1695 train_time:90410ms step_avg:97.01ms -step:933/1695 train_time:90506ms step_avg:97.01ms -step:934/1695 train_time:90602ms step_avg:97.00ms -step:935/1695 train_time:90697ms step_avg:97.00ms -step:936/1695 train_time:90794ms step_avg:97.00ms -step:937/1695 train_time:90890ms step_avg:97.00ms -step:938/1695 train_time:90985ms step_avg:97.00ms -step:939/1695 train_time:91081ms step_avg:97.00ms -step:940/1695 train_time:91177ms step_avg:97.00ms -step:941/1695 train_time:91274ms step_avg:97.00ms -step:942/1695 train_time:91371ms step_avg:97.00ms -step:943/1695 train_time:91468ms step_avg:97.00ms -step:944/1695 train_time:91563ms step_avg:97.00ms -step:945/1695 train_time:91660ms step_avg:96.99ms -step:946/1695 train_time:91756ms step_avg:96.99ms -step:947/1695 train_time:91855ms step_avg:97.00ms -step:948/1695 train_time:91952ms step_avg:97.00ms -step:949/1695 train_time:92048ms step_avg:96.99ms -step:950/1695 train_time:92144ms step_avg:96.99ms -step:951/1695 train_time:92239ms step_avg:96.99ms -step:952/1695 train_time:92336ms step_avg:96.99ms -step:953/1695 train_time:92433ms step_avg:96.99ms -step:954/1695 train_time:92529ms step_avg:96.99ms -step:955/1695 train_time:92625ms step_avg:96.99ms -step:956/1695 train_time:92721ms step_avg:96.99ms -step:957/1695 train_time:92817ms step_avg:96.99ms -step:958/1695 train_time:92914ms step_avg:96.99ms -step:959/1695 train_time:93011ms step_avg:96.99ms -step:960/1695 train_time:93107ms step_avg:96.99ms -step:961/1695 train_time:93203ms step_avg:96.99ms -step:962/1695 train_time:93299ms step_avg:96.98ms -step:963/1695 train_time:93395ms step_avg:96.98ms -step:964/1695 train_time:93491ms step_avg:96.98ms -step:965/1695 train_time:93587ms step_avg:96.98ms -step:966/1695 train_time:93682ms step_avg:96.98ms -step:967/1695 train_time:93779ms step_avg:96.98ms -step:968/1695 train_time:93874ms step_avg:96.98ms -step:969/1695 train_time:93970ms step_avg:96.98ms -step:970/1695 train_time:94066ms step_avg:96.98ms -step:971/1695 train_time:94162ms step_avg:96.97ms -step:972/1695 train_time:94259ms step_avg:96.97ms -step:973/1695 train_time:94356ms step_avg:96.97ms -step:974/1695 train_time:94452ms step_avg:96.97ms -step:975/1695 train_time:94549ms step_avg:96.97ms -step:976/1695 train_time:94643ms step_avg:96.97ms -step:977/1695 train_time:94739ms step_avg:96.97ms -step:978/1695 train_time:94835ms step_avg:96.97ms -step:979/1695 train_time:94932ms step_avg:96.97ms -step:980/1695 train_time:95028ms step_avg:96.97ms -step:981/1695 train_time:95124ms step_avg:96.97ms -step:982/1695 train_time:95219ms step_avg:96.96ms -step:983/1695 train_time:95316ms step_avg:96.96ms -step:984/1695 train_time:95412ms step_avg:96.96ms -step:985/1695 train_time:95508ms step_avg:96.96ms -step:986/1695 train_time:95603ms step_avg:96.96ms -step:987/1695 train_time:95699ms step_avg:96.96ms -step:988/1695 train_time:95795ms step_avg:96.96ms -step:989/1695 train_time:95891ms step_avg:96.96ms -step:990/1695 train_time:95987ms step_avg:96.96ms -step:991/1695 train_time:96083ms step_avg:96.96ms -step:992/1695 train_time:96178ms step_avg:96.95ms -step:993/1695 train_time:96274ms step_avg:96.95ms -step:994/1695 train_time:96370ms step_avg:96.95ms -step:995/1695 train_time:96465ms step_avg:96.95ms -step:996/1695 train_time:96560ms step_avg:96.95ms -step:997/1695 train_time:96657ms step_avg:96.95ms -step:998/1695 train_time:96753ms step_avg:96.95ms -step:999/1695 train_time:96850ms step_avg:96.95ms -step:1000/1695 train_time:96947ms step_avg:96.95ms -step:1000/1695 val_loss:3.4841 train_time:97040ms step_avg:97.04ms -step:1001/1695 train_time:97064ms step_avg:96.97ms -step:1002/1695 train_time:97146ms step_avg:96.95ms -step:1003/1695 train_time:97243ms step_avg:96.95ms -step:1004/1695 train_time:97339ms step_avg:96.95ms -step:1005/1695 train_time:97434ms step_avg:96.95ms -step:1006/1695 train_time:97530ms step_avg:96.95ms -step:1007/1695 train_time:97624ms step_avg:96.95ms -step:1008/1695 train_time:97719ms step_avg:96.94ms -step:1009/1695 train_time:97815ms step_avg:96.94ms -step:1010/1695 train_time:97910ms step_avg:96.94ms -step:1011/1695 train_time:98007ms step_avg:96.94ms -step:1012/1695 train_time:98104ms step_avg:96.94ms -step:1013/1695 train_time:98201ms step_avg:96.94ms -step:1014/1695 train_time:98298ms step_avg:96.94ms -step:1015/1695 train_time:98395ms step_avg:96.94ms -step:1016/1695 train_time:98490ms step_avg:96.94ms -step:1017/1695 train_time:98586ms step_avg:96.94ms -step:1018/1695 train_time:98682ms step_avg:96.94ms -step:1019/1695 train_time:98777ms step_avg:96.93ms -step:1020/1695 train_time:98872ms step_avg:96.93ms -step:1021/1695 train_time:98967ms step_avg:96.93ms -step:1022/1695 train_time:99063ms step_avg:96.93ms -step:1023/1695 train_time:99160ms step_avg:96.93ms -step:1024/1695 train_time:99258ms step_avg:96.93ms -step:1025/1695 train_time:99355ms step_avg:96.93ms -step:1026/1695 train_time:99451ms step_avg:96.93ms -step:1027/1695 train_time:99547ms step_avg:96.93ms -step:1028/1695 train_time:99643ms step_avg:96.93ms -step:1029/1695 train_time:99738ms step_avg:96.93ms -step:1030/1695 train_time:99833ms step_avg:96.93ms -step:1031/1695 train_time:99930ms step_avg:96.93ms -step:1032/1695 train_time:100026ms step_avg:96.92ms -step:1033/1695 train_time:100122ms step_avg:96.92ms -step:1034/1695 train_time:100218ms step_avg:96.92ms -step:1035/1695 train_time:100315ms step_avg:96.92ms -step:1036/1695 train_time:100647ms step_avg:97.15ms -step:1037/1695 train_time:100826ms step_avg:97.23ms -step:1038/1695 train_time:100920ms step_avg:97.22ms -step:1039/1695 train_time:101015ms step_avg:97.22ms -step:1040/1695 train_time:101110ms step_avg:97.22ms -step:1041/1695 train_time:101204ms step_avg:97.22ms -step:1042/1695 train_time:101299ms step_avg:97.22ms -step:1043/1695 train_time:101394ms step_avg:97.21ms -step:1044/1695 train_time:101489ms step_avg:97.21ms -step:1045/1695 train_time:101584ms step_avg:97.21ms -step:1046/1695 train_time:101681ms step_avg:97.21ms -step:1047/1695 train_time:101783ms step_avg:97.21ms -step:1048/1695 train_time:101883ms step_avg:97.22ms -step:1049/1695 train_time:101980ms step_avg:97.22ms -step:1050/1695 train_time:102077ms step_avg:97.22ms -step:1051/1695 train_time:102173ms step_avg:97.21ms -step:1052/1695 train_time:102268ms step_avg:97.21ms -step:1053/1695 train_time:102362ms step_avg:97.21ms -step:1054/1695 train_time:102457ms step_avg:97.21ms -step:1055/1695 train_time:102553ms step_avg:97.21ms -step:1056/1695 train_time:102650ms step_avg:97.21ms -step:1057/1695 train_time:102747ms step_avg:97.21ms -step:1058/1695 train_time:102844ms step_avg:97.21ms -step:1059/1695 train_time:102941ms step_avg:97.21ms -step:1060/1695 train_time:103037ms step_avg:97.20ms -step:1061/1695 train_time:103134ms step_avg:97.20ms -step:1062/1695 train_time:103230ms step_avg:97.20ms -step:1063/1695 train_time:103325ms step_avg:97.20ms -step:1064/1695 train_time:103421ms step_avg:97.20ms -step:1065/1695 train_time:103516ms step_avg:97.20ms -step:1066/1695 train_time:103612ms step_avg:97.20ms -step:1067/1695 train_time:103709ms step_avg:97.20ms -step:1068/1695 train_time:103806ms step_avg:97.20ms -step:1069/1695 train_time:103901ms step_avg:97.19ms -step:1070/1695 train_time:103997ms step_avg:97.19ms -step:1071/1695 train_time:104094ms step_avg:97.19ms -step:1072/1695 train_time:104190ms step_avg:97.19ms -step:1073/1695 train_time:104285ms step_avg:97.19ms -step:1074/1695 train_time:104381ms step_avg:97.19ms -step:1075/1695 train_time:104476ms step_avg:97.19ms -step:1076/1695 train_time:104572ms step_avg:97.19ms -step:1077/1695 train_time:104668ms step_avg:97.18ms -step:1078/1695 train_time:104764ms step_avg:97.18ms -step:1079/1695 train_time:104860ms step_avg:97.18ms -step:1080/1695 train_time:104956ms step_avg:97.18ms -step:1081/1695 train_time:105052ms step_avg:97.18ms -step:1082/1695 train_time:105149ms step_avg:97.18ms -step:1083/1695 train_time:105246ms step_avg:97.18ms -step:1084/1695 train_time:105341ms step_avg:97.18ms -step:1085/1695 train_time:105437ms step_avg:97.18ms -step:1086/1695 train_time:105533ms step_avg:97.18ms -step:1087/1695 train_time:105630ms step_avg:97.18ms -step:1088/1695 train_time:105726ms step_avg:97.18ms -step:1089/1695 train_time:105822ms step_avg:97.17ms -step:1090/1695 train_time:105918ms step_avg:97.17ms -step:1091/1695 train_time:106014ms step_avg:97.17ms -step:1092/1695 train_time:106110ms step_avg:97.17ms -step:1093/1695 train_time:106206ms step_avg:97.17ms -step:1094/1695 train_time:106301ms step_avg:97.17ms -step:1095/1695 train_time:106397ms step_avg:97.17ms -step:1096/1695 train_time:106493ms step_avg:97.16ms -step:1097/1695 train_time:106589ms step_avg:97.16ms -step:1098/1695 train_time:106685ms step_avg:97.16ms -step:1099/1695 train_time:106781ms step_avg:97.16ms -step:1100/1695 train_time:106878ms step_avg:97.16ms -step:1101/1695 train_time:106974ms step_avg:97.16ms -step:1102/1695 train_time:107070ms step_avg:97.16ms -step:1103/1695 train_time:107166ms step_avg:97.16ms -step:1104/1695 train_time:107261ms step_avg:97.16ms -step:1105/1695 train_time:107358ms step_avg:97.16ms -step:1106/1695 train_time:107455ms step_avg:97.16ms -step:1107/1695 train_time:107551ms step_avg:97.16ms -step:1108/1695 train_time:107648ms step_avg:97.16ms -step:1109/1695 train_time:107743ms step_avg:97.15ms -step:1110/1695 train_time:107839ms step_avg:97.15ms -step:1111/1695 train_time:107935ms step_avg:97.15ms -step:1112/1695 train_time:108031ms step_avg:97.15ms -step:1113/1695 train_time:108127ms step_avg:97.15ms -step:1114/1695 train_time:108223ms step_avg:97.15ms -step:1115/1695 train_time:108319ms step_avg:97.15ms -step:1116/1695 train_time:108416ms step_avg:97.15ms -step:1117/1695 train_time:108512ms step_avg:97.15ms -step:1118/1695 train_time:108608ms step_avg:97.15ms -step:1119/1695 train_time:108704ms step_avg:97.14ms -step:1120/1695 train_time:108799ms step_avg:97.14ms -step:1121/1695 train_time:108896ms step_avg:97.14ms -step:1122/1695 train_time:108993ms step_avg:97.14ms -step:1123/1695 train_time:109090ms step_avg:97.14ms -step:1124/1695 train_time:109186ms step_avg:97.14ms -step:1125/1695 train_time:109281ms step_avg:97.14ms -step:1125/1695 val_loss:3.4352 train_time:109375ms step_avg:97.22ms -step:1126/1695 train_time:109400ms step_avg:97.16ms -step:1127/1695 train_time:109483ms step_avg:97.15ms -step:1128/1695 train_time:109580ms step_avg:97.15ms -step:1129/1695 train_time:109676ms step_avg:97.14ms -step:1130/1695 train_time:109771ms step_avg:97.14ms -step:1131/1695 train_time:109866ms step_avg:97.14ms -step:1132/1695 train_time:109960ms step_avg:97.14ms -step:1133/1695 train_time:110056ms step_avg:97.14ms -step:1134/1695 train_time:110153ms step_avg:97.14ms -step:1135/1695 train_time:110251ms step_avg:97.14ms -step:1136/1695 train_time:110349ms step_avg:97.14ms -step:1137/1695 train_time:110450ms step_avg:97.14ms -step:1138/1695 train_time:110549ms step_avg:97.14ms -step:1139/1695 train_time:110648ms step_avg:97.14ms -step:1140/1695 train_time:110745ms step_avg:97.14ms -step:1141/1695 train_time:110841ms step_avg:97.14ms -step:1142/1695 train_time:110938ms step_avg:97.14ms -step:1143/1695 train_time:111036ms step_avg:97.14ms -step:1144/1695 train_time:111133ms step_avg:97.14ms -step:1145/1695 train_time:111230ms step_avg:97.14ms -step:1146/1695 train_time:111328ms step_avg:97.15ms -step:1147/1695 train_time:111427ms step_avg:97.15ms -step:1148/1695 train_time:111527ms step_avg:97.15ms -step:1149/1695 train_time:111626ms step_avg:97.15ms -step:1150/1695 train_time:111724ms step_avg:97.15ms -step:1151/1695 train_time:111820ms step_avg:97.15ms -step:1152/1695 train_time:111918ms step_avg:97.15ms -step:1153/1695 train_time:112015ms step_avg:97.15ms -step:1154/1695 train_time:112112ms step_avg:97.15ms -step:1155/1695 train_time:112210ms step_avg:97.15ms -step:1156/1695 train_time:112307ms step_avg:97.15ms -step:1157/1695 train_time:112405ms step_avg:97.15ms -step:1158/1695 train_time:112503ms step_avg:97.15ms -step:1159/1695 train_time:112602ms step_avg:97.15ms -step:1160/1695 train_time:112701ms step_avg:97.16ms -step:1161/1695 train_time:112798ms step_avg:97.16ms -step:1162/1695 train_time:112895ms step_avg:97.16ms -step:1163/1695 train_time:112993ms step_avg:97.16ms -step:1164/1695 train_time:113089ms step_avg:97.16ms -step:1165/1695 train_time:113186ms step_avg:97.16ms -step:1166/1695 train_time:113283ms step_avg:97.16ms -step:1167/1695 train_time:113381ms step_avg:97.16ms -step:1168/1695 train_time:113479ms step_avg:97.16ms -step:1169/1695 train_time:113578ms step_avg:97.16ms -step:1170/1695 train_time:113678ms step_avg:97.16ms -step:1171/1695 train_time:113777ms step_avg:97.16ms -step:1172/1695 train_time:113875ms step_avg:97.16ms -step:1173/1695 train_time:113973ms step_avg:97.16ms -step:1174/1695 train_time:114071ms step_avg:97.16ms -step:1175/1695 train_time:114168ms step_avg:97.16ms -step:1176/1695 train_time:114265ms step_avg:97.16ms -step:1177/1695 train_time:114362ms step_avg:97.16ms -step:1178/1695 train_time:114460ms step_avg:97.16ms -step:1179/1695 train_time:114558ms step_avg:97.17ms -step:1180/1695 train_time:114659ms step_avg:97.17ms -step:1181/1695 train_time:114757ms step_avg:97.17ms -step:1182/1695 train_time:114856ms step_avg:97.17ms -step:1183/1695 train_time:114955ms step_avg:97.17ms -step:1184/1695 train_time:115054ms step_avg:97.17ms -step:1185/1695 train_time:115152ms step_avg:97.17ms -step:1186/1695 train_time:115250ms step_avg:97.18ms -step:1187/1695 train_time:115348ms step_avg:97.18ms -step:1188/1695 train_time:115445ms step_avg:97.18ms -step:1189/1695 train_time:115543ms step_avg:97.18ms -step:1190/1695 train_time:115640ms step_avg:97.18ms -step:1191/1695 train_time:115739ms step_avg:97.18ms -step:1192/1695 train_time:115838ms step_avg:97.18ms -step:1193/1695 train_time:115938ms step_avg:97.18ms -step:1194/1695 train_time:116038ms step_avg:97.18ms -step:1195/1695 train_time:116138ms step_avg:97.19ms -step:1196/1695 train_time:116236ms step_avg:97.19ms -step:1197/1695 train_time:116336ms step_avg:97.19ms -step:1198/1695 train_time:116435ms step_avg:97.19ms -step:1199/1695 train_time:116533ms step_avg:97.19ms -step:1200/1695 train_time:116632ms step_avg:97.19ms -step:1201/1695 train_time:116730ms step_avg:97.19ms -step:1202/1695 train_time:116828ms step_avg:97.19ms -step:1203/1695 train_time:116927ms step_avg:97.20ms -step:1204/1695 train_time:117024ms step_avg:97.20ms -step:1205/1695 train_time:117122ms step_avg:97.20ms -step:1206/1695 train_time:117220ms step_avg:97.20ms -step:1207/1695 train_time:117318ms step_avg:97.20ms -step:1208/1695 train_time:117661ms step_avg:97.40ms -step:1209/1695 train_time:117847ms step_avg:97.47ms -step:1210/1695 train_time:117942ms step_avg:97.47ms -step:1211/1695 train_time:118039ms step_avg:97.47ms -step:1212/1695 train_time:118136ms step_avg:97.47ms -step:1213/1695 train_time:118233ms step_avg:97.47ms -step:1214/1695 train_time:118330ms step_avg:97.47ms -step:1215/1695 train_time:118426ms step_avg:97.47ms -step:1216/1695 train_time:118522ms step_avg:97.47ms -step:1217/1695 train_time:118618ms step_avg:97.47ms -step:1218/1695 train_time:118722ms step_avg:97.47ms -step:1219/1695 train_time:118826ms step_avg:97.48ms -step:1220/1695 train_time:118924ms step_avg:97.48ms -step:1221/1695 train_time:119021ms step_avg:97.48ms -step:1222/1695 train_time:119118ms step_avg:97.48ms -step:1223/1695 train_time:119215ms step_avg:97.48ms -step:1224/1695 train_time:119312ms step_avg:97.48ms -step:1225/1695 train_time:119409ms step_avg:97.48ms -step:1226/1695 train_time:119505ms step_avg:97.48ms -step:1227/1695 train_time:119602ms step_avg:97.48ms -step:1228/1695 train_time:119701ms step_avg:97.48ms -step:1229/1695 train_time:119802ms step_avg:97.48ms -step:1230/1695 train_time:119902ms step_avg:97.48ms -step:1231/1695 train_time:120000ms step_avg:97.48ms -step:1232/1695 train_time:120098ms step_avg:97.48ms -step:1233/1695 train_time:120196ms step_avg:97.48ms -step:1234/1695 train_time:120294ms step_avg:97.48ms -step:1235/1695 train_time:120391ms step_avg:97.48ms -step:1236/1695 train_time:120488ms step_avg:97.48ms -step:1237/1695 train_time:120585ms step_avg:97.48ms -step:1238/1695 train_time:120683ms step_avg:97.48ms -step:1239/1695 train_time:120781ms step_avg:97.48ms -step:1240/1695 train_time:120880ms step_avg:97.48ms -step:1241/1695 train_time:120978ms step_avg:97.48ms -step:1242/1695 train_time:121077ms step_avg:97.49ms -step:1243/1695 train_time:121174ms step_avg:97.48ms -step:1244/1695 train_time:121272ms step_avg:97.49ms -step:1245/1695 train_time:121369ms step_avg:97.48ms -step:1246/1695 train_time:121466ms step_avg:97.48ms -step:1247/1695 train_time:121562ms step_avg:97.48ms -step:1248/1695 train_time:121661ms step_avg:97.48ms -step:1249/1695 train_time:121759ms step_avg:97.49ms -step:1250/1695 train_time:121859ms step_avg:97.49ms -step:1250/1695 val_loss:3.3872 train_time:121956ms step_avg:97.56ms -step:1251/1695 train_time:121980ms step_avg:97.51ms -step:1252/1695 train_time:122061ms step_avg:97.49ms -step:1253/1695 train_time:122158ms step_avg:97.49ms -step:1254/1695 train_time:122255ms step_avg:97.49ms -step:1255/1695 train_time:122351ms step_avg:97.49ms -step:1256/1695 train_time:122448ms step_avg:97.49ms -step:1257/1695 train_time:122545ms step_avg:97.49ms -step:1258/1695 train_time:122641ms step_avg:97.49ms -step:1259/1695 train_time:122737ms step_avg:97.49ms -step:1260/1695 train_time:122835ms step_avg:97.49ms -step:1261/1695 train_time:122937ms step_avg:97.49ms -step:1262/1695 train_time:123037ms step_avg:97.49ms -step:1263/1695 train_time:123135ms step_avg:97.49ms -step:1264/1695 train_time:123233ms step_avg:97.49ms -step:1265/1695 train_time:123330ms step_avg:97.49ms -step:1266/1695 train_time:123427ms step_avg:97.49ms -step:1267/1695 train_time:123524ms step_avg:97.49ms -step:1268/1695 train_time:123621ms step_avg:97.49ms -step:1269/1695 train_time:123718ms step_avg:97.49ms -step:1270/1695 train_time:123816ms step_avg:97.49ms -step:1271/1695 train_time:123915ms step_avg:97.49ms -step:1272/1695 train_time:124014ms step_avg:97.50ms -step:1273/1695 train_time:124113ms step_avg:97.50ms -step:1274/1695 train_time:124210ms step_avg:97.50ms -step:1275/1695 train_time:124308ms step_avg:97.50ms -step:1276/1695 train_time:124406ms step_avg:97.50ms -step:1277/1695 train_time:124503ms step_avg:97.50ms -step:1278/1695 train_time:124600ms step_avg:97.50ms -step:1279/1695 train_time:124697ms step_avg:97.50ms -step:1280/1695 train_time:124795ms step_avg:97.50ms -step:1281/1695 train_time:124893ms step_avg:97.50ms -step:1282/1695 train_time:124991ms step_avg:97.50ms -step:1283/1695 train_time:125089ms step_avg:97.50ms -step:1284/1695 train_time:125187ms step_avg:97.50ms -step:1285/1695 train_time:125285ms step_avg:97.50ms -step:1286/1695 train_time:125383ms step_avg:97.50ms -step:1287/1695 train_time:125480ms step_avg:97.50ms -step:1288/1695 train_time:125578ms step_avg:97.50ms -step:1289/1695 train_time:125675ms step_avg:97.50ms -step:1290/1695 train_time:125772ms step_avg:97.50ms -step:1291/1695 train_time:125870ms step_avg:97.50ms -step:1292/1695 train_time:125968ms step_avg:97.50ms -step:1293/1695 train_time:126066ms step_avg:97.50ms -step:1294/1695 train_time:126165ms step_avg:97.50ms -step:1295/1695 train_time:126264ms step_avg:97.50ms -step:1296/1695 train_time:126364ms step_avg:97.50ms -step:1297/1695 train_time:126462ms step_avg:97.50ms -step:1298/1695 train_time:126560ms step_avg:97.50ms -step:1299/1695 train_time:126658ms step_avg:97.50ms -step:1300/1695 train_time:126756ms step_avg:97.50ms -step:1301/1695 train_time:126855ms step_avg:97.51ms -step:1302/1695 train_time:126952ms step_avg:97.51ms -step:1303/1695 train_time:127049ms step_avg:97.50ms -step:1304/1695 train_time:127147ms step_avg:97.51ms -step:1305/1695 train_time:127244ms step_avg:97.51ms -step:1306/1695 train_time:127343ms step_avg:97.51ms -step:1307/1695 train_time:127440ms step_avg:97.51ms -step:1308/1695 train_time:127539ms step_avg:97.51ms -step:1309/1695 train_time:127637ms step_avg:97.51ms -step:1310/1695 train_time:127735ms step_avg:97.51ms -step:1311/1695 train_time:127834ms step_avg:97.51ms -step:1312/1695 train_time:127931ms step_avg:97.51ms -step:1313/1695 train_time:128028ms step_avg:97.51ms -step:1314/1695 train_time:128126ms step_avg:97.51ms -step:1315/1695 train_time:128224ms step_avg:97.51ms -step:1316/1695 train_time:128322ms step_avg:97.51ms -step:1317/1695 train_time:128420ms step_avg:97.51ms -step:1318/1695 train_time:128518ms step_avg:97.51ms -step:1319/1695 train_time:128616ms step_avg:97.51ms -step:1320/1695 train_time:128714ms step_avg:97.51ms -step:1321/1695 train_time:128813ms step_avg:97.51ms -step:1322/1695 train_time:128910ms step_avg:97.51ms -step:1323/1695 train_time:129008ms step_avg:97.51ms -step:1324/1695 train_time:129106ms step_avg:97.51ms -step:1325/1695 train_time:129204ms step_avg:97.51ms -step:1326/1695 train_time:129303ms step_avg:97.51ms -step:1327/1695 train_time:129401ms step_avg:97.51ms -step:1328/1695 train_time:129499ms step_avg:97.51ms -step:1329/1695 train_time:129597ms step_avg:97.51ms -step:1330/1695 train_time:129695ms step_avg:97.51ms -step:1331/1695 train_time:129792ms step_avg:97.51ms -step:1332/1695 train_time:129890ms step_avg:97.52ms -step:1333/1695 train_time:129988ms step_avg:97.52ms -step:1334/1695 train_time:130085ms step_avg:97.51ms -step:1335/1695 train_time:130182ms step_avg:97.51ms -step:1336/1695 train_time:130281ms step_avg:97.52ms -step:1337/1695 train_time:130381ms step_avg:97.52ms -step:1338/1695 train_time:130478ms step_avg:97.52ms -step:1339/1695 train_time:130577ms step_avg:97.52ms -step:1340/1695 train_time:130674ms step_avg:97.52ms -step:1341/1695 train_time:130773ms step_avg:97.52ms -step:1342/1695 train_time:130870ms step_avg:97.52ms -step:1343/1695 train_time:130967ms step_avg:97.52ms -step:1344/1695 train_time:131063ms step_avg:97.52ms -step:1345/1695 train_time:131161ms step_avg:97.52ms -step:1346/1695 train_time:131259ms step_avg:97.52ms -step:1347/1695 train_time:131357ms step_avg:97.52ms -step:1348/1695 train_time:131455ms step_avg:97.52ms -step:1349/1695 train_time:131553ms step_avg:97.52ms -step:1350/1695 train_time:131651ms step_avg:97.52ms -step:1351/1695 train_time:131749ms step_avg:97.52ms -step:1352/1695 train_time:131847ms step_avg:97.52ms -step:1353/1695 train_time:131946ms step_avg:97.52ms -step:1354/1695 train_time:132044ms step_avg:97.52ms -step:1355/1695 train_time:132142ms step_avg:97.52ms -step:1356/1695 train_time:132239ms step_avg:97.52ms -step:1357/1695 train_time:132336ms step_avg:97.52ms -step:1358/1695 train_time:132434ms step_avg:97.52ms -step:1359/1695 train_time:132532ms step_avg:97.52ms -step:1360/1695 train_time:132629ms step_avg:97.52ms -step:1361/1695 train_time:132727ms step_avg:97.52ms -step:1362/1695 train_time:132825ms step_avg:97.52ms -step:1363/1695 train_time:132925ms step_avg:97.52ms -step:1364/1695 train_time:133023ms step_avg:97.52ms -step:1365/1695 train_time:133120ms step_avg:97.52ms -step:1366/1695 train_time:133218ms step_avg:97.52ms -step:1367/1695 train_time:133316ms step_avg:97.52ms -step:1368/1695 train_time:133413ms step_avg:97.52ms -step:1369/1695 train_time:133511ms step_avg:97.52ms -step:1370/1695 train_time:133609ms step_avg:97.52ms -step:1371/1695 train_time:133707ms step_avg:97.53ms -step:1372/1695 train_time:133805ms step_avg:97.53ms -step:1373/1695 train_time:133904ms step_avg:97.53ms -step:1374/1695 train_time:134003ms step_avg:97.53ms -step:1375/1695 train_time:134101ms step_avg:97.53ms -step:1375/1695 val_loss:3.3494 train_time:134197ms step_avg:97.60ms -step:1376/1695 train_time:134222ms step_avg:97.55ms -step:1377/1695 train_time:134308ms step_avg:97.54ms -step:1378/1695 train_time:134406ms step_avg:97.54ms -step:1379/1695 train_time:134504ms step_avg:97.54ms -step:1380/1695 train_time:134602ms step_avg:97.54ms -step:1381/1695 train_time:135056ms step_avg:97.80ms -step:1382/1695 train_time:135131ms step_avg:97.78ms -step:1383/1695 train_time:135227ms step_avg:97.78ms -step:1384/1695 train_time:135324ms step_avg:97.78ms -step:1385/1695 train_time:135420ms step_avg:97.78ms -step:1386/1695 train_time:135517ms step_avg:97.78ms -step:1387/1695 train_time:135613ms step_avg:97.77ms -step:1388/1695 train_time:135709ms step_avg:97.77ms -step:1389/1695 train_time:135806ms step_avg:97.77ms -step:1390/1695 train_time:135905ms step_avg:97.77ms -step:1391/1695 train_time:136009ms step_avg:97.78ms -step:1392/1695 train_time:136109ms step_avg:97.78ms -step:1393/1695 train_time:136207ms step_avg:97.78ms -step:1394/1695 train_time:136304ms step_avg:97.78ms -step:1395/1695 train_time:136402ms step_avg:97.78ms -step:1396/1695 train_time:136499ms step_avg:97.78ms -step:1397/1695 train_time:136596ms step_avg:97.78ms -step:1398/1695 train_time:136692ms step_avg:97.78ms -step:1399/1695 train_time:136788ms step_avg:97.78ms -step:1400/1695 train_time:136887ms step_avg:97.78ms -step:1401/1695 train_time:136985ms step_avg:97.78ms -step:1402/1695 train_time:137085ms step_avg:97.78ms -step:1403/1695 train_time:137185ms step_avg:97.78ms -step:1404/1695 train_time:137283ms step_avg:97.78ms -step:1405/1695 train_time:137381ms step_avg:97.78ms -step:1406/1695 train_time:137479ms step_avg:97.78ms -step:1407/1695 train_time:137577ms step_avg:97.78ms -step:1408/1695 train_time:137674ms step_avg:97.78ms -step:1409/1695 train_time:137771ms step_avg:97.78ms -step:1410/1695 train_time:137869ms step_avg:97.78ms -step:1411/1695 train_time:137967ms step_avg:97.78ms -step:1412/1695 train_time:138066ms step_avg:97.78ms -step:1413/1695 train_time:138164ms step_avg:97.78ms -step:1414/1695 train_time:138263ms step_avg:97.78ms -step:1415/1695 train_time:138361ms step_avg:97.78ms -step:1416/1695 train_time:138459ms step_avg:97.78ms -step:1417/1695 train_time:138557ms step_avg:97.78ms -step:1418/1695 train_time:138655ms step_avg:97.78ms -step:1419/1695 train_time:138753ms step_avg:97.78ms -step:1420/1695 train_time:138850ms step_avg:97.78ms -step:1421/1695 train_time:138947ms step_avg:97.78ms -step:1422/1695 train_time:139045ms step_avg:97.78ms -step:1423/1695 train_time:139143ms step_avg:97.78ms -step:1424/1695 train_time:139242ms step_avg:97.78ms -step:1425/1695 train_time:139340ms step_avg:97.78ms -step:1426/1695 train_time:139438ms step_avg:97.78ms -step:1427/1695 train_time:139535ms step_avg:97.78ms -step:1428/1695 train_time:139633ms step_avg:97.78ms -step:1429/1695 train_time:139730ms step_avg:97.78ms -step:1430/1695 train_time:139828ms step_avg:97.78ms -step:1431/1695 train_time:139926ms step_avg:97.78ms -step:1432/1695 train_time:140024ms step_avg:97.78ms -step:1433/1695 train_time:140122ms step_avg:97.78ms -step:1434/1695 train_time:140222ms step_avg:97.78ms -step:1435/1695 train_time:140319ms step_avg:97.78ms -step:1436/1695 train_time:140418ms step_avg:97.78ms -step:1437/1695 train_time:140516ms step_avg:97.78ms -step:1438/1695 train_time:140613ms step_avg:97.78ms -step:1439/1695 train_time:140711ms step_avg:97.78ms -step:1440/1695 train_time:140808ms step_avg:97.78ms -step:1441/1695 train_time:140905ms step_avg:97.78ms -step:1442/1695 train_time:141003ms step_avg:97.78ms -step:1443/1695 train_time:141101ms step_avg:97.78ms -step:1444/1695 train_time:141199ms step_avg:97.78ms -step:1445/1695 train_time:141297ms step_avg:97.78ms -step:1446/1695 train_time:141395ms step_avg:97.78ms -step:1447/1695 train_time:141493ms step_avg:97.78ms -step:1448/1695 train_time:141590ms step_avg:97.78ms -step:1449/1695 train_time:141687ms step_avg:97.78ms -step:1450/1695 train_time:141785ms step_avg:97.78ms -step:1451/1695 train_time:141882ms step_avg:97.78ms -step:1452/1695 train_time:141980ms step_avg:97.78ms -step:1453/1695 train_time:142078ms step_avg:97.78ms -step:1454/1695 train_time:142176ms step_avg:97.78ms -step:1455/1695 train_time:142274ms step_avg:97.78ms -step:1456/1695 train_time:142372ms step_avg:97.78ms -step:1457/1695 train_time:142470ms step_avg:97.78ms -step:1458/1695 train_time:142568ms step_avg:97.78ms -step:1459/1695 train_time:142665ms step_avg:97.78ms -step:1460/1695 train_time:142764ms step_avg:97.78ms -step:1461/1695 train_time:142863ms step_avg:97.78ms -step:1462/1695 train_time:142961ms step_avg:97.78ms -step:1463/1695 train_time:143059ms step_avg:97.78ms -step:1464/1695 train_time:143156ms step_avg:97.78ms -step:1465/1695 train_time:143253ms step_avg:97.78ms -step:1466/1695 train_time:143351ms step_avg:97.78ms -step:1467/1695 train_time:143448ms step_avg:97.78ms -step:1468/1695 train_time:143546ms step_avg:97.78ms -step:1469/1695 train_time:143644ms step_avg:97.78ms -step:1470/1695 train_time:143742ms step_avg:97.78ms -step:1471/1695 train_time:143841ms step_avg:97.78ms -step:1472/1695 train_time:143938ms step_avg:97.78ms -step:1473/1695 train_time:144036ms step_avg:97.78ms -step:1474/1695 train_time:144132ms step_avg:97.78ms -step:1475/1695 train_time:144231ms step_avg:97.78ms -step:1476/1695 train_time:144328ms step_avg:97.78ms -step:1477/1695 train_time:144426ms step_avg:97.78ms -step:1478/1695 train_time:144524ms step_avg:97.78ms -step:1479/1695 train_time:144623ms step_avg:97.78ms -step:1480/1695 train_time:144722ms step_avg:97.79ms -step:1481/1695 train_time:144821ms step_avg:97.79ms -step:1482/1695 train_time:144919ms step_avg:97.79ms -step:1483/1695 train_time:145017ms step_avg:97.79ms -step:1484/1695 train_time:145115ms step_avg:97.79ms -step:1485/1695 train_time:145212ms step_avg:97.79ms -step:1486/1695 train_time:145310ms step_avg:97.79ms -step:1487/1695 train_time:145407ms step_avg:97.79ms -step:1488/1695 train_time:145505ms step_avg:97.79ms -step:1489/1695 train_time:145603ms step_avg:97.79ms -step:1490/1695 train_time:145702ms step_avg:97.79ms -step:1491/1695 train_time:145799ms step_avg:97.79ms -step:1492/1695 train_time:145896ms step_avg:97.79ms -step:1493/1695 train_time:145994ms step_avg:97.79ms -step:1494/1695 train_time:146092ms step_avg:97.79ms -step:1495/1695 train_time:146190ms step_avg:97.79ms -step:1496/1695 train_time:146288ms step_avg:97.79ms -step:1497/1695 train_time:146385ms step_avg:97.79ms -step:1498/1695 train_time:146482ms step_avg:97.79ms -step:1499/1695 train_time:146580ms step_avg:97.79ms -step:1500/1695 train_time:146679ms step_avg:97.79ms -step:1500/1695 val_loss:3.3158 train_time:146775ms step_avg:97.85ms -step:1501/1695 train_time:146802ms step_avg:97.80ms -step:1502/1695 train_time:146885ms step_avg:97.79ms -step:1503/1695 train_time:146985ms step_avg:97.79ms -step:1504/1695 train_time:147082ms step_avg:97.79ms -step:1505/1695 train_time:147180ms step_avg:97.79ms -step:1506/1695 train_time:147276ms step_avg:97.79ms -step:1507/1695 train_time:147372ms step_avg:97.79ms -step:1508/1695 train_time:147469ms step_avg:97.79ms -step:1509/1695 train_time:147566ms step_avg:97.79ms -step:1510/1695 train_time:147663ms step_avg:97.79ms -step:1511/1695 train_time:147762ms step_avg:97.79ms -step:1512/1695 train_time:147865ms step_avg:97.79ms -step:1513/1695 train_time:147965ms step_avg:97.80ms -step:1514/1695 train_time:148064ms step_avg:97.80ms -step:1515/1695 train_time:148162ms step_avg:97.80ms -step:1516/1695 train_time:148260ms step_avg:97.80ms -step:1517/1695 train_time:148356ms step_avg:97.80ms -step:1518/1695 train_time:148454ms step_avg:97.80ms -step:1519/1695 train_time:148551ms step_avg:97.80ms -step:1520/1695 train_time:148647ms step_avg:97.79ms -step:1521/1695 train_time:148745ms step_avg:97.79ms -step:1522/1695 train_time:148844ms step_avg:97.79ms -step:1523/1695 train_time:148943ms step_avg:97.80ms -step:1524/1695 train_time:149041ms step_avg:97.80ms -step:1525/1695 train_time:149140ms step_avg:97.80ms -step:1526/1695 train_time:149238ms step_avg:97.80ms -step:1527/1695 train_time:149336ms step_avg:97.80ms -step:1528/1695 train_time:149434ms step_avg:97.80ms -step:1529/1695 train_time:149531ms step_avg:97.80ms -step:1530/1695 train_time:149628ms step_avg:97.80ms -step:1531/1695 train_time:149726ms step_avg:97.80ms -step:1532/1695 train_time:149824ms step_avg:97.80ms -step:1533/1695 train_time:149922ms step_avg:97.80ms -step:1534/1695 train_time:150020ms step_avg:97.80ms -step:1535/1695 train_time:150119ms step_avg:97.80ms -step:1536/1695 train_time:150217ms step_avg:97.80ms -step:1537/1695 train_time:150316ms step_avg:97.80ms -step:1538/1695 train_time:150413ms step_avg:97.80ms -step:1539/1695 train_time:150510ms step_avg:97.80ms -step:1540/1695 train_time:150607ms step_avg:97.80ms -step:1541/1695 train_time:150704ms step_avg:97.80ms -step:1542/1695 train_time:150803ms step_avg:97.80ms -step:1543/1695 train_time:150902ms step_avg:97.80ms -step:1544/1695 train_time:151000ms step_avg:97.80ms -step:1545/1695 train_time:151099ms step_avg:97.80ms -step:1546/1695 train_time:151198ms step_avg:97.80ms -step:1547/1695 train_time:151296ms step_avg:97.80ms -step:1548/1695 train_time:151394ms step_avg:97.80ms -step:1549/1695 train_time:151492ms step_avg:97.80ms -step:1550/1695 train_time:151590ms step_avg:97.80ms -step:1551/1695 train_time:151687ms step_avg:97.80ms -step:1552/1695 train_time:152039ms step_avg:97.96ms -step:1553/1695 train_time:152209ms step_avg:98.01ms -step:1554/1695 train_time:152305ms step_avg:98.01ms -step:1555/1695 train_time:152401ms step_avg:98.01ms -step:1556/1695 train_time:152498ms step_avg:98.01ms -step:1557/1695 train_time:152595ms step_avg:98.01ms -step:1558/1695 train_time:152692ms step_avg:98.00ms -step:1559/1695 train_time:152787ms step_avg:98.00ms -step:1560/1695 train_time:152884ms step_avg:98.00ms -step:1561/1695 train_time:152980ms step_avg:98.00ms -step:1562/1695 train_time:153085ms step_avg:98.01ms -step:1563/1695 train_time:153187ms step_avg:98.01ms -step:1564/1695 train_time:153287ms step_avg:98.01ms -step:1565/1695 train_time:153384ms step_avg:98.01ms -step:1566/1695 train_time:153482ms step_avg:98.01ms -step:1567/1695 train_time:153579ms step_avg:98.01ms -step:1568/1695 train_time:153677ms step_avg:98.01ms -step:1569/1695 train_time:153774ms step_avg:98.01ms -step:1570/1695 train_time:153871ms step_avg:98.01ms -step:1571/1695 train_time:153968ms step_avg:98.01ms -step:1572/1695 train_time:154066ms step_avg:98.01ms -step:1573/1695 train_time:154166ms step_avg:98.01ms -step:1574/1695 train_time:154265ms step_avg:98.01ms -step:1575/1695 train_time:154364ms step_avg:98.01ms -step:1576/1695 train_time:154461ms step_avg:98.01ms -step:1577/1695 train_time:154559ms step_avg:98.01ms -step:1578/1695 train_time:154656ms step_avg:98.01ms -step:1579/1695 train_time:154753ms step_avg:98.01ms -step:1580/1695 train_time:154850ms step_avg:98.01ms -step:1581/1695 train_time:154946ms step_avg:98.01ms -step:1582/1695 train_time:155044ms step_avg:98.01ms -step:1583/1695 train_time:155144ms step_avg:98.01ms -step:1584/1695 train_time:155243ms step_avg:98.01ms -step:1585/1695 train_time:155341ms step_avg:98.01ms -step:1586/1695 train_time:155439ms step_avg:98.01ms -step:1587/1695 train_time:155537ms step_avg:98.01ms -step:1588/1695 train_time:155634ms step_avg:98.01ms -step:1589/1695 train_time:155731ms step_avg:98.01ms -step:1590/1695 train_time:155829ms step_avg:98.01ms -step:1591/1695 train_time:155926ms step_avg:98.00ms -step:1592/1695 train_time:156023ms step_avg:98.00ms -step:1593/1695 train_time:156121ms step_avg:98.00ms -step:1594/1695 train_time:156220ms step_avg:98.01ms -step:1595/1695 train_time:156320ms step_avg:98.01ms -step:1596/1695 train_time:156419ms step_avg:98.01ms -step:1597/1695 train_time:156517ms step_avg:98.01ms -step:1598/1695 train_time:156615ms step_avg:98.01ms -step:1599/1695 train_time:156713ms step_avg:98.01ms -step:1600/1695 train_time:156810ms step_avg:98.01ms -step:1601/1695 train_time:156908ms step_avg:98.01ms -step:1602/1695 train_time:157005ms step_avg:98.01ms -step:1603/1695 train_time:157103ms step_avg:98.01ms -step:1604/1695 train_time:157201ms step_avg:98.01ms -step:1605/1695 train_time:157300ms step_avg:98.01ms -step:1606/1695 train_time:157399ms step_avg:98.01ms -step:1607/1695 train_time:157497ms step_avg:98.01ms -step:1608/1695 train_time:157595ms step_avg:98.01ms -step:1609/1695 train_time:157693ms step_avg:98.01ms -step:1610/1695 train_time:157791ms step_avg:98.01ms -step:1611/1695 train_time:157889ms step_avg:98.01ms -step:1612/1695 train_time:157986ms step_avg:98.01ms -step:1613/1695 train_time:158083ms step_avg:98.01ms -step:1614/1695 train_time:158181ms step_avg:98.01ms -step:1615/1695 train_time:158279ms step_avg:98.01ms -step:1616/1695 train_time:158378ms step_avg:98.01ms -step:1617/1695 train_time:158477ms step_avg:98.01ms -step:1618/1695 train_time:158575ms step_avg:98.01ms -step:1619/1695 train_time:158672ms step_avg:98.01ms -step:1620/1695 train_time:158771ms step_avg:98.01ms -step:1621/1695 train_time:158869ms step_avg:98.01ms -step:1622/1695 train_time:158967ms step_avg:98.01ms -step:1623/1695 train_time:159064ms step_avg:98.01ms -step:1624/1695 train_time:159161ms step_avg:98.01ms -step:1625/1695 train_time:159259ms step_avg:98.01ms -step:1625/1695 val_loss:3.2885 train_time:159356ms step_avg:98.07ms -step:1626/1695 train_time:159382ms step_avg:98.02ms -step:1627/1695 train_time:159464ms step_avg:98.01ms -step:1628/1695 train_time:159563ms step_avg:98.01ms -step:1629/1695 train_time:159661ms step_avg:98.01ms -step:1630/1695 train_time:159759ms step_avg:98.01ms -step:1631/1695 train_time:159856ms step_avg:98.01ms -step:1632/1695 train_time:159953ms step_avg:98.01ms -step:1633/1695 train_time:160051ms step_avg:98.01ms -step:1634/1695 train_time:160147ms step_avg:98.01ms -step:1635/1695 train_time:160244ms step_avg:98.01ms -step:1636/1695 train_time:160346ms step_avg:98.01ms -step:1637/1695 train_time:160445ms step_avg:98.01ms -step:1638/1695 train_time:160544ms step_avg:98.01ms -step:1639/1695 train_time:160643ms step_avg:98.01ms -step:1640/1695 train_time:160740ms step_avg:98.01ms -step:1641/1695 train_time:160839ms step_avg:98.01ms -step:1642/1695 train_time:160936ms step_avg:98.01ms -step:1643/1695 train_time:161034ms step_avg:98.01ms -step:1644/1695 train_time:161132ms step_avg:98.01ms -step:1645/1695 train_time:161230ms step_avg:98.01ms -step:1646/1695 train_time:161329ms step_avg:98.01ms -step:1647/1695 train_time:161427ms step_avg:98.01ms -step:1648/1695 train_time:161526ms step_avg:98.01ms -step:1649/1695 train_time:161624ms step_avg:98.01ms -step:1650/1695 train_time:161721ms step_avg:98.01ms -step:1651/1695 train_time:161819ms step_avg:98.01ms -step:1652/1695 train_time:161917ms step_avg:98.01ms -step:1653/1695 train_time:162016ms step_avg:98.01ms -step:1654/1695 train_time:162114ms step_avg:98.01ms -step:1655/1695 train_time:162213ms step_avg:98.01ms -step:1656/1695 train_time:162313ms step_avg:98.01ms -step:1657/1695 train_time:162411ms step_avg:98.02ms -step:1658/1695 train_time:162510ms step_avg:98.02ms -step:1659/1695 train_time:162608ms step_avg:98.02ms -step:1660/1695 train_time:162706ms step_avg:98.02ms -step:1661/1695 train_time:162804ms step_avg:98.02ms -step:1662/1695 train_time:162901ms step_avg:98.02ms -step:1663/1695 train_time:163000ms step_avg:98.02ms -step:1664/1695 train_time:163099ms step_avg:98.02ms -step:1665/1695 train_time:163198ms step_avg:98.02ms -step:1666/1695 train_time:163297ms step_avg:98.02ms -step:1667/1695 train_time:163396ms step_avg:98.02ms -step:1668/1695 train_time:163496ms step_avg:98.02ms -step:1669/1695 train_time:163595ms step_avg:98.02ms -step:1670/1695 train_time:163695ms step_avg:98.02ms -step:1671/1695 train_time:163793ms step_avg:98.02ms -step:1672/1695 train_time:163891ms step_avg:98.02ms -step:1673/1695 train_time:163988ms step_avg:98.02ms -step:1674/1695 train_time:164085ms step_avg:98.02ms -step:1675/1695 train_time:164182ms step_avg:98.02ms -step:1676/1695 train_time:164280ms step_avg:98.02ms -step:1677/1695 train_time:164379ms step_avg:98.02ms -step:1678/1695 train_time:164478ms step_avg:98.02ms -step:1679/1695 train_time:164578ms step_avg:98.02ms -step:1680/1695 train_time:164677ms step_avg:98.02ms -step:1681/1695 train_time:164774ms step_avg:98.02ms -step:1682/1695 train_time:164872ms step_avg:98.02ms -step:1683/1695 train_time:164969ms step_avg:98.02ms -step:1684/1695 train_time:165067ms step_avg:98.02ms -step:1685/1695 train_time:165164ms step_avg:98.02ms -step:1686/1695 train_time:165262ms step_avg:98.02ms -step:1687/1695 train_time:165359ms step_avg:98.02ms -step:1688/1695 train_time:165458ms step_avg:98.02ms -step:1689/1695 train_time:165558ms step_avg:98.02ms -step:1690/1695 train_time:165659ms step_avg:98.02ms -step:1691/1695 train_time:165757ms step_avg:98.02ms -step:1692/1695 train_time:165855ms step_avg:98.02ms -step:1693/1695 train_time:165953ms step_avg:98.02ms -step:1694/1695 train_time:166052ms step_avg:98.02ms -step:1695/1695 train_time:166151ms step_avg:98.02ms -step:1695/1695 val_loss:3.2769 train_time:166247ms step_avg:98.08ms -peak memory allocated: 34505 MiB reserved: 49576 MiB diff --git a/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt b/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt deleted file mode 100644 index 7e21a501e..000000000 --- a/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:53:12 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 32C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:508ms step_avg:507.99ms -step:2/1695 train_time:531ms step_avg:265.69ms -step:3/1695 train_time:603ms step_avg:200.91ms -step:4/1695 train_time:695ms step_avg:173.68ms -step:5/1695 train_time:787ms step_avg:157.47ms -step:6/1695 train_time:881ms step_avg:146.76ms -step:7/1695 train_time:974ms step_avg:139.13ms -step:8/1695 train_time:1067ms step_avg:133.43ms -step:9/1695 train_time:1161ms step_avg:128.95ms -step:10/1695 train_time:1253ms step_avg:125.34ms -step:11/1695 train_time:1347ms step_avg:122.48ms -step:12/1695 train_time:1445ms step_avg:120.43ms -step:13/1695 train_time:1543ms step_avg:118.70ms -step:14/1695 train_time:1638ms step_avg:117.04ms -step:15/1695 train_time:1732ms step_avg:115.49ms -step:16/1695 train_time:1827ms step_avg:114.19ms -step:17/1695 train_time:1921ms step_avg:113.01ms -step:18/1695 train_time:2015ms step_avg:111.93ms -step:19/1695 train_time:2108ms step_avg:110.95ms -step:20/1695 train_time:2202ms step_avg:110.08ms -step:21/1695 train_time:2295ms step_avg:109.29ms -step:22/1695 train_time:2390ms step_avg:108.63ms -step:23/1695 train_time:2486ms step_avg:108.07ms -step:24/1695 train_time:2583ms step_avg:107.61ms -step:25/1695 train_time:2678ms step_avg:107.10ms -step:26/1695 train_time:2771ms step_avg:106.59ms -step:27/1695 train_time:2867ms step_avg:106.18ms -step:28/1695 train_time:2962ms step_avg:105.78ms -step:29/1695 train_time:3055ms step_avg:105.34ms -step:30/1695 train_time:3149ms step_avg:104.96ms -step:31/1695 train_time:3243ms step_avg:104.60ms -step:32/1695 train_time:3336ms step_avg:104.26ms -step:33/1695 train_time:3430ms step_avg:103.94ms -step:34/1695 train_time:3526ms step_avg:103.70ms -step:35/1695 train_time:3622ms step_avg:103.48ms -step:36/1695 train_time:3717ms step_avg:103.24ms -step:37/1695 train_time:3811ms step_avg:102.99ms -step:38/1695 train_time:3906ms step_avg:102.79ms -step:39/1695 train_time:4001ms step_avg:102.58ms -step:40/1695 train_time:4094ms step_avg:102.36ms -step:41/1695 train_time:4188ms step_avg:102.14ms -step:42/1695 train_time:4282ms step_avg:101.96ms -step:43/1695 train_time:4376ms step_avg:101.76ms -step:44/1695 train_time:4470ms step_avg:101.58ms -step:45/1695 train_time:4565ms step_avg:101.45ms -step:46/1695 train_time:4661ms step_avg:101.32ms -step:47/1695 train_time:4754ms step_avg:101.15ms -step:48/1695 train_time:4849ms step_avg:101.01ms -step:49/1695 train_time:4944ms step_avg:100.89ms -step:50/1695 train_time:5039ms step_avg:100.77ms -step:51/1695 train_time:5132ms step_avg:100.63ms -step:52/1695 train_time:5227ms step_avg:100.51ms -step:53/1695 train_time:5322ms step_avg:100.41ms -step:54/1695 train_time:5416ms step_avg:100.29ms -step:55/1695 train_time:5510ms step_avg:100.18ms -step:56/1695 train_time:5606ms step_avg:100.10ms -step:57/1695 train_time:5701ms step_avg:100.02ms -step:58/1695 train_time:5795ms step_avg:99.91ms -step:59/1695 train_time:5889ms step_avg:99.81ms -step:60/1695 train_time:5984ms step_avg:99.73ms -step:61/1695 train_time:6077ms step_avg:99.63ms -step:62/1695 train_time:6171ms step_avg:99.54ms -step:63/1695 train_time:6267ms step_avg:99.47ms -step:64/1695 train_time:6362ms step_avg:99.40ms -step:65/1695 train_time:6457ms step_avg:99.33ms -step:66/1695 train_time:6551ms step_avg:99.26ms -step:67/1695 train_time:6647ms step_avg:99.21ms -step:68/1695 train_time:6743ms step_avg:99.15ms -step:69/1695 train_time:6837ms step_avg:99.08ms -step:70/1695 train_time:6930ms step_avg:99.00ms -step:71/1695 train_time:7025ms step_avg:98.95ms -step:72/1695 train_time:7119ms step_avg:98.88ms -step:73/1695 train_time:7214ms step_avg:98.82ms -step:74/1695 train_time:7308ms step_avg:98.75ms -step:75/1695 train_time:7403ms step_avg:98.70ms -step:76/1695 train_time:7497ms step_avg:98.64ms -step:77/1695 train_time:7591ms step_avg:98.59ms -step:78/1695 train_time:7685ms step_avg:98.53ms -step:79/1695 train_time:7781ms step_avg:98.50ms -step:80/1695 train_time:7876ms step_avg:98.45ms -step:81/1695 train_time:7969ms step_avg:98.39ms -step:82/1695 train_time:8065ms step_avg:98.35ms -step:83/1695 train_time:8160ms step_avg:98.31ms -step:84/1695 train_time:8253ms step_avg:98.25ms -step:85/1695 train_time:8347ms step_avg:98.21ms -step:86/1695 train_time:8442ms step_avg:98.16ms -step:87/1695 train_time:8536ms step_avg:98.11ms -step:88/1695 train_time:8630ms step_avg:98.07ms -step:89/1695 train_time:8725ms step_avg:98.03ms -step:90/1695 train_time:8819ms step_avg:97.99ms -step:91/1695 train_time:8913ms step_avg:97.95ms -step:92/1695 train_time:9008ms step_avg:97.91ms -step:93/1695 train_time:9102ms step_avg:97.87ms -step:94/1695 train_time:9196ms step_avg:97.83ms -step:95/1695 train_time:9290ms step_avg:97.79ms -step:96/1695 train_time:9385ms step_avg:97.76ms -step:97/1695 train_time:9480ms step_avg:97.73ms -step:98/1695 train_time:9574ms step_avg:97.69ms -step:99/1695 train_time:9669ms step_avg:97.66ms -step:100/1695 train_time:9764ms step_avg:97.64ms -step:101/1695 train_time:9858ms step_avg:97.61ms -step:102/1695 train_time:9952ms step_avg:97.57ms -step:103/1695 train_time:10047ms step_avg:97.54ms -step:104/1695 train_time:10141ms step_avg:97.51ms -step:105/1695 train_time:10235ms step_avg:97.47ms -step:106/1695 train_time:10329ms step_avg:97.44ms -step:107/1695 train_time:10424ms step_avg:97.42ms -step:108/1695 train_time:10519ms step_avg:97.40ms -step:109/1695 train_time:10613ms step_avg:97.37ms -step:110/1695 train_time:10708ms step_avg:97.34ms -step:111/1695 train_time:10802ms step_avg:97.32ms -step:112/1695 train_time:10896ms step_avg:97.29ms -step:113/1695 train_time:10990ms step_avg:97.25ms -step:114/1695 train_time:11084ms step_avg:97.22ms -step:115/1695 train_time:11178ms step_avg:97.20ms -step:116/1695 train_time:11271ms step_avg:97.17ms -step:117/1695 train_time:11366ms step_avg:97.14ms -step:118/1695 train_time:11460ms step_avg:97.12ms -step:119/1695 train_time:11555ms step_avg:97.10ms -step:120/1695 train_time:11649ms step_avg:97.07ms -step:121/1695 train_time:11744ms step_avg:97.05ms -step:122/1695 train_time:11839ms step_avg:97.04ms -step:123/1695 train_time:11933ms step_avg:97.01ms -step:124/1695 train_time:12028ms step_avg:97.00ms -step:125/1695 train_time:12122ms step_avg:96.98ms -step:125/1695 val_loss:4.3129 train_time:12214ms step_avg:97.71ms -step:126/1695 train_time:12238ms step_avg:97.13ms -step:127/1695 train_time:12320ms step_avg:97.01ms -step:128/1695 train_time:12421ms step_avg:97.04ms -step:129/1695 train_time:12516ms step_avg:97.02ms -step:130/1695 train_time:12609ms step_avg:97.00ms -step:131/1695 train_time:12702ms step_avg:96.96ms -step:132/1695 train_time:12795ms step_avg:96.94ms -step:133/1695 train_time:12889ms step_avg:96.91ms -step:134/1695 train_time:12982ms step_avg:96.88ms -step:135/1695 train_time:13075ms step_avg:96.85ms -step:136/1695 train_time:13168ms step_avg:96.83ms -step:137/1695 train_time:13264ms step_avg:96.82ms -step:138/1695 train_time:13359ms step_avg:96.81ms -step:139/1695 train_time:13455ms step_avg:96.80ms -step:140/1695 train_time:13550ms step_avg:96.79ms -step:141/1695 train_time:13645ms step_avg:96.77ms -step:142/1695 train_time:13739ms step_avg:96.75ms -step:143/1695 train_time:13832ms step_avg:96.73ms -step:144/1695 train_time:13926ms step_avg:96.71ms -step:145/1695 train_time:14019ms step_avg:96.68ms -step:146/1695 train_time:14112ms step_avg:96.66ms -step:147/1695 train_time:14206ms step_avg:96.64ms -step:148/1695 train_time:14302ms step_avg:96.63ms -step:149/1695 train_time:14396ms step_avg:96.62ms -step:150/1695 train_time:14492ms step_avg:96.61ms -step:151/1695 train_time:14586ms step_avg:96.60ms -step:152/1695 train_time:14681ms step_avg:96.59ms -step:153/1695 train_time:14775ms step_avg:96.57ms -step:154/1695 train_time:14869ms step_avg:96.55ms -step:155/1695 train_time:14963ms step_avg:96.53ms -step:156/1695 train_time:15056ms step_avg:96.51ms -step:157/1695 train_time:15149ms step_avg:96.49ms -step:158/1695 train_time:15245ms step_avg:96.49ms -step:159/1695 train_time:15340ms step_avg:96.48ms -step:160/1695 train_time:15434ms step_avg:96.46ms -step:161/1695 train_time:15529ms step_avg:96.45ms -step:162/1695 train_time:15624ms step_avg:96.45ms -step:163/1695 train_time:15719ms step_avg:96.44ms -step:164/1695 train_time:15813ms step_avg:96.42ms -step:165/1695 train_time:15907ms step_avg:96.40ms -step:166/1695 train_time:16001ms step_avg:96.39ms -step:167/1695 train_time:16095ms step_avg:96.38ms -step:168/1695 train_time:16189ms step_avg:96.36ms -step:169/1695 train_time:16283ms step_avg:96.35ms -step:170/1695 train_time:16378ms step_avg:96.34ms -step:171/1695 train_time:16471ms step_avg:96.32ms -step:172/1695 train_time:16567ms step_avg:96.32ms -step:173/1695 train_time:16951ms step_avg:97.98ms -step:174/1695 train_time:17020ms step_avg:97.82ms -step:175/1695 train_time:17112ms step_avg:97.78ms -step:176/1695 train_time:17206ms step_avg:97.76ms -step:177/1695 train_time:17299ms step_avg:97.73ms -step:178/1695 train_time:17391ms step_avg:97.70ms -step:179/1695 train_time:17484ms step_avg:97.68ms -step:180/1695 train_time:17578ms step_avg:97.65ms -step:181/1695 train_time:17671ms step_avg:97.63ms -step:182/1695 train_time:17764ms step_avg:97.61ms -step:183/1695 train_time:17859ms step_avg:97.59ms -step:184/1695 train_time:17956ms step_avg:97.59ms -step:185/1695 train_time:18052ms step_avg:97.58ms -step:186/1695 train_time:18148ms step_avg:97.57ms -step:187/1695 train_time:18243ms step_avg:97.56ms -step:188/1695 train_time:18336ms step_avg:97.53ms -step:189/1695 train_time:18429ms step_avg:97.51ms -step:190/1695 train_time:18522ms step_avg:97.49ms -step:191/1695 train_time:18615ms step_avg:97.46ms -step:192/1695 train_time:18709ms step_avg:97.44ms -step:193/1695 train_time:18803ms step_avg:97.42ms -step:194/1695 train_time:18898ms step_avg:97.41ms -step:195/1695 train_time:18993ms step_avg:97.40ms -step:196/1695 train_time:19088ms step_avg:97.39ms -step:197/1695 train_time:19184ms step_avg:97.38ms -step:198/1695 train_time:19278ms step_avg:97.37ms -step:199/1695 train_time:19372ms step_avg:97.35ms -step:200/1695 train_time:19466ms step_avg:97.33ms -step:201/1695 train_time:19560ms step_avg:97.31ms -step:202/1695 train_time:19652ms step_avg:97.29ms -step:203/1695 train_time:19747ms step_avg:97.28ms -step:204/1695 train_time:19841ms step_avg:97.26ms -step:205/1695 train_time:19935ms step_avg:97.24ms -step:206/1695 train_time:20030ms step_avg:97.23ms -step:207/1695 train_time:20124ms step_avg:97.22ms -step:208/1695 train_time:20219ms step_avg:97.21ms -step:209/1695 train_time:20312ms step_avg:97.19ms -step:210/1695 train_time:20406ms step_avg:97.17ms -step:211/1695 train_time:20500ms step_avg:97.16ms -step:212/1695 train_time:20593ms step_avg:97.14ms -step:213/1695 train_time:20686ms step_avg:97.12ms -step:214/1695 train_time:20781ms step_avg:97.11ms -step:215/1695 train_time:20874ms step_avg:97.09ms -step:216/1695 train_time:20969ms step_avg:97.08ms -step:217/1695 train_time:21063ms step_avg:97.07ms -step:218/1695 train_time:21158ms step_avg:97.05ms -step:219/1695 train_time:21252ms step_avg:97.04ms -step:220/1695 train_time:21348ms step_avg:97.04ms -step:221/1695 train_time:21442ms step_avg:97.02ms -step:222/1695 train_time:21535ms step_avg:97.01ms -step:223/1695 train_time:21629ms step_avg:96.99ms -step:224/1695 train_time:21723ms step_avg:96.98ms -step:225/1695 train_time:21817ms step_avg:96.97ms -step:226/1695 train_time:21911ms step_avg:96.95ms -step:227/1695 train_time:22005ms step_avg:96.94ms -step:228/1695 train_time:22099ms step_avg:96.93ms -step:229/1695 train_time:22193ms step_avg:96.91ms -step:230/1695 train_time:22287ms step_avg:96.90ms -step:231/1695 train_time:22382ms step_avg:96.89ms -step:232/1695 train_time:22476ms step_avg:96.88ms -step:233/1695 train_time:22570ms step_avg:96.87ms -step:234/1695 train_time:22663ms step_avg:96.85ms -step:235/1695 train_time:22756ms step_avg:96.83ms -step:236/1695 train_time:22850ms step_avg:96.82ms -step:237/1695 train_time:22945ms step_avg:96.81ms -step:238/1695 train_time:23040ms step_avg:96.81ms -step:239/1695 train_time:23134ms step_avg:96.79ms -step:240/1695 train_time:23228ms step_avg:96.78ms -step:241/1695 train_time:23323ms step_avg:96.78ms -step:242/1695 train_time:23418ms step_avg:96.77ms -step:243/1695 train_time:23512ms step_avg:96.76ms -step:244/1695 train_time:23606ms step_avg:96.75ms -step:245/1695 train_time:23699ms step_avg:96.73ms -step:246/1695 train_time:23792ms step_avg:96.72ms -step:247/1695 train_time:23886ms step_avg:96.71ms -step:248/1695 train_time:23981ms step_avg:96.70ms -step:249/1695 train_time:24075ms step_avg:96.69ms -step:250/1695 train_time:24169ms step_avg:96.68ms -step:250/1695 val_loss:3.9787 train_time:24262ms step_avg:97.05ms -step:251/1695 train_time:24286ms step_avg:96.76ms -step:252/1695 train_time:24364ms step_avg:96.68ms -step:253/1695 train_time:24461ms step_avg:96.68ms -step:254/1695 train_time:24555ms step_avg:96.67ms -step:255/1695 train_time:24648ms step_avg:96.66ms -step:256/1695 train_time:24742ms step_avg:96.65ms -step:257/1695 train_time:24835ms step_avg:96.63ms -step:258/1695 train_time:24928ms step_avg:96.62ms -step:259/1695 train_time:25021ms step_avg:96.61ms -step:260/1695 train_time:25114ms step_avg:96.59ms -step:261/1695 train_time:25208ms step_avg:96.58ms -step:262/1695 train_time:25304ms step_avg:96.58ms -step:263/1695 train_time:25401ms step_avg:96.58ms -step:264/1695 train_time:25495ms step_avg:96.57ms -step:265/1695 train_time:25590ms step_avg:96.57ms -step:266/1695 train_time:25684ms step_avg:96.56ms -step:267/1695 train_time:25777ms step_avg:96.54ms -step:268/1695 train_time:25870ms step_avg:96.53ms -step:269/1695 train_time:25964ms step_avg:96.52ms -step:270/1695 train_time:26058ms step_avg:96.51ms -step:271/1695 train_time:26151ms step_avg:96.50ms -step:272/1695 train_time:26245ms step_avg:96.49ms -step:273/1695 train_time:26341ms step_avg:96.49ms -step:274/1695 train_time:26436ms step_avg:96.48ms -step:275/1695 train_time:26531ms step_avg:96.48ms -step:276/1695 train_time:26626ms step_avg:96.47ms -step:277/1695 train_time:26720ms step_avg:96.46ms -step:278/1695 train_time:26813ms step_avg:96.45ms -step:279/1695 train_time:26906ms step_avg:96.44ms -step:280/1695 train_time:26999ms step_avg:96.43ms -step:281/1695 train_time:27092ms step_avg:96.41ms -step:282/1695 train_time:27187ms step_avg:96.41ms -step:283/1695 train_time:27282ms step_avg:96.40ms -step:284/1695 train_time:27376ms step_avg:96.40ms -step:285/1695 train_time:27471ms step_avg:96.39ms -step:286/1695 train_time:27566ms step_avg:96.39ms -step:287/1695 train_time:27661ms step_avg:96.38ms -step:288/1695 train_time:27754ms step_avg:96.37ms -step:289/1695 train_time:27849ms step_avg:96.36ms -step:290/1695 train_time:27943ms step_avg:96.35ms -step:291/1695 train_time:28037ms step_avg:96.35ms -step:292/1695 train_time:28130ms step_avg:96.33ms -step:293/1695 train_time:28224ms step_avg:96.33ms -step:294/1695 train_time:28318ms step_avg:96.32ms -step:295/1695 train_time:28412ms step_avg:96.31ms -step:296/1695 train_time:28507ms step_avg:96.31ms -step:297/1695 train_time:28603ms step_avg:96.30ms -step:298/1695 train_time:28697ms step_avg:96.30ms -step:299/1695 train_time:28791ms step_avg:96.29ms -step:300/1695 train_time:28885ms step_avg:96.28ms -step:301/1695 train_time:28979ms step_avg:96.27ms -step:302/1695 train_time:29072ms step_avg:96.27ms -step:303/1695 train_time:29166ms step_avg:96.26ms -step:304/1695 train_time:29261ms step_avg:96.25ms -step:305/1695 train_time:29355ms step_avg:96.25ms -step:306/1695 train_time:29449ms step_avg:96.24ms -step:307/1695 train_time:29544ms step_avg:96.23ms -step:308/1695 train_time:29638ms step_avg:96.23ms -step:309/1695 train_time:29732ms step_avg:96.22ms -step:310/1695 train_time:29827ms step_avg:96.21ms -step:311/1695 train_time:29921ms step_avg:96.21ms -step:312/1695 train_time:30014ms step_avg:96.20ms -step:313/1695 train_time:30108ms step_avg:96.19ms -step:314/1695 train_time:30203ms step_avg:96.19ms -step:315/1695 train_time:30297ms step_avg:96.18ms -step:316/1695 train_time:30391ms step_avg:96.17ms -step:317/1695 train_time:30486ms step_avg:96.17ms -step:318/1695 train_time:30580ms step_avg:96.16ms -step:319/1695 train_time:30674ms step_avg:96.16ms -step:320/1695 train_time:30769ms step_avg:96.15ms -step:321/1695 train_time:30864ms step_avg:96.15ms -step:322/1695 train_time:30957ms step_avg:96.14ms -step:323/1695 train_time:31050ms step_avg:96.13ms -step:324/1695 train_time:31145ms step_avg:96.13ms -step:325/1695 train_time:31239ms step_avg:96.12ms -step:326/1695 train_time:31333ms step_avg:96.11ms -step:327/1695 train_time:31428ms step_avg:96.11ms -step:328/1695 train_time:31523ms step_avg:96.11ms -step:329/1695 train_time:31616ms step_avg:96.10ms -step:330/1695 train_time:31710ms step_avg:96.09ms -step:331/1695 train_time:31805ms step_avg:96.09ms -step:332/1695 train_time:31899ms step_avg:96.08ms -step:333/1695 train_time:31993ms step_avg:96.07ms -step:334/1695 train_time:32087ms step_avg:96.07ms -step:335/1695 train_time:32180ms step_avg:96.06ms -step:336/1695 train_time:32274ms step_avg:96.05ms -step:337/1695 train_time:32368ms step_avg:96.05ms -step:338/1695 train_time:32468ms step_avg:96.06ms -step:339/1695 train_time:32561ms step_avg:96.05ms -step:340/1695 train_time:32655ms step_avg:96.04ms -step:341/1695 train_time:32748ms step_avg:96.04ms -step:342/1695 train_time:32839ms step_avg:96.02ms -step:343/1695 train_time:32933ms step_avg:96.01ms -step:344/1695 train_time:33027ms step_avg:96.01ms -step:345/1695 train_time:33355ms step_avg:96.68ms -step:346/1695 train_time:33456ms step_avg:96.69ms -step:347/1695 train_time:33548ms step_avg:96.68ms -step:348/1695 train_time:33642ms step_avg:96.67ms -step:349/1695 train_time:33735ms step_avg:96.66ms -step:350/1695 train_time:33828ms step_avg:96.65ms -step:351/1695 train_time:33921ms step_avg:96.64ms -step:352/1695 train_time:34014ms step_avg:96.63ms -step:353/1695 train_time:34106ms step_avg:96.62ms -step:354/1695 train_time:34200ms step_avg:96.61ms -step:355/1695 train_time:34296ms step_avg:96.61ms -step:356/1695 train_time:34393ms step_avg:96.61ms -step:357/1695 train_time:34490ms step_avg:96.61ms -step:358/1695 train_time:34584ms step_avg:96.60ms -step:359/1695 train_time:34678ms step_avg:96.60ms -step:360/1695 train_time:34771ms step_avg:96.59ms -step:361/1695 train_time:34865ms step_avg:96.58ms -step:362/1695 train_time:34959ms step_avg:96.57ms -step:363/1695 train_time:35051ms step_avg:96.56ms -step:364/1695 train_time:35145ms step_avg:96.55ms -step:365/1695 train_time:35238ms step_avg:96.54ms -step:366/1695 train_time:35332ms step_avg:96.54ms -step:367/1695 train_time:35428ms step_avg:96.53ms -step:368/1695 train_time:35523ms step_avg:96.53ms -step:369/1695 train_time:35616ms step_avg:96.52ms -step:370/1695 train_time:35710ms step_avg:96.51ms -step:371/1695 train_time:35803ms step_avg:96.51ms -step:372/1695 train_time:35897ms step_avg:96.50ms -step:373/1695 train_time:35990ms step_avg:96.49ms -step:374/1695 train_time:36083ms step_avg:96.48ms -step:375/1695 train_time:36176ms step_avg:96.47ms -step:375/1695 val_loss:3.8232 train_time:36268ms step_avg:96.71ms -step:376/1695 train_time:36292ms step_avg:96.52ms -step:377/1695 train_time:36372ms step_avg:96.48ms -step:378/1695 train_time:36470ms step_avg:96.48ms -step:379/1695 train_time:36565ms step_avg:96.48ms -step:380/1695 train_time:36659ms step_avg:96.47ms -step:381/1695 train_time:36752ms step_avg:96.46ms -step:382/1695 train_time:36845ms step_avg:96.45ms -step:383/1695 train_time:36938ms step_avg:96.44ms -step:384/1695 train_time:37031ms step_avg:96.43ms -step:385/1695 train_time:37124ms step_avg:96.43ms -step:386/1695 train_time:37218ms step_avg:96.42ms -step:387/1695 train_time:37313ms step_avg:96.42ms -step:388/1695 train_time:37410ms step_avg:96.42ms -step:389/1695 train_time:37505ms step_avg:96.41ms -step:390/1695 train_time:37599ms step_avg:96.41ms -step:391/1695 train_time:37693ms step_avg:96.40ms -step:392/1695 train_time:37787ms step_avg:96.40ms -step:393/1695 train_time:37880ms step_avg:96.39ms -step:394/1695 train_time:37973ms step_avg:96.38ms -step:395/1695 train_time:38067ms step_avg:96.37ms -step:396/1695 train_time:38160ms step_avg:96.36ms -step:397/1695 train_time:38254ms step_avg:96.36ms -step:398/1695 train_time:38350ms step_avg:96.36ms -step:399/1695 train_time:38445ms step_avg:96.35ms -step:400/1695 train_time:38540ms step_avg:96.35ms -step:401/1695 train_time:38633ms step_avg:96.34ms -step:402/1695 train_time:38727ms step_avg:96.34ms -step:403/1695 train_time:38822ms step_avg:96.33ms -step:404/1695 train_time:38915ms step_avg:96.32ms -step:405/1695 train_time:39008ms step_avg:96.32ms -step:406/1695 train_time:39101ms step_avg:96.31ms -step:407/1695 train_time:39194ms step_avg:96.30ms -step:408/1695 train_time:39289ms step_avg:96.30ms -step:409/1695 train_time:39384ms step_avg:96.29ms -step:410/1695 train_time:39478ms step_avg:96.29ms -step:411/1695 train_time:39572ms step_avg:96.28ms -step:412/1695 train_time:39667ms step_avg:96.28ms -step:413/1695 train_time:39762ms step_avg:96.28ms -step:414/1695 train_time:39856ms step_avg:96.27ms -step:415/1695 train_time:39949ms step_avg:96.26ms -step:416/1695 train_time:40043ms step_avg:96.26ms -step:417/1695 train_time:40137ms step_avg:96.25ms -step:418/1695 train_time:40230ms step_avg:96.24ms -step:419/1695 train_time:40324ms step_avg:96.24ms -step:420/1695 train_time:40419ms step_avg:96.24ms -step:421/1695 train_time:40513ms step_avg:96.23ms -step:422/1695 train_time:40607ms step_avg:96.22ms -step:423/1695 train_time:40702ms step_avg:96.22ms -step:424/1695 train_time:40795ms step_avg:96.21ms -step:425/1695 train_time:40889ms step_avg:96.21ms -step:426/1695 train_time:40984ms step_avg:96.21ms -step:427/1695 train_time:41077ms step_avg:96.20ms -step:428/1695 train_time:41170ms step_avg:96.19ms -step:429/1695 train_time:41264ms step_avg:96.19ms -step:430/1695 train_time:41359ms step_avg:96.18ms -step:431/1695 train_time:41453ms step_avg:96.18ms -step:432/1695 train_time:41547ms step_avg:96.17ms -step:433/1695 train_time:41641ms step_avg:96.17ms -step:434/1695 train_time:41734ms step_avg:96.16ms -step:435/1695 train_time:41828ms step_avg:96.16ms -step:436/1695 train_time:41922ms step_avg:96.15ms -step:437/1695 train_time:42016ms step_avg:96.15ms -step:438/1695 train_time:42110ms step_avg:96.14ms -step:439/1695 train_time:42203ms step_avg:96.13ms -step:440/1695 train_time:42297ms step_avg:96.13ms -step:441/1695 train_time:42392ms step_avg:96.13ms -step:442/1695 train_time:42487ms step_avg:96.12ms -step:443/1695 train_time:42583ms step_avg:96.12ms -step:444/1695 train_time:42676ms step_avg:96.12ms -step:445/1695 train_time:42770ms step_avg:96.11ms -step:446/1695 train_time:42864ms step_avg:96.11ms -step:447/1695 train_time:42957ms step_avg:96.10ms -step:448/1695 train_time:43050ms step_avg:96.09ms -step:449/1695 train_time:43144ms step_avg:96.09ms -step:450/1695 train_time:43237ms step_avg:96.08ms -step:451/1695 train_time:43331ms step_avg:96.08ms -step:452/1695 train_time:43426ms step_avg:96.07ms -step:453/1695 train_time:43519ms step_avg:96.07ms -step:454/1695 train_time:43613ms step_avg:96.06ms -step:455/1695 train_time:43707ms step_avg:96.06ms -step:456/1695 train_time:43801ms step_avg:96.06ms -step:457/1695 train_time:43895ms step_avg:96.05ms -step:458/1695 train_time:43990ms step_avg:96.05ms -step:459/1695 train_time:44084ms step_avg:96.04ms -step:460/1695 train_time:44178ms step_avg:96.04ms -step:461/1695 train_time:44271ms step_avg:96.03ms -step:462/1695 train_time:44365ms step_avg:96.03ms -step:463/1695 train_time:44459ms step_avg:96.02ms -step:464/1695 train_time:44553ms step_avg:96.02ms -step:465/1695 train_time:44647ms step_avg:96.02ms -step:466/1695 train_time:44742ms step_avg:96.01ms -step:467/1695 train_time:44836ms step_avg:96.01ms -step:468/1695 train_time:44930ms step_avg:96.00ms -step:469/1695 train_time:45025ms step_avg:96.00ms -step:470/1695 train_time:45119ms step_avg:96.00ms -step:471/1695 train_time:45212ms step_avg:95.99ms -step:472/1695 train_time:45306ms step_avg:95.99ms -step:473/1695 train_time:45401ms step_avg:95.98ms -step:474/1695 train_time:45494ms step_avg:95.98ms -step:475/1695 train_time:45589ms step_avg:95.98ms -step:476/1695 train_time:45683ms step_avg:95.97ms -step:477/1695 train_time:45777ms step_avg:95.97ms -step:478/1695 train_time:45870ms step_avg:95.96ms -step:479/1695 train_time:45964ms step_avg:95.96ms -step:480/1695 train_time:46059ms step_avg:95.96ms -step:481/1695 train_time:46153ms step_avg:95.95ms -step:482/1695 train_time:46247ms step_avg:95.95ms -step:483/1695 train_time:46341ms step_avg:95.94ms -step:484/1695 train_time:46435ms step_avg:95.94ms -step:485/1695 train_time:46529ms step_avg:95.94ms -step:486/1695 train_time:46623ms step_avg:95.93ms -step:487/1695 train_time:46718ms step_avg:95.93ms -step:488/1695 train_time:46811ms step_avg:95.92ms -step:489/1695 train_time:46905ms step_avg:95.92ms -step:490/1695 train_time:46998ms step_avg:95.91ms -step:491/1695 train_time:47092ms step_avg:95.91ms -step:492/1695 train_time:47186ms step_avg:95.91ms -step:493/1695 train_time:47280ms step_avg:95.90ms -step:494/1695 train_time:47374ms step_avg:95.90ms -step:495/1695 train_time:47468ms step_avg:95.89ms -step:496/1695 train_time:47562ms step_avg:95.89ms -step:497/1695 train_time:47657ms step_avg:95.89ms -step:498/1695 train_time:47751ms step_avg:95.88ms -step:499/1695 train_time:47845ms step_avg:95.88ms -step:500/1695 train_time:47938ms step_avg:95.88ms -step:500/1695 val_loss:3.7202 train_time:48030ms step_avg:96.06ms -step:501/1695 train_time:48054ms step_avg:95.92ms -step:502/1695 train_time:48133ms step_avg:95.88ms -step:503/1695 train_time:48232ms step_avg:95.89ms -step:504/1695 train_time:48327ms step_avg:95.89ms -step:505/1695 train_time:48419ms step_avg:95.88ms -step:506/1695 train_time:48513ms step_avg:95.88ms -step:507/1695 train_time:48607ms step_avg:95.87ms -step:508/1695 train_time:48699ms step_avg:95.86ms -step:509/1695 train_time:48792ms step_avg:95.86ms -step:510/1695 train_time:48885ms step_avg:95.85ms -step:511/1695 train_time:48979ms step_avg:95.85ms -step:512/1695 train_time:49076ms step_avg:95.85ms -step:513/1695 train_time:49173ms step_avg:95.85ms -step:514/1695 train_time:49268ms step_avg:95.85ms -step:515/1695 train_time:49363ms step_avg:95.85ms -step:516/1695 train_time:49456ms step_avg:95.84ms -step:517/1695 train_time:49549ms step_avg:95.84ms -step:518/1695 train_time:49643ms step_avg:95.84ms -step:519/1695 train_time:49968ms step_avg:96.28ms -step:520/1695 train_time:50168ms step_avg:96.48ms -step:521/1695 train_time:50261ms step_avg:96.47ms -step:522/1695 train_time:50353ms step_avg:96.46ms -step:523/1695 train_time:50446ms step_avg:96.46ms -step:524/1695 train_time:50539ms step_avg:96.45ms -step:525/1695 train_time:50632ms step_avg:96.44ms -step:526/1695 train_time:50725ms step_avg:96.43ms -step:527/1695 train_time:50817ms step_avg:96.43ms -step:528/1695 train_time:50910ms step_avg:96.42ms -step:529/1695 train_time:51008ms step_avg:96.42ms -step:530/1695 train_time:51106ms step_avg:96.43ms -step:531/1695 train_time:51202ms step_avg:96.43ms -step:532/1695 train_time:51296ms step_avg:96.42ms -step:533/1695 train_time:51389ms step_avg:96.41ms -step:534/1695 train_time:51482ms step_avg:96.41ms -step:535/1695 train_time:51575ms step_avg:96.40ms -step:536/1695 train_time:51668ms step_avg:96.40ms -step:537/1695 train_time:51761ms step_avg:96.39ms -step:538/1695 train_time:51854ms step_avg:96.38ms -step:539/1695 train_time:51949ms step_avg:96.38ms -step:540/1695 train_time:52044ms step_avg:96.38ms -step:541/1695 train_time:52139ms step_avg:96.38ms -step:542/1695 train_time:52234ms step_avg:96.37ms -step:543/1695 train_time:52328ms step_avg:96.37ms -step:544/1695 train_time:52421ms step_avg:96.36ms -step:545/1695 train_time:52515ms step_avg:96.36ms -step:546/1695 train_time:52609ms step_avg:96.35ms -step:547/1695 train_time:52702ms step_avg:96.35ms -step:548/1695 train_time:52795ms step_avg:96.34ms -step:549/1695 train_time:52889ms step_avg:96.34ms -step:550/1695 train_time:52983ms step_avg:96.33ms -step:551/1695 train_time:53077ms step_avg:96.33ms -step:552/1695 train_time:53172ms step_avg:96.33ms -step:553/1695 train_time:53267ms step_avg:96.32ms -step:554/1695 train_time:53361ms step_avg:96.32ms -step:555/1695 train_time:53454ms step_avg:96.31ms -step:556/1695 train_time:53548ms step_avg:96.31ms -step:557/1695 train_time:53640ms step_avg:96.30ms -step:558/1695 train_time:53734ms step_avg:96.30ms -step:559/1695 train_time:53827ms step_avg:96.29ms -step:560/1695 train_time:53920ms step_avg:96.29ms -step:561/1695 train_time:54015ms step_avg:96.28ms -step:562/1695 train_time:54111ms step_avg:96.28ms -step:563/1695 train_time:54207ms step_avg:96.28ms -step:564/1695 train_time:54301ms step_avg:96.28ms -step:565/1695 train_time:54395ms step_avg:96.27ms -step:566/1695 train_time:54488ms step_avg:96.27ms -step:567/1695 train_time:54583ms step_avg:96.27ms -step:568/1695 train_time:54678ms step_avg:96.26ms -step:569/1695 train_time:54774ms step_avg:96.26ms -step:570/1695 train_time:54869ms step_avg:96.26ms -step:571/1695 train_time:54965ms step_avg:96.26ms -step:572/1695 train_time:55061ms step_avg:96.26ms -step:573/1695 train_time:55157ms step_avg:96.26ms -step:574/1695 train_time:55253ms step_avg:96.26ms -step:575/1695 train_time:55349ms step_avg:96.26ms -step:576/1695 train_time:55445ms step_avg:96.26ms -step:577/1695 train_time:55540ms step_avg:96.26ms -step:578/1695 train_time:55636ms step_avg:96.26ms -step:579/1695 train_time:55732ms step_avg:96.26ms -step:580/1695 train_time:55828ms step_avg:96.25ms -step:581/1695 train_time:55923ms step_avg:96.25ms -step:582/1695 train_time:56018ms step_avg:96.25ms -step:583/1695 train_time:56116ms step_avg:96.25ms -step:584/1695 train_time:56213ms step_avg:96.25ms -step:585/1695 train_time:56310ms step_avg:96.26ms -step:586/1695 train_time:56407ms step_avg:96.26ms -step:587/1695 train_time:56502ms step_avg:96.26ms -step:588/1695 train_time:56597ms step_avg:96.25ms -step:589/1695 train_time:56693ms step_avg:96.25ms -step:590/1695 train_time:56789ms step_avg:96.25ms -step:591/1695 train_time:56885ms step_avg:96.25ms -step:592/1695 train_time:56981ms step_avg:96.25ms -step:593/1695 train_time:57077ms step_avg:96.25ms -step:594/1695 train_time:57173ms step_avg:96.25ms -step:595/1695 train_time:57270ms step_avg:96.25ms -step:596/1695 train_time:57365ms step_avg:96.25ms -step:597/1695 train_time:57461ms step_avg:96.25ms -step:598/1695 train_time:57556ms step_avg:96.25ms -step:599/1695 train_time:57653ms step_avg:96.25ms -step:600/1695 train_time:57749ms step_avg:96.25ms -step:601/1695 train_time:57845ms step_avg:96.25ms -step:602/1695 train_time:57941ms step_avg:96.25ms -step:603/1695 train_time:58037ms step_avg:96.25ms -step:604/1695 train_time:58134ms step_avg:96.25ms -step:605/1695 train_time:58230ms step_avg:96.25ms -step:606/1695 train_time:58326ms step_avg:96.25ms -step:607/1695 train_time:58421ms step_avg:96.25ms -step:608/1695 train_time:58518ms step_avg:96.25ms -step:609/1695 train_time:58614ms step_avg:96.25ms -step:610/1695 train_time:58711ms step_avg:96.25ms -step:611/1695 train_time:58808ms step_avg:96.25ms -step:612/1695 train_time:58904ms step_avg:96.25ms -step:613/1695 train_time:59000ms step_avg:96.25ms -step:614/1695 train_time:59095ms step_avg:96.25ms -step:615/1695 train_time:59192ms step_avg:96.25ms -step:616/1695 train_time:59289ms step_avg:96.25ms -step:617/1695 train_time:59385ms step_avg:96.25ms -step:618/1695 train_time:59481ms step_avg:96.25ms -step:619/1695 train_time:59576ms step_avg:96.25ms -step:620/1695 train_time:59673ms step_avg:96.25ms -step:621/1695 train_time:59769ms step_avg:96.25ms -step:622/1695 train_time:59865ms step_avg:96.25ms -step:623/1695 train_time:59960ms step_avg:96.24ms -step:624/1695 train_time:60056ms step_avg:96.24ms -step:625/1695 train_time:60152ms step_avg:96.24ms -step:625/1695 val_loss:3.6216 train_time:60246ms step_avg:96.39ms -step:626/1695 train_time:60272ms step_avg:96.28ms -step:627/1695 train_time:60351ms step_avg:96.25ms -step:628/1695 train_time:60448ms step_avg:96.25ms -step:629/1695 train_time:60544ms step_avg:96.25ms -step:630/1695 train_time:60638ms step_avg:96.25ms -step:631/1695 train_time:60733ms step_avg:96.25ms -step:632/1695 train_time:60828ms step_avg:96.25ms -step:633/1695 train_time:60923ms step_avg:96.25ms -step:634/1695 train_time:61018ms step_avg:96.24ms -step:635/1695 train_time:61112ms step_avg:96.24ms -step:636/1695 train_time:61210ms step_avg:96.24ms -step:637/1695 train_time:61308ms step_avg:96.25ms -step:638/1695 train_time:61406ms step_avg:96.25ms -step:639/1695 train_time:61503ms step_avg:96.25ms -step:640/1695 train_time:61598ms step_avg:96.25ms -step:641/1695 train_time:61693ms step_avg:96.24ms -step:642/1695 train_time:61788ms step_avg:96.24ms -step:643/1695 train_time:61883ms step_avg:96.24ms -step:644/1695 train_time:61977ms step_avg:96.24ms -step:645/1695 train_time:62072ms step_avg:96.24ms -step:646/1695 train_time:62168ms step_avg:96.23ms -step:647/1695 train_time:62264ms step_avg:96.24ms -step:648/1695 train_time:62361ms step_avg:96.24ms -step:649/1695 train_time:62458ms step_avg:96.24ms -step:650/1695 train_time:62554ms step_avg:96.24ms -step:651/1695 train_time:62651ms step_avg:96.24ms -step:652/1695 train_time:62747ms step_avg:96.24ms -step:653/1695 train_time:62843ms step_avg:96.24ms -step:654/1695 train_time:62937ms step_avg:96.23ms -step:655/1695 train_time:63034ms step_avg:96.23ms -step:656/1695 train_time:63130ms step_avg:96.24ms -step:657/1695 train_time:63228ms step_avg:96.24ms -step:658/1695 train_time:63324ms step_avg:96.24ms -step:659/1695 train_time:63420ms step_avg:96.24ms -step:660/1695 train_time:63515ms step_avg:96.24ms -step:661/1695 train_time:63612ms step_avg:96.24ms -step:662/1695 train_time:63707ms step_avg:96.23ms -step:663/1695 train_time:63803ms step_avg:96.23ms -step:664/1695 train_time:63897ms step_avg:96.23ms -step:665/1695 train_time:63993ms step_avg:96.23ms -step:666/1695 train_time:64089ms step_avg:96.23ms -step:667/1695 train_time:64186ms step_avg:96.23ms -step:668/1695 train_time:64282ms step_avg:96.23ms -step:669/1695 train_time:64377ms step_avg:96.23ms -step:670/1695 train_time:64473ms step_avg:96.23ms -step:671/1695 train_time:64570ms step_avg:96.23ms -step:672/1695 train_time:64666ms step_avg:96.23ms -step:673/1695 train_time:64762ms step_avg:96.23ms -step:674/1695 train_time:64857ms step_avg:96.23ms -step:675/1695 train_time:64954ms step_avg:96.23ms -step:676/1695 train_time:65048ms step_avg:96.23ms -step:677/1695 train_time:65144ms step_avg:96.22ms -step:678/1695 train_time:65238ms step_avg:96.22ms -step:679/1695 train_time:65335ms step_avg:96.22ms -step:680/1695 train_time:65431ms step_avg:96.22ms -step:681/1695 train_time:65528ms step_avg:96.22ms -step:682/1695 train_time:65624ms step_avg:96.22ms -step:683/1695 train_time:65720ms step_avg:96.22ms -step:684/1695 train_time:65815ms step_avg:96.22ms -step:685/1695 train_time:65911ms step_avg:96.22ms -step:686/1695 train_time:66007ms step_avg:96.22ms -step:687/1695 train_time:66103ms step_avg:96.22ms -step:688/1695 train_time:66199ms step_avg:96.22ms -step:689/1695 train_time:66294ms step_avg:96.22ms -step:690/1695 train_time:66390ms step_avg:96.22ms -step:691/1695 train_time:66847ms step_avg:96.74ms -step:692/1695 train_time:66917ms step_avg:96.70ms -step:693/1695 train_time:67011ms step_avg:96.70ms -step:694/1695 train_time:67106ms step_avg:96.69ms -step:695/1695 train_time:67201ms step_avg:96.69ms -step:696/1695 train_time:67296ms step_avg:96.69ms -step:697/1695 train_time:67392ms step_avg:96.69ms -step:698/1695 train_time:67487ms step_avg:96.69ms -step:699/1695 train_time:67581ms step_avg:96.68ms -step:700/1695 train_time:67676ms step_avg:96.68ms -step:701/1695 train_time:67776ms step_avg:96.69ms -step:702/1695 train_time:67879ms step_avg:96.69ms -step:703/1695 train_time:67976ms step_avg:96.69ms -step:704/1695 train_time:68073ms step_avg:96.69ms -step:705/1695 train_time:68169ms step_avg:96.69ms -step:706/1695 train_time:68266ms step_avg:96.69ms -step:707/1695 train_time:68361ms step_avg:96.69ms -step:708/1695 train_time:68455ms step_avg:96.69ms -step:709/1695 train_time:68550ms step_avg:96.69ms -step:710/1695 train_time:68645ms step_avg:96.68ms -step:711/1695 train_time:68741ms step_avg:96.68ms -step:712/1695 train_time:68838ms step_avg:96.68ms -step:713/1695 train_time:68936ms step_avg:96.68ms -step:714/1695 train_time:69034ms step_avg:96.69ms -step:715/1695 train_time:69130ms step_avg:96.69ms -step:716/1695 train_time:69226ms step_avg:96.68ms -step:717/1695 train_time:69321ms step_avg:96.68ms -step:718/1695 train_time:69416ms step_avg:96.68ms -step:719/1695 train_time:69512ms step_avg:96.68ms -step:720/1695 train_time:69607ms step_avg:96.68ms -step:721/1695 train_time:69702ms step_avg:96.67ms -step:722/1695 train_time:69799ms step_avg:96.67ms -step:723/1695 train_time:69895ms step_avg:96.67ms -step:724/1695 train_time:69992ms step_avg:96.67ms -step:725/1695 train_time:70088ms step_avg:96.67ms -step:726/1695 train_time:70185ms step_avg:96.67ms -step:727/1695 train_time:70280ms step_avg:96.67ms -step:728/1695 train_time:70375ms step_avg:96.67ms -step:729/1695 train_time:70471ms step_avg:96.67ms -step:730/1695 train_time:70566ms step_avg:96.67ms -step:731/1695 train_time:70661ms step_avg:96.66ms -step:732/1695 train_time:70757ms step_avg:96.66ms -step:733/1695 train_time:70853ms step_avg:96.66ms -step:734/1695 train_time:70951ms step_avg:96.66ms -step:735/1695 train_time:71049ms step_avg:96.66ms -step:736/1695 train_time:71145ms step_avg:96.66ms -step:737/1695 train_time:71240ms step_avg:96.66ms -step:738/1695 train_time:71336ms step_avg:96.66ms -step:739/1695 train_time:71432ms step_avg:96.66ms -step:740/1695 train_time:71528ms step_avg:96.66ms -step:741/1695 train_time:71623ms step_avg:96.66ms -step:742/1695 train_time:71719ms step_avg:96.66ms -step:743/1695 train_time:71814ms step_avg:96.65ms -step:744/1695 train_time:71910ms step_avg:96.65ms -step:745/1695 train_time:72007ms step_avg:96.65ms -step:746/1695 train_time:72103ms step_avg:96.65ms -step:747/1695 train_time:72199ms step_avg:96.65ms -step:748/1695 train_time:72295ms step_avg:96.65ms -step:749/1695 train_time:72391ms step_avg:96.65ms -step:750/1695 train_time:72488ms step_avg:96.65ms -step:750/1695 val_loss:3.5671 train_time:72581ms step_avg:96.77ms -step:751/1695 train_time:72608ms step_avg:96.68ms -step:752/1695 train_time:72687ms step_avg:96.66ms -step:753/1695 train_time:72784ms step_avg:96.66ms -step:754/1695 train_time:72880ms step_avg:96.66ms -step:755/1695 train_time:72976ms step_avg:96.66ms -step:756/1695 train_time:73072ms step_avg:96.66ms -step:757/1695 train_time:73167ms step_avg:96.65ms -step:758/1695 train_time:73262ms step_avg:96.65ms -step:759/1695 train_time:73357ms step_avg:96.65ms -step:760/1695 train_time:73452ms step_avg:96.65ms -step:761/1695 train_time:73549ms step_avg:96.65ms -step:762/1695 train_time:73646ms step_avg:96.65ms -step:763/1695 train_time:73744ms step_avg:96.65ms -step:764/1695 train_time:73841ms step_avg:96.65ms -step:765/1695 train_time:73938ms step_avg:96.65ms -step:766/1695 train_time:74034ms step_avg:96.65ms -step:767/1695 train_time:74129ms step_avg:96.65ms -step:768/1695 train_time:74224ms step_avg:96.65ms -step:769/1695 train_time:74319ms step_avg:96.64ms -step:770/1695 train_time:74414ms step_avg:96.64ms -step:771/1695 train_time:74509ms step_avg:96.64ms -step:772/1695 train_time:74606ms step_avg:96.64ms -step:773/1695 train_time:74703ms step_avg:96.64ms -step:774/1695 train_time:74800ms step_avg:96.64ms -step:775/1695 train_time:74897ms step_avg:96.64ms -step:776/1695 train_time:74994ms step_avg:96.64ms -step:777/1695 train_time:75090ms step_avg:96.64ms -step:778/1695 train_time:75185ms step_avg:96.64ms -step:779/1695 train_time:75280ms step_avg:96.64ms -step:780/1695 train_time:75375ms step_avg:96.63ms -step:781/1695 train_time:75472ms step_avg:96.63ms -step:782/1695 train_time:75567ms step_avg:96.63ms -step:783/1695 train_time:75663ms step_avg:96.63ms -step:784/1695 train_time:75760ms step_avg:96.63ms -step:785/1695 train_time:75857ms step_avg:96.63ms -step:786/1695 train_time:75953ms step_avg:96.63ms -step:787/1695 train_time:76049ms step_avg:96.63ms -step:788/1695 train_time:76144ms step_avg:96.63ms -step:789/1695 train_time:76239ms step_avg:96.63ms -step:790/1695 train_time:76335ms step_avg:96.63ms -step:791/1695 train_time:76433ms step_avg:96.63ms -step:792/1695 train_time:76529ms step_avg:96.63ms -step:793/1695 train_time:76624ms step_avg:96.63ms -step:794/1695 train_time:76721ms step_avg:96.63ms -step:795/1695 train_time:76818ms step_avg:96.63ms -step:796/1695 train_time:76916ms step_avg:96.63ms -step:797/1695 train_time:77012ms step_avg:96.63ms -step:798/1695 train_time:77108ms step_avg:96.63ms -step:799/1695 train_time:77202ms step_avg:96.62ms -step:800/1695 train_time:77298ms step_avg:96.62ms -step:801/1695 train_time:77394ms step_avg:96.62ms -step:802/1695 train_time:77490ms step_avg:96.62ms -step:803/1695 train_time:77585ms step_avg:96.62ms -step:804/1695 train_time:77681ms step_avg:96.62ms -step:805/1695 train_time:77778ms step_avg:96.62ms -step:806/1695 train_time:77874ms step_avg:96.62ms -step:807/1695 train_time:77971ms step_avg:96.62ms -step:808/1695 train_time:78067ms step_avg:96.62ms -step:809/1695 train_time:78163ms step_avg:96.62ms -step:810/1695 train_time:78258ms step_avg:96.62ms -step:811/1695 train_time:78355ms step_avg:96.61ms -step:812/1695 train_time:78450ms step_avg:96.61ms -step:813/1695 train_time:78545ms step_avg:96.61ms -step:814/1695 train_time:78641ms step_avg:96.61ms -step:815/1695 train_time:78738ms step_avg:96.61ms -step:816/1695 train_time:78835ms step_avg:96.61ms -step:817/1695 train_time:78932ms step_avg:96.61ms -step:818/1695 train_time:79028ms step_avg:96.61ms -step:819/1695 train_time:79123ms step_avg:96.61ms -step:820/1695 train_time:79219ms step_avg:96.61ms -step:821/1695 train_time:79316ms step_avg:96.61ms -step:822/1695 train_time:79412ms step_avg:96.61ms -step:823/1695 train_time:79508ms step_avg:96.61ms -step:824/1695 train_time:79604ms step_avg:96.61ms -step:825/1695 train_time:79700ms step_avg:96.61ms -step:826/1695 train_time:79796ms step_avg:96.60ms -step:827/1695 train_time:79892ms step_avg:96.61ms -step:828/1695 train_time:79989ms step_avg:96.61ms -step:829/1695 train_time:80084ms step_avg:96.60ms -step:830/1695 train_time:80180ms step_avg:96.60ms -step:831/1695 train_time:80276ms step_avg:96.60ms -step:832/1695 train_time:80373ms step_avg:96.60ms -step:833/1695 train_time:80469ms step_avg:96.60ms -step:834/1695 train_time:80565ms step_avg:96.60ms -step:835/1695 train_time:80660ms step_avg:96.60ms -step:836/1695 train_time:80756ms step_avg:96.60ms -step:837/1695 train_time:80853ms step_avg:96.60ms -step:838/1695 train_time:80949ms step_avg:96.60ms -step:839/1695 train_time:81045ms step_avg:96.60ms -step:840/1695 train_time:81141ms step_avg:96.60ms -step:841/1695 train_time:81237ms step_avg:96.60ms -step:842/1695 train_time:81333ms step_avg:96.59ms -step:843/1695 train_time:81429ms step_avg:96.59ms -step:844/1695 train_time:81524ms step_avg:96.59ms -step:845/1695 train_time:81619ms step_avg:96.59ms -step:846/1695 train_time:81716ms step_avg:96.59ms -step:847/1695 train_time:81812ms step_avg:96.59ms -step:848/1695 train_time:81908ms step_avg:96.59ms -step:849/1695 train_time:82003ms step_avg:96.59ms -step:850/1695 train_time:82099ms step_avg:96.59ms -step:851/1695 train_time:82196ms step_avg:96.59ms -step:852/1695 train_time:82292ms step_avg:96.59ms -step:853/1695 train_time:82388ms step_avg:96.59ms -step:854/1695 train_time:82484ms step_avg:96.59ms -step:855/1695 train_time:82579ms step_avg:96.58ms -step:856/1695 train_time:82676ms step_avg:96.58ms -step:857/1695 train_time:82773ms step_avg:96.58ms -step:858/1695 train_time:82869ms step_avg:96.58ms -step:859/1695 train_time:82964ms step_avg:96.58ms -step:860/1695 train_time:83060ms step_avg:96.58ms -step:861/1695 train_time:83156ms step_avg:96.58ms -step:862/1695 train_time:83252ms step_avg:96.58ms -step:863/1695 train_time:83584ms step_avg:96.85ms -step:864/1695 train_time:83778ms step_avg:96.96ms -step:865/1695 train_time:83872ms step_avg:96.96ms -step:866/1695 train_time:83966ms step_avg:96.96ms -step:867/1695 train_time:84061ms step_avg:96.96ms -step:868/1695 train_time:84156ms step_avg:96.95ms -step:869/1695 train_time:84250ms step_avg:96.95ms -step:870/1695 train_time:84345ms step_avg:96.95ms -step:871/1695 train_time:84440ms step_avg:96.95ms -step:872/1695 train_time:84535ms step_avg:96.94ms -step:873/1695 train_time:84637ms step_avg:96.95ms -step:874/1695 train_time:84737ms step_avg:96.95ms -step:875/1695 train_time:84837ms step_avg:96.96ms -step:875/1695 val_loss:3.5244 train_time:84933ms step_avg:97.07ms -step:876/1695 train_time:84957ms step_avg:96.98ms -step:877/1695 train_time:85038ms step_avg:96.96ms -step:878/1695 train_time:85138ms step_avg:96.97ms -step:879/1695 train_time:85235ms step_avg:96.97ms -step:880/1695 train_time:85330ms step_avg:96.97ms -step:881/1695 train_time:85425ms step_avg:96.96ms -step:882/1695 train_time:85520ms step_avg:96.96ms -step:883/1695 train_time:85614ms step_avg:96.96ms -step:884/1695 train_time:85709ms step_avg:96.96ms -step:885/1695 train_time:85804ms step_avg:96.95ms -step:886/1695 train_time:85901ms step_avg:96.95ms -step:887/1695 train_time:86000ms step_avg:96.96ms -step:888/1695 train_time:86100ms step_avg:96.96ms -step:889/1695 train_time:86198ms step_avg:96.96ms -step:890/1695 train_time:86295ms step_avg:96.96ms -step:891/1695 train_time:86391ms step_avg:96.96ms -step:892/1695 train_time:86486ms step_avg:96.96ms -step:893/1695 train_time:86581ms step_avg:96.96ms -step:894/1695 train_time:86677ms step_avg:96.95ms -step:895/1695 train_time:86772ms step_avg:96.95ms -step:896/1695 train_time:86868ms step_avg:96.95ms -step:897/1695 train_time:86964ms step_avg:96.95ms -step:898/1695 train_time:87062ms step_avg:96.95ms -step:899/1695 train_time:87159ms step_avg:96.95ms -step:900/1695 train_time:87256ms step_avg:96.95ms -step:901/1695 train_time:87351ms step_avg:96.95ms -step:902/1695 train_time:87447ms step_avg:96.95ms -step:903/1695 train_time:87542ms step_avg:96.95ms -step:904/1695 train_time:87638ms step_avg:96.95ms -step:905/1695 train_time:87734ms step_avg:96.94ms -step:906/1695 train_time:87829ms step_avg:96.94ms -step:907/1695 train_time:87925ms step_avg:96.94ms -step:908/1695 train_time:88021ms step_avg:96.94ms -step:909/1695 train_time:88118ms step_avg:96.94ms -step:910/1695 train_time:88216ms step_avg:96.94ms -step:911/1695 train_time:88312ms step_avg:96.94ms -step:912/1695 train_time:88407ms step_avg:96.94ms -step:913/1695 train_time:88503ms step_avg:96.94ms -step:914/1695 train_time:88599ms step_avg:96.94ms -step:915/1695 train_time:88695ms step_avg:96.93ms -step:916/1695 train_time:88791ms step_avg:96.93ms -step:917/1695 train_time:88887ms step_avg:96.93ms -step:918/1695 train_time:88983ms step_avg:96.93ms -step:919/1695 train_time:89080ms step_avg:96.93ms -step:920/1695 train_time:89177ms step_avg:96.93ms -step:921/1695 train_time:89274ms step_avg:96.93ms -step:922/1695 train_time:89370ms step_avg:96.93ms -step:923/1695 train_time:89467ms step_avg:96.93ms -step:924/1695 train_time:89562ms step_avg:96.93ms -step:925/1695 train_time:89658ms step_avg:96.93ms -step:926/1695 train_time:89752ms step_avg:96.92ms -step:927/1695 train_time:89848ms step_avg:96.92ms -step:928/1695 train_time:89944ms step_avg:96.92ms -step:929/1695 train_time:90041ms step_avg:96.92ms -step:930/1695 train_time:90138ms step_avg:96.92ms -step:931/1695 train_time:90235ms step_avg:96.92ms -step:932/1695 train_time:90332ms step_avg:96.92ms -step:933/1695 train_time:90428ms step_avg:96.92ms -step:934/1695 train_time:90523ms step_avg:96.92ms -step:935/1695 train_time:90620ms step_avg:96.92ms -step:936/1695 train_time:90717ms step_avg:96.92ms -step:937/1695 train_time:90813ms step_avg:96.92ms -step:938/1695 train_time:90908ms step_avg:96.92ms -step:939/1695 train_time:91003ms step_avg:96.92ms -step:940/1695 train_time:91099ms step_avg:96.91ms -step:941/1695 train_time:91195ms step_avg:96.91ms -step:942/1695 train_time:91291ms step_avg:96.91ms -step:943/1695 train_time:91386ms step_avg:96.91ms -step:944/1695 train_time:91482ms step_avg:96.91ms -step:945/1695 train_time:91578ms step_avg:96.91ms -step:946/1695 train_time:91675ms step_avg:96.91ms -step:947/1695 train_time:91771ms step_avg:96.91ms -step:948/1695 train_time:91866ms step_avg:96.91ms -step:949/1695 train_time:91962ms step_avg:96.90ms -step:950/1695 train_time:92057ms step_avg:96.90ms -step:951/1695 train_time:92153ms step_avg:96.90ms -step:952/1695 train_time:92249ms step_avg:96.90ms -step:953/1695 train_time:92345ms step_avg:96.90ms -step:954/1695 train_time:92441ms step_avg:96.90ms -step:955/1695 train_time:92538ms step_avg:96.90ms -step:956/1695 train_time:92636ms step_avg:96.90ms -step:957/1695 train_time:92732ms step_avg:96.90ms -step:958/1695 train_time:92827ms step_avg:96.90ms -step:959/1695 train_time:92923ms step_avg:96.90ms -step:960/1695 train_time:93018ms step_avg:96.89ms -step:961/1695 train_time:93114ms step_avg:96.89ms -step:962/1695 train_time:93210ms step_avg:96.89ms -step:963/1695 train_time:93305ms step_avg:96.89ms -step:964/1695 train_time:93401ms step_avg:96.89ms -step:965/1695 train_time:93496ms step_avg:96.89ms -step:966/1695 train_time:93592ms step_avg:96.89ms -step:967/1695 train_time:93688ms step_avg:96.88ms -step:968/1695 train_time:93783ms step_avg:96.88ms -step:969/1695 train_time:93880ms step_avg:96.88ms -step:970/1695 train_time:93976ms step_avg:96.88ms -step:971/1695 train_time:94073ms step_avg:96.88ms -step:972/1695 train_time:94168ms step_avg:96.88ms -step:973/1695 train_time:94263ms step_avg:96.88ms -step:974/1695 train_time:94360ms step_avg:96.88ms -step:975/1695 train_time:94456ms step_avg:96.88ms -step:976/1695 train_time:94553ms step_avg:96.88ms -step:977/1695 train_time:94648ms step_avg:96.88ms -step:978/1695 train_time:94743ms step_avg:96.87ms -step:979/1695 train_time:94840ms step_avg:96.87ms -step:980/1695 train_time:94937ms step_avg:96.87ms -step:981/1695 train_time:95033ms step_avg:96.87ms -step:982/1695 train_time:95129ms step_avg:96.87ms -step:983/1695 train_time:95224ms step_avg:96.87ms -step:984/1695 train_time:95320ms step_avg:96.87ms -step:985/1695 train_time:95416ms step_avg:96.87ms -step:986/1695 train_time:95512ms step_avg:96.87ms -step:987/1695 train_time:95608ms step_avg:96.87ms -step:988/1695 train_time:95703ms step_avg:96.87ms -step:989/1695 train_time:95799ms step_avg:96.86ms -step:990/1695 train_time:95895ms step_avg:96.86ms -step:991/1695 train_time:95991ms step_avg:96.86ms -step:992/1695 train_time:96087ms step_avg:96.86ms -step:993/1695 train_time:96183ms step_avg:96.86ms -step:994/1695 train_time:96280ms step_avg:96.86ms -step:995/1695 train_time:96375ms step_avg:96.86ms -step:996/1695 train_time:96471ms step_avg:96.86ms -step:997/1695 train_time:96566ms step_avg:96.86ms -step:998/1695 train_time:96663ms step_avg:96.86ms -step:999/1695 train_time:96758ms step_avg:96.85ms -step:1000/1695 train_time:96854ms step_avg:96.85ms -step:1000/1695 val_loss:3.4839 train_time:96948ms step_avg:96.95ms -step:1001/1695 train_time:96972ms step_avg:96.88ms -step:1002/1695 train_time:97055ms step_avg:96.86ms -step:1003/1695 train_time:97153ms step_avg:96.86ms -step:1004/1695 train_time:97250ms step_avg:96.86ms -step:1005/1695 train_time:97345ms step_avg:96.86ms -step:1006/1695 train_time:97440ms step_avg:96.86ms -step:1007/1695 train_time:97534ms step_avg:96.86ms -step:1008/1695 train_time:97629ms step_avg:96.85ms -step:1009/1695 train_time:97724ms step_avg:96.85ms -step:1010/1695 train_time:97819ms step_avg:96.85ms -step:1011/1695 train_time:97914ms step_avg:96.85ms -step:1012/1695 train_time:98013ms step_avg:96.85ms -step:1013/1695 train_time:98112ms step_avg:96.85ms -step:1014/1695 train_time:98211ms step_avg:96.86ms -step:1015/1695 train_time:98309ms step_avg:96.86ms -step:1016/1695 train_time:98406ms step_avg:96.86ms -step:1017/1695 train_time:98501ms step_avg:96.85ms -step:1018/1695 train_time:98596ms step_avg:96.85ms -step:1019/1695 train_time:98690ms step_avg:96.85ms -step:1020/1695 train_time:98786ms step_avg:96.85ms -step:1021/1695 train_time:98883ms step_avg:96.85ms -step:1022/1695 train_time:98979ms step_avg:96.85ms -step:1023/1695 train_time:99077ms step_avg:96.85ms -step:1024/1695 train_time:99173ms step_avg:96.85ms -step:1025/1695 train_time:99270ms step_avg:96.85ms -step:1026/1695 train_time:99366ms step_avg:96.85ms -step:1027/1695 train_time:99462ms step_avg:96.85ms -step:1028/1695 train_time:99556ms step_avg:96.84ms -step:1029/1695 train_time:99651ms step_avg:96.84ms -step:1030/1695 train_time:99746ms step_avg:96.84ms -step:1031/1695 train_time:99842ms step_avg:96.84ms -step:1032/1695 train_time:99938ms step_avg:96.84ms -step:1033/1695 train_time:100035ms step_avg:96.84ms -step:1034/1695 train_time:100131ms step_avg:96.84ms -step:1035/1695 train_time:100228ms step_avg:96.84ms -step:1036/1695 train_time:100552ms step_avg:97.06ms -step:1037/1695 train_time:100740ms step_avg:97.15ms -step:1038/1695 train_time:100833ms step_avg:97.14ms -step:1039/1695 train_time:100928ms step_avg:97.14ms -step:1040/1695 train_time:101024ms step_avg:97.14ms -step:1041/1695 train_time:101118ms step_avg:97.14ms -step:1042/1695 train_time:101212ms step_avg:97.13ms -step:1043/1695 train_time:101307ms step_avg:97.13ms -step:1044/1695 train_time:101402ms step_avg:97.13ms -step:1045/1695 train_time:101497ms step_avg:97.13ms -step:1046/1695 train_time:101598ms step_avg:97.13ms -step:1047/1695 train_time:101696ms step_avg:97.13ms -step:1048/1695 train_time:101794ms step_avg:97.13ms -step:1049/1695 train_time:101890ms step_avg:97.13ms -step:1050/1695 train_time:101986ms step_avg:97.13ms -step:1051/1695 train_time:102083ms step_avg:97.13ms -step:1052/1695 train_time:102178ms step_avg:97.13ms -step:1053/1695 train_time:102272ms step_avg:97.12ms -step:1054/1695 train_time:102367ms step_avg:97.12ms -step:1055/1695 train_time:102462ms step_avg:97.12ms -step:1056/1695 train_time:102560ms step_avg:97.12ms -step:1057/1695 train_time:102656ms step_avg:97.12ms -step:1058/1695 train_time:102753ms step_avg:97.12ms -step:1059/1695 train_time:102850ms step_avg:97.12ms -step:1060/1695 train_time:102947ms step_avg:97.12ms -step:1061/1695 train_time:103044ms step_avg:97.12ms -step:1062/1695 train_time:103140ms step_avg:97.12ms -step:1063/1695 train_time:103235ms step_avg:97.12ms -step:1064/1695 train_time:103330ms step_avg:97.11ms -step:1065/1695 train_time:103425ms step_avg:97.11ms -step:1066/1695 train_time:103521ms step_avg:97.11ms -step:1067/1695 train_time:103617ms step_avg:97.11ms -step:1068/1695 train_time:103714ms step_avg:97.11ms -step:1069/1695 train_time:103810ms step_avg:97.11ms -step:1070/1695 train_time:103907ms step_avg:97.11ms -step:1071/1695 train_time:104004ms step_avg:97.11ms -step:1072/1695 train_time:104100ms step_avg:97.11ms -step:1073/1695 train_time:104195ms step_avg:97.11ms -step:1074/1695 train_time:104289ms step_avg:97.10ms -step:1075/1695 train_time:104385ms step_avg:97.10ms -step:1076/1695 train_time:104481ms step_avg:97.10ms -step:1077/1695 train_time:104578ms step_avg:97.10ms -step:1078/1695 train_time:104673ms step_avg:97.10ms -step:1079/1695 train_time:104769ms step_avg:97.10ms -step:1080/1695 train_time:104865ms step_avg:97.10ms -step:1081/1695 train_time:104961ms step_avg:97.10ms -step:1082/1695 train_time:105056ms step_avg:97.09ms -step:1083/1695 train_time:105152ms step_avg:97.09ms -step:1084/1695 train_time:105248ms step_avg:97.09ms -step:1085/1695 train_time:105344ms step_avg:97.09ms -step:1086/1695 train_time:105439ms step_avg:97.09ms -step:1087/1695 train_time:105535ms step_avg:97.09ms -step:1088/1695 train_time:105631ms step_avg:97.09ms -step:1089/1695 train_time:105729ms step_avg:97.09ms -step:1090/1695 train_time:105825ms step_avg:97.09ms -step:1091/1695 train_time:105922ms step_avg:97.09ms -step:1092/1695 train_time:106018ms step_avg:97.09ms -step:1093/1695 train_time:106114ms step_avg:97.08ms -step:1094/1695 train_time:106209ms step_avg:97.08ms -step:1095/1695 train_time:106306ms step_avg:97.08ms -step:1096/1695 train_time:106402ms step_avg:97.08ms -step:1097/1695 train_time:106497ms step_avg:97.08ms -step:1098/1695 train_time:106592ms step_avg:97.08ms -step:1099/1695 train_time:106688ms step_avg:97.08ms -step:1100/1695 train_time:106785ms step_avg:97.08ms -step:1101/1695 train_time:106882ms step_avg:97.08ms -step:1102/1695 train_time:106978ms step_avg:97.08ms -step:1103/1695 train_time:107073ms step_avg:97.07ms -step:1104/1695 train_time:107169ms step_avg:97.07ms -step:1105/1695 train_time:107265ms step_avg:97.07ms -step:1106/1695 train_time:107362ms step_avg:97.07ms -step:1107/1695 train_time:107457ms step_avg:97.07ms -step:1108/1695 train_time:107552ms step_avg:97.07ms -step:1109/1695 train_time:107648ms step_avg:97.07ms -step:1110/1695 train_time:107745ms step_avg:97.07ms -step:1111/1695 train_time:107841ms step_avg:97.07ms -step:1112/1695 train_time:107938ms step_avg:97.07ms -step:1113/1695 train_time:108034ms step_avg:97.07ms -step:1114/1695 train_time:108130ms step_avg:97.06ms -step:1115/1695 train_time:108227ms step_avg:97.06ms -step:1116/1695 train_time:108323ms step_avg:97.06ms -step:1117/1695 train_time:108419ms step_avg:97.06ms -step:1118/1695 train_time:108515ms step_avg:97.06ms -step:1119/1695 train_time:108611ms step_avg:97.06ms -step:1120/1695 train_time:108708ms step_avg:97.06ms -step:1121/1695 train_time:108804ms step_avg:97.06ms -step:1122/1695 train_time:108901ms step_avg:97.06ms -step:1123/1695 train_time:108995ms step_avg:97.06ms -step:1124/1695 train_time:109090ms step_avg:97.06ms -step:1125/1695 train_time:109187ms step_avg:97.06ms -step:1125/1695 val_loss:3.4370 train_time:109281ms step_avg:97.14ms -step:1126/1695 train_time:109306ms step_avg:97.08ms -step:1127/1695 train_time:109389ms step_avg:97.06ms -step:1128/1695 train_time:109486ms step_avg:97.06ms -step:1129/1695 train_time:109581ms step_avg:97.06ms -step:1130/1695 train_time:109676ms step_avg:97.06ms -step:1131/1695 train_time:109771ms step_avg:97.06ms -step:1132/1695 train_time:109865ms step_avg:97.05ms -step:1133/1695 train_time:109961ms step_avg:97.05ms -step:1134/1695 train_time:110058ms step_avg:97.05ms -step:1135/1695 train_time:110154ms step_avg:97.05ms -step:1136/1695 train_time:110253ms step_avg:97.05ms -step:1137/1695 train_time:110354ms step_avg:97.06ms -step:1138/1695 train_time:110454ms step_avg:97.06ms -step:1139/1695 train_time:110552ms step_avg:97.06ms -step:1140/1695 train_time:110650ms step_avg:97.06ms -step:1141/1695 train_time:110746ms step_avg:97.06ms -step:1142/1695 train_time:110842ms step_avg:97.06ms -step:1143/1695 train_time:110939ms step_avg:97.06ms -step:1144/1695 train_time:111035ms step_avg:97.06ms -step:1145/1695 train_time:111133ms step_avg:97.06ms -step:1146/1695 train_time:111230ms step_avg:97.06ms -step:1147/1695 train_time:111329ms step_avg:97.06ms -step:1148/1695 train_time:111428ms step_avg:97.06ms -step:1149/1695 train_time:111526ms step_avg:97.06ms -step:1150/1695 train_time:111624ms step_avg:97.06ms -step:1151/1695 train_time:111721ms step_avg:97.06ms -step:1152/1695 train_time:111819ms step_avg:97.06ms -step:1153/1695 train_time:111916ms step_avg:97.06ms -step:1154/1695 train_time:112013ms step_avg:97.06ms -step:1155/1695 train_time:112110ms step_avg:97.06ms -step:1156/1695 train_time:112207ms step_avg:97.07ms -step:1157/1695 train_time:112305ms step_avg:97.07ms -step:1158/1695 train_time:112403ms step_avg:97.07ms -step:1159/1695 train_time:112502ms step_avg:97.07ms -step:1160/1695 train_time:112601ms step_avg:97.07ms -step:1161/1695 train_time:112699ms step_avg:97.07ms -step:1162/1695 train_time:112796ms step_avg:97.07ms -step:1163/1695 train_time:112894ms step_avg:97.07ms -step:1164/1695 train_time:112991ms step_avg:97.07ms -step:1165/1695 train_time:113088ms step_avg:97.07ms -step:1166/1695 train_time:113185ms step_avg:97.07ms -step:1167/1695 train_time:113282ms step_avg:97.07ms -step:1168/1695 train_time:113380ms step_avg:97.07ms -step:1169/1695 train_time:113479ms step_avg:97.07ms -step:1170/1695 train_time:113579ms step_avg:97.08ms -step:1171/1695 train_time:113676ms step_avg:97.08ms -step:1172/1695 train_time:113774ms step_avg:97.08ms -step:1173/1695 train_time:113870ms step_avg:97.08ms -step:1174/1695 train_time:113966ms step_avg:97.08ms -step:1175/1695 train_time:114063ms step_avg:97.07ms -step:1176/1695 train_time:114161ms step_avg:97.08ms -step:1177/1695 train_time:114259ms step_avg:97.08ms -step:1178/1695 train_time:114357ms step_avg:97.08ms -step:1179/1695 train_time:114455ms step_avg:97.08ms -step:1180/1695 train_time:114553ms step_avg:97.08ms -step:1181/1695 train_time:114651ms step_avg:97.08ms -step:1182/1695 train_time:114748ms step_avg:97.08ms -step:1183/1695 train_time:114845ms step_avg:97.08ms -step:1184/1695 train_time:114942ms step_avg:97.08ms -step:1185/1695 train_time:115041ms step_avg:97.08ms -step:1186/1695 train_time:115140ms step_avg:97.08ms -step:1187/1695 train_time:115238ms step_avg:97.08ms -step:1188/1695 train_time:115337ms step_avg:97.09ms -step:1189/1695 train_time:115435ms step_avg:97.09ms -step:1190/1695 train_time:115533ms step_avg:97.09ms -step:1191/1695 train_time:115632ms step_avg:97.09ms -step:1192/1695 train_time:115730ms step_avg:97.09ms -step:1193/1695 train_time:115826ms step_avg:97.09ms -step:1194/1695 train_time:115923ms step_avg:97.09ms -step:1195/1695 train_time:116021ms step_avg:97.09ms -step:1196/1695 train_time:116119ms step_avg:97.09ms -step:1197/1695 train_time:116216ms step_avg:97.09ms -step:1198/1695 train_time:116313ms step_avg:97.09ms -step:1199/1695 train_time:116411ms step_avg:97.09ms -step:1200/1695 train_time:116508ms step_avg:97.09ms -step:1201/1695 train_time:116605ms step_avg:97.09ms -step:1202/1695 train_time:116704ms step_avg:97.09ms -step:1203/1695 train_time:116801ms step_avg:97.09ms -step:1204/1695 train_time:116899ms step_avg:97.09ms -step:1205/1695 train_time:116997ms step_avg:97.09ms -step:1206/1695 train_time:117095ms step_avg:97.09ms -step:1207/1695 train_time:117193ms step_avg:97.09ms -step:1208/1695 train_time:117515ms step_avg:97.28ms -step:1209/1695 train_time:117719ms step_avg:97.37ms -step:1210/1695 train_time:117814ms step_avg:97.37ms -step:1211/1695 train_time:117911ms step_avg:97.37ms -step:1212/1695 train_time:118008ms step_avg:97.37ms -step:1213/1695 train_time:118103ms step_avg:97.36ms -step:1214/1695 train_time:118200ms step_avg:97.36ms -step:1215/1695 train_time:118297ms step_avg:97.36ms -step:1216/1695 train_time:118393ms step_avg:97.36ms -step:1217/1695 train_time:118491ms step_avg:97.36ms -step:1218/1695 train_time:118592ms step_avg:97.37ms -step:1219/1695 train_time:118695ms step_avg:97.37ms -step:1220/1695 train_time:118794ms step_avg:97.37ms -step:1221/1695 train_time:118893ms step_avg:97.37ms -step:1222/1695 train_time:118990ms step_avg:97.37ms -step:1223/1695 train_time:119087ms step_avg:97.37ms -step:1224/1695 train_time:119183ms step_avg:97.37ms -step:1225/1695 train_time:119281ms step_avg:97.37ms -step:1226/1695 train_time:119378ms step_avg:97.37ms -step:1227/1695 train_time:119475ms step_avg:97.37ms -step:1228/1695 train_time:119573ms step_avg:97.37ms -step:1229/1695 train_time:119673ms step_avg:97.37ms -step:1230/1695 train_time:119772ms step_avg:97.38ms -step:1231/1695 train_time:119870ms step_avg:97.38ms -step:1232/1695 train_time:119966ms step_avg:97.38ms -step:1233/1695 train_time:120064ms step_avg:97.38ms -step:1234/1695 train_time:120162ms step_avg:97.38ms -step:1235/1695 train_time:120259ms step_avg:97.38ms -step:1236/1695 train_time:120357ms step_avg:97.38ms -step:1237/1695 train_time:120454ms step_avg:97.38ms -step:1238/1695 train_time:120552ms step_avg:97.38ms -step:1239/1695 train_time:120650ms step_avg:97.38ms -step:1240/1695 train_time:120748ms step_avg:97.38ms -step:1241/1695 train_time:120845ms step_avg:97.38ms -step:1242/1695 train_time:120944ms step_avg:97.38ms -step:1243/1695 train_time:121041ms step_avg:97.38ms -step:1244/1695 train_time:121139ms step_avg:97.38ms -step:1245/1695 train_time:121235ms step_avg:97.38ms -step:1246/1695 train_time:121332ms step_avg:97.38ms -step:1247/1695 train_time:121429ms step_avg:97.38ms -step:1248/1695 train_time:121527ms step_avg:97.38ms -step:1249/1695 train_time:121625ms step_avg:97.38ms -step:1250/1695 train_time:121724ms step_avg:97.38ms -step:1250/1695 val_loss:3.3885 train_time:121819ms step_avg:97.46ms -step:1251/1695 train_time:121843ms step_avg:97.40ms -step:1252/1695 train_time:121929ms step_avg:97.39ms -step:1253/1695 train_time:122027ms step_avg:97.39ms -step:1254/1695 train_time:122123ms step_avg:97.39ms -step:1255/1695 train_time:122220ms step_avg:97.39ms -step:1256/1695 train_time:122317ms step_avg:97.39ms -step:1257/1695 train_time:122414ms step_avg:97.39ms -step:1258/1695 train_time:122510ms step_avg:97.38ms -step:1259/1695 train_time:122606ms step_avg:97.38ms -step:1260/1695 train_time:122703ms step_avg:97.38ms -step:1261/1695 train_time:122803ms step_avg:97.39ms -step:1262/1695 train_time:122905ms step_avg:97.39ms -step:1263/1695 train_time:123004ms step_avg:97.39ms -step:1264/1695 train_time:123101ms step_avg:97.39ms -step:1265/1695 train_time:123199ms step_avg:97.39ms -step:1266/1695 train_time:123296ms step_avg:97.39ms -step:1267/1695 train_time:123393ms step_avg:97.39ms -step:1268/1695 train_time:123490ms step_avg:97.39ms -step:1269/1695 train_time:123587ms step_avg:97.39ms -step:1270/1695 train_time:123686ms step_avg:97.39ms -step:1271/1695 train_time:123781ms step_avg:97.39ms -step:1272/1695 train_time:123881ms step_avg:97.39ms -step:1273/1695 train_time:123982ms step_avg:97.39ms -step:1274/1695 train_time:124082ms step_avg:97.40ms -step:1275/1695 train_time:124180ms step_avg:97.40ms -step:1276/1695 train_time:124278ms step_avg:97.40ms -step:1277/1695 train_time:124376ms step_avg:97.40ms -step:1278/1695 train_time:124474ms step_avg:97.40ms -step:1279/1695 train_time:124570ms step_avg:97.40ms -step:1280/1695 train_time:124666ms step_avg:97.40ms -step:1281/1695 train_time:124763ms step_avg:97.40ms -step:1282/1695 train_time:124862ms step_avg:97.40ms -step:1283/1695 train_time:124961ms step_avg:97.40ms -step:1284/1695 train_time:125061ms step_avg:97.40ms -step:1285/1695 train_time:125160ms step_avg:97.40ms -step:1286/1695 train_time:125258ms step_avg:97.40ms -step:1287/1695 train_time:125356ms step_avg:97.40ms -step:1288/1695 train_time:125454ms step_avg:97.40ms -step:1289/1695 train_time:125552ms step_avg:97.40ms -step:1290/1695 train_time:125649ms step_avg:97.40ms -step:1291/1695 train_time:125746ms step_avg:97.40ms -step:1292/1695 train_time:125843ms step_avg:97.40ms -step:1293/1695 train_time:125941ms step_avg:97.40ms -step:1294/1695 train_time:126039ms step_avg:97.40ms -step:1295/1695 train_time:126138ms step_avg:97.40ms -step:1296/1695 train_time:126237ms step_avg:97.40ms -step:1297/1695 train_time:126335ms step_avg:97.41ms -step:1298/1695 train_time:126432ms step_avg:97.41ms -step:1299/1695 train_time:126529ms step_avg:97.41ms -step:1300/1695 train_time:126626ms step_avg:97.40ms -step:1301/1695 train_time:126723ms step_avg:97.40ms -step:1302/1695 train_time:126821ms step_avg:97.40ms -step:1303/1695 train_time:126919ms step_avg:97.41ms -step:1304/1695 train_time:127018ms step_avg:97.41ms -step:1305/1695 train_time:127116ms step_avg:97.41ms -step:1306/1695 train_time:127214ms step_avg:97.41ms -step:1307/1695 train_time:127312ms step_avg:97.41ms -step:1308/1695 train_time:127410ms step_avg:97.41ms -step:1309/1695 train_time:127508ms step_avg:97.41ms -step:1310/1695 train_time:127605ms step_avg:97.41ms -step:1311/1695 train_time:127703ms step_avg:97.41ms -step:1312/1695 train_time:127800ms step_avg:97.41ms -step:1313/1695 train_time:127898ms step_avg:97.41ms -step:1314/1695 train_time:127996ms step_avg:97.41ms -step:1315/1695 train_time:128093ms step_avg:97.41ms -step:1316/1695 train_time:128191ms step_avg:97.41ms -step:1317/1695 train_time:128289ms step_avg:97.41ms -step:1318/1695 train_time:128386ms step_avg:97.41ms -step:1319/1695 train_time:128483ms step_avg:97.41ms -step:1320/1695 train_time:128582ms step_avg:97.41ms -step:1321/1695 train_time:128679ms step_avg:97.41ms -step:1322/1695 train_time:128777ms step_avg:97.41ms -step:1323/1695 train_time:128875ms step_avg:97.41ms -step:1324/1695 train_time:128973ms step_avg:97.41ms -step:1325/1695 train_time:129070ms step_avg:97.41ms -step:1326/1695 train_time:129168ms step_avg:97.41ms -step:1327/1695 train_time:129265ms step_avg:97.41ms -step:1328/1695 train_time:129363ms step_avg:97.41ms -step:1329/1695 train_time:129461ms step_avg:97.41ms -step:1330/1695 train_time:129559ms step_avg:97.41ms -step:1331/1695 train_time:129658ms step_avg:97.41ms -step:1332/1695 train_time:129756ms step_avg:97.41ms -step:1333/1695 train_time:129854ms step_avg:97.42ms -step:1334/1695 train_time:129952ms step_avg:97.42ms -step:1335/1695 train_time:130049ms step_avg:97.41ms -step:1336/1695 train_time:130146ms step_avg:97.41ms -step:1337/1695 train_time:130244ms step_avg:97.42ms -step:1338/1695 train_time:130342ms step_avg:97.42ms -step:1339/1695 train_time:130440ms step_avg:97.42ms -step:1340/1695 train_time:130539ms step_avg:97.42ms -step:1341/1695 train_time:130637ms step_avg:97.42ms -step:1342/1695 train_time:130735ms step_avg:97.42ms -step:1343/1695 train_time:130832ms step_avg:97.42ms -step:1344/1695 train_time:130929ms step_avg:97.42ms -step:1345/1695 train_time:131026ms step_avg:97.42ms -step:1346/1695 train_time:131124ms step_avg:97.42ms -step:1347/1695 train_time:131220ms step_avg:97.42ms -step:1348/1695 train_time:131318ms step_avg:97.42ms -step:1349/1695 train_time:131417ms step_avg:97.42ms -step:1350/1695 train_time:131515ms step_avg:97.42ms -step:1351/1695 train_time:131613ms step_avg:97.42ms -step:1352/1695 train_time:131710ms step_avg:97.42ms -step:1353/1695 train_time:131808ms step_avg:97.42ms -step:1354/1695 train_time:131905ms step_avg:97.42ms -step:1355/1695 train_time:132003ms step_avg:97.42ms -step:1356/1695 train_time:132102ms step_avg:97.42ms -step:1357/1695 train_time:132200ms step_avg:97.42ms -step:1358/1695 train_time:132297ms step_avg:97.42ms -step:1359/1695 train_time:132395ms step_avg:97.42ms -step:1360/1695 train_time:132492ms step_avg:97.42ms -step:1361/1695 train_time:132590ms step_avg:97.42ms -step:1362/1695 train_time:132688ms step_avg:97.42ms -step:1363/1695 train_time:132786ms step_avg:97.42ms -step:1364/1695 train_time:132883ms step_avg:97.42ms -step:1365/1695 train_time:132981ms step_avg:97.42ms -step:1366/1695 train_time:133080ms step_avg:97.42ms -step:1367/1695 train_time:133179ms step_avg:97.42ms -step:1368/1695 train_time:133277ms step_avg:97.42ms -step:1369/1695 train_time:133375ms step_avg:97.43ms -step:1370/1695 train_time:133473ms step_avg:97.43ms -step:1371/1695 train_time:133571ms step_avg:97.43ms -step:1372/1695 train_time:133669ms step_avg:97.43ms -step:1373/1695 train_time:133766ms step_avg:97.43ms -step:1374/1695 train_time:133863ms step_avg:97.43ms -step:1375/1695 train_time:133961ms step_avg:97.43ms -step:1375/1695 val_loss:3.3505 train_time:134057ms step_avg:97.50ms -step:1376/1695 train_time:134084ms step_avg:97.45ms -step:1377/1695 train_time:134163ms step_avg:97.43ms -step:1378/1695 train_time:134261ms step_avg:97.43ms -step:1379/1695 train_time:134359ms step_avg:97.43ms -step:1380/1695 train_time:134456ms step_avg:97.43ms -step:1381/1695 train_time:134781ms step_avg:97.60ms -step:1382/1695 train_time:134987ms step_avg:97.68ms -step:1383/1695 train_time:135083ms step_avg:97.67ms -step:1384/1695 train_time:135179ms step_avg:97.67ms -step:1385/1695 train_time:135276ms step_avg:97.67ms -step:1386/1695 train_time:135374ms step_avg:97.67ms -step:1387/1695 train_time:135471ms step_avg:97.67ms -step:1388/1695 train_time:135568ms step_avg:97.67ms -step:1389/1695 train_time:135665ms step_avg:97.67ms -step:1390/1695 train_time:135763ms step_avg:97.67ms -step:1391/1695 train_time:135867ms step_avg:97.68ms -step:1392/1695 train_time:135968ms step_avg:97.68ms -step:1393/1695 train_time:136065ms step_avg:97.68ms -step:1394/1695 train_time:136163ms step_avg:97.68ms -step:1395/1695 train_time:136260ms step_avg:97.68ms -step:1396/1695 train_time:136356ms step_avg:97.68ms -step:1397/1695 train_time:136453ms step_avg:97.68ms -step:1398/1695 train_time:136549ms step_avg:97.67ms -step:1399/1695 train_time:136646ms step_avg:97.67ms -step:1400/1695 train_time:136743ms step_avg:97.67ms -step:1401/1695 train_time:136841ms step_avg:97.67ms -step:1402/1695 train_time:136940ms step_avg:97.67ms -step:1403/1695 train_time:137039ms step_avg:97.68ms -step:1404/1695 train_time:137137ms step_avg:97.68ms -step:1405/1695 train_time:137234ms step_avg:97.68ms -step:1406/1695 train_time:137332ms step_avg:97.68ms -step:1407/1695 train_time:137429ms step_avg:97.67ms -step:1408/1695 train_time:137525ms step_avg:97.67ms -step:1409/1695 train_time:137622ms step_avg:97.67ms -step:1410/1695 train_time:137719ms step_avg:97.67ms -step:1411/1695 train_time:137817ms step_avg:97.67ms -step:1412/1695 train_time:137916ms step_avg:97.67ms -step:1413/1695 train_time:138014ms step_avg:97.67ms -step:1414/1695 train_time:138113ms step_avg:97.68ms -step:1415/1695 train_time:138212ms step_avg:97.68ms -step:1416/1695 train_time:138309ms step_avg:97.68ms -step:1417/1695 train_time:138405ms step_avg:97.67ms -step:1418/1695 train_time:138501ms step_avg:97.67ms -step:1419/1695 train_time:138598ms step_avg:97.67ms -step:1420/1695 train_time:138696ms step_avg:97.67ms -step:1421/1695 train_time:138794ms step_avg:97.67ms -step:1422/1695 train_time:138893ms step_avg:97.67ms -step:1423/1695 train_time:138990ms step_avg:97.67ms -step:1424/1695 train_time:139089ms step_avg:97.67ms -step:1425/1695 train_time:139188ms step_avg:97.68ms -step:1426/1695 train_time:139286ms step_avg:97.68ms -step:1427/1695 train_time:139382ms step_avg:97.68ms -step:1428/1695 train_time:139479ms step_avg:97.67ms -step:1429/1695 train_time:139577ms step_avg:97.67ms -step:1430/1695 train_time:139675ms step_avg:97.67ms -step:1431/1695 train_time:139772ms step_avg:97.67ms -step:1432/1695 train_time:139871ms step_avg:97.68ms -step:1433/1695 train_time:139969ms step_avg:97.68ms -step:1434/1695 train_time:140067ms step_avg:97.68ms -step:1435/1695 train_time:140164ms step_avg:97.68ms -step:1436/1695 train_time:140261ms step_avg:97.68ms -step:1437/1695 train_time:140358ms step_avg:97.67ms -step:1438/1695 train_time:140455ms step_avg:97.67ms -step:1439/1695 train_time:140553ms step_avg:97.67ms -step:1440/1695 train_time:140651ms step_avg:97.67ms -step:1441/1695 train_time:140749ms step_avg:97.67ms -step:1442/1695 train_time:140847ms step_avg:97.67ms -step:1443/1695 train_time:140945ms step_avg:97.68ms -step:1444/1695 train_time:141043ms step_avg:97.68ms -step:1445/1695 train_time:141140ms step_avg:97.67ms -step:1446/1695 train_time:141238ms step_avg:97.67ms -step:1447/1695 train_time:141335ms step_avg:97.67ms -step:1448/1695 train_time:141433ms step_avg:97.67ms -step:1449/1695 train_time:141530ms step_avg:97.67ms -step:1450/1695 train_time:141628ms step_avg:97.67ms -step:1451/1695 train_time:141725ms step_avg:97.67ms -step:1452/1695 train_time:141822ms step_avg:97.67ms -step:1453/1695 train_time:141920ms step_avg:97.67ms -step:1454/1695 train_time:142019ms step_avg:97.67ms -step:1455/1695 train_time:142118ms step_avg:97.68ms -step:1456/1695 train_time:142215ms step_avg:97.68ms -step:1457/1695 train_time:142313ms step_avg:97.68ms -step:1458/1695 train_time:142410ms step_avg:97.68ms -step:1459/1695 train_time:142507ms step_avg:97.67ms -step:1460/1695 train_time:142604ms step_avg:97.67ms -step:1461/1695 train_time:142701ms step_avg:97.67ms -step:1462/1695 train_time:142798ms step_avg:97.67ms -step:1463/1695 train_time:142896ms step_avg:97.67ms -step:1464/1695 train_time:142994ms step_avg:97.67ms -step:1465/1695 train_time:143092ms step_avg:97.67ms -step:1466/1695 train_time:143190ms step_avg:97.67ms -step:1467/1695 train_time:143287ms step_avg:97.67ms -step:1468/1695 train_time:143385ms step_avg:97.67ms -step:1469/1695 train_time:143482ms step_avg:97.67ms -step:1470/1695 train_time:143579ms step_avg:97.67ms -step:1471/1695 train_time:143678ms step_avg:97.67ms -step:1472/1695 train_time:143776ms step_avg:97.67ms -step:1473/1695 train_time:143873ms step_avg:97.67ms -step:1474/1695 train_time:143972ms step_avg:97.67ms -step:1475/1695 train_time:144070ms step_avg:97.67ms -step:1476/1695 train_time:144167ms step_avg:97.67ms -step:1477/1695 train_time:144265ms step_avg:97.67ms -step:1478/1695 train_time:144361ms step_avg:97.67ms -step:1479/1695 train_time:144459ms step_avg:97.67ms -step:1480/1695 train_time:144555ms step_avg:97.67ms -step:1481/1695 train_time:144653ms step_avg:97.67ms -step:1482/1695 train_time:144751ms step_avg:97.67ms -step:1483/1695 train_time:144849ms step_avg:97.67ms -step:1484/1695 train_time:144947ms step_avg:97.67ms -step:1485/1695 train_time:145044ms step_avg:97.67ms -step:1486/1695 train_time:145141ms step_avg:97.67ms -step:1487/1695 train_time:145239ms step_avg:97.67ms -step:1488/1695 train_time:145337ms step_avg:97.67ms -step:1489/1695 train_time:145435ms step_avg:97.67ms -step:1490/1695 train_time:145533ms step_avg:97.67ms -step:1491/1695 train_time:145630ms step_avg:97.67ms -step:1492/1695 train_time:145728ms step_avg:97.67ms -step:1493/1695 train_time:145825ms step_avg:97.67ms -step:1494/1695 train_time:145922ms step_avg:97.67ms -step:1495/1695 train_time:146020ms step_avg:97.67ms -step:1496/1695 train_time:146118ms step_avg:97.67ms -step:1497/1695 train_time:146216ms step_avg:97.67ms -step:1498/1695 train_time:146315ms step_avg:97.67ms -step:1499/1695 train_time:146413ms step_avg:97.67ms -step:1500/1695 train_time:146511ms step_avg:97.67ms -step:1500/1695 val_loss:3.3179 train_time:146606ms step_avg:97.74ms -step:1501/1695 train_time:146632ms step_avg:97.69ms -step:1502/1695 train_time:146715ms step_avg:97.68ms -step:1503/1695 train_time:146816ms step_avg:97.68ms -step:1504/1695 train_time:146914ms step_avg:97.68ms -step:1505/1695 train_time:147011ms step_avg:97.68ms -step:1506/1695 train_time:147108ms step_avg:97.68ms -step:1507/1695 train_time:147205ms step_avg:97.68ms -step:1508/1695 train_time:147301ms step_avg:97.68ms -step:1509/1695 train_time:147397ms step_avg:97.68ms -step:1510/1695 train_time:147494ms step_avg:97.68ms -step:1511/1695 train_time:147596ms step_avg:97.68ms -step:1512/1695 train_time:147698ms step_avg:97.68ms -step:1513/1695 train_time:147797ms step_avg:97.68ms -step:1514/1695 train_time:147896ms step_avg:97.69ms -step:1515/1695 train_time:147994ms step_avg:97.69ms -step:1516/1695 train_time:148092ms step_avg:97.69ms -step:1517/1695 train_time:148189ms step_avg:97.69ms -step:1518/1695 train_time:148286ms step_avg:97.68ms -step:1519/1695 train_time:148382ms step_avg:97.68ms -step:1520/1695 train_time:148478ms step_avg:97.68ms -step:1521/1695 train_time:148577ms step_avg:97.68ms -step:1522/1695 train_time:148678ms step_avg:97.69ms -step:1523/1695 train_time:148777ms step_avg:97.69ms -step:1524/1695 train_time:148876ms step_avg:97.69ms -step:1525/1695 train_time:148975ms step_avg:97.69ms -step:1526/1695 train_time:149074ms step_avg:97.69ms -step:1527/1695 train_time:149172ms step_avg:97.69ms -step:1528/1695 train_time:149269ms step_avg:97.69ms -step:1529/1695 train_time:149367ms step_avg:97.69ms -step:1530/1695 train_time:149465ms step_avg:97.69ms -step:1531/1695 train_time:149562ms step_avg:97.69ms -step:1532/1695 train_time:149660ms step_avg:97.69ms -step:1533/1695 train_time:149758ms step_avg:97.69ms -step:1534/1695 train_time:149856ms step_avg:97.69ms -step:1535/1695 train_time:149955ms step_avg:97.69ms -step:1536/1695 train_time:150054ms step_avg:97.69ms -step:1537/1695 train_time:150152ms step_avg:97.69ms -step:1538/1695 train_time:150249ms step_avg:97.69ms -step:1539/1695 train_time:150346ms step_avg:97.69ms -step:1540/1695 train_time:150443ms step_avg:97.69ms -step:1541/1695 train_time:150540ms step_avg:97.69ms -step:1542/1695 train_time:150638ms step_avg:97.69ms -step:1543/1695 train_time:150736ms step_avg:97.69ms -step:1544/1695 train_time:150835ms step_avg:97.69ms -step:1545/1695 train_time:150932ms step_avg:97.69ms -step:1546/1695 train_time:151030ms step_avg:97.69ms -step:1547/1695 train_time:151128ms step_avg:97.69ms -step:1548/1695 train_time:151225ms step_avg:97.69ms -step:1549/1695 train_time:151322ms step_avg:97.69ms -step:1550/1695 train_time:151419ms step_avg:97.69ms -step:1551/1695 train_time:151517ms step_avg:97.69ms -step:1552/1695 train_time:151888ms step_avg:97.87ms -step:1553/1695 train_time:151963ms step_avg:97.85ms -step:1554/1695 train_time:152058ms step_avg:97.85ms -step:1555/1695 train_time:152155ms step_avg:97.85ms -step:1556/1695 train_time:152252ms step_avg:97.85ms -step:1557/1695 train_time:152349ms step_avg:97.85ms -step:1558/1695 train_time:152445ms step_avg:97.85ms -step:1559/1695 train_time:152541ms step_avg:97.85ms -step:1560/1695 train_time:152638ms step_avg:97.84ms -step:1561/1695 train_time:152735ms step_avg:97.84ms -step:1562/1695 train_time:152839ms step_avg:97.85ms -step:1563/1695 train_time:152939ms step_avg:97.85ms -step:1564/1695 train_time:153039ms step_avg:97.85ms -step:1565/1695 train_time:153136ms step_avg:97.85ms -step:1566/1695 train_time:153235ms step_avg:97.85ms -step:1567/1695 train_time:153331ms step_avg:97.85ms -step:1568/1695 train_time:153428ms step_avg:97.85ms -step:1569/1695 train_time:153525ms step_avg:97.85ms -step:1570/1695 train_time:153622ms step_avg:97.85ms -step:1571/1695 train_time:153720ms step_avg:97.85ms -step:1572/1695 train_time:153819ms step_avg:97.85ms -step:1573/1695 train_time:153917ms step_avg:97.85ms -step:1574/1695 train_time:154017ms step_avg:97.85ms -step:1575/1695 train_time:154115ms step_avg:97.85ms -step:1576/1695 train_time:154215ms step_avg:97.85ms -step:1577/1695 train_time:154312ms step_avg:97.85ms -step:1578/1695 train_time:154410ms step_avg:97.85ms -step:1579/1695 train_time:154507ms step_avg:97.85ms -step:1580/1695 train_time:154604ms step_avg:97.85ms -step:1581/1695 train_time:154701ms step_avg:97.85ms -step:1582/1695 train_time:154798ms step_avg:97.85ms -step:1583/1695 train_time:154897ms step_avg:97.85ms -step:1584/1695 train_time:154996ms step_avg:97.85ms -step:1585/1695 train_time:155096ms step_avg:97.85ms -step:1586/1695 train_time:155195ms step_avg:97.85ms -step:1587/1695 train_time:155293ms step_avg:97.85ms -step:1588/1695 train_time:155391ms step_avg:97.85ms -step:1589/1695 train_time:155489ms step_avg:97.85ms -step:1590/1695 train_time:155586ms step_avg:97.85ms -step:1591/1695 train_time:155682ms step_avg:97.85ms -step:1592/1695 train_time:155779ms step_avg:97.85ms -step:1593/1695 train_time:155876ms step_avg:97.85ms -step:1594/1695 train_time:155976ms step_avg:97.85ms -step:1595/1695 train_time:156075ms step_avg:97.85ms -step:1596/1695 train_time:156175ms step_avg:97.85ms -step:1597/1695 train_time:156275ms step_avg:97.86ms -step:1598/1695 train_time:156375ms step_avg:97.86ms -step:1599/1695 train_time:156473ms step_avg:97.86ms -step:1600/1695 train_time:156571ms step_avg:97.86ms -step:1601/1695 train_time:156669ms step_avg:97.86ms -step:1602/1695 train_time:156767ms step_avg:97.86ms -step:1603/1695 train_time:156864ms step_avg:97.86ms -step:1604/1695 train_time:156962ms step_avg:97.86ms -step:1605/1695 train_time:157060ms step_avg:97.86ms -step:1606/1695 train_time:157159ms step_avg:97.86ms -step:1607/1695 train_time:157257ms step_avg:97.86ms -step:1608/1695 train_time:157354ms step_avg:97.86ms -step:1609/1695 train_time:157453ms step_avg:97.86ms -step:1610/1695 train_time:157551ms step_avg:97.86ms -step:1611/1695 train_time:157649ms step_avg:97.86ms -step:1612/1695 train_time:157747ms step_avg:97.86ms -step:1613/1695 train_time:157844ms step_avg:97.86ms -step:1614/1695 train_time:157942ms step_avg:97.86ms -step:1615/1695 train_time:158039ms step_avg:97.86ms -step:1616/1695 train_time:158137ms step_avg:97.86ms -step:1617/1695 train_time:158235ms step_avg:97.86ms -step:1618/1695 train_time:158333ms step_avg:97.86ms -step:1619/1695 train_time:158432ms step_avg:97.86ms -step:1620/1695 train_time:158530ms step_avg:97.86ms -step:1621/1695 train_time:158628ms step_avg:97.86ms -step:1622/1695 train_time:158726ms step_avg:97.86ms -step:1623/1695 train_time:158824ms step_avg:97.86ms -step:1624/1695 train_time:158922ms step_avg:97.86ms -step:1625/1695 train_time:159019ms step_avg:97.86ms -step:1625/1695 val_loss:3.2905 train_time:159114ms step_avg:97.92ms -step:1626/1695 train_time:159139ms step_avg:97.87ms -step:1627/1695 train_time:159222ms step_avg:97.86ms -step:1628/1695 train_time:159320ms step_avg:97.86ms -step:1629/1695 train_time:159418ms step_avg:97.86ms -step:1630/1695 train_time:159515ms step_avg:97.86ms -step:1631/1695 train_time:159612ms step_avg:97.86ms -step:1632/1695 train_time:159709ms step_avg:97.86ms -step:1633/1695 train_time:159806ms step_avg:97.86ms -step:1634/1695 train_time:159902ms step_avg:97.86ms -step:1635/1695 train_time:159999ms step_avg:97.86ms -step:1636/1695 train_time:160099ms step_avg:97.86ms -step:1637/1695 train_time:160200ms step_avg:97.86ms -step:1638/1695 train_time:160300ms step_avg:97.86ms -step:1639/1695 train_time:160398ms step_avg:97.86ms -step:1640/1695 train_time:160495ms step_avg:97.86ms -step:1641/1695 train_time:160592ms step_avg:97.86ms -step:1642/1695 train_time:160689ms step_avg:97.86ms -step:1643/1695 train_time:160786ms step_avg:97.86ms -step:1644/1695 train_time:160882ms step_avg:97.86ms -step:1645/1695 train_time:160980ms step_avg:97.86ms -step:1646/1695 train_time:161078ms step_avg:97.86ms -step:1647/1695 train_time:161177ms step_avg:97.86ms -step:1648/1695 train_time:161277ms step_avg:97.86ms -step:1649/1695 train_time:161378ms step_avg:97.86ms -step:1650/1695 train_time:161476ms step_avg:97.86ms -step:1651/1695 train_time:161573ms step_avg:97.86ms -step:1652/1695 train_time:161670ms step_avg:97.86ms -step:1653/1695 train_time:161767ms step_avg:97.86ms -step:1654/1695 train_time:161864ms step_avg:97.86ms -step:1655/1695 train_time:161961ms step_avg:97.86ms -step:1656/1695 train_time:162059ms step_avg:97.86ms -step:1657/1695 train_time:162157ms step_avg:97.86ms -step:1658/1695 train_time:162256ms step_avg:97.86ms -step:1659/1695 train_time:162356ms step_avg:97.86ms -step:1660/1695 train_time:162456ms step_avg:97.86ms -step:1661/1695 train_time:162555ms step_avg:97.87ms -step:1662/1695 train_time:162654ms step_avg:97.87ms -step:1663/1695 train_time:162751ms step_avg:97.87ms -step:1664/1695 train_time:162849ms step_avg:97.87ms -step:1665/1695 train_time:162946ms step_avg:97.87ms -step:1666/1695 train_time:163044ms step_avg:97.87ms -step:1667/1695 train_time:163142ms step_avg:97.87ms -step:1668/1695 train_time:163239ms step_avg:97.87ms -step:1669/1695 train_time:163337ms step_avg:97.87ms -step:1670/1695 train_time:163435ms step_avg:97.87ms -step:1671/1695 train_time:163534ms step_avg:97.87ms -step:1672/1695 train_time:163633ms step_avg:97.87ms -step:1673/1695 train_time:163731ms step_avg:97.87ms -step:1674/1695 train_time:163829ms step_avg:97.87ms -step:1675/1695 train_time:163927ms step_avg:97.87ms -step:1676/1695 train_time:164024ms step_avg:97.87ms -step:1677/1695 train_time:164122ms step_avg:97.87ms -step:1678/1695 train_time:164219ms step_avg:97.87ms -step:1679/1695 train_time:164317ms step_avg:97.87ms -step:1680/1695 train_time:164415ms step_avg:97.87ms -step:1681/1695 train_time:164513ms step_avg:97.87ms -step:1682/1695 train_time:164612ms step_avg:97.87ms -step:1683/1695 train_time:164710ms step_avg:97.87ms -step:1684/1695 train_time:164809ms step_avg:97.87ms -step:1685/1695 train_time:164906ms step_avg:97.87ms -step:1686/1695 train_time:165004ms step_avg:97.87ms -step:1687/1695 train_time:165101ms step_avg:97.87ms -step:1688/1695 train_time:165199ms step_avg:97.87ms -step:1689/1695 train_time:165296ms step_avg:97.87ms -step:1690/1695 train_time:165393ms step_avg:97.87ms -step:1691/1695 train_time:165491ms step_avg:97.87ms -step:1692/1695 train_time:165589ms step_avg:97.87ms -step:1693/1695 train_time:165686ms step_avg:97.87ms -step:1694/1695 train_time:165783ms step_avg:97.86ms -step:1695/1695 train_time:165881ms step_avg:97.86ms -step:1695/1695 val_loss:3.2790 train_time:165977ms step_avg:97.92ms -peak memory allocated: 34000 MiB reserved: 49756 MiB diff --git a/records/082725_FA3/README.md b/records/082725_FA3/README.md deleted file mode 100644 index a4079630d..000000000 --- a/records/082725_FA3/README.md +++ /dev/null @@ -1,147 +0,0 @@ -# New record 08/27/25 - -This submission includes recent WR changes by -@ClassicLarry [(08/23/25)](https://github.com/ClassicLarry/modded-nanogpt/tree/master/records/082325_SparseAttnGate) -and @byronxu99 [(07/18/25)](https://github.com/KellerJordan/modded-nanogpt/pull/109). - -The main idea of this record is to use input tensors with `batch_size > 1` throughout our training run. -Increasing `batch_size` increases GPU utilization and allows us to use shorter input sequences for training. -However, since Flex Attention's is inefficient for `batch_size > 1`, we use [Flash Attention v3](https://github.com/Dao-AILab/flash-attention). -The official version of this module is incompatible with `torch.compile` and causes graph breaks. -However, a [recent PR](https://github.com/Dao-AILab/flash-attention/pull/1769) by -[@guilhermeleobas](https://github.com/guilhermeleobas) addresses this issue. - - -## Timing and Validation - -Validated over 7 runs: -- In 1695 training steps, this run achieves a loss <3.28 (`p=0.0031`) -- In 166.10 seconds on average, or <166.25 seconds (`p=0.0024`), - -``` -import scipy.stats -import torch -import numpy as np - -accs = [ - 3.2769, 3.2782, 3.2790, 3.2791, 3.2791, 3.2780, 3.2782 -] - -times = [ - 166.247, 166.117, 165.977, 166.135, 166.045, 166.044, 166.157 -] - -print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) -# p=0.0008 - -print('p=%.4f' % scipy.stats.ttest_1samp(times, 166.25, alternative='less').pvalue) -# p=0.0024 - -print(f"{np.mean(times):.4f}") -# 166.1031 -``` - -In my timing, this is a 2.1 second mean improvement over [PR#117])(https://github.com/KellerJordan/modded-nanogpt/pull/117). -The number of steps can also probably be brought down by 5-15 while achieving loss <3.28. - -I used SXM5 8 x H100 via Prime Intellect for validation compute. - -## Further Details - -### Motivation - -PyTorch's Flex Attention experiences a slowdown >10% wallclock for inputs with `batch_size > 1`. -As such, previous records would train on very long sequence lengths (`48 * 1024`) with no batch dimension. -Attention is approximately `O(|seq_len|^2 x |batch_size|)`, so this is theoretically bad, -but it was mitigated by using aggressive blocking masking. -Attention used a `block_mask` which only grew at most to `1664` tokens (and was often shorter due to document masking). -However, GPU utilization for attention is higher when tokens are distributed along the batch dimension. - - -Additionally, increasing the batch size allows us to decrease sequence length while maintaining the total -number of tokens processed per step. -WR#26 by @ClassicLarry found that validation loss decreases when we train only -on sequences beginning with the Beginning of Sequence token (``). -Decreasing the sequence length ensures makes it more likely that `` is present in the attention window. -In order generate batches where each sequence begins with ``, I have created the helper class -`EOSBatchFinder`. This class pre-indexes shards with the location of `` for slight speedups. - -### Flash Attention 3 - -Most of the Hopper-specific benefits in Flash Attention 3 are incorporated into -PyTorch's Flex Attention already. However, the latter implementation is fastest with `batch_size == 1`, -Flash Attention 3 is as fast as Flex Attention for 1 dimensional input sequences, and increases -in speed as we distribute tokens along the batch dimension. -I measured a 9% wallclock decrease for FA3 when using an optimal ratio of batch dimension to sequence length -(`24: 2048`) over a single batch dimension (`1: 49152`) (on a single Hopper H100). - -As mentioned above, we need to use an unmerged PR in order to use FA3 with `torch.compile`. -You can build the wheel like so: - -``` -pip install -U pip wheel setuptools ninja numpy packaging psutil - -git clone https://github.com/guilhermeleobas/flash-attention.git -cd flash-attention/hopper -git switch guilhermeleobas/fa3-compile - -export MAX_JOBS=32 # Can increase based on machine -export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch -export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only -export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8 -export FLASH_ATTENTION_DISABLE_HDIM64=TRUE # NanoGPT only uses HDIM = 128 -export FLASH_ATTENTION_DISABLE_HDIM96=TRUE -export FLASH_ATTENTION_DISABLE_HDIM192=TRUE -export FLASH_ATTENTION_DISABLE_HDIM256=TRUE - -python setup.py bdist_wheel -``` - -Additionally, I have uploaded a prebuilt wheel -[here](https://github.com/varunneal/flash-attention/releases/tag/v3.0.0b1-alpha), -though it will likely be faster to build it yourself than download this wheel. - -For exact reproduction, I recommend that you install Torch Nightly 2.9.0.dev20250718 and -install the FA3 wheel afterward: - -``` -pip install --pre "torch==2.9.0.dev20250718+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126 - -# typical path to FA3 Wheel -pip install flash-attention/hopper/dist/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl -``` - -For me, Torch Nightly 2.9.0.dev20250713 was incompatible with PR#109. - -### Attention Masks - -Unfortunately, Flash Attention does not support complex Block Masks like Flex Attention. -Therefore, `create_blockmasks` was removed. Instead, we only are given the parameter `window_size` -where we can specify the number of left tokens to attend to. - -I kept the existing long-short sliding window block mask pattern, as well as the idea -that the window sizes should linearly increase over the length of the training run. -To aid with this, I modified `get_lr(step)` to instead be `get_lr_and_ws(step)`. -Additionally, I added a hyperparameter `ws_schedule` which specifies what the -longer window size should be during each portion of the run. I additionally added the -size of blocks in a window as a hyperparameter `bandwidth=128`. - -I have picked a linear schedule with three steps: `ws_schedule=(3, 7, 11)`. -Currently, `torch.compile` creates a new compilation graph per each step in `ws_schedule`. -Therefore, each graph needs to be warmed up separately. I have increased the number -of warmup steps from `10` to `60`. The compile time is dominated by the first iteration -so this will take approximately `len(ws_schedule)` times longer than before. - -Removing document masking had a noticeably negative impact on validation loss, -however the benefits of a short sequence length counteract this. - -### Potential Improvements - -- Batch size scheduling: Previously, the block mask acted as a proxy for batch size. -Now block size can be controlled explicitly and sequenced according to critical batch -size theory. I have added code in `distributed_data_generator` that allows for changing the -batch size and sequence length yielded after the generator is created. -- The current block mask window schedule `(3, 7, 11)` can almost certainly be improved upon. -- Hyperparameter tuning might change with smaller sequence length. Rotary base, validation sequence length, learning rates -etc. should be re-tuned. I haven't done that for this run. -- FA3 has additional features over Flex Attention that may be useful. \ No newline at end of file diff --git a/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt b/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt deleted file mode 100644 index 7a5ed0b1c..000000000 --- a/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:58:09 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 34C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 35C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 33C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms -step:1/1695 train_time:516ms step_avg:515.52ms -step:2/1695 train_time:539ms step_avg:269.65ms -step:3/1695 train_time:612ms step_avg:203.90ms -step:4/1695 train_time:704ms step_avg:175.97ms -step:5/1695 train_time:797ms step_avg:159.42ms -step:6/1695 train_time:891ms step_avg:148.48ms -step:7/1695 train_time:984ms step_avg:140.60ms -step:8/1695 train_time:1078ms step_avg:134.78ms -step:9/1695 train_time:1172ms step_avg:130.23ms -step:10/1695 train_time:1265ms step_avg:126.49ms -step:11/1695 train_time:1359ms step_avg:123.52ms -step:12/1695 train_time:1457ms step_avg:121.44ms -step:13/1695 train_time:1555ms step_avg:119.64ms -step:14/1695 train_time:1650ms step_avg:117.89ms -step:15/1695 train_time:1745ms step_avg:116.30ms -step:16/1695 train_time:1839ms step_avg:114.95ms -step:17/1695 train_time:1933ms step_avg:113.72ms -step:18/1695 train_time:2027ms step_avg:112.62ms -step:19/1695 train_time:2122ms step_avg:111.67ms -step:20/1695 train_time:2216ms step_avg:110.82ms -step:21/1695 train_time:2311ms step_avg:110.03ms -step:22/1695 train_time:2405ms step_avg:109.33ms -step:23/1695 train_time:2501ms step_avg:108.73ms -step:24/1695 train_time:2597ms step_avg:108.21ms -step:25/1695 train_time:2693ms step_avg:107.73ms -step:26/1695 train_time:2788ms step_avg:107.23ms -step:27/1695 train_time:2882ms step_avg:106.74ms -step:28/1695 train_time:2977ms step_avg:106.33ms -step:29/1695 train_time:3071ms step_avg:105.90ms -step:30/1695 train_time:3165ms step_avg:105.50ms -step:31/1695 train_time:3259ms step_avg:105.14ms -step:32/1695 train_time:3355ms step_avg:104.84ms -step:33/1695 train_time:3449ms step_avg:104.53ms -step:34/1695 train_time:3545ms step_avg:104.26ms -step:35/1695 train_time:3640ms step_avg:104.01ms -step:36/1695 train_time:3736ms step_avg:103.78ms -step:37/1695 train_time:3831ms step_avg:103.53ms -step:38/1695 train_time:3925ms step_avg:103.28ms -step:39/1695 train_time:4019ms step_avg:103.06ms -step:40/1695 train_time:4113ms step_avg:102.84ms -step:41/1695 train_time:4207ms step_avg:102.60ms -step:42/1695 train_time:4301ms step_avg:102.40ms -step:43/1695 train_time:4396ms step_avg:102.23ms -step:44/1695 train_time:4491ms step_avg:102.07ms -step:45/1695 train_time:4587ms step_avg:101.92ms -step:46/1695 train_time:4681ms step_avg:101.76ms -step:47/1695 train_time:4777ms step_avg:101.63ms -step:48/1695 train_time:4873ms step_avg:101.53ms -step:49/1695 train_time:4966ms step_avg:101.35ms -step:50/1695 train_time:5060ms step_avg:101.21ms -step:51/1695 train_time:5155ms step_avg:101.07ms -step:52/1695 train_time:5249ms step_avg:100.94ms -step:53/1695 train_time:5343ms step_avg:100.82ms -step:54/1695 train_time:5439ms step_avg:100.72ms -step:55/1695 train_time:5534ms step_avg:100.62ms -step:56/1695 train_time:5629ms step_avg:100.51ms -step:57/1695 train_time:5723ms step_avg:100.41ms -step:58/1695 train_time:5818ms step_avg:100.32ms -step:59/1695 train_time:5913ms step_avg:100.23ms -step:60/1695 train_time:6007ms step_avg:100.11ms -step:61/1695 train_time:6100ms step_avg:100.00ms -step:62/1695 train_time:6196ms step_avg:99.93ms -step:63/1695 train_time:6290ms step_avg:99.84ms -step:64/1695 train_time:6384ms step_avg:99.75ms -step:65/1695 train_time:6479ms step_avg:99.68ms -step:66/1695 train_time:6573ms step_avg:99.59ms -step:67/1695 train_time:6667ms step_avg:99.51ms -step:68/1695 train_time:6762ms step_avg:99.44ms -step:69/1695 train_time:6856ms step_avg:99.37ms -step:70/1695 train_time:6950ms step_avg:99.29ms -step:71/1695 train_time:7044ms step_avg:99.21ms -step:72/1695 train_time:7139ms step_avg:99.16ms -step:73/1695 train_time:7234ms step_avg:99.10ms -step:74/1695 train_time:7329ms step_avg:99.04ms -step:75/1695 train_time:7423ms step_avg:98.98ms -step:76/1695 train_time:7519ms step_avg:98.94ms -step:77/1695 train_time:7614ms step_avg:98.88ms -step:78/1695 train_time:7709ms step_avg:98.83ms -step:79/1695 train_time:7803ms step_avg:98.77ms -step:80/1695 train_time:7897ms step_avg:98.71ms -step:81/1695 train_time:7991ms step_avg:98.66ms -step:82/1695 train_time:8085ms step_avg:98.60ms -step:83/1695 train_time:8179ms step_avg:98.55ms -step:84/1695 train_time:8274ms step_avg:98.50ms -step:85/1695 train_time:8368ms step_avg:98.45ms -step:86/1695 train_time:8462ms step_avg:98.39ms -step:87/1695 train_time:8558ms step_avg:98.36ms -step:88/1695 train_time:8653ms step_avg:98.33ms -step:89/1695 train_time:8747ms step_avg:98.28ms -step:90/1695 train_time:8841ms step_avg:98.24ms -step:91/1695 train_time:8936ms step_avg:98.20ms -step:92/1695 train_time:9031ms step_avg:98.16ms -step:93/1695 train_time:9125ms step_avg:98.12ms -step:94/1695 train_time:9219ms step_avg:98.08ms -step:95/1695 train_time:9313ms step_avg:98.03ms -step:96/1695 train_time:9406ms step_avg:97.98ms -step:97/1695 train_time:9500ms step_avg:97.94ms -step:98/1695 train_time:9596ms step_avg:97.91ms -step:99/1695 train_time:9690ms step_avg:97.88ms -step:100/1695 train_time:9784ms step_avg:97.84ms -step:101/1695 train_time:9878ms step_avg:97.80ms -step:102/1695 train_time:9973ms step_avg:97.78ms -step:103/1695 train_time:10067ms step_avg:97.74ms -step:104/1695 train_time:10162ms step_avg:97.71ms -step:105/1695 train_time:10256ms step_avg:97.67ms -step:106/1695 train_time:10350ms step_avg:97.64ms -step:107/1695 train_time:10444ms step_avg:97.60ms -step:108/1695 train_time:10540ms step_avg:97.59ms -step:109/1695 train_time:10635ms step_avg:97.57ms -step:110/1695 train_time:10731ms step_avg:97.55ms -step:111/1695 train_time:10825ms step_avg:97.52ms -step:112/1695 train_time:10920ms step_avg:97.50ms -step:113/1695 train_time:11014ms step_avg:97.47ms -step:114/1695 train_time:11108ms step_avg:97.44ms -step:115/1695 train_time:11201ms step_avg:97.40ms -step:116/1695 train_time:11297ms step_avg:97.38ms -step:117/1695 train_time:11391ms step_avg:97.36ms -step:118/1695 train_time:11485ms step_avg:97.33ms -step:119/1695 train_time:11580ms step_avg:97.31ms -step:120/1695 train_time:11675ms step_avg:97.29ms -step:121/1695 train_time:11769ms step_avg:97.27ms -step:122/1695 train_time:11863ms step_avg:97.24ms -step:123/1695 train_time:11959ms step_avg:97.22ms -step:124/1695 train_time:12054ms step_avg:97.21ms -step:125/1695 train_time:12148ms step_avg:97.18ms -step:125/1695 val_loss:4.3195 train_time:12239ms step_avg:97.92ms -step:126/1695 train_time:12266ms step_avg:97.35ms -step:127/1695 train_time:12343ms step_avg:97.19ms -step:128/1695 train_time:12442ms step_avg:97.21ms -step:129/1695 train_time:12537ms step_avg:97.19ms -step:130/1695 train_time:12631ms step_avg:97.16ms -step:131/1695 train_time:12725ms step_avg:97.14ms -step:132/1695 train_time:12818ms step_avg:97.11ms -step:133/1695 train_time:12911ms step_avg:97.07ms -step:134/1695 train_time:13005ms step_avg:97.05ms -step:135/1695 train_time:13098ms step_avg:97.02ms -step:136/1695 train_time:13191ms step_avg:97.00ms -step:137/1695 train_time:13287ms step_avg:96.99ms -step:138/1695 train_time:13385ms step_avg:97.00ms -step:139/1695 train_time:13481ms step_avg:96.98ms -step:140/1695 train_time:13575ms step_avg:96.97ms -step:141/1695 train_time:13669ms step_avg:96.94ms -step:142/1695 train_time:13763ms step_avg:96.92ms -step:143/1695 train_time:13856ms step_avg:96.90ms -step:144/1695 train_time:13949ms step_avg:96.87ms -step:145/1695 train_time:14043ms step_avg:96.85ms -step:146/1695 train_time:14136ms step_avg:96.82ms -step:147/1695 train_time:14230ms step_avg:96.80ms -step:148/1695 train_time:14326ms step_avg:96.79ms -step:149/1695 train_time:14422ms step_avg:96.79ms -step:150/1695 train_time:14517ms step_avg:96.78ms -step:151/1695 train_time:14612ms step_avg:96.77ms -step:152/1695 train_time:14707ms step_avg:96.75ms -step:153/1695 train_time:14801ms step_avg:96.74ms -step:154/1695 train_time:14895ms step_avg:96.72ms -step:155/1695 train_time:14989ms step_avg:96.70ms -step:156/1695 train_time:15083ms step_avg:96.69ms -step:157/1695 train_time:15176ms step_avg:96.66ms -step:158/1695 train_time:15270ms step_avg:96.65ms -step:159/1695 train_time:15365ms step_avg:96.63ms -step:160/1695 train_time:15461ms step_avg:96.63ms -step:161/1695 train_time:15556ms step_avg:96.62ms -step:162/1695 train_time:15650ms step_avg:96.60ms -step:163/1695 train_time:15744ms step_avg:96.59ms -step:164/1695 train_time:15839ms step_avg:96.58ms -step:165/1695 train_time:15933ms step_avg:96.56ms -step:166/1695 train_time:16027ms step_avg:96.55ms -step:167/1695 train_time:16121ms step_avg:96.53ms -step:168/1695 train_time:16215ms step_avg:96.52ms -step:169/1695 train_time:16309ms step_avg:96.50ms -step:170/1695 train_time:16404ms step_avg:96.50ms -step:171/1695 train_time:16499ms step_avg:96.49ms -step:172/1695 train_time:16594ms step_avg:96.47ms -step:173/1695 train_time:16931ms step_avg:97.87ms -step:174/1695 train_time:17020ms step_avg:97.82ms -step:175/1695 train_time:17114ms step_avg:97.79ms -step:176/1695 train_time:17207ms step_avg:97.77ms -step:177/1695 train_time:17301ms step_avg:97.74ms -step:178/1695 train_time:17394ms step_avg:97.72ms -step:179/1695 train_time:17487ms step_avg:97.69ms -step:180/1695 train_time:17581ms step_avg:97.67ms -step:181/1695 train_time:17673ms step_avg:97.64ms -step:182/1695 train_time:17767ms step_avg:97.62ms -step:183/1695 train_time:17864ms step_avg:97.62ms -step:184/1695 train_time:17961ms step_avg:97.61ms -step:185/1695 train_time:18055ms step_avg:97.60ms -step:186/1695 train_time:18149ms step_avg:97.58ms -step:187/1695 train_time:18244ms step_avg:97.56ms -step:188/1695 train_time:18339ms step_avg:97.55ms -step:189/1695 train_time:18432ms step_avg:97.52ms -step:190/1695 train_time:18526ms step_avg:97.50ms -step:191/1695 train_time:18620ms step_avg:97.49ms -step:192/1695 train_time:18713ms step_avg:97.46ms -step:193/1695 train_time:18808ms step_avg:97.45ms -step:194/1695 train_time:18904ms step_avg:97.44ms -step:195/1695 train_time:19000ms step_avg:97.43ms -step:196/1695 train_time:19093ms step_avg:97.42ms -step:197/1695 train_time:19188ms step_avg:97.40ms -step:198/1695 train_time:19283ms step_avg:97.39ms -step:199/1695 train_time:19377ms step_avg:97.37ms -step:200/1695 train_time:19470ms step_avg:97.35ms -step:201/1695 train_time:19564ms step_avg:97.33ms -step:202/1695 train_time:19657ms step_avg:97.31ms -step:203/1695 train_time:19751ms step_avg:97.30ms -step:204/1695 train_time:19847ms step_avg:97.29ms -step:205/1695 train_time:19942ms step_avg:97.28ms -step:206/1695 train_time:20036ms step_avg:97.26ms -step:207/1695 train_time:20129ms step_avg:97.24ms -step:208/1695 train_time:20224ms step_avg:97.23ms -step:209/1695 train_time:20318ms step_avg:97.21ms -step:210/1695 train_time:20411ms step_avg:97.20ms -step:211/1695 train_time:20506ms step_avg:97.18ms -step:212/1695 train_time:20601ms step_avg:97.17ms -step:213/1695 train_time:20693ms step_avg:97.15ms -step:214/1695 train_time:20787ms step_avg:97.14ms -step:215/1695 train_time:20883ms step_avg:97.13ms -step:216/1695 train_time:20977ms step_avg:97.12ms -step:217/1695 train_time:21071ms step_avg:97.10ms -step:218/1695 train_time:21166ms step_avg:97.09ms -step:219/1695 train_time:21261ms step_avg:97.08ms -step:220/1695 train_time:21355ms step_avg:97.07ms -step:221/1695 train_time:21449ms step_avg:97.05ms -step:222/1695 train_time:21544ms step_avg:97.04ms -step:223/1695 train_time:21637ms step_avg:97.03ms -step:224/1695 train_time:21730ms step_avg:97.01ms -step:225/1695 train_time:21825ms step_avg:97.00ms -step:226/1695 train_time:21919ms step_avg:96.99ms -step:227/1695 train_time:22013ms step_avg:96.97ms -step:228/1695 train_time:22108ms step_avg:96.97ms -step:229/1695 train_time:22203ms step_avg:96.96ms -step:230/1695 train_time:22298ms step_avg:96.95ms -step:231/1695 train_time:22391ms step_avg:96.93ms -step:232/1695 train_time:22486ms step_avg:96.92ms -step:233/1695 train_time:22580ms step_avg:96.91ms -step:234/1695 train_time:22673ms step_avg:96.89ms -step:235/1695 train_time:22767ms step_avg:96.88ms -step:236/1695 train_time:22862ms step_avg:96.87ms -step:237/1695 train_time:22957ms step_avg:96.86ms -step:238/1695 train_time:23051ms step_avg:96.85ms -step:239/1695 train_time:23146ms step_avg:96.85ms -step:240/1695 train_time:23242ms step_avg:96.84ms -step:241/1695 train_time:23336ms step_avg:96.83ms -step:242/1695 train_time:23430ms step_avg:96.82ms -step:243/1695 train_time:23526ms step_avg:96.81ms -step:244/1695 train_time:23621ms step_avg:96.81ms -step:245/1695 train_time:23715ms step_avg:96.79ms -step:246/1695 train_time:23809ms step_avg:96.78ms -step:247/1695 train_time:23902ms step_avg:96.77ms -step:248/1695 train_time:23996ms step_avg:96.76ms -step:249/1695 train_time:24089ms step_avg:96.74ms -step:250/1695 train_time:24184ms step_avg:96.73ms -step:250/1695 val_loss:3.9759 train_time:24277ms step_avg:97.11ms -step:251/1695 train_time:24301ms step_avg:96.82ms -step:252/1695 train_time:24381ms step_avg:96.75ms -step:253/1695 train_time:24478ms step_avg:96.75ms -step:254/1695 train_time:24574ms step_avg:96.75ms -step:255/1695 train_time:24668ms step_avg:96.74ms -step:256/1695 train_time:24760ms step_avg:96.72ms -step:257/1695 train_time:24854ms step_avg:96.71ms -step:258/1695 train_time:24948ms step_avg:96.70ms -step:259/1695 train_time:25041ms step_avg:96.68ms -step:260/1695 train_time:25134ms step_avg:96.67ms -step:261/1695 train_time:25228ms step_avg:96.66ms -step:262/1695 train_time:25322ms step_avg:96.65ms -step:263/1695 train_time:25418ms step_avg:96.65ms -step:264/1695 train_time:25514ms step_avg:96.64ms -step:265/1695 train_time:25609ms step_avg:96.64ms -step:266/1695 train_time:25703ms step_avg:96.63ms -step:267/1695 train_time:25796ms step_avg:96.62ms -step:268/1695 train_time:25890ms step_avg:96.60ms -step:269/1695 train_time:25983ms step_avg:96.59ms -step:270/1695 train_time:26076ms step_avg:96.58ms -step:271/1695 train_time:26169ms step_avg:96.57ms -step:272/1695 train_time:26264ms step_avg:96.56ms -step:273/1695 train_time:26357ms step_avg:96.55ms -step:274/1695 train_time:26453ms step_avg:96.54ms -step:275/1695 train_time:26549ms step_avg:96.54ms -step:276/1695 train_time:26643ms step_avg:96.53ms -step:277/1695 train_time:26737ms step_avg:96.52ms -step:278/1695 train_time:26831ms step_avg:96.51ms -step:279/1695 train_time:26924ms step_avg:96.50ms -step:280/1695 train_time:27017ms step_avg:96.49ms -step:281/1695 train_time:27111ms step_avg:96.48ms -step:282/1695 train_time:27205ms step_avg:96.47ms -step:283/1695 train_time:27298ms step_avg:96.46ms -step:284/1695 train_time:27393ms step_avg:96.45ms -step:285/1695 train_time:27488ms step_avg:96.45ms -step:286/1695 train_time:27582ms step_avg:96.44ms -step:287/1695 train_time:27677ms step_avg:96.43ms -step:288/1695 train_time:27772ms step_avg:96.43ms -step:289/1695 train_time:27868ms step_avg:96.43ms -step:290/1695 train_time:27961ms step_avg:96.42ms -step:291/1695 train_time:28055ms step_avg:96.41ms -step:292/1695 train_time:28149ms step_avg:96.40ms -step:293/1695 train_time:28242ms step_avg:96.39ms -step:294/1695 train_time:28336ms step_avg:96.38ms -step:295/1695 train_time:28431ms step_avg:96.38ms -step:296/1695 train_time:28526ms step_avg:96.37ms -step:297/1695 train_time:28619ms step_avg:96.36ms -step:298/1695 train_time:28713ms step_avg:96.35ms -step:299/1695 train_time:28809ms step_avg:96.35ms -step:300/1695 train_time:28903ms step_avg:96.34ms -step:301/1695 train_time:28997ms step_avg:96.33ms -step:302/1695 train_time:29091ms step_avg:96.33ms -step:303/1695 train_time:29186ms step_avg:96.32ms -step:304/1695 train_time:29279ms step_avg:96.31ms -step:305/1695 train_time:29374ms step_avg:96.31ms -step:306/1695 train_time:29469ms step_avg:96.31ms -step:307/1695 train_time:29565ms step_avg:96.30ms -step:308/1695 train_time:29659ms step_avg:96.30ms -step:309/1695 train_time:29753ms step_avg:96.29ms -step:310/1695 train_time:29848ms step_avg:96.29ms -step:311/1695 train_time:29942ms step_avg:96.28ms -step:312/1695 train_time:30036ms step_avg:96.27ms -step:313/1695 train_time:30131ms step_avg:96.26ms -step:314/1695 train_time:30224ms step_avg:96.25ms -step:315/1695 train_time:30317ms step_avg:96.25ms -step:316/1695 train_time:30412ms step_avg:96.24ms -step:317/1695 train_time:30508ms step_avg:96.24ms -step:318/1695 train_time:30602ms step_avg:96.23ms -step:319/1695 train_time:30696ms step_avg:96.23ms -step:320/1695 train_time:30790ms step_avg:96.22ms -step:321/1695 train_time:30884ms step_avg:96.21ms -step:322/1695 train_time:30978ms step_avg:96.21ms -step:323/1695 train_time:31073ms step_avg:96.20ms -step:324/1695 train_time:31168ms step_avg:96.20ms -step:325/1695 train_time:31261ms step_avg:96.19ms -step:326/1695 train_time:31354ms step_avg:96.18ms -step:327/1695 train_time:31448ms step_avg:96.17ms -step:328/1695 train_time:31543ms step_avg:96.17ms -step:329/1695 train_time:31637ms step_avg:96.16ms -step:330/1695 train_time:31733ms step_avg:96.16ms -step:331/1695 train_time:31828ms step_avg:96.16ms -step:332/1695 train_time:31921ms step_avg:96.15ms -step:333/1695 train_time:32015ms step_avg:96.14ms -step:334/1695 train_time:32109ms step_avg:96.14ms -step:335/1695 train_time:32203ms step_avg:96.13ms -step:336/1695 train_time:32296ms step_avg:96.12ms -step:337/1695 train_time:32390ms step_avg:96.11ms -step:338/1695 train_time:32483ms step_avg:96.10ms -step:339/1695 train_time:32577ms step_avg:96.10ms -step:340/1695 train_time:32672ms step_avg:96.09ms -step:341/1695 train_time:32767ms step_avg:96.09ms -step:342/1695 train_time:32861ms step_avg:96.08ms -step:343/1695 train_time:32956ms step_avg:96.08ms -step:344/1695 train_time:33050ms step_avg:96.08ms -step:345/1695 train_time:33388ms step_avg:96.78ms -step:346/1695 train_time:33462ms step_avg:96.71ms -step:347/1695 train_time:33554ms step_avg:96.70ms -step:348/1695 train_time:33647ms step_avg:96.69ms -step:349/1695 train_time:33740ms step_avg:96.68ms -step:350/1695 train_time:33833ms step_avg:96.67ms -step:351/1695 train_time:33927ms step_avg:96.66ms -step:352/1695 train_time:34019ms step_avg:96.65ms -step:353/1695 train_time:34112ms step_avg:96.64ms -step:354/1695 train_time:34205ms step_avg:96.62ms -step:355/1695 train_time:34300ms step_avg:96.62ms -step:356/1695 train_time:34397ms step_avg:96.62ms -step:357/1695 train_time:34494ms step_avg:96.62ms -step:358/1695 train_time:34589ms step_avg:96.62ms -step:359/1695 train_time:34682ms step_avg:96.61ms -step:360/1695 train_time:34775ms step_avg:96.60ms -step:361/1695 train_time:34868ms step_avg:96.59ms -step:362/1695 train_time:34961ms step_avg:96.58ms -step:363/1695 train_time:35055ms step_avg:96.57ms -step:364/1695 train_time:35148ms step_avg:96.56ms -step:365/1695 train_time:35243ms step_avg:96.56ms -step:366/1695 train_time:35338ms step_avg:96.55ms -step:367/1695 train_time:35434ms step_avg:96.55ms -step:368/1695 train_time:35530ms step_avg:96.55ms -step:369/1695 train_time:35624ms step_avg:96.54ms -step:370/1695 train_time:35717ms step_avg:96.53ms -step:371/1695 train_time:35811ms step_avg:96.53ms -step:372/1695 train_time:35905ms step_avg:96.52ms -step:373/1695 train_time:35998ms step_avg:96.51ms -step:374/1695 train_time:36092ms step_avg:96.50ms -step:375/1695 train_time:36187ms step_avg:96.50ms -step:375/1695 val_loss:3.8237 train_time:36278ms step_avg:96.74ms -step:376/1695 train_time:36303ms step_avg:96.55ms -step:377/1695 train_time:36381ms step_avg:96.50ms -step:378/1695 train_time:36478ms step_avg:96.50ms -step:379/1695 train_time:36573ms step_avg:96.50ms -step:380/1695 train_time:36667ms step_avg:96.49ms -step:381/1695 train_time:36761ms step_avg:96.48ms -step:382/1695 train_time:36854ms step_avg:96.48ms -step:383/1695 train_time:36948ms step_avg:96.47ms -step:384/1695 train_time:37040ms step_avg:96.46ms -step:385/1695 train_time:37133ms step_avg:96.45ms -step:386/1695 train_time:37227ms step_avg:96.44ms -step:387/1695 train_time:37323ms step_avg:96.44ms -step:388/1695 train_time:37419ms step_avg:96.44ms -step:389/1695 train_time:37514ms step_avg:96.44ms -step:390/1695 train_time:37609ms step_avg:96.43ms -step:391/1695 train_time:37703ms step_avg:96.43ms -step:392/1695 train_time:37795ms step_avg:96.42ms -step:393/1695 train_time:37890ms step_avg:96.41ms -step:394/1695 train_time:37982ms step_avg:96.40ms -step:395/1695 train_time:38075ms step_avg:96.39ms -step:396/1695 train_time:38169ms step_avg:96.39ms -step:397/1695 train_time:38263ms step_avg:96.38ms -step:398/1695 train_time:38358ms step_avg:96.38ms -step:399/1695 train_time:38453ms step_avg:96.37ms -step:400/1695 train_time:38548ms step_avg:96.37ms -step:401/1695 train_time:38642ms step_avg:96.36ms -step:402/1695 train_time:38736ms step_avg:96.36ms -step:403/1695 train_time:38830ms step_avg:96.35ms -step:404/1695 train_time:38924ms step_avg:96.35ms -step:405/1695 train_time:39017ms step_avg:96.34ms -step:406/1695 train_time:39111ms step_avg:96.33ms -step:407/1695 train_time:39205ms step_avg:96.33ms -step:408/1695 train_time:39298ms step_avg:96.32ms -step:409/1695 train_time:39394ms step_avg:96.32ms -step:410/1695 train_time:39489ms step_avg:96.31ms -step:411/1695 train_time:39583ms step_avg:96.31ms -step:412/1695 train_time:39676ms step_avg:96.30ms -step:413/1695 train_time:39771ms step_avg:96.30ms -step:414/1695 train_time:39865ms step_avg:96.29ms -step:415/1695 train_time:39959ms step_avg:96.29ms -step:416/1695 train_time:40053ms step_avg:96.28ms -step:417/1695 train_time:40148ms step_avg:96.28ms -step:418/1695 train_time:40241ms step_avg:96.27ms -step:419/1695 train_time:40335ms step_avg:96.27ms -step:420/1695 train_time:40430ms step_avg:96.26ms -step:421/1695 train_time:40525ms step_avg:96.26ms -step:422/1695 train_time:40618ms step_avg:96.25ms -step:423/1695 train_time:40712ms step_avg:96.25ms -step:424/1695 train_time:40806ms step_avg:96.24ms -step:425/1695 train_time:40900ms step_avg:96.23ms -step:426/1695 train_time:40994ms step_avg:96.23ms -step:427/1695 train_time:41088ms step_avg:96.23ms -step:428/1695 train_time:41181ms step_avg:96.22ms -step:429/1695 train_time:41275ms step_avg:96.21ms -step:430/1695 train_time:41370ms step_avg:96.21ms -step:431/1695 train_time:41465ms step_avg:96.21ms -step:432/1695 train_time:41559ms step_avg:96.20ms -step:433/1695 train_time:41654ms step_avg:96.20ms -step:434/1695 train_time:41748ms step_avg:96.19ms -step:435/1695 train_time:41841ms step_avg:96.19ms -step:436/1695 train_time:41934ms step_avg:96.18ms -step:437/1695 train_time:42028ms step_avg:96.17ms -step:438/1695 train_time:42121ms step_avg:96.17ms -step:439/1695 train_time:42215ms step_avg:96.16ms -step:440/1695 train_time:42309ms step_avg:96.16ms -step:441/1695 train_time:42403ms step_avg:96.15ms -step:442/1695 train_time:42497ms step_avg:96.15ms -step:443/1695 train_time:42592ms step_avg:96.15ms -step:444/1695 train_time:42687ms step_avg:96.14ms -step:445/1695 train_time:42781ms step_avg:96.14ms -step:446/1695 train_time:42874ms step_avg:96.13ms -step:447/1695 train_time:42968ms step_avg:96.12ms -step:448/1695 train_time:43061ms step_avg:96.12ms -step:449/1695 train_time:43155ms step_avg:96.11ms -step:450/1695 train_time:43249ms step_avg:96.11ms -step:451/1695 train_time:43343ms step_avg:96.10ms -step:452/1695 train_time:43437ms step_avg:96.10ms -step:453/1695 train_time:43531ms step_avg:96.09ms -step:454/1695 train_time:43625ms step_avg:96.09ms -step:455/1695 train_time:43719ms step_avg:96.08ms -step:456/1695 train_time:43813ms step_avg:96.08ms -step:457/1695 train_time:43907ms step_avg:96.08ms -step:458/1695 train_time:44000ms step_avg:96.07ms -step:459/1695 train_time:44095ms step_avg:96.07ms -step:460/1695 train_time:44189ms step_avg:96.06ms -step:461/1695 train_time:44283ms step_avg:96.06ms -step:462/1695 train_time:44377ms step_avg:96.05ms -step:463/1695 train_time:44471ms step_avg:96.05ms -step:464/1695 train_time:44566ms step_avg:96.05ms -step:465/1695 train_time:44660ms step_avg:96.04ms -step:466/1695 train_time:44754ms step_avg:96.04ms -step:467/1695 train_time:44849ms step_avg:96.04ms -step:468/1695 train_time:44942ms step_avg:96.03ms -step:469/1695 train_time:45036ms step_avg:96.03ms -step:470/1695 train_time:45131ms step_avg:96.02ms -step:471/1695 train_time:45224ms step_avg:96.02ms -step:472/1695 train_time:45318ms step_avg:96.01ms -step:473/1695 train_time:45413ms step_avg:96.01ms -step:474/1695 train_time:45507ms step_avg:96.01ms -step:475/1695 train_time:45601ms step_avg:96.00ms -step:476/1695 train_time:45695ms step_avg:96.00ms -step:477/1695 train_time:45789ms step_avg:95.99ms -step:478/1695 train_time:45882ms step_avg:95.99ms -step:479/1695 train_time:45976ms step_avg:95.98ms -step:480/1695 train_time:46071ms step_avg:95.98ms -step:481/1695 train_time:46165ms step_avg:95.98ms -step:482/1695 train_time:46259ms step_avg:95.97ms -step:483/1695 train_time:46353ms step_avg:95.97ms -step:484/1695 train_time:46449ms step_avg:95.97ms -step:485/1695 train_time:46542ms step_avg:95.96ms -step:486/1695 train_time:46636ms step_avg:95.96ms -step:487/1695 train_time:46731ms step_avg:95.96ms -step:488/1695 train_time:46825ms step_avg:95.95ms -step:489/1695 train_time:46918ms step_avg:95.95ms -step:490/1695 train_time:47013ms step_avg:95.94ms -step:491/1695 train_time:47107ms step_avg:95.94ms -step:492/1695 train_time:47200ms step_avg:95.94ms -step:493/1695 train_time:47295ms step_avg:95.93ms -step:494/1695 train_time:47390ms step_avg:95.93ms -step:495/1695 train_time:47484ms step_avg:95.93ms -step:496/1695 train_time:47577ms step_avg:95.92ms -step:497/1695 train_time:47672ms step_avg:95.92ms -step:498/1695 train_time:47766ms step_avg:95.92ms -step:499/1695 train_time:47860ms step_avg:95.91ms -step:500/1695 train_time:47954ms step_avg:95.91ms -step:500/1695 val_loss:3.7206 train_time:48046ms step_avg:96.09ms -step:501/1695 train_time:48071ms step_avg:95.95ms -step:502/1695 train_time:48149ms step_avg:95.91ms -step:503/1695 train_time:48247ms step_avg:95.92ms -step:504/1695 train_time:48342ms step_avg:95.92ms -step:505/1695 train_time:48436ms step_avg:95.91ms -step:506/1695 train_time:48529ms step_avg:95.91ms -step:507/1695 train_time:48622ms step_avg:95.90ms -step:508/1695 train_time:48715ms step_avg:95.90ms -step:509/1695 train_time:48808ms step_avg:95.89ms -step:510/1695 train_time:48901ms step_avg:95.88ms -step:511/1695 train_time:48994ms step_avg:95.88ms -step:512/1695 train_time:49090ms step_avg:95.88ms -step:513/1695 train_time:49185ms step_avg:95.88ms -step:514/1695 train_time:49281ms step_avg:95.88ms -step:515/1695 train_time:49377ms step_avg:95.88ms -step:516/1695 train_time:49471ms step_avg:95.87ms -step:517/1695 train_time:49564ms step_avg:95.87ms -step:518/1695 train_time:49657ms step_avg:95.86ms -step:519/1695 train_time:49991ms step_avg:96.32ms -step:520/1695 train_time:50182ms step_avg:96.50ms -step:521/1695 train_time:50274ms step_avg:96.50ms -step:522/1695 train_time:50366ms step_avg:96.49ms -step:523/1695 train_time:50458ms step_avg:96.48ms -step:524/1695 train_time:50551ms step_avg:96.47ms -step:525/1695 train_time:50644ms step_avg:96.46ms -step:526/1695 train_time:50737ms step_avg:96.46ms -step:527/1695 train_time:50829ms step_avg:96.45ms -step:528/1695 train_time:50922ms step_avg:96.44ms -step:529/1695 train_time:51018ms step_avg:96.44ms -step:530/1695 train_time:51114ms step_avg:96.44ms -step:531/1695 train_time:51210ms step_avg:96.44ms -step:532/1695 train_time:51305ms step_avg:96.44ms -step:533/1695 train_time:51399ms step_avg:96.43ms -step:534/1695 train_time:51493ms step_avg:96.43ms -step:535/1695 train_time:51586ms step_avg:96.42ms -step:536/1695 train_time:51679ms step_avg:96.42ms -step:537/1695 train_time:51772ms step_avg:96.41ms -step:538/1695 train_time:51865ms step_avg:96.40ms -step:539/1695 train_time:51958ms step_avg:96.40ms -step:540/1695 train_time:52052ms step_avg:96.39ms -step:541/1695 train_time:52147ms step_avg:96.39ms -step:542/1695 train_time:52241ms step_avg:96.39ms -step:543/1695 train_time:52336ms step_avg:96.38ms -step:544/1695 train_time:52429ms step_avg:96.38ms -step:545/1695 train_time:52523ms step_avg:96.37ms -step:546/1695 train_time:52617ms step_avg:96.37ms -step:547/1695 train_time:52710ms step_avg:96.36ms -step:548/1695 train_time:52803ms step_avg:96.36ms -step:549/1695 train_time:52898ms step_avg:96.35ms -step:550/1695 train_time:52991ms step_avg:96.35ms -step:551/1695 train_time:53086ms step_avg:96.34ms -step:552/1695 train_time:53182ms step_avg:96.34ms -step:553/1695 train_time:53276ms step_avg:96.34ms -step:554/1695 train_time:53370ms step_avg:96.34ms -step:555/1695 train_time:53464ms step_avg:96.33ms -step:556/1695 train_time:53558ms step_avg:96.33ms -step:557/1695 train_time:53652ms step_avg:96.32ms -step:558/1695 train_time:53745ms step_avg:96.32ms -step:559/1695 train_time:53839ms step_avg:96.31ms -step:560/1695 train_time:53933ms step_avg:96.31ms -step:561/1695 train_time:54027ms step_avg:96.30ms -step:562/1695 train_time:54121ms step_avg:96.30ms -step:563/1695 train_time:54216ms step_avg:96.30ms -step:564/1695 train_time:54311ms step_avg:96.30ms -step:565/1695 train_time:54405ms step_avg:96.29ms -step:566/1695 train_time:54499ms step_avg:96.29ms -step:567/1695 train_time:54594ms step_avg:96.29ms -step:568/1695 train_time:54690ms step_avg:96.28ms -step:569/1695 train_time:54785ms step_avg:96.28ms -step:570/1695 train_time:54881ms step_avg:96.28ms -step:571/1695 train_time:54978ms step_avg:96.28ms -step:572/1695 train_time:55074ms step_avg:96.28ms -step:573/1695 train_time:55169ms step_avg:96.28ms -step:574/1695 train_time:55265ms step_avg:96.28ms -step:575/1695 train_time:55363ms step_avg:96.28ms -step:576/1695 train_time:55460ms step_avg:96.28ms -step:577/1695 train_time:55557ms step_avg:96.29ms -step:578/1695 train_time:55653ms step_avg:96.29ms -step:579/1695 train_time:55748ms step_avg:96.28ms -step:580/1695 train_time:55845ms step_avg:96.28ms -step:581/1695 train_time:55940ms step_avg:96.28ms -step:582/1695 train_time:56037ms step_avg:96.28ms -step:583/1695 train_time:56132ms step_avg:96.28ms -step:584/1695 train_time:56228ms step_avg:96.28ms -step:585/1695 train_time:56324ms step_avg:96.28ms -step:586/1695 train_time:56421ms step_avg:96.28ms -step:587/1695 train_time:56519ms step_avg:96.28ms -step:588/1695 train_time:56617ms step_avg:96.29ms -step:589/1695 train_time:56712ms step_avg:96.29ms -step:590/1695 train_time:56807ms step_avg:96.28ms -step:591/1695 train_time:56903ms step_avg:96.28ms -step:592/1695 train_time:57000ms step_avg:96.28ms -step:593/1695 train_time:57096ms step_avg:96.28ms -step:594/1695 train_time:57192ms step_avg:96.28ms -step:595/1695 train_time:57288ms step_avg:96.28ms -step:596/1695 train_time:57384ms step_avg:96.28ms -step:597/1695 train_time:57481ms step_avg:96.28ms -step:598/1695 train_time:57577ms step_avg:96.28ms -step:599/1695 train_time:57673ms step_avg:96.28ms -step:600/1695 train_time:57769ms step_avg:96.28ms -step:601/1695 train_time:57864ms step_avg:96.28ms -step:602/1695 train_time:57960ms step_avg:96.28ms -step:603/1695 train_time:58057ms step_avg:96.28ms -step:604/1695 train_time:58153ms step_avg:96.28ms -step:605/1695 train_time:58248ms step_avg:96.28ms -step:606/1695 train_time:58344ms step_avg:96.28ms -step:607/1695 train_time:58440ms step_avg:96.28ms -step:608/1695 train_time:58537ms step_avg:96.28ms -step:609/1695 train_time:58633ms step_avg:96.28ms -step:610/1695 train_time:58728ms step_avg:96.28ms -step:611/1695 train_time:58824ms step_avg:96.27ms -step:612/1695 train_time:58920ms step_avg:96.27ms -step:613/1695 train_time:59016ms step_avg:96.27ms -step:614/1695 train_time:59111ms step_avg:96.27ms -step:615/1695 train_time:59208ms step_avg:96.27ms -step:616/1695 train_time:59305ms step_avg:96.27ms -step:617/1695 train_time:59401ms step_avg:96.27ms -step:618/1695 train_time:59499ms step_avg:96.28ms -step:619/1695 train_time:59596ms step_avg:96.28ms -step:620/1695 train_time:59692ms step_avg:96.28ms -step:621/1695 train_time:59789ms step_avg:96.28ms -step:622/1695 train_time:59885ms step_avg:96.28ms -step:623/1695 train_time:59982ms step_avg:96.28ms -step:624/1695 train_time:60079ms step_avg:96.28ms -step:625/1695 train_time:60176ms step_avg:96.28ms -step:625/1695 val_loss:3.6228 train_time:60270ms step_avg:96.43ms -step:626/1695 train_time:60294ms step_avg:96.32ms -step:627/1695 train_time:60378ms step_avg:96.30ms -step:628/1695 train_time:60473ms step_avg:96.30ms -step:629/1695 train_time:60570ms step_avg:96.30ms -step:630/1695 train_time:60665ms step_avg:96.29ms -step:631/1695 train_time:60761ms step_avg:96.29ms -step:632/1695 train_time:60855ms step_avg:96.29ms -step:633/1695 train_time:60949ms step_avg:96.29ms -step:634/1695 train_time:61044ms step_avg:96.28ms -step:635/1695 train_time:61139ms step_avg:96.28ms -step:636/1695 train_time:61237ms step_avg:96.28ms -step:637/1695 train_time:61335ms step_avg:96.29ms -step:638/1695 train_time:61432ms step_avg:96.29ms -step:639/1695 train_time:61530ms step_avg:96.29ms -step:640/1695 train_time:61627ms step_avg:96.29ms -step:641/1695 train_time:61723ms step_avg:96.29ms -step:642/1695 train_time:61818ms step_avg:96.29ms -step:643/1695 train_time:61912ms step_avg:96.29ms -step:644/1695 train_time:62007ms step_avg:96.28ms -step:645/1695 train_time:62103ms step_avg:96.28ms -step:646/1695 train_time:62199ms step_avg:96.28ms -step:647/1695 train_time:62297ms step_avg:96.29ms -step:648/1695 train_time:62394ms step_avg:96.29ms -step:649/1695 train_time:62491ms step_avg:96.29ms -step:650/1695 train_time:62588ms step_avg:96.29ms -step:651/1695 train_time:62686ms step_avg:96.29ms -step:652/1695 train_time:62783ms step_avg:96.29ms -step:653/1695 train_time:62879ms step_avg:96.29ms -step:654/1695 train_time:62973ms step_avg:96.29ms -step:655/1695 train_time:63068ms step_avg:96.29ms -step:656/1695 train_time:63166ms step_avg:96.29ms -step:657/1695 train_time:63264ms step_avg:96.29ms -step:658/1695 train_time:63362ms step_avg:96.29ms -step:659/1695 train_time:63459ms step_avg:96.30ms -step:660/1695 train_time:63555ms step_avg:96.30ms -step:661/1695 train_time:63652ms step_avg:96.30ms -step:662/1695 train_time:63747ms step_avg:96.30ms -step:663/1695 train_time:63844ms step_avg:96.30ms -step:664/1695 train_time:63939ms step_avg:96.29ms -step:665/1695 train_time:64034ms step_avg:96.29ms -step:666/1695 train_time:64130ms step_avg:96.29ms -step:667/1695 train_time:64227ms step_avg:96.29ms -step:668/1695 train_time:64325ms step_avg:96.29ms -step:669/1695 train_time:64422ms step_avg:96.30ms -step:670/1695 train_time:64518ms step_avg:96.30ms -step:671/1695 train_time:64614ms step_avg:96.30ms -step:672/1695 train_time:64710ms step_avg:96.29ms -step:673/1695 train_time:64806ms step_avg:96.29ms -step:674/1695 train_time:64902ms step_avg:96.29ms -step:675/1695 train_time:64998ms step_avg:96.29ms -step:676/1695 train_time:65093ms step_avg:96.29ms -step:677/1695 train_time:65190ms step_avg:96.29ms -step:678/1695 train_time:65287ms step_avg:96.29ms -step:679/1695 train_time:65384ms step_avg:96.30ms -step:680/1695 train_time:65482ms step_avg:96.30ms -step:681/1695 train_time:65578ms step_avg:96.30ms -step:682/1695 train_time:65674ms step_avg:96.30ms -step:683/1695 train_time:65769ms step_avg:96.29ms -step:684/1695 train_time:65866ms step_avg:96.30ms -step:685/1695 train_time:65962ms step_avg:96.29ms -step:686/1695 train_time:66058ms step_avg:96.29ms -step:687/1695 train_time:66153ms step_avg:96.29ms -step:688/1695 train_time:66249ms step_avg:96.29ms -step:689/1695 train_time:66345ms step_avg:96.29ms -step:690/1695 train_time:66441ms step_avg:96.29ms -step:691/1695 train_time:66884ms step_avg:96.79ms -step:692/1695 train_time:66970ms step_avg:96.78ms -step:693/1695 train_time:67064ms step_avg:96.77ms -step:694/1695 train_time:67159ms step_avg:96.77ms -step:695/1695 train_time:67254ms step_avg:96.77ms -step:696/1695 train_time:67348ms step_avg:96.77ms -step:697/1695 train_time:67443ms step_avg:96.76ms -step:698/1695 train_time:67538ms step_avg:96.76ms -step:699/1695 train_time:67633ms step_avg:96.76ms -step:700/1695 train_time:67727ms step_avg:96.75ms -step:701/1695 train_time:67826ms step_avg:96.76ms -step:702/1695 train_time:67927ms step_avg:96.76ms -step:703/1695 train_time:68025ms step_avg:96.76ms -step:704/1695 train_time:68122ms step_avg:96.76ms -step:705/1695 train_time:68218ms step_avg:96.76ms -step:706/1695 train_time:68313ms step_avg:96.76ms -step:707/1695 train_time:68408ms step_avg:96.76ms -step:708/1695 train_time:68503ms step_avg:96.76ms -step:709/1695 train_time:68598ms step_avg:96.75ms -step:710/1695 train_time:68693ms step_avg:96.75ms -step:711/1695 train_time:68789ms step_avg:96.75ms -step:712/1695 train_time:68885ms step_avg:96.75ms -step:713/1695 train_time:68982ms step_avg:96.75ms -step:714/1695 train_time:69079ms step_avg:96.75ms -step:715/1695 train_time:69175ms step_avg:96.75ms -step:716/1695 train_time:69271ms step_avg:96.75ms -step:717/1695 train_time:69366ms step_avg:96.74ms -step:718/1695 train_time:69461ms step_avg:96.74ms -step:719/1695 train_time:69556ms step_avg:96.74ms -step:720/1695 train_time:69652ms step_avg:96.74ms -step:721/1695 train_time:69748ms step_avg:96.74ms -step:722/1695 train_time:69845ms step_avg:96.74ms -step:723/1695 train_time:69942ms step_avg:96.74ms -step:724/1695 train_time:70038ms step_avg:96.74ms -step:725/1695 train_time:70134ms step_avg:96.74ms -step:726/1695 train_time:70230ms step_avg:96.74ms -step:727/1695 train_time:70328ms step_avg:96.74ms -step:728/1695 train_time:70425ms step_avg:96.74ms -step:729/1695 train_time:70520ms step_avg:96.74ms -step:730/1695 train_time:70615ms step_avg:96.73ms -step:731/1695 train_time:70710ms step_avg:96.73ms -step:732/1695 train_time:70806ms step_avg:96.73ms -step:733/1695 train_time:70904ms step_avg:96.73ms -step:734/1695 train_time:71000ms step_avg:96.73ms -step:735/1695 train_time:71097ms step_avg:96.73ms -step:736/1695 train_time:71193ms step_avg:96.73ms -step:737/1695 train_time:71289ms step_avg:96.73ms -step:738/1695 train_time:71386ms step_avg:96.73ms -step:739/1695 train_time:71483ms step_avg:96.73ms -step:740/1695 train_time:71579ms step_avg:96.73ms -step:741/1695 train_time:71674ms step_avg:96.73ms -step:742/1695 train_time:71769ms step_avg:96.72ms -step:743/1695 train_time:71866ms step_avg:96.72ms -step:744/1695 train_time:71963ms step_avg:96.72ms -step:745/1695 train_time:72059ms step_avg:96.72ms -step:746/1695 train_time:72155ms step_avg:96.72ms -step:747/1695 train_time:72251ms step_avg:96.72ms -step:748/1695 train_time:72348ms step_avg:96.72ms -step:749/1695 train_time:72444ms step_avg:96.72ms -step:750/1695 train_time:72540ms step_avg:96.72ms -step:750/1695 val_loss:3.5691 train_time:72633ms step_avg:96.84ms -step:751/1695 train_time:72658ms step_avg:96.75ms -step:752/1695 train_time:72740ms step_avg:96.73ms -step:753/1695 train_time:72841ms step_avg:96.73ms -step:754/1695 train_time:72937ms step_avg:96.73ms -step:755/1695 train_time:73032ms step_avg:96.73ms -step:756/1695 train_time:73127ms step_avg:96.73ms -step:757/1695 train_time:73221ms step_avg:96.73ms -step:758/1695 train_time:73316ms step_avg:96.72ms -step:759/1695 train_time:73411ms step_avg:96.72ms -step:760/1695 train_time:73506ms step_avg:96.72ms -step:761/1695 train_time:73604ms step_avg:96.72ms -step:762/1695 train_time:73701ms step_avg:96.72ms -step:763/1695 train_time:73800ms step_avg:96.72ms -step:764/1695 train_time:73896ms step_avg:96.72ms -step:765/1695 train_time:73992ms step_avg:96.72ms -step:766/1695 train_time:74088ms step_avg:96.72ms -step:767/1695 train_time:74183ms step_avg:96.72ms -step:768/1695 train_time:74278ms step_avg:96.72ms -step:769/1695 train_time:74373ms step_avg:96.71ms -step:770/1695 train_time:74470ms step_avg:96.71ms -step:771/1695 train_time:74566ms step_avg:96.71ms -step:772/1695 train_time:74663ms step_avg:96.71ms -step:773/1695 train_time:74761ms step_avg:96.71ms -step:774/1695 train_time:74857ms step_avg:96.72ms -step:775/1695 train_time:74954ms step_avg:96.71ms -step:776/1695 train_time:75051ms step_avg:96.72ms -step:777/1695 train_time:75147ms step_avg:96.71ms -step:778/1695 train_time:75242ms step_avg:96.71ms -step:779/1695 train_time:75338ms step_avg:96.71ms -step:780/1695 train_time:75434ms step_avg:96.71ms -step:781/1695 train_time:75531ms step_avg:96.71ms -step:782/1695 train_time:75627ms step_avg:96.71ms -step:783/1695 train_time:75724ms step_avg:96.71ms -step:784/1695 train_time:75820ms step_avg:96.71ms -step:785/1695 train_time:75916ms step_avg:96.71ms -step:786/1695 train_time:76013ms step_avg:96.71ms -step:787/1695 train_time:76109ms step_avg:96.71ms -step:788/1695 train_time:76206ms step_avg:96.71ms -step:789/1695 train_time:76301ms step_avg:96.71ms -step:790/1695 train_time:76397ms step_avg:96.71ms -step:791/1695 train_time:76493ms step_avg:96.70ms -step:792/1695 train_time:76591ms step_avg:96.71ms -step:793/1695 train_time:76689ms step_avg:96.71ms -step:794/1695 train_time:76785ms step_avg:96.71ms -step:795/1695 train_time:76881ms step_avg:96.71ms -step:796/1695 train_time:76977ms step_avg:96.71ms -step:797/1695 train_time:77073ms step_avg:96.70ms -step:798/1695 train_time:77169ms step_avg:96.70ms -step:799/1695 train_time:77264ms step_avg:96.70ms -step:800/1695 train_time:77359ms step_avg:96.70ms -step:801/1695 train_time:77455ms step_avg:96.70ms -step:802/1695 train_time:77551ms step_avg:96.70ms -step:803/1695 train_time:77647ms step_avg:96.70ms -step:804/1695 train_time:77743ms step_avg:96.69ms -step:805/1695 train_time:77839ms step_avg:96.69ms -step:806/1695 train_time:77936ms step_avg:96.69ms -step:807/1695 train_time:78033ms step_avg:96.69ms -step:808/1695 train_time:78130ms step_avg:96.69ms -step:809/1695 train_time:78226ms step_avg:96.69ms -step:810/1695 train_time:78321ms step_avg:96.69ms -step:811/1695 train_time:78417ms step_avg:96.69ms -step:812/1695 train_time:78513ms step_avg:96.69ms -step:813/1695 train_time:78609ms step_avg:96.69ms -step:814/1695 train_time:78706ms step_avg:96.69ms -step:815/1695 train_time:78802ms step_avg:96.69ms -step:816/1695 train_time:78898ms step_avg:96.69ms -step:817/1695 train_time:78996ms step_avg:96.69ms -step:818/1695 train_time:79093ms step_avg:96.69ms -step:819/1695 train_time:79190ms step_avg:96.69ms -step:820/1695 train_time:79286ms step_avg:96.69ms -step:821/1695 train_time:79381ms step_avg:96.69ms -step:822/1695 train_time:79477ms step_avg:96.69ms -step:823/1695 train_time:79574ms step_avg:96.69ms -step:824/1695 train_time:79671ms step_avg:96.69ms -step:825/1695 train_time:79767ms step_avg:96.69ms -step:826/1695 train_time:79863ms step_avg:96.69ms -step:827/1695 train_time:79959ms step_avg:96.69ms -step:828/1695 train_time:80056ms step_avg:96.69ms -step:829/1695 train_time:80153ms step_avg:96.69ms -step:830/1695 train_time:80249ms step_avg:96.69ms -step:831/1695 train_time:80345ms step_avg:96.68ms -step:832/1695 train_time:80440ms step_avg:96.68ms -step:833/1695 train_time:80536ms step_avg:96.68ms -step:834/1695 train_time:80633ms step_avg:96.68ms -step:835/1695 train_time:80730ms step_avg:96.68ms -step:836/1695 train_time:80826ms step_avg:96.68ms -step:837/1695 train_time:80922ms step_avg:96.68ms -step:838/1695 train_time:81019ms step_avg:96.68ms -step:839/1695 train_time:81116ms step_avg:96.68ms -step:840/1695 train_time:81213ms step_avg:96.68ms -step:841/1695 train_time:81309ms step_avg:96.68ms -step:842/1695 train_time:81405ms step_avg:96.68ms -step:843/1695 train_time:81502ms step_avg:96.68ms -step:844/1695 train_time:81597ms step_avg:96.68ms -step:845/1695 train_time:81694ms step_avg:96.68ms -step:846/1695 train_time:81791ms step_avg:96.68ms -step:847/1695 train_time:81888ms step_avg:96.68ms -step:848/1695 train_time:81983ms step_avg:96.68ms -step:849/1695 train_time:82078ms step_avg:96.68ms -step:850/1695 train_time:82174ms step_avg:96.68ms -step:851/1695 train_time:82271ms step_avg:96.68ms -step:852/1695 train_time:82368ms step_avg:96.68ms -step:853/1695 train_time:82464ms step_avg:96.68ms -step:854/1695 train_time:82559ms step_avg:96.67ms -step:855/1695 train_time:82655ms step_avg:96.67ms -step:856/1695 train_time:82752ms step_avg:96.67ms -step:857/1695 train_time:82848ms step_avg:96.67ms -step:858/1695 train_time:82944ms step_avg:96.67ms -step:859/1695 train_time:83039ms step_avg:96.67ms -step:860/1695 train_time:83136ms step_avg:96.67ms -step:861/1695 train_time:83233ms step_avg:96.67ms -step:862/1695 train_time:83331ms step_avg:96.67ms -step:863/1695 train_time:83651ms step_avg:96.93ms -step:864/1695 train_time:83850ms step_avg:97.05ms -step:865/1695 train_time:83944ms step_avg:97.05ms -step:866/1695 train_time:84039ms step_avg:97.04ms -step:867/1695 train_time:84134ms step_avg:97.04ms -step:868/1695 train_time:84229ms step_avg:97.04ms -step:869/1695 train_time:84323ms step_avg:97.03ms -step:870/1695 train_time:84418ms step_avg:97.03ms -step:871/1695 train_time:84513ms step_avg:97.03ms -step:872/1695 train_time:84609ms step_avg:97.03ms -step:873/1695 train_time:84705ms step_avg:97.03ms -step:874/1695 train_time:84806ms step_avg:97.03ms -step:875/1695 train_time:84904ms step_avg:97.03ms -step:875/1695 val_loss:3.5252 train_time:84998ms step_avg:97.14ms -step:876/1695 train_time:85024ms step_avg:97.06ms -step:877/1695 train_time:85101ms step_avg:97.04ms -step:878/1695 train_time:85198ms step_avg:97.04ms -step:879/1695 train_time:85293ms step_avg:97.03ms -step:880/1695 train_time:85388ms step_avg:97.03ms -step:881/1695 train_time:85484ms step_avg:97.03ms -step:882/1695 train_time:85579ms step_avg:97.03ms -step:883/1695 train_time:85673ms step_avg:97.03ms -step:884/1695 train_time:85768ms step_avg:97.02ms -step:885/1695 train_time:85863ms step_avg:97.02ms -step:886/1695 train_time:85962ms step_avg:97.02ms -step:887/1695 train_time:86061ms step_avg:97.02ms -step:888/1695 train_time:86157ms step_avg:97.02ms -step:889/1695 train_time:86253ms step_avg:97.02ms -step:890/1695 train_time:86350ms step_avg:97.02ms -step:891/1695 train_time:86446ms step_avg:97.02ms -step:892/1695 train_time:86542ms step_avg:97.02ms -step:893/1695 train_time:86637ms step_avg:97.02ms -step:894/1695 train_time:86732ms step_avg:97.02ms -step:895/1695 train_time:86828ms step_avg:97.01ms -step:896/1695 train_time:86924ms step_avg:97.01ms -step:897/1695 train_time:87021ms step_avg:97.01ms -step:898/1695 train_time:87117ms step_avg:97.01ms -step:899/1695 train_time:87213ms step_avg:97.01ms -step:900/1695 train_time:87309ms step_avg:97.01ms -step:901/1695 train_time:87406ms step_avg:97.01ms -step:902/1695 train_time:87502ms step_avg:97.01ms -step:903/1695 train_time:87598ms step_avg:97.01ms -step:904/1695 train_time:87693ms step_avg:97.01ms -step:905/1695 train_time:87789ms step_avg:97.00ms -step:906/1695 train_time:87885ms step_avg:97.00ms -step:907/1695 train_time:87982ms step_avg:97.00ms -step:908/1695 train_time:88078ms step_avg:97.00ms -step:909/1695 train_time:88174ms step_avg:97.00ms -step:910/1695 train_time:88270ms step_avg:97.00ms -step:911/1695 train_time:88366ms step_avg:97.00ms -step:912/1695 train_time:88462ms step_avg:97.00ms -step:913/1695 train_time:88558ms step_avg:97.00ms -step:914/1695 train_time:88653ms step_avg:96.99ms -step:915/1695 train_time:88749ms step_avg:96.99ms -step:916/1695 train_time:88845ms step_avg:96.99ms -step:917/1695 train_time:88942ms step_avg:96.99ms -step:918/1695 train_time:89038ms step_avg:96.99ms -step:919/1695 train_time:89133ms step_avg:96.99ms -step:920/1695 train_time:89230ms step_avg:96.99ms -step:921/1695 train_time:89327ms step_avg:96.99ms -step:922/1695 train_time:89423ms step_avg:96.99ms -step:923/1695 train_time:89519ms step_avg:96.99ms -step:924/1695 train_time:89615ms step_avg:96.99ms -step:925/1695 train_time:89711ms step_avg:96.98ms -step:926/1695 train_time:89806ms step_avg:96.98ms -step:927/1695 train_time:89902ms step_avg:96.98ms -step:928/1695 train_time:89998ms step_avg:96.98ms -step:929/1695 train_time:90094ms step_avg:96.98ms -step:930/1695 train_time:90190ms step_avg:96.98ms -step:931/1695 train_time:90287ms step_avg:96.98ms -step:932/1695 train_time:90383ms step_avg:96.98ms -step:933/1695 train_time:90479ms step_avg:96.98ms -step:934/1695 train_time:90575ms step_avg:96.98ms -step:935/1695 train_time:90670ms step_avg:96.97ms -step:936/1695 train_time:90767ms step_avg:96.97ms -step:937/1695 train_time:90863ms step_avg:96.97ms -step:938/1695 train_time:90960ms step_avg:96.97ms -step:939/1695 train_time:91055ms step_avg:96.97ms -step:940/1695 train_time:91151ms step_avg:96.97ms -step:941/1695 train_time:91247ms step_avg:96.97ms -step:942/1695 train_time:91343ms step_avg:96.97ms -step:943/1695 train_time:91440ms step_avg:96.97ms -step:944/1695 train_time:91535ms step_avg:96.97ms -step:945/1695 train_time:91631ms step_avg:96.96ms -step:946/1695 train_time:91726ms step_avg:96.96ms -step:947/1695 train_time:91822ms step_avg:96.96ms -step:948/1695 train_time:91918ms step_avg:96.96ms -step:949/1695 train_time:92013ms step_avg:96.96ms -step:950/1695 train_time:92109ms step_avg:96.96ms -step:951/1695 train_time:92205ms step_avg:96.96ms -step:952/1695 train_time:92301ms step_avg:96.95ms -step:953/1695 train_time:92396ms step_avg:96.95ms -step:954/1695 train_time:92491ms step_avg:96.95ms -step:955/1695 train_time:92587ms step_avg:96.95ms -step:956/1695 train_time:92683ms step_avg:96.95ms -step:957/1695 train_time:92779ms step_avg:96.95ms -step:958/1695 train_time:92875ms step_avg:96.95ms -step:959/1695 train_time:92971ms step_avg:96.95ms -step:960/1695 train_time:93068ms step_avg:96.95ms -step:961/1695 train_time:93164ms step_avg:96.95ms -step:962/1695 train_time:93260ms step_avg:96.94ms -step:963/1695 train_time:93356ms step_avg:96.94ms -step:964/1695 train_time:93453ms step_avg:96.94ms -step:965/1695 train_time:93549ms step_avg:96.94ms -step:966/1695 train_time:93646ms step_avg:96.94ms -step:967/1695 train_time:93743ms step_avg:96.94ms -step:968/1695 train_time:93839ms step_avg:96.94ms -step:969/1695 train_time:93934ms step_avg:96.94ms -step:970/1695 train_time:94030ms step_avg:96.94ms -step:971/1695 train_time:94126ms step_avg:96.94ms -step:972/1695 train_time:94222ms step_avg:96.94ms -step:973/1695 train_time:94318ms step_avg:96.94ms -step:974/1695 train_time:94414ms step_avg:96.93ms -step:975/1695 train_time:94510ms step_avg:96.93ms -step:976/1695 train_time:94607ms step_avg:96.93ms -step:977/1695 train_time:94702ms step_avg:96.93ms -step:978/1695 train_time:94799ms step_avg:96.93ms -step:979/1695 train_time:94894ms step_avg:96.93ms -step:980/1695 train_time:94990ms step_avg:96.93ms -step:981/1695 train_time:95087ms step_avg:96.93ms -step:982/1695 train_time:95184ms step_avg:96.93ms -step:983/1695 train_time:95280ms step_avg:96.93ms -step:984/1695 train_time:95376ms step_avg:96.93ms -step:985/1695 train_time:95472ms step_avg:96.93ms -step:986/1695 train_time:95568ms step_avg:96.92ms -step:987/1695 train_time:95664ms step_avg:96.92ms -step:988/1695 train_time:95760ms step_avg:96.92ms -step:989/1695 train_time:95855ms step_avg:96.92ms -step:990/1695 train_time:95950ms step_avg:96.92ms -step:991/1695 train_time:96047ms step_avg:96.92ms -step:992/1695 train_time:96144ms step_avg:96.92ms -step:993/1695 train_time:96240ms step_avg:96.92ms -step:994/1695 train_time:96335ms step_avg:96.92ms -step:995/1695 train_time:96431ms step_avg:96.92ms -step:996/1695 train_time:96528ms step_avg:96.92ms -step:997/1695 train_time:96625ms step_avg:96.92ms -step:998/1695 train_time:96722ms step_avg:96.92ms -step:999/1695 train_time:96817ms step_avg:96.91ms -step:1000/1695 train_time:96912ms step_avg:96.91ms -step:1000/1695 val_loss:3.4845 train_time:97007ms step_avg:97.01ms -step:1001/1695 train_time:97031ms step_avg:96.93ms -step:1002/1695 train_time:97110ms step_avg:96.92ms -step:1003/1695 train_time:97207ms step_avg:96.92ms -step:1004/1695 train_time:97304ms step_avg:96.92ms -step:1005/1695 train_time:97399ms step_avg:96.91ms -step:1006/1695 train_time:97495ms step_avg:96.91ms -step:1007/1695 train_time:97589ms step_avg:96.91ms -step:1008/1695 train_time:97684ms step_avg:96.91ms -step:1009/1695 train_time:97779ms step_avg:96.91ms -step:1010/1695 train_time:97874ms step_avg:96.90ms -step:1011/1695 train_time:97971ms step_avg:96.90ms -step:1012/1695 train_time:98069ms step_avg:96.91ms -step:1013/1695 train_time:98165ms step_avg:96.91ms -step:1014/1695 train_time:98263ms step_avg:96.91ms -step:1015/1695 train_time:98359ms step_avg:96.91ms -step:1016/1695 train_time:98454ms step_avg:96.90ms -step:1017/1695 train_time:98549ms step_avg:96.90ms -step:1018/1695 train_time:98644ms step_avg:96.90ms -step:1019/1695 train_time:98740ms step_avg:96.90ms -step:1020/1695 train_time:98835ms step_avg:96.90ms -step:1021/1695 train_time:98930ms step_avg:96.89ms -step:1022/1695 train_time:99026ms step_avg:96.89ms -step:1023/1695 train_time:99123ms step_avg:96.89ms -step:1024/1695 train_time:99220ms step_avg:96.89ms -step:1025/1695 train_time:99317ms step_avg:96.89ms -step:1026/1695 train_time:99414ms step_avg:96.90ms -step:1027/1695 train_time:99510ms step_avg:96.89ms -step:1028/1695 train_time:99605ms step_avg:96.89ms -step:1029/1695 train_time:99700ms step_avg:96.89ms -step:1030/1695 train_time:99795ms step_avg:96.89ms -step:1031/1695 train_time:99891ms step_avg:96.89ms -step:1032/1695 train_time:99986ms step_avg:96.89ms -step:1033/1695 train_time:100082ms step_avg:96.88ms -step:1034/1695 train_time:100179ms step_avg:96.88ms -step:1035/1695 train_time:100274ms step_avg:96.88ms -step:1036/1695 train_time:100600ms step_avg:97.10ms -step:1037/1695 train_time:100773ms step_avg:97.18ms -step:1038/1695 train_time:100866ms step_avg:97.17ms -step:1039/1695 train_time:100961ms step_avg:97.17ms -step:1040/1695 train_time:101056ms step_avg:97.17ms -step:1041/1695 train_time:101151ms step_avg:97.17ms -step:1042/1695 train_time:101245ms step_avg:97.16ms -step:1043/1695 train_time:101341ms step_avg:97.16ms -step:1044/1695 train_time:101436ms step_avg:97.16ms -step:1045/1695 train_time:101530ms step_avg:97.16ms -step:1046/1695 train_time:101626ms step_avg:97.16ms -step:1047/1695 train_time:101728ms step_avg:97.16ms -step:1048/1695 train_time:101826ms step_avg:97.16ms -step:1049/1695 train_time:101923ms step_avg:97.16ms -step:1050/1695 train_time:102019ms step_avg:97.16ms -step:1051/1695 train_time:102115ms step_avg:97.16ms -step:1052/1695 train_time:102210ms step_avg:97.16ms -step:1053/1695 train_time:102304ms step_avg:97.16ms -step:1054/1695 train_time:102399ms step_avg:97.15ms -step:1055/1695 train_time:102494ms step_avg:97.15ms -step:1056/1695 train_time:102590ms step_avg:97.15ms -step:1057/1695 train_time:102688ms step_avg:97.15ms -step:1058/1695 train_time:102784ms step_avg:97.15ms -step:1059/1695 train_time:102881ms step_avg:97.15ms -step:1060/1695 train_time:102978ms step_avg:97.15ms -step:1061/1695 train_time:103075ms step_avg:97.15ms -step:1062/1695 train_time:103171ms step_avg:97.15ms -step:1063/1695 train_time:103265ms step_avg:97.15ms -step:1064/1695 train_time:103360ms step_avg:97.14ms -step:1065/1695 train_time:103456ms step_avg:97.14ms -step:1066/1695 train_time:103553ms step_avg:97.14ms -step:1067/1695 train_time:103650ms step_avg:97.14ms -step:1068/1695 train_time:103746ms step_avg:97.14ms -step:1069/1695 train_time:103842ms step_avg:97.14ms -step:1070/1695 train_time:103939ms step_avg:97.14ms -step:1071/1695 train_time:104036ms step_avg:97.14ms -step:1072/1695 train_time:104132ms step_avg:97.14ms -step:1073/1695 train_time:104227ms step_avg:97.14ms -step:1074/1695 train_time:104322ms step_avg:97.13ms -step:1075/1695 train_time:104418ms step_avg:97.13ms -step:1076/1695 train_time:104514ms step_avg:97.13ms -step:1077/1695 train_time:104609ms step_avg:97.13ms -step:1078/1695 train_time:104705ms step_avg:97.13ms -step:1079/1695 train_time:104802ms step_avg:97.13ms -step:1080/1695 train_time:104898ms step_avg:97.13ms -step:1081/1695 train_time:104995ms step_avg:97.13ms -step:1082/1695 train_time:105091ms step_avg:97.13ms -step:1083/1695 train_time:105187ms step_avg:97.13ms -step:1084/1695 train_time:105282ms step_avg:97.12ms -step:1085/1695 train_time:105378ms step_avg:97.12ms -step:1086/1695 train_time:105474ms step_avg:97.12ms -step:1087/1695 train_time:105569ms step_avg:97.12ms -step:1088/1695 train_time:105664ms step_avg:97.12ms -step:1089/1695 train_time:105761ms step_avg:97.12ms -step:1090/1695 train_time:105858ms step_avg:97.12ms -step:1091/1695 train_time:105955ms step_avg:97.12ms -step:1092/1695 train_time:106050ms step_avg:97.12ms -step:1093/1695 train_time:106146ms step_avg:97.11ms -step:1094/1695 train_time:106241ms step_avg:97.11ms -step:1095/1695 train_time:106339ms step_avg:97.11ms -step:1096/1695 train_time:106435ms step_avg:97.11ms -step:1097/1695 train_time:106531ms step_avg:97.11ms -step:1098/1695 train_time:106627ms step_avg:97.11ms -step:1099/1695 train_time:106723ms step_avg:97.11ms -step:1100/1695 train_time:106819ms step_avg:97.11ms -step:1101/1695 train_time:106915ms step_avg:97.11ms -step:1102/1695 train_time:107012ms step_avg:97.11ms -step:1103/1695 train_time:107107ms step_avg:97.11ms -step:1104/1695 train_time:107203ms step_avg:97.10ms -step:1105/1695 train_time:107300ms step_avg:97.10ms -step:1106/1695 train_time:107396ms step_avg:97.10ms -step:1107/1695 train_time:107492ms step_avg:97.10ms -step:1108/1695 train_time:107587ms step_avg:97.10ms -step:1109/1695 train_time:107683ms step_avg:97.10ms -step:1110/1695 train_time:107779ms step_avg:97.10ms -step:1111/1695 train_time:107875ms step_avg:97.10ms -step:1112/1695 train_time:107971ms step_avg:97.10ms -step:1113/1695 train_time:108066ms step_avg:97.09ms -step:1114/1695 train_time:108163ms step_avg:97.09ms -step:1115/1695 train_time:108259ms step_avg:97.09ms -step:1116/1695 train_time:108356ms step_avg:97.09ms -step:1117/1695 train_time:108453ms step_avg:97.09ms -step:1118/1695 train_time:108549ms step_avg:97.09ms -step:1119/1695 train_time:108645ms step_avg:97.09ms -step:1120/1695 train_time:108741ms step_avg:97.09ms -step:1121/1695 train_time:108838ms step_avg:97.09ms -step:1122/1695 train_time:108934ms step_avg:97.09ms -step:1123/1695 train_time:109030ms step_avg:97.09ms -step:1124/1695 train_time:109126ms step_avg:97.09ms -step:1125/1695 train_time:109222ms step_avg:97.09ms -step:1125/1695 val_loss:3.4375 train_time:109315ms step_avg:97.17ms -step:1126/1695 train_time:109340ms step_avg:97.10ms -step:1127/1695 train_time:109421ms step_avg:97.09ms -step:1128/1695 train_time:109520ms step_avg:97.09ms -step:1129/1695 train_time:109618ms step_avg:97.09ms -step:1130/1695 train_time:109713ms step_avg:97.09ms -step:1131/1695 train_time:109808ms step_avg:97.09ms -step:1132/1695 train_time:109903ms step_avg:97.09ms -step:1133/1695 train_time:110000ms step_avg:97.09ms -step:1134/1695 train_time:110096ms step_avg:97.09ms -step:1135/1695 train_time:110193ms step_avg:97.09ms -step:1136/1695 train_time:110294ms step_avg:97.09ms -step:1137/1695 train_time:110395ms step_avg:97.09ms -step:1138/1695 train_time:110495ms step_avg:97.10ms -step:1139/1695 train_time:110595ms step_avg:97.10ms -step:1140/1695 train_time:110691ms step_avg:97.10ms -step:1141/1695 train_time:110788ms step_avg:97.10ms -step:1142/1695 train_time:110886ms step_avg:97.10ms -step:1143/1695 train_time:110983ms step_avg:97.10ms -step:1144/1695 train_time:111080ms step_avg:97.10ms -step:1145/1695 train_time:111178ms step_avg:97.10ms -step:1146/1695 train_time:111276ms step_avg:97.10ms -step:1147/1695 train_time:111375ms step_avg:97.10ms -step:1148/1695 train_time:111474ms step_avg:97.10ms -step:1149/1695 train_time:111572ms step_avg:97.10ms -step:1150/1695 train_time:111669ms step_avg:97.10ms -step:1151/1695 train_time:111766ms step_avg:97.10ms -step:1152/1695 train_time:111863ms step_avg:97.10ms -step:1153/1695 train_time:111959ms step_avg:97.10ms -step:1154/1695 train_time:112057ms step_avg:97.10ms -step:1155/1695 train_time:112154ms step_avg:97.10ms -step:1156/1695 train_time:112251ms step_avg:97.10ms -step:1157/1695 train_time:112349ms step_avg:97.10ms -step:1158/1695 train_time:112448ms step_avg:97.11ms -step:1159/1695 train_time:112547ms step_avg:97.11ms -step:1160/1695 train_time:112645ms step_avg:97.11ms -step:1161/1695 train_time:112744ms step_avg:97.11ms -step:1162/1695 train_time:112843ms step_avg:97.11ms -step:1163/1695 train_time:112940ms step_avg:97.11ms -step:1164/1695 train_time:113037ms step_avg:97.11ms -step:1165/1695 train_time:113135ms step_avg:97.11ms -step:1166/1695 train_time:113232ms step_avg:97.11ms -step:1167/1695 train_time:113330ms step_avg:97.11ms -step:1168/1695 train_time:113428ms step_avg:97.11ms -step:1169/1695 train_time:113527ms step_avg:97.11ms -step:1170/1695 train_time:113626ms step_avg:97.12ms -step:1171/1695 train_time:113723ms step_avg:97.12ms -step:1172/1695 train_time:113822ms step_avg:97.12ms -step:1173/1695 train_time:113919ms step_avg:97.12ms -step:1174/1695 train_time:114018ms step_avg:97.12ms -step:1175/1695 train_time:114115ms step_avg:97.12ms -step:1176/1695 train_time:114212ms step_avg:97.12ms -step:1177/1695 train_time:114310ms step_avg:97.12ms -step:1178/1695 train_time:114406ms step_avg:97.12ms -step:1179/1695 train_time:114504ms step_avg:97.12ms -step:1180/1695 train_time:114604ms step_avg:97.12ms -step:1181/1695 train_time:114703ms step_avg:97.12ms -step:1182/1695 train_time:114801ms step_avg:97.12ms -step:1183/1695 train_time:114898ms step_avg:97.12ms -step:1184/1695 train_time:114996ms step_avg:97.13ms -step:1185/1695 train_time:115095ms step_avg:97.13ms -step:1186/1695 train_time:115193ms step_avg:97.13ms -step:1187/1695 train_time:115290ms step_avg:97.13ms -step:1188/1695 train_time:115388ms step_avg:97.13ms -step:1189/1695 train_time:115486ms step_avg:97.13ms -step:1190/1695 train_time:115584ms step_avg:97.13ms -step:1191/1695 train_time:115682ms step_avg:97.13ms -step:1192/1695 train_time:115779ms step_avg:97.13ms -step:1193/1695 train_time:115876ms step_avg:97.13ms -step:1194/1695 train_time:115973ms step_avg:97.13ms -step:1195/1695 train_time:116071ms step_avg:97.13ms -step:1196/1695 train_time:116169ms step_avg:97.13ms -step:1197/1695 train_time:116266ms step_avg:97.13ms -step:1198/1695 train_time:116365ms step_avg:97.13ms -step:1199/1695 train_time:116464ms step_avg:97.13ms -step:1200/1695 train_time:116562ms step_avg:97.14ms -step:1201/1695 train_time:116661ms step_avg:97.14ms -step:1202/1695 train_time:116758ms step_avg:97.14ms -step:1203/1695 train_time:116856ms step_avg:97.14ms -step:1204/1695 train_time:116955ms step_avg:97.14ms -step:1205/1695 train_time:117053ms step_avg:97.14ms -step:1206/1695 train_time:117150ms step_avg:97.14ms -step:1207/1695 train_time:117247ms step_avg:97.14ms -step:1208/1695 train_time:117580ms step_avg:97.33ms -step:1209/1695 train_time:117762ms step_avg:97.40ms -step:1210/1695 train_time:117858ms step_avg:97.40ms -step:1211/1695 train_time:117955ms step_avg:97.40ms -step:1212/1695 train_time:118051ms step_avg:97.40ms -step:1213/1695 train_time:118147ms step_avg:97.40ms -step:1214/1695 train_time:118243ms step_avg:97.40ms -step:1215/1695 train_time:118340ms step_avg:97.40ms -step:1216/1695 train_time:118436ms step_avg:97.40ms -step:1217/1695 train_time:118533ms step_avg:97.40ms -step:1218/1695 train_time:118636ms step_avg:97.40ms -step:1219/1695 train_time:118739ms step_avg:97.41ms -step:1220/1695 train_time:118838ms step_avg:97.41ms -step:1221/1695 train_time:118935ms step_avg:97.41ms -step:1222/1695 train_time:119031ms step_avg:97.41ms -step:1223/1695 train_time:119128ms step_avg:97.41ms -step:1224/1695 train_time:119225ms step_avg:97.41ms -step:1225/1695 train_time:119323ms step_avg:97.41ms -step:1226/1695 train_time:119420ms step_avg:97.41ms -step:1227/1695 train_time:119519ms step_avg:97.41ms -step:1228/1695 train_time:119618ms step_avg:97.41ms -step:1229/1695 train_time:119719ms step_avg:97.41ms -step:1230/1695 train_time:119818ms step_avg:97.41ms -step:1231/1695 train_time:119916ms step_avg:97.41ms -step:1232/1695 train_time:120013ms step_avg:97.41ms -step:1233/1695 train_time:120109ms step_avg:97.41ms -step:1234/1695 train_time:120206ms step_avg:97.41ms -step:1235/1695 train_time:120302ms step_avg:97.41ms -step:1236/1695 train_time:120399ms step_avg:97.41ms -step:1237/1695 train_time:120498ms step_avg:97.41ms -step:1238/1695 train_time:120595ms step_avg:97.41ms -step:1239/1695 train_time:120694ms step_avg:97.41ms -step:1240/1695 train_time:120792ms step_avg:97.41ms -step:1241/1695 train_time:120890ms step_avg:97.41ms -step:1242/1695 train_time:120987ms step_avg:97.41ms -step:1243/1695 train_time:121085ms step_avg:97.41ms -step:1244/1695 train_time:121183ms step_avg:97.41ms -step:1245/1695 train_time:121280ms step_avg:97.41ms -step:1246/1695 train_time:121377ms step_avg:97.41ms -step:1247/1695 train_time:121475ms step_avg:97.41ms -step:1248/1695 train_time:121573ms step_avg:97.41ms -step:1249/1695 train_time:121671ms step_avg:97.41ms -step:1250/1695 train_time:121769ms step_avg:97.41ms -step:1250/1695 val_loss:3.3901 train_time:121865ms step_avg:97.49ms -step:1251/1695 train_time:121890ms step_avg:97.43ms -step:1252/1695 train_time:121970ms step_avg:97.42ms -step:1253/1695 train_time:122068ms step_avg:97.42ms -step:1254/1695 train_time:122165ms step_avg:97.42ms -step:1255/1695 train_time:122262ms step_avg:97.42ms -step:1256/1695 train_time:122358ms step_avg:97.42ms -step:1257/1695 train_time:122455ms step_avg:97.42ms -step:1258/1695 train_time:122552ms step_avg:97.42ms -step:1259/1695 train_time:122648ms step_avg:97.42ms -step:1260/1695 train_time:122746ms step_avg:97.42ms -step:1261/1695 train_time:122850ms step_avg:97.42ms -step:1262/1695 train_time:122952ms step_avg:97.43ms -step:1263/1695 train_time:123051ms step_avg:97.43ms -step:1264/1695 train_time:123148ms step_avg:97.43ms -step:1265/1695 train_time:123246ms step_avg:97.43ms -step:1266/1695 train_time:123344ms step_avg:97.43ms -step:1267/1695 train_time:123440ms step_avg:97.43ms -step:1268/1695 train_time:123537ms step_avg:97.43ms -step:1269/1695 train_time:123633ms step_avg:97.43ms -step:1270/1695 train_time:123732ms step_avg:97.43ms -step:1271/1695 train_time:123833ms step_avg:97.43ms -step:1272/1695 train_time:123933ms step_avg:97.43ms -step:1273/1695 train_time:124031ms step_avg:97.43ms -step:1274/1695 train_time:124131ms step_avg:97.43ms -step:1275/1695 train_time:124230ms step_avg:97.44ms -step:1276/1695 train_time:124329ms step_avg:97.44ms -step:1277/1695 train_time:124427ms step_avg:97.44ms -step:1278/1695 train_time:124525ms step_avg:97.44ms -step:1279/1695 train_time:124622ms step_avg:97.44ms -step:1280/1695 train_time:124719ms step_avg:97.44ms -step:1281/1695 train_time:124817ms step_avg:97.44ms -step:1282/1695 train_time:124915ms step_avg:97.44ms -step:1283/1695 train_time:125014ms step_avg:97.44ms -step:1284/1695 train_time:125113ms step_avg:97.44ms -step:1285/1695 train_time:125211ms step_avg:97.44ms -step:1286/1695 train_time:125310ms step_avg:97.44ms -step:1287/1695 train_time:125409ms step_avg:97.44ms -step:1288/1695 train_time:125508ms step_avg:97.44ms -step:1289/1695 train_time:125606ms step_avg:97.44ms -step:1290/1695 train_time:125703ms step_avg:97.44ms -step:1291/1695 train_time:125801ms step_avg:97.44ms -step:1292/1695 train_time:125898ms step_avg:97.44ms -step:1293/1695 train_time:125996ms step_avg:97.44ms -step:1294/1695 train_time:126094ms step_avg:97.44ms -step:1295/1695 train_time:126193ms step_avg:97.45ms -step:1296/1695 train_time:126291ms step_avg:97.45ms -step:1297/1695 train_time:126390ms step_avg:97.45ms -step:1298/1695 train_time:126489ms step_avg:97.45ms -step:1299/1695 train_time:126586ms step_avg:97.45ms -step:1300/1695 train_time:126684ms step_avg:97.45ms -step:1301/1695 train_time:126781ms step_avg:97.45ms -step:1302/1695 train_time:126879ms step_avg:97.45ms -step:1303/1695 train_time:126976ms step_avg:97.45ms -step:1304/1695 train_time:127074ms step_avg:97.45ms -step:1305/1695 train_time:127172ms step_avg:97.45ms -step:1306/1695 train_time:127271ms step_avg:97.45ms -step:1307/1695 train_time:127369ms step_avg:97.45ms -step:1308/1695 train_time:127467ms step_avg:97.45ms -step:1309/1695 train_time:127566ms step_avg:97.45ms -step:1310/1695 train_time:127664ms step_avg:97.45ms -step:1311/1695 train_time:127762ms step_avg:97.45ms -step:1312/1695 train_time:127860ms step_avg:97.45ms -step:1313/1695 train_time:127957ms step_avg:97.45ms -step:1314/1695 train_time:128054ms step_avg:97.45ms -step:1315/1695 train_time:128151ms step_avg:97.45ms -step:1316/1695 train_time:128250ms step_avg:97.45ms -step:1317/1695 train_time:128349ms step_avg:97.46ms -step:1318/1695 train_time:128448ms step_avg:97.46ms -step:1319/1695 train_time:128546ms step_avg:97.46ms -step:1320/1695 train_time:128645ms step_avg:97.46ms -step:1321/1695 train_time:128743ms step_avg:97.46ms -step:1322/1695 train_time:128840ms step_avg:97.46ms -step:1323/1695 train_time:128938ms step_avg:97.46ms -step:1324/1695 train_time:129035ms step_avg:97.46ms -step:1325/1695 train_time:129132ms step_avg:97.46ms -step:1326/1695 train_time:129230ms step_avg:97.46ms -step:1327/1695 train_time:129328ms step_avg:97.46ms -step:1328/1695 train_time:129427ms step_avg:97.46ms -step:1329/1695 train_time:129525ms step_avg:97.46ms -step:1330/1695 train_time:129623ms step_avg:97.46ms -step:1331/1695 train_time:129721ms step_avg:97.46ms -step:1332/1695 train_time:129820ms step_avg:97.46ms -step:1333/1695 train_time:129917ms step_avg:97.46ms -step:1334/1695 train_time:130014ms step_avg:97.46ms -step:1335/1695 train_time:130112ms step_avg:97.46ms -step:1336/1695 train_time:130209ms step_avg:97.46ms -step:1337/1695 train_time:130306ms step_avg:97.46ms -step:1338/1695 train_time:130405ms step_avg:97.46ms -step:1339/1695 train_time:130503ms step_avg:97.46ms -step:1340/1695 train_time:130600ms step_avg:97.46ms -step:1341/1695 train_time:130698ms step_avg:97.46ms -step:1342/1695 train_time:130795ms step_avg:97.46ms -step:1343/1695 train_time:130894ms step_avg:97.46ms -step:1344/1695 train_time:130992ms step_avg:97.46ms -step:1345/1695 train_time:131090ms step_avg:97.46ms -step:1346/1695 train_time:131188ms step_avg:97.46ms -step:1347/1695 train_time:131285ms step_avg:97.46ms -step:1348/1695 train_time:131382ms step_avg:97.46ms -step:1349/1695 train_time:131480ms step_avg:97.46ms -step:1350/1695 train_time:131577ms step_avg:97.46ms -step:1351/1695 train_time:131675ms step_avg:97.46ms -step:1352/1695 train_time:131773ms step_avg:97.47ms -step:1353/1695 train_time:131872ms step_avg:97.47ms -step:1354/1695 train_time:131970ms step_avg:97.47ms -step:1355/1695 train_time:132069ms step_avg:97.47ms -step:1356/1695 train_time:132167ms step_avg:97.47ms -step:1357/1695 train_time:132264ms step_avg:97.47ms -step:1358/1695 train_time:132362ms step_avg:97.47ms -step:1359/1695 train_time:132460ms step_avg:97.47ms -step:1360/1695 train_time:132556ms step_avg:97.47ms -step:1361/1695 train_time:132654ms step_avg:97.47ms -step:1362/1695 train_time:132752ms step_avg:97.47ms -step:1363/1695 train_time:132851ms step_avg:97.47ms -step:1364/1695 train_time:132950ms step_avg:97.47ms -step:1365/1695 train_time:133049ms step_avg:97.47ms -step:1366/1695 train_time:133147ms step_avg:97.47ms -step:1367/1695 train_time:133246ms step_avg:97.47ms -step:1368/1695 train_time:133344ms step_avg:97.47ms -step:1369/1695 train_time:133443ms step_avg:97.47ms -step:1370/1695 train_time:133540ms step_avg:97.47ms -step:1371/1695 train_time:133637ms step_avg:97.47ms -step:1372/1695 train_time:133734ms step_avg:97.47ms -step:1373/1695 train_time:133830ms step_avg:97.47ms -step:1374/1695 train_time:133929ms step_avg:97.47ms -step:1375/1695 train_time:134027ms step_avg:97.47ms -step:1375/1695 val_loss:3.3508 train_time:134124ms step_avg:97.54ms -step:1376/1695 train_time:134149ms step_avg:97.49ms -step:1377/1695 train_time:134233ms step_avg:97.48ms -step:1378/1695 train_time:134332ms step_avg:97.48ms -step:1379/1695 train_time:134430ms step_avg:97.48ms -step:1380/1695 train_time:134528ms step_avg:97.48ms -step:1381/1695 train_time:134984ms step_avg:97.74ms -step:1382/1695 train_time:135059ms step_avg:97.73ms -step:1383/1695 train_time:135155ms step_avg:97.73ms -step:1384/1695 train_time:135251ms step_avg:97.72ms -step:1385/1695 train_time:135348ms step_avg:97.72ms -step:1386/1695 train_time:135445ms step_avg:97.72ms -step:1387/1695 train_time:135542ms step_avg:97.72ms -step:1388/1695 train_time:135638ms step_avg:97.72ms -step:1389/1695 train_time:135734ms step_avg:97.72ms -step:1390/1695 train_time:135832ms step_avg:97.72ms -step:1391/1695 train_time:135934ms step_avg:97.72ms -step:1392/1695 train_time:136034ms step_avg:97.73ms -step:1393/1695 train_time:136133ms step_avg:97.73ms -step:1394/1695 train_time:136230ms step_avg:97.73ms -step:1395/1695 train_time:136328ms step_avg:97.73ms -step:1396/1695 train_time:136425ms step_avg:97.73ms -step:1397/1695 train_time:136523ms step_avg:97.73ms -step:1398/1695 train_time:136620ms step_avg:97.73ms -step:1399/1695 train_time:136717ms step_avg:97.72ms -step:1400/1695 train_time:136814ms step_avg:97.72ms -step:1401/1695 train_time:136913ms step_avg:97.73ms -step:1402/1695 train_time:137012ms step_avg:97.73ms -step:1403/1695 train_time:137111ms step_avg:97.73ms -step:1404/1695 train_time:137211ms step_avg:97.73ms -step:1405/1695 train_time:137309ms step_avg:97.73ms -step:1406/1695 train_time:137407ms step_avg:97.73ms -step:1407/1695 train_time:137504ms step_avg:97.73ms -step:1408/1695 train_time:137601ms step_avg:97.73ms -step:1409/1695 train_time:137698ms step_avg:97.73ms -step:1410/1695 train_time:137795ms step_avg:97.73ms -step:1411/1695 train_time:137894ms step_avg:97.73ms -step:1412/1695 train_time:137993ms step_avg:97.73ms -step:1413/1695 train_time:138092ms step_avg:97.73ms -step:1414/1695 train_time:138191ms step_avg:97.73ms -step:1415/1695 train_time:138289ms step_avg:97.73ms -step:1416/1695 train_time:138386ms step_avg:97.73ms -step:1417/1695 train_time:138483ms step_avg:97.73ms -step:1418/1695 train_time:138580ms step_avg:97.73ms -step:1419/1695 train_time:138677ms step_avg:97.73ms -step:1420/1695 train_time:138774ms step_avg:97.73ms -step:1421/1695 train_time:138873ms step_avg:97.73ms -step:1422/1695 train_time:138971ms step_avg:97.73ms -step:1423/1695 train_time:139071ms step_avg:97.73ms -step:1424/1695 train_time:139170ms step_avg:97.73ms -step:1425/1695 train_time:139269ms step_avg:97.73ms -step:1426/1695 train_time:139367ms step_avg:97.73ms -step:1427/1695 train_time:139467ms step_avg:97.73ms -step:1428/1695 train_time:139565ms step_avg:97.73ms -step:1429/1695 train_time:139662ms step_avg:97.73ms -step:1430/1695 train_time:139760ms step_avg:97.73ms -step:1431/1695 train_time:139857ms step_avg:97.73ms -step:1432/1695 train_time:139954ms step_avg:97.73ms -step:1433/1695 train_time:140052ms step_avg:97.73ms -step:1434/1695 train_time:140149ms step_avg:97.73ms -step:1435/1695 train_time:140246ms step_avg:97.73ms -step:1436/1695 train_time:140344ms step_avg:97.73ms -step:1437/1695 train_time:140441ms step_avg:97.73ms -step:1438/1695 train_time:140538ms step_avg:97.73ms -step:1439/1695 train_time:140635ms step_avg:97.73ms -step:1440/1695 train_time:140733ms step_avg:97.73ms -step:1441/1695 train_time:140831ms step_avg:97.73ms -step:1442/1695 train_time:140929ms step_avg:97.73ms -step:1443/1695 train_time:141027ms step_avg:97.73ms -step:1444/1695 train_time:141125ms step_avg:97.73ms -step:1445/1695 train_time:141223ms step_avg:97.73ms -step:1446/1695 train_time:141320ms step_avg:97.73ms -step:1447/1695 train_time:141417ms step_avg:97.73ms -step:1448/1695 train_time:141514ms step_avg:97.73ms -step:1449/1695 train_time:141612ms step_avg:97.73ms -step:1450/1695 train_time:141710ms step_avg:97.73ms -step:1451/1695 train_time:141808ms step_avg:97.73ms -step:1452/1695 train_time:141905ms step_avg:97.73ms -step:1453/1695 train_time:142003ms step_avg:97.73ms -step:1454/1695 train_time:142100ms step_avg:97.73ms -step:1455/1695 train_time:142198ms step_avg:97.73ms -step:1456/1695 train_time:142295ms step_avg:97.73ms -step:1457/1695 train_time:142393ms step_avg:97.73ms -step:1458/1695 train_time:142491ms step_avg:97.73ms -step:1459/1695 train_time:142590ms step_avg:97.73ms -step:1460/1695 train_time:142687ms step_avg:97.73ms -step:1461/1695 train_time:142784ms step_avg:97.73ms -step:1462/1695 train_time:142881ms step_avg:97.73ms -step:1463/1695 train_time:142979ms step_avg:97.73ms -step:1464/1695 train_time:143076ms step_avg:97.73ms -step:1465/1695 train_time:143174ms step_avg:97.73ms -step:1466/1695 train_time:143271ms step_avg:97.73ms -step:1467/1695 train_time:143369ms step_avg:97.73ms -step:1468/1695 train_time:143468ms step_avg:97.73ms -step:1469/1695 train_time:143568ms step_avg:97.73ms -step:1470/1695 train_time:143665ms step_avg:97.73ms -step:1471/1695 train_time:143763ms step_avg:97.73ms -step:1472/1695 train_time:143860ms step_avg:97.73ms -step:1473/1695 train_time:143957ms step_avg:97.73ms -step:1474/1695 train_time:144054ms step_avg:97.73ms -step:1475/1695 train_time:144151ms step_avg:97.73ms -step:1476/1695 train_time:144248ms step_avg:97.73ms -step:1477/1695 train_time:144346ms step_avg:97.73ms -step:1478/1695 train_time:144444ms step_avg:97.73ms -step:1479/1695 train_time:144542ms step_avg:97.73ms -step:1480/1695 train_time:144640ms step_avg:97.73ms -step:1481/1695 train_time:144737ms step_avg:97.73ms -step:1482/1695 train_time:144834ms step_avg:97.73ms -step:1483/1695 train_time:144931ms step_avg:97.73ms -step:1484/1695 train_time:145029ms step_avg:97.73ms -step:1485/1695 train_time:145127ms step_avg:97.73ms -step:1486/1695 train_time:145225ms step_avg:97.73ms -step:1487/1695 train_time:145322ms step_avg:97.73ms -step:1488/1695 train_time:145420ms step_avg:97.73ms -step:1489/1695 train_time:145518ms step_avg:97.73ms -step:1490/1695 train_time:145616ms step_avg:97.73ms -step:1491/1695 train_time:145714ms step_avg:97.73ms -step:1492/1695 train_time:145811ms step_avg:97.73ms -step:1493/1695 train_time:145909ms step_avg:97.73ms -step:1494/1695 train_time:146007ms step_avg:97.73ms -step:1495/1695 train_time:146105ms step_avg:97.73ms -step:1496/1695 train_time:146203ms step_avg:97.73ms -step:1497/1695 train_time:146300ms step_avg:97.73ms -step:1498/1695 train_time:146397ms step_avg:97.73ms -step:1499/1695 train_time:146494ms step_avg:97.73ms -step:1500/1695 train_time:146592ms step_avg:97.73ms -step:1500/1695 val_loss:3.3185 train_time:146688ms step_avg:97.79ms -step:1501/1695 train_time:146713ms step_avg:97.74ms -step:1502/1695 train_time:146798ms step_avg:97.74ms -step:1503/1695 train_time:146898ms step_avg:97.74ms -step:1504/1695 train_time:146996ms step_avg:97.74ms -step:1505/1695 train_time:147093ms step_avg:97.74ms -step:1506/1695 train_time:147190ms step_avg:97.74ms -step:1507/1695 train_time:147287ms step_avg:97.73ms -step:1508/1695 train_time:147383ms step_avg:97.73ms -step:1509/1695 train_time:147479ms step_avg:97.73ms -step:1510/1695 train_time:147576ms step_avg:97.73ms -step:1511/1695 train_time:147675ms step_avg:97.73ms -step:1512/1695 train_time:147778ms step_avg:97.74ms -step:1513/1695 train_time:147878ms step_avg:97.74ms -step:1514/1695 train_time:147977ms step_avg:97.74ms -step:1515/1695 train_time:148075ms step_avg:97.74ms -step:1516/1695 train_time:148173ms step_avg:97.74ms -step:1517/1695 train_time:148270ms step_avg:97.74ms -step:1518/1695 train_time:148368ms step_avg:97.74ms -step:1519/1695 train_time:148465ms step_avg:97.74ms -step:1520/1695 train_time:148562ms step_avg:97.74ms -step:1521/1695 train_time:148659ms step_avg:97.74ms -step:1522/1695 train_time:148758ms step_avg:97.74ms -step:1523/1695 train_time:148858ms step_avg:97.74ms -step:1524/1695 train_time:148957ms step_avg:97.74ms -step:1525/1695 train_time:149056ms step_avg:97.74ms -step:1526/1695 train_time:149154ms step_avg:97.74ms -step:1527/1695 train_time:149252ms step_avg:97.74ms -step:1528/1695 train_time:149349ms step_avg:97.74ms -step:1529/1695 train_time:149446ms step_avg:97.74ms -step:1530/1695 train_time:149544ms step_avg:97.74ms -step:1531/1695 train_time:149641ms step_avg:97.74ms -step:1532/1695 train_time:149739ms step_avg:97.74ms -step:1533/1695 train_time:149836ms step_avg:97.74ms -step:1534/1695 train_time:149935ms step_avg:97.74ms -step:1535/1695 train_time:150033ms step_avg:97.74ms -step:1536/1695 train_time:150131ms step_avg:97.74ms -step:1537/1695 train_time:150228ms step_avg:97.74ms -step:1538/1695 train_time:150325ms step_avg:97.74ms -step:1539/1695 train_time:150423ms step_avg:97.74ms -step:1540/1695 train_time:150520ms step_avg:97.74ms -step:1541/1695 train_time:150617ms step_avg:97.74ms -step:1542/1695 train_time:150715ms step_avg:97.74ms -step:1543/1695 train_time:150815ms step_avg:97.74ms -step:1544/1695 train_time:150914ms step_avg:97.74ms -step:1545/1695 train_time:151013ms step_avg:97.74ms -step:1546/1695 train_time:151111ms step_avg:97.74ms -step:1547/1695 train_time:151209ms step_avg:97.74ms -step:1548/1695 train_time:151307ms step_avg:97.74ms -step:1549/1695 train_time:151404ms step_avg:97.74ms -step:1550/1695 train_time:151502ms step_avg:97.74ms -step:1551/1695 train_time:151599ms step_avg:97.74ms -step:1552/1695 train_time:152044ms step_avg:97.97ms -step:1553/1695 train_time:152117ms step_avg:97.95ms -step:1554/1695 train_time:152213ms step_avg:97.95ms -step:1555/1695 train_time:152309ms step_avg:97.95ms -step:1556/1695 train_time:152406ms step_avg:97.95ms -step:1557/1695 train_time:152502ms step_avg:97.95ms -step:1558/1695 train_time:152598ms step_avg:97.94ms -step:1559/1695 train_time:152694ms step_avg:97.94ms -step:1560/1695 train_time:152791ms step_avg:97.94ms -step:1561/1695 train_time:152888ms step_avg:97.94ms -step:1562/1695 train_time:152990ms step_avg:97.95ms -step:1563/1695 train_time:153092ms step_avg:97.95ms -step:1564/1695 train_time:153194ms step_avg:97.95ms -step:1565/1695 train_time:153294ms step_avg:97.95ms -step:1566/1695 train_time:153391ms step_avg:97.95ms -step:1567/1695 train_time:153490ms step_avg:97.95ms -step:1568/1695 train_time:153587ms step_avg:97.95ms -step:1569/1695 train_time:153684ms step_avg:97.95ms -step:1570/1695 train_time:153780ms step_avg:97.95ms -step:1571/1695 train_time:153877ms step_avg:97.95ms -step:1572/1695 train_time:153976ms step_avg:97.95ms -step:1573/1695 train_time:154077ms step_avg:97.95ms -step:1574/1695 train_time:154176ms step_avg:97.95ms -step:1575/1695 train_time:154276ms step_avg:97.95ms -step:1576/1695 train_time:154375ms step_avg:97.95ms -step:1577/1695 train_time:154473ms step_avg:97.95ms -step:1578/1695 train_time:154571ms step_avg:97.95ms -step:1579/1695 train_time:154669ms step_avg:97.95ms -step:1580/1695 train_time:154766ms step_avg:97.95ms -step:1581/1695 train_time:154863ms step_avg:97.95ms -step:1582/1695 train_time:154961ms step_avg:97.95ms -step:1583/1695 train_time:155059ms step_avg:97.95ms -step:1584/1695 train_time:155156ms step_avg:97.95ms -step:1585/1695 train_time:155255ms step_avg:97.95ms -step:1586/1695 train_time:155354ms step_avg:97.95ms -step:1587/1695 train_time:155454ms step_avg:97.95ms -step:1588/1695 train_time:155553ms step_avg:97.96ms -step:1589/1695 train_time:155651ms step_avg:97.96ms -step:1590/1695 train_time:155749ms step_avg:97.96ms -step:1591/1695 train_time:155847ms step_avg:97.96ms -step:1592/1695 train_time:155946ms step_avg:97.96ms -step:1593/1695 train_time:156045ms step_avg:97.96ms -step:1594/1695 train_time:156143ms step_avg:97.96ms -step:1595/1695 train_time:156241ms step_avg:97.96ms -step:1596/1695 train_time:156339ms step_avg:97.96ms -step:1597/1695 train_time:156436ms step_avg:97.96ms -step:1598/1695 train_time:156534ms step_avg:97.96ms -step:1599/1695 train_time:156632ms step_avg:97.96ms -step:1600/1695 train_time:156731ms step_avg:97.96ms -step:1601/1695 train_time:156829ms step_avg:97.96ms -step:1602/1695 train_time:156928ms step_avg:97.96ms -step:1603/1695 train_time:157027ms step_avg:97.96ms -step:1604/1695 train_time:157126ms step_avg:97.96ms -step:1605/1695 train_time:157224ms step_avg:97.96ms -step:1606/1695 train_time:157323ms step_avg:97.96ms -step:1607/1695 train_time:157420ms step_avg:97.96ms -step:1608/1695 train_time:157518ms step_avg:97.96ms -step:1609/1695 train_time:157615ms step_avg:97.96ms -step:1610/1695 train_time:157712ms step_avg:97.96ms -step:1611/1695 train_time:157810ms step_avg:97.96ms -step:1612/1695 train_time:157909ms step_avg:97.96ms -step:1613/1695 train_time:158007ms step_avg:97.96ms -step:1614/1695 train_time:158106ms step_avg:97.96ms -step:1615/1695 train_time:158204ms step_avg:97.96ms -step:1616/1695 train_time:158302ms step_avg:97.96ms -step:1617/1695 train_time:158400ms step_avg:97.96ms -step:1618/1695 train_time:158498ms step_avg:97.96ms -step:1619/1695 train_time:158594ms step_avg:97.96ms -step:1620/1695 train_time:158691ms step_avg:97.96ms -step:1621/1695 train_time:158789ms step_avg:97.96ms -step:1622/1695 train_time:158887ms step_avg:97.96ms -step:1623/1695 train_time:158984ms step_avg:97.96ms -step:1624/1695 train_time:159083ms step_avg:97.96ms -step:1625/1695 train_time:159181ms step_avg:97.96ms -step:1625/1695 val_loss:3.2909 train_time:159277ms step_avg:98.02ms -step:1626/1695 train_time:159302ms step_avg:97.97ms -step:1627/1695 train_time:159385ms step_avg:97.96ms -step:1628/1695 train_time:159484ms step_avg:97.96ms -step:1629/1695 train_time:159582ms step_avg:97.96ms -step:1630/1695 train_time:159678ms step_avg:97.96ms -step:1631/1695 train_time:159775ms step_avg:97.96ms -step:1632/1695 train_time:159872ms step_avg:97.96ms -step:1633/1695 train_time:159970ms step_avg:97.96ms -step:1634/1695 train_time:160066ms step_avg:97.96ms -step:1635/1695 train_time:160163ms step_avg:97.96ms -step:1636/1695 train_time:160263ms step_avg:97.96ms -step:1637/1695 train_time:160363ms step_avg:97.96ms -step:1638/1695 train_time:160463ms step_avg:97.96ms -step:1639/1695 train_time:160561ms step_avg:97.96ms -step:1640/1695 train_time:160658ms step_avg:97.96ms -step:1641/1695 train_time:160756ms step_avg:97.96ms -step:1642/1695 train_time:160853ms step_avg:97.96ms -step:1643/1695 train_time:160949ms step_avg:97.96ms -step:1644/1695 train_time:161046ms step_avg:97.96ms -step:1645/1695 train_time:161142ms step_avg:97.96ms -step:1646/1695 train_time:161240ms step_avg:97.96ms -step:1647/1695 train_time:161339ms step_avg:97.96ms -step:1648/1695 train_time:161438ms step_avg:97.96ms -step:1649/1695 train_time:161536ms step_avg:97.96ms -step:1650/1695 train_time:161634ms step_avg:97.96ms -step:1651/1695 train_time:161732ms step_avg:97.96ms -step:1652/1695 train_time:161831ms step_avg:97.96ms -step:1653/1695 train_time:161930ms step_avg:97.96ms -step:1654/1695 train_time:162027ms step_avg:97.96ms -step:1655/1695 train_time:162124ms step_avg:97.96ms -step:1656/1695 train_time:162223ms step_avg:97.96ms -step:1657/1695 train_time:162320ms step_avg:97.96ms -step:1658/1695 train_time:162418ms step_avg:97.96ms -step:1659/1695 train_time:162516ms step_avg:97.96ms -step:1660/1695 train_time:162614ms step_avg:97.96ms -step:1661/1695 train_time:162712ms step_avg:97.96ms -step:1662/1695 train_time:162811ms step_avg:97.96ms -step:1663/1695 train_time:162909ms step_avg:97.96ms -step:1664/1695 train_time:163007ms step_avg:97.96ms -step:1665/1695 train_time:163105ms step_avg:97.96ms -step:1666/1695 train_time:163204ms step_avg:97.96ms -step:1667/1695 train_time:163303ms step_avg:97.96ms -step:1668/1695 train_time:163402ms step_avg:97.96ms -step:1669/1695 train_time:163499ms step_avg:97.96ms -step:1670/1695 train_time:163597ms step_avg:97.96ms -step:1671/1695 train_time:163694ms step_avg:97.96ms -step:1672/1695 train_time:163791ms step_avg:97.96ms -step:1673/1695 train_time:163889ms step_avg:97.96ms -step:1674/1695 train_time:163987ms step_avg:97.96ms -step:1675/1695 train_time:164085ms step_avg:97.96ms -step:1676/1695 train_time:164183ms step_avg:97.96ms -step:1677/1695 train_time:164281ms step_avg:97.96ms -step:1678/1695 train_time:164379ms step_avg:97.96ms -step:1679/1695 train_time:164476ms step_avg:97.96ms -step:1680/1695 train_time:164573ms step_avg:97.96ms -step:1681/1695 train_time:164671ms step_avg:97.96ms -step:1682/1695 train_time:164768ms step_avg:97.96ms -step:1683/1695 train_time:164866ms step_avg:97.96ms -step:1684/1695 train_time:164963ms step_avg:97.96ms -step:1685/1695 train_time:165060ms step_avg:97.96ms -step:1686/1695 train_time:165157ms step_avg:97.96ms -step:1687/1695 train_time:165255ms step_avg:97.96ms -step:1688/1695 train_time:165354ms step_avg:97.96ms -step:1689/1695 train_time:165453ms step_avg:97.96ms -step:1690/1695 train_time:165551ms step_avg:97.96ms -step:1691/1695 train_time:165650ms step_avg:97.96ms -step:1692/1695 train_time:165748ms step_avg:97.96ms -step:1693/1695 train_time:165845ms step_avg:97.96ms -step:1694/1695 train_time:165943ms step_avg:97.96ms -step:1695/1695 train_time:166040ms step_avg:97.96ms -step:1695/1695 val_loss:3.2791 train_time:166135ms step_avg:98.01ms -peak memory allocated: 34000 MiB reserved: 49416 MiB diff --git a/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt b/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt deleted file mode 100644 index f68fe219a..000000000 --- a/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 04:10:32 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 29C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 31C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 29C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:507ms step_avg:507.01ms -step:2/1695 train_time:531ms step_avg:265.61ms -step:3/1695 train_time:603ms step_avg:201.15ms -step:4/1695 train_time:695ms step_avg:173.81ms -step:5/1695 train_time:789ms step_avg:157.80ms -step:6/1695 train_time:883ms step_avg:147.20ms -step:7/1695 train_time:976ms step_avg:139.42ms -step:8/1695 train_time:1069ms step_avg:133.63ms -step:9/1695 train_time:1162ms step_avg:129.16ms -step:10/1695 train_time:1256ms step_avg:125.61ms -step:11/1695 train_time:1350ms step_avg:122.74ms -step:12/1695 train_time:1446ms step_avg:120.48ms -step:13/1695 train_time:1542ms step_avg:118.60ms -step:14/1695 train_time:1637ms step_avg:116.89ms -step:15/1695 train_time:1731ms step_avg:115.41ms -step:16/1695 train_time:1827ms step_avg:114.16ms -step:17/1695 train_time:1921ms step_avg:112.98ms -step:18/1695 train_time:2014ms step_avg:111.90ms -step:19/1695 train_time:2108ms step_avg:110.97ms -step:20/1695 train_time:2204ms step_avg:110.18ms -step:21/1695 train_time:2297ms step_avg:109.37ms -step:22/1695 train_time:2392ms step_avg:108.71ms -step:23/1695 train_time:2487ms step_avg:108.14ms -step:24/1695 train_time:2583ms step_avg:107.62ms -step:25/1695 train_time:2677ms step_avg:107.07ms -step:26/1695 train_time:2771ms step_avg:106.58ms -step:27/1695 train_time:2866ms step_avg:106.14ms -step:28/1695 train_time:2960ms step_avg:105.72ms -step:29/1695 train_time:3054ms step_avg:105.31ms -step:30/1695 train_time:3149ms step_avg:104.96ms -step:31/1695 train_time:3243ms step_avg:104.62ms -step:32/1695 train_time:3337ms step_avg:104.27ms -step:33/1695 train_time:3432ms step_avg:104.01ms -step:34/1695 train_time:3529ms step_avg:103.78ms -step:35/1695 train_time:3624ms step_avg:103.56ms -step:36/1695 train_time:3719ms step_avg:103.30ms -step:37/1695 train_time:3814ms step_avg:103.07ms -step:38/1695 train_time:3909ms step_avg:102.86ms -step:39/1695 train_time:4004ms step_avg:102.66ms -step:40/1695 train_time:4098ms step_avg:102.45ms -step:41/1695 train_time:4192ms step_avg:102.24ms -step:42/1695 train_time:4286ms step_avg:102.04ms -step:43/1695 train_time:4380ms step_avg:101.86ms -step:44/1695 train_time:4474ms step_avg:101.68ms -step:45/1695 train_time:4569ms step_avg:101.53ms -step:46/1695 train_time:4664ms step_avg:101.39ms -step:47/1695 train_time:4757ms step_avg:101.22ms -step:48/1695 train_time:4852ms step_avg:101.07ms -step:49/1695 train_time:4947ms step_avg:100.95ms -step:50/1695 train_time:5041ms step_avg:100.82ms -step:51/1695 train_time:5135ms step_avg:100.68ms -step:52/1695 train_time:5229ms step_avg:100.56ms -step:53/1695 train_time:5324ms step_avg:100.45ms -step:54/1695 train_time:5418ms step_avg:100.33ms -step:55/1695 train_time:5513ms step_avg:100.23ms -step:56/1695 train_time:5608ms step_avg:100.14ms -step:57/1695 train_time:5703ms step_avg:100.05ms -step:58/1695 train_time:5796ms step_avg:99.94ms -step:59/1695 train_time:5891ms step_avg:99.85ms -step:60/1695 train_time:5986ms step_avg:99.76ms -step:61/1695 train_time:6079ms step_avg:99.66ms -step:62/1695 train_time:6173ms step_avg:99.56ms -step:63/1695 train_time:6268ms step_avg:99.50ms -step:64/1695 train_time:6364ms step_avg:99.44ms -step:65/1695 train_time:6458ms step_avg:99.36ms -step:66/1695 train_time:6553ms step_avg:99.29ms -step:67/1695 train_time:6649ms step_avg:99.24ms -step:68/1695 train_time:6745ms step_avg:99.19ms -step:69/1695 train_time:6837ms step_avg:99.08ms -step:70/1695 train_time:6931ms step_avg:99.02ms -step:71/1695 train_time:7026ms step_avg:98.95ms -step:72/1695 train_time:7119ms step_avg:98.87ms -step:73/1695 train_time:7213ms step_avg:98.81ms -step:74/1695 train_time:7307ms step_avg:98.75ms -step:75/1695 train_time:7402ms step_avg:98.69ms -step:76/1695 train_time:7495ms step_avg:98.62ms -step:77/1695 train_time:7590ms step_avg:98.58ms -step:78/1695 train_time:7686ms step_avg:98.53ms -step:79/1695 train_time:7779ms step_avg:98.47ms -step:80/1695 train_time:7873ms step_avg:98.41ms -step:81/1695 train_time:7969ms step_avg:98.38ms -step:82/1695 train_time:8064ms step_avg:98.34ms -step:83/1695 train_time:8158ms step_avg:98.28ms -step:84/1695 train_time:8252ms step_avg:98.23ms -step:85/1695 train_time:8346ms step_avg:98.18ms -step:86/1695 train_time:8439ms step_avg:98.13ms -step:87/1695 train_time:8533ms step_avg:98.08ms -step:88/1695 train_time:8629ms step_avg:98.06ms -step:89/1695 train_time:8724ms step_avg:98.02ms -step:90/1695 train_time:8817ms step_avg:97.97ms -step:91/1695 train_time:8911ms step_avg:97.93ms -step:92/1695 train_time:9006ms step_avg:97.89ms -step:93/1695 train_time:9100ms step_avg:97.85ms -step:94/1695 train_time:9194ms step_avg:97.80ms -step:95/1695 train_time:9289ms step_avg:97.77ms -step:96/1695 train_time:9383ms step_avg:97.74ms -step:97/1695 train_time:9476ms step_avg:97.69ms -step:98/1695 train_time:9570ms step_avg:97.66ms -step:99/1695 train_time:9665ms step_avg:97.63ms -step:100/1695 train_time:9760ms step_avg:97.60ms -step:101/1695 train_time:9854ms step_avg:97.56ms -step:102/1695 train_time:9949ms step_avg:97.54ms -step:103/1695 train_time:10043ms step_avg:97.51ms -step:104/1695 train_time:10137ms step_avg:97.47ms -step:105/1695 train_time:10231ms step_avg:97.44ms -step:106/1695 train_time:10326ms step_avg:97.41ms -step:107/1695 train_time:10420ms step_avg:97.38ms -step:108/1695 train_time:10513ms step_avg:97.34ms -step:109/1695 train_time:10608ms step_avg:97.32ms -step:110/1695 train_time:10703ms step_avg:97.30ms -step:111/1695 train_time:10797ms step_avg:97.27ms -step:112/1695 train_time:10891ms step_avg:97.24ms -step:113/1695 train_time:10985ms step_avg:97.21ms -step:114/1695 train_time:11078ms step_avg:97.18ms -step:115/1695 train_time:11172ms step_avg:97.15ms -step:116/1695 train_time:11267ms step_avg:97.13ms -step:117/1695 train_time:11362ms step_avg:97.11ms -step:118/1695 train_time:11456ms step_avg:97.08ms -step:119/1695 train_time:11550ms step_avg:97.06ms -step:120/1695 train_time:11644ms step_avg:97.03ms -step:121/1695 train_time:11737ms step_avg:97.00ms -step:122/1695 train_time:11832ms step_avg:96.98ms -step:123/1695 train_time:11927ms step_avg:96.97ms -step:124/1695 train_time:12022ms step_avg:96.95ms -step:125/1695 train_time:12115ms step_avg:96.92ms -step:125/1695 val_loss:4.3104 train_time:12207ms step_avg:97.66ms -step:126/1695 train_time:12232ms step_avg:97.08ms -step:127/1695 train_time:12310ms step_avg:96.93ms -step:128/1695 train_time:12410ms step_avg:96.96ms -step:129/1695 train_time:12506ms step_avg:96.94ms -step:130/1695 train_time:12599ms step_avg:96.92ms -step:131/1695 train_time:12692ms step_avg:96.89ms -step:132/1695 train_time:12785ms step_avg:96.86ms -step:133/1695 train_time:12878ms step_avg:96.83ms -step:134/1695 train_time:12972ms step_avg:96.80ms -step:135/1695 train_time:13064ms step_avg:96.77ms -step:136/1695 train_time:13157ms step_avg:96.75ms -step:137/1695 train_time:13252ms step_avg:96.73ms -step:138/1695 train_time:13348ms step_avg:96.73ms -step:139/1695 train_time:13444ms step_avg:96.72ms -step:140/1695 train_time:13539ms step_avg:96.70ms -step:141/1695 train_time:13633ms step_avg:96.69ms -step:142/1695 train_time:13727ms step_avg:96.67ms -step:143/1695 train_time:13821ms step_avg:96.65ms -step:144/1695 train_time:13916ms step_avg:96.64ms -step:145/1695 train_time:14010ms step_avg:96.62ms -step:146/1695 train_time:14103ms step_avg:96.59ms -step:147/1695 train_time:14197ms step_avg:96.58ms -step:148/1695 train_time:14293ms step_avg:96.57ms -step:149/1695 train_time:14387ms step_avg:96.56ms -step:150/1695 train_time:14482ms step_avg:96.55ms -step:151/1695 train_time:14578ms step_avg:96.55ms -step:152/1695 train_time:14673ms step_avg:96.53ms -step:153/1695 train_time:14766ms step_avg:96.51ms -step:154/1695 train_time:14860ms step_avg:96.49ms -step:155/1695 train_time:14953ms step_avg:96.47ms -step:156/1695 train_time:15046ms step_avg:96.45ms -step:157/1695 train_time:15140ms step_avg:96.43ms -step:158/1695 train_time:15233ms step_avg:96.41ms -step:159/1695 train_time:15327ms step_avg:96.40ms -step:160/1695 train_time:15422ms step_avg:96.39ms -step:161/1695 train_time:15517ms step_avg:96.38ms -step:162/1695 train_time:15613ms step_avg:96.38ms -step:163/1695 train_time:15706ms step_avg:96.36ms -step:164/1695 train_time:15800ms step_avg:96.34ms -step:165/1695 train_time:15893ms step_avg:96.32ms -step:166/1695 train_time:15987ms step_avg:96.30ms -step:167/1695 train_time:16080ms step_avg:96.29ms -step:168/1695 train_time:16174ms step_avg:96.27ms -step:169/1695 train_time:16267ms step_avg:96.26ms -step:170/1695 train_time:16361ms step_avg:96.24ms -step:171/1695 train_time:16455ms step_avg:96.23ms -step:172/1695 train_time:16551ms step_avg:96.22ms -step:173/1695 train_time:16932ms step_avg:97.87ms -step:174/1695 train_time:17007ms step_avg:97.74ms -step:175/1695 train_time:17099ms step_avg:97.71ms -step:176/1695 train_time:17192ms step_avg:97.68ms -step:177/1695 train_time:17285ms step_avg:97.65ms -step:178/1695 train_time:17378ms step_avg:97.63ms -step:179/1695 train_time:17472ms step_avg:97.61ms -step:180/1695 train_time:17565ms step_avg:97.58ms -step:181/1695 train_time:17658ms step_avg:97.56ms -step:182/1695 train_time:17752ms step_avg:97.54ms -step:183/1695 train_time:17847ms step_avg:97.53ms -step:184/1695 train_time:17944ms step_avg:97.52ms -step:185/1695 train_time:18041ms step_avg:97.52ms -step:186/1695 train_time:18135ms step_avg:97.50ms -step:187/1695 train_time:18229ms step_avg:97.48ms -step:188/1695 train_time:18322ms step_avg:97.46ms -step:189/1695 train_time:18416ms step_avg:97.44ms -step:190/1695 train_time:18510ms step_avg:97.42ms -step:191/1695 train_time:18602ms step_avg:97.40ms -step:192/1695 train_time:18696ms step_avg:97.38ms -step:193/1695 train_time:18790ms step_avg:97.36ms -step:194/1695 train_time:18885ms step_avg:97.34ms -step:195/1695 train_time:18980ms step_avg:97.33ms -step:196/1695 train_time:19075ms step_avg:97.32ms -step:197/1695 train_time:19169ms step_avg:97.31ms -step:198/1695 train_time:19263ms step_avg:97.29ms -step:199/1695 train_time:19357ms step_avg:97.27ms -step:200/1695 train_time:19451ms step_avg:97.26ms -step:201/1695 train_time:19544ms step_avg:97.23ms -step:202/1695 train_time:19638ms step_avg:97.22ms -step:203/1695 train_time:19733ms step_avg:97.21ms -step:204/1695 train_time:19826ms step_avg:97.19ms -step:205/1695 train_time:19921ms step_avg:97.18ms -step:206/1695 train_time:20016ms step_avg:97.17ms -step:207/1695 train_time:20111ms step_avg:97.15ms -step:208/1695 train_time:20204ms step_avg:97.14ms -step:209/1695 train_time:20299ms step_avg:97.12ms -step:210/1695 train_time:20393ms step_avg:97.11ms -step:211/1695 train_time:20487ms step_avg:97.10ms -step:212/1695 train_time:20581ms step_avg:97.08ms -step:213/1695 train_time:20674ms step_avg:97.06ms -step:214/1695 train_time:20768ms step_avg:97.05ms -step:215/1695 train_time:20862ms step_avg:97.03ms -step:216/1695 train_time:20957ms step_avg:97.02ms -step:217/1695 train_time:21051ms step_avg:97.01ms -step:218/1695 train_time:21145ms step_avg:97.00ms -step:219/1695 train_time:21240ms step_avg:96.98ms -step:220/1695 train_time:21334ms step_avg:96.97ms -step:221/1695 train_time:21428ms step_avg:96.96ms -step:222/1695 train_time:21521ms step_avg:96.94ms -step:223/1695 train_time:21616ms step_avg:96.93ms -step:224/1695 train_time:21710ms step_avg:96.92ms -step:225/1695 train_time:21804ms step_avg:96.91ms -step:226/1695 train_time:21899ms step_avg:96.90ms -step:227/1695 train_time:21994ms step_avg:96.89ms -step:228/1695 train_time:22087ms step_avg:96.87ms -step:229/1695 train_time:22181ms step_avg:96.86ms -step:230/1695 train_time:22276ms step_avg:96.85ms -step:231/1695 train_time:22371ms step_avg:96.85ms -step:232/1695 train_time:22465ms step_avg:96.83ms -step:233/1695 train_time:22559ms step_avg:96.82ms -step:234/1695 train_time:22653ms step_avg:96.81ms -step:235/1695 train_time:22747ms step_avg:96.79ms -step:236/1695 train_time:22840ms step_avg:96.78ms -step:237/1695 train_time:22935ms step_avg:96.77ms -step:238/1695 train_time:23029ms step_avg:96.76ms -step:239/1695 train_time:23122ms step_avg:96.75ms -step:240/1695 train_time:23217ms step_avg:96.74ms -step:241/1695 train_time:23312ms step_avg:96.73ms -step:242/1695 train_time:23405ms step_avg:96.72ms -step:243/1695 train_time:23499ms step_avg:96.71ms -step:244/1695 train_time:23593ms step_avg:96.69ms -step:245/1695 train_time:23686ms step_avg:96.68ms -step:246/1695 train_time:23780ms step_avg:96.67ms -step:247/1695 train_time:23875ms step_avg:96.66ms -step:248/1695 train_time:23969ms step_avg:96.65ms -step:249/1695 train_time:24063ms step_avg:96.64ms -step:250/1695 train_time:24157ms step_avg:96.63ms -step:250/1695 val_loss:3.9654 train_time:24251ms step_avg:97.00ms -step:251/1695 train_time:24276ms step_avg:96.72ms -step:252/1695 train_time:24355ms step_avg:96.65ms -step:253/1695 train_time:24453ms step_avg:96.65ms -step:254/1695 train_time:24548ms step_avg:96.65ms -step:255/1695 train_time:24642ms step_avg:96.64ms -step:256/1695 train_time:24735ms step_avg:96.62ms -step:257/1695 train_time:24828ms step_avg:96.61ms -step:258/1695 train_time:24921ms step_avg:96.59ms -step:259/1695 train_time:25014ms step_avg:96.58ms -step:260/1695 train_time:25108ms step_avg:96.57ms -step:261/1695 train_time:25201ms step_avg:96.55ms -step:262/1695 train_time:25295ms step_avg:96.55ms -step:263/1695 train_time:25391ms step_avg:96.54ms -step:264/1695 train_time:25487ms step_avg:96.54ms -step:265/1695 train_time:25582ms step_avg:96.54ms -step:266/1695 train_time:25676ms step_avg:96.52ms -step:267/1695 train_time:25769ms step_avg:96.51ms -step:268/1695 train_time:25862ms step_avg:96.50ms -step:269/1695 train_time:25955ms step_avg:96.49ms -step:270/1695 train_time:26048ms step_avg:96.47ms -step:271/1695 train_time:26141ms step_avg:96.46ms -step:272/1695 train_time:26235ms step_avg:96.45ms -step:273/1695 train_time:26330ms step_avg:96.45ms -step:274/1695 train_time:26426ms step_avg:96.45ms -step:275/1695 train_time:26522ms step_avg:96.44ms -step:276/1695 train_time:26616ms step_avg:96.44ms -step:277/1695 train_time:26710ms step_avg:96.43ms -step:278/1695 train_time:26803ms step_avg:96.42ms -step:279/1695 train_time:26897ms step_avg:96.40ms -step:280/1695 train_time:26990ms step_avg:96.39ms -step:281/1695 train_time:27083ms step_avg:96.38ms -step:282/1695 train_time:27176ms step_avg:96.37ms -step:283/1695 train_time:27270ms step_avg:96.36ms -step:284/1695 train_time:27365ms step_avg:96.36ms -step:285/1695 train_time:27459ms step_avg:96.35ms -step:286/1695 train_time:27553ms step_avg:96.34ms -step:287/1695 train_time:27648ms step_avg:96.34ms -step:288/1695 train_time:27742ms step_avg:96.33ms -step:289/1695 train_time:27836ms step_avg:96.32ms -step:290/1695 train_time:27930ms step_avg:96.31ms -step:291/1695 train_time:28024ms step_avg:96.30ms -step:292/1695 train_time:28118ms step_avg:96.29ms -step:293/1695 train_time:28211ms step_avg:96.28ms -step:294/1695 train_time:28305ms step_avg:96.28ms -step:295/1695 train_time:28399ms step_avg:96.27ms -step:296/1695 train_time:28493ms step_avg:96.26ms -step:297/1695 train_time:28587ms step_avg:96.25ms -step:298/1695 train_time:28682ms step_avg:96.25ms -step:299/1695 train_time:28776ms step_avg:96.24ms -step:300/1695 train_time:28870ms step_avg:96.23ms -step:301/1695 train_time:28964ms step_avg:96.23ms -step:302/1695 train_time:29057ms step_avg:96.22ms -step:303/1695 train_time:29151ms step_avg:96.21ms -step:304/1695 train_time:29245ms step_avg:96.20ms -step:305/1695 train_time:29340ms step_avg:96.20ms -step:306/1695 train_time:29434ms step_avg:96.19ms -step:307/1695 train_time:29528ms step_avg:96.18ms -step:308/1695 train_time:29623ms step_avg:96.18ms -step:309/1695 train_time:29717ms step_avg:96.17ms -step:310/1695 train_time:29812ms step_avg:96.17ms -step:311/1695 train_time:29907ms step_avg:96.16ms -step:312/1695 train_time:30001ms step_avg:96.16ms -step:313/1695 train_time:30094ms step_avg:96.15ms -step:314/1695 train_time:30188ms step_avg:96.14ms -step:315/1695 train_time:30282ms step_avg:96.13ms -step:316/1695 train_time:30375ms step_avg:96.12ms -step:317/1695 train_time:30469ms step_avg:96.12ms -step:318/1695 train_time:30563ms step_avg:96.11ms -step:319/1695 train_time:30658ms step_avg:96.11ms -step:320/1695 train_time:30751ms step_avg:96.10ms -step:321/1695 train_time:30846ms step_avg:96.09ms -step:322/1695 train_time:30941ms step_avg:96.09ms -step:323/1695 train_time:31034ms step_avg:96.08ms -step:324/1695 train_time:31128ms step_avg:96.07ms -step:325/1695 train_time:31223ms step_avg:96.07ms -step:326/1695 train_time:31317ms step_avg:96.06ms -step:327/1695 train_time:31410ms step_avg:96.06ms -step:328/1695 train_time:31504ms step_avg:96.05ms -step:329/1695 train_time:31597ms step_avg:96.04ms -step:330/1695 train_time:31691ms step_avg:96.03ms -step:331/1695 train_time:31785ms step_avg:96.03ms -step:332/1695 train_time:31881ms step_avg:96.03ms -step:333/1695 train_time:31974ms step_avg:96.02ms -step:334/1695 train_time:32069ms step_avg:96.01ms -step:335/1695 train_time:32163ms step_avg:96.01ms -step:336/1695 train_time:32257ms step_avg:96.00ms -step:337/1695 train_time:32350ms step_avg:95.99ms -step:338/1695 train_time:32446ms step_avg:95.99ms -step:339/1695 train_time:32541ms step_avg:95.99ms -step:340/1695 train_time:32634ms step_avg:95.98ms -step:341/1695 train_time:32728ms step_avg:95.98ms -step:342/1695 train_time:32823ms step_avg:95.97ms -step:343/1695 train_time:32916ms step_avg:95.97ms -step:344/1695 train_time:33010ms step_avg:95.96ms -step:345/1695 train_time:33345ms step_avg:96.65ms -step:346/1695 train_time:33427ms step_avg:96.61ms -step:347/1695 train_time:33519ms step_avg:96.60ms -step:348/1695 train_time:33612ms step_avg:96.59ms -step:349/1695 train_time:33705ms step_avg:96.58ms -step:350/1695 train_time:33798ms step_avg:96.57ms -step:351/1695 train_time:33891ms step_avg:96.55ms -step:352/1695 train_time:33984ms step_avg:96.54ms -step:353/1695 train_time:34077ms step_avg:96.53ms -step:354/1695 train_time:34169ms step_avg:96.52ms -step:355/1695 train_time:34266ms step_avg:96.52ms -step:356/1695 train_time:34363ms step_avg:96.52ms -step:357/1695 train_time:34458ms step_avg:96.52ms -step:358/1695 train_time:34552ms step_avg:96.51ms -step:359/1695 train_time:34646ms step_avg:96.51ms -step:360/1695 train_time:34741ms step_avg:96.50ms -step:361/1695 train_time:34834ms step_avg:96.49ms -step:362/1695 train_time:34927ms step_avg:96.48ms -step:363/1695 train_time:35020ms step_avg:96.47ms -step:364/1695 train_time:35113ms step_avg:96.46ms -step:365/1695 train_time:35207ms step_avg:96.46ms -step:366/1695 train_time:35303ms step_avg:96.45ms -step:367/1695 train_time:35398ms step_avg:96.45ms -step:368/1695 train_time:35492ms step_avg:96.45ms -step:369/1695 train_time:35587ms step_avg:96.44ms -step:370/1695 train_time:35681ms step_avg:96.44ms -step:371/1695 train_time:35775ms step_avg:96.43ms -step:372/1695 train_time:35868ms step_avg:96.42ms -step:373/1695 train_time:35961ms step_avg:96.41ms -step:374/1695 train_time:36054ms step_avg:96.40ms -step:375/1695 train_time:36147ms step_avg:96.39ms -step:375/1695 val_loss:3.8151 train_time:36240ms step_avg:96.64ms -step:376/1695 train_time:36265ms step_avg:96.45ms -step:377/1695 train_time:36343ms step_avg:96.40ms -step:378/1695 train_time:36441ms step_avg:96.41ms -step:379/1695 train_time:36535ms step_avg:96.40ms -step:380/1695 train_time:36628ms step_avg:96.39ms -step:381/1695 train_time:36721ms step_avg:96.38ms -step:382/1695 train_time:36814ms step_avg:96.37ms -step:383/1695 train_time:36907ms step_avg:96.36ms -step:384/1695 train_time:36999ms step_avg:96.35ms -step:385/1695 train_time:37092ms step_avg:96.34ms -step:386/1695 train_time:37185ms step_avg:96.33ms -step:387/1695 train_time:37280ms step_avg:96.33ms -step:388/1695 train_time:37376ms step_avg:96.33ms -step:389/1695 train_time:37474ms step_avg:96.33ms -step:390/1695 train_time:37570ms step_avg:96.33ms -step:391/1695 train_time:37665ms step_avg:96.33ms -step:392/1695 train_time:37758ms step_avg:96.32ms -step:393/1695 train_time:37851ms step_avg:96.31ms -step:394/1695 train_time:37945ms step_avg:96.31ms -step:395/1695 train_time:38038ms step_avg:96.30ms -step:396/1695 train_time:38132ms step_avg:96.29ms -step:397/1695 train_time:38226ms step_avg:96.29ms -step:398/1695 train_time:38320ms step_avg:96.28ms -step:399/1695 train_time:38414ms step_avg:96.28ms -step:400/1695 train_time:38510ms step_avg:96.27ms -step:401/1695 train_time:38606ms step_avg:96.27ms -step:402/1695 train_time:38700ms step_avg:96.27ms -step:403/1695 train_time:38793ms step_avg:96.26ms -step:404/1695 train_time:38888ms step_avg:96.26ms -step:405/1695 train_time:38981ms step_avg:96.25ms -step:406/1695 train_time:39074ms step_avg:96.24ms -step:407/1695 train_time:39168ms step_avg:96.24ms -step:408/1695 train_time:39262ms step_avg:96.23ms -step:409/1695 train_time:39356ms step_avg:96.22ms -step:410/1695 train_time:39450ms step_avg:96.22ms -step:411/1695 train_time:39545ms step_avg:96.22ms -step:412/1695 train_time:39639ms step_avg:96.21ms -step:413/1695 train_time:39733ms step_avg:96.21ms -step:414/1695 train_time:39827ms step_avg:96.20ms -step:415/1695 train_time:39920ms step_avg:96.19ms -step:416/1695 train_time:40013ms step_avg:96.19ms -step:417/1695 train_time:40107ms step_avg:96.18ms -step:418/1695 train_time:40200ms step_avg:96.17ms -step:419/1695 train_time:40294ms step_avg:96.17ms -step:420/1695 train_time:40388ms step_avg:96.16ms -step:421/1695 train_time:40482ms step_avg:96.16ms -step:422/1695 train_time:40577ms step_avg:96.15ms -step:423/1695 train_time:40672ms step_avg:96.15ms -step:424/1695 train_time:40767ms step_avg:96.15ms -step:425/1695 train_time:40862ms step_avg:96.15ms -step:426/1695 train_time:40955ms step_avg:96.14ms -step:427/1695 train_time:41049ms step_avg:96.13ms -step:428/1695 train_time:41143ms step_avg:96.13ms -step:429/1695 train_time:41236ms step_avg:96.12ms -step:430/1695 train_time:41330ms step_avg:96.12ms -step:431/1695 train_time:41424ms step_avg:96.11ms -step:432/1695 train_time:41518ms step_avg:96.11ms -step:433/1695 train_time:41612ms step_avg:96.10ms -step:434/1695 train_time:41707ms step_avg:96.10ms -step:435/1695 train_time:41800ms step_avg:96.09ms -step:436/1695 train_time:41894ms step_avg:96.09ms -step:437/1695 train_time:41988ms step_avg:96.08ms -step:438/1695 train_time:42081ms step_avg:96.08ms -step:439/1695 train_time:42175ms step_avg:96.07ms -step:440/1695 train_time:42269ms step_avg:96.07ms -step:441/1695 train_time:42365ms step_avg:96.06ms -step:442/1695 train_time:42458ms step_avg:96.06ms -step:443/1695 train_time:42553ms step_avg:96.06ms -step:444/1695 train_time:42647ms step_avg:96.05ms -step:445/1695 train_time:42741ms step_avg:96.05ms -step:446/1695 train_time:42835ms step_avg:96.04ms -step:447/1695 train_time:42929ms step_avg:96.04ms -step:448/1695 train_time:43023ms step_avg:96.03ms -step:449/1695 train_time:43116ms step_avg:96.03ms -step:450/1695 train_time:43210ms step_avg:96.02ms -step:451/1695 train_time:43305ms step_avg:96.02ms -step:452/1695 train_time:43398ms step_avg:96.01ms -step:453/1695 train_time:43492ms step_avg:96.01ms -step:454/1695 train_time:43587ms step_avg:96.01ms -step:455/1695 train_time:43680ms step_avg:96.00ms -step:456/1695 train_time:43774ms step_avg:96.00ms -step:457/1695 train_time:43869ms step_avg:95.99ms -step:458/1695 train_time:43964ms step_avg:95.99ms -step:459/1695 train_time:44057ms step_avg:95.99ms -step:460/1695 train_time:44151ms step_avg:95.98ms -step:461/1695 train_time:44245ms step_avg:95.98ms -step:462/1695 train_time:44338ms step_avg:95.97ms -step:463/1695 train_time:44432ms step_avg:95.97ms -step:464/1695 train_time:44526ms step_avg:95.96ms -step:465/1695 train_time:44620ms step_avg:95.96ms -step:466/1695 train_time:44714ms step_avg:95.95ms -step:467/1695 train_time:44808ms step_avg:95.95ms -step:468/1695 train_time:44902ms step_avg:95.95ms -step:469/1695 train_time:44995ms step_avg:95.94ms -step:470/1695 train_time:45090ms step_avg:95.94ms -step:471/1695 train_time:45184ms step_avg:95.93ms -step:472/1695 train_time:45277ms step_avg:95.93ms -step:473/1695 train_time:45371ms step_avg:95.92ms -step:474/1695 train_time:45466ms step_avg:95.92ms -step:475/1695 train_time:45559ms step_avg:95.91ms -step:476/1695 train_time:45653ms step_avg:95.91ms -step:477/1695 train_time:45748ms step_avg:95.91ms -step:478/1695 train_time:45843ms step_avg:95.91ms -step:479/1695 train_time:45936ms step_avg:95.90ms -step:480/1695 train_time:46031ms step_avg:95.90ms -step:481/1695 train_time:46125ms step_avg:95.89ms -step:482/1695 train_time:46218ms step_avg:95.89ms -step:483/1695 train_time:46312ms step_avg:95.88ms -step:484/1695 train_time:46406ms step_avg:95.88ms -step:485/1695 train_time:46499ms step_avg:95.87ms -step:486/1695 train_time:46593ms step_avg:95.87ms -step:487/1695 train_time:46687ms step_avg:95.87ms -step:488/1695 train_time:46780ms step_avg:95.86ms -step:489/1695 train_time:46874ms step_avg:95.86ms -step:490/1695 train_time:46969ms step_avg:95.86ms -step:491/1695 train_time:47064ms step_avg:95.85ms -step:492/1695 train_time:47157ms step_avg:95.85ms -step:493/1695 train_time:47251ms step_avg:95.84ms -step:494/1695 train_time:47345ms step_avg:95.84ms -step:495/1695 train_time:47439ms step_avg:95.84ms -step:496/1695 train_time:47533ms step_avg:95.83ms -step:497/1695 train_time:47628ms step_avg:95.83ms -step:498/1695 train_time:47722ms step_avg:95.83ms -step:499/1695 train_time:47814ms step_avg:95.82ms -step:500/1695 train_time:47908ms step_avg:95.82ms -step:500/1695 val_loss:3.7156 train_time:48001ms step_avg:96.00ms -step:501/1695 train_time:48026ms step_avg:95.86ms -step:502/1695 train_time:48105ms step_avg:95.83ms -step:503/1695 train_time:48205ms step_avg:95.83ms -step:504/1695 train_time:48300ms step_avg:95.83ms -step:505/1695 train_time:48393ms step_avg:95.83ms -step:506/1695 train_time:48486ms step_avg:95.82ms -step:507/1695 train_time:48579ms step_avg:95.82ms -step:508/1695 train_time:48673ms step_avg:95.81ms -step:509/1695 train_time:48765ms step_avg:95.81ms -step:510/1695 train_time:48859ms step_avg:95.80ms -step:511/1695 train_time:48952ms step_avg:95.80ms -step:512/1695 train_time:49046ms step_avg:95.79ms -step:513/1695 train_time:49142ms step_avg:95.79ms -step:514/1695 train_time:49239ms step_avg:95.80ms -step:515/1695 train_time:49334ms step_avg:95.79ms -step:516/1695 train_time:49427ms step_avg:95.79ms -step:517/1695 train_time:49520ms step_avg:95.78ms -step:518/1695 train_time:49614ms step_avg:95.78ms -step:519/1695 train_time:50068ms step_avg:96.47ms -step:520/1695 train_time:50139ms step_avg:96.42ms -step:521/1695 train_time:50231ms step_avg:96.41ms -step:522/1695 train_time:50324ms step_avg:96.41ms -step:523/1695 train_time:50417ms step_avg:96.40ms -step:524/1695 train_time:50510ms step_avg:96.39ms -step:525/1695 train_time:50602ms step_avg:96.39ms -step:526/1695 train_time:50695ms step_avg:96.38ms -step:527/1695 train_time:50789ms step_avg:96.37ms -step:528/1695 train_time:50882ms step_avg:96.37ms -step:529/1695 train_time:50977ms step_avg:96.37ms -step:530/1695 train_time:51074ms step_avg:96.37ms -step:531/1695 train_time:51171ms step_avg:96.37ms -step:532/1695 train_time:51264ms step_avg:96.36ms -step:533/1695 train_time:51358ms step_avg:96.36ms -step:534/1695 train_time:51451ms step_avg:96.35ms -step:535/1695 train_time:51543ms step_avg:96.34ms -step:536/1695 train_time:51636ms step_avg:96.34ms -step:537/1695 train_time:51730ms step_avg:96.33ms -step:538/1695 train_time:51823ms step_avg:96.32ms -step:539/1695 train_time:51916ms step_avg:96.32ms -step:540/1695 train_time:52011ms step_avg:96.32ms -step:541/1695 train_time:52105ms step_avg:96.31ms -step:542/1695 train_time:52200ms step_avg:96.31ms -step:543/1695 train_time:52295ms step_avg:96.31ms -step:544/1695 train_time:52389ms step_avg:96.30ms -step:545/1695 train_time:52482ms step_avg:96.30ms -step:546/1695 train_time:52577ms step_avg:96.29ms -step:547/1695 train_time:52671ms step_avg:96.29ms -step:548/1695 train_time:52764ms step_avg:96.28ms -step:549/1695 train_time:52858ms step_avg:96.28ms -step:550/1695 train_time:52951ms step_avg:96.27ms -step:551/1695 train_time:53045ms step_avg:96.27ms -step:552/1695 train_time:53139ms step_avg:96.27ms -step:553/1695 train_time:53233ms step_avg:96.26ms -step:554/1695 train_time:53326ms step_avg:96.26ms -step:555/1695 train_time:53420ms step_avg:96.25ms -step:556/1695 train_time:53514ms step_avg:96.25ms -step:557/1695 train_time:53608ms step_avg:96.24ms -step:558/1695 train_time:53701ms step_avg:96.24ms -step:559/1695 train_time:53795ms step_avg:96.24ms -step:560/1695 train_time:53890ms step_avg:96.23ms -step:561/1695 train_time:53983ms step_avg:96.23ms -step:562/1695 train_time:54077ms step_avg:96.22ms -step:563/1695 train_time:54172ms step_avg:96.22ms -step:564/1695 train_time:54265ms step_avg:96.21ms -step:565/1695 train_time:54359ms step_avg:96.21ms -step:566/1695 train_time:54453ms step_avg:96.21ms -step:567/1695 train_time:54546ms step_avg:96.20ms -step:568/1695 train_time:54641ms step_avg:96.20ms -step:569/1695 train_time:54737ms step_avg:96.20ms -step:570/1695 train_time:54833ms step_avg:96.20ms -step:571/1695 train_time:54929ms step_avg:96.20ms -step:572/1695 train_time:55024ms step_avg:96.20ms -step:573/1695 train_time:55120ms step_avg:96.20ms -step:574/1695 train_time:55216ms step_avg:96.20ms -step:575/1695 train_time:55312ms step_avg:96.19ms -step:576/1695 train_time:55407ms step_avg:96.19ms -step:577/1695 train_time:55502ms step_avg:96.19ms -step:578/1695 train_time:55599ms step_avg:96.19ms -step:579/1695 train_time:55695ms step_avg:96.19ms -step:580/1695 train_time:55792ms step_avg:96.19ms -step:581/1695 train_time:55887ms step_avg:96.19ms -step:582/1695 train_time:55983ms step_avg:96.19ms -step:583/1695 train_time:56080ms step_avg:96.19ms -step:584/1695 train_time:56177ms step_avg:96.19ms -step:585/1695 train_time:56274ms step_avg:96.19ms -step:586/1695 train_time:56371ms step_avg:96.20ms -step:587/1695 train_time:56467ms step_avg:96.20ms -step:588/1695 train_time:56562ms step_avg:96.19ms -step:589/1695 train_time:56659ms step_avg:96.19ms -step:590/1695 train_time:56755ms step_avg:96.20ms -step:591/1695 train_time:56851ms step_avg:96.20ms -step:592/1695 train_time:56946ms step_avg:96.19ms -step:593/1695 train_time:57041ms step_avg:96.19ms -step:594/1695 train_time:57138ms step_avg:96.19ms -step:595/1695 train_time:57234ms step_avg:96.19ms -step:596/1695 train_time:57332ms step_avg:96.19ms -step:597/1695 train_time:57428ms step_avg:96.19ms -step:598/1695 train_time:57523ms step_avg:96.19ms -step:599/1695 train_time:57619ms step_avg:96.19ms -step:600/1695 train_time:57716ms step_avg:96.19ms -step:601/1695 train_time:57812ms step_avg:96.19ms -step:602/1695 train_time:57906ms step_avg:96.19ms -step:603/1695 train_time:58002ms step_avg:96.19ms -step:604/1695 train_time:58098ms step_avg:96.19ms -step:605/1695 train_time:58194ms step_avg:96.19ms -step:606/1695 train_time:58290ms step_avg:96.19ms -step:607/1695 train_time:58386ms step_avg:96.19ms -step:608/1695 train_time:58482ms step_avg:96.19ms -step:609/1695 train_time:58579ms step_avg:96.19ms -step:610/1695 train_time:58676ms step_avg:96.19ms -step:611/1695 train_time:58773ms step_avg:96.19ms -step:612/1695 train_time:58868ms step_avg:96.19ms -step:613/1695 train_time:58963ms step_avg:96.19ms -step:614/1695 train_time:59059ms step_avg:96.19ms -step:615/1695 train_time:59156ms step_avg:96.19ms -step:616/1695 train_time:59252ms step_avg:96.19ms -step:617/1695 train_time:59348ms step_avg:96.19ms -step:618/1695 train_time:59444ms step_avg:96.19ms -step:619/1695 train_time:59540ms step_avg:96.19ms -step:620/1695 train_time:59636ms step_avg:96.19ms -step:621/1695 train_time:59732ms step_avg:96.19ms -step:622/1695 train_time:59828ms step_avg:96.19ms -step:623/1695 train_time:59923ms step_avg:96.18ms -step:624/1695 train_time:60019ms step_avg:96.18ms -step:625/1695 train_time:60116ms step_avg:96.19ms -step:625/1695 val_loss:3.6216 train_time:60211ms step_avg:96.34ms -step:626/1695 train_time:60235ms step_avg:96.22ms -step:627/1695 train_time:60317ms step_avg:96.20ms -step:628/1695 train_time:60413ms step_avg:96.20ms -step:629/1695 train_time:60508ms step_avg:96.20ms -step:630/1695 train_time:60603ms step_avg:96.19ms -step:631/1695 train_time:60697ms step_avg:96.19ms -step:632/1695 train_time:60792ms step_avg:96.19ms -step:633/1695 train_time:60888ms step_avg:96.19ms -step:634/1695 train_time:60982ms step_avg:96.19ms -step:635/1695 train_time:61078ms step_avg:96.19ms -step:636/1695 train_time:61178ms step_avg:96.19ms -step:637/1695 train_time:61278ms step_avg:96.20ms -step:638/1695 train_time:61377ms step_avg:96.20ms -step:639/1695 train_time:61475ms step_avg:96.20ms -step:640/1695 train_time:61572ms step_avg:96.21ms -step:641/1695 train_time:61668ms step_avg:96.21ms -step:642/1695 train_time:61763ms step_avg:96.20ms -step:643/1695 train_time:61858ms step_avg:96.20ms -step:644/1695 train_time:61954ms step_avg:96.20ms -step:645/1695 train_time:62049ms step_avg:96.20ms -step:646/1695 train_time:62144ms step_avg:96.20ms -step:647/1695 train_time:62241ms step_avg:96.20ms -step:648/1695 train_time:62339ms step_avg:96.20ms -step:649/1695 train_time:62436ms step_avg:96.20ms -step:650/1695 train_time:62533ms step_avg:96.21ms -step:651/1695 train_time:62630ms step_avg:96.21ms -step:652/1695 train_time:62725ms step_avg:96.20ms -step:653/1695 train_time:62820ms step_avg:96.20ms -step:654/1695 train_time:62916ms step_avg:96.20ms -step:655/1695 train_time:63011ms step_avg:96.20ms -step:656/1695 train_time:63107ms step_avg:96.20ms -step:657/1695 train_time:63202ms step_avg:96.20ms -step:658/1695 train_time:63298ms step_avg:96.20ms -step:659/1695 train_time:63396ms step_avg:96.20ms -step:660/1695 train_time:63494ms step_avg:96.20ms -step:661/1695 train_time:63591ms step_avg:96.20ms -step:662/1695 train_time:63688ms step_avg:96.21ms -step:663/1695 train_time:63784ms step_avg:96.20ms -step:664/1695 train_time:63879ms step_avg:96.20ms -step:665/1695 train_time:63974ms step_avg:96.20ms -step:666/1695 train_time:64070ms step_avg:96.20ms -step:667/1695 train_time:64165ms step_avg:96.20ms -step:668/1695 train_time:64260ms step_avg:96.20ms -step:669/1695 train_time:64356ms step_avg:96.20ms -step:670/1695 train_time:64453ms step_avg:96.20ms -step:671/1695 train_time:64550ms step_avg:96.20ms -step:672/1695 train_time:64647ms step_avg:96.20ms -step:673/1695 train_time:64742ms step_avg:96.20ms -step:674/1695 train_time:64838ms step_avg:96.20ms -step:675/1695 train_time:64934ms step_avg:96.20ms -step:676/1695 train_time:65030ms step_avg:96.20ms -step:677/1695 train_time:65126ms step_avg:96.20ms -step:678/1695 train_time:65222ms step_avg:96.20ms -step:679/1695 train_time:65318ms step_avg:96.20ms -step:680/1695 train_time:65414ms step_avg:96.20ms -step:681/1695 train_time:65510ms step_avg:96.20ms -step:682/1695 train_time:65606ms step_avg:96.20ms -step:683/1695 train_time:65701ms step_avg:96.20ms -step:684/1695 train_time:65798ms step_avg:96.20ms -step:685/1695 train_time:65894ms step_avg:96.20ms -step:686/1695 train_time:65990ms step_avg:96.20ms -step:687/1695 train_time:66086ms step_avg:96.20ms -step:688/1695 train_time:66181ms step_avg:96.19ms -step:689/1695 train_time:66277ms step_avg:96.19ms -step:690/1695 train_time:66373ms step_avg:96.19ms -step:691/1695 train_time:66817ms step_avg:96.70ms -step:692/1695 train_time:66898ms step_avg:96.67ms -step:693/1695 train_time:66992ms step_avg:96.67ms -step:694/1695 train_time:67087ms step_avg:96.67ms -step:695/1695 train_time:67181ms step_avg:96.66ms -step:696/1695 train_time:67277ms step_avg:96.66ms -step:697/1695 train_time:67371ms step_avg:96.66ms -step:698/1695 train_time:67466ms step_avg:96.66ms -step:699/1695 train_time:67560ms step_avg:96.65ms -step:700/1695 train_time:67656ms step_avg:96.65ms -step:701/1695 train_time:67756ms step_avg:96.66ms -step:702/1695 train_time:67856ms step_avg:96.66ms -step:703/1695 train_time:67953ms step_avg:96.66ms -step:704/1695 train_time:68049ms step_avg:96.66ms -step:705/1695 train_time:68144ms step_avg:96.66ms -step:706/1695 train_time:68238ms step_avg:96.65ms -step:707/1695 train_time:68334ms step_avg:96.65ms -step:708/1695 train_time:68430ms step_avg:96.65ms -step:709/1695 train_time:68524ms step_avg:96.65ms -step:710/1695 train_time:68618ms step_avg:96.65ms -step:711/1695 train_time:68715ms step_avg:96.65ms -step:712/1695 train_time:68813ms step_avg:96.65ms -step:713/1695 train_time:68909ms step_avg:96.65ms -step:714/1695 train_time:69005ms step_avg:96.65ms -step:715/1695 train_time:69101ms step_avg:96.64ms -step:716/1695 train_time:69196ms step_avg:96.64ms -step:717/1695 train_time:69293ms step_avg:96.64ms -step:718/1695 train_time:69390ms step_avg:96.64ms -step:719/1695 train_time:69485ms step_avg:96.64ms -step:720/1695 train_time:69580ms step_avg:96.64ms -step:721/1695 train_time:69676ms step_avg:96.64ms -step:722/1695 train_time:69774ms step_avg:96.64ms -step:723/1695 train_time:69871ms step_avg:96.64ms -step:724/1695 train_time:69969ms step_avg:96.64ms -step:725/1695 train_time:70065ms step_avg:96.64ms -step:726/1695 train_time:70160ms step_avg:96.64ms -step:727/1695 train_time:70256ms step_avg:96.64ms -step:728/1695 train_time:70353ms step_avg:96.64ms -step:729/1695 train_time:70450ms step_avg:96.64ms -step:730/1695 train_time:70547ms step_avg:96.64ms -step:731/1695 train_time:70642ms step_avg:96.64ms -step:732/1695 train_time:70737ms step_avg:96.64ms -step:733/1695 train_time:70834ms step_avg:96.64ms -step:734/1695 train_time:70933ms step_avg:96.64ms -step:735/1695 train_time:71029ms step_avg:96.64ms -step:736/1695 train_time:71125ms step_avg:96.64ms -step:737/1695 train_time:71220ms step_avg:96.63ms -step:738/1695 train_time:71316ms step_avg:96.63ms -step:739/1695 train_time:71413ms step_avg:96.63ms -step:740/1695 train_time:71508ms step_avg:96.63ms -step:741/1695 train_time:71603ms step_avg:96.63ms -step:742/1695 train_time:71698ms step_avg:96.63ms -step:743/1695 train_time:71794ms step_avg:96.63ms -step:744/1695 train_time:71891ms step_avg:96.63ms -step:745/1695 train_time:71988ms step_avg:96.63ms -step:746/1695 train_time:72083ms step_avg:96.63ms -step:747/1695 train_time:72179ms step_avg:96.63ms -step:748/1695 train_time:72275ms step_avg:96.62ms -step:749/1695 train_time:72371ms step_avg:96.62ms -step:750/1695 train_time:72467ms step_avg:96.62ms -step:750/1695 val_loss:3.5657 train_time:72560ms step_avg:96.75ms -step:751/1695 train_time:72585ms step_avg:96.65ms -step:752/1695 train_time:72667ms step_avg:96.63ms -step:753/1695 train_time:72765ms step_avg:96.63ms -step:754/1695 train_time:72860ms step_avg:96.63ms -step:755/1695 train_time:72956ms step_avg:96.63ms -step:756/1695 train_time:73050ms step_avg:96.63ms -step:757/1695 train_time:73145ms step_avg:96.62ms -step:758/1695 train_time:73239ms step_avg:96.62ms -step:759/1695 train_time:73334ms step_avg:96.62ms -step:760/1695 train_time:73429ms step_avg:96.62ms -step:761/1695 train_time:73526ms step_avg:96.62ms -step:762/1695 train_time:73625ms step_avg:96.62ms -step:763/1695 train_time:73723ms step_avg:96.62ms -step:764/1695 train_time:73819ms step_avg:96.62ms -step:765/1695 train_time:73916ms step_avg:96.62ms -step:766/1695 train_time:74011ms step_avg:96.62ms -step:767/1695 train_time:74106ms step_avg:96.62ms -step:768/1695 train_time:74201ms step_avg:96.62ms -step:769/1695 train_time:74296ms step_avg:96.61ms -step:770/1695 train_time:74391ms step_avg:96.61ms -step:771/1695 train_time:74487ms step_avg:96.61ms -step:772/1695 train_time:74584ms step_avg:96.61ms -step:773/1695 train_time:74682ms step_avg:96.61ms -step:774/1695 train_time:74779ms step_avg:96.61ms -step:775/1695 train_time:74875ms step_avg:96.61ms -step:776/1695 train_time:74971ms step_avg:96.61ms -step:777/1695 train_time:75066ms step_avg:96.61ms -step:778/1695 train_time:75161ms step_avg:96.61ms -step:779/1695 train_time:75257ms step_avg:96.61ms -step:780/1695 train_time:75352ms step_avg:96.61ms -step:781/1695 train_time:75446ms step_avg:96.60ms -step:782/1695 train_time:75542ms step_avg:96.60ms -step:783/1695 train_time:75639ms step_avg:96.60ms -step:784/1695 train_time:75736ms step_avg:96.60ms -step:785/1695 train_time:75833ms step_avg:96.60ms -step:786/1695 train_time:75930ms step_avg:96.60ms -step:787/1695 train_time:76024ms step_avg:96.60ms -step:788/1695 train_time:76119ms step_avg:96.60ms -step:789/1695 train_time:76215ms step_avg:96.60ms -step:790/1695 train_time:76309ms step_avg:96.59ms -step:791/1695 train_time:76404ms step_avg:96.59ms -step:792/1695 train_time:76500ms step_avg:96.59ms -step:793/1695 train_time:76597ms step_avg:96.59ms -step:794/1695 train_time:76694ms step_avg:96.59ms -step:795/1695 train_time:76790ms step_avg:96.59ms -step:796/1695 train_time:76886ms step_avg:96.59ms -step:797/1695 train_time:76982ms step_avg:96.59ms -step:798/1695 train_time:77077ms step_avg:96.59ms -step:799/1695 train_time:77173ms step_avg:96.59ms -step:800/1695 train_time:77268ms step_avg:96.58ms -step:801/1695 train_time:77362ms step_avg:96.58ms -step:802/1695 train_time:77458ms step_avg:96.58ms -step:803/1695 train_time:77555ms step_avg:96.58ms -step:804/1695 train_time:77651ms step_avg:96.58ms -step:805/1695 train_time:77747ms step_avg:96.58ms -step:806/1695 train_time:77843ms step_avg:96.58ms -step:807/1695 train_time:77939ms step_avg:96.58ms -step:808/1695 train_time:78036ms step_avg:96.58ms -step:809/1695 train_time:78132ms step_avg:96.58ms -step:810/1695 train_time:78228ms step_avg:96.58ms -step:811/1695 train_time:78322ms step_avg:96.58ms -step:812/1695 train_time:78418ms step_avg:96.57ms -step:813/1695 train_time:78514ms step_avg:96.57ms -step:814/1695 train_time:78611ms step_avg:96.57ms -step:815/1695 train_time:78706ms step_avg:96.57ms -step:816/1695 train_time:78802ms step_avg:96.57ms -step:817/1695 train_time:78899ms step_avg:96.57ms -step:818/1695 train_time:78995ms step_avg:96.57ms -step:819/1695 train_time:79091ms step_avg:96.57ms -step:820/1695 train_time:79186ms step_avg:96.57ms -step:821/1695 train_time:79281ms step_avg:96.57ms -step:822/1695 train_time:79378ms step_avg:96.57ms -step:823/1695 train_time:79474ms step_avg:96.57ms -step:824/1695 train_time:79570ms step_avg:96.57ms -step:825/1695 train_time:79665ms step_avg:96.56ms -step:826/1695 train_time:79761ms step_avg:96.56ms -step:827/1695 train_time:79857ms step_avg:96.56ms -step:828/1695 train_time:79953ms step_avg:96.56ms -step:829/1695 train_time:80049ms step_avg:96.56ms -step:830/1695 train_time:80145ms step_avg:96.56ms -step:831/1695 train_time:80241ms step_avg:96.56ms -step:832/1695 train_time:80337ms step_avg:96.56ms -step:833/1695 train_time:80434ms step_avg:96.56ms -step:834/1695 train_time:80529ms step_avg:96.56ms -step:835/1695 train_time:80624ms step_avg:96.56ms -step:836/1695 train_time:80720ms step_avg:96.55ms -step:837/1695 train_time:80816ms step_avg:96.55ms -step:838/1695 train_time:80912ms step_avg:96.55ms -step:839/1695 train_time:81007ms step_avg:96.55ms -step:840/1695 train_time:81103ms step_avg:96.55ms -step:841/1695 train_time:81199ms step_avg:96.55ms -step:842/1695 train_time:81294ms step_avg:96.55ms -step:843/1695 train_time:81390ms step_avg:96.55ms -step:844/1695 train_time:81485ms step_avg:96.55ms -step:845/1695 train_time:81581ms step_avg:96.55ms -step:846/1695 train_time:81676ms step_avg:96.54ms -step:847/1695 train_time:81772ms step_avg:96.54ms -step:848/1695 train_time:81867ms step_avg:96.54ms -step:849/1695 train_time:81963ms step_avg:96.54ms -step:850/1695 train_time:82059ms step_avg:96.54ms -step:851/1695 train_time:82155ms step_avg:96.54ms -step:852/1695 train_time:82251ms step_avg:96.54ms -step:853/1695 train_time:82347ms step_avg:96.54ms -step:854/1695 train_time:82443ms step_avg:96.54ms -step:855/1695 train_time:82539ms step_avg:96.54ms -step:856/1695 train_time:82635ms step_avg:96.54ms -step:857/1695 train_time:82730ms step_avg:96.53ms -step:858/1695 train_time:82826ms step_avg:96.53ms -step:859/1695 train_time:82921ms step_avg:96.53ms -step:860/1695 train_time:83017ms step_avg:96.53ms -step:861/1695 train_time:83113ms step_avg:96.53ms -step:862/1695 train_time:83208ms step_avg:96.53ms -step:863/1695 train_time:83631ms step_avg:96.91ms -step:864/1695 train_time:83735ms step_avg:96.92ms -step:865/1695 train_time:83829ms step_avg:96.91ms -step:866/1695 train_time:83923ms step_avg:96.91ms -step:867/1695 train_time:84018ms step_avg:96.91ms -step:868/1695 train_time:84113ms step_avg:96.90ms -step:869/1695 train_time:84207ms step_avg:96.90ms -step:870/1695 train_time:84301ms step_avg:96.90ms -step:871/1695 train_time:84397ms step_avg:96.90ms -step:872/1695 train_time:84492ms step_avg:96.89ms -step:873/1695 train_time:84593ms step_avg:96.90ms -step:874/1695 train_time:84691ms step_avg:96.90ms -step:875/1695 train_time:84788ms step_avg:96.90ms -step:875/1695 val_loss:3.5240 train_time:84881ms step_avg:97.01ms -step:876/1695 train_time:84907ms step_avg:96.93ms -step:877/1695 train_time:84991ms step_avg:96.91ms -step:878/1695 train_time:85089ms step_avg:96.91ms -step:879/1695 train_time:85186ms step_avg:96.91ms -step:880/1695 train_time:85282ms step_avg:96.91ms -step:881/1695 train_time:85377ms step_avg:96.91ms -step:882/1695 train_time:85472ms step_avg:96.91ms -step:883/1695 train_time:85567ms step_avg:96.90ms -step:884/1695 train_time:85663ms step_avg:96.90ms -step:885/1695 train_time:85757ms step_avg:96.90ms -step:886/1695 train_time:85854ms step_avg:96.90ms -step:887/1695 train_time:85952ms step_avg:96.90ms -step:888/1695 train_time:86049ms step_avg:96.90ms -step:889/1695 train_time:86145ms step_avg:96.90ms -step:890/1695 train_time:86242ms step_avg:96.90ms -step:891/1695 train_time:86339ms step_avg:96.90ms -step:892/1695 train_time:86435ms step_avg:96.90ms -step:893/1695 train_time:86530ms step_avg:96.90ms -step:894/1695 train_time:86625ms step_avg:96.90ms -step:895/1695 train_time:86720ms step_avg:96.89ms -step:896/1695 train_time:86816ms step_avg:96.89ms -step:897/1695 train_time:86913ms step_avg:96.89ms -step:898/1695 train_time:87009ms step_avg:96.89ms -step:899/1695 train_time:87105ms step_avg:96.89ms -step:900/1695 train_time:87201ms step_avg:96.89ms -step:901/1695 train_time:87298ms step_avg:96.89ms -step:902/1695 train_time:87394ms step_avg:96.89ms -step:903/1695 train_time:87489ms step_avg:96.89ms -step:904/1695 train_time:87584ms step_avg:96.88ms -step:905/1695 train_time:87679ms step_avg:96.88ms -step:906/1695 train_time:87776ms step_avg:96.88ms -step:907/1695 train_time:87872ms step_avg:96.88ms -step:908/1695 train_time:87968ms step_avg:96.88ms -step:909/1695 train_time:88065ms step_avg:96.88ms -step:910/1695 train_time:88162ms step_avg:96.88ms -step:911/1695 train_time:88259ms step_avg:96.88ms -step:912/1695 train_time:88354ms step_avg:96.88ms -step:913/1695 train_time:88449ms step_avg:96.88ms -step:914/1695 train_time:88544ms step_avg:96.88ms -step:915/1695 train_time:88640ms step_avg:96.87ms -step:916/1695 train_time:88736ms step_avg:96.87ms -step:917/1695 train_time:88831ms step_avg:96.87ms -step:918/1695 train_time:88927ms step_avg:96.87ms -step:919/1695 train_time:89023ms step_avg:96.87ms -step:920/1695 train_time:89119ms step_avg:96.87ms -step:921/1695 train_time:89216ms step_avg:96.87ms -step:922/1695 train_time:89312ms step_avg:96.87ms -step:923/1695 train_time:89407ms step_avg:96.87ms -step:924/1695 train_time:89504ms step_avg:96.87ms -step:925/1695 train_time:89600ms step_avg:96.86ms -step:926/1695 train_time:89696ms step_avg:96.86ms -step:927/1695 train_time:89792ms step_avg:96.86ms -step:928/1695 train_time:89887ms step_avg:96.86ms -step:929/1695 train_time:89983ms step_avg:96.86ms -step:930/1695 train_time:90080ms step_avg:96.86ms -step:931/1695 train_time:90177ms step_avg:96.86ms -step:932/1695 train_time:90273ms step_avg:96.86ms -step:933/1695 train_time:90369ms step_avg:96.86ms -step:934/1695 train_time:90465ms step_avg:96.86ms -step:935/1695 train_time:90561ms step_avg:96.86ms -step:936/1695 train_time:90657ms step_avg:96.86ms -step:937/1695 train_time:90753ms step_avg:96.86ms -step:938/1695 train_time:90848ms step_avg:96.85ms -step:939/1695 train_time:90944ms step_avg:96.85ms -step:940/1695 train_time:91040ms step_avg:96.85ms -step:941/1695 train_time:91136ms step_avg:96.85ms -step:942/1695 train_time:91232ms step_avg:96.85ms -step:943/1695 train_time:91329ms step_avg:96.85ms -step:944/1695 train_time:91424ms step_avg:96.85ms -step:945/1695 train_time:91520ms step_avg:96.85ms -step:946/1695 train_time:91616ms step_avg:96.85ms -step:947/1695 train_time:91711ms step_avg:96.84ms -step:948/1695 train_time:91807ms step_avg:96.84ms -step:949/1695 train_time:91904ms step_avg:96.84ms -step:950/1695 train_time:92001ms step_avg:96.84ms -step:951/1695 train_time:92098ms step_avg:96.84ms -step:952/1695 train_time:92194ms step_avg:96.84ms -step:953/1695 train_time:92290ms step_avg:96.84ms -step:954/1695 train_time:92385ms step_avg:96.84ms -step:955/1695 train_time:92481ms step_avg:96.84ms -step:956/1695 train_time:92578ms step_avg:96.84ms -step:957/1695 train_time:92675ms step_avg:96.84ms -step:958/1695 train_time:92770ms step_avg:96.84ms -step:959/1695 train_time:92866ms step_avg:96.84ms -step:960/1695 train_time:92963ms step_avg:96.84ms -step:961/1695 train_time:93060ms step_avg:96.84ms -step:962/1695 train_time:93158ms step_avg:96.84ms -step:963/1695 train_time:93254ms step_avg:96.84ms -step:964/1695 train_time:93349ms step_avg:96.83ms -step:965/1695 train_time:93444ms step_avg:96.83ms -step:966/1695 train_time:93540ms step_avg:96.83ms -step:967/1695 train_time:93636ms step_avg:96.83ms -step:968/1695 train_time:93733ms step_avg:96.83ms -step:969/1695 train_time:93828ms step_avg:96.83ms -step:970/1695 train_time:93924ms step_avg:96.83ms -step:971/1695 train_time:94021ms step_avg:96.83ms -step:972/1695 train_time:94119ms step_avg:96.83ms -step:973/1695 train_time:94215ms step_avg:96.83ms -step:974/1695 train_time:94310ms step_avg:96.83ms -step:975/1695 train_time:94406ms step_avg:96.83ms -step:976/1695 train_time:94501ms step_avg:96.82ms -step:977/1695 train_time:94597ms step_avg:96.82ms -step:978/1695 train_time:94693ms step_avg:96.82ms -step:979/1695 train_time:94788ms step_avg:96.82ms -step:980/1695 train_time:94884ms step_avg:96.82ms -step:981/1695 train_time:94980ms step_avg:96.82ms -step:982/1695 train_time:95077ms step_avg:96.82ms -step:983/1695 train_time:95173ms step_avg:96.82ms -step:984/1695 train_time:95268ms step_avg:96.82ms -step:985/1695 train_time:95364ms step_avg:96.82ms -step:986/1695 train_time:95460ms step_avg:96.82ms -step:987/1695 train_time:95556ms step_avg:96.81ms -step:988/1695 train_time:95652ms step_avg:96.81ms -step:989/1695 train_time:95747ms step_avg:96.81ms -step:990/1695 train_time:95844ms step_avg:96.81ms -step:991/1695 train_time:95940ms step_avg:96.81ms -step:992/1695 train_time:96037ms step_avg:96.81ms -step:993/1695 train_time:96133ms step_avg:96.81ms -step:994/1695 train_time:96229ms step_avg:96.81ms -step:995/1695 train_time:96325ms step_avg:96.81ms -step:996/1695 train_time:96422ms step_avg:96.81ms -step:997/1695 train_time:96518ms step_avg:96.81ms -step:998/1695 train_time:96613ms step_avg:96.81ms -step:999/1695 train_time:96708ms step_avg:96.80ms -step:1000/1695 train_time:96804ms step_avg:96.80ms -step:1000/1695 val_loss:3.4845 train_time:96898ms step_avg:96.90ms -step:1001/1695 train_time:96924ms step_avg:96.83ms -step:1002/1695 train_time:97001ms step_avg:96.81ms -step:1003/1695 train_time:97099ms step_avg:96.81ms -step:1004/1695 train_time:97195ms step_avg:96.81ms -step:1005/1695 train_time:97291ms step_avg:96.81ms -step:1006/1695 train_time:97386ms step_avg:96.81ms -step:1007/1695 train_time:97481ms step_avg:96.80ms -step:1008/1695 train_time:97576ms step_avg:96.80ms -step:1009/1695 train_time:97672ms step_avg:96.80ms -step:1010/1695 train_time:97766ms step_avg:96.80ms -step:1011/1695 train_time:97863ms step_avg:96.80ms -step:1012/1695 train_time:97961ms step_avg:96.80ms -step:1013/1695 train_time:98057ms step_avg:96.80ms -step:1014/1695 train_time:98153ms step_avg:96.80ms -step:1015/1695 train_time:98249ms step_avg:96.80ms -step:1016/1695 train_time:98346ms step_avg:96.80ms -step:1017/1695 train_time:98440ms step_avg:96.79ms -step:1018/1695 train_time:98535ms step_avg:96.79ms -step:1019/1695 train_time:98632ms step_avg:96.79ms -step:1020/1695 train_time:98728ms step_avg:96.79ms -step:1021/1695 train_time:98823ms step_avg:96.79ms -step:1022/1695 train_time:98919ms step_avg:96.79ms -step:1023/1695 train_time:99016ms step_avg:96.79ms -step:1024/1695 train_time:99114ms step_avg:96.79ms -step:1025/1695 train_time:99211ms step_avg:96.79ms -step:1026/1695 train_time:99307ms step_avg:96.79ms -step:1027/1695 train_time:99402ms step_avg:96.79ms -step:1028/1695 train_time:99497ms step_avg:96.79ms -step:1029/1695 train_time:99593ms step_avg:96.79ms -step:1030/1695 train_time:99690ms step_avg:96.79ms -step:1031/1695 train_time:99786ms step_avg:96.79ms -step:1032/1695 train_time:99882ms step_avg:96.78ms -step:1033/1695 train_time:99978ms step_avg:96.78ms -step:1034/1695 train_time:100075ms step_avg:96.78ms -step:1035/1695 train_time:100172ms step_avg:96.78ms -step:1036/1695 train_time:100506ms step_avg:97.01ms -step:1037/1695 train_time:100695ms step_avg:97.10ms -step:1038/1695 train_time:100788ms step_avg:97.10ms -step:1039/1695 train_time:100882ms step_avg:97.10ms -step:1040/1695 train_time:100977ms step_avg:97.09ms -step:1041/1695 train_time:101073ms step_avg:97.09ms -step:1042/1695 train_time:101168ms step_avg:97.09ms -step:1043/1695 train_time:101262ms step_avg:97.09ms -step:1044/1695 train_time:101357ms step_avg:97.09ms -step:1045/1695 train_time:101452ms step_avg:97.08ms -step:1046/1695 train_time:101548ms step_avg:97.08ms -step:1047/1695 train_time:101649ms step_avg:97.09ms -step:1048/1695 train_time:101747ms step_avg:97.09ms -step:1049/1695 train_time:101842ms step_avg:97.09ms -step:1050/1695 train_time:101937ms step_avg:97.08ms -step:1051/1695 train_time:102034ms step_avg:97.08ms -step:1052/1695 train_time:102130ms step_avg:97.08ms -step:1053/1695 train_time:102225ms step_avg:97.08ms -step:1054/1695 train_time:102320ms step_avg:97.08ms -step:1055/1695 train_time:102415ms step_avg:97.08ms -step:1056/1695 train_time:102511ms step_avg:97.07ms -step:1057/1695 train_time:102610ms step_avg:97.08ms -step:1058/1695 train_time:102707ms step_avg:97.08ms -step:1059/1695 train_time:102804ms step_avg:97.08ms -step:1060/1695 train_time:102900ms step_avg:97.08ms -step:1061/1695 train_time:102996ms step_avg:97.07ms -step:1062/1695 train_time:103092ms step_avg:97.07ms -step:1063/1695 train_time:103187ms step_avg:97.07ms -step:1064/1695 train_time:103282ms step_avg:97.07ms -step:1065/1695 train_time:103377ms step_avg:97.07ms -step:1066/1695 train_time:103473ms step_avg:97.07ms -step:1067/1695 train_time:103569ms step_avg:97.07ms -step:1068/1695 train_time:103666ms step_avg:97.07ms -step:1069/1695 train_time:103762ms step_avg:97.06ms -step:1070/1695 train_time:103858ms step_avg:97.06ms -step:1071/1695 train_time:103954ms step_avg:97.06ms -step:1072/1695 train_time:104050ms step_avg:97.06ms -step:1073/1695 train_time:104145ms step_avg:97.06ms -step:1074/1695 train_time:104241ms step_avg:97.06ms -step:1075/1695 train_time:104336ms step_avg:97.06ms -step:1076/1695 train_time:104432ms step_avg:97.06ms -step:1077/1695 train_time:104527ms step_avg:97.05ms -step:1078/1695 train_time:104624ms step_avg:97.05ms -step:1079/1695 train_time:104719ms step_avg:97.05ms -step:1080/1695 train_time:104817ms step_avg:97.05ms -step:1081/1695 train_time:104915ms step_avg:97.05ms -step:1082/1695 train_time:105011ms step_avg:97.05ms -step:1083/1695 train_time:105107ms step_avg:97.05ms -step:1084/1695 train_time:105203ms step_avg:97.05ms -step:1085/1695 train_time:105298ms step_avg:97.05ms -step:1086/1695 train_time:105394ms step_avg:97.05ms -step:1087/1695 train_time:105489ms step_avg:97.05ms -step:1088/1695 train_time:105585ms step_avg:97.05ms -step:1089/1695 train_time:105680ms step_avg:97.04ms -step:1090/1695 train_time:105776ms step_avg:97.04ms -step:1091/1695 train_time:105873ms step_avg:97.04ms -step:1092/1695 train_time:105970ms step_avg:97.04ms -step:1093/1695 train_time:106065ms step_avg:97.04ms -step:1094/1695 train_time:106160ms step_avg:97.04ms -step:1095/1695 train_time:106255ms step_avg:97.04ms -step:1096/1695 train_time:106351ms step_avg:97.04ms -step:1097/1695 train_time:106448ms step_avg:97.04ms -step:1098/1695 train_time:106544ms step_avg:97.03ms -step:1099/1695 train_time:106640ms step_avg:97.03ms -step:1100/1695 train_time:106735ms step_avg:97.03ms -step:1101/1695 train_time:106832ms step_avg:97.03ms -step:1102/1695 train_time:106929ms step_avg:97.03ms -step:1103/1695 train_time:107026ms step_avg:97.03ms -step:1104/1695 train_time:107121ms step_avg:97.03ms -step:1105/1695 train_time:107216ms step_avg:97.03ms -step:1106/1695 train_time:107312ms step_avg:97.03ms -step:1107/1695 train_time:107408ms step_avg:97.03ms -step:1108/1695 train_time:107503ms step_avg:97.02ms -step:1109/1695 train_time:107599ms step_avg:97.02ms -step:1110/1695 train_time:107695ms step_avg:97.02ms -step:1111/1695 train_time:107792ms step_avg:97.02ms -step:1112/1695 train_time:107889ms step_avg:97.02ms -step:1113/1695 train_time:107985ms step_avg:97.02ms -step:1114/1695 train_time:108080ms step_avg:97.02ms -step:1115/1695 train_time:108176ms step_avg:97.02ms -step:1116/1695 train_time:108272ms step_avg:97.02ms -step:1117/1695 train_time:108369ms step_avg:97.02ms -step:1118/1695 train_time:108464ms step_avg:97.02ms -step:1119/1695 train_time:108560ms step_avg:97.02ms -step:1120/1695 train_time:108655ms step_avg:97.01ms -step:1121/1695 train_time:108752ms step_avg:97.01ms -step:1122/1695 train_time:108848ms step_avg:97.01ms -step:1123/1695 train_time:108944ms step_avg:97.01ms -step:1124/1695 train_time:109040ms step_avg:97.01ms -step:1125/1695 train_time:109136ms step_avg:97.01ms -step:1125/1695 val_loss:3.4374 train_time:109230ms step_avg:97.09ms -step:1126/1695 train_time:109255ms step_avg:97.03ms -step:1127/1695 train_time:109339ms step_avg:97.02ms -step:1128/1695 train_time:109436ms step_avg:97.02ms -step:1129/1695 train_time:109533ms step_avg:97.02ms -step:1130/1695 train_time:109628ms step_avg:97.02ms -step:1131/1695 train_time:109724ms step_avg:97.01ms -step:1132/1695 train_time:109818ms step_avg:97.01ms -step:1133/1695 train_time:109915ms step_avg:97.01ms -step:1134/1695 train_time:110012ms step_avg:97.01ms -step:1135/1695 train_time:110108ms step_avg:97.01ms -step:1136/1695 train_time:110207ms step_avg:97.01ms -step:1137/1695 train_time:110306ms step_avg:97.02ms -step:1138/1695 train_time:110406ms step_avg:97.02ms -step:1139/1695 train_time:110505ms step_avg:97.02ms -step:1140/1695 train_time:110602ms step_avg:97.02ms -step:1141/1695 train_time:110699ms step_avg:97.02ms -step:1142/1695 train_time:110796ms step_avg:97.02ms -step:1143/1695 train_time:110894ms step_avg:97.02ms -step:1144/1695 train_time:110992ms step_avg:97.02ms -step:1145/1695 train_time:111089ms step_avg:97.02ms -step:1146/1695 train_time:111188ms step_avg:97.02ms -step:1147/1695 train_time:111286ms step_avg:97.02ms -step:1148/1695 train_time:111384ms step_avg:97.02ms -step:1149/1695 train_time:111483ms step_avg:97.03ms -step:1150/1695 train_time:111581ms step_avg:97.03ms -step:1151/1695 train_time:111678ms step_avg:97.03ms -step:1152/1695 train_time:111776ms step_avg:97.03ms -step:1153/1695 train_time:111873ms step_avg:97.03ms -step:1154/1695 train_time:111970ms step_avg:97.03ms -step:1155/1695 train_time:112067ms step_avg:97.03ms -step:1156/1695 train_time:112165ms step_avg:97.03ms -step:1157/1695 train_time:112264ms step_avg:97.03ms -step:1158/1695 train_time:112361ms step_avg:97.03ms -step:1159/1695 train_time:112460ms step_avg:97.03ms -step:1160/1695 train_time:112558ms step_avg:97.03ms -step:1161/1695 train_time:112656ms step_avg:97.03ms -step:1162/1695 train_time:112754ms step_avg:97.03ms -step:1163/1695 train_time:112852ms step_avg:97.04ms -step:1164/1695 train_time:112949ms step_avg:97.04ms -step:1165/1695 train_time:113046ms step_avg:97.04ms -step:1166/1695 train_time:113143ms step_avg:97.04ms -step:1167/1695 train_time:113241ms step_avg:97.04ms -step:1168/1695 train_time:113339ms step_avg:97.04ms -step:1169/1695 train_time:113437ms step_avg:97.04ms -step:1170/1695 train_time:113535ms step_avg:97.04ms -step:1171/1695 train_time:113634ms step_avg:97.04ms -step:1172/1695 train_time:113733ms step_avg:97.04ms -step:1173/1695 train_time:113830ms step_avg:97.04ms -step:1174/1695 train_time:113927ms step_avg:97.04ms -step:1175/1695 train_time:114025ms step_avg:97.04ms -step:1176/1695 train_time:114121ms step_avg:97.04ms -step:1177/1695 train_time:114219ms step_avg:97.04ms -step:1178/1695 train_time:114317ms step_avg:97.04ms -step:1179/1695 train_time:114416ms step_avg:97.05ms -step:1180/1695 train_time:114515ms step_avg:97.05ms -step:1181/1695 train_time:114613ms step_avg:97.05ms -step:1182/1695 train_time:114712ms step_avg:97.05ms -step:1183/1695 train_time:114809ms step_avg:97.05ms -step:1184/1695 train_time:114908ms step_avg:97.05ms -step:1185/1695 train_time:115005ms step_avg:97.05ms -step:1186/1695 train_time:115102ms step_avg:97.05ms -step:1187/1695 train_time:115199ms step_avg:97.05ms -step:1188/1695 train_time:115296ms step_avg:97.05ms -step:1189/1695 train_time:115395ms step_avg:97.05ms -step:1190/1695 train_time:115494ms step_avg:97.05ms -step:1191/1695 train_time:115593ms step_avg:97.06ms -step:1192/1695 train_time:115692ms step_avg:97.06ms -step:1193/1695 train_time:115790ms step_avg:97.06ms -step:1194/1695 train_time:115888ms step_avg:97.06ms -step:1195/1695 train_time:115986ms step_avg:97.06ms -step:1196/1695 train_time:116084ms step_avg:97.06ms -step:1197/1695 train_time:116183ms step_avg:97.06ms -step:1198/1695 train_time:116280ms step_avg:97.06ms -step:1199/1695 train_time:116378ms step_avg:97.06ms -step:1200/1695 train_time:116476ms step_avg:97.06ms -step:1201/1695 train_time:116574ms step_avg:97.06ms -step:1202/1695 train_time:116672ms step_avg:97.07ms -step:1203/1695 train_time:116770ms step_avg:97.07ms -step:1204/1695 train_time:116868ms step_avg:97.07ms -step:1205/1695 train_time:116966ms step_avg:97.07ms -step:1206/1695 train_time:117064ms step_avg:97.07ms -step:1207/1695 train_time:117162ms step_avg:97.07ms -step:1208/1695 train_time:117508ms step_avg:97.27ms -step:1209/1695 train_time:117691ms step_avg:97.35ms -step:1210/1695 train_time:117787ms step_avg:97.34ms -step:1211/1695 train_time:117883ms step_avg:97.34ms -step:1212/1695 train_time:117979ms step_avg:97.34ms -step:1213/1695 train_time:118076ms step_avg:97.34ms -step:1214/1695 train_time:118174ms step_avg:97.34ms -step:1215/1695 train_time:118270ms step_avg:97.34ms -step:1216/1695 train_time:118367ms step_avg:97.34ms -step:1217/1695 train_time:118463ms step_avg:97.34ms -step:1218/1695 train_time:118566ms step_avg:97.34ms -step:1219/1695 train_time:118667ms step_avg:97.35ms -step:1220/1695 train_time:118767ms step_avg:97.35ms -step:1221/1695 train_time:118863ms step_avg:97.35ms -step:1222/1695 train_time:118959ms step_avg:97.35ms -step:1223/1695 train_time:119056ms step_avg:97.35ms -step:1224/1695 train_time:119152ms step_avg:97.35ms -step:1225/1695 train_time:119249ms step_avg:97.35ms -step:1226/1695 train_time:119346ms step_avg:97.35ms -step:1227/1695 train_time:119443ms step_avg:97.35ms -step:1228/1695 train_time:119541ms step_avg:97.35ms -step:1229/1695 train_time:119640ms step_avg:97.35ms -step:1230/1695 train_time:119740ms step_avg:97.35ms -step:1231/1695 train_time:119838ms step_avg:97.35ms -step:1232/1695 train_time:119936ms step_avg:97.35ms -step:1233/1695 train_time:120033ms step_avg:97.35ms -step:1234/1695 train_time:120130ms step_avg:97.35ms -step:1235/1695 train_time:120227ms step_avg:97.35ms -step:1236/1695 train_time:120323ms step_avg:97.35ms -step:1237/1695 train_time:120420ms step_avg:97.35ms -step:1238/1695 train_time:120517ms step_avg:97.35ms -step:1239/1695 train_time:120616ms step_avg:97.35ms -step:1240/1695 train_time:120716ms step_avg:97.35ms -step:1241/1695 train_time:120815ms step_avg:97.35ms -step:1242/1695 train_time:120914ms step_avg:97.35ms -step:1243/1695 train_time:121012ms step_avg:97.36ms -step:1244/1695 train_time:121110ms step_avg:97.35ms -step:1245/1695 train_time:121206ms step_avg:97.35ms -step:1246/1695 train_time:121304ms step_avg:97.35ms -step:1247/1695 train_time:121400ms step_avg:97.35ms -step:1248/1695 train_time:121498ms step_avg:97.35ms -step:1249/1695 train_time:121596ms step_avg:97.35ms -step:1250/1695 train_time:121696ms step_avg:97.36ms -step:1250/1695 val_loss:3.3886 train_time:121792ms step_avg:97.43ms -step:1251/1695 train_time:121818ms step_avg:97.38ms -step:1252/1695 train_time:121899ms step_avg:97.36ms -step:1253/1695 train_time:121997ms step_avg:97.36ms -step:1254/1695 train_time:122094ms step_avg:97.36ms -step:1255/1695 train_time:122190ms step_avg:97.36ms -step:1256/1695 train_time:122287ms step_avg:97.36ms -step:1257/1695 train_time:122383ms step_avg:97.36ms -step:1258/1695 train_time:122480ms step_avg:97.36ms -step:1259/1695 train_time:122576ms step_avg:97.36ms -step:1260/1695 train_time:122673ms step_avg:97.36ms -step:1261/1695 train_time:122774ms step_avg:97.36ms -step:1262/1695 train_time:122874ms step_avg:97.36ms -step:1263/1695 train_time:122972ms step_avg:97.36ms -step:1264/1695 train_time:123070ms step_avg:97.37ms -step:1265/1695 train_time:123167ms step_avg:97.37ms -step:1266/1695 train_time:123265ms step_avg:97.37ms -step:1267/1695 train_time:123361ms step_avg:97.36ms -step:1268/1695 train_time:123458ms step_avg:97.36ms -step:1269/1695 train_time:123554ms step_avg:97.36ms -step:1270/1695 train_time:123650ms step_avg:97.36ms -step:1271/1695 train_time:123749ms step_avg:97.36ms -step:1272/1695 train_time:123849ms step_avg:97.37ms -step:1273/1695 train_time:123947ms step_avg:97.37ms -step:1274/1695 train_time:124045ms step_avg:97.37ms -step:1275/1695 train_time:124143ms step_avg:97.37ms -step:1276/1695 train_time:124242ms step_avg:97.37ms -step:1277/1695 train_time:124340ms step_avg:97.37ms -step:1278/1695 train_time:124437ms step_avg:97.37ms -step:1279/1695 train_time:124534ms step_avg:97.37ms -step:1280/1695 train_time:124632ms step_avg:97.37ms -step:1281/1695 train_time:124730ms step_avg:97.37ms -step:1282/1695 train_time:124828ms step_avg:97.37ms -step:1283/1695 train_time:124926ms step_avg:97.37ms -step:1284/1695 train_time:125024ms step_avg:97.37ms -step:1285/1695 train_time:125122ms step_avg:97.37ms -step:1286/1695 train_time:125220ms step_avg:97.37ms -step:1287/1695 train_time:125318ms step_avg:97.37ms -step:1288/1695 train_time:125415ms step_avg:97.37ms -step:1289/1695 train_time:125512ms step_avg:97.37ms -step:1290/1695 train_time:125609ms step_avg:97.37ms -step:1291/1695 train_time:125707ms step_avg:97.37ms -step:1292/1695 train_time:125806ms step_avg:97.37ms -step:1293/1695 train_time:125904ms step_avg:97.37ms -step:1294/1695 train_time:126003ms step_avg:97.37ms -step:1295/1695 train_time:126102ms step_avg:97.38ms -step:1296/1695 train_time:126200ms step_avg:97.38ms -step:1297/1695 train_time:126297ms step_avg:97.38ms -step:1298/1695 train_time:126395ms step_avg:97.38ms -step:1299/1695 train_time:126492ms step_avg:97.38ms -step:1300/1695 train_time:126589ms step_avg:97.38ms -step:1301/1695 train_time:126686ms step_avg:97.38ms -step:1302/1695 train_time:126785ms step_avg:97.38ms -step:1303/1695 train_time:126883ms step_avg:97.38ms -step:1304/1695 train_time:126981ms step_avg:97.38ms -step:1305/1695 train_time:127079ms step_avg:97.38ms -step:1306/1695 train_time:127178ms step_avg:97.38ms -step:1307/1695 train_time:127276ms step_avg:97.38ms -step:1308/1695 train_time:127373ms step_avg:97.38ms -step:1309/1695 train_time:127470ms step_avg:97.38ms -step:1310/1695 train_time:127568ms step_avg:97.38ms -step:1311/1695 train_time:127665ms step_avg:97.38ms -step:1312/1695 train_time:127764ms step_avg:97.38ms -step:1313/1695 train_time:127862ms step_avg:97.38ms -step:1314/1695 train_time:127961ms step_avg:97.38ms -step:1315/1695 train_time:128059ms step_avg:97.38ms -step:1316/1695 train_time:128158ms step_avg:97.38ms -step:1317/1695 train_time:128255ms step_avg:97.38ms -step:1318/1695 train_time:128352ms step_avg:97.38ms -step:1319/1695 train_time:128449ms step_avg:97.38ms -step:1320/1695 train_time:128546ms step_avg:97.38ms -step:1321/1695 train_time:128644ms step_avg:97.38ms -step:1322/1695 train_time:128743ms step_avg:97.38ms -step:1323/1695 train_time:128841ms step_avg:97.39ms -step:1324/1695 train_time:128939ms step_avg:97.39ms -step:1325/1695 train_time:129038ms step_avg:97.39ms -step:1326/1695 train_time:129136ms step_avg:97.39ms -step:1327/1695 train_time:129233ms step_avg:97.39ms -step:1328/1695 train_time:129332ms step_avg:97.39ms -step:1329/1695 train_time:129429ms step_avg:97.39ms -step:1330/1695 train_time:129525ms step_avg:97.39ms -step:1331/1695 train_time:129623ms step_avg:97.39ms -step:1332/1695 train_time:129721ms step_avg:97.39ms -step:1333/1695 train_time:129820ms step_avg:97.39ms -step:1334/1695 train_time:129917ms step_avg:97.39ms -step:1335/1695 train_time:130015ms step_avg:97.39ms -step:1336/1695 train_time:130112ms step_avg:97.39ms -step:1337/1695 train_time:130209ms step_avg:97.39ms -step:1338/1695 train_time:130307ms step_avg:97.39ms -step:1339/1695 train_time:130405ms step_avg:97.39ms -step:1340/1695 train_time:130503ms step_avg:97.39ms -step:1341/1695 train_time:130601ms step_avg:97.39ms -step:1342/1695 train_time:130699ms step_avg:97.39ms -step:1343/1695 train_time:130797ms step_avg:97.39ms -step:1344/1695 train_time:130894ms step_avg:97.39ms -step:1345/1695 train_time:130991ms step_avg:97.39ms -step:1346/1695 train_time:131088ms step_avg:97.39ms -step:1347/1695 train_time:131185ms step_avg:97.39ms -step:1348/1695 train_time:131283ms step_avg:97.39ms -step:1349/1695 train_time:131381ms step_avg:97.39ms -step:1350/1695 train_time:131481ms step_avg:97.39ms -step:1351/1695 train_time:131579ms step_avg:97.39ms -step:1352/1695 train_time:131676ms step_avg:97.39ms -step:1353/1695 train_time:131774ms step_avg:97.39ms -step:1354/1695 train_time:131871ms step_avg:97.39ms -step:1355/1695 train_time:131969ms step_avg:97.39ms -step:1356/1695 train_time:132066ms step_avg:97.39ms -step:1357/1695 train_time:132164ms step_avg:97.39ms -step:1358/1695 train_time:132262ms step_avg:97.39ms -step:1359/1695 train_time:132360ms step_avg:97.40ms -step:1360/1695 train_time:132459ms step_avg:97.40ms -step:1361/1695 train_time:132558ms step_avg:97.40ms -step:1362/1695 train_time:132655ms step_avg:97.40ms -step:1363/1695 train_time:132753ms step_avg:97.40ms -step:1364/1695 train_time:132850ms step_avg:97.40ms -step:1365/1695 train_time:132948ms step_avg:97.40ms -step:1366/1695 train_time:133046ms step_avg:97.40ms -step:1367/1695 train_time:133144ms step_avg:97.40ms -step:1368/1695 train_time:133242ms step_avg:97.40ms -step:1369/1695 train_time:133340ms step_avg:97.40ms -step:1370/1695 train_time:133439ms step_avg:97.40ms -step:1371/1695 train_time:133537ms step_avg:97.40ms -step:1372/1695 train_time:133634ms step_avg:97.40ms -step:1373/1695 train_time:133732ms step_avg:97.40ms -step:1374/1695 train_time:133829ms step_avg:97.40ms -step:1375/1695 train_time:133926ms step_avg:97.40ms -step:1375/1695 val_loss:3.3508 train_time:134022ms step_avg:97.47ms -step:1376/1695 train_time:134049ms step_avg:97.42ms -step:1377/1695 train_time:134131ms step_avg:97.41ms -step:1378/1695 train_time:134229ms step_avg:97.41ms -step:1379/1695 train_time:134327ms step_avg:97.41ms -step:1380/1695 train_time:134424ms step_avg:97.41ms -step:1381/1695 train_time:134877ms step_avg:97.67ms -step:1382/1695 train_time:134952ms step_avg:97.65ms -step:1383/1695 train_time:135047ms step_avg:97.65ms -step:1384/1695 train_time:135144ms step_avg:97.65ms -step:1385/1695 train_time:135241ms step_avg:97.65ms -step:1386/1695 train_time:135338ms step_avg:97.65ms -step:1387/1695 train_time:135434ms step_avg:97.65ms -step:1388/1695 train_time:135530ms step_avg:97.64ms -step:1389/1695 train_time:135626ms step_avg:97.64ms -step:1390/1695 train_time:135723ms step_avg:97.64ms -step:1391/1695 train_time:135831ms step_avg:97.65ms -step:1392/1695 train_time:135931ms step_avg:97.65ms -step:1393/1695 train_time:136030ms step_avg:97.65ms -step:1394/1695 train_time:136128ms step_avg:97.65ms -step:1395/1695 train_time:136225ms step_avg:97.65ms -step:1396/1695 train_time:136322ms step_avg:97.65ms -step:1397/1695 train_time:136419ms step_avg:97.65ms -step:1398/1695 train_time:136516ms step_avg:97.65ms -step:1399/1695 train_time:136613ms step_avg:97.65ms -step:1400/1695 train_time:136711ms step_avg:97.65ms -step:1401/1695 train_time:136810ms step_avg:97.65ms -step:1402/1695 train_time:136910ms step_avg:97.65ms -step:1403/1695 train_time:137008ms step_avg:97.65ms -step:1404/1695 train_time:137106ms step_avg:97.65ms -step:1405/1695 train_time:137204ms step_avg:97.65ms -step:1406/1695 train_time:137301ms step_avg:97.65ms -step:1407/1695 train_time:137398ms step_avg:97.65ms -step:1408/1695 train_time:137494ms step_avg:97.65ms -step:1409/1695 train_time:137591ms step_avg:97.65ms -step:1410/1695 train_time:137688ms step_avg:97.65ms -step:1411/1695 train_time:137787ms step_avg:97.65ms -step:1412/1695 train_time:137887ms step_avg:97.65ms -step:1413/1695 train_time:137987ms step_avg:97.66ms -step:1414/1695 train_time:138085ms step_avg:97.66ms -step:1415/1695 train_time:138183ms step_avg:97.66ms -step:1416/1695 train_time:138281ms step_avg:97.66ms -step:1417/1695 train_time:138378ms step_avg:97.66ms -step:1418/1695 train_time:138476ms step_avg:97.66ms -step:1419/1695 train_time:138574ms step_avg:97.66ms -step:1420/1695 train_time:138669ms step_avg:97.65ms -step:1421/1695 train_time:138766ms step_avg:97.65ms -step:1422/1695 train_time:138866ms step_avg:97.66ms -step:1423/1695 train_time:138965ms step_avg:97.66ms -step:1424/1695 train_time:139065ms step_avg:97.66ms -step:1425/1695 train_time:139163ms step_avg:97.66ms -step:1426/1695 train_time:139260ms step_avg:97.66ms -step:1427/1695 train_time:139359ms step_avg:97.66ms -step:1428/1695 train_time:139456ms step_avg:97.66ms -step:1429/1695 train_time:139553ms step_avg:97.66ms -step:1430/1695 train_time:139650ms step_avg:97.66ms -step:1431/1695 train_time:139748ms step_avg:97.66ms -step:1432/1695 train_time:139845ms step_avg:97.66ms -step:1433/1695 train_time:139945ms step_avg:97.66ms -step:1434/1695 train_time:140045ms step_avg:97.66ms -step:1435/1695 train_time:140144ms step_avg:97.66ms -step:1436/1695 train_time:140243ms step_avg:97.66ms -step:1437/1695 train_time:140341ms step_avg:97.66ms -step:1438/1695 train_time:140440ms step_avg:97.66ms -step:1439/1695 train_time:140539ms step_avg:97.66ms -step:1440/1695 train_time:140637ms step_avg:97.66ms -step:1441/1695 train_time:140733ms step_avg:97.66ms -step:1442/1695 train_time:140830ms step_avg:97.66ms -step:1443/1695 train_time:140927ms step_avg:97.66ms -step:1444/1695 train_time:141024ms step_avg:97.66ms -step:1445/1695 train_time:141124ms step_avg:97.66ms -step:1446/1695 train_time:141222ms step_avg:97.66ms -step:1447/1695 train_time:141321ms step_avg:97.66ms -step:1448/1695 train_time:141421ms step_avg:97.67ms -step:1449/1695 train_time:141519ms step_avg:97.67ms -step:1450/1695 train_time:141618ms step_avg:97.67ms -step:1451/1695 train_time:141716ms step_avg:97.67ms -step:1452/1695 train_time:141813ms step_avg:97.67ms -step:1453/1695 train_time:141909ms step_avg:97.67ms -step:1454/1695 train_time:142006ms step_avg:97.67ms -step:1455/1695 train_time:142104ms step_avg:97.67ms -step:1456/1695 train_time:142202ms step_avg:97.67ms -step:1457/1695 train_time:142302ms step_avg:97.67ms -step:1458/1695 train_time:142401ms step_avg:97.67ms -step:1459/1695 train_time:142500ms step_avg:97.67ms -step:1460/1695 train_time:142599ms step_avg:97.67ms -step:1461/1695 train_time:142697ms step_avg:97.67ms -step:1462/1695 train_time:142795ms step_avg:97.67ms -step:1463/1695 train_time:142892ms step_avg:97.67ms -step:1464/1695 train_time:142989ms step_avg:97.67ms -step:1465/1695 train_time:143087ms step_avg:97.67ms -step:1466/1695 train_time:143185ms step_avg:97.67ms -step:1467/1695 train_time:143283ms step_avg:97.67ms -step:1468/1695 train_time:143382ms step_avg:97.67ms -step:1469/1695 train_time:143481ms step_avg:97.67ms -step:1470/1695 train_time:143579ms step_avg:97.67ms -step:1471/1695 train_time:143678ms step_avg:97.67ms -step:1472/1695 train_time:143775ms step_avg:97.67ms -step:1473/1695 train_time:143872ms step_avg:97.67ms -step:1474/1695 train_time:143968ms step_avg:97.67ms -step:1475/1695 train_time:144065ms step_avg:97.67ms -step:1476/1695 train_time:144163ms step_avg:97.67ms -step:1477/1695 train_time:144261ms step_avg:97.67ms -step:1478/1695 train_time:144359ms step_avg:97.67ms -step:1479/1695 train_time:144458ms step_avg:97.67ms -step:1480/1695 train_time:144557ms step_avg:97.67ms -step:1481/1695 train_time:144654ms step_avg:97.67ms -step:1482/1695 train_time:144751ms step_avg:97.67ms -step:1483/1695 train_time:144849ms step_avg:97.67ms -step:1484/1695 train_time:144946ms step_avg:97.67ms -step:1485/1695 train_time:145044ms step_avg:97.67ms -step:1486/1695 train_time:145141ms step_avg:97.67ms -step:1487/1695 train_time:145238ms step_avg:97.67ms -step:1488/1695 train_time:145335ms step_avg:97.67ms -step:1489/1695 train_time:145433ms step_avg:97.67ms -step:1490/1695 train_time:145532ms step_avg:97.67ms -step:1491/1695 train_time:145630ms step_avg:97.67ms -step:1492/1695 train_time:145728ms step_avg:97.67ms -step:1493/1695 train_time:145826ms step_avg:97.67ms -step:1494/1695 train_time:145924ms step_avg:97.67ms -step:1495/1695 train_time:146023ms step_avg:97.67ms -step:1496/1695 train_time:146121ms step_avg:97.67ms -step:1497/1695 train_time:146218ms step_avg:97.67ms -step:1498/1695 train_time:146316ms step_avg:97.67ms -step:1499/1695 train_time:146414ms step_avg:97.67ms -step:1500/1695 train_time:146512ms step_avg:97.67ms -step:1500/1695 val_loss:3.3173 train_time:146608ms step_avg:97.74ms -step:1501/1695 train_time:146635ms step_avg:97.69ms -step:1502/1695 train_time:146715ms step_avg:97.68ms -step:1503/1695 train_time:146815ms step_avg:97.68ms -step:1504/1695 train_time:146912ms step_avg:97.68ms -step:1505/1695 train_time:147009ms step_avg:97.68ms -step:1506/1695 train_time:147105ms step_avg:97.68ms -step:1507/1695 train_time:147202ms step_avg:97.68ms -step:1508/1695 train_time:147299ms step_avg:97.68ms -step:1509/1695 train_time:147395ms step_avg:97.68ms -step:1510/1695 train_time:147491ms step_avg:97.68ms -step:1511/1695 train_time:147591ms step_avg:97.68ms -step:1512/1695 train_time:147692ms step_avg:97.68ms -step:1513/1695 train_time:147792ms step_avg:97.68ms -step:1514/1695 train_time:147890ms step_avg:97.68ms -step:1515/1695 train_time:147987ms step_avg:97.68ms -step:1516/1695 train_time:148084ms step_avg:97.68ms -step:1517/1695 train_time:148181ms step_avg:97.68ms -step:1518/1695 train_time:148278ms step_avg:97.68ms -step:1519/1695 train_time:148375ms step_avg:97.68ms -step:1520/1695 train_time:148472ms step_avg:97.68ms -step:1521/1695 train_time:148570ms step_avg:97.68ms -step:1522/1695 train_time:148669ms step_avg:97.68ms -step:1523/1695 train_time:148769ms step_avg:97.68ms -step:1524/1695 train_time:148867ms step_avg:97.68ms -step:1525/1695 train_time:148965ms step_avg:97.68ms -step:1526/1695 train_time:149062ms step_avg:97.68ms -step:1527/1695 train_time:149159ms step_avg:97.68ms -step:1528/1695 train_time:149256ms step_avg:97.68ms -step:1529/1695 train_time:149352ms step_avg:97.68ms -step:1530/1695 train_time:149448ms step_avg:97.68ms -step:1531/1695 train_time:149546ms step_avg:97.68ms -step:1532/1695 train_time:149645ms step_avg:97.68ms -step:1533/1695 train_time:149745ms step_avg:97.68ms -step:1534/1695 train_time:149844ms step_avg:97.68ms -step:1535/1695 train_time:149943ms step_avg:97.68ms -step:1536/1695 train_time:150041ms step_avg:97.68ms -step:1537/1695 train_time:150138ms step_avg:97.68ms -step:1538/1695 train_time:150235ms step_avg:97.68ms -step:1539/1695 train_time:150333ms step_avg:97.68ms -step:1540/1695 train_time:150429ms step_avg:97.68ms -step:1541/1695 train_time:150527ms step_avg:97.68ms -step:1542/1695 train_time:150625ms step_avg:97.68ms -step:1543/1695 train_time:150724ms step_avg:97.68ms -step:1544/1695 train_time:150823ms step_avg:97.68ms -step:1545/1695 train_time:150922ms step_avg:97.68ms -step:1546/1695 train_time:151020ms step_avg:97.68ms -step:1547/1695 train_time:151118ms step_avg:97.68ms -step:1548/1695 train_time:151214ms step_avg:97.68ms -step:1549/1695 train_time:151311ms step_avg:97.68ms -step:1550/1695 train_time:151407ms step_avg:97.68ms -step:1551/1695 train_time:151505ms step_avg:97.68ms -step:1552/1695 train_time:151855ms step_avg:97.84ms -step:1553/1695 train_time:152033ms step_avg:97.90ms -step:1554/1695 train_time:152130ms step_avg:97.90ms -step:1555/1695 train_time:152226ms step_avg:97.89ms -step:1556/1695 train_time:152323ms step_avg:97.89ms -step:1557/1695 train_time:152420ms step_avg:97.89ms -step:1558/1695 train_time:152516ms step_avg:97.89ms -step:1559/1695 train_time:152612ms step_avg:97.89ms -step:1560/1695 train_time:152708ms step_avg:97.89ms -step:1561/1695 train_time:152804ms step_avg:97.89ms -step:1562/1695 train_time:152908ms step_avg:97.89ms -step:1563/1695 train_time:153011ms step_avg:97.90ms -step:1564/1695 train_time:153110ms step_avg:97.90ms -step:1565/1695 train_time:153207ms step_avg:97.90ms -step:1566/1695 train_time:153305ms step_avg:97.90ms -step:1567/1695 train_time:153401ms step_avg:97.89ms -step:1568/1695 train_time:153499ms step_avg:97.89ms -step:1569/1695 train_time:153596ms step_avg:97.89ms -step:1570/1695 train_time:153694ms step_avg:97.89ms -step:1571/1695 train_time:153791ms step_avg:97.89ms -step:1572/1695 train_time:153889ms step_avg:97.89ms -step:1573/1695 train_time:153989ms step_avg:97.89ms -step:1574/1695 train_time:154087ms step_avg:97.90ms -step:1575/1695 train_time:154185ms step_avg:97.90ms -step:1576/1695 train_time:154283ms step_avg:97.90ms -step:1577/1695 train_time:154380ms step_avg:97.89ms -step:1578/1695 train_time:154477ms step_avg:97.89ms -step:1579/1695 train_time:154574ms step_avg:97.89ms -step:1580/1695 train_time:154672ms step_avg:97.89ms -step:1581/1695 train_time:154769ms step_avg:97.89ms -step:1582/1695 train_time:154867ms step_avg:97.89ms -step:1583/1695 train_time:154967ms step_avg:97.89ms -step:1584/1695 train_time:155066ms step_avg:97.90ms -step:1585/1695 train_time:155164ms step_avg:97.90ms -step:1586/1695 train_time:155262ms step_avg:97.90ms -step:1587/1695 train_time:155360ms step_avg:97.90ms -step:1588/1695 train_time:155457ms step_avg:97.89ms -step:1589/1695 train_time:155554ms step_avg:97.89ms -step:1590/1695 train_time:155651ms step_avg:97.89ms -step:1591/1695 train_time:155749ms step_avg:97.89ms -step:1592/1695 train_time:155846ms step_avg:97.89ms -step:1593/1695 train_time:155945ms step_avg:97.89ms -step:1594/1695 train_time:156044ms step_avg:97.89ms -step:1595/1695 train_time:156143ms step_avg:97.90ms -step:1596/1695 train_time:156240ms step_avg:97.89ms -step:1597/1695 train_time:156338ms step_avg:97.89ms -step:1598/1695 train_time:156435ms step_avg:97.89ms -step:1599/1695 train_time:156532ms step_avg:97.89ms -step:1600/1695 train_time:156629ms step_avg:97.89ms -step:1601/1695 train_time:156727ms step_avg:97.89ms -step:1602/1695 train_time:156825ms step_avg:97.89ms -step:1603/1695 train_time:156923ms step_avg:97.89ms -step:1604/1695 train_time:157022ms step_avg:97.89ms -step:1605/1695 train_time:157121ms step_avg:97.89ms -step:1606/1695 train_time:157220ms step_avg:97.90ms -step:1607/1695 train_time:157318ms step_avg:97.90ms -step:1608/1695 train_time:157416ms step_avg:97.90ms -step:1609/1695 train_time:157515ms step_avg:97.90ms -step:1610/1695 train_time:157611ms step_avg:97.90ms -step:1611/1695 train_time:157709ms step_avg:97.90ms -step:1612/1695 train_time:157806ms step_avg:97.89ms -step:1613/1695 train_time:157904ms step_avg:97.89ms -step:1614/1695 train_time:158003ms step_avg:97.90ms -step:1615/1695 train_time:158102ms step_avg:97.90ms -step:1616/1695 train_time:158200ms step_avg:97.90ms -step:1617/1695 train_time:158299ms step_avg:97.90ms -step:1618/1695 train_time:158398ms step_avg:97.90ms -step:1619/1695 train_time:158495ms step_avg:97.90ms -step:1620/1695 train_time:158593ms step_avg:97.90ms -step:1621/1695 train_time:158691ms step_avg:97.90ms -step:1622/1695 train_time:158787ms step_avg:97.90ms -step:1623/1695 train_time:158885ms step_avg:97.90ms -step:1624/1695 train_time:158983ms step_avg:97.90ms -step:1625/1695 train_time:159082ms step_avg:97.90ms -step:1625/1695 val_loss:3.2899 train_time:159178ms step_avg:97.96ms -step:1626/1695 train_time:159206ms step_avg:97.91ms -step:1627/1695 train_time:159288ms step_avg:97.90ms -step:1628/1695 train_time:159387ms step_avg:97.90ms -step:1629/1695 train_time:159485ms step_avg:97.90ms -step:1630/1695 train_time:159582ms step_avg:97.90ms -step:1631/1695 train_time:159679ms step_avg:97.90ms -step:1632/1695 train_time:159776ms step_avg:97.90ms -step:1633/1695 train_time:159872ms step_avg:97.90ms -step:1634/1695 train_time:159968ms step_avg:97.90ms -step:1635/1695 train_time:160065ms step_avg:97.90ms -step:1636/1695 train_time:160166ms step_avg:97.90ms -step:1637/1695 train_time:160267ms step_avg:97.90ms -step:1638/1695 train_time:160367ms step_avg:97.90ms -step:1639/1695 train_time:160466ms step_avg:97.90ms -step:1640/1695 train_time:160563ms step_avg:97.90ms -step:1641/1695 train_time:160661ms step_avg:97.90ms -step:1642/1695 train_time:160759ms step_avg:97.90ms -step:1643/1695 train_time:160856ms step_avg:97.90ms -step:1644/1695 train_time:160953ms step_avg:97.90ms -step:1645/1695 train_time:161050ms step_avg:97.90ms -step:1646/1695 train_time:161148ms step_avg:97.90ms -step:1647/1695 train_time:161247ms step_avg:97.90ms -step:1648/1695 train_time:161346ms step_avg:97.90ms -step:1649/1695 train_time:161445ms step_avg:97.90ms -step:1650/1695 train_time:161544ms step_avg:97.91ms -step:1651/1695 train_time:161641ms step_avg:97.90ms -step:1652/1695 train_time:161739ms step_avg:97.91ms -step:1653/1695 train_time:161838ms step_avg:97.91ms -step:1654/1695 train_time:161935ms step_avg:97.91ms -step:1655/1695 train_time:162032ms step_avg:97.90ms -step:1656/1695 train_time:162128ms step_avg:97.90ms -step:1657/1695 train_time:162226ms step_avg:97.90ms -step:1658/1695 train_time:162326ms step_avg:97.90ms -step:1659/1695 train_time:162424ms step_avg:97.90ms -step:1660/1695 train_time:162522ms step_avg:97.90ms -step:1661/1695 train_time:162620ms step_avg:97.90ms -step:1662/1695 train_time:162717ms step_avg:97.90ms -step:1663/1695 train_time:162815ms step_avg:97.90ms -step:1664/1695 train_time:162913ms step_avg:97.90ms -step:1665/1695 train_time:163010ms step_avg:97.90ms -step:1666/1695 train_time:163107ms step_avg:97.90ms -step:1667/1695 train_time:163205ms step_avg:97.90ms -step:1668/1695 train_time:163305ms step_avg:97.90ms -step:1669/1695 train_time:163404ms step_avg:97.91ms -step:1670/1695 train_time:163503ms step_avg:97.91ms -step:1671/1695 train_time:163601ms step_avg:97.91ms -step:1672/1695 train_time:163699ms step_avg:97.91ms -step:1673/1695 train_time:163797ms step_avg:97.91ms -step:1674/1695 train_time:163896ms step_avg:97.91ms -step:1675/1695 train_time:163994ms step_avg:97.91ms -step:1676/1695 train_time:164090ms step_avg:97.91ms -step:1677/1695 train_time:164187ms step_avg:97.91ms -step:1678/1695 train_time:164285ms step_avg:97.91ms -step:1679/1695 train_time:164384ms step_avg:97.91ms -step:1680/1695 train_time:164482ms step_avg:97.91ms -step:1681/1695 train_time:164581ms step_avg:97.91ms -step:1682/1695 train_time:164679ms step_avg:97.91ms -step:1683/1695 train_time:164777ms step_avg:97.91ms -step:1684/1695 train_time:164875ms step_avg:97.91ms -step:1685/1695 train_time:164972ms step_avg:97.91ms -step:1686/1695 train_time:165068ms step_avg:97.91ms -step:1687/1695 train_time:165167ms step_avg:97.91ms -step:1688/1695 train_time:165265ms step_avg:97.91ms -step:1689/1695 train_time:165363ms step_avg:97.91ms -step:1690/1695 train_time:165460ms step_avg:97.91ms -step:1691/1695 train_time:165558ms step_avg:97.91ms -step:1692/1695 train_time:165656ms step_avg:97.91ms -step:1693/1695 train_time:165754ms step_avg:97.91ms -step:1694/1695 train_time:165851ms step_avg:97.91ms -step:1695/1695 train_time:165949ms step_avg:97.90ms -step:1695/1695 val_loss:3.2780 train_time:166044ms step_avg:97.96ms -peak memory allocated: 34000 MiB reserved: 49636 MiB diff --git a/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt b/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt deleted file mode 100644 index 789ccc0d1..000000000 --- a/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:47:47 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 32C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 36C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 37C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 32C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 38C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 36C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:520ms step_avg:519.82ms -step:2/1695 train_time:546ms step_avg:272.89ms -step:3/1695 train_time:615ms step_avg:204.90ms -step:4/1695 train_time:707ms step_avg:176.66ms -step:5/1695 train_time:800ms step_avg:159.97ms -step:6/1695 train_time:892ms step_avg:148.74ms -step:7/1695 train_time:986ms step_avg:140.81ms -step:8/1695 train_time:1079ms step_avg:134.91ms -step:9/1695 train_time:1173ms step_avg:130.29ms -step:10/1695 train_time:1266ms step_avg:126.65ms -step:11/1695 train_time:1360ms step_avg:123.62ms -step:12/1695 train_time:1455ms step_avg:121.29ms -step:13/1695 train_time:1553ms step_avg:119.48ms -step:14/1695 train_time:1649ms step_avg:117.78ms -step:15/1695 train_time:1745ms step_avg:116.36ms -step:16/1695 train_time:1840ms step_avg:115.01ms -step:17/1695 train_time:1933ms step_avg:113.72ms -step:18/1695 train_time:2027ms step_avg:112.62ms -step:19/1695 train_time:2121ms step_avg:111.64ms -step:20/1695 train_time:2215ms step_avg:110.74ms -step:21/1695 train_time:2309ms step_avg:109.95ms -step:22/1695 train_time:2405ms step_avg:109.30ms -step:23/1695 train_time:2501ms step_avg:108.76ms -step:24/1695 train_time:2596ms step_avg:108.18ms -step:25/1695 train_time:2692ms step_avg:107.66ms -step:26/1695 train_time:2787ms step_avg:107.19ms -step:27/1695 train_time:2882ms step_avg:106.75ms -step:28/1695 train_time:2976ms step_avg:106.30ms -step:29/1695 train_time:3070ms step_avg:105.87ms -step:30/1695 train_time:3164ms step_avg:105.48ms -step:31/1695 train_time:3258ms step_avg:105.10ms -step:32/1695 train_time:3352ms step_avg:104.76ms -step:33/1695 train_time:3448ms step_avg:104.49ms -step:34/1695 train_time:3546ms step_avg:104.28ms -step:35/1695 train_time:3641ms step_avg:104.04ms -step:36/1695 train_time:3735ms step_avg:103.76ms -step:37/1695 train_time:3830ms step_avg:103.51ms -step:38/1695 train_time:3925ms step_avg:103.28ms -step:39/1695 train_time:4019ms step_avg:103.05ms -step:40/1695 train_time:4113ms step_avg:102.83ms -step:41/1695 train_time:4208ms step_avg:102.63ms -step:42/1695 train_time:4302ms step_avg:102.43ms -step:43/1695 train_time:4396ms step_avg:102.22ms -step:44/1695 train_time:4490ms step_avg:102.06ms -step:45/1695 train_time:4587ms step_avg:101.92ms -step:46/1695 train_time:4682ms step_avg:101.78ms -step:47/1695 train_time:4776ms step_avg:101.62ms -step:48/1695 train_time:4871ms step_avg:101.47ms -step:49/1695 train_time:4965ms step_avg:101.33ms -step:50/1695 train_time:5061ms step_avg:101.22ms -step:51/1695 train_time:5154ms step_avg:101.05ms -step:52/1695 train_time:5248ms step_avg:100.93ms -step:53/1695 train_time:5343ms step_avg:100.82ms -step:54/1695 train_time:5439ms step_avg:100.73ms -step:55/1695 train_time:5533ms step_avg:100.59ms -step:56/1695 train_time:5628ms step_avg:100.50ms -step:57/1695 train_time:5723ms step_avg:100.41ms -step:58/1695 train_time:5819ms step_avg:100.33ms -step:59/1695 train_time:5913ms step_avg:100.22ms -step:60/1695 train_time:6008ms step_avg:100.13ms -step:61/1695 train_time:6101ms step_avg:100.02ms -step:62/1695 train_time:6195ms step_avg:99.91ms -step:63/1695 train_time:6289ms step_avg:99.82ms -step:64/1695 train_time:6384ms step_avg:99.75ms -step:65/1695 train_time:6479ms step_avg:99.68ms -step:66/1695 train_time:6573ms step_avg:99.59ms -step:67/1695 train_time:6668ms step_avg:99.52ms -step:68/1695 train_time:6762ms step_avg:99.45ms -step:69/1695 train_time:6856ms step_avg:99.37ms -step:70/1695 train_time:6950ms step_avg:99.29ms -step:71/1695 train_time:7044ms step_avg:99.21ms -step:72/1695 train_time:7138ms step_avg:99.14ms -step:73/1695 train_time:7232ms step_avg:99.07ms -step:74/1695 train_time:7327ms step_avg:99.02ms -step:75/1695 train_time:7423ms step_avg:98.97ms -step:76/1695 train_time:7518ms step_avg:98.92ms -step:77/1695 train_time:7613ms step_avg:98.87ms -step:78/1695 train_time:7709ms step_avg:98.83ms -step:79/1695 train_time:7803ms step_avg:98.78ms -step:80/1695 train_time:7897ms step_avg:98.72ms -step:81/1695 train_time:7991ms step_avg:98.65ms -step:82/1695 train_time:8085ms step_avg:98.60ms -step:83/1695 train_time:8180ms step_avg:98.56ms -step:84/1695 train_time:8274ms step_avg:98.50ms -step:85/1695 train_time:8368ms step_avg:98.45ms -step:86/1695 train_time:8463ms step_avg:98.40ms -step:87/1695 train_time:8556ms step_avg:98.35ms -step:88/1695 train_time:8651ms step_avg:98.31ms -step:89/1695 train_time:8747ms step_avg:98.29ms -step:90/1695 train_time:8843ms step_avg:98.26ms -step:91/1695 train_time:8938ms step_avg:98.22ms -step:92/1695 train_time:9031ms step_avg:98.16ms -step:93/1695 train_time:9125ms step_avg:98.12ms -step:94/1695 train_time:9220ms step_avg:98.08ms -step:95/1695 train_time:9314ms step_avg:98.04ms -step:96/1695 train_time:9408ms step_avg:98.00ms -step:97/1695 train_time:9503ms step_avg:97.97ms -step:98/1695 train_time:9597ms step_avg:97.93ms -step:99/1695 train_time:9691ms step_avg:97.89ms -step:100/1695 train_time:9787ms step_avg:97.87ms -step:101/1695 train_time:9881ms step_avg:97.84ms -step:102/1695 train_time:9975ms step_avg:97.80ms -step:103/1695 train_time:10069ms step_avg:97.76ms -step:104/1695 train_time:10164ms step_avg:97.74ms -step:105/1695 train_time:10258ms step_avg:97.70ms -step:106/1695 train_time:10352ms step_avg:97.66ms -step:107/1695 train_time:10447ms step_avg:97.64ms -step:108/1695 train_time:10543ms step_avg:97.62ms -step:109/1695 train_time:10637ms step_avg:97.59ms -step:110/1695 train_time:10732ms step_avg:97.56ms -step:111/1695 train_time:10827ms step_avg:97.54ms -step:112/1695 train_time:10922ms step_avg:97.52ms -step:113/1695 train_time:11016ms step_avg:97.49ms -step:114/1695 train_time:11110ms step_avg:97.46ms -step:115/1695 train_time:11205ms step_avg:97.44ms -step:116/1695 train_time:11299ms step_avg:97.41ms -step:117/1695 train_time:11393ms step_avg:97.38ms -step:118/1695 train_time:11488ms step_avg:97.35ms -step:119/1695 train_time:11583ms step_avg:97.33ms -step:120/1695 train_time:11677ms step_avg:97.31ms -step:121/1695 train_time:11771ms step_avg:97.28ms -step:122/1695 train_time:11867ms step_avg:97.27ms -step:123/1695 train_time:11961ms step_avg:97.24ms -step:124/1695 train_time:12054ms step_avg:97.21ms -step:125/1695 train_time:12149ms step_avg:97.19ms -step:125/1695 val_loss:4.3107 train_time:12241ms step_avg:97.93ms -step:126/1695 train_time:12268ms step_avg:97.36ms -step:127/1695 train_time:12344ms step_avg:97.20ms -step:128/1695 train_time:12448ms step_avg:97.25ms -step:129/1695 train_time:12544ms step_avg:97.24ms -step:130/1695 train_time:12638ms step_avg:97.21ms -step:131/1695 train_time:12730ms step_avg:97.18ms -step:132/1695 train_time:12824ms step_avg:97.15ms -step:133/1695 train_time:12918ms step_avg:97.13ms -step:134/1695 train_time:13011ms step_avg:97.10ms -step:135/1695 train_time:13105ms step_avg:97.07ms -step:136/1695 train_time:13198ms step_avg:97.04ms -step:137/1695 train_time:13292ms step_avg:97.02ms -step:138/1695 train_time:13389ms step_avg:97.02ms -step:139/1695 train_time:13485ms step_avg:97.01ms -step:140/1695 train_time:13580ms step_avg:97.00ms -step:141/1695 train_time:13675ms step_avg:96.98ms -step:142/1695 train_time:13770ms step_avg:96.97ms -step:143/1695 train_time:13862ms step_avg:96.94ms -step:144/1695 train_time:13955ms step_avg:96.91ms -step:145/1695 train_time:14049ms step_avg:96.89ms -step:146/1695 train_time:14144ms step_avg:96.87ms -step:147/1695 train_time:14238ms step_avg:96.86ms -step:148/1695 train_time:14333ms step_avg:96.84ms -step:149/1695 train_time:14427ms step_avg:96.83ms -step:150/1695 train_time:14524ms step_avg:96.82ms -step:151/1695 train_time:14617ms step_avg:96.80ms -step:152/1695 train_time:14711ms step_avg:96.78ms -step:153/1695 train_time:14805ms step_avg:96.76ms -step:154/1695 train_time:14899ms step_avg:96.75ms -step:155/1695 train_time:14992ms step_avg:96.72ms -step:156/1695 train_time:15085ms step_avg:96.70ms -step:157/1695 train_time:15180ms step_avg:96.69ms -step:158/1695 train_time:15274ms step_avg:96.67ms -step:159/1695 train_time:15368ms step_avg:96.66ms -step:160/1695 train_time:15463ms step_avg:96.65ms -step:161/1695 train_time:15558ms step_avg:96.63ms -step:162/1695 train_time:15652ms step_avg:96.62ms -step:163/1695 train_time:15746ms step_avg:96.60ms -step:164/1695 train_time:15841ms step_avg:96.59ms -step:165/1695 train_time:15935ms step_avg:96.57ms -step:166/1695 train_time:16029ms step_avg:96.56ms -step:167/1695 train_time:16124ms step_avg:96.55ms -step:168/1695 train_time:16217ms step_avg:96.53ms -step:169/1695 train_time:16313ms step_avg:96.52ms -step:170/1695 train_time:16406ms step_avg:96.51ms -step:171/1695 train_time:16501ms step_avg:96.50ms -step:172/1695 train_time:16595ms step_avg:96.48ms -step:173/1695 train_time:16938ms step_avg:97.91ms -step:174/1695 train_time:17053ms step_avg:98.00ms -step:175/1695 train_time:17146ms step_avg:97.98ms -step:176/1695 train_time:17239ms step_avg:97.95ms -step:177/1695 train_time:17332ms step_avg:97.92ms -step:178/1695 train_time:17426ms step_avg:97.90ms -step:179/1695 train_time:17518ms step_avg:97.87ms -step:180/1695 train_time:17611ms step_avg:97.84ms -step:181/1695 train_time:17705ms step_avg:97.82ms -step:182/1695 train_time:17798ms step_avg:97.79ms -step:183/1695 train_time:17897ms step_avg:97.80ms -step:184/1695 train_time:17993ms step_avg:97.79ms -step:185/1695 train_time:18087ms step_avg:97.77ms -step:186/1695 train_time:18181ms step_avg:97.75ms -step:187/1695 train_time:18275ms step_avg:97.73ms -step:188/1695 train_time:18369ms step_avg:97.71ms -step:189/1695 train_time:18462ms step_avg:97.68ms -step:190/1695 train_time:18555ms step_avg:97.66ms -step:191/1695 train_time:18648ms step_avg:97.63ms -step:192/1695 train_time:18743ms step_avg:97.62ms -step:193/1695 train_time:18839ms step_avg:97.61ms -step:194/1695 train_time:18934ms step_avg:97.60ms -step:195/1695 train_time:19028ms step_avg:97.58ms -step:196/1695 train_time:19122ms step_avg:97.56ms -step:197/1695 train_time:19216ms step_avg:97.55ms -step:198/1695 train_time:19310ms step_avg:97.53ms -step:199/1695 train_time:19404ms step_avg:97.51ms -step:200/1695 train_time:19498ms step_avg:97.49ms -step:201/1695 train_time:19591ms step_avg:97.47ms -step:202/1695 train_time:19684ms step_avg:97.45ms -step:203/1695 train_time:19779ms step_avg:97.43ms -step:204/1695 train_time:19873ms step_avg:97.42ms -step:205/1695 train_time:19967ms step_avg:97.40ms -step:206/1695 train_time:20062ms step_avg:97.39ms -step:207/1695 train_time:20157ms step_avg:97.38ms -step:208/1695 train_time:20251ms step_avg:97.36ms -step:209/1695 train_time:20345ms step_avg:97.34ms -step:210/1695 train_time:20439ms step_avg:97.33ms -step:211/1695 train_time:20532ms step_avg:97.31ms -step:212/1695 train_time:20625ms step_avg:97.29ms -step:213/1695 train_time:20719ms step_avg:97.27ms -step:214/1695 train_time:20814ms step_avg:97.26ms -step:215/1695 train_time:20908ms step_avg:97.25ms -step:216/1695 train_time:21002ms step_avg:97.23ms -step:217/1695 train_time:21097ms step_avg:97.22ms -step:218/1695 train_time:21191ms step_avg:97.21ms -step:219/1695 train_time:21285ms step_avg:97.19ms -step:220/1695 train_time:21380ms step_avg:97.18ms -step:221/1695 train_time:21474ms step_avg:97.17ms -step:222/1695 train_time:21568ms step_avg:97.15ms -step:223/1695 train_time:21664ms step_avg:97.15ms -step:224/1695 train_time:21756ms step_avg:97.13ms -step:225/1695 train_time:21850ms step_avg:97.11ms -step:226/1695 train_time:21945ms step_avg:97.10ms -step:227/1695 train_time:22040ms step_avg:97.09ms -step:228/1695 train_time:22134ms step_avg:97.08ms -step:229/1695 train_time:22228ms step_avg:97.06ms -step:230/1695 train_time:22322ms step_avg:97.05ms -step:231/1695 train_time:22417ms step_avg:97.04ms -step:232/1695 train_time:22510ms step_avg:97.02ms -step:233/1695 train_time:22604ms step_avg:97.01ms -step:234/1695 train_time:22699ms step_avg:97.00ms -step:235/1695 train_time:22792ms step_avg:96.99ms -step:236/1695 train_time:22886ms step_avg:96.97ms -step:237/1695 train_time:22981ms step_avg:96.96ms -step:238/1695 train_time:23075ms step_avg:96.95ms -step:239/1695 train_time:23170ms step_avg:96.94ms -step:240/1695 train_time:23264ms step_avg:96.93ms -step:241/1695 train_time:23358ms step_avg:96.92ms -step:242/1695 train_time:23451ms step_avg:96.91ms -step:243/1695 train_time:23546ms step_avg:96.90ms -step:244/1695 train_time:23641ms step_avg:96.89ms -step:245/1695 train_time:23735ms step_avg:96.88ms -step:246/1695 train_time:23829ms step_avg:96.87ms -step:247/1695 train_time:23923ms step_avg:96.85ms -step:248/1695 train_time:24018ms step_avg:96.85ms -step:249/1695 train_time:24112ms step_avg:96.83ms -step:250/1695 train_time:24206ms step_avg:96.82ms -step:250/1695 val_loss:3.9776 train_time:24298ms step_avg:97.19ms -step:251/1695 train_time:24328ms step_avg:96.92ms -step:252/1695 train_time:24399ms step_avg:96.82ms -step:253/1695 train_time:24498ms step_avg:96.83ms -step:254/1695 train_time:24592ms step_avg:96.82ms -step:255/1695 train_time:24685ms step_avg:96.81ms -step:256/1695 train_time:24778ms step_avg:96.79ms -step:257/1695 train_time:24870ms step_avg:96.77ms -step:258/1695 train_time:24964ms step_avg:96.76ms -step:259/1695 train_time:25057ms step_avg:96.74ms -step:260/1695 train_time:25150ms step_avg:96.73ms -step:261/1695 train_time:25243ms step_avg:96.72ms -step:262/1695 train_time:25338ms step_avg:96.71ms -step:263/1695 train_time:25434ms step_avg:96.71ms -step:264/1695 train_time:25529ms step_avg:96.70ms -step:265/1695 train_time:25625ms step_avg:96.70ms -step:266/1695 train_time:25719ms step_avg:96.69ms -step:267/1695 train_time:25812ms step_avg:96.68ms -step:268/1695 train_time:25905ms step_avg:96.66ms -step:269/1695 train_time:25998ms step_avg:96.65ms -step:270/1695 train_time:26091ms step_avg:96.63ms -step:271/1695 train_time:26185ms step_avg:96.62ms -step:272/1695 train_time:26279ms step_avg:96.61ms -step:273/1695 train_time:26373ms step_avg:96.60ms -step:274/1695 train_time:26468ms step_avg:96.60ms -step:275/1695 train_time:26563ms step_avg:96.59ms -step:276/1695 train_time:26658ms step_avg:96.59ms -step:277/1695 train_time:26751ms step_avg:96.57ms -step:278/1695 train_time:26845ms step_avg:96.56ms -step:279/1695 train_time:26939ms step_avg:96.56ms -step:280/1695 train_time:27032ms step_avg:96.54ms -step:281/1695 train_time:27125ms step_avg:96.53ms -step:282/1695 train_time:27219ms step_avg:96.52ms -step:283/1695 train_time:27313ms step_avg:96.51ms -step:284/1695 train_time:27407ms step_avg:96.50ms -step:285/1695 train_time:27502ms step_avg:96.50ms -step:286/1695 train_time:27597ms step_avg:96.49ms -step:287/1695 train_time:27691ms step_avg:96.49ms -step:288/1695 train_time:27785ms step_avg:96.48ms -step:289/1695 train_time:27879ms step_avg:96.47ms -step:290/1695 train_time:27973ms step_avg:96.46ms -step:291/1695 train_time:28067ms step_avg:96.45ms -step:292/1695 train_time:28160ms step_avg:96.44ms -step:293/1695 train_time:28254ms step_avg:96.43ms -step:294/1695 train_time:28347ms step_avg:96.42ms -step:295/1695 train_time:28442ms step_avg:96.41ms -step:296/1695 train_time:28537ms step_avg:96.41ms -step:297/1695 train_time:28632ms step_avg:96.40ms -step:298/1695 train_time:28726ms step_avg:96.40ms -step:299/1695 train_time:28821ms step_avg:96.39ms -step:300/1695 train_time:28915ms step_avg:96.38ms -step:301/1695 train_time:29008ms step_avg:96.37ms -step:302/1695 train_time:29101ms step_avg:96.36ms -step:303/1695 train_time:29195ms step_avg:96.35ms -step:304/1695 train_time:29289ms step_avg:96.34ms -step:305/1695 train_time:29383ms step_avg:96.34ms -step:306/1695 train_time:29478ms step_avg:96.33ms -step:307/1695 train_time:29572ms step_avg:96.32ms -step:308/1695 train_time:29667ms step_avg:96.32ms -step:309/1695 train_time:29761ms step_avg:96.31ms -step:310/1695 train_time:29855ms step_avg:96.31ms -step:311/1695 train_time:29949ms step_avg:96.30ms -step:312/1695 train_time:30044ms step_avg:96.29ms -step:313/1695 train_time:30138ms step_avg:96.29ms -step:314/1695 train_time:30231ms step_avg:96.28ms -step:315/1695 train_time:30325ms step_avg:96.27ms -step:316/1695 train_time:30419ms step_avg:96.26ms -step:317/1695 train_time:30513ms step_avg:96.26ms -step:318/1695 train_time:30607ms step_avg:96.25ms -step:319/1695 train_time:30702ms step_avg:96.24ms -step:320/1695 train_time:30796ms step_avg:96.24ms -step:321/1695 train_time:30890ms step_avg:96.23ms -step:322/1695 train_time:30984ms step_avg:96.22ms -step:323/1695 train_time:31078ms step_avg:96.22ms -step:324/1695 train_time:31171ms step_avg:96.21ms -step:325/1695 train_time:31265ms step_avg:96.20ms -step:326/1695 train_time:31359ms step_avg:96.19ms -step:327/1695 train_time:31452ms step_avg:96.18ms -step:328/1695 train_time:31545ms step_avg:96.17ms -step:329/1695 train_time:31640ms step_avg:96.17ms -step:330/1695 train_time:31735ms step_avg:96.17ms -step:331/1695 train_time:31829ms step_avg:96.16ms -step:332/1695 train_time:31923ms step_avg:96.15ms -step:333/1695 train_time:32016ms step_avg:96.14ms -step:334/1695 train_time:32109ms step_avg:96.14ms -step:335/1695 train_time:32203ms step_avg:96.13ms -step:336/1695 train_time:32298ms step_avg:96.12ms -step:337/1695 train_time:32392ms step_avg:96.12ms -step:338/1695 train_time:32486ms step_avg:96.11ms -step:339/1695 train_time:32580ms step_avg:96.11ms -step:340/1695 train_time:32674ms step_avg:96.10ms -step:341/1695 train_time:32767ms step_avg:96.09ms -step:342/1695 train_time:32862ms step_avg:96.09ms -step:343/1695 train_time:32957ms step_avg:96.08ms -step:344/1695 train_time:33050ms step_avg:96.08ms -step:345/1695 train_time:33394ms step_avg:96.79ms -step:346/1695 train_time:33465ms step_avg:96.72ms -step:347/1695 train_time:33557ms step_avg:96.71ms -step:348/1695 train_time:33650ms step_avg:96.70ms -step:349/1695 train_time:33744ms step_avg:96.69ms -step:350/1695 train_time:33837ms step_avg:96.68ms -step:351/1695 train_time:33930ms step_avg:96.67ms -step:352/1695 train_time:34023ms step_avg:96.66ms -step:353/1695 train_time:34116ms step_avg:96.65ms -step:354/1695 train_time:34209ms step_avg:96.64ms -step:355/1695 train_time:34305ms step_avg:96.63ms -step:356/1695 train_time:34401ms step_avg:96.63ms -step:357/1695 train_time:34497ms step_avg:96.63ms -step:358/1695 train_time:34591ms step_avg:96.62ms -step:359/1695 train_time:34685ms step_avg:96.61ms -step:360/1695 train_time:34779ms step_avg:96.61ms -step:361/1695 train_time:34871ms step_avg:96.60ms -step:362/1695 train_time:34965ms step_avg:96.59ms -step:363/1695 train_time:35060ms step_avg:96.58ms -step:364/1695 train_time:35153ms step_avg:96.57ms -step:365/1695 train_time:35246ms step_avg:96.56ms -step:366/1695 train_time:35341ms step_avg:96.56ms -step:367/1695 train_time:35436ms step_avg:96.56ms -step:368/1695 train_time:35530ms step_avg:96.55ms -step:369/1695 train_time:35625ms step_avg:96.54ms -step:370/1695 train_time:35719ms step_avg:96.54ms -step:371/1695 train_time:35812ms step_avg:96.53ms -step:372/1695 train_time:35905ms step_avg:96.52ms -step:373/1695 train_time:35999ms step_avg:96.51ms -step:374/1695 train_time:36092ms step_avg:96.50ms -step:375/1695 train_time:36185ms step_avg:96.49ms -step:375/1695 val_loss:3.8206 train_time:36277ms step_avg:96.74ms -step:376/1695 train_time:36303ms step_avg:96.55ms -step:377/1695 train_time:36376ms step_avg:96.49ms -step:378/1695 train_time:36477ms step_avg:96.50ms -step:379/1695 train_time:36572ms step_avg:96.50ms -step:380/1695 train_time:36667ms step_avg:96.49ms -step:381/1695 train_time:36760ms step_avg:96.48ms -step:382/1695 train_time:36852ms step_avg:96.47ms -step:383/1695 train_time:36946ms step_avg:96.46ms -step:384/1695 train_time:37039ms step_avg:96.45ms -step:385/1695 train_time:37131ms step_avg:96.44ms -step:386/1695 train_time:37224ms step_avg:96.44ms -step:387/1695 train_time:37319ms step_avg:96.43ms -step:388/1695 train_time:37415ms step_avg:96.43ms -step:389/1695 train_time:37511ms step_avg:96.43ms -step:390/1695 train_time:37607ms step_avg:96.43ms -step:391/1695 train_time:37701ms step_avg:96.42ms -step:392/1695 train_time:37794ms step_avg:96.41ms -step:393/1695 train_time:37888ms step_avg:96.41ms -step:394/1695 train_time:37982ms step_avg:96.40ms -step:395/1695 train_time:38075ms step_avg:96.39ms -step:396/1695 train_time:38168ms step_avg:96.39ms -step:397/1695 train_time:38263ms step_avg:96.38ms -step:398/1695 train_time:38358ms step_avg:96.38ms -step:399/1695 train_time:38452ms step_avg:96.37ms -step:400/1695 train_time:38548ms step_avg:96.37ms -step:401/1695 train_time:38643ms step_avg:96.37ms -step:402/1695 train_time:38737ms step_avg:96.36ms -step:403/1695 train_time:38830ms step_avg:96.35ms -step:404/1695 train_time:38924ms step_avg:96.35ms -step:405/1695 train_time:39017ms step_avg:96.34ms -step:406/1695 train_time:39110ms step_avg:96.33ms -step:407/1695 train_time:39203ms step_avg:96.32ms -step:408/1695 train_time:39296ms step_avg:96.31ms -step:409/1695 train_time:39391ms step_avg:96.31ms -step:410/1695 train_time:39486ms step_avg:96.31ms -step:411/1695 train_time:39580ms step_avg:96.30ms -step:412/1695 train_time:39673ms step_avg:96.29ms -step:413/1695 train_time:39768ms step_avg:96.29ms -step:414/1695 train_time:39863ms step_avg:96.29ms -step:415/1695 train_time:39957ms step_avg:96.28ms -step:416/1695 train_time:40050ms step_avg:96.27ms -step:417/1695 train_time:40144ms step_avg:96.27ms -step:418/1695 train_time:40237ms step_avg:96.26ms -step:419/1695 train_time:40331ms step_avg:96.25ms -step:420/1695 train_time:40425ms step_avg:96.25ms -step:421/1695 train_time:40519ms step_avg:96.25ms -step:422/1695 train_time:40613ms step_avg:96.24ms -step:423/1695 train_time:40707ms step_avg:96.23ms -step:424/1695 train_time:40802ms step_avg:96.23ms -step:425/1695 train_time:40895ms step_avg:96.22ms -step:426/1695 train_time:40990ms step_avg:96.22ms -step:427/1695 train_time:41085ms step_avg:96.22ms -step:428/1695 train_time:41178ms step_avg:96.21ms -step:429/1695 train_time:41272ms step_avg:96.20ms -step:430/1695 train_time:41365ms step_avg:96.20ms -step:431/1695 train_time:41459ms step_avg:96.19ms -step:432/1695 train_time:41553ms step_avg:96.19ms -step:433/1695 train_time:41648ms step_avg:96.18ms -step:434/1695 train_time:41743ms step_avg:96.18ms -step:435/1695 train_time:41837ms step_avg:96.18ms -step:436/1695 train_time:41930ms step_avg:96.17ms -step:437/1695 train_time:42025ms step_avg:96.17ms -step:438/1695 train_time:42119ms step_avg:96.16ms -step:439/1695 train_time:42212ms step_avg:96.15ms -step:440/1695 train_time:42305ms step_avg:96.15ms -step:441/1695 train_time:42399ms step_avg:96.14ms -step:442/1695 train_time:42492ms step_avg:96.14ms -step:443/1695 train_time:42587ms step_avg:96.13ms -step:444/1695 train_time:42680ms step_avg:96.13ms -step:445/1695 train_time:42774ms step_avg:96.12ms -step:446/1695 train_time:42868ms step_avg:96.12ms -step:447/1695 train_time:42963ms step_avg:96.11ms -step:448/1695 train_time:43057ms step_avg:96.11ms -step:449/1695 train_time:43150ms step_avg:96.10ms -step:450/1695 train_time:43245ms step_avg:96.10ms -step:451/1695 train_time:43338ms step_avg:96.09ms -step:452/1695 train_time:43432ms step_avg:96.09ms -step:453/1695 train_time:43526ms step_avg:96.08ms -step:454/1695 train_time:43621ms step_avg:96.08ms -step:455/1695 train_time:43715ms step_avg:96.08ms -step:456/1695 train_time:43809ms step_avg:96.07ms -step:457/1695 train_time:43903ms step_avg:96.07ms -step:458/1695 train_time:43997ms step_avg:96.06ms -step:459/1695 train_time:44091ms step_avg:96.06ms -step:460/1695 train_time:44185ms step_avg:96.05ms -step:461/1695 train_time:44278ms step_avg:96.05ms -step:462/1695 train_time:44372ms step_avg:96.04ms -step:463/1695 train_time:44466ms step_avg:96.04ms -step:464/1695 train_time:44560ms step_avg:96.03ms -step:465/1695 train_time:44654ms step_avg:96.03ms -step:466/1695 train_time:44748ms step_avg:96.03ms -step:467/1695 train_time:44842ms step_avg:96.02ms -step:468/1695 train_time:44935ms step_avg:96.02ms -step:469/1695 train_time:45031ms step_avg:96.01ms -step:470/1695 train_time:45125ms step_avg:96.01ms -step:471/1695 train_time:45218ms step_avg:96.00ms -step:472/1695 train_time:45312ms step_avg:96.00ms -step:473/1695 train_time:45406ms step_avg:96.00ms -step:474/1695 train_time:45499ms step_avg:95.99ms -step:475/1695 train_time:45594ms step_avg:95.99ms -step:476/1695 train_time:45689ms step_avg:95.99ms -step:477/1695 train_time:45784ms step_avg:95.98ms -step:478/1695 train_time:45877ms step_avg:95.98ms -step:479/1695 train_time:45971ms step_avg:95.97ms -step:480/1695 train_time:46065ms step_avg:95.97ms -step:481/1695 train_time:46159ms step_avg:95.96ms -step:482/1695 train_time:46252ms step_avg:95.96ms -step:483/1695 train_time:46347ms step_avg:95.96ms -step:484/1695 train_time:46442ms step_avg:95.95ms -step:485/1695 train_time:46535ms step_avg:95.95ms -step:486/1695 train_time:46629ms step_avg:95.95ms -step:487/1695 train_time:46723ms step_avg:95.94ms -step:488/1695 train_time:46818ms step_avg:95.94ms -step:489/1695 train_time:46911ms step_avg:95.93ms -step:490/1695 train_time:47004ms step_avg:95.93ms -step:491/1695 train_time:47098ms step_avg:95.92ms -step:492/1695 train_time:47191ms step_avg:95.92ms -step:493/1695 train_time:47286ms step_avg:95.91ms -step:494/1695 train_time:47381ms step_avg:95.91ms -step:495/1695 train_time:47474ms step_avg:95.91ms -step:496/1695 train_time:47568ms step_avg:95.90ms -step:497/1695 train_time:47663ms step_avg:95.90ms -step:498/1695 train_time:47756ms step_avg:95.90ms -step:499/1695 train_time:47851ms step_avg:95.89ms -step:500/1695 train_time:47944ms step_avg:95.89ms -step:500/1695 val_loss:3.7169 train_time:48035ms step_avg:96.07ms -step:501/1695 train_time:48064ms step_avg:95.94ms -step:502/1695 train_time:48136ms step_avg:95.89ms -step:503/1695 train_time:48238ms step_avg:95.90ms -step:504/1695 train_time:48335ms step_avg:95.90ms -step:505/1695 train_time:48428ms step_avg:95.90ms -step:506/1695 train_time:48521ms step_avg:95.89ms -step:507/1695 train_time:48614ms step_avg:95.89ms -step:508/1695 train_time:48707ms step_avg:95.88ms -step:509/1695 train_time:48800ms step_avg:95.87ms -step:510/1695 train_time:48893ms step_avg:95.87ms -step:511/1695 train_time:48986ms step_avg:95.86ms -step:512/1695 train_time:49080ms step_avg:95.86ms -step:513/1695 train_time:49177ms step_avg:95.86ms -step:514/1695 train_time:49275ms step_avg:95.87ms -step:515/1695 train_time:49371ms step_avg:95.87ms -step:516/1695 train_time:49465ms step_avg:95.86ms -step:517/1695 train_time:49558ms step_avg:95.86ms -step:518/1695 train_time:49652ms step_avg:95.85ms -step:519/1695 train_time:49984ms step_avg:96.31ms -step:520/1695 train_time:50159ms step_avg:96.46ms -step:521/1695 train_time:50250ms step_avg:96.45ms -step:522/1695 train_time:50343ms step_avg:96.44ms -step:523/1695 train_time:50436ms step_avg:96.44ms -step:524/1695 train_time:50529ms step_avg:96.43ms -step:525/1695 train_time:50621ms step_avg:96.42ms -step:526/1695 train_time:50714ms step_avg:96.41ms -step:527/1695 train_time:50806ms step_avg:96.41ms -step:528/1695 train_time:50898ms step_avg:96.40ms -step:529/1695 train_time:50992ms step_avg:96.39ms -step:530/1695 train_time:51091ms step_avg:96.40ms -step:531/1695 train_time:51188ms step_avg:96.40ms -step:532/1695 train_time:51282ms step_avg:96.40ms -step:533/1695 train_time:51376ms step_avg:96.39ms -step:534/1695 train_time:51470ms step_avg:96.39ms -step:535/1695 train_time:51564ms step_avg:96.38ms -step:536/1695 train_time:51656ms step_avg:96.37ms -step:537/1695 train_time:51750ms step_avg:96.37ms -step:538/1695 train_time:51842ms step_avg:96.36ms -step:539/1695 train_time:51935ms step_avg:96.35ms -step:540/1695 train_time:52032ms step_avg:96.36ms -step:541/1695 train_time:52126ms step_avg:96.35ms -step:542/1695 train_time:52220ms step_avg:96.35ms -step:543/1695 train_time:52316ms step_avg:96.35ms -step:544/1695 train_time:52411ms step_avg:96.34ms -step:545/1695 train_time:52505ms step_avg:96.34ms -step:546/1695 train_time:52598ms step_avg:96.33ms -step:547/1695 train_time:52691ms step_avg:96.33ms -step:548/1695 train_time:52784ms step_avg:96.32ms -step:549/1695 train_time:52877ms step_avg:96.32ms -step:550/1695 train_time:52971ms step_avg:96.31ms -step:551/1695 train_time:53066ms step_avg:96.31ms -step:552/1695 train_time:53161ms step_avg:96.31ms -step:553/1695 train_time:53255ms step_avg:96.30ms -step:554/1695 train_time:53352ms step_avg:96.30ms -step:555/1695 train_time:53446ms step_avg:96.30ms -step:556/1695 train_time:53539ms step_avg:96.29ms -step:557/1695 train_time:53633ms step_avg:96.29ms -step:558/1695 train_time:53727ms step_avg:96.29ms -step:559/1695 train_time:53820ms step_avg:96.28ms -step:560/1695 train_time:53914ms step_avg:96.27ms -step:561/1695 train_time:54009ms step_avg:96.27ms -step:562/1695 train_time:54102ms step_avg:96.27ms -step:563/1695 train_time:54196ms step_avg:96.26ms -step:564/1695 train_time:54291ms step_avg:96.26ms -step:565/1695 train_time:54385ms step_avg:96.26ms -step:566/1695 train_time:54479ms step_avg:96.25ms -step:567/1695 train_time:54573ms step_avg:96.25ms -step:568/1695 train_time:54670ms step_avg:96.25ms -step:569/1695 train_time:54765ms step_avg:96.25ms -step:570/1695 train_time:54859ms step_avg:96.24ms -step:571/1695 train_time:54955ms step_avg:96.24ms -step:572/1695 train_time:55052ms step_avg:96.24ms -step:573/1695 train_time:55149ms step_avg:96.25ms -step:574/1695 train_time:55245ms step_avg:96.25ms -step:575/1695 train_time:55344ms step_avg:96.25ms -step:576/1695 train_time:55437ms step_avg:96.25ms -step:577/1695 train_time:55534ms step_avg:96.25ms -step:578/1695 train_time:55631ms step_avg:96.25ms -step:579/1695 train_time:55727ms step_avg:96.25ms -step:580/1695 train_time:55823ms step_avg:96.25ms -step:581/1695 train_time:55918ms step_avg:96.24ms -step:582/1695 train_time:56015ms step_avg:96.25ms -step:583/1695 train_time:56111ms step_avg:96.25ms -step:584/1695 train_time:56209ms step_avg:96.25ms -step:585/1695 train_time:56306ms step_avg:96.25ms -step:586/1695 train_time:56402ms step_avg:96.25ms -step:587/1695 train_time:56497ms step_avg:96.25ms -step:588/1695 train_time:56593ms step_avg:96.25ms -step:589/1695 train_time:56689ms step_avg:96.25ms -step:590/1695 train_time:56786ms step_avg:96.25ms -step:591/1695 train_time:56882ms step_avg:96.25ms -step:592/1695 train_time:56978ms step_avg:96.25ms -step:593/1695 train_time:57074ms step_avg:96.25ms -step:594/1695 train_time:57172ms step_avg:96.25ms -step:595/1695 train_time:57269ms step_avg:96.25ms -step:596/1695 train_time:57365ms step_avg:96.25ms -step:597/1695 train_time:57461ms step_avg:96.25ms -step:598/1695 train_time:57555ms step_avg:96.25ms -step:599/1695 train_time:57652ms step_avg:96.25ms -step:600/1695 train_time:57749ms step_avg:96.25ms -step:601/1695 train_time:57846ms step_avg:96.25ms -step:602/1695 train_time:57942ms step_avg:96.25ms -step:603/1695 train_time:58037ms step_avg:96.25ms -step:604/1695 train_time:58133ms step_avg:96.25ms -step:605/1695 train_time:58229ms step_avg:96.25ms -step:606/1695 train_time:58326ms step_avg:96.25ms -step:607/1695 train_time:58421ms step_avg:96.25ms -step:608/1695 train_time:58516ms step_avg:96.24ms -step:609/1695 train_time:58613ms step_avg:96.25ms -step:610/1695 train_time:58710ms step_avg:96.25ms -step:611/1695 train_time:58807ms step_avg:96.25ms -step:612/1695 train_time:58903ms step_avg:96.25ms -step:613/1695 train_time:58998ms step_avg:96.24ms -step:614/1695 train_time:59095ms step_avg:96.25ms -step:615/1695 train_time:59191ms step_avg:96.25ms -step:616/1695 train_time:59288ms step_avg:96.25ms -step:617/1695 train_time:59384ms step_avg:96.25ms -step:618/1695 train_time:59479ms step_avg:96.24ms -step:619/1695 train_time:59577ms step_avg:96.25ms -step:620/1695 train_time:59674ms step_avg:96.25ms -step:621/1695 train_time:59772ms step_avg:96.25ms -step:622/1695 train_time:59868ms step_avg:96.25ms -step:623/1695 train_time:59964ms step_avg:96.25ms -step:624/1695 train_time:60059ms step_avg:96.25ms -step:625/1695 train_time:60155ms step_avg:96.25ms -step:625/1695 val_loss:3.6208 train_time:60249ms step_avg:96.40ms -step:626/1695 train_time:60275ms step_avg:96.29ms -step:627/1695 train_time:60358ms step_avg:96.26ms -step:628/1695 train_time:60456ms step_avg:96.27ms -step:629/1695 train_time:60551ms step_avg:96.27ms -step:630/1695 train_time:60646ms step_avg:96.26ms -step:631/1695 train_time:60741ms step_avg:96.26ms -step:632/1695 train_time:60836ms step_avg:96.26ms -step:633/1695 train_time:60930ms step_avg:96.26ms -step:634/1695 train_time:61025ms step_avg:96.25ms -step:635/1695 train_time:61122ms step_avg:96.25ms -step:636/1695 train_time:61217ms step_avg:96.25ms -step:637/1695 train_time:61315ms step_avg:96.26ms -step:638/1695 train_time:61413ms step_avg:96.26ms -step:639/1695 train_time:61509ms step_avg:96.26ms -step:640/1695 train_time:61606ms step_avg:96.26ms -step:641/1695 train_time:61702ms step_avg:96.26ms -step:642/1695 train_time:61797ms step_avg:96.26ms -step:643/1695 train_time:61892ms step_avg:96.26ms -step:644/1695 train_time:61988ms step_avg:96.25ms -step:645/1695 train_time:62084ms step_avg:96.25ms -step:646/1695 train_time:62181ms step_avg:96.26ms -step:647/1695 train_time:62278ms step_avg:96.26ms -step:648/1695 train_time:62376ms step_avg:96.26ms -step:649/1695 train_time:62472ms step_avg:96.26ms -step:650/1695 train_time:62568ms step_avg:96.26ms -step:651/1695 train_time:62664ms step_avg:96.26ms -step:652/1695 train_time:62760ms step_avg:96.26ms -step:653/1695 train_time:62855ms step_avg:96.26ms -step:654/1695 train_time:62950ms step_avg:96.25ms -step:655/1695 train_time:63045ms step_avg:96.25ms -step:656/1695 train_time:63141ms step_avg:96.25ms -step:657/1695 train_time:63238ms step_avg:96.25ms -step:658/1695 train_time:63335ms step_avg:96.25ms -step:659/1695 train_time:63431ms step_avg:96.25ms -step:660/1695 train_time:63529ms step_avg:96.26ms -step:661/1695 train_time:63626ms step_avg:96.26ms -step:662/1695 train_time:63721ms step_avg:96.26ms -step:663/1695 train_time:63816ms step_avg:96.25ms -step:664/1695 train_time:63911ms step_avg:96.25ms -step:665/1695 train_time:64007ms step_avg:96.25ms -step:666/1695 train_time:64103ms step_avg:96.25ms -step:667/1695 train_time:64199ms step_avg:96.25ms -step:668/1695 train_time:64295ms step_avg:96.25ms -step:669/1695 train_time:64390ms step_avg:96.25ms -step:670/1695 train_time:64487ms step_avg:96.25ms -step:671/1695 train_time:64585ms step_avg:96.25ms -step:672/1695 train_time:64682ms step_avg:96.25ms -step:673/1695 train_time:64778ms step_avg:96.25ms -step:674/1695 train_time:64873ms step_avg:96.25ms -step:675/1695 train_time:64969ms step_avg:96.25ms -step:676/1695 train_time:65065ms step_avg:96.25ms -step:677/1695 train_time:65161ms step_avg:96.25ms -step:678/1695 train_time:65257ms step_avg:96.25ms -step:679/1695 train_time:65353ms step_avg:96.25ms -step:680/1695 train_time:65449ms step_avg:96.25ms -step:681/1695 train_time:65545ms step_avg:96.25ms -step:682/1695 train_time:65642ms step_avg:96.25ms -step:683/1695 train_time:65738ms step_avg:96.25ms -step:684/1695 train_time:65833ms step_avg:96.25ms -step:685/1695 train_time:65929ms step_avg:96.25ms -step:686/1695 train_time:66024ms step_avg:96.25ms -step:687/1695 train_time:66120ms step_avg:96.25ms -step:688/1695 train_time:66216ms step_avg:96.24ms -step:689/1695 train_time:66311ms step_avg:96.24ms -step:690/1695 train_time:66408ms step_avg:96.24ms -step:691/1695 train_time:66766ms step_avg:96.62ms -step:692/1695 train_time:66929ms step_avg:96.72ms -step:693/1695 train_time:67024ms step_avg:96.72ms -step:694/1695 train_time:67118ms step_avg:96.71ms -step:695/1695 train_time:67212ms step_avg:96.71ms -step:696/1695 train_time:67307ms step_avg:96.71ms -step:697/1695 train_time:67402ms step_avg:96.70ms -step:698/1695 train_time:67496ms step_avg:96.70ms -step:699/1695 train_time:67590ms step_avg:96.70ms -step:700/1695 train_time:67686ms step_avg:96.69ms -step:701/1695 train_time:67789ms step_avg:96.70ms -step:702/1695 train_time:67889ms step_avg:96.71ms -step:703/1695 train_time:67986ms step_avg:96.71ms -step:704/1695 train_time:68082ms step_avg:96.71ms -step:705/1695 train_time:68176ms step_avg:96.70ms -step:706/1695 train_time:68272ms step_avg:96.70ms -step:707/1695 train_time:68368ms step_avg:96.70ms -step:708/1695 train_time:68463ms step_avg:96.70ms -step:709/1695 train_time:68558ms step_avg:96.70ms -step:710/1695 train_time:68653ms step_avg:96.69ms -step:711/1695 train_time:68750ms step_avg:96.70ms -step:712/1695 train_time:68848ms step_avg:96.70ms -step:713/1695 train_time:68946ms step_avg:96.70ms -step:714/1695 train_time:69043ms step_avg:96.70ms -step:715/1695 train_time:69140ms step_avg:96.70ms -step:716/1695 train_time:69235ms step_avg:96.70ms -step:717/1695 train_time:69330ms step_avg:96.69ms -step:718/1695 train_time:69426ms step_avg:96.69ms -step:719/1695 train_time:69522ms step_avg:96.69ms -step:720/1695 train_time:69617ms step_avg:96.69ms -step:721/1695 train_time:69713ms step_avg:96.69ms -step:722/1695 train_time:69810ms step_avg:96.69ms -step:723/1695 train_time:69907ms step_avg:96.69ms -step:724/1695 train_time:70003ms step_avg:96.69ms -step:725/1695 train_time:70100ms step_avg:96.69ms -step:726/1695 train_time:70195ms step_avg:96.69ms -step:727/1695 train_time:70291ms step_avg:96.69ms -step:728/1695 train_time:70387ms step_avg:96.69ms -step:729/1695 train_time:70482ms step_avg:96.68ms -step:730/1695 train_time:70578ms step_avg:96.68ms -step:731/1695 train_time:70673ms step_avg:96.68ms -step:732/1695 train_time:70769ms step_avg:96.68ms -step:733/1695 train_time:70865ms step_avg:96.68ms -step:734/1695 train_time:70963ms step_avg:96.68ms -step:735/1695 train_time:71059ms step_avg:96.68ms -step:736/1695 train_time:71155ms step_avg:96.68ms -step:737/1695 train_time:71250ms step_avg:96.68ms -step:738/1695 train_time:71345ms step_avg:96.67ms -step:739/1695 train_time:71441ms step_avg:96.67ms -step:740/1695 train_time:71537ms step_avg:96.67ms -step:741/1695 train_time:71632ms step_avg:96.67ms -step:742/1695 train_time:71728ms step_avg:96.67ms -step:743/1695 train_time:71824ms step_avg:96.67ms -step:744/1695 train_time:71920ms step_avg:96.67ms -step:745/1695 train_time:72016ms step_avg:96.67ms -step:746/1695 train_time:72111ms step_avg:96.66ms -step:747/1695 train_time:72208ms step_avg:96.66ms -step:748/1695 train_time:72303ms step_avg:96.66ms -step:749/1695 train_time:72399ms step_avg:96.66ms -step:750/1695 train_time:72495ms step_avg:96.66ms -step:750/1695 val_loss:3.5658 train_time:72587ms step_avg:96.78ms -step:751/1695 train_time:72614ms step_avg:96.69ms -step:752/1695 train_time:72692ms step_avg:96.67ms -step:753/1695 train_time:72794ms step_avg:96.67ms -step:754/1695 train_time:72890ms step_avg:96.67ms -step:755/1695 train_time:72986ms step_avg:96.67ms -step:756/1695 train_time:73081ms step_avg:96.67ms -step:757/1695 train_time:73176ms step_avg:96.67ms -step:758/1695 train_time:73271ms step_avg:96.66ms -step:759/1695 train_time:73365ms step_avg:96.66ms -step:760/1695 train_time:73460ms step_avg:96.66ms -step:761/1695 train_time:73556ms step_avg:96.66ms -step:762/1695 train_time:73653ms step_avg:96.66ms -step:763/1695 train_time:73751ms step_avg:96.66ms -step:764/1695 train_time:73849ms step_avg:96.66ms -step:765/1695 train_time:73945ms step_avg:96.66ms -step:766/1695 train_time:74040ms step_avg:96.66ms -step:767/1695 train_time:74135ms step_avg:96.66ms -step:768/1695 train_time:74231ms step_avg:96.66ms -step:769/1695 train_time:74327ms step_avg:96.65ms -step:770/1695 train_time:74421ms step_avg:96.65ms -step:771/1695 train_time:74516ms step_avg:96.65ms -step:772/1695 train_time:74614ms step_avg:96.65ms -step:773/1695 train_time:74712ms step_avg:96.65ms -step:774/1695 train_time:74809ms step_avg:96.65ms -step:775/1695 train_time:74908ms step_avg:96.65ms -step:776/1695 train_time:75004ms step_avg:96.65ms -step:777/1695 train_time:75099ms step_avg:96.65ms -step:778/1695 train_time:75194ms step_avg:96.65ms -step:779/1695 train_time:75290ms step_avg:96.65ms -step:780/1695 train_time:75386ms step_avg:96.65ms -step:781/1695 train_time:75482ms step_avg:96.65ms -step:782/1695 train_time:75578ms step_avg:96.65ms -step:783/1695 train_time:75675ms step_avg:96.65ms -step:784/1695 train_time:75772ms step_avg:96.65ms -step:785/1695 train_time:75870ms step_avg:96.65ms -step:786/1695 train_time:75967ms step_avg:96.65ms -step:787/1695 train_time:76062ms step_avg:96.65ms -step:788/1695 train_time:76158ms step_avg:96.65ms -step:789/1695 train_time:76253ms step_avg:96.65ms -step:790/1695 train_time:76350ms step_avg:96.65ms -step:791/1695 train_time:76445ms step_avg:96.64ms -step:792/1695 train_time:76541ms step_avg:96.64ms -step:793/1695 train_time:76638ms step_avg:96.64ms -step:794/1695 train_time:76735ms step_avg:96.64ms -step:795/1695 train_time:76832ms step_avg:96.64ms -step:796/1695 train_time:76930ms step_avg:96.65ms -step:797/1695 train_time:77026ms step_avg:96.65ms -step:798/1695 train_time:77121ms step_avg:96.64ms -step:799/1695 train_time:77216ms step_avg:96.64ms -step:800/1695 train_time:77312ms step_avg:96.64ms -step:801/1695 train_time:77411ms step_avg:96.64ms -step:802/1695 train_time:77506ms step_avg:96.64ms -step:803/1695 train_time:77602ms step_avg:96.64ms -step:804/1695 train_time:77697ms step_avg:96.64ms -step:805/1695 train_time:77795ms step_avg:96.64ms -step:806/1695 train_time:77893ms step_avg:96.64ms -step:807/1695 train_time:77990ms step_avg:96.64ms -step:808/1695 train_time:78087ms step_avg:96.64ms -step:809/1695 train_time:78183ms step_avg:96.64ms -step:810/1695 train_time:78281ms step_avg:96.64ms -step:811/1695 train_time:78374ms step_avg:96.64ms -step:812/1695 train_time:78471ms step_avg:96.64ms -step:813/1695 train_time:78568ms step_avg:96.64ms -step:814/1695 train_time:78663ms step_avg:96.64ms -step:815/1695 train_time:78759ms step_avg:96.64ms -step:816/1695 train_time:78855ms step_avg:96.64ms -step:817/1695 train_time:78952ms step_avg:96.64ms -step:818/1695 train_time:79048ms step_avg:96.64ms -step:819/1695 train_time:79144ms step_avg:96.64ms -step:820/1695 train_time:79239ms step_avg:96.63ms -step:821/1695 train_time:79336ms step_avg:96.63ms -step:822/1695 train_time:79431ms step_avg:96.63ms -step:823/1695 train_time:79527ms step_avg:96.63ms -step:824/1695 train_time:79623ms step_avg:96.63ms -step:825/1695 train_time:79719ms step_avg:96.63ms -step:826/1695 train_time:79816ms step_avg:96.63ms -step:827/1695 train_time:79912ms step_avg:96.63ms -step:828/1695 train_time:80008ms step_avg:96.63ms -step:829/1695 train_time:80103ms step_avg:96.63ms -step:830/1695 train_time:80199ms step_avg:96.62ms -step:831/1695 train_time:80295ms step_avg:96.62ms -step:832/1695 train_time:80391ms step_avg:96.62ms -step:833/1695 train_time:80487ms step_avg:96.62ms -step:834/1695 train_time:80584ms step_avg:96.62ms -step:835/1695 train_time:80680ms step_avg:96.62ms -step:836/1695 train_time:80775ms step_avg:96.62ms -step:837/1695 train_time:80872ms step_avg:96.62ms -step:838/1695 train_time:80968ms step_avg:96.62ms -step:839/1695 train_time:81064ms step_avg:96.62ms -step:840/1695 train_time:81160ms step_avg:96.62ms -step:841/1695 train_time:81257ms step_avg:96.62ms -step:842/1695 train_time:81352ms step_avg:96.62ms -step:843/1695 train_time:81450ms step_avg:96.62ms -step:844/1695 train_time:81546ms step_avg:96.62ms -step:845/1695 train_time:81643ms step_avg:96.62ms -step:846/1695 train_time:81738ms step_avg:96.62ms -step:847/1695 train_time:81834ms step_avg:96.62ms -step:848/1695 train_time:81931ms step_avg:96.62ms -step:849/1695 train_time:82026ms step_avg:96.62ms -step:850/1695 train_time:82122ms step_avg:96.61ms -step:851/1695 train_time:82218ms step_avg:96.61ms -step:852/1695 train_time:82313ms step_avg:96.61ms -step:853/1695 train_time:82409ms step_avg:96.61ms -step:854/1695 train_time:82506ms step_avg:96.61ms -step:855/1695 train_time:82602ms step_avg:96.61ms -step:856/1695 train_time:82698ms step_avg:96.61ms -step:857/1695 train_time:82795ms step_avg:96.61ms -step:858/1695 train_time:82891ms step_avg:96.61ms -step:859/1695 train_time:82987ms step_avg:96.61ms -step:860/1695 train_time:83083ms step_avg:96.61ms -step:861/1695 train_time:83180ms step_avg:96.61ms -step:862/1695 train_time:83275ms step_avg:96.61ms -step:863/1695 train_time:83625ms step_avg:96.90ms -step:864/1695 train_time:83802ms step_avg:96.99ms -step:865/1695 train_time:83896ms step_avg:96.99ms -step:866/1695 train_time:83991ms step_avg:96.99ms -step:867/1695 train_time:84086ms step_avg:96.99ms -step:868/1695 train_time:84180ms step_avg:96.98ms -step:869/1695 train_time:84275ms step_avg:96.98ms -step:870/1695 train_time:84370ms step_avg:96.98ms -step:871/1695 train_time:84465ms step_avg:96.97ms -step:872/1695 train_time:84559ms step_avg:96.97ms -step:873/1695 train_time:84660ms step_avg:96.98ms -step:874/1695 train_time:84758ms step_avg:96.98ms -step:875/1695 train_time:84857ms step_avg:96.98ms -step:875/1695 val_loss:3.5251 train_time:84952ms step_avg:97.09ms -step:876/1695 train_time:84978ms step_avg:97.01ms -step:877/1695 train_time:85058ms step_avg:96.99ms -step:878/1695 train_time:85158ms step_avg:96.99ms -step:879/1695 train_time:85256ms step_avg:96.99ms -step:880/1695 train_time:85352ms step_avg:96.99ms -step:881/1695 train_time:85447ms step_avg:96.99ms -step:882/1695 train_time:85541ms step_avg:96.99ms -step:883/1695 train_time:85636ms step_avg:96.98ms -step:884/1695 train_time:85731ms step_avg:96.98ms -step:885/1695 train_time:85825ms step_avg:96.98ms -step:886/1695 train_time:85920ms step_avg:96.98ms -step:887/1695 train_time:86020ms step_avg:96.98ms -step:888/1695 train_time:86118ms step_avg:96.98ms -step:889/1695 train_time:86217ms step_avg:96.98ms -step:890/1695 train_time:86314ms step_avg:96.98ms -step:891/1695 train_time:86411ms step_avg:96.98ms -step:892/1695 train_time:86505ms step_avg:96.98ms -step:893/1695 train_time:86600ms step_avg:96.98ms -step:894/1695 train_time:86695ms step_avg:96.97ms -step:895/1695 train_time:86790ms step_avg:96.97ms -step:896/1695 train_time:86886ms step_avg:96.97ms -step:897/1695 train_time:86982ms step_avg:96.97ms -step:898/1695 train_time:87080ms step_avg:96.97ms -step:899/1695 train_time:87177ms step_avg:96.97ms -step:900/1695 train_time:87276ms step_avg:96.97ms -step:901/1695 train_time:87374ms step_avg:96.97ms -step:902/1695 train_time:87471ms step_avg:96.97ms -step:903/1695 train_time:87567ms step_avg:96.97ms -step:904/1695 train_time:87661ms step_avg:96.97ms -step:905/1695 train_time:87756ms step_avg:96.97ms -step:906/1695 train_time:87852ms step_avg:96.97ms -step:907/1695 train_time:87948ms step_avg:96.97ms -step:908/1695 train_time:88044ms step_avg:96.96ms -step:909/1695 train_time:88140ms step_avg:96.96ms -step:910/1695 train_time:88237ms step_avg:96.96ms -step:911/1695 train_time:88336ms step_avg:96.97ms -step:912/1695 train_time:88433ms step_avg:96.97ms -step:913/1695 train_time:88528ms step_avg:96.96ms -step:914/1695 train_time:88624ms step_avg:96.96ms -step:915/1695 train_time:88718ms step_avg:96.96ms -step:916/1695 train_time:88816ms step_avg:96.96ms -step:917/1695 train_time:88914ms step_avg:96.96ms -step:918/1695 train_time:89010ms step_avg:96.96ms -step:919/1695 train_time:89107ms step_avg:96.96ms -step:920/1695 train_time:89203ms step_avg:96.96ms -step:921/1695 train_time:89299ms step_avg:96.96ms -step:922/1695 train_time:89396ms step_avg:96.96ms -step:923/1695 train_time:89493ms step_avg:96.96ms -step:924/1695 train_time:89590ms step_avg:96.96ms -step:925/1695 train_time:89686ms step_avg:96.96ms -step:926/1695 train_time:89782ms step_avg:96.96ms -step:927/1695 train_time:89878ms step_avg:96.96ms -step:928/1695 train_time:89974ms step_avg:96.95ms -step:929/1695 train_time:90071ms step_avg:96.95ms -step:930/1695 train_time:90167ms step_avg:96.95ms -step:931/1695 train_time:90262ms step_avg:96.95ms -step:932/1695 train_time:90358ms step_avg:96.95ms -step:933/1695 train_time:90455ms step_avg:96.95ms -step:934/1695 train_time:90551ms step_avg:96.95ms -step:935/1695 train_time:90648ms step_avg:96.95ms -step:936/1695 train_time:90743ms step_avg:96.95ms -step:937/1695 train_time:90839ms step_avg:96.95ms -step:938/1695 train_time:90935ms step_avg:96.95ms -step:939/1695 train_time:91030ms step_avg:96.94ms -step:940/1695 train_time:91126ms step_avg:96.94ms -step:941/1695 train_time:91221ms step_avg:96.94ms -step:942/1695 train_time:91317ms step_avg:96.94ms -step:943/1695 train_time:91413ms step_avg:96.94ms -step:944/1695 train_time:91510ms step_avg:96.94ms -step:945/1695 train_time:91607ms step_avg:96.94ms -step:946/1695 train_time:91703ms step_avg:96.94ms -step:947/1695 train_time:91798ms step_avg:96.94ms -step:948/1695 train_time:91894ms step_avg:96.93ms -step:949/1695 train_time:91989ms step_avg:96.93ms -step:950/1695 train_time:92085ms step_avg:96.93ms -step:951/1695 train_time:92181ms step_avg:96.93ms -step:952/1695 train_time:92277ms step_avg:96.93ms -step:953/1695 train_time:92373ms step_avg:96.93ms -step:954/1695 train_time:92469ms step_avg:96.93ms -step:955/1695 train_time:92566ms step_avg:96.93ms -step:956/1695 train_time:92662ms step_avg:96.93ms -step:957/1695 train_time:92758ms step_avg:96.93ms -step:958/1695 train_time:92854ms step_avg:96.92ms -step:959/1695 train_time:92950ms step_avg:96.92ms -step:960/1695 train_time:93046ms step_avg:96.92ms -step:961/1695 train_time:93142ms step_avg:96.92ms -step:962/1695 train_time:93238ms step_avg:96.92ms -step:963/1695 train_time:93334ms step_avg:96.92ms -step:964/1695 train_time:93431ms step_avg:96.92ms -step:965/1695 train_time:93526ms step_avg:96.92ms -step:966/1695 train_time:93621ms step_avg:96.92ms -step:967/1695 train_time:93718ms step_avg:96.92ms -step:968/1695 train_time:93815ms step_avg:96.92ms -step:969/1695 train_time:93911ms step_avg:96.92ms -step:970/1695 train_time:94008ms step_avg:96.92ms -step:971/1695 train_time:94104ms step_avg:96.91ms -step:972/1695 train_time:94200ms step_avg:96.91ms -step:973/1695 train_time:94295ms step_avg:96.91ms -step:974/1695 train_time:94392ms step_avg:96.91ms -step:975/1695 train_time:94489ms step_avg:96.91ms -step:976/1695 train_time:94585ms step_avg:96.91ms -step:977/1695 train_time:94681ms step_avg:96.91ms -step:978/1695 train_time:94778ms step_avg:96.91ms -step:979/1695 train_time:94874ms step_avg:96.91ms -step:980/1695 train_time:94970ms step_avg:96.91ms -step:981/1695 train_time:95066ms step_avg:96.91ms -step:982/1695 train_time:95161ms step_avg:96.91ms -step:983/1695 train_time:95257ms step_avg:96.90ms -step:984/1695 train_time:95353ms step_avg:96.90ms -step:985/1695 train_time:95450ms step_avg:96.90ms -step:986/1695 train_time:95547ms step_avg:96.90ms -step:987/1695 train_time:95644ms step_avg:96.90ms -step:988/1695 train_time:95739ms step_avg:96.90ms -step:989/1695 train_time:95835ms step_avg:96.90ms -step:990/1695 train_time:95932ms step_avg:96.90ms -step:991/1695 train_time:96028ms step_avg:96.90ms -step:992/1695 train_time:96123ms step_avg:96.90ms -step:993/1695 train_time:96218ms step_avg:96.90ms -step:994/1695 train_time:96314ms step_avg:96.90ms -step:995/1695 train_time:96411ms step_avg:96.90ms -step:996/1695 train_time:96508ms step_avg:96.90ms -step:997/1695 train_time:96605ms step_avg:96.90ms -step:998/1695 train_time:96701ms step_avg:96.89ms -step:999/1695 train_time:96796ms step_avg:96.89ms -step:1000/1695 train_time:96893ms step_avg:96.89ms -step:1000/1695 val_loss:3.4830 train_time:96988ms step_avg:96.99ms -step:1001/1695 train_time:97014ms step_avg:96.92ms -step:1002/1695 train_time:97096ms step_avg:96.90ms -step:1003/1695 train_time:97195ms step_avg:96.90ms -step:1004/1695 train_time:97291ms step_avg:96.90ms -step:1005/1695 train_time:97387ms step_avg:96.90ms -step:1006/1695 train_time:97482ms step_avg:96.90ms -step:1007/1695 train_time:97577ms step_avg:96.90ms -step:1008/1695 train_time:97673ms step_avg:96.90ms -step:1009/1695 train_time:97768ms step_avg:96.90ms -step:1010/1695 train_time:97862ms step_avg:96.89ms -step:1011/1695 train_time:97959ms step_avg:96.89ms -step:1012/1695 train_time:98058ms step_avg:96.90ms -step:1013/1695 train_time:98157ms step_avg:96.90ms -step:1014/1695 train_time:98256ms step_avg:96.90ms -step:1015/1695 train_time:98354ms step_avg:96.90ms -step:1016/1695 train_time:98450ms step_avg:96.90ms -step:1017/1695 train_time:98545ms step_avg:96.90ms -step:1018/1695 train_time:98640ms step_avg:96.90ms -step:1019/1695 train_time:98736ms step_avg:96.89ms -step:1020/1695 train_time:98831ms step_avg:96.89ms -step:1021/1695 train_time:98927ms step_avg:96.89ms -step:1022/1695 train_time:99025ms step_avg:96.89ms -step:1023/1695 train_time:99121ms step_avg:96.89ms -step:1024/1695 train_time:99218ms step_avg:96.89ms -step:1025/1695 train_time:99316ms step_avg:96.89ms -step:1026/1695 train_time:99413ms step_avg:96.89ms -step:1027/1695 train_time:99510ms step_avg:96.89ms -step:1028/1695 train_time:99605ms step_avg:96.89ms -step:1029/1695 train_time:99700ms step_avg:96.89ms -step:1030/1695 train_time:99795ms step_avg:96.89ms -step:1031/1695 train_time:99891ms step_avg:96.89ms -step:1032/1695 train_time:99987ms step_avg:96.89ms -step:1033/1695 train_time:100085ms step_avg:96.89ms -step:1034/1695 train_time:100181ms step_avg:96.89ms -step:1035/1695 train_time:100277ms step_avg:96.89ms -step:1036/1695 train_time:100616ms step_avg:97.12ms -step:1037/1695 train_time:100781ms step_avg:97.19ms -step:1038/1695 train_time:100876ms step_avg:97.18ms -step:1039/1695 train_time:100971ms step_avg:97.18ms -step:1040/1695 train_time:101066ms step_avg:97.18ms -step:1041/1695 train_time:101161ms step_avg:97.18ms -step:1042/1695 train_time:101256ms step_avg:97.17ms -step:1043/1695 train_time:101350ms step_avg:97.17ms -step:1044/1695 train_time:101444ms step_avg:97.17ms -step:1045/1695 train_time:101539ms step_avg:97.17ms -step:1046/1695 train_time:101641ms step_avg:97.17ms -step:1047/1695 train_time:101740ms step_avg:97.17ms -step:1048/1695 train_time:101838ms step_avg:97.17ms -step:1049/1695 train_time:101934ms step_avg:97.17ms -step:1050/1695 train_time:102031ms step_avg:97.17ms -step:1051/1695 train_time:102126ms step_avg:97.17ms -step:1052/1695 train_time:102220ms step_avg:97.17ms -step:1053/1695 train_time:102315ms step_avg:97.17ms -step:1054/1695 train_time:102411ms step_avg:97.16ms -step:1055/1695 train_time:102506ms step_avg:97.16ms -step:1056/1695 train_time:102603ms step_avg:97.16ms -step:1057/1695 train_time:102700ms step_avg:97.16ms -step:1058/1695 train_time:102797ms step_avg:97.16ms -step:1059/1695 train_time:102894ms step_avg:97.16ms -step:1060/1695 train_time:102991ms step_avg:97.16ms -step:1061/1695 train_time:103087ms step_avg:97.16ms -step:1062/1695 train_time:103182ms step_avg:97.16ms -step:1063/1695 train_time:103277ms step_avg:97.16ms -step:1064/1695 train_time:103373ms step_avg:97.16ms -step:1065/1695 train_time:103469ms step_avg:97.15ms -step:1066/1695 train_time:103565ms step_avg:97.15ms -step:1067/1695 train_time:103661ms step_avg:97.15ms -step:1068/1695 train_time:103756ms step_avg:97.15ms -step:1069/1695 train_time:103854ms step_avg:97.15ms -step:1070/1695 train_time:103951ms step_avg:97.15ms -step:1071/1695 train_time:104048ms step_avg:97.15ms -step:1072/1695 train_time:104144ms step_avg:97.15ms -step:1073/1695 train_time:104239ms step_avg:97.15ms -step:1074/1695 train_time:104335ms step_avg:97.15ms -step:1075/1695 train_time:104431ms step_avg:97.14ms -step:1076/1695 train_time:104527ms step_avg:97.14ms -step:1077/1695 train_time:104623ms step_avg:97.14ms -step:1078/1695 train_time:104719ms step_avg:97.14ms -step:1079/1695 train_time:104816ms step_avg:97.14ms -step:1080/1695 train_time:104912ms step_avg:97.14ms -step:1081/1695 train_time:105009ms step_avg:97.14ms -step:1082/1695 train_time:105105ms step_avg:97.14ms -step:1083/1695 train_time:105200ms step_avg:97.14ms -step:1084/1695 train_time:105296ms step_avg:97.14ms -step:1085/1695 train_time:105391ms step_avg:97.13ms -step:1086/1695 train_time:105487ms step_avg:97.13ms -step:1087/1695 train_time:105583ms step_avg:97.13ms -step:1088/1695 train_time:105678ms step_avg:97.13ms -step:1089/1695 train_time:105775ms step_avg:97.13ms -step:1090/1695 train_time:105871ms step_avg:97.13ms -step:1091/1695 train_time:105967ms step_avg:97.13ms -step:1092/1695 train_time:106063ms step_avg:97.13ms -step:1093/1695 train_time:106158ms step_avg:97.13ms -step:1094/1695 train_time:106254ms step_avg:97.12ms -step:1095/1695 train_time:106350ms step_avg:97.12ms -step:1096/1695 train_time:106446ms step_avg:97.12ms -step:1097/1695 train_time:106541ms step_avg:97.12ms -step:1098/1695 train_time:106636ms step_avg:97.12ms -step:1099/1695 train_time:106732ms step_avg:97.12ms -step:1100/1695 train_time:106829ms step_avg:97.12ms -step:1101/1695 train_time:106925ms step_avg:97.12ms -step:1102/1695 train_time:107021ms step_avg:97.12ms -step:1103/1695 train_time:107117ms step_avg:97.11ms -step:1104/1695 train_time:107213ms step_avg:97.11ms -step:1105/1695 train_time:107309ms step_avg:97.11ms -step:1106/1695 train_time:107405ms step_avg:97.11ms -step:1107/1695 train_time:107501ms step_avg:97.11ms -step:1108/1695 train_time:107597ms step_avg:97.11ms -step:1109/1695 train_time:107693ms step_avg:97.11ms -step:1110/1695 train_time:107790ms step_avg:97.11ms -step:1111/1695 train_time:107886ms step_avg:97.11ms -step:1112/1695 train_time:107982ms step_avg:97.11ms -step:1113/1695 train_time:108078ms step_avg:97.11ms -step:1114/1695 train_time:108174ms step_avg:97.10ms -step:1115/1695 train_time:108272ms step_avg:97.10ms -step:1116/1695 train_time:108369ms step_avg:97.10ms -step:1117/1695 train_time:108463ms step_avg:97.10ms -step:1118/1695 train_time:108558ms step_avg:97.10ms -step:1119/1695 train_time:108653ms step_avg:97.10ms -step:1120/1695 train_time:108749ms step_avg:97.10ms -step:1121/1695 train_time:108846ms step_avg:97.10ms -step:1122/1695 train_time:108942ms step_avg:97.10ms -step:1123/1695 train_time:109037ms step_avg:97.09ms -step:1124/1695 train_time:109133ms step_avg:97.09ms -step:1125/1695 train_time:109229ms step_avg:97.09ms -step:1125/1695 val_loss:3.4364 train_time:109322ms step_avg:97.18ms -step:1126/1695 train_time:109349ms step_avg:97.11ms -step:1127/1695 train_time:109426ms step_avg:97.10ms -step:1128/1695 train_time:109523ms step_avg:97.09ms -step:1129/1695 train_time:109619ms step_avg:97.09ms -step:1130/1695 train_time:109715ms step_avg:97.09ms -step:1131/1695 train_time:109810ms step_avg:97.09ms -step:1132/1695 train_time:109905ms step_avg:97.09ms -step:1133/1695 train_time:110001ms step_avg:97.09ms -step:1134/1695 train_time:110098ms step_avg:97.09ms -step:1135/1695 train_time:110195ms step_avg:97.09ms -step:1136/1695 train_time:110294ms step_avg:97.09ms -step:1137/1695 train_time:110397ms step_avg:97.09ms -step:1138/1695 train_time:110496ms step_avg:97.10ms -step:1139/1695 train_time:110596ms step_avg:97.10ms -step:1140/1695 train_time:110695ms step_avg:97.10ms -step:1141/1695 train_time:110793ms step_avg:97.10ms -step:1142/1695 train_time:110891ms step_avg:97.10ms -step:1143/1695 train_time:110988ms step_avg:97.10ms -step:1144/1695 train_time:111085ms step_avg:97.10ms -step:1145/1695 train_time:111182ms step_avg:97.10ms -step:1146/1695 train_time:111280ms step_avg:97.10ms -step:1147/1695 train_time:111379ms step_avg:97.10ms -step:1148/1695 train_time:111478ms step_avg:97.11ms -step:1149/1695 train_time:111576ms step_avg:97.11ms -step:1150/1695 train_time:111674ms step_avg:97.11ms -step:1151/1695 train_time:111773ms step_avg:97.11ms -step:1152/1695 train_time:111871ms step_avg:97.11ms -step:1153/1695 train_time:111969ms step_avg:97.11ms -step:1154/1695 train_time:112065ms step_avg:97.11ms -step:1155/1695 train_time:112162ms step_avg:97.11ms -step:1156/1695 train_time:112260ms step_avg:97.11ms -step:1157/1695 train_time:112358ms step_avg:97.11ms -step:1158/1695 train_time:112458ms step_avg:97.11ms -step:1159/1695 train_time:112557ms step_avg:97.12ms -step:1160/1695 train_time:112656ms step_avg:97.12ms -step:1161/1695 train_time:112755ms step_avg:97.12ms -step:1162/1695 train_time:112852ms step_avg:97.12ms -step:1163/1695 train_time:112951ms step_avg:97.12ms -step:1164/1695 train_time:113048ms step_avg:97.12ms -step:1165/1695 train_time:113144ms step_avg:97.12ms -step:1166/1695 train_time:113242ms step_avg:97.12ms -step:1167/1695 train_time:113340ms step_avg:97.12ms -step:1168/1695 train_time:113437ms step_avg:97.12ms -step:1169/1695 train_time:113536ms step_avg:97.12ms -step:1170/1695 train_time:113634ms step_avg:97.12ms -step:1171/1695 train_time:113733ms step_avg:97.12ms -step:1172/1695 train_time:113832ms step_avg:97.13ms -step:1173/1695 train_time:113931ms step_avg:97.13ms -step:1174/1695 train_time:114029ms step_avg:97.13ms -step:1175/1695 train_time:114127ms step_avg:97.13ms -step:1176/1695 train_time:114224ms step_avg:97.13ms -step:1177/1695 train_time:114322ms step_avg:97.13ms -step:1178/1695 train_time:114419ms step_avg:97.13ms -step:1179/1695 train_time:114517ms step_avg:97.13ms -step:1180/1695 train_time:114615ms step_avg:97.13ms -step:1181/1695 train_time:114713ms step_avg:97.13ms -step:1182/1695 train_time:114810ms step_avg:97.13ms -step:1183/1695 train_time:114908ms step_avg:97.13ms -step:1184/1695 train_time:115005ms step_avg:97.13ms -step:1185/1695 train_time:115103ms step_avg:97.13ms -step:1186/1695 train_time:115200ms step_avg:97.13ms -step:1187/1695 train_time:115298ms step_avg:97.13ms -step:1188/1695 train_time:115397ms step_avg:97.14ms -step:1189/1695 train_time:115495ms step_avg:97.14ms -step:1190/1695 train_time:115593ms step_avg:97.14ms -step:1191/1695 train_time:115691ms step_avg:97.14ms -step:1192/1695 train_time:115789ms step_avg:97.14ms -step:1193/1695 train_time:115887ms step_avg:97.14ms -step:1194/1695 train_time:115984ms step_avg:97.14ms -step:1195/1695 train_time:116082ms step_avg:97.14ms -step:1196/1695 train_time:116180ms step_avg:97.14ms -step:1197/1695 train_time:116279ms step_avg:97.14ms -step:1198/1695 train_time:116377ms step_avg:97.14ms -step:1199/1695 train_time:116476ms step_avg:97.14ms -step:1200/1695 train_time:116574ms step_avg:97.14ms -step:1201/1695 train_time:116672ms step_avg:97.15ms -step:1202/1695 train_time:116770ms step_avg:97.15ms -step:1203/1695 train_time:116868ms step_avg:97.15ms -step:1204/1695 train_time:116965ms step_avg:97.15ms -step:1205/1695 train_time:117062ms step_avg:97.15ms -step:1206/1695 train_time:117160ms step_avg:97.15ms -step:1207/1695 train_time:117259ms step_avg:97.15ms -step:1208/1695 train_time:117635ms step_avg:97.38ms -step:1209/1695 train_time:117773ms step_avg:97.41ms -step:1210/1695 train_time:117868ms step_avg:97.41ms -step:1211/1695 train_time:117964ms step_avg:97.41ms -step:1212/1695 train_time:118061ms step_avg:97.41ms -step:1213/1695 train_time:118157ms step_avg:97.41ms -step:1214/1695 train_time:118253ms step_avg:97.41ms -step:1215/1695 train_time:118350ms step_avg:97.41ms -step:1216/1695 train_time:118446ms step_avg:97.41ms -step:1217/1695 train_time:118542ms step_avg:97.41ms -step:1218/1695 train_time:118646ms step_avg:97.41ms -step:1219/1695 train_time:118750ms step_avg:97.42ms -step:1220/1695 train_time:118849ms step_avg:97.42ms -step:1221/1695 train_time:118947ms step_avg:97.42ms -step:1222/1695 train_time:119044ms step_avg:97.42ms -step:1223/1695 train_time:119140ms step_avg:97.42ms -step:1224/1695 train_time:119238ms step_avg:97.42ms -step:1225/1695 train_time:119335ms step_avg:97.42ms -step:1226/1695 train_time:119432ms step_avg:97.42ms -step:1227/1695 train_time:119529ms step_avg:97.42ms -step:1228/1695 train_time:119626ms step_avg:97.42ms -step:1229/1695 train_time:119726ms step_avg:97.42ms -step:1230/1695 train_time:119824ms step_avg:97.42ms -step:1231/1695 train_time:119923ms step_avg:97.42ms -step:1232/1695 train_time:120021ms step_avg:97.42ms -step:1233/1695 train_time:120119ms step_avg:97.42ms -step:1234/1695 train_time:120216ms step_avg:97.42ms -step:1235/1695 train_time:120313ms step_avg:97.42ms -step:1236/1695 train_time:120410ms step_avg:97.42ms -step:1237/1695 train_time:120506ms step_avg:97.42ms -step:1238/1695 train_time:120604ms step_avg:97.42ms -step:1239/1695 train_time:120703ms step_avg:97.42ms -step:1240/1695 train_time:120801ms step_avg:97.42ms -step:1241/1695 train_time:120900ms step_avg:97.42ms -step:1242/1695 train_time:120998ms step_avg:97.42ms -step:1243/1695 train_time:121096ms step_avg:97.42ms -step:1244/1695 train_time:121194ms step_avg:97.42ms -step:1245/1695 train_time:121292ms step_avg:97.42ms -step:1246/1695 train_time:121388ms step_avg:97.42ms -step:1247/1695 train_time:121485ms step_avg:97.42ms -step:1248/1695 train_time:121583ms step_avg:97.42ms -step:1249/1695 train_time:121682ms step_avg:97.42ms -step:1250/1695 train_time:121780ms step_avg:97.42ms -step:1250/1695 val_loss:3.3889 train_time:121876ms step_avg:97.50ms -step:1251/1695 train_time:121915ms step_avg:97.45ms -step:1252/1695 train_time:121984ms step_avg:97.43ms -step:1253/1695 train_time:122082ms step_avg:97.43ms -step:1254/1695 train_time:122178ms step_avg:97.43ms -step:1255/1695 train_time:122274ms step_avg:97.43ms -step:1256/1695 train_time:122371ms step_avg:97.43ms -step:1257/1695 train_time:122467ms step_avg:97.43ms -step:1258/1695 train_time:122564ms step_avg:97.43ms -step:1259/1695 train_time:122660ms step_avg:97.43ms -step:1260/1695 train_time:122759ms step_avg:97.43ms -step:1261/1695 train_time:122864ms step_avg:97.43ms -step:1262/1695 train_time:122963ms step_avg:97.43ms -step:1263/1695 train_time:123061ms step_avg:97.44ms -step:1264/1695 train_time:123158ms step_avg:97.43ms -step:1265/1695 train_time:123255ms step_avg:97.43ms -step:1266/1695 train_time:123352ms step_avg:97.43ms -step:1267/1695 train_time:123449ms step_avg:97.43ms -step:1268/1695 train_time:123546ms step_avg:97.43ms -step:1269/1695 train_time:123642ms step_avg:97.43ms -step:1270/1695 train_time:123739ms step_avg:97.43ms -step:1271/1695 train_time:123838ms step_avg:97.43ms -step:1272/1695 train_time:123937ms step_avg:97.43ms -step:1273/1695 train_time:124036ms step_avg:97.44ms -step:1274/1695 train_time:124135ms step_avg:97.44ms -step:1275/1695 train_time:124233ms step_avg:97.44ms -step:1276/1695 train_time:124329ms step_avg:97.44ms -step:1277/1695 train_time:124426ms step_avg:97.44ms -step:1278/1695 train_time:124524ms step_avg:97.44ms -step:1279/1695 train_time:124619ms step_avg:97.44ms -step:1280/1695 train_time:124717ms step_avg:97.43ms -step:1281/1695 train_time:124815ms step_avg:97.44ms -step:1282/1695 train_time:124914ms step_avg:97.44ms -step:1283/1695 train_time:125012ms step_avg:97.44ms -step:1284/1695 train_time:125111ms step_avg:97.44ms -step:1285/1695 train_time:125208ms step_avg:97.44ms -step:1286/1695 train_time:125306ms step_avg:97.44ms -step:1287/1695 train_time:125403ms step_avg:97.44ms -step:1288/1695 train_time:125499ms step_avg:97.44ms -step:1289/1695 train_time:125597ms step_avg:97.44ms -step:1290/1695 train_time:125693ms step_avg:97.44ms -step:1291/1695 train_time:125793ms step_avg:97.44ms -step:1292/1695 train_time:125892ms step_avg:97.44ms -step:1293/1695 train_time:125993ms step_avg:97.44ms -step:1294/1695 train_time:126092ms step_avg:97.44ms -step:1295/1695 train_time:126191ms step_avg:97.44ms -step:1296/1695 train_time:126289ms step_avg:97.45ms -step:1297/1695 train_time:126387ms step_avg:97.45ms -step:1298/1695 train_time:126485ms step_avg:97.45ms -step:1299/1695 train_time:126581ms step_avg:97.45ms -step:1300/1695 train_time:126679ms step_avg:97.45ms -step:1301/1695 train_time:126776ms step_avg:97.44ms -step:1302/1695 train_time:126874ms step_avg:97.45ms -step:1303/1695 train_time:126975ms step_avg:97.45ms -step:1304/1695 train_time:127073ms step_avg:97.45ms -step:1305/1695 train_time:127173ms step_avg:97.45ms -step:1306/1695 train_time:127270ms step_avg:97.45ms -step:1307/1695 train_time:127369ms step_avg:97.45ms -step:1308/1695 train_time:127467ms step_avg:97.45ms -step:1309/1695 train_time:127564ms step_avg:97.45ms -step:1310/1695 train_time:127661ms step_avg:97.45ms -step:1311/1695 train_time:127759ms step_avg:97.45ms -step:1312/1695 train_time:127855ms step_avg:97.45ms -step:1313/1695 train_time:127953ms step_avg:97.45ms -step:1314/1695 train_time:128051ms step_avg:97.45ms -step:1315/1695 train_time:128149ms step_avg:97.45ms -step:1316/1695 train_time:128246ms step_avg:97.45ms -step:1317/1695 train_time:128343ms step_avg:97.45ms -step:1318/1695 train_time:128440ms step_avg:97.45ms -step:1319/1695 train_time:128538ms step_avg:97.45ms -step:1320/1695 train_time:128636ms step_avg:97.45ms -step:1321/1695 train_time:128734ms step_avg:97.45ms -step:1322/1695 train_time:128833ms step_avg:97.45ms -step:1323/1695 train_time:128932ms step_avg:97.45ms -step:1324/1695 train_time:129029ms step_avg:97.45ms -step:1325/1695 train_time:129127ms step_avg:97.45ms -step:1326/1695 train_time:129225ms step_avg:97.45ms -step:1327/1695 train_time:129322ms step_avg:97.45ms -step:1328/1695 train_time:129419ms step_avg:97.45ms -step:1329/1695 train_time:129517ms step_avg:97.45ms -step:1330/1695 train_time:129615ms step_avg:97.45ms -step:1331/1695 train_time:129713ms step_avg:97.46ms -step:1332/1695 train_time:129811ms step_avg:97.46ms -step:1333/1695 train_time:129910ms step_avg:97.46ms -step:1334/1695 train_time:130007ms step_avg:97.46ms -step:1335/1695 train_time:130105ms step_avg:97.46ms -step:1336/1695 train_time:130204ms step_avg:97.46ms -step:1337/1695 train_time:130300ms step_avg:97.46ms -step:1338/1695 train_time:130398ms step_avg:97.46ms -step:1339/1695 train_time:130495ms step_avg:97.46ms -step:1340/1695 train_time:130593ms step_avg:97.46ms -step:1341/1695 train_time:130692ms step_avg:97.46ms -step:1342/1695 train_time:130790ms step_avg:97.46ms -step:1343/1695 train_time:130888ms step_avg:97.46ms -step:1344/1695 train_time:130986ms step_avg:97.46ms -step:1345/1695 train_time:131084ms step_avg:97.46ms -step:1346/1695 train_time:131182ms step_avg:97.46ms -step:1347/1695 train_time:131280ms step_avg:97.46ms -step:1348/1695 train_time:131377ms step_avg:97.46ms -step:1349/1695 train_time:131475ms step_avg:97.46ms -step:1350/1695 train_time:131572ms step_avg:97.46ms -step:1351/1695 train_time:131672ms step_avg:97.46ms -step:1352/1695 train_time:131770ms step_avg:97.46ms -step:1353/1695 train_time:131869ms step_avg:97.46ms -step:1354/1695 train_time:131966ms step_avg:97.46ms -step:1355/1695 train_time:132064ms step_avg:97.46ms -step:1356/1695 train_time:132161ms step_avg:97.46ms -step:1357/1695 train_time:132258ms step_avg:97.46ms -step:1358/1695 train_time:132356ms step_avg:97.46ms -step:1359/1695 train_time:132454ms step_avg:97.46ms -step:1360/1695 train_time:132552ms step_avg:97.46ms -step:1361/1695 train_time:132650ms step_avg:97.46ms -step:1362/1695 train_time:132748ms step_avg:97.47ms -step:1363/1695 train_time:132845ms step_avg:97.47ms -step:1364/1695 train_time:132942ms step_avg:97.46ms -step:1365/1695 train_time:133038ms step_avg:97.46ms -step:1366/1695 train_time:133136ms step_avg:97.46ms -step:1367/1695 train_time:133235ms step_avg:97.46ms -step:1368/1695 train_time:133333ms step_avg:97.47ms -step:1369/1695 train_time:133430ms step_avg:97.47ms -step:1370/1695 train_time:133529ms step_avg:97.47ms -step:1371/1695 train_time:133630ms step_avg:97.47ms -step:1372/1695 train_time:133725ms step_avg:97.47ms -step:1373/1695 train_time:133823ms step_avg:97.47ms -step:1374/1695 train_time:133921ms step_avg:97.47ms -step:1375/1695 train_time:134023ms step_avg:97.47ms -step:1375/1695 val_loss:3.3505 train_time:134113ms step_avg:97.54ms -step:1376/1695 train_time:134162ms step_avg:97.50ms -step:1377/1695 train_time:134219ms step_avg:97.47ms -step:1378/1695 train_time:134319ms step_avg:97.47ms -step:1379/1695 train_time:134416ms step_avg:97.47ms -step:1380/1695 train_time:134512ms step_avg:97.47ms -step:1381/1695 train_time:134884ms step_avg:97.67ms -step:1382/1695 train_time:135043ms step_avg:97.72ms -step:1383/1695 train_time:135139ms step_avg:97.71ms -step:1384/1695 train_time:135235ms step_avg:97.71ms -step:1385/1695 train_time:135332ms step_avg:97.71ms -step:1386/1695 train_time:135428ms step_avg:97.71ms -step:1387/1695 train_time:135526ms step_avg:97.71ms -step:1388/1695 train_time:135621ms step_avg:97.71ms -step:1389/1695 train_time:135716ms step_avg:97.71ms -step:1390/1695 train_time:135813ms step_avg:97.71ms -step:1391/1695 train_time:135916ms step_avg:97.71ms -step:1392/1695 train_time:136020ms step_avg:97.72ms -step:1393/1695 train_time:136121ms step_avg:97.72ms -step:1394/1695 train_time:136218ms step_avg:97.72ms -step:1395/1695 train_time:136315ms step_avg:97.72ms -step:1396/1695 train_time:136412ms step_avg:97.72ms -step:1397/1695 train_time:136509ms step_avg:97.72ms -step:1398/1695 train_time:136607ms step_avg:97.72ms -step:1399/1695 train_time:136704ms step_avg:97.72ms -step:1400/1695 train_time:136801ms step_avg:97.71ms -step:1401/1695 train_time:136899ms step_avg:97.72ms -step:1402/1695 train_time:136997ms step_avg:97.72ms -step:1403/1695 train_time:137096ms step_avg:97.72ms -step:1404/1695 train_time:137194ms step_avg:97.72ms -step:1405/1695 train_time:137293ms step_avg:97.72ms -step:1406/1695 train_time:137390ms step_avg:97.72ms -step:1407/1695 train_time:137487ms step_avg:97.72ms -step:1408/1695 train_time:137584ms step_avg:97.72ms -step:1409/1695 train_time:137681ms step_avg:97.72ms -step:1410/1695 train_time:137778ms step_avg:97.72ms -step:1411/1695 train_time:137876ms step_avg:97.72ms -step:1412/1695 train_time:137975ms step_avg:97.72ms -step:1413/1695 train_time:138075ms step_avg:97.72ms -step:1414/1695 train_time:138173ms step_avg:97.72ms -step:1415/1695 train_time:138272ms step_avg:97.72ms -step:1416/1695 train_time:138368ms step_avg:97.72ms -step:1417/1695 train_time:138466ms step_avg:97.72ms -step:1418/1695 train_time:138563ms step_avg:97.72ms -step:1419/1695 train_time:138661ms step_avg:97.72ms -step:1420/1695 train_time:138757ms step_avg:97.72ms -step:1421/1695 train_time:138854ms step_avg:97.72ms -step:1422/1695 train_time:138953ms step_avg:97.72ms -step:1423/1695 train_time:139053ms step_avg:97.72ms -step:1424/1695 train_time:139153ms step_avg:97.72ms -step:1425/1695 train_time:139252ms step_avg:97.72ms -step:1426/1695 train_time:139350ms step_avg:97.72ms -step:1427/1695 train_time:139448ms step_avg:97.72ms -step:1428/1695 train_time:139546ms step_avg:97.72ms -step:1429/1695 train_time:139644ms step_avg:97.72ms -step:1430/1695 train_time:139742ms step_avg:97.72ms -step:1431/1695 train_time:139839ms step_avg:97.72ms -step:1432/1695 train_time:139938ms step_avg:97.72ms -step:1433/1695 train_time:140036ms step_avg:97.72ms -step:1434/1695 train_time:140134ms step_avg:97.72ms -step:1435/1695 train_time:140232ms step_avg:97.72ms -step:1436/1695 train_time:140330ms step_avg:97.72ms -step:1437/1695 train_time:140428ms step_avg:97.72ms -step:1438/1695 train_time:140526ms step_avg:97.72ms -step:1439/1695 train_time:140623ms step_avg:97.72ms -step:1440/1695 train_time:140720ms step_avg:97.72ms -step:1441/1695 train_time:140818ms step_avg:97.72ms -step:1442/1695 train_time:140917ms step_avg:97.72ms -step:1443/1695 train_time:141014ms step_avg:97.72ms -step:1444/1695 train_time:141113ms step_avg:97.72ms -step:1445/1695 train_time:141213ms step_avg:97.73ms -step:1446/1695 train_time:141312ms step_avg:97.73ms -step:1447/1695 train_time:141411ms step_avg:97.73ms -step:1448/1695 train_time:141508ms step_avg:97.73ms -step:1449/1695 train_time:141607ms step_avg:97.73ms -step:1450/1695 train_time:141706ms step_avg:97.73ms -step:1451/1695 train_time:141804ms step_avg:97.73ms -step:1452/1695 train_time:141902ms step_avg:97.73ms -step:1453/1695 train_time:141999ms step_avg:97.73ms -step:1454/1695 train_time:142097ms step_avg:97.73ms -step:1455/1695 train_time:142194ms step_avg:97.73ms -step:1456/1695 train_time:142292ms step_avg:97.73ms -step:1457/1695 train_time:142390ms step_avg:97.73ms -step:1458/1695 train_time:142488ms step_avg:97.73ms -step:1459/1695 train_time:142588ms step_avg:97.73ms -step:1460/1695 train_time:142685ms step_avg:97.73ms -step:1461/1695 train_time:142784ms step_avg:97.73ms -step:1462/1695 train_time:142881ms step_avg:97.73ms -step:1463/1695 train_time:142978ms step_avg:97.73ms -step:1464/1695 train_time:143076ms step_avg:97.73ms -step:1465/1695 train_time:143173ms step_avg:97.73ms -step:1466/1695 train_time:143272ms step_avg:97.73ms -step:1467/1695 train_time:143369ms step_avg:97.73ms -step:1468/1695 train_time:143467ms step_avg:97.73ms -step:1469/1695 train_time:143564ms step_avg:97.73ms -step:1470/1695 train_time:143661ms step_avg:97.73ms -step:1471/1695 train_time:143759ms step_avg:97.73ms -step:1472/1695 train_time:143856ms step_avg:97.73ms -step:1473/1695 train_time:143956ms step_avg:97.73ms -step:1474/1695 train_time:144053ms step_avg:97.73ms -step:1475/1695 train_time:144149ms step_avg:97.73ms -step:1476/1695 train_time:144248ms step_avg:97.73ms -step:1477/1695 train_time:144346ms step_avg:97.73ms -step:1478/1695 train_time:144445ms step_avg:97.73ms -step:1479/1695 train_time:144541ms step_avg:97.73ms -step:1480/1695 train_time:144639ms step_avg:97.73ms -step:1481/1695 train_time:144737ms step_avg:97.73ms -step:1482/1695 train_time:144835ms step_avg:97.73ms -step:1483/1695 train_time:144933ms step_avg:97.73ms -step:1484/1695 train_time:145031ms step_avg:97.73ms -step:1485/1695 train_time:145128ms step_avg:97.73ms -step:1486/1695 train_time:145225ms step_avg:97.73ms -step:1487/1695 train_time:145323ms step_avg:97.73ms -step:1488/1695 train_time:145421ms step_avg:97.73ms -step:1489/1695 train_time:145519ms step_avg:97.73ms -step:1490/1695 train_time:145616ms step_avg:97.73ms -step:1491/1695 train_time:145714ms step_avg:97.73ms -step:1492/1695 train_time:145813ms step_avg:97.73ms -step:1493/1695 train_time:145910ms step_avg:97.73ms -step:1494/1695 train_time:146008ms step_avg:97.73ms -step:1495/1695 train_time:146105ms step_avg:97.73ms -step:1496/1695 train_time:146203ms step_avg:97.73ms -step:1497/1695 train_time:146299ms step_avg:97.73ms -step:1498/1695 train_time:146396ms step_avg:97.73ms -step:1499/1695 train_time:146495ms step_avg:97.73ms -step:1500/1695 train_time:146593ms step_avg:97.73ms -step:1500/1695 val_loss:3.3176 train_time:146690ms step_avg:97.79ms -step:1501/1695 train_time:146740ms step_avg:97.76ms -step:1502/1695 train_time:146800ms step_avg:97.74ms -step:1503/1695 train_time:146898ms step_avg:97.74ms -step:1504/1695 train_time:146995ms step_avg:97.74ms -step:1505/1695 train_time:147094ms step_avg:97.74ms -step:1506/1695 train_time:147190ms step_avg:97.74ms -step:1507/1695 train_time:147287ms step_avg:97.74ms -step:1508/1695 train_time:147383ms step_avg:97.73ms -step:1509/1695 train_time:147481ms step_avg:97.73ms -step:1510/1695 train_time:147577ms step_avg:97.73ms -step:1511/1695 train_time:147677ms step_avg:97.73ms -step:1512/1695 train_time:147778ms step_avg:97.74ms -step:1513/1695 train_time:147878ms step_avg:97.74ms -step:1514/1695 train_time:147975ms step_avg:97.74ms -step:1515/1695 train_time:148073ms step_avg:97.74ms -step:1516/1695 train_time:148170ms step_avg:97.74ms -step:1517/1695 train_time:148269ms step_avg:97.74ms -step:1518/1695 train_time:148367ms step_avg:97.74ms -step:1519/1695 train_time:148463ms step_avg:97.74ms -step:1520/1695 train_time:148560ms step_avg:97.74ms -step:1521/1695 train_time:148658ms step_avg:97.74ms -step:1522/1695 train_time:148758ms step_avg:97.74ms -step:1523/1695 train_time:148857ms step_avg:97.74ms -step:1524/1695 train_time:148956ms step_avg:97.74ms -step:1525/1695 train_time:149055ms step_avg:97.74ms -step:1526/1695 train_time:149152ms step_avg:97.74ms -step:1527/1695 train_time:149250ms step_avg:97.74ms -step:1528/1695 train_time:149347ms step_avg:97.74ms -step:1529/1695 train_time:149444ms step_avg:97.74ms -step:1530/1695 train_time:149541ms step_avg:97.74ms -step:1531/1695 train_time:149638ms step_avg:97.74ms -step:1532/1695 train_time:149735ms step_avg:97.74ms -step:1533/1695 train_time:149835ms step_avg:97.74ms -step:1534/1695 train_time:149933ms step_avg:97.74ms -step:1535/1695 train_time:150032ms step_avg:97.74ms -step:1536/1695 train_time:150129ms step_avg:97.74ms -step:1537/1695 train_time:150227ms step_avg:97.74ms -step:1538/1695 train_time:150324ms step_avg:97.74ms -step:1539/1695 train_time:150421ms step_avg:97.74ms -step:1540/1695 train_time:150519ms step_avg:97.74ms -step:1541/1695 train_time:150616ms step_avg:97.74ms -step:1542/1695 train_time:150715ms step_avg:97.74ms -step:1543/1695 train_time:150813ms step_avg:97.74ms -step:1544/1695 train_time:150912ms step_avg:97.74ms -step:1545/1695 train_time:151010ms step_avg:97.74ms -step:1546/1695 train_time:151108ms step_avg:97.74ms -step:1547/1695 train_time:151206ms step_avg:97.74ms -step:1548/1695 train_time:151304ms step_avg:97.74ms -step:1549/1695 train_time:151401ms step_avg:97.74ms -step:1550/1695 train_time:151498ms step_avg:97.74ms -step:1551/1695 train_time:151596ms step_avg:97.74ms -step:1552/1695 train_time:151941ms step_avg:97.90ms -step:1553/1695 train_time:152117ms step_avg:97.95ms -step:1554/1695 train_time:152212ms step_avg:97.95ms -step:1555/1695 train_time:152308ms step_avg:97.95ms -step:1556/1695 train_time:152403ms step_avg:97.95ms -step:1557/1695 train_time:152499ms step_avg:97.94ms -step:1558/1695 train_time:152596ms step_avg:97.94ms -step:1559/1695 train_time:152693ms step_avg:97.94ms -step:1560/1695 train_time:152790ms step_avg:97.94ms -step:1561/1695 train_time:152886ms step_avg:97.94ms -step:1562/1695 train_time:152987ms step_avg:97.94ms -step:1563/1695 train_time:153089ms step_avg:97.95ms -step:1564/1695 train_time:153189ms step_avg:97.95ms -step:1565/1695 train_time:153285ms step_avg:97.95ms -step:1566/1695 train_time:153382ms step_avg:97.95ms -step:1567/1695 train_time:153479ms step_avg:97.94ms -step:1568/1695 train_time:153576ms step_avg:97.94ms -step:1569/1695 train_time:153673ms step_avg:97.94ms -step:1570/1695 train_time:153769ms step_avg:97.94ms -step:1571/1695 train_time:153866ms step_avg:97.94ms -step:1572/1695 train_time:153964ms step_avg:97.94ms -step:1573/1695 train_time:154064ms step_avg:97.94ms -step:1574/1695 train_time:154164ms step_avg:97.94ms -step:1575/1695 train_time:154262ms step_avg:97.94ms -step:1576/1695 train_time:154360ms step_avg:97.94ms -step:1577/1695 train_time:154457ms step_avg:97.94ms -step:1578/1695 train_time:154554ms step_avg:97.94ms -step:1579/1695 train_time:154651ms step_avg:97.94ms -step:1580/1695 train_time:154748ms step_avg:97.94ms -step:1581/1695 train_time:154845ms step_avg:97.94ms -step:1582/1695 train_time:154943ms step_avg:97.94ms -step:1583/1695 train_time:155042ms step_avg:97.94ms -step:1584/1695 train_time:155142ms step_avg:97.94ms -step:1585/1695 train_time:155239ms step_avg:97.94ms -step:1586/1695 train_time:155337ms step_avg:97.94ms -step:1587/1695 train_time:155435ms step_avg:97.94ms -step:1588/1695 train_time:155532ms step_avg:97.94ms -step:1589/1695 train_time:155630ms step_avg:97.94ms -step:1590/1695 train_time:155726ms step_avg:97.94ms -step:1591/1695 train_time:155823ms step_avg:97.94ms -step:1592/1695 train_time:155921ms step_avg:97.94ms -step:1593/1695 train_time:156019ms step_avg:97.94ms -step:1594/1695 train_time:156119ms step_avg:97.94ms -step:1595/1695 train_time:156218ms step_avg:97.94ms -step:1596/1695 train_time:156317ms step_avg:97.94ms -step:1597/1695 train_time:156416ms step_avg:97.94ms -step:1598/1695 train_time:156513ms step_avg:97.94ms -step:1599/1695 train_time:156610ms step_avg:97.94ms -step:1600/1695 train_time:156707ms step_avg:97.94ms -step:1601/1695 train_time:156804ms step_avg:97.94ms -step:1602/1695 train_time:156901ms step_avg:97.94ms -step:1603/1695 train_time:156999ms step_avg:97.94ms -step:1604/1695 train_time:157097ms step_avg:97.94ms -step:1605/1695 train_time:157196ms step_avg:97.94ms -step:1606/1695 train_time:157295ms step_avg:97.94ms -step:1607/1695 train_time:157393ms step_avg:97.94ms -step:1608/1695 train_time:157491ms step_avg:97.94ms -step:1609/1695 train_time:157590ms step_avg:97.94ms -step:1610/1695 train_time:157687ms step_avg:97.94ms -step:1611/1695 train_time:157784ms step_avg:97.94ms -step:1612/1695 train_time:157881ms step_avg:97.94ms -step:1613/1695 train_time:157978ms step_avg:97.94ms -step:1614/1695 train_time:158077ms step_avg:97.94ms -step:1615/1695 train_time:158176ms step_avg:97.94ms -step:1616/1695 train_time:158274ms step_avg:97.94ms -step:1617/1695 train_time:158373ms step_avg:97.94ms -step:1618/1695 train_time:158471ms step_avg:97.94ms -step:1619/1695 train_time:158569ms step_avg:97.94ms -step:1620/1695 train_time:158666ms step_avg:97.94ms -step:1621/1695 train_time:158765ms step_avg:97.94ms -step:1622/1695 train_time:158861ms step_avg:97.94ms -step:1623/1695 train_time:158958ms step_avg:97.94ms -step:1624/1695 train_time:159057ms step_avg:97.94ms -step:1625/1695 train_time:159155ms step_avg:97.94ms -step:1625/1695 val_loss:3.2898 train_time:159251ms step_avg:98.00ms -step:1626/1695 train_time:159278ms step_avg:97.96ms -step:1627/1695 train_time:159359ms step_avg:97.95ms -step:1628/1695 train_time:159459ms step_avg:97.95ms -step:1629/1695 train_time:159560ms step_avg:97.95ms -step:1630/1695 train_time:159658ms step_avg:97.95ms -step:1631/1695 train_time:159755ms step_avg:97.95ms -step:1632/1695 train_time:159852ms step_avg:97.95ms -step:1633/1695 train_time:159950ms step_avg:97.95ms -step:1634/1695 train_time:160046ms step_avg:97.95ms -step:1635/1695 train_time:160142ms step_avg:97.95ms -step:1636/1695 train_time:160242ms step_avg:97.95ms -step:1637/1695 train_time:160342ms step_avg:97.95ms -step:1638/1695 train_time:160441ms step_avg:97.95ms -step:1639/1695 train_time:160539ms step_avg:97.95ms -step:1640/1695 train_time:160639ms step_avg:97.95ms -step:1641/1695 train_time:160737ms step_avg:97.95ms -step:1642/1695 train_time:160835ms step_avg:97.95ms -step:1643/1695 train_time:160933ms step_avg:97.95ms -step:1644/1695 train_time:161031ms step_avg:97.95ms -step:1645/1695 train_time:161128ms step_avg:97.95ms -step:1646/1695 train_time:161225ms step_avg:97.95ms -step:1647/1695 train_time:161323ms step_avg:97.95ms -step:1648/1695 train_time:161421ms step_avg:97.95ms -step:1649/1695 train_time:161519ms step_avg:97.95ms -step:1650/1695 train_time:161618ms step_avg:97.95ms -step:1651/1695 train_time:161717ms step_avg:97.95ms -step:1652/1695 train_time:161816ms step_avg:97.95ms -step:1653/1695 train_time:161914ms step_avg:97.95ms -step:1654/1695 train_time:162010ms step_avg:97.95ms -step:1655/1695 train_time:162107ms step_avg:97.95ms -step:1656/1695 train_time:162205ms step_avg:97.95ms -step:1657/1695 train_time:162302ms step_avg:97.95ms -step:1658/1695 train_time:162399ms step_avg:97.95ms -step:1659/1695 train_time:162496ms step_avg:97.95ms -step:1660/1695 train_time:162595ms step_avg:97.95ms -step:1661/1695 train_time:162693ms step_avg:97.95ms -step:1662/1695 train_time:162791ms step_avg:97.95ms -step:1663/1695 train_time:162888ms step_avg:97.95ms -step:1664/1695 train_time:162985ms step_avg:97.95ms -step:1665/1695 train_time:163083ms step_avg:97.95ms -step:1666/1695 train_time:163181ms step_avg:97.95ms -step:1667/1695 train_time:163280ms step_avg:97.95ms -step:1668/1695 train_time:163379ms step_avg:97.95ms -step:1669/1695 train_time:163476ms step_avg:97.95ms -step:1670/1695 train_time:163574ms step_avg:97.95ms -step:1671/1695 train_time:163672ms step_avg:97.95ms -step:1672/1695 train_time:163769ms step_avg:97.95ms -step:1673/1695 train_time:163867ms step_avg:97.95ms -step:1674/1695 train_time:163964ms step_avg:97.95ms -step:1675/1695 train_time:164062ms step_avg:97.95ms -step:1676/1695 train_time:164160ms step_avg:97.95ms -step:1677/1695 train_time:164258ms step_avg:97.95ms -step:1678/1695 train_time:164359ms step_avg:97.95ms -step:1679/1695 train_time:164455ms step_avg:97.95ms -step:1680/1695 train_time:164553ms step_avg:97.95ms -step:1681/1695 train_time:164651ms step_avg:97.95ms -step:1682/1695 train_time:164750ms step_avg:97.95ms -step:1683/1695 train_time:164849ms step_avg:97.95ms -step:1684/1695 train_time:164947ms step_avg:97.95ms -step:1685/1695 train_time:165044ms step_avg:97.95ms -step:1686/1695 train_time:165141ms step_avg:97.95ms -step:1687/1695 train_time:165239ms step_avg:97.95ms -step:1688/1695 train_time:165336ms step_avg:97.95ms -step:1689/1695 train_time:165435ms step_avg:97.95ms -step:1690/1695 train_time:165533ms step_avg:97.95ms -step:1691/1695 train_time:165631ms step_avg:97.95ms -step:1692/1695 train_time:165728ms step_avg:97.95ms -step:1693/1695 train_time:165825ms step_avg:97.95ms -step:1694/1695 train_time:165922ms step_avg:97.95ms -step:1695/1695 train_time:166021ms step_avg:97.95ms -step:1695/1695 val_loss:3.2782 train_time:166117ms step_avg:98.00ms -peak memory allocated: 34001 MiB reserved: 49716 MiB diff --git a/records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt b/records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt new file mode 100644 index 000000000..dbcf68147 --- /dev/null +++ b/records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:58:00 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 29C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 32C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 28C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 51945 C /usr/bin/python 0MiB | +| 0 N/A N/A 51946 C /usr/bin/python 0MiB | +| 0 N/A N/A 51947 C /usr/bin/python 0MiB | +| 0 N/A N/A 51948 C /usr/bin/python 0MiB | +| 0 N/A N/A 51949 C /usr/bin/python 0MiB | +| 0 N/A N/A 51950 C /usr/bin/python 0MiB | +| 0 N/A N/A 51951 C /usr/bin/python 0MiB | +| 0 N/A N/A 51952 C /usr/bin/python 0MiB | +| 1 N/A N/A 51946 C /usr/bin/python 0MiB | +| 2 N/A N/A 51947 C /usr/bin/python 0MiB | +| 3 N/A N/A 51948 C /usr/bin/python 0MiB | +| 4 N/A N/A 51949 C /usr/bin/python 0MiB | +| 5 N/A N/A 51950 C /usr/bin/python 0MiB | +| 6 N/A N/A 51951 C /usr/bin/python 0MiB | +| 7 N/A N/A 51952 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:473ms step_avg:472.96ms +step:2/1670 train_time:494ms step_avg:246.81ms +step:3/1670 train_time:568ms step_avg:189.19ms +step:4/1670 train_time:661ms step_avg:165.29ms +step:5/1670 train_time:755ms step_avg:151.05ms +step:6/1670 train_time:849ms step_avg:141.57ms +step:7/1670 train_time:944ms step_avg:134.90ms +step:8/1670 train_time:1039ms step_avg:129.87ms +step:9/1670 train_time:1134ms step_avg:125.95ms +step:10/1670 train_time:1228ms step_avg:122.84ms +step:11/1670 train_time:1323ms step_avg:120.30ms +step:12/1670 train_time:1421ms step_avg:118.43ms +step:13/1670 train_time:1521ms step_avg:117.02ms +step:14/1670 train_time:1618ms step_avg:115.58ms +step:15/1670 train_time:1713ms step_avg:114.21ms +step:16/1670 train_time:1808ms step_avg:113.03ms +step:17/1670 train_time:1903ms step_avg:111.96ms +step:18/1670 train_time:1998ms step_avg:111.01ms +step:19/1670 train_time:2093ms step_avg:110.16ms +step:20/1670 train_time:2188ms step_avg:109.41ms +step:21/1670 train_time:2284ms step_avg:108.74ms +step:22/1670 train_time:2381ms step_avg:108.21ms +step:23/1670 train_time:2477ms step_avg:107.71ms +step:24/1670 train_time:2574ms step_avg:107.26ms +step:25/1670 train_time:2671ms step_avg:106.83ms +step:26/1670 train_time:2766ms step_avg:106.40ms +step:27/1670 train_time:2862ms step_avg:106.01ms +step:28/1670 train_time:2957ms step_avg:105.61ms +step:29/1670 train_time:3052ms step_avg:105.24ms +step:30/1670 train_time:3146ms step_avg:104.88ms +step:31/1670 train_time:3241ms step_avg:104.56ms +step:32/1670 train_time:3337ms step_avg:104.28ms +step:33/1670 train_time:3433ms step_avg:104.04ms +step:34/1670 train_time:3530ms step_avg:103.84ms +step:35/1670 train_time:3628ms step_avg:103.66ms +step:36/1670 train_time:3724ms step_avg:103.45ms +step:37/1670 train_time:3819ms step_avg:103.23ms +step:38/1670 train_time:3914ms step_avg:103.01ms +step:39/1670 train_time:4010ms step_avg:102.81ms +step:40/1670 train_time:4105ms step_avg:102.63ms +step:41/1670 train_time:4200ms step_avg:102.45ms +step:42/1670 train_time:4296ms step_avg:102.28ms +step:43/1670 train_time:4391ms step_avg:102.11ms +step:44/1670 train_time:4487ms step_avg:101.98ms +step:45/1670 train_time:4583ms step_avg:101.84ms +step:46/1670 train_time:4679ms step_avg:101.72ms +step:47/1670 train_time:4775ms step_avg:101.59ms +step:48/1670 train_time:4871ms step_avg:101.48ms +step:49/1670 train_time:4967ms step_avg:101.36ms +step:50/1670 train_time:5062ms step_avg:101.24ms +step:51/1670 train_time:5157ms step_avg:101.12ms +step:52/1670 train_time:5252ms step_avg:101.00ms +step:53/1670 train_time:5347ms step_avg:100.89ms +step:54/1670 train_time:5443ms step_avg:100.80ms +step:55/1670 train_time:5539ms step_avg:100.71ms +step:56/1670 train_time:5635ms step_avg:100.63ms +step:57/1670 train_time:5731ms step_avg:100.54ms +step:58/1670 train_time:5827ms step_avg:100.47ms +step:59/1670 train_time:5923ms step_avg:100.38ms +step:60/1670 train_time:6018ms step_avg:100.31ms +step:61/1670 train_time:6114ms step_avg:100.22ms +step:62/1670 train_time:6209ms step_avg:100.14ms +step:63/1670 train_time:6305ms step_avg:100.07ms +step:64/1670 train_time:6399ms step_avg:99.99ms +step:65/1670 train_time:6494ms step_avg:99.91ms +step:66/1670 train_time:6590ms step_avg:99.85ms +step:67/1670 train_time:6685ms step_avg:99.78ms +step:68/1670 train_time:6781ms step_avg:99.72ms +step:69/1670 train_time:6876ms step_avg:99.66ms +step:70/1670 train_time:6972ms step_avg:99.60ms +step:71/1670 train_time:7067ms step_avg:99.54ms +step:72/1670 train_time:7163ms step_avg:99.48ms +step:73/1670 train_time:7258ms step_avg:99.43ms +step:74/1670 train_time:7354ms step_avg:99.37ms +step:75/1670 train_time:7449ms step_avg:99.32ms +step:76/1670 train_time:7546ms step_avg:99.29ms +step:77/1670 train_time:7642ms step_avg:99.25ms +step:78/1670 train_time:7737ms step_avg:99.19ms +step:79/1670 train_time:7832ms step_avg:99.15ms +step:80/1670 train_time:7928ms step_avg:99.10ms +step:81/1670 train_time:8025ms step_avg:99.08ms +step:82/1670 train_time:8120ms step_avg:99.02ms +step:83/1670 train_time:8215ms step_avg:98.98ms +step:84/1670 train_time:8311ms step_avg:98.94ms +step:85/1670 train_time:8406ms step_avg:98.90ms +step:86/1670 train_time:8502ms step_avg:98.86ms +step:87/1670 train_time:8598ms step_avg:98.82ms +step:88/1670 train_time:8693ms step_avg:98.79ms +step:89/1670 train_time:8789ms step_avg:98.75ms +step:90/1670 train_time:8885ms step_avg:98.72ms +step:91/1670 train_time:8980ms step_avg:98.68ms +step:92/1670 train_time:9075ms step_avg:98.65ms +step:93/1670 train_time:9171ms step_avg:98.62ms +step:94/1670 train_time:9267ms step_avg:98.59ms +step:95/1670 train_time:9362ms step_avg:98.55ms +step:96/1670 train_time:9458ms step_avg:98.52ms +step:97/1670 train_time:9553ms step_avg:98.48ms +step:98/1670 train_time:9648ms step_avg:98.45ms +step:99/1670 train_time:9745ms step_avg:98.43ms +step:100/1670 train_time:9840ms step_avg:98.40ms +step:101/1670 train_time:9936ms step_avg:98.37ms +step:102/1670 train_time:10031ms step_avg:98.34ms +step:103/1670 train_time:10127ms step_avg:98.32ms +step:104/1670 train_time:10222ms step_avg:98.29ms +step:105/1670 train_time:10317ms step_avg:98.26ms +step:106/1670 train_time:10413ms step_avg:98.23ms +step:107/1670 train_time:10508ms step_avg:98.21ms +step:108/1670 train_time:10605ms step_avg:98.19ms +step:109/1670 train_time:10701ms step_avg:98.17ms +step:110/1670 train_time:10797ms step_avg:98.15ms +step:111/1670 train_time:10893ms step_avg:98.13ms +step:112/1670 train_time:10989ms step_avg:98.11ms +step:113/1670 train_time:11085ms step_avg:98.09ms +step:114/1670 train_time:11180ms step_avg:98.07ms +step:115/1670 train_time:11275ms step_avg:98.04ms +step:116/1670 train_time:11370ms step_avg:98.02ms +step:117/1670 train_time:11466ms step_avg:98.00ms +step:118/1670 train_time:11562ms step_avg:97.98ms +step:119/1670 train_time:11658ms step_avg:97.96ms +step:120/1670 train_time:11754ms step_avg:97.95ms +step:121/1670 train_time:11850ms step_avg:97.93ms +step:122/1670 train_time:11946ms step_avg:97.92ms +step:123/1670 train_time:12042ms step_avg:97.90ms +step:124/1670 train_time:12137ms step_avg:97.88ms +step:125/1670 train_time:12233ms step_avg:97.86ms +step:125/1670 val_loss:4.2975 train_time:12327ms step_avg:98.62ms +step:126/1670 train_time:12350ms step_avg:98.02ms +step:127/1670 train_time:12434ms step_avg:97.90ms +step:128/1670 train_time:12538ms step_avg:97.95ms +step:129/1670 train_time:12635ms step_avg:97.95ms +step:130/1670 train_time:12730ms step_avg:97.93ms +step:131/1670 train_time:12825ms step_avg:97.90ms +step:132/1670 train_time:12919ms step_avg:97.87ms +step:133/1670 train_time:13014ms step_avg:97.85ms +step:134/1670 train_time:13109ms step_avg:97.83ms +step:135/1670 train_time:13204ms step_avg:97.81ms +step:136/1670 train_time:13298ms step_avg:97.78ms +step:137/1670 train_time:13394ms step_avg:97.77ms +step:138/1670 train_time:13493ms step_avg:97.77ms +step:139/1670 train_time:13591ms step_avg:97.78ms +step:140/1670 train_time:13687ms step_avg:97.76ms +step:141/1670 train_time:13783ms step_avg:97.75ms +step:142/1670 train_time:13878ms step_avg:97.73ms +step:143/1670 train_time:13973ms step_avg:97.71ms +step:144/1670 train_time:14068ms step_avg:97.70ms +step:145/1670 train_time:14163ms step_avg:97.68ms +step:146/1670 train_time:14258ms step_avg:97.66ms +step:147/1670 train_time:14353ms step_avg:97.64ms +step:148/1670 train_time:14451ms step_avg:97.64ms +step:149/1670 train_time:14548ms step_avg:97.64ms +step:150/1670 train_time:14644ms step_avg:97.62ms +step:151/1670 train_time:14739ms step_avg:97.61ms +step:152/1670 train_time:14835ms step_avg:97.60ms +step:153/1670 train_time:14930ms step_avg:97.58ms +step:154/1670 train_time:15025ms step_avg:97.56ms +step:155/1670 train_time:15120ms step_avg:97.55ms +step:156/1670 train_time:15215ms step_avg:97.53ms +step:157/1670 train_time:15311ms step_avg:97.52ms +step:158/1670 train_time:15407ms step_avg:97.51ms +step:159/1670 train_time:15502ms step_avg:97.50ms +step:160/1670 train_time:15598ms step_avg:97.49ms +step:161/1670 train_time:15695ms step_avg:97.48ms +step:162/1670 train_time:15791ms step_avg:97.47ms +step:163/1670 train_time:15887ms step_avg:97.46ms +step:164/1670 train_time:15983ms step_avg:97.46ms +step:165/1670 train_time:16077ms step_avg:97.44ms +step:166/1670 train_time:16172ms step_avg:97.42ms +step:167/1670 train_time:16268ms step_avg:97.41ms +step:168/1670 train_time:16363ms step_avg:97.40ms +step:169/1670 train_time:16459ms step_avg:97.39ms +step:170/1670 train_time:16554ms step_avg:97.38ms +step:171/1670 train_time:16650ms step_avg:97.37ms +step:172/1670 train_time:16746ms step_avg:97.36ms +step:173/1670 train_time:16842ms step_avg:97.35ms +step:174/1670 train_time:16937ms step_avg:97.34ms +step:175/1670 train_time:17033ms step_avg:97.33ms +step:176/1670 train_time:17128ms step_avg:97.32ms +step:177/1670 train_time:17224ms step_avg:97.31ms +step:178/1670 train_time:17320ms step_avg:97.30ms +step:179/1670 train_time:17414ms step_avg:97.29ms +step:180/1670 train_time:17511ms step_avg:97.29ms +step:181/1670 train_time:17608ms step_avg:97.28ms +step:182/1670 train_time:17704ms step_avg:97.27ms +step:183/1670 train_time:17799ms step_avg:97.26ms +step:184/1670 train_time:17895ms step_avg:97.26ms +step:185/1670 train_time:17991ms step_avg:97.25ms +step:186/1670 train_time:18087ms step_avg:97.24ms +step:187/1670 train_time:18182ms step_avg:97.23ms +step:188/1670 train_time:18277ms step_avg:97.22ms +step:189/1670 train_time:18372ms step_avg:97.21ms +step:190/1670 train_time:18468ms step_avg:97.20ms +step:191/1670 train_time:18564ms step_avg:97.19ms +step:192/1670 train_time:18660ms step_avg:97.18ms +step:193/1670 train_time:18755ms step_avg:97.18ms +step:194/1670 train_time:18851ms step_avg:97.17ms +step:195/1670 train_time:18947ms step_avg:97.16ms +step:196/1670 train_time:19042ms step_avg:97.15ms +step:197/1670 train_time:19137ms step_avg:97.14ms +step:198/1670 train_time:19232ms step_avg:97.13ms +step:199/1670 train_time:19327ms step_avg:97.12ms +step:200/1670 train_time:19423ms step_avg:97.11ms +step:201/1670 train_time:19518ms step_avg:97.10ms +step:202/1670 train_time:19613ms step_avg:97.10ms +step:203/1670 train_time:19709ms step_avg:97.09ms +step:204/1670 train_time:19805ms step_avg:97.08ms +step:205/1670 train_time:19900ms step_avg:97.07ms +step:206/1670 train_time:19995ms step_avg:97.06ms +step:207/1670 train_time:20091ms step_avg:97.06ms +step:208/1670 train_time:20186ms step_avg:97.05ms +step:209/1670 train_time:20282ms step_avg:97.04ms +step:210/1670 train_time:20377ms step_avg:97.03ms +step:211/1670 train_time:20472ms step_avg:97.03ms +step:212/1670 train_time:20568ms step_avg:97.02ms +step:213/1670 train_time:20870ms step_avg:97.98ms +step:214/1670 train_time:20943ms step_avg:97.86ms +step:215/1670 train_time:21036ms step_avg:97.84ms +step:216/1670 train_time:21131ms step_avg:97.83ms +step:217/1670 train_time:21226ms step_avg:97.81ms +step:218/1670 train_time:21321ms step_avg:97.80ms +step:219/1670 train_time:21415ms step_avg:97.79ms +step:220/1670 train_time:21510ms step_avg:97.77ms +step:221/1670 train_time:21605ms step_avg:97.76ms +step:222/1670 train_time:21699ms step_avg:97.74ms +step:223/1670 train_time:21798ms step_avg:97.75ms +step:224/1670 train_time:21896ms step_avg:97.75ms +step:225/1670 train_time:21992ms step_avg:97.74ms +step:226/1670 train_time:22087ms step_avg:97.73ms +step:227/1670 train_time:22182ms step_avg:97.72ms +step:228/1670 train_time:22276ms step_avg:97.70ms +step:229/1670 train_time:22371ms step_avg:97.69ms +step:230/1670 train_time:22466ms step_avg:97.68ms +step:231/1670 train_time:22560ms step_avg:97.66ms +step:232/1670 train_time:22654ms step_avg:97.65ms +step:233/1670 train_time:22751ms step_avg:97.64ms +step:234/1670 train_time:22849ms step_avg:97.65ms +step:235/1670 train_time:22946ms step_avg:97.64ms +step:236/1670 train_time:23042ms step_avg:97.63ms +step:237/1670 train_time:23137ms step_avg:97.63ms +step:238/1670 train_time:23233ms step_avg:97.62ms +step:239/1670 train_time:23328ms step_avg:97.61ms +step:240/1670 train_time:23423ms step_avg:97.60ms +step:241/1670 train_time:23518ms step_avg:97.58ms +step:242/1670 train_time:23613ms step_avg:97.57ms +step:243/1670 train_time:23709ms step_avg:97.57ms +step:244/1670 train_time:23806ms step_avg:97.56ms +step:245/1670 train_time:23902ms step_avg:97.56ms +step:246/1670 train_time:23998ms step_avg:97.55ms +step:247/1670 train_time:24094ms step_avg:97.55ms +step:248/1670 train_time:24190ms step_avg:97.54ms +step:249/1670 train_time:24286ms step_avg:97.53ms +step:250/1670 train_time:24381ms step_avg:97.53ms +step:250/1670 val_loss:3.9606 train_time:24475ms step_avg:97.90ms +step:251/1670 train_time:24495ms step_avg:97.59ms +step:252/1670 train_time:24577ms step_avg:97.53ms +step:253/1670 train_time:24676ms step_avg:97.53ms +step:254/1670 train_time:24771ms step_avg:97.52ms +step:255/1670 train_time:24866ms step_avg:97.51ms +step:256/1670 train_time:24961ms step_avg:97.50ms +step:257/1670 train_time:25055ms step_avg:97.49ms +step:258/1670 train_time:25150ms step_avg:97.48ms +step:259/1670 train_time:25244ms step_avg:97.47ms +step:260/1670 train_time:25339ms step_avg:97.46ms +step:261/1670 train_time:25434ms step_avg:97.45ms +step:262/1670 train_time:25530ms step_avg:97.44ms +step:263/1670 train_time:25628ms step_avg:97.44ms +step:264/1670 train_time:25726ms step_avg:97.45ms +step:265/1670 train_time:25822ms step_avg:97.44ms +step:266/1670 train_time:25917ms step_avg:97.43ms +step:267/1670 train_time:26012ms step_avg:97.42ms +step:268/1670 train_time:26106ms step_avg:97.41ms +step:269/1670 train_time:26201ms step_avg:97.40ms +step:270/1670 train_time:26295ms step_avg:97.39ms +step:271/1670 train_time:26390ms step_avg:97.38ms +step:272/1670 train_time:26486ms step_avg:97.37ms +step:273/1670 train_time:26583ms step_avg:97.37ms +step:274/1670 train_time:26680ms step_avg:97.37ms +step:275/1670 train_time:26775ms step_avg:97.37ms +step:276/1670 train_time:26871ms step_avg:97.36ms +step:277/1670 train_time:26966ms step_avg:97.35ms +step:278/1670 train_time:27061ms step_avg:97.34ms +step:279/1670 train_time:27156ms step_avg:97.33ms +step:280/1670 train_time:27252ms step_avg:97.33ms +step:281/1670 train_time:27347ms step_avg:97.32ms +step:282/1670 train_time:27443ms step_avg:97.31ms +step:283/1670 train_time:27539ms step_avg:97.31ms +step:284/1670 train_time:27634ms step_avg:97.30ms +step:285/1670 train_time:27730ms step_avg:97.30ms +step:286/1670 train_time:27826ms step_avg:97.29ms +step:287/1670 train_time:27922ms step_avg:97.29ms +step:288/1670 train_time:28017ms step_avg:97.28ms +step:289/1670 train_time:28112ms step_avg:97.27ms +step:290/1670 train_time:28206ms step_avg:97.26ms +step:291/1670 train_time:28302ms step_avg:97.26ms +step:292/1670 train_time:28397ms step_avg:97.25ms +step:293/1670 train_time:28492ms step_avg:97.24ms +step:294/1670 train_time:28588ms step_avg:97.24ms +step:295/1670 train_time:28686ms step_avg:97.24ms +step:296/1670 train_time:28782ms step_avg:97.24ms +step:297/1670 train_time:28878ms step_avg:97.23ms +step:298/1670 train_time:28973ms step_avg:97.22ms +step:299/1670 train_time:29068ms step_avg:97.22ms +step:300/1670 train_time:29163ms step_avg:97.21ms +step:301/1670 train_time:29258ms step_avg:97.20ms +step:302/1670 train_time:29352ms step_avg:97.19ms +step:303/1670 train_time:29448ms step_avg:97.19ms +step:304/1670 train_time:29544ms step_avg:97.18ms +step:305/1670 train_time:29640ms step_avg:97.18ms +step:306/1670 train_time:29736ms step_avg:97.18ms +step:307/1670 train_time:29833ms step_avg:97.17ms +step:308/1670 train_time:29928ms step_avg:97.17ms +step:309/1670 train_time:30024ms step_avg:97.16ms +step:310/1670 train_time:30119ms step_avg:97.16ms +step:311/1670 train_time:30214ms step_avg:97.15ms +step:312/1670 train_time:30309ms step_avg:97.14ms +step:313/1670 train_time:30404ms step_avg:97.14ms +step:314/1670 train_time:30499ms step_avg:97.13ms +step:315/1670 train_time:30595ms step_avg:97.13ms +step:316/1670 train_time:30690ms step_avg:97.12ms +step:317/1670 train_time:30787ms step_avg:97.12ms +step:318/1670 train_time:30884ms step_avg:97.12ms +step:319/1670 train_time:30979ms step_avg:97.11ms +step:320/1670 train_time:31074ms step_avg:97.11ms +step:321/1670 train_time:31170ms step_avg:97.10ms +step:322/1670 train_time:31265ms step_avg:97.10ms +step:323/1670 train_time:31361ms step_avg:97.09ms +step:324/1670 train_time:31456ms step_avg:97.09ms +step:325/1670 train_time:31551ms step_avg:97.08ms +step:326/1670 train_time:31647ms step_avg:97.08ms +step:327/1670 train_time:31743ms step_avg:97.07ms +step:328/1670 train_time:31839ms step_avg:97.07ms +step:329/1670 train_time:31934ms step_avg:97.06ms +step:330/1670 train_time:32030ms step_avg:97.06ms +step:331/1670 train_time:32125ms step_avg:97.05ms +step:332/1670 train_time:32221ms step_avg:97.05ms +step:333/1670 train_time:32317ms step_avg:97.05ms +step:334/1670 train_time:32412ms step_avg:97.04ms +step:335/1670 train_time:32507ms step_avg:97.04ms +step:336/1670 train_time:32603ms step_avg:97.03ms +step:337/1670 train_time:32698ms step_avg:97.03ms +step:338/1670 train_time:32793ms step_avg:97.02ms +step:339/1670 train_time:32888ms step_avg:97.01ms +step:340/1670 train_time:32984ms step_avg:97.01ms +step:341/1670 train_time:33080ms step_avg:97.01ms +step:342/1670 train_time:33176ms step_avg:97.01ms +step:343/1670 train_time:33271ms step_avg:97.00ms +step:344/1670 train_time:33367ms step_avg:97.00ms +step:345/1670 train_time:33463ms step_avg:96.99ms +step:346/1670 train_time:33558ms step_avg:96.99ms +step:347/1670 train_time:33654ms step_avg:96.98ms +step:348/1670 train_time:33749ms step_avg:96.98ms +step:349/1670 train_time:33845ms step_avg:96.98ms +step:350/1670 train_time:33940ms step_avg:96.97ms +step:351/1670 train_time:34036ms step_avg:96.97ms +step:352/1670 train_time:34131ms step_avg:96.96ms +step:353/1670 train_time:34228ms step_avg:96.96ms +step:354/1670 train_time:34323ms step_avg:96.96ms +step:355/1670 train_time:34419ms step_avg:96.95ms +step:356/1670 train_time:34514ms step_avg:96.95ms +step:357/1670 train_time:34609ms step_avg:96.94ms +step:358/1670 train_time:34705ms step_avg:96.94ms +step:359/1670 train_time:34801ms step_avg:96.94ms +step:360/1670 train_time:34896ms step_avg:96.93ms +step:361/1670 train_time:34991ms step_avg:96.93ms +step:362/1670 train_time:35086ms step_avg:96.92ms +step:363/1670 train_time:35182ms step_avg:96.92ms +step:364/1670 train_time:35278ms step_avg:96.92ms +step:365/1670 train_time:35372ms step_avg:96.91ms +step:366/1670 train_time:35468ms step_avg:96.91ms +step:367/1670 train_time:35563ms step_avg:96.90ms +step:368/1670 train_time:35659ms step_avg:96.90ms +step:369/1670 train_time:35755ms step_avg:96.90ms +step:370/1670 train_time:35850ms step_avg:96.89ms +step:371/1670 train_time:35946ms step_avg:96.89ms +step:372/1670 train_time:36042ms step_avg:96.89ms +step:373/1670 train_time:36137ms step_avg:96.88ms +step:374/1670 train_time:36232ms step_avg:96.88ms +step:375/1670 train_time:36327ms step_avg:96.87ms +step:375/1670 val_loss:3.8096 train_time:36422ms step_avg:97.13ms +step:376/1670 train_time:36443ms step_avg:96.92ms +step:377/1670 train_time:36526ms step_avg:96.89ms +step:378/1670 train_time:36623ms step_avg:96.89ms +step:379/1670 train_time:36718ms step_avg:96.88ms +step:380/1670 train_time:36813ms step_avg:96.88ms +step:381/1670 train_time:36908ms step_avg:96.87ms +step:382/1670 train_time:37003ms step_avg:96.87ms +step:383/1670 train_time:37097ms step_avg:96.86ms +step:384/1670 train_time:37194ms step_avg:96.86ms +step:385/1670 train_time:37288ms step_avg:96.85ms +step:386/1670 train_time:37383ms step_avg:96.85ms +step:387/1670 train_time:37481ms step_avg:96.85ms +step:388/1670 train_time:37579ms step_avg:96.85ms +step:389/1670 train_time:37675ms step_avg:96.85ms +step:390/1670 train_time:37771ms step_avg:96.85ms +step:391/1670 train_time:37866ms step_avg:96.84ms +step:392/1670 train_time:37961ms step_avg:96.84ms +step:393/1670 train_time:38056ms step_avg:96.83ms +step:394/1670 train_time:38152ms step_avg:96.83ms +step:395/1670 train_time:38248ms step_avg:96.83ms +step:396/1670 train_time:38343ms step_avg:96.83ms +step:397/1670 train_time:38439ms step_avg:96.82ms +step:398/1670 train_time:38537ms step_avg:96.83ms +step:399/1670 train_time:38633ms step_avg:96.82ms +step:400/1670 train_time:38728ms step_avg:96.82ms +step:401/1670 train_time:38823ms step_avg:96.82ms +step:402/1670 train_time:38918ms step_avg:96.81ms +step:403/1670 train_time:39014ms step_avg:96.81ms +step:404/1670 train_time:39110ms step_avg:96.81ms +step:405/1670 train_time:39206ms step_avg:96.80ms +step:406/1670 train_time:39301ms step_avg:96.80ms +step:407/1670 train_time:39397ms step_avg:96.80ms +step:408/1670 train_time:39495ms step_avg:96.80ms +step:409/1670 train_time:39592ms step_avg:96.80ms +step:410/1670 train_time:39688ms step_avg:96.80ms +step:411/1670 train_time:39783ms step_avg:96.79ms +step:412/1670 train_time:39878ms step_avg:96.79ms +step:413/1670 train_time:39973ms step_avg:96.79ms +step:414/1670 train_time:40068ms step_avg:96.78ms +step:415/1670 train_time:40163ms step_avg:96.78ms +step:416/1670 train_time:40258ms step_avg:96.77ms +step:417/1670 train_time:40354ms step_avg:96.77ms +step:418/1670 train_time:40450ms step_avg:96.77ms +step:419/1670 train_time:40546ms step_avg:96.77ms +step:420/1670 train_time:40642ms step_avg:96.77ms +step:421/1670 train_time:40738ms step_avg:96.76ms +step:422/1670 train_time:40834ms step_avg:96.76ms +step:423/1670 train_time:40929ms step_avg:96.76ms +step:424/1670 train_time:41024ms step_avg:96.76ms +step:425/1670 train_time:41351ms step_avg:97.30ms +step:426/1670 train_time:41425ms step_avg:97.24ms +step:427/1670 train_time:41519ms step_avg:97.23ms +step:428/1670 train_time:41613ms step_avg:97.23ms +step:429/1670 train_time:41708ms step_avg:97.22ms +step:430/1670 train_time:41803ms step_avg:97.22ms +step:431/1670 train_time:41897ms step_avg:97.21ms +step:432/1670 train_time:41992ms step_avg:97.20ms +step:433/1670 train_time:42086ms step_avg:97.20ms +step:434/1670 train_time:42180ms step_avg:97.19ms +step:435/1670 train_time:42281ms step_avg:97.20ms +step:436/1670 train_time:42381ms step_avg:97.20ms +step:437/1670 train_time:42479ms step_avg:97.21ms +step:438/1670 train_time:42575ms step_avg:97.20ms +step:439/1670 train_time:42671ms step_avg:97.20ms +step:440/1670 train_time:42765ms step_avg:97.19ms +step:441/1670 train_time:42859ms step_avg:97.19ms +step:442/1670 train_time:42954ms step_avg:97.18ms +step:443/1670 train_time:43048ms step_avg:97.17ms +step:444/1670 train_time:43143ms step_avg:97.17ms +step:445/1670 train_time:43239ms step_avg:97.17ms +step:446/1670 train_time:43337ms step_avg:97.17ms +step:447/1670 train_time:43435ms step_avg:97.17ms +step:448/1670 train_time:43532ms step_avg:97.17ms +step:449/1670 train_time:43628ms step_avg:97.17ms +step:450/1670 train_time:43722ms step_avg:97.16ms +step:451/1670 train_time:43817ms step_avg:97.15ms +step:452/1670 train_time:43912ms step_avg:97.15ms +step:453/1670 train_time:44006ms step_avg:97.14ms +step:454/1670 train_time:44101ms step_avg:97.14ms +step:455/1670 train_time:44197ms step_avg:97.14ms +step:456/1670 train_time:44293ms step_avg:97.13ms +step:457/1670 train_time:44389ms step_avg:97.13ms +step:458/1670 train_time:44486ms step_avg:97.13ms +step:459/1670 train_time:44583ms step_avg:97.13ms +step:460/1670 train_time:44678ms step_avg:97.13ms +step:461/1670 train_time:44773ms step_avg:97.12ms +step:462/1670 train_time:44868ms step_avg:97.12ms +step:463/1670 train_time:44962ms step_avg:97.11ms +step:464/1670 train_time:45057ms step_avg:97.11ms +step:465/1670 train_time:45153ms step_avg:97.10ms +step:466/1670 train_time:45249ms step_avg:97.10ms +step:467/1670 train_time:45345ms step_avg:97.10ms +step:468/1670 train_time:45440ms step_avg:97.09ms +step:469/1670 train_time:45537ms step_avg:97.09ms +step:470/1670 train_time:45634ms step_avg:97.09ms +step:471/1670 train_time:45729ms step_avg:97.09ms +step:472/1670 train_time:45824ms step_avg:97.08ms +step:473/1670 train_time:45919ms step_avg:97.08ms +step:474/1670 train_time:46014ms step_avg:97.08ms +step:475/1670 train_time:46110ms step_avg:97.07ms +step:476/1670 train_time:46205ms step_avg:97.07ms +step:477/1670 train_time:46300ms step_avg:97.07ms +step:478/1670 train_time:46395ms step_avg:97.06ms +step:479/1670 train_time:46491ms step_avg:97.06ms +step:480/1670 train_time:46588ms step_avg:97.06ms +step:481/1670 train_time:46683ms step_avg:97.05ms +step:482/1670 train_time:46779ms step_avg:97.05ms +step:483/1670 train_time:46875ms step_avg:97.05ms +step:484/1670 train_time:46970ms step_avg:97.05ms +step:485/1670 train_time:47065ms step_avg:97.04ms +step:486/1670 train_time:47160ms step_avg:97.04ms +step:487/1670 train_time:47256ms step_avg:97.04ms +step:488/1670 train_time:47353ms step_avg:97.03ms +step:489/1670 train_time:47449ms step_avg:97.03ms +step:490/1670 train_time:47544ms step_avg:97.03ms +step:491/1670 train_time:47639ms step_avg:97.02ms +step:492/1670 train_time:47735ms step_avg:97.02ms +step:493/1670 train_time:47830ms step_avg:97.02ms +step:494/1670 train_time:47925ms step_avg:97.01ms +step:495/1670 train_time:48020ms step_avg:97.01ms +step:496/1670 train_time:48116ms step_avg:97.01ms +step:497/1670 train_time:48212ms step_avg:97.01ms +step:498/1670 train_time:48307ms step_avg:97.00ms +step:499/1670 train_time:48402ms step_avg:97.00ms +step:500/1670 train_time:48498ms step_avg:97.00ms +step:500/1670 val_loss:3.7107 train_time:48594ms step_avg:97.19ms +step:501/1670 train_time:48615ms step_avg:97.04ms +step:502/1670 train_time:48697ms step_avg:97.01ms +step:503/1670 train_time:48795ms step_avg:97.01ms +step:504/1670 train_time:48891ms step_avg:97.01ms +step:505/1670 train_time:48985ms step_avg:97.00ms +step:506/1670 train_time:49080ms step_avg:97.00ms +step:507/1670 train_time:49174ms step_avg:96.99ms +step:508/1670 train_time:49269ms step_avg:96.99ms +step:509/1670 train_time:49364ms step_avg:96.98ms +step:510/1670 train_time:49459ms step_avg:96.98ms +step:511/1670 train_time:49553ms step_avg:96.97ms +step:512/1670 train_time:49651ms step_avg:96.97ms +step:513/1670 train_time:49749ms step_avg:96.98ms +step:514/1670 train_time:49848ms step_avg:96.98ms +step:515/1670 train_time:49944ms step_avg:96.98ms +step:516/1670 train_time:50040ms step_avg:96.98ms +step:517/1670 train_time:50134ms step_avg:96.97ms +step:518/1670 train_time:50229ms step_avg:96.97ms +step:519/1670 train_time:50324ms step_avg:96.96ms +step:520/1670 train_time:50419ms step_avg:96.96ms +step:521/1670 train_time:50514ms step_avg:96.96ms +step:522/1670 train_time:50610ms step_avg:96.95ms +step:523/1670 train_time:50707ms step_avg:96.95ms +step:524/1670 train_time:50805ms step_avg:96.96ms +step:525/1670 train_time:50901ms step_avg:96.95ms +step:526/1670 train_time:50996ms step_avg:96.95ms +step:527/1670 train_time:51091ms step_avg:96.95ms +step:528/1670 train_time:51187ms step_avg:96.94ms +step:529/1670 train_time:51281ms step_avg:96.94ms +step:530/1670 train_time:51376ms step_avg:96.94ms +step:531/1670 train_time:51471ms step_avg:96.93ms +step:532/1670 train_time:51567ms step_avg:96.93ms +step:533/1670 train_time:51663ms step_avg:96.93ms +step:534/1670 train_time:51760ms step_avg:96.93ms +step:535/1670 train_time:51856ms step_avg:96.93ms +step:536/1670 train_time:51951ms step_avg:96.92ms +step:537/1670 train_time:52047ms step_avg:96.92ms +step:538/1670 train_time:52142ms step_avg:96.92ms +step:539/1670 train_time:52237ms step_avg:96.92ms +step:540/1670 train_time:52332ms step_avg:96.91ms +step:541/1670 train_time:52428ms step_avg:96.91ms +step:542/1670 train_time:52523ms step_avg:96.91ms +step:543/1670 train_time:52619ms step_avg:96.90ms +step:544/1670 train_time:52714ms step_avg:96.90ms +step:545/1670 train_time:52810ms step_avg:96.90ms +step:546/1670 train_time:52906ms step_avg:96.90ms +step:547/1670 train_time:53003ms step_avg:96.90ms +step:548/1670 train_time:53099ms step_avg:96.90ms +step:549/1670 train_time:53194ms step_avg:96.89ms +step:550/1670 train_time:53289ms step_avg:96.89ms +step:551/1670 train_time:53384ms step_avg:96.89ms +step:552/1670 train_time:53480ms step_avg:96.88ms +step:553/1670 train_time:53575ms step_avg:96.88ms +step:554/1670 train_time:53670ms step_avg:96.88ms +step:555/1670 train_time:53766ms step_avg:96.88ms +step:556/1670 train_time:53862ms step_avg:96.87ms +step:557/1670 train_time:53958ms step_avg:96.87ms +step:558/1670 train_time:54054ms step_avg:96.87ms +step:559/1670 train_time:54150ms step_avg:96.87ms +step:560/1670 train_time:54247ms step_avg:96.87ms +step:561/1670 train_time:54344ms step_avg:96.87ms +step:562/1670 train_time:54441ms step_avg:96.87ms +step:563/1670 train_time:54537ms step_avg:96.87ms +step:564/1670 train_time:54634ms step_avg:96.87ms +step:565/1670 train_time:54730ms step_avg:96.87ms +step:566/1670 train_time:54827ms step_avg:96.87ms +step:567/1670 train_time:54925ms step_avg:96.87ms +step:568/1670 train_time:55023ms step_avg:96.87ms +step:569/1670 train_time:55119ms step_avg:96.87ms +step:570/1670 train_time:55216ms step_avg:96.87ms +step:571/1670 train_time:55312ms step_avg:96.87ms +step:572/1670 train_time:55409ms step_avg:96.87ms +step:573/1670 train_time:55508ms step_avg:96.87ms +step:574/1670 train_time:55606ms step_avg:96.87ms +step:575/1670 train_time:55703ms step_avg:96.88ms +step:576/1670 train_time:55800ms step_avg:96.87ms +step:577/1670 train_time:55897ms step_avg:96.87ms +step:578/1670 train_time:55993ms step_avg:96.87ms +step:579/1670 train_time:56090ms step_avg:96.87ms +step:580/1670 train_time:56188ms step_avg:96.88ms +step:581/1670 train_time:56286ms step_avg:96.88ms +step:582/1670 train_time:56383ms step_avg:96.88ms +step:583/1670 train_time:56480ms step_avg:96.88ms +step:584/1670 train_time:56576ms step_avg:96.88ms +step:585/1670 train_time:56673ms step_avg:96.88ms +step:586/1670 train_time:56770ms step_avg:96.88ms +step:587/1670 train_time:56867ms step_avg:96.88ms +step:588/1670 train_time:56965ms step_avg:96.88ms +step:589/1670 train_time:57061ms step_avg:96.88ms +step:590/1670 train_time:57158ms step_avg:96.88ms +step:591/1670 train_time:57254ms step_avg:96.88ms +step:592/1670 train_time:57350ms step_avg:96.87ms +step:593/1670 train_time:57447ms step_avg:96.88ms +step:594/1670 train_time:57546ms step_avg:96.88ms +step:595/1670 train_time:57645ms step_avg:96.88ms +step:596/1670 train_time:57741ms step_avg:96.88ms +step:597/1670 train_time:57838ms step_avg:96.88ms +step:598/1670 train_time:57934ms step_avg:96.88ms +step:599/1670 train_time:58030ms step_avg:96.88ms +step:600/1670 train_time:58129ms step_avg:96.88ms +step:601/1670 train_time:58227ms step_avg:96.88ms +step:602/1670 train_time:58324ms step_avg:96.88ms +step:603/1670 train_time:58420ms step_avg:96.88ms +step:604/1670 train_time:58517ms step_avg:96.88ms +step:605/1670 train_time:58614ms step_avg:96.88ms +step:606/1670 train_time:58711ms step_avg:96.88ms +step:607/1670 train_time:58809ms step_avg:96.88ms +step:608/1670 train_time:58907ms step_avg:96.89ms +step:609/1670 train_time:59004ms step_avg:96.89ms +step:610/1670 train_time:59100ms step_avg:96.89ms +step:611/1670 train_time:59197ms step_avg:96.89ms +step:612/1670 train_time:59293ms step_avg:96.88ms +step:613/1670 train_time:59390ms step_avg:96.88ms +step:614/1670 train_time:59487ms step_avg:96.88ms +step:615/1670 train_time:59585ms step_avg:96.89ms +step:616/1670 train_time:59681ms step_avg:96.89ms +step:617/1670 train_time:59778ms step_avg:96.89ms +step:618/1670 train_time:59875ms step_avg:96.89ms +step:619/1670 train_time:59972ms step_avg:96.89ms +step:620/1670 train_time:60069ms step_avg:96.89ms +step:621/1670 train_time:60166ms step_avg:96.89ms +step:622/1670 train_time:60263ms step_avg:96.89ms +step:623/1670 train_time:60360ms step_avg:96.89ms +step:624/1670 train_time:60457ms step_avg:96.89ms +step:625/1670 train_time:60554ms step_avg:96.89ms +step:625/1670 val_loss:3.6104 train_time:60650ms step_avg:97.04ms +step:626/1670 train_time:60671ms step_avg:96.92ms +step:627/1670 train_time:60759ms step_avg:96.90ms +step:628/1670 train_time:60858ms step_avg:96.91ms +step:629/1670 train_time:60955ms step_avg:96.91ms +step:630/1670 train_time:61050ms step_avg:96.91ms +step:631/1670 train_time:61146ms step_avg:96.90ms +step:632/1670 train_time:61242ms step_avg:96.90ms +step:633/1670 train_time:61338ms step_avg:96.90ms +step:634/1670 train_time:61433ms step_avg:96.90ms +step:635/1670 train_time:61529ms step_avg:96.90ms +step:636/1670 train_time:61628ms step_avg:96.90ms +step:637/1670 train_time:61730ms step_avg:96.91ms +step:638/1670 train_time:61830ms step_avg:96.91ms +step:639/1670 train_time:62201ms step_avg:97.34ms +step:640/1670 train_time:62304ms step_avg:97.35ms +step:641/1670 train_time:62399ms step_avg:97.35ms +step:642/1670 train_time:62495ms step_avg:97.34ms +step:643/1670 train_time:62590ms step_avg:97.34ms +step:644/1670 train_time:62687ms step_avg:97.34ms +step:645/1670 train_time:62783ms step_avg:97.34ms +step:646/1670 train_time:62878ms step_avg:97.33ms +step:647/1670 train_time:62973ms step_avg:97.33ms +step:648/1670 train_time:63069ms step_avg:97.33ms +step:649/1670 train_time:63166ms step_avg:97.33ms +step:650/1670 train_time:63268ms step_avg:97.34ms +step:651/1670 train_time:63368ms step_avg:97.34ms +step:652/1670 train_time:63466ms step_avg:97.34ms +step:653/1670 train_time:63563ms step_avg:97.34ms +step:654/1670 train_time:63659ms step_avg:97.34ms +step:655/1670 train_time:63754ms step_avg:97.33ms +step:656/1670 train_time:63850ms step_avg:97.33ms +step:657/1670 train_time:63948ms step_avg:97.33ms +step:658/1670 train_time:64045ms step_avg:97.33ms +step:659/1670 train_time:64141ms step_avg:97.33ms +step:660/1670 train_time:64240ms step_avg:97.33ms +step:661/1670 train_time:64338ms step_avg:97.34ms +step:662/1670 train_time:64436ms step_avg:97.34ms +step:663/1670 train_time:64533ms step_avg:97.33ms +step:664/1670 train_time:64630ms step_avg:97.33ms +step:665/1670 train_time:64726ms step_avg:97.33ms +step:666/1670 train_time:64822ms step_avg:97.33ms +step:667/1670 train_time:64918ms step_avg:97.33ms +step:668/1670 train_time:65014ms step_avg:97.33ms +step:669/1670 train_time:65110ms step_avg:97.33ms +step:670/1670 train_time:65207ms step_avg:97.32ms +step:671/1670 train_time:65306ms step_avg:97.33ms +step:672/1670 train_time:65404ms step_avg:97.33ms +step:673/1670 train_time:65503ms step_avg:97.33ms +step:674/1670 train_time:65600ms step_avg:97.33ms +step:675/1670 train_time:65698ms step_avg:97.33ms +step:676/1670 train_time:65793ms step_avg:97.33ms +step:677/1670 train_time:65889ms step_avg:97.33ms +step:678/1670 train_time:65986ms step_avg:97.32ms +step:679/1670 train_time:66082ms step_avg:97.32ms +step:680/1670 train_time:66178ms step_avg:97.32ms +step:681/1670 train_time:66275ms step_avg:97.32ms +step:682/1670 train_time:66373ms step_avg:97.32ms +step:683/1670 train_time:66470ms step_avg:97.32ms +step:684/1670 train_time:66568ms step_avg:97.32ms +step:685/1670 train_time:66666ms step_avg:97.32ms +step:686/1670 train_time:66764ms step_avg:97.32ms +step:687/1670 train_time:66861ms step_avg:97.32ms +step:688/1670 train_time:66956ms step_avg:97.32ms +step:689/1670 train_time:67052ms step_avg:97.32ms +step:690/1670 train_time:67149ms step_avg:97.32ms +step:691/1670 train_time:67246ms step_avg:97.32ms +step:692/1670 train_time:67344ms step_avg:97.32ms +step:693/1670 train_time:67440ms step_avg:97.32ms +step:694/1670 train_time:67538ms step_avg:97.32ms +step:695/1670 train_time:67634ms step_avg:97.32ms +step:696/1670 train_time:67731ms step_avg:97.31ms +step:697/1670 train_time:67828ms step_avg:97.31ms +step:698/1670 train_time:67926ms step_avg:97.31ms +step:699/1670 train_time:68022ms step_avg:97.31ms +step:700/1670 train_time:68119ms step_avg:97.31ms +step:701/1670 train_time:68215ms step_avg:97.31ms +step:702/1670 train_time:68312ms step_avg:97.31ms +step:703/1670 train_time:68408ms step_avg:97.31ms +step:704/1670 train_time:68506ms step_avg:97.31ms +step:705/1670 train_time:68603ms step_avg:97.31ms +step:706/1670 train_time:68700ms step_avg:97.31ms +step:707/1670 train_time:68798ms step_avg:97.31ms +step:708/1670 train_time:68894ms step_avg:97.31ms +step:709/1670 train_time:68991ms step_avg:97.31ms +step:710/1670 train_time:69088ms step_avg:97.31ms +step:711/1670 train_time:69186ms step_avg:97.31ms +step:712/1670 train_time:69284ms step_avg:97.31ms +step:713/1670 train_time:69380ms step_avg:97.31ms +step:714/1670 train_time:69477ms step_avg:97.31ms +step:715/1670 train_time:69573ms step_avg:97.30ms +step:716/1670 train_time:69669ms step_avg:97.30ms +step:717/1670 train_time:69768ms step_avg:97.31ms +step:718/1670 train_time:69865ms step_avg:97.31ms +step:719/1670 train_time:69961ms step_avg:97.30ms +step:720/1670 train_time:70058ms step_avg:97.30ms +step:721/1670 train_time:70155ms step_avg:97.30ms +step:722/1670 train_time:70253ms step_avg:97.30ms +step:723/1670 train_time:70350ms step_avg:97.30ms +step:724/1670 train_time:70447ms step_avg:97.30ms +step:725/1670 train_time:70545ms step_avg:97.30ms +step:726/1670 train_time:70642ms step_avg:97.30ms +step:727/1670 train_time:70739ms step_avg:97.30ms +step:728/1670 train_time:70836ms step_avg:97.30ms +step:729/1670 train_time:70932ms step_avg:97.30ms +step:730/1670 train_time:71029ms step_avg:97.30ms +step:731/1670 train_time:71127ms step_avg:97.30ms +step:732/1670 train_time:71225ms step_avg:97.30ms +step:733/1670 train_time:71323ms step_avg:97.30ms +step:734/1670 train_time:71420ms step_avg:97.30ms +step:735/1670 train_time:71517ms step_avg:97.30ms +step:736/1670 train_time:71614ms step_avg:97.30ms +step:737/1670 train_time:71710ms step_avg:97.30ms +step:738/1670 train_time:71808ms step_avg:97.30ms +step:739/1670 train_time:71906ms step_avg:97.30ms +step:740/1670 train_time:72003ms step_avg:97.30ms +step:741/1670 train_time:72100ms step_avg:97.30ms +step:742/1670 train_time:72196ms step_avg:97.30ms +step:743/1670 train_time:72293ms step_avg:97.30ms +step:744/1670 train_time:72390ms step_avg:97.30ms +step:745/1670 train_time:72488ms step_avg:97.30ms +step:746/1670 train_time:72584ms step_avg:97.30ms +step:747/1670 train_time:72682ms step_avg:97.30ms +step:748/1670 train_time:72778ms step_avg:97.30ms +step:749/1670 train_time:72876ms step_avg:97.30ms +step:750/1670 train_time:72973ms step_avg:97.30ms +step:750/1670 val_loss:3.5576 train_time:73069ms step_avg:97.43ms +step:751/1670 train_time:73089ms step_avg:97.32ms +step:752/1670 train_time:73172ms step_avg:97.30ms +step:753/1670 train_time:73269ms step_avg:97.30ms +step:754/1670 train_time:73364ms step_avg:97.30ms +step:755/1670 train_time:73460ms step_avg:97.30ms +step:756/1670 train_time:73556ms step_avg:97.30ms +step:757/1670 train_time:73653ms step_avg:97.30ms +step:758/1670 train_time:73750ms step_avg:97.29ms +step:759/1670 train_time:73847ms step_avg:97.30ms +step:760/1670 train_time:73943ms step_avg:97.29ms +step:761/1670 train_time:74043ms step_avg:97.30ms +step:762/1670 train_time:74144ms step_avg:97.30ms +step:763/1670 train_time:74243ms step_avg:97.30ms +step:764/1670 train_time:74341ms step_avg:97.30ms +step:765/1670 train_time:74438ms step_avg:97.30ms +step:766/1670 train_time:74534ms step_avg:97.30ms +step:767/1670 train_time:74630ms step_avg:97.30ms +step:768/1670 train_time:74726ms step_avg:97.30ms +step:769/1670 train_time:74824ms step_avg:97.30ms +step:770/1670 train_time:74921ms step_avg:97.30ms +step:771/1670 train_time:75020ms step_avg:97.30ms +step:772/1670 train_time:75121ms step_avg:97.31ms +step:773/1670 train_time:75221ms step_avg:97.31ms +step:774/1670 train_time:75319ms step_avg:97.31ms +step:775/1670 train_time:75416ms step_avg:97.31ms +step:776/1670 train_time:75512ms step_avg:97.31ms +step:777/1670 train_time:75608ms step_avg:97.31ms +step:778/1670 train_time:75704ms step_avg:97.31ms +step:779/1670 train_time:75801ms step_avg:97.31ms +step:780/1670 train_time:75898ms step_avg:97.31ms +step:781/1670 train_time:75997ms step_avg:97.31ms +step:782/1670 train_time:76095ms step_avg:97.31ms +step:783/1670 train_time:76192ms step_avg:97.31ms +step:784/1670 train_time:76289ms step_avg:97.31ms +step:785/1670 train_time:76385ms step_avg:97.31ms +step:786/1670 train_time:76483ms step_avg:97.31ms +step:787/1670 train_time:76580ms step_avg:97.31ms +step:788/1670 train_time:76678ms step_avg:97.31ms +step:789/1670 train_time:76774ms step_avg:97.31ms +step:790/1670 train_time:76871ms step_avg:97.30ms +step:791/1670 train_time:76967ms step_avg:97.30ms +step:792/1670 train_time:77065ms step_avg:97.30ms +step:793/1670 train_time:77163ms step_avg:97.30ms +step:794/1670 train_time:77262ms step_avg:97.31ms +step:795/1670 train_time:77359ms step_avg:97.31ms +step:796/1670 train_time:77456ms step_avg:97.31ms +step:797/1670 train_time:77553ms step_avg:97.31ms +step:798/1670 train_time:77649ms step_avg:97.30ms +step:799/1670 train_time:77745ms step_avg:97.30ms +step:800/1670 train_time:77843ms step_avg:97.30ms +step:801/1670 train_time:77941ms step_avg:97.30ms +step:802/1670 train_time:78038ms step_avg:97.30ms +step:803/1670 train_time:78136ms step_avg:97.31ms +step:804/1670 train_time:78233ms step_avg:97.30ms +step:805/1670 train_time:78331ms step_avg:97.31ms +step:806/1670 train_time:78427ms step_avg:97.30ms +step:807/1670 train_time:78524ms step_avg:97.30ms +step:808/1670 train_time:78621ms step_avg:97.30ms +step:809/1670 train_time:78718ms step_avg:97.30ms +step:810/1670 train_time:78815ms step_avg:97.30ms +step:811/1670 train_time:78912ms step_avg:97.30ms +step:812/1670 train_time:79008ms step_avg:97.30ms +step:813/1670 train_time:79106ms step_avg:97.30ms +step:814/1670 train_time:79204ms step_avg:97.30ms +step:815/1670 train_time:79303ms step_avg:97.30ms +step:816/1670 train_time:79402ms step_avg:97.31ms +step:817/1670 train_time:79499ms step_avg:97.31ms +step:818/1670 train_time:79595ms step_avg:97.30ms +step:819/1670 train_time:79692ms step_avg:97.30ms +step:820/1670 train_time:79788ms step_avg:97.30ms +step:821/1670 train_time:79884ms step_avg:97.30ms +step:822/1670 train_time:79982ms step_avg:97.30ms +step:823/1670 train_time:80080ms step_avg:97.30ms +step:824/1670 train_time:80178ms step_avg:97.30ms +step:825/1670 train_time:80275ms step_avg:97.30ms +step:826/1670 train_time:80372ms step_avg:97.30ms +step:827/1670 train_time:80469ms step_avg:97.30ms +step:828/1670 train_time:80565ms step_avg:97.30ms +step:829/1670 train_time:80662ms step_avg:97.30ms +step:830/1670 train_time:80760ms step_avg:97.30ms +step:831/1670 train_time:80856ms step_avg:97.30ms +step:832/1670 train_time:80953ms step_avg:97.30ms +step:833/1670 train_time:81050ms step_avg:97.30ms +step:834/1670 train_time:81146ms step_avg:97.30ms +step:835/1670 train_time:81244ms step_avg:97.30ms +step:836/1670 train_time:81344ms step_avg:97.30ms +step:837/1670 train_time:81442ms step_avg:97.30ms +step:838/1670 train_time:81539ms step_avg:97.30ms +step:839/1670 train_time:81637ms step_avg:97.30ms +step:840/1670 train_time:81734ms step_avg:97.30ms +step:841/1670 train_time:81831ms step_avg:97.30ms +step:842/1670 train_time:81927ms step_avg:97.30ms +step:843/1670 train_time:82024ms step_avg:97.30ms +step:844/1670 train_time:82120ms step_avg:97.30ms +step:845/1670 train_time:82217ms step_avg:97.30ms +step:846/1670 train_time:82314ms step_avg:97.30ms +step:847/1670 train_time:82411ms step_avg:97.30ms +step:848/1670 train_time:82508ms step_avg:97.30ms +step:849/1670 train_time:82606ms step_avg:97.30ms +step:850/1670 train_time:82702ms step_avg:97.30ms +step:851/1670 train_time:83014ms step_avg:97.55ms +step:852/1670 train_time:83090ms step_avg:97.52ms +step:853/1670 train_time:83185ms step_avg:97.52ms +step:854/1670 train_time:83281ms step_avg:97.52ms +step:855/1670 train_time:83377ms step_avg:97.52ms +step:856/1670 train_time:83472ms step_avg:97.51ms +step:857/1670 train_time:83568ms step_avg:97.51ms +step:858/1670 train_time:83664ms step_avg:97.51ms +step:859/1670 train_time:83761ms step_avg:97.51ms +step:860/1670 train_time:83857ms step_avg:97.51ms +step:861/1670 train_time:83955ms step_avg:97.51ms +step:862/1670 train_time:84058ms step_avg:97.52ms +step:863/1670 train_time:84157ms step_avg:97.52ms +step:864/1670 train_time:84254ms step_avg:97.52ms +step:865/1670 train_time:84350ms step_avg:97.51ms +step:866/1670 train_time:84445ms step_avg:97.51ms +step:867/1670 train_time:84542ms step_avg:97.51ms +step:868/1670 train_time:84638ms step_avg:97.51ms +step:869/1670 train_time:84734ms step_avg:97.51ms +step:870/1670 train_time:84829ms step_avg:97.51ms +step:871/1670 train_time:84926ms step_avg:97.50ms +step:872/1670 train_time:85025ms step_avg:97.51ms +step:873/1670 train_time:85124ms step_avg:97.51ms +step:874/1670 train_time:85222ms step_avg:97.51ms +step:875/1670 train_time:85320ms step_avg:97.51ms +step:875/1670 val_loss:3.5173 train_time:85416ms step_avg:97.62ms +step:876/1670 train_time:85436ms step_avg:97.53ms +step:877/1670 train_time:85522ms step_avg:97.52ms +step:878/1670 train_time:85624ms step_avg:97.52ms +step:879/1670 train_time:85722ms step_avg:97.52ms +step:880/1670 train_time:85818ms step_avg:97.52ms +step:881/1670 train_time:85915ms step_avg:97.52ms +step:882/1670 train_time:86010ms step_avg:97.52ms +step:883/1670 train_time:86106ms step_avg:97.52ms +step:884/1670 train_time:86202ms step_avg:97.51ms +step:885/1670 train_time:86299ms step_avg:97.51ms +step:886/1670 train_time:86397ms step_avg:97.51ms +step:887/1670 train_time:86496ms step_avg:97.52ms +step:888/1670 train_time:86596ms step_avg:97.52ms +step:889/1670 train_time:86694ms step_avg:97.52ms +step:890/1670 train_time:86790ms step_avg:97.52ms +step:891/1670 train_time:86886ms step_avg:97.52ms +step:892/1670 train_time:86982ms step_avg:97.51ms +step:893/1670 train_time:87079ms step_avg:97.51ms +step:894/1670 train_time:87175ms step_avg:97.51ms +step:895/1670 train_time:87271ms step_avg:97.51ms +step:896/1670 train_time:87368ms step_avg:97.51ms +step:897/1670 train_time:87465ms step_avg:97.51ms +step:898/1670 train_time:87564ms step_avg:97.51ms +step:899/1670 train_time:87662ms step_avg:97.51ms +step:900/1670 train_time:87761ms step_avg:97.51ms +step:901/1670 train_time:87860ms step_avg:97.51ms +step:902/1670 train_time:87957ms step_avg:97.51ms +step:903/1670 train_time:88053ms step_avg:97.51ms +step:904/1670 train_time:88149ms step_avg:97.51ms +step:905/1670 train_time:88245ms step_avg:97.51ms +step:906/1670 train_time:88342ms step_avg:97.51ms +step:907/1670 train_time:88439ms step_avg:97.51ms +step:908/1670 train_time:88537ms step_avg:97.51ms +step:909/1670 train_time:88634ms step_avg:97.51ms +step:910/1670 train_time:88733ms step_avg:97.51ms +step:911/1670 train_time:88831ms step_avg:97.51ms +step:912/1670 train_time:88929ms step_avg:97.51ms +step:913/1670 train_time:89025ms step_avg:97.51ms +step:914/1670 train_time:89122ms step_avg:97.51ms +step:915/1670 train_time:89219ms step_avg:97.51ms +step:916/1670 train_time:89315ms step_avg:97.51ms +step:917/1670 train_time:89412ms step_avg:97.50ms +step:918/1670 train_time:89509ms step_avg:97.50ms +step:919/1670 train_time:89605ms step_avg:97.50ms +step:920/1670 train_time:89703ms step_avg:97.50ms +step:921/1670 train_time:89801ms step_avg:97.50ms +step:922/1670 train_time:89900ms step_avg:97.51ms +step:923/1670 train_time:89998ms step_avg:97.51ms +step:924/1670 train_time:90096ms step_avg:97.51ms +step:925/1670 train_time:90192ms step_avg:97.51ms +step:926/1670 train_time:90288ms step_avg:97.50ms +step:927/1670 train_time:90384ms step_avg:97.50ms +step:928/1670 train_time:90481ms step_avg:97.50ms +step:929/1670 train_time:90579ms step_avg:97.50ms +step:930/1670 train_time:90677ms step_avg:97.50ms +step:931/1670 train_time:90774ms step_avg:97.50ms +step:932/1670 train_time:90872ms step_avg:97.50ms +step:933/1670 train_time:90969ms step_avg:97.50ms +step:934/1670 train_time:91066ms step_avg:97.50ms +step:935/1670 train_time:91163ms step_avg:97.50ms +step:936/1670 train_time:91261ms step_avg:97.50ms +step:937/1670 train_time:91358ms step_avg:97.50ms +step:938/1670 train_time:91455ms step_avg:97.50ms +step:939/1670 train_time:91551ms step_avg:97.50ms +step:940/1670 train_time:91647ms step_avg:97.50ms +step:941/1670 train_time:91744ms step_avg:97.50ms +step:942/1670 train_time:91843ms step_avg:97.50ms +step:943/1670 train_time:91940ms step_avg:97.50ms +step:944/1670 train_time:92038ms step_avg:97.50ms +step:945/1670 train_time:92134ms step_avg:97.50ms +step:946/1670 train_time:92232ms step_avg:97.50ms +step:947/1670 train_time:92328ms step_avg:97.50ms +step:948/1670 train_time:92425ms step_avg:97.50ms +step:949/1670 train_time:92523ms step_avg:97.50ms +step:950/1670 train_time:92621ms step_avg:97.50ms +step:951/1670 train_time:92719ms step_avg:97.50ms +step:952/1670 train_time:92816ms step_avg:97.50ms +step:953/1670 train_time:92913ms step_avg:97.50ms +step:954/1670 train_time:93012ms step_avg:97.50ms +step:955/1670 train_time:93108ms step_avg:97.49ms +step:956/1670 train_time:93204ms step_avg:97.49ms +step:957/1670 train_time:93301ms step_avg:97.49ms +step:958/1670 train_time:93398ms step_avg:97.49ms +step:959/1670 train_time:93495ms step_avg:97.49ms +step:960/1670 train_time:93592ms step_avg:97.49ms +step:961/1670 train_time:93689ms step_avg:97.49ms +step:962/1670 train_time:93787ms step_avg:97.49ms +step:963/1670 train_time:93885ms step_avg:97.49ms +step:964/1670 train_time:93983ms step_avg:97.49ms +step:965/1670 train_time:94080ms step_avg:97.49ms +step:966/1670 train_time:94179ms step_avg:97.49ms +step:967/1670 train_time:94276ms step_avg:97.49ms +step:968/1670 train_time:94373ms step_avg:97.49ms +step:969/1670 train_time:94469ms step_avg:97.49ms +step:970/1670 train_time:94565ms step_avg:97.49ms +step:971/1670 train_time:94663ms step_avg:97.49ms +step:972/1670 train_time:94760ms step_avg:97.49ms +step:973/1670 train_time:94858ms step_avg:97.49ms +step:974/1670 train_time:94955ms step_avg:97.49ms +step:975/1670 train_time:95053ms step_avg:97.49ms +step:976/1670 train_time:95150ms step_avg:97.49ms +step:977/1670 train_time:95246ms step_avg:97.49ms +step:978/1670 train_time:95342ms step_avg:97.49ms +step:979/1670 train_time:95440ms step_avg:97.49ms +step:980/1670 train_time:95537ms step_avg:97.49ms +step:981/1670 train_time:95634ms step_avg:97.49ms +step:982/1670 train_time:95732ms step_avg:97.49ms +step:983/1670 train_time:95829ms step_avg:97.49ms +step:984/1670 train_time:95926ms step_avg:97.49ms +step:985/1670 train_time:96022ms step_avg:97.48ms +step:986/1670 train_time:96120ms step_avg:97.48ms +step:987/1670 train_time:96218ms step_avg:97.49ms +step:988/1670 train_time:96315ms step_avg:97.48ms +step:989/1670 train_time:96412ms step_avg:97.48ms +step:990/1670 train_time:96508ms step_avg:97.48ms +step:991/1670 train_time:96604ms step_avg:97.48ms +step:992/1670 train_time:96702ms step_avg:97.48ms +step:993/1670 train_time:96800ms step_avg:97.48ms +step:994/1670 train_time:96898ms step_avg:97.48ms +step:995/1670 train_time:96996ms step_avg:97.48ms +step:996/1670 train_time:97094ms step_avg:97.48ms +step:997/1670 train_time:97190ms step_avg:97.48ms +step:998/1670 train_time:97287ms step_avg:97.48ms +step:999/1670 train_time:97384ms step_avg:97.48ms +step:1000/1670 train_time:97481ms step_avg:97.48ms +step:1000/1670 val_loss:3.4755 train_time:97578ms step_avg:97.58ms +step:1001/1670 train_time:97599ms step_avg:97.50ms +step:1002/1670 train_time:97679ms step_avg:97.48ms +step:1003/1670 train_time:97777ms step_avg:97.48ms +step:1004/1670 train_time:97874ms step_avg:97.48ms +step:1005/1670 train_time:97969ms step_avg:97.48ms +step:1006/1670 train_time:98066ms step_avg:97.48ms +step:1007/1670 train_time:98162ms step_avg:97.48ms +step:1008/1670 train_time:98257ms step_avg:97.48ms +step:1009/1670 train_time:98353ms step_avg:97.48ms +step:1010/1670 train_time:98449ms step_avg:97.47ms +step:1011/1670 train_time:98549ms step_avg:97.48ms +step:1012/1670 train_time:98649ms step_avg:97.48ms +step:1013/1670 train_time:98749ms step_avg:97.48ms +step:1014/1670 train_time:98847ms step_avg:97.48ms +step:1015/1670 train_time:98944ms step_avg:97.48ms +step:1016/1670 train_time:99040ms step_avg:97.48ms +step:1017/1670 train_time:99136ms step_avg:97.48ms +step:1018/1670 train_time:99232ms step_avg:97.48ms +step:1019/1670 train_time:99329ms step_avg:97.48ms +step:1020/1670 train_time:99426ms step_avg:97.48ms +step:1021/1670 train_time:99522ms step_avg:97.48ms +step:1022/1670 train_time:99620ms step_avg:97.48ms +step:1023/1670 train_time:99719ms step_avg:97.48ms +step:1024/1670 train_time:99817ms step_avg:97.48ms +step:1025/1670 train_time:99914ms step_avg:97.48ms +step:1026/1670 train_time:100011ms step_avg:97.48ms +step:1027/1670 train_time:100107ms step_avg:97.48ms +step:1028/1670 train_time:100204ms step_avg:97.47ms +step:1029/1670 train_time:100300ms step_avg:97.47ms +step:1030/1670 train_time:100396ms step_avg:97.47ms +step:1031/1670 train_time:100493ms step_avg:97.47ms +step:1032/1670 train_time:100591ms step_avg:97.47ms +step:1033/1670 train_time:100689ms step_avg:97.47ms +step:1034/1670 train_time:100789ms step_avg:97.47ms +step:1035/1670 train_time:100887ms step_avg:97.48ms +step:1036/1670 train_time:100984ms step_avg:97.47ms +step:1037/1670 train_time:101081ms step_avg:97.47ms +step:1038/1670 train_time:101177ms step_avg:97.47ms +step:1039/1670 train_time:101273ms step_avg:97.47ms +step:1040/1670 train_time:101370ms step_avg:97.47ms +step:1041/1670 train_time:101467ms step_avg:97.47ms +step:1042/1670 train_time:101564ms step_avg:97.47ms +step:1043/1670 train_time:101661ms step_avg:97.47ms +step:1044/1670 train_time:101759ms step_avg:97.47ms +step:1045/1670 train_time:101856ms step_avg:97.47ms +step:1046/1670 train_time:101953ms step_avg:97.47ms +step:1047/1670 train_time:102051ms step_avg:97.47ms +step:1048/1670 train_time:102147ms step_avg:97.47ms +step:1049/1670 train_time:102244ms step_avg:97.47ms +step:1050/1670 train_time:102341ms step_avg:97.47ms +step:1051/1670 train_time:102439ms step_avg:97.47ms +step:1052/1670 train_time:102535ms step_avg:97.47ms +step:1053/1670 train_time:102631ms step_avg:97.47ms +step:1054/1670 train_time:102729ms step_avg:97.47ms +step:1055/1670 train_time:102827ms step_avg:97.47ms +step:1056/1670 train_time:102925ms step_avg:97.47ms +step:1057/1670 train_time:103022ms step_avg:97.47ms +step:1058/1670 train_time:103120ms step_avg:97.47ms +step:1059/1670 train_time:103217ms step_avg:97.47ms +step:1060/1670 train_time:103313ms step_avg:97.46ms +step:1061/1670 train_time:103410ms step_avg:97.46ms +step:1062/1670 train_time:103685ms step_avg:97.63ms +step:1063/1670 train_time:103760ms step_avg:97.61ms +step:1064/1670 train_time:103856ms step_avg:97.61ms +step:1065/1670 train_time:103952ms step_avg:97.61ms +step:1066/1670 train_time:104048ms step_avg:97.61ms +step:1067/1670 train_time:104143ms step_avg:97.60ms +step:1068/1670 train_time:104240ms step_avg:97.60ms +step:1069/1670 train_time:104335ms step_avg:97.60ms +step:1070/1670 train_time:104431ms step_avg:97.60ms +step:1071/1670 train_time:104527ms step_avg:97.60ms +step:1072/1670 train_time:104632ms step_avg:97.60ms +step:1073/1670 train_time:104731ms step_avg:97.61ms +step:1074/1670 train_time:104829ms step_avg:97.61ms +step:1075/1670 train_time:104928ms step_avg:97.61ms +step:1076/1670 train_time:105024ms step_avg:97.61ms +step:1077/1670 train_time:105120ms step_avg:97.60ms +step:1078/1670 train_time:105216ms step_avg:97.60ms +step:1079/1670 train_time:105312ms step_avg:97.60ms +step:1080/1670 train_time:105408ms step_avg:97.60ms +step:1081/1670 train_time:105505ms step_avg:97.60ms +step:1082/1670 train_time:105604ms step_avg:97.60ms +step:1083/1670 train_time:105704ms step_avg:97.60ms +step:1084/1670 train_time:105803ms step_avg:97.60ms +step:1085/1670 train_time:105901ms step_avg:97.60ms +step:1086/1670 train_time:105997ms step_avg:97.60ms +step:1087/1670 train_time:106094ms step_avg:97.60ms +step:1088/1670 train_time:106190ms step_avg:97.60ms +step:1089/1670 train_time:106287ms step_avg:97.60ms +step:1090/1670 train_time:106383ms step_avg:97.60ms +step:1091/1670 train_time:106479ms step_avg:97.60ms +step:1092/1670 train_time:106576ms step_avg:97.60ms +step:1093/1670 train_time:106673ms step_avg:97.60ms +step:1094/1670 train_time:106771ms step_avg:97.60ms +step:1095/1670 train_time:106870ms step_avg:97.60ms +step:1096/1670 train_time:106967ms step_avg:97.60ms +step:1097/1670 train_time:107065ms step_avg:97.60ms +step:1098/1670 train_time:107161ms step_avg:97.60ms +step:1099/1670 train_time:107257ms step_avg:97.59ms +step:1100/1670 train_time:107352ms step_avg:97.59ms +step:1101/1670 train_time:107449ms step_avg:97.59ms +step:1102/1670 train_time:107547ms step_avg:97.59ms +step:1103/1670 train_time:107645ms step_avg:97.59ms +step:1104/1670 train_time:107742ms step_avg:97.59ms +step:1105/1670 train_time:107841ms step_avg:97.59ms +step:1106/1670 train_time:107938ms step_avg:97.59ms +step:1107/1670 train_time:108035ms step_avg:97.59ms +step:1108/1670 train_time:108132ms step_avg:97.59ms +step:1109/1670 train_time:108229ms step_avg:97.59ms +step:1110/1670 train_time:108325ms step_avg:97.59ms +step:1111/1670 train_time:108421ms step_avg:97.59ms +step:1112/1670 train_time:108518ms step_avg:97.59ms +step:1113/1670 train_time:108614ms step_avg:97.59ms +step:1114/1670 train_time:108711ms step_avg:97.59ms +step:1115/1670 train_time:108808ms step_avg:97.59ms +step:1116/1670 train_time:108906ms step_avg:97.59ms +step:1117/1670 train_time:109005ms step_avg:97.59ms +step:1118/1670 train_time:109106ms step_avg:97.59ms +step:1119/1670 train_time:109204ms step_avg:97.59ms +step:1120/1670 train_time:109301ms step_avg:97.59ms +step:1121/1670 train_time:109399ms step_avg:97.59ms +step:1122/1670 train_time:109496ms step_avg:97.59ms +step:1123/1670 train_time:109593ms step_avg:97.59ms +step:1124/1670 train_time:109691ms step_avg:97.59ms +step:1125/1670 train_time:109789ms step_avg:97.59ms +step:1125/1670 val_loss:3.4209 train_time:109887ms step_avg:97.68ms +step:1126/1670 train_time:109909ms step_avg:97.61ms +step:1127/1670 train_time:109995ms step_avg:97.60ms +step:1128/1670 train_time:110094ms step_avg:97.60ms +step:1129/1670 train_time:110192ms step_avg:97.60ms +step:1130/1670 train_time:110288ms step_avg:97.60ms +step:1131/1670 train_time:110385ms step_avg:97.60ms +step:1132/1670 train_time:110481ms step_avg:97.60ms +step:1133/1670 train_time:110578ms step_avg:97.60ms +step:1134/1670 train_time:110674ms step_avg:97.60ms +step:1135/1670 train_time:110772ms step_avg:97.60ms +step:1136/1670 train_time:110873ms step_avg:97.60ms +step:1137/1670 train_time:110974ms step_avg:97.60ms +step:1138/1670 train_time:111073ms step_avg:97.60ms +step:1139/1670 train_time:111171ms step_avg:97.60ms +step:1140/1670 train_time:111269ms step_avg:97.60ms +step:1141/1670 train_time:111366ms step_avg:97.60ms +step:1142/1670 train_time:111462ms step_avg:97.60ms +step:1143/1670 train_time:111558ms step_avg:97.60ms +step:1144/1670 train_time:111655ms step_avg:97.60ms +step:1145/1670 train_time:111753ms step_avg:97.60ms +step:1146/1670 train_time:111851ms step_avg:97.60ms +step:1147/1670 train_time:111952ms step_avg:97.60ms +step:1148/1670 train_time:112053ms step_avg:97.61ms +step:1149/1670 train_time:112151ms step_avg:97.61ms +step:1150/1670 train_time:112249ms step_avg:97.61ms +step:1151/1670 train_time:112345ms step_avg:97.61ms +step:1152/1670 train_time:112442ms step_avg:97.61ms +step:1153/1670 train_time:112538ms step_avg:97.60ms +step:1154/1670 train_time:112635ms step_avg:97.60ms +step:1155/1670 train_time:112733ms step_avg:97.60ms +step:1156/1670 train_time:112832ms step_avg:97.61ms +step:1157/1670 train_time:112931ms step_avg:97.61ms +step:1158/1670 train_time:113030ms step_avg:97.61ms +step:1159/1670 train_time:113128ms step_avg:97.61ms +step:1160/1670 train_time:113226ms step_avg:97.61ms +step:1161/1670 train_time:113324ms step_avg:97.61ms +step:1162/1670 train_time:113421ms step_avg:97.61ms +step:1163/1670 train_time:113518ms step_avg:97.61ms +step:1164/1670 train_time:113614ms step_avg:97.61ms +step:1165/1670 train_time:113713ms step_avg:97.61ms +step:1166/1670 train_time:113811ms step_avg:97.61ms +step:1167/1670 train_time:113909ms step_avg:97.61ms +step:1168/1670 train_time:114008ms step_avg:97.61ms +step:1169/1670 train_time:114106ms step_avg:97.61ms +step:1170/1670 train_time:114204ms step_avg:97.61ms +step:1171/1670 train_time:114301ms step_avg:97.61ms +step:1172/1670 train_time:114398ms step_avg:97.61ms +step:1173/1670 train_time:114495ms step_avg:97.61ms +step:1174/1670 train_time:114592ms step_avg:97.61ms +step:1175/1670 train_time:114689ms step_avg:97.61ms +step:1176/1670 train_time:114787ms step_avg:97.61ms +step:1177/1670 train_time:114885ms step_avg:97.61ms +step:1178/1670 train_time:114983ms step_avg:97.61ms +step:1179/1670 train_time:115081ms step_avg:97.61ms +step:1180/1670 train_time:115178ms step_avg:97.61ms +step:1181/1670 train_time:115277ms step_avg:97.61ms +step:1182/1670 train_time:115376ms step_avg:97.61ms +step:1183/1670 train_time:115473ms step_avg:97.61ms +step:1184/1670 train_time:115571ms step_avg:97.61ms +step:1185/1670 train_time:115669ms step_avg:97.61ms +step:1186/1670 train_time:115766ms step_avg:97.61ms +step:1187/1670 train_time:115864ms step_avg:97.61ms +step:1188/1670 train_time:115961ms step_avg:97.61ms +step:1189/1670 train_time:116059ms step_avg:97.61ms +step:1190/1670 train_time:116157ms step_avg:97.61ms +step:1191/1670 train_time:116254ms step_avg:97.61ms +step:1192/1670 train_time:116353ms step_avg:97.61ms +step:1193/1670 train_time:116451ms step_avg:97.61ms +step:1194/1670 train_time:116548ms step_avg:97.61ms +step:1195/1670 train_time:116646ms step_avg:97.61ms +step:1196/1670 train_time:116743ms step_avg:97.61ms +step:1197/1670 train_time:116841ms step_avg:97.61ms +step:1198/1670 train_time:116938ms step_avg:97.61ms +step:1199/1670 train_time:117037ms step_avg:97.61ms +step:1200/1670 train_time:117134ms step_avg:97.61ms +step:1201/1670 train_time:117233ms step_avg:97.61ms +step:1202/1670 train_time:117330ms step_avg:97.61ms +step:1203/1670 train_time:117428ms step_avg:97.61ms +step:1204/1670 train_time:117525ms step_avg:97.61ms +step:1205/1670 train_time:117623ms step_avg:97.61ms +step:1206/1670 train_time:117721ms step_avg:97.61ms +step:1207/1670 train_time:117818ms step_avg:97.61ms +step:1208/1670 train_time:117916ms step_avg:97.61ms +step:1209/1670 train_time:118014ms step_avg:97.61ms +step:1210/1670 train_time:118112ms step_avg:97.61ms +step:1211/1670 train_time:118209ms step_avg:97.61ms +step:1212/1670 train_time:118306ms step_avg:97.61ms +step:1213/1670 train_time:118403ms step_avg:97.61ms +step:1214/1670 train_time:118500ms step_avg:97.61ms +step:1215/1670 train_time:118599ms step_avg:97.61ms +step:1216/1670 train_time:118697ms step_avg:97.61ms +step:1217/1670 train_time:118795ms step_avg:97.61ms +step:1218/1670 train_time:118893ms step_avg:97.61ms +step:1219/1670 train_time:118991ms step_avg:97.61ms +step:1220/1670 train_time:119088ms step_avg:97.61ms +step:1221/1670 train_time:119186ms step_avg:97.61ms +step:1222/1670 train_time:119283ms step_avg:97.61ms +step:1223/1670 train_time:119380ms step_avg:97.61ms +step:1224/1670 train_time:119477ms step_avg:97.61ms +step:1225/1670 train_time:119575ms step_avg:97.61ms +step:1226/1670 train_time:119673ms step_avg:97.61ms +step:1227/1670 train_time:119771ms step_avg:97.61ms +step:1228/1670 train_time:119870ms step_avg:97.61ms +step:1229/1670 train_time:119967ms step_avg:97.61ms +step:1230/1670 train_time:120065ms step_avg:97.61ms +step:1231/1670 train_time:120162ms step_avg:97.61ms +step:1232/1670 train_time:120259ms step_avg:97.61ms +step:1233/1670 train_time:120356ms step_avg:97.61ms +step:1234/1670 train_time:120454ms step_avg:97.61ms +step:1235/1670 train_time:120554ms step_avg:97.61ms +step:1236/1670 train_time:120653ms step_avg:97.62ms +step:1237/1670 train_time:120751ms step_avg:97.62ms +step:1238/1670 train_time:120849ms step_avg:97.62ms +step:1239/1670 train_time:120947ms step_avg:97.62ms +step:1240/1670 train_time:121044ms step_avg:97.62ms +step:1241/1670 train_time:121142ms step_avg:97.62ms +step:1242/1670 train_time:121239ms step_avg:97.62ms +step:1243/1670 train_time:121336ms step_avg:97.62ms +step:1244/1670 train_time:121434ms step_avg:97.62ms +step:1245/1670 train_time:121533ms step_avg:97.62ms +step:1246/1670 train_time:121632ms step_avg:97.62ms +step:1247/1670 train_time:121730ms step_avg:97.62ms +step:1248/1670 train_time:121828ms step_avg:97.62ms +step:1249/1670 train_time:121926ms step_avg:97.62ms +step:1250/1670 train_time:122024ms step_avg:97.62ms +step:1250/1670 val_loss:3.3788 train_time:122121ms step_avg:97.70ms +step:1251/1670 train_time:122142ms step_avg:97.64ms +step:1252/1670 train_time:122227ms step_avg:97.63ms +step:1253/1670 train_time:122326ms step_avg:97.63ms +step:1254/1670 train_time:122424ms step_avg:97.63ms +step:1255/1670 train_time:122520ms step_avg:97.63ms +step:1256/1670 train_time:122617ms step_avg:97.63ms +step:1257/1670 train_time:122714ms step_avg:97.62ms +step:1258/1670 train_time:122810ms step_avg:97.62ms +step:1259/1670 train_time:122907ms step_avg:97.62ms +step:1260/1670 train_time:123003ms step_avg:97.62ms +step:1261/1670 train_time:123102ms step_avg:97.62ms +step:1262/1670 train_time:123204ms step_avg:97.63ms +step:1263/1670 train_time:123302ms step_avg:97.63ms +step:1264/1670 train_time:123401ms step_avg:97.63ms +step:1265/1670 train_time:123499ms step_avg:97.63ms +step:1266/1670 train_time:123596ms step_avg:97.63ms +step:1267/1670 train_time:123693ms step_avg:97.63ms +step:1268/1670 train_time:123790ms step_avg:97.63ms +step:1269/1670 train_time:123887ms step_avg:97.63ms +step:1270/1670 train_time:123984ms step_avg:97.63ms +step:1271/1670 train_time:124082ms step_avg:97.63ms +step:1272/1670 train_time:124181ms step_avg:97.63ms +step:1273/1670 train_time:124281ms step_avg:97.63ms +step:1274/1670 train_time:124553ms step_avg:97.77ms +step:1275/1670 train_time:124721ms step_avg:97.82ms +step:1276/1670 train_time:124816ms step_avg:97.82ms +step:1277/1670 train_time:124912ms step_avg:97.82ms +step:1278/1670 train_time:125009ms step_avg:97.82ms +step:1279/1670 train_time:125104ms step_avg:97.81ms +step:1280/1670 train_time:125201ms step_avg:97.81ms +step:1281/1670 train_time:125298ms step_avg:97.81ms +step:1282/1670 train_time:125395ms step_avg:97.81ms +step:1283/1670 train_time:125491ms step_avg:97.81ms +step:1284/1670 train_time:125590ms step_avg:97.81ms +step:1285/1670 train_time:125693ms step_avg:97.82ms +step:1286/1670 train_time:125792ms step_avg:97.82ms +step:1287/1670 train_time:125889ms step_avg:97.82ms +step:1288/1670 train_time:125986ms step_avg:97.81ms +step:1289/1670 train_time:126083ms step_avg:97.81ms +step:1290/1670 train_time:126180ms step_avg:97.81ms +step:1291/1670 train_time:126277ms step_avg:97.81ms +step:1292/1670 train_time:126374ms step_avg:97.81ms +step:1293/1670 train_time:126470ms step_avg:97.81ms +step:1294/1670 train_time:126568ms step_avg:97.81ms +step:1295/1670 train_time:126667ms step_avg:97.81ms +step:1296/1670 train_time:126765ms step_avg:97.81ms +step:1297/1670 train_time:126864ms step_avg:97.81ms +step:1298/1670 train_time:126962ms step_avg:97.81ms +step:1299/1670 train_time:127059ms step_avg:97.81ms +step:1300/1670 train_time:127157ms step_avg:97.81ms +step:1301/1670 train_time:127253ms step_avg:97.81ms +step:1302/1670 train_time:127350ms step_avg:97.81ms +step:1303/1670 train_time:127447ms step_avg:97.81ms +step:1304/1670 train_time:127544ms step_avg:97.81ms +step:1305/1670 train_time:127642ms step_avg:97.81ms +step:1306/1670 train_time:127741ms step_avg:97.81ms +step:1307/1670 train_time:127840ms step_avg:97.81ms +step:1308/1670 train_time:127940ms step_avg:97.81ms +step:1309/1670 train_time:128038ms step_avg:97.81ms +step:1310/1670 train_time:128136ms step_avg:97.81ms +step:1311/1670 train_time:128233ms step_avg:97.81ms +step:1312/1670 train_time:128330ms step_avg:97.81ms +step:1313/1670 train_time:128427ms step_avg:97.81ms +step:1314/1670 train_time:128524ms step_avg:97.81ms +step:1315/1670 train_time:128622ms step_avg:97.81ms +step:1316/1670 train_time:128720ms step_avg:97.81ms +step:1317/1670 train_time:128818ms step_avg:97.81ms +step:1318/1670 train_time:128917ms step_avg:97.81ms +step:1319/1670 train_time:129014ms step_avg:97.81ms +step:1320/1670 train_time:129112ms step_avg:97.81ms +step:1321/1670 train_time:129209ms step_avg:97.81ms +step:1322/1670 train_time:129306ms step_avg:97.81ms +step:1323/1670 train_time:129404ms step_avg:97.81ms +step:1324/1670 train_time:129501ms step_avg:97.81ms +step:1325/1670 train_time:129598ms step_avg:97.81ms +step:1326/1670 train_time:129696ms step_avg:97.81ms +step:1327/1670 train_time:129795ms step_avg:97.81ms +step:1328/1670 train_time:129894ms step_avg:97.81ms +step:1329/1670 train_time:129991ms step_avg:97.81ms +step:1330/1670 train_time:130090ms step_avg:97.81ms +step:1331/1670 train_time:130187ms step_avg:97.81ms +step:1332/1670 train_time:130284ms step_avg:97.81ms +step:1333/1670 train_time:130381ms step_avg:97.81ms +step:1334/1670 train_time:130478ms step_avg:97.81ms +step:1335/1670 train_time:130576ms step_avg:97.81ms +step:1336/1670 train_time:130674ms step_avg:97.81ms +step:1337/1670 train_time:130773ms step_avg:97.81ms +step:1338/1670 train_time:130871ms step_avg:97.81ms +step:1339/1670 train_time:130969ms step_avg:97.81ms +step:1340/1670 train_time:131066ms step_avg:97.81ms +step:1341/1670 train_time:131164ms step_avg:97.81ms +step:1342/1670 train_time:131262ms step_avg:97.81ms +step:1343/1670 train_time:131360ms step_avg:97.81ms +step:1344/1670 train_time:131458ms step_avg:97.81ms +step:1345/1670 train_time:131556ms step_avg:97.81ms +step:1346/1670 train_time:131655ms step_avg:97.81ms +step:1347/1670 train_time:131753ms step_avg:97.81ms +step:1348/1670 train_time:131851ms step_avg:97.81ms +step:1349/1670 train_time:131949ms step_avg:97.81ms +step:1350/1670 train_time:132046ms step_avg:97.81ms +step:1351/1670 train_time:132144ms step_avg:97.81ms +step:1352/1670 train_time:132241ms step_avg:97.81ms +step:1353/1670 train_time:132339ms step_avg:97.81ms +step:1354/1670 train_time:132437ms step_avg:97.81ms +step:1355/1670 train_time:132535ms step_avg:97.81ms +step:1356/1670 train_time:132632ms step_avg:97.81ms +step:1357/1670 train_time:132730ms step_avg:97.81ms +step:1358/1670 train_time:132827ms step_avg:97.81ms +step:1359/1670 train_time:132925ms step_avg:97.81ms +step:1360/1670 train_time:133023ms step_avg:97.81ms +step:1361/1670 train_time:133121ms step_avg:97.81ms +step:1362/1670 train_time:133220ms step_avg:97.81ms +step:1363/1670 train_time:133318ms step_avg:97.81ms +step:1364/1670 train_time:133415ms step_avg:97.81ms +step:1365/1670 train_time:133512ms step_avg:97.81ms +step:1366/1670 train_time:133609ms step_avg:97.81ms +step:1367/1670 train_time:133706ms step_avg:97.81ms +step:1368/1670 train_time:133804ms step_avg:97.81ms +step:1369/1670 train_time:133902ms step_avg:97.81ms +step:1370/1670 train_time:134000ms step_avg:97.81ms +step:1371/1670 train_time:134097ms step_avg:97.81ms +step:1372/1670 train_time:134195ms step_avg:97.81ms +step:1373/1670 train_time:134293ms step_avg:97.81ms +step:1374/1670 train_time:134390ms step_avg:97.81ms +step:1375/1670 train_time:134488ms step_avg:97.81ms +step:1375/1670 val_loss:3.3416 train_time:134585ms step_avg:97.88ms +step:1376/1670 train_time:134606ms step_avg:97.82ms +step:1377/1670 train_time:134692ms step_avg:97.82ms +step:1378/1670 train_time:134791ms step_avg:97.82ms +step:1379/1670 train_time:134889ms step_avg:97.82ms +step:1380/1670 train_time:134986ms step_avg:97.82ms +step:1381/1670 train_time:135082ms step_avg:97.81ms +step:1382/1670 train_time:135180ms step_avg:97.81ms +step:1383/1670 train_time:135276ms step_avg:97.81ms +step:1384/1670 train_time:135374ms step_avg:97.81ms +step:1385/1670 train_time:135472ms step_avg:97.81ms +step:1386/1670 train_time:135570ms step_avg:97.81ms +step:1387/1670 train_time:135669ms step_avg:97.81ms +step:1388/1670 train_time:135769ms step_avg:97.82ms +step:1389/1670 train_time:135867ms step_avg:97.82ms +step:1390/1670 train_time:135965ms step_avg:97.82ms +step:1391/1670 train_time:136062ms step_avg:97.82ms +step:1392/1670 train_time:136159ms step_avg:97.82ms +step:1393/1670 train_time:136256ms step_avg:97.82ms +step:1394/1670 train_time:136353ms step_avg:97.81ms +step:1395/1670 train_time:136451ms step_avg:97.81ms +step:1396/1670 train_time:136549ms step_avg:97.81ms +step:1397/1670 train_time:136649ms step_avg:97.82ms +step:1398/1670 train_time:136749ms step_avg:97.82ms +step:1399/1670 train_time:136848ms step_avg:97.82ms +step:1400/1670 train_time:136946ms step_avg:97.82ms +step:1401/1670 train_time:137043ms step_avg:97.82ms +step:1402/1670 train_time:137140ms step_avg:97.82ms +step:1403/1670 train_time:137236ms step_avg:97.82ms +step:1404/1670 train_time:137333ms step_avg:97.82ms +step:1405/1670 train_time:137431ms step_avg:97.82ms +step:1406/1670 train_time:137528ms step_avg:97.82ms +step:1407/1670 train_time:137626ms step_avg:97.82ms +step:1408/1670 train_time:137726ms step_avg:97.82ms +step:1409/1670 train_time:137826ms step_avg:97.82ms +step:1410/1670 train_time:137924ms step_avg:97.82ms +step:1411/1670 train_time:138022ms step_avg:97.82ms +step:1412/1670 train_time:138120ms step_avg:97.82ms +step:1413/1670 train_time:138218ms step_avg:97.82ms +step:1414/1670 train_time:138314ms step_avg:97.82ms +step:1415/1670 train_time:138412ms step_avg:97.82ms +step:1416/1670 train_time:138509ms step_avg:97.82ms +step:1417/1670 train_time:138607ms step_avg:97.82ms +step:1418/1670 train_time:138705ms step_avg:97.82ms +step:1419/1670 train_time:138804ms step_avg:97.82ms +step:1420/1670 train_time:138903ms step_avg:97.82ms +step:1421/1670 train_time:139001ms step_avg:97.82ms +step:1422/1670 train_time:139100ms step_avg:97.82ms +step:1423/1670 train_time:139199ms step_avg:97.82ms +step:1424/1670 train_time:139296ms step_avg:97.82ms +step:1425/1670 train_time:139393ms step_avg:97.82ms +step:1426/1670 train_time:139490ms step_avg:97.82ms +step:1427/1670 train_time:139588ms step_avg:97.82ms +step:1428/1670 train_time:139686ms step_avg:97.82ms +step:1429/1670 train_time:139784ms step_avg:97.82ms +step:1430/1670 train_time:139883ms step_avg:97.82ms +step:1431/1670 train_time:139980ms step_avg:97.82ms +step:1432/1670 train_time:140078ms step_avg:97.82ms +step:1433/1670 train_time:140175ms step_avg:97.82ms +step:1434/1670 train_time:140273ms step_avg:97.82ms +step:1435/1670 train_time:140370ms step_avg:97.82ms +step:1436/1670 train_time:140467ms step_avg:97.82ms +step:1437/1670 train_time:140565ms step_avg:97.82ms +step:1438/1670 train_time:140663ms step_avg:97.82ms +step:1439/1670 train_time:140761ms step_avg:97.82ms +step:1440/1670 train_time:140859ms step_avg:97.82ms +step:1441/1670 train_time:140956ms step_avg:97.82ms +step:1442/1670 train_time:141054ms step_avg:97.82ms +step:1443/1670 train_time:141152ms step_avg:97.82ms +step:1444/1670 train_time:141250ms step_avg:97.82ms +step:1445/1670 train_time:141347ms step_avg:97.82ms +step:1446/1670 train_time:141444ms step_avg:97.82ms +step:1447/1670 train_time:141542ms step_avg:97.82ms +step:1448/1670 train_time:141640ms step_avg:97.82ms +step:1449/1670 train_time:141737ms step_avg:97.82ms +step:1450/1670 train_time:141835ms step_avg:97.82ms +step:1451/1670 train_time:141932ms step_avg:97.82ms +step:1452/1670 train_time:142030ms step_avg:97.82ms +step:1453/1670 train_time:142128ms step_avg:97.82ms +step:1454/1670 train_time:142227ms step_avg:97.82ms +step:1455/1670 train_time:142325ms step_avg:97.82ms +step:1456/1670 train_time:142423ms step_avg:97.82ms +step:1457/1670 train_time:142521ms step_avg:97.82ms +step:1458/1670 train_time:142618ms step_avg:97.82ms +step:1459/1670 train_time:142716ms step_avg:97.82ms +step:1460/1670 train_time:142813ms step_avg:97.82ms +step:1461/1670 train_time:142910ms step_avg:97.82ms +step:1462/1670 train_time:143008ms step_avg:97.82ms +step:1463/1670 train_time:143106ms step_avg:97.82ms +step:1464/1670 train_time:143205ms step_avg:97.82ms +step:1465/1670 train_time:143304ms step_avg:97.82ms +step:1466/1670 train_time:143402ms step_avg:97.82ms +step:1467/1670 train_time:143500ms step_avg:97.82ms +step:1468/1670 train_time:143597ms step_avg:97.82ms +step:1469/1670 train_time:143695ms step_avg:97.82ms +step:1470/1670 train_time:143793ms step_avg:97.82ms +step:1471/1670 train_time:143890ms step_avg:97.82ms +step:1472/1670 train_time:143988ms step_avg:97.82ms +step:1473/1670 train_time:144086ms step_avg:97.82ms +step:1474/1670 train_time:144184ms step_avg:97.82ms +step:1475/1670 train_time:144282ms step_avg:97.82ms +step:1476/1670 train_time:144379ms step_avg:97.82ms +step:1477/1670 train_time:144476ms step_avg:97.82ms +step:1478/1670 train_time:144574ms step_avg:97.82ms +step:1479/1670 train_time:144671ms step_avg:97.82ms +step:1480/1670 train_time:144770ms step_avg:97.82ms +step:1481/1670 train_time:144867ms step_avg:97.82ms +step:1482/1670 train_time:144965ms step_avg:97.82ms +step:1483/1670 train_time:145063ms step_avg:97.82ms +step:1484/1670 train_time:145161ms step_avg:97.82ms +step:1485/1670 train_time:145431ms step_avg:97.93ms +step:1486/1670 train_time:145516ms step_avg:97.92ms +step:1487/1670 train_time:145612ms step_avg:97.92ms +step:1488/1670 train_time:145709ms step_avg:97.92ms +step:1489/1670 train_time:145806ms step_avg:97.92ms +step:1490/1670 train_time:145902ms step_avg:97.92ms +step:1491/1670 train_time:145998ms step_avg:97.92ms +step:1492/1670 train_time:146094ms step_avg:97.92ms +step:1493/1670 train_time:146191ms step_avg:97.92ms +step:1494/1670 train_time:146288ms step_avg:97.92ms +step:1495/1670 train_time:146392ms step_avg:97.92ms +step:1496/1670 train_time:146493ms step_avg:97.92ms +step:1497/1670 train_time:146592ms step_avg:97.92ms +step:1498/1670 train_time:146689ms step_avg:97.92ms +step:1499/1670 train_time:146787ms step_avg:97.92ms +step:1500/1670 train_time:146883ms step_avg:97.92ms +step:1500/1670 val_loss:3.3100 train_time:146979ms step_avg:97.99ms +step:1501/1670 train_time:146999ms step_avg:97.93ms +step:1502/1670 train_time:147083ms step_avg:97.92ms +step:1503/1670 train_time:147185ms step_avg:97.93ms +step:1504/1670 train_time:147282ms step_avg:97.93ms +step:1505/1670 train_time:147379ms step_avg:97.93ms +step:1506/1670 train_time:147476ms step_avg:97.93ms +step:1507/1670 train_time:147573ms step_avg:97.93ms +step:1508/1670 train_time:147671ms step_avg:97.92ms +step:1509/1670 train_time:147768ms step_avg:97.92ms +step:1510/1670 train_time:147865ms step_avg:97.92ms +step:1511/1670 train_time:147963ms step_avg:97.92ms +step:1512/1670 train_time:148063ms step_avg:97.93ms +step:1513/1670 train_time:148164ms step_avg:97.93ms +step:1514/1670 train_time:148262ms step_avg:97.93ms +step:1515/1670 train_time:148359ms step_avg:97.93ms +step:1516/1670 train_time:148456ms step_avg:97.93ms +step:1517/1670 train_time:148553ms step_avg:97.93ms +step:1518/1670 train_time:148650ms step_avg:97.92ms +step:1519/1670 train_time:148747ms step_avg:97.92ms +step:1520/1670 train_time:148845ms step_avg:97.92ms +step:1521/1670 train_time:148942ms step_avg:97.92ms +step:1522/1670 train_time:149041ms step_avg:97.92ms +step:1523/1670 train_time:149140ms step_avg:97.93ms +step:1524/1670 train_time:149239ms step_avg:97.93ms +step:1525/1670 train_time:149338ms step_avg:97.93ms +step:1526/1670 train_time:149435ms step_avg:97.93ms +step:1527/1670 train_time:149533ms step_avg:97.93ms +step:1528/1670 train_time:149631ms step_avg:97.93ms +step:1529/1670 train_time:149728ms step_avg:97.93ms +step:1530/1670 train_time:149825ms step_avg:97.92ms +step:1531/1670 train_time:149922ms step_avg:97.92ms +step:1532/1670 train_time:150021ms step_avg:97.92ms +step:1533/1670 train_time:150119ms step_avg:97.93ms +step:1534/1670 train_time:150218ms step_avg:97.93ms +step:1535/1670 train_time:150316ms step_avg:97.93ms +step:1536/1670 train_time:150413ms step_avg:97.93ms +step:1537/1670 train_time:150510ms step_avg:97.92ms +step:1538/1670 train_time:150607ms step_avg:97.92ms +step:1539/1670 train_time:150704ms step_avg:97.92ms +step:1540/1670 train_time:150801ms step_avg:97.92ms +step:1541/1670 train_time:150899ms step_avg:97.92ms +step:1542/1670 train_time:150998ms step_avg:97.92ms +step:1543/1670 train_time:151096ms step_avg:97.92ms +step:1544/1670 train_time:151195ms step_avg:97.92ms +step:1545/1670 train_time:151294ms step_avg:97.92ms +step:1546/1670 train_time:151392ms step_avg:97.92ms +step:1547/1670 train_time:151489ms step_avg:97.92ms +step:1548/1670 train_time:151587ms step_avg:97.92ms +step:1549/1670 train_time:151684ms step_avg:97.92ms +step:1550/1670 train_time:151781ms step_avg:97.92ms +step:1551/1670 train_time:151879ms step_avg:97.92ms +step:1552/1670 train_time:151975ms step_avg:97.92ms +step:1553/1670 train_time:152075ms step_avg:97.92ms +step:1554/1670 train_time:152174ms step_avg:97.92ms +step:1555/1670 train_time:152272ms step_avg:97.92ms +step:1556/1670 train_time:152370ms step_avg:97.92ms +step:1557/1670 train_time:152467ms step_avg:97.92ms +step:1558/1670 train_time:152564ms step_avg:97.92ms +step:1559/1670 train_time:152661ms step_avg:97.92ms +step:1560/1670 train_time:152758ms step_avg:97.92ms +step:1561/1670 train_time:152856ms step_avg:97.92ms +step:1562/1670 train_time:152955ms step_avg:97.92ms +step:1563/1670 train_time:153053ms step_avg:97.92ms +step:1564/1670 train_time:153150ms step_avg:97.92ms +step:1565/1670 train_time:153249ms step_avg:97.92ms +step:1566/1670 train_time:153348ms step_avg:97.92ms +step:1567/1670 train_time:153447ms step_avg:97.92ms +step:1568/1670 train_time:153544ms step_avg:97.92ms +step:1569/1670 train_time:153641ms step_avg:97.92ms +step:1570/1670 train_time:153738ms step_avg:97.92ms +step:1571/1670 train_time:153836ms step_avg:97.92ms +step:1572/1670 train_time:153934ms step_avg:97.92ms +step:1573/1670 train_time:154032ms step_avg:97.92ms +step:1574/1670 train_time:154130ms step_avg:97.92ms +step:1575/1670 train_time:154228ms step_avg:97.92ms +step:1576/1670 train_time:154326ms step_avg:97.92ms +step:1577/1670 train_time:154423ms step_avg:97.92ms +step:1578/1670 train_time:154521ms step_avg:97.92ms +step:1579/1670 train_time:154618ms step_avg:97.92ms +step:1580/1670 train_time:154715ms step_avg:97.92ms +step:1581/1670 train_time:154814ms step_avg:97.92ms +step:1582/1670 train_time:154912ms step_avg:97.92ms +step:1583/1670 train_time:155010ms step_avg:97.92ms +step:1584/1670 train_time:155108ms step_avg:97.92ms +step:1585/1670 train_time:155205ms step_avg:97.92ms +step:1586/1670 train_time:155303ms step_avg:97.92ms +step:1587/1670 train_time:155401ms step_avg:97.92ms +step:1588/1670 train_time:155498ms step_avg:97.92ms +step:1589/1670 train_time:155597ms step_avg:97.92ms +step:1590/1670 train_time:155696ms step_avg:97.92ms +step:1591/1670 train_time:155793ms step_avg:97.92ms +step:1592/1670 train_time:155891ms step_avg:97.92ms +step:1593/1670 train_time:155988ms step_avg:97.92ms +step:1594/1670 train_time:156085ms step_avg:97.92ms +step:1595/1670 train_time:156182ms step_avg:97.92ms +step:1596/1670 train_time:156280ms step_avg:97.92ms +step:1597/1670 train_time:156378ms step_avg:97.92ms +step:1598/1670 train_time:156477ms step_avg:97.92ms +step:1599/1670 train_time:156575ms step_avg:97.92ms +step:1600/1670 train_time:156673ms step_avg:97.92ms +step:1601/1670 train_time:156771ms step_avg:97.92ms +step:1602/1670 train_time:156868ms step_avg:97.92ms +step:1603/1670 train_time:156966ms step_avg:97.92ms +step:1604/1670 train_time:157062ms step_avg:97.92ms +step:1605/1670 train_time:157160ms step_avg:97.92ms +step:1606/1670 train_time:157258ms step_avg:97.92ms +step:1607/1670 train_time:157356ms step_avg:97.92ms +step:1608/1670 train_time:157454ms step_avg:97.92ms +step:1609/1670 train_time:157553ms step_avg:97.92ms +step:1610/1670 train_time:157650ms step_avg:97.92ms +step:1611/1670 train_time:157748ms step_avg:97.92ms +step:1612/1670 train_time:157846ms step_avg:97.92ms +step:1613/1670 train_time:157944ms step_avg:97.92ms +step:1614/1670 train_time:158041ms step_avg:97.92ms +step:1615/1670 train_time:158138ms step_avg:97.92ms +step:1616/1670 train_time:158236ms step_avg:97.92ms +step:1617/1670 train_time:158334ms step_avg:97.92ms +step:1618/1670 train_time:158432ms step_avg:97.92ms +step:1619/1670 train_time:158530ms step_avg:97.92ms +step:1620/1670 train_time:158627ms step_avg:97.92ms +step:1621/1670 train_time:158724ms step_avg:97.92ms +step:1622/1670 train_time:158822ms step_avg:97.92ms +step:1623/1670 train_time:158919ms step_avg:97.92ms +step:1624/1670 train_time:159017ms step_avg:97.92ms +step:1625/1670 train_time:159115ms step_avg:97.92ms +step:1625/1670 val_loss:3.2831 train_time:159213ms step_avg:97.98ms +step:1626/1670 train_time:159234ms step_avg:97.93ms +step:1627/1670 train_time:159317ms step_avg:97.92ms +step:1628/1670 train_time:159417ms step_avg:97.92ms +step:1629/1670 train_time:159515ms step_avg:97.92ms +step:1630/1670 train_time:159612ms step_avg:97.92ms +step:1631/1670 train_time:159709ms step_avg:97.92ms +step:1632/1670 train_time:159806ms step_avg:97.92ms +step:1633/1670 train_time:159902ms step_avg:97.92ms +step:1634/1670 train_time:160000ms step_avg:97.92ms +step:1635/1670 train_time:160097ms step_avg:97.92ms +step:1636/1670 train_time:160196ms step_avg:97.92ms +step:1637/1670 train_time:160296ms step_avg:97.92ms +step:1638/1670 train_time:160394ms step_avg:97.92ms +step:1639/1670 train_time:160493ms step_avg:97.92ms +step:1640/1670 train_time:160591ms step_avg:97.92ms +step:1641/1670 train_time:160688ms step_avg:97.92ms +step:1642/1670 train_time:160784ms step_avg:97.92ms +step:1643/1670 train_time:160881ms step_avg:97.92ms +step:1644/1670 train_time:160979ms step_avg:97.92ms +step:1645/1670 train_time:161077ms step_avg:97.92ms +step:1646/1670 train_time:161174ms step_avg:97.92ms +step:1647/1670 train_time:161272ms step_avg:97.92ms +step:1648/1670 train_time:161371ms step_avg:97.92ms +step:1649/1670 train_time:161469ms step_avg:97.92ms +step:1650/1670 train_time:161567ms step_avg:97.92ms +step:1651/1670 train_time:161664ms step_avg:97.92ms +step:1652/1670 train_time:161762ms step_avg:97.92ms +step:1653/1670 train_time:161859ms step_avg:97.92ms +step:1654/1670 train_time:161957ms step_avg:97.92ms +step:1655/1670 train_time:162054ms step_avg:97.92ms +step:1656/1670 train_time:162151ms step_avg:97.92ms +step:1657/1670 train_time:162250ms step_avg:97.92ms +step:1658/1670 train_time:162348ms step_avg:97.92ms +step:1659/1670 train_time:162446ms step_avg:97.92ms +step:1660/1670 train_time:162545ms step_avg:97.92ms +step:1661/1670 train_time:162644ms step_avg:97.92ms +step:1662/1670 train_time:162742ms step_avg:97.92ms +step:1663/1670 train_time:162840ms step_avg:97.92ms +step:1664/1670 train_time:162937ms step_avg:97.92ms +step:1665/1670 train_time:163034ms step_avg:97.92ms +step:1666/1670 train_time:163132ms step_avg:97.92ms +step:1667/1670 train_time:163229ms step_avg:97.92ms +step:1668/1670 train_time:163328ms step_avg:97.92ms +step:1669/1670 train_time:163426ms step_avg:97.92ms +step:1670/1670 train_time:163524ms step_avg:97.92ms +step:1670/1670 val_loss:3.2755 train_time:163621ms step_avg:97.98ms +peak memory allocated: 34217 MiB reserved: 49936 MiB diff --git a/records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt b/records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt new file mode 100644 index 000000000..a11cc06d9 --- /dev/null +++ b/records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:39:51 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 129W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 48531 C /usr/bin/python 0MiB | +| 0 N/A N/A 48532 C /usr/bin/python 0MiB | +| 0 N/A N/A 48533 C /usr/bin/python 0MiB | +| 0 N/A N/A 48534 C /usr/bin/python 0MiB | +| 0 N/A N/A 48535 C /usr/bin/python 0MiB | +| 0 N/A N/A 48536 C /usr/bin/python 0MiB | +| 0 N/A N/A 48537 C /usr/bin/python 0MiB | +| 0 N/A N/A 48538 C /usr/bin/python 0MiB | +| 1 N/A N/A 48532 C /usr/bin/python 0MiB | +| 2 N/A N/A 48533 C /usr/bin/python 0MiB | +| 3 N/A N/A 48534 C /usr/bin/python 0MiB | +| 4 N/A N/A 48535 C /usr/bin/python 0MiB | +| 5 N/A N/A 48536 C /usr/bin/python 0MiB | +| 6 N/A N/A 48537 C /usr/bin/python 0MiB | +| 7 N/A N/A 48538 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:407ms step_avg:407.50ms +step:2/1670 train_time:427ms step_avg:213.68ms +step:3/1670 train_time:501ms step_avg:167.04ms +step:4/1670 train_time:594ms step_avg:148.59ms +step:5/1670 train_time:689ms step_avg:137.71ms +step:6/1670 train_time:784ms step_avg:130.59ms +step:7/1670 train_time:878ms step_avg:125.44ms +step:8/1670 train_time:973ms step_avg:121.64ms +step:9/1670 train_time:1068ms step_avg:118.69ms +step:10/1670 train_time:1163ms step_avg:116.33ms +step:11/1670 train_time:1258ms step_avg:114.39ms +step:12/1670 train_time:1355ms step_avg:112.93ms +step:13/1670 train_time:1454ms step_avg:111.86ms +step:14/1670 train_time:1551ms step_avg:110.78ms +step:15/1670 train_time:1648ms step_avg:109.84ms +step:16/1670 train_time:1742ms step_avg:108.90ms +step:17/1670 train_time:1838ms step_avg:108.10ms +step:18/1670 train_time:1933ms step_avg:107.37ms +step:19/1670 train_time:2028ms step_avg:106.73ms +step:20/1670 train_time:2123ms step_avg:106.16ms +step:21/1670 train_time:2218ms step_avg:105.64ms +step:22/1670 train_time:2314ms step_avg:105.20ms +step:23/1670 train_time:2412ms step_avg:104.86ms +step:24/1670 train_time:2508ms step_avg:104.52ms +step:25/1670 train_time:2606ms step_avg:104.22ms +step:26/1670 train_time:2702ms step_avg:103.94ms +step:27/1670 train_time:2799ms step_avg:103.66ms +step:28/1670 train_time:2894ms step_avg:103.34ms +step:29/1670 train_time:2989ms step_avg:103.06ms +step:30/1670 train_time:3084ms step_avg:102.79ms +step:31/1670 train_time:3180ms step_avg:102.59ms +step:32/1670 train_time:3276ms step_avg:102.37ms +step:33/1670 train_time:3372ms step_avg:102.18ms +step:34/1670 train_time:3469ms step_avg:102.02ms +step:35/1670 train_time:3566ms step_avg:101.87ms +step:36/1670 train_time:3662ms step_avg:101.73ms +step:37/1670 train_time:3758ms step_avg:101.58ms +step:38/1670 train_time:3854ms step_avg:101.41ms +step:39/1670 train_time:3950ms step_avg:101.28ms +step:40/1670 train_time:4046ms step_avg:101.14ms +step:41/1670 train_time:4141ms step_avg:101.00ms +step:42/1670 train_time:4237ms step_avg:100.87ms +step:43/1670 train_time:4332ms step_avg:100.74ms +step:44/1670 train_time:4428ms step_avg:100.64ms +step:45/1670 train_time:4525ms step_avg:100.55ms +step:46/1670 train_time:4620ms step_avg:100.44ms +step:47/1670 train_time:4717ms step_avg:100.35ms +step:48/1670 train_time:4812ms step_avg:100.25ms +step:49/1670 train_time:4908ms step_avg:100.17ms +step:50/1670 train_time:5004ms step_avg:100.09ms +step:51/1670 train_time:5100ms step_avg:100.00ms +step:52/1670 train_time:5195ms step_avg:99.90ms +step:53/1670 train_time:5291ms step_avg:99.83ms +step:54/1670 train_time:5387ms step_avg:99.77ms +step:55/1670 train_time:5483ms step_avg:99.69ms +step:56/1670 train_time:5579ms step_avg:99.62ms +step:57/1670 train_time:5674ms step_avg:99.55ms +step:58/1670 train_time:5770ms step_avg:99.47ms +step:59/1670 train_time:5865ms step_avg:99.41ms +step:60/1670 train_time:5962ms step_avg:99.36ms +step:61/1670 train_time:6057ms step_avg:99.30ms +step:62/1670 train_time:6154ms step_avg:99.25ms +step:63/1670 train_time:6249ms step_avg:99.20ms +step:64/1670 train_time:6345ms step_avg:99.14ms +step:65/1670 train_time:6442ms step_avg:99.10ms +step:66/1670 train_time:6538ms step_avg:99.06ms +step:67/1670 train_time:6634ms step_avg:99.01ms +step:68/1670 train_time:6730ms step_avg:98.97ms +step:69/1670 train_time:6826ms step_avg:98.92ms +step:70/1670 train_time:6922ms step_avg:98.88ms +step:71/1670 train_time:7018ms step_avg:98.84ms +step:72/1670 train_time:7113ms step_avg:98.79ms +step:73/1670 train_time:7209ms step_avg:98.75ms +step:74/1670 train_time:7304ms step_avg:98.70ms +step:75/1670 train_time:7400ms step_avg:98.67ms +step:76/1670 train_time:7495ms step_avg:98.62ms +step:77/1670 train_time:7592ms step_avg:98.60ms +step:78/1670 train_time:7688ms step_avg:98.57ms +step:79/1670 train_time:7784ms step_avg:98.53ms +step:80/1670 train_time:7879ms step_avg:98.49ms +step:81/1670 train_time:7975ms step_avg:98.46ms +step:82/1670 train_time:8071ms step_avg:98.43ms +step:83/1670 train_time:8167ms step_avg:98.40ms +step:84/1670 train_time:8263ms step_avg:98.37ms +step:85/1670 train_time:8358ms step_avg:98.33ms +step:86/1670 train_time:8454ms step_avg:98.30ms +step:87/1670 train_time:8549ms step_avg:98.27ms +step:88/1670 train_time:8645ms step_avg:98.24ms +step:89/1670 train_time:8741ms step_avg:98.21ms +step:90/1670 train_time:8836ms step_avg:98.18ms +step:91/1670 train_time:8932ms step_avg:98.15ms +step:92/1670 train_time:9028ms step_avg:98.13ms +step:93/1670 train_time:9125ms step_avg:98.11ms +step:94/1670 train_time:9220ms step_avg:98.08ms +step:95/1670 train_time:9315ms step_avg:98.06ms +step:96/1670 train_time:9411ms step_avg:98.03ms +step:97/1670 train_time:9507ms step_avg:98.01ms +step:98/1670 train_time:9603ms step_avg:97.99ms +step:99/1670 train_time:9699ms step_avg:97.97ms +step:100/1670 train_time:9795ms step_avg:97.95ms +step:101/1670 train_time:9892ms step_avg:97.94ms +step:102/1670 train_time:9987ms step_avg:97.91ms +step:103/1670 train_time:10082ms step_avg:97.88ms +step:104/1670 train_time:10177ms step_avg:97.86ms +step:105/1670 train_time:10273ms step_avg:97.84ms +step:106/1670 train_time:10369ms step_avg:97.83ms +step:107/1670 train_time:10465ms step_avg:97.81ms +step:108/1670 train_time:10561ms step_avg:97.79ms +step:109/1670 train_time:10656ms step_avg:97.76ms +step:110/1670 train_time:10752ms step_avg:97.75ms +step:111/1670 train_time:10848ms step_avg:97.73ms +step:112/1670 train_time:10944ms step_avg:97.72ms +step:113/1670 train_time:11039ms step_avg:97.69ms +step:114/1670 train_time:11134ms step_avg:97.67ms +step:115/1670 train_time:11230ms step_avg:97.65ms +step:116/1670 train_time:11326ms step_avg:97.64ms +step:117/1670 train_time:11422ms step_avg:97.62ms +step:118/1670 train_time:11518ms step_avg:97.61ms +step:119/1670 train_time:11614ms step_avg:97.59ms +step:120/1670 train_time:11709ms step_avg:97.58ms +step:121/1670 train_time:11805ms step_avg:97.57ms +step:122/1670 train_time:11901ms step_avg:97.55ms +step:123/1670 train_time:11997ms step_avg:97.53ms +step:124/1670 train_time:12092ms step_avg:97.52ms +step:125/1670 train_time:12189ms step_avg:97.51ms +step:125/1670 val_loss:4.2887 train_time:12284ms step_avg:98.27ms +step:126/1670 train_time:12306ms step_avg:97.67ms +step:127/1670 train_time:12388ms step_avg:97.54ms +step:128/1670 train_time:12494ms step_avg:97.61ms +step:129/1670 train_time:12591ms step_avg:97.61ms +step:130/1670 train_time:12686ms step_avg:97.59ms +step:131/1670 train_time:12781ms step_avg:97.56ms +step:132/1670 train_time:12876ms step_avg:97.55ms +step:133/1670 train_time:12971ms step_avg:97.53ms +step:134/1670 train_time:13066ms step_avg:97.51ms +step:135/1670 train_time:13160ms step_avg:97.48ms +step:136/1670 train_time:13255ms step_avg:97.46ms +step:137/1670 train_time:13353ms step_avg:97.47ms +step:138/1670 train_time:13452ms step_avg:97.47ms +step:139/1670 train_time:13550ms step_avg:97.48ms +step:140/1670 train_time:13646ms step_avg:97.47ms +step:141/1670 train_time:13742ms step_avg:97.46ms +step:142/1670 train_time:13838ms step_avg:97.45ms +step:143/1670 train_time:13933ms step_avg:97.43ms +step:144/1670 train_time:14028ms step_avg:97.42ms +step:145/1670 train_time:14123ms step_avg:97.40ms +step:146/1670 train_time:14218ms step_avg:97.38ms +step:147/1670 train_time:14314ms step_avg:97.37ms +step:148/1670 train_time:14410ms step_avg:97.37ms +step:149/1670 train_time:14507ms step_avg:97.36ms +step:150/1670 train_time:14603ms step_avg:97.35ms +step:151/1670 train_time:14698ms step_avg:97.34ms +step:152/1670 train_time:14794ms step_avg:97.33ms +step:153/1670 train_time:14890ms step_avg:97.32ms +step:154/1670 train_time:14985ms step_avg:97.31ms +step:155/1670 train_time:15080ms step_avg:97.29ms +step:156/1670 train_time:15175ms step_avg:97.28ms +step:157/1670 train_time:15271ms step_avg:97.27ms +step:158/1670 train_time:15367ms step_avg:97.26ms +step:159/1670 train_time:15463ms step_avg:97.25ms +step:160/1670 train_time:15559ms step_avg:97.25ms +step:161/1670 train_time:15656ms step_avg:97.24ms +step:162/1670 train_time:15753ms step_avg:97.24ms +step:163/1670 train_time:15849ms step_avg:97.23ms +step:164/1670 train_time:15945ms step_avg:97.23ms +step:165/1670 train_time:16040ms step_avg:97.21ms +step:166/1670 train_time:16136ms step_avg:97.20ms +step:167/1670 train_time:16232ms step_avg:97.19ms +step:168/1670 train_time:16327ms step_avg:97.18ms +step:169/1670 train_time:16422ms step_avg:97.17ms +step:170/1670 train_time:16518ms step_avg:97.17ms +step:171/1670 train_time:16615ms step_avg:97.16ms +step:172/1670 train_time:16711ms step_avg:97.16ms +step:173/1670 train_time:16807ms step_avg:97.15ms +step:174/1670 train_time:16902ms step_avg:97.14ms +step:175/1670 train_time:16998ms step_avg:97.13ms +step:176/1670 train_time:17093ms step_avg:97.12ms +step:177/1670 train_time:17189ms step_avg:97.11ms +step:178/1670 train_time:17284ms step_avg:97.10ms +step:179/1670 train_time:17379ms step_avg:97.09ms +step:180/1670 train_time:17476ms step_avg:97.09ms +step:181/1670 train_time:17573ms step_avg:97.09ms +step:182/1670 train_time:17669ms step_avg:97.08ms +step:183/1670 train_time:17764ms step_avg:97.07ms +step:184/1670 train_time:17859ms step_avg:97.06ms +step:185/1670 train_time:17955ms step_avg:97.05ms +step:186/1670 train_time:18050ms step_avg:97.04ms +step:187/1670 train_time:18146ms step_avg:97.04ms +step:188/1670 train_time:18241ms step_avg:97.03ms +step:189/1670 train_time:18337ms step_avg:97.02ms +step:190/1670 train_time:18432ms step_avg:97.01ms +step:191/1670 train_time:18528ms step_avg:97.00ms +step:192/1670 train_time:18623ms step_avg:96.99ms +step:193/1670 train_time:18719ms step_avg:96.99ms +step:194/1670 train_time:18816ms step_avg:96.99ms +step:195/1670 train_time:18911ms step_avg:96.98ms +step:196/1670 train_time:19006ms step_avg:96.97ms +step:197/1670 train_time:19101ms step_avg:96.96ms +step:198/1670 train_time:19197ms step_avg:96.96ms +step:199/1670 train_time:19293ms step_avg:96.95ms +step:200/1670 train_time:19388ms step_avg:96.94ms +step:201/1670 train_time:19483ms step_avg:96.93ms +step:202/1670 train_time:19579ms step_avg:96.93ms +step:203/1670 train_time:19675ms step_avg:96.92ms +step:204/1670 train_time:19771ms step_avg:96.92ms +step:205/1670 train_time:19866ms step_avg:96.91ms +step:206/1670 train_time:19962ms step_avg:96.90ms +step:207/1670 train_time:20057ms step_avg:96.89ms +step:208/1670 train_time:20153ms step_avg:96.89ms +step:209/1670 train_time:20249ms step_avg:96.89ms +step:210/1670 train_time:20345ms step_avg:96.88ms +step:211/1670 train_time:20440ms step_avg:96.87ms +step:212/1670 train_time:20536ms step_avg:96.87ms +step:213/1670 train_time:20868ms step_avg:97.97ms +step:214/1670 train_time:20942ms step_avg:97.86ms +step:215/1670 train_time:21036ms step_avg:97.84ms +step:216/1670 train_time:21132ms step_avg:97.83ms +step:217/1670 train_time:21226ms step_avg:97.81ms +step:218/1670 train_time:21320ms step_avg:97.80ms +step:219/1670 train_time:21415ms step_avg:97.78ms +step:220/1670 train_time:21509ms step_avg:97.77ms +step:221/1670 train_time:21604ms step_avg:97.75ms +step:222/1670 train_time:21698ms step_avg:97.74ms +step:223/1670 train_time:21794ms step_avg:97.73ms +step:224/1670 train_time:21897ms step_avg:97.75ms +step:225/1670 train_time:21995ms step_avg:97.75ms +step:226/1670 train_time:22091ms step_avg:97.75ms +step:227/1670 train_time:22186ms step_avg:97.74ms +step:228/1670 train_time:22281ms step_avg:97.72ms +step:229/1670 train_time:22377ms step_avg:97.71ms +step:230/1670 train_time:22471ms step_avg:97.70ms +step:231/1670 train_time:22566ms step_avg:97.69ms +step:232/1670 train_time:22660ms step_avg:97.67ms +step:233/1670 train_time:22755ms step_avg:97.66ms +step:234/1670 train_time:22852ms step_avg:97.66ms +step:235/1670 train_time:22950ms step_avg:97.66ms +step:236/1670 train_time:23046ms step_avg:97.65ms +step:237/1670 train_time:23141ms step_avg:97.64ms +step:238/1670 train_time:23236ms step_avg:97.63ms +step:239/1670 train_time:23332ms step_avg:97.62ms +step:240/1670 train_time:23428ms step_avg:97.62ms +step:241/1670 train_time:23523ms step_avg:97.61ms +step:242/1670 train_time:23617ms step_avg:97.59ms +step:243/1670 train_time:23713ms step_avg:97.58ms +step:244/1670 train_time:23809ms step_avg:97.58ms +step:245/1670 train_time:23905ms step_avg:97.57ms +step:246/1670 train_time:24000ms step_avg:97.56ms +step:247/1670 train_time:24097ms step_avg:97.56ms +step:248/1670 train_time:24194ms step_avg:97.55ms +step:249/1670 train_time:24289ms step_avg:97.55ms +step:250/1670 train_time:24384ms step_avg:97.54ms +step:250/1670 val_loss:3.9604 train_time:24478ms step_avg:97.91ms +step:251/1670 train_time:24499ms step_avg:97.61ms +step:252/1670 train_time:24580ms step_avg:97.54ms +step:253/1670 train_time:24678ms step_avg:97.54ms +step:254/1670 train_time:24773ms step_avg:97.53ms +step:255/1670 train_time:24868ms step_avg:97.52ms +step:256/1670 train_time:24963ms step_avg:97.51ms +step:257/1670 train_time:25058ms step_avg:97.50ms +step:258/1670 train_time:25152ms step_avg:97.49ms +step:259/1670 train_time:25248ms step_avg:97.48ms +step:260/1670 train_time:25342ms step_avg:97.47ms +step:261/1670 train_time:25438ms step_avg:97.46ms +step:262/1670 train_time:25535ms step_avg:97.46ms +step:263/1670 train_time:25633ms step_avg:97.46ms +step:264/1670 train_time:25730ms step_avg:97.46ms +step:265/1670 train_time:25826ms step_avg:97.46ms +step:266/1670 train_time:25920ms step_avg:97.45ms +step:267/1670 train_time:26015ms step_avg:97.43ms +step:268/1670 train_time:26110ms step_avg:97.42ms +step:269/1670 train_time:26204ms step_avg:97.41ms +step:270/1670 train_time:26299ms step_avg:97.40ms +step:271/1670 train_time:26394ms step_avg:97.40ms +step:272/1670 train_time:26491ms step_avg:97.39ms +step:273/1670 train_time:26589ms step_avg:97.39ms +step:274/1670 train_time:26685ms step_avg:97.39ms +step:275/1670 train_time:26781ms step_avg:97.39ms +step:276/1670 train_time:26876ms step_avg:97.38ms +step:277/1670 train_time:26971ms step_avg:97.37ms +step:278/1670 train_time:27067ms step_avg:97.36ms +step:279/1670 train_time:27163ms step_avg:97.36ms +step:280/1670 train_time:27257ms step_avg:97.35ms +step:281/1670 train_time:27352ms step_avg:97.34ms +step:282/1670 train_time:27449ms step_avg:97.34ms +step:283/1670 train_time:27545ms step_avg:97.33ms +step:284/1670 train_time:27641ms step_avg:97.33ms +step:285/1670 train_time:27737ms step_avg:97.32ms +step:286/1670 train_time:27832ms step_avg:97.32ms +step:287/1670 train_time:27929ms step_avg:97.31ms +step:288/1670 train_time:28024ms step_avg:97.30ms +step:289/1670 train_time:28119ms step_avg:97.30ms +step:290/1670 train_time:28213ms step_avg:97.29ms +step:291/1670 train_time:28308ms step_avg:97.28ms +step:292/1670 train_time:28404ms step_avg:97.27ms +step:293/1670 train_time:28500ms step_avg:97.27ms +step:294/1670 train_time:28596ms step_avg:97.26ms +step:295/1670 train_time:28692ms step_avg:97.26ms +step:296/1670 train_time:28788ms step_avg:97.26ms +step:297/1670 train_time:28884ms step_avg:97.25ms +step:298/1670 train_time:28979ms step_avg:97.25ms +step:299/1670 train_time:29075ms step_avg:97.24ms +step:300/1670 train_time:29170ms step_avg:97.23ms +step:301/1670 train_time:29265ms step_avg:97.22ms +step:302/1670 train_time:29361ms step_avg:97.22ms +step:303/1670 train_time:29456ms step_avg:97.22ms +step:304/1670 train_time:29552ms step_avg:97.21ms +step:305/1670 train_time:29648ms step_avg:97.21ms +step:306/1670 train_time:29744ms step_avg:97.20ms +step:307/1670 train_time:29839ms step_avg:97.20ms +step:308/1670 train_time:29935ms step_avg:97.19ms +step:309/1670 train_time:30031ms step_avg:97.19ms +step:310/1670 train_time:30126ms step_avg:97.18ms +step:311/1670 train_time:30222ms step_avg:97.18ms +step:312/1670 train_time:30317ms step_avg:97.17ms +step:313/1670 train_time:30412ms step_avg:97.16ms +step:314/1670 train_time:30507ms step_avg:97.16ms +step:315/1670 train_time:30603ms step_avg:97.15ms +step:316/1670 train_time:30699ms step_avg:97.15ms +step:317/1670 train_time:30794ms step_avg:97.14ms +step:318/1670 train_time:30889ms step_avg:97.14ms +step:319/1670 train_time:30986ms step_avg:97.13ms +step:320/1670 train_time:31082ms step_avg:97.13ms +step:321/1670 train_time:31177ms step_avg:97.13ms +step:322/1670 train_time:31273ms step_avg:97.12ms +step:323/1670 train_time:31367ms step_avg:97.11ms +step:324/1670 train_time:31463ms step_avg:97.11ms +step:325/1670 train_time:31558ms step_avg:97.10ms +step:326/1670 train_time:31653ms step_avg:97.10ms +step:327/1670 train_time:31750ms step_avg:97.09ms +step:328/1670 train_time:31847ms step_avg:97.09ms +step:329/1670 train_time:31943ms step_avg:97.09ms +step:330/1670 train_time:32039ms step_avg:97.09ms +step:331/1670 train_time:32134ms step_avg:97.08ms +step:332/1670 train_time:32230ms step_avg:97.08ms +step:333/1670 train_time:32326ms step_avg:97.08ms +step:334/1670 train_time:32421ms step_avg:97.07ms +step:335/1670 train_time:32516ms step_avg:97.06ms +step:336/1670 train_time:32611ms step_avg:97.06ms +step:337/1670 train_time:32707ms step_avg:97.05ms +step:338/1670 train_time:32803ms step_avg:97.05ms +step:339/1670 train_time:32899ms step_avg:97.05ms +step:340/1670 train_time:32994ms step_avg:97.04ms +step:341/1670 train_time:33089ms step_avg:97.04ms +step:342/1670 train_time:33185ms step_avg:97.03ms +step:343/1670 train_time:33281ms step_avg:97.03ms +step:344/1670 train_time:33376ms step_avg:97.02ms +step:345/1670 train_time:33471ms step_avg:97.02ms +step:346/1670 train_time:33567ms step_avg:97.01ms +step:347/1670 train_time:33663ms step_avg:97.01ms +step:348/1670 train_time:33758ms step_avg:97.00ms +step:349/1670 train_time:33854ms step_avg:97.00ms +step:350/1670 train_time:33950ms step_avg:97.00ms +step:351/1670 train_time:34045ms step_avg:96.99ms +step:352/1670 train_time:34141ms step_avg:96.99ms +step:353/1670 train_time:34236ms step_avg:96.99ms +step:354/1670 train_time:34331ms step_avg:96.98ms +step:355/1670 train_time:34427ms step_avg:96.98ms +step:356/1670 train_time:34523ms step_avg:96.97ms +step:357/1670 train_time:34618ms step_avg:96.97ms +step:358/1670 train_time:34713ms step_avg:96.96ms +step:359/1670 train_time:34809ms step_avg:96.96ms +step:360/1670 train_time:34905ms step_avg:96.96ms +step:361/1670 train_time:35001ms step_avg:96.95ms +step:362/1670 train_time:35096ms step_avg:96.95ms +step:363/1670 train_time:35192ms step_avg:96.95ms +step:364/1670 train_time:35288ms step_avg:96.94ms +step:365/1670 train_time:35384ms step_avg:96.94ms +step:366/1670 train_time:35480ms step_avg:96.94ms +step:367/1670 train_time:35575ms step_avg:96.93ms +step:368/1670 train_time:35671ms step_avg:96.93ms +step:369/1670 train_time:35766ms step_avg:96.93ms +step:370/1670 train_time:35862ms step_avg:96.92ms +step:371/1670 train_time:35957ms step_avg:96.92ms +step:372/1670 train_time:36052ms step_avg:96.91ms +step:373/1670 train_time:36149ms step_avg:96.91ms +step:374/1670 train_time:36245ms step_avg:96.91ms +step:375/1670 train_time:36341ms step_avg:96.91ms +step:375/1670 val_loss:3.8122 train_time:36435ms step_avg:97.16ms +step:376/1670 train_time:36456ms step_avg:96.96ms +step:377/1670 train_time:36540ms step_avg:96.92ms +step:378/1670 train_time:36639ms step_avg:96.93ms +step:379/1670 train_time:36736ms step_avg:96.93ms +step:380/1670 train_time:36831ms step_avg:96.92ms +step:381/1670 train_time:36926ms step_avg:96.92ms +step:382/1670 train_time:37021ms step_avg:96.91ms +step:383/1670 train_time:37116ms step_avg:96.91ms +step:384/1670 train_time:37210ms step_avg:96.90ms +step:385/1670 train_time:37304ms step_avg:96.89ms +step:386/1670 train_time:37400ms step_avg:96.89ms +step:387/1670 train_time:37497ms step_avg:96.89ms +step:388/1670 train_time:37595ms step_avg:96.89ms +step:389/1670 train_time:37691ms step_avg:96.89ms +step:390/1670 train_time:37786ms step_avg:96.89ms +step:391/1670 train_time:37882ms step_avg:96.89ms +step:392/1670 train_time:37978ms step_avg:96.88ms +step:393/1670 train_time:38073ms step_avg:96.88ms +step:394/1670 train_time:38167ms step_avg:96.87ms +step:395/1670 train_time:38262ms step_avg:96.87ms +step:396/1670 train_time:38357ms step_avg:96.86ms +step:397/1670 train_time:38453ms step_avg:96.86ms +step:398/1670 train_time:38549ms step_avg:96.86ms +step:399/1670 train_time:38646ms step_avg:96.86ms +step:400/1670 train_time:38743ms step_avg:96.86ms +step:401/1670 train_time:38840ms step_avg:96.86ms +step:402/1670 train_time:38936ms step_avg:96.85ms +step:403/1670 train_time:39031ms step_avg:96.85ms +step:404/1670 train_time:39126ms step_avg:96.85ms +step:405/1670 train_time:39221ms step_avg:96.84ms +step:406/1670 train_time:39317ms step_avg:96.84ms +step:407/1670 train_time:39412ms step_avg:96.84ms +step:408/1670 train_time:39508ms step_avg:96.83ms +step:409/1670 train_time:39605ms step_avg:96.83ms +step:410/1670 train_time:39701ms step_avg:96.83ms +step:411/1670 train_time:39797ms step_avg:96.83ms +step:412/1670 train_time:39893ms step_avg:96.83ms +step:413/1670 train_time:39988ms step_avg:96.82ms +step:414/1670 train_time:40083ms step_avg:96.82ms +step:415/1670 train_time:40178ms step_avg:96.82ms +step:416/1670 train_time:40274ms step_avg:96.81ms +step:417/1670 train_time:40370ms step_avg:96.81ms +step:418/1670 train_time:40465ms step_avg:96.81ms +step:419/1670 train_time:40561ms step_avg:96.80ms +step:420/1670 train_time:40658ms step_avg:96.80ms +step:421/1670 train_time:40753ms step_avg:96.80ms +step:422/1670 train_time:40849ms step_avg:96.80ms +step:423/1670 train_time:40945ms step_avg:96.80ms +step:424/1670 train_time:41041ms step_avg:96.80ms +step:425/1670 train_time:41337ms step_avg:97.26ms +step:426/1670 train_time:41450ms step_avg:97.30ms +step:427/1670 train_time:41543ms step_avg:97.29ms +step:428/1670 train_time:41638ms step_avg:97.29ms +step:429/1670 train_time:41733ms step_avg:97.28ms +step:430/1670 train_time:41827ms step_avg:97.27ms +step:431/1670 train_time:41922ms step_avg:97.27ms +step:432/1670 train_time:42017ms step_avg:97.26ms +step:433/1670 train_time:42112ms step_avg:97.26ms +step:434/1670 train_time:42206ms step_avg:97.25ms +step:435/1670 train_time:42303ms step_avg:97.25ms +step:436/1670 train_time:42404ms step_avg:97.26ms +step:437/1670 train_time:42504ms step_avg:97.26ms +step:438/1670 train_time:42600ms step_avg:97.26ms +step:439/1670 train_time:42695ms step_avg:97.25ms +step:440/1670 train_time:42789ms step_avg:97.25ms +step:441/1670 train_time:42884ms step_avg:97.24ms +step:442/1670 train_time:42979ms step_avg:97.24ms +step:443/1670 train_time:43073ms step_avg:97.23ms +step:444/1670 train_time:43168ms step_avg:97.22ms +step:445/1670 train_time:43264ms step_avg:97.22ms +step:446/1670 train_time:43362ms step_avg:97.22ms +step:447/1670 train_time:43460ms step_avg:97.23ms +step:448/1670 train_time:43557ms step_avg:97.23ms +step:449/1670 train_time:43653ms step_avg:97.22ms +step:450/1670 train_time:43748ms step_avg:97.22ms +step:451/1670 train_time:43844ms step_avg:97.22ms +step:452/1670 train_time:43940ms step_avg:97.21ms +step:453/1670 train_time:44034ms step_avg:97.21ms +step:454/1670 train_time:44129ms step_avg:97.20ms +step:455/1670 train_time:44224ms step_avg:97.19ms +step:456/1670 train_time:44320ms step_avg:97.19ms +step:457/1670 train_time:44417ms step_avg:97.19ms +step:458/1670 train_time:44513ms step_avg:97.19ms +step:459/1670 train_time:44609ms step_avg:97.19ms +step:460/1670 train_time:44704ms step_avg:97.18ms +step:461/1670 train_time:44800ms step_avg:97.18ms +step:462/1670 train_time:44895ms step_avg:97.18ms +step:463/1670 train_time:44990ms step_avg:97.17ms +step:464/1670 train_time:45084ms step_avg:97.16ms +step:465/1670 train_time:45179ms step_avg:97.16ms +step:466/1670 train_time:45274ms step_avg:97.16ms +step:467/1670 train_time:45370ms step_avg:97.15ms +step:468/1670 train_time:45466ms step_avg:97.15ms +step:469/1670 train_time:45563ms step_avg:97.15ms +step:470/1670 train_time:45660ms step_avg:97.15ms +step:471/1670 train_time:45756ms step_avg:97.15ms +step:472/1670 train_time:45851ms step_avg:97.14ms +step:473/1670 train_time:45946ms step_avg:97.14ms +step:474/1670 train_time:46041ms step_avg:97.13ms +step:475/1670 train_time:46136ms step_avg:97.13ms +step:476/1670 train_time:46231ms step_avg:97.12ms +step:477/1670 train_time:46326ms step_avg:97.12ms +step:478/1670 train_time:46422ms step_avg:97.12ms +step:479/1670 train_time:46519ms step_avg:97.12ms +step:480/1670 train_time:46615ms step_avg:97.11ms +step:481/1670 train_time:46710ms step_avg:97.11ms +step:482/1670 train_time:46806ms step_avg:97.11ms +step:483/1670 train_time:46901ms step_avg:97.10ms +step:484/1670 train_time:46997ms step_avg:97.10ms +step:485/1670 train_time:47092ms step_avg:97.10ms +step:486/1670 train_time:47186ms step_avg:97.09ms +step:487/1670 train_time:47282ms step_avg:97.09ms +step:488/1670 train_time:47378ms step_avg:97.09ms +step:489/1670 train_time:47473ms step_avg:97.08ms +step:490/1670 train_time:47568ms step_avg:97.08ms +step:491/1670 train_time:47665ms step_avg:97.08ms +step:492/1670 train_time:47761ms step_avg:97.07ms +step:493/1670 train_time:47856ms step_avg:97.07ms +step:494/1670 train_time:47952ms step_avg:97.07ms +step:495/1670 train_time:48046ms step_avg:97.06ms +step:496/1670 train_time:48142ms step_avg:97.06ms +step:497/1670 train_time:48238ms step_avg:97.06ms +step:498/1670 train_time:48334ms step_avg:97.06ms +step:499/1670 train_time:48429ms step_avg:97.05ms +step:500/1670 train_time:48524ms step_avg:97.05ms +step:500/1670 val_loss:3.7107 train_time:48620ms step_avg:97.24ms +step:501/1670 train_time:48640ms step_avg:97.09ms +step:502/1670 train_time:48722ms step_avg:97.06ms +step:503/1670 train_time:48823ms step_avg:97.06ms +step:504/1670 train_time:48919ms step_avg:97.06ms +step:505/1670 train_time:49014ms step_avg:97.06ms +step:506/1670 train_time:49109ms step_avg:97.05ms +step:507/1670 train_time:49204ms step_avg:97.05ms +step:508/1670 train_time:49298ms step_avg:97.04ms +step:509/1670 train_time:49393ms step_avg:97.04ms +step:510/1670 train_time:49489ms step_avg:97.04ms +step:511/1670 train_time:49583ms step_avg:97.03ms +step:512/1670 train_time:49679ms step_avg:97.03ms +step:513/1670 train_time:49777ms step_avg:97.03ms +step:514/1670 train_time:49876ms step_avg:97.03ms +step:515/1670 train_time:49973ms step_avg:97.03ms +step:516/1670 train_time:50068ms step_avg:97.03ms +step:517/1670 train_time:50164ms step_avg:97.03ms +step:518/1670 train_time:50259ms step_avg:97.02ms +step:519/1670 train_time:50353ms step_avg:97.02ms +step:520/1670 train_time:50448ms step_avg:97.02ms +step:521/1670 train_time:50544ms step_avg:97.01ms +step:522/1670 train_time:50639ms step_avg:97.01ms +step:523/1670 train_time:50735ms step_avg:97.01ms +step:524/1670 train_time:50832ms step_avg:97.01ms +step:525/1670 train_time:50929ms step_avg:97.01ms +step:526/1670 train_time:51025ms step_avg:97.01ms +step:527/1670 train_time:51119ms step_avg:97.00ms +step:528/1670 train_time:51214ms step_avg:97.00ms +step:529/1670 train_time:51310ms step_avg:96.99ms +step:530/1670 train_time:51406ms step_avg:96.99ms +step:531/1670 train_time:51500ms step_avg:96.99ms +step:532/1670 train_time:51596ms step_avg:96.98ms +step:533/1670 train_time:51691ms step_avg:96.98ms +step:534/1670 train_time:51788ms step_avg:96.98ms +step:535/1670 train_time:51884ms step_avg:96.98ms +step:536/1670 train_time:51980ms step_avg:96.98ms +step:537/1670 train_time:52076ms step_avg:96.98ms +step:538/1670 train_time:52172ms step_avg:96.97ms +step:539/1670 train_time:52267ms step_avg:96.97ms +step:540/1670 train_time:52362ms step_avg:96.97ms +step:541/1670 train_time:52456ms step_avg:96.96ms +step:542/1670 train_time:52552ms step_avg:96.96ms +step:543/1670 train_time:52648ms step_avg:96.96ms +step:544/1670 train_time:52744ms step_avg:96.96ms +step:545/1670 train_time:52840ms step_avg:96.95ms +step:546/1670 train_time:52936ms step_avg:96.95ms +step:547/1670 train_time:53031ms step_avg:96.95ms +step:548/1670 train_time:53127ms step_avg:96.95ms +step:549/1670 train_time:53223ms step_avg:96.95ms +step:550/1670 train_time:53318ms step_avg:96.94ms +step:551/1670 train_time:53413ms step_avg:96.94ms +step:552/1670 train_time:53509ms step_avg:96.94ms +step:553/1670 train_time:53604ms step_avg:96.93ms +step:554/1670 train_time:53699ms step_avg:96.93ms +step:555/1670 train_time:53796ms step_avg:96.93ms +step:556/1670 train_time:53891ms step_avg:96.93ms +step:557/1670 train_time:53987ms step_avg:96.93ms +step:558/1670 train_time:54084ms step_avg:96.92ms +step:559/1670 train_time:54180ms step_avg:96.92ms +step:560/1670 train_time:54277ms step_avg:96.92ms +step:561/1670 train_time:54374ms step_avg:96.92ms +step:562/1670 train_time:54471ms step_avg:96.92ms +step:563/1670 train_time:54567ms step_avg:96.92ms +step:564/1670 train_time:54665ms step_avg:96.92ms +step:565/1670 train_time:54761ms step_avg:96.92ms +step:566/1670 train_time:54857ms step_avg:96.92ms +step:567/1670 train_time:54954ms step_avg:96.92ms +step:568/1670 train_time:55052ms step_avg:96.92ms +step:569/1670 train_time:55150ms step_avg:96.92ms +step:570/1670 train_time:55247ms step_avg:96.92ms +step:571/1670 train_time:55344ms step_avg:96.92ms +step:572/1670 train_time:55441ms step_avg:96.92ms +step:573/1670 train_time:55537ms step_avg:96.92ms +step:574/1670 train_time:55635ms step_avg:96.92ms +step:575/1670 train_time:55732ms step_avg:96.93ms +step:576/1670 train_time:55830ms step_avg:96.93ms +step:577/1670 train_time:55927ms step_avg:96.93ms +step:578/1670 train_time:56024ms step_avg:96.93ms +step:579/1670 train_time:56121ms step_avg:96.93ms +step:580/1670 train_time:56217ms step_avg:96.93ms +step:581/1670 train_time:56315ms step_avg:96.93ms +step:582/1670 train_time:56412ms step_avg:96.93ms +step:583/1670 train_time:56510ms step_avg:96.93ms +step:584/1670 train_time:56607ms step_avg:96.93ms +step:585/1670 train_time:56704ms step_avg:96.93ms +step:586/1670 train_time:56801ms step_avg:96.93ms +step:587/1670 train_time:56898ms step_avg:96.93ms +step:588/1670 train_time:56995ms step_avg:96.93ms +step:589/1670 train_time:57092ms step_avg:96.93ms +step:590/1670 train_time:57190ms step_avg:96.93ms +step:591/1670 train_time:57287ms step_avg:96.93ms +step:592/1670 train_time:57384ms step_avg:96.93ms +step:593/1670 train_time:57480ms step_avg:96.93ms +step:594/1670 train_time:57576ms step_avg:96.93ms +step:595/1670 train_time:57674ms step_avg:96.93ms +step:596/1670 train_time:57771ms step_avg:96.93ms +step:597/1670 train_time:57869ms step_avg:96.93ms +step:598/1670 train_time:57966ms step_avg:96.93ms +step:599/1670 train_time:58062ms step_avg:96.93ms +step:600/1670 train_time:58159ms step_avg:96.93ms +step:601/1670 train_time:58256ms step_avg:96.93ms +step:602/1670 train_time:58355ms step_avg:96.94ms +step:603/1670 train_time:58452ms step_avg:96.94ms +step:604/1670 train_time:58550ms step_avg:96.94ms +step:605/1670 train_time:58646ms step_avg:96.94ms +step:606/1670 train_time:58744ms step_avg:96.94ms +step:607/1670 train_time:58841ms step_avg:96.94ms +step:608/1670 train_time:58937ms step_avg:96.94ms +step:609/1670 train_time:59034ms step_avg:96.94ms +step:610/1670 train_time:59132ms step_avg:96.94ms +step:611/1670 train_time:59229ms step_avg:96.94ms +step:612/1670 train_time:59327ms step_avg:96.94ms +step:613/1670 train_time:59424ms step_avg:96.94ms +step:614/1670 train_time:59521ms step_avg:96.94ms +step:615/1670 train_time:59617ms step_avg:96.94ms +step:616/1670 train_time:59714ms step_avg:96.94ms +step:617/1670 train_time:59812ms step_avg:96.94ms +step:618/1670 train_time:59909ms step_avg:96.94ms +step:619/1670 train_time:60007ms step_avg:96.94ms +step:620/1670 train_time:60103ms step_avg:96.94ms +step:621/1670 train_time:60199ms step_avg:96.94ms +step:622/1670 train_time:60296ms step_avg:96.94ms +step:623/1670 train_time:60393ms step_avg:96.94ms +step:624/1670 train_time:60490ms step_avg:96.94ms +step:625/1670 train_time:60588ms step_avg:96.94ms +step:625/1670 val_loss:3.6144 train_time:60684ms step_avg:97.09ms +step:626/1670 train_time:60705ms step_avg:96.97ms +step:627/1670 train_time:60795ms step_avg:96.96ms +step:628/1670 train_time:60895ms step_avg:96.97ms +step:629/1670 train_time:60991ms step_avg:96.96ms +step:630/1670 train_time:61086ms step_avg:96.96ms +step:631/1670 train_time:61182ms step_avg:96.96ms +step:632/1670 train_time:61277ms step_avg:96.96ms +step:633/1670 train_time:61373ms step_avg:96.96ms +step:634/1670 train_time:61469ms step_avg:96.95ms +step:635/1670 train_time:61565ms step_avg:96.95ms +step:636/1670 train_time:61662ms step_avg:96.95ms +step:637/1670 train_time:61763ms step_avg:96.96ms +step:638/1670 train_time:61863ms step_avg:96.96ms +step:639/1670 train_time:62238ms step_avg:97.40ms +step:640/1670 train_time:62338ms step_avg:97.40ms +step:641/1670 train_time:62433ms step_avg:97.40ms +step:642/1670 train_time:62529ms step_avg:97.40ms +step:643/1670 train_time:62625ms step_avg:97.39ms +step:644/1670 train_time:62721ms step_avg:97.39ms +step:645/1670 train_time:62816ms step_avg:97.39ms +step:646/1670 train_time:62912ms step_avg:97.39ms +step:647/1670 train_time:63008ms step_avg:97.38ms +step:648/1670 train_time:63104ms step_avg:97.38ms +step:649/1670 train_time:63201ms step_avg:97.38ms +step:650/1670 train_time:63305ms step_avg:97.39ms +step:651/1670 train_time:63406ms step_avg:97.40ms +step:652/1670 train_time:63505ms step_avg:97.40ms +step:653/1670 train_time:63602ms step_avg:97.40ms +step:654/1670 train_time:63699ms step_avg:97.40ms +step:655/1670 train_time:63795ms step_avg:97.40ms +step:656/1670 train_time:63890ms step_avg:97.39ms +step:657/1670 train_time:63985ms step_avg:97.39ms +step:658/1670 train_time:64082ms step_avg:97.39ms +step:659/1670 train_time:64178ms step_avg:97.39ms +step:660/1670 train_time:64277ms step_avg:97.39ms +step:661/1670 train_time:64374ms step_avg:97.39ms +step:662/1670 train_time:64471ms step_avg:97.39ms +step:663/1670 train_time:64569ms step_avg:97.39ms +step:664/1670 train_time:64667ms step_avg:97.39ms +step:665/1670 train_time:64764ms step_avg:97.39ms +step:666/1670 train_time:64861ms step_avg:97.39ms +step:667/1670 train_time:64957ms step_avg:97.39ms +step:668/1670 train_time:65053ms step_avg:97.38ms +step:669/1670 train_time:65149ms step_avg:97.38ms +step:670/1670 train_time:65247ms step_avg:97.38ms +step:671/1670 train_time:65346ms step_avg:97.39ms +step:672/1670 train_time:65444ms step_avg:97.39ms +step:673/1670 train_time:65542ms step_avg:97.39ms +step:674/1670 train_time:65640ms step_avg:97.39ms +step:675/1670 train_time:65737ms step_avg:97.39ms +step:676/1670 train_time:65833ms step_avg:97.39ms +step:677/1670 train_time:65929ms step_avg:97.38ms +step:678/1670 train_time:66025ms step_avg:97.38ms +step:679/1670 train_time:66123ms step_avg:97.38ms +step:680/1670 train_time:66220ms step_avg:97.38ms +step:681/1670 train_time:66317ms step_avg:97.38ms +step:682/1670 train_time:66414ms step_avg:97.38ms +step:683/1670 train_time:66512ms step_avg:97.38ms +step:684/1670 train_time:66609ms step_avg:97.38ms +step:685/1670 train_time:66709ms step_avg:97.39ms +step:686/1670 train_time:66808ms step_avg:97.39ms +step:687/1670 train_time:66905ms step_avg:97.39ms +step:688/1670 train_time:67001ms step_avg:97.39ms +step:689/1670 train_time:67097ms step_avg:97.38ms +step:690/1670 train_time:67193ms step_avg:97.38ms +step:691/1670 train_time:67290ms step_avg:97.38ms +step:692/1670 train_time:67387ms step_avg:97.38ms +step:693/1670 train_time:67485ms step_avg:97.38ms +step:694/1670 train_time:67583ms step_avg:97.38ms +step:695/1670 train_time:67681ms step_avg:97.38ms +step:696/1670 train_time:67778ms step_avg:97.38ms +step:697/1670 train_time:67874ms step_avg:97.38ms +step:698/1670 train_time:67971ms step_avg:97.38ms +step:699/1670 train_time:68068ms step_avg:97.38ms +step:700/1670 train_time:68165ms step_avg:97.38ms +step:701/1670 train_time:68262ms step_avg:97.38ms +step:702/1670 train_time:68360ms step_avg:97.38ms +step:703/1670 train_time:68457ms step_avg:97.38ms +step:704/1670 train_time:68554ms step_avg:97.38ms +step:705/1670 train_time:68650ms step_avg:97.38ms +step:706/1670 train_time:68748ms step_avg:97.38ms +step:707/1670 train_time:68847ms step_avg:97.38ms +step:708/1670 train_time:68944ms step_avg:97.38ms +step:709/1670 train_time:69041ms step_avg:97.38ms +step:710/1670 train_time:69137ms step_avg:97.38ms +step:711/1670 train_time:69234ms step_avg:97.38ms +step:712/1670 train_time:69331ms step_avg:97.37ms +step:713/1670 train_time:69428ms step_avg:97.37ms +step:714/1670 train_time:69525ms step_avg:97.37ms +step:715/1670 train_time:69622ms step_avg:97.37ms +step:716/1670 train_time:69720ms step_avg:97.37ms +step:717/1670 train_time:69817ms step_avg:97.37ms +step:718/1670 train_time:69913ms step_avg:97.37ms +step:719/1670 train_time:70010ms step_avg:97.37ms +step:720/1670 train_time:70107ms step_avg:97.37ms +step:721/1670 train_time:70206ms step_avg:97.37ms +step:722/1670 train_time:70303ms step_avg:97.37ms +step:723/1670 train_time:70401ms step_avg:97.37ms +step:724/1670 train_time:70498ms step_avg:97.37ms +step:725/1670 train_time:70595ms step_avg:97.37ms +step:726/1670 train_time:70691ms step_avg:97.37ms +step:727/1670 train_time:70789ms step_avg:97.37ms +step:728/1670 train_time:70887ms step_avg:97.37ms +step:729/1670 train_time:70984ms step_avg:97.37ms +step:730/1670 train_time:71081ms step_avg:97.37ms +step:731/1670 train_time:71177ms step_avg:97.37ms +step:732/1670 train_time:71274ms step_avg:97.37ms +step:733/1670 train_time:71370ms step_avg:97.37ms +step:734/1670 train_time:71469ms step_avg:97.37ms +step:735/1670 train_time:71566ms step_avg:97.37ms +step:736/1670 train_time:71663ms step_avg:97.37ms +step:737/1670 train_time:71761ms step_avg:97.37ms +step:738/1670 train_time:71857ms step_avg:97.37ms +step:739/1670 train_time:71954ms step_avg:97.37ms +step:740/1670 train_time:72051ms step_avg:97.37ms +step:741/1670 train_time:72148ms step_avg:97.37ms +step:742/1670 train_time:72246ms step_avg:97.37ms +step:743/1670 train_time:72343ms step_avg:97.37ms +step:744/1670 train_time:72441ms step_avg:97.37ms +step:745/1670 train_time:72538ms step_avg:97.37ms +step:746/1670 train_time:72634ms step_avg:97.36ms +step:747/1670 train_time:72730ms step_avg:97.36ms +step:748/1670 train_time:72827ms step_avg:97.36ms +step:749/1670 train_time:72925ms step_avg:97.36ms +step:750/1670 train_time:73023ms step_avg:97.36ms +step:750/1670 val_loss:3.5600 train_time:73120ms step_avg:97.49ms +step:751/1670 train_time:73141ms step_avg:97.39ms +step:752/1670 train_time:73225ms step_avg:97.37ms +step:753/1670 train_time:73324ms step_avg:97.38ms +step:754/1670 train_time:73422ms step_avg:97.38ms +step:755/1670 train_time:73518ms step_avg:97.37ms +step:756/1670 train_time:73614ms step_avg:97.37ms +step:757/1670 train_time:73710ms step_avg:97.37ms +step:758/1670 train_time:73806ms step_avg:97.37ms +step:759/1670 train_time:73902ms step_avg:97.37ms +step:760/1670 train_time:73999ms step_avg:97.37ms +step:761/1670 train_time:74097ms step_avg:97.37ms +step:762/1670 train_time:74198ms step_avg:97.37ms +step:763/1670 train_time:74298ms step_avg:97.38ms +step:764/1670 train_time:74397ms step_avg:97.38ms +step:765/1670 train_time:74494ms step_avg:97.38ms +step:766/1670 train_time:74590ms step_avg:97.38ms +step:767/1670 train_time:74686ms step_avg:97.37ms +step:768/1670 train_time:74782ms step_avg:97.37ms +step:769/1670 train_time:74879ms step_avg:97.37ms +step:770/1670 train_time:74976ms step_avg:97.37ms +step:771/1670 train_time:75073ms step_avg:97.37ms +step:772/1670 train_time:75171ms step_avg:97.37ms +step:773/1670 train_time:75268ms step_avg:97.37ms +step:774/1670 train_time:75367ms step_avg:97.37ms +step:775/1670 train_time:75464ms step_avg:97.37ms +step:776/1670 train_time:75561ms step_avg:97.37ms +step:777/1670 train_time:75660ms step_avg:97.37ms +step:778/1670 train_time:75756ms step_avg:97.37ms +step:779/1670 train_time:75853ms step_avg:97.37ms +step:780/1670 train_time:75949ms step_avg:97.37ms +step:781/1670 train_time:76045ms step_avg:97.37ms +step:782/1670 train_time:76142ms step_avg:97.37ms +step:783/1670 train_time:76241ms step_avg:97.37ms +step:784/1670 train_time:76339ms step_avg:97.37ms +step:785/1670 train_time:76437ms step_avg:97.37ms +step:786/1670 train_time:76534ms step_avg:97.37ms +step:787/1670 train_time:76631ms step_avg:97.37ms +step:788/1670 train_time:76727ms step_avg:97.37ms +step:789/1670 train_time:76824ms step_avg:97.37ms +step:790/1670 train_time:76921ms step_avg:97.37ms +step:791/1670 train_time:77018ms step_avg:97.37ms +step:792/1670 train_time:77115ms step_avg:97.37ms +step:793/1670 train_time:77214ms step_avg:97.37ms +step:794/1670 train_time:77310ms step_avg:97.37ms +step:795/1670 train_time:77406ms step_avg:97.37ms +step:796/1670 train_time:77504ms step_avg:97.37ms +step:797/1670 train_time:77602ms step_avg:97.37ms +step:798/1670 train_time:77700ms step_avg:97.37ms +step:799/1670 train_time:77796ms step_avg:97.37ms +step:800/1670 train_time:77894ms step_avg:97.37ms +step:801/1670 train_time:77991ms step_avg:97.37ms +step:802/1670 train_time:78087ms step_avg:97.37ms +step:803/1670 train_time:78184ms step_avg:97.37ms +step:804/1670 train_time:78283ms step_avg:97.37ms +step:805/1670 train_time:78381ms step_avg:97.37ms +step:806/1670 train_time:78479ms step_avg:97.37ms +step:807/1670 train_time:78577ms step_avg:97.37ms +step:808/1670 train_time:78674ms step_avg:97.37ms +step:809/1670 train_time:78771ms step_avg:97.37ms +step:810/1670 train_time:78867ms step_avg:97.37ms +step:811/1670 train_time:78963ms step_avg:97.37ms +step:812/1670 train_time:79061ms step_avg:97.37ms +step:813/1670 train_time:79159ms step_avg:97.37ms +step:814/1670 train_time:79256ms step_avg:97.37ms +step:815/1670 train_time:79355ms step_avg:97.37ms +step:816/1670 train_time:79452ms step_avg:97.37ms +step:817/1670 train_time:79548ms step_avg:97.37ms +step:818/1670 train_time:79644ms step_avg:97.36ms +step:819/1670 train_time:79741ms step_avg:97.36ms +step:820/1670 train_time:79838ms step_avg:97.36ms +step:821/1670 train_time:79936ms step_avg:97.36ms +step:822/1670 train_time:80033ms step_avg:97.36ms +step:823/1670 train_time:80129ms step_avg:97.36ms +step:824/1670 train_time:80225ms step_avg:97.36ms +step:825/1670 train_time:80323ms step_avg:97.36ms +step:826/1670 train_time:80423ms step_avg:97.36ms +step:827/1670 train_time:80521ms step_avg:97.37ms +step:828/1670 train_time:80618ms step_avg:97.37ms +step:829/1670 train_time:80715ms step_avg:97.36ms +step:830/1670 train_time:80811ms step_avg:97.36ms +step:831/1670 train_time:80907ms step_avg:97.36ms +step:832/1670 train_time:81004ms step_avg:97.36ms +step:833/1670 train_time:81101ms step_avg:97.36ms +step:834/1670 train_time:81199ms step_avg:97.36ms +step:835/1670 train_time:81297ms step_avg:97.36ms +step:836/1670 train_time:81394ms step_avg:97.36ms +step:837/1670 train_time:81491ms step_avg:97.36ms +step:838/1670 train_time:81588ms step_avg:97.36ms +step:839/1670 train_time:81685ms step_avg:97.36ms +step:840/1670 train_time:81783ms step_avg:97.36ms +step:841/1670 train_time:81881ms step_avg:97.36ms +step:842/1670 train_time:81978ms step_avg:97.36ms +step:843/1670 train_time:82075ms step_avg:97.36ms +step:844/1670 train_time:82172ms step_avg:97.36ms +step:845/1670 train_time:82268ms step_avg:97.36ms +step:846/1670 train_time:82365ms step_avg:97.36ms +step:847/1670 train_time:82463ms step_avg:97.36ms +step:848/1670 train_time:82560ms step_avg:97.36ms +step:849/1670 train_time:82658ms step_avg:97.36ms +step:850/1670 train_time:82756ms step_avg:97.36ms +step:851/1670 train_time:83061ms step_avg:97.60ms +step:852/1670 train_time:83183ms step_avg:97.63ms +step:853/1670 train_time:83278ms step_avg:97.63ms +step:854/1670 train_time:83373ms step_avg:97.63ms +step:855/1670 train_time:83469ms step_avg:97.62ms +step:856/1670 train_time:83564ms step_avg:97.62ms +step:857/1670 train_time:83660ms step_avg:97.62ms +step:858/1670 train_time:83756ms step_avg:97.62ms +step:859/1670 train_time:83852ms step_avg:97.62ms +step:860/1670 train_time:83948ms step_avg:97.61ms +step:861/1670 train_time:84050ms step_avg:97.62ms +step:862/1670 train_time:84148ms step_avg:97.62ms +step:863/1670 train_time:84248ms step_avg:97.62ms +step:864/1670 train_time:84345ms step_avg:97.62ms +step:865/1670 train_time:84442ms step_avg:97.62ms +step:866/1670 train_time:84539ms step_avg:97.62ms +step:867/1670 train_time:84635ms step_avg:97.62ms +step:868/1670 train_time:84730ms step_avg:97.62ms +step:869/1670 train_time:84826ms step_avg:97.61ms +step:870/1670 train_time:84924ms step_avg:97.61ms +step:871/1670 train_time:85022ms step_avg:97.61ms +step:872/1670 train_time:85122ms step_avg:97.62ms +step:873/1670 train_time:85221ms step_avg:97.62ms +step:874/1670 train_time:85319ms step_avg:97.62ms +step:875/1670 train_time:85417ms step_avg:97.62ms +step:875/1670 val_loss:3.5194 train_time:85513ms step_avg:97.73ms +step:876/1670 train_time:85533ms step_avg:97.64ms +step:877/1670 train_time:85616ms step_avg:97.62ms +step:878/1670 train_time:85715ms step_avg:97.63ms +step:879/1670 train_time:85812ms step_avg:97.62ms +step:880/1670 train_time:85907ms step_avg:97.62ms +step:881/1670 train_time:86003ms step_avg:97.62ms +step:882/1670 train_time:86099ms step_avg:97.62ms +step:883/1670 train_time:86195ms step_avg:97.62ms +step:884/1670 train_time:86292ms step_avg:97.62ms +step:885/1670 train_time:86387ms step_avg:97.61ms +step:886/1670 train_time:86485ms step_avg:97.61ms +step:887/1670 train_time:86585ms step_avg:97.62ms +step:888/1670 train_time:86683ms step_avg:97.62ms +step:889/1670 train_time:86782ms step_avg:97.62ms +step:890/1670 train_time:86880ms step_avg:97.62ms +step:891/1670 train_time:86976ms step_avg:97.62ms +step:892/1670 train_time:87073ms step_avg:97.61ms +step:893/1670 train_time:87168ms step_avg:97.61ms +step:894/1670 train_time:87264ms step_avg:97.61ms +step:895/1670 train_time:87361ms step_avg:97.61ms +step:896/1670 train_time:87459ms step_avg:97.61ms +step:897/1670 train_time:87557ms step_avg:97.61ms +step:898/1670 train_time:87655ms step_avg:97.61ms +step:899/1670 train_time:87752ms step_avg:97.61ms +step:900/1670 train_time:87849ms step_avg:97.61ms +step:901/1670 train_time:87945ms step_avg:97.61ms +step:902/1670 train_time:88044ms step_avg:97.61ms +step:903/1670 train_time:88141ms step_avg:97.61ms +step:904/1670 train_time:88237ms step_avg:97.61ms +step:905/1670 train_time:88334ms step_avg:97.61ms +step:906/1670 train_time:88431ms step_avg:97.61ms +step:907/1670 train_time:88528ms step_avg:97.61ms +step:908/1670 train_time:88625ms step_avg:97.60ms +step:909/1670 train_time:88722ms step_avg:97.60ms +step:910/1670 train_time:88820ms step_avg:97.60ms +step:911/1670 train_time:88918ms step_avg:97.60ms +step:912/1670 train_time:89015ms step_avg:97.60ms +step:913/1670 train_time:89112ms step_avg:97.60ms +step:914/1670 train_time:89208ms step_avg:97.60ms +step:915/1670 train_time:89305ms step_avg:97.60ms +step:916/1670 train_time:89403ms step_avg:97.60ms +step:917/1670 train_time:89501ms step_avg:97.60ms +step:918/1670 train_time:89599ms step_avg:97.60ms +step:919/1670 train_time:89698ms step_avg:97.60ms +step:920/1670 train_time:89795ms step_avg:97.60ms +step:921/1670 train_time:89891ms step_avg:97.60ms +step:922/1670 train_time:89988ms step_avg:97.60ms +step:923/1670 train_time:90085ms step_avg:97.60ms +step:924/1670 train_time:90183ms step_avg:97.60ms +step:925/1670 train_time:90280ms step_avg:97.60ms +step:926/1670 train_time:90377ms step_avg:97.60ms +step:927/1670 train_time:90474ms step_avg:97.60ms +step:928/1670 train_time:90572ms step_avg:97.60ms +step:929/1670 train_time:90669ms step_avg:97.60ms +step:930/1670 train_time:90767ms step_avg:97.60ms +step:931/1670 train_time:90864ms step_avg:97.60ms +step:932/1670 train_time:90961ms step_avg:97.60ms +step:933/1670 train_time:91058ms step_avg:97.60ms +step:934/1670 train_time:91155ms step_avg:97.60ms +step:935/1670 train_time:91251ms step_avg:97.60ms +step:936/1670 train_time:91348ms step_avg:97.59ms +step:937/1670 train_time:91445ms step_avg:97.59ms +step:938/1670 train_time:91544ms step_avg:97.59ms +step:939/1670 train_time:91642ms step_avg:97.59ms +step:940/1670 train_time:91739ms step_avg:97.59ms +step:941/1670 train_time:91837ms step_avg:97.59ms +step:942/1670 train_time:91934ms step_avg:97.59ms +step:943/1670 train_time:92030ms step_avg:97.59ms +step:944/1670 train_time:92126ms step_avg:97.59ms +step:945/1670 train_time:92223ms step_avg:97.59ms +step:946/1670 train_time:92320ms step_avg:97.59ms +step:947/1670 train_time:92418ms step_avg:97.59ms +step:948/1670 train_time:92515ms step_avg:97.59ms +step:949/1670 train_time:92613ms step_avg:97.59ms +step:950/1670 train_time:92709ms step_avg:97.59ms +step:951/1670 train_time:92806ms step_avg:97.59ms +step:952/1670 train_time:92903ms step_avg:97.59ms +step:953/1670 train_time:93001ms step_avg:97.59ms +step:954/1670 train_time:93098ms step_avg:97.59ms +step:955/1670 train_time:93195ms step_avg:97.59ms +step:956/1670 train_time:93291ms step_avg:97.59ms +step:957/1670 train_time:93388ms step_avg:97.58ms +step:958/1670 train_time:93484ms step_avg:97.58ms +step:959/1670 train_time:93581ms step_avg:97.58ms +step:960/1670 train_time:93680ms step_avg:97.58ms +step:961/1670 train_time:93779ms step_avg:97.58ms +step:962/1670 train_time:93876ms step_avg:97.58ms +step:963/1670 train_time:93973ms step_avg:97.58ms +step:964/1670 train_time:94070ms step_avg:97.58ms +step:965/1670 train_time:94167ms step_avg:97.58ms +step:966/1670 train_time:94264ms step_avg:97.58ms +step:967/1670 train_time:94360ms step_avg:97.58ms +step:968/1670 train_time:94458ms step_avg:97.58ms +step:969/1670 train_time:94555ms step_avg:97.58ms +step:970/1670 train_time:94652ms step_avg:97.58ms +step:971/1670 train_time:94748ms step_avg:97.58ms +step:972/1670 train_time:94846ms step_avg:97.58ms +step:973/1670 train_time:94944ms step_avg:97.58ms +step:974/1670 train_time:95042ms step_avg:97.58ms +step:975/1670 train_time:95139ms step_avg:97.58ms +step:976/1670 train_time:95236ms step_avg:97.58ms +step:977/1670 train_time:95333ms step_avg:97.58ms +step:978/1670 train_time:95430ms step_avg:97.58ms +step:979/1670 train_time:95527ms step_avg:97.58ms +step:980/1670 train_time:95624ms step_avg:97.58ms +step:981/1670 train_time:95722ms step_avg:97.58ms +step:982/1670 train_time:95819ms step_avg:97.58ms +step:983/1670 train_time:95917ms step_avg:97.58ms +step:984/1670 train_time:96014ms step_avg:97.58ms +step:985/1670 train_time:96111ms step_avg:97.57ms +step:986/1670 train_time:96207ms step_avg:97.57ms +step:987/1670 train_time:96304ms step_avg:97.57ms +step:988/1670 train_time:96402ms step_avg:97.57ms +step:989/1670 train_time:96501ms step_avg:97.57ms +step:990/1670 train_time:96599ms step_avg:97.57ms +step:991/1670 train_time:96696ms step_avg:97.57ms +step:992/1670 train_time:96793ms step_avg:97.57ms +step:993/1670 train_time:96891ms step_avg:97.57ms +step:994/1670 train_time:96987ms step_avg:97.57ms +step:995/1670 train_time:97085ms step_avg:97.57ms +step:996/1670 train_time:97182ms step_avg:97.57ms +step:997/1670 train_time:97279ms step_avg:97.57ms +step:998/1670 train_time:97377ms step_avg:97.57ms +step:999/1670 train_time:97473ms step_avg:97.57ms +step:1000/1670 train_time:97571ms step_avg:97.57ms +step:1000/1670 val_loss:3.4779 train_time:97667ms step_avg:97.67ms +step:1001/1670 train_time:97688ms step_avg:97.59ms +step:1002/1670 train_time:97768ms step_avg:97.57ms +step:1003/1670 train_time:97868ms step_avg:97.58ms +step:1004/1670 train_time:97964ms step_avg:97.57ms +step:1005/1670 train_time:98060ms step_avg:97.57ms +step:1006/1670 train_time:98156ms step_avg:97.57ms +step:1007/1670 train_time:98252ms step_avg:97.57ms +step:1008/1670 train_time:98348ms step_avg:97.57ms +step:1009/1670 train_time:98445ms step_avg:97.57ms +step:1010/1670 train_time:98541ms step_avg:97.57ms +step:1011/1670 train_time:98640ms step_avg:97.57ms +step:1012/1670 train_time:98740ms step_avg:97.57ms +step:1013/1670 train_time:98841ms step_avg:97.57ms +step:1014/1670 train_time:98939ms step_avg:97.57ms +step:1015/1670 train_time:99036ms step_avg:97.57ms +step:1016/1670 train_time:99133ms step_avg:97.57ms +step:1017/1670 train_time:99229ms step_avg:97.57ms +step:1018/1670 train_time:99325ms step_avg:97.57ms +step:1019/1670 train_time:99420ms step_avg:97.57ms +step:1020/1670 train_time:99517ms step_avg:97.57ms +step:1021/1670 train_time:99616ms step_avg:97.57ms +step:1022/1670 train_time:99714ms step_avg:97.57ms +step:1023/1670 train_time:99813ms step_avg:97.57ms +step:1024/1670 train_time:99912ms step_avg:97.57ms +step:1025/1670 train_time:100010ms step_avg:97.57ms +step:1026/1670 train_time:100106ms step_avg:97.57ms +step:1027/1670 train_time:100202ms step_avg:97.57ms +step:1028/1670 train_time:100299ms step_avg:97.57ms +step:1029/1670 train_time:100395ms step_avg:97.57ms +step:1030/1670 train_time:100492ms step_avg:97.57ms +step:1031/1670 train_time:100589ms step_avg:97.56ms +step:1032/1670 train_time:100686ms step_avg:97.56ms +step:1033/1670 train_time:100784ms step_avg:97.56ms +step:1034/1670 train_time:100882ms step_avg:97.56ms +step:1035/1670 train_time:100979ms step_avg:97.56ms +step:1036/1670 train_time:101077ms step_avg:97.56ms +step:1037/1670 train_time:101175ms step_avg:97.56ms +step:1038/1670 train_time:101271ms step_avg:97.56ms +step:1039/1670 train_time:101368ms step_avg:97.56ms +step:1040/1670 train_time:101463ms step_avg:97.56ms +step:1041/1670 train_time:101561ms step_avg:97.56ms +step:1042/1670 train_time:101658ms step_avg:97.56ms +step:1043/1670 train_time:101757ms step_avg:97.56ms +step:1044/1670 train_time:101855ms step_avg:97.56ms +step:1045/1670 train_time:101953ms step_avg:97.56ms +step:1046/1670 train_time:102050ms step_avg:97.56ms +step:1047/1670 train_time:102146ms step_avg:97.56ms +step:1048/1670 train_time:102243ms step_avg:97.56ms +step:1049/1670 train_time:102340ms step_avg:97.56ms +step:1050/1670 train_time:102438ms step_avg:97.56ms +step:1051/1670 train_time:102535ms step_avg:97.56ms +step:1052/1670 train_time:102633ms step_avg:97.56ms +step:1053/1670 train_time:102731ms step_avg:97.56ms +step:1054/1670 train_time:102829ms step_avg:97.56ms +step:1055/1670 train_time:102925ms step_avg:97.56ms +step:1056/1670 train_time:103023ms step_avg:97.56ms +step:1057/1670 train_time:103120ms step_avg:97.56ms +step:1058/1670 train_time:103217ms step_avg:97.56ms +step:1059/1670 train_time:103315ms step_avg:97.56ms +step:1060/1670 train_time:103413ms step_avg:97.56ms +step:1061/1670 train_time:103510ms step_avg:97.56ms +step:1062/1670 train_time:103781ms step_avg:97.72ms +step:1063/1670 train_time:103856ms step_avg:97.70ms +step:1064/1670 train_time:103952ms step_avg:97.70ms +step:1065/1670 train_time:104048ms step_avg:97.70ms +step:1066/1670 train_time:104144ms step_avg:97.70ms +step:1067/1670 train_time:104239ms step_avg:97.69ms +step:1068/1670 train_time:104336ms step_avg:97.69ms +step:1069/1670 train_time:104432ms step_avg:97.69ms +step:1070/1670 train_time:104528ms step_avg:97.69ms +step:1071/1670 train_time:104625ms step_avg:97.69ms +step:1072/1670 train_time:104726ms step_avg:97.69ms +step:1073/1670 train_time:104825ms step_avg:97.69ms +step:1074/1670 train_time:104922ms step_avg:97.69ms +step:1075/1670 train_time:105019ms step_avg:97.69ms +step:1076/1670 train_time:105116ms step_avg:97.69ms +step:1077/1670 train_time:105212ms step_avg:97.69ms +step:1078/1670 train_time:105307ms step_avg:97.69ms +step:1079/1670 train_time:105403ms step_avg:97.69ms +step:1080/1670 train_time:105498ms step_avg:97.68ms +step:1081/1670 train_time:105597ms step_avg:97.68ms +step:1082/1670 train_time:105697ms step_avg:97.69ms +step:1083/1670 train_time:105798ms step_avg:97.69ms +step:1084/1670 train_time:105896ms step_avg:97.69ms +step:1085/1670 train_time:105993ms step_avg:97.69ms +step:1086/1670 train_time:106090ms step_avg:97.69ms +step:1087/1670 train_time:106187ms step_avg:97.69ms +step:1088/1670 train_time:106283ms step_avg:97.69ms +step:1089/1670 train_time:106378ms step_avg:97.68ms +step:1090/1670 train_time:106475ms step_avg:97.68ms +step:1091/1670 train_time:106572ms step_avg:97.68ms +step:1092/1670 train_time:106670ms step_avg:97.68ms +step:1093/1670 train_time:106768ms step_avg:97.68ms +step:1094/1670 train_time:106865ms step_avg:97.68ms +step:1095/1670 train_time:106962ms step_avg:97.68ms +step:1096/1670 train_time:107061ms step_avg:97.68ms +step:1097/1670 train_time:107158ms step_avg:97.68ms +step:1098/1670 train_time:107255ms step_avg:97.68ms +step:1099/1670 train_time:107351ms step_avg:97.68ms +step:1100/1670 train_time:107448ms step_avg:97.68ms +step:1101/1670 train_time:107544ms step_avg:97.68ms +step:1102/1670 train_time:107641ms step_avg:97.68ms +step:1103/1670 train_time:107739ms step_avg:97.68ms +step:1104/1670 train_time:107838ms step_avg:97.68ms +step:1105/1670 train_time:107935ms step_avg:97.68ms +step:1106/1670 train_time:108033ms step_avg:97.68ms +step:1107/1670 train_time:108131ms step_avg:97.68ms +step:1108/1670 train_time:108228ms step_avg:97.68ms +step:1109/1670 train_time:108324ms step_avg:97.68ms +step:1110/1670 train_time:108421ms step_avg:97.68ms +step:1111/1670 train_time:108517ms step_avg:97.68ms +step:1112/1670 train_time:108615ms step_avg:97.68ms +step:1113/1670 train_time:108712ms step_avg:97.67ms +step:1114/1670 train_time:108810ms step_avg:97.67ms +step:1115/1670 train_time:108907ms step_avg:97.67ms +step:1116/1670 train_time:109005ms step_avg:97.68ms +step:1117/1670 train_time:109103ms step_avg:97.68ms +step:1118/1670 train_time:109201ms step_avg:97.68ms +step:1119/1670 train_time:109298ms step_avg:97.67ms +step:1120/1670 train_time:109396ms step_avg:97.68ms +step:1121/1670 train_time:109494ms step_avg:97.68ms +step:1122/1670 train_time:109593ms step_avg:97.68ms +step:1123/1670 train_time:109690ms step_avg:97.68ms +step:1124/1670 train_time:109788ms step_avg:97.68ms +step:1125/1670 train_time:109885ms step_avg:97.68ms +step:1125/1670 val_loss:3.4237 train_time:109982ms step_avg:97.76ms +step:1126/1670 train_time:110003ms step_avg:97.69ms +step:1127/1670 train_time:110090ms step_avg:97.68ms +step:1128/1670 train_time:110187ms step_avg:97.68ms +step:1129/1670 train_time:110284ms step_avg:97.68ms +step:1130/1670 train_time:110380ms step_avg:97.68ms +step:1131/1670 train_time:110476ms step_avg:97.68ms +step:1132/1670 train_time:110574ms step_avg:97.68ms +step:1133/1670 train_time:110670ms step_avg:97.68ms +step:1134/1670 train_time:110767ms step_avg:97.68ms +step:1135/1670 train_time:110864ms step_avg:97.68ms +step:1136/1670 train_time:110964ms step_avg:97.68ms +step:1137/1670 train_time:111064ms step_avg:97.68ms +step:1138/1670 train_time:111163ms step_avg:97.68ms +step:1139/1670 train_time:111261ms step_avg:97.68ms +step:1140/1670 train_time:111358ms step_avg:97.68ms +step:1141/1670 train_time:111455ms step_avg:97.68ms +step:1142/1670 train_time:111552ms step_avg:97.68ms +step:1143/1670 train_time:111649ms step_avg:97.68ms +step:1144/1670 train_time:111745ms step_avg:97.68ms +step:1145/1670 train_time:111841ms step_avg:97.68ms +step:1146/1670 train_time:111940ms step_avg:97.68ms +step:1147/1670 train_time:112040ms step_avg:97.68ms +step:1148/1670 train_time:112139ms step_avg:97.68ms +step:1149/1670 train_time:112238ms step_avg:97.68ms +step:1150/1670 train_time:112336ms step_avg:97.68ms +step:1151/1670 train_time:112434ms step_avg:97.68ms +step:1152/1670 train_time:112531ms step_avg:97.68ms +step:1153/1670 train_time:112628ms step_avg:97.68ms +step:1154/1670 train_time:112724ms step_avg:97.68ms +step:1155/1670 train_time:112821ms step_avg:97.68ms +step:1156/1670 train_time:112918ms step_avg:97.68ms +step:1157/1670 train_time:113017ms step_avg:97.68ms +step:1158/1670 train_time:113117ms step_avg:97.68ms +step:1159/1670 train_time:113217ms step_avg:97.68ms +step:1160/1670 train_time:113316ms step_avg:97.69ms +step:1161/1670 train_time:113415ms step_avg:97.69ms +step:1162/1670 train_time:113513ms step_avg:97.69ms +step:1163/1670 train_time:113609ms step_avg:97.69ms +step:1164/1670 train_time:113707ms step_avg:97.69ms +step:1165/1670 train_time:113804ms step_avg:97.69ms +step:1166/1670 train_time:113901ms step_avg:97.69ms +step:1167/1670 train_time:113999ms step_avg:97.69ms +step:1168/1670 train_time:114097ms step_avg:97.69ms +step:1169/1670 train_time:114195ms step_avg:97.69ms +step:1170/1670 train_time:114294ms step_avg:97.69ms +step:1171/1670 train_time:114393ms step_avg:97.69ms +step:1172/1670 train_time:114490ms step_avg:97.69ms +step:1173/1670 train_time:114587ms step_avg:97.69ms +step:1174/1670 train_time:114683ms step_avg:97.69ms +step:1175/1670 train_time:114780ms step_avg:97.69ms +step:1176/1670 train_time:114878ms step_avg:97.69ms +step:1177/1670 train_time:114977ms step_avg:97.69ms +step:1178/1670 train_time:115077ms step_avg:97.69ms +step:1179/1670 train_time:115177ms step_avg:97.69ms +step:1180/1670 train_time:115275ms step_avg:97.69ms +step:1181/1670 train_time:115374ms step_avg:97.69ms +step:1182/1670 train_time:115471ms step_avg:97.69ms +step:1183/1670 train_time:115569ms step_avg:97.69ms +step:1184/1670 train_time:115667ms step_avg:97.69ms +step:1185/1670 train_time:115765ms step_avg:97.69ms +step:1186/1670 train_time:115862ms step_avg:97.69ms +step:1187/1670 train_time:115959ms step_avg:97.69ms +step:1188/1670 train_time:116057ms step_avg:97.69ms +step:1189/1670 train_time:116155ms step_avg:97.69ms +step:1190/1670 train_time:116254ms step_avg:97.69ms +step:1191/1670 train_time:116353ms step_avg:97.69ms +step:1192/1670 train_time:116452ms step_avg:97.69ms +step:1193/1670 train_time:116550ms step_avg:97.69ms +step:1194/1670 train_time:116648ms step_avg:97.70ms +step:1195/1670 train_time:116745ms step_avg:97.69ms +step:1196/1670 train_time:116842ms step_avg:97.69ms +step:1197/1670 train_time:116940ms step_avg:97.69ms +step:1198/1670 train_time:117038ms step_avg:97.69ms +step:1199/1670 train_time:117135ms step_avg:97.69ms +step:1200/1670 train_time:117234ms step_avg:97.69ms +step:1201/1670 train_time:117331ms step_avg:97.69ms +step:1202/1670 train_time:117430ms step_avg:97.70ms +step:1203/1670 train_time:117527ms step_avg:97.70ms +step:1204/1670 train_time:117624ms step_avg:97.69ms +step:1205/1670 train_time:117722ms step_avg:97.69ms +step:1206/1670 train_time:117819ms step_avg:97.69ms +step:1207/1670 train_time:117917ms step_avg:97.69ms +step:1208/1670 train_time:118016ms step_avg:97.70ms +step:1209/1670 train_time:118113ms step_avg:97.70ms +step:1210/1670 train_time:118211ms step_avg:97.70ms +step:1211/1670 train_time:118309ms step_avg:97.70ms +step:1212/1670 train_time:118406ms step_avg:97.70ms +step:1213/1670 train_time:118504ms step_avg:97.70ms +step:1214/1670 train_time:118601ms step_avg:97.69ms +step:1215/1670 train_time:118699ms step_avg:97.69ms +step:1216/1670 train_time:118797ms step_avg:97.69ms +step:1217/1670 train_time:118894ms step_avg:97.69ms +step:1218/1670 train_time:118993ms step_avg:97.70ms +step:1219/1670 train_time:119091ms step_avg:97.70ms +step:1220/1670 train_time:119190ms step_avg:97.70ms +step:1221/1670 train_time:119288ms step_avg:97.70ms +step:1222/1670 train_time:119385ms step_avg:97.70ms +step:1223/1670 train_time:119482ms step_avg:97.70ms +step:1224/1670 train_time:119580ms step_avg:97.70ms +step:1225/1670 train_time:119677ms step_avg:97.70ms +step:1226/1670 train_time:119774ms step_avg:97.70ms +step:1227/1670 train_time:119873ms step_avg:97.70ms +step:1228/1670 train_time:119971ms step_avg:97.70ms +step:1229/1670 train_time:120068ms step_avg:97.70ms +step:1230/1670 train_time:120165ms step_avg:97.70ms +step:1231/1670 train_time:120263ms step_avg:97.70ms +step:1232/1670 train_time:120360ms step_avg:97.69ms +step:1233/1670 train_time:120458ms step_avg:97.69ms +step:1234/1670 train_time:120557ms step_avg:97.70ms +step:1235/1670 train_time:120655ms step_avg:97.70ms +step:1236/1670 train_time:120754ms step_avg:97.70ms +step:1237/1670 train_time:120852ms step_avg:97.70ms +step:1238/1670 train_time:120949ms step_avg:97.70ms +step:1239/1670 train_time:121047ms step_avg:97.70ms +step:1240/1670 train_time:121145ms step_avg:97.70ms +step:1241/1670 train_time:121243ms step_avg:97.70ms +step:1242/1670 train_time:121340ms step_avg:97.70ms +step:1243/1670 train_time:121437ms step_avg:97.70ms +step:1244/1670 train_time:121536ms step_avg:97.70ms +step:1245/1670 train_time:121635ms step_avg:97.70ms +step:1246/1670 train_time:121732ms step_avg:97.70ms +step:1247/1670 train_time:121830ms step_avg:97.70ms +step:1248/1670 train_time:121927ms step_avg:97.70ms +step:1249/1670 train_time:122024ms step_avg:97.70ms +step:1250/1670 train_time:122121ms step_avg:97.70ms +step:1250/1670 val_loss:3.3812 train_time:122219ms step_avg:97.77ms +step:1251/1670 train_time:122239ms step_avg:97.71ms +step:1252/1670 train_time:122323ms step_avg:97.70ms +step:1253/1670 train_time:122423ms step_avg:97.70ms +step:1254/1670 train_time:122521ms step_avg:97.70ms +step:1255/1670 train_time:122618ms step_avg:97.70ms +step:1256/1670 train_time:122714ms step_avg:97.70ms +step:1257/1670 train_time:122811ms step_avg:97.70ms +step:1258/1670 train_time:122908ms step_avg:97.70ms +step:1259/1670 train_time:123005ms step_avg:97.70ms +step:1260/1670 train_time:123102ms step_avg:97.70ms +step:1261/1670 train_time:123201ms step_avg:97.70ms +step:1262/1670 train_time:123301ms step_avg:97.70ms +step:1263/1670 train_time:123400ms step_avg:97.70ms +step:1264/1670 train_time:123498ms step_avg:97.70ms +step:1265/1670 train_time:123596ms step_avg:97.70ms +step:1266/1670 train_time:123692ms step_avg:97.70ms +step:1267/1670 train_time:123789ms step_avg:97.70ms +step:1268/1670 train_time:123886ms step_avg:97.70ms +step:1269/1670 train_time:123984ms step_avg:97.70ms +step:1270/1670 train_time:124082ms step_avg:97.70ms +step:1271/1670 train_time:124179ms step_avg:97.70ms +step:1272/1670 train_time:124278ms step_avg:97.70ms +step:1273/1670 train_time:124376ms step_avg:97.70ms +step:1274/1670 train_time:124644ms step_avg:97.84ms +step:1275/1670 train_time:124845ms step_avg:97.92ms +step:1276/1670 train_time:124940ms step_avg:97.92ms +step:1277/1670 train_time:125037ms step_avg:97.91ms +step:1278/1670 train_time:125133ms step_avg:97.91ms +step:1279/1670 train_time:125229ms step_avg:97.91ms +step:1280/1670 train_time:125327ms step_avg:97.91ms +step:1281/1670 train_time:125424ms step_avg:97.91ms +step:1282/1670 train_time:125521ms step_avg:97.91ms +step:1283/1670 train_time:125617ms step_avg:97.91ms +step:1284/1670 train_time:125719ms step_avg:97.91ms +step:1285/1670 train_time:125818ms step_avg:97.91ms +step:1286/1670 train_time:125916ms step_avg:97.91ms +step:1287/1670 train_time:126013ms step_avg:97.91ms +step:1288/1670 train_time:126110ms step_avg:97.91ms +step:1289/1670 train_time:126209ms step_avg:97.91ms +step:1290/1670 train_time:126306ms step_avg:97.91ms +step:1291/1670 train_time:126403ms step_avg:97.91ms +step:1292/1670 train_time:126500ms step_avg:97.91ms +step:1293/1670 train_time:126596ms step_avg:97.91ms +step:1294/1670 train_time:126694ms step_avg:97.91ms +step:1295/1670 train_time:126794ms step_avg:97.91ms +step:1296/1670 train_time:126892ms step_avg:97.91ms +step:1297/1670 train_time:126992ms step_avg:97.91ms +step:1298/1670 train_time:127089ms step_avg:97.91ms +step:1299/1670 train_time:127188ms step_avg:97.91ms +step:1300/1670 train_time:127285ms step_avg:97.91ms +step:1301/1670 train_time:127383ms step_avg:97.91ms +step:1302/1670 train_time:127479ms step_avg:97.91ms +step:1303/1670 train_time:127576ms step_avg:97.91ms +step:1304/1670 train_time:127674ms step_avg:97.91ms +step:1305/1670 train_time:127772ms step_avg:97.91ms +step:1306/1670 train_time:127871ms step_avg:97.91ms +step:1307/1670 train_time:127970ms step_avg:97.91ms +step:1308/1670 train_time:128069ms step_avg:97.91ms +step:1309/1670 train_time:128167ms step_avg:97.91ms +step:1310/1670 train_time:128265ms step_avg:97.91ms +step:1311/1670 train_time:128363ms step_avg:97.91ms +step:1312/1670 train_time:128460ms step_avg:97.91ms +step:1313/1670 train_time:128557ms step_avg:97.91ms +step:1314/1670 train_time:128655ms step_avg:97.91ms +step:1315/1670 train_time:128753ms step_avg:97.91ms +step:1316/1670 train_time:128851ms step_avg:97.91ms +step:1317/1670 train_time:128949ms step_avg:97.91ms +step:1318/1670 train_time:129049ms step_avg:97.91ms +step:1319/1670 train_time:129148ms step_avg:97.91ms +step:1320/1670 train_time:129246ms step_avg:97.91ms +step:1321/1670 train_time:129345ms step_avg:97.91ms +step:1322/1670 train_time:129443ms step_avg:97.91ms +step:1323/1670 train_time:129541ms step_avg:97.91ms +step:1324/1670 train_time:129638ms step_avg:97.91ms +step:1325/1670 train_time:129736ms step_avg:97.91ms +step:1326/1670 train_time:129834ms step_avg:97.91ms +step:1327/1670 train_time:129933ms step_avg:97.91ms +step:1328/1670 train_time:130030ms step_avg:97.91ms +step:1329/1670 train_time:130127ms step_avg:97.91ms +step:1330/1670 train_time:130225ms step_avg:97.91ms +step:1331/1670 train_time:130323ms step_avg:97.91ms +step:1332/1670 train_time:130422ms step_avg:97.91ms +step:1333/1670 train_time:130519ms step_avg:97.91ms +step:1334/1670 train_time:130617ms step_avg:97.91ms +step:1335/1670 train_time:130714ms step_avg:97.91ms +step:1336/1670 train_time:130812ms step_avg:97.91ms +step:1337/1670 train_time:130910ms step_avg:97.91ms +step:1338/1670 train_time:131008ms step_avg:97.91ms +step:1339/1670 train_time:131106ms step_avg:97.91ms +step:1340/1670 train_time:131204ms step_avg:97.91ms +step:1341/1670 train_time:131302ms step_avg:97.91ms +step:1342/1670 train_time:131401ms step_avg:97.91ms +step:1343/1670 train_time:131498ms step_avg:97.91ms +step:1344/1670 train_time:131594ms step_avg:97.91ms +step:1345/1670 train_time:131692ms step_avg:97.91ms +step:1346/1670 train_time:131790ms step_avg:97.91ms +step:1347/1670 train_time:131889ms step_avg:97.91ms +step:1348/1670 train_time:131988ms step_avg:97.91ms +step:1349/1670 train_time:132085ms step_avg:97.91ms +step:1350/1670 train_time:132183ms step_avg:97.91ms +step:1351/1670 train_time:132281ms step_avg:97.91ms +step:1352/1670 train_time:132378ms step_avg:97.91ms +step:1353/1670 train_time:132476ms step_avg:97.91ms +step:1354/1670 train_time:132573ms step_avg:97.91ms +step:1355/1670 train_time:132670ms step_avg:97.91ms +step:1356/1670 train_time:132768ms step_avg:97.91ms +step:1357/1670 train_time:132867ms step_avg:97.91ms +step:1358/1670 train_time:132966ms step_avg:97.91ms +step:1359/1670 train_time:133063ms step_avg:97.91ms +step:1360/1670 train_time:133161ms step_avg:97.91ms +step:1361/1670 train_time:133258ms step_avg:97.91ms +step:1362/1670 train_time:133356ms step_avg:97.91ms +step:1363/1670 train_time:133453ms step_avg:97.91ms +step:1364/1670 train_time:133550ms step_avg:97.91ms +step:1365/1670 train_time:133649ms step_avg:97.91ms +step:1366/1670 train_time:133747ms step_avg:97.91ms +step:1367/1670 train_time:133846ms step_avg:97.91ms +step:1368/1670 train_time:133943ms step_avg:97.91ms +step:1369/1670 train_time:134043ms step_avg:97.91ms +step:1370/1670 train_time:134141ms step_avg:97.91ms +step:1371/1670 train_time:134239ms step_avg:97.91ms +step:1372/1670 train_time:134336ms step_avg:97.91ms +step:1373/1670 train_time:134434ms step_avg:97.91ms +step:1374/1670 train_time:134531ms step_avg:97.91ms +step:1375/1670 train_time:134630ms step_avg:97.91ms +step:1375/1670 val_loss:3.3441 train_time:134727ms step_avg:97.98ms +step:1376/1670 train_time:134749ms step_avg:97.93ms +step:1377/1670 train_time:134833ms step_avg:97.92ms +step:1378/1670 train_time:134933ms step_avg:97.92ms +step:1379/1670 train_time:135030ms step_avg:97.92ms +step:1380/1670 train_time:135127ms step_avg:97.92ms +step:1381/1670 train_time:135223ms step_avg:97.92ms +step:1382/1670 train_time:135320ms step_avg:97.92ms +step:1383/1670 train_time:135416ms step_avg:97.91ms +step:1384/1670 train_time:135514ms step_avg:97.91ms +step:1385/1670 train_time:135612ms step_avg:97.91ms +step:1386/1670 train_time:135712ms step_avg:97.92ms +step:1387/1670 train_time:135813ms step_avg:97.92ms +step:1388/1670 train_time:135913ms step_avg:97.92ms +step:1389/1670 train_time:136011ms step_avg:97.92ms +step:1390/1670 train_time:136108ms step_avg:97.92ms +step:1391/1670 train_time:136205ms step_avg:97.92ms +step:1392/1670 train_time:136301ms step_avg:97.92ms +step:1393/1670 train_time:136398ms step_avg:97.92ms +step:1394/1670 train_time:136495ms step_avg:97.92ms +step:1395/1670 train_time:136592ms step_avg:97.92ms +step:1396/1670 train_time:136691ms step_avg:97.92ms +step:1397/1670 train_time:136790ms step_avg:97.92ms +step:1398/1670 train_time:136890ms step_avg:97.92ms +step:1399/1670 train_time:136989ms step_avg:97.92ms +step:1400/1670 train_time:137087ms step_avg:97.92ms +step:1401/1670 train_time:137186ms step_avg:97.92ms +step:1402/1670 train_time:137283ms step_avg:97.92ms +step:1403/1670 train_time:137380ms step_avg:97.92ms +step:1404/1670 train_time:137476ms step_avg:97.92ms +step:1405/1670 train_time:137573ms step_avg:97.92ms +step:1406/1670 train_time:137671ms step_avg:97.92ms +step:1407/1670 train_time:137770ms step_avg:97.92ms +step:1408/1670 train_time:137869ms step_avg:97.92ms +step:1409/1670 train_time:137967ms step_avg:97.92ms +step:1410/1670 train_time:138064ms step_avg:97.92ms +step:1411/1670 train_time:138162ms step_avg:97.92ms +step:1412/1670 train_time:138260ms step_avg:97.92ms +step:1413/1670 train_time:138358ms step_avg:97.92ms +step:1414/1670 train_time:138455ms step_avg:97.92ms +step:1415/1670 train_time:138552ms step_avg:97.92ms +step:1416/1670 train_time:138649ms step_avg:97.92ms +step:1417/1670 train_time:138747ms step_avg:97.92ms +step:1418/1670 train_time:138844ms step_avg:97.92ms +step:1419/1670 train_time:138942ms step_avg:97.92ms +step:1420/1670 train_time:139040ms step_avg:97.92ms +step:1421/1670 train_time:139139ms step_avg:97.92ms +step:1422/1670 train_time:139238ms step_avg:97.92ms +step:1423/1670 train_time:139336ms step_avg:97.92ms +step:1424/1670 train_time:139433ms step_avg:97.92ms +step:1425/1670 train_time:139530ms step_avg:97.92ms +step:1426/1670 train_time:139628ms step_avg:97.92ms +step:1427/1670 train_time:139725ms step_avg:97.92ms +step:1428/1670 train_time:139823ms step_avg:97.92ms +step:1429/1670 train_time:139920ms step_avg:97.91ms +step:1430/1670 train_time:140018ms step_avg:97.91ms +step:1431/1670 train_time:140117ms step_avg:97.92ms +step:1432/1670 train_time:140215ms step_avg:97.92ms +step:1433/1670 train_time:140314ms step_avg:97.92ms +step:1434/1670 train_time:140411ms step_avg:97.92ms +step:1435/1670 train_time:140509ms step_avg:97.92ms +step:1436/1670 train_time:140605ms step_avg:97.91ms +step:1437/1670 train_time:140703ms step_avg:97.91ms +step:1438/1670 train_time:140800ms step_avg:97.91ms +step:1439/1670 train_time:140899ms step_avg:97.91ms +step:1440/1670 train_time:140997ms step_avg:97.91ms +step:1441/1670 train_time:141096ms step_avg:97.92ms +step:1442/1670 train_time:141195ms step_avg:97.92ms +step:1443/1670 train_time:141292ms step_avg:97.92ms +step:1444/1670 train_time:141391ms step_avg:97.92ms +step:1445/1670 train_time:141488ms step_avg:97.92ms +step:1446/1670 train_time:141586ms step_avg:97.92ms +step:1447/1670 train_time:141683ms step_avg:97.92ms +step:1448/1670 train_time:141781ms step_avg:97.92ms +step:1449/1670 train_time:141879ms step_avg:97.91ms +step:1450/1670 train_time:141976ms step_avg:97.91ms +step:1451/1670 train_time:142073ms step_avg:97.91ms +step:1452/1670 train_time:142172ms step_avg:97.91ms +step:1453/1670 train_time:142270ms step_avg:97.91ms +step:1454/1670 train_time:142368ms step_avg:97.91ms +step:1455/1670 train_time:142464ms step_avg:97.91ms +step:1456/1670 train_time:142562ms step_avg:97.91ms +step:1457/1670 train_time:142659ms step_avg:97.91ms +step:1458/1670 train_time:142757ms step_avg:97.91ms +step:1459/1670 train_time:142855ms step_avg:97.91ms +step:1460/1670 train_time:142954ms step_avg:97.91ms +step:1461/1670 train_time:143052ms step_avg:97.91ms +step:1462/1670 train_time:143151ms step_avg:97.91ms +step:1463/1670 train_time:143249ms step_avg:97.91ms +step:1464/1670 train_time:143347ms step_avg:97.91ms +step:1465/1670 train_time:143444ms step_avg:97.91ms +step:1466/1670 train_time:143541ms step_avg:97.91ms +step:1467/1670 train_time:143639ms step_avg:97.91ms +step:1468/1670 train_time:143737ms step_avg:97.91ms +step:1469/1670 train_time:143837ms step_avg:97.91ms +step:1470/1670 train_time:143935ms step_avg:97.91ms +step:1471/1670 train_time:144033ms step_avg:97.92ms +step:1472/1670 train_time:144133ms step_avg:97.92ms +step:1473/1670 train_time:144231ms step_avg:97.92ms +step:1474/1670 train_time:144329ms step_avg:97.92ms +step:1475/1670 train_time:144427ms step_avg:97.92ms +step:1476/1670 train_time:144525ms step_avg:97.92ms +step:1477/1670 train_time:144623ms step_avg:97.92ms +step:1478/1670 train_time:144720ms step_avg:97.92ms +step:1479/1670 train_time:144817ms step_avg:97.92ms +step:1480/1670 train_time:144915ms step_avg:97.92ms +step:1481/1670 train_time:145012ms step_avg:97.92ms +step:1482/1670 train_time:145110ms step_avg:97.91ms +step:1483/1670 train_time:145208ms step_avg:97.91ms +step:1484/1670 train_time:145305ms step_avg:97.91ms +step:1485/1670 train_time:145583ms step_avg:98.04ms +step:1486/1670 train_time:145760ms step_avg:98.09ms +step:1487/1670 train_time:145855ms step_avg:98.09ms +step:1488/1670 train_time:145951ms step_avg:98.09ms +step:1489/1670 train_time:146048ms step_avg:98.08ms +step:1490/1670 train_time:146144ms step_avg:98.08ms +step:1491/1670 train_time:146241ms step_avg:98.08ms +step:1492/1670 train_time:146337ms step_avg:98.08ms +step:1493/1670 train_time:146434ms step_avg:98.08ms +step:1494/1670 train_time:146531ms step_avg:98.08ms +step:1495/1670 train_time:146632ms step_avg:98.08ms +step:1496/1670 train_time:146734ms step_avg:98.08ms +step:1497/1670 train_time:146835ms step_avg:98.09ms +step:1498/1670 train_time:146934ms step_avg:98.09ms +step:1499/1670 train_time:147033ms step_avg:98.09ms +step:1500/1670 train_time:147132ms step_avg:98.09ms +step:1500/1670 val_loss:3.3122 train_time:147230ms step_avg:98.15ms +step:1501/1670 train_time:147251ms step_avg:98.10ms +step:1502/1670 train_time:147334ms step_avg:98.09ms +step:1503/1670 train_time:147433ms step_avg:98.09ms +step:1504/1670 train_time:147530ms step_avg:98.09ms +step:1505/1670 train_time:147628ms step_avg:98.09ms +step:1506/1670 train_time:147725ms step_avg:98.09ms +step:1507/1670 train_time:147822ms step_avg:98.09ms +step:1508/1670 train_time:147920ms step_avg:98.09ms +step:1509/1670 train_time:148017ms step_avg:98.09ms +step:1510/1670 train_time:148114ms step_avg:98.09ms +step:1511/1670 train_time:148213ms step_avg:98.09ms +step:1512/1670 train_time:148311ms step_avg:98.09ms +step:1513/1670 train_time:148409ms step_avg:98.09ms +step:1514/1670 train_time:148508ms step_avg:98.09ms +step:1515/1670 train_time:148605ms step_avg:98.09ms +step:1516/1670 train_time:148702ms step_avg:98.09ms +step:1517/1670 train_time:148801ms step_avg:98.09ms +step:1518/1670 train_time:148899ms step_avg:98.09ms +step:1519/1670 train_time:148995ms step_avg:98.09ms +step:1520/1670 train_time:149092ms step_avg:98.09ms +step:1521/1670 train_time:149190ms step_avg:98.09ms +step:1522/1670 train_time:149289ms step_avg:98.09ms +step:1523/1670 train_time:149387ms step_avg:98.09ms +step:1524/1670 train_time:149485ms step_avg:98.09ms +step:1525/1670 train_time:149583ms step_avg:98.09ms +step:1526/1670 train_time:149681ms step_avg:98.09ms +step:1527/1670 train_time:149778ms step_avg:98.09ms +step:1528/1670 train_time:149875ms step_avg:98.09ms +step:1529/1670 train_time:149972ms step_avg:98.09ms +step:1530/1670 train_time:150069ms step_avg:98.08ms +step:1531/1670 train_time:150166ms step_avg:98.08ms +step:1532/1670 train_time:150266ms step_avg:98.08ms +step:1533/1670 train_time:150365ms step_avg:98.09ms +step:1534/1670 train_time:150465ms step_avg:98.09ms +step:1535/1670 train_time:150563ms step_avg:98.09ms +step:1536/1670 train_time:150661ms step_avg:98.09ms +step:1537/1670 train_time:150758ms step_avg:98.09ms +step:1538/1670 train_time:150855ms step_avg:98.09ms +step:1539/1670 train_time:150952ms step_avg:98.08ms +step:1540/1670 train_time:151049ms step_avg:98.08ms +step:1541/1670 train_time:151146ms step_avg:98.08ms +step:1542/1670 train_time:151244ms step_avg:98.08ms +step:1543/1670 train_time:151342ms step_avg:98.08ms +step:1544/1670 train_time:151441ms step_avg:98.08ms +step:1545/1670 train_time:151540ms step_avg:98.08ms +step:1546/1670 train_time:151638ms step_avg:98.08ms +step:1547/1670 train_time:151736ms step_avg:98.08ms +step:1548/1670 train_time:151833ms step_avg:98.08ms +step:1549/1670 train_time:151931ms step_avg:98.08ms +step:1550/1670 train_time:152027ms step_avg:98.08ms +step:1551/1670 train_time:152126ms step_avg:98.08ms +step:1552/1670 train_time:152224ms step_avg:98.08ms +step:1553/1670 train_time:152323ms step_avg:98.08ms +step:1554/1670 train_time:152421ms step_avg:98.08ms +step:1555/1670 train_time:152519ms step_avg:98.08ms +step:1556/1670 train_time:152617ms step_avg:98.08ms +step:1557/1670 train_time:152714ms step_avg:98.08ms +step:1558/1670 train_time:152812ms step_avg:98.08ms +step:1559/1670 train_time:152908ms step_avg:98.08ms +step:1560/1670 train_time:153007ms step_avg:98.08ms +step:1561/1670 train_time:153104ms step_avg:98.08ms +step:1562/1670 train_time:153202ms step_avg:98.08ms +step:1563/1670 train_time:153302ms step_avg:98.08ms +step:1564/1670 train_time:153401ms step_avg:98.08ms +step:1565/1670 train_time:153499ms step_avg:98.08ms +step:1566/1670 train_time:153597ms step_avg:98.08ms +step:1567/1670 train_time:153694ms step_avg:98.08ms +step:1568/1670 train_time:153792ms step_avg:98.08ms +step:1569/1670 train_time:153889ms step_avg:98.08ms +step:1570/1670 train_time:153987ms step_avg:98.08ms +step:1571/1670 train_time:154084ms step_avg:98.08ms +step:1572/1670 train_time:154182ms step_avg:98.08ms +step:1573/1670 train_time:154279ms step_avg:98.08ms +step:1574/1670 train_time:154377ms step_avg:98.08ms +step:1575/1670 train_time:154474ms step_avg:98.08ms +step:1576/1670 train_time:154571ms step_avg:98.08ms +step:1577/1670 train_time:154668ms step_avg:98.08ms +step:1578/1670 train_time:154767ms step_avg:98.08ms +step:1579/1670 train_time:154865ms step_avg:98.08ms +step:1580/1670 train_time:154964ms step_avg:98.08ms +step:1581/1670 train_time:155062ms step_avg:98.08ms +step:1582/1670 train_time:155161ms step_avg:98.08ms +step:1583/1670 train_time:155258ms step_avg:98.08ms +step:1584/1670 train_time:155355ms step_avg:98.08ms +step:1585/1670 train_time:155453ms step_avg:98.08ms +step:1586/1670 train_time:155550ms step_avg:98.08ms +step:1587/1670 train_time:155648ms step_avg:98.08ms +step:1588/1670 train_time:155745ms step_avg:98.08ms +step:1589/1670 train_time:155843ms step_avg:98.08ms +step:1590/1670 train_time:155942ms step_avg:98.08ms +step:1591/1670 train_time:156040ms step_avg:98.08ms +step:1592/1670 train_time:156138ms step_avg:98.08ms +step:1593/1670 train_time:156235ms step_avg:98.08ms +step:1594/1670 train_time:156332ms step_avg:98.08ms +step:1595/1670 train_time:156431ms step_avg:98.08ms +step:1596/1670 train_time:156528ms step_avg:98.08ms +step:1597/1670 train_time:156626ms step_avg:98.07ms +step:1598/1670 train_time:156724ms step_avg:98.08ms +step:1599/1670 train_time:156823ms step_avg:98.08ms +step:1600/1670 train_time:156921ms step_avg:98.08ms +step:1601/1670 train_time:157019ms step_avg:98.08ms +step:1602/1670 train_time:157116ms step_avg:98.07ms +step:1603/1670 train_time:157214ms step_avg:98.07ms +step:1604/1670 train_time:157311ms step_avg:98.07ms +step:1605/1670 train_time:157408ms step_avg:98.07ms +step:1606/1670 train_time:157505ms step_avg:98.07ms +step:1607/1670 train_time:157604ms step_avg:98.07ms +step:1608/1670 train_time:157702ms step_avg:98.07ms +step:1609/1670 train_time:157800ms step_avg:98.07ms +step:1610/1670 train_time:157898ms step_avg:98.07ms +step:1611/1670 train_time:157996ms step_avg:98.07ms +step:1612/1670 train_time:158093ms step_avg:98.07ms +step:1613/1670 train_time:158191ms step_avg:98.07ms +step:1614/1670 train_time:158288ms step_avg:98.07ms +step:1615/1670 train_time:158387ms step_avg:98.07ms +step:1616/1670 train_time:158485ms step_avg:98.07ms +step:1617/1670 train_time:158582ms step_avg:98.07ms +step:1618/1670 train_time:158680ms step_avg:98.07ms +step:1619/1670 train_time:158779ms step_avg:98.07ms +step:1620/1670 train_time:158876ms step_avg:98.07ms +step:1621/1670 train_time:158973ms step_avg:98.07ms +step:1622/1670 train_time:159071ms step_avg:98.07ms +step:1623/1670 train_time:159168ms step_avg:98.07ms +step:1624/1670 train_time:159266ms step_avg:98.07ms +step:1625/1670 train_time:159364ms step_avg:98.07ms +step:1625/1670 val_loss:3.2853 train_time:159462ms step_avg:98.13ms +step:1626/1670 train_time:159483ms step_avg:98.08ms +step:1627/1670 train_time:159567ms step_avg:98.07ms +step:1628/1670 train_time:159668ms step_avg:98.08ms +step:1629/1670 train_time:159766ms step_avg:98.08ms +step:1630/1670 train_time:159863ms step_avg:98.08ms +step:1631/1670 train_time:159959ms step_avg:98.07ms +step:1632/1670 train_time:160056ms step_avg:98.07ms +step:1633/1670 train_time:160153ms step_avg:98.07ms +step:1634/1670 train_time:160251ms step_avg:98.07ms +step:1635/1670 train_time:160348ms step_avg:98.07ms +step:1636/1670 train_time:160447ms step_avg:98.07ms +step:1637/1670 train_time:160547ms step_avg:98.07ms +step:1638/1670 train_time:160647ms step_avg:98.08ms +step:1639/1670 train_time:160746ms step_avg:98.08ms +step:1640/1670 train_time:160844ms step_avg:98.08ms +step:1641/1670 train_time:160941ms step_avg:98.07ms +step:1642/1670 train_time:161038ms step_avg:98.07ms +step:1643/1670 train_time:161135ms step_avg:98.07ms +step:1644/1670 train_time:161232ms step_avg:98.07ms +step:1645/1670 train_time:161329ms step_avg:98.07ms +step:1646/1670 train_time:161427ms step_avg:98.07ms +step:1647/1670 train_time:161526ms step_avg:98.07ms +step:1648/1670 train_time:161625ms step_avg:98.07ms +step:1649/1670 train_time:161723ms step_avg:98.07ms +step:1650/1670 train_time:161821ms step_avg:98.07ms +step:1651/1670 train_time:161918ms step_avg:98.07ms +step:1652/1670 train_time:162015ms step_avg:98.07ms +step:1653/1670 train_time:162113ms step_avg:98.07ms +step:1654/1670 train_time:162210ms step_avg:98.07ms +step:1655/1670 train_time:162308ms step_avg:98.07ms +step:1656/1670 train_time:162406ms step_avg:98.07ms +step:1657/1670 train_time:162504ms step_avg:98.07ms +step:1658/1670 train_time:162603ms step_avg:98.07ms +step:1659/1670 train_time:162701ms step_avg:98.07ms +step:1660/1670 train_time:162798ms step_avg:98.07ms +step:1661/1670 train_time:162896ms step_avg:98.07ms +step:1662/1670 train_time:162994ms step_avg:98.07ms +step:1663/1670 train_time:163091ms step_avg:98.07ms +step:1664/1670 train_time:163189ms step_avg:98.07ms +step:1665/1670 train_time:163286ms step_avg:98.07ms +step:1666/1670 train_time:163384ms step_avg:98.07ms +step:1667/1670 train_time:163482ms step_avg:98.07ms +step:1668/1670 train_time:163579ms step_avg:98.07ms +step:1669/1670 train_time:163676ms step_avg:98.07ms +step:1670/1670 train_time:163774ms step_avg:98.07ms +step:1670/1670 val_loss:3.2771 train_time:163871ms step_avg:98.13ms +peak memory allocated: 34361 MiB reserved: 49276 MiB diff --git a/records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt b/records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt new file mode 100644 index 000000000..80e318320 --- /dev/null +++ b/records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:44:22 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 37C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 36C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 49666 C /usr/bin/python 0MiB | +| 0 N/A N/A 49667 C /usr/bin/python 0MiB | +| 0 N/A N/A 49668 C /usr/bin/python 0MiB | +| 0 N/A N/A 49669 C /usr/bin/python 0MiB | +| 0 N/A N/A 49670 C /usr/bin/python 0MiB | +| 0 N/A N/A 49671 C /usr/bin/python 0MiB | +| 0 N/A N/A 49672 C /usr/bin/python 0MiB | +| 0 N/A N/A 49673 C /usr/bin/python 0MiB | +| 1 N/A N/A 49667 C /usr/bin/python 0MiB | +| 2 N/A N/A 49668 C /usr/bin/python 0MiB | +| 3 N/A N/A 49669 C /usr/bin/python 0MiB | +| 4 N/A N/A 49670 C /usr/bin/python 0MiB | +| 5 N/A N/A 49671 C /usr/bin/python 0MiB | +| 6 N/A N/A 49672 C /usr/bin/python 0MiB | +| 7 N/A N/A 49673 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:380ms step_avg:379.58ms +step:2/1670 train_time:400ms step_avg:200.05ms +step:3/1670 train_time:473ms step_avg:157.71ms +step:4/1670 train_time:567ms step_avg:141.68ms +step:5/1670 train_time:661ms step_avg:132.24ms +step:6/1670 train_time:756ms step_avg:125.95ms +step:7/1670 train_time:850ms step_avg:121.42ms +step:8/1670 train_time:945ms step_avg:118.11ms +step:9/1670 train_time:1040ms step_avg:115.55ms +step:10/1670 train_time:1135ms step_avg:113.51ms +step:11/1670 train_time:1230ms step_avg:111.83ms +step:12/1670 train_time:1327ms step_avg:110.61ms +step:13/1670 train_time:1426ms step_avg:109.69ms +step:14/1670 train_time:1524ms step_avg:108.84ms +step:15/1670 train_time:1620ms step_avg:107.99ms +step:16/1670 train_time:1715ms step_avg:107.20ms +step:17/1670 train_time:1811ms step_avg:106.51ms +step:18/1670 train_time:1906ms step_avg:105.87ms +step:19/1670 train_time:2001ms step_avg:105.33ms +step:20/1670 train_time:2097ms step_avg:104.83ms +step:21/1670 train_time:2192ms step_avg:104.38ms +step:22/1670 train_time:2288ms step_avg:103.99ms +step:23/1670 train_time:2385ms step_avg:103.68ms +step:24/1670 train_time:2482ms step_avg:103.43ms +step:25/1670 train_time:2579ms step_avg:103.15ms +step:26/1670 train_time:2675ms step_avg:102.88ms +step:27/1670 train_time:2770ms step_avg:102.60ms +step:28/1670 train_time:2865ms step_avg:102.34ms +step:29/1670 train_time:2961ms step_avg:102.09ms +step:30/1670 train_time:3056ms step_avg:101.86ms +step:31/1670 train_time:3151ms step_avg:101.65ms +step:32/1670 train_time:3247ms step_avg:101.47ms +step:33/1670 train_time:3344ms step_avg:101.32ms +step:34/1670 train_time:3440ms step_avg:101.18ms +step:35/1670 train_time:3537ms step_avg:101.05ms +step:36/1670 train_time:3633ms step_avg:100.92ms +step:37/1670 train_time:3729ms step_avg:100.80ms +step:38/1670 train_time:3825ms step_avg:100.65ms +step:39/1670 train_time:3920ms step_avg:100.51ms +step:40/1670 train_time:4016ms step_avg:100.40ms +step:41/1670 train_time:4111ms step_avg:100.27ms +step:42/1670 train_time:4206ms step_avg:100.15ms +step:43/1670 train_time:4302ms step_avg:100.05ms +step:44/1670 train_time:4398ms step_avg:99.96ms +step:45/1670 train_time:4495ms step_avg:99.89ms +step:46/1670 train_time:4591ms step_avg:99.80ms +step:47/1670 train_time:4687ms step_avg:99.72ms +step:48/1670 train_time:4782ms step_avg:99.63ms +step:49/1670 train_time:4879ms step_avg:99.57ms +step:50/1670 train_time:4974ms step_avg:99.48ms +step:51/1670 train_time:5069ms step_avg:99.40ms +step:52/1670 train_time:5165ms step_avg:99.32ms +step:53/1670 train_time:5260ms step_avg:99.25ms +step:54/1670 train_time:5356ms step_avg:99.19ms +step:55/1670 train_time:5452ms step_avg:99.13ms +step:56/1670 train_time:5547ms step_avg:99.06ms +step:57/1670 train_time:5643ms step_avg:99.01ms +step:58/1670 train_time:5741ms step_avg:98.98ms +step:59/1670 train_time:5837ms step_avg:98.93ms +step:60/1670 train_time:5933ms step_avg:98.88ms +step:61/1670 train_time:6028ms step_avg:98.82ms +step:62/1670 train_time:6124ms step_avg:98.77ms +step:63/1670 train_time:6220ms step_avg:98.73ms +step:64/1670 train_time:6316ms step_avg:98.69ms +step:65/1670 train_time:6412ms step_avg:98.64ms +step:66/1670 train_time:6507ms step_avg:98.60ms +step:67/1670 train_time:6604ms step_avg:98.56ms +step:68/1670 train_time:6700ms step_avg:98.53ms +step:69/1670 train_time:6797ms step_avg:98.50ms +step:70/1670 train_time:6892ms step_avg:98.46ms +step:71/1670 train_time:6988ms step_avg:98.42ms +step:72/1670 train_time:7084ms step_avg:98.39ms +step:73/1670 train_time:7180ms step_avg:98.35ms +step:74/1670 train_time:7276ms step_avg:98.32ms +step:75/1670 train_time:7371ms step_avg:98.28ms +step:76/1670 train_time:7466ms step_avg:98.24ms +step:77/1670 train_time:7563ms step_avg:98.21ms +step:78/1670 train_time:7658ms step_avg:98.18ms +step:79/1670 train_time:7754ms step_avg:98.16ms +step:80/1670 train_time:7850ms step_avg:98.12ms +step:81/1670 train_time:7945ms step_avg:98.09ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8137ms step_avg:98.04ms +step:84/1670 train_time:8233ms step_avg:98.01ms +step:85/1670 train_time:8329ms step_avg:97.99ms +step:86/1670 train_time:8425ms step_avg:97.96ms +step:87/1670 train_time:8522ms step_avg:97.96ms +step:88/1670 train_time:8618ms step_avg:97.94ms +step:89/1670 train_time:8714ms step_avg:97.91ms +step:90/1670 train_time:8810ms step_avg:97.89ms +step:91/1670 train_time:8905ms step_avg:97.86ms +step:92/1670 train_time:9001ms step_avg:97.83ms +step:93/1670 train_time:9097ms step_avg:97.82ms +step:94/1670 train_time:9192ms step_avg:97.79ms +step:95/1670 train_time:9288ms step_avg:97.77ms +step:96/1670 train_time:9383ms step_avg:97.74ms +step:97/1670 train_time:9479ms step_avg:97.72ms +step:98/1670 train_time:9575ms step_avg:97.70ms +step:99/1670 train_time:9670ms step_avg:97.68ms +step:100/1670 train_time:9766ms step_avg:97.66ms +step:101/1670 train_time:9863ms step_avg:97.65ms +step:102/1670 train_time:9959ms step_avg:97.64ms +step:103/1670 train_time:10055ms step_avg:97.62ms +step:104/1670 train_time:10151ms step_avg:97.60ms +step:105/1670 train_time:10246ms step_avg:97.58ms +step:106/1670 train_time:10342ms step_avg:97.57ms +step:107/1670 train_time:10438ms step_avg:97.55ms +step:108/1670 train_time:10533ms step_avg:97.53ms +step:109/1670 train_time:10630ms step_avg:97.52ms +step:110/1670 train_time:10725ms step_avg:97.50ms +step:111/1670 train_time:10821ms step_avg:97.49ms +step:112/1670 train_time:10918ms step_avg:97.48ms +step:113/1670 train_time:11013ms step_avg:97.46ms +step:114/1670 train_time:11109ms step_avg:97.44ms +step:115/1670 train_time:11204ms step_avg:97.43ms +step:116/1670 train_time:11300ms step_avg:97.42ms +step:117/1670 train_time:11397ms step_avg:97.41ms +step:118/1670 train_time:11493ms step_avg:97.40ms +step:119/1670 train_time:11588ms step_avg:97.38ms +step:120/1670 train_time:11684ms step_avg:97.37ms +step:121/1670 train_time:11780ms step_avg:97.35ms +step:122/1670 train_time:11875ms step_avg:97.34ms +step:123/1670 train_time:11970ms step_avg:97.32ms +step:124/1670 train_time:12067ms step_avg:97.31ms +step:125/1670 train_time:12163ms step_avg:97.30ms +step:125/1670 val_loss:4.3037 train_time:12257ms step_avg:98.06ms +step:126/1670 train_time:12280ms step_avg:97.46ms +step:127/1670 train_time:12362ms step_avg:97.34ms +step:128/1670 train_time:12466ms step_avg:97.39ms +step:129/1670 train_time:12564ms step_avg:97.40ms +step:130/1670 train_time:12660ms step_avg:97.38ms +step:131/1670 train_time:12755ms step_avg:97.36ms +step:132/1670 train_time:12850ms step_avg:97.34ms +step:133/1670 train_time:12945ms step_avg:97.33ms +step:134/1670 train_time:13040ms step_avg:97.32ms +step:135/1670 train_time:13135ms step_avg:97.30ms +step:136/1670 train_time:13230ms step_avg:97.28ms +step:137/1670 train_time:13328ms step_avg:97.29ms +step:138/1670 train_time:13427ms step_avg:97.30ms +step:139/1670 train_time:13525ms step_avg:97.30ms +step:140/1670 train_time:13622ms step_avg:97.30ms +step:141/1670 train_time:13718ms step_avg:97.29ms +step:142/1670 train_time:13814ms step_avg:97.28ms +step:143/1670 train_time:13908ms step_avg:97.26ms +step:144/1670 train_time:14003ms step_avg:97.24ms +step:145/1670 train_time:14098ms step_avg:97.23ms +step:146/1670 train_time:14193ms step_avg:97.21ms +step:147/1670 train_time:14289ms step_avg:97.20ms +step:148/1670 train_time:14386ms step_avg:97.20ms +step:149/1670 train_time:14483ms step_avg:97.20ms +step:150/1670 train_time:14580ms step_avg:97.20ms +step:151/1670 train_time:14675ms step_avg:97.19ms +step:152/1670 train_time:14771ms step_avg:97.18ms +step:153/1670 train_time:14866ms step_avg:97.16ms +step:154/1670 train_time:14961ms step_avg:97.15ms +step:155/1670 train_time:15056ms step_avg:97.14ms +step:156/1670 train_time:15151ms step_avg:97.12ms +step:157/1670 train_time:15247ms step_avg:97.11ms +step:158/1670 train_time:15342ms step_avg:97.10ms +step:159/1670 train_time:15439ms step_avg:97.10ms +step:160/1670 train_time:15535ms step_avg:97.10ms +step:161/1670 train_time:15631ms step_avg:97.09ms +step:162/1670 train_time:15728ms step_avg:97.09ms +step:163/1670 train_time:15824ms step_avg:97.08ms +step:164/1670 train_time:15919ms step_avg:97.07ms +step:165/1670 train_time:16016ms step_avg:97.06ms +step:166/1670 train_time:16111ms step_avg:97.05ms +step:167/1670 train_time:16206ms step_avg:97.04ms +step:168/1670 train_time:16302ms step_avg:97.03ms +step:169/1670 train_time:16397ms step_avg:97.03ms +step:170/1670 train_time:16493ms step_avg:97.02ms +step:171/1670 train_time:16588ms step_avg:97.00ms +step:172/1670 train_time:16685ms step_avg:97.00ms +step:173/1670 train_time:16780ms step_avg:97.00ms +step:174/1670 train_time:16875ms step_avg:96.99ms +step:175/1670 train_time:16970ms step_avg:96.97ms +step:176/1670 train_time:17066ms step_avg:96.96ms +step:177/1670 train_time:17162ms step_avg:96.96ms +step:178/1670 train_time:17257ms step_avg:96.95ms +step:179/1670 train_time:17352ms step_avg:96.94ms +step:180/1670 train_time:17447ms step_avg:96.93ms +step:181/1670 train_time:17543ms step_avg:96.92ms +step:182/1670 train_time:17640ms step_avg:96.92ms +step:183/1670 train_time:17735ms step_avg:96.91ms +step:184/1670 train_time:17831ms step_avg:96.91ms +step:185/1670 train_time:17927ms step_avg:96.90ms +step:186/1670 train_time:18023ms step_avg:96.90ms +step:187/1670 train_time:18119ms step_avg:96.89ms +step:188/1670 train_time:18213ms step_avg:96.88ms +step:189/1670 train_time:18309ms step_avg:96.88ms +step:190/1670 train_time:18405ms step_avg:96.87ms +step:191/1670 train_time:18500ms step_avg:96.86ms +step:192/1670 train_time:18596ms step_avg:96.85ms +step:193/1670 train_time:18691ms step_avg:96.85ms +step:194/1670 train_time:18787ms step_avg:96.84ms +step:195/1670 train_time:18883ms step_avg:96.84ms +step:196/1670 train_time:18980ms step_avg:96.83ms +step:197/1670 train_time:19075ms step_avg:96.83ms +step:198/1670 train_time:19170ms step_avg:96.82ms +step:199/1670 train_time:19266ms step_avg:96.81ms +step:200/1670 train_time:19362ms step_avg:96.81ms +step:201/1670 train_time:19457ms step_avg:96.80ms +step:202/1670 train_time:19552ms step_avg:96.79ms +step:203/1670 train_time:19648ms step_avg:96.79ms +step:204/1670 train_time:19743ms step_avg:96.78ms +step:205/1670 train_time:19839ms step_avg:96.78ms +step:206/1670 train_time:19934ms step_avg:96.77ms +step:207/1670 train_time:20030ms step_avg:96.77ms +step:208/1670 train_time:20125ms step_avg:96.76ms +step:209/1670 train_time:20221ms step_avg:96.75ms +step:210/1670 train_time:20316ms step_avg:96.74ms +step:211/1670 train_time:20412ms step_avg:96.74ms +step:212/1670 train_time:20507ms step_avg:96.73ms +step:213/1670 train_time:20790ms step_avg:97.61ms +step:214/1670 train_time:20880ms step_avg:97.57ms +step:215/1670 train_time:20974ms step_avg:97.56ms +step:216/1670 train_time:21069ms step_avg:97.54ms +step:217/1670 train_time:21163ms step_avg:97.53ms +step:218/1670 train_time:21258ms step_avg:97.51ms +step:219/1670 train_time:21352ms step_avg:97.50ms +step:220/1670 train_time:21446ms step_avg:97.48ms +step:221/1670 train_time:21541ms step_avg:97.47ms +step:222/1670 train_time:21636ms step_avg:97.46ms +step:223/1670 train_time:21735ms step_avg:97.47ms +step:224/1670 train_time:21834ms step_avg:97.47ms +step:225/1670 train_time:21930ms step_avg:97.47ms +step:226/1670 train_time:22026ms step_avg:97.46ms +step:227/1670 train_time:22121ms step_avg:97.45ms +step:228/1670 train_time:22216ms step_avg:97.44ms +step:229/1670 train_time:22311ms step_avg:97.43ms +step:230/1670 train_time:22406ms step_avg:97.42ms +step:231/1670 train_time:22501ms step_avg:97.41ms +step:232/1670 train_time:22596ms step_avg:97.40ms +step:233/1670 train_time:22692ms step_avg:97.39ms +step:234/1670 train_time:22789ms step_avg:97.39ms +step:235/1670 train_time:22887ms step_avg:97.39ms +step:236/1670 train_time:22984ms step_avg:97.39ms +step:237/1670 train_time:23079ms step_avg:97.38ms +step:238/1670 train_time:23174ms step_avg:97.37ms +step:239/1670 train_time:23269ms step_avg:97.36ms +step:240/1670 train_time:23364ms step_avg:97.35ms +step:241/1670 train_time:23458ms step_avg:97.34ms +step:242/1670 train_time:23553ms step_avg:97.33ms +step:243/1670 train_time:23648ms step_avg:97.32ms +step:244/1670 train_time:23746ms step_avg:97.32ms +step:245/1670 train_time:23842ms step_avg:97.32ms +step:246/1670 train_time:23939ms step_avg:97.31ms +step:247/1670 train_time:24035ms step_avg:97.31ms +step:248/1670 train_time:24130ms step_avg:97.30ms +step:249/1670 train_time:24225ms step_avg:97.29ms +step:250/1670 train_time:24320ms step_avg:97.28ms +step:250/1670 val_loss:3.9734 train_time:24414ms step_avg:97.66ms +step:251/1670 train_time:24435ms step_avg:97.35ms +step:252/1670 train_time:24516ms step_avg:97.29ms +step:253/1670 train_time:24618ms step_avg:97.30ms +step:254/1670 train_time:24714ms step_avg:97.30ms +step:255/1670 train_time:24809ms step_avg:97.29ms +step:256/1670 train_time:24903ms step_avg:97.28ms +step:257/1670 train_time:24998ms step_avg:97.27ms +step:258/1670 train_time:25093ms step_avg:97.26ms +step:259/1670 train_time:25187ms step_avg:97.25ms +step:260/1670 train_time:25282ms step_avg:97.24ms +step:261/1670 train_time:25378ms step_avg:97.24ms +step:262/1670 train_time:25478ms step_avg:97.24ms +step:263/1670 train_time:25576ms step_avg:97.25ms +step:264/1670 train_time:25672ms step_avg:97.24ms +step:265/1670 train_time:25769ms step_avg:97.24ms +step:266/1670 train_time:25864ms step_avg:97.23ms +step:267/1670 train_time:25959ms step_avg:97.22ms +step:268/1670 train_time:26053ms step_avg:97.21ms +step:269/1670 train_time:26148ms step_avg:97.21ms +step:270/1670 train_time:26243ms step_avg:97.20ms +step:271/1670 train_time:26338ms step_avg:97.19ms +step:272/1670 train_time:26435ms step_avg:97.19ms +step:273/1670 train_time:26532ms step_avg:97.19ms +step:274/1670 train_time:26628ms step_avg:97.18ms +step:275/1670 train_time:26724ms step_avg:97.18ms +step:276/1670 train_time:26819ms step_avg:97.17ms +step:277/1670 train_time:26915ms step_avg:97.17ms +step:278/1670 train_time:27010ms step_avg:97.16ms +step:279/1670 train_time:27105ms step_avg:97.15ms +step:280/1670 train_time:27200ms step_avg:97.14ms +step:281/1670 train_time:27295ms step_avg:97.14ms +step:282/1670 train_time:27391ms step_avg:97.13ms +step:283/1670 train_time:27487ms step_avg:97.13ms +step:284/1670 train_time:27582ms step_avg:97.12ms +step:285/1670 train_time:27679ms step_avg:97.12ms +step:286/1670 train_time:27776ms step_avg:97.12ms +step:287/1670 train_time:27871ms step_avg:97.11ms +step:288/1670 train_time:27968ms step_avg:97.11ms +step:289/1670 train_time:28062ms step_avg:97.10ms +step:290/1670 train_time:28158ms step_avg:97.10ms +step:291/1670 train_time:28254ms step_avg:97.09ms +step:292/1670 train_time:28349ms step_avg:97.09ms +step:293/1670 train_time:28445ms step_avg:97.08ms +step:294/1670 train_time:28540ms step_avg:97.08ms +step:295/1670 train_time:28636ms step_avg:97.07ms +step:296/1670 train_time:28733ms step_avg:97.07ms +step:297/1670 train_time:28828ms step_avg:97.06ms +step:298/1670 train_time:28924ms step_avg:97.06ms +step:299/1670 train_time:29019ms step_avg:97.05ms +step:300/1670 train_time:29116ms step_avg:97.05ms +step:301/1670 train_time:29211ms step_avg:97.05ms +step:302/1670 train_time:29306ms step_avg:97.04ms +step:303/1670 train_time:29401ms step_avg:97.03ms +step:304/1670 train_time:29498ms step_avg:97.03ms +step:305/1670 train_time:29594ms step_avg:97.03ms +step:306/1670 train_time:29690ms step_avg:97.03ms +step:307/1670 train_time:29785ms step_avg:97.02ms +step:308/1670 train_time:29880ms step_avg:97.01ms +step:309/1670 train_time:29976ms step_avg:97.01ms +step:310/1670 train_time:30071ms step_avg:97.00ms +step:311/1670 train_time:30167ms step_avg:97.00ms +step:312/1670 train_time:30262ms step_avg:96.99ms +step:313/1670 train_time:30357ms step_avg:96.99ms +step:314/1670 train_time:30453ms step_avg:96.98ms +step:315/1670 train_time:30548ms step_avg:96.98ms +step:316/1670 train_time:30644ms step_avg:96.97ms +step:317/1670 train_time:30740ms step_avg:96.97ms +step:318/1670 train_time:30837ms step_avg:96.97ms +step:319/1670 train_time:30934ms step_avg:96.97ms +step:320/1670 train_time:31030ms step_avg:96.97ms +step:321/1670 train_time:31126ms step_avg:96.96ms +step:322/1670 train_time:31221ms step_avg:96.96ms +step:323/1670 train_time:31318ms step_avg:96.96ms +step:324/1670 train_time:31414ms step_avg:96.96ms +step:325/1670 train_time:31509ms step_avg:96.95ms +step:326/1670 train_time:31604ms step_avg:96.95ms +step:327/1670 train_time:31700ms step_avg:96.94ms +step:328/1670 train_time:31796ms step_avg:96.94ms +step:329/1670 train_time:31892ms step_avg:96.94ms +step:330/1670 train_time:31987ms step_avg:96.93ms +step:331/1670 train_time:32082ms step_avg:96.92ms +step:332/1670 train_time:32178ms step_avg:96.92ms +step:333/1670 train_time:32275ms step_avg:96.92ms +step:334/1670 train_time:32370ms step_avg:96.92ms +step:335/1670 train_time:32465ms step_avg:96.91ms +step:336/1670 train_time:32561ms step_avg:96.91ms +step:337/1670 train_time:32658ms step_avg:96.91ms +step:338/1670 train_time:32754ms step_avg:96.90ms +step:339/1670 train_time:32849ms step_avg:96.90ms +step:340/1670 train_time:32944ms step_avg:96.89ms +step:341/1670 train_time:33040ms step_avg:96.89ms +step:342/1670 train_time:33136ms step_avg:96.89ms +step:343/1670 train_time:33232ms step_avg:96.89ms +step:344/1670 train_time:33328ms step_avg:96.88ms +step:345/1670 train_time:33424ms step_avg:96.88ms +step:346/1670 train_time:33519ms step_avg:96.88ms +step:347/1670 train_time:33616ms step_avg:96.88ms +step:348/1670 train_time:33711ms step_avg:96.87ms +step:349/1670 train_time:33806ms step_avg:96.87ms +step:350/1670 train_time:33901ms step_avg:96.86ms +step:351/1670 train_time:33997ms step_avg:96.86ms +step:352/1670 train_time:34093ms step_avg:96.85ms +step:353/1670 train_time:34188ms step_avg:96.85ms +step:354/1670 train_time:34284ms step_avg:96.85ms +step:355/1670 train_time:34379ms step_avg:96.84ms +step:356/1670 train_time:34475ms step_avg:96.84ms +step:357/1670 train_time:34571ms step_avg:96.84ms +step:358/1670 train_time:34667ms step_avg:96.83ms +step:359/1670 train_time:34762ms step_avg:96.83ms +step:360/1670 train_time:34858ms step_avg:96.83ms +step:361/1670 train_time:34954ms step_avg:96.83ms +step:362/1670 train_time:35050ms step_avg:96.82ms +step:363/1670 train_time:35145ms step_avg:96.82ms +step:364/1670 train_time:35240ms step_avg:96.81ms +step:365/1670 train_time:35337ms step_avg:96.81ms +step:366/1670 train_time:35432ms step_avg:96.81ms +step:367/1670 train_time:35528ms step_avg:96.81ms +step:368/1670 train_time:35624ms step_avg:96.80ms +step:369/1670 train_time:35719ms step_avg:96.80ms +step:370/1670 train_time:35815ms step_avg:96.80ms +step:371/1670 train_time:35911ms step_avg:96.80ms +step:372/1670 train_time:36006ms step_avg:96.79ms +step:373/1670 train_time:36101ms step_avg:96.79ms +step:374/1670 train_time:36198ms step_avg:96.79ms +step:375/1670 train_time:36294ms step_avg:96.78ms +step:375/1670 val_loss:3.8171 train_time:36390ms step_avg:97.04ms +step:376/1670 train_time:36411ms step_avg:96.84ms +step:377/1670 train_time:36492ms step_avg:96.79ms +step:378/1670 train_time:36593ms step_avg:96.81ms +step:379/1670 train_time:36688ms step_avg:96.80ms +step:380/1670 train_time:36783ms step_avg:96.80ms +step:381/1670 train_time:36878ms step_avg:96.79ms +step:382/1670 train_time:36973ms step_avg:96.79ms +step:383/1670 train_time:37067ms step_avg:96.78ms +step:384/1670 train_time:37163ms step_avg:96.78ms +step:385/1670 train_time:37257ms step_avg:96.77ms +step:386/1670 train_time:37353ms step_avg:96.77ms +step:387/1670 train_time:37449ms step_avg:96.77ms +step:388/1670 train_time:37547ms step_avg:96.77ms +step:389/1670 train_time:37644ms step_avg:96.77ms +step:390/1670 train_time:37740ms step_avg:96.77ms +step:391/1670 train_time:37835ms step_avg:96.76ms +step:392/1670 train_time:37930ms step_avg:96.76ms +step:393/1670 train_time:38025ms step_avg:96.76ms +step:394/1670 train_time:38120ms step_avg:96.75ms +step:395/1670 train_time:38215ms step_avg:96.75ms +step:396/1670 train_time:38310ms step_avg:96.74ms +step:397/1670 train_time:38407ms step_avg:96.74ms +step:398/1670 train_time:38504ms step_avg:96.74ms +step:399/1670 train_time:38600ms step_avg:96.74ms +step:400/1670 train_time:38696ms step_avg:96.74ms +step:401/1670 train_time:38791ms step_avg:96.74ms +step:402/1670 train_time:38887ms step_avg:96.73ms +step:403/1670 train_time:38983ms step_avg:96.73ms +step:404/1670 train_time:39079ms step_avg:96.73ms +step:405/1670 train_time:39174ms step_avg:96.73ms +step:406/1670 train_time:39269ms step_avg:96.72ms +step:407/1670 train_time:39365ms step_avg:96.72ms +step:408/1670 train_time:39462ms step_avg:96.72ms +step:409/1670 train_time:39559ms step_avg:96.72ms +step:410/1670 train_time:39654ms step_avg:96.72ms +step:411/1670 train_time:39750ms step_avg:96.71ms +step:412/1670 train_time:39845ms step_avg:96.71ms +step:413/1670 train_time:39941ms step_avg:96.71ms +step:414/1670 train_time:40037ms step_avg:96.71ms +step:415/1670 train_time:40131ms step_avg:96.70ms +step:416/1670 train_time:40227ms step_avg:96.70ms +step:417/1670 train_time:40323ms step_avg:96.70ms +step:418/1670 train_time:40420ms step_avg:96.70ms +step:419/1670 train_time:40517ms step_avg:96.70ms +step:420/1670 train_time:40612ms step_avg:96.70ms +step:421/1670 train_time:40708ms step_avg:96.69ms +step:422/1670 train_time:40805ms step_avg:96.69ms +step:423/1670 train_time:40901ms step_avg:96.69ms +step:424/1670 train_time:40996ms step_avg:96.69ms +step:425/1670 train_time:41286ms step_avg:97.14ms +step:426/1670 train_time:41453ms step_avg:97.31ms +step:427/1670 train_time:41546ms step_avg:97.30ms +step:428/1670 train_time:41641ms step_avg:97.29ms +step:429/1670 train_time:41735ms step_avg:97.28ms +step:430/1670 train_time:41829ms step_avg:97.28ms +step:431/1670 train_time:41924ms step_avg:97.27ms +step:432/1670 train_time:42018ms step_avg:97.26ms +step:433/1670 train_time:42112ms step_avg:97.26ms +step:434/1670 train_time:42207ms step_avg:97.25ms +step:435/1670 train_time:42303ms step_avg:97.25ms +step:436/1670 train_time:42405ms step_avg:97.26ms +step:437/1670 train_time:42506ms step_avg:97.27ms +step:438/1670 train_time:42602ms step_avg:97.27ms +step:439/1670 train_time:42697ms step_avg:97.26ms +step:440/1670 train_time:42792ms step_avg:97.26ms +step:441/1670 train_time:42887ms step_avg:97.25ms +step:442/1670 train_time:42982ms step_avg:97.24ms +step:443/1670 train_time:43076ms step_avg:97.24ms +step:444/1670 train_time:43171ms step_avg:97.23ms +step:445/1670 train_time:43266ms step_avg:97.23ms +step:446/1670 train_time:43363ms step_avg:97.23ms +step:447/1670 train_time:43461ms step_avg:97.23ms +step:448/1670 train_time:43558ms step_avg:97.23ms +step:449/1670 train_time:43653ms step_avg:97.22ms +step:450/1670 train_time:43748ms step_avg:97.22ms +step:451/1670 train_time:43844ms step_avg:97.21ms +step:452/1670 train_time:43939ms step_avg:97.21ms +step:453/1670 train_time:44034ms step_avg:97.21ms +step:454/1670 train_time:44128ms step_avg:97.20ms +step:455/1670 train_time:44224ms step_avg:97.20ms +step:456/1670 train_time:44320ms step_avg:97.19ms +step:457/1670 train_time:44415ms step_avg:97.19ms +step:458/1670 train_time:44511ms step_avg:97.19ms +step:459/1670 train_time:44608ms step_avg:97.18ms +step:460/1670 train_time:44704ms step_avg:97.18ms +step:461/1670 train_time:44800ms step_avg:97.18ms +step:462/1670 train_time:44896ms step_avg:97.18ms +step:463/1670 train_time:44991ms step_avg:97.17ms +step:464/1670 train_time:45086ms step_avg:97.17ms +step:465/1670 train_time:45181ms step_avg:97.16ms +step:466/1670 train_time:45277ms step_avg:97.16ms +step:467/1670 train_time:45372ms step_avg:97.16ms +step:468/1670 train_time:45468ms step_avg:97.15ms +step:469/1670 train_time:45566ms step_avg:97.16ms +step:470/1670 train_time:45662ms step_avg:97.15ms +step:471/1670 train_time:45758ms step_avg:97.15ms +step:472/1670 train_time:45853ms step_avg:97.15ms +step:473/1670 train_time:45948ms step_avg:97.14ms +step:474/1670 train_time:46045ms step_avg:97.14ms +step:475/1670 train_time:46140ms step_avg:97.14ms +step:476/1670 train_time:46235ms step_avg:97.13ms +step:477/1670 train_time:46330ms step_avg:97.13ms +step:478/1670 train_time:46425ms step_avg:97.12ms +step:479/1670 train_time:46521ms step_avg:97.12ms +step:480/1670 train_time:46618ms step_avg:97.12ms +step:481/1670 train_time:46713ms step_avg:97.12ms +step:482/1670 train_time:46809ms step_avg:97.11ms +step:483/1670 train_time:46904ms step_avg:97.11ms +step:484/1670 train_time:47000ms step_avg:97.11ms +step:485/1670 train_time:47095ms step_avg:97.10ms +step:486/1670 train_time:47190ms step_avg:97.10ms +step:487/1670 train_time:47285ms step_avg:97.09ms +step:488/1670 train_time:47381ms step_avg:97.09ms +step:489/1670 train_time:47477ms step_avg:97.09ms +step:490/1670 train_time:47573ms step_avg:97.09ms +step:491/1670 train_time:47669ms step_avg:97.08ms +step:492/1670 train_time:47765ms step_avg:97.08ms +step:493/1670 train_time:47861ms step_avg:97.08ms +step:494/1670 train_time:47957ms step_avg:97.08ms +step:495/1670 train_time:48052ms step_avg:97.07ms +step:496/1670 train_time:48148ms step_avg:97.07ms +step:497/1670 train_time:48243ms step_avg:97.07ms +step:498/1670 train_time:48338ms step_avg:97.07ms +step:499/1670 train_time:48433ms step_avg:97.06ms +step:500/1670 train_time:48528ms step_avg:97.06ms +step:500/1670 val_loss:3.7150 train_time:48624ms step_avg:97.25ms +step:501/1670 train_time:48645ms step_avg:97.10ms +step:502/1670 train_time:48726ms step_avg:97.06ms +step:503/1670 train_time:48826ms step_avg:97.07ms +step:504/1670 train_time:48922ms step_avg:97.07ms +step:505/1670 train_time:49016ms step_avg:97.06ms +step:506/1670 train_time:49111ms step_avg:97.06ms +step:507/1670 train_time:49206ms step_avg:97.05ms +step:508/1670 train_time:49301ms step_avg:97.05ms +step:509/1670 train_time:49396ms step_avg:97.05ms +step:510/1670 train_time:49491ms step_avg:97.04ms +step:511/1670 train_time:49586ms step_avg:97.04ms +step:512/1670 train_time:49684ms step_avg:97.04ms +step:513/1670 train_time:49783ms step_avg:97.04ms +step:514/1670 train_time:49881ms step_avg:97.04ms +step:515/1670 train_time:49976ms step_avg:97.04ms +step:516/1670 train_time:50071ms step_avg:97.04ms +step:517/1670 train_time:50166ms step_avg:97.03ms +step:518/1670 train_time:50261ms step_avg:97.03ms +step:519/1670 train_time:50357ms step_avg:97.03ms +step:520/1670 train_time:50453ms step_avg:97.02ms +step:521/1670 train_time:50548ms step_avg:97.02ms +step:522/1670 train_time:50643ms step_avg:97.02ms +step:523/1670 train_time:50740ms step_avg:97.02ms +step:524/1670 train_time:50838ms step_avg:97.02ms +step:525/1670 train_time:50934ms step_avg:97.02ms +step:526/1670 train_time:51030ms step_avg:97.02ms +step:527/1670 train_time:51125ms step_avg:97.01ms +step:528/1670 train_time:51221ms step_avg:97.01ms +step:529/1670 train_time:51316ms step_avg:97.00ms +step:530/1670 train_time:51411ms step_avg:97.00ms +step:531/1670 train_time:51506ms step_avg:97.00ms +step:532/1670 train_time:51602ms step_avg:97.00ms +step:533/1670 train_time:51698ms step_avg:96.99ms +step:534/1670 train_time:51795ms step_avg:96.99ms +step:535/1670 train_time:51891ms step_avg:96.99ms +step:536/1670 train_time:51986ms step_avg:96.99ms +step:537/1670 train_time:52083ms step_avg:96.99ms +step:538/1670 train_time:52179ms step_avg:96.99ms +step:539/1670 train_time:52274ms step_avg:96.98ms +step:540/1670 train_time:52369ms step_avg:96.98ms +step:541/1670 train_time:52464ms step_avg:96.98ms +step:542/1670 train_time:52560ms step_avg:96.97ms +step:543/1670 train_time:52656ms step_avg:96.97ms +step:544/1670 train_time:52752ms step_avg:96.97ms +step:545/1670 train_time:52847ms step_avg:96.97ms +step:546/1670 train_time:52943ms step_avg:96.97ms +step:547/1670 train_time:53041ms step_avg:96.97ms +step:548/1670 train_time:53136ms step_avg:96.96ms +step:549/1670 train_time:53231ms step_avg:96.96ms +step:550/1670 train_time:53325ms step_avg:96.96ms +step:551/1670 train_time:53421ms step_avg:96.95ms +step:552/1670 train_time:53518ms step_avg:96.95ms +step:553/1670 train_time:53614ms step_avg:96.95ms +step:554/1670 train_time:53709ms step_avg:96.95ms +step:555/1670 train_time:53805ms step_avg:96.95ms +step:556/1670 train_time:53901ms step_avg:96.94ms +step:557/1670 train_time:53997ms step_avg:96.94ms +step:558/1670 train_time:54094ms step_avg:96.94ms +step:559/1670 train_time:54190ms step_avg:96.94ms +step:560/1670 train_time:54287ms step_avg:96.94ms +step:561/1670 train_time:54383ms step_avg:96.94ms +step:562/1670 train_time:54480ms step_avg:96.94ms +step:563/1670 train_time:54578ms step_avg:96.94ms +step:564/1670 train_time:54676ms step_avg:96.94ms +step:565/1670 train_time:54773ms step_avg:96.94ms +step:566/1670 train_time:54869ms step_avg:96.94ms +step:567/1670 train_time:54966ms step_avg:96.94ms +step:568/1670 train_time:55063ms step_avg:96.94ms +step:569/1670 train_time:55160ms step_avg:96.94ms +step:570/1670 train_time:55258ms step_avg:96.94ms +step:571/1670 train_time:55355ms step_avg:96.94ms +step:572/1670 train_time:55452ms step_avg:96.94ms +step:573/1670 train_time:55548ms step_avg:96.94ms +step:574/1670 train_time:55646ms step_avg:96.94ms +step:575/1670 train_time:55743ms step_avg:96.94ms +step:576/1670 train_time:55842ms step_avg:96.95ms +step:577/1670 train_time:55939ms step_avg:96.95ms +step:578/1670 train_time:56036ms step_avg:96.95ms +step:579/1670 train_time:56133ms step_avg:96.95ms +step:580/1670 train_time:56230ms step_avg:96.95ms +step:581/1670 train_time:56327ms step_avg:96.95ms +step:582/1670 train_time:56423ms step_avg:96.95ms +step:583/1670 train_time:56520ms step_avg:96.95ms +step:584/1670 train_time:56617ms step_avg:96.95ms +step:585/1670 train_time:56715ms step_avg:96.95ms +step:586/1670 train_time:56811ms step_avg:96.95ms +step:587/1670 train_time:56908ms step_avg:96.95ms +step:588/1670 train_time:57005ms step_avg:96.95ms +step:589/1670 train_time:57102ms step_avg:96.95ms +step:590/1670 train_time:57199ms step_avg:96.95ms +step:591/1670 train_time:57297ms step_avg:96.95ms +step:592/1670 train_time:57394ms step_avg:96.95ms +step:593/1670 train_time:57491ms step_avg:96.95ms +step:594/1670 train_time:57588ms step_avg:96.95ms +step:595/1670 train_time:57685ms step_avg:96.95ms +step:596/1670 train_time:57782ms step_avg:96.95ms +step:597/1670 train_time:57880ms step_avg:96.95ms +step:598/1670 train_time:57978ms step_avg:96.95ms +step:599/1670 train_time:58074ms step_avg:96.95ms +step:600/1670 train_time:58172ms step_avg:96.95ms +step:601/1670 train_time:58268ms step_avg:96.95ms +step:602/1670 train_time:58364ms step_avg:96.95ms +step:603/1670 train_time:58461ms step_avg:96.95ms +step:604/1670 train_time:58559ms step_avg:96.95ms +step:605/1670 train_time:58656ms step_avg:96.95ms +step:606/1670 train_time:58753ms step_avg:96.95ms +step:607/1670 train_time:58850ms step_avg:96.95ms +step:608/1670 train_time:58948ms step_avg:96.95ms +step:609/1670 train_time:59045ms step_avg:96.95ms +step:610/1670 train_time:59143ms step_avg:96.96ms +step:611/1670 train_time:59241ms step_avg:96.96ms +step:612/1670 train_time:59339ms step_avg:96.96ms +step:613/1670 train_time:59436ms step_avg:96.96ms +step:614/1670 train_time:59533ms step_avg:96.96ms +step:615/1670 train_time:59630ms step_avg:96.96ms +step:616/1670 train_time:59727ms step_avg:96.96ms +step:617/1670 train_time:59824ms step_avg:96.96ms +step:618/1670 train_time:59924ms step_avg:96.96ms +step:619/1670 train_time:60021ms step_avg:96.96ms +step:620/1670 train_time:60118ms step_avg:96.96ms +step:621/1670 train_time:60216ms step_avg:96.97ms +step:622/1670 train_time:60313ms step_avg:96.97ms +step:623/1670 train_time:60409ms step_avg:96.96ms +step:624/1670 train_time:60505ms step_avg:96.96ms +step:625/1670 train_time:60603ms step_avg:96.97ms +step:625/1670 val_loss:3.6149 train_time:60701ms step_avg:97.12ms +step:626/1670 train_time:60722ms step_avg:97.00ms +step:627/1670 train_time:60808ms step_avg:96.98ms +step:628/1670 train_time:60908ms step_avg:96.99ms +step:629/1670 train_time:61005ms step_avg:96.99ms +step:630/1670 train_time:61101ms step_avg:96.99ms +step:631/1670 train_time:61197ms step_avg:96.98ms +step:632/1670 train_time:61293ms step_avg:96.98ms +step:633/1670 train_time:61389ms step_avg:96.98ms +step:634/1670 train_time:61485ms step_avg:96.98ms +step:635/1670 train_time:61581ms step_avg:96.98ms +step:636/1670 train_time:61678ms step_avg:96.98ms +step:637/1670 train_time:61776ms step_avg:96.98ms +step:638/1670 train_time:61874ms step_avg:96.98ms +step:639/1670 train_time:62168ms step_avg:97.29ms +step:640/1670 train_time:62353ms step_avg:97.43ms +step:641/1670 train_time:62448ms step_avg:97.42ms +step:642/1670 train_time:62544ms step_avg:97.42ms +step:643/1670 train_time:62640ms step_avg:97.42ms +step:644/1670 train_time:62736ms step_avg:97.42ms +step:645/1670 train_time:62832ms step_avg:97.41ms +step:646/1670 train_time:62927ms step_avg:97.41ms +step:647/1670 train_time:63023ms step_avg:97.41ms +step:648/1670 train_time:63119ms step_avg:97.41ms +step:649/1670 train_time:63218ms step_avg:97.41ms +step:650/1670 train_time:63318ms step_avg:97.41ms +step:651/1670 train_time:63417ms step_avg:97.41ms +step:652/1670 train_time:63513ms step_avg:97.41ms +step:653/1670 train_time:63609ms step_avg:97.41ms +step:654/1670 train_time:63706ms step_avg:97.41ms +step:655/1670 train_time:63803ms step_avg:97.41ms +step:656/1670 train_time:63898ms step_avg:97.41ms +step:657/1670 train_time:63993ms step_avg:97.40ms +step:658/1670 train_time:64089ms step_avg:97.40ms +step:659/1670 train_time:64188ms step_avg:97.40ms +step:660/1670 train_time:64287ms step_avg:97.41ms +step:661/1670 train_time:64386ms step_avg:97.41ms +step:662/1670 train_time:64484ms step_avg:97.41ms +step:663/1670 train_time:64581ms step_avg:97.41ms +step:664/1670 train_time:64678ms step_avg:97.41ms +step:665/1670 train_time:64774ms step_avg:97.40ms +step:666/1670 train_time:64870ms step_avg:97.40ms +step:667/1670 train_time:64967ms step_avg:97.40ms +step:668/1670 train_time:65064ms step_avg:97.40ms +step:669/1670 train_time:65161ms step_avg:97.40ms +step:670/1670 train_time:65258ms step_avg:97.40ms +step:671/1670 train_time:65355ms step_avg:97.40ms +step:672/1670 train_time:65451ms step_avg:97.40ms +step:673/1670 train_time:65549ms step_avg:97.40ms +step:674/1670 train_time:65646ms step_avg:97.40ms +step:675/1670 train_time:65744ms step_avg:97.40ms +step:676/1670 train_time:65841ms step_avg:97.40ms +step:677/1670 train_time:65938ms step_avg:97.40ms +step:678/1670 train_time:66034ms step_avg:97.40ms +step:679/1670 train_time:66130ms step_avg:97.39ms +step:680/1670 train_time:66228ms step_avg:97.39ms +step:681/1670 train_time:66327ms step_avg:97.40ms +step:682/1670 train_time:66425ms step_avg:97.40ms +step:683/1670 train_time:66523ms step_avg:97.40ms +step:684/1670 train_time:66621ms step_avg:97.40ms +step:685/1670 train_time:66717ms step_avg:97.40ms +step:686/1670 train_time:66814ms step_avg:97.40ms +step:687/1670 train_time:66911ms step_avg:97.40ms +step:688/1670 train_time:67008ms step_avg:97.39ms +step:689/1670 train_time:67104ms step_avg:97.39ms +step:690/1670 train_time:67201ms step_avg:97.39ms +step:691/1670 train_time:67298ms step_avg:97.39ms +step:692/1670 train_time:67394ms step_avg:97.39ms +step:693/1670 train_time:67491ms step_avg:97.39ms +step:694/1670 train_time:67589ms step_avg:97.39ms +step:695/1670 train_time:67686ms step_avg:97.39ms +step:696/1670 train_time:67784ms step_avg:97.39ms +step:697/1670 train_time:67881ms step_avg:97.39ms +step:698/1670 train_time:67979ms step_avg:97.39ms +step:699/1670 train_time:68075ms step_avg:97.39ms +step:700/1670 train_time:68170ms step_avg:97.39ms +step:701/1670 train_time:68267ms step_avg:97.39ms +step:702/1670 train_time:68365ms step_avg:97.39ms +step:703/1670 train_time:68463ms step_avg:97.39ms +step:704/1670 train_time:68560ms step_avg:97.39ms +step:705/1670 train_time:68657ms step_avg:97.39ms +step:706/1670 train_time:68754ms step_avg:97.38ms +step:707/1670 train_time:68851ms step_avg:97.38ms +step:708/1670 train_time:68948ms step_avg:97.38ms +step:709/1670 train_time:69046ms step_avg:97.38ms +step:710/1670 train_time:69143ms step_avg:97.38ms +step:711/1670 train_time:69240ms step_avg:97.38ms +step:712/1670 train_time:69336ms step_avg:97.38ms +step:713/1670 train_time:69433ms step_avg:97.38ms +step:714/1670 train_time:69530ms step_avg:97.38ms +step:715/1670 train_time:69628ms step_avg:97.38ms +step:716/1670 train_time:69726ms step_avg:97.38ms +step:717/1670 train_time:69824ms step_avg:97.38ms +step:718/1670 train_time:69922ms step_avg:97.38ms +step:719/1670 train_time:70018ms step_avg:97.38ms +step:720/1670 train_time:70114ms step_avg:97.38ms +step:721/1670 train_time:70210ms step_avg:97.38ms +step:722/1670 train_time:70309ms step_avg:97.38ms +step:723/1670 train_time:70406ms step_avg:97.38ms +step:724/1670 train_time:70504ms step_avg:97.38ms +step:725/1670 train_time:70601ms step_avg:97.38ms +step:726/1670 train_time:70698ms step_avg:97.38ms +step:727/1670 train_time:70795ms step_avg:97.38ms +step:728/1670 train_time:70892ms step_avg:97.38ms +step:729/1670 train_time:70989ms step_avg:97.38ms +step:730/1670 train_time:71087ms step_avg:97.38ms +step:731/1670 train_time:71184ms step_avg:97.38ms +step:732/1670 train_time:71282ms step_avg:97.38ms +step:733/1670 train_time:71379ms step_avg:97.38ms +step:734/1670 train_time:71475ms step_avg:97.38ms +step:735/1670 train_time:71572ms step_avg:97.38ms +step:736/1670 train_time:71669ms step_avg:97.38ms +step:737/1670 train_time:71767ms step_avg:97.38ms +step:738/1670 train_time:71865ms step_avg:97.38ms +step:739/1670 train_time:71962ms step_avg:97.38ms +step:740/1670 train_time:72059ms step_avg:97.38ms +step:741/1670 train_time:72156ms step_avg:97.38ms +step:742/1670 train_time:72252ms step_avg:97.38ms +step:743/1670 train_time:72350ms step_avg:97.37ms +step:744/1670 train_time:72447ms step_avg:97.37ms +step:745/1670 train_time:72545ms step_avg:97.38ms +step:746/1670 train_time:72643ms step_avg:97.38ms +step:747/1670 train_time:72740ms step_avg:97.38ms +step:748/1670 train_time:72836ms step_avg:97.37ms +step:749/1670 train_time:72932ms step_avg:97.37ms +step:750/1670 train_time:73030ms step_avg:97.37ms +step:750/1670 val_loss:3.5624 train_time:73128ms step_avg:97.50ms +step:751/1670 train_time:73148ms step_avg:97.40ms +step:752/1670 train_time:73231ms step_avg:97.38ms +step:753/1670 train_time:73332ms step_avg:97.39ms +step:754/1670 train_time:73429ms step_avg:97.39ms +step:755/1670 train_time:73525ms step_avg:97.38ms +step:756/1670 train_time:73621ms step_avg:97.38ms +step:757/1670 train_time:73717ms step_avg:97.38ms +step:758/1670 train_time:73813ms step_avg:97.38ms +step:759/1670 train_time:73910ms step_avg:97.38ms +step:760/1670 train_time:74008ms step_avg:97.38ms +step:761/1670 train_time:74107ms step_avg:97.38ms +step:762/1670 train_time:74208ms step_avg:97.39ms +step:763/1670 train_time:74308ms step_avg:97.39ms +step:764/1670 train_time:74406ms step_avg:97.39ms +step:765/1670 train_time:74503ms step_avg:97.39ms +step:766/1670 train_time:74599ms step_avg:97.39ms +step:767/1670 train_time:74695ms step_avg:97.39ms +step:768/1670 train_time:74791ms step_avg:97.38ms +step:769/1670 train_time:74887ms step_avg:97.38ms +step:770/1670 train_time:74983ms step_avg:97.38ms +step:771/1670 train_time:75081ms step_avg:97.38ms +step:772/1670 train_time:75179ms step_avg:97.38ms +step:773/1670 train_time:75277ms step_avg:97.38ms +step:774/1670 train_time:75375ms step_avg:97.38ms +step:775/1670 train_time:75472ms step_avg:97.38ms +step:776/1670 train_time:75569ms step_avg:97.38ms +step:777/1670 train_time:75667ms step_avg:97.38ms +step:778/1670 train_time:75763ms step_avg:97.38ms +step:779/1670 train_time:75859ms step_avg:97.38ms +step:780/1670 train_time:75955ms step_avg:97.38ms +step:781/1670 train_time:76051ms step_avg:97.38ms +step:782/1670 train_time:76150ms step_avg:97.38ms +step:783/1670 train_time:76248ms step_avg:97.38ms +step:784/1670 train_time:76348ms step_avg:97.38ms +step:785/1670 train_time:76446ms step_avg:97.38ms +step:786/1670 train_time:76543ms step_avg:97.38ms +step:787/1670 train_time:76641ms step_avg:97.38ms +step:788/1670 train_time:76737ms step_avg:97.38ms +step:789/1670 train_time:76834ms step_avg:97.38ms +step:790/1670 train_time:76931ms step_avg:97.38ms +step:791/1670 train_time:77028ms step_avg:97.38ms +step:792/1670 train_time:77125ms step_avg:97.38ms +step:793/1670 train_time:77222ms step_avg:97.38ms +step:794/1670 train_time:77320ms step_avg:97.38ms +step:795/1670 train_time:77417ms step_avg:97.38ms +step:796/1670 train_time:77515ms step_avg:97.38ms +step:797/1670 train_time:77612ms step_avg:97.38ms +step:798/1670 train_time:77709ms step_avg:97.38ms +step:799/1670 train_time:77806ms step_avg:97.38ms +step:800/1670 train_time:77903ms step_avg:97.38ms +step:801/1670 train_time:78000ms step_avg:97.38ms +step:802/1670 train_time:78096ms step_avg:97.38ms +step:803/1670 train_time:78193ms step_avg:97.38ms +step:804/1670 train_time:78290ms step_avg:97.38ms +step:805/1670 train_time:78387ms step_avg:97.38ms +step:806/1670 train_time:78484ms step_avg:97.38ms +step:807/1670 train_time:78582ms step_avg:97.38ms +step:808/1670 train_time:78680ms step_avg:97.38ms +step:809/1670 train_time:78776ms step_avg:97.37ms +step:810/1670 train_time:78872ms step_avg:97.37ms +step:811/1670 train_time:78968ms step_avg:97.37ms +step:812/1670 train_time:79065ms step_avg:97.37ms +step:813/1670 train_time:79163ms step_avg:97.37ms +step:814/1670 train_time:79260ms step_avg:97.37ms +step:815/1670 train_time:79357ms step_avg:97.37ms +step:816/1670 train_time:79453ms step_avg:97.37ms +step:817/1670 train_time:79551ms step_avg:97.37ms +step:818/1670 train_time:79649ms step_avg:97.37ms +step:819/1670 train_time:79747ms step_avg:97.37ms +step:820/1670 train_time:79844ms step_avg:97.37ms +step:821/1670 train_time:79941ms step_avg:97.37ms +step:822/1670 train_time:80038ms step_avg:97.37ms +step:823/1670 train_time:80135ms step_avg:97.37ms +step:824/1670 train_time:80231ms step_avg:97.37ms +step:825/1670 train_time:80329ms step_avg:97.37ms +step:826/1670 train_time:80426ms step_avg:97.37ms +step:827/1670 train_time:80523ms step_avg:97.37ms +step:828/1670 train_time:80621ms step_avg:97.37ms +step:829/1670 train_time:80718ms step_avg:97.37ms +step:830/1670 train_time:80814ms step_avg:97.37ms +step:831/1670 train_time:80910ms step_avg:97.37ms +step:832/1670 train_time:81008ms step_avg:97.37ms +step:833/1670 train_time:81105ms step_avg:97.37ms +step:834/1670 train_time:81203ms step_avg:97.37ms +step:835/1670 train_time:81300ms step_avg:97.37ms +step:836/1670 train_time:81396ms step_avg:97.36ms +step:837/1670 train_time:81493ms step_avg:97.36ms +step:838/1670 train_time:81590ms step_avg:97.36ms +step:839/1670 train_time:81688ms step_avg:97.36ms +step:840/1670 train_time:81787ms step_avg:97.37ms +step:841/1670 train_time:81885ms step_avg:97.37ms +step:842/1670 train_time:81983ms step_avg:97.37ms +step:843/1670 train_time:82080ms step_avg:97.37ms +step:844/1670 train_time:82176ms step_avg:97.36ms +step:845/1670 train_time:82272ms step_avg:97.36ms +step:846/1670 train_time:82369ms step_avg:97.36ms +step:847/1670 train_time:82467ms step_avg:97.36ms +step:848/1670 train_time:82564ms step_avg:97.36ms +step:849/1670 train_time:82662ms step_avg:97.36ms +step:850/1670 train_time:82759ms step_avg:97.36ms +step:851/1670 train_time:83018ms step_avg:97.55ms +step:852/1670 train_time:83183ms step_avg:97.63ms +step:853/1670 train_time:83279ms step_avg:97.63ms +step:854/1670 train_time:83375ms step_avg:97.63ms +step:855/1670 train_time:83470ms step_avg:97.63ms +step:856/1670 train_time:83566ms step_avg:97.62ms +step:857/1670 train_time:83662ms step_avg:97.62ms +step:858/1670 train_time:83758ms step_avg:97.62ms +step:859/1670 train_time:83854ms step_avg:97.62ms +step:860/1670 train_time:83950ms step_avg:97.62ms +step:861/1670 train_time:84053ms step_avg:97.62ms +step:862/1670 train_time:84152ms step_avg:97.62ms +step:863/1670 train_time:84250ms step_avg:97.62ms +step:864/1670 train_time:84347ms step_avg:97.62ms +step:865/1670 train_time:84445ms step_avg:97.62ms +step:866/1670 train_time:84541ms step_avg:97.62ms +step:867/1670 train_time:84636ms step_avg:97.62ms +step:868/1670 train_time:84732ms step_avg:97.62ms +step:869/1670 train_time:84828ms step_avg:97.62ms +step:870/1670 train_time:84926ms step_avg:97.62ms +step:871/1670 train_time:85026ms step_avg:97.62ms +step:872/1670 train_time:85125ms step_avg:97.62ms +step:873/1670 train_time:85223ms step_avg:97.62ms +step:874/1670 train_time:85321ms step_avg:97.62ms +step:875/1670 train_time:85419ms step_avg:97.62ms +step:875/1670 val_loss:3.5190 train_time:85515ms step_avg:97.73ms +step:876/1670 train_time:85536ms step_avg:97.64ms +step:877/1670 train_time:85620ms step_avg:97.63ms +step:878/1670 train_time:85718ms step_avg:97.63ms +step:879/1670 train_time:85814ms step_avg:97.63ms +step:880/1670 train_time:85911ms step_avg:97.63ms +step:881/1670 train_time:86007ms step_avg:97.62ms +step:882/1670 train_time:86102ms step_avg:97.62ms +step:883/1670 train_time:86198ms step_avg:97.62ms +step:884/1670 train_time:86294ms step_avg:97.62ms +step:885/1670 train_time:86390ms step_avg:97.62ms +step:886/1670 train_time:86489ms step_avg:97.62ms +step:887/1670 train_time:86589ms step_avg:97.62ms +step:888/1670 train_time:86690ms step_avg:97.62ms +step:889/1670 train_time:86789ms step_avg:97.63ms +step:890/1670 train_time:86886ms step_avg:97.62ms +step:891/1670 train_time:86982ms step_avg:97.62ms +step:892/1670 train_time:87078ms step_avg:97.62ms +step:893/1670 train_time:87174ms step_avg:97.62ms +step:894/1670 train_time:87270ms step_avg:97.62ms +step:895/1670 train_time:87367ms step_avg:97.62ms +step:896/1670 train_time:87464ms step_avg:97.62ms +step:897/1670 train_time:87563ms step_avg:97.62ms +step:898/1670 train_time:87661ms step_avg:97.62ms +step:899/1670 train_time:87758ms step_avg:97.62ms +step:900/1670 train_time:87855ms step_avg:97.62ms +step:901/1670 train_time:87953ms step_avg:97.62ms +step:902/1670 train_time:88050ms step_avg:97.62ms +step:903/1670 train_time:88147ms step_avg:97.62ms +step:904/1670 train_time:88243ms step_avg:97.61ms +step:905/1670 train_time:88339ms step_avg:97.61ms +step:906/1670 train_time:88435ms step_avg:97.61ms +step:907/1670 train_time:88533ms step_avg:97.61ms +step:908/1670 train_time:88632ms step_avg:97.61ms +step:909/1670 train_time:88730ms step_avg:97.61ms +step:910/1670 train_time:88829ms step_avg:97.61ms +step:911/1670 train_time:88927ms step_avg:97.61ms +step:912/1670 train_time:89023ms step_avg:97.61ms +step:913/1670 train_time:89120ms step_avg:97.61ms +step:914/1670 train_time:89216ms step_avg:97.61ms +step:915/1670 train_time:89312ms step_avg:97.61ms +step:916/1670 train_time:89409ms step_avg:97.61ms +step:917/1670 train_time:89507ms step_avg:97.61ms +step:918/1670 train_time:89605ms step_avg:97.61ms +step:919/1670 train_time:89702ms step_avg:97.61ms +step:920/1670 train_time:89799ms step_avg:97.61ms +step:921/1670 train_time:89897ms step_avg:97.61ms +step:922/1670 train_time:89994ms step_avg:97.61ms +step:923/1670 train_time:90092ms step_avg:97.61ms +step:924/1670 train_time:90189ms step_avg:97.61ms +step:925/1670 train_time:90286ms step_avg:97.61ms +step:926/1670 train_time:90382ms step_avg:97.61ms +step:927/1670 train_time:90479ms step_avg:97.60ms +step:928/1670 train_time:90575ms step_avg:97.60ms +step:929/1670 train_time:90674ms step_avg:97.60ms +step:930/1670 train_time:90773ms step_avg:97.60ms +step:931/1670 train_time:90871ms step_avg:97.61ms +step:932/1670 train_time:90969ms step_avg:97.61ms +step:933/1670 train_time:91066ms step_avg:97.61ms +step:934/1670 train_time:91163ms step_avg:97.61ms +step:935/1670 train_time:91259ms step_avg:97.60ms +step:936/1670 train_time:91355ms step_avg:97.60ms +step:937/1670 train_time:91452ms step_avg:97.60ms +step:938/1670 train_time:91550ms step_avg:97.60ms +step:939/1670 train_time:91648ms step_avg:97.60ms +step:940/1670 train_time:91746ms step_avg:97.60ms +step:941/1670 train_time:91843ms step_avg:97.60ms +step:942/1670 train_time:91941ms step_avg:97.60ms +step:943/1670 train_time:92037ms step_avg:97.60ms +step:944/1670 train_time:92134ms step_avg:97.60ms +step:945/1670 train_time:92232ms step_avg:97.60ms +step:946/1670 train_time:92329ms step_avg:97.60ms +step:947/1670 train_time:92427ms step_avg:97.60ms +step:948/1670 train_time:92525ms step_avg:97.60ms +step:949/1670 train_time:92622ms step_avg:97.60ms +step:950/1670 train_time:92717ms step_avg:97.60ms +step:951/1670 train_time:92814ms step_avg:97.60ms +step:952/1670 train_time:92912ms step_avg:97.60ms +step:953/1670 train_time:93010ms step_avg:97.60ms +step:954/1670 train_time:93107ms step_avg:97.60ms +step:955/1670 train_time:93204ms step_avg:97.60ms +step:956/1670 train_time:93300ms step_avg:97.59ms +step:957/1670 train_time:93398ms step_avg:97.59ms +step:958/1670 train_time:93495ms step_avg:97.59ms +step:959/1670 train_time:93592ms step_avg:97.59ms +step:960/1670 train_time:93689ms step_avg:97.59ms +step:961/1670 train_time:93786ms step_avg:97.59ms +step:962/1670 train_time:93884ms step_avg:97.59ms +step:963/1670 train_time:93981ms step_avg:97.59ms +step:964/1670 train_time:94077ms step_avg:97.59ms +step:965/1670 train_time:94174ms step_avg:97.59ms +step:966/1670 train_time:94273ms step_avg:97.59ms +step:967/1670 train_time:94371ms step_avg:97.59ms +step:968/1670 train_time:94469ms step_avg:97.59ms +step:969/1670 train_time:94567ms step_avg:97.59ms +step:970/1670 train_time:94663ms step_avg:97.59ms +step:971/1670 train_time:94760ms step_avg:97.59ms +step:972/1670 train_time:94856ms step_avg:97.59ms +step:973/1670 train_time:94953ms step_avg:97.59ms +step:974/1670 train_time:95051ms step_avg:97.59ms +step:975/1670 train_time:95150ms step_avg:97.59ms +step:976/1670 train_time:95248ms step_avg:97.59ms +step:977/1670 train_time:95346ms step_avg:97.59ms +step:978/1670 train_time:95443ms step_avg:97.59ms +step:979/1670 train_time:95540ms step_avg:97.59ms +step:980/1670 train_time:95636ms step_avg:97.59ms +step:981/1670 train_time:95733ms step_avg:97.59ms +step:982/1670 train_time:95831ms step_avg:97.59ms +step:983/1670 train_time:95929ms step_avg:97.59ms +step:984/1670 train_time:96026ms step_avg:97.59ms +step:985/1670 train_time:96122ms step_avg:97.59ms +step:986/1670 train_time:96219ms step_avg:97.58ms +step:987/1670 train_time:96315ms step_avg:97.58ms +step:988/1670 train_time:96413ms step_avg:97.58ms +step:989/1670 train_time:96511ms step_avg:97.58ms +step:990/1670 train_time:96609ms step_avg:97.59ms +step:991/1670 train_time:96707ms step_avg:97.59ms +step:992/1670 train_time:96804ms step_avg:97.58ms +step:993/1670 train_time:96901ms step_avg:97.58ms +step:994/1670 train_time:96997ms step_avg:97.58ms +step:995/1670 train_time:97094ms step_avg:97.58ms +step:996/1670 train_time:97192ms step_avg:97.58ms +step:997/1670 train_time:97290ms step_avg:97.58ms +step:998/1670 train_time:97387ms step_avg:97.58ms +step:999/1670 train_time:97485ms step_avg:97.58ms +step:1000/1670 train_time:97582ms step_avg:97.58ms +step:1000/1670 val_loss:3.4774 train_time:97679ms step_avg:97.68ms +step:1001/1670 train_time:97700ms step_avg:97.60ms +step:1002/1670 train_time:97786ms step_avg:97.59ms +step:1003/1670 train_time:97886ms step_avg:97.59ms +step:1004/1670 train_time:97984ms step_avg:97.59ms +step:1005/1670 train_time:98079ms step_avg:97.59ms +step:1006/1670 train_time:98175ms step_avg:97.59ms +step:1007/1670 train_time:98271ms step_avg:97.59ms +step:1008/1670 train_time:98367ms step_avg:97.59ms +step:1009/1670 train_time:98464ms step_avg:97.59ms +step:1010/1670 train_time:98560ms step_avg:97.58ms +step:1011/1670 train_time:98657ms step_avg:97.58ms +step:1012/1670 train_time:98756ms step_avg:97.58ms +step:1013/1670 train_time:98854ms step_avg:97.59ms +step:1014/1670 train_time:98953ms step_avg:97.59ms +step:1015/1670 train_time:99050ms step_avg:97.59ms +step:1016/1670 train_time:99148ms step_avg:97.59ms +step:1017/1670 train_time:99244ms step_avg:97.59ms +step:1018/1670 train_time:99341ms step_avg:97.58ms +step:1019/1670 train_time:99437ms step_avg:97.58ms +step:1020/1670 train_time:99533ms step_avg:97.58ms +step:1021/1670 train_time:99631ms step_avg:97.58ms +step:1022/1670 train_time:99729ms step_avg:97.58ms +step:1023/1670 train_time:99828ms step_avg:97.58ms +step:1024/1670 train_time:99928ms step_avg:97.59ms +step:1025/1670 train_time:100026ms step_avg:97.59ms +step:1026/1670 train_time:100123ms step_avg:97.59ms +step:1027/1670 train_time:100219ms step_avg:97.58ms +step:1028/1670 train_time:100316ms step_avg:97.58ms +step:1029/1670 train_time:100412ms step_avg:97.58ms +step:1030/1670 train_time:100509ms step_avg:97.58ms +step:1031/1670 train_time:100605ms step_avg:97.58ms +step:1032/1670 train_time:100702ms step_avg:97.58ms +step:1033/1670 train_time:100799ms step_avg:97.58ms +step:1034/1670 train_time:100897ms step_avg:97.58ms +step:1035/1670 train_time:100994ms step_avg:97.58ms +step:1036/1670 train_time:101091ms step_avg:97.58ms +step:1037/1670 train_time:101190ms step_avg:97.58ms +step:1038/1670 train_time:101288ms step_avg:97.58ms +step:1039/1670 train_time:101385ms step_avg:97.58ms +step:1040/1670 train_time:101482ms step_avg:97.58ms +step:1041/1670 train_time:101578ms step_avg:97.58ms +step:1042/1670 train_time:101674ms step_avg:97.58ms +step:1043/1670 train_time:101772ms step_avg:97.58ms +step:1044/1670 train_time:101871ms step_avg:97.58ms +step:1045/1670 train_time:101970ms step_avg:97.58ms +step:1046/1670 train_time:102068ms step_avg:97.58ms +step:1047/1670 train_time:102165ms step_avg:97.58ms +step:1048/1670 train_time:102263ms step_avg:97.58ms +step:1049/1670 train_time:102360ms step_avg:97.58ms +step:1050/1670 train_time:102456ms step_avg:97.58ms +step:1051/1670 train_time:102552ms step_avg:97.58ms +step:1052/1670 train_time:102649ms step_avg:97.58ms +step:1053/1670 train_time:102747ms step_avg:97.58ms +step:1054/1670 train_time:102844ms step_avg:97.58ms +step:1055/1670 train_time:102942ms step_avg:97.57ms +step:1056/1670 train_time:103038ms step_avg:97.57ms +step:1057/1670 train_time:103135ms step_avg:97.57ms +step:1058/1670 train_time:103233ms step_avg:97.57ms +step:1059/1670 train_time:103331ms step_avg:97.57ms +step:1060/1670 train_time:103429ms step_avg:97.57ms +step:1061/1670 train_time:103527ms step_avg:97.57ms +step:1062/1670 train_time:103793ms step_avg:97.73ms +step:1063/1670 train_time:103890ms step_avg:97.73ms +step:1064/1670 train_time:103986ms step_avg:97.73ms +step:1065/1670 train_time:104081ms step_avg:97.73ms +step:1066/1670 train_time:104177ms step_avg:97.73ms +step:1067/1670 train_time:104272ms step_avg:97.72ms +step:1068/1670 train_time:104369ms step_avg:97.72ms +step:1069/1670 train_time:104466ms step_avg:97.72ms +step:1070/1670 train_time:104561ms step_avg:97.72ms +step:1071/1670 train_time:104657ms step_avg:97.72ms +step:1072/1670 train_time:104756ms step_avg:97.72ms +step:1073/1670 train_time:104856ms step_avg:97.72ms +step:1074/1670 train_time:104955ms step_avg:97.72ms +step:1075/1670 train_time:105052ms step_avg:97.72ms +step:1076/1670 train_time:105150ms step_avg:97.72ms +step:1077/1670 train_time:105246ms step_avg:97.72ms +step:1078/1670 train_time:105342ms step_avg:97.72ms +step:1079/1670 train_time:105438ms step_avg:97.72ms +step:1080/1670 train_time:105534ms step_avg:97.72ms +step:1081/1670 train_time:105631ms step_avg:97.72ms +step:1082/1670 train_time:105729ms step_avg:97.72ms +step:1083/1670 train_time:105830ms step_avg:97.72ms +step:1084/1670 train_time:105929ms step_avg:97.72ms +step:1085/1670 train_time:106026ms step_avg:97.72ms +step:1086/1670 train_time:106124ms step_avg:97.72ms +step:1087/1670 train_time:106220ms step_avg:97.72ms +step:1088/1670 train_time:106316ms step_avg:97.72ms +step:1089/1670 train_time:106412ms step_avg:97.72ms +step:1090/1670 train_time:106510ms step_avg:97.72ms +step:1091/1670 train_time:106606ms step_avg:97.71ms +step:1092/1670 train_time:106703ms step_avg:97.71ms +step:1093/1670 train_time:106801ms step_avg:97.71ms +step:1094/1670 train_time:106899ms step_avg:97.71ms +step:1095/1670 train_time:106995ms step_avg:97.71ms +step:1096/1670 train_time:107093ms step_avg:97.71ms +step:1097/1670 train_time:107191ms step_avg:97.71ms +step:1098/1670 train_time:107288ms step_avg:97.71ms +step:1099/1670 train_time:107384ms step_avg:97.71ms +step:1100/1670 train_time:107480ms step_avg:97.71ms +step:1101/1670 train_time:107576ms step_avg:97.71ms +step:1102/1670 train_time:107673ms step_avg:97.71ms +step:1103/1670 train_time:107771ms step_avg:97.71ms +step:1104/1670 train_time:107869ms step_avg:97.71ms +step:1105/1670 train_time:107967ms step_avg:97.71ms +step:1106/1670 train_time:108065ms step_avg:97.71ms +step:1107/1670 train_time:108162ms step_avg:97.71ms +step:1108/1670 train_time:108259ms step_avg:97.71ms +step:1109/1670 train_time:108355ms step_avg:97.71ms +step:1110/1670 train_time:108452ms step_avg:97.70ms +step:1111/1670 train_time:108549ms step_avg:97.70ms +step:1112/1670 train_time:108646ms step_avg:97.70ms +step:1113/1670 train_time:108743ms step_avg:97.70ms +step:1114/1670 train_time:108840ms step_avg:97.70ms +step:1115/1670 train_time:108937ms step_avg:97.70ms +step:1116/1670 train_time:109034ms step_avg:97.70ms +step:1117/1670 train_time:109132ms step_avg:97.70ms +step:1118/1670 train_time:109232ms step_avg:97.70ms +step:1119/1670 train_time:109331ms step_avg:97.70ms +step:1120/1670 train_time:109429ms step_avg:97.70ms +step:1121/1670 train_time:109528ms step_avg:97.71ms +step:1122/1670 train_time:109626ms step_avg:97.71ms +step:1123/1670 train_time:109725ms step_avg:97.71ms +step:1124/1670 train_time:109823ms step_avg:97.71ms +step:1125/1670 train_time:109920ms step_avg:97.71ms +step:1125/1670 val_loss:3.4232 train_time:110017ms step_avg:97.79ms +step:1126/1670 train_time:110039ms step_avg:97.73ms +step:1127/1670 train_time:110126ms step_avg:97.72ms +step:1128/1670 train_time:110224ms step_avg:97.72ms +step:1129/1670 train_time:110320ms step_avg:97.72ms +step:1130/1670 train_time:110417ms step_avg:97.71ms +step:1131/1670 train_time:110514ms step_avg:97.71ms +step:1132/1670 train_time:110611ms step_avg:97.71ms +step:1133/1670 train_time:110708ms step_avg:97.71ms +step:1134/1670 train_time:110805ms step_avg:97.71ms +step:1135/1670 train_time:110901ms step_avg:97.71ms +step:1136/1670 train_time:111003ms step_avg:97.71ms +step:1137/1670 train_time:111103ms step_avg:97.72ms +step:1138/1670 train_time:111201ms step_avg:97.72ms +step:1139/1670 train_time:111299ms step_avg:97.72ms +step:1140/1670 train_time:111396ms step_avg:97.72ms +step:1141/1670 train_time:111492ms step_avg:97.71ms +step:1142/1670 train_time:111590ms step_avg:97.71ms +step:1143/1670 train_time:111686ms step_avg:97.71ms +step:1144/1670 train_time:111783ms step_avg:97.71ms +step:1145/1670 train_time:111881ms step_avg:97.71ms +step:1146/1670 train_time:111981ms step_avg:97.71ms +step:1147/1670 train_time:112082ms step_avg:97.72ms +step:1148/1670 train_time:112181ms step_avg:97.72ms +step:1149/1670 train_time:112280ms step_avg:97.72ms +step:1150/1670 train_time:112378ms step_avg:97.72ms +step:1151/1670 train_time:112475ms step_avg:97.72ms +step:1152/1670 train_time:112572ms step_avg:97.72ms +step:1153/1670 train_time:112669ms step_avg:97.72ms +step:1154/1670 train_time:112766ms step_avg:97.72ms +step:1155/1670 train_time:112863ms step_avg:97.72ms +step:1156/1670 train_time:112961ms step_avg:97.72ms +step:1157/1670 train_time:113061ms step_avg:97.72ms +step:1158/1670 train_time:113160ms step_avg:97.72ms +step:1159/1670 train_time:113259ms step_avg:97.72ms +step:1160/1670 train_time:113357ms step_avg:97.72ms +step:1161/1670 train_time:113455ms step_avg:97.72ms +step:1162/1670 train_time:113552ms step_avg:97.72ms +step:1163/1670 train_time:113649ms step_avg:97.72ms +step:1164/1670 train_time:113747ms step_avg:97.72ms +step:1165/1670 train_time:113843ms step_avg:97.72ms +step:1166/1670 train_time:113940ms step_avg:97.72ms +step:1167/1670 train_time:114039ms step_avg:97.72ms +step:1168/1670 train_time:114138ms step_avg:97.72ms +step:1169/1670 train_time:114237ms step_avg:97.72ms +step:1170/1670 train_time:114334ms step_avg:97.72ms +step:1171/1670 train_time:114433ms step_avg:97.72ms +step:1172/1670 train_time:114530ms step_avg:97.72ms +step:1173/1670 train_time:114627ms step_avg:97.72ms +step:1174/1670 train_time:114724ms step_avg:97.72ms +step:1175/1670 train_time:114821ms step_avg:97.72ms +step:1176/1670 train_time:114919ms step_avg:97.72ms +step:1177/1670 train_time:115017ms step_avg:97.72ms +step:1178/1670 train_time:115116ms step_avg:97.72ms +step:1179/1670 train_time:115215ms step_avg:97.72ms +step:1180/1670 train_time:115313ms step_avg:97.72ms +step:1181/1670 train_time:115410ms step_avg:97.72ms +step:1182/1670 train_time:115507ms step_avg:97.72ms +step:1183/1670 train_time:115604ms step_avg:97.72ms +step:1184/1670 train_time:115701ms step_avg:97.72ms +step:1185/1670 train_time:115798ms step_avg:97.72ms +step:1186/1670 train_time:115897ms step_avg:97.72ms +step:1187/1670 train_time:115995ms step_avg:97.72ms +step:1188/1670 train_time:116092ms step_avg:97.72ms +step:1189/1670 train_time:116190ms step_avg:97.72ms +step:1190/1670 train_time:116288ms step_avg:97.72ms +step:1191/1670 train_time:116385ms step_avg:97.72ms +step:1192/1670 train_time:116483ms step_avg:97.72ms +step:1193/1670 train_time:116581ms step_avg:97.72ms +step:1194/1670 train_time:116679ms step_avg:97.72ms +step:1195/1670 train_time:116777ms step_avg:97.72ms +step:1196/1670 train_time:116874ms step_avg:97.72ms +step:1197/1670 train_time:116972ms step_avg:97.72ms +step:1198/1670 train_time:117069ms step_avg:97.72ms +step:1199/1670 train_time:117167ms step_avg:97.72ms +step:1200/1670 train_time:117264ms step_avg:97.72ms +step:1201/1670 train_time:117361ms step_avg:97.72ms +step:1202/1670 train_time:117460ms step_avg:97.72ms +step:1203/1670 train_time:117559ms step_avg:97.72ms +step:1204/1670 train_time:117658ms step_avg:97.72ms +step:1205/1670 train_time:117756ms step_avg:97.72ms +step:1206/1670 train_time:117854ms step_avg:97.72ms +step:1207/1670 train_time:117951ms step_avg:97.72ms +step:1208/1670 train_time:118050ms step_avg:97.72ms +step:1209/1670 train_time:118147ms step_avg:97.72ms +step:1210/1670 train_time:118244ms step_avg:97.72ms +step:1211/1670 train_time:118342ms step_avg:97.72ms +step:1212/1670 train_time:118439ms step_avg:97.72ms +step:1213/1670 train_time:118537ms step_avg:97.72ms +step:1214/1670 train_time:118635ms step_avg:97.72ms +step:1215/1670 train_time:118733ms step_avg:97.72ms +step:1216/1670 train_time:118830ms step_avg:97.72ms +step:1217/1670 train_time:118928ms step_avg:97.72ms +step:1218/1670 train_time:119026ms step_avg:97.72ms +step:1219/1670 train_time:119124ms step_avg:97.72ms +step:1220/1670 train_time:119221ms step_avg:97.72ms +step:1221/1670 train_time:119319ms step_avg:97.72ms +step:1222/1670 train_time:119417ms step_avg:97.72ms +step:1223/1670 train_time:119515ms step_avg:97.72ms +step:1224/1670 train_time:119613ms step_avg:97.72ms +step:1225/1670 train_time:119710ms step_avg:97.72ms +step:1226/1670 train_time:119807ms step_avg:97.72ms +step:1227/1670 train_time:119903ms step_avg:97.72ms +step:1228/1670 train_time:120002ms step_avg:97.72ms +step:1229/1670 train_time:120100ms step_avg:97.72ms +step:1230/1670 train_time:120199ms step_avg:97.72ms +step:1231/1670 train_time:120297ms step_avg:97.72ms +step:1232/1670 train_time:120396ms step_avg:97.72ms +step:1233/1670 train_time:120494ms step_avg:97.72ms +step:1234/1670 train_time:120592ms step_avg:97.72ms +step:1235/1670 train_time:120690ms step_avg:97.72ms +step:1236/1670 train_time:120787ms step_avg:97.72ms +step:1237/1670 train_time:120884ms step_avg:97.72ms +step:1238/1670 train_time:120982ms step_avg:97.72ms +step:1239/1670 train_time:121080ms step_avg:97.72ms +step:1240/1670 train_time:121179ms step_avg:97.72ms +step:1241/1670 train_time:121277ms step_avg:97.73ms +step:1242/1670 train_time:121376ms step_avg:97.73ms +step:1243/1670 train_time:121473ms step_avg:97.73ms +step:1244/1670 train_time:121571ms step_avg:97.73ms +step:1245/1670 train_time:121668ms step_avg:97.73ms +step:1246/1670 train_time:121765ms step_avg:97.72ms +step:1247/1670 train_time:121863ms step_avg:97.72ms +step:1248/1670 train_time:121961ms step_avg:97.73ms +step:1249/1670 train_time:122059ms step_avg:97.73ms +step:1250/1670 train_time:122157ms step_avg:97.73ms +step:1250/1670 val_loss:3.3799 train_time:122255ms step_avg:97.80ms +step:1251/1670 train_time:122276ms step_avg:97.74ms +step:1252/1670 train_time:122359ms step_avg:97.73ms +step:1253/1670 train_time:122462ms step_avg:97.73ms +step:1254/1670 train_time:122559ms step_avg:97.73ms +step:1255/1670 train_time:122655ms step_avg:97.73ms +step:1256/1670 train_time:122752ms step_avg:97.73ms +step:1257/1670 train_time:122849ms step_avg:97.73ms +step:1258/1670 train_time:122946ms step_avg:97.73ms +step:1259/1670 train_time:123043ms step_avg:97.73ms +step:1260/1670 train_time:123139ms step_avg:97.73ms +step:1261/1670 train_time:123237ms step_avg:97.73ms +step:1262/1670 train_time:123336ms step_avg:97.73ms +step:1263/1670 train_time:123436ms step_avg:97.73ms +step:1264/1670 train_time:123535ms step_avg:97.73ms +step:1265/1670 train_time:123632ms step_avg:97.73ms +step:1266/1670 train_time:123729ms step_avg:97.73ms +step:1267/1670 train_time:123827ms step_avg:97.73ms +step:1268/1670 train_time:123924ms step_avg:97.73ms +step:1269/1670 train_time:124021ms step_avg:97.73ms +step:1270/1670 train_time:124118ms step_avg:97.73ms +step:1271/1670 train_time:124215ms step_avg:97.73ms +step:1272/1670 train_time:124317ms step_avg:97.73ms +step:1273/1670 train_time:124416ms step_avg:97.73ms +step:1274/1670 train_time:124688ms step_avg:97.87ms +step:1275/1670 train_time:124901ms step_avg:97.96ms +step:1276/1670 train_time:124996ms step_avg:97.96ms +step:1277/1670 train_time:125093ms step_avg:97.96ms +step:1278/1670 train_time:125190ms step_avg:97.96ms +step:1279/1670 train_time:125286ms step_avg:97.96ms +step:1280/1670 train_time:125383ms step_avg:97.96ms +step:1281/1670 train_time:125479ms step_avg:97.95ms +step:1282/1670 train_time:125576ms step_avg:97.95ms +step:1283/1670 train_time:125675ms step_avg:97.95ms +step:1284/1670 train_time:125776ms step_avg:97.96ms +step:1285/1670 train_time:125878ms step_avg:97.96ms +step:1286/1670 train_time:125977ms step_avg:97.96ms +step:1287/1670 train_time:126074ms step_avg:97.96ms +step:1288/1670 train_time:126171ms step_avg:97.96ms +step:1289/1670 train_time:126269ms step_avg:97.96ms +step:1290/1670 train_time:126366ms step_avg:97.96ms +step:1291/1670 train_time:126463ms step_avg:97.96ms +step:1292/1670 train_time:126560ms step_avg:97.96ms +step:1293/1670 train_time:126657ms step_avg:97.96ms +step:1294/1670 train_time:126757ms step_avg:97.96ms +step:1295/1670 train_time:126857ms step_avg:97.96ms +step:1296/1670 train_time:126956ms step_avg:97.96ms +step:1297/1670 train_time:127055ms step_avg:97.96ms +step:1298/1670 train_time:127153ms step_avg:97.96ms +step:1299/1670 train_time:127251ms step_avg:97.96ms +step:1300/1670 train_time:127348ms step_avg:97.96ms +step:1301/1670 train_time:127446ms step_avg:97.96ms +step:1302/1670 train_time:127543ms step_avg:97.96ms +step:1303/1670 train_time:127640ms step_avg:97.96ms +step:1304/1670 train_time:127738ms step_avg:97.96ms +step:1305/1670 train_time:127835ms step_avg:97.96ms +step:1306/1670 train_time:127935ms step_avg:97.96ms +step:1307/1670 train_time:128033ms step_avg:97.96ms +step:1308/1670 train_time:128131ms step_avg:97.96ms +step:1309/1670 train_time:128228ms step_avg:97.96ms +step:1310/1670 train_time:128326ms step_avg:97.96ms +step:1311/1670 train_time:128423ms step_avg:97.96ms +step:1312/1670 train_time:128520ms step_avg:97.96ms +step:1313/1670 train_time:128617ms step_avg:97.96ms +step:1314/1670 train_time:128715ms step_avg:97.96ms +step:1315/1670 train_time:128815ms step_avg:97.96ms +step:1316/1670 train_time:128914ms step_avg:97.96ms +step:1317/1670 train_time:129012ms step_avg:97.96ms +step:1318/1670 train_time:129110ms step_avg:97.96ms +step:1319/1670 train_time:129208ms step_avg:97.96ms +step:1320/1670 train_time:129305ms step_avg:97.96ms +step:1321/1670 train_time:129403ms step_avg:97.96ms +step:1322/1670 train_time:129499ms step_avg:97.96ms +step:1323/1670 train_time:129597ms step_avg:97.96ms +step:1324/1670 train_time:129695ms step_avg:97.96ms +step:1325/1670 train_time:129793ms step_avg:97.96ms +step:1326/1670 train_time:129892ms step_avg:97.96ms +step:1327/1670 train_time:129990ms step_avg:97.96ms +step:1328/1670 train_time:130088ms step_avg:97.96ms +step:1329/1670 train_time:130185ms step_avg:97.96ms +step:1330/1670 train_time:130282ms step_avg:97.96ms +step:1331/1670 train_time:130380ms step_avg:97.96ms +step:1332/1670 train_time:130477ms step_avg:97.96ms +step:1333/1670 train_time:130575ms step_avg:97.96ms +step:1334/1670 train_time:130673ms step_avg:97.96ms +step:1335/1670 train_time:130772ms step_avg:97.96ms +step:1336/1670 train_time:130870ms step_avg:97.96ms +step:1337/1670 train_time:130968ms step_avg:97.96ms +step:1338/1670 train_time:131066ms step_avg:97.96ms +step:1339/1670 train_time:131164ms step_avg:97.96ms +step:1340/1670 train_time:131261ms step_avg:97.96ms +step:1341/1670 train_time:131359ms step_avg:97.96ms +step:1342/1670 train_time:131457ms step_avg:97.96ms +step:1343/1670 train_time:131555ms step_avg:97.96ms +step:1344/1670 train_time:131653ms step_avg:97.96ms +step:1345/1670 train_time:131751ms step_avg:97.96ms +step:1346/1670 train_time:131849ms step_avg:97.96ms +step:1347/1670 train_time:131947ms step_avg:97.96ms +step:1348/1670 train_time:132045ms step_avg:97.96ms +step:1349/1670 train_time:132143ms step_avg:97.96ms +step:1350/1670 train_time:132240ms step_avg:97.96ms +step:1351/1670 train_time:132337ms step_avg:97.95ms +step:1352/1670 train_time:132435ms step_avg:97.95ms +step:1353/1670 train_time:132533ms step_avg:97.96ms +step:1354/1670 train_time:132632ms step_avg:97.96ms +step:1355/1670 train_time:132730ms step_avg:97.96ms +step:1356/1670 train_time:132828ms step_avg:97.96ms +step:1357/1670 train_time:132925ms step_avg:97.96ms +step:1358/1670 train_time:133023ms step_avg:97.96ms +step:1359/1670 train_time:133120ms step_avg:97.95ms +step:1360/1670 train_time:133218ms step_avg:97.95ms +step:1361/1670 train_time:133316ms step_avg:97.95ms +step:1362/1670 train_time:133414ms step_avg:97.95ms +step:1363/1670 train_time:133513ms step_avg:97.96ms +step:1364/1670 train_time:133612ms step_avg:97.96ms +step:1365/1670 train_time:133709ms step_avg:97.96ms +step:1366/1670 train_time:133806ms step_avg:97.95ms +step:1367/1670 train_time:133903ms step_avg:97.95ms +step:1368/1670 train_time:134000ms step_avg:97.95ms +step:1369/1670 train_time:134098ms step_avg:97.95ms +step:1370/1670 train_time:134196ms step_avg:97.95ms +step:1371/1670 train_time:134294ms step_avg:97.95ms +step:1372/1670 train_time:134392ms step_avg:97.95ms +step:1373/1670 train_time:134490ms step_avg:97.95ms +step:1374/1670 train_time:134589ms step_avg:97.95ms +step:1375/1670 train_time:134687ms step_avg:97.95ms +step:1375/1670 val_loss:3.3433 train_time:134784ms step_avg:98.02ms +step:1376/1670 train_time:134806ms step_avg:97.97ms +step:1377/1670 train_time:134891ms step_avg:97.96ms +step:1378/1670 train_time:134991ms step_avg:97.96ms +step:1379/1670 train_time:135087ms step_avg:97.96ms +step:1380/1670 train_time:135185ms step_avg:97.96ms +step:1381/1670 train_time:135283ms step_avg:97.96ms +step:1382/1670 train_time:135382ms step_avg:97.96ms +step:1383/1670 train_time:135478ms step_avg:97.96ms +step:1384/1670 train_time:135575ms step_avg:97.96ms +step:1385/1670 train_time:135671ms step_avg:97.96ms +step:1386/1670 train_time:135770ms step_avg:97.96ms +step:1387/1670 train_time:135870ms step_avg:97.96ms +step:1388/1670 train_time:135970ms step_avg:97.96ms +step:1389/1670 train_time:136068ms step_avg:97.96ms +step:1390/1670 train_time:136166ms step_avg:97.96ms +step:1391/1670 train_time:136264ms step_avg:97.96ms +step:1392/1670 train_time:136362ms step_avg:97.96ms +step:1393/1670 train_time:136459ms step_avg:97.96ms +step:1394/1670 train_time:136556ms step_avg:97.96ms +step:1395/1670 train_time:136653ms step_avg:97.96ms +step:1396/1670 train_time:136750ms step_avg:97.96ms +step:1397/1670 train_time:136849ms step_avg:97.96ms +step:1398/1670 train_time:136947ms step_avg:97.96ms +step:1399/1670 train_time:137046ms step_avg:97.96ms +step:1400/1670 train_time:137143ms step_avg:97.96ms +step:1401/1670 train_time:137240ms step_avg:97.96ms +step:1402/1670 train_time:137339ms step_avg:97.96ms +step:1403/1670 train_time:137436ms step_avg:97.96ms +step:1404/1670 train_time:137533ms step_avg:97.96ms +step:1405/1670 train_time:137630ms step_avg:97.96ms +step:1406/1670 train_time:137728ms step_avg:97.96ms +step:1407/1670 train_time:137827ms step_avg:97.96ms +step:1408/1670 train_time:137927ms step_avg:97.96ms +step:1409/1670 train_time:138026ms step_avg:97.96ms +step:1410/1670 train_time:138124ms step_avg:97.96ms +step:1411/1670 train_time:138222ms step_avg:97.96ms +step:1412/1670 train_time:138320ms step_avg:97.96ms +step:1413/1670 train_time:138418ms step_avg:97.96ms +step:1414/1670 train_time:138516ms step_avg:97.96ms +step:1415/1670 train_time:138614ms step_avg:97.96ms +step:1416/1670 train_time:138710ms step_avg:97.96ms +step:1417/1670 train_time:138808ms step_avg:97.96ms +step:1418/1670 train_time:138906ms step_avg:97.96ms +step:1419/1670 train_time:139006ms step_avg:97.96ms +step:1420/1670 train_time:139106ms step_avg:97.96ms +step:1421/1670 train_time:139205ms step_avg:97.96ms +step:1422/1670 train_time:139303ms step_avg:97.96ms +step:1423/1670 train_time:139401ms step_avg:97.96ms +step:1424/1670 train_time:139500ms step_avg:97.96ms +step:1425/1670 train_time:139599ms step_avg:97.96ms +step:1426/1670 train_time:139698ms step_avg:97.96ms +step:1427/1670 train_time:139796ms step_avg:97.97ms +step:1428/1670 train_time:139894ms step_avg:97.96ms +step:1429/1670 train_time:139992ms step_avg:97.96ms +step:1430/1670 train_time:140090ms step_avg:97.96ms +step:1431/1670 train_time:140188ms step_avg:97.96ms +step:1432/1670 train_time:140286ms step_avg:97.96ms +step:1433/1670 train_time:140383ms step_avg:97.96ms +step:1434/1670 train_time:140480ms step_avg:97.96ms +step:1435/1670 train_time:140578ms step_avg:97.96ms +step:1436/1670 train_time:140676ms step_avg:97.96ms +step:1437/1670 train_time:140774ms step_avg:97.96ms +step:1438/1670 train_time:140871ms step_avg:97.96ms +step:1439/1670 train_time:140968ms step_avg:97.96ms +step:1440/1670 train_time:141066ms step_avg:97.96ms +step:1441/1670 train_time:141165ms step_avg:97.96ms +step:1442/1670 train_time:141263ms step_avg:97.96ms +step:1443/1670 train_time:141360ms step_avg:97.96ms +step:1444/1670 train_time:141458ms step_avg:97.96ms +step:1445/1670 train_time:141556ms step_avg:97.96ms +step:1446/1670 train_time:141653ms step_avg:97.96ms +step:1447/1670 train_time:141751ms step_avg:97.96ms +step:1448/1670 train_time:141849ms step_avg:97.96ms +step:1449/1670 train_time:141947ms step_avg:97.96ms +step:1450/1670 train_time:142046ms step_avg:97.96ms +step:1451/1670 train_time:142143ms step_avg:97.96ms +step:1452/1670 train_time:142241ms step_avg:97.96ms +step:1453/1670 train_time:142339ms step_avg:97.96ms +step:1454/1670 train_time:142436ms step_avg:97.96ms +step:1455/1670 train_time:142533ms step_avg:97.96ms +step:1456/1670 train_time:142630ms step_avg:97.96ms +step:1457/1670 train_time:142728ms step_avg:97.96ms +step:1458/1670 train_time:142827ms step_avg:97.96ms +step:1459/1670 train_time:142925ms step_avg:97.96ms +step:1460/1670 train_time:143024ms step_avg:97.96ms +step:1461/1670 train_time:143124ms step_avg:97.96ms +step:1462/1670 train_time:143222ms step_avg:97.96ms +step:1463/1670 train_time:143319ms step_avg:97.96ms +step:1464/1670 train_time:143417ms step_avg:97.96ms +step:1465/1670 train_time:143515ms step_avg:97.96ms +step:1466/1670 train_time:143613ms step_avg:97.96ms +step:1467/1670 train_time:143710ms step_avg:97.96ms +step:1468/1670 train_time:143808ms step_avg:97.96ms +step:1469/1670 train_time:143906ms step_avg:97.96ms +step:1470/1670 train_time:144004ms step_avg:97.96ms +step:1471/1670 train_time:144101ms step_avg:97.96ms +step:1472/1670 train_time:144199ms step_avg:97.96ms +step:1473/1670 train_time:144297ms step_avg:97.96ms +step:1474/1670 train_time:144394ms step_avg:97.96ms +step:1475/1670 train_time:144492ms step_avg:97.96ms +step:1476/1670 train_time:144589ms step_avg:97.96ms +step:1477/1670 train_time:144687ms step_avg:97.96ms +step:1478/1670 train_time:144786ms step_avg:97.96ms +step:1479/1670 train_time:144885ms step_avg:97.96ms +step:1480/1670 train_time:144983ms step_avg:97.96ms +step:1481/1670 train_time:145080ms step_avg:97.96ms +step:1482/1670 train_time:145178ms step_avg:97.96ms +step:1483/1670 train_time:145276ms step_avg:97.96ms +step:1484/1670 train_time:145373ms step_avg:97.96ms +step:1485/1670 train_time:145639ms step_avg:98.07ms +step:1486/1670 train_time:145720ms step_avg:98.06ms +step:1487/1670 train_time:145817ms step_avg:98.06ms +step:1488/1670 train_time:145913ms step_avg:98.06ms +step:1489/1670 train_time:146010ms step_avg:98.06ms +step:1490/1670 train_time:146107ms step_avg:98.06ms +step:1491/1670 train_time:146204ms step_avg:98.06ms +step:1492/1670 train_time:146301ms step_avg:98.06ms +step:1493/1670 train_time:146398ms step_avg:98.06ms +step:1494/1670 train_time:146496ms step_avg:98.06ms +step:1495/1670 train_time:146598ms step_avg:98.06ms +step:1496/1670 train_time:146698ms step_avg:98.06ms +step:1497/1670 train_time:146796ms step_avg:98.06ms +step:1498/1670 train_time:146893ms step_avg:98.06ms +step:1499/1670 train_time:146990ms step_avg:98.06ms +step:1500/1670 train_time:147088ms step_avg:98.06ms +step:1500/1670 val_loss:3.3107 train_time:147184ms step_avg:98.12ms +step:1501/1670 train_time:147205ms step_avg:98.07ms +step:1502/1670 train_time:147289ms step_avg:98.06ms +step:1503/1670 train_time:147389ms step_avg:98.06ms +step:1504/1670 train_time:147486ms step_avg:98.06ms +step:1505/1670 train_time:147583ms step_avg:98.06ms +step:1506/1670 train_time:147680ms step_avg:98.06ms +step:1507/1670 train_time:147776ms step_avg:98.06ms +step:1508/1670 train_time:147873ms step_avg:98.06ms +step:1509/1670 train_time:147971ms step_avg:98.06ms +step:1510/1670 train_time:148069ms step_avg:98.06ms +step:1511/1670 train_time:148169ms step_avg:98.06ms +step:1512/1670 train_time:148269ms step_avg:98.06ms +step:1513/1670 train_time:148369ms step_avg:98.06ms +step:1514/1670 train_time:148468ms step_avg:98.06ms +step:1515/1670 train_time:148566ms step_avg:98.06ms +step:1516/1670 train_time:148665ms step_avg:98.06ms +step:1517/1670 train_time:148763ms step_avg:98.06ms +step:1518/1670 train_time:148860ms step_avg:98.06ms +step:1519/1670 train_time:148958ms step_avg:98.06ms +step:1520/1670 train_time:149055ms step_avg:98.06ms +step:1521/1670 train_time:149153ms step_avg:98.06ms +step:1522/1670 train_time:149251ms step_avg:98.06ms +step:1523/1670 train_time:149349ms step_avg:98.06ms +step:1524/1670 train_time:149446ms step_avg:98.06ms +step:1525/1670 train_time:149545ms step_avg:98.06ms +step:1526/1670 train_time:149643ms step_avg:98.06ms +step:1527/1670 train_time:149740ms step_avg:98.06ms +step:1528/1670 train_time:149838ms step_avg:98.06ms +step:1529/1670 train_time:149935ms step_avg:98.06ms +step:1530/1670 train_time:150032ms step_avg:98.06ms +step:1531/1670 train_time:150130ms step_avg:98.06ms +step:1532/1670 train_time:150229ms step_avg:98.06ms +step:1533/1670 train_time:150328ms step_avg:98.06ms +step:1534/1670 train_time:150427ms step_avg:98.06ms +step:1535/1670 train_time:150525ms step_avg:98.06ms +step:1536/1670 train_time:150623ms step_avg:98.06ms +step:1537/1670 train_time:150721ms step_avg:98.06ms +step:1538/1670 train_time:150819ms step_avg:98.06ms +step:1539/1670 train_time:150916ms step_avg:98.06ms +step:1540/1670 train_time:151013ms step_avg:98.06ms +step:1541/1670 train_time:151110ms step_avg:98.06ms +step:1542/1670 train_time:151208ms step_avg:98.06ms +step:1543/1670 train_time:151308ms step_avg:98.06ms +step:1544/1670 train_time:151406ms step_avg:98.06ms +step:1545/1670 train_time:151504ms step_avg:98.06ms +step:1546/1670 train_time:151602ms step_avg:98.06ms +step:1547/1670 train_time:151701ms step_avg:98.06ms +step:1548/1670 train_time:151800ms step_avg:98.06ms +step:1549/1670 train_time:151898ms step_avg:98.06ms +step:1550/1670 train_time:151995ms step_avg:98.06ms +step:1551/1670 train_time:152093ms step_avg:98.06ms +step:1552/1670 train_time:152190ms step_avg:98.06ms +step:1553/1670 train_time:152289ms step_avg:98.06ms +step:1554/1670 train_time:152387ms step_avg:98.06ms +step:1555/1670 train_time:152484ms step_avg:98.06ms +step:1556/1670 train_time:152582ms step_avg:98.06ms +step:1557/1670 train_time:152679ms step_avg:98.06ms +step:1558/1670 train_time:152777ms step_avg:98.06ms +step:1559/1670 train_time:152875ms step_avg:98.06ms +step:1560/1670 train_time:152973ms step_avg:98.06ms +step:1561/1670 train_time:153070ms step_avg:98.06ms +step:1562/1670 train_time:153168ms step_avg:98.06ms +step:1563/1670 train_time:153267ms step_avg:98.06ms +step:1564/1670 train_time:153365ms step_avg:98.06ms +step:1565/1670 train_time:153464ms step_avg:98.06ms +step:1566/1670 train_time:153562ms step_avg:98.06ms +step:1567/1670 train_time:153659ms step_avg:98.06ms +step:1568/1670 train_time:153756ms step_avg:98.06ms +step:1569/1670 train_time:153853ms step_avg:98.06ms +step:1570/1670 train_time:153951ms step_avg:98.06ms +step:1571/1670 train_time:154049ms step_avg:98.06ms +step:1572/1670 train_time:154148ms step_avg:98.06ms +step:1573/1670 train_time:154246ms step_avg:98.06ms +step:1574/1670 train_time:154344ms step_avg:98.06ms +step:1575/1670 train_time:154441ms step_avg:98.06ms +step:1576/1670 train_time:154539ms step_avg:98.06ms +step:1577/1670 train_time:154636ms step_avg:98.06ms +step:1578/1670 train_time:154733ms step_avg:98.06ms +step:1579/1670 train_time:154831ms step_avg:98.06ms +step:1580/1670 train_time:154929ms step_avg:98.06ms +step:1581/1670 train_time:155029ms step_avg:98.06ms +step:1582/1670 train_time:155128ms step_avg:98.06ms +step:1583/1670 train_time:155228ms step_avg:98.06ms +step:1584/1670 train_time:155326ms step_avg:98.06ms +step:1585/1670 train_time:155425ms step_avg:98.06ms +step:1586/1670 train_time:155523ms step_avg:98.06ms +step:1587/1670 train_time:155622ms step_avg:98.06ms +step:1588/1670 train_time:155720ms step_avg:98.06ms +step:1589/1670 train_time:155818ms step_avg:98.06ms +step:1590/1670 train_time:155915ms step_avg:98.06ms +step:1591/1670 train_time:156013ms step_avg:98.06ms +step:1592/1670 train_time:156111ms step_avg:98.06ms +step:1593/1670 train_time:156208ms step_avg:98.06ms +step:1594/1670 train_time:156307ms step_avg:98.06ms +step:1595/1670 train_time:156405ms step_avg:98.06ms +step:1596/1670 train_time:156504ms step_avg:98.06ms +step:1597/1670 train_time:156602ms step_avg:98.06ms +step:1598/1670 train_time:156699ms step_avg:98.06ms +step:1599/1670 train_time:156796ms step_avg:98.06ms +step:1600/1670 train_time:156894ms step_avg:98.06ms +step:1601/1670 train_time:156991ms step_avg:98.06ms +step:1602/1670 train_time:157089ms step_avg:98.06ms +step:1603/1670 train_time:157187ms step_avg:98.06ms +step:1604/1670 train_time:157285ms step_avg:98.06ms +step:1605/1670 train_time:157384ms step_avg:98.06ms +step:1606/1670 train_time:157482ms step_avg:98.06ms +step:1607/1670 train_time:157579ms step_avg:98.06ms +step:1608/1670 train_time:157676ms step_avg:98.06ms +step:1609/1670 train_time:157774ms step_avg:98.06ms +step:1610/1670 train_time:157872ms step_avg:98.06ms +step:1611/1670 train_time:157970ms step_avg:98.06ms +step:1612/1670 train_time:158068ms step_avg:98.06ms +step:1613/1670 train_time:158166ms step_avg:98.06ms +step:1614/1670 train_time:158265ms step_avg:98.06ms +step:1615/1670 train_time:158363ms step_avg:98.06ms +step:1616/1670 train_time:158461ms step_avg:98.06ms +step:1617/1670 train_time:158559ms step_avg:98.06ms +step:1618/1670 train_time:158656ms step_avg:98.06ms +step:1619/1670 train_time:158754ms step_avg:98.06ms +step:1620/1670 train_time:158851ms step_avg:98.06ms +step:1621/1670 train_time:158949ms step_avg:98.06ms +step:1622/1670 train_time:159047ms step_avg:98.06ms +step:1623/1670 train_time:159145ms step_avg:98.06ms +step:1624/1670 train_time:159243ms step_avg:98.06ms +step:1625/1670 train_time:159342ms step_avg:98.06ms +step:1625/1670 val_loss:3.2839 train_time:159438ms step_avg:98.12ms +step:1626/1670 train_time:159459ms step_avg:98.07ms +step:1627/1670 train_time:159543ms step_avg:98.06ms +step:1628/1670 train_time:159642ms step_avg:98.06ms +step:1629/1670 train_time:159739ms step_avg:98.06ms +step:1630/1670 train_time:159836ms step_avg:98.06ms +step:1631/1670 train_time:159933ms step_avg:98.06ms +step:1632/1670 train_time:160031ms step_avg:98.06ms +step:1633/1670 train_time:160128ms step_avg:98.06ms +step:1634/1670 train_time:160226ms step_avg:98.06ms +step:1635/1670 train_time:160323ms step_avg:98.06ms +step:1636/1670 train_time:160422ms step_avg:98.06ms +step:1637/1670 train_time:160523ms step_avg:98.06ms +step:1638/1670 train_time:160622ms step_avg:98.06ms +step:1639/1670 train_time:160720ms step_avg:98.06ms +step:1640/1670 train_time:160817ms step_avg:98.06ms +step:1641/1670 train_time:160914ms step_avg:98.06ms +step:1642/1670 train_time:161011ms step_avg:98.06ms +step:1643/1670 train_time:161109ms step_avg:98.06ms +step:1644/1670 train_time:161206ms step_avg:98.06ms +step:1645/1670 train_time:161303ms step_avg:98.06ms +step:1646/1670 train_time:161402ms step_avg:98.06ms +step:1647/1670 train_time:161501ms step_avg:98.06ms +step:1648/1670 train_time:161599ms step_avg:98.06ms +step:1649/1670 train_time:161696ms step_avg:98.06ms +step:1650/1670 train_time:161794ms step_avg:98.06ms +step:1651/1670 train_time:161891ms step_avg:98.06ms +step:1652/1670 train_time:161990ms step_avg:98.06ms +step:1653/1670 train_time:162087ms step_avg:98.06ms +step:1654/1670 train_time:162184ms step_avg:98.06ms +step:1655/1670 train_time:162282ms step_avg:98.06ms +step:1656/1670 train_time:162379ms step_avg:98.06ms +step:1657/1670 train_time:162477ms step_avg:98.06ms +step:1658/1670 train_time:162576ms step_avg:98.06ms +step:1659/1670 train_time:162675ms step_avg:98.06ms +step:1660/1670 train_time:162772ms step_avg:98.06ms +step:1661/1670 train_time:162870ms step_avg:98.06ms +step:1662/1670 train_time:162968ms step_avg:98.06ms +step:1663/1670 train_time:163067ms step_avg:98.06ms +step:1664/1670 train_time:163164ms step_avg:98.06ms +step:1665/1670 train_time:163262ms step_avg:98.06ms +step:1666/1670 train_time:163360ms step_avg:98.06ms +step:1667/1670 train_time:163458ms step_avg:98.06ms +step:1668/1670 train_time:163556ms step_avg:98.06ms +step:1669/1670 train_time:163654ms step_avg:98.06ms +step:1670/1670 train_time:163751ms step_avg:98.05ms +step:1670/1670 val_loss:3.2760 train_time:163848ms step_avg:98.11ms +peak memory allocated: 34001 MiB reserved: 49316 MiB diff --git a/records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt b/records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt new file mode 100644 index 000000000..34fc9dc52 --- /dev/null +++ b/records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:48:32 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 41C P0 121W / 700W | 5858MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 40C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 40C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 50803 C /usr/bin/python 0MiB | +| 0 N/A N/A 50804 C /usr/bin/python 0MiB | +| 0 N/A N/A 50805 C /usr/bin/python 0MiB | +| 0 N/A N/A 50806 C /usr/bin/python 0MiB | +| 0 N/A N/A 50807 C /usr/bin/python 0MiB | +| 0 N/A N/A 50808 C /usr/bin/python 0MiB | +| 0 N/A N/A 50809 C /usr/bin/python 0MiB | +| 0 N/A N/A 50810 C /usr/bin/python 0MiB | +| 1 N/A N/A 50804 C /usr/bin/python 0MiB | +| 2 N/A N/A 50805 C /usr/bin/python 0MiB | +| 3 N/A N/A 50806 C /usr/bin/python 0MiB | +| 4 N/A N/A 50807 C /usr/bin/python 0MiB | +| 5 N/A N/A 50808 C /usr/bin/python 0MiB | +| 6 N/A N/A 50809 C /usr/bin/python 0MiB | +| 7 N/A N/A 50810 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:366ms step_avg:365.67ms +step:2/1670 train_time:386ms step_avg:193.15ms +step:3/1670 train_time:459ms step_avg:153.08ms +step:4/1670 train_time:553ms step_avg:138.26ms +step:5/1670 train_time:648ms step_avg:129.54ms +step:6/1670 train_time:742ms step_avg:123.73ms +step:7/1670 train_time:838ms step_avg:119.67ms +step:8/1670 train_time:933ms step_avg:116.57ms +step:9/1670 train_time:1028ms step_avg:114.20ms +step:10/1670 train_time:1123ms step_avg:112.27ms +step:11/1670 train_time:1219ms step_avg:110.77ms +step:12/1670 train_time:1316ms step_avg:109.64ms +step:13/1670 train_time:1414ms step_avg:108.77ms +step:14/1670 train_time:1512ms step_avg:108.01ms +step:15/1670 train_time:1608ms step_avg:107.19ms +step:16/1670 train_time:1704ms step_avg:106.47ms +step:17/1670 train_time:1799ms step_avg:105.83ms +step:18/1670 train_time:1894ms step_avg:105.25ms +step:19/1670 train_time:1990ms step_avg:104.72ms +step:20/1670 train_time:2085ms step_avg:104.27ms +step:21/1670 train_time:2181ms step_avg:103.88ms +step:22/1670 train_time:2277ms step_avg:103.50ms +step:23/1670 train_time:2373ms step_avg:103.18ms +step:24/1670 train_time:2470ms step_avg:102.92ms +step:25/1670 train_time:2567ms step_avg:102.68ms +step:26/1670 train_time:2664ms step_avg:102.46ms +step:27/1670 train_time:2761ms step_avg:102.27ms +step:28/1670 train_time:2856ms step_avg:102.02ms +step:29/1670 train_time:2952ms step_avg:101.78ms +step:30/1670 train_time:3048ms step_avg:101.60ms +step:31/1670 train_time:3143ms step_avg:101.40ms +step:32/1670 train_time:3240ms step_avg:101.25ms +step:33/1670 train_time:3335ms step_avg:101.07ms +step:34/1670 train_time:3432ms step_avg:100.93ms +step:35/1670 train_time:3530ms step_avg:100.86ms +step:36/1670 train_time:3627ms step_avg:100.75ms +step:37/1670 train_time:3724ms step_avg:100.65ms +step:38/1670 train_time:3821ms step_avg:100.55ms +step:39/1670 train_time:3917ms step_avg:100.42ms +step:40/1670 train_time:4012ms step_avg:100.31ms +step:41/1670 train_time:4108ms step_avg:100.19ms +step:42/1670 train_time:4204ms step_avg:100.09ms +step:43/1670 train_time:4300ms step_avg:100.01ms +step:44/1670 train_time:4397ms step_avg:99.92ms +step:45/1670 train_time:4493ms step_avg:99.84ms +step:46/1670 train_time:4589ms step_avg:99.75ms +step:47/1670 train_time:4686ms step_avg:99.70ms +step:48/1670 train_time:4783ms step_avg:99.65ms +step:49/1670 train_time:4880ms step_avg:99.59ms +step:50/1670 train_time:4976ms step_avg:99.52ms +step:51/1670 train_time:5071ms step_avg:99.43ms +step:52/1670 train_time:5166ms step_avg:99.35ms +step:53/1670 train_time:5262ms step_avg:99.29ms +step:54/1670 train_time:5359ms step_avg:99.23ms +step:55/1670 train_time:5454ms step_avg:99.16ms +step:56/1670 train_time:5549ms step_avg:99.09ms +step:57/1670 train_time:5645ms step_avg:99.04ms +step:58/1670 train_time:5741ms step_avg:98.98ms +step:59/1670 train_time:5837ms step_avg:98.93ms +step:60/1670 train_time:5933ms step_avg:98.88ms +step:61/1670 train_time:6029ms step_avg:98.83ms +step:62/1670 train_time:6125ms step_avg:98.79ms +step:63/1670 train_time:6221ms step_avg:98.74ms +step:64/1670 train_time:6316ms step_avg:98.69ms +step:65/1670 train_time:6412ms step_avg:98.64ms +step:66/1670 train_time:6508ms step_avg:98.60ms +step:67/1670 train_time:6603ms step_avg:98.56ms +step:68/1670 train_time:6700ms step_avg:98.53ms +step:69/1670 train_time:6795ms step_avg:98.48ms +step:70/1670 train_time:6891ms step_avg:98.45ms +step:71/1670 train_time:6987ms step_avg:98.41ms +step:72/1670 train_time:7083ms step_avg:98.37ms +step:73/1670 train_time:7179ms step_avg:98.34ms +step:74/1670 train_time:7275ms step_avg:98.30ms +step:75/1670 train_time:7370ms step_avg:98.27ms +step:76/1670 train_time:7466ms step_avg:98.23ms +step:77/1670 train_time:7561ms step_avg:98.20ms +step:78/1670 train_time:7657ms step_avg:98.17ms +step:79/1670 train_time:7753ms step_avg:98.13ms +step:80/1670 train_time:7849ms step_avg:98.11ms +step:81/1670 train_time:7946ms step_avg:98.09ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8137ms step_avg:98.03ms +step:84/1670 train_time:8232ms step_avg:98.00ms +step:85/1670 train_time:8328ms step_avg:97.97ms +step:86/1670 train_time:8424ms step_avg:97.96ms +step:87/1670 train_time:8520ms step_avg:97.93ms +step:88/1670 train_time:8616ms step_avg:97.91ms +step:89/1670 train_time:8711ms step_avg:97.88ms +step:90/1670 train_time:8807ms step_avg:97.85ms +step:91/1670 train_time:8903ms step_avg:97.84ms +step:92/1670 train_time:8999ms step_avg:97.82ms +step:93/1670 train_time:9095ms step_avg:97.80ms +step:94/1670 train_time:9190ms step_avg:97.77ms +step:95/1670 train_time:9287ms step_avg:97.75ms +step:96/1670 train_time:9383ms step_avg:97.74ms +step:97/1670 train_time:9479ms step_avg:97.73ms +step:98/1670 train_time:9575ms step_avg:97.70ms +step:99/1670 train_time:9670ms step_avg:97.68ms +step:100/1670 train_time:9766ms step_avg:97.66ms +step:101/1670 train_time:9863ms step_avg:97.65ms +step:102/1670 train_time:9959ms step_avg:97.64ms +step:103/1670 train_time:10055ms step_avg:97.62ms +step:104/1670 train_time:10150ms step_avg:97.59ms +step:105/1670 train_time:10245ms step_avg:97.57ms +step:106/1670 train_time:10340ms step_avg:97.55ms +step:107/1670 train_time:10436ms step_avg:97.53ms +step:108/1670 train_time:10531ms step_avg:97.51ms +step:109/1670 train_time:10628ms step_avg:97.50ms +step:110/1670 train_time:10724ms step_avg:97.49ms +step:111/1670 train_time:10819ms step_avg:97.47ms +step:112/1670 train_time:10915ms step_avg:97.45ms +step:113/1670 train_time:11010ms step_avg:97.44ms +step:114/1670 train_time:11107ms step_avg:97.43ms +step:115/1670 train_time:11202ms step_avg:97.41ms +step:116/1670 train_time:11298ms step_avg:97.39ms +step:117/1670 train_time:11393ms step_avg:97.38ms +step:118/1670 train_time:11488ms step_avg:97.36ms +step:119/1670 train_time:11584ms step_avg:97.34ms +step:120/1670 train_time:11680ms step_avg:97.33ms +step:121/1670 train_time:11776ms step_avg:97.32ms +step:122/1670 train_time:11872ms step_avg:97.31ms +step:123/1670 train_time:11967ms step_avg:97.29ms +step:124/1670 train_time:12063ms step_avg:97.28ms +step:125/1670 train_time:12159ms step_avg:97.27ms +step:125/1670 val_loss:4.3294 train_time:12253ms step_avg:98.03ms +step:126/1670 train_time:12275ms step_avg:97.42ms +step:127/1670 train_time:12357ms step_avg:97.30ms +step:128/1670 train_time:12464ms step_avg:97.37ms +step:129/1670 train_time:12560ms step_avg:97.36ms +step:130/1670 train_time:12655ms step_avg:97.34ms +step:131/1670 train_time:12749ms step_avg:97.32ms +step:132/1670 train_time:12843ms step_avg:97.30ms +step:133/1670 train_time:12938ms step_avg:97.28ms +step:134/1670 train_time:13033ms step_avg:97.26ms +step:135/1670 train_time:13128ms step_avg:97.24ms +step:136/1670 train_time:13222ms step_avg:97.22ms +step:137/1670 train_time:13319ms step_avg:97.22ms +step:138/1670 train_time:13420ms step_avg:97.25ms +step:139/1670 train_time:13518ms step_avg:97.25ms +step:140/1670 train_time:13615ms step_avg:97.25ms +step:141/1670 train_time:13711ms step_avg:97.24ms +step:142/1670 train_time:13807ms step_avg:97.23ms +step:143/1670 train_time:13901ms step_avg:97.21ms +step:144/1670 train_time:13996ms step_avg:97.19ms +step:145/1670 train_time:14091ms step_avg:97.18ms +step:146/1670 train_time:14187ms step_avg:97.17ms +step:147/1670 train_time:14282ms step_avg:97.15ms +step:148/1670 train_time:14378ms step_avg:97.15ms +step:149/1670 train_time:14475ms step_avg:97.15ms +step:150/1670 train_time:14571ms step_avg:97.14ms +step:151/1670 train_time:14667ms step_avg:97.13ms +step:152/1670 train_time:14763ms step_avg:97.12ms +step:153/1670 train_time:14858ms step_avg:97.11ms +step:154/1670 train_time:14954ms step_avg:97.10ms +step:155/1670 train_time:15049ms step_avg:97.09ms +step:156/1670 train_time:15144ms step_avg:97.08ms +step:157/1670 train_time:15239ms step_avg:97.06ms +step:158/1670 train_time:15335ms step_avg:97.06ms +step:159/1670 train_time:15432ms step_avg:97.06ms +step:160/1670 train_time:15529ms step_avg:97.05ms +step:161/1670 train_time:15625ms step_avg:97.05ms +step:162/1670 train_time:15721ms step_avg:97.05ms +step:163/1670 train_time:15817ms step_avg:97.04ms +step:164/1670 train_time:15912ms step_avg:97.02ms +step:165/1670 train_time:16007ms step_avg:97.01ms +step:166/1670 train_time:16102ms step_avg:97.00ms +step:167/1670 train_time:16197ms step_avg:96.99ms +step:168/1670 train_time:16293ms step_avg:96.98ms +step:169/1670 train_time:16389ms step_avg:96.98ms +step:170/1670 train_time:16485ms step_avg:96.97ms +step:171/1670 train_time:16580ms step_avg:96.96ms +step:172/1670 train_time:16678ms step_avg:96.96ms +step:173/1670 train_time:16774ms step_avg:96.96ms +step:174/1670 train_time:16869ms step_avg:96.95ms +step:175/1670 train_time:16964ms step_avg:96.94ms +step:176/1670 train_time:17059ms step_avg:96.93ms +step:177/1670 train_time:17154ms step_avg:96.92ms +step:178/1670 train_time:17250ms step_avg:96.91ms +step:179/1670 train_time:17345ms step_avg:96.90ms +step:180/1670 train_time:17441ms step_avg:96.89ms +step:181/1670 train_time:17537ms step_avg:96.89ms +step:182/1670 train_time:17634ms step_avg:96.89ms +step:183/1670 train_time:17730ms step_avg:96.89ms +step:184/1670 train_time:17826ms step_avg:96.88ms +step:185/1670 train_time:17922ms step_avg:96.87ms +step:186/1670 train_time:18017ms step_avg:96.86ms +step:187/1670 train_time:18112ms step_avg:96.86ms +step:188/1670 train_time:18207ms step_avg:96.85ms +step:189/1670 train_time:18303ms step_avg:96.84ms +step:190/1670 train_time:18399ms step_avg:96.84ms +step:191/1670 train_time:18495ms step_avg:96.83ms +step:192/1670 train_time:18591ms step_avg:96.83ms +step:193/1670 train_time:18688ms step_avg:96.83ms +step:194/1670 train_time:18783ms step_avg:96.82ms +step:195/1670 train_time:18879ms step_avg:96.81ms +step:196/1670 train_time:18975ms step_avg:96.81ms +step:197/1670 train_time:19070ms step_avg:96.80ms +step:198/1670 train_time:19165ms step_avg:96.79ms +step:199/1670 train_time:19261ms step_avg:96.79ms +step:200/1670 train_time:19356ms step_avg:96.78ms +step:201/1670 train_time:19453ms step_avg:96.78ms +step:202/1670 train_time:19549ms step_avg:96.78ms +step:203/1670 train_time:19645ms step_avg:96.77ms +step:204/1670 train_time:19741ms step_avg:96.77ms +step:205/1670 train_time:19837ms step_avg:96.77ms +step:206/1670 train_time:19933ms step_avg:96.76ms +step:207/1670 train_time:20029ms step_avg:96.76ms +step:208/1670 train_time:20125ms step_avg:96.75ms +step:209/1670 train_time:20219ms step_avg:96.74ms +step:210/1670 train_time:20315ms step_avg:96.74ms +step:211/1670 train_time:20412ms step_avg:96.74ms +step:212/1670 train_time:20508ms step_avg:96.74ms +step:213/1670 train_time:20795ms step_avg:97.63ms +step:214/1670 train_time:20891ms step_avg:97.62ms +step:215/1670 train_time:20984ms step_avg:97.60ms +step:216/1670 train_time:21078ms step_avg:97.59ms +step:217/1670 train_time:21173ms step_avg:97.57ms +step:218/1670 train_time:21269ms step_avg:97.56ms +step:219/1670 train_time:21363ms step_avg:97.55ms +step:220/1670 train_time:21457ms step_avg:97.53ms +step:221/1670 train_time:21552ms step_avg:97.52ms +step:222/1670 train_time:21647ms step_avg:97.51ms +step:223/1670 train_time:21743ms step_avg:97.50ms +step:224/1670 train_time:21841ms step_avg:97.50ms +step:225/1670 train_time:21939ms step_avg:97.51ms +step:226/1670 train_time:22036ms step_avg:97.50ms +step:227/1670 train_time:22131ms step_avg:97.49ms +step:228/1670 train_time:22226ms step_avg:97.48ms +step:229/1670 train_time:22321ms step_avg:97.47ms +step:230/1670 train_time:22416ms step_avg:97.46ms +step:231/1670 train_time:22511ms step_avg:97.45ms +step:232/1670 train_time:22607ms step_avg:97.44ms +step:233/1670 train_time:22702ms step_avg:97.43ms +step:234/1670 train_time:22798ms step_avg:97.43ms +step:235/1670 train_time:22896ms step_avg:97.43ms +step:236/1670 train_time:22993ms step_avg:97.43ms +step:237/1670 train_time:23089ms step_avg:97.42ms +step:238/1670 train_time:23184ms step_avg:97.41ms +step:239/1670 train_time:23279ms step_avg:97.40ms +step:240/1670 train_time:23375ms step_avg:97.39ms +step:241/1670 train_time:23470ms step_avg:97.39ms +step:242/1670 train_time:23565ms step_avg:97.37ms +step:243/1670 train_time:23660ms step_avg:97.37ms +step:244/1670 train_time:23756ms step_avg:97.36ms +step:245/1670 train_time:23852ms step_avg:97.35ms +step:246/1670 train_time:23949ms step_avg:97.35ms +step:247/1670 train_time:24044ms step_avg:97.35ms +step:248/1670 train_time:24140ms step_avg:97.34ms +step:249/1670 train_time:24236ms step_avg:97.33ms +step:250/1670 train_time:24333ms step_avg:97.33ms +step:250/1670 val_loss:3.9718 train_time:24427ms step_avg:97.71ms +step:251/1670 train_time:24450ms step_avg:97.41ms +step:252/1670 train_time:24530ms step_avg:97.34ms +step:253/1670 train_time:24633ms step_avg:97.36ms +step:254/1670 train_time:24730ms step_avg:97.36ms +step:255/1670 train_time:24825ms step_avg:97.35ms +step:256/1670 train_time:24919ms step_avg:97.34ms +step:257/1670 train_time:25014ms step_avg:97.33ms +step:258/1670 train_time:25109ms step_avg:97.32ms +step:259/1670 train_time:25204ms step_avg:97.31ms +step:260/1670 train_time:25298ms step_avg:97.30ms +step:261/1670 train_time:25393ms step_avg:97.29ms +step:262/1670 train_time:25491ms step_avg:97.30ms +step:263/1670 train_time:25591ms step_avg:97.31ms +step:264/1670 train_time:25690ms step_avg:97.31ms +step:265/1670 train_time:25787ms step_avg:97.31ms +step:266/1670 train_time:25883ms step_avg:97.30ms +step:267/1670 train_time:25978ms step_avg:97.30ms +step:268/1670 train_time:26073ms step_avg:97.29ms +step:269/1670 train_time:26168ms step_avg:97.28ms +step:270/1670 train_time:26263ms step_avg:97.27ms +step:271/1670 train_time:26358ms step_avg:97.26ms +step:272/1670 train_time:26453ms step_avg:97.25ms +step:273/1670 train_time:26551ms step_avg:97.25ms +step:274/1670 train_time:26648ms step_avg:97.26ms +step:275/1670 train_time:26745ms step_avg:97.26ms +step:276/1670 train_time:26841ms step_avg:97.25ms +step:277/1670 train_time:26936ms step_avg:97.24ms +step:278/1670 train_time:27031ms step_avg:97.23ms +step:279/1670 train_time:27126ms step_avg:97.23ms +step:280/1670 train_time:27221ms step_avg:97.22ms +step:281/1670 train_time:27316ms step_avg:97.21ms +step:282/1670 train_time:27411ms step_avg:97.20ms +step:283/1670 train_time:27508ms step_avg:97.20ms +step:284/1670 train_time:27604ms step_avg:97.20ms +step:285/1670 train_time:27700ms step_avg:97.19ms +step:286/1670 train_time:27796ms step_avg:97.19ms +step:287/1670 train_time:27892ms step_avg:97.18ms +step:288/1670 train_time:27988ms step_avg:97.18ms +step:289/1670 train_time:28083ms step_avg:97.17ms +step:290/1670 train_time:28179ms step_avg:97.17ms +step:291/1670 train_time:28273ms step_avg:97.16ms +step:292/1670 train_time:28369ms step_avg:97.15ms +step:293/1670 train_time:28464ms step_avg:97.15ms +step:294/1670 train_time:28560ms step_avg:97.14ms +step:295/1670 train_time:28655ms step_avg:97.14ms +step:296/1670 train_time:28752ms step_avg:97.13ms +step:297/1670 train_time:28849ms step_avg:97.13ms +step:298/1670 train_time:28945ms step_avg:97.13ms +step:299/1670 train_time:29040ms step_avg:97.12ms +step:300/1670 train_time:29135ms step_avg:97.12ms +step:301/1670 train_time:29230ms step_avg:97.11ms +step:302/1670 train_time:29325ms step_avg:97.10ms +step:303/1670 train_time:29421ms step_avg:97.10ms +step:304/1670 train_time:29516ms step_avg:97.09ms +step:305/1670 train_time:29612ms step_avg:97.09ms +step:306/1670 train_time:29708ms step_avg:97.08ms +step:307/1670 train_time:29804ms step_avg:97.08ms +step:308/1670 train_time:29900ms step_avg:97.08ms +step:309/1670 train_time:29995ms step_avg:97.07ms +step:310/1670 train_time:30091ms step_avg:97.07ms +step:311/1670 train_time:30187ms step_avg:97.06ms +step:312/1670 train_time:30283ms step_avg:97.06ms +step:313/1670 train_time:30379ms step_avg:97.06ms +step:314/1670 train_time:30474ms step_avg:97.05ms +step:315/1670 train_time:30570ms step_avg:97.05ms +step:316/1670 train_time:30666ms step_avg:97.04ms +step:317/1670 train_time:30762ms step_avg:97.04ms +step:318/1670 train_time:30857ms step_avg:97.04ms +step:319/1670 train_time:30953ms step_avg:97.03ms +step:320/1670 train_time:31050ms step_avg:97.03ms +step:321/1670 train_time:31145ms step_avg:97.03ms +step:322/1670 train_time:31241ms step_avg:97.02ms +step:323/1670 train_time:31337ms step_avg:97.02ms +step:324/1670 train_time:31432ms step_avg:97.01ms +step:325/1670 train_time:31527ms step_avg:97.01ms +step:326/1670 train_time:31623ms step_avg:97.00ms +step:327/1670 train_time:31719ms step_avg:97.00ms +step:328/1670 train_time:31814ms step_avg:96.99ms +step:329/1670 train_time:31910ms step_avg:96.99ms +step:330/1670 train_time:32006ms step_avg:96.99ms +step:331/1670 train_time:32102ms step_avg:96.98ms +step:332/1670 train_time:32197ms step_avg:96.98ms +step:333/1670 train_time:32292ms step_avg:96.97ms +step:334/1670 train_time:32389ms step_avg:96.97ms +step:335/1670 train_time:32485ms step_avg:96.97ms +step:336/1670 train_time:32580ms step_avg:96.96ms +step:337/1670 train_time:32675ms step_avg:96.96ms +step:338/1670 train_time:32772ms step_avg:96.96ms +step:339/1670 train_time:32868ms step_avg:96.96ms +step:340/1670 train_time:32964ms step_avg:96.95ms +step:341/1670 train_time:33060ms step_avg:96.95ms +step:342/1670 train_time:33155ms step_avg:96.94ms +step:343/1670 train_time:33251ms step_avg:96.94ms +step:344/1670 train_time:33347ms step_avg:96.94ms +step:345/1670 train_time:33443ms step_avg:96.94ms +step:346/1670 train_time:33538ms step_avg:96.93ms +step:347/1670 train_time:33634ms step_avg:96.93ms +step:348/1670 train_time:33729ms step_avg:96.92ms +step:349/1670 train_time:33825ms step_avg:96.92ms +step:350/1670 train_time:33921ms step_avg:96.92ms +step:351/1670 train_time:34016ms step_avg:96.91ms +step:352/1670 train_time:34112ms step_avg:96.91ms +step:353/1670 train_time:34208ms step_avg:96.91ms +step:354/1670 train_time:34304ms step_avg:96.90ms +step:355/1670 train_time:34399ms step_avg:96.90ms +step:356/1670 train_time:34495ms step_avg:96.89ms +step:357/1670 train_time:34591ms step_avg:96.89ms +step:358/1670 train_time:34686ms step_avg:96.89ms +step:359/1670 train_time:34781ms step_avg:96.88ms +step:360/1670 train_time:34876ms step_avg:96.88ms +step:361/1670 train_time:34972ms step_avg:96.88ms +step:362/1670 train_time:35068ms step_avg:96.87ms +step:363/1670 train_time:35165ms step_avg:96.87ms +step:364/1670 train_time:35260ms step_avg:96.87ms +step:365/1670 train_time:35356ms step_avg:96.87ms +step:366/1670 train_time:35451ms step_avg:96.86ms +step:367/1670 train_time:35547ms step_avg:96.86ms +step:368/1670 train_time:35643ms step_avg:96.86ms +step:369/1670 train_time:35738ms step_avg:96.85ms +step:370/1670 train_time:35833ms step_avg:96.85ms +step:371/1670 train_time:35929ms step_avg:96.84ms +step:372/1670 train_time:36025ms step_avg:96.84ms +step:373/1670 train_time:36121ms step_avg:96.84ms +step:374/1670 train_time:36216ms step_avg:96.83ms +step:375/1670 train_time:36311ms step_avg:96.83ms +step:375/1670 val_loss:3.8133 train_time:36407ms step_avg:97.08ms +step:376/1670 train_time:36427ms step_avg:96.88ms +step:377/1670 train_time:36509ms step_avg:96.84ms +step:378/1670 train_time:36610ms step_avg:96.85ms +step:379/1670 train_time:36705ms step_avg:96.85ms +step:380/1670 train_time:36801ms step_avg:96.84ms +step:381/1670 train_time:36896ms step_avg:96.84ms +step:382/1670 train_time:36991ms step_avg:96.83ms +step:383/1670 train_time:37086ms step_avg:96.83ms +step:384/1670 train_time:37181ms step_avg:96.83ms +step:385/1670 train_time:37276ms step_avg:96.82ms +step:386/1670 train_time:37370ms step_avg:96.81ms +step:387/1670 train_time:37467ms step_avg:96.81ms +step:388/1670 train_time:37566ms step_avg:96.82ms +step:389/1670 train_time:37663ms step_avg:96.82ms +step:390/1670 train_time:37760ms step_avg:96.82ms +step:391/1670 train_time:37856ms step_avg:96.82ms +step:392/1670 train_time:37951ms step_avg:96.81ms +step:393/1670 train_time:38046ms step_avg:96.81ms +step:394/1670 train_time:38140ms step_avg:96.80ms +step:395/1670 train_time:38235ms step_avg:96.80ms +step:396/1670 train_time:38330ms step_avg:96.79ms +step:397/1670 train_time:38425ms step_avg:96.79ms +step:398/1670 train_time:38522ms step_avg:96.79ms +step:399/1670 train_time:38620ms step_avg:96.79ms +step:400/1670 train_time:38716ms step_avg:96.79ms +step:401/1670 train_time:38812ms step_avg:96.79ms +step:402/1670 train_time:38907ms step_avg:96.78ms +step:403/1670 train_time:39003ms step_avg:96.78ms +step:404/1670 train_time:39098ms step_avg:96.78ms +step:405/1670 train_time:39194ms step_avg:96.77ms +step:406/1670 train_time:39288ms step_avg:96.77ms +step:407/1670 train_time:39384ms step_avg:96.77ms +step:408/1670 train_time:39480ms step_avg:96.76ms +step:409/1670 train_time:39576ms step_avg:96.76ms +step:410/1670 train_time:39672ms step_avg:96.76ms +step:411/1670 train_time:39768ms step_avg:96.76ms +step:412/1670 train_time:39864ms step_avg:96.76ms +step:413/1670 train_time:39961ms step_avg:96.76ms +step:414/1670 train_time:40057ms step_avg:96.76ms +step:415/1670 train_time:40152ms step_avg:96.75ms +step:416/1670 train_time:40247ms step_avg:96.75ms +step:417/1670 train_time:40342ms step_avg:96.74ms +step:418/1670 train_time:40438ms step_avg:96.74ms +step:419/1670 train_time:40534ms step_avg:96.74ms +step:420/1670 train_time:40629ms step_avg:96.73ms +step:421/1670 train_time:40725ms step_avg:96.73ms +step:422/1670 train_time:40823ms step_avg:96.74ms +step:423/1670 train_time:40920ms step_avg:96.74ms +step:424/1670 train_time:41016ms step_avg:96.74ms +step:425/1670 train_time:41309ms step_avg:97.20ms +step:426/1670 train_time:41449ms step_avg:97.30ms +step:427/1670 train_time:41542ms step_avg:97.29ms +step:428/1670 train_time:41637ms step_avg:97.28ms +step:429/1670 train_time:41732ms step_avg:97.28ms +step:430/1670 train_time:41826ms step_avg:97.27ms +step:431/1670 train_time:41921ms step_avg:97.26ms +step:432/1670 train_time:42015ms step_avg:97.26ms +step:433/1670 train_time:42110ms step_avg:97.25ms +step:434/1670 train_time:42204ms step_avg:97.24ms +step:435/1670 train_time:42301ms step_avg:97.24ms +step:436/1670 train_time:42402ms step_avg:97.25ms +step:437/1670 train_time:42501ms step_avg:97.26ms +step:438/1670 train_time:42597ms step_avg:97.25ms +step:439/1670 train_time:42693ms step_avg:97.25ms +step:440/1670 train_time:42788ms step_avg:97.24ms +step:441/1670 train_time:42882ms step_avg:97.24ms +step:442/1670 train_time:42978ms step_avg:97.23ms +step:443/1670 train_time:43072ms step_avg:97.23ms +step:444/1670 train_time:43167ms step_avg:97.22ms +step:445/1670 train_time:43262ms step_avg:97.22ms +step:446/1670 train_time:43358ms step_avg:97.22ms +step:447/1670 train_time:43456ms step_avg:97.22ms +step:448/1670 train_time:43552ms step_avg:97.21ms +step:449/1670 train_time:43648ms step_avg:97.21ms +step:450/1670 train_time:43744ms step_avg:97.21ms +step:451/1670 train_time:43839ms step_avg:97.20ms +step:452/1670 train_time:43934ms step_avg:97.20ms +step:453/1670 train_time:44029ms step_avg:97.19ms +step:454/1670 train_time:44123ms step_avg:97.19ms +step:455/1670 train_time:44219ms step_avg:97.18ms +step:456/1670 train_time:44315ms step_avg:97.18ms +step:457/1670 train_time:44411ms step_avg:97.18ms +step:458/1670 train_time:44508ms step_avg:97.18ms +step:459/1670 train_time:44604ms step_avg:97.18ms +step:460/1670 train_time:44700ms step_avg:97.17ms +step:461/1670 train_time:44795ms step_avg:97.17ms +step:462/1670 train_time:44890ms step_avg:97.17ms +step:463/1670 train_time:44985ms step_avg:97.16ms +step:464/1670 train_time:45081ms step_avg:97.16ms +step:465/1670 train_time:45176ms step_avg:97.15ms +step:466/1670 train_time:45271ms step_avg:97.15ms +step:467/1670 train_time:45367ms step_avg:97.15ms +step:468/1670 train_time:45464ms step_avg:97.14ms +step:469/1670 train_time:45561ms step_avg:97.14ms +step:470/1670 train_time:45657ms step_avg:97.14ms +step:471/1670 train_time:45753ms step_avg:97.14ms +step:472/1670 train_time:45849ms step_avg:97.14ms +step:473/1670 train_time:45943ms step_avg:97.13ms +step:474/1670 train_time:46039ms step_avg:97.13ms +step:475/1670 train_time:46134ms step_avg:97.12ms +step:476/1670 train_time:46229ms step_avg:97.12ms +step:477/1670 train_time:46324ms step_avg:97.12ms +step:478/1670 train_time:46420ms step_avg:97.11ms +step:479/1670 train_time:46517ms step_avg:97.11ms +step:480/1670 train_time:46614ms step_avg:97.11ms +step:481/1670 train_time:46710ms step_avg:97.11ms +step:482/1670 train_time:46805ms step_avg:97.11ms +step:483/1670 train_time:46901ms step_avg:97.10ms +step:484/1670 train_time:46997ms step_avg:97.10ms +step:485/1670 train_time:47092ms step_avg:97.10ms +step:486/1670 train_time:47187ms step_avg:97.09ms +step:487/1670 train_time:47283ms step_avg:97.09ms +step:488/1670 train_time:47378ms step_avg:97.09ms +step:489/1670 train_time:47474ms step_avg:97.08ms +step:490/1670 train_time:47569ms step_avg:97.08ms +step:491/1670 train_time:47664ms step_avg:97.08ms +step:492/1670 train_time:47761ms step_avg:97.07ms +step:493/1670 train_time:47857ms step_avg:97.07ms +step:494/1670 train_time:47952ms step_avg:97.07ms +step:495/1670 train_time:48048ms step_avg:97.07ms +step:496/1670 train_time:48143ms step_avg:97.06ms +step:497/1670 train_time:48238ms step_avg:97.06ms +step:498/1670 train_time:48334ms step_avg:97.06ms +step:499/1670 train_time:48429ms step_avg:97.05ms +step:500/1670 train_time:48525ms step_avg:97.05ms +step:500/1670 val_loss:3.7125 train_time:48621ms step_avg:97.24ms +step:501/1670 train_time:48642ms step_avg:97.09ms +step:502/1670 train_time:48723ms step_avg:97.06ms +step:503/1670 train_time:48827ms step_avg:97.07ms +step:504/1670 train_time:48922ms step_avg:97.07ms +step:505/1670 train_time:49016ms step_avg:97.06ms +step:506/1670 train_time:49111ms step_avg:97.06ms +step:507/1670 train_time:49206ms step_avg:97.05ms +step:508/1670 train_time:49300ms step_avg:97.05ms +step:509/1670 train_time:49395ms step_avg:97.04ms +step:510/1670 train_time:49489ms step_avg:97.04ms +step:511/1670 train_time:49585ms step_avg:97.03ms +step:512/1670 train_time:49681ms step_avg:97.03ms +step:513/1670 train_time:49780ms step_avg:97.04ms +step:514/1670 train_time:49877ms step_avg:97.04ms +step:515/1670 train_time:49973ms step_avg:97.04ms +step:516/1670 train_time:50069ms step_avg:97.03ms +step:517/1670 train_time:50164ms step_avg:97.03ms +step:518/1670 train_time:50258ms step_avg:97.02ms +step:519/1670 train_time:50353ms step_avg:97.02ms +step:520/1670 train_time:50448ms step_avg:97.01ms +step:521/1670 train_time:50542ms step_avg:97.01ms +step:522/1670 train_time:50638ms step_avg:97.01ms +step:523/1670 train_time:50736ms step_avg:97.01ms +step:524/1670 train_time:50834ms step_avg:97.01ms +step:525/1670 train_time:50931ms step_avg:97.01ms +step:526/1670 train_time:51027ms step_avg:97.01ms +step:527/1670 train_time:51122ms step_avg:97.01ms +step:528/1670 train_time:51217ms step_avg:97.00ms +step:529/1670 train_time:51312ms step_avg:97.00ms +step:530/1670 train_time:51407ms step_avg:96.99ms +step:531/1670 train_time:51502ms step_avg:96.99ms +step:532/1670 train_time:51597ms step_avg:96.99ms +step:533/1670 train_time:51694ms step_avg:96.99ms +step:534/1670 train_time:51791ms step_avg:96.99ms +step:535/1670 train_time:51887ms step_avg:96.99ms +step:536/1670 train_time:51983ms step_avg:96.98ms +step:537/1670 train_time:52079ms step_avg:96.98ms +step:538/1670 train_time:52174ms step_avg:96.98ms +step:539/1670 train_time:52270ms step_avg:96.98ms +step:540/1670 train_time:52366ms step_avg:96.97ms +step:541/1670 train_time:52460ms step_avg:96.97ms +step:542/1670 train_time:52556ms step_avg:96.97ms +step:543/1670 train_time:52651ms step_avg:96.96ms +step:544/1670 train_time:52747ms step_avg:96.96ms +step:545/1670 train_time:52843ms step_avg:96.96ms +step:546/1670 train_time:52939ms step_avg:96.96ms +step:547/1670 train_time:53035ms step_avg:96.96ms +step:548/1670 train_time:53131ms step_avg:96.95ms +step:549/1670 train_time:53226ms step_avg:96.95ms +step:550/1670 train_time:53321ms step_avg:96.95ms +step:551/1670 train_time:53416ms step_avg:96.94ms +step:552/1670 train_time:53512ms step_avg:96.94ms +step:553/1670 train_time:53608ms step_avg:96.94ms +step:554/1670 train_time:53703ms step_avg:96.94ms +step:555/1670 train_time:53799ms step_avg:96.93ms +step:556/1670 train_time:53895ms step_avg:96.93ms +step:557/1670 train_time:53991ms step_avg:96.93ms +step:558/1670 train_time:54087ms step_avg:96.93ms +step:559/1670 train_time:54183ms step_avg:96.93ms +step:560/1670 train_time:54280ms step_avg:96.93ms +step:561/1670 train_time:54376ms step_avg:96.93ms +step:562/1670 train_time:54474ms step_avg:96.93ms +step:563/1670 train_time:54571ms step_avg:96.93ms +step:564/1670 train_time:54668ms step_avg:96.93ms +step:565/1670 train_time:54764ms step_avg:96.93ms +step:566/1670 train_time:54861ms step_avg:96.93ms +step:567/1670 train_time:54958ms step_avg:96.93ms +step:568/1670 train_time:55056ms step_avg:96.93ms +step:569/1670 train_time:55154ms step_avg:96.93ms +step:570/1670 train_time:55253ms step_avg:96.93ms +step:571/1670 train_time:55350ms step_avg:96.94ms +step:572/1670 train_time:55447ms step_avg:96.94ms +step:573/1670 train_time:55543ms step_avg:96.93ms +step:574/1670 train_time:55639ms step_avg:96.93ms +step:575/1670 train_time:55738ms step_avg:96.93ms +step:576/1670 train_time:55836ms step_avg:96.94ms +step:577/1670 train_time:55933ms step_avg:96.94ms +step:578/1670 train_time:56030ms step_avg:96.94ms +step:579/1670 train_time:56128ms step_avg:96.94ms +step:580/1670 train_time:56225ms step_avg:96.94ms +step:581/1670 train_time:56322ms step_avg:96.94ms +step:582/1670 train_time:56418ms step_avg:96.94ms +step:583/1670 train_time:56515ms step_avg:96.94ms +step:584/1670 train_time:56612ms step_avg:96.94ms +step:585/1670 train_time:56710ms step_avg:96.94ms +step:586/1670 train_time:56808ms step_avg:96.94ms +step:587/1670 train_time:56905ms step_avg:96.94ms +step:588/1670 train_time:57001ms step_avg:96.94ms +step:589/1670 train_time:57098ms step_avg:96.94ms +step:590/1670 train_time:57195ms step_avg:96.94ms +step:591/1670 train_time:57293ms step_avg:96.94ms +step:592/1670 train_time:57391ms step_avg:96.94ms +step:593/1670 train_time:57489ms step_avg:96.95ms +step:594/1670 train_time:57586ms step_avg:96.95ms +step:595/1670 train_time:57682ms step_avg:96.94ms +step:596/1670 train_time:57779ms step_avg:96.94ms +step:597/1670 train_time:57876ms step_avg:96.95ms +step:598/1670 train_time:57974ms step_avg:96.95ms +step:599/1670 train_time:58072ms step_avg:96.95ms +step:600/1670 train_time:58169ms step_avg:96.95ms +step:601/1670 train_time:58266ms step_avg:96.95ms +step:602/1670 train_time:58362ms step_avg:96.95ms +step:603/1670 train_time:58459ms step_avg:96.95ms +step:604/1670 train_time:58556ms step_avg:96.95ms +step:605/1670 train_time:58654ms step_avg:96.95ms +step:606/1670 train_time:58752ms step_avg:96.95ms +step:607/1670 train_time:58850ms step_avg:96.95ms +step:608/1670 train_time:58947ms step_avg:96.95ms +step:609/1670 train_time:59043ms step_avg:96.95ms +step:610/1670 train_time:59140ms step_avg:96.95ms +step:611/1670 train_time:59237ms step_avg:96.95ms +step:612/1670 train_time:59335ms step_avg:96.95ms +step:613/1670 train_time:59432ms step_avg:96.95ms +step:614/1670 train_time:59530ms step_avg:96.95ms +step:615/1670 train_time:59627ms step_avg:96.95ms +step:616/1670 train_time:59723ms step_avg:96.95ms +step:617/1670 train_time:59820ms step_avg:96.95ms +step:618/1670 train_time:59918ms step_avg:96.95ms +step:619/1670 train_time:60016ms step_avg:96.96ms +step:620/1670 train_time:60114ms step_avg:96.96ms +step:621/1670 train_time:60212ms step_avg:96.96ms +step:622/1670 train_time:60309ms step_avg:96.96ms +step:623/1670 train_time:60405ms step_avg:96.96ms +step:624/1670 train_time:60502ms step_avg:96.96ms +step:625/1670 train_time:60598ms step_avg:96.96ms +step:625/1670 val_loss:3.6145 train_time:60696ms step_avg:97.11ms +step:626/1670 train_time:60718ms step_avg:96.99ms +step:627/1670 train_time:60807ms step_avg:96.98ms +step:628/1670 train_time:60904ms step_avg:96.98ms +step:629/1670 train_time:61000ms step_avg:96.98ms +step:630/1670 train_time:61096ms step_avg:96.98ms +step:631/1670 train_time:61192ms step_avg:96.98ms +step:632/1670 train_time:61288ms step_avg:96.97ms +step:633/1670 train_time:61383ms step_avg:96.97ms +step:634/1670 train_time:61478ms step_avg:96.97ms +step:635/1670 train_time:61575ms step_avg:96.97ms +step:636/1670 train_time:61674ms step_avg:96.97ms +step:637/1670 train_time:61778ms step_avg:96.98ms +step:638/1670 train_time:61878ms step_avg:96.99ms +step:639/1670 train_time:62212ms step_avg:97.36ms +step:640/1670 train_time:62370ms step_avg:97.45ms +step:641/1670 train_time:62465ms step_avg:97.45ms +step:642/1670 train_time:62561ms step_avg:97.45ms +step:643/1670 train_time:62657ms step_avg:97.44ms +step:644/1670 train_time:62753ms step_avg:97.44ms +step:645/1670 train_time:62848ms step_avg:97.44ms +step:646/1670 train_time:62944ms step_avg:97.44ms +step:647/1670 train_time:63040ms step_avg:97.43ms +step:648/1670 train_time:63137ms step_avg:97.43ms +step:649/1670 train_time:63236ms step_avg:97.44ms +step:650/1670 train_time:63337ms step_avg:97.44ms +step:651/1670 train_time:63438ms step_avg:97.45ms +step:652/1670 train_time:63536ms step_avg:97.45ms +step:653/1670 train_time:63633ms step_avg:97.45ms +step:654/1670 train_time:63729ms step_avg:97.45ms +step:655/1670 train_time:63825ms step_avg:97.44ms +step:656/1670 train_time:63921ms step_avg:97.44ms +step:657/1670 train_time:64017ms step_avg:97.44ms +step:658/1670 train_time:64114ms step_avg:97.44ms +step:659/1670 train_time:64212ms step_avg:97.44ms +step:660/1670 train_time:64311ms step_avg:97.44ms +step:661/1670 train_time:64409ms step_avg:97.44ms +step:662/1670 train_time:64506ms step_avg:97.44ms +step:663/1670 train_time:64603ms step_avg:97.44ms +step:664/1670 train_time:64700ms step_avg:97.44ms +step:665/1670 train_time:64797ms step_avg:97.44ms +step:666/1670 train_time:64894ms step_avg:97.44ms +step:667/1670 train_time:64990ms step_avg:97.44ms +step:668/1670 train_time:65087ms step_avg:97.44ms +step:669/1670 train_time:65184ms step_avg:97.43ms +step:670/1670 train_time:65280ms step_avg:97.43ms +step:671/1670 train_time:65378ms step_avg:97.43ms +step:672/1670 train_time:65478ms step_avg:97.44ms +step:673/1670 train_time:65576ms step_avg:97.44ms +step:674/1670 train_time:65673ms step_avg:97.44ms +step:675/1670 train_time:65770ms step_avg:97.44ms +step:676/1670 train_time:65866ms step_avg:97.44ms +step:677/1670 train_time:65962ms step_avg:97.43ms +step:678/1670 train_time:66059ms step_avg:97.43ms +step:679/1670 train_time:66156ms step_avg:97.43ms +step:680/1670 train_time:66253ms step_avg:97.43ms +step:681/1670 train_time:66350ms step_avg:97.43ms +step:682/1670 train_time:66447ms step_avg:97.43ms +step:683/1670 train_time:66543ms step_avg:97.43ms +step:684/1670 train_time:66640ms step_avg:97.43ms +step:685/1670 train_time:66738ms step_avg:97.43ms +step:686/1670 train_time:66835ms step_avg:97.43ms +step:687/1670 train_time:66932ms step_avg:97.43ms +step:688/1670 train_time:67029ms step_avg:97.43ms +step:689/1670 train_time:67125ms step_avg:97.42ms +step:690/1670 train_time:67221ms step_avg:97.42ms +step:691/1670 train_time:67319ms step_avg:97.42ms +step:692/1670 train_time:67416ms step_avg:97.42ms +step:693/1670 train_time:67514ms step_avg:97.42ms +step:694/1670 train_time:67612ms step_avg:97.42ms +step:695/1670 train_time:67709ms step_avg:97.42ms +step:696/1670 train_time:67806ms step_avg:97.42ms +step:697/1670 train_time:67903ms step_avg:97.42ms +step:698/1670 train_time:68000ms step_avg:97.42ms +step:699/1670 train_time:68098ms step_avg:97.42ms +step:700/1670 train_time:68196ms step_avg:97.42ms +step:701/1670 train_time:68292ms step_avg:97.42ms +step:702/1670 train_time:68389ms step_avg:97.42ms +step:703/1670 train_time:68488ms step_avg:97.42ms +step:704/1670 train_time:68584ms step_avg:97.42ms +step:705/1670 train_time:68681ms step_avg:97.42ms +step:706/1670 train_time:68778ms step_avg:97.42ms +step:707/1670 train_time:68875ms step_avg:97.42ms +step:708/1670 train_time:68972ms step_avg:97.42ms +step:709/1670 train_time:69069ms step_avg:97.42ms +step:710/1670 train_time:69167ms step_avg:97.42ms +step:711/1670 train_time:69265ms step_avg:97.42ms +step:712/1670 train_time:69361ms step_avg:97.42ms +step:713/1670 train_time:69459ms step_avg:97.42ms +step:714/1670 train_time:69556ms step_avg:97.42ms +step:715/1670 train_time:69654ms step_avg:97.42ms +step:716/1670 train_time:69751ms step_avg:97.42ms +step:717/1670 train_time:69847ms step_avg:97.42ms +step:718/1670 train_time:69943ms step_avg:97.41ms +step:719/1670 train_time:70040ms step_avg:97.41ms +step:720/1670 train_time:70138ms step_avg:97.41ms +step:721/1670 train_time:70235ms step_avg:97.41ms +step:722/1670 train_time:70333ms step_avg:97.41ms +step:723/1670 train_time:70431ms step_avg:97.42ms +step:724/1670 train_time:70528ms step_avg:97.42ms +step:725/1670 train_time:70626ms step_avg:97.41ms +step:726/1670 train_time:70722ms step_avg:97.41ms +step:727/1670 train_time:70820ms step_avg:97.41ms +step:728/1670 train_time:70918ms step_avg:97.41ms +step:729/1670 train_time:71016ms step_avg:97.42ms +step:730/1670 train_time:71113ms step_avg:97.41ms +step:731/1670 train_time:71211ms step_avg:97.42ms +step:732/1670 train_time:71307ms step_avg:97.41ms +step:733/1670 train_time:71404ms step_avg:97.41ms +step:734/1670 train_time:71502ms step_avg:97.41ms +step:735/1670 train_time:71600ms step_avg:97.41ms +step:736/1670 train_time:71697ms step_avg:97.41ms +step:737/1670 train_time:71795ms step_avg:97.41ms +step:738/1670 train_time:71891ms step_avg:97.41ms +step:739/1670 train_time:71989ms step_avg:97.41ms +step:740/1670 train_time:72085ms step_avg:97.41ms +step:741/1670 train_time:72181ms step_avg:97.41ms +step:742/1670 train_time:72279ms step_avg:97.41ms +step:743/1670 train_time:72377ms step_avg:97.41ms +step:744/1670 train_time:72474ms step_avg:97.41ms +step:745/1670 train_time:72572ms step_avg:97.41ms +step:746/1670 train_time:72668ms step_avg:97.41ms +step:747/1670 train_time:72764ms step_avg:97.41ms +step:748/1670 train_time:72860ms step_avg:97.41ms +step:749/1670 train_time:72958ms step_avg:97.41ms +step:750/1670 train_time:73057ms step_avg:97.41ms +step:750/1670 val_loss:3.5614 train_time:73154ms step_avg:97.54ms +step:751/1670 train_time:73175ms step_avg:97.44ms +step:752/1670 train_time:73257ms step_avg:97.42ms +step:753/1670 train_time:73358ms step_avg:97.42ms +step:754/1670 train_time:73456ms step_avg:97.42ms +step:755/1670 train_time:73553ms step_avg:97.42ms +step:756/1670 train_time:73650ms step_avg:97.42ms +step:757/1670 train_time:73746ms step_avg:97.42ms +step:758/1670 train_time:73842ms step_avg:97.42ms +step:759/1670 train_time:73938ms step_avg:97.42ms +step:760/1670 train_time:74035ms step_avg:97.41ms +step:761/1670 train_time:74133ms step_avg:97.41ms +step:762/1670 train_time:74232ms step_avg:97.42ms +step:763/1670 train_time:74332ms step_avg:97.42ms +step:764/1670 train_time:74429ms step_avg:97.42ms +step:765/1670 train_time:74526ms step_avg:97.42ms +step:766/1670 train_time:74623ms step_avg:97.42ms +step:767/1670 train_time:74719ms step_avg:97.42ms +step:768/1670 train_time:74816ms step_avg:97.42ms +step:769/1670 train_time:74913ms step_avg:97.42ms +step:770/1670 train_time:75009ms step_avg:97.41ms +step:771/1670 train_time:75106ms step_avg:97.41ms +step:772/1670 train_time:75204ms step_avg:97.41ms +step:773/1670 train_time:75301ms step_avg:97.41ms +step:774/1670 train_time:75398ms step_avg:97.41ms +step:775/1670 train_time:75497ms step_avg:97.42ms +step:776/1670 train_time:75595ms step_avg:97.42ms +step:777/1670 train_time:75693ms step_avg:97.42ms +step:778/1670 train_time:75790ms step_avg:97.42ms +step:779/1670 train_time:75886ms step_avg:97.41ms +step:780/1670 train_time:75982ms step_avg:97.41ms +step:781/1670 train_time:76078ms step_avg:97.41ms +step:782/1670 train_time:76177ms step_avg:97.41ms +step:783/1670 train_time:76275ms step_avg:97.41ms +step:784/1670 train_time:76372ms step_avg:97.41ms +step:785/1670 train_time:76470ms step_avg:97.41ms +step:786/1670 train_time:76567ms step_avg:97.41ms +step:787/1670 train_time:76664ms step_avg:97.41ms +step:788/1670 train_time:76761ms step_avg:97.41ms +step:789/1670 train_time:76858ms step_avg:97.41ms +step:790/1670 train_time:76955ms step_avg:97.41ms +step:791/1670 train_time:77052ms step_avg:97.41ms +step:792/1670 train_time:77149ms step_avg:97.41ms +step:793/1670 train_time:77246ms step_avg:97.41ms +step:794/1670 train_time:77343ms step_avg:97.41ms +step:795/1670 train_time:77440ms step_avg:97.41ms +step:796/1670 train_time:77539ms step_avg:97.41ms +step:797/1670 train_time:77637ms step_avg:97.41ms +step:798/1670 train_time:77734ms step_avg:97.41ms +step:799/1670 train_time:77831ms step_avg:97.41ms +step:800/1670 train_time:77929ms step_avg:97.41ms +step:801/1670 train_time:78026ms step_avg:97.41ms +step:802/1670 train_time:78122ms step_avg:97.41ms +step:803/1670 train_time:78219ms step_avg:97.41ms +step:804/1670 train_time:78316ms step_avg:97.41ms +step:805/1670 train_time:78413ms step_avg:97.41ms +step:806/1670 train_time:78510ms step_avg:97.41ms +step:807/1670 train_time:78608ms step_avg:97.41ms +step:808/1670 train_time:78704ms step_avg:97.41ms +step:809/1670 train_time:78801ms step_avg:97.40ms +step:810/1670 train_time:78898ms step_avg:97.40ms +step:811/1670 train_time:78995ms step_avg:97.40ms +step:812/1670 train_time:79093ms step_avg:97.40ms +step:813/1670 train_time:79190ms step_avg:97.40ms +step:814/1670 train_time:79287ms step_avg:97.40ms +step:815/1670 train_time:79384ms step_avg:97.40ms +step:816/1670 train_time:79480ms step_avg:97.40ms +step:817/1670 train_time:79578ms step_avg:97.40ms +step:818/1670 train_time:79677ms step_avg:97.40ms +step:819/1670 train_time:79775ms step_avg:97.41ms +step:820/1670 train_time:79872ms step_avg:97.41ms +step:821/1670 train_time:79970ms step_avg:97.41ms +step:822/1670 train_time:80067ms step_avg:97.40ms +step:823/1670 train_time:80163ms step_avg:97.40ms +step:824/1670 train_time:80259ms step_avg:97.40ms +step:825/1670 train_time:80356ms step_avg:97.40ms +step:826/1670 train_time:80454ms step_avg:97.40ms +step:827/1670 train_time:80551ms step_avg:97.40ms +step:828/1670 train_time:80649ms step_avg:97.40ms +step:829/1670 train_time:80747ms step_avg:97.40ms +step:830/1670 train_time:80843ms step_avg:97.40ms +step:831/1670 train_time:80940ms step_avg:97.40ms +step:832/1670 train_time:81038ms step_avg:97.40ms +step:833/1670 train_time:81136ms step_avg:97.40ms +step:834/1670 train_time:81233ms step_avg:97.40ms +step:835/1670 train_time:81330ms step_avg:97.40ms +step:836/1670 train_time:81427ms step_avg:97.40ms +step:837/1670 train_time:81524ms step_avg:97.40ms +step:838/1670 train_time:81621ms step_avg:97.40ms +step:839/1670 train_time:81718ms step_avg:97.40ms +step:840/1670 train_time:81817ms step_avg:97.40ms +step:841/1670 train_time:81914ms step_avg:97.40ms +step:842/1670 train_time:82011ms step_avg:97.40ms +step:843/1670 train_time:82109ms step_avg:97.40ms +step:844/1670 train_time:82205ms step_avg:97.40ms +step:845/1670 train_time:82301ms step_avg:97.40ms +step:846/1670 train_time:82398ms step_avg:97.40ms +step:847/1670 train_time:82496ms step_avg:97.40ms +step:848/1670 train_time:82594ms step_avg:97.40ms +step:849/1670 train_time:82690ms step_avg:97.40ms +step:850/1670 train_time:82787ms step_avg:97.40ms +step:851/1670 train_time:83055ms step_avg:97.60ms +step:852/1670 train_time:83197ms step_avg:97.65ms +step:853/1670 train_time:83292ms step_avg:97.65ms +step:854/1670 train_time:83388ms step_avg:97.64ms +step:855/1670 train_time:83483ms step_avg:97.64ms +step:856/1670 train_time:83579ms step_avg:97.64ms +step:857/1670 train_time:83674ms step_avg:97.64ms +step:858/1670 train_time:83770ms step_avg:97.63ms +step:859/1670 train_time:83865ms step_avg:97.63ms +step:860/1670 train_time:83961ms step_avg:97.63ms +step:861/1670 train_time:84058ms step_avg:97.63ms +step:862/1670 train_time:84161ms step_avg:97.63ms +step:863/1670 train_time:84262ms step_avg:97.64ms +step:864/1670 train_time:84358ms step_avg:97.64ms +step:865/1670 train_time:84455ms step_avg:97.64ms +step:866/1670 train_time:84551ms step_avg:97.63ms +step:867/1670 train_time:84647ms step_avg:97.63ms +step:868/1670 train_time:84743ms step_avg:97.63ms +step:869/1670 train_time:84838ms step_avg:97.63ms +step:870/1670 train_time:84935ms step_avg:97.63ms +step:871/1670 train_time:85032ms step_avg:97.63ms +step:872/1670 train_time:85131ms step_avg:97.63ms +step:873/1670 train_time:85231ms step_avg:97.63ms +step:874/1670 train_time:85329ms step_avg:97.63ms +step:875/1670 train_time:85425ms step_avg:97.63ms +step:875/1670 val_loss:3.5194 train_time:85521ms step_avg:97.74ms +step:876/1670 train_time:85542ms step_avg:97.65ms +step:877/1670 train_time:85623ms step_avg:97.63ms +step:878/1670 train_time:85720ms step_avg:97.63ms +step:879/1670 train_time:85817ms step_avg:97.63ms +step:880/1670 train_time:85913ms step_avg:97.63ms +step:881/1670 train_time:86009ms step_avg:97.63ms +step:882/1670 train_time:86104ms step_avg:97.62ms +step:883/1670 train_time:86200ms step_avg:97.62ms +step:884/1670 train_time:86296ms step_avg:97.62ms +step:885/1670 train_time:86392ms step_avg:97.62ms +step:886/1670 train_time:86493ms step_avg:97.62ms +step:887/1670 train_time:86593ms step_avg:97.62ms +step:888/1670 train_time:86694ms step_avg:97.63ms +step:889/1670 train_time:86791ms step_avg:97.63ms +step:890/1670 train_time:86889ms step_avg:97.63ms +step:891/1670 train_time:86986ms step_avg:97.63ms +step:892/1670 train_time:87081ms step_avg:97.62ms +step:893/1670 train_time:87177ms step_avg:97.62ms +step:894/1670 train_time:87273ms step_avg:97.62ms +step:895/1670 train_time:87370ms step_avg:97.62ms +step:896/1670 train_time:87468ms step_avg:97.62ms +step:897/1670 train_time:87567ms step_avg:97.62ms +step:898/1670 train_time:87665ms step_avg:97.62ms +step:899/1670 train_time:87762ms step_avg:97.62ms +step:900/1670 train_time:87859ms step_avg:97.62ms +step:901/1670 train_time:87957ms step_avg:97.62ms +step:902/1670 train_time:88054ms step_avg:97.62ms +step:903/1670 train_time:88151ms step_avg:97.62ms +step:904/1670 train_time:88248ms step_avg:97.62ms +step:905/1670 train_time:88344ms step_avg:97.62ms +step:906/1670 train_time:88441ms step_avg:97.62ms +step:907/1670 train_time:88538ms step_avg:97.62ms +step:908/1670 train_time:88635ms step_avg:97.62ms +step:909/1670 train_time:88733ms step_avg:97.62ms +step:910/1670 train_time:88831ms step_avg:97.62ms +step:911/1670 train_time:88929ms step_avg:97.62ms +step:912/1670 train_time:89026ms step_avg:97.62ms +step:913/1670 train_time:89122ms step_avg:97.61ms +step:914/1670 train_time:89218ms step_avg:97.61ms +step:915/1670 train_time:89315ms step_avg:97.61ms +step:916/1670 train_time:89412ms step_avg:97.61ms +step:917/1670 train_time:89510ms step_avg:97.61ms +step:918/1670 train_time:89607ms step_avg:97.61ms +step:919/1670 train_time:89704ms step_avg:97.61ms +step:920/1670 train_time:89801ms step_avg:97.61ms +step:921/1670 train_time:89898ms step_avg:97.61ms +step:922/1670 train_time:89995ms step_avg:97.61ms +step:923/1670 train_time:90093ms step_avg:97.61ms +step:924/1670 train_time:90190ms step_avg:97.61ms +step:925/1670 train_time:90287ms step_avg:97.61ms +step:926/1670 train_time:90384ms step_avg:97.61ms +step:927/1670 train_time:90480ms step_avg:97.61ms +step:928/1670 train_time:90577ms step_avg:97.61ms +step:929/1670 train_time:90675ms step_avg:97.60ms +step:930/1670 train_time:90773ms step_avg:97.61ms +step:931/1670 train_time:90871ms step_avg:97.61ms +step:932/1670 train_time:90967ms step_avg:97.60ms +step:933/1670 train_time:91064ms step_avg:97.60ms +step:934/1670 train_time:91160ms step_avg:97.60ms +step:935/1670 train_time:91256ms step_avg:97.60ms +step:936/1670 train_time:91356ms step_avg:97.60ms +step:937/1670 train_time:91454ms step_avg:97.60ms +step:938/1670 train_time:91552ms step_avg:97.60ms +step:939/1670 train_time:91649ms step_avg:97.60ms +step:940/1670 train_time:91746ms step_avg:97.60ms +step:941/1670 train_time:91842ms step_avg:97.60ms +step:942/1670 train_time:91940ms step_avg:97.60ms +step:943/1670 train_time:92037ms step_avg:97.60ms +step:944/1670 train_time:92134ms step_avg:97.60ms +step:945/1670 train_time:92231ms step_avg:97.60ms +step:946/1670 train_time:92329ms step_avg:97.60ms +step:947/1670 train_time:92426ms step_avg:97.60ms +step:948/1670 train_time:92522ms step_avg:97.60ms +step:949/1670 train_time:92619ms step_avg:97.60ms +step:950/1670 train_time:92716ms step_avg:97.60ms +step:951/1670 train_time:92814ms step_avg:97.60ms +step:952/1670 train_time:92912ms step_avg:97.60ms +step:953/1670 train_time:93009ms step_avg:97.60ms +step:954/1670 train_time:93107ms step_avg:97.60ms +step:955/1670 train_time:93203ms step_avg:97.60ms +step:956/1670 train_time:93300ms step_avg:97.59ms +step:957/1670 train_time:93396ms step_avg:97.59ms +step:958/1670 train_time:93494ms step_avg:97.59ms +step:959/1670 train_time:93592ms step_avg:97.59ms +step:960/1670 train_time:93691ms step_avg:97.59ms +step:961/1670 train_time:93789ms step_avg:97.59ms +step:962/1670 train_time:93886ms step_avg:97.59ms +step:963/1670 train_time:93982ms step_avg:97.59ms +step:964/1670 train_time:94079ms step_avg:97.59ms +step:965/1670 train_time:94177ms step_avg:97.59ms +step:966/1670 train_time:94273ms step_avg:97.59ms +step:967/1670 train_time:94370ms step_avg:97.59ms +step:968/1670 train_time:94468ms step_avg:97.59ms +step:969/1670 train_time:94565ms step_avg:97.59ms +step:970/1670 train_time:94662ms step_avg:97.59ms +step:971/1670 train_time:94759ms step_avg:97.59ms +step:972/1670 train_time:94857ms step_avg:97.59ms +step:973/1670 train_time:94955ms step_avg:97.59ms +step:974/1670 train_time:95054ms step_avg:97.59ms +step:975/1670 train_time:95151ms step_avg:97.59ms +step:976/1670 train_time:95249ms step_avg:97.59ms +step:977/1670 train_time:95345ms step_avg:97.59ms +step:978/1670 train_time:95442ms step_avg:97.59ms +step:979/1670 train_time:95539ms step_avg:97.59ms +step:980/1670 train_time:95636ms step_avg:97.59ms +step:981/1670 train_time:95734ms step_avg:97.59ms +step:982/1670 train_time:95833ms step_avg:97.59ms +step:983/1670 train_time:95931ms step_avg:97.59ms +step:984/1670 train_time:96029ms step_avg:97.59ms +step:985/1670 train_time:96126ms step_avg:97.59ms +step:986/1670 train_time:96222ms step_avg:97.59ms +step:987/1670 train_time:96319ms step_avg:97.59ms +step:988/1670 train_time:96416ms step_avg:97.59ms +step:989/1670 train_time:96514ms step_avg:97.59ms +step:990/1670 train_time:96611ms step_avg:97.59ms +step:991/1670 train_time:96708ms step_avg:97.59ms +step:992/1670 train_time:96805ms step_avg:97.59ms +step:993/1670 train_time:96901ms step_avg:97.58ms +step:994/1670 train_time:96998ms step_avg:97.58ms +step:995/1670 train_time:97096ms step_avg:97.58ms +step:996/1670 train_time:97195ms step_avg:97.59ms +step:997/1670 train_time:97292ms step_avg:97.58ms +step:998/1670 train_time:97390ms step_avg:97.59ms +step:999/1670 train_time:97488ms step_avg:97.59ms +step:1000/1670 train_time:97585ms step_avg:97.58ms +step:1000/1670 val_loss:3.4778 train_time:97680ms step_avg:97.68ms +step:1001/1670 train_time:97701ms step_avg:97.60ms +step:1002/1670 train_time:97783ms step_avg:97.59ms +step:1003/1670 train_time:97882ms step_avg:97.59ms +step:1004/1670 train_time:97980ms step_avg:97.59ms +step:1005/1670 train_time:98076ms step_avg:97.59ms +step:1006/1670 train_time:98172ms step_avg:97.59ms +step:1007/1670 train_time:98268ms step_avg:97.58ms +step:1008/1670 train_time:98364ms step_avg:97.58ms +step:1009/1670 train_time:98462ms step_avg:97.58ms +step:1010/1670 train_time:98559ms step_avg:97.58ms +step:1011/1670 train_time:98658ms step_avg:97.58ms +step:1012/1670 train_time:98757ms step_avg:97.59ms +step:1013/1670 train_time:98856ms step_avg:97.59ms +step:1014/1670 train_time:98954ms step_avg:97.59ms +step:1015/1670 train_time:99050ms step_avg:97.59ms +step:1016/1670 train_time:99146ms step_avg:97.58ms +step:1017/1670 train_time:99242ms step_avg:97.58ms +step:1018/1670 train_time:99339ms step_avg:97.58ms +step:1019/1670 train_time:99435ms step_avg:97.58ms +step:1020/1670 train_time:99531ms step_avg:97.58ms +step:1021/1670 train_time:99628ms step_avg:97.58ms +step:1022/1670 train_time:99725ms step_avg:97.58ms +step:1023/1670 train_time:99823ms step_avg:97.58ms +step:1024/1670 train_time:99922ms step_avg:97.58ms +step:1025/1670 train_time:100020ms step_avg:97.58ms +step:1026/1670 train_time:100117ms step_avg:97.58ms +step:1027/1670 train_time:100214ms step_avg:97.58ms +step:1028/1670 train_time:100311ms step_avg:97.58ms +step:1029/1670 train_time:100406ms step_avg:97.58ms +step:1030/1670 train_time:100503ms step_avg:97.58ms +step:1031/1670 train_time:100601ms step_avg:97.58ms +step:1032/1670 train_time:100698ms step_avg:97.58ms +step:1033/1670 train_time:100796ms step_avg:97.58ms +step:1034/1670 train_time:100894ms step_avg:97.58ms +step:1035/1670 train_time:100991ms step_avg:97.58ms +step:1036/1670 train_time:101088ms step_avg:97.58ms +step:1037/1670 train_time:101186ms step_avg:97.58ms +step:1038/1670 train_time:101282ms step_avg:97.57ms +step:1039/1670 train_time:101380ms step_avg:97.57ms +step:1040/1670 train_time:101477ms step_avg:97.57ms +step:1041/1670 train_time:101574ms step_avg:97.57ms +step:1042/1670 train_time:101671ms step_avg:97.57ms +step:1043/1670 train_time:101767ms step_avg:97.57ms +step:1044/1670 train_time:101865ms step_avg:97.57ms +step:1045/1670 train_time:101964ms step_avg:97.57ms +step:1046/1670 train_time:102062ms step_avg:97.57ms +step:1047/1670 train_time:102160ms step_avg:97.57ms +step:1048/1670 train_time:102257ms step_avg:97.57ms +step:1049/1670 train_time:102353ms step_avg:97.57ms +step:1050/1670 train_time:102451ms step_avg:97.57ms +step:1051/1670 train_time:102547ms step_avg:97.57ms +step:1052/1670 train_time:102643ms step_avg:97.57ms +step:1053/1670 train_time:102741ms step_avg:97.57ms +step:1054/1670 train_time:102838ms step_avg:97.57ms +step:1055/1670 train_time:102936ms step_avg:97.57ms +step:1056/1670 train_time:103034ms step_avg:97.57ms +step:1057/1670 train_time:103132ms step_avg:97.57ms +step:1058/1670 train_time:103228ms step_avg:97.57ms +step:1059/1670 train_time:103325ms step_avg:97.57ms +step:1060/1670 train_time:103423ms step_avg:97.57ms +step:1061/1670 train_time:103520ms step_avg:97.57ms +step:1062/1670 train_time:103794ms step_avg:97.73ms +step:1063/1670 train_time:103894ms step_avg:97.74ms +step:1064/1670 train_time:103989ms step_avg:97.73ms +step:1065/1670 train_time:104086ms step_avg:97.73ms +step:1066/1670 train_time:104182ms step_avg:97.73ms +step:1067/1670 train_time:104278ms step_avg:97.73ms +step:1068/1670 train_time:104374ms step_avg:97.73ms +step:1069/1670 train_time:104470ms step_avg:97.73ms +step:1070/1670 train_time:104566ms step_avg:97.73ms +step:1071/1670 train_time:104662ms step_avg:97.72ms +step:1072/1670 train_time:104765ms step_avg:97.73ms +step:1073/1670 train_time:104865ms step_avg:97.73ms +step:1074/1670 train_time:104964ms step_avg:97.73ms +step:1075/1670 train_time:105062ms step_avg:97.73ms +step:1076/1670 train_time:105160ms step_avg:97.73ms +step:1077/1670 train_time:105256ms step_avg:97.73ms +step:1078/1670 train_time:105353ms step_avg:97.73ms +step:1079/1670 train_time:105449ms step_avg:97.73ms +step:1080/1670 train_time:105545ms step_avg:97.73ms +step:1081/1670 train_time:105643ms step_avg:97.73ms +step:1082/1670 train_time:105743ms step_avg:97.73ms +step:1083/1670 train_time:105843ms step_avg:97.73ms +step:1084/1670 train_time:105941ms step_avg:97.73ms +step:1085/1670 train_time:106039ms step_avg:97.73ms +step:1086/1670 train_time:106135ms step_avg:97.73ms +step:1087/1670 train_time:106231ms step_avg:97.73ms +step:1088/1670 train_time:106327ms step_avg:97.73ms +step:1089/1670 train_time:106423ms step_avg:97.73ms +step:1090/1670 train_time:106520ms step_avg:97.73ms +step:1091/1670 train_time:106619ms step_avg:97.73ms +step:1092/1670 train_time:106718ms step_avg:97.73ms +step:1093/1670 train_time:106816ms step_avg:97.73ms +step:1094/1670 train_time:106915ms step_avg:97.73ms +step:1095/1670 train_time:107014ms step_avg:97.73ms +step:1096/1670 train_time:107110ms step_avg:97.73ms +step:1097/1670 train_time:107206ms step_avg:97.73ms +step:1098/1670 train_time:107302ms step_avg:97.73ms +step:1099/1670 train_time:107399ms step_avg:97.72ms +step:1100/1670 train_time:107495ms step_avg:97.72ms +step:1101/1670 train_time:107592ms step_avg:97.72ms +step:1102/1670 train_time:107689ms step_avg:97.72ms +step:1103/1670 train_time:107786ms step_avg:97.72ms +step:1104/1670 train_time:107884ms step_avg:97.72ms +step:1105/1670 train_time:107982ms step_avg:97.72ms +step:1106/1670 train_time:108080ms step_avg:97.72ms +step:1107/1670 train_time:108178ms step_avg:97.72ms +step:1108/1670 train_time:108274ms step_avg:97.72ms +step:1109/1670 train_time:108371ms step_avg:97.72ms +step:1110/1670 train_time:108466ms step_avg:97.72ms +step:1111/1670 train_time:108563ms step_avg:97.72ms +step:1112/1670 train_time:108661ms step_avg:97.72ms +step:1113/1670 train_time:108759ms step_avg:97.72ms +step:1114/1670 train_time:108857ms step_avg:97.72ms +step:1115/1670 train_time:108955ms step_avg:97.72ms +step:1116/1670 train_time:109054ms step_avg:97.72ms +step:1117/1670 train_time:109151ms step_avg:97.72ms +step:1118/1670 train_time:109249ms step_avg:97.72ms +step:1119/1670 train_time:109346ms step_avg:97.72ms +step:1120/1670 train_time:109443ms step_avg:97.72ms +step:1121/1670 train_time:109542ms step_avg:97.72ms +step:1122/1670 train_time:109640ms step_avg:97.72ms +step:1123/1670 train_time:109738ms step_avg:97.72ms +step:1124/1670 train_time:109837ms step_avg:97.72ms +step:1125/1670 train_time:109935ms step_avg:97.72ms +step:1125/1670 val_loss:3.4238 train_time:110032ms step_avg:97.81ms +step:1126/1670 train_time:110054ms step_avg:97.74ms +step:1127/1670 train_time:110142ms step_avg:97.73ms +step:1128/1670 train_time:110240ms step_avg:97.73ms +step:1129/1670 train_time:110337ms step_avg:97.73ms +step:1130/1670 train_time:110433ms step_avg:97.73ms +step:1131/1670 train_time:110529ms step_avg:97.73ms +step:1132/1670 train_time:110625ms step_avg:97.73ms +step:1133/1670 train_time:110722ms step_avg:97.72ms +step:1134/1670 train_time:110819ms step_avg:97.72ms +step:1135/1670 train_time:110918ms step_avg:97.73ms +step:1136/1670 train_time:111022ms step_avg:97.73ms +step:1137/1670 train_time:111122ms step_avg:97.73ms +step:1138/1670 train_time:111222ms step_avg:97.73ms +step:1139/1670 train_time:111321ms step_avg:97.74ms +step:1140/1670 train_time:111420ms step_avg:97.74ms +step:1141/1670 train_time:111517ms step_avg:97.74ms +step:1142/1670 train_time:111614ms step_avg:97.74ms +step:1143/1670 train_time:111712ms step_avg:97.74ms +step:1144/1670 train_time:111808ms step_avg:97.73ms +step:1145/1670 train_time:111905ms step_avg:97.73ms +step:1146/1670 train_time:112006ms step_avg:97.74ms +step:1147/1670 train_time:112104ms step_avg:97.74ms +step:1148/1670 train_time:112202ms step_avg:97.74ms +step:1149/1670 train_time:112302ms step_avg:97.74ms +step:1150/1670 train_time:112400ms step_avg:97.74ms +step:1151/1670 train_time:112498ms step_avg:97.74ms +step:1152/1670 train_time:112595ms step_avg:97.74ms +step:1153/1670 train_time:112691ms step_avg:97.74ms +step:1154/1670 train_time:112788ms step_avg:97.74ms +step:1155/1670 train_time:112885ms step_avg:97.74ms +step:1156/1670 train_time:112985ms step_avg:97.74ms +step:1157/1670 train_time:113085ms step_avg:97.74ms +step:1158/1670 train_time:113183ms step_avg:97.74ms +step:1159/1670 train_time:113282ms step_avg:97.74ms +step:1160/1670 train_time:113380ms step_avg:97.74ms +step:1161/1670 train_time:113479ms step_avg:97.74ms +step:1162/1670 train_time:113576ms step_avg:97.74ms +step:1163/1670 train_time:113673ms step_avg:97.74ms +step:1164/1670 train_time:113770ms step_avg:97.74ms +step:1165/1670 train_time:113867ms step_avg:97.74ms +step:1166/1670 train_time:113964ms step_avg:97.74ms +step:1167/1670 train_time:114062ms step_avg:97.74ms +step:1168/1670 train_time:114162ms step_avg:97.74ms +step:1169/1670 train_time:114261ms step_avg:97.74ms +step:1170/1670 train_time:114359ms step_avg:97.74ms +step:1171/1670 train_time:114457ms step_avg:97.74ms +step:1172/1670 train_time:114555ms step_avg:97.74ms +step:1173/1670 train_time:114652ms step_avg:97.74ms +step:1174/1670 train_time:114750ms step_avg:97.74ms +step:1175/1670 train_time:114847ms step_avg:97.74ms +step:1176/1670 train_time:114944ms step_avg:97.74ms +step:1177/1670 train_time:115042ms step_avg:97.74ms +step:1178/1670 train_time:115141ms step_avg:97.74ms +step:1179/1670 train_time:115240ms step_avg:97.74ms +step:1180/1670 train_time:115339ms step_avg:97.74ms +step:1181/1670 train_time:115437ms step_avg:97.75ms +step:1182/1670 train_time:115535ms step_avg:97.75ms +step:1183/1670 train_time:115633ms step_avg:97.75ms +step:1184/1670 train_time:115730ms step_avg:97.75ms +step:1185/1670 train_time:115827ms step_avg:97.74ms +step:1186/1670 train_time:115925ms step_avg:97.74ms +step:1187/1670 train_time:116022ms step_avg:97.74ms +step:1188/1670 train_time:116119ms step_avg:97.74ms +step:1189/1670 train_time:116217ms step_avg:97.74ms +step:1190/1670 train_time:116315ms step_avg:97.74ms +step:1191/1670 train_time:116412ms step_avg:97.74ms +step:1192/1670 train_time:116510ms step_avg:97.74ms +step:1193/1670 train_time:116607ms step_avg:97.74ms +step:1194/1670 train_time:116705ms step_avg:97.74ms +step:1195/1670 train_time:116803ms step_avg:97.74ms +step:1196/1670 train_time:116901ms step_avg:97.74ms +step:1197/1670 train_time:117000ms step_avg:97.74ms +step:1198/1670 train_time:117097ms step_avg:97.74ms +step:1199/1670 train_time:117195ms step_avg:97.74ms +step:1200/1670 train_time:117293ms step_avg:97.74ms +step:1201/1670 train_time:117390ms step_avg:97.74ms +step:1202/1670 train_time:117487ms step_avg:97.74ms +step:1203/1670 train_time:117585ms step_avg:97.74ms +step:1204/1670 train_time:117683ms step_avg:97.74ms +step:1205/1670 train_time:117783ms step_avg:97.75ms +step:1206/1670 train_time:117881ms step_avg:97.75ms +step:1207/1670 train_time:117979ms step_avg:97.75ms +step:1208/1670 train_time:118076ms step_avg:97.75ms +step:1209/1670 train_time:118173ms step_avg:97.74ms +step:1210/1670 train_time:118270ms step_avg:97.74ms +step:1211/1670 train_time:118367ms step_avg:97.74ms +step:1212/1670 train_time:118465ms step_avg:97.74ms +step:1213/1670 train_time:118563ms step_avg:97.74ms +step:1214/1670 train_time:118662ms step_avg:97.74ms +step:1215/1670 train_time:118760ms step_avg:97.75ms +step:1216/1670 train_time:118859ms step_avg:97.75ms +step:1217/1670 train_time:118957ms step_avg:97.75ms +step:1218/1670 train_time:119055ms step_avg:97.75ms +step:1219/1670 train_time:119152ms step_avg:97.75ms +step:1220/1670 train_time:119250ms step_avg:97.75ms +step:1221/1670 train_time:119347ms step_avg:97.75ms +step:1222/1670 train_time:119445ms step_avg:97.75ms +step:1223/1670 train_time:119543ms step_avg:97.75ms +step:1224/1670 train_time:119640ms step_avg:97.74ms +step:1225/1670 train_time:119737ms step_avg:97.74ms +step:1226/1670 train_time:119835ms step_avg:97.74ms +step:1227/1670 train_time:119932ms step_avg:97.74ms +step:1228/1670 train_time:120030ms step_avg:97.74ms +step:1229/1670 train_time:120128ms step_avg:97.74ms +step:1230/1670 train_time:120226ms step_avg:97.74ms +step:1231/1670 train_time:120324ms step_avg:97.74ms +step:1232/1670 train_time:120422ms step_avg:97.75ms +step:1233/1670 train_time:120521ms step_avg:97.75ms +step:1234/1670 train_time:120619ms step_avg:97.75ms +step:1235/1670 train_time:120717ms step_avg:97.75ms +step:1236/1670 train_time:120815ms step_avg:97.75ms +step:1237/1670 train_time:120913ms step_avg:97.75ms +step:1238/1670 train_time:121010ms step_avg:97.75ms +step:1239/1670 train_time:121107ms step_avg:97.75ms +step:1240/1670 train_time:121205ms step_avg:97.75ms +step:1241/1670 train_time:121304ms step_avg:97.75ms +step:1242/1670 train_time:121403ms step_avg:97.75ms +step:1243/1670 train_time:121501ms step_avg:97.75ms +step:1244/1670 train_time:121599ms step_avg:97.75ms +step:1245/1670 train_time:121697ms step_avg:97.75ms +step:1246/1670 train_time:121795ms step_avg:97.75ms +step:1247/1670 train_time:121892ms step_avg:97.75ms +step:1248/1670 train_time:121989ms step_avg:97.75ms +step:1249/1670 train_time:122087ms step_avg:97.75ms +step:1250/1670 train_time:122185ms step_avg:97.75ms +step:1250/1670 val_loss:3.3817 train_time:122282ms step_avg:97.83ms +step:1251/1670 train_time:122303ms step_avg:97.76ms +step:1252/1670 train_time:122386ms step_avg:97.75ms +step:1253/1670 train_time:122487ms step_avg:97.75ms +step:1254/1670 train_time:122585ms step_avg:97.75ms +step:1255/1670 train_time:122681ms step_avg:97.75ms +step:1256/1670 train_time:122778ms step_avg:97.75ms +step:1257/1670 train_time:122875ms step_avg:97.75ms +step:1258/1670 train_time:122972ms step_avg:97.75ms +step:1259/1670 train_time:123068ms step_avg:97.75ms +step:1260/1670 train_time:123165ms step_avg:97.75ms +step:1261/1670 train_time:123264ms step_avg:97.75ms +step:1262/1670 train_time:123362ms step_avg:97.75ms +step:1263/1670 train_time:123461ms step_avg:97.75ms +step:1264/1670 train_time:123560ms step_avg:97.75ms +step:1265/1670 train_time:123658ms step_avg:97.75ms +step:1266/1670 train_time:123755ms step_avg:97.75ms +step:1267/1670 train_time:123852ms step_avg:97.75ms +step:1268/1670 train_time:123949ms step_avg:97.75ms +step:1269/1670 train_time:124046ms step_avg:97.75ms +step:1270/1670 train_time:124143ms step_avg:97.75ms +step:1271/1670 train_time:124240ms step_avg:97.75ms +step:1272/1670 train_time:124339ms step_avg:97.75ms +step:1273/1670 train_time:124438ms step_avg:97.75ms +step:1274/1670 train_time:124725ms step_avg:97.90ms +step:1275/1670 train_time:124927ms step_avg:97.98ms +step:1276/1670 train_time:125024ms step_avg:97.98ms +step:1277/1670 train_time:125120ms step_avg:97.98ms +step:1278/1670 train_time:125217ms step_avg:97.98ms +step:1279/1670 train_time:125313ms step_avg:97.98ms +step:1280/1670 train_time:125411ms step_avg:97.98ms +step:1281/1670 train_time:125507ms step_avg:97.98ms +step:1282/1670 train_time:125604ms step_avg:97.97ms +step:1283/1670 train_time:125700ms step_avg:97.97ms +step:1284/1670 train_time:125800ms step_avg:97.97ms +step:1285/1670 train_time:125902ms step_avg:97.98ms +step:1286/1670 train_time:126001ms step_avg:97.98ms +step:1287/1670 train_time:126099ms step_avg:97.98ms +step:1288/1670 train_time:126197ms step_avg:97.98ms +step:1289/1670 train_time:126295ms step_avg:97.98ms +step:1290/1670 train_time:126393ms step_avg:97.98ms +step:1291/1670 train_time:126490ms step_avg:97.98ms +step:1292/1670 train_time:126588ms step_avg:97.98ms +step:1293/1670 train_time:126685ms step_avg:97.98ms +step:1294/1670 train_time:126783ms step_avg:97.98ms +step:1295/1670 train_time:126880ms step_avg:97.98ms +step:1296/1670 train_time:126980ms step_avg:97.98ms +step:1297/1670 train_time:127078ms step_avg:97.98ms +step:1298/1670 train_time:127175ms step_avg:97.98ms +step:1299/1670 train_time:127273ms step_avg:97.98ms +step:1300/1670 train_time:127370ms step_avg:97.98ms +step:1301/1670 train_time:127468ms step_avg:97.98ms +step:1302/1670 train_time:127565ms step_avg:97.98ms +step:1303/1670 train_time:127662ms step_avg:97.98ms +step:1304/1670 train_time:127760ms step_avg:97.98ms +step:1305/1670 train_time:127857ms step_avg:97.97ms +step:1306/1670 train_time:127956ms step_avg:97.98ms +step:1307/1670 train_time:128056ms step_avg:97.98ms +step:1308/1670 train_time:128154ms step_avg:97.98ms +step:1309/1670 train_time:128252ms step_avg:97.98ms +step:1310/1670 train_time:128350ms step_avg:97.98ms +step:1311/1670 train_time:128448ms step_avg:97.98ms +step:1312/1670 train_time:128544ms step_avg:97.98ms +step:1313/1670 train_time:128642ms step_avg:97.98ms +step:1314/1670 train_time:128739ms step_avg:97.98ms +step:1315/1670 train_time:128837ms step_avg:97.97ms +step:1316/1670 train_time:128934ms step_avg:97.97ms +step:1317/1670 train_time:129033ms step_avg:97.97ms +step:1318/1670 train_time:129133ms step_avg:97.98ms +step:1319/1670 train_time:129231ms step_avg:97.98ms +step:1320/1670 train_time:129329ms step_avg:97.98ms +step:1321/1670 train_time:129426ms step_avg:97.98ms +step:1322/1670 train_time:129523ms step_avg:97.98ms +step:1323/1670 train_time:129622ms step_avg:97.98ms +step:1324/1670 train_time:129718ms step_avg:97.97ms +step:1325/1670 train_time:129816ms step_avg:97.97ms +step:1326/1670 train_time:129914ms step_avg:97.97ms +step:1327/1670 train_time:130013ms step_avg:97.98ms +step:1328/1670 train_time:130111ms step_avg:97.97ms +step:1329/1670 train_time:130209ms step_avg:97.97ms +step:1330/1670 train_time:130306ms step_avg:97.97ms +step:1331/1670 train_time:130403ms step_avg:97.97ms +step:1332/1670 train_time:130500ms step_avg:97.97ms +step:1333/1670 train_time:130597ms step_avg:97.97ms +step:1334/1670 train_time:130695ms step_avg:97.97ms +step:1335/1670 train_time:130794ms step_avg:97.97ms +step:1336/1670 train_time:130892ms step_avg:97.97ms +step:1337/1670 train_time:130989ms step_avg:97.97ms +step:1338/1670 train_time:131087ms step_avg:97.97ms +step:1339/1670 train_time:131184ms step_avg:97.97ms +step:1340/1670 train_time:131281ms step_avg:97.97ms +step:1341/1670 train_time:131379ms step_avg:97.97ms +step:1342/1670 train_time:131477ms step_avg:97.97ms +step:1343/1670 train_time:131575ms step_avg:97.97ms +step:1344/1670 train_time:131674ms step_avg:97.97ms +step:1345/1670 train_time:131772ms step_avg:97.97ms +step:1346/1670 train_time:131871ms step_avg:97.97ms +step:1347/1670 train_time:131968ms step_avg:97.97ms +step:1348/1670 train_time:132066ms step_avg:97.97ms +step:1349/1670 train_time:132164ms step_avg:97.97ms +step:1350/1670 train_time:132261ms step_avg:97.97ms +step:1351/1670 train_time:132359ms step_avg:97.97ms +step:1352/1670 train_time:132457ms step_avg:97.97ms +step:1353/1670 train_time:132556ms step_avg:97.97ms +step:1354/1670 train_time:132654ms step_avg:97.97ms +step:1355/1670 train_time:132753ms step_avg:97.97ms +step:1356/1670 train_time:132852ms step_avg:97.97ms +step:1357/1670 train_time:132950ms step_avg:97.97ms +step:1358/1670 train_time:133048ms step_avg:97.97ms +step:1359/1670 train_time:133146ms step_avg:97.97ms +step:1360/1670 train_time:133245ms step_avg:97.97ms +step:1361/1670 train_time:133341ms step_avg:97.97ms +step:1362/1670 train_time:133439ms step_avg:97.97ms +step:1363/1670 train_time:133537ms step_avg:97.97ms +step:1364/1670 train_time:133635ms step_avg:97.97ms +step:1365/1670 train_time:133733ms step_avg:97.97ms +step:1366/1670 train_time:133831ms step_avg:97.97ms +step:1367/1670 train_time:133929ms step_avg:97.97ms +step:1368/1670 train_time:134027ms step_avg:97.97ms +step:1369/1670 train_time:134126ms step_avg:97.97ms +step:1370/1670 train_time:134224ms step_avg:97.97ms +step:1371/1670 train_time:134322ms step_avg:97.97ms +step:1372/1670 train_time:134419ms step_avg:97.97ms +step:1373/1670 train_time:134517ms step_avg:97.97ms +step:1374/1670 train_time:134614ms step_avg:97.97ms +step:1375/1670 train_time:134713ms step_avg:97.97ms +step:1375/1670 val_loss:3.3437 train_time:134810ms step_avg:98.04ms +step:1376/1670 train_time:134830ms step_avg:97.99ms +step:1377/1670 train_time:134913ms step_avg:97.98ms +step:1378/1670 train_time:135013ms step_avg:97.98ms +step:1379/1670 train_time:135110ms step_avg:97.98ms +step:1380/1670 train_time:135208ms step_avg:97.98ms +step:1381/1670 train_time:135307ms step_avg:97.98ms +step:1382/1670 train_time:135404ms step_avg:97.98ms +step:1383/1670 train_time:135501ms step_avg:97.98ms +step:1384/1670 train_time:135598ms step_avg:97.98ms +step:1385/1670 train_time:135695ms step_avg:97.97ms +step:1386/1670 train_time:135793ms step_avg:97.97ms +step:1387/1670 train_time:135892ms step_avg:97.98ms +step:1388/1670 train_time:135991ms step_avg:97.98ms +step:1389/1670 train_time:136089ms step_avg:97.98ms +step:1390/1670 train_time:136187ms step_avg:97.98ms +step:1391/1670 train_time:136285ms step_avg:97.98ms +step:1392/1670 train_time:136382ms step_avg:97.98ms +step:1393/1670 train_time:136480ms step_avg:97.98ms +step:1394/1670 train_time:136577ms step_avg:97.97ms +step:1395/1670 train_time:136674ms step_avg:97.97ms +step:1396/1670 train_time:136772ms step_avg:97.97ms +step:1397/1670 train_time:136870ms step_avg:97.97ms +step:1398/1670 train_time:136968ms step_avg:97.97ms +step:1399/1670 train_time:137068ms step_avg:97.98ms +step:1400/1670 train_time:137167ms step_avg:97.98ms +step:1401/1670 train_time:137265ms step_avg:97.98ms +step:1402/1670 train_time:137363ms step_avg:97.98ms +step:1403/1670 train_time:137460ms step_avg:97.98ms +step:1404/1670 train_time:137557ms step_avg:97.98ms +step:1405/1670 train_time:137654ms step_avg:97.97ms +step:1406/1670 train_time:137751ms step_avg:97.97ms +step:1407/1670 train_time:137850ms step_avg:97.97ms +step:1408/1670 train_time:137950ms step_avg:97.98ms +step:1409/1670 train_time:138049ms step_avg:97.98ms +step:1410/1670 train_time:138148ms step_avg:97.98ms +step:1411/1670 train_time:138246ms step_avg:97.98ms +step:1412/1670 train_time:138344ms step_avg:97.98ms +step:1413/1670 train_time:138442ms step_avg:97.98ms +step:1414/1670 train_time:138541ms step_avg:97.98ms +step:1415/1670 train_time:138639ms step_avg:97.98ms +step:1416/1670 train_time:138737ms step_avg:97.98ms +step:1417/1670 train_time:138833ms step_avg:97.98ms +step:1418/1670 train_time:138932ms step_avg:97.98ms +step:1419/1670 train_time:139030ms step_avg:97.98ms +step:1420/1670 train_time:139128ms step_avg:97.98ms +step:1421/1670 train_time:139226ms step_avg:97.98ms +step:1422/1670 train_time:139325ms step_avg:97.98ms +step:1423/1670 train_time:139423ms step_avg:97.98ms +step:1424/1670 train_time:139521ms step_avg:97.98ms +step:1425/1670 train_time:139619ms step_avg:97.98ms +step:1426/1670 train_time:139717ms step_avg:97.98ms +step:1427/1670 train_time:139815ms step_avg:97.98ms +step:1428/1670 train_time:139912ms step_avg:97.98ms +step:1429/1670 train_time:140010ms step_avg:97.98ms +step:1430/1670 train_time:140109ms step_avg:97.98ms +step:1431/1670 train_time:140207ms step_avg:97.98ms +step:1432/1670 train_time:140305ms step_avg:97.98ms +step:1433/1670 train_time:140403ms step_avg:97.98ms +step:1434/1670 train_time:140501ms step_avg:97.98ms +step:1435/1670 train_time:140598ms step_avg:97.98ms +step:1436/1670 train_time:140695ms step_avg:97.98ms +step:1437/1670 train_time:140792ms step_avg:97.98ms +step:1438/1670 train_time:140890ms step_avg:97.98ms +step:1439/1670 train_time:140987ms step_avg:97.98ms +step:1440/1670 train_time:141086ms step_avg:97.98ms +step:1441/1670 train_time:141185ms step_avg:97.98ms +step:1442/1670 train_time:141285ms step_avg:97.98ms +step:1443/1670 train_time:141383ms step_avg:97.98ms +step:1444/1670 train_time:141480ms step_avg:97.98ms +step:1445/1670 train_time:141578ms step_avg:97.98ms +step:1446/1670 train_time:141676ms step_avg:97.98ms +step:1447/1670 train_time:141773ms step_avg:97.98ms +step:1448/1670 train_time:141871ms step_avg:97.98ms +step:1449/1670 train_time:141969ms step_avg:97.98ms +step:1450/1670 train_time:142068ms step_avg:97.98ms +step:1451/1670 train_time:142165ms step_avg:97.98ms +step:1452/1670 train_time:142264ms step_avg:97.98ms +step:1453/1670 train_time:142362ms step_avg:97.98ms +step:1454/1670 train_time:142460ms step_avg:97.98ms +step:1455/1670 train_time:142559ms step_avg:97.98ms +step:1456/1670 train_time:142657ms step_avg:97.98ms +step:1457/1670 train_time:142755ms step_avg:97.98ms +step:1458/1670 train_time:142852ms step_avg:97.98ms +step:1459/1670 train_time:142949ms step_avg:97.98ms +step:1460/1670 train_time:143047ms step_avg:97.98ms +step:1461/1670 train_time:143146ms step_avg:97.98ms +step:1462/1670 train_time:143245ms step_avg:97.98ms +step:1463/1670 train_time:143344ms step_avg:97.98ms +step:1464/1670 train_time:143442ms step_avg:97.98ms +step:1465/1670 train_time:143541ms step_avg:97.98ms +step:1466/1670 train_time:143640ms step_avg:97.98ms +step:1467/1670 train_time:143738ms step_avg:97.98ms +step:1468/1670 train_time:143836ms step_avg:97.98ms +step:1469/1670 train_time:143933ms step_avg:97.98ms +step:1470/1670 train_time:144031ms step_avg:97.98ms +step:1471/1670 train_time:144129ms step_avg:97.98ms +step:1472/1670 train_time:144226ms step_avg:97.98ms +step:1473/1670 train_time:144324ms step_avg:97.98ms +step:1474/1670 train_time:144422ms step_avg:97.98ms +step:1475/1670 train_time:144521ms step_avg:97.98ms +step:1476/1670 train_time:144620ms step_avg:97.98ms +step:1477/1670 train_time:144718ms step_avg:97.98ms +step:1478/1670 train_time:144816ms step_avg:97.98ms +step:1479/1670 train_time:144913ms step_avg:97.98ms +step:1480/1670 train_time:145011ms step_avg:97.98ms +step:1481/1670 train_time:145108ms step_avg:97.98ms +step:1482/1670 train_time:145206ms step_avg:97.98ms +step:1483/1670 train_time:145303ms step_avg:97.98ms +step:1484/1670 train_time:145401ms step_avg:97.98ms +step:1485/1670 train_time:145673ms step_avg:98.10ms +step:1486/1670 train_time:145862ms step_avg:98.16ms +step:1487/1670 train_time:145958ms step_avg:98.16ms +step:1488/1670 train_time:146054ms step_avg:98.15ms +step:1489/1670 train_time:146150ms step_avg:98.15ms +step:1490/1670 train_time:146247ms step_avg:98.15ms +step:1491/1670 train_time:146344ms step_avg:98.15ms +step:1492/1670 train_time:146441ms step_avg:98.15ms +step:1493/1670 train_time:146537ms step_avg:98.15ms +step:1494/1670 train_time:146633ms step_avg:98.15ms +step:1495/1670 train_time:146730ms step_avg:98.15ms +step:1496/1670 train_time:146835ms step_avg:98.15ms +step:1497/1670 train_time:146937ms step_avg:98.15ms +step:1498/1670 train_time:147034ms step_avg:98.15ms +step:1499/1670 train_time:147130ms step_avg:98.15ms +step:1500/1670 train_time:147227ms step_avg:98.15ms +step:1500/1670 val_loss:3.3116 train_time:147324ms step_avg:98.22ms +step:1501/1670 train_time:147345ms step_avg:98.16ms +step:1502/1670 train_time:147430ms step_avg:98.16ms +step:1503/1670 train_time:147530ms step_avg:98.16ms +step:1504/1670 train_time:147628ms step_avg:98.16ms +step:1505/1670 train_time:147725ms step_avg:98.16ms +step:1506/1670 train_time:147821ms step_avg:98.15ms +step:1507/1670 train_time:147919ms step_avg:98.15ms +step:1508/1670 train_time:148016ms step_avg:98.15ms +step:1509/1670 train_time:148116ms step_avg:98.15ms +step:1510/1670 train_time:148214ms step_avg:98.16ms +step:1511/1670 train_time:148316ms step_avg:98.16ms +step:1512/1670 train_time:148417ms step_avg:98.16ms +step:1513/1670 train_time:148516ms step_avg:98.16ms +step:1514/1670 train_time:148616ms step_avg:98.16ms +step:1515/1670 train_time:148715ms step_avg:98.16ms +step:1516/1670 train_time:148813ms step_avg:98.16ms +step:1517/1670 train_time:148910ms step_avg:98.16ms +step:1518/1670 train_time:149007ms step_avg:98.16ms +step:1519/1670 train_time:149104ms step_avg:98.16ms +step:1520/1670 train_time:149201ms step_avg:98.16ms +step:1521/1670 train_time:149299ms step_avg:98.16ms +step:1522/1670 train_time:149398ms step_avg:98.16ms +step:1523/1670 train_time:149496ms step_avg:98.16ms +step:1524/1670 train_time:149595ms step_avg:98.16ms +step:1525/1670 train_time:149694ms step_avg:98.16ms +step:1526/1670 train_time:149792ms step_avg:98.16ms +step:1527/1670 train_time:149889ms step_avg:98.16ms +step:1528/1670 train_time:149987ms step_avg:98.16ms +step:1529/1670 train_time:150084ms step_avg:98.16ms +step:1530/1670 train_time:150182ms step_avg:98.16ms +step:1531/1670 train_time:150280ms step_avg:98.16ms +step:1532/1670 train_time:150378ms step_avg:98.16ms +step:1533/1670 train_time:150476ms step_avg:98.16ms +step:1534/1670 train_time:150576ms step_avg:98.16ms +step:1535/1670 train_time:150674ms step_avg:98.16ms +step:1536/1670 train_time:150773ms step_avg:98.16ms +step:1537/1670 train_time:150870ms step_avg:98.16ms +step:1538/1670 train_time:150969ms step_avg:98.16ms +step:1539/1670 train_time:151066ms step_avg:98.16ms +step:1540/1670 train_time:151163ms step_avg:98.16ms +step:1541/1670 train_time:151260ms step_avg:98.16ms +step:1542/1670 train_time:151357ms step_avg:98.16ms +step:1543/1670 train_time:151457ms step_avg:98.16ms +step:1544/1670 train_time:151557ms step_avg:98.16ms +step:1545/1670 train_time:151655ms step_avg:98.16ms +step:1546/1670 train_time:151753ms step_avg:98.16ms +step:1547/1670 train_time:151852ms step_avg:98.16ms +step:1548/1670 train_time:151950ms step_avg:98.16ms +step:1549/1670 train_time:152048ms step_avg:98.16ms +step:1550/1670 train_time:152146ms step_avg:98.16ms +step:1551/1670 train_time:152244ms step_avg:98.16ms +step:1552/1670 train_time:152341ms step_avg:98.16ms +step:1553/1670 train_time:152438ms step_avg:98.16ms +step:1554/1670 train_time:152536ms step_avg:98.16ms +step:1555/1670 train_time:152635ms step_avg:98.16ms +step:1556/1670 train_time:152732ms step_avg:98.16ms +step:1557/1670 train_time:152830ms step_avg:98.16ms +step:1558/1670 train_time:152929ms step_avg:98.16ms +step:1559/1670 train_time:153027ms step_avg:98.16ms +step:1560/1670 train_time:153125ms step_avg:98.16ms +step:1561/1670 train_time:153224ms step_avg:98.16ms +step:1562/1670 train_time:153321ms step_avg:98.16ms +step:1563/1670 train_time:153420ms step_avg:98.16ms +step:1564/1670 train_time:153518ms step_avg:98.16ms +step:1565/1670 train_time:153616ms step_avg:98.16ms +step:1566/1670 train_time:153713ms step_avg:98.16ms +step:1567/1670 train_time:153812ms step_avg:98.16ms +step:1568/1670 train_time:153909ms step_avg:98.16ms +step:1569/1670 train_time:154007ms step_avg:98.16ms +step:1570/1670 train_time:154105ms step_avg:98.16ms +step:1571/1670 train_time:154203ms step_avg:98.16ms +step:1572/1670 train_time:154300ms step_avg:98.16ms +step:1573/1670 train_time:154398ms step_avg:98.16ms +step:1574/1670 train_time:154497ms step_avg:98.16ms +step:1575/1670 train_time:154595ms step_avg:98.16ms +step:1576/1670 train_time:154693ms step_avg:98.16ms +step:1577/1670 train_time:154790ms step_avg:98.15ms +step:1578/1670 train_time:154888ms step_avg:98.15ms +step:1579/1670 train_time:154986ms step_avg:98.15ms +step:1580/1670 train_time:155084ms step_avg:98.15ms +step:1581/1670 train_time:155182ms step_avg:98.15ms +step:1582/1670 train_time:155280ms step_avg:98.15ms +step:1583/1670 train_time:155377ms step_avg:98.15ms +step:1584/1670 train_time:155475ms step_avg:98.15ms +step:1585/1670 train_time:155572ms step_avg:98.15ms +step:1586/1670 train_time:155670ms step_avg:98.15ms +step:1587/1670 train_time:155768ms step_avg:98.15ms +step:1588/1670 train_time:155865ms step_avg:98.15ms +step:1589/1670 train_time:155963ms step_avg:98.15ms +step:1590/1670 train_time:156061ms step_avg:98.15ms +step:1591/1670 train_time:156159ms step_avg:98.15ms +step:1592/1670 train_time:156257ms step_avg:98.15ms +step:1593/1670 train_time:156355ms step_avg:98.15ms +step:1594/1670 train_time:156453ms step_avg:98.15ms +step:1595/1670 train_time:156551ms step_avg:98.15ms +step:1596/1670 train_time:156648ms step_avg:98.15ms +step:1597/1670 train_time:156746ms step_avg:98.15ms +step:1598/1670 train_time:156843ms step_avg:98.15ms +step:1599/1670 train_time:156940ms step_avg:98.15ms +step:1600/1670 train_time:157039ms step_avg:98.15ms +step:1601/1670 train_time:157136ms step_avg:98.15ms +step:1602/1670 train_time:157235ms step_avg:98.15ms +step:1603/1670 train_time:157334ms step_avg:98.15ms +step:1604/1670 train_time:157432ms step_avg:98.15ms +step:1605/1670 train_time:157530ms step_avg:98.15ms +step:1606/1670 train_time:157628ms step_avg:98.15ms +step:1607/1670 train_time:157725ms step_avg:98.15ms +step:1608/1670 train_time:157823ms step_avg:98.15ms +step:1609/1670 train_time:157920ms step_avg:98.15ms +step:1610/1670 train_time:158018ms step_avg:98.15ms +step:1611/1670 train_time:158115ms step_avg:98.15ms +step:1612/1670 train_time:158214ms step_avg:98.15ms +step:1613/1670 train_time:158313ms step_avg:98.15ms +step:1614/1670 train_time:158411ms step_avg:98.15ms +step:1615/1670 train_time:158508ms step_avg:98.15ms +step:1616/1670 train_time:158606ms step_avg:98.15ms +step:1617/1670 train_time:158703ms step_avg:98.15ms +step:1618/1670 train_time:158801ms step_avg:98.15ms +step:1619/1670 train_time:158899ms step_avg:98.15ms +step:1620/1670 train_time:158996ms step_avg:98.15ms +step:1621/1670 train_time:159094ms step_avg:98.15ms +step:1622/1670 train_time:159193ms step_avg:98.15ms +step:1623/1670 train_time:159291ms step_avg:98.15ms +step:1624/1670 train_time:159389ms step_avg:98.15ms +step:1625/1670 train_time:159487ms step_avg:98.15ms +step:1625/1670 val_loss:3.2846 train_time:159584ms step_avg:98.21ms +step:1626/1670 train_time:159606ms step_avg:98.16ms +step:1627/1670 train_time:159689ms step_avg:98.15ms +step:1628/1670 train_time:159788ms step_avg:98.15ms +step:1629/1670 train_time:159886ms step_avg:98.15ms +step:1630/1670 train_time:159983ms step_avg:98.15ms +step:1631/1670 train_time:160080ms step_avg:98.15ms +step:1632/1670 train_time:160177ms step_avg:98.15ms +step:1633/1670 train_time:160274ms step_avg:98.15ms +step:1634/1670 train_time:160371ms step_avg:98.15ms +step:1635/1670 train_time:160467ms step_avg:98.15ms +step:1636/1670 train_time:160566ms step_avg:98.15ms +step:1637/1670 train_time:160668ms step_avg:98.15ms +step:1638/1670 train_time:160767ms step_avg:98.15ms +step:1639/1670 train_time:160865ms step_avg:98.15ms +step:1640/1670 train_time:160963ms step_avg:98.15ms +step:1641/1670 train_time:161061ms step_avg:98.15ms +step:1642/1670 train_time:161160ms step_avg:98.15ms +step:1643/1670 train_time:161259ms step_avg:98.15ms +step:1644/1670 train_time:161357ms step_avg:98.15ms +step:1645/1670 train_time:161454ms step_avg:98.15ms +step:1646/1670 train_time:161552ms step_avg:98.15ms +step:1647/1670 train_time:161650ms step_avg:98.15ms +step:1648/1670 train_time:161748ms step_avg:98.15ms +step:1649/1670 train_time:161845ms step_avg:98.15ms +step:1650/1670 train_time:161944ms step_avg:98.15ms +step:1651/1670 train_time:162041ms step_avg:98.15ms +step:1652/1670 train_time:162140ms step_avg:98.15ms +step:1653/1670 train_time:162238ms step_avg:98.15ms +step:1654/1670 train_time:162335ms step_avg:98.15ms +step:1655/1670 train_time:162433ms step_avg:98.15ms +step:1656/1670 train_time:162532ms step_avg:98.15ms +step:1657/1670 train_time:162630ms step_avg:98.15ms +step:1658/1670 train_time:162728ms step_avg:98.15ms +step:1659/1670 train_time:162825ms step_avg:98.15ms +step:1660/1670 train_time:162922ms step_avg:98.15ms +step:1661/1670 train_time:163021ms step_avg:98.15ms +step:1662/1670 train_time:163119ms step_avg:98.15ms +step:1663/1670 train_time:163216ms step_avg:98.15ms +step:1664/1670 train_time:163313ms step_avg:98.14ms +step:1665/1670 train_time:163410ms step_avg:98.14ms +step:1666/1670 train_time:163508ms step_avg:98.14ms +step:1667/1670 train_time:163607ms step_avg:98.14ms +step:1668/1670 train_time:163706ms step_avg:98.14ms +step:1669/1670 train_time:163803ms step_avg:98.14ms +step:1670/1670 train_time:163901ms step_avg:98.14ms +step:1670/1670 val_loss:3.2766 train_time:163998ms step_avg:98.20ms +peak memory allocated: 34757 MiB reserved: 49516 MiB diff --git a/records/090325_FA3/README.md b/records/090325_FA3/README.md new file mode 100644 index 000000000..e45e52a3a --- /dev/null +++ b/records/090325_FA3/README.md @@ -0,0 +1,133 @@ +# New record 09/03/25 + +This submission includes recent WR changes by +@ClassicLarry [(08/23/25)](https://github.com/ClassicLarry/modded-nanogpt/tree/master/records/082325_SparseAttnGate) +and @byronxu99 [(07/18/25)](https://github.com/KellerJordan/modded-nanogpt/pull/109). + +Additionally, it has been updated after helpful discussion with @ClassicLarry and @YouJiacheng. + +The main idea of this record is to use [Flash Attention v3](https://github.com/Dao-AILab/flash-attention) instead of Flex Attention. +The official version of this module is incompatible with `torch.compile` and causes graph breaks. +However, a [recent PR](https://github.com/Dao-AILab/flash-attention/pull/1769) by +@guilhermeleobas addresses this issue. + + +## Timing and Validation + +In 1670 training steps, this run achieves a loss <3.28 (`p=0.0001`) in 163.84 seconds on average, validated over 7 runs. + +``` +import torch +import numpy as np + +accs = [ + 3.2771, 3.2755, 3.2760, 3.2766, 3.2778, 3.2774, 3.2780 +] + +times = [ + 163.871, 163.621, 163.848, 163.998, 163.897, 164.016, 163.618 +] + +print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) +# p=0.0001 + +print(f"{np.mean(times):.4f}") +# 163.8384 +``` + +In my timing, this is a 4.3 second mean improvement over https://github.com/KellerJordan/modded-nanogpt/pull/117. +The number of steps can also probably be brought down by 5-10 while achieving loss <3.28. + +I used SXM5 8 x H100 via Prime Intellect for validation compute. + +## Further Details + +### Motivation + +Flash Attention v3 achieves greater SM utilization on Hopper GPUs than Flash Attention v2. +Flash Attention 3 is significantly faster than Flex Attention on batched inputs, and this gap increases as we increase the number of sequences per batch: + +Flash vs Flex Attention varying #sequences/batch + +In order to train with document masking, we use Flex Attention's `flash_attn_varlen_func` (suggested by @YouJiacheng). +We keep the number of tokens per step fixed (`393216`) but pack a variable number of sequences in each batch, +clipping the maximum length of each sequence to `args.train_max_seq_len = 2048`. + +WR#26 by @ClassicLarry found that validation loss decreases when we train only on sequences beginning with the Beginning of Sequence token (``). + + +### Flash Attention 3 + + +As mentioned above, we need to use an unmerged PR in order to use FA3 with `torch.compile`. +You can build the wheel like so: + +``` +pip install -U pip wheel setuptools ninja numpy packaging psutil + +git clone https://github.com/guilhermeleobas/flash-attention.git +cd flash-attention/hopper +git switch guilhermeleobas/fa3-compile + +export MAX_JOBS=32 # Can increase based on machine +export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch +export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only +export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8 +export FLASH_ATTENTION_DISABLE_HDIM64=TRUE # NanoGPT only uses HDIM = 128 +export FLASH_ATTENTION_DISABLE_HDIM96=TRUE +export FLASH_ATTENTION_DISABLE_HDIM192=TRUE +export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + +python setup.py bdist_wheel +``` + +Additionally, I have uploaded a prebuilt wheel [here](https://github.com/varunneal/flash-attention/releases/tag/v3.0.0b1-alpha). +Downloading this wheel and installing it via pip is likely to be fairly fast. + +For exact reproduction, I recommend that you install Torch Nightly 2.9.0.dev20250718 and +install the FA3 wheel afterward: + +``` +pip install --pre "torch==2.9.0.dev20250718+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126 + +pip install /path/to/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl +``` + +For me, Torch Nightly 2.9.0.dev20250713 was incompatible with PR#109. + +### Attention Masks + +Flash Attention exposes the parameter `window_size` where we can specify the number of tokens to attend to. +Unfortunately, it expects this value to be an int, so varying it will cause a `torch.compile` to +create a new graph. As such, I decreased the number of window sizes over the course of the run. + +I kept the existing long-short sliding window block mask pattern, as well as the idea +that the window sizes should linearly increase over the length of the training run. +To aid with this, I created a hyperparameter `ws_schedule` and `get_ws(step)`. +I additionally added the size of blocks in a window as a hyperparameter `block_size=128`. + +I have picked a linear schedule with three steps: `ws_schedule=(3, 7, 11)`. +Each graph needs to be warmed up separately. I have increased the number +of warmup steps from `10` to `30`. The compile time is dominated by the first iteration +so this will take approximately `len(ws_schedule)` times longer than before. + + +Document masks are implemented by specifying the start and end of each sequence in `cu_seqlens_*`. +In order for the tensor sizes to be fixed, we pad `cu_seqlens_*` to be a fixed length of a length larger +than the number of documents we may ever expect in a single input batch. + +At training time, sequences are clipped to `args.max_seq_len` tokens. +This clipping helps pack a greater diversity of sequences per batch. +I believe this change to be responsible for the decrease of ~25 training steps. + +In order to implement the above, I have created the helper class `BOSFinder`. + +### Potential Improvements + +- Batch size scheduling: Previously, the block mask acted as a proxy for batch size. +Now block size can be controlled explicitly and sequenced according to critical batch size theory. +I have added code in `distributed_data_generator` that allows for changing the +batch size max sequence length, and grad_accum_steps yielded after the generator is created. +- The current block mask window schedule `(3, 7, 11)` can almost certainly be improved upon. +- Hyperparameter tuning might change with smaller sequence length. Rotary base, validation sequence length, learning rates +etc. should be re-tuned. I haven't done that for this run. diff --git a/records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt b/records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt new file mode 100644 index 000000000..6f1aee1e5 --- /dev/null +++ b/records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 20:09:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 34C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 54231 C /usr/bin/python 0MiB | +| 0 N/A N/A 54232 C /usr/bin/python 0MiB | +| 0 N/A N/A 54233 C /usr/bin/python 0MiB | +| 0 N/A N/A 54234 C /usr/bin/python 0MiB | +| 0 N/A N/A 54235 C /usr/bin/python 0MiB | +| 0 N/A N/A 54236 C /usr/bin/python 0MiB | +| 0 N/A N/A 54237 C /usr/bin/python 0MiB | +| 0 N/A N/A 54238 C /usr/bin/python 0MiB | +| 1 N/A N/A 54232 C /usr/bin/python 0MiB | +| 2 N/A N/A 54233 C /usr/bin/python 0MiB | +| 3 N/A N/A 54234 C /usr/bin/python 0MiB | +| 4 N/A N/A 54235 C /usr/bin/python 0MiB | +| 5 N/A N/A 54236 C /usr/bin/python 0MiB | +| 6 N/A N/A 54237 C /usr/bin/python 0MiB | +| 7 N/A N/A 54238 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:385ms step_avg:385.21ms +step:2/1670 train_time:406ms step_avg:202.85ms +step:3/1670 train_time:479ms step_avg:159.59ms +step:4/1670 train_time:572ms step_avg:143.09ms +step:5/1670 train_time:667ms step_avg:133.32ms +step:6/1670 train_time:762ms step_avg:126.96ms +step:7/1670 train_time:856ms step_avg:122.35ms +step:8/1670 train_time:951ms step_avg:118.90ms +step:9/1670 train_time:1046ms step_avg:116.23ms +step:10/1670 train_time:1140ms step_avg:114.04ms +step:11/1670 train_time:1235ms step_avg:112.31ms +step:12/1670 train_time:1332ms step_avg:110.99ms +step:13/1670 train_time:1431ms step_avg:110.10ms +step:14/1670 train_time:1529ms step_avg:109.22ms +step:15/1670 train_time:1625ms step_avg:108.31ms +step:16/1670 train_time:1720ms step_avg:107.53ms +step:17/1670 train_time:1816ms step_avg:106.82ms +step:18/1670 train_time:1911ms step_avg:106.15ms +step:19/1670 train_time:2006ms step_avg:105.59ms +step:20/1670 train_time:2102ms step_avg:105.09ms +step:21/1670 train_time:2197ms step_avg:104.61ms +step:22/1670 train_time:2293ms step_avg:104.22ms +step:23/1670 train_time:2390ms step_avg:103.91ms +step:24/1670 train_time:2488ms step_avg:103.65ms +step:25/1670 train_time:2584ms step_avg:103.38ms +step:26/1670 train_time:2680ms step_avg:103.07ms +step:27/1670 train_time:2775ms step_avg:102.79ms +step:28/1670 train_time:2870ms step_avg:102.51ms +step:29/1670 train_time:2966ms step_avg:102.28ms +step:30/1670 train_time:3061ms step_avg:102.04ms +step:31/1670 train_time:3157ms step_avg:101.85ms +step:32/1670 train_time:3253ms step_avg:101.65ms +step:33/1670 train_time:3348ms step_avg:101.47ms +step:34/1670 train_time:3445ms step_avg:101.32ms +step:35/1670 train_time:3542ms step_avg:101.20ms +step:36/1670 train_time:3637ms step_avg:101.03ms +step:37/1670 train_time:3733ms step_avg:100.88ms +step:38/1670 train_time:3829ms step_avg:100.77ms +step:39/1670 train_time:3925ms step_avg:100.64ms +step:40/1670 train_time:4021ms step_avg:100.53ms +step:41/1670 train_time:4117ms step_avg:100.41ms +step:42/1670 train_time:4212ms step_avg:100.29ms +step:43/1670 train_time:4308ms step_avg:100.19ms +step:44/1670 train_time:4404ms step_avg:100.09ms +step:45/1670 train_time:4500ms step_avg:100.00ms +step:46/1670 train_time:4596ms step_avg:99.91ms +step:47/1670 train_time:4692ms step_avg:99.84ms +step:48/1670 train_time:4788ms step_avg:99.74ms +step:49/1670 train_time:4883ms step_avg:99.65ms +step:50/1670 train_time:4978ms step_avg:99.56ms +step:51/1670 train_time:5074ms step_avg:99.48ms +step:52/1670 train_time:5169ms step_avg:99.40ms +step:53/1670 train_time:5265ms step_avg:99.33ms +step:54/1670 train_time:5360ms step_avg:99.26ms +step:55/1670 train_time:5456ms step_avg:99.20ms +step:56/1670 train_time:5551ms step_avg:99.13ms +step:57/1670 train_time:5648ms step_avg:99.09ms +step:58/1670 train_time:5743ms step_avg:99.02ms +step:59/1670 train_time:5839ms step_avg:98.97ms +step:60/1670 train_time:5935ms step_avg:98.91ms +step:61/1670 train_time:6030ms step_avg:98.86ms +step:62/1670 train_time:6126ms step_avg:98.81ms +step:63/1670 train_time:6222ms step_avg:98.76ms +step:64/1670 train_time:6317ms step_avg:98.71ms +step:65/1670 train_time:6414ms step_avg:98.67ms +step:66/1670 train_time:6509ms step_avg:98.62ms +step:67/1670 train_time:6605ms step_avg:98.59ms +step:68/1670 train_time:6701ms step_avg:98.54ms +step:69/1670 train_time:6797ms step_avg:98.50ms +step:70/1670 train_time:6893ms step_avg:98.47ms +step:71/1670 train_time:6988ms step_avg:98.43ms +step:72/1670 train_time:7084ms step_avg:98.39ms +step:73/1670 train_time:7180ms step_avg:98.35ms +step:74/1670 train_time:7275ms step_avg:98.31ms +step:75/1670 train_time:7370ms step_avg:98.27ms +step:76/1670 train_time:7466ms step_avg:98.24ms +step:77/1670 train_time:7562ms step_avg:98.21ms +step:78/1670 train_time:7658ms step_avg:98.18ms +step:79/1670 train_time:7754ms step_avg:98.16ms +step:80/1670 train_time:7850ms step_avg:98.13ms +step:81/1670 train_time:7946ms step_avg:98.10ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8136ms step_avg:98.03ms +step:84/1670 train_time:8232ms step_avg:98.00ms +step:85/1670 train_time:8328ms step_avg:97.98ms +step:86/1670 train_time:8423ms step_avg:97.95ms +step:87/1670 train_time:8519ms step_avg:97.92ms +step:88/1670 train_time:8615ms step_avg:97.90ms +step:89/1670 train_time:8711ms step_avg:97.87ms +step:90/1670 train_time:8807ms step_avg:97.85ms +step:91/1670 train_time:8903ms step_avg:97.84ms +step:92/1670 train_time:8998ms step_avg:97.81ms +step:93/1670 train_time:9094ms step_avg:97.78ms +step:94/1670 train_time:9189ms step_avg:97.75ms +step:95/1670 train_time:9285ms step_avg:97.73ms +step:96/1670 train_time:9380ms step_avg:97.70ms +step:97/1670 train_time:9475ms step_avg:97.68ms +step:98/1670 train_time:9571ms step_avg:97.66ms +step:99/1670 train_time:9667ms step_avg:97.64ms +step:100/1670 train_time:9763ms step_avg:97.63ms +step:101/1670 train_time:9859ms step_avg:97.61ms +step:102/1670 train_time:9955ms step_avg:97.60ms +step:103/1670 train_time:10051ms step_avg:97.58ms +step:104/1670 train_time:10147ms step_avg:97.57ms +step:105/1670 train_time:10243ms step_avg:97.56ms +step:106/1670 train_time:10339ms step_avg:97.54ms +step:107/1670 train_time:10434ms step_avg:97.51ms +step:108/1670 train_time:10530ms step_avg:97.50ms +step:109/1670 train_time:10626ms step_avg:97.49ms +step:110/1670 train_time:10722ms step_avg:97.47ms +step:111/1670 train_time:10817ms step_avg:97.45ms +step:112/1670 train_time:10912ms step_avg:97.43ms +step:113/1670 train_time:11008ms step_avg:97.41ms +step:114/1670 train_time:11103ms step_avg:97.40ms +step:115/1670 train_time:11199ms step_avg:97.38ms +step:116/1670 train_time:11295ms step_avg:97.37ms +step:117/1670 train_time:11390ms step_avg:97.35ms +step:118/1670 train_time:11487ms step_avg:97.34ms +step:119/1670 train_time:11583ms step_avg:97.34ms +step:120/1670 train_time:11678ms step_avg:97.32ms +step:121/1670 train_time:11775ms step_avg:97.31ms +step:122/1670 train_time:11871ms step_avg:97.30ms +step:123/1670 train_time:11966ms step_avg:97.29ms +step:124/1670 train_time:12062ms step_avg:97.28ms +step:125/1670 train_time:12158ms step_avg:97.26ms +step:125/1670 val_loss:4.2958 train_time:12253ms step_avg:98.02ms +step:126/1670 train_time:12274ms step_avg:97.41ms +step:127/1670 train_time:12356ms step_avg:97.29ms +step:128/1670 train_time:12461ms step_avg:97.36ms +step:129/1670 train_time:12560ms step_avg:97.36ms +step:130/1670 train_time:12655ms step_avg:97.35ms +step:131/1670 train_time:12750ms step_avg:97.33ms +step:132/1670 train_time:12845ms step_avg:97.31ms +step:133/1670 train_time:12940ms step_avg:97.29ms +step:134/1670 train_time:13035ms step_avg:97.28ms +step:135/1670 train_time:13130ms step_avg:97.26ms +step:136/1670 train_time:13225ms step_avg:97.24ms +step:137/1670 train_time:13322ms step_avg:97.24ms +step:138/1670 train_time:13420ms step_avg:97.25ms +step:139/1670 train_time:13518ms step_avg:97.25ms +step:140/1670 train_time:13613ms step_avg:97.24ms +step:141/1670 train_time:13710ms step_avg:97.23ms +step:142/1670 train_time:13805ms step_avg:97.22ms +step:143/1670 train_time:13900ms step_avg:97.20ms +step:144/1670 train_time:13994ms step_avg:97.18ms +step:145/1670 train_time:14089ms step_avg:97.17ms +step:146/1670 train_time:14184ms step_avg:97.15ms +step:147/1670 train_time:14278ms step_avg:97.13ms +step:148/1670 train_time:14375ms step_avg:97.13ms +step:149/1670 train_time:14472ms step_avg:97.13ms +step:150/1670 train_time:14569ms step_avg:97.13ms +step:151/1670 train_time:14666ms step_avg:97.12ms +step:152/1670 train_time:14761ms step_avg:97.11ms +step:153/1670 train_time:14856ms step_avg:97.10ms +step:154/1670 train_time:14951ms step_avg:97.09ms +step:155/1670 train_time:15047ms step_avg:97.08ms +step:156/1670 train_time:15142ms step_avg:97.07ms +step:157/1670 train_time:15237ms step_avg:97.05ms +step:158/1670 train_time:15333ms step_avg:97.04ms +step:159/1670 train_time:15430ms step_avg:97.04ms +step:160/1670 train_time:15527ms step_avg:97.04ms +step:161/1670 train_time:15623ms step_avg:97.04ms +step:162/1670 train_time:15719ms step_avg:97.03ms +step:163/1670 train_time:15815ms step_avg:97.02ms +step:164/1670 train_time:15911ms step_avg:97.02ms +step:165/1670 train_time:16006ms step_avg:97.01ms +step:166/1670 train_time:16102ms step_avg:97.00ms +step:167/1670 train_time:16197ms step_avg:96.99ms +step:168/1670 train_time:16292ms step_avg:96.98ms +step:169/1670 train_time:16388ms step_avg:96.97ms +step:170/1670 train_time:16485ms step_avg:96.97ms +step:171/1670 train_time:16582ms step_avg:96.97ms +step:172/1670 train_time:16677ms step_avg:96.96ms +step:173/1670 train_time:16773ms step_avg:96.95ms +step:174/1670 train_time:16869ms step_avg:96.95ms +step:175/1670 train_time:16964ms step_avg:96.94ms +step:176/1670 train_time:17059ms step_avg:96.93ms +step:177/1670 train_time:17153ms step_avg:96.91ms +step:178/1670 train_time:17249ms step_avg:96.90ms +step:179/1670 train_time:17345ms step_avg:96.90ms +step:180/1670 train_time:17441ms step_avg:96.89ms +step:181/1670 train_time:17537ms step_avg:96.89ms +step:182/1670 train_time:17633ms step_avg:96.88ms +step:183/1670 train_time:17729ms step_avg:96.88ms +step:184/1670 train_time:17825ms step_avg:96.88ms +step:185/1670 train_time:17921ms step_avg:96.87ms +step:186/1670 train_time:18016ms step_avg:96.86ms +step:187/1670 train_time:18112ms step_avg:96.85ms +step:188/1670 train_time:18208ms step_avg:96.85ms +step:189/1670 train_time:18303ms step_avg:96.84ms +step:190/1670 train_time:18399ms step_avg:96.84ms +step:191/1670 train_time:18495ms step_avg:96.83ms +step:192/1670 train_time:18591ms step_avg:96.83ms +step:193/1670 train_time:18688ms step_avg:96.83ms +step:194/1670 train_time:18784ms step_avg:96.83ms +step:195/1670 train_time:18880ms step_avg:96.82ms +step:196/1670 train_time:18975ms step_avg:96.81ms +step:197/1670 train_time:19070ms step_avg:96.80ms +step:198/1670 train_time:19166ms step_avg:96.80ms +step:199/1670 train_time:19261ms step_avg:96.79ms +step:200/1670 train_time:19356ms step_avg:96.78ms +step:201/1670 train_time:19452ms step_avg:96.78ms +step:202/1670 train_time:19547ms step_avg:96.77ms +step:203/1670 train_time:19644ms step_avg:96.77ms +step:204/1670 train_time:19740ms step_avg:96.77ms +step:205/1670 train_time:19835ms step_avg:96.76ms +step:206/1670 train_time:19930ms step_avg:96.75ms +step:207/1670 train_time:20026ms step_avg:96.74ms +step:208/1670 train_time:20121ms step_avg:96.74ms +step:209/1670 train_time:20217ms step_avg:96.73ms +step:210/1670 train_time:20312ms step_avg:96.72ms +step:211/1670 train_time:20409ms step_avg:96.72ms +step:212/1670 train_time:20504ms step_avg:96.72ms +step:213/1670 train_time:20834ms step_avg:97.81ms +step:214/1670 train_time:20907ms step_avg:97.69ms +step:215/1670 train_time:21001ms step_avg:97.68ms +step:216/1670 train_time:21095ms step_avg:97.66ms +step:217/1670 train_time:21189ms step_avg:97.65ms +step:218/1670 train_time:21284ms step_avg:97.63ms +step:219/1670 train_time:21379ms step_avg:97.62ms +step:220/1670 train_time:21473ms step_avg:97.60ms +step:221/1670 train_time:21567ms step_avg:97.59ms +step:222/1670 train_time:21661ms step_avg:97.57ms +step:223/1670 train_time:21759ms step_avg:97.57ms +step:224/1670 train_time:21857ms step_avg:97.57ms +step:225/1670 train_time:21954ms step_avg:97.57ms +step:226/1670 train_time:22050ms step_avg:97.57ms +step:227/1670 train_time:22145ms step_avg:97.56ms +step:228/1670 train_time:22240ms step_avg:97.54ms +step:229/1670 train_time:22334ms step_avg:97.53ms +step:230/1670 train_time:22429ms step_avg:97.52ms +step:231/1670 train_time:22523ms step_avg:97.50ms +step:232/1670 train_time:22618ms step_avg:97.49ms +step:233/1670 train_time:22714ms step_avg:97.49ms +step:234/1670 train_time:22811ms step_avg:97.48ms +step:235/1670 train_time:22909ms step_avg:97.48ms +step:236/1670 train_time:23006ms step_avg:97.48ms +step:237/1670 train_time:23101ms step_avg:97.47ms +step:238/1670 train_time:23196ms step_avg:97.46ms +step:239/1670 train_time:23291ms step_avg:97.45ms +step:240/1670 train_time:23386ms step_avg:97.44ms +step:241/1670 train_time:23481ms step_avg:97.43ms +step:242/1670 train_time:23576ms step_avg:97.42ms +step:243/1670 train_time:23671ms step_avg:97.41ms +step:244/1670 train_time:23767ms step_avg:97.40ms +step:245/1670 train_time:23863ms step_avg:97.40ms +step:246/1670 train_time:23959ms step_avg:97.40ms +step:247/1670 train_time:24055ms step_avg:97.39ms +step:248/1670 train_time:24150ms step_avg:97.38ms +step:249/1670 train_time:24246ms step_avg:97.37ms +step:250/1670 train_time:24340ms step_avg:97.36ms +step:250/1670 val_loss:3.9752 train_time:24435ms step_avg:97.74ms +step:251/1670 train_time:24456ms step_avg:97.43ms +step:252/1670 train_time:24538ms step_avg:97.37ms +step:253/1670 train_time:24636ms step_avg:97.38ms +step:254/1670 train_time:24732ms step_avg:97.37ms +step:255/1670 train_time:24827ms step_avg:97.36ms +step:256/1670 train_time:24921ms step_avg:97.35ms +step:257/1670 train_time:25016ms step_avg:97.34ms +step:258/1670 train_time:25110ms step_avg:97.33ms +step:259/1670 train_time:25205ms step_avg:97.32ms +step:260/1670 train_time:25300ms step_avg:97.31ms +step:261/1670 train_time:25396ms step_avg:97.30ms +step:262/1670 train_time:25494ms step_avg:97.30ms +step:263/1670 train_time:25591ms step_avg:97.30ms +step:264/1670 train_time:25687ms step_avg:97.30ms +step:265/1670 train_time:25782ms step_avg:97.29ms +step:266/1670 train_time:25878ms step_avg:97.28ms +step:267/1670 train_time:25973ms step_avg:97.28ms +step:268/1670 train_time:26067ms step_avg:97.27ms +step:269/1670 train_time:26162ms step_avg:97.26ms +step:270/1670 train_time:26256ms step_avg:97.25ms +step:271/1670 train_time:26352ms step_avg:97.24ms +step:272/1670 train_time:26448ms step_avg:97.24ms +step:273/1670 train_time:26545ms step_avg:97.23ms +step:274/1670 train_time:26642ms step_avg:97.23ms +step:275/1670 train_time:26739ms step_avg:97.23ms +step:276/1670 train_time:26835ms step_avg:97.23ms +step:277/1670 train_time:26930ms step_avg:97.22ms +step:278/1670 train_time:27026ms step_avg:97.21ms +step:279/1670 train_time:27121ms step_avg:97.21ms +step:280/1670 train_time:27215ms step_avg:97.20ms +step:281/1670 train_time:27310ms step_avg:97.19ms +step:282/1670 train_time:27406ms step_avg:97.19ms +step:283/1670 train_time:27502ms step_avg:97.18ms +step:284/1670 train_time:27599ms step_avg:97.18ms +step:285/1670 train_time:27696ms step_avg:97.18ms +step:286/1670 train_time:27791ms step_avg:97.17ms +step:287/1670 train_time:27886ms step_avg:97.16ms +step:288/1670 train_time:27982ms step_avg:97.16ms +step:289/1670 train_time:28077ms step_avg:97.15ms +step:290/1670 train_time:28172ms step_avg:97.15ms +step:291/1670 train_time:28267ms step_avg:97.14ms +step:292/1670 train_time:28363ms step_avg:97.13ms +step:293/1670 train_time:28459ms step_avg:97.13ms +step:294/1670 train_time:28556ms step_avg:97.13ms +step:295/1670 train_time:28652ms step_avg:97.12ms +step:296/1670 train_time:28749ms step_avg:97.12ms +step:297/1670 train_time:28844ms step_avg:97.12ms +step:298/1670 train_time:28940ms step_avg:97.11ms +step:299/1670 train_time:29035ms step_avg:97.11ms +step:300/1670 train_time:29130ms step_avg:97.10ms +step:301/1670 train_time:29225ms step_avg:97.09ms +step:302/1670 train_time:29321ms step_avg:97.09ms +step:303/1670 train_time:29416ms step_avg:97.08ms +step:304/1670 train_time:29512ms step_avg:97.08ms +step:305/1670 train_time:29608ms step_avg:97.07ms +step:306/1670 train_time:29703ms step_avg:97.07ms +step:307/1670 train_time:29800ms step_avg:97.07ms +step:308/1670 train_time:29895ms step_avg:97.06ms +step:309/1670 train_time:29991ms step_avg:97.06ms +step:310/1670 train_time:30086ms step_avg:97.05ms +step:311/1670 train_time:30182ms step_avg:97.05ms +step:312/1670 train_time:30277ms step_avg:97.04ms +step:313/1670 train_time:30371ms step_avg:97.03ms +step:314/1670 train_time:30467ms step_avg:97.03ms +step:315/1670 train_time:30564ms step_avg:97.03ms +step:316/1670 train_time:30661ms step_avg:97.03ms +step:317/1670 train_time:30758ms step_avg:97.03ms +step:318/1670 train_time:30854ms step_avg:97.02ms +step:319/1670 train_time:30949ms step_avg:97.02ms +step:320/1670 train_time:31044ms step_avg:97.01ms +step:321/1670 train_time:31139ms step_avg:97.01ms +step:322/1670 train_time:31234ms step_avg:97.00ms +step:323/1670 train_time:31329ms step_avg:96.99ms +step:324/1670 train_time:31424ms step_avg:96.99ms +step:325/1670 train_time:31520ms step_avg:96.99ms +step:326/1670 train_time:31616ms step_avg:96.98ms +step:327/1670 train_time:31711ms step_avg:96.98ms +step:328/1670 train_time:31807ms step_avg:96.97ms +step:329/1670 train_time:31903ms step_avg:96.97ms +step:330/1670 train_time:31999ms step_avg:96.97ms +step:331/1670 train_time:32095ms step_avg:96.96ms +step:332/1670 train_time:32190ms step_avg:96.96ms +step:333/1670 train_time:32285ms step_avg:96.95ms +step:334/1670 train_time:32381ms step_avg:96.95ms +step:335/1670 train_time:32477ms step_avg:96.94ms +step:336/1670 train_time:32573ms step_avg:96.94ms +step:337/1670 train_time:32668ms step_avg:96.94ms +step:338/1670 train_time:32764ms step_avg:96.93ms +step:339/1670 train_time:32861ms step_avg:96.94ms +step:340/1670 train_time:32958ms step_avg:96.93ms +step:341/1670 train_time:33053ms step_avg:96.93ms +step:342/1670 train_time:33148ms step_avg:96.92ms +step:343/1670 train_time:33243ms step_avg:96.92ms +step:344/1670 train_time:33339ms step_avg:96.92ms +step:345/1670 train_time:33434ms step_avg:96.91ms +step:346/1670 train_time:33530ms step_avg:96.91ms +step:347/1670 train_time:33625ms step_avg:96.90ms +step:348/1670 train_time:33721ms step_avg:96.90ms +step:349/1670 train_time:33817ms step_avg:96.90ms +step:350/1670 train_time:33912ms step_avg:96.89ms +step:351/1670 train_time:34008ms step_avg:96.89ms +step:352/1670 train_time:34103ms step_avg:96.88ms +step:353/1670 train_time:34199ms step_avg:96.88ms +step:354/1670 train_time:34294ms step_avg:96.88ms +step:355/1670 train_time:34389ms step_avg:96.87ms +step:356/1670 train_time:34484ms step_avg:96.87ms +step:357/1670 train_time:34580ms step_avg:96.86ms +step:358/1670 train_time:34677ms step_avg:96.86ms +step:359/1670 train_time:34772ms step_avg:96.86ms +step:360/1670 train_time:34868ms step_avg:96.85ms +step:361/1670 train_time:34964ms step_avg:96.85ms +step:362/1670 train_time:35059ms step_avg:96.85ms +step:363/1670 train_time:35155ms step_avg:96.84ms +step:364/1670 train_time:35250ms step_avg:96.84ms +step:365/1670 train_time:35346ms step_avg:96.84ms +step:366/1670 train_time:35442ms step_avg:96.84ms +step:367/1670 train_time:35538ms step_avg:96.83ms +step:368/1670 train_time:35633ms step_avg:96.83ms +step:369/1670 train_time:35728ms step_avg:96.83ms +step:370/1670 train_time:35824ms step_avg:96.82ms +step:371/1670 train_time:35919ms step_avg:96.82ms +step:372/1670 train_time:36015ms step_avg:96.81ms +step:373/1670 train_time:36111ms step_avg:96.81ms +step:374/1670 train_time:36206ms step_avg:96.81ms +step:375/1670 train_time:36302ms step_avg:96.81ms +step:375/1670 val_loss:3.8227 train_time:36397ms step_avg:97.06ms +step:376/1670 train_time:36420ms step_avg:96.86ms +step:377/1670 train_time:36501ms step_avg:96.82ms +step:378/1670 train_time:36601ms step_avg:96.83ms +step:379/1670 train_time:36698ms step_avg:96.83ms +step:380/1670 train_time:36792ms step_avg:96.82ms +step:381/1670 train_time:36886ms step_avg:96.81ms +step:382/1670 train_time:36981ms step_avg:96.81ms +step:383/1670 train_time:37076ms step_avg:96.80ms +step:384/1670 train_time:37171ms step_avg:96.80ms +step:385/1670 train_time:37265ms step_avg:96.79ms +step:386/1670 train_time:37361ms step_avg:96.79ms +step:387/1670 train_time:37457ms step_avg:96.79ms +step:388/1670 train_time:37555ms step_avg:96.79ms +step:389/1670 train_time:37652ms step_avg:96.79ms +step:390/1670 train_time:37749ms step_avg:96.79ms +step:391/1670 train_time:37845ms step_avg:96.79ms +step:392/1670 train_time:37940ms step_avg:96.78ms +step:393/1670 train_time:38034ms step_avg:96.78ms +step:394/1670 train_time:38129ms step_avg:96.77ms +step:395/1670 train_time:38223ms step_avg:96.77ms +step:396/1670 train_time:38318ms step_avg:96.76ms +step:397/1670 train_time:38414ms step_avg:96.76ms +step:398/1670 train_time:38512ms step_avg:96.76ms +step:399/1670 train_time:38610ms step_avg:96.77ms +step:400/1670 train_time:38707ms step_avg:96.77ms +step:401/1670 train_time:38803ms step_avg:96.77ms +step:402/1670 train_time:38898ms step_avg:96.76ms +step:403/1670 train_time:38994ms step_avg:96.76ms +step:404/1670 train_time:39088ms step_avg:96.75ms +step:405/1670 train_time:39184ms step_avg:96.75ms +step:406/1670 train_time:39279ms step_avg:96.75ms +step:407/1670 train_time:39374ms step_avg:96.74ms +step:408/1670 train_time:39471ms step_avg:96.74ms +step:409/1670 train_time:39568ms step_avg:96.74ms +step:410/1670 train_time:39665ms step_avg:96.74ms +step:411/1670 train_time:39761ms step_avg:96.74ms +step:412/1670 train_time:39856ms step_avg:96.74ms +step:413/1670 train_time:39952ms step_avg:96.74ms +step:414/1670 train_time:40047ms step_avg:96.73ms +step:415/1670 train_time:40142ms step_avg:96.73ms +step:416/1670 train_time:40237ms step_avg:96.72ms +step:417/1670 train_time:40332ms step_avg:96.72ms +step:418/1670 train_time:40428ms step_avg:96.72ms +step:419/1670 train_time:40523ms step_avg:96.71ms +step:420/1670 train_time:40620ms step_avg:96.71ms +step:421/1670 train_time:40716ms step_avg:96.71ms +step:422/1670 train_time:40812ms step_avg:96.71ms +step:423/1670 train_time:40909ms step_avg:96.71ms +step:424/1670 train_time:41005ms step_avg:96.71ms +step:425/1670 train_time:41288ms step_avg:97.15ms +step:426/1670 train_time:41401ms step_avg:97.18ms +step:427/1670 train_time:41495ms step_avg:97.18ms +step:428/1670 train_time:41590ms step_avg:97.17ms +step:429/1670 train_time:41684ms step_avg:97.17ms +step:430/1670 train_time:41778ms step_avg:97.16ms +step:431/1670 train_time:41873ms step_avg:97.15ms +step:432/1670 train_time:41968ms step_avg:97.15ms +step:433/1670 train_time:42062ms step_avg:97.14ms +step:434/1670 train_time:42156ms step_avg:97.13ms +step:435/1670 train_time:42255ms step_avg:97.14ms +step:436/1670 train_time:42355ms step_avg:97.15ms +step:437/1670 train_time:42454ms step_avg:97.15ms +step:438/1670 train_time:42551ms step_avg:97.15ms +step:439/1670 train_time:42646ms step_avg:97.14ms +step:440/1670 train_time:42741ms step_avg:97.14ms +step:441/1670 train_time:42836ms step_avg:97.13ms +step:442/1670 train_time:42931ms step_avg:97.13ms +step:443/1670 train_time:43026ms step_avg:97.12ms +step:444/1670 train_time:43120ms step_avg:97.12ms +step:445/1670 train_time:43215ms step_avg:97.11ms +step:446/1670 train_time:43313ms step_avg:97.11ms +step:447/1670 train_time:43412ms step_avg:97.12ms +step:448/1670 train_time:43509ms step_avg:97.12ms +step:449/1670 train_time:43605ms step_avg:97.12ms +step:450/1670 train_time:43700ms step_avg:97.11ms +step:451/1670 train_time:43794ms step_avg:97.11ms +step:452/1670 train_time:43890ms step_avg:97.10ms +step:453/1670 train_time:43985ms step_avg:97.10ms +step:454/1670 train_time:44079ms step_avg:97.09ms +step:455/1670 train_time:44174ms step_avg:97.09ms +step:456/1670 train_time:44272ms step_avg:97.09ms +step:457/1670 train_time:44369ms step_avg:97.09ms +step:458/1670 train_time:44466ms step_avg:97.09ms +step:459/1670 train_time:44562ms step_avg:97.08ms +step:460/1670 train_time:44658ms step_avg:97.08ms +step:461/1670 train_time:44753ms step_avg:97.08ms +step:462/1670 train_time:44849ms step_avg:97.08ms +step:463/1670 train_time:44945ms step_avg:97.07ms +step:464/1670 train_time:45039ms step_avg:97.07ms +step:465/1670 train_time:45134ms step_avg:97.06ms +step:466/1670 train_time:45230ms step_avg:97.06ms +step:467/1670 train_time:45327ms step_avg:97.06ms +step:468/1670 train_time:45423ms step_avg:97.06ms +step:469/1670 train_time:45519ms step_avg:97.05ms +step:470/1670 train_time:45616ms step_avg:97.05ms +step:471/1670 train_time:45711ms step_avg:97.05ms +step:472/1670 train_time:45806ms step_avg:97.05ms +step:473/1670 train_time:45902ms step_avg:97.04ms +step:474/1670 train_time:45997ms step_avg:97.04ms +step:475/1670 train_time:46093ms step_avg:97.04ms +step:476/1670 train_time:46188ms step_avg:97.03ms +step:477/1670 train_time:46284ms step_avg:97.03ms +step:478/1670 train_time:46380ms step_avg:97.03ms +step:479/1670 train_time:46475ms step_avg:97.03ms +step:480/1670 train_time:46572ms step_avg:97.02ms +step:481/1670 train_time:46668ms step_avg:97.02ms +step:482/1670 train_time:46763ms step_avg:97.02ms +step:483/1670 train_time:46858ms step_avg:97.01ms +step:484/1670 train_time:46954ms step_avg:97.01ms +step:485/1670 train_time:47050ms step_avg:97.01ms +step:486/1670 train_time:47145ms step_avg:97.01ms +step:487/1670 train_time:47240ms step_avg:97.00ms +step:488/1670 train_time:47336ms step_avg:97.00ms +step:489/1670 train_time:47432ms step_avg:97.00ms +step:490/1670 train_time:47527ms step_avg:96.99ms +step:491/1670 train_time:47623ms step_avg:96.99ms +step:492/1670 train_time:47719ms step_avg:96.99ms +step:493/1670 train_time:47814ms step_avg:96.99ms +step:494/1670 train_time:47910ms step_avg:96.98ms +step:495/1670 train_time:48007ms step_avg:96.98ms +step:496/1670 train_time:48102ms step_avg:96.98ms +step:497/1670 train_time:48197ms step_avg:96.98ms +step:498/1670 train_time:48293ms step_avg:96.97ms +step:499/1670 train_time:48389ms step_avg:96.97ms +step:500/1670 train_time:48485ms step_avg:96.97ms +step:500/1670 val_loss:3.7166 train_time:48580ms step_avg:97.16ms +step:501/1670 train_time:48601ms step_avg:97.01ms +step:502/1670 train_time:48683ms step_avg:96.98ms +step:503/1670 train_time:48782ms step_avg:96.98ms +step:504/1670 train_time:48879ms step_avg:96.98ms +step:505/1670 train_time:48974ms step_avg:96.98ms +step:506/1670 train_time:49068ms step_avg:96.97ms +step:507/1670 train_time:49163ms step_avg:96.97ms +step:508/1670 train_time:49258ms step_avg:96.97ms +step:509/1670 train_time:49353ms step_avg:96.96ms +step:510/1670 train_time:49448ms step_avg:96.96ms +step:511/1670 train_time:49544ms step_avg:96.95ms +step:512/1670 train_time:49641ms step_avg:96.95ms +step:513/1670 train_time:49740ms step_avg:96.96ms +step:514/1670 train_time:49838ms step_avg:96.96ms +step:515/1670 train_time:49934ms step_avg:96.96ms +step:516/1670 train_time:50029ms step_avg:96.96ms +step:517/1670 train_time:50123ms step_avg:96.95ms +step:518/1670 train_time:50219ms step_avg:96.95ms +step:519/1670 train_time:50314ms step_avg:96.94ms +step:520/1670 train_time:50408ms step_avg:96.94ms +step:521/1670 train_time:50504ms step_avg:96.94ms +step:522/1670 train_time:50600ms step_avg:96.94ms +step:523/1670 train_time:50697ms step_avg:96.94ms +step:524/1670 train_time:50793ms step_avg:96.93ms +step:525/1670 train_time:50890ms step_avg:96.93ms +step:526/1670 train_time:50986ms step_avg:96.93ms +step:527/1670 train_time:51081ms step_avg:96.93ms +step:528/1670 train_time:51177ms step_avg:96.93ms +step:529/1670 train_time:51272ms step_avg:96.92ms +step:530/1670 train_time:51366ms step_avg:96.92ms +step:531/1670 train_time:51461ms step_avg:96.91ms +step:532/1670 train_time:51558ms step_avg:96.91ms +step:533/1670 train_time:51654ms step_avg:96.91ms +step:534/1670 train_time:51750ms step_avg:96.91ms +step:535/1670 train_time:51846ms step_avg:96.91ms +step:536/1670 train_time:51942ms step_avg:96.91ms +step:537/1670 train_time:52039ms step_avg:96.91ms +step:538/1670 train_time:52135ms step_avg:96.90ms +step:539/1670 train_time:52230ms step_avg:96.90ms +step:540/1670 train_time:52325ms step_avg:96.90ms +step:541/1670 train_time:52420ms step_avg:96.89ms +step:542/1670 train_time:52515ms step_avg:96.89ms +step:543/1670 train_time:52611ms step_avg:96.89ms +step:544/1670 train_time:52707ms step_avg:96.89ms +step:545/1670 train_time:52803ms step_avg:96.89ms +step:546/1670 train_time:52899ms step_avg:96.88ms +step:547/1670 train_time:52994ms step_avg:96.88ms +step:548/1670 train_time:53090ms step_avg:96.88ms +step:549/1670 train_time:53186ms step_avg:96.88ms +step:550/1670 train_time:53281ms step_avg:96.88ms +step:551/1670 train_time:53376ms step_avg:96.87ms +step:552/1670 train_time:53472ms step_avg:96.87ms +step:553/1670 train_time:53567ms step_avg:96.87ms +step:554/1670 train_time:53663ms step_avg:96.86ms +step:555/1670 train_time:53759ms step_avg:96.86ms +step:556/1670 train_time:53854ms step_avg:96.86ms +step:557/1670 train_time:53950ms step_avg:96.86ms +step:558/1670 train_time:54046ms step_avg:96.86ms +step:559/1670 train_time:54143ms step_avg:96.86ms +step:560/1670 train_time:54239ms step_avg:96.86ms +step:561/1670 train_time:54337ms step_avg:96.86ms +step:562/1670 train_time:54434ms step_avg:96.86ms +step:563/1670 train_time:54531ms step_avg:96.86ms +step:564/1670 train_time:54627ms step_avg:96.86ms +step:565/1670 train_time:54724ms step_avg:96.86ms +step:566/1670 train_time:54822ms step_avg:96.86ms +step:567/1670 train_time:54919ms step_avg:96.86ms +step:568/1670 train_time:55016ms step_avg:96.86ms +step:569/1670 train_time:55114ms step_avg:96.86ms +step:570/1670 train_time:55211ms step_avg:96.86ms +step:571/1670 train_time:55308ms step_avg:96.86ms +step:572/1670 train_time:55405ms step_avg:96.86ms +step:573/1670 train_time:55502ms step_avg:96.86ms +step:574/1670 train_time:55600ms step_avg:96.86ms +step:575/1670 train_time:55697ms step_avg:96.87ms +step:576/1670 train_time:55794ms step_avg:96.87ms +step:577/1670 train_time:55891ms step_avg:96.87ms +step:578/1670 train_time:55988ms step_avg:96.86ms +step:579/1670 train_time:56087ms step_avg:96.87ms +step:580/1670 train_time:56184ms step_avg:96.87ms +step:581/1670 train_time:56281ms step_avg:96.87ms +step:582/1670 train_time:56380ms step_avg:96.87ms +step:583/1670 train_time:56478ms step_avg:96.87ms +step:584/1670 train_time:56575ms step_avg:96.87ms +step:585/1670 train_time:56671ms step_avg:96.87ms +step:586/1670 train_time:56768ms step_avg:96.87ms +step:587/1670 train_time:56864ms step_avg:96.87ms +step:588/1670 train_time:56962ms step_avg:96.87ms +step:589/1670 train_time:57059ms step_avg:96.87ms +step:590/1670 train_time:57155ms step_avg:96.87ms +step:591/1670 train_time:57253ms step_avg:96.87ms +step:592/1670 train_time:57350ms step_avg:96.88ms +step:593/1670 train_time:57447ms step_avg:96.88ms +step:594/1670 train_time:57544ms step_avg:96.88ms +step:595/1670 train_time:57642ms step_avg:96.88ms +step:596/1670 train_time:57740ms step_avg:96.88ms +step:597/1670 train_time:57838ms step_avg:96.88ms +step:598/1670 train_time:57935ms step_avg:96.88ms +step:599/1670 train_time:58032ms step_avg:96.88ms +step:600/1670 train_time:58129ms step_avg:96.88ms +step:601/1670 train_time:58226ms step_avg:96.88ms +step:602/1670 train_time:58323ms step_avg:96.88ms +step:603/1670 train_time:58420ms step_avg:96.88ms +step:604/1670 train_time:58517ms step_avg:96.88ms +step:605/1670 train_time:58615ms step_avg:96.88ms +step:606/1670 train_time:58712ms step_avg:96.88ms +step:607/1670 train_time:58808ms step_avg:96.88ms +step:608/1670 train_time:58906ms step_avg:96.88ms +step:609/1670 train_time:59004ms step_avg:96.89ms +step:610/1670 train_time:59101ms step_avg:96.89ms +step:611/1670 train_time:59199ms step_avg:96.89ms +step:612/1670 train_time:59295ms step_avg:96.89ms +step:613/1670 train_time:59392ms step_avg:96.89ms +step:614/1670 train_time:59490ms step_avg:96.89ms +step:615/1670 train_time:59586ms step_avg:96.89ms +step:616/1670 train_time:59684ms step_avg:96.89ms +step:617/1670 train_time:59782ms step_avg:96.89ms +step:618/1670 train_time:59879ms step_avg:96.89ms +step:619/1670 train_time:59977ms step_avg:96.89ms +step:620/1670 train_time:60074ms step_avg:96.89ms +step:621/1670 train_time:60171ms step_avg:96.89ms +step:622/1670 train_time:60268ms step_avg:96.89ms +step:623/1670 train_time:60364ms step_avg:96.89ms +step:624/1670 train_time:60462ms step_avg:96.89ms +step:625/1670 train_time:60559ms step_avg:96.90ms +step:625/1670 val_loss:3.6170 train_time:60656ms step_avg:97.05ms +step:626/1670 train_time:60679ms step_avg:96.93ms +step:627/1670 train_time:60765ms step_avg:96.91ms +step:628/1670 train_time:60863ms step_avg:96.92ms +step:629/1670 train_time:60960ms step_avg:96.92ms +step:630/1670 train_time:61056ms step_avg:96.91ms +step:631/1670 train_time:61152ms step_avg:96.91ms +step:632/1670 train_time:61248ms step_avg:96.91ms +step:633/1670 train_time:61343ms step_avg:96.91ms +step:634/1670 train_time:61438ms step_avg:96.91ms +step:635/1670 train_time:61534ms step_avg:96.90ms +step:636/1670 train_time:61632ms step_avg:96.91ms +step:637/1670 train_time:61733ms step_avg:96.91ms +step:638/1670 train_time:61833ms step_avg:96.92ms +step:639/1670 train_time:62113ms step_avg:97.20ms +step:640/1670 train_time:62303ms step_avg:97.35ms +step:641/1670 train_time:62398ms step_avg:97.34ms +step:642/1670 train_time:62493ms step_avg:97.34ms +step:643/1670 train_time:62589ms step_avg:97.34ms +step:644/1670 train_time:62685ms step_avg:97.34ms +step:645/1670 train_time:62780ms step_avg:97.33ms +step:646/1670 train_time:62876ms step_avg:97.33ms +step:647/1670 train_time:62973ms step_avg:97.33ms +step:648/1670 train_time:63071ms step_avg:97.33ms +step:649/1670 train_time:63176ms step_avg:97.34ms +step:650/1670 train_time:63278ms step_avg:97.35ms +step:651/1670 train_time:63375ms step_avg:97.35ms +step:652/1670 train_time:63472ms step_avg:97.35ms +step:653/1670 train_time:63568ms step_avg:97.35ms +step:654/1670 train_time:63664ms step_avg:97.35ms +step:655/1670 train_time:63760ms step_avg:97.34ms +step:656/1670 train_time:63856ms step_avg:97.34ms +step:657/1670 train_time:63952ms step_avg:97.34ms +step:658/1670 train_time:64049ms step_avg:97.34ms +step:659/1670 train_time:64148ms step_avg:97.34ms +step:660/1670 train_time:64247ms step_avg:97.34ms +step:661/1670 train_time:64345ms step_avg:97.34ms +step:662/1670 train_time:64442ms step_avg:97.34ms +step:663/1670 train_time:64539ms step_avg:97.34ms +step:664/1670 train_time:64635ms step_avg:97.34ms +step:665/1670 train_time:64731ms step_avg:97.34ms +step:666/1670 train_time:64828ms step_avg:97.34ms +step:667/1670 train_time:64923ms step_avg:97.34ms +step:668/1670 train_time:65019ms step_avg:97.33ms +step:669/1670 train_time:65116ms step_avg:97.33ms +step:670/1670 train_time:65214ms step_avg:97.33ms +step:671/1670 train_time:65313ms step_avg:97.34ms +step:672/1670 train_time:65411ms step_avg:97.34ms +step:673/1670 train_time:65509ms step_avg:97.34ms +step:674/1670 train_time:65607ms step_avg:97.34ms +step:675/1670 train_time:65703ms step_avg:97.34ms +step:676/1670 train_time:65799ms step_avg:97.34ms +step:677/1670 train_time:65896ms step_avg:97.34ms +step:678/1670 train_time:65992ms step_avg:97.33ms +step:679/1670 train_time:66089ms step_avg:97.33ms +step:680/1670 train_time:66186ms step_avg:97.33ms +step:681/1670 train_time:66284ms step_avg:97.33ms +step:682/1670 train_time:66382ms step_avg:97.33ms +step:683/1670 train_time:66479ms step_avg:97.33ms +step:684/1670 train_time:66576ms step_avg:97.33ms +step:685/1670 train_time:66673ms step_avg:97.33ms +step:686/1670 train_time:66770ms step_avg:97.33ms +step:687/1670 train_time:66869ms step_avg:97.33ms +step:688/1670 train_time:66965ms step_avg:97.33ms +step:689/1670 train_time:67062ms step_avg:97.33ms +step:690/1670 train_time:67158ms step_avg:97.33ms +step:691/1670 train_time:67256ms step_avg:97.33ms +step:692/1670 train_time:67354ms step_avg:97.33ms +step:693/1670 train_time:67451ms step_avg:97.33ms +step:694/1670 train_time:67548ms step_avg:97.33ms +step:695/1670 train_time:67645ms step_avg:97.33ms +step:696/1670 train_time:67742ms step_avg:97.33ms +step:697/1670 train_time:67839ms step_avg:97.33ms +step:698/1670 train_time:67936ms step_avg:97.33ms +step:699/1670 train_time:68033ms step_avg:97.33ms +step:700/1670 train_time:68131ms step_avg:97.33ms +step:701/1670 train_time:68228ms step_avg:97.33ms +step:702/1670 train_time:68324ms step_avg:97.33ms +step:703/1670 train_time:68421ms step_avg:97.33ms +step:704/1670 train_time:68517ms step_avg:97.33ms +step:705/1670 train_time:68614ms step_avg:97.33ms +step:706/1670 train_time:68712ms step_avg:97.33ms +step:707/1670 train_time:68810ms step_avg:97.33ms +step:708/1670 train_time:68906ms step_avg:97.33ms +step:709/1670 train_time:69004ms step_avg:97.33ms +step:710/1670 train_time:69100ms step_avg:97.32ms +step:711/1670 train_time:69197ms step_avg:97.32ms +step:712/1670 train_time:69294ms step_avg:97.32ms +step:713/1670 train_time:69391ms step_avg:97.32ms +step:714/1670 train_time:69488ms step_avg:97.32ms +step:715/1670 train_time:69586ms step_avg:97.32ms +step:716/1670 train_time:69683ms step_avg:97.32ms +step:717/1670 train_time:69781ms step_avg:97.32ms +step:718/1670 train_time:69877ms step_avg:97.32ms +step:719/1670 train_time:69973ms step_avg:97.32ms +step:720/1670 train_time:70071ms step_avg:97.32ms +step:721/1670 train_time:70168ms step_avg:97.32ms +step:722/1670 train_time:70265ms step_avg:97.32ms +step:723/1670 train_time:70362ms step_avg:97.32ms +step:724/1670 train_time:70459ms step_avg:97.32ms +step:725/1670 train_time:70557ms step_avg:97.32ms +step:726/1670 train_time:70656ms step_avg:97.32ms +step:727/1670 train_time:70753ms step_avg:97.32ms +step:728/1670 train_time:70850ms step_avg:97.32ms +step:729/1670 train_time:70947ms step_avg:97.32ms +step:730/1670 train_time:71044ms step_avg:97.32ms +step:731/1670 train_time:71141ms step_avg:97.32ms +step:732/1670 train_time:71237ms step_avg:97.32ms +step:733/1670 train_time:71334ms step_avg:97.32ms +step:734/1670 train_time:71432ms step_avg:97.32ms +step:735/1670 train_time:71531ms step_avg:97.32ms +step:736/1670 train_time:71628ms step_avg:97.32ms +step:737/1670 train_time:71725ms step_avg:97.32ms +step:738/1670 train_time:71822ms step_avg:97.32ms +step:739/1670 train_time:71919ms step_avg:97.32ms +step:740/1670 train_time:72017ms step_avg:97.32ms +step:741/1670 train_time:72113ms step_avg:97.32ms +step:742/1670 train_time:72211ms step_avg:97.32ms +step:743/1670 train_time:72308ms step_avg:97.32ms +step:744/1670 train_time:72405ms step_avg:97.32ms +step:745/1670 train_time:72503ms step_avg:97.32ms +step:746/1670 train_time:72600ms step_avg:97.32ms +step:747/1670 train_time:72697ms step_avg:97.32ms +step:748/1670 train_time:72793ms step_avg:97.32ms +step:749/1670 train_time:72891ms step_avg:97.32ms +step:750/1670 train_time:72988ms step_avg:97.32ms +step:750/1670 val_loss:3.5615 train_time:73084ms step_avg:97.45ms +step:751/1670 train_time:73105ms step_avg:97.34ms +step:752/1670 train_time:73188ms step_avg:97.32ms +step:753/1670 train_time:73288ms step_avg:97.33ms +step:754/1670 train_time:73387ms step_avg:97.33ms +step:755/1670 train_time:73483ms step_avg:97.33ms +step:756/1670 train_time:73578ms step_avg:97.33ms +step:757/1670 train_time:73675ms step_avg:97.32ms +step:758/1670 train_time:73771ms step_avg:97.32ms +step:759/1670 train_time:73868ms step_avg:97.32ms +step:760/1670 train_time:73965ms step_avg:97.32ms +step:761/1670 train_time:74063ms step_avg:97.32ms +step:762/1670 train_time:74161ms step_avg:97.32ms +step:763/1670 train_time:74259ms step_avg:97.33ms +step:764/1670 train_time:74358ms step_avg:97.33ms +step:765/1670 train_time:74456ms step_avg:97.33ms +step:766/1670 train_time:74553ms step_avg:97.33ms +step:767/1670 train_time:74650ms step_avg:97.33ms +step:768/1670 train_time:74746ms step_avg:97.33ms +step:769/1670 train_time:74843ms step_avg:97.32ms +step:770/1670 train_time:74939ms step_avg:97.32ms +step:771/1670 train_time:75035ms step_avg:97.32ms +step:772/1670 train_time:75133ms step_avg:97.32ms +step:773/1670 train_time:75232ms step_avg:97.32ms +step:774/1670 train_time:75331ms step_avg:97.33ms +step:775/1670 train_time:75430ms step_avg:97.33ms +step:776/1670 train_time:75528ms step_avg:97.33ms +step:777/1670 train_time:75626ms step_avg:97.33ms +step:778/1670 train_time:75722ms step_avg:97.33ms +step:779/1670 train_time:75819ms step_avg:97.33ms +step:780/1670 train_time:75915ms step_avg:97.33ms +step:781/1670 train_time:76012ms step_avg:97.33ms +step:782/1670 train_time:76110ms step_avg:97.33ms +step:783/1670 train_time:76209ms step_avg:97.33ms +step:784/1670 train_time:76307ms step_avg:97.33ms +step:785/1670 train_time:76405ms step_avg:97.33ms +step:786/1670 train_time:76503ms step_avg:97.33ms +step:787/1670 train_time:76600ms step_avg:97.33ms +step:788/1670 train_time:76697ms step_avg:97.33ms +step:789/1670 train_time:76793ms step_avg:97.33ms +step:790/1670 train_time:76890ms step_avg:97.33ms +step:791/1670 train_time:76987ms step_avg:97.33ms +step:792/1670 train_time:77084ms step_avg:97.33ms +step:793/1670 train_time:77182ms step_avg:97.33ms +step:794/1670 train_time:77279ms step_avg:97.33ms +step:795/1670 train_time:77376ms step_avg:97.33ms +step:796/1670 train_time:77474ms step_avg:97.33ms +step:797/1670 train_time:77571ms step_avg:97.33ms +step:798/1670 train_time:77669ms step_avg:97.33ms +step:799/1670 train_time:77765ms step_avg:97.33ms +step:800/1670 train_time:77862ms step_avg:97.33ms +step:801/1670 train_time:77958ms step_avg:97.33ms +step:802/1670 train_time:78055ms step_avg:97.33ms +step:803/1670 train_time:78152ms step_avg:97.32ms +step:804/1670 train_time:78250ms step_avg:97.33ms +step:805/1670 train_time:78349ms step_avg:97.33ms +step:806/1670 train_time:78448ms step_avg:97.33ms +step:807/1670 train_time:78545ms step_avg:97.33ms +step:808/1670 train_time:78643ms step_avg:97.33ms +step:809/1670 train_time:78739ms step_avg:97.33ms +step:810/1670 train_time:78835ms step_avg:97.33ms +step:811/1670 train_time:78931ms step_avg:97.33ms +step:812/1670 train_time:79028ms step_avg:97.33ms +step:813/1670 train_time:79126ms step_avg:97.33ms +step:814/1670 train_time:79223ms step_avg:97.33ms +step:815/1670 train_time:79322ms step_avg:97.33ms +step:816/1670 train_time:79420ms step_avg:97.33ms +step:817/1670 train_time:79517ms step_avg:97.33ms +step:818/1670 train_time:79614ms step_avg:97.33ms +step:819/1670 train_time:79711ms step_avg:97.33ms +step:820/1670 train_time:79809ms step_avg:97.33ms +step:821/1670 train_time:79905ms step_avg:97.33ms +step:822/1670 train_time:80002ms step_avg:97.33ms +step:823/1670 train_time:80098ms step_avg:97.32ms +step:824/1670 train_time:80195ms step_avg:97.32ms +step:825/1670 train_time:80292ms step_avg:97.32ms +step:826/1670 train_time:80391ms step_avg:97.33ms +step:827/1670 train_time:80489ms step_avg:97.33ms +step:828/1670 train_time:80586ms step_avg:97.33ms +step:829/1670 train_time:80683ms step_avg:97.33ms +step:830/1670 train_time:80779ms step_avg:97.32ms +step:831/1670 train_time:80875ms step_avg:97.32ms +step:832/1670 train_time:80972ms step_avg:97.32ms +step:833/1670 train_time:81069ms step_avg:97.32ms +step:834/1670 train_time:81167ms step_avg:97.32ms +step:835/1670 train_time:81264ms step_avg:97.32ms +step:836/1670 train_time:81362ms step_avg:97.32ms +step:837/1670 train_time:81459ms step_avg:97.32ms +step:838/1670 train_time:81556ms step_avg:97.32ms +step:839/1670 train_time:81653ms step_avg:97.32ms +step:840/1670 train_time:81751ms step_avg:97.32ms +step:841/1670 train_time:81849ms step_avg:97.32ms +step:842/1670 train_time:81946ms step_avg:97.32ms +step:843/1670 train_time:82044ms step_avg:97.32ms +step:844/1670 train_time:82140ms step_avg:97.32ms +step:845/1670 train_time:82236ms step_avg:97.32ms +step:846/1670 train_time:82334ms step_avg:97.32ms +step:847/1670 train_time:82432ms step_avg:97.32ms +step:848/1670 train_time:82529ms step_avg:97.32ms +step:849/1670 train_time:82626ms step_avg:97.32ms +step:850/1670 train_time:82723ms step_avg:97.32ms +step:851/1670 train_time:83006ms step_avg:97.54ms +step:852/1670 train_time:83168ms step_avg:97.61ms +step:853/1670 train_time:83263ms step_avg:97.61ms +step:854/1670 train_time:83358ms step_avg:97.61ms +step:855/1670 train_time:83454ms step_avg:97.61ms +step:856/1670 train_time:83550ms step_avg:97.60ms +step:857/1670 train_time:83646ms step_avg:97.60ms +step:858/1670 train_time:83741ms step_avg:97.60ms +step:859/1670 train_time:83837ms step_avg:97.60ms +step:860/1670 train_time:83933ms step_avg:97.60ms +step:861/1670 train_time:84033ms step_avg:97.60ms +step:862/1670 train_time:84134ms step_avg:97.60ms +step:863/1670 train_time:84233ms step_avg:97.60ms +step:864/1670 train_time:84330ms step_avg:97.60ms +step:865/1670 train_time:84428ms step_avg:97.60ms +step:866/1670 train_time:84525ms step_avg:97.60ms +step:867/1670 train_time:84620ms step_avg:97.60ms +step:868/1670 train_time:84716ms step_avg:97.60ms +step:869/1670 train_time:84812ms step_avg:97.60ms +step:870/1670 train_time:84908ms step_avg:97.60ms +step:871/1670 train_time:85004ms step_avg:97.59ms +step:872/1670 train_time:85102ms step_avg:97.59ms +step:873/1670 train_time:85200ms step_avg:97.59ms +step:874/1670 train_time:85299ms step_avg:97.60ms +step:875/1670 train_time:85396ms step_avg:97.59ms +step:875/1670 val_loss:3.5212 train_time:85492ms step_avg:97.70ms +step:876/1670 train_time:85513ms step_avg:97.62ms +step:877/1670 train_time:85594ms step_avg:97.60ms +step:878/1670 train_time:85693ms step_avg:97.60ms +step:879/1670 train_time:85790ms step_avg:97.60ms +step:880/1670 train_time:85886ms step_avg:97.60ms +step:881/1670 train_time:85982ms step_avg:97.60ms +step:882/1670 train_time:86077ms step_avg:97.59ms +step:883/1670 train_time:86173ms step_avg:97.59ms +step:884/1670 train_time:86270ms step_avg:97.59ms +step:885/1670 train_time:86365ms step_avg:97.59ms +step:886/1670 train_time:86464ms step_avg:97.59ms +step:887/1670 train_time:86563ms step_avg:97.59ms +step:888/1670 train_time:86664ms step_avg:97.59ms +step:889/1670 train_time:86762ms step_avg:97.60ms +step:890/1670 train_time:86859ms step_avg:97.59ms +step:891/1670 train_time:86955ms step_avg:97.59ms +step:892/1670 train_time:87051ms step_avg:97.59ms +step:893/1670 train_time:87147ms step_avg:97.59ms +step:894/1670 train_time:87243ms step_avg:97.59ms +step:895/1670 train_time:87340ms step_avg:97.59ms +step:896/1670 train_time:87438ms step_avg:97.59ms +step:897/1670 train_time:87538ms step_avg:97.59ms +step:898/1670 train_time:87637ms step_avg:97.59ms +step:899/1670 train_time:87735ms step_avg:97.59ms +step:900/1670 train_time:87832ms step_avg:97.59ms +step:901/1670 train_time:87928ms step_avg:97.59ms +step:902/1670 train_time:88024ms step_avg:97.59ms +step:903/1670 train_time:88120ms step_avg:97.59ms +step:904/1670 train_time:88216ms step_avg:97.58ms +step:905/1670 train_time:88313ms step_avg:97.58ms +step:906/1670 train_time:88409ms step_avg:97.58ms +step:907/1670 train_time:88507ms step_avg:97.58ms +step:908/1670 train_time:88605ms step_avg:97.58ms +step:909/1670 train_time:88702ms step_avg:97.58ms +step:910/1670 train_time:88799ms step_avg:97.58ms +step:911/1670 train_time:88898ms step_avg:97.58ms +step:912/1670 train_time:88995ms step_avg:97.58ms +step:913/1670 train_time:89092ms step_avg:97.58ms +step:914/1670 train_time:89188ms step_avg:97.58ms +step:915/1670 train_time:89284ms step_avg:97.58ms +step:916/1670 train_time:89380ms step_avg:97.58ms +step:917/1670 train_time:89479ms step_avg:97.58ms +step:918/1670 train_time:89577ms step_avg:97.58ms +step:919/1670 train_time:89674ms step_avg:97.58ms +step:920/1670 train_time:89771ms step_avg:97.58ms +step:921/1670 train_time:89869ms step_avg:97.58ms +step:922/1670 train_time:89966ms step_avg:97.58ms +step:923/1670 train_time:90063ms step_avg:97.58ms +step:924/1670 train_time:90160ms step_avg:97.58ms +step:925/1670 train_time:90258ms step_avg:97.58ms +step:926/1670 train_time:90354ms step_avg:97.57ms +step:927/1670 train_time:90452ms step_avg:97.58ms +step:928/1670 train_time:90549ms step_avg:97.57ms +step:929/1670 train_time:90647ms step_avg:97.57ms +step:930/1670 train_time:90745ms step_avg:97.58ms +step:931/1670 train_time:90842ms step_avg:97.57ms +step:932/1670 train_time:90940ms step_avg:97.58ms +step:933/1670 train_time:91038ms step_avg:97.58ms +step:934/1670 train_time:91135ms step_avg:97.58ms +step:935/1670 train_time:91232ms step_avg:97.57ms +step:936/1670 train_time:91328ms step_avg:97.57ms +step:937/1670 train_time:91425ms step_avg:97.57ms +step:938/1670 train_time:91522ms step_avg:97.57ms +step:939/1670 train_time:91620ms step_avg:97.57ms +step:940/1670 train_time:91717ms step_avg:97.57ms +step:941/1670 train_time:91815ms step_avg:97.57ms +step:942/1670 train_time:91912ms step_avg:97.57ms +step:943/1670 train_time:92009ms step_avg:97.57ms +step:944/1670 train_time:92106ms step_avg:97.57ms +step:945/1670 train_time:92202ms step_avg:97.57ms +step:946/1670 train_time:92300ms step_avg:97.57ms +step:947/1670 train_time:92398ms step_avg:97.57ms +step:948/1670 train_time:92495ms step_avg:97.57ms +step:949/1670 train_time:92593ms step_avg:97.57ms +step:950/1670 train_time:92689ms step_avg:97.57ms +step:951/1670 train_time:92786ms step_avg:97.57ms +step:952/1670 train_time:92883ms step_avg:97.57ms +step:953/1670 train_time:92981ms step_avg:97.57ms +step:954/1670 train_time:93078ms step_avg:97.57ms +step:955/1670 train_time:93175ms step_avg:97.56ms +step:956/1670 train_time:93271ms step_avg:97.56ms +step:957/1670 train_time:93368ms step_avg:97.56ms +step:958/1670 train_time:93465ms step_avg:97.56ms +step:959/1670 train_time:93563ms step_avg:97.56ms +step:960/1670 train_time:93660ms step_avg:97.56ms +step:961/1670 train_time:93759ms step_avg:97.56ms +step:962/1670 train_time:93856ms step_avg:97.56ms +step:963/1670 train_time:93953ms step_avg:97.56ms +step:964/1670 train_time:94049ms step_avg:97.56ms +step:965/1670 train_time:94146ms step_avg:97.56ms +step:966/1670 train_time:94243ms step_avg:97.56ms +step:967/1670 train_time:94339ms step_avg:97.56ms +step:968/1670 train_time:94437ms step_avg:97.56ms +step:969/1670 train_time:94534ms step_avg:97.56ms +step:970/1670 train_time:94631ms step_avg:97.56ms +step:971/1670 train_time:94729ms step_avg:97.56ms +step:972/1670 train_time:94825ms step_avg:97.56ms +step:973/1670 train_time:94922ms step_avg:97.56ms +step:974/1670 train_time:95021ms step_avg:97.56ms +step:975/1670 train_time:95118ms step_avg:97.56ms +step:976/1670 train_time:95215ms step_avg:97.56ms +step:977/1670 train_time:95312ms step_avg:97.56ms +step:978/1670 train_time:95408ms step_avg:97.55ms +step:979/1670 train_time:95506ms step_avg:97.55ms +step:980/1670 train_time:95603ms step_avg:97.55ms +step:981/1670 train_time:95700ms step_avg:97.55ms +step:982/1670 train_time:95798ms step_avg:97.55ms +step:983/1670 train_time:95896ms step_avg:97.55ms +step:984/1670 train_time:95993ms step_avg:97.55ms +step:985/1670 train_time:96089ms step_avg:97.55ms +step:986/1670 train_time:96186ms step_avg:97.55ms +step:987/1670 train_time:96285ms step_avg:97.55ms +step:988/1670 train_time:96382ms step_avg:97.55ms +step:989/1670 train_time:96480ms step_avg:97.55ms +step:990/1670 train_time:96577ms step_avg:97.55ms +step:991/1670 train_time:96674ms step_avg:97.55ms +step:992/1670 train_time:96771ms step_avg:97.55ms +step:993/1670 train_time:96869ms step_avg:97.55ms +step:994/1670 train_time:96966ms step_avg:97.55ms +step:995/1670 train_time:97063ms step_avg:97.55ms +step:996/1670 train_time:97160ms step_avg:97.55ms +step:997/1670 train_time:97257ms step_avg:97.55ms +step:998/1670 train_time:97354ms step_avg:97.55ms +step:999/1670 train_time:97453ms step_avg:97.55ms +step:1000/1670 train_time:97550ms step_avg:97.55ms +step:1000/1670 val_loss:3.4775 train_time:97645ms step_avg:97.65ms +step:1001/1670 train_time:97667ms step_avg:97.57ms +step:1002/1670 train_time:97751ms step_avg:97.56ms +step:1003/1670 train_time:97852ms step_avg:97.56ms +step:1004/1670 train_time:97949ms step_avg:97.56ms +step:1005/1670 train_time:98046ms step_avg:97.56ms +step:1006/1670 train_time:98142ms step_avg:97.56ms +step:1007/1670 train_time:98238ms step_avg:97.56ms +step:1008/1670 train_time:98334ms step_avg:97.55ms +step:1009/1670 train_time:98430ms step_avg:97.55ms +step:1010/1670 train_time:98525ms step_avg:97.55ms +step:1011/1670 train_time:98623ms step_avg:97.55ms +step:1012/1670 train_time:98721ms step_avg:97.55ms +step:1013/1670 train_time:98820ms step_avg:97.55ms +step:1014/1670 train_time:98917ms step_avg:97.55ms +step:1015/1670 train_time:99015ms step_avg:97.55ms +step:1016/1670 train_time:99113ms step_avg:97.55ms +step:1017/1670 train_time:99209ms step_avg:97.55ms +step:1018/1670 train_time:99305ms step_avg:97.55ms +step:1019/1670 train_time:99402ms step_avg:97.55ms +step:1020/1670 train_time:99498ms step_avg:97.55ms +step:1021/1670 train_time:99595ms step_avg:97.55ms +step:1022/1670 train_time:99692ms step_avg:97.55ms +step:1023/1670 train_time:99790ms step_avg:97.55ms +step:1024/1670 train_time:99887ms step_avg:97.55ms +step:1025/1670 train_time:99985ms step_avg:97.55ms +step:1026/1670 train_time:100084ms step_avg:97.55ms +step:1027/1670 train_time:100181ms step_avg:97.55ms +step:1028/1670 train_time:100278ms step_avg:97.55ms +step:1029/1670 train_time:100375ms step_avg:97.55ms +step:1030/1670 train_time:100471ms step_avg:97.54ms +step:1031/1670 train_time:100568ms step_avg:97.54ms +step:1032/1670 train_time:100664ms step_avg:97.54ms +step:1033/1670 train_time:100761ms step_avg:97.54ms +step:1034/1670 train_time:100859ms step_avg:97.54ms +step:1035/1670 train_time:100957ms step_avg:97.54ms +step:1036/1670 train_time:101054ms step_avg:97.54ms +step:1037/1670 train_time:101152ms step_avg:97.54ms +step:1038/1670 train_time:101249ms step_avg:97.54ms +step:1039/1670 train_time:101345ms step_avg:97.54ms +step:1040/1670 train_time:101442ms step_avg:97.54ms +step:1041/1670 train_time:101539ms step_avg:97.54ms +step:1042/1670 train_time:101636ms step_avg:97.54ms +step:1043/1670 train_time:101734ms step_avg:97.54ms +step:1044/1670 train_time:101831ms step_avg:97.54ms +step:1045/1670 train_time:101927ms step_avg:97.54ms +step:1046/1670 train_time:102024ms step_avg:97.54ms +step:1047/1670 train_time:102122ms step_avg:97.54ms +step:1048/1670 train_time:102219ms step_avg:97.54ms +step:1049/1670 train_time:102317ms step_avg:97.54ms +step:1050/1670 train_time:102415ms step_avg:97.54ms +step:1051/1670 train_time:102512ms step_avg:97.54ms +step:1052/1670 train_time:102609ms step_avg:97.54ms +step:1053/1670 train_time:102705ms step_avg:97.54ms +step:1054/1670 train_time:102803ms step_avg:97.54ms +step:1055/1670 train_time:102900ms step_avg:97.54ms +step:1056/1670 train_time:102997ms step_avg:97.54ms +step:1057/1670 train_time:103095ms step_avg:97.54ms +step:1058/1670 train_time:103192ms step_avg:97.53ms +step:1059/1670 train_time:103288ms step_avg:97.53ms +step:1060/1670 train_time:103385ms step_avg:97.53ms +step:1061/1670 train_time:103482ms step_avg:97.53ms +step:1062/1670 train_time:103761ms step_avg:97.70ms +step:1063/1670 train_time:103848ms step_avg:97.69ms +step:1064/1670 train_time:103943ms step_avg:97.69ms +step:1065/1670 train_time:104039ms step_avg:97.69ms +step:1066/1670 train_time:104135ms step_avg:97.69ms +step:1067/1670 train_time:104231ms step_avg:97.69ms +step:1068/1670 train_time:104326ms step_avg:97.68ms +step:1069/1670 train_time:104423ms step_avg:97.68ms +step:1070/1670 train_time:104519ms step_avg:97.68ms +step:1071/1670 train_time:104615ms step_avg:97.68ms +step:1072/1670 train_time:104716ms step_avg:97.68ms +step:1073/1670 train_time:104817ms step_avg:97.69ms +step:1074/1670 train_time:104917ms step_avg:97.69ms +step:1075/1670 train_time:105014ms step_avg:97.69ms +step:1076/1670 train_time:105110ms step_avg:97.69ms +step:1077/1670 train_time:105206ms step_avg:97.68ms +step:1078/1670 train_time:105302ms step_avg:97.68ms +step:1079/1670 train_time:105398ms step_avg:97.68ms +step:1080/1670 train_time:105494ms step_avg:97.68ms +step:1081/1670 train_time:105590ms step_avg:97.68ms +step:1082/1670 train_time:105688ms step_avg:97.68ms +step:1083/1670 train_time:105786ms step_avg:97.68ms +step:1084/1670 train_time:105884ms step_avg:97.68ms +step:1085/1670 train_time:105983ms step_avg:97.68ms +step:1086/1670 train_time:106080ms step_avg:97.68ms +step:1087/1670 train_time:106178ms step_avg:97.68ms +step:1088/1670 train_time:106275ms step_avg:97.68ms +step:1089/1670 train_time:106371ms step_avg:97.68ms +step:1090/1670 train_time:106467ms step_avg:97.68ms +step:1091/1670 train_time:106563ms step_avg:97.67ms +step:1092/1670 train_time:106660ms step_avg:97.67ms +step:1093/1670 train_time:106758ms step_avg:97.67ms +step:1094/1670 train_time:106857ms step_avg:97.68ms +step:1095/1670 train_time:106957ms step_avg:97.68ms +step:1096/1670 train_time:107055ms step_avg:97.68ms +step:1097/1670 train_time:107152ms step_avg:97.68ms +step:1098/1670 train_time:107249ms step_avg:97.68ms +step:1099/1670 train_time:107345ms step_avg:97.68ms +step:1100/1670 train_time:107442ms step_avg:97.67ms +step:1101/1670 train_time:107539ms step_avg:97.67ms +step:1102/1670 train_time:107635ms step_avg:97.67ms +step:1103/1670 train_time:107732ms step_avg:97.67ms +step:1104/1670 train_time:107829ms step_avg:97.67ms +step:1105/1670 train_time:107927ms step_avg:97.67ms +step:1106/1670 train_time:108025ms step_avg:97.67ms +step:1107/1670 train_time:108122ms step_avg:97.67ms +step:1108/1670 train_time:108220ms step_avg:97.67ms +step:1109/1670 train_time:108317ms step_avg:97.67ms +step:1110/1670 train_time:108415ms step_avg:97.67ms +step:1111/1670 train_time:108512ms step_avg:97.67ms +step:1112/1670 train_time:108608ms step_avg:97.67ms +step:1113/1670 train_time:108704ms step_avg:97.67ms +step:1114/1670 train_time:108801ms step_avg:97.67ms +step:1115/1670 train_time:108898ms step_avg:97.67ms +step:1116/1670 train_time:108997ms step_avg:97.67ms +step:1117/1670 train_time:109095ms step_avg:97.67ms +step:1118/1670 train_time:109194ms step_avg:97.67ms +step:1119/1670 train_time:109291ms step_avg:97.67ms +step:1120/1670 train_time:109389ms step_avg:97.67ms +step:1121/1670 train_time:109486ms step_avg:97.67ms +step:1122/1670 train_time:109583ms step_avg:97.67ms +step:1123/1670 train_time:109680ms step_avg:97.67ms +step:1124/1670 train_time:109778ms step_avg:97.67ms +step:1125/1670 train_time:109875ms step_avg:97.67ms +step:1125/1670 val_loss:3.4241 train_time:109973ms step_avg:97.75ms +step:1126/1670 train_time:109994ms step_avg:97.69ms +step:1127/1670 train_time:110086ms step_avg:97.68ms +step:1128/1670 train_time:110188ms step_avg:97.68ms +step:1129/1670 train_time:110285ms step_avg:97.68ms +step:1130/1670 train_time:110382ms step_avg:97.68ms +step:1131/1670 train_time:110479ms step_avg:97.68ms +step:1132/1670 train_time:110576ms step_avg:97.68ms +step:1133/1670 train_time:110673ms step_avg:97.68ms +step:1134/1670 train_time:110769ms step_avg:97.68ms +step:1135/1670 train_time:110865ms step_avg:97.68ms +step:1136/1670 train_time:110965ms step_avg:97.68ms +step:1137/1670 train_time:111067ms step_avg:97.68ms +step:1138/1670 train_time:111166ms step_avg:97.69ms +step:1139/1670 train_time:111265ms step_avg:97.69ms +step:1140/1670 train_time:111363ms step_avg:97.69ms +step:1141/1670 train_time:111461ms step_avg:97.69ms +step:1142/1670 train_time:111558ms step_avg:97.69ms +step:1143/1670 train_time:111655ms step_avg:97.69ms +step:1144/1670 train_time:111752ms step_avg:97.68ms +step:1145/1670 train_time:111848ms step_avg:97.68ms +step:1146/1670 train_time:111945ms step_avg:97.68ms +step:1147/1670 train_time:112043ms step_avg:97.68ms +step:1148/1670 train_time:112142ms step_avg:97.68ms +step:1149/1670 train_time:112240ms step_avg:97.69ms +step:1150/1670 train_time:112339ms step_avg:97.69ms +step:1151/1670 train_time:112437ms step_avg:97.69ms +step:1152/1670 train_time:112534ms step_avg:97.69ms +step:1153/1670 train_time:112631ms step_avg:97.69ms +step:1154/1670 train_time:112728ms step_avg:97.68ms +step:1155/1670 train_time:112825ms step_avg:97.68ms +step:1156/1670 train_time:112923ms step_avg:97.68ms +step:1157/1670 train_time:113020ms step_avg:97.68ms +step:1158/1670 train_time:113119ms step_avg:97.68ms +step:1159/1670 train_time:113217ms step_avg:97.68ms +step:1160/1670 train_time:113315ms step_avg:97.68ms +step:1161/1670 train_time:113413ms step_avg:97.69ms +step:1162/1670 train_time:113511ms step_avg:97.69ms +step:1163/1670 train_time:113609ms step_avg:97.69ms +step:1164/1670 train_time:113705ms step_avg:97.68ms +step:1165/1670 train_time:113802ms step_avg:97.68ms +step:1166/1670 train_time:113900ms step_avg:97.68ms +step:1167/1670 train_time:113998ms step_avg:97.68ms +step:1168/1670 train_time:114095ms step_avg:97.68ms +step:1169/1670 train_time:114194ms step_avg:97.69ms +step:1170/1670 train_time:114292ms step_avg:97.69ms +step:1171/1670 train_time:114392ms step_avg:97.69ms +step:1172/1670 train_time:114492ms step_avg:97.69ms +step:1173/1670 train_time:114590ms step_avg:97.69ms +step:1174/1670 train_time:114688ms step_avg:97.69ms +step:1175/1670 train_time:114786ms step_avg:97.69ms +step:1176/1670 train_time:114883ms step_avg:97.69ms +step:1177/1670 train_time:114980ms step_avg:97.69ms +step:1178/1670 train_time:115077ms step_avg:97.69ms +step:1179/1670 train_time:115175ms step_avg:97.69ms +step:1180/1670 train_time:115273ms step_avg:97.69ms +step:1181/1670 train_time:115371ms step_avg:97.69ms +step:1182/1670 train_time:115470ms step_avg:97.69ms +step:1183/1670 train_time:115569ms step_avg:97.69ms +step:1184/1670 train_time:115668ms step_avg:97.69ms +step:1185/1670 train_time:115766ms step_avg:97.69ms +step:1186/1670 train_time:115865ms step_avg:97.69ms +step:1187/1670 train_time:115962ms step_avg:97.69ms +step:1188/1670 train_time:116059ms step_avg:97.69ms +step:1189/1670 train_time:116157ms step_avg:97.69ms +step:1190/1670 train_time:116254ms step_avg:97.69ms +step:1191/1670 train_time:116352ms step_avg:97.69ms +step:1192/1670 train_time:116451ms step_avg:97.69ms +step:1193/1670 train_time:116549ms step_avg:97.69ms +step:1194/1670 train_time:116647ms step_avg:97.69ms +step:1195/1670 train_time:116744ms step_avg:97.69ms +step:1196/1670 train_time:116842ms step_avg:97.69ms +step:1197/1670 train_time:116939ms step_avg:97.69ms +step:1198/1670 train_time:117036ms step_avg:97.69ms +step:1199/1670 train_time:117134ms step_avg:97.69ms +step:1200/1670 train_time:117231ms step_avg:97.69ms +step:1201/1670 train_time:117329ms step_avg:97.69ms +step:1202/1670 train_time:117429ms step_avg:97.69ms +step:1203/1670 train_time:117526ms step_avg:97.69ms +step:1204/1670 train_time:117623ms step_avg:97.69ms +step:1205/1670 train_time:117721ms step_avg:97.69ms +step:1206/1670 train_time:117819ms step_avg:97.69ms +step:1207/1670 train_time:117917ms step_avg:97.69ms +step:1208/1670 train_time:118015ms step_avg:97.69ms +step:1209/1670 train_time:118112ms step_avg:97.69ms +step:1210/1670 train_time:118210ms step_avg:97.69ms +step:1211/1670 train_time:118307ms step_avg:97.69ms +step:1212/1670 train_time:118405ms step_avg:97.69ms +step:1213/1670 train_time:118503ms step_avg:97.69ms +step:1214/1670 train_time:118600ms step_avg:97.69ms +step:1215/1670 train_time:118697ms step_avg:97.69ms +step:1216/1670 train_time:118795ms step_avg:97.69ms +step:1217/1670 train_time:118893ms step_avg:97.69ms +step:1218/1670 train_time:118992ms step_avg:97.69ms +step:1219/1670 train_time:119089ms step_avg:97.69ms +step:1220/1670 train_time:119186ms step_avg:97.69ms +step:1221/1670 train_time:119284ms step_avg:97.69ms +step:1222/1670 train_time:119382ms step_avg:97.69ms +step:1223/1670 train_time:119479ms step_avg:97.69ms +step:1224/1670 train_time:119576ms step_avg:97.69ms +step:1225/1670 train_time:119674ms step_avg:97.69ms +step:1226/1670 train_time:119772ms step_avg:97.69ms +step:1227/1670 train_time:119871ms step_avg:97.69ms +step:1228/1670 train_time:119972ms step_avg:97.70ms +step:1229/1670 train_time:120071ms step_avg:97.70ms +step:1230/1670 train_time:120169ms step_avg:97.70ms +step:1231/1670 train_time:120267ms step_avg:97.70ms +step:1232/1670 train_time:120365ms step_avg:97.70ms +step:1233/1670 train_time:120463ms step_avg:97.70ms +step:1234/1670 train_time:120560ms step_avg:97.70ms +step:1235/1670 train_time:120657ms step_avg:97.70ms +step:1236/1670 train_time:120755ms step_avg:97.70ms +step:1237/1670 train_time:120853ms step_avg:97.70ms +step:1238/1670 train_time:120949ms step_avg:97.70ms +step:1239/1670 train_time:121046ms step_avg:97.70ms +step:1240/1670 train_time:121144ms step_avg:97.70ms +step:1241/1670 train_time:121241ms step_avg:97.70ms +step:1242/1670 train_time:121339ms step_avg:97.70ms +step:1243/1670 train_time:121437ms step_avg:97.70ms +step:1244/1670 train_time:121535ms step_avg:97.70ms +step:1245/1670 train_time:121633ms step_avg:97.70ms +step:1246/1670 train_time:121733ms step_avg:97.70ms +step:1247/1670 train_time:121831ms step_avg:97.70ms +step:1248/1670 train_time:121929ms step_avg:97.70ms +step:1249/1670 train_time:122028ms step_avg:97.70ms +step:1250/1670 train_time:122126ms step_avg:97.70ms +step:1250/1670 val_loss:3.3814 train_time:122222ms step_avg:97.78ms +step:1251/1670 train_time:122244ms step_avg:97.72ms +step:1252/1670 train_time:122329ms step_avg:97.71ms +step:1253/1670 train_time:122428ms step_avg:97.71ms +step:1254/1670 train_time:122526ms step_avg:97.71ms +step:1255/1670 train_time:122623ms step_avg:97.71ms +step:1256/1670 train_time:122720ms step_avg:97.71ms +step:1257/1670 train_time:122816ms step_avg:97.71ms +step:1258/1670 train_time:122912ms step_avg:97.70ms +step:1259/1670 train_time:123008ms step_avg:97.70ms +step:1260/1670 train_time:123104ms step_avg:97.70ms +step:1261/1670 train_time:123202ms step_avg:97.70ms +step:1262/1670 train_time:123303ms step_avg:97.70ms +step:1263/1670 train_time:123402ms step_avg:97.71ms +step:1264/1670 train_time:123501ms step_avg:97.71ms +step:1265/1670 train_time:123600ms step_avg:97.71ms +step:1266/1670 train_time:123697ms step_avg:97.71ms +step:1267/1670 train_time:123795ms step_avg:97.71ms +step:1268/1670 train_time:123892ms step_avg:97.71ms +step:1269/1670 train_time:123989ms step_avg:97.71ms +step:1270/1670 train_time:124085ms step_avg:97.71ms +step:1271/1670 train_time:124182ms step_avg:97.70ms +step:1272/1670 train_time:124281ms step_avg:97.71ms +step:1273/1670 train_time:124382ms step_avg:97.71ms +step:1274/1670 train_time:124776ms step_avg:97.94ms +step:1275/1670 train_time:124877ms step_avg:97.94ms +step:1276/1670 train_time:124972ms step_avg:97.94ms +step:1277/1670 train_time:125069ms step_avg:97.94ms +step:1278/1670 train_time:125165ms step_avg:97.94ms +step:1279/1670 train_time:125262ms step_avg:97.94ms +step:1280/1670 train_time:125359ms step_avg:97.94ms +step:1281/1670 train_time:125456ms step_avg:97.94ms +step:1282/1670 train_time:125552ms step_avg:97.93ms +step:1283/1670 train_time:125649ms step_avg:97.93ms +step:1284/1670 train_time:125746ms step_avg:97.93ms +step:1285/1670 train_time:125847ms step_avg:97.94ms +step:1286/1670 train_time:125946ms step_avg:97.94ms +step:1287/1670 train_time:126044ms step_avg:97.94ms +step:1288/1670 train_time:126142ms step_avg:97.94ms +step:1289/1670 train_time:126240ms step_avg:97.94ms +step:1290/1670 train_time:126337ms step_avg:97.94ms +step:1291/1670 train_time:126434ms step_avg:97.93ms +step:1292/1670 train_time:126530ms step_avg:97.93ms +step:1293/1670 train_time:126627ms step_avg:97.93ms +step:1294/1670 train_time:126724ms step_avg:97.93ms +step:1295/1670 train_time:126823ms step_avg:97.93ms +step:1296/1670 train_time:126922ms step_avg:97.93ms +step:1297/1670 train_time:127021ms step_avg:97.93ms +step:1298/1670 train_time:127119ms step_avg:97.93ms +step:1299/1670 train_time:127217ms step_avg:97.93ms +step:1300/1670 train_time:127314ms step_avg:97.93ms +step:1301/1670 train_time:127411ms step_avg:97.93ms +step:1302/1670 train_time:127509ms step_avg:97.93ms +step:1303/1670 train_time:127605ms step_avg:97.93ms +step:1304/1670 train_time:127703ms step_avg:97.93ms +step:1305/1670 train_time:127800ms step_avg:97.93ms +step:1306/1670 train_time:127898ms step_avg:97.93ms +step:1307/1670 train_time:127997ms step_avg:97.93ms +step:1308/1670 train_time:128097ms step_avg:97.93ms +step:1309/1670 train_time:128195ms step_avg:97.93ms +step:1310/1670 train_time:128291ms step_avg:97.93ms +step:1311/1670 train_time:128388ms step_avg:97.93ms +step:1312/1670 train_time:128485ms step_avg:97.93ms +step:1313/1670 train_time:128582ms step_avg:97.93ms +step:1314/1670 train_time:128681ms step_avg:97.93ms +step:1315/1670 train_time:128778ms step_avg:97.93ms +step:1316/1670 train_time:128877ms step_avg:97.93ms +step:1317/1670 train_time:128975ms step_avg:97.93ms +step:1318/1670 train_time:129073ms step_avg:97.93ms +step:1319/1670 train_time:129171ms step_avg:97.93ms +step:1320/1670 train_time:129268ms step_avg:97.93ms +step:1321/1670 train_time:129366ms step_avg:97.93ms +step:1322/1670 train_time:129464ms step_avg:97.93ms +step:1323/1670 train_time:129562ms step_avg:97.93ms +step:1324/1670 train_time:129660ms step_avg:97.93ms +step:1325/1670 train_time:129759ms step_avg:97.93ms +step:1326/1670 train_time:129856ms step_avg:97.93ms +step:1327/1670 train_time:129954ms step_avg:97.93ms +step:1328/1670 train_time:130051ms step_avg:97.93ms +step:1329/1670 train_time:130149ms step_avg:97.93ms +step:1330/1670 train_time:130247ms step_avg:97.93ms +step:1331/1670 train_time:130345ms step_avg:97.93ms +step:1332/1670 train_time:130442ms step_avg:97.93ms +step:1333/1670 train_time:130540ms step_avg:97.93ms +step:1334/1670 train_time:130637ms step_avg:97.93ms +step:1335/1670 train_time:130736ms step_avg:97.93ms +step:1336/1670 train_time:130833ms step_avg:97.93ms +step:1337/1670 train_time:130931ms step_avg:97.93ms +step:1338/1670 train_time:131029ms step_avg:97.93ms +step:1339/1670 train_time:131129ms step_avg:97.93ms +step:1340/1670 train_time:131227ms step_avg:97.93ms +step:1341/1670 train_time:131324ms step_avg:97.93ms +step:1342/1670 train_time:131421ms step_avg:97.93ms +step:1343/1670 train_time:131518ms step_avg:97.93ms +step:1344/1670 train_time:131616ms step_avg:97.93ms +step:1345/1670 train_time:131714ms step_avg:97.93ms +step:1346/1670 train_time:131812ms step_avg:97.93ms +step:1347/1670 train_time:131909ms step_avg:97.93ms +step:1348/1670 train_time:132006ms step_avg:97.93ms +step:1349/1670 train_time:132106ms step_avg:97.93ms +step:1350/1670 train_time:132204ms step_avg:97.93ms +step:1351/1670 train_time:132301ms step_avg:97.93ms +step:1352/1670 train_time:132399ms step_avg:97.93ms +step:1353/1670 train_time:132496ms step_avg:97.93ms +step:1354/1670 train_time:132595ms step_avg:97.93ms +step:1355/1670 train_time:132693ms step_avg:97.93ms +step:1356/1670 train_time:132790ms step_avg:97.93ms +step:1357/1670 train_time:132887ms step_avg:97.93ms +step:1358/1670 train_time:132985ms step_avg:97.93ms +step:1359/1670 train_time:133082ms step_avg:97.93ms +step:1360/1670 train_time:133181ms step_avg:97.93ms +step:1361/1670 train_time:133278ms step_avg:97.93ms +step:1362/1670 train_time:133377ms step_avg:97.93ms +step:1363/1670 train_time:133475ms step_avg:97.93ms +step:1364/1670 train_time:133573ms step_avg:97.93ms +step:1365/1670 train_time:133671ms step_avg:97.93ms +step:1366/1670 train_time:133768ms step_avg:97.93ms +step:1367/1670 train_time:133866ms step_avg:97.93ms +step:1368/1670 train_time:133963ms step_avg:97.93ms +step:1369/1670 train_time:134062ms step_avg:97.93ms +step:1370/1670 train_time:134161ms step_avg:97.93ms +step:1371/1670 train_time:134259ms step_avg:97.93ms +step:1372/1670 train_time:134357ms step_avg:97.93ms +step:1373/1670 train_time:134455ms step_avg:97.93ms +step:1374/1670 train_time:134553ms step_avg:97.93ms +step:1375/1670 train_time:134651ms step_avg:97.93ms +step:1375/1670 val_loss:3.3443 train_time:134749ms step_avg:98.00ms +step:1376/1670 train_time:134770ms step_avg:97.94ms +step:1377/1670 train_time:134855ms step_avg:97.93ms +step:1378/1670 train_time:134959ms step_avg:97.94ms +step:1379/1670 train_time:135057ms step_avg:97.94ms +step:1380/1670 train_time:135154ms step_avg:97.94ms +step:1381/1670 train_time:135250ms step_avg:97.94ms +step:1382/1670 train_time:135347ms step_avg:97.94ms +step:1383/1670 train_time:135444ms step_avg:97.94ms +step:1384/1670 train_time:135541ms step_avg:97.93ms +step:1385/1670 train_time:135638ms step_avg:97.93ms +step:1386/1670 train_time:135737ms step_avg:97.93ms +step:1387/1670 train_time:135840ms step_avg:97.94ms +step:1388/1670 train_time:135941ms step_avg:97.94ms +step:1389/1670 train_time:136042ms step_avg:97.94ms +step:1390/1670 train_time:136140ms step_avg:97.94ms +step:1391/1670 train_time:136238ms step_avg:97.94ms +step:1392/1670 train_time:136335ms step_avg:97.94ms +step:1393/1670 train_time:136432ms step_avg:97.94ms +step:1394/1670 train_time:136529ms step_avg:97.94ms +step:1395/1670 train_time:136625ms step_avg:97.94ms +step:1396/1670 train_time:136723ms step_avg:97.94ms +step:1397/1670 train_time:136822ms step_avg:97.94ms +step:1398/1670 train_time:136922ms step_avg:97.94ms +step:1399/1670 train_time:137023ms step_avg:97.94ms +step:1400/1670 train_time:137122ms step_avg:97.94ms +step:1401/1670 train_time:137222ms step_avg:97.95ms +step:1402/1670 train_time:137320ms step_avg:97.95ms +step:1403/1670 train_time:137418ms step_avg:97.95ms +step:1404/1670 train_time:137515ms step_avg:97.94ms +step:1405/1670 train_time:137611ms step_avg:97.94ms +step:1406/1670 train_time:137708ms step_avg:97.94ms +step:1407/1670 train_time:137806ms step_avg:97.94ms +step:1408/1670 train_time:137905ms step_avg:97.94ms +step:1409/1670 train_time:138005ms step_avg:97.95ms +step:1410/1670 train_time:138106ms step_avg:97.95ms +step:1411/1670 train_time:138204ms step_avg:97.95ms +step:1412/1670 train_time:138302ms step_avg:97.95ms +step:1413/1670 train_time:138400ms step_avg:97.95ms +step:1414/1670 train_time:138497ms step_avg:97.95ms +step:1415/1670 train_time:138594ms step_avg:97.95ms +step:1416/1670 train_time:138691ms step_avg:97.95ms +step:1417/1670 train_time:138789ms step_avg:97.95ms +step:1418/1670 train_time:138886ms step_avg:97.95ms +step:1419/1670 train_time:138985ms step_avg:97.95ms +step:1420/1670 train_time:139085ms step_avg:97.95ms +step:1421/1670 train_time:139183ms step_avg:97.95ms +step:1422/1670 train_time:139282ms step_avg:97.95ms +step:1423/1670 train_time:139380ms step_avg:97.95ms +step:1424/1670 train_time:139477ms step_avg:97.95ms +step:1425/1670 train_time:139575ms step_avg:97.95ms +step:1426/1670 train_time:139673ms step_avg:97.95ms +step:1427/1670 train_time:139771ms step_avg:97.95ms +step:1428/1670 train_time:139868ms step_avg:97.95ms +step:1429/1670 train_time:139966ms step_avg:97.95ms +step:1430/1670 train_time:140065ms step_avg:97.95ms +step:1431/1670 train_time:140164ms step_avg:97.95ms +step:1432/1670 train_time:140262ms step_avg:97.95ms +step:1433/1670 train_time:140360ms step_avg:97.95ms +step:1434/1670 train_time:140458ms step_avg:97.95ms +step:1435/1670 train_time:140555ms step_avg:97.95ms +step:1436/1670 train_time:140652ms step_avg:97.95ms +step:1437/1670 train_time:140749ms step_avg:97.95ms +step:1438/1670 train_time:140846ms step_avg:97.95ms +step:1439/1670 train_time:140944ms step_avg:97.95ms +step:1440/1670 train_time:141042ms step_avg:97.95ms +step:1441/1670 train_time:141141ms step_avg:97.95ms +step:1442/1670 train_time:141240ms step_avg:97.95ms +step:1443/1670 train_time:141336ms step_avg:97.95ms +step:1444/1670 train_time:141434ms step_avg:97.95ms +step:1445/1670 train_time:141532ms step_avg:97.95ms +step:1446/1670 train_time:141630ms step_avg:97.95ms +step:1447/1670 train_time:141727ms step_avg:97.95ms +step:1448/1670 train_time:141824ms step_avg:97.94ms +step:1449/1670 train_time:141922ms step_avg:97.94ms +step:1450/1670 train_time:142020ms step_avg:97.94ms +step:1451/1670 train_time:142117ms step_avg:97.94ms +step:1452/1670 train_time:142215ms step_avg:97.94ms +step:1453/1670 train_time:142312ms step_avg:97.94ms +step:1454/1670 train_time:142410ms step_avg:97.94ms +step:1455/1670 train_time:142509ms step_avg:97.94ms +step:1456/1670 train_time:142607ms step_avg:97.94ms +step:1457/1670 train_time:142705ms step_avg:97.94ms +step:1458/1670 train_time:142803ms step_avg:97.94ms +step:1459/1670 train_time:142902ms step_avg:97.95ms +step:1460/1670 train_time:142999ms step_avg:97.94ms +step:1461/1670 train_time:143097ms step_avg:97.94ms +step:1462/1670 train_time:143195ms step_avg:97.94ms +step:1463/1670 train_time:143292ms step_avg:97.94ms +step:1464/1670 train_time:143390ms step_avg:97.94ms +step:1465/1670 train_time:143487ms step_avg:97.94ms +step:1466/1670 train_time:143585ms step_avg:97.94ms +step:1467/1670 train_time:143684ms step_avg:97.94ms +step:1468/1670 train_time:143783ms step_avg:97.94ms +step:1469/1670 train_time:143881ms step_avg:97.94ms +step:1470/1670 train_time:143978ms step_avg:97.94ms +step:1471/1670 train_time:144076ms step_avg:97.94ms +step:1472/1670 train_time:144174ms step_avg:97.94ms +step:1473/1670 train_time:144272ms step_avg:97.94ms +step:1474/1670 train_time:144370ms step_avg:97.94ms +step:1475/1670 train_time:144467ms step_avg:97.94ms +step:1476/1670 train_time:144564ms step_avg:97.94ms +step:1477/1670 train_time:144663ms step_avg:97.94ms +step:1478/1670 train_time:144761ms step_avg:97.94ms +step:1479/1670 train_time:144859ms step_avg:97.94ms +step:1480/1670 train_time:144957ms step_avg:97.94ms +step:1481/1670 train_time:145054ms step_avg:97.94ms +step:1482/1670 train_time:145152ms step_avg:97.94ms +step:1483/1670 train_time:145250ms step_avg:97.94ms +step:1484/1670 train_time:145348ms step_avg:97.94ms +step:1485/1670 train_time:145693ms step_avg:98.11ms +step:1486/1670 train_time:145768ms step_avg:98.09ms +step:1487/1670 train_time:145865ms step_avg:98.09ms +step:1488/1670 train_time:145961ms step_avg:98.09ms +step:1489/1670 train_time:146058ms step_avg:98.09ms +step:1490/1670 train_time:146154ms step_avg:98.09ms +step:1491/1670 train_time:146251ms step_avg:98.09ms +step:1492/1670 train_time:146347ms step_avg:98.09ms +step:1493/1670 train_time:146444ms step_avg:98.09ms +step:1494/1670 train_time:146540ms step_avg:98.09ms +step:1495/1670 train_time:146643ms step_avg:98.09ms +step:1496/1670 train_time:146746ms step_avg:98.09ms +step:1497/1670 train_time:146847ms step_avg:98.09ms +step:1498/1670 train_time:146945ms step_avg:98.09ms +step:1499/1670 train_time:147042ms step_avg:98.09ms +step:1500/1670 train_time:147140ms step_avg:98.09ms +step:1500/1670 val_loss:3.3126 train_time:147236ms step_avg:98.16ms +step:1501/1670 train_time:147258ms step_avg:98.11ms +step:1502/1670 train_time:147342ms step_avg:98.10ms +step:1503/1670 train_time:147443ms step_avg:98.10ms +step:1504/1670 train_time:147541ms step_avg:98.10ms +step:1505/1670 train_time:147638ms step_avg:98.10ms +step:1506/1670 train_time:147735ms step_avg:98.10ms +step:1507/1670 train_time:147832ms step_avg:98.10ms +step:1508/1670 train_time:147929ms step_avg:98.10ms +step:1509/1670 train_time:148027ms step_avg:98.10ms +step:1510/1670 train_time:148124ms step_avg:98.10ms +step:1511/1670 train_time:148223ms step_avg:98.10ms +step:1512/1670 train_time:148324ms step_avg:98.10ms +step:1513/1670 train_time:148422ms step_avg:98.10ms +step:1514/1670 train_time:148521ms step_avg:98.10ms +step:1515/1670 train_time:148618ms step_avg:98.10ms +step:1516/1670 train_time:148715ms step_avg:98.10ms +step:1517/1670 train_time:148811ms step_avg:98.10ms +step:1518/1670 train_time:148908ms step_avg:98.09ms +step:1519/1670 train_time:149006ms step_avg:98.09ms +step:1520/1670 train_time:149104ms step_avg:98.09ms +step:1521/1670 train_time:149202ms step_avg:98.09ms +step:1522/1670 train_time:149301ms step_avg:98.10ms +step:1523/1670 train_time:149399ms step_avg:98.10ms +step:1524/1670 train_time:149498ms step_avg:98.10ms +step:1525/1670 train_time:149596ms step_avg:98.10ms +step:1526/1670 train_time:149693ms step_avg:98.09ms +step:1527/1670 train_time:149790ms step_avg:98.09ms +step:1528/1670 train_time:149888ms step_avg:98.09ms +step:1529/1670 train_time:149985ms step_avg:98.09ms +step:1530/1670 train_time:150083ms step_avg:98.09ms +step:1531/1670 train_time:150181ms step_avg:98.09ms +step:1532/1670 train_time:150279ms step_avg:98.09ms +step:1533/1670 train_time:150378ms step_avg:98.09ms +step:1534/1670 train_time:150477ms step_avg:98.09ms +step:1535/1670 train_time:150575ms step_avg:98.09ms +step:1536/1670 train_time:150673ms step_avg:98.09ms +step:1537/1670 train_time:150770ms step_avg:98.09ms +step:1538/1670 train_time:150868ms step_avg:98.09ms +step:1539/1670 train_time:150966ms step_avg:98.09ms +step:1540/1670 train_time:151063ms step_avg:98.09ms +step:1541/1670 train_time:151160ms step_avg:98.09ms +step:1542/1670 train_time:151258ms step_avg:98.09ms +step:1543/1670 train_time:151356ms step_avg:98.09ms +step:1544/1670 train_time:151454ms step_avg:98.09ms +step:1545/1670 train_time:151552ms step_avg:98.09ms +step:1546/1670 train_time:151650ms step_avg:98.09ms +step:1547/1670 train_time:151748ms step_avg:98.09ms +step:1548/1670 train_time:151847ms step_avg:98.09ms +step:1549/1670 train_time:151945ms step_avg:98.09ms +step:1550/1670 train_time:152043ms step_avg:98.09ms +step:1551/1670 train_time:152141ms step_avg:98.09ms +step:1552/1670 train_time:152240ms step_avg:98.09ms +step:1553/1670 train_time:152339ms step_avg:98.09ms +step:1554/1670 train_time:152437ms step_avg:98.09ms +step:1555/1670 train_time:152535ms step_avg:98.09ms +step:1556/1670 train_time:152633ms step_avg:98.09ms +step:1557/1670 train_time:152731ms step_avg:98.09ms +step:1558/1670 train_time:152829ms step_avg:98.09ms +step:1559/1670 train_time:152926ms step_avg:98.09ms +step:1560/1670 train_time:153024ms step_avg:98.09ms +step:1561/1670 train_time:153122ms step_avg:98.09ms +step:1562/1670 train_time:153220ms step_avg:98.09ms +step:1563/1670 train_time:153318ms step_avg:98.09ms +step:1564/1670 train_time:153415ms step_avg:98.09ms +step:1565/1670 train_time:153513ms step_avg:98.09ms +step:1566/1670 train_time:153611ms step_avg:98.09ms +step:1567/1670 train_time:153709ms step_avg:98.09ms +step:1568/1670 train_time:153808ms step_avg:98.09ms +step:1569/1670 train_time:153905ms step_avg:98.09ms +step:1570/1670 train_time:154003ms step_avg:98.09ms +step:1571/1670 train_time:154100ms step_avg:98.09ms +step:1572/1670 train_time:154197ms step_avg:98.09ms +step:1573/1670 train_time:154295ms step_avg:98.09ms +step:1574/1670 train_time:154392ms step_avg:98.09ms +step:1575/1670 train_time:154491ms step_avg:98.09ms +step:1576/1670 train_time:154590ms step_avg:98.09ms +step:1577/1670 train_time:154688ms step_avg:98.09ms +step:1578/1670 train_time:154785ms step_avg:98.09ms +step:1579/1670 train_time:154883ms step_avg:98.09ms +step:1580/1670 train_time:154980ms step_avg:98.09ms +step:1581/1670 train_time:155078ms step_avg:98.09ms +step:1582/1670 train_time:155176ms step_avg:98.09ms +step:1583/1670 train_time:155274ms step_avg:98.09ms +step:1584/1670 train_time:155372ms step_avg:98.09ms +step:1585/1670 train_time:155470ms step_avg:98.09ms +step:1586/1670 train_time:155569ms step_avg:98.09ms +step:1587/1670 train_time:155668ms step_avg:98.09ms +step:1588/1670 train_time:155766ms step_avg:98.09ms +step:1589/1670 train_time:155863ms step_avg:98.09ms +step:1590/1670 train_time:155961ms step_avg:98.09ms +step:1591/1670 train_time:156059ms step_avg:98.09ms +step:1592/1670 train_time:156157ms step_avg:98.09ms +step:1593/1670 train_time:156254ms step_avg:98.09ms +step:1594/1670 train_time:156352ms step_avg:98.09ms +step:1595/1670 train_time:156450ms step_avg:98.09ms +step:1596/1670 train_time:156547ms step_avg:98.09ms +step:1597/1670 train_time:156645ms step_avg:98.09ms +step:1598/1670 train_time:156742ms step_avg:98.09ms +step:1599/1670 train_time:156839ms step_avg:98.09ms +step:1600/1670 train_time:156937ms step_avg:98.09ms +step:1601/1670 train_time:157035ms step_avg:98.09ms +step:1602/1670 train_time:157133ms step_avg:98.09ms +step:1603/1670 train_time:157230ms step_avg:98.08ms +step:1604/1670 train_time:157329ms step_avg:98.09ms +step:1605/1670 train_time:157427ms step_avg:98.09ms +step:1606/1670 train_time:157524ms step_avg:98.08ms +step:1607/1670 train_time:157622ms step_avg:98.08ms +step:1608/1670 train_time:157719ms step_avg:98.08ms +step:1609/1670 train_time:157817ms step_avg:98.08ms +step:1610/1670 train_time:157915ms step_avg:98.08ms +step:1611/1670 train_time:158013ms step_avg:98.08ms +step:1612/1670 train_time:158110ms step_avg:98.08ms +step:1613/1670 train_time:158208ms step_avg:98.08ms +step:1614/1670 train_time:158307ms step_avg:98.08ms +step:1615/1670 train_time:158405ms step_avg:98.08ms +step:1616/1670 train_time:158502ms step_avg:98.08ms +step:1617/1670 train_time:158599ms step_avg:98.08ms +step:1618/1670 train_time:158697ms step_avg:98.08ms +step:1619/1670 train_time:158795ms step_avg:98.08ms +step:1620/1670 train_time:158893ms step_avg:98.08ms +step:1621/1670 train_time:158993ms step_avg:98.08ms +step:1622/1670 train_time:159092ms step_avg:98.08ms +step:1623/1670 train_time:159189ms step_avg:98.08ms +step:1624/1670 train_time:159287ms step_avg:98.08ms +step:1625/1670 train_time:159386ms step_avg:98.08ms +step:1625/1670 val_loss:3.2856 train_time:159483ms step_avg:98.14ms +step:1626/1670 train_time:159507ms step_avg:98.10ms +step:1627/1670 train_time:159588ms step_avg:98.09ms +step:1628/1670 train_time:159688ms step_avg:98.09ms +step:1629/1670 train_time:159786ms step_avg:98.09ms +step:1630/1670 train_time:159883ms step_avg:98.09ms +step:1631/1670 train_time:159980ms step_avg:98.09ms +step:1632/1670 train_time:160077ms step_avg:98.09ms +step:1633/1670 train_time:160175ms step_avg:98.09ms +step:1634/1670 train_time:160272ms step_avg:98.09ms +step:1635/1670 train_time:160368ms step_avg:98.08ms +step:1636/1670 train_time:160467ms step_avg:98.08ms +step:1637/1670 train_time:160567ms step_avg:98.09ms +step:1638/1670 train_time:160668ms step_avg:98.09ms +step:1639/1670 train_time:160767ms step_avg:98.09ms +step:1640/1670 train_time:160864ms step_avg:98.09ms +step:1641/1670 train_time:160962ms step_avg:98.09ms +step:1642/1670 train_time:161058ms step_avg:98.09ms +step:1643/1670 train_time:161156ms step_avg:98.09ms +step:1644/1670 train_time:161253ms step_avg:98.09ms +step:1645/1670 train_time:161350ms step_avg:98.09ms +step:1646/1670 train_time:161449ms step_avg:98.09ms +step:1647/1670 train_time:161548ms step_avg:98.09ms +step:1648/1670 train_time:161647ms step_avg:98.09ms +step:1649/1670 train_time:161746ms step_avg:98.09ms +step:1650/1670 train_time:161843ms step_avg:98.09ms +step:1651/1670 train_time:161940ms step_avg:98.09ms +step:1652/1670 train_time:162038ms step_avg:98.09ms +step:1653/1670 train_time:162136ms step_avg:98.09ms +step:1654/1670 train_time:162234ms step_avg:98.09ms +step:1655/1670 train_time:162332ms step_avg:98.09ms +step:1656/1670 train_time:162429ms step_avg:98.09ms +step:1657/1670 train_time:162528ms step_avg:98.09ms +step:1658/1670 train_time:162626ms step_avg:98.09ms +step:1659/1670 train_time:162725ms step_avg:98.09ms +step:1660/1670 train_time:162822ms step_avg:98.09ms +step:1661/1670 train_time:162920ms step_avg:98.09ms +step:1662/1670 train_time:163018ms step_avg:98.09ms +step:1663/1670 train_time:163117ms step_avg:98.09ms +step:1664/1670 train_time:163215ms step_avg:98.09ms +step:1665/1670 train_time:163312ms step_avg:98.09ms +step:1666/1670 train_time:163410ms step_avg:98.08ms +step:1667/1670 train_time:163507ms step_avg:98.08ms +step:1668/1670 train_time:163605ms step_avg:98.08ms +step:1669/1670 train_time:163703ms step_avg:98.08ms +step:1670/1670 train_time:163800ms step_avg:98.08ms +step:1670/1670 val_loss:3.2778 train_time:163897ms step_avg:98.14ms +peak memory allocated: 34217 MiB reserved: 49676 MiB diff --git a/records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt b/records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt new file mode 100644 index 000000000..f800a7966 --- /dev/null +++ b/records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:35:35 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 35C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 47396 C /usr/bin/python 0MiB | +| 0 N/A N/A 47397 C /usr/bin/python 0MiB | +| 0 N/A N/A 47398 C /usr/bin/python 0MiB | +| 0 N/A N/A 47399 C /usr/bin/python 0MiB | +| 0 N/A N/A 47400 C /usr/bin/python 0MiB | +| 0 N/A N/A 47401 C /usr/bin/python 0MiB | +| 0 N/A N/A 47402 C /usr/bin/python 0MiB | +| 0 N/A N/A 47403 C /usr/bin/python 0MiB | +| 1 N/A N/A 47397 C /usr/bin/python 0MiB | +| 2 N/A N/A 47398 C /usr/bin/python 0MiB | +| 3 N/A N/A 47399 C /usr/bin/python 0MiB | +| 4 N/A N/A 47400 C /usr/bin/python 0MiB | +| 5 N/A N/A 47401 C /usr/bin/python 0MiB | +| 6 N/A N/A 47402 C /usr/bin/python 0MiB | +| 7 N/A N/A 47403 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:473ms step_avg:472.59ms +step:2/1670 train_time:494ms step_avg:247.17ms +step:3/1670 train_time:566ms step_avg:188.74ms +step:4/1670 train_time:659ms step_avg:164.86ms +step:5/1670 train_time:754ms step_avg:150.77ms +step:6/1670 train_time:848ms step_avg:141.33ms +step:7/1670 train_time:943ms step_avg:134.68ms +step:8/1670 train_time:1038ms step_avg:129.80ms +step:9/1670 train_time:1133ms step_avg:125.94ms +step:10/1670 train_time:1228ms step_avg:122.75ms +step:11/1670 train_time:1322ms step_avg:120.19ms +step:12/1670 train_time:1421ms step_avg:118.45ms +step:13/1670 train_time:1523ms step_avg:117.16ms +step:14/1670 train_time:1621ms step_avg:115.75ms +step:15/1670 train_time:1716ms step_avg:114.40ms +step:16/1670 train_time:1811ms step_avg:113.19ms +step:17/1670 train_time:1906ms step_avg:112.12ms +step:18/1670 train_time:2001ms step_avg:111.15ms +step:19/1670 train_time:2096ms step_avg:110.34ms +step:20/1670 train_time:2191ms step_avg:109.57ms +step:21/1670 train_time:2287ms step_avg:108.89ms +step:22/1670 train_time:2383ms step_avg:108.31ms +step:23/1670 train_time:2479ms step_avg:107.79ms +step:24/1670 train_time:2576ms step_avg:107.35ms +step:25/1670 train_time:2673ms step_avg:106.92ms +step:26/1670 train_time:2768ms step_avg:106.47ms +step:27/1670 train_time:2863ms step_avg:106.06ms +step:28/1670 train_time:2959ms step_avg:105.70ms +step:29/1670 train_time:3055ms step_avg:105.34ms +step:30/1670 train_time:3151ms step_avg:105.03ms +step:31/1670 train_time:3247ms step_avg:104.73ms +step:32/1670 train_time:3342ms step_avg:104.43ms +step:33/1670 train_time:3438ms step_avg:104.17ms +step:34/1670 train_time:3534ms step_avg:103.94ms +step:35/1670 train_time:3630ms step_avg:103.72ms +step:36/1670 train_time:3727ms step_avg:103.52ms +step:37/1670 train_time:3823ms step_avg:103.31ms +step:38/1670 train_time:3918ms step_avg:103.11ms +step:39/1670 train_time:4014ms step_avg:102.91ms +step:40/1670 train_time:4109ms step_avg:102.73ms +step:41/1670 train_time:4204ms step_avg:102.55ms +step:42/1670 train_time:4300ms step_avg:102.38ms +step:43/1670 train_time:4396ms step_avg:102.23ms +step:44/1670 train_time:4492ms step_avg:102.09ms +step:45/1670 train_time:4587ms step_avg:101.94ms +step:46/1670 train_time:4684ms step_avg:101.83ms +step:47/1670 train_time:4780ms step_avg:101.71ms +step:48/1670 train_time:4877ms step_avg:101.59ms +step:49/1670 train_time:4973ms step_avg:101.48ms +step:50/1670 train_time:5068ms step_avg:101.36ms +step:51/1670 train_time:5164ms step_avg:101.25ms +step:52/1670 train_time:5259ms step_avg:101.13ms +step:53/1670 train_time:5355ms step_avg:101.03ms +step:54/1670 train_time:5451ms step_avg:100.94ms +step:55/1670 train_time:5547ms step_avg:100.85ms +step:56/1670 train_time:5643ms step_avg:100.77ms +step:57/1670 train_time:5739ms step_avg:100.68ms +step:58/1670 train_time:5834ms step_avg:100.59ms +step:59/1670 train_time:5930ms step_avg:100.51ms +step:60/1670 train_time:6026ms step_avg:100.43ms +step:61/1670 train_time:6122ms step_avg:100.36ms +step:62/1670 train_time:6218ms step_avg:100.29ms +step:63/1670 train_time:6313ms step_avg:100.21ms +step:64/1670 train_time:6408ms step_avg:100.13ms +step:65/1670 train_time:6504ms step_avg:100.05ms +step:66/1670 train_time:6600ms step_avg:100.00ms +step:67/1670 train_time:6695ms step_avg:99.93ms +step:68/1670 train_time:6791ms step_avg:99.87ms +step:69/1670 train_time:6887ms step_avg:99.81ms +step:70/1670 train_time:6983ms step_avg:99.76ms +step:71/1670 train_time:7079ms step_avg:99.70ms +step:72/1670 train_time:7176ms step_avg:99.66ms +step:73/1670 train_time:7272ms step_avg:99.62ms +step:74/1670 train_time:7367ms step_avg:99.56ms +step:75/1670 train_time:7463ms step_avg:99.51ms +step:76/1670 train_time:7559ms step_avg:99.46ms +step:77/1670 train_time:7655ms step_avg:99.42ms +step:78/1670 train_time:7752ms step_avg:99.38ms +step:79/1670 train_time:7847ms step_avg:99.32ms +step:80/1670 train_time:7943ms step_avg:99.28ms +step:81/1670 train_time:8038ms step_avg:99.24ms +step:82/1670 train_time:8135ms step_avg:99.21ms +step:83/1670 train_time:8230ms step_avg:99.16ms +step:84/1670 train_time:8325ms step_avg:99.11ms +step:85/1670 train_time:8421ms step_avg:99.07ms +step:86/1670 train_time:8517ms step_avg:99.04ms +step:87/1670 train_time:8613ms step_avg:99.00ms +step:88/1670 train_time:8709ms step_avg:98.97ms +step:89/1670 train_time:8804ms step_avg:98.93ms +step:90/1670 train_time:8900ms step_avg:98.89ms +step:91/1670 train_time:8995ms step_avg:98.85ms +step:92/1670 train_time:9091ms step_avg:98.81ms +step:93/1670 train_time:9186ms step_avg:98.78ms +step:94/1670 train_time:9282ms step_avg:98.74ms +step:95/1670 train_time:9379ms step_avg:98.72ms +step:96/1670 train_time:9476ms step_avg:98.70ms +step:97/1670 train_time:9572ms step_avg:98.68ms +step:98/1670 train_time:9667ms step_avg:98.65ms +step:99/1670 train_time:9763ms step_avg:98.62ms +step:100/1670 train_time:9859ms step_avg:98.59ms +step:101/1670 train_time:9955ms step_avg:98.57ms +step:102/1670 train_time:10051ms step_avg:98.54ms +step:103/1670 train_time:10146ms step_avg:98.51ms +step:104/1670 train_time:10241ms step_avg:98.48ms +step:105/1670 train_time:10337ms step_avg:98.45ms +step:106/1670 train_time:10433ms step_avg:98.43ms +step:107/1670 train_time:10529ms step_avg:98.40ms +step:108/1670 train_time:10624ms step_avg:98.37ms +step:109/1670 train_time:10720ms step_avg:98.35ms +step:110/1670 train_time:10816ms step_avg:98.33ms +step:111/1670 train_time:10912ms step_avg:98.31ms +step:112/1670 train_time:11009ms step_avg:98.29ms +step:113/1670 train_time:11104ms step_avg:98.27ms +step:114/1670 train_time:11199ms step_avg:98.24ms +step:115/1670 train_time:11295ms step_avg:98.21ms +step:116/1670 train_time:11390ms step_avg:98.19ms +step:117/1670 train_time:11485ms step_avg:98.17ms +step:118/1670 train_time:11581ms step_avg:98.14ms +step:119/1670 train_time:11677ms step_avg:98.12ms +step:120/1670 train_time:11773ms step_avg:98.10ms +step:121/1670 train_time:11868ms step_avg:98.08ms +step:122/1670 train_time:11964ms step_avg:98.06ms +step:123/1670 train_time:12060ms step_avg:98.05ms +step:124/1670 train_time:12156ms step_avg:98.03ms +step:125/1670 train_time:12251ms step_avg:98.01ms +step:125/1670 val_loss:4.3007 train_time:12346ms step_avg:98.77ms +step:126/1670 train_time:12367ms step_avg:98.15ms +step:127/1670 train_time:12446ms step_avg:98.00ms +step:128/1670 train_time:12552ms step_avg:98.06ms +step:129/1670 train_time:12649ms step_avg:98.06ms +step:130/1670 train_time:12745ms step_avg:98.04ms +step:131/1670 train_time:12840ms step_avg:98.01ms +step:132/1670 train_time:12934ms step_avg:97.99ms +step:133/1670 train_time:13029ms step_avg:97.96ms +step:134/1670 train_time:13123ms step_avg:97.94ms +step:135/1670 train_time:13219ms step_avg:97.92ms +step:136/1670 train_time:13313ms step_avg:97.89ms +step:137/1670 train_time:13409ms step_avg:97.88ms +step:138/1670 train_time:13507ms step_avg:97.88ms +step:139/1670 train_time:13605ms step_avg:97.88ms +step:140/1670 train_time:13703ms step_avg:97.88ms +step:141/1670 train_time:13797ms step_avg:97.85ms +step:142/1670 train_time:13892ms step_avg:97.83ms +step:143/1670 train_time:13988ms step_avg:97.82ms +step:144/1670 train_time:14082ms step_avg:97.79ms +step:145/1670 train_time:14177ms step_avg:97.77ms +step:146/1670 train_time:14272ms step_avg:97.75ms +step:147/1670 train_time:14367ms step_avg:97.74ms +step:148/1670 train_time:14463ms step_avg:97.73ms +step:149/1670 train_time:14560ms step_avg:97.72ms +step:150/1670 train_time:14657ms step_avg:97.71ms +step:151/1670 train_time:14752ms step_avg:97.70ms +step:152/1670 train_time:14848ms step_avg:97.68ms +step:153/1670 train_time:14943ms step_avg:97.67ms +step:154/1670 train_time:15038ms step_avg:97.65ms +step:155/1670 train_time:15133ms step_avg:97.63ms +step:156/1670 train_time:15228ms step_avg:97.61ms +step:157/1670 train_time:15323ms step_avg:97.60ms +step:158/1670 train_time:15419ms step_avg:97.59ms +step:159/1670 train_time:15515ms step_avg:97.58ms +step:160/1670 train_time:15611ms step_avg:97.57ms +step:161/1670 train_time:15707ms step_avg:97.56ms +step:162/1670 train_time:15803ms step_avg:97.55ms +step:163/1670 train_time:15899ms step_avg:97.54ms +step:164/1670 train_time:15995ms step_avg:97.53ms +step:165/1670 train_time:16090ms step_avg:97.51ms +step:166/1670 train_time:16184ms step_avg:97.50ms +step:167/1670 train_time:16279ms step_avg:97.48ms +step:168/1670 train_time:16374ms step_avg:97.47ms +step:169/1670 train_time:16470ms step_avg:97.46ms +step:170/1670 train_time:16567ms step_avg:97.45ms +step:171/1670 train_time:16663ms step_avg:97.45ms +step:172/1670 train_time:16759ms step_avg:97.44ms +step:173/1670 train_time:16855ms step_avg:97.43ms +step:174/1670 train_time:16950ms step_avg:97.41ms +step:175/1670 train_time:17046ms step_avg:97.40ms +step:176/1670 train_time:17141ms step_avg:97.39ms +step:177/1670 train_time:17237ms step_avg:97.38ms +step:178/1670 train_time:17331ms step_avg:97.37ms +step:179/1670 train_time:17427ms step_avg:97.36ms +step:180/1670 train_time:17522ms step_avg:97.35ms +step:181/1670 train_time:17618ms step_avg:97.34ms +step:182/1670 train_time:17714ms step_avg:97.33ms +step:183/1670 train_time:17809ms step_avg:97.32ms +step:184/1670 train_time:17906ms step_avg:97.31ms +step:185/1670 train_time:18001ms step_avg:97.30ms +step:186/1670 train_time:18097ms step_avg:97.29ms +step:187/1670 train_time:18192ms step_avg:97.28ms +step:188/1670 train_time:18287ms step_avg:97.27ms +step:189/1670 train_time:18382ms step_avg:97.26ms +step:190/1670 train_time:18478ms step_avg:97.25ms +step:191/1670 train_time:18574ms step_avg:97.25ms +step:192/1670 train_time:18669ms step_avg:97.23ms +step:193/1670 train_time:18766ms step_avg:97.23ms +step:194/1670 train_time:18862ms step_avg:97.23ms +step:195/1670 train_time:18957ms step_avg:97.22ms +step:196/1670 train_time:19053ms step_avg:97.21ms +step:197/1670 train_time:19148ms step_avg:97.20ms +step:198/1670 train_time:19244ms step_avg:97.19ms +step:199/1670 train_time:19340ms step_avg:97.18ms +step:200/1670 train_time:19435ms step_avg:97.17ms +step:201/1670 train_time:19530ms step_avg:97.17ms +step:202/1670 train_time:19625ms step_avg:97.15ms +step:203/1670 train_time:19721ms step_avg:97.15ms +step:204/1670 train_time:19817ms step_avg:97.14ms +step:205/1670 train_time:19913ms step_avg:97.14ms +step:206/1670 train_time:20007ms step_avg:97.12ms +step:207/1670 train_time:20103ms step_avg:97.12ms +step:208/1670 train_time:20199ms step_avg:97.11ms +step:209/1670 train_time:20294ms step_avg:97.10ms +step:210/1670 train_time:20389ms step_avg:97.09ms +step:211/1670 train_time:20485ms step_avg:97.09ms +step:212/1670 train_time:20579ms step_avg:97.07ms +step:213/1670 train_time:20860ms step_avg:97.93ms +step:214/1670 train_time:20955ms step_avg:97.92ms +step:215/1670 train_time:21048ms step_avg:97.90ms +step:216/1670 train_time:21143ms step_avg:97.88ms +step:217/1670 train_time:21238ms step_avg:97.87ms +step:218/1670 train_time:21333ms step_avg:97.86ms +step:219/1670 train_time:21427ms step_avg:97.84ms +step:220/1670 train_time:21522ms step_avg:97.83ms +step:221/1670 train_time:21617ms step_avg:97.81ms +step:222/1670 train_time:21711ms step_avg:97.80ms +step:223/1670 train_time:21808ms step_avg:97.79ms +step:224/1670 train_time:21908ms step_avg:97.80ms +step:225/1670 train_time:22005ms step_avg:97.80ms +step:226/1670 train_time:22101ms step_avg:97.79ms +step:227/1670 train_time:22196ms step_avg:97.78ms +step:228/1670 train_time:22291ms step_avg:97.77ms +step:229/1670 train_time:22385ms step_avg:97.75ms +step:230/1670 train_time:22481ms step_avg:97.74ms +step:231/1670 train_time:22575ms step_avg:97.73ms +step:232/1670 train_time:22671ms step_avg:97.72ms +step:233/1670 train_time:22766ms step_avg:97.71ms +step:234/1670 train_time:22863ms step_avg:97.71ms +step:235/1670 train_time:22961ms step_avg:97.71ms +step:236/1670 train_time:23058ms step_avg:97.70ms +step:237/1670 train_time:23154ms step_avg:97.70ms +step:238/1670 train_time:23250ms step_avg:97.69ms +step:239/1670 train_time:23345ms step_avg:97.68ms +step:240/1670 train_time:23440ms step_avg:97.67ms +step:241/1670 train_time:23535ms step_avg:97.65ms +step:242/1670 train_time:23630ms step_avg:97.64ms +step:243/1670 train_time:23725ms step_avg:97.63ms +step:244/1670 train_time:23821ms step_avg:97.63ms +step:245/1670 train_time:23918ms step_avg:97.62ms +step:246/1670 train_time:24014ms step_avg:97.62ms +step:247/1670 train_time:24109ms step_avg:97.61ms +step:248/1670 train_time:24205ms step_avg:97.60ms +step:249/1670 train_time:24301ms step_avg:97.59ms +step:250/1670 train_time:24396ms step_avg:97.58ms +step:250/1670 val_loss:3.9790 train_time:24490ms step_avg:97.96ms +step:251/1670 train_time:24515ms step_avg:97.67ms +step:252/1670 train_time:24593ms step_avg:97.59ms +step:253/1670 train_time:24693ms step_avg:97.60ms +step:254/1670 train_time:24789ms step_avg:97.60ms +step:255/1670 train_time:24884ms step_avg:97.59ms +step:256/1670 train_time:24979ms step_avg:97.57ms +step:257/1670 train_time:25074ms step_avg:97.56ms +step:258/1670 train_time:25168ms step_avg:97.55ms +step:259/1670 train_time:25263ms step_avg:97.54ms +step:260/1670 train_time:25358ms step_avg:97.53ms +step:261/1670 train_time:25454ms step_avg:97.52ms +step:262/1670 train_time:25552ms step_avg:97.53ms +step:263/1670 train_time:25650ms step_avg:97.53ms +step:264/1670 train_time:25747ms step_avg:97.53ms +step:265/1670 train_time:25842ms step_avg:97.52ms +step:266/1670 train_time:25938ms step_avg:97.51ms +step:267/1670 train_time:26032ms step_avg:97.50ms +step:268/1670 train_time:26127ms step_avg:97.49ms +step:269/1670 train_time:26222ms step_avg:97.48ms +step:270/1670 train_time:26316ms step_avg:97.47ms +step:271/1670 train_time:26412ms step_avg:97.46ms +step:272/1670 train_time:26510ms step_avg:97.46ms +step:273/1670 train_time:26605ms step_avg:97.45ms +step:274/1670 train_time:26701ms step_avg:97.45ms +step:275/1670 train_time:26798ms step_avg:97.45ms +step:276/1670 train_time:26894ms step_avg:97.44ms +step:277/1670 train_time:26990ms step_avg:97.44ms +step:278/1670 train_time:27084ms step_avg:97.42ms +step:279/1670 train_time:27180ms step_avg:97.42ms +step:280/1670 train_time:27275ms step_avg:97.41ms +step:281/1670 train_time:27370ms step_avg:97.40ms +step:282/1670 train_time:27465ms step_avg:97.39ms +step:283/1670 train_time:27561ms step_avg:97.39ms +step:284/1670 train_time:27658ms step_avg:97.39ms +step:285/1670 train_time:27755ms step_avg:97.39ms +step:286/1670 train_time:27853ms step_avg:97.39ms +step:287/1670 train_time:27949ms step_avg:97.38ms +step:288/1670 train_time:28044ms step_avg:97.37ms +step:289/1670 train_time:28139ms step_avg:97.37ms +step:290/1670 train_time:28234ms step_avg:97.36ms +step:291/1670 train_time:28329ms step_avg:97.35ms +step:292/1670 train_time:28425ms step_avg:97.34ms +step:293/1670 train_time:28520ms step_avg:97.34ms +step:294/1670 train_time:28616ms step_avg:97.33ms +step:295/1670 train_time:28713ms step_avg:97.33ms +step:296/1670 train_time:28809ms step_avg:97.33ms +step:297/1670 train_time:28904ms step_avg:97.32ms +step:298/1670 train_time:29000ms step_avg:97.32ms +step:299/1670 train_time:29097ms step_avg:97.31ms +step:300/1670 train_time:29192ms step_avg:97.31ms +step:301/1670 train_time:29287ms step_avg:97.30ms +step:302/1670 train_time:29383ms step_avg:97.29ms +step:303/1670 train_time:29478ms step_avg:97.29ms +step:304/1670 train_time:29574ms step_avg:97.28ms +step:305/1670 train_time:29671ms step_avg:97.28ms +step:306/1670 train_time:29767ms step_avg:97.28ms +step:307/1670 train_time:29863ms step_avg:97.27ms +step:308/1670 train_time:29959ms step_avg:97.27ms +step:309/1670 train_time:30056ms step_avg:97.27ms +step:310/1670 train_time:30152ms step_avg:97.27ms +step:311/1670 train_time:30248ms step_avg:97.26ms +step:312/1670 train_time:30343ms step_avg:97.25ms +step:313/1670 train_time:30437ms step_avg:97.24ms +step:314/1670 train_time:30532ms step_avg:97.24ms +step:315/1670 train_time:30628ms step_avg:97.23ms +step:316/1670 train_time:30724ms step_avg:97.23ms +step:317/1670 train_time:30819ms step_avg:97.22ms +step:318/1670 train_time:30916ms step_avg:97.22ms +step:319/1670 train_time:31013ms step_avg:97.22ms +step:320/1670 train_time:31109ms step_avg:97.22ms +step:321/1670 train_time:31204ms step_avg:97.21ms +step:322/1670 train_time:31299ms step_avg:97.20ms +step:323/1670 train_time:31395ms step_avg:97.20ms +step:324/1670 train_time:31491ms step_avg:97.19ms +step:325/1670 train_time:31586ms step_avg:97.19ms +step:326/1670 train_time:31681ms step_avg:97.18ms +step:327/1670 train_time:31777ms step_avg:97.18ms +step:328/1670 train_time:31872ms step_avg:97.17ms +step:329/1670 train_time:31968ms step_avg:97.17ms +step:330/1670 train_time:32063ms step_avg:97.16ms +step:331/1670 train_time:32160ms step_avg:97.16ms +step:332/1670 train_time:32255ms step_avg:97.16ms +step:333/1670 train_time:32351ms step_avg:97.15ms +step:334/1670 train_time:32447ms step_avg:97.15ms +step:335/1670 train_time:32542ms step_avg:97.14ms +step:336/1670 train_time:32638ms step_avg:97.14ms +step:337/1670 train_time:32733ms step_avg:97.13ms +step:338/1670 train_time:32829ms step_avg:97.13ms +step:339/1670 train_time:32925ms step_avg:97.12ms +step:340/1670 train_time:33020ms step_avg:97.12ms +step:341/1670 train_time:33116ms step_avg:97.11ms +step:342/1670 train_time:33212ms step_avg:97.11ms +step:343/1670 train_time:33308ms step_avg:97.11ms +step:344/1670 train_time:33402ms step_avg:97.10ms +step:345/1670 train_time:33497ms step_avg:97.09ms +step:346/1670 train_time:33593ms step_avg:97.09ms +step:347/1670 train_time:33689ms step_avg:97.09ms +step:348/1670 train_time:33784ms step_avg:97.08ms +step:349/1670 train_time:33880ms step_avg:97.08ms +step:350/1670 train_time:33976ms step_avg:97.07ms +step:351/1670 train_time:34072ms step_avg:97.07ms +step:352/1670 train_time:34167ms step_avg:97.07ms +step:353/1670 train_time:34263ms step_avg:97.06ms +step:354/1670 train_time:34358ms step_avg:97.06ms +step:355/1670 train_time:34454ms step_avg:97.05ms +step:356/1670 train_time:34550ms step_avg:97.05ms +step:357/1670 train_time:34645ms step_avg:97.05ms +step:358/1670 train_time:34741ms step_avg:97.04ms +step:359/1670 train_time:34836ms step_avg:97.04ms +step:360/1670 train_time:34932ms step_avg:97.03ms +step:361/1670 train_time:35029ms step_avg:97.03ms +step:362/1670 train_time:35124ms step_avg:97.03ms +step:363/1670 train_time:35220ms step_avg:97.02ms +step:364/1670 train_time:35316ms step_avg:97.02ms +step:365/1670 train_time:35412ms step_avg:97.02ms +step:366/1670 train_time:35507ms step_avg:97.01ms +step:367/1670 train_time:35602ms step_avg:97.01ms +step:368/1670 train_time:35697ms step_avg:97.00ms +step:369/1670 train_time:35794ms step_avg:97.00ms +step:370/1670 train_time:35889ms step_avg:97.00ms +step:371/1670 train_time:35984ms step_avg:96.99ms +step:372/1670 train_time:36080ms step_avg:96.99ms +step:373/1670 train_time:36177ms step_avg:96.99ms +step:374/1670 train_time:36273ms step_avg:96.99ms +step:375/1670 train_time:36369ms step_avg:96.98ms +step:375/1670 val_loss:3.8252 train_time:36465ms step_avg:97.24ms +step:376/1670 train_time:36489ms step_avg:97.04ms +step:377/1670 train_time:36567ms step_avg:96.99ms +step:378/1670 train_time:36663ms step_avg:96.99ms +step:379/1670 train_time:36759ms step_avg:96.99ms +step:380/1670 train_time:36854ms step_avg:96.98ms +step:381/1670 train_time:36949ms step_avg:96.98ms +step:382/1670 train_time:37043ms step_avg:96.97ms +step:383/1670 train_time:37138ms step_avg:96.97ms +step:384/1670 train_time:37233ms step_avg:96.96ms +step:385/1670 train_time:37328ms step_avg:96.95ms +step:386/1670 train_time:37425ms step_avg:96.95ms +step:387/1670 train_time:37522ms step_avg:96.96ms +step:388/1670 train_time:37620ms step_avg:96.96ms +step:389/1670 train_time:37717ms step_avg:96.96ms +step:390/1670 train_time:37812ms step_avg:96.95ms +step:391/1670 train_time:37907ms step_avg:96.95ms +step:392/1670 train_time:38002ms step_avg:96.94ms +step:393/1670 train_time:38098ms step_avg:96.94ms +step:394/1670 train_time:38193ms step_avg:96.94ms +step:395/1670 train_time:38287ms step_avg:96.93ms +step:396/1670 train_time:38382ms step_avg:96.92ms +step:397/1670 train_time:38479ms step_avg:96.92ms +step:398/1670 train_time:38576ms step_avg:96.93ms +step:399/1670 train_time:38673ms step_avg:96.92ms +step:400/1670 train_time:38768ms step_avg:96.92ms +step:401/1670 train_time:38864ms step_avg:96.92ms +step:402/1670 train_time:38959ms step_avg:96.91ms +step:403/1670 train_time:39054ms step_avg:96.91ms +step:404/1670 train_time:39150ms step_avg:96.90ms +step:405/1670 train_time:39245ms step_avg:96.90ms +step:406/1670 train_time:39340ms step_avg:96.90ms +step:407/1670 train_time:39436ms step_avg:96.89ms +step:408/1670 train_time:39533ms step_avg:96.89ms +step:409/1670 train_time:39629ms step_avg:96.89ms +step:410/1670 train_time:39725ms step_avg:96.89ms +step:411/1670 train_time:39820ms step_avg:96.89ms +step:412/1670 train_time:39917ms step_avg:96.89ms +step:413/1670 train_time:40013ms step_avg:96.88ms +step:414/1670 train_time:40108ms step_avg:96.88ms +step:415/1670 train_time:40203ms step_avg:96.87ms +step:416/1670 train_time:40299ms step_avg:96.87ms +step:417/1670 train_time:40395ms step_avg:96.87ms +step:418/1670 train_time:40491ms step_avg:96.87ms +step:419/1670 train_time:40587ms step_avg:96.87ms +step:420/1670 train_time:40682ms step_avg:96.86ms +step:421/1670 train_time:40777ms step_avg:96.86ms +step:422/1670 train_time:40873ms step_avg:96.86ms +step:423/1670 train_time:40969ms step_avg:96.85ms +step:424/1670 train_time:41065ms step_avg:96.85ms +step:425/1670 train_time:41353ms step_avg:97.30ms +step:426/1670 train_time:41446ms step_avg:97.29ms +step:427/1670 train_time:41540ms step_avg:97.28ms +step:428/1670 train_time:41634ms step_avg:97.28ms +step:429/1670 train_time:41729ms step_avg:97.27ms +step:430/1670 train_time:41824ms step_avg:97.26ms +step:431/1670 train_time:41918ms step_avg:97.26ms +step:432/1670 train_time:42013ms step_avg:97.25ms +step:433/1670 train_time:42108ms step_avg:97.25ms +step:434/1670 train_time:42203ms step_avg:97.24ms +step:435/1670 train_time:42304ms step_avg:97.25ms +step:436/1670 train_time:42401ms step_avg:97.25ms +step:437/1670 train_time:42498ms step_avg:97.25ms +step:438/1670 train_time:42594ms step_avg:97.25ms +step:439/1670 train_time:42689ms step_avg:97.24ms +step:440/1670 train_time:42784ms step_avg:97.24ms +step:441/1670 train_time:42879ms step_avg:97.23ms +step:442/1670 train_time:42974ms step_avg:97.23ms +step:443/1670 train_time:43068ms step_avg:97.22ms +step:444/1670 train_time:43163ms step_avg:97.21ms +step:445/1670 train_time:43260ms step_avg:97.21ms +step:446/1670 train_time:43357ms step_avg:97.21ms +step:447/1670 train_time:43453ms step_avg:97.21ms +step:448/1670 train_time:43549ms step_avg:97.21ms +step:449/1670 train_time:43644ms step_avg:97.20ms +step:450/1670 train_time:43740ms step_avg:97.20ms +step:451/1670 train_time:43835ms step_avg:97.19ms +step:452/1670 train_time:43930ms step_avg:97.19ms +step:453/1670 train_time:44024ms step_avg:97.18ms +step:454/1670 train_time:44120ms step_avg:97.18ms +step:455/1670 train_time:44216ms step_avg:97.18ms +step:456/1670 train_time:44313ms step_avg:97.18ms +step:457/1670 train_time:44409ms step_avg:97.18ms +step:458/1670 train_time:44506ms step_avg:97.17ms +step:459/1670 train_time:44601ms step_avg:97.17ms +step:460/1670 train_time:44697ms step_avg:97.17ms +step:461/1670 train_time:44793ms step_avg:97.16ms +step:462/1670 train_time:44887ms step_avg:97.16ms +step:463/1670 train_time:44982ms step_avg:97.15ms +step:464/1670 train_time:45078ms step_avg:97.15ms +step:465/1670 train_time:45174ms step_avg:97.15ms +step:466/1670 train_time:45270ms step_avg:97.15ms +step:467/1670 train_time:45366ms step_avg:97.14ms +step:468/1670 train_time:45461ms step_avg:97.14ms +step:469/1670 train_time:45557ms step_avg:97.14ms +step:470/1670 train_time:45653ms step_avg:97.13ms +step:471/1670 train_time:45749ms step_avg:97.13ms +step:472/1670 train_time:45845ms step_avg:97.13ms +step:473/1670 train_time:45940ms step_avg:97.12ms +step:474/1670 train_time:46035ms step_avg:97.12ms +step:475/1670 train_time:46130ms step_avg:97.12ms +step:476/1670 train_time:46226ms step_avg:97.11ms +step:477/1670 train_time:46322ms step_avg:97.11ms +step:478/1670 train_time:46418ms step_avg:97.11ms +step:479/1670 train_time:46515ms step_avg:97.11ms +step:480/1670 train_time:46611ms step_avg:97.11ms +step:481/1670 train_time:46706ms step_avg:97.10ms +step:482/1670 train_time:46801ms step_avg:97.10ms +step:483/1670 train_time:46897ms step_avg:97.10ms +step:484/1670 train_time:46993ms step_avg:97.09ms +step:485/1670 train_time:47087ms step_avg:97.09ms +step:486/1670 train_time:47183ms step_avg:97.08ms +step:487/1670 train_time:47278ms step_avg:97.08ms +step:488/1670 train_time:47375ms step_avg:97.08ms +step:489/1670 train_time:47471ms step_avg:97.08ms +step:490/1670 train_time:47567ms step_avg:97.08ms +step:491/1670 train_time:47662ms step_avg:97.07ms +step:492/1670 train_time:47758ms step_avg:97.07ms +step:493/1670 train_time:47854ms step_avg:97.07ms +step:494/1670 train_time:47950ms step_avg:97.06ms +step:495/1670 train_time:48045ms step_avg:97.06ms +step:496/1670 train_time:48140ms step_avg:97.06ms +step:497/1670 train_time:48236ms step_avg:97.05ms +step:498/1670 train_time:48332ms step_avg:97.05ms +step:499/1670 train_time:48428ms step_avg:97.05ms +step:500/1670 train_time:48523ms step_avg:97.05ms +step:500/1670 val_loss:3.7242 train_time:48619ms step_avg:97.24ms +step:501/1670 train_time:48643ms step_avg:97.09ms +step:502/1670 train_time:48722ms step_avg:97.06ms +step:503/1670 train_time:48821ms step_avg:97.06ms +step:504/1670 train_time:48916ms step_avg:97.06ms +step:505/1670 train_time:49011ms step_avg:97.05ms +step:506/1670 train_time:49106ms step_avg:97.05ms +step:507/1670 train_time:49201ms step_avg:97.04ms +step:508/1670 train_time:49295ms step_avg:97.04ms +step:509/1670 train_time:49390ms step_avg:97.03ms +step:510/1670 train_time:49485ms step_avg:97.03ms +step:511/1670 train_time:49581ms step_avg:97.03ms +step:512/1670 train_time:49678ms step_avg:97.03ms +step:513/1670 train_time:49775ms step_avg:97.03ms +step:514/1670 train_time:49873ms step_avg:97.03ms +step:515/1670 train_time:49968ms step_avg:97.03ms +step:516/1670 train_time:50064ms step_avg:97.02ms +step:517/1670 train_time:50159ms step_avg:97.02ms +step:518/1670 train_time:50254ms step_avg:97.01ms +step:519/1670 train_time:50348ms step_avg:97.01ms +step:520/1670 train_time:50443ms step_avg:97.01ms +step:521/1670 train_time:50537ms step_avg:97.00ms +step:522/1670 train_time:50634ms step_avg:97.00ms +step:523/1670 train_time:50731ms step_avg:97.00ms +step:524/1670 train_time:50828ms step_avg:97.00ms +step:525/1670 train_time:50925ms step_avg:97.00ms +step:526/1670 train_time:51021ms step_avg:97.00ms +step:527/1670 train_time:51116ms step_avg:97.00ms +step:528/1670 train_time:51211ms step_avg:96.99ms +step:529/1670 train_time:51306ms step_avg:96.99ms +step:530/1670 train_time:51401ms step_avg:96.98ms +step:531/1670 train_time:51496ms step_avg:96.98ms +step:532/1670 train_time:51591ms step_avg:96.98ms +step:533/1670 train_time:51688ms step_avg:96.98ms +step:534/1670 train_time:51785ms step_avg:96.98ms +step:535/1670 train_time:51880ms step_avg:96.97ms +step:536/1670 train_time:51976ms step_avg:96.97ms +step:537/1670 train_time:52072ms step_avg:96.97ms +step:538/1670 train_time:52168ms step_avg:96.97ms +step:539/1670 train_time:52264ms step_avg:96.96ms +step:540/1670 train_time:52359ms step_avg:96.96ms +step:541/1670 train_time:52453ms step_avg:96.96ms +step:542/1670 train_time:52549ms step_avg:96.95ms +step:543/1670 train_time:52645ms step_avg:96.95ms +step:544/1670 train_time:52740ms step_avg:96.95ms +step:545/1670 train_time:52836ms step_avg:96.95ms +step:546/1670 train_time:52932ms step_avg:96.95ms +step:547/1670 train_time:53029ms step_avg:96.95ms +step:548/1670 train_time:53125ms step_avg:96.94ms +step:549/1670 train_time:53221ms step_avg:96.94ms +step:550/1670 train_time:53316ms step_avg:96.94ms +step:551/1670 train_time:53411ms step_avg:96.93ms +step:552/1670 train_time:53506ms step_avg:96.93ms +step:553/1670 train_time:53602ms step_avg:96.93ms +step:554/1670 train_time:53698ms step_avg:96.93ms +step:555/1670 train_time:53794ms step_avg:96.93ms +step:556/1670 train_time:53890ms step_avg:96.92ms +step:557/1670 train_time:53986ms step_avg:96.92ms +step:558/1670 train_time:54082ms step_avg:96.92ms +step:559/1670 train_time:54179ms step_avg:96.92ms +step:560/1670 train_time:54276ms step_avg:96.92ms +step:561/1670 train_time:54372ms step_avg:96.92ms +step:562/1670 train_time:54469ms step_avg:96.92ms +step:563/1670 train_time:54565ms step_avg:96.92ms +step:564/1670 train_time:54663ms step_avg:96.92ms +step:565/1670 train_time:54759ms step_avg:96.92ms +step:566/1670 train_time:54856ms step_avg:96.92ms +step:567/1670 train_time:54954ms step_avg:96.92ms +step:568/1670 train_time:55052ms step_avg:96.92ms +step:569/1670 train_time:55150ms step_avg:96.93ms +step:570/1670 train_time:55249ms step_avg:96.93ms +step:571/1670 train_time:55346ms step_avg:96.93ms +step:572/1670 train_time:55443ms step_avg:96.93ms +step:573/1670 train_time:55540ms step_avg:96.93ms +step:574/1670 train_time:55636ms step_avg:96.93ms +step:575/1670 train_time:55732ms step_avg:96.93ms +step:576/1670 train_time:55830ms step_avg:96.93ms +step:577/1670 train_time:55927ms step_avg:96.93ms +step:578/1670 train_time:56024ms step_avg:96.93ms +step:579/1670 train_time:56122ms step_avg:96.93ms +step:580/1670 train_time:56218ms step_avg:96.93ms +step:581/1670 train_time:56315ms step_avg:96.93ms +step:582/1670 train_time:56413ms step_avg:96.93ms +step:583/1670 train_time:56510ms step_avg:96.93ms +step:584/1670 train_time:56608ms step_avg:96.93ms +step:585/1670 train_time:56705ms step_avg:96.93ms +step:586/1670 train_time:56802ms step_avg:96.93ms +step:587/1670 train_time:56900ms step_avg:96.93ms +step:588/1670 train_time:56996ms step_avg:96.93ms +step:589/1670 train_time:57093ms step_avg:96.93ms +step:590/1670 train_time:57191ms step_avg:96.93ms +step:591/1670 train_time:57289ms step_avg:96.94ms +step:592/1670 train_time:57386ms step_avg:96.93ms +step:593/1670 train_time:57483ms step_avg:96.94ms +step:594/1670 train_time:57580ms step_avg:96.94ms +step:595/1670 train_time:57676ms step_avg:96.93ms +step:596/1670 train_time:57773ms step_avg:96.93ms +step:597/1670 train_time:57871ms step_avg:96.94ms +step:598/1670 train_time:57969ms step_avg:96.94ms +step:599/1670 train_time:58066ms step_avg:96.94ms +step:600/1670 train_time:58163ms step_avg:96.94ms +step:601/1670 train_time:58260ms step_avg:96.94ms +step:602/1670 train_time:58356ms step_avg:96.94ms +step:603/1670 train_time:58454ms step_avg:96.94ms +step:604/1670 train_time:58552ms step_avg:96.94ms +step:605/1670 train_time:58650ms step_avg:96.94ms +step:606/1670 train_time:58748ms step_avg:96.94ms +step:607/1670 train_time:58845ms step_avg:96.94ms +step:608/1670 train_time:58942ms step_avg:96.94ms +step:609/1670 train_time:59038ms step_avg:96.94ms +step:610/1670 train_time:59135ms step_avg:96.94ms +step:611/1670 train_time:59232ms step_avg:96.94ms +step:612/1670 train_time:59330ms step_avg:96.94ms +step:613/1670 train_time:59428ms step_avg:96.95ms +step:614/1670 train_time:59525ms step_avg:96.95ms +step:615/1670 train_time:59622ms step_avg:96.95ms +step:616/1670 train_time:59719ms step_avg:96.95ms +step:617/1670 train_time:59816ms step_avg:96.95ms +step:618/1670 train_time:59913ms step_avg:96.95ms +step:619/1670 train_time:60010ms step_avg:96.95ms +step:620/1670 train_time:60107ms step_avg:96.95ms +step:621/1670 train_time:60205ms step_avg:96.95ms +step:622/1670 train_time:60302ms step_avg:96.95ms +step:623/1670 train_time:60398ms step_avg:96.95ms +step:624/1670 train_time:60496ms step_avg:96.95ms +step:625/1670 train_time:60593ms step_avg:96.95ms +step:625/1670 val_loss:3.6224 train_time:60691ms step_avg:97.11ms +step:626/1670 train_time:60717ms step_avg:96.99ms +step:627/1670 train_time:60794ms step_avg:96.96ms +step:628/1670 train_time:60891ms step_avg:96.96ms +step:629/1670 train_time:60987ms step_avg:96.96ms +step:630/1670 train_time:61082ms step_avg:96.96ms +step:631/1670 train_time:61177ms step_avg:96.95ms +step:632/1670 train_time:61273ms step_avg:96.95ms +step:633/1670 train_time:61369ms step_avg:96.95ms +step:634/1670 train_time:61465ms step_avg:96.95ms +step:635/1670 train_time:61561ms step_avg:96.95ms +step:636/1670 train_time:61661ms step_avg:96.95ms +step:637/1670 train_time:61761ms step_avg:96.96ms +step:638/1670 train_time:61859ms step_avg:96.96ms +step:639/1670 train_time:62129ms step_avg:97.23ms +step:640/1670 train_time:62324ms step_avg:97.38ms +step:641/1670 train_time:62419ms step_avg:97.38ms +step:642/1670 train_time:62515ms step_avg:97.37ms +step:643/1670 train_time:62610ms step_avg:97.37ms +step:644/1670 train_time:62706ms step_avg:97.37ms +step:645/1670 train_time:62802ms step_avg:97.37ms +step:646/1670 train_time:62897ms step_avg:97.36ms +step:647/1670 train_time:62993ms step_avg:97.36ms +step:648/1670 train_time:63089ms step_avg:97.36ms +step:649/1670 train_time:63187ms step_avg:97.36ms +step:650/1670 train_time:63291ms step_avg:97.37ms +step:651/1670 train_time:63392ms step_avg:97.38ms +step:652/1670 train_time:63490ms step_avg:97.38ms +step:653/1670 train_time:63587ms step_avg:97.38ms +step:654/1670 train_time:63683ms step_avg:97.38ms +step:655/1670 train_time:63779ms step_avg:97.37ms +step:656/1670 train_time:63875ms step_avg:97.37ms +step:657/1670 train_time:63971ms step_avg:97.37ms +step:658/1670 train_time:64068ms step_avg:97.37ms +step:659/1670 train_time:64165ms step_avg:97.37ms +step:660/1670 train_time:64264ms step_avg:97.37ms +step:661/1670 train_time:64362ms step_avg:97.37ms +step:662/1670 train_time:64461ms step_avg:97.37ms +step:663/1670 train_time:64558ms step_avg:97.37ms +step:664/1670 train_time:64654ms step_avg:97.37ms +step:665/1670 train_time:64751ms step_avg:97.37ms +step:666/1670 train_time:64847ms step_avg:97.37ms +step:667/1670 train_time:64944ms step_avg:97.37ms +step:668/1670 train_time:65041ms step_avg:97.37ms +step:669/1670 train_time:65137ms step_avg:97.36ms +step:670/1670 train_time:65235ms step_avg:97.36ms +step:671/1670 train_time:65332ms step_avg:97.37ms +step:672/1670 train_time:65431ms step_avg:97.37ms +step:673/1670 train_time:65528ms step_avg:97.37ms +step:674/1670 train_time:65627ms step_avg:97.37ms +step:675/1670 train_time:65724ms step_avg:97.37ms +step:676/1670 train_time:65820ms step_avg:97.37ms +step:677/1670 train_time:65916ms step_avg:97.37ms +step:678/1670 train_time:66012ms step_avg:97.36ms +step:679/1670 train_time:66110ms step_avg:97.36ms +step:680/1670 train_time:66207ms step_avg:97.36ms +step:681/1670 train_time:66304ms step_avg:97.36ms +step:682/1670 train_time:66402ms step_avg:97.36ms +step:683/1670 train_time:66499ms step_avg:97.36ms +step:684/1670 train_time:66596ms step_avg:97.36ms +step:685/1670 train_time:66692ms step_avg:97.36ms +step:686/1670 train_time:66790ms step_avg:97.36ms +step:687/1670 train_time:66888ms step_avg:97.36ms +step:688/1670 train_time:66984ms step_avg:97.36ms +step:689/1670 train_time:67081ms step_avg:97.36ms +step:690/1670 train_time:67178ms step_avg:97.36ms +step:691/1670 train_time:67275ms step_avg:97.36ms +step:692/1670 train_time:67371ms step_avg:97.36ms +step:693/1670 train_time:67470ms step_avg:97.36ms +step:694/1670 train_time:67568ms step_avg:97.36ms +step:695/1670 train_time:67665ms step_avg:97.36ms +step:696/1670 train_time:67763ms step_avg:97.36ms +step:697/1670 train_time:67859ms step_avg:97.36ms +step:698/1670 train_time:67956ms step_avg:97.36ms +step:699/1670 train_time:68052ms step_avg:97.36ms +step:700/1670 train_time:68150ms step_avg:97.36ms +step:701/1670 train_time:68248ms step_avg:97.36ms +step:702/1670 train_time:68346ms step_avg:97.36ms +step:703/1670 train_time:68443ms step_avg:97.36ms +step:704/1670 train_time:68541ms step_avg:97.36ms +step:705/1670 train_time:68638ms step_avg:97.36ms +step:706/1670 train_time:68735ms step_avg:97.36ms +step:707/1670 train_time:68832ms step_avg:97.36ms +step:708/1670 train_time:68929ms step_avg:97.36ms +step:709/1670 train_time:69026ms step_avg:97.36ms +step:710/1670 train_time:69123ms step_avg:97.36ms +step:711/1670 train_time:69220ms step_avg:97.36ms +step:712/1670 train_time:69317ms step_avg:97.36ms +step:713/1670 train_time:69414ms step_avg:97.35ms +step:714/1670 train_time:69511ms step_avg:97.35ms +step:715/1670 train_time:69609ms step_avg:97.35ms +step:716/1670 train_time:69707ms step_avg:97.36ms +step:717/1670 train_time:69803ms step_avg:97.35ms +step:718/1670 train_time:69900ms step_avg:97.35ms +step:719/1670 train_time:69997ms step_avg:97.35ms +step:720/1670 train_time:70092ms step_avg:97.35ms +step:721/1670 train_time:70190ms step_avg:97.35ms +step:722/1670 train_time:70288ms step_avg:97.35ms +step:723/1670 train_time:70385ms step_avg:97.35ms +step:724/1670 train_time:70482ms step_avg:97.35ms +step:725/1670 train_time:70580ms step_avg:97.35ms +step:726/1670 train_time:70677ms step_avg:97.35ms +step:727/1670 train_time:70774ms step_avg:97.35ms +step:728/1670 train_time:70872ms step_avg:97.35ms +step:729/1670 train_time:70969ms step_avg:97.35ms +step:730/1670 train_time:71066ms step_avg:97.35ms +step:731/1670 train_time:71163ms step_avg:97.35ms +step:732/1670 train_time:71261ms step_avg:97.35ms +step:733/1670 train_time:71357ms step_avg:97.35ms +step:734/1670 train_time:71453ms step_avg:97.35ms +step:735/1670 train_time:71551ms step_avg:97.35ms +step:736/1670 train_time:71648ms step_avg:97.35ms +step:737/1670 train_time:71745ms step_avg:97.35ms +step:738/1670 train_time:71842ms step_avg:97.35ms +step:739/1670 train_time:71940ms step_avg:97.35ms +step:740/1670 train_time:72037ms step_avg:97.35ms +step:741/1670 train_time:72133ms step_avg:97.35ms +step:742/1670 train_time:72230ms step_avg:97.35ms +step:743/1670 train_time:72328ms step_avg:97.35ms +step:744/1670 train_time:72425ms step_avg:97.35ms +step:745/1670 train_time:72523ms step_avg:97.35ms +step:746/1670 train_time:72619ms step_avg:97.35ms +step:747/1670 train_time:72716ms step_avg:97.34ms +step:748/1670 train_time:72814ms step_avg:97.34ms +step:749/1670 train_time:72910ms step_avg:97.34ms +step:750/1670 train_time:73008ms step_avg:97.34ms +step:750/1670 val_loss:3.5662 train_time:73105ms step_avg:97.47ms +step:751/1670 train_time:73129ms step_avg:97.38ms +step:752/1670 train_time:73210ms step_avg:97.35ms +step:753/1670 train_time:73310ms step_avg:97.36ms +step:754/1670 train_time:73407ms step_avg:97.36ms +step:755/1670 train_time:73503ms step_avg:97.35ms +step:756/1670 train_time:73599ms step_avg:97.35ms +step:757/1670 train_time:73695ms step_avg:97.35ms +step:758/1670 train_time:73791ms step_avg:97.35ms +step:759/1670 train_time:73888ms step_avg:97.35ms +step:760/1670 train_time:73983ms step_avg:97.35ms +step:761/1670 train_time:74081ms step_avg:97.35ms +step:762/1670 train_time:74181ms step_avg:97.35ms +step:763/1670 train_time:74279ms step_avg:97.35ms +step:764/1670 train_time:74377ms step_avg:97.35ms +step:765/1670 train_time:74473ms step_avg:97.35ms +step:766/1670 train_time:74569ms step_avg:97.35ms +step:767/1670 train_time:74666ms step_avg:97.35ms +step:768/1670 train_time:74763ms step_avg:97.35ms +step:769/1670 train_time:74860ms step_avg:97.35ms +step:770/1670 train_time:74956ms step_avg:97.35ms +step:771/1670 train_time:75053ms step_avg:97.34ms +step:772/1670 train_time:75150ms step_avg:97.34ms +step:773/1670 train_time:75249ms step_avg:97.35ms +step:774/1670 train_time:75347ms step_avg:97.35ms +step:775/1670 train_time:75444ms step_avg:97.35ms +step:776/1670 train_time:75541ms step_avg:97.35ms +step:777/1670 train_time:75637ms step_avg:97.35ms +step:778/1670 train_time:75733ms step_avg:97.34ms +step:779/1670 train_time:75830ms step_avg:97.34ms +step:780/1670 train_time:75927ms step_avg:97.34ms +step:781/1670 train_time:76025ms step_avg:97.34ms +step:782/1670 train_time:76123ms step_avg:97.34ms +step:783/1670 train_time:76221ms step_avg:97.34ms +step:784/1670 train_time:76318ms step_avg:97.34ms +step:785/1670 train_time:76414ms step_avg:97.34ms +step:786/1670 train_time:76511ms step_avg:97.34ms +step:787/1670 train_time:76610ms step_avg:97.34ms +step:788/1670 train_time:76707ms step_avg:97.34ms +step:789/1670 train_time:76804ms step_avg:97.34ms +step:790/1670 train_time:76901ms step_avg:97.34ms +step:791/1670 train_time:76998ms step_avg:97.34ms +step:792/1670 train_time:77094ms step_avg:97.34ms +step:793/1670 train_time:77191ms step_avg:97.34ms +step:794/1670 train_time:77289ms step_avg:97.34ms +step:795/1670 train_time:77386ms step_avg:97.34ms +step:796/1670 train_time:77484ms step_avg:97.34ms +step:797/1670 train_time:77581ms step_avg:97.34ms +step:798/1670 train_time:77678ms step_avg:97.34ms +step:799/1670 train_time:77774ms step_avg:97.34ms +step:800/1670 train_time:77870ms step_avg:97.34ms +step:801/1670 train_time:77969ms step_avg:97.34ms +step:802/1670 train_time:78068ms step_avg:97.34ms +step:803/1670 train_time:78165ms step_avg:97.34ms +step:804/1670 train_time:78263ms step_avg:97.34ms +step:805/1670 train_time:78360ms step_avg:97.34ms +step:806/1670 train_time:78457ms step_avg:97.34ms +step:807/1670 train_time:78555ms step_avg:97.34ms +step:808/1670 train_time:78651ms step_avg:97.34ms +step:809/1670 train_time:78748ms step_avg:97.34ms +step:810/1670 train_time:78846ms step_avg:97.34ms +step:811/1670 train_time:78942ms step_avg:97.34ms +step:812/1670 train_time:79040ms step_avg:97.34ms +step:813/1670 train_time:79136ms step_avg:97.34ms +step:814/1670 train_time:79234ms step_avg:97.34ms +step:815/1670 train_time:79332ms step_avg:97.34ms +step:816/1670 train_time:79429ms step_avg:97.34ms +step:817/1670 train_time:79526ms step_avg:97.34ms +step:818/1670 train_time:79624ms step_avg:97.34ms +step:819/1670 train_time:79720ms step_avg:97.34ms +step:820/1670 train_time:79817ms step_avg:97.34ms +step:821/1670 train_time:79914ms step_avg:97.34ms +step:822/1670 train_time:80011ms step_avg:97.34ms +step:823/1670 train_time:80109ms step_avg:97.34ms +step:824/1670 train_time:80206ms step_avg:97.34ms +step:825/1670 train_time:80303ms step_avg:97.34ms +step:826/1670 train_time:80400ms step_avg:97.34ms +step:827/1670 train_time:80497ms step_avg:97.34ms +step:828/1670 train_time:80594ms step_avg:97.34ms +step:829/1670 train_time:80690ms step_avg:97.33ms +step:830/1670 train_time:80787ms step_avg:97.33ms +step:831/1670 train_time:80884ms step_avg:97.33ms +step:832/1670 train_time:80981ms step_avg:97.33ms +step:833/1670 train_time:81079ms step_avg:97.33ms +step:834/1670 train_time:81176ms step_avg:97.33ms +step:835/1670 train_time:81273ms step_avg:97.33ms +step:836/1670 train_time:81370ms step_avg:97.33ms +step:837/1670 train_time:81467ms step_avg:97.33ms +step:838/1670 train_time:81565ms step_avg:97.33ms +step:839/1670 train_time:81662ms step_avg:97.33ms +step:840/1670 train_time:81759ms step_avg:97.33ms +step:841/1670 train_time:81855ms step_avg:97.33ms +step:842/1670 train_time:81952ms step_avg:97.33ms +step:843/1670 train_time:82050ms step_avg:97.33ms +step:844/1670 train_time:82147ms step_avg:97.33ms +step:845/1670 train_time:82245ms step_avg:97.33ms +step:846/1670 train_time:82343ms step_avg:97.33ms +step:847/1670 train_time:82441ms step_avg:97.33ms +step:848/1670 train_time:82537ms step_avg:97.33ms +step:849/1670 train_time:82634ms step_avg:97.33ms +step:850/1670 train_time:82731ms step_avg:97.33ms +step:851/1670 train_time:83001ms step_avg:97.53ms +step:852/1670 train_time:83142ms step_avg:97.58ms +step:853/1670 train_time:83238ms step_avg:97.58ms +step:854/1670 train_time:83333ms step_avg:97.58ms +step:855/1670 train_time:83429ms step_avg:97.58ms +step:856/1670 train_time:83524ms step_avg:97.57ms +step:857/1670 train_time:83620ms step_avg:97.57ms +step:858/1670 train_time:83716ms step_avg:97.57ms +step:859/1670 train_time:83812ms step_avg:97.57ms +step:860/1670 train_time:83908ms step_avg:97.57ms +step:861/1670 train_time:84007ms step_avg:97.57ms +step:862/1670 train_time:84112ms step_avg:97.58ms +step:863/1670 train_time:84210ms step_avg:97.58ms +step:864/1670 train_time:84308ms step_avg:97.58ms +step:865/1670 train_time:84405ms step_avg:97.58ms +step:866/1670 train_time:84501ms step_avg:97.58ms +step:867/1670 train_time:84597ms step_avg:97.57ms +step:868/1670 train_time:84692ms step_avg:97.57ms +step:869/1670 train_time:84788ms step_avg:97.57ms +step:870/1670 train_time:84884ms step_avg:97.57ms +step:871/1670 train_time:84982ms step_avg:97.57ms +step:872/1670 train_time:85081ms step_avg:97.57ms +step:873/1670 train_time:85179ms step_avg:97.57ms +step:874/1670 train_time:85277ms step_avg:97.57ms +step:875/1670 train_time:85373ms step_avg:97.57ms +step:875/1670 val_loss:3.5252 train_time:85468ms step_avg:97.68ms +step:876/1670 train_time:85492ms step_avg:97.59ms +step:877/1670 train_time:85572ms step_avg:97.57ms +step:878/1670 train_time:85671ms step_avg:97.58ms +step:879/1670 train_time:85768ms step_avg:97.57ms +step:880/1670 train_time:85863ms step_avg:97.57ms +step:881/1670 train_time:85959ms step_avg:97.57ms +step:882/1670 train_time:86055ms step_avg:97.57ms +step:883/1670 train_time:86150ms step_avg:97.57ms +step:884/1670 train_time:86247ms step_avg:97.56ms +step:885/1670 train_time:86343ms step_avg:97.56ms +step:886/1670 train_time:86441ms step_avg:97.56ms +step:887/1670 train_time:86540ms step_avg:97.57ms +step:888/1670 train_time:86639ms step_avg:97.57ms +step:889/1670 train_time:86736ms step_avg:97.57ms +step:890/1670 train_time:86834ms step_avg:97.57ms +step:891/1670 train_time:86930ms step_avg:97.56ms +step:892/1670 train_time:87026ms step_avg:97.56ms +step:893/1670 train_time:87121ms step_avg:97.56ms +step:894/1670 train_time:87217ms step_avg:97.56ms +step:895/1670 train_time:87315ms step_avg:97.56ms +step:896/1670 train_time:87413ms step_avg:97.56ms +step:897/1670 train_time:87512ms step_avg:97.56ms +step:898/1670 train_time:87610ms step_avg:97.56ms +step:899/1670 train_time:87708ms step_avg:97.56ms +step:900/1670 train_time:87805ms step_avg:97.56ms +step:901/1670 train_time:87902ms step_avg:97.56ms +step:902/1670 train_time:87998ms step_avg:97.56ms +step:903/1670 train_time:88095ms step_avg:97.56ms +step:904/1670 train_time:88192ms step_avg:97.56ms +step:905/1670 train_time:88289ms step_avg:97.56ms +step:906/1670 train_time:88386ms step_avg:97.56ms +step:907/1670 train_time:88483ms step_avg:97.56ms +step:908/1670 train_time:88580ms step_avg:97.56ms +step:909/1670 train_time:88678ms step_avg:97.56ms +step:910/1670 train_time:88778ms step_avg:97.56ms +step:911/1670 train_time:88876ms step_avg:97.56ms +step:912/1670 train_time:88972ms step_avg:97.56ms +step:913/1670 train_time:89069ms step_avg:97.56ms +step:914/1670 train_time:89166ms step_avg:97.56ms +step:915/1670 train_time:89262ms step_avg:97.55ms +step:916/1670 train_time:89358ms step_avg:97.55ms +step:917/1670 train_time:89457ms step_avg:97.55ms +step:918/1670 train_time:89555ms step_avg:97.55ms +step:919/1670 train_time:89653ms step_avg:97.56ms +step:920/1670 train_time:89751ms step_avg:97.56ms +step:921/1670 train_time:89849ms step_avg:97.56ms +step:922/1670 train_time:89946ms step_avg:97.56ms +step:923/1670 train_time:90043ms step_avg:97.55ms +step:924/1670 train_time:90141ms step_avg:97.55ms +step:925/1670 train_time:90237ms step_avg:97.55ms +step:926/1670 train_time:90334ms step_avg:97.55ms +step:927/1670 train_time:90432ms step_avg:97.55ms +step:928/1670 train_time:90529ms step_avg:97.55ms +step:929/1670 train_time:90626ms step_avg:97.55ms +step:930/1670 train_time:90724ms step_avg:97.55ms +step:931/1670 train_time:90821ms step_avg:97.55ms +step:932/1670 train_time:90918ms step_avg:97.55ms +step:933/1670 train_time:91015ms step_avg:97.55ms +step:934/1670 train_time:91112ms step_avg:97.55ms +step:935/1670 train_time:91209ms step_avg:97.55ms +step:936/1670 train_time:91306ms step_avg:97.55ms +step:937/1670 train_time:91402ms step_avg:97.55ms +step:938/1670 train_time:91499ms step_avg:97.55ms +step:939/1670 train_time:91596ms step_avg:97.55ms +step:940/1670 train_time:91694ms step_avg:97.55ms +step:941/1670 train_time:91792ms step_avg:97.55ms +step:942/1670 train_time:91889ms step_avg:97.55ms +step:943/1670 train_time:91986ms step_avg:97.55ms +step:944/1670 train_time:92082ms step_avg:97.54ms +step:945/1670 train_time:92179ms step_avg:97.54ms +step:946/1670 train_time:92277ms step_avg:97.54ms +step:947/1670 train_time:92376ms step_avg:97.55ms +step:948/1670 train_time:92472ms step_avg:97.54ms +step:949/1670 train_time:92571ms step_avg:97.55ms +step:950/1670 train_time:92668ms step_avg:97.55ms +step:951/1670 train_time:92765ms step_avg:97.54ms +step:952/1670 train_time:92862ms step_avg:97.54ms +step:953/1670 train_time:92958ms step_avg:97.54ms +step:954/1670 train_time:93056ms step_avg:97.54ms +step:955/1670 train_time:93153ms step_avg:97.54ms +step:956/1670 train_time:93250ms step_avg:97.54ms +step:957/1670 train_time:93347ms step_avg:97.54ms +step:958/1670 train_time:93444ms step_avg:97.54ms +step:959/1670 train_time:93541ms step_avg:97.54ms +step:960/1670 train_time:93638ms step_avg:97.54ms +step:961/1670 train_time:93736ms step_avg:97.54ms +step:962/1670 train_time:93834ms step_avg:97.54ms +step:963/1670 train_time:93931ms step_avg:97.54ms +step:964/1670 train_time:94028ms step_avg:97.54ms +step:965/1670 train_time:94126ms step_avg:97.54ms +step:966/1670 train_time:94222ms step_avg:97.54ms +step:967/1670 train_time:94319ms step_avg:97.54ms +step:968/1670 train_time:94418ms step_avg:97.54ms +step:969/1670 train_time:94515ms step_avg:97.54ms +step:970/1670 train_time:94612ms step_avg:97.54ms +step:971/1670 train_time:94711ms step_avg:97.54ms +step:972/1670 train_time:94808ms step_avg:97.54ms +step:973/1670 train_time:94905ms step_avg:97.54ms +step:974/1670 train_time:95002ms step_avg:97.54ms +step:975/1670 train_time:95099ms step_avg:97.54ms +step:976/1670 train_time:95197ms step_avg:97.54ms +step:977/1670 train_time:95295ms step_avg:97.54ms +step:978/1670 train_time:95392ms step_avg:97.54ms +step:979/1670 train_time:95489ms step_avg:97.54ms +step:980/1670 train_time:95586ms step_avg:97.54ms +step:981/1670 train_time:95682ms step_avg:97.54ms +step:982/1670 train_time:95779ms step_avg:97.53ms +step:983/1670 train_time:95878ms step_avg:97.54ms +step:984/1670 train_time:95975ms step_avg:97.54ms +step:985/1670 train_time:96072ms step_avg:97.53ms +step:986/1670 train_time:96170ms step_avg:97.54ms +step:987/1670 train_time:96268ms step_avg:97.54ms +step:988/1670 train_time:96365ms step_avg:97.54ms +step:989/1670 train_time:96462ms step_avg:97.53ms +step:990/1670 train_time:96558ms step_avg:97.53ms +step:991/1670 train_time:96656ms step_avg:97.53ms +step:992/1670 train_time:96754ms step_avg:97.53ms +step:993/1670 train_time:96852ms step_avg:97.53ms +step:994/1670 train_time:96950ms step_avg:97.53ms +step:995/1670 train_time:97046ms step_avg:97.53ms +step:996/1670 train_time:97142ms step_avg:97.53ms +step:997/1670 train_time:97239ms step_avg:97.53ms +step:998/1670 train_time:97337ms step_avg:97.53ms +step:999/1670 train_time:97435ms step_avg:97.53ms +step:1000/1670 train_time:97533ms step_avg:97.53ms +step:1000/1670 val_loss:3.4801 train_time:97629ms step_avg:97.63ms +step:1001/1670 train_time:97658ms step_avg:97.56ms +step:1002/1670 train_time:97734ms step_avg:97.54ms +step:1003/1670 train_time:97835ms step_avg:97.54ms +step:1004/1670 train_time:97931ms step_avg:97.54ms +step:1005/1670 train_time:98027ms step_avg:97.54ms +step:1006/1670 train_time:98123ms step_avg:97.54ms +step:1007/1670 train_time:98219ms step_avg:97.54ms +step:1008/1670 train_time:98315ms step_avg:97.53ms +step:1009/1670 train_time:98410ms step_avg:97.53ms +step:1010/1670 train_time:98506ms step_avg:97.53ms +step:1011/1670 train_time:98602ms step_avg:97.53ms +step:1012/1670 train_time:98702ms step_avg:97.53ms +step:1013/1670 train_time:98801ms step_avg:97.53ms +step:1014/1670 train_time:98899ms step_avg:97.53ms +step:1015/1670 train_time:98996ms step_avg:97.53ms +step:1016/1670 train_time:99094ms step_avg:97.53ms +step:1017/1670 train_time:99191ms step_avg:97.53ms +step:1018/1670 train_time:99286ms step_avg:97.53ms +step:1019/1670 train_time:99382ms step_avg:97.53ms +step:1020/1670 train_time:99478ms step_avg:97.53ms +step:1021/1670 train_time:99575ms step_avg:97.53ms +step:1022/1670 train_time:99673ms step_avg:97.53ms +step:1023/1670 train_time:99771ms step_avg:97.53ms +step:1024/1670 train_time:99868ms step_avg:97.53ms +step:1025/1670 train_time:99965ms step_avg:97.53ms +step:1026/1670 train_time:100063ms step_avg:97.53ms +step:1027/1670 train_time:100160ms step_avg:97.53ms +step:1028/1670 train_time:100258ms step_avg:97.53ms +step:1029/1670 train_time:100354ms step_avg:97.53ms +step:1030/1670 train_time:100451ms step_avg:97.52ms +step:1031/1670 train_time:100548ms step_avg:97.52ms +step:1032/1670 train_time:100644ms step_avg:97.52ms +step:1033/1670 train_time:100740ms step_avg:97.52ms +step:1034/1670 train_time:100838ms step_avg:97.52ms +step:1035/1670 train_time:100935ms step_avg:97.52ms +step:1036/1670 train_time:101033ms step_avg:97.52ms +step:1037/1670 train_time:101131ms step_avg:97.52ms +step:1038/1670 train_time:101228ms step_avg:97.52ms +step:1039/1670 train_time:101324ms step_avg:97.52ms +step:1040/1670 train_time:101420ms step_avg:97.52ms +step:1041/1670 train_time:101517ms step_avg:97.52ms +step:1042/1670 train_time:101615ms step_avg:97.52ms +step:1043/1670 train_time:101712ms step_avg:97.52ms +step:1044/1670 train_time:101809ms step_avg:97.52ms +step:1045/1670 train_time:101905ms step_avg:97.52ms +step:1046/1670 train_time:102003ms step_avg:97.52ms +step:1047/1670 train_time:102100ms step_avg:97.52ms +step:1048/1670 train_time:102198ms step_avg:97.52ms +step:1049/1670 train_time:102296ms step_avg:97.52ms +step:1050/1670 train_time:102393ms step_avg:97.52ms +step:1051/1670 train_time:102490ms step_avg:97.52ms +step:1052/1670 train_time:102587ms step_avg:97.52ms +step:1053/1670 train_time:102684ms step_avg:97.52ms +step:1054/1670 train_time:102781ms step_avg:97.52ms +step:1055/1670 train_time:102880ms step_avg:97.52ms +step:1056/1670 train_time:102977ms step_avg:97.52ms +step:1057/1670 train_time:103075ms step_avg:97.52ms +step:1058/1670 train_time:103172ms step_avg:97.52ms +step:1059/1670 train_time:103269ms step_avg:97.52ms +step:1060/1670 train_time:103365ms step_avg:97.51ms +step:1061/1670 train_time:103462ms step_avg:97.51ms +step:1062/1670 train_time:103731ms step_avg:97.67ms +step:1063/1670 train_time:103969ms step_avg:97.81ms +step:1064/1670 train_time:104063ms step_avg:97.80ms +step:1065/1670 train_time:104159ms step_avg:97.80ms +step:1066/1670 train_time:104255ms step_avg:97.80ms +step:1067/1670 train_time:104350ms step_avg:97.80ms +step:1068/1670 train_time:104445ms step_avg:97.79ms +step:1069/1670 train_time:104541ms step_avg:97.79ms +step:1070/1670 train_time:104637ms step_avg:97.79ms +step:1071/1670 train_time:104733ms step_avg:97.79ms +step:1072/1670 train_time:104835ms step_avg:97.79ms +step:1073/1670 train_time:104935ms step_avg:97.80ms +step:1074/1670 train_time:105036ms step_avg:97.80ms +step:1075/1670 train_time:105134ms step_avg:97.80ms +step:1076/1670 train_time:105230ms step_avg:97.80ms +step:1077/1670 train_time:105327ms step_avg:97.80ms +step:1078/1670 train_time:105422ms step_avg:97.79ms +step:1079/1670 train_time:105518ms step_avg:97.79ms +step:1080/1670 train_time:105614ms step_avg:97.79ms +step:1081/1670 train_time:105710ms step_avg:97.79ms +step:1082/1670 train_time:105809ms step_avg:97.79ms +step:1083/1670 train_time:105907ms step_avg:97.79ms +step:1084/1670 train_time:106005ms step_avg:97.79ms +step:1085/1670 train_time:106104ms step_avg:97.79ms +step:1086/1670 train_time:106202ms step_avg:97.79ms +step:1087/1670 train_time:106300ms step_avg:97.79ms +step:1088/1670 train_time:106396ms step_avg:97.79ms +step:1089/1670 train_time:106493ms step_avg:97.79ms +step:1090/1670 train_time:106589ms step_avg:97.79ms +step:1091/1670 train_time:106684ms step_avg:97.79ms +step:1092/1670 train_time:106781ms step_avg:97.78ms +step:1093/1670 train_time:106880ms step_avg:97.79ms +step:1094/1670 train_time:106978ms step_avg:97.79ms +step:1095/1670 train_time:107076ms step_avg:97.79ms +step:1096/1670 train_time:107175ms step_avg:97.79ms +step:1097/1670 train_time:107272ms step_avg:97.79ms +step:1098/1670 train_time:107369ms step_avg:97.79ms +step:1099/1670 train_time:107466ms step_avg:97.79ms +step:1100/1670 train_time:107562ms step_avg:97.78ms +step:1101/1670 train_time:107659ms step_avg:97.78ms +step:1102/1670 train_time:107755ms step_avg:97.78ms +step:1103/1670 train_time:107853ms step_avg:97.78ms +step:1104/1670 train_time:107949ms step_avg:97.78ms +step:1105/1670 train_time:108046ms step_avg:97.78ms +step:1106/1670 train_time:108144ms step_avg:97.78ms +step:1107/1670 train_time:108243ms step_avg:97.78ms +step:1108/1670 train_time:108341ms step_avg:97.78ms +step:1109/1670 train_time:108439ms step_avg:97.78ms +step:1110/1670 train_time:108537ms step_avg:97.78ms +step:1111/1670 train_time:108635ms step_avg:97.78ms +step:1112/1670 train_time:108731ms step_avg:97.78ms +step:1113/1670 train_time:108828ms step_avg:97.78ms +step:1114/1670 train_time:108925ms step_avg:97.78ms +step:1115/1670 train_time:109021ms step_avg:97.78ms +step:1116/1670 train_time:109119ms step_avg:97.78ms +step:1117/1670 train_time:109217ms step_avg:97.78ms +step:1118/1670 train_time:109315ms step_avg:97.78ms +step:1119/1670 train_time:109413ms step_avg:97.78ms +step:1120/1670 train_time:109510ms step_avg:97.78ms +step:1121/1670 train_time:109608ms step_avg:97.78ms +step:1122/1670 train_time:109706ms step_avg:97.78ms +step:1123/1670 train_time:109804ms step_avg:97.78ms +step:1124/1670 train_time:109902ms step_avg:97.78ms +step:1125/1670 train_time:110000ms step_avg:97.78ms +step:1125/1670 val_loss:3.4271 train_time:110097ms step_avg:97.86ms +step:1126/1670 train_time:110124ms step_avg:97.80ms +step:1127/1670 train_time:110208ms step_avg:97.79ms +step:1128/1670 train_time:110305ms step_avg:97.79ms +step:1129/1670 train_time:110402ms step_avg:97.79ms +step:1130/1670 train_time:110499ms step_avg:97.79ms +step:1131/1670 train_time:110595ms step_avg:97.79ms +step:1132/1670 train_time:110693ms step_avg:97.78ms +step:1133/1670 train_time:110788ms step_avg:97.78ms +step:1134/1670 train_time:110884ms step_avg:97.78ms +step:1135/1670 train_time:110981ms step_avg:97.78ms +step:1136/1670 train_time:111083ms step_avg:97.78ms +step:1137/1670 train_time:111185ms step_avg:97.79ms +step:1138/1670 train_time:111284ms step_avg:97.79ms +step:1139/1670 train_time:111383ms step_avg:97.79ms +step:1140/1670 train_time:111480ms step_avg:97.79ms +step:1141/1670 train_time:111577ms step_avg:97.79ms +step:1142/1670 train_time:111674ms step_avg:97.79ms +step:1143/1670 train_time:111771ms step_avg:97.79ms +step:1144/1670 train_time:111868ms step_avg:97.79ms +step:1145/1670 train_time:111965ms step_avg:97.79ms +step:1146/1670 train_time:112065ms step_avg:97.79ms +step:1147/1670 train_time:112165ms step_avg:97.79ms +step:1148/1670 train_time:112266ms step_avg:97.79ms +step:1149/1670 train_time:112364ms step_avg:97.79ms +step:1150/1670 train_time:112461ms step_avg:97.79ms +step:1151/1670 train_time:112557ms step_avg:97.79ms +step:1152/1670 train_time:112655ms step_avg:97.79ms +step:1153/1670 train_time:112752ms step_avg:97.79ms +step:1154/1670 train_time:112849ms step_avg:97.79ms +step:1155/1670 train_time:112946ms step_avg:97.79ms +step:1156/1670 train_time:113043ms step_avg:97.79ms +step:1157/1670 train_time:113142ms step_avg:97.79ms +step:1158/1670 train_time:113240ms step_avg:97.79ms +step:1159/1670 train_time:113339ms step_avg:97.79ms +step:1160/1670 train_time:113437ms step_avg:97.79ms +step:1161/1670 train_time:113536ms step_avg:97.79ms +step:1162/1670 train_time:113633ms step_avg:97.79ms +step:1163/1670 train_time:113729ms step_avg:97.79ms +step:1164/1670 train_time:113826ms step_avg:97.79ms +step:1165/1670 train_time:113924ms step_avg:97.79ms +step:1166/1670 train_time:114022ms step_avg:97.79ms +step:1167/1670 train_time:114120ms step_avg:97.79ms +step:1168/1670 train_time:114218ms step_avg:97.79ms +step:1169/1670 train_time:114316ms step_avg:97.79ms +step:1170/1670 train_time:114413ms step_avg:97.79ms +step:1171/1670 train_time:114510ms step_avg:97.79ms +step:1172/1670 train_time:114608ms step_avg:97.79ms +step:1173/1670 train_time:114706ms step_avg:97.79ms +step:1174/1670 train_time:114803ms step_avg:97.79ms +step:1175/1670 train_time:114902ms step_avg:97.79ms +step:1176/1670 train_time:115001ms step_avg:97.79ms +step:1177/1670 train_time:115097ms step_avg:97.79ms +step:1178/1670 train_time:115195ms step_avg:97.79ms +step:1179/1670 train_time:115293ms step_avg:97.79ms +step:1180/1670 train_time:115391ms step_avg:97.79ms +step:1181/1670 train_time:115488ms step_avg:97.79ms +step:1182/1670 train_time:115586ms step_avg:97.79ms +step:1183/1670 train_time:115683ms step_avg:97.79ms +step:1184/1670 train_time:115780ms step_avg:97.79ms +step:1185/1670 train_time:115879ms step_avg:97.79ms +step:1186/1670 train_time:115976ms step_avg:97.79ms +step:1187/1670 train_time:116074ms step_avg:97.79ms +step:1188/1670 train_time:116170ms step_avg:97.79ms +step:1189/1670 train_time:116269ms step_avg:97.79ms +step:1190/1670 train_time:116367ms step_avg:97.79ms +step:1191/1670 train_time:116465ms step_avg:97.79ms +step:1192/1670 train_time:116565ms step_avg:97.79ms +step:1193/1670 train_time:116663ms step_avg:97.79ms +step:1194/1670 train_time:116759ms step_avg:97.79ms +step:1195/1670 train_time:116857ms step_avg:97.79ms +step:1196/1670 train_time:116955ms step_avg:97.79ms +step:1197/1670 train_time:117052ms step_avg:97.79ms +step:1198/1670 train_time:117150ms step_avg:97.79ms +step:1199/1670 train_time:117247ms step_avg:97.79ms +step:1200/1670 train_time:117345ms step_avg:97.79ms +step:1201/1670 train_time:117444ms step_avg:97.79ms +step:1202/1670 train_time:117543ms step_avg:97.79ms +step:1203/1670 train_time:117643ms step_avg:97.79ms +step:1204/1670 train_time:117740ms step_avg:97.79ms +step:1205/1670 train_time:117838ms step_avg:97.79ms +step:1206/1670 train_time:117935ms step_avg:97.79ms +step:1207/1670 train_time:118033ms step_avg:97.79ms +step:1208/1670 train_time:118131ms step_avg:97.79ms +step:1209/1670 train_time:118228ms step_avg:97.79ms +step:1210/1670 train_time:118326ms step_avg:97.79ms +step:1211/1670 train_time:118424ms step_avg:97.79ms +step:1212/1670 train_time:118522ms step_avg:97.79ms +step:1213/1670 train_time:118620ms step_avg:97.79ms +step:1214/1670 train_time:118718ms step_avg:97.79ms +step:1215/1670 train_time:118815ms step_avg:97.79ms +step:1216/1670 train_time:118913ms step_avg:97.79ms +step:1217/1670 train_time:119010ms step_avg:97.79ms +step:1218/1670 train_time:119107ms step_avg:97.79ms +step:1219/1670 train_time:119204ms step_avg:97.79ms +step:1220/1670 train_time:119302ms step_avg:97.79ms +step:1221/1670 train_time:119401ms step_avg:97.79ms +step:1222/1670 train_time:119499ms step_avg:97.79ms +step:1223/1670 train_time:119597ms step_avg:97.79ms +step:1224/1670 train_time:119695ms step_avg:97.79ms +step:1225/1670 train_time:119791ms step_avg:97.79ms +step:1226/1670 train_time:119889ms step_avg:97.79ms +step:1227/1670 train_time:119986ms step_avg:97.79ms +step:1228/1670 train_time:120085ms step_avg:97.79ms +step:1229/1670 train_time:120182ms step_avg:97.79ms +step:1230/1670 train_time:120281ms step_avg:97.79ms +step:1231/1670 train_time:120380ms step_avg:97.79ms +step:1232/1670 train_time:120478ms step_avg:97.79ms +step:1233/1670 train_time:120575ms step_avg:97.79ms +step:1234/1670 train_time:120673ms step_avg:97.79ms +step:1235/1670 train_time:120771ms step_avg:97.79ms +step:1236/1670 train_time:120868ms step_avg:97.79ms +step:1237/1670 train_time:120965ms step_avg:97.79ms +step:1238/1670 train_time:121063ms step_avg:97.79ms +step:1239/1670 train_time:121161ms step_avg:97.79ms +step:1240/1670 train_time:121260ms step_avg:97.79ms +step:1241/1670 train_time:121358ms step_avg:97.79ms +step:1242/1670 train_time:121457ms step_avg:97.79ms +step:1243/1670 train_time:121555ms step_avg:97.79ms +step:1244/1670 train_time:121653ms step_avg:97.79ms +step:1245/1670 train_time:121750ms step_avg:97.79ms +step:1246/1670 train_time:121847ms step_avg:97.79ms +step:1247/1670 train_time:121945ms step_avg:97.79ms +step:1248/1670 train_time:122042ms step_avg:97.79ms +step:1249/1670 train_time:122141ms step_avg:97.79ms +step:1250/1670 train_time:122238ms step_avg:97.79ms +step:1250/1670 val_loss:3.3828 train_time:122335ms step_avg:97.87ms +step:1251/1670 train_time:122358ms step_avg:97.81ms +step:1252/1670 train_time:122439ms step_avg:97.79ms +step:1253/1670 train_time:122539ms step_avg:97.80ms +step:1254/1670 train_time:122638ms step_avg:97.80ms +step:1255/1670 train_time:122736ms step_avg:97.80ms +step:1256/1670 train_time:122833ms step_avg:97.80ms +step:1257/1670 train_time:122930ms step_avg:97.80ms +step:1258/1670 train_time:123027ms step_avg:97.80ms +step:1259/1670 train_time:123123ms step_avg:97.79ms +step:1260/1670 train_time:123220ms step_avg:97.79ms +step:1261/1670 train_time:123319ms step_avg:97.79ms +step:1262/1670 train_time:123421ms step_avg:97.80ms +step:1263/1670 train_time:123520ms step_avg:97.80ms +step:1264/1670 train_time:123617ms step_avg:97.80ms +step:1265/1670 train_time:123715ms step_avg:97.80ms +step:1266/1670 train_time:123812ms step_avg:97.80ms +step:1267/1670 train_time:123910ms step_avg:97.80ms +step:1268/1670 train_time:124007ms step_avg:97.80ms +step:1269/1670 train_time:124104ms step_avg:97.80ms +step:1270/1670 train_time:124201ms step_avg:97.80ms +step:1271/1670 train_time:124299ms step_avg:97.80ms +step:1272/1670 train_time:124399ms step_avg:97.80ms +step:1273/1670 train_time:124499ms step_avg:97.80ms +step:1274/1670 train_time:124768ms step_avg:97.93ms +step:1275/1670 train_time:124952ms step_avg:98.00ms +step:1276/1670 train_time:125047ms step_avg:98.00ms +step:1277/1670 train_time:125143ms step_avg:98.00ms +step:1278/1670 train_time:125239ms step_avg:98.00ms +step:1279/1670 train_time:125336ms step_avg:98.00ms +step:1280/1670 train_time:125433ms step_avg:97.99ms +step:1281/1670 train_time:125531ms step_avg:97.99ms +step:1282/1670 train_time:125628ms step_avg:97.99ms +step:1283/1670 train_time:125724ms step_avg:97.99ms +step:1284/1670 train_time:125823ms step_avg:97.99ms +step:1285/1670 train_time:125925ms step_avg:98.00ms +step:1286/1670 train_time:126024ms step_avg:98.00ms +step:1287/1670 train_time:126123ms step_avg:98.00ms +step:1288/1670 train_time:126220ms step_avg:98.00ms +step:1289/1670 train_time:126318ms step_avg:98.00ms +step:1290/1670 train_time:126414ms step_avg:98.00ms +step:1291/1670 train_time:126511ms step_avg:97.99ms +step:1292/1670 train_time:126608ms step_avg:97.99ms +step:1293/1670 train_time:126705ms step_avg:97.99ms +step:1294/1670 train_time:126803ms step_avg:97.99ms +step:1295/1670 train_time:126902ms step_avg:97.99ms +step:1296/1670 train_time:127001ms step_avg:97.99ms +step:1297/1670 train_time:127099ms step_avg:97.99ms +step:1298/1670 train_time:127197ms step_avg:97.99ms +step:1299/1670 train_time:127295ms step_avg:97.99ms +step:1300/1670 train_time:127392ms step_avg:97.99ms +step:1301/1670 train_time:127489ms step_avg:97.99ms +step:1302/1670 train_time:127586ms step_avg:97.99ms +step:1303/1670 train_time:127682ms step_avg:97.99ms +step:1304/1670 train_time:127780ms step_avg:97.99ms +step:1305/1670 train_time:127879ms step_avg:97.99ms +step:1306/1670 train_time:127980ms step_avg:97.99ms +step:1307/1670 train_time:128079ms step_avg:97.99ms +step:1308/1670 train_time:128177ms step_avg:97.99ms +step:1309/1670 train_time:128275ms step_avg:97.99ms +step:1310/1670 train_time:128372ms step_avg:97.99ms +step:1311/1670 train_time:128469ms step_avg:97.99ms +step:1312/1670 train_time:128567ms step_avg:97.99ms +step:1313/1670 train_time:128664ms step_avg:97.99ms +step:1314/1670 train_time:128760ms step_avg:97.99ms +step:1315/1670 train_time:128858ms step_avg:97.99ms +step:1316/1670 train_time:128957ms step_avg:97.99ms +step:1317/1670 train_time:129058ms step_avg:97.99ms +step:1318/1670 train_time:129157ms step_avg:97.99ms +step:1319/1670 train_time:129256ms step_avg:98.00ms +step:1320/1670 train_time:129354ms step_avg:98.00ms +step:1321/1670 train_time:129451ms step_avg:97.99ms +step:1322/1670 train_time:129549ms step_avg:97.99ms +step:1323/1670 train_time:129647ms step_avg:97.99ms +step:1324/1670 train_time:129744ms step_avg:97.99ms +step:1325/1670 train_time:129841ms step_avg:97.99ms +step:1326/1670 train_time:129938ms step_avg:97.99ms +step:1327/1670 train_time:130037ms step_avg:97.99ms +step:1328/1670 train_time:130136ms step_avg:97.99ms +step:1329/1670 train_time:130234ms step_avg:97.99ms +step:1330/1670 train_time:130332ms step_avg:97.99ms +step:1331/1670 train_time:130430ms step_avg:97.99ms +step:1332/1670 train_time:130527ms step_avg:97.99ms +step:1333/1670 train_time:130623ms step_avg:97.99ms +step:1334/1670 train_time:130720ms step_avg:97.99ms +step:1335/1670 train_time:130817ms step_avg:97.99ms +step:1336/1670 train_time:130915ms step_avg:97.99ms +step:1337/1670 train_time:131014ms step_avg:97.99ms +step:1338/1670 train_time:131114ms step_avg:97.99ms +step:1339/1670 train_time:131212ms step_avg:97.99ms +step:1340/1670 train_time:131310ms step_avg:97.99ms +step:1341/1670 train_time:131409ms step_avg:97.99ms +step:1342/1670 train_time:131507ms step_avg:97.99ms +step:1343/1670 train_time:131604ms step_avg:97.99ms +step:1344/1670 train_time:131701ms step_avg:97.99ms +step:1345/1670 train_time:131798ms step_avg:97.99ms +step:1346/1670 train_time:131897ms step_avg:97.99ms +step:1347/1670 train_time:131994ms step_avg:97.99ms +step:1348/1670 train_time:132093ms step_avg:97.99ms +step:1349/1670 train_time:132190ms step_avg:97.99ms +step:1350/1670 train_time:132288ms step_avg:97.99ms +step:1351/1670 train_time:132386ms step_avg:97.99ms +step:1352/1670 train_time:132483ms step_avg:97.99ms +step:1353/1670 train_time:132581ms step_avg:97.99ms +step:1354/1670 train_time:132678ms step_avg:97.99ms +step:1355/1670 train_time:132776ms step_avg:97.99ms +step:1356/1670 train_time:132875ms step_avg:97.99ms +step:1357/1670 train_time:132973ms step_avg:97.99ms +step:1358/1670 train_time:133071ms step_avg:97.99ms +step:1359/1670 train_time:133169ms step_avg:97.99ms +step:1360/1670 train_time:133267ms step_avg:97.99ms +step:1361/1670 train_time:133365ms step_avg:97.99ms +step:1362/1670 train_time:133462ms step_avg:97.99ms +step:1363/1670 train_time:133559ms step_avg:97.99ms +step:1364/1670 train_time:133657ms step_avg:97.99ms +step:1365/1670 train_time:133754ms step_avg:97.99ms +step:1366/1670 train_time:133852ms step_avg:97.99ms +step:1367/1670 train_time:133949ms step_avg:97.99ms +step:1368/1670 train_time:134047ms step_avg:97.99ms +step:1369/1670 train_time:134144ms step_avg:97.99ms +step:1370/1670 train_time:134241ms step_avg:97.99ms +step:1371/1670 train_time:134339ms step_avg:97.99ms +step:1372/1670 train_time:134437ms step_avg:97.99ms +step:1373/1670 train_time:134536ms step_avg:97.99ms +step:1374/1670 train_time:134634ms step_avg:97.99ms +step:1375/1670 train_time:134733ms step_avg:97.99ms +step:1375/1670 val_loss:3.3445 train_time:134829ms step_avg:98.06ms +step:1376/1670 train_time:134853ms step_avg:98.00ms +step:1377/1670 train_time:134934ms step_avg:97.99ms +step:1378/1670 train_time:135033ms step_avg:97.99ms +step:1379/1670 train_time:135130ms step_avg:97.99ms +step:1380/1670 train_time:135227ms step_avg:97.99ms +step:1381/1670 train_time:135324ms step_avg:97.99ms +step:1382/1670 train_time:135421ms step_avg:97.99ms +step:1383/1670 train_time:135518ms step_avg:97.99ms +step:1384/1670 train_time:135616ms step_avg:97.99ms +step:1385/1670 train_time:135714ms step_avg:97.99ms +step:1386/1670 train_time:135813ms step_avg:97.99ms +step:1387/1670 train_time:135913ms step_avg:97.99ms +step:1388/1670 train_time:136012ms step_avg:97.99ms +step:1389/1670 train_time:136110ms step_avg:97.99ms +step:1390/1670 train_time:136208ms step_avg:97.99ms +step:1391/1670 train_time:136305ms step_avg:97.99ms +step:1392/1670 train_time:136403ms step_avg:97.99ms +step:1393/1670 train_time:136500ms step_avg:97.99ms +step:1394/1670 train_time:136597ms step_avg:97.99ms +step:1395/1670 train_time:136695ms step_avg:97.99ms +step:1396/1670 train_time:136794ms step_avg:97.99ms +step:1397/1670 train_time:136893ms step_avg:97.99ms +step:1398/1670 train_time:136990ms step_avg:97.99ms +step:1399/1670 train_time:137089ms step_avg:97.99ms +step:1400/1670 train_time:137186ms step_avg:97.99ms +step:1401/1670 train_time:137284ms step_avg:97.99ms +step:1402/1670 train_time:137382ms step_avg:97.99ms +step:1403/1670 train_time:137479ms step_avg:97.99ms +step:1404/1670 train_time:137576ms step_avg:97.99ms +step:1405/1670 train_time:137674ms step_avg:97.99ms +step:1406/1670 train_time:137772ms step_avg:97.99ms +step:1407/1670 train_time:137870ms step_avg:97.99ms +step:1408/1670 train_time:137968ms step_avg:97.99ms +step:1409/1670 train_time:138066ms step_avg:97.99ms +step:1410/1670 train_time:138164ms step_avg:97.99ms +step:1411/1670 train_time:138263ms step_avg:97.99ms +step:1412/1670 train_time:138360ms step_avg:97.99ms +step:1413/1670 train_time:138457ms step_avg:97.99ms +step:1414/1670 train_time:138555ms step_avg:97.99ms +step:1415/1670 train_time:138652ms step_avg:97.99ms +step:1416/1670 train_time:138749ms step_avg:97.99ms +step:1417/1670 train_time:138847ms step_avg:97.99ms +step:1418/1670 train_time:138945ms step_avg:97.99ms +step:1419/1670 train_time:139043ms step_avg:97.99ms +step:1420/1670 train_time:139143ms step_avg:97.99ms +step:1421/1670 train_time:139241ms step_avg:97.99ms +step:1422/1670 train_time:139340ms step_avg:97.99ms +step:1423/1670 train_time:139438ms step_avg:97.99ms +step:1424/1670 train_time:139535ms step_avg:97.99ms +step:1425/1670 train_time:139632ms step_avg:97.99ms +step:1426/1670 train_time:139730ms step_avg:97.99ms +step:1427/1670 train_time:139828ms step_avg:97.99ms +step:1428/1670 train_time:139925ms step_avg:97.99ms +step:1429/1670 train_time:140024ms step_avg:97.99ms +step:1430/1670 train_time:140122ms step_avg:97.99ms +step:1431/1670 train_time:140221ms step_avg:97.99ms +step:1432/1670 train_time:140318ms step_avg:97.99ms +step:1433/1670 train_time:140416ms step_avg:97.99ms +step:1434/1670 train_time:140513ms step_avg:97.99ms +step:1435/1670 train_time:140610ms step_avg:97.99ms +step:1436/1670 train_time:140708ms step_avg:97.99ms +step:1437/1670 train_time:140806ms step_avg:97.99ms +step:1438/1670 train_time:140903ms step_avg:97.99ms +step:1439/1670 train_time:141002ms step_avg:97.99ms +step:1440/1670 train_time:141099ms step_avg:97.99ms +step:1441/1670 train_time:141198ms step_avg:97.99ms +step:1442/1670 train_time:141296ms step_avg:97.99ms +step:1443/1670 train_time:141393ms step_avg:97.99ms +step:1444/1670 train_time:141490ms step_avg:97.98ms +step:1445/1670 train_time:141588ms step_avg:97.98ms +step:1446/1670 train_time:141685ms step_avg:97.98ms +step:1447/1670 train_time:141784ms step_avg:97.99ms +step:1448/1670 train_time:141883ms step_avg:97.99ms +step:1449/1670 train_time:141981ms step_avg:97.99ms +step:1450/1670 train_time:142080ms step_avg:97.99ms +step:1451/1670 train_time:142178ms step_avg:97.99ms +step:1452/1670 train_time:142276ms step_avg:97.99ms +step:1453/1670 train_time:142373ms step_avg:97.99ms +step:1454/1670 train_time:142471ms step_avg:97.99ms +step:1455/1670 train_time:142568ms step_avg:97.99ms +step:1456/1670 train_time:142666ms step_avg:97.99ms +step:1457/1670 train_time:142765ms step_avg:97.99ms +step:1458/1670 train_time:142862ms step_avg:97.98ms +step:1459/1670 train_time:142960ms step_avg:97.98ms +step:1460/1670 train_time:143058ms step_avg:97.98ms +step:1461/1670 train_time:143156ms step_avg:97.98ms +step:1462/1670 train_time:143253ms step_avg:97.98ms +step:1463/1670 train_time:143350ms step_avg:97.98ms +step:1464/1670 train_time:143448ms step_avg:97.98ms +step:1465/1670 train_time:143546ms step_avg:97.98ms +step:1466/1670 train_time:143644ms step_avg:97.98ms +step:1467/1670 train_time:143743ms step_avg:97.98ms +step:1468/1670 train_time:143841ms step_avg:97.98ms +step:1469/1670 train_time:143939ms step_avg:97.98ms +step:1470/1670 train_time:144036ms step_avg:97.98ms +step:1471/1670 train_time:144133ms step_avg:97.98ms +step:1472/1670 train_time:144231ms step_avg:97.98ms +step:1473/1670 train_time:144328ms step_avg:97.98ms +step:1474/1670 train_time:144426ms step_avg:97.98ms +step:1475/1670 train_time:144525ms step_avg:97.98ms +step:1476/1670 train_time:144623ms step_avg:97.98ms +step:1477/1670 train_time:144722ms step_avg:97.98ms +step:1478/1670 train_time:144821ms step_avg:97.98ms +step:1479/1670 train_time:144918ms step_avg:97.98ms +step:1480/1670 train_time:145017ms step_avg:97.98ms +step:1481/1670 train_time:145114ms step_avg:97.98ms +step:1482/1670 train_time:145211ms step_avg:97.98ms +step:1483/1670 train_time:145309ms step_avg:97.98ms +step:1484/1670 train_time:145406ms step_avg:97.98ms +step:1485/1670 train_time:145675ms step_avg:98.10ms +step:1486/1670 train_time:145879ms step_avg:98.17ms +step:1487/1670 train_time:145976ms step_avg:98.17ms +step:1488/1670 train_time:146072ms step_avg:98.17ms +step:1489/1670 train_time:146168ms step_avg:98.17ms +step:1490/1670 train_time:146265ms step_avg:98.16ms +step:1491/1670 train_time:146361ms step_avg:98.16ms +step:1492/1670 train_time:146457ms step_avg:98.16ms +step:1493/1670 train_time:146554ms step_avg:98.16ms +step:1494/1670 train_time:146650ms step_avg:98.16ms +step:1495/1670 train_time:146750ms step_avg:98.16ms +step:1496/1670 train_time:146852ms step_avg:98.16ms +step:1497/1670 train_time:146951ms step_avg:98.16ms +step:1498/1670 train_time:147049ms step_avg:98.16ms +step:1499/1670 train_time:147146ms step_avg:98.16ms +step:1500/1670 train_time:147243ms step_avg:98.16ms +step:1500/1670 val_loss:3.3122 train_time:147340ms step_avg:98.23ms +step:1501/1670 train_time:147363ms step_avg:98.18ms +step:1502/1670 train_time:147444ms step_avg:98.17ms +step:1503/1670 train_time:147548ms step_avg:98.17ms +step:1504/1670 train_time:147646ms step_avg:98.17ms +step:1505/1670 train_time:147744ms step_avg:98.17ms +step:1506/1670 train_time:147841ms step_avg:98.17ms +step:1507/1670 train_time:147937ms step_avg:98.17ms +step:1508/1670 train_time:148034ms step_avg:98.17ms +step:1509/1670 train_time:148131ms step_avg:98.16ms +step:1510/1670 train_time:148228ms step_avg:98.16ms +step:1511/1670 train_time:148327ms step_avg:98.16ms +step:1512/1670 train_time:148430ms step_avg:98.17ms +step:1513/1670 train_time:148530ms step_avg:98.17ms +step:1514/1670 train_time:148630ms step_avg:98.17ms +step:1515/1670 train_time:148727ms step_avg:98.17ms +step:1516/1670 train_time:148825ms step_avg:98.17ms +step:1517/1670 train_time:148922ms step_avg:98.17ms +step:1518/1670 train_time:149019ms step_avg:98.17ms +step:1519/1670 train_time:149115ms step_avg:98.17ms +step:1520/1670 train_time:149211ms step_avg:98.17ms +step:1521/1670 train_time:149310ms step_avg:98.17ms +step:1522/1670 train_time:149411ms step_avg:98.17ms +step:1523/1670 train_time:149511ms step_avg:98.17ms +step:1524/1670 train_time:149610ms step_avg:98.17ms +step:1525/1670 train_time:149709ms step_avg:98.17ms +step:1526/1670 train_time:149808ms step_avg:98.17ms +step:1527/1670 train_time:149905ms step_avg:98.17ms +step:1528/1670 train_time:150003ms step_avg:98.17ms +step:1529/1670 train_time:150099ms step_avg:98.17ms +step:1530/1670 train_time:150196ms step_avg:98.17ms +step:1531/1670 train_time:150293ms step_avg:98.17ms +step:1532/1670 train_time:150392ms step_avg:98.17ms +step:1533/1670 train_time:150491ms step_avg:98.17ms +step:1534/1670 train_time:150590ms step_avg:98.17ms +step:1535/1670 train_time:150690ms step_avg:98.17ms +step:1536/1670 train_time:150788ms step_avg:98.17ms +step:1537/1670 train_time:150887ms step_avg:98.17ms +step:1538/1670 train_time:150985ms step_avg:98.17ms +step:1539/1670 train_time:151083ms step_avg:98.17ms +step:1540/1670 train_time:151180ms step_avg:98.17ms +step:1541/1670 train_time:151278ms step_avg:98.17ms +step:1542/1670 train_time:151376ms step_avg:98.17ms +step:1543/1670 train_time:151474ms step_avg:98.17ms +step:1544/1670 train_time:151570ms step_avg:98.17ms +step:1545/1670 train_time:151669ms step_avg:98.17ms +step:1546/1670 train_time:151767ms step_avg:98.17ms +step:1547/1670 train_time:151864ms step_avg:98.17ms +step:1548/1670 train_time:151963ms step_avg:98.17ms +step:1549/1670 train_time:152061ms step_avg:98.17ms +step:1550/1670 train_time:152159ms step_avg:98.17ms +step:1551/1670 train_time:152256ms step_avg:98.17ms +step:1552/1670 train_time:152354ms step_avg:98.17ms +step:1553/1670 train_time:152452ms step_avg:98.17ms +step:1554/1670 train_time:152550ms step_avg:98.17ms +step:1555/1670 train_time:152649ms step_avg:98.17ms +step:1556/1670 train_time:152747ms step_avg:98.17ms +step:1557/1670 train_time:152846ms step_avg:98.17ms +step:1558/1670 train_time:152943ms step_avg:98.17ms +step:1559/1670 train_time:153040ms step_avg:98.17ms +step:1560/1670 train_time:153138ms step_avg:98.17ms +step:1561/1670 train_time:153235ms step_avg:98.16ms +step:1562/1670 train_time:153332ms step_avg:98.16ms +step:1563/1670 train_time:153431ms step_avg:98.16ms +step:1564/1670 train_time:153529ms step_avg:98.16ms +step:1565/1670 train_time:153628ms step_avg:98.16ms +step:1566/1670 train_time:153726ms step_avg:98.16ms +step:1567/1670 train_time:153824ms step_avg:98.16ms +step:1568/1670 train_time:153921ms step_avg:98.16ms +step:1569/1670 train_time:154020ms step_avg:98.16ms +step:1570/1670 train_time:154116ms step_avg:98.16ms +step:1571/1670 train_time:154212ms step_avg:98.16ms +step:1572/1670 train_time:154310ms step_avg:98.16ms +step:1573/1670 train_time:154408ms step_avg:98.16ms +step:1574/1670 train_time:154507ms step_avg:98.16ms +step:1575/1670 train_time:154606ms step_avg:98.16ms +step:1576/1670 train_time:154704ms step_avg:98.16ms +step:1577/1670 train_time:154801ms step_avg:98.16ms +step:1578/1670 train_time:154898ms step_avg:98.16ms +step:1579/1670 train_time:154996ms step_avg:98.16ms +step:1580/1670 train_time:155094ms step_avg:98.16ms +step:1581/1670 train_time:155193ms step_avg:98.16ms +step:1582/1670 train_time:155290ms step_avg:98.16ms +step:1583/1670 train_time:155388ms step_avg:98.16ms +step:1584/1670 train_time:155486ms step_avg:98.16ms +step:1585/1670 train_time:155584ms step_avg:98.16ms +step:1586/1670 train_time:155682ms step_avg:98.16ms +step:1587/1670 train_time:155779ms step_avg:98.16ms +step:1588/1670 train_time:155876ms step_avg:98.16ms +step:1589/1670 train_time:155974ms step_avg:98.16ms +step:1590/1670 train_time:156071ms step_avg:98.16ms +step:1591/1670 train_time:156170ms step_avg:98.16ms +step:1592/1670 train_time:156268ms step_avg:98.16ms +step:1593/1670 train_time:156366ms step_avg:98.16ms +step:1594/1670 train_time:156463ms step_avg:98.16ms +step:1595/1670 train_time:156561ms step_avg:98.16ms +step:1596/1670 train_time:156658ms step_avg:98.16ms +step:1597/1670 train_time:156756ms step_avg:98.16ms +step:1598/1670 train_time:156854ms step_avg:98.16ms +step:1599/1670 train_time:156951ms step_avg:98.16ms +step:1600/1670 train_time:157050ms step_avg:98.16ms +step:1601/1670 train_time:157149ms step_avg:98.16ms +step:1602/1670 train_time:157246ms step_avg:98.16ms +step:1603/1670 train_time:157344ms step_avg:98.16ms +step:1604/1670 train_time:157441ms step_avg:98.16ms +step:1605/1670 train_time:157539ms step_avg:98.16ms +step:1606/1670 train_time:157637ms step_avg:98.15ms +step:1607/1670 train_time:157733ms step_avg:98.15ms +step:1608/1670 train_time:157831ms step_avg:98.15ms +step:1609/1670 train_time:157929ms step_avg:98.15ms +step:1610/1670 train_time:158028ms step_avg:98.15ms +step:1611/1670 train_time:158126ms step_avg:98.15ms +step:1612/1670 train_time:158224ms step_avg:98.15ms +step:1613/1670 train_time:158322ms step_avg:98.15ms +step:1614/1670 train_time:158420ms step_avg:98.15ms +step:1615/1670 train_time:158518ms step_avg:98.15ms +step:1616/1670 train_time:158616ms step_avg:98.15ms +step:1617/1670 train_time:158713ms step_avg:98.15ms +step:1618/1670 train_time:158812ms step_avg:98.15ms +step:1619/1670 train_time:158910ms step_avg:98.15ms +step:1620/1670 train_time:159008ms step_avg:98.15ms +step:1621/1670 train_time:159106ms step_avg:98.15ms +step:1622/1670 train_time:159204ms step_avg:98.15ms +step:1623/1670 train_time:159302ms step_avg:98.15ms +step:1624/1670 train_time:159399ms step_avg:98.15ms +step:1625/1670 train_time:159497ms step_avg:98.15ms +step:1625/1670 val_loss:3.2853 train_time:159593ms step_avg:98.21ms +step:1626/1670 train_time:159619ms step_avg:98.17ms +step:1627/1670 train_time:159701ms step_avg:98.16ms +step:1628/1670 train_time:159801ms step_avg:98.16ms +step:1629/1670 train_time:159899ms step_avg:98.16ms +step:1630/1670 train_time:159996ms step_avg:98.16ms +step:1631/1670 train_time:160093ms step_avg:98.16ms +step:1632/1670 train_time:160190ms step_avg:98.16ms +step:1633/1670 train_time:160286ms step_avg:98.15ms +step:1634/1670 train_time:160384ms step_avg:98.15ms +step:1635/1670 train_time:160481ms step_avg:98.15ms +step:1636/1670 train_time:160582ms step_avg:98.16ms +step:1637/1670 train_time:160683ms step_avg:98.16ms +step:1638/1670 train_time:160782ms step_avg:98.16ms +step:1639/1670 train_time:160879ms step_avg:98.16ms +step:1640/1670 train_time:160977ms step_avg:98.16ms +step:1641/1670 train_time:161076ms step_avg:98.16ms +step:1642/1670 train_time:161175ms step_avg:98.16ms +step:1643/1670 train_time:161273ms step_avg:98.16ms +step:1644/1670 train_time:161371ms step_avg:98.16ms +step:1645/1670 train_time:161467ms step_avg:98.16ms +step:1646/1670 train_time:161564ms step_avg:98.16ms +step:1647/1670 train_time:161664ms step_avg:98.16ms +step:1648/1670 train_time:161763ms step_avg:98.16ms +step:1649/1670 train_time:161861ms step_avg:98.16ms +step:1650/1670 train_time:161959ms step_avg:98.16ms +step:1651/1670 train_time:162057ms step_avg:98.16ms +step:1652/1670 train_time:162156ms step_avg:98.16ms +step:1653/1670 train_time:162254ms step_avg:98.16ms +step:1654/1670 train_time:162351ms step_avg:98.16ms +step:1655/1670 train_time:162448ms step_avg:98.16ms +step:1656/1670 train_time:162547ms step_avg:98.16ms +step:1657/1670 train_time:162645ms step_avg:98.16ms +step:1658/1670 train_time:162743ms step_avg:98.16ms +step:1659/1670 train_time:162841ms step_avg:98.16ms +step:1660/1670 train_time:162939ms step_avg:98.16ms +step:1661/1670 train_time:163039ms step_avg:98.16ms +step:1662/1670 train_time:163136ms step_avg:98.16ms +step:1663/1670 train_time:163235ms step_avg:98.16ms +step:1664/1670 train_time:163333ms step_avg:98.16ms +step:1665/1670 train_time:163431ms step_avg:98.16ms +step:1666/1670 train_time:163529ms step_avg:98.16ms +step:1667/1670 train_time:163627ms step_avg:98.16ms +step:1668/1670 train_time:163725ms step_avg:98.16ms +step:1669/1670 train_time:163822ms step_avg:98.16ms +step:1670/1670 train_time:163919ms step_avg:98.15ms +step:1670/1670 val_loss:3.2774 train_time:164016ms step_avg:98.21ms +peak memory allocated: 34000 MiB reserved: 49496 MiB diff --git a/records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt b/records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt new file mode 100644 index 000000000..a0c9c3b35 --- /dev/null +++ b/records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 20:04:19 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 29C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 33C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 53094 C /usr/bin/python 0MiB | +| 0 N/A N/A 53095 C /usr/bin/python 0MiB | +| 0 N/A N/A 53096 C /usr/bin/python 0MiB | +| 0 N/A N/A 53097 C /usr/bin/python 0MiB | +| 0 N/A N/A 53098 C /usr/bin/python 0MiB | +| 0 N/A N/A 53099 C /usr/bin/python 0MiB | +| 0 N/A N/A 53100 C /usr/bin/python 0MiB | +| 0 N/A N/A 53101 C /usr/bin/python 0MiB | +| 1 N/A N/A 53095 C /usr/bin/python 0MiB | +| 2 N/A N/A 53096 C /usr/bin/python 0MiB | +| 3 N/A N/A 53097 C /usr/bin/python 0MiB | +| 4 N/A N/A 53098 C /usr/bin/python 0MiB | +| 5 N/A N/A 53099 C /usr/bin/python 0MiB | +| 6 N/A N/A 53100 C /usr/bin/python 0MiB | +| 7 N/A N/A 53101 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:375ms step_avg:374.99ms +step:2/1670 train_time:398ms step_avg:199.22ms +step:3/1670 train_time:469ms step_avg:156.43ms +step:4/1670 train_time:563ms step_avg:140.76ms +step:5/1670 train_time:658ms step_avg:131.56ms +step:6/1670 train_time:752ms step_avg:125.32ms +step:7/1670 train_time:847ms step_avg:120.96ms +step:8/1670 train_time:942ms step_avg:117.75ms +step:9/1670 train_time:1037ms step_avg:115.24ms +step:10/1670 train_time:1133ms step_avg:113.32ms +step:11/1670 train_time:1228ms step_avg:111.62ms +step:12/1670 train_time:1326ms step_avg:110.52ms +step:13/1670 train_time:1424ms step_avg:109.56ms +step:14/1670 train_time:1520ms step_avg:108.58ms +step:15/1670 train_time:1616ms step_avg:107.72ms +step:16/1670 train_time:1711ms step_avg:106.91ms +step:17/1670 train_time:1806ms step_avg:106.22ms +step:18/1670 train_time:1901ms step_avg:105.63ms +step:19/1670 train_time:1996ms step_avg:105.06ms +step:20/1670 train_time:2091ms step_avg:104.57ms +step:21/1670 train_time:2187ms step_avg:104.15ms +step:22/1670 train_time:2283ms step_avg:103.78ms +step:23/1670 train_time:2380ms step_avg:103.49ms +step:24/1670 train_time:2477ms step_avg:103.21ms +step:25/1670 train_time:2572ms step_avg:102.89ms +step:26/1670 train_time:2669ms step_avg:102.66ms +step:27/1670 train_time:2764ms step_avg:102.39ms +step:28/1670 train_time:2860ms step_avg:102.14ms +step:29/1670 train_time:2955ms step_avg:101.91ms +step:30/1670 train_time:3051ms step_avg:101.70ms +step:31/1670 train_time:3147ms step_avg:101.51ms +step:32/1670 train_time:3243ms step_avg:101.34ms +step:33/1670 train_time:3338ms step_avg:101.16ms +step:34/1670 train_time:3434ms step_avg:101.00ms +step:35/1670 train_time:3530ms step_avg:100.85ms +step:36/1670 train_time:3627ms step_avg:100.75ms +step:37/1670 train_time:3723ms step_avg:100.61ms +step:38/1670 train_time:3818ms step_avg:100.46ms +step:39/1670 train_time:3913ms step_avg:100.33ms +step:40/1670 train_time:4008ms step_avg:100.19ms +step:41/1670 train_time:4104ms step_avg:100.10ms +step:42/1670 train_time:4199ms step_avg:99.98ms +step:43/1670 train_time:4295ms step_avg:99.88ms +step:44/1670 train_time:4391ms step_avg:99.79ms +step:45/1670 train_time:4487ms step_avg:99.72ms +step:46/1670 train_time:4583ms step_avg:99.63ms +step:47/1670 train_time:4679ms step_avg:99.56ms +step:48/1670 train_time:4775ms step_avg:99.48ms +step:49/1670 train_time:4871ms step_avg:99.40ms +step:50/1670 train_time:4966ms step_avg:99.32ms +step:51/1670 train_time:5061ms step_avg:99.24ms +step:52/1670 train_time:5156ms step_avg:99.16ms +step:53/1670 train_time:5252ms step_avg:99.10ms +step:54/1670 train_time:5348ms step_avg:99.04ms +step:55/1670 train_time:5444ms step_avg:98.97ms +step:56/1670 train_time:5540ms step_avg:98.92ms +step:57/1670 train_time:5635ms step_avg:98.86ms +step:58/1670 train_time:5731ms step_avg:98.80ms +step:59/1670 train_time:5826ms step_avg:98.75ms +step:60/1670 train_time:5923ms step_avg:98.71ms +step:61/1670 train_time:6018ms step_avg:98.66ms +step:62/1670 train_time:6113ms step_avg:98.59ms +step:63/1670 train_time:6208ms step_avg:98.55ms +step:64/1670 train_time:6305ms step_avg:98.51ms +step:65/1670 train_time:6401ms step_avg:98.48ms +step:66/1670 train_time:6497ms step_avg:98.43ms +step:67/1670 train_time:6593ms step_avg:98.40ms +step:68/1670 train_time:6688ms step_avg:98.35ms +step:69/1670 train_time:6783ms step_avg:98.31ms +step:70/1670 train_time:6879ms step_avg:98.28ms +step:71/1670 train_time:6975ms step_avg:98.24ms +step:72/1670 train_time:7070ms step_avg:98.20ms +step:73/1670 train_time:7166ms step_avg:98.17ms +step:74/1670 train_time:7262ms step_avg:98.14ms +step:75/1670 train_time:7358ms step_avg:98.11ms +step:76/1670 train_time:7454ms step_avg:98.08ms +step:77/1670 train_time:7551ms step_avg:98.06ms +step:78/1670 train_time:7647ms step_avg:98.04ms +step:79/1670 train_time:7743ms step_avg:98.02ms +step:80/1670 train_time:7839ms step_avg:97.98ms +step:81/1670 train_time:7934ms step_avg:97.95ms +step:82/1670 train_time:8030ms step_avg:97.92ms +step:83/1670 train_time:8126ms step_avg:97.90ms +step:84/1670 train_time:8222ms step_avg:97.88ms +step:85/1670 train_time:8317ms step_avg:97.85ms +step:86/1670 train_time:8412ms step_avg:97.82ms +step:87/1670 train_time:8508ms step_avg:97.79ms +step:88/1670 train_time:8604ms step_avg:97.77ms +step:89/1670 train_time:8699ms step_avg:97.75ms +step:90/1670 train_time:8796ms step_avg:97.73ms +step:91/1670 train_time:8891ms step_avg:97.70ms +step:92/1670 train_time:8986ms step_avg:97.67ms +step:93/1670 train_time:9081ms step_avg:97.65ms +step:94/1670 train_time:9177ms step_avg:97.62ms +step:95/1670 train_time:9272ms step_avg:97.60ms +step:96/1670 train_time:9368ms step_avg:97.58ms +step:97/1670 train_time:9464ms step_avg:97.56ms +step:98/1670 train_time:9559ms step_avg:97.54ms +step:99/1670 train_time:9655ms step_avg:97.52ms +step:100/1670 train_time:9750ms step_avg:97.50ms +step:101/1670 train_time:9847ms step_avg:97.49ms +step:102/1670 train_time:9943ms step_avg:97.48ms +step:103/1670 train_time:10039ms step_avg:97.46ms +step:104/1670 train_time:10134ms step_avg:97.44ms +step:105/1670 train_time:10229ms step_avg:97.42ms +step:106/1670 train_time:10325ms step_avg:97.40ms +step:107/1670 train_time:10420ms step_avg:97.38ms +step:108/1670 train_time:10515ms step_avg:97.36ms +step:109/1670 train_time:10611ms step_avg:97.35ms +step:110/1670 train_time:10707ms step_avg:97.33ms +step:111/1670 train_time:10803ms step_avg:97.32ms +step:112/1670 train_time:10898ms step_avg:97.30ms +step:113/1670 train_time:10994ms step_avg:97.29ms +step:114/1670 train_time:11090ms step_avg:97.28ms +step:115/1670 train_time:11186ms step_avg:97.27ms +step:116/1670 train_time:11282ms step_avg:97.26ms +step:117/1670 train_time:11378ms step_avg:97.25ms +step:118/1670 train_time:11474ms step_avg:97.23ms +step:119/1670 train_time:11569ms step_avg:97.22ms +step:120/1670 train_time:11666ms step_avg:97.22ms +step:121/1670 train_time:11762ms step_avg:97.20ms +step:122/1670 train_time:11857ms step_avg:97.19ms +step:123/1670 train_time:11952ms step_avg:97.17ms +step:124/1670 train_time:12047ms step_avg:97.16ms +step:125/1670 train_time:12143ms step_avg:97.15ms +step:125/1670 val_loss:4.2921 train_time:12238ms step_avg:97.91ms +step:126/1670 train_time:12265ms step_avg:97.34ms +step:127/1670 train_time:12345ms step_avg:97.20ms +step:128/1670 train_time:12448ms step_avg:97.25ms +step:129/1670 train_time:12546ms step_avg:97.26ms +step:130/1670 train_time:12641ms step_avg:97.24ms +step:131/1670 train_time:12737ms step_avg:97.23ms +step:132/1670 train_time:12831ms step_avg:97.20ms +step:133/1670 train_time:12925ms step_avg:97.18ms +step:134/1670 train_time:13020ms step_avg:97.16ms +step:135/1670 train_time:13115ms step_avg:97.15ms +step:136/1670 train_time:13209ms step_avg:97.13ms +step:137/1670 train_time:13305ms step_avg:97.12ms +step:138/1670 train_time:13405ms step_avg:97.14ms +step:139/1670 train_time:13503ms step_avg:97.15ms +step:140/1670 train_time:13600ms step_avg:97.14ms +step:141/1670 train_time:13695ms step_avg:97.13ms +step:142/1670 train_time:13790ms step_avg:97.11ms +step:143/1670 train_time:13885ms step_avg:97.10ms +step:144/1670 train_time:13980ms step_avg:97.08ms +step:145/1670 train_time:14074ms step_avg:97.06ms +step:146/1670 train_time:14170ms step_avg:97.05ms +step:147/1670 train_time:14265ms step_avg:97.04ms +step:148/1670 train_time:14362ms step_avg:97.04ms +step:149/1670 train_time:14460ms step_avg:97.05ms +step:150/1670 train_time:14557ms step_avg:97.05ms +step:151/1670 train_time:14654ms step_avg:97.04ms +step:152/1670 train_time:14749ms step_avg:97.03ms +step:153/1670 train_time:14844ms step_avg:97.02ms +step:154/1670 train_time:14939ms step_avg:97.01ms +step:155/1670 train_time:15034ms step_avg:96.99ms +step:156/1670 train_time:15129ms step_avg:96.98ms +step:157/1670 train_time:15224ms step_avg:96.97ms +step:158/1670 train_time:15320ms step_avg:96.96ms +step:159/1670 train_time:15417ms step_avg:96.96ms +step:160/1670 train_time:15514ms step_avg:96.96ms +step:161/1670 train_time:15610ms step_avg:96.96ms +step:162/1670 train_time:15705ms step_avg:96.95ms +step:163/1670 train_time:15801ms step_avg:96.94ms +step:164/1670 train_time:15896ms step_avg:96.93ms +step:165/1670 train_time:15991ms step_avg:96.91ms +step:166/1670 train_time:16085ms step_avg:96.90ms +step:167/1670 train_time:16181ms step_avg:96.89ms +step:168/1670 train_time:16277ms step_avg:96.89ms +step:169/1670 train_time:16373ms step_avg:96.88ms +step:170/1670 train_time:16469ms step_avg:96.88ms +step:171/1670 train_time:16565ms step_avg:96.87ms +step:172/1670 train_time:16660ms step_avg:96.86ms +step:173/1670 train_time:16758ms step_avg:96.87ms +step:174/1670 train_time:16853ms step_avg:96.85ms +step:175/1670 train_time:16948ms step_avg:96.85ms +step:176/1670 train_time:17043ms step_avg:96.83ms +step:177/1670 train_time:17139ms step_avg:96.83ms +step:178/1670 train_time:17234ms step_avg:96.82ms +step:179/1670 train_time:17330ms step_avg:96.82ms +step:180/1670 train_time:17426ms step_avg:96.81ms +step:181/1670 train_time:17522ms step_avg:96.81ms +step:182/1670 train_time:17617ms step_avg:96.80ms +step:183/1670 train_time:17713ms step_avg:96.79ms +step:184/1670 train_time:17809ms step_avg:96.79ms +step:185/1670 train_time:17904ms step_avg:96.78ms +step:186/1670 train_time:17999ms step_avg:96.77ms +step:187/1670 train_time:18094ms step_avg:96.76ms +step:188/1670 train_time:18189ms step_avg:96.75ms +step:189/1670 train_time:18285ms step_avg:96.74ms +step:190/1670 train_time:18381ms step_avg:96.74ms +step:191/1670 train_time:18477ms step_avg:96.74ms +step:192/1670 train_time:18574ms step_avg:96.74ms +step:193/1670 train_time:18669ms step_avg:96.73ms +step:194/1670 train_time:18766ms step_avg:96.73ms +step:195/1670 train_time:18861ms step_avg:96.72ms +step:196/1670 train_time:18957ms step_avg:96.72ms +step:197/1670 train_time:19052ms step_avg:96.71ms +step:198/1670 train_time:19146ms step_avg:96.70ms +step:199/1670 train_time:19242ms step_avg:96.69ms +step:200/1670 train_time:19337ms step_avg:96.69ms +step:201/1670 train_time:19432ms step_avg:96.68ms +step:202/1670 train_time:19528ms step_avg:96.67ms +step:203/1670 train_time:19624ms step_avg:96.67ms +step:204/1670 train_time:19719ms step_avg:96.66ms +step:205/1670 train_time:19816ms step_avg:96.66ms +step:206/1670 train_time:19911ms step_avg:96.65ms +step:207/1670 train_time:20006ms step_avg:96.64ms +step:208/1670 train_time:20101ms step_avg:96.64ms +step:209/1670 train_time:20196ms step_avg:96.63ms +step:210/1670 train_time:20292ms step_avg:96.63ms +step:211/1670 train_time:20387ms step_avg:96.62ms +step:212/1670 train_time:20483ms step_avg:96.62ms +step:213/1670 train_time:20759ms step_avg:97.46ms +step:214/1670 train_time:20853ms step_avg:97.44ms +step:215/1670 train_time:20947ms step_avg:97.43ms +step:216/1670 train_time:21041ms step_avg:97.41ms +step:217/1670 train_time:21136ms step_avg:97.40ms +step:218/1670 train_time:21231ms step_avg:97.39ms +step:219/1670 train_time:21325ms step_avg:97.37ms +step:220/1670 train_time:21420ms step_avg:97.36ms +step:221/1670 train_time:21514ms step_avg:97.35ms +step:222/1670 train_time:21610ms step_avg:97.34ms +step:223/1670 train_time:21708ms step_avg:97.35ms +step:224/1670 train_time:21805ms step_avg:97.34ms +step:225/1670 train_time:21901ms step_avg:97.34ms +step:226/1670 train_time:21996ms step_avg:97.33ms +step:227/1670 train_time:22091ms step_avg:97.32ms +step:228/1670 train_time:22186ms step_avg:97.31ms +step:229/1670 train_time:22281ms step_avg:97.30ms +step:230/1670 train_time:22376ms step_avg:97.29ms +step:231/1670 train_time:22471ms step_avg:97.28ms +step:232/1670 train_time:22565ms step_avg:97.27ms +step:233/1670 train_time:22663ms step_avg:97.27ms +step:234/1670 train_time:22760ms step_avg:97.26ms +step:235/1670 train_time:22857ms step_avg:97.26ms +step:236/1670 train_time:22954ms step_avg:97.26ms +step:237/1670 train_time:23048ms step_avg:97.25ms +step:238/1670 train_time:23143ms step_avg:97.24ms +step:239/1670 train_time:23238ms step_avg:97.23ms +step:240/1670 train_time:23333ms step_avg:97.22ms +step:241/1670 train_time:23428ms step_avg:97.21ms +step:242/1670 train_time:23523ms step_avg:97.20ms +step:243/1670 train_time:23619ms step_avg:97.20ms +step:244/1670 train_time:23714ms step_avg:97.19ms +step:245/1670 train_time:23810ms step_avg:97.18ms +step:246/1670 train_time:23906ms step_avg:97.18ms +step:247/1670 train_time:24002ms step_avg:97.17ms +step:248/1670 train_time:24098ms step_avg:97.17ms +step:249/1670 train_time:24193ms step_avg:97.16ms +step:250/1670 train_time:24288ms step_avg:97.15ms +step:250/1670 val_loss:3.9671 train_time:24382ms step_avg:97.53ms +step:251/1670 train_time:24406ms step_avg:97.24ms +step:252/1670 train_time:24482ms step_avg:97.15ms +step:253/1670 train_time:24581ms step_avg:97.16ms +step:254/1670 train_time:24677ms step_avg:97.15ms +step:255/1670 train_time:24772ms step_avg:97.14ms +step:256/1670 train_time:24866ms step_avg:97.13ms +step:257/1670 train_time:24960ms step_avg:97.12ms +step:258/1670 train_time:25055ms step_avg:97.11ms +step:259/1670 train_time:25150ms step_avg:97.10ms +step:260/1670 train_time:25245ms step_avg:97.10ms +step:261/1670 train_time:25341ms step_avg:97.09ms +step:262/1670 train_time:25439ms step_avg:97.10ms +step:263/1670 train_time:25536ms step_avg:97.10ms +step:264/1670 train_time:25633ms step_avg:97.09ms +step:265/1670 train_time:25729ms step_avg:97.09ms +step:266/1670 train_time:25825ms step_avg:97.09ms +step:267/1670 train_time:25919ms step_avg:97.08ms +step:268/1670 train_time:26014ms step_avg:97.07ms +step:269/1670 train_time:26110ms step_avg:97.06ms +step:270/1670 train_time:26203ms step_avg:97.05ms +step:271/1670 train_time:26299ms step_avg:97.04ms +step:272/1670 train_time:26395ms step_avg:97.04ms +step:273/1670 train_time:26492ms step_avg:97.04ms +step:274/1670 train_time:26588ms step_avg:97.04ms +step:275/1670 train_time:26683ms step_avg:97.03ms +step:276/1670 train_time:26779ms step_avg:97.02ms +step:277/1670 train_time:26875ms step_avg:97.02ms +step:278/1670 train_time:26971ms step_avg:97.02ms +step:279/1670 train_time:27067ms step_avg:97.01ms +step:280/1670 train_time:27161ms step_avg:97.00ms +step:281/1670 train_time:27256ms step_avg:97.00ms +step:282/1670 train_time:27351ms step_avg:96.99ms +step:283/1670 train_time:27447ms step_avg:96.99ms +step:284/1670 train_time:27544ms step_avg:96.98ms +step:285/1670 train_time:27638ms step_avg:96.98ms +step:286/1670 train_time:27735ms step_avg:96.97ms +step:287/1670 train_time:27832ms step_avg:96.97ms +step:288/1670 train_time:27928ms step_avg:96.97ms +step:289/1670 train_time:28024ms step_avg:96.97ms +step:290/1670 train_time:28118ms step_avg:96.96ms +step:291/1670 train_time:28214ms step_avg:96.95ms +step:292/1670 train_time:28310ms step_avg:96.95ms +step:293/1670 train_time:28405ms step_avg:96.95ms +step:294/1670 train_time:28500ms step_avg:96.94ms +step:295/1670 train_time:28595ms step_avg:96.93ms +step:296/1670 train_time:28691ms step_avg:96.93ms +step:297/1670 train_time:28786ms step_avg:96.92ms +step:298/1670 train_time:28882ms step_avg:96.92ms +step:299/1670 train_time:28977ms step_avg:96.91ms +step:300/1670 train_time:29074ms step_avg:96.91ms +step:301/1670 train_time:29170ms step_avg:96.91ms +step:302/1670 train_time:29266ms step_avg:96.91ms +step:303/1670 train_time:29362ms step_avg:96.90ms +step:304/1670 train_time:29457ms step_avg:96.90ms +step:305/1670 train_time:29553ms step_avg:96.89ms +step:306/1670 train_time:29649ms step_avg:96.89ms +step:307/1670 train_time:29744ms step_avg:96.89ms +step:308/1670 train_time:29839ms step_avg:96.88ms +step:309/1670 train_time:29935ms step_avg:96.88ms +step:310/1670 train_time:30031ms step_avg:96.87ms +step:311/1670 train_time:30126ms step_avg:96.87ms +step:312/1670 train_time:30222ms step_avg:96.86ms +step:313/1670 train_time:30318ms step_avg:96.86ms +step:314/1670 train_time:30413ms step_avg:96.86ms +step:315/1670 train_time:30509ms step_avg:96.85ms +step:316/1670 train_time:30605ms step_avg:96.85ms +step:317/1670 train_time:30700ms step_avg:96.85ms +step:318/1670 train_time:30795ms step_avg:96.84ms +step:319/1670 train_time:30891ms step_avg:96.84ms +step:320/1670 train_time:30986ms step_avg:96.83ms +step:321/1670 train_time:31082ms step_avg:96.83ms +step:322/1670 train_time:31177ms step_avg:96.82ms +step:323/1670 train_time:31273ms step_avg:96.82ms +step:324/1670 train_time:31370ms step_avg:96.82ms +step:325/1670 train_time:31465ms step_avg:96.82ms +step:326/1670 train_time:31560ms step_avg:96.81ms +step:327/1670 train_time:31656ms step_avg:96.81ms +step:328/1670 train_time:31752ms step_avg:96.81ms +step:329/1670 train_time:31849ms step_avg:96.80ms +step:330/1670 train_time:31944ms step_avg:96.80ms +step:331/1670 train_time:32039ms step_avg:96.80ms +step:332/1670 train_time:32136ms step_avg:96.79ms +step:333/1670 train_time:32232ms step_avg:96.79ms +step:334/1670 train_time:32328ms step_avg:96.79ms +step:335/1670 train_time:32424ms step_avg:96.79ms +step:336/1670 train_time:32520ms step_avg:96.78ms +step:337/1670 train_time:32614ms step_avg:96.78ms +step:338/1670 train_time:32709ms step_avg:96.77ms +step:339/1670 train_time:32805ms step_avg:96.77ms +step:340/1670 train_time:32900ms step_avg:96.76ms +step:341/1670 train_time:32996ms step_avg:96.76ms +step:342/1670 train_time:33092ms step_avg:96.76ms +step:343/1670 train_time:33188ms step_avg:96.76ms +step:344/1670 train_time:33284ms step_avg:96.75ms +step:345/1670 train_time:33379ms step_avg:96.75ms +step:346/1670 train_time:33475ms step_avg:96.75ms +step:347/1670 train_time:33571ms step_avg:96.75ms +step:348/1670 train_time:33666ms step_avg:96.74ms +step:349/1670 train_time:33762ms step_avg:96.74ms +step:350/1670 train_time:33857ms step_avg:96.73ms +step:351/1670 train_time:33952ms step_avg:96.73ms +step:352/1670 train_time:34048ms step_avg:96.73ms +step:353/1670 train_time:34144ms step_avg:96.73ms +step:354/1670 train_time:34239ms step_avg:96.72ms +step:355/1670 train_time:34335ms step_avg:96.72ms +step:356/1670 train_time:34431ms step_avg:96.72ms +step:357/1670 train_time:34527ms step_avg:96.72ms +step:358/1670 train_time:34622ms step_avg:96.71ms +step:359/1670 train_time:34717ms step_avg:96.70ms +step:360/1670 train_time:34813ms step_avg:96.70ms +step:361/1670 train_time:34909ms step_avg:96.70ms +step:362/1670 train_time:35004ms step_avg:96.70ms +step:363/1670 train_time:35099ms step_avg:96.69ms +step:364/1670 train_time:35195ms step_avg:96.69ms +step:365/1670 train_time:35291ms step_avg:96.69ms +step:366/1670 train_time:35387ms step_avg:96.69ms +step:367/1670 train_time:35482ms step_avg:96.68ms +step:368/1670 train_time:35578ms step_avg:96.68ms +step:369/1670 train_time:35673ms step_avg:96.68ms +step:370/1670 train_time:35769ms step_avg:96.67ms +step:371/1670 train_time:35866ms step_avg:96.67ms +step:372/1670 train_time:35960ms step_avg:96.67ms +step:373/1670 train_time:36056ms step_avg:96.66ms +step:374/1670 train_time:36152ms step_avg:96.66ms +step:375/1670 train_time:36248ms step_avg:96.66ms +step:375/1670 val_loss:3.8203 train_time:36343ms step_avg:96.91ms +step:376/1670 train_time:36367ms step_avg:96.72ms +step:377/1670 train_time:36445ms step_avg:96.67ms +step:378/1670 train_time:36543ms step_avg:96.67ms +step:379/1670 train_time:36638ms step_avg:96.67ms +step:380/1670 train_time:36733ms step_avg:96.66ms +step:381/1670 train_time:36827ms step_avg:96.66ms +step:382/1670 train_time:36922ms step_avg:96.65ms +step:383/1670 train_time:37017ms step_avg:96.65ms +step:384/1670 train_time:37112ms step_avg:96.65ms +step:385/1670 train_time:37206ms step_avg:96.64ms +step:386/1670 train_time:37303ms step_avg:96.64ms +step:387/1670 train_time:37401ms step_avg:96.64ms +step:388/1670 train_time:37498ms step_avg:96.64ms +step:389/1670 train_time:37594ms step_avg:96.64ms +step:390/1670 train_time:37689ms step_avg:96.64ms +step:391/1670 train_time:37785ms step_avg:96.64ms +step:392/1670 train_time:37880ms step_avg:96.63ms +step:393/1670 train_time:37975ms step_avg:96.63ms +step:394/1670 train_time:38070ms step_avg:96.62ms +step:395/1670 train_time:38166ms step_avg:96.62ms +step:396/1670 train_time:38261ms step_avg:96.62ms +step:397/1670 train_time:38357ms step_avg:96.62ms +step:398/1670 train_time:38452ms step_avg:96.61ms +step:399/1670 train_time:38549ms step_avg:96.61ms +step:400/1670 train_time:38645ms step_avg:96.61ms +step:401/1670 train_time:38741ms step_avg:96.61ms +step:402/1670 train_time:38836ms step_avg:96.61ms +step:403/1670 train_time:38931ms step_avg:96.60ms +step:404/1670 train_time:39026ms step_avg:96.60ms +step:405/1670 train_time:39121ms step_avg:96.59ms +step:406/1670 train_time:39216ms step_avg:96.59ms +step:407/1670 train_time:39311ms step_avg:96.59ms +step:408/1670 train_time:39407ms step_avg:96.58ms +step:409/1670 train_time:39504ms step_avg:96.59ms +step:410/1670 train_time:39600ms step_avg:96.59ms +step:411/1670 train_time:39696ms step_avg:96.58ms +step:412/1670 train_time:39792ms step_avg:96.58ms +step:413/1670 train_time:39887ms step_avg:96.58ms +step:414/1670 train_time:39982ms step_avg:96.58ms +step:415/1670 train_time:40079ms step_avg:96.57ms +step:416/1670 train_time:40173ms step_avg:96.57ms +step:417/1670 train_time:40269ms step_avg:96.57ms +step:418/1670 train_time:40364ms step_avg:96.56ms +step:419/1670 train_time:40460ms step_avg:96.56ms +step:420/1670 train_time:40556ms step_avg:96.56ms +step:421/1670 train_time:40651ms step_avg:96.56ms +step:422/1670 train_time:40747ms step_avg:96.56ms +step:423/1670 train_time:40843ms step_avg:96.56ms +step:424/1670 train_time:40939ms step_avg:96.55ms +step:425/1670 train_time:41231ms step_avg:97.01ms +step:426/1670 train_time:41321ms step_avg:97.00ms +step:427/1670 train_time:41414ms step_avg:96.99ms +step:428/1670 train_time:41508ms step_avg:96.98ms +step:429/1670 train_time:41604ms step_avg:96.98ms +step:430/1670 train_time:41698ms step_avg:96.97ms +step:431/1670 train_time:41793ms step_avg:96.97ms +step:432/1670 train_time:41887ms step_avg:96.96ms +step:433/1670 train_time:41981ms step_avg:96.95ms +step:434/1670 train_time:42076ms step_avg:96.95ms +step:435/1670 train_time:42174ms step_avg:96.95ms +step:436/1670 train_time:42271ms step_avg:96.95ms +step:437/1670 train_time:42369ms step_avg:96.95ms +step:438/1670 train_time:42464ms step_avg:96.95ms +step:439/1670 train_time:42560ms step_avg:96.95ms +step:440/1670 train_time:42655ms step_avg:96.94ms +step:441/1670 train_time:42749ms step_avg:96.94ms +step:442/1670 train_time:42844ms step_avg:96.93ms +step:443/1670 train_time:42939ms step_avg:96.93ms +step:444/1670 train_time:43034ms step_avg:96.92ms +step:445/1670 train_time:43130ms step_avg:96.92ms +step:446/1670 train_time:43227ms step_avg:96.92ms +step:447/1670 train_time:43324ms step_avg:96.92ms +step:448/1670 train_time:43421ms step_avg:96.92ms +step:449/1670 train_time:43516ms step_avg:96.92ms +step:450/1670 train_time:43611ms step_avg:96.91ms +step:451/1670 train_time:43706ms step_avg:96.91ms +step:452/1670 train_time:43801ms step_avg:96.90ms +step:453/1670 train_time:43895ms step_avg:96.90ms +step:454/1670 train_time:43990ms step_avg:96.89ms +step:455/1670 train_time:44086ms step_avg:96.89ms +step:456/1670 train_time:44182ms step_avg:96.89ms +step:457/1670 train_time:44280ms step_avg:96.89ms +step:458/1670 train_time:44377ms step_avg:96.89ms +step:459/1670 train_time:44472ms step_avg:96.89ms +step:460/1670 train_time:44568ms step_avg:96.89ms +step:461/1670 train_time:44663ms step_avg:96.88ms +step:462/1670 train_time:44758ms step_avg:96.88ms +step:463/1670 train_time:44853ms step_avg:96.88ms +step:464/1670 train_time:44948ms step_avg:96.87ms +step:465/1670 train_time:45043ms step_avg:96.87ms +step:466/1670 train_time:45140ms step_avg:96.87ms +step:467/1670 train_time:45235ms step_avg:96.86ms +step:468/1670 train_time:45331ms step_avg:96.86ms +step:469/1670 train_time:45428ms step_avg:96.86ms +step:470/1670 train_time:45523ms step_avg:96.86ms +step:471/1670 train_time:45620ms step_avg:96.86ms +step:472/1670 train_time:45714ms step_avg:96.85ms +step:473/1670 train_time:45809ms step_avg:96.85ms +step:474/1670 train_time:45905ms step_avg:96.85ms +step:475/1670 train_time:46001ms step_avg:96.84ms +step:476/1670 train_time:46095ms step_avg:96.84ms +step:477/1670 train_time:46191ms step_avg:96.84ms +step:478/1670 train_time:46287ms step_avg:96.84ms +step:479/1670 train_time:46384ms step_avg:96.84ms +step:480/1670 train_time:46480ms step_avg:96.83ms +step:481/1670 train_time:46575ms step_avg:96.83ms +step:482/1670 train_time:46671ms step_avg:96.83ms +step:483/1670 train_time:46766ms step_avg:96.82ms +step:484/1670 train_time:46860ms step_avg:96.82ms +step:485/1670 train_time:46955ms step_avg:96.82ms +step:486/1670 train_time:47050ms step_avg:96.81ms +step:487/1670 train_time:47146ms step_avg:96.81ms +step:488/1670 train_time:47242ms step_avg:96.81ms +step:489/1670 train_time:47339ms step_avg:96.81ms +step:490/1670 train_time:47435ms step_avg:96.81ms +step:491/1670 train_time:47530ms step_avg:96.80ms +step:492/1670 train_time:47626ms step_avg:96.80ms +step:493/1670 train_time:47722ms step_avg:96.80ms +step:494/1670 train_time:47817ms step_avg:96.80ms +step:495/1670 train_time:47913ms step_avg:96.79ms +step:496/1670 train_time:48007ms step_avg:96.79ms +step:497/1670 train_time:48103ms step_avg:96.79ms +step:498/1670 train_time:48199ms step_avg:96.79ms +step:499/1670 train_time:48295ms step_avg:96.78ms +step:500/1670 train_time:48391ms step_avg:96.78ms +step:500/1670 val_loss:3.7189 train_time:48485ms step_avg:96.97ms +step:501/1670 train_time:48511ms step_avg:96.83ms +step:502/1670 train_time:48587ms step_avg:96.79ms +step:503/1670 train_time:48684ms step_avg:96.79ms +step:504/1670 train_time:48780ms step_avg:96.79ms +step:505/1670 train_time:48875ms step_avg:96.78ms +step:506/1670 train_time:48970ms step_avg:96.78ms +step:507/1670 train_time:49065ms step_avg:96.77ms +step:508/1670 train_time:49159ms step_avg:96.77ms +step:509/1670 train_time:49253ms step_avg:96.77ms +step:510/1670 train_time:49348ms step_avg:96.76ms +step:511/1670 train_time:49443ms step_avg:96.76ms +step:512/1670 train_time:49540ms step_avg:96.76ms +step:513/1670 train_time:49638ms step_avg:96.76ms +step:514/1670 train_time:49736ms step_avg:96.76ms +step:515/1670 train_time:49832ms step_avg:96.76ms +step:516/1670 train_time:49927ms step_avg:96.76ms +step:517/1670 train_time:50022ms step_avg:96.75ms +step:518/1670 train_time:50117ms step_avg:96.75ms +step:519/1670 train_time:50212ms step_avg:96.75ms +step:520/1670 train_time:50307ms step_avg:96.74ms +step:521/1670 train_time:50402ms step_avg:96.74ms +step:522/1670 train_time:50499ms step_avg:96.74ms +step:523/1670 train_time:50595ms step_avg:96.74ms +step:524/1670 train_time:50693ms step_avg:96.74ms +step:525/1670 train_time:50790ms step_avg:96.74ms +step:526/1670 train_time:50886ms step_avg:96.74ms +step:527/1670 train_time:50980ms step_avg:96.74ms +step:528/1670 train_time:51075ms step_avg:96.73ms +step:529/1670 train_time:51170ms step_avg:96.73ms +step:530/1670 train_time:51265ms step_avg:96.73ms +step:531/1670 train_time:51360ms step_avg:96.72ms +step:532/1670 train_time:51456ms step_avg:96.72ms +step:533/1670 train_time:51552ms step_avg:96.72ms +step:534/1670 train_time:51648ms step_avg:96.72ms +step:535/1670 train_time:51745ms step_avg:96.72ms +step:536/1670 train_time:51840ms step_avg:96.72ms +step:537/1670 train_time:51936ms step_avg:96.71ms +step:538/1670 train_time:52031ms step_avg:96.71ms +step:539/1670 train_time:52126ms step_avg:96.71ms +step:540/1670 train_time:52221ms step_avg:96.71ms +step:541/1670 train_time:52316ms step_avg:96.70ms +step:542/1670 train_time:52413ms step_avg:96.70ms +step:543/1670 train_time:52509ms step_avg:96.70ms +step:544/1670 train_time:52605ms step_avg:96.70ms +step:545/1670 train_time:52702ms step_avg:96.70ms +step:546/1670 train_time:52797ms step_avg:96.70ms +step:547/1670 train_time:52893ms step_avg:96.70ms +step:548/1670 train_time:52989ms step_avg:96.70ms +step:549/1670 train_time:53084ms step_avg:96.69ms +step:550/1670 train_time:53179ms step_avg:96.69ms +step:551/1670 train_time:53274ms step_avg:96.69ms +step:552/1670 train_time:53370ms step_avg:96.68ms +step:553/1670 train_time:53465ms step_avg:96.68ms +step:554/1670 train_time:53561ms step_avg:96.68ms +step:555/1670 train_time:53656ms step_avg:96.68ms +step:556/1670 train_time:53753ms step_avg:96.68ms +step:557/1670 train_time:53849ms step_avg:96.68ms +step:558/1670 train_time:53944ms step_avg:96.67ms +step:559/1670 train_time:54040ms step_avg:96.67ms +step:560/1670 train_time:54137ms step_avg:96.67ms +step:561/1670 train_time:54235ms step_avg:96.68ms +step:562/1670 train_time:54332ms step_avg:96.68ms +step:563/1670 train_time:54429ms step_avg:96.68ms +step:564/1670 train_time:54527ms step_avg:96.68ms +step:565/1670 train_time:54624ms step_avg:96.68ms +step:566/1670 train_time:54720ms step_avg:96.68ms +step:567/1670 train_time:54818ms step_avg:96.68ms +step:568/1670 train_time:54915ms step_avg:96.68ms +step:569/1670 train_time:55011ms step_avg:96.68ms +step:570/1670 train_time:55108ms step_avg:96.68ms +step:571/1670 train_time:55204ms step_avg:96.68ms +step:572/1670 train_time:55300ms step_avg:96.68ms +step:573/1670 train_time:55398ms step_avg:96.68ms +step:574/1670 train_time:55495ms step_avg:96.68ms +step:575/1670 train_time:55593ms step_avg:96.68ms +step:576/1670 train_time:55689ms step_avg:96.68ms +step:577/1670 train_time:55787ms step_avg:96.68ms +step:578/1670 train_time:55883ms step_avg:96.68ms +step:579/1670 train_time:55980ms step_avg:96.68ms +step:580/1670 train_time:56077ms step_avg:96.68ms +step:581/1670 train_time:56176ms step_avg:96.69ms +step:582/1670 train_time:56274ms step_avg:96.69ms +step:583/1670 train_time:56372ms step_avg:96.69ms +step:584/1670 train_time:56469ms step_avg:96.69ms +step:585/1670 train_time:56566ms step_avg:96.69ms +step:586/1670 train_time:56663ms step_avg:96.70ms +step:587/1670 train_time:56760ms step_avg:96.69ms +step:588/1670 train_time:56857ms step_avg:96.70ms +step:589/1670 train_time:56954ms step_avg:96.70ms +step:590/1670 train_time:57051ms step_avg:96.70ms +step:591/1670 train_time:57148ms step_avg:96.70ms +step:592/1670 train_time:57246ms step_avg:96.70ms +step:593/1670 train_time:57343ms step_avg:96.70ms +step:594/1670 train_time:57439ms step_avg:96.70ms +step:595/1670 train_time:57536ms step_avg:96.70ms +step:596/1670 train_time:57635ms step_avg:96.70ms +step:597/1670 train_time:57733ms step_avg:96.70ms +step:598/1670 train_time:57830ms step_avg:96.70ms +step:599/1670 train_time:57927ms step_avg:96.71ms +step:600/1670 train_time:58024ms step_avg:96.71ms +step:601/1670 train_time:58120ms step_avg:96.71ms +step:602/1670 train_time:58217ms step_avg:96.71ms +step:603/1670 train_time:58315ms step_avg:96.71ms +step:604/1670 train_time:58413ms step_avg:96.71ms +step:605/1670 train_time:58511ms step_avg:96.71ms +step:606/1670 train_time:58607ms step_avg:96.71ms +step:607/1670 train_time:58704ms step_avg:96.71ms +step:608/1670 train_time:58801ms step_avg:96.71ms +step:609/1670 train_time:58898ms step_avg:96.71ms +step:610/1670 train_time:58995ms step_avg:96.71ms +step:611/1670 train_time:59092ms step_avg:96.71ms +step:612/1670 train_time:59189ms step_avg:96.71ms +step:613/1670 train_time:59286ms step_avg:96.71ms +step:614/1670 train_time:59383ms step_avg:96.72ms +step:615/1670 train_time:59480ms step_avg:96.72ms +step:616/1670 train_time:59578ms step_avg:96.72ms +step:617/1670 train_time:59676ms step_avg:96.72ms +step:618/1670 train_time:59774ms step_avg:96.72ms +step:619/1670 train_time:59871ms step_avg:96.72ms +step:620/1670 train_time:59969ms step_avg:96.72ms +step:621/1670 train_time:60065ms step_avg:96.72ms +step:622/1670 train_time:60162ms step_avg:96.72ms +step:623/1670 train_time:60259ms step_avg:96.72ms +step:624/1670 train_time:60357ms step_avg:96.73ms +step:625/1670 train_time:60453ms step_avg:96.73ms +step:625/1670 val_loss:3.6175 train_time:60550ms step_avg:96.88ms +step:626/1670 train_time:60574ms step_avg:96.76ms +step:627/1670 train_time:60660ms step_avg:96.75ms +step:628/1670 train_time:60756ms step_avg:96.75ms +step:629/1670 train_time:60852ms step_avg:96.74ms +step:630/1670 train_time:60948ms step_avg:96.74ms +step:631/1670 train_time:61044ms step_avg:96.74ms +step:632/1670 train_time:61140ms step_avg:96.74ms +step:633/1670 train_time:61235ms step_avg:96.74ms +step:634/1670 train_time:61331ms step_avg:96.74ms +step:635/1670 train_time:61428ms step_avg:96.74ms +step:636/1670 train_time:61527ms step_avg:96.74ms +step:637/1670 train_time:61627ms step_avg:96.75ms +step:638/1670 train_time:61725ms step_avg:96.75ms +step:639/1670 train_time:62063ms step_avg:97.12ms +step:640/1670 train_time:62165ms step_avg:97.13ms +step:641/1670 train_time:62260ms step_avg:97.13ms +step:642/1670 train_time:62355ms step_avg:97.13ms +step:643/1670 train_time:62451ms step_avg:97.12ms +step:644/1670 train_time:62548ms step_avg:97.12ms +step:645/1670 train_time:62644ms step_avg:97.12ms +step:646/1670 train_time:62740ms step_avg:97.12ms +step:647/1670 train_time:62836ms step_avg:97.12ms +step:648/1670 train_time:62931ms step_avg:97.12ms +step:649/1670 train_time:63032ms step_avg:97.12ms +step:650/1670 train_time:63130ms step_avg:97.12ms +step:651/1670 train_time:63228ms step_avg:97.12ms +step:652/1670 train_time:63326ms step_avg:97.13ms +step:653/1670 train_time:63424ms step_avg:97.13ms +step:654/1670 train_time:63521ms step_avg:97.13ms +step:655/1670 train_time:63617ms step_avg:97.13ms +step:656/1670 train_time:63713ms step_avg:97.12ms +step:657/1670 train_time:63809ms step_avg:97.12ms +step:658/1670 train_time:63905ms step_avg:97.12ms +step:659/1670 train_time:64005ms step_avg:97.12ms +step:660/1670 train_time:64104ms step_avg:97.13ms +step:661/1670 train_time:64202ms step_avg:97.13ms +step:662/1670 train_time:64300ms step_avg:97.13ms +step:663/1670 train_time:64397ms step_avg:97.13ms +step:664/1670 train_time:64494ms step_avg:97.13ms +step:665/1670 train_time:64590ms step_avg:97.13ms +step:666/1670 train_time:64686ms step_avg:97.13ms +step:667/1670 train_time:64783ms step_avg:97.13ms +step:668/1670 train_time:64880ms step_avg:97.13ms +step:669/1670 train_time:64977ms step_avg:97.13ms +step:670/1670 train_time:65075ms step_avg:97.13ms +step:671/1670 train_time:65173ms step_avg:97.13ms +step:672/1670 train_time:65271ms step_avg:97.13ms +step:673/1670 train_time:65368ms step_avg:97.13ms +step:674/1670 train_time:65465ms step_avg:97.13ms +step:675/1670 train_time:65562ms step_avg:97.13ms +step:676/1670 train_time:65660ms step_avg:97.13ms +step:677/1670 train_time:65755ms step_avg:97.13ms +step:678/1670 train_time:65852ms step_avg:97.13ms +step:679/1670 train_time:65949ms step_avg:97.13ms +step:680/1670 train_time:66045ms step_avg:97.13ms +step:681/1670 train_time:66144ms step_avg:97.13ms +step:682/1670 train_time:66243ms step_avg:97.13ms +step:683/1670 train_time:66340ms step_avg:97.13ms +step:684/1670 train_time:66436ms step_avg:97.13ms +step:685/1670 train_time:66534ms step_avg:97.13ms +step:686/1670 train_time:66630ms step_avg:97.13ms +step:687/1670 train_time:66727ms step_avg:97.13ms +step:688/1670 train_time:66824ms step_avg:97.13ms +step:689/1670 train_time:66922ms step_avg:97.13ms +step:690/1670 train_time:67019ms step_avg:97.13ms +step:691/1670 train_time:67117ms step_avg:97.13ms +step:692/1670 train_time:67213ms step_avg:97.13ms +step:693/1670 train_time:67309ms step_avg:97.13ms +step:694/1670 train_time:67407ms step_avg:97.13ms +step:695/1670 train_time:67505ms step_avg:97.13ms +step:696/1670 train_time:67603ms step_avg:97.13ms +step:697/1670 train_time:67699ms step_avg:97.13ms +step:698/1670 train_time:67796ms step_avg:97.13ms +step:699/1670 train_time:67892ms step_avg:97.13ms +step:700/1670 train_time:67990ms step_avg:97.13ms +step:701/1670 train_time:68087ms step_avg:97.13ms +step:702/1670 train_time:68185ms step_avg:97.13ms +step:703/1670 train_time:68283ms step_avg:97.13ms +step:704/1670 train_time:68381ms step_avg:97.13ms +step:705/1670 train_time:68478ms step_avg:97.13ms +step:706/1670 train_time:68575ms step_avg:97.13ms +step:707/1670 train_time:68672ms step_avg:97.13ms +step:708/1670 train_time:68767ms step_avg:97.13ms +step:709/1670 train_time:68864ms step_avg:97.13ms +step:710/1670 train_time:68962ms step_avg:97.13ms +step:711/1670 train_time:69061ms step_avg:97.13ms +step:712/1670 train_time:69159ms step_avg:97.13ms +step:713/1670 train_time:69255ms step_avg:97.13ms +step:714/1670 train_time:69352ms step_avg:97.13ms +step:715/1670 train_time:69448ms step_avg:97.13ms +step:716/1670 train_time:69545ms step_avg:97.13ms +step:717/1670 train_time:69644ms step_avg:97.13ms +step:718/1670 train_time:69740ms step_avg:97.13ms +step:719/1670 train_time:69836ms step_avg:97.13ms +step:720/1670 train_time:69933ms step_avg:97.13ms +step:721/1670 train_time:70029ms step_avg:97.13ms +step:722/1670 train_time:70127ms step_avg:97.13ms +step:723/1670 train_time:70224ms step_avg:97.13ms +step:724/1670 train_time:70322ms step_avg:97.13ms +step:725/1670 train_time:70420ms step_avg:97.13ms +step:726/1670 train_time:70516ms step_avg:97.13ms +step:727/1670 train_time:70613ms step_avg:97.13ms +step:728/1670 train_time:70709ms step_avg:97.13ms +step:729/1670 train_time:70805ms step_avg:97.13ms +step:730/1670 train_time:70902ms step_avg:97.13ms +step:731/1670 train_time:70999ms step_avg:97.13ms +step:732/1670 train_time:71096ms step_avg:97.13ms +step:733/1670 train_time:71194ms step_avg:97.13ms +step:734/1670 train_time:71291ms step_avg:97.13ms +step:735/1670 train_time:71388ms step_avg:97.13ms +step:736/1670 train_time:71486ms step_avg:97.13ms +step:737/1670 train_time:71583ms step_avg:97.13ms +step:738/1670 train_time:71680ms step_avg:97.13ms +step:739/1670 train_time:71777ms step_avg:97.13ms +step:740/1670 train_time:71873ms step_avg:97.13ms +step:741/1670 train_time:71970ms step_avg:97.12ms +step:742/1670 train_time:72067ms step_avg:97.13ms +step:743/1670 train_time:72165ms step_avg:97.13ms +step:744/1670 train_time:72263ms step_avg:97.13ms +step:745/1670 train_time:72362ms step_avg:97.13ms +step:746/1670 train_time:72458ms step_avg:97.13ms +step:747/1670 train_time:72555ms step_avg:97.13ms +step:748/1670 train_time:72652ms step_avg:97.13ms +step:749/1670 train_time:72749ms step_avg:97.13ms +step:750/1670 train_time:72846ms step_avg:97.13ms +step:750/1670 val_loss:3.5636 train_time:72942ms step_avg:97.26ms +step:751/1670 train_time:72967ms step_avg:97.16ms +step:752/1670 train_time:73047ms step_avg:97.14ms +step:753/1670 train_time:73145ms step_avg:97.14ms +step:754/1670 train_time:73242ms step_avg:97.14ms +step:755/1670 train_time:73339ms step_avg:97.14ms +step:756/1670 train_time:73435ms step_avg:97.14ms +step:757/1670 train_time:73531ms step_avg:97.13ms +step:758/1670 train_time:73627ms step_avg:97.13ms +step:759/1670 train_time:73723ms step_avg:97.13ms +step:760/1670 train_time:73820ms step_avg:97.13ms +step:761/1670 train_time:73920ms step_avg:97.13ms +step:762/1670 train_time:74019ms step_avg:97.14ms +step:763/1670 train_time:74118ms step_avg:97.14ms +step:764/1670 train_time:74215ms step_avg:97.14ms +step:765/1670 train_time:74312ms step_avg:97.14ms +step:766/1670 train_time:74409ms step_avg:97.14ms +step:767/1670 train_time:74505ms step_avg:97.14ms +step:768/1670 train_time:74602ms step_avg:97.14ms +step:769/1670 train_time:74699ms step_avg:97.14ms +step:770/1670 train_time:74795ms step_avg:97.14ms +step:771/1670 train_time:74892ms step_avg:97.14ms +step:772/1670 train_time:74990ms step_avg:97.14ms +step:773/1670 train_time:75087ms step_avg:97.14ms +step:774/1670 train_time:75185ms step_avg:97.14ms +step:775/1670 train_time:75282ms step_avg:97.14ms +step:776/1670 train_time:75380ms step_avg:97.14ms +step:777/1670 train_time:75478ms step_avg:97.14ms +step:778/1670 train_time:75574ms step_avg:97.14ms +step:779/1670 train_time:75670ms step_avg:97.14ms +step:780/1670 train_time:75767ms step_avg:97.14ms +step:781/1670 train_time:75863ms step_avg:97.14ms +step:782/1670 train_time:75961ms step_avg:97.14ms +step:783/1670 train_time:76061ms step_avg:97.14ms +step:784/1670 train_time:76158ms step_avg:97.14ms +step:785/1670 train_time:76255ms step_avg:97.14ms +step:786/1670 train_time:76352ms step_avg:97.14ms +step:787/1670 train_time:76448ms step_avg:97.14ms +step:788/1670 train_time:76545ms step_avg:97.14ms +step:789/1670 train_time:76643ms step_avg:97.14ms +step:790/1670 train_time:76741ms step_avg:97.14ms +step:791/1670 train_time:76838ms step_avg:97.14ms +step:792/1670 train_time:76935ms step_avg:97.14ms +step:793/1670 train_time:77032ms step_avg:97.14ms +step:794/1670 train_time:77130ms step_avg:97.14ms +step:795/1670 train_time:77227ms step_avg:97.14ms +step:796/1670 train_time:77324ms step_avg:97.14ms +step:797/1670 train_time:77421ms step_avg:97.14ms +step:798/1670 train_time:77518ms step_avg:97.14ms +step:799/1670 train_time:77615ms step_avg:97.14ms +step:800/1670 train_time:77711ms step_avg:97.14ms +step:801/1670 train_time:77808ms step_avg:97.14ms +step:802/1670 train_time:77906ms step_avg:97.14ms +step:803/1670 train_time:78003ms step_avg:97.14ms +step:804/1670 train_time:78101ms step_avg:97.14ms +step:805/1670 train_time:78199ms step_avg:97.14ms +step:806/1670 train_time:78296ms step_avg:97.14ms +step:807/1670 train_time:78392ms step_avg:97.14ms +step:808/1670 train_time:78489ms step_avg:97.14ms +step:809/1670 train_time:78585ms step_avg:97.14ms +step:810/1670 train_time:78683ms step_avg:97.14ms +step:811/1670 train_time:78780ms step_avg:97.14ms +step:812/1670 train_time:78878ms step_avg:97.14ms +step:813/1670 train_time:78975ms step_avg:97.14ms +step:814/1670 train_time:79072ms step_avg:97.14ms +step:815/1670 train_time:79168ms step_avg:97.14ms +step:816/1670 train_time:79266ms step_avg:97.14ms +step:817/1670 train_time:79363ms step_avg:97.14ms +step:818/1670 train_time:79460ms step_avg:97.14ms +step:819/1670 train_time:79557ms step_avg:97.14ms +step:820/1670 train_time:79654ms step_avg:97.14ms +step:821/1670 train_time:79751ms step_avg:97.14ms +step:822/1670 train_time:79848ms step_avg:97.14ms +step:823/1670 train_time:79945ms step_avg:97.14ms +step:824/1670 train_time:80043ms step_avg:97.14ms +step:825/1670 train_time:80141ms step_avg:97.14ms +step:826/1670 train_time:80238ms step_avg:97.14ms +step:827/1670 train_time:80336ms step_avg:97.14ms +step:828/1670 train_time:80432ms step_avg:97.14ms +step:829/1670 train_time:80528ms step_avg:97.14ms +step:830/1670 train_time:80625ms step_avg:97.14ms +step:831/1670 train_time:80722ms step_avg:97.14ms +step:832/1670 train_time:80820ms step_avg:97.14ms +step:833/1670 train_time:80917ms step_avg:97.14ms +step:834/1670 train_time:81014ms step_avg:97.14ms +step:835/1670 train_time:81112ms step_avg:97.14ms +step:836/1670 train_time:81208ms step_avg:97.14ms +step:837/1670 train_time:81306ms step_avg:97.14ms +step:838/1670 train_time:81404ms step_avg:97.14ms +step:839/1670 train_time:81501ms step_avg:97.14ms +step:840/1670 train_time:81598ms step_avg:97.14ms +step:841/1670 train_time:81696ms step_avg:97.14ms +step:842/1670 train_time:81792ms step_avg:97.14ms +step:843/1670 train_time:81888ms step_avg:97.14ms +step:844/1670 train_time:81984ms step_avg:97.14ms +step:845/1670 train_time:82082ms step_avg:97.14ms +step:846/1670 train_time:82180ms step_avg:97.14ms +step:847/1670 train_time:82278ms step_avg:97.14ms +step:848/1670 train_time:82376ms step_avg:97.14ms +step:849/1670 train_time:82472ms step_avg:97.14ms +step:850/1670 train_time:82570ms step_avg:97.14ms +step:851/1670 train_time:82846ms step_avg:97.35ms +step:852/1670 train_time:83015ms step_avg:97.44ms +step:853/1670 train_time:83109ms step_avg:97.43ms +step:854/1670 train_time:83205ms step_avg:97.43ms +step:855/1670 train_time:83301ms step_avg:97.43ms +step:856/1670 train_time:83397ms step_avg:97.43ms +step:857/1670 train_time:83493ms step_avg:97.42ms +step:858/1670 train_time:83588ms step_avg:97.42ms +step:859/1670 train_time:83684ms step_avg:97.42ms +step:860/1670 train_time:83780ms step_avg:97.42ms +step:861/1670 train_time:83878ms step_avg:97.42ms +step:862/1670 train_time:83979ms step_avg:97.42ms +step:863/1670 train_time:84079ms step_avg:97.43ms +step:864/1670 train_time:84177ms step_avg:97.43ms +step:865/1670 train_time:84273ms step_avg:97.43ms +step:866/1670 train_time:84370ms step_avg:97.42ms +step:867/1670 train_time:84466ms step_avg:97.42ms +step:868/1670 train_time:84562ms step_avg:97.42ms +step:869/1670 train_time:84659ms step_avg:97.42ms +step:870/1670 train_time:84756ms step_avg:97.42ms +step:871/1670 train_time:84852ms step_avg:97.42ms +step:872/1670 train_time:84951ms step_avg:97.42ms +step:873/1670 train_time:85050ms step_avg:97.42ms +step:874/1670 train_time:85147ms step_avg:97.42ms +step:875/1670 train_time:85244ms step_avg:97.42ms +step:875/1670 val_loss:3.5198 train_time:85341ms step_avg:97.53ms +step:876/1670 train_time:85364ms step_avg:97.45ms +step:877/1670 train_time:85447ms step_avg:97.43ms +step:878/1670 train_time:85548ms step_avg:97.44ms +step:879/1670 train_time:85647ms step_avg:97.44ms +step:880/1670 train_time:85743ms step_avg:97.43ms +step:881/1670 train_time:85839ms step_avg:97.43ms +step:882/1670 train_time:85935ms step_avg:97.43ms +step:883/1670 train_time:86030ms step_avg:97.43ms +step:884/1670 train_time:86126ms step_avg:97.43ms +step:885/1670 train_time:86222ms step_avg:97.43ms +step:886/1670 train_time:86320ms step_avg:97.43ms +step:887/1670 train_time:86422ms step_avg:97.43ms +step:888/1670 train_time:86523ms step_avg:97.44ms +step:889/1670 train_time:86623ms step_avg:97.44ms +step:890/1670 train_time:86720ms step_avg:97.44ms +step:891/1670 train_time:86817ms step_avg:97.44ms +step:892/1670 train_time:86913ms step_avg:97.44ms +step:893/1670 train_time:87008ms step_avg:97.43ms +step:894/1670 train_time:87105ms step_avg:97.43ms +step:895/1670 train_time:87201ms step_avg:97.43ms +step:896/1670 train_time:87298ms step_avg:97.43ms +step:897/1670 train_time:87396ms step_avg:97.43ms +step:898/1670 train_time:87494ms step_avg:97.43ms +step:899/1670 train_time:87594ms step_avg:97.44ms +step:900/1670 train_time:87692ms step_avg:97.44ms +step:901/1670 train_time:87788ms step_avg:97.43ms +step:902/1670 train_time:87885ms step_avg:97.43ms +step:903/1670 train_time:87982ms step_avg:97.43ms +step:904/1670 train_time:88079ms step_avg:97.43ms +step:905/1670 train_time:88175ms step_avg:97.43ms +step:906/1670 train_time:88271ms step_avg:97.43ms +step:907/1670 train_time:88368ms step_avg:97.43ms +step:908/1670 train_time:88465ms step_avg:97.43ms +step:909/1670 train_time:88564ms step_avg:97.43ms +step:910/1670 train_time:88663ms step_avg:97.43ms +step:911/1670 train_time:88762ms step_avg:97.43ms +step:912/1670 train_time:88860ms step_avg:97.43ms +step:913/1670 train_time:88957ms step_avg:97.43ms +step:914/1670 train_time:89053ms step_avg:97.43ms +step:915/1670 train_time:89149ms step_avg:97.43ms +step:916/1670 train_time:89245ms step_avg:97.43ms +step:917/1670 train_time:89342ms step_avg:97.43ms +step:918/1670 train_time:89439ms step_avg:97.43ms +step:919/1670 train_time:89538ms step_avg:97.43ms +step:920/1670 train_time:89635ms step_avg:97.43ms +step:921/1670 train_time:89733ms step_avg:97.43ms +step:922/1670 train_time:89830ms step_avg:97.43ms +step:923/1670 train_time:89927ms step_avg:97.43ms +step:924/1670 train_time:90026ms step_avg:97.43ms +step:925/1670 train_time:90123ms step_avg:97.43ms +step:926/1670 train_time:90219ms step_avg:97.43ms +step:927/1670 train_time:90316ms step_avg:97.43ms +step:928/1670 train_time:90412ms step_avg:97.43ms +step:929/1670 train_time:90509ms step_avg:97.43ms +step:930/1670 train_time:90607ms step_avg:97.43ms +step:931/1670 train_time:90705ms step_avg:97.43ms +step:932/1670 train_time:90803ms step_avg:97.43ms +step:933/1670 train_time:90901ms step_avg:97.43ms +step:934/1670 train_time:90998ms step_avg:97.43ms +step:935/1670 train_time:91094ms step_avg:97.43ms +step:936/1670 train_time:91191ms step_avg:97.43ms +step:937/1670 train_time:91287ms step_avg:97.42ms +step:938/1670 train_time:91384ms step_avg:97.42ms +step:939/1670 train_time:91482ms step_avg:97.42ms +step:940/1670 train_time:91580ms step_avg:97.43ms +step:941/1670 train_time:91677ms step_avg:97.42ms +step:942/1670 train_time:91774ms step_avg:97.42ms +step:943/1670 train_time:91871ms step_avg:97.42ms +step:944/1670 train_time:91967ms step_avg:97.42ms +step:945/1670 train_time:92064ms step_avg:97.42ms +step:946/1670 train_time:92162ms step_avg:97.42ms +step:947/1670 train_time:92260ms step_avg:97.42ms +step:948/1670 train_time:92356ms step_avg:97.42ms +step:949/1670 train_time:92452ms step_avg:97.42ms +step:950/1670 train_time:92550ms step_avg:97.42ms +step:951/1670 train_time:92647ms step_avg:97.42ms +step:952/1670 train_time:92746ms step_avg:97.42ms +step:953/1670 train_time:92843ms step_avg:97.42ms +step:954/1670 train_time:92940ms step_avg:97.42ms +step:955/1670 train_time:93038ms step_avg:97.42ms +step:956/1670 train_time:93135ms step_avg:97.42ms +step:957/1670 train_time:93232ms step_avg:97.42ms +step:958/1670 train_time:93328ms step_avg:97.42ms +step:959/1670 train_time:93425ms step_avg:97.42ms +step:960/1670 train_time:93522ms step_avg:97.42ms +step:961/1670 train_time:93620ms step_avg:97.42ms +step:962/1670 train_time:93716ms step_avg:97.42ms +step:963/1670 train_time:93813ms step_avg:97.42ms +step:964/1670 train_time:93911ms step_avg:97.42ms +step:965/1670 train_time:94008ms step_avg:97.42ms +step:966/1670 train_time:94105ms step_avg:97.42ms +step:967/1670 train_time:94202ms step_avg:97.42ms +step:968/1670 train_time:94300ms step_avg:97.42ms +step:969/1670 train_time:94398ms step_avg:97.42ms +step:970/1670 train_time:94495ms step_avg:97.42ms +step:971/1670 train_time:94591ms step_avg:97.42ms +step:972/1670 train_time:94687ms step_avg:97.42ms +step:973/1670 train_time:94786ms step_avg:97.42ms +step:974/1670 train_time:94884ms step_avg:97.42ms +step:975/1670 train_time:94981ms step_avg:97.42ms +step:976/1670 train_time:95078ms step_avg:97.42ms +step:977/1670 train_time:95174ms step_avg:97.41ms +step:978/1670 train_time:95272ms step_avg:97.41ms +step:979/1670 train_time:95368ms step_avg:97.41ms +step:980/1670 train_time:95465ms step_avg:97.41ms +step:981/1670 train_time:95562ms step_avg:97.41ms +step:982/1670 train_time:95660ms step_avg:97.41ms +step:983/1670 train_time:95758ms step_avg:97.41ms +step:984/1670 train_time:95854ms step_avg:97.41ms +step:985/1670 train_time:95950ms step_avg:97.41ms +step:986/1670 train_time:96047ms step_avg:97.41ms +step:987/1670 train_time:96144ms step_avg:97.41ms +step:988/1670 train_time:96241ms step_avg:97.41ms +step:989/1670 train_time:96338ms step_avg:97.41ms +step:990/1670 train_time:96436ms step_avg:97.41ms +step:991/1670 train_time:96533ms step_avg:97.41ms +step:992/1670 train_time:96631ms step_avg:97.41ms +step:993/1670 train_time:96728ms step_avg:97.41ms +step:994/1670 train_time:96826ms step_avg:97.41ms +step:995/1670 train_time:96923ms step_avg:97.41ms +step:996/1670 train_time:97021ms step_avg:97.41ms +step:997/1670 train_time:97119ms step_avg:97.41ms +step:998/1670 train_time:97216ms step_avg:97.41ms +step:999/1670 train_time:97313ms step_avg:97.41ms +step:1000/1670 train_time:97409ms step_avg:97.41ms +step:1000/1670 val_loss:3.4791 train_time:97506ms step_avg:97.51ms +step:1001/1670 train_time:97529ms step_avg:97.43ms +step:1002/1670 train_time:97611ms step_avg:97.42ms +step:1003/1670 train_time:97708ms step_avg:97.42ms +step:1004/1670 train_time:97806ms step_avg:97.42ms +step:1005/1670 train_time:97903ms step_avg:97.42ms +step:1006/1670 train_time:97999ms step_avg:97.41ms +step:1007/1670 train_time:98095ms step_avg:97.41ms +step:1008/1670 train_time:98190ms step_avg:97.41ms +step:1009/1670 train_time:98287ms step_avg:97.41ms +step:1010/1670 train_time:98384ms step_avg:97.41ms +step:1011/1670 train_time:98481ms step_avg:97.41ms +step:1012/1670 train_time:98582ms step_avg:97.41ms +step:1013/1670 train_time:98681ms step_avg:97.41ms +step:1014/1670 train_time:98778ms step_avg:97.41ms +step:1015/1670 train_time:98875ms step_avg:97.41ms +step:1016/1670 train_time:98971ms step_avg:97.41ms +step:1017/1670 train_time:99067ms step_avg:97.41ms +step:1018/1670 train_time:99164ms step_avg:97.41ms +step:1019/1670 train_time:99261ms step_avg:97.41ms +step:1020/1670 train_time:99357ms step_avg:97.41ms +step:1021/1670 train_time:99455ms step_avg:97.41ms +step:1022/1670 train_time:99553ms step_avg:97.41ms +step:1023/1670 train_time:99650ms step_avg:97.41ms +step:1024/1670 train_time:99748ms step_avg:97.41ms +step:1025/1670 train_time:99846ms step_avg:97.41ms +step:1026/1670 train_time:99945ms step_avg:97.41ms +step:1027/1670 train_time:100041ms step_avg:97.41ms +step:1028/1670 train_time:100138ms step_avg:97.41ms +step:1029/1670 train_time:100234ms step_avg:97.41ms +step:1030/1670 train_time:100330ms step_avg:97.41ms +step:1031/1670 train_time:100427ms step_avg:97.41ms +step:1032/1670 train_time:100526ms step_avg:97.41ms +step:1033/1670 train_time:100625ms step_avg:97.41ms +step:1034/1670 train_time:100724ms step_avg:97.41ms +step:1035/1670 train_time:100822ms step_avg:97.41ms +step:1036/1670 train_time:100919ms step_avg:97.41ms +step:1037/1670 train_time:101016ms step_avg:97.41ms +step:1038/1670 train_time:101112ms step_avg:97.41ms +step:1039/1670 train_time:101208ms step_avg:97.41ms +step:1040/1670 train_time:101305ms step_avg:97.41ms +step:1041/1670 train_time:101402ms step_avg:97.41ms +step:1042/1670 train_time:101499ms step_avg:97.41ms +step:1043/1670 train_time:101596ms step_avg:97.41ms +step:1044/1670 train_time:101694ms step_avg:97.41ms +step:1045/1670 train_time:101791ms step_avg:97.41ms +step:1046/1670 train_time:101887ms step_avg:97.41ms +step:1047/1670 train_time:101985ms step_avg:97.41ms +step:1048/1670 train_time:102083ms step_avg:97.41ms +step:1049/1670 train_time:102181ms step_avg:97.41ms +step:1050/1670 train_time:102277ms step_avg:97.41ms +step:1051/1670 train_time:102374ms step_avg:97.41ms +step:1052/1670 train_time:102470ms step_avg:97.40ms +step:1053/1670 train_time:102567ms step_avg:97.40ms +step:1054/1670 train_time:102663ms step_avg:97.40ms +step:1055/1670 train_time:102762ms step_avg:97.40ms +step:1056/1670 train_time:102859ms step_avg:97.40ms +step:1057/1670 train_time:102956ms step_avg:97.40ms +step:1058/1670 train_time:103053ms step_avg:97.40ms +step:1059/1670 train_time:103150ms step_avg:97.40ms +step:1060/1670 train_time:103248ms step_avg:97.40ms +step:1061/1670 train_time:103344ms step_avg:97.40ms +step:1062/1670 train_time:103619ms step_avg:97.57ms +step:1063/1670 train_time:103693ms step_avg:97.55ms +step:1064/1670 train_time:103789ms step_avg:97.55ms +step:1065/1670 train_time:103885ms step_avg:97.54ms +step:1066/1670 train_time:103980ms step_avg:97.54ms +step:1067/1670 train_time:104076ms step_avg:97.54ms +step:1068/1670 train_time:104172ms step_avg:97.54ms +step:1069/1670 train_time:104268ms step_avg:97.54ms +step:1070/1670 train_time:104364ms step_avg:97.54ms +step:1071/1670 train_time:104460ms step_avg:97.53ms +step:1072/1670 train_time:104561ms step_avg:97.54ms +step:1073/1670 train_time:104661ms step_avg:97.54ms +step:1074/1670 train_time:104759ms step_avg:97.54ms +step:1075/1670 train_time:104856ms step_avg:97.54ms +step:1076/1670 train_time:104952ms step_avg:97.54ms +step:1077/1670 train_time:105047ms step_avg:97.54ms +step:1078/1670 train_time:105143ms step_avg:97.54ms +step:1079/1670 train_time:105240ms step_avg:97.53ms +step:1080/1670 train_time:105336ms step_avg:97.53ms +step:1081/1670 train_time:105432ms step_avg:97.53ms +step:1082/1670 train_time:105530ms step_avg:97.53ms +step:1083/1670 train_time:105628ms step_avg:97.53ms +step:1084/1670 train_time:105726ms step_avg:97.53ms +step:1085/1670 train_time:105824ms step_avg:97.53ms +step:1086/1670 train_time:105921ms step_avg:97.53ms +step:1087/1670 train_time:106019ms step_avg:97.53ms +step:1088/1670 train_time:106115ms step_avg:97.53ms +step:1089/1670 train_time:106211ms step_avg:97.53ms +step:1090/1670 train_time:106308ms step_avg:97.53ms +step:1091/1670 train_time:106404ms step_avg:97.53ms +step:1092/1670 train_time:106503ms step_avg:97.53ms +step:1093/1670 train_time:106600ms step_avg:97.53ms +step:1094/1670 train_time:106697ms step_avg:97.53ms +step:1095/1670 train_time:106794ms step_avg:97.53ms +step:1096/1670 train_time:106891ms step_avg:97.53ms +step:1097/1670 train_time:106988ms step_avg:97.53ms +step:1098/1670 train_time:107085ms step_avg:97.53ms +step:1099/1670 train_time:107182ms step_avg:97.53ms +step:1100/1670 train_time:107279ms step_avg:97.53ms +step:1101/1670 train_time:107376ms step_avg:97.53ms +step:1102/1670 train_time:107472ms step_avg:97.52ms +step:1103/1670 train_time:107568ms step_avg:97.52ms +step:1104/1670 train_time:107666ms step_avg:97.52ms +step:1105/1670 train_time:107764ms step_avg:97.52ms +step:1106/1670 train_time:107862ms step_avg:97.52ms +step:1107/1670 train_time:107959ms step_avg:97.52ms +step:1108/1670 train_time:108055ms step_avg:97.52ms +step:1109/1670 train_time:108152ms step_avg:97.52ms +step:1110/1670 train_time:108249ms step_avg:97.52ms +step:1111/1670 train_time:108346ms step_avg:97.52ms +step:1112/1670 train_time:108443ms step_avg:97.52ms +step:1113/1670 train_time:108540ms step_avg:97.52ms +step:1114/1670 train_time:108638ms step_avg:97.52ms +step:1115/1670 train_time:108736ms step_avg:97.52ms +step:1116/1670 train_time:108833ms step_avg:97.52ms +step:1117/1670 train_time:108931ms step_avg:97.52ms +step:1118/1670 train_time:109028ms step_avg:97.52ms +step:1119/1670 train_time:109126ms step_avg:97.52ms +step:1120/1670 train_time:109223ms step_avg:97.52ms +step:1121/1670 train_time:109322ms step_avg:97.52ms +step:1122/1670 train_time:109420ms step_avg:97.52ms +step:1123/1670 train_time:109516ms step_avg:97.52ms +step:1124/1670 train_time:109613ms step_avg:97.52ms +step:1125/1670 train_time:109711ms step_avg:97.52ms +step:1125/1670 val_loss:3.4256 train_time:109807ms step_avg:97.61ms +step:1126/1670 train_time:109830ms step_avg:97.54ms +step:1127/1670 train_time:109912ms step_avg:97.53ms +step:1128/1670 train_time:110013ms step_avg:97.53ms +step:1129/1670 train_time:110111ms step_avg:97.53ms +step:1130/1670 train_time:110209ms step_avg:97.53ms +step:1131/1670 train_time:110305ms step_avg:97.53ms +step:1132/1670 train_time:110402ms step_avg:97.53ms +step:1133/1670 train_time:110499ms step_avg:97.53ms +step:1134/1670 train_time:110595ms step_avg:97.53ms +step:1135/1670 train_time:110692ms step_avg:97.53ms +step:1136/1670 train_time:110793ms step_avg:97.53ms +step:1137/1670 train_time:110892ms step_avg:97.53ms +step:1138/1670 train_time:110993ms step_avg:97.53ms +step:1139/1670 train_time:111092ms step_avg:97.53ms +step:1140/1670 train_time:111189ms step_avg:97.53ms +step:1141/1670 train_time:111285ms step_avg:97.53ms +step:1142/1670 train_time:111382ms step_avg:97.53ms +step:1143/1670 train_time:111479ms step_avg:97.53ms +step:1144/1670 train_time:111576ms step_avg:97.53ms +step:1145/1670 train_time:111673ms step_avg:97.53ms +step:1146/1670 train_time:111773ms step_avg:97.53ms +step:1147/1670 train_time:111872ms step_avg:97.53ms +step:1148/1670 train_time:111971ms step_avg:97.54ms +step:1149/1670 train_time:112068ms step_avg:97.53ms +step:1150/1670 train_time:112165ms step_avg:97.54ms +step:1151/1670 train_time:112262ms step_avg:97.53ms +step:1152/1670 train_time:112359ms step_avg:97.53ms +step:1153/1670 train_time:112456ms step_avg:97.53ms +step:1154/1670 train_time:112553ms step_avg:97.53ms +step:1155/1670 train_time:112650ms step_avg:97.53ms +step:1156/1670 train_time:112748ms step_avg:97.53ms +step:1157/1670 train_time:112846ms step_avg:97.53ms +step:1158/1670 train_time:112945ms step_avg:97.53ms +step:1159/1670 train_time:113043ms step_avg:97.54ms +step:1160/1670 train_time:113143ms step_avg:97.54ms +step:1161/1670 train_time:113240ms step_avg:97.54ms +step:1162/1670 train_time:113337ms step_avg:97.54ms +step:1163/1670 train_time:113434ms step_avg:97.54ms +step:1164/1670 train_time:113531ms step_avg:97.54ms +step:1165/1670 train_time:113628ms step_avg:97.53ms +step:1166/1670 train_time:113725ms step_avg:97.53ms +step:1167/1670 train_time:113823ms step_avg:97.53ms +step:1168/1670 train_time:113923ms step_avg:97.54ms +step:1169/1670 train_time:114023ms step_avg:97.54ms +step:1170/1670 train_time:114122ms step_avg:97.54ms +step:1171/1670 train_time:114221ms step_avg:97.54ms +step:1172/1670 train_time:114318ms step_avg:97.54ms +step:1173/1670 train_time:114415ms step_avg:97.54ms +step:1174/1670 train_time:114513ms step_avg:97.54ms +step:1175/1670 train_time:114610ms step_avg:97.54ms +step:1176/1670 train_time:114707ms step_avg:97.54ms +step:1177/1670 train_time:114805ms step_avg:97.54ms +step:1178/1670 train_time:114903ms step_avg:97.54ms +step:1179/1670 train_time:115002ms step_avg:97.54ms +step:1180/1670 train_time:115100ms step_avg:97.54ms +step:1181/1670 train_time:115198ms step_avg:97.54ms +step:1182/1670 train_time:115297ms step_avg:97.54ms +step:1183/1670 train_time:115394ms step_avg:97.54ms +step:1184/1670 train_time:115492ms step_avg:97.54ms +step:1185/1670 train_time:115589ms step_avg:97.54ms +step:1186/1670 train_time:115687ms step_avg:97.54ms +step:1187/1670 train_time:115784ms step_avg:97.54ms +step:1188/1670 train_time:115881ms step_avg:97.54ms +step:1189/1670 train_time:115980ms step_avg:97.54ms +step:1190/1670 train_time:116078ms step_avg:97.54ms +step:1191/1670 train_time:116176ms step_avg:97.55ms +step:1192/1670 train_time:116275ms step_avg:97.55ms +step:1193/1670 train_time:116372ms step_avg:97.55ms +step:1194/1670 train_time:116470ms step_avg:97.55ms +step:1195/1670 train_time:116567ms step_avg:97.55ms +step:1196/1670 train_time:116664ms step_avg:97.55ms +step:1197/1670 train_time:116762ms step_avg:97.55ms +step:1198/1670 train_time:116861ms step_avg:97.55ms +step:1199/1670 train_time:116960ms step_avg:97.55ms +step:1200/1670 train_time:117058ms step_avg:97.55ms +step:1201/1670 train_time:117156ms step_avg:97.55ms +step:1202/1670 train_time:117254ms step_avg:97.55ms +step:1203/1670 train_time:117353ms step_avg:97.55ms +step:1204/1670 train_time:117450ms step_avg:97.55ms +step:1205/1670 train_time:117547ms step_avg:97.55ms +step:1206/1670 train_time:117644ms step_avg:97.55ms +step:1207/1670 train_time:117741ms step_avg:97.55ms +step:1208/1670 train_time:117839ms step_avg:97.55ms +step:1209/1670 train_time:117937ms step_avg:97.55ms +step:1210/1670 train_time:118036ms step_avg:97.55ms +step:1211/1670 train_time:118133ms step_avg:97.55ms +step:1212/1670 train_time:118230ms step_avg:97.55ms +step:1213/1670 train_time:118328ms step_avg:97.55ms +step:1214/1670 train_time:118426ms step_avg:97.55ms +step:1215/1670 train_time:118524ms step_avg:97.55ms +step:1216/1670 train_time:118623ms step_avg:97.55ms +step:1217/1670 train_time:118720ms step_avg:97.55ms +step:1218/1670 train_time:118818ms step_avg:97.55ms +step:1219/1670 train_time:118915ms step_avg:97.55ms +step:1220/1670 train_time:119012ms step_avg:97.55ms +step:1221/1670 train_time:119110ms step_avg:97.55ms +step:1222/1670 train_time:119207ms step_avg:97.55ms +step:1223/1670 train_time:119305ms step_avg:97.55ms +step:1224/1670 train_time:119403ms step_avg:97.55ms +step:1225/1670 train_time:119502ms step_avg:97.55ms +step:1226/1670 train_time:119600ms step_avg:97.55ms +step:1227/1670 train_time:119697ms step_avg:97.55ms +step:1228/1670 train_time:119794ms step_avg:97.55ms +step:1229/1670 train_time:119893ms step_avg:97.55ms +step:1230/1670 train_time:119990ms step_avg:97.55ms +step:1231/1670 train_time:120087ms step_avg:97.55ms +step:1232/1670 train_time:120185ms step_avg:97.55ms +step:1233/1670 train_time:120283ms step_avg:97.55ms +step:1234/1670 train_time:120381ms step_avg:97.55ms +step:1235/1670 train_time:120479ms step_avg:97.55ms +step:1236/1670 train_time:120578ms step_avg:97.56ms +step:1237/1670 train_time:120676ms step_avg:97.56ms +step:1238/1670 train_time:120773ms step_avg:97.56ms +step:1239/1670 train_time:120870ms step_avg:97.55ms +step:1240/1670 train_time:120968ms step_avg:97.55ms +step:1241/1670 train_time:121065ms step_avg:97.55ms +step:1242/1670 train_time:121164ms step_avg:97.56ms +step:1243/1670 train_time:121262ms step_avg:97.56ms +step:1244/1670 train_time:121360ms step_avg:97.56ms +step:1245/1670 train_time:121459ms step_avg:97.56ms +step:1246/1670 train_time:121557ms step_avg:97.56ms +step:1247/1670 train_time:121655ms step_avg:97.56ms +step:1248/1670 train_time:121753ms step_avg:97.56ms +step:1249/1670 train_time:121851ms step_avg:97.56ms +step:1250/1670 train_time:121948ms step_avg:97.56ms +step:1250/1670 val_loss:3.3835 train_time:122044ms step_avg:97.63ms +step:1251/1670 train_time:122067ms step_avg:97.58ms +step:1252/1670 train_time:122149ms step_avg:97.56ms +step:1253/1670 train_time:122249ms step_avg:97.57ms +step:1254/1670 train_time:122348ms step_avg:97.57ms +step:1255/1670 train_time:122444ms step_avg:97.57ms +step:1256/1670 train_time:122541ms step_avg:97.56ms +step:1257/1670 train_time:122638ms step_avg:97.56ms +step:1258/1670 train_time:122735ms step_avg:97.56ms +step:1259/1670 train_time:122832ms step_avg:97.56ms +step:1260/1670 train_time:122928ms step_avg:97.56ms +step:1261/1670 train_time:123028ms step_avg:97.56ms +step:1262/1670 train_time:123129ms step_avg:97.57ms +step:1263/1670 train_time:123229ms step_avg:97.57ms +step:1264/1670 train_time:123328ms step_avg:97.57ms +step:1265/1670 train_time:123425ms step_avg:97.57ms +step:1266/1670 train_time:123522ms step_avg:97.57ms +step:1267/1670 train_time:123619ms step_avg:97.57ms +step:1268/1670 train_time:123716ms step_avg:97.57ms +step:1269/1670 train_time:123813ms step_avg:97.57ms +step:1270/1670 train_time:123909ms step_avg:97.57ms +step:1271/1670 train_time:124008ms step_avg:97.57ms +step:1272/1670 train_time:124107ms step_avg:97.57ms +step:1273/1670 train_time:124205ms step_avg:97.57ms +step:1274/1670 train_time:124479ms step_avg:97.71ms +step:1275/1670 train_time:124674ms step_avg:97.78ms +step:1276/1670 train_time:124770ms step_avg:97.78ms +step:1277/1670 train_time:124866ms step_avg:97.78ms +step:1278/1670 train_time:124963ms step_avg:97.78ms +step:1279/1670 train_time:125059ms step_avg:97.78ms +step:1280/1670 train_time:125157ms step_avg:97.78ms +step:1281/1670 train_time:125254ms step_avg:97.78ms +step:1282/1670 train_time:125350ms step_avg:97.78ms +step:1283/1670 train_time:125448ms step_avg:97.78ms +step:1284/1670 train_time:125547ms step_avg:97.78ms +step:1285/1670 train_time:125648ms step_avg:97.78ms +step:1286/1670 train_time:125748ms step_avg:97.78ms +step:1287/1670 train_time:125845ms step_avg:97.78ms +step:1288/1670 train_time:125942ms step_avg:97.78ms +step:1289/1670 train_time:126039ms step_avg:97.78ms +step:1290/1670 train_time:126136ms step_avg:97.78ms +step:1291/1670 train_time:126233ms step_avg:97.78ms +step:1292/1670 train_time:126330ms step_avg:97.78ms +step:1293/1670 train_time:126427ms step_avg:97.78ms +step:1294/1670 train_time:126525ms step_avg:97.78ms +step:1295/1670 train_time:126625ms step_avg:97.78ms +step:1296/1670 train_time:126725ms step_avg:97.78ms +step:1297/1670 train_time:126823ms step_avg:97.78ms +step:1298/1670 train_time:126920ms step_avg:97.78ms +step:1299/1670 train_time:127017ms step_avg:97.78ms +step:1300/1670 train_time:127114ms step_avg:97.78ms +step:1301/1670 train_time:127211ms step_avg:97.78ms +step:1302/1670 train_time:127308ms step_avg:97.78ms +step:1303/1670 train_time:127405ms step_avg:97.78ms +step:1304/1670 train_time:127502ms step_avg:97.78ms +step:1305/1670 train_time:127600ms step_avg:97.78ms +step:1306/1670 train_time:127700ms step_avg:97.78ms +step:1307/1670 train_time:127799ms step_avg:97.78ms +step:1308/1670 train_time:127896ms step_avg:97.78ms +step:1309/1670 train_time:127994ms step_avg:97.78ms +step:1310/1670 train_time:128093ms step_avg:97.78ms +step:1311/1670 train_time:128190ms step_avg:97.78ms +step:1312/1670 train_time:128287ms step_avg:97.78ms +step:1313/1670 train_time:128384ms step_avg:97.78ms +step:1314/1670 train_time:128482ms step_avg:97.78ms +step:1315/1670 train_time:128580ms step_avg:97.78ms +step:1316/1670 train_time:128678ms step_avg:97.78ms +step:1317/1670 train_time:128777ms step_avg:97.78ms +step:1318/1670 train_time:128875ms step_avg:97.78ms +step:1319/1670 train_time:128973ms step_avg:97.78ms +step:1320/1670 train_time:129071ms step_avg:97.78ms +step:1321/1670 train_time:129170ms step_avg:97.78ms +step:1322/1670 train_time:129266ms step_avg:97.78ms +step:1323/1670 train_time:129365ms step_avg:97.78ms +step:1324/1670 train_time:129464ms step_avg:97.78ms +step:1325/1670 train_time:129561ms step_avg:97.78ms +step:1326/1670 train_time:129658ms step_avg:97.78ms +step:1327/1670 train_time:129756ms step_avg:97.78ms +step:1328/1670 train_time:129854ms step_avg:97.78ms +step:1329/1670 train_time:129951ms step_avg:97.78ms +step:1330/1670 train_time:130049ms step_avg:97.78ms +step:1331/1670 train_time:130146ms step_avg:97.78ms +step:1332/1670 train_time:130243ms step_avg:97.78ms +step:1333/1670 train_time:130340ms step_avg:97.78ms +step:1334/1670 train_time:130438ms step_avg:97.78ms +step:1335/1670 train_time:130536ms step_avg:97.78ms +step:1336/1670 train_time:130634ms step_avg:97.78ms +step:1337/1670 train_time:130733ms step_avg:97.78ms +step:1338/1670 train_time:130832ms step_avg:97.78ms +step:1339/1670 train_time:130930ms step_avg:97.78ms +step:1340/1670 train_time:131029ms step_avg:97.78ms +step:1341/1670 train_time:131127ms step_avg:97.78ms +step:1342/1670 train_time:131225ms step_avg:97.78ms +step:1343/1670 train_time:131323ms step_avg:97.78ms +step:1344/1670 train_time:131421ms step_avg:97.78ms +step:1345/1670 train_time:131518ms step_avg:97.78ms +step:1346/1670 train_time:131615ms step_avg:97.78ms +step:1347/1670 train_time:131714ms step_avg:97.78ms +step:1348/1670 train_time:131813ms step_avg:97.78ms +step:1349/1670 train_time:131911ms step_avg:97.78ms +step:1350/1670 train_time:132009ms step_avg:97.78ms +step:1351/1670 train_time:132107ms step_avg:97.78ms +step:1352/1670 train_time:132204ms step_avg:97.78ms +step:1353/1670 train_time:132302ms step_avg:97.78ms +step:1354/1670 train_time:132400ms step_avg:97.78ms +step:1355/1670 train_time:132497ms step_avg:97.78ms +step:1356/1670 train_time:132594ms step_avg:97.78ms +step:1357/1670 train_time:132692ms step_avg:97.78ms +step:1358/1670 train_time:132791ms step_avg:97.78ms +step:1359/1670 train_time:132889ms step_avg:97.78ms +step:1360/1670 train_time:132987ms step_avg:97.78ms +step:1361/1670 train_time:133084ms step_avg:97.78ms +step:1362/1670 train_time:133181ms step_avg:97.78ms +step:1363/1670 train_time:133280ms step_avg:97.78ms +step:1364/1670 train_time:133377ms step_avg:97.78ms +step:1365/1670 train_time:133475ms step_avg:97.78ms +step:1366/1670 train_time:133573ms step_avg:97.78ms +step:1367/1670 train_time:133670ms step_avg:97.78ms +step:1368/1670 train_time:133767ms step_avg:97.78ms +step:1369/1670 train_time:133865ms step_avg:97.78ms +step:1370/1670 train_time:133963ms step_avg:97.78ms +step:1371/1670 train_time:134062ms step_avg:97.78ms +step:1372/1670 train_time:134159ms step_avg:97.78ms +step:1373/1670 train_time:134257ms step_avg:97.78ms +step:1374/1670 train_time:134354ms step_avg:97.78ms +step:1375/1670 train_time:134453ms step_avg:97.78ms +step:1375/1670 val_loss:3.3452 train_time:134549ms step_avg:97.85ms +step:1376/1670 train_time:134574ms step_avg:97.80ms +step:1377/1670 train_time:134656ms step_avg:97.79ms +step:1378/1670 train_time:134756ms step_avg:97.79ms +step:1379/1670 train_time:134853ms step_avg:97.79ms +step:1380/1670 train_time:134950ms step_avg:97.79ms +step:1381/1670 train_time:135047ms step_avg:97.79ms +step:1382/1670 train_time:135144ms step_avg:97.79ms +step:1383/1670 train_time:135241ms step_avg:97.79ms +step:1384/1670 train_time:135339ms step_avg:97.79ms +step:1385/1670 train_time:135437ms step_avg:97.79ms +step:1386/1670 train_time:135536ms step_avg:97.79ms +step:1387/1670 train_time:135636ms step_avg:97.79ms +step:1388/1670 train_time:135735ms step_avg:97.79ms +step:1389/1670 train_time:135833ms step_avg:97.79ms +step:1390/1670 train_time:135930ms step_avg:97.79ms +step:1391/1670 train_time:136028ms step_avg:97.79ms +step:1392/1670 train_time:136125ms step_avg:97.79ms +step:1393/1670 train_time:136222ms step_avg:97.79ms +step:1394/1670 train_time:136319ms step_avg:97.79ms +step:1395/1670 train_time:136416ms step_avg:97.79ms +step:1396/1670 train_time:136514ms step_avg:97.79ms +step:1397/1670 train_time:136613ms step_avg:97.79ms +step:1398/1670 train_time:136712ms step_avg:97.79ms +step:1399/1670 train_time:136810ms step_avg:97.79ms +step:1400/1670 train_time:136908ms step_avg:97.79ms +step:1401/1670 train_time:137007ms step_avg:97.79ms +step:1402/1670 train_time:137104ms step_avg:97.79ms +step:1403/1670 train_time:137201ms step_avg:97.79ms +step:1404/1670 train_time:137298ms step_avg:97.79ms +step:1405/1670 train_time:137396ms step_avg:97.79ms +step:1406/1670 train_time:137493ms step_avg:97.79ms +step:1407/1670 train_time:137591ms step_avg:97.79ms +step:1408/1670 train_time:137690ms step_avg:97.79ms +step:1409/1670 train_time:137789ms step_avg:97.79ms +step:1410/1670 train_time:137887ms step_avg:97.79ms +step:1411/1670 train_time:137985ms step_avg:97.79ms +step:1412/1670 train_time:138082ms step_avg:97.79ms +step:1413/1670 train_time:138180ms step_avg:97.79ms +step:1414/1670 train_time:138278ms step_avg:97.79ms +step:1415/1670 train_time:138375ms step_avg:97.79ms +step:1416/1670 train_time:138472ms step_avg:97.79ms +step:1417/1670 train_time:138571ms step_avg:97.79ms +step:1418/1670 train_time:138668ms step_avg:97.79ms +step:1419/1670 train_time:138767ms step_avg:97.79ms +step:1420/1670 train_time:138866ms step_avg:97.79ms +step:1421/1670 train_time:138964ms step_avg:97.79ms +step:1422/1670 train_time:139063ms step_avg:97.79ms +step:1423/1670 train_time:139160ms step_avg:97.79ms +step:1424/1670 train_time:139258ms step_avg:97.79ms +step:1425/1670 train_time:139356ms step_avg:97.79ms +step:1426/1670 train_time:139453ms step_avg:97.79ms +step:1427/1670 train_time:139551ms step_avg:97.79ms +step:1428/1670 train_time:139649ms step_avg:97.79ms +step:1429/1670 train_time:139747ms step_avg:97.79ms +step:1430/1670 train_time:139846ms step_avg:97.79ms +step:1431/1670 train_time:139945ms step_avg:97.79ms +step:1432/1670 train_time:140042ms step_avg:97.80ms +step:1433/1670 train_time:140140ms step_avg:97.79ms +step:1434/1670 train_time:140239ms step_avg:97.80ms +step:1435/1670 train_time:140336ms step_avg:97.79ms +step:1436/1670 train_time:140433ms step_avg:97.79ms +step:1437/1670 train_time:140531ms step_avg:97.79ms +step:1438/1670 train_time:140628ms step_avg:97.79ms +step:1439/1670 train_time:140727ms step_avg:97.80ms +step:1440/1670 train_time:140825ms step_avg:97.80ms +step:1441/1670 train_time:140923ms step_avg:97.80ms +step:1442/1670 train_time:141021ms step_avg:97.80ms +step:1443/1670 train_time:141118ms step_avg:97.79ms +step:1444/1670 train_time:141216ms step_avg:97.79ms +step:1445/1670 train_time:141313ms step_avg:97.79ms +step:1446/1670 train_time:141412ms step_avg:97.79ms +step:1447/1670 train_time:141509ms step_avg:97.80ms +step:1448/1670 train_time:141607ms step_avg:97.80ms +step:1449/1670 train_time:141705ms step_avg:97.80ms +step:1450/1670 train_time:141803ms step_avg:97.80ms +step:1451/1670 train_time:141900ms step_avg:97.79ms +step:1452/1670 train_time:141998ms step_avg:97.79ms +step:1453/1670 train_time:142096ms step_avg:97.79ms +step:1454/1670 train_time:142194ms step_avg:97.80ms +step:1455/1670 train_time:142292ms step_avg:97.79ms +step:1456/1670 train_time:142389ms step_avg:97.79ms +step:1457/1670 train_time:142488ms step_avg:97.80ms +step:1458/1670 train_time:142586ms step_avg:97.80ms +step:1459/1670 train_time:142687ms step_avg:97.80ms +step:1460/1670 train_time:142784ms step_avg:97.80ms +step:1461/1670 train_time:142883ms step_avg:97.80ms +step:1462/1670 train_time:142982ms step_avg:97.80ms +step:1463/1670 train_time:143080ms step_avg:97.80ms +step:1464/1670 train_time:143179ms step_avg:97.80ms +step:1465/1670 train_time:143276ms step_avg:97.80ms +step:1466/1670 train_time:143375ms step_avg:97.80ms +step:1467/1670 train_time:143473ms step_avg:97.80ms +step:1468/1670 train_time:143570ms step_avg:97.80ms +step:1469/1670 train_time:143668ms step_avg:97.80ms +step:1470/1670 train_time:143766ms step_avg:97.80ms +step:1471/1670 train_time:143865ms step_avg:97.80ms +step:1472/1670 train_time:143963ms step_avg:97.80ms +step:1473/1670 train_time:144061ms step_avg:97.80ms +step:1474/1670 train_time:144160ms step_avg:97.80ms +step:1475/1670 train_time:144258ms step_avg:97.80ms +step:1476/1670 train_time:144356ms step_avg:97.80ms +step:1477/1670 train_time:144453ms step_avg:97.80ms +step:1478/1670 train_time:144550ms step_avg:97.80ms +step:1479/1670 train_time:144648ms step_avg:97.80ms +step:1480/1670 train_time:144745ms step_avg:97.80ms +step:1481/1670 train_time:144843ms step_avg:97.80ms +step:1482/1670 train_time:144941ms step_avg:97.80ms +step:1483/1670 train_time:145039ms step_avg:97.80ms +step:1484/1670 train_time:145136ms step_avg:97.80ms +step:1485/1670 train_time:145417ms step_avg:97.92ms +step:1486/1670 train_time:145490ms step_avg:97.91ms +step:1487/1670 train_time:145587ms step_avg:97.91ms +step:1488/1670 train_time:145684ms step_avg:97.91ms +step:1489/1670 train_time:145781ms step_avg:97.91ms +step:1490/1670 train_time:145877ms step_avg:97.90ms +step:1491/1670 train_time:145974ms step_avg:97.90ms +step:1492/1670 train_time:146070ms step_avg:97.90ms +step:1493/1670 train_time:146166ms step_avg:97.90ms +step:1494/1670 train_time:146264ms step_avg:97.90ms +step:1495/1670 train_time:146370ms step_avg:97.91ms +step:1496/1670 train_time:146469ms step_avg:97.91ms +step:1497/1670 train_time:146568ms step_avg:97.91ms +step:1498/1670 train_time:146665ms step_avg:97.91ms +step:1499/1670 train_time:146762ms step_avg:97.91ms +step:1500/1670 train_time:146859ms step_avg:97.91ms +step:1500/1670 val_loss:3.3124 train_time:146954ms step_avg:97.97ms +step:1501/1670 train_time:146978ms step_avg:97.92ms +step:1502/1670 train_time:147059ms step_avg:97.91ms +step:1503/1670 train_time:147158ms step_avg:97.91ms +step:1504/1670 train_time:147255ms step_avg:97.91ms +step:1505/1670 train_time:147353ms step_avg:97.91ms +step:1506/1670 train_time:147451ms step_avg:97.91ms +step:1507/1670 train_time:147549ms step_avg:97.91ms +step:1508/1670 train_time:147646ms step_avg:97.91ms +step:1509/1670 train_time:147742ms step_avg:97.91ms +step:1510/1670 train_time:147839ms step_avg:97.91ms +step:1511/1670 train_time:147939ms step_avg:97.91ms +step:1512/1670 train_time:148038ms step_avg:97.91ms +step:1513/1670 train_time:148137ms step_avg:97.91ms +step:1514/1670 train_time:148235ms step_avg:97.91ms +step:1515/1670 train_time:148333ms step_avg:97.91ms +step:1516/1670 train_time:148431ms step_avg:97.91ms +step:1517/1670 train_time:148528ms step_avg:97.91ms +step:1518/1670 train_time:148624ms step_avg:97.91ms +step:1519/1670 train_time:148721ms step_avg:97.91ms +step:1520/1670 train_time:148818ms step_avg:97.91ms +step:1521/1670 train_time:148917ms step_avg:97.91ms +step:1522/1670 train_time:149017ms step_avg:97.91ms +step:1523/1670 train_time:149116ms step_avg:97.91ms +step:1524/1670 train_time:149214ms step_avg:97.91ms +step:1525/1670 train_time:149311ms step_avg:97.91ms +step:1526/1670 train_time:149409ms step_avg:97.91ms +step:1527/1670 train_time:149508ms step_avg:97.91ms +step:1528/1670 train_time:149605ms step_avg:97.91ms +step:1529/1670 train_time:149702ms step_avg:97.91ms +step:1530/1670 train_time:149799ms step_avg:97.91ms +step:1531/1670 train_time:149896ms step_avg:97.91ms +step:1532/1670 train_time:149995ms step_avg:97.91ms +step:1533/1670 train_time:150095ms step_avg:97.91ms +step:1534/1670 train_time:150194ms step_avg:97.91ms +step:1535/1670 train_time:150292ms step_avg:97.91ms +step:1536/1670 train_time:150391ms step_avg:97.91ms +step:1537/1670 train_time:150489ms step_avg:97.91ms +step:1538/1670 train_time:150587ms step_avg:97.91ms +step:1539/1670 train_time:150685ms step_avg:97.91ms +step:1540/1670 train_time:150782ms step_avg:97.91ms +step:1541/1670 train_time:150879ms step_avg:97.91ms +step:1542/1670 train_time:150978ms step_avg:97.91ms +step:1543/1670 train_time:151075ms step_avg:97.91ms +step:1544/1670 train_time:151173ms step_avg:97.91ms +step:1545/1670 train_time:151272ms step_avg:97.91ms +step:1546/1670 train_time:151368ms step_avg:97.91ms +step:1547/1670 train_time:151466ms step_avg:97.91ms +step:1548/1670 train_time:151565ms step_avg:97.91ms +step:1549/1670 train_time:151663ms step_avg:97.91ms +step:1550/1670 train_time:151760ms step_avg:97.91ms +step:1551/1670 train_time:151857ms step_avg:97.91ms +step:1552/1670 train_time:151955ms step_avg:97.91ms +step:1553/1670 train_time:152053ms step_avg:97.91ms +step:1554/1670 train_time:152152ms step_avg:97.91ms +step:1555/1670 train_time:152250ms step_avg:97.91ms +step:1556/1670 train_time:152348ms step_avg:97.91ms +step:1557/1670 train_time:152445ms step_avg:97.91ms +step:1558/1670 train_time:152544ms step_avg:97.91ms +step:1559/1670 train_time:152641ms step_avg:97.91ms +step:1560/1670 train_time:152739ms step_avg:97.91ms +step:1561/1670 train_time:152837ms step_avg:97.91ms +step:1562/1670 train_time:152935ms step_avg:97.91ms +step:1563/1670 train_time:153033ms step_avg:97.91ms +step:1564/1670 train_time:153132ms step_avg:97.91ms +step:1565/1670 train_time:153230ms step_avg:97.91ms +step:1566/1670 train_time:153327ms step_avg:97.91ms +step:1567/1670 train_time:153424ms step_avg:97.91ms +step:1568/1670 train_time:153522ms step_avg:97.91ms +step:1569/1670 train_time:153620ms step_avg:97.91ms +step:1570/1670 train_time:153717ms step_avg:97.91ms +step:1571/1670 train_time:153814ms step_avg:97.91ms +step:1572/1670 train_time:153913ms step_avg:97.91ms +step:1573/1670 train_time:154011ms step_avg:97.91ms +step:1574/1670 train_time:154108ms step_avg:97.91ms +step:1575/1670 train_time:154206ms step_avg:97.91ms +step:1576/1670 train_time:154303ms step_avg:97.91ms +step:1577/1670 train_time:154400ms step_avg:97.91ms +step:1578/1670 train_time:154498ms step_avg:97.91ms +step:1579/1670 train_time:154597ms step_avg:97.91ms +step:1580/1670 train_time:154696ms step_avg:97.91ms +step:1581/1670 train_time:154794ms step_avg:97.91ms +step:1582/1670 train_time:154893ms step_avg:97.91ms +step:1583/1670 train_time:154991ms step_avg:97.91ms +step:1584/1670 train_time:155089ms step_avg:97.91ms +step:1585/1670 train_time:155186ms step_avg:97.91ms +step:1586/1670 train_time:155284ms step_avg:97.91ms +step:1587/1670 train_time:155381ms step_avg:97.91ms +step:1588/1670 train_time:155479ms step_avg:97.91ms +step:1589/1670 train_time:155576ms step_avg:97.91ms +step:1590/1670 train_time:155674ms step_avg:97.91ms +step:1591/1670 train_time:155773ms step_avg:97.91ms +step:1592/1670 train_time:155871ms step_avg:97.91ms +step:1593/1670 train_time:155969ms step_avg:97.91ms +step:1594/1670 train_time:156067ms step_avg:97.91ms +step:1595/1670 train_time:156164ms step_avg:97.91ms +step:1596/1670 train_time:156260ms step_avg:97.91ms +step:1597/1670 train_time:156358ms step_avg:97.91ms +step:1598/1670 train_time:156456ms step_avg:97.91ms +step:1599/1670 train_time:156553ms step_avg:97.91ms +step:1600/1670 train_time:156652ms step_avg:97.91ms +step:1601/1670 train_time:156751ms step_avg:97.91ms +step:1602/1670 train_time:156849ms step_avg:97.91ms +step:1603/1670 train_time:156948ms step_avg:97.91ms +step:1604/1670 train_time:157046ms step_avg:97.91ms +step:1605/1670 train_time:157143ms step_avg:97.91ms +step:1606/1670 train_time:157241ms step_avg:97.91ms +step:1607/1670 train_time:157338ms step_avg:97.91ms +step:1608/1670 train_time:157435ms step_avg:97.91ms +step:1609/1670 train_time:157533ms step_avg:97.91ms +step:1610/1670 train_time:157631ms step_avg:97.91ms +step:1611/1670 train_time:157729ms step_avg:97.91ms +step:1612/1670 train_time:157827ms step_avg:97.91ms +step:1613/1670 train_time:157925ms step_avg:97.91ms +step:1614/1670 train_time:158022ms step_avg:97.91ms +step:1615/1670 train_time:158120ms step_avg:97.91ms +step:1616/1670 train_time:158217ms step_avg:97.91ms +step:1617/1670 train_time:158315ms step_avg:97.91ms +step:1618/1670 train_time:158414ms step_avg:97.91ms +step:1619/1670 train_time:158512ms step_avg:97.91ms +step:1620/1670 train_time:158611ms step_avg:97.91ms +step:1621/1670 train_time:158708ms step_avg:97.91ms +step:1622/1670 train_time:158806ms step_avg:97.91ms +step:1623/1670 train_time:158903ms step_avg:97.91ms +step:1624/1670 train_time:159001ms step_avg:97.91ms +step:1625/1670 train_time:159098ms step_avg:97.91ms +step:1625/1670 val_loss:3.2859 train_time:159195ms step_avg:97.97ms +step:1626/1670 train_time:159218ms step_avg:97.92ms +step:1627/1670 train_time:159303ms step_avg:97.91ms +step:1628/1670 train_time:159404ms step_avg:97.91ms +step:1629/1670 train_time:159502ms step_avg:97.91ms +step:1630/1670 train_time:159599ms step_avg:97.91ms +step:1631/1670 train_time:159695ms step_avg:97.91ms +step:1632/1670 train_time:159792ms step_avg:97.91ms +step:1633/1670 train_time:159889ms step_avg:97.91ms +step:1634/1670 train_time:159986ms step_avg:97.91ms +step:1635/1670 train_time:160083ms step_avg:97.91ms +step:1636/1670 train_time:160182ms step_avg:97.91ms +step:1637/1670 train_time:160285ms step_avg:97.91ms +step:1638/1670 train_time:160386ms step_avg:97.92ms +step:1639/1670 train_time:160484ms step_avg:97.92ms +step:1640/1670 train_time:160583ms step_avg:97.92ms +step:1641/1670 train_time:160681ms step_avg:97.92ms +step:1642/1670 train_time:160779ms step_avg:97.92ms +step:1643/1670 train_time:160877ms step_avg:97.92ms +step:1644/1670 train_time:160973ms step_avg:97.92ms +step:1645/1670 train_time:161070ms step_avg:97.92ms +step:1646/1670 train_time:161168ms step_avg:97.91ms +step:1647/1670 train_time:161267ms step_avg:97.92ms +step:1648/1670 train_time:161368ms step_avg:97.92ms +step:1649/1670 train_time:161466ms step_avg:97.92ms +step:1650/1670 train_time:161564ms step_avg:97.92ms +step:1651/1670 train_time:161662ms step_avg:97.92ms +step:1652/1670 train_time:161761ms step_avg:97.92ms +step:1653/1670 train_time:161858ms step_avg:97.92ms +step:1654/1670 train_time:161954ms step_avg:97.92ms +step:1655/1670 train_time:162051ms step_avg:97.92ms +step:1656/1670 train_time:162150ms step_avg:97.92ms +step:1657/1670 train_time:162248ms step_avg:97.92ms +step:1658/1670 train_time:162345ms step_avg:97.92ms +step:1659/1670 train_time:162444ms step_avg:97.92ms +step:1660/1670 train_time:162543ms step_avg:97.92ms +step:1661/1670 train_time:162641ms step_avg:97.92ms +step:1662/1670 train_time:162738ms step_avg:97.92ms +step:1663/1670 train_time:162836ms step_avg:97.92ms +step:1664/1670 train_time:162933ms step_avg:97.92ms +step:1665/1670 train_time:163030ms step_avg:97.92ms +step:1666/1670 train_time:163128ms step_avg:97.92ms +step:1667/1670 train_time:163226ms step_avg:97.92ms +step:1668/1670 train_time:163325ms step_avg:97.92ms +step:1669/1670 train_time:163423ms step_avg:97.92ms +step:1670/1670 train_time:163522ms step_avg:97.92ms +step:1670/1670 val_loss:3.2780 train_time:163618ms step_avg:97.97ms +peak memory allocated: 34000 MiB reserved: 49576 MiB diff --git a/records/090325_FA3/media/attn_speed_vs_batch_s1024_ws384.png b/records/090325_FA3/media/attn_speed_vs_batch_s1024_ws384.png new file mode 100644 index 0000000000000000000000000000000000000000..eb60b4b854662157febe6d2b870004dace5e7778 GIT binary patch literal 105002 zcmeFZhdY<;|3CgBg=9q$vXg8jD|=KZrM&Dd5gFN=C^M8IWfR#eGnz18$6zqT3fJZCU@%AF zk4G2@j={g&qdNECUn2JM8uoXs@7p^W+L~Z)8rs`Dw6=fvz=+Y&#MbVCwbdmaL7q!o zjAr)sHg=-Cyq5p@1|Dl$Q{K+as!jMMgf`bT?JyYXX!H+`f`jiX3=RgPASbQr9J4s= z@hc0?H&8H~(- zzmnJ#9xfguBmeuQQg4X+JmBwF8lmz3{~!Nv&tQ%w-)dM?$7$H1bot`~pY2q0qZNtD z%L3;$BPI8nF~a%z(UTY74r~-#k4m(zq3;VBrh|u1bkkVKJ0Kt}{rb(D4QVRL^B&dw z{QTP!!MyM%)oMMrrOOZd^KdY$wfGDMt7EkxgMZd6skQ$ITA(?OH#gW8hsr20315G7 zMSYoYjEaVU0LPO?;IZ005LlLX&fAD*+uT)P*hwj9&BN#?uVQLi> zm7W~^ueRpyYin!Gt*u8eKbMDTT`Q|zXv?9mpQLP*yCM`x8pRZ6{?ufA#7K5fu*xla!oHPD%N#nfp&E7Bn)|pl@3qDzn?(?9(Twr40_I z;$`OLB^?l-ENz_GrufrD8ZD5`H_@OpTuHiiB8 z!T!drhoeLz2?!C zfu9qgA$E6nC!?nibzB-qH0_L|7Pes&6B9EZdS6-Z_44IQR2+1|HkjFtSXo0uCROZH ze}8-#8JVsE6D9klfm6Q6h-Efr+AZ95FaaSU4e?$Ff$%gLFI}Q&XlPhnTaykx&D~Y) z<@JnSoCXI6$8jk$CL^Q2%HxXTV*g`bd{SmkPNJhnk9I9=qDQp%YpMJFErucYjrP@F zzjAu2T6IfpaFVZu6%QRiWsy-(821$)^Tm<%Cy`fCA;*kGsO!~MA8h6-F6e%=CtT>u zZ5pX`d$jxW0VO4+RHR91c{vLdOMA4?HMoIClOaM$u=)<7ByMyfws?noqlaV+3}H&I zd88@jVe!TqX{|r$It@R4{+z7**WEjJeERz}76(g6&)cfL8P{*XKtgMXi%u-TPy3C7KH#g<)-D4HCpYztMbTzE;#<8)p%Pb6o zU#}_ty4deX_0CMRMzyC$?7;RXid$oc2P=mey0fd}_5B4VM8B2>{hvKMF*Q3|KR8G` zK0ZFzm&>qRv7sz?^=j{rwZMR82YyuiQ}1rDQJm({2{^^61`jpUL_|cS`SWLK85!KF zz2z$>7$oSvzQ0w%X=oB=IaqAi`0Pxi(K%=aQIcK~6%`d6*|hMG;Wk+=TuA>?08e}T zONvr|rJGPBm!{8;ukR-#wXM>z+R8D)kI!7XKeoTKp|)jeVq(}7dbU5`7{511H<(mn zJ7Hk5v$?rOiMR3zEI(F@t@R>&J%91yG5qKF`*_bfZ=55m-Pag;`m2HnYovXTeb<@Y z=#*@N595rzhH38XyyodCagu@}-(f+GOCt|c{ri`8QD8Y-`{c0GXhW?pJ|}l;ON)$} z+G!FxVej1P1DZ~;jT6gdvvD+GNBH7fMeJr6`T6s|6eJ}v$B5V+Gwq0}O}!b-=d`S| zzrWGxFu&HHrFqGIWn}ov0Wqy$e=KHyA-9_Avm|sdZR^USFQT04IiKpD5~*luX+6+{ z{RCCvZ zJ=J`x$aT}aEnaeTW({-r>3f{GTi(6TC&hNAUPkd5%2d)_b3i@BeS6&hx`IM_VOCyV zmu<{KZ?^C2*Hq|l%75)Gm5ffjye0AY^Sy$Pk|H8!pl@iDJ95XlEsH!md!<-{)7kV?zS#wQwrBggZhLZkSNkKJ#JGW^gK$&D7OmACpCtQ&!kcW<^f z3nODqINx2$qnJF?PWpZ;O$K~pu!)5_Vy~E%b+)$?b88nriH~Q12LQDb z{_-VvQ&NU`$P$!EU3c;|9L$RsF9bOSdG#u&4}Y(CWiJzvl3sIm7KXkW6cqG4Jltn! zNT+pueX_vh@Zh)e!OpCThK5;L=hv?k*;++D&`U2~y7cJRqGnn~5L{fv)`(kkTiYkN zcg>~zEd6Twm5Pl}>=*fP)!2b9|4~Dqg~I`K;|bWW_^X{8seKEN#_Nv0d;k6j-}$K# zuR{;`#J<*;XiZJc?XBv=keC>{Ru9+85nHaS$>e2aW!IIJiz|55sFE)S1_c#l8i^Tr zhV<`tPbexVC^-N6Zpz(BZXm&6IZ*f*_JFOO-HTVR@Z?nt1ZdW#U+XhpzI@8Ma#Jlf z%r_uFhls)(&GESpof01u4>p$usd%flnYD@^J}BcA7e5QPS~Jrgje~){-mv)n1HVOY zmZ;=%_&KH3AMfL}tioW>UAo`q&8iU5D-K&h&+E4Xb;wnW6ii&c2VYoWY-nt9DBQU5 z2_7GRAR1CV@A_f zDUvIHh#tm#s&cHkU4S)I6QCf!c+cJRmL~K;g%7@ReSQU7*l2y~b1AHUc83al(`haZ z?@_PAv)&@_DrN0#ZL`AF;HS%bY+ru9*JJTPyTrct9rU9mHoKq;curuTjulU$4TtYTf+1e!> z{wBwLyccrHNuX+*U@usI=$nRNy0+HDVB8rej%l0S_oKCvf)1433y^_fcXI*1Mkl{` z2$tN!pMnPmM`X=b9Mq5@zTf8P{)EomJW7Z-KWG6uHGHdLZr}>Cb3&a+6DIl&+W|q_!0QvYw8~mrj}Wxm7A4y;?bi={e=%+ z>VBTjxDmxyn?wH=ru*I08~H=M3!8#&XQ)n}o`8w-f%VYG*OwnEX2^9RTdnE&%a^|L z!KXDxj1krY=+(1KUk9^!Km+edKtS{77zqy=Sf0x1Ur;~+VG_H|96R`!#KgoXakt{0 z7#A3J(dTe((`IB$_>xNG?cO<@OAedpQP(eevk!eRxANE9b30C7Zwie=qm;|ltK{%F z6BZOyXBzL}SLLyXi+OgH1%R}1`9MT;bTB@>GX}OyBzB#YZ!2kUwVsrmn%b<)dibM* z;MND@CX&`sG_tWWsmvi)ed)zr4F(DyIG($6xVymHI$C_vJ3l`^lHLKjXVeuZtB8#~ zfL{hfrH?f7jpQThy+%FvES17CtcE|14_a3@)I2`I8Ji5ysV(*j70b2I!l4p_THI*y zVJLI7zm(Q_`S{56JvL7@hA~*%H9LK{-@e++5;is)?{(G2Ca>LbpuVZ8iG!&O5HU^3 zcX<3Mv7Vlu)O-fsJ?l^<+OBt8U0sdF>z-me!p=}qK7rmo2^>M%-roLfSY2^Ng%~tw z1XZAhN?fx8lLDYD#!!nmp~=h5-QaVSC}U_8>ObGCn|a7(HWo`#w)`!fLSko%Ldoa4 z3Z*-Y+)bz3ucv2bv>L8vs%5W2VNQl#S~;lob?efd}PIn6+_`yom+!uv|JBs`y#5X2_yczuFV}y%3+n{I^Wcp|=%> z2YZ0Wj~XcB!om*<>UTol$~#eS)tz*-3SiHW=iX98Xzl1I1N5pIxNhz(4;W5Z67KWN zNvTi&P@VH1j2ezBsH0D0Sm#GHq!uh?nlt=jr%j0Z#u(f&aGGvGPxp6h9>%0 z`un&H+i1yy?fNgtJnZZQmrOgJT)n$lgBx3#%5mFTzAVAxZ60q|6H|EMR_yKaWN;}^vl}(7dQsTj7MDaS4C{~SR5*9r?zIbz$^r1Pc=#T6-y7ES z5OS8Q8V3i!gI>ITt-iE-%Kwn$tZ3nv4b3&N606~X zRp0>3EG(QiEA&QFZ@e*b;c`2O*kWd4HV@)Ze&q$F*Qe3YGdw)Z(Vb_}$D_0s`j(Q4 z${5=FHUL3P)JstWs~SUS)t4Nn^@VZ9^bdDm>*_Dp*VSeD%`Oj@mo!>l4WYprf5|eM zSOWYSsTX((yZFW)egh3dAAPygO0igHpT1XRyQ5l+P`>*?_Q#Cvm{Dk41Vluck7!9r zNl(y=_1g~i+Uc#edcmV!jEl~aKrLWam-j(}MgBbQbofw#yD z+lfzsG;SYFOij__6!zQ~#@Jccm)XtsZnFeuWT;hInQtWDfcoLq7rYfCtTMR^Ei&^_ zEHKQ>;!~$j*Nu*jn#az~nJG=INFMHx?*IM`<1aJht>0m!q*WQ>CeY7o-YLAy8hDsn zS}NQ+>aqJX->p1|@a?*tZpXcgfZ6l*xKd(nCMF!zGNzk+qN1oP_SajnIrkQWgqAx*Te5A;KfLKRSn=7JbJ`W@3zR_I=QpAH<=tJQLum?`SYiJX5!m;k6&-v zkJlEfid6-j63V>u^|l};V&?3Jo&^guf2uvjVW@!UG6EwAs)g+L-qtmkhF#g(LWP!t zy~{?pK}lmP)rSC%jdVq*bjx74B>>2dWF7P)rd@>{nFyV1?(2Kfj}8ldIlAS&!k_M9 zv8?CLX<8k^oOEn|gRenqU9m0#eYNDoEZ*%EyYcUnzXq*KSTUq$MUl3Yp!wmx+Co}} zdC%RomI;RGy1F`>h3pb7*TUtIN?PN_XE2upc7H7e!{bSYPgB>uP|gPY-H~(<2zd56 z9F5_C;K`;i%|t@^eW8sePsvy7`GTwRm_euQJ%T3Zq1BcrAc z03<^XynU`WJ6blp`OTX*&cJ=S_`d_Gg~?koz`&s#D{8Xi2FT8lG%#nI-m24$pPal* zf(=9Hy%RjDuJ7>_*r=8#Hde?wuKuLAXlza+Z`kz9YkNC81`}fS3xs{bLTB}kVrx6#^kPYv%AnX83z_f|mKRngS4uopOaQ<6* zWa85TE*2i1j`fNM0yIF;&mwRL-OuX7M>9jiL@Y0_n`_YbJ9^?pLlLF;y8ElW-&>>5$DWC{&zf3Xch}WDv-hh?@_TEX$i>`yoFf(% zpx!u-Kc!lQ>8cesnU%$rrCWX)D(k2L&qvg8_80UI3-a+W81-fQud1^|Y;!uX6he7BN32Q5o_9$XJ+0m#RPs4L2S-|4vLr$X} zf^Ia^Q&R@e#+>)pn)n^&Rh2IGL_*zZ9`NLT)QW)4=Kg!L&v~;)o&4-s-SpRqpYPQH zZdrp5MuUY48jAV@2I9>a{TX?=&z&Ysfs-{}`T=wYPNdG4=W15I+^K4~}fg)EJHTwo6ju<4gvWA*P~ zu?X=*c769UCV^XOMbr9e*s@z2cANG4woL6R)lw^>Q zkbs>-yIi@Y_l!>D%knt(R0Jt8?qalRbU-r}yFZbh!4VW34D#)RzT#7EPCWvjRNj9b z2M}2gcqjx~o=l`k#qNA2@NJ-iMdcxwy!-7m0@fo#If8=4@La{VhU}1N3>p?WNNO+* zK%DF}ryhbwx(4ieA~UzD;LVe4;Y0nKgi9pzZ2uN`{VuQ9kNzH%JYuKn|L%5?!u>e^-jRJh>3Xyuv;Uv8u*KL z@x%NqZQYe>z+wi?&zYz%-Xj2E6!9}UF=w*9t(B&H<0h`aE`YKX>n)r}k?bpb&PkG) z)Bds8l?6!KR;1A-cv>0-_XuGyX*dPt>tTDE7g@vC2|8i2muxrdK=aH86pU^zGW{B;YCuqO4j4JED=B?&8(vRr z_g*Ru2?;6M7q=QJrFrRf?%cUe(G`7p_0w~uw#qfBGU2TDKfW9VrY#9@%O6y*F<^~S z(Snv-3;VDD0bya)^G&<(r!~OX8dhNpl2t7-!o{j|#VbL7+K=!J%$-W)4FLuTnVPyM zV9|R5&2jLNLZQ?rK;{F|P=NtC2viDVkQs{xwqUMWr3ZYTV37OvOcU;~qOhQdFVS^qs9vUH!=&q>eE>MrJ(ih4L*`mq%~r zKQ%RJ2T5RNc8GIT;|Pk>UAhI_R2Q-4oeK)d-3jWsdXa<3l9^YH`h-1-ho`ZrTKP7v z6l9uuXj5h))!_L>UA9n55#hABv-aD3Yk62L_dHl~K~b1@nUOIu56>`f&O}?QX=X3N zzlzF#y3w->eTU4jOsMfw5kJ^91kzG^Nn1~KO@b1Qc3iOI?A-t^MSN-=S9 zBoS%*)o=v)0JAXqy}K#c#~|rhF&#V)+&>=#c`eCQmL4`~%|{1?oU`DxwSM?ne1ecL zJ^#IlCH+>*g^M{cTh`EYBS)+O7HE{(Bo6Q8bbyP&T^~k7M6_V3_hDcilrpGZv$EfS zw~PQwMdAfkS65>T>LWkKnIS$v=5Y(l^db$nQTppBo7=otMDu5x_jlRTUqfSKA65s2k4Ge za=&YNM~{<6%$e`)Ws74qI`2$ixRLPYf!>+(m{97YIk6S!u><(;oZVG+7ZC-L+PK0kxb8Bf>A$tkyxhCl_PpRj=Cg1wHANm@&2C2$n3*+OD z!z2%=0Z(Gu2!l#ST!{e7WEWb&4QB^k^ni_w%wqG!?0bCFsJ6vcYmPXW8|#M93Xy;+ zmz)3kE8JRk29x^Y5?)lU%d~i$6$te-{XFL$`~D*A29!Bv9GR$Fj( zaT!N07gvdTj!qz8-ag@y9O&z?Nmp0A4hLt3>_LFYb8XMkE+K(_q*v4#QEO1s62ZA( z$x~;0%558vO#3{3{h(EOc4J1rzSr-|Ks|I>+1S*|$79>lydZZbEt2ctPo#yWyTbeOlBqiI1C`7$iJ)7*Js}qt(bOg&nGGrN47vK>G6Y z6Doao^4ZRKe1ql9-{7x*E3^H;s-RZyR|3ry{w{xjW5!!#?F@{3D)`K|XIgH`*=HM{ zGWcjd%}@XuuVsg>ad2h*!xsO;-R*6I{yakm3r6x6!H?|h>cOabid0sRcc2AF2|r@C zDqj*rgXQt#$58V1;Em*a{&q#8_RLNNOw}F^H>1J&;o-BeD4(E><#`=?vIqHqvP&D3 zq?*Aa;-K9A*>7BSE1({#(%5PZj#oc@`V?_EWI(}n zCPPy(8?6#YpYip_b6~xI)ev-5`&`yU@76YR0D@r{3}{V1ze8AAz@b7 z0#qN++H==L(n1ios;a6&?tHC5o}%OTOZ$BWdEjzQ}1uZ7eCyB`89B(ss{_MsEVrU6PTe!kR+hKbpOOh$ED-hYQdnljX@Q~ z_vCNRbtktxP)7Hr@40g{C<&N5HL52_c<`W{1PIur_CHRdx&o<{xPZ$T$j5a>#R3cV zIm?pa`wGtSVlF3vAzgz{8(&(oRIn7B=L-le2SW3dS2Y=enn z12iM??t5kA`2+g~s|+DdaII;N`PQ1pDVf^NburzD79@s=3%V;xPjpwf*g3no)!)0w z1P+YX;cp=X$dr7FU^6>peDe75#{$bfNPsKpJQ>WjB;gAghg6g++A%o$9Z+d#77&(`(F;vYR~h*jq}k8KP+tG$wZ0&*Oy z^O?C5;7eUaXm=z2P_@hob|Jsl!9EAYb8u{%E_I6eLS+zm@8I`=2uLyty)MNeIVp+j zvxpmLcGOS=B^9}^UcJ)TRK$De)v>X$QIB~3g}bSY0Ia=Tgl!y}ejJA!Oa7zDMjFn< z67!zpsIg|xE_^t5>C)FbiyidS*umVN;Aw-}#emou*s1V{z=u`O6k!SsNP93zhI*5J zC`4VrG&Sn01_idSbTo+JKxF@QK#n}hO9Y+bv;iqoJBUR`A)FO8q}JEGO;E0C)?wLw z3u}ls&#>Obo&xp14#d@aMR5R)Q3nAMpLi=?;<6_b(!XKlB%(hD`~8GQWCV4uCQteh zMr!T$+9XnQnK?LG%)i5C>_TY<Fqjh$s}Oucm!ABY<1A!WWCO{Vc*@WEp)gFo94;eEaVOdr?Hbw_!caW0aPm z>3_PiQsE1vFuGM-T%6x?-02Rd5+KFpabaU)27g%jkCfLL z9=bRP$@96aPiYmks^H?{!XNK}G5L*J3EX&RxWH9VMtZ5gh8QAEtjzY@)lfRkNjF_7 zDU2T}Lnv~W?ysF;W;S=lD@uR0l8@|ZN4&GxF{ z_THZUrg3IQ5Ik7sbLWnNY7JF-1=L%>9$D|+;UV4#5E_(8QV_u100^g`CCW&UHhyk$CNMA%Yz26#$ADTV_Lz?^ECd7WPJ*5e`53{H z*h{9NCgwpPk5V`dJ-U`ne+4{T^g~KYveMExkY{@T{yb>sAXS8TdwYwXy$eAH@C2TG z`=&YtUd=CB$VP$C#CY*yf%4u>;4|a({x4!(@Y*9q80x$Bo6U-}GsW%}Ls z&!P7K^=&RV0Z$%Dna~tTZjM2>9EZdY*a+V2K`6Ban9f}HXDu{82?=_*B6c1gKfv{% z^>t;c^Um3SJZ4_$0r<=pgglljS5O;Wr-6>2e3z~D;RRrc@L4uMD#6B2DlDWFjU|D2 zKtNEC3}Ep2l)iYDs2gk``QQP5metdv2U-LDfDEb=wq3GGYb0^_HXAeZYy%!6_VWmy zfsqmOF=84#4A_N&^iV&Uo{PY14B!JX-cY%kH?1HO_l+7n&&ME`fD;RDsBVu=FknYH zxzJBcHTFG-lH6-1!_=U>!5gAf4)A&}E8ozS*jFdRBRM?BoOIP`;uB$OQ6?C@Hx5u_ z^M#TjIavWhNIi0$r!vXu;euh#7YteTWJL5-NWxYwIAK`^M%gONKO+DtX5S0WVK(?; z5QhYOSp&nD&w8YS!(aIn6_qwyGCiMxRIEP_{O%qqKuQS7V1)HgqYo>iBf(({E?=pl zMfn7aVNjE_{pKRCc3~< z5)=S)2!e%6fr3eBzM=$TM+Zm~1Ox<9AoX`cgb{_JKu52H5*KH0dkd zwu)JCAqn~CEA0W8NPqwQ&!=SpP=iut5Bk!7lj$QH-%J2qF#okFjb|dvO-60xB9WX8 zbQIqAfY1w|10lyt0Xb*>y_`QEMr|h-I<3e*yMIB^mR(*2;w%#<&t(O^#ciZ} zVa)&EcVL?m1!7$RbL++R^>t9V%{~xiW@JFy$jA_sSqOMb4;8Yw_z({dPbzt5mRpky z?|L|%G#O34gWSzqH~U1GGxY!c|5&PXWS?^a2-L`7jfruINjI z0eJsa(B44^g5V~k5FWibE+m&Dr9#ObgsVcCJ)wxjQlB^mSHK9_Qoy@NAGm${HcN5j zKP#(4juwnk-r3n%<0N@TMiO9H4B5$(KG5G9AJVJ5OwPzShvPg=rY|HU1W#Btob{fm zDa-lu_yF+fj^D69l~d6{u2PSd|KG)pseU?jJDe4zcEVXfRW{R zsJh;4K$_uy)-Q+Bu5Ir5Db2_i(b27n)p8_B865XCa|9ugxaki|^QS!jy}tizz5U(1 zP?iEA!x&qhBzBqHz7J#9&Vo#P^1*`#JvkdgcAev3%O1hJOHDQUmZm~LO4-WSM#UMC9}GSQ#jd@)rx0@4eN9300HyM@e;}LP-DGzR{PVdjO9N-8?WE%*IjK zNRzaTfo5)&kAkbB|64DCNViBLT$BOW^8M9bk|0Ebn2XQAz+k^TbY`+KL>2^UOREwH z&t~aW(tvf#&c*cv6e~a}U6pQ5NZ*Ip^M^8Bls?^GSNhb@kO=GscvO9?(?}p#`_;YG zMo7@dWGo)llu(}-D6&8qO9X$eU%w91cW$KeN}S8g898}*9E_mVPyp!1WKC6LxX>={d@P0Pk>$x;eNLvc<>@Bia;~68xl}ryK~8<<9@Uciyv*T zj)PS6$FsOb@AB!~OwiwE-q&ay_fHh99>AY}LBT~xkCTHSjxxeXNrl)e4WQuM!gjBN z4Tc->5@!KNU4t3K4}%yfxJm32P)?1Iinf$cd-#xx-gDb9M%?W*T#`K0FjDq`_80&~ zDKamqHIRaD0Snm`sQA*yO!eGGp!1Ba7muBw#&~&o0T?)POKc4f>7Y-Es3t(g@&>93 z$wn|W0wMYF=Wl@ua4zwJVhvx_{f1d|o8Ii}Ul z1o4qXfq~{o2?U!GfzJp|^@;o@`VWho5TyT_dJ{p_)LZePCHf!s}O@MSil^_43XPxzjJSohdu)xOevn+<2?KSH< zINJ8+62IL{o8nf}r%z*`W93;e{K|@$GJSu!Q!WuxCI@l{7?(TeG9SwnK4$GnxL0EIRSACfIs>Tgmp*;5;Th{aEC&6Px zNtQxgm&5=61g-3W!o?nBLkO-| zLlg)-70BwB55~mC5@R6upZ!QzTcHD{Qh(8Yi{KLMOw1qL8g>dMJ^b%6AgPEMK`xjZ z^4pEzz&-*61}d)3fKP*C1Gl$xBjTaB6T~wDo;{O;ZVGAohisVE_;|3M34Crpu##am z7X~Eo4+x;Pn+j6JKGumLN!iAonxIT;6f>(Btl|FI*l48{0{Cw``(+jkk(ZzGLax@+ zD>w+a{+#0=GU&H}){IU9C@Lyu=)6x$Yo7lFMY9GAHx3Jf`Wn8zzCP;yh~x*5vt(p7 zTX0Gr#ehhKxjo%ws11Tnc5!xScGKqSYM#YT#KwBcyLT82z1RL44TI46^G)HH8=!xK zo8%2$LcQ!DqSH>^!DG4QH5DbLMtI2L(d$#x|9vAjrV3=(&ru~%>U!(M98VEGWvf8* zD%_qu<;QJBUg-;;z3)U4rxo-RFE7R01mY$91A!ZWsX=@eFcJ_;N6X|jtKftyVFEU- zG@*^7ZnM^>Y0g2@q63Nwpt|YB`W;Ls5IAy@7#CCu&cC zJnAyL^GM}Fxji*mtJa>`=vOfGyysuupEjRo){W-*Hr?3(2`5}9r!Dio>?i^cVDb>o ze=r2GKWCu(-H?C*@mm@q6=niSUWW&py^TRF6e%MkBj6_&Z~@ADUu{5k`KUt~fV<(p zFpKym_FnTjD&F&(JW>}Fp6a-h2c64zBDDQ+E-mHF?db_o$ zJtPq3R%U~C=)F{T7Q|W5Bg=>P*0;8rV9Mu%yAC-)@|~R>P%YC-iuKZAfMC2qNEPTe z?!RbIgMqzojM6UWF|99cKpPR99UIIul&vhzPkP5{*R!}j8~hKks!9*~EcUZ_4tFY? zlD*-<3{z>tgXc#6|U8T4^dfMhjre6Z_-=`Gke+uP2-ZJsem(gUg~89*xsCm={r5Coia zmWvlpdb?0tK(oS=u7ThLkwQ34;Sxw6;1AS;?1P+#=m>XO8K^AeTO*$YiMk;F_`+W^ zgQgFJN!V`Y1cESV5T5jg9B?UE3}Ct0ifw+o3(`401g{YYi4lH;8z%mX8l6n&b<>8t zy3PEz3wZeWEKsd?_I;u2|94x-QHp5gp=v3`l(8SRRp59vA;6vjb@@(TdEP|oLOUKM z{?A^!4ga+2nEtQ+GPa~o=g3J^7(7MNsR+A2h{y>YTI_DlDDs0!t=VlR5c4|gg$u-c z03K+-tcH>K6o6(OcnqjbLSh#ZAo>rYp{tC8HFpgd3L1w1w`8$cxI|+ZpVW|Zfn$JE z&wq`=BTLxGajjkT0L+2L)fL>pV*k)!CIVatq(cw}2LraSvVM;rO?*ZMy5o*` z$uIyMJsI=MMaoz#swJ?>njmTM48c9~4(xC9Pr|GjlK;&&XQ_kIDHZcY5>4s-ER^Se zzw9p`((Zgt$ZkSBM1$kg!H~2ZU$)7*#TT$__h#3iG4C&B5pg0uY*)$0xQu=KIgtT@bQ-jJW*x zFP%1dec;D$O#19QegdUBPo0JnGoCJn)q@YuBTpp#y-;Vi>9BUL2&+)MtGf;}oTyEE zjHKm|VZ8T1Vj;5rzR~I!QZ2}FZex$8#w4?s7oKf=HVl{Ue}Nb0_RXS^*9mJ+5{P`> z32pUKY`%6|aCsvAGvLV-Jz@8yli_Z{c=?7zmri2;Nt`31ZvJIzBA8ui)UV)~F5mrS z-wU&hs>Y9p2ZKKvnBls~(0-$(S#g~C6lcmR7dG>JsDdVMT5i)9p;Q7~U@o)FEVj?r zHH<;KIF|b_H*So1W5n^8pQTsKT7PCbu|j@!Y4mZLe{VHWp>&@#?GS@iqL9qGxkI~T zuIih=f>(04Mi2#=N<9T$?T_=p=gbxVLyL8T`QL-)Q+7=yvg$}<+j?Tfc}#=ZEIr=! z5z7SPXO9Kbc!y-Hyam&!h%BD#B4iZA2E1JB!13t`!55k6(*7f3IdFf9abaH}p}yo; z8&yQwxF~x&q|DKpxAdqKZP$^3gpn>kJ<^irw2VLYX>)47y(_!Z1_@a;o%AP{RqiGr zLBx+$LD{R8(fs&-uJ3YD)ZJ(2UYk@D#>>!kJ=stwpQ-7_^;FBzq(ky#pF`t!^U`QAVI>0sEmGKOoHrO z$+3r*Uo05>M@pnqVQ zn!Fx4hshKCro3#J&q2V|VfGjc`Odv%BJlZo>$K9hSO1YK%%4TQRN|?1{$c&fWYLG; zuG)n?m#KAm>1Z~uz=3WeaRrGL)9>-qu18qI0^rtEem_cC0G;&Dm;^7Hr9tC_ajz$= zDm)LW#zp7mqwjIjel1CLNlgxtzpaV0A(|($Y~JkJc6!RLNIYQ047zL%97kKwsT+tY ztA4E2!v~i>@8AVeiJX>}R?z~={*{75hry($r&C`xr-Wn?9M1meY8fn!)N8l9nhP%R1}Wpv6H2sZeY!5>UJlk)P&^z`(=bUy-SP9T_@WpY9bZ=soLMlva!lv8=Q zI74>XMcpozC20WFsG3J(Ohz-`biB;>ldt};mthJ#{~BN>OCF1_sUg1FlW*_bxwG1D9EO8I$28OLo(Ij51uP(B z-$Pm$MW#`9LivzQ2|d55MKERVyfvh6bbCPcOmok$MFcP9XZY|QnSkDF|Mg!!f~5@h zi9%98A9k~P7M?dh{^vu%4rpwC!9ok7+D$S%qWDln4v+m70BXPKI7u|}EZKHFqAOh6^0k6aRBLS;i zR&Bt$Tl=H+si=`*iv;HGGJjesGt}rue_4vIvsRD$;K?4_@=x2cu2G{<@qlg8Oo(#K zV3Rb!mA{CK3k64MdwV+p4#uJx2lr)RR}QH>gU$IKsN%e~H&kR~-XI~%RJK9gvjX{8>Z9ezPlZqY0cHsc12F?U(B2gsy{C|I+J_tv#v82cvM=k9bH|6Q zJQ6YCNG8VG+Ae}0G!;-?&O*J&Lsg=1AKM7xCCxw5gHEJy;WuVEN7+umPQtxA@+fNk zmNa*3dQX*{2fM(*@Si(_U5tNaI*kDgj~kfqQ~za)jJGE9-0z49{PTi?YQg%OP%Ey{ zTTTnT#DJQ4k)VWf*3g|uF$rmD1VFMtpUr%=f5dxrb(Mo*AvM~n0Es=IlXXM-620>S zF%2O4i{?%{I$qeo;9$f%)PEcdhzWs~B(11;91!>#9N^R7z5-Ys4&!6sj2ZA4Z6x)4 z{ee^;^mc%N;9%K1cTR%Pw&dF8_wJ7o0~kQknOp?(clZ|oD!hF&@}Z$2Bn`tS(ns`y zmwwD3hgm0L-GqvFK$83-!k8+$QErOl2OdM1gvYL`tC;$n!!MBQ z4B!j~h(d6bl?Rk^13Phra7@h??Jtq*CesUlpj8h(mS37j!T-V@2C%kai#YDxFzZ3!>Li^>ewRfe|fY*Mr&-ivsXz%)k~aJCmAcK;X%&J{;4 zdN^IE*xDH{sXtxpT1gLLKrQ6I;nhweBIt-cIx|qQ*{h9S`{7$uo`|dv2KUv+aJsji zSO@|@F4m0@=i%rpvl$0#9BWANXAGX5(er5Fo2LKJ)??Zte(dgn0(%42izGWd;qg=t z4b3P?*S3QnTf&t>9C)@xK55;3hd5;2m*II9Bz2obQ9ro1voj+i`S8>-7(v^hc`uE; zuCA73yA}5gNkGud;RPK2VPUr_i;5Cq3mk`1ftL=TNT+#USYY5|Rt1Owu%efMKrTd! z8$ux0z>L83b6x#HjVEoh``IY)?Mbi*kw2ZIjD7Yz_^UzN0O{RdOw@K^QXN-t?n*N= zzV3DPO5!01zvTw#QF@{8C}1tP@TYS#+Psy!e~sfN>)3k@r~G^c9j1@+10@cFvpSmX z4eEM0TZ&)@_QRV={7#&~<>g%n^aPFN7zWg#{KeL8cuslNqg7g#kV>FI=gu9^A7U6m zQ$U^&xJ*d=bN|_WEwYR>I5-$3#K4P7f(YUqSP8l$DI@>!HlK;uo`TMD`Lxb*4;B91 zUxw5hq0$|agemVx?5+h{J+By)n*2&mbJ)?f>*OoQcr096*=9pNYCaSW(ytmakA@IU1-wBh>XIqR zK#6r2oFu@^1l5AcMuvyeDY5s^x6|!4obCok=$mUfyrp3rGF8Px$_QZCQJ*;z2&o+- zV6uWGl?7?dkVgwc=Mw>Tq1O^Ug4j8DWOVSdAgDxeHAP0?{HIn)e&WBX@Z=pTQfBpr zg+DHR^O@96EqDo#7J@S|&)`is2H?wpS3*iDFAc|UQL+;3pTs-7pBq&@;T=a8ajs8c{))~t zFEco9x2!YG()sa21NhU2NkmpyDP^xNIkC?ttOfc^4R(vU_o3Fs|6Gijrchnq5+x(R zF;>2hh?ME#Z7-IpH%mq(2iFU_SX)8KVa{ z5;RP!#rqnEy>fk&2PSn95D)LxJf|3bnl1Mho77LU`#*Z#Z?4?W zsag1nFOdgd-qE;CS;7m<;N?5VFhg+o zv?ph;?NTR-YG&J(MGy9LUP%q`pAHOX1_pbPC)?b(g*bX;oj6vFx7m@~~5ykSC<*D^DRk$_n5LS(=}t$uVM zUGP65Tz3TJ-3wFuFZu2_)kb*`qz)xN-#E!g6^ozn&DNTkHTU#$cTt$ANA3T#7$-`( zY&Thf(I2QwLESCGQ+9mkL8HN{`qX(@;19ollrK$US;Z|s6)$iNFMD`)mZpVsHJb_J|GJ`nC3923-8yH z-s&rIx){9f%T?}IaQ}vh0DIjWf

GiE%K{X{}jG8bSJlGJnl(c=H<|GV>h}N>%D7 z{{irQ$DpWnoJkt9e$PD7&v7BJj-8pOV7;!Fov^lKz|%55dUn7&?s9wsgaoanL4_of!;b= z<*}Nsn+`sWhqk#@HF{mnMNb%YJ(y&fnKRB%c!#{J*kUj^Tpit*?bJSqG&utz!S?9k zfj^3=md?hx!8>M#T+{LX5%S?_VH4jfGE%X=@n4;xyLf(Y{g%uHQPKp*`hm7n&VvXP zhkL1;9WhQIp2Ja@?bP#-lM4wTthowV0FZx=V9W-Kxs8mBGIXmUvF;3SnnIa3l*|O{ z1`^lkwOtasb1dj3S@6oMCm9)#Z!pXrum!BAu9?pUxoozVpD#vMP8_O-T2tZS!0zd} zh`If$CE`;{tKF5m+PTvvHANnA9BLsPUi3Lb$|{ttzre)G?19jA)Qwcd&Nr+Ig0BgE zFd>}0P^_b(-%vUDq?5G}f&}D12P3gPp(8{Xd!ye?I2XBkH_uVxvX??aZ zr`0(wW~ideYE|~}@721Aaer6nN5DuJm~*ICdtdinTd07)1KrFvhXqX0<)8T7mrlJjSjSXo(- z<^Tv6;Oq1{VRq98I*nDyRY?PZaqpkrWeIIMyTuQtn_Q-ww5|puZ78xYQ+5Vr%8n6* z*Kr5_d>3$4+*?cF3mIeT2TDJQG~FK7pxg9SwILR{yieD4K8Y}>u4VM>KrFItfl zZQV=ZQblGUhV_5HUHfs#_*D8E?AiKve^l0kn2`c5AxLv@dMdovef%G0rT7pd6Ej8a zS^#JuPq|?CF2mFdnjumQIJO-DiF&17kEXi%CFqfHFi1n^<@K>^4T#S$(ZvzsW(FEC zE(W;0L5YcpI-4n$ZD)JT@QnxHFzM$;NVyaaLexThEKECqB{Q2o``^7H55709>u#^V z;6dE{vGH+S7|X9uyqIq4KAXbbBdG~=JRk-FXPEL*>~s!C>!T*ZI$b>RGdQo3yc$cx z(k58Kcx7PBua_lN8tSu9(%aQ^JIy$vPAv@~ry0$GPM;QYgE~wbQ zPS38{NNTparR;p_ac|WzFnBik?j=t|eCx407}FZ~AV?RBsUkZ+c&iJHYa|sv!tuUv zXDk2ZoA62z#>g&`@3ta9P^bcFovXih+^e4llm2E+&k}Inl+B2G-ud`&I_H$WG{xjV zyqOB-8g)nzyRNtJ&2Sq`u8H2=*t|BII4*;eN=!b_cp|W$=;q5FPB-ngKI*Uy=Kq>N z;frZ|bS{P4$ScKoUGjLO5?mj^KUL$rQVP~2bKc^p-4uH$PN0=UOQEx5@pMh&Y1qp1 z2nQ!KQJlCNBr2a+pe7=X^rx`B=Oee~Kqi-2xb}_&!pG-j_+7y4`~qGEQ?4o8(-$=L z-;u(jd9|*a@{Q(F(RZI#z(by5+&_BEcN>&7^n>=t6ksa&;tnqhko*o{PeWa-F#Qyd z1cBAm-KT(5JideXs^U=$N!yWZNuLbExQoys3hIw75~j&a zVGHD4X1KD;ozBSD2Jqj%t9E?%uv0?#j3XM9#ir)8rkok}^v43o-GNmGKn4K@As$sj zLzSjmcQ$6UlK`s2v;*|$)5qOq0;FEXVomPU(lr3R7_uxx;=n8WCZ|-XQ+(G{ z5zh_cB^lKAy}jzFnB<0u4EqyV!BThej@y*=gC}C#G^=!D@OVO>-@Y0PNUj}jq%*-G zYT1hmGZ3h;adPH3D!}E{(eg-pVDp^=oGpq2;P=1wOsPU`5?C4~3;-^tKg2|Mt+`(R z{C;Zgd7lki$mMT~61AP7z@-4~GN8TLMTtLkpEj>NO~9Dy8RL6VZhnLo5UF-} zhBH$Ha<29PC)o8pUe_~gTe-Y^>b6fO_6BNF!}uo!7lPMO?}hABx)k@Xq-C)rS_1{yodA|tOmbBHI6p>~(Y>$-5F;U_+Ww9tm z65w@g1@9_OKKFsq1CxG4sl7A~lF(QINm1;H6hK^hB{0T~1_nd-tc&OJFw&G`s*u6W zwr6gxC2$adPU>GgF9lOhFEC4)>--P_G9w_)Ra92ymj}6XD|VgApYq z9bG80*8&F1TeU{VzPCrApACa#1aMZlnu5n|7`it-;zrU&_#p*A^%v}F0G-98HM2%e zuCoH!jt*H%0`?jStT5m&K!6{Ab|Xa&yhiL>sEf z@8igf854lGa{u?_(ia+^e?dQPknk_6Z~iM9w_Hwz_y-*HFRKnar+WYcg~0gMb?msT zXqQu1207#M2RNAkR`d)dEx1ukQkv$zgCA2Ys6u^$CPtze1hhC$-_;bWVlC5hLU@cmdd6rKbHkJj803F9UUScH_+OD}E@7!MZ>mgeWj4 z0zD2HBotIaqrM(AD@EW)4RRV>Aqr)n0`emR4`eR^&|&P!QFjfsf~fZs9-ikxd79FD^zo7)#~{P6qt<=`Z_ zOf*sVK7~=i)UWvy=#Z8XTU19$bQIsead3E z!n^qHLzv8UKlQ!oVDmO&!?F}2-i>rt>Orn))r!BI@JB5R25|qx4SV~zh3E3YkHNe%GE}+7#$Tl(Vk!*f#|K#BoNw6vDVICeOTxIUN6B@Y zd<;pV+*-IXvvV)$zj|31d0L&%Y?WqU2$&k)q$dU*S}jl|PYG^tHp(#Woiu59fjFdQ9Z5sQe zn&PJbB(#F$vzo+4L>02 zT>|zv%z`+;I(VSC1RB^&Dz}Jo+)PwYdrqXYx0<1lUGm3;6d8e44%lL$M)TjTrn8r5w#Ket_Rr=K@Eev ztDQgp5gP`K18kCz%3VOZrfeA{CE$pkvXj86-SSc%^R@0M0a;QwqE- z0e0ZevzwKBw)yOK+UV#NYiHJoppUaCFe|NT1_Otr!l&YXK>Ft|S#3wW-ru1yJSe!b zul=a%-*0Qe8U(SyV9m%ZHW_6qGbd;pYT1lRg>BZgAq|*u+WW^uK<+ZO0BYolXL`IWnM4z5+7Z zL@=-uih#M14;C`uhlE7mw;&?H{2LLB!Q$^rn%q60DS&AmT7zP3Z9NJ`zuCHh*D^cI z!Q}2Eko*zdcw?@M1TOlAJQnD+&%j7ObaTGvB0mlx1~L`m=IZ5N3q(~az^NceQULWm z=naK^cWw>3eihIgK$I{5g{r~!=@#g!5%v2oizGVmFzII}szT-=;Ozkn6R|K>$S|VV zTuTw|-R9}pN0%|B?_IT{d>E-;GPs(n43H``D*@_&ElnAv!>ZpnG4Bg{eV3h|2c~4vSWJL(s=>= z$6&yQ#Ol8xzUg6a7kWU;2Gk$0f0w+wzk36Ky)JJVuuC-}d+MU1>whR9-9&)MSj+)$ zV|w$+0Lzj0oDmt-=8w~Vs8QDSZ+_HKKPmL$;HW|~KEpZ7yc7~G*D%3_Qz%veM>Oc# zn6kKBKsH6_46Bo$jP|j^sUt+;RX!CxMeX47@;ku%OA2NuM1Fcm3<&NCG~mL0bK#VD zB#b#Zp}Ym3PC8O~gPnB`&Bk7mX%X_uoE(XTCLm%T+2FhPYQ=bE>NLs`mz0@y!d(647v zsD;xho0po2gtiQX|0z`K*ARqy~ypFYF#);Sv4_&SQa! z({4G-}^|66wZBS z;Cq)PJbmEGkk+i<+ND4BQH+N#LQTEBT(ydpNVar^SxrZ*ItHM%)et^JB@GpHWpIzJ zVfR+lD@{q;44t7O(w;YPUgUNaMZ?m+RMv-pn4kvsNe9xaLdZjL)t^!PK@0}<YJ%fmi}7c zZ|>l^a*5}1-@}ae>hQXbn%~JT0&23+jsX0|Kw-*YMxU!H5=gAzfwMVs9cubhp@%lu z04}{~wxRXecow!QFMPqIP})ehF8uA~i`qUGxV4uDwDNU@7p`?y17j7QEF!KF#uZE@ zKQIhu9i#M+OOl4A0*xAoRizpSRJ6qntDgJMiRCD-VS*b2fwX($MQ+5&CVkTw1<`m+ z?qIKa9xWCNxQVUcR4gicDKJ{Q#N3E&XEJf?W5|%hLA!}u&qr_xjFBv+lSUlf00HG<>y&8YX(CpMw-I}0+oDM zsiv&T(x=L`jxmD@i~bn7;KO@^FE`OcpD}mHiG`A^A}EgbP;3tpq}$<~{st%QY$ve+ z2uBcxuc53>`v@DSDN2j$%A3}oa?siZx*`jeK0F94kd@$#LCMD;LBrY-R?~7m*OP*i zkkUO6-Wq#&f1ZPIM{(eP96k^g1O*)dBXsdU51wT6@@Qg+gR1b@V{zfbY_=Fb&g`r= z_Cvz6N}c94?Sv1sFL{{|)Ks-)930e5Ulz{MeugVe6w_1CY-uutG) z+JBbSc2+;t9F0=+otuSN^#Tl*{hCbZY2v zoT;0Y?0LRW7U}#)@6q4uwqV;3!XjaQPJsD)Q`q0QbCjCM>91zS@|zk8!V}OU9!I^> zBy+B-)E5CvS$H}efqV6ZZpj5*aXeLDGplb1YU6fDk!DCU1daYSO9(e6LwF?^Ng&0I zHqB9BE=nY~o;=&zn*HChd^0G}km(dA0DKNWQiBI0sTVI^m;=iOQ9!^o_B=vv3?2WR zvM~XgmKHW18*raFxTo~hCG3?0uH7rN4_rG7R%(l7MdPbt_UTFCynYLRT{fLLL2lA( zKEvr8CYF63Z*}gIAWsDGJ?+p|b&^PB9>r}TLQel3PCK6eeA)>%6O)r1LPG7cW)-b= z$l_Jp)!OQO%}uDSp+X!7?*LgSA7A^FFT$D-NNQgIE~2|Gy!1yV$ImEM@2iT>?-ZChrr2K>BH32$>5aMZbby)Io|T0IRUZFof0qdPGzy2 zc3|~cijM@!kk3Bktj@Qb1auM#S-0l6d|><9*22^bVdWdyvc4Q_K zLa2yZ3wMb5KLGaB*4i19@qKa7FMwM{Yzwxz7aZ8lcGF+^P%%-ONNaMMaMn z&(%(o4hf{@wYB%QxO%JIJVl!G`Pl>S36?xpDQ$XElL^PE7pa@(*p(rnb9e?AWl&_Xl@CA8Ioyx&{~s!7Z^e?J4oXIZpyObWz5UBcBn1T03G#tJ zVCz+rCRo=5do&sV9;O0}4E%t4ZOq=G9TAsW50od^D_TV*v0}2kU(Z3+`^b;KS=*V$ zilbR@|0X2YQu=_!X}xx@YJuN$hh-xJuAFUEH521)l_jVG?#9#oi?8=LA^{0}eEdgm z%dMOLg65o;DZ2nfMAlA7#g7(F%{CEWpc?`{9X|?e0M&6oF$tA91_MKMIH6Y6!9>W; z18E+?z5@9_!8VU-u=KX~V4igQgFOkHeGqb(Bu5{^An7ACS|I5gk5y(ZGV}y13&gPl z*EnGJf8cvv@W1k*CC^U7sw=Do4RF5TDexc8onq_6=bq&^ z?xrG^76|^&A9%lT46*x`XQ__UU&R1#zeZr;dbkt$4mIcB{a%mP;(r?8pj-+D^5@U@ zoK(AZF%JMXE7&68F)R&277T&`@;E+yfE@|CWzOwsa>y+1kSBxs;r4ukcZMZsWd&GC zTz|QzqocD56k_DTCH7=s94)v3oB7eHsZdaFvPt=DWWU$&{7`CV;o*43%yac;okVfX z@xz%7=p!I%K*?`hqeWU8L!$ZasfT58`>=mF8n6qPVd+?}UZjNB@e zLkjG33LN2Eo2EDZQJ($f*5rD2Zt{#K`g_7IH}Y;6gIO%N4BRaB1G~1Y=X=X&m%=dp*dnLLnU0(Y8r0wQ- zIFNKRKda496U@*J_)I&jF00zwQj9V8%Uu0qZsXsIDQ zPO^?cLj&lCkyps&wPje|l@{kKp^8Qb3 z?qVRHtzY^Bh9gi!qk#w|^FX<5`%!r)nAp3V_^6<4a}{hIl8K#R)Ez$I1T&|E;eai} zGIw6&ya4fK9Kf1-P+$Fv)&P(|erQ8|U7bZo#w8GgBgQWzIBjcGb(_H!n1REHD*F&d_7CFAR?;*Fa zun=O5gb9`$s6jyR3d1(wW}=N*`ljAT!7LoDxq8lo7PKkB0ta$7>4!+719eGB9t-km zeq%28dSC=xG$mmIk}v>-XfT?l)5W6Ynjf5oA^;HR2l{UYCn69d^cZA$qp*me7RI#% zs)xJD(0&w{o}tmR-H_4^2)-?LZHk>jUdiPWK=l@sBQ#5Kga2sN6rx{;)~GH< zd$Npi#P+flx5eH2nJb7?9>^qS|brb~N@cgfp zII)Pq9z>YxjKLrhZ36{j7BX~#sW!~@<)AD8+AXk^s#tChYUdm&Qp!Z;^e!;5v;u>a znXWu~d1z39JnYdWEq6x#x0`U8!5C`<{tizpMoNWEj*c2oq zV4F&Y0dto#a8>!gB9_|Qw?KIye62?USR`zy!N|hY{k>rWxdMmX9L=Rb{;x4yIn7Sv z4;3}sJh1h5zWf52X^0SmaB%ye)$DEm;IPv4p5 zXz?FW^RL~-;khsPl%B2~!_jfEI$w7)`xW)$mj@x=q5HVLa74FnD*vIDT-B|>; z4`BLEy)c`*o^6Arm+9EKeNWK!f;&}-y!*PF@d^9C8gFRiRGcD*7~!|raUpa&EV9BG61!DPIdiL0RD-!xrbulERD_Dg5#kU*}syNtYG&p!oDyp#q zJa{lqxu3>!akF?TDGI`k47r-r-P>YI`RcbKeCFzyG9oqAt_B|Ol5*p3<6;% zJWw~}aiP_=?cHD0!-Ckn)1{idHm8*lzzQ20@wUbKQ$WF?Faxy5Oyl9e10S47+!c8JA&*K zr&G*k)c!O2qR#>1QE0g5q=A-CLra?uVrp2-(b($5Xhs)qtO|A`VEPD-eu{!^-5yJZ zz#2b{d$}nDZM4~d{5SPOf*4Ok0_g{KLS)ed%fMk6Yku|_V`;pi`q%i%gXa4ggz%hU@NjllSy$7+WV+xa zPI4P2?epeL*cGi2XN6qrihe4#8Z*k3>c$AwqD>ivf_NolW7gh=(*PrFirQ6N;OlP=~K{m-3|Rp$@_1vTzmy3HvlRHd)JFG1n6O z1}IRF`nbc*{-oDbJ0^{6i2^bu9Gjptm)o`%nS$BnVJMSL5yCF^hZkG$65`zXGCmsD;l~Ita$tDK$A4s;S1qI(*Koc6kY(!h{dddlIVAlZ0>vL|y zN1)Gv46A^#=MG2%kPuFjp2p7@Td$epF{^vqdVmBn|7@$4dt5|So zr`99?jW^VxG2})daYO%8gs)T@e^cOA^}xAr}`rVwAvu%BUe&iuW>mjw>;f=jAh@7yI~VOtuyUW3p_LP{B1Zx9@FP_kQ@_ zDiHY6LLaAW{ez**SXmveQ4-ld&ukXY{1UXc$X5t1O%XtgJ|hP}-Vp-kAm0=HZF%wA zOUx-=KZ#&$45lqK&?KHhQy^e}Bf&eu%V1`B4_S)A66MWdPn7rXy1Fnv{@JPI2o5n9 zBB-VB30s}V4Aev)jpQDzb{2!+vT#(+)@OZ`$neZz-qhc&;fDjL?N|EhSOAc|k&!sU zAA*B}F&LP;_plSmLSzLsQI;dP1xdk&Q7G&g2fRrU2EG_%GK`jNjlv3hq-jQra?rpM zI0|S2ffmpG|07T6Fb2XZ_p@`9$G?(ykZa0R+xW{ZM)BhEXLzc~&s3$VeWc{6FA4xF3a) zZ76@D%#~10%- z^X;ILqq58QvovgGwjQaUHvMinyRmU?zfrP_zZY%@{y>zQU(|n!c2UBQP?a8f62si_ ze&kSqePok$P)2Xcv4#Y2)dwOY#}eQO@$=x}KoU?YqymDPV{rh(evPVF5Ke(?yKs)K1DWWEuDu54ZcW1L0Fi^Whu*K@M znn&-{;9tR9dqB%;#)LsIe{Wht4tF`lS%KcgKC?LQ3lXQH9~ct>Of92?I85{A6tV=W zpxuS*4pmZaE^0o2PK?z2$cIu9wga^+!b`@%z+Ap-Nrs0f`@5J2v^Izh#t>OAv3xIl zp#c?M`qAP*H|IT*?()drm3M7h{hh|X>e9;WyEn6^4nQ&_P00AUHAXT~cvqjKfEB)S zCl32)DA#v13`vQ zNKUQ<4(W(g3~K{u!1JscD2P#T_L;%c^n_xE1hFkJ((mq=-t~2N4kqP9p#a zkgLY!FTwWY3?5D~vF^FM+B5*s0O9}J&fuD#BP=!)z^vbL zz+zt~z!pLCF17YK zFPWRihNndw=V`tS<>k0rySsX~>aNo!h_h=MxuBw04rHyd7^9=m33yn_^#=OW)z=`APH^#?l z?|efFbk3teiD#fE%Z?*Ib+p}dzqi)w?Xq9t&yDk~G@6o{1*@ML)^+#R@j9itF)7^j zkN!-Dufk;B^eTG_ZVJ{4jvF_b|cczFHuAA>%&8 zlos38`7wCbBl(V!Hcg)6X;YE6Z}U$RL#J8|a$x*>=~tsK0X&U-Q^14+0*KyMP#M8M z091viH#Iu{5jf>;c)|X3P$+59y!?I^wiVBAdQG0yrlF#ywsbp=2@ZKS&DGr6TCj17 zj**cV(+G$q%>#U5742;pC zYBY$eoCIp~7{pNE4ymb{{Fkr^WGTy%lMr~Z*smc$(@o-S$HsGRCUJus2TVf&0fbub ziMS?5;4s)(#6Hm{{6R~U$1Wqqh0}QYY_;v%VZnA~f=DG%mOtiX~iI;G5@LHT*StYP(_!Y51*Ptl4W?Jg8e-a6SSI1^Q`Mu@0u+2FZ*AliBjC=Jv4p$AP z8uL`M@=s~jO+$$#`6h1Dte-oj4fCBP^Ap>9KR$n5tUPtRP{+a~OTLl21@lv4l%H3* zI_6ELXn>Zyf-!-m1ZB@j!o*`P3i9mFWw~q?ZZ6M_Y@d5?Ossg!_U3Haeu!_f#E4OD z?sJc1iNebqF}0sgnm#&7P2>RTVF3XF+nYDR+aNnTTlUhWqY!oPv_Eq$6}!>Oj3$Pd zH^6G|?ChNIsOaw2>+I<{!5TgOSV!Xv(3lcBWPonzyVZB&smM)g*soQGn{oqGK080@ zKNDhL_4mb<%|ib_0S#iJqUoPLv0_%xu2VpzXQgX|+Pk|$0U*im3lA}N88KtiVn}_s za*W?JYw6SQeZ?a!EiF7}&S?5ec_kAsO9-dpw|aaRA$$G$^{km-#Fzw;z!60;`5+wi z^KX`wC$E3Qbe>40GWC>F^u5y~JZ$I{T~7b@o&7UvgMHF+~$os3ZJ)5Z5aiJVS#iZJnHyz+OQKC#EOoO@WU8I&bWi>)$dS*c~BC zXU{!O$nCBuirb;v^RVJ+RMh8XQKmCp>y@hN+&C9}!pY9F&a0Qu&Ur*MjL&W-B#)FQ z;e338Y?1-QYe1@huwUkSwC4|M<&yF2%&2M@8%4isl^{9Btsx@*{rBlJXO#W?q|k!T zji}FiTez|?#VDEXgHh72HW+G1v6^AK9y&L+vhBR(XIo*yCQRo4G}rvar7AHhzTQEQ zHt?y^IlOcZf_+5B)9#)g`!Dml*Lqa7ALkoO_Rf1HKh-3to_AmQ)S5KUoEcO2D3(+< z+6k|HRj_isqw|k;+2W)2RZkxozqfde1WT8X>T^_V{Cda|*7Y`Xv@n&RZ)s$CUV2#6av`)b<{5g;28*Ekc*4-XzFR#uX zsonqOot*z!L!Pi!1UL4vtnMR;2-jK5laYi4N>sKei8&r)n9_V44I&|4BI#pG)=`E?sAuNr zD1i1TBRhswuipU|ah!+BME^{em6b_1Sib8chW+FnvJFWm&gu(KCdJx@unq_3-}F1B zwqzRAz9$phUe5aNE!m#r??=l)FBT4@R6AlFV}?#}TnydtR;87AW7#a85f~u;#9!6K z{k8oQ>Tp|pOWwd+xtZBdux|$392^P{h)aEMJi&=pU*h?>$lL1uy{@$Kf#=?S<@kjs z{i`%z%(btG-%!I@0jc!)rB8cKB4T2%9^CEf>bkJiYvuAyxWOF8>xg>?@=`d#&d$n4 zMob9*L`5*#WW>nCwA!J}hRE+oG3Nr*#c{CMICtWA1+0&!4~P;a*%H0^oeL9CFpscP z@P;fKaYNii*@pGSKfD;w)x~aZ>^L!u2_)3V*!s|>Q#j))_E?>XntL~vyzwKjh@{|! ztY>7v0wcbv#L9vyCnsZ)w`Bn94Ud96 zOkjT_O9u83n>Ol9O!dP3DD%1vPT1Es{Z}8DT`9z_fD+=y=c84MoPh{k$K3j^%Kc^1 zbyGa0vPWsoDOZeY65lPV(YP+g&By)XL)%9$g=@uHlqY_SSS%kUFex^>ccOH@G^tNd+rFg$Xcm1Bxn&E$xARQFg~xAFjN=ZP@VB{@3wg~ z+NYDh-n+%u?LoI90dMkK>$@hu5%$NwUOjacT<&-cW;e>rE_g5}j09f(_O7IdhX-U|J7u+i1(hYQO)VQc&>y4++jZ9!BUuV(L)(&)Q6(>LgT-28t zdr!t2)crujT}MYpOaRQ^-vtpiH=2))AC=PUbn{aRNK9kp<5my#z7n&yLe}1-|p}XT_R(pwP z!?|BcY7OIs$~S&7*JlNVr^9v7FI262m!RXr(M zqy9&_2mEKweZU#{RB3;^9KF#$UV>;7BR*`^o&Ee7dhCxg84oZRZ~{cjC!e}35wQ7E zqJcV$F(j|RQaczRg$^t#-MMoSd3qriJJ6&Z!6d;ts0#4r&3q`m_^kv(8i_ zzY7bLR0FrO2S)Ry2i|^km2-V82T|@SvtAhVXx`lhGXsoj7WU-C5GhIBc(d2|vR88$ zCUvx?bn!RKZtEtknPR?qte)wa{sbmRtB~>*OO_KyeqXO20lvcn@NU|i`;9gi%ys3J zOHF}g2O=<&Ut5I9s>nV#&5aL??294STUbyfG5-|yn2iD#dl~D1Iz%ZByp8tIhbErt z(Q=xc8y$`+P6uZCnF|5_YeR1+WIEn{{B`;5-bG^n+<`f*@Nk}nvg;c+dzm}mhz_pK zIQbL%)eI5yRT3+CV;^qiWpCv$%xZrWWF;dE`4iw8TMlnZ*r8Vl-G1=(Qnt50EuT(l zEJ)9H`m``Gtc_tkB*9>>l`||Y;zZ2h->cF7wLpDF>aq*T8c6tuhzB-LOgyVzw?c=O zEktA?OCcuvLY_e8g62bAhV*U!aNVzO-4m_XznC4|>1~dF%U`e`qf23}No$?hhCXw4 z>fvXm{{H>kul_xQnY^W>j%j`EajbAqh&1CUmo_N{pN@v;-!;V>pvLIx>Al+9+V>p& zg{!!G7_B)blZ4(!{Y<4R;>SPH5mfbkS%Jb`>NWet&}x}Y>FoVs$6h^38g81~wARWw zU9s(~A&=b@@5u>I8*_CP9apxxcA0+fm(GmN7R@gSwjq{9eeD)`w#jRd1WumABtCzB zt4(UVYV6i%&*yznYXj-gqOyWil~(36a|Pu&!!_#P)6 zQzACJ#gbOX|3EItCI)YZ_b8+{Lw4erDsc`8;19g>_~1=ZxR%Wn{E(>ODzI zhgrWuh@<`^!ji-urATFK!8|t&+oGsv{z}bpt2pdkYdAq)uPfM&xhFa{>h47EnXEb_ za5Tl3)Id~Lh%3c+<#;DC9gfd*+O^bHWT~CiD z))XF^=I6yAiWTNoEtN8S5-^KCf`PfIPs|B8q~XMmVu=+f#4z$D{Gqi`=P7@U>Q&7= z_p>dOWcifoPo(<6fKESJmD+k*tFa(9W}t&NRRKz2hqJBFc&5=H5!saAx3~e7w;A(8^u~ z-2|hV`J|>Eg9WSf=xm{DS24`6)DzjKAR9NZ`y7;u_J_Ix2Lr}l&9`murGO6{5G3W! z-{ANw`FveSI`^{$0F1Fy$Cgxy`@DkF7HqvP@2M~M!TU9N z$gjX{aZb0zGt)s$tt}TMu5Wrs)(=F>r;_TvVjEiADTE`&_#$Ff`i3gb3Tj5684rgH z0y{7K#DIzBT4e=36ljPHO9hf#n|Y^>imRx`;awWD>{q;Q0om~Dx1=AL7tIw*o}KO9 zCsa`Is#{cXxKos;{o%{)Z-rwFyGOtNl;G()T4!>!&(zX`>oIOc8`ZND1;=fpj@#a! z@1^QEcYSlSq+S0}N=sHof0p^v2O-iK8g=fekAjaH#lQdX!4e!5!<8}-(dO=Y;>aj0 zLHQ;D+Z```glC3UpMM;3LRM$+ptAeM$S;7hKfADy0E0@_`1tr)Y4snJgm|zz?COfq zwPR#>-{65fi+A&}W_s8gRgrvP9!#fjYqv#Z@$UJdQ#q_VQHi*FG2{FX@{A&IBkW;e zR!CQ@7Wt~Qawim@X(tOZZ8QT5-?jHppKE9S(ZP70vg_BreWf}+pA-v|i|#UZhG|KU zoqp#mm7?D1S9?h(JK~=_8NdB`kLA;L<+26Xo@RT0}J?&AG5WvAto3aROFv~(bL`U>hNw|UUoLJuDJ)R zZH98S&^;$tBgB!apyKSfZb0heKoE4fV7}MgnxfAnBra!LVYTxbUC&)rV(Xq;jRj+u zib#a1w4G-|WsRy;+6D)1dAe=X&az~ER>qr#I;X&!AO z13yd8Eb&rnO~q)B#R=EGOsmdG8!Vk}6Fh#mPL{7Tltbb7n?46|*5vSrXOqiikjMqg z7tR`c_SqcPGK!t{)$Qp7J#w~GUoBON{USyoO};oXD^~pb+RpqmU8`VO9A1+td7TN7 zYLO8O8C3@k5eKCjvBQBKZ}873eIXYk56X+p9<4+tjtYDvZ?_Z|H#6h9 zUHM>ijSa^~a_BnSU1Hycea5(-Z}2uD|FPTAV`Cs+YllP#tpTjp z8J@d^FQ6KC0h`pQh_2V<+3Dbv8J_*w`$R;>v55h}(Tafr_Sfyg06f zpbuigRFMSDTf${G>km)6i>6Gc7b?5+i?>O+T%|lH0(wp@dUKUzX6r9K7JlKjm2z3G zR3TKo>EPEL^jDxvHu%;0rnK}p5>uJLa24?JiN9_o^3&L{VVa#I?Jb7!_-Z_^Hj znOgmED$8YpVE%T<*^${=e0ZAYNnBauu)yuq2KP+Er3ZJMu+3hiCDnb{nfc$JYl458 zY6?w5xdF|lf#@PUoBDWt)BE5jd%N7k6`EwUgbdV<=3~_%Po@3J?rW1vhgz`5s#FO* z3_6%P$w+wkE-$qn&mr6xErTZSI<-7|+0+50y%$&JklRk=t;oHIlfsMYTb;5lWbY(M z`6%HfYQA(2?5@4b`gLgjAYB(e75Sz9bwVfoz(U`)gNrv-N;`k{n1cXCvH#qV3N6tNn2EI=L7vtI%_{imVcoc#6b}?O1B66^ zWlBlzzi#ssRU@uY9wEQc^@RU^&S~=F5w}U9*eQ0Q)Ry~B8n{TsGEmLvsl0s`-d`2H z%EbP3uDVk9DaPaPwvEq9=34*BzsbpY3^L$vSWWd1o)7Ye_JY_4*PJp^iNecw`gy=< zy?mk5iYVohDrJ!Qc%5i(c|@|ywQ)h7?Ke$Rva`D=pT-XN0o|4b_cdyr=PzVeduUsu zY80)XwaS}WT~n}9sp-~I?r^@N*fv{G8WEhYF2a}e?A^AnC*5{;#nKD=%{6W zY9-FtiaJFWV4Qh-i=Q{*I7@^sm0y#%)Q4kAFNCADJ=((0pLjDNF_Z2wE7m8@yblp| zGqRXp5EC#r$FtAtX+NG8RJUim`hzr(g^I-p3N^+WXrKR_-iVhV*BTr+cc<_hheE8u z;5-RB-x0ji{P%_8=X!+;vy4uL)mUHx1{Pb5u5^NhXYZo#F->ba^uPHe{+$vfmE7N~ zmTIGP6G~qdyOo zQNR1LXu4i8)`kB2tkr~)DOg#ZKQaI_NxqVK(Z}w4U*Q#O0tGE zjPvd|oBmZe7o{Nxg z*Cu94P}R8@?ac^N+05}nv3ybNXQNGeVRPOB?@m#>vZC$+zc;-&e7$c;;>X$~*GWJ9 z>MH#C)olN)^mb|6%e@0I<+K1VQ!fCT(GG3U9~*mgVZoRzh;|InsSb8_QWl?V98TH~ z{Gq*5{v2yf&KUPU-Al}OcqPL|Fsg4?=$%!AejX?5Gq5lrRlLVc0&anwYzcl3eW8tG zk5?O$IoPE@b6m;wZ|7z`tP8++0K*bJim5`2z~H729@Y=G0spjc`Nz!&YrCZ$&h*P% ze_zgROrpO?+g(u1-K6DP9J6?L-jbbBaZ6ysk;K0D`8|2F%@pO^7m4q(_(MAqZoBd( zA7fiGfZ&LqPyyDM?XS^8$1PcTQ3wONG62@;Ms{0o-f z9h5(lN$ro?gY0pnEL5+^Glcv8x>$Q&rWh5|FH)<-=Oiy%&Jinw%&}3ZdKfVg&H$m%oTr+a`*hftF7lFZJ@32d8wCy|c>7L7!k zo12&bAcDKxwgFu4eK6fJ&;DZZa7xZN6K}Beuj6s*g`T3#byDy!<-@Cc?B=Vld>k%`2xG|CNxpwylGY@*es`4)^#Gy#`zJa=QK>_cf2Bk< ziWurpHc-!=6n;7WW-oom?)8v+F#f|puhg&l=#*yEH^{SXw#j}I&yd%H{HTf}ME4*a zV97YHvG5a7jdT8>3Kxd;;z)lEdJu#~AsTb_q^kj?0Hh)^E+sxvsa`^C(RC{_sWO%? zCr9qF_Qg3p8WQCZ-@glG4ZJWealSM!Wux+0LhZnI(!k^nspy5N>~<(25b^-|tSauv zD3IN6L7!y>tp)Vaca;DX-j?y5HyR1zQl*qP2YAVACOr%}qDcZetQU;dRp17zya&_| z&jEDl^ri33aiQwH7u%ccq@#8!&APcVhGp_q!@Ki+qDLYc`j2qMVgb4(!yzME z;Knn?)9AZHoy`q_5!1ckpJ4Zg@QaO1fJ(@B#}5%cJ5hTq)`>lFPbHd$wY8EI1R?_) z`JZ#V5q*)KZc_7Wi>A9c>y1GE_AIx;jr2mFwuhe>QzJlDOo48%t~qq{$lVHqmc&2Y zf43n8wE-HbU)^|0#X8SZ)@BU^UIe>Fy>&Ed3+Jf zxxj4?%Ewllt8OFuW` zLv5&9;gM&FqgqtscP{@`86Sce5Qxye!RD_Q=_oyBw{(t?F-F~_~g+5TEs$AgePC7;Lc>4gQIp&fT= zJs)`DayQd|D&aaD0iC2RE-RBjVsqpL;nHgnhJ!Kn|Kst#xcCdq^`2TqHC?z1x;;!l zxH!dS&*)!b>e*D-KouAT0}`liM_VM@*vE@E=K~_D^dC^iwn0 zZA&&A>DwuU=zQ1`z5@mB5i3Aoa`OtRp=F2yOT;lv|6j)fqBMUolM^_>^5Fb&*gtOr zA1O!Ywcj{0_YJu`(*(dQYl*e*iQ|g5`Y6(hUqgAib#*6eki=+qETOhj)${*j>@A?G z-nKW;O?OJCl$5lBG!lX$B7zE%N~$yx(v8xff{L_?34$UBNQZzDBHbX;v6ap@7wWy| zp7(!ayfelbCx*H=d;iv4bAIy+%CIqZ>#OAInD}o5(ggDG+;2PCzMGs~ZKMT+@hO>K z;S*^Cw?U{&Ss?|FF;lJ|_NQ(zv8V)*4y9?g1hD9~*Bg}UAtSt9Th2+7(mF z4=1v$W_He&WxL5;aB5r-5?G)~s#giQ!-fAa3-@L98T&!{noSjAV^x-Avse8x10p$O ze(rqI(SmxTBPD+7M73*4fB`>!bSIShV=+OB7(>((&%M_#UW^&+ve{j`rV5+iz}VDe zVPRpBSop4>ax;O^4_47EUob;liAx|oPL)ff-s(XlP9EfSE}|p%yA^}Q ziybNaWPeWX-ZV%3=Gfp0e-5kCis8vcsgvPjFAJp!S=zV-@n)g1?;<*9*qzH2Kedrv z=E2Nl{1waOG#Q{~YEl&UtYUpA5q58P-EN9uS|C{ltukSID*;S`(<#!g1)iRuyik(M zLq`hdiE&#-tX-w@5zmY~PdgiL@ zMOT~ni82CjD=CZiQ-^Lpy$GGJZQzvmT#ud~W!W8n+>`IEZck7~W@oSG9D|!9@*x-g zd(Gw82Y&M`=wg%YDRtipFaEy#s$9aXGMzNb^**O`iby~Z;KO=FKX)lt0Jf?__FpK$ zgp5&C9rn7rFy4-XzkAphn;H;fLaZSDilwt(JGrUi@0fF;FL=cZK{g z%pju{Z4%IWaeXQ~?qH}^MG&$LMH@bpS;YUy8Q0)my3OP%&r-~6Nvz!w(xUz^g)FyE zqbEvMv4H-B{l&eO%fz)~EG8{Ouaysl#(v`WeRCj^xr=Ruhbc8@~Hybpz+ z-B#z`Ingz{GKN5F(6_ntUCpUloQ(C@X*b1z;r&?~F;_-}j2iE&o3Ur7PKGLN3nt1l z>^Io_!xURf2SjcQvWH*Cr=vvX2O9Cg9+q436dP7!oQ=dbk*f< z-OEzkjLS?~ytK~8*a#{@D1INoG-|@-Ho{w02l*$Z$b-~RMbtc9wqvy5p+7OTlM;XJ zrt;vS-tDFC){xiR&mPr&yI7Dz70Gdro^z+aOvu zuRkYi;6b_(8G_cMeviN(f_-yWRb~nIGsxdvJuu|AX8KZ%?ad8S&4bjyPtvFg_(h8* z?kv0m;gpK+IJoJXdSxlp4l$`bkh!1~YIPp$m@Tu2{8s!z}?h}$~ z@iFmq0=t(!NvFF@oE@49uk4Bs4-X6ve;?cH7xU4uhzfJUTdr%K)f{>v=sJ|>rNd^` zV6vs~ATY+zsb&CHsFY51CP!*e36fG^%Q{s332q62^w6|9fGb_#g@dot7`mF)KMdL5 z;C*B)`1mamn&1M87Xb>qoz|x0ccq}Sh-CKR7@hCA=c#{It$=ZM^WXGPtFT$+6zq9klWNJWz&!s=Ayx$6FZ>U6EsI{1G zPCl9ZMVZfG2Hsr8!;C<2Y_Zz|xS_45hdAX+_{Oiv?1BPX%*YKnY;wgm!QMt|fFb6y zRE>E2=|h)``lGUfbqwYJ(G3KS7?lG?U5r6_cGgh!;5|u;w~Y|40Yt8@J!SQa!|5)M zw%T5fzSs(X5_hqwB@-H?!>s8ABwTH5Qf+T+J%;<<(@7KPYb|Fv&c8j$o^s=EFU%y+ z_!Fl4jR4o-f8VV=>?;j?z)^*H9`QafXlHhsaSt;B=HImlB>@-JuVn{`1Wmi!hlb4P z{_MqA=AzYwmzN+x#U{POKRSo(UBsPdOEA3QCfxf)J-l9Pruk`!<0jdoY}<#ltwEm_ zPF?q8J%mv?cGbeqqP4`&LGTsTW9rkZ+diG5c$V7s$Ewb@z<7`LyB>~7NS>uFT6O6(`oj^ z(qlMH^XAJk(z`paWKO~X-q)_%7pbUXw;h1NFwpWZT1GZJCoVh|0%*R4RQx209@OG;FLyu-G&u4+(`9wt6$| z=$4$#x&z6I<6wEFD0AWTmGnpnf7VArQ7~+)8{CI!8>M=opyxF3%sUMh%9tgHN919n zteN+)eNEBx?HiU1wn=Jk3?bCG6ciMQ2ZV@cqhRoX4B+BVd(3SY{U(zxUylYOIV5_n zL(E9;m1}c-g=C4mm%+9a@!C*M72@a&rj4Vo9tNE5c?pa&fZk(3xNxr+gJ4X_5RR*% zH2qf_drlOv|Fo4+A@dXw=9U#YD>*@RR{PNY=X@b^edb`lLRX!5U}61=VjT8Rey(Xv zePDR;(64Lh-G?$^0>xUh%m3_3>KXB*TeVbzA6DH5L7)I~jC@=hs^N)K2%hgRLTS}r zKO6ahJ>R16>-4j0Hii&chJ!)uJK3nr{Ct|6N3N$@T8i=zFR5aCLTs(S2*ls^gC+Ec zriFgffakCufmEqplJ&9Fmf!%7RJ#m|mUav7&_IqM@&vnlD__O;c%x6ckCQr`E`k~3 z?rA?vWKJI{{2Iaib1ylo5o z++&+l+3kIOk$l=o-@iW^b!NgKM|H&hIR*wj60pwfv6qJLhkw0>b2kuv)kwwOju@Ee zf%jX_{08#d9p3&SLo`BrYT&#@Y3RX_PT zo38thwgj!iFe4tN*lf-nV&%>DxDv(O++50-m>5~k;n|fdh=Hf$m>QBSYhIXFN=rO{$TPum_OF=>}^AiPU&yMsF^I(=Y8G`Kkce+rv> zze$;@^}9S*zX`N4H(tp<^U`G8X4McbtaE=wWoOvkrMJ}2)J8q8QP}$W6jM=7k4*VD z)J!}%!xu->vAd)M#Gb$F9AV741{;CY_fEr(mwDl!ec+F#dso>PJ(ludEJlpGY>W{G zDBf>oejHbh${3amrNQRAPT#tBJxt&!a}RY}+nb!8A^_g5LlddbXK~lOj>{VGiCAah z!*5Xf!+aHPM%crYl#YPa4PQm?n5S3ctv?7b+~?)as?z=s-D$fgyn)6@Rt9#ZYwkVQ zD?$uEJ%sP#_hQn1TQ!3j7smLmqU(nJNvdi-r2RZ{-f~5{xZ>5>8My{Jd4n>2t?paU zr`#&hrzWUG#Q`EAZq;3(&e|#p08LKWNxg|-GO@Dw-Cy4u@02R8LiQ>Q_5)qb9hN2 zJrBd4^60p+I&n~{o}%fTad%pEQhd09n5f!wwG$4{Bh^;vw(9RabM9Y1E`ZunC&LSv z=(;P2^eYjt>kAe3rDk3NFA2C%HM&{-fN#GJ3LF^*rpM~5IjLRqSTOtlyuJmNnOKM~ z9et8)R?S$xM1TT$1q21BJxvM@*Ni$lhjS&Iyzh3Tjp@1gvV~f`Wswa{c=!^(Cs+PT z@BZ90)j#@b^IU^D9|>hjV_Z1GDMhS|TWz)&rM^(FKbGy)o4c-kSxrUlF?67^_gdVy z45Gyfv)>zN$ac>QU;EcX4%LUz#2E*IscyD02jC^-&KCc8cpcOOgqY%;)&5cExd4%@ zhYwj$^+MH2@BH~NFwo|vCbj@Iq8IF}mVixb35)CXA2JoU`O|$1x`%JrrrkV5XT9h2 zic!q$5!TZvDN}%yyap)H9~Ta&s&u^LSOdDy`SV)?M40CWt8;f+!eo!?G;i+epQlgc zUHPEf?amb5KCT3Q#m%niTbte@AXAu>gWr?aO;#iRHE+^gsBE^@cc#~!_Br)(+X;b! zsVOP^(1Z;+qlAN;&vcUxHV>{bhX>~Crt@mYYv5K<`3HY%ktvm<<1cI2WoQ}YucXEe zAHvGzy&Ru?#u~`2%^uLTCY!^1T#xb_kw9eGu##3_iD$Zf$V^@sB|o{2@U7Wzcglx; z03yY(s5SaLoXMlXTl%AJLg2N$0^5sb^yAHJOruE$CF`X!W|{78z%RbC@nLdp>t<@? z$Zvgxkcddw)2qz+Kjr9$+^;B6M9*3=ov2@Q+cL{h?PEesIXK}2!fi&5ao(KHu>}91 zXRlv}HCqol`r2&PiTMW=4g0fVoIB579E}F7U?d;Vf#z$hkjVkH)Eu1`JWX&!fsk)j_jQA!52RUI=NiNmK##)Q;H|*1?yJO0OM;IZ6s=U$PXRZTmE`T! z)zxUFoQ~VDb$EPE5I9yg%d%{gkZW_i(kkdOt_i3=b?WRkQZ##pU@+d1Tj z&!Yh=nh7VvlHo_WvU4~UCnW}Wa;>L+!h-^;ad?$ioW-f~yZ9l;d1-G)m;e0f3jB*u zRj%6c2V|qkNlDeyH@InUA7qKKS-yAoZsmk47fTE`ZR3Dq$Kt(DQ zeK!(-z}eDa#KVElt?bO_oDg$<==Ev(NmMdEPX-Pi;$qX`O8e}4;KEE{&E<>KF;{OL zS{0=AvUip>03?DN`&lWypRh<@L*Y3O-IwUP*R>e|9#ZeRs5%@}MS@&<2DNU+1s_lm zvd&O^H1O*LMZ9(8#96(uLY^li@4_T2y4j=&o+!!R0M6I_UR!E+%HkjBYygJb`wF5 z6`slqo^)TKG3>ElR|9tj4yZn^7%Dpx=97x~(V_C;%*4b500nl}uQ#tN#v$3vfnQfg zOmcqwu$d}J+x2rCp8a;ylr)I<@cm#<;Nf3Jkt66Pg&DGAZj{On5;*S57g<&<)kQI8 zKYe;=(sCQTksB8M*W~wSt03s7rUIpwyt4`m>0udpt6}`u#9e(9*Ur9t^e<@+c>*Di z!ID;Sa)9D9U%fhfUHWq7g@KJmEf-(O>e+XDd7j(}AgPgRYHsc@r!vc_E`_gVA@2@D zKLthb} z!r~F#9_DdHY?XyL|0IBgGbD;%izoBVe_ajg1So}ryLUw(;=d#cn4UiU88emBJ=BjV zd^i-m_MZREVorm9X((B}kMpG%GU>1|WyHQwh2;eJPTxif_4oF@!O+Wpjif$Z2oB!c zQ<}=b)VP|Qen)oUd7HfPv8KiJu3J8tIPeLo0EE+guTq-=kVbDAdFGl^a9c@NOg@Do z&TgZp+PT-px`rMXRbAo~)`K*jisX(@yVD#VU};d&wW)ccb*!i>Osb(|^m9k_DR>v$ zGs@aS&rN}_K&`rL_p#*48g$!Z;1Cvgnd#Wvo0=u)HSx^yWNasg!8;xHIg}V#!KbGZ z@l(De^Sx)0apBQ!6Fi-cv&P0z0lUA_L1;9Koh@=2PRl>uP)>Ur$yic{ zTx`MDYbVy|{CNWdgXJv+^>0$}TpdznYEzr#O6)o;OVhtCrmnO>@unN^S&K_@m0Qo6 zPuWmf){VK56sbh=IvL6WzmdLAfj?rViD|+gMe$ti+f%-2}i=E4e=yS`RuAB%V z`rlD}1XL4eAu$P(Jx0O6p&l7Y-Z>*j0?Kvesn}xW^(rb?hK7CBNzxDqX*OGt+xuqw zB@EF)-1%!B(fPG)DydynV{KV-MRh4{Vnc$Xh%5Om{Q`n;VV z7W65nLe2M=zE5G((wR+?N7*g2D;qaV=L*T3l^0TekFb0ZZ%IV15C^S2q+OqUDDF{L zKo+1$nLqlK#4K8Pv{hh6@t+VnP!3+acoCcz%@EuyR>6igoIfX;k^2H%{@k!4nsCho zAPbqTf&SqZ*xX-YhqyV4v#Po`==PRF2`u?(E@9nxQ_SYC-DO{5_u*&bDcQ3S$@@KC zet0wdojsZwmvjEiSbpbadx~kJokBh#QTfkNx9L5|vY;Hn-YhYcdTb=3UFGYyZ=su; z;QLTu-Fy}TNEWo8oYT<%cnFbL3_gW2%nIx>3wAU}zav&%+@pT}GV*pcMtAgo9#nmSVFQZv6@nY0;$ z`0oSe@2v^=c+0?7Qh3;0gbgKPQ0WRYH_1BlUcK_ zR6ojAwc-td1B;XJk;9~Xy8RG-vI&`D0pH-E((7kkw|CACoZsP8!8Opp^?u-5aQzsa zuKQj?(WjPYvHabkSVw)_Ev3!Tw-$MyxG!adv&!lAW5hF1o zr?=0#(g5%4N^MGJgp9GI$BPvn{Nx=BH~p!iUZEMZIgfW3dKA;~^3fe$p!UVHTg$Qz zA5OR%G})zz9gflCU3>89c~fO{yQuGn7QBkoFy$t(e|#v9tO+WSU=8uTsH5sic7kCY zdi?x(gsmWFPG=}K57=ii&Yk{)bqdh|pC(!>o6nNpvpidFS`a!Td}^Yf-4^(pbGfnklD-LGz=ZAXxs;cRPv_YjNvkv5Z9K zSQXsYAO3S+TMna02e4l`afC$!7Sm}^HeVU$j|DcQ1J{RIbp+&j;$qLZ8wXFe9d&^f;fNZQVyG5xFk-)fo+;wR3lL8bDeQuZ90WA*aoB``8}ktX zyU{PySR3v7*xWt&td|R4?A&7Z7zQGI`%5$h==crxyk;ZgK3Wk6)%73nx}w8Wn$?@~ zo=xv^&65LV&qPG#bRQ!ffItP{ndK0M0tN}7Hw=Hlp%4n!2oC1q!-qfFS)c{TIEYX= z;sm?5R&!G?d}#WYrp#W`DrjyK=Sp-r;2-KOLrlN#!96^$ja7+X^`zEXzO3n?{rYyZ z;oh@i`(Q>*a=}>D{nYP?WWGgbR#tfqsfS0do{Nk1EFHwO%^yDAL;;Q(&8}5qW+xFL z=@aRn`||j&25F!x@6yx8w0a&rCFI*MfTKAvy+NgdYh^TJNbR&}<0D8nY^Zf^)G0Fm z>Cv!;3lVRg>NL-=PU|^OdXI!3QFeQ-yeyMT3|CM;t%rXqy>;b#E@$ zyrK-TYqMwp6=bFw)lrql)j#51H_({)lbtJNc+$^XM=|VD5#l#xpnv5oJ+qp2Dzig@ zn}C(BPF00VKke70S|~6*_dd@EdqRP^MZ|wzo{Mq4B7~6fpN}{c@zI(Hv=p(kv#Y!- zs88t)L1n%e3`k>c>@SE}-h?IuqYat_l+VQD_{oEq*Yo2b5flpvXyUXGRvduJ*0^vA z?b2d~50qxfrI{S4m)9Y=<+oWTO(Hcra9^p9BBi<-cyfROrSr9Z04f0m23>x?G8#y^ zmH>-JN+5voHYgq1pm2kC57D{jcRMqwDM&xdN=5;p68HLH!QJV6Q{u9W$nQlTXcJmB16dd$anzJ28w!iZ$SHUF)ctiPWVcr9JZg|=P5d>6uO-1*<(tC!X z(8^Nt2#!7g3w;X(HU`Wyys|4tQQ85S6PXOTM&zodE%22`h(h;(I{{TnygZ`+U;~?YRuPcN@{eO|GXWd5&wUicp2=XCW zqVYtvkIz<(R^P|K!M_#x=*c3l)5nzyh_<*69Ajiuei@?jrppToB+<@ceYL0}2E43{ znE2rJu@bcNpqazuFTSVQNVfAP$fOD&yp?Md5j*2HlDd3At3bm^ale~5!j2)$#@G=D z@$FY0eEDqe3wyPstZ=Dom_Gj!JcVA~?qEXbn-Q51JSc9z*|tj?zD7GR*m;x=27e+< zf(aNrHt;Z@QweWsIu(WJ6DE>lB5gb=B7(bh(n_~K*{)%H9E?VPvpTd}mt|6{K)Vzh z!pm5n{T6DG-I+1$`OEf{){MLLo99hD?lWpqj{0z*QKX4NTPYj_(G4!Z@U|%|9ds-n zT-Yf;l%O@jI=oJ}3u@l)@(S>hSx*`~s&5^6sFX{M+g?b;^?pH&d48eBF|c&iwPSpm zSGFnBtHJPs??5*!_&AK|>u282rH-BOrt_UROZ7C`neH@@)Sz4=tso+WN?3iO22RLF z&4!$Fg1#V!s|BOt2a}()I*vdmmsM8A_M)jU{MOGHSR8yG5S#{ z$Q?ba@PPG}(LFa~H5M3l>WNAdjr=AknHf39s!n#E5s)McIP&Iyu;<%nI*ypQeL0GR z;1bqNcOBz&4&B*D1mXVuI8t#RF9-m)%Nc=bZuRC@$(0k*87w6*QO{|>KPte6l$U&arGwhhE1Mu{Rd839&cY^@buIN$(A$Qa zhTHr7YuQ_?Q-d%nwZT3ItuK;4r!4cqg@HB;A|S!2Ih^$~J5b@!Wc&n~Ym}bmbN0ySM zweYqgLI}s4ZN8eoonxIn$(In31k^+B@)4*&dK_9eL!KcQ3XXy`6df7EVHFUR1jcHH z0#D`|$_M}sZ^hQ&qBHP_FVUbHaGVyH9^Wttq$x`8Ij*V$Mj@VMni3ij`8*h-zSHtfflq{`*U4$v}+=IA4EetsF)z(FLiGMR?lBYrP+x8J|5;9H$ zM?k=V#Tl_LXy1AH=zyvP zb`^gYAQW~z#g}_;Tvhv}?=AHZoPV@3xd*f45o%IYu|omO zI6)BH`@@tze%*ZQVdrIvQ>w&v4=?;>2FwU%YRGeJ9W97Dv%xU@(=D4jUU2SW2=W>bQxS2CN2ujW1Fe^mJpk@$$MX5_@9?L z5EspEv}-6Tbxr=_0yiXN@B>FDy;JNYErf2TH!ne@Cr$gqp;OM(vE= z^B)84sa4=p3)fTfBKa5NjfIyF_a#-9FJ4iQCYkImk+6(r^$1(kMR(J)H{jlZ*!u;} z$iqZK1m$6SW5M{`CyvfYn#QYQNYI6)S6#-7$f!{ee>x!BFw=*xQKuHi_%9?Psz7! z)#$|Rqq*~2JGvtAeO>O7F-EUt)K*n<>goOy>mt2!fLe$x5mqR4FKZ669J%lIPTBh% z$%z}K)8j0S9{9;anAbZUH&VVoQ)#@btRnKj0yQ{wl$|mhi8G%Yr&p8RPMVHP=;$3b z%FlA&qAolB5Ylt{=ZhI)9+N+O%0*3_0WmxyMQ`?Df_Gy3q{HNm z@p_Mgu=t)D)d8&>QXf;PU=31Gm2g*E={}z5AS*vHB+K7OI%F3;`n*u@9bZAWp@PTn zI#J&+z&j)Sk88!`xOz8K_lf4@Y#<`Q!ptm>5z%TvM|$?c1siv-kVv$hej)Jw_uKe( z63L)RxZAI7_FS^t7lhCZKXBCFjzEHtrw+yWVvP$CW9&pHTju#+s1ozu5j@^5C*B)3k z793QC{$l&UAJy1>C65>L2hc~KFvDUjqRg4pFDgziVh|87`Fl8r2~9O zS{AU_tZj6YAkLtq-`w|HPihl?sf>mWhgF)KYEb10}3SBQdZO1bMZB7IG{gKw7*OHLgx9tIl#{*V+6S~<{J7k^LGb_OZvSeqJESi?B1 zlByua7P%1-=_+9HgR2gf1KD64ZG%sqyiLyqR}r|B{O05z^_%k=k0Q=nCpcJQGa`Q? zl?wF6Dv7U*J{hhWsJT$IPRp|XZgqZ|Hohu97ZDvr=`#N98YA6vagTqvo2aR+kIPdV?-lLuwC|f&9bKi>w>){r z;jA`gnEkJnEdz-!m3OXJ7q70^l)aNB^vi1kx?oP6>u9-tghiJZBkd>#xPrX@yhgQu z2ppm%!2~u>%%FGSULO0iL7W`+oKOKZ@@*&tvZ~;RFaEPxFY&RV_GZhZa@260G&YTi zB-?GQ2=nVXqHj6PGECB`%#nCyvHNObNo8|N*819}K=t$AW0!%Pb7Q6%xtr(ri$eDh zb?w$az(kO*K^}Q#&wB?};I~Nm0;sP%h22NS-QlUO`=mvf5h0DK7WW)50JYZJ%tHVB zj(6<&kSmP6wdQ>mZExRBygGBDlM7q@)PvMsZzW=Lx-{BYT6Z6ENPUadruuc|Zr)P2 z<$iiVlZ^Hdz3u}^0wn814iSh&>#-H2yaRrrDDndsbbtt)j!pznM&=KVZ?2SGBdsCk zf7Qg}oPjg!me&ra^7@d#t)q2r_yQM=s>O6HKJ@=pv39Gco79>Nd{@LdDFxrI$QbOG z2ArD2qkjf)KpXBqPJGL zNG|_$;`P}tfgr~FD*q&!b@d&ue|zO0)j!@ z`u!ap30M57Cy_Rc0L?Z$V(nR;@?dH|)X~t1`x0or?CJ#~gHMxm_7xaDw^3g=mm509 z(VNqq>!dr{zF)SiS18;Xi@(D-c3zHa?O+6g_t3-ehiEO^%uxRUTuaZInZ-gr40nQ- z1#Ii5gE`x<8u*xqnjjQE-EHC)b3T(&D3RBFOABn%7~C93zVaAfuJw4J--@O|_gjq;sZtv;2<$R1huR9939iW;cKqL?-Q?JNfDyJO@Gr$rd zOcjRbzx;GoufIlK&~e&%Q)I?D;zj2^aE*N-HKI5(xzp#I-O9@x>12{;%P!q;1C47% zMSO1-Q|O9>M$0bAw6M*Fn@;cKy!!Y1nlK)RwbqFI@CIf1uhd0DcG65svRN+$%D4vo zrQtGlJRCPi3mVmZ`PusvB@g>w9k+a4W!_wD{4Bd5=|#EapWZz%%V>mH#`eFK*ZHa* z2L-&2f!&;h-Il@boaE9I2}b~jDdJu)?r|Wh)ZcFkF;0IXGlO;&ByhG<1Txt|vFbq> z#&6~N31kl{j>;8ZtFZZUi5FydPC)bS&91S0xDP#7*9*{Y+0VVys(gO*jLDIqkaM<| zE1si6!7XF}+~t=qUvf1|q6Dbq*^yZq5VV8Qvs%(bv^y@E|ANGiOjU(PyRV&QzBqNu z*PCiACB5(E`EZbQR~xzILhtun6wO=TLl$dOz7m)XvDF2KqXHxtKpS|Khkd;XUo8BC z(K?zyxo4=FZ7{3W)5oKdg&}^ji5s!LE~&(&$mgc1Pi3d`eAwwkqxEpUXSKwb6_hvL zUf+Fqb%D$yoZ8FnI@_;qs4y}Gn{g~x8YUucN;I_+(eEEO*Gz#SBTt51Qa|m9Ajo?G zq9Mcf+gWuy0=L{^5=9)5VurT2p3lGVt3@4L-la;oX7)ixW6J`1kcZj+tNo#J)8D{3ON^41Eg*#&6WG^xDT;?> z9Yhhlw1SbpcX1{zlIJcXsLZPMN;AJ=Unn@rWNMal)!b9eJ}bGR1TtfEaa6EGUSG%A z>Mr8P#+|2mW&cI=E~R@TLT}Swx~fWFO;Sb#_rHe#`GK6mc>}1_OC91MHA>mX4U>SG zgUynMXie6RI5}Rs9i3?M%RfR)6Rboc&QvBv;3uP7DENUr22gXa0Nd5^cC@{3@Nw=FU^M6*m{9zC2L4tnebPXpzOw9`6mN+Ag%NbjczE-I zXbpP0=y}>np_>)IssF-YR+CtC3rAh(uH##Z2e=5les$@MLO62RP<_8qP;%*t9E`is z8vV3AMd~x>_#Vl(#JsKlFnYE)R}2kX2}Y*{G2q`UGroxS0sx85=9XF%ae8Gs0dF2? zL2NFyG~kr`$@4y3mPC&mJZp%ZWC5y5>e*mWZ4Mx3x_u+X?>pnD{AP0e-ED^?7nMH> zjC<{Dr&zqaP4V*_VvqOdC|yeqVzx9{cTFWSOR>N526dF``1cJk2#uB6Rq|Q2$5cYi z!z4EsRwRQ_Aw9Wb3EqhtDbwTiMY}|!aRK-8yE5FC!C$O#p_c7QDKzQgM zdx4AV%UV3qJ!d)^V*ayW;r!`PTpu93C_k|f_Xj^`VPx=Pc{5Z1DjExN4h`8 zGnf4qLX%eA=M)WUo_*RU`W}&~#COx1?(T*3HRq^46@0EfoN)WTTbMRXAq527L-s$5 z31O9^W?z%-Q+R5L%|*(@QuN5tq~)nqT7hzxcwS~G0?=k`T<5Wiv2_ZSyn2fwS!e0W z>Dym|wWwnM^-{vG;wBnhK|#n8Q$6p&s6mkOtVd#!^v++0i~$!4iqW`5l`L!+bar8j z(Anv2Xh3I2waA8BCq;^RW97r9`|7p@xUb0N=Bn3ON8EJr3L}`MUB&&^(MVpy9XWr| zGz^+r5^(YTJbhDnXIURCBZxpZsFmKz6{@19h6hc#n(uPECeo-OX+CfZz*|D78$6YH zhu~zP%+YmF@<#xZv}XDSN*zZwwax(1u1>L950<#EBd?jFA#XAWItXr^tN{;XBW>m$t=vhf2HsvKD;m77oc z0FMF-3zl)vludiu0|OZmv^M?HLX3cSQ@%6r#6wN|xF0Z21Zsno#3%^zt6%^i&Fl85 zbc38s44MSSuwEr107S!hR1LQOVStlzAi9OYy*qMHSx+&s0G&FNLBidHhgoG`HthGv z@y?tza?ASmwdQop6H8|jFjJJS+woc+mwsh`(wbeeR?xC5ljD)$M5qb=?>Q-}>U7H* z#&gh}5t@l5M8(O95=A|zGvryK{0boQ68ZTZW~hM!;(|8CO?L%R)*6!Vkuyu;!*9LV zGR6iv!NQ$Ty-=Sfp)xkuzLT%d5+a-aVjU|#Z&asTVOo7J!Ur3nFYJ7-Fscr5pg??ikc6wY zifT3;maE7j4@p{~+uNk!GD?y_YBI231=k9|?!u^X0mLK5$L#^(9Rvyz)l3;$nGoYZ z4~7^ch;|Ev^DX*|;QR*ky$~2L34xM{^JRyTrpF`q#1P$9D=^Twofw0ts(WOe+ETh6O=<^ezayw!DB^3OG1}sL3sG zVRp0x{W?-d|F<>%O<&8%%EFa4zAQR&=19EwBxy+70Wh(|C|;f^zT|ZPEH1)m`V(M* z=DyElb9?qYAER=;DK^K`;IiJGaKEB8=jXI79Fz6EmV?eO>Wr5`^}OF1|K+NmzG?Pt z96OR^>{wOUZ5D-yK>PmEt`n!GN@6=-UbFMN%DTm#OA94ZG%&;R3doZrJ)c13DL}lpTJZPMPERalV$so!L!4%P(`qKIhZh=Yon7x8cuTpG@!4qI$<90SUlr`FeZwt~c3he3 z$wy9xSI!gPTm!ZfX6hL+r{8irk%HsT?~TZv6bcm!Fham6^zoL`Zn34#GcY~O03K4m z8={%^l(-A`_4fl^voGg9MUBc~hLGlo;KPi2r>U}38y>xlO(0jLdQDVYPF4Ej;+6hp zV!vfKl^b8)xqW+tx0nE%&lES~F7lDiR5OL{@0wc8q6<3Aqm@r0yv4uDFG(l;9nt7*p1C|z{M}psM{Ur%02U|f8E%+ z#P`iUgSN#&Qn=n@({K4o6RC6;=k^JYq&Yiy`wjR?;xDZ_%Ip@t#(&*7SA@|t=y3+j zJz7L?Zbf@Dg9$w~xV7kuCFHp&OFtdu{8)CjZPj0@Xu ztO(@Ghw&b_6m=_y6LHoxerqx%R5Xhu=zSrAqXkLFmk#~;$Z zR-QvoiV_$85uI6=bR=aiZ1q!Vh*j8{8F--Cab&KP!V#W{uWBivK{Exq+;Ecr& zgiAoo_ezyp`5`yIvo%oO{XE`-&<3XspUCG`b96&FWvzI0GG9j$u~yu% zBSgw!-!^S7=FB0pvv}b%D%#?|hX0rQezGA#5&d|fnNQgLURUAc2i6zkwaQRcl^CvO z)mlpjNeL0m2$M&%je;DWz%BLXj+<`B5ja!CT(E9?piAJLx>}#^Q|Fg@L`S>f%dA0h zH;IK3#oa#<9q#2t_#gq9M@ot=_ZeTO=P$0oKg&eJRA{%7yof%?y8WNj_w*eP^95XJ_g2QSTHZy2%OyK zg@MKb2g7mmI}L2r|BrwT=AEqt^?~Q{!V6oocqbAr>vAxitGT*CT~$e@A ztY;{eE=|70G3S`9zIv4MnQH@)QrL~u&+8v-nF#kygJcZOTPp#z{qJK9Ed<-(}cZSX-S+`)=UI9-Z z2LmY}gB(SWNVc5b+|(qw+0@hoPxq?v=y3m#<(yCK!m}3-S14}2aXM?1B;|NWasCeP zbl_x*CTH7Dh4rFtt!rnF04W`xkLDL&&<HafQl z`=J@@2*`c;vVT!l+G)YRY7^A&Y)fn(+cyN~%kOz{H&OjC*+mveM^$+HPP1Q(JCe1N z>SP7`=2pAAz6*KNW9Fh^N0|eL4I+Cv&yUU9=$^pp-or0HA{n`*D_Es*po!9OoAva+ zhi>rbw)^_{Bu4Iim>I;Ezyn%dHjWj~>3DFyOn&BBHU@(z!@x2wFe_b@g+|T$D4w%9 zj9%2lRyrwsS=$yZ6A|J!;3}tX%y>S2t*Q{As^VEO;4piG`;^j}INe3Z?!8;Hv}umZ z(9;;t;(G9Phj8GSMa+S|jC}B}b7S6F`l=>Gm7E4g6un9pk_ko_cANbG! zIC|O7CSgnrUzv^or8ZfVkPQtuu1eAMJ)j2@R%;z6%MH>B(_xqCEMN=r ztRP63x>Qglz3g?DZSSC!&()UtcJh(Bd9&@2?1c5T_q4S!pN%)Z7YHn*y~Df6`P6di z4L05?clGx;-II^GuL*t#$0i@SVK8Xb1N}om>`joKflLkS)dz#ogHC10*kJRS%f|%5 za3FrYn+3D}#mOJI^+3KrhKE0v%vID4OhVHF84U64B=j3}=DhIW5j<{ZD*r%M$hH!g z#ECO4jZEQ~p@I4O7~vCR`6AZDS07w|%C-l<@mBD=%mTJsMRtvQo*yCa5+hdH!fIm({8G(iTj&Vp(!*7%5 zSp0rRcgQKa0;-vF6#*l1ut328$UFn-W+7>(9(h%LPFLo{I78EE0aNeJ`Za1Ada=wp z#3(^LGu++n$izjaj+|3Ya|74=1^yg1upwgsO;1;a(lrNenQ4Z0W2fE;8an6{Q*5k ziOJ2@-!8h1e!C3VUZLm{d|81wBzvaF#FF~N|C{nlO&^t@$eV%>t0-=PT7t_qpuuuUh7-sc) zj2M0Y7WS`xRydM5aNlDBdGj%7N-hBc4nrB}SLptDPR+{m4B;vk^bSCXgvP$n`obkS z)pLqV>@q~FTn#H5Wez0FzgV@G&b%cel;mXh_(WKKggewtNX=PjYulJ!-eA~E>)Pt+ z<9*8is$4C1aYx`xuJ!wam!oH?O)9vMgWDaTBM_AUHm~n|5!G1(h!CMPSC6=#=IM<^ zZk5;f4i6sM_ct(m#uu?V{fI`VTB`rl8|NHG!33jqeZC<5h%kw_Ld59~%5gF;|KRgZ zPc;zcryi0bNKgsWn5{^-XlNxENm6e1wea5+77}vaS?Qic))hh=`{weLZVl1%D$b3C zVAucEp`etM55-eFUs&0oEeV}0+FJVR!FlDx5B@Mo{4-3nbsC*G2FI3n%FPNkGC<7U z81}IZG*s{xJt2U$Acyv?9|n;v`N>3}eGJdjo%{`?qOJ!O7P#hrA3E!gU zme(>Vh($k8p76CT(ZO;CCkn2udN-JU!I%wB*SYEVhrXK|8_a*A_xS}$Mw8#jU^Jg5 zGu>j(qdCcao{4%$2>hT(Ts^+n*~$Kvj{|BhhXbzaj|6LGWEq%l%8#7K6?qWB8$ZGB zsQZ9He*;X=02{|UzCC|Be6d*eW8aM$>dOAoK`$F)l1qubMMEUz$3)pXV_k*}iY60G zYcjm9dGAtX9R3TG|J`Eor33=BqYORdo8Dqay-e_Y0dVC)#PK6`kfd{l7|+`dPhHY^ zDa1|S!z?@g@E(l?@7Mnl5Mnt})rc)_rX2WHl$0fy=GV#!(v-x99&APrds<&gHKtt? z$Q@POBWm;i_lbg?4rl{U2!}u60FSMEr>bn;VzH9k2gD~!`2V=$O1lqujR(+}xyT!lUKKQ-NCcF4sJ2X4i`^tquw z%#=TxZhQ)B)b^nBqD{B>Tgh8@A~X)JecuTJh(0F7er@fDuN?9J`;|t(g}8IF8_}+0 zqeEdqmGGPEwlfk}|HJ+?_&oY4Uzy!AKT&EsrsEv*@q0$zFpd6F>+2_mLqCF5Y}{M3 zr;g=F-eTsdezS7pCUq}iCv33Y*Ot>&_D{z<^Q2sy9^*vbe2kX*kxCMgAwXA3hN*dS zELm8p(+6(Xc1@Nb)5h!o9L>&T@Qg85t=Es?GyFCel&?G5%EfwL0?xiLoc*jx9N;7G zVinDm|30fOV~K~ z+S|S6eA>@zOtP)Ni!tt{+49p-Oib?5|Wl< zh%xC~1Cx*@U%sAImq70O@M-3gCj<{|hBl4WLwYS*!5^K>+w0ErI%SbJEc1*0cqd*7 zqTRh}BS}E`7)zSt0bH+&&M;#ojv!@+XDV)zZk^G33%vebQ=$VN>;Dg1-yKhN|Neih zkc5b`H)WI}dy^2d%a%k|_9!c47DC7>6e48r9a3h<-m+)*{#|eC^ZR`7`*Z&4emvZq z^FFU}UC-+|th#n~D-8n{sIcl=-t<1axB%DKjlWME<|=G~=NqVkgLMf(l4)hb2-ZJr z2&1iQ@a=rllrc_)2eA*SY*C-D+u38`YTRt$|KI6lD`%bFTN^r{8a2PPXX8Ev9i|{@ zZgM^i%Q1C>UZ6OM6~zf@B#6xjM=8ek9QdoeLZkf(F7+RGrEEDgOjQhB(~>W(X5fENJI+^ z0YMOCcbDoc5zz*i2R(5nmP+dGvN4KHxH;KbhGs8pwm|O0Kj^1RARCs$kOKLg@;T)# zSkId05@b4T9F1}G$%(u;$nZ>99n&gTL%SB}k2P*Y$|68Nj}v~nDIOBgI00pbFt|wI z(qV$D8GyRx687G)$T@AHqN%B2yXr#!?Fz_4CGiPAes(cB zP4!A?SsWCVx-Y}f!P{Gm#)g`dduO#OUhlosRKAG%Ilf1x7sxcQ#)>HTgJk|(z~5>J z5aUJ|wA0ByL{S6Xe`r%a3=>WRutFFEhuRu&kbDUq3U*!;lOQ%dgOSUAGA6_KjU)a; zWS)BYk%ee?X8M+uM?r?NW3^55<5e&BQ+irslDFH~_hkwBD(M4*GB7U8b=i@JOtN~9 zQwbVm+i+bx!!M(D>u)1@;+ZuHd>hO9&s`Yt+IKb8)T$p&Fqld0+<-i~6DL?+%iI5% z&i8kxM2;V|sB-U*5%wAw2DW8vpM+4k9(q z?=#baQB|A|sUN8p_Ka+21NS*g9L@@AEd%w&5jw-?)7Q%H*kZ@<^4P4QP;Y`{rNq@@ z^PIv&zr%CN`9@Q3U+rCVz)q5h1#>}; z7hg$nP#Kn~qb?(hy(za*DsLKp7KcUiT=q64EA1)VvZ5h)BDGli&VuO@QE3M ztO%HQAn`Z>#^*R{4z`-S!BKL$Gae9)@loKU~2zYNITpVsbOAH#5N+Hfmh zF2<4gHrUsbe1XsI0inr@Gc}wvql_gmQ%wVZWyHx+Gg^;pQ_{&dT~Yb)R@MaerSXm0SP!%mC$PZd}nng z0=(x*|JH&CX`<=zWf+Un0R9SbR9?sw1l~x~%+}rk!pgMC>48~;!0fDXMglAo)=>PCy` zZ*e(uFE3D+VELAm;r)K80i)T9|OeLcZjK^GBNwKuQ}&=2~FqC|M2; zI?U40Z3rjI6oO0qHRDe{$W-Nuo>eYgpqW87K`-NlXcg+lPce>X>|%;`{Ew|rbywaZ zHmrz};^B!|B+B~RrBn!3L((Ep5I2{Q3SUhc?_W-q3ynNkb%tPtvsEmCILBTgPgBXo zsIES)Z^^%eBP*JgUfmhgk$ippZJvG$yjmzIVNR3W*vX(?z|j2r_itBUXk&zFp=O7y z35&LQK3^0mPuc2iBixwqzM_oPkJ81~_UDMr8!)bxY)`{=$=}t*g-AX9{dd2GVFfAj zjbWmyA)V=oz3a3MXc51kB#8hUp|Q{Zwu%KmQYBmK!l)161?~9lRhGKdRQ&Vz>Z@Y9 z{EC@f&%4`7*((+^ z;K{JjB-tPPQ2!=vgs3W5f?y>!X!GZjHicdEBxb%{jo@MVMMXxt>$A?lJm?r4jI+57 z-Y1CF1(MAP<~NPN$7Vq|i|q94jQQ0u<>!xXNUQBX>?YiQvhn%TxcK03u ziMx&D5$;QdYe61?R1nzphv}HcdhZ&>=QZVFgRA~xw4n)?XSepoR-IuYKzQaw*YqlLTdr2^!nlO zJ?}G^=m4AFYotl5P~{RbSCCSB0hZd%3q4;KnJNCyWo}ZeD>v3PgTcpROA|*w*l}Z@ zaXYwzIBdN65hXUm2@~x>eb=97Y}*7}V(8v8FGPYW{NDNl$=S<0a8AY2Z6#0ns}X?N#~)03iW2X&x7G?nM!ulw zp}LE133ZP>%?E!S1Sb|fa5_)2h11yxzKTfUKqh|!jn3Sm2p5M4-|C4g?~9d4_qQ@Q zql*uDNiEoFmZUaK6bgnhIX|8?@DdID%2Fln`rRu_5lKybTXWA%qiSO2Jg`8iF#43O zmJl&1yj&VCdIH4>)qZ@c0X<|(JKni82lFM@Qw{ysYynVA@D;M3y@2Zzc|_7gmF+}D zD>y2;RP>o(Pv`jefyelB3(=L}wUggh|Lq650nlqe&3XwDMwK}&-()A1bRs^-BwJgy zTxova$2;W2?V+y3EWB_7_@0R_dE>ZLb3j<{!aJ=dj@KD15ulH`mvWio9uj((eC z%bDeS`hXD1Ny(wSEx%W$;s#9)?cU~{jZLkF2zwztyRgk>Jz(4u33O@V)cR5Vz@e20 z|HgkEJwWPS0RxO`M6ZVOgfu)^mUtf^2^Q=0RSBIVf_MaQm4?c8U~hewiVF+Ch5Pqn z@zCGUv=@_LK=E9Aj{!PYU0zAyAO6*0=7v%fkW1=xW?yyLZ~U6m1itJB5a^JAf^dQ3 z1(I}_J6V)~7^I`F!chiM>;gaAcXQUC8O*OhZ7i=o()Oc zTF7fd@~)~59Ff!KHuITG@Ku+7vQg^KY*lH&KvgR5Yc@DX6H@HX!U|I-{pW(2iMm-@Hz%jHmVvNA~P=}l!6oj=sQ*p z7o@y>`}h%i!chUSpp9)F8!Bg;GJ-2hD-7u-q=$?+w>g6hH53GImICr6l5^kyPTjZ$ zjGdOj-!86DAK=}}&3t?mDsn=zQlJVyjrUi=_LtBI_KgshSqnu8Am>0PFq6iyN5v6I zKH$6#b@53wtI&20iu?cX;SUcsl-3)$GE%4}%0FTxfKT}dj)yCYXVyKSR3PAZ;cUpr zx8p1orax33IH<)4FzuU`S+dYd|`K3St+*tU`*P-#x{-cR_}HugXE%GyKHcK ztdU_ggu@svK$`?icOgSEch}w}4ALVX@F?74LxL=xk{vtS-7#i5i zR1_Mz>(vk~oRR(GP3`mU#k4CS$z=qk07CCraE`|wmK?_&X^&BE{2D&O*p*a4bd3Y$ zQ^E3BL4V~Ae?U2%VLZPcV1)^1^XX`#zEAEB2c)P6UrWKo3ZXDkC=*bZy)x z4}O)p`Ri@F(;qhWqrXQV7V|PzHA0pfaHqort0ED^2^L#8icz*V*H}JFf`OjA9k{~l zD-FRWPY4WDAt~Cfh3y79_*A-2y(V~rg?~7n}TT>5NNVAvuzAc#{F-*XUethIF5=LnZDx)AXUe0S| z|5(LHKtrlRY~(T>s~@ARz>L5PsR+T6R&evjziA<(@$_!^%EAp%(xu<@h7X40NMsGP zq%p}fu&;jUIM8~w_e0{uRc5<2Yi)->yeZ--gbA-N(saoDy-ZbB$NoC!BCGhgmglTi zE9O>^1Os|mSzNW(Bh*qmk>BGwYh&rgTP|ASQf)6g0u5^AttFxNQ^Qg%8Gy&h7(Gstk8Dm;+*pP#NaIQTXT5}`U?`a+7iNxBsm z$pesJ?EhgIgH~SnJ!m?1-&D!igWjQ7G&m?@KzUkz<54rBWph1}@qBfws81&D^axj# z*ZT*{jB&hSnDy7mbMtD;xjv@-N1gbGlM4<&{%QV5Ja?wwEyw2C>t{yeN#TB9Gt#V$ zblAVEgifZ3ZT;X8eMWZb_{$?{h{*M9@3d@Hp02Z-w#+(=Qjfm4BjUbVJ%G5=kk_-)WXPjTo2m>9N%9* zZf#YlB|SD9{Vlh#>DfHqc1)2U;TIE>^bfw|eQ^HR%HDDqp34QfQwVX)I~+OM^SGOH9}6LF0euaOqQ=2GPas%$vsWbE z?%0FS$tqhf&r*9cD;|h#(ndxT;`&bqEQ`j#2WP3n-NU8ciK%2<}MI5--Bz#)@y2)(R$zs$0PheV=XUsEaxF;e!(a*-; zNfE;_M1s8vF5Fu#IrVQ37EGLBkaxWwx{A+WaVG=-ERr>>aOo43W3QnLO)?9lH^S?r z^Pe!ik!Fe~2=VzPlBCpxb;EyRSN!_2f0|4suNJmkf zGEcGx%QDlSdSK^z;I8jK4%;|J_PZZJtAO}*J+3;qfv_1T5b?|d{-Lj z{Zrt+#v-zPv^eDCtlhnXy|PTzA2~@RSQn6|iA#rJa#iAyXS>)lU$J+0?(ko=Tm)k&9CHjJiO)r;2AKyc?tW~D~Z`gFT?P&!fAXX53p_Nb~kWwauG{@Bv)Q6lHc3TnB&Z@qU56|LXy(%*LR z4!d?UNlKgLYb$jF>BYoP_5b=PoGM^MD2RGz)EJCQr5{e2;iC6a#BCRaD%X3QH@yP5 zEE)>L7*~M|{UqcfuN5UbDV_v)rtM4;BYrh4)TwmfjTNFt?+Vmv=EZWU2w^WGe3}KE z(JDdo6*jTOsy>`{F7Be!QTA)^Vs&Nbip_pmIY$nj&NTDLSFqjmGXD}DY7tOj#S#>) zir&;;DxvgeCF#nY^9)ouFmymBH$c*k6>^7v*|cRS=TUrlnDG>2KjqVL+Met7%Bhf? zyX+UZO$mX?PAx2QZ>lvO5B0UBrJjE}E_`JTm|YL6btn;PviS$*po`-3C>RN+#;m=weD%!blvLq z$$82CI!xZv+;K@5Fmrfpb~-IWCHR_qsn%)Z0OeAj8A_t}GP z`MzQs286F&ye;)Inqez_Sf}>!HLM-HJsZaC0ZSzAeB*fQ@7AKm|=DyAm(510Z7p@wJ zgiEyP^qtOl3sD#n;RA?%;WLG zt*3s`oGpRD6RyRP?}6r;v4oa?cf(e`XD#&lHZ@Mtvuy0Ji@A9Om?b-cWdY}E{PBCHofTa>Uh14eozt) zK-`bC;!pDitk0I}OA1j$Edi>LMu%^M4j_y7 zqP33y;TnHIsV`5FIXqxK0uA}u1c&sUo$e)!^q!9ob$^J%{)IeJ#J|vet8B#ZaDN9* zwQWeqUPWA7sWjI@mXRn8Br&vDt^$JR053q@+4?a$Wjd$#q^YuJ7*^zSD4EbuoF$WC z%po4|_~-K-W`AHIHYX88KRHgrhY`@aLFyH8^F1 zG3UhUmyNzf$6_5$pE^Te)qL5wm#Hy1H3cr(UluDd(}_R3xVYRQ`Oxaj9FB{Hg$0eE z;#>Cb9Yb##w*rzba8TO0(xB|;x^Trd6x%2Ct<6q%!f)xq{asrN`NNo?y`9D?diM?m z|2KH&#^h(cdi$0jukQ@(G=r-bdZJ=u@iMwkSB=m_MV~@}fjHa-Vzemu@1ALUfiY{o zJL>9KZ`;Uv;2FPg5J}~@v1cH_tdS+4Nj|L&HJY@(G8u=guBfmM_aiVl4e-9 z>OBZZ!+ms8FQ%tYMDtd*F$R5gy7z9@e`8j_r{ap_;F`){j@DKz;mmJ5OjjcRB+uHx zIUk)jyd+VL#3+(qn#)~mTh9|8tyWzn`K2z}rh3uWR_8~KWY+mGl=D836bZi9gXgB- z{q8JDi$A+$`7>ZO;^JyooCA34JQG6!zuK`Zem){)UX5bnQeKa2CpRb|$H&K?^mxrj zP|sstO~^Hlx3#lFAlNgfCPYN-;(wT*EE*Xo!w;=41>TQ$NWT|lPdn38nxl)C35As1 zRC5j3nSOR!*4sXC=c7DJLPnMh{B8KkA9<@k(Kb8S-RP*L*c$N=WmZZU9L)g%#N%~p zpV%9gW*cX#5ia0zB+t5Z-*J8zY1VatAwx8`Nn_0Y(`5B)&nxKq$^RxH{M zE|D~c)~;+-{07x0jjAuFK7}W-f`4();g4?)a%~Rj*&*{0Z|0#<0Sc1x-McvVB1;q> zl`1N1*u-BmiMQ8n1V!|#*jS2#zMgfP_#Y-G&k}s9cgu`h+I$326jz;B49-vs^|o`F zli~9}SvO5?{JE~pP6}0$#>@GQC##6Q1f?jrn6JTddi+?kP?KHYn$)WDo?#JvL^Wi=?t{JHet_I?nvRb2* zhnjZI)Xg3^X;1Cm>CFEZH(jv2f7;NxQ_oj-Cu zVtBqQ(hb$ql;#j9MQ;;@&O054>Iqm?6U|bsbK|K}{Yhpw8?le^gBC6NG<+=Ix-n01 z9sy1vxv(Qt4=7K)iYEoTwxh`MDdoBG>bnPVccLWr!eB{W>y5kga-g2B+Xu|NSND0l zP0%> zIaW`Km%emeX4JE}ct1$ktqnZSMlQq*)N-fpc(8lSq_!I_pOaq%0R%DvE!=hGfr|wF z-LAk>dN~4d={jsd^lej_^OqSMmh-QE`ualFt$8NRqCV&MWz1c@~3Si=y z$R%%9fs`RgTcn1@KM5zJyMJztuz)u`#c-^WaJX--u|{m@i#|cr3;8{Mc-m)lam7}x zu(+|3XshpA4e;XCmz^JGDw4pfUvjxzhF!}uZsQ-he>36DgQPQk-DF-CyjHi9CRPBT zx%=WzoA6qet7nrpF5NF{p5yS!z29qVTn=+;sAEkbOxM9;kKEYU7|E?-R*3Nd-aHBg zx-_VmD7W^ZKP9cNuZIEyF%R#b+}eZLFa=B~WC&hR@?T5jAGT`~Mzpxc{a=0Hc|ObYP6-}@|Lp@1r0$o-!5+D%(hUER52Sm3ptf0}z~mQh~~;bHM-^tMl{ z-Li)cj<<>({P749VuHq~IND#!r46@uYUs+!b=z<66QyYlkUR*9$n-Ml*sy65RrMpX ze(CLO(@AwVXYo(!^{c?Qwljgd;f(-5Z>W8n=GrSm;XWDyKIF1@&a4uYlgLm2c##MDeL0eb=nU=*~gec$kOfrQ&H}wJaA@yOBpgu45Huu(*>1;ftq6 z=bAZhwGI>dL|E{#95*lYD6fR(d9#79Fc^Q1uSc5S3XnVOiNYj zu-)ZAU!!5d(>kPIZRf)ulx8KH4P@Y0xL(T|8eTrtyQ@baZaQ$Ep#V1${TBZaK2>pfrouW%eqOvNoT_hw%~~y}=Cyeo~sK{cvdm3tqN| z@vlPIu&25d4zV_Qqxa=+==Koyu580%MGGcALOrdbIHT|-pinV1UO4*Ld6J%94&$PY zWXk^j97XqXZw^EK^(OZ)#+pUP)PsU~shcal+~s90Ay{SJIO(ggAEZ&q0@~J(LRf#xd`b#=IBRw@pn@ zuV0mB3rga(M0brmOIX?^_9)Gg%pAlOR{J@5H;J|S;iDn>oGi7PvT*``%0vBph0q1- zNRiWaem^y_m7;0IzRluxgqgq%(DLP!Iy%@{6BVw2Ii7knnU93Z#(QTKiAHL15$HWcGAdpfGh7-O3^{=Bs7R79P zr3&WsmIH;y32-&u*mOAPS>?awFdp_)K8iV2HA0Zl9kY4l91tem&CNbYQ!A*A&UfL? z{JN8qZ1MZYXBC%Bt=1D9OH(8^ET2<bqjCBps8 z{iW5yzc1b5XKE49%((UNYpV$1Wb^iSvVzC|TzDTC(Y9kp_L;{h>UwlGpMu}{eUKYj z@^jFenQ-ei{+2Xs9#RR=9WD*1QgkFD-sQ8_%E1cjb#7tQRek4NgpzxgnL;7+aS(TT zT>Bjsf%E{Q2~F`$5!@y%)320}eVQq-!@^U}2HsN0ONGri3dQRnDjHPp1r6kV;EG0b zKRO2oiBWsAKkEP-e;dX`awZbo{P=G$P&g@+x8T!$SwEB~L&l3l%1@)hFh_?go~m-O z6A=}vORa1^6xE=@?V0_~g}eQTK(KY$*jwF7U=2_u$y$bcj zLP)*{FVt>)h)`-muT|PSF=y)nqa1x@8_`8%H?cRs0e$V#R5a>6UlMEWnDzT7Q(8>QBk&XSYI*40E7gIKWIu)f$k{2D2G0?jNNu8MKLB;pK6mtoyCaF9LW@Wd zX}YIY2}XNC3W3#}PE4@D-RIrx-DP-pdK@rswA}3e88zNdCAy1mQITc3v5iZNc`JB{ zFcYicrEJ%!ZOBVG{`koK^zD7L{3CQ|AIkxEb%s#E&o&bc#kI6vvz?;neAQHtf=(w2$UVDs-y5 z3ePktRJnG#(|#MBDB3tj5&ce&x>}YSPU6s-(MvC%xbVX><~^&j7e(p3aGLt?lKL#+ z=&|(q=b;zDgd(63d6AJ_Np9Ygg>|Rs^5?91#37N(vY#9AYJlkCYp3z>RJqjYzY)ry zUw+8`u{u#~=)3LS^c(9(`q+vy24<^Fd1l>)ria3r@ed{Jc`!%5iPE{Q1l>(GeKtQ? zRi|U0PG^@!t?E0ufYA|gaiIT%fs~dN!H9zVMZ^^a5e~YR7dy2#*Z6ZtJ+nmKWg-)ps8z8$Qls#0}jqVodL<5*i!7J&#ye#?( zVkZLEJZ&I7Vn;BaT~*o21!XEVG9>bV*+en)ub&+kwGfwT5UaVr^14W@EYRrynG(MZ z9vido^mK5JPghX!RrI3YGaI5gt9C?fs_v&P*y3YalZ*hE=}va}3i*yc=9A%}@fBy_ z91ttiWH~zXiSvZvBW20JG%!2-bXge3%hNooBxQI~-x4MuRM1hM(d}7 zuHQ?k*n_E+rl!A@u9nf3?OorIxfk%SCDCjRP90!8c%@1xal+j3pj#-UKbTP#TsqiG zB`z9uj}Z>{PlmWJa^r4GKH^+R`B_muTY9Ryi0aAAUGvmlR#}ozBVO?>%bWE@JrZ~e zL7Pn_-lHPFXhJRG`|ZyjR~K;d`K)_jFqq-~^UP5BjoUW8hK3SwkXw+ngbEA|_4@T| z%Ns1|8mhHgo%w1?&(mbF4^C6-&q^!c(kY`-&N?Ml|12fb*Huk?QY1m3ng1kMp@W<1 zHFjim-i&YYgXE9a+=Iyp&rDk_*Z8{#E#qkgpXL|eI<3q+`maI~w|a0;2UQKRhQtEy zZ3%Pfu~tQ8nJH4D5PD7>f}O5a6nI%#0QWNa;bA;B8k2_jz9xpwMsD~XF-++6kURgcPabw`d4h(z zQn9WAvkSz?=P5N&!DAS}G^y(?WT>xCyvWjcF>h!-Wh&H{UT37k4Qpq&i(x zqTOigO8#a!CBLMJ7@-3EF)tKs{IJZg(W@ty+E?pKvM9yqt7y<}fHfG53Qxh=T+t{^ zPjuY8!4wBBdc^qC+eA)kb|ya4-PMV8v$W$o$HcC9tA=iAI-5A_u!1M~?FJGG8u=cg z-rw|>dqhyQh3bPMqojQ9+}*`{%eV;C?H?&*)Nt(RMxPXMOE#b}h_e(28Lw6zY-%Gu zCoq$tMw0&47YE4Ts9xDS=q-1-PRf4Q)2K0hIly7=W9m47khkmY16W^v;v@Mo*pfr_ z#469TOwqB9v4DAp%yWDQ-Bj*cO7HgF(o;vfZ*Z1e*?yiP+9W$Jy80gowe)>&BD7#W2>u z&jX7Wm1|@A-yXa^opAS85a~fE=4flU-K)ozE*%6MD!AQc=i3Jk6Yxt(#ebz&(dHk8 z!-?lRE3;Ljb*-?aF(m9$cavz71)RH&y#BGt_{FoEazA{9fJraJXZ?F52h4H<3c+L_ ziIO^Yh=BnKN$b_U|Da)DhLi}AQ07HB}8aT&>&`@Z-4f3Hp+omO<-GO-mPI zE&Z!p2&3}ECZP67V&xn2Ay(lMA~Fs&#`fOc`gyk_L?_5!P^tAfd!a>99629~cBmLY`6!m}^z6a1MX&P-$P}Tcl6(IB=g)*FgPgm^ zzu3P*1WPINw{%sc^eDJjd@aad_#nTr%)`Dzc1a(jt*ECFW<^+aW7G%ba7v4NMs#{N z9a}QdCpmI5;Z{TAi9*4QDk^X+fE8n#0nQ&?EUyC2mlal_#BR1wv_xgi-*FX8RHu@g zBB3A=$tQ#Xlz519vgMQ_SmBQ2`lnFB#V77Fxq#ncHmE51XGbW(j?jas143z{yoW7b zNvt2E(Rsv3r0-__%&&4`GUiq2S>2Yqau(};w@af#(O9!p;lsD`Jr1za+9eyq`0Gz$ zZMowtgDO7%PJWBlZ#zCeQB3jlviW<2Sa(~_`H<@U(7B@QJ2{8lv7k82ACQaruZH8z z7=+6Jctl|E>)UGt9r^rZV@~b1_Z$e(0i&jqRY^~kSnQ2caQ`H7j%}BVw#_HQy=y<- zt~3iu#%I)7x5gL9H3|;&({wgW#6pHpKYwls@F|v4c-`^1Et`aL3cje43lk5l+@~n1 za0dQqhHuyY)m^;D`N|ZfgO2+8{d*n2)7CHqX1H>NfQgCeb07Ke7y}MEV$YSF(#Hv< z|1C!J@Fq<0TjRluNthhNG+#Ul>jT{{~lz(cPzDYvGe9zRaxykWs7IXnq`pgPldjABgkcJ>hI*1kr-Co%7 zaW;gGKz=FDZn^!D;X@2*W$^3v_3}cGzI~fMG{KA`;eR^S~oPVf(~c4 zY9o&I^}p&^S#%^F@%9vvavTi3tFB~-u{v(BTM2ap=q*iE_ycHt1I;XL)>~Uv(|aYe zm|{#)zmc4YLWf24jNMq;4Vo%ngC;*%*4Y^+dqgc8sy`ag+p3+BmwF`E>Cmh-R`V(y zw%)3+C0~X4TA=$(%<1~F`q0U@q?@&SLu#zX8A+C2P53?jjfUSSx~3+2kh`2e5T61G zhJx2MB|ufBqzWo>63@*CODE~Sy<<@O8Y4z?9cs(8*8^xDF86h2Na9mJb851@J%)Dv)A2hYx&YSMV{pH2McPy2E|=YbTi`M0x21DemboW6uPQd!WuUvC~- zNQa3z=ZxXwPVB-ktr9$7DTrtE(^!VbRg99ass)AFcmYxw7UQ&&^{4yqSy=V&7yReF zkNlZ1M!G8@vRKy|6dz;e^=v2U$9k~tj&buY*Q;(Uds7QUF3ff(s#$oN%tFG;GKXM`T z%sGQF6HU$S8I|qGAzWR2sR0;q*YNDq;e-{`Pj@>|rKd}V`J@;ra*I1|Jg4Pd9`=_a zsrx73@$H-^%oeMV9s`PelqvVBCc}_NbRU?qQUXxNAa_@FN`;3iFY>81_l56=y6xY6 zRZF?AKYEr~ zXSg|uM(Ybffh~4uH4TAQve~rIcFzI6m<6SFwtlV1$_*SZt}PeB+x)P7wBs%QG!&bc znBFqOARb@0M_EVb1)07pk@qQlxq5DYR;j{D;XYnSAcwN>HZw8t*$DfDP+z~Q zpewsJ-(@-N?U)a*y$AGZSwjrJpyrIn1xft6A~S(nY|<~}(s*%oN2lc*xnyC9p~Bh6 zw?l~+b!qzqS=|2T`ea1Gl;eeNnIj#ESsB4<*R@>sn@4k%S@sn z-tFbewHxoYbI|ph8#t@GT2xO0No@af_VePBEd6iY$h+ROVzv=t`HrUH|7-WD*s6Yn zt{e?T%ApkzIQEeUP(Af-pY&_QOmnZ7-Fnri;mXNxe<8oiALDSgyNcpTZ}`GT=|_pz zoJof+MC}uEJhu*UPizVEs3zb_o$cmg@rQNrW}m@K@?Vz>Bbqa?NZzHT;Sh-KUj;gq z3Ky1{C)6LdlQ<+F7MHSbE`BF_n!YldWw1W4=!jcH;J}js+l~(`Oc#MNF7V6Q*;*5#Vj=STyS_wMg}^i&hpjH9d_)3V6}t zTmZt*utg_Kh;H>=vk>TPaiFU1(w5O~eWj$6I(DG>e8JStf-d zizB}}!8CS(+J8{>5UpwWLrcR;h8WtiT?uZU$C`{MIFmAc|30-RIFXwFjvhH!a>9+5 zC&XMBt|%)|rKiVG!&Dc_9mM#cp?9>I$VoEIc$k~~yaJH&egnsD`bLe~vIXAu*WXNI zKi1*!DY;)rvbb|iu~0I%46W(Mdd|t!A;r-?g)6o+$6)3`qR`0jT{sN~u3w!WJ;7sC z#7efaa;Wsjg6%i6=@Py5Yl_CL#DdN%p{zDjJlfe&f&%|yA3+UoY!WhcFfFs3k zG-2yb$<(vrHSIEe^u~9BbT}%pFI`z!MDv4O8U{?)l6-hRMCHBVj5;a|c0p74{R>o%*ZxA5IVds*;)fTGVkvJh$%NE^9ND z9K7)g1n}mS3h(^e2>Y8EqKEoxN%uIE7q|K?CiLnH-E_91SGJB*sOTcYg}zuNbohKpSJpOCjKO}1WG0l?00US9lD!K6`xK$udTkT&ps|44b* zkGieKTIZcT_@xWEo@H`&hOdclA4Y^_Ba9h|k$Y0YAM91_&}aDU2{ zAYV&b6l$E8!H$;_r|0De)14?nC+Wps{dQR4Jm6XPX8{@>!Sf9J6vM4i;^XFy4o9Kq z9r}}EjGn}{Co3f!;t>o~` zVs_G#iK;2Hxcp8-dZ3xA&C{+IgOJf&rP0jal*4CcCX&bl?WA)_+Dx2Xyd;}6|3u|5 zx$t~|l=aG}xD@>>z=wWBWX(IfXOan)L%*2nm@~Sc5OBVvo#%X@6W%-DaDnWVSM@8L z6%0T&opnTmYwDrS@3vtq6-C*+Fps~pVuVuqckTHuqFiB8Xiz}DMps@^x&K=p@eaFN z@q`DlIK6Yn^A&^4c#L^g4Ky-ZuEPPuEJ*5=0<{y+q9&eUQ=6_wAM9;SfZ1u?(ZR^k zSoJfE&sv!?14UO5umw!9syrE)9Q&A*&E4|zPLC|$Nmnf`Mc7P?!eKB|s;3(v`O)lm z?-UM=77n!aXj<$aP3I5Tda(!1niYkBTfoY-aOm~tY-};Si*DVOru@0@+yLE!_p_a; z6Nqrwe4xrh6auG0V5CTZ(56#vzX?m~Q`auGrn4#@JLN2t^E3v_Tbvy{8~}}|Lp*zX z^q%4aG;=AAsQ`OE*mt$x7pim0@M8pqo+XqN$uSk+)tR*JPgft91u6hZDf+4W%5ggbWbh`z` zfjNMS@=!eBL@VR%v_``oMQ# z0Z&}9O`GCakVZt*ZCWIV0-12ow{LUu7eVPls&Lp19u>u~!O>7pHy3IEO}E;n1Up2e z@<*Y7)<6P3;y9mKM9<4hhC)7I(UaTZO+Je80yPh=73}InPyw51p<8cZv;Lf=Q6L9j zk<9(5b@Rm20PoAjt6DF6jA^f?!o!=8sp8?>VHuz+!)9dKY5cog{_Y=yYhgv~`POao zA_`)yFeQD@eB6qL2Uv4I2Zcgd%E>9D8^B@~SI;1(guo8*ibi;NadxfzC{KikUSrVy zc}IC~rDpRoy8t2E?lB5UOw6{MD7_B{AWHVrDAnMa$Ysrgi=L#Gc4u~lEh;#b`Y*gB zYUcc!gfG8=+=w<1t?FjC*^g1QGdVYow)aS@MC;wW#E2xVjsSHcw07% z{I&{u&NAJFGxx#9I5S}aemId;t`l4Jr?wlTPY*8??8VLK?t+8(cIa?9^I&!too^>T z;vIPJbay(OdR=3Ay@+%SIR30tlat(>EbQVguIsIma(p=OpjmIkxnuxXP>@Rv> zvP9^xG1g)^cN72g5j-vyN$G`ExJRW`-DX6uu1|{GEjwU;WW;^b-v0UIt?VAIKj*(x z!S~D6;4ZscJVot+|FL4Nb$N?{g@t9z24QNbD)M8c#ygS+1_dE}JvcOgy>grvgT!n> z;MW8gtk=cy+s%>Ce|%$NH{DB##O^TGFV6snGKuv*!VX?cVwDJEG>svx;5D?{wb#wg z<0#Al>K!C82koI~&jobtB$!*Zeq@!4XRl!YPKcRU@0Onav!uB!M+BrQf%k*(M}=(isZwRac!>7kY>p* zbxdAi<#`tshu_)GlqBjwmN+jCEqrL)hqdCh2^x0S13UI*> z^TgyS7KpY{RVmJ@QN9O)oVod&ZRJm$C0_HNp9CzQKYxx`&~O1jj5u-Ac`j#D+)moD!EyQKQgENWs z-PNMf>}HCDg#CLKi^|DnBw+0 zOioS$>C!&l%t*+1Q{YnIxp>~5OK#vSEM)v$phPEY#-t=>gaa*0wKS7>;!@eqk5f+a z_af&%skW59#}IGx!FwHUflze6mHZ!y?v(_OLl;yv=nrBFFg{-b0vy=;OJBY4CB$Rj z2=NEAIXLF7W$T9k6J!F&Z)Ph(6gMoY@0rA7p^zaGxTD0}Y#f3D19TE-DBv+lz@?Pz zmc)AV`ch++>3h3|ib@erkBg72Mq-Rk9~L$fD>Rg}pfAzx{= zo62)+@e2Ltof56a@78DBa8i=EI)6rgm>)faghfCTxhC!wN9sf{u0I60>u}pn|X?ZXwmm~c6K@%;T*MG?~QQ} zsz>*$d$sO@$PR#+{ITepb&F7J%mh?&Ng->$f5-4(!?QcUz9g%%PF66{&4?R2z=9*; zeY8E0L?D*zR&Ycg`KhwO-O929GyrMYk3BwkHEK4H9Tcxk>X7goiPrhTUR}Aco{o`L zVzK4ieZ>oZSuOKGNEfDN`|Rn3<#)|%9~S!_QB4J>){82cZQn>yNp?d|-@5p-(rq4pl^)MN!9F_+8W10sF%r%o6mxe2VteByht)Xi+5E^5dte0J+sJ}d1fu=@A z#P2R|X$j>*fXGs3s#3;OGVAGJYGEywAcZ~-!jGK0%IX<2jRoFuGf_S3c#))=5*xhK zG66pvsGhx4qK?|nt5z4tQ~aIHfs>1oTjo*GR8?5}zsdnlVowx8gGN>VCEgkVOn~wv zRualln6lYkF+_Fu)X-!<3{FqK4=5#m0>UpurDt0{wnYE@5YJq&{95_Y4|^E5Z{DvG zEmYqH!c9YDZ;WeJ7UsL%!o=2sLXLOy0{bIv65WNGpolx4MHEt_!y3f*WpO$qUPICi{CD#zGPyZytDlgEXu6EfpZYXbXZ(*AEaGgxCHhrSQ5w*}Vxib|iU}b< zvjJ>vlwmdj?&U_n(E{B%9?AghWv<3fEaPs1m26UR@de=0B6c(LaUM{~Bkars=W$#_ zr;3c^dgkk41kVMsoj8zhZi8wS4Taq1?vOy(BiM!rQ>;sLAW0J;Wes~Axk1DVCgLxR zK_KJ_C=NOb87;@zG+dP3+qttjFb|F~-+C)pjERcZ9`M>`ZjAUB_-2`RM;)Z&k3?Nl zF3n1^Z8^YMa)-Y5D?bpo$y4^jwcHMG;-bHgTMAu>j}-w27v#n2jg%cch%GWAD}zz} zIb>!HM!#Rc#@GK*XUbUlTB~`-4zJaq|4Rzq8&8_IS0}jO9zEqmYxXB%Sm{z#DliRh3^%?Amo2_VYvRJ_QR0luotHzYqr2jY>rIxWGk#hn z#82MmcH6Z(oPV9MByBq40pr#tETC@FO`o;c*r8xTphw!6`QfLXrvJUZ|Mszq0%$l$ ztYvj?j?wBjwrm41)k6rbuw)SRHRSb>*7+7JocqF-U49P4zev^y!e55L)3&Qfuf+{} zQ25mO5j27W2AA#RSLQZ{{zs{XkU!$Q*=MqfR8{}KwPn5??a7u^oK*9>uZ0Om;zz6W ztZXMlpZ^H*eg=Ky+s)5ojYBb&9`ENyFoe~H=lEx0&_~i|ah8t|2LGUSNCJLecQ!1# z2x!wbpPAa`FY@93FMxq~{Whz*glqwno)M86WEx0VaH;+5xkXhHEZ}Td6;b34F0Iww zd(@gDSDwd1c<{ua64!(L&A_DFQCGSoAT-*-m06coGUI$#o*e}h;!e31%Lrrib(dcL zsSws28t+E;^=Bbf>czk4+J6uZpt$|`f6cvhSe4zjH@pZDL>fduLZrk(rBjqpK@ce= z1VsdCB&9rUpU*}5{8)mPzLq91bUtkl;scVc?v-1K`TN0LUjqTOAXF%CicYEWL%~8to_2~ zr-F0S${EUt;;(Jz|MsaAG1NIUg|2&|uQMN6xv>&E>Ra<;sb~+7VNHetHGUs$HVmNK0m56rPROm_2DFtS|N@Vq>c>z;3& zL0Z)Lc-p7oId4L@`I#fQ@YrN&$C5(Qr-uTtaJ@sFaWUS5qwPsHLb4fy=f3=`*NO?8RuWGZPcP)gDy|!FX4V`!qxA#fSh!y3CE3BzAI1^J zTF+<5@=VNhNlLTl$axHpz%3RXfm?iVYrGNNA{-2Ejkmrv;2`*0U|*?y zvX>L)7jNfypPf0jIJv>r-L>z?ii$Cfbrqk`YdO?nA#chfyPTvYkh?GOzQfP-*pT$V z+zZfPmy;l^)vM15$r~*PB@^E5Iz? zC3&YZYLjY~W-r|aB0{}jbrZ#uO6?i`k=(J)$5cn}NvW+2_3&>6NWZ|%1#-{TfU}u3 zfNotlvF?7&ics+pB~tY>$(>zE);X82+5M``6KLl`tc_oX{cbzJ9qzC1W_VVl5D^u# zbse6DlX&MtWYMV1SL|k`m}t9fr(`3UigEs^@4i_8Uf08wdh5pq^@bLVBO;y8!ZioY znhKaW{G@1CKbuDU-p0Kel58!q&$YF+041TdEadhYX@;}>RStk+*gQ2bS<1SqOQt^E z+Mw-Lf-m4m*e$F2YQ;geWAHR|NGBfVRxh9HW1QVD zQ(*3I4b>C0sR>V>zL+ZX;4f<|(BW&mF!1N$D1`>30v4Eo2>)do}c^{ZhrnA zP03ggbA#(|fqH!aYLV)$sr~^XG(}_p+E4U@Iqh z64ZIiYm2|VQb2<~82a?y-e6RQ_$&p$BYIzh&gegrR} z*xVyc$WhUNwpPw@Uk|IS5W=|YDeZC zzDgB!I-tnH^2(D(S)5v;XM?M ztZDspmv!gjvg0-=o9$9pvh)O6;mU@|y}EbBYy~LIv$tqC(v;6u9SF)Myh0jC0TZ{V zxTTHC{zV4T@0$X@q=EV&7PUs+73*v6yM6Wv=s+uZ(t{<@tlqINuBWVUA~1h2HfLnW zYR%Z>N5EabTz0>Ov`K91?(Lys5?WoKpr0Br?$WA71%u?I>|OK=77)V5dT2CL^F+$! zUzDP02hg2zub13+?VYT?^n(PSL!Ly&)ORo{csr(!RT8pQKII^>$P+`~W_Qd~QaQb| zf!OTDNNNkl2{S&1_|36ZzO2Cm3U|_J2M2sD;r83(rx_MDgtxxZ`{hLP1`MnndG zd3y%)!X6PI8XyLybGV zCf~>AMW@!)t|y&?F=pNWiY5rB429C^WrqX*-JW=R2e6~)!rBiI zW`=CMKKBUGu)bAnLHB^&n>{eOdNn@3AjE}4432DIMl+rFw(ssFa_%vWkeZeA0}_(1 z>(!rmEPt!QJ`ld-c?`+{^vx?7T{H@xquMLIemggpyt&%Ke%VDEuV~rboOo(=|MMg9 z1b#)AMcn3?TAg(3Cu~5NVyj?K2F%;mM@+Z|b=bv2tdl9_Evpb&o$_yW3b(3HzsU!l z%&Y}oqg7P#Tf#**k1y(r2ahicjXGq^=f9Kvofc{g(T}Wt@ce&WstsL=W@jct5r3^T zs-i37{5%KF7R;TQUa1KsfbR*8wt(_fRJDy{&P_{B*SH z`syp+FX0kOO@Vp6dbsnc`=~_DA>zXCaMv^+zlGy)^a{2wUO-XhAFPV)rgWgTEsqkH zdPX%7;G=AU_cvUWzJ2#5uTY-qNrN9d`)2%kp8ubz#|bK}=77+PXRQ}UqnNb=AqC~Pwz%q`xF zTMVjXBqaGBOHq2o#X{k&PAJp}Qxx^jLL-0Yoc{0DldJ|SN#Fcq2&~7mJ$lF;Ds+?q zP23NsyEwS4AdbNETkR)CLWb{s6lFhtRvinw%j`AbF*rsh&CfaHvOVPsaX}HPRJsJ3 zPEK)lj+}TeMLQpp=e`#rZ+gzA)}wJo*#+2WmNz5o`77Hpke3b5NlMpKS{cgnq)?y{iYvp75od2ZJwM*dXEJzO{(q!nI@myzgzD zq628p0sU7z8tu6Y3!1!+{SXjKn5X25bKva=SYe!IjLS!w?Z}>+vdCDqd@zmxfm6n1 zTE*paS|T=QGVEbNkV+5w^yQ*=^N-s`X&t>A8(-E{#<`9YQ%>q0Z1dN0VA&227Y}n`9=1PuJMMTp zhkljUz@1FE@S{)WJned2`Spa2E_5wgwX_w|v#V0jE8-FYqJ_@{OWpg2q|}?!U(NCF z=eF1EVRSHHaxyZ-)Q~@j(_5@P-8V*ackDp?cIDFYGdTzy!$0iqBZ!ghAhcmYtc{7W zj`tT7npyo_uWJ9-ZIL3JgbMAAypGfGZQid=T_&~>s@lZhDl!jy9c9s+)g+S!oX*@w z3&tc#Q08fX*vY2y@$#>AmIvh_=m^oEMl4$<7B>*Q7YYE}j7;&(mU~=nBg=N;1 zN8t4ORP66D6UuHwRVtBRfOL6h=d&6za!}-bj!QXqFC6sX{8&>fo%*+iZk2+^!)L}* zz*0OANVmSEQcO~E;3{lA)X>p+S|zD=@U=DfffSB!$rcBs_t(=vjVG8A{68&>3vLKfc=>&M^Lge18 z04kl@8CEBflR=Lp2xeWjP`@k!OxQQC3~A^H&x=T~Y5me(a516zTOL5G_rkL$5fKq% zcW&(ki2cV+h<#hhZ>M%Er%ZtS(^Vcn12erH|bqmIw#rSIiRFp z%EfnTQ1^UySp?2G{>mE#D-{w>_BFxoSk?%L-I_6DM_3t}~QWs6_5OfBip$|oS z7=O95^zbokIY61_jTY>rN(t2e;?m29dgf|CP8Wac1D3QUXlx8AjZ*0akEkY-mt`+; zZ34fmWDh)@g*Ss&*3p4VV;|f1Q zt+7?bNb2=s!~^2b0OKt;Y>!Wcx+tlandJr{=jmHU(=Q}Kw`=(QIf5~s6_^+pfC2h9$!UY)l6N+&b;Dm_*)QeW4ei zb{4u+g9Ra9qGpc2C(#-`|L&Wbt*A}~t{H1mgpE4_A?&tP(2u`&EvHq?4#6z;rHBm| zrZ+@YsSN0tk3f=+4Tdc{a3mm$6JU@=@coLu+fgNeHs9>S!-S2Wcza zjg2_{LaODtzFnoBZt}?9lm3=s5GWogY(P6?x#D}t?X*}#Jh<`eM8)5l3OOBwy`ZEV zS2za)W{nXi^KGfcGxMBZ^@27d*}VrAT84Y_l(WtN^~=~^X!_pv z-fx_G2fuF=<4W{ug}MvgUZe9@pt=Do3xkPq`E}vPNL`6LZghZPH#z@ZjyGRSnRsVi zuceGuWw4(<{*xZCM&pdM2;DB=#8R1^D7s94JSg|t`$*D}y48bz-wLVq_(TBzbmocq zHz$_u@;66$QV=g6Si+dnnZ61p@lQ81_3K_qTUT*6fJjtaP?VC<*;=rA3vXUBd>e8= zZz8e0O3Xl3>b)>Wry$2R5_i&r+IUX^o==dgNNraNZAck!~Iz#oMar1U3Z?$ zSpspPV9^uOl=`Q_eryasH#gtxrHw!@EqK>vFAajO8qQ+$hO(=2kbg+Jx5MnWxU@w-TvD z8_-(f#$aFq2y;KPxGGjrTjJHlX(-vj#vb&?%0Q^^QFSo)KvHs=S5<=}JFm(n832>j z>E`fK9T4R20YTm@SEKEle>sw7H7vn%9lCzw_v{s4o6&)`glC$XTt`aR3AyjRcPQ`Q zz8?)Mbiwo1k&A00At(QK^84ZLW6nnDr<--nCaA>cYHH?93bztgz&@t3=&7rBy8Qme zb0WQ%wG;MEN3Zz*5?^%s&KHO^2QSP|%V>XawC-%#MimjMt4DNcsvdD zD+6HK*ig{PRI+&3YNE1w(j>NXsq;l3iwE+BANGY6 zD+zEmYq!TInM5i4WExLizL)ndJu@aRDd{pT=W%;LP~-uEVnPc%$AK2#a`;=z&n3w; z)#0?@v~g>xymmkuWuyP#YJ7Oli6C47AXGMmYp{Chr?Pu{ZzgaLlNvOxUiCw+qaa*dfoxP~t5^pBby>N#{M&lAl3vL4 zc=+2&*B~!885-SxmJn=evo_A_=(}4$3HtadDpV~-Ia(mc7I2<}I&YOaL8)|jXh>GU zsO9#lFo5e~E;WqJRc9UOin`|#!{g7x%bS!(0fh50!dFi8!?P!Zr&%*WvS^}m?xPNK zd;x*5{ZS8)5+pX^{6b+V&q0mDUO}9>U7);hKk%q${wxLMyQ?=SP1BFOvMppVEVkO4!)221Sm1K{d=+C znvEITP=wE&se~CgtleBabhi%#l0fR#LY}(ia1l}%lo+=K9hA}O7gXeHt^x^zBII6!huN0@t7389i&!t1-gv3?!vKpdAadyC z#HOU|1&$F+@|R|V+#j{*0ea*=2>^-~k~D}6T-Rf*`%v=8rC8{GXmL=HFrAppn&E3w z!2(ff-h+Fa3K^2Hy-Nv+ZFu{|myYpjcJ69a+L+&L2D_sq9er;&OoI~y0H@!8v?6SiO98xeHCYE1DYAwi$k0|2 z;bNP01e^Xv0I>gU7V-5qN`;L!Lp+D)!{vs_l5(2oE{k_ybiFy6aO)2!j;$SBrguJ> zJ6}Y;;3lMDSw)ZWo1f-+4sAYVN;Z*;Y+7O;S2z;O+_{hG1}N&r-b* zv)7`Tt&P8ts36TiqFWf?1y5?8CTck>(el#~o*GtCs4R#cIz+W+QS^#3YfRjQW4OlA|GM*y+v4*)6^aGC(92Og%#Y2Ku1sMNA} znG0#IVfKsM<$w}X7O3a@V$yTNQ{w@#^zcxkq@gq&3|3xk%v6!%l}?`19i3P@xEzn> z2Lsg3Riw;o)ZNolkHkK8Tkbiq9?!8cEf;dIUeW8|9wRpB5Dk91K1a<`Y^6d_=dd4WC?Q0wpkpbrEI#V?9;Ois0$TTfTm z(yG`&F7oz=Y{@n6pl_^Vr?L-PTTam0%7|5Ev+c!g;o{9nb1Cxa02j|GN!3QwCP9f! zoR_`Lll?C5<)MoLe-_ywXO`B*g#n<^QB0Eh#GNx0O3{w_gAj$t*2Ma)sET z(|R}mr3%1W%%;1qA36N|z9|C0)P$bcUH6{D@Nsy=HNo6`PHQ6aEV#H_GqduAOy!pC*s1vEP|$tfHxgWHrOL z&g2b=x(YfnNvZtB5@-x@OD;Qmz(?{a;#1!$T#6{>e3Pd>?OQk_BhRMavs@55i z*)f?PGY39Ol{Wv*`pCrw;2w#CBL3R^?q66YIfwmz?E16pg0^4ihBIwP9XPN16*({V zzX`ow4jO)-?(dcFrb3Ijwg^o1O>&XR@&i>@m;>--5JKuPixjAlevJ@HS6cn_msFT;Rx|G!f<9^2usnMbk#Bn9p^M8mFh7dHvwat6d8MOKZJp4wnlS%CwH* z-z8G^&KbnQU2Zica?+CcciO@1a3Z6`AZ7_`gCnF8?)DlWE=|2xRxNk>Bvrk?sYuUO|U60BF873k>60J6BI(DV!^@=#M9| zU*~F3g=)akD~&0$nPqjc9v|q&n&uFr!DoiOm{caoXM1S&e&}J{z9$B`y3>7U4KF6a z@fAtfVlvFyP64=yvz;D@L?8-6OG{gN>(m<)e{;ZRi)Xj{rY}}btvZ;{lUgwK1AVgY zIWJH2CF{kxfcF$XHPRaIrW$K>9))8bFREQvSi;PRfS_{#kM&r8)Z& zf+zie6;Pqg{&lxM$X)}Y)8;s6$#*aY@PZxf`oB(OAhw@4l_@2RE+jDuq1n`o+M6;H zok`28_Fsk}a>)PLTLKW{W?%V*A@EM#I5A$xn=;;?Zth37ZWX?(Q>2qDfa6R{;~4xv z+G2s{F+kbgu}_f(HGlOS~r9=Y`95%RZ=L_6W&=sjH3Yht{9!SYWZFf_@qk=FI%`0~X%o#e1G) z%CAgTO*7@{@il)cz9fxE%RhP*brov;oH08+$pz?1W;O5bNS-NKj~#6(>O7SuRS#tA zR~pPd8g>caZESUYajT#@<9NE6=UN}H>)v5xpKH{wHo5GL_W}{v(s_TGt_?uMZuy#i9N?qdYwmIeVID|8G^2fS zn`bB!vKRV>%_5_kGoZ*}GPS}@Qflc_SYO5^z1eC|84-D(7JGM6kE&fl3UG|%NZwAh zyZ6)XkfQ)%WBc69;YPpx0c5Iyv{-pEOmVa^Q8g*P{wAo~#??s7wD7LXCbJId>lZV; zr%p**N;b8wtw(UfX?f%6)q26Q!TVqfz&Au@Vk-{+dvIcx^erBhnp;j^sSFwaC65JO zBmKe%v$?P)$%%!;`rOln*Q}Ksm~4;XR|0P)WNE}0X*34Z6$fp~IxHr{!%>=FG?$NhYx6g z5Y7^_c6y{7+L!}KO9c+8hvV;8OPS;-28v)6{|USUum|hiY?YvV5dK!WW&SON(xD?1 zzI;Z(QY&_8vL74!EIzX*omf@Jx2_u3A}pqBpp;y?hxnispz06G9b^<7=Ql^|x)*RP z39ernUmp6R_LprlOTI)WZmm-vlvvUEIq%|SUI%Fk=YhbMCO()+13WMtFl}!LbO6%^ zlY>4_jw81;S&5%=;Wg##RamNB${CU9z&m zHHR(b$#?t8B*w0Rsv=q7AyDm*J2X3(8x=i8h6k(~K*H9T$LCsLJq1Q;ze)8P~Z zjj}rY-TiP^Y#(~M9)5mW34Qx!=z5Hdw)YBSQ;7j~H@zsoeHn&_s83#VO4NSh<=7pL z!SzWb!yV|>B?bB{%AY;=Ex>*Y6w}a5>eUsbdumXFV0-(~@TIM=QOe^Vr2J(Xyf8W& znYMo|GYshT{SOhJohz|*xM+05DL=`kbyD)72I8Qex)Ng zps6szIq0mvT#FM7y#R5S>|f&UCXx=85u`il>>q!pkUt$FzLuX%hc;vfn92jK`L#Hr+A0`vhe;j>?i=t0b(I7Ctx%P zJY=asIO8~gS}U{Z>e%ZgkM$=jZK$UNSdW%^cGtsv^maUq)d!&GjQvXVgRho ze`3+W+JpfVgK~>NLN_q8ALZV;K>^K;2mNi7XbzAR?EYfSsrFF2sPtXhyk(bdn1SD; zkA!T7;u2~^y0h{n&E<5^d1u&^ox3{^HxOns9(@K(Gi)vSPVP*9L#ko&{U|<8qb%q$ zgLr_Y*I36s^I+h-!@yA{yKpL4mIg^ebNN!KZh1bkv!B{YHE;r=Jw46hs5N5@RHeJ} z?=F33CT#i4eK0)IE{1Ijk@ zhylAeiV$>O##4unz?5KUPCt|4Vu9jgct!sp(4k$1UeT(wHI!9QTA(Q{SyKr;2p@gq z#{KdwllWEb1lY+bm{|?tG*x0cZ=B>EeL?P8UVgGE*Y`g z=_SWdb8_^!uu?t^Mt{nY2n2O>gqDpC!vMHyaHpQh(e8v%B^Y@ClKYm(I_?g!7iOf| z-4~2oY|4m~kk$^-N(@NhV@%1yZ0zbsy67)etSzEtEQ_Ifgw$RR^X*y?nBA3{!? zPm_dHEBE;w#5UW1k8Nm@9TFM}q8Vz+Zwxs|VeAcv&LpF=Ax|C!iD36u{5hcBTGQ5G zkj6vn_4ZC;^u_G0NjY&=&QI@$KV<&8N`Bzj4Y7oN&1d}rEf_K^y|wNpvL`EfjYB_gAV;jK;j7hxvVW+Z72ito`3er-5i= z?x4p0j(R4kIK=(?McMEo>4eNa_CZx>N8w$$|gp>H_V!d`Wfn*gK5m2ZS-#hNUR5xu5>+7SDf$ z75he3d=Csy6Qc{EQ$YNUGH2!r1FZM&;1{tK=AM(63Iu$b6MG4 zd-m+%yE}MZ2XyK73z}U`ikAb$5c=yqMy{Qx=vAoNZEBD?c&p9cvB zOKmqM8HTu=lZ$9OrGBM8iX1|HC04q#=i2u`Z3etfj~|=HL3L^WME1Y8QNrVvHp73D9>aas{&;n9_ZMp7!2oIu-j1i$a80 z#RfG+&Sw}k+Fe3jwlo{y%+%arnWHs38W+@$4LEM{P3>XUSyKbR&SF1C*)?B-n?Dy< z1soi4kiX3bW*C!1V9dN<_=`s`6&7oqA^sx=ajYdVYDqv*uX5#z*QE%-KH&X-1iBrP zhzbb_LB@24qmQZBBWFN($DfH-%>^Sovk!i80aK1oPY#|#qH@s9;8^n9?(kHm5Ry18%oQ1NRr4 z{a$tfFShs?H&NtV5$W=)8jknS_O1Y3fSP zz6w24>vzz8y#Wq$2`MQb564G(iP#LEp*6_!L1Z`>FC-B-- zEpZ^tINK1jK?~+P&{C*^$GbkV->v!3MV(VOk9SE;XA++m69PJ$nuZ2G1}35V7JV>k zMsKD+{8qv=;a-lf$}7ppgzSOVJP0%C@Ke?6ptd9bCGVp(OJJCMHNFJgmhykTjlf{Q zuYnnUt|e)i+VFkzrvp@+Q_xMn!$*aOab1}!18e{W_1w#gU025rY`#wgxvz+6(@{_> zxw;o~85YLBO+p{4sKTx`QH4(9VT?YQHH`vAtf)%Is%)2CAe>|`?uN$-l9YroW0V_u zuC%xCTqofJr6Q*{#-jA6JE98Q*ujLv$n!Fvc#j;ZfC)7o{{~Ad>?q^APo_O(s;d^I zu`k4+1fR=n_?gmB3m?3?N$qPBcU1StRl?WIV(=5T3XT^pT!^azp(K4^fWocphQl@+ zv7F*l9@BN07?s0y`xaMk2u8rSqKI@``H6G}oE9d+#(q*676Eudmem2st@Q?lpalT0 z?gl72m`fbP+7BR^mW9am02U9c9 zz&@KQg{_qs2rKd5hnXdpMr0dVK&ARHqe7q6{s-5*!PRLg9E63b$1etStEuD5SlBPdGn2;rF8lR0(Flor3_A zmfyeU{4?5?otMOXf>~eJxhhTFf{zi5DR!Pl23 z^TCAVBrt7lr#czH3<@cMpqXlXw9_7#0l9;@jecH7oKP4icF?riS$}ebQssA;BG)5U zDHf!6O+dyZtY;Ej(+6%7OKLMZysK4!l9Qz^_#T1h>?! zK~$t!#0O&1Gw!D^34Y*yv(R`DVlG+y9u~C1HXi}w6d+n(i!V}+kwjLb))1~xq&o|a z>j=XDKQ7tltU#$!Gb{1r^M05L0S`U3!L6KxyZ1hg65=dcgYD!YkmqQBcV`KdOCmaN zNH=zaVBp>s9h*yYS>@>Yj)x}KkAT=;!caJk&CiW{L7|vD1N%Y=kcpsvp$WvK zu?s2aHa2xqAUPt2sVW4EiKb_ZrR}u`v4cz}ZxW_~6*Z<57E7i{Q&PxxU!E?QzUhEu zduxFWH`rJ~<5$J<3UjZK1ooD((pOIQrkKT-7_%zym5B>7)yz&*zEo0qvo8W)g^3Y% z_1wuLnN&GZuo_@^X*$;mtYQoXy5jA&N-~7Vva!C(g_XlU5kec1M29T{LImQ!J|9B1 zgw4IPD8Gn)dt4y%<`;UklULCU<-`o=;+V{}b`?5+CEom4<*|Ml^91{o_~|VXGp3LMcM817%SNW&SP)84Up)`sO!P+A|nSXhUFSrT!A`VUxfq}`xO1MdO${o z40AQ>)`iX!202Bzja7J+6dWebDDKAvg*|$;paf%lr}5M!@P@}=pw;i}-6TJbqT{!Q z_LmAWyi(IBnCUwGd{cjF8aeIqUkVuC@iBV!zH=UyA6d*HmE6`6clWq;q<;W9KX{me zf&ye{f!9#eCYbjR-K1B)`8FS2+Tz68T-O^VdLg>ALH@ zI;JUs$H5Aq z7g#6}AT7hpw-QAKNLDP6w=gp>n6IpRHX{>e0#^dIMtwTcq;dTK*Upb}1!R<8VaZkS zC$U0UGc;a+K!w4euL+f*M^m!q08D)=AtK*66*dFWugzoHhi5&9e;AxhsAZ zNXh;V8#|Rr?R1M&$@r9Nr ztTL~;xqDZEUE1fnYgNvFBvulrX?+6E-}wnSuneYBYL+n3e<^6r2>MnxXHxvKYG6th z@u5B$M3?}*5de(j@>kcW40E@Ub0o}S;AdCFKub&= zn^O1VspWq2b zO*4n#KN=1tMzgDYiU4fai{rYecv~apO07-`TxZ@-s!>XvM-EUBqVixMOB3N0N1UxCV;Rod)E_RJ@%(*B<>}pKN#j%6~pc{y+Q~o^RhkgA4FQ^BbB_?a(L*&{WP)JBr=V)a?ToVI&2y%S; z73bB^v+!K{3@2aompQ+iOh7Z9(PL5SmUA@Wq>-Q+J zS%{2%P>~>c7c|#`{tl`M5P;CGMb{T<#O8b5z~;|1uE&nHfQSP+It&Dc!OBEhVJ&a( zFrw3w3xpMPwgPrDKbdhz0_)3EuTqaf?P>DL{yk6&JrPia>yvr63#cv?~~+D zy&u;XjEGn(;J?+DW#|ZTW46norw<+$;{nq@+lf{xbSp?90ofS4lZWcuC1C+kXya8R z&9o-Ym~Gip?9?Fza$zu`qyR@z85rNDXdHo5_W+GpLD510rt{e)%IN6Hwe@;rCSPBH zmh@GJTHt>O?kFLUKhm-2k`INV)0>f)D*o zntJ971={7~^EUwui`=$1FRqcM^AK%xuuJP^;~C?FQ)vqDu{R z>3UL_(sL9N_aSV4jOw_VDZ}b4&Y}ZZZM$SS;WkUE7%5BzXbTwIp~4yx@wi+v3B2Lp z)Muc zd8rjDo#XUyzD8>RTw|!A-PC7;P{OF!z%jvGw&Ed5Vtq(#OK>6v@T&cmNrEJywrrie@4kQ{`L)pH93OxdOqa*)j=+A~P61EH(K zQPVfaa(mG{(cAbHO16NMMbK7A1Bds0G47z$j>6E)H=&+)O9v!lPQ1^ahJ<`C3fyGK zIKN@^p=h(SXTbM>Tq$hL-VEAK$bm;CM6T9GTvi2x9AtrO+BcFT1K1c8Dd0<#-p=le z!L1oZ)nvm#Ulvme|6J3^LH{wJCJ?p)ovY#YYd0I30;F4c;o40&h9Ll|H8S*oI^K9D zA2!iVNbX3!k(@=`5BvgkTLuC}Y)2^}p`nw^#J!`s?`dN%6Jx9(O$n|p3w#&<{gr4H zDRjRQEgPP71Qg$snR}*gf^eL1=(3}SKyw@%fDIQzMQZ~RoI~xLYq|q+EzLc2vFQbp8<{%_;9|f-w5xg0p&-$+qKcU(;|=#lSfpX z7bAP7eVa_o=S70QnF$UIM_7J>W&CKeBM+%Tp>&J{ci`{@gwYjxja01%04vP6wn$It z0rx`=IqX5&_{H3?-SmYw;TO*W%Y#Lmrv4VAWQac*<6<_YAUzIpf1Ce?ABZ^C%kjW;8LhG}zw z{dD(ZbcBKZVXFVC^N~YVM zoWu*8re}27)V|_QcysT`LvoM&rdAZ@+`NC5-#T>z7Il5PX-{1G;p^%2n!0v$u#|lO z{fQK4!K}h`P6U>+U-$S7Bv$jb*Qh)QFWuaNDlbqCkEMZ@0Q8CLi$SYV=7iAtV6gI- z>xMy)1j>d2qR1{CgRo0vEWuEK)EqC|i{W$Ma6x&Gj3t@pV%$K^19d zBL;Krb?~#GMu$XH(IR&s=u+Ta)>X;P$lzNMzwnF?AV`=}@DJy6U9mOm%%=i21gb%% z@(0MjMvBmYl%1*V8k{iYYcl~msHUzitwPJ=Pp$cAV6g-TrK1g{;)d*^glacoAcd9q21b!lyc+^}h*_cm9p(@j|M-qs0$o|t2ni^IWWI`Zv zDSZz#j<{{VuTGb+Zo-#)(UIc{RCGJ!mXCzDxl7bBdKIOZS?))X7}%SC@Hwr3W8Ak z5&^n#`@x3E<&1}sz|em{l`GLk=*l$>#^`k7o6*j+cEml6 z*&r4PE0AY10FXKO01BEZL30(wwM^TeSCB+6bUZ5IrYZM40IMVTa1?u)1ouMMLF1zR zB@|+B)6BKEm~5kg;Qwp%No-b^@XeqC|NrB$XY= n: x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) if i < n: skip_connections.append(x) @@ -723,8 +730,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: logits = self.lm_head(x).float() # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") return loss # ----------------------------------------------------------------------------- @@ -742,103 +748,108 @@ def _load_data_shard(file: Path): assert nbytes == 2 * num_tokens, "number of tokens read does not match header" return tokens -class EOSBatchFinder: +BOS_ID = 50256 + +class BOSFinder: # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 self.i = idx - return starts, advance + return starts, ends -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps files = [Path(file) for file in sorted(glob.glob(filename_pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 if align_to_bos: try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) except StopIteration: # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) continue - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) - buf = tokens[start_pos_local: end_pos_local + 1] + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths new_params = yield ( _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) ) - pos += batch_span - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps # ----------------------------------------------------------------------------- @@ -850,18 +861,18 @@ class Hyperparameters: train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1695 # number of iterations to run + num_iterations: int = 1670 # number of iterations to run cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate # evaluation and logging run_id: str = str(uuid.uuid4()) val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end save_checkpoint: bool = False # attention masking - bandwidth: int = 128 + block_size: int = 128 ws_schedule: tuple = (3, 7, 11) args = Hyperparameters() @@ -915,7 +926,7 @@ def nvidia_smi(): num_layers=12, num_heads=6, model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) ).cuda() for m in model.modules(): if isinstance(m, nn.Embedding): @@ -940,15 +951,20 @@ def nvidia_smi(): group["initial_lr"] = group["lr"] # learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training +def get_lr(step: int): + x = step / args.num_iterations assert 0 <= x < 1 lr = 1.0 if x >= 1 - args.cooldown_frac: w = (1 - x) / args.cooldown_frac lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] + return args.ws_schedule[ws_idx] model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) @@ -957,14 +973,14 @@ def get_lr_and_ws(step: int): ######################################## # Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 +warmup_steps = 30 initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) @@ -977,7 +993,7 @@ def get_lr_and_ws(step: int): # Training and validation # ######################################## -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) training_time_ms = 0 # start the clock torch.cuda.synchronize() @@ -986,7 +1002,7 @@ def get_lr_and_ws(step: int): train_steps = args.num_iterations for step in range(train_steps + 1): last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) + ws = get_ws(step) # --------------- VALIDATION SECTION ----------------- if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): @@ -994,14 +1010,14 @@ def get_lr_and_ws(step: int): torch.cuda.synchronize() training_time_ms += 1000 * (time.perf_counter() - t0) model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) val_loss = 0 with torch.no_grad(): for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) val_loss /= val_steps del val_loader dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) @@ -1021,12 +1037,12 @@ def get_lr_and_ws(step: int): # --------------- TRAINING SECTION ----------------- for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() # set optimization hyperparameters for opt in optimizers: for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr + group["lr"] = group["initial_lr"] * get_lr(step) for group in optimizer2.param_groups: frac = min(step / 300, 1) # momentum warmup for muon group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 From 05502bf95629f8142187ada9eca3b451604044e1 Mon Sep 17 00:00:00 2001 From: emelyanenkok Date: Fri, 5 Sep 2025 16:43:58 +0000 Subject: [PATCH 05/14] Add skip_mlp_block code and results --- README.md | 4 +- .../07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt | 2853 +++++++++++++++++ .../1858912a-2697-4461-9edb-e5ee4246ee3d.txt | 2853 +++++++++++++++++ .../3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt | 2853 +++++++++++++++++ .../56955462-7201-4627-91d9-b2426a1424e2.txt | 2853 +++++++++++++++++ .../5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt | 2853 +++++++++++++++++ .../70af20aa-f602-4cc1-85e9-430a1664f62e.txt | 2853 +++++++++++++++++ .../8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt | 2853 +++++++++++++++++ .../cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt | 2853 +++++++++++++++++ .../cf8c8a10-ea32-46a0-8276-241330023e83.txt | 2853 +++++++++++++++++ ...n_0f6c8eac-db39-49ce-bef8-08a34044625f.txt | 2815 ++++++++++++++++ ...n_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt | 2815 ++++++++++++++++ ...n_3f42c181-6303-4ade-9f64-556d44d54065.txt | 2815 ++++++++++++++++ ...n_50e5b966-21a9-4545-8c88-91308e140958.txt | 2815 ++++++++++++++++ ...n_803c2d15-4adb-42d2-958b-0b712cd9d062.txt | 2815 ++++++++++++++++ ...n_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt | 2815 ++++++++++++++++ ...n_adcc39f4-c919-420a-bd94-9d0035f0038c.txt | 2815 ++++++++++++++++ ...n_c753588f-47c7-4107-9087-3c5da90cc0f4.txt | 2815 ++++++++++++++++ ...n_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt | 2815 ++++++++++++++++ ...n_e501e1e9-39fa-473b-bded-39427f349f37.txt | 2815 ++++++++++++++++ .../f01447c9-da70-405a-8ed0-858caadd1194.txt | 2853 +++++++++++++++++ train_gpt.py | 11 +- 22 files changed, 56689 insertions(+), 6 deletions(-) create mode 100644 records/050925_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt create mode 100644 records/050925_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt create mode 100644 records/050925_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt create mode 100644 records/050925_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt create mode 100644 records/050925_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt create mode 100644 records/050925_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt create mode 100644 records/050925_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt create mode 100644 records/050925_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt create mode 100644 records/050925_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt create mode 100644 records/050925_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt create mode 100644 records/050925_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt diff --git a/README.md b/README.md index 9185383b7..e40ee0a1b 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,8 @@ To run the current record, run the following commands. git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt pip install -r requirements.txt pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade -# downloads only the first 800M training tokens to save time -python data/cached_fineweb10B.py 8 +# downloads only the first 900M training tokens to save time +python data/cached_fineweb10B.py 9 ./run.sh ``` diff --git a/records/050925_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt b/records/050925_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt new file mode 100644 index 000000000..6abacd0c3 --- /dev/null +++ b/records/050925_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:45:09 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 130W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 72774 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 72775 C /usr/bin/python3 610MiB | +| 0 N/A N/A 72776 C /usr/bin/python3 610MiB | +| 0 N/A N/A 72777 C /usr/bin/python3 610MiB | +| 0 N/A N/A 72778 C /usr/bin/python3 610MiB | +| 0 N/A N/A 72779 C /usr/bin/python3 610MiB | +| 0 N/A N/A 72780 C /usr/bin/python3 610MiB | +| 0 N/A N/A 72781 C /usr/bin/python3 610MiB | +| 1 N/A N/A 72775 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 72776 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 72777 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 72778 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 72779 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 72780 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 72781 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1705 train_time:390ms step_avg:390.50ms +step:2/1705 train_time:411ms step_avg:205.49ms +step:3/1705 train_time:481ms step_avg:160.24ms +step:4/1705 train_time:571ms step_avg:142.85ms +step:5/1705 train_time:664ms step_avg:132.73ms +step:6/1705 train_time:756ms step_avg:126.05ms +step:7/1705 train_time:848ms step_avg:121.18ms +step:8/1705 train_time:940ms step_avg:117.56ms +step:9/1705 train_time:1033ms step_avg:114.77ms +step:10/1705 train_time:1125ms step_avg:112.53ms +step:11/1705 train_time:1217ms step_avg:110.68ms +step:12/1705 train_time:1311ms step_avg:109.24ms +step:13/1705 train_time:1407ms step_avg:108.25ms +step:14/1705 train_time:1503ms step_avg:107.35ms +step:15/1705 train_time:1596ms step_avg:106.38ms +step:16/1705 train_time:1688ms step_avg:105.52ms +step:17/1705 train_time:1782ms step_avg:104.82ms +step:18/1705 train_time:1875ms step_avg:104.15ms +step:19/1705 train_time:1967ms step_avg:103.53ms +step:20/1705 train_time:2060ms step_avg:102.98ms +step:21/1705 train_time:2153ms step_avg:102.50ms +step:22/1705 train_time:2246ms step_avg:102.09ms +step:23/1705 train_time:2340ms step_avg:101.73ms +step:24/1705 train_time:2433ms step_avg:101.38ms +step:25/1705 train_time:2527ms step_avg:101.09ms +step:26/1705 train_time:2620ms step_avg:100.78ms +step:27/1705 train_time:2713ms step_avg:100.47ms +step:28/1705 train_time:2807ms step_avg:100.24ms +step:29/1705 train_time:2900ms step_avg:99.99ms +step:30/1705 train_time:2992ms step_avg:99.74ms +step:31/1705 train_time:3085ms step_avg:99.51ms +step:32/1705 train_time:3178ms step_avg:99.31ms +step:33/1705 train_time:3271ms step_avg:99.12ms +step:34/1705 train_time:3364ms step_avg:98.95ms +step:35/1705 train_time:3458ms step_avg:98.79ms +step:36/1705 train_time:3550ms step_avg:98.62ms +step:37/1705 train_time:3644ms step_avg:98.49ms +step:38/1705 train_time:3737ms step_avg:98.35ms +step:39/1705 train_time:3830ms step_avg:98.21ms +step:40/1705 train_time:3924ms step_avg:98.10ms +step:41/1705 train_time:4017ms step_avg:97.98ms +step:42/1705 train_time:4110ms step_avg:97.85ms +step:43/1705 train_time:4204ms step_avg:97.76ms +step:44/1705 train_time:4297ms step_avg:97.65ms +step:45/1705 train_time:4389ms step_avg:97.54ms +step:46/1705 train_time:4483ms step_avg:97.45ms +step:47/1705 train_time:4576ms step_avg:97.36ms +step:48/1705 train_time:4669ms step_avg:97.27ms +step:49/1705 train_time:4763ms step_avg:97.20ms +step:50/1705 train_time:4856ms step_avg:97.13ms +step:51/1705 train_time:4950ms step_avg:97.05ms +step:52/1705 train_time:5044ms step_avg:96.99ms +step:53/1705 train_time:5137ms step_avg:96.92ms +step:54/1705 train_time:5230ms step_avg:96.85ms +step:55/1705 train_time:5323ms step_avg:96.78ms +step:56/1705 train_time:5415ms step_avg:96.70ms +step:57/1705 train_time:5509ms step_avg:96.64ms +step:58/1705 train_time:5602ms step_avg:96.59ms +step:59/1705 train_time:5695ms step_avg:96.52ms +step:60/1705 train_time:5788ms step_avg:96.47ms +step:61/1705 train_time:5882ms step_avg:96.42ms +step:62/1705 train_time:5975ms step_avg:96.36ms +step:63/1705 train_time:6068ms step_avg:96.32ms +step:64/1705 train_time:6162ms step_avg:96.28ms +step:65/1705 train_time:6256ms step_avg:96.24ms +step:66/1705 train_time:6348ms step_avg:96.19ms +step:67/1705 train_time:6442ms step_avg:96.14ms +step:68/1705 train_time:6535ms step_avg:96.10ms +step:69/1705 train_time:6628ms step_avg:96.05ms +step:70/1705 train_time:6721ms step_avg:96.01ms +step:71/1705 train_time:6814ms step_avg:95.97ms +step:72/1705 train_time:6907ms step_avg:95.93ms +step:73/1705 train_time:7002ms step_avg:95.91ms +step:74/1705 train_time:7094ms step_avg:95.87ms +step:75/1705 train_time:7188ms step_avg:95.84ms +step:76/1705 train_time:7281ms step_avg:95.80ms +step:77/1705 train_time:7374ms step_avg:95.77ms +step:78/1705 train_time:7467ms step_avg:95.74ms +step:79/1705 train_time:7560ms step_avg:95.70ms +step:80/1705 train_time:7653ms step_avg:95.66ms +step:81/1705 train_time:7746ms step_avg:95.63ms +step:82/1705 train_time:7839ms step_avg:95.59ms +step:83/1705 train_time:7931ms step_avg:95.56ms +step:84/1705 train_time:8026ms step_avg:95.54ms +step:85/1705 train_time:8119ms step_avg:95.51ms +step:86/1705 train_time:8211ms step_avg:95.48ms +step:87/1705 train_time:8305ms step_avg:95.46ms +step:88/1705 train_time:8397ms step_avg:95.42ms +step:89/1705 train_time:8490ms step_avg:95.40ms +step:90/1705 train_time:8584ms step_avg:95.38ms +step:91/1705 train_time:8677ms step_avg:95.35ms +step:92/1705 train_time:8770ms step_avg:95.33ms +step:93/1705 train_time:8863ms step_avg:95.30ms +step:94/1705 train_time:8956ms step_avg:95.27ms +step:95/1705 train_time:9049ms step_avg:95.25ms +step:96/1705 train_time:9142ms step_avg:95.23ms +step:97/1705 train_time:9235ms step_avg:95.21ms +step:98/1705 train_time:9328ms step_avg:95.18ms +step:99/1705 train_time:9421ms step_avg:95.16ms +step:100/1705 train_time:9514ms step_avg:95.14ms +step:101/1705 train_time:9608ms step_avg:95.13ms +step:102/1705 train_time:9701ms step_avg:95.10ms +step:103/1705 train_time:9793ms step_avg:95.08ms +step:104/1705 train_time:9887ms step_avg:95.07ms +step:105/1705 train_time:9980ms step_avg:95.05ms +step:106/1705 train_time:10073ms step_avg:95.03ms +step:107/1705 train_time:10167ms step_avg:95.02ms +step:108/1705 train_time:10260ms step_avg:95.00ms +step:109/1705 train_time:10352ms step_avg:94.97ms +step:110/1705 train_time:10445ms step_avg:94.96ms +step:111/1705 train_time:10537ms step_avg:94.93ms +step:112/1705 train_time:10630ms step_avg:94.91ms +step:113/1705 train_time:10723ms step_avg:94.90ms +step:114/1705 train_time:10816ms step_avg:94.88ms +step:115/1705 train_time:10909ms step_avg:94.86ms +step:116/1705 train_time:11002ms step_avg:94.85ms +step:117/1705 train_time:11095ms step_avg:94.83ms +step:118/1705 train_time:11189ms step_avg:94.82ms +step:119/1705 train_time:11282ms step_avg:94.81ms +step:120/1705 train_time:11374ms step_avg:94.78ms +step:121/1705 train_time:11467ms step_avg:94.77ms +step:122/1705 train_time:11560ms step_avg:94.75ms +step:123/1705 train_time:11652ms step_avg:94.73ms +step:124/1705 train_time:11745ms step_avg:94.72ms +step:125/1705 train_time:11839ms step_avg:94.71ms +step:125/1705 val_loss:4.2939 train_time:11931ms step_avg:95.45ms +step:126/1705 train_time:11957ms step_avg:94.89ms +step:127/1705 train_time:12031ms step_avg:94.73ms +step:128/1705 train_time:12133ms step_avg:94.79ms +step:129/1705 train_time:12229ms step_avg:94.80ms +step:130/1705 train_time:12323ms step_avg:94.79ms +step:131/1705 train_time:12415ms step_avg:94.77ms +step:132/1705 train_time:12507ms step_avg:94.75ms +step:133/1705 train_time:12599ms step_avg:94.73ms +step:134/1705 train_time:12690ms step_avg:94.70ms +step:135/1705 train_time:12783ms step_avg:94.69ms +step:136/1705 train_time:12875ms step_avg:94.67ms +step:137/1705 train_time:12968ms step_avg:94.66ms +step:138/1705 train_time:13064ms step_avg:94.67ms +step:139/1705 train_time:13160ms step_avg:94.67ms +step:140/1705 train_time:13253ms step_avg:94.67ms +step:141/1705 train_time:13346ms step_avg:94.66ms +step:142/1705 train_time:13440ms step_avg:94.64ms +step:143/1705 train_time:13532ms step_avg:94.63ms +step:144/1705 train_time:13624ms step_avg:94.61ms +step:145/1705 train_time:13716ms step_avg:94.60ms +step:146/1705 train_time:13808ms step_avg:94.58ms +step:147/1705 train_time:13901ms step_avg:94.57ms +step:148/1705 train_time:13993ms step_avg:94.55ms +step:149/1705 train_time:14087ms step_avg:94.55ms +step:150/1705 train_time:14183ms step_avg:94.55ms +step:151/1705 train_time:14276ms step_avg:94.54ms +step:152/1705 train_time:14369ms step_avg:94.53ms +step:153/1705 train_time:14462ms step_avg:94.52ms +step:154/1705 train_time:14555ms step_avg:94.51ms +step:155/1705 train_time:14647ms step_avg:94.50ms +step:156/1705 train_time:14740ms step_avg:94.49ms +step:157/1705 train_time:14832ms step_avg:94.47ms +step:158/1705 train_time:14926ms step_avg:94.47ms +step:159/1705 train_time:15018ms step_avg:94.45ms +step:160/1705 train_time:15111ms step_avg:94.44ms +step:161/1705 train_time:15205ms step_avg:94.44ms +step:162/1705 train_time:15298ms step_avg:94.43ms +step:163/1705 train_time:15390ms step_avg:94.42ms +step:164/1705 train_time:15484ms step_avg:94.42ms +step:165/1705 train_time:15577ms step_avg:94.41ms +step:166/1705 train_time:15670ms step_avg:94.40ms +step:167/1705 train_time:15763ms step_avg:94.39ms +step:168/1705 train_time:15855ms step_avg:94.38ms +step:169/1705 train_time:15948ms step_avg:94.37ms +step:170/1705 train_time:16041ms step_avg:94.36ms +step:171/1705 train_time:16133ms step_avg:94.35ms +step:172/1705 train_time:16228ms step_avg:94.35ms +step:173/1705 train_time:16321ms step_avg:94.34ms +step:174/1705 train_time:16415ms step_avg:94.34ms +step:175/1705 train_time:16508ms step_avg:94.33ms +step:176/1705 train_time:16601ms step_avg:94.33ms +step:177/1705 train_time:16694ms step_avg:94.31ms +step:178/1705 train_time:16787ms step_avg:94.31ms +step:179/1705 train_time:16881ms step_avg:94.30ms +step:180/1705 train_time:16974ms step_avg:94.30ms +step:181/1705 train_time:17067ms step_avg:94.29ms +step:182/1705 train_time:17160ms step_avg:94.28ms +step:183/1705 train_time:17252ms step_avg:94.27ms +step:184/1705 train_time:17345ms step_avg:94.27ms +step:185/1705 train_time:17439ms step_avg:94.26ms +step:186/1705 train_time:17532ms step_avg:94.26ms +step:187/1705 train_time:17625ms step_avg:94.25ms +step:188/1705 train_time:17718ms step_avg:94.24ms +step:189/1705 train_time:17810ms step_avg:94.23ms +step:190/1705 train_time:17905ms step_avg:94.24ms +step:191/1705 train_time:17998ms step_avg:94.23ms +step:192/1705 train_time:18091ms step_avg:94.22ms +step:193/1705 train_time:18185ms step_avg:94.22ms +step:194/1705 train_time:18278ms step_avg:94.22ms +step:195/1705 train_time:18370ms step_avg:94.21ms +step:196/1705 train_time:18464ms step_avg:94.20ms +step:197/1705 train_time:18557ms step_avg:94.20ms +step:198/1705 train_time:18650ms step_avg:94.19ms +step:199/1705 train_time:18743ms step_avg:94.19ms +step:200/1705 train_time:18836ms step_avg:94.18ms +step:201/1705 train_time:18928ms step_avg:94.17ms +step:202/1705 train_time:19021ms step_avg:94.16ms +step:203/1705 train_time:19114ms step_avg:94.16ms +step:204/1705 train_time:19207ms step_avg:94.15ms +step:205/1705 train_time:19301ms step_avg:94.15ms +step:206/1705 train_time:19394ms step_avg:94.15ms +step:207/1705 train_time:19487ms step_avg:94.14ms +step:208/1705 train_time:19582ms step_avg:94.14ms +step:209/1705 train_time:19675ms step_avg:94.14ms +step:210/1705 train_time:19768ms step_avg:94.13ms +step:211/1705 train_time:19862ms step_avg:94.13ms +step:212/1705 train_time:19954ms step_avg:94.12ms +step:213/1705 train_time:20336ms step_avg:95.47ms +step:214/1705 train_time:20405ms step_avg:95.35ms +step:215/1705 train_time:20496ms step_avg:95.33ms +step:216/1705 train_time:20588ms step_avg:95.31ms +step:217/1705 train_time:20680ms step_avg:95.30ms +step:218/1705 train_time:20771ms step_avg:95.28ms +step:219/1705 train_time:20864ms step_avg:95.27ms +step:220/1705 train_time:20956ms step_avg:95.26ms +step:221/1705 train_time:21048ms step_avg:95.24ms +step:222/1705 train_time:21140ms step_avg:95.23ms +step:223/1705 train_time:21233ms step_avg:95.21ms +step:224/1705 train_time:21328ms step_avg:95.22ms +step:225/1705 train_time:21424ms step_avg:95.22ms +step:226/1705 train_time:21517ms step_avg:95.21ms +step:227/1705 train_time:21610ms step_avg:95.20ms +step:228/1705 train_time:21703ms step_avg:95.19ms +step:229/1705 train_time:21795ms step_avg:95.17ms +step:230/1705 train_time:21887ms step_avg:95.16ms +step:231/1705 train_time:21979ms step_avg:95.15ms +step:232/1705 train_time:22071ms step_avg:95.13ms +step:233/1705 train_time:22164ms step_avg:95.13ms +step:234/1705 train_time:22258ms step_avg:95.12ms +step:235/1705 train_time:22350ms step_avg:95.11ms +step:236/1705 train_time:22445ms step_avg:95.11ms +step:237/1705 train_time:22538ms step_avg:95.10ms +step:238/1705 train_time:22631ms step_avg:95.09ms +step:239/1705 train_time:22724ms step_avg:95.08ms +step:240/1705 train_time:22816ms step_avg:95.07ms +step:241/1705 train_time:22908ms step_avg:95.06ms +step:242/1705 train_time:23002ms step_avg:95.05ms +step:243/1705 train_time:23094ms step_avg:95.04ms +step:244/1705 train_time:23186ms step_avg:95.03ms +step:245/1705 train_time:23279ms step_avg:95.02ms +step:246/1705 train_time:23371ms step_avg:95.01ms +step:247/1705 train_time:23465ms step_avg:95.00ms +step:248/1705 train_time:23558ms step_avg:94.99ms +step:249/1705 train_time:23652ms step_avg:94.99ms +step:250/1705 train_time:23745ms step_avg:94.98ms +step:250/1705 val_loss:3.9797 train_time:23838ms step_avg:95.35ms +step:251/1705 train_time:23859ms step_avg:95.06ms +step:252/1705 train_time:23936ms step_avg:94.98ms +step:253/1705 train_time:24034ms step_avg:95.00ms +step:254/1705 train_time:24127ms step_avg:94.99ms +step:255/1705 train_time:24219ms step_avg:94.98ms +step:256/1705 train_time:24311ms step_avg:94.97ms +step:257/1705 train_time:24403ms step_avg:94.95ms +step:258/1705 train_time:24495ms step_avg:94.94ms +step:259/1705 train_time:24586ms step_avg:94.93ms +step:260/1705 train_time:24679ms step_avg:94.92ms +step:261/1705 train_time:24770ms step_avg:94.91ms +step:262/1705 train_time:24864ms step_avg:94.90ms +step:263/1705 train_time:24960ms step_avg:94.91ms +step:264/1705 train_time:25054ms step_avg:94.90ms +step:265/1705 train_time:25147ms step_avg:94.89ms +step:266/1705 train_time:25240ms step_avg:94.89ms +step:267/1705 train_time:25332ms step_avg:94.88ms +step:268/1705 train_time:25424ms step_avg:94.87ms +step:269/1705 train_time:25517ms step_avg:94.86ms +step:270/1705 train_time:25609ms step_avg:94.85ms +step:271/1705 train_time:25701ms step_avg:94.84ms +step:272/1705 train_time:25794ms step_avg:94.83ms +step:273/1705 train_time:25887ms step_avg:94.82ms +step:274/1705 train_time:25981ms step_avg:94.82ms +step:275/1705 train_time:26075ms step_avg:94.82ms +step:276/1705 train_time:26168ms step_avg:94.81ms +step:277/1705 train_time:26261ms step_avg:94.80ms +step:278/1705 train_time:26353ms step_avg:94.80ms +step:279/1705 train_time:26446ms step_avg:94.79ms +step:280/1705 train_time:26539ms step_avg:94.78ms +step:281/1705 train_time:26632ms step_avg:94.78ms +step:282/1705 train_time:26724ms step_avg:94.77ms +step:283/1705 train_time:26817ms step_avg:94.76ms +step:284/1705 train_time:26910ms step_avg:94.75ms +step:285/1705 train_time:27004ms step_avg:94.75ms +step:286/1705 train_time:27097ms step_avg:94.74ms +step:287/1705 train_time:27189ms step_avg:94.74ms +step:288/1705 train_time:27282ms step_avg:94.73ms +step:289/1705 train_time:27375ms step_avg:94.72ms +step:290/1705 train_time:27467ms step_avg:94.71ms +step:291/1705 train_time:27560ms step_avg:94.71ms +step:292/1705 train_time:27652ms step_avg:94.70ms +step:293/1705 train_time:27745ms step_avg:94.69ms +step:294/1705 train_time:27838ms step_avg:94.69ms +step:295/1705 train_time:27931ms step_avg:94.68ms +step:296/1705 train_time:28024ms step_avg:94.68ms +step:297/1705 train_time:28118ms step_avg:94.67ms +step:298/1705 train_time:28210ms step_avg:94.66ms +step:299/1705 train_time:28303ms step_avg:94.66ms +step:300/1705 train_time:28396ms step_avg:94.65ms +step:301/1705 train_time:28488ms step_avg:94.64ms +step:302/1705 train_time:28581ms step_avg:94.64ms +step:303/1705 train_time:28674ms step_avg:94.63ms +step:304/1705 train_time:28766ms step_avg:94.63ms +step:305/1705 train_time:28859ms step_avg:94.62ms +step:306/1705 train_time:28952ms step_avg:94.61ms +step:307/1705 train_time:29044ms step_avg:94.61ms +step:308/1705 train_time:29137ms step_avg:94.60ms +step:309/1705 train_time:29230ms step_avg:94.59ms +step:310/1705 train_time:29323ms step_avg:94.59ms +step:311/1705 train_time:29415ms step_avg:94.58ms +step:312/1705 train_time:29507ms step_avg:94.57ms +step:313/1705 train_time:29600ms step_avg:94.57ms +step:314/1705 train_time:29693ms step_avg:94.56ms +step:315/1705 train_time:29786ms step_avg:94.56ms +step:316/1705 train_time:29879ms step_avg:94.55ms +step:317/1705 train_time:29972ms step_avg:94.55ms +step:318/1705 train_time:30065ms step_avg:94.54ms +step:319/1705 train_time:30157ms step_avg:94.54ms +step:320/1705 train_time:30250ms step_avg:94.53ms +step:321/1705 train_time:30343ms step_avg:94.53ms +step:322/1705 train_time:30436ms step_avg:94.52ms +step:323/1705 train_time:30528ms step_avg:94.51ms +step:324/1705 train_time:30621ms step_avg:94.51ms +step:325/1705 train_time:30714ms step_avg:94.50ms +step:326/1705 train_time:30806ms step_avg:94.50ms +step:327/1705 train_time:30899ms step_avg:94.49ms +step:328/1705 train_time:30993ms step_avg:94.49ms +step:329/1705 train_time:31085ms step_avg:94.48ms +step:330/1705 train_time:31178ms step_avg:94.48ms +step:331/1705 train_time:31271ms step_avg:94.47ms +step:332/1705 train_time:31364ms step_avg:94.47ms +step:333/1705 train_time:31456ms step_avg:94.46ms +step:334/1705 train_time:31549ms step_avg:94.46ms +step:335/1705 train_time:31643ms step_avg:94.46ms +step:336/1705 train_time:31736ms step_avg:94.45ms +step:337/1705 train_time:31828ms step_avg:94.45ms +step:338/1705 train_time:31921ms step_avg:94.44ms +step:339/1705 train_time:32014ms step_avg:94.44ms +step:340/1705 train_time:32107ms step_avg:94.43ms +step:341/1705 train_time:32200ms step_avg:94.43ms +step:342/1705 train_time:32293ms step_avg:94.42ms +step:343/1705 train_time:32385ms step_avg:94.42ms +step:344/1705 train_time:32478ms step_avg:94.41ms +step:345/1705 train_time:32570ms step_avg:94.41ms +step:346/1705 train_time:32663ms step_avg:94.40ms +step:347/1705 train_time:32756ms step_avg:94.40ms +step:348/1705 train_time:32848ms step_avg:94.39ms +step:349/1705 train_time:32942ms step_avg:94.39ms +step:350/1705 train_time:33034ms step_avg:94.38ms +step:351/1705 train_time:33126ms step_avg:94.38ms +step:352/1705 train_time:33219ms step_avg:94.37ms +step:353/1705 train_time:33313ms step_avg:94.37ms +step:354/1705 train_time:33405ms step_avg:94.36ms +step:355/1705 train_time:33498ms step_avg:94.36ms +step:356/1705 train_time:33591ms step_avg:94.36ms +step:357/1705 train_time:33684ms step_avg:94.35ms +step:358/1705 train_time:33777ms step_avg:94.35ms +step:359/1705 train_time:33870ms step_avg:94.35ms +step:360/1705 train_time:33964ms step_avg:94.34ms +step:361/1705 train_time:34058ms step_avg:94.34ms +step:362/1705 train_time:34150ms step_avg:94.34ms +step:363/1705 train_time:34243ms step_avg:94.33ms +step:364/1705 train_time:34336ms step_avg:94.33ms +step:365/1705 train_time:34428ms step_avg:94.32ms +step:366/1705 train_time:34522ms step_avg:94.32ms +step:367/1705 train_time:34614ms step_avg:94.32ms +step:368/1705 train_time:34707ms step_avg:94.31ms +step:369/1705 train_time:34800ms step_avg:94.31ms +step:370/1705 train_time:34893ms step_avg:94.31ms +step:371/1705 train_time:34986ms step_avg:94.30ms +step:372/1705 train_time:35080ms step_avg:94.30ms +step:373/1705 train_time:35173ms step_avg:94.30ms +step:374/1705 train_time:35265ms step_avg:94.29ms +step:375/1705 train_time:35359ms step_avg:94.29ms +step:375/1705 val_loss:3.8236 train_time:35452ms step_avg:94.54ms +step:376/1705 train_time:35474ms step_avg:94.34ms +step:377/1705 train_time:35550ms step_avg:94.30ms +step:378/1705 train_time:35647ms step_avg:94.30ms +step:379/1705 train_time:35740ms step_avg:94.30ms +step:380/1705 train_time:35832ms step_avg:94.30ms +step:381/1705 train_time:35924ms step_avg:94.29ms +step:382/1705 train_time:36016ms step_avg:94.28ms +step:383/1705 train_time:36108ms step_avg:94.28ms +step:384/1705 train_time:36201ms step_avg:94.27ms +step:385/1705 train_time:36293ms step_avg:94.27ms +step:386/1705 train_time:36386ms step_avg:94.26ms +step:387/1705 train_time:36481ms step_avg:94.27ms +step:388/1705 train_time:36577ms step_avg:94.27ms +step:389/1705 train_time:36670ms step_avg:94.27ms +step:390/1705 train_time:36763ms step_avg:94.26ms +step:391/1705 train_time:36856ms step_avg:94.26ms +step:392/1705 train_time:36948ms step_avg:94.25ms +step:393/1705 train_time:37040ms step_avg:94.25ms +step:394/1705 train_time:37132ms step_avg:94.24ms +step:395/1705 train_time:37224ms step_avg:94.24ms +step:396/1705 train_time:37317ms step_avg:94.23ms +step:397/1705 train_time:37409ms step_avg:94.23ms +step:398/1705 train_time:37504ms step_avg:94.23ms +step:399/1705 train_time:37598ms step_avg:94.23ms +step:400/1705 train_time:37691ms step_avg:94.23ms +step:401/1705 train_time:37784ms step_avg:94.22ms +step:402/1705 train_time:37877ms step_avg:94.22ms +step:403/1705 train_time:37969ms step_avg:94.22ms +step:404/1705 train_time:38061ms step_avg:94.21ms +step:405/1705 train_time:38153ms step_avg:94.21ms +step:406/1705 train_time:38245ms step_avg:94.20ms +step:407/1705 train_time:38339ms step_avg:94.20ms +step:408/1705 train_time:38431ms step_avg:94.19ms +step:409/1705 train_time:38525ms step_avg:94.19ms +step:410/1705 train_time:38619ms step_avg:94.19ms +step:411/1705 train_time:38712ms step_avg:94.19ms +step:412/1705 train_time:38805ms step_avg:94.19ms +step:413/1705 train_time:38900ms step_avg:94.19ms +step:414/1705 train_time:38992ms step_avg:94.18ms +step:415/1705 train_time:39084ms step_avg:94.18ms +step:416/1705 train_time:39177ms step_avg:94.18ms +step:417/1705 train_time:39269ms step_avg:94.17ms +step:418/1705 train_time:39362ms step_avg:94.17ms +step:419/1705 train_time:39455ms step_avg:94.16ms +step:420/1705 train_time:39547ms step_avg:94.16ms +step:421/1705 train_time:39641ms step_avg:94.16ms +step:422/1705 train_time:39735ms step_avg:94.16ms +step:423/1705 train_time:39828ms step_avg:94.16ms +step:424/1705 train_time:39921ms step_avg:94.15ms +step:425/1705 train_time:40274ms step_avg:94.76ms +step:426/1705 train_time:40343ms step_avg:94.70ms +step:427/1705 train_time:40435ms step_avg:94.69ms +step:428/1705 train_time:40526ms step_avg:94.69ms +step:429/1705 train_time:40618ms step_avg:94.68ms +step:430/1705 train_time:40710ms step_avg:94.67ms +step:431/1705 train_time:40802ms step_avg:94.67ms +step:432/1705 train_time:40894ms step_avg:94.66ms +step:433/1705 train_time:40986ms step_avg:94.66ms +step:434/1705 train_time:41079ms step_avg:94.65ms +step:435/1705 train_time:41172ms step_avg:94.65ms +step:436/1705 train_time:41269ms step_avg:94.65ms +step:437/1705 train_time:41363ms step_avg:94.65ms +step:438/1705 train_time:41457ms step_avg:94.65ms +step:439/1705 train_time:41549ms step_avg:94.65ms +step:440/1705 train_time:41642ms step_avg:94.64ms +step:441/1705 train_time:41734ms step_avg:94.63ms +step:442/1705 train_time:41826ms step_avg:94.63ms +step:443/1705 train_time:41918ms step_avg:94.62ms +step:444/1705 train_time:42010ms step_avg:94.62ms +step:445/1705 train_time:42103ms step_avg:94.61ms +step:446/1705 train_time:42197ms step_avg:94.61ms +step:447/1705 train_time:42290ms step_avg:94.61ms +step:448/1705 train_time:42384ms step_avg:94.61ms +step:449/1705 train_time:42478ms step_avg:94.60ms +step:450/1705 train_time:42570ms step_avg:94.60ms +step:451/1705 train_time:42663ms step_avg:94.60ms +step:452/1705 train_time:42756ms step_avg:94.59ms +step:453/1705 train_time:42848ms step_avg:94.59ms +step:454/1705 train_time:42940ms step_avg:94.58ms +step:455/1705 train_time:43033ms step_avg:94.58ms +step:456/1705 train_time:43126ms step_avg:94.57ms +step:457/1705 train_time:43221ms step_avg:94.57ms +step:458/1705 train_time:43314ms step_avg:94.57ms +step:459/1705 train_time:43407ms step_avg:94.57ms +step:460/1705 train_time:43501ms step_avg:94.57ms +step:461/1705 train_time:43594ms step_avg:94.56ms +step:462/1705 train_time:43686ms step_avg:94.56ms +step:463/1705 train_time:43779ms step_avg:94.55ms +step:464/1705 train_time:43871ms step_avg:94.55ms +step:465/1705 train_time:43964ms step_avg:94.55ms +step:466/1705 train_time:44057ms step_avg:94.54ms +step:467/1705 train_time:44150ms step_avg:94.54ms +step:468/1705 train_time:44243ms step_avg:94.54ms +step:469/1705 train_time:44336ms step_avg:94.53ms +step:470/1705 train_time:44430ms step_avg:94.53ms +step:471/1705 train_time:44525ms step_avg:94.53ms +step:472/1705 train_time:44618ms step_avg:94.53ms +step:473/1705 train_time:44710ms step_avg:94.52ms +step:474/1705 train_time:44803ms step_avg:94.52ms +step:475/1705 train_time:44896ms step_avg:94.52ms +step:476/1705 train_time:44988ms step_avg:94.51ms +step:477/1705 train_time:45081ms step_avg:94.51ms +step:478/1705 train_time:45174ms step_avg:94.51ms +step:479/1705 train_time:45266ms step_avg:94.50ms +step:480/1705 train_time:45360ms step_avg:94.50ms +step:481/1705 train_time:45453ms step_avg:94.50ms +step:482/1705 train_time:45546ms step_avg:94.49ms +step:483/1705 train_time:45640ms step_avg:94.49ms +step:484/1705 train_time:45733ms step_avg:94.49ms +step:485/1705 train_time:45826ms step_avg:94.49ms +step:486/1705 train_time:45919ms step_avg:94.48ms +step:487/1705 train_time:46012ms step_avg:94.48ms +step:488/1705 train_time:46104ms step_avg:94.48ms +step:489/1705 train_time:46197ms step_avg:94.47ms +step:490/1705 train_time:46290ms step_avg:94.47ms +step:491/1705 train_time:46383ms step_avg:94.47ms +step:492/1705 train_time:46477ms step_avg:94.47ms +step:493/1705 train_time:46571ms step_avg:94.46ms +step:494/1705 train_time:46664ms step_avg:94.46ms +step:495/1705 train_time:46758ms step_avg:94.46ms +step:496/1705 train_time:46851ms step_avg:94.46ms +step:497/1705 train_time:46944ms step_avg:94.45ms +step:498/1705 train_time:47037ms step_avg:94.45ms +step:499/1705 train_time:47131ms step_avg:94.45ms +step:500/1705 train_time:47224ms step_avg:94.45ms +step:500/1705 val_loss:3.7233 train_time:47317ms step_avg:94.63ms +step:501/1705 train_time:47339ms step_avg:94.49ms +step:502/1705 train_time:47416ms step_avg:94.45ms +step:503/1705 train_time:47514ms step_avg:94.46ms +step:504/1705 train_time:47608ms step_avg:94.46ms +step:505/1705 train_time:47700ms step_avg:94.45ms +step:506/1705 train_time:47792ms step_avg:94.45ms +step:507/1705 train_time:47883ms step_avg:94.44ms +step:508/1705 train_time:47976ms step_avg:94.44ms +step:509/1705 train_time:48067ms step_avg:94.43ms +step:510/1705 train_time:48159ms step_avg:94.43ms +step:511/1705 train_time:48252ms step_avg:94.43ms +step:512/1705 train_time:48347ms step_avg:94.43ms +step:513/1705 train_time:48444ms step_avg:94.43ms +step:514/1705 train_time:48538ms step_avg:94.43ms +step:515/1705 train_time:48631ms step_avg:94.43ms +step:516/1705 train_time:48724ms step_avg:94.43ms +step:517/1705 train_time:48816ms step_avg:94.42ms +step:518/1705 train_time:48909ms step_avg:94.42ms +step:519/1705 train_time:49000ms step_avg:94.41ms +step:520/1705 train_time:49093ms step_avg:94.41ms +step:521/1705 train_time:49185ms step_avg:94.40ms +step:522/1705 train_time:49279ms step_avg:94.40ms +step:523/1705 train_time:49373ms step_avg:94.40ms +step:524/1705 train_time:49467ms step_avg:94.40ms +step:525/1705 train_time:49561ms step_avg:94.40ms +step:526/1705 train_time:49656ms step_avg:94.40ms +step:527/1705 train_time:49749ms step_avg:94.40ms +step:528/1705 train_time:49841ms step_avg:94.40ms +step:529/1705 train_time:49934ms step_avg:94.39ms +step:530/1705 train_time:50026ms step_avg:94.39ms +step:531/1705 train_time:50118ms step_avg:94.38ms +step:532/1705 train_time:50211ms step_avg:94.38ms +step:533/1705 train_time:50304ms step_avg:94.38ms +step:534/1705 train_time:50398ms step_avg:94.38ms +step:535/1705 train_time:50491ms step_avg:94.38ms +step:536/1705 train_time:50584ms step_avg:94.37ms +step:537/1705 train_time:50678ms step_avg:94.37ms +step:538/1705 train_time:50771ms step_avg:94.37ms +step:539/1705 train_time:50864ms step_avg:94.37ms +step:540/1705 train_time:50956ms step_avg:94.36ms +step:541/1705 train_time:51049ms step_avg:94.36ms +step:542/1705 train_time:51142ms step_avg:94.36ms +step:543/1705 train_time:51235ms step_avg:94.36ms +step:544/1705 train_time:51329ms step_avg:94.35ms +step:545/1705 train_time:51421ms step_avg:94.35ms +step:546/1705 train_time:51515ms step_avg:94.35ms +step:547/1705 train_time:51608ms step_avg:94.35ms +step:548/1705 train_time:51701ms step_avg:94.34ms +step:549/1705 train_time:51795ms step_avg:94.34ms +step:550/1705 train_time:51888ms step_avg:94.34ms +step:551/1705 train_time:51980ms step_avg:94.34ms +step:552/1705 train_time:52073ms step_avg:94.33ms +step:553/1705 train_time:52165ms step_avg:94.33ms +step:554/1705 train_time:52258ms step_avg:94.33ms +step:555/1705 train_time:52351ms step_avg:94.33ms +step:556/1705 train_time:52444ms step_avg:94.32ms +step:557/1705 train_time:52537ms step_avg:94.32ms +step:558/1705 train_time:52631ms step_avg:94.32ms +step:559/1705 train_time:52723ms step_avg:94.32ms +step:560/1705 train_time:52816ms step_avg:94.32ms +step:561/1705 train_time:52910ms step_avg:94.31ms +step:562/1705 train_time:53003ms step_avg:94.31ms +step:563/1705 train_time:53095ms step_avg:94.31ms +step:564/1705 train_time:53188ms step_avg:94.30ms +step:565/1705 train_time:53280ms step_avg:94.30ms +step:566/1705 train_time:53374ms step_avg:94.30ms +step:567/1705 train_time:53467ms step_avg:94.30ms +step:568/1705 train_time:53560ms step_avg:94.30ms +step:569/1705 train_time:53653ms step_avg:94.29ms +step:570/1705 train_time:53747ms step_avg:94.29ms +step:571/1705 train_time:53841ms step_avg:94.29ms +step:572/1705 train_time:53935ms step_avg:94.29ms +step:573/1705 train_time:54029ms step_avg:94.29ms +step:574/1705 train_time:54123ms step_avg:94.29ms +step:575/1705 train_time:54218ms step_avg:94.29ms +step:576/1705 train_time:54313ms step_avg:94.29ms +step:577/1705 train_time:54408ms step_avg:94.29ms +step:578/1705 train_time:54500ms step_avg:94.29ms +step:579/1705 train_time:54595ms step_avg:94.29ms +step:580/1705 train_time:54690ms step_avg:94.29ms +step:581/1705 train_time:54784ms step_avg:94.29ms +step:582/1705 train_time:54879ms step_avg:94.29ms +step:583/1705 train_time:54973ms step_avg:94.29ms +step:584/1705 train_time:55067ms step_avg:94.29ms +step:585/1705 train_time:55161ms step_avg:94.29ms +step:586/1705 train_time:55255ms step_avg:94.29ms +step:587/1705 train_time:55351ms step_avg:94.29ms +step:588/1705 train_time:55444ms step_avg:94.29ms +step:589/1705 train_time:55539ms step_avg:94.29ms +step:590/1705 train_time:55634ms step_avg:94.29ms +step:591/1705 train_time:55729ms step_avg:94.30ms +step:592/1705 train_time:55824ms step_avg:94.30ms +step:593/1705 train_time:55918ms step_avg:94.30ms +step:594/1705 train_time:56013ms step_avg:94.30ms +step:595/1705 train_time:56106ms step_avg:94.30ms +step:596/1705 train_time:56200ms step_avg:94.29ms +step:597/1705 train_time:56294ms step_avg:94.30ms +step:598/1705 train_time:56389ms step_avg:94.30ms +step:599/1705 train_time:56483ms step_avg:94.29ms +step:600/1705 train_time:56578ms step_avg:94.30ms +step:601/1705 train_time:56672ms step_avg:94.30ms +step:602/1705 train_time:56767ms step_avg:94.30ms +step:603/1705 train_time:56861ms step_avg:94.30ms +step:604/1705 train_time:56956ms step_avg:94.30ms +step:605/1705 train_time:57050ms step_avg:94.30ms +step:606/1705 train_time:57144ms step_avg:94.30ms +step:607/1705 train_time:57238ms step_avg:94.30ms +step:608/1705 train_time:57334ms step_avg:94.30ms +step:609/1705 train_time:57429ms step_avg:94.30ms +step:610/1705 train_time:57523ms step_avg:94.30ms +step:611/1705 train_time:57617ms step_avg:94.30ms +step:612/1705 train_time:57712ms step_avg:94.30ms +step:613/1705 train_time:57806ms step_avg:94.30ms +step:614/1705 train_time:57900ms step_avg:94.30ms +step:615/1705 train_time:57995ms step_avg:94.30ms +step:616/1705 train_time:58089ms step_avg:94.30ms +step:617/1705 train_time:58182ms step_avg:94.30ms +step:618/1705 train_time:58276ms step_avg:94.30ms +step:619/1705 train_time:58371ms step_avg:94.30ms +step:620/1705 train_time:58466ms step_avg:94.30ms +step:621/1705 train_time:58559ms step_avg:94.30ms +step:622/1705 train_time:58654ms step_avg:94.30ms +step:623/1705 train_time:58750ms step_avg:94.30ms +step:624/1705 train_time:58843ms step_avg:94.30ms +step:625/1705 train_time:58937ms step_avg:94.30ms +step:625/1705 val_loss:3.6227 train_time:59033ms step_avg:94.45ms +step:626/1705 train_time:59055ms step_avg:94.34ms +step:627/1705 train_time:59138ms step_avg:94.32ms +step:628/1705 train_time:59237ms step_avg:94.33ms +step:629/1705 train_time:59331ms step_avg:94.33ms +step:630/1705 train_time:59424ms step_avg:94.32ms +step:631/1705 train_time:59517ms step_avg:94.32ms +step:632/1705 train_time:59610ms step_avg:94.32ms +step:633/1705 train_time:59703ms step_avg:94.32ms +step:634/1705 train_time:59797ms step_avg:94.32ms +step:635/1705 train_time:59890ms step_avg:94.32ms +step:636/1705 train_time:59986ms step_avg:94.32ms +step:637/1705 train_time:60082ms step_avg:94.32ms +step:638/1705 train_time:60179ms step_avg:94.32ms +step:639/1705 train_time:60551ms step_avg:94.76ms +step:640/1705 train_time:60639ms step_avg:94.75ms +step:641/1705 train_time:60732ms step_avg:94.75ms +step:642/1705 train_time:60825ms step_avg:94.74ms +step:643/1705 train_time:60918ms step_avg:94.74ms +step:644/1705 train_time:61012ms step_avg:94.74ms +step:645/1705 train_time:61105ms step_avg:94.74ms +step:646/1705 train_time:61198ms step_avg:94.73ms +step:647/1705 train_time:61291ms step_avg:94.73ms +step:648/1705 train_time:61384ms step_avg:94.73ms +step:649/1705 train_time:61483ms step_avg:94.73ms +step:650/1705 train_time:61580ms step_avg:94.74ms +step:651/1705 train_time:61677ms step_avg:94.74ms +step:652/1705 train_time:61771ms step_avg:94.74ms +step:653/1705 train_time:61865ms step_avg:94.74ms +step:654/1705 train_time:61959ms step_avg:94.74ms +step:655/1705 train_time:62053ms step_avg:94.74ms +step:656/1705 train_time:62145ms step_avg:94.73ms +step:657/1705 train_time:62239ms step_avg:94.73ms +step:658/1705 train_time:62332ms step_avg:94.73ms +step:659/1705 train_time:62426ms step_avg:94.73ms +step:660/1705 train_time:62521ms step_avg:94.73ms +step:661/1705 train_time:62617ms step_avg:94.73ms +step:662/1705 train_time:62713ms step_avg:94.73ms +step:663/1705 train_time:62807ms step_avg:94.73ms +step:664/1705 train_time:62901ms step_avg:94.73ms +step:665/1705 train_time:62996ms step_avg:94.73ms +step:666/1705 train_time:63090ms step_avg:94.73ms +step:667/1705 train_time:63183ms step_avg:94.73ms +step:668/1705 train_time:63277ms step_avg:94.73ms +step:669/1705 train_time:63372ms step_avg:94.73ms +step:670/1705 train_time:63466ms step_avg:94.73ms +step:671/1705 train_time:63561ms step_avg:94.73ms +step:672/1705 train_time:63656ms step_avg:94.73ms +step:673/1705 train_time:63752ms step_avg:94.73ms +step:674/1705 train_time:63846ms step_avg:94.73ms +step:675/1705 train_time:63941ms step_avg:94.73ms +step:676/1705 train_time:64036ms step_avg:94.73ms +step:677/1705 train_time:64129ms step_avg:94.73ms +step:678/1705 train_time:64223ms step_avg:94.72ms +step:679/1705 train_time:64317ms step_avg:94.72ms +step:680/1705 train_time:64411ms step_avg:94.72ms +step:681/1705 train_time:64506ms step_avg:94.72ms +step:682/1705 train_time:64600ms step_avg:94.72ms +step:683/1705 train_time:64696ms step_avg:94.72ms +step:684/1705 train_time:64791ms step_avg:94.72ms +step:685/1705 train_time:64884ms step_avg:94.72ms +step:686/1705 train_time:64979ms step_avg:94.72ms +step:687/1705 train_time:65073ms step_avg:94.72ms +step:688/1705 train_time:65168ms step_avg:94.72ms +step:689/1705 train_time:65261ms step_avg:94.72ms +step:690/1705 train_time:65355ms step_avg:94.72ms +step:691/1705 train_time:65450ms step_avg:94.72ms +step:692/1705 train_time:65543ms step_avg:94.72ms +step:693/1705 train_time:65638ms step_avg:94.72ms +step:694/1705 train_time:65733ms step_avg:94.72ms +step:695/1705 train_time:65829ms step_avg:94.72ms +step:696/1705 train_time:65922ms step_avg:94.72ms +step:697/1705 train_time:66017ms step_avg:94.72ms +step:698/1705 train_time:66111ms step_avg:94.71ms +step:699/1705 train_time:66205ms step_avg:94.71ms +step:700/1705 train_time:66299ms step_avg:94.71ms +step:701/1705 train_time:66394ms step_avg:94.71ms +step:702/1705 train_time:66488ms step_avg:94.71ms +step:703/1705 train_time:66581ms step_avg:94.71ms +step:704/1705 train_time:66677ms step_avg:94.71ms +step:705/1705 train_time:66772ms step_avg:94.71ms +step:706/1705 train_time:66866ms step_avg:94.71ms +step:707/1705 train_time:66960ms step_avg:94.71ms +step:708/1705 train_time:67055ms step_avg:94.71ms +step:709/1705 train_time:67149ms step_avg:94.71ms +step:710/1705 train_time:67243ms step_avg:94.71ms +step:711/1705 train_time:67337ms step_avg:94.71ms +step:712/1705 train_time:67431ms step_avg:94.71ms +step:713/1705 train_time:67524ms step_avg:94.70ms +step:714/1705 train_time:67619ms step_avg:94.70ms +step:715/1705 train_time:67714ms step_avg:94.70ms +step:716/1705 train_time:67809ms step_avg:94.70ms +step:717/1705 train_time:67903ms step_avg:94.70ms +step:718/1705 train_time:67997ms step_avg:94.70ms +step:719/1705 train_time:68093ms step_avg:94.70ms +step:720/1705 train_time:68187ms step_avg:94.70ms +step:721/1705 train_time:68280ms step_avg:94.70ms +step:722/1705 train_time:68375ms step_avg:94.70ms +step:723/1705 train_time:68469ms step_avg:94.70ms +step:724/1705 train_time:68563ms step_avg:94.70ms +step:725/1705 train_time:68657ms step_avg:94.70ms +step:726/1705 train_time:68752ms step_avg:94.70ms +step:727/1705 train_time:68846ms step_avg:94.70ms +step:728/1705 train_time:68940ms step_avg:94.70ms +step:729/1705 train_time:69035ms step_avg:94.70ms +step:730/1705 train_time:69130ms step_avg:94.70ms +step:731/1705 train_time:69224ms step_avg:94.70ms +step:732/1705 train_time:69318ms step_avg:94.70ms +step:733/1705 train_time:69413ms step_avg:94.70ms +step:734/1705 train_time:69507ms step_avg:94.70ms +step:735/1705 train_time:69601ms step_avg:94.70ms +step:736/1705 train_time:69696ms step_avg:94.70ms +step:737/1705 train_time:69791ms step_avg:94.70ms +step:738/1705 train_time:69885ms step_avg:94.69ms +step:739/1705 train_time:69978ms step_avg:94.69ms +step:740/1705 train_time:70073ms step_avg:94.69ms +step:741/1705 train_time:70168ms step_avg:94.69ms +step:742/1705 train_time:70261ms step_avg:94.69ms +step:743/1705 train_time:70357ms step_avg:94.69ms +step:744/1705 train_time:70452ms step_avg:94.69ms +step:745/1705 train_time:70547ms step_avg:94.69ms +step:746/1705 train_time:70640ms step_avg:94.69ms +step:747/1705 train_time:70735ms step_avg:94.69ms +step:748/1705 train_time:70829ms step_avg:94.69ms +step:749/1705 train_time:70923ms step_avg:94.69ms +step:750/1705 train_time:71017ms step_avg:94.69ms +step:750/1705 val_loss:3.5689 train_time:71113ms step_avg:94.82ms +step:751/1705 train_time:71134ms step_avg:94.72ms +step:752/1705 train_time:71213ms step_avg:94.70ms +step:753/1705 train_time:71314ms step_avg:94.71ms +step:754/1705 train_time:71409ms step_avg:94.71ms +step:755/1705 train_time:71503ms step_avg:94.71ms +step:756/1705 train_time:71596ms step_avg:94.70ms +step:757/1705 train_time:71689ms step_avg:94.70ms +step:758/1705 train_time:71782ms step_avg:94.70ms +step:759/1705 train_time:71875ms step_avg:94.70ms +step:760/1705 train_time:71969ms step_avg:94.70ms +step:761/1705 train_time:72062ms step_avg:94.69ms +step:762/1705 train_time:72159ms step_avg:94.70ms +step:763/1705 train_time:72256ms step_avg:94.70ms +step:764/1705 train_time:72353ms step_avg:94.70ms +step:765/1705 train_time:72448ms step_avg:94.70ms +step:766/1705 train_time:72541ms step_avg:94.70ms +step:767/1705 train_time:72636ms step_avg:94.70ms +step:768/1705 train_time:72730ms step_avg:94.70ms +step:769/1705 train_time:72822ms step_avg:94.70ms +step:770/1705 train_time:72917ms step_avg:94.70ms +step:771/1705 train_time:73010ms step_avg:94.70ms +step:772/1705 train_time:73104ms step_avg:94.69ms +step:773/1705 train_time:73199ms step_avg:94.70ms +step:774/1705 train_time:73296ms step_avg:94.70ms +step:775/1705 train_time:73392ms step_avg:94.70ms +step:776/1705 train_time:73486ms step_avg:94.70ms +step:777/1705 train_time:73579ms step_avg:94.70ms +step:778/1705 train_time:73674ms step_avg:94.70ms +step:779/1705 train_time:73768ms step_avg:94.70ms +step:780/1705 train_time:73861ms step_avg:94.69ms +step:781/1705 train_time:73955ms step_avg:94.69ms +step:782/1705 train_time:74049ms step_avg:94.69ms +step:783/1705 train_time:74143ms step_avg:94.69ms +step:784/1705 train_time:74239ms step_avg:94.69ms +step:785/1705 train_time:74336ms step_avg:94.70ms +step:786/1705 train_time:74431ms step_avg:94.70ms +step:787/1705 train_time:74525ms step_avg:94.70ms +step:788/1705 train_time:74619ms step_avg:94.69ms +step:789/1705 train_time:74714ms step_avg:94.69ms +step:790/1705 train_time:74808ms step_avg:94.69ms +step:791/1705 train_time:74901ms step_avg:94.69ms +step:792/1705 train_time:74995ms step_avg:94.69ms +step:793/1705 train_time:75089ms step_avg:94.69ms +step:794/1705 train_time:75183ms step_avg:94.69ms +step:795/1705 train_time:75278ms step_avg:94.69ms +step:796/1705 train_time:75373ms step_avg:94.69ms +step:797/1705 train_time:75467ms step_avg:94.69ms +step:798/1705 train_time:75562ms step_avg:94.69ms +step:799/1705 train_time:75656ms step_avg:94.69ms +step:800/1705 train_time:75751ms step_avg:94.69ms +step:801/1705 train_time:75844ms step_avg:94.69ms +step:802/1705 train_time:75938ms step_avg:94.69ms +step:803/1705 train_time:76033ms step_avg:94.69ms +step:804/1705 train_time:76127ms step_avg:94.69ms +step:805/1705 train_time:76221ms step_avg:94.68ms +step:806/1705 train_time:76316ms step_avg:94.69ms +step:807/1705 train_time:76411ms step_avg:94.69ms +step:808/1705 train_time:76507ms step_avg:94.69ms +step:809/1705 train_time:76601ms step_avg:94.69ms +step:810/1705 train_time:76695ms step_avg:94.69ms +step:811/1705 train_time:76789ms step_avg:94.68ms +step:812/1705 train_time:76882ms step_avg:94.68ms +step:813/1705 train_time:76976ms step_avg:94.68ms +step:814/1705 train_time:77071ms step_avg:94.68ms +step:815/1705 train_time:77165ms step_avg:94.68ms +step:816/1705 train_time:77259ms step_avg:94.68ms +step:817/1705 train_time:77356ms step_avg:94.68ms +step:818/1705 train_time:77450ms step_avg:94.68ms +step:819/1705 train_time:77545ms step_avg:94.68ms +step:820/1705 train_time:77640ms step_avg:94.68ms +step:821/1705 train_time:77735ms step_avg:94.68ms +step:822/1705 train_time:77829ms step_avg:94.68ms +step:823/1705 train_time:77923ms step_avg:94.68ms +step:824/1705 train_time:78018ms step_avg:94.68ms +step:825/1705 train_time:78113ms step_avg:94.68ms +step:826/1705 train_time:78207ms step_avg:94.68ms +step:827/1705 train_time:78301ms step_avg:94.68ms +step:828/1705 train_time:78397ms step_avg:94.68ms +step:829/1705 train_time:78492ms step_avg:94.68ms +step:830/1705 train_time:78586ms step_avg:94.68ms +step:831/1705 train_time:78680ms step_avg:94.68ms +step:832/1705 train_time:78775ms step_avg:94.68ms +step:833/1705 train_time:78869ms step_avg:94.68ms +step:834/1705 train_time:78962ms step_avg:94.68ms +step:835/1705 train_time:79057ms step_avg:94.68ms +step:836/1705 train_time:79151ms step_avg:94.68ms +step:837/1705 train_time:79246ms step_avg:94.68ms +step:838/1705 train_time:79339ms step_avg:94.68ms +step:839/1705 train_time:79434ms step_avg:94.68ms +step:840/1705 train_time:79530ms step_avg:94.68ms +step:841/1705 train_time:79624ms step_avg:94.68ms +step:842/1705 train_time:79718ms step_avg:94.68ms +step:843/1705 train_time:79813ms step_avg:94.68ms +step:844/1705 train_time:79907ms step_avg:94.68ms +step:845/1705 train_time:80002ms step_avg:94.68ms +step:846/1705 train_time:80096ms step_avg:94.68ms +step:847/1705 train_time:80190ms step_avg:94.68ms +step:848/1705 train_time:80284ms step_avg:94.67ms +step:849/1705 train_time:80379ms step_avg:94.67ms +step:850/1705 train_time:80474ms step_avg:94.68ms +step:851/1705 train_time:80742ms step_avg:94.88ms +step:852/1705 train_time:80823ms step_avg:94.86ms +step:853/1705 train_time:80916ms step_avg:94.86ms +step:854/1705 train_time:81009ms step_avg:94.86ms +step:855/1705 train_time:81102ms step_avg:94.86ms +step:856/1705 train_time:81196ms step_avg:94.86ms +step:857/1705 train_time:81289ms step_avg:94.85ms +step:858/1705 train_time:81383ms step_avg:94.85ms +step:859/1705 train_time:81476ms step_avg:94.85ms +step:860/1705 train_time:81570ms step_avg:94.85ms +step:861/1705 train_time:81668ms step_avg:94.85ms +step:862/1705 train_time:81764ms step_avg:94.85ms +step:863/1705 train_time:81858ms step_avg:94.85ms +step:864/1705 train_time:81953ms step_avg:94.85ms +step:865/1705 train_time:82047ms step_avg:94.85ms +step:866/1705 train_time:82141ms step_avg:94.85ms +step:867/1705 train_time:82235ms step_avg:94.85ms +step:868/1705 train_time:82328ms step_avg:94.85ms +step:869/1705 train_time:82421ms step_avg:94.85ms +step:870/1705 train_time:82517ms step_avg:94.85ms +step:871/1705 train_time:82611ms step_avg:94.85ms +step:872/1705 train_time:82707ms step_avg:94.85ms +step:873/1705 train_time:82802ms step_avg:94.85ms +step:874/1705 train_time:82897ms step_avg:94.85ms +step:875/1705 train_time:82992ms step_avg:94.85ms +step:875/1705 val_loss:3.5262 train_time:83086ms step_avg:94.96ms +step:876/1705 train_time:83108ms step_avg:94.87ms +step:877/1705 train_time:83186ms step_avg:94.85ms +step:878/1705 train_time:83283ms step_avg:94.86ms +step:879/1705 train_time:83378ms step_avg:94.86ms +step:880/1705 train_time:83473ms step_avg:94.86ms +step:881/1705 train_time:83566ms step_avg:94.85ms +step:882/1705 train_time:83659ms step_avg:94.85ms +step:883/1705 train_time:83754ms step_avg:94.85ms +step:884/1705 train_time:83847ms step_avg:94.85ms +step:885/1705 train_time:83941ms step_avg:94.85ms +step:886/1705 train_time:84036ms step_avg:94.85ms +step:887/1705 train_time:84133ms step_avg:94.85ms +step:888/1705 train_time:84229ms step_avg:94.85ms +step:889/1705 train_time:84326ms step_avg:94.86ms +step:890/1705 train_time:84421ms step_avg:94.85ms +step:891/1705 train_time:84515ms step_avg:94.85ms +step:892/1705 train_time:84608ms step_avg:94.85ms +step:893/1705 train_time:84702ms step_avg:94.85ms +step:894/1705 train_time:84796ms step_avg:94.85ms +step:895/1705 train_time:84889ms step_avg:94.85ms +step:896/1705 train_time:84983ms step_avg:94.85ms +step:897/1705 train_time:85078ms step_avg:94.85ms +step:898/1705 train_time:85174ms step_avg:94.85ms +step:899/1705 train_time:85270ms step_avg:94.85ms +step:900/1705 train_time:85364ms step_avg:94.85ms +step:901/1705 train_time:85458ms step_avg:94.85ms +step:902/1705 train_time:85552ms step_avg:94.85ms +step:903/1705 train_time:85646ms step_avg:94.85ms +step:904/1705 train_time:85740ms step_avg:94.84ms +step:905/1705 train_time:85834ms step_avg:94.84ms +step:906/1705 train_time:85927ms step_avg:94.84ms +step:907/1705 train_time:86022ms step_avg:94.84ms +step:908/1705 train_time:86116ms step_avg:94.84ms +step:909/1705 train_time:86212ms step_avg:94.84ms +step:910/1705 train_time:86307ms step_avg:94.84ms +step:911/1705 train_time:86401ms step_avg:94.84ms +step:912/1705 train_time:86496ms step_avg:94.84ms +step:913/1705 train_time:86590ms step_avg:94.84ms +step:914/1705 train_time:86683ms step_avg:94.84ms +step:915/1705 train_time:86778ms step_avg:94.84ms +step:916/1705 train_time:86872ms step_avg:94.84ms +step:917/1705 train_time:86966ms step_avg:94.84ms +step:918/1705 train_time:87060ms step_avg:94.84ms +step:919/1705 train_time:87155ms step_avg:94.84ms +step:920/1705 train_time:87250ms step_avg:94.84ms +step:921/1705 train_time:87344ms step_avg:94.84ms +step:922/1705 train_time:87439ms step_avg:94.84ms +step:923/1705 train_time:87534ms step_avg:94.84ms +step:924/1705 train_time:87628ms step_avg:94.84ms +step:925/1705 train_time:87722ms step_avg:94.83ms +step:926/1705 train_time:87816ms step_avg:94.83ms +step:927/1705 train_time:87911ms step_avg:94.83ms +step:928/1705 train_time:88005ms step_avg:94.83ms +step:929/1705 train_time:88099ms step_avg:94.83ms +step:930/1705 train_time:88193ms step_avg:94.83ms +step:931/1705 train_time:88287ms step_avg:94.83ms +step:932/1705 train_time:88382ms step_avg:94.83ms +step:933/1705 train_time:88476ms step_avg:94.83ms +step:934/1705 train_time:88571ms step_avg:94.83ms +step:935/1705 train_time:88664ms step_avg:94.83ms +step:936/1705 train_time:88758ms step_avg:94.83ms +step:937/1705 train_time:88853ms step_avg:94.83ms +step:938/1705 train_time:88947ms step_avg:94.83ms +step:939/1705 train_time:89041ms step_avg:94.83ms +step:940/1705 train_time:89137ms step_avg:94.83ms +step:941/1705 train_time:89231ms step_avg:94.83ms +step:942/1705 train_time:89324ms step_avg:94.82ms +step:943/1705 train_time:89419ms step_avg:94.82ms +step:944/1705 train_time:89514ms step_avg:94.82ms +step:945/1705 train_time:89610ms step_avg:94.83ms +step:946/1705 train_time:89704ms step_avg:94.82ms +step:947/1705 train_time:89798ms step_avg:94.82ms +step:948/1705 train_time:89894ms step_avg:94.82ms +step:949/1705 train_time:89988ms step_avg:94.82ms +step:950/1705 train_time:90082ms step_avg:94.82ms +step:951/1705 train_time:90177ms step_avg:94.82ms +step:952/1705 train_time:90271ms step_avg:94.82ms +step:953/1705 train_time:90365ms step_avg:94.82ms +step:954/1705 train_time:90460ms step_avg:94.82ms +step:955/1705 train_time:90555ms step_avg:94.82ms +step:956/1705 train_time:90649ms step_avg:94.82ms +step:957/1705 train_time:90743ms step_avg:94.82ms +step:958/1705 train_time:90838ms step_avg:94.82ms +step:959/1705 train_time:90933ms step_avg:94.82ms +step:960/1705 train_time:91026ms step_avg:94.82ms +step:961/1705 train_time:91121ms step_avg:94.82ms +step:962/1705 train_time:91216ms step_avg:94.82ms +step:963/1705 train_time:91310ms step_avg:94.82ms +step:964/1705 train_time:91404ms step_avg:94.82ms +step:965/1705 train_time:91498ms step_avg:94.82ms +step:966/1705 train_time:91592ms step_avg:94.82ms +step:967/1705 train_time:91686ms step_avg:94.82ms +step:968/1705 train_time:91780ms step_avg:94.81ms +step:969/1705 train_time:91876ms step_avg:94.82ms +step:970/1705 train_time:91971ms step_avg:94.82ms +step:971/1705 train_time:92065ms step_avg:94.81ms +step:972/1705 train_time:92159ms step_avg:94.81ms +step:973/1705 train_time:92254ms step_avg:94.81ms +step:974/1705 train_time:92348ms step_avg:94.81ms +step:975/1705 train_time:92442ms step_avg:94.81ms +step:976/1705 train_time:92537ms step_avg:94.81ms +step:977/1705 train_time:92631ms step_avg:94.81ms +step:978/1705 train_time:92725ms step_avg:94.81ms +step:979/1705 train_time:92820ms step_avg:94.81ms +step:980/1705 train_time:92915ms step_avg:94.81ms +step:981/1705 train_time:93009ms step_avg:94.81ms +step:982/1705 train_time:93103ms step_avg:94.81ms +step:983/1705 train_time:93198ms step_avg:94.81ms +step:984/1705 train_time:93293ms step_avg:94.81ms +step:985/1705 train_time:93387ms step_avg:94.81ms +step:986/1705 train_time:93481ms step_avg:94.81ms +step:987/1705 train_time:93576ms step_avg:94.81ms +step:988/1705 train_time:93670ms step_avg:94.81ms +step:989/1705 train_time:93765ms step_avg:94.81ms +step:990/1705 train_time:93860ms step_avg:94.81ms +step:991/1705 train_time:93955ms step_avg:94.81ms +step:992/1705 train_time:94049ms step_avg:94.81ms +step:993/1705 train_time:94143ms step_avg:94.81ms +step:994/1705 train_time:94237ms step_avg:94.81ms +step:995/1705 train_time:94331ms step_avg:94.81ms +step:996/1705 train_time:94425ms step_avg:94.80ms +step:997/1705 train_time:94520ms step_avg:94.80ms +step:998/1705 train_time:94615ms step_avg:94.80ms +step:999/1705 train_time:94710ms step_avg:94.80ms +step:1000/1705 train_time:94804ms step_avg:94.80ms +step:1000/1705 val_loss:3.4882 train_time:94899ms step_avg:94.90ms +step:1001/1705 train_time:94920ms step_avg:94.83ms +step:1002/1705 train_time:94999ms step_avg:94.81ms +step:1003/1705 train_time:95099ms step_avg:94.81ms +step:1004/1705 train_time:95194ms step_avg:94.81ms +step:1005/1705 train_time:95287ms step_avg:94.81ms +step:1006/1705 train_time:95380ms step_avg:94.81ms +step:1007/1705 train_time:95474ms step_avg:94.81ms +step:1008/1705 train_time:95568ms step_avg:94.81ms +step:1009/1705 train_time:95661ms step_avg:94.81ms +step:1010/1705 train_time:95754ms step_avg:94.81ms +step:1011/1705 train_time:95851ms step_avg:94.81ms +step:1012/1705 train_time:95948ms step_avg:94.81ms +step:1013/1705 train_time:96044ms step_avg:94.81ms +step:1014/1705 train_time:96139ms step_avg:94.81ms +step:1015/1705 train_time:96234ms step_avg:94.81ms +step:1016/1705 train_time:96329ms step_avg:94.81ms +step:1017/1705 train_time:96422ms step_avg:94.81ms +step:1018/1705 train_time:96516ms step_avg:94.81ms +step:1019/1705 train_time:96609ms step_avg:94.81ms +step:1020/1705 train_time:96703ms step_avg:94.81ms +step:1021/1705 train_time:96797ms step_avg:94.81ms +step:1022/1705 train_time:96893ms step_avg:94.81ms +step:1023/1705 train_time:96991ms step_avg:94.81ms +step:1024/1705 train_time:97086ms step_avg:94.81ms +step:1025/1705 train_time:97180ms step_avg:94.81ms +step:1026/1705 train_time:97274ms step_avg:94.81ms +step:1027/1705 train_time:97368ms step_avg:94.81ms +step:1028/1705 train_time:97462ms step_avg:94.81ms +step:1029/1705 train_time:97556ms step_avg:94.81ms +step:1030/1705 train_time:97649ms step_avg:94.81ms +step:1031/1705 train_time:97744ms step_avg:94.80ms +step:1032/1705 train_time:97838ms step_avg:94.80ms +step:1033/1705 train_time:97934ms step_avg:94.81ms +step:1034/1705 train_time:98030ms step_avg:94.81ms +step:1035/1705 train_time:98125ms step_avg:94.81ms +step:1036/1705 train_time:98219ms step_avg:94.81ms +step:1037/1705 train_time:98315ms step_avg:94.81ms +step:1038/1705 train_time:98410ms step_avg:94.81ms +step:1039/1705 train_time:98504ms step_avg:94.81ms +step:1040/1705 train_time:98597ms step_avg:94.81ms +step:1041/1705 train_time:98692ms step_avg:94.80ms +step:1042/1705 train_time:98786ms step_avg:94.80ms +step:1043/1705 train_time:98882ms step_avg:94.81ms +step:1044/1705 train_time:98977ms step_avg:94.81ms +step:1045/1705 train_time:99073ms step_avg:94.81ms +step:1046/1705 train_time:99167ms step_avg:94.81ms +step:1047/1705 train_time:99261ms step_avg:94.81ms +step:1048/1705 train_time:99356ms step_avg:94.80ms +step:1049/1705 train_time:99451ms step_avg:94.81ms +step:1050/1705 train_time:99545ms step_avg:94.80ms +step:1051/1705 train_time:99639ms step_avg:94.80ms +step:1052/1705 train_time:99733ms step_avg:94.80ms +step:1053/1705 train_time:99828ms step_avg:94.80ms +step:1054/1705 train_time:99922ms step_avg:94.80ms +step:1055/1705 train_time:100017ms step_avg:94.80ms +step:1056/1705 train_time:100113ms step_avg:94.80ms +step:1057/1705 train_time:100208ms step_avg:94.80ms +step:1058/1705 train_time:100301ms step_avg:94.80ms +step:1059/1705 train_time:100396ms step_avg:94.80ms +step:1060/1705 train_time:100492ms step_avg:94.80ms +step:1061/1705 train_time:100586ms step_avg:94.80ms +step:1062/1705 train_time:100909ms step_avg:95.02ms +step:1063/1705 train_time:101017ms step_avg:95.03ms +step:1064/1705 train_time:101110ms step_avg:95.03ms +step:1065/1705 train_time:101203ms step_avg:95.03ms +step:1066/1705 train_time:101297ms step_avg:95.03ms +step:1067/1705 train_time:101392ms step_avg:95.03ms +step:1068/1705 train_time:101485ms step_avg:95.02ms +step:1069/1705 train_time:101578ms step_avg:95.02ms +step:1070/1705 train_time:101672ms step_avg:95.02ms +step:1071/1705 train_time:101766ms step_avg:95.02ms +step:1072/1705 train_time:101860ms step_avg:95.02ms +step:1073/1705 train_time:101959ms step_avg:95.02ms +step:1074/1705 train_time:102057ms step_avg:95.03ms +step:1075/1705 train_time:102152ms step_avg:95.02ms +step:1076/1705 train_time:102246ms step_avg:95.02ms +step:1077/1705 train_time:102340ms step_avg:95.02ms +step:1078/1705 train_time:102434ms step_avg:95.02ms +step:1079/1705 train_time:102528ms step_avg:95.02ms +step:1080/1705 train_time:102621ms step_avg:95.02ms +step:1081/1705 train_time:102715ms step_avg:95.02ms +step:1082/1705 train_time:102809ms step_avg:95.02ms +step:1083/1705 train_time:102903ms step_avg:95.02ms +step:1084/1705 train_time:102998ms step_avg:95.02ms +step:1085/1705 train_time:103095ms step_avg:95.02ms +step:1086/1705 train_time:103191ms step_avg:95.02ms +step:1087/1705 train_time:103286ms step_avg:95.02ms +step:1088/1705 train_time:103379ms step_avg:95.02ms +step:1089/1705 train_time:103473ms step_avg:95.02ms +step:1090/1705 train_time:103567ms step_avg:95.02ms +step:1091/1705 train_time:103660ms step_avg:95.01ms +step:1092/1705 train_time:103754ms step_avg:95.01ms +step:1093/1705 train_time:103849ms step_avg:95.01ms +step:1094/1705 train_time:103943ms step_avg:95.01ms +step:1095/1705 train_time:104037ms step_avg:95.01ms +step:1096/1705 train_time:104132ms step_avg:95.01ms +step:1097/1705 train_time:104227ms step_avg:95.01ms +step:1098/1705 train_time:104321ms step_avg:95.01ms +step:1099/1705 train_time:104415ms step_avg:95.01ms +step:1100/1705 train_time:104510ms step_avg:95.01ms +step:1101/1705 train_time:104604ms step_avg:95.01ms +step:1102/1705 train_time:104698ms step_avg:95.01ms +step:1103/1705 train_time:104792ms step_avg:95.01ms +step:1104/1705 train_time:104887ms step_avg:95.01ms +step:1105/1705 train_time:104980ms step_avg:95.00ms +step:1106/1705 train_time:105075ms step_avg:95.00ms +step:1107/1705 train_time:105171ms step_avg:95.01ms +step:1108/1705 train_time:105265ms step_avg:95.00ms +step:1109/1705 train_time:105359ms step_avg:95.00ms +step:1110/1705 train_time:105454ms step_avg:95.00ms +step:1111/1705 train_time:105549ms step_avg:95.00ms +step:1112/1705 train_time:105643ms step_avg:95.00ms +step:1113/1705 train_time:105738ms step_avg:95.00ms +step:1114/1705 train_time:105833ms step_avg:95.00ms +step:1115/1705 train_time:105928ms step_avg:95.00ms +step:1116/1705 train_time:106021ms step_avg:95.00ms +step:1117/1705 train_time:106116ms step_avg:95.00ms +step:1118/1705 train_time:106212ms step_avg:95.00ms +step:1119/1705 train_time:106307ms step_avg:95.00ms +step:1120/1705 train_time:106401ms step_avg:95.00ms +step:1121/1705 train_time:106496ms step_avg:95.00ms +step:1122/1705 train_time:106590ms step_avg:95.00ms +step:1123/1705 train_time:106685ms step_avg:95.00ms +step:1124/1705 train_time:106778ms step_avg:95.00ms +step:1125/1705 train_time:106874ms step_avg:95.00ms +step:1125/1705 val_loss:3.4409 train_time:106969ms step_avg:95.08ms +step:1126/1705 train_time:106990ms step_avg:95.02ms +step:1127/1705 train_time:107068ms step_avg:95.00ms +step:1128/1705 train_time:107166ms step_avg:95.01ms +step:1129/1705 train_time:107261ms step_avg:95.01ms +step:1130/1705 train_time:107355ms step_avg:95.00ms +step:1131/1705 train_time:107449ms step_avg:95.00ms +step:1132/1705 train_time:107543ms step_avg:95.00ms +step:1133/1705 train_time:107636ms step_avg:95.00ms +step:1134/1705 train_time:107730ms step_avg:95.00ms +step:1135/1705 train_time:107823ms step_avg:95.00ms +step:1136/1705 train_time:107917ms step_avg:95.00ms +step:1137/1705 train_time:108015ms step_avg:95.00ms +step:1138/1705 train_time:108111ms step_avg:95.00ms +step:1139/1705 train_time:108207ms step_avg:95.00ms +step:1140/1705 train_time:108302ms step_avg:95.00ms +step:1141/1705 train_time:108396ms step_avg:95.00ms +step:1142/1705 train_time:108491ms step_avg:95.00ms +step:1143/1705 train_time:108586ms step_avg:95.00ms +step:1144/1705 train_time:108680ms step_avg:95.00ms +step:1145/1705 train_time:108775ms step_avg:95.00ms +step:1146/1705 train_time:108870ms step_avg:95.00ms +step:1147/1705 train_time:108965ms step_avg:95.00ms +step:1148/1705 train_time:109060ms step_avg:95.00ms +step:1149/1705 train_time:109157ms step_avg:95.00ms +step:1150/1705 train_time:109253ms step_avg:95.00ms +step:1151/1705 train_time:109348ms step_avg:95.00ms +step:1152/1705 train_time:109443ms step_avg:95.00ms +step:1153/1705 train_time:109537ms step_avg:95.00ms +step:1154/1705 train_time:109632ms step_avg:95.00ms +step:1155/1705 train_time:109727ms step_avg:95.00ms +step:1156/1705 train_time:109823ms step_avg:95.00ms +step:1157/1705 train_time:109917ms step_avg:95.00ms +step:1158/1705 train_time:110014ms step_avg:95.00ms +step:1159/1705 train_time:110110ms step_avg:95.00ms +step:1160/1705 train_time:110206ms step_avg:95.01ms +step:1161/1705 train_time:110301ms step_avg:95.01ms +step:1162/1705 train_time:110397ms step_avg:95.01ms +step:1163/1705 train_time:110493ms step_avg:95.01ms +step:1164/1705 train_time:110587ms step_avg:95.01ms +step:1165/1705 train_time:110681ms step_avg:95.01ms +step:1166/1705 train_time:110776ms step_avg:95.01ms +step:1167/1705 train_time:110872ms step_avg:95.01ms +step:1168/1705 train_time:110967ms step_avg:95.01ms +step:1169/1705 train_time:111062ms step_avg:95.01ms +step:1170/1705 train_time:111157ms step_avg:95.01ms +step:1171/1705 train_time:111255ms step_avg:95.01ms +step:1172/1705 train_time:111352ms step_avg:95.01ms +step:1173/1705 train_time:111449ms step_avg:95.01ms +step:1174/1705 train_time:111544ms step_avg:95.01ms +step:1175/1705 train_time:111639ms step_avg:95.01ms +step:1176/1705 train_time:111734ms step_avg:95.01ms +step:1177/1705 train_time:111830ms step_avg:95.01ms +step:1178/1705 train_time:111925ms step_avg:95.01ms +step:1179/1705 train_time:112019ms step_avg:95.01ms +step:1180/1705 train_time:112116ms step_avg:95.01ms +step:1181/1705 train_time:112212ms step_avg:95.01ms +step:1182/1705 train_time:112307ms step_avg:95.01ms +step:1183/1705 train_time:112402ms step_avg:95.01ms +step:1184/1705 train_time:112497ms step_avg:95.01ms +step:1185/1705 train_time:112592ms step_avg:95.01ms +step:1186/1705 train_time:112687ms step_avg:95.01ms +step:1187/1705 train_time:112782ms step_avg:95.01ms +step:1188/1705 train_time:112877ms step_avg:95.01ms +step:1189/1705 train_time:112973ms step_avg:95.02ms +step:1190/1705 train_time:113069ms step_avg:95.02ms +step:1191/1705 train_time:113164ms step_avg:95.02ms +step:1192/1705 train_time:113259ms step_avg:95.02ms +step:1193/1705 train_time:113356ms step_avg:95.02ms +step:1194/1705 train_time:113453ms step_avg:95.02ms +step:1195/1705 train_time:113548ms step_avg:95.02ms +step:1196/1705 train_time:113644ms step_avg:95.02ms +step:1197/1705 train_time:113738ms step_avg:95.02ms +step:1198/1705 train_time:113834ms step_avg:95.02ms +step:1199/1705 train_time:113930ms step_avg:95.02ms +step:1200/1705 train_time:114026ms step_avg:95.02ms +step:1201/1705 train_time:114122ms step_avg:95.02ms +step:1202/1705 train_time:114217ms step_avg:95.02ms +step:1203/1705 train_time:114313ms step_avg:95.02ms +step:1204/1705 train_time:114408ms step_avg:95.02ms +step:1205/1705 train_time:114503ms step_avg:95.02ms +step:1206/1705 train_time:114598ms step_avg:95.02ms +step:1207/1705 train_time:114694ms step_avg:95.02ms +step:1208/1705 train_time:114788ms step_avg:95.02ms +step:1209/1705 train_time:114883ms step_avg:95.02ms +step:1210/1705 train_time:114978ms step_avg:95.02ms +step:1211/1705 train_time:115074ms step_avg:95.02ms +step:1212/1705 train_time:115170ms step_avg:95.02ms +step:1213/1705 train_time:115265ms step_avg:95.02ms +step:1214/1705 train_time:115359ms step_avg:95.02ms +step:1215/1705 train_time:115454ms step_avg:95.02ms +step:1216/1705 train_time:115551ms step_avg:95.03ms +step:1217/1705 train_time:115647ms step_avg:95.03ms +step:1218/1705 train_time:115742ms step_avg:95.03ms +step:1219/1705 train_time:115836ms step_avg:95.03ms +step:1220/1705 train_time:115932ms step_avg:95.03ms +step:1221/1705 train_time:116027ms step_avg:95.03ms +step:1222/1705 train_time:116122ms step_avg:95.03ms +step:1223/1705 train_time:116217ms step_avg:95.03ms +step:1224/1705 train_time:116312ms step_avg:95.03ms +step:1225/1705 train_time:116408ms step_avg:95.03ms +step:1226/1705 train_time:116503ms step_avg:95.03ms +step:1227/1705 train_time:116598ms step_avg:95.03ms +step:1228/1705 train_time:116694ms step_avg:95.03ms +step:1229/1705 train_time:116790ms step_avg:95.03ms +step:1230/1705 train_time:116884ms step_avg:95.03ms +step:1231/1705 train_time:116980ms step_avg:95.03ms +step:1232/1705 train_time:117075ms step_avg:95.03ms +step:1233/1705 train_time:117171ms step_avg:95.03ms +step:1234/1705 train_time:117267ms step_avg:95.03ms +step:1235/1705 train_time:117361ms step_avg:95.03ms +step:1236/1705 train_time:117456ms step_avg:95.03ms +step:1237/1705 train_time:117552ms step_avg:95.03ms +step:1238/1705 train_time:117647ms step_avg:95.03ms +step:1239/1705 train_time:117742ms step_avg:95.03ms +step:1240/1705 train_time:117837ms step_avg:95.03ms +step:1241/1705 train_time:117933ms step_avg:95.03ms +step:1242/1705 train_time:118028ms step_avg:95.03ms +step:1243/1705 train_time:118124ms step_avg:95.03ms +step:1244/1705 train_time:118218ms step_avg:95.03ms +step:1245/1705 train_time:118314ms step_avg:95.03ms +step:1246/1705 train_time:118409ms step_avg:95.03ms +step:1247/1705 train_time:118504ms step_avg:95.03ms +step:1248/1705 train_time:118599ms step_avg:95.03ms +step:1249/1705 train_time:118695ms step_avg:95.03ms +step:1250/1705 train_time:118790ms step_avg:95.03ms +step:1250/1705 val_loss:3.3920 train_time:118886ms step_avg:95.11ms +step:1251/1705 train_time:118908ms step_avg:95.05ms +step:1252/1705 train_time:118988ms step_avg:95.04ms +step:1253/1705 train_time:119083ms step_avg:95.04ms +step:1254/1705 train_time:119179ms step_avg:95.04ms +step:1255/1705 train_time:119273ms step_avg:95.04ms +step:1256/1705 train_time:119367ms step_avg:95.04ms +step:1257/1705 train_time:119461ms step_avg:95.04ms +step:1258/1705 train_time:119556ms step_avg:95.04ms +step:1259/1705 train_time:119650ms step_avg:95.04ms +step:1260/1705 train_time:119743ms step_avg:95.03ms +step:1261/1705 train_time:119843ms step_avg:95.04ms +step:1262/1705 train_time:119942ms step_avg:95.04ms +step:1263/1705 train_time:120038ms step_avg:95.04ms +step:1264/1705 train_time:120134ms step_avg:95.04ms +step:1265/1705 train_time:120229ms step_avg:95.04ms +step:1266/1705 train_time:120323ms step_avg:95.04ms +step:1267/1705 train_time:120418ms step_avg:95.04ms +step:1268/1705 train_time:120512ms step_avg:95.04ms +step:1269/1705 train_time:120606ms step_avg:95.04ms +step:1270/1705 train_time:120700ms step_avg:95.04ms +step:1271/1705 train_time:120797ms step_avg:95.04ms +step:1272/1705 train_time:120894ms step_avg:95.04ms +step:1273/1705 train_time:120990ms step_avg:95.04ms +step:1274/1705 train_time:121361ms step_avg:95.26ms +step:1275/1705 train_time:121445ms step_avg:95.25ms +step:1276/1705 train_time:121539ms step_avg:95.25ms +step:1277/1705 train_time:121633ms step_avg:95.25ms +step:1278/1705 train_time:121727ms step_avg:95.25ms +step:1279/1705 train_time:121820ms step_avg:95.25ms +step:1280/1705 train_time:121915ms step_avg:95.25ms +step:1281/1705 train_time:122009ms step_avg:95.25ms +step:1282/1705 train_time:122103ms step_avg:95.24ms +step:1283/1705 train_time:122197ms step_avg:95.24ms +step:1284/1705 train_time:122296ms step_avg:95.25ms +step:1285/1705 train_time:122396ms step_avg:95.25ms +step:1286/1705 train_time:122493ms step_avg:95.25ms +step:1287/1705 train_time:122588ms step_avg:95.25ms +step:1288/1705 train_time:122682ms step_avg:95.25ms +step:1289/1705 train_time:122777ms step_avg:95.25ms +step:1290/1705 train_time:122872ms step_avg:95.25ms +step:1291/1705 train_time:122966ms step_avg:95.25ms +step:1292/1705 train_time:123061ms step_avg:95.25ms +step:1293/1705 train_time:123155ms step_avg:95.25ms +step:1294/1705 train_time:123252ms step_avg:95.25ms +step:1295/1705 train_time:123348ms step_avg:95.25ms +step:1296/1705 train_time:123445ms step_avg:95.25ms +step:1297/1705 train_time:123541ms step_avg:95.25ms +step:1298/1705 train_time:123637ms step_avg:95.25ms +step:1299/1705 train_time:123731ms step_avg:95.25ms +step:1300/1705 train_time:123825ms step_avg:95.25ms +step:1301/1705 train_time:123920ms step_avg:95.25ms +step:1302/1705 train_time:124014ms step_avg:95.25ms +step:1303/1705 train_time:124108ms step_avg:95.25ms +step:1304/1705 train_time:124203ms step_avg:95.25ms +step:1305/1705 train_time:124299ms step_avg:95.25ms +step:1306/1705 train_time:124396ms step_avg:95.25ms +step:1307/1705 train_time:124492ms step_avg:95.25ms +step:1308/1705 train_time:124588ms step_avg:95.25ms +step:1309/1705 train_time:124683ms step_avg:95.25ms +step:1310/1705 train_time:124778ms step_avg:95.25ms +step:1311/1705 train_time:124872ms step_avg:95.25ms +step:1312/1705 train_time:124966ms step_avg:95.25ms +step:1313/1705 train_time:125061ms step_avg:95.25ms +step:1314/1705 train_time:125157ms step_avg:95.25ms +step:1315/1705 train_time:125252ms step_avg:95.25ms +step:1316/1705 train_time:125348ms step_avg:95.25ms +step:1317/1705 train_time:125443ms step_avg:95.25ms +step:1318/1705 train_time:125542ms step_avg:95.25ms +step:1319/1705 train_time:125638ms step_avg:95.25ms +step:1320/1705 train_time:125733ms step_avg:95.25ms +step:1321/1705 train_time:125828ms step_avg:95.25ms +step:1322/1705 train_time:125923ms step_avg:95.25ms +step:1323/1705 train_time:126017ms step_avg:95.25ms +step:1324/1705 train_time:126113ms step_avg:95.25ms +step:1325/1705 train_time:126209ms step_avg:95.25ms +step:1326/1705 train_time:126304ms step_avg:95.25ms +step:1327/1705 train_time:126399ms step_avg:95.25ms +step:1328/1705 train_time:126495ms step_avg:95.25ms +step:1329/1705 train_time:126590ms step_avg:95.25ms +step:1330/1705 train_time:126685ms step_avg:95.25ms +step:1331/1705 train_time:126781ms step_avg:95.25ms +step:1332/1705 train_time:126876ms step_avg:95.25ms +step:1333/1705 train_time:126970ms step_avg:95.25ms +step:1334/1705 train_time:127065ms step_avg:95.25ms +step:1335/1705 train_time:127161ms step_avg:95.25ms +step:1336/1705 train_time:127257ms step_avg:95.25ms +step:1337/1705 train_time:127353ms step_avg:95.25ms +step:1338/1705 train_time:127448ms step_avg:95.25ms +step:1339/1705 train_time:127543ms step_avg:95.25ms +step:1340/1705 train_time:127640ms step_avg:95.25ms +step:1341/1705 train_time:127734ms step_avg:95.25ms +step:1342/1705 train_time:127829ms step_avg:95.25ms +step:1343/1705 train_time:127924ms step_avg:95.25ms +step:1344/1705 train_time:128019ms step_avg:95.25ms +step:1345/1705 train_time:128114ms step_avg:95.25ms +step:1346/1705 train_time:128208ms step_avg:95.25ms +step:1347/1705 train_time:128304ms step_avg:95.25ms +step:1348/1705 train_time:128399ms step_avg:95.25ms +step:1349/1705 train_time:128496ms step_avg:95.25ms +step:1350/1705 train_time:128592ms step_avg:95.25ms +step:1351/1705 train_time:128686ms step_avg:95.25ms +step:1352/1705 train_time:128782ms step_avg:95.25ms +step:1353/1705 train_time:128878ms step_avg:95.25ms +step:1354/1705 train_time:128974ms step_avg:95.25ms +step:1355/1705 train_time:129068ms step_avg:95.25ms +step:1356/1705 train_time:129164ms step_avg:95.25ms +step:1357/1705 train_time:129260ms step_avg:95.25ms +step:1358/1705 train_time:129354ms step_avg:95.25ms +step:1359/1705 train_time:129449ms step_avg:95.25ms +step:1360/1705 train_time:129544ms step_avg:95.25ms +step:1361/1705 train_time:129640ms step_avg:95.25ms +step:1362/1705 train_time:129737ms step_avg:95.25ms +step:1363/1705 train_time:129831ms step_avg:95.25ms +step:1364/1705 train_time:129925ms step_avg:95.25ms +step:1365/1705 train_time:130021ms step_avg:95.25ms +step:1366/1705 train_time:130117ms step_avg:95.25ms +step:1367/1705 train_time:130212ms step_avg:95.25ms +step:1368/1705 train_time:130307ms step_avg:95.25ms +step:1369/1705 train_time:130403ms step_avg:95.25ms +step:1370/1705 train_time:130499ms step_avg:95.25ms +step:1371/1705 train_time:130596ms step_avg:95.26ms +step:1372/1705 train_time:130691ms step_avg:95.26ms +step:1373/1705 train_time:130785ms step_avg:95.26ms +step:1374/1705 train_time:130881ms step_avg:95.26ms +step:1375/1705 train_time:130976ms step_avg:95.26ms +step:1375/1705 val_loss:3.3546 train_time:131071ms step_avg:95.32ms +step:1376/1705 train_time:131093ms step_avg:95.27ms +step:1377/1705 train_time:131173ms step_avg:95.26ms +step:1378/1705 train_time:131273ms step_avg:95.26ms +step:1379/1705 train_time:131368ms step_avg:95.26ms +step:1380/1705 train_time:131463ms step_avg:95.26ms +step:1381/1705 train_time:131557ms step_avg:95.26ms +step:1382/1705 train_time:131651ms step_avg:95.26ms +step:1383/1705 train_time:131745ms step_avg:95.26ms +step:1384/1705 train_time:131839ms step_avg:95.26ms +step:1385/1705 train_time:131934ms step_avg:95.26ms +step:1386/1705 train_time:132030ms step_avg:95.26ms +step:1387/1705 train_time:132126ms step_avg:95.26ms +step:1388/1705 train_time:132224ms step_avg:95.26ms +step:1389/1705 train_time:132320ms step_avg:95.26ms +step:1390/1705 train_time:132416ms step_avg:95.26ms +step:1391/1705 train_time:132511ms step_avg:95.26ms +step:1392/1705 train_time:132606ms step_avg:95.26ms +step:1393/1705 train_time:132700ms step_avg:95.26ms +step:1394/1705 train_time:132794ms step_avg:95.26ms +step:1395/1705 train_time:132889ms step_avg:95.26ms +step:1396/1705 train_time:132983ms step_avg:95.26ms +step:1397/1705 train_time:133079ms step_avg:95.26ms +step:1398/1705 train_time:133177ms step_avg:95.26ms +step:1399/1705 train_time:133275ms step_avg:95.26ms +step:1400/1705 train_time:133373ms step_avg:95.27ms +step:1401/1705 train_time:133468ms step_avg:95.27ms +step:1402/1705 train_time:133563ms step_avg:95.27ms +step:1403/1705 train_time:133658ms step_avg:95.27ms +step:1404/1705 train_time:133753ms step_avg:95.27ms +step:1405/1705 train_time:133848ms step_avg:95.27ms +step:1406/1705 train_time:133943ms step_avg:95.27ms +step:1407/1705 train_time:134038ms step_avg:95.26ms +step:1408/1705 train_time:134134ms step_avg:95.27ms +step:1409/1705 train_time:134229ms step_avg:95.27ms +step:1410/1705 train_time:134324ms step_avg:95.27ms +step:1411/1705 train_time:134420ms step_avg:95.27ms +step:1412/1705 train_time:134517ms step_avg:95.27ms +step:1413/1705 train_time:134613ms step_avg:95.27ms +step:1414/1705 train_time:134707ms step_avg:95.27ms +step:1415/1705 train_time:134802ms step_avg:95.27ms +step:1416/1705 train_time:134896ms step_avg:95.27ms +step:1417/1705 train_time:134991ms step_avg:95.27ms +step:1418/1705 train_time:135087ms step_avg:95.27ms +step:1419/1705 train_time:135182ms step_avg:95.27ms +step:1420/1705 train_time:135277ms step_avg:95.27ms +step:1421/1705 train_time:135374ms step_avg:95.27ms +step:1422/1705 train_time:135469ms step_avg:95.27ms +step:1423/1705 train_time:135564ms step_avg:95.27ms +step:1424/1705 train_time:135659ms step_avg:95.27ms +step:1425/1705 train_time:135755ms step_avg:95.27ms +step:1426/1705 train_time:135850ms step_avg:95.27ms +step:1427/1705 train_time:135944ms step_avg:95.27ms +step:1428/1705 train_time:136040ms step_avg:95.27ms +step:1429/1705 train_time:136135ms step_avg:95.27ms +step:1430/1705 train_time:136231ms step_avg:95.27ms +step:1431/1705 train_time:136327ms step_avg:95.27ms +step:1432/1705 train_time:136422ms step_avg:95.27ms +step:1433/1705 train_time:136518ms step_avg:95.27ms +step:1434/1705 train_time:136614ms step_avg:95.27ms +step:1435/1705 train_time:136710ms step_avg:95.27ms +step:1436/1705 train_time:136804ms step_avg:95.27ms +step:1437/1705 train_time:136899ms step_avg:95.27ms +step:1438/1705 train_time:136994ms step_avg:95.27ms +step:1439/1705 train_time:137090ms step_avg:95.27ms +step:1440/1705 train_time:137184ms step_avg:95.27ms +step:1441/1705 train_time:137279ms step_avg:95.27ms +step:1442/1705 train_time:137375ms step_avg:95.27ms +step:1443/1705 train_time:137471ms step_avg:95.27ms +step:1444/1705 train_time:137566ms step_avg:95.27ms +step:1445/1705 train_time:137661ms step_avg:95.27ms +step:1446/1705 train_time:137757ms step_avg:95.27ms +step:1447/1705 train_time:137853ms step_avg:95.27ms +step:1448/1705 train_time:137948ms step_avg:95.27ms +step:1449/1705 train_time:138043ms step_avg:95.27ms +step:1450/1705 train_time:138139ms step_avg:95.27ms +step:1451/1705 train_time:138235ms step_avg:95.27ms +step:1452/1705 train_time:138330ms step_avg:95.27ms +step:1453/1705 train_time:138425ms step_avg:95.27ms +step:1454/1705 train_time:138520ms step_avg:95.27ms +step:1455/1705 train_time:138616ms step_avg:95.27ms +step:1456/1705 train_time:138711ms step_avg:95.27ms +step:1457/1705 train_time:138806ms step_avg:95.27ms +step:1458/1705 train_time:138901ms step_avg:95.27ms +step:1459/1705 train_time:138996ms step_avg:95.27ms +step:1460/1705 train_time:139093ms step_avg:95.27ms +step:1461/1705 train_time:139190ms step_avg:95.27ms +step:1462/1705 train_time:139286ms step_avg:95.27ms +step:1463/1705 train_time:139380ms step_avg:95.27ms +step:1464/1705 train_time:139476ms step_avg:95.27ms +step:1465/1705 train_time:139571ms step_avg:95.27ms +step:1466/1705 train_time:139666ms step_avg:95.27ms +step:1467/1705 train_time:139761ms step_avg:95.27ms +step:1468/1705 train_time:139856ms step_avg:95.27ms +step:1469/1705 train_time:139951ms step_avg:95.27ms +step:1470/1705 train_time:140046ms step_avg:95.27ms +step:1471/1705 train_time:140141ms step_avg:95.27ms +step:1472/1705 train_time:140238ms step_avg:95.27ms +step:1473/1705 train_time:140334ms step_avg:95.27ms +step:1474/1705 train_time:140428ms step_avg:95.27ms +step:1475/1705 train_time:140523ms step_avg:95.27ms +step:1476/1705 train_time:140619ms step_avg:95.27ms +step:1477/1705 train_time:140714ms step_avg:95.27ms +step:1478/1705 train_time:140809ms step_avg:95.27ms +step:1479/1705 train_time:140904ms step_avg:95.27ms +step:1480/1705 train_time:141000ms step_avg:95.27ms +step:1481/1705 train_time:141095ms step_avg:95.27ms +step:1482/1705 train_time:141191ms step_avg:95.27ms +step:1483/1705 train_time:141286ms step_avg:95.27ms +step:1484/1705 train_time:141381ms step_avg:95.27ms +step:1485/1705 train_time:141765ms step_avg:95.46ms +step:1486/1705 train_time:141835ms step_avg:95.45ms +step:1487/1705 train_time:141927ms step_avg:95.45ms +step:1488/1705 train_time:142022ms step_avg:95.44ms +step:1489/1705 train_time:142115ms step_avg:95.44ms +step:1490/1705 train_time:142210ms step_avg:95.44ms +step:1491/1705 train_time:142304ms step_avg:95.44ms +step:1492/1705 train_time:142398ms step_avg:95.44ms +step:1493/1705 train_time:142492ms step_avg:95.44ms +step:1494/1705 train_time:142587ms step_avg:95.44ms +step:1495/1705 train_time:142685ms step_avg:95.44ms +step:1496/1705 train_time:142785ms step_avg:95.44ms +step:1497/1705 train_time:142884ms step_avg:95.45ms +step:1498/1705 train_time:142979ms step_avg:95.45ms +step:1499/1705 train_time:143073ms step_avg:95.45ms +step:1500/1705 train_time:143167ms step_avg:95.44ms +step:1500/1705 val_loss:3.3225 train_time:143261ms step_avg:95.51ms +step:1501/1705 train_time:143283ms step_avg:95.46ms +step:1502/1705 train_time:143364ms step_avg:95.45ms +step:1503/1705 train_time:143460ms step_avg:95.45ms +step:1504/1705 train_time:143554ms step_avg:95.45ms +step:1505/1705 train_time:143648ms step_avg:95.45ms +step:1506/1705 train_time:143743ms step_avg:95.45ms +step:1507/1705 train_time:143836ms step_avg:95.45ms +step:1508/1705 train_time:143932ms step_avg:95.45ms +step:1509/1705 train_time:144027ms step_avg:95.45ms +step:1510/1705 train_time:144121ms step_avg:95.44ms +step:1511/1705 train_time:144217ms step_avg:95.44ms +step:1512/1705 train_time:144314ms step_avg:95.45ms +step:1513/1705 train_time:144412ms step_avg:95.45ms +step:1514/1705 train_time:144509ms step_avg:95.45ms +step:1515/1705 train_time:144604ms step_avg:95.45ms +step:1516/1705 train_time:144698ms step_avg:95.45ms +step:1517/1705 train_time:144793ms step_avg:95.45ms +step:1518/1705 train_time:144888ms step_avg:95.45ms +step:1519/1705 train_time:144982ms step_avg:95.45ms +step:1520/1705 train_time:145076ms step_avg:95.44ms +step:1521/1705 train_time:145171ms step_avg:95.44ms +step:1522/1705 train_time:145269ms step_avg:95.45ms +step:1523/1705 train_time:145366ms step_avg:95.45ms +step:1524/1705 train_time:145464ms step_avg:95.45ms +step:1525/1705 train_time:145558ms step_avg:95.45ms +step:1526/1705 train_time:145653ms step_avg:95.45ms +step:1527/1705 train_time:145749ms step_avg:95.45ms +step:1528/1705 train_time:145844ms step_avg:95.45ms +step:1529/1705 train_time:145938ms step_avg:95.45ms +step:1530/1705 train_time:146032ms step_avg:95.45ms +step:1531/1705 train_time:146127ms step_avg:95.45ms +step:1532/1705 train_time:146223ms step_avg:95.45ms +step:1533/1705 train_time:146319ms step_avg:95.45ms +step:1534/1705 train_time:146415ms step_avg:95.45ms +step:1535/1705 train_time:146512ms step_avg:95.45ms +step:1536/1705 train_time:146608ms step_avg:95.45ms +step:1537/1705 train_time:146703ms step_avg:95.45ms +step:1538/1705 train_time:146798ms step_avg:95.45ms +step:1539/1705 train_time:146892ms step_avg:95.45ms +step:1540/1705 train_time:146987ms step_avg:95.45ms +step:1541/1705 train_time:147082ms step_avg:95.45ms +step:1542/1705 train_time:147177ms step_avg:95.45ms +step:1543/1705 train_time:147273ms step_avg:95.45ms +step:1544/1705 train_time:147369ms step_avg:95.45ms +step:1545/1705 train_time:147466ms step_avg:95.45ms +step:1546/1705 train_time:147562ms step_avg:95.45ms +step:1547/1705 train_time:147657ms step_avg:95.45ms +step:1548/1705 train_time:147751ms step_avg:95.45ms +step:1549/1705 train_time:147847ms step_avg:95.45ms +step:1550/1705 train_time:147941ms step_avg:95.45ms +step:1551/1705 train_time:148035ms step_avg:95.45ms +step:1552/1705 train_time:148131ms step_avg:95.44ms +step:1553/1705 train_time:148226ms step_avg:95.45ms +step:1554/1705 train_time:148322ms step_avg:95.45ms +step:1555/1705 train_time:148417ms step_avg:95.45ms +step:1556/1705 train_time:148513ms step_avg:95.45ms +step:1557/1705 train_time:148609ms step_avg:95.45ms +step:1558/1705 train_time:148705ms step_avg:95.45ms +step:1559/1705 train_time:148800ms step_avg:95.45ms +step:1560/1705 train_time:148895ms step_avg:95.45ms +step:1561/1705 train_time:148990ms step_avg:95.45ms +step:1562/1705 train_time:149086ms step_avg:95.45ms +step:1563/1705 train_time:149181ms step_avg:95.45ms +step:1564/1705 train_time:149276ms step_avg:95.45ms +step:1565/1705 train_time:149372ms step_avg:95.45ms +step:1566/1705 train_time:149468ms step_avg:95.45ms +step:1567/1705 train_time:149563ms step_avg:95.45ms +step:1568/1705 train_time:149658ms step_avg:95.44ms +step:1569/1705 train_time:149752ms step_avg:95.44ms +step:1570/1705 train_time:149848ms step_avg:95.44ms +step:1571/1705 train_time:149944ms step_avg:95.44ms +step:1572/1705 train_time:150038ms step_avg:95.44ms +step:1573/1705 train_time:150133ms step_avg:95.44ms +step:1574/1705 train_time:150229ms step_avg:95.44ms +step:1575/1705 train_time:150326ms step_avg:95.45ms +step:1576/1705 train_time:150422ms step_avg:95.45ms +step:1577/1705 train_time:150518ms step_avg:95.45ms +step:1578/1705 train_time:150614ms step_avg:95.45ms +step:1579/1705 train_time:150709ms step_avg:95.45ms +step:1580/1705 train_time:150804ms step_avg:95.45ms +step:1581/1705 train_time:150901ms step_avg:95.45ms +step:1582/1705 train_time:150995ms step_avg:95.45ms +step:1583/1705 train_time:151091ms step_avg:95.45ms +step:1584/1705 train_time:151186ms step_avg:95.45ms +step:1585/1705 train_time:151281ms step_avg:95.45ms +step:1586/1705 train_time:151376ms step_avg:95.45ms +step:1587/1705 train_time:151471ms step_avg:95.44ms +step:1588/1705 train_time:151567ms step_avg:95.45ms +step:1589/1705 train_time:151663ms step_avg:95.45ms +step:1590/1705 train_time:151757ms step_avg:95.44ms +step:1591/1705 train_time:151852ms step_avg:95.44ms +step:1592/1705 train_time:151949ms step_avg:95.45ms +step:1593/1705 train_time:152046ms step_avg:95.45ms +step:1594/1705 train_time:152141ms step_avg:95.45ms +step:1595/1705 train_time:152236ms step_avg:95.45ms +step:1596/1705 train_time:152331ms step_avg:95.45ms +step:1597/1705 train_time:152427ms step_avg:95.45ms +step:1598/1705 train_time:152523ms step_avg:95.45ms +step:1599/1705 train_time:152617ms step_avg:95.45ms +step:1600/1705 train_time:152712ms step_avg:95.44ms +step:1601/1705 train_time:152807ms step_avg:95.45ms +step:1602/1705 train_time:152902ms step_avg:95.44ms +step:1603/1705 train_time:152997ms step_avg:95.44ms +step:1604/1705 train_time:153093ms step_avg:95.44ms +step:1605/1705 train_time:153189ms step_avg:95.44ms +step:1606/1705 train_time:153283ms step_avg:95.44ms +step:1607/1705 train_time:153378ms step_avg:95.44ms +step:1608/1705 train_time:153473ms step_avg:95.44ms +step:1609/1705 train_time:153569ms step_avg:95.44ms +step:1610/1705 train_time:153664ms step_avg:95.44ms +step:1611/1705 train_time:153759ms step_avg:95.44ms +step:1612/1705 train_time:153854ms step_avg:95.44ms +step:1613/1705 train_time:153950ms step_avg:95.44ms +step:1614/1705 train_time:154046ms step_avg:95.44ms +step:1615/1705 train_time:154141ms step_avg:95.44ms +step:1616/1705 train_time:154236ms step_avg:95.44ms +step:1617/1705 train_time:154331ms step_avg:95.44ms +step:1618/1705 train_time:154426ms step_avg:95.44ms +step:1619/1705 train_time:154522ms step_avg:95.44ms +step:1620/1705 train_time:154616ms step_avg:95.44ms +step:1621/1705 train_time:154711ms step_avg:95.44ms +step:1622/1705 train_time:154807ms step_avg:95.44ms +step:1623/1705 train_time:154903ms step_avg:95.44ms +step:1624/1705 train_time:154998ms step_avg:95.44ms +step:1625/1705 train_time:155094ms step_avg:95.44ms +step:1625/1705 val_loss:3.2949 train_time:155190ms step_avg:95.50ms +step:1626/1705 train_time:155211ms step_avg:95.46ms +step:1627/1705 train_time:155290ms step_avg:95.45ms +step:1628/1705 train_time:155388ms step_avg:95.45ms +step:1629/1705 train_time:155485ms step_avg:95.45ms +step:1630/1705 train_time:155582ms step_avg:95.45ms +step:1631/1705 train_time:155676ms step_avg:95.45ms +step:1632/1705 train_time:155770ms step_avg:95.45ms +step:1633/1705 train_time:155865ms step_avg:95.45ms +step:1634/1705 train_time:155960ms step_avg:95.45ms +step:1635/1705 train_time:156054ms step_avg:95.45ms +step:1636/1705 train_time:156149ms step_avg:95.45ms +step:1637/1705 train_time:156248ms step_avg:95.45ms +step:1638/1705 train_time:156347ms step_avg:95.45ms +step:1639/1705 train_time:156442ms step_avg:95.45ms +step:1640/1705 train_time:156538ms step_avg:95.45ms +step:1641/1705 train_time:156634ms step_avg:95.45ms +step:1642/1705 train_time:156727ms step_avg:95.45ms +step:1643/1705 train_time:156821ms step_avg:95.45ms +step:1644/1705 train_time:156917ms step_avg:95.45ms +step:1645/1705 train_time:157011ms step_avg:95.45ms +step:1646/1705 train_time:157107ms step_avg:95.45ms +step:1647/1705 train_time:157203ms step_avg:95.45ms +step:1648/1705 train_time:157302ms step_avg:95.45ms +step:1649/1705 train_time:157399ms step_avg:95.45ms +step:1650/1705 train_time:157495ms step_avg:95.45ms +step:1651/1705 train_time:157590ms step_avg:95.45ms +step:1652/1705 train_time:157685ms step_avg:95.45ms +step:1653/1705 train_time:157779ms step_avg:95.45ms +step:1654/1705 train_time:157873ms step_avg:95.45ms +step:1655/1705 train_time:157968ms step_avg:95.45ms +step:1656/1705 train_time:158064ms step_avg:95.45ms +step:1657/1705 train_time:158159ms step_avg:95.45ms +step:1658/1705 train_time:158255ms step_avg:95.45ms +step:1659/1705 train_time:158350ms step_avg:95.45ms +step:1660/1705 train_time:158445ms step_avg:95.45ms +step:1661/1705 train_time:158542ms step_avg:95.45ms +step:1662/1705 train_time:158639ms step_avg:95.45ms +step:1663/1705 train_time:158734ms step_avg:95.45ms +step:1664/1705 train_time:158829ms step_avg:95.45ms +step:1665/1705 train_time:158924ms step_avg:95.45ms +step:1666/1705 train_time:159019ms step_avg:95.45ms +step:1667/1705 train_time:159114ms step_avg:95.45ms +step:1668/1705 train_time:159208ms step_avg:95.45ms +step:1669/1705 train_time:159305ms step_avg:95.45ms +step:1670/1705 train_time:159401ms step_avg:95.45ms +step:1671/1705 train_time:159497ms step_avg:95.45ms +step:1672/1705 train_time:159592ms step_avg:95.45ms +step:1673/1705 train_time:159687ms step_avg:95.45ms +step:1674/1705 train_time:159782ms step_avg:95.45ms +step:1675/1705 train_time:159879ms step_avg:95.45ms +step:1676/1705 train_time:159974ms step_avg:95.45ms +step:1677/1705 train_time:160068ms step_avg:95.45ms +step:1678/1705 train_time:160163ms step_avg:95.45ms +step:1679/1705 train_time:160260ms step_avg:95.45ms +step:1680/1705 train_time:160356ms step_avg:95.45ms +step:1681/1705 train_time:160451ms step_avg:95.45ms +step:1682/1705 train_time:160547ms step_avg:95.45ms +step:1683/1705 train_time:160643ms step_avg:95.45ms +step:1684/1705 train_time:160738ms step_avg:95.45ms +step:1685/1705 train_time:160834ms step_avg:95.45ms +step:1686/1705 train_time:160928ms step_avg:95.45ms +step:1687/1705 train_time:161024ms step_avg:95.45ms +step:1688/1705 train_time:161120ms step_avg:95.45ms +step:1689/1705 train_time:161215ms step_avg:95.45ms +step:1690/1705 train_time:161310ms step_avg:95.45ms +step:1691/1705 train_time:161406ms step_avg:95.45ms +step:1692/1705 train_time:161501ms step_avg:95.45ms +step:1693/1705 train_time:161597ms step_avg:95.45ms +step:1694/1705 train_time:161692ms step_avg:95.45ms +step:1695/1705 train_time:161787ms step_avg:95.45ms +step:1696/1705 train_time:161882ms step_avg:95.45ms +step:1697/1705 train_time:161978ms step_avg:95.45ms +step:1698/1705 train_time:162294ms step_avg:95.58ms +step:1699/1705 train_time:162421ms step_avg:95.60ms +step:1700/1705 train_time:162514ms step_avg:95.60ms +step:1701/1705 train_time:162607ms step_avg:95.59ms +step:1702/1705 train_time:162701ms step_avg:95.59ms +step:1703/1705 train_time:162795ms step_avg:95.59ms +step:1704/1705 train_time:162889ms step_avg:95.59ms +step:1705/1705 train_time:162983ms step_avg:95.59ms +step:1705/1705 val_loss:3.2806 train_time:163078ms step_avg:95.65ms +peak memory allocated: 34489 MiB reserved: 48516 MiB diff --git a/records/050925_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt b/records/050925_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt new file mode 100644 index 000000000..e269a2db6 --- /dev/null +++ b/records/050925_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:36:47 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 128W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 123W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 33C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 69624 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 69625 C /usr/bin/python3 610MiB | +| 0 N/A N/A 69626 C /usr/bin/python3 610MiB | +| 0 N/A N/A 69627 C /usr/bin/python3 610MiB | +| 0 N/A N/A 69628 C /usr/bin/python3 610MiB | +| 0 N/A N/A 69629 C /usr/bin/python3 610MiB | +| 0 N/A N/A 69630 C /usr/bin/python3 610MiB | +| 0 N/A N/A 69631 C /usr/bin/python3 610MiB | +| 1 N/A N/A 69625 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 69626 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 69627 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 69628 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 69629 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 69630 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 69631 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1705 train_time:399ms step_avg:398.69ms +step:2/1705 train_time:419ms step_avg:209.65ms +step:3/1705 train_time:488ms step_avg:162.60ms +step:4/1705 train_time:579ms step_avg:144.80ms +step:5/1705 train_time:671ms step_avg:134.24ms +step:6/1705 train_time:763ms step_avg:127.22ms +step:7/1705 train_time:856ms step_avg:122.29ms +step:8/1705 train_time:948ms step_avg:118.50ms +step:9/1705 train_time:1040ms step_avg:115.56ms +step:10/1705 train_time:1132ms step_avg:113.24ms +step:11/1705 train_time:1224ms step_avg:111.31ms +step:12/1705 train_time:1319ms step_avg:109.88ms +step:13/1705 train_time:1415ms step_avg:108.82ms +step:14/1705 train_time:1509ms step_avg:107.77ms +step:15/1705 train_time:1603ms step_avg:106.86ms +step:16/1705 train_time:1695ms step_avg:105.97ms +step:17/1705 train_time:1788ms step_avg:105.15ms +step:18/1705 train_time:1880ms step_avg:104.43ms +step:19/1705 train_time:1973ms step_avg:103.84ms +step:20/1705 train_time:2066ms step_avg:103.29ms +step:21/1705 train_time:2159ms step_avg:102.80ms +step:22/1705 train_time:2252ms step_avg:102.36ms +step:23/1705 train_time:2346ms step_avg:101.99ms +step:24/1705 train_time:2439ms step_avg:101.64ms +step:25/1705 train_time:2534ms step_avg:101.35ms +step:26/1705 train_time:2626ms step_avg:101.00ms +step:27/1705 train_time:2719ms step_avg:100.69ms +step:28/1705 train_time:2812ms step_avg:100.44ms +step:29/1705 train_time:2906ms step_avg:100.19ms +step:30/1705 train_time:2998ms step_avg:99.95ms +step:31/1705 train_time:3091ms step_avg:99.72ms +step:32/1705 train_time:3184ms step_avg:99.51ms +step:33/1705 train_time:3278ms step_avg:99.34ms +step:34/1705 train_time:3372ms step_avg:99.18ms +step:35/1705 train_time:3466ms step_avg:99.02ms +step:36/1705 train_time:3559ms step_avg:98.86ms +step:37/1705 train_time:3652ms step_avg:98.71ms +step:38/1705 train_time:3745ms step_avg:98.56ms +step:39/1705 train_time:3839ms step_avg:98.42ms +step:40/1705 train_time:3932ms step_avg:98.29ms +step:41/1705 train_time:4025ms step_avg:98.17ms +step:42/1705 train_time:4118ms step_avg:98.04ms +step:43/1705 train_time:4211ms step_avg:97.92ms +step:44/1705 train_time:4304ms step_avg:97.82ms +step:45/1705 train_time:4398ms step_avg:97.72ms +step:46/1705 train_time:4491ms step_avg:97.63ms +step:47/1705 train_time:4584ms step_avg:97.54ms +step:48/1705 train_time:4678ms step_avg:97.46ms +step:49/1705 train_time:4771ms step_avg:97.38ms +step:50/1705 train_time:4864ms step_avg:97.29ms +step:51/1705 train_time:4958ms step_avg:97.22ms +step:52/1705 train_time:5051ms step_avg:97.14ms +step:53/1705 train_time:5143ms step_avg:97.04ms +step:54/1705 train_time:5236ms step_avg:96.97ms +step:55/1705 train_time:5330ms step_avg:96.90ms +step:56/1705 train_time:5423ms step_avg:96.84ms +step:57/1705 train_time:5516ms step_avg:96.77ms +step:58/1705 train_time:5609ms step_avg:96.70ms +step:59/1705 train_time:5702ms step_avg:96.64ms +step:60/1705 train_time:5795ms step_avg:96.58ms +step:61/1705 train_time:5888ms step_avg:96.52ms +step:62/1705 train_time:5981ms step_avg:96.47ms +step:63/1705 train_time:6075ms step_avg:96.43ms +step:64/1705 train_time:6168ms step_avg:96.38ms +step:65/1705 train_time:6262ms step_avg:96.34ms +step:66/1705 train_time:6355ms step_avg:96.29ms +step:67/1705 train_time:6448ms step_avg:96.24ms +step:68/1705 train_time:6541ms step_avg:96.19ms +step:69/1705 train_time:6634ms step_avg:96.15ms +step:70/1705 train_time:6728ms step_avg:96.11ms +step:71/1705 train_time:6821ms step_avg:96.06ms +step:72/1705 train_time:6914ms step_avg:96.03ms +step:73/1705 train_time:7007ms step_avg:95.99ms +step:74/1705 train_time:7101ms step_avg:95.96ms +step:75/1705 train_time:7195ms step_avg:95.93ms +step:76/1705 train_time:7288ms step_avg:95.90ms +step:77/1705 train_time:7381ms step_avg:95.86ms +step:78/1705 train_time:7475ms step_avg:95.83ms +step:79/1705 train_time:7567ms step_avg:95.78ms +step:80/1705 train_time:7660ms step_avg:95.75ms +step:81/1705 train_time:7753ms step_avg:95.72ms +step:82/1705 train_time:7846ms step_avg:95.69ms +step:83/1705 train_time:7939ms step_avg:95.65ms +step:84/1705 train_time:8032ms step_avg:95.62ms +step:85/1705 train_time:8125ms step_avg:95.59ms +step:86/1705 train_time:8219ms step_avg:95.56ms +step:87/1705 train_time:8312ms step_avg:95.54ms +step:88/1705 train_time:8404ms step_avg:95.50ms +step:89/1705 train_time:8497ms step_avg:95.47ms +step:90/1705 train_time:8590ms step_avg:95.44ms +step:91/1705 train_time:8683ms step_avg:95.41ms +step:92/1705 train_time:8776ms step_avg:95.39ms +step:93/1705 train_time:8870ms step_avg:95.37ms +step:94/1705 train_time:8962ms step_avg:95.34ms +step:95/1705 train_time:9055ms step_avg:95.32ms +step:96/1705 train_time:9150ms step_avg:95.31ms +step:97/1705 train_time:9241ms step_avg:95.27ms +step:98/1705 train_time:9335ms step_avg:95.25ms +step:99/1705 train_time:9427ms step_avg:95.22ms +step:100/1705 train_time:9521ms step_avg:95.21ms +step:101/1705 train_time:9614ms step_avg:95.19ms +step:102/1705 train_time:9707ms step_avg:95.16ms +step:103/1705 train_time:9800ms step_avg:95.15ms +step:104/1705 train_time:9894ms step_avg:95.13ms +step:105/1705 train_time:9986ms step_avg:95.11ms +step:106/1705 train_time:10079ms step_avg:95.09ms +step:107/1705 train_time:10173ms step_avg:95.07ms +step:108/1705 train_time:10266ms step_avg:95.06ms +step:109/1705 train_time:10360ms step_avg:95.04ms +step:110/1705 train_time:10453ms step_avg:95.03ms +step:111/1705 train_time:10545ms step_avg:95.00ms +step:112/1705 train_time:10638ms step_avg:94.99ms +step:113/1705 train_time:10731ms step_avg:94.97ms +step:114/1705 train_time:10824ms step_avg:94.95ms +step:115/1705 train_time:10918ms step_avg:94.94ms +step:116/1705 train_time:11010ms step_avg:94.92ms +step:117/1705 train_time:11104ms step_avg:94.90ms +step:118/1705 train_time:11196ms step_avg:94.88ms +step:119/1705 train_time:11289ms step_avg:94.87ms +step:120/1705 train_time:11383ms step_avg:94.86ms +step:121/1705 train_time:11476ms step_avg:94.84ms +step:122/1705 train_time:11568ms step_avg:94.82ms +step:123/1705 train_time:11661ms step_avg:94.80ms +step:124/1705 train_time:11754ms step_avg:94.79ms +step:125/1705 train_time:11847ms step_avg:94.78ms +step:125/1705 val_loss:4.3131 train_time:11940ms step_avg:95.52ms +step:126/1705 train_time:11964ms step_avg:94.95ms +step:127/1705 train_time:12041ms step_avg:94.81ms +step:128/1705 train_time:12140ms step_avg:94.85ms +step:129/1705 train_time:12236ms step_avg:94.85ms +step:130/1705 train_time:12329ms step_avg:94.84ms +step:131/1705 train_time:12422ms step_avg:94.82ms +step:132/1705 train_time:12514ms step_avg:94.80ms +step:133/1705 train_time:12606ms step_avg:94.78ms +step:134/1705 train_time:12698ms step_avg:94.76ms +step:135/1705 train_time:12790ms step_avg:94.74ms +step:136/1705 train_time:12882ms step_avg:94.72ms +step:137/1705 train_time:12975ms step_avg:94.71ms +step:138/1705 train_time:13070ms step_avg:94.71ms +step:139/1705 train_time:13165ms step_avg:94.71ms +step:140/1705 train_time:13259ms step_avg:94.71ms +step:141/1705 train_time:13352ms step_avg:94.69ms +step:142/1705 train_time:13445ms step_avg:94.68ms +step:143/1705 train_time:13537ms step_avg:94.67ms +step:144/1705 train_time:13629ms step_avg:94.65ms +step:145/1705 train_time:13722ms step_avg:94.64ms +step:146/1705 train_time:13815ms step_avg:94.62ms +step:147/1705 train_time:13907ms step_avg:94.61ms +step:148/1705 train_time:14001ms step_avg:94.60ms +step:149/1705 train_time:14094ms step_avg:94.59ms +step:150/1705 train_time:14188ms step_avg:94.58ms +step:151/1705 train_time:14281ms step_avg:94.58ms +step:152/1705 train_time:14375ms step_avg:94.57ms +step:153/1705 train_time:14468ms step_avg:94.56ms +step:154/1705 train_time:14560ms step_avg:94.55ms +step:155/1705 train_time:14653ms step_avg:94.54ms +step:156/1705 train_time:14746ms step_avg:94.52ms +step:157/1705 train_time:14838ms step_avg:94.51ms +step:158/1705 train_time:14931ms step_avg:94.50ms +step:159/1705 train_time:15025ms step_avg:94.50ms +step:160/1705 train_time:15119ms step_avg:94.49ms +step:161/1705 train_time:15211ms step_avg:94.48ms +step:162/1705 train_time:15305ms step_avg:94.48ms +step:163/1705 train_time:15398ms step_avg:94.47ms +step:164/1705 train_time:15491ms step_avg:94.45ms +step:165/1705 train_time:15583ms step_avg:94.45ms +step:166/1705 train_time:15676ms step_avg:94.43ms +step:167/1705 train_time:15769ms step_avg:94.42ms +step:168/1705 train_time:15861ms step_avg:94.41ms +step:169/1705 train_time:15954ms step_avg:94.40ms +step:170/1705 train_time:16047ms step_avg:94.39ms +step:171/1705 train_time:16141ms step_avg:94.39ms +step:172/1705 train_time:16234ms step_avg:94.38ms +step:173/1705 train_time:16328ms step_avg:94.38ms +step:174/1705 train_time:16422ms step_avg:94.38ms +step:175/1705 train_time:16515ms step_avg:94.37ms +step:176/1705 train_time:16608ms step_avg:94.36ms +step:177/1705 train_time:16701ms step_avg:94.35ms +step:178/1705 train_time:16793ms step_avg:94.34ms +step:179/1705 train_time:16886ms step_avg:94.33ms +step:180/1705 train_time:16978ms step_avg:94.32ms +step:181/1705 train_time:17071ms step_avg:94.32ms +step:182/1705 train_time:17164ms step_avg:94.31ms +step:183/1705 train_time:17257ms step_avg:94.30ms +step:184/1705 train_time:17350ms step_avg:94.29ms +step:185/1705 train_time:17443ms step_avg:94.29ms +step:186/1705 train_time:17536ms step_avg:94.28ms +step:187/1705 train_time:17628ms step_avg:94.27ms +step:188/1705 train_time:17722ms step_avg:94.26ms +step:189/1705 train_time:17814ms step_avg:94.26ms +step:190/1705 train_time:17907ms step_avg:94.25ms +step:191/1705 train_time:18001ms step_avg:94.25ms +step:192/1705 train_time:18094ms step_avg:94.24ms +step:193/1705 train_time:18186ms step_avg:94.23ms +step:194/1705 train_time:18279ms step_avg:94.22ms +step:195/1705 train_time:18372ms step_avg:94.22ms +step:196/1705 train_time:18466ms step_avg:94.21ms +step:197/1705 train_time:18559ms step_avg:94.21ms +step:198/1705 train_time:18652ms step_avg:94.20ms +step:199/1705 train_time:18745ms step_avg:94.20ms +step:200/1705 train_time:18837ms step_avg:94.19ms +step:201/1705 train_time:18930ms step_avg:94.18ms +step:202/1705 train_time:19024ms step_avg:94.18ms +step:203/1705 train_time:19117ms step_avg:94.17ms +step:204/1705 train_time:19210ms step_avg:94.17ms +step:205/1705 train_time:19303ms step_avg:94.16ms +step:206/1705 train_time:19396ms step_avg:94.16ms +step:207/1705 train_time:19489ms step_avg:94.15ms +step:208/1705 train_time:19583ms step_avg:94.15ms +step:209/1705 train_time:19675ms step_avg:94.14ms +step:210/1705 train_time:19768ms step_avg:94.13ms +step:211/1705 train_time:19861ms step_avg:94.13ms +step:212/1705 train_time:19953ms step_avg:94.12ms +step:213/1705 train_time:20278ms step_avg:95.20ms +step:214/1705 train_time:20382ms step_avg:95.24ms +step:215/1705 train_time:20473ms step_avg:95.22ms +step:216/1705 train_time:20565ms step_avg:95.21ms +step:217/1705 train_time:20657ms step_avg:95.19ms +step:218/1705 train_time:20749ms step_avg:95.18ms +step:219/1705 train_time:20841ms step_avg:95.17ms +step:220/1705 train_time:20933ms step_avg:95.15ms +step:221/1705 train_time:21024ms step_avg:95.13ms +step:222/1705 train_time:21116ms step_avg:95.12ms +step:223/1705 train_time:21211ms step_avg:95.12ms +step:224/1705 train_time:21307ms step_avg:95.12ms +step:225/1705 train_time:21402ms step_avg:95.12ms +step:226/1705 train_time:21496ms step_avg:95.11ms +step:227/1705 train_time:21588ms step_avg:95.10ms +step:228/1705 train_time:21681ms step_avg:95.09ms +step:229/1705 train_time:21773ms step_avg:95.08ms +step:230/1705 train_time:21866ms step_avg:95.07ms +step:231/1705 train_time:21957ms step_avg:95.05ms +step:232/1705 train_time:22049ms step_avg:95.04ms +step:233/1705 train_time:22141ms step_avg:95.03ms +step:234/1705 train_time:22234ms step_avg:95.02ms +step:235/1705 train_time:22328ms step_avg:95.01ms +step:236/1705 train_time:22422ms step_avg:95.01ms +step:237/1705 train_time:22516ms step_avg:95.00ms +step:238/1705 train_time:22608ms step_avg:94.99ms +step:239/1705 train_time:22701ms step_avg:94.98ms +step:240/1705 train_time:22793ms step_avg:94.97ms +step:241/1705 train_time:22885ms step_avg:94.96ms +step:242/1705 train_time:22977ms step_avg:94.95ms +step:243/1705 train_time:23070ms step_avg:94.94ms +step:244/1705 train_time:23162ms step_avg:94.93ms +step:245/1705 train_time:23255ms step_avg:94.92ms +step:246/1705 train_time:23349ms step_avg:94.91ms +step:247/1705 train_time:23443ms step_avg:94.91ms +step:248/1705 train_time:23536ms step_avg:94.90ms +step:249/1705 train_time:23629ms step_avg:94.90ms +step:250/1705 train_time:23722ms step_avg:94.89ms +step:250/1705 val_loss:3.9764 train_time:23815ms step_avg:95.26ms +step:251/1705 train_time:23838ms step_avg:94.97ms +step:252/1705 train_time:23910ms step_avg:94.88ms +step:253/1705 train_time:24002ms step_avg:94.87ms +step:254/1705 train_time:24098ms step_avg:94.87ms +step:255/1705 train_time:24194ms step_avg:94.88ms +step:256/1705 train_time:24286ms step_avg:94.87ms +step:257/1705 train_time:24378ms step_avg:94.86ms +step:258/1705 train_time:24470ms step_avg:94.85ms +step:259/1705 train_time:24562ms step_avg:94.83ms +step:260/1705 train_time:24653ms step_avg:94.82ms +step:261/1705 train_time:24749ms step_avg:94.82ms +step:262/1705 train_time:24843ms step_avg:94.82ms +step:263/1705 train_time:24936ms step_avg:94.82ms +step:264/1705 train_time:25029ms step_avg:94.81ms +step:265/1705 train_time:25123ms step_avg:94.80ms +step:266/1705 train_time:25216ms step_avg:94.80ms +step:267/1705 train_time:25308ms step_avg:94.79ms +step:268/1705 train_time:25400ms step_avg:94.78ms +step:269/1705 train_time:25492ms step_avg:94.76ms +step:270/1705 train_time:25584ms step_avg:94.75ms +step:271/1705 train_time:25677ms step_avg:94.75ms +step:272/1705 train_time:25771ms step_avg:94.75ms +step:273/1705 train_time:25865ms step_avg:94.74ms +step:274/1705 train_time:25958ms step_avg:94.74ms +step:275/1705 train_time:26052ms step_avg:94.73ms +step:276/1705 train_time:26146ms step_avg:94.73ms +step:277/1705 train_time:26238ms step_avg:94.72ms +step:278/1705 train_time:26331ms step_avg:94.72ms +step:279/1705 train_time:26423ms step_avg:94.71ms +step:280/1705 train_time:26515ms step_avg:94.70ms +step:281/1705 train_time:26608ms step_avg:94.69ms +step:282/1705 train_time:26700ms step_avg:94.68ms +step:283/1705 train_time:26793ms step_avg:94.67ms +step:284/1705 train_time:26886ms step_avg:94.67ms +step:285/1705 train_time:26979ms step_avg:94.66ms +step:286/1705 train_time:27072ms step_avg:94.66ms +step:287/1705 train_time:27166ms step_avg:94.66ms +step:288/1705 train_time:27259ms step_avg:94.65ms +step:289/1705 train_time:27352ms step_avg:94.64ms +step:290/1705 train_time:27445ms step_avg:94.64ms +step:291/1705 train_time:27538ms step_avg:94.63ms +step:292/1705 train_time:27631ms step_avg:94.63ms +step:293/1705 train_time:27725ms step_avg:94.62ms +step:294/1705 train_time:27817ms step_avg:94.62ms +step:295/1705 train_time:27910ms step_avg:94.61ms +step:296/1705 train_time:28002ms step_avg:94.60ms +step:297/1705 train_time:28095ms step_avg:94.60ms +step:298/1705 train_time:28188ms step_avg:94.59ms +step:299/1705 train_time:28280ms step_avg:94.58ms +step:300/1705 train_time:28373ms step_avg:94.58ms +step:301/1705 train_time:28467ms step_avg:94.57ms +step:302/1705 train_time:28559ms step_avg:94.57ms +step:303/1705 train_time:28652ms step_avg:94.56ms +step:304/1705 train_time:28745ms step_avg:94.55ms +step:305/1705 train_time:28837ms step_avg:94.55ms +step:306/1705 train_time:28930ms step_avg:94.54ms +step:307/1705 train_time:29023ms step_avg:94.54ms +step:308/1705 train_time:29115ms step_avg:94.53ms +step:309/1705 train_time:29209ms step_avg:94.53ms +step:310/1705 train_time:29302ms step_avg:94.52ms +step:311/1705 train_time:29394ms step_avg:94.51ms +step:312/1705 train_time:29488ms step_avg:94.51ms +step:313/1705 train_time:29581ms step_avg:94.51ms +step:314/1705 train_time:29675ms step_avg:94.51ms +step:315/1705 train_time:29768ms step_avg:94.50ms +step:316/1705 train_time:29860ms step_avg:94.49ms +step:317/1705 train_time:29953ms step_avg:94.49ms +step:318/1705 train_time:30046ms step_avg:94.48ms +step:319/1705 train_time:30138ms step_avg:94.48ms +step:320/1705 train_time:30232ms step_avg:94.47ms +step:321/1705 train_time:30324ms step_avg:94.47ms +step:322/1705 train_time:30417ms step_avg:94.46ms +step:323/1705 train_time:30511ms step_avg:94.46ms +step:324/1705 train_time:30604ms step_avg:94.46ms +step:325/1705 train_time:30696ms step_avg:94.45ms +step:326/1705 train_time:30789ms step_avg:94.45ms +step:327/1705 train_time:30881ms step_avg:94.44ms +step:328/1705 train_time:30975ms step_avg:94.43ms +step:329/1705 train_time:31067ms step_avg:94.43ms +step:330/1705 train_time:31159ms step_avg:94.42ms +step:331/1705 train_time:31252ms step_avg:94.42ms +step:332/1705 train_time:31345ms step_avg:94.41ms +step:333/1705 train_time:31438ms step_avg:94.41ms +step:334/1705 train_time:31531ms step_avg:94.41ms +step:335/1705 train_time:31624ms step_avg:94.40ms +step:336/1705 train_time:31717ms step_avg:94.40ms +step:337/1705 train_time:31811ms step_avg:94.39ms +step:338/1705 train_time:31903ms step_avg:94.39ms +step:339/1705 train_time:31995ms step_avg:94.38ms +step:340/1705 train_time:32088ms step_avg:94.38ms +step:341/1705 train_time:32181ms step_avg:94.37ms +step:342/1705 train_time:32274ms step_avg:94.37ms +step:343/1705 train_time:32366ms step_avg:94.36ms +step:344/1705 train_time:32459ms step_avg:94.36ms +step:345/1705 train_time:32553ms step_avg:94.36ms +step:346/1705 train_time:32647ms step_avg:94.35ms +step:347/1705 train_time:32739ms step_avg:94.35ms +step:348/1705 train_time:32831ms step_avg:94.34ms +step:349/1705 train_time:32924ms step_avg:94.34ms +step:350/1705 train_time:33016ms step_avg:94.33ms +step:351/1705 train_time:33109ms step_avg:94.33ms +step:352/1705 train_time:33202ms step_avg:94.32ms +step:353/1705 train_time:33294ms step_avg:94.32ms +step:354/1705 train_time:33388ms step_avg:94.32ms +step:355/1705 train_time:33481ms step_avg:94.31ms +step:356/1705 train_time:33574ms step_avg:94.31ms +step:357/1705 train_time:33667ms step_avg:94.31ms +step:358/1705 train_time:33761ms step_avg:94.30ms +step:359/1705 train_time:33854ms step_avg:94.30ms +step:360/1705 train_time:33947ms step_avg:94.30ms +step:361/1705 train_time:34039ms step_avg:94.29ms +step:362/1705 train_time:34132ms step_avg:94.29ms +step:363/1705 train_time:34225ms step_avg:94.28ms +step:364/1705 train_time:34318ms step_avg:94.28ms +step:365/1705 train_time:34412ms step_avg:94.28ms +step:366/1705 train_time:34505ms step_avg:94.28ms +step:367/1705 train_time:34598ms step_avg:94.27ms +step:368/1705 train_time:34691ms step_avg:94.27ms +step:369/1705 train_time:34783ms step_avg:94.26ms +step:370/1705 train_time:34876ms step_avg:94.26ms +step:371/1705 train_time:34968ms step_avg:94.25ms +step:372/1705 train_time:35061ms step_avg:94.25ms +step:373/1705 train_time:35153ms step_avg:94.25ms +step:374/1705 train_time:35247ms step_avg:94.24ms +step:375/1705 train_time:35339ms step_avg:94.24ms +step:375/1705 val_loss:3.8205 train_time:35433ms step_avg:94.49ms +step:376/1705 train_time:35458ms step_avg:94.30ms +step:377/1705 train_time:35530ms step_avg:94.24ms +step:378/1705 train_time:35630ms step_avg:94.26ms +step:379/1705 train_time:35724ms step_avg:94.26ms +step:380/1705 train_time:35816ms step_avg:94.25ms +step:381/1705 train_time:35909ms step_avg:94.25ms +step:382/1705 train_time:36001ms step_avg:94.24ms +step:383/1705 train_time:36093ms step_avg:94.24ms +step:384/1705 train_time:36185ms step_avg:94.23ms +step:385/1705 train_time:36277ms step_avg:94.23ms +step:386/1705 train_time:36369ms step_avg:94.22ms +step:387/1705 train_time:36463ms step_avg:94.22ms +step:388/1705 train_time:36557ms step_avg:94.22ms +step:389/1705 train_time:36652ms step_avg:94.22ms +step:390/1705 train_time:36746ms step_avg:94.22ms +step:391/1705 train_time:36839ms step_avg:94.22ms +step:392/1705 train_time:36932ms step_avg:94.21ms +step:393/1705 train_time:37024ms step_avg:94.21ms +step:394/1705 train_time:37117ms step_avg:94.20ms +step:395/1705 train_time:37208ms step_avg:94.20ms +step:396/1705 train_time:37300ms step_avg:94.19ms +step:397/1705 train_time:37393ms step_avg:94.19ms +step:398/1705 train_time:37486ms step_avg:94.19ms +step:399/1705 train_time:37580ms step_avg:94.19ms +step:400/1705 train_time:37673ms step_avg:94.18ms +step:401/1705 train_time:37768ms step_avg:94.18ms +step:402/1705 train_time:37862ms step_avg:94.18ms +step:403/1705 train_time:37954ms step_avg:94.18ms +step:404/1705 train_time:38047ms step_avg:94.18ms +step:405/1705 train_time:38139ms step_avg:94.17ms +step:406/1705 train_time:38231ms step_avg:94.17ms +step:407/1705 train_time:38324ms step_avg:94.16ms +step:408/1705 train_time:38416ms step_avg:94.16ms +step:409/1705 train_time:38509ms step_avg:94.15ms +step:410/1705 train_time:38602ms step_avg:94.15ms +step:411/1705 train_time:38695ms step_avg:94.15ms +step:412/1705 train_time:38790ms step_avg:94.15ms +step:413/1705 train_time:38884ms step_avg:94.15ms +step:414/1705 train_time:38976ms step_avg:94.14ms +step:415/1705 train_time:39069ms step_avg:94.14ms +step:416/1705 train_time:39161ms step_avg:94.14ms +step:417/1705 train_time:39254ms step_avg:94.13ms +step:418/1705 train_time:39346ms step_avg:94.13ms +step:419/1705 train_time:39439ms step_avg:94.13ms +step:420/1705 train_time:39532ms step_avg:94.12ms +step:421/1705 train_time:39625ms step_avg:94.12ms +step:422/1705 train_time:39719ms step_avg:94.12ms +step:423/1705 train_time:39812ms step_avg:94.12ms +step:424/1705 train_time:39906ms step_avg:94.12ms +step:425/1705 train_time:40189ms step_avg:94.56ms +step:426/1705 train_time:40360ms step_avg:94.74ms +step:427/1705 train_time:40451ms step_avg:94.73ms +step:428/1705 train_time:40543ms step_avg:94.73ms +step:429/1705 train_time:40635ms step_avg:94.72ms +step:430/1705 train_time:40727ms step_avg:94.71ms +step:431/1705 train_time:40818ms step_avg:94.71ms +step:432/1705 train_time:40910ms step_avg:94.70ms +step:433/1705 train_time:41002ms step_avg:94.69ms +step:434/1705 train_time:41093ms step_avg:94.69ms +step:435/1705 train_time:41188ms step_avg:94.68ms +step:436/1705 train_time:41285ms step_avg:94.69ms +step:437/1705 train_time:41382ms step_avg:94.69ms +step:438/1705 train_time:41474ms step_avg:94.69ms +step:439/1705 train_time:41568ms step_avg:94.69ms +step:440/1705 train_time:41660ms step_avg:94.68ms +step:441/1705 train_time:41753ms step_avg:94.68ms +step:442/1705 train_time:41845ms step_avg:94.67ms +step:443/1705 train_time:41937ms step_avg:94.66ms +step:444/1705 train_time:42029ms step_avg:94.66ms +step:445/1705 train_time:42121ms step_avg:94.65ms +step:446/1705 train_time:42214ms step_avg:94.65ms +step:447/1705 train_time:42309ms step_avg:94.65ms +step:448/1705 train_time:42403ms step_avg:94.65ms +step:449/1705 train_time:42496ms step_avg:94.65ms +step:450/1705 train_time:42589ms step_avg:94.64ms +step:451/1705 train_time:42683ms step_avg:94.64ms +step:452/1705 train_time:42776ms step_avg:94.64ms +step:453/1705 train_time:42868ms step_avg:94.63ms +step:454/1705 train_time:42961ms step_avg:94.63ms +step:455/1705 train_time:43052ms step_avg:94.62ms +step:456/1705 train_time:43145ms step_avg:94.62ms +step:457/1705 train_time:43238ms step_avg:94.61ms +step:458/1705 train_time:43332ms step_avg:94.61ms +step:459/1705 train_time:43426ms step_avg:94.61ms +step:460/1705 train_time:43519ms step_avg:94.61ms +step:461/1705 train_time:43613ms step_avg:94.60ms +step:462/1705 train_time:43705ms step_avg:94.60ms +step:463/1705 train_time:43798ms step_avg:94.60ms +step:464/1705 train_time:43891ms step_avg:94.59ms +step:465/1705 train_time:43983ms step_avg:94.59ms +step:466/1705 train_time:44076ms step_avg:94.58ms +step:467/1705 train_time:44169ms step_avg:94.58ms +step:468/1705 train_time:44263ms step_avg:94.58ms +step:469/1705 train_time:44356ms step_avg:94.58ms +step:470/1705 train_time:44449ms step_avg:94.57ms +step:471/1705 train_time:44542ms step_avg:94.57ms +step:472/1705 train_time:44635ms step_avg:94.57ms +step:473/1705 train_time:44728ms step_avg:94.56ms +step:474/1705 train_time:44821ms step_avg:94.56ms +step:475/1705 train_time:44913ms step_avg:94.55ms +step:476/1705 train_time:45006ms step_avg:94.55ms +step:477/1705 train_time:45099ms step_avg:94.55ms +step:478/1705 train_time:45192ms step_avg:94.54ms +step:479/1705 train_time:45286ms step_avg:94.54ms +step:480/1705 train_time:45379ms step_avg:94.54ms +step:481/1705 train_time:45473ms step_avg:94.54ms +step:482/1705 train_time:45567ms step_avg:94.54ms +step:483/1705 train_time:45659ms step_avg:94.53ms +step:484/1705 train_time:45751ms step_avg:94.53ms +step:485/1705 train_time:45845ms step_avg:94.52ms +step:486/1705 train_time:45937ms step_avg:94.52ms +step:487/1705 train_time:46031ms step_avg:94.52ms +step:488/1705 train_time:46124ms step_avg:94.52ms +step:489/1705 train_time:46217ms step_avg:94.51ms +step:490/1705 train_time:46311ms step_avg:94.51ms +step:491/1705 train_time:46405ms step_avg:94.51ms +step:492/1705 train_time:46498ms step_avg:94.51ms +step:493/1705 train_time:46591ms step_avg:94.51ms +step:494/1705 train_time:46684ms step_avg:94.50ms +step:495/1705 train_time:46777ms step_avg:94.50ms +step:496/1705 train_time:46870ms step_avg:94.50ms +step:497/1705 train_time:46962ms step_avg:94.49ms +step:498/1705 train_time:47055ms step_avg:94.49ms +step:499/1705 train_time:47148ms step_avg:94.49ms +step:500/1705 train_time:47241ms step_avg:94.48ms +step:500/1705 val_loss:3.7190 train_time:47334ms step_avg:94.67ms +step:501/1705 train_time:47357ms step_avg:94.53ms +step:502/1705 train_time:47432ms step_avg:94.49ms +step:503/1705 train_time:47530ms step_avg:94.49ms +step:504/1705 train_time:47623ms step_avg:94.49ms +step:505/1705 train_time:47716ms step_avg:94.49ms +step:506/1705 train_time:47807ms step_avg:94.48ms +step:507/1705 train_time:47900ms step_avg:94.48ms +step:508/1705 train_time:47992ms step_avg:94.47ms +step:509/1705 train_time:48084ms step_avg:94.47ms +step:510/1705 train_time:48176ms step_avg:94.46ms +step:511/1705 train_time:48269ms step_avg:94.46ms +step:512/1705 train_time:48363ms step_avg:94.46ms +step:513/1705 train_time:48459ms step_avg:94.46ms +step:514/1705 train_time:48553ms step_avg:94.46ms +step:515/1705 train_time:48646ms step_avg:94.46ms +step:516/1705 train_time:48739ms step_avg:94.45ms +step:517/1705 train_time:48831ms step_avg:94.45ms +step:518/1705 train_time:48924ms step_avg:94.45ms +step:519/1705 train_time:49016ms step_avg:94.44ms +step:520/1705 train_time:49107ms step_avg:94.44ms +step:521/1705 train_time:49201ms step_avg:94.44ms +step:522/1705 train_time:49295ms step_avg:94.43ms +step:523/1705 train_time:49388ms step_avg:94.43ms +step:524/1705 train_time:49482ms step_avg:94.43ms +step:525/1705 train_time:49576ms step_avg:94.43ms +step:526/1705 train_time:49670ms step_avg:94.43ms +step:527/1705 train_time:49763ms step_avg:94.43ms +step:528/1705 train_time:49855ms step_avg:94.42ms +step:529/1705 train_time:49947ms step_avg:94.42ms +step:530/1705 train_time:50040ms step_avg:94.41ms +step:531/1705 train_time:50132ms step_avg:94.41ms +step:532/1705 train_time:50225ms step_avg:94.41ms +step:533/1705 train_time:50318ms step_avg:94.40ms +step:534/1705 train_time:50411ms step_avg:94.40ms +step:535/1705 train_time:50504ms step_avg:94.40ms +step:536/1705 train_time:50598ms step_avg:94.40ms +step:537/1705 train_time:50691ms step_avg:94.40ms +step:538/1705 train_time:50784ms step_avg:94.39ms +step:539/1705 train_time:50877ms step_avg:94.39ms +step:540/1705 train_time:50970ms step_avg:94.39ms +step:541/1705 train_time:51062ms step_avg:94.38ms +step:542/1705 train_time:51155ms step_avg:94.38ms +step:543/1705 train_time:51248ms step_avg:94.38ms +step:544/1705 train_time:51341ms step_avg:94.38ms +step:545/1705 train_time:51434ms step_avg:94.37ms +step:546/1705 train_time:51527ms step_avg:94.37ms +step:547/1705 train_time:51621ms step_avg:94.37ms +step:548/1705 train_time:51715ms step_avg:94.37ms +step:549/1705 train_time:51808ms step_avg:94.37ms +step:550/1705 train_time:51901ms step_avg:94.37ms +step:551/1705 train_time:51993ms step_avg:94.36ms +step:552/1705 train_time:52086ms step_avg:94.36ms +step:553/1705 train_time:52179ms step_avg:94.36ms +step:554/1705 train_time:52272ms step_avg:94.35ms +step:555/1705 train_time:52365ms step_avg:94.35ms +step:556/1705 train_time:52458ms step_avg:94.35ms +step:557/1705 train_time:52551ms step_avg:94.35ms +step:558/1705 train_time:52644ms step_avg:94.34ms +step:559/1705 train_time:52737ms step_avg:94.34ms +step:560/1705 train_time:52830ms step_avg:94.34ms +step:561/1705 train_time:52923ms step_avg:94.34ms +step:562/1705 train_time:53016ms step_avg:94.33ms +step:563/1705 train_time:53108ms step_avg:94.33ms +step:564/1705 train_time:53202ms step_avg:94.33ms +step:565/1705 train_time:53295ms step_avg:94.33ms +step:566/1705 train_time:53387ms step_avg:94.32ms +step:567/1705 train_time:53480ms step_avg:94.32ms +step:568/1705 train_time:53573ms step_avg:94.32ms +step:569/1705 train_time:53666ms step_avg:94.32ms +step:570/1705 train_time:53759ms step_avg:94.31ms +step:571/1705 train_time:53854ms step_avg:94.32ms +step:572/1705 train_time:53949ms step_avg:94.32ms +step:573/1705 train_time:54043ms step_avg:94.32ms +step:574/1705 train_time:54137ms step_avg:94.32ms +step:575/1705 train_time:54231ms step_avg:94.32ms +step:576/1705 train_time:54326ms step_avg:94.32ms +step:577/1705 train_time:54421ms step_avg:94.32ms +step:578/1705 train_time:54514ms step_avg:94.31ms +step:579/1705 train_time:54607ms step_avg:94.31ms +step:580/1705 train_time:54702ms step_avg:94.31ms +step:581/1705 train_time:54796ms step_avg:94.31ms +step:582/1705 train_time:54889ms step_avg:94.31ms +step:583/1705 train_time:54983ms step_avg:94.31ms +step:584/1705 train_time:55078ms step_avg:94.31ms +step:585/1705 train_time:55172ms step_avg:94.31ms +step:586/1705 train_time:55266ms step_avg:94.31ms +step:587/1705 train_time:55361ms step_avg:94.31ms +step:588/1705 train_time:55455ms step_avg:94.31ms +step:589/1705 train_time:55549ms step_avg:94.31ms +step:590/1705 train_time:55643ms step_avg:94.31ms +step:591/1705 train_time:55738ms step_avg:94.31ms +step:592/1705 train_time:55832ms step_avg:94.31ms +step:593/1705 train_time:55926ms step_avg:94.31ms +step:594/1705 train_time:56020ms step_avg:94.31ms +step:595/1705 train_time:56115ms step_avg:94.31ms +step:596/1705 train_time:56209ms step_avg:94.31ms +step:597/1705 train_time:56303ms step_avg:94.31ms +step:598/1705 train_time:56398ms step_avg:94.31ms +step:599/1705 train_time:56491ms step_avg:94.31ms +step:600/1705 train_time:56586ms step_avg:94.31ms +step:601/1705 train_time:56680ms step_avg:94.31ms +step:602/1705 train_time:56774ms step_avg:94.31ms +step:603/1705 train_time:56869ms step_avg:94.31ms +step:604/1705 train_time:56964ms step_avg:94.31ms +step:605/1705 train_time:57058ms step_avg:94.31ms +step:606/1705 train_time:57152ms step_avg:94.31ms +step:607/1705 train_time:57246ms step_avg:94.31ms +step:608/1705 train_time:57341ms step_avg:94.31ms +step:609/1705 train_time:57435ms step_avg:94.31ms +step:610/1705 train_time:57529ms step_avg:94.31ms +step:611/1705 train_time:57624ms step_avg:94.31ms +step:612/1705 train_time:57718ms step_avg:94.31ms +step:613/1705 train_time:57812ms step_avg:94.31ms +step:614/1705 train_time:57906ms step_avg:94.31ms +step:615/1705 train_time:58001ms step_avg:94.31ms +step:616/1705 train_time:58095ms step_avg:94.31ms +step:617/1705 train_time:58189ms step_avg:94.31ms +step:618/1705 train_time:58283ms step_avg:94.31ms +step:619/1705 train_time:58378ms step_avg:94.31ms +step:620/1705 train_time:58472ms step_avg:94.31ms +step:621/1705 train_time:58566ms step_avg:94.31ms +step:622/1705 train_time:58660ms step_avg:94.31ms +step:623/1705 train_time:58755ms step_avg:94.31ms +step:624/1705 train_time:58849ms step_avg:94.31ms +step:625/1705 train_time:58944ms step_avg:94.31ms +step:625/1705 val_loss:3.6190 train_time:59040ms step_avg:94.46ms +step:626/1705 train_time:59063ms step_avg:94.35ms +step:627/1705 train_time:59137ms step_avg:94.32ms +step:628/1705 train_time:59235ms step_avg:94.32ms +step:629/1705 train_time:59341ms step_avg:94.34ms +step:630/1705 train_time:59438ms step_avg:94.35ms +step:631/1705 train_time:59531ms step_avg:94.34ms +step:632/1705 train_time:59624ms step_avg:94.34ms +step:633/1705 train_time:59717ms step_avg:94.34ms +step:634/1705 train_time:59811ms step_avg:94.34ms +step:635/1705 train_time:59903ms step_avg:94.34ms +step:636/1705 train_time:59998ms step_avg:94.34ms +step:637/1705 train_time:60093ms step_avg:94.34ms +step:638/1705 train_time:60188ms step_avg:94.34ms +step:639/1705 train_time:60561ms step_avg:94.77ms +step:640/1705 train_time:60636ms step_avg:94.74ms +step:641/1705 train_time:60729ms step_avg:94.74ms +step:642/1705 train_time:60821ms step_avg:94.74ms +step:643/1705 train_time:60915ms step_avg:94.74ms +step:644/1705 train_time:61008ms step_avg:94.73ms +step:645/1705 train_time:61101ms step_avg:94.73ms +step:646/1705 train_time:61194ms step_avg:94.73ms +step:647/1705 train_time:61287ms step_avg:94.72ms +step:648/1705 train_time:61381ms step_avg:94.72ms +step:649/1705 train_time:61478ms step_avg:94.73ms +step:650/1705 train_time:61576ms step_avg:94.73ms +step:651/1705 train_time:61672ms step_avg:94.73ms +step:652/1705 train_time:61766ms step_avg:94.73ms +step:653/1705 train_time:61860ms step_avg:94.73ms +step:654/1705 train_time:61955ms step_avg:94.73ms +step:655/1705 train_time:62048ms step_avg:94.73ms +step:656/1705 train_time:62142ms step_avg:94.73ms +step:657/1705 train_time:62235ms step_avg:94.73ms +step:658/1705 train_time:62329ms step_avg:94.73ms +step:659/1705 train_time:62425ms step_avg:94.73ms +step:660/1705 train_time:62519ms step_avg:94.73ms +step:661/1705 train_time:62615ms step_avg:94.73ms +step:662/1705 train_time:62712ms step_avg:94.73ms +step:663/1705 train_time:62806ms step_avg:94.73ms +step:664/1705 train_time:62900ms step_avg:94.73ms +step:665/1705 train_time:62994ms step_avg:94.73ms +step:666/1705 train_time:63088ms step_avg:94.73ms +step:667/1705 train_time:63182ms step_avg:94.73ms +step:668/1705 train_time:63276ms step_avg:94.72ms +step:669/1705 train_time:63370ms step_avg:94.72ms +step:670/1705 train_time:63463ms step_avg:94.72ms +step:671/1705 train_time:63559ms step_avg:94.72ms +step:672/1705 train_time:63654ms step_avg:94.72ms +step:673/1705 train_time:63750ms step_avg:94.72ms +step:674/1705 train_time:63843ms step_avg:94.72ms +step:675/1705 train_time:63938ms step_avg:94.72ms +step:676/1705 train_time:64032ms step_avg:94.72ms +step:677/1705 train_time:64125ms step_avg:94.72ms +step:678/1705 train_time:64219ms step_avg:94.72ms +step:679/1705 train_time:64313ms step_avg:94.72ms +step:680/1705 train_time:64407ms step_avg:94.72ms +step:681/1705 train_time:64501ms step_avg:94.71ms +step:682/1705 train_time:64595ms step_avg:94.71ms +step:683/1705 train_time:64691ms step_avg:94.72ms +step:684/1705 train_time:64786ms step_avg:94.72ms +step:685/1705 train_time:64880ms step_avg:94.72ms +step:686/1705 train_time:64976ms step_avg:94.72ms +step:687/1705 train_time:65069ms step_avg:94.72ms +step:688/1705 train_time:65163ms step_avg:94.71ms +step:689/1705 train_time:65257ms step_avg:94.71ms +step:690/1705 train_time:65352ms step_avg:94.71ms +step:691/1705 train_time:65446ms step_avg:94.71ms +step:692/1705 train_time:65540ms step_avg:94.71ms +step:693/1705 train_time:65635ms step_avg:94.71ms +step:694/1705 train_time:65729ms step_avg:94.71ms +step:695/1705 train_time:65824ms step_avg:94.71ms +step:696/1705 train_time:65918ms step_avg:94.71ms +step:697/1705 train_time:66013ms step_avg:94.71ms +step:698/1705 train_time:66107ms step_avg:94.71ms +step:699/1705 train_time:66201ms step_avg:94.71ms +step:700/1705 train_time:66296ms step_avg:94.71ms +step:701/1705 train_time:66390ms step_avg:94.71ms +step:702/1705 train_time:66484ms step_avg:94.71ms +step:703/1705 train_time:66578ms step_avg:94.71ms +step:704/1705 train_time:66673ms step_avg:94.71ms +step:705/1705 train_time:66767ms step_avg:94.71ms +step:706/1705 train_time:66861ms step_avg:94.70ms +step:707/1705 train_time:66956ms step_avg:94.70ms +step:708/1705 train_time:67051ms step_avg:94.70ms +step:709/1705 train_time:67145ms step_avg:94.70ms +step:710/1705 train_time:67238ms step_avg:94.70ms +step:711/1705 train_time:67334ms step_avg:94.70ms +step:712/1705 train_time:67429ms step_avg:94.70ms +step:713/1705 train_time:67523ms step_avg:94.70ms +step:714/1705 train_time:67617ms step_avg:94.70ms +step:715/1705 train_time:67712ms step_avg:94.70ms +step:716/1705 train_time:67806ms step_avg:94.70ms +step:717/1705 train_time:67900ms step_avg:94.70ms +step:718/1705 train_time:67994ms step_avg:94.70ms +step:719/1705 train_time:68088ms step_avg:94.70ms +step:720/1705 train_time:68182ms step_avg:94.70ms +step:721/1705 train_time:68276ms step_avg:94.70ms +step:722/1705 train_time:68372ms step_avg:94.70ms +step:723/1705 train_time:68466ms step_avg:94.70ms +step:724/1705 train_time:68560ms step_avg:94.70ms +step:725/1705 train_time:68655ms step_avg:94.70ms +step:726/1705 train_time:68751ms step_avg:94.70ms +step:727/1705 train_time:68845ms step_avg:94.70ms +step:728/1705 train_time:68939ms step_avg:94.70ms +step:729/1705 train_time:69034ms step_avg:94.70ms +step:730/1705 train_time:69128ms step_avg:94.70ms +step:731/1705 train_time:69221ms step_avg:94.69ms +step:732/1705 train_time:69316ms step_avg:94.69ms +step:733/1705 train_time:69410ms step_avg:94.69ms +step:734/1705 train_time:69505ms step_avg:94.69ms +step:735/1705 train_time:69599ms step_avg:94.69ms +step:736/1705 train_time:69694ms step_avg:94.69ms +step:737/1705 train_time:69789ms step_avg:94.69ms +step:738/1705 train_time:69884ms step_avg:94.69ms +step:739/1705 train_time:69978ms step_avg:94.69ms +step:740/1705 train_time:70073ms step_avg:94.69ms +step:741/1705 train_time:70167ms step_avg:94.69ms +step:742/1705 train_time:70261ms step_avg:94.69ms +step:743/1705 train_time:70356ms step_avg:94.69ms +step:744/1705 train_time:70451ms step_avg:94.69ms +step:745/1705 train_time:70546ms step_avg:94.69ms +step:746/1705 train_time:70640ms step_avg:94.69ms +step:747/1705 train_time:70736ms step_avg:94.69ms +step:748/1705 train_time:70830ms step_avg:94.69ms +step:749/1705 train_time:70924ms step_avg:94.69ms +step:750/1705 train_time:71018ms step_avg:94.69ms +step:750/1705 val_loss:3.5638 train_time:71114ms step_avg:94.82ms +step:751/1705 train_time:71137ms step_avg:94.72ms +step:752/1705 train_time:71213ms step_avg:94.70ms +step:753/1705 train_time:71310ms step_avg:94.70ms +step:754/1705 train_time:71404ms step_avg:94.70ms +step:755/1705 train_time:71498ms step_avg:94.70ms +step:756/1705 train_time:71591ms step_avg:94.70ms +step:757/1705 train_time:71684ms step_avg:94.70ms +step:758/1705 train_time:71778ms step_avg:94.69ms +step:759/1705 train_time:71871ms step_avg:94.69ms +step:760/1705 train_time:71964ms step_avg:94.69ms +step:761/1705 train_time:72059ms step_avg:94.69ms +step:762/1705 train_time:72154ms step_avg:94.69ms +step:763/1705 train_time:72249ms step_avg:94.69ms +step:764/1705 train_time:72344ms step_avg:94.69ms +step:765/1705 train_time:72438ms step_avg:94.69ms +step:766/1705 train_time:72533ms step_avg:94.69ms +step:767/1705 train_time:72626ms step_avg:94.69ms +step:768/1705 train_time:72720ms step_avg:94.69ms +step:769/1705 train_time:72815ms step_avg:94.69ms +step:770/1705 train_time:72908ms step_avg:94.69ms +step:771/1705 train_time:73001ms step_avg:94.68ms +step:772/1705 train_time:73097ms step_avg:94.69ms +step:773/1705 train_time:73192ms step_avg:94.69ms +step:774/1705 train_time:73286ms step_avg:94.68ms +step:775/1705 train_time:73381ms step_avg:94.69ms +step:776/1705 train_time:73475ms step_avg:94.68ms +step:777/1705 train_time:73570ms step_avg:94.68ms +step:778/1705 train_time:73663ms step_avg:94.68ms +step:779/1705 train_time:73758ms step_avg:94.68ms +step:780/1705 train_time:73851ms step_avg:94.68ms +step:781/1705 train_time:73944ms step_avg:94.68ms +step:782/1705 train_time:74038ms step_avg:94.68ms +step:783/1705 train_time:74133ms step_avg:94.68ms +step:784/1705 train_time:74227ms step_avg:94.68ms +step:785/1705 train_time:74322ms step_avg:94.68ms +step:786/1705 train_time:74416ms step_avg:94.68ms +step:787/1705 train_time:74511ms step_avg:94.68ms +step:788/1705 train_time:74604ms step_avg:94.68ms +step:789/1705 train_time:74699ms step_avg:94.68ms +step:790/1705 train_time:74793ms step_avg:94.68ms +step:791/1705 train_time:74887ms step_avg:94.67ms +step:792/1705 train_time:74980ms step_avg:94.67ms +step:793/1705 train_time:75076ms step_avg:94.67ms +step:794/1705 train_time:75169ms step_avg:94.67ms +step:795/1705 train_time:75264ms step_avg:94.67ms +step:796/1705 train_time:75359ms step_avg:94.67ms +step:797/1705 train_time:75453ms step_avg:94.67ms +step:798/1705 train_time:75548ms step_avg:94.67ms +step:799/1705 train_time:75642ms step_avg:94.67ms +step:800/1705 train_time:75736ms step_avg:94.67ms +step:801/1705 train_time:75830ms step_avg:94.67ms +step:802/1705 train_time:75924ms step_avg:94.67ms +step:803/1705 train_time:76019ms step_avg:94.67ms +step:804/1705 train_time:76113ms step_avg:94.67ms +step:805/1705 train_time:76206ms step_avg:94.67ms +step:806/1705 train_time:76300ms step_avg:94.67ms +step:807/1705 train_time:76395ms step_avg:94.67ms +step:808/1705 train_time:76489ms step_avg:94.67ms +step:809/1705 train_time:76584ms step_avg:94.67ms +step:810/1705 train_time:76678ms step_avg:94.66ms +step:811/1705 train_time:76772ms step_avg:94.66ms +step:812/1705 train_time:76867ms step_avg:94.66ms +step:813/1705 train_time:76961ms step_avg:94.66ms +step:814/1705 train_time:77056ms step_avg:94.66ms +step:815/1705 train_time:77150ms step_avg:94.66ms +step:816/1705 train_time:77244ms step_avg:94.66ms +step:817/1705 train_time:77339ms step_avg:94.66ms +step:818/1705 train_time:77434ms step_avg:94.66ms +step:819/1705 train_time:77528ms step_avg:94.66ms +step:820/1705 train_time:77622ms step_avg:94.66ms +step:821/1705 train_time:77717ms step_avg:94.66ms +step:822/1705 train_time:77811ms step_avg:94.66ms +step:823/1705 train_time:77905ms step_avg:94.66ms +step:824/1705 train_time:77999ms step_avg:94.66ms +step:825/1705 train_time:78094ms step_avg:94.66ms +step:826/1705 train_time:78187ms step_avg:94.66ms +step:827/1705 train_time:78281ms step_avg:94.66ms +step:828/1705 train_time:78376ms step_avg:94.66ms +step:829/1705 train_time:78471ms step_avg:94.66ms +step:830/1705 train_time:78564ms step_avg:94.66ms +step:831/1705 train_time:78658ms step_avg:94.66ms +step:832/1705 train_time:78753ms step_avg:94.65ms +step:833/1705 train_time:78847ms step_avg:94.65ms +step:834/1705 train_time:78941ms step_avg:94.65ms +step:835/1705 train_time:79035ms step_avg:94.65ms +step:836/1705 train_time:79130ms step_avg:94.65ms +step:837/1705 train_time:79225ms step_avg:94.65ms +step:838/1705 train_time:79319ms step_avg:94.65ms +step:839/1705 train_time:79414ms step_avg:94.65ms +step:840/1705 train_time:79508ms step_avg:94.65ms +step:841/1705 train_time:79602ms step_avg:94.65ms +step:842/1705 train_time:79697ms step_avg:94.65ms +step:843/1705 train_time:79791ms step_avg:94.65ms +step:844/1705 train_time:79885ms step_avg:94.65ms +step:845/1705 train_time:79979ms step_avg:94.65ms +step:846/1705 train_time:80075ms step_avg:94.65ms +step:847/1705 train_time:80169ms step_avg:94.65ms +step:848/1705 train_time:80263ms step_avg:94.65ms +step:849/1705 train_time:80357ms step_avg:94.65ms +step:850/1705 train_time:80451ms step_avg:94.65ms +step:851/1705 train_time:80701ms step_avg:94.83ms +step:852/1705 train_time:80860ms step_avg:94.91ms +step:853/1705 train_time:80952ms step_avg:94.90ms +step:854/1705 train_time:81045ms step_avg:94.90ms +step:855/1705 train_time:81138ms step_avg:94.90ms +step:856/1705 train_time:81231ms step_avg:94.90ms +step:857/1705 train_time:81325ms step_avg:94.89ms +step:858/1705 train_time:81418ms step_avg:94.89ms +step:859/1705 train_time:81511ms step_avg:94.89ms +step:860/1705 train_time:81604ms step_avg:94.89ms +step:861/1705 train_time:81699ms step_avg:94.89ms +step:862/1705 train_time:81798ms step_avg:94.89ms +step:863/1705 train_time:81896ms step_avg:94.90ms +step:864/1705 train_time:81992ms step_avg:94.90ms +step:865/1705 train_time:82085ms step_avg:94.90ms +step:866/1705 train_time:82179ms step_avg:94.89ms +step:867/1705 train_time:82272ms step_avg:94.89ms +step:868/1705 train_time:82365ms step_avg:94.89ms +step:869/1705 train_time:82459ms step_avg:94.89ms +step:870/1705 train_time:82553ms step_avg:94.89ms +step:871/1705 train_time:82646ms step_avg:94.89ms +step:872/1705 train_time:82743ms step_avg:94.89ms +step:873/1705 train_time:82839ms step_avg:94.89ms +step:874/1705 train_time:82934ms step_avg:94.89ms +step:875/1705 train_time:83029ms step_avg:94.89ms +step:875/1705 val_loss:3.5229 train_time:83123ms step_avg:95.00ms +step:876/1705 train_time:83145ms step_avg:94.91ms +step:877/1705 train_time:83224ms step_avg:94.90ms +step:878/1705 train_time:83324ms step_avg:94.90ms +step:879/1705 train_time:83419ms step_avg:94.90ms +step:880/1705 train_time:83512ms step_avg:94.90ms +step:881/1705 train_time:83605ms step_avg:94.90ms +step:882/1705 train_time:83698ms step_avg:94.90ms +step:883/1705 train_time:83791ms step_avg:94.89ms +step:884/1705 train_time:83885ms step_avg:94.89ms +step:885/1705 train_time:83978ms step_avg:94.89ms +step:886/1705 train_time:84072ms step_avg:94.89ms +step:887/1705 train_time:84168ms step_avg:94.89ms +step:888/1705 train_time:84265ms step_avg:94.89ms +step:889/1705 train_time:84363ms step_avg:94.90ms +step:890/1705 train_time:84458ms step_avg:94.90ms +step:891/1705 train_time:84552ms step_avg:94.90ms +step:892/1705 train_time:84647ms step_avg:94.90ms +step:893/1705 train_time:84740ms step_avg:94.89ms +step:894/1705 train_time:84833ms step_avg:94.89ms +step:895/1705 train_time:84927ms step_avg:94.89ms +step:896/1705 train_time:85021ms step_avg:94.89ms +step:897/1705 train_time:85115ms step_avg:94.89ms +step:898/1705 train_time:85211ms step_avg:94.89ms +step:899/1705 train_time:85307ms step_avg:94.89ms +step:900/1705 train_time:85403ms step_avg:94.89ms +step:901/1705 train_time:85497ms step_avg:94.89ms +step:902/1705 train_time:85592ms step_avg:94.89ms +step:903/1705 train_time:85686ms step_avg:94.89ms +step:904/1705 train_time:85779ms step_avg:94.89ms +step:905/1705 train_time:85872ms step_avg:94.89ms +step:906/1705 train_time:85967ms step_avg:94.89ms +step:907/1705 train_time:86061ms step_avg:94.89ms +step:908/1705 train_time:86155ms step_avg:94.88ms +step:909/1705 train_time:86250ms step_avg:94.88ms +step:910/1705 train_time:86345ms step_avg:94.88ms +step:911/1705 train_time:86441ms step_avg:94.89ms +step:912/1705 train_time:86536ms step_avg:94.89ms +step:913/1705 train_time:86631ms step_avg:94.89ms +step:914/1705 train_time:86725ms step_avg:94.89ms +step:915/1705 train_time:86819ms step_avg:94.88ms +step:916/1705 train_time:86913ms step_avg:94.88ms +step:917/1705 train_time:87007ms step_avg:94.88ms +step:918/1705 train_time:87102ms step_avg:94.88ms +step:919/1705 train_time:87197ms step_avg:94.88ms +step:920/1705 train_time:87291ms step_avg:94.88ms +step:921/1705 train_time:87386ms step_avg:94.88ms +step:922/1705 train_time:87482ms step_avg:94.88ms +step:923/1705 train_time:87576ms step_avg:94.88ms +step:924/1705 train_time:87670ms step_avg:94.88ms +step:925/1705 train_time:87765ms step_avg:94.88ms +step:926/1705 train_time:87859ms step_avg:94.88ms +step:927/1705 train_time:87952ms step_avg:94.88ms +step:928/1705 train_time:88047ms step_avg:94.88ms +step:929/1705 train_time:88142ms step_avg:94.88ms +step:930/1705 train_time:88236ms step_avg:94.88ms +step:931/1705 train_time:88331ms step_avg:94.88ms +step:932/1705 train_time:88426ms step_avg:94.88ms +step:933/1705 train_time:88522ms step_avg:94.88ms +step:934/1705 train_time:88616ms step_avg:94.88ms +step:935/1705 train_time:88710ms step_avg:94.88ms +step:936/1705 train_time:88805ms step_avg:94.88ms +step:937/1705 train_time:88899ms step_avg:94.88ms +step:938/1705 train_time:88993ms step_avg:94.87ms +step:939/1705 train_time:89088ms step_avg:94.88ms +step:940/1705 train_time:89182ms step_avg:94.87ms +step:941/1705 train_time:89277ms step_avg:94.87ms +step:942/1705 train_time:89372ms step_avg:94.87ms +step:943/1705 train_time:89467ms step_avg:94.87ms +step:944/1705 train_time:89561ms step_avg:94.87ms +step:945/1705 train_time:89655ms step_avg:94.87ms +step:946/1705 train_time:89750ms step_avg:94.87ms +step:947/1705 train_time:89845ms step_avg:94.87ms +step:948/1705 train_time:89940ms step_avg:94.87ms +step:949/1705 train_time:90034ms step_avg:94.87ms +step:950/1705 train_time:90128ms step_avg:94.87ms +step:951/1705 train_time:90223ms step_avg:94.87ms +step:952/1705 train_time:90318ms step_avg:94.87ms +step:953/1705 train_time:90412ms step_avg:94.87ms +step:954/1705 train_time:90506ms step_avg:94.87ms +step:955/1705 train_time:90601ms step_avg:94.87ms +step:956/1705 train_time:90696ms step_avg:94.87ms +step:957/1705 train_time:90790ms step_avg:94.87ms +step:958/1705 train_time:90885ms step_avg:94.87ms +step:959/1705 train_time:90979ms step_avg:94.87ms +step:960/1705 train_time:91073ms step_avg:94.87ms +step:961/1705 train_time:91168ms step_avg:94.87ms +step:962/1705 train_time:91263ms step_avg:94.87ms +step:963/1705 train_time:91358ms step_avg:94.87ms +step:964/1705 train_time:91453ms step_avg:94.87ms +step:965/1705 train_time:91548ms step_avg:94.87ms +step:966/1705 train_time:91643ms step_avg:94.87ms +step:967/1705 train_time:91737ms step_avg:94.87ms +step:968/1705 train_time:91831ms step_avg:94.87ms +step:969/1705 train_time:91926ms step_avg:94.87ms +step:970/1705 train_time:92020ms step_avg:94.87ms +step:971/1705 train_time:92114ms step_avg:94.87ms +step:972/1705 train_time:92209ms step_avg:94.86ms +step:973/1705 train_time:92303ms step_avg:94.86ms +step:974/1705 train_time:92399ms step_avg:94.87ms +step:975/1705 train_time:92493ms step_avg:94.86ms +step:976/1705 train_time:92587ms step_avg:94.86ms +step:977/1705 train_time:92682ms step_avg:94.86ms +step:978/1705 train_time:92776ms step_avg:94.86ms +step:979/1705 train_time:92871ms step_avg:94.86ms +step:980/1705 train_time:92965ms step_avg:94.86ms +step:981/1705 train_time:93060ms step_avg:94.86ms +step:982/1705 train_time:93154ms step_avg:94.86ms +step:983/1705 train_time:93248ms step_avg:94.86ms +step:984/1705 train_time:93343ms step_avg:94.86ms +step:985/1705 train_time:93437ms step_avg:94.86ms +step:986/1705 train_time:93532ms step_avg:94.86ms +step:987/1705 train_time:93626ms step_avg:94.86ms +step:988/1705 train_time:93720ms step_avg:94.86ms +step:989/1705 train_time:93814ms step_avg:94.86ms +step:990/1705 train_time:93909ms step_avg:94.86ms +step:991/1705 train_time:94003ms step_avg:94.86ms +step:992/1705 train_time:94098ms step_avg:94.86ms +step:993/1705 train_time:94192ms step_avg:94.86ms +step:994/1705 train_time:94287ms step_avg:94.86ms +step:995/1705 train_time:94381ms step_avg:94.86ms +step:996/1705 train_time:94475ms step_avg:94.85ms +step:997/1705 train_time:94570ms step_avg:94.85ms +step:998/1705 train_time:94665ms step_avg:94.85ms +step:999/1705 train_time:94759ms step_avg:94.85ms +step:1000/1705 train_time:94853ms step_avg:94.85ms +step:1000/1705 val_loss:3.4830 train_time:94948ms step_avg:94.95ms +step:1001/1705 train_time:94970ms step_avg:94.87ms +step:1002/1705 train_time:95049ms step_avg:94.86ms +step:1003/1705 train_time:95148ms step_avg:94.86ms +step:1004/1705 train_time:95244ms step_avg:94.86ms +step:1005/1705 train_time:95337ms step_avg:94.86ms +step:1006/1705 train_time:95431ms step_avg:94.86ms +step:1007/1705 train_time:95524ms step_avg:94.86ms +step:1008/1705 train_time:95617ms step_avg:94.86ms +step:1009/1705 train_time:95711ms step_avg:94.86ms +step:1010/1705 train_time:95804ms step_avg:94.85ms +step:1011/1705 train_time:95898ms step_avg:94.85ms +step:1012/1705 train_time:95993ms step_avg:94.86ms +step:1013/1705 train_time:96091ms step_avg:94.86ms +step:1014/1705 train_time:96187ms step_avg:94.86ms +step:1015/1705 train_time:96282ms step_avg:94.86ms +step:1016/1705 train_time:96376ms step_avg:94.86ms +step:1017/1705 train_time:96469ms step_avg:94.86ms +step:1018/1705 train_time:96563ms step_avg:94.86ms +step:1019/1705 train_time:96657ms step_avg:94.85ms +step:1020/1705 train_time:96750ms step_avg:94.85ms +step:1021/1705 train_time:96844ms step_avg:94.85ms +step:1022/1705 train_time:96938ms step_avg:94.85ms +step:1023/1705 train_time:97034ms step_avg:94.85ms +step:1024/1705 train_time:97130ms step_avg:94.85ms +step:1025/1705 train_time:97225ms step_avg:94.85ms +step:1026/1705 train_time:97320ms step_avg:94.85ms +step:1027/1705 train_time:97415ms step_avg:94.85ms +step:1028/1705 train_time:97508ms step_avg:94.85ms +step:1029/1705 train_time:97603ms step_avg:94.85ms +step:1030/1705 train_time:97696ms step_avg:94.85ms +step:1031/1705 train_time:97790ms step_avg:94.85ms +step:1032/1705 train_time:97885ms step_avg:94.85ms +step:1033/1705 train_time:97980ms step_avg:94.85ms +step:1034/1705 train_time:98075ms step_avg:94.85ms +step:1035/1705 train_time:98170ms step_avg:94.85ms +step:1036/1705 train_time:98265ms step_avg:94.85ms +step:1037/1705 train_time:98361ms step_avg:94.85ms +step:1038/1705 train_time:98454ms step_avg:94.85ms +step:1039/1705 train_time:98548ms step_avg:94.85ms +step:1040/1705 train_time:98643ms step_avg:94.85ms +step:1041/1705 train_time:98738ms step_avg:94.85ms +step:1042/1705 train_time:98831ms step_avg:94.85ms +step:1043/1705 train_time:98926ms step_avg:94.85ms +step:1044/1705 train_time:99022ms step_avg:94.85ms +step:1045/1705 train_time:99116ms step_avg:94.85ms +step:1046/1705 train_time:99212ms step_avg:94.85ms +step:1047/1705 train_time:99307ms step_avg:94.85ms +step:1048/1705 train_time:99402ms step_avg:94.85ms +step:1049/1705 train_time:99496ms step_avg:94.85ms +step:1050/1705 train_time:99590ms step_avg:94.85ms +step:1051/1705 train_time:99685ms step_avg:94.85ms +step:1052/1705 train_time:99780ms step_avg:94.85ms +step:1053/1705 train_time:99875ms step_avg:94.85ms +step:1054/1705 train_time:99969ms step_avg:94.85ms +step:1055/1705 train_time:100064ms step_avg:94.85ms +step:1056/1705 train_time:100159ms step_avg:94.85ms +step:1057/1705 train_time:100253ms step_avg:94.85ms +step:1058/1705 train_time:100349ms step_avg:94.85ms +step:1059/1705 train_time:100444ms step_avg:94.85ms +step:1060/1705 train_time:100538ms step_avg:94.85ms +step:1061/1705 train_time:100632ms step_avg:94.85ms +step:1062/1705 train_time:100975ms step_avg:95.08ms +step:1063/1705 train_time:101067ms step_avg:95.08ms +step:1064/1705 train_time:101161ms step_avg:95.08ms +step:1065/1705 train_time:101254ms step_avg:95.07ms +step:1066/1705 train_time:101348ms step_avg:95.07ms +step:1067/1705 train_time:101441ms step_avg:95.07ms +step:1068/1705 train_time:101534ms step_avg:95.07ms +step:1069/1705 train_time:101628ms step_avg:95.07ms +step:1070/1705 train_time:101721ms step_avg:95.07ms +step:1071/1705 train_time:101814ms step_avg:95.06ms +step:1072/1705 train_time:101911ms step_avg:95.07ms +step:1073/1705 train_time:102010ms step_avg:95.07ms +step:1074/1705 train_time:102109ms step_avg:95.07ms +step:1075/1705 train_time:102205ms step_avg:95.07ms +step:1076/1705 train_time:102299ms step_avg:95.07ms +step:1077/1705 train_time:102392ms step_avg:95.07ms +step:1078/1705 train_time:102486ms step_avg:95.07ms +step:1079/1705 train_time:102579ms step_avg:95.07ms +step:1080/1705 train_time:102673ms step_avg:95.07ms +step:1081/1705 train_time:102766ms step_avg:95.07ms +step:1082/1705 train_time:102862ms step_avg:95.07ms +step:1083/1705 train_time:102957ms step_avg:95.07ms +step:1084/1705 train_time:103053ms step_avg:95.07ms +step:1085/1705 train_time:103149ms step_avg:95.07ms +step:1086/1705 train_time:103245ms step_avg:95.07ms +step:1087/1705 train_time:103339ms step_avg:95.07ms +step:1088/1705 train_time:103432ms step_avg:95.07ms +step:1089/1705 train_time:103526ms step_avg:95.07ms +step:1090/1705 train_time:103620ms step_avg:95.06ms +step:1091/1705 train_time:103713ms step_avg:95.06ms +step:1092/1705 train_time:103808ms step_avg:95.06ms +step:1093/1705 train_time:103903ms step_avg:95.06ms +step:1094/1705 train_time:103998ms step_avg:95.06ms +step:1095/1705 train_time:104093ms step_avg:95.06ms +step:1096/1705 train_time:104187ms step_avg:95.06ms +step:1097/1705 train_time:104283ms step_avg:95.06ms +step:1098/1705 train_time:104378ms step_avg:95.06ms +step:1099/1705 train_time:104472ms step_avg:95.06ms +step:1100/1705 train_time:104565ms step_avg:95.06ms +step:1101/1705 train_time:104659ms step_avg:95.06ms +step:1102/1705 train_time:104753ms step_avg:95.06ms +step:1103/1705 train_time:104848ms step_avg:95.06ms +step:1104/1705 train_time:104942ms step_avg:95.06ms +step:1105/1705 train_time:105037ms step_avg:95.06ms +step:1106/1705 train_time:105132ms step_avg:95.06ms +step:1107/1705 train_time:105227ms step_avg:95.06ms +step:1108/1705 train_time:105322ms step_avg:95.06ms +step:1109/1705 train_time:105416ms step_avg:95.06ms +step:1110/1705 train_time:105511ms step_avg:95.05ms +step:1111/1705 train_time:105605ms step_avg:95.05ms +step:1112/1705 train_time:105698ms step_avg:95.05ms +step:1113/1705 train_time:105792ms step_avg:95.05ms +step:1114/1705 train_time:105887ms step_avg:95.05ms +step:1115/1705 train_time:105982ms step_avg:95.05ms +step:1116/1705 train_time:106076ms step_avg:95.05ms +step:1117/1705 train_time:106171ms step_avg:95.05ms +step:1118/1705 train_time:106267ms step_avg:95.05ms +step:1119/1705 train_time:106361ms step_avg:95.05ms +step:1120/1705 train_time:106455ms step_avg:95.05ms +step:1121/1705 train_time:106548ms step_avg:95.05ms +step:1122/1705 train_time:106643ms step_avg:95.05ms +step:1123/1705 train_time:106737ms step_avg:95.05ms +step:1124/1705 train_time:106831ms step_avg:95.05ms +step:1125/1705 train_time:106925ms step_avg:95.04ms +step:1125/1705 val_loss:3.4374 train_time:107020ms step_avg:95.13ms +step:1126/1705 train_time:107043ms step_avg:95.06ms +step:1127/1705 train_time:107121ms step_avg:95.05ms +step:1128/1705 train_time:107217ms step_avg:95.05ms +step:1129/1705 train_time:107312ms step_avg:95.05ms +step:1130/1705 train_time:107405ms step_avg:95.05ms +step:1131/1705 train_time:107498ms step_avg:95.05ms +step:1132/1705 train_time:107592ms step_avg:95.05ms +step:1133/1705 train_time:107686ms step_avg:95.05ms +step:1134/1705 train_time:107780ms step_avg:95.04ms +step:1135/1705 train_time:107873ms step_avg:95.04ms +step:1136/1705 train_time:107968ms step_avg:95.04ms +step:1137/1705 train_time:108065ms step_avg:95.04ms +step:1138/1705 train_time:108161ms step_avg:95.05ms +step:1139/1705 train_time:108257ms step_avg:95.05ms +step:1140/1705 train_time:108352ms step_avg:95.05ms +step:1141/1705 train_time:108447ms step_avg:95.05ms +step:1142/1705 train_time:108541ms step_avg:95.04ms +step:1143/1705 train_time:108634ms step_avg:95.04ms +step:1144/1705 train_time:108730ms step_avg:95.04ms +step:1145/1705 train_time:108824ms step_avg:95.04ms +step:1146/1705 train_time:108919ms step_avg:95.04ms +step:1147/1705 train_time:109015ms step_avg:95.04ms +step:1148/1705 train_time:109111ms step_avg:95.04ms +step:1149/1705 train_time:109209ms step_avg:95.05ms +step:1150/1705 train_time:109305ms step_avg:95.05ms +step:1151/1705 train_time:109399ms step_avg:95.05ms +step:1152/1705 train_time:109494ms step_avg:95.05ms +step:1153/1705 train_time:109589ms step_avg:95.05ms +step:1154/1705 train_time:109683ms step_avg:95.05ms +step:1155/1705 train_time:109778ms step_avg:95.05ms +step:1156/1705 train_time:109872ms step_avg:95.04ms +step:1157/1705 train_time:109967ms step_avg:95.05ms +step:1158/1705 train_time:110064ms step_avg:95.05ms +step:1159/1705 train_time:110160ms step_avg:95.05ms +step:1160/1705 train_time:110255ms step_avg:95.05ms +step:1161/1705 train_time:110351ms step_avg:95.05ms +step:1162/1705 train_time:110446ms step_avg:95.05ms +step:1163/1705 train_time:110541ms step_avg:95.05ms +step:1164/1705 train_time:110635ms step_avg:95.05ms +step:1165/1705 train_time:110730ms step_avg:95.05ms +step:1166/1705 train_time:110826ms step_avg:95.05ms +step:1167/1705 train_time:110921ms step_avg:95.05ms +step:1168/1705 train_time:111016ms step_avg:95.05ms +step:1169/1705 train_time:111112ms step_avg:95.05ms +step:1170/1705 train_time:111208ms step_avg:95.05ms +step:1171/1705 train_time:111305ms step_avg:95.05ms +step:1172/1705 train_time:111400ms step_avg:95.05ms +step:1173/1705 train_time:111495ms step_avg:95.05ms +step:1174/1705 train_time:111589ms step_avg:95.05ms +step:1175/1705 train_time:111685ms step_avg:95.05ms +step:1176/1705 train_time:111781ms step_avg:95.05ms +step:1177/1705 train_time:111875ms step_avg:95.05ms +step:1178/1705 train_time:111971ms step_avg:95.05ms +step:1179/1705 train_time:112065ms step_avg:95.05ms +step:1180/1705 train_time:112162ms step_avg:95.05ms +step:1181/1705 train_time:112258ms step_avg:95.05ms +step:1182/1705 train_time:112353ms step_avg:95.05ms +step:1183/1705 train_time:112449ms step_avg:95.05ms +step:1184/1705 train_time:112544ms step_avg:95.05ms +step:1185/1705 train_time:112638ms step_avg:95.05ms +step:1186/1705 train_time:112733ms step_avg:95.05ms +step:1187/1705 train_time:112828ms step_avg:95.05ms +step:1188/1705 train_time:112923ms step_avg:95.05ms +step:1189/1705 train_time:113018ms step_avg:95.05ms +step:1190/1705 train_time:113114ms step_avg:95.05ms +step:1191/1705 train_time:113210ms step_avg:95.05ms +step:1192/1705 train_time:113306ms step_avg:95.06ms +step:1193/1705 train_time:113401ms step_avg:95.06ms +step:1194/1705 train_time:113495ms step_avg:95.05ms +step:1195/1705 train_time:113591ms step_avg:95.05ms +step:1196/1705 train_time:113686ms step_avg:95.06ms +step:1197/1705 train_time:113781ms step_avg:95.05ms +step:1198/1705 train_time:113876ms step_avg:95.05ms +step:1199/1705 train_time:113970ms step_avg:95.05ms +step:1200/1705 train_time:114065ms step_avg:95.05ms +step:1201/1705 train_time:114161ms step_avg:95.06ms +step:1202/1705 train_time:114256ms step_avg:95.06ms +step:1203/1705 train_time:114352ms step_avg:95.06ms +step:1204/1705 train_time:114448ms step_avg:95.06ms +step:1205/1705 train_time:114544ms step_avg:95.06ms +step:1206/1705 train_time:114639ms step_avg:95.06ms +step:1207/1705 train_time:114734ms step_avg:95.06ms +step:1208/1705 train_time:114830ms step_avg:95.06ms +step:1209/1705 train_time:114926ms step_avg:95.06ms +step:1210/1705 train_time:115021ms step_avg:95.06ms +step:1211/1705 train_time:115116ms step_avg:95.06ms +step:1212/1705 train_time:115212ms step_avg:95.06ms +step:1213/1705 train_time:115307ms step_avg:95.06ms +step:1214/1705 train_time:115402ms step_avg:95.06ms +step:1215/1705 train_time:115499ms step_avg:95.06ms +step:1216/1705 train_time:115594ms step_avg:95.06ms +step:1217/1705 train_time:115689ms step_avg:95.06ms +step:1218/1705 train_time:115785ms step_avg:95.06ms +step:1219/1705 train_time:115880ms step_avg:95.06ms +step:1220/1705 train_time:115975ms step_avg:95.06ms +step:1221/1705 train_time:116070ms step_avg:95.06ms +step:1222/1705 train_time:116166ms step_avg:95.06ms +step:1223/1705 train_time:116260ms step_avg:95.06ms +step:1224/1705 train_time:116354ms step_avg:95.06ms +step:1225/1705 train_time:116451ms step_avg:95.06ms +step:1226/1705 train_time:116547ms step_avg:95.06ms +step:1227/1705 train_time:116643ms step_avg:95.06ms +step:1228/1705 train_time:116737ms step_avg:95.06ms +step:1229/1705 train_time:116833ms step_avg:95.06ms +step:1230/1705 train_time:116929ms step_avg:95.06ms +step:1231/1705 train_time:117024ms step_avg:95.06ms +step:1232/1705 train_time:117120ms step_avg:95.07ms +step:1233/1705 train_time:117215ms step_avg:95.06ms +step:1234/1705 train_time:117310ms step_avg:95.06ms +step:1235/1705 train_time:117405ms step_avg:95.06ms +step:1236/1705 train_time:117500ms step_avg:95.07ms +step:1237/1705 train_time:117595ms step_avg:95.06ms +step:1238/1705 train_time:117691ms step_avg:95.07ms +step:1239/1705 train_time:117787ms step_avg:95.07ms +step:1240/1705 train_time:117884ms step_avg:95.07ms +step:1241/1705 train_time:117978ms step_avg:95.07ms +step:1242/1705 train_time:118074ms step_avg:95.07ms +step:1243/1705 train_time:118169ms step_avg:95.07ms +step:1244/1705 train_time:118263ms step_avg:95.07ms +step:1245/1705 train_time:118358ms step_avg:95.07ms +step:1246/1705 train_time:118454ms step_avg:95.07ms +step:1247/1705 train_time:118549ms step_avg:95.07ms +step:1248/1705 train_time:118645ms step_avg:95.07ms +step:1249/1705 train_time:118741ms step_avg:95.07ms +step:1250/1705 train_time:118837ms step_avg:95.07ms +step:1250/1705 val_loss:3.3884 train_time:118932ms step_avg:95.15ms +step:1251/1705 train_time:118955ms step_avg:95.09ms +step:1252/1705 train_time:119040ms step_avg:95.08ms +step:1253/1705 train_time:119138ms step_avg:95.08ms +step:1254/1705 train_time:119233ms step_avg:95.08ms +step:1255/1705 train_time:119327ms step_avg:95.08ms +step:1256/1705 train_time:119421ms step_avg:95.08ms +step:1257/1705 train_time:119514ms step_avg:95.08ms +step:1258/1705 train_time:119609ms step_avg:95.08ms +step:1259/1705 train_time:119703ms step_avg:95.08ms +step:1260/1705 train_time:119796ms step_avg:95.08ms +step:1261/1705 train_time:119895ms step_avg:95.08ms +step:1262/1705 train_time:119994ms step_avg:95.08ms +step:1263/1705 train_time:120092ms step_avg:95.08ms +step:1264/1705 train_time:120189ms step_avg:95.09ms +step:1265/1705 train_time:120284ms step_avg:95.09ms +step:1266/1705 train_time:120378ms step_avg:95.08ms +step:1267/1705 train_time:120472ms step_avg:95.08ms +step:1268/1705 train_time:120567ms step_avg:95.08ms +step:1269/1705 train_time:120660ms step_avg:95.08ms +step:1270/1705 train_time:120754ms step_avg:95.08ms +step:1271/1705 train_time:120850ms step_avg:95.08ms +step:1272/1705 train_time:120949ms step_avg:95.09ms +step:1273/1705 train_time:121045ms step_avg:95.09ms +step:1274/1705 train_time:121439ms step_avg:95.32ms +step:1275/1705 train_time:121521ms step_avg:95.31ms +step:1276/1705 train_time:121614ms step_avg:95.31ms +step:1277/1705 train_time:121708ms step_avg:95.31ms +step:1278/1705 train_time:121802ms step_avg:95.31ms +step:1279/1705 train_time:121895ms step_avg:95.31ms +step:1280/1705 train_time:121990ms step_avg:95.30ms +step:1281/1705 train_time:122083ms step_avg:95.30ms +step:1282/1705 train_time:122177ms step_avg:95.30ms +step:1283/1705 train_time:122271ms step_avg:95.30ms +step:1284/1705 train_time:122369ms step_avg:95.30ms +step:1285/1705 train_time:122469ms step_avg:95.31ms +step:1286/1705 train_time:122565ms step_avg:95.31ms +step:1287/1705 train_time:122660ms step_avg:95.31ms +step:1288/1705 train_time:122755ms step_avg:95.31ms +step:1289/1705 train_time:122849ms step_avg:95.31ms +step:1290/1705 train_time:122944ms step_avg:95.31ms +step:1291/1705 train_time:123038ms step_avg:95.30ms +step:1292/1705 train_time:123132ms step_avg:95.30ms +step:1293/1705 train_time:123227ms step_avg:95.30ms +step:1294/1705 train_time:123323ms step_avg:95.30ms +step:1295/1705 train_time:123420ms step_avg:95.30ms +step:1296/1705 train_time:123516ms step_avg:95.31ms +step:1297/1705 train_time:123614ms step_avg:95.31ms +step:1298/1705 train_time:123710ms step_avg:95.31ms +step:1299/1705 train_time:123805ms step_avg:95.31ms +step:1300/1705 train_time:123899ms step_avg:95.31ms +step:1301/1705 train_time:123995ms step_avg:95.31ms +step:1302/1705 train_time:124089ms step_avg:95.31ms +step:1303/1705 train_time:124184ms step_avg:95.31ms +step:1304/1705 train_time:124279ms step_avg:95.31ms +step:1305/1705 train_time:124375ms step_avg:95.31ms +step:1306/1705 train_time:124471ms step_avg:95.31ms +step:1307/1705 train_time:124569ms step_avg:95.31ms +step:1308/1705 train_time:124665ms step_avg:95.31ms +step:1309/1705 train_time:124759ms step_avg:95.31ms +step:1310/1705 train_time:124854ms step_avg:95.31ms +step:1311/1705 train_time:124950ms step_avg:95.31ms +step:1312/1705 train_time:125044ms step_avg:95.31ms +step:1313/1705 train_time:125138ms step_avg:95.31ms +step:1314/1705 train_time:125233ms step_avg:95.31ms +step:1315/1705 train_time:125328ms step_avg:95.31ms +step:1316/1705 train_time:125423ms step_avg:95.31ms +step:1317/1705 train_time:125518ms step_avg:95.31ms +step:1318/1705 train_time:125616ms step_avg:95.31ms +step:1319/1705 train_time:125711ms step_avg:95.31ms +step:1320/1705 train_time:125806ms step_avg:95.31ms +step:1321/1705 train_time:125901ms step_avg:95.31ms +step:1322/1705 train_time:125996ms step_avg:95.31ms +step:1323/1705 train_time:126091ms step_avg:95.31ms +step:1324/1705 train_time:126186ms step_avg:95.31ms +step:1325/1705 train_time:126281ms step_avg:95.31ms +step:1326/1705 train_time:126376ms step_avg:95.31ms +step:1327/1705 train_time:126471ms step_avg:95.31ms +step:1328/1705 train_time:126568ms step_avg:95.31ms +step:1329/1705 train_time:126664ms step_avg:95.31ms +step:1330/1705 train_time:126758ms step_avg:95.31ms +step:1331/1705 train_time:126854ms step_avg:95.31ms +step:1332/1705 train_time:126949ms step_avg:95.31ms +step:1333/1705 train_time:127044ms step_avg:95.31ms +step:1334/1705 train_time:127138ms step_avg:95.31ms +step:1335/1705 train_time:127233ms step_avg:95.31ms +step:1336/1705 train_time:127328ms step_avg:95.31ms +step:1337/1705 train_time:127423ms step_avg:95.31ms +step:1338/1705 train_time:127518ms step_avg:95.31ms +step:1339/1705 train_time:127614ms step_avg:95.31ms +step:1340/1705 train_time:127709ms step_avg:95.31ms +step:1341/1705 train_time:127805ms step_avg:95.31ms +step:1342/1705 train_time:127899ms step_avg:95.30ms +step:1343/1705 train_time:127994ms step_avg:95.30ms +step:1344/1705 train_time:128089ms step_avg:95.30ms +step:1345/1705 train_time:128184ms step_avg:95.30ms +step:1346/1705 train_time:128279ms step_avg:95.30ms +step:1347/1705 train_time:128375ms step_avg:95.30ms +step:1348/1705 train_time:128470ms step_avg:95.30ms +step:1349/1705 train_time:128565ms step_avg:95.30ms +step:1350/1705 train_time:128661ms step_avg:95.30ms +step:1351/1705 train_time:128756ms step_avg:95.30ms +step:1352/1705 train_time:128852ms step_avg:95.30ms +step:1353/1705 train_time:128948ms step_avg:95.31ms +step:1354/1705 train_time:129044ms step_avg:95.31ms +step:1355/1705 train_time:129138ms step_avg:95.30ms +step:1356/1705 train_time:129233ms step_avg:95.30ms +step:1357/1705 train_time:129329ms step_avg:95.30ms +step:1358/1705 train_time:129424ms step_avg:95.31ms +step:1359/1705 train_time:129519ms step_avg:95.30ms +step:1360/1705 train_time:129614ms step_avg:95.30ms +step:1361/1705 train_time:129711ms step_avg:95.31ms +step:1362/1705 train_time:129805ms step_avg:95.30ms +step:1363/1705 train_time:129900ms step_avg:95.30ms +step:1364/1705 train_time:129995ms step_avg:95.30ms +step:1365/1705 train_time:130091ms step_avg:95.30ms +step:1366/1705 train_time:130186ms step_avg:95.30ms +step:1367/1705 train_time:130280ms step_avg:95.30ms +step:1368/1705 train_time:130376ms step_avg:95.30ms +step:1369/1705 train_time:130471ms step_avg:95.30ms +step:1370/1705 train_time:130566ms step_avg:95.30ms +step:1371/1705 train_time:130662ms step_avg:95.30ms +step:1372/1705 train_time:130756ms step_avg:95.30ms +step:1373/1705 train_time:130852ms step_avg:95.30ms +step:1374/1705 train_time:130947ms step_avg:95.30ms +step:1375/1705 train_time:131042ms step_avg:95.30ms +step:1375/1705 val_loss:3.3506 train_time:131137ms step_avg:95.37ms +step:1376/1705 train_time:131160ms step_avg:95.32ms +step:1377/1705 train_time:131236ms step_avg:95.31ms +step:1378/1705 train_time:131335ms step_avg:95.31ms +step:1379/1705 train_time:131429ms step_avg:95.31ms +step:1380/1705 train_time:131525ms step_avg:95.31ms +step:1381/1705 train_time:131620ms step_avg:95.31ms +step:1382/1705 train_time:131714ms step_avg:95.31ms +step:1383/1705 train_time:131808ms step_avg:95.31ms +step:1384/1705 train_time:131903ms step_avg:95.31ms +step:1385/1705 train_time:131998ms step_avg:95.31ms +step:1386/1705 train_time:132092ms step_avg:95.30ms +step:1387/1705 train_time:132190ms step_avg:95.31ms +step:1388/1705 train_time:132287ms step_avg:95.31ms +step:1389/1705 train_time:132384ms step_avg:95.31ms +step:1390/1705 train_time:132480ms step_avg:95.31ms +step:1391/1705 train_time:132574ms step_avg:95.31ms +step:1392/1705 train_time:132668ms step_avg:95.31ms +step:1393/1705 train_time:132764ms step_avg:95.31ms +step:1394/1705 train_time:132858ms step_avg:95.31ms +step:1395/1705 train_time:132952ms step_avg:95.31ms +step:1396/1705 train_time:133047ms step_avg:95.31ms +step:1397/1705 train_time:133143ms step_avg:95.31ms +step:1398/1705 train_time:133239ms step_avg:95.31ms +step:1399/1705 train_time:133334ms step_avg:95.31ms +step:1400/1705 train_time:133429ms step_avg:95.31ms +step:1401/1705 train_time:133524ms step_avg:95.31ms +step:1402/1705 train_time:133620ms step_avg:95.31ms +step:1403/1705 train_time:133715ms step_avg:95.31ms +step:1404/1705 train_time:133809ms step_avg:95.31ms +step:1405/1705 train_time:133904ms step_avg:95.31ms +step:1406/1705 train_time:133998ms step_avg:95.30ms +step:1407/1705 train_time:134094ms step_avg:95.30ms +step:1408/1705 train_time:134190ms step_avg:95.31ms +step:1409/1705 train_time:134286ms step_avg:95.31ms +step:1410/1705 train_time:134383ms step_avg:95.31ms +step:1411/1705 train_time:134477ms step_avg:95.31ms +step:1412/1705 train_time:134572ms step_avg:95.31ms +step:1413/1705 train_time:134667ms step_avg:95.31ms +step:1414/1705 train_time:134763ms step_avg:95.31ms +step:1415/1705 train_time:134858ms step_avg:95.31ms +step:1416/1705 train_time:134953ms step_avg:95.31ms +step:1417/1705 train_time:135047ms step_avg:95.31ms +step:1418/1705 train_time:135143ms step_avg:95.31ms +step:1419/1705 train_time:135238ms step_avg:95.31ms +step:1420/1705 train_time:135335ms step_avg:95.31ms +step:1421/1705 train_time:135430ms step_avg:95.31ms +step:1422/1705 train_time:135526ms step_avg:95.31ms +step:1423/1705 train_time:135621ms step_avg:95.31ms +step:1424/1705 train_time:135715ms step_avg:95.31ms +step:1425/1705 train_time:135810ms step_avg:95.31ms +step:1426/1705 train_time:135905ms step_avg:95.31ms +step:1427/1705 train_time:136000ms step_avg:95.30ms +step:1428/1705 train_time:136095ms step_avg:95.30ms +step:1429/1705 train_time:136190ms step_avg:95.30ms +step:1430/1705 train_time:136285ms step_avg:95.30ms +step:1431/1705 train_time:136381ms step_avg:95.30ms +step:1432/1705 train_time:136476ms step_avg:95.30ms +step:1433/1705 train_time:136571ms step_avg:95.30ms +step:1434/1705 train_time:136666ms step_avg:95.30ms +step:1435/1705 train_time:136762ms step_avg:95.30ms +step:1436/1705 train_time:136856ms step_avg:95.30ms +step:1437/1705 train_time:136951ms step_avg:95.30ms +step:1438/1705 train_time:137046ms step_avg:95.30ms +step:1439/1705 train_time:137141ms step_avg:95.30ms +step:1440/1705 train_time:137237ms step_avg:95.30ms +step:1441/1705 train_time:137333ms step_avg:95.30ms +step:1442/1705 train_time:137428ms step_avg:95.30ms +step:1443/1705 train_time:137523ms step_avg:95.30ms +step:1444/1705 train_time:137619ms step_avg:95.30ms +step:1445/1705 train_time:137714ms step_avg:95.30ms +step:1446/1705 train_time:137809ms step_avg:95.30ms +step:1447/1705 train_time:137905ms step_avg:95.30ms +step:1448/1705 train_time:138001ms step_avg:95.30ms +step:1449/1705 train_time:138097ms step_avg:95.31ms +step:1450/1705 train_time:138192ms step_avg:95.30ms +step:1451/1705 train_time:138287ms step_avg:95.30ms +step:1452/1705 train_time:138382ms step_avg:95.30ms +step:1453/1705 train_time:138477ms step_avg:95.30ms +step:1454/1705 train_time:138572ms step_avg:95.30ms +step:1455/1705 train_time:138667ms step_avg:95.30ms +step:1456/1705 train_time:138762ms step_avg:95.30ms +step:1457/1705 train_time:138857ms step_avg:95.30ms +step:1458/1705 train_time:138951ms step_avg:95.30ms +step:1459/1705 train_time:139047ms step_avg:95.30ms +step:1460/1705 train_time:139143ms step_avg:95.30ms +step:1461/1705 train_time:139239ms step_avg:95.30ms +step:1462/1705 train_time:139333ms step_avg:95.30ms +step:1463/1705 train_time:139428ms step_avg:95.30ms +step:1464/1705 train_time:139523ms step_avg:95.30ms +step:1465/1705 train_time:139618ms step_avg:95.30ms +step:1466/1705 train_time:139715ms step_avg:95.30ms +step:1467/1705 train_time:139810ms step_avg:95.30ms +step:1468/1705 train_time:139905ms step_avg:95.30ms +step:1469/1705 train_time:140001ms step_avg:95.30ms +step:1470/1705 train_time:140095ms step_avg:95.30ms +step:1471/1705 train_time:140191ms step_avg:95.30ms +step:1472/1705 train_time:140285ms step_avg:95.30ms +step:1473/1705 train_time:140382ms step_avg:95.30ms +step:1474/1705 train_time:140477ms step_avg:95.30ms +step:1475/1705 train_time:140572ms step_avg:95.30ms +step:1476/1705 train_time:140667ms step_avg:95.30ms +step:1477/1705 train_time:140763ms step_avg:95.30ms +step:1478/1705 train_time:140858ms step_avg:95.30ms +step:1479/1705 train_time:140952ms step_avg:95.30ms +step:1480/1705 train_time:141048ms step_avg:95.30ms +step:1481/1705 train_time:141144ms step_avg:95.30ms +step:1482/1705 train_time:141241ms step_avg:95.30ms +step:1483/1705 train_time:141336ms step_avg:95.30ms +step:1484/1705 train_time:141430ms step_avg:95.30ms +step:1485/1705 train_time:141679ms step_avg:95.41ms +step:1486/1705 train_time:141890ms step_avg:95.48ms +step:1487/1705 train_time:141983ms step_avg:95.48ms +step:1488/1705 train_time:142077ms step_avg:95.48ms +step:1489/1705 train_time:142171ms step_avg:95.48ms +step:1490/1705 train_time:142265ms step_avg:95.48ms +step:1491/1705 train_time:142359ms step_avg:95.48ms +step:1492/1705 train_time:142453ms step_avg:95.48ms +step:1493/1705 train_time:142547ms step_avg:95.48ms +step:1494/1705 train_time:142642ms step_avg:95.48ms +step:1495/1705 train_time:142742ms step_avg:95.48ms +step:1496/1705 train_time:142842ms step_avg:95.48ms +step:1497/1705 train_time:142939ms step_avg:95.48ms +step:1498/1705 train_time:143034ms step_avg:95.48ms +step:1499/1705 train_time:143128ms step_avg:95.48ms +step:1500/1705 train_time:143222ms step_avg:95.48ms +step:1500/1705 val_loss:3.3182 train_time:143317ms step_avg:95.54ms +step:1501/1705 train_time:143339ms step_avg:95.50ms +step:1502/1705 train_time:143420ms step_avg:95.49ms +step:1503/1705 train_time:143520ms step_avg:95.49ms +step:1504/1705 train_time:143616ms step_avg:95.49ms +step:1505/1705 train_time:143709ms step_avg:95.49ms +step:1506/1705 train_time:143803ms step_avg:95.49ms +step:1507/1705 train_time:143896ms step_avg:95.49ms +step:1508/1705 train_time:143990ms step_avg:95.48ms +step:1509/1705 train_time:144084ms step_avg:95.48ms +step:1510/1705 train_time:144178ms step_avg:95.48ms +step:1511/1705 train_time:144274ms step_avg:95.48ms +step:1512/1705 train_time:144371ms step_avg:95.48ms +step:1513/1705 train_time:144469ms step_avg:95.48ms +step:1514/1705 train_time:144566ms step_avg:95.49ms +step:1515/1705 train_time:144662ms step_avg:95.49ms +step:1516/1705 train_time:144756ms step_avg:95.49ms +step:1517/1705 train_time:144850ms step_avg:95.48ms +step:1518/1705 train_time:144944ms step_avg:95.48ms +step:1519/1705 train_time:145039ms step_avg:95.48ms +step:1520/1705 train_time:145133ms step_avg:95.48ms +step:1521/1705 train_time:145228ms step_avg:95.48ms +step:1522/1705 train_time:145324ms step_avg:95.48ms +step:1523/1705 train_time:145421ms step_avg:95.48ms +step:1524/1705 train_time:145518ms step_avg:95.48ms +step:1525/1705 train_time:145614ms step_avg:95.48ms +step:1526/1705 train_time:145708ms step_avg:95.48ms +step:1527/1705 train_time:145803ms step_avg:95.48ms +step:1528/1705 train_time:145898ms step_avg:95.48ms +step:1529/1705 train_time:145992ms step_avg:95.48ms +step:1530/1705 train_time:146086ms step_avg:95.48ms +step:1531/1705 train_time:146181ms step_avg:95.48ms +step:1532/1705 train_time:146277ms step_avg:95.48ms +step:1533/1705 train_time:146372ms step_avg:95.48ms +step:1534/1705 train_time:146468ms step_avg:95.48ms +step:1535/1705 train_time:146564ms step_avg:95.48ms +step:1536/1705 train_time:146660ms step_avg:95.48ms +step:1537/1705 train_time:146757ms step_avg:95.48ms +step:1538/1705 train_time:146851ms step_avg:95.48ms +step:1539/1705 train_time:146945ms step_avg:95.48ms +step:1540/1705 train_time:147040ms step_avg:95.48ms +step:1541/1705 train_time:147135ms step_avg:95.48ms +step:1542/1705 train_time:147230ms step_avg:95.48ms +step:1543/1705 train_time:147325ms step_avg:95.48ms +step:1544/1705 train_time:147422ms step_avg:95.48ms +step:1545/1705 train_time:147518ms step_avg:95.48ms +step:1546/1705 train_time:147615ms step_avg:95.48ms +step:1547/1705 train_time:147711ms step_avg:95.48ms +step:1548/1705 train_time:147805ms step_avg:95.48ms +step:1549/1705 train_time:147901ms step_avg:95.48ms +step:1550/1705 train_time:147995ms step_avg:95.48ms +step:1551/1705 train_time:148090ms step_avg:95.48ms +step:1552/1705 train_time:148185ms step_avg:95.48ms +step:1553/1705 train_time:148281ms step_avg:95.48ms +step:1554/1705 train_time:148376ms step_avg:95.48ms +step:1555/1705 train_time:148471ms step_avg:95.48ms +step:1556/1705 train_time:148567ms step_avg:95.48ms +step:1557/1705 train_time:148664ms step_avg:95.48ms +step:1558/1705 train_time:148759ms step_avg:95.48ms +step:1559/1705 train_time:148855ms step_avg:95.48ms +step:1560/1705 train_time:148949ms step_avg:95.48ms +step:1561/1705 train_time:149045ms step_avg:95.48ms +step:1562/1705 train_time:149140ms step_avg:95.48ms +step:1563/1705 train_time:149235ms step_avg:95.48ms +step:1564/1705 train_time:149331ms step_avg:95.48ms +step:1565/1705 train_time:149425ms step_avg:95.48ms +step:1566/1705 train_time:149522ms step_avg:95.48ms +step:1567/1705 train_time:149618ms step_avg:95.48ms +step:1568/1705 train_time:149714ms step_avg:95.48ms +step:1569/1705 train_time:149809ms step_avg:95.48ms +step:1570/1705 train_time:149904ms step_avg:95.48ms +step:1571/1705 train_time:149998ms step_avg:95.48ms +step:1572/1705 train_time:150094ms step_avg:95.48ms +step:1573/1705 train_time:150188ms step_avg:95.48ms +step:1574/1705 train_time:150283ms step_avg:95.48ms +step:1575/1705 train_time:150378ms step_avg:95.48ms +step:1576/1705 train_time:150473ms step_avg:95.48ms +step:1577/1705 train_time:150568ms step_avg:95.48ms +step:1578/1705 train_time:150664ms step_avg:95.48ms +step:1579/1705 train_time:150761ms step_avg:95.48ms +step:1580/1705 train_time:150858ms step_avg:95.48ms +step:1581/1705 train_time:150954ms step_avg:95.48ms +step:1582/1705 train_time:151050ms step_avg:95.48ms +step:1583/1705 train_time:151144ms step_avg:95.48ms +step:1584/1705 train_time:151239ms step_avg:95.48ms +step:1585/1705 train_time:151334ms step_avg:95.48ms +step:1586/1705 train_time:151429ms step_avg:95.48ms +step:1587/1705 train_time:151524ms step_avg:95.48ms +step:1588/1705 train_time:151619ms step_avg:95.48ms +step:1589/1705 train_time:151714ms step_avg:95.48ms +step:1590/1705 train_time:151809ms step_avg:95.48ms +step:1591/1705 train_time:151905ms step_avg:95.48ms +step:1592/1705 train_time:152001ms step_avg:95.48ms +step:1593/1705 train_time:152097ms step_avg:95.48ms +step:1594/1705 train_time:152192ms step_avg:95.48ms +step:1595/1705 train_time:152286ms step_avg:95.48ms +step:1596/1705 train_time:152381ms step_avg:95.48ms +step:1597/1705 train_time:152477ms step_avg:95.48ms +step:1598/1705 train_time:152571ms step_avg:95.48ms +step:1599/1705 train_time:152667ms step_avg:95.48ms +step:1600/1705 train_time:152763ms step_avg:95.48ms +step:1601/1705 train_time:152859ms step_avg:95.48ms +step:1602/1705 train_time:152954ms step_avg:95.48ms +step:1603/1705 train_time:153048ms step_avg:95.48ms +step:1604/1705 train_time:153145ms step_avg:95.48ms +step:1605/1705 train_time:153241ms step_avg:95.48ms +step:1606/1705 train_time:153337ms step_avg:95.48ms +step:1607/1705 train_time:153434ms step_avg:95.48ms +step:1608/1705 train_time:153531ms step_avg:95.48ms +step:1609/1705 train_time:153626ms step_avg:95.48ms +step:1610/1705 train_time:153721ms step_avg:95.48ms +step:1611/1705 train_time:153817ms step_avg:95.48ms +step:1612/1705 train_time:153912ms step_avg:95.48ms +step:1613/1705 train_time:154008ms step_avg:95.48ms +step:1614/1705 train_time:154103ms step_avg:95.48ms +step:1615/1705 train_time:154197ms step_avg:95.48ms +step:1616/1705 train_time:154292ms step_avg:95.48ms +step:1617/1705 train_time:154387ms step_avg:95.48ms +step:1618/1705 train_time:154484ms step_avg:95.48ms +step:1619/1705 train_time:154580ms step_avg:95.48ms +step:1620/1705 train_time:154676ms step_avg:95.48ms +step:1621/1705 train_time:154771ms step_avg:95.48ms +step:1622/1705 train_time:154865ms step_avg:95.48ms +step:1623/1705 train_time:154961ms step_avg:95.48ms +step:1624/1705 train_time:155056ms step_avg:95.48ms +step:1625/1705 train_time:155151ms step_avg:95.48ms +step:1625/1705 val_loss:3.2906 train_time:155246ms step_avg:95.54ms +step:1626/1705 train_time:155269ms step_avg:95.49ms +step:1627/1705 train_time:155350ms step_avg:95.48ms +step:1628/1705 train_time:155447ms step_avg:95.48ms +step:1629/1705 train_time:155543ms step_avg:95.48ms +step:1630/1705 train_time:155637ms step_avg:95.48ms +step:1631/1705 train_time:155732ms step_avg:95.48ms +step:1632/1705 train_time:155827ms step_avg:95.48ms +step:1633/1705 train_time:155922ms step_avg:95.48ms +step:1634/1705 train_time:156016ms step_avg:95.48ms +step:1635/1705 train_time:156110ms step_avg:95.48ms +step:1636/1705 train_time:156205ms step_avg:95.48ms +step:1637/1705 train_time:156302ms step_avg:95.48ms +step:1638/1705 train_time:156399ms step_avg:95.48ms +step:1639/1705 train_time:156495ms step_avg:95.48ms +step:1640/1705 train_time:156591ms step_avg:95.48ms +step:1641/1705 train_time:156686ms step_avg:95.48ms +step:1642/1705 train_time:156782ms step_avg:95.48ms +step:1643/1705 train_time:156876ms step_avg:95.48ms +step:1644/1705 train_time:156970ms step_avg:95.48ms +step:1645/1705 train_time:157065ms step_avg:95.48ms +step:1646/1705 train_time:157160ms step_avg:95.48ms +step:1647/1705 train_time:157255ms step_avg:95.48ms +step:1648/1705 train_time:157352ms step_avg:95.48ms +step:1649/1705 train_time:157450ms step_avg:95.48ms +step:1650/1705 train_time:157546ms step_avg:95.48ms +step:1651/1705 train_time:157640ms step_avg:95.48ms +step:1652/1705 train_time:157735ms step_avg:95.48ms +step:1653/1705 train_time:157830ms step_avg:95.48ms +step:1654/1705 train_time:157925ms step_avg:95.48ms +step:1655/1705 train_time:158019ms step_avg:95.48ms +step:1656/1705 train_time:158114ms step_avg:95.48ms +step:1657/1705 train_time:158209ms step_avg:95.48ms +step:1658/1705 train_time:158306ms step_avg:95.48ms +step:1659/1705 train_time:158402ms step_avg:95.48ms +step:1660/1705 train_time:158497ms step_avg:95.48ms +step:1661/1705 train_time:158594ms step_avg:95.48ms +step:1662/1705 train_time:158690ms step_avg:95.48ms +step:1663/1705 train_time:158784ms step_avg:95.48ms +step:1664/1705 train_time:158879ms step_avg:95.48ms +step:1665/1705 train_time:158974ms step_avg:95.48ms +step:1666/1705 train_time:159069ms step_avg:95.48ms +step:1667/1705 train_time:159163ms step_avg:95.48ms +step:1668/1705 train_time:159259ms step_avg:95.48ms +step:1669/1705 train_time:159354ms step_avg:95.48ms +step:1670/1705 train_time:159450ms step_avg:95.48ms +step:1671/1705 train_time:159545ms step_avg:95.48ms +step:1672/1705 train_time:159642ms step_avg:95.48ms +step:1673/1705 train_time:159737ms step_avg:95.48ms +step:1674/1705 train_time:159832ms step_avg:95.48ms +step:1675/1705 train_time:159927ms step_avg:95.48ms +step:1676/1705 train_time:160022ms step_avg:95.48ms +step:1677/1705 train_time:160117ms step_avg:95.48ms +step:1678/1705 train_time:160212ms step_avg:95.48ms +step:1679/1705 train_time:160308ms step_avg:95.48ms +step:1680/1705 train_time:160403ms step_avg:95.48ms +step:1681/1705 train_time:160499ms step_avg:95.48ms +step:1682/1705 train_time:160595ms step_avg:95.48ms +step:1683/1705 train_time:160691ms step_avg:95.48ms +step:1684/1705 train_time:160786ms step_avg:95.48ms +step:1685/1705 train_time:160881ms step_avg:95.48ms +step:1686/1705 train_time:160976ms step_avg:95.48ms +step:1687/1705 train_time:161072ms step_avg:95.48ms +step:1688/1705 train_time:161167ms step_avg:95.48ms +step:1689/1705 train_time:161262ms step_avg:95.48ms +step:1690/1705 train_time:161357ms step_avg:95.48ms +step:1691/1705 train_time:161452ms step_avg:95.48ms +step:1692/1705 train_time:161548ms step_avg:95.48ms +step:1693/1705 train_time:161645ms step_avg:95.48ms +step:1694/1705 train_time:161739ms step_avg:95.48ms +step:1695/1705 train_time:161834ms step_avg:95.48ms +step:1696/1705 train_time:161930ms step_avg:95.48ms +step:1697/1705 train_time:162025ms step_avg:95.48ms +step:1698/1705 train_time:162305ms step_avg:95.59ms +step:1699/1705 train_time:162471ms step_avg:95.63ms +step:1700/1705 train_time:162565ms step_avg:95.63ms +step:1701/1705 train_time:162658ms step_avg:95.63ms +step:1702/1705 train_time:162753ms step_avg:95.62ms +step:1703/1705 train_time:162847ms step_avg:95.62ms +step:1704/1705 train_time:162941ms step_avg:95.62ms +step:1705/1705 train_time:163035ms step_avg:95.62ms +step:1705/1705 val_loss:3.2764 train_time:163129ms step_avg:95.68ms +peak memory allocated: 33992 MiB reserved: 49096 MiB diff --git a/records/050925_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt b/records/050925_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt new file mode 100644 index 000000000..fc4774250 --- /dev/null +++ b/records/050925_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:10:06 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 130W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 126W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 82357 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 82358 C /usr/bin/python3 610MiB | +| 0 N/A N/A 82359 C /usr/bin/python3 610MiB | +| 0 N/A N/A 82360 C /usr/bin/python3 610MiB | +| 0 N/A N/A 82361 C /usr/bin/python3 610MiB | +| 0 N/A N/A 82362 C /usr/bin/python3 610MiB | +| 0 N/A N/A 82363 C /usr/bin/python3 610MiB | +| 0 N/A N/A 82364 C /usr/bin/python3 610MiB | +| 1 N/A N/A 82358 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 82359 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 82360 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 82361 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 82362 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 82363 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 82364 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1705 train_time:399ms step_avg:399.21ms +step:2/1705 train_time:419ms step_avg:209.29ms +step:3/1705 train_time:488ms step_avg:162.77ms +step:4/1705 train_time:579ms step_avg:144.74ms +step:5/1705 train_time:671ms step_avg:134.13ms +step:6/1705 train_time:762ms step_avg:127.05ms +step:7/1705 train_time:854ms step_avg:122.05ms +step:8/1705 train_time:947ms step_avg:118.40ms +step:9/1705 train_time:1040ms step_avg:115.51ms +step:10/1705 train_time:1132ms step_avg:113.18ms +step:11/1705 train_time:1225ms step_avg:111.33ms +step:12/1705 train_time:1319ms step_avg:109.94ms +step:13/1705 train_time:1414ms step_avg:108.78ms +step:14/1705 train_time:1510ms step_avg:107.83ms +step:15/1705 train_time:1604ms step_avg:106.90ms +step:16/1705 train_time:1697ms step_avg:106.04ms +step:17/1705 train_time:1789ms step_avg:105.26ms +step:18/1705 train_time:1882ms step_avg:104.57ms +step:19/1705 train_time:1974ms step_avg:103.91ms +step:20/1705 train_time:2067ms step_avg:103.35ms +step:21/1705 train_time:2160ms step_avg:102.85ms +step:22/1705 train_time:2252ms step_avg:102.38ms +step:23/1705 train_time:2346ms step_avg:102.01ms +step:24/1705 train_time:2440ms step_avg:101.66ms +step:25/1705 train_time:2533ms step_avg:101.32ms +step:26/1705 train_time:2628ms step_avg:101.06ms +step:27/1705 train_time:2721ms step_avg:100.79ms +step:28/1705 train_time:2815ms step_avg:100.52ms +step:29/1705 train_time:2908ms step_avg:100.27ms +step:30/1705 train_time:3001ms step_avg:100.04ms +step:31/1705 train_time:3094ms step_avg:99.82ms +step:32/1705 train_time:3188ms step_avg:99.62ms +step:33/1705 train_time:3280ms step_avg:99.40ms +step:34/1705 train_time:3373ms step_avg:99.20ms +step:35/1705 train_time:3467ms step_avg:99.05ms +step:36/1705 train_time:3560ms step_avg:98.89ms +step:37/1705 train_time:3653ms step_avg:98.73ms +step:38/1705 train_time:3748ms step_avg:98.62ms +step:39/1705 train_time:3841ms step_avg:98.48ms +step:40/1705 train_time:3933ms step_avg:98.33ms +step:41/1705 train_time:4028ms step_avg:98.24ms +step:42/1705 train_time:4121ms step_avg:98.11ms +step:43/1705 train_time:4213ms step_avg:97.98ms +step:44/1705 train_time:4307ms step_avg:97.88ms +step:45/1705 train_time:4401ms step_avg:97.79ms +step:46/1705 train_time:4493ms step_avg:97.67ms +step:47/1705 train_time:4587ms step_avg:97.59ms +step:48/1705 train_time:4680ms step_avg:97.51ms +step:49/1705 train_time:4773ms step_avg:97.40ms +step:50/1705 train_time:4866ms step_avg:97.33ms +step:51/1705 train_time:4960ms step_avg:97.25ms +step:52/1705 train_time:5053ms step_avg:97.16ms +step:53/1705 train_time:5147ms step_avg:97.10ms +step:54/1705 train_time:5239ms step_avg:97.02ms +step:55/1705 train_time:5332ms step_avg:96.95ms +step:56/1705 train_time:5426ms step_avg:96.90ms +step:57/1705 train_time:5520ms step_avg:96.84ms +step:58/1705 train_time:5612ms step_avg:96.77ms +step:59/1705 train_time:5706ms step_avg:96.72ms +step:60/1705 train_time:5799ms step_avg:96.66ms +step:61/1705 train_time:5893ms step_avg:96.60ms +step:62/1705 train_time:5986ms step_avg:96.55ms +step:63/1705 train_time:6079ms step_avg:96.48ms +step:64/1705 train_time:6171ms step_avg:96.42ms +step:65/1705 train_time:6265ms step_avg:96.38ms +step:66/1705 train_time:6358ms step_avg:96.33ms +step:67/1705 train_time:6451ms step_avg:96.29ms +step:68/1705 train_time:6545ms step_avg:96.25ms +step:69/1705 train_time:6638ms step_avg:96.20ms +step:70/1705 train_time:6731ms step_avg:96.16ms +step:71/1705 train_time:6825ms step_avg:96.13ms +step:72/1705 train_time:6919ms step_avg:96.09ms +step:73/1705 train_time:7012ms step_avg:96.05ms +step:74/1705 train_time:7106ms step_avg:96.02ms +step:75/1705 train_time:7198ms step_avg:95.98ms +step:76/1705 train_time:7291ms step_avg:95.94ms +step:77/1705 train_time:7385ms step_avg:95.91ms +step:78/1705 train_time:7477ms step_avg:95.87ms +step:79/1705 train_time:7570ms step_avg:95.83ms +step:80/1705 train_time:7664ms step_avg:95.80ms +step:81/1705 train_time:7757ms step_avg:95.76ms +step:82/1705 train_time:7851ms step_avg:95.74ms +step:83/1705 train_time:7945ms step_avg:95.72ms +step:84/1705 train_time:8038ms step_avg:95.69ms +step:85/1705 train_time:8131ms step_avg:95.66ms +step:86/1705 train_time:8225ms step_avg:95.64ms +step:87/1705 train_time:8318ms step_avg:95.61ms +step:88/1705 train_time:8411ms step_avg:95.58ms +step:89/1705 train_time:8504ms step_avg:95.56ms +step:90/1705 train_time:8598ms step_avg:95.53ms +step:91/1705 train_time:8691ms step_avg:95.50ms +step:92/1705 train_time:8785ms step_avg:95.48ms +step:93/1705 train_time:8877ms step_avg:95.45ms +step:94/1705 train_time:8970ms step_avg:95.43ms +step:95/1705 train_time:9063ms step_avg:95.40ms +step:96/1705 train_time:9155ms step_avg:95.37ms +step:97/1705 train_time:9249ms step_avg:95.35ms +step:98/1705 train_time:9343ms step_avg:95.33ms +step:99/1705 train_time:9435ms step_avg:95.30ms +step:100/1705 train_time:9529ms step_avg:95.29ms +step:101/1705 train_time:9622ms step_avg:95.27ms +step:102/1705 train_time:9715ms step_avg:95.24ms +step:103/1705 train_time:9808ms step_avg:95.22ms +step:104/1705 train_time:9901ms step_avg:95.20ms +step:105/1705 train_time:9993ms step_avg:95.18ms +step:106/1705 train_time:10087ms step_avg:95.16ms +step:107/1705 train_time:10179ms step_avg:95.13ms +step:108/1705 train_time:10271ms step_avg:95.10ms +step:109/1705 train_time:10365ms step_avg:95.09ms +step:110/1705 train_time:10459ms step_avg:95.08ms +step:111/1705 train_time:10551ms step_avg:95.06ms +step:112/1705 train_time:10645ms step_avg:95.05ms +step:113/1705 train_time:10737ms step_avg:95.02ms +step:114/1705 train_time:10830ms step_avg:95.00ms +step:115/1705 train_time:10923ms step_avg:94.98ms +step:116/1705 train_time:11017ms step_avg:94.97ms +step:117/1705 train_time:11109ms step_avg:94.95ms +step:118/1705 train_time:11201ms step_avg:94.93ms +step:119/1705 train_time:11294ms step_avg:94.91ms +step:120/1705 train_time:11387ms step_avg:94.90ms +step:121/1705 train_time:11480ms step_avg:94.88ms +step:122/1705 train_time:11573ms step_avg:94.86ms +step:123/1705 train_time:11666ms step_avg:94.84ms +step:124/1705 train_time:11759ms step_avg:94.83ms +step:125/1705 train_time:11852ms step_avg:94.81ms +step:125/1705 val_loss:4.3053 train_time:11946ms step_avg:95.57ms +step:126/1705 train_time:11967ms step_avg:94.98ms +step:127/1705 train_time:12044ms step_avg:94.84ms +step:128/1705 train_time:12145ms step_avg:94.88ms +step:129/1705 train_time:12239ms step_avg:94.88ms +step:130/1705 train_time:12332ms step_avg:94.86ms +step:131/1705 train_time:12424ms step_avg:94.84ms +step:132/1705 train_time:12515ms step_avg:94.81ms +step:133/1705 train_time:12607ms step_avg:94.79ms +step:134/1705 train_time:12699ms step_avg:94.77ms +step:135/1705 train_time:12792ms step_avg:94.76ms +step:136/1705 train_time:12883ms step_avg:94.73ms +step:137/1705 train_time:12977ms step_avg:94.72ms +step:138/1705 train_time:13073ms step_avg:94.73ms +step:139/1705 train_time:13166ms step_avg:94.72ms +step:140/1705 train_time:13260ms step_avg:94.72ms +step:141/1705 train_time:13354ms step_avg:94.71ms +step:142/1705 train_time:13447ms step_avg:94.70ms +step:143/1705 train_time:13539ms step_avg:94.68ms +step:144/1705 train_time:13631ms step_avg:94.66ms +step:145/1705 train_time:13723ms step_avg:94.64ms +step:146/1705 train_time:13815ms step_avg:94.62ms +step:147/1705 train_time:13907ms step_avg:94.61ms +step:148/1705 train_time:14000ms step_avg:94.60ms +step:149/1705 train_time:14095ms step_avg:94.60ms +step:150/1705 train_time:14189ms step_avg:94.59ms +step:151/1705 train_time:14282ms step_avg:94.58ms +step:152/1705 train_time:14376ms step_avg:94.58ms +step:153/1705 train_time:14469ms step_avg:94.57ms +step:154/1705 train_time:14562ms step_avg:94.56ms +step:155/1705 train_time:14655ms step_avg:94.55ms +step:156/1705 train_time:14748ms step_avg:94.54ms +step:157/1705 train_time:14839ms step_avg:94.52ms +step:158/1705 train_time:14932ms step_avg:94.51ms +step:159/1705 train_time:15024ms step_avg:94.49ms +step:160/1705 train_time:15118ms step_avg:94.49ms +step:161/1705 train_time:15212ms step_avg:94.48ms +step:162/1705 train_time:15305ms step_avg:94.47ms +step:163/1705 train_time:15398ms step_avg:94.47ms +step:164/1705 train_time:15491ms step_avg:94.46ms +step:165/1705 train_time:15583ms step_avg:94.44ms +step:166/1705 train_time:15675ms step_avg:94.43ms +step:167/1705 train_time:15768ms step_avg:94.42ms +step:168/1705 train_time:15860ms step_avg:94.41ms +step:169/1705 train_time:15953ms step_avg:94.40ms +step:170/1705 train_time:16045ms step_avg:94.38ms +step:171/1705 train_time:16138ms step_avg:94.37ms +step:172/1705 train_time:16231ms step_avg:94.37ms +step:173/1705 train_time:16323ms step_avg:94.35ms +step:174/1705 train_time:16416ms step_avg:94.34ms +step:175/1705 train_time:16509ms step_avg:94.34ms +step:176/1705 train_time:16601ms step_avg:94.33ms +step:177/1705 train_time:16695ms step_avg:94.32ms +step:178/1705 train_time:16788ms step_avg:94.31ms +step:179/1705 train_time:16880ms step_avg:94.30ms +step:180/1705 train_time:16973ms step_avg:94.29ms +step:181/1705 train_time:17065ms step_avg:94.28ms +step:182/1705 train_time:17159ms step_avg:94.28ms +step:183/1705 train_time:17252ms step_avg:94.27ms +step:184/1705 train_time:17345ms step_avg:94.27ms +step:185/1705 train_time:17438ms step_avg:94.26ms +step:186/1705 train_time:17532ms step_avg:94.26ms +step:187/1705 train_time:17624ms step_avg:94.25ms +step:188/1705 train_time:17717ms step_avg:94.24ms +step:189/1705 train_time:17810ms step_avg:94.23ms +step:190/1705 train_time:17902ms step_avg:94.22ms +step:191/1705 train_time:17995ms step_avg:94.22ms +step:192/1705 train_time:18087ms step_avg:94.21ms +step:193/1705 train_time:18180ms step_avg:94.20ms +step:194/1705 train_time:18274ms step_avg:94.20ms +step:195/1705 train_time:18367ms step_avg:94.19ms +step:196/1705 train_time:18459ms step_avg:94.18ms +step:197/1705 train_time:18552ms step_avg:94.17ms +step:198/1705 train_time:18644ms step_avg:94.16ms +step:199/1705 train_time:18738ms step_avg:94.16ms +step:200/1705 train_time:18830ms step_avg:94.15ms +step:201/1705 train_time:18922ms step_avg:94.14ms +step:202/1705 train_time:19015ms step_avg:94.14ms +step:203/1705 train_time:19108ms step_avg:94.13ms +step:204/1705 train_time:19200ms step_avg:94.12ms +step:205/1705 train_time:19294ms step_avg:94.12ms +step:206/1705 train_time:19387ms step_avg:94.11ms +step:207/1705 train_time:19479ms step_avg:94.10ms +step:208/1705 train_time:19573ms step_avg:94.10ms +step:209/1705 train_time:19665ms step_avg:94.09ms +step:210/1705 train_time:19759ms step_avg:94.09ms +step:211/1705 train_time:19852ms step_avg:94.09ms +step:212/1705 train_time:19945ms step_avg:94.08ms +step:213/1705 train_time:20259ms step_avg:95.11ms +step:214/1705 train_time:20329ms step_avg:95.00ms +step:215/1705 train_time:20420ms step_avg:94.98ms +step:216/1705 train_time:20513ms step_avg:94.97ms +step:217/1705 train_time:20605ms step_avg:94.95ms +step:218/1705 train_time:20697ms step_avg:94.94ms +step:219/1705 train_time:20788ms step_avg:94.92ms +step:220/1705 train_time:20880ms step_avg:94.91ms +step:221/1705 train_time:20972ms step_avg:94.90ms +step:222/1705 train_time:21063ms step_avg:94.88ms +step:223/1705 train_time:21158ms step_avg:94.88ms +step:224/1705 train_time:21254ms step_avg:94.88ms +step:225/1705 train_time:21350ms step_avg:94.89ms +step:226/1705 train_time:21443ms step_avg:94.88ms +step:227/1705 train_time:21536ms step_avg:94.87ms +step:228/1705 train_time:21627ms step_avg:94.86ms +step:229/1705 train_time:21720ms step_avg:94.85ms +step:230/1705 train_time:21813ms step_avg:94.84ms +step:231/1705 train_time:21904ms step_avg:94.82ms +step:232/1705 train_time:21997ms step_avg:94.81ms +step:233/1705 train_time:22088ms step_avg:94.80ms +step:234/1705 train_time:22181ms step_avg:94.79ms +step:235/1705 train_time:22276ms step_avg:94.79ms +step:236/1705 train_time:22370ms step_avg:94.79ms +step:237/1705 train_time:22463ms step_avg:94.78ms +step:238/1705 train_time:22556ms step_avg:94.77ms +step:239/1705 train_time:22649ms step_avg:94.77ms +step:240/1705 train_time:22741ms step_avg:94.75ms +step:241/1705 train_time:22833ms step_avg:94.74ms +step:242/1705 train_time:22926ms step_avg:94.73ms +step:243/1705 train_time:23018ms step_avg:94.72ms +step:244/1705 train_time:23110ms step_avg:94.71ms +step:245/1705 train_time:23203ms step_avg:94.71ms +step:246/1705 train_time:23297ms step_avg:94.70ms +step:247/1705 train_time:23391ms step_avg:94.70ms +step:248/1705 train_time:23484ms step_avg:94.69ms +step:249/1705 train_time:23578ms step_avg:94.69ms +step:250/1705 train_time:23671ms step_avg:94.68ms +step:250/1705 val_loss:3.9711 train_time:23763ms step_avg:95.05ms +step:251/1705 train_time:23785ms step_avg:94.76ms +step:252/1705 train_time:23856ms step_avg:94.67ms +step:253/1705 train_time:23948ms step_avg:94.66ms +step:254/1705 train_time:24050ms step_avg:94.68ms +step:255/1705 train_time:24148ms step_avg:94.70ms +step:256/1705 train_time:24240ms step_avg:94.69ms +step:257/1705 train_time:24331ms step_avg:94.68ms +step:258/1705 train_time:24423ms step_avg:94.66ms +step:259/1705 train_time:24515ms step_avg:94.65ms +step:260/1705 train_time:24607ms step_avg:94.64ms +step:261/1705 train_time:24700ms step_avg:94.64ms +step:262/1705 train_time:24793ms step_avg:94.63ms +step:263/1705 train_time:24886ms step_avg:94.62ms +step:264/1705 train_time:24978ms step_avg:94.62ms +step:265/1705 train_time:25072ms step_avg:94.61ms +step:266/1705 train_time:25165ms step_avg:94.61ms +step:267/1705 train_time:25257ms step_avg:94.60ms +step:268/1705 train_time:25350ms step_avg:94.59ms +step:269/1705 train_time:25442ms step_avg:94.58ms +step:270/1705 train_time:25534ms step_avg:94.57ms +step:271/1705 train_time:25627ms step_avg:94.57ms +step:272/1705 train_time:25719ms step_avg:94.56ms +step:273/1705 train_time:25812ms step_avg:94.55ms +step:274/1705 train_time:25905ms step_avg:94.54ms +step:275/1705 train_time:25998ms step_avg:94.54ms +step:276/1705 train_time:26092ms step_avg:94.54ms +step:277/1705 train_time:26186ms step_avg:94.54ms +step:278/1705 train_time:26279ms step_avg:94.53ms +step:279/1705 train_time:26372ms step_avg:94.52ms +step:280/1705 train_time:26464ms step_avg:94.52ms +step:281/1705 train_time:26556ms step_avg:94.51ms +step:282/1705 train_time:26649ms step_avg:94.50ms +step:283/1705 train_time:26743ms step_avg:94.50ms +step:284/1705 train_time:26834ms step_avg:94.49ms +step:285/1705 train_time:26927ms step_avg:94.48ms +step:286/1705 train_time:27020ms step_avg:94.47ms +step:287/1705 train_time:27113ms step_avg:94.47ms +step:288/1705 train_time:27207ms step_avg:94.47ms +step:289/1705 train_time:27300ms step_avg:94.46ms +step:290/1705 train_time:27393ms step_avg:94.46ms +step:291/1705 train_time:27485ms step_avg:94.45ms +step:292/1705 train_time:27578ms step_avg:94.44ms +step:293/1705 train_time:27671ms step_avg:94.44ms +step:294/1705 train_time:27763ms step_avg:94.43ms +step:295/1705 train_time:27855ms step_avg:94.42ms +step:296/1705 train_time:27948ms step_avg:94.42ms +step:297/1705 train_time:28041ms step_avg:94.41ms +step:298/1705 train_time:28135ms step_avg:94.41ms +step:299/1705 train_time:28229ms step_avg:94.41ms +step:300/1705 train_time:28321ms step_avg:94.40ms +step:301/1705 train_time:28414ms step_avg:94.40ms +step:302/1705 train_time:28507ms step_avg:94.39ms +step:303/1705 train_time:28599ms step_avg:94.39ms +step:304/1705 train_time:28692ms step_avg:94.38ms +step:305/1705 train_time:28785ms step_avg:94.38ms +step:306/1705 train_time:28878ms step_avg:94.37ms +step:307/1705 train_time:28971ms step_avg:94.37ms +step:308/1705 train_time:29063ms step_avg:94.36ms +step:309/1705 train_time:29156ms step_avg:94.36ms +step:310/1705 train_time:29249ms step_avg:94.35ms +step:311/1705 train_time:29342ms step_avg:94.35ms +step:312/1705 train_time:29435ms step_avg:94.34ms +step:313/1705 train_time:29527ms step_avg:94.34ms +step:314/1705 train_time:29620ms step_avg:94.33ms +step:315/1705 train_time:29713ms step_avg:94.33ms +step:316/1705 train_time:29806ms step_avg:94.32ms +step:317/1705 train_time:29898ms step_avg:94.31ms +step:318/1705 train_time:29991ms step_avg:94.31ms +step:319/1705 train_time:30084ms step_avg:94.31ms +step:320/1705 train_time:30176ms step_avg:94.30ms +step:321/1705 train_time:30270ms step_avg:94.30ms +step:322/1705 train_time:30363ms step_avg:94.30ms +step:323/1705 train_time:30455ms step_avg:94.29ms +step:324/1705 train_time:30549ms step_avg:94.29ms +step:325/1705 train_time:30643ms step_avg:94.29ms +step:326/1705 train_time:30735ms step_avg:94.28ms +step:327/1705 train_time:30828ms step_avg:94.28ms +step:328/1705 train_time:30921ms step_avg:94.27ms +step:329/1705 train_time:31013ms step_avg:94.27ms +step:330/1705 train_time:31107ms step_avg:94.26ms +step:331/1705 train_time:31199ms step_avg:94.26ms +step:332/1705 train_time:31293ms step_avg:94.26ms +step:333/1705 train_time:31386ms step_avg:94.25ms +step:334/1705 train_time:31478ms step_avg:94.25ms +step:335/1705 train_time:31571ms step_avg:94.24ms +step:336/1705 train_time:31664ms step_avg:94.24ms +step:337/1705 train_time:31756ms step_avg:94.23ms +step:338/1705 train_time:31849ms step_avg:94.23ms +step:339/1705 train_time:31942ms step_avg:94.22ms +step:340/1705 train_time:32034ms step_avg:94.22ms +step:341/1705 train_time:32127ms step_avg:94.22ms +step:342/1705 train_time:32220ms step_avg:94.21ms +step:343/1705 train_time:32313ms step_avg:94.21ms +step:344/1705 train_time:32406ms step_avg:94.20ms +step:345/1705 train_time:32499ms step_avg:94.20ms +step:346/1705 train_time:32592ms step_avg:94.20ms +step:347/1705 train_time:32685ms step_avg:94.19ms +step:348/1705 train_time:32778ms step_avg:94.19ms +step:349/1705 train_time:32870ms step_avg:94.18ms +step:350/1705 train_time:32963ms step_avg:94.18ms +step:351/1705 train_time:33056ms step_avg:94.18ms +step:352/1705 train_time:33149ms step_avg:94.17ms +step:353/1705 train_time:33242ms step_avg:94.17ms +step:354/1705 train_time:33334ms step_avg:94.17ms +step:355/1705 train_time:33429ms step_avg:94.17ms +step:356/1705 train_time:33521ms step_avg:94.16ms +step:357/1705 train_time:33614ms step_avg:94.16ms +step:358/1705 train_time:33707ms step_avg:94.15ms +step:359/1705 train_time:33799ms step_avg:94.15ms +step:360/1705 train_time:33892ms step_avg:94.14ms +step:361/1705 train_time:33985ms step_avg:94.14ms +step:362/1705 train_time:34077ms step_avg:94.14ms +step:363/1705 train_time:34171ms step_avg:94.13ms +step:364/1705 train_time:34263ms step_avg:94.13ms +step:365/1705 train_time:34355ms step_avg:94.12ms +step:366/1705 train_time:34448ms step_avg:94.12ms +step:367/1705 train_time:34541ms step_avg:94.12ms +step:368/1705 train_time:34635ms step_avg:94.12ms +step:369/1705 train_time:34728ms step_avg:94.11ms +step:370/1705 train_time:34821ms step_avg:94.11ms +step:371/1705 train_time:34914ms step_avg:94.11ms +step:372/1705 train_time:35007ms step_avg:94.11ms +step:373/1705 train_time:35100ms step_avg:94.10ms +step:374/1705 train_time:35193ms step_avg:94.10ms +step:375/1705 train_time:35287ms step_avg:94.10ms +step:375/1705 val_loss:3.8190 train_time:35380ms step_avg:94.35ms +step:376/1705 train_time:35401ms step_avg:94.15ms +step:377/1705 train_time:35475ms step_avg:94.10ms +step:378/1705 train_time:35572ms step_avg:94.11ms +step:379/1705 train_time:35665ms step_avg:94.10ms +step:380/1705 train_time:35758ms step_avg:94.10ms +step:381/1705 train_time:35850ms step_avg:94.09ms +step:382/1705 train_time:35943ms step_avg:94.09ms +step:383/1705 train_time:36035ms step_avg:94.09ms +step:384/1705 train_time:36126ms step_avg:94.08ms +step:385/1705 train_time:36219ms step_avg:94.08ms +step:386/1705 train_time:36311ms step_avg:94.07ms +step:387/1705 train_time:36405ms step_avg:94.07ms +step:388/1705 train_time:36499ms step_avg:94.07ms +step:389/1705 train_time:36594ms step_avg:94.07ms +step:390/1705 train_time:36687ms step_avg:94.07ms +step:391/1705 train_time:36780ms step_avg:94.07ms +step:392/1705 train_time:36873ms step_avg:94.06ms +step:393/1705 train_time:36965ms step_avg:94.06ms +step:394/1705 train_time:37059ms step_avg:94.06ms +step:395/1705 train_time:37150ms step_avg:94.05ms +step:396/1705 train_time:37243ms step_avg:94.05ms +step:397/1705 train_time:37336ms step_avg:94.05ms +step:398/1705 train_time:37429ms step_avg:94.04ms +step:399/1705 train_time:37523ms step_avg:94.04ms +step:400/1705 train_time:37616ms step_avg:94.04ms +step:401/1705 train_time:37709ms step_avg:94.04ms +step:402/1705 train_time:37803ms step_avg:94.04ms +step:403/1705 train_time:37895ms step_avg:94.03ms +step:404/1705 train_time:37987ms step_avg:94.03ms +step:405/1705 train_time:38080ms step_avg:94.02ms +step:406/1705 train_time:38173ms step_avg:94.02ms +step:407/1705 train_time:38265ms step_avg:94.02ms +step:408/1705 train_time:38359ms step_avg:94.02ms +step:409/1705 train_time:38453ms step_avg:94.02ms +step:410/1705 train_time:38546ms step_avg:94.01ms +step:411/1705 train_time:38639ms step_avg:94.01ms +step:412/1705 train_time:38732ms step_avg:94.01ms +step:413/1705 train_time:38825ms step_avg:94.01ms +step:414/1705 train_time:38917ms step_avg:94.00ms +step:415/1705 train_time:39009ms step_avg:94.00ms +step:416/1705 train_time:39102ms step_avg:93.99ms +step:417/1705 train_time:39194ms step_avg:93.99ms +step:418/1705 train_time:39286ms step_avg:93.99ms +step:419/1705 train_time:39379ms step_avg:93.98ms +step:420/1705 train_time:39472ms step_avg:93.98ms +step:421/1705 train_time:39565ms step_avg:93.98ms +step:422/1705 train_time:39660ms step_avg:93.98ms +step:423/1705 train_time:39753ms step_avg:93.98ms +step:424/1705 train_time:39846ms step_avg:93.98ms +step:425/1705 train_time:40124ms step_avg:94.41ms +step:426/1705 train_time:40241ms step_avg:94.46ms +step:427/1705 train_time:40332ms step_avg:94.45ms +step:428/1705 train_time:40424ms step_avg:94.45ms +step:429/1705 train_time:40516ms step_avg:94.44ms +step:430/1705 train_time:40608ms step_avg:94.44ms +step:431/1705 train_time:40700ms step_avg:94.43ms +step:432/1705 train_time:40792ms step_avg:94.43ms +step:433/1705 train_time:40884ms step_avg:94.42ms +step:434/1705 train_time:40976ms step_avg:94.41ms +step:435/1705 train_time:41068ms step_avg:94.41ms +step:436/1705 train_time:41165ms step_avg:94.42ms +step:437/1705 train_time:41262ms step_avg:94.42ms +step:438/1705 train_time:41356ms step_avg:94.42ms +step:439/1705 train_time:41448ms step_avg:94.42ms +step:440/1705 train_time:41542ms step_avg:94.41ms +step:441/1705 train_time:41635ms step_avg:94.41ms +step:442/1705 train_time:41726ms step_avg:94.40ms +step:443/1705 train_time:41819ms step_avg:94.40ms +step:444/1705 train_time:41910ms step_avg:94.39ms +step:445/1705 train_time:42003ms step_avg:94.39ms +step:446/1705 train_time:42096ms step_avg:94.39ms +step:447/1705 train_time:42190ms step_avg:94.38ms +step:448/1705 train_time:42284ms step_avg:94.38ms +step:449/1705 train_time:42379ms step_avg:94.38ms +step:450/1705 train_time:42472ms step_avg:94.38ms +step:451/1705 train_time:42565ms step_avg:94.38ms +step:452/1705 train_time:42658ms step_avg:94.38ms +step:453/1705 train_time:42751ms step_avg:94.37ms +step:454/1705 train_time:42844ms step_avg:94.37ms +step:455/1705 train_time:42937ms step_avg:94.37ms +step:456/1705 train_time:43030ms step_avg:94.36ms +step:457/1705 train_time:43123ms step_avg:94.36ms +step:458/1705 train_time:43217ms step_avg:94.36ms +step:459/1705 train_time:43310ms step_avg:94.36ms +step:460/1705 train_time:43404ms step_avg:94.36ms +step:461/1705 train_time:43497ms step_avg:94.35ms +step:462/1705 train_time:43590ms step_avg:94.35ms +step:463/1705 train_time:43684ms step_avg:94.35ms +step:464/1705 train_time:43776ms step_avg:94.34ms +step:465/1705 train_time:43868ms step_avg:94.34ms +step:466/1705 train_time:43961ms step_avg:94.34ms +step:467/1705 train_time:44054ms step_avg:94.33ms +step:468/1705 train_time:44148ms step_avg:94.33ms +step:469/1705 train_time:44242ms step_avg:94.33ms +step:470/1705 train_time:44334ms step_avg:94.33ms +step:471/1705 train_time:44428ms step_avg:94.33ms +step:472/1705 train_time:44522ms step_avg:94.33ms +step:473/1705 train_time:44615ms step_avg:94.32ms +step:474/1705 train_time:44707ms step_avg:94.32ms +step:475/1705 train_time:44800ms step_avg:94.32ms +step:476/1705 train_time:44894ms step_avg:94.31ms +step:477/1705 train_time:44986ms step_avg:94.31ms +step:478/1705 train_time:45079ms step_avg:94.31ms +step:479/1705 train_time:45172ms step_avg:94.30ms +step:480/1705 train_time:45265ms step_avg:94.30ms +step:481/1705 train_time:45358ms step_avg:94.30ms +step:482/1705 train_time:45450ms step_avg:94.30ms +step:483/1705 train_time:45544ms step_avg:94.29ms +step:484/1705 train_time:45638ms step_avg:94.29ms +step:485/1705 train_time:45730ms step_avg:94.29ms +step:486/1705 train_time:45823ms step_avg:94.29ms +step:487/1705 train_time:45916ms step_avg:94.28ms +step:488/1705 train_time:46008ms step_avg:94.28ms +step:489/1705 train_time:46101ms step_avg:94.28ms +step:490/1705 train_time:46194ms step_avg:94.27ms +step:491/1705 train_time:46286ms step_avg:94.27ms +step:492/1705 train_time:46380ms step_avg:94.27ms +step:493/1705 train_time:46472ms step_avg:94.26ms +step:494/1705 train_time:46565ms step_avg:94.26ms +step:495/1705 train_time:46658ms step_avg:94.26ms +step:496/1705 train_time:46751ms step_avg:94.26ms +step:497/1705 train_time:46844ms step_avg:94.25ms +step:498/1705 train_time:46937ms step_avg:94.25ms +step:499/1705 train_time:47030ms step_avg:94.25ms +step:500/1705 train_time:47123ms step_avg:94.25ms +step:500/1705 val_loss:3.7215 train_time:47217ms step_avg:94.43ms +step:501/1705 train_time:47239ms step_avg:94.29ms +step:502/1705 train_time:47314ms step_avg:94.25ms +step:503/1705 train_time:47412ms step_avg:94.26ms +step:504/1705 train_time:47506ms step_avg:94.26ms +step:505/1705 train_time:47598ms step_avg:94.25ms +step:506/1705 train_time:47689ms step_avg:94.25ms +step:507/1705 train_time:47781ms step_avg:94.24ms +step:508/1705 train_time:47874ms step_avg:94.24ms +step:509/1705 train_time:47966ms step_avg:94.24ms +step:510/1705 train_time:48058ms step_avg:94.23ms +step:511/1705 train_time:48151ms step_avg:94.23ms +step:512/1705 train_time:48245ms step_avg:94.23ms +step:513/1705 train_time:48339ms step_avg:94.23ms +step:514/1705 train_time:48434ms step_avg:94.23ms +step:515/1705 train_time:48528ms step_avg:94.23ms +step:516/1705 train_time:48621ms step_avg:94.23ms +step:517/1705 train_time:48713ms step_avg:94.22ms +step:518/1705 train_time:48807ms step_avg:94.22ms +step:519/1705 train_time:48899ms step_avg:94.22ms +step:520/1705 train_time:48991ms step_avg:94.21ms +step:521/1705 train_time:49084ms step_avg:94.21ms +step:522/1705 train_time:49177ms step_avg:94.21ms +step:523/1705 train_time:49271ms step_avg:94.21ms +step:524/1705 train_time:49364ms step_avg:94.21ms +step:525/1705 train_time:49457ms step_avg:94.20ms +step:526/1705 train_time:49551ms step_avg:94.20ms +step:527/1705 train_time:49644ms step_avg:94.20ms +step:528/1705 train_time:49736ms step_avg:94.20ms +step:529/1705 train_time:49829ms step_avg:94.20ms +step:530/1705 train_time:49922ms step_avg:94.19ms +step:531/1705 train_time:50014ms step_avg:94.19ms +step:532/1705 train_time:50108ms step_avg:94.19ms +step:533/1705 train_time:50201ms step_avg:94.19ms +step:534/1705 train_time:50294ms step_avg:94.18ms +step:535/1705 train_time:50387ms step_avg:94.18ms +step:536/1705 train_time:50480ms step_avg:94.18ms +step:537/1705 train_time:50573ms step_avg:94.18ms +step:538/1705 train_time:50667ms step_avg:94.18ms +step:539/1705 train_time:50760ms step_avg:94.17ms +step:540/1705 train_time:50852ms step_avg:94.17ms +step:541/1705 train_time:50945ms step_avg:94.17ms +step:542/1705 train_time:51038ms step_avg:94.17ms +step:543/1705 train_time:51131ms step_avg:94.16ms +step:544/1705 train_time:51224ms step_avg:94.16ms +step:545/1705 train_time:51317ms step_avg:94.16ms +step:546/1705 train_time:51410ms step_avg:94.16ms +step:547/1705 train_time:51503ms step_avg:94.16ms +step:548/1705 train_time:51596ms step_avg:94.15ms +step:549/1705 train_time:51690ms step_avg:94.15ms +step:550/1705 train_time:51783ms step_avg:94.15ms +step:551/1705 train_time:51875ms step_avg:94.15ms +step:552/1705 train_time:51969ms step_avg:94.15ms +step:553/1705 train_time:52062ms step_avg:94.14ms +step:554/1705 train_time:52154ms step_avg:94.14ms +step:555/1705 train_time:52248ms step_avg:94.14ms +step:556/1705 train_time:52341ms step_avg:94.14ms +step:557/1705 train_time:52435ms step_avg:94.14ms +step:558/1705 train_time:52529ms step_avg:94.14ms +step:559/1705 train_time:52622ms step_avg:94.14ms +step:560/1705 train_time:52715ms step_avg:94.13ms +step:561/1705 train_time:52809ms step_avg:94.13ms +step:562/1705 train_time:52902ms step_avg:94.13ms +step:563/1705 train_time:52994ms step_avg:94.13ms +step:564/1705 train_time:53087ms step_avg:94.13ms +step:565/1705 train_time:53180ms step_avg:94.12ms +step:566/1705 train_time:53273ms step_avg:94.12ms +step:567/1705 train_time:53367ms step_avg:94.12ms +step:568/1705 train_time:53460ms step_avg:94.12ms +step:569/1705 train_time:53553ms step_avg:94.12ms +step:570/1705 train_time:53646ms step_avg:94.12ms +step:571/1705 train_time:53741ms step_avg:94.12ms +step:572/1705 train_time:53835ms step_avg:94.12ms +step:573/1705 train_time:53930ms step_avg:94.12ms +step:574/1705 train_time:54023ms step_avg:94.12ms +step:575/1705 train_time:54117ms step_avg:94.12ms +step:576/1705 train_time:54212ms step_avg:94.12ms +step:577/1705 train_time:54306ms step_avg:94.12ms +step:578/1705 train_time:54401ms step_avg:94.12ms +step:579/1705 train_time:54495ms step_avg:94.12ms +step:580/1705 train_time:54589ms step_avg:94.12ms +step:581/1705 train_time:54684ms step_avg:94.12ms +step:582/1705 train_time:54777ms step_avg:94.12ms +step:583/1705 train_time:54872ms step_avg:94.12ms +step:584/1705 train_time:54966ms step_avg:94.12ms +step:585/1705 train_time:55060ms step_avg:94.12ms +step:586/1705 train_time:55154ms step_avg:94.12ms +step:587/1705 train_time:55249ms step_avg:94.12ms +step:588/1705 train_time:55343ms step_avg:94.12ms +step:589/1705 train_time:55438ms step_avg:94.12ms +step:590/1705 train_time:55532ms step_avg:94.12ms +step:591/1705 train_time:55627ms step_avg:94.12ms +step:592/1705 train_time:55721ms step_avg:94.12ms +step:593/1705 train_time:55815ms step_avg:94.12ms +step:594/1705 train_time:55910ms step_avg:94.13ms +step:595/1705 train_time:56005ms step_avg:94.13ms +step:596/1705 train_time:56098ms step_avg:94.12ms +step:597/1705 train_time:56193ms step_avg:94.13ms +step:598/1705 train_time:56287ms step_avg:94.13ms +step:599/1705 train_time:56382ms step_avg:94.13ms +step:600/1705 train_time:56476ms step_avg:94.13ms +step:601/1705 train_time:56571ms step_avg:94.13ms +step:602/1705 train_time:56666ms step_avg:94.13ms +step:603/1705 train_time:56760ms step_avg:94.13ms +step:604/1705 train_time:56854ms step_avg:94.13ms +step:605/1705 train_time:56949ms step_avg:94.13ms +step:606/1705 train_time:57044ms step_avg:94.13ms +step:607/1705 train_time:57138ms step_avg:94.13ms +step:608/1705 train_time:57232ms step_avg:94.13ms +step:609/1705 train_time:57327ms step_avg:94.13ms +step:610/1705 train_time:57421ms step_avg:94.13ms +step:611/1705 train_time:57516ms step_avg:94.13ms +step:612/1705 train_time:57611ms step_avg:94.13ms +step:613/1705 train_time:57706ms step_avg:94.14ms +step:614/1705 train_time:57800ms step_avg:94.14ms +step:615/1705 train_time:57894ms step_avg:94.14ms +step:616/1705 train_time:57988ms step_avg:94.14ms +step:617/1705 train_time:58083ms step_avg:94.14ms +step:618/1705 train_time:58177ms step_avg:94.14ms +step:619/1705 train_time:58272ms step_avg:94.14ms +step:620/1705 train_time:58366ms step_avg:94.14ms +step:621/1705 train_time:58461ms step_avg:94.14ms +step:622/1705 train_time:58554ms step_avg:94.14ms +step:623/1705 train_time:58648ms step_avg:94.14ms +step:624/1705 train_time:58742ms step_avg:94.14ms +step:625/1705 train_time:58836ms step_avg:94.14ms +step:625/1705 val_loss:3.6203 train_time:58932ms step_avg:94.29ms +step:626/1705 train_time:58955ms step_avg:94.18ms +step:627/1705 train_time:59038ms step_avg:94.16ms +step:628/1705 train_time:59137ms step_avg:94.17ms +step:629/1705 train_time:59231ms step_avg:94.17ms +step:630/1705 train_time:59325ms step_avg:94.17ms +step:631/1705 train_time:59418ms step_avg:94.16ms +step:632/1705 train_time:59511ms step_avg:94.16ms +step:633/1705 train_time:59604ms step_avg:94.16ms +step:634/1705 train_time:59697ms step_avg:94.16ms +step:635/1705 train_time:59790ms step_avg:94.16ms +step:636/1705 train_time:59885ms step_avg:94.16ms +step:637/1705 train_time:59980ms step_avg:94.16ms +step:638/1705 train_time:60078ms step_avg:94.17ms +step:639/1705 train_time:60438ms step_avg:94.58ms +step:640/1705 train_time:60529ms step_avg:94.58ms +step:641/1705 train_time:60621ms step_avg:94.57ms +step:642/1705 train_time:60715ms step_avg:94.57ms +step:643/1705 train_time:60809ms step_avg:94.57ms +step:644/1705 train_time:60902ms step_avg:94.57ms +step:645/1705 train_time:60995ms step_avg:94.57ms +step:646/1705 train_time:61088ms step_avg:94.56ms +step:647/1705 train_time:61181ms step_avg:94.56ms +step:648/1705 train_time:61274ms step_avg:94.56ms +step:649/1705 train_time:61372ms step_avg:94.56ms +step:650/1705 train_time:61469ms step_avg:94.57ms +step:651/1705 train_time:61565ms step_avg:94.57ms +step:652/1705 train_time:61660ms step_avg:94.57ms +step:653/1705 train_time:61754ms step_avg:94.57ms +step:654/1705 train_time:61848ms step_avg:94.57ms +step:655/1705 train_time:61941ms step_avg:94.57ms +step:656/1705 train_time:62035ms step_avg:94.57ms +step:657/1705 train_time:62129ms step_avg:94.56ms +step:658/1705 train_time:62222ms step_avg:94.56ms +step:659/1705 train_time:62316ms step_avg:94.56ms +step:660/1705 train_time:62413ms step_avg:94.56ms +step:661/1705 train_time:62508ms step_avg:94.57ms +step:662/1705 train_time:62602ms step_avg:94.57ms +step:663/1705 train_time:62697ms step_avg:94.57ms +step:664/1705 train_time:62792ms step_avg:94.57ms +step:665/1705 train_time:62886ms step_avg:94.57ms +step:666/1705 train_time:62979ms step_avg:94.56ms +step:667/1705 train_time:63073ms step_avg:94.56ms +step:668/1705 train_time:63166ms step_avg:94.56ms +step:669/1705 train_time:63260ms step_avg:94.56ms +step:670/1705 train_time:63355ms step_avg:94.56ms +step:671/1705 train_time:63450ms step_avg:94.56ms +step:672/1705 train_time:63545ms step_avg:94.56ms +step:673/1705 train_time:63639ms step_avg:94.56ms +step:674/1705 train_time:63733ms step_avg:94.56ms +step:675/1705 train_time:63828ms step_avg:94.56ms +step:676/1705 train_time:63922ms step_avg:94.56ms +step:677/1705 train_time:64016ms step_avg:94.56ms +step:678/1705 train_time:64110ms step_avg:94.56ms +step:679/1705 train_time:64204ms step_avg:94.56ms +step:680/1705 train_time:64298ms step_avg:94.56ms +step:681/1705 train_time:64392ms step_avg:94.56ms +step:682/1705 train_time:64488ms step_avg:94.56ms +step:683/1705 train_time:64585ms step_avg:94.56ms +step:684/1705 train_time:64677ms step_avg:94.56ms +step:685/1705 train_time:64772ms step_avg:94.56ms +step:686/1705 train_time:64868ms step_avg:94.56ms +step:687/1705 train_time:64962ms step_avg:94.56ms +step:688/1705 train_time:65055ms step_avg:94.56ms +step:689/1705 train_time:65150ms step_avg:94.56ms +step:690/1705 train_time:65245ms step_avg:94.56ms +step:691/1705 train_time:65339ms step_avg:94.56ms +step:692/1705 train_time:65434ms step_avg:94.56ms +step:693/1705 train_time:65529ms step_avg:94.56ms +step:694/1705 train_time:65623ms step_avg:94.56ms +step:695/1705 train_time:65717ms step_avg:94.56ms +step:696/1705 train_time:65813ms step_avg:94.56ms +step:697/1705 train_time:65907ms step_avg:94.56ms +step:698/1705 train_time:66000ms step_avg:94.56ms +step:699/1705 train_time:66094ms step_avg:94.56ms +step:700/1705 train_time:66190ms step_avg:94.56ms +step:701/1705 train_time:66283ms step_avg:94.56ms +step:702/1705 train_time:66378ms step_avg:94.56ms +step:703/1705 train_time:66473ms step_avg:94.56ms +step:704/1705 train_time:66567ms step_avg:94.56ms +step:705/1705 train_time:66661ms step_avg:94.55ms +step:706/1705 train_time:66755ms step_avg:94.55ms +step:707/1705 train_time:66850ms step_avg:94.55ms +step:708/1705 train_time:66944ms step_avg:94.55ms +step:709/1705 train_time:67037ms step_avg:94.55ms +step:710/1705 train_time:67132ms step_avg:94.55ms +step:711/1705 train_time:67226ms step_avg:94.55ms +step:712/1705 train_time:67320ms step_avg:94.55ms +step:713/1705 train_time:67414ms step_avg:94.55ms +step:714/1705 train_time:67510ms step_avg:94.55ms +step:715/1705 train_time:67604ms step_avg:94.55ms +step:716/1705 train_time:67698ms step_avg:94.55ms +step:717/1705 train_time:67793ms step_avg:94.55ms +step:718/1705 train_time:67889ms step_avg:94.55ms +step:719/1705 train_time:67983ms step_avg:94.55ms +step:720/1705 train_time:68077ms step_avg:94.55ms +step:721/1705 train_time:68172ms step_avg:94.55ms +step:722/1705 train_time:68266ms step_avg:94.55ms +step:723/1705 train_time:68359ms step_avg:94.55ms +step:724/1705 train_time:68454ms step_avg:94.55ms +step:725/1705 train_time:68549ms step_avg:94.55ms +step:726/1705 train_time:68643ms step_avg:94.55ms +step:727/1705 train_time:68738ms step_avg:94.55ms +step:728/1705 train_time:68833ms step_avg:94.55ms +step:729/1705 train_time:68928ms step_avg:94.55ms +step:730/1705 train_time:69023ms step_avg:94.55ms +step:731/1705 train_time:69117ms step_avg:94.55ms +step:732/1705 train_time:69212ms step_avg:94.55ms +step:733/1705 train_time:69306ms step_avg:94.55ms +step:734/1705 train_time:69400ms step_avg:94.55ms +step:735/1705 train_time:69495ms step_avg:94.55ms +step:736/1705 train_time:69591ms step_avg:94.55ms +step:737/1705 train_time:69685ms step_avg:94.55ms +step:738/1705 train_time:69779ms step_avg:94.55ms +step:739/1705 train_time:69874ms step_avg:94.55ms +step:740/1705 train_time:69969ms step_avg:94.55ms +step:741/1705 train_time:70063ms step_avg:94.55ms +step:742/1705 train_time:70157ms step_avg:94.55ms +step:743/1705 train_time:70252ms step_avg:94.55ms +step:744/1705 train_time:70346ms step_avg:94.55ms +step:745/1705 train_time:70440ms step_avg:94.55ms +step:746/1705 train_time:70535ms step_avg:94.55ms +step:747/1705 train_time:70630ms step_avg:94.55ms +step:748/1705 train_time:70724ms step_avg:94.55ms +step:749/1705 train_time:70817ms step_avg:94.55ms +step:750/1705 train_time:70912ms step_avg:94.55ms +step:750/1705 val_loss:3.5658 train_time:71007ms step_avg:94.68ms +step:751/1705 train_time:71029ms step_avg:94.58ms +step:752/1705 train_time:71106ms step_avg:94.56ms +step:753/1705 train_time:71203ms step_avg:94.56ms +step:754/1705 train_time:71299ms step_avg:94.56ms +step:755/1705 train_time:71393ms step_avg:94.56ms +step:756/1705 train_time:71487ms step_avg:94.56ms +step:757/1705 train_time:71580ms step_avg:94.56ms +step:758/1705 train_time:71674ms step_avg:94.56ms +step:759/1705 train_time:71767ms step_avg:94.55ms +step:760/1705 train_time:71860ms step_avg:94.55ms +step:761/1705 train_time:71955ms step_avg:94.55ms +step:762/1705 train_time:72051ms step_avg:94.56ms +step:763/1705 train_time:72148ms step_avg:94.56ms +step:764/1705 train_time:72242ms step_avg:94.56ms +step:765/1705 train_time:72337ms step_avg:94.56ms +step:766/1705 train_time:72432ms step_avg:94.56ms +step:767/1705 train_time:72526ms step_avg:94.56ms +step:768/1705 train_time:72620ms step_avg:94.56ms +step:769/1705 train_time:72715ms step_avg:94.56ms +step:770/1705 train_time:72809ms step_avg:94.56ms +step:771/1705 train_time:72902ms step_avg:94.55ms +step:772/1705 train_time:72997ms step_avg:94.56ms +step:773/1705 train_time:73093ms step_avg:94.56ms +step:774/1705 train_time:73188ms step_avg:94.56ms +step:775/1705 train_time:73282ms step_avg:94.56ms +step:776/1705 train_time:73376ms step_avg:94.56ms +step:777/1705 train_time:73472ms step_avg:94.56ms +step:778/1705 train_time:73565ms step_avg:94.56ms +step:779/1705 train_time:73659ms step_avg:94.56ms +step:780/1705 train_time:73754ms step_avg:94.56ms +step:781/1705 train_time:73848ms step_avg:94.56ms +step:782/1705 train_time:73942ms step_avg:94.55ms +step:783/1705 train_time:74038ms step_avg:94.56ms +step:784/1705 train_time:74133ms step_avg:94.56ms +step:785/1705 train_time:74228ms step_avg:94.56ms +step:786/1705 train_time:74322ms step_avg:94.56ms +step:787/1705 train_time:74418ms step_avg:94.56ms +step:788/1705 train_time:74512ms step_avg:94.56ms +step:789/1705 train_time:74606ms step_avg:94.56ms +step:790/1705 train_time:74699ms step_avg:94.56ms +step:791/1705 train_time:74795ms step_avg:94.56ms +step:792/1705 train_time:74889ms step_avg:94.56ms +step:793/1705 train_time:74983ms step_avg:94.56ms +step:794/1705 train_time:75077ms step_avg:94.56ms +step:795/1705 train_time:75172ms step_avg:94.56ms +step:796/1705 train_time:75267ms step_avg:94.56ms +step:797/1705 train_time:75361ms step_avg:94.56ms +step:798/1705 train_time:75456ms step_avg:94.56ms +step:799/1705 train_time:75550ms step_avg:94.56ms +step:800/1705 train_time:75644ms step_avg:94.56ms +step:801/1705 train_time:75739ms step_avg:94.56ms +step:802/1705 train_time:75833ms step_avg:94.55ms +step:803/1705 train_time:75928ms step_avg:94.55ms +step:804/1705 train_time:76022ms step_avg:94.55ms +step:805/1705 train_time:76116ms step_avg:94.55ms +step:806/1705 train_time:76211ms step_avg:94.55ms +step:807/1705 train_time:76305ms step_avg:94.55ms +step:808/1705 train_time:76400ms step_avg:94.55ms +step:809/1705 train_time:76496ms step_avg:94.56ms +step:810/1705 train_time:76590ms step_avg:94.56ms +step:811/1705 train_time:76684ms step_avg:94.55ms +step:812/1705 train_time:76778ms step_avg:94.55ms +step:813/1705 train_time:76873ms step_avg:94.55ms +step:814/1705 train_time:76967ms step_avg:94.55ms +step:815/1705 train_time:77061ms step_avg:94.55ms +step:816/1705 train_time:77156ms step_avg:94.55ms +step:817/1705 train_time:77250ms step_avg:94.55ms +step:818/1705 train_time:77344ms step_avg:94.55ms +step:819/1705 train_time:77438ms step_avg:94.55ms +step:820/1705 train_time:77533ms step_avg:94.55ms +step:821/1705 train_time:77628ms step_avg:94.55ms +step:822/1705 train_time:77722ms step_avg:94.55ms +step:823/1705 train_time:77816ms step_avg:94.55ms +step:824/1705 train_time:77911ms step_avg:94.55ms +step:825/1705 train_time:78004ms step_avg:94.55ms +step:826/1705 train_time:78098ms step_avg:94.55ms +step:827/1705 train_time:78193ms step_avg:94.55ms +step:828/1705 train_time:78288ms step_avg:94.55ms +step:829/1705 train_time:78382ms step_avg:94.55ms +step:830/1705 train_time:78477ms step_avg:94.55ms +step:831/1705 train_time:78573ms step_avg:94.55ms +step:832/1705 train_time:78667ms step_avg:94.55ms +step:833/1705 train_time:78761ms step_avg:94.55ms +step:834/1705 train_time:78856ms step_avg:94.55ms +step:835/1705 train_time:78951ms step_avg:94.55ms +step:836/1705 train_time:79045ms step_avg:94.55ms +step:837/1705 train_time:79140ms step_avg:94.55ms +step:838/1705 train_time:79234ms step_avg:94.55ms +step:839/1705 train_time:79329ms step_avg:94.55ms +step:840/1705 train_time:79423ms step_avg:94.55ms +step:841/1705 train_time:79518ms step_avg:94.55ms +step:842/1705 train_time:79613ms step_avg:94.55ms +step:843/1705 train_time:79707ms step_avg:94.55ms +step:844/1705 train_time:79801ms step_avg:94.55ms +step:845/1705 train_time:79896ms step_avg:94.55ms +step:846/1705 train_time:79990ms step_avg:94.55ms +step:847/1705 train_time:80083ms step_avg:94.55ms +step:848/1705 train_time:80178ms step_avg:94.55ms +step:849/1705 train_time:80273ms step_avg:94.55ms +step:850/1705 train_time:80368ms step_avg:94.55ms +step:851/1705 train_time:80626ms step_avg:94.74ms +step:852/1705 train_time:80734ms step_avg:94.76ms +step:853/1705 train_time:80826ms step_avg:94.75ms +step:854/1705 train_time:80919ms step_avg:94.75ms +step:855/1705 train_time:81012ms step_avg:94.75ms +step:856/1705 train_time:81105ms step_avg:94.75ms +step:857/1705 train_time:81199ms step_avg:94.75ms +step:858/1705 train_time:81292ms step_avg:94.75ms +step:859/1705 train_time:81386ms step_avg:94.75ms +step:860/1705 train_time:81480ms step_avg:94.74ms +step:861/1705 train_time:81576ms step_avg:94.75ms +step:862/1705 train_time:81675ms step_avg:94.75ms +step:863/1705 train_time:81774ms step_avg:94.76ms +step:864/1705 train_time:81868ms step_avg:94.75ms +step:865/1705 train_time:81962ms step_avg:94.75ms +step:866/1705 train_time:82056ms step_avg:94.75ms +step:867/1705 train_time:82149ms step_avg:94.75ms +step:868/1705 train_time:82242ms step_avg:94.75ms +step:869/1705 train_time:82336ms step_avg:94.75ms +step:870/1705 train_time:82429ms step_avg:94.75ms +step:871/1705 train_time:82523ms step_avg:94.74ms +step:872/1705 train_time:82620ms step_avg:94.75ms +step:873/1705 train_time:82718ms step_avg:94.75ms +step:874/1705 train_time:82815ms step_avg:94.75ms +step:875/1705 train_time:82910ms step_avg:94.75ms +step:875/1705 val_loss:3.5261 train_time:83004ms step_avg:94.86ms +step:876/1705 train_time:83025ms step_avg:94.78ms +step:877/1705 train_time:83103ms step_avg:94.76ms +step:878/1705 train_time:83204ms step_avg:94.76ms +step:879/1705 train_time:83298ms step_avg:94.76ms +step:880/1705 train_time:83392ms step_avg:94.76ms +step:881/1705 train_time:83484ms step_avg:94.76ms +step:882/1705 train_time:83578ms step_avg:94.76ms +step:883/1705 train_time:83672ms step_avg:94.76ms +step:884/1705 train_time:83764ms step_avg:94.76ms +step:885/1705 train_time:83858ms step_avg:94.75ms +step:886/1705 train_time:83952ms step_avg:94.75ms +step:887/1705 train_time:84048ms step_avg:94.75ms +step:888/1705 train_time:84145ms step_avg:94.76ms +step:889/1705 train_time:84242ms step_avg:94.76ms +step:890/1705 train_time:84337ms step_avg:94.76ms +step:891/1705 train_time:84431ms step_avg:94.76ms +step:892/1705 train_time:84525ms step_avg:94.76ms +step:893/1705 train_time:84620ms step_avg:94.76ms +step:894/1705 train_time:84713ms step_avg:94.76ms +step:895/1705 train_time:84806ms step_avg:94.76ms +step:896/1705 train_time:84900ms step_avg:94.75ms +step:897/1705 train_time:84995ms step_avg:94.76ms +step:898/1705 train_time:85091ms step_avg:94.76ms +step:899/1705 train_time:85185ms step_avg:94.76ms +step:900/1705 train_time:85280ms step_avg:94.76ms +step:901/1705 train_time:85375ms step_avg:94.76ms +step:902/1705 train_time:85470ms step_avg:94.76ms +step:903/1705 train_time:85564ms step_avg:94.75ms +step:904/1705 train_time:85658ms step_avg:94.75ms +step:905/1705 train_time:85751ms step_avg:94.75ms +step:906/1705 train_time:85844ms step_avg:94.75ms +step:907/1705 train_time:85938ms step_avg:94.75ms +step:908/1705 train_time:86034ms step_avg:94.75ms +step:909/1705 train_time:86128ms step_avg:94.75ms +step:910/1705 train_time:86223ms step_avg:94.75ms +step:911/1705 train_time:86318ms step_avg:94.75ms +step:912/1705 train_time:86414ms step_avg:94.75ms +step:913/1705 train_time:86508ms step_avg:94.75ms +step:914/1705 train_time:86603ms step_avg:94.75ms +step:915/1705 train_time:86698ms step_avg:94.75ms +step:916/1705 train_time:86792ms step_avg:94.75ms +step:917/1705 train_time:86886ms step_avg:94.75ms +step:918/1705 train_time:86980ms step_avg:94.75ms +step:919/1705 train_time:87076ms step_avg:94.75ms +step:920/1705 train_time:87171ms step_avg:94.75ms +step:921/1705 train_time:87265ms step_avg:94.75ms +step:922/1705 train_time:87360ms step_avg:94.75ms +step:923/1705 train_time:87455ms step_avg:94.75ms +step:924/1705 train_time:87550ms step_avg:94.75ms +step:925/1705 train_time:87644ms step_avg:94.75ms +step:926/1705 train_time:87739ms step_avg:94.75ms +step:927/1705 train_time:87833ms step_avg:94.75ms +step:928/1705 train_time:87927ms step_avg:94.75ms +step:929/1705 train_time:88021ms step_avg:94.75ms +step:930/1705 train_time:88117ms step_avg:94.75ms +step:931/1705 train_time:88212ms step_avg:94.75ms +step:932/1705 train_time:88306ms step_avg:94.75ms +step:933/1705 train_time:88400ms step_avg:94.75ms +step:934/1705 train_time:88496ms step_avg:94.75ms +step:935/1705 train_time:88591ms step_avg:94.75ms +step:936/1705 train_time:88685ms step_avg:94.75ms +step:937/1705 train_time:88779ms step_avg:94.75ms +step:938/1705 train_time:88874ms step_avg:94.75ms +step:939/1705 train_time:88968ms step_avg:94.75ms +step:940/1705 train_time:89063ms step_avg:94.75ms +step:941/1705 train_time:89158ms step_avg:94.75ms +step:942/1705 train_time:89252ms step_avg:94.75ms +step:943/1705 train_time:89345ms step_avg:94.75ms +step:944/1705 train_time:89440ms step_avg:94.75ms +step:945/1705 train_time:89536ms step_avg:94.75ms +step:946/1705 train_time:89632ms step_avg:94.75ms +step:947/1705 train_time:89724ms step_avg:94.75ms +step:948/1705 train_time:89819ms step_avg:94.75ms +step:949/1705 train_time:89914ms step_avg:94.75ms +step:950/1705 train_time:90008ms step_avg:94.75ms +step:951/1705 train_time:90102ms step_avg:94.74ms +step:952/1705 train_time:90198ms step_avg:94.75ms +step:953/1705 train_time:90293ms step_avg:94.75ms +step:954/1705 train_time:90386ms step_avg:94.74ms +step:955/1705 train_time:90481ms step_avg:94.74ms +step:956/1705 train_time:90576ms step_avg:94.75ms +step:957/1705 train_time:90671ms step_avg:94.74ms +step:958/1705 train_time:90764ms step_avg:94.74ms +step:959/1705 train_time:90859ms step_avg:94.74ms +step:960/1705 train_time:90954ms step_avg:94.74ms +step:961/1705 train_time:91047ms step_avg:94.74ms +step:962/1705 train_time:91142ms step_avg:94.74ms +step:963/1705 train_time:91237ms step_avg:94.74ms +step:964/1705 train_time:91331ms step_avg:94.74ms +step:965/1705 train_time:91425ms step_avg:94.74ms +step:966/1705 train_time:91520ms step_avg:94.74ms +step:967/1705 train_time:91615ms step_avg:94.74ms +step:968/1705 train_time:91709ms step_avg:94.74ms +step:969/1705 train_time:91803ms step_avg:94.74ms +step:970/1705 train_time:91897ms step_avg:94.74ms +step:971/1705 train_time:91992ms step_avg:94.74ms +step:972/1705 train_time:92087ms step_avg:94.74ms +step:973/1705 train_time:92181ms step_avg:94.74ms +step:974/1705 train_time:92276ms step_avg:94.74ms +step:975/1705 train_time:92370ms step_avg:94.74ms +step:976/1705 train_time:92465ms step_avg:94.74ms +step:977/1705 train_time:92559ms step_avg:94.74ms +step:978/1705 train_time:92654ms step_avg:94.74ms +step:979/1705 train_time:92748ms step_avg:94.74ms +step:980/1705 train_time:92842ms step_avg:94.74ms +step:981/1705 train_time:92938ms step_avg:94.74ms +step:982/1705 train_time:93032ms step_avg:94.74ms +step:983/1705 train_time:93126ms step_avg:94.74ms +step:984/1705 train_time:93221ms step_avg:94.74ms +step:985/1705 train_time:93316ms step_avg:94.74ms +step:986/1705 train_time:93411ms step_avg:94.74ms +step:987/1705 train_time:93505ms step_avg:94.74ms +step:988/1705 train_time:93601ms step_avg:94.74ms +step:989/1705 train_time:93696ms step_avg:94.74ms +step:990/1705 train_time:93790ms step_avg:94.74ms +step:991/1705 train_time:93884ms step_avg:94.74ms +step:992/1705 train_time:93979ms step_avg:94.74ms +step:993/1705 train_time:94073ms step_avg:94.74ms +step:994/1705 train_time:94168ms step_avg:94.74ms +step:995/1705 train_time:94262ms step_avg:94.74ms +step:996/1705 train_time:94358ms step_avg:94.74ms +step:997/1705 train_time:94454ms step_avg:94.74ms +step:998/1705 train_time:94548ms step_avg:94.74ms +step:999/1705 train_time:94642ms step_avg:94.74ms +step:1000/1705 train_time:94737ms step_avg:94.74ms +step:1000/1705 val_loss:3.4867 train_time:94832ms step_avg:94.83ms +step:1001/1705 train_time:94853ms step_avg:94.76ms +step:1002/1705 train_time:94931ms step_avg:94.74ms +step:1003/1705 train_time:95027ms step_avg:94.74ms +step:1004/1705 train_time:95122ms step_avg:94.74ms +step:1005/1705 train_time:95216ms step_avg:94.74ms +step:1006/1705 train_time:95309ms step_avg:94.74ms +step:1007/1705 train_time:95403ms step_avg:94.74ms +step:1008/1705 train_time:95497ms step_avg:94.74ms +step:1009/1705 train_time:95590ms step_avg:94.74ms +step:1010/1705 train_time:95684ms step_avg:94.74ms +step:1011/1705 train_time:95779ms step_avg:94.74ms +step:1012/1705 train_time:95877ms step_avg:94.74ms +step:1013/1705 train_time:95975ms step_avg:94.74ms +step:1014/1705 train_time:96069ms step_avg:94.74ms +step:1015/1705 train_time:96163ms step_avg:94.74ms +step:1016/1705 train_time:96257ms step_avg:94.74ms +step:1017/1705 train_time:96351ms step_avg:94.74ms +step:1018/1705 train_time:96444ms step_avg:94.74ms +step:1019/1705 train_time:96538ms step_avg:94.74ms +step:1020/1705 train_time:96632ms step_avg:94.74ms +step:1021/1705 train_time:96725ms step_avg:94.74ms +step:1022/1705 train_time:96821ms step_avg:94.74ms +step:1023/1705 train_time:96917ms step_avg:94.74ms +step:1024/1705 train_time:97013ms step_avg:94.74ms +step:1025/1705 train_time:97107ms step_avg:94.74ms +step:1026/1705 train_time:97201ms step_avg:94.74ms +step:1027/1705 train_time:97297ms step_avg:94.74ms +step:1028/1705 train_time:97391ms step_avg:94.74ms +step:1029/1705 train_time:97485ms step_avg:94.74ms +step:1030/1705 train_time:97579ms step_avg:94.74ms +step:1031/1705 train_time:97673ms step_avg:94.74ms +step:1032/1705 train_time:97767ms step_avg:94.74ms +step:1033/1705 train_time:97863ms step_avg:94.74ms +step:1034/1705 train_time:97959ms step_avg:94.74ms +step:1035/1705 train_time:98055ms step_avg:94.74ms +step:1036/1705 train_time:98149ms step_avg:94.74ms +step:1037/1705 train_time:98243ms step_avg:94.74ms +step:1038/1705 train_time:98339ms step_avg:94.74ms +step:1039/1705 train_time:98432ms step_avg:94.74ms +step:1040/1705 train_time:98526ms step_avg:94.74ms +step:1041/1705 train_time:98620ms step_avg:94.74ms +step:1042/1705 train_time:98715ms step_avg:94.74ms +step:1043/1705 train_time:98809ms step_avg:94.74ms +step:1044/1705 train_time:98904ms step_avg:94.74ms +step:1045/1705 train_time:99000ms step_avg:94.74ms +step:1046/1705 train_time:99096ms step_avg:94.74ms +step:1047/1705 train_time:99190ms step_avg:94.74ms +step:1048/1705 train_time:99284ms step_avg:94.74ms +step:1049/1705 train_time:99379ms step_avg:94.74ms +step:1050/1705 train_time:99473ms step_avg:94.74ms +step:1051/1705 train_time:99566ms step_avg:94.73ms +step:1052/1705 train_time:99661ms step_avg:94.73ms +step:1053/1705 train_time:99755ms step_avg:94.73ms +step:1054/1705 train_time:99851ms step_avg:94.73ms +step:1055/1705 train_time:99945ms step_avg:94.73ms +step:1056/1705 train_time:100040ms step_avg:94.73ms +step:1057/1705 train_time:100136ms step_avg:94.74ms +step:1058/1705 train_time:100230ms step_avg:94.74ms +step:1059/1705 train_time:100324ms step_avg:94.73ms +step:1060/1705 train_time:100420ms step_avg:94.74ms +step:1061/1705 train_time:100515ms step_avg:94.74ms +step:1062/1705 train_time:100779ms step_avg:94.90ms +step:1063/1705 train_time:100964ms step_avg:94.98ms +step:1064/1705 train_time:101056ms step_avg:94.98ms +step:1065/1705 train_time:101150ms step_avg:94.98ms +step:1066/1705 train_time:101243ms step_avg:94.97ms +step:1067/1705 train_time:101337ms step_avg:94.97ms +step:1068/1705 train_time:101430ms step_avg:94.97ms +step:1069/1705 train_time:101524ms step_avg:94.97ms +step:1070/1705 train_time:101617ms step_avg:94.97ms +step:1071/1705 train_time:101711ms step_avg:94.97ms +step:1072/1705 train_time:101805ms step_avg:94.97ms +step:1073/1705 train_time:101905ms step_avg:94.97ms +step:1074/1705 train_time:102003ms step_avg:94.97ms +step:1075/1705 train_time:102099ms step_avg:94.98ms +step:1076/1705 train_time:102194ms step_avg:94.98ms +step:1077/1705 train_time:102288ms step_avg:94.98ms +step:1078/1705 train_time:102381ms step_avg:94.97ms +step:1079/1705 train_time:102477ms step_avg:94.97ms +step:1080/1705 train_time:102570ms step_avg:94.97ms +step:1081/1705 train_time:102663ms step_avg:94.97ms +step:1082/1705 train_time:102757ms step_avg:94.97ms +step:1083/1705 train_time:102853ms step_avg:94.97ms +step:1084/1705 train_time:102948ms step_avg:94.97ms +step:1085/1705 train_time:103043ms step_avg:94.97ms +step:1086/1705 train_time:103138ms step_avg:94.97ms +step:1087/1705 train_time:103233ms step_avg:94.97ms +step:1088/1705 train_time:103326ms step_avg:94.97ms +step:1089/1705 train_time:103420ms step_avg:94.97ms +step:1090/1705 train_time:103515ms step_avg:94.97ms +step:1091/1705 train_time:103609ms step_avg:94.97ms +step:1092/1705 train_time:103703ms step_avg:94.97ms +step:1093/1705 train_time:103798ms step_avg:94.97ms +step:1094/1705 train_time:103893ms step_avg:94.97ms +step:1095/1705 train_time:103988ms step_avg:94.97ms +step:1096/1705 train_time:104082ms step_avg:94.97ms +step:1097/1705 train_time:104177ms step_avg:94.97ms +step:1098/1705 train_time:104271ms step_avg:94.96ms +step:1099/1705 train_time:104365ms step_avg:94.96ms +step:1100/1705 train_time:104459ms step_avg:94.96ms +step:1101/1705 train_time:104554ms step_avg:94.96ms +step:1102/1705 train_time:104648ms step_avg:94.96ms +step:1103/1705 train_time:104741ms step_avg:94.96ms +step:1104/1705 train_time:104836ms step_avg:94.96ms +step:1105/1705 train_time:104931ms step_avg:94.96ms +step:1106/1705 train_time:105026ms step_avg:94.96ms +step:1107/1705 train_time:105121ms step_avg:94.96ms +step:1108/1705 train_time:105216ms step_avg:94.96ms +step:1109/1705 train_time:105310ms step_avg:94.96ms +step:1110/1705 train_time:105404ms step_avg:94.96ms +step:1111/1705 train_time:105499ms step_avg:94.96ms +step:1112/1705 train_time:105594ms step_avg:94.96ms +step:1113/1705 train_time:105687ms step_avg:94.96ms +step:1114/1705 train_time:105781ms step_avg:94.96ms +step:1115/1705 train_time:105877ms step_avg:94.96ms +step:1116/1705 train_time:105972ms step_avg:94.96ms +step:1117/1705 train_time:106066ms step_avg:94.96ms +step:1118/1705 train_time:106161ms step_avg:94.96ms +step:1119/1705 train_time:106256ms step_avg:94.96ms +step:1120/1705 train_time:106351ms step_avg:94.96ms +step:1121/1705 train_time:106444ms step_avg:94.95ms +step:1122/1705 train_time:106539ms step_avg:94.95ms +step:1123/1705 train_time:106633ms step_avg:94.95ms +step:1124/1705 train_time:106728ms step_avg:94.95ms +step:1125/1705 train_time:106822ms step_avg:94.95ms +step:1125/1705 val_loss:3.4396 train_time:106917ms step_avg:95.04ms +step:1126/1705 train_time:106938ms step_avg:94.97ms +step:1127/1705 train_time:107018ms step_avg:94.96ms +step:1128/1705 train_time:107116ms step_avg:94.96ms +step:1129/1705 train_time:107211ms step_avg:94.96ms +step:1130/1705 train_time:107304ms step_avg:94.96ms +step:1131/1705 train_time:107398ms step_avg:94.96ms +step:1132/1705 train_time:107492ms step_avg:94.96ms +step:1133/1705 train_time:107585ms step_avg:94.96ms +step:1134/1705 train_time:107679ms step_avg:94.95ms +step:1135/1705 train_time:107773ms step_avg:94.95ms +step:1136/1705 train_time:107867ms step_avg:94.95ms +step:1137/1705 train_time:107963ms step_avg:94.95ms +step:1138/1705 train_time:108060ms step_avg:94.96ms +step:1139/1705 train_time:108157ms step_avg:94.96ms +step:1140/1705 train_time:108252ms step_avg:94.96ms +step:1141/1705 train_time:108348ms step_avg:94.96ms +step:1142/1705 train_time:108441ms step_avg:94.96ms +step:1143/1705 train_time:108537ms step_avg:94.96ms +step:1144/1705 train_time:108631ms step_avg:94.96ms +step:1145/1705 train_time:108726ms step_avg:94.96ms +step:1146/1705 train_time:108821ms step_avg:94.96ms +step:1147/1705 train_time:108917ms step_avg:94.96ms +step:1148/1705 train_time:109014ms step_avg:94.96ms +step:1149/1705 train_time:109110ms step_avg:94.96ms +step:1150/1705 train_time:109205ms step_avg:94.96ms +step:1151/1705 train_time:109301ms step_avg:94.96ms +step:1152/1705 train_time:109396ms step_avg:94.96ms +step:1153/1705 train_time:109492ms step_avg:94.96ms +step:1154/1705 train_time:109585ms step_avg:94.96ms +step:1155/1705 train_time:109681ms step_avg:94.96ms +step:1156/1705 train_time:109776ms step_avg:94.96ms +step:1157/1705 train_time:109871ms step_avg:94.96ms +step:1158/1705 train_time:109966ms step_avg:94.96ms +step:1159/1705 train_time:110063ms step_avg:94.96ms +step:1160/1705 train_time:110159ms step_avg:94.96ms +step:1161/1705 train_time:110256ms step_avg:94.97ms +step:1162/1705 train_time:110351ms step_avg:94.97ms +step:1163/1705 train_time:110447ms step_avg:94.97ms +step:1164/1705 train_time:110541ms step_avg:94.97ms +step:1165/1705 train_time:110636ms step_avg:94.97ms +step:1166/1705 train_time:110732ms step_avg:94.97ms +step:1167/1705 train_time:110827ms step_avg:94.97ms +step:1168/1705 train_time:110922ms step_avg:94.97ms +step:1169/1705 train_time:111017ms step_avg:94.97ms +step:1170/1705 train_time:111112ms step_avg:94.97ms +step:1171/1705 train_time:111208ms step_avg:94.97ms +step:1172/1705 train_time:111303ms step_avg:94.97ms +step:1173/1705 train_time:111399ms step_avg:94.97ms +step:1174/1705 train_time:111496ms step_avg:94.97ms +step:1175/1705 train_time:111591ms step_avg:94.97ms +step:1176/1705 train_time:111686ms step_avg:94.97ms +step:1177/1705 train_time:111781ms step_avg:94.97ms +step:1178/1705 train_time:111876ms step_avg:94.97ms +step:1179/1705 train_time:111971ms step_avg:94.97ms +step:1180/1705 train_time:112066ms step_avg:94.97ms +step:1181/1705 train_time:112162ms step_avg:94.97ms +step:1182/1705 train_time:112258ms step_avg:94.97ms +step:1183/1705 train_time:112354ms step_avg:94.97ms +step:1184/1705 train_time:112449ms step_avg:94.97ms +step:1185/1705 train_time:112544ms step_avg:94.97ms +step:1186/1705 train_time:112639ms step_avg:94.97ms +step:1187/1705 train_time:112735ms step_avg:94.97ms +step:1188/1705 train_time:112831ms step_avg:94.98ms +step:1189/1705 train_time:112925ms step_avg:94.98ms +step:1190/1705 train_time:113021ms step_avg:94.98ms +step:1191/1705 train_time:113118ms step_avg:94.98ms +step:1192/1705 train_time:113214ms step_avg:94.98ms +step:1193/1705 train_time:113309ms step_avg:94.98ms +step:1194/1705 train_time:113404ms step_avg:94.98ms +step:1195/1705 train_time:113500ms step_avg:94.98ms +step:1196/1705 train_time:113595ms step_avg:94.98ms +step:1197/1705 train_time:113691ms step_avg:94.98ms +step:1198/1705 train_time:113785ms step_avg:94.98ms +step:1199/1705 train_time:113880ms step_avg:94.98ms +step:1200/1705 train_time:113976ms step_avg:94.98ms +step:1201/1705 train_time:114070ms step_avg:94.98ms +step:1202/1705 train_time:114165ms step_avg:94.98ms +step:1203/1705 train_time:114260ms step_avg:94.98ms +step:1204/1705 train_time:114356ms step_avg:94.98ms +step:1205/1705 train_time:114453ms step_avg:94.98ms +step:1206/1705 train_time:114548ms step_avg:94.98ms +step:1207/1705 train_time:114643ms step_avg:94.98ms +step:1208/1705 train_time:114738ms step_avg:94.98ms +step:1209/1705 train_time:114834ms step_avg:94.98ms +step:1210/1705 train_time:114929ms step_avg:94.98ms +step:1211/1705 train_time:115023ms step_avg:94.98ms +step:1212/1705 train_time:115119ms step_avg:94.98ms +step:1213/1705 train_time:115214ms step_avg:94.98ms +step:1214/1705 train_time:115310ms step_avg:94.98ms +step:1215/1705 train_time:115405ms step_avg:94.98ms +step:1216/1705 train_time:115500ms step_avg:94.98ms +step:1217/1705 train_time:115596ms step_avg:94.98ms +step:1218/1705 train_time:115692ms step_avg:94.99ms +step:1219/1705 train_time:115789ms step_avg:94.99ms +step:1220/1705 train_time:115883ms step_avg:94.99ms +step:1221/1705 train_time:115979ms step_avg:94.99ms +step:1222/1705 train_time:116074ms step_avg:94.99ms +step:1223/1705 train_time:116169ms step_avg:94.99ms +step:1224/1705 train_time:116265ms step_avg:94.99ms +step:1225/1705 train_time:116360ms step_avg:94.99ms +step:1226/1705 train_time:116456ms step_avg:94.99ms +step:1227/1705 train_time:116552ms step_avg:94.99ms +step:1228/1705 train_time:116647ms step_avg:94.99ms +step:1229/1705 train_time:116741ms step_avg:94.99ms +step:1230/1705 train_time:116837ms step_avg:94.99ms +step:1231/1705 train_time:116933ms step_avg:94.99ms +step:1232/1705 train_time:117028ms step_avg:94.99ms +step:1233/1705 train_time:117123ms step_avg:94.99ms +step:1234/1705 train_time:117219ms step_avg:94.99ms +step:1235/1705 train_time:117314ms step_avg:94.99ms +step:1236/1705 train_time:117409ms step_avg:94.99ms +step:1237/1705 train_time:117504ms step_avg:94.99ms +step:1238/1705 train_time:117600ms step_avg:94.99ms +step:1239/1705 train_time:117696ms step_avg:94.99ms +step:1240/1705 train_time:117792ms step_avg:94.99ms +step:1241/1705 train_time:117887ms step_avg:94.99ms +step:1242/1705 train_time:117982ms step_avg:94.99ms +step:1243/1705 train_time:118078ms step_avg:94.99ms +step:1244/1705 train_time:118174ms step_avg:94.99ms +step:1245/1705 train_time:118269ms step_avg:94.99ms +step:1246/1705 train_time:118363ms step_avg:94.99ms +step:1247/1705 train_time:118458ms step_avg:94.99ms +step:1248/1705 train_time:118554ms step_avg:95.00ms +step:1249/1705 train_time:118650ms step_avg:95.00ms +step:1250/1705 train_time:118745ms step_avg:95.00ms +step:1250/1705 val_loss:3.3903 train_time:118841ms step_avg:95.07ms +step:1251/1705 train_time:118864ms step_avg:95.02ms +step:1252/1705 train_time:118946ms step_avg:95.00ms +step:1253/1705 train_time:119041ms step_avg:95.00ms +step:1254/1705 train_time:119136ms step_avg:95.01ms +step:1255/1705 train_time:119231ms step_avg:95.00ms +step:1256/1705 train_time:119325ms step_avg:95.00ms +step:1257/1705 train_time:119419ms step_avg:95.00ms +step:1258/1705 train_time:119514ms step_avg:95.00ms +step:1259/1705 train_time:119607ms step_avg:95.00ms +step:1260/1705 train_time:119701ms step_avg:95.00ms +step:1261/1705 train_time:119797ms step_avg:95.00ms +step:1262/1705 train_time:119896ms step_avg:95.00ms +step:1263/1705 train_time:119995ms step_avg:95.01ms +step:1264/1705 train_time:120090ms step_avg:95.01ms +step:1265/1705 train_time:120185ms step_avg:95.01ms +step:1266/1705 train_time:120279ms step_avg:95.01ms +step:1267/1705 train_time:120373ms step_avg:95.01ms +step:1268/1705 train_time:120468ms step_avg:95.01ms +step:1269/1705 train_time:120562ms step_avg:95.01ms +step:1270/1705 train_time:120656ms step_avg:95.01ms +step:1271/1705 train_time:120751ms step_avg:95.00ms +step:1272/1705 train_time:120847ms step_avg:95.01ms +step:1273/1705 train_time:120943ms step_avg:95.01ms +step:1274/1705 train_time:121206ms step_avg:95.14ms +step:1275/1705 train_time:121398ms step_avg:95.21ms +step:1276/1705 train_time:121492ms step_avg:95.21ms +step:1277/1705 train_time:121586ms step_avg:95.21ms +step:1278/1705 train_time:121680ms step_avg:95.21ms +step:1279/1705 train_time:121775ms step_avg:95.21ms +step:1280/1705 train_time:121869ms step_avg:95.21ms +step:1281/1705 train_time:121962ms step_avg:95.21ms +step:1282/1705 train_time:122057ms step_avg:95.21ms +step:1283/1705 train_time:122151ms step_avg:95.21ms +step:1284/1705 train_time:122252ms step_avg:95.21ms +step:1285/1705 train_time:122352ms step_avg:95.22ms +step:1286/1705 train_time:122447ms step_avg:95.22ms +step:1287/1705 train_time:122542ms step_avg:95.21ms +step:1288/1705 train_time:122637ms step_avg:95.22ms +step:1289/1705 train_time:122731ms step_avg:95.21ms +step:1290/1705 train_time:122826ms step_avg:95.21ms +step:1291/1705 train_time:122921ms step_avg:95.21ms +step:1292/1705 train_time:123015ms step_avg:95.21ms +step:1293/1705 train_time:123110ms step_avg:95.21ms +step:1294/1705 train_time:123205ms step_avg:95.21ms +step:1295/1705 train_time:123302ms step_avg:95.21ms +step:1296/1705 train_time:123398ms step_avg:95.21ms +step:1297/1705 train_time:123495ms step_avg:95.22ms +step:1298/1705 train_time:123590ms step_avg:95.22ms +step:1299/1705 train_time:123684ms step_avg:95.22ms +step:1300/1705 train_time:123779ms step_avg:95.21ms +step:1301/1705 train_time:123875ms step_avg:95.22ms +step:1302/1705 train_time:123969ms step_avg:95.21ms +step:1303/1705 train_time:124064ms step_avg:95.21ms +step:1304/1705 train_time:124159ms step_avg:95.21ms +step:1305/1705 train_time:124256ms step_avg:95.22ms +step:1306/1705 train_time:124353ms step_avg:95.22ms +step:1307/1705 train_time:124450ms step_avg:95.22ms +step:1308/1705 train_time:124544ms step_avg:95.22ms +step:1309/1705 train_time:124639ms step_avg:95.22ms +step:1310/1705 train_time:124734ms step_avg:95.22ms +step:1311/1705 train_time:124829ms step_avg:95.22ms +step:1312/1705 train_time:124923ms step_avg:95.22ms +step:1313/1705 train_time:125019ms step_avg:95.22ms +step:1314/1705 train_time:125114ms step_avg:95.22ms +step:1315/1705 train_time:125209ms step_avg:95.22ms +step:1316/1705 train_time:125304ms step_avg:95.22ms +step:1317/1705 train_time:125400ms step_avg:95.22ms +step:1318/1705 train_time:125496ms step_avg:95.22ms +step:1319/1705 train_time:125593ms step_avg:95.22ms +step:1320/1705 train_time:125688ms step_avg:95.22ms +step:1321/1705 train_time:125782ms step_avg:95.22ms +step:1322/1705 train_time:125878ms step_avg:95.22ms +step:1323/1705 train_time:125973ms step_avg:95.22ms +step:1324/1705 train_time:126068ms step_avg:95.22ms +step:1325/1705 train_time:126162ms step_avg:95.22ms +step:1326/1705 train_time:126257ms step_avg:95.22ms +step:1327/1705 train_time:126352ms step_avg:95.22ms +step:1328/1705 train_time:126447ms step_avg:95.22ms +step:1329/1705 train_time:126542ms step_avg:95.22ms +step:1330/1705 train_time:126638ms step_avg:95.22ms +step:1331/1705 train_time:126734ms step_avg:95.22ms +step:1332/1705 train_time:126829ms step_avg:95.22ms +step:1333/1705 train_time:126923ms step_avg:95.22ms +step:1334/1705 train_time:127018ms step_avg:95.22ms +step:1335/1705 train_time:127114ms step_avg:95.22ms +step:1336/1705 train_time:127210ms step_avg:95.22ms +step:1337/1705 train_time:127306ms step_avg:95.22ms +step:1338/1705 train_time:127400ms step_avg:95.22ms +step:1339/1705 train_time:127496ms step_avg:95.22ms +step:1340/1705 train_time:127592ms step_avg:95.22ms +step:1341/1705 train_time:127687ms step_avg:95.22ms +step:1342/1705 train_time:127781ms step_avg:95.22ms +step:1343/1705 train_time:127877ms step_avg:95.22ms +step:1344/1705 train_time:127971ms step_avg:95.22ms +step:1345/1705 train_time:128065ms step_avg:95.22ms +step:1346/1705 train_time:128160ms step_avg:95.22ms +step:1347/1705 train_time:128256ms step_avg:95.22ms +step:1348/1705 train_time:128351ms step_avg:95.22ms +step:1349/1705 train_time:128446ms step_avg:95.22ms +step:1350/1705 train_time:128541ms step_avg:95.22ms +step:1351/1705 train_time:128637ms step_avg:95.22ms +step:1352/1705 train_time:128734ms step_avg:95.22ms +step:1353/1705 train_time:128829ms step_avg:95.22ms +step:1354/1705 train_time:128923ms step_avg:95.22ms +step:1355/1705 train_time:129018ms step_avg:95.22ms +step:1356/1705 train_time:129114ms step_avg:95.22ms +step:1357/1705 train_time:129209ms step_avg:95.22ms +step:1358/1705 train_time:129303ms step_avg:95.22ms +step:1359/1705 train_time:129399ms step_avg:95.22ms +step:1360/1705 train_time:129495ms step_avg:95.22ms +step:1361/1705 train_time:129590ms step_avg:95.22ms +step:1362/1705 train_time:129686ms step_avg:95.22ms +step:1363/1705 train_time:129781ms step_avg:95.22ms +step:1364/1705 train_time:129876ms step_avg:95.22ms +step:1365/1705 train_time:129971ms step_avg:95.22ms +step:1366/1705 train_time:130067ms step_avg:95.22ms +step:1367/1705 train_time:130161ms step_avg:95.22ms +step:1368/1705 train_time:130257ms step_avg:95.22ms +step:1369/1705 train_time:130352ms step_avg:95.22ms +step:1370/1705 train_time:130448ms step_avg:95.22ms +step:1371/1705 train_time:130543ms step_avg:95.22ms +step:1372/1705 train_time:130638ms step_avg:95.22ms +step:1373/1705 train_time:130734ms step_avg:95.22ms +step:1374/1705 train_time:130829ms step_avg:95.22ms +step:1375/1705 train_time:130923ms step_avg:95.22ms +step:1375/1705 val_loss:3.3528 train_time:131019ms step_avg:95.29ms +step:1376/1705 train_time:131040ms step_avg:95.23ms +step:1377/1705 train_time:131119ms step_avg:95.22ms +step:1378/1705 train_time:131219ms step_avg:95.22ms +step:1379/1705 train_time:131313ms step_avg:95.22ms +step:1380/1705 train_time:131407ms step_avg:95.22ms +step:1381/1705 train_time:131501ms step_avg:95.22ms +step:1382/1705 train_time:131595ms step_avg:95.22ms +step:1383/1705 train_time:131689ms step_avg:95.22ms +step:1384/1705 train_time:131784ms step_avg:95.22ms +step:1385/1705 train_time:131879ms step_avg:95.22ms +step:1386/1705 train_time:131974ms step_avg:95.22ms +step:1387/1705 train_time:132072ms step_avg:95.22ms +step:1388/1705 train_time:132170ms step_avg:95.22ms +step:1389/1705 train_time:132266ms step_avg:95.22ms +step:1390/1705 train_time:132363ms step_avg:95.23ms +step:1391/1705 train_time:132459ms step_avg:95.23ms +step:1392/1705 train_time:132553ms step_avg:95.22ms +step:1393/1705 train_time:132647ms step_avg:95.22ms +step:1394/1705 train_time:132742ms step_avg:95.22ms +step:1395/1705 train_time:132836ms step_avg:95.22ms +step:1396/1705 train_time:132931ms step_avg:95.22ms +step:1397/1705 train_time:133028ms step_avg:95.22ms +step:1398/1705 train_time:133125ms step_avg:95.23ms +step:1399/1705 train_time:133221ms step_avg:95.23ms +step:1400/1705 train_time:133318ms step_avg:95.23ms +step:1401/1705 train_time:133412ms step_avg:95.23ms +step:1402/1705 train_time:133507ms step_avg:95.23ms +step:1403/1705 train_time:133602ms step_avg:95.23ms +step:1404/1705 train_time:133696ms step_avg:95.22ms +step:1405/1705 train_time:133790ms step_avg:95.22ms +step:1406/1705 train_time:133886ms step_avg:95.22ms +step:1407/1705 train_time:133983ms step_avg:95.23ms +step:1408/1705 train_time:134079ms step_avg:95.23ms +step:1409/1705 train_time:134175ms step_avg:95.23ms +step:1410/1705 train_time:134270ms step_avg:95.23ms +step:1411/1705 train_time:134366ms step_avg:95.23ms +step:1412/1705 train_time:134462ms step_avg:95.23ms +step:1413/1705 train_time:134557ms step_avg:95.23ms +step:1414/1705 train_time:134651ms step_avg:95.23ms +step:1415/1705 train_time:134746ms step_avg:95.23ms +step:1416/1705 train_time:134842ms step_avg:95.23ms +step:1417/1705 train_time:134937ms step_avg:95.23ms +step:1418/1705 train_time:135031ms step_avg:95.23ms +step:1419/1705 train_time:135127ms step_avg:95.23ms +step:1420/1705 train_time:135224ms step_avg:95.23ms +step:1421/1705 train_time:135320ms step_avg:95.23ms +step:1422/1705 train_time:135415ms step_avg:95.23ms +step:1423/1705 train_time:135510ms step_avg:95.23ms +step:1424/1705 train_time:135605ms step_avg:95.23ms +step:1425/1705 train_time:135701ms step_avg:95.23ms +step:1426/1705 train_time:135796ms step_avg:95.23ms +step:1427/1705 train_time:135890ms step_avg:95.23ms +step:1428/1705 train_time:135985ms step_avg:95.23ms +step:1429/1705 train_time:136081ms step_avg:95.23ms +step:1430/1705 train_time:136177ms step_avg:95.23ms +step:1431/1705 train_time:136272ms step_avg:95.23ms +step:1432/1705 train_time:136369ms step_avg:95.23ms +step:1433/1705 train_time:136465ms step_avg:95.23ms +step:1434/1705 train_time:136561ms step_avg:95.23ms +step:1435/1705 train_time:136656ms step_avg:95.23ms +step:1436/1705 train_time:136751ms step_avg:95.23ms +step:1437/1705 train_time:136846ms step_avg:95.23ms +step:1438/1705 train_time:136941ms step_avg:95.23ms +step:1439/1705 train_time:137036ms step_avg:95.23ms +step:1440/1705 train_time:137131ms step_avg:95.23ms +step:1441/1705 train_time:137226ms step_avg:95.23ms +step:1442/1705 train_time:137324ms step_avg:95.23ms +step:1443/1705 train_time:137419ms step_avg:95.23ms +step:1444/1705 train_time:137514ms step_avg:95.23ms +step:1445/1705 train_time:137609ms step_avg:95.23ms +step:1446/1705 train_time:137704ms step_avg:95.23ms +step:1447/1705 train_time:137800ms step_avg:95.23ms +step:1448/1705 train_time:137896ms step_avg:95.23ms +step:1449/1705 train_time:137992ms step_avg:95.23ms +step:1450/1705 train_time:138088ms step_avg:95.23ms +step:1451/1705 train_time:138184ms step_avg:95.23ms +step:1452/1705 train_time:138278ms step_avg:95.23ms +step:1453/1705 train_time:138373ms step_avg:95.23ms +step:1454/1705 train_time:138469ms step_avg:95.23ms +step:1455/1705 train_time:138564ms step_avg:95.23ms +step:1456/1705 train_time:138659ms step_avg:95.23ms +step:1457/1705 train_time:138754ms step_avg:95.23ms +step:1458/1705 train_time:138849ms step_avg:95.23ms +step:1459/1705 train_time:138945ms step_avg:95.23ms +step:1460/1705 train_time:139041ms step_avg:95.23ms +step:1461/1705 train_time:139135ms step_avg:95.23ms +step:1462/1705 train_time:139230ms step_avg:95.23ms +step:1463/1705 train_time:139326ms step_avg:95.23ms +step:1464/1705 train_time:139423ms step_avg:95.23ms +step:1465/1705 train_time:139519ms step_avg:95.23ms +step:1466/1705 train_time:139615ms step_avg:95.24ms +step:1467/1705 train_time:139709ms step_avg:95.23ms +step:1468/1705 train_time:139804ms step_avg:95.23ms +step:1469/1705 train_time:139900ms step_avg:95.23ms +step:1470/1705 train_time:139996ms step_avg:95.24ms +step:1471/1705 train_time:140091ms step_avg:95.23ms +step:1472/1705 train_time:140186ms step_avg:95.23ms +step:1473/1705 train_time:140281ms step_avg:95.23ms +step:1474/1705 train_time:140375ms step_avg:95.23ms +step:1475/1705 train_time:140471ms step_avg:95.23ms +step:1476/1705 train_time:140566ms step_avg:95.23ms +step:1477/1705 train_time:140663ms step_avg:95.24ms +step:1478/1705 train_time:140757ms step_avg:95.23ms +step:1479/1705 train_time:140852ms step_avg:95.23ms +step:1480/1705 train_time:140949ms step_avg:95.24ms +step:1481/1705 train_time:141045ms step_avg:95.24ms +step:1482/1705 train_time:141141ms step_avg:95.24ms +step:1483/1705 train_time:141236ms step_avg:95.24ms +step:1484/1705 train_time:141330ms step_avg:95.24ms +step:1485/1705 train_time:141592ms step_avg:95.35ms +step:1486/1705 train_time:141783ms step_avg:95.41ms +step:1487/1705 train_time:141877ms step_avg:95.41ms +step:1488/1705 train_time:141971ms step_avg:95.41ms +step:1489/1705 train_time:142065ms step_avg:95.41ms +step:1490/1705 train_time:142160ms step_avg:95.41ms +step:1491/1705 train_time:142254ms step_avg:95.41ms +step:1492/1705 train_time:142348ms step_avg:95.41ms +step:1493/1705 train_time:142443ms step_avg:95.41ms +step:1494/1705 train_time:142537ms step_avg:95.41ms +step:1495/1705 train_time:142636ms step_avg:95.41ms +step:1496/1705 train_time:142735ms step_avg:95.41ms +step:1497/1705 train_time:142833ms step_avg:95.41ms +step:1498/1705 train_time:142929ms step_avg:95.41ms +step:1499/1705 train_time:143023ms step_avg:95.41ms +step:1500/1705 train_time:143118ms step_avg:95.41ms +step:1500/1705 val_loss:3.3205 train_time:143213ms step_avg:95.48ms +step:1501/1705 train_time:143234ms step_avg:95.43ms +step:1502/1705 train_time:143312ms step_avg:95.41ms +step:1503/1705 train_time:143410ms step_avg:95.42ms +step:1504/1705 train_time:143506ms step_avg:95.42ms +step:1505/1705 train_time:143600ms step_avg:95.42ms +step:1506/1705 train_time:143694ms step_avg:95.41ms +step:1507/1705 train_time:143789ms step_avg:95.41ms +step:1508/1705 train_time:143883ms step_avg:95.41ms +step:1509/1705 train_time:143978ms step_avg:95.41ms +step:1510/1705 train_time:144072ms step_avg:95.41ms +step:1511/1705 train_time:144168ms step_avg:95.41ms +step:1512/1705 train_time:144264ms step_avg:95.41ms +step:1513/1705 train_time:144360ms step_avg:95.41ms +step:1514/1705 train_time:144458ms step_avg:95.41ms +step:1515/1705 train_time:144554ms step_avg:95.42ms +step:1516/1705 train_time:144648ms step_avg:95.41ms +step:1517/1705 train_time:144743ms step_avg:95.41ms +step:1518/1705 train_time:144838ms step_avg:95.41ms +step:1519/1705 train_time:144933ms step_avg:95.41ms +step:1520/1705 train_time:145027ms step_avg:95.41ms +step:1521/1705 train_time:145123ms step_avg:95.41ms +step:1522/1705 train_time:145218ms step_avg:95.41ms +step:1523/1705 train_time:145315ms step_avg:95.41ms +step:1524/1705 train_time:145412ms step_avg:95.41ms +step:1525/1705 train_time:145506ms step_avg:95.41ms +step:1526/1705 train_time:145602ms step_avg:95.41ms +step:1527/1705 train_time:145697ms step_avg:95.41ms +step:1528/1705 train_time:145793ms step_avg:95.41ms +step:1529/1705 train_time:145887ms step_avg:95.41ms +step:1530/1705 train_time:145982ms step_avg:95.41ms +step:1531/1705 train_time:146076ms step_avg:95.41ms +step:1532/1705 train_time:146171ms step_avg:95.41ms +step:1533/1705 train_time:146267ms step_avg:95.41ms +step:1534/1705 train_time:146362ms step_avg:95.41ms +step:1535/1705 train_time:146458ms step_avg:95.41ms +step:1536/1705 train_time:146554ms step_avg:95.41ms +step:1537/1705 train_time:146649ms step_avg:95.41ms +step:1538/1705 train_time:146744ms step_avg:95.41ms +step:1539/1705 train_time:146840ms step_avg:95.41ms +step:1540/1705 train_time:146935ms step_avg:95.41ms +step:1541/1705 train_time:147030ms step_avg:95.41ms +step:1542/1705 train_time:147125ms step_avg:95.41ms +step:1543/1705 train_time:147220ms step_avg:95.41ms +step:1544/1705 train_time:147316ms step_avg:95.41ms +step:1545/1705 train_time:147412ms step_avg:95.41ms +step:1546/1705 train_time:147507ms step_avg:95.41ms +step:1547/1705 train_time:147601ms step_avg:95.41ms +step:1548/1705 train_time:147698ms step_avg:95.41ms +step:1549/1705 train_time:147793ms step_avg:95.41ms +step:1550/1705 train_time:147888ms step_avg:95.41ms +step:1551/1705 train_time:147982ms step_avg:95.41ms +step:1552/1705 train_time:148078ms step_avg:95.41ms +step:1553/1705 train_time:148173ms step_avg:95.41ms +step:1554/1705 train_time:148269ms step_avg:95.41ms +step:1555/1705 train_time:148364ms step_avg:95.41ms +step:1556/1705 train_time:148459ms step_avg:95.41ms +step:1557/1705 train_time:148555ms step_avg:95.41ms +step:1558/1705 train_time:148651ms step_avg:95.41ms +step:1559/1705 train_time:148746ms step_avg:95.41ms +step:1560/1705 train_time:148842ms step_avg:95.41ms +step:1561/1705 train_time:148937ms step_avg:95.41ms +step:1562/1705 train_time:149032ms step_avg:95.41ms +step:1563/1705 train_time:149127ms step_avg:95.41ms +step:1564/1705 train_time:149222ms step_avg:95.41ms +step:1565/1705 train_time:149318ms step_avg:95.41ms +step:1566/1705 train_time:149413ms step_avg:95.41ms +step:1567/1705 train_time:149508ms step_avg:95.41ms +step:1568/1705 train_time:149603ms step_avg:95.41ms +step:1569/1705 train_time:149698ms step_avg:95.41ms +step:1570/1705 train_time:149795ms step_avg:95.41ms +step:1571/1705 train_time:149890ms step_avg:95.41ms +step:1572/1705 train_time:149985ms step_avg:95.41ms +step:1573/1705 train_time:150080ms step_avg:95.41ms +step:1574/1705 train_time:150176ms step_avg:95.41ms +step:1575/1705 train_time:150272ms step_avg:95.41ms +step:1576/1705 train_time:150367ms step_avg:95.41ms +step:1577/1705 train_time:150462ms step_avg:95.41ms +step:1578/1705 train_time:150557ms step_avg:95.41ms +step:1579/1705 train_time:150653ms step_avg:95.41ms +step:1580/1705 train_time:150748ms step_avg:95.41ms +step:1581/1705 train_time:150843ms step_avg:95.41ms +step:1582/1705 train_time:150938ms step_avg:95.41ms +step:1583/1705 train_time:151034ms step_avg:95.41ms +step:1584/1705 train_time:151129ms step_avg:95.41ms +step:1585/1705 train_time:151224ms step_avg:95.41ms +step:1586/1705 train_time:151319ms step_avg:95.41ms +step:1587/1705 train_time:151416ms step_avg:95.41ms +step:1588/1705 train_time:151511ms step_avg:95.41ms +step:1589/1705 train_time:151606ms step_avg:95.41ms +step:1590/1705 train_time:151701ms step_avg:95.41ms +step:1591/1705 train_time:151797ms step_avg:95.41ms +step:1592/1705 train_time:151893ms step_avg:95.41ms +step:1593/1705 train_time:151989ms step_avg:95.41ms +step:1594/1705 train_time:152083ms step_avg:95.41ms +step:1595/1705 train_time:152179ms step_avg:95.41ms +step:1596/1705 train_time:152275ms step_avg:95.41ms +step:1597/1705 train_time:152371ms step_avg:95.41ms +step:1598/1705 train_time:152467ms step_avg:95.41ms +step:1599/1705 train_time:152562ms step_avg:95.41ms +step:1600/1705 train_time:152657ms step_avg:95.41ms +step:1601/1705 train_time:152752ms step_avg:95.41ms +step:1602/1705 train_time:152847ms step_avg:95.41ms +step:1603/1705 train_time:152943ms step_avg:95.41ms +step:1604/1705 train_time:153038ms step_avg:95.41ms +step:1605/1705 train_time:153134ms step_avg:95.41ms +step:1606/1705 train_time:153229ms step_avg:95.41ms +step:1607/1705 train_time:153323ms step_avg:95.41ms +step:1608/1705 train_time:153419ms step_avg:95.41ms +step:1609/1705 train_time:153515ms step_avg:95.41ms +step:1610/1705 train_time:153610ms step_avg:95.41ms +step:1611/1705 train_time:153705ms step_avg:95.41ms +step:1612/1705 train_time:153801ms step_avg:95.41ms +step:1613/1705 train_time:153897ms step_avg:95.41ms +step:1614/1705 train_time:153992ms step_avg:95.41ms +step:1615/1705 train_time:154087ms step_avg:95.41ms +step:1616/1705 train_time:154182ms step_avg:95.41ms +step:1617/1705 train_time:154278ms step_avg:95.41ms +step:1618/1705 train_time:154373ms step_avg:95.41ms +step:1619/1705 train_time:154469ms step_avg:95.41ms +step:1620/1705 train_time:154564ms step_avg:95.41ms +step:1621/1705 train_time:154660ms step_avg:95.41ms +step:1622/1705 train_time:154755ms step_avg:95.41ms +step:1623/1705 train_time:154850ms step_avg:95.41ms +step:1624/1705 train_time:154946ms step_avg:95.41ms +step:1625/1705 train_time:155042ms step_avg:95.41ms +step:1625/1705 val_loss:3.2927 train_time:155138ms step_avg:95.47ms +step:1626/1705 train_time:155159ms step_avg:95.42ms +step:1627/1705 train_time:155239ms step_avg:95.41ms +step:1628/1705 train_time:155336ms step_avg:95.42ms +step:1629/1705 train_time:155431ms step_avg:95.41ms +step:1630/1705 train_time:155525ms step_avg:95.41ms +step:1631/1705 train_time:155620ms step_avg:95.41ms +step:1632/1705 train_time:155714ms step_avg:95.41ms +step:1633/1705 train_time:155808ms step_avg:95.41ms +step:1634/1705 train_time:155903ms step_avg:95.41ms +step:1635/1705 train_time:155996ms step_avg:95.41ms +step:1636/1705 train_time:156093ms step_avg:95.41ms +step:1637/1705 train_time:156193ms step_avg:95.41ms +step:1638/1705 train_time:156291ms step_avg:95.42ms +step:1639/1705 train_time:156388ms step_avg:95.42ms +step:1640/1705 train_time:156483ms step_avg:95.42ms +step:1641/1705 train_time:156577ms step_avg:95.42ms +step:1642/1705 train_time:156673ms step_avg:95.42ms +step:1643/1705 train_time:156767ms step_avg:95.42ms +step:1644/1705 train_time:156862ms step_avg:95.41ms +step:1645/1705 train_time:156956ms step_avg:95.41ms +step:1646/1705 train_time:157052ms step_avg:95.41ms +step:1647/1705 train_time:157149ms step_avg:95.42ms +step:1648/1705 train_time:157247ms step_avg:95.42ms +step:1649/1705 train_time:157342ms step_avg:95.42ms +step:1650/1705 train_time:157437ms step_avg:95.42ms +step:1651/1705 train_time:157532ms step_avg:95.42ms +step:1652/1705 train_time:157627ms step_avg:95.42ms +step:1653/1705 train_time:157722ms step_avg:95.42ms +step:1654/1705 train_time:157816ms step_avg:95.41ms +step:1655/1705 train_time:157911ms step_avg:95.41ms +step:1656/1705 train_time:158007ms step_avg:95.41ms +step:1657/1705 train_time:158103ms step_avg:95.42ms +step:1658/1705 train_time:158198ms step_avg:95.42ms +step:1659/1705 train_time:158295ms step_avg:95.42ms +step:1660/1705 train_time:158391ms step_avg:95.42ms +step:1661/1705 train_time:158487ms step_avg:95.42ms +step:1662/1705 train_time:158582ms step_avg:95.42ms +step:1663/1705 train_time:158678ms step_avg:95.42ms +step:1664/1705 train_time:158772ms step_avg:95.42ms +step:1665/1705 train_time:158867ms step_avg:95.42ms +step:1666/1705 train_time:158963ms step_avg:95.42ms +step:1667/1705 train_time:159057ms step_avg:95.42ms +step:1668/1705 train_time:159152ms step_avg:95.42ms +step:1669/1705 train_time:159248ms step_avg:95.42ms +step:1670/1705 train_time:159344ms step_avg:95.42ms +step:1671/1705 train_time:159439ms step_avg:95.42ms +step:1672/1705 train_time:159534ms step_avg:95.42ms +step:1673/1705 train_time:159631ms step_avg:95.42ms +step:1674/1705 train_time:159727ms step_avg:95.42ms +step:1675/1705 train_time:159822ms step_avg:95.42ms +step:1676/1705 train_time:159916ms step_avg:95.42ms +step:1677/1705 train_time:160012ms step_avg:95.42ms +step:1678/1705 train_time:160108ms step_avg:95.42ms +step:1679/1705 train_time:160202ms step_avg:95.42ms +step:1680/1705 train_time:160297ms step_avg:95.41ms +step:1681/1705 train_time:160394ms step_avg:95.42ms +step:1682/1705 train_time:160491ms step_avg:95.42ms +step:1683/1705 train_time:160587ms step_avg:95.42ms +step:1684/1705 train_time:160683ms step_avg:95.42ms +step:1685/1705 train_time:160776ms step_avg:95.42ms +step:1686/1705 train_time:160872ms step_avg:95.42ms +step:1687/1705 train_time:160967ms step_avg:95.42ms +step:1688/1705 train_time:161062ms step_avg:95.42ms +step:1689/1705 train_time:161158ms step_avg:95.42ms +step:1690/1705 train_time:161252ms step_avg:95.42ms +step:1691/1705 train_time:161348ms step_avg:95.42ms +step:1692/1705 train_time:161444ms step_avg:95.42ms +step:1693/1705 train_time:161540ms step_avg:95.42ms +step:1694/1705 train_time:161635ms step_avg:95.42ms +step:1695/1705 train_time:161730ms step_avg:95.42ms +step:1696/1705 train_time:161826ms step_avg:95.42ms +step:1697/1705 train_time:161921ms step_avg:95.42ms +step:1698/1705 train_time:162181ms step_avg:95.51ms +step:1699/1705 train_time:162371ms step_avg:95.57ms +step:1700/1705 train_time:162464ms step_avg:95.57ms +step:1701/1705 train_time:162557ms step_avg:95.57ms +step:1702/1705 train_time:162652ms step_avg:95.56ms +step:1703/1705 train_time:162746ms step_avg:95.56ms +step:1704/1705 train_time:162840ms step_avg:95.56ms +step:1705/1705 train_time:162935ms step_avg:95.56ms +step:1705/1705 val_loss:3.2787 train_time:163029ms step_avg:95.62ms +peak memory allocated: 33992 MiB reserved: 48836 MiB diff --git a/records/050925_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt b/records/050925_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt new file mode 100644 index 000000000..953806534 --- /dev/null +++ b/records/050925_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:35:03 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 128W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 93208 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 93209 C /usr/bin/python3 610MiB | +| 0 N/A N/A 93210 C /usr/bin/python3 610MiB | +| 0 N/A N/A 93211 C /usr/bin/python3 610MiB | +| 0 N/A N/A 93212 C /usr/bin/python3 610MiB | +| 0 N/A N/A 93213 C /usr/bin/python3 610MiB | +| 0 N/A N/A 93214 C /usr/bin/python3 610MiB | +| 0 N/A N/A 93215 C /usr/bin/python3 610MiB | +| 1 N/A N/A 93209 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 93210 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 93211 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 93212 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 93213 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 93214 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 93215 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1705 train_time:392ms step_avg:391.78ms +step:2/1705 train_time:412ms step_avg:205.99ms +step:3/1705 train_time:481ms step_avg:160.40ms +step:4/1705 train_time:572ms step_avg:143.00ms +step:5/1705 train_time:664ms step_avg:132.81ms +step:6/1705 train_time:756ms step_avg:125.98ms +step:7/1705 train_time:848ms step_avg:121.20ms +step:8/1705 train_time:940ms step_avg:117.56ms +step:9/1705 train_time:1033ms step_avg:114.76ms +step:10/1705 train_time:1126ms step_avg:112.55ms +step:11/1705 train_time:1218ms step_avg:110.70ms +step:12/1705 train_time:1311ms step_avg:109.29ms +step:13/1705 train_time:1409ms step_avg:108.35ms +step:14/1705 train_time:1503ms step_avg:107.34ms +step:15/1705 train_time:1595ms step_avg:106.34ms +step:16/1705 train_time:1688ms step_avg:105.51ms +step:17/1705 train_time:1781ms step_avg:104.77ms +step:18/1705 train_time:1874ms step_avg:104.11ms +step:19/1705 train_time:1967ms step_avg:103.51ms +step:20/1705 train_time:2059ms step_avg:102.96ms +step:21/1705 train_time:2152ms step_avg:102.46ms +step:22/1705 train_time:2245ms step_avg:102.03ms +step:23/1705 train_time:2339ms step_avg:101.69ms +step:24/1705 train_time:2432ms step_avg:101.34ms +step:25/1705 train_time:2527ms step_avg:101.09ms +step:26/1705 train_time:2621ms step_avg:100.81ms +step:27/1705 train_time:2713ms step_avg:100.49ms +step:28/1705 train_time:2807ms step_avg:100.24ms +step:29/1705 train_time:2900ms step_avg:99.99ms +step:30/1705 train_time:2992ms step_avg:99.73ms +step:31/1705 train_time:3084ms step_avg:99.49ms +step:32/1705 train_time:3177ms step_avg:99.28ms +step:33/1705 train_time:3270ms step_avg:99.10ms +step:34/1705 train_time:3365ms step_avg:98.96ms +step:35/1705 train_time:3458ms step_avg:98.81ms +step:36/1705 train_time:3551ms step_avg:98.65ms +step:37/1705 train_time:3645ms step_avg:98.53ms +step:38/1705 train_time:3739ms step_avg:98.39ms +step:39/1705 train_time:3832ms step_avg:98.26ms +step:40/1705 train_time:3925ms step_avg:98.13ms +step:41/1705 train_time:4017ms step_avg:97.99ms +step:42/1705 train_time:4110ms step_avg:97.86ms +step:43/1705 train_time:4203ms step_avg:97.74ms +step:44/1705 train_time:4296ms step_avg:97.63ms +step:45/1705 train_time:4389ms step_avg:97.54ms +step:46/1705 train_time:4483ms step_avg:97.46ms +step:47/1705 train_time:4575ms step_avg:97.34ms +step:48/1705 train_time:4669ms step_avg:97.28ms +step:49/1705 train_time:4763ms step_avg:97.20ms +step:50/1705 train_time:4856ms step_avg:97.13ms +step:51/1705 train_time:4950ms step_avg:97.06ms +step:52/1705 train_time:5043ms step_avg:96.97ms +step:53/1705 train_time:5135ms step_avg:96.88ms +step:54/1705 train_time:5227ms step_avg:96.80ms +step:55/1705 train_time:5321ms step_avg:96.75ms +step:56/1705 train_time:5414ms step_avg:96.68ms +step:57/1705 train_time:5507ms step_avg:96.61ms +step:58/1705 train_time:5599ms step_avg:96.54ms +step:59/1705 train_time:5693ms step_avg:96.49ms +step:60/1705 train_time:5787ms step_avg:96.45ms +step:61/1705 train_time:5880ms step_avg:96.40ms +step:62/1705 train_time:5973ms step_avg:96.33ms +step:63/1705 train_time:6066ms step_avg:96.29ms +step:64/1705 train_time:6160ms step_avg:96.25ms +step:65/1705 train_time:6253ms step_avg:96.19ms +step:66/1705 train_time:6345ms step_avg:96.14ms +step:67/1705 train_time:6438ms step_avg:96.09ms +step:68/1705 train_time:6531ms step_avg:96.05ms +step:69/1705 train_time:6625ms step_avg:96.02ms +step:70/1705 train_time:6719ms step_avg:95.98ms +step:71/1705 train_time:6812ms step_avg:95.94ms +step:72/1705 train_time:6906ms step_avg:95.91ms +step:73/1705 train_time:6999ms step_avg:95.88ms +step:74/1705 train_time:7092ms step_avg:95.84ms +step:75/1705 train_time:7186ms step_avg:95.81ms +step:76/1705 train_time:7279ms step_avg:95.78ms +step:77/1705 train_time:7372ms step_avg:95.73ms +step:78/1705 train_time:7464ms step_avg:95.69ms +step:79/1705 train_time:7556ms step_avg:95.65ms +step:80/1705 train_time:7649ms step_avg:95.62ms +step:81/1705 train_time:7743ms step_avg:95.60ms +step:82/1705 train_time:7836ms step_avg:95.56ms +step:83/1705 train_time:7929ms step_avg:95.53ms +step:84/1705 train_time:8023ms step_avg:95.51ms +step:85/1705 train_time:8116ms step_avg:95.48ms +step:86/1705 train_time:8209ms step_avg:95.45ms +step:87/1705 train_time:8302ms step_avg:95.43ms +step:88/1705 train_time:8395ms step_avg:95.39ms +step:89/1705 train_time:8487ms step_avg:95.36ms +step:90/1705 train_time:8580ms step_avg:95.33ms +step:91/1705 train_time:8673ms step_avg:95.30ms +step:92/1705 train_time:8766ms step_avg:95.28ms +step:93/1705 train_time:8860ms step_avg:95.27ms +step:94/1705 train_time:8952ms step_avg:95.24ms +step:95/1705 train_time:9045ms step_avg:95.21ms +step:96/1705 train_time:9138ms step_avg:95.19ms +step:97/1705 train_time:9231ms step_avg:95.17ms +step:98/1705 train_time:9325ms step_avg:95.15ms +step:99/1705 train_time:9418ms step_avg:95.13ms +step:100/1705 train_time:9511ms step_avg:95.11ms +step:101/1705 train_time:9604ms step_avg:95.09ms +step:102/1705 train_time:9697ms step_avg:95.07ms +step:103/1705 train_time:9791ms step_avg:95.05ms +step:104/1705 train_time:9884ms step_avg:95.04ms +step:105/1705 train_time:9976ms step_avg:95.01ms +step:106/1705 train_time:10070ms step_avg:95.00ms +step:107/1705 train_time:10164ms step_avg:94.99ms +step:108/1705 train_time:10257ms step_avg:94.97ms +step:109/1705 train_time:10351ms step_avg:94.96ms +step:110/1705 train_time:10444ms step_avg:94.95ms +step:111/1705 train_time:10537ms step_avg:94.93ms +step:112/1705 train_time:10630ms step_avg:94.91ms +step:113/1705 train_time:10723ms step_avg:94.90ms +step:114/1705 train_time:10815ms step_avg:94.87ms +step:115/1705 train_time:10908ms step_avg:94.86ms +step:116/1705 train_time:11001ms step_avg:94.84ms +step:117/1705 train_time:11094ms step_avg:94.82ms +step:118/1705 train_time:11187ms step_avg:94.81ms +step:119/1705 train_time:11281ms step_avg:94.80ms +step:120/1705 train_time:11374ms step_avg:94.78ms +step:121/1705 train_time:11467ms step_avg:94.77ms +step:122/1705 train_time:11559ms step_avg:94.75ms +step:123/1705 train_time:11652ms step_avg:94.73ms +step:124/1705 train_time:11745ms step_avg:94.72ms +step:125/1705 train_time:11838ms step_avg:94.71ms +step:125/1705 val_loss:4.3025 train_time:11931ms step_avg:95.45ms +step:126/1705 train_time:11953ms step_avg:94.87ms +step:127/1705 train_time:12031ms step_avg:94.73ms +step:128/1705 train_time:12133ms step_avg:94.79ms +step:129/1705 train_time:12229ms step_avg:94.80ms +step:130/1705 train_time:12322ms step_avg:94.78ms +step:131/1705 train_time:12413ms step_avg:94.76ms +step:132/1705 train_time:12505ms step_avg:94.74ms +step:133/1705 train_time:12597ms step_avg:94.72ms +step:134/1705 train_time:12689ms step_avg:94.70ms +step:135/1705 train_time:12781ms step_avg:94.68ms +step:136/1705 train_time:12873ms step_avg:94.66ms +step:137/1705 train_time:12966ms step_avg:94.64ms +step:138/1705 train_time:13061ms step_avg:94.65ms +step:139/1705 train_time:13157ms step_avg:94.65ms +step:140/1705 train_time:13250ms step_avg:94.64ms +step:141/1705 train_time:13343ms step_avg:94.63ms +step:142/1705 train_time:13436ms step_avg:94.62ms +step:143/1705 train_time:13528ms step_avg:94.60ms +step:144/1705 train_time:13621ms step_avg:94.59ms +step:145/1705 train_time:13712ms step_avg:94.57ms +step:146/1705 train_time:13805ms step_avg:94.55ms +step:147/1705 train_time:13897ms step_avg:94.53ms +step:148/1705 train_time:13989ms step_avg:94.52ms +step:149/1705 train_time:14083ms step_avg:94.52ms +step:150/1705 train_time:14177ms step_avg:94.52ms +step:151/1705 train_time:14270ms step_avg:94.50ms +step:152/1705 train_time:14364ms step_avg:94.50ms +step:153/1705 train_time:14457ms step_avg:94.49ms +step:154/1705 train_time:14549ms step_avg:94.47ms +step:155/1705 train_time:14641ms step_avg:94.46ms +step:156/1705 train_time:14734ms step_avg:94.45ms +step:157/1705 train_time:14826ms step_avg:94.43ms +step:158/1705 train_time:14919ms step_avg:94.43ms +step:159/1705 train_time:15012ms step_avg:94.41ms +step:160/1705 train_time:15106ms step_avg:94.41ms +step:161/1705 train_time:15201ms step_avg:94.41ms +step:162/1705 train_time:15294ms step_avg:94.41ms +step:163/1705 train_time:15387ms step_avg:94.40ms +step:164/1705 train_time:15480ms step_avg:94.39ms +step:165/1705 train_time:15573ms step_avg:94.38ms +step:166/1705 train_time:15666ms step_avg:94.37ms +step:167/1705 train_time:15758ms step_avg:94.36ms +step:168/1705 train_time:15850ms step_avg:94.35ms +step:169/1705 train_time:15943ms step_avg:94.34ms +step:170/1705 train_time:16036ms step_avg:94.33ms +step:171/1705 train_time:16129ms step_avg:94.32ms +step:172/1705 train_time:16222ms step_avg:94.32ms +step:173/1705 train_time:16315ms step_avg:94.31ms +step:174/1705 train_time:16409ms step_avg:94.31ms +step:175/1705 train_time:16502ms step_avg:94.30ms +step:176/1705 train_time:16595ms step_avg:94.29ms +step:177/1705 train_time:16687ms step_avg:94.28ms +step:178/1705 train_time:16780ms step_avg:94.27ms +step:179/1705 train_time:16872ms step_avg:94.26ms +step:180/1705 train_time:16965ms step_avg:94.25ms +step:181/1705 train_time:17058ms step_avg:94.25ms +step:182/1705 train_time:17151ms step_avg:94.24ms +step:183/1705 train_time:17245ms step_avg:94.24ms +step:184/1705 train_time:17338ms step_avg:94.23ms +step:185/1705 train_time:17430ms step_avg:94.22ms +step:186/1705 train_time:17523ms step_avg:94.21ms +step:187/1705 train_time:17617ms step_avg:94.21ms +step:188/1705 train_time:17710ms step_avg:94.20ms +step:189/1705 train_time:17803ms step_avg:94.19ms +step:190/1705 train_time:17895ms step_avg:94.19ms +step:191/1705 train_time:17988ms step_avg:94.18ms +step:192/1705 train_time:18081ms step_avg:94.17ms +step:193/1705 train_time:18173ms step_avg:94.16ms +step:194/1705 train_time:18267ms step_avg:94.16ms +step:195/1705 train_time:18360ms step_avg:94.16ms +step:196/1705 train_time:18453ms step_avg:94.15ms +step:197/1705 train_time:18546ms step_avg:94.14ms +step:198/1705 train_time:18639ms step_avg:94.14ms +step:199/1705 train_time:18731ms step_avg:94.13ms +step:200/1705 train_time:18824ms step_avg:94.12ms +step:201/1705 train_time:18917ms step_avg:94.11ms +step:202/1705 train_time:19009ms step_avg:94.11ms +step:203/1705 train_time:19102ms step_avg:94.10ms +step:204/1705 train_time:19195ms step_avg:94.09ms +step:205/1705 train_time:19288ms step_avg:94.09ms +step:206/1705 train_time:19380ms step_avg:94.08ms +step:207/1705 train_time:19473ms step_avg:94.07ms +step:208/1705 train_time:19566ms step_avg:94.07ms +step:209/1705 train_time:19658ms step_avg:94.06ms +step:210/1705 train_time:19750ms step_avg:94.05ms +step:211/1705 train_time:19843ms step_avg:94.04ms +step:212/1705 train_time:19936ms step_avg:94.04ms +step:213/1705 train_time:20234ms step_avg:94.99ms +step:214/1705 train_time:20366ms step_avg:95.17ms +step:215/1705 train_time:20457ms step_avg:95.15ms +step:216/1705 train_time:20549ms step_avg:95.13ms +step:217/1705 train_time:20640ms step_avg:95.12ms +step:218/1705 train_time:20732ms step_avg:95.10ms +step:219/1705 train_time:20824ms step_avg:95.09ms +step:220/1705 train_time:20916ms step_avg:95.07ms +step:221/1705 train_time:21008ms step_avg:95.06ms +step:222/1705 train_time:21100ms step_avg:95.04ms +step:223/1705 train_time:21192ms step_avg:95.03ms +step:224/1705 train_time:21289ms step_avg:95.04ms +step:225/1705 train_time:21385ms step_avg:95.04ms +step:226/1705 train_time:21478ms step_avg:95.04ms +step:227/1705 train_time:21571ms step_avg:95.03ms +step:228/1705 train_time:21664ms step_avg:95.02ms +step:229/1705 train_time:21756ms step_avg:95.00ms +step:230/1705 train_time:21848ms step_avg:94.99ms +step:231/1705 train_time:21940ms step_avg:94.98ms +step:232/1705 train_time:22032ms step_avg:94.97ms +step:233/1705 train_time:22124ms step_avg:94.95ms +step:234/1705 train_time:22218ms step_avg:94.95ms +step:235/1705 train_time:22311ms step_avg:94.94ms +step:236/1705 train_time:22405ms step_avg:94.94ms +step:237/1705 train_time:22498ms step_avg:94.93ms +step:238/1705 train_time:22592ms step_avg:94.92ms +step:239/1705 train_time:22684ms step_avg:94.91ms +step:240/1705 train_time:22776ms step_avg:94.90ms +step:241/1705 train_time:22868ms step_avg:94.89ms +step:242/1705 train_time:22960ms step_avg:94.88ms +step:243/1705 train_time:23052ms step_avg:94.86ms +step:244/1705 train_time:23144ms step_avg:94.85ms +step:245/1705 train_time:23237ms step_avg:94.85ms +step:246/1705 train_time:23330ms step_avg:94.84ms +step:247/1705 train_time:23425ms step_avg:94.84ms +step:248/1705 train_time:23518ms step_avg:94.83ms +step:249/1705 train_time:23610ms step_avg:94.82ms +step:250/1705 train_time:23703ms step_avg:94.81ms +step:250/1705 val_loss:3.9743 train_time:23796ms step_avg:95.19ms +step:251/1705 train_time:23818ms step_avg:94.89ms +step:252/1705 train_time:23893ms step_avg:94.81ms +step:253/1705 train_time:23990ms step_avg:94.82ms +step:254/1705 train_time:24085ms step_avg:94.82ms +step:255/1705 train_time:24177ms step_avg:94.81ms +step:256/1705 train_time:24269ms step_avg:94.80ms +step:257/1705 train_time:24362ms step_avg:94.79ms +step:258/1705 train_time:24453ms step_avg:94.78ms +step:259/1705 train_time:24544ms step_avg:94.77ms +step:260/1705 train_time:24636ms step_avg:94.75ms +step:261/1705 train_time:24729ms step_avg:94.75ms +step:262/1705 train_time:24823ms step_avg:94.74ms +step:263/1705 train_time:24917ms step_avg:94.74ms +step:264/1705 train_time:25011ms step_avg:94.74ms +step:265/1705 train_time:25105ms step_avg:94.74ms +step:266/1705 train_time:25198ms step_avg:94.73ms +step:267/1705 train_time:25290ms step_avg:94.72ms +step:268/1705 train_time:25383ms step_avg:94.71ms +step:269/1705 train_time:25476ms step_avg:94.70ms +step:270/1705 train_time:25568ms step_avg:94.70ms +step:271/1705 train_time:25661ms step_avg:94.69ms +step:272/1705 train_time:25753ms step_avg:94.68ms +step:273/1705 train_time:25847ms step_avg:94.68ms +step:274/1705 train_time:25940ms step_avg:94.67ms +step:275/1705 train_time:26033ms step_avg:94.67ms +step:276/1705 train_time:26126ms step_avg:94.66ms +step:277/1705 train_time:26219ms step_avg:94.65ms +step:278/1705 train_time:26311ms step_avg:94.65ms +step:279/1705 train_time:26405ms step_avg:94.64ms +step:280/1705 train_time:26497ms step_avg:94.63ms +step:281/1705 train_time:26589ms step_avg:94.62ms +step:282/1705 train_time:26682ms step_avg:94.62ms +step:283/1705 train_time:26776ms step_avg:94.61ms +step:284/1705 train_time:26868ms step_avg:94.61ms +step:285/1705 train_time:26962ms step_avg:94.60ms +step:286/1705 train_time:27056ms step_avg:94.60ms +step:287/1705 train_time:27148ms step_avg:94.59ms +step:288/1705 train_time:27242ms step_avg:94.59ms +step:289/1705 train_time:27334ms step_avg:94.58ms +step:290/1705 train_time:27427ms step_avg:94.57ms +step:291/1705 train_time:27519ms step_avg:94.57ms +step:292/1705 train_time:27611ms step_avg:94.56ms +step:293/1705 train_time:27704ms step_avg:94.55ms +step:294/1705 train_time:27797ms step_avg:94.55ms +step:295/1705 train_time:27889ms step_avg:94.54ms +step:296/1705 train_time:27983ms step_avg:94.54ms +step:297/1705 train_time:28076ms step_avg:94.53ms +step:298/1705 train_time:28168ms step_avg:94.53ms +step:299/1705 train_time:28261ms step_avg:94.52ms +step:300/1705 train_time:28355ms step_avg:94.52ms +step:301/1705 train_time:28447ms step_avg:94.51ms +step:302/1705 train_time:28540ms step_avg:94.50ms +step:303/1705 train_time:28632ms step_avg:94.50ms +step:304/1705 train_time:28725ms step_avg:94.49ms +step:305/1705 train_time:28818ms step_avg:94.49ms +step:306/1705 train_time:28910ms step_avg:94.48ms +step:307/1705 train_time:29004ms step_avg:94.47ms +step:308/1705 train_time:29097ms step_avg:94.47ms +step:309/1705 train_time:29189ms step_avg:94.46ms +step:310/1705 train_time:29282ms step_avg:94.46ms +step:311/1705 train_time:29375ms step_avg:94.45ms +step:312/1705 train_time:29468ms step_avg:94.45ms +step:313/1705 train_time:29561ms step_avg:94.44ms +step:314/1705 train_time:29653ms step_avg:94.44ms +step:315/1705 train_time:29746ms step_avg:94.43ms +step:316/1705 train_time:29839ms step_avg:94.43ms +step:317/1705 train_time:29931ms step_avg:94.42ms +step:318/1705 train_time:30025ms step_avg:94.42ms +step:319/1705 train_time:30118ms step_avg:94.42ms +step:320/1705 train_time:30211ms step_avg:94.41ms +step:321/1705 train_time:30304ms step_avg:94.41ms +step:322/1705 train_time:30397ms step_avg:94.40ms +step:323/1705 train_time:30489ms step_avg:94.39ms +step:324/1705 train_time:30582ms step_avg:94.39ms +step:325/1705 train_time:30675ms step_avg:94.38ms +step:326/1705 train_time:30767ms step_avg:94.38ms +step:327/1705 train_time:30860ms step_avg:94.37ms +step:328/1705 train_time:30952ms step_avg:94.37ms +step:329/1705 train_time:31045ms step_avg:94.36ms +step:330/1705 train_time:31138ms step_avg:94.36ms +step:331/1705 train_time:31231ms step_avg:94.35ms +step:332/1705 train_time:31324ms step_avg:94.35ms +step:333/1705 train_time:31417ms step_avg:94.35ms +step:334/1705 train_time:31510ms step_avg:94.34ms +step:335/1705 train_time:31602ms step_avg:94.34ms +step:336/1705 train_time:31695ms step_avg:94.33ms +step:337/1705 train_time:31788ms step_avg:94.33ms +step:338/1705 train_time:31881ms step_avg:94.32ms +step:339/1705 train_time:31974ms step_avg:94.32ms +step:340/1705 train_time:32067ms step_avg:94.31ms +step:341/1705 train_time:32160ms step_avg:94.31ms +step:342/1705 train_time:32253ms step_avg:94.31ms +step:343/1705 train_time:32345ms step_avg:94.30ms +step:344/1705 train_time:32439ms step_avg:94.30ms +step:345/1705 train_time:32531ms step_avg:94.29ms +step:346/1705 train_time:32625ms step_avg:94.29ms +step:347/1705 train_time:32718ms step_avg:94.29ms +step:348/1705 train_time:32810ms step_avg:94.28ms +step:349/1705 train_time:32903ms step_avg:94.28ms +step:350/1705 train_time:32996ms step_avg:94.27ms +step:351/1705 train_time:33088ms step_avg:94.27ms +step:352/1705 train_time:33182ms step_avg:94.27ms +step:353/1705 train_time:33275ms step_avg:94.26ms +step:354/1705 train_time:33368ms step_avg:94.26ms +step:355/1705 train_time:33462ms step_avg:94.26ms +step:356/1705 train_time:33554ms step_avg:94.25ms +step:357/1705 train_time:33647ms step_avg:94.25ms +step:358/1705 train_time:33739ms step_avg:94.24ms +step:359/1705 train_time:33831ms step_avg:94.24ms +step:360/1705 train_time:33924ms step_avg:94.23ms +step:361/1705 train_time:34017ms step_avg:94.23ms +step:362/1705 train_time:34110ms step_avg:94.23ms +step:363/1705 train_time:34203ms step_avg:94.22ms +step:364/1705 train_time:34296ms step_avg:94.22ms +step:365/1705 train_time:34389ms step_avg:94.22ms +step:366/1705 train_time:34483ms step_avg:94.22ms +step:367/1705 train_time:34576ms step_avg:94.21ms +step:368/1705 train_time:34668ms step_avg:94.21ms +step:369/1705 train_time:34761ms step_avg:94.20ms +step:370/1705 train_time:34853ms step_avg:94.20ms +step:371/1705 train_time:34946ms step_avg:94.19ms +step:372/1705 train_time:35039ms step_avg:94.19ms +step:373/1705 train_time:35132ms step_avg:94.19ms +step:374/1705 train_time:35225ms step_avg:94.18ms +step:375/1705 train_time:35318ms step_avg:94.18ms +step:375/1705 val_loss:3.8206 train_time:35411ms step_avg:94.43ms +step:376/1705 train_time:35433ms step_avg:94.24ms +step:377/1705 train_time:35512ms step_avg:94.20ms +step:378/1705 train_time:35608ms step_avg:94.20ms +step:379/1705 train_time:35702ms step_avg:94.20ms +step:380/1705 train_time:35795ms step_avg:94.20ms +step:381/1705 train_time:35887ms step_avg:94.19ms +step:382/1705 train_time:35978ms step_avg:94.18ms +step:383/1705 train_time:36070ms step_avg:94.18ms +step:384/1705 train_time:36162ms step_avg:94.17ms +step:385/1705 train_time:36254ms step_avg:94.17ms +step:386/1705 train_time:36346ms step_avg:94.16ms +step:387/1705 train_time:36442ms step_avg:94.16ms +step:388/1705 train_time:36536ms step_avg:94.17ms +step:389/1705 train_time:36630ms step_avg:94.17ms +step:390/1705 train_time:36725ms step_avg:94.17ms +step:391/1705 train_time:36818ms step_avg:94.16ms +step:392/1705 train_time:36910ms step_avg:94.16ms +step:393/1705 train_time:37002ms step_avg:94.15ms +step:394/1705 train_time:37094ms step_avg:94.15ms +step:395/1705 train_time:37186ms step_avg:94.14ms +step:396/1705 train_time:37279ms step_avg:94.14ms +step:397/1705 train_time:37373ms step_avg:94.14ms +step:398/1705 train_time:37466ms step_avg:94.14ms +step:399/1705 train_time:37561ms step_avg:94.14ms +step:400/1705 train_time:37654ms step_avg:94.14ms +step:401/1705 train_time:37747ms step_avg:94.13ms +step:402/1705 train_time:37842ms step_avg:94.13ms +step:403/1705 train_time:37935ms step_avg:94.13ms +step:404/1705 train_time:38027ms step_avg:94.13ms +step:405/1705 train_time:38119ms step_avg:94.12ms +step:406/1705 train_time:38211ms step_avg:94.12ms +step:407/1705 train_time:38304ms step_avg:94.11ms +step:408/1705 train_time:38397ms step_avg:94.11ms +step:409/1705 train_time:38490ms step_avg:94.11ms +step:410/1705 train_time:38583ms step_avg:94.10ms +step:411/1705 train_time:38677ms step_avg:94.10ms +step:412/1705 train_time:38770ms step_avg:94.10ms +step:413/1705 train_time:38863ms step_avg:94.10ms +step:414/1705 train_time:38956ms step_avg:94.10ms +step:415/1705 train_time:39049ms step_avg:94.09ms +step:416/1705 train_time:39142ms step_avg:94.09ms +step:417/1705 train_time:39234ms step_avg:94.09ms +step:418/1705 train_time:39326ms step_avg:94.08ms +step:419/1705 train_time:39419ms step_avg:94.08ms +step:420/1705 train_time:39512ms step_avg:94.08ms +step:421/1705 train_time:39605ms step_avg:94.07ms +step:422/1705 train_time:39698ms step_avg:94.07ms +step:423/1705 train_time:39791ms step_avg:94.07ms +step:424/1705 train_time:39884ms step_avg:94.07ms +step:425/1705 train_time:40174ms step_avg:94.53ms +step:426/1705 train_time:40314ms step_avg:94.63ms +step:427/1705 train_time:40405ms step_avg:94.63ms +step:428/1705 train_time:40497ms step_avg:94.62ms +step:429/1705 train_time:40589ms step_avg:94.61ms +step:430/1705 train_time:40681ms step_avg:94.61ms +step:431/1705 train_time:40773ms step_avg:94.60ms +step:432/1705 train_time:40865ms step_avg:94.60ms +step:433/1705 train_time:40957ms step_avg:94.59ms +step:434/1705 train_time:41049ms step_avg:94.58ms +step:435/1705 train_time:41144ms step_avg:94.58ms +step:436/1705 train_time:41239ms step_avg:94.59ms +step:437/1705 train_time:41334ms step_avg:94.59ms +step:438/1705 train_time:41427ms step_avg:94.58ms +step:439/1705 train_time:41520ms step_avg:94.58ms +step:440/1705 train_time:41613ms step_avg:94.57ms +step:441/1705 train_time:41705ms step_avg:94.57ms +step:442/1705 train_time:41797ms step_avg:94.56ms +step:443/1705 train_time:41889ms step_avg:94.56ms +step:444/1705 train_time:41981ms step_avg:94.55ms +step:445/1705 train_time:42074ms step_avg:94.55ms +step:446/1705 train_time:42168ms step_avg:94.55ms +step:447/1705 train_time:42262ms step_avg:94.55ms +step:448/1705 train_time:42356ms step_avg:94.54ms +step:449/1705 train_time:42449ms step_avg:94.54ms +step:450/1705 train_time:42543ms step_avg:94.54ms +step:451/1705 train_time:42637ms step_avg:94.54ms +step:452/1705 train_time:42729ms step_avg:94.53ms +step:453/1705 train_time:42822ms step_avg:94.53ms +step:454/1705 train_time:42914ms step_avg:94.52ms +step:455/1705 train_time:43007ms step_avg:94.52ms +step:456/1705 train_time:43100ms step_avg:94.52ms +step:457/1705 train_time:43192ms step_avg:94.51ms +step:458/1705 train_time:43286ms step_avg:94.51ms +step:459/1705 train_time:43379ms step_avg:94.51ms +step:460/1705 train_time:43472ms step_avg:94.50ms +step:461/1705 train_time:43566ms step_avg:94.50ms +step:462/1705 train_time:43659ms step_avg:94.50ms +step:463/1705 train_time:43751ms step_avg:94.49ms +step:464/1705 train_time:43844ms step_avg:94.49ms +step:465/1705 train_time:43937ms step_avg:94.49ms +step:466/1705 train_time:44029ms step_avg:94.48ms +step:467/1705 train_time:44123ms step_avg:94.48ms +step:468/1705 train_time:44216ms step_avg:94.48ms +step:469/1705 train_time:44308ms step_avg:94.47ms +step:470/1705 train_time:44401ms step_avg:94.47ms +step:471/1705 train_time:44494ms step_avg:94.47ms +step:472/1705 train_time:44587ms step_avg:94.47ms +step:473/1705 train_time:44680ms step_avg:94.46ms +step:474/1705 train_time:44772ms step_avg:94.46ms +step:475/1705 train_time:44865ms step_avg:94.45ms +step:476/1705 train_time:44959ms step_avg:94.45ms +step:477/1705 train_time:45051ms step_avg:94.45ms +step:478/1705 train_time:45144ms step_avg:94.44ms +step:479/1705 train_time:45237ms step_avg:94.44ms +step:480/1705 train_time:45329ms step_avg:94.44ms +step:481/1705 train_time:45422ms step_avg:94.43ms +step:482/1705 train_time:45515ms step_avg:94.43ms +step:483/1705 train_time:45608ms step_avg:94.43ms +step:484/1705 train_time:45702ms step_avg:94.42ms +step:485/1705 train_time:45794ms step_avg:94.42ms +step:486/1705 train_time:45887ms step_avg:94.42ms +step:487/1705 train_time:45980ms step_avg:94.41ms +step:488/1705 train_time:46072ms step_avg:94.41ms +step:489/1705 train_time:46165ms step_avg:94.41ms +step:490/1705 train_time:46258ms step_avg:94.40ms +step:491/1705 train_time:46350ms step_avg:94.40ms +step:492/1705 train_time:46445ms step_avg:94.40ms +step:493/1705 train_time:46538ms step_avg:94.40ms +step:494/1705 train_time:46631ms step_avg:94.39ms +step:495/1705 train_time:46725ms step_avg:94.39ms +step:496/1705 train_time:46817ms step_avg:94.39ms +step:497/1705 train_time:46910ms step_avg:94.39ms +step:498/1705 train_time:47003ms step_avg:94.38ms +step:499/1705 train_time:47096ms step_avg:94.38ms +step:500/1705 train_time:47188ms step_avg:94.38ms +step:500/1705 val_loss:3.7177 train_time:47281ms step_avg:94.56ms +step:501/1705 train_time:47303ms step_avg:94.42ms +step:502/1705 train_time:47380ms step_avg:94.38ms +step:503/1705 train_time:47479ms step_avg:94.39ms +step:504/1705 train_time:47573ms step_avg:94.39ms +step:505/1705 train_time:47665ms step_avg:94.39ms +step:506/1705 train_time:47757ms step_avg:94.38ms +step:507/1705 train_time:47849ms step_avg:94.38ms +step:508/1705 train_time:47941ms step_avg:94.37ms +step:509/1705 train_time:48033ms step_avg:94.37ms +step:510/1705 train_time:48125ms step_avg:94.36ms +step:511/1705 train_time:48218ms step_avg:94.36ms +step:512/1705 train_time:48314ms step_avg:94.36ms +step:513/1705 train_time:48409ms step_avg:94.36ms +step:514/1705 train_time:48503ms step_avg:94.36ms +step:515/1705 train_time:48597ms step_avg:94.36ms +step:516/1705 train_time:48690ms step_avg:94.36ms +step:517/1705 train_time:48782ms step_avg:94.36ms +step:518/1705 train_time:48875ms step_avg:94.35ms +step:519/1705 train_time:48967ms step_avg:94.35ms +step:520/1705 train_time:49060ms step_avg:94.35ms +step:521/1705 train_time:49152ms step_avg:94.34ms +step:522/1705 train_time:49244ms step_avg:94.34ms +step:523/1705 train_time:49338ms step_avg:94.34ms +step:524/1705 train_time:49433ms step_avg:94.34ms +step:525/1705 train_time:49526ms step_avg:94.33ms +step:526/1705 train_time:49620ms step_avg:94.33ms +step:527/1705 train_time:49713ms step_avg:94.33ms +step:528/1705 train_time:49806ms step_avg:94.33ms +step:529/1705 train_time:49898ms step_avg:94.32ms +step:530/1705 train_time:49991ms step_avg:94.32ms +step:531/1705 train_time:50083ms step_avg:94.32ms +step:532/1705 train_time:50175ms step_avg:94.31ms +step:533/1705 train_time:50268ms step_avg:94.31ms +step:534/1705 train_time:50361ms step_avg:94.31ms +step:535/1705 train_time:50455ms step_avg:94.31ms +step:536/1705 train_time:50549ms step_avg:94.31ms +step:537/1705 train_time:50642ms step_avg:94.30ms +step:538/1705 train_time:50735ms step_avg:94.30ms +step:539/1705 train_time:50828ms step_avg:94.30ms +step:540/1705 train_time:50921ms step_avg:94.30ms +step:541/1705 train_time:51015ms step_avg:94.30ms +step:542/1705 train_time:51106ms step_avg:94.29ms +step:543/1705 train_time:51199ms step_avg:94.29ms +step:544/1705 train_time:51293ms step_avg:94.29ms +step:545/1705 train_time:51386ms step_avg:94.29ms +step:546/1705 train_time:51480ms step_avg:94.29ms +step:547/1705 train_time:51573ms step_avg:94.28ms +step:548/1705 train_time:51666ms step_avg:94.28ms +step:549/1705 train_time:51759ms step_avg:94.28ms +step:550/1705 train_time:51853ms step_avg:94.28ms +step:551/1705 train_time:51945ms step_avg:94.27ms +step:552/1705 train_time:52037ms step_avg:94.27ms +step:553/1705 train_time:52130ms step_avg:94.27ms +step:554/1705 train_time:52222ms step_avg:94.26ms +step:555/1705 train_time:52315ms step_avg:94.26ms +step:556/1705 train_time:52408ms step_avg:94.26ms +step:557/1705 train_time:52501ms step_avg:94.26ms +step:558/1705 train_time:52595ms step_avg:94.26ms +step:559/1705 train_time:52688ms step_avg:94.25ms +step:560/1705 train_time:52781ms step_avg:94.25ms +step:561/1705 train_time:52875ms step_avg:94.25ms +step:562/1705 train_time:52968ms step_avg:94.25ms +step:563/1705 train_time:53060ms step_avg:94.25ms +step:564/1705 train_time:53153ms step_avg:94.24ms +step:565/1705 train_time:53245ms step_avg:94.24ms +step:566/1705 train_time:53338ms step_avg:94.24ms +step:567/1705 train_time:53432ms step_avg:94.24ms +step:568/1705 train_time:53525ms step_avg:94.23ms +step:569/1705 train_time:53618ms step_avg:94.23ms +step:570/1705 train_time:53712ms step_avg:94.23ms +step:571/1705 train_time:53806ms step_avg:94.23ms +step:572/1705 train_time:53900ms step_avg:94.23ms +step:573/1705 train_time:53995ms step_avg:94.23ms +step:574/1705 train_time:54089ms step_avg:94.23ms +step:575/1705 train_time:54182ms step_avg:94.23ms +step:576/1705 train_time:54277ms step_avg:94.23ms +step:577/1705 train_time:54371ms step_avg:94.23ms +step:578/1705 train_time:54465ms step_avg:94.23ms +step:579/1705 train_time:54560ms step_avg:94.23ms +step:580/1705 train_time:54655ms step_avg:94.23ms +step:581/1705 train_time:54749ms step_avg:94.23ms +step:582/1705 train_time:54843ms step_avg:94.23ms +step:583/1705 train_time:54938ms step_avg:94.23ms +step:584/1705 train_time:55032ms step_avg:94.23ms +step:585/1705 train_time:55125ms step_avg:94.23ms +step:586/1705 train_time:55219ms step_avg:94.23ms +step:587/1705 train_time:55313ms step_avg:94.23ms +step:588/1705 train_time:55408ms step_avg:94.23ms +step:589/1705 train_time:55502ms step_avg:94.23ms +step:590/1705 train_time:55597ms step_avg:94.23ms +step:591/1705 train_time:55693ms step_avg:94.23ms +step:592/1705 train_time:55786ms step_avg:94.23ms +step:593/1705 train_time:55881ms step_avg:94.23ms +step:594/1705 train_time:55976ms step_avg:94.24ms +step:595/1705 train_time:56070ms step_avg:94.24ms +step:596/1705 train_time:56164ms step_avg:94.23ms +step:597/1705 train_time:56258ms step_avg:94.24ms +step:598/1705 train_time:56353ms step_avg:94.24ms +step:599/1705 train_time:56447ms step_avg:94.24ms +step:600/1705 train_time:56541ms step_avg:94.24ms +step:601/1705 train_time:56637ms step_avg:94.24ms +step:602/1705 train_time:56732ms step_avg:94.24ms +step:603/1705 train_time:56827ms step_avg:94.24ms +step:604/1705 train_time:56921ms step_avg:94.24ms +step:605/1705 train_time:57016ms step_avg:94.24ms +step:606/1705 train_time:57110ms step_avg:94.24ms +step:607/1705 train_time:57203ms step_avg:94.24ms +step:608/1705 train_time:57298ms step_avg:94.24ms +step:609/1705 train_time:57393ms step_avg:94.24ms +step:610/1705 train_time:57488ms step_avg:94.24ms +step:611/1705 train_time:57582ms step_avg:94.24ms +step:612/1705 train_time:57676ms step_avg:94.24ms +step:613/1705 train_time:57770ms step_avg:94.24ms +step:614/1705 train_time:57864ms step_avg:94.24ms +step:615/1705 train_time:57958ms step_avg:94.24ms +step:616/1705 train_time:58052ms step_avg:94.24ms +step:617/1705 train_time:58146ms step_avg:94.24ms +step:618/1705 train_time:58240ms step_avg:94.24ms +step:619/1705 train_time:58335ms step_avg:94.24ms +step:620/1705 train_time:58429ms step_avg:94.24ms +step:621/1705 train_time:58523ms step_avg:94.24ms +step:622/1705 train_time:58617ms step_avg:94.24ms +step:623/1705 train_time:58712ms step_avg:94.24ms +step:624/1705 train_time:58806ms step_avg:94.24ms +step:625/1705 train_time:58900ms step_avg:94.24ms +step:625/1705 val_loss:3.6185 train_time:58995ms step_avg:94.39ms +step:626/1705 train_time:59017ms step_avg:94.28ms +step:627/1705 train_time:59103ms step_avg:94.26ms +step:628/1705 train_time:59201ms step_avg:94.27ms +step:629/1705 train_time:59297ms step_avg:94.27ms +step:630/1705 train_time:59390ms step_avg:94.27ms +step:631/1705 train_time:59483ms step_avg:94.27ms +step:632/1705 train_time:59576ms step_avg:94.27ms +step:633/1705 train_time:59669ms step_avg:94.26ms +step:634/1705 train_time:59762ms step_avg:94.26ms +step:635/1705 train_time:59855ms step_avg:94.26ms +step:636/1705 train_time:59950ms step_avg:94.26ms +step:637/1705 train_time:60046ms step_avg:94.26ms +step:638/1705 train_time:60142ms step_avg:94.27ms +step:639/1705 train_time:60412ms step_avg:94.54ms +step:640/1705 train_time:60592ms step_avg:94.67ms +step:641/1705 train_time:60684ms step_avg:94.67ms +step:642/1705 train_time:60778ms step_avg:94.67ms +step:643/1705 train_time:60871ms step_avg:94.67ms +step:644/1705 train_time:60963ms step_avg:94.66ms +step:645/1705 train_time:61057ms step_avg:94.66ms +step:646/1705 train_time:61150ms step_avg:94.66ms +step:647/1705 train_time:61243ms step_avg:94.66ms +step:648/1705 train_time:61336ms step_avg:94.65ms +step:649/1705 train_time:61437ms step_avg:94.66ms +step:650/1705 train_time:61536ms step_avg:94.67ms +step:651/1705 train_time:61632ms step_avg:94.67ms +step:652/1705 train_time:61726ms step_avg:94.67ms +step:653/1705 train_time:61819ms step_avg:94.67ms +step:654/1705 train_time:61913ms step_avg:94.67ms +step:655/1705 train_time:62006ms step_avg:94.67ms +step:656/1705 train_time:62099ms step_avg:94.66ms +step:657/1705 train_time:62193ms step_avg:94.66ms +step:658/1705 train_time:62286ms step_avg:94.66ms +step:659/1705 train_time:62382ms step_avg:94.66ms +step:660/1705 train_time:62478ms step_avg:94.66ms +step:661/1705 train_time:62574ms step_avg:94.67ms +step:662/1705 train_time:62670ms step_avg:94.67ms +step:663/1705 train_time:62763ms step_avg:94.67ms +step:664/1705 train_time:62858ms step_avg:94.67ms +step:665/1705 train_time:62952ms step_avg:94.66ms +step:666/1705 train_time:63045ms step_avg:94.66ms +step:667/1705 train_time:63139ms step_avg:94.66ms +step:668/1705 train_time:63232ms step_avg:94.66ms +step:669/1705 train_time:63327ms step_avg:94.66ms +step:670/1705 train_time:63421ms step_avg:94.66ms +step:671/1705 train_time:63517ms step_avg:94.66ms +step:672/1705 train_time:63612ms step_avg:94.66ms +step:673/1705 train_time:63707ms step_avg:94.66ms +step:674/1705 train_time:63800ms step_avg:94.66ms +step:675/1705 train_time:63895ms step_avg:94.66ms +step:676/1705 train_time:63990ms step_avg:94.66ms +step:677/1705 train_time:64083ms step_avg:94.66ms +step:678/1705 train_time:64177ms step_avg:94.66ms +step:679/1705 train_time:64271ms step_avg:94.65ms +step:680/1705 train_time:64364ms step_avg:94.65ms +step:681/1705 train_time:64459ms step_avg:94.65ms +step:682/1705 train_time:64555ms step_avg:94.66ms +step:683/1705 train_time:64650ms step_avg:94.66ms +step:684/1705 train_time:64744ms step_avg:94.66ms +step:685/1705 train_time:64839ms step_avg:94.66ms +step:686/1705 train_time:64934ms step_avg:94.66ms +step:687/1705 train_time:65027ms step_avg:94.65ms +step:688/1705 train_time:65120ms step_avg:94.65ms +step:689/1705 train_time:65215ms step_avg:94.65ms +step:690/1705 train_time:65309ms step_avg:94.65ms +step:691/1705 train_time:65403ms step_avg:94.65ms +step:692/1705 train_time:65498ms step_avg:94.65ms +step:693/1705 train_time:65593ms step_avg:94.65ms +step:694/1705 train_time:65688ms step_avg:94.65ms +step:695/1705 train_time:65782ms step_avg:94.65ms +step:696/1705 train_time:65876ms step_avg:94.65ms +step:697/1705 train_time:65970ms step_avg:94.65ms +step:698/1705 train_time:66063ms step_avg:94.65ms +step:699/1705 train_time:66158ms step_avg:94.65ms +step:700/1705 train_time:66253ms step_avg:94.65ms +step:701/1705 train_time:66346ms step_avg:94.64ms +step:702/1705 train_time:66440ms step_avg:94.64ms +step:703/1705 train_time:66535ms step_avg:94.64ms +step:704/1705 train_time:66630ms step_avg:94.64ms +step:705/1705 train_time:66724ms step_avg:94.64ms +step:706/1705 train_time:66819ms step_avg:94.64ms +step:707/1705 train_time:66914ms step_avg:94.64ms +step:708/1705 train_time:67008ms step_avg:94.64ms +step:709/1705 train_time:67102ms step_avg:94.64ms +step:710/1705 train_time:67196ms step_avg:94.64ms +step:711/1705 train_time:67291ms step_avg:94.64ms +step:712/1705 train_time:67385ms step_avg:94.64ms +step:713/1705 train_time:67479ms step_avg:94.64ms +step:714/1705 train_time:67574ms step_avg:94.64ms +step:715/1705 train_time:67669ms step_avg:94.64ms +step:716/1705 train_time:67763ms step_avg:94.64ms +step:717/1705 train_time:67858ms step_avg:94.64ms +step:718/1705 train_time:67952ms step_avg:94.64ms +step:719/1705 train_time:68047ms step_avg:94.64ms +step:720/1705 train_time:68140ms step_avg:94.64ms +step:721/1705 train_time:68235ms step_avg:94.64ms +step:722/1705 train_time:68329ms step_avg:94.64ms +step:723/1705 train_time:68422ms step_avg:94.64ms +step:724/1705 train_time:68517ms step_avg:94.64ms +step:725/1705 train_time:68611ms step_avg:94.64ms +step:726/1705 train_time:68705ms step_avg:94.63ms +step:727/1705 train_time:68799ms step_avg:94.63ms +step:728/1705 train_time:68894ms step_avg:94.63ms +step:729/1705 train_time:68989ms step_avg:94.63ms +step:730/1705 train_time:69082ms step_avg:94.63ms +step:731/1705 train_time:69177ms step_avg:94.63ms +step:732/1705 train_time:69272ms step_avg:94.63ms +step:733/1705 train_time:69367ms step_avg:94.63ms +step:734/1705 train_time:69461ms step_avg:94.63ms +step:735/1705 train_time:69555ms step_avg:94.63ms +step:736/1705 train_time:69651ms step_avg:94.64ms +step:737/1705 train_time:69745ms step_avg:94.63ms +step:738/1705 train_time:69839ms step_avg:94.63ms +step:739/1705 train_time:69934ms step_avg:94.63ms +step:740/1705 train_time:70029ms step_avg:94.63ms +step:741/1705 train_time:70122ms step_avg:94.63ms +step:742/1705 train_time:70217ms step_avg:94.63ms +step:743/1705 train_time:70311ms step_avg:94.63ms +step:744/1705 train_time:70405ms step_avg:94.63ms +step:745/1705 train_time:70500ms step_avg:94.63ms +step:746/1705 train_time:70596ms step_avg:94.63ms +step:747/1705 train_time:70690ms step_avg:94.63ms +step:748/1705 train_time:70783ms step_avg:94.63ms +step:749/1705 train_time:70878ms step_avg:94.63ms +step:750/1705 train_time:70972ms step_avg:94.63ms +step:750/1705 val_loss:3.5654 train_time:71067ms step_avg:94.76ms +step:751/1705 train_time:71089ms step_avg:94.66ms +step:752/1705 train_time:71168ms step_avg:94.64ms +step:753/1705 train_time:71267ms step_avg:94.64ms +step:754/1705 train_time:71362ms step_avg:94.64ms +step:755/1705 train_time:71455ms step_avg:94.64ms +step:756/1705 train_time:71548ms step_avg:94.64ms +step:757/1705 train_time:71642ms step_avg:94.64ms +step:758/1705 train_time:71735ms step_avg:94.64ms +step:759/1705 train_time:71828ms step_avg:94.64ms +step:760/1705 train_time:71922ms step_avg:94.63ms +step:761/1705 train_time:72017ms step_avg:94.64ms +step:762/1705 train_time:72113ms step_avg:94.64ms +step:763/1705 train_time:72210ms step_avg:94.64ms +step:764/1705 train_time:72306ms step_avg:94.64ms +step:765/1705 train_time:72402ms step_avg:94.64ms +step:766/1705 train_time:72496ms step_avg:94.64ms +step:767/1705 train_time:72590ms step_avg:94.64ms +step:768/1705 train_time:72684ms step_avg:94.64ms +step:769/1705 train_time:72777ms step_avg:94.64ms +step:770/1705 train_time:72871ms step_avg:94.64ms +step:771/1705 train_time:72965ms step_avg:94.64ms +step:772/1705 train_time:73060ms step_avg:94.64ms +step:773/1705 train_time:73156ms step_avg:94.64ms +step:774/1705 train_time:73251ms step_avg:94.64ms +step:775/1705 train_time:73347ms step_avg:94.64ms +step:776/1705 train_time:73441ms step_avg:94.64ms +step:777/1705 train_time:73536ms step_avg:94.64ms +step:778/1705 train_time:73629ms step_avg:94.64ms +step:779/1705 train_time:73723ms step_avg:94.64ms +step:780/1705 train_time:73817ms step_avg:94.64ms +step:781/1705 train_time:73910ms step_avg:94.63ms +step:782/1705 train_time:74004ms step_avg:94.63ms +step:783/1705 train_time:74100ms step_avg:94.64ms +step:784/1705 train_time:74196ms step_avg:94.64ms +step:785/1705 train_time:74290ms step_avg:94.64ms +step:786/1705 train_time:74385ms step_avg:94.64ms +step:787/1705 train_time:74480ms step_avg:94.64ms +step:788/1705 train_time:74575ms step_avg:94.64ms +step:789/1705 train_time:74668ms step_avg:94.64ms +step:790/1705 train_time:74762ms step_avg:94.64ms +step:791/1705 train_time:74857ms step_avg:94.64ms +step:792/1705 train_time:74950ms step_avg:94.63ms +step:793/1705 train_time:75045ms step_avg:94.63ms +step:794/1705 train_time:75140ms step_avg:94.63ms +step:795/1705 train_time:75235ms step_avg:94.63ms +step:796/1705 train_time:75329ms step_avg:94.63ms +step:797/1705 train_time:75424ms step_avg:94.63ms +step:798/1705 train_time:75519ms step_avg:94.64ms +step:799/1705 train_time:75613ms step_avg:94.63ms +step:800/1705 train_time:75707ms step_avg:94.63ms +step:801/1705 train_time:75801ms step_avg:94.63ms +step:802/1705 train_time:75895ms step_avg:94.63ms +step:803/1705 train_time:75989ms step_avg:94.63ms +step:804/1705 train_time:76084ms step_avg:94.63ms +step:805/1705 train_time:76180ms step_avg:94.63ms +step:806/1705 train_time:76276ms step_avg:94.63ms +step:807/1705 train_time:76370ms step_avg:94.63ms +step:808/1705 train_time:76465ms step_avg:94.63ms +step:809/1705 train_time:76560ms step_avg:94.64ms +step:810/1705 train_time:76654ms step_avg:94.63ms +step:811/1705 train_time:76748ms step_avg:94.63ms +step:812/1705 train_time:76842ms step_avg:94.63ms +step:813/1705 train_time:76936ms step_avg:94.63ms +step:814/1705 train_time:77030ms step_avg:94.63ms +step:815/1705 train_time:77124ms step_avg:94.63ms +step:816/1705 train_time:77220ms step_avg:94.63ms +step:817/1705 train_time:77315ms step_avg:94.63ms +step:818/1705 train_time:77409ms step_avg:94.63ms +step:819/1705 train_time:77504ms step_avg:94.63ms +step:820/1705 train_time:77599ms step_avg:94.63ms +step:821/1705 train_time:77693ms step_avg:94.63ms +step:822/1705 train_time:77788ms step_avg:94.63ms +step:823/1705 train_time:77882ms step_avg:94.63ms +step:824/1705 train_time:77977ms step_avg:94.63ms +step:825/1705 train_time:78070ms step_avg:94.63ms +step:826/1705 train_time:78165ms step_avg:94.63ms +step:827/1705 train_time:78260ms step_avg:94.63ms +step:828/1705 train_time:78354ms step_avg:94.63ms +step:829/1705 train_time:78448ms step_avg:94.63ms +step:830/1705 train_time:78543ms step_avg:94.63ms +step:831/1705 train_time:78638ms step_avg:94.63ms +step:832/1705 train_time:78731ms step_avg:94.63ms +step:833/1705 train_time:78825ms step_avg:94.63ms +step:834/1705 train_time:78921ms step_avg:94.63ms +step:835/1705 train_time:79015ms step_avg:94.63ms +step:836/1705 train_time:79109ms step_avg:94.63ms +step:837/1705 train_time:79204ms step_avg:94.63ms +step:838/1705 train_time:79298ms step_avg:94.63ms +step:839/1705 train_time:79393ms step_avg:94.63ms +step:840/1705 train_time:79487ms step_avg:94.63ms +step:841/1705 train_time:79581ms step_avg:94.63ms +step:842/1705 train_time:79675ms step_avg:94.63ms +step:843/1705 train_time:79769ms step_avg:94.63ms +step:844/1705 train_time:79864ms step_avg:94.63ms +step:845/1705 train_time:79959ms step_avg:94.63ms +step:846/1705 train_time:80053ms step_avg:94.63ms +step:847/1705 train_time:80148ms step_avg:94.63ms +step:848/1705 train_time:80242ms step_avg:94.63ms +step:849/1705 train_time:80336ms step_avg:94.62ms +step:850/1705 train_time:80431ms step_avg:94.62ms +step:851/1705 train_time:80698ms step_avg:94.83ms +step:852/1705 train_time:80879ms step_avg:94.93ms +step:853/1705 train_time:80971ms step_avg:94.92ms +step:854/1705 train_time:81065ms step_avg:94.92ms +step:855/1705 train_time:81159ms step_avg:94.92ms +step:856/1705 train_time:81252ms step_avg:94.92ms +step:857/1705 train_time:81346ms step_avg:94.92ms +step:858/1705 train_time:81439ms step_avg:94.92ms +step:859/1705 train_time:81532ms step_avg:94.91ms +step:860/1705 train_time:81625ms step_avg:94.91ms +step:861/1705 train_time:81722ms step_avg:94.92ms +step:862/1705 train_time:81819ms step_avg:94.92ms +step:863/1705 train_time:81915ms step_avg:94.92ms +step:864/1705 train_time:82009ms step_avg:94.92ms +step:865/1705 train_time:82103ms step_avg:94.92ms +step:866/1705 train_time:82198ms step_avg:94.92ms +step:867/1705 train_time:82291ms step_avg:94.91ms +step:868/1705 train_time:82384ms step_avg:94.91ms +step:869/1705 train_time:82479ms step_avg:94.91ms +step:870/1705 train_time:82573ms step_avg:94.91ms +step:871/1705 train_time:82666ms step_avg:94.91ms +step:872/1705 train_time:82762ms step_avg:94.91ms +step:873/1705 train_time:82859ms step_avg:94.91ms +step:874/1705 train_time:82953ms step_avg:94.91ms +step:875/1705 train_time:83047ms step_avg:94.91ms +step:875/1705 val_loss:3.5228 train_time:83143ms step_avg:95.02ms +step:876/1705 train_time:83164ms step_avg:94.94ms +step:877/1705 train_time:83242ms step_avg:94.92ms +step:878/1705 train_time:83341ms step_avg:94.92ms +step:879/1705 train_time:83437ms step_avg:94.92ms +step:880/1705 train_time:83529ms step_avg:94.92ms +step:881/1705 train_time:83623ms step_avg:94.92ms +step:882/1705 train_time:83716ms step_avg:94.92ms +step:883/1705 train_time:83809ms step_avg:94.91ms +step:884/1705 train_time:83902ms step_avg:94.91ms +step:885/1705 train_time:83996ms step_avg:94.91ms +step:886/1705 train_time:84090ms step_avg:94.91ms +step:887/1705 train_time:84187ms step_avg:94.91ms +step:888/1705 train_time:84284ms step_avg:94.91ms +step:889/1705 train_time:84380ms step_avg:94.92ms +step:890/1705 train_time:84475ms step_avg:94.92ms +step:891/1705 train_time:84569ms step_avg:94.91ms +step:892/1705 train_time:84664ms step_avg:94.91ms +step:893/1705 train_time:84758ms step_avg:94.91ms +step:894/1705 train_time:84850ms step_avg:94.91ms +step:895/1705 train_time:84944ms step_avg:94.91ms +step:896/1705 train_time:85040ms step_avg:94.91ms +step:897/1705 train_time:85136ms step_avg:94.91ms +step:898/1705 train_time:85230ms step_avg:94.91ms +step:899/1705 train_time:85325ms step_avg:94.91ms +step:900/1705 train_time:85421ms step_avg:94.91ms +step:901/1705 train_time:85517ms step_avg:94.91ms +step:902/1705 train_time:85610ms step_avg:94.91ms +step:903/1705 train_time:85704ms step_avg:94.91ms +step:904/1705 train_time:85798ms step_avg:94.91ms +step:905/1705 train_time:85891ms step_avg:94.91ms +step:906/1705 train_time:85985ms step_avg:94.91ms +step:907/1705 train_time:86080ms step_avg:94.91ms +step:908/1705 train_time:86175ms step_avg:94.91ms +step:909/1705 train_time:86269ms step_avg:94.91ms +step:910/1705 train_time:86366ms step_avg:94.91ms +step:911/1705 train_time:86462ms step_avg:94.91ms +step:912/1705 train_time:86557ms step_avg:94.91ms +step:913/1705 train_time:86650ms step_avg:94.91ms +step:914/1705 train_time:86745ms step_avg:94.91ms +step:915/1705 train_time:86840ms step_avg:94.91ms +step:916/1705 train_time:86933ms step_avg:94.91ms +step:917/1705 train_time:87027ms step_avg:94.90ms +step:918/1705 train_time:87123ms step_avg:94.90ms +step:919/1705 train_time:87218ms step_avg:94.91ms +step:920/1705 train_time:87313ms step_avg:94.91ms +step:921/1705 train_time:87407ms step_avg:94.90ms +step:922/1705 train_time:87503ms step_avg:94.91ms +step:923/1705 train_time:87597ms step_avg:94.90ms +step:924/1705 train_time:87691ms step_avg:94.90ms +step:925/1705 train_time:87785ms step_avg:94.90ms +step:926/1705 train_time:87880ms step_avg:94.90ms +step:927/1705 train_time:87974ms step_avg:94.90ms +step:928/1705 train_time:88069ms step_avg:94.90ms +step:929/1705 train_time:88163ms step_avg:94.90ms +step:930/1705 train_time:88258ms step_avg:94.90ms +step:931/1705 train_time:88352ms step_avg:94.90ms +step:932/1705 train_time:88446ms step_avg:94.90ms +step:933/1705 train_time:88542ms step_avg:94.90ms +step:934/1705 train_time:88637ms step_avg:94.90ms +step:935/1705 train_time:88731ms step_avg:94.90ms +step:936/1705 train_time:88826ms step_avg:94.90ms +step:937/1705 train_time:88920ms step_avg:94.90ms +step:938/1705 train_time:89015ms step_avg:94.90ms +step:939/1705 train_time:89109ms step_avg:94.90ms +step:940/1705 train_time:89204ms step_avg:94.90ms +step:941/1705 train_time:89299ms step_avg:94.90ms +step:942/1705 train_time:89392ms step_avg:94.90ms +step:943/1705 train_time:89486ms step_avg:94.90ms +step:944/1705 train_time:89581ms step_avg:94.90ms +step:945/1705 train_time:89676ms step_avg:94.90ms +step:946/1705 train_time:89770ms step_avg:94.89ms +step:947/1705 train_time:89864ms step_avg:94.89ms +step:948/1705 train_time:89960ms step_avg:94.89ms +step:949/1705 train_time:90054ms step_avg:94.89ms +step:950/1705 train_time:90148ms step_avg:94.89ms +step:951/1705 train_time:90243ms step_avg:94.89ms +step:952/1705 train_time:90338ms step_avg:94.89ms +step:953/1705 train_time:90433ms step_avg:94.89ms +step:954/1705 train_time:90527ms step_avg:94.89ms +step:955/1705 train_time:90621ms step_avg:94.89ms +step:956/1705 train_time:90716ms step_avg:94.89ms +step:957/1705 train_time:90810ms step_avg:94.89ms +step:958/1705 train_time:90904ms step_avg:94.89ms +step:959/1705 train_time:90998ms step_avg:94.89ms +step:960/1705 train_time:91092ms step_avg:94.89ms +step:961/1705 train_time:91187ms step_avg:94.89ms +step:962/1705 train_time:91282ms step_avg:94.89ms +step:963/1705 train_time:91376ms step_avg:94.89ms +step:964/1705 train_time:91471ms step_avg:94.89ms +step:965/1705 train_time:91565ms step_avg:94.89ms +step:966/1705 train_time:91661ms step_avg:94.89ms +step:967/1705 train_time:91755ms step_avg:94.89ms +step:968/1705 train_time:91848ms step_avg:94.88ms +step:969/1705 train_time:91943ms step_avg:94.88ms +step:970/1705 train_time:92038ms step_avg:94.88ms +step:971/1705 train_time:92133ms step_avg:94.88ms +step:972/1705 train_time:92227ms step_avg:94.88ms +step:973/1705 train_time:92322ms step_avg:94.88ms +step:974/1705 train_time:92417ms step_avg:94.88ms +step:975/1705 train_time:92511ms step_avg:94.88ms +step:976/1705 train_time:92605ms step_avg:94.88ms +step:977/1705 train_time:92700ms step_avg:94.88ms +step:978/1705 train_time:92794ms step_avg:94.88ms +step:979/1705 train_time:92888ms step_avg:94.88ms +step:980/1705 train_time:92983ms step_avg:94.88ms +step:981/1705 train_time:93077ms step_avg:94.88ms +step:982/1705 train_time:93171ms step_avg:94.88ms +step:983/1705 train_time:93266ms step_avg:94.88ms +step:984/1705 train_time:93362ms step_avg:94.88ms +step:985/1705 train_time:93457ms step_avg:94.88ms +step:986/1705 train_time:93551ms step_avg:94.88ms +step:987/1705 train_time:93646ms step_avg:94.88ms +step:988/1705 train_time:93740ms step_avg:94.88ms +step:989/1705 train_time:93834ms step_avg:94.88ms +step:990/1705 train_time:93928ms step_avg:94.88ms +step:991/1705 train_time:94024ms step_avg:94.88ms +step:992/1705 train_time:94118ms step_avg:94.88ms +step:993/1705 train_time:94212ms step_avg:94.88ms +step:994/1705 train_time:94307ms step_avg:94.88ms +step:995/1705 train_time:94402ms step_avg:94.88ms +step:996/1705 train_time:94498ms step_avg:94.88ms +step:997/1705 train_time:94591ms step_avg:94.88ms +step:998/1705 train_time:94685ms step_avg:94.87ms +step:999/1705 train_time:94779ms step_avg:94.87ms +step:1000/1705 train_time:94874ms step_avg:94.87ms +step:1000/1705 val_loss:3.4836 train_time:94968ms step_avg:94.97ms +step:1001/1705 train_time:94989ms step_avg:94.89ms +step:1002/1705 train_time:95069ms step_avg:94.88ms +step:1003/1705 train_time:95166ms step_avg:94.88ms +step:1004/1705 train_time:95263ms step_avg:94.88ms +step:1005/1705 train_time:95356ms step_avg:94.88ms +step:1006/1705 train_time:95450ms step_avg:94.88ms +step:1007/1705 train_time:95543ms step_avg:94.88ms +step:1008/1705 train_time:95637ms step_avg:94.88ms +step:1009/1705 train_time:95730ms step_avg:94.88ms +step:1010/1705 train_time:95824ms step_avg:94.88ms +step:1011/1705 train_time:95920ms step_avg:94.88ms +step:1012/1705 train_time:96017ms step_avg:94.88ms +step:1013/1705 train_time:96114ms step_avg:94.88ms +step:1014/1705 train_time:96210ms step_avg:94.88ms +step:1015/1705 train_time:96305ms step_avg:94.88ms +step:1016/1705 train_time:96399ms step_avg:94.88ms +step:1017/1705 train_time:96494ms step_avg:94.88ms +step:1018/1705 train_time:96587ms step_avg:94.88ms +step:1019/1705 train_time:96681ms step_avg:94.88ms +step:1020/1705 train_time:96776ms step_avg:94.88ms +step:1021/1705 train_time:96871ms step_avg:94.88ms +step:1022/1705 train_time:96965ms step_avg:94.88ms +step:1023/1705 train_time:97061ms step_avg:94.88ms +step:1024/1705 train_time:97157ms step_avg:94.88ms +step:1025/1705 train_time:97253ms step_avg:94.88ms +step:1026/1705 train_time:97348ms step_avg:94.88ms +step:1027/1705 train_time:97442ms step_avg:94.88ms +step:1028/1705 train_time:97536ms step_avg:94.88ms +step:1029/1705 train_time:97630ms step_avg:94.88ms +step:1030/1705 train_time:97724ms step_avg:94.88ms +step:1031/1705 train_time:97820ms step_avg:94.88ms +step:1032/1705 train_time:97914ms step_avg:94.88ms +step:1033/1705 train_time:98008ms step_avg:94.88ms +step:1034/1705 train_time:98103ms step_avg:94.88ms +step:1035/1705 train_time:98199ms step_avg:94.88ms +step:1036/1705 train_time:98294ms step_avg:94.88ms +step:1037/1705 train_time:98389ms step_avg:94.88ms +step:1038/1705 train_time:98483ms step_avg:94.88ms +step:1039/1705 train_time:98578ms step_avg:94.88ms +step:1040/1705 train_time:98673ms step_avg:94.88ms +step:1041/1705 train_time:98766ms step_avg:94.88ms +step:1042/1705 train_time:98861ms step_avg:94.88ms +step:1043/1705 train_time:98956ms step_avg:94.88ms +step:1044/1705 train_time:99052ms step_avg:94.88ms +step:1045/1705 train_time:99147ms step_avg:94.88ms +step:1046/1705 train_time:99242ms step_avg:94.88ms +step:1047/1705 train_time:99336ms step_avg:94.88ms +step:1048/1705 train_time:99431ms step_avg:94.88ms +step:1049/1705 train_time:99525ms step_avg:94.88ms +step:1050/1705 train_time:99619ms step_avg:94.87ms +step:1051/1705 train_time:99714ms step_avg:94.87ms +step:1052/1705 train_time:99807ms step_avg:94.87ms +step:1053/1705 train_time:99902ms step_avg:94.87ms +step:1054/1705 train_time:99997ms step_avg:94.87ms +step:1055/1705 train_time:100092ms step_avg:94.87ms +step:1056/1705 train_time:100185ms step_avg:94.87ms +step:1057/1705 train_time:100281ms step_avg:94.87ms +step:1058/1705 train_time:100377ms step_avg:94.87ms +step:1059/1705 train_time:100471ms step_avg:94.87ms +step:1060/1705 train_time:100565ms step_avg:94.87ms +step:1061/1705 train_time:100660ms step_avg:94.87ms +step:1062/1705 train_time:100930ms step_avg:95.04ms +step:1063/1705 train_time:101033ms step_avg:95.04ms +step:1064/1705 train_time:101125ms step_avg:95.04ms +step:1065/1705 train_time:101218ms step_avg:95.04ms +step:1066/1705 train_time:101311ms step_avg:95.04ms +step:1067/1705 train_time:101404ms step_avg:95.04ms +step:1068/1705 train_time:101498ms step_avg:95.04ms +step:1069/1705 train_time:101592ms step_avg:95.03ms +step:1070/1705 train_time:101685ms step_avg:95.03ms +step:1071/1705 train_time:101778ms step_avg:95.03ms +step:1072/1705 train_time:101877ms step_avg:95.03ms +step:1073/1705 train_time:101975ms step_avg:95.04ms +step:1074/1705 train_time:102073ms step_avg:95.04ms +step:1075/1705 train_time:102166ms step_avg:95.04ms +step:1076/1705 train_time:102261ms step_avg:95.04ms +step:1077/1705 train_time:102356ms step_avg:95.04ms +step:1078/1705 train_time:102450ms step_avg:95.04ms +step:1079/1705 train_time:102543ms step_avg:95.04ms +step:1080/1705 train_time:102637ms step_avg:95.03ms +step:1081/1705 train_time:102730ms step_avg:95.03ms +step:1082/1705 train_time:102824ms step_avg:95.03ms +step:1083/1705 train_time:102921ms step_avg:95.03ms +step:1084/1705 train_time:103017ms step_avg:95.03ms +step:1085/1705 train_time:103112ms step_avg:95.03ms +step:1086/1705 train_time:103207ms step_avg:95.03ms +step:1087/1705 train_time:103302ms step_avg:95.03ms +step:1088/1705 train_time:103397ms step_avg:95.03ms +step:1089/1705 train_time:103491ms step_avg:95.03ms +step:1090/1705 train_time:103585ms step_avg:95.03ms +step:1091/1705 train_time:103679ms step_avg:95.03ms +step:1092/1705 train_time:103775ms step_avg:95.03ms +step:1093/1705 train_time:103869ms step_avg:95.03ms +step:1094/1705 train_time:103964ms step_avg:95.03ms +step:1095/1705 train_time:104059ms step_avg:95.03ms +step:1096/1705 train_time:104155ms step_avg:95.03ms +step:1097/1705 train_time:104250ms step_avg:95.03ms +step:1098/1705 train_time:104344ms step_avg:95.03ms +step:1099/1705 train_time:104439ms step_avg:95.03ms +step:1100/1705 train_time:104533ms step_avg:95.03ms +step:1101/1705 train_time:104626ms step_avg:95.03ms +step:1102/1705 train_time:104721ms step_avg:95.03ms +step:1103/1705 train_time:104816ms step_avg:95.03ms +step:1104/1705 train_time:104910ms step_avg:95.03ms +step:1105/1705 train_time:105004ms step_avg:95.03ms +step:1106/1705 train_time:105100ms step_avg:95.03ms +step:1107/1705 train_time:105195ms step_avg:95.03ms +step:1108/1705 train_time:105289ms step_avg:95.03ms +step:1109/1705 train_time:105383ms step_avg:95.03ms +step:1110/1705 train_time:105478ms step_avg:95.02ms +step:1111/1705 train_time:105572ms step_avg:95.02ms +step:1112/1705 train_time:105666ms step_avg:95.02ms +step:1113/1705 train_time:105760ms step_avg:95.02ms +step:1114/1705 train_time:105855ms step_avg:95.02ms +step:1115/1705 train_time:105950ms step_avg:95.02ms +step:1116/1705 train_time:106045ms step_avg:95.02ms +step:1117/1705 train_time:106140ms step_avg:95.02ms +step:1118/1705 train_time:106235ms step_avg:95.02ms +step:1119/1705 train_time:106329ms step_avg:95.02ms +step:1120/1705 train_time:106423ms step_avg:95.02ms +step:1121/1705 train_time:106518ms step_avg:95.02ms +step:1122/1705 train_time:106611ms step_avg:95.02ms +step:1123/1705 train_time:106705ms step_avg:95.02ms +step:1124/1705 train_time:106800ms step_avg:95.02ms +step:1125/1705 train_time:106895ms step_avg:95.02ms +step:1125/1705 val_loss:3.4373 train_time:106990ms step_avg:95.10ms +step:1126/1705 train_time:107012ms step_avg:95.04ms +step:1127/1705 train_time:107091ms step_avg:95.02ms +step:1128/1705 train_time:107191ms step_avg:95.03ms +step:1129/1705 train_time:107287ms step_avg:95.03ms +step:1130/1705 train_time:107381ms step_avg:95.03ms +step:1131/1705 train_time:107475ms step_avg:95.03ms +step:1132/1705 train_time:107568ms step_avg:95.02ms +step:1133/1705 train_time:107662ms step_avg:95.02ms +step:1134/1705 train_time:107755ms step_avg:95.02ms +step:1135/1705 train_time:107848ms step_avg:95.02ms +step:1136/1705 train_time:107943ms step_avg:95.02ms +step:1137/1705 train_time:108039ms step_avg:95.02ms +step:1138/1705 train_time:108137ms step_avg:95.02ms +step:1139/1705 train_time:108233ms step_avg:95.02ms +step:1140/1705 train_time:108328ms step_avg:95.02ms +step:1141/1705 train_time:108424ms step_avg:95.03ms +step:1142/1705 train_time:108518ms step_avg:95.02ms +step:1143/1705 train_time:108612ms step_avg:95.02ms +step:1144/1705 train_time:108708ms step_avg:95.02ms +step:1145/1705 train_time:108802ms step_avg:95.02ms +step:1146/1705 train_time:108897ms step_avg:95.02ms +step:1147/1705 train_time:108992ms step_avg:95.02ms +step:1148/1705 train_time:109088ms step_avg:95.02ms +step:1149/1705 train_time:109184ms step_avg:95.03ms +step:1150/1705 train_time:109281ms step_avg:95.03ms +step:1151/1705 train_time:109376ms step_avg:95.03ms +step:1152/1705 train_time:109470ms step_avg:95.03ms +step:1153/1705 train_time:109565ms step_avg:95.03ms +step:1154/1705 train_time:109660ms step_avg:95.03ms +step:1155/1705 train_time:109754ms step_avg:95.03ms +step:1156/1705 train_time:109850ms step_avg:95.03ms +step:1157/1705 train_time:109945ms step_avg:95.03ms +step:1158/1705 train_time:110042ms step_avg:95.03ms +step:1159/1705 train_time:110137ms step_avg:95.03ms +step:1160/1705 train_time:110233ms step_avg:95.03ms +step:1161/1705 train_time:110329ms step_avg:95.03ms +step:1162/1705 train_time:110425ms step_avg:95.03ms +step:1163/1705 train_time:110520ms step_avg:95.03ms +step:1164/1705 train_time:110615ms step_avg:95.03ms +step:1165/1705 train_time:110709ms step_avg:95.03ms +step:1166/1705 train_time:110805ms step_avg:95.03ms +step:1167/1705 train_time:110900ms step_avg:95.03ms +step:1168/1705 train_time:110995ms step_avg:95.03ms +step:1169/1705 train_time:111091ms step_avg:95.03ms +step:1170/1705 train_time:111187ms step_avg:95.03ms +step:1171/1705 train_time:111283ms step_avg:95.03ms +step:1172/1705 train_time:111380ms step_avg:95.03ms +step:1173/1705 train_time:111475ms step_avg:95.03ms +step:1174/1705 train_time:111569ms step_avg:95.03ms +step:1175/1705 train_time:111664ms step_avg:95.03ms +step:1176/1705 train_time:111760ms step_avg:95.03ms +step:1177/1705 train_time:111855ms step_avg:95.03ms +step:1178/1705 train_time:111950ms step_avg:95.03ms +step:1179/1705 train_time:112045ms step_avg:95.03ms +step:1180/1705 train_time:112140ms step_avg:95.03ms +step:1181/1705 train_time:112236ms step_avg:95.04ms +step:1182/1705 train_time:112331ms step_avg:95.03ms +step:1183/1705 train_time:112427ms step_avg:95.04ms +step:1184/1705 train_time:112523ms step_avg:95.04ms +step:1185/1705 train_time:112618ms step_avg:95.04ms +step:1186/1705 train_time:112713ms step_avg:95.04ms +step:1187/1705 train_time:112808ms step_avg:95.04ms +step:1188/1705 train_time:112903ms step_avg:95.04ms +step:1189/1705 train_time:112999ms step_avg:95.04ms +step:1190/1705 train_time:113094ms step_avg:95.04ms +step:1191/1705 train_time:113189ms step_avg:95.04ms +step:1192/1705 train_time:113285ms step_avg:95.04ms +step:1193/1705 train_time:113380ms step_avg:95.04ms +step:1194/1705 train_time:113476ms step_avg:95.04ms +step:1195/1705 train_time:113570ms step_avg:95.04ms +step:1196/1705 train_time:113665ms step_avg:95.04ms +step:1197/1705 train_time:113761ms step_avg:95.04ms +step:1198/1705 train_time:113856ms step_avg:95.04ms +step:1199/1705 train_time:113950ms step_avg:95.04ms +step:1200/1705 train_time:114046ms step_avg:95.04ms +step:1201/1705 train_time:114142ms step_avg:95.04ms +step:1202/1705 train_time:114237ms step_avg:95.04ms +step:1203/1705 train_time:114332ms step_avg:95.04ms +step:1204/1705 train_time:114427ms step_avg:95.04ms +step:1205/1705 train_time:114523ms step_avg:95.04ms +step:1206/1705 train_time:114620ms step_avg:95.04ms +step:1207/1705 train_time:114715ms step_avg:95.04ms +step:1208/1705 train_time:114810ms step_avg:95.04ms +step:1209/1705 train_time:114905ms step_avg:95.04ms +step:1210/1705 train_time:115001ms step_avg:95.04ms +step:1211/1705 train_time:115096ms step_avg:95.04ms +step:1212/1705 train_time:115191ms step_avg:95.04ms +step:1213/1705 train_time:115286ms step_avg:95.04ms +step:1214/1705 train_time:115381ms step_avg:95.04ms +step:1215/1705 train_time:115477ms step_avg:95.04ms +step:1216/1705 train_time:115571ms step_avg:95.04ms +step:1217/1705 train_time:115667ms step_avg:95.04ms +step:1218/1705 train_time:115764ms step_avg:95.04ms +step:1219/1705 train_time:115860ms step_avg:95.05ms +step:1220/1705 train_time:115956ms step_avg:95.05ms +step:1221/1705 train_time:116051ms step_avg:95.05ms +step:1222/1705 train_time:116146ms step_avg:95.05ms +step:1223/1705 train_time:116242ms step_avg:95.05ms +step:1224/1705 train_time:116336ms step_avg:95.05ms +step:1225/1705 train_time:116431ms step_avg:95.05ms +step:1226/1705 train_time:116527ms step_avg:95.05ms +step:1227/1705 train_time:116622ms step_avg:95.05ms +step:1228/1705 train_time:116718ms step_avg:95.05ms +step:1229/1705 train_time:116813ms step_avg:95.05ms +step:1230/1705 train_time:116908ms step_avg:95.05ms +step:1231/1705 train_time:117004ms step_avg:95.05ms +step:1232/1705 train_time:117100ms step_avg:95.05ms +step:1233/1705 train_time:117195ms step_avg:95.05ms +step:1234/1705 train_time:117289ms step_avg:95.05ms +step:1235/1705 train_time:117385ms step_avg:95.05ms +step:1236/1705 train_time:117481ms step_avg:95.05ms +step:1237/1705 train_time:117576ms step_avg:95.05ms +step:1238/1705 train_time:117672ms step_avg:95.05ms +step:1239/1705 train_time:117768ms step_avg:95.05ms +step:1240/1705 train_time:117863ms step_avg:95.05ms +step:1241/1705 train_time:117958ms step_avg:95.05ms +step:1242/1705 train_time:118052ms step_avg:95.05ms +step:1243/1705 train_time:118147ms step_avg:95.05ms +step:1244/1705 train_time:118243ms step_avg:95.05ms +step:1245/1705 train_time:118338ms step_avg:95.05ms +step:1246/1705 train_time:118433ms step_avg:95.05ms +step:1247/1705 train_time:118529ms step_avg:95.05ms +step:1248/1705 train_time:118625ms step_avg:95.05ms +step:1249/1705 train_time:118720ms step_avg:95.05ms +step:1250/1705 train_time:118815ms step_avg:95.05ms +step:1250/1705 val_loss:3.3888 train_time:118911ms step_avg:95.13ms +step:1251/1705 train_time:118932ms step_avg:95.07ms +step:1252/1705 train_time:119014ms step_avg:95.06ms +step:1253/1705 train_time:119112ms step_avg:95.06ms +step:1254/1705 train_time:119206ms step_avg:95.06ms +step:1255/1705 train_time:119300ms step_avg:95.06ms +step:1256/1705 train_time:119394ms step_avg:95.06ms +step:1257/1705 train_time:119488ms step_avg:95.06ms +step:1258/1705 train_time:119583ms step_avg:95.06ms +step:1259/1705 train_time:119677ms step_avg:95.06ms +step:1260/1705 train_time:119770ms step_avg:95.06ms +step:1261/1705 train_time:119867ms step_avg:95.06ms +step:1262/1705 train_time:119966ms step_avg:95.06ms +step:1263/1705 train_time:120065ms step_avg:95.06ms +step:1264/1705 train_time:120161ms step_avg:95.06ms +step:1265/1705 train_time:120256ms step_avg:95.06ms +step:1266/1705 train_time:120350ms step_avg:95.06ms +step:1267/1705 train_time:120445ms step_avg:95.06ms +step:1268/1705 train_time:120539ms step_avg:95.06ms +step:1269/1705 train_time:120632ms step_avg:95.06ms +step:1270/1705 train_time:120727ms step_avg:95.06ms +step:1271/1705 train_time:120823ms step_avg:95.06ms +step:1272/1705 train_time:120918ms step_avg:95.06ms +step:1273/1705 train_time:121015ms step_avg:95.06ms +step:1274/1705 train_time:121423ms step_avg:95.31ms +step:1275/1705 train_time:121493ms step_avg:95.29ms +step:1276/1705 train_time:121588ms step_avg:95.29ms +step:1277/1705 train_time:121682ms step_avg:95.29ms +step:1278/1705 train_time:121776ms step_avg:95.29ms +step:1279/1705 train_time:121869ms step_avg:95.28ms +step:1280/1705 train_time:121963ms step_avg:95.28ms +step:1281/1705 train_time:122057ms step_avg:95.28ms +step:1282/1705 train_time:122151ms step_avg:95.28ms +step:1283/1705 train_time:122245ms step_avg:95.28ms +step:1284/1705 train_time:122347ms step_avg:95.29ms +step:1285/1705 train_time:122445ms step_avg:95.29ms +step:1286/1705 train_time:122544ms step_avg:95.29ms +step:1287/1705 train_time:122640ms step_avg:95.29ms +step:1288/1705 train_time:122735ms step_avg:95.29ms +step:1289/1705 train_time:122829ms step_avg:95.29ms +step:1290/1705 train_time:122924ms step_avg:95.29ms +step:1291/1705 train_time:123018ms step_avg:95.29ms +step:1292/1705 train_time:123112ms step_avg:95.29ms +step:1293/1705 train_time:123207ms step_avg:95.29ms +step:1294/1705 train_time:123303ms step_avg:95.29ms +step:1295/1705 train_time:123400ms step_avg:95.29ms +step:1296/1705 train_time:123497ms step_avg:95.29ms +step:1297/1705 train_time:123592ms step_avg:95.29ms +step:1298/1705 train_time:123688ms step_avg:95.29ms +step:1299/1705 train_time:123785ms step_avg:95.29ms +step:1300/1705 train_time:123880ms step_avg:95.29ms +step:1301/1705 train_time:123976ms step_avg:95.29ms +step:1302/1705 train_time:124069ms step_avg:95.29ms +step:1303/1705 train_time:124165ms step_avg:95.29ms +step:1304/1705 train_time:124260ms step_avg:95.29ms +step:1305/1705 train_time:124355ms step_avg:95.29ms +step:1306/1705 train_time:124452ms step_avg:95.29ms +step:1307/1705 train_time:124549ms step_avg:95.29ms +step:1308/1705 train_time:124645ms step_avg:95.29ms +step:1309/1705 train_time:124741ms step_avg:95.29ms +step:1310/1705 train_time:124836ms step_avg:95.29ms +step:1311/1705 train_time:124931ms step_avg:95.29ms +step:1312/1705 train_time:125025ms step_avg:95.29ms +step:1313/1705 train_time:125121ms step_avg:95.29ms +step:1314/1705 train_time:125216ms step_avg:95.29ms +step:1315/1705 train_time:125310ms step_avg:95.29ms +step:1316/1705 train_time:125406ms step_avg:95.29ms +step:1317/1705 train_time:125502ms step_avg:95.29ms +step:1318/1705 train_time:125597ms step_avg:95.29ms +step:1319/1705 train_time:125692ms step_avg:95.29ms +step:1320/1705 train_time:125788ms step_avg:95.29ms +step:1321/1705 train_time:125884ms step_avg:95.29ms +step:1322/1705 train_time:125979ms step_avg:95.29ms +step:1323/1705 train_time:126073ms step_avg:95.29ms +step:1324/1705 train_time:126168ms step_avg:95.29ms +step:1325/1705 train_time:126263ms step_avg:95.29ms +step:1326/1705 train_time:126359ms step_avg:95.29ms +step:1327/1705 train_time:126454ms step_avg:95.29ms +step:1328/1705 train_time:126550ms step_avg:95.29ms +step:1329/1705 train_time:126647ms step_avg:95.30ms +step:1330/1705 train_time:126742ms step_avg:95.30ms +step:1331/1705 train_time:126838ms step_avg:95.29ms +step:1332/1705 train_time:126932ms step_avg:95.29ms +step:1333/1705 train_time:127028ms step_avg:95.29ms +step:1334/1705 train_time:127123ms step_avg:95.29ms +step:1335/1705 train_time:127218ms step_avg:95.29ms +step:1336/1705 train_time:127312ms step_avg:95.29ms +step:1337/1705 train_time:127408ms step_avg:95.29ms +step:1338/1705 train_time:127504ms step_avg:95.29ms +step:1339/1705 train_time:127600ms step_avg:95.30ms +step:1340/1705 train_time:127696ms step_avg:95.30ms +step:1341/1705 train_time:127791ms step_avg:95.30ms +step:1342/1705 train_time:127886ms step_avg:95.30ms +step:1343/1705 train_time:127982ms step_avg:95.30ms +step:1344/1705 train_time:128076ms step_avg:95.29ms +step:1345/1705 train_time:128170ms step_avg:95.29ms +step:1346/1705 train_time:128266ms step_avg:95.29ms +step:1347/1705 train_time:128362ms step_avg:95.29ms +step:1348/1705 train_time:128457ms step_avg:95.29ms +step:1349/1705 train_time:128553ms step_avg:95.30ms +step:1350/1705 train_time:128649ms step_avg:95.30ms +step:1351/1705 train_time:128745ms step_avg:95.30ms +step:1352/1705 train_time:128841ms step_avg:95.30ms +step:1353/1705 train_time:128937ms step_avg:95.30ms +step:1354/1705 train_time:129032ms step_avg:95.30ms +step:1355/1705 train_time:129127ms step_avg:95.30ms +step:1356/1705 train_time:129223ms step_avg:95.30ms +step:1357/1705 train_time:129319ms step_avg:95.30ms +step:1358/1705 train_time:129413ms step_avg:95.30ms +step:1359/1705 train_time:129509ms step_avg:95.30ms +step:1360/1705 train_time:129605ms step_avg:95.30ms +step:1361/1705 train_time:129699ms step_avg:95.30ms +step:1362/1705 train_time:129795ms step_avg:95.30ms +step:1363/1705 train_time:129890ms step_avg:95.30ms +step:1364/1705 train_time:129985ms step_avg:95.30ms +step:1365/1705 train_time:130080ms step_avg:95.30ms +step:1366/1705 train_time:130174ms step_avg:95.30ms +step:1367/1705 train_time:130269ms step_avg:95.30ms +step:1368/1705 train_time:130365ms step_avg:95.30ms +step:1369/1705 train_time:130461ms step_avg:95.30ms +step:1370/1705 train_time:130557ms step_avg:95.30ms +step:1371/1705 train_time:130653ms step_avg:95.30ms +step:1372/1705 train_time:130748ms step_avg:95.30ms +step:1373/1705 train_time:130844ms step_avg:95.30ms +step:1374/1705 train_time:130939ms step_avg:95.30ms +step:1375/1705 train_time:131034ms step_avg:95.30ms +step:1375/1705 val_loss:3.3520 train_time:131130ms step_avg:95.37ms +step:1376/1705 train_time:131151ms step_avg:95.31ms +step:1377/1705 train_time:131233ms step_avg:95.30ms +step:1378/1705 train_time:131335ms step_avg:95.31ms +step:1379/1705 train_time:131431ms step_avg:95.31ms +step:1380/1705 train_time:131524ms step_avg:95.31ms +step:1381/1705 train_time:131619ms step_avg:95.31ms +step:1382/1705 train_time:131713ms step_avg:95.31ms +step:1383/1705 train_time:131808ms step_avg:95.31ms +step:1384/1705 train_time:131901ms step_avg:95.30ms +step:1385/1705 train_time:131996ms step_avg:95.30ms +step:1386/1705 train_time:132091ms step_avg:95.30ms +step:1387/1705 train_time:132188ms step_avg:95.30ms +step:1388/1705 train_time:132286ms step_avg:95.31ms +step:1389/1705 train_time:132382ms step_avg:95.31ms +step:1390/1705 train_time:132478ms step_avg:95.31ms +step:1391/1705 train_time:132573ms step_avg:95.31ms +step:1392/1705 train_time:132667ms step_avg:95.31ms +step:1393/1705 train_time:132761ms step_avg:95.31ms +step:1394/1705 train_time:132856ms step_avg:95.31ms +step:1395/1705 train_time:132951ms step_avg:95.31ms +step:1396/1705 train_time:133045ms step_avg:95.30ms +step:1397/1705 train_time:133140ms step_avg:95.30ms +step:1398/1705 train_time:133237ms step_avg:95.31ms +step:1399/1705 train_time:133335ms step_avg:95.31ms +step:1400/1705 train_time:133431ms step_avg:95.31ms +step:1401/1705 train_time:133527ms step_avg:95.31ms +step:1402/1705 train_time:133621ms step_avg:95.31ms +step:1403/1705 train_time:133715ms step_avg:95.31ms +step:1404/1705 train_time:133810ms step_avg:95.31ms +step:1405/1705 train_time:133905ms step_avg:95.31ms +step:1406/1705 train_time:134000ms step_avg:95.31ms +step:1407/1705 train_time:134096ms step_avg:95.31ms +step:1408/1705 train_time:134192ms step_avg:95.31ms +step:1409/1705 train_time:134288ms step_avg:95.31ms +step:1410/1705 train_time:134384ms step_avg:95.31ms +step:1411/1705 train_time:134479ms step_avg:95.31ms +step:1412/1705 train_time:134575ms step_avg:95.31ms +step:1413/1705 train_time:134671ms step_avg:95.31ms +step:1414/1705 train_time:134765ms step_avg:95.31ms +step:1415/1705 train_time:134860ms step_avg:95.31ms +step:1416/1705 train_time:134955ms step_avg:95.31ms +step:1417/1705 train_time:135050ms step_avg:95.31ms +step:1418/1705 train_time:135144ms step_avg:95.31ms +step:1419/1705 train_time:135240ms step_avg:95.31ms +step:1420/1705 train_time:135337ms step_avg:95.31ms +step:1421/1705 train_time:135433ms step_avg:95.31ms +step:1422/1705 train_time:135527ms step_avg:95.31ms +step:1423/1705 train_time:135622ms step_avg:95.31ms +step:1424/1705 train_time:135718ms step_avg:95.31ms +step:1425/1705 train_time:135813ms step_avg:95.31ms +step:1426/1705 train_time:135907ms step_avg:95.31ms +step:1427/1705 train_time:136002ms step_avg:95.31ms +step:1428/1705 train_time:136097ms step_avg:95.31ms +step:1429/1705 train_time:136193ms step_avg:95.31ms +step:1430/1705 train_time:136290ms step_avg:95.31ms +step:1431/1705 train_time:136385ms step_avg:95.31ms +step:1432/1705 train_time:136480ms step_avg:95.31ms +step:1433/1705 train_time:136576ms step_avg:95.31ms +step:1434/1705 train_time:136672ms step_avg:95.31ms +step:1435/1705 train_time:136767ms step_avg:95.31ms +step:1436/1705 train_time:136862ms step_avg:95.31ms +step:1437/1705 train_time:136958ms step_avg:95.31ms +step:1438/1705 train_time:137052ms step_avg:95.31ms +step:1439/1705 train_time:137147ms step_avg:95.31ms +step:1440/1705 train_time:137243ms step_avg:95.31ms +step:1441/1705 train_time:137338ms step_avg:95.31ms +step:1442/1705 train_time:137434ms step_avg:95.31ms +step:1443/1705 train_time:137530ms step_avg:95.31ms +step:1444/1705 train_time:137625ms step_avg:95.31ms +step:1445/1705 train_time:137720ms step_avg:95.31ms +step:1446/1705 train_time:137815ms step_avg:95.31ms +step:1447/1705 train_time:137911ms step_avg:95.31ms +step:1448/1705 train_time:138007ms step_avg:95.31ms +step:1449/1705 train_time:138102ms step_avg:95.31ms +step:1450/1705 train_time:138198ms step_avg:95.31ms +step:1451/1705 train_time:138293ms step_avg:95.31ms +step:1452/1705 train_time:138388ms step_avg:95.31ms +step:1453/1705 train_time:138483ms step_avg:95.31ms +step:1454/1705 train_time:138579ms step_avg:95.31ms +step:1455/1705 train_time:138675ms step_avg:95.31ms +step:1456/1705 train_time:138772ms step_avg:95.31ms +step:1457/1705 train_time:138869ms step_avg:95.31ms +step:1458/1705 train_time:138964ms step_avg:95.31ms +step:1459/1705 train_time:139059ms step_avg:95.31ms +step:1460/1705 train_time:139154ms step_avg:95.31ms +step:1461/1705 train_time:139250ms step_avg:95.31ms +step:1462/1705 train_time:139345ms step_avg:95.31ms +step:1463/1705 train_time:139440ms step_avg:95.31ms +step:1464/1705 train_time:139535ms step_avg:95.31ms +step:1465/1705 train_time:139630ms step_avg:95.31ms +step:1466/1705 train_time:139725ms step_avg:95.31ms +step:1467/1705 train_time:139821ms step_avg:95.31ms +step:1468/1705 train_time:139917ms step_avg:95.31ms +step:1469/1705 train_time:140013ms step_avg:95.31ms +step:1470/1705 train_time:140108ms step_avg:95.31ms +step:1471/1705 train_time:140203ms step_avg:95.31ms +step:1472/1705 train_time:140298ms step_avg:95.31ms +step:1473/1705 train_time:140395ms step_avg:95.31ms +step:1474/1705 train_time:140490ms step_avg:95.31ms +step:1475/1705 train_time:140585ms step_avg:95.31ms +step:1476/1705 train_time:140681ms step_avg:95.31ms +step:1477/1705 train_time:140776ms step_avg:95.31ms +step:1478/1705 train_time:140872ms step_avg:95.31ms +step:1479/1705 train_time:140967ms step_avg:95.31ms +step:1480/1705 train_time:141063ms step_avg:95.31ms +step:1481/1705 train_time:141159ms step_avg:95.31ms +step:1482/1705 train_time:141255ms step_avg:95.31ms +step:1483/1705 train_time:141351ms step_avg:95.31ms +step:1484/1705 train_time:141445ms step_avg:95.31ms +step:1485/1705 train_time:141723ms step_avg:95.44ms +step:1486/1705 train_time:141887ms step_avg:95.48ms +step:1487/1705 train_time:141980ms step_avg:95.48ms +step:1488/1705 train_time:142074ms step_avg:95.48ms +step:1489/1705 train_time:142168ms step_avg:95.48ms +step:1490/1705 train_time:142262ms step_avg:95.48ms +step:1491/1705 train_time:142357ms step_avg:95.48ms +step:1492/1705 train_time:142451ms step_avg:95.48ms +step:1493/1705 train_time:142545ms step_avg:95.48ms +step:1494/1705 train_time:142640ms step_avg:95.48ms +step:1495/1705 train_time:142740ms step_avg:95.48ms +step:1496/1705 train_time:142842ms step_avg:95.48ms +step:1497/1705 train_time:142939ms step_avg:95.48ms +step:1498/1705 train_time:143035ms step_avg:95.48ms +step:1499/1705 train_time:143131ms step_avg:95.48ms +step:1500/1705 train_time:143224ms step_avg:95.48ms +step:1500/1705 val_loss:3.3197 train_time:143319ms step_avg:95.55ms +step:1501/1705 train_time:143340ms step_avg:95.50ms +step:1502/1705 train_time:143420ms step_avg:95.49ms +step:1503/1705 train_time:143518ms step_avg:95.49ms +step:1504/1705 train_time:143613ms step_avg:95.49ms +step:1505/1705 train_time:143709ms step_avg:95.49ms +step:1506/1705 train_time:143803ms step_avg:95.49ms +step:1507/1705 train_time:143896ms step_avg:95.49ms +step:1508/1705 train_time:143990ms step_avg:95.48ms +step:1509/1705 train_time:144085ms step_avg:95.48ms +step:1510/1705 train_time:144179ms step_avg:95.48ms +step:1511/1705 train_time:144274ms step_avg:95.48ms +step:1512/1705 train_time:144372ms step_avg:95.48ms +step:1513/1705 train_time:144470ms step_avg:95.49ms +step:1514/1705 train_time:144569ms step_avg:95.49ms +step:1515/1705 train_time:144665ms step_avg:95.49ms +step:1516/1705 train_time:144760ms step_avg:95.49ms +step:1517/1705 train_time:144853ms step_avg:95.49ms +step:1518/1705 train_time:144948ms step_avg:95.49ms +step:1519/1705 train_time:145042ms step_avg:95.49ms +step:1520/1705 train_time:145137ms step_avg:95.48ms +step:1521/1705 train_time:145231ms step_avg:95.48ms +step:1522/1705 train_time:145327ms step_avg:95.48ms +step:1523/1705 train_time:145425ms step_avg:95.49ms +step:1524/1705 train_time:145522ms step_avg:95.49ms +step:1525/1705 train_time:145618ms step_avg:95.49ms +step:1526/1705 train_time:145713ms step_avg:95.49ms +step:1527/1705 train_time:145808ms step_avg:95.49ms +step:1528/1705 train_time:145904ms step_avg:95.49ms +step:1529/1705 train_time:145999ms step_avg:95.49ms +step:1530/1705 train_time:146093ms step_avg:95.49ms +step:1531/1705 train_time:146187ms step_avg:95.48ms +step:1532/1705 train_time:146283ms step_avg:95.48ms +step:1533/1705 train_time:146379ms step_avg:95.49ms +step:1534/1705 train_time:146474ms step_avg:95.49ms +step:1535/1705 train_time:146571ms step_avg:95.49ms +step:1536/1705 train_time:146666ms step_avg:95.49ms +step:1537/1705 train_time:146762ms step_avg:95.49ms +step:1538/1705 train_time:146857ms step_avg:95.49ms +step:1539/1705 train_time:146951ms step_avg:95.48ms +step:1540/1705 train_time:147046ms step_avg:95.48ms +step:1541/1705 train_time:147142ms step_avg:95.48ms +step:1542/1705 train_time:147237ms step_avg:95.48ms +step:1543/1705 train_time:147332ms step_avg:95.48ms +step:1544/1705 train_time:147428ms step_avg:95.48ms +step:1545/1705 train_time:147525ms step_avg:95.49ms +step:1546/1705 train_time:147622ms step_avg:95.49ms +step:1547/1705 train_time:147717ms step_avg:95.49ms +step:1548/1705 train_time:147811ms step_avg:95.49ms +step:1549/1705 train_time:147907ms step_avg:95.49ms +step:1550/1705 train_time:148003ms step_avg:95.49ms +step:1551/1705 train_time:148098ms step_avg:95.49ms +step:1552/1705 train_time:148192ms step_avg:95.48ms +step:1553/1705 train_time:148289ms step_avg:95.49ms +step:1554/1705 train_time:148384ms step_avg:95.49ms +step:1555/1705 train_time:148480ms step_avg:95.49ms +step:1556/1705 train_time:148575ms step_avg:95.49ms +step:1557/1705 train_time:148671ms step_avg:95.49ms +step:1558/1705 train_time:148766ms step_avg:95.49ms +step:1559/1705 train_time:148862ms step_avg:95.49ms +step:1560/1705 train_time:148958ms step_avg:95.49ms +step:1561/1705 train_time:149053ms step_avg:95.49ms +step:1562/1705 train_time:149148ms step_avg:95.49ms +step:1563/1705 train_time:149244ms step_avg:95.49ms +step:1564/1705 train_time:149340ms step_avg:95.49ms +step:1565/1705 train_time:149436ms step_avg:95.49ms +step:1566/1705 train_time:149531ms step_avg:95.49ms +step:1567/1705 train_time:149626ms step_avg:95.49ms +step:1568/1705 train_time:149721ms step_avg:95.49ms +step:1569/1705 train_time:149816ms step_avg:95.49ms +step:1570/1705 train_time:149912ms step_avg:95.49ms +step:1571/1705 train_time:150007ms step_avg:95.49ms +step:1572/1705 train_time:150102ms step_avg:95.48ms +step:1573/1705 train_time:150197ms step_avg:95.48ms +step:1574/1705 train_time:150292ms step_avg:95.48ms +step:1575/1705 train_time:150387ms step_avg:95.48ms +step:1576/1705 train_time:150482ms step_avg:95.48ms +step:1577/1705 train_time:150577ms step_avg:95.48ms +step:1578/1705 train_time:150673ms step_avg:95.48ms +step:1579/1705 train_time:150768ms step_avg:95.48ms +step:1580/1705 train_time:150864ms step_avg:95.48ms +step:1581/1705 train_time:150959ms step_avg:95.48ms +step:1582/1705 train_time:151054ms step_avg:95.48ms +step:1583/1705 train_time:151149ms step_avg:95.48ms +step:1584/1705 train_time:151245ms step_avg:95.48ms +step:1585/1705 train_time:151342ms step_avg:95.48ms +step:1586/1705 train_time:151437ms step_avg:95.48ms +step:1587/1705 train_time:151531ms step_avg:95.48ms +step:1588/1705 train_time:151627ms step_avg:95.48ms +step:1589/1705 train_time:151723ms step_avg:95.48ms +step:1590/1705 train_time:151818ms step_avg:95.48ms +step:1591/1705 train_time:151913ms step_avg:95.48ms +step:1592/1705 train_time:152009ms step_avg:95.48ms +step:1593/1705 train_time:152104ms step_avg:95.48ms +step:1594/1705 train_time:152200ms step_avg:95.48ms +step:1595/1705 train_time:152294ms step_avg:95.48ms +step:1596/1705 train_time:152389ms step_avg:95.48ms +step:1597/1705 train_time:152485ms step_avg:95.48ms +step:1598/1705 train_time:152581ms step_avg:95.48ms +step:1599/1705 train_time:152677ms step_avg:95.48ms +step:1600/1705 train_time:152772ms step_avg:95.48ms +step:1601/1705 train_time:152867ms step_avg:95.48ms +step:1602/1705 train_time:152963ms step_avg:95.48ms +step:1603/1705 train_time:153058ms step_avg:95.48ms +step:1604/1705 train_time:153153ms step_avg:95.48ms +step:1605/1705 train_time:153248ms step_avg:95.48ms +step:1606/1705 train_time:153344ms step_avg:95.48ms +step:1607/1705 train_time:153440ms step_avg:95.48ms +step:1608/1705 train_time:153534ms step_avg:95.48ms +step:1609/1705 train_time:153629ms step_avg:95.48ms +step:1610/1705 train_time:153726ms step_avg:95.48ms +step:1611/1705 train_time:153821ms step_avg:95.48ms +step:1612/1705 train_time:153916ms step_avg:95.48ms +step:1613/1705 train_time:154012ms step_avg:95.48ms +step:1614/1705 train_time:154107ms step_avg:95.48ms +step:1615/1705 train_time:154203ms step_avg:95.48ms +step:1616/1705 train_time:154298ms step_avg:95.48ms +step:1617/1705 train_time:154393ms step_avg:95.48ms +step:1618/1705 train_time:154488ms step_avg:95.48ms +step:1619/1705 train_time:154585ms step_avg:95.48ms +step:1620/1705 train_time:154680ms step_avg:95.48ms +step:1621/1705 train_time:154775ms step_avg:95.48ms +step:1622/1705 train_time:154870ms step_avg:95.48ms +step:1623/1705 train_time:154966ms step_avg:95.48ms +step:1624/1705 train_time:155062ms step_avg:95.48ms +step:1625/1705 train_time:155158ms step_avg:95.48ms +step:1625/1705 val_loss:3.2922 train_time:155253ms step_avg:95.54ms +step:1626/1705 train_time:155275ms step_avg:95.50ms +step:1627/1705 train_time:155355ms step_avg:95.49ms +step:1628/1705 train_time:155457ms step_avg:95.49ms +step:1629/1705 train_time:155553ms step_avg:95.49ms +step:1630/1705 train_time:155647ms step_avg:95.49ms +step:1631/1705 train_time:155741ms step_avg:95.49ms +step:1632/1705 train_time:155836ms step_avg:95.49ms +step:1633/1705 train_time:155930ms step_avg:95.49ms +step:1634/1705 train_time:156024ms step_avg:95.49ms +step:1635/1705 train_time:156119ms step_avg:95.49ms +step:1636/1705 train_time:156215ms step_avg:95.49ms +step:1637/1705 train_time:156313ms step_avg:95.49ms +step:1638/1705 train_time:156411ms step_avg:95.49ms +step:1639/1705 train_time:156508ms step_avg:95.49ms +step:1640/1705 train_time:156604ms step_avg:95.49ms +step:1641/1705 train_time:156700ms step_avg:95.49ms +step:1642/1705 train_time:156795ms step_avg:95.49ms +step:1643/1705 train_time:156889ms step_avg:95.49ms +step:1644/1705 train_time:156984ms step_avg:95.49ms +step:1645/1705 train_time:157079ms step_avg:95.49ms +step:1646/1705 train_time:157174ms step_avg:95.49ms +step:1647/1705 train_time:157269ms step_avg:95.49ms +step:1648/1705 train_time:157365ms step_avg:95.49ms +step:1649/1705 train_time:157461ms step_avg:95.49ms +step:1650/1705 train_time:157557ms step_avg:95.49ms +step:1651/1705 train_time:157652ms step_avg:95.49ms +step:1652/1705 train_time:157747ms step_avg:95.49ms +step:1653/1705 train_time:157843ms step_avg:95.49ms +step:1654/1705 train_time:157937ms step_avg:95.49ms +step:1655/1705 train_time:158032ms step_avg:95.49ms +step:1656/1705 train_time:158126ms step_avg:95.49ms +step:1657/1705 train_time:158222ms step_avg:95.49ms +step:1658/1705 train_time:158318ms step_avg:95.49ms +step:1659/1705 train_time:158416ms step_avg:95.49ms +step:1660/1705 train_time:158512ms step_avg:95.49ms +step:1661/1705 train_time:158607ms step_avg:95.49ms +step:1662/1705 train_time:158702ms step_avg:95.49ms +step:1663/1705 train_time:158798ms step_avg:95.49ms +step:1664/1705 train_time:158894ms step_avg:95.49ms +step:1665/1705 train_time:158988ms step_avg:95.49ms +step:1666/1705 train_time:159083ms step_avg:95.49ms +step:1667/1705 train_time:159179ms step_avg:95.49ms +step:1668/1705 train_time:159274ms step_avg:95.49ms +step:1669/1705 train_time:159369ms step_avg:95.49ms +step:1670/1705 train_time:159465ms step_avg:95.49ms +step:1671/1705 train_time:159560ms step_avg:95.49ms +step:1672/1705 train_time:159656ms step_avg:95.49ms +step:1673/1705 train_time:159752ms step_avg:95.49ms +step:1674/1705 train_time:159847ms step_avg:95.49ms +step:1675/1705 train_time:159943ms step_avg:95.49ms +step:1676/1705 train_time:160039ms step_avg:95.49ms +step:1677/1705 train_time:160134ms step_avg:95.49ms +step:1678/1705 train_time:160229ms step_avg:95.49ms +step:1679/1705 train_time:160324ms step_avg:95.49ms +step:1680/1705 train_time:160421ms step_avg:95.49ms +step:1681/1705 train_time:160516ms step_avg:95.49ms +step:1682/1705 train_time:160612ms step_avg:95.49ms +step:1683/1705 train_time:160707ms step_avg:95.49ms +step:1684/1705 train_time:160803ms step_avg:95.49ms +step:1685/1705 train_time:160900ms step_avg:95.49ms +step:1686/1705 train_time:160995ms step_avg:95.49ms +step:1687/1705 train_time:161090ms step_avg:95.49ms +step:1688/1705 train_time:161185ms step_avg:95.49ms +step:1689/1705 train_time:161280ms step_avg:95.49ms +step:1690/1705 train_time:161376ms step_avg:95.49ms +step:1691/1705 train_time:161471ms step_avg:95.49ms +step:1692/1705 train_time:161566ms step_avg:95.49ms +step:1693/1705 train_time:161661ms step_avg:95.49ms +step:1694/1705 train_time:161757ms step_avg:95.49ms +step:1695/1705 train_time:161853ms step_avg:95.49ms +step:1696/1705 train_time:161948ms step_avg:95.49ms +step:1697/1705 train_time:162043ms step_avg:95.49ms +step:1698/1705 train_time:162363ms step_avg:95.62ms +step:1699/1705 train_time:162483ms step_avg:95.63ms +step:1700/1705 train_time:162575ms step_avg:95.63ms +step:1701/1705 train_time:162669ms step_avg:95.63ms +step:1702/1705 train_time:162763ms step_avg:95.63ms +step:1703/1705 train_time:162857ms step_avg:95.63ms +step:1704/1705 train_time:162952ms step_avg:95.63ms +step:1705/1705 train_time:163046ms step_avg:95.63ms +step:1705/1705 val_loss:3.2778 train_time:163141ms step_avg:95.68ms +peak memory allocated: 33848 MiB reserved: 48936 MiB diff --git a/records/050925_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt b/records/050925_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt new file mode 100644 index 000000000..3c0cce51c --- /dev/null +++ b/records/050925_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:18:13 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 36C P0 123W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 37C P0 122W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 35C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 30C P0 122W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 47592 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 47593 C /usr/bin/python3 610MiB | +| 0 N/A N/A 47594 C /usr/bin/python3 610MiB | +| 0 N/A N/A 47595 C /usr/bin/python3 610MiB | +| 0 N/A N/A 47596 C /usr/bin/python3 610MiB | +| 0 N/A N/A 47597 C /usr/bin/python3 610MiB | +| 0 N/A N/A 47598 C /usr/bin/python3 610MiB | +| 0 N/A N/A 47599 C /usr/bin/python3 610MiB | +| 1 N/A N/A 47593 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 47594 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 47595 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 47596 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 47597 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 47598 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 47599 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1705 train_time:408ms step_avg:407.67ms +step:2/1705 train_time:429ms step_avg:214.72ms +step:3/1705 train_time:497ms step_avg:165.70ms +step:4/1705 train_time:588ms step_avg:146.96ms +step:5/1705 train_time:679ms step_avg:135.84ms +step:6/1705 train_time:772ms step_avg:128.59ms +step:7/1705 train_time:863ms step_avg:123.24ms +step:8/1705 train_time:955ms step_avg:119.35ms +step:9/1705 train_time:1047ms step_avg:116.30ms +step:10/1705 train_time:1139ms step_avg:113.93ms +step:11/1705 train_time:1231ms step_avg:111.95ms +step:12/1705 train_time:1325ms step_avg:110.38ms +step:13/1705 train_time:1419ms step_avg:109.15ms +step:14/1705 train_time:1513ms step_avg:108.09ms +step:15/1705 train_time:1607ms step_avg:107.14ms +step:16/1705 train_time:1699ms step_avg:106.18ms +step:17/1705 train_time:1791ms step_avg:105.38ms +step:18/1705 train_time:1884ms step_avg:104.66ms +step:19/1705 train_time:1976ms step_avg:103.99ms +step:20/1705 train_time:2069ms step_avg:103.45ms +step:21/1705 train_time:2161ms step_avg:102.93ms +step:22/1705 train_time:2255ms step_avg:102.50ms +step:23/1705 train_time:2349ms step_avg:102.15ms +step:24/1705 train_time:2443ms step_avg:101.78ms +step:25/1705 train_time:2537ms step_avg:101.47ms +step:26/1705 train_time:2630ms step_avg:101.16ms +step:27/1705 train_time:2722ms step_avg:100.83ms +step:28/1705 train_time:2815ms step_avg:100.54ms +step:29/1705 train_time:2908ms step_avg:100.27ms +step:30/1705 train_time:3000ms step_avg:99.99ms +step:31/1705 train_time:3092ms step_avg:99.76ms +step:32/1705 train_time:3186ms step_avg:99.56ms +step:33/1705 train_time:3279ms step_avg:99.35ms +step:34/1705 train_time:3373ms step_avg:99.20ms +step:35/1705 train_time:3466ms step_avg:99.04ms +step:36/1705 train_time:3559ms step_avg:98.87ms +step:37/1705 train_time:3653ms step_avg:98.73ms +step:38/1705 train_time:3746ms step_avg:98.59ms +step:39/1705 train_time:3839ms step_avg:98.42ms +step:40/1705 train_time:3932ms step_avg:98.29ms +step:41/1705 train_time:4024ms step_avg:98.15ms +step:42/1705 train_time:4117ms step_avg:98.03ms +step:43/1705 train_time:4211ms step_avg:97.92ms +step:44/1705 train_time:4303ms step_avg:97.80ms +step:45/1705 train_time:4397ms step_avg:97.70ms +step:46/1705 train_time:4490ms step_avg:97.61ms +step:47/1705 train_time:4583ms step_avg:97.51ms +step:48/1705 train_time:4676ms step_avg:97.42ms +step:49/1705 train_time:4770ms step_avg:97.34ms +step:50/1705 train_time:4862ms step_avg:97.24ms +step:51/1705 train_time:4955ms step_avg:97.17ms +step:52/1705 train_time:5049ms step_avg:97.09ms +step:53/1705 train_time:5142ms step_avg:97.02ms +step:54/1705 train_time:5235ms step_avg:96.95ms +step:55/1705 train_time:5328ms step_avg:96.87ms +step:56/1705 train_time:5420ms step_avg:96.79ms +step:57/1705 train_time:5514ms step_avg:96.74ms +step:58/1705 train_time:5607ms step_avg:96.67ms +step:59/1705 train_time:5699ms step_avg:96.59ms +step:60/1705 train_time:5792ms step_avg:96.54ms +step:61/1705 train_time:5885ms step_avg:96.47ms +step:62/1705 train_time:5977ms step_avg:96.40ms +step:63/1705 train_time:6070ms step_avg:96.36ms +step:64/1705 train_time:6162ms step_avg:96.28ms +step:65/1705 train_time:6255ms step_avg:96.23ms +step:66/1705 train_time:6349ms step_avg:96.19ms +step:67/1705 train_time:6441ms step_avg:96.14ms +step:68/1705 train_time:6535ms step_avg:96.10ms +step:69/1705 train_time:6628ms step_avg:96.06ms +step:70/1705 train_time:6720ms step_avg:96.01ms +step:71/1705 train_time:6813ms step_avg:95.96ms +step:72/1705 train_time:6906ms step_avg:95.92ms +step:73/1705 train_time:6999ms step_avg:95.87ms +step:74/1705 train_time:7092ms step_avg:95.84ms +step:75/1705 train_time:7185ms step_avg:95.80ms +step:76/1705 train_time:7278ms step_avg:95.76ms +step:77/1705 train_time:7371ms step_avg:95.73ms +step:78/1705 train_time:7464ms step_avg:95.70ms +step:79/1705 train_time:7557ms step_avg:95.66ms +step:80/1705 train_time:7651ms step_avg:95.64ms +step:81/1705 train_time:7746ms step_avg:95.63ms +step:82/1705 train_time:7838ms step_avg:95.59ms +step:83/1705 train_time:7931ms step_avg:95.55ms +step:84/1705 train_time:8024ms step_avg:95.52ms +step:85/1705 train_time:8117ms step_avg:95.49ms +step:86/1705 train_time:8210ms step_avg:95.47ms +step:87/1705 train_time:8302ms step_avg:95.43ms +step:88/1705 train_time:8395ms step_avg:95.40ms +step:89/1705 train_time:8489ms step_avg:95.38ms +step:90/1705 train_time:8582ms step_avg:95.35ms +step:91/1705 train_time:8676ms step_avg:95.35ms +step:92/1705 train_time:8768ms step_avg:95.31ms +step:93/1705 train_time:8860ms step_avg:95.27ms +step:94/1705 train_time:8953ms step_avg:95.25ms +step:95/1705 train_time:9046ms step_avg:95.22ms +step:96/1705 train_time:9138ms step_avg:95.19ms +step:97/1705 train_time:9230ms step_avg:95.16ms +step:98/1705 train_time:9322ms step_avg:95.13ms +step:99/1705 train_time:9415ms step_avg:95.10ms +step:100/1705 train_time:9509ms step_avg:95.09ms +step:101/1705 train_time:9601ms step_avg:95.06ms +step:102/1705 train_time:9695ms step_avg:95.05ms +step:103/1705 train_time:9788ms step_avg:95.03ms +step:104/1705 train_time:9881ms step_avg:95.01ms +step:105/1705 train_time:9974ms step_avg:94.99ms +step:106/1705 train_time:10067ms step_avg:94.97ms +step:107/1705 train_time:10159ms step_avg:94.94ms +step:108/1705 train_time:10253ms step_avg:94.94ms +step:109/1705 train_time:10346ms step_avg:94.92ms +step:110/1705 train_time:10438ms step_avg:94.90ms +step:111/1705 train_time:10531ms step_avg:94.88ms +step:112/1705 train_time:10624ms step_avg:94.85ms +step:113/1705 train_time:10716ms step_avg:94.84ms +step:114/1705 train_time:10810ms step_avg:94.82ms +step:115/1705 train_time:10902ms step_avg:94.80ms +step:116/1705 train_time:10995ms step_avg:94.79ms +step:117/1705 train_time:11088ms step_avg:94.77ms +step:118/1705 train_time:11180ms step_avg:94.75ms +step:119/1705 train_time:11273ms step_avg:94.73ms +step:120/1705 train_time:11366ms step_avg:94.72ms +step:121/1705 train_time:11458ms step_avg:94.69ms +step:122/1705 train_time:11552ms step_avg:94.69ms +step:123/1705 train_time:11646ms step_avg:94.68ms +step:124/1705 train_time:11738ms step_avg:94.66ms +step:125/1705 train_time:11831ms step_avg:94.65ms +step:125/1705 val_loss:4.3026 train_time:11924ms step_avg:95.39ms +step:126/1705 train_time:11948ms step_avg:94.83ms +step:127/1705 train_time:12023ms step_avg:94.67ms +step:128/1705 train_time:12126ms step_avg:94.74ms +step:129/1705 train_time:12222ms step_avg:94.74ms +step:130/1705 train_time:12314ms step_avg:94.72ms +step:131/1705 train_time:12406ms step_avg:94.70ms +step:132/1705 train_time:12498ms step_avg:94.68ms +step:133/1705 train_time:12590ms step_avg:94.66ms +step:134/1705 train_time:12682ms step_avg:94.64ms +step:135/1705 train_time:12774ms step_avg:94.62ms +step:136/1705 train_time:12867ms step_avg:94.61ms +step:137/1705 train_time:12959ms step_avg:94.59ms +step:138/1705 train_time:13054ms step_avg:94.59ms +step:139/1705 train_time:13151ms step_avg:94.61ms +step:140/1705 train_time:13244ms step_avg:94.60ms +step:141/1705 train_time:13337ms step_avg:94.59ms +step:142/1705 train_time:13429ms step_avg:94.57ms +step:143/1705 train_time:13522ms step_avg:94.56ms +step:144/1705 train_time:13614ms step_avg:94.54ms +step:145/1705 train_time:13706ms step_avg:94.53ms +step:146/1705 train_time:13798ms step_avg:94.51ms +step:147/1705 train_time:13890ms step_avg:94.49ms +step:148/1705 train_time:13982ms step_avg:94.48ms +step:149/1705 train_time:14075ms step_avg:94.46ms +step:150/1705 train_time:14169ms step_avg:94.46ms +step:151/1705 train_time:14263ms step_avg:94.46ms +step:152/1705 train_time:14356ms step_avg:94.45ms +step:153/1705 train_time:14449ms step_avg:94.44ms +step:154/1705 train_time:14542ms step_avg:94.43ms +step:155/1705 train_time:14634ms step_avg:94.41ms +step:156/1705 train_time:14727ms step_avg:94.40ms +step:157/1705 train_time:14819ms step_avg:94.39ms +step:158/1705 train_time:14911ms step_avg:94.37ms +step:159/1705 train_time:15004ms step_avg:94.37ms +step:160/1705 train_time:15098ms step_avg:94.36ms +step:161/1705 train_time:15191ms step_avg:94.36ms +step:162/1705 train_time:15285ms step_avg:94.35ms +step:163/1705 train_time:15378ms step_avg:94.35ms +step:164/1705 train_time:15471ms step_avg:94.33ms +step:165/1705 train_time:15564ms step_avg:94.33ms +step:166/1705 train_time:15656ms step_avg:94.31ms +step:167/1705 train_time:15749ms step_avg:94.30ms +step:168/1705 train_time:15841ms step_avg:94.29ms +step:169/1705 train_time:15933ms step_avg:94.28ms +step:170/1705 train_time:16026ms step_avg:94.27ms +step:171/1705 train_time:16119ms step_avg:94.26ms +step:172/1705 train_time:16212ms step_avg:94.25ms +step:173/1705 train_time:16306ms step_avg:94.25ms +step:174/1705 train_time:16398ms step_avg:94.24ms +step:175/1705 train_time:16491ms step_avg:94.24ms +step:176/1705 train_time:16584ms step_avg:94.23ms +step:177/1705 train_time:16677ms step_avg:94.22ms +step:178/1705 train_time:16769ms step_avg:94.21ms +step:179/1705 train_time:16862ms step_avg:94.20ms +step:180/1705 train_time:16954ms step_avg:94.19ms +step:181/1705 train_time:17048ms step_avg:94.19ms +step:182/1705 train_time:17140ms step_avg:94.18ms +step:183/1705 train_time:17232ms step_avg:94.16ms +step:184/1705 train_time:17326ms step_avg:94.16ms +step:185/1705 train_time:17418ms step_avg:94.15ms +step:186/1705 train_time:17511ms step_avg:94.15ms +step:187/1705 train_time:17604ms step_avg:94.14ms +step:188/1705 train_time:17697ms step_avg:94.13ms +step:189/1705 train_time:17789ms step_avg:94.12ms +step:190/1705 train_time:17882ms step_avg:94.11ms +step:191/1705 train_time:17973ms step_avg:94.10ms +step:192/1705 train_time:18067ms step_avg:94.10ms +step:193/1705 train_time:18160ms step_avg:94.09ms +step:194/1705 train_time:18251ms step_avg:94.08ms +step:195/1705 train_time:18344ms step_avg:94.07ms +step:196/1705 train_time:18437ms step_avg:94.07ms +step:197/1705 train_time:18530ms step_avg:94.06ms +step:198/1705 train_time:18623ms step_avg:94.05ms +step:199/1705 train_time:18715ms step_avg:94.04ms +step:200/1705 train_time:18807ms step_avg:94.03ms +step:201/1705 train_time:18900ms step_avg:94.03ms +step:202/1705 train_time:18992ms step_avg:94.02ms +step:203/1705 train_time:19084ms step_avg:94.01ms +step:204/1705 train_time:19177ms step_avg:94.00ms +step:205/1705 train_time:19270ms step_avg:94.00ms +step:206/1705 train_time:19363ms step_avg:94.00ms +step:207/1705 train_time:19456ms step_avg:93.99ms +step:208/1705 train_time:19548ms step_avg:93.98ms +step:209/1705 train_time:19641ms step_avg:93.97ms +step:210/1705 train_time:19733ms step_avg:93.97ms +step:211/1705 train_time:19826ms step_avg:93.96ms +step:212/1705 train_time:19918ms step_avg:93.95ms +step:213/1705 train_time:20202ms step_avg:94.84ms +step:214/1705 train_time:20330ms step_avg:95.00ms +step:215/1705 train_time:20421ms step_avg:94.98ms +step:216/1705 train_time:20513ms step_avg:94.97ms +step:217/1705 train_time:20604ms step_avg:94.95ms +step:218/1705 train_time:20696ms step_avg:94.94ms +step:219/1705 train_time:20788ms step_avg:94.92ms +step:220/1705 train_time:20880ms step_avg:94.91ms +step:221/1705 train_time:20972ms step_avg:94.90ms +step:222/1705 train_time:21064ms step_avg:94.88ms +step:223/1705 train_time:21157ms step_avg:94.87ms +step:224/1705 train_time:21253ms step_avg:94.88ms +step:225/1705 train_time:21351ms step_avg:94.89ms +step:226/1705 train_time:21444ms step_avg:94.89ms +step:227/1705 train_time:21537ms step_avg:94.88ms +step:228/1705 train_time:21629ms step_avg:94.86ms +step:229/1705 train_time:21721ms step_avg:94.85ms +step:230/1705 train_time:21813ms step_avg:94.84ms +step:231/1705 train_time:21906ms step_avg:94.83ms +step:232/1705 train_time:21998ms step_avg:94.82ms +step:233/1705 train_time:22090ms step_avg:94.81ms +step:234/1705 train_time:22184ms step_avg:94.80ms +step:235/1705 train_time:22278ms step_avg:94.80ms +step:236/1705 train_time:22371ms step_avg:94.79ms +step:237/1705 train_time:22465ms step_avg:94.79ms +step:238/1705 train_time:22558ms step_avg:94.78ms +step:239/1705 train_time:22650ms step_avg:94.77ms +step:240/1705 train_time:22744ms step_avg:94.77ms +step:241/1705 train_time:22836ms step_avg:94.76ms +step:242/1705 train_time:22928ms step_avg:94.75ms +step:243/1705 train_time:23021ms step_avg:94.73ms +step:244/1705 train_time:23113ms step_avg:94.73ms +step:245/1705 train_time:23207ms step_avg:94.72ms +step:246/1705 train_time:23300ms step_avg:94.72ms +step:247/1705 train_time:23393ms step_avg:94.71ms +step:248/1705 train_time:23487ms step_avg:94.71ms +step:249/1705 train_time:23580ms step_avg:94.70ms +step:250/1705 train_time:23673ms step_avg:94.69ms +step:250/1705 val_loss:3.9817 train_time:23766ms step_avg:95.06ms +step:251/1705 train_time:23789ms step_avg:94.78ms +step:252/1705 train_time:23864ms step_avg:94.70ms +step:253/1705 train_time:23965ms step_avg:94.72ms +step:254/1705 train_time:24059ms step_avg:94.72ms +step:255/1705 train_time:24151ms step_avg:94.71ms +step:256/1705 train_time:24243ms step_avg:94.70ms +step:257/1705 train_time:24335ms step_avg:94.69ms +step:258/1705 train_time:24427ms step_avg:94.68ms +step:259/1705 train_time:24518ms step_avg:94.67ms +step:260/1705 train_time:24610ms step_avg:94.65ms +step:261/1705 train_time:24702ms step_avg:94.64ms +step:262/1705 train_time:24796ms step_avg:94.64ms +step:263/1705 train_time:24891ms step_avg:94.64ms +step:264/1705 train_time:24984ms step_avg:94.64ms +step:265/1705 train_time:25078ms step_avg:94.63ms +step:266/1705 train_time:25171ms step_avg:94.63ms +step:267/1705 train_time:25263ms step_avg:94.62ms +step:268/1705 train_time:25355ms step_avg:94.61ms +step:269/1705 train_time:25447ms step_avg:94.60ms +step:270/1705 train_time:25540ms step_avg:94.59ms +step:271/1705 train_time:25632ms step_avg:94.58ms +step:272/1705 train_time:25724ms step_avg:94.57ms +step:273/1705 train_time:25818ms step_avg:94.57ms +step:274/1705 train_time:25912ms step_avg:94.57ms +step:275/1705 train_time:26005ms step_avg:94.56ms +step:276/1705 train_time:26099ms step_avg:94.56ms +step:277/1705 train_time:26192ms step_avg:94.55ms +step:278/1705 train_time:26283ms step_avg:94.54ms +step:279/1705 train_time:26376ms step_avg:94.54ms +step:280/1705 train_time:26467ms step_avg:94.53ms +step:281/1705 train_time:26559ms step_avg:94.52ms +step:282/1705 train_time:26652ms step_avg:94.51ms +step:283/1705 train_time:26744ms step_avg:94.50ms +step:284/1705 train_time:26837ms step_avg:94.50ms +step:285/1705 train_time:26930ms step_avg:94.49ms +step:286/1705 train_time:27023ms step_avg:94.49ms +step:287/1705 train_time:27116ms step_avg:94.48ms +step:288/1705 train_time:27208ms step_avg:94.47ms +step:289/1705 train_time:27302ms step_avg:94.47ms +step:290/1705 train_time:27394ms step_avg:94.46ms +step:291/1705 train_time:27486ms step_avg:94.45ms +step:292/1705 train_time:27579ms step_avg:94.45ms +step:293/1705 train_time:27671ms step_avg:94.44ms +step:294/1705 train_time:27763ms step_avg:94.43ms +step:295/1705 train_time:27857ms step_avg:94.43ms +step:296/1705 train_time:27950ms step_avg:94.43ms +step:297/1705 train_time:28044ms step_avg:94.42ms +step:298/1705 train_time:28138ms step_avg:94.42ms +step:299/1705 train_time:28229ms step_avg:94.41ms +step:300/1705 train_time:28322ms step_avg:94.41ms +step:301/1705 train_time:28415ms step_avg:94.40ms +step:302/1705 train_time:28507ms step_avg:94.40ms +step:303/1705 train_time:28601ms step_avg:94.39ms +step:304/1705 train_time:28693ms step_avg:94.38ms +step:305/1705 train_time:28785ms step_avg:94.38ms +step:306/1705 train_time:28879ms step_avg:94.37ms +step:307/1705 train_time:28971ms step_avg:94.37ms +step:308/1705 train_time:29064ms step_avg:94.36ms +step:309/1705 train_time:29157ms step_avg:94.36ms +step:310/1705 train_time:29250ms step_avg:94.35ms +step:311/1705 train_time:29343ms step_avg:94.35ms +step:312/1705 train_time:29436ms step_avg:94.35ms +step:313/1705 train_time:29530ms step_avg:94.34ms +step:314/1705 train_time:29622ms step_avg:94.34ms +step:315/1705 train_time:29714ms step_avg:94.33ms +step:316/1705 train_time:29807ms step_avg:94.33ms +step:317/1705 train_time:29901ms step_avg:94.32ms +step:318/1705 train_time:29993ms step_avg:94.32ms +step:319/1705 train_time:30085ms step_avg:94.31ms +step:320/1705 train_time:30179ms step_avg:94.31ms +step:321/1705 train_time:30272ms step_avg:94.30ms +step:322/1705 train_time:30364ms step_avg:94.30ms +step:323/1705 train_time:30457ms step_avg:94.29ms +step:324/1705 train_time:30549ms step_avg:94.29ms +step:325/1705 train_time:30643ms step_avg:94.29ms +step:326/1705 train_time:30736ms step_avg:94.28ms +step:327/1705 train_time:30828ms step_avg:94.28ms +step:328/1705 train_time:30921ms step_avg:94.27ms +step:329/1705 train_time:31014ms step_avg:94.27ms +step:330/1705 train_time:31106ms step_avg:94.26ms +step:331/1705 train_time:31199ms step_avg:94.26ms +step:332/1705 train_time:31293ms step_avg:94.26ms +step:333/1705 train_time:31385ms step_avg:94.25ms +step:334/1705 train_time:31478ms step_avg:94.25ms +step:335/1705 train_time:31570ms step_avg:94.24ms +step:336/1705 train_time:31663ms step_avg:94.24ms +step:337/1705 train_time:31756ms step_avg:94.23ms +step:338/1705 train_time:31849ms step_avg:94.23ms +step:339/1705 train_time:31941ms step_avg:94.22ms +step:340/1705 train_time:32034ms step_avg:94.22ms +step:341/1705 train_time:32126ms step_avg:94.21ms +step:342/1705 train_time:32219ms step_avg:94.21ms +step:343/1705 train_time:32311ms step_avg:94.20ms +step:344/1705 train_time:32404ms step_avg:94.20ms +step:345/1705 train_time:32498ms step_avg:94.20ms +step:346/1705 train_time:32590ms step_avg:94.19ms +step:347/1705 train_time:32683ms step_avg:94.19ms +step:348/1705 train_time:32776ms step_avg:94.18ms +step:349/1705 train_time:32868ms step_avg:94.18ms +step:350/1705 train_time:32962ms step_avg:94.18ms +step:351/1705 train_time:33056ms step_avg:94.18ms +step:352/1705 train_time:33148ms step_avg:94.17ms +step:353/1705 train_time:33242ms step_avg:94.17ms +step:354/1705 train_time:33335ms step_avg:94.17ms +step:355/1705 train_time:33427ms step_avg:94.16ms +step:356/1705 train_time:33520ms step_avg:94.16ms +step:357/1705 train_time:33613ms step_avg:94.15ms +step:358/1705 train_time:33706ms step_avg:94.15ms +step:359/1705 train_time:33798ms step_avg:94.14ms +step:360/1705 train_time:33891ms step_avg:94.14ms +step:361/1705 train_time:33983ms step_avg:94.14ms +step:362/1705 train_time:34076ms step_avg:94.13ms +step:363/1705 train_time:34168ms step_avg:94.13ms +step:364/1705 train_time:34261ms step_avg:94.12ms +step:365/1705 train_time:34354ms step_avg:94.12ms +step:366/1705 train_time:34446ms step_avg:94.11ms +step:367/1705 train_time:34539ms step_avg:94.11ms +step:368/1705 train_time:34632ms step_avg:94.11ms +step:369/1705 train_time:34724ms step_avg:94.10ms +step:370/1705 train_time:34818ms step_avg:94.10ms +step:371/1705 train_time:34910ms step_avg:94.10ms +step:372/1705 train_time:35003ms step_avg:94.09ms +step:373/1705 train_time:35096ms step_avg:94.09ms +step:374/1705 train_time:35189ms step_avg:94.09ms +step:375/1705 train_time:35282ms step_avg:94.08ms +step:375/1705 val_loss:3.8213 train_time:35375ms step_avg:94.33ms +step:376/1705 train_time:35398ms step_avg:94.14ms +step:377/1705 train_time:35473ms step_avg:94.09ms +step:378/1705 train_time:35574ms step_avg:94.11ms +step:379/1705 train_time:35667ms step_avg:94.11ms +step:380/1705 train_time:35759ms step_avg:94.10ms +step:381/1705 train_time:35851ms step_avg:94.10ms +step:382/1705 train_time:35943ms step_avg:94.09ms +step:383/1705 train_time:36034ms step_avg:94.08ms +step:384/1705 train_time:36126ms step_avg:94.08ms +step:385/1705 train_time:36217ms step_avg:94.07ms +step:386/1705 train_time:36309ms step_avg:94.06ms +step:387/1705 train_time:36402ms step_avg:94.06ms +step:388/1705 train_time:36496ms step_avg:94.06ms +step:389/1705 train_time:36592ms step_avg:94.07ms +step:390/1705 train_time:36685ms step_avg:94.06ms +step:391/1705 train_time:36777ms step_avg:94.06ms +step:392/1705 train_time:36870ms step_avg:94.06ms +step:393/1705 train_time:36963ms step_avg:94.05ms +step:394/1705 train_time:37055ms step_avg:94.05ms +step:395/1705 train_time:37148ms step_avg:94.04ms +step:396/1705 train_time:37239ms step_avg:94.04ms +step:397/1705 train_time:37331ms step_avg:94.03ms +step:398/1705 train_time:37424ms step_avg:94.03ms +step:399/1705 train_time:37518ms step_avg:94.03ms +step:400/1705 train_time:37613ms step_avg:94.03ms +step:401/1705 train_time:37706ms step_avg:94.03ms +step:402/1705 train_time:37798ms step_avg:94.02ms +step:403/1705 train_time:37891ms step_avg:94.02ms +step:404/1705 train_time:37983ms step_avg:94.02ms +step:405/1705 train_time:38075ms step_avg:94.01ms +step:406/1705 train_time:38168ms step_avg:94.01ms +step:407/1705 train_time:38260ms step_avg:94.00ms +step:408/1705 train_time:38353ms step_avg:94.00ms +step:409/1705 train_time:38446ms step_avg:94.00ms +step:410/1705 train_time:38538ms step_avg:94.00ms +step:411/1705 train_time:38632ms step_avg:94.00ms +step:412/1705 train_time:38725ms step_avg:93.99ms +step:413/1705 train_time:38818ms step_avg:93.99ms +step:414/1705 train_time:38911ms step_avg:93.99ms +step:415/1705 train_time:39004ms step_avg:93.99ms +step:416/1705 train_time:39096ms step_avg:93.98ms +step:417/1705 train_time:39188ms step_avg:93.98ms +step:418/1705 train_time:39280ms step_avg:93.97ms +step:419/1705 train_time:39374ms step_avg:93.97ms +step:420/1705 train_time:39466ms step_avg:93.97ms +step:421/1705 train_time:39559ms step_avg:93.96ms +step:422/1705 train_time:39652ms step_avg:93.96ms +step:423/1705 train_time:39745ms step_avg:93.96ms +step:424/1705 train_time:39837ms step_avg:93.96ms +step:425/1705 train_time:40139ms step_avg:94.44ms +step:426/1705 train_time:40218ms step_avg:94.41ms +step:427/1705 train_time:40310ms step_avg:94.40ms +step:428/1705 train_time:40402ms step_avg:94.40ms +step:429/1705 train_time:40494ms step_avg:94.39ms +step:430/1705 train_time:40586ms step_avg:94.39ms +step:431/1705 train_time:40677ms step_avg:94.38ms +step:432/1705 train_time:40770ms step_avg:94.37ms +step:433/1705 train_time:40861ms step_avg:94.37ms +step:434/1705 train_time:40953ms step_avg:94.36ms +step:435/1705 train_time:41049ms step_avg:94.36ms +step:436/1705 train_time:41144ms step_avg:94.37ms +step:437/1705 train_time:41237ms step_avg:94.36ms +step:438/1705 train_time:41331ms step_avg:94.36ms +step:439/1705 train_time:41423ms step_avg:94.36ms +step:440/1705 train_time:41516ms step_avg:94.36ms +step:441/1705 train_time:41608ms step_avg:94.35ms +step:442/1705 train_time:41700ms step_avg:94.34ms +step:443/1705 train_time:41792ms step_avg:94.34ms +step:444/1705 train_time:41884ms step_avg:94.33ms +step:445/1705 train_time:41976ms step_avg:94.33ms +step:446/1705 train_time:42070ms step_avg:94.33ms +step:447/1705 train_time:42164ms step_avg:94.33ms +step:448/1705 train_time:42256ms step_avg:94.32ms +step:449/1705 train_time:42349ms step_avg:94.32ms +step:450/1705 train_time:42443ms step_avg:94.32ms +step:451/1705 train_time:42535ms step_avg:94.31ms +step:452/1705 train_time:42628ms step_avg:94.31ms +step:453/1705 train_time:42720ms step_avg:94.30ms +step:454/1705 train_time:42813ms step_avg:94.30ms +step:455/1705 train_time:42905ms step_avg:94.30ms +step:456/1705 train_time:42997ms step_avg:94.29ms +step:457/1705 train_time:43091ms step_avg:94.29ms +step:458/1705 train_time:43184ms step_avg:94.29ms +step:459/1705 train_time:43277ms step_avg:94.29ms +step:460/1705 train_time:43370ms step_avg:94.28ms +step:461/1705 train_time:43463ms step_avg:94.28ms +step:462/1705 train_time:43555ms step_avg:94.28ms +step:463/1705 train_time:43648ms step_avg:94.27ms +step:464/1705 train_time:43741ms step_avg:94.27ms +step:465/1705 train_time:43834ms step_avg:94.27ms +step:466/1705 train_time:43926ms step_avg:94.26ms +step:467/1705 train_time:44019ms step_avg:94.26ms +step:468/1705 train_time:44112ms step_avg:94.26ms +step:469/1705 train_time:44205ms step_avg:94.25ms +step:470/1705 train_time:44297ms step_avg:94.25ms +step:471/1705 train_time:44390ms step_avg:94.25ms +step:472/1705 train_time:44483ms step_avg:94.24ms +step:473/1705 train_time:44575ms step_avg:94.24ms +step:474/1705 train_time:44668ms step_avg:94.24ms +step:475/1705 train_time:44760ms step_avg:94.23ms +step:476/1705 train_time:44853ms step_avg:94.23ms +step:477/1705 train_time:44946ms step_avg:94.23ms +step:478/1705 train_time:45038ms step_avg:94.22ms +step:479/1705 train_time:45131ms step_avg:94.22ms +step:480/1705 train_time:45224ms step_avg:94.22ms +step:481/1705 train_time:45317ms step_avg:94.21ms +step:482/1705 train_time:45411ms step_avg:94.21ms +step:483/1705 train_time:45504ms step_avg:94.21ms +step:484/1705 train_time:45596ms step_avg:94.21ms +step:485/1705 train_time:45689ms step_avg:94.20ms +step:486/1705 train_time:45782ms step_avg:94.20ms +step:487/1705 train_time:45875ms step_avg:94.20ms +step:488/1705 train_time:45968ms step_avg:94.20ms +step:489/1705 train_time:46060ms step_avg:94.19ms +step:490/1705 train_time:46153ms step_avg:94.19ms +step:491/1705 train_time:46245ms step_avg:94.19ms +step:492/1705 train_time:46338ms step_avg:94.18ms +step:493/1705 train_time:46431ms step_avg:94.18ms +step:494/1705 train_time:46524ms step_avg:94.18ms +step:495/1705 train_time:46616ms step_avg:94.17ms +step:496/1705 train_time:46710ms step_avg:94.17ms +step:497/1705 train_time:46803ms step_avg:94.17ms +step:498/1705 train_time:46896ms step_avg:94.17ms +step:499/1705 train_time:46989ms step_avg:94.17ms +step:500/1705 train_time:47081ms step_avg:94.16ms +step:500/1705 val_loss:3.7198 train_time:47174ms step_avg:94.35ms +step:501/1705 train_time:47198ms step_avg:94.21ms +step:502/1705 train_time:47272ms step_avg:94.17ms +step:503/1705 train_time:47369ms step_avg:94.17ms +step:504/1705 train_time:47463ms step_avg:94.17ms +step:505/1705 train_time:47556ms step_avg:94.17ms +step:506/1705 train_time:47648ms step_avg:94.17ms +step:507/1705 train_time:47740ms step_avg:94.16ms +step:508/1705 train_time:47832ms step_avg:94.16ms +step:509/1705 train_time:47924ms step_avg:94.15ms +step:510/1705 train_time:48016ms step_avg:94.15ms +step:511/1705 train_time:48108ms step_avg:94.14ms +step:512/1705 train_time:48202ms step_avg:94.14ms +step:513/1705 train_time:48298ms step_avg:94.15ms +step:514/1705 train_time:48391ms step_avg:94.15ms +step:515/1705 train_time:48485ms step_avg:94.15ms +step:516/1705 train_time:48579ms step_avg:94.14ms +step:517/1705 train_time:48671ms step_avg:94.14ms +step:518/1705 train_time:48763ms step_avg:94.14ms +step:519/1705 train_time:48855ms step_avg:94.13ms +step:520/1705 train_time:48947ms step_avg:94.13ms +step:521/1705 train_time:49039ms step_avg:94.12ms +step:522/1705 train_time:49131ms step_avg:94.12ms +step:523/1705 train_time:49224ms step_avg:94.12ms +step:524/1705 train_time:49318ms step_avg:94.12ms +step:525/1705 train_time:49411ms step_avg:94.12ms +step:526/1705 train_time:49505ms step_avg:94.12ms +step:527/1705 train_time:49599ms step_avg:94.12ms +step:528/1705 train_time:49692ms step_avg:94.11ms +step:529/1705 train_time:49785ms step_avg:94.11ms +step:530/1705 train_time:49877ms step_avg:94.11ms +step:531/1705 train_time:49969ms step_avg:94.10ms +step:532/1705 train_time:50061ms step_avg:94.10ms +step:533/1705 train_time:50154ms step_avg:94.10ms +step:534/1705 train_time:50246ms step_avg:94.09ms +step:535/1705 train_time:50340ms step_avg:94.09ms +step:536/1705 train_time:50433ms step_avg:94.09ms +step:537/1705 train_time:50527ms step_avg:94.09ms +step:538/1705 train_time:50620ms step_avg:94.09ms +step:539/1705 train_time:50713ms step_avg:94.09ms +step:540/1705 train_time:50806ms step_avg:94.08ms +step:541/1705 train_time:50898ms step_avg:94.08ms +step:542/1705 train_time:50990ms step_avg:94.08ms +step:543/1705 train_time:51083ms step_avg:94.07ms +step:544/1705 train_time:51175ms step_avg:94.07ms +step:545/1705 train_time:51267ms step_avg:94.07ms +step:546/1705 train_time:51361ms step_avg:94.07ms +step:547/1705 train_time:51453ms step_avg:94.06ms +step:548/1705 train_time:51546ms step_avg:94.06ms +step:549/1705 train_time:51640ms step_avg:94.06ms +step:550/1705 train_time:51733ms step_avg:94.06ms +step:551/1705 train_time:51825ms step_avg:94.06ms +step:552/1705 train_time:51919ms step_avg:94.06ms +step:553/1705 train_time:52011ms step_avg:94.05ms +step:554/1705 train_time:52104ms step_avg:94.05ms +step:555/1705 train_time:52197ms step_avg:94.05ms +step:556/1705 train_time:52289ms step_avg:94.04ms +step:557/1705 train_time:52382ms step_avg:94.04ms +step:558/1705 train_time:52476ms step_avg:94.04ms +step:559/1705 train_time:52569ms step_avg:94.04ms +step:560/1705 train_time:52662ms step_avg:94.04ms +step:561/1705 train_time:52757ms step_avg:94.04ms +step:562/1705 train_time:52849ms step_avg:94.04ms +step:563/1705 train_time:52942ms step_avg:94.04ms +step:564/1705 train_time:53034ms step_avg:94.03ms +step:565/1705 train_time:53127ms step_avg:94.03ms +step:566/1705 train_time:53220ms step_avg:94.03ms +step:567/1705 train_time:53312ms step_avg:94.02ms +step:568/1705 train_time:53405ms step_avg:94.02ms +step:569/1705 train_time:53499ms step_avg:94.02ms +step:570/1705 train_time:53592ms step_avg:94.02ms +step:571/1705 train_time:53686ms step_avg:94.02ms +step:572/1705 train_time:53781ms step_avg:94.02ms +step:573/1705 train_time:53876ms step_avg:94.02ms +step:574/1705 train_time:53969ms step_avg:94.02ms +step:575/1705 train_time:54063ms step_avg:94.02ms +step:576/1705 train_time:54159ms step_avg:94.03ms +step:577/1705 train_time:54253ms step_avg:94.03ms +step:578/1705 train_time:54346ms step_avg:94.02ms +step:579/1705 train_time:54440ms step_avg:94.02ms +step:580/1705 train_time:54534ms step_avg:94.02ms +step:581/1705 train_time:54628ms step_avg:94.02ms +step:582/1705 train_time:54723ms step_avg:94.03ms +step:583/1705 train_time:54817ms step_avg:94.03ms +step:584/1705 train_time:54911ms step_avg:94.03ms +step:585/1705 train_time:55005ms step_avg:94.03ms +step:586/1705 train_time:55099ms step_avg:94.03ms +step:587/1705 train_time:55194ms step_avg:94.03ms +step:588/1705 train_time:55287ms step_avg:94.03ms +step:589/1705 train_time:55382ms step_avg:94.03ms +step:590/1705 train_time:55476ms step_avg:94.03ms +step:591/1705 train_time:55570ms step_avg:94.03ms +step:592/1705 train_time:55664ms step_avg:94.03ms +step:593/1705 train_time:55759ms step_avg:94.03ms +step:594/1705 train_time:55853ms step_avg:94.03ms +step:595/1705 train_time:55947ms step_avg:94.03ms +step:596/1705 train_time:56041ms step_avg:94.03ms +step:597/1705 train_time:56135ms step_avg:94.03ms +step:598/1705 train_time:56229ms step_avg:94.03ms +step:599/1705 train_time:56323ms step_avg:94.03ms +step:600/1705 train_time:56418ms step_avg:94.03ms +step:601/1705 train_time:56512ms step_avg:94.03ms +step:602/1705 train_time:56606ms step_avg:94.03ms +step:603/1705 train_time:56700ms step_avg:94.03ms +step:604/1705 train_time:56795ms step_avg:94.03ms +step:605/1705 train_time:56888ms step_avg:94.03ms +step:606/1705 train_time:56983ms step_avg:94.03ms +step:607/1705 train_time:57077ms step_avg:94.03ms +step:608/1705 train_time:57171ms step_avg:94.03ms +step:609/1705 train_time:57265ms step_avg:94.03ms +step:610/1705 train_time:57359ms step_avg:94.03ms +step:611/1705 train_time:57453ms step_avg:94.03ms +step:612/1705 train_time:57547ms step_avg:94.03ms +step:613/1705 train_time:57642ms step_avg:94.03ms +step:614/1705 train_time:57736ms step_avg:94.03ms +step:615/1705 train_time:57830ms step_avg:94.03ms +step:616/1705 train_time:57924ms step_avg:94.03ms +step:617/1705 train_time:58018ms step_avg:94.03ms +step:618/1705 train_time:58112ms step_avg:94.03ms +step:619/1705 train_time:58206ms step_avg:94.03ms +step:620/1705 train_time:58301ms step_avg:94.03ms +step:621/1705 train_time:58394ms step_avg:94.03ms +step:622/1705 train_time:58488ms step_avg:94.03ms +step:623/1705 train_time:58582ms step_avg:94.03ms +step:624/1705 train_time:58676ms step_avg:94.03ms +step:625/1705 train_time:58770ms step_avg:94.03ms +step:625/1705 val_loss:3.6221 train_time:58864ms step_avg:94.18ms +step:626/1705 train_time:58889ms step_avg:94.07ms +step:627/1705 train_time:58971ms step_avg:94.05ms +step:628/1705 train_time:59072ms step_avg:94.06ms +step:629/1705 train_time:59167ms step_avg:94.07ms +step:630/1705 train_time:59260ms step_avg:94.06ms +step:631/1705 train_time:59353ms step_avg:94.06ms +step:632/1705 train_time:59446ms step_avg:94.06ms +step:633/1705 train_time:59539ms step_avg:94.06ms +step:634/1705 train_time:59632ms step_avg:94.06ms +step:635/1705 train_time:59725ms step_avg:94.06ms +step:636/1705 train_time:59819ms step_avg:94.05ms +step:637/1705 train_time:59915ms step_avg:94.06ms +step:638/1705 train_time:60012ms step_avg:94.06ms +step:639/1705 train_time:60277ms step_avg:94.33ms +step:640/1705 train_time:60427ms step_avg:94.42ms +step:641/1705 train_time:60519ms step_avg:94.41ms +step:642/1705 train_time:60612ms step_avg:94.41ms +step:643/1705 train_time:60705ms step_avg:94.41ms +step:644/1705 train_time:60798ms step_avg:94.41ms +step:645/1705 train_time:60891ms step_avg:94.40ms +step:646/1705 train_time:60984ms step_avg:94.40ms +step:647/1705 train_time:61076ms step_avg:94.40ms +step:648/1705 train_time:61169ms step_avg:94.40ms +step:649/1705 train_time:61265ms step_avg:94.40ms +step:650/1705 train_time:61362ms step_avg:94.40ms +step:651/1705 train_time:61460ms step_avg:94.41ms +step:652/1705 train_time:61553ms step_avg:94.41ms +step:653/1705 train_time:61647ms step_avg:94.41ms +step:654/1705 train_time:61740ms step_avg:94.40ms +step:655/1705 train_time:61834ms step_avg:94.40ms +step:656/1705 train_time:61927ms step_avg:94.40ms +step:657/1705 train_time:62020ms step_avg:94.40ms +step:658/1705 train_time:62112ms step_avg:94.40ms +step:659/1705 train_time:62206ms step_avg:94.40ms +step:660/1705 train_time:62302ms step_avg:94.40ms +step:661/1705 train_time:62397ms step_avg:94.40ms +step:662/1705 train_time:62491ms step_avg:94.40ms +step:663/1705 train_time:62586ms step_avg:94.40ms +step:664/1705 train_time:62680ms step_avg:94.40ms +step:665/1705 train_time:62774ms step_avg:94.40ms +step:666/1705 train_time:62867ms step_avg:94.39ms +step:667/1705 train_time:62961ms step_avg:94.39ms +step:668/1705 train_time:63054ms step_avg:94.39ms +step:669/1705 train_time:63147ms step_avg:94.39ms +step:670/1705 train_time:63241ms step_avg:94.39ms +step:671/1705 train_time:63334ms step_avg:94.39ms +step:672/1705 train_time:63429ms step_avg:94.39ms +step:673/1705 train_time:63523ms step_avg:94.39ms +step:674/1705 train_time:63617ms step_avg:94.39ms +step:675/1705 train_time:63710ms step_avg:94.39ms +step:676/1705 train_time:63805ms step_avg:94.39ms +step:677/1705 train_time:63899ms step_avg:94.39ms +step:678/1705 train_time:63993ms step_avg:94.39ms +step:679/1705 train_time:64086ms step_avg:94.38ms +step:680/1705 train_time:64180ms step_avg:94.38ms +step:681/1705 train_time:64273ms step_avg:94.38ms +step:682/1705 train_time:64367ms step_avg:94.38ms +step:683/1705 train_time:64461ms step_avg:94.38ms +step:684/1705 train_time:64556ms step_avg:94.38ms +step:685/1705 train_time:64650ms step_avg:94.38ms +step:686/1705 train_time:64744ms step_avg:94.38ms +step:687/1705 train_time:64838ms step_avg:94.38ms +step:688/1705 train_time:64931ms step_avg:94.38ms +step:689/1705 train_time:65025ms step_avg:94.38ms +step:690/1705 train_time:65119ms step_avg:94.37ms +step:691/1705 train_time:65212ms step_avg:94.37ms +step:692/1705 train_time:65307ms step_avg:94.37ms +step:693/1705 train_time:65401ms step_avg:94.37ms +step:694/1705 train_time:65494ms step_avg:94.37ms +step:695/1705 train_time:65588ms step_avg:94.37ms +step:696/1705 train_time:65683ms step_avg:94.37ms +step:697/1705 train_time:65778ms step_avg:94.37ms +step:698/1705 train_time:65871ms step_avg:94.37ms +step:699/1705 train_time:65965ms step_avg:94.37ms +step:700/1705 train_time:66059ms step_avg:94.37ms +step:701/1705 train_time:66153ms step_avg:94.37ms +step:702/1705 train_time:66247ms step_avg:94.37ms +step:703/1705 train_time:66341ms step_avg:94.37ms +step:704/1705 train_time:66434ms step_avg:94.37ms +step:705/1705 train_time:66527ms step_avg:94.37ms +step:706/1705 train_time:66622ms step_avg:94.37ms +step:707/1705 train_time:66717ms step_avg:94.37ms +step:708/1705 train_time:66810ms step_avg:94.36ms +step:709/1705 train_time:66905ms step_avg:94.37ms +step:710/1705 train_time:66999ms step_avg:94.37ms +step:711/1705 train_time:67092ms step_avg:94.36ms +step:712/1705 train_time:67186ms step_avg:94.36ms +step:713/1705 train_time:67280ms step_avg:94.36ms +step:714/1705 train_time:67373ms step_avg:94.36ms +step:715/1705 train_time:67467ms step_avg:94.36ms +step:716/1705 train_time:67561ms step_avg:94.36ms +step:717/1705 train_time:67655ms step_avg:94.36ms +step:718/1705 train_time:67748ms step_avg:94.36ms +step:719/1705 train_time:67843ms step_avg:94.36ms +step:720/1705 train_time:67937ms step_avg:94.36ms +step:721/1705 train_time:68030ms step_avg:94.36ms +step:722/1705 train_time:68125ms step_avg:94.36ms +step:723/1705 train_time:68218ms step_avg:94.35ms +step:724/1705 train_time:68312ms step_avg:94.35ms +step:725/1705 train_time:68406ms step_avg:94.35ms +step:726/1705 train_time:68500ms step_avg:94.35ms +step:727/1705 train_time:68594ms step_avg:94.35ms +step:728/1705 train_time:68688ms step_avg:94.35ms +step:729/1705 train_time:68782ms step_avg:94.35ms +step:730/1705 train_time:68876ms step_avg:94.35ms +step:731/1705 train_time:68970ms step_avg:94.35ms +step:732/1705 train_time:69065ms step_avg:94.35ms +step:733/1705 train_time:69159ms step_avg:94.35ms +step:734/1705 train_time:69252ms step_avg:94.35ms +step:735/1705 train_time:69346ms step_avg:94.35ms +step:736/1705 train_time:69440ms step_avg:94.35ms +step:737/1705 train_time:69534ms step_avg:94.35ms +step:738/1705 train_time:69629ms step_avg:94.35ms +step:739/1705 train_time:69724ms step_avg:94.35ms +step:740/1705 train_time:69818ms step_avg:94.35ms +step:741/1705 train_time:69911ms step_avg:94.35ms +step:742/1705 train_time:70006ms step_avg:94.35ms +step:743/1705 train_time:70099ms step_avg:94.35ms +step:744/1705 train_time:70193ms step_avg:94.35ms +step:745/1705 train_time:70288ms step_avg:94.35ms +step:746/1705 train_time:70382ms step_avg:94.35ms +step:747/1705 train_time:70475ms step_avg:94.34ms +step:748/1705 train_time:70569ms step_avg:94.34ms +step:749/1705 train_time:70664ms step_avg:94.34ms +step:750/1705 train_time:70757ms step_avg:94.34ms +step:750/1705 val_loss:3.5651 train_time:70852ms step_avg:94.47ms +step:751/1705 train_time:70875ms step_avg:94.37ms +step:752/1705 train_time:70954ms step_avg:94.35ms +step:753/1705 train_time:71052ms step_avg:94.36ms +step:754/1705 train_time:71147ms step_avg:94.36ms +step:755/1705 train_time:71241ms step_avg:94.36ms +step:756/1705 train_time:71334ms step_avg:94.36ms +step:757/1705 train_time:71428ms step_avg:94.36ms +step:758/1705 train_time:71521ms step_avg:94.35ms +step:759/1705 train_time:71614ms step_avg:94.35ms +step:760/1705 train_time:71707ms step_avg:94.35ms +step:761/1705 train_time:71800ms step_avg:94.35ms +step:762/1705 train_time:71895ms step_avg:94.35ms +step:763/1705 train_time:71991ms step_avg:94.35ms +step:764/1705 train_time:72087ms step_avg:94.35ms +step:765/1705 train_time:72182ms step_avg:94.36ms +step:766/1705 train_time:72277ms step_avg:94.36ms +step:767/1705 train_time:72370ms step_avg:94.36ms +step:768/1705 train_time:72464ms step_avg:94.35ms +step:769/1705 train_time:72557ms step_avg:94.35ms +step:770/1705 train_time:72650ms step_avg:94.35ms +step:771/1705 train_time:72744ms step_avg:94.35ms +step:772/1705 train_time:72838ms step_avg:94.35ms +step:773/1705 train_time:72932ms step_avg:94.35ms +step:774/1705 train_time:73027ms step_avg:94.35ms +step:775/1705 train_time:73122ms step_avg:94.35ms +step:776/1705 train_time:73216ms step_avg:94.35ms +step:777/1705 train_time:73310ms step_avg:94.35ms +step:778/1705 train_time:73404ms step_avg:94.35ms +step:779/1705 train_time:73497ms step_avg:94.35ms +step:780/1705 train_time:73590ms step_avg:94.35ms +step:781/1705 train_time:73684ms step_avg:94.35ms +step:782/1705 train_time:73779ms step_avg:94.35ms +step:783/1705 train_time:73873ms step_avg:94.35ms +step:784/1705 train_time:73967ms step_avg:94.35ms +step:785/1705 train_time:74062ms step_avg:94.35ms +step:786/1705 train_time:74156ms step_avg:94.35ms +step:787/1705 train_time:74250ms step_avg:94.35ms +step:788/1705 train_time:74346ms step_avg:94.35ms +step:789/1705 train_time:74440ms step_avg:94.35ms +step:790/1705 train_time:74533ms step_avg:94.35ms +step:791/1705 train_time:74627ms step_avg:94.34ms +step:792/1705 train_time:74721ms step_avg:94.34ms +step:793/1705 train_time:74815ms step_avg:94.34ms +step:794/1705 train_time:74909ms step_avg:94.34ms +step:795/1705 train_time:75004ms step_avg:94.35ms +step:796/1705 train_time:75099ms step_avg:94.34ms +step:797/1705 train_time:75192ms step_avg:94.34ms +step:798/1705 train_time:75287ms step_avg:94.34ms +step:799/1705 train_time:75382ms step_avg:94.35ms +step:800/1705 train_time:75475ms step_avg:94.34ms +step:801/1705 train_time:75569ms step_avg:94.34ms +step:802/1705 train_time:75663ms step_avg:94.34ms +step:803/1705 train_time:75758ms step_avg:94.34ms +step:804/1705 train_time:75851ms step_avg:94.34ms +step:805/1705 train_time:75946ms step_avg:94.34ms +step:806/1705 train_time:76040ms step_avg:94.34ms +step:807/1705 train_time:76134ms step_avg:94.34ms +step:808/1705 train_time:76228ms step_avg:94.34ms +step:809/1705 train_time:76324ms step_avg:94.34ms +step:810/1705 train_time:76418ms step_avg:94.34ms +step:811/1705 train_time:76511ms step_avg:94.34ms +step:812/1705 train_time:76605ms step_avg:94.34ms +step:813/1705 train_time:76699ms step_avg:94.34ms +step:814/1705 train_time:76793ms step_avg:94.34ms +step:815/1705 train_time:76888ms step_avg:94.34ms +step:816/1705 train_time:76982ms step_avg:94.34ms +step:817/1705 train_time:77077ms step_avg:94.34ms +step:818/1705 train_time:77171ms step_avg:94.34ms +step:819/1705 train_time:77265ms step_avg:94.34ms +step:820/1705 train_time:77360ms step_avg:94.34ms +step:821/1705 train_time:77453ms step_avg:94.34ms +step:822/1705 train_time:77547ms step_avg:94.34ms +step:823/1705 train_time:77641ms step_avg:94.34ms +step:824/1705 train_time:77735ms step_avg:94.34ms +step:825/1705 train_time:77829ms step_avg:94.34ms +step:826/1705 train_time:77924ms step_avg:94.34ms +step:827/1705 train_time:78019ms step_avg:94.34ms +step:828/1705 train_time:78113ms step_avg:94.34ms +step:829/1705 train_time:78207ms step_avg:94.34ms +step:830/1705 train_time:78301ms step_avg:94.34ms +step:831/1705 train_time:78395ms step_avg:94.34ms +step:832/1705 train_time:78489ms step_avg:94.34ms +step:833/1705 train_time:78583ms step_avg:94.34ms +step:834/1705 train_time:78678ms step_avg:94.34ms +step:835/1705 train_time:78772ms step_avg:94.34ms +step:836/1705 train_time:78865ms step_avg:94.34ms +step:837/1705 train_time:78959ms step_avg:94.34ms +step:838/1705 train_time:79052ms step_avg:94.33ms +step:839/1705 train_time:79147ms step_avg:94.33ms +step:840/1705 train_time:79242ms step_avg:94.34ms +step:841/1705 train_time:79336ms step_avg:94.34ms +step:842/1705 train_time:79430ms step_avg:94.34ms +step:843/1705 train_time:79524ms step_avg:94.33ms +step:844/1705 train_time:79619ms step_avg:94.34ms +step:845/1705 train_time:79713ms step_avg:94.33ms +step:846/1705 train_time:79807ms step_avg:94.33ms +step:847/1705 train_time:79901ms step_avg:94.33ms +step:848/1705 train_time:79995ms step_avg:94.33ms +step:849/1705 train_time:80089ms step_avg:94.33ms +step:850/1705 train_time:80184ms step_avg:94.33ms +step:851/1705 train_time:80459ms step_avg:94.55ms +step:852/1705 train_time:80595ms step_avg:94.60ms +step:853/1705 train_time:80688ms step_avg:94.59ms +step:854/1705 train_time:80782ms step_avg:94.59ms +step:855/1705 train_time:80874ms step_avg:94.59ms +step:856/1705 train_time:80968ms step_avg:94.59ms +step:857/1705 train_time:81061ms step_avg:94.59ms +step:858/1705 train_time:81154ms step_avg:94.58ms +step:859/1705 train_time:81247ms step_avg:94.58ms +step:860/1705 train_time:81341ms step_avg:94.58ms +step:861/1705 train_time:81436ms step_avg:94.58ms +step:862/1705 train_time:81534ms step_avg:94.59ms +step:863/1705 train_time:81630ms step_avg:94.59ms +step:864/1705 train_time:81725ms step_avg:94.59ms +step:865/1705 train_time:81819ms step_avg:94.59ms +step:866/1705 train_time:81912ms step_avg:94.59ms +step:867/1705 train_time:82006ms step_avg:94.59ms +step:868/1705 train_time:82099ms step_avg:94.58ms +step:869/1705 train_time:82192ms step_avg:94.58ms +step:870/1705 train_time:82285ms step_avg:94.58ms +step:871/1705 train_time:82379ms step_avg:94.58ms +step:872/1705 train_time:82474ms step_avg:94.58ms +step:873/1705 train_time:82570ms step_avg:94.58ms +step:874/1705 train_time:82665ms step_avg:94.58ms +step:875/1705 train_time:82759ms step_avg:94.58ms +step:875/1705 val_loss:3.5242 train_time:82854ms step_avg:94.69ms +step:876/1705 train_time:82878ms step_avg:94.61ms +step:877/1705 train_time:82955ms step_avg:94.59ms +step:878/1705 train_time:83052ms step_avg:94.59ms +step:879/1705 train_time:83148ms step_avg:94.59ms +step:880/1705 train_time:83241ms step_avg:94.59ms +step:881/1705 train_time:83334ms step_avg:94.59ms +step:882/1705 train_time:83427ms step_avg:94.59ms +step:883/1705 train_time:83520ms step_avg:94.59ms +step:884/1705 train_time:83613ms step_avg:94.58ms +step:885/1705 train_time:83706ms step_avg:94.58ms +step:886/1705 train_time:83801ms step_avg:94.58ms +step:887/1705 train_time:83897ms step_avg:94.58ms +step:888/1705 train_time:83993ms step_avg:94.59ms +step:889/1705 train_time:84088ms step_avg:94.59ms +step:890/1705 train_time:84182ms step_avg:94.59ms +step:891/1705 train_time:84277ms step_avg:94.59ms +step:892/1705 train_time:84371ms step_avg:94.59ms +step:893/1705 train_time:84465ms step_avg:94.59ms +step:894/1705 train_time:84558ms step_avg:94.58ms +step:895/1705 train_time:84651ms step_avg:94.58ms +step:896/1705 train_time:84745ms step_avg:94.58ms +step:897/1705 train_time:84839ms step_avg:94.58ms +step:898/1705 train_time:84933ms step_avg:94.58ms +step:899/1705 train_time:85029ms step_avg:94.58ms +step:900/1705 train_time:85123ms step_avg:94.58ms +step:901/1705 train_time:85218ms step_avg:94.58ms +step:902/1705 train_time:85312ms step_avg:94.58ms +step:903/1705 train_time:85406ms step_avg:94.58ms +step:904/1705 train_time:85500ms step_avg:94.58ms +step:905/1705 train_time:85593ms step_avg:94.58ms +step:906/1705 train_time:85687ms step_avg:94.58ms +step:907/1705 train_time:85781ms step_avg:94.58ms +step:908/1705 train_time:85875ms step_avg:94.58ms +step:909/1705 train_time:85970ms step_avg:94.58ms +step:910/1705 train_time:86065ms step_avg:94.58ms +step:911/1705 train_time:86160ms step_avg:94.58ms +step:912/1705 train_time:86254ms step_avg:94.58ms +step:913/1705 train_time:86348ms step_avg:94.58ms +step:914/1705 train_time:86441ms step_avg:94.57ms +step:915/1705 train_time:86536ms step_avg:94.57ms +step:916/1705 train_time:86629ms step_avg:94.57ms +step:917/1705 train_time:86724ms step_avg:94.57ms +step:918/1705 train_time:86818ms step_avg:94.57ms +step:919/1705 train_time:86912ms step_avg:94.57ms +step:920/1705 train_time:87007ms step_avg:94.57ms +step:921/1705 train_time:87102ms step_avg:94.57ms +step:922/1705 train_time:87196ms step_avg:94.57ms +step:923/1705 train_time:87289ms step_avg:94.57ms +step:924/1705 train_time:87384ms step_avg:94.57ms +step:925/1705 train_time:87478ms step_avg:94.57ms +step:926/1705 train_time:87571ms step_avg:94.57ms +step:927/1705 train_time:87666ms step_avg:94.57ms +step:928/1705 train_time:87761ms step_avg:94.57ms +step:929/1705 train_time:87855ms step_avg:94.57ms +step:930/1705 train_time:87948ms step_avg:94.57ms +step:931/1705 train_time:88043ms step_avg:94.57ms +step:932/1705 train_time:88138ms step_avg:94.57ms +step:933/1705 train_time:88232ms step_avg:94.57ms +step:934/1705 train_time:88326ms step_avg:94.57ms +step:935/1705 train_time:88421ms step_avg:94.57ms +step:936/1705 train_time:88515ms step_avg:94.57ms +step:937/1705 train_time:88608ms step_avg:94.57ms +step:938/1705 train_time:88703ms step_avg:94.57ms +step:939/1705 train_time:88797ms step_avg:94.57ms +step:940/1705 train_time:88891ms step_avg:94.56ms +step:941/1705 train_time:88985ms step_avg:94.56ms +step:942/1705 train_time:89081ms step_avg:94.57ms +step:943/1705 train_time:89175ms step_avg:94.56ms +step:944/1705 train_time:89269ms step_avg:94.56ms +step:945/1705 train_time:89363ms step_avg:94.56ms +step:946/1705 train_time:89458ms step_avg:94.56ms +step:947/1705 train_time:89552ms step_avg:94.56ms +step:948/1705 train_time:89646ms step_avg:94.56ms +step:949/1705 train_time:89741ms step_avg:94.56ms +step:950/1705 train_time:89835ms step_avg:94.56ms +step:951/1705 train_time:89929ms step_avg:94.56ms +step:952/1705 train_time:90024ms step_avg:94.56ms +step:953/1705 train_time:90118ms step_avg:94.56ms +step:954/1705 train_time:90212ms step_avg:94.56ms +step:955/1705 train_time:90306ms step_avg:94.56ms +step:956/1705 train_time:90400ms step_avg:94.56ms +step:957/1705 train_time:90495ms step_avg:94.56ms +step:958/1705 train_time:90589ms step_avg:94.56ms +step:959/1705 train_time:90683ms step_avg:94.56ms +step:960/1705 train_time:90777ms step_avg:94.56ms +step:961/1705 train_time:90871ms step_avg:94.56ms +step:962/1705 train_time:90966ms step_avg:94.56ms +step:963/1705 train_time:91060ms step_avg:94.56ms +step:964/1705 train_time:91155ms step_avg:94.56ms +step:965/1705 train_time:91249ms step_avg:94.56ms +step:966/1705 train_time:91344ms step_avg:94.56ms +step:967/1705 train_time:91439ms step_avg:94.56ms +step:968/1705 train_time:91532ms step_avg:94.56ms +step:969/1705 train_time:91626ms step_avg:94.56ms +step:970/1705 train_time:91722ms step_avg:94.56ms +step:971/1705 train_time:91815ms step_avg:94.56ms +step:972/1705 train_time:91909ms step_avg:94.56ms +step:973/1705 train_time:92004ms step_avg:94.56ms +step:974/1705 train_time:92099ms step_avg:94.56ms +step:975/1705 train_time:92192ms step_avg:94.56ms +step:976/1705 train_time:92287ms step_avg:94.56ms +step:977/1705 train_time:92382ms step_avg:94.56ms +step:978/1705 train_time:92476ms step_avg:94.56ms +step:979/1705 train_time:92570ms step_avg:94.56ms +step:980/1705 train_time:92664ms step_avg:94.56ms +step:981/1705 train_time:92759ms step_avg:94.56ms +step:982/1705 train_time:92853ms step_avg:94.56ms +step:983/1705 train_time:92947ms step_avg:94.55ms +step:984/1705 train_time:93041ms step_avg:94.55ms +step:985/1705 train_time:93136ms step_avg:94.55ms +step:986/1705 train_time:93230ms step_avg:94.55ms +step:987/1705 train_time:93324ms step_avg:94.55ms +step:988/1705 train_time:93419ms step_avg:94.55ms +step:989/1705 train_time:93512ms step_avg:94.55ms +step:990/1705 train_time:93607ms step_avg:94.55ms +step:991/1705 train_time:93702ms step_avg:94.55ms +step:992/1705 train_time:93796ms step_avg:94.55ms +step:993/1705 train_time:93889ms step_avg:94.55ms +step:994/1705 train_time:93984ms step_avg:94.55ms +step:995/1705 train_time:94078ms step_avg:94.55ms +step:996/1705 train_time:94172ms step_avg:94.55ms +step:997/1705 train_time:94266ms step_avg:94.55ms +step:998/1705 train_time:94360ms step_avg:94.55ms +step:999/1705 train_time:94455ms step_avg:94.55ms +step:1000/1705 train_time:94549ms step_avg:94.55ms +step:1000/1705 val_loss:3.4846 train_time:94644ms step_avg:94.64ms +step:1001/1705 train_time:94667ms step_avg:94.57ms +step:1002/1705 train_time:94744ms step_avg:94.55ms +step:1003/1705 train_time:94842ms step_avg:94.56ms +step:1004/1705 train_time:94938ms step_avg:94.56ms +step:1005/1705 train_time:95031ms step_avg:94.56ms +step:1006/1705 train_time:95124ms step_avg:94.56ms +step:1007/1705 train_time:95218ms step_avg:94.56ms +step:1008/1705 train_time:95311ms step_avg:94.55ms +step:1009/1705 train_time:95405ms step_avg:94.55ms +step:1010/1705 train_time:95498ms step_avg:94.55ms +step:1011/1705 train_time:95592ms step_avg:94.55ms +step:1012/1705 train_time:95687ms step_avg:94.55ms +step:1013/1705 train_time:95784ms step_avg:94.55ms +step:1014/1705 train_time:95879ms step_avg:94.56ms +step:1015/1705 train_time:95974ms step_avg:94.56ms +step:1016/1705 train_time:96067ms step_avg:94.55ms +step:1017/1705 train_time:96162ms step_avg:94.55ms +step:1018/1705 train_time:96256ms step_avg:94.55ms +step:1019/1705 train_time:96349ms step_avg:94.55ms +step:1020/1705 train_time:96443ms step_avg:94.55ms +step:1021/1705 train_time:96537ms step_avg:94.55ms +step:1022/1705 train_time:96633ms step_avg:94.55ms +step:1023/1705 train_time:96727ms step_avg:94.55ms +step:1024/1705 train_time:96821ms step_avg:94.55ms +step:1025/1705 train_time:96918ms step_avg:94.55ms +step:1026/1705 train_time:97013ms step_avg:94.55ms +step:1027/1705 train_time:97106ms step_avg:94.55ms +step:1028/1705 train_time:97201ms step_avg:94.55ms +step:1029/1705 train_time:97295ms step_avg:94.55ms +step:1030/1705 train_time:97388ms step_avg:94.55ms +step:1031/1705 train_time:97482ms step_avg:94.55ms +step:1032/1705 train_time:97577ms step_avg:94.55ms +step:1033/1705 train_time:97671ms step_avg:94.55ms +step:1034/1705 train_time:97765ms step_avg:94.55ms +step:1035/1705 train_time:97860ms step_avg:94.55ms +step:1036/1705 train_time:97956ms step_avg:94.55ms +step:1037/1705 train_time:98049ms step_avg:94.55ms +step:1038/1705 train_time:98143ms step_avg:94.55ms +step:1039/1705 train_time:98238ms step_avg:94.55ms +step:1040/1705 train_time:98331ms step_avg:94.55ms +step:1041/1705 train_time:98425ms step_avg:94.55ms +step:1042/1705 train_time:98519ms step_avg:94.55ms +step:1043/1705 train_time:98613ms step_avg:94.55ms +step:1044/1705 train_time:98706ms step_avg:94.55ms +step:1045/1705 train_time:98800ms step_avg:94.55ms +step:1046/1705 train_time:98896ms step_avg:94.55ms +step:1047/1705 train_time:98991ms step_avg:94.55ms +step:1048/1705 train_time:99085ms step_avg:94.55ms +step:1049/1705 train_time:99179ms step_avg:94.55ms +step:1050/1705 train_time:99274ms step_avg:94.55ms +step:1051/1705 train_time:99367ms step_avg:94.55ms +step:1052/1705 train_time:99461ms step_avg:94.54ms +step:1053/1705 train_time:99556ms step_avg:94.54ms +step:1054/1705 train_time:99651ms step_avg:94.55ms +step:1055/1705 train_time:99744ms step_avg:94.54ms +step:1056/1705 train_time:99839ms step_avg:94.54ms +step:1057/1705 train_time:99935ms step_avg:94.55ms +step:1058/1705 train_time:100029ms step_avg:94.55ms +step:1059/1705 train_time:100123ms step_avg:94.54ms +step:1060/1705 train_time:100219ms step_avg:94.55ms +step:1061/1705 train_time:100314ms step_avg:94.55ms +step:1062/1705 train_time:100589ms step_avg:94.72ms +step:1063/1705 train_time:100735ms step_avg:94.76ms +step:1064/1705 train_time:100827ms step_avg:94.76ms +step:1065/1705 train_time:100921ms step_avg:94.76ms +step:1066/1705 train_time:101014ms step_avg:94.76ms +step:1067/1705 train_time:101108ms step_avg:94.76ms +step:1068/1705 train_time:101201ms step_avg:94.76ms +step:1069/1705 train_time:101294ms step_avg:94.76ms +step:1070/1705 train_time:101387ms step_avg:94.75ms +step:1071/1705 train_time:101481ms step_avg:94.75ms +step:1072/1705 train_time:101580ms step_avg:94.76ms +step:1073/1705 train_time:101678ms step_avg:94.76ms +step:1074/1705 train_time:101773ms step_avg:94.76ms +step:1075/1705 train_time:101867ms step_avg:94.76ms +step:1076/1705 train_time:101961ms step_avg:94.76ms +step:1077/1705 train_time:102056ms step_avg:94.76ms +step:1078/1705 train_time:102149ms step_avg:94.76ms +step:1079/1705 train_time:102243ms step_avg:94.76ms +step:1080/1705 train_time:102337ms step_avg:94.76ms +step:1081/1705 train_time:102430ms step_avg:94.75ms +step:1082/1705 train_time:102525ms step_avg:94.75ms +step:1083/1705 train_time:102622ms step_avg:94.76ms +step:1084/1705 train_time:102720ms step_avg:94.76ms +step:1085/1705 train_time:102815ms step_avg:94.76ms +step:1086/1705 train_time:102909ms step_avg:94.76ms +step:1087/1705 train_time:103004ms step_avg:94.76ms +step:1088/1705 train_time:103098ms step_avg:94.76ms +step:1089/1705 train_time:103192ms step_avg:94.76ms +step:1090/1705 train_time:103285ms step_avg:94.76ms +step:1091/1705 train_time:103379ms step_avg:94.76ms +step:1092/1705 train_time:103474ms step_avg:94.76ms +step:1093/1705 train_time:103568ms step_avg:94.76ms +step:1094/1705 train_time:103663ms step_avg:94.76ms +step:1095/1705 train_time:103758ms step_avg:94.76ms +step:1096/1705 train_time:103853ms step_avg:94.76ms +step:1097/1705 train_time:103947ms step_avg:94.76ms +step:1098/1705 train_time:104042ms step_avg:94.76ms +step:1099/1705 train_time:104136ms step_avg:94.76ms +step:1100/1705 train_time:104230ms step_avg:94.75ms +step:1101/1705 train_time:104324ms step_avg:94.75ms +step:1102/1705 train_time:104419ms step_avg:94.75ms +step:1103/1705 train_time:104513ms step_avg:94.75ms +step:1104/1705 train_time:104607ms step_avg:94.75ms +step:1105/1705 train_time:104701ms step_avg:94.75ms +step:1106/1705 train_time:104796ms step_avg:94.75ms +step:1107/1705 train_time:104891ms step_avg:94.75ms +step:1108/1705 train_time:104984ms step_avg:94.75ms +step:1109/1705 train_time:105079ms step_avg:94.75ms +step:1110/1705 train_time:105174ms step_avg:94.75ms +step:1111/1705 train_time:105268ms step_avg:94.75ms +step:1112/1705 train_time:105362ms step_avg:94.75ms +step:1113/1705 train_time:105457ms step_avg:94.75ms +step:1114/1705 train_time:105551ms step_avg:94.75ms +step:1115/1705 train_time:105645ms step_avg:94.75ms +step:1116/1705 train_time:105740ms step_avg:94.75ms +step:1117/1705 train_time:105834ms step_avg:94.75ms +step:1118/1705 train_time:105929ms step_avg:94.75ms +step:1119/1705 train_time:106024ms step_avg:94.75ms +step:1120/1705 train_time:106118ms step_avg:94.75ms +step:1121/1705 train_time:106214ms step_avg:94.75ms +step:1122/1705 train_time:106307ms step_avg:94.75ms +step:1123/1705 train_time:106401ms step_avg:94.75ms +step:1124/1705 train_time:106496ms step_avg:94.75ms +step:1125/1705 train_time:106590ms step_avg:94.75ms +step:1125/1705 val_loss:3.4377 train_time:106684ms step_avg:94.83ms +step:1126/1705 train_time:106707ms step_avg:94.77ms +step:1127/1705 train_time:106786ms step_avg:94.75ms +step:1128/1705 train_time:106887ms step_avg:94.76ms +step:1129/1705 train_time:106981ms step_avg:94.76ms +step:1130/1705 train_time:107075ms step_avg:94.76ms +step:1131/1705 train_time:107169ms step_avg:94.76ms +step:1132/1705 train_time:107261ms step_avg:94.75ms +step:1133/1705 train_time:107355ms step_avg:94.75ms +step:1134/1705 train_time:107449ms step_avg:94.75ms +step:1135/1705 train_time:107542ms step_avg:94.75ms +step:1136/1705 train_time:107636ms step_avg:94.75ms +step:1137/1705 train_time:107733ms step_avg:94.75ms +step:1138/1705 train_time:107831ms step_avg:94.75ms +step:1139/1705 train_time:107926ms step_avg:94.76ms +step:1140/1705 train_time:108021ms step_avg:94.76ms +step:1141/1705 train_time:108117ms step_avg:94.76ms +step:1142/1705 train_time:108211ms step_avg:94.76ms +step:1143/1705 train_time:108305ms step_avg:94.75ms +step:1144/1705 train_time:108400ms step_avg:94.75ms +step:1145/1705 train_time:108494ms step_avg:94.75ms +step:1146/1705 train_time:108588ms step_avg:94.75ms +step:1147/1705 train_time:108684ms step_avg:94.76ms +step:1148/1705 train_time:108782ms step_avg:94.76ms +step:1149/1705 train_time:108879ms step_avg:94.76ms +step:1150/1705 train_time:108975ms step_avg:94.76ms +step:1151/1705 train_time:109070ms step_avg:94.76ms +step:1152/1705 train_time:109163ms step_avg:94.76ms +step:1153/1705 train_time:109258ms step_avg:94.76ms +step:1154/1705 train_time:109353ms step_avg:94.76ms +step:1155/1705 train_time:109447ms step_avg:94.76ms +step:1156/1705 train_time:109542ms step_avg:94.76ms +step:1157/1705 train_time:109637ms step_avg:94.76ms +step:1158/1705 train_time:109734ms step_avg:94.76ms +step:1159/1705 train_time:109831ms step_avg:94.76ms +step:1160/1705 train_time:109927ms step_avg:94.76ms +step:1161/1705 train_time:110022ms step_avg:94.77ms +step:1162/1705 train_time:110117ms step_avg:94.76ms +step:1163/1705 train_time:110212ms step_avg:94.76ms +step:1164/1705 train_time:110305ms step_avg:94.76ms +step:1165/1705 train_time:110400ms step_avg:94.76ms +step:1166/1705 train_time:110495ms step_avg:94.76ms +step:1167/1705 train_time:110590ms step_avg:94.76ms +step:1168/1705 train_time:110684ms step_avg:94.76ms +step:1169/1705 train_time:110781ms step_avg:94.77ms +step:1170/1705 train_time:110878ms step_avg:94.77ms +step:1171/1705 train_time:110974ms step_avg:94.77ms +step:1172/1705 train_time:111069ms step_avg:94.77ms +step:1173/1705 train_time:111164ms step_avg:94.77ms +step:1174/1705 train_time:111259ms step_avg:94.77ms +step:1175/1705 train_time:111354ms step_avg:94.77ms +step:1176/1705 train_time:111449ms step_avg:94.77ms +step:1177/1705 train_time:111544ms step_avg:94.77ms +step:1178/1705 train_time:111639ms step_avg:94.77ms +step:1179/1705 train_time:111734ms step_avg:94.77ms +step:1180/1705 train_time:111830ms step_avg:94.77ms +step:1181/1705 train_time:111925ms step_avg:94.77ms +step:1182/1705 train_time:112021ms step_avg:94.77ms +step:1183/1705 train_time:112117ms step_avg:94.77ms +step:1184/1705 train_time:112212ms step_avg:94.77ms +step:1185/1705 train_time:112306ms step_avg:94.77ms +step:1186/1705 train_time:112402ms step_avg:94.77ms +step:1187/1705 train_time:112496ms step_avg:94.77ms +step:1188/1705 train_time:112591ms step_avg:94.77ms +step:1189/1705 train_time:112686ms step_avg:94.77ms +step:1190/1705 train_time:112782ms step_avg:94.77ms +step:1191/1705 train_time:112879ms step_avg:94.78ms +step:1192/1705 train_time:112975ms step_avg:94.78ms +step:1193/1705 train_time:113070ms step_avg:94.78ms +step:1194/1705 train_time:113164ms step_avg:94.78ms +step:1195/1705 train_time:113260ms step_avg:94.78ms +step:1196/1705 train_time:113355ms step_avg:94.78ms +step:1197/1705 train_time:113450ms step_avg:94.78ms +step:1198/1705 train_time:113544ms step_avg:94.78ms +step:1199/1705 train_time:113640ms step_avg:94.78ms +step:1200/1705 train_time:113735ms step_avg:94.78ms +step:1201/1705 train_time:113831ms step_avg:94.78ms +step:1202/1705 train_time:113925ms step_avg:94.78ms +step:1203/1705 train_time:114021ms step_avg:94.78ms +step:1204/1705 train_time:114117ms step_avg:94.78ms +step:1205/1705 train_time:114213ms step_avg:94.78ms +step:1206/1705 train_time:114308ms step_avg:94.78ms +step:1207/1705 train_time:114402ms step_avg:94.78ms +step:1208/1705 train_time:114498ms step_avg:94.78ms +step:1209/1705 train_time:114592ms step_avg:94.78ms +step:1210/1705 train_time:114687ms step_avg:94.78ms +step:1211/1705 train_time:114782ms step_avg:94.78ms +step:1212/1705 train_time:114878ms step_avg:94.78ms +step:1213/1705 train_time:114973ms step_avg:94.78ms +step:1214/1705 train_time:115067ms step_avg:94.78ms +step:1215/1705 train_time:115162ms step_avg:94.78ms +step:1216/1705 train_time:115259ms step_avg:94.79ms +step:1217/1705 train_time:115355ms step_avg:94.79ms +step:1218/1705 train_time:115449ms step_avg:94.79ms +step:1219/1705 train_time:115544ms step_avg:94.79ms +step:1220/1705 train_time:115639ms step_avg:94.79ms +step:1221/1705 train_time:115735ms step_avg:94.79ms +step:1222/1705 train_time:115830ms step_avg:94.79ms +step:1223/1705 train_time:115924ms step_avg:94.79ms +step:1224/1705 train_time:116020ms step_avg:94.79ms +step:1225/1705 train_time:116116ms step_avg:94.79ms +step:1226/1705 train_time:116212ms step_avg:94.79ms +step:1227/1705 train_time:116306ms step_avg:94.79ms +step:1228/1705 train_time:116401ms step_avg:94.79ms +step:1229/1705 train_time:116496ms step_avg:94.79ms +step:1230/1705 train_time:116591ms step_avg:94.79ms +step:1231/1705 train_time:116686ms step_avg:94.79ms +step:1232/1705 train_time:116782ms step_avg:94.79ms +step:1233/1705 train_time:116878ms step_avg:94.79ms +step:1234/1705 train_time:116972ms step_avg:94.79ms +step:1235/1705 train_time:117067ms step_avg:94.79ms +step:1236/1705 train_time:117163ms step_avg:94.79ms +step:1237/1705 train_time:117259ms step_avg:94.79ms +step:1238/1705 train_time:117354ms step_avg:94.79ms +step:1239/1705 train_time:117449ms step_avg:94.79ms +step:1240/1705 train_time:117543ms step_avg:94.79ms +step:1241/1705 train_time:117638ms step_avg:94.79ms +step:1242/1705 train_time:117733ms step_avg:94.79ms +step:1243/1705 train_time:117829ms step_avg:94.79ms +step:1244/1705 train_time:117923ms step_avg:94.79ms +step:1245/1705 train_time:118019ms step_avg:94.79ms +step:1246/1705 train_time:118114ms step_avg:94.79ms +step:1247/1705 train_time:118209ms step_avg:94.79ms +step:1248/1705 train_time:118304ms step_avg:94.79ms +step:1249/1705 train_time:118398ms step_avg:94.79ms +step:1250/1705 train_time:118494ms step_avg:94.80ms +step:1250/1705 val_loss:3.3891 train_time:118588ms step_avg:94.87ms +step:1251/1705 train_time:118612ms step_avg:94.81ms +step:1252/1705 train_time:118691ms step_avg:94.80ms +step:1253/1705 train_time:118790ms step_avg:94.80ms +step:1254/1705 train_time:118886ms step_avg:94.81ms +step:1255/1705 train_time:118980ms step_avg:94.80ms +step:1256/1705 train_time:119074ms step_avg:94.80ms +step:1257/1705 train_time:119168ms step_avg:94.80ms +step:1258/1705 train_time:119261ms step_avg:94.80ms +step:1259/1705 train_time:119355ms step_avg:94.80ms +step:1260/1705 train_time:119449ms step_avg:94.80ms +step:1261/1705 train_time:119546ms step_avg:94.80ms +step:1262/1705 train_time:119644ms step_avg:94.81ms +step:1263/1705 train_time:119741ms step_avg:94.81ms +step:1264/1705 train_time:119836ms step_avg:94.81ms +step:1265/1705 train_time:119930ms step_avg:94.81ms +step:1266/1705 train_time:120025ms step_avg:94.81ms +step:1267/1705 train_time:120120ms step_avg:94.81ms +step:1268/1705 train_time:120214ms step_avg:94.81ms +step:1269/1705 train_time:120307ms step_avg:94.80ms +step:1270/1705 train_time:120401ms step_avg:94.80ms +step:1271/1705 train_time:120496ms step_avg:94.80ms +step:1272/1705 train_time:120591ms step_avg:94.80ms +step:1273/1705 train_time:120688ms step_avg:94.81ms +step:1274/1705 train_time:121028ms step_avg:95.00ms +step:1275/1705 train_time:121097ms step_avg:94.98ms +step:1276/1705 train_time:121189ms step_avg:94.98ms +step:1277/1705 train_time:121283ms step_avg:94.98ms +step:1278/1705 train_time:121377ms step_avg:94.97ms +step:1279/1705 train_time:121470ms step_avg:94.97ms +step:1280/1705 train_time:121564ms step_avg:94.97ms +step:1281/1705 train_time:121657ms step_avg:94.97ms +step:1282/1705 train_time:121751ms step_avg:94.97ms +step:1283/1705 train_time:121845ms step_avg:94.97ms +step:1284/1705 train_time:121942ms step_avg:94.97ms +step:1285/1705 train_time:122041ms step_avg:94.97ms +step:1286/1705 train_time:122138ms step_avg:94.97ms +step:1287/1705 train_time:122233ms step_avg:94.97ms +step:1288/1705 train_time:122327ms step_avg:94.97ms +step:1289/1705 train_time:122421ms step_avg:94.97ms +step:1290/1705 train_time:122516ms step_avg:94.97ms +step:1291/1705 train_time:122609ms step_avg:94.97ms +step:1292/1705 train_time:122703ms step_avg:94.97ms +step:1293/1705 train_time:122797ms step_avg:94.97ms +step:1294/1705 train_time:122891ms step_avg:94.97ms +step:1295/1705 train_time:122987ms step_avg:94.97ms +step:1296/1705 train_time:123084ms step_avg:94.97ms +step:1297/1705 train_time:123180ms step_avg:94.97ms +step:1298/1705 train_time:123275ms step_avg:94.97ms +step:1299/1705 train_time:123371ms step_avg:94.97ms +step:1300/1705 train_time:123465ms step_avg:94.97ms +step:1301/1705 train_time:123561ms step_avg:94.97ms +step:1302/1705 train_time:123654ms step_avg:94.97ms +step:1303/1705 train_time:123748ms step_avg:94.97ms +step:1304/1705 train_time:123843ms step_avg:94.97ms +step:1305/1705 train_time:123938ms step_avg:94.97ms +step:1306/1705 train_time:124033ms step_avg:94.97ms +step:1307/1705 train_time:124128ms step_avg:94.97ms +step:1308/1705 train_time:124224ms step_avg:94.97ms +step:1309/1705 train_time:124320ms step_avg:94.97ms +step:1310/1705 train_time:124415ms step_avg:94.97ms +step:1311/1705 train_time:124509ms step_avg:94.97ms +step:1312/1705 train_time:124604ms step_avg:94.97ms +step:1313/1705 train_time:124698ms step_avg:94.97ms +step:1314/1705 train_time:124792ms step_avg:94.97ms +step:1315/1705 train_time:124887ms step_avg:94.97ms +step:1316/1705 train_time:124983ms step_avg:94.97ms +step:1317/1705 train_time:125078ms step_avg:94.97ms +step:1318/1705 train_time:125174ms step_avg:94.97ms +step:1319/1705 train_time:125269ms step_avg:94.97ms +step:1320/1705 train_time:125364ms step_avg:94.97ms +step:1321/1705 train_time:125459ms step_avg:94.97ms +step:1322/1705 train_time:125553ms step_avg:94.97ms +step:1323/1705 train_time:125647ms step_avg:94.97ms +step:1324/1705 train_time:125742ms step_avg:94.97ms +step:1325/1705 train_time:125835ms step_avg:94.97ms +step:1326/1705 train_time:125929ms step_avg:94.97ms +step:1327/1705 train_time:126025ms step_avg:94.97ms +step:1328/1705 train_time:126121ms step_avg:94.97ms +step:1329/1705 train_time:126215ms step_avg:94.97ms +step:1330/1705 train_time:126309ms step_avg:94.97ms +step:1331/1705 train_time:126405ms step_avg:94.97ms +step:1332/1705 train_time:126500ms step_avg:94.97ms +step:1333/1705 train_time:126595ms step_avg:94.97ms +step:1334/1705 train_time:126688ms step_avg:94.97ms +step:1335/1705 train_time:126784ms step_avg:94.97ms +step:1336/1705 train_time:126879ms step_avg:94.97ms +step:1337/1705 train_time:126975ms step_avg:94.97ms +step:1338/1705 train_time:127069ms step_avg:94.97ms +step:1339/1705 train_time:127165ms step_avg:94.97ms +step:1340/1705 train_time:127260ms step_avg:94.97ms +step:1341/1705 train_time:127355ms step_avg:94.97ms +step:1342/1705 train_time:127449ms step_avg:94.97ms +step:1343/1705 train_time:127545ms step_avg:94.97ms +step:1344/1705 train_time:127639ms step_avg:94.97ms +step:1345/1705 train_time:127734ms step_avg:94.97ms +step:1346/1705 train_time:127828ms step_avg:94.97ms +step:1347/1705 train_time:127923ms step_avg:94.97ms +step:1348/1705 train_time:128018ms step_avg:94.97ms +step:1349/1705 train_time:128113ms step_avg:94.97ms +step:1350/1705 train_time:128207ms step_avg:94.97ms +step:1351/1705 train_time:128304ms step_avg:94.97ms +step:1352/1705 train_time:128401ms step_avg:94.97ms +step:1353/1705 train_time:128496ms step_avg:94.97ms +step:1354/1705 train_time:128590ms step_avg:94.97ms +step:1355/1705 train_time:128686ms step_avg:94.97ms +step:1356/1705 train_time:128781ms step_avg:94.97ms +step:1357/1705 train_time:128875ms step_avg:94.97ms +step:1358/1705 train_time:128968ms step_avg:94.97ms +step:1359/1705 train_time:129063ms step_avg:94.97ms +step:1360/1705 train_time:129159ms step_avg:94.97ms +step:1361/1705 train_time:129253ms step_avg:94.97ms +step:1362/1705 train_time:129348ms step_avg:94.97ms +step:1363/1705 train_time:129444ms step_avg:94.97ms +step:1364/1705 train_time:129539ms step_avg:94.97ms +step:1365/1705 train_time:129633ms step_avg:94.97ms +step:1366/1705 train_time:129728ms step_avg:94.97ms +step:1367/1705 train_time:129823ms step_avg:94.97ms +step:1368/1705 train_time:129918ms step_avg:94.97ms +step:1369/1705 train_time:130012ms step_avg:94.97ms +step:1370/1705 train_time:130107ms step_avg:94.97ms +step:1371/1705 train_time:130203ms step_avg:94.97ms +step:1372/1705 train_time:130299ms step_avg:94.97ms +step:1373/1705 train_time:130393ms step_avg:94.97ms +step:1374/1705 train_time:130488ms step_avg:94.97ms +step:1375/1705 train_time:130584ms step_avg:94.97ms +step:1375/1705 val_loss:3.3522 train_time:130680ms step_avg:95.04ms +step:1376/1705 train_time:130702ms step_avg:94.99ms +step:1377/1705 train_time:130782ms step_avg:94.98ms +step:1378/1705 train_time:130883ms step_avg:94.98ms +step:1379/1705 train_time:130978ms step_avg:94.98ms +step:1380/1705 train_time:131072ms step_avg:94.98ms +step:1381/1705 train_time:131166ms step_avg:94.98ms +step:1382/1705 train_time:131260ms step_avg:94.98ms +step:1383/1705 train_time:131354ms step_avg:94.98ms +step:1384/1705 train_time:131448ms step_avg:94.98ms +step:1385/1705 train_time:131541ms step_avg:94.98ms +step:1386/1705 train_time:131635ms step_avg:94.97ms +step:1387/1705 train_time:131732ms step_avg:94.98ms +step:1388/1705 train_time:131831ms step_avg:94.98ms +step:1389/1705 train_time:131928ms step_avg:94.98ms +step:1390/1705 train_time:132024ms step_avg:94.98ms +step:1391/1705 train_time:132118ms step_avg:94.98ms +step:1392/1705 train_time:132212ms step_avg:94.98ms +step:1393/1705 train_time:132307ms step_avg:94.98ms +step:1394/1705 train_time:132400ms step_avg:94.98ms +step:1395/1705 train_time:132494ms step_avg:94.98ms +step:1396/1705 train_time:132588ms step_avg:94.98ms +step:1397/1705 train_time:132684ms step_avg:94.98ms +step:1398/1705 train_time:132779ms step_avg:94.98ms +step:1399/1705 train_time:132874ms step_avg:94.98ms +step:1400/1705 train_time:132970ms step_avg:94.98ms +step:1401/1705 train_time:133066ms step_avg:94.98ms +step:1402/1705 train_time:133162ms step_avg:94.98ms +step:1403/1705 train_time:133255ms step_avg:94.98ms +step:1404/1705 train_time:133349ms step_avg:94.98ms +step:1405/1705 train_time:133445ms step_avg:94.98ms +step:1406/1705 train_time:133539ms step_avg:94.98ms +step:1407/1705 train_time:133633ms step_avg:94.98ms +step:1408/1705 train_time:133729ms step_avg:94.98ms +step:1409/1705 train_time:133826ms step_avg:94.98ms +step:1410/1705 train_time:133921ms step_avg:94.98ms +step:1411/1705 train_time:134016ms step_avg:94.98ms +step:1412/1705 train_time:134112ms step_avg:94.98ms +step:1413/1705 train_time:134207ms step_avg:94.98ms +step:1414/1705 train_time:134301ms step_avg:94.98ms +step:1415/1705 train_time:134394ms step_avg:94.98ms +step:1416/1705 train_time:134489ms step_avg:94.98ms +step:1417/1705 train_time:134584ms step_avg:94.98ms +step:1418/1705 train_time:134678ms step_avg:94.98ms +step:1419/1705 train_time:134773ms step_avg:94.98ms +step:1420/1705 train_time:134869ms step_avg:94.98ms +step:1421/1705 train_time:134964ms step_avg:94.98ms +step:1422/1705 train_time:135060ms step_avg:94.98ms +step:1423/1705 train_time:135155ms step_avg:94.98ms +step:1424/1705 train_time:135250ms step_avg:94.98ms +step:1425/1705 train_time:135345ms step_avg:94.98ms +step:1426/1705 train_time:135439ms step_avg:94.98ms +step:1427/1705 train_time:135533ms step_avg:94.98ms +step:1428/1705 train_time:135628ms step_avg:94.98ms +step:1429/1705 train_time:135723ms step_avg:94.98ms +step:1430/1705 train_time:135819ms step_avg:94.98ms +step:1431/1705 train_time:135914ms step_avg:94.98ms +step:1432/1705 train_time:136010ms step_avg:94.98ms +step:1433/1705 train_time:136107ms step_avg:94.98ms +step:1434/1705 train_time:136202ms step_avg:94.98ms +step:1435/1705 train_time:136296ms step_avg:94.98ms +step:1436/1705 train_time:136392ms step_avg:94.98ms +step:1437/1705 train_time:136487ms step_avg:94.98ms +step:1438/1705 train_time:136581ms step_avg:94.98ms +step:1439/1705 train_time:136675ms step_avg:94.98ms +step:1440/1705 train_time:136770ms step_avg:94.98ms +step:1441/1705 train_time:136865ms step_avg:94.98ms +step:1442/1705 train_time:136961ms step_avg:94.98ms +step:1443/1705 train_time:137055ms step_avg:94.98ms +step:1444/1705 train_time:137151ms step_avg:94.98ms +step:1445/1705 train_time:137246ms step_avg:94.98ms +step:1446/1705 train_time:137341ms step_avg:94.98ms +step:1447/1705 train_time:137436ms step_avg:94.98ms +step:1448/1705 train_time:137531ms step_avg:94.98ms +step:1449/1705 train_time:137626ms step_avg:94.98ms +step:1450/1705 train_time:137721ms step_avg:94.98ms +step:1451/1705 train_time:137815ms step_avg:94.98ms +step:1452/1705 train_time:137911ms step_avg:94.98ms +step:1453/1705 train_time:138005ms step_avg:94.98ms +step:1454/1705 train_time:138100ms step_avg:94.98ms +step:1455/1705 train_time:138195ms step_avg:94.98ms +step:1456/1705 train_time:138291ms step_avg:94.98ms +step:1457/1705 train_time:138386ms step_avg:94.98ms +step:1458/1705 train_time:138482ms step_avg:94.98ms +step:1459/1705 train_time:138576ms step_avg:94.98ms +step:1460/1705 train_time:138671ms step_avg:94.98ms +step:1461/1705 train_time:138767ms step_avg:94.98ms +step:1462/1705 train_time:138862ms step_avg:94.98ms +step:1463/1705 train_time:138956ms step_avg:94.98ms +step:1464/1705 train_time:139052ms step_avg:94.98ms +step:1465/1705 train_time:139147ms step_avg:94.98ms +step:1466/1705 train_time:139243ms step_avg:94.98ms +step:1467/1705 train_time:139337ms step_avg:94.98ms +step:1468/1705 train_time:139433ms step_avg:94.98ms +step:1469/1705 train_time:139529ms step_avg:94.98ms +step:1470/1705 train_time:139623ms step_avg:94.98ms +step:1471/1705 train_time:139717ms step_avg:94.98ms +step:1472/1705 train_time:139813ms step_avg:94.98ms +step:1473/1705 train_time:139908ms step_avg:94.98ms +step:1474/1705 train_time:140003ms step_avg:94.98ms +step:1475/1705 train_time:140098ms step_avg:94.98ms +step:1476/1705 train_time:140194ms step_avg:94.98ms +step:1477/1705 train_time:140289ms step_avg:94.98ms +step:1478/1705 train_time:140385ms step_avg:94.98ms +step:1479/1705 train_time:140480ms step_avg:94.98ms +step:1480/1705 train_time:140575ms step_avg:94.98ms +step:1481/1705 train_time:140670ms step_avg:94.98ms +step:1482/1705 train_time:140766ms step_avg:94.98ms +step:1483/1705 train_time:140861ms step_avg:94.98ms +step:1484/1705 train_time:140955ms step_avg:94.98ms +step:1485/1705 train_time:141256ms step_avg:95.12ms +step:1486/1705 train_time:141332ms step_avg:95.11ms +step:1487/1705 train_time:141425ms step_avg:95.11ms +step:1488/1705 train_time:141519ms step_avg:95.11ms +step:1489/1705 train_time:141612ms step_avg:95.11ms +step:1490/1705 train_time:141706ms step_avg:95.10ms +step:1491/1705 train_time:141800ms step_avg:95.10ms +step:1492/1705 train_time:141894ms step_avg:95.10ms +step:1493/1705 train_time:141988ms step_avg:95.10ms +step:1494/1705 train_time:142082ms step_avg:95.10ms +step:1495/1705 train_time:142181ms step_avg:95.10ms +step:1496/1705 train_time:142279ms step_avg:95.11ms +step:1497/1705 train_time:142375ms step_avg:95.11ms +step:1498/1705 train_time:142470ms step_avg:95.11ms +step:1499/1705 train_time:142565ms step_avg:95.11ms +step:1500/1705 train_time:142659ms step_avg:95.11ms +step:1500/1705 val_loss:3.3196 train_time:142753ms step_avg:95.17ms +step:1501/1705 train_time:142776ms step_avg:95.12ms +step:1502/1705 train_time:142854ms step_avg:95.11ms +step:1503/1705 train_time:142953ms step_avg:95.11ms +step:1504/1705 train_time:143048ms step_avg:95.11ms +step:1505/1705 train_time:143141ms step_avg:95.11ms +step:1506/1705 train_time:143235ms step_avg:95.11ms +step:1507/1705 train_time:143329ms step_avg:95.11ms +step:1508/1705 train_time:143423ms step_avg:95.11ms +step:1509/1705 train_time:143517ms step_avg:95.11ms +step:1510/1705 train_time:143612ms step_avg:95.11ms +step:1511/1705 train_time:143705ms step_avg:95.11ms +step:1512/1705 train_time:143803ms step_avg:95.11ms +step:1513/1705 train_time:143901ms step_avg:95.11ms +step:1514/1705 train_time:143999ms step_avg:95.11ms +step:1515/1705 train_time:144095ms step_avg:95.11ms +step:1516/1705 train_time:144189ms step_avg:95.11ms +step:1517/1705 train_time:144283ms step_avg:95.11ms +step:1518/1705 train_time:144377ms step_avg:95.11ms +step:1519/1705 train_time:144470ms step_avg:95.11ms +step:1520/1705 train_time:144564ms step_avg:95.11ms +step:1521/1705 train_time:144659ms step_avg:95.11ms +step:1522/1705 train_time:144754ms step_avg:95.11ms +step:1523/1705 train_time:144851ms step_avg:95.11ms +step:1524/1705 train_time:144947ms step_avg:95.11ms +step:1525/1705 train_time:145042ms step_avg:95.11ms +step:1526/1705 train_time:145138ms step_avg:95.11ms +step:1527/1705 train_time:145232ms step_avg:95.11ms +step:1528/1705 train_time:145327ms step_avg:95.11ms +step:1529/1705 train_time:145421ms step_avg:95.11ms +step:1530/1705 train_time:145516ms step_avg:95.11ms +step:1531/1705 train_time:145610ms step_avg:95.11ms +step:1532/1705 train_time:145704ms step_avg:95.11ms +step:1533/1705 train_time:145800ms step_avg:95.11ms +step:1534/1705 train_time:145897ms step_avg:95.11ms +step:1535/1705 train_time:145993ms step_avg:95.11ms +step:1536/1705 train_time:146088ms step_avg:95.11ms +step:1537/1705 train_time:146183ms step_avg:95.11ms +step:1538/1705 train_time:146279ms step_avg:95.11ms +step:1539/1705 train_time:146373ms step_avg:95.11ms +step:1540/1705 train_time:146467ms step_avg:95.11ms +step:1541/1705 train_time:146562ms step_avg:95.11ms +step:1542/1705 train_time:146657ms step_avg:95.11ms +step:1543/1705 train_time:146753ms step_avg:95.11ms +step:1544/1705 train_time:146847ms step_avg:95.11ms +step:1545/1705 train_time:146944ms step_avg:95.11ms +step:1546/1705 train_time:147040ms step_avg:95.11ms +step:1547/1705 train_time:147136ms step_avg:95.11ms +step:1548/1705 train_time:147231ms step_avg:95.11ms +step:1549/1705 train_time:147324ms step_avg:95.11ms +step:1550/1705 train_time:147419ms step_avg:95.11ms +step:1551/1705 train_time:147515ms step_avg:95.11ms +step:1552/1705 train_time:147610ms step_avg:95.11ms +step:1553/1705 train_time:147704ms step_avg:95.11ms +step:1554/1705 train_time:147800ms step_avg:95.11ms +step:1555/1705 train_time:147896ms step_avg:95.11ms +step:1556/1705 train_time:147991ms step_avg:95.11ms +step:1557/1705 train_time:148086ms step_avg:95.11ms +step:1558/1705 train_time:148182ms step_avg:95.11ms +step:1559/1705 train_time:148277ms step_avg:95.11ms +step:1560/1705 train_time:148371ms step_avg:95.11ms +step:1561/1705 train_time:148465ms step_avg:95.11ms +step:1562/1705 train_time:148561ms step_avg:95.11ms +step:1563/1705 train_time:148657ms step_avg:95.11ms +step:1564/1705 train_time:148752ms step_avg:95.11ms +step:1565/1705 train_time:148846ms step_avg:95.11ms +step:1566/1705 train_time:148941ms step_avg:95.11ms +step:1567/1705 train_time:149037ms step_avg:95.11ms +step:1568/1705 train_time:149132ms step_avg:95.11ms +step:1569/1705 train_time:149226ms step_avg:95.11ms +step:1570/1705 train_time:149322ms step_avg:95.11ms +step:1571/1705 train_time:149417ms step_avg:95.11ms +step:1572/1705 train_time:149512ms step_avg:95.11ms +step:1573/1705 train_time:149607ms step_avg:95.11ms +step:1574/1705 train_time:149702ms step_avg:95.11ms +step:1575/1705 train_time:149796ms step_avg:95.11ms +step:1576/1705 train_time:149891ms step_avg:95.11ms +step:1577/1705 train_time:149987ms step_avg:95.11ms +step:1578/1705 train_time:150082ms step_avg:95.11ms +step:1579/1705 train_time:150177ms step_avg:95.11ms +step:1580/1705 train_time:150273ms step_avg:95.11ms +step:1581/1705 train_time:150368ms step_avg:95.11ms +step:1582/1705 train_time:150463ms step_avg:95.11ms +step:1583/1705 train_time:150558ms step_avg:95.11ms +step:1584/1705 train_time:150653ms step_avg:95.11ms +step:1585/1705 train_time:150747ms step_avg:95.11ms +step:1586/1705 train_time:150843ms step_avg:95.11ms +step:1587/1705 train_time:150938ms step_avg:95.11ms +step:1588/1705 train_time:151034ms step_avg:95.11ms +step:1589/1705 train_time:151129ms step_avg:95.11ms +step:1590/1705 train_time:151223ms step_avg:95.11ms +step:1591/1705 train_time:151319ms step_avg:95.11ms +step:1592/1705 train_time:151414ms step_avg:95.11ms +step:1593/1705 train_time:151509ms step_avg:95.11ms +step:1594/1705 train_time:151605ms step_avg:95.11ms +step:1595/1705 train_time:151700ms step_avg:95.11ms +step:1596/1705 train_time:151796ms step_avg:95.11ms +step:1597/1705 train_time:151891ms step_avg:95.11ms +step:1598/1705 train_time:151986ms step_avg:95.11ms +step:1599/1705 train_time:152081ms step_avg:95.11ms +step:1600/1705 train_time:152176ms step_avg:95.11ms +step:1601/1705 train_time:152271ms step_avg:95.11ms +step:1602/1705 train_time:152365ms step_avg:95.11ms +step:1603/1705 train_time:152460ms step_avg:95.11ms +step:1604/1705 train_time:152556ms step_avg:95.11ms +step:1605/1705 train_time:152651ms step_avg:95.11ms +step:1606/1705 train_time:152746ms step_avg:95.11ms +step:1607/1705 train_time:152841ms step_avg:95.11ms +step:1608/1705 train_time:152937ms step_avg:95.11ms +step:1609/1705 train_time:153032ms step_avg:95.11ms +step:1610/1705 train_time:153126ms step_avg:95.11ms +step:1611/1705 train_time:153221ms step_avg:95.11ms +step:1612/1705 train_time:153317ms step_avg:95.11ms +step:1613/1705 train_time:153412ms step_avg:95.11ms +step:1614/1705 train_time:153506ms step_avg:95.11ms +step:1615/1705 train_time:153601ms step_avg:95.11ms +step:1616/1705 train_time:153698ms step_avg:95.11ms +step:1617/1705 train_time:153793ms step_avg:95.11ms +step:1618/1705 train_time:153887ms step_avg:95.11ms +step:1619/1705 train_time:153984ms step_avg:95.11ms +step:1620/1705 train_time:154079ms step_avg:95.11ms +step:1621/1705 train_time:154174ms step_avg:95.11ms +step:1622/1705 train_time:154269ms step_avg:95.11ms +step:1623/1705 train_time:154364ms step_avg:95.11ms +step:1624/1705 train_time:154459ms step_avg:95.11ms +step:1625/1705 train_time:154555ms step_avg:95.11ms +step:1625/1705 val_loss:3.2916 train_time:154650ms step_avg:95.17ms +step:1626/1705 train_time:154673ms step_avg:95.12ms +step:1627/1705 train_time:154751ms step_avg:95.11ms +step:1628/1705 train_time:154853ms step_avg:95.12ms +step:1629/1705 train_time:154949ms step_avg:95.12ms +step:1630/1705 train_time:155046ms step_avg:95.12ms +step:1631/1705 train_time:155141ms step_avg:95.12ms +step:1632/1705 train_time:155234ms step_avg:95.12ms +step:1633/1705 train_time:155329ms step_avg:95.12ms +step:1634/1705 train_time:155424ms step_avg:95.12ms +step:1635/1705 train_time:155517ms step_avg:95.12ms +step:1636/1705 train_time:155612ms step_avg:95.12ms +step:1637/1705 train_time:155710ms step_avg:95.12ms +step:1638/1705 train_time:155809ms step_avg:95.12ms +step:1639/1705 train_time:155906ms step_avg:95.12ms +step:1640/1705 train_time:156003ms step_avg:95.12ms +step:1641/1705 train_time:156099ms step_avg:95.12ms +step:1642/1705 train_time:156193ms step_avg:95.12ms +step:1643/1705 train_time:156288ms step_avg:95.12ms +step:1644/1705 train_time:156382ms step_avg:95.12ms +step:1645/1705 train_time:156476ms step_avg:95.12ms +step:1646/1705 train_time:156570ms step_avg:95.12ms +step:1647/1705 train_time:156666ms step_avg:95.12ms +step:1648/1705 train_time:156764ms step_avg:95.12ms +step:1649/1705 train_time:156861ms step_avg:95.12ms +step:1650/1705 train_time:156956ms step_avg:95.12ms +step:1651/1705 train_time:157053ms step_avg:95.13ms +step:1652/1705 train_time:157148ms step_avg:95.13ms +step:1653/1705 train_time:157243ms step_avg:95.13ms +step:1654/1705 train_time:157337ms step_avg:95.13ms +step:1655/1705 train_time:157431ms step_avg:95.12ms +step:1656/1705 train_time:157526ms step_avg:95.12ms +step:1657/1705 train_time:157621ms step_avg:95.12ms +step:1658/1705 train_time:157716ms step_avg:95.12ms +step:1659/1705 train_time:157812ms step_avg:95.13ms +step:1660/1705 train_time:157908ms step_avg:95.13ms +step:1661/1705 train_time:158004ms step_avg:95.13ms +step:1662/1705 train_time:158100ms step_avg:95.13ms +step:1663/1705 train_time:158194ms step_avg:95.13ms +step:1664/1705 train_time:158289ms step_avg:95.13ms +step:1665/1705 train_time:158384ms step_avg:95.13ms +step:1666/1705 train_time:158479ms step_avg:95.13ms +step:1667/1705 train_time:158573ms step_avg:95.12ms +step:1668/1705 train_time:158669ms step_avg:95.13ms +step:1669/1705 train_time:158765ms step_avg:95.13ms +step:1670/1705 train_time:158860ms step_avg:95.13ms +step:1671/1705 train_time:158956ms step_avg:95.13ms +step:1672/1705 train_time:159051ms step_avg:95.13ms +step:1673/1705 train_time:159147ms step_avg:95.13ms +step:1674/1705 train_time:159242ms step_avg:95.13ms +step:1675/1705 train_time:159337ms step_avg:95.13ms +step:1676/1705 train_time:159431ms step_avg:95.13ms +step:1677/1705 train_time:159527ms step_avg:95.13ms +step:1678/1705 train_time:159621ms step_avg:95.13ms +step:1679/1705 train_time:159715ms step_avg:95.13ms +step:1680/1705 train_time:159811ms step_avg:95.13ms +step:1681/1705 train_time:159907ms step_avg:95.13ms +step:1682/1705 train_time:160005ms step_avg:95.13ms +step:1683/1705 train_time:160101ms step_avg:95.13ms +step:1684/1705 train_time:160195ms step_avg:95.13ms +step:1685/1705 train_time:160291ms step_avg:95.13ms +step:1686/1705 train_time:160386ms step_avg:95.13ms +step:1687/1705 train_time:160481ms step_avg:95.13ms +step:1688/1705 train_time:160576ms step_avg:95.13ms +step:1689/1705 train_time:160670ms step_avg:95.13ms +step:1690/1705 train_time:160767ms step_avg:95.13ms +step:1691/1705 train_time:160862ms step_avg:95.13ms +step:1692/1705 train_time:160957ms step_avg:95.13ms +step:1693/1705 train_time:161053ms step_avg:95.13ms +step:1694/1705 train_time:161149ms step_avg:95.13ms +step:1695/1705 train_time:161244ms step_avg:95.13ms +step:1696/1705 train_time:161339ms step_avg:95.13ms +step:1697/1705 train_time:161435ms step_avg:95.13ms +step:1698/1705 train_time:161785ms step_avg:95.28ms +step:1699/1705 train_time:161876ms step_avg:95.28ms +step:1700/1705 train_time:161969ms step_avg:95.28ms +step:1701/1705 train_time:162063ms step_avg:95.28ms +step:1702/1705 train_time:162156ms step_avg:95.27ms +step:1703/1705 train_time:162251ms step_avg:95.27ms +step:1704/1705 train_time:162345ms step_avg:95.27ms +step:1705/1705 train_time:162440ms step_avg:95.27ms +step:1705/1705 val_loss:3.2776 train_time:162534ms step_avg:95.33ms +peak memory allocated: 33992 MiB reserved: 49376 MiB diff --git a/records/050925_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt b/records/050925_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt new file mode 100644 index 000000000..0bbc9fad7 --- /dev/null +++ b/records/050925_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:01:46 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 130W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 45C P0 128W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 44C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 123W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 79057 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 79058 C /usr/bin/python3 610MiB | +| 0 N/A N/A 79059 C /usr/bin/python3 610MiB | +| 0 N/A N/A 79060 C /usr/bin/python3 610MiB | +| 0 N/A N/A 79061 C /usr/bin/python3 610MiB | +| 0 N/A N/A 79062 C /usr/bin/python3 610MiB | +| 0 N/A N/A 79063 C /usr/bin/python3 610MiB | +| 0 N/A N/A 79064 C /usr/bin/python3 610MiB | +| 1 N/A N/A 79058 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 79059 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 79060 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 79061 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 79062 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 79063 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 79064 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1705 train_time:399ms step_avg:398.52ms +step:2/1705 train_time:418ms step_avg:209.22ms +step:3/1705 train_time:488ms step_avg:162.52ms +step:4/1705 train_time:579ms step_avg:144.70ms +step:5/1705 train_time:670ms step_avg:134.07ms +step:6/1705 train_time:763ms step_avg:127.11ms +step:7/1705 train_time:855ms step_avg:122.10ms +step:8/1705 train_time:948ms step_avg:118.47ms +step:9/1705 train_time:1040ms step_avg:115.56ms +step:10/1705 train_time:1133ms step_avg:113.27ms +step:11/1705 train_time:1226ms step_avg:111.42ms +step:12/1705 train_time:1320ms step_avg:110.01ms +step:13/1705 train_time:1415ms step_avg:108.87ms +step:14/1705 train_time:1509ms step_avg:107.77ms +step:15/1705 train_time:1603ms step_avg:106.85ms +step:16/1705 train_time:1696ms step_avg:105.97ms +step:17/1705 train_time:1788ms step_avg:105.16ms +step:18/1705 train_time:1880ms step_avg:104.47ms +step:19/1705 train_time:1973ms step_avg:103.85ms +step:20/1705 train_time:2066ms step_avg:103.28ms +step:21/1705 train_time:2158ms step_avg:102.78ms +step:22/1705 train_time:2251ms step_avg:102.31ms +step:23/1705 train_time:2346ms step_avg:101.98ms +step:24/1705 train_time:2440ms step_avg:101.66ms +step:25/1705 train_time:2534ms step_avg:101.37ms +step:26/1705 train_time:2628ms step_avg:101.07ms +step:27/1705 train_time:2721ms step_avg:100.79ms +step:28/1705 train_time:2814ms step_avg:100.50ms +step:29/1705 train_time:2907ms step_avg:100.25ms +step:30/1705 train_time:3000ms step_avg:100.01ms +step:31/1705 train_time:3093ms step_avg:99.77ms +step:32/1705 train_time:3186ms step_avg:99.55ms +step:33/1705 train_time:3279ms step_avg:99.37ms +step:34/1705 train_time:3373ms step_avg:99.19ms +step:35/1705 train_time:3467ms step_avg:99.05ms +step:36/1705 train_time:3561ms step_avg:98.92ms +step:37/1705 train_time:3654ms step_avg:98.77ms +step:38/1705 train_time:3748ms step_avg:98.63ms +step:39/1705 train_time:3842ms step_avg:98.50ms +step:40/1705 train_time:3935ms step_avg:98.37ms +step:41/1705 train_time:4028ms step_avg:98.23ms +step:42/1705 train_time:4121ms step_avg:98.12ms +step:43/1705 train_time:4213ms step_avg:97.98ms +step:44/1705 train_time:4307ms step_avg:97.88ms +step:45/1705 train_time:4400ms step_avg:97.77ms +step:46/1705 train_time:4493ms step_avg:97.67ms +step:47/1705 train_time:4586ms step_avg:97.58ms +step:48/1705 train_time:4680ms step_avg:97.49ms +step:49/1705 train_time:4773ms step_avg:97.41ms +step:50/1705 train_time:4866ms step_avg:97.33ms +step:51/1705 train_time:4960ms step_avg:97.25ms +step:52/1705 train_time:5052ms step_avg:97.15ms +step:53/1705 train_time:5146ms step_avg:97.09ms +step:54/1705 train_time:5239ms step_avg:97.02ms +step:55/1705 train_time:5332ms step_avg:96.94ms +step:56/1705 train_time:5426ms step_avg:96.89ms +step:57/1705 train_time:5519ms step_avg:96.82ms +step:58/1705 train_time:5613ms step_avg:96.77ms +step:59/1705 train_time:5706ms step_avg:96.71ms +step:60/1705 train_time:5799ms step_avg:96.66ms +step:61/1705 train_time:5892ms step_avg:96.59ms +step:62/1705 train_time:5985ms step_avg:96.53ms +step:63/1705 train_time:6078ms step_avg:96.48ms +step:64/1705 train_time:6171ms step_avg:96.42ms +step:65/1705 train_time:6264ms step_avg:96.37ms +step:66/1705 train_time:6358ms step_avg:96.33ms +step:67/1705 train_time:6450ms step_avg:96.27ms +step:68/1705 train_time:6544ms step_avg:96.24ms +step:69/1705 train_time:6637ms step_avg:96.19ms +step:70/1705 train_time:6731ms step_avg:96.15ms +step:71/1705 train_time:6825ms step_avg:96.12ms +step:72/1705 train_time:6917ms step_avg:96.07ms +step:73/1705 train_time:7010ms step_avg:96.02ms +step:74/1705 train_time:7104ms step_avg:96.00ms +step:75/1705 train_time:7197ms step_avg:95.96ms +step:76/1705 train_time:7290ms step_avg:95.92ms +step:77/1705 train_time:7383ms step_avg:95.89ms +step:78/1705 train_time:7477ms step_avg:95.85ms +step:79/1705 train_time:7570ms step_avg:95.82ms +step:80/1705 train_time:7663ms step_avg:95.79ms +step:81/1705 train_time:7757ms step_avg:95.76ms +step:82/1705 train_time:7849ms step_avg:95.72ms +step:83/1705 train_time:7943ms step_avg:95.70ms +step:84/1705 train_time:8036ms step_avg:95.66ms +step:85/1705 train_time:8129ms step_avg:95.64ms +step:86/1705 train_time:8223ms step_avg:95.62ms +step:87/1705 train_time:8317ms step_avg:95.60ms +step:88/1705 train_time:8410ms step_avg:95.56ms +step:89/1705 train_time:8503ms step_avg:95.54ms +step:90/1705 train_time:8596ms step_avg:95.51ms +step:91/1705 train_time:8689ms step_avg:95.49ms +step:92/1705 train_time:8783ms step_avg:95.47ms +step:93/1705 train_time:8876ms step_avg:95.44ms +step:94/1705 train_time:8969ms step_avg:95.41ms +step:95/1705 train_time:9063ms step_avg:95.40ms +step:96/1705 train_time:9155ms step_avg:95.37ms +step:97/1705 train_time:9249ms step_avg:95.35ms +step:98/1705 train_time:9343ms step_avg:95.33ms +step:99/1705 train_time:9437ms step_avg:95.32ms +step:100/1705 train_time:9529ms step_avg:95.29ms +step:101/1705 train_time:9622ms step_avg:95.27ms +step:102/1705 train_time:9715ms step_avg:95.24ms +step:103/1705 train_time:9807ms step_avg:95.22ms +step:104/1705 train_time:9901ms step_avg:95.20ms +step:105/1705 train_time:9993ms step_avg:95.18ms +step:106/1705 train_time:10086ms step_avg:95.15ms +step:107/1705 train_time:10179ms step_avg:95.13ms +step:108/1705 train_time:10272ms step_avg:95.11ms +step:109/1705 train_time:10366ms step_avg:95.10ms +step:110/1705 train_time:10458ms step_avg:95.07ms +step:111/1705 train_time:10551ms step_avg:95.05ms +step:112/1705 train_time:10645ms step_avg:95.04ms +step:113/1705 train_time:10737ms step_avg:95.02ms +step:114/1705 train_time:10830ms step_avg:95.00ms +step:115/1705 train_time:10924ms step_avg:94.99ms +step:116/1705 train_time:11017ms step_avg:94.97ms +step:117/1705 train_time:11109ms step_avg:94.95ms +step:118/1705 train_time:11203ms step_avg:94.94ms +step:119/1705 train_time:11296ms step_avg:94.92ms +step:120/1705 train_time:11388ms step_avg:94.90ms +step:121/1705 train_time:11481ms step_avg:94.89ms +step:122/1705 train_time:11574ms step_avg:94.87ms +step:123/1705 train_time:11666ms step_avg:94.85ms +step:124/1705 train_time:11759ms step_avg:94.83ms +step:125/1705 train_time:11853ms step_avg:94.83ms +step:125/1705 val_loss:4.2975 train_time:11947ms step_avg:95.57ms +step:126/1705 train_time:11976ms step_avg:95.05ms +step:127/1705 train_time:12046ms step_avg:94.85ms +step:128/1705 train_time:12150ms step_avg:94.92ms +step:129/1705 train_time:12243ms step_avg:94.90ms +step:130/1705 train_time:12335ms step_avg:94.88ms +step:131/1705 train_time:12427ms step_avg:94.86ms +step:132/1705 train_time:12519ms step_avg:94.84ms +step:133/1705 train_time:12611ms step_avg:94.82ms +step:134/1705 train_time:12703ms step_avg:94.80ms +step:135/1705 train_time:12795ms step_avg:94.78ms +step:136/1705 train_time:12887ms step_avg:94.76ms +step:137/1705 train_time:12979ms step_avg:94.74ms +step:138/1705 train_time:13075ms step_avg:94.75ms +step:139/1705 train_time:13171ms step_avg:94.76ms +step:140/1705 train_time:13265ms step_avg:94.75ms +step:141/1705 train_time:13358ms step_avg:94.74ms +step:142/1705 train_time:13451ms step_avg:94.73ms +step:143/1705 train_time:13543ms step_avg:94.71ms +step:144/1705 train_time:13635ms step_avg:94.69ms +step:145/1705 train_time:13728ms step_avg:94.67ms +step:146/1705 train_time:13819ms step_avg:94.65ms +step:147/1705 train_time:13912ms step_avg:94.64ms +step:148/1705 train_time:14005ms step_avg:94.63ms +step:149/1705 train_time:14099ms step_avg:94.62ms +step:150/1705 train_time:14193ms step_avg:94.62ms +step:151/1705 train_time:14286ms step_avg:94.61ms +step:152/1705 train_time:14379ms step_avg:94.60ms +step:153/1705 train_time:14472ms step_avg:94.59ms +step:154/1705 train_time:14566ms step_avg:94.58ms +step:155/1705 train_time:14658ms step_avg:94.57ms +step:156/1705 train_time:14751ms step_avg:94.56ms +step:157/1705 train_time:14843ms step_avg:94.54ms +step:158/1705 train_time:14936ms step_avg:94.53ms +step:159/1705 train_time:15028ms step_avg:94.52ms +step:160/1705 train_time:15121ms step_avg:94.51ms +step:161/1705 train_time:15214ms step_avg:94.50ms +step:162/1705 train_time:15308ms step_avg:94.50ms +step:163/1705 train_time:15402ms step_avg:94.49ms +step:164/1705 train_time:15495ms step_avg:94.48ms +step:165/1705 train_time:15587ms step_avg:94.47ms +step:166/1705 train_time:15679ms step_avg:94.45ms +step:167/1705 train_time:15772ms step_avg:94.44ms +step:168/1705 train_time:15864ms step_avg:94.43ms +step:169/1705 train_time:15956ms step_avg:94.42ms +step:170/1705 train_time:16049ms step_avg:94.41ms +step:171/1705 train_time:16143ms step_avg:94.40ms +step:172/1705 train_time:16237ms step_avg:94.40ms +step:173/1705 train_time:16330ms step_avg:94.39ms +step:174/1705 train_time:16422ms step_avg:94.38ms +step:175/1705 train_time:16515ms step_avg:94.37ms +step:176/1705 train_time:16608ms step_avg:94.37ms +step:177/1705 train_time:16701ms step_avg:94.35ms +step:178/1705 train_time:16794ms step_avg:94.35ms +step:179/1705 train_time:16886ms step_avg:94.34ms +step:180/1705 train_time:16978ms step_avg:94.32ms +step:181/1705 train_time:17072ms step_avg:94.32ms +step:182/1705 train_time:17164ms step_avg:94.31ms +step:183/1705 train_time:17257ms step_avg:94.30ms +step:184/1705 train_time:17351ms step_avg:94.30ms +step:185/1705 train_time:17444ms step_avg:94.29ms +step:186/1705 train_time:17536ms step_avg:94.28ms +step:187/1705 train_time:17629ms step_avg:94.27ms +step:188/1705 train_time:17722ms step_avg:94.26ms +step:189/1705 train_time:17815ms step_avg:94.26ms +step:190/1705 train_time:17907ms step_avg:94.25ms +step:191/1705 train_time:18000ms step_avg:94.24ms +step:192/1705 train_time:18093ms step_avg:94.23ms +step:193/1705 train_time:18186ms step_avg:94.23ms +step:194/1705 train_time:18279ms step_avg:94.22ms +step:195/1705 train_time:18372ms step_avg:94.22ms +step:196/1705 train_time:18465ms step_avg:94.21ms +step:197/1705 train_time:18558ms step_avg:94.20ms +step:198/1705 train_time:18650ms step_avg:94.19ms +step:199/1705 train_time:18742ms step_avg:94.18ms +step:200/1705 train_time:18836ms step_avg:94.18ms +step:201/1705 train_time:18929ms step_avg:94.17ms +step:202/1705 train_time:19021ms step_avg:94.16ms +step:203/1705 train_time:19114ms step_avg:94.16ms +step:204/1705 train_time:19207ms step_avg:94.15ms +step:205/1705 train_time:19300ms step_avg:94.15ms +step:206/1705 train_time:19393ms step_avg:94.14ms +step:207/1705 train_time:19485ms step_avg:94.13ms +step:208/1705 train_time:19578ms step_avg:94.13ms +step:209/1705 train_time:19671ms step_avg:94.12ms +step:210/1705 train_time:19764ms step_avg:94.11ms +step:211/1705 train_time:19857ms step_avg:94.11ms +step:212/1705 train_time:19951ms step_avg:94.11ms +step:213/1705 train_time:20230ms step_avg:94.98ms +step:214/1705 train_time:20369ms step_avg:95.18ms +step:215/1705 train_time:20460ms step_avg:95.16ms +step:216/1705 train_time:20552ms step_avg:95.15ms +step:217/1705 train_time:20644ms step_avg:95.13ms +step:218/1705 train_time:20736ms step_avg:95.12ms +step:219/1705 train_time:20827ms step_avg:95.10ms +step:220/1705 train_time:20919ms step_avg:95.09ms +step:221/1705 train_time:21011ms step_avg:95.07ms +step:222/1705 train_time:21103ms step_avg:95.06ms +step:223/1705 train_time:21196ms step_avg:95.05ms +step:224/1705 train_time:21292ms step_avg:95.06ms +step:225/1705 train_time:21389ms step_avg:95.06ms +step:226/1705 train_time:21482ms step_avg:95.05ms +step:227/1705 train_time:21574ms step_avg:95.04ms +step:228/1705 train_time:21666ms step_avg:95.03ms +step:229/1705 train_time:21758ms step_avg:95.01ms +step:230/1705 train_time:21851ms step_avg:95.00ms +step:231/1705 train_time:21943ms step_avg:94.99ms +step:232/1705 train_time:22035ms step_avg:94.98ms +step:233/1705 train_time:22127ms step_avg:94.96ms +step:234/1705 train_time:22219ms step_avg:94.95ms +step:235/1705 train_time:22313ms step_avg:94.95ms +step:236/1705 train_time:22407ms step_avg:94.94ms +step:237/1705 train_time:22500ms step_avg:94.94ms +step:238/1705 train_time:22593ms step_avg:94.93ms +step:239/1705 train_time:22685ms step_avg:94.92ms +step:240/1705 train_time:22777ms step_avg:94.90ms +step:241/1705 train_time:22870ms step_avg:94.89ms +step:242/1705 train_time:22961ms step_avg:94.88ms +step:243/1705 train_time:23054ms step_avg:94.87ms +step:244/1705 train_time:23146ms step_avg:94.86ms +step:245/1705 train_time:23238ms step_avg:94.85ms +step:246/1705 train_time:23332ms step_avg:94.84ms +step:247/1705 train_time:23425ms step_avg:94.84ms +step:248/1705 train_time:23519ms step_avg:94.83ms +step:249/1705 train_time:23612ms step_avg:94.83ms +step:250/1705 train_time:23705ms step_avg:94.82ms +step:250/1705 val_loss:3.9686 train_time:23798ms step_avg:95.19ms +step:251/1705 train_time:23819ms step_avg:94.90ms +step:252/1705 train_time:23896ms step_avg:94.82ms +step:253/1705 train_time:23995ms step_avg:94.84ms +step:254/1705 train_time:24089ms step_avg:94.84ms +step:255/1705 train_time:24181ms step_avg:94.83ms +step:256/1705 train_time:24273ms step_avg:94.82ms +step:257/1705 train_time:24365ms step_avg:94.80ms +step:258/1705 train_time:24457ms step_avg:94.79ms +step:259/1705 train_time:24549ms step_avg:94.78ms +step:260/1705 train_time:24641ms step_avg:94.77ms +step:261/1705 train_time:24733ms step_avg:94.76ms +step:262/1705 train_time:24827ms step_avg:94.76ms +step:263/1705 train_time:24921ms step_avg:94.76ms +step:264/1705 train_time:25016ms step_avg:94.76ms +step:265/1705 train_time:25110ms step_avg:94.76ms +step:266/1705 train_time:25203ms step_avg:94.75ms +step:267/1705 train_time:25295ms step_avg:94.74ms +step:268/1705 train_time:25388ms step_avg:94.73ms +step:269/1705 train_time:25480ms step_avg:94.72ms +step:270/1705 train_time:25573ms step_avg:94.71ms +step:271/1705 train_time:25664ms step_avg:94.70ms +step:272/1705 train_time:25758ms step_avg:94.70ms +step:273/1705 train_time:25852ms step_avg:94.69ms +step:274/1705 train_time:25945ms step_avg:94.69ms +step:275/1705 train_time:26040ms step_avg:94.69ms +step:276/1705 train_time:26134ms step_avg:94.69ms +step:277/1705 train_time:26226ms step_avg:94.68ms +step:278/1705 train_time:26319ms step_avg:94.67ms +step:279/1705 train_time:26412ms step_avg:94.67ms +step:280/1705 train_time:26504ms step_avg:94.66ms +step:281/1705 train_time:26596ms step_avg:94.65ms +step:282/1705 train_time:26689ms step_avg:94.64ms +step:283/1705 train_time:26782ms step_avg:94.64ms +step:284/1705 train_time:26875ms step_avg:94.63ms +step:285/1705 train_time:26969ms step_avg:94.63ms +step:286/1705 train_time:27064ms step_avg:94.63ms +step:287/1705 train_time:27157ms step_avg:94.62ms +step:288/1705 train_time:27251ms step_avg:94.62ms +step:289/1705 train_time:27344ms step_avg:94.62ms +step:290/1705 train_time:27437ms step_avg:94.61ms +step:291/1705 train_time:27530ms step_avg:94.61ms +step:292/1705 train_time:27623ms step_avg:94.60ms +step:293/1705 train_time:27715ms step_avg:94.59ms +step:294/1705 train_time:27808ms step_avg:94.58ms +step:295/1705 train_time:27901ms step_avg:94.58ms +step:296/1705 train_time:27994ms step_avg:94.57ms +step:297/1705 train_time:28088ms step_avg:94.57ms +step:298/1705 train_time:28180ms step_avg:94.57ms +step:299/1705 train_time:28274ms step_avg:94.56ms +step:300/1705 train_time:28367ms step_avg:94.56ms +step:301/1705 train_time:28459ms step_avg:94.55ms +step:302/1705 train_time:28552ms step_avg:94.54ms +step:303/1705 train_time:28644ms step_avg:94.54ms +step:304/1705 train_time:28737ms step_avg:94.53ms +step:305/1705 train_time:28830ms step_avg:94.52ms +step:306/1705 train_time:28923ms step_avg:94.52ms +step:307/1705 train_time:29016ms step_avg:94.51ms +step:308/1705 train_time:29110ms step_avg:94.51ms +step:309/1705 train_time:29202ms step_avg:94.51ms +step:310/1705 train_time:29295ms step_avg:94.50ms +step:311/1705 train_time:29388ms step_avg:94.50ms +step:312/1705 train_time:29481ms step_avg:94.49ms +step:313/1705 train_time:29574ms step_avg:94.48ms +step:314/1705 train_time:29667ms step_avg:94.48ms +step:315/1705 train_time:29759ms step_avg:94.47ms +step:316/1705 train_time:29853ms step_avg:94.47ms +step:317/1705 train_time:29945ms step_avg:94.46ms +step:318/1705 train_time:30038ms step_avg:94.46ms +step:319/1705 train_time:30132ms step_avg:94.46ms +step:320/1705 train_time:30225ms step_avg:94.45ms +step:321/1705 train_time:30318ms step_avg:94.45ms +step:322/1705 train_time:30411ms step_avg:94.44ms +step:323/1705 train_time:30504ms step_avg:94.44ms +step:324/1705 train_time:30597ms step_avg:94.43ms +step:325/1705 train_time:30689ms step_avg:94.43ms +step:326/1705 train_time:30783ms step_avg:94.42ms +step:327/1705 train_time:30875ms step_avg:94.42ms +step:328/1705 train_time:30968ms step_avg:94.41ms +step:329/1705 train_time:31061ms step_avg:94.41ms +step:330/1705 train_time:31154ms step_avg:94.40ms +step:331/1705 train_time:31247ms step_avg:94.40ms +step:332/1705 train_time:31340ms step_avg:94.40ms +step:333/1705 train_time:31433ms step_avg:94.39ms +step:334/1705 train_time:31526ms step_avg:94.39ms +step:335/1705 train_time:31618ms step_avg:94.38ms +step:336/1705 train_time:31711ms step_avg:94.38ms +step:337/1705 train_time:31803ms step_avg:94.37ms +step:338/1705 train_time:31896ms step_avg:94.37ms +step:339/1705 train_time:31989ms step_avg:94.36ms +step:340/1705 train_time:32082ms step_avg:94.36ms +step:341/1705 train_time:32175ms step_avg:94.36ms +step:342/1705 train_time:32268ms step_avg:94.35ms +step:343/1705 train_time:32361ms step_avg:94.35ms +step:344/1705 train_time:32454ms step_avg:94.34ms +step:345/1705 train_time:32546ms step_avg:94.34ms +step:346/1705 train_time:32640ms step_avg:94.33ms +step:347/1705 train_time:32733ms step_avg:94.33ms +step:348/1705 train_time:32825ms step_avg:94.33ms +step:349/1705 train_time:32918ms step_avg:94.32ms +step:350/1705 train_time:33011ms step_avg:94.32ms +step:351/1705 train_time:33104ms step_avg:94.31ms +step:352/1705 train_time:33197ms step_avg:94.31ms +step:353/1705 train_time:33291ms step_avg:94.31ms +step:354/1705 train_time:33384ms step_avg:94.30ms +step:355/1705 train_time:33477ms step_avg:94.30ms +step:356/1705 train_time:33570ms step_avg:94.30ms +step:357/1705 train_time:33663ms step_avg:94.29ms +step:358/1705 train_time:33756ms step_avg:94.29ms +step:359/1705 train_time:33850ms step_avg:94.29ms +step:360/1705 train_time:33943ms step_avg:94.29ms +step:361/1705 train_time:34036ms step_avg:94.28ms +step:362/1705 train_time:34129ms step_avg:94.28ms +step:363/1705 train_time:34222ms step_avg:94.28ms +step:364/1705 train_time:34315ms step_avg:94.27ms +step:365/1705 train_time:34408ms step_avg:94.27ms +step:366/1705 train_time:34501ms step_avg:94.27ms +step:367/1705 train_time:34594ms step_avg:94.26ms +step:368/1705 train_time:34687ms step_avg:94.26ms +step:369/1705 train_time:34780ms step_avg:94.25ms +step:370/1705 train_time:34873ms step_avg:94.25ms +step:371/1705 train_time:34966ms step_avg:94.25ms +step:372/1705 train_time:35059ms step_avg:94.24ms +step:373/1705 train_time:35152ms step_avg:94.24ms +step:374/1705 train_time:35245ms step_avg:94.24ms +step:375/1705 train_time:35339ms step_avg:94.24ms +step:375/1705 val_loss:3.8161 train_time:35433ms step_avg:94.49ms +step:376/1705 train_time:35453ms step_avg:94.29ms +step:377/1705 train_time:35530ms step_avg:94.24ms +step:378/1705 train_time:35628ms step_avg:94.25ms +step:379/1705 train_time:35722ms step_avg:94.25ms +step:380/1705 train_time:35814ms step_avg:94.25ms +step:381/1705 train_time:35906ms step_avg:94.24ms +step:382/1705 train_time:35998ms step_avg:94.24ms +step:383/1705 train_time:36091ms step_avg:94.23ms +step:384/1705 train_time:36183ms step_avg:94.23ms +step:385/1705 train_time:36275ms step_avg:94.22ms +step:386/1705 train_time:36368ms step_avg:94.22ms +step:387/1705 train_time:36461ms step_avg:94.22ms +step:388/1705 train_time:36556ms step_avg:94.22ms +step:389/1705 train_time:36651ms step_avg:94.22ms +step:390/1705 train_time:36744ms step_avg:94.22ms +step:391/1705 train_time:36837ms step_avg:94.21ms +step:392/1705 train_time:36929ms step_avg:94.21ms +step:393/1705 train_time:37022ms step_avg:94.20ms +step:394/1705 train_time:37114ms step_avg:94.20ms +step:395/1705 train_time:37207ms step_avg:94.20ms +step:396/1705 train_time:37300ms step_avg:94.19ms +step:397/1705 train_time:37393ms step_avg:94.19ms +step:398/1705 train_time:37487ms step_avg:94.19ms +step:399/1705 train_time:37581ms step_avg:94.19ms +step:400/1705 train_time:37675ms step_avg:94.19ms +step:401/1705 train_time:37769ms step_avg:94.19ms +step:402/1705 train_time:37861ms step_avg:94.18ms +step:403/1705 train_time:37953ms step_avg:94.18ms +step:404/1705 train_time:38046ms step_avg:94.17ms +step:405/1705 train_time:38138ms step_avg:94.17ms +step:406/1705 train_time:38231ms step_avg:94.17ms +step:407/1705 train_time:38324ms step_avg:94.16ms +step:408/1705 train_time:38417ms step_avg:94.16ms +step:409/1705 train_time:38511ms step_avg:94.16ms +step:410/1705 train_time:38604ms step_avg:94.16ms +step:411/1705 train_time:38698ms step_avg:94.15ms +step:412/1705 train_time:38792ms step_avg:94.15ms +step:413/1705 train_time:38885ms step_avg:94.15ms +step:414/1705 train_time:38978ms step_avg:94.15ms +step:415/1705 train_time:39070ms step_avg:94.14ms +step:416/1705 train_time:39162ms step_avg:94.14ms +step:417/1705 train_time:39255ms step_avg:94.14ms +step:418/1705 train_time:39349ms step_avg:94.14ms +step:419/1705 train_time:39441ms step_avg:94.13ms +step:420/1705 train_time:39535ms step_avg:94.13ms +step:421/1705 train_time:39629ms step_avg:94.13ms +step:422/1705 train_time:39722ms step_avg:94.13ms +step:423/1705 train_time:39815ms step_avg:94.13ms +step:424/1705 train_time:39909ms step_avg:94.12ms +step:425/1705 train_time:40236ms step_avg:94.67ms +step:426/1705 train_time:40305ms step_avg:94.61ms +step:427/1705 train_time:40396ms step_avg:94.60ms +step:428/1705 train_time:40489ms step_avg:94.60ms +step:429/1705 train_time:40581ms step_avg:94.60ms +step:430/1705 train_time:40673ms step_avg:94.59ms +step:431/1705 train_time:40766ms step_avg:94.58ms +step:432/1705 train_time:40858ms step_avg:94.58ms +step:433/1705 train_time:40949ms step_avg:94.57ms +step:434/1705 train_time:41042ms step_avg:94.57ms +step:435/1705 train_time:41135ms step_avg:94.56ms +step:436/1705 train_time:41233ms step_avg:94.57ms +step:437/1705 train_time:41331ms step_avg:94.58ms +step:438/1705 train_time:41424ms step_avg:94.58ms +step:439/1705 train_time:41518ms step_avg:94.57ms +step:440/1705 train_time:41611ms step_avg:94.57ms +step:441/1705 train_time:41703ms step_avg:94.56ms +step:442/1705 train_time:41795ms step_avg:94.56ms +step:443/1705 train_time:41888ms step_avg:94.55ms +step:444/1705 train_time:41979ms step_avg:94.55ms +step:445/1705 train_time:42071ms step_avg:94.54ms +step:446/1705 train_time:42164ms step_avg:94.54ms +step:447/1705 train_time:42258ms step_avg:94.54ms +step:448/1705 train_time:42352ms step_avg:94.54ms +step:449/1705 train_time:42445ms step_avg:94.53ms +step:450/1705 train_time:42539ms step_avg:94.53ms +step:451/1705 train_time:42632ms step_avg:94.53ms +step:452/1705 train_time:42725ms step_avg:94.52ms +step:453/1705 train_time:42819ms step_avg:94.52ms +step:454/1705 train_time:42911ms step_avg:94.52ms +step:455/1705 train_time:43003ms step_avg:94.51ms +step:456/1705 train_time:43096ms step_avg:94.51ms +step:457/1705 train_time:43190ms step_avg:94.51ms +step:458/1705 train_time:43283ms step_avg:94.50ms +step:459/1705 train_time:43377ms step_avg:94.50ms +step:460/1705 train_time:43471ms step_avg:94.50ms +step:461/1705 train_time:43564ms step_avg:94.50ms +step:462/1705 train_time:43657ms step_avg:94.50ms +step:463/1705 train_time:43750ms step_avg:94.49ms +step:464/1705 train_time:43843ms step_avg:94.49ms +step:465/1705 train_time:43935ms step_avg:94.48ms +step:466/1705 train_time:44028ms step_avg:94.48ms +step:467/1705 train_time:44120ms step_avg:94.48ms +step:468/1705 train_time:44213ms step_avg:94.47ms +step:469/1705 train_time:44307ms step_avg:94.47ms +step:470/1705 train_time:44400ms step_avg:94.47ms +step:471/1705 train_time:44493ms step_avg:94.47ms +step:472/1705 train_time:44587ms step_avg:94.46ms +step:473/1705 train_time:44680ms step_avg:94.46ms +step:474/1705 train_time:44773ms step_avg:94.46ms +step:475/1705 train_time:44866ms step_avg:94.46ms +step:476/1705 train_time:44959ms step_avg:94.45ms +step:477/1705 train_time:45051ms step_avg:94.45ms +step:478/1705 train_time:45144ms step_avg:94.44ms +step:479/1705 train_time:45237ms step_avg:94.44ms +step:480/1705 train_time:45330ms step_avg:94.44ms +step:481/1705 train_time:45423ms step_avg:94.43ms +step:482/1705 train_time:45517ms step_avg:94.43ms +step:483/1705 train_time:45611ms step_avg:94.43ms +step:484/1705 train_time:45704ms step_avg:94.43ms +step:485/1705 train_time:45797ms step_avg:94.43ms +step:486/1705 train_time:45890ms step_avg:94.42ms +step:487/1705 train_time:45983ms step_avg:94.42ms +step:488/1705 train_time:46076ms step_avg:94.42ms +step:489/1705 train_time:46169ms step_avg:94.42ms +step:490/1705 train_time:46262ms step_avg:94.41ms +step:491/1705 train_time:46355ms step_avg:94.41ms +step:492/1705 train_time:46448ms step_avg:94.41ms +step:493/1705 train_time:46541ms step_avg:94.40ms +step:494/1705 train_time:46635ms step_avg:94.40ms +step:495/1705 train_time:46728ms step_avg:94.40ms +step:496/1705 train_time:46821ms step_avg:94.40ms +step:497/1705 train_time:46914ms step_avg:94.39ms +step:498/1705 train_time:47008ms step_avg:94.39ms +step:499/1705 train_time:47100ms step_avg:94.39ms +step:500/1705 train_time:47193ms step_avg:94.39ms +step:500/1705 val_loss:3.7129 train_time:47287ms step_avg:94.57ms +step:501/1705 train_time:47309ms step_avg:94.43ms +step:502/1705 train_time:47385ms step_avg:94.39ms +step:503/1705 train_time:47482ms step_avg:94.40ms +step:504/1705 train_time:47576ms step_avg:94.40ms +step:505/1705 train_time:47669ms step_avg:94.39ms +step:506/1705 train_time:47761ms step_avg:94.39ms +step:507/1705 train_time:47853ms step_avg:94.39ms +step:508/1705 train_time:47945ms step_avg:94.38ms +step:509/1705 train_time:48037ms step_avg:94.38ms +step:510/1705 train_time:48129ms step_avg:94.37ms +step:511/1705 train_time:48221ms step_avg:94.37ms +step:512/1705 train_time:48314ms step_avg:94.36ms +step:513/1705 train_time:48409ms step_avg:94.36ms +step:514/1705 train_time:48504ms step_avg:94.37ms +step:515/1705 train_time:48597ms step_avg:94.36ms +step:516/1705 train_time:48690ms step_avg:94.36ms +step:517/1705 train_time:48783ms step_avg:94.36ms +step:518/1705 train_time:48875ms step_avg:94.35ms +step:519/1705 train_time:48968ms step_avg:94.35ms +step:520/1705 train_time:49062ms step_avg:94.35ms +step:521/1705 train_time:49154ms step_avg:94.35ms +step:522/1705 train_time:49247ms step_avg:94.34ms +step:523/1705 train_time:49341ms step_avg:94.34ms +step:524/1705 train_time:49434ms step_avg:94.34ms +step:525/1705 train_time:49528ms step_avg:94.34ms +step:526/1705 train_time:49623ms step_avg:94.34ms +step:527/1705 train_time:49716ms step_avg:94.34ms +step:528/1705 train_time:49808ms step_avg:94.33ms +step:529/1705 train_time:49901ms step_avg:94.33ms +step:530/1705 train_time:49993ms step_avg:94.33ms +step:531/1705 train_time:50086ms step_avg:94.32ms +step:532/1705 train_time:50178ms step_avg:94.32ms +step:533/1705 train_time:50271ms step_avg:94.32ms +step:534/1705 train_time:50365ms step_avg:94.32ms +step:535/1705 train_time:50459ms step_avg:94.32ms +step:536/1705 train_time:50552ms step_avg:94.31ms +step:537/1705 train_time:50645ms step_avg:94.31ms +step:538/1705 train_time:50739ms step_avg:94.31ms +step:539/1705 train_time:50831ms step_avg:94.31ms +step:540/1705 train_time:50924ms step_avg:94.30ms +step:541/1705 train_time:51016ms step_avg:94.30ms +step:542/1705 train_time:51108ms step_avg:94.30ms +step:543/1705 train_time:51201ms step_avg:94.29ms +step:544/1705 train_time:51294ms step_avg:94.29ms +step:545/1705 train_time:51387ms step_avg:94.29ms +step:546/1705 train_time:51480ms step_avg:94.29ms +step:547/1705 train_time:51573ms step_avg:94.28ms +step:548/1705 train_time:51666ms step_avg:94.28ms +step:549/1705 train_time:51760ms step_avg:94.28ms +step:550/1705 train_time:51852ms step_avg:94.28ms +step:551/1705 train_time:51945ms step_avg:94.27ms +step:552/1705 train_time:52038ms step_avg:94.27ms +step:553/1705 train_time:52131ms step_avg:94.27ms +step:554/1705 train_time:52224ms step_avg:94.27ms +step:555/1705 train_time:52316ms step_avg:94.26ms +step:556/1705 train_time:52409ms step_avg:94.26ms +step:557/1705 train_time:52502ms step_avg:94.26ms +step:558/1705 train_time:52596ms step_avg:94.26ms +step:559/1705 train_time:52689ms step_avg:94.26ms +step:560/1705 train_time:52782ms step_avg:94.25ms +step:561/1705 train_time:52876ms step_avg:94.25ms +step:562/1705 train_time:52969ms step_avg:94.25ms +step:563/1705 train_time:53064ms step_avg:94.25ms +step:564/1705 train_time:53157ms step_avg:94.25ms +step:565/1705 train_time:53249ms step_avg:94.25ms +step:566/1705 train_time:53343ms step_avg:94.25ms +step:567/1705 train_time:53435ms step_avg:94.24ms +step:568/1705 train_time:53529ms step_avg:94.24ms +step:569/1705 train_time:53624ms step_avg:94.24ms +step:570/1705 train_time:53717ms step_avg:94.24ms +step:571/1705 train_time:53811ms step_avg:94.24ms +step:572/1705 train_time:53905ms step_avg:94.24ms +step:573/1705 train_time:53999ms step_avg:94.24ms +step:574/1705 train_time:54093ms step_avg:94.24ms +step:575/1705 train_time:54187ms step_avg:94.24ms +step:576/1705 train_time:54282ms step_avg:94.24ms +step:577/1705 train_time:54377ms step_avg:94.24ms +step:578/1705 train_time:54471ms step_avg:94.24ms +step:579/1705 train_time:54565ms step_avg:94.24ms +step:580/1705 train_time:54659ms step_avg:94.24ms +step:581/1705 train_time:54753ms step_avg:94.24ms +step:582/1705 train_time:54847ms step_avg:94.24ms +step:583/1705 train_time:54942ms step_avg:94.24ms +step:584/1705 train_time:55037ms step_avg:94.24ms +step:585/1705 train_time:55130ms step_avg:94.24ms +step:586/1705 train_time:55224ms step_avg:94.24ms +step:587/1705 train_time:55319ms step_avg:94.24ms +step:588/1705 train_time:55412ms step_avg:94.24ms +step:589/1705 train_time:55505ms step_avg:94.24ms +step:590/1705 train_time:55599ms step_avg:94.24ms +step:591/1705 train_time:55693ms step_avg:94.23ms +step:592/1705 train_time:55787ms step_avg:94.23ms +step:593/1705 train_time:55881ms step_avg:94.23ms +step:594/1705 train_time:55976ms step_avg:94.24ms +step:595/1705 train_time:56070ms step_avg:94.23ms +step:596/1705 train_time:56165ms step_avg:94.24ms +step:597/1705 train_time:56260ms step_avg:94.24ms +step:598/1705 train_time:56354ms step_avg:94.24ms +step:599/1705 train_time:56447ms step_avg:94.24ms +step:600/1705 train_time:56543ms step_avg:94.24ms +step:601/1705 train_time:56638ms step_avg:94.24ms +step:602/1705 train_time:56732ms step_avg:94.24ms +step:603/1705 train_time:56827ms step_avg:94.24ms +step:604/1705 train_time:56921ms step_avg:94.24ms +step:605/1705 train_time:57015ms step_avg:94.24ms +step:606/1705 train_time:57109ms step_avg:94.24ms +step:607/1705 train_time:57204ms step_avg:94.24ms +step:608/1705 train_time:57299ms step_avg:94.24ms +step:609/1705 train_time:57392ms step_avg:94.24ms +step:610/1705 train_time:57487ms step_avg:94.24ms +step:611/1705 train_time:57581ms step_avg:94.24ms +step:612/1705 train_time:57676ms step_avg:94.24ms +step:613/1705 train_time:57770ms step_avg:94.24ms +step:614/1705 train_time:57865ms step_avg:94.24ms +step:615/1705 train_time:57960ms step_avg:94.24ms +step:616/1705 train_time:58054ms step_avg:94.24ms +step:617/1705 train_time:58148ms step_avg:94.24ms +step:618/1705 train_time:58244ms step_avg:94.25ms +step:619/1705 train_time:58339ms step_avg:94.25ms +step:620/1705 train_time:58433ms step_avg:94.25ms +step:621/1705 train_time:58527ms step_avg:94.25ms +step:622/1705 train_time:58622ms step_avg:94.25ms +step:623/1705 train_time:58715ms step_avg:94.25ms +step:624/1705 train_time:58810ms step_avg:94.25ms +step:625/1705 train_time:58904ms step_avg:94.25ms +step:625/1705 val_loss:3.6156 train_time:58999ms step_avg:94.40ms +step:626/1705 train_time:59021ms step_avg:94.28ms +step:627/1705 train_time:59101ms step_avg:94.26ms +step:628/1705 train_time:59200ms step_avg:94.27ms +step:629/1705 train_time:59296ms step_avg:94.27ms +step:630/1705 train_time:59391ms step_avg:94.27ms +step:631/1705 train_time:59484ms step_avg:94.27ms +step:632/1705 train_time:59578ms step_avg:94.27ms +step:633/1705 train_time:59672ms step_avg:94.27ms +step:634/1705 train_time:59765ms step_avg:94.27ms +step:635/1705 train_time:59858ms step_avg:94.27ms +step:636/1705 train_time:59952ms step_avg:94.26ms +step:637/1705 train_time:60049ms step_avg:94.27ms +step:638/1705 train_time:60144ms step_avg:94.27ms +step:639/1705 train_time:60502ms step_avg:94.68ms +step:640/1705 train_time:60599ms step_avg:94.69ms +step:641/1705 train_time:60692ms step_avg:94.68ms +step:642/1705 train_time:60785ms step_avg:94.68ms +step:643/1705 train_time:60879ms step_avg:94.68ms +step:644/1705 train_time:60973ms step_avg:94.68ms +step:645/1705 train_time:61066ms step_avg:94.68ms +step:646/1705 train_time:61158ms step_avg:94.67ms +step:647/1705 train_time:61251ms step_avg:94.67ms +step:648/1705 train_time:61345ms step_avg:94.67ms +step:649/1705 train_time:61442ms step_avg:94.67ms +step:650/1705 train_time:61539ms step_avg:94.68ms +step:651/1705 train_time:61636ms step_avg:94.68ms +step:652/1705 train_time:61730ms step_avg:94.68ms +step:653/1705 train_time:61825ms step_avg:94.68ms +step:654/1705 train_time:61919ms step_avg:94.68ms +step:655/1705 train_time:62013ms step_avg:94.68ms +step:656/1705 train_time:62106ms step_avg:94.67ms +step:657/1705 train_time:62199ms step_avg:94.67ms +step:658/1705 train_time:62294ms step_avg:94.67ms +step:659/1705 train_time:62389ms step_avg:94.67ms +step:660/1705 train_time:62484ms step_avg:94.67ms +step:661/1705 train_time:62580ms step_avg:94.67ms +step:662/1705 train_time:62676ms step_avg:94.68ms +step:663/1705 train_time:62770ms step_avg:94.68ms +step:664/1705 train_time:62864ms step_avg:94.68ms +step:665/1705 train_time:62959ms step_avg:94.67ms +step:666/1705 train_time:63053ms step_avg:94.67ms +step:667/1705 train_time:63146ms step_avg:94.67ms +step:668/1705 train_time:63241ms step_avg:94.67ms +step:669/1705 train_time:63335ms step_avg:94.67ms +step:670/1705 train_time:63430ms step_avg:94.67ms +step:671/1705 train_time:63523ms step_avg:94.67ms +step:672/1705 train_time:63619ms step_avg:94.67ms +step:673/1705 train_time:63715ms step_avg:94.67ms +step:674/1705 train_time:63809ms step_avg:94.67ms +step:675/1705 train_time:63904ms step_avg:94.67ms +step:676/1705 train_time:63999ms step_avg:94.67ms +step:677/1705 train_time:64093ms step_avg:94.67ms +step:678/1705 train_time:64187ms step_avg:94.67ms +step:679/1705 train_time:64280ms step_avg:94.67ms +step:680/1705 train_time:64375ms step_avg:94.67ms +step:681/1705 train_time:64470ms step_avg:94.67ms +step:682/1705 train_time:64564ms step_avg:94.67ms +step:683/1705 train_time:64659ms step_avg:94.67ms +step:684/1705 train_time:64754ms step_avg:94.67ms +step:685/1705 train_time:64849ms step_avg:94.67ms +step:686/1705 train_time:64943ms step_avg:94.67ms +step:687/1705 train_time:65037ms step_avg:94.67ms +step:688/1705 train_time:65130ms step_avg:94.67ms +step:689/1705 train_time:65224ms step_avg:94.66ms +step:690/1705 train_time:65318ms step_avg:94.66ms +step:691/1705 train_time:65414ms step_avg:94.67ms +step:692/1705 train_time:65508ms step_avg:94.67ms +step:693/1705 train_time:65603ms step_avg:94.67ms +step:694/1705 train_time:65698ms step_avg:94.67ms +step:695/1705 train_time:65794ms step_avg:94.67ms +step:696/1705 train_time:65889ms step_avg:94.67ms +step:697/1705 train_time:65983ms step_avg:94.67ms +step:698/1705 train_time:66078ms step_avg:94.67ms +step:699/1705 train_time:66172ms step_avg:94.67ms +step:700/1705 train_time:66266ms step_avg:94.67ms +step:701/1705 train_time:66360ms step_avg:94.66ms +step:702/1705 train_time:66456ms step_avg:94.67ms +step:703/1705 train_time:66551ms step_avg:94.67ms +step:704/1705 train_time:66645ms step_avg:94.67ms +step:705/1705 train_time:66740ms step_avg:94.67ms +step:706/1705 train_time:66834ms step_avg:94.67ms +step:707/1705 train_time:66928ms step_avg:94.66ms +step:708/1705 train_time:67022ms step_avg:94.66ms +step:709/1705 train_time:67116ms step_avg:94.66ms +step:710/1705 train_time:67211ms step_avg:94.66ms +step:711/1705 train_time:67304ms step_avg:94.66ms +step:712/1705 train_time:67398ms step_avg:94.66ms +step:713/1705 train_time:67493ms step_avg:94.66ms +step:714/1705 train_time:67587ms step_avg:94.66ms +step:715/1705 train_time:67681ms step_avg:94.66ms +step:716/1705 train_time:67776ms step_avg:94.66ms +step:717/1705 train_time:67870ms step_avg:94.66ms +step:718/1705 train_time:67964ms step_avg:94.66ms +step:719/1705 train_time:68057ms step_avg:94.66ms +step:720/1705 train_time:68152ms step_avg:94.66ms +step:721/1705 train_time:68245ms step_avg:94.65ms +step:722/1705 train_time:68339ms step_avg:94.65ms +step:723/1705 train_time:68434ms step_avg:94.65ms +step:724/1705 train_time:68528ms step_avg:94.65ms +step:725/1705 train_time:68623ms step_avg:94.65ms +step:726/1705 train_time:68717ms step_avg:94.65ms +step:727/1705 train_time:68812ms step_avg:94.65ms +step:728/1705 train_time:68906ms step_avg:94.65ms +step:729/1705 train_time:69000ms step_avg:94.65ms +step:730/1705 train_time:69096ms step_avg:94.65ms +step:731/1705 train_time:69191ms step_avg:94.65ms +step:732/1705 train_time:69285ms step_avg:94.65ms +step:733/1705 train_time:69380ms step_avg:94.65ms +step:734/1705 train_time:69475ms step_avg:94.65ms +step:735/1705 train_time:69569ms step_avg:94.65ms +step:736/1705 train_time:69663ms step_avg:94.65ms +step:737/1705 train_time:69758ms step_avg:94.65ms +step:738/1705 train_time:69853ms step_avg:94.65ms +step:739/1705 train_time:69948ms step_avg:94.65ms +step:740/1705 train_time:70042ms step_avg:94.65ms +step:741/1705 train_time:70136ms step_avg:94.65ms +step:742/1705 train_time:70230ms step_avg:94.65ms +step:743/1705 train_time:70324ms step_avg:94.65ms +step:744/1705 train_time:70419ms step_avg:94.65ms +step:745/1705 train_time:70513ms step_avg:94.65ms +step:746/1705 train_time:70607ms step_avg:94.65ms +step:747/1705 train_time:70702ms step_avg:94.65ms +step:748/1705 train_time:70796ms step_avg:94.65ms +step:749/1705 train_time:70891ms step_avg:94.65ms +step:750/1705 train_time:70985ms step_avg:94.65ms +step:750/1705 val_loss:3.5630 train_time:71080ms step_avg:94.77ms +step:751/1705 train_time:71101ms step_avg:94.67ms +step:752/1705 train_time:71181ms step_avg:94.66ms +step:753/1705 train_time:71279ms step_avg:94.66ms +step:754/1705 train_time:71375ms step_avg:94.66ms +step:755/1705 train_time:71469ms step_avg:94.66ms +step:756/1705 train_time:71561ms step_avg:94.66ms +step:757/1705 train_time:71655ms step_avg:94.66ms +step:758/1705 train_time:71749ms step_avg:94.66ms +step:759/1705 train_time:71842ms step_avg:94.65ms +step:760/1705 train_time:71935ms step_avg:94.65ms +step:761/1705 train_time:72030ms step_avg:94.65ms +step:762/1705 train_time:72127ms step_avg:94.65ms +step:763/1705 train_time:72223ms step_avg:94.66ms +step:764/1705 train_time:72318ms step_avg:94.66ms +step:765/1705 train_time:72414ms step_avg:94.66ms +step:766/1705 train_time:72508ms step_avg:94.66ms +step:767/1705 train_time:72601ms step_avg:94.66ms +step:768/1705 train_time:72695ms step_avg:94.65ms +step:769/1705 train_time:72788ms step_avg:94.65ms +step:770/1705 train_time:72881ms step_avg:94.65ms +step:771/1705 train_time:72975ms step_avg:94.65ms +step:772/1705 train_time:73069ms step_avg:94.65ms +step:773/1705 train_time:73164ms step_avg:94.65ms +step:774/1705 train_time:73260ms step_avg:94.65ms +step:775/1705 train_time:73355ms step_avg:94.65ms +step:776/1705 train_time:73450ms step_avg:94.65ms +step:777/1705 train_time:73544ms step_avg:94.65ms +step:778/1705 train_time:73638ms step_avg:94.65ms +step:779/1705 train_time:73732ms step_avg:94.65ms +step:780/1705 train_time:73826ms step_avg:94.65ms +step:781/1705 train_time:73920ms step_avg:94.65ms +step:782/1705 train_time:74014ms step_avg:94.65ms +step:783/1705 train_time:74109ms step_avg:94.65ms +step:784/1705 train_time:74203ms step_avg:94.65ms +step:785/1705 train_time:74299ms step_avg:94.65ms +step:786/1705 train_time:74394ms step_avg:94.65ms +step:787/1705 train_time:74490ms step_avg:94.65ms +step:788/1705 train_time:74584ms step_avg:94.65ms +step:789/1705 train_time:74679ms step_avg:94.65ms +step:790/1705 train_time:74772ms step_avg:94.65ms +step:791/1705 train_time:74866ms step_avg:94.65ms +step:792/1705 train_time:74960ms step_avg:94.65ms +step:793/1705 train_time:75055ms step_avg:94.65ms +step:794/1705 train_time:75151ms step_avg:94.65ms +step:795/1705 train_time:75244ms step_avg:94.65ms +step:796/1705 train_time:75339ms step_avg:94.65ms +step:797/1705 train_time:75435ms step_avg:94.65ms +step:798/1705 train_time:75530ms step_avg:94.65ms +step:799/1705 train_time:75624ms step_avg:94.65ms +step:800/1705 train_time:75718ms step_avg:94.65ms +step:801/1705 train_time:75813ms step_avg:94.65ms +step:802/1705 train_time:75908ms step_avg:94.65ms +step:803/1705 train_time:76001ms step_avg:94.65ms +step:804/1705 train_time:76095ms step_avg:94.65ms +step:805/1705 train_time:76190ms step_avg:94.65ms +step:806/1705 train_time:76283ms step_avg:94.64ms +step:807/1705 train_time:76378ms step_avg:94.64ms +step:808/1705 train_time:76473ms step_avg:94.65ms +step:809/1705 train_time:76568ms step_avg:94.65ms +step:810/1705 train_time:76662ms step_avg:94.64ms +step:811/1705 train_time:76756ms step_avg:94.64ms +step:812/1705 train_time:76850ms step_avg:94.64ms +step:813/1705 train_time:76944ms step_avg:94.64ms +step:814/1705 train_time:77038ms step_avg:94.64ms +step:815/1705 train_time:77133ms step_avg:94.64ms +step:816/1705 train_time:77227ms step_avg:94.64ms +step:817/1705 train_time:77321ms step_avg:94.64ms +step:818/1705 train_time:77417ms step_avg:94.64ms +step:819/1705 train_time:77513ms step_avg:94.64ms +step:820/1705 train_time:77607ms step_avg:94.64ms +step:821/1705 train_time:77700ms step_avg:94.64ms +step:822/1705 train_time:77794ms step_avg:94.64ms +step:823/1705 train_time:77889ms step_avg:94.64ms +step:824/1705 train_time:77982ms step_avg:94.64ms +step:825/1705 train_time:78077ms step_avg:94.64ms +step:826/1705 train_time:78171ms step_avg:94.64ms +step:827/1705 train_time:78266ms step_avg:94.64ms +step:828/1705 train_time:78360ms step_avg:94.64ms +step:829/1705 train_time:78456ms step_avg:94.64ms +step:830/1705 train_time:78550ms step_avg:94.64ms +step:831/1705 train_time:78644ms step_avg:94.64ms +step:832/1705 train_time:78739ms step_avg:94.64ms +step:833/1705 train_time:78833ms step_avg:94.64ms +step:834/1705 train_time:78928ms step_avg:94.64ms +step:835/1705 train_time:79021ms step_avg:94.64ms +step:836/1705 train_time:79116ms step_avg:94.64ms +step:837/1705 train_time:79211ms step_avg:94.64ms +step:838/1705 train_time:79304ms step_avg:94.64ms +step:839/1705 train_time:79398ms step_avg:94.63ms +step:840/1705 train_time:79494ms step_avg:94.64ms +step:841/1705 train_time:79589ms step_avg:94.64ms +step:842/1705 train_time:79683ms step_avg:94.64ms +step:843/1705 train_time:79777ms step_avg:94.63ms +step:844/1705 train_time:79872ms step_avg:94.63ms +step:845/1705 train_time:79966ms step_avg:94.63ms +step:846/1705 train_time:80060ms step_avg:94.63ms +step:847/1705 train_time:80156ms step_avg:94.63ms +step:848/1705 train_time:80250ms step_avg:94.63ms +step:849/1705 train_time:80344ms step_avg:94.63ms +step:850/1705 train_time:80439ms step_avg:94.63ms +step:851/1705 train_time:80722ms step_avg:94.86ms +step:852/1705 train_time:80792ms step_avg:94.83ms +step:853/1705 train_time:80885ms step_avg:94.82ms +step:854/1705 train_time:80978ms step_avg:94.82ms +step:855/1705 train_time:81071ms step_avg:94.82ms +step:856/1705 train_time:81165ms step_avg:94.82ms +step:857/1705 train_time:81258ms step_avg:94.82ms +step:858/1705 train_time:81352ms step_avg:94.82ms +step:859/1705 train_time:81445ms step_avg:94.81ms +step:860/1705 train_time:81539ms step_avg:94.81ms +step:861/1705 train_time:81639ms step_avg:94.82ms +step:862/1705 train_time:81738ms step_avg:94.82ms +step:863/1705 train_time:81837ms step_avg:94.83ms +step:864/1705 train_time:81932ms step_avg:94.83ms +step:865/1705 train_time:82026ms step_avg:94.83ms +step:866/1705 train_time:82119ms step_avg:94.83ms +step:867/1705 train_time:82213ms step_avg:94.82ms +step:868/1705 train_time:82306ms step_avg:94.82ms +step:869/1705 train_time:82399ms step_avg:94.82ms +step:870/1705 train_time:82494ms step_avg:94.82ms +step:871/1705 train_time:82589ms step_avg:94.82ms +step:872/1705 train_time:82685ms step_avg:94.82ms +step:873/1705 train_time:82780ms step_avg:94.82ms +step:874/1705 train_time:82875ms step_avg:94.82ms +step:875/1705 train_time:82970ms step_avg:94.82ms +step:875/1705 val_loss:3.5205 train_time:83065ms step_avg:94.93ms +step:876/1705 train_time:83086ms step_avg:94.85ms +step:877/1705 train_time:83165ms step_avg:94.83ms +step:878/1705 train_time:83262ms step_avg:94.83ms +step:879/1705 train_time:83357ms step_avg:94.83ms +step:880/1705 train_time:83451ms step_avg:94.83ms +step:881/1705 train_time:83544ms step_avg:94.83ms +step:882/1705 train_time:83638ms step_avg:94.83ms +step:883/1705 train_time:83731ms step_avg:94.83ms +step:884/1705 train_time:83824ms step_avg:94.82ms +step:885/1705 train_time:83919ms step_avg:94.82ms +step:886/1705 train_time:84014ms step_avg:94.82ms +step:887/1705 train_time:84112ms step_avg:94.83ms +step:888/1705 train_time:84210ms step_avg:94.83ms +step:889/1705 train_time:84304ms step_avg:94.83ms +step:890/1705 train_time:84399ms step_avg:94.83ms +step:891/1705 train_time:84495ms step_avg:94.83ms +step:892/1705 train_time:84588ms step_avg:94.83ms +step:893/1705 train_time:84681ms step_avg:94.83ms +step:894/1705 train_time:84775ms step_avg:94.83ms +step:895/1705 train_time:84868ms step_avg:94.82ms +step:896/1705 train_time:84962ms step_avg:94.82ms +step:897/1705 train_time:85057ms step_avg:94.82ms +step:898/1705 train_time:85153ms step_avg:94.83ms +step:899/1705 train_time:85248ms step_avg:94.83ms +step:900/1705 train_time:85342ms step_avg:94.82ms +step:901/1705 train_time:85437ms step_avg:94.82ms +step:902/1705 train_time:85532ms step_avg:94.82ms +step:903/1705 train_time:85627ms step_avg:94.82ms +step:904/1705 train_time:85720ms step_avg:94.82ms +step:905/1705 train_time:85814ms step_avg:94.82ms +step:906/1705 train_time:85907ms step_avg:94.82ms +step:907/1705 train_time:86000ms step_avg:94.82ms +step:908/1705 train_time:86095ms step_avg:94.82ms +step:909/1705 train_time:86191ms step_avg:94.82ms +step:910/1705 train_time:86288ms step_avg:94.82ms +step:911/1705 train_time:86381ms step_avg:94.82ms +step:912/1705 train_time:86477ms step_avg:94.82ms +step:913/1705 train_time:86572ms step_avg:94.82ms +step:914/1705 train_time:86666ms step_avg:94.82ms +step:915/1705 train_time:86760ms step_avg:94.82ms +step:916/1705 train_time:86855ms step_avg:94.82ms +step:917/1705 train_time:86950ms step_avg:94.82ms +step:918/1705 train_time:87044ms step_avg:94.82ms +step:919/1705 train_time:87138ms step_avg:94.82ms +step:920/1705 train_time:87233ms step_avg:94.82ms +step:921/1705 train_time:87328ms step_avg:94.82ms +step:922/1705 train_time:87422ms step_avg:94.82ms +step:923/1705 train_time:87516ms step_avg:94.82ms +step:924/1705 train_time:87610ms step_avg:94.82ms +step:925/1705 train_time:87704ms step_avg:94.81ms +step:926/1705 train_time:87798ms step_avg:94.81ms +step:927/1705 train_time:87892ms step_avg:94.81ms +step:928/1705 train_time:87986ms step_avg:94.81ms +step:929/1705 train_time:88081ms step_avg:94.81ms +step:930/1705 train_time:88176ms step_avg:94.81ms +step:931/1705 train_time:88271ms step_avg:94.81ms +step:932/1705 train_time:88366ms step_avg:94.81ms +step:933/1705 train_time:88460ms step_avg:94.81ms +step:934/1705 train_time:88555ms step_avg:94.81ms +step:935/1705 train_time:88651ms step_avg:94.81ms +step:936/1705 train_time:88745ms step_avg:94.81ms +step:937/1705 train_time:88839ms step_avg:94.81ms +step:938/1705 train_time:88934ms step_avg:94.81ms +step:939/1705 train_time:89029ms step_avg:94.81ms +step:940/1705 train_time:89122ms step_avg:94.81ms +step:941/1705 train_time:89217ms step_avg:94.81ms +step:942/1705 train_time:89312ms step_avg:94.81ms +step:943/1705 train_time:89406ms step_avg:94.81ms +step:944/1705 train_time:89500ms step_avg:94.81ms +step:945/1705 train_time:89594ms step_avg:94.81ms +step:946/1705 train_time:89689ms step_avg:94.81ms +step:947/1705 train_time:89783ms step_avg:94.81ms +step:948/1705 train_time:89877ms step_avg:94.81ms +step:949/1705 train_time:89972ms step_avg:94.81ms +step:950/1705 train_time:90067ms step_avg:94.81ms +step:951/1705 train_time:90161ms step_avg:94.81ms +step:952/1705 train_time:90255ms step_avg:94.81ms +step:953/1705 train_time:90350ms step_avg:94.81ms +step:954/1705 train_time:90444ms step_avg:94.81ms +step:955/1705 train_time:90538ms step_avg:94.80ms +step:956/1705 train_time:90633ms step_avg:94.80ms +step:957/1705 train_time:90728ms step_avg:94.80ms +step:958/1705 train_time:90822ms step_avg:94.80ms +step:959/1705 train_time:90917ms step_avg:94.80ms +step:960/1705 train_time:91011ms step_avg:94.80ms +step:961/1705 train_time:91105ms step_avg:94.80ms +step:962/1705 train_time:91200ms step_avg:94.80ms +step:963/1705 train_time:91295ms step_avg:94.80ms +step:964/1705 train_time:91390ms step_avg:94.80ms +step:965/1705 train_time:91484ms step_avg:94.80ms +step:966/1705 train_time:91578ms step_avg:94.80ms +step:967/1705 train_time:91674ms step_avg:94.80ms +step:968/1705 train_time:91769ms step_avg:94.80ms +step:969/1705 train_time:91864ms step_avg:94.80ms +step:970/1705 train_time:91958ms step_avg:94.80ms +step:971/1705 train_time:92053ms step_avg:94.80ms +step:972/1705 train_time:92147ms step_avg:94.80ms +step:973/1705 train_time:92240ms step_avg:94.80ms +step:974/1705 train_time:92335ms step_avg:94.80ms +step:975/1705 train_time:92428ms step_avg:94.80ms +step:976/1705 train_time:92522ms step_avg:94.80ms +step:977/1705 train_time:92617ms step_avg:94.80ms +step:978/1705 train_time:92711ms step_avg:94.80ms +step:979/1705 train_time:92806ms step_avg:94.80ms +step:980/1705 train_time:92900ms step_avg:94.80ms +step:981/1705 train_time:92995ms step_avg:94.80ms +step:982/1705 train_time:93091ms step_avg:94.80ms +step:983/1705 train_time:93185ms step_avg:94.80ms +step:984/1705 train_time:93279ms step_avg:94.80ms +step:985/1705 train_time:93373ms step_avg:94.80ms +step:986/1705 train_time:93468ms step_avg:94.79ms +step:987/1705 train_time:93561ms step_avg:94.79ms +step:988/1705 train_time:93655ms step_avg:94.79ms +step:989/1705 train_time:93750ms step_avg:94.79ms +step:990/1705 train_time:93844ms step_avg:94.79ms +step:991/1705 train_time:93938ms step_avg:94.79ms +step:992/1705 train_time:94033ms step_avg:94.79ms +step:993/1705 train_time:94127ms step_avg:94.79ms +step:994/1705 train_time:94221ms step_avg:94.79ms +step:995/1705 train_time:94315ms step_avg:94.79ms +step:996/1705 train_time:94410ms step_avg:94.79ms +step:997/1705 train_time:94504ms step_avg:94.79ms +step:998/1705 train_time:94598ms step_avg:94.79ms +step:999/1705 train_time:94694ms step_avg:94.79ms +step:1000/1705 train_time:94789ms step_avg:94.79ms +step:1000/1705 val_loss:3.4826 train_time:94884ms step_avg:94.88ms +step:1001/1705 train_time:94905ms step_avg:94.81ms +step:1002/1705 train_time:94983ms step_avg:94.79ms +step:1003/1705 train_time:95080ms step_avg:94.80ms +step:1004/1705 train_time:95174ms step_avg:94.79ms +step:1005/1705 train_time:95268ms step_avg:94.79ms +step:1006/1705 train_time:95363ms step_avg:94.79ms +step:1007/1705 train_time:95456ms step_avg:94.79ms +step:1008/1705 train_time:95550ms step_avg:94.79ms +step:1009/1705 train_time:95643ms step_avg:94.79ms +step:1010/1705 train_time:95737ms step_avg:94.79ms +step:1011/1705 train_time:95833ms step_avg:94.79ms +step:1012/1705 train_time:95931ms step_avg:94.79ms +step:1013/1705 train_time:96029ms step_avg:94.80ms +step:1014/1705 train_time:96124ms step_avg:94.80ms +step:1015/1705 train_time:96217ms step_avg:94.80ms +step:1016/1705 train_time:96312ms step_avg:94.80ms +step:1017/1705 train_time:96406ms step_avg:94.79ms +step:1018/1705 train_time:96499ms step_avg:94.79ms +step:1019/1705 train_time:96593ms step_avg:94.79ms +step:1020/1705 train_time:96687ms step_avg:94.79ms +step:1021/1705 train_time:96781ms step_avg:94.79ms +step:1022/1705 train_time:96876ms step_avg:94.79ms +step:1023/1705 train_time:96972ms step_avg:94.79ms +step:1024/1705 train_time:97069ms step_avg:94.79ms +step:1025/1705 train_time:97164ms step_avg:94.79ms +step:1026/1705 train_time:97259ms step_avg:94.79ms +step:1027/1705 train_time:97353ms step_avg:94.79ms +step:1028/1705 train_time:97448ms step_avg:94.79ms +step:1029/1705 train_time:97542ms step_avg:94.79ms +step:1030/1705 train_time:97636ms step_avg:94.79ms +step:1031/1705 train_time:97730ms step_avg:94.79ms +step:1032/1705 train_time:97826ms step_avg:94.79ms +step:1033/1705 train_time:97920ms step_avg:94.79ms +step:1034/1705 train_time:98015ms step_avg:94.79ms +step:1035/1705 train_time:98111ms step_avg:94.79ms +step:1036/1705 train_time:98207ms step_avg:94.79ms +step:1037/1705 train_time:98301ms step_avg:94.79ms +step:1038/1705 train_time:98395ms step_avg:94.79ms +step:1039/1705 train_time:98489ms step_avg:94.79ms +step:1040/1705 train_time:98583ms step_avg:94.79ms +step:1041/1705 train_time:98677ms step_avg:94.79ms +step:1042/1705 train_time:98771ms step_avg:94.79ms +step:1043/1705 train_time:98866ms step_avg:94.79ms +step:1044/1705 train_time:98960ms step_avg:94.79ms +step:1045/1705 train_time:99055ms step_avg:94.79ms +step:1046/1705 train_time:99151ms step_avg:94.79ms +step:1047/1705 train_time:99246ms step_avg:94.79ms +step:1048/1705 train_time:99340ms step_avg:94.79ms +step:1049/1705 train_time:99434ms step_avg:94.79ms +step:1050/1705 train_time:99530ms step_avg:94.79ms +step:1051/1705 train_time:99624ms step_avg:94.79ms +step:1052/1705 train_time:99718ms step_avg:94.79ms +step:1053/1705 train_time:99812ms step_avg:94.79ms +step:1054/1705 train_time:99908ms step_avg:94.79ms +step:1055/1705 train_time:100002ms step_avg:94.79ms +step:1056/1705 train_time:100096ms step_avg:94.79ms +step:1057/1705 train_time:100192ms step_avg:94.79ms +step:1058/1705 train_time:100287ms step_avg:94.79ms +step:1059/1705 train_time:100381ms step_avg:94.79ms +step:1060/1705 train_time:100475ms step_avg:94.79ms +step:1061/1705 train_time:100569ms step_avg:94.79ms +step:1062/1705 train_time:100811ms step_avg:94.93ms +step:1063/1705 train_time:100990ms step_avg:95.00ms +step:1064/1705 train_time:101083ms step_avg:95.00ms +step:1065/1705 train_time:101176ms step_avg:95.00ms +step:1066/1705 train_time:101270ms step_avg:95.00ms +step:1067/1705 train_time:101363ms step_avg:95.00ms +step:1068/1705 train_time:101456ms step_avg:95.00ms +step:1069/1705 train_time:101550ms step_avg:95.00ms +step:1070/1705 train_time:101643ms step_avg:94.99ms +step:1071/1705 train_time:101736ms step_avg:94.99ms +step:1072/1705 train_time:101835ms step_avg:95.00ms +step:1073/1705 train_time:101934ms step_avg:95.00ms +step:1074/1705 train_time:102033ms step_avg:95.00ms +step:1075/1705 train_time:102130ms step_avg:95.00ms +step:1076/1705 train_time:102225ms step_avg:95.00ms +step:1077/1705 train_time:102318ms step_avg:95.00ms +step:1078/1705 train_time:102411ms step_avg:95.00ms +step:1079/1705 train_time:102504ms step_avg:95.00ms +step:1080/1705 train_time:102598ms step_avg:95.00ms +step:1081/1705 train_time:102691ms step_avg:95.00ms +step:1082/1705 train_time:102786ms step_avg:95.00ms +step:1083/1705 train_time:102881ms step_avg:95.00ms +step:1084/1705 train_time:102976ms step_avg:95.00ms +step:1085/1705 train_time:103071ms step_avg:95.00ms +step:1086/1705 train_time:103166ms step_avg:95.00ms +step:1087/1705 train_time:103260ms step_avg:95.00ms +step:1088/1705 train_time:103354ms step_avg:94.99ms +step:1089/1705 train_time:103447ms step_avg:94.99ms +step:1090/1705 train_time:103541ms step_avg:94.99ms +step:1091/1705 train_time:103635ms step_avg:94.99ms +step:1092/1705 train_time:103730ms step_avg:94.99ms +step:1093/1705 train_time:103825ms step_avg:94.99ms +step:1094/1705 train_time:103920ms step_avg:94.99ms +step:1095/1705 train_time:104015ms step_avg:94.99ms +step:1096/1705 train_time:104111ms step_avg:94.99ms +step:1097/1705 train_time:104207ms step_avg:94.99ms +step:1098/1705 train_time:104300ms step_avg:94.99ms +step:1099/1705 train_time:104395ms step_avg:94.99ms +step:1100/1705 train_time:104490ms step_avg:94.99ms +step:1101/1705 train_time:104584ms step_avg:94.99ms +step:1102/1705 train_time:104678ms step_avg:94.99ms +step:1103/1705 train_time:104772ms step_avg:94.99ms +step:1104/1705 train_time:104867ms step_avg:94.99ms +step:1105/1705 train_time:104961ms step_avg:94.99ms +step:1106/1705 train_time:105055ms step_avg:94.99ms +step:1107/1705 train_time:105151ms step_avg:94.99ms +step:1108/1705 train_time:105246ms step_avg:94.99ms +step:1109/1705 train_time:105340ms step_avg:94.99ms +step:1110/1705 train_time:105434ms step_avg:94.99ms +step:1111/1705 train_time:105528ms step_avg:94.99ms +step:1112/1705 train_time:105624ms step_avg:94.99ms +step:1113/1705 train_time:105717ms step_avg:94.98ms +step:1114/1705 train_time:105813ms step_avg:94.98ms +step:1115/1705 train_time:105909ms step_avg:94.99ms +step:1116/1705 train_time:106004ms step_avg:94.99ms +step:1117/1705 train_time:106098ms step_avg:94.98ms +step:1118/1705 train_time:106194ms step_avg:94.99ms +step:1119/1705 train_time:106289ms step_avg:94.99ms +step:1120/1705 train_time:106382ms step_avg:94.98ms +step:1121/1705 train_time:106476ms step_avg:94.98ms +step:1122/1705 train_time:106571ms step_avg:94.98ms +step:1123/1705 train_time:106665ms step_avg:94.98ms +step:1124/1705 train_time:106758ms step_avg:94.98ms +step:1125/1705 train_time:106852ms step_avg:94.98ms +step:1125/1705 val_loss:3.4350 train_time:106948ms step_avg:95.07ms +step:1126/1705 train_time:106969ms step_avg:95.00ms +step:1127/1705 train_time:107048ms step_avg:94.98ms +step:1128/1705 train_time:107146ms step_avg:94.99ms +step:1129/1705 train_time:107241ms step_avg:94.99ms +step:1130/1705 train_time:107335ms step_avg:94.99ms +step:1131/1705 train_time:107429ms step_avg:94.99ms +step:1132/1705 train_time:107522ms step_avg:94.98ms +step:1133/1705 train_time:107615ms step_avg:94.98ms +step:1134/1705 train_time:107709ms step_avg:94.98ms +step:1135/1705 train_time:107802ms step_avg:94.98ms +step:1136/1705 train_time:107897ms step_avg:94.98ms +step:1137/1705 train_time:107993ms step_avg:94.98ms +step:1138/1705 train_time:108090ms step_avg:94.98ms +step:1139/1705 train_time:108187ms step_avg:94.98ms +step:1140/1705 train_time:108282ms step_avg:94.98ms +step:1141/1705 train_time:108377ms step_avg:94.98ms +step:1142/1705 train_time:108472ms step_avg:94.98ms +step:1143/1705 train_time:108567ms step_avg:94.98ms +step:1144/1705 train_time:108661ms step_avg:94.98ms +step:1145/1705 train_time:108756ms step_avg:94.98ms +step:1146/1705 train_time:108850ms step_avg:94.98ms +step:1147/1705 train_time:108945ms step_avg:94.98ms +step:1148/1705 train_time:109042ms step_avg:94.98ms +step:1149/1705 train_time:109137ms step_avg:94.98ms +step:1150/1705 train_time:109233ms step_avg:94.99ms +step:1151/1705 train_time:109329ms step_avg:94.99ms +step:1152/1705 train_time:109424ms step_avg:94.99ms +step:1153/1705 train_time:109518ms step_avg:94.99ms +step:1154/1705 train_time:109613ms step_avg:94.99ms +step:1155/1705 train_time:109708ms step_avg:94.99ms +step:1156/1705 train_time:109803ms step_avg:94.99ms +step:1157/1705 train_time:109898ms step_avg:94.99ms +step:1158/1705 train_time:109995ms step_avg:94.99ms +step:1159/1705 train_time:110092ms step_avg:94.99ms +step:1160/1705 train_time:110188ms step_avg:94.99ms +step:1161/1705 train_time:110283ms step_avg:94.99ms +step:1162/1705 train_time:110379ms step_avg:94.99ms +step:1163/1705 train_time:110474ms step_avg:94.99ms +step:1164/1705 train_time:110569ms step_avg:94.99ms +step:1165/1705 train_time:110664ms step_avg:94.99ms +step:1166/1705 train_time:110759ms step_avg:94.99ms +step:1167/1705 train_time:110854ms step_avg:94.99ms +step:1168/1705 train_time:110950ms step_avg:94.99ms +step:1169/1705 train_time:111045ms step_avg:94.99ms +step:1170/1705 train_time:111141ms step_avg:94.99ms +step:1171/1705 train_time:111236ms step_avg:94.99ms +step:1172/1705 train_time:111332ms step_avg:94.99ms +step:1173/1705 train_time:111429ms step_avg:94.99ms +step:1174/1705 train_time:111524ms step_avg:95.00ms +step:1175/1705 train_time:111619ms step_avg:94.99ms +step:1176/1705 train_time:111714ms step_avg:94.99ms +step:1177/1705 train_time:111808ms step_avg:94.99ms +step:1178/1705 train_time:111903ms step_avg:94.99ms +step:1179/1705 train_time:111999ms step_avg:95.00ms +step:1180/1705 train_time:112095ms step_avg:95.00ms +step:1181/1705 train_time:112191ms step_avg:95.00ms +step:1182/1705 train_time:112287ms step_avg:95.00ms +step:1183/1705 train_time:112382ms step_avg:95.00ms +step:1184/1705 train_time:112478ms step_avg:95.00ms +step:1185/1705 train_time:112573ms step_avg:95.00ms +step:1186/1705 train_time:112668ms step_avg:95.00ms +step:1187/1705 train_time:112762ms step_avg:95.00ms +step:1188/1705 train_time:112857ms step_avg:95.00ms +step:1189/1705 train_time:112953ms step_avg:95.00ms +step:1190/1705 train_time:113049ms step_avg:95.00ms +step:1191/1705 train_time:113144ms step_avg:95.00ms +step:1192/1705 train_time:113240ms step_avg:95.00ms +step:1193/1705 train_time:113335ms step_avg:95.00ms +step:1194/1705 train_time:113432ms step_avg:95.00ms +step:1195/1705 train_time:113528ms step_avg:95.00ms +step:1196/1705 train_time:113623ms step_avg:95.00ms +step:1197/1705 train_time:113718ms step_avg:95.00ms +step:1198/1705 train_time:113813ms step_avg:95.00ms +step:1199/1705 train_time:113909ms step_avg:95.00ms +step:1200/1705 train_time:114005ms step_avg:95.00ms +step:1201/1705 train_time:114100ms step_avg:95.00ms +step:1202/1705 train_time:114196ms step_avg:95.00ms +step:1203/1705 train_time:114291ms step_avg:95.01ms +step:1204/1705 train_time:114388ms step_avg:95.01ms +step:1205/1705 train_time:114484ms step_avg:95.01ms +step:1206/1705 train_time:114578ms step_avg:95.01ms +step:1207/1705 train_time:114673ms step_avg:95.01ms +step:1208/1705 train_time:114769ms step_avg:95.01ms +step:1209/1705 train_time:114864ms step_avg:95.01ms +step:1210/1705 train_time:114959ms step_avg:95.01ms +step:1211/1705 train_time:115054ms step_avg:95.01ms +step:1212/1705 train_time:115151ms step_avg:95.01ms +step:1213/1705 train_time:115247ms step_avg:95.01ms +step:1214/1705 train_time:115342ms step_avg:95.01ms +step:1215/1705 train_time:115438ms step_avg:95.01ms +step:1216/1705 train_time:115533ms step_avg:95.01ms +step:1217/1705 train_time:115630ms step_avg:95.01ms +step:1218/1705 train_time:115725ms step_avg:95.01ms +step:1219/1705 train_time:115820ms step_avg:95.01ms +step:1220/1705 train_time:115915ms step_avg:95.01ms +step:1221/1705 train_time:116011ms step_avg:95.01ms +step:1222/1705 train_time:116106ms step_avg:95.01ms +step:1223/1705 train_time:116201ms step_avg:95.01ms +step:1224/1705 train_time:116296ms step_avg:95.01ms +step:1225/1705 train_time:116392ms step_avg:95.01ms +step:1226/1705 train_time:116488ms step_avg:95.01ms +step:1227/1705 train_time:116582ms step_avg:95.01ms +step:1228/1705 train_time:116677ms step_avg:95.01ms +step:1229/1705 train_time:116772ms step_avg:95.01ms +step:1230/1705 train_time:116867ms step_avg:95.01ms +step:1231/1705 train_time:116962ms step_avg:95.01ms +step:1232/1705 train_time:117057ms step_avg:95.01ms +step:1233/1705 train_time:117153ms step_avg:95.01ms +step:1234/1705 train_time:117249ms step_avg:95.02ms +step:1235/1705 train_time:117344ms step_avg:95.02ms +step:1236/1705 train_time:117439ms step_avg:95.02ms +step:1237/1705 train_time:117534ms step_avg:95.02ms +step:1238/1705 train_time:117631ms step_avg:95.02ms +step:1239/1705 train_time:117726ms step_avg:95.02ms +step:1240/1705 train_time:117822ms step_avg:95.02ms +step:1241/1705 train_time:117916ms step_avg:95.02ms +step:1242/1705 train_time:118011ms step_avg:95.02ms +step:1243/1705 train_time:118107ms step_avg:95.02ms +step:1244/1705 train_time:118202ms step_avg:95.02ms +step:1245/1705 train_time:118297ms step_avg:95.02ms +step:1246/1705 train_time:118393ms step_avg:95.02ms +step:1247/1705 train_time:118488ms step_avg:95.02ms +step:1248/1705 train_time:118583ms step_avg:95.02ms +step:1249/1705 train_time:118679ms step_avg:95.02ms +step:1250/1705 train_time:118774ms step_avg:95.02ms +step:1250/1705 val_loss:3.3868 train_time:118871ms step_avg:95.10ms +step:1251/1705 train_time:118891ms step_avg:95.04ms +step:1252/1705 train_time:118979ms step_avg:95.03ms +step:1253/1705 train_time:119076ms step_avg:95.03ms +step:1254/1705 train_time:119170ms step_avg:95.03ms +step:1255/1705 train_time:119264ms step_avg:95.03ms +step:1256/1705 train_time:119358ms step_avg:95.03ms +step:1257/1705 train_time:119452ms step_avg:95.03ms +step:1258/1705 train_time:119547ms step_avg:95.03ms +step:1259/1705 train_time:119642ms step_avg:95.03ms +step:1260/1705 train_time:119735ms step_avg:95.03ms +step:1261/1705 train_time:119831ms step_avg:95.03ms +step:1262/1705 train_time:119930ms step_avg:95.03ms +step:1263/1705 train_time:120028ms step_avg:95.03ms +step:1264/1705 train_time:120123ms step_avg:95.03ms +step:1265/1705 train_time:120219ms step_avg:95.03ms +step:1266/1705 train_time:120313ms step_avg:95.03ms +step:1267/1705 train_time:120407ms step_avg:95.03ms +step:1268/1705 train_time:120501ms step_avg:95.03ms +step:1269/1705 train_time:120596ms step_avg:95.03ms +step:1270/1705 train_time:120690ms step_avg:95.03ms +step:1271/1705 train_time:120785ms step_avg:95.03ms +step:1272/1705 train_time:120882ms step_avg:95.03ms +step:1273/1705 train_time:120979ms step_avg:95.03ms +step:1274/1705 train_time:121347ms step_avg:95.25ms +step:1275/1705 train_time:121432ms step_avg:95.24ms +step:1276/1705 train_time:121526ms step_avg:95.24ms +step:1277/1705 train_time:121621ms step_avg:95.24ms +step:1278/1705 train_time:121715ms step_avg:95.24ms +step:1279/1705 train_time:121808ms step_avg:95.24ms +step:1280/1705 train_time:121902ms step_avg:95.24ms +step:1281/1705 train_time:121996ms step_avg:95.23ms +step:1282/1705 train_time:122090ms step_avg:95.23ms +step:1283/1705 train_time:122184ms step_avg:95.23ms +step:1284/1705 train_time:122288ms step_avg:95.24ms +step:1285/1705 train_time:122386ms step_avg:95.24ms +step:1286/1705 train_time:122482ms step_avg:95.24ms +step:1287/1705 train_time:122578ms step_avg:95.24ms +step:1288/1705 train_time:122672ms step_avg:95.24ms +step:1289/1705 train_time:122766ms step_avg:95.24ms +step:1290/1705 train_time:122861ms step_avg:95.24ms +step:1291/1705 train_time:122955ms step_avg:95.24ms +step:1292/1705 train_time:123049ms step_avg:95.24ms +step:1293/1705 train_time:123144ms step_avg:95.24ms +step:1294/1705 train_time:123241ms step_avg:95.24ms +step:1295/1705 train_time:123338ms step_avg:95.24ms +step:1296/1705 train_time:123436ms step_avg:95.24ms +step:1297/1705 train_time:123530ms step_avg:95.24ms +step:1298/1705 train_time:123626ms step_avg:95.24ms +step:1299/1705 train_time:123721ms step_avg:95.24ms +step:1300/1705 train_time:123815ms step_avg:95.24ms +step:1301/1705 train_time:123910ms step_avg:95.24ms +step:1302/1705 train_time:124005ms step_avg:95.24ms +step:1303/1705 train_time:124099ms step_avg:95.24ms +step:1304/1705 train_time:124194ms step_avg:95.24ms +step:1305/1705 train_time:124290ms step_avg:95.24ms +step:1306/1705 train_time:124386ms step_avg:95.24ms +step:1307/1705 train_time:124483ms step_avg:95.24ms +step:1308/1705 train_time:124579ms step_avg:95.24ms +step:1309/1705 train_time:124674ms step_avg:95.24ms +step:1310/1705 train_time:124768ms step_avg:95.24ms +step:1311/1705 train_time:124864ms step_avg:95.24ms +step:1312/1705 train_time:124958ms step_avg:95.24ms +step:1313/1705 train_time:125053ms step_avg:95.24ms +step:1314/1705 train_time:125148ms step_avg:95.24ms +step:1315/1705 train_time:125244ms step_avg:95.24ms +step:1316/1705 train_time:125340ms step_avg:95.24ms +step:1317/1705 train_time:125436ms step_avg:95.24ms +step:1318/1705 train_time:125531ms step_avg:95.24ms +step:1319/1705 train_time:125626ms step_avg:95.24ms +step:1320/1705 train_time:125721ms step_avg:95.24ms +step:1321/1705 train_time:125816ms step_avg:95.24ms +step:1322/1705 train_time:125910ms step_avg:95.24ms +step:1323/1705 train_time:126006ms step_avg:95.24ms +step:1324/1705 train_time:126101ms step_avg:95.24ms +step:1325/1705 train_time:126196ms step_avg:95.24ms +step:1326/1705 train_time:126291ms step_avg:95.24ms +step:1327/1705 train_time:126386ms step_avg:95.24ms +step:1328/1705 train_time:126483ms step_avg:95.24ms +step:1329/1705 train_time:126579ms step_avg:95.24ms +step:1330/1705 train_time:126674ms step_avg:95.24ms +step:1331/1705 train_time:126769ms step_avg:95.24ms +step:1332/1705 train_time:126865ms step_avg:95.24ms +step:1333/1705 train_time:126961ms step_avg:95.24ms +step:1334/1705 train_time:127056ms step_avg:95.24ms +step:1335/1705 train_time:127151ms step_avg:95.24ms +step:1336/1705 train_time:127246ms step_avg:95.24ms +step:1337/1705 train_time:127342ms step_avg:95.24ms +step:1338/1705 train_time:127437ms step_avg:95.24ms +step:1339/1705 train_time:127532ms step_avg:95.24ms +step:1340/1705 train_time:127628ms step_avg:95.24ms +step:1341/1705 train_time:127722ms step_avg:95.24ms +step:1342/1705 train_time:127818ms step_avg:95.24ms +step:1343/1705 train_time:127913ms step_avg:95.24ms +step:1344/1705 train_time:128007ms step_avg:95.24ms +step:1345/1705 train_time:128102ms step_avg:95.24ms +step:1346/1705 train_time:128198ms step_avg:95.24ms +step:1347/1705 train_time:128293ms step_avg:95.24ms +step:1348/1705 train_time:128388ms step_avg:95.24ms +step:1349/1705 train_time:128485ms step_avg:95.24ms +step:1350/1705 train_time:128582ms step_avg:95.25ms +step:1351/1705 train_time:128677ms step_avg:95.25ms +step:1352/1705 train_time:128772ms step_avg:95.25ms +step:1353/1705 train_time:128867ms step_avg:95.25ms +step:1354/1705 train_time:128963ms step_avg:95.25ms +step:1355/1705 train_time:129058ms step_avg:95.25ms +step:1356/1705 train_time:129152ms step_avg:95.25ms +step:1357/1705 train_time:129249ms step_avg:95.25ms +step:1358/1705 train_time:129343ms step_avg:95.25ms +step:1359/1705 train_time:129438ms step_avg:95.25ms +step:1360/1705 train_time:129533ms step_avg:95.24ms +step:1361/1705 train_time:129628ms step_avg:95.24ms +step:1362/1705 train_time:129723ms step_avg:95.24ms +step:1363/1705 train_time:129819ms step_avg:95.24ms +step:1364/1705 train_time:129914ms step_avg:95.25ms +step:1365/1705 train_time:130009ms step_avg:95.24ms +step:1366/1705 train_time:130104ms step_avg:95.24ms +step:1367/1705 train_time:130200ms step_avg:95.25ms +step:1368/1705 train_time:130295ms step_avg:95.24ms +step:1369/1705 train_time:130389ms step_avg:95.24ms +step:1370/1705 train_time:130484ms step_avg:95.24ms +step:1371/1705 train_time:130581ms step_avg:95.25ms +step:1372/1705 train_time:130677ms step_avg:95.25ms +step:1373/1705 train_time:130772ms step_avg:95.25ms +step:1374/1705 train_time:130867ms step_avg:95.25ms +step:1375/1705 train_time:130964ms step_avg:95.25ms +step:1375/1705 val_loss:3.3495 train_time:131060ms step_avg:95.32ms +step:1376/1705 train_time:131081ms step_avg:95.26ms +step:1377/1705 train_time:131162ms step_avg:95.25ms +step:1378/1705 train_time:131259ms step_avg:95.25ms +step:1379/1705 train_time:131355ms step_avg:95.25ms +step:1380/1705 train_time:131450ms step_avg:95.25ms +step:1381/1705 train_time:131544ms step_avg:95.25ms +step:1382/1705 train_time:131638ms step_avg:95.25ms +step:1383/1705 train_time:131733ms step_avg:95.25ms +step:1384/1705 train_time:131827ms step_avg:95.25ms +step:1385/1705 train_time:131922ms step_avg:95.25ms +step:1386/1705 train_time:132018ms step_avg:95.25ms +step:1387/1705 train_time:132117ms step_avg:95.25ms +step:1388/1705 train_time:132215ms step_avg:95.26ms +step:1389/1705 train_time:132313ms step_avg:95.26ms +step:1390/1705 train_time:132409ms step_avg:95.26ms +step:1391/1705 train_time:132504ms step_avg:95.26ms +step:1392/1705 train_time:132598ms step_avg:95.26ms +step:1393/1705 train_time:132692ms step_avg:95.26ms +step:1394/1705 train_time:132787ms step_avg:95.26ms +step:1395/1705 train_time:132881ms step_avg:95.26ms +step:1396/1705 train_time:132976ms step_avg:95.26ms +step:1397/1705 train_time:133073ms step_avg:95.26ms +step:1398/1705 train_time:133170ms step_avg:95.26ms +step:1399/1705 train_time:133268ms step_avg:95.26ms +step:1400/1705 train_time:133364ms step_avg:95.26ms +step:1401/1705 train_time:133458ms step_avg:95.26ms +step:1402/1705 train_time:133554ms step_avg:95.26ms +step:1403/1705 train_time:133649ms step_avg:95.26ms +step:1404/1705 train_time:133744ms step_avg:95.26ms +step:1405/1705 train_time:133838ms step_avg:95.26ms +step:1406/1705 train_time:133933ms step_avg:95.26ms +step:1407/1705 train_time:134028ms step_avg:95.26ms +step:1408/1705 train_time:134123ms step_avg:95.26ms +step:1409/1705 train_time:134221ms step_avg:95.26ms +step:1410/1705 train_time:134318ms step_avg:95.26ms +step:1411/1705 train_time:134414ms step_avg:95.26ms +step:1412/1705 train_time:134508ms step_avg:95.26ms +step:1413/1705 train_time:134602ms step_avg:95.26ms +step:1414/1705 train_time:134697ms step_avg:95.26ms +step:1415/1705 train_time:134793ms step_avg:95.26ms +step:1416/1705 train_time:134888ms step_avg:95.26ms +step:1417/1705 train_time:134983ms step_avg:95.26ms +step:1418/1705 train_time:135078ms step_avg:95.26ms +step:1419/1705 train_time:135175ms step_avg:95.26ms +step:1420/1705 train_time:135271ms step_avg:95.26ms +step:1421/1705 train_time:135366ms step_avg:95.26ms +step:1422/1705 train_time:135461ms step_avg:95.26ms +step:1423/1705 train_time:135556ms step_avg:95.26ms +step:1424/1705 train_time:135650ms step_avg:95.26ms +step:1425/1705 train_time:135746ms step_avg:95.26ms +step:1426/1705 train_time:135840ms step_avg:95.26ms +step:1427/1705 train_time:135935ms step_avg:95.26ms +step:1428/1705 train_time:136030ms step_avg:95.26ms +step:1429/1705 train_time:136126ms step_avg:95.26ms +step:1430/1705 train_time:136221ms step_avg:95.26ms +step:1431/1705 train_time:136317ms step_avg:95.26ms +step:1432/1705 train_time:136413ms step_avg:95.26ms +step:1433/1705 train_time:136509ms step_avg:95.26ms +step:1434/1705 train_time:136603ms step_avg:95.26ms +step:1435/1705 train_time:136699ms step_avg:95.26ms +step:1436/1705 train_time:136794ms step_avg:95.26ms +step:1437/1705 train_time:136889ms step_avg:95.26ms +step:1438/1705 train_time:136983ms step_avg:95.26ms +step:1439/1705 train_time:137079ms step_avg:95.26ms +step:1440/1705 train_time:137175ms step_avg:95.26ms +step:1441/1705 train_time:137272ms step_avg:95.26ms +step:1442/1705 train_time:137367ms step_avg:95.26ms +step:1443/1705 train_time:137462ms step_avg:95.26ms +step:1444/1705 train_time:137557ms step_avg:95.26ms +step:1445/1705 train_time:137652ms step_avg:95.26ms +step:1446/1705 train_time:137747ms step_avg:95.26ms +step:1447/1705 train_time:137842ms step_avg:95.26ms +step:1448/1705 train_time:137937ms step_avg:95.26ms +step:1449/1705 train_time:138033ms step_avg:95.26ms +step:1450/1705 train_time:138129ms step_avg:95.26ms +step:1451/1705 train_time:138224ms step_avg:95.26ms +step:1452/1705 train_time:138319ms step_avg:95.26ms +step:1453/1705 train_time:138416ms step_avg:95.26ms +step:1454/1705 train_time:138512ms step_avg:95.26ms +step:1455/1705 train_time:138607ms step_avg:95.26ms +step:1456/1705 train_time:138701ms step_avg:95.26ms +step:1457/1705 train_time:138797ms step_avg:95.26ms +step:1458/1705 train_time:138892ms step_avg:95.26ms +step:1459/1705 train_time:138987ms step_avg:95.26ms +step:1460/1705 train_time:139082ms step_avg:95.26ms +step:1461/1705 train_time:139178ms step_avg:95.26ms +step:1462/1705 train_time:139274ms step_avg:95.26ms +step:1463/1705 train_time:139371ms step_avg:95.26ms +step:1464/1705 train_time:139466ms step_avg:95.26ms +step:1465/1705 train_time:139561ms step_avg:95.26ms +step:1466/1705 train_time:139656ms step_avg:95.26ms +step:1467/1705 train_time:139752ms step_avg:95.26ms +step:1468/1705 train_time:139848ms step_avg:95.26ms +step:1469/1705 train_time:139943ms step_avg:95.26ms +step:1470/1705 train_time:140040ms step_avg:95.27ms +step:1471/1705 train_time:140133ms step_avg:95.26ms +step:1472/1705 train_time:140228ms step_avg:95.26ms +step:1473/1705 train_time:140324ms step_avg:95.26ms +step:1474/1705 train_time:140420ms step_avg:95.26ms +step:1475/1705 train_time:140516ms step_avg:95.26ms +step:1476/1705 train_time:140611ms step_avg:95.26ms +step:1477/1705 train_time:140706ms step_avg:95.26ms +step:1478/1705 train_time:140801ms step_avg:95.26ms +step:1479/1705 train_time:140897ms step_avg:95.26ms +step:1480/1705 train_time:140993ms step_avg:95.27ms +step:1481/1705 train_time:141090ms step_avg:95.27ms +step:1482/1705 train_time:141185ms step_avg:95.27ms +step:1483/1705 train_time:141279ms step_avg:95.27ms +step:1484/1705 train_time:141376ms step_avg:95.27ms +step:1485/1705 train_time:141646ms step_avg:95.38ms +step:1486/1705 train_time:141838ms step_avg:95.45ms +step:1487/1705 train_time:141931ms step_avg:95.45ms +step:1488/1705 train_time:142025ms step_avg:95.45ms +step:1489/1705 train_time:142118ms step_avg:95.45ms +step:1490/1705 train_time:142213ms step_avg:95.45ms +step:1491/1705 train_time:142307ms step_avg:95.44ms +step:1492/1705 train_time:142401ms step_avg:95.44ms +step:1493/1705 train_time:142495ms step_avg:95.44ms +step:1494/1705 train_time:142590ms step_avg:95.44ms +step:1495/1705 train_time:142688ms step_avg:95.44ms +step:1496/1705 train_time:142787ms step_avg:95.45ms +step:1497/1705 train_time:142885ms step_avg:95.45ms +step:1498/1705 train_time:142980ms step_avg:95.45ms +step:1499/1705 train_time:143075ms step_avg:95.45ms +step:1500/1705 train_time:143169ms step_avg:95.45ms +step:1500/1705 val_loss:3.3171 train_time:143265ms step_avg:95.51ms +step:1501/1705 train_time:143286ms step_avg:95.46ms +step:1502/1705 train_time:143367ms step_avg:95.45ms +step:1503/1705 train_time:143467ms step_avg:95.45ms +step:1504/1705 train_time:143562ms step_avg:95.45ms +step:1505/1705 train_time:143656ms step_avg:95.45ms +step:1506/1705 train_time:143750ms step_avg:95.45ms +step:1507/1705 train_time:143845ms step_avg:95.45ms +step:1508/1705 train_time:143940ms step_avg:95.45ms +step:1509/1705 train_time:144033ms step_avg:95.45ms +step:1510/1705 train_time:144127ms step_avg:95.45ms +step:1511/1705 train_time:144225ms step_avg:95.45ms +step:1512/1705 train_time:144323ms step_avg:95.45ms +step:1513/1705 train_time:144420ms step_avg:95.45ms +step:1514/1705 train_time:144516ms step_avg:95.45ms +step:1515/1705 train_time:144611ms step_avg:95.45ms +step:1516/1705 train_time:144706ms step_avg:95.45ms +step:1517/1705 train_time:144800ms step_avg:95.45ms +step:1518/1705 train_time:144894ms step_avg:95.45ms +step:1519/1705 train_time:144989ms step_avg:95.45ms +step:1520/1705 train_time:145084ms step_avg:95.45ms +step:1521/1705 train_time:145179ms step_avg:95.45ms +step:1522/1705 train_time:145275ms step_avg:95.45ms +step:1523/1705 train_time:145371ms step_avg:95.45ms +step:1524/1705 train_time:145467ms step_avg:95.45ms +step:1525/1705 train_time:145565ms step_avg:95.45ms +step:1526/1705 train_time:145661ms step_avg:95.45ms +step:1527/1705 train_time:145757ms step_avg:95.45ms +step:1528/1705 train_time:145851ms step_avg:95.45ms +step:1529/1705 train_time:145946ms step_avg:95.45ms +step:1530/1705 train_time:146041ms step_avg:95.45ms +step:1531/1705 train_time:146136ms step_avg:95.45ms +step:1532/1705 train_time:146230ms step_avg:95.45ms +step:1533/1705 train_time:146326ms step_avg:95.45ms +step:1534/1705 train_time:146423ms step_avg:95.45ms +step:1535/1705 train_time:146519ms step_avg:95.45ms +step:1536/1705 train_time:146614ms step_avg:95.45ms +step:1537/1705 train_time:146709ms step_avg:95.45ms +step:1538/1705 train_time:146805ms step_avg:95.45ms +step:1539/1705 train_time:146899ms step_avg:95.45ms +step:1540/1705 train_time:146993ms step_avg:95.45ms +step:1541/1705 train_time:147088ms step_avg:95.45ms +step:1542/1705 train_time:147184ms step_avg:95.45ms +step:1543/1705 train_time:147280ms step_avg:95.45ms +step:1544/1705 train_time:147376ms step_avg:95.45ms +step:1545/1705 train_time:147472ms step_avg:95.45ms +step:1546/1705 train_time:147568ms step_avg:95.45ms +step:1547/1705 train_time:147663ms step_avg:95.45ms +step:1548/1705 train_time:147761ms step_avg:95.45ms +step:1549/1705 train_time:147857ms step_avg:95.45ms +step:1550/1705 train_time:147951ms step_avg:95.45ms +step:1551/1705 train_time:148047ms step_avg:95.45ms +step:1552/1705 train_time:148143ms step_avg:95.45ms +step:1553/1705 train_time:148238ms step_avg:95.45ms +step:1554/1705 train_time:148333ms step_avg:95.45ms +step:1555/1705 train_time:148428ms step_avg:95.45ms +step:1556/1705 train_time:148525ms step_avg:95.45ms +step:1557/1705 train_time:148621ms step_avg:95.45ms +step:1558/1705 train_time:148716ms step_avg:95.45ms +step:1559/1705 train_time:148811ms step_avg:95.45ms +step:1560/1705 train_time:148907ms step_avg:95.45ms +step:1561/1705 train_time:149003ms step_avg:95.45ms +step:1562/1705 train_time:149097ms step_avg:95.45ms +step:1563/1705 train_time:149193ms step_avg:95.45ms +step:1564/1705 train_time:149288ms step_avg:95.45ms +step:1565/1705 train_time:149385ms step_avg:95.45ms +step:1566/1705 train_time:149481ms step_avg:95.45ms +step:1567/1705 train_time:149576ms step_avg:95.45ms +step:1568/1705 train_time:149670ms step_avg:95.45ms +step:1569/1705 train_time:149766ms step_avg:95.45ms +step:1570/1705 train_time:149861ms step_avg:95.45ms +step:1571/1705 train_time:149956ms step_avg:95.45ms +step:1572/1705 train_time:150051ms step_avg:95.45ms +step:1573/1705 train_time:150147ms step_avg:95.45ms +step:1574/1705 train_time:150242ms step_avg:95.45ms +step:1575/1705 train_time:150337ms step_avg:95.45ms +step:1576/1705 train_time:150431ms step_avg:95.45ms +step:1577/1705 train_time:150527ms step_avg:95.45ms +step:1578/1705 train_time:150623ms step_avg:95.45ms +step:1579/1705 train_time:150717ms step_avg:95.45ms +step:1580/1705 train_time:150812ms step_avg:95.45ms +step:1581/1705 train_time:150908ms step_avg:95.45ms +step:1582/1705 train_time:151004ms step_avg:95.45ms +step:1583/1705 train_time:151101ms step_avg:95.45ms +step:1584/1705 train_time:151195ms step_avg:95.45ms +step:1585/1705 train_time:151290ms step_avg:95.45ms +step:1586/1705 train_time:151385ms step_avg:95.45ms +step:1587/1705 train_time:151481ms step_avg:95.45ms +step:1588/1705 train_time:151576ms step_avg:95.45ms +step:1589/1705 train_time:151671ms step_avg:95.45ms +step:1590/1705 train_time:151767ms step_avg:95.45ms +step:1591/1705 train_time:151863ms step_avg:95.45ms +step:1592/1705 train_time:151960ms step_avg:95.45ms +step:1593/1705 train_time:152055ms step_avg:95.45ms +step:1594/1705 train_time:152151ms step_avg:95.45ms +step:1595/1705 train_time:152247ms step_avg:95.45ms +step:1596/1705 train_time:152344ms step_avg:95.45ms +step:1597/1705 train_time:152439ms step_avg:95.45ms +step:1598/1705 train_time:152534ms step_avg:95.45ms +step:1599/1705 train_time:152629ms step_avg:95.45ms +step:1600/1705 train_time:152725ms step_avg:95.45ms +step:1601/1705 train_time:152819ms step_avg:95.45ms +step:1602/1705 train_time:152914ms step_avg:95.45ms +step:1603/1705 train_time:153010ms step_avg:95.45ms +step:1604/1705 train_time:153105ms step_avg:95.45ms +step:1605/1705 train_time:153201ms step_avg:95.45ms +step:1606/1705 train_time:153297ms step_avg:95.45ms +step:1607/1705 train_time:153391ms step_avg:95.45ms +step:1608/1705 train_time:153487ms step_avg:95.45ms +step:1609/1705 train_time:153583ms step_avg:95.45ms +step:1610/1705 train_time:153678ms step_avg:95.45ms +step:1611/1705 train_time:153773ms step_avg:95.45ms +step:1612/1705 train_time:153869ms step_avg:95.45ms +step:1613/1705 train_time:153965ms step_avg:95.45ms +step:1614/1705 train_time:154061ms step_avg:95.45ms +step:1615/1705 train_time:154157ms step_avg:95.45ms +step:1616/1705 train_time:154252ms step_avg:95.45ms +step:1617/1705 train_time:154348ms step_avg:95.45ms +step:1618/1705 train_time:154443ms step_avg:95.45ms +step:1619/1705 train_time:154539ms step_avg:95.45ms +step:1620/1705 train_time:154633ms step_avg:95.45ms +step:1621/1705 train_time:154728ms step_avg:95.45ms +step:1622/1705 train_time:154824ms step_avg:95.45ms +step:1623/1705 train_time:154919ms step_avg:95.45ms +step:1624/1705 train_time:155014ms step_avg:95.45ms +step:1625/1705 train_time:155111ms step_avg:95.45ms +step:1625/1705 val_loss:3.2893 train_time:155207ms step_avg:95.51ms +step:1626/1705 train_time:155228ms step_avg:95.47ms +step:1627/1705 train_time:155309ms step_avg:95.46ms +step:1628/1705 train_time:155410ms step_avg:95.46ms +step:1629/1705 train_time:155504ms step_avg:95.46ms +step:1630/1705 train_time:155598ms step_avg:95.46ms +step:1631/1705 train_time:155693ms step_avg:95.46ms +step:1632/1705 train_time:155787ms step_avg:95.46ms +step:1633/1705 train_time:155881ms step_avg:95.46ms +step:1634/1705 train_time:155975ms step_avg:95.46ms +step:1635/1705 train_time:156070ms step_avg:95.46ms +step:1636/1705 train_time:156165ms step_avg:95.46ms +step:1637/1705 train_time:156263ms step_avg:95.46ms +step:1638/1705 train_time:156363ms step_avg:95.46ms +step:1639/1705 train_time:156459ms step_avg:95.46ms +step:1640/1705 train_time:156555ms step_avg:95.46ms +step:1641/1705 train_time:156650ms step_avg:95.46ms +step:1642/1705 train_time:156745ms step_avg:95.46ms +step:1643/1705 train_time:156839ms step_avg:95.46ms +step:1644/1705 train_time:156934ms step_avg:95.46ms +step:1645/1705 train_time:157028ms step_avg:95.46ms +step:1646/1705 train_time:157123ms step_avg:95.46ms +step:1647/1705 train_time:157219ms step_avg:95.46ms +step:1648/1705 train_time:157317ms step_avg:95.46ms +step:1649/1705 train_time:157414ms step_avg:95.46ms +step:1650/1705 train_time:157510ms step_avg:95.46ms +step:1651/1705 train_time:157604ms step_avg:95.46ms +step:1652/1705 train_time:157699ms step_avg:95.46ms +step:1653/1705 train_time:157793ms step_avg:95.46ms +step:1654/1705 train_time:157887ms step_avg:95.46ms +step:1655/1705 train_time:157982ms step_avg:95.46ms +step:1656/1705 train_time:158077ms step_avg:95.46ms +step:1657/1705 train_time:158172ms step_avg:95.46ms +step:1658/1705 train_time:158268ms step_avg:95.46ms +step:1659/1705 train_time:158363ms step_avg:95.46ms +step:1660/1705 train_time:158461ms step_avg:95.46ms +step:1661/1705 train_time:158556ms step_avg:95.46ms +step:1662/1705 train_time:158652ms step_avg:95.46ms +step:1663/1705 train_time:158746ms step_avg:95.46ms +step:1664/1705 train_time:158841ms step_avg:95.46ms +step:1665/1705 train_time:158937ms step_avg:95.46ms +step:1666/1705 train_time:159032ms step_avg:95.46ms +step:1667/1705 train_time:159127ms step_avg:95.46ms +step:1668/1705 train_time:159222ms step_avg:95.46ms +step:1669/1705 train_time:159317ms step_avg:95.46ms +step:1670/1705 train_time:159413ms step_avg:95.46ms +step:1671/1705 train_time:159509ms step_avg:95.46ms +step:1672/1705 train_time:159604ms step_avg:95.46ms +step:1673/1705 train_time:159699ms step_avg:95.46ms +step:1674/1705 train_time:159794ms step_avg:95.46ms +step:1675/1705 train_time:159889ms step_avg:95.46ms +step:1676/1705 train_time:159983ms step_avg:95.46ms +step:1677/1705 train_time:160078ms step_avg:95.46ms +step:1678/1705 train_time:160173ms step_avg:95.45ms +step:1679/1705 train_time:160270ms step_avg:95.46ms +step:1680/1705 train_time:160366ms step_avg:95.46ms +step:1681/1705 train_time:160461ms step_avg:95.46ms +step:1682/1705 train_time:160557ms step_avg:95.46ms +step:1683/1705 train_time:160654ms step_avg:95.46ms +step:1684/1705 train_time:160749ms step_avg:95.46ms +step:1685/1705 train_time:160844ms step_avg:95.46ms +step:1686/1705 train_time:160939ms step_avg:95.46ms +step:1687/1705 train_time:161034ms step_avg:95.46ms +step:1688/1705 train_time:161129ms step_avg:95.46ms +step:1689/1705 train_time:161224ms step_avg:95.46ms +step:1690/1705 train_time:161319ms step_avg:95.46ms +step:1691/1705 train_time:161415ms step_avg:95.46ms +step:1692/1705 train_time:161512ms step_avg:95.46ms +step:1693/1705 train_time:161606ms step_avg:95.46ms +step:1694/1705 train_time:161701ms step_avg:95.46ms +step:1695/1705 train_time:161797ms step_avg:95.46ms +step:1696/1705 train_time:161893ms step_avg:95.46ms +step:1697/1705 train_time:161988ms step_avg:95.46ms +step:1698/1705 train_time:162279ms step_avg:95.57ms +step:1699/1705 train_time:162442ms step_avg:95.61ms +step:1700/1705 train_time:162536ms step_avg:95.61ms +step:1701/1705 train_time:162630ms step_avg:95.61ms +step:1702/1705 train_time:162723ms step_avg:95.61ms +step:1703/1705 train_time:162817ms step_avg:95.61ms +step:1704/1705 train_time:162912ms step_avg:95.61ms +step:1705/1705 train_time:163006ms step_avg:95.60ms +step:1705/1705 val_loss:3.2754 train_time:163100ms step_avg:95.66ms +peak memory allocated: 33750 MiB reserved: 48696 MiB diff --git a/records/050925_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt b/records/050925_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt new file mode 100644 index 000000000..e2f905a6f --- /dev/null +++ b/records/050925_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:26:47 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 42C P0 129W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 33C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 89808 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 89809 C /usr/bin/python3 610MiB | +| 0 N/A N/A 89810 C /usr/bin/python3 610MiB | +| 0 N/A N/A 89811 C /usr/bin/python3 610MiB | +| 0 N/A N/A 89812 C /usr/bin/python3 610MiB | +| 0 N/A N/A 89813 C /usr/bin/python3 610MiB | +| 0 N/A N/A 89814 C /usr/bin/python3 610MiB | +| 0 N/A N/A 89815 C /usr/bin/python3 610MiB | +| 1 N/A N/A 89809 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 89810 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 89811 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 89812 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 89813 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 89814 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 89815 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1705 train_time:368ms step_avg:367.62ms +step:2/1705 train_time:389ms step_avg:194.68ms +step:3/1705 train_time:457ms step_avg:152.28ms +step:4/1705 train_time:548ms step_avg:137.00ms +step:5/1705 train_time:640ms step_avg:128.06ms +step:6/1705 train_time:733ms step_avg:122.12ms +step:7/1705 train_time:825ms step_avg:117.86ms +step:8/1705 train_time:917ms step_avg:114.68ms +step:9/1705 train_time:1010ms step_avg:112.19ms +step:10/1705 train_time:1102ms step_avg:110.23ms +step:11/1705 train_time:1194ms step_avg:108.59ms +step:12/1705 train_time:1288ms step_avg:107.36ms +step:13/1705 train_time:1385ms step_avg:106.54ms +step:14/1705 train_time:1481ms step_avg:105.80ms +step:15/1705 train_time:1575ms step_avg:104.99ms +step:16/1705 train_time:1667ms step_avg:104.19ms +step:17/1705 train_time:1760ms step_avg:103.55ms +step:18/1705 train_time:1852ms step_avg:102.91ms +step:19/1705 train_time:1946ms step_avg:102.40ms +step:20/1705 train_time:2038ms step_avg:101.91ms +step:21/1705 train_time:2131ms step_avg:101.46ms +step:22/1705 train_time:2224ms step_avg:101.07ms +step:23/1705 train_time:2318ms step_avg:100.77ms +step:24/1705 train_time:2412ms step_avg:100.50ms +step:25/1705 train_time:2506ms step_avg:100.23ms +step:26/1705 train_time:2600ms step_avg:100.00ms +step:27/1705 train_time:2693ms step_avg:99.74ms +step:28/1705 train_time:2786ms step_avg:99.49ms +step:29/1705 train_time:2878ms step_avg:99.25ms +step:30/1705 train_time:2971ms step_avg:99.03ms +step:31/1705 train_time:3064ms step_avg:98.85ms +step:32/1705 train_time:3158ms step_avg:98.69ms +step:33/1705 train_time:3251ms step_avg:98.51ms +step:34/1705 train_time:3346ms step_avg:98.40ms +step:35/1705 train_time:3439ms step_avg:98.25ms +step:36/1705 train_time:3532ms step_avg:98.12ms +step:37/1705 train_time:3626ms step_avg:98.00ms +step:38/1705 train_time:3719ms step_avg:97.87ms +step:39/1705 train_time:3813ms step_avg:97.76ms +step:40/1705 train_time:3905ms step_avg:97.63ms +step:41/1705 train_time:3998ms step_avg:97.51ms +step:42/1705 train_time:4091ms step_avg:97.39ms +step:43/1705 train_time:4183ms step_avg:97.29ms +step:44/1705 train_time:4276ms step_avg:97.18ms +step:45/1705 train_time:4369ms step_avg:97.10ms +step:46/1705 train_time:4463ms step_avg:97.03ms +step:47/1705 train_time:4557ms step_avg:96.96ms +step:48/1705 train_time:4651ms step_avg:96.90ms +step:49/1705 train_time:4744ms step_avg:96.83ms +step:50/1705 train_time:4838ms step_avg:96.77ms +step:51/1705 train_time:4931ms step_avg:96.69ms +step:52/1705 train_time:5025ms step_avg:96.63ms +step:53/1705 train_time:5118ms step_avg:96.57ms +step:54/1705 train_time:5212ms step_avg:96.51ms +step:55/1705 train_time:5305ms step_avg:96.45ms +step:56/1705 train_time:5397ms step_avg:96.38ms +step:57/1705 train_time:5490ms step_avg:96.32ms +step:58/1705 train_time:5584ms step_avg:96.27ms +step:59/1705 train_time:5677ms step_avg:96.22ms +step:60/1705 train_time:5770ms step_avg:96.16ms +step:61/1705 train_time:5864ms step_avg:96.13ms +step:62/1705 train_time:5958ms step_avg:96.09ms +step:63/1705 train_time:6050ms step_avg:96.03ms +step:64/1705 train_time:6143ms step_avg:95.98ms +step:65/1705 train_time:6236ms step_avg:95.94ms +step:66/1705 train_time:6329ms step_avg:95.90ms +step:67/1705 train_time:6423ms step_avg:95.87ms +step:68/1705 train_time:6517ms step_avg:95.83ms +step:69/1705 train_time:6610ms step_avg:95.79ms +step:70/1705 train_time:6703ms step_avg:95.75ms +step:71/1705 train_time:6796ms step_avg:95.71ms +step:72/1705 train_time:6889ms step_avg:95.68ms +step:73/1705 train_time:6983ms step_avg:95.65ms +step:74/1705 train_time:7075ms step_avg:95.61ms +step:75/1705 train_time:7169ms step_avg:95.58ms +step:76/1705 train_time:7262ms step_avg:95.55ms +step:77/1705 train_time:7355ms step_avg:95.51ms +step:78/1705 train_time:7447ms step_avg:95.48ms +step:79/1705 train_time:7541ms step_avg:95.46ms +step:80/1705 train_time:7634ms step_avg:95.43ms +step:81/1705 train_time:7727ms step_avg:95.40ms +step:82/1705 train_time:7821ms step_avg:95.38ms +step:83/1705 train_time:7914ms step_avg:95.35ms +step:84/1705 train_time:8007ms step_avg:95.32ms +step:85/1705 train_time:8101ms step_avg:95.30ms +step:86/1705 train_time:8192ms step_avg:95.26ms +step:87/1705 train_time:8286ms step_avg:95.24ms +step:88/1705 train_time:8378ms step_avg:95.21ms +step:89/1705 train_time:8472ms step_avg:95.19ms +step:90/1705 train_time:8565ms step_avg:95.17ms +step:91/1705 train_time:8659ms step_avg:95.16ms +step:92/1705 train_time:8752ms step_avg:95.13ms +step:93/1705 train_time:8845ms step_avg:95.11ms +step:94/1705 train_time:8939ms step_avg:95.10ms +step:95/1705 train_time:9032ms step_avg:95.08ms +step:96/1705 train_time:9126ms step_avg:95.06ms +step:97/1705 train_time:9219ms step_avg:95.04ms +step:98/1705 train_time:9312ms step_avg:95.02ms +step:99/1705 train_time:9404ms step_avg:94.99ms +step:100/1705 train_time:9497ms step_avg:94.97ms +step:101/1705 train_time:9591ms step_avg:94.96ms +step:102/1705 train_time:9684ms step_avg:94.94ms +step:103/1705 train_time:9777ms step_avg:94.93ms +step:104/1705 train_time:9869ms step_avg:94.90ms +step:105/1705 train_time:9963ms step_avg:94.89ms +step:106/1705 train_time:10057ms step_avg:94.88ms +step:107/1705 train_time:10149ms step_avg:94.85ms +step:108/1705 train_time:10243ms step_avg:94.84ms +step:109/1705 train_time:10336ms step_avg:94.82ms +step:110/1705 train_time:10429ms step_avg:94.81ms +step:111/1705 train_time:10523ms step_avg:94.80ms +step:112/1705 train_time:10616ms step_avg:94.78ms +step:113/1705 train_time:10709ms step_avg:94.77ms +step:114/1705 train_time:10803ms step_avg:94.76ms +step:115/1705 train_time:10896ms step_avg:94.75ms +step:116/1705 train_time:10989ms step_avg:94.74ms +step:117/1705 train_time:11084ms step_avg:94.73ms +step:118/1705 train_time:11176ms step_avg:94.71ms +step:119/1705 train_time:11268ms step_avg:94.69ms +step:120/1705 train_time:11361ms step_avg:94.68ms +step:121/1705 train_time:11455ms step_avg:94.67ms +step:122/1705 train_time:11548ms step_avg:94.65ms +step:123/1705 train_time:11641ms step_avg:94.64ms +step:124/1705 train_time:11733ms step_avg:94.62ms +step:125/1705 train_time:11827ms step_avg:94.61ms +step:125/1705 val_loss:4.2950 train_time:11920ms step_avg:95.36ms +step:126/1705 train_time:11943ms step_avg:94.79ms +step:127/1705 train_time:12020ms step_avg:94.65ms +step:128/1705 train_time:12125ms step_avg:94.73ms +step:129/1705 train_time:12219ms step_avg:94.72ms +step:130/1705 train_time:12312ms step_avg:94.71ms +step:131/1705 train_time:12404ms step_avg:94.69ms +step:132/1705 train_time:12497ms step_avg:94.67ms +step:133/1705 train_time:12589ms step_avg:94.65ms +step:134/1705 train_time:12680ms step_avg:94.63ms +step:135/1705 train_time:12773ms step_avg:94.61ms +step:136/1705 train_time:12864ms step_avg:94.59ms +step:137/1705 train_time:12957ms step_avg:94.57ms +step:138/1705 train_time:13052ms step_avg:94.58ms +step:139/1705 train_time:13146ms step_avg:94.58ms +step:140/1705 train_time:13240ms step_avg:94.57ms +step:141/1705 train_time:13334ms step_avg:94.56ms +step:142/1705 train_time:13426ms step_avg:94.55ms +step:143/1705 train_time:13518ms step_avg:94.53ms +step:144/1705 train_time:13611ms step_avg:94.52ms +step:145/1705 train_time:13703ms step_avg:94.50ms +step:146/1705 train_time:13795ms step_avg:94.49ms +step:147/1705 train_time:13887ms step_avg:94.47ms +step:148/1705 train_time:13980ms step_avg:94.46ms +step:149/1705 train_time:14074ms step_avg:94.46ms +step:150/1705 train_time:14168ms step_avg:94.46ms +step:151/1705 train_time:14262ms step_avg:94.45ms +step:152/1705 train_time:14355ms step_avg:94.44ms +step:153/1705 train_time:14448ms step_avg:94.43ms +step:154/1705 train_time:14541ms step_avg:94.42ms +step:155/1705 train_time:14633ms step_avg:94.41ms +step:156/1705 train_time:14726ms step_avg:94.40ms +step:157/1705 train_time:14818ms step_avg:94.38ms +step:158/1705 train_time:14910ms step_avg:94.37ms +step:159/1705 train_time:15003ms step_avg:94.36ms +step:160/1705 train_time:15097ms step_avg:94.36ms +step:161/1705 train_time:15192ms step_avg:94.36ms +step:162/1705 train_time:15286ms step_avg:94.36ms +step:163/1705 train_time:15378ms step_avg:94.34ms +step:164/1705 train_time:15472ms step_avg:94.34ms +step:165/1705 train_time:15565ms step_avg:94.33ms +step:166/1705 train_time:15656ms step_avg:94.31ms +step:167/1705 train_time:15749ms step_avg:94.30ms +step:168/1705 train_time:15841ms step_avg:94.29ms +step:169/1705 train_time:15934ms step_avg:94.29ms +step:170/1705 train_time:16027ms step_avg:94.28ms +step:171/1705 train_time:16120ms step_avg:94.27ms +step:172/1705 train_time:16213ms step_avg:94.26ms +step:173/1705 train_time:16306ms step_avg:94.26ms +step:174/1705 train_time:16399ms step_avg:94.25ms +step:175/1705 train_time:16493ms step_avg:94.25ms +step:176/1705 train_time:16586ms step_avg:94.24ms +step:177/1705 train_time:16678ms step_avg:94.23ms +step:178/1705 train_time:16771ms step_avg:94.22ms +step:179/1705 train_time:16863ms step_avg:94.21ms +step:180/1705 train_time:16956ms step_avg:94.20ms +step:181/1705 train_time:17049ms step_avg:94.19ms +step:182/1705 train_time:17142ms step_avg:94.19ms +step:183/1705 train_time:17236ms step_avg:94.18ms +step:184/1705 train_time:17329ms step_avg:94.18ms +step:185/1705 train_time:17422ms step_avg:94.17ms +step:186/1705 train_time:17515ms step_avg:94.17ms +step:187/1705 train_time:17608ms step_avg:94.16ms +step:188/1705 train_time:17700ms step_avg:94.15ms +step:189/1705 train_time:17793ms step_avg:94.14ms +step:190/1705 train_time:17885ms step_avg:94.13ms +step:191/1705 train_time:17978ms step_avg:94.12ms +step:192/1705 train_time:18070ms step_avg:94.11ms +step:193/1705 train_time:18163ms step_avg:94.11ms +step:194/1705 train_time:18256ms step_avg:94.10ms +step:195/1705 train_time:18349ms step_avg:94.10ms +step:196/1705 train_time:18442ms step_avg:94.09ms +step:197/1705 train_time:18535ms step_avg:94.09ms +step:198/1705 train_time:18628ms step_avg:94.08ms +step:199/1705 train_time:18722ms step_avg:94.08ms +step:200/1705 train_time:18815ms step_avg:94.07ms +step:201/1705 train_time:18908ms step_avg:94.07ms +step:202/1705 train_time:19000ms step_avg:94.06ms +step:203/1705 train_time:19093ms step_avg:94.06ms +step:204/1705 train_time:19186ms step_avg:94.05ms +step:205/1705 train_time:19278ms step_avg:94.04ms +step:206/1705 train_time:19371ms step_avg:94.03ms +step:207/1705 train_time:19464ms step_avg:94.03ms +step:208/1705 train_time:19556ms step_avg:94.02ms +step:209/1705 train_time:19649ms step_avg:94.01ms +step:210/1705 train_time:19742ms step_avg:94.01ms +step:211/1705 train_time:19835ms step_avg:94.00ms +step:212/1705 train_time:19928ms step_avg:94.00ms +step:213/1705 train_time:20218ms step_avg:94.92ms +step:214/1705 train_time:20310ms step_avg:94.91ms +step:215/1705 train_time:20402ms step_avg:94.89ms +step:216/1705 train_time:20495ms step_avg:94.88ms +step:217/1705 train_time:20587ms step_avg:94.87ms +step:218/1705 train_time:20679ms step_avg:94.86ms +step:219/1705 train_time:20771ms step_avg:94.84ms +step:220/1705 train_time:20863ms step_avg:94.83ms +step:221/1705 train_time:20955ms step_avg:94.82ms +step:222/1705 train_time:21047ms step_avg:94.81ms +step:223/1705 train_time:21143ms step_avg:94.81ms +step:224/1705 train_time:21238ms step_avg:94.81ms +step:225/1705 train_time:21332ms step_avg:94.81ms +step:226/1705 train_time:21426ms step_avg:94.80ms +step:227/1705 train_time:21518ms step_avg:94.79ms +step:228/1705 train_time:21612ms step_avg:94.79ms +step:229/1705 train_time:21704ms step_avg:94.78ms +step:230/1705 train_time:21796ms step_avg:94.76ms +step:231/1705 train_time:21887ms step_avg:94.75ms +step:232/1705 train_time:21979ms step_avg:94.74ms +step:233/1705 train_time:22072ms step_avg:94.73ms +step:234/1705 train_time:22166ms step_avg:94.73ms +step:235/1705 train_time:22260ms step_avg:94.72ms +step:236/1705 train_time:22354ms step_avg:94.72ms +step:237/1705 train_time:22446ms step_avg:94.71ms +step:238/1705 train_time:22539ms step_avg:94.70ms +step:239/1705 train_time:22632ms step_avg:94.70ms +step:240/1705 train_time:22724ms step_avg:94.68ms +step:241/1705 train_time:22816ms step_avg:94.67ms +step:242/1705 train_time:22908ms step_avg:94.66ms +step:243/1705 train_time:23001ms step_avg:94.65ms +step:244/1705 train_time:23094ms step_avg:94.65ms +step:245/1705 train_time:23186ms step_avg:94.64ms +step:246/1705 train_time:23280ms step_avg:94.64ms +step:247/1705 train_time:23374ms step_avg:94.63ms +step:248/1705 train_time:23468ms step_avg:94.63ms +step:249/1705 train_time:23560ms step_avg:94.62ms +step:250/1705 train_time:23654ms step_avg:94.61ms +step:250/1705 val_loss:3.9737 train_time:23746ms step_avg:94.98ms +step:251/1705 train_time:23769ms step_avg:94.70ms +step:252/1705 train_time:23843ms step_avg:94.62ms +step:253/1705 train_time:23943ms step_avg:94.64ms +step:254/1705 train_time:24038ms step_avg:94.64ms +step:255/1705 train_time:24131ms step_avg:94.63ms +step:256/1705 train_time:24224ms step_avg:94.62ms +step:257/1705 train_time:24315ms step_avg:94.61ms +step:258/1705 train_time:24407ms step_avg:94.60ms +step:259/1705 train_time:24499ms step_avg:94.59ms +step:260/1705 train_time:24590ms step_avg:94.58ms +step:261/1705 train_time:24683ms step_avg:94.57ms +step:262/1705 train_time:24777ms step_avg:94.57ms +step:263/1705 train_time:24873ms step_avg:94.58ms +step:264/1705 train_time:24968ms step_avg:94.57ms +step:265/1705 train_time:25062ms step_avg:94.57ms +step:266/1705 train_time:25154ms step_avg:94.56ms +step:267/1705 train_time:25246ms step_avg:94.56ms +step:268/1705 train_time:25339ms step_avg:94.55ms +step:269/1705 train_time:25432ms step_avg:94.54ms +step:270/1705 train_time:25524ms step_avg:94.53ms +step:271/1705 train_time:25616ms step_avg:94.52ms +step:272/1705 train_time:25708ms step_avg:94.51ms +step:273/1705 train_time:25801ms step_avg:94.51ms +step:274/1705 train_time:25894ms step_avg:94.50ms +step:275/1705 train_time:25987ms step_avg:94.50ms +step:276/1705 train_time:26081ms step_avg:94.50ms +step:277/1705 train_time:26174ms step_avg:94.49ms +step:278/1705 train_time:26267ms step_avg:94.49ms +step:279/1705 train_time:26360ms step_avg:94.48ms +step:280/1705 train_time:26453ms step_avg:94.47ms +step:281/1705 train_time:26544ms step_avg:94.46ms +step:282/1705 train_time:26636ms step_avg:94.46ms +step:283/1705 train_time:26729ms step_avg:94.45ms +step:284/1705 train_time:26822ms step_avg:94.44ms +step:285/1705 train_time:26914ms step_avg:94.44ms +step:286/1705 train_time:27007ms step_avg:94.43ms +step:287/1705 train_time:27101ms step_avg:94.43ms +step:288/1705 train_time:27194ms step_avg:94.42ms +step:289/1705 train_time:27286ms step_avg:94.42ms +step:290/1705 train_time:27379ms step_avg:94.41ms +step:291/1705 train_time:27472ms step_avg:94.41ms +step:292/1705 train_time:27565ms step_avg:94.40ms +step:293/1705 train_time:27658ms step_avg:94.40ms +step:294/1705 train_time:27751ms step_avg:94.39ms +step:295/1705 train_time:27844ms step_avg:94.38ms +step:296/1705 train_time:27937ms step_avg:94.38ms +step:297/1705 train_time:28031ms step_avg:94.38ms +step:298/1705 train_time:28124ms step_avg:94.37ms +step:299/1705 train_time:28216ms step_avg:94.37ms +step:300/1705 train_time:28309ms step_avg:94.36ms +step:301/1705 train_time:28402ms step_avg:94.36ms +step:302/1705 train_time:28495ms step_avg:94.35ms +step:303/1705 train_time:28587ms step_avg:94.35ms +step:304/1705 train_time:28680ms step_avg:94.34ms +step:305/1705 train_time:28773ms step_avg:94.34ms +step:306/1705 train_time:28866ms step_avg:94.33ms +step:307/1705 train_time:28959ms step_avg:94.33ms +step:308/1705 train_time:29052ms step_avg:94.32ms +step:309/1705 train_time:29145ms step_avg:94.32ms +step:310/1705 train_time:29238ms step_avg:94.32ms +step:311/1705 train_time:29330ms step_avg:94.31ms +step:312/1705 train_time:29422ms step_avg:94.30ms +step:313/1705 train_time:29515ms step_avg:94.30ms +step:314/1705 train_time:29608ms step_avg:94.29ms +step:315/1705 train_time:29700ms step_avg:94.29ms +step:316/1705 train_time:29793ms step_avg:94.28ms +step:317/1705 train_time:29886ms step_avg:94.28ms +step:318/1705 train_time:29979ms step_avg:94.27ms +step:319/1705 train_time:30072ms step_avg:94.27ms +step:320/1705 train_time:30165ms step_avg:94.27ms +step:321/1705 train_time:30258ms step_avg:94.26ms +step:322/1705 train_time:30350ms step_avg:94.26ms +step:323/1705 train_time:30444ms step_avg:94.25ms +step:324/1705 train_time:30536ms step_avg:94.25ms +step:325/1705 train_time:30629ms step_avg:94.24ms +step:326/1705 train_time:30722ms step_avg:94.24ms +step:327/1705 train_time:30815ms step_avg:94.23ms +step:328/1705 train_time:30908ms step_avg:94.23ms +step:329/1705 train_time:31000ms step_avg:94.23ms +step:330/1705 train_time:31093ms step_avg:94.22ms +step:331/1705 train_time:31186ms step_avg:94.22ms +step:332/1705 train_time:31279ms step_avg:94.21ms +step:333/1705 train_time:31373ms step_avg:94.21ms +step:334/1705 train_time:31465ms step_avg:94.21ms +step:335/1705 train_time:31558ms step_avg:94.20ms +step:336/1705 train_time:31651ms step_avg:94.20ms +step:337/1705 train_time:31744ms step_avg:94.20ms +step:338/1705 train_time:31837ms step_avg:94.19ms +step:339/1705 train_time:31930ms step_avg:94.19ms +step:340/1705 train_time:32023ms step_avg:94.19ms +step:341/1705 train_time:32116ms step_avg:94.18ms +step:342/1705 train_time:32209ms step_avg:94.18ms +step:343/1705 train_time:32302ms step_avg:94.17ms +step:344/1705 train_time:32394ms step_avg:94.17ms +step:345/1705 train_time:32486ms step_avg:94.16ms +step:346/1705 train_time:32579ms step_avg:94.16ms +step:347/1705 train_time:32672ms step_avg:94.16ms +step:348/1705 train_time:32766ms step_avg:94.15ms +step:349/1705 train_time:32858ms step_avg:94.15ms +step:350/1705 train_time:32952ms step_avg:94.15ms +step:351/1705 train_time:33045ms step_avg:94.14ms +step:352/1705 train_time:33137ms step_avg:94.14ms +step:353/1705 train_time:33230ms step_avg:94.14ms +step:354/1705 train_time:33323ms step_avg:94.13ms +step:355/1705 train_time:33416ms step_avg:94.13ms +step:356/1705 train_time:33509ms step_avg:94.13ms +step:357/1705 train_time:33602ms step_avg:94.12ms +step:358/1705 train_time:33695ms step_avg:94.12ms +step:359/1705 train_time:33788ms step_avg:94.12ms +step:360/1705 train_time:33881ms step_avg:94.11ms +step:361/1705 train_time:33975ms step_avg:94.11ms +step:362/1705 train_time:34068ms step_avg:94.11ms +step:363/1705 train_time:34161ms step_avg:94.11ms +step:364/1705 train_time:34254ms step_avg:94.11ms +step:365/1705 train_time:34346ms step_avg:94.10ms +step:366/1705 train_time:34439ms step_avg:94.10ms +step:367/1705 train_time:34532ms step_avg:94.09ms +step:368/1705 train_time:34625ms step_avg:94.09ms +step:369/1705 train_time:34718ms step_avg:94.09ms +step:370/1705 train_time:34811ms step_avg:94.08ms +step:371/1705 train_time:34904ms step_avg:94.08ms +step:372/1705 train_time:34997ms step_avg:94.08ms +step:373/1705 train_time:35090ms step_avg:94.08ms +step:374/1705 train_time:35183ms step_avg:94.07ms +step:375/1705 train_time:35276ms step_avg:94.07ms +step:375/1705 val_loss:3.8244 train_time:35369ms step_avg:94.32ms +step:376/1705 train_time:35392ms step_avg:94.13ms +step:377/1705 train_time:35468ms step_avg:94.08ms +step:378/1705 train_time:35566ms step_avg:94.09ms +step:379/1705 train_time:35659ms step_avg:94.09ms +step:380/1705 train_time:35752ms step_avg:94.08ms +step:381/1705 train_time:35844ms step_avg:94.08ms +step:382/1705 train_time:35936ms step_avg:94.07ms +step:383/1705 train_time:36029ms step_avg:94.07ms +step:384/1705 train_time:36121ms step_avg:94.06ms +step:385/1705 train_time:36213ms step_avg:94.06ms +step:386/1705 train_time:36307ms step_avg:94.06ms +step:387/1705 train_time:36401ms step_avg:94.06ms +step:388/1705 train_time:36497ms step_avg:94.06ms +step:389/1705 train_time:36591ms step_avg:94.06ms +step:390/1705 train_time:36684ms step_avg:94.06ms +step:391/1705 train_time:36776ms step_avg:94.06ms +step:392/1705 train_time:36869ms step_avg:94.05ms +step:393/1705 train_time:36961ms step_avg:94.05ms +step:394/1705 train_time:37054ms step_avg:94.05ms +step:395/1705 train_time:37147ms step_avg:94.04ms +step:396/1705 train_time:37239ms step_avg:94.04ms +step:397/1705 train_time:37332ms step_avg:94.04ms +step:398/1705 train_time:37426ms step_avg:94.03ms +step:399/1705 train_time:37520ms step_avg:94.03ms +step:400/1705 train_time:37613ms step_avg:94.03ms +step:401/1705 train_time:37707ms step_avg:94.03ms +step:402/1705 train_time:37800ms step_avg:94.03ms +step:403/1705 train_time:37893ms step_avg:94.03ms +step:404/1705 train_time:37987ms step_avg:94.03ms +step:405/1705 train_time:38079ms step_avg:94.02ms +step:406/1705 train_time:38172ms step_avg:94.02ms +step:407/1705 train_time:38264ms step_avg:94.01ms +step:408/1705 train_time:38357ms step_avg:94.01ms +step:409/1705 train_time:38450ms step_avg:94.01ms +step:410/1705 train_time:38543ms step_avg:94.01ms +step:411/1705 train_time:38637ms step_avg:94.01ms +step:412/1705 train_time:38730ms step_avg:94.01ms +step:413/1705 train_time:38823ms step_avg:94.00ms +step:414/1705 train_time:38918ms step_avg:94.01ms +step:415/1705 train_time:39011ms step_avg:94.00ms +step:416/1705 train_time:39104ms step_avg:94.00ms +step:417/1705 train_time:39197ms step_avg:94.00ms +step:418/1705 train_time:39290ms step_avg:93.99ms +step:419/1705 train_time:39382ms step_avg:93.99ms +step:420/1705 train_time:39474ms step_avg:93.99ms +step:421/1705 train_time:39568ms step_avg:93.99ms +step:422/1705 train_time:39661ms step_avg:93.98ms +step:423/1705 train_time:39755ms step_avg:93.98ms +step:424/1705 train_time:39848ms step_avg:93.98ms +step:425/1705 train_time:40130ms step_avg:94.42ms +step:426/1705 train_time:40211ms step_avg:94.39ms +step:427/1705 train_time:40302ms step_avg:94.38ms +step:428/1705 train_time:40394ms step_avg:94.38ms +step:429/1705 train_time:40486ms step_avg:94.37ms +step:430/1705 train_time:40579ms step_avg:94.37ms +step:431/1705 train_time:40671ms step_avg:94.36ms +step:432/1705 train_time:40763ms step_avg:94.36ms +step:433/1705 train_time:40854ms step_avg:94.35ms +step:434/1705 train_time:40947ms step_avg:94.35ms +step:435/1705 train_time:41043ms step_avg:94.35ms +step:436/1705 train_time:41138ms step_avg:94.35ms +step:437/1705 train_time:41232ms step_avg:94.35ms +step:438/1705 train_time:41327ms step_avg:94.35ms +step:439/1705 train_time:41421ms step_avg:94.35ms +step:440/1705 train_time:41513ms step_avg:94.35ms +step:441/1705 train_time:41605ms step_avg:94.34ms +step:442/1705 train_time:41697ms step_avg:94.34ms +step:443/1705 train_time:41789ms step_avg:94.33ms +step:444/1705 train_time:41882ms step_avg:94.33ms +step:445/1705 train_time:41975ms step_avg:94.33ms +step:446/1705 train_time:42069ms step_avg:94.32ms +step:447/1705 train_time:42162ms step_avg:94.32ms +step:448/1705 train_time:42256ms step_avg:94.32ms +step:449/1705 train_time:42350ms step_avg:94.32ms +step:450/1705 train_time:42442ms step_avg:94.32ms +step:451/1705 train_time:42536ms step_avg:94.31ms +step:452/1705 train_time:42629ms step_avg:94.31ms +step:453/1705 train_time:42721ms step_avg:94.31ms +step:454/1705 train_time:42813ms step_avg:94.30ms +step:455/1705 train_time:42906ms step_avg:94.30ms +step:456/1705 train_time:42999ms step_avg:94.30ms +step:457/1705 train_time:43091ms step_avg:94.29ms +step:458/1705 train_time:43185ms step_avg:94.29ms +step:459/1705 train_time:43278ms step_avg:94.29ms +step:460/1705 train_time:43372ms step_avg:94.29ms +step:461/1705 train_time:43465ms step_avg:94.28ms +step:462/1705 train_time:43557ms step_avg:94.28ms +step:463/1705 train_time:43651ms step_avg:94.28ms +step:464/1705 train_time:43744ms step_avg:94.28ms +step:465/1705 train_time:43836ms step_avg:94.27ms +step:466/1705 train_time:43930ms step_avg:94.27ms +step:467/1705 train_time:44023ms step_avg:94.27ms +step:468/1705 train_time:44116ms step_avg:94.27ms +step:469/1705 train_time:44210ms step_avg:94.26ms +step:470/1705 train_time:44303ms step_avg:94.26ms +step:471/1705 train_time:44396ms step_avg:94.26ms +step:472/1705 train_time:44489ms step_avg:94.26ms +step:473/1705 train_time:44582ms step_avg:94.25ms +step:474/1705 train_time:44675ms step_avg:94.25ms +step:475/1705 train_time:44768ms step_avg:94.25ms +step:476/1705 train_time:44861ms step_avg:94.25ms +step:477/1705 train_time:44954ms step_avg:94.24ms +step:478/1705 train_time:45047ms step_avg:94.24ms +step:479/1705 train_time:45140ms step_avg:94.24ms +step:480/1705 train_time:45233ms step_avg:94.23ms +step:481/1705 train_time:45325ms step_avg:94.23ms +step:482/1705 train_time:45418ms step_avg:94.23ms +step:483/1705 train_time:45511ms step_avg:94.23ms +step:484/1705 train_time:45605ms step_avg:94.23ms +step:485/1705 train_time:45698ms step_avg:94.22ms +step:486/1705 train_time:45790ms step_avg:94.22ms +step:487/1705 train_time:45883ms step_avg:94.22ms +step:488/1705 train_time:45975ms step_avg:94.21ms +step:489/1705 train_time:46069ms step_avg:94.21ms +step:490/1705 train_time:46161ms step_avg:94.21ms +step:491/1705 train_time:46255ms step_avg:94.20ms +step:492/1705 train_time:46348ms step_avg:94.20ms +step:493/1705 train_time:46440ms step_avg:94.20ms +step:494/1705 train_time:46534ms step_avg:94.20ms +step:495/1705 train_time:46628ms step_avg:94.20ms +step:496/1705 train_time:46721ms step_avg:94.19ms +step:497/1705 train_time:46814ms step_avg:94.19ms +step:498/1705 train_time:46908ms step_avg:94.19ms +step:499/1705 train_time:47000ms step_avg:94.19ms +step:500/1705 train_time:47093ms step_avg:94.19ms +step:500/1705 val_loss:3.7168 train_time:47187ms step_avg:94.37ms +step:501/1705 train_time:47210ms step_avg:94.23ms +step:502/1705 train_time:47285ms step_avg:94.19ms +step:503/1705 train_time:47381ms step_avg:94.20ms +step:504/1705 train_time:47473ms step_avg:94.19ms +step:505/1705 train_time:47566ms step_avg:94.19ms +step:506/1705 train_time:47658ms step_avg:94.19ms +step:507/1705 train_time:47750ms step_avg:94.18ms +step:508/1705 train_time:47842ms step_avg:94.18ms +step:509/1705 train_time:47934ms step_avg:94.17ms +step:510/1705 train_time:48027ms step_avg:94.17ms +step:511/1705 train_time:48120ms step_avg:94.17ms +step:512/1705 train_time:48215ms step_avg:94.17ms +step:513/1705 train_time:48311ms step_avg:94.17ms +step:514/1705 train_time:48404ms step_avg:94.17ms +step:515/1705 train_time:48498ms step_avg:94.17ms +step:516/1705 train_time:48592ms step_avg:94.17ms +step:517/1705 train_time:48683ms step_avg:94.16ms +step:518/1705 train_time:48775ms step_avg:94.16ms +step:519/1705 train_time:48867ms step_avg:94.16ms +step:520/1705 train_time:48960ms step_avg:94.15ms +step:521/1705 train_time:49053ms step_avg:94.15ms +step:522/1705 train_time:49146ms step_avg:94.15ms +step:523/1705 train_time:49239ms step_avg:94.15ms +step:524/1705 train_time:49334ms step_avg:94.15ms +step:525/1705 train_time:49428ms step_avg:94.15ms +step:526/1705 train_time:49521ms step_avg:94.15ms +step:527/1705 train_time:49614ms step_avg:94.14ms +step:528/1705 train_time:49708ms step_avg:94.14ms +step:529/1705 train_time:49800ms step_avg:94.14ms +step:530/1705 train_time:49893ms step_avg:94.14ms +step:531/1705 train_time:49985ms step_avg:94.13ms +step:532/1705 train_time:50078ms step_avg:94.13ms +step:533/1705 train_time:50172ms step_avg:94.13ms +step:534/1705 train_time:50265ms step_avg:94.13ms +step:535/1705 train_time:50359ms step_avg:94.13ms +step:536/1705 train_time:50452ms step_avg:94.13ms +step:537/1705 train_time:50546ms step_avg:94.13ms +step:538/1705 train_time:50639ms step_avg:94.13ms +step:539/1705 train_time:50733ms step_avg:94.12ms +step:540/1705 train_time:50825ms step_avg:94.12ms +step:541/1705 train_time:50918ms step_avg:94.12ms +step:542/1705 train_time:51011ms step_avg:94.12ms +step:543/1705 train_time:51103ms step_avg:94.11ms +step:544/1705 train_time:51197ms step_avg:94.11ms +step:545/1705 train_time:51290ms step_avg:94.11ms +step:546/1705 train_time:51383ms step_avg:94.11ms +step:547/1705 train_time:51477ms step_avg:94.11ms +step:548/1705 train_time:51571ms step_avg:94.11ms +step:549/1705 train_time:51664ms step_avg:94.11ms +step:550/1705 train_time:51757ms step_avg:94.10ms +step:551/1705 train_time:51850ms step_avg:94.10ms +step:552/1705 train_time:51943ms step_avg:94.10ms +step:553/1705 train_time:52036ms step_avg:94.10ms +step:554/1705 train_time:52129ms step_avg:94.10ms +step:555/1705 train_time:52222ms step_avg:94.09ms +step:556/1705 train_time:52315ms step_avg:94.09ms +step:557/1705 train_time:52409ms step_avg:94.09ms +step:558/1705 train_time:52502ms step_avg:94.09ms +step:559/1705 train_time:52596ms step_avg:94.09ms +step:560/1705 train_time:52688ms step_avg:94.09ms +step:561/1705 train_time:52781ms step_avg:94.08ms +step:562/1705 train_time:52875ms step_avg:94.08ms +step:563/1705 train_time:52967ms step_avg:94.08ms +step:564/1705 train_time:53060ms step_avg:94.08ms +step:565/1705 train_time:53153ms step_avg:94.08ms +step:566/1705 train_time:53247ms step_avg:94.08ms +step:567/1705 train_time:53340ms step_avg:94.07ms +step:568/1705 train_time:53433ms step_avg:94.07ms +step:569/1705 train_time:53526ms step_avg:94.07ms +step:570/1705 train_time:53619ms step_avg:94.07ms +step:571/1705 train_time:53713ms step_avg:94.07ms +step:572/1705 train_time:53807ms step_avg:94.07ms +step:573/1705 train_time:53903ms step_avg:94.07ms +step:574/1705 train_time:53997ms step_avg:94.07ms +step:575/1705 train_time:54091ms step_avg:94.07ms +step:576/1705 train_time:54186ms step_avg:94.07ms +step:577/1705 train_time:54280ms step_avg:94.07ms +step:578/1705 train_time:54375ms step_avg:94.07ms +step:579/1705 train_time:54470ms step_avg:94.08ms +step:580/1705 train_time:54564ms step_avg:94.08ms +step:581/1705 train_time:54658ms step_avg:94.08ms +step:582/1705 train_time:54753ms step_avg:94.08ms +step:583/1705 train_time:54847ms step_avg:94.08ms +step:584/1705 train_time:54941ms step_avg:94.08ms +step:585/1705 train_time:55035ms step_avg:94.08ms +step:586/1705 train_time:55129ms step_avg:94.08ms +step:587/1705 train_time:55223ms step_avg:94.08ms +step:588/1705 train_time:55318ms step_avg:94.08ms +step:589/1705 train_time:55413ms step_avg:94.08ms +step:590/1705 train_time:55507ms step_avg:94.08ms +step:591/1705 train_time:55601ms step_avg:94.08ms +step:592/1705 train_time:55695ms step_avg:94.08ms +step:593/1705 train_time:55790ms step_avg:94.08ms +step:594/1705 train_time:55883ms step_avg:94.08ms +step:595/1705 train_time:55977ms step_avg:94.08ms +step:596/1705 train_time:56072ms step_avg:94.08ms +step:597/1705 train_time:56167ms step_avg:94.08ms +step:598/1705 train_time:56261ms step_avg:94.08ms +step:599/1705 train_time:56356ms step_avg:94.08ms +step:600/1705 train_time:56451ms step_avg:94.08ms +step:601/1705 train_time:56545ms step_avg:94.09ms +step:602/1705 train_time:56640ms step_avg:94.09ms +step:603/1705 train_time:56735ms step_avg:94.09ms +step:604/1705 train_time:56830ms step_avg:94.09ms +step:605/1705 train_time:56924ms step_avg:94.09ms +step:606/1705 train_time:57018ms step_avg:94.09ms +step:607/1705 train_time:57113ms step_avg:94.09ms +step:608/1705 train_time:57206ms step_avg:94.09ms +step:609/1705 train_time:57300ms step_avg:94.09ms +step:610/1705 train_time:57396ms step_avg:94.09ms +step:611/1705 train_time:57490ms step_avg:94.09ms +step:612/1705 train_time:57584ms step_avg:94.09ms +step:613/1705 train_time:57679ms step_avg:94.09ms +step:614/1705 train_time:57773ms step_avg:94.09ms +step:615/1705 train_time:57868ms step_avg:94.09ms +step:616/1705 train_time:57962ms step_avg:94.09ms +step:617/1705 train_time:58057ms step_avg:94.10ms +step:618/1705 train_time:58151ms step_avg:94.10ms +step:619/1705 train_time:58245ms step_avg:94.10ms +step:620/1705 train_time:58340ms step_avg:94.10ms +step:621/1705 train_time:58435ms step_avg:94.10ms +step:622/1705 train_time:58529ms step_avg:94.10ms +step:623/1705 train_time:58623ms step_avg:94.10ms +step:624/1705 train_time:58718ms step_avg:94.10ms +step:625/1705 train_time:58812ms step_avg:94.10ms +step:625/1705 val_loss:3.6168 train_time:58906ms step_avg:94.25ms +step:626/1705 train_time:58929ms step_avg:94.14ms +step:627/1705 train_time:59011ms step_avg:94.12ms +step:628/1705 train_time:59110ms step_avg:94.12ms +step:629/1705 train_time:59205ms step_avg:94.13ms +step:630/1705 train_time:59298ms step_avg:94.12ms +step:631/1705 train_time:59392ms step_avg:94.12ms +step:632/1705 train_time:59485ms step_avg:94.12ms +step:633/1705 train_time:59579ms step_avg:94.12ms +step:634/1705 train_time:59671ms step_avg:94.12ms +step:635/1705 train_time:59764ms step_avg:94.12ms +step:636/1705 train_time:59857ms step_avg:94.12ms +step:637/1705 train_time:59956ms step_avg:94.12ms +step:638/1705 train_time:60055ms step_avg:94.13ms +step:639/1705 train_time:60428ms step_avg:94.57ms +step:640/1705 train_time:60507ms step_avg:94.54ms +step:641/1705 train_time:60600ms step_avg:94.54ms +step:642/1705 train_time:60694ms step_avg:94.54ms +step:643/1705 train_time:60787ms step_avg:94.54ms +step:644/1705 train_time:60881ms step_avg:94.54ms +step:645/1705 train_time:60974ms step_avg:94.53ms +step:646/1705 train_time:61067ms step_avg:94.53ms +step:647/1705 train_time:61160ms step_avg:94.53ms +step:648/1705 train_time:61254ms step_avg:94.53ms +step:649/1705 train_time:61351ms step_avg:94.53ms +step:650/1705 train_time:61448ms step_avg:94.54ms +step:651/1705 train_time:61543ms step_avg:94.54ms +step:652/1705 train_time:61638ms step_avg:94.54ms +step:653/1705 train_time:61733ms step_avg:94.54ms +step:654/1705 train_time:61827ms step_avg:94.54ms +step:655/1705 train_time:61921ms step_avg:94.54ms +step:656/1705 train_time:62014ms step_avg:94.53ms +step:657/1705 train_time:62107ms step_avg:94.53ms +step:658/1705 train_time:62201ms step_avg:94.53ms +step:659/1705 train_time:62296ms step_avg:94.53ms +step:660/1705 train_time:62392ms step_avg:94.53ms +step:661/1705 train_time:62487ms step_avg:94.53ms +step:662/1705 train_time:62581ms step_avg:94.53ms +step:663/1705 train_time:62676ms step_avg:94.53ms +step:664/1705 train_time:62771ms step_avg:94.53ms +step:665/1705 train_time:62865ms step_avg:94.53ms +step:666/1705 train_time:62959ms step_avg:94.53ms +step:667/1705 train_time:63053ms step_avg:94.53ms +step:668/1705 train_time:63146ms step_avg:94.53ms +step:669/1705 train_time:63240ms step_avg:94.53ms +step:670/1705 train_time:63336ms step_avg:94.53ms +step:671/1705 train_time:63431ms step_avg:94.53ms +step:672/1705 train_time:63526ms step_avg:94.53ms +step:673/1705 train_time:63621ms step_avg:94.53ms +step:674/1705 train_time:63715ms step_avg:94.53ms +step:675/1705 train_time:63810ms step_avg:94.53ms +step:676/1705 train_time:63904ms step_avg:94.53ms +step:677/1705 train_time:63997ms step_avg:94.53ms +step:678/1705 train_time:64091ms step_avg:94.53ms +step:679/1705 train_time:64185ms step_avg:94.53ms +step:680/1705 train_time:64279ms step_avg:94.53ms +step:681/1705 train_time:64374ms step_avg:94.53ms +step:682/1705 train_time:64469ms step_avg:94.53ms +step:683/1705 train_time:64563ms step_avg:94.53ms +step:684/1705 train_time:64658ms step_avg:94.53ms +step:685/1705 train_time:64753ms step_avg:94.53ms +step:686/1705 train_time:64847ms step_avg:94.53ms +step:687/1705 train_time:64941ms step_avg:94.53ms +step:688/1705 train_time:65035ms step_avg:94.53ms +step:689/1705 train_time:65130ms step_avg:94.53ms +step:690/1705 train_time:65224ms step_avg:94.53ms +step:691/1705 train_time:65319ms step_avg:94.53ms +step:692/1705 train_time:65413ms step_avg:94.53ms +step:693/1705 train_time:65508ms step_avg:94.53ms +step:694/1705 train_time:65603ms step_avg:94.53ms +step:695/1705 train_time:65698ms step_avg:94.53ms +step:696/1705 train_time:65793ms step_avg:94.53ms +step:697/1705 train_time:65886ms step_avg:94.53ms +step:698/1705 train_time:65980ms step_avg:94.53ms +step:699/1705 train_time:66075ms step_avg:94.53ms +step:700/1705 train_time:66170ms step_avg:94.53ms +step:701/1705 train_time:66263ms step_avg:94.53ms +step:702/1705 train_time:66358ms step_avg:94.53ms +step:703/1705 train_time:66454ms step_avg:94.53ms +step:704/1705 train_time:66548ms step_avg:94.53ms +step:705/1705 train_time:66642ms step_avg:94.53ms +step:706/1705 train_time:66736ms step_avg:94.53ms +step:707/1705 train_time:66831ms step_avg:94.53ms +step:708/1705 train_time:66925ms step_avg:94.53ms +step:709/1705 train_time:67019ms step_avg:94.53ms +step:710/1705 train_time:67113ms step_avg:94.53ms +step:711/1705 train_time:67207ms step_avg:94.52ms +step:712/1705 train_time:67301ms step_avg:94.52ms +step:713/1705 train_time:67397ms step_avg:94.53ms +step:714/1705 train_time:67491ms step_avg:94.53ms +step:715/1705 train_time:67586ms step_avg:94.53ms +step:716/1705 train_time:67680ms step_avg:94.52ms +step:717/1705 train_time:67774ms step_avg:94.52ms +step:718/1705 train_time:67868ms step_avg:94.52ms +step:719/1705 train_time:67962ms step_avg:94.52ms +step:720/1705 train_time:68057ms step_avg:94.52ms +step:721/1705 train_time:68152ms step_avg:94.52ms +step:722/1705 train_time:68246ms step_avg:94.52ms +step:723/1705 train_time:68340ms step_avg:94.52ms +step:724/1705 train_time:68435ms step_avg:94.52ms +step:725/1705 train_time:68529ms step_avg:94.52ms +step:726/1705 train_time:68624ms step_avg:94.52ms +step:727/1705 train_time:68719ms step_avg:94.52ms +step:728/1705 train_time:68813ms step_avg:94.52ms +step:729/1705 train_time:68908ms step_avg:94.52ms +step:730/1705 train_time:69002ms step_avg:94.52ms +step:731/1705 train_time:69096ms step_avg:94.52ms +step:732/1705 train_time:69192ms step_avg:94.52ms +step:733/1705 train_time:69286ms step_avg:94.52ms +step:734/1705 train_time:69380ms step_avg:94.52ms +step:735/1705 train_time:69475ms step_avg:94.52ms +step:736/1705 train_time:69570ms step_avg:94.52ms +step:737/1705 train_time:69664ms step_avg:94.52ms +step:738/1705 train_time:69759ms step_avg:94.52ms +step:739/1705 train_time:69854ms step_avg:94.52ms +step:740/1705 train_time:69948ms step_avg:94.52ms +step:741/1705 train_time:70043ms step_avg:94.52ms +step:742/1705 train_time:70138ms step_avg:94.53ms +step:743/1705 train_time:70233ms step_avg:94.53ms +step:744/1705 train_time:70327ms step_avg:94.53ms +step:745/1705 train_time:70422ms step_avg:94.53ms +step:746/1705 train_time:70517ms step_avg:94.53ms +step:747/1705 train_time:70612ms step_avg:94.53ms +step:748/1705 train_time:70706ms step_avg:94.53ms +step:749/1705 train_time:70800ms step_avg:94.53ms +step:750/1705 train_time:70895ms step_avg:94.53ms +step:750/1705 val_loss:3.5649 train_time:70990ms step_avg:94.65ms +step:751/1705 train_time:71013ms step_avg:94.56ms +step:752/1705 train_time:71089ms step_avg:94.53ms +step:753/1705 train_time:71189ms step_avg:94.54ms +step:754/1705 train_time:71285ms step_avg:94.54ms +step:755/1705 train_time:71378ms step_avg:94.54ms +step:756/1705 train_time:71472ms step_avg:94.54ms +step:757/1705 train_time:71566ms step_avg:94.54ms +step:758/1705 train_time:71659ms step_avg:94.54ms +step:759/1705 train_time:71753ms step_avg:94.54ms +step:760/1705 train_time:71846ms step_avg:94.53ms +step:761/1705 train_time:71940ms step_avg:94.53ms +step:762/1705 train_time:72036ms step_avg:94.54ms +step:763/1705 train_time:72134ms step_avg:94.54ms +step:764/1705 train_time:72230ms step_avg:94.54ms +step:765/1705 train_time:72324ms step_avg:94.54ms +step:766/1705 train_time:72419ms step_avg:94.54ms +step:767/1705 train_time:72513ms step_avg:94.54ms +step:768/1705 train_time:72606ms step_avg:94.54ms +step:769/1705 train_time:72700ms step_avg:94.54ms +step:770/1705 train_time:72794ms step_avg:94.54ms +step:771/1705 train_time:72888ms step_avg:94.54ms +step:772/1705 train_time:72982ms step_avg:94.54ms +step:773/1705 train_time:73080ms step_avg:94.54ms +step:774/1705 train_time:73177ms step_avg:94.54ms +step:775/1705 train_time:73273ms step_avg:94.55ms +step:776/1705 train_time:73367ms step_avg:94.54ms +step:777/1705 train_time:73461ms step_avg:94.54ms +step:778/1705 train_time:73555ms step_avg:94.54ms +step:779/1705 train_time:73648ms step_avg:94.54ms +step:780/1705 train_time:73742ms step_avg:94.54ms +step:781/1705 train_time:73837ms step_avg:94.54ms +step:782/1705 train_time:73931ms step_avg:94.54ms +step:783/1705 train_time:74025ms step_avg:94.54ms +step:784/1705 train_time:74122ms step_avg:94.54ms +step:785/1705 train_time:74217ms step_avg:94.54ms +step:786/1705 train_time:74312ms step_avg:94.54ms +step:787/1705 train_time:74406ms step_avg:94.54ms +step:788/1705 train_time:74500ms step_avg:94.54ms +step:789/1705 train_time:74595ms step_avg:94.54ms +step:790/1705 train_time:74688ms step_avg:94.54ms +step:791/1705 train_time:74783ms step_avg:94.54ms +step:792/1705 train_time:74878ms step_avg:94.54ms +step:793/1705 train_time:74973ms step_avg:94.54ms +step:794/1705 train_time:75068ms step_avg:94.54ms +step:795/1705 train_time:75163ms step_avg:94.54ms +step:796/1705 train_time:75258ms step_avg:94.55ms +step:797/1705 train_time:75352ms step_avg:94.54ms +step:798/1705 train_time:75446ms step_avg:94.54ms +step:799/1705 train_time:75540ms step_avg:94.54ms +step:800/1705 train_time:75636ms step_avg:94.54ms +step:801/1705 train_time:75730ms step_avg:94.54ms +step:802/1705 train_time:75824ms step_avg:94.54ms +step:803/1705 train_time:75919ms step_avg:94.54ms +step:804/1705 train_time:76014ms step_avg:94.54ms +step:805/1705 train_time:76108ms step_avg:94.54ms +step:806/1705 train_time:76204ms step_avg:94.55ms +step:807/1705 train_time:76299ms step_avg:94.55ms +step:808/1705 train_time:76393ms step_avg:94.55ms +step:809/1705 train_time:76487ms step_avg:94.55ms +step:810/1705 train_time:76582ms step_avg:94.55ms +step:811/1705 train_time:76676ms step_avg:94.55ms +step:812/1705 train_time:76770ms step_avg:94.54ms +step:813/1705 train_time:76865ms step_avg:94.54ms +step:814/1705 train_time:76960ms step_avg:94.55ms +step:815/1705 train_time:77055ms step_avg:94.55ms +step:816/1705 train_time:77149ms step_avg:94.55ms +step:817/1705 train_time:77244ms step_avg:94.55ms +step:818/1705 train_time:77338ms step_avg:94.55ms +step:819/1705 train_time:77432ms step_avg:94.55ms +step:820/1705 train_time:77526ms step_avg:94.54ms +step:821/1705 train_time:77621ms step_avg:94.54ms +step:822/1705 train_time:77715ms step_avg:94.54ms +step:823/1705 train_time:77809ms step_avg:94.54ms +step:824/1705 train_time:77903ms step_avg:94.54ms +step:825/1705 train_time:77997ms step_avg:94.54ms +step:826/1705 train_time:78093ms step_avg:94.54ms +step:827/1705 train_time:78186ms step_avg:94.54ms +step:828/1705 train_time:78281ms step_avg:94.54ms +step:829/1705 train_time:78375ms step_avg:94.54ms +step:830/1705 train_time:78470ms step_avg:94.54ms +step:831/1705 train_time:78565ms step_avg:94.54ms +step:832/1705 train_time:78659ms step_avg:94.54ms +step:833/1705 train_time:78753ms step_avg:94.54ms +step:834/1705 train_time:78847ms step_avg:94.54ms +step:835/1705 train_time:78941ms step_avg:94.54ms +step:836/1705 train_time:79037ms step_avg:94.54ms +step:837/1705 train_time:79131ms step_avg:94.54ms +step:838/1705 train_time:79224ms step_avg:94.54ms +step:839/1705 train_time:79319ms step_avg:94.54ms +step:840/1705 train_time:79413ms step_avg:94.54ms +step:841/1705 train_time:79507ms step_avg:94.54ms +step:842/1705 train_time:79601ms step_avg:94.54ms +step:843/1705 train_time:79697ms step_avg:94.54ms +step:844/1705 train_time:79791ms step_avg:94.54ms +step:845/1705 train_time:79885ms step_avg:94.54ms +step:846/1705 train_time:79980ms step_avg:94.54ms +step:847/1705 train_time:80075ms step_avg:94.54ms +step:848/1705 train_time:80169ms step_avg:94.54ms +step:849/1705 train_time:80263ms step_avg:94.54ms +step:850/1705 train_time:80359ms step_avg:94.54ms +step:851/1705 train_time:80626ms step_avg:94.74ms +step:852/1705 train_time:80792ms step_avg:94.83ms +step:853/1705 train_time:80886ms step_avg:94.82ms +step:854/1705 train_time:80979ms step_avg:94.82ms +step:855/1705 train_time:81073ms step_avg:94.82ms +step:856/1705 train_time:81165ms step_avg:94.82ms +step:857/1705 train_time:81259ms step_avg:94.82ms +step:858/1705 train_time:81353ms step_avg:94.82ms +step:859/1705 train_time:81446ms step_avg:94.82ms +step:860/1705 train_time:81540ms step_avg:94.81ms +step:861/1705 train_time:81636ms step_avg:94.82ms +step:862/1705 train_time:81736ms step_avg:94.82ms +step:863/1705 train_time:81834ms step_avg:94.83ms +step:864/1705 train_time:81928ms step_avg:94.82ms +step:865/1705 train_time:82023ms step_avg:94.82ms +step:866/1705 train_time:82116ms step_avg:94.82ms +step:867/1705 train_time:82210ms step_avg:94.82ms +step:868/1705 train_time:82303ms step_avg:94.82ms +step:869/1705 train_time:82397ms step_avg:94.82ms +step:870/1705 train_time:82490ms step_avg:94.82ms +step:871/1705 train_time:82585ms step_avg:94.82ms +step:872/1705 train_time:82681ms step_avg:94.82ms +step:873/1705 train_time:82780ms step_avg:94.82ms +step:874/1705 train_time:82877ms step_avg:94.82ms +step:875/1705 train_time:82971ms step_avg:94.82ms +step:875/1705 val_loss:3.5231 train_time:83065ms step_avg:94.93ms +step:876/1705 train_time:83088ms step_avg:94.85ms +step:877/1705 train_time:83167ms step_avg:94.83ms +step:878/1705 train_time:83265ms step_avg:94.83ms +step:879/1705 train_time:83360ms step_avg:94.84ms +step:880/1705 train_time:83454ms step_avg:94.83ms +step:881/1705 train_time:83547ms step_avg:94.83ms +step:882/1705 train_time:83640ms step_avg:94.83ms +step:883/1705 train_time:83734ms step_avg:94.83ms +step:884/1705 train_time:83828ms step_avg:94.83ms +step:885/1705 train_time:83921ms step_avg:94.83ms +step:886/1705 train_time:84016ms step_avg:94.83ms +step:887/1705 train_time:84113ms step_avg:94.83ms +step:888/1705 train_time:84210ms step_avg:94.83ms +step:889/1705 train_time:84306ms step_avg:94.83ms +step:890/1705 train_time:84401ms step_avg:94.83ms +step:891/1705 train_time:84495ms step_avg:94.83ms +step:892/1705 train_time:84588ms step_avg:94.83ms +step:893/1705 train_time:84681ms step_avg:94.83ms +step:894/1705 train_time:84775ms step_avg:94.83ms +step:895/1705 train_time:84869ms step_avg:94.83ms +step:896/1705 train_time:84962ms step_avg:94.82ms +step:897/1705 train_time:85058ms step_avg:94.82ms +step:898/1705 train_time:85154ms step_avg:94.83ms +step:899/1705 train_time:85249ms step_avg:94.83ms +step:900/1705 train_time:85343ms step_avg:94.83ms +step:901/1705 train_time:85439ms step_avg:94.83ms +step:902/1705 train_time:85533ms step_avg:94.83ms +step:903/1705 train_time:85627ms step_avg:94.83ms +step:904/1705 train_time:85721ms step_avg:94.82ms +step:905/1705 train_time:85815ms step_avg:94.82ms +step:906/1705 train_time:85908ms step_avg:94.82ms +step:907/1705 train_time:86002ms step_avg:94.82ms +step:908/1705 train_time:86098ms step_avg:94.82ms +step:909/1705 train_time:86194ms step_avg:94.82ms +step:910/1705 train_time:86289ms step_avg:94.82ms +step:911/1705 train_time:86384ms step_avg:94.82ms +step:912/1705 train_time:86478ms step_avg:94.82ms +step:913/1705 train_time:86573ms step_avg:94.82ms +step:914/1705 train_time:86667ms step_avg:94.82ms +step:915/1705 train_time:86761ms step_avg:94.82ms +step:916/1705 train_time:86855ms step_avg:94.82ms +step:917/1705 train_time:86949ms step_avg:94.82ms +step:918/1705 train_time:87043ms step_avg:94.82ms +step:919/1705 train_time:87138ms step_avg:94.82ms +step:920/1705 train_time:87234ms step_avg:94.82ms +step:921/1705 train_time:87329ms step_avg:94.82ms +step:922/1705 train_time:87423ms step_avg:94.82ms +step:923/1705 train_time:87518ms step_avg:94.82ms +step:924/1705 train_time:87613ms step_avg:94.82ms +step:925/1705 train_time:87708ms step_avg:94.82ms +step:926/1705 train_time:87801ms step_avg:94.82ms +step:927/1705 train_time:87896ms step_avg:94.82ms +step:928/1705 train_time:87990ms step_avg:94.82ms +step:929/1705 train_time:88084ms step_avg:94.82ms +step:930/1705 train_time:88179ms step_avg:94.82ms +step:931/1705 train_time:88275ms step_avg:94.82ms +step:932/1705 train_time:88370ms step_avg:94.82ms +step:933/1705 train_time:88463ms step_avg:94.82ms +step:934/1705 train_time:88558ms step_avg:94.82ms +step:935/1705 train_time:88653ms step_avg:94.82ms +step:936/1705 train_time:88748ms step_avg:94.82ms +step:937/1705 train_time:88842ms step_avg:94.81ms +step:938/1705 train_time:88936ms step_avg:94.81ms +step:939/1705 train_time:89030ms step_avg:94.81ms +step:940/1705 train_time:89125ms step_avg:94.81ms +step:941/1705 train_time:89219ms step_avg:94.81ms +step:942/1705 train_time:89315ms step_avg:94.81ms +step:943/1705 train_time:89410ms step_avg:94.81ms +step:944/1705 train_time:89504ms step_avg:94.81ms +step:945/1705 train_time:89599ms step_avg:94.81ms +step:946/1705 train_time:89694ms step_avg:94.81ms +step:947/1705 train_time:89788ms step_avg:94.81ms +step:948/1705 train_time:89881ms step_avg:94.81ms +step:949/1705 train_time:89976ms step_avg:94.81ms +step:950/1705 train_time:90070ms step_avg:94.81ms +step:951/1705 train_time:90164ms step_avg:94.81ms +step:952/1705 train_time:90260ms step_avg:94.81ms +step:953/1705 train_time:90354ms step_avg:94.81ms +step:954/1705 train_time:90449ms step_avg:94.81ms +step:955/1705 train_time:90544ms step_avg:94.81ms +step:956/1705 train_time:90639ms step_avg:94.81ms +step:957/1705 train_time:90734ms step_avg:94.81ms +step:958/1705 train_time:90829ms step_avg:94.81ms +step:959/1705 train_time:90922ms step_avg:94.81ms +step:960/1705 train_time:91016ms step_avg:94.81ms +step:961/1705 train_time:91112ms step_avg:94.81ms +step:962/1705 train_time:91206ms step_avg:94.81ms +step:963/1705 train_time:91300ms step_avg:94.81ms +step:964/1705 train_time:91395ms step_avg:94.81ms +step:965/1705 train_time:91490ms step_avg:94.81ms +step:966/1705 train_time:91584ms step_avg:94.81ms +step:967/1705 train_time:91679ms step_avg:94.81ms +step:968/1705 train_time:91773ms step_avg:94.81ms +step:969/1705 train_time:91868ms step_avg:94.81ms +step:970/1705 train_time:91961ms step_avg:94.81ms +step:971/1705 train_time:92055ms step_avg:94.80ms +step:972/1705 train_time:92150ms step_avg:94.80ms +step:973/1705 train_time:92244ms step_avg:94.80ms +step:974/1705 train_time:92339ms step_avg:94.80ms +step:975/1705 train_time:92434ms step_avg:94.80ms +step:976/1705 train_time:92528ms step_avg:94.80ms +step:977/1705 train_time:92622ms step_avg:94.80ms +step:978/1705 train_time:92717ms step_avg:94.80ms +step:979/1705 train_time:92812ms step_avg:94.80ms +step:980/1705 train_time:92908ms step_avg:94.80ms +step:981/1705 train_time:93002ms step_avg:94.80ms +step:982/1705 train_time:93096ms step_avg:94.80ms +step:983/1705 train_time:93192ms step_avg:94.80ms +step:984/1705 train_time:93286ms step_avg:94.80ms +step:985/1705 train_time:93380ms step_avg:94.80ms +step:986/1705 train_time:93475ms step_avg:94.80ms +step:987/1705 train_time:93569ms step_avg:94.80ms +step:988/1705 train_time:93663ms step_avg:94.80ms +step:989/1705 train_time:93758ms step_avg:94.80ms +step:990/1705 train_time:93853ms step_avg:94.80ms +step:991/1705 train_time:93948ms step_avg:94.80ms +step:992/1705 train_time:94042ms step_avg:94.80ms +step:993/1705 train_time:94137ms step_avg:94.80ms +step:994/1705 train_time:94231ms step_avg:94.80ms +step:995/1705 train_time:94325ms step_avg:94.80ms +step:996/1705 train_time:94419ms step_avg:94.80ms +step:997/1705 train_time:94515ms step_avg:94.80ms +step:998/1705 train_time:94610ms step_avg:94.80ms +step:999/1705 train_time:94704ms step_avg:94.80ms +step:1000/1705 train_time:94798ms step_avg:94.80ms +step:1000/1705 val_loss:3.4852 train_time:94893ms step_avg:94.89ms +step:1001/1705 train_time:94916ms step_avg:94.82ms +step:1002/1705 train_time:94993ms step_avg:94.80ms +step:1003/1705 train_time:95089ms step_avg:94.80ms +step:1004/1705 train_time:95184ms step_avg:94.80ms +step:1005/1705 train_time:95277ms step_avg:94.80ms +step:1006/1705 train_time:95371ms step_avg:94.80ms +step:1007/1705 train_time:95464ms step_avg:94.80ms +step:1008/1705 train_time:95557ms step_avg:94.80ms +step:1009/1705 train_time:95651ms step_avg:94.80ms +step:1010/1705 train_time:95745ms step_avg:94.80ms +step:1011/1705 train_time:95840ms step_avg:94.80ms +step:1012/1705 train_time:95936ms step_avg:94.80ms +step:1013/1705 train_time:96034ms step_avg:94.80ms +step:1014/1705 train_time:96130ms step_avg:94.80ms +step:1015/1705 train_time:96224ms step_avg:94.80ms +step:1016/1705 train_time:96318ms step_avg:94.80ms +step:1017/1705 train_time:96412ms step_avg:94.80ms +step:1018/1705 train_time:96505ms step_avg:94.80ms +step:1019/1705 train_time:96599ms step_avg:94.80ms +step:1020/1705 train_time:96693ms step_avg:94.80ms +step:1021/1705 train_time:96788ms step_avg:94.80ms +step:1022/1705 train_time:96882ms step_avg:94.80ms +step:1023/1705 train_time:96979ms step_avg:94.80ms +step:1024/1705 train_time:97076ms step_avg:94.80ms +step:1025/1705 train_time:97170ms step_avg:94.80ms +step:1026/1705 train_time:97264ms step_avg:94.80ms +step:1027/1705 train_time:97358ms step_avg:94.80ms +step:1028/1705 train_time:97454ms step_avg:94.80ms +step:1029/1705 train_time:97547ms step_avg:94.80ms +step:1030/1705 train_time:97640ms step_avg:94.80ms +step:1031/1705 train_time:97735ms step_avg:94.80ms +step:1032/1705 train_time:97829ms step_avg:94.80ms +step:1033/1705 train_time:97923ms step_avg:94.79ms +step:1034/1705 train_time:98019ms step_avg:94.80ms +step:1035/1705 train_time:98114ms step_avg:94.80ms +step:1036/1705 train_time:98209ms step_avg:94.80ms +step:1037/1705 train_time:98303ms step_avg:94.80ms +step:1038/1705 train_time:98398ms step_avg:94.80ms +step:1039/1705 train_time:98492ms step_avg:94.80ms +step:1040/1705 train_time:98587ms step_avg:94.80ms +step:1041/1705 train_time:98681ms step_avg:94.79ms +step:1042/1705 train_time:98776ms step_avg:94.79ms +step:1043/1705 train_time:98870ms step_avg:94.79ms +step:1044/1705 train_time:98965ms step_avg:94.79ms +step:1045/1705 train_time:99060ms step_avg:94.79ms +step:1046/1705 train_time:99156ms step_avg:94.80ms +step:1047/1705 train_time:99250ms step_avg:94.79ms +step:1048/1705 train_time:99344ms step_avg:94.79ms +step:1049/1705 train_time:99439ms step_avg:94.79ms +step:1050/1705 train_time:99533ms step_avg:94.79ms +step:1051/1705 train_time:99627ms step_avg:94.79ms +step:1052/1705 train_time:99722ms step_avg:94.79ms +step:1053/1705 train_time:99816ms step_avg:94.79ms +step:1054/1705 train_time:99911ms step_avg:94.79ms +step:1055/1705 train_time:100005ms step_avg:94.79ms +step:1056/1705 train_time:100100ms step_avg:94.79ms +step:1057/1705 train_time:100195ms step_avg:94.79ms +step:1058/1705 train_time:100290ms step_avg:94.79ms +step:1059/1705 train_time:100384ms step_avg:94.79ms +step:1060/1705 train_time:100479ms step_avg:94.79ms +step:1061/1705 train_time:100574ms step_avg:94.79ms +step:1062/1705 train_time:100835ms step_avg:94.95ms +step:1063/1705 train_time:100938ms step_avg:94.96ms +step:1064/1705 train_time:101031ms step_avg:94.95ms +step:1065/1705 train_time:101124ms step_avg:94.95ms +step:1066/1705 train_time:101217ms step_avg:94.95ms +step:1067/1705 train_time:101311ms step_avg:94.95ms +step:1068/1705 train_time:101404ms step_avg:94.95ms +step:1069/1705 train_time:101498ms step_avg:94.95ms +step:1070/1705 train_time:101591ms step_avg:94.95ms +step:1071/1705 train_time:101685ms step_avg:94.94ms +step:1072/1705 train_time:101780ms step_avg:94.94ms +step:1073/1705 train_time:101879ms step_avg:94.95ms +step:1074/1705 train_time:101976ms step_avg:94.95ms +step:1075/1705 train_time:102072ms step_avg:94.95ms +step:1076/1705 train_time:102165ms step_avg:94.95ms +step:1077/1705 train_time:102259ms step_avg:94.95ms +step:1078/1705 train_time:102353ms step_avg:94.95ms +step:1079/1705 train_time:102447ms step_avg:94.95ms +step:1080/1705 train_time:102541ms step_avg:94.95ms +step:1081/1705 train_time:102634ms step_avg:94.94ms +step:1082/1705 train_time:102729ms step_avg:94.94ms +step:1083/1705 train_time:102824ms step_avg:94.94ms +step:1084/1705 train_time:102921ms step_avg:94.95ms +step:1085/1705 train_time:103018ms step_avg:94.95ms +step:1086/1705 train_time:103113ms step_avg:94.95ms +step:1087/1705 train_time:103207ms step_avg:94.95ms +step:1088/1705 train_time:103302ms step_avg:94.95ms +step:1089/1705 train_time:103396ms step_avg:94.95ms +step:1090/1705 train_time:103491ms step_avg:94.95ms +step:1091/1705 train_time:103585ms step_avg:94.94ms +step:1092/1705 train_time:103679ms step_avg:94.94ms +step:1093/1705 train_time:103774ms step_avg:94.94ms +step:1094/1705 train_time:103869ms step_avg:94.94ms +step:1095/1705 train_time:103964ms step_avg:94.94ms +step:1096/1705 train_time:104059ms step_avg:94.94ms +step:1097/1705 train_time:104155ms step_avg:94.95ms +step:1098/1705 train_time:104249ms step_avg:94.94ms +step:1099/1705 train_time:104343ms step_avg:94.94ms +step:1100/1705 train_time:104437ms step_avg:94.94ms +step:1101/1705 train_time:104531ms step_avg:94.94ms +step:1102/1705 train_time:104625ms step_avg:94.94ms +step:1103/1705 train_time:104719ms step_avg:94.94ms +step:1104/1705 train_time:104814ms step_avg:94.94ms +step:1105/1705 train_time:104909ms step_avg:94.94ms +step:1106/1705 train_time:105004ms step_avg:94.94ms +step:1107/1705 train_time:105098ms step_avg:94.94ms +step:1108/1705 train_time:105193ms step_avg:94.94ms +step:1109/1705 train_time:105287ms step_avg:94.94ms +step:1110/1705 train_time:105382ms step_avg:94.94ms +step:1111/1705 train_time:105477ms step_avg:94.94ms +step:1112/1705 train_time:105571ms step_avg:94.94ms +step:1113/1705 train_time:105665ms step_avg:94.94ms +step:1114/1705 train_time:105760ms step_avg:94.94ms +step:1115/1705 train_time:105855ms step_avg:94.94ms +step:1116/1705 train_time:105949ms step_avg:94.94ms +step:1117/1705 train_time:106044ms step_avg:94.94ms +step:1118/1705 train_time:106139ms step_avg:94.94ms +step:1119/1705 train_time:106234ms step_avg:94.94ms +step:1120/1705 train_time:106328ms step_avg:94.94ms +step:1121/1705 train_time:106422ms step_avg:94.93ms +step:1122/1705 train_time:106517ms step_avg:94.94ms +step:1123/1705 train_time:106612ms step_avg:94.93ms +step:1124/1705 train_time:106705ms step_avg:94.93ms +step:1125/1705 train_time:106800ms step_avg:94.93ms +step:1125/1705 val_loss:3.4369 train_time:106895ms step_avg:95.02ms +step:1126/1705 train_time:106918ms step_avg:94.95ms +step:1127/1705 train_time:106995ms step_avg:94.94ms +step:1128/1705 train_time:107092ms step_avg:94.94ms +step:1129/1705 train_time:107187ms step_avg:94.94ms +step:1130/1705 train_time:107281ms step_avg:94.94ms +step:1131/1705 train_time:107374ms step_avg:94.94ms +step:1132/1705 train_time:107468ms step_avg:94.94ms +step:1133/1705 train_time:107562ms step_avg:94.94ms +step:1134/1705 train_time:107656ms step_avg:94.94ms +step:1135/1705 train_time:107750ms step_avg:94.93ms +step:1136/1705 train_time:107844ms step_avg:94.93ms +step:1137/1705 train_time:107940ms step_avg:94.93ms +step:1138/1705 train_time:108038ms step_avg:94.94ms +step:1139/1705 train_time:108134ms step_avg:94.94ms +step:1140/1705 train_time:108229ms step_avg:94.94ms +step:1141/1705 train_time:108324ms step_avg:94.94ms +step:1142/1705 train_time:108419ms step_avg:94.94ms +step:1143/1705 train_time:108513ms step_avg:94.94ms +step:1144/1705 train_time:108607ms step_avg:94.94ms +step:1145/1705 train_time:108701ms step_avg:94.94ms +step:1146/1705 train_time:108796ms step_avg:94.94ms +step:1147/1705 train_time:108891ms step_avg:94.94ms +step:1148/1705 train_time:108988ms step_avg:94.94ms +step:1149/1705 train_time:109084ms step_avg:94.94ms +step:1150/1705 train_time:109180ms step_avg:94.94ms +step:1151/1705 train_time:109277ms step_avg:94.94ms +step:1152/1705 train_time:109372ms step_avg:94.94ms +step:1153/1705 train_time:109466ms step_avg:94.94ms +step:1154/1705 train_time:109561ms step_avg:94.94ms +step:1155/1705 train_time:109656ms step_avg:94.94ms +step:1156/1705 train_time:109750ms step_avg:94.94ms +step:1157/1705 train_time:109845ms step_avg:94.94ms +step:1158/1705 train_time:109941ms step_avg:94.94ms +step:1159/1705 train_time:110038ms step_avg:94.94ms +step:1160/1705 train_time:110134ms step_avg:94.94ms +step:1161/1705 train_time:110230ms step_avg:94.94ms +step:1162/1705 train_time:110325ms step_avg:94.94ms +step:1163/1705 train_time:110421ms step_avg:94.95ms +step:1164/1705 train_time:110517ms step_avg:94.95ms +step:1165/1705 train_time:110611ms step_avg:94.95ms +step:1166/1705 train_time:110706ms step_avg:94.94ms +step:1167/1705 train_time:110801ms step_avg:94.95ms +step:1168/1705 train_time:110896ms step_avg:94.95ms +step:1169/1705 train_time:110992ms step_avg:94.95ms +step:1170/1705 train_time:111087ms step_avg:94.95ms +step:1171/1705 train_time:111183ms step_avg:94.95ms +step:1172/1705 train_time:111280ms step_avg:94.95ms +step:1173/1705 train_time:111376ms step_avg:94.95ms +step:1174/1705 train_time:111472ms step_avg:94.95ms +step:1175/1705 train_time:111567ms step_avg:94.95ms +step:1176/1705 train_time:111662ms step_avg:94.95ms +step:1177/1705 train_time:111756ms step_avg:94.95ms +step:1178/1705 train_time:111851ms step_avg:94.95ms +step:1179/1705 train_time:111946ms step_avg:94.95ms +step:1180/1705 train_time:112042ms step_avg:94.95ms +step:1181/1705 train_time:112137ms step_avg:94.95ms +step:1182/1705 train_time:112232ms step_avg:94.95ms +step:1183/1705 train_time:112328ms step_avg:94.95ms +step:1184/1705 train_time:112423ms step_avg:94.95ms +step:1185/1705 train_time:112519ms step_avg:94.95ms +step:1186/1705 train_time:112614ms step_avg:94.95ms +step:1187/1705 train_time:112709ms step_avg:94.95ms +step:1188/1705 train_time:112804ms step_avg:94.95ms +step:1189/1705 train_time:112900ms step_avg:94.95ms +step:1190/1705 train_time:112996ms step_avg:94.95ms +step:1191/1705 train_time:113091ms step_avg:94.95ms +step:1192/1705 train_time:113186ms step_avg:94.95ms +step:1193/1705 train_time:113282ms step_avg:94.96ms +step:1194/1705 train_time:113379ms step_avg:94.96ms +step:1195/1705 train_time:113474ms step_avg:94.96ms +step:1196/1705 train_time:113569ms step_avg:94.96ms +step:1197/1705 train_time:113664ms step_avg:94.96ms +step:1198/1705 train_time:113759ms step_avg:94.96ms +step:1199/1705 train_time:113854ms step_avg:94.96ms +step:1200/1705 train_time:113949ms step_avg:94.96ms +step:1201/1705 train_time:114044ms step_avg:94.96ms +step:1202/1705 train_time:114140ms step_avg:94.96ms +step:1203/1705 train_time:114235ms step_avg:94.96ms +step:1204/1705 train_time:114331ms step_avg:94.96ms +step:1205/1705 train_time:114426ms step_avg:94.96ms +step:1206/1705 train_time:114522ms step_avg:94.96ms +step:1207/1705 train_time:114617ms step_avg:94.96ms +step:1208/1705 train_time:114712ms step_avg:94.96ms +step:1209/1705 train_time:114807ms step_avg:94.96ms +step:1210/1705 train_time:114903ms step_avg:94.96ms +step:1211/1705 train_time:114998ms step_avg:94.96ms +step:1212/1705 train_time:115094ms step_avg:94.96ms +step:1213/1705 train_time:115189ms step_avg:94.96ms +step:1214/1705 train_time:115284ms step_avg:94.96ms +step:1215/1705 train_time:115381ms step_avg:94.96ms +step:1216/1705 train_time:115477ms step_avg:94.96ms +step:1217/1705 train_time:115573ms step_avg:94.97ms +step:1218/1705 train_time:115667ms step_avg:94.97ms +step:1219/1705 train_time:115762ms step_avg:94.96ms +step:1220/1705 train_time:115858ms step_avg:94.97ms +step:1221/1705 train_time:115953ms step_avg:94.97ms +step:1222/1705 train_time:116048ms step_avg:94.97ms +step:1223/1705 train_time:116142ms step_avg:94.97ms +step:1224/1705 train_time:116238ms step_avg:94.97ms +step:1225/1705 train_time:116335ms step_avg:94.97ms +step:1226/1705 train_time:116431ms step_avg:94.97ms +step:1227/1705 train_time:116526ms step_avg:94.97ms +step:1228/1705 train_time:116621ms step_avg:94.97ms +step:1229/1705 train_time:116717ms step_avg:94.97ms +step:1230/1705 train_time:116812ms step_avg:94.97ms +step:1231/1705 train_time:116907ms step_avg:94.97ms +step:1232/1705 train_time:117002ms step_avg:94.97ms +step:1233/1705 train_time:117098ms step_avg:94.97ms +step:1234/1705 train_time:117193ms step_avg:94.97ms +step:1235/1705 train_time:117287ms step_avg:94.97ms +step:1236/1705 train_time:117383ms step_avg:94.97ms +step:1237/1705 train_time:117479ms step_avg:94.97ms +step:1238/1705 train_time:117574ms step_avg:94.97ms +step:1239/1705 train_time:117669ms step_avg:94.97ms +step:1240/1705 train_time:117764ms step_avg:94.97ms +step:1241/1705 train_time:117859ms step_avg:94.97ms +step:1242/1705 train_time:117954ms step_avg:94.97ms +step:1243/1705 train_time:118049ms step_avg:94.97ms +step:1244/1705 train_time:118145ms step_avg:94.97ms +step:1245/1705 train_time:118240ms step_avg:94.97ms +step:1246/1705 train_time:118335ms step_avg:94.97ms +step:1247/1705 train_time:118431ms step_avg:94.97ms +step:1248/1705 train_time:118526ms step_avg:94.97ms +step:1249/1705 train_time:118621ms step_avg:94.97ms +step:1250/1705 train_time:118717ms step_avg:94.97ms +step:1250/1705 val_loss:3.3885 train_time:118813ms step_avg:95.05ms +step:1251/1705 train_time:118836ms step_avg:94.99ms +step:1252/1705 train_time:118909ms step_avg:94.98ms +step:1253/1705 train_time:119006ms step_avg:94.98ms +step:1254/1705 train_time:119102ms step_avg:94.98ms +step:1255/1705 train_time:119197ms step_avg:94.98ms +step:1256/1705 train_time:119291ms step_avg:94.98ms +step:1257/1705 train_time:119386ms step_avg:94.98ms +step:1258/1705 train_time:119480ms step_avg:94.98ms +step:1259/1705 train_time:119573ms step_avg:94.97ms +step:1260/1705 train_time:119667ms step_avg:94.97ms +step:1261/1705 train_time:119767ms step_avg:94.98ms +step:1262/1705 train_time:119865ms step_avg:94.98ms +step:1263/1705 train_time:119962ms step_avg:94.98ms +step:1264/1705 train_time:120057ms step_avg:94.98ms +step:1265/1705 train_time:120152ms step_avg:94.98ms +step:1266/1705 train_time:120246ms step_avg:94.98ms +step:1267/1705 train_time:120341ms step_avg:94.98ms +step:1268/1705 train_time:120436ms step_avg:94.98ms +step:1269/1705 train_time:120530ms step_avg:94.98ms +step:1270/1705 train_time:120625ms step_avg:94.98ms +step:1271/1705 train_time:120720ms step_avg:94.98ms +step:1272/1705 train_time:120817ms step_avg:94.98ms +step:1273/1705 train_time:120914ms step_avg:94.98ms +step:1274/1705 train_time:121289ms step_avg:95.20ms +step:1275/1705 train_time:121374ms step_avg:95.20ms +step:1276/1705 train_time:121468ms step_avg:95.19ms +step:1277/1705 train_time:121562ms step_avg:95.19ms +step:1278/1705 train_time:121656ms step_avg:95.19ms +step:1279/1705 train_time:121749ms step_avg:95.19ms +step:1280/1705 train_time:121844ms step_avg:95.19ms +step:1281/1705 train_time:121938ms step_avg:95.19ms +step:1282/1705 train_time:122033ms step_avg:95.19ms +step:1283/1705 train_time:122127ms step_avg:95.19ms +step:1284/1705 train_time:122226ms step_avg:95.19ms +step:1285/1705 train_time:122327ms step_avg:95.20ms +step:1286/1705 train_time:122423ms step_avg:95.20ms +step:1287/1705 train_time:122519ms step_avg:95.20ms +step:1288/1705 train_time:122614ms step_avg:95.20ms +step:1289/1705 train_time:122708ms step_avg:95.20ms +step:1290/1705 train_time:122803ms step_avg:95.20ms +step:1291/1705 train_time:122898ms step_avg:95.20ms +step:1292/1705 train_time:122991ms step_avg:95.19ms +step:1293/1705 train_time:123086ms step_avg:95.19ms +step:1294/1705 train_time:123183ms step_avg:95.20ms +step:1295/1705 train_time:123281ms step_avg:95.20ms +step:1296/1705 train_time:123377ms step_avg:95.20ms +step:1297/1705 train_time:123472ms step_avg:95.20ms +step:1298/1705 train_time:123568ms step_avg:95.20ms +step:1299/1705 train_time:123664ms step_avg:95.20ms +step:1300/1705 train_time:123760ms step_avg:95.20ms +step:1301/1705 train_time:123855ms step_avg:95.20ms +step:1302/1705 train_time:123948ms step_avg:95.20ms +step:1303/1705 train_time:124042ms step_avg:95.20ms +step:1304/1705 train_time:124138ms step_avg:95.20ms +step:1305/1705 train_time:124235ms step_avg:95.20ms +step:1306/1705 train_time:124331ms step_avg:95.20ms +step:1307/1705 train_time:124427ms step_avg:95.20ms +step:1308/1705 train_time:124523ms step_avg:95.20ms +step:1309/1705 train_time:124618ms step_avg:95.20ms +step:1310/1705 train_time:124713ms step_avg:95.20ms +step:1311/1705 train_time:124809ms step_avg:95.20ms +step:1312/1705 train_time:124904ms step_avg:95.20ms +step:1313/1705 train_time:124998ms step_avg:95.20ms +step:1314/1705 train_time:125093ms step_avg:95.20ms +step:1315/1705 train_time:125187ms step_avg:95.20ms +step:1316/1705 train_time:125285ms step_avg:95.20ms +step:1317/1705 train_time:125382ms step_avg:95.20ms +step:1318/1705 train_time:125478ms step_avg:95.20ms +step:1319/1705 train_time:125572ms step_avg:95.20ms +step:1320/1705 train_time:125668ms step_avg:95.20ms +step:1321/1705 train_time:125763ms step_avg:95.20ms +step:1322/1705 train_time:125859ms step_avg:95.20ms +step:1323/1705 train_time:125953ms step_avg:95.20ms +step:1324/1705 train_time:126048ms step_avg:95.20ms +step:1325/1705 train_time:126143ms step_avg:95.20ms +step:1326/1705 train_time:126238ms step_avg:95.20ms +step:1327/1705 train_time:126333ms step_avg:95.20ms +step:1328/1705 train_time:126429ms step_avg:95.20ms +step:1329/1705 train_time:126525ms step_avg:95.20ms +step:1330/1705 train_time:126621ms step_avg:95.20ms +step:1331/1705 train_time:126716ms step_avg:95.20ms +step:1332/1705 train_time:126811ms step_avg:95.20ms +step:1333/1705 train_time:126906ms step_avg:95.20ms +step:1334/1705 train_time:127001ms step_avg:95.20ms +step:1335/1705 train_time:127097ms step_avg:95.20ms +step:1336/1705 train_time:127192ms step_avg:95.20ms +step:1337/1705 train_time:127288ms step_avg:95.20ms +step:1338/1705 train_time:127383ms step_avg:95.20ms +step:1339/1705 train_time:127480ms step_avg:95.21ms +step:1340/1705 train_time:127576ms step_avg:95.21ms +step:1341/1705 train_time:127671ms step_avg:95.21ms +step:1342/1705 train_time:127766ms step_avg:95.21ms +step:1343/1705 train_time:127861ms step_avg:95.21ms +step:1344/1705 train_time:127955ms step_avg:95.20ms +step:1345/1705 train_time:128050ms step_avg:95.20ms +step:1346/1705 train_time:128146ms step_avg:95.21ms +step:1347/1705 train_time:128242ms step_avg:95.21ms +step:1348/1705 train_time:128337ms step_avg:95.21ms +step:1349/1705 train_time:128433ms step_avg:95.21ms +step:1350/1705 train_time:128528ms step_avg:95.21ms +step:1351/1705 train_time:128625ms step_avg:95.21ms +step:1352/1705 train_time:128720ms step_avg:95.21ms +step:1353/1705 train_time:128816ms step_avg:95.21ms +step:1354/1705 train_time:128911ms step_avg:95.21ms +step:1355/1705 train_time:129005ms step_avg:95.21ms +step:1356/1705 train_time:129101ms step_avg:95.21ms +step:1357/1705 train_time:129196ms step_avg:95.21ms +step:1358/1705 train_time:129291ms step_avg:95.21ms +step:1359/1705 train_time:129387ms step_avg:95.21ms +step:1360/1705 train_time:129483ms step_avg:95.21ms +step:1361/1705 train_time:129577ms step_avg:95.21ms +step:1362/1705 train_time:129672ms step_avg:95.21ms +step:1363/1705 train_time:129768ms step_avg:95.21ms +step:1364/1705 train_time:129863ms step_avg:95.21ms +step:1365/1705 train_time:129958ms step_avg:95.21ms +step:1366/1705 train_time:130053ms step_avg:95.21ms +step:1367/1705 train_time:130149ms step_avg:95.21ms +step:1368/1705 train_time:130244ms step_avg:95.21ms +step:1369/1705 train_time:130341ms step_avg:95.21ms +step:1370/1705 train_time:130436ms step_avg:95.21ms +step:1371/1705 train_time:130531ms step_avg:95.21ms +step:1372/1705 train_time:130626ms step_avg:95.21ms +step:1373/1705 train_time:130721ms step_avg:95.21ms +step:1374/1705 train_time:130817ms step_avg:95.21ms +step:1375/1705 train_time:130912ms step_avg:95.21ms +step:1375/1705 val_loss:3.3516 train_time:131007ms step_avg:95.28ms +step:1376/1705 train_time:131030ms step_avg:95.23ms +step:1377/1705 train_time:131111ms step_avg:95.22ms +step:1378/1705 train_time:131212ms step_avg:95.22ms +step:1379/1705 train_time:131307ms step_avg:95.22ms +step:1380/1705 train_time:131402ms step_avg:95.22ms +step:1381/1705 train_time:131497ms step_avg:95.22ms +step:1382/1705 train_time:131590ms step_avg:95.22ms +step:1383/1705 train_time:131685ms step_avg:95.22ms +step:1384/1705 train_time:131780ms step_avg:95.22ms +step:1385/1705 train_time:131874ms step_avg:95.22ms +step:1386/1705 train_time:131970ms step_avg:95.22ms +step:1387/1705 train_time:132070ms step_avg:95.22ms +step:1388/1705 train_time:132167ms step_avg:95.22ms +step:1389/1705 train_time:132264ms step_avg:95.22ms +step:1390/1705 train_time:132359ms step_avg:95.22ms +step:1391/1705 train_time:132454ms step_avg:95.22ms +step:1392/1705 train_time:132549ms step_avg:95.22ms +step:1393/1705 train_time:132643ms step_avg:95.22ms +step:1394/1705 train_time:132737ms step_avg:95.22ms +step:1395/1705 train_time:132832ms step_avg:95.22ms +step:1396/1705 train_time:132927ms step_avg:95.22ms +step:1397/1705 train_time:133023ms step_avg:95.22ms +step:1398/1705 train_time:133119ms step_avg:95.22ms +step:1399/1705 train_time:133216ms step_avg:95.22ms +step:1400/1705 train_time:133313ms step_avg:95.22ms +step:1401/1705 train_time:133408ms step_avg:95.22ms +step:1402/1705 train_time:133503ms step_avg:95.22ms +step:1403/1705 train_time:133598ms step_avg:95.22ms +step:1404/1705 train_time:133693ms step_avg:95.22ms +step:1405/1705 train_time:133788ms step_avg:95.22ms +step:1406/1705 train_time:133882ms step_avg:95.22ms +step:1407/1705 train_time:133977ms step_avg:95.22ms +step:1408/1705 train_time:134074ms step_avg:95.22ms +step:1409/1705 train_time:134170ms step_avg:95.22ms +step:1410/1705 train_time:134267ms step_avg:95.22ms +step:1411/1705 train_time:134362ms step_avg:95.22ms +step:1412/1705 train_time:134458ms step_avg:95.22ms +step:1413/1705 train_time:134552ms step_avg:95.22ms +step:1414/1705 train_time:134647ms step_avg:95.22ms +step:1415/1705 train_time:134743ms step_avg:95.22ms +step:1416/1705 train_time:134838ms step_avg:95.22ms +step:1417/1705 train_time:134932ms step_avg:95.22ms +step:1418/1705 train_time:135028ms step_avg:95.22ms +step:1419/1705 train_time:135124ms step_avg:95.22ms +step:1420/1705 train_time:135219ms step_avg:95.22ms +step:1421/1705 train_time:135316ms step_avg:95.23ms +step:1422/1705 train_time:135412ms step_avg:95.23ms +step:1423/1705 train_time:135507ms step_avg:95.23ms +step:1424/1705 train_time:135601ms step_avg:95.23ms +step:1425/1705 train_time:135696ms step_avg:95.23ms +step:1426/1705 train_time:135791ms step_avg:95.23ms +step:1427/1705 train_time:135887ms step_avg:95.23ms +step:1428/1705 train_time:135982ms step_avg:95.23ms +step:1429/1705 train_time:136077ms step_avg:95.23ms +step:1430/1705 train_time:136173ms step_avg:95.23ms +step:1431/1705 train_time:136271ms step_avg:95.23ms +step:1432/1705 train_time:136367ms step_avg:95.23ms +step:1433/1705 train_time:136462ms step_avg:95.23ms +step:1434/1705 train_time:136557ms step_avg:95.23ms +step:1435/1705 train_time:136652ms step_avg:95.23ms +step:1436/1705 train_time:136747ms step_avg:95.23ms +step:1437/1705 train_time:136844ms step_avg:95.23ms +step:1438/1705 train_time:136937ms step_avg:95.23ms +step:1439/1705 train_time:137031ms step_avg:95.23ms +step:1440/1705 train_time:137127ms step_avg:95.23ms +step:1441/1705 train_time:137223ms step_avg:95.23ms +step:1442/1705 train_time:137319ms step_avg:95.23ms +step:1443/1705 train_time:137416ms step_avg:95.23ms +step:1444/1705 train_time:137511ms step_avg:95.23ms +step:1445/1705 train_time:137606ms step_avg:95.23ms +step:1446/1705 train_time:137701ms step_avg:95.23ms +step:1447/1705 train_time:137797ms step_avg:95.23ms +step:1448/1705 train_time:137892ms step_avg:95.23ms +step:1449/1705 train_time:137988ms step_avg:95.23ms +step:1450/1705 train_time:138084ms step_avg:95.23ms +step:1451/1705 train_time:138179ms step_avg:95.23ms +step:1452/1705 train_time:138275ms step_avg:95.23ms +step:1453/1705 train_time:138370ms step_avg:95.23ms +step:1454/1705 train_time:138466ms step_avg:95.23ms +step:1455/1705 train_time:138561ms step_avg:95.23ms +step:1456/1705 train_time:138655ms step_avg:95.23ms +step:1457/1705 train_time:138751ms step_avg:95.23ms +step:1458/1705 train_time:138847ms step_avg:95.23ms +step:1459/1705 train_time:138941ms step_avg:95.23ms +step:1460/1705 train_time:139037ms step_avg:95.23ms +step:1461/1705 train_time:139132ms step_avg:95.23ms +step:1462/1705 train_time:139230ms step_avg:95.23ms +step:1463/1705 train_time:139325ms step_avg:95.23ms +step:1464/1705 train_time:139420ms step_avg:95.23ms +step:1465/1705 train_time:139516ms step_avg:95.23ms +step:1466/1705 train_time:139611ms step_avg:95.23ms +step:1467/1705 train_time:139707ms step_avg:95.23ms +step:1468/1705 train_time:139802ms step_avg:95.23ms +step:1469/1705 train_time:139897ms step_avg:95.23ms +step:1470/1705 train_time:139992ms step_avg:95.23ms +step:1471/1705 train_time:140087ms step_avg:95.23ms +step:1472/1705 train_time:140183ms step_avg:95.23ms +step:1473/1705 train_time:140278ms step_avg:95.23ms +step:1474/1705 train_time:140372ms step_avg:95.23ms +step:1475/1705 train_time:140468ms step_avg:95.23ms +step:1476/1705 train_time:140563ms step_avg:95.23ms +step:1477/1705 train_time:140657ms step_avg:95.23ms +step:1478/1705 train_time:140754ms step_avg:95.23ms +step:1479/1705 train_time:140849ms step_avg:95.23ms +step:1480/1705 train_time:140944ms step_avg:95.23ms +step:1481/1705 train_time:141039ms step_avg:95.23ms +step:1482/1705 train_time:141134ms step_avg:95.23ms +step:1483/1705 train_time:141231ms step_avg:95.23ms +step:1484/1705 train_time:141326ms step_avg:95.23ms +step:1485/1705 train_time:141701ms step_avg:95.42ms +step:1486/1705 train_time:141777ms step_avg:95.41ms +step:1487/1705 train_time:141871ms step_avg:95.41ms +step:1488/1705 train_time:141965ms step_avg:95.41ms +step:1489/1705 train_time:142059ms step_avg:95.41ms +step:1490/1705 train_time:142153ms step_avg:95.40ms +step:1491/1705 train_time:142247ms step_avg:95.40ms +step:1492/1705 train_time:142340ms step_avg:95.40ms +step:1493/1705 train_time:142435ms step_avg:95.40ms +step:1494/1705 train_time:142529ms step_avg:95.40ms +step:1495/1705 train_time:142628ms step_avg:95.40ms +step:1496/1705 train_time:142729ms step_avg:95.41ms +step:1497/1705 train_time:142826ms step_avg:95.41ms +step:1498/1705 train_time:142921ms step_avg:95.41ms +step:1499/1705 train_time:143017ms step_avg:95.41ms +step:1500/1705 train_time:143112ms step_avg:95.41ms +step:1500/1705 val_loss:3.3198 train_time:143207ms step_avg:95.47ms +step:1501/1705 train_time:143230ms step_avg:95.42ms +step:1502/1705 train_time:143309ms step_avg:95.41ms +step:1503/1705 train_time:143408ms step_avg:95.41ms +step:1504/1705 train_time:143504ms step_avg:95.41ms +step:1505/1705 train_time:143598ms step_avg:95.41ms +step:1506/1705 train_time:143692ms step_avg:95.41ms +step:1507/1705 train_time:143787ms step_avg:95.41ms +step:1508/1705 train_time:143882ms step_avg:95.41ms +step:1509/1705 train_time:143976ms step_avg:95.41ms +step:1510/1705 train_time:144070ms step_avg:95.41ms +step:1511/1705 train_time:144167ms step_avg:95.41ms +step:1512/1705 train_time:144264ms step_avg:95.41ms +step:1513/1705 train_time:144363ms step_avg:95.41ms +step:1514/1705 train_time:144459ms step_avg:95.42ms +step:1515/1705 train_time:144555ms step_avg:95.42ms +step:1516/1705 train_time:144649ms step_avg:95.41ms +step:1517/1705 train_time:144744ms step_avg:95.41ms +step:1518/1705 train_time:144838ms step_avg:95.41ms +step:1519/1705 train_time:144932ms step_avg:95.41ms +step:1520/1705 train_time:145027ms step_avg:95.41ms +step:1521/1705 train_time:145123ms step_avg:95.41ms +step:1522/1705 train_time:145222ms step_avg:95.42ms +step:1523/1705 train_time:145319ms step_avg:95.42ms +step:1524/1705 train_time:145416ms step_avg:95.42ms +step:1525/1705 train_time:145511ms step_avg:95.42ms +step:1526/1705 train_time:145607ms step_avg:95.42ms +step:1527/1705 train_time:145702ms step_avg:95.42ms +step:1528/1705 train_time:145797ms step_avg:95.42ms +step:1529/1705 train_time:145891ms step_avg:95.42ms +step:1530/1705 train_time:145985ms step_avg:95.42ms +step:1531/1705 train_time:146080ms step_avg:95.41ms +step:1532/1705 train_time:146176ms step_avg:95.42ms +step:1533/1705 train_time:146272ms step_avg:95.42ms +step:1534/1705 train_time:146368ms step_avg:95.42ms +step:1535/1705 train_time:146463ms step_avg:95.42ms +step:1536/1705 train_time:146561ms step_avg:95.42ms +step:1537/1705 train_time:146656ms step_avg:95.42ms +step:1538/1705 train_time:146751ms step_avg:95.42ms +step:1539/1705 train_time:146846ms step_avg:95.42ms +step:1540/1705 train_time:146941ms step_avg:95.42ms +step:1541/1705 train_time:147036ms step_avg:95.42ms +step:1542/1705 train_time:147130ms step_avg:95.42ms +step:1543/1705 train_time:147226ms step_avg:95.42ms +step:1544/1705 train_time:147323ms step_avg:95.42ms +step:1545/1705 train_time:147420ms step_avg:95.42ms +step:1546/1705 train_time:147516ms step_avg:95.42ms +step:1547/1705 train_time:147613ms step_avg:95.42ms +step:1548/1705 train_time:147707ms step_avg:95.42ms +step:1549/1705 train_time:147803ms step_avg:95.42ms +step:1550/1705 train_time:147898ms step_avg:95.42ms +step:1551/1705 train_time:147994ms step_avg:95.42ms +step:1552/1705 train_time:148089ms step_avg:95.42ms +step:1553/1705 train_time:148183ms step_avg:95.42ms +step:1554/1705 train_time:148278ms step_avg:95.42ms +step:1555/1705 train_time:148374ms step_avg:95.42ms +step:1556/1705 train_time:148470ms step_avg:95.42ms +step:1557/1705 train_time:148567ms step_avg:95.42ms +step:1558/1705 train_time:148663ms step_avg:95.42ms +step:1559/1705 train_time:148759ms step_avg:95.42ms +step:1560/1705 train_time:148854ms step_avg:95.42ms +step:1561/1705 train_time:148948ms step_avg:95.42ms +step:1562/1705 train_time:149043ms step_avg:95.42ms +step:1563/1705 train_time:149138ms step_avg:95.42ms +step:1564/1705 train_time:149233ms step_avg:95.42ms +step:1565/1705 train_time:149328ms step_avg:95.42ms +step:1566/1705 train_time:149424ms step_avg:95.42ms +step:1567/1705 train_time:149520ms step_avg:95.42ms +step:1568/1705 train_time:149617ms step_avg:95.42ms +step:1569/1705 train_time:149713ms step_avg:95.42ms +step:1570/1705 train_time:149808ms step_avg:95.42ms +step:1571/1705 train_time:149902ms step_avg:95.42ms +step:1572/1705 train_time:149997ms step_avg:95.42ms +step:1573/1705 train_time:150092ms step_avg:95.42ms +step:1574/1705 train_time:150187ms step_avg:95.42ms +step:1575/1705 train_time:150282ms step_avg:95.42ms +step:1576/1705 train_time:150377ms step_avg:95.42ms +step:1577/1705 train_time:150472ms step_avg:95.42ms +step:1578/1705 train_time:150568ms step_avg:95.42ms +step:1579/1705 train_time:150664ms step_avg:95.42ms +step:1580/1705 train_time:150760ms step_avg:95.42ms +step:1581/1705 train_time:150855ms step_avg:95.42ms +step:1582/1705 train_time:150950ms step_avg:95.42ms +step:1583/1705 train_time:151045ms step_avg:95.42ms +step:1584/1705 train_time:151140ms step_avg:95.42ms +step:1585/1705 train_time:151236ms step_avg:95.42ms +step:1586/1705 train_time:151330ms step_avg:95.42ms +step:1587/1705 train_time:151426ms step_avg:95.42ms +step:1588/1705 train_time:151522ms step_avg:95.42ms +step:1589/1705 train_time:151618ms step_avg:95.42ms +step:1590/1705 train_time:151714ms step_avg:95.42ms +step:1591/1705 train_time:151809ms step_avg:95.42ms +step:1592/1705 train_time:151904ms step_avg:95.42ms +step:1593/1705 train_time:152000ms step_avg:95.42ms +step:1594/1705 train_time:152096ms step_avg:95.42ms +step:1595/1705 train_time:152191ms step_avg:95.42ms +step:1596/1705 train_time:152285ms step_avg:95.42ms +step:1597/1705 train_time:152381ms step_avg:95.42ms +step:1598/1705 train_time:152477ms step_avg:95.42ms +step:1599/1705 train_time:152572ms step_avg:95.42ms +step:1600/1705 train_time:152667ms step_avg:95.42ms +step:1601/1705 train_time:152763ms step_avg:95.42ms +step:1602/1705 train_time:152860ms step_avg:95.42ms +step:1603/1705 train_time:152955ms step_avg:95.42ms +step:1604/1705 train_time:153050ms step_avg:95.42ms +step:1605/1705 train_time:153145ms step_avg:95.42ms +step:1606/1705 train_time:153241ms step_avg:95.42ms +step:1607/1705 train_time:153338ms step_avg:95.42ms +step:1608/1705 train_time:153434ms step_avg:95.42ms +step:1609/1705 train_time:153529ms step_avg:95.42ms +step:1610/1705 train_time:153625ms step_avg:95.42ms +step:1611/1705 train_time:153721ms step_avg:95.42ms +step:1612/1705 train_time:153816ms step_avg:95.42ms +step:1613/1705 train_time:153913ms step_avg:95.42ms +step:1614/1705 train_time:154008ms step_avg:95.42ms +step:1615/1705 train_time:154103ms step_avg:95.42ms +step:1616/1705 train_time:154198ms step_avg:95.42ms +step:1617/1705 train_time:154293ms step_avg:95.42ms +step:1618/1705 train_time:154389ms step_avg:95.42ms +step:1619/1705 train_time:154484ms step_avg:95.42ms +step:1620/1705 train_time:154579ms step_avg:95.42ms +step:1621/1705 train_time:154674ms step_avg:95.42ms +step:1622/1705 train_time:154768ms step_avg:95.42ms +step:1623/1705 train_time:154864ms step_avg:95.42ms +step:1624/1705 train_time:154961ms step_avg:95.42ms +step:1625/1705 train_time:155057ms step_avg:95.42ms +step:1625/1705 val_loss:3.2921 train_time:155153ms step_avg:95.48ms +step:1626/1705 train_time:155176ms step_avg:95.43ms +step:1627/1705 train_time:155256ms step_avg:95.42ms +step:1628/1705 train_time:155355ms step_avg:95.43ms +step:1629/1705 train_time:155451ms step_avg:95.43ms +step:1630/1705 train_time:155545ms step_avg:95.43ms +step:1631/1705 train_time:155640ms step_avg:95.43ms +step:1632/1705 train_time:155734ms step_avg:95.43ms +step:1633/1705 train_time:155828ms step_avg:95.42ms +step:1634/1705 train_time:155922ms step_avg:95.42ms +step:1635/1705 train_time:156016ms step_avg:95.42ms +step:1636/1705 train_time:156113ms step_avg:95.42ms +step:1637/1705 train_time:156210ms step_avg:95.42ms +step:1638/1705 train_time:156309ms step_avg:95.43ms +step:1639/1705 train_time:156406ms step_avg:95.43ms +step:1640/1705 train_time:156502ms step_avg:95.43ms +step:1641/1705 train_time:156596ms step_avg:95.43ms +step:1642/1705 train_time:156691ms step_avg:95.43ms +step:1643/1705 train_time:156786ms step_avg:95.43ms +step:1644/1705 train_time:156881ms step_avg:95.43ms +step:1645/1705 train_time:156975ms step_avg:95.43ms +step:1646/1705 train_time:157071ms step_avg:95.43ms +step:1647/1705 train_time:157167ms step_avg:95.43ms +step:1648/1705 train_time:157264ms step_avg:95.43ms +step:1649/1705 train_time:157360ms step_avg:95.43ms +step:1650/1705 train_time:157456ms step_avg:95.43ms +step:1651/1705 train_time:157551ms step_avg:95.43ms +step:1652/1705 train_time:157646ms step_avg:95.43ms +step:1653/1705 train_time:157740ms step_avg:95.43ms +step:1654/1705 train_time:157835ms step_avg:95.43ms +step:1655/1705 train_time:157930ms step_avg:95.43ms +step:1656/1705 train_time:158025ms step_avg:95.43ms +step:1657/1705 train_time:158120ms step_avg:95.43ms +step:1658/1705 train_time:158216ms step_avg:95.43ms +step:1659/1705 train_time:158314ms step_avg:95.43ms +step:1660/1705 train_time:158410ms step_avg:95.43ms +step:1661/1705 train_time:158505ms step_avg:95.43ms +step:1662/1705 train_time:158600ms step_avg:95.43ms +step:1663/1705 train_time:158696ms step_avg:95.43ms +step:1664/1705 train_time:158791ms step_avg:95.43ms +step:1665/1705 train_time:158887ms step_avg:95.43ms +step:1666/1705 train_time:158982ms step_avg:95.43ms +step:1667/1705 train_time:159077ms step_avg:95.43ms +step:1668/1705 train_time:159173ms step_avg:95.43ms +step:1669/1705 train_time:159269ms step_avg:95.43ms +step:1670/1705 train_time:159364ms step_avg:95.43ms +step:1671/1705 train_time:159459ms step_avg:95.43ms +step:1672/1705 train_time:159554ms step_avg:95.43ms +step:1673/1705 train_time:159650ms step_avg:95.43ms +step:1674/1705 train_time:159746ms step_avg:95.43ms +step:1675/1705 train_time:159841ms step_avg:95.43ms +step:1676/1705 train_time:159936ms step_avg:95.43ms +step:1677/1705 train_time:160031ms step_avg:95.43ms +step:1678/1705 train_time:160127ms step_avg:95.43ms +step:1679/1705 train_time:160222ms step_avg:95.43ms +step:1680/1705 train_time:160317ms step_avg:95.43ms +step:1681/1705 train_time:160413ms step_avg:95.43ms +step:1682/1705 train_time:160509ms step_avg:95.43ms +step:1683/1705 train_time:160605ms step_avg:95.43ms +step:1684/1705 train_time:160700ms step_avg:95.43ms +step:1685/1705 train_time:160796ms step_avg:95.43ms +step:1686/1705 train_time:160892ms step_avg:95.43ms +step:1687/1705 train_time:160987ms step_avg:95.43ms +step:1688/1705 train_time:161082ms step_avg:95.43ms +step:1689/1705 train_time:161177ms step_avg:95.43ms +step:1690/1705 train_time:161273ms step_avg:95.43ms +step:1691/1705 train_time:161369ms step_avg:95.43ms +step:1692/1705 train_time:161464ms step_avg:95.43ms +step:1693/1705 train_time:161559ms step_avg:95.43ms +step:1694/1705 train_time:161655ms step_avg:95.43ms +step:1695/1705 train_time:161751ms step_avg:95.43ms +step:1696/1705 train_time:161847ms step_avg:95.43ms +step:1697/1705 train_time:161943ms step_avg:95.43ms +step:1698/1705 train_time:162176ms step_avg:95.51ms +step:1699/1705 train_time:162387ms step_avg:95.58ms +step:1700/1705 train_time:162480ms step_avg:95.58ms +step:1701/1705 train_time:162574ms step_avg:95.58ms +step:1702/1705 train_time:162669ms step_avg:95.58ms +step:1703/1705 train_time:162763ms step_avg:95.57ms +step:1704/1705 train_time:162857ms step_avg:95.57ms +step:1705/1705 train_time:162951ms step_avg:95.57ms +step:1705/1705 val_loss:3.2779 train_time:163046ms step_avg:95.63ms +peak memory allocated: 33750 MiB reserved: 49456 MiB diff --git a/records/050925_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt b/records/050925_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt new file mode 100644 index 000000000..8fa21c9f2 --- /dev/null +++ b/records/050925_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:53:26 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 128W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 128W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 123W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 75912 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 75913 C /usr/bin/python3 610MiB | +| 0 N/A N/A 75914 C /usr/bin/python3 610MiB | +| 0 N/A N/A 75915 C /usr/bin/python3 610MiB | +| 0 N/A N/A 75916 C /usr/bin/python3 610MiB | +| 0 N/A N/A 75917 C /usr/bin/python3 610MiB | +| 0 N/A N/A 75918 C /usr/bin/python3 610MiB | +| 0 N/A N/A 75919 C /usr/bin/python3 610MiB | +| 1 N/A N/A 75913 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 75914 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 75915 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 75916 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 75917 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 75918 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 75919 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1705 train_time:355ms step_avg:355.30ms +step:2/1705 train_time:382ms step_avg:190.80ms +step:3/1705 train_time:444ms step_avg:148.16ms +step:4/1705 train_time:536ms step_avg:133.88ms +step:5/1705 train_time:627ms step_avg:125.44ms +step:6/1705 train_time:719ms step_avg:119.82ms +step:7/1705 train_time:810ms step_avg:115.76ms +step:8/1705 train_time:903ms step_avg:112.87ms +step:9/1705 train_time:995ms step_avg:110.52ms +step:10/1705 train_time:1087ms step_avg:108.69ms +step:11/1705 train_time:1179ms step_avg:107.22ms +step:12/1705 train_time:1274ms step_avg:106.15ms +step:13/1705 train_time:1370ms step_avg:105.36ms +step:14/1705 train_time:1465ms step_avg:104.65ms +step:15/1705 train_time:1558ms step_avg:103.87ms +step:16/1705 train_time:1651ms step_avg:103.17ms +step:17/1705 train_time:1743ms step_avg:102.53ms +step:18/1705 train_time:1835ms step_avg:101.94ms +step:19/1705 train_time:1927ms step_avg:101.43ms +step:20/1705 train_time:2020ms step_avg:101.00ms +step:21/1705 train_time:2112ms step_avg:100.58ms +step:22/1705 train_time:2206ms step_avg:100.29ms +step:23/1705 train_time:2301ms step_avg:100.05ms +step:24/1705 train_time:2395ms step_avg:99.80ms +step:25/1705 train_time:2489ms step_avg:99.57ms +step:26/1705 train_time:2583ms step_avg:99.36ms +step:27/1705 train_time:2677ms step_avg:99.14ms +step:28/1705 train_time:2769ms step_avg:98.91ms +step:29/1705 train_time:2862ms step_avg:98.69ms +step:30/1705 train_time:2955ms step_avg:98.50ms +step:31/1705 train_time:3047ms step_avg:98.30ms +step:32/1705 train_time:3141ms step_avg:98.16ms +step:33/1705 train_time:3234ms step_avg:98.01ms +step:34/1705 train_time:3327ms step_avg:97.86ms +step:35/1705 train_time:3421ms step_avg:97.75ms +step:36/1705 train_time:3515ms step_avg:97.64ms +step:37/1705 train_time:3608ms step_avg:97.51ms +step:38/1705 train_time:3702ms step_avg:97.43ms +step:39/1705 train_time:3796ms step_avg:97.34ms +step:40/1705 train_time:3889ms step_avg:97.22ms +step:41/1705 train_time:3981ms step_avg:97.10ms +step:42/1705 train_time:4073ms step_avg:96.99ms +step:43/1705 train_time:4166ms step_avg:96.89ms +step:44/1705 train_time:4260ms step_avg:96.82ms +step:45/1705 train_time:4353ms step_avg:96.73ms +step:46/1705 train_time:4446ms step_avg:96.66ms +step:47/1705 train_time:4540ms step_avg:96.59ms +step:48/1705 train_time:4634ms step_avg:96.54ms +step:49/1705 train_time:4727ms step_avg:96.47ms +step:50/1705 train_time:4821ms step_avg:96.41ms +step:51/1705 train_time:4913ms step_avg:96.34ms +step:52/1705 train_time:5006ms step_avg:96.27ms +step:53/1705 train_time:5099ms step_avg:96.21ms +step:54/1705 train_time:5192ms step_avg:96.14ms +step:55/1705 train_time:5285ms step_avg:96.10ms +step:56/1705 train_time:5379ms step_avg:96.05ms +step:57/1705 train_time:5471ms step_avg:95.99ms +step:58/1705 train_time:5565ms step_avg:95.94ms +step:59/1705 train_time:5659ms step_avg:95.91ms +step:60/1705 train_time:5751ms step_avg:95.85ms +step:61/1705 train_time:5845ms step_avg:95.81ms +step:62/1705 train_time:5938ms step_avg:95.78ms +step:63/1705 train_time:6030ms step_avg:95.72ms +step:64/1705 train_time:6124ms step_avg:95.68ms +step:65/1705 train_time:6218ms step_avg:95.66ms +step:66/1705 train_time:6311ms step_avg:95.62ms +step:67/1705 train_time:6404ms step_avg:95.58ms +step:68/1705 train_time:6498ms step_avg:95.55ms +step:69/1705 train_time:6590ms step_avg:95.51ms +step:70/1705 train_time:6684ms step_avg:95.49ms +step:71/1705 train_time:6778ms step_avg:95.46ms +step:72/1705 train_time:6870ms step_avg:95.42ms +step:73/1705 train_time:6964ms step_avg:95.40ms +step:74/1705 train_time:7057ms step_avg:95.37ms +step:75/1705 train_time:7150ms step_avg:95.33ms +step:76/1705 train_time:7243ms step_avg:95.31ms +step:77/1705 train_time:7337ms step_avg:95.28ms +step:78/1705 train_time:7429ms step_avg:95.25ms +step:79/1705 train_time:7523ms step_avg:95.22ms +step:80/1705 train_time:7616ms step_avg:95.20ms +step:81/1705 train_time:7709ms step_avg:95.17ms +step:82/1705 train_time:7802ms step_avg:95.15ms +step:83/1705 train_time:7896ms step_avg:95.13ms +step:84/1705 train_time:7989ms step_avg:95.10ms +step:85/1705 train_time:8084ms step_avg:95.10ms +step:86/1705 train_time:8177ms step_avg:95.08ms +step:87/1705 train_time:8270ms step_avg:95.06ms +step:88/1705 train_time:8364ms step_avg:95.04ms +step:89/1705 train_time:8457ms step_avg:95.02ms +step:90/1705 train_time:8550ms step_avg:95.00ms +step:91/1705 train_time:8643ms step_avg:94.98ms +step:92/1705 train_time:8736ms step_avg:94.96ms +step:93/1705 train_time:8829ms step_avg:94.93ms +step:94/1705 train_time:8922ms step_avg:94.92ms +step:95/1705 train_time:9016ms step_avg:94.90ms +step:96/1705 train_time:9108ms step_avg:94.88ms +step:97/1705 train_time:9201ms step_avg:94.86ms +step:98/1705 train_time:9294ms step_avg:94.84ms +step:99/1705 train_time:9388ms step_avg:94.83ms +step:100/1705 train_time:9481ms step_avg:94.81ms +step:101/1705 train_time:9574ms step_avg:94.79ms +step:102/1705 train_time:9667ms step_avg:94.78ms +step:103/1705 train_time:9761ms step_avg:94.77ms +step:104/1705 train_time:9853ms step_avg:94.74ms +step:105/1705 train_time:9947ms step_avg:94.73ms +step:106/1705 train_time:10040ms step_avg:94.71ms +step:107/1705 train_time:10132ms step_avg:94.69ms +step:108/1705 train_time:10226ms step_avg:94.68ms +step:109/1705 train_time:10319ms step_avg:94.67ms +step:110/1705 train_time:10412ms step_avg:94.65ms +step:111/1705 train_time:10505ms step_avg:94.64ms +step:112/1705 train_time:10599ms step_avg:94.63ms +step:113/1705 train_time:10692ms step_avg:94.62ms +step:114/1705 train_time:10785ms step_avg:94.61ms +step:115/1705 train_time:10878ms step_avg:94.60ms +step:116/1705 train_time:10971ms step_avg:94.58ms +step:117/1705 train_time:11064ms step_avg:94.56ms +step:118/1705 train_time:11156ms step_avg:94.55ms +step:119/1705 train_time:11249ms step_avg:94.53ms +step:120/1705 train_time:11342ms step_avg:94.52ms +step:121/1705 train_time:11435ms step_avg:94.51ms +step:122/1705 train_time:11528ms step_avg:94.49ms +step:123/1705 train_time:11621ms step_avg:94.48ms +step:124/1705 train_time:11715ms step_avg:94.47ms +step:125/1705 train_time:11807ms step_avg:94.46ms +step:125/1705 val_loss:4.3039 train_time:11901ms step_avg:95.21ms +step:126/1705 train_time:11925ms step_avg:94.65ms +step:127/1705 train_time:11999ms step_avg:94.48ms +step:128/1705 train_time:12101ms step_avg:94.54ms +step:129/1705 train_time:12195ms step_avg:94.54ms +step:130/1705 train_time:12289ms step_avg:94.53ms +step:131/1705 train_time:12383ms step_avg:94.52ms +step:132/1705 train_time:12473ms step_avg:94.50ms +step:133/1705 train_time:12565ms step_avg:94.48ms +step:134/1705 train_time:12657ms step_avg:94.46ms +step:135/1705 train_time:12749ms step_avg:94.44ms +step:136/1705 train_time:12842ms step_avg:94.42ms +step:137/1705 train_time:12935ms step_avg:94.41ms +step:138/1705 train_time:13030ms step_avg:94.42ms +step:139/1705 train_time:13126ms step_avg:94.43ms +step:140/1705 train_time:13220ms step_avg:94.43ms +step:141/1705 train_time:13313ms step_avg:94.42ms +step:142/1705 train_time:13406ms step_avg:94.41ms +step:143/1705 train_time:13498ms step_avg:94.39ms +step:144/1705 train_time:13591ms step_avg:94.38ms +step:145/1705 train_time:13683ms step_avg:94.37ms +step:146/1705 train_time:13775ms step_avg:94.35ms +step:147/1705 train_time:13869ms step_avg:94.35ms +step:148/1705 train_time:13962ms step_avg:94.33ms +step:149/1705 train_time:14055ms step_avg:94.33ms +step:150/1705 train_time:14149ms step_avg:94.33ms +step:151/1705 train_time:14243ms step_avg:94.33ms +step:152/1705 train_time:14336ms step_avg:94.31ms +step:153/1705 train_time:14429ms step_avg:94.31ms +step:154/1705 train_time:14522ms step_avg:94.30ms +step:155/1705 train_time:14614ms step_avg:94.28ms +step:156/1705 train_time:14707ms step_avg:94.28ms +step:157/1705 train_time:14800ms step_avg:94.26ms +step:158/1705 train_time:14892ms step_avg:94.25ms +step:159/1705 train_time:14985ms step_avg:94.25ms +step:160/1705 train_time:15079ms step_avg:94.24ms +step:161/1705 train_time:15172ms step_avg:94.24ms +step:162/1705 train_time:15266ms step_avg:94.24ms +step:163/1705 train_time:15360ms step_avg:94.23ms +step:164/1705 train_time:15452ms step_avg:94.22ms +step:165/1705 train_time:15545ms step_avg:94.21ms +step:166/1705 train_time:15637ms step_avg:94.20ms +step:167/1705 train_time:15730ms step_avg:94.19ms +step:168/1705 train_time:15822ms step_avg:94.18ms +step:169/1705 train_time:15915ms step_avg:94.17ms +step:170/1705 train_time:16008ms step_avg:94.17ms +step:171/1705 train_time:16101ms step_avg:94.16ms +step:172/1705 train_time:16194ms step_avg:94.15ms +step:173/1705 train_time:16289ms step_avg:94.15ms +step:174/1705 train_time:16382ms step_avg:94.15ms +step:175/1705 train_time:16475ms step_avg:94.14ms +step:176/1705 train_time:16567ms step_avg:94.13ms +step:177/1705 train_time:16660ms step_avg:94.12ms +step:178/1705 train_time:16753ms step_avg:94.12ms +step:179/1705 train_time:16846ms step_avg:94.11ms +step:180/1705 train_time:16938ms step_avg:94.10ms +step:181/1705 train_time:17032ms step_avg:94.10ms +step:182/1705 train_time:17125ms step_avg:94.10ms +step:183/1705 train_time:17218ms step_avg:94.09ms +step:184/1705 train_time:17311ms step_avg:94.08ms +step:185/1705 train_time:17404ms step_avg:94.07ms +step:186/1705 train_time:17496ms step_avg:94.07ms +step:187/1705 train_time:17589ms step_avg:94.06ms +step:188/1705 train_time:17682ms step_avg:94.05ms +step:189/1705 train_time:17775ms step_avg:94.05ms +step:190/1705 train_time:17867ms step_avg:94.04ms +step:191/1705 train_time:17960ms step_avg:94.03ms +step:192/1705 train_time:18053ms step_avg:94.03ms +step:193/1705 train_time:18147ms step_avg:94.03ms +step:194/1705 train_time:18241ms step_avg:94.02ms +step:195/1705 train_time:18334ms step_avg:94.02ms +step:196/1705 train_time:18427ms step_avg:94.02ms +step:197/1705 train_time:18519ms step_avg:94.01ms +step:198/1705 train_time:18612ms step_avg:94.00ms +step:199/1705 train_time:18705ms step_avg:94.00ms +step:200/1705 train_time:18797ms step_avg:93.99ms +step:201/1705 train_time:18890ms step_avg:93.98ms +step:202/1705 train_time:18983ms step_avg:93.97ms +step:203/1705 train_time:19076ms step_avg:93.97ms +step:204/1705 train_time:19169ms step_avg:93.97ms +step:205/1705 train_time:19263ms step_avg:93.96ms +step:206/1705 train_time:19356ms step_avg:93.96ms +step:207/1705 train_time:19450ms step_avg:93.96ms +step:208/1705 train_time:19544ms step_avg:93.96ms +step:209/1705 train_time:19637ms step_avg:93.96ms +step:210/1705 train_time:19730ms step_avg:93.95ms +step:211/1705 train_time:19823ms step_avg:93.95ms +step:212/1705 train_time:19916ms step_avg:93.94ms +step:213/1705 train_time:20170ms step_avg:94.70ms +step:214/1705 train_time:20340ms step_avg:95.05ms +step:215/1705 train_time:20432ms step_avg:95.03ms +step:216/1705 train_time:20524ms step_avg:95.02ms +step:217/1705 train_time:20615ms step_avg:95.00ms +step:218/1705 train_time:20708ms step_avg:94.99ms +step:219/1705 train_time:20800ms step_avg:94.98ms +step:220/1705 train_time:20892ms step_avg:94.96ms +step:221/1705 train_time:20984ms step_avg:94.95ms +step:222/1705 train_time:21075ms step_avg:94.93ms +step:223/1705 train_time:21168ms step_avg:94.93ms +step:224/1705 train_time:21262ms step_avg:94.92ms +step:225/1705 train_time:21357ms step_avg:94.92ms +step:226/1705 train_time:21451ms step_avg:94.92ms +step:227/1705 train_time:21545ms step_avg:94.91ms +step:228/1705 train_time:21637ms step_avg:94.90ms +step:229/1705 train_time:21730ms step_avg:94.89ms +step:230/1705 train_time:21823ms step_avg:94.88ms +step:231/1705 train_time:21915ms step_avg:94.87ms +step:232/1705 train_time:22008ms step_avg:94.86ms +step:233/1705 train_time:22099ms step_avg:94.85ms +step:234/1705 train_time:22192ms step_avg:94.84ms +step:235/1705 train_time:22285ms step_avg:94.83ms +step:236/1705 train_time:22379ms step_avg:94.83ms +step:237/1705 train_time:22472ms step_avg:94.82ms +step:238/1705 train_time:22565ms step_avg:94.81ms +step:239/1705 train_time:22658ms step_avg:94.80ms +step:240/1705 train_time:22751ms step_avg:94.79ms +step:241/1705 train_time:22844ms step_avg:94.79ms +step:242/1705 train_time:22936ms step_avg:94.78ms +step:243/1705 train_time:23029ms step_avg:94.77ms +step:244/1705 train_time:23121ms step_avg:94.76ms +step:245/1705 train_time:23214ms step_avg:94.75ms +step:246/1705 train_time:23307ms step_avg:94.74ms +step:247/1705 train_time:23400ms step_avg:94.74ms +step:248/1705 train_time:23493ms step_avg:94.73ms +step:249/1705 train_time:23587ms step_avg:94.73ms +step:250/1705 train_time:23680ms step_avg:94.72ms +step:250/1705 val_loss:3.9663 train_time:23774ms step_avg:95.10ms +step:251/1705 train_time:23798ms step_avg:94.81ms +step:252/1705 train_time:23869ms step_avg:94.72ms +step:253/1705 train_time:23965ms step_avg:94.72ms +step:254/1705 train_time:24066ms step_avg:94.75ms +step:255/1705 train_time:24160ms step_avg:94.75ms +step:256/1705 train_time:24252ms step_avg:94.73ms +step:257/1705 train_time:24343ms step_avg:94.72ms +step:258/1705 train_time:24436ms step_avg:94.71ms +step:259/1705 train_time:24527ms step_avg:94.70ms +step:260/1705 train_time:24619ms step_avg:94.69ms +step:261/1705 train_time:24714ms step_avg:94.69ms +step:262/1705 train_time:24808ms step_avg:94.69ms +step:263/1705 train_time:24902ms step_avg:94.68ms +step:264/1705 train_time:24996ms step_avg:94.68ms +step:265/1705 train_time:25090ms step_avg:94.68ms +step:266/1705 train_time:25183ms step_avg:94.67ms +step:267/1705 train_time:25275ms step_avg:94.66ms +step:268/1705 train_time:25367ms step_avg:94.65ms +step:269/1705 train_time:25459ms step_avg:94.64ms +step:270/1705 train_time:25551ms step_avg:94.63ms +step:271/1705 train_time:25643ms step_avg:94.62ms +step:272/1705 train_time:25736ms step_avg:94.62ms +step:273/1705 train_time:25830ms step_avg:94.61ms +step:274/1705 train_time:25922ms step_avg:94.61ms +step:275/1705 train_time:26016ms step_avg:94.60ms +step:276/1705 train_time:26109ms step_avg:94.60ms +step:277/1705 train_time:26202ms step_avg:94.59ms +step:278/1705 train_time:26296ms step_avg:94.59ms +step:279/1705 train_time:26388ms step_avg:94.58ms +step:280/1705 train_time:26481ms step_avg:94.57ms +step:281/1705 train_time:26573ms step_avg:94.57ms +step:282/1705 train_time:26666ms step_avg:94.56ms +step:283/1705 train_time:26758ms step_avg:94.55ms +step:284/1705 train_time:26851ms step_avg:94.54ms +step:285/1705 train_time:26944ms step_avg:94.54ms +step:286/1705 train_time:27037ms step_avg:94.54ms +step:287/1705 train_time:27131ms step_avg:94.53ms +step:288/1705 train_time:27223ms step_avg:94.52ms +step:289/1705 train_time:27316ms step_avg:94.52ms +step:290/1705 train_time:27409ms step_avg:94.51ms +step:291/1705 train_time:27501ms step_avg:94.51ms +step:292/1705 train_time:27594ms step_avg:94.50ms +step:293/1705 train_time:27685ms step_avg:94.49ms +step:294/1705 train_time:27778ms step_avg:94.48ms +step:295/1705 train_time:27871ms step_avg:94.48ms +step:296/1705 train_time:27965ms step_avg:94.48ms +step:297/1705 train_time:28058ms step_avg:94.47ms +step:298/1705 train_time:28151ms step_avg:94.46ms +step:299/1705 train_time:28243ms step_avg:94.46ms +step:300/1705 train_time:28336ms step_avg:94.45ms +step:301/1705 train_time:28429ms step_avg:94.45ms +step:302/1705 train_time:28521ms step_avg:94.44ms +step:303/1705 train_time:28614ms step_avg:94.44ms +step:304/1705 train_time:28707ms step_avg:94.43ms +step:305/1705 train_time:28800ms step_avg:94.43ms +step:306/1705 train_time:28893ms step_avg:94.42ms +step:307/1705 train_time:28985ms step_avg:94.41ms +step:308/1705 train_time:29078ms step_avg:94.41ms +step:309/1705 train_time:29171ms step_avg:94.40ms +step:310/1705 train_time:29264ms step_avg:94.40ms +step:311/1705 train_time:29357ms step_avg:94.39ms +step:312/1705 train_time:29449ms step_avg:94.39ms +step:313/1705 train_time:29541ms step_avg:94.38ms +step:314/1705 train_time:29635ms step_avg:94.38ms +step:315/1705 train_time:29728ms step_avg:94.37ms +step:316/1705 train_time:29820ms step_avg:94.37ms +step:317/1705 train_time:29914ms step_avg:94.37ms +step:318/1705 train_time:30007ms step_avg:94.36ms +step:319/1705 train_time:30099ms step_avg:94.35ms +step:320/1705 train_time:30193ms step_avg:94.35ms +step:321/1705 train_time:30285ms step_avg:94.35ms +step:322/1705 train_time:30378ms step_avg:94.34ms +step:323/1705 train_time:30471ms step_avg:94.34ms +step:324/1705 train_time:30564ms step_avg:94.33ms +step:325/1705 train_time:30656ms step_avg:94.33ms +step:326/1705 train_time:30750ms step_avg:94.32ms +step:327/1705 train_time:30842ms step_avg:94.32ms +step:328/1705 train_time:30936ms step_avg:94.32ms +step:329/1705 train_time:31029ms step_avg:94.31ms +step:330/1705 train_time:31121ms step_avg:94.31ms +step:331/1705 train_time:31215ms step_avg:94.30ms +step:332/1705 train_time:31308ms step_avg:94.30ms +step:333/1705 train_time:31401ms step_avg:94.30ms +step:334/1705 train_time:31494ms step_avg:94.29ms +step:335/1705 train_time:31587ms step_avg:94.29ms +step:336/1705 train_time:31679ms step_avg:94.28ms +step:337/1705 train_time:31773ms step_avg:94.28ms +step:338/1705 train_time:31865ms step_avg:94.28ms +step:339/1705 train_time:31959ms step_avg:94.27ms +step:340/1705 train_time:32052ms step_avg:94.27ms +step:341/1705 train_time:32145ms step_avg:94.27ms +step:342/1705 train_time:32237ms step_avg:94.26ms +step:343/1705 train_time:32331ms step_avg:94.26ms +step:344/1705 train_time:32424ms step_avg:94.25ms +step:345/1705 train_time:32517ms step_avg:94.25ms +step:346/1705 train_time:32610ms step_avg:94.25ms +step:347/1705 train_time:32702ms step_avg:94.24ms +step:348/1705 train_time:32795ms step_avg:94.24ms +step:349/1705 train_time:32888ms step_avg:94.24ms +step:350/1705 train_time:32982ms step_avg:94.23ms +step:351/1705 train_time:33075ms step_avg:94.23ms +step:352/1705 train_time:33168ms step_avg:94.23ms +step:353/1705 train_time:33260ms step_avg:94.22ms +step:354/1705 train_time:33354ms step_avg:94.22ms +step:355/1705 train_time:33446ms step_avg:94.22ms +step:356/1705 train_time:33539ms step_avg:94.21ms +step:357/1705 train_time:33633ms step_avg:94.21ms +step:358/1705 train_time:33726ms step_avg:94.21ms +step:359/1705 train_time:33819ms step_avg:94.20ms +step:360/1705 train_time:33912ms step_avg:94.20ms +step:361/1705 train_time:34005ms step_avg:94.20ms +step:362/1705 train_time:34098ms step_avg:94.19ms +step:363/1705 train_time:34191ms step_avg:94.19ms +step:364/1705 train_time:34283ms step_avg:94.19ms +step:365/1705 train_time:34376ms step_avg:94.18ms +step:366/1705 train_time:34469ms step_avg:94.18ms +step:367/1705 train_time:34561ms step_avg:94.17ms +step:368/1705 train_time:34655ms step_avg:94.17ms +step:369/1705 train_time:34748ms step_avg:94.17ms +step:370/1705 train_time:34840ms step_avg:94.16ms +step:371/1705 train_time:34935ms step_avg:94.16ms +step:372/1705 train_time:35028ms step_avg:94.16ms +step:373/1705 train_time:35120ms step_avg:94.16ms +step:374/1705 train_time:35214ms step_avg:94.15ms +step:375/1705 train_time:35307ms step_avg:94.15ms +step:375/1705 val_loss:3.8168 train_time:35400ms step_avg:94.40ms +step:376/1705 train_time:35422ms step_avg:94.21ms +step:377/1705 train_time:35499ms step_avg:94.16ms +step:378/1705 train_time:35597ms step_avg:94.17ms +step:379/1705 train_time:35691ms step_avg:94.17ms +step:380/1705 train_time:35784ms step_avg:94.17ms +step:381/1705 train_time:35875ms step_avg:94.16ms +step:382/1705 train_time:35967ms step_avg:94.16ms +step:383/1705 train_time:36059ms step_avg:94.15ms +step:384/1705 train_time:36151ms step_avg:94.14ms +step:385/1705 train_time:36243ms step_avg:94.14ms +step:386/1705 train_time:36336ms step_avg:94.13ms +step:387/1705 train_time:36430ms step_avg:94.13ms +step:388/1705 train_time:36524ms step_avg:94.13ms +step:389/1705 train_time:36617ms step_avg:94.13ms +step:390/1705 train_time:36711ms step_avg:94.13ms +step:391/1705 train_time:36805ms step_avg:94.13ms +step:392/1705 train_time:36897ms step_avg:94.12ms +step:393/1705 train_time:36990ms step_avg:94.12ms +step:394/1705 train_time:37083ms step_avg:94.12ms +step:395/1705 train_time:37175ms step_avg:94.11ms +step:396/1705 train_time:37267ms step_avg:94.11ms +step:397/1705 train_time:37360ms step_avg:94.11ms +step:398/1705 train_time:37453ms step_avg:94.10ms +step:399/1705 train_time:37548ms step_avg:94.10ms +step:400/1705 train_time:37641ms step_avg:94.10ms +step:401/1705 train_time:37734ms step_avg:94.10ms +step:402/1705 train_time:37828ms step_avg:94.10ms +step:403/1705 train_time:37921ms step_avg:94.10ms +step:404/1705 train_time:38013ms step_avg:94.09ms +step:405/1705 train_time:38106ms step_avg:94.09ms +step:406/1705 train_time:38197ms step_avg:94.08ms +step:407/1705 train_time:38290ms step_avg:94.08ms +step:408/1705 train_time:38382ms step_avg:94.07ms +step:409/1705 train_time:38475ms step_avg:94.07ms +step:410/1705 train_time:38569ms step_avg:94.07ms +step:411/1705 train_time:38663ms step_avg:94.07ms +step:412/1705 train_time:38756ms step_avg:94.07ms +step:413/1705 train_time:38849ms step_avg:94.07ms +step:414/1705 train_time:38942ms step_avg:94.06ms +step:415/1705 train_time:39035ms step_avg:94.06ms +step:416/1705 train_time:39127ms step_avg:94.06ms +step:417/1705 train_time:39219ms step_avg:94.05ms +step:418/1705 train_time:39312ms step_avg:94.05ms +step:419/1705 train_time:39405ms step_avg:94.05ms +step:420/1705 train_time:39498ms step_avg:94.04ms +step:421/1705 train_time:39592ms step_avg:94.04ms +step:422/1705 train_time:39685ms step_avg:94.04ms +step:423/1705 train_time:39778ms step_avg:94.04ms +step:424/1705 train_time:39871ms step_avg:94.04ms +step:425/1705 train_time:40155ms step_avg:94.48ms +step:426/1705 train_time:40268ms step_avg:94.53ms +step:427/1705 train_time:40359ms step_avg:94.52ms +step:428/1705 train_time:40451ms step_avg:94.51ms +step:429/1705 train_time:40543ms step_avg:94.51ms +step:430/1705 train_time:40635ms step_avg:94.50ms +step:431/1705 train_time:40727ms step_avg:94.49ms +step:432/1705 train_time:40819ms step_avg:94.49ms +step:433/1705 train_time:40911ms step_avg:94.48ms +step:434/1705 train_time:41003ms step_avg:94.48ms +step:435/1705 train_time:41097ms step_avg:94.48ms +step:436/1705 train_time:41192ms step_avg:94.48ms +step:437/1705 train_time:41288ms step_avg:94.48ms +step:438/1705 train_time:41381ms step_avg:94.48ms +step:439/1705 train_time:41474ms step_avg:94.47ms +step:440/1705 train_time:41567ms step_avg:94.47ms +step:441/1705 train_time:41659ms step_avg:94.47ms +step:442/1705 train_time:41752ms step_avg:94.46ms +step:443/1705 train_time:41844ms step_avg:94.46ms +step:444/1705 train_time:41936ms step_avg:94.45ms +step:445/1705 train_time:42029ms step_avg:94.45ms +step:446/1705 train_time:42123ms step_avg:94.45ms +step:447/1705 train_time:42216ms step_avg:94.44ms +step:448/1705 train_time:42310ms step_avg:94.44ms +step:449/1705 train_time:42404ms step_avg:94.44ms +step:450/1705 train_time:42496ms step_avg:94.44ms +step:451/1705 train_time:42590ms step_avg:94.43ms +step:452/1705 train_time:42683ms step_avg:94.43ms +step:453/1705 train_time:42775ms step_avg:94.43ms +step:454/1705 train_time:42867ms step_avg:94.42ms +step:455/1705 train_time:42960ms step_avg:94.42ms +step:456/1705 train_time:43053ms step_avg:94.41ms +step:457/1705 train_time:43146ms step_avg:94.41ms +step:458/1705 train_time:43239ms step_avg:94.41ms +step:459/1705 train_time:43332ms step_avg:94.41ms +step:460/1705 train_time:43426ms step_avg:94.40ms +step:461/1705 train_time:43518ms step_avg:94.40ms +step:462/1705 train_time:43612ms step_avg:94.40ms +step:463/1705 train_time:43705ms step_avg:94.40ms +step:464/1705 train_time:43797ms step_avg:94.39ms +step:465/1705 train_time:43890ms step_avg:94.39ms +step:466/1705 train_time:43983ms step_avg:94.38ms +step:467/1705 train_time:44076ms step_avg:94.38ms +step:468/1705 train_time:44170ms step_avg:94.38ms +step:469/1705 train_time:44263ms step_avg:94.38ms +step:470/1705 train_time:44356ms step_avg:94.37ms +step:471/1705 train_time:44449ms step_avg:94.37ms +step:472/1705 train_time:44542ms step_avg:94.37ms +step:473/1705 train_time:44635ms step_avg:94.37ms +step:474/1705 train_time:44728ms step_avg:94.36ms +step:475/1705 train_time:44822ms step_avg:94.36ms +step:476/1705 train_time:44914ms step_avg:94.36ms +step:477/1705 train_time:45008ms step_avg:94.36ms +step:478/1705 train_time:45101ms step_avg:94.35ms +step:479/1705 train_time:45194ms step_avg:94.35ms +step:480/1705 train_time:45287ms step_avg:94.35ms +step:481/1705 train_time:45381ms step_avg:94.35ms +step:482/1705 train_time:45474ms step_avg:94.34ms +step:483/1705 train_time:45567ms step_avg:94.34ms +step:484/1705 train_time:45660ms step_avg:94.34ms +step:485/1705 train_time:45752ms step_avg:94.33ms +step:486/1705 train_time:45845ms step_avg:94.33ms +step:487/1705 train_time:45938ms step_avg:94.33ms +step:488/1705 train_time:46031ms step_avg:94.33ms +step:489/1705 train_time:46125ms step_avg:94.32ms +step:490/1705 train_time:46217ms step_avg:94.32ms +step:491/1705 train_time:46311ms step_avg:94.32ms +step:492/1705 train_time:46404ms step_avg:94.32ms +step:493/1705 train_time:46496ms step_avg:94.31ms +step:494/1705 train_time:46589ms step_avg:94.31ms +step:495/1705 train_time:46683ms step_avg:94.31ms +step:496/1705 train_time:46776ms step_avg:94.31ms +step:497/1705 train_time:46868ms step_avg:94.30ms +step:498/1705 train_time:46961ms step_avg:94.30ms +step:499/1705 train_time:47054ms step_avg:94.30ms +step:500/1705 train_time:47148ms step_avg:94.30ms +step:500/1705 val_loss:3.7144 train_time:47241ms step_avg:94.48ms +step:501/1705 train_time:47264ms step_avg:94.34ms +step:502/1705 train_time:47341ms step_avg:94.31ms +step:503/1705 train_time:47438ms step_avg:94.31ms +step:504/1705 train_time:47533ms step_avg:94.31ms +step:505/1705 train_time:47625ms step_avg:94.31ms +step:506/1705 train_time:47718ms step_avg:94.30ms +step:507/1705 train_time:47809ms step_avg:94.30ms +step:508/1705 train_time:47902ms step_avg:94.29ms +step:509/1705 train_time:47993ms step_avg:94.29ms +step:510/1705 train_time:48085ms step_avg:94.29ms +step:511/1705 train_time:48178ms step_avg:94.28ms +step:512/1705 train_time:48273ms step_avg:94.28ms +step:513/1705 train_time:48368ms step_avg:94.28ms +step:514/1705 train_time:48462ms step_avg:94.28ms +step:515/1705 train_time:48556ms step_avg:94.28ms +step:516/1705 train_time:48649ms step_avg:94.28ms +step:517/1705 train_time:48743ms step_avg:94.28ms +step:518/1705 train_time:48835ms step_avg:94.28ms +step:519/1705 train_time:48927ms step_avg:94.27ms +step:520/1705 train_time:49019ms step_avg:94.27ms +step:521/1705 train_time:49112ms step_avg:94.26ms +step:522/1705 train_time:49205ms step_avg:94.26ms +step:523/1705 train_time:49298ms step_avg:94.26ms +step:524/1705 train_time:49391ms step_avg:94.26ms +step:525/1705 train_time:49485ms step_avg:94.26ms +step:526/1705 train_time:49580ms step_avg:94.26ms +step:527/1705 train_time:49673ms step_avg:94.26ms +step:528/1705 train_time:49766ms step_avg:94.25ms +step:529/1705 train_time:49858ms step_avg:94.25ms +step:530/1705 train_time:49951ms step_avg:94.25ms +step:531/1705 train_time:50043ms step_avg:94.24ms +step:532/1705 train_time:50135ms step_avg:94.24ms +step:533/1705 train_time:50229ms step_avg:94.24ms +step:534/1705 train_time:50322ms step_avg:94.24ms +step:535/1705 train_time:50415ms step_avg:94.23ms +step:536/1705 train_time:50509ms step_avg:94.23ms +step:537/1705 train_time:50602ms step_avg:94.23ms +step:538/1705 train_time:50694ms step_avg:94.23ms +step:539/1705 train_time:50788ms step_avg:94.23ms +step:540/1705 train_time:50881ms step_avg:94.22ms +step:541/1705 train_time:50973ms step_avg:94.22ms +step:542/1705 train_time:51066ms step_avg:94.22ms +step:543/1705 train_time:51159ms step_avg:94.21ms +step:544/1705 train_time:51251ms step_avg:94.21ms +step:545/1705 train_time:51345ms step_avg:94.21ms +step:546/1705 train_time:51439ms step_avg:94.21ms +step:547/1705 train_time:51532ms step_avg:94.21ms +step:548/1705 train_time:51626ms step_avg:94.21ms +step:549/1705 train_time:51721ms step_avg:94.21ms +step:550/1705 train_time:51813ms step_avg:94.21ms +step:551/1705 train_time:51907ms step_avg:94.20ms +step:552/1705 train_time:52000ms step_avg:94.20ms +step:553/1705 train_time:52093ms step_avg:94.20ms +step:554/1705 train_time:52186ms step_avg:94.20ms +step:555/1705 train_time:52279ms step_avg:94.20ms +step:556/1705 train_time:52372ms step_avg:94.19ms +step:557/1705 train_time:52466ms step_avg:94.19ms +step:558/1705 train_time:52560ms step_avg:94.19ms +step:559/1705 train_time:52653ms step_avg:94.19ms +step:560/1705 train_time:52746ms step_avg:94.19ms +step:561/1705 train_time:52839ms step_avg:94.19ms +step:562/1705 train_time:52932ms step_avg:94.18ms +step:563/1705 train_time:53025ms step_avg:94.18ms +step:564/1705 train_time:53119ms step_avg:94.18ms +step:565/1705 train_time:53211ms step_avg:94.18ms +step:566/1705 train_time:53304ms step_avg:94.18ms +step:567/1705 train_time:53397ms step_avg:94.17ms +step:568/1705 train_time:53490ms step_avg:94.17ms +step:569/1705 train_time:53584ms step_avg:94.17ms +step:570/1705 train_time:53677ms step_avg:94.17ms +step:571/1705 train_time:53772ms step_avg:94.17ms +step:572/1705 train_time:53866ms step_avg:94.17ms +step:573/1705 train_time:53960ms step_avg:94.17ms +step:574/1705 train_time:54054ms step_avg:94.17ms +step:575/1705 train_time:54148ms step_avg:94.17ms +step:576/1705 train_time:54242ms step_avg:94.17ms +step:577/1705 train_time:54337ms step_avg:94.17ms +step:578/1705 train_time:54431ms step_avg:94.17ms +step:579/1705 train_time:54526ms step_avg:94.17ms +step:580/1705 train_time:54621ms step_avg:94.17ms +step:581/1705 train_time:54715ms step_avg:94.17ms +step:582/1705 train_time:54809ms step_avg:94.17ms +step:583/1705 train_time:54904ms step_avg:94.17ms +step:584/1705 train_time:54999ms step_avg:94.18ms +step:585/1705 train_time:55092ms step_avg:94.18ms +step:586/1705 train_time:55187ms step_avg:94.18ms +step:587/1705 train_time:55282ms step_avg:94.18ms +step:588/1705 train_time:55376ms step_avg:94.18ms +step:589/1705 train_time:55471ms step_avg:94.18ms +step:590/1705 train_time:55566ms step_avg:94.18ms +step:591/1705 train_time:55661ms step_avg:94.18ms +step:592/1705 train_time:55755ms step_avg:94.18ms +step:593/1705 train_time:55849ms step_avg:94.18ms +step:594/1705 train_time:55944ms step_avg:94.18ms +step:595/1705 train_time:56039ms step_avg:94.18ms +step:596/1705 train_time:56133ms step_avg:94.18ms +step:597/1705 train_time:56228ms step_avg:94.18ms +step:598/1705 train_time:56322ms step_avg:94.18ms +step:599/1705 train_time:56416ms step_avg:94.18ms +step:600/1705 train_time:56510ms step_avg:94.18ms +step:601/1705 train_time:56605ms step_avg:94.18ms +step:602/1705 train_time:56700ms step_avg:94.19ms +step:603/1705 train_time:56794ms step_avg:94.19ms +step:604/1705 train_time:56888ms step_avg:94.19ms +step:605/1705 train_time:56983ms step_avg:94.19ms +step:606/1705 train_time:57078ms step_avg:94.19ms +step:607/1705 train_time:57172ms step_avg:94.19ms +step:608/1705 train_time:57266ms step_avg:94.19ms +step:609/1705 train_time:57361ms step_avg:94.19ms +step:610/1705 train_time:57455ms step_avg:94.19ms +step:611/1705 train_time:57550ms step_avg:94.19ms +step:612/1705 train_time:57645ms step_avg:94.19ms +step:613/1705 train_time:57740ms step_avg:94.19ms +step:614/1705 train_time:57834ms step_avg:94.19ms +step:615/1705 train_time:57928ms step_avg:94.19ms +step:616/1705 train_time:58023ms step_avg:94.19ms +step:617/1705 train_time:58118ms step_avg:94.19ms +step:618/1705 train_time:58211ms step_avg:94.19ms +step:619/1705 train_time:58307ms step_avg:94.20ms +step:620/1705 train_time:58402ms step_avg:94.20ms +step:621/1705 train_time:58497ms step_avg:94.20ms +step:622/1705 train_time:58591ms step_avg:94.20ms +step:623/1705 train_time:58685ms step_avg:94.20ms +step:624/1705 train_time:58779ms step_avg:94.20ms +step:625/1705 train_time:58873ms step_avg:94.20ms +step:625/1705 val_loss:3.6166 train_time:58967ms step_avg:94.35ms +step:626/1705 train_time:58991ms step_avg:94.23ms +step:627/1705 train_time:59073ms step_avg:94.21ms +step:628/1705 train_time:59170ms step_avg:94.22ms +step:629/1705 train_time:59265ms step_avg:94.22ms +step:630/1705 train_time:59358ms step_avg:94.22ms +step:631/1705 train_time:59451ms step_avg:94.22ms +step:632/1705 train_time:59544ms step_avg:94.22ms +step:633/1705 train_time:59638ms step_avg:94.21ms +step:634/1705 train_time:59731ms step_avg:94.21ms +step:635/1705 train_time:59824ms step_avg:94.21ms +step:636/1705 train_time:59918ms step_avg:94.21ms +step:637/1705 train_time:60014ms step_avg:94.21ms +step:638/1705 train_time:60112ms step_avg:94.22ms +step:639/1705 train_time:60466ms step_avg:94.63ms +step:640/1705 train_time:60573ms step_avg:94.65ms +step:641/1705 train_time:60666ms step_avg:94.64ms +step:642/1705 train_time:60759ms step_avg:94.64ms +step:643/1705 train_time:60853ms step_avg:94.64ms +step:644/1705 train_time:60946ms step_avg:94.64ms +step:645/1705 train_time:61039ms step_avg:94.63ms +step:646/1705 train_time:61132ms step_avg:94.63ms +step:647/1705 train_time:61225ms step_avg:94.63ms +step:648/1705 train_time:61319ms step_avg:94.63ms +step:649/1705 train_time:61415ms step_avg:94.63ms +step:650/1705 train_time:61514ms step_avg:94.64ms +step:651/1705 train_time:61611ms step_avg:94.64ms +step:652/1705 train_time:61705ms step_avg:94.64ms +step:653/1705 train_time:61798ms step_avg:94.64ms +step:654/1705 train_time:61892ms step_avg:94.64ms +step:655/1705 train_time:61986ms step_avg:94.64ms +step:656/1705 train_time:62079ms step_avg:94.63ms +step:657/1705 train_time:62173ms step_avg:94.63ms +step:658/1705 train_time:62266ms step_avg:94.63ms +step:659/1705 train_time:62360ms step_avg:94.63ms +step:660/1705 train_time:62455ms step_avg:94.63ms +step:661/1705 train_time:62551ms step_avg:94.63ms +step:662/1705 train_time:62646ms step_avg:94.63ms +step:663/1705 train_time:62740ms step_avg:94.63ms +step:664/1705 train_time:62834ms step_avg:94.63ms +step:665/1705 train_time:62928ms step_avg:94.63ms +step:666/1705 train_time:63021ms step_avg:94.63ms +step:667/1705 train_time:63115ms step_avg:94.63ms +step:668/1705 train_time:63210ms step_avg:94.63ms +step:669/1705 train_time:63304ms step_avg:94.62ms +step:670/1705 train_time:63398ms step_avg:94.62ms +step:671/1705 train_time:63493ms step_avg:94.62ms +step:672/1705 train_time:63589ms step_avg:94.63ms +step:673/1705 train_time:63685ms step_avg:94.63ms +step:674/1705 train_time:63779ms step_avg:94.63ms +step:675/1705 train_time:63873ms step_avg:94.63ms +step:676/1705 train_time:63968ms step_avg:94.63ms +step:677/1705 train_time:64061ms step_avg:94.63ms +step:678/1705 train_time:64155ms step_avg:94.62ms +step:679/1705 train_time:64249ms step_avg:94.62ms +step:680/1705 train_time:64343ms step_avg:94.62ms +step:681/1705 train_time:64437ms step_avg:94.62ms +step:682/1705 train_time:64533ms step_avg:94.62ms +step:683/1705 train_time:64627ms step_avg:94.62ms +step:684/1705 train_time:64722ms step_avg:94.62ms +step:685/1705 train_time:64816ms step_avg:94.62ms +step:686/1705 train_time:64911ms step_avg:94.62ms +step:687/1705 train_time:65006ms step_avg:94.62ms +step:688/1705 train_time:65100ms step_avg:94.62ms +step:689/1705 train_time:65194ms step_avg:94.62ms +step:690/1705 train_time:65289ms step_avg:94.62ms +step:691/1705 train_time:65383ms step_avg:94.62ms +step:692/1705 train_time:65477ms step_avg:94.62ms +step:693/1705 train_time:65572ms step_avg:94.62ms +step:694/1705 train_time:65666ms step_avg:94.62ms +step:695/1705 train_time:65761ms step_avg:94.62ms +step:696/1705 train_time:65855ms step_avg:94.62ms +step:697/1705 train_time:65950ms step_avg:94.62ms +step:698/1705 train_time:66045ms step_avg:94.62ms +step:699/1705 train_time:66138ms step_avg:94.62ms +step:700/1705 train_time:66233ms step_avg:94.62ms +step:701/1705 train_time:66328ms step_avg:94.62ms +step:702/1705 train_time:66422ms step_avg:94.62ms +step:703/1705 train_time:66516ms step_avg:94.62ms +step:704/1705 train_time:66611ms step_avg:94.62ms +step:705/1705 train_time:66704ms step_avg:94.62ms +step:706/1705 train_time:66798ms step_avg:94.62ms +step:707/1705 train_time:66893ms step_avg:94.62ms +step:708/1705 train_time:66988ms step_avg:94.62ms +step:709/1705 train_time:67081ms step_avg:94.61ms +step:710/1705 train_time:67176ms step_avg:94.61ms +step:711/1705 train_time:67270ms step_avg:94.61ms +step:712/1705 train_time:67366ms step_avg:94.61ms +step:713/1705 train_time:67459ms step_avg:94.61ms +step:714/1705 train_time:67554ms step_avg:94.61ms +step:715/1705 train_time:67648ms step_avg:94.61ms +step:716/1705 train_time:67742ms step_avg:94.61ms +step:717/1705 train_time:67836ms step_avg:94.61ms +step:718/1705 train_time:67931ms step_avg:94.61ms +step:719/1705 train_time:68025ms step_avg:94.61ms +step:720/1705 train_time:68118ms step_avg:94.61ms +step:721/1705 train_time:68212ms step_avg:94.61ms +step:722/1705 train_time:68307ms step_avg:94.61ms +step:723/1705 train_time:68401ms step_avg:94.61ms +step:724/1705 train_time:68495ms step_avg:94.61ms +step:725/1705 train_time:68590ms step_avg:94.61ms +step:726/1705 train_time:68683ms step_avg:94.61ms +step:727/1705 train_time:68778ms step_avg:94.60ms +step:728/1705 train_time:68873ms step_avg:94.61ms +step:729/1705 train_time:68967ms step_avg:94.60ms +step:730/1705 train_time:69061ms step_avg:94.60ms +step:731/1705 train_time:69156ms step_avg:94.60ms +step:732/1705 train_time:69251ms step_avg:94.61ms +step:733/1705 train_time:69346ms step_avg:94.61ms +step:734/1705 train_time:69439ms step_avg:94.60ms +step:735/1705 train_time:69534ms step_avg:94.60ms +step:736/1705 train_time:69627ms step_avg:94.60ms +step:737/1705 train_time:69722ms step_avg:94.60ms +step:738/1705 train_time:69816ms step_avg:94.60ms +step:739/1705 train_time:69911ms step_avg:94.60ms +step:740/1705 train_time:70005ms step_avg:94.60ms +step:741/1705 train_time:70099ms step_avg:94.60ms +step:742/1705 train_time:70194ms step_avg:94.60ms +step:743/1705 train_time:70288ms step_avg:94.60ms +step:744/1705 train_time:70383ms step_avg:94.60ms +step:745/1705 train_time:70478ms step_avg:94.60ms +step:746/1705 train_time:70572ms step_avg:94.60ms +step:747/1705 train_time:70666ms step_avg:94.60ms +step:748/1705 train_time:70760ms step_avg:94.60ms +step:749/1705 train_time:70854ms step_avg:94.60ms +step:750/1705 train_time:70949ms step_avg:94.60ms +step:750/1705 val_loss:3.5638 train_time:71043ms step_avg:94.72ms +step:751/1705 train_time:71066ms step_avg:94.63ms +step:752/1705 train_time:71142ms step_avg:94.60ms +step:753/1705 train_time:71239ms step_avg:94.61ms +step:754/1705 train_time:71332ms step_avg:94.61ms +step:755/1705 train_time:71426ms step_avg:94.60ms +step:756/1705 train_time:71520ms step_avg:94.60ms +step:757/1705 train_time:71614ms step_avg:94.60ms +step:758/1705 train_time:71707ms step_avg:94.60ms +step:759/1705 train_time:71800ms step_avg:94.60ms +step:760/1705 train_time:71893ms step_avg:94.60ms +step:761/1705 train_time:71988ms step_avg:94.60ms +step:762/1705 train_time:72086ms step_avg:94.60ms +step:763/1705 train_time:72182ms step_avg:94.60ms +step:764/1705 train_time:72278ms step_avg:94.60ms +step:765/1705 train_time:72372ms step_avg:94.60ms +step:766/1705 train_time:72466ms step_avg:94.60ms +step:767/1705 train_time:72561ms step_avg:94.60ms +step:768/1705 train_time:72654ms step_avg:94.60ms +step:769/1705 train_time:72749ms step_avg:94.60ms +step:770/1705 train_time:72843ms step_avg:94.60ms +step:771/1705 train_time:72936ms step_avg:94.60ms +step:772/1705 train_time:73031ms step_avg:94.60ms +step:773/1705 train_time:73127ms step_avg:94.60ms +step:774/1705 train_time:73225ms step_avg:94.61ms +step:775/1705 train_time:73321ms step_avg:94.61ms +step:776/1705 train_time:73415ms step_avg:94.61ms +step:777/1705 train_time:73508ms step_avg:94.61ms +step:778/1705 train_time:73602ms step_avg:94.60ms +step:779/1705 train_time:73695ms step_avg:94.60ms +step:780/1705 train_time:73789ms step_avg:94.60ms +step:781/1705 train_time:73883ms step_avg:94.60ms +step:782/1705 train_time:73977ms step_avg:94.60ms +step:783/1705 train_time:74072ms step_avg:94.60ms +step:784/1705 train_time:74168ms step_avg:94.60ms +step:785/1705 train_time:74263ms step_avg:94.60ms +step:786/1705 train_time:74358ms step_avg:94.60ms +step:787/1705 train_time:74452ms step_avg:94.60ms +step:788/1705 train_time:74546ms step_avg:94.60ms +step:789/1705 train_time:74640ms step_avg:94.60ms +step:790/1705 train_time:74733ms step_avg:94.60ms +step:791/1705 train_time:74827ms step_avg:94.60ms +step:792/1705 train_time:74921ms step_avg:94.60ms +step:793/1705 train_time:75015ms step_avg:94.60ms +step:794/1705 train_time:75110ms step_avg:94.60ms +step:795/1705 train_time:75206ms step_avg:94.60ms +step:796/1705 train_time:75301ms step_avg:94.60ms +step:797/1705 train_time:75396ms step_avg:94.60ms +step:798/1705 train_time:75490ms step_avg:94.60ms +step:799/1705 train_time:75584ms step_avg:94.60ms +step:800/1705 train_time:75679ms step_avg:94.60ms +step:801/1705 train_time:75773ms step_avg:94.60ms +step:802/1705 train_time:75866ms step_avg:94.60ms +step:803/1705 train_time:75961ms step_avg:94.60ms +step:804/1705 train_time:76054ms step_avg:94.60ms +step:805/1705 train_time:76149ms step_avg:94.60ms +step:806/1705 train_time:76244ms step_avg:94.60ms +step:807/1705 train_time:76339ms step_avg:94.60ms +step:808/1705 train_time:76433ms step_avg:94.60ms +step:809/1705 train_time:76528ms step_avg:94.60ms +step:810/1705 train_time:76624ms step_avg:94.60ms +step:811/1705 train_time:76718ms step_avg:94.60ms +step:812/1705 train_time:76812ms step_avg:94.60ms +step:813/1705 train_time:76906ms step_avg:94.60ms +step:814/1705 train_time:77001ms step_avg:94.60ms +step:815/1705 train_time:77095ms step_avg:94.59ms +step:816/1705 train_time:77189ms step_avg:94.59ms +step:817/1705 train_time:77284ms step_avg:94.60ms +step:818/1705 train_time:77378ms step_avg:94.59ms +step:819/1705 train_time:77472ms step_avg:94.59ms +step:820/1705 train_time:77568ms step_avg:94.59ms +step:821/1705 train_time:77663ms step_avg:94.60ms +step:822/1705 train_time:77757ms step_avg:94.59ms +step:823/1705 train_time:77851ms step_avg:94.59ms +step:824/1705 train_time:77946ms step_avg:94.59ms +step:825/1705 train_time:78041ms step_avg:94.59ms +step:826/1705 train_time:78135ms step_avg:94.59ms +step:827/1705 train_time:78230ms step_avg:94.59ms +step:828/1705 train_time:78324ms step_avg:94.59ms +step:829/1705 train_time:78418ms step_avg:94.59ms +step:830/1705 train_time:78512ms step_avg:94.59ms +step:831/1705 train_time:78607ms step_avg:94.59ms +step:832/1705 train_time:78701ms step_avg:94.59ms +step:833/1705 train_time:78795ms step_avg:94.59ms +step:834/1705 train_time:78889ms step_avg:94.59ms +step:835/1705 train_time:78983ms step_avg:94.59ms +step:836/1705 train_time:79079ms step_avg:94.59ms +step:837/1705 train_time:79173ms step_avg:94.59ms +step:838/1705 train_time:79267ms step_avg:94.59ms +step:839/1705 train_time:79362ms step_avg:94.59ms +step:840/1705 train_time:79457ms step_avg:94.59ms +step:841/1705 train_time:79551ms step_avg:94.59ms +step:842/1705 train_time:79646ms step_avg:94.59ms +step:843/1705 train_time:79740ms step_avg:94.59ms +step:844/1705 train_time:79834ms step_avg:94.59ms +step:845/1705 train_time:79928ms step_avg:94.59ms +step:846/1705 train_time:80023ms step_avg:94.59ms +step:847/1705 train_time:80118ms step_avg:94.59ms +step:848/1705 train_time:80212ms step_avg:94.59ms +step:849/1705 train_time:80306ms step_avg:94.59ms +step:850/1705 train_time:80401ms step_avg:94.59ms +step:851/1705 train_time:80688ms step_avg:94.82ms +step:852/1705 train_time:80847ms step_avg:94.89ms +step:853/1705 train_time:80940ms step_avg:94.89ms +step:854/1705 train_time:81033ms step_avg:94.89ms +step:855/1705 train_time:81126ms step_avg:94.88ms +step:856/1705 train_time:81219ms step_avg:94.88ms +step:857/1705 train_time:81313ms step_avg:94.88ms +step:858/1705 train_time:81407ms step_avg:94.88ms +step:859/1705 train_time:81501ms step_avg:94.88ms +step:860/1705 train_time:81594ms step_avg:94.88ms +step:861/1705 train_time:81690ms step_avg:94.88ms +step:862/1705 train_time:81789ms step_avg:94.88ms +step:863/1705 train_time:81889ms step_avg:94.89ms +step:864/1705 train_time:81984ms step_avg:94.89ms +step:865/1705 train_time:82078ms step_avg:94.89ms +step:866/1705 train_time:82171ms step_avg:94.89ms +step:867/1705 train_time:82265ms step_avg:94.88ms +step:868/1705 train_time:82359ms step_avg:94.88ms +step:869/1705 train_time:82452ms step_avg:94.88ms +step:870/1705 train_time:82545ms step_avg:94.88ms +step:871/1705 train_time:82641ms step_avg:94.88ms +step:872/1705 train_time:82736ms step_avg:94.88ms +step:873/1705 train_time:82832ms step_avg:94.88ms +step:874/1705 train_time:82927ms step_avg:94.88ms +step:875/1705 train_time:83023ms step_avg:94.88ms +step:875/1705 val_loss:3.5241 train_time:83118ms step_avg:94.99ms +step:876/1705 train_time:83140ms step_avg:94.91ms +step:877/1705 train_time:83221ms step_avg:94.89ms +step:878/1705 train_time:83319ms step_avg:94.90ms +step:879/1705 train_time:83413ms step_avg:94.90ms +step:880/1705 train_time:83506ms step_avg:94.89ms +step:881/1705 train_time:83599ms step_avg:94.89ms +step:882/1705 train_time:83692ms step_avg:94.89ms +step:883/1705 train_time:83786ms step_avg:94.89ms +step:884/1705 train_time:83878ms step_avg:94.89ms +step:885/1705 train_time:83972ms step_avg:94.88ms +step:886/1705 train_time:84067ms step_avg:94.88ms +step:887/1705 train_time:84164ms step_avg:94.89ms +step:888/1705 train_time:84261ms step_avg:94.89ms +step:889/1705 train_time:84359ms step_avg:94.89ms +step:890/1705 train_time:84453ms step_avg:94.89ms +step:891/1705 train_time:84547ms step_avg:94.89ms +step:892/1705 train_time:84641ms step_avg:94.89ms +step:893/1705 train_time:84733ms step_avg:94.89ms +step:894/1705 train_time:84827ms step_avg:94.89ms +step:895/1705 train_time:84921ms step_avg:94.88ms +step:896/1705 train_time:85015ms step_avg:94.88ms +step:897/1705 train_time:85109ms step_avg:94.88ms +step:898/1705 train_time:85205ms step_avg:94.88ms +step:899/1705 train_time:85301ms step_avg:94.88ms +step:900/1705 train_time:85396ms step_avg:94.88ms +step:901/1705 train_time:85490ms step_avg:94.88ms +step:902/1705 train_time:85585ms step_avg:94.88ms +step:903/1705 train_time:85679ms step_avg:94.88ms +step:904/1705 train_time:85772ms step_avg:94.88ms +step:905/1705 train_time:85866ms step_avg:94.88ms +step:906/1705 train_time:85960ms step_avg:94.88ms +step:907/1705 train_time:86055ms step_avg:94.88ms +step:908/1705 train_time:86149ms step_avg:94.88ms +step:909/1705 train_time:86245ms step_avg:94.88ms +step:910/1705 train_time:86340ms step_avg:94.88ms +step:911/1705 train_time:86434ms step_avg:94.88ms +step:912/1705 train_time:86529ms step_avg:94.88ms +step:913/1705 train_time:86624ms step_avg:94.88ms +step:914/1705 train_time:86718ms step_avg:94.88ms +step:915/1705 train_time:86812ms step_avg:94.88ms +step:916/1705 train_time:86906ms step_avg:94.88ms +step:917/1705 train_time:87000ms step_avg:94.87ms +step:918/1705 train_time:87094ms step_avg:94.87ms +step:919/1705 train_time:87189ms step_avg:94.87ms +step:920/1705 train_time:87284ms step_avg:94.87ms +step:921/1705 train_time:87379ms step_avg:94.87ms +step:922/1705 train_time:87473ms step_avg:94.87ms +step:923/1705 train_time:87568ms step_avg:94.87ms +step:924/1705 train_time:87662ms step_avg:94.87ms +step:925/1705 train_time:87757ms step_avg:94.87ms +step:926/1705 train_time:87850ms step_avg:94.87ms +step:927/1705 train_time:87945ms step_avg:94.87ms +step:928/1705 train_time:88040ms step_avg:94.87ms +step:929/1705 train_time:88134ms step_avg:94.87ms +step:930/1705 train_time:88229ms step_avg:94.87ms +step:931/1705 train_time:88323ms step_avg:94.87ms +step:932/1705 train_time:88418ms step_avg:94.87ms +step:933/1705 train_time:88512ms step_avg:94.87ms +step:934/1705 train_time:88607ms step_avg:94.87ms +step:935/1705 train_time:88702ms step_avg:94.87ms +step:936/1705 train_time:88796ms step_avg:94.87ms +step:937/1705 train_time:88890ms step_avg:94.87ms +step:938/1705 train_time:88984ms step_avg:94.87ms +step:939/1705 train_time:89079ms step_avg:94.87ms +step:940/1705 train_time:89174ms step_avg:94.87ms +step:941/1705 train_time:89268ms step_avg:94.87ms +step:942/1705 train_time:89363ms step_avg:94.87ms +step:943/1705 train_time:89458ms step_avg:94.86ms +step:944/1705 train_time:89553ms step_avg:94.87ms +step:945/1705 train_time:89649ms step_avg:94.87ms +step:946/1705 train_time:89744ms step_avg:94.87ms +step:947/1705 train_time:89838ms step_avg:94.87ms +step:948/1705 train_time:89931ms step_avg:94.86ms +step:949/1705 train_time:90026ms step_avg:94.86ms +step:950/1705 train_time:90120ms step_avg:94.86ms +step:951/1705 train_time:90214ms step_avg:94.86ms +step:952/1705 train_time:90309ms step_avg:94.86ms +step:953/1705 train_time:90405ms step_avg:94.86ms +step:954/1705 train_time:90500ms step_avg:94.86ms +step:955/1705 train_time:90593ms step_avg:94.86ms +step:956/1705 train_time:90688ms step_avg:94.86ms +step:957/1705 train_time:90784ms step_avg:94.86ms +step:958/1705 train_time:90878ms step_avg:94.86ms +step:959/1705 train_time:90972ms step_avg:94.86ms +step:960/1705 train_time:91067ms step_avg:94.86ms +step:961/1705 train_time:91161ms step_avg:94.86ms +step:962/1705 train_time:91255ms step_avg:94.86ms +step:963/1705 train_time:91350ms step_avg:94.86ms +step:964/1705 train_time:91446ms step_avg:94.86ms +step:965/1705 train_time:91540ms step_avg:94.86ms +step:966/1705 train_time:91633ms step_avg:94.86ms +step:967/1705 train_time:91728ms step_avg:94.86ms +step:968/1705 train_time:91822ms step_avg:94.86ms +step:969/1705 train_time:91916ms step_avg:94.86ms +step:970/1705 train_time:92011ms step_avg:94.86ms +step:971/1705 train_time:92105ms step_avg:94.86ms +step:972/1705 train_time:92200ms step_avg:94.86ms +step:973/1705 train_time:92294ms step_avg:94.85ms +step:974/1705 train_time:92388ms step_avg:94.85ms +step:975/1705 train_time:92483ms step_avg:94.85ms +step:976/1705 train_time:92578ms step_avg:94.85ms +step:977/1705 train_time:92672ms step_avg:94.85ms +step:978/1705 train_time:92767ms step_avg:94.85ms +step:979/1705 train_time:92862ms step_avg:94.85ms +step:980/1705 train_time:92956ms step_avg:94.85ms +step:981/1705 train_time:93051ms step_avg:94.85ms +step:982/1705 train_time:93145ms step_avg:94.85ms +step:983/1705 train_time:93240ms step_avg:94.85ms +step:984/1705 train_time:93333ms step_avg:94.85ms +step:985/1705 train_time:93428ms step_avg:94.85ms +step:986/1705 train_time:93523ms step_avg:94.85ms +step:987/1705 train_time:93618ms step_avg:94.85ms +step:988/1705 train_time:93713ms step_avg:94.85ms +step:989/1705 train_time:93809ms step_avg:94.85ms +step:990/1705 train_time:93903ms step_avg:94.85ms +step:991/1705 train_time:93999ms step_avg:94.85ms +step:992/1705 train_time:94093ms step_avg:94.85ms +step:993/1705 train_time:94188ms step_avg:94.85ms +step:994/1705 train_time:94282ms step_avg:94.85ms +step:995/1705 train_time:94377ms step_avg:94.85ms +step:996/1705 train_time:94471ms step_avg:94.85ms +step:997/1705 train_time:94565ms step_avg:94.85ms +step:998/1705 train_time:94660ms step_avg:94.85ms +step:999/1705 train_time:94753ms step_avg:94.85ms +step:1000/1705 train_time:94848ms step_avg:94.85ms +step:1000/1705 val_loss:3.4849 train_time:94943ms step_avg:94.94ms +step:1001/1705 train_time:94967ms step_avg:94.87ms +step:1002/1705 train_time:95041ms step_avg:94.85ms +step:1003/1705 train_time:95142ms step_avg:94.86ms +step:1004/1705 train_time:95238ms step_avg:94.86ms +step:1005/1705 train_time:95332ms step_avg:94.86ms +step:1006/1705 train_time:95425ms step_avg:94.86ms +step:1007/1705 train_time:95519ms step_avg:94.85ms +step:1008/1705 train_time:95612ms step_avg:94.85ms +step:1009/1705 train_time:95706ms step_avg:94.85ms +step:1010/1705 train_time:95799ms step_avg:94.85ms +step:1011/1705 train_time:95893ms step_avg:94.85ms +step:1012/1705 train_time:95989ms step_avg:94.85ms +step:1013/1705 train_time:96085ms step_avg:94.85ms +step:1014/1705 train_time:96182ms step_avg:94.85ms +step:1015/1705 train_time:96279ms step_avg:94.86ms +step:1016/1705 train_time:96373ms step_avg:94.86ms +step:1017/1705 train_time:96467ms step_avg:94.85ms +step:1018/1705 train_time:96560ms step_avg:94.85ms +step:1019/1705 train_time:96653ms step_avg:94.85ms +step:1020/1705 train_time:96746ms step_avg:94.85ms +step:1021/1705 train_time:96840ms step_avg:94.85ms +step:1022/1705 train_time:96935ms step_avg:94.85ms +step:1023/1705 train_time:97030ms step_avg:94.85ms +step:1024/1705 train_time:97126ms step_avg:94.85ms +step:1025/1705 train_time:97223ms step_avg:94.85ms +step:1026/1705 train_time:97319ms step_avg:94.85ms +step:1027/1705 train_time:97413ms step_avg:94.85ms +step:1028/1705 train_time:97507ms step_avg:94.85ms +step:1029/1705 train_time:97601ms step_avg:94.85ms +step:1030/1705 train_time:97695ms step_avg:94.85ms +step:1031/1705 train_time:97788ms step_avg:94.85ms +step:1032/1705 train_time:97883ms step_avg:94.85ms +step:1033/1705 train_time:97978ms step_avg:94.85ms +step:1034/1705 train_time:98074ms step_avg:94.85ms +step:1035/1705 train_time:98168ms step_avg:94.85ms +step:1036/1705 train_time:98264ms step_avg:94.85ms +step:1037/1705 train_time:98359ms step_avg:94.85ms +step:1038/1705 train_time:98453ms step_avg:94.85ms +step:1039/1705 train_time:98547ms step_avg:94.85ms +step:1040/1705 train_time:98641ms step_avg:94.85ms +step:1041/1705 train_time:98735ms step_avg:94.85ms +step:1042/1705 train_time:98829ms step_avg:94.85ms +step:1043/1705 train_time:98923ms step_avg:94.84ms +step:1044/1705 train_time:99018ms step_avg:94.84ms +step:1045/1705 train_time:99113ms step_avg:94.85ms +step:1046/1705 train_time:99208ms step_avg:94.85ms +step:1047/1705 train_time:99303ms step_avg:94.85ms +step:1048/1705 train_time:99399ms step_avg:94.85ms +step:1049/1705 train_time:99494ms step_avg:94.85ms +step:1050/1705 train_time:99587ms step_avg:94.85ms +step:1051/1705 train_time:99682ms step_avg:94.84ms +step:1052/1705 train_time:99776ms step_avg:94.84ms +step:1053/1705 train_time:99871ms step_avg:94.84ms +step:1054/1705 train_time:99965ms step_avg:94.84ms +step:1055/1705 train_time:100061ms step_avg:94.84ms +step:1056/1705 train_time:100156ms step_avg:94.84ms +step:1057/1705 train_time:100251ms step_avg:94.84ms +step:1058/1705 train_time:100345ms step_avg:94.84ms +step:1059/1705 train_time:100440ms step_avg:94.84ms +step:1060/1705 train_time:100535ms step_avg:94.84ms +step:1061/1705 train_time:100628ms step_avg:94.84ms +step:1062/1705 train_time:100934ms step_avg:95.04ms +step:1063/1705 train_time:101051ms step_avg:95.06ms +step:1064/1705 train_time:101145ms step_avg:95.06ms +step:1065/1705 train_time:101239ms step_avg:95.06ms +step:1066/1705 train_time:101332ms step_avg:95.06ms +step:1067/1705 train_time:101426ms step_avg:95.06ms +step:1068/1705 train_time:101519ms step_avg:95.06ms +step:1069/1705 train_time:101612ms step_avg:95.05ms +step:1070/1705 train_time:101706ms step_avg:95.05ms +step:1071/1705 train_time:101799ms step_avg:95.05ms +step:1072/1705 train_time:101897ms step_avg:95.05ms +step:1073/1705 train_time:101996ms step_avg:95.06ms +step:1074/1705 train_time:102094ms step_avg:95.06ms +step:1075/1705 train_time:102189ms step_avg:95.06ms +step:1076/1705 train_time:102283ms step_avg:95.06ms +step:1077/1705 train_time:102377ms step_avg:95.06ms +step:1078/1705 train_time:102471ms step_avg:95.06ms +step:1079/1705 train_time:102564ms step_avg:95.05ms +step:1080/1705 train_time:102658ms step_avg:95.05ms +step:1081/1705 train_time:102752ms step_avg:95.05ms +step:1082/1705 train_time:102847ms step_avg:95.05ms +step:1083/1705 train_time:102942ms step_avg:95.05ms +step:1084/1705 train_time:103039ms step_avg:95.05ms +step:1085/1705 train_time:103136ms step_avg:95.06ms +step:1086/1705 train_time:103231ms step_avg:95.06ms +step:1087/1705 train_time:103325ms step_avg:95.06ms +step:1088/1705 train_time:103419ms step_avg:95.05ms +step:1089/1705 train_time:103513ms step_avg:95.05ms +step:1090/1705 train_time:103607ms step_avg:95.05ms +step:1091/1705 train_time:103701ms step_avg:95.05ms +step:1092/1705 train_time:103795ms step_avg:95.05ms +step:1093/1705 train_time:103888ms step_avg:95.05ms +step:1094/1705 train_time:103984ms step_avg:95.05ms +step:1095/1705 train_time:104080ms step_avg:95.05ms +step:1096/1705 train_time:104176ms step_avg:95.05ms +step:1097/1705 train_time:104270ms step_avg:95.05ms +step:1098/1705 train_time:104364ms step_avg:95.05ms +step:1099/1705 train_time:104459ms step_avg:95.05ms +step:1100/1705 train_time:104552ms step_avg:95.05ms +step:1101/1705 train_time:104646ms step_avg:95.05ms +step:1102/1705 train_time:104740ms step_avg:95.05ms +step:1103/1705 train_time:104834ms step_avg:95.04ms +step:1104/1705 train_time:104928ms step_avg:95.04ms +step:1105/1705 train_time:105024ms step_avg:95.04ms +step:1106/1705 train_time:105119ms step_avg:95.04ms +step:1107/1705 train_time:105214ms step_avg:95.04ms +step:1108/1705 train_time:105308ms step_avg:95.04ms +step:1109/1705 train_time:105403ms step_avg:95.04ms +step:1110/1705 train_time:105497ms step_avg:95.04ms +step:1111/1705 train_time:105591ms step_avg:95.04ms +step:1112/1705 train_time:105685ms step_avg:95.04ms +step:1113/1705 train_time:105780ms step_avg:95.04ms +step:1114/1705 train_time:105874ms step_avg:95.04ms +step:1115/1705 train_time:105968ms step_avg:95.04ms +step:1116/1705 train_time:106064ms step_avg:95.04ms +step:1117/1705 train_time:106159ms step_avg:95.04ms +step:1118/1705 train_time:106254ms step_avg:95.04ms +step:1119/1705 train_time:106348ms step_avg:95.04ms +step:1120/1705 train_time:106442ms step_avg:95.04ms +step:1121/1705 train_time:106537ms step_avg:95.04ms +step:1122/1705 train_time:106632ms step_avg:95.04ms +step:1123/1705 train_time:106726ms step_avg:95.04ms +step:1124/1705 train_time:106821ms step_avg:95.04ms +step:1125/1705 train_time:106915ms step_avg:95.04ms +step:1125/1705 val_loss:3.4376 train_time:107009ms step_avg:95.12ms +step:1126/1705 train_time:107033ms step_avg:95.06ms +step:1127/1705 train_time:107110ms step_avg:95.04ms +step:1128/1705 train_time:107206ms step_avg:95.04ms +step:1129/1705 train_time:107301ms step_avg:95.04ms +step:1130/1705 train_time:107394ms step_avg:95.04ms +step:1131/1705 train_time:107488ms step_avg:95.04ms +step:1132/1705 train_time:107581ms step_avg:95.04ms +step:1133/1705 train_time:107675ms step_avg:95.03ms +step:1134/1705 train_time:107768ms step_avg:95.03ms +step:1135/1705 train_time:107862ms step_avg:95.03ms +step:1136/1705 train_time:107957ms step_avg:95.03ms +step:1137/1705 train_time:108055ms step_avg:95.03ms +step:1138/1705 train_time:108151ms step_avg:95.04ms +step:1139/1705 train_time:108247ms step_avg:95.04ms +step:1140/1705 train_time:108342ms step_avg:95.04ms +step:1141/1705 train_time:108437ms step_avg:95.04ms +step:1142/1705 train_time:108532ms step_avg:95.04ms +step:1143/1705 train_time:108626ms step_avg:95.04ms +step:1144/1705 train_time:108721ms step_avg:95.04ms +step:1145/1705 train_time:108816ms step_avg:95.04ms +step:1146/1705 train_time:108910ms step_avg:95.04ms +step:1147/1705 train_time:109006ms step_avg:95.04ms +step:1148/1705 train_time:109104ms step_avg:95.04ms +step:1149/1705 train_time:109200ms step_avg:95.04ms +step:1150/1705 train_time:109297ms step_avg:95.04ms +step:1151/1705 train_time:109393ms step_avg:95.04ms +step:1152/1705 train_time:109488ms step_avg:95.04ms +step:1153/1705 train_time:109582ms step_avg:95.04ms +step:1154/1705 train_time:109677ms step_avg:95.04ms +step:1155/1705 train_time:109772ms step_avg:95.04ms +step:1156/1705 train_time:109866ms step_avg:95.04ms +step:1157/1705 train_time:109961ms step_avg:95.04ms +step:1158/1705 train_time:110058ms step_avg:95.04ms +step:1159/1705 train_time:110155ms step_avg:95.04ms +step:1160/1705 train_time:110253ms step_avg:95.05ms +step:1161/1705 train_time:110347ms step_avg:95.04ms +step:1162/1705 train_time:110443ms step_avg:95.05ms +step:1163/1705 train_time:110537ms step_avg:95.05ms +step:1164/1705 train_time:110632ms step_avg:95.04ms +step:1165/1705 train_time:110726ms step_avg:95.04ms +step:1166/1705 train_time:110821ms step_avg:95.04ms +step:1167/1705 train_time:110917ms step_avg:95.04ms +step:1168/1705 train_time:111012ms step_avg:95.04ms +step:1169/1705 train_time:111107ms step_avg:95.04ms +step:1170/1705 train_time:111204ms step_avg:95.05ms +step:1171/1705 train_time:111300ms step_avg:95.05ms +step:1172/1705 train_time:111396ms step_avg:95.05ms +step:1173/1705 train_time:111491ms step_avg:95.05ms +step:1174/1705 train_time:111586ms step_avg:95.05ms +step:1175/1705 train_time:111681ms step_avg:95.05ms +step:1176/1705 train_time:111776ms step_avg:95.05ms +step:1177/1705 train_time:111871ms step_avg:95.05ms +step:1178/1705 train_time:111965ms step_avg:95.05ms +step:1179/1705 train_time:112061ms step_avg:95.05ms +step:1180/1705 train_time:112157ms step_avg:95.05ms +step:1181/1705 train_time:112253ms step_avg:95.05ms +step:1182/1705 train_time:112349ms step_avg:95.05ms +step:1183/1705 train_time:112444ms step_avg:95.05ms +step:1184/1705 train_time:112540ms step_avg:95.05ms +step:1185/1705 train_time:112636ms step_avg:95.05ms +step:1186/1705 train_time:112730ms step_avg:95.05ms +step:1187/1705 train_time:112825ms step_avg:95.05ms +step:1188/1705 train_time:112920ms step_avg:95.05ms +step:1189/1705 train_time:113015ms step_avg:95.05ms +step:1190/1705 train_time:113111ms step_avg:95.05ms +step:1191/1705 train_time:113206ms step_avg:95.05ms +step:1192/1705 train_time:113302ms step_avg:95.05ms +step:1193/1705 train_time:113399ms step_avg:95.05ms +step:1194/1705 train_time:113494ms step_avg:95.05ms +step:1195/1705 train_time:113589ms step_avg:95.05ms +step:1196/1705 train_time:113684ms step_avg:95.05ms +step:1197/1705 train_time:113779ms step_avg:95.05ms +step:1198/1705 train_time:113874ms step_avg:95.05ms +step:1199/1705 train_time:113969ms step_avg:95.05ms +step:1200/1705 train_time:114064ms step_avg:95.05ms +step:1201/1705 train_time:114161ms step_avg:95.05ms +step:1202/1705 train_time:114258ms step_avg:95.06ms +step:1203/1705 train_time:114353ms step_avg:95.06ms +step:1204/1705 train_time:114448ms step_avg:95.06ms +step:1205/1705 train_time:114544ms step_avg:95.06ms +step:1206/1705 train_time:114640ms step_avg:95.06ms +step:1207/1705 train_time:114735ms step_avg:95.06ms +step:1208/1705 train_time:114831ms step_avg:95.06ms +step:1209/1705 train_time:114925ms step_avg:95.06ms +step:1210/1705 train_time:115021ms step_avg:95.06ms +step:1211/1705 train_time:115116ms step_avg:95.06ms +step:1212/1705 train_time:115212ms step_avg:95.06ms +step:1213/1705 train_time:115306ms step_avg:95.06ms +step:1214/1705 train_time:115403ms step_avg:95.06ms +step:1215/1705 train_time:115498ms step_avg:95.06ms +step:1216/1705 train_time:115594ms step_avg:95.06ms +step:1217/1705 train_time:115689ms step_avg:95.06ms +step:1218/1705 train_time:115784ms step_avg:95.06ms +step:1219/1705 train_time:115880ms step_avg:95.06ms +step:1220/1705 train_time:115974ms step_avg:95.06ms +step:1221/1705 train_time:116071ms step_avg:95.06ms +step:1222/1705 train_time:116165ms step_avg:95.06ms +step:1223/1705 train_time:116260ms step_avg:95.06ms +step:1224/1705 train_time:116357ms step_avg:95.06ms +step:1225/1705 train_time:116452ms step_avg:95.06ms +step:1226/1705 train_time:116547ms step_avg:95.06ms +step:1227/1705 train_time:116642ms step_avg:95.06ms +step:1228/1705 train_time:116738ms step_avg:95.06ms +step:1229/1705 train_time:116832ms step_avg:95.06ms +step:1230/1705 train_time:116927ms step_avg:95.06ms +step:1231/1705 train_time:117023ms step_avg:95.06ms +step:1232/1705 train_time:117119ms step_avg:95.06ms +step:1233/1705 train_time:117215ms step_avg:95.06ms +step:1234/1705 train_time:117310ms step_avg:95.07ms +step:1235/1705 train_time:117406ms step_avg:95.07ms +step:1236/1705 train_time:117501ms step_avg:95.07ms +step:1237/1705 train_time:117596ms step_avg:95.07ms +step:1238/1705 train_time:117691ms step_avg:95.07ms +step:1239/1705 train_time:117785ms step_avg:95.06ms +step:1240/1705 train_time:117881ms step_avg:95.07ms +step:1241/1705 train_time:117975ms step_avg:95.06ms +step:1242/1705 train_time:118071ms step_avg:95.07ms +step:1243/1705 train_time:118165ms step_avg:95.06ms +step:1244/1705 train_time:118262ms step_avg:95.07ms +step:1245/1705 train_time:118358ms step_avg:95.07ms +step:1246/1705 train_time:118453ms step_avg:95.07ms +step:1247/1705 train_time:118547ms step_avg:95.07ms +step:1248/1705 train_time:118642ms step_avg:95.07ms +step:1249/1705 train_time:118739ms step_avg:95.07ms +step:1250/1705 train_time:118834ms step_avg:95.07ms +step:1250/1705 val_loss:3.3887 train_time:118929ms step_avg:95.14ms +step:1251/1705 train_time:118952ms step_avg:95.09ms +step:1252/1705 train_time:119033ms step_avg:95.07ms +step:1253/1705 train_time:119131ms step_avg:95.08ms +step:1254/1705 train_time:119224ms step_avg:95.08ms +step:1255/1705 train_time:119319ms step_avg:95.07ms +step:1256/1705 train_time:119413ms step_avg:95.07ms +step:1257/1705 train_time:119507ms step_avg:95.07ms +step:1258/1705 train_time:119601ms step_avg:95.07ms +step:1259/1705 train_time:119695ms step_avg:95.07ms +step:1260/1705 train_time:119788ms step_avg:95.07ms +step:1261/1705 train_time:119885ms step_avg:95.07ms +step:1262/1705 train_time:119984ms step_avg:95.07ms +step:1263/1705 train_time:120081ms step_avg:95.08ms +step:1264/1705 train_time:120177ms step_avg:95.08ms +step:1265/1705 train_time:120273ms step_avg:95.08ms +step:1266/1705 train_time:120367ms step_avg:95.08ms +step:1267/1705 train_time:120461ms step_avg:95.08ms +step:1268/1705 train_time:120555ms step_avg:95.08ms +step:1269/1705 train_time:120650ms step_avg:95.07ms +step:1270/1705 train_time:120743ms step_avg:95.07ms +step:1271/1705 train_time:120839ms step_avg:95.07ms +step:1272/1705 train_time:120936ms step_avg:95.08ms +step:1273/1705 train_time:121033ms step_avg:95.08ms +step:1274/1705 train_time:121405ms step_avg:95.29ms +step:1275/1705 train_time:121488ms step_avg:95.29ms +step:1276/1705 train_time:121582ms step_avg:95.28ms +step:1277/1705 train_time:121676ms step_avg:95.28ms +step:1278/1705 train_time:121770ms step_avg:95.28ms +step:1279/1705 train_time:121864ms step_avg:95.28ms +step:1280/1705 train_time:121959ms step_avg:95.28ms +step:1281/1705 train_time:122053ms step_avg:95.28ms +step:1282/1705 train_time:122147ms step_avg:95.28ms +step:1283/1705 train_time:122240ms step_avg:95.28ms +step:1284/1705 train_time:122343ms step_avg:95.28ms +step:1285/1705 train_time:122442ms step_avg:95.29ms +step:1286/1705 train_time:122538ms step_avg:95.29ms +step:1287/1705 train_time:122633ms step_avg:95.29ms +step:1288/1705 train_time:122727ms step_avg:95.28ms +step:1289/1705 train_time:122821ms step_avg:95.28ms +step:1290/1705 train_time:122916ms step_avg:95.28ms +step:1291/1705 train_time:123010ms step_avg:95.28ms +step:1292/1705 train_time:123104ms step_avg:95.28ms +step:1293/1705 train_time:123198ms step_avg:95.28ms +step:1294/1705 train_time:123295ms step_avg:95.28ms +step:1295/1705 train_time:123394ms step_avg:95.29ms +step:1296/1705 train_time:123491ms step_avg:95.29ms +step:1297/1705 train_time:123586ms step_avg:95.29ms +step:1298/1705 train_time:123681ms step_avg:95.29ms +step:1299/1705 train_time:123776ms step_avg:95.29ms +step:1300/1705 train_time:123871ms step_avg:95.29ms +step:1301/1705 train_time:123966ms step_avg:95.29ms +step:1302/1705 train_time:124060ms step_avg:95.28ms +step:1303/1705 train_time:124154ms step_avg:95.28ms +step:1304/1705 train_time:124250ms step_avg:95.28ms +step:1305/1705 train_time:124347ms step_avg:95.28ms +step:1306/1705 train_time:124442ms step_avg:95.29ms +step:1307/1705 train_time:124538ms step_avg:95.29ms +step:1308/1705 train_time:124634ms step_avg:95.29ms +step:1309/1705 train_time:124729ms step_avg:95.29ms +step:1310/1705 train_time:124824ms step_avg:95.29ms +step:1311/1705 train_time:124918ms step_avg:95.28ms +step:1312/1705 train_time:125012ms step_avg:95.28ms +step:1313/1705 train_time:125107ms step_avg:95.28ms +step:1314/1705 train_time:125201ms step_avg:95.28ms +step:1315/1705 train_time:125297ms step_avg:95.28ms +step:1316/1705 train_time:125394ms step_avg:95.28ms +step:1317/1705 train_time:125490ms step_avg:95.28ms +step:1318/1705 train_time:125585ms step_avg:95.28ms +step:1319/1705 train_time:125680ms step_avg:95.28ms +step:1320/1705 train_time:125775ms step_avg:95.28ms +step:1321/1705 train_time:125870ms step_avg:95.28ms +step:1322/1705 train_time:125965ms step_avg:95.28ms +step:1323/1705 train_time:126060ms step_avg:95.28ms +step:1324/1705 train_time:126154ms step_avg:95.28ms +step:1325/1705 train_time:126249ms step_avg:95.28ms +step:1326/1705 train_time:126346ms step_avg:95.28ms +step:1327/1705 train_time:126441ms step_avg:95.28ms +step:1328/1705 train_time:126537ms step_avg:95.28ms +step:1329/1705 train_time:126632ms step_avg:95.28ms +step:1330/1705 train_time:126728ms step_avg:95.28ms +step:1331/1705 train_time:126822ms step_avg:95.28ms +step:1332/1705 train_time:126918ms step_avg:95.28ms +step:1333/1705 train_time:127013ms step_avg:95.28ms +step:1334/1705 train_time:127108ms step_avg:95.28ms +step:1335/1705 train_time:127202ms step_avg:95.28ms +step:1336/1705 train_time:127298ms step_avg:95.28ms +step:1337/1705 train_time:127395ms step_avg:95.28ms +step:1338/1705 train_time:127491ms step_avg:95.28ms +step:1339/1705 train_time:127586ms step_avg:95.28ms +step:1340/1705 train_time:127681ms step_avg:95.28ms +step:1341/1705 train_time:127776ms step_avg:95.28ms +step:1342/1705 train_time:127871ms step_avg:95.28ms +step:1343/1705 train_time:127966ms step_avg:95.28ms +step:1344/1705 train_time:128062ms step_avg:95.28ms +step:1345/1705 train_time:128156ms step_avg:95.28ms +step:1346/1705 train_time:128253ms step_avg:95.28ms +step:1347/1705 train_time:128349ms step_avg:95.29ms +step:1348/1705 train_time:128445ms step_avg:95.29ms +step:1349/1705 train_time:128541ms step_avg:95.29ms +step:1350/1705 train_time:128636ms step_avg:95.29ms +step:1351/1705 train_time:128732ms step_avg:95.29ms +step:1352/1705 train_time:128827ms step_avg:95.29ms +step:1353/1705 train_time:128922ms step_avg:95.29ms +step:1354/1705 train_time:129017ms step_avg:95.29ms +step:1355/1705 train_time:129113ms step_avg:95.29ms +step:1356/1705 train_time:129208ms step_avg:95.29ms +step:1357/1705 train_time:129304ms step_avg:95.29ms +step:1358/1705 train_time:129399ms step_avg:95.29ms +step:1359/1705 train_time:129494ms step_avg:95.29ms +step:1360/1705 train_time:129590ms step_avg:95.29ms +step:1361/1705 train_time:129685ms step_avg:95.29ms +step:1362/1705 train_time:129779ms step_avg:95.29ms +step:1363/1705 train_time:129875ms step_avg:95.29ms +step:1364/1705 train_time:129970ms step_avg:95.29ms +step:1365/1705 train_time:130065ms step_avg:95.29ms +step:1366/1705 train_time:130160ms step_avg:95.29ms +step:1367/1705 train_time:130255ms step_avg:95.29ms +step:1368/1705 train_time:130352ms step_avg:95.29ms +step:1369/1705 train_time:130447ms step_avg:95.29ms +step:1370/1705 train_time:130543ms step_avg:95.29ms +step:1371/1705 train_time:130638ms step_avg:95.29ms +step:1372/1705 train_time:130733ms step_avg:95.29ms +step:1373/1705 train_time:130828ms step_avg:95.29ms +step:1374/1705 train_time:130924ms step_avg:95.29ms +step:1375/1705 train_time:131019ms step_avg:95.29ms +step:1375/1705 val_loss:3.3517 train_time:131115ms step_avg:95.36ms +step:1376/1705 train_time:131138ms step_avg:95.30ms +step:1377/1705 train_time:131221ms step_avg:95.29ms +step:1378/1705 train_time:131321ms step_avg:95.30ms +step:1379/1705 train_time:131416ms step_avg:95.30ms +step:1380/1705 train_time:131511ms step_avg:95.30ms +step:1381/1705 train_time:131604ms step_avg:95.30ms +step:1382/1705 train_time:131698ms step_avg:95.30ms +step:1383/1705 train_time:131792ms step_avg:95.29ms +step:1384/1705 train_time:131886ms step_avg:95.29ms +step:1385/1705 train_time:131981ms step_avg:95.29ms +step:1386/1705 train_time:132076ms step_avg:95.29ms +step:1387/1705 train_time:132175ms step_avg:95.30ms +step:1388/1705 train_time:132273ms step_avg:95.30ms +step:1389/1705 train_time:132368ms step_avg:95.30ms +step:1390/1705 train_time:132464ms step_avg:95.30ms +step:1391/1705 train_time:132559ms step_avg:95.30ms +step:1392/1705 train_time:132654ms step_avg:95.30ms +step:1393/1705 train_time:132748ms step_avg:95.30ms +step:1394/1705 train_time:132842ms step_avg:95.30ms +step:1395/1705 train_time:132937ms step_avg:95.30ms +step:1396/1705 train_time:133031ms step_avg:95.29ms +step:1397/1705 train_time:133127ms step_avg:95.30ms +step:1398/1705 train_time:133224ms step_avg:95.30ms +step:1399/1705 train_time:133322ms step_avg:95.30ms +step:1400/1705 train_time:133417ms step_avg:95.30ms +step:1401/1705 train_time:133511ms step_avg:95.30ms +step:1402/1705 train_time:133607ms step_avg:95.30ms +step:1403/1705 train_time:133701ms step_avg:95.30ms +step:1404/1705 train_time:133797ms step_avg:95.30ms +step:1405/1705 train_time:133891ms step_avg:95.30ms +step:1406/1705 train_time:133986ms step_avg:95.30ms +step:1407/1705 train_time:134081ms step_avg:95.30ms +step:1408/1705 train_time:134177ms step_avg:95.30ms +step:1409/1705 train_time:134273ms step_avg:95.30ms +step:1410/1705 train_time:134369ms step_avg:95.30ms +step:1411/1705 train_time:134464ms step_avg:95.30ms +step:1412/1705 train_time:134560ms step_avg:95.30ms +step:1413/1705 train_time:134655ms step_avg:95.30ms +step:1414/1705 train_time:134749ms step_avg:95.30ms +step:1415/1705 train_time:134845ms step_avg:95.30ms +step:1416/1705 train_time:134940ms step_avg:95.30ms +step:1417/1705 train_time:135035ms step_avg:95.30ms +step:1418/1705 train_time:135130ms step_avg:95.30ms +step:1419/1705 train_time:135225ms step_avg:95.30ms +step:1420/1705 train_time:135322ms step_avg:95.30ms +step:1421/1705 train_time:135418ms step_avg:95.30ms +step:1422/1705 train_time:135513ms step_avg:95.30ms +step:1423/1705 train_time:135609ms step_avg:95.30ms +step:1424/1705 train_time:135703ms step_avg:95.30ms +step:1425/1705 train_time:135798ms step_avg:95.30ms +step:1426/1705 train_time:135892ms step_avg:95.30ms +step:1427/1705 train_time:135987ms step_avg:95.30ms +step:1428/1705 train_time:136083ms step_avg:95.30ms +step:1429/1705 train_time:136179ms step_avg:95.30ms +step:1430/1705 train_time:136274ms step_avg:95.30ms +step:1431/1705 train_time:136369ms step_avg:95.30ms +step:1432/1705 train_time:136467ms step_avg:95.30ms +step:1433/1705 train_time:136563ms step_avg:95.30ms +step:1434/1705 train_time:136658ms step_avg:95.30ms +step:1435/1705 train_time:136753ms step_avg:95.30ms +step:1436/1705 train_time:136848ms step_avg:95.30ms +step:1437/1705 train_time:136943ms step_avg:95.30ms +step:1438/1705 train_time:137038ms step_avg:95.30ms +step:1439/1705 train_time:137132ms step_avg:95.30ms +step:1440/1705 train_time:137228ms step_avg:95.30ms +step:1441/1705 train_time:137323ms step_avg:95.30ms +step:1442/1705 train_time:137420ms step_avg:95.30ms +step:1443/1705 train_time:137515ms step_avg:95.30ms +step:1444/1705 train_time:137610ms step_avg:95.30ms +step:1445/1705 train_time:137706ms step_avg:95.30ms +step:1446/1705 train_time:137802ms step_avg:95.30ms +step:1447/1705 train_time:137898ms step_avg:95.30ms +step:1448/1705 train_time:137992ms step_avg:95.30ms +step:1449/1705 train_time:138087ms step_avg:95.30ms +step:1450/1705 train_time:138183ms step_avg:95.30ms +step:1451/1705 train_time:138279ms step_avg:95.30ms +step:1452/1705 train_time:138375ms step_avg:95.30ms +step:1453/1705 train_time:138469ms step_avg:95.30ms +step:1454/1705 train_time:138565ms step_avg:95.30ms +step:1455/1705 train_time:138661ms step_avg:95.30ms +step:1456/1705 train_time:138757ms step_avg:95.30ms +step:1457/1705 train_time:138852ms step_avg:95.30ms +step:1458/1705 train_time:138947ms step_avg:95.30ms +step:1459/1705 train_time:139043ms step_avg:95.30ms +step:1460/1705 train_time:139138ms step_avg:95.30ms +step:1461/1705 train_time:139233ms step_avg:95.30ms +step:1462/1705 train_time:139328ms step_avg:95.30ms +step:1463/1705 train_time:139424ms step_avg:95.30ms +step:1464/1705 train_time:139519ms step_avg:95.30ms +step:1465/1705 train_time:139617ms step_avg:95.30ms +step:1466/1705 train_time:139709ms step_avg:95.30ms +step:1467/1705 train_time:139805ms step_avg:95.30ms +step:1468/1705 train_time:139901ms step_avg:95.30ms +step:1469/1705 train_time:139997ms step_avg:95.30ms +step:1470/1705 train_time:140092ms step_avg:95.30ms +step:1471/1705 train_time:140186ms step_avg:95.30ms +step:1472/1705 train_time:140283ms step_avg:95.30ms +step:1473/1705 train_time:140378ms step_avg:95.30ms +step:1474/1705 train_time:140474ms step_avg:95.30ms +step:1475/1705 train_time:140569ms step_avg:95.30ms +step:1476/1705 train_time:140664ms step_avg:95.30ms +step:1477/1705 train_time:140760ms step_avg:95.30ms +step:1478/1705 train_time:140855ms step_avg:95.30ms +step:1479/1705 train_time:140950ms step_avg:95.30ms +step:1480/1705 train_time:141047ms step_avg:95.30ms +step:1481/1705 train_time:141142ms step_avg:95.30ms +step:1482/1705 train_time:141238ms step_avg:95.30ms +step:1483/1705 train_time:141333ms step_avg:95.30ms +step:1484/1705 train_time:141428ms step_avg:95.30ms +step:1485/1705 train_time:141667ms step_avg:95.40ms +step:1486/1705 train_time:141898ms step_avg:95.49ms +step:1487/1705 train_time:141992ms step_avg:95.49ms +step:1488/1705 train_time:142086ms step_avg:95.49ms +step:1489/1705 train_time:142181ms step_avg:95.49ms +step:1490/1705 train_time:142275ms step_avg:95.49ms +step:1491/1705 train_time:142370ms step_avg:95.49ms +step:1492/1705 train_time:142464ms step_avg:95.49ms +step:1493/1705 train_time:142559ms step_avg:95.48ms +step:1494/1705 train_time:142653ms step_avg:95.48ms +step:1495/1705 train_time:142752ms step_avg:95.49ms +step:1496/1705 train_time:142851ms step_avg:95.49ms +step:1497/1705 train_time:142948ms step_avg:95.49ms +step:1498/1705 train_time:143043ms step_avg:95.49ms +step:1499/1705 train_time:143138ms step_avg:95.49ms +step:1500/1705 train_time:143232ms step_avg:95.49ms +step:1500/1705 val_loss:3.3196 train_time:143326ms step_avg:95.55ms +step:1501/1705 train_time:143349ms step_avg:95.50ms +step:1502/1705 train_time:143430ms step_avg:95.49ms +step:1503/1705 train_time:143527ms step_avg:95.49ms +step:1504/1705 train_time:143622ms step_avg:95.49ms +step:1505/1705 train_time:143716ms step_avg:95.49ms +step:1506/1705 train_time:143810ms step_avg:95.49ms +step:1507/1705 train_time:143905ms step_avg:95.49ms +step:1508/1705 train_time:144000ms step_avg:95.49ms +step:1509/1705 train_time:144094ms step_avg:95.49ms +step:1510/1705 train_time:144188ms step_avg:95.49ms +step:1511/1705 train_time:144284ms step_avg:95.49ms +step:1512/1705 train_time:144382ms step_avg:95.49ms +step:1513/1705 train_time:144480ms step_avg:95.49ms +step:1514/1705 train_time:144578ms step_avg:95.49ms +step:1515/1705 train_time:144674ms step_avg:95.49ms +step:1516/1705 train_time:144768ms step_avg:95.49ms +step:1517/1705 train_time:144862ms step_avg:95.49ms +step:1518/1705 train_time:144956ms step_avg:95.49ms +step:1519/1705 train_time:145051ms step_avg:95.49ms +step:1520/1705 train_time:145146ms step_avg:95.49ms +step:1521/1705 train_time:145240ms step_avg:95.49ms +step:1522/1705 train_time:145337ms step_avg:95.49ms +step:1523/1705 train_time:145433ms step_avg:95.49ms +step:1524/1705 train_time:145529ms step_avg:95.49ms +step:1525/1705 train_time:145625ms step_avg:95.49ms +step:1526/1705 train_time:145720ms step_avg:95.49ms +step:1527/1705 train_time:145816ms step_avg:95.49ms +step:1528/1705 train_time:145911ms step_avg:95.49ms +step:1529/1705 train_time:146005ms step_avg:95.49ms +step:1530/1705 train_time:146099ms step_avg:95.49ms +step:1531/1705 train_time:146195ms step_avg:95.49ms +step:1532/1705 train_time:146290ms step_avg:95.49ms +step:1533/1705 train_time:146386ms step_avg:95.49ms +step:1534/1705 train_time:146482ms step_avg:95.49ms +step:1535/1705 train_time:146579ms step_avg:95.49ms +step:1536/1705 train_time:146675ms step_avg:95.49ms +step:1537/1705 train_time:146770ms step_avg:95.49ms +step:1538/1705 train_time:146865ms step_avg:95.49ms +step:1539/1705 train_time:146959ms step_avg:95.49ms +step:1540/1705 train_time:147055ms step_avg:95.49ms +step:1541/1705 train_time:147150ms step_avg:95.49ms +step:1542/1705 train_time:147246ms step_avg:95.49ms +step:1543/1705 train_time:147341ms step_avg:95.49ms +step:1544/1705 train_time:147437ms step_avg:95.49ms +step:1545/1705 train_time:147533ms step_avg:95.49ms +step:1546/1705 train_time:147630ms step_avg:95.49ms +step:1547/1705 train_time:147725ms step_avg:95.49ms +step:1548/1705 train_time:147820ms step_avg:95.49ms +step:1549/1705 train_time:147915ms step_avg:95.49ms +step:1550/1705 train_time:148010ms step_avg:95.49ms +step:1551/1705 train_time:148105ms step_avg:95.49ms +step:1552/1705 train_time:148200ms step_avg:95.49ms +step:1553/1705 train_time:148296ms step_avg:95.49ms +step:1554/1705 train_time:148391ms step_avg:95.49ms +step:1555/1705 train_time:148486ms step_avg:95.49ms +step:1556/1705 train_time:148582ms step_avg:95.49ms +step:1557/1705 train_time:148678ms step_avg:95.49ms +step:1558/1705 train_time:148774ms step_avg:95.49ms +step:1559/1705 train_time:148868ms step_avg:95.49ms +step:1560/1705 train_time:148963ms step_avg:95.49ms +step:1561/1705 train_time:149058ms step_avg:95.49ms +step:1562/1705 train_time:149154ms step_avg:95.49ms +step:1563/1705 train_time:149249ms step_avg:95.49ms +step:1564/1705 train_time:149344ms step_avg:95.49ms +step:1565/1705 train_time:149439ms step_avg:95.49ms +step:1566/1705 train_time:149535ms step_avg:95.49ms +step:1567/1705 train_time:149631ms step_avg:95.49ms +step:1568/1705 train_time:149727ms step_avg:95.49ms +step:1569/1705 train_time:149821ms step_avg:95.49ms +step:1570/1705 train_time:149917ms step_avg:95.49ms +step:1571/1705 train_time:150013ms step_avg:95.49ms +step:1572/1705 train_time:150108ms step_avg:95.49ms +step:1573/1705 train_time:150203ms step_avg:95.49ms +step:1574/1705 train_time:150298ms step_avg:95.49ms +step:1575/1705 train_time:150395ms step_avg:95.49ms +step:1576/1705 train_time:150490ms step_avg:95.49ms +step:1577/1705 train_time:150586ms step_avg:95.49ms +step:1578/1705 train_time:150681ms step_avg:95.49ms +step:1579/1705 train_time:150777ms step_avg:95.49ms +step:1580/1705 train_time:150872ms step_avg:95.49ms +step:1581/1705 train_time:150967ms step_avg:95.49ms +step:1582/1705 train_time:151062ms step_avg:95.49ms +step:1583/1705 train_time:151158ms step_avg:95.49ms +step:1584/1705 train_time:151253ms step_avg:95.49ms +step:1585/1705 train_time:151348ms step_avg:95.49ms +step:1586/1705 train_time:151443ms step_avg:95.49ms +step:1587/1705 train_time:151539ms step_avg:95.49ms +step:1588/1705 train_time:151634ms step_avg:95.49ms +step:1589/1705 train_time:151730ms step_avg:95.49ms +step:1590/1705 train_time:151826ms step_avg:95.49ms +step:1591/1705 train_time:151921ms step_avg:95.49ms +step:1592/1705 train_time:152016ms step_avg:95.49ms +step:1593/1705 train_time:152112ms step_avg:95.49ms +step:1594/1705 train_time:152207ms step_avg:95.49ms +step:1595/1705 train_time:152301ms step_avg:95.49ms +step:1596/1705 train_time:152398ms step_avg:95.49ms +step:1597/1705 train_time:152494ms step_avg:95.49ms +step:1598/1705 train_time:152590ms step_avg:95.49ms +step:1599/1705 train_time:152685ms step_avg:95.49ms +step:1600/1705 train_time:152780ms step_avg:95.49ms +step:1601/1705 train_time:152874ms step_avg:95.49ms +step:1602/1705 train_time:152970ms step_avg:95.49ms +step:1603/1705 train_time:153065ms step_avg:95.49ms +step:1604/1705 train_time:153160ms step_avg:95.49ms +step:1605/1705 train_time:153255ms step_avg:95.49ms +step:1606/1705 train_time:153350ms step_avg:95.49ms +step:1607/1705 train_time:153446ms step_avg:95.49ms +step:1608/1705 train_time:153541ms step_avg:95.49ms +step:1609/1705 train_time:153637ms step_avg:95.49ms +step:1610/1705 train_time:153733ms step_avg:95.49ms +step:1611/1705 train_time:153829ms step_avg:95.49ms +step:1612/1705 train_time:153923ms step_avg:95.49ms +step:1613/1705 train_time:154018ms step_avg:95.49ms +step:1614/1705 train_time:154115ms step_avg:95.49ms +step:1615/1705 train_time:154210ms step_avg:95.49ms +step:1616/1705 train_time:154306ms step_avg:95.49ms +step:1617/1705 train_time:154401ms step_avg:95.49ms +step:1618/1705 train_time:154497ms step_avg:95.49ms +step:1619/1705 train_time:154593ms step_avg:95.49ms +step:1620/1705 train_time:154688ms step_avg:95.49ms +step:1621/1705 train_time:154783ms step_avg:95.49ms +step:1622/1705 train_time:154878ms step_avg:95.49ms +step:1623/1705 train_time:154974ms step_avg:95.49ms +step:1624/1705 train_time:155069ms step_avg:95.49ms +step:1625/1705 train_time:155166ms step_avg:95.49ms +step:1625/1705 val_loss:3.2917 train_time:155260ms step_avg:95.54ms +step:1626/1705 train_time:155285ms step_avg:95.50ms +step:1627/1705 train_time:155363ms step_avg:95.49ms +step:1628/1705 train_time:155462ms step_avg:95.49ms +step:1629/1705 train_time:155558ms step_avg:95.49ms +step:1630/1705 train_time:155652ms step_avg:95.49ms +step:1631/1705 train_time:155746ms step_avg:95.49ms +step:1632/1705 train_time:155841ms step_avg:95.49ms +step:1633/1705 train_time:155936ms step_avg:95.49ms +step:1634/1705 train_time:156029ms step_avg:95.49ms +step:1635/1705 train_time:156123ms step_avg:95.49ms +step:1636/1705 train_time:156219ms step_avg:95.49ms +step:1637/1705 train_time:156316ms step_avg:95.49ms +step:1638/1705 train_time:156412ms step_avg:95.49ms +step:1639/1705 train_time:156509ms step_avg:95.49ms +step:1640/1705 train_time:156606ms step_avg:95.49ms +step:1641/1705 train_time:156703ms step_avg:95.49ms +step:1642/1705 train_time:156799ms step_avg:95.49ms +step:1643/1705 train_time:156893ms step_avg:95.49ms +step:1644/1705 train_time:156987ms step_avg:95.49ms +step:1645/1705 train_time:157082ms step_avg:95.49ms +step:1646/1705 train_time:157177ms step_avg:95.49ms +step:1647/1705 train_time:157272ms step_avg:95.49ms +step:1648/1705 train_time:157369ms step_avg:95.49ms +step:1649/1705 train_time:157465ms step_avg:95.49ms +step:1650/1705 train_time:157562ms step_avg:95.49ms +step:1651/1705 train_time:157657ms step_avg:95.49ms +step:1652/1705 train_time:157752ms step_avg:95.49ms +step:1653/1705 train_time:157847ms step_avg:95.49ms +step:1654/1705 train_time:157943ms step_avg:95.49ms +step:1655/1705 train_time:158037ms step_avg:95.49ms +step:1656/1705 train_time:158131ms step_avg:95.49ms +step:1657/1705 train_time:158227ms step_avg:95.49ms +step:1658/1705 train_time:158324ms step_avg:95.49ms +step:1659/1705 train_time:158421ms step_avg:95.49ms +step:1660/1705 train_time:158517ms step_avg:95.49ms +step:1661/1705 train_time:158612ms step_avg:95.49ms +step:1662/1705 train_time:158707ms step_avg:95.49ms +step:1663/1705 train_time:158803ms step_avg:95.49ms +step:1664/1705 train_time:158899ms step_avg:95.49ms +step:1665/1705 train_time:158993ms step_avg:95.49ms +step:1666/1705 train_time:159088ms step_avg:95.49ms +step:1667/1705 train_time:159183ms step_avg:95.49ms +step:1668/1705 train_time:159279ms step_avg:95.49ms +step:1669/1705 train_time:159374ms step_avg:95.49ms +step:1670/1705 train_time:159469ms step_avg:95.49ms +step:1671/1705 train_time:159566ms step_avg:95.49ms +step:1672/1705 train_time:159662ms step_avg:95.49ms +step:1673/1705 train_time:159758ms step_avg:95.49ms +step:1674/1705 train_time:159854ms step_avg:95.49ms +step:1675/1705 train_time:159949ms step_avg:95.49ms +step:1676/1705 train_time:160044ms step_avg:95.49ms +step:1677/1705 train_time:160139ms step_avg:95.49ms +step:1678/1705 train_time:160234ms step_avg:95.49ms +step:1679/1705 train_time:160329ms step_avg:95.49ms +step:1680/1705 train_time:160425ms step_avg:95.49ms +step:1681/1705 train_time:160522ms step_avg:95.49ms +step:1682/1705 train_time:160617ms step_avg:95.49ms +step:1683/1705 train_time:160713ms step_avg:95.49ms +step:1684/1705 train_time:160808ms step_avg:95.49ms +step:1685/1705 train_time:160904ms step_avg:95.49ms +step:1686/1705 train_time:160998ms step_avg:95.49ms +step:1687/1705 train_time:161093ms step_avg:95.49ms +step:1688/1705 train_time:161188ms step_avg:95.49ms +step:1689/1705 train_time:161283ms step_avg:95.49ms +step:1690/1705 train_time:161379ms step_avg:95.49ms +step:1691/1705 train_time:161474ms step_avg:95.49ms +step:1692/1705 train_time:161570ms step_avg:95.49ms +step:1693/1705 train_time:161665ms step_avg:95.49ms +step:1694/1705 train_time:161762ms step_avg:95.49ms +step:1695/1705 train_time:161859ms step_avg:95.49ms +step:1696/1705 train_time:161954ms step_avg:95.49ms +step:1697/1705 train_time:162048ms step_avg:95.49ms +step:1698/1705 train_time:162285ms step_avg:95.57ms +step:1699/1705 train_time:162486ms step_avg:95.64ms +step:1700/1705 train_time:162580ms step_avg:95.64ms +step:1701/1705 train_time:162673ms step_avg:95.63ms +step:1702/1705 train_time:162768ms step_avg:95.63ms +step:1703/1705 train_time:162862ms step_avg:95.63ms +step:1704/1705 train_time:162956ms step_avg:95.63ms +step:1705/1705 train_time:163050ms step_avg:95.63ms +step:1705/1705 val_loss:3.2775 train_time:163144ms step_avg:95.69ms +peak memory allocated: 34489 MiB reserved: 49496 MiB diff --git a/records/050925_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt b/records/050925_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt new file mode 100644 index 000000000..4d203736f --- /dev/null +++ b/records/050925_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:28:29 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 66237 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 66238 C /usr/bin/python3 610MiB | +| 0 N/A N/A 66239 C /usr/bin/python3 610MiB | +| 0 N/A N/A 66240 C /usr/bin/python3 610MiB | +| 0 N/A N/A 66241 C /usr/bin/python3 610MiB | +| 0 N/A N/A 66242 C /usr/bin/python3 610MiB | +| 0 N/A N/A 66243 C /usr/bin/python3 610MiB | +| 0 N/A N/A 66244 C /usr/bin/python3 610MiB | +| 1 N/A N/A 66238 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 66239 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 66240 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 66241 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 66242 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 66243 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 66244 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1705 train_time:382ms step_avg:382.22ms +step:2/1705 train_time:404ms step_avg:201.94ms +step:3/1705 train_time:472ms step_avg:157.27ms +step:4/1705 train_time:563ms step_avg:140.71ms +step:5/1705 train_time:655ms step_avg:130.90ms +step:6/1705 train_time:747ms step_avg:124.47ms +step:7/1705 train_time:839ms step_avg:119.87ms +step:8/1705 train_time:931ms step_avg:116.36ms +step:9/1705 train_time:1023ms step_avg:113.66ms +step:10/1705 train_time:1115ms step_avg:111.53ms +step:11/1705 train_time:1208ms step_avg:109.78ms +step:12/1705 train_time:1303ms step_avg:108.57ms +step:13/1705 train_time:1401ms step_avg:107.80ms +step:14/1705 train_time:1496ms step_avg:106.84ms +step:15/1705 train_time:1588ms step_avg:105.87ms +step:16/1705 train_time:1680ms step_avg:105.03ms +step:17/1705 train_time:1772ms step_avg:104.26ms +step:18/1705 train_time:1865ms step_avg:103.60ms +step:19/1705 train_time:1957ms step_avg:103.00ms +step:20/1705 train_time:2049ms step_avg:102.47ms +step:21/1705 train_time:2142ms step_avg:102.00ms +step:22/1705 train_time:2235ms step_avg:101.57ms +step:23/1705 train_time:2329ms step_avg:101.28ms +step:24/1705 train_time:2426ms step_avg:101.10ms +step:25/1705 train_time:2521ms step_avg:100.84ms +step:26/1705 train_time:2614ms step_avg:100.55ms +step:27/1705 train_time:2707ms step_avg:100.27ms +step:28/1705 train_time:2801ms step_avg:100.03ms +step:29/1705 train_time:2894ms step_avg:99.78ms +step:30/1705 train_time:2986ms step_avg:99.55ms +step:31/1705 train_time:3079ms step_avg:99.32ms +step:32/1705 train_time:3171ms step_avg:99.10ms +step:33/1705 train_time:3266ms step_avg:98.96ms +step:34/1705 train_time:3360ms step_avg:98.82ms +step:35/1705 train_time:3453ms step_avg:98.67ms +step:36/1705 train_time:3548ms step_avg:98.56ms +step:37/1705 train_time:3642ms step_avg:98.42ms +step:38/1705 train_time:3735ms step_avg:98.28ms +step:39/1705 train_time:3827ms step_avg:98.13ms +step:40/1705 train_time:3920ms step_avg:98.00ms +step:41/1705 train_time:4013ms step_avg:97.88ms +step:42/1705 train_time:4106ms step_avg:97.75ms +step:43/1705 train_time:4198ms step_avg:97.63ms +step:44/1705 train_time:4291ms step_avg:97.53ms +step:45/1705 train_time:4385ms step_avg:97.45ms +step:46/1705 train_time:4479ms step_avg:97.38ms +step:47/1705 train_time:4572ms step_avg:97.29ms +step:48/1705 train_time:4666ms step_avg:97.21ms +step:49/1705 train_time:4759ms step_avg:97.13ms +step:50/1705 train_time:4852ms step_avg:97.04ms +step:51/1705 train_time:4945ms step_avg:96.97ms +step:52/1705 train_time:5039ms step_avg:96.90ms +step:53/1705 train_time:5131ms step_avg:96.82ms +step:54/1705 train_time:5225ms step_avg:96.75ms +step:55/1705 train_time:5318ms step_avg:96.69ms +step:56/1705 train_time:5412ms step_avg:96.64ms +step:57/1705 train_time:5505ms step_avg:96.59ms +step:58/1705 train_time:5599ms step_avg:96.53ms +step:59/1705 train_time:5692ms step_avg:96.47ms +step:60/1705 train_time:5785ms step_avg:96.42ms +step:61/1705 train_time:5878ms step_avg:96.37ms +step:62/1705 train_time:5971ms step_avg:96.30ms +step:63/1705 train_time:6063ms step_avg:96.24ms +step:64/1705 train_time:6156ms step_avg:96.19ms +step:65/1705 train_time:6249ms step_avg:96.14ms +step:66/1705 train_time:6343ms step_avg:96.10ms +step:67/1705 train_time:6436ms step_avg:96.06ms +step:68/1705 train_time:6529ms step_avg:96.02ms +step:69/1705 train_time:6623ms step_avg:95.98ms +step:70/1705 train_time:6717ms step_avg:95.96ms +step:71/1705 train_time:6810ms step_avg:95.92ms +step:72/1705 train_time:6904ms step_avg:95.89ms +step:73/1705 train_time:6996ms step_avg:95.84ms +step:74/1705 train_time:7089ms step_avg:95.80ms +step:75/1705 train_time:7183ms step_avg:95.77ms +step:76/1705 train_time:7275ms step_avg:95.73ms +step:77/1705 train_time:7368ms step_avg:95.69ms +step:78/1705 train_time:7462ms step_avg:95.67ms +step:79/1705 train_time:7555ms step_avg:95.63ms +step:80/1705 train_time:7649ms step_avg:95.61ms +step:81/1705 train_time:7743ms step_avg:95.59ms +step:82/1705 train_time:7836ms step_avg:95.56ms +step:83/1705 train_time:7928ms step_avg:95.52ms +step:84/1705 train_time:8021ms step_avg:95.49ms +step:85/1705 train_time:8114ms step_avg:95.46ms +step:86/1705 train_time:8207ms step_avg:95.43ms +step:87/1705 train_time:8300ms step_avg:95.40ms +step:88/1705 train_time:8392ms step_avg:95.37ms +step:89/1705 train_time:8485ms step_avg:95.34ms +step:90/1705 train_time:8579ms step_avg:95.32ms +step:91/1705 train_time:8672ms step_avg:95.29ms +step:92/1705 train_time:8765ms step_avg:95.28ms +step:93/1705 train_time:8859ms step_avg:95.25ms +step:94/1705 train_time:8951ms step_avg:95.22ms +step:95/1705 train_time:9044ms step_avg:95.20ms +step:96/1705 train_time:9136ms step_avg:95.17ms +step:97/1705 train_time:9229ms step_avg:95.14ms +step:98/1705 train_time:9322ms step_avg:95.13ms +step:99/1705 train_time:9416ms step_avg:95.11ms +step:100/1705 train_time:9509ms step_avg:95.09ms +step:101/1705 train_time:9602ms step_avg:95.07ms +step:102/1705 train_time:9695ms step_avg:95.05ms +step:103/1705 train_time:9788ms step_avg:95.03ms +step:104/1705 train_time:9883ms step_avg:95.03ms +step:105/1705 train_time:9977ms step_avg:95.02ms +step:106/1705 train_time:10070ms step_avg:95.00ms +step:107/1705 train_time:10162ms step_avg:94.98ms +step:108/1705 train_time:10255ms step_avg:94.95ms +step:109/1705 train_time:10348ms step_avg:94.94ms +step:110/1705 train_time:10441ms step_avg:94.92ms +step:111/1705 train_time:10534ms step_avg:94.90ms +step:112/1705 train_time:10627ms step_avg:94.88ms +step:113/1705 train_time:10721ms step_avg:94.88ms +step:114/1705 train_time:10814ms step_avg:94.86ms +step:115/1705 train_time:10908ms step_avg:94.85ms +step:116/1705 train_time:11002ms step_avg:94.84ms +step:117/1705 train_time:11094ms step_avg:94.82ms +step:118/1705 train_time:11188ms step_avg:94.81ms +step:119/1705 train_time:11281ms step_avg:94.80ms +step:120/1705 train_time:11374ms step_avg:94.78ms +step:121/1705 train_time:11467ms step_avg:94.77ms +step:122/1705 train_time:11560ms step_avg:94.75ms +step:123/1705 train_time:11652ms step_avg:94.73ms +step:124/1705 train_time:11746ms step_avg:94.73ms +step:125/1705 train_time:11839ms step_avg:94.71ms +step:125/1705 val_loss:4.3144 train_time:11932ms step_avg:95.45ms +step:126/1705 train_time:11955ms step_avg:94.88ms +step:127/1705 train_time:12029ms step_avg:94.72ms +step:128/1705 train_time:12131ms step_avg:94.77ms +step:129/1705 train_time:12226ms step_avg:94.78ms +step:130/1705 train_time:12320ms step_avg:94.77ms +step:131/1705 train_time:12412ms step_avg:94.75ms +step:132/1705 train_time:12504ms step_avg:94.72ms +step:133/1705 train_time:12596ms step_avg:94.71ms +step:134/1705 train_time:12687ms step_avg:94.68ms +step:135/1705 train_time:12779ms step_avg:94.66ms +step:136/1705 train_time:12872ms step_avg:94.64ms +step:137/1705 train_time:12963ms step_avg:94.62ms +step:138/1705 train_time:13058ms step_avg:94.62ms +step:139/1705 train_time:13153ms step_avg:94.62ms +step:140/1705 train_time:13247ms step_avg:94.62ms +step:141/1705 train_time:13340ms step_avg:94.61ms +step:142/1705 train_time:13433ms step_avg:94.60ms +step:143/1705 train_time:13526ms step_avg:94.59ms +step:144/1705 train_time:13619ms step_avg:94.57ms +step:145/1705 train_time:13711ms step_avg:94.56ms +step:146/1705 train_time:13803ms step_avg:94.54ms +step:147/1705 train_time:13895ms step_avg:94.53ms +step:148/1705 train_time:13987ms step_avg:94.51ms +step:149/1705 train_time:14080ms step_avg:94.50ms +step:150/1705 train_time:14174ms step_avg:94.49ms +step:151/1705 train_time:14268ms step_avg:94.49ms +step:152/1705 train_time:14361ms step_avg:94.48ms +step:153/1705 train_time:14455ms step_avg:94.47ms +step:154/1705 train_time:14547ms step_avg:94.46ms +step:155/1705 train_time:14640ms step_avg:94.45ms +step:156/1705 train_time:14732ms step_avg:94.44ms +step:157/1705 train_time:14824ms step_avg:94.42ms +step:158/1705 train_time:14916ms step_avg:94.41ms +step:159/1705 train_time:15009ms step_avg:94.40ms +step:160/1705 train_time:15102ms step_avg:94.39ms +step:161/1705 train_time:15196ms step_avg:94.38ms +step:162/1705 train_time:15289ms step_avg:94.38ms +step:163/1705 train_time:15382ms step_avg:94.37ms +step:164/1705 train_time:15476ms step_avg:94.36ms +step:165/1705 train_time:15569ms step_avg:94.35ms +step:166/1705 train_time:15661ms step_avg:94.34ms +step:167/1705 train_time:15754ms step_avg:94.33ms +step:168/1705 train_time:15846ms step_avg:94.32ms +step:169/1705 train_time:15939ms step_avg:94.31ms +step:170/1705 train_time:16032ms step_avg:94.31ms +step:171/1705 train_time:16125ms step_avg:94.30ms +step:172/1705 train_time:16217ms step_avg:94.29ms +step:173/1705 train_time:16311ms step_avg:94.28ms +step:174/1705 train_time:16404ms step_avg:94.27ms +step:175/1705 train_time:16497ms step_avg:94.27ms +step:176/1705 train_time:16589ms step_avg:94.26ms +step:177/1705 train_time:16682ms step_avg:94.25ms +step:178/1705 train_time:16776ms step_avg:94.24ms +step:179/1705 train_time:16868ms step_avg:94.23ms +step:180/1705 train_time:16961ms step_avg:94.23ms +step:181/1705 train_time:17053ms step_avg:94.22ms +step:182/1705 train_time:17146ms step_avg:94.21ms +step:183/1705 train_time:17239ms step_avg:94.20ms +step:184/1705 train_time:17332ms step_avg:94.20ms +step:185/1705 train_time:17425ms step_avg:94.19ms +step:186/1705 train_time:17518ms step_avg:94.18ms +step:187/1705 train_time:17611ms step_avg:94.18ms +step:188/1705 train_time:17705ms step_avg:94.17ms +step:189/1705 train_time:17797ms step_avg:94.16ms +step:190/1705 train_time:17890ms step_avg:94.16ms +step:191/1705 train_time:17983ms step_avg:94.15ms +step:192/1705 train_time:18076ms step_avg:94.14ms +step:193/1705 train_time:18168ms step_avg:94.14ms +step:194/1705 train_time:18262ms step_avg:94.13ms +step:195/1705 train_time:18355ms step_avg:94.13ms +step:196/1705 train_time:18448ms step_avg:94.12ms +step:197/1705 train_time:18541ms step_avg:94.12ms +step:198/1705 train_time:18634ms step_avg:94.11ms +step:199/1705 train_time:18727ms step_avg:94.11ms +step:200/1705 train_time:18820ms step_avg:94.10ms +step:201/1705 train_time:18913ms step_avg:94.09ms +step:202/1705 train_time:19006ms step_avg:94.09ms +step:203/1705 train_time:19099ms step_avg:94.08ms +step:204/1705 train_time:19192ms step_avg:94.08ms +step:205/1705 train_time:19284ms step_avg:94.07ms +step:206/1705 train_time:19378ms step_avg:94.07ms +step:207/1705 train_time:19471ms step_avg:94.06ms +step:208/1705 train_time:19564ms step_avg:94.06ms +step:209/1705 train_time:19657ms step_avg:94.05ms +step:210/1705 train_time:19750ms step_avg:94.05ms +step:211/1705 train_time:19842ms step_avg:94.04ms +step:212/1705 train_time:19935ms step_avg:94.03ms +step:213/1705 train_time:20186ms step_avg:94.77ms +step:214/1705 train_time:20354ms step_avg:95.11ms +step:215/1705 train_time:20445ms step_avg:95.09ms +step:216/1705 train_time:20538ms step_avg:95.08ms +step:217/1705 train_time:20630ms step_avg:95.07ms +step:218/1705 train_time:20722ms step_avg:95.06ms +step:219/1705 train_time:20814ms step_avg:95.04ms +step:220/1705 train_time:20906ms step_avg:95.03ms +step:221/1705 train_time:20998ms step_avg:95.01ms +step:222/1705 train_time:21090ms step_avg:95.00ms +step:223/1705 train_time:21182ms step_avg:94.99ms +step:224/1705 train_time:21279ms step_avg:95.00ms +step:225/1705 train_time:21375ms step_avg:95.00ms +step:226/1705 train_time:21468ms step_avg:94.99ms +step:227/1705 train_time:21562ms step_avg:94.99ms +step:228/1705 train_time:21654ms step_avg:94.98ms +step:229/1705 train_time:21747ms step_avg:94.96ms +step:230/1705 train_time:21839ms step_avg:94.95ms +step:231/1705 train_time:21932ms step_avg:94.94ms +step:232/1705 train_time:22023ms step_avg:94.93ms +step:233/1705 train_time:22116ms step_avg:94.92ms +step:234/1705 train_time:22208ms step_avg:94.91ms +step:235/1705 train_time:22304ms step_avg:94.91ms +step:236/1705 train_time:22398ms step_avg:94.91ms +step:237/1705 train_time:22491ms step_avg:94.90ms +step:238/1705 train_time:22583ms step_avg:94.89ms +step:239/1705 train_time:22676ms step_avg:94.88ms +step:240/1705 train_time:22769ms step_avg:94.87ms +step:241/1705 train_time:22862ms step_avg:94.86ms +step:242/1705 train_time:22954ms step_avg:94.85ms +step:243/1705 train_time:23047ms step_avg:94.84ms +step:244/1705 train_time:23139ms step_avg:94.83ms +step:245/1705 train_time:23233ms step_avg:94.83ms +step:246/1705 train_time:23326ms step_avg:94.82ms +step:247/1705 train_time:23420ms step_avg:94.82ms +step:248/1705 train_time:23513ms step_avg:94.81ms +step:249/1705 train_time:23606ms step_avg:94.80ms +step:250/1705 train_time:23699ms step_avg:94.80ms +step:250/1705 val_loss:3.9838 train_time:23793ms step_avg:95.17ms +step:251/1705 train_time:23815ms step_avg:94.88ms +step:252/1705 train_time:23889ms step_avg:94.80ms +step:253/1705 train_time:23990ms step_avg:94.82ms +step:254/1705 train_time:24084ms step_avg:94.82ms +step:255/1705 train_time:24175ms step_avg:94.81ms +step:256/1705 train_time:24267ms step_avg:94.79ms +step:257/1705 train_time:24359ms step_avg:94.78ms +step:258/1705 train_time:24451ms step_avg:94.77ms +step:259/1705 train_time:24542ms step_avg:94.76ms +step:260/1705 train_time:24634ms step_avg:94.75ms +step:261/1705 train_time:24727ms step_avg:94.74ms +step:262/1705 train_time:24820ms step_avg:94.73ms +step:263/1705 train_time:24916ms step_avg:94.74ms +step:264/1705 train_time:25010ms step_avg:94.74ms +step:265/1705 train_time:25104ms step_avg:94.73ms +step:266/1705 train_time:25197ms step_avg:94.73ms +step:267/1705 train_time:25290ms step_avg:94.72ms +step:268/1705 train_time:25382ms step_avg:94.71ms +step:269/1705 train_time:25474ms step_avg:94.70ms +step:270/1705 train_time:25566ms step_avg:94.69ms +step:271/1705 train_time:25658ms step_avg:94.68ms +step:272/1705 train_time:25751ms step_avg:94.67ms +step:273/1705 train_time:25844ms step_avg:94.67ms +step:274/1705 train_time:25937ms step_avg:94.66ms +step:275/1705 train_time:26032ms step_avg:94.66ms +step:276/1705 train_time:26126ms step_avg:94.66ms +step:277/1705 train_time:26218ms step_avg:94.65ms +step:278/1705 train_time:26311ms step_avg:94.64ms +step:279/1705 train_time:26404ms step_avg:94.64ms +step:280/1705 train_time:26497ms step_avg:94.63ms +step:281/1705 train_time:26590ms step_avg:94.62ms +step:282/1705 train_time:26681ms step_avg:94.62ms +step:283/1705 train_time:26775ms step_avg:94.61ms +step:284/1705 train_time:26868ms step_avg:94.61ms +step:285/1705 train_time:26961ms step_avg:94.60ms +step:286/1705 train_time:27055ms step_avg:94.60ms +step:287/1705 train_time:27148ms step_avg:94.59ms +step:288/1705 train_time:27240ms step_avg:94.58ms +step:289/1705 train_time:27334ms step_avg:94.58ms +step:290/1705 train_time:27427ms step_avg:94.57ms +step:291/1705 train_time:27518ms step_avg:94.56ms +step:292/1705 train_time:27611ms step_avg:94.56ms +step:293/1705 train_time:27703ms step_avg:94.55ms +step:294/1705 train_time:27797ms step_avg:94.55ms +step:295/1705 train_time:27890ms step_avg:94.54ms +step:296/1705 train_time:27984ms step_avg:94.54ms +step:297/1705 train_time:28076ms step_avg:94.53ms +step:298/1705 train_time:28170ms step_avg:94.53ms +step:299/1705 train_time:28263ms step_avg:94.52ms +step:300/1705 train_time:28356ms step_avg:94.52ms +step:301/1705 train_time:28449ms step_avg:94.51ms +step:302/1705 train_time:28542ms step_avg:94.51ms +step:303/1705 train_time:28635ms step_avg:94.50ms +step:304/1705 train_time:28728ms step_avg:94.50ms +step:305/1705 train_time:28820ms step_avg:94.49ms +step:306/1705 train_time:28913ms step_avg:94.49ms +step:307/1705 train_time:29006ms step_avg:94.48ms +step:308/1705 train_time:29099ms step_avg:94.48ms +step:309/1705 train_time:29192ms step_avg:94.47ms +step:310/1705 train_time:29286ms step_avg:94.47ms +step:311/1705 train_time:29378ms step_avg:94.46ms +step:312/1705 train_time:29471ms step_avg:94.46ms +step:313/1705 train_time:29565ms step_avg:94.46ms +step:314/1705 train_time:29658ms step_avg:94.45ms +step:315/1705 train_time:29750ms step_avg:94.45ms +step:316/1705 train_time:29843ms step_avg:94.44ms +step:317/1705 train_time:29936ms step_avg:94.44ms +step:318/1705 train_time:30030ms step_avg:94.44ms +step:319/1705 train_time:30122ms step_avg:94.43ms +step:320/1705 train_time:30217ms step_avg:94.43ms +step:321/1705 train_time:30309ms step_avg:94.42ms +step:322/1705 train_time:30402ms step_avg:94.41ms +step:323/1705 train_time:30495ms step_avg:94.41ms +step:324/1705 train_time:30588ms step_avg:94.41ms +step:325/1705 train_time:30680ms step_avg:94.40ms +step:326/1705 train_time:30773ms step_avg:94.40ms +step:327/1705 train_time:30867ms step_avg:94.39ms +step:328/1705 train_time:30960ms step_avg:94.39ms +step:329/1705 train_time:31053ms step_avg:94.39ms +step:330/1705 train_time:31145ms step_avg:94.38ms +step:331/1705 train_time:31238ms step_avg:94.37ms +step:332/1705 train_time:31331ms step_avg:94.37ms +step:333/1705 train_time:31424ms step_avg:94.37ms +step:334/1705 train_time:31517ms step_avg:94.36ms +step:335/1705 train_time:31610ms step_avg:94.36ms +step:336/1705 train_time:31702ms step_avg:94.35ms +step:337/1705 train_time:31795ms step_avg:94.35ms +step:338/1705 train_time:31888ms step_avg:94.34ms +step:339/1705 train_time:31981ms step_avg:94.34ms +step:340/1705 train_time:32074ms step_avg:94.34ms +step:341/1705 train_time:32168ms step_avg:94.33ms +step:342/1705 train_time:32260ms step_avg:94.33ms +step:343/1705 train_time:32353ms step_avg:94.32ms +step:344/1705 train_time:32446ms step_avg:94.32ms +step:345/1705 train_time:32538ms step_avg:94.31ms +step:346/1705 train_time:32632ms step_avg:94.31ms +step:347/1705 train_time:32725ms step_avg:94.31ms +step:348/1705 train_time:32818ms step_avg:94.30ms +step:349/1705 train_time:32911ms step_avg:94.30ms +step:350/1705 train_time:33004ms step_avg:94.30ms +step:351/1705 train_time:33097ms step_avg:94.29ms +step:352/1705 train_time:33189ms step_avg:94.29ms +step:353/1705 train_time:33282ms step_avg:94.28ms +step:354/1705 train_time:33376ms step_avg:94.28ms +step:355/1705 train_time:33469ms step_avg:94.28ms +step:356/1705 train_time:33561ms step_avg:94.27ms +step:357/1705 train_time:33655ms step_avg:94.27ms +step:358/1705 train_time:33748ms step_avg:94.27ms +step:359/1705 train_time:33841ms step_avg:94.26ms +step:360/1705 train_time:33935ms step_avg:94.26ms +step:361/1705 train_time:34028ms step_avg:94.26ms +step:362/1705 train_time:34120ms step_avg:94.26ms +step:363/1705 train_time:34213ms step_avg:94.25ms +step:364/1705 train_time:34306ms step_avg:94.25ms +step:365/1705 train_time:34400ms step_avg:94.25ms +step:366/1705 train_time:34493ms step_avg:94.24ms +step:367/1705 train_time:34586ms step_avg:94.24ms +step:368/1705 train_time:34678ms step_avg:94.23ms +step:369/1705 train_time:34771ms step_avg:94.23ms +step:370/1705 train_time:34864ms step_avg:94.23ms +step:371/1705 train_time:34957ms step_avg:94.22ms +step:372/1705 train_time:35050ms step_avg:94.22ms +step:373/1705 train_time:35143ms step_avg:94.22ms +step:374/1705 train_time:35236ms step_avg:94.21ms +step:375/1705 train_time:35330ms step_avg:94.21ms +step:375/1705 val_loss:3.8258 train_time:35423ms step_avg:94.46ms +step:376/1705 train_time:35445ms step_avg:94.27ms +step:377/1705 train_time:35520ms step_avg:94.22ms +step:378/1705 train_time:35617ms step_avg:94.23ms +step:379/1705 train_time:35712ms step_avg:94.23ms +step:380/1705 train_time:35804ms step_avg:94.22ms +step:381/1705 train_time:35896ms step_avg:94.22ms +step:382/1705 train_time:35989ms step_avg:94.21ms +step:383/1705 train_time:36080ms step_avg:94.20ms +step:384/1705 train_time:36173ms step_avg:94.20ms +step:385/1705 train_time:36264ms step_avg:94.19ms +step:386/1705 train_time:36356ms step_avg:94.19ms +step:387/1705 train_time:36451ms step_avg:94.19ms +step:388/1705 train_time:36547ms step_avg:94.19ms +step:389/1705 train_time:36642ms step_avg:94.19ms +step:390/1705 train_time:36736ms step_avg:94.19ms +step:391/1705 train_time:36828ms step_avg:94.19ms +step:392/1705 train_time:36921ms step_avg:94.19ms +step:393/1705 train_time:37013ms step_avg:94.18ms +step:394/1705 train_time:37106ms step_avg:94.18ms +step:395/1705 train_time:37198ms step_avg:94.17ms +step:396/1705 train_time:37291ms step_avg:94.17ms +step:397/1705 train_time:37384ms step_avg:94.17ms +step:398/1705 train_time:37478ms step_avg:94.16ms +step:399/1705 train_time:37572ms step_avg:94.17ms +step:400/1705 train_time:37667ms step_avg:94.17ms +step:401/1705 train_time:37760ms step_avg:94.16ms +step:402/1705 train_time:37853ms step_avg:94.16ms +step:403/1705 train_time:37945ms step_avg:94.16ms +step:404/1705 train_time:38037ms step_avg:94.15ms +step:405/1705 train_time:38130ms step_avg:94.15ms +step:406/1705 train_time:38222ms step_avg:94.14ms +step:407/1705 train_time:38315ms step_avg:94.14ms +step:408/1705 train_time:38408ms step_avg:94.14ms +step:409/1705 train_time:38501ms step_avg:94.14ms +step:410/1705 train_time:38596ms step_avg:94.14ms +step:411/1705 train_time:38690ms step_avg:94.14ms +step:412/1705 train_time:38785ms step_avg:94.14ms +step:413/1705 train_time:38877ms step_avg:94.13ms +step:414/1705 train_time:38971ms step_avg:94.13ms +step:415/1705 train_time:39063ms step_avg:94.13ms +step:416/1705 train_time:39155ms step_avg:94.12ms +step:417/1705 train_time:39248ms step_avg:94.12ms +step:418/1705 train_time:39340ms step_avg:94.12ms +step:419/1705 train_time:39433ms step_avg:94.11ms +step:420/1705 train_time:39526ms step_avg:94.11ms +step:421/1705 train_time:39620ms step_avg:94.11ms +step:422/1705 train_time:39714ms step_avg:94.11ms +step:423/1705 train_time:39808ms step_avg:94.11ms +step:424/1705 train_time:39901ms step_avg:94.11ms +step:425/1705 train_time:40149ms step_avg:94.47ms +step:426/1705 train_time:40259ms step_avg:94.50ms +step:427/1705 train_time:40350ms step_avg:94.50ms +step:428/1705 train_time:40442ms step_avg:94.49ms +step:429/1705 train_time:40534ms step_avg:94.49ms +step:430/1705 train_time:40626ms step_avg:94.48ms +step:431/1705 train_time:40717ms step_avg:94.47ms +step:432/1705 train_time:40809ms step_avg:94.47ms +step:433/1705 train_time:40902ms step_avg:94.46ms +step:434/1705 train_time:40993ms step_avg:94.45ms +step:435/1705 train_time:41090ms step_avg:94.46ms +step:436/1705 train_time:41186ms step_avg:94.46ms +step:437/1705 train_time:41281ms step_avg:94.46ms +step:438/1705 train_time:41374ms step_avg:94.46ms +step:439/1705 train_time:41467ms step_avg:94.46ms +step:440/1705 train_time:41559ms step_avg:94.45ms +step:441/1705 train_time:41651ms step_avg:94.45ms +step:442/1705 train_time:41743ms step_avg:94.44ms +step:443/1705 train_time:41835ms step_avg:94.44ms +step:444/1705 train_time:41927ms step_avg:94.43ms +step:445/1705 train_time:42020ms step_avg:94.43ms +step:446/1705 train_time:42115ms step_avg:94.43ms +step:447/1705 train_time:42210ms step_avg:94.43ms +step:448/1705 train_time:42305ms step_avg:94.43ms +step:449/1705 train_time:42397ms step_avg:94.43ms +step:450/1705 train_time:42490ms step_avg:94.42ms +step:451/1705 train_time:42583ms step_avg:94.42ms +step:452/1705 train_time:42676ms step_avg:94.42ms +step:453/1705 train_time:42767ms step_avg:94.41ms +step:454/1705 train_time:42860ms step_avg:94.41ms +step:455/1705 train_time:42952ms step_avg:94.40ms +step:456/1705 train_time:43046ms step_avg:94.40ms +step:457/1705 train_time:43139ms step_avg:94.40ms +step:458/1705 train_time:43233ms step_avg:94.40ms +step:459/1705 train_time:43327ms step_avg:94.39ms +step:460/1705 train_time:43420ms step_avg:94.39ms +step:461/1705 train_time:43513ms step_avg:94.39ms +step:462/1705 train_time:43606ms step_avg:94.39ms +step:463/1705 train_time:43699ms step_avg:94.38ms +step:464/1705 train_time:43791ms step_avg:94.38ms +step:465/1705 train_time:43884ms step_avg:94.37ms +step:466/1705 train_time:43976ms step_avg:94.37ms +step:467/1705 train_time:44070ms step_avg:94.37ms +step:468/1705 train_time:44163ms step_avg:94.37ms +step:469/1705 train_time:44256ms step_avg:94.36ms +step:470/1705 train_time:44350ms step_avg:94.36ms +step:471/1705 train_time:44443ms step_avg:94.36ms +step:472/1705 train_time:44535ms step_avg:94.35ms +step:473/1705 train_time:44629ms step_avg:94.35ms +step:474/1705 train_time:44721ms step_avg:94.35ms +step:475/1705 train_time:44815ms step_avg:94.35ms +step:476/1705 train_time:44908ms step_avg:94.34ms +step:477/1705 train_time:45001ms step_avg:94.34ms +step:478/1705 train_time:45095ms step_avg:94.34ms +step:479/1705 train_time:45189ms step_avg:94.34ms +step:480/1705 train_time:45282ms step_avg:94.34ms +step:481/1705 train_time:45375ms step_avg:94.33ms +step:482/1705 train_time:45469ms step_avg:94.33ms +step:483/1705 train_time:45562ms step_avg:94.33ms +step:484/1705 train_time:45655ms step_avg:94.33ms +step:485/1705 train_time:45747ms step_avg:94.32ms +step:486/1705 train_time:45840ms step_avg:94.32ms +step:487/1705 train_time:45933ms step_avg:94.32ms +step:488/1705 train_time:46026ms step_avg:94.31ms +step:489/1705 train_time:46118ms step_avg:94.31ms +step:490/1705 train_time:46212ms step_avg:94.31ms +step:491/1705 train_time:46305ms step_avg:94.31ms +step:492/1705 train_time:46397ms step_avg:94.30ms +step:493/1705 train_time:46491ms step_avg:94.30ms +step:494/1705 train_time:46585ms step_avg:94.30ms +step:495/1705 train_time:46677ms step_avg:94.30ms +step:496/1705 train_time:46770ms step_avg:94.29ms +step:497/1705 train_time:46862ms step_avg:94.29ms +step:498/1705 train_time:46955ms step_avg:94.29ms +step:499/1705 train_time:47048ms step_avg:94.28ms +step:500/1705 train_time:47141ms step_avg:94.28ms +step:500/1705 val_loss:3.7219 train_time:47234ms step_avg:94.47ms +step:501/1705 train_time:47256ms step_avg:94.32ms +step:502/1705 train_time:47332ms step_avg:94.29ms +step:503/1705 train_time:47430ms step_avg:94.30ms +step:504/1705 train_time:47525ms step_avg:94.30ms +step:505/1705 train_time:47617ms step_avg:94.29ms +step:506/1705 train_time:47709ms step_avg:94.29ms +step:507/1705 train_time:47802ms step_avg:94.28ms +step:508/1705 train_time:47894ms step_avg:94.28ms +step:509/1705 train_time:47986ms step_avg:94.27ms +step:510/1705 train_time:48078ms step_avg:94.27ms +step:511/1705 train_time:48170ms step_avg:94.27ms +step:512/1705 train_time:48265ms step_avg:94.27ms +step:513/1705 train_time:48360ms step_avg:94.27ms +step:514/1705 train_time:48454ms step_avg:94.27ms +step:515/1705 train_time:48548ms step_avg:94.27ms +step:516/1705 train_time:48641ms step_avg:94.27ms +step:517/1705 train_time:48734ms step_avg:94.26ms +step:518/1705 train_time:48827ms step_avg:94.26ms +step:519/1705 train_time:48920ms step_avg:94.26ms +step:520/1705 train_time:49012ms step_avg:94.25ms +step:521/1705 train_time:49105ms step_avg:94.25ms +step:522/1705 train_time:49197ms step_avg:94.25ms +step:523/1705 train_time:49290ms step_avg:94.24ms +step:524/1705 train_time:49384ms step_avg:94.24ms +step:525/1705 train_time:49479ms step_avg:94.25ms +step:526/1705 train_time:49571ms step_avg:94.24ms +step:527/1705 train_time:49664ms step_avg:94.24ms +step:528/1705 train_time:49757ms step_avg:94.24ms +step:529/1705 train_time:49850ms step_avg:94.23ms +step:530/1705 train_time:49943ms step_avg:94.23ms +step:531/1705 train_time:50036ms step_avg:94.23ms +step:532/1705 train_time:50128ms step_avg:94.23ms +step:533/1705 train_time:50221ms step_avg:94.22ms +step:534/1705 train_time:50314ms step_avg:94.22ms +step:535/1705 train_time:50408ms step_avg:94.22ms +step:536/1705 train_time:50501ms step_avg:94.22ms +step:537/1705 train_time:50594ms step_avg:94.22ms +step:538/1705 train_time:50687ms step_avg:94.21ms +step:539/1705 train_time:50781ms step_avg:94.21ms +step:540/1705 train_time:50874ms step_avg:94.21ms +step:541/1705 train_time:50966ms step_avg:94.21ms +step:542/1705 train_time:51059ms step_avg:94.20ms +step:543/1705 train_time:51152ms step_avg:94.20ms +step:544/1705 train_time:51245ms step_avg:94.20ms +step:545/1705 train_time:51338ms step_avg:94.20ms +step:546/1705 train_time:51432ms step_avg:94.20ms +step:547/1705 train_time:51526ms step_avg:94.20ms +step:548/1705 train_time:51619ms step_avg:94.19ms +step:549/1705 train_time:51711ms step_avg:94.19ms +step:550/1705 train_time:51805ms step_avg:94.19ms +step:551/1705 train_time:51898ms step_avg:94.19ms +step:552/1705 train_time:51990ms step_avg:94.18ms +step:553/1705 train_time:52083ms step_avg:94.18ms +step:554/1705 train_time:52175ms step_avg:94.18ms +step:555/1705 train_time:52267ms step_avg:94.18ms +step:556/1705 train_time:52361ms step_avg:94.17ms +step:557/1705 train_time:52454ms step_avg:94.17ms +step:558/1705 train_time:52547ms step_avg:94.17ms +step:559/1705 train_time:52640ms step_avg:94.17ms +step:560/1705 train_time:52733ms step_avg:94.17ms +step:561/1705 train_time:52827ms step_avg:94.17ms +step:562/1705 train_time:52920ms step_avg:94.16ms +step:563/1705 train_time:53013ms step_avg:94.16ms +step:564/1705 train_time:53105ms step_avg:94.16ms +step:565/1705 train_time:53198ms step_avg:94.16ms +step:566/1705 train_time:53291ms step_avg:94.15ms +step:567/1705 train_time:53384ms step_avg:94.15ms +step:568/1705 train_time:53477ms step_avg:94.15ms +step:569/1705 train_time:53570ms step_avg:94.15ms +step:570/1705 train_time:53663ms step_avg:94.15ms +step:571/1705 train_time:53758ms step_avg:94.15ms +step:572/1705 train_time:53852ms step_avg:94.15ms +step:573/1705 train_time:53946ms step_avg:94.15ms +step:574/1705 train_time:54040ms step_avg:94.15ms +step:575/1705 train_time:54135ms step_avg:94.15ms +step:576/1705 train_time:54229ms step_avg:94.15ms +step:577/1705 train_time:54323ms step_avg:94.15ms +step:578/1705 train_time:54418ms step_avg:94.15ms +step:579/1705 train_time:54512ms step_avg:94.15ms +step:580/1705 train_time:54606ms step_avg:94.15ms +step:581/1705 train_time:54701ms step_avg:94.15ms +step:582/1705 train_time:54796ms step_avg:94.15ms +step:583/1705 train_time:54889ms step_avg:94.15ms +step:584/1705 train_time:54984ms step_avg:94.15ms +step:585/1705 train_time:55078ms step_avg:94.15ms +step:586/1705 train_time:55171ms step_avg:94.15ms +step:587/1705 train_time:55266ms step_avg:94.15ms +step:588/1705 train_time:55361ms step_avg:94.15ms +step:589/1705 train_time:55456ms step_avg:94.15ms +step:590/1705 train_time:55550ms step_avg:94.15ms +step:591/1705 train_time:55645ms step_avg:94.15ms +step:592/1705 train_time:55740ms step_avg:94.16ms +step:593/1705 train_time:55834ms step_avg:94.15ms +step:594/1705 train_time:55928ms step_avg:94.15ms +step:595/1705 train_time:56023ms step_avg:94.16ms +step:596/1705 train_time:56117ms step_avg:94.16ms +step:597/1705 train_time:56211ms step_avg:94.16ms +step:598/1705 train_time:56305ms step_avg:94.16ms +step:599/1705 train_time:56400ms step_avg:94.16ms +step:600/1705 train_time:56495ms step_avg:94.16ms +step:601/1705 train_time:56589ms step_avg:94.16ms +step:602/1705 train_time:56683ms step_avg:94.16ms +step:603/1705 train_time:56777ms step_avg:94.16ms +step:604/1705 train_time:56871ms step_avg:94.16ms +step:605/1705 train_time:56966ms step_avg:94.16ms +step:606/1705 train_time:57060ms step_avg:94.16ms +step:607/1705 train_time:57153ms step_avg:94.16ms +step:608/1705 train_time:57247ms step_avg:94.16ms +step:609/1705 train_time:57342ms step_avg:94.16ms +step:610/1705 train_time:57436ms step_avg:94.16ms +step:611/1705 train_time:57530ms step_avg:94.16ms +step:612/1705 train_time:57625ms step_avg:94.16ms +step:613/1705 train_time:57719ms step_avg:94.16ms +step:614/1705 train_time:57814ms step_avg:94.16ms +step:615/1705 train_time:57908ms step_avg:94.16ms +step:616/1705 train_time:58003ms step_avg:94.16ms +step:617/1705 train_time:58097ms step_avg:94.16ms +step:618/1705 train_time:58191ms step_avg:94.16ms +step:619/1705 train_time:58285ms step_avg:94.16ms +step:620/1705 train_time:58380ms step_avg:94.16ms +step:621/1705 train_time:58475ms step_avg:94.16ms +step:622/1705 train_time:58569ms step_avg:94.16ms +step:623/1705 train_time:58664ms step_avg:94.16ms +step:624/1705 train_time:58759ms step_avg:94.17ms +step:625/1705 train_time:58854ms step_avg:94.17ms +step:625/1705 val_loss:3.6232 train_time:58948ms step_avg:94.32ms +step:626/1705 train_time:58970ms step_avg:94.20ms +step:627/1705 train_time:59057ms step_avg:94.19ms +step:628/1705 train_time:59153ms step_avg:94.19ms +step:629/1705 train_time:59248ms step_avg:94.19ms +step:630/1705 train_time:59342ms step_avg:94.19ms +step:631/1705 train_time:59434ms step_avg:94.19ms +step:632/1705 train_time:59528ms step_avg:94.19ms +step:633/1705 train_time:59621ms step_avg:94.19ms +step:634/1705 train_time:59713ms step_avg:94.18ms +step:635/1705 train_time:59807ms step_avg:94.18ms +step:636/1705 train_time:59901ms step_avg:94.18ms +step:637/1705 train_time:59998ms step_avg:94.19ms +step:638/1705 train_time:60095ms step_avg:94.19ms +step:639/1705 train_time:60456ms step_avg:94.61ms +step:640/1705 train_time:60543ms step_avg:94.60ms +step:641/1705 train_time:60635ms step_avg:94.59ms +step:642/1705 train_time:60728ms step_avg:94.59ms +step:643/1705 train_time:60822ms step_avg:94.59ms +step:644/1705 train_time:60915ms step_avg:94.59ms +step:645/1705 train_time:61008ms step_avg:94.59ms +step:646/1705 train_time:61102ms step_avg:94.58ms +step:647/1705 train_time:61194ms step_avg:94.58ms +step:648/1705 train_time:61288ms step_avg:94.58ms +step:649/1705 train_time:61386ms step_avg:94.59ms +step:650/1705 train_time:61484ms step_avg:94.59ms +step:651/1705 train_time:61578ms step_avg:94.59ms +step:652/1705 train_time:61671ms step_avg:94.59ms +step:653/1705 train_time:61766ms step_avg:94.59ms +step:654/1705 train_time:61860ms step_avg:94.59ms +step:655/1705 train_time:61953ms step_avg:94.58ms +step:656/1705 train_time:62047ms step_avg:94.58ms +step:657/1705 train_time:62142ms step_avg:94.58ms +step:658/1705 train_time:62235ms step_avg:94.58ms +step:659/1705 train_time:62330ms step_avg:94.58ms +step:660/1705 train_time:62427ms step_avg:94.59ms +step:661/1705 train_time:62522ms step_avg:94.59ms +step:662/1705 train_time:62616ms step_avg:94.59ms +step:663/1705 train_time:62711ms step_avg:94.59ms +step:664/1705 train_time:62805ms step_avg:94.59ms +step:665/1705 train_time:62899ms step_avg:94.58ms +step:666/1705 train_time:62992ms step_avg:94.58ms +step:667/1705 train_time:63086ms step_avg:94.58ms +step:668/1705 train_time:63180ms step_avg:94.58ms +step:669/1705 train_time:63274ms step_avg:94.58ms +step:670/1705 train_time:63369ms step_avg:94.58ms +step:671/1705 train_time:63465ms step_avg:94.58ms +step:672/1705 train_time:63560ms step_avg:94.58ms +step:673/1705 train_time:63655ms step_avg:94.58ms +step:674/1705 train_time:63749ms step_avg:94.58ms +step:675/1705 train_time:63843ms step_avg:94.58ms +step:676/1705 train_time:63937ms step_avg:94.58ms +step:677/1705 train_time:64030ms step_avg:94.58ms +step:678/1705 train_time:64124ms step_avg:94.58ms +step:679/1705 train_time:64218ms step_avg:94.58ms +step:680/1705 train_time:64312ms step_avg:94.58ms +step:681/1705 train_time:64407ms step_avg:94.58ms +step:682/1705 train_time:64503ms step_avg:94.58ms +step:683/1705 train_time:64597ms step_avg:94.58ms +step:684/1705 train_time:64691ms step_avg:94.58ms +step:685/1705 train_time:64787ms step_avg:94.58ms +step:686/1705 train_time:64881ms step_avg:94.58ms +step:687/1705 train_time:64974ms step_avg:94.58ms +step:688/1705 train_time:65068ms step_avg:94.58ms +step:689/1705 train_time:65162ms step_avg:94.57ms +step:690/1705 train_time:65256ms step_avg:94.57ms +step:691/1705 train_time:65350ms step_avg:94.57ms +step:692/1705 train_time:65445ms step_avg:94.57ms +step:693/1705 train_time:65540ms step_avg:94.57ms +step:694/1705 train_time:65634ms step_avg:94.57ms +step:695/1705 train_time:65728ms step_avg:94.57ms +step:696/1705 train_time:65823ms step_avg:94.57ms +step:697/1705 train_time:65916ms step_avg:94.57ms +step:698/1705 train_time:66010ms step_avg:94.57ms +step:699/1705 train_time:66105ms step_avg:94.57ms +step:700/1705 train_time:66199ms step_avg:94.57ms +step:701/1705 train_time:66294ms step_avg:94.57ms +step:702/1705 train_time:66388ms step_avg:94.57ms +step:703/1705 train_time:66483ms step_avg:94.57ms +step:704/1705 train_time:66577ms step_avg:94.57ms +step:705/1705 train_time:66672ms step_avg:94.57ms +step:706/1705 train_time:66767ms step_avg:94.57ms +step:707/1705 train_time:66861ms step_avg:94.57ms +step:708/1705 train_time:66955ms step_avg:94.57ms +step:709/1705 train_time:67049ms step_avg:94.57ms +step:710/1705 train_time:67143ms step_avg:94.57ms +step:711/1705 train_time:67237ms step_avg:94.57ms +step:712/1705 train_time:67331ms step_avg:94.57ms +step:713/1705 train_time:67425ms step_avg:94.57ms +step:714/1705 train_time:67520ms step_avg:94.57ms +step:715/1705 train_time:67615ms step_avg:94.57ms +step:716/1705 train_time:67710ms step_avg:94.57ms +step:717/1705 train_time:67805ms step_avg:94.57ms +step:718/1705 train_time:67899ms step_avg:94.57ms +step:719/1705 train_time:67993ms step_avg:94.57ms +step:720/1705 train_time:68087ms step_avg:94.57ms +step:721/1705 train_time:68181ms step_avg:94.56ms +step:722/1705 train_time:68275ms step_avg:94.56ms +step:723/1705 train_time:68369ms step_avg:94.56ms +step:724/1705 train_time:68464ms step_avg:94.56ms +step:725/1705 train_time:68558ms step_avg:94.56ms +step:726/1705 train_time:68652ms step_avg:94.56ms +step:727/1705 train_time:68748ms step_avg:94.56ms +step:728/1705 train_time:68843ms step_avg:94.56ms +step:729/1705 train_time:68936ms step_avg:94.56ms +step:730/1705 train_time:69030ms step_avg:94.56ms +step:731/1705 train_time:69125ms step_avg:94.56ms +step:732/1705 train_time:69219ms step_avg:94.56ms +step:733/1705 train_time:69313ms step_avg:94.56ms +step:734/1705 train_time:69409ms step_avg:94.56ms +step:735/1705 train_time:69504ms step_avg:94.56ms +step:736/1705 train_time:69598ms step_avg:94.56ms +step:737/1705 train_time:69692ms step_avg:94.56ms +step:738/1705 train_time:69788ms step_avg:94.56ms +step:739/1705 train_time:69882ms step_avg:94.56ms +step:740/1705 train_time:69976ms step_avg:94.56ms +step:741/1705 train_time:70069ms step_avg:94.56ms +step:742/1705 train_time:70164ms step_avg:94.56ms +step:743/1705 train_time:70258ms step_avg:94.56ms +step:744/1705 train_time:70352ms step_avg:94.56ms +step:745/1705 train_time:70448ms step_avg:94.56ms +step:746/1705 train_time:70542ms step_avg:94.56ms +step:747/1705 train_time:70636ms step_avg:94.56ms +step:748/1705 train_time:70730ms step_avg:94.56ms +step:749/1705 train_time:70825ms step_avg:94.56ms +step:750/1705 train_time:70919ms step_avg:94.56ms +step:750/1705 val_loss:3.5678 train_time:71013ms step_avg:94.68ms +step:751/1705 train_time:71037ms step_avg:94.59ms +step:752/1705 train_time:71115ms step_avg:94.57ms +step:753/1705 train_time:71215ms step_avg:94.57ms +step:754/1705 train_time:71310ms step_avg:94.58ms +step:755/1705 train_time:71404ms step_avg:94.58ms +step:756/1705 train_time:71497ms step_avg:94.57ms +step:757/1705 train_time:71591ms step_avg:94.57ms +step:758/1705 train_time:71684ms step_avg:94.57ms +step:759/1705 train_time:71778ms step_avg:94.57ms +step:760/1705 train_time:71871ms step_avg:94.57ms +step:761/1705 train_time:71967ms step_avg:94.57ms +step:762/1705 train_time:72063ms step_avg:94.57ms +step:763/1705 train_time:72158ms step_avg:94.57ms +step:764/1705 train_time:72254ms step_avg:94.57ms +step:765/1705 train_time:72351ms step_avg:94.58ms +step:766/1705 train_time:72444ms step_avg:94.57ms +step:767/1705 train_time:72537ms step_avg:94.57ms +step:768/1705 train_time:72631ms step_avg:94.57ms +step:769/1705 train_time:72724ms step_avg:94.57ms +step:770/1705 train_time:72818ms step_avg:94.57ms +step:771/1705 train_time:72911ms step_avg:94.57ms +step:772/1705 train_time:73006ms step_avg:94.57ms +step:773/1705 train_time:73101ms step_avg:94.57ms +step:774/1705 train_time:73196ms step_avg:94.57ms +step:775/1705 train_time:73291ms step_avg:94.57ms +step:776/1705 train_time:73386ms step_avg:94.57ms +step:777/1705 train_time:73480ms step_avg:94.57ms +step:778/1705 train_time:73574ms step_avg:94.57ms +step:779/1705 train_time:73668ms step_avg:94.57ms +step:780/1705 train_time:73761ms step_avg:94.57ms +step:781/1705 train_time:73856ms step_avg:94.57ms +step:782/1705 train_time:73950ms step_avg:94.56ms +step:783/1705 train_time:74045ms step_avg:94.57ms +step:784/1705 train_time:74140ms step_avg:94.57ms +step:785/1705 train_time:74234ms step_avg:94.57ms +step:786/1705 train_time:74330ms step_avg:94.57ms +step:787/1705 train_time:74425ms step_avg:94.57ms +step:788/1705 train_time:74520ms step_avg:94.57ms +step:789/1705 train_time:74614ms step_avg:94.57ms +step:790/1705 train_time:74709ms step_avg:94.57ms +step:791/1705 train_time:74803ms step_avg:94.57ms +step:792/1705 train_time:74896ms step_avg:94.57ms +step:793/1705 train_time:74991ms step_avg:94.57ms +step:794/1705 train_time:75085ms step_avg:94.57ms +step:795/1705 train_time:75180ms step_avg:94.57ms +step:796/1705 train_time:75275ms step_avg:94.57ms +step:797/1705 train_time:75370ms step_avg:94.57ms +step:798/1705 train_time:75465ms step_avg:94.57ms +step:799/1705 train_time:75559ms step_avg:94.57ms +step:800/1705 train_time:75653ms step_avg:94.57ms +step:801/1705 train_time:75748ms step_avg:94.57ms +step:802/1705 train_time:75842ms step_avg:94.57ms +step:803/1705 train_time:75935ms step_avg:94.56ms +step:804/1705 train_time:76030ms step_avg:94.56ms +step:805/1705 train_time:76125ms step_avg:94.57ms +step:806/1705 train_time:76219ms step_avg:94.56ms +step:807/1705 train_time:76313ms step_avg:94.56ms +step:808/1705 train_time:76408ms step_avg:94.56ms +step:809/1705 train_time:76503ms step_avg:94.57ms +step:810/1705 train_time:76597ms step_avg:94.56ms +step:811/1705 train_time:76691ms step_avg:94.56ms +step:812/1705 train_time:76786ms step_avg:94.56ms +step:813/1705 train_time:76879ms step_avg:94.56ms +step:814/1705 train_time:76973ms step_avg:94.56ms +step:815/1705 train_time:77068ms step_avg:94.56ms +step:816/1705 train_time:77163ms step_avg:94.56ms +step:817/1705 train_time:77257ms step_avg:94.56ms +step:818/1705 train_time:77352ms step_avg:94.56ms +step:819/1705 train_time:77447ms step_avg:94.56ms +step:820/1705 train_time:77541ms step_avg:94.56ms +step:821/1705 train_time:77635ms step_avg:94.56ms +step:822/1705 train_time:77729ms step_avg:94.56ms +step:823/1705 train_time:77824ms step_avg:94.56ms +step:824/1705 train_time:77918ms step_avg:94.56ms +step:825/1705 train_time:78012ms step_avg:94.56ms +step:826/1705 train_time:78107ms step_avg:94.56ms +step:827/1705 train_time:78201ms step_avg:94.56ms +step:828/1705 train_time:78295ms step_avg:94.56ms +step:829/1705 train_time:78389ms step_avg:94.56ms +step:830/1705 train_time:78484ms step_avg:94.56ms +step:831/1705 train_time:78577ms step_avg:94.56ms +step:832/1705 train_time:78672ms step_avg:94.56ms +step:833/1705 train_time:78768ms step_avg:94.56ms +step:834/1705 train_time:78861ms step_avg:94.56ms +step:835/1705 train_time:78955ms step_avg:94.56ms +step:836/1705 train_time:79050ms step_avg:94.56ms +step:837/1705 train_time:79145ms step_avg:94.56ms +step:838/1705 train_time:79239ms step_avg:94.56ms +step:839/1705 train_time:79333ms step_avg:94.56ms +step:840/1705 train_time:79429ms step_avg:94.56ms +step:841/1705 train_time:79525ms step_avg:94.56ms +step:842/1705 train_time:79619ms step_avg:94.56ms +step:843/1705 train_time:79713ms step_avg:94.56ms +step:844/1705 train_time:79808ms step_avg:94.56ms +step:845/1705 train_time:79902ms step_avg:94.56ms +step:846/1705 train_time:79997ms step_avg:94.56ms +step:847/1705 train_time:80092ms step_avg:94.56ms +step:848/1705 train_time:80187ms step_avg:94.56ms +step:849/1705 train_time:80281ms step_avg:94.56ms +step:850/1705 train_time:80375ms step_avg:94.56ms +step:851/1705 train_time:80637ms step_avg:94.76ms +step:852/1705 train_time:80715ms step_avg:94.74ms +step:853/1705 train_time:80808ms step_avg:94.73ms +step:854/1705 train_time:80901ms step_avg:94.73ms +step:855/1705 train_time:80994ms step_avg:94.73ms +step:856/1705 train_time:81088ms step_avg:94.73ms +step:857/1705 train_time:81181ms step_avg:94.73ms +step:858/1705 train_time:81275ms step_avg:94.73ms +step:859/1705 train_time:81368ms step_avg:94.72ms +step:860/1705 train_time:81461ms step_avg:94.72ms +step:861/1705 train_time:81561ms step_avg:94.73ms +step:862/1705 train_time:81658ms step_avg:94.73ms +step:863/1705 train_time:81755ms step_avg:94.73ms +step:864/1705 train_time:81850ms step_avg:94.73ms +step:865/1705 train_time:81944ms step_avg:94.73ms +step:866/1705 train_time:82038ms step_avg:94.73ms +step:867/1705 train_time:82131ms step_avg:94.73ms +step:868/1705 train_time:82225ms step_avg:94.73ms +step:869/1705 train_time:82318ms step_avg:94.73ms +step:870/1705 train_time:82412ms step_avg:94.73ms +step:871/1705 train_time:82508ms step_avg:94.73ms +step:872/1705 train_time:82604ms step_avg:94.73ms +step:873/1705 train_time:82699ms step_avg:94.73ms +step:874/1705 train_time:82794ms step_avg:94.73ms +step:875/1705 train_time:82889ms step_avg:94.73ms +step:875/1705 val_loss:3.5253 train_time:82984ms step_avg:94.84ms +step:876/1705 train_time:83006ms step_avg:94.76ms +step:877/1705 train_time:83083ms step_avg:94.74ms +step:878/1705 train_time:83183ms step_avg:94.74ms +step:879/1705 train_time:83277ms step_avg:94.74ms +step:880/1705 train_time:83371ms step_avg:94.74ms +step:881/1705 train_time:83465ms step_avg:94.74ms +step:882/1705 train_time:83558ms step_avg:94.74ms +step:883/1705 train_time:83652ms step_avg:94.74ms +step:884/1705 train_time:83746ms step_avg:94.74ms +step:885/1705 train_time:83841ms step_avg:94.74ms +step:886/1705 train_time:83937ms step_avg:94.74ms +step:887/1705 train_time:84032ms step_avg:94.74ms +step:888/1705 train_time:84129ms step_avg:94.74ms +step:889/1705 train_time:84224ms step_avg:94.74ms +step:890/1705 train_time:84319ms step_avg:94.74ms +step:891/1705 train_time:84412ms step_avg:94.74ms +step:892/1705 train_time:84507ms step_avg:94.74ms +step:893/1705 train_time:84601ms step_avg:94.74ms +step:894/1705 train_time:84695ms step_avg:94.74ms +step:895/1705 train_time:84789ms step_avg:94.74ms +step:896/1705 train_time:84884ms step_avg:94.74ms +step:897/1705 train_time:84979ms step_avg:94.74ms +step:898/1705 train_time:85074ms step_avg:94.74ms +step:899/1705 train_time:85170ms step_avg:94.74ms +step:900/1705 train_time:85265ms step_avg:94.74ms +step:901/1705 train_time:85361ms step_avg:94.74ms +step:902/1705 train_time:85454ms step_avg:94.74ms +step:903/1705 train_time:85548ms step_avg:94.74ms +step:904/1705 train_time:85642ms step_avg:94.74ms +step:905/1705 train_time:85736ms step_avg:94.74ms +step:906/1705 train_time:85829ms step_avg:94.73ms +step:907/1705 train_time:85924ms step_avg:94.73ms +step:908/1705 train_time:86019ms step_avg:94.73ms +step:909/1705 train_time:86114ms step_avg:94.73ms +step:910/1705 train_time:86209ms step_avg:94.73ms +step:911/1705 train_time:86304ms step_avg:94.74ms +step:912/1705 train_time:86399ms step_avg:94.74ms +step:913/1705 train_time:86493ms step_avg:94.74ms +step:914/1705 train_time:86588ms step_avg:94.73ms +step:915/1705 train_time:86682ms step_avg:94.73ms +step:916/1705 train_time:86776ms step_avg:94.73ms +step:917/1705 train_time:86870ms step_avg:94.73ms +step:918/1705 train_time:86965ms step_avg:94.73ms +step:919/1705 train_time:87060ms step_avg:94.73ms +step:920/1705 train_time:87155ms step_avg:94.73ms +step:921/1705 train_time:87250ms step_avg:94.73ms +step:922/1705 train_time:87345ms step_avg:94.73ms +step:923/1705 train_time:87439ms step_avg:94.73ms +step:924/1705 train_time:87532ms step_avg:94.73ms +step:925/1705 train_time:87627ms step_avg:94.73ms +step:926/1705 train_time:87721ms step_avg:94.73ms +step:927/1705 train_time:87815ms step_avg:94.73ms +step:928/1705 train_time:87910ms step_avg:94.73ms +step:929/1705 train_time:88005ms step_avg:94.73ms +step:930/1705 train_time:88099ms step_avg:94.73ms +step:931/1705 train_time:88193ms step_avg:94.73ms +step:932/1705 train_time:88288ms step_avg:94.73ms +step:933/1705 train_time:88382ms step_avg:94.73ms +step:934/1705 train_time:88476ms step_avg:94.73ms +step:935/1705 train_time:88570ms step_avg:94.73ms +step:936/1705 train_time:88665ms step_avg:94.73ms +step:937/1705 train_time:88759ms step_avg:94.73ms +step:938/1705 train_time:88853ms step_avg:94.73ms +step:939/1705 train_time:88948ms step_avg:94.73ms +step:940/1705 train_time:89044ms step_avg:94.73ms +step:941/1705 train_time:89139ms step_avg:94.73ms +step:942/1705 train_time:89233ms step_avg:94.73ms +step:943/1705 train_time:89328ms step_avg:94.73ms +step:944/1705 train_time:89423ms step_avg:94.73ms +step:945/1705 train_time:89516ms step_avg:94.73ms +step:946/1705 train_time:89611ms step_avg:94.73ms +step:947/1705 train_time:89706ms step_avg:94.73ms +step:948/1705 train_time:89800ms step_avg:94.73ms +step:949/1705 train_time:89894ms step_avg:94.73ms +step:950/1705 train_time:89989ms step_avg:94.72ms +step:951/1705 train_time:90083ms step_avg:94.72ms +step:952/1705 train_time:90178ms step_avg:94.72ms +step:953/1705 train_time:90272ms step_avg:94.72ms +step:954/1705 train_time:90367ms step_avg:94.72ms +step:955/1705 train_time:90462ms step_avg:94.72ms +step:956/1705 train_time:90556ms step_avg:94.72ms +step:957/1705 train_time:90650ms step_avg:94.72ms +step:958/1705 train_time:90745ms step_avg:94.72ms +step:959/1705 train_time:90839ms step_avg:94.72ms +step:960/1705 train_time:90933ms step_avg:94.72ms +step:961/1705 train_time:91027ms step_avg:94.72ms +step:962/1705 train_time:91122ms step_avg:94.72ms +step:963/1705 train_time:91217ms step_avg:94.72ms +step:964/1705 train_time:91311ms step_avg:94.72ms +step:965/1705 train_time:91407ms step_avg:94.72ms +step:966/1705 train_time:91501ms step_avg:94.72ms +step:967/1705 train_time:91595ms step_avg:94.72ms +step:968/1705 train_time:91689ms step_avg:94.72ms +step:969/1705 train_time:91783ms step_avg:94.72ms +step:970/1705 train_time:91878ms step_avg:94.72ms +step:971/1705 train_time:91971ms step_avg:94.72ms +step:972/1705 train_time:92067ms step_avg:94.72ms +step:973/1705 train_time:92162ms step_avg:94.72ms +step:974/1705 train_time:92257ms step_avg:94.72ms +step:975/1705 train_time:92351ms step_avg:94.72ms +step:976/1705 train_time:92446ms step_avg:94.72ms +step:977/1705 train_time:92540ms step_avg:94.72ms +step:978/1705 train_time:92634ms step_avg:94.72ms +step:979/1705 train_time:92729ms step_avg:94.72ms +step:980/1705 train_time:92824ms step_avg:94.72ms +step:981/1705 train_time:92918ms step_avg:94.72ms +step:982/1705 train_time:93012ms step_avg:94.72ms +step:983/1705 train_time:93109ms step_avg:94.72ms +step:984/1705 train_time:93204ms step_avg:94.72ms +step:985/1705 train_time:93298ms step_avg:94.72ms +step:986/1705 train_time:93392ms step_avg:94.72ms +step:987/1705 train_time:93487ms step_avg:94.72ms +step:988/1705 train_time:93582ms step_avg:94.72ms +step:989/1705 train_time:93676ms step_avg:94.72ms +step:990/1705 train_time:93771ms step_avg:94.72ms +step:991/1705 train_time:93866ms step_avg:94.72ms +step:992/1705 train_time:93960ms step_avg:94.72ms +step:993/1705 train_time:94054ms step_avg:94.72ms +step:994/1705 train_time:94148ms step_avg:94.72ms +step:995/1705 train_time:94244ms step_avg:94.72ms +step:996/1705 train_time:94338ms step_avg:94.72ms +step:997/1705 train_time:94432ms step_avg:94.72ms +step:998/1705 train_time:94526ms step_avg:94.72ms +step:999/1705 train_time:94621ms step_avg:94.72ms +step:1000/1705 train_time:94716ms step_avg:94.72ms +step:1000/1705 val_loss:3.4858 train_time:94810ms step_avg:94.81ms +step:1001/1705 train_time:94832ms step_avg:94.74ms +step:1002/1705 train_time:94912ms step_avg:94.72ms +step:1003/1705 train_time:95012ms step_avg:94.73ms +step:1004/1705 train_time:95108ms step_avg:94.73ms +step:1005/1705 train_time:95201ms step_avg:94.73ms +step:1006/1705 train_time:95295ms step_avg:94.73ms +step:1007/1705 train_time:95389ms step_avg:94.73ms +step:1008/1705 train_time:95481ms step_avg:94.72ms +step:1009/1705 train_time:95574ms step_avg:94.72ms +step:1010/1705 train_time:95667ms step_avg:94.72ms +step:1011/1705 train_time:95762ms step_avg:94.72ms +step:1012/1705 train_time:95858ms step_avg:94.72ms +step:1013/1705 train_time:95954ms step_avg:94.72ms +step:1014/1705 train_time:96051ms step_avg:94.72ms +step:1015/1705 train_time:96147ms step_avg:94.73ms +step:1016/1705 train_time:96241ms step_avg:94.73ms +step:1017/1705 train_time:96334ms step_avg:94.72ms +step:1018/1705 train_time:96428ms step_avg:94.72ms +step:1019/1705 train_time:96522ms step_avg:94.72ms +step:1020/1705 train_time:96615ms step_avg:94.72ms +step:1021/1705 train_time:96709ms step_avg:94.72ms +step:1022/1705 train_time:96803ms step_avg:94.72ms +step:1023/1705 train_time:96899ms step_avg:94.72ms +step:1024/1705 train_time:96995ms step_avg:94.72ms +step:1025/1705 train_time:97090ms step_avg:94.72ms +step:1026/1705 train_time:97186ms step_avg:94.72ms +step:1027/1705 train_time:97280ms step_avg:94.72ms +step:1028/1705 train_time:97374ms step_avg:94.72ms +step:1029/1705 train_time:97468ms step_avg:94.72ms +step:1030/1705 train_time:97561ms step_avg:94.72ms +step:1031/1705 train_time:97654ms step_avg:94.72ms +step:1032/1705 train_time:97749ms step_avg:94.72ms +step:1033/1705 train_time:97844ms step_avg:94.72ms +step:1034/1705 train_time:97939ms step_avg:94.72ms +step:1035/1705 train_time:98033ms step_avg:94.72ms +step:1036/1705 train_time:98130ms step_avg:94.72ms +step:1037/1705 train_time:98226ms step_avg:94.72ms +step:1038/1705 train_time:98321ms step_avg:94.72ms +step:1039/1705 train_time:98414ms step_avg:94.72ms +step:1040/1705 train_time:98509ms step_avg:94.72ms +step:1041/1705 train_time:98603ms step_avg:94.72ms +step:1042/1705 train_time:98698ms step_avg:94.72ms +step:1043/1705 train_time:98791ms step_avg:94.72ms +step:1044/1705 train_time:98888ms step_avg:94.72ms +step:1045/1705 train_time:98982ms step_avg:94.72ms +step:1046/1705 train_time:99076ms step_avg:94.72ms +step:1047/1705 train_time:99171ms step_avg:94.72ms +step:1048/1705 train_time:99266ms step_avg:94.72ms +step:1049/1705 train_time:99361ms step_avg:94.72ms +step:1050/1705 train_time:99455ms step_avg:94.72ms +step:1051/1705 train_time:99549ms step_avg:94.72ms +step:1052/1705 train_time:99644ms step_avg:94.72ms +step:1053/1705 train_time:99738ms step_avg:94.72ms +step:1054/1705 train_time:99832ms step_avg:94.72ms +step:1055/1705 train_time:99927ms step_avg:94.72ms +step:1056/1705 train_time:100022ms step_avg:94.72ms +step:1057/1705 train_time:100117ms step_avg:94.72ms +step:1058/1705 train_time:100211ms step_avg:94.72ms +step:1059/1705 train_time:100306ms step_avg:94.72ms +step:1060/1705 train_time:100401ms step_avg:94.72ms +step:1061/1705 train_time:100494ms step_avg:94.72ms +step:1062/1705 train_time:100743ms step_avg:94.86ms +step:1063/1705 train_time:100846ms step_avg:94.87ms +step:1064/1705 train_time:100938ms step_avg:94.87ms +step:1065/1705 train_time:101032ms step_avg:94.87ms +step:1066/1705 train_time:101126ms step_avg:94.87ms +step:1067/1705 train_time:101219ms step_avg:94.86ms +step:1068/1705 train_time:101313ms step_avg:94.86ms +step:1069/1705 train_time:101407ms step_avg:94.86ms +step:1070/1705 train_time:101500ms step_avg:94.86ms +step:1071/1705 train_time:101593ms step_avg:94.86ms +step:1072/1705 train_time:101693ms step_avg:94.86ms +step:1073/1705 train_time:101791ms step_avg:94.87ms +step:1074/1705 train_time:101888ms step_avg:94.87ms +step:1075/1705 train_time:101982ms step_avg:94.87ms +step:1076/1705 train_time:102076ms step_avg:94.87ms +step:1077/1705 train_time:102170ms step_avg:94.87ms +step:1078/1705 train_time:102264ms step_avg:94.86ms +step:1079/1705 train_time:102357ms step_avg:94.86ms +step:1080/1705 train_time:102451ms step_avg:94.86ms +step:1081/1705 train_time:102544ms step_avg:94.86ms +step:1082/1705 train_time:102639ms step_avg:94.86ms +step:1083/1705 train_time:102734ms step_avg:94.86ms +step:1084/1705 train_time:102831ms step_avg:94.86ms +step:1085/1705 train_time:102928ms step_avg:94.86ms +step:1086/1705 train_time:103023ms step_avg:94.86ms +step:1087/1705 train_time:103117ms step_avg:94.86ms +step:1088/1705 train_time:103211ms step_avg:94.86ms +step:1089/1705 train_time:103305ms step_avg:94.86ms +step:1090/1705 train_time:103399ms step_avg:94.86ms +step:1091/1705 train_time:103493ms step_avg:94.86ms +step:1092/1705 train_time:103587ms step_avg:94.86ms +step:1093/1705 train_time:103682ms step_avg:94.86ms +step:1094/1705 train_time:103778ms step_avg:94.86ms +step:1095/1705 train_time:103873ms step_avg:94.86ms +step:1096/1705 train_time:103967ms step_avg:94.86ms +step:1097/1705 train_time:104062ms step_avg:94.86ms +step:1098/1705 train_time:104156ms step_avg:94.86ms +step:1099/1705 train_time:104250ms step_avg:94.86ms +step:1100/1705 train_time:104345ms step_avg:94.86ms +step:1101/1705 train_time:104439ms step_avg:94.86ms +step:1102/1705 train_time:104533ms step_avg:94.86ms +step:1103/1705 train_time:104628ms step_avg:94.86ms +step:1104/1705 train_time:104723ms step_avg:94.86ms +step:1105/1705 train_time:104818ms step_avg:94.86ms +step:1106/1705 train_time:104912ms step_avg:94.86ms +step:1107/1705 train_time:105007ms step_avg:94.86ms +step:1108/1705 train_time:105102ms step_avg:94.86ms +step:1109/1705 train_time:105196ms step_avg:94.86ms +step:1110/1705 train_time:105290ms step_avg:94.86ms +step:1111/1705 train_time:105385ms step_avg:94.86ms +step:1112/1705 train_time:105480ms step_avg:94.86ms +step:1113/1705 train_time:105573ms step_avg:94.85ms +step:1114/1705 train_time:105668ms step_avg:94.85ms +step:1115/1705 train_time:105763ms step_avg:94.86ms +step:1116/1705 train_time:105858ms step_avg:94.85ms +step:1117/1705 train_time:105953ms step_avg:94.85ms +step:1118/1705 train_time:106047ms step_avg:94.85ms +step:1119/1705 train_time:106142ms step_avg:94.85ms +step:1120/1705 train_time:106237ms step_avg:94.85ms +step:1121/1705 train_time:106331ms step_avg:94.85ms +step:1122/1705 train_time:106427ms step_avg:94.85ms +step:1123/1705 train_time:106521ms step_avg:94.85ms +step:1124/1705 train_time:106615ms step_avg:94.85ms +step:1125/1705 train_time:106710ms step_avg:94.85ms +step:1125/1705 val_loss:3.4384 train_time:106806ms step_avg:94.94ms +step:1126/1705 train_time:106828ms step_avg:94.87ms +step:1127/1705 train_time:106907ms step_avg:94.86ms +step:1128/1705 train_time:107005ms step_avg:94.86ms +step:1129/1705 train_time:107099ms step_avg:94.86ms +step:1130/1705 train_time:107193ms step_avg:94.86ms +step:1131/1705 train_time:107287ms step_avg:94.86ms +step:1132/1705 train_time:107380ms step_avg:94.86ms +step:1133/1705 train_time:107474ms step_avg:94.86ms +step:1134/1705 train_time:107568ms step_avg:94.86ms +step:1135/1705 train_time:107662ms step_avg:94.86ms +step:1136/1705 train_time:107757ms step_avg:94.86ms +step:1137/1705 train_time:107854ms step_avg:94.86ms +step:1138/1705 train_time:107950ms step_avg:94.86ms +step:1139/1705 train_time:108047ms step_avg:94.86ms +step:1140/1705 train_time:108142ms step_avg:94.86ms +step:1141/1705 train_time:108236ms step_avg:94.86ms +step:1142/1705 train_time:108330ms step_avg:94.86ms +step:1143/1705 train_time:108425ms step_avg:94.86ms +step:1144/1705 train_time:108520ms step_avg:94.86ms +step:1145/1705 train_time:108614ms step_avg:94.86ms +step:1146/1705 train_time:108709ms step_avg:94.86ms +step:1147/1705 train_time:108805ms step_avg:94.86ms +step:1148/1705 train_time:108901ms step_avg:94.86ms +step:1149/1705 train_time:108997ms step_avg:94.86ms +step:1150/1705 train_time:109093ms step_avg:94.86ms +step:1151/1705 train_time:109188ms step_avg:94.86ms +step:1152/1705 train_time:109284ms step_avg:94.86ms +step:1153/1705 train_time:109378ms step_avg:94.86ms +step:1154/1705 train_time:109473ms step_avg:94.86ms +step:1155/1705 train_time:109569ms step_avg:94.87ms +step:1156/1705 train_time:109664ms step_avg:94.86ms +step:1157/1705 train_time:109759ms step_avg:94.86ms +step:1158/1705 train_time:109855ms step_avg:94.87ms +step:1159/1705 train_time:109951ms step_avg:94.87ms +step:1160/1705 train_time:110047ms step_avg:94.87ms +step:1161/1705 train_time:110142ms step_avg:94.87ms +step:1162/1705 train_time:110236ms step_avg:94.87ms +step:1163/1705 train_time:110331ms step_avg:94.87ms +step:1164/1705 train_time:110427ms step_avg:94.87ms +step:1165/1705 train_time:110522ms step_avg:94.87ms +step:1166/1705 train_time:110617ms step_avg:94.87ms +step:1167/1705 train_time:110713ms step_avg:94.87ms +step:1168/1705 train_time:110809ms step_avg:94.87ms +step:1169/1705 train_time:110905ms step_avg:94.87ms +step:1170/1705 train_time:111000ms step_avg:94.87ms +step:1171/1705 train_time:111097ms step_avg:94.87ms +step:1172/1705 train_time:111193ms step_avg:94.87ms +step:1173/1705 train_time:111289ms step_avg:94.88ms +step:1174/1705 train_time:111384ms step_avg:94.88ms +step:1175/1705 train_time:111479ms step_avg:94.88ms +step:1176/1705 train_time:111574ms step_avg:94.88ms +step:1177/1705 train_time:111669ms step_avg:94.88ms +step:1178/1705 train_time:111764ms step_avg:94.88ms +step:1179/1705 train_time:111861ms step_avg:94.88ms +step:1180/1705 train_time:111955ms step_avg:94.88ms +step:1181/1705 train_time:112050ms step_avg:94.88ms +step:1182/1705 train_time:112146ms step_avg:94.88ms +step:1183/1705 train_time:112241ms step_avg:94.88ms +step:1184/1705 train_time:112336ms step_avg:94.88ms +step:1185/1705 train_time:112432ms step_avg:94.88ms +step:1186/1705 train_time:112527ms step_avg:94.88ms +step:1187/1705 train_time:112622ms step_avg:94.88ms +step:1188/1705 train_time:112717ms step_avg:94.88ms +step:1189/1705 train_time:112813ms step_avg:94.88ms +step:1190/1705 train_time:112909ms step_avg:94.88ms +step:1191/1705 train_time:113005ms step_avg:94.88ms +step:1192/1705 train_time:113100ms step_avg:94.88ms +step:1193/1705 train_time:113196ms step_avg:94.88ms +step:1194/1705 train_time:113291ms step_avg:94.88ms +step:1195/1705 train_time:113386ms step_avg:94.88ms +step:1196/1705 train_time:113482ms step_avg:94.88ms +step:1197/1705 train_time:113576ms step_avg:94.88ms +step:1198/1705 train_time:113673ms step_avg:94.89ms +step:1199/1705 train_time:113769ms step_avg:94.89ms +step:1200/1705 train_time:113864ms step_avg:94.89ms +step:1201/1705 train_time:113959ms step_avg:94.89ms +step:1202/1705 train_time:114054ms step_avg:94.89ms +step:1203/1705 train_time:114150ms step_avg:94.89ms +step:1204/1705 train_time:114246ms step_avg:94.89ms +step:1205/1705 train_time:114341ms step_avg:94.89ms +step:1206/1705 train_time:114436ms step_avg:94.89ms +step:1207/1705 train_time:114532ms step_avg:94.89ms +step:1208/1705 train_time:114628ms step_avg:94.89ms +step:1209/1705 train_time:114723ms step_avg:94.89ms +step:1210/1705 train_time:114818ms step_avg:94.89ms +step:1211/1705 train_time:114914ms step_avg:94.89ms +step:1212/1705 train_time:115010ms step_avg:94.89ms +step:1213/1705 train_time:115105ms step_avg:94.89ms +step:1214/1705 train_time:115200ms step_avg:94.89ms +step:1215/1705 train_time:115295ms step_avg:94.89ms +step:1216/1705 train_time:115391ms step_avg:94.89ms +step:1217/1705 train_time:115486ms step_avg:94.89ms +step:1218/1705 train_time:115582ms step_avg:94.90ms +step:1219/1705 train_time:115677ms step_avg:94.89ms +step:1220/1705 train_time:115772ms step_avg:94.90ms +step:1221/1705 train_time:115868ms step_avg:94.90ms +step:1222/1705 train_time:115963ms step_avg:94.90ms +step:1223/1705 train_time:116058ms step_avg:94.90ms +step:1224/1705 train_time:116154ms step_avg:94.90ms +step:1225/1705 train_time:116250ms step_avg:94.90ms +step:1226/1705 train_time:116347ms step_avg:94.90ms +step:1227/1705 train_time:116442ms step_avg:94.90ms +step:1228/1705 train_time:116536ms step_avg:94.90ms +step:1229/1705 train_time:116631ms step_avg:94.90ms +step:1230/1705 train_time:116726ms step_avg:94.90ms +step:1231/1705 train_time:116822ms step_avg:94.90ms +step:1232/1705 train_time:116916ms step_avg:94.90ms +step:1233/1705 train_time:117011ms step_avg:94.90ms +step:1234/1705 train_time:117107ms step_avg:94.90ms +step:1235/1705 train_time:117202ms step_avg:94.90ms +step:1236/1705 train_time:117297ms step_avg:94.90ms +step:1237/1705 train_time:117393ms step_avg:94.90ms +step:1238/1705 train_time:117489ms step_avg:94.90ms +step:1239/1705 train_time:117585ms step_avg:94.90ms +step:1240/1705 train_time:117680ms step_avg:94.90ms +step:1241/1705 train_time:117775ms step_avg:94.90ms +step:1242/1705 train_time:117870ms step_avg:94.90ms +step:1243/1705 train_time:117966ms step_avg:94.90ms +step:1244/1705 train_time:118062ms step_avg:94.90ms +step:1245/1705 train_time:118156ms step_avg:94.90ms +step:1246/1705 train_time:118252ms step_avg:94.91ms +step:1247/1705 train_time:118347ms step_avg:94.91ms +step:1248/1705 train_time:118443ms step_avg:94.91ms +step:1249/1705 train_time:118537ms step_avg:94.91ms +step:1250/1705 train_time:118633ms step_avg:94.91ms +step:1250/1705 val_loss:3.3898 train_time:118729ms step_avg:94.98ms +step:1251/1705 train_time:118752ms step_avg:94.93ms +step:1252/1705 train_time:118837ms step_avg:94.92ms +step:1253/1705 train_time:118932ms step_avg:94.92ms +step:1254/1705 train_time:119026ms step_avg:94.92ms +step:1255/1705 train_time:119120ms step_avg:94.92ms +step:1256/1705 train_time:119214ms step_avg:94.92ms +step:1257/1705 train_time:119309ms step_avg:94.92ms +step:1258/1705 train_time:119402ms step_avg:94.91ms +step:1259/1705 train_time:119496ms step_avg:94.91ms +step:1260/1705 train_time:119590ms step_avg:94.91ms +step:1261/1705 train_time:119689ms step_avg:94.92ms +step:1262/1705 train_time:119788ms step_avg:94.92ms +step:1263/1705 train_time:119887ms step_avg:94.92ms +step:1264/1705 train_time:119983ms step_avg:94.92ms +step:1265/1705 train_time:120078ms step_avg:94.92ms +step:1266/1705 train_time:120172ms step_avg:94.92ms +step:1267/1705 train_time:120268ms step_avg:94.92ms +step:1268/1705 train_time:120362ms step_avg:94.92ms +step:1269/1705 train_time:120456ms step_avg:94.92ms +step:1270/1705 train_time:120551ms step_avg:94.92ms +step:1271/1705 train_time:120646ms step_avg:94.92ms +step:1272/1705 train_time:120743ms step_avg:94.92ms +step:1273/1705 train_time:120840ms step_avg:94.93ms +step:1274/1705 train_time:121238ms step_avg:95.16ms +step:1275/1705 train_time:121309ms step_avg:95.14ms +step:1276/1705 train_time:121402ms step_avg:95.14ms +step:1277/1705 train_time:121496ms step_avg:95.14ms +step:1278/1705 train_time:121589ms step_avg:95.14ms +step:1279/1705 train_time:121683ms step_avg:95.14ms +step:1280/1705 train_time:121777ms step_avg:95.14ms +step:1281/1705 train_time:121870ms step_avg:95.14ms +step:1282/1705 train_time:121964ms step_avg:95.14ms +step:1283/1705 train_time:122058ms step_avg:95.13ms +step:1284/1705 train_time:122159ms step_avg:95.14ms +step:1285/1705 train_time:122257ms step_avg:95.14ms +step:1286/1705 train_time:122353ms step_avg:95.14ms +step:1287/1705 train_time:122448ms step_avg:95.14ms +step:1288/1705 train_time:122545ms step_avg:95.14ms +step:1289/1705 train_time:122639ms step_avg:95.14ms +step:1290/1705 train_time:122733ms step_avg:95.14ms +step:1291/1705 train_time:122827ms step_avg:95.14ms +step:1292/1705 train_time:122922ms step_avg:95.14ms +step:1293/1705 train_time:123016ms step_avg:95.14ms +step:1294/1705 train_time:123112ms step_avg:95.14ms +step:1295/1705 train_time:123209ms step_avg:95.14ms +step:1296/1705 train_time:123307ms step_avg:95.14ms +step:1297/1705 train_time:123403ms step_avg:95.14ms +step:1298/1705 train_time:123497ms step_avg:95.14ms +step:1299/1705 train_time:123593ms step_avg:95.14ms +step:1300/1705 train_time:123687ms step_avg:95.14ms +step:1301/1705 train_time:123783ms step_avg:95.14ms +step:1302/1705 train_time:123876ms step_avg:95.14ms +step:1303/1705 train_time:123970ms step_avg:95.14ms +step:1304/1705 train_time:124065ms step_avg:95.14ms +step:1305/1705 train_time:124161ms step_avg:95.14ms +step:1306/1705 train_time:124258ms step_avg:95.14ms +step:1307/1705 train_time:124354ms step_avg:95.14ms +step:1308/1705 train_time:124451ms step_avg:95.15ms +step:1309/1705 train_time:124546ms step_avg:95.15ms +step:1310/1705 train_time:124643ms step_avg:95.15ms +step:1311/1705 train_time:124739ms step_avg:95.15ms +step:1312/1705 train_time:124832ms step_avg:95.15ms +step:1313/1705 train_time:124927ms step_avg:95.15ms +step:1314/1705 train_time:125022ms step_avg:95.15ms +step:1315/1705 train_time:125116ms step_avg:95.15ms +step:1316/1705 train_time:125212ms step_avg:95.15ms +step:1317/1705 train_time:125308ms step_avg:95.15ms +step:1318/1705 train_time:125404ms step_avg:95.15ms +step:1319/1705 train_time:125500ms step_avg:95.15ms +step:1320/1705 train_time:125596ms step_avg:95.15ms +step:1321/1705 train_time:125692ms step_avg:95.15ms +step:1322/1705 train_time:125788ms step_avg:95.15ms +step:1323/1705 train_time:125884ms step_avg:95.15ms +step:1324/1705 train_time:125977ms step_avg:95.15ms +step:1325/1705 train_time:126072ms step_avg:95.15ms +step:1326/1705 train_time:126167ms step_avg:95.15ms +step:1327/1705 train_time:126263ms step_avg:95.15ms +step:1328/1705 train_time:126358ms step_avg:95.15ms +step:1329/1705 train_time:126454ms step_avg:95.15ms +step:1330/1705 train_time:126550ms step_avg:95.15ms +step:1331/1705 train_time:126646ms step_avg:95.15ms +step:1332/1705 train_time:126741ms step_avg:95.15ms +step:1333/1705 train_time:126835ms step_avg:95.15ms +step:1334/1705 train_time:126930ms step_avg:95.15ms +step:1335/1705 train_time:127025ms step_avg:95.15ms +step:1336/1705 train_time:127120ms step_avg:95.15ms +step:1337/1705 train_time:127215ms step_avg:95.15ms +step:1338/1705 train_time:127310ms step_avg:95.15ms +step:1339/1705 train_time:127406ms step_avg:95.15ms +step:1340/1705 train_time:127502ms step_avg:95.15ms +step:1341/1705 train_time:127596ms step_avg:95.15ms +step:1342/1705 train_time:127691ms step_avg:95.15ms +step:1343/1705 train_time:127788ms step_avg:95.15ms +step:1344/1705 train_time:127882ms step_avg:95.15ms +step:1345/1705 train_time:127977ms step_avg:95.15ms +step:1346/1705 train_time:128072ms step_avg:95.15ms +step:1347/1705 train_time:128167ms step_avg:95.15ms +step:1348/1705 train_time:128263ms step_avg:95.15ms +step:1349/1705 train_time:128359ms step_avg:95.15ms +step:1350/1705 train_time:128454ms step_avg:95.15ms +step:1351/1705 train_time:128551ms step_avg:95.15ms +step:1352/1705 train_time:128647ms step_avg:95.15ms +step:1353/1705 train_time:128743ms step_avg:95.15ms +step:1354/1705 train_time:128838ms step_avg:95.15ms +step:1355/1705 train_time:128933ms step_avg:95.15ms +step:1356/1705 train_time:129028ms step_avg:95.15ms +step:1357/1705 train_time:129123ms step_avg:95.15ms +step:1358/1705 train_time:129218ms step_avg:95.15ms +step:1359/1705 train_time:129313ms step_avg:95.15ms +step:1360/1705 train_time:129408ms step_avg:95.15ms +step:1361/1705 train_time:129505ms step_avg:95.15ms +step:1362/1705 train_time:129601ms step_avg:95.16ms +step:1363/1705 train_time:129696ms step_avg:95.15ms +step:1364/1705 train_time:129791ms step_avg:95.15ms +step:1365/1705 train_time:129886ms step_avg:95.15ms +step:1366/1705 train_time:129980ms step_avg:95.15ms +step:1367/1705 train_time:130075ms step_avg:95.15ms +step:1368/1705 train_time:130171ms step_avg:95.15ms +step:1369/1705 train_time:130266ms step_avg:95.15ms +step:1370/1705 train_time:130362ms step_avg:95.15ms +step:1371/1705 train_time:130457ms step_avg:95.15ms +step:1372/1705 train_time:130553ms step_avg:95.15ms +step:1373/1705 train_time:130649ms step_avg:95.16ms +step:1374/1705 train_time:130745ms step_avg:95.16ms +step:1375/1705 train_time:130840ms step_avg:95.16ms +step:1375/1705 val_loss:3.3524 train_time:130935ms step_avg:95.23ms +step:1376/1705 train_time:130957ms step_avg:95.17ms +step:1377/1705 train_time:131036ms step_avg:95.16ms +step:1378/1705 train_time:131134ms step_avg:95.16ms +step:1379/1705 train_time:131229ms step_avg:95.16ms +step:1380/1705 train_time:131324ms step_avg:95.16ms +step:1381/1705 train_time:131419ms step_avg:95.16ms +step:1382/1705 train_time:131513ms step_avg:95.16ms +step:1383/1705 train_time:131607ms step_avg:95.16ms +step:1384/1705 train_time:131703ms step_avg:95.16ms +step:1385/1705 train_time:131797ms step_avg:95.16ms +step:1386/1705 train_time:131893ms step_avg:95.16ms +step:1387/1705 train_time:131991ms step_avg:95.16ms +step:1388/1705 train_time:132089ms step_avg:95.16ms +step:1389/1705 train_time:132184ms step_avg:95.16ms +step:1390/1705 train_time:132279ms step_avg:95.16ms +step:1391/1705 train_time:132374ms step_avg:95.16ms +step:1392/1705 train_time:132468ms step_avg:95.16ms +step:1393/1705 train_time:132562ms step_avg:95.16ms +step:1394/1705 train_time:132657ms step_avg:95.16ms +step:1395/1705 train_time:132752ms step_avg:95.16ms +step:1396/1705 train_time:132847ms step_avg:95.16ms +step:1397/1705 train_time:132942ms step_avg:95.16ms +step:1398/1705 train_time:133039ms step_avg:95.16ms +step:1399/1705 train_time:133134ms step_avg:95.16ms +step:1400/1705 train_time:133230ms step_avg:95.16ms +step:1401/1705 train_time:133325ms step_avg:95.16ms +step:1402/1705 train_time:133420ms step_avg:95.16ms +step:1403/1705 train_time:133515ms step_avg:95.16ms +step:1404/1705 train_time:133611ms step_avg:95.16ms +step:1405/1705 train_time:133705ms step_avg:95.16ms +step:1406/1705 train_time:133801ms step_avg:95.16ms +step:1407/1705 train_time:133897ms step_avg:95.16ms +step:1408/1705 train_time:133993ms step_avg:95.17ms +step:1409/1705 train_time:134089ms step_avg:95.17ms +step:1410/1705 train_time:134184ms step_avg:95.17ms +step:1411/1705 train_time:134280ms step_avg:95.17ms +step:1412/1705 train_time:134375ms step_avg:95.17ms +step:1413/1705 train_time:134470ms step_avg:95.17ms +step:1414/1705 train_time:134565ms step_avg:95.17ms +step:1415/1705 train_time:134660ms step_avg:95.17ms +step:1416/1705 train_time:134755ms step_avg:95.17ms +step:1417/1705 train_time:134850ms step_avg:95.17ms +step:1418/1705 train_time:134944ms step_avg:95.17ms +step:1419/1705 train_time:135040ms step_avg:95.17ms +step:1420/1705 train_time:135136ms step_avg:95.17ms +step:1421/1705 train_time:135231ms step_avg:95.17ms +step:1422/1705 train_time:135326ms step_avg:95.17ms +step:1423/1705 train_time:135422ms step_avg:95.17ms +step:1424/1705 train_time:135517ms step_avg:95.17ms +step:1425/1705 train_time:135613ms step_avg:95.17ms +step:1426/1705 train_time:135707ms step_avg:95.17ms +step:1427/1705 train_time:135802ms step_avg:95.17ms +step:1428/1705 train_time:135898ms step_avg:95.17ms +step:1429/1705 train_time:135993ms step_avg:95.17ms +step:1430/1705 train_time:136089ms step_avg:95.17ms +step:1431/1705 train_time:136183ms step_avg:95.17ms +step:1432/1705 train_time:136279ms step_avg:95.17ms +step:1433/1705 train_time:136375ms step_avg:95.17ms +step:1434/1705 train_time:136470ms step_avg:95.17ms +step:1435/1705 train_time:136565ms step_avg:95.17ms +step:1436/1705 train_time:136660ms step_avg:95.17ms +step:1437/1705 train_time:136756ms step_avg:95.17ms +step:1438/1705 train_time:136850ms step_avg:95.17ms +step:1439/1705 train_time:136945ms step_avg:95.17ms +step:1440/1705 train_time:137041ms step_avg:95.17ms +step:1441/1705 train_time:137137ms step_avg:95.17ms +step:1442/1705 train_time:137232ms step_avg:95.17ms +step:1443/1705 train_time:137328ms step_avg:95.17ms +step:1444/1705 train_time:137423ms step_avg:95.17ms +step:1445/1705 train_time:137519ms step_avg:95.17ms +step:1446/1705 train_time:137614ms step_avg:95.17ms +step:1447/1705 train_time:137709ms step_avg:95.17ms +step:1448/1705 train_time:137804ms step_avg:95.17ms +step:1449/1705 train_time:137901ms step_avg:95.17ms +step:1450/1705 train_time:137996ms step_avg:95.17ms +step:1451/1705 train_time:138091ms step_avg:95.17ms +step:1452/1705 train_time:138186ms step_avg:95.17ms +step:1453/1705 train_time:138280ms step_avg:95.17ms +step:1454/1705 train_time:138377ms step_avg:95.17ms +step:1455/1705 train_time:138474ms step_avg:95.17ms +step:1456/1705 train_time:138570ms step_avg:95.17ms +step:1457/1705 train_time:138664ms step_avg:95.17ms +step:1458/1705 train_time:138759ms step_avg:95.17ms +step:1459/1705 train_time:138855ms step_avg:95.17ms +step:1460/1705 train_time:138950ms step_avg:95.17ms +step:1461/1705 train_time:139045ms step_avg:95.17ms +step:1462/1705 train_time:139141ms step_avg:95.17ms +step:1463/1705 train_time:139237ms step_avg:95.17ms +step:1464/1705 train_time:139332ms step_avg:95.17ms +step:1465/1705 train_time:139426ms step_avg:95.17ms +step:1466/1705 train_time:139522ms step_avg:95.17ms +step:1467/1705 train_time:139618ms step_avg:95.17ms +step:1468/1705 train_time:139714ms step_avg:95.17ms +step:1469/1705 train_time:139808ms step_avg:95.17ms +step:1470/1705 train_time:139902ms step_avg:95.17ms +step:1471/1705 train_time:139998ms step_avg:95.17ms +step:1472/1705 train_time:140094ms step_avg:95.17ms +step:1473/1705 train_time:140189ms step_avg:95.17ms +step:1474/1705 train_time:140284ms step_avg:95.17ms +step:1475/1705 train_time:140381ms step_avg:95.17ms +step:1476/1705 train_time:140476ms step_avg:95.17ms +step:1477/1705 train_time:140571ms step_avg:95.17ms +step:1478/1705 train_time:140666ms step_avg:95.17ms +step:1479/1705 train_time:140761ms step_avg:95.17ms +step:1480/1705 train_time:140857ms step_avg:95.17ms +step:1481/1705 train_time:140953ms step_avg:95.17ms +step:1482/1705 train_time:141050ms step_avg:95.18ms +step:1483/1705 train_time:141145ms step_avg:95.18ms +step:1484/1705 train_time:141240ms step_avg:95.17ms +step:1485/1705 train_time:141597ms step_avg:95.35ms +step:1486/1705 train_time:141697ms step_avg:95.35ms +step:1487/1705 train_time:141790ms step_avg:95.35ms +step:1488/1705 train_time:141884ms step_avg:95.35ms +step:1489/1705 train_time:141978ms step_avg:95.35ms +step:1490/1705 train_time:142072ms step_avg:95.35ms +step:1491/1705 train_time:142166ms step_avg:95.35ms +step:1492/1705 train_time:142260ms step_avg:95.35ms +step:1493/1705 train_time:142354ms step_avg:95.35ms +step:1494/1705 train_time:142449ms step_avg:95.35ms +step:1495/1705 train_time:142546ms step_avg:95.35ms +step:1496/1705 train_time:142645ms step_avg:95.35ms +step:1497/1705 train_time:142741ms step_avg:95.35ms +step:1498/1705 train_time:142836ms step_avg:95.35ms +step:1499/1705 train_time:142931ms step_avg:95.35ms +step:1500/1705 train_time:143026ms step_avg:95.35ms +step:1500/1705 val_loss:3.3200 train_time:143120ms step_avg:95.41ms +step:1501/1705 train_time:143143ms step_avg:95.37ms +step:1502/1705 train_time:143221ms step_avg:95.35ms +step:1503/1705 train_time:143318ms step_avg:95.35ms +step:1504/1705 train_time:143414ms step_avg:95.35ms +step:1505/1705 train_time:143508ms step_avg:95.35ms +step:1506/1705 train_time:143602ms step_avg:95.35ms +step:1507/1705 train_time:143696ms step_avg:95.35ms +step:1508/1705 train_time:143791ms step_avg:95.35ms +step:1509/1705 train_time:143885ms step_avg:95.35ms +step:1510/1705 train_time:143979ms step_avg:95.35ms +step:1511/1705 train_time:144075ms step_avg:95.35ms +step:1512/1705 train_time:144176ms step_avg:95.35ms +step:1513/1705 train_time:144273ms step_avg:95.36ms +step:1514/1705 train_time:144370ms step_avg:95.36ms +step:1515/1705 train_time:144466ms step_avg:95.36ms +step:1516/1705 train_time:144560ms step_avg:95.36ms +step:1517/1705 train_time:144654ms step_avg:95.36ms +step:1518/1705 train_time:144748ms step_avg:95.35ms +step:1519/1705 train_time:144843ms step_avg:95.35ms +step:1520/1705 train_time:144937ms step_avg:95.35ms +step:1521/1705 train_time:145032ms step_avg:95.35ms +step:1522/1705 train_time:145129ms step_avg:95.35ms +step:1523/1705 train_time:145227ms step_avg:95.36ms +step:1524/1705 train_time:145323ms step_avg:95.36ms +step:1525/1705 train_time:145419ms step_avg:95.36ms +step:1526/1705 train_time:145514ms step_avg:95.36ms +step:1527/1705 train_time:145609ms step_avg:95.36ms +step:1528/1705 train_time:145704ms step_avg:95.36ms +step:1529/1705 train_time:145798ms step_avg:95.36ms +step:1530/1705 train_time:145892ms step_avg:95.35ms +step:1531/1705 train_time:145986ms step_avg:95.35ms +step:1532/1705 train_time:146081ms step_avg:95.35ms +step:1533/1705 train_time:146178ms step_avg:95.35ms +step:1534/1705 train_time:146275ms step_avg:95.35ms +step:1535/1705 train_time:146371ms step_avg:95.36ms +step:1536/1705 train_time:146467ms step_avg:95.36ms +step:1537/1705 train_time:146562ms step_avg:95.36ms +step:1538/1705 train_time:146657ms step_avg:95.36ms +step:1539/1705 train_time:146751ms step_avg:95.36ms +step:1540/1705 train_time:146847ms step_avg:95.36ms +step:1541/1705 train_time:146941ms step_avg:95.35ms +step:1542/1705 train_time:147036ms step_avg:95.35ms +step:1543/1705 train_time:147132ms step_avg:95.35ms +step:1544/1705 train_time:147228ms step_avg:95.36ms +step:1545/1705 train_time:147324ms step_avg:95.36ms +step:1546/1705 train_time:147419ms step_avg:95.36ms +step:1547/1705 train_time:147514ms step_avg:95.36ms +step:1548/1705 train_time:147610ms step_avg:95.36ms +step:1549/1705 train_time:147705ms step_avg:95.36ms +step:1550/1705 train_time:147800ms step_avg:95.36ms +step:1551/1705 train_time:147895ms step_avg:95.35ms +step:1552/1705 train_time:147990ms step_avg:95.35ms +step:1553/1705 train_time:148085ms step_avg:95.35ms +step:1554/1705 train_time:148181ms step_avg:95.35ms +step:1555/1705 train_time:148277ms step_avg:95.35ms +step:1556/1705 train_time:148372ms step_avg:95.36ms +step:1557/1705 train_time:148468ms step_avg:95.36ms +step:1558/1705 train_time:148563ms step_avg:95.36ms +step:1559/1705 train_time:148658ms step_avg:95.35ms +step:1560/1705 train_time:148754ms step_avg:95.35ms +step:1561/1705 train_time:148850ms step_avg:95.36ms +step:1562/1705 train_time:148944ms step_avg:95.35ms +step:1563/1705 train_time:149039ms step_avg:95.35ms +step:1564/1705 train_time:149135ms step_avg:95.35ms +step:1565/1705 train_time:149230ms step_avg:95.35ms +step:1566/1705 train_time:149327ms step_avg:95.36ms +step:1567/1705 train_time:149424ms step_avg:95.36ms +step:1568/1705 train_time:149518ms step_avg:95.36ms +step:1569/1705 train_time:149613ms step_avg:95.36ms +step:1570/1705 train_time:149709ms step_avg:95.36ms +step:1571/1705 train_time:149804ms step_avg:95.36ms +step:1572/1705 train_time:149899ms step_avg:95.36ms +step:1573/1705 train_time:149994ms step_avg:95.36ms +step:1574/1705 train_time:150089ms step_avg:95.36ms +step:1575/1705 train_time:150185ms step_avg:95.36ms +step:1576/1705 train_time:150280ms step_avg:95.36ms +step:1577/1705 train_time:150375ms step_avg:95.35ms +step:1578/1705 train_time:150471ms step_avg:95.36ms +step:1579/1705 train_time:150567ms step_avg:95.36ms +step:1580/1705 train_time:150662ms step_avg:95.36ms +step:1581/1705 train_time:150757ms step_avg:95.36ms +step:1582/1705 train_time:150852ms step_avg:95.36ms +step:1583/1705 train_time:150949ms step_avg:95.36ms +step:1584/1705 train_time:151045ms step_avg:95.36ms +step:1585/1705 train_time:151140ms step_avg:95.36ms +step:1586/1705 train_time:151234ms step_avg:95.36ms +step:1587/1705 train_time:151330ms step_avg:95.36ms +step:1588/1705 train_time:151426ms step_avg:95.36ms +step:1589/1705 train_time:151521ms step_avg:95.36ms +step:1590/1705 train_time:151616ms step_avg:95.36ms +step:1591/1705 train_time:151712ms step_avg:95.36ms +step:1592/1705 train_time:151809ms step_avg:95.36ms +step:1593/1705 train_time:151905ms step_avg:95.36ms +step:1594/1705 train_time:152001ms step_avg:95.36ms +step:1595/1705 train_time:152095ms step_avg:95.36ms +step:1596/1705 train_time:152190ms step_avg:95.36ms +step:1597/1705 train_time:152286ms step_avg:95.36ms +step:1598/1705 train_time:152381ms step_avg:95.36ms +step:1599/1705 train_time:152475ms step_avg:95.36ms +step:1600/1705 train_time:152571ms step_avg:95.36ms +step:1601/1705 train_time:152666ms step_avg:95.36ms +step:1602/1705 train_time:152762ms step_avg:95.36ms +step:1603/1705 train_time:152858ms step_avg:95.36ms +step:1604/1705 train_time:152954ms step_avg:95.36ms +step:1605/1705 train_time:153050ms step_avg:95.36ms +step:1606/1705 train_time:153146ms step_avg:95.36ms +step:1607/1705 train_time:153242ms step_avg:95.36ms +step:1608/1705 train_time:153337ms step_avg:95.36ms +step:1609/1705 train_time:153432ms step_avg:95.36ms +step:1610/1705 train_time:153527ms step_avg:95.36ms +step:1611/1705 train_time:153623ms step_avg:95.36ms +step:1612/1705 train_time:153717ms step_avg:95.36ms +step:1613/1705 train_time:153813ms step_avg:95.36ms +step:1614/1705 train_time:153909ms step_avg:95.36ms +step:1615/1705 train_time:154004ms step_avg:95.36ms +step:1616/1705 train_time:154099ms step_avg:95.36ms +step:1617/1705 train_time:154194ms step_avg:95.36ms +step:1618/1705 train_time:154289ms step_avg:95.36ms +step:1619/1705 train_time:154384ms step_avg:95.36ms +step:1620/1705 train_time:154479ms step_avg:95.36ms +step:1621/1705 train_time:154574ms step_avg:95.36ms +step:1622/1705 train_time:154670ms step_avg:95.36ms +step:1623/1705 train_time:154766ms step_avg:95.36ms +step:1624/1705 train_time:154863ms step_avg:95.36ms +step:1625/1705 train_time:154959ms step_avg:95.36ms +step:1625/1705 val_loss:3.2920 train_time:155055ms step_avg:95.42ms +step:1626/1705 train_time:155078ms step_avg:95.37ms +step:1627/1705 train_time:155158ms step_avg:95.36ms +step:1628/1705 train_time:155257ms step_avg:95.37ms +step:1629/1705 train_time:155351ms step_avg:95.37ms +step:1630/1705 train_time:155446ms step_avg:95.37ms +step:1631/1705 train_time:155540ms step_avg:95.36ms +step:1632/1705 train_time:155634ms step_avg:95.36ms +step:1633/1705 train_time:155728ms step_avg:95.36ms +step:1634/1705 train_time:155823ms step_avg:95.36ms +step:1635/1705 train_time:155917ms step_avg:95.36ms +step:1636/1705 train_time:156012ms step_avg:95.36ms +step:1637/1705 train_time:156111ms step_avg:95.36ms +step:1638/1705 train_time:156210ms step_avg:95.37ms +step:1639/1705 train_time:156307ms step_avg:95.37ms +step:1640/1705 train_time:156402ms step_avg:95.37ms +step:1641/1705 train_time:156497ms step_avg:95.37ms +step:1642/1705 train_time:156592ms step_avg:95.37ms +step:1643/1705 train_time:156686ms step_avg:95.37ms +step:1644/1705 train_time:156781ms step_avg:95.37ms +step:1645/1705 train_time:156875ms step_avg:95.36ms +step:1646/1705 train_time:156970ms step_avg:95.36ms +step:1647/1705 train_time:157067ms step_avg:95.37ms +step:1648/1705 train_time:157165ms step_avg:95.37ms +step:1649/1705 train_time:157262ms step_avg:95.37ms +step:1650/1705 train_time:157357ms step_avg:95.37ms +step:1651/1705 train_time:157452ms step_avg:95.37ms +step:1652/1705 train_time:157546ms step_avg:95.37ms +step:1653/1705 train_time:157640ms step_avg:95.37ms +step:1654/1705 train_time:157734ms step_avg:95.37ms +step:1655/1705 train_time:157829ms step_avg:95.37ms +step:1656/1705 train_time:157925ms step_avg:95.37ms +step:1657/1705 train_time:158021ms step_avg:95.37ms +step:1658/1705 train_time:158118ms step_avg:95.37ms +step:1659/1705 train_time:158213ms step_avg:95.37ms +step:1660/1705 train_time:158309ms step_avg:95.37ms +step:1661/1705 train_time:158406ms step_avg:95.37ms +step:1662/1705 train_time:158502ms step_avg:95.37ms +step:1663/1705 train_time:158598ms step_avg:95.37ms +step:1664/1705 train_time:158692ms step_avg:95.37ms +step:1665/1705 train_time:158786ms step_avg:95.37ms +step:1666/1705 train_time:158881ms step_avg:95.37ms +step:1667/1705 train_time:158976ms step_avg:95.37ms +step:1668/1705 train_time:159072ms step_avg:95.37ms +step:1669/1705 train_time:159168ms step_avg:95.37ms +step:1670/1705 train_time:159264ms step_avg:95.37ms +step:1671/1705 train_time:159360ms step_avg:95.37ms +step:1672/1705 train_time:159455ms step_avg:95.37ms +step:1673/1705 train_time:159551ms step_avg:95.37ms +step:1674/1705 train_time:159646ms step_avg:95.37ms +step:1675/1705 train_time:159741ms step_avg:95.37ms +step:1676/1705 train_time:159836ms step_avg:95.37ms +step:1677/1705 train_time:159930ms step_avg:95.37ms +step:1678/1705 train_time:160026ms step_avg:95.37ms +step:1679/1705 train_time:160122ms step_avg:95.37ms +step:1680/1705 train_time:160217ms step_avg:95.37ms +step:1681/1705 train_time:160312ms step_avg:95.37ms +step:1682/1705 train_time:160408ms step_avg:95.37ms +step:1683/1705 train_time:160504ms step_avg:95.37ms +step:1684/1705 train_time:160599ms step_avg:95.37ms +step:1685/1705 train_time:160693ms step_avg:95.37ms +step:1686/1705 train_time:160788ms step_avg:95.37ms +step:1687/1705 train_time:160883ms step_avg:95.37ms +step:1688/1705 train_time:160979ms step_avg:95.37ms +step:1689/1705 train_time:161074ms step_avg:95.37ms +step:1690/1705 train_time:161170ms step_avg:95.37ms +step:1691/1705 train_time:161266ms step_avg:95.37ms +step:1692/1705 train_time:161361ms step_avg:95.37ms +step:1693/1705 train_time:161457ms step_avg:95.37ms +step:1694/1705 train_time:161552ms step_avg:95.37ms +step:1695/1705 train_time:161648ms step_avg:95.37ms +step:1696/1705 train_time:161744ms step_avg:95.37ms +step:1697/1705 train_time:161839ms step_avg:95.37ms +step:1698/1705 train_time:162087ms step_avg:95.46ms +step:1699/1705 train_time:162268ms step_avg:95.51ms +step:1700/1705 train_time:162361ms step_avg:95.51ms +step:1701/1705 train_time:162455ms step_avg:95.51ms +step:1702/1705 train_time:162549ms step_avg:95.50ms +step:1703/1705 train_time:162643ms step_avg:95.50ms +step:1704/1705 train_time:162738ms step_avg:95.50ms +step:1705/1705 train_time:162832ms step_avg:95.50ms +step:1705/1705 val_loss:3.2782 train_time:162926ms step_avg:95.56ms +peak memory allocated: 33992 MiB reserved: 49496 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt b/records/050925_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt new file mode 100644 index 000000000..3762cf210 --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:39:10 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 128W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 45C P0 128W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 44C P0 132W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 94903 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 94904 C /usr/bin/python3 610MiB | +| 0 N/A N/A 94905 C /usr/bin/python3 610MiB | +| 0 N/A N/A 94906 C /usr/bin/python3 610MiB | +| 0 N/A N/A 94907 C /usr/bin/python3 610MiB | +| 0 N/A N/A 94908 C /usr/bin/python3 610MiB | +| 0 N/A N/A 94909 C /usr/bin/python3 610MiB | +| 0 N/A N/A 94910 C /usr/bin/python3 610MiB | +| 1 N/A N/A 94904 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 94905 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 94906 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 94907 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 94908 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 94909 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 94910 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:358ms step_avg:358.46ms +step:2/1670 train_time:378ms step_avg:189.16ms +step:3/1670 train_time:452ms step_avg:150.71ms +step:4/1670 train_time:546ms step_avg:136.40ms +step:5/1670 train_time:640ms step_avg:128.05ms +step:6/1670 train_time:735ms step_avg:122.48ms +step:7/1670 train_time:829ms step_avg:118.46ms +step:8/1670 train_time:924ms step_avg:115.53ms +step:9/1670 train_time:1020ms step_avg:113.32ms +step:10/1670 train_time:1115ms step_avg:111.52ms +step:11/1670 train_time:1210ms step_avg:109.99ms +step:12/1670 train_time:1308ms step_avg:109.01ms +step:13/1670 train_time:1407ms step_avg:108.21ms +step:14/1670 train_time:1504ms step_avg:107.45ms +step:15/1670 train_time:1601ms step_avg:106.71ms +step:16/1670 train_time:1697ms step_avg:106.05ms +step:17/1670 train_time:1792ms step_avg:105.42ms +step:18/1670 train_time:1887ms step_avg:104.84ms +step:19/1670 train_time:1982ms step_avg:104.34ms +step:20/1670 train_time:2078ms step_avg:103.89ms +step:21/1670 train_time:2174ms step_avg:103.52ms +step:22/1670 train_time:2270ms step_avg:103.20ms +step:23/1670 train_time:2366ms step_avg:102.88ms +step:24/1670 train_time:2463ms step_avg:102.62ms +step:25/1670 train_time:2560ms step_avg:102.40ms +step:26/1670 train_time:2656ms step_avg:102.16ms +step:27/1670 train_time:2752ms step_avg:101.93ms +step:28/1670 train_time:2847ms step_avg:101.69ms +step:29/1670 train_time:2943ms step_avg:101.47ms +step:30/1670 train_time:3038ms step_avg:101.27ms +step:31/1670 train_time:3133ms step_avg:101.07ms +step:32/1670 train_time:3229ms step_avg:100.92ms +step:33/1670 train_time:3326ms step_avg:100.78ms +step:34/1670 train_time:3422ms step_avg:100.65ms +step:35/1670 train_time:3519ms step_avg:100.55ms +step:36/1670 train_time:3615ms step_avg:100.42ms +step:37/1670 train_time:3712ms step_avg:100.31ms +step:38/1670 train_time:3807ms step_avg:100.18ms +step:39/1670 train_time:3903ms step_avg:100.07ms +step:40/1670 train_time:3999ms step_avg:99.97ms +step:41/1670 train_time:4095ms step_avg:99.87ms +step:42/1670 train_time:4191ms step_avg:99.77ms +step:43/1670 train_time:4286ms step_avg:99.68ms +step:44/1670 train_time:4382ms step_avg:99.58ms +step:45/1670 train_time:4479ms step_avg:99.54ms +step:46/1670 train_time:4575ms step_avg:99.46ms +step:47/1670 train_time:4673ms step_avg:99.43ms +step:48/1670 train_time:4768ms step_avg:99.33ms +step:49/1670 train_time:4863ms step_avg:99.24ms +step:50/1670 train_time:4959ms step_avg:99.19ms +step:51/1670 train_time:5055ms step_avg:99.12ms +step:52/1670 train_time:5151ms step_avg:99.05ms +step:53/1670 train_time:5246ms step_avg:98.99ms +step:54/1670 train_time:5342ms step_avg:98.93ms +step:55/1670 train_time:5438ms step_avg:98.87ms +step:56/1670 train_time:5534ms step_avg:98.83ms +step:57/1670 train_time:5630ms step_avg:98.77ms +step:58/1670 train_time:5726ms step_avg:98.72ms +step:59/1670 train_time:5821ms step_avg:98.66ms +step:60/1670 train_time:5917ms step_avg:98.62ms +step:61/1670 train_time:6014ms step_avg:98.59ms +step:62/1670 train_time:6109ms step_avg:98.54ms +step:63/1670 train_time:6205ms step_avg:98.49ms +step:64/1670 train_time:6301ms step_avg:98.45ms +step:65/1670 train_time:6396ms step_avg:98.41ms +step:66/1670 train_time:6493ms step_avg:98.37ms +step:67/1670 train_time:6588ms step_avg:98.33ms +step:68/1670 train_time:6684ms step_avg:98.29ms +step:69/1670 train_time:6780ms step_avg:98.27ms +step:70/1670 train_time:6876ms step_avg:98.23ms +step:71/1670 train_time:6972ms step_avg:98.19ms +step:72/1670 train_time:7067ms step_avg:98.15ms +step:73/1670 train_time:7163ms step_avg:98.12ms +step:74/1670 train_time:7259ms step_avg:98.09ms +step:75/1670 train_time:7354ms step_avg:98.06ms +step:76/1670 train_time:7450ms step_avg:98.02ms +step:77/1670 train_time:7545ms step_avg:97.99ms +step:78/1670 train_time:7641ms step_avg:97.97ms +step:79/1670 train_time:7737ms step_avg:97.94ms +step:80/1670 train_time:7834ms step_avg:97.92ms +step:81/1670 train_time:7930ms step_avg:97.90ms +step:82/1670 train_time:8026ms step_avg:97.88ms +step:83/1670 train_time:8122ms step_avg:97.86ms +step:84/1670 train_time:8219ms step_avg:97.85ms +step:85/1670 train_time:8314ms step_avg:97.82ms +step:86/1670 train_time:8410ms step_avg:97.79ms +step:87/1670 train_time:8505ms step_avg:97.76ms +step:88/1670 train_time:8601ms step_avg:97.74ms +step:89/1670 train_time:8696ms step_avg:97.71ms +step:90/1670 train_time:8793ms step_avg:97.70ms +step:91/1670 train_time:8888ms step_avg:97.67ms +step:92/1670 train_time:8984ms step_avg:97.66ms +step:93/1670 train_time:9081ms step_avg:97.64ms +step:94/1670 train_time:9177ms step_avg:97.62ms +step:95/1670 train_time:9272ms step_avg:97.60ms +step:96/1670 train_time:9367ms step_avg:97.57ms +step:97/1670 train_time:9463ms step_avg:97.56ms +step:98/1670 train_time:9560ms step_avg:97.55ms +step:99/1670 train_time:9656ms step_avg:97.53ms +step:100/1670 train_time:9751ms step_avg:97.51ms +step:101/1670 train_time:9847ms step_avg:97.49ms +step:102/1670 train_time:9943ms step_avg:97.48ms +step:103/1670 train_time:10040ms step_avg:97.47ms +step:104/1670 train_time:10136ms step_avg:97.46ms +step:105/1670 train_time:10232ms step_avg:97.45ms +step:106/1670 train_time:10328ms step_avg:97.43ms +step:107/1670 train_time:10423ms step_avg:97.41ms +step:108/1670 train_time:10519ms step_avg:97.40ms +step:109/1670 train_time:10615ms step_avg:97.39ms +step:110/1670 train_time:10711ms step_avg:97.37ms +step:111/1670 train_time:10806ms step_avg:97.35ms +step:112/1670 train_time:10902ms step_avg:97.34ms +step:113/1670 train_time:10998ms step_avg:97.33ms +step:114/1670 train_time:11094ms step_avg:97.32ms +step:115/1670 train_time:11191ms step_avg:97.31ms +step:116/1670 train_time:11286ms step_avg:97.29ms +step:117/1670 train_time:11381ms step_avg:97.28ms +step:118/1670 train_time:11477ms step_avg:97.26ms +step:119/1670 train_time:11574ms step_avg:97.26ms +step:120/1670 train_time:11670ms step_avg:97.25ms +step:121/1670 train_time:11766ms step_avg:97.24ms +step:122/1670 train_time:11861ms step_avg:97.22ms +step:123/1670 train_time:11957ms step_avg:97.21ms +step:124/1670 train_time:12053ms step_avg:97.20ms +step:125/1670 train_time:12150ms step_avg:97.20ms +step:125/1670 val_loss:4.3007 train_time:12245ms step_avg:97.96ms +step:126/1670 train_time:12266ms step_avg:97.35ms +step:127/1670 train_time:12342ms step_avg:97.18ms +step:128/1670 train_time:12441ms step_avg:97.20ms +step:129/1670 train_time:12549ms step_avg:97.28ms +step:130/1670 train_time:12646ms step_avg:97.28ms +step:131/1670 train_time:12740ms step_avg:97.25ms +step:132/1670 train_time:12834ms step_avg:97.23ms +step:133/1670 train_time:12930ms step_avg:97.21ms +step:134/1670 train_time:13024ms step_avg:97.19ms +step:135/1670 train_time:13118ms step_avg:97.17ms +step:136/1670 train_time:13215ms step_avg:97.17ms +step:137/1670 train_time:13312ms step_avg:97.17ms +step:138/1670 train_time:13408ms step_avg:97.16ms +step:139/1670 train_time:13505ms step_avg:97.16ms +step:140/1670 train_time:13602ms step_avg:97.16ms +step:141/1670 train_time:13697ms step_avg:97.14ms +step:142/1670 train_time:13793ms step_avg:97.14ms +step:143/1670 train_time:13887ms step_avg:97.11ms +step:144/1670 train_time:13982ms step_avg:97.10ms +step:145/1670 train_time:14077ms step_avg:97.08ms +step:146/1670 train_time:14172ms step_avg:97.07ms +step:147/1670 train_time:14268ms step_avg:97.06ms +step:148/1670 train_time:14364ms step_avg:97.05ms +step:149/1670 train_time:14460ms step_avg:97.05ms +step:150/1670 train_time:14556ms step_avg:97.04ms +step:151/1670 train_time:14652ms step_avg:97.04ms +step:152/1670 train_time:14749ms step_avg:97.03ms +step:153/1670 train_time:14844ms step_avg:97.02ms +step:154/1670 train_time:14939ms step_avg:97.01ms +step:155/1670 train_time:15034ms step_avg:96.99ms +step:156/1670 train_time:15130ms step_avg:96.99ms +step:157/1670 train_time:15225ms step_avg:96.98ms +step:158/1670 train_time:15321ms step_avg:96.97ms +step:159/1670 train_time:15416ms step_avg:96.96ms +step:160/1670 train_time:15512ms step_avg:96.95ms +step:161/1670 train_time:15608ms step_avg:96.94ms +step:162/1670 train_time:15704ms step_avg:96.94ms +step:163/1670 train_time:15799ms step_avg:96.93ms +step:164/1670 train_time:15894ms step_avg:96.91ms +step:165/1670 train_time:15990ms step_avg:96.91ms +step:166/1670 train_time:16085ms step_avg:96.90ms +step:167/1670 train_time:16179ms step_avg:96.88ms +step:168/1670 train_time:16275ms step_avg:96.88ms +step:169/1670 train_time:16372ms step_avg:96.88ms +step:170/1670 train_time:16468ms step_avg:96.87ms +step:171/1670 train_time:16563ms step_avg:96.86ms +step:172/1670 train_time:16659ms step_avg:96.85ms +step:173/1670 train_time:16755ms step_avg:96.85ms +step:174/1670 train_time:16850ms step_avg:96.84ms +step:175/1670 train_time:16946ms step_avg:96.83ms +step:176/1670 train_time:17040ms step_avg:96.82ms +step:177/1670 train_time:17136ms step_avg:96.81ms +step:178/1670 train_time:17231ms step_avg:96.80ms +step:179/1670 train_time:17326ms step_avg:96.80ms +step:180/1670 train_time:17421ms step_avg:96.78ms +step:181/1670 train_time:17517ms step_avg:96.78ms +step:182/1670 train_time:17613ms step_avg:96.78ms +step:183/1670 train_time:17709ms step_avg:96.77ms +step:184/1670 train_time:17805ms step_avg:96.77ms +step:185/1670 train_time:17900ms step_avg:96.75ms +step:186/1670 train_time:17995ms step_avg:96.75ms +step:187/1670 train_time:18091ms step_avg:96.74ms +step:188/1670 train_time:18187ms step_avg:96.74ms +step:189/1670 train_time:18282ms step_avg:96.73ms +step:190/1670 train_time:18377ms step_avg:96.72ms +step:191/1670 train_time:18472ms step_avg:96.71ms +step:192/1670 train_time:18569ms step_avg:96.71ms +step:193/1670 train_time:18664ms step_avg:96.71ms +step:194/1670 train_time:18760ms step_avg:96.70ms +step:195/1670 train_time:18854ms step_avg:96.69ms +step:196/1670 train_time:18951ms step_avg:96.69ms +step:197/1670 train_time:19047ms step_avg:96.68ms +step:198/1670 train_time:19142ms step_avg:96.68ms +step:199/1670 train_time:19237ms step_avg:96.67ms +step:200/1670 train_time:19333ms step_avg:96.66ms +step:201/1670 train_time:19429ms step_avg:96.66ms +step:202/1670 train_time:19524ms step_avg:96.65ms +step:203/1670 train_time:19619ms step_avg:96.65ms +step:204/1670 train_time:19716ms step_avg:96.65ms +step:205/1670 train_time:19811ms step_avg:96.64ms +step:206/1670 train_time:19907ms step_avg:96.64ms +step:207/1670 train_time:20003ms step_avg:96.63ms +step:208/1670 train_time:20098ms step_avg:96.62ms +step:209/1670 train_time:20193ms step_avg:96.62ms +step:210/1670 train_time:20289ms step_avg:96.62ms +step:211/1670 train_time:20384ms step_avg:96.61ms +step:212/1670 train_time:20479ms step_avg:96.60ms +step:213/1670 train_time:20775ms step_avg:97.54ms +step:214/1670 train_time:20905ms step_avg:97.69ms +step:215/1670 train_time:20999ms step_avg:97.67ms +step:216/1670 train_time:21093ms step_avg:97.65ms +step:217/1670 train_time:21188ms step_avg:97.64ms +step:218/1670 train_time:21282ms step_avg:97.62ms +step:219/1670 train_time:21377ms step_avg:97.61ms +step:220/1670 train_time:21472ms step_avg:97.60ms +step:221/1670 train_time:21566ms step_avg:97.58ms +step:222/1670 train_time:21660ms step_avg:97.57ms +step:223/1670 train_time:21759ms step_avg:97.57ms +step:224/1670 train_time:21856ms step_avg:97.57ms +step:225/1670 train_time:21954ms step_avg:97.57ms +step:226/1670 train_time:22051ms step_avg:97.57ms +step:227/1670 train_time:22147ms step_avg:97.56ms +step:228/1670 train_time:22242ms step_avg:97.55ms +step:229/1670 train_time:22337ms step_avg:97.54ms +step:230/1670 train_time:22432ms step_avg:97.53ms +step:231/1670 train_time:22527ms step_avg:97.52ms +step:232/1670 train_time:22622ms step_avg:97.51ms +step:233/1670 train_time:22717ms step_avg:97.50ms +step:234/1670 train_time:22813ms step_avg:97.49ms +step:235/1670 train_time:22911ms step_avg:97.49ms +step:236/1670 train_time:23008ms step_avg:97.49ms +step:237/1670 train_time:23104ms step_avg:97.48ms +step:238/1670 train_time:23199ms step_avg:97.47ms +step:239/1670 train_time:23294ms step_avg:97.47ms +step:240/1670 train_time:23390ms step_avg:97.46ms +step:241/1670 train_time:23486ms step_avg:97.45ms +step:242/1670 train_time:23580ms step_avg:97.44ms +step:243/1670 train_time:23675ms step_avg:97.43ms +step:244/1670 train_time:23771ms step_avg:97.42ms +step:245/1670 train_time:23867ms step_avg:97.42ms +step:246/1670 train_time:23964ms step_avg:97.41ms +step:247/1670 train_time:24060ms step_avg:97.41ms +step:248/1670 train_time:24155ms step_avg:97.40ms +step:249/1670 train_time:24251ms step_avg:97.39ms +step:250/1670 train_time:24346ms step_avg:97.39ms +step:250/1670 val_loss:3.9722 train_time:24441ms step_avg:97.76ms +step:251/1670 train_time:24464ms step_avg:97.46ms +step:252/1670 train_time:24544ms step_avg:97.40ms +step:253/1670 train_time:24643ms step_avg:97.40ms +step:254/1670 train_time:24739ms step_avg:97.40ms +step:255/1670 train_time:24834ms step_avg:97.39ms +step:256/1670 train_time:24928ms step_avg:97.37ms +step:257/1670 train_time:25023ms step_avg:97.36ms +step:258/1670 train_time:25118ms step_avg:97.36ms +step:259/1670 train_time:25213ms step_avg:97.35ms +step:260/1670 train_time:25308ms step_avg:97.34ms +step:261/1670 train_time:25403ms step_avg:97.33ms +step:262/1670 train_time:25501ms step_avg:97.33ms +step:263/1670 train_time:25599ms step_avg:97.34ms +step:264/1670 train_time:25695ms step_avg:97.33ms +step:265/1670 train_time:25791ms step_avg:97.32ms +step:266/1670 train_time:25885ms step_avg:97.31ms +step:267/1670 train_time:25981ms step_avg:97.31ms +step:268/1670 train_time:26075ms step_avg:97.30ms +step:269/1670 train_time:26170ms step_avg:97.29ms +step:270/1670 train_time:26265ms step_avg:97.28ms +step:271/1670 train_time:26361ms step_avg:97.27ms +step:272/1670 train_time:26458ms step_avg:97.27ms +step:273/1670 train_time:26555ms step_avg:97.27ms +step:274/1670 train_time:26651ms step_avg:97.27ms +step:275/1670 train_time:26747ms step_avg:97.26ms +step:276/1670 train_time:26842ms step_avg:97.25ms +step:277/1670 train_time:26938ms step_avg:97.25ms +step:278/1670 train_time:27033ms step_avg:97.24ms +step:279/1670 train_time:27128ms step_avg:97.23ms +step:280/1670 train_time:27223ms step_avg:97.23ms +step:281/1670 train_time:27319ms step_avg:97.22ms +step:282/1670 train_time:27415ms step_avg:97.22ms +step:283/1670 train_time:27511ms step_avg:97.21ms +step:284/1670 train_time:27608ms step_avg:97.21ms +step:285/1670 train_time:27704ms step_avg:97.21ms +step:286/1670 train_time:27800ms step_avg:97.20ms +step:287/1670 train_time:27895ms step_avg:97.20ms +step:288/1670 train_time:27991ms step_avg:97.19ms +step:289/1670 train_time:28086ms step_avg:97.18ms +step:290/1670 train_time:28180ms step_avg:97.17ms +step:291/1670 train_time:28276ms step_avg:97.17ms +step:292/1670 train_time:28371ms step_avg:97.16ms +step:293/1670 train_time:28466ms step_avg:97.15ms +step:294/1670 train_time:28562ms step_avg:97.15ms +step:295/1670 train_time:28659ms step_avg:97.15ms +step:296/1670 train_time:28755ms step_avg:97.15ms +step:297/1670 train_time:28851ms step_avg:97.14ms +step:298/1670 train_time:28946ms step_avg:97.14ms +step:299/1670 train_time:29042ms step_avg:97.13ms +step:300/1670 train_time:29138ms step_avg:97.13ms +step:301/1670 train_time:29233ms step_avg:97.12ms +step:302/1670 train_time:29328ms step_avg:97.11ms +step:303/1670 train_time:29423ms step_avg:97.11ms +step:304/1670 train_time:29519ms step_avg:97.10ms +step:305/1670 train_time:29614ms step_avg:97.10ms +step:306/1670 train_time:29710ms step_avg:97.09ms +step:307/1670 train_time:29806ms step_avg:97.09ms +step:308/1670 train_time:29902ms step_avg:97.08ms +step:309/1670 train_time:29998ms step_avg:97.08ms +step:310/1670 train_time:30094ms step_avg:97.08ms +step:311/1670 train_time:30189ms step_avg:97.07ms +step:312/1670 train_time:30284ms step_avg:97.07ms +step:313/1670 train_time:30380ms step_avg:97.06ms +step:314/1670 train_time:30476ms step_avg:97.06ms +step:315/1670 train_time:30572ms step_avg:97.05ms +step:316/1670 train_time:30667ms step_avg:97.05ms +step:317/1670 train_time:30763ms step_avg:97.04ms +step:318/1670 train_time:30859ms step_avg:97.04ms +step:319/1670 train_time:30954ms step_avg:97.04ms +step:320/1670 train_time:31050ms step_avg:97.03ms +step:321/1670 train_time:31146ms step_avg:97.03ms +step:322/1670 train_time:31241ms step_avg:97.02ms +step:323/1670 train_time:31336ms step_avg:97.02ms +step:324/1670 train_time:31432ms step_avg:97.01ms +step:325/1670 train_time:31527ms step_avg:97.01ms +step:326/1670 train_time:31624ms step_avg:97.00ms +step:327/1670 train_time:31719ms step_avg:97.00ms +step:328/1670 train_time:31815ms step_avg:97.00ms +step:329/1670 train_time:31910ms step_avg:96.99ms +step:330/1670 train_time:32005ms step_avg:96.98ms +step:331/1670 train_time:32101ms step_avg:96.98ms +step:332/1670 train_time:32196ms step_avg:96.98ms +step:333/1670 train_time:32292ms step_avg:96.97ms +step:334/1670 train_time:32387ms step_avg:96.97ms +step:335/1670 train_time:32482ms step_avg:96.96ms +step:336/1670 train_time:32578ms step_avg:96.96ms +step:337/1670 train_time:32674ms step_avg:96.96ms +step:338/1670 train_time:32770ms step_avg:96.95ms +step:339/1670 train_time:32865ms step_avg:96.95ms +step:340/1670 train_time:32961ms step_avg:96.94ms +step:341/1670 train_time:33057ms step_avg:96.94ms +step:342/1670 train_time:33152ms step_avg:96.94ms +step:343/1670 train_time:33248ms step_avg:96.93ms +step:344/1670 train_time:33343ms step_avg:96.93ms +step:345/1670 train_time:33439ms step_avg:96.92ms +step:346/1670 train_time:33534ms step_avg:96.92ms +step:347/1670 train_time:33630ms step_avg:96.92ms +step:348/1670 train_time:33725ms step_avg:96.91ms +step:349/1670 train_time:33821ms step_avg:96.91ms +step:350/1670 train_time:33918ms step_avg:96.91ms +step:351/1670 train_time:34014ms step_avg:96.91ms +step:352/1670 train_time:34110ms step_avg:96.90ms +step:353/1670 train_time:34206ms step_avg:96.90ms +step:354/1670 train_time:34302ms step_avg:96.90ms +step:355/1670 train_time:34398ms step_avg:96.89ms +step:356/1670 train_time:34493ms step_avg:96.89ms +step:357/1670 train_time:34589ms step_avg:96.89ms +step:358/1670 train_time:34684ms step_avg:96.88ms +step:359/1670 train_time:34779ms step_avg:96.88ms +step:360/1670 train_time:34875ms step_avg:96.88ms +step:361/1670 train_time:34971ms step_avg:96.87ms +step:362/1670 train_time:35066ms step_avg:96.87ms +step:363/1670 train_time:35162ms step_avg:96.87ms +step:364/1670 train_time:35258ms step_avg:96.86ms +step:365/1670 train_time:35353ms step_avg:96.86ms +step:366/1670 train_time:35449ms step_avg:96.85ms +step:367/1670 train_time:35544ms step_avg:96.85ms +step:368/1670 train_time:35640ms step_avg:96.85ms +step:369/1670 train_time:35736ms step_avg:96.85ms +step:370/1670 train_time:35831ms step_avg:96.84ms +step:371/1670 train_time:35927ms step_avg:96.84ms +step:372/1670 train_time:36022ms step_avg:96.83ms +step:373/1670 train_time:36119ms step_avg:96.83ms +step:374/1670 train_time:36214ms step_avg:96.83ms +step:375/1670 train_time:36310ms step_avg:96.83ms +step:375/1670 val_loss:3.8222 train_time:36405ms step_avg:97.08ms +step:376/1670 train_time:36428ms step_avg:96.88ms +step:377/1670 train_time:36510ms step_avg:96.84ms +step:378/1670 train_time:36609ms step_avg:96.85ms +step:379/1670 train_time:36704ms step_avg:96.85ms +step:380/1670 train_time:36800ms step_avg:96.84ms +step:381/1670 train_time:36894ms step_avg:96.84ms +step:382/1670 train_time:36989ms step_avg:96.83ms +step:383/1670 train_time:37084ms step_avg:96.82ms +step:384/1670 train_time:37179ms step_avg:96.82ms +step:385/1670 train_time:37273ms step_avg:96.81ms +step:386/1670 train_time:37369ms step_avg:96.81ms +step:387/1670 train_time:37467ms step_avg:96.81ms +step:388/1670 train_time:37564ms step_avg:96.82ms +step:389/1670 train_time:37660ms step_avg:96.81ms +step:390/1670 train_time:37756ms step_avg:96.81ms +step:391/1670 train_time:37851ms step_avg:96.81ms +step:392/1670 train_time:37947ms step_avg:96.80ms +step:393/1670 train_time:38041ms step_avg:96.80ms +step:394/1670 train_time:38136ms step_avg:96.79ms +step:395/1670 train_time:38231ms step_avg:96.79ms +step:396/1670 train_time:38327ms step_avg:96.79ms +step:397/1670 train_time:38423ms step_avg:96.78ms +step:398/1670 train_time:38520ms step_avg:96.78ms +step:399/1670 train_time:38616ms step_avg:96.78ms +step:400/1670 train_time:38712ms step_avg:96.78ms +step:401/1670 train_time:38808ms step_avg:96.78ms +step:402/1670 train_time:38905ms step_avg:96.78ms +step:403/1670 train_time:39001ms step_avg:96.78ms +step:404/1670 train_time:39095ms step_avg:96.77ms +step:405/1670 train_time:39190ms step_avg:96.77ms +step:406/1670 train_time:39285ms step_avg:96.76ms +step:407/1670 train_time:39381ms step_avg:96.76ms +step:408/1670 train_time:39477ms step_avg:96.76ms +step:409/1670 train_time:39573ms step_avg:96.75ms +step:410/1670 train_time:39670ms step_avg:96.75ms +step:411/1670 train_time:39766ms step_avg:96.75ms +step:412/1670 train_time:39861ms step_avg:96.75ms +step:413/1670 train_time:39956ms step_avg:96.75ms +step:414/1670 train_time:40052ms step_avg:96.74ms +step:415/1670 train_time:40147ms step_avg:96.74ms +step:416/1670 train_time:40242ms step_avg:96.74ms +step:417/1670 train_time:40338ms step_avg:96.73ms +step:418/1670 train_time:40434ms step_avg:96.73ms +step:419/1670 train_time:40530ms step_avg:96.73ms +step:420/1670 train_time:40627ms step_avg:96.73ms +step:421/1670 train_time:40723ms step_avg:96.73ms +step:422/1670 train_time:40819ms step_avg:96.73ms +step:423/1670 train_time:40914ms step_avg:96.72ms +step:424/1670 train_time:41009ms step_avg:96.72ms +step:425/1670 train_time:41307ms step_avg:97.19ms +step:426/1670 train_time:41454ms step_avg:97.31ms +step:427/1670 train_time:41548ms step_avg:97.30ms +step:428/1670 train_time:41643ms step_avg:97.30ms +step:429/1670 train_time:41737ms step_avg:97.29ms +step:430/1670 train_time:41832ms step_avg:97.28ms +step:431/1670 train_time:41927ms step_avg:97.28ms +step:432/1670 train_time:42022ms step_avg:97.27ms +step:433/1670 train_time:42117ms step_avg:97.27ms +step:434/1670 train_time:42211ms step_avg:97.26ms +step:435/1670 train_time:42307ms step_avg:97.26ms +step:436/1670 train_time:42407ms step_avg:97.26ms +step:437/1670 train_time:42506ms step_avg:97.27ms +step:438/1670 train_time:42602ms step_avg:97.26ms +step:439/1670 train_time:42697ms step_avg:97.26ms +step:440/1670 train_time:42792ms step_avg:97.25ms +step:441/1670 train_time:42888ms step_avg:97.25ms +step:442/1670 train_time:42983ms step_avg:97.25ms +step:443/1670 train_time:43078ms step_avg:97.24ms +step:444/1670 train_time:43172ms step_avg:97.24ms +step:445/1670 train_time:43268ms step_avg:97.23ms +step:446/1670 train_time:43365ms step_avg:97.23ms +step:447/1670 train_time:43463ms step_avg:97.23ms +step:448/1670 train_time:43559ms step_avg:97.23ms +step:449/1670 train_time:43655ms step_avg:97.23ms +step:450/1670 train_time:43750ms step_avg:97.22ms +step:451/1670 train_time:43847ms step_avg:97.22ms +step:452/1670 train_time:43942ms step_avg:97.22ms +step:453/1670 train_time:44037ms step_avg:97.21ms +step:454/1670 train_time:44132ms step_avg:97.21ms +step:455/1670 train_time:44228ms step_avg:97.20ms +step:456/1670 train_time:44323ms step_avg:97.20ms +step:457/1670 train_time:44419ms step_avg:97.20ms +step:458/1670 train_time:44516ms step_avg:97.20ms +step:459/1670 train_time:44612ms step_avg:97.19ms +step:460/1670 train_time:44709ms step_avg:97.19ms +step:461/1670 train_time:44805ms step_avg:97.19ms +step:462/1670 train_time:44900ms step_avg:97.19ms +step:463/1670 train_time:44996ms step_avg:97.18ms +step:464/1670 train_time:45090ms step_avg:97.18ms +step:465/1670 train_time:45186ms step_avg:97.17ms +step:466/1670 train_time:45281ms step_avg:97.17ms +step:467/1670 train_time:45376ms step_avg:97.17ms +step:468/1670 train_time:45473ms step_avg:97.16ms +step:469/1670 train_time:45569ms step_avg:97.16ms +step:470/1670 train_time:45665ms step_avg:97.16ms +step:471/1670 train_time:45761ms step_avg:97.16ms +step:472/1670 train_time:45856ms step_avg:97.15ms +step:473/1670 train_time:45952ms step_avg:97.15ms +step:474/1670 train_time:46048ms step_avg:97.15ms +step:475/1670 train_time:46143ms step_avg:97.14ms +step:476/1670 train_time:46239ms step_avg:97.14ms +step:477/1670 train_time:46334ms step_avg:97.14ms +step:478/1670 train_time:46430ms step_avg:97.13ms +step:479/1670 train_time:46526ms step_avg:97.13ms +step:480/1670 train_time:46622ms step_avg:97.13ms +step:481/1670 train_time:46717ms step_avg:97.13ms +step:482/1670 train_time:46814ms step_avg:97.12ms +step:483/1670 train_time:46910ms step_avg:97.12ms +step:484/1670 train_time:47005ms step_avg:97.12ms +step:485/1670 train_time:47102ms step_avg:97.12ms +step:486/1670 train_time:47196ms step_avg:97.11ms +step:487/1670 train_time:47293ms step_avg:97.11ms +step:488/1670 train_time:47388ms step_avg:97.11ms +step:489/1670 train_time:47484ms step_avg:97.11ms +step:490/1670 train_time:47580ms step_avg:97.10ms +step:491/1670 train_time:47676ms step_avg:97.10ms +step:492/1670 train_time:47771ms step_avg:97.10ms +step:493/1670 train_time:47868ms step_avg:97.10ms +step:494/1670 train_time:47964ms step_avg:97.09ms +step:495/1670 train_time:48060ms step_avg:97.09ms +step:496/1670 train_time:48155ms step_avg:97.09ms +step:497/1670 train_time:48250ms step_avg:97.08ms +step:498/1670 train_time:48347ms step_avg:97.08ms +step:499/1670 train_time:48443ms step_avg:97.08ms +step:500/1670 train_time:48540ms step_avg:97.08ms +step:500/1670 val_loss:3.7170 train_time:48634ms step_avg:97.27ms +step:501/1670 train_time:48655ms step_avg:97.12ms +step:502/1670 train_time:48738ms step_avg:97.09ms +step:503/1670 train_time:48838ms step_avg:97.09ms +step:504/1670 train_time:48935ms step_avg:97.09ms +step:505/1670 train_time:49030ms step_avg:97.09ms +step:506/1670 train_time:49125ms step_avg:97.09ms +step:507/1670 train_time:49220ms step_avg:97.08ms +step:508/1670 train_time:49315ms step_avg:97.08ms +step:509/1670 train_time:49410ms step_avg:97.07ms +step:510/1670 train_time:49504ms step_avg:97.07ms +step:511/1670 train_time:49600ms step_avg:97.06ms +step:512/1670 train_time:49697ms step_avg:97.06ms +step:513/1670 train_time:49795ms step_avg:97.07ms +step:514/1670 train_time:49893ms step_avg:97.07ms +step:515/1670 train_time:49989ms step_avg:97.07ms +step:516/1670 train_time:50085ms step_avg:97.06ms +step:517/1670 train_time:50180ms step_avg:97.06ms +step:518/1670 train_time:50275ms step_avg:97.06ms +step:519/1670 train_time:50370ms step_avg:97.05ms +step:520/1670 train_time:50465ms step_avg:97.05ms +step:521/1670 train_time:50560ms step_avg:97.05ms +step:522/1670 train_time:50657ms step_avg:97.04ms +step:523/1670 train_time:50754ms step_avg:97.04ms +step:524/1670 train_time:50852ms step_avg:97.05ms +step:525/1670 train_time:50948ms step_avg:97.04ms +step:526/1670 train_time:51044ms step_avg:97.04ms +step:527/1670 train_time:51139ms step_avg:97.04ms +step:528/1670 train_time:51234ms step_avg:97.03ms +step:529/1670 train_time:51330ms step_avg:97.03ms +step:530/1670 train_time:51426ms step_avg:97.03ms +step:531/1670 train_time:51521ms step_avg:97.03ms +step:532/1670 train_time:51616ms step_avg:97.02ms +step:533/1670 train_time:51713ms step_avg:97.02ms +step:534/1670 train_time:51810ms step_avg:97.02ms +step:535/1670 train_time:51906ms step_avg:97.02ms +step:536/1670 train_time:52002ms step_avg:97.02ms +step:537/1670 train_time:52097ms step_avg:97.02ms +step:538/1670 train_time:52193ms step_avg:97.01ms +step:539/1670 train_time:52289ms step_avg:97.01ms +step:540/1670 train_time:52385ms step_avg:97.01ms +step:541/1670 train_time:52480ms step_avg:97.00ms +step:542/1670 train_time:52575ms step_avg:97.00ms +step:543/1670 train_time:52671ms step_avg:97.00ms +step:544/1670 train_time:52767ms step_avg:97.00ms +step:545/1670 train_time:52863ms step_avg:97.00ms +step:546/1670 train_time:52959ms step_avg:96.99ms +step:547/1670 train_time:53055ms step_avg:96.99ms +step:548/1670 train_time:53151ms step_avg:96.99ms +step:549/1670 train_time:53247ms step_avg:96.99ms +step:550/1670 train_time:53342ms step_avg:96.98ms +step:551/1670 train_time:53437ms step_avg:96.98ms +step:552/1670 train_time:53533ms step_avg:96.98ms +step:553/1670 train_time:53628ms step_avg:96.98ms +step:554/1670 train_time:53724ms step_avg:96.98ms +step:555/1670 train_time:53820ms step_avg:96.97ms +step:556/1670 train_time:53916ms step_avg:96.97ms +step:557/1670 train_time:54013ms step_avg:96.97ms +step:558/1670 train_time:54109ms step_avg:96.97ms +step:559/1670 train_time:54206ms step_avg:96.97ms +step:560/1670 train_time:54304ms step_avg:96.97ms +step:561/1670 train_time:54400ms step_avg:96.97ms +step:562/1670 train_time:54496ms step_avg:96.97ms +step:563/1670 train_time:54593ms step_avg:96.97ms +step:564/1670 train_time:54691ms step_avg:96.97ms +step:565/1670 train_time:54790ms step_avg:96.97ms +step:566/1670 train_time:54888ms step_avg:96.97ms +step:567/1670 train_time:54986ms step_avg:96.98ms +step:568/1670 train_time:55082ms step_avg:96.98ms +step:569/1670 train_time:55180ms step_avg:96.98ms +step:570/1670 train_time:55277ms step_avg:96.98ms +step:571/1670 train_time:55374ms step_avg:96.98ms +step:572/1670 train_time:55471ms step_avg:96.98ms +step:573/1670 train_time:55568ms step_avg:96.98ms +step:574/1670 train_time:55666ms step_avg:96.98ms +step:575/1670 train_time:55763ms step_avg:96.98ms +step:576/1670 train_time:55860ms step_avg:96.98ms +step:577/1670 train_time:55958ms step_avg:96.98ms +step:578/1670 train_time:56056ms step_avg:96.98ms +step:579/1670 train_time:56154ms step_avg:96.98ms +step:580/1670 train_time:56254ms step_avg:96.99ms +step:581/1670 train_time:56352ms step_avg:96.99ms +step:582/1670 train_time:56448ms step_avg:96.99ms +step:583/1670 train_time:56545ms step_avg:96.99ms +step:584/1670 train_time:56642ms step_avg:96.99ms +step:585/1670 train_time:56738ms step_avg:96.99ms +step:586/1670 train_time:56836ms step_avg:96.99ms +step:587/1670 train_time:56933ms step_avg:96.99ms +step:588/1670 train_time:57031ms step_avg:96.99ms +step:589/1670 train_time:57129ms step_avg:96.99ms +step:590/1670 train_time:57226ms step_avg:96.99ms +step:591/1670 train_time:57323ms step_avg:96.99ms +step:592/1670 train_time:57419ms step_avg:96.99ms +step:593/1670 train_time:57517ms step_avg:96.99ms +step:594/1670 train_time:57615ms step_avg:96.99ms +step:595/1670 train_time:57713ms step_avg:97.00ms +step:596/1670 train_time:57811ms step_avg:97.00ms +step:597/1670 train_time:57908ms step_avg:97.00ms +step:598/1670 train_time:58005ms step_avg:97.00ms +step:599/1670 train_time:58101ms step_avg:97.00ms +step:600/1670 train_time:58199ms step_avg:97.00ms +step:601/1670 train_time:58296ms step_avg:97.00ms +step:602/1670 train_time:58394ms step_avg:97.00ms +step:603/1670 train_time:58492ms step_avg:97.00ms +step:604/1670 train_time:58589ms step_avg:97.00ms +step:605/1670 train_time:58685ms step_avg:97.00ms +step:606/1670 train_time:58781ms step_avg:97.00ms +step:607/1670 train_time:58879ms step_avg:97.00ms +step:608/1670 train_time:58976ms step_avg:97.00ms +step:609/1670 train_time:59074ms step_avg:97.00ms +step:610/1670 train_time:59172ms step_avg:97.00ms +step:611/1670 train_time:59270ms step_avg:97.01ms +step:612/1670 train_time:59368ms step_avg:97.01ms +step:613/1670 train_time:59464ms step_avg:97.01ms +step:614/1670 train_time:59561ms step_avg:97.00ms +step:615/1670 train_time:59657ms step_avg:97.00ms +step:616/1670 train_time:59755ms step_avg:97.00ms +step:617/1670 train_time:59853ms step_avg:97.01ms +step:618/1670 train_time:59951ms step_avg:97.01ms +step:619/1670 train_time:60048ms step_avg:97.01ms +step:620/1670 train_time:60147ms step_avg:97.01ms +step:621/1670 train_time:60243ms step_avg:97.01ms +step:622/1670 train_time:60340ms step_avg:97.01ms +step:623/1670 train_time:60437ms step_avg:97.01ms +step:624/1670 train_time:60535ms step_avg:97.01ms +step:625/1670 train_time:60631ms step_avg:97.01ms +step:625/1670 val_loss:3.6163 train_time:60728ms step_avg:97.16ms +step:626/1670 train_time:60750ms step_avg:97.04ms +step:627/1670 train_time:60838ms step_avg:97.03ms +step:628/1670 train_time:60936ms step_avg:97.03ms +step:629/1670 train_time:61034ms step_avg:97.03ms +step:630/1670 train_time:61130ms step_avg:97.03ms +step:631/1670 train_time:61226ms step_avg:97.03ms +step:632/1670 train_time:61323ms step_avg:97.03ms +step:633/1670 train_time:61418ms step_avg:97.03ms +step:634/1670 train_time:61514ms step_avg:97.03ms +step:635/1670 train_time:61611ms step_avg:97.02ms +step:636/1670 train_time:61710ms step_avg:97.03ms +step:637/1670 train_time:61810ms step_avg:97.03ms +step:638/1670 train_time:61912ms step_avg:97.04ms +step:639/1670 train_time:62157ms step_avg:97.27ms +step:640/1670 train_time:62360ms step_avg:97.44ms +step:641/1670 train_time:62456ms step_avg:97.44ms +step:642/1670 train_time:62551ms step_avg:97.43ms +step:643/1670 train_time:62647ms step_avg:97.43ms +step:644/1670 train_time:62743ms step_avg:97.43ms +step:645/1670 train_time:62839ms step_avg:97.42ms +step:646/1670 train_time:62934ms step_avg:97.42ms +step:647/1670 train_time:63030ms step_avg:97.42ms +step:648/1670 train_time:63127ms step_avg:97.42ms +step:649/1670 train_time:63229ms step_avg:97.43ms +step:650/1670 train_time:63330ms step_avg:97.43ms +step:651/1670 train_time:63429ms step_avg:97.43ms +step:652/1670 train_time:63527ms step_avg:97.43ms +step:653/1670 train_time:63623ms step_avg:97.43ms +step:654/1670 train_time:63719ms step_avg:97.43ms +step:655/1670 train_time:63815ms step_avg:97.43ms +step:656/1670 train_time:63911ms step_avg:97.42ms +step:657/1670 train_time:64007ms step_avg:97.42ms +step:658/1670 train_time:64105ms step_avg:97.42ms +step:659/1670 train_time:64203ms step_avg:97.43ms +step:660/1670 train_time:64301ms step_avg:97.43ms +step:661/1670 train_time:64399ms step_avg:97.43ms +step:662/1670 train_time:64497ms step_avg:97.43ms +step:663/1670 train_time:64594ms step_avg:97.43ms +step:664/1670 train_time:64691ms step_avg:97.43ms +step:665/1670 train_time:64788ms step_avg:97.43ms +step:666/1670 train_time:64884ms step_avg:97.42ms +step:667/1670 train_time:64980ms step_avg:97.42ms +step:668/1670 train_time:65077ms step_avg:97.42ms +step:669/1670 train_time:65175ms step_avg:97.42ms +step:670/1670 train_time:65273ms step_avg:97.42ms +step:671/1670 train_time:65372ms step_avg:97.42ms +step:672/1670 train_time:65471ms step_avg:97.43ms +step:673/1670 train_time:65568ms step_avg:97.43ms +step:674/1670 train_time:65665ms step_avg:97.43ms +step:675/1670 train_time:65761ms step_avg:97.42ms +step:676/1670 train_time:65857ms step_avg:97.42ms +step:677/1670 train_time:65954ms step_avg:97.42ms +step:678/1670 train_time:66051ms step_avg:97.42ms +step:679/1670 train_time:66148ms step_avg:97.42ms +step:680/1670 train_time:66246ms step_avg:97.42ms +step:681/1670 train_time:66345ms step_avg:97.42ms +step:682/1670 train_time:66443ms step_avg:97.42ms +step:683/1670 train_time:66540ms step_avg:97.42ms +step:684/1670 train_time:66637ms step_avg:97.42ms +step:685/1670 train_time:66735ms step_avg:97.42ms +step:686/1670 train_time:66831ms step_avg:97.42ms +step:687/1670 train_time:66928ms step_avg:97.42ms +step:688/1670 train_time:67025ms step_avg:97.42ms +step:689/1670 train_time:67121ms step_avg:97.42ms +step:690/1670 train_time:67218ms step_avg:97.42ms +step:691/1670 train_time:67316ms step_avg:97.42ms +step:692/1670 train_time:67415ms step_avg:97.42ms +step:693/1670 train_time:67512ms step_avg:97.42ms +step:694/1670 train_time:67611ms step_avg:97.42ms +step:695/1670 train_time:67709ms step_avg:97.42ms +step:696/1670 train_time:67806ms step_avg:97.42ms +step:697/1670 train_time:67903ms step_avg:97.42ms +step:698/1670 train_time:67999ms step_avg:97.42ms +step:699/1670 train_time:68096ms step_avg:97.42ms +step:700/1670 train_time:68194ms step_avg:97.42ms +step:701/1670 train_time:68291ms step_avg:97.42ms +step:702/1670 train_time:68390ms step_avg:97.42ms +step:703/1670 train_time:68488ms step_avg:97.42ms +step:704/1670 train_time:68585ms step_avg:97.42ms +step:705/1670 train_time:68682ms step_avg:97.42ms +step:706/1670 train_time:68778ms step_avg:97.42ms +step:707/1670 train_time:68876ms step_avg:97.42ms +step:708/1670 train_time:68974ms step_avg:97.42ms +step:709/1670 train_time:69071ms step_avg:97.42ms +step:710/1670 train_time:69168ms step_avg:97.42ms +step:711/1670 train_time:69266ms step_avg:97.42ms +step:712/1670 train_time:69364ms step_avg:97.42ms +step:713/1670 train_time:69460ms step_avg:97.42ms +step:714/1670 train_time:69557ms step_avg:97.42ms +step:715/1670 train_time:69656ms step_avg:97.42ms +step:716/1670 train_time:69753ms step_avg:97.42ms +step:717/1670 train_time:69850ms step_avg:97.42ms +step:718/1670 train_time:69948ms step_avg:97.42ms +step:719/1670 train_time:70045ms step_avg:97.42ms +step:720/1670 train_time:70143ms step_avg:97.42ms +step:721/1670 train_time:70239ms step_avg:97.42ms +step:722/1670 train_time:70336ms step_avg:97.42ms +step:723/1670 train_time:70433ms step_avg:97.42ms +step:724/1670 train_time:70532ms step_avg:97.42ms +step:725/1670 train_time:70630ms step_avg:97.42ms +step:726/1670 train_time:70727ms step_avg:97.42ms +step:727/1670 train_time:70824ms step_avg:97.42ms +step:728/1670 train_time:70920ms step_avg:97.42ms +step:729/1670 train_time:71017ms step_avg:97.42ms +step:730/1670 train_time:71114ms step_avg:97.42ms +step:731/1670 train_time:71212ms step_avg:97.42ms +step:732/1670 train_time:71310ms step_avg:97.42ms +step:733/1670 train_time:71409ms step_avg:97.42ms +step:734/1670 train_time:71507ms step_avg:97.42ms +step:735/1670 train_time:71604ms step_avg:97.42ms +step:736/1670 train_time:71701ms step_avg:97.42ms +step:737/1670 train_time:71798ms step_avg:97.42ms +step:738/1670 train_time:71895ms step_avg:97.42ms +step:739/1670 train_time:71992ms step_avg:97.42ms +step:740/1670 train_time:72089ms step_avg:97.42ms +step:741/1670 train_time:72187ms step_avg:97.42ms +step:742/1670 train_time:72284ms step_avg:97.42ms +step:743/1670 train_time:72381ms step_avg:97.42ms +step:744/1670 train_time:72478ms step_avg:97.42ms +step:745/1670 train_time:72575ms step_avg:97.42ms +step:746/1670 train_time:72673ms step_avg:97.42ms +step:747/1670 train_time:72772ms step_avg:97.42ms +step:748/1670 train_time:72869ms step_avg:97.42ms +step:749/1670 train_time:72966ms step_avg:97.42ms +step:750/1670 train_time:73063ms step_avg:97.42ms +step:750/1670 val_loss:3.5631 train_time:73158ms step_avg:97.54ms +step:751/1670 train_time:73183ms step_avg:97.45ms +step:752/1670 train_time:73263ms step_avg:97.42ms +step:753/1670 train_time:73364ms step_avg:97.43ms +step:754/1670 train_time:73461ms step_avg:97.43ms +step:755/1670 train_time:73557ms step_avg:97.43ms +step:756/1670 train_time:73652ms step_avg:97.42ms +step:757/1670 train_time:73749ms step_avg:97.42ms +step:758/1670 train_time:73845ms step_avg:97.42ms +step:759/1670 train_time:73942ms step_avg:97.42ms +step:760/1670 train_time:74037ms step_avg:97.42ms +step:761/1670 train_time:74135ms step_avg:97.42ms +step:762/1670 train_time:74235ms step_avg:97.42ms +step:763/1670 train_time:74334ms step_avg:97.42ms +step:764/1670 train_time:74431ms step_avg:97.42ms +step:765/1670 train_time:74529ms step_avg:97.42ms +step:766/1670 train_time:74627ms step_avg:97.42ms +step:767/1670 train_time:74723ms step_avg:97.42ms +step:768/1670 train_time:74820ms step_avg:97.42ms +step:769/1670 train_time:74916ms step_avg:97.42ms +step:770/1670 train_time:75013ms step_avg:97.42ms +step:771/1670 train_time:75111ms step_avg:97.42ms +step:772/1670 train_time:75209ms step_avg:97.42ms +step:773/1670 train_time:75308ms step_avg:97.42ms +step:774/1670 train_time:75408ms step_avg:97.43ms +step:775/1670 train_time:75506ms step_avg:97.43ms +step:776/1670 train_time:75603ms step_avg:97.43ms +step:777/1670 train_time:75699ms step_avg:97.43ms +step:778/1670 train_time:75795ms step_avg:97.42ms +step:779/1670 train_time:75892ms step_avg:97.42ms +step:780/1670 train_time:75989ms step_avg:97.42ms +step:781/1670 train_time:76086ms step_avg:97.42ms +step:782/1670 train_time:76185ms step_avg:97.42ms +step:783/1670 train_time:76283ms step_avg:97.42ms +step:784/1670 train_time:76381ms step_avg:97.43ms +step:785/1670 train_time:76479ms step_avg:97.43ms +step:786/1670 train_time:76576ms step_avg:97.42ms +step:787/1670 train_time:76673ms step_avg:97.42ms +step:788/1670 train_time:76770ms step_avg:97.42ms +step:789/1670 train_time:76867ms step_avg:97.42ms +step:790/1670 train_time:76964ms step_avg:97.42ms +step:791/1670 train_time:77063ms step_avg:97.42ms +step:792/1670 train_time:77160ms step_avg:97.42ms +step:793/1670 train_time:77258ms step_avg:97.42ms +step:794/1670 train_time:77355ms step_avg:97.42ms +step:795/1670 train_time:77452ms step_avg:97.42ms +step:796/1670 train_time:77549ms step_avg:97.42ms +step:797/1670 train_time:77648ms step_avg:97.42ms +step:798/1670 train_time:77745ms step_avg:97.43ms +step:799/1670 train_time:77842ms step_avg:97.42ms +step:800/1670 train_time:77939ms step_avg:97.42ms +step:801/1670 train_time:78036ms step_avg:97.42ms +step:802/1670 train_time:78133ms step_avg:97.42ms +step:803/1670 train_time:78231ms step_avg:97.42ms +step:804/1670 train_time:78329ms step_avg:97.42ms +step:805/1670 train_time:78428ms step_avg:97.43ms +step:806/1670 train_time:78526ms step_avg:97.43ms +step:807/1670 train_time:78623ms step_avg:97.43ms +step:808/1670 train_time:78720ms step_avg:97.43ms +step:809/1670 train_time:78816ms step_avg:97.42ms +step:810/1670 train_time:78914ms step_avg:97.42ms +step:811/1670 train_time:79011ms step_avg:97.42ms +step:812/1670 train_time:79109ms step_avg:97.42ms +step:813/1670 train_time:79206ms step_avg:97.42ms +step:814/1670 train_time:79304ms step_avg:97.42ms +step:815/1670 train_time:79402ms step_avg:97.43ms +step:816/1670 train_time:79499ms step_avg:97.42ms +step:817/1670 train_time:79596ms step_avg:97.42ms +step:818/1670 train_time:79692ms step_avg:97.42ms +step:819/1670 train_time:79789ms step_avg:97.42ms +step:820/1670 train_time:79886ms step_avg:97.42ms +step:821/1670 train_time:79985ms step_avg:97.42ms +step:822/1670 train_time:80083ms step_avg:97.42ms +step:823/1670 train_time:80181ms step_avg:97.42ms +step:824/1670 train_time:80279ms step_avg:97.43ms +step:825/1670 train_time:80376ms step_avg:97.43ms +step:826/1670 train_time:80472ms step_avg:97.42ms +step:827/1670 train_time:80570ms step_avg:97.42ms +step:828/1670 train_time:80667ms step_avg:97.42ms +step:829/1670 train_time:80764ms step_avg:97.42ms +step:830/1670 train_time:80862ms step_avg:97.42ms +step:831/1670 train_time:80960ms step_avg:97.42ms +step:832/1670 train_time:81056ms step_avg:97.42ms +step:833/1670 train_time:81153ms step_avg:97.42ms +step:834/1670 train_time:81251ms step_avg:97.42ms +step:835/1670 train_time:81348ms step_avg:97.42ms +step:836/1670 train_time:81446ms step_avg:97.42ms +step:837/1670 train_time:81544ms step_avg:97.42ms +step:838/1670 train_time:81642ms step_avg:97.42ms +step:839/1670 train_time:81738ms step_avg:97.42ms +step:840/1670 train_time:81835ms step_avg:97.42ms +step:841/1670 train_time:81932ms step_avg:97.42ms +step:842/1670 train_time:82030ms step_avg:97.42ms +step:843/1670 train_time:82127ms step_avg:97.42ms +step:844/1670 train_time:82226ms step_avg:97.42ms +step:845/1670 train_time:82323ms step_avg:97.42ms +step:846/1670 train_time:82420ms step_avg:97.42ms +step:847/1670 train_time:82518ms step_avg:97.42ms +step:848/1670 train_time:82614ms step_avg:97.42ms +step:849/1670 train_time:82712ms step_avg:97.42ms +step:850/1670 train_time:82809ms step_avg:97.42ms +step:851/1670 train_time:83081ms step_avg:97.63ms +step:852/1670 train_time:83264ms step_avg:97.73ms +step:853/1670 train_time:83359ms step_avg:97.72ms +step:854/1670 train_time:83455ms step_avg:97.72ms +step:855/1670 train_time:83551ms step_avg:97.72ms +step:856/1670 train_time:83647ms step_avg:97.72ms +step:857/1670 train_time:83743ms step_avg:97.72ms +step:858/1670 train_time:83840ms step_avg:97.72ms +step:859/1670 train_time:83936ms step_avg:97.71ms +step:860/1670 train_time:84032ms step_avg:97.71ms +step:861/1670 train_time:84137ms step_avg:97.72ms +step:862/1670 train_time:84236ms step_avg:97.72ms +step:863/1670 train_time:84335ms step_avg:97.72ms +step:864/1670 train_time:84432ms step_avg:97.72ms +step:865/1670 train_time:84528ms step_avg:97.72ms +step:866/1670 train_time:84626ms step_avg:97.72ms +step:867/1670 train_time:84722ms step_avg:97.72ms +step:868/1670 train_time:84818ms step_avg:97.72ms +step:869/1670 train_time:84914ms step_avg:97.71ms +step:870/1670 train_time:85010ms step_avg:97.71ms +step:871/1670 train_time:85110ms step_avg:97.72ms +step:872/1670 train_time:85210ms step_avg:97.72ms +step:873/1670 train_time:85309ms step_avg:97.72ms +step:874/1670 train_time:85407ms step_avg:97.72ms +step:875/1670 train_time:85505ms step_avg:97.72ms +step:875/1670 val_loss:3.5229 train_time:85602ms step_avg:97.83ms +step:876/1670 train_time:85623ms step_avg:97.74ms +step:877/1670 train_time:85707ms step_avg:97.73ms +step:878/1670 train_time:85807ms step_avg:97.73ms +step:879/1670 train_time:85905ms step_avg:97.73ms +step:880/1670 train_time:86002ms step_avg:97.73ms +step:881/1670 train_time:86098ms step_avg:97.73ms +step:882/1670 train_time:86194ms step_avg:97.73ms +step:883/1670 train_time:86290ms step_avg:97.72ms +step:884/1670 train_time:86386ms step_avg:97.72ms +step:885/1670 train_time:86483ms step_avg:97.72ms +step:886/1670 train_time:86582ms step_avg:97.72ms +step:887/1670 train_time:86684ms step_avg:97.73ms +step:888/1670 train_time:86784ms step_avg:97.73ms +step:889/1670 train_time:86883ms step_avg:97.73ms +step:890/1670 train_time:86980ms step_avg:97.73ms +step:891/1670 train_time:87077ms step_avg:97.73ms +step:892/1670 train_time:87173ms step_avg:97.73ms +step:893/1670 train_time:87271ms step_avg:97.73ms +step:894/1670 train_time:87366ms step_avg:97.72ms +step:895/1670 train_time:87463ms step_avg:97.72ms +step:896/1670 train_time:87562ms step_avg:97.73ms +step:897/1670 train_time:87661ms step_avg:97.73ms +step:898/1670 train_time:87762ms step_avg:97.73ms +step:899/1670 train_time:87861ms step_avg:97.73ms +step:900/1670 train_time:87959ms step_avg:97.73ms +step:901/1670 train_time:88056ms step_avg:97.73ms +step:902/1670 train_time:88151ms step_avg:97.73ms +step:903/1670 train_time:88248ms step_avg:97.73ms +step:904/1670 train_time:88344ms step_avg:97.73ms +step:905/1670 train_time:88441ms step_avg:97.72ms +step:906/1670 train_time:88538ms step_avg:97.72ms +step:907/1670 train_time:88637ms step_avg:97.73ms +step:908/1670 train_time:88735ms step_avg:97.73ms +step:909/1670 train_time:88834ms step_avg:97.73ms +step:910/1670 train_time:88931ms step_avg:97.73ms +step:911/1670 train_time:89027ms step_avg:97.72ms +step:912/1670 train_time:89124ms step_avg:97.72ms +step:913/1670 train_time:89221ms step_avg:97.72ms +step:914/1670 train_time:89318ms step_avg:97.72ms +step:915/1670 train_time:89414ms step_avg:97.72ms +step:916/1670 train_time:89511ms step_avg:97.72ms +step:917/1670 train_time:89607ms step_avg:97.72ms +step:918/1670 train_time:89706ms step_avg:97.72ms +step:919/1670 train_time:89805ms step_avg:97.72ms +step:920/1670 train_time:89903ms step_avg:97.72ms +step:921/1670 train_time:90001ms step_avg:97.72ms +step:922/1670 train_time:90099ms step_avg:97.72ms +step:923/1670 train_time:90196ms step_avg:97.72ms +step:924/1670 train_time:90294ms step_avg:97.72ms +step:925/1670 train_time:90391ms step_avg:97.72ms +step:926/1670 train_time:90487ms step_avg:97.72ms +step:927/1670 train_time:90584ms step_avg:97.72ms +step:928/1670 train_time:90682ms step_avg:97.72ms +step:929/1670 train_time:90780ms step_avg:97.72ms +step:930/1670 train_time:90878ms step_avg:97.72ms +step:931/1670 train_time:90976ms step_avg:97.72ms +step:932/1670 train_time:91073ms step_avg:97.72ms +step:933/1670 train_time:91170ms step_avg:97.72ms +step:934/1670 train_time:91267ms step_avg:97.72ms +step:935/1670 train_time:91365ms step_avg:97.72ms +step:936/1670 train_time:91463ms step_avg:97.72ms +step:937/1670 train_time:91560ms step_avg:97.72ms +step:938/1670 train_time:91658ms step_avg:97.72ms +step:939/1670 train_time:91756ms step_avg:97.72ms +step:940/1670 train_time:91853ms step_avg:97.72ms +step:941/1670 train_time:91950ms step_avg:97.72ms +step:942/1670 train_time:92047ms step_avg:97.71ms +step:943/1670 train_time:92144ms step_avg:97.71ms +step:944/1670 train_time:92242ms step_avg:97.71ms +step:945/1670 train_time:92340ms step_avg:97.71ms +step:946/1670 train_time:92437ms step_avg:97.71ms +step:947/1670 train_time:92534ms step_avg:97.71ms +step:948/1670 train_time:92631ms step_avg:97.71ms +step:949/1670 train_time:92728ms step_avg:97.71ms +step:950/1670 train_time:92827ms step_avg:97.71ms +step:951/1670 train_time:92924ms step_avg:97.71ms +step:952/1670 train_time:93022ms step_avg:97.71ms +step:953/1670 train_time:93120ms step_avg:97.71ms +step:954/1670 train_time:93218ms step_avg:97.71ms +step:955/1670 train_time:93315ms step_avg:97.71ms +step:956/1670 train_time:93411ms step_avg:97.71ms +step:957/1670 train_time:93508ms step_avg:97.71ms +step:958/1670 train_time:93605ms step_avg:97.71ms +step:959/1670 train_time:93702ms step_avg:97.71ms +step:960/1670 train_time:93800ms step_avg:97.71ms +step:961/1670 train_time:93898ms step_avg:97.71ms +step:962/1670 train_time:93996ms step_avg:97.71ms +step:963/1670 train_time:94093ms step_avg:97.71ms +step:964/1670 train_time:94190ms step_avg:97.71ms +step:965/1670 train_time:94287ms step_avg:97.71ms +step:966/1670 train_time:94384ms step_avg:97.71ms +step:967/1670 train_time:94481ms step_avg:97.71ms +step:968/1670 train_time:94578ms step_avg:97.70ms +step:969/1670 train_time:94676ms step_avg:97.71ms +step:970/1670 train_time:94774ms step_avg:97.71ms +step:971/1670 train_time:94871ms step_avg:97.70ms +step:972/1670 train_time:94968ms step_avg:97.70ms +step:973/1670 train_time:95065ms step_avg:97.70ms +step:974/1670 train_time:95163ms step_avg:97.70ms +step:975/1670 train_time:95261ms step_avg:97.70ms +step:976/1670 train_time:95358ms step_avg:97.70ms +step:977/1670 train_time:95457ms step_avg:97.70ms +step:978/1670 train_time:95554ms step_avg:97.70ms +step:979/1670 train_time:95651ms step_avg:97.70ms +step:980/1670 train_time:95748ms step_avg:97.70ms +step:981/1670 train_time:95846ms step_avg:97.70ms +step:982/1670 train_time:95944ms step_avg:97.70ms +step:983/1670 train_time:96041ms step_avg:97.70ms +step:984/1670 train_time:96139ms step_avg:97.70ms +step:985/1670 train_time:96236ms step_avg:97.70ms +step:986/1670 train_time:96333ms step_avg:97.70ms +step:987/1670 train_time:96429ms step_avg:97.70ms +step:988/1670 train_time:96526ms step_avg:97.70ms +step:989/1670 train_time:96624ms step_avg:97.70ms +step:990/1670 train_time:96722ms step_avg:97.70ms +step:991/1670 train_time:96820ms step_avg:97.70ms +step:992/1670 train_time:96917ms step_avg:97.70ms +step:993/1670 train_time:97015ms step_avg:97.70ms +step:994/1670 train_time:97112ms step_avg:97.70ms +step:995/1670 train_time:97208ms step_avg:97.70ms +step:996/1670 train_time:97306ms step_avg:97.70ms +step:997/1670 train_time:97403ms step_avg:97.70ms +step:998/1670 train_time:97501ms step_avg:97.70ms +step:999/1670 train_time:97598ms step_avg:97.70ms +step:1000/1670 train_time:97696ms step_avg:97.70ms +step:1000/1670 val_loss:3.4804 train_time:97792ms step_avg:97.79ms +step:1001/1670 train_time:97814ms step_avg:97.72ms +step:1002/1670 train_time:97899ms step_avg:97.70ms +step:1003/1670 train_time:97998ms step_avg:97.71ms +step:1004/1670 train_time:98097ms step_avg:97.71ms +step:1005/1670 train_time:98194ms step_avg:97.71ms +step:1006/1670 train_time:98291ms step_avg:97.70ms +step:1007/1670 train_time:98387ms step_avg:97.70ms +step:1008/1670 train_time:98483ms step_avg:97.70ms +step:1009/1670 train_time:98579ms step_avg:97.70ms +step:1010/1670 train_time:98675ms step_avg:97.70ms +step:1011/1670 train_time:98774ms step_avg:97.70ms +step:1012/1670 train_time:98873ms step_avg:97.70ms +step:1013/1670 train_time:98974ms step_avg:97.70ms +step:1014/1670 train_time:99073ms step_avg:97.71ms +step:1015/1670 train_time:99171ms step_avg:97.71ms +step:1016/1670 train_time:99268ms step_avg:97.70ms +step:1017/1670 train_time:99365ms step_avg:97.70ms +step:1018/1670 train_time:99461ms step_avg:97.70ms +step:1019/1670 train_time:99557ms step_avg:97.70ms +step:1020/1670 train_time:99654ms step_avg:97.70ms +step:1021/1670 train_time:99752ms step_avg:97.70ms +step:1022/1670 train_time:99851ms step_avg:97.70ms +step:1023/1670 train_time:99951ms step_avg:97.70ms +step:1024/1670 train_time:100050ms step_avg:97.71ms +step:1025/1670 train_time:100148ms step_avg:97.71ms +step:1026/1670 train_time:100245ms step_avg:97.70ms +step:1027/1670 train_time:100341ms step_avg:97.70ms +step:1028/1670 train_time:100439ms step_avg:97.70ms +step:1029/1670 train_time:100536ms step_avg:97.70ms +step:1030/1670 train_time:100632ms step_avg:97.70ms +step:1031/1670 train_time:100729ms step_avg:97.70ms +step:1032/1670 train_time:100826ms step_avg:97.70ms +step:1033/1670 train_time:100924ms step_avg:97.70ms +step:1034/1670 train_time:101022ms step_avg:97.70ms +step:1035/1670 train_time:101120ms step_avg:97.70ms +step:1036/1670 train_time:101217ms step_avg:97.70ms +step:1037/1670 train_time:101315ms step_avg:97.70ms +step:1038/1670 train_time:101414ms step_avg:97.70ms +step:1039/1670 train_time:101511ms step_avg:97.70ms +step:1040/1670 train_time:101608ms step_avg:97.70ms +step:1041/1670 train_time:101705ms step_avg:97.70ms +step:1042/1670 train_time:101801ms step_avg:97.70ms +step:1043/1670 train_time:101899ms step_avg:97.70ms +step:1044/1670 train_time:101997ms step_avg:97.70ms +step:1045/1670 train_time:102095ms step_avg:97.70ms +step:1046/1670 train_time:102193ms step_avg:97.70ms +step:1047/1670 train_time:102291ms step_avg:97.70ms +step:1048/1670 train_time:102389ms step_avg:97.70ms +step:1049/1670 train_time:102486ms step_avg:97.70ms +step:1050/1670 train_time:102583ms step_avg:97.70ms +step:1051/1670 train_time:102679ms step_avg:97.70ms +step:1052/1670 train_time:102776ms step_avg:97.70ms +step:1053/1670 train_time:102873ms step_avg:97.70ms +step:1054/1670 train_time:102971ms step_avg:97.70ms +step:1055/1670 train_time:103069ms step_avg:97.70ms +step:1056/1670 train_time:103167ms step_avg:97.70ms +step:1057/1670 train_time:103265ms step_avg:97.70ms +step:1058/1670 train_time:103362ms step_avg:97.70ms +step:1059/1670 train_time:103459ms step_avg:97.70ms +step:1060/1670 train_time:103557ms step_avg:97.69ms +step:1061/1670 train_time:103654ms step_avg:97.70ms +step:1062/1670 train_time:103922ms step_avg:97.86ms +step:1063/1670 train_time:104096ms step_avg:97.93ms +step:1064/1670 train_time:104191ms step_avg:97.92ms +step:1065/1670 train_time:104287ms step_avg:97.92ms +step:1066/1670 train_time:104383ms step_avg:97.92ms +step:1067/1670 train_time:104479ms step_avg:97.92ms +step:1068/1670 train_time:104575ms step_avg:97.92ms +step:1069/1670 train_time:104671ms step_avg:97.92ms +step:1070/1670 train_time:104768ms step_avg:97.91ms +step:1071/1670 train_time:104864ms step_avg:97.91ms +step:1072/1670 train_time:104969ms step_avg:97.92ms +step:1073/1670 train_time:105071ms step_avg:97.92ms +step:1074/1670 train_time:105169ms step_avg:97.92ms +step:1075/1670 train_time:105266ms step_avg:97.92ms +step:1076/1670 train_time:105363ms step_avg:97.92ms +step:1077/1670 train_time:105459ms step_avg:97.92ms +step:1078/1670 train_time:105555ms step_avg:97.92ms +step:1079/1670 train_time:105651ms step_avg:97.92ms +step:1080/1670 train_time:105748ms step_avg:97.92ms +step:1081/1670 train_time:105845ms step_avg:97.91ms +step:1082/1670 train_time:105943ms step_avg:97.91ms +step:1083/1670 train_time:106043ms step_avg:97.92ms +step:1084/1670 train_time:106140ms step_avg:97.92ms +step:1085/1670 train_time:106240ms step_avg:97.92ms +step:1086/1670 train_time:106337ms step_avg:97.92ms +step:1087/1670 train_time:106434ms step_avg:97.92ms +step:1088/1670 train_time:106531ms step_avg:97.91ms +step:1089/1670 train_time:106628ms step_avg:97.91ms +step:1090/1670 train_time:106725ms step_avg:97.91ms +step:1091/1670 train_time:106821ms step_avg:97.91ms +step:1092/1670 train_time:106919ms step_avg:97.91ms +step:1093/1670 train_time:107017ms step_avg:97.91ms +step:1094/1670 train_time:107115ms step_avg:97.91ms +step:1095/1670 train_time:107215ms step_avg:97.91ms +step:1096/1670 train_time:107312ms step_avg:97.91ms +step:1097/1670 train_time:107409ms step_avg:97.91ms +step:1098/1670 train_time:107506ms step_avg:97.91ms +step:1099/1670 train_time:107602ms step_avg:97.91ms +step:1100/1670 train_time:107699ms step_avg:97.91ms +step:1101/1670 train_time:107796ms step_avg:97.91ms +step:1102/1670 train_time:107894ms step_avg:97.91ms +step:1103/1670 train_time:107992ms step_avg:97.91ms +step:1104/1670 train_time:108091ms step_avg:97.91ms +step:1105/1670 train_time:108190ms step_avg:97.91ms +step:1106/1670 train_time:108289ms step_avg:97.91ms +step:1107/1670 train_time:108386ms step_avg:97.91ms +step:1108/1670 train_time:108483ms step_avg:97.91ms +step:1109/1670 train_time:108580ms step_avg:97.91ms +step:1110/1670 train_time:108677ms step_avg:97.91ms +step:1111/1670 train_time:108774ms step_avg:97.91ms +step:1112/1670 train_time:108872ms step_avg:97.91ms +step:1113/1670 train_time:108970ms step_avg:97.91ms +step:1114/1670 train_time:109069ms step_avg:97.91ms +step:1115/1670 train_time:109167ms step_avg:97.91ms +step:1116/1670 train_time:109265ms step_avg:97.91ms +step:1117/1670 train_time:109363ms step_avg:97.91ms +step:1118/1670 train_time:109460ms step_avg:97.91ms +step:1119/1670 train_time:109557ms step_avg:97.91ms +step:1120/1670 train_time:109655ms step_avg:97.91ms +step:1121/1670 train_time:109753ms step_avg:97.91ms +step:1122/1670 train_time:109852ms step_avg:97.91ms +step:1123/1670 train_time:109951ms step_avg:97.91ms +step:1124/1670 train_time:110050ms step_avg:97.91ms +step:1125/1670 train_time:110150ms step_avg:97.91ms +step:1125/1670 val_loss:3.4262 train_time:110248ms step_avg:98.00ms +step:1126/1670 train_time:110269ms step_avg:97.93ms +step:1127/1670 train_time:110356ms step_avg:97.92ms +step:1128/1670 train_time:110456ms step_avg:97.92ms +step:1129/1670 train_time:110553ms step_avg:97.92ms +step:1130/1670 train_time:110649ms step_avg:97.92ms +step:1131/1670 train_time:110746ms step_avg:97.92ms +step:1132/1670 train_time:110843ms step_avg:97.92ms +step:1133/1670 train_time:110940ms step_avg:97.92ms +step:1134/1670 train_time:111037ms step_avg:97.92ms +step:1135/1670 train_time:111134ms step_avg:97.92ms +step:1136/1670 train_time:111235ms step_avg:97.92ms +step:1137/1670 train_time:111335ms step_avg:97.92ms +step:1138/1670 train_time:111434ms step_avg:97.92ms +step:1139/1670 train_time:111532ms step_avg:97.92ms +step:1140/1670 train_time:111630ms step_avg:97.92ms +step:1141/1670 train_time:111728ms step_avg:97.92ms +step:1142/1670 train_time:111825ms step_avg:97.92ms +step:1143/1670 train_time:111922ms step_avg:97.92ms +step:1144/1670 train_time:112020ms step_avg:97.92ms +step:1145/1670 train_time:112116ms step_avg:97.92ms +step:1146/1670 train_time:112215ms step_avg:97.92ms +step:1147/1670 train_time:112314ms step_avg:97.92ms +step:1148/1670 train_time:112414ms step_avg:97.92ms +step:1149/1670 train_time:112512ms step_avg:97.92ms +step:1150/1670 train_time:112610ms step_avg:97.92ms +step:1151/1670 train_time:112707ms step_avg:97.92ms +step:1152/1670 train_time:112804ms step_avg:97.92ms +step:1153/1670 train_time:112902ms step_avg:97.92ms +step:1154/1670 train_time:112999ms step_avg:97.92ms +step:1155/1670 train_time:113096ms step_avg:97.92ms +step:1156/1670 train_time:113195ms step_avg:97.92ms +step:1157/1670 train_time:113292ms step_avg:97.92ms +step:1158/1670 train_time:113392ms step_avg:97.92ms +step:1159/1670 train_time:113490ms step_avg:97.92ms +step:1160/1670 train_time:113588ms step_avg:97.92ms +step:1161/1670 train_time:113687ms step_avg:97.92ms +step:1162/1670 train_time:113783ms step_avg:97.92ms +step:1163/1670 train_time:113881ms step_avg:97.92ms +step:1164/1670 train_time:113980ms step_avg:97.92ms +step:1165/1670 train_time:114078ms step_avg:97.92ms +step:1166/1670 train_time:114176ms step_avg:97.92ms +step:1167/1670 train_time:114273ms step_avg:97.92ms +step:1168/1670 train_time:114371ms step_avg:97.92ms +step:1169/1670 train_time:114469ms step_avg:97.92ms +step:1170/1670 train_time:114568ms step_avg:97.92ms +step:1171/1670 train_time:114666ms step_avg:97.92ms +step:1172/1670 train_time:114765ms step_avg:97.92ms +step:1173/1670 train_time:114862ms step_avg:97.92ms +step:1174/1670 train_time:114960ms step_avg:97.92ms +step:1175/1670 train_time:115058ms step_avg:97.92ms +step:1176/1670 train_time:115156ms step_avg:97.92ms +step:1177/1670 train_time:115254ms step_avg:97.92ms +step:1178/1670 train_time:115352ms step_avg:97.92ms +step:1179/1670 train_time:115450ms step_avg:97.92ms +step:1180/1670 train_time:115548ms step_avg:97.92ms +step:1181/1670 train_time:115647ms step_avg:97.92ms +step:1182/1670 train_time:115745ms step_avg:97.92ms +step:1183/1670 train_time:115843ms step_avg:97.92ms +step:1184/1670 train_time:115941ms step_avg:97.92ms +step:1185/1670 train_time:116040ms step_avg:97.92ms +step:1186/1670 train_time:116137ms step_avg:97.92ms +step:1187/1670 train_time:116235ms step_avg:97.92ms +step:1188/1670 train_time:116332ms step_avg:97.92ms +step:1189/1670 train_time:116430ms step_avg:97.92ms +step:1190/1670 train_time:116528ms step_avg:97.92ms +step:1191/1670 train_time:116626ms step_avg:97.92ms +step:1192/1670 train_time:116724ms step_avg:97.92ms +step:1193/1670 train_time:116822ms step_avg:97.92ms +step:1194/1670 train_time:116920ms step_avg:97.92ms +step:1195/1670 train_time:117018ms step_avg:97.92ms +step:1196/1670 train_time:117116ms step_avg:97.92ms +step:1197/1670 train_time:117214ms step_avg:97.92ms +step:1198/1670 train_time:117311ms step_avg:97.92ms +step:1199/1670 train_time:117409ms step_avg:97.92ms +step:1200/1670 train_time:117507ms step_avg:97.92ms +step:1201/1670 train_time:117605ms step_avg:97.92ms +step:1202/1670 train_time:117703ms step_avg:97.92ms +step:1203/1670 train_time:117802ms step_avg:97.92ms +step:1204/1670 train_time:117900ms step_avg:97.92ms +step:1205/1670 train_time:117998ms step_avg:97.92ms +step:1206/1670 train_time:118095ms step_avg:97.92ms +step:1207/1670 train_time:118194ms step_avg:97.92ms +step:1208/1670 train_time:118292ms step_avg:97.92ms +step:1209/1670 train_time:118390ms step_avg:97.92ms +step:1210/1670 train_time:118488ms step_avg:97.92ms +step:1211/1670 train_time:118586ms step_avg:97.92ms +step:1212/1670 train_time:118684ms step_avg:97.92ms +step:1213/1670 train_time:118781ms step_avg:97.92ms +step:1214/1670 train_time:118879ms step_avg:97.92ms +step:1215/1670 train_time:118977ms step_avg:97.92ms +step:1216/1670 train_time:119075ms step_avg:97.92ms +step:1217/1670 train_time:119173ms step_avg:97.92ms +step:1218/1670 train_time:119271ms step_avg:97.92ms +step:1219/1670 train_time:119369ms step_avg:97.92ms +step:1220/1670 train_time:119468ms step_avg:97.92ms +step:1221/1670 train_time:119567ms step_avg:97.93ms +step:1222/1670 train_time:119665ms step_avg:97.93ms +step:1223/1670 train_time:119762ms step_avg:97.92ms +step:1224/1670 train_time:119862ms step_avg:97.93ms +step:1225/1670 train_time:119958ms step_avg:97.92ms +step:1226/1670 train_time:120056ms step_avg:97.92ms +step:1227/1670 train_time:120155ms step_avg:97.93ms +step:1228/1670 train_time:120253ms step_avg:97.93ms +step:1229/1670 train_time:120350ms step_avg:97.93ms +step:1230/1670 train_time:120448ms step_avg:97.93ms +step:1231/1670 train_time:120547ms step_avg:97.93ms +step:1232/1670 train_time:120645ms step_avg:97.93ms +step:1233/1670 train_time:120744ms step_avg:97.93ms +step:1234/1670 train_time:120843ms step_avg:97.93ms +step:1235/1670 train_time:120940ms step_avg:97.93ms +step:1236/1670 train_time:121039ms step_avg:97.93ms +step:1237/1670 train_time:121137ms step_avg:97.93ms +step:1238/1670 train_time:121234ms step_avg:97.93ms +step:1239/1670 train_time:121332ms step_avg:97.93ms +step:1240/1670 train_time:121430ms step_avg:97.93ms +step:1241/1670 train_time:121528ms step_avg:97.93ms +step:1242/1670 train_time:121625ms step_avg:97.93ms +step:1243/1670 train_time:121723ms step_avg:97.93ms +step:1244/1670 train_time:121822ms step_avg:97.93ms +step:1245/1670 train_time:121920ms step_avg:97.93ms +step:1246/1670 train_time:122019ms step_avg:97.93ms +step:1247/1670 train_time:122118ms step_avg:97.93ms +step:1248/1670 train_time:122219ms step_avg:97.93ms +step:1249/1670 train_time:122318ms step_avg:97.93ms +step:1250/1670 train_time:122417ms step_avg:97.93ms +step:1250/1670 val_loss:3.3827 train_time:122514ms step_avg:98.01ms +step:1251/1670 train_time:122536ms step_avg:97.95ms +step:1252/1670 train_time:122619ms step_avg:97.94ms +step:1253/1670 train_time:122718ms step_avg:97.94ms +step:1254/1670 train_time:122816ms step_avg:97.94ms +step:1255/1670 train_time:122913ms step_avg:97.94ms +step:1256/1670 train_time:123010ms step_avg:97.94ms +step:1257/1670 train_time:123107ms step_avg:97.94ms +step:1258/1670 train_time:123204ms step_avg:97.94ms +step:1259/1670 train_time:123302ms step_avg:97.94ms +step:1260/1670 train_time:123397ms step_avg:97.93ms +step:1261/1670 train_time:123497ms step_avg:97.94ms +step:1262/1670 train_time:123597ms step_avg:97.94ms +step:1263/1670 train_time:123696ms step_avg:97.94ms +step:1264/1670 train_time:123795ms step_avg:97.94ms +step:1265/1670 train_time:123893ms step_avg:97.94ms +step:1266/1670 train_time:123990ms step_avg:97.94ms +step:1267/1670 train_time:124088ms step_avg:97.94ms +step:1268/1670 train_time:124185ms step_avg:97.94ms +step:1269/1670 train_time:124283ms step_avg:97.94ms +step:1270/1670 train_time:124380ms step_avg:97.94ms +step:1271/1670 train_time:124478ms step_avg:97.94ms +step:1272/1670 train_time:124576ms step_avg:97.94ms +step:1273/1670 train_time:124676ms step_avg:97.94ms +step:1274/1670 train_time:125042ms step_avg:98.15ms +step:1275/1670 train_time:125134ms step_avg:98.14ms +step:1276/1670 train_time:125231ms step_avg:98.14ms +step:1277/1670 train_time:125328ms step_avg:98.14ms +step:1278/1670 train_time:125426ms step_avg:98.14ms +step:1279/1670 train_time:125523ms step_avg:98.14ms +step:1280/1670 train_time:125620ms step_avg:98.14ms +step:1281/1670 train_time:125716ms step_avg:98.14ms +step:1282/1670 train_time:125813ms step_avg:98.14ms +step:1283/1670 train_time:125911ms step_avg:98.14ms +step:1284/1670 train_time:126015ms step_avg:98.14ms +step:1285/1670 train_time:126117ms step_avg:98.15ms +step:1286/1670 train_time:126216ms step_avg:98.15ms +step:1287/1670 train_time:126314ms step_avg:98.15ms +step:1288/1670 train_time:126412ms step_avg:98.15ms +step:1289/1670 train_time:126509ms step_avg:98.15ms +step:1290/1670 train_time:126608ms step_avg:98.15ms +step:1291/1670 train_time:126706ms step_avg:98.15ms +step:1292/1670 train_time:126803ms step_avg:98.15ms +step:1293/1670 train_time:126900ms step_avg:98.14ms +step:1294/1670 train_time:127000ms step_avg:98.15ms +step:1295/1670 train_time:127100ms step_avg:98.15ms +step:1296/1670 train_time:127199ms step_avg:98.15ms +step:1297/1670 train_time:127298ms step_avg:98.15ms +step:1298/1670 train_time:127395ms step_avg:98.15ms +step:1299/1670 train_time:127493ms step_avg:98.15ms +step:1300/1670 train_time:127590ms step_avg:98.15ms +step:1301/1670 train_time:127688ms step_avg:98.15ms +step:1302/1670 train_time:127786ms step_avg:98.15ms +step:1303/1670 train_time:127883ms step_avg:98.15ms +step:1304/1670 train_time:127982ms step_avg:98.15ms +step:1305/1670 train_time:128083ms step_avg:98.15ms +step:1306/1670 train_time:128181ms step_avg:98.15ms +step:1307/1670 train_time:128280ms step_avg:98.15ms +step:1308/1670 train_time:128379ms step_avg:98.15ms +step:1309/1670 train_time:128477ms step_avg:98.15ms +step:1310/1670 train_time:128574ms step_avg:98.15ms +step:1311/1670 train_time:128672ms step_avg:98.15ms +step:1312/1670 train_time:128769ms step_avg:98.15ms +step:1313/1670 train_time:128868ms step_avg:98.15ms +step:1314/1670 train_time:128966ms step_avg:98.15ms +step:1315/1670 train_time:129065ms step_avg:98.15ms +step:1316/1670 train_time:129165ms step_avg:98.15ms +step:1317/1670 train_time:129265ms step_avg:98.15ms +step:1318/1670 train_time:129364ms step_avg:98.15ms +step:1319/1670 train_time:129463ms step_avg:98.15ms +step:1320/1670 train_time:129561ms step_avg:98.15ms +step:1321/1670 train_time:129658ms step_avg:98.15ms +step:1322/1670 train_time:129756ms step_avg:98.15ms +step:1323/1670 train_time:129853ms step_avg:98.15ms +step:1324/1670 train_time:129951ms step_avg:98.15ms +step:1325/1670 train_time:130049ms step_avg:98.15ms +step:1326/1670 train_time:130148ms step_avg:98.15ms +step:1327/1670 train_time:130247ms step_avg:98.15ms +step:1328/1670 train_time:130347ms step_avg:98.15ms +step:1329/1670 train_time:130447ms step_avg:98.15ms +step:1330/1670 train_time:130545ms step_avg:98.15ms +step:1331/1670 train_time:130644ms step_avg:98.15ms +step:1332/1670 train_time:130742ms step_avg:98.15ms +step:1333/1670 train_time:130840ms step_avg:98.15ms +step:1334/1670 train_time:130937ms step_avg:98.15ms +step:1335/1670 train_time:131034ms step_avg:98.15ms +step:1336/1670 train_time:131132ms step_avg:98.15ms +step:1337/1670 train_time:131232ms step_avg:98.15ms +step:1338/1670 train_time:131332ms step_avg:98.16ms +step:1339/1670 train_time:131431ms step_avg:98.16ms +step:1340/1670 train_time:131529ms step_avg:98.16ms +step:1341/1670 train_time:131628ms step_avg:98.16ms +step:1342/1670 train_time:131727ms step_avg:98.16ms +step:1343/1670 train_time:131826ms step_avg:98.16ms +step:1344/1670 train_time:131924ms step_avg:98.16ms +step:1345/1670 train_time:132022ms step_avg:98.16ms +step:1346/1670 train_time:132120ms step_avg:98.16ms +step:1347/1670 train_time:132218ms step_avg:98.16ms +step:1348/1670 train_time:132315ms step_avg:98.16ms +step:1349/1670 train_time:132414ms step_avg:98.16ms +step:1350/1670 train_time:132512ms step_avg:98.16ms +step:1351/1670 train_time:132610ms step_avg:98.16ms +step:1352/1670 train_time:132709ms step_avg:98.16ms +step:1353/1670 train_time:132810ms step_avg:98.16ms +step:1354/1670 train_time:132909ms step_avg:98.16ms +step:1355/1670 train_time:133007ms step_avg:98.16ms +step:1356/1670 train_time:133106ms step_avg:98.16ms +step:1357/1670 train_time:133205ms step_avg:98.16ms +step:1358/1670 train_time:133303ms step_avg:98.16ms +step:1359/1670 train_time:133401ms step_avg:98.16ms +step:1360/1670 train_time:133500ms step_avg:98.16ms +step:1361/1670 train_time:133596ms step_avg:98.16ms +step:1362/1670 train_time:133695ms step_avg:98.16ms +step:1363/1670 train_time:133794ms step_avg:98.16ms +step:1364/1670 train_time:133892ms step_avg:98.16ms +step:1365/1670 train_time:133990ms step_avg:98.16ms +step:1366/1670 train_time:134089ms step_avg:98.16ms +step:1367/1670 train_time:134187ms step_avg:98.16ms +step:1368/1670 train_time:134286ms step_avg:98.16ms +step:1369/1670 train_time:134385ms step_avg:98.16ms +step:1370/1670 train_time:134484ms step_avg:98.16ms +step:1371/1670 train_time:134582ms step_avg:98.16ms +step:1372/1670 train_time:134681ms step_avg:98.16ms +step:1373/1670 train_time:134781ms step_avg:98.17ms +step:1374/1670 train_time:134879ms step_avg:98.16ms +step:1375/1670 train_time:134976ms step_avg:98.16ms +step:1375/1670 val_loss:3.3459 train_time:135073ms step_avg:98.23ms +step:1376/1670 train_time:135095ms step_avg:98.18ms +step:1377/1670 train_time:135179ms step_avg:98.17ms +step:1378/1670 train_time:135277ms step_avg:98.17ms +step:1379/1670 train_time:135375ms step_avg:98.17ms +step:1380/1670 train_time:135473ms step_avg:98.17ms +step:1381/1670 train_time:135570ms step_avg:98.17ms +step:1382/1670 train_time:135666ms step_avg:98.17ms +step:1383/1670 train_time:135764ms step_avg:98.17ms +step:1384/1670 train_time:135861ms step_avg:98.17ms +step:1385/1670 train_time:135958ms step_avg:98.16ms +step:1386/1670 train_time:136058ms step_avg:98.17ms +step:1387/1670 train_time:136159ms step_avg:98.17ms +step:1388/1670 train_time:136258ms step_avg:98.17ms +step:1389/1670 train_time:136357ms step_avg:98.17ms +step:1390/1670 train_time:136455ms step_avg:98.17ms +step:1391/1670 train_time:136552ms step_avg:98.17ms +step:1392/1670 train_time:136650ms step_avg:98.17ms +step:1393/1670 train_time:136747ms step_avg:98.17ms +step:1394/1670 train_time:136844ms step_avg:98.17ms +step:1395/1670 train_time:136941ms step_avg:98.17ms +step:1396/1670 train_time:137039ms step_avg:98.17ms +step:1397/1670 train_time:137138ms step_avg:98.17ms +step:1398/1670 train_time:137237ms step_avg:98.17ms +step:1399/1670 train_time:137335ms step_avg:98.17ms +step:1400/1670 train_time:137434ms step_avg:98.17ms +step:1401/1670 train_time:137532ms step_avg:98.17ms +step:1402/1670 train_time:137630ms step_avg:98.17ms +step:1403/1670 train_time:137728ms step_avg:98.17ms +step:1404/1670 train_time:137826ms step_avg:98.17ms +step:1405/1670 train_time:137923ms step_avg:98.17ms +step:1406/1670 train_time:138022ms step_avg:98.17ms +step:1407/1670 train_time:138120ms step_avg:98.17ms +step:1408/1670 train_time:138222ms step_avg:98.17ms +step:1409/1670 train_time:138321ms step_avg:98.17ms +step:1410/1670 train_time:138419ms step_avg:98.17ms +step:1411/1670 train_time:138517ms step_avg:98.17ms +step:1412/1670 train_time:138615ms step_avg:98.17ms +step:1413/1670 train_time:138712ms step_avg:98.17ms +step:1414/1670 train_time:138810ms step_avg:98.17ms +step:1415/1670 train_time:138909ms step_avg:98.17ms +step:1416/1670 train_time:139009ms step_avg:98.17ms +step:1417/1670 train_time:139110ms step_avg:98.17ms +step:1418/1670 train_time:139209ms step_avg:98.17ms +step:1419/1670 train_time:139310ms step_avg:98.17ms +step:1420/1670 train_time:139409ms step_avg:98.18ms +step:1421/1670 train_time:139508ms step_avg:98.18ms +step:1422/1670 train_time:139606ms step_avg:98.18ms +step:1423/1670 train_time:139703ms step_avg:98.17ms +step:1424/1670 train_time:139800ms step_avg:98.17ms +step:1425/1670 train_time:139897ms step_avg:98.17ms +step:1426/1670 train_time:139996ms step_avg:98.17ms +step:1427/1670 train_time:140096ms step_avg:98.18ms +step:1428/1670 train_time:140197ms step_avg:98.18ms +step:1429/1670 train_time:140296ms step_avg:98.18ms +step:1430/1670 train_time:140394ms step_avg:98.18ms +step:1431/1670 train_time:140493ms step_avg:98.18ms +step:1432/1670 train_time:140591ms step_avg:98.18ms +step:1433/1670 train_time:140689ms step_avg:98.18ms +step:1434/1670 train_time:140787ms step_avg:98.18ms +step:1435/1670 train_time:140885ms step_avg:98.18ms +step:1436/1670 train_time:140983ms step_avg:98.18ms +step:1437/1670 train_time:141082ms step_avg:98.18ms +step:1438/1670 train_time:141179ms step_avg:98.18ms +step:1439/1670 train_time:141277ms step_avg:98.18ms +step:1440/1670 train_time:141375ms step_avg:98.18ms +step:1441/1670 train_time:141473ms step_avg:98.18ms +step:1442/1670 train_time:141572ms step_avg:98.18ms +step:1443/1670 train_time:141671ms step_avg:98.18ms +step:1444/1670 train_time:141770ms step_avg:98.18ms +step:1445/1670 train_time:141868ms step_avg:98.18ms +step:1446/1670 train_time:141966ms step_avg:98.18ms +step:1447/1670 train_time:142066ms step_avg:98.18ms +step:1448/1670 train_time:142167ms step_avg:98.18ms +step:1449/1670 train_time:142266ms step_avg:98.18ms +step:1450/1670 train_time:142366ms step_avg:98.18ms +step:1451/1670 train_time:142464ms step_avg:98.18ms +step:1452/1670 train_time:142561ms step_avg:98.18ms +step:1453/1670 train_time:142658ms step_avg:98.18ms +step:1454/1670 train_time:142756ms step_avg:98.18ms +step:1455/1670 train_time:142854ms step_avg:98.18ms +step:1456/1670 train_time:142953ms step_avg:98.18ms +step:1457/1670 train_time:143051ms step_avg:98.18ms +step:1458/1670 train_time:143150ms step_avg:98.18ms +step:1459/1670 train_time:143251ms step_avg:98.18ms +step:1460/1670 train_time:143351ms step_avg:98.19ms +step:1461/1670 train_time:143450ms step_avg:98.19ms +step:1462/1670 train_time:143548ms step_avg:98.19ms +step:1463/1670 train_time:143646ms step_avg:98.19ms +step:1464/1670 train_time:143745ms step_avg:98.19ms +step:1465/1670 train_time:143843ms step_avg:98.19ms +step:1466/1670 train_time:143941ms step_avg:98.19ms +step:1467/1670 train_time:144038ms step_avg:98.19ms +step:1468/1670 train_time:144136ms step_avg:98.19ms +step:1469/1670 train_time:144236ms step_avg:98.19ms +step:1470/1670 train_time:144335ms step_avg:98.19ms +step:1471/1670 train_time:144433ms step_avg:98.19ms +step:1472/1670 train_time:144533ms step_avg:98.19ms +step:1473/1670 train_time:144631ms step_avg:98.19ms +step:1474/1670 train_time:144730ms step_avg:98.19ms +step:1475/1670 train_time:144828ms step_avg:98.19ms +step:1476/1670 train_time:144927ms step_avg:98.19ms +step:1477/1670 train_time:145027ms step_avg:98.19ms +step:1478/1670 train_time:145125ms step_avg:98.19ms +step:1479/1670 train_time:145226ms step_avg:98.19ms +step:1480/1670 train_time:145325ms step_avg:98.19ms +step:1481/1670 train_time:145424ms step_avg:98.19ms +step:1482/1670 train_time:145522ms step_avg:98.19ms +step:1483/1670 train_time:145620ms step_avg:98.19ms +step:1484/1670 train_time:145717ms step_avg:98.19ms +step:1485/1670 train_time:146077ms step_avg:98.37ms +step:1486/1670 train_time:146178ms step_avg:98.37ms +step:1487/1670 train_time:146274ms step_avg:98.37ms +step:1488/1670 train_time:146370ms step_avg:98.37ms +step:1489/1670 train_time:146467ms step_avg:98.37ms +step:1490/1670 train_time:146565ms step_avg:98.37ms +step:1491/1670 train_time:146662ms step_avg:98.36ms +step:1492/1670 train_time:146759ms step_avg:98.36ms +step:1493/1670 train_time:146856ms step_avg:98.36ms +step:1494/1670 train_time:146954ms step_avg:98.36ms +step:1495/1670 train_time:147060ms step_avg:98.37ms +step:1496/1670 train_time:147160ms step_avg:98.37ms +step:1497/1670 train_time:147258ms step_avg:98.37ms +step:1498/1670 train_time:147356ms step_avg:98.37ms +step:1499/1670 train_time:147453ms step_avg:98.37ms +step:1500/1670 train_time:147551ms step_avg:98.37ms +step:1500/1670 val_loss:3.3130 train_time:147648ms step_avg:98.43ms +step:1501/1670 train_time:147670ms step_avg:98.38ms +step:1502/1670 train_time:147753ms step_avg:98.37ms +step:1503/1670 train_time:147853ms step_avg:98.37ms +step:1504/1670 train_time:147952ms step_avg:98.37ms +step:1505/1670 train_time:148049ms step_avg:98.37ms +step:1506/1670 train_time:148146ms step_avg:98.37ms +step:1507/1670 train_time:148244ms step_avg:98.37ms +step:1508/1670 train_time:148340ms step_avg:98.37ms +step:1509/1670 train_time:148438ms step_avg:98.37ms +step:1510/1670 train_time:148535ms step_avg:98.37ms +step:1511/1670 train_time:148635ms step_avg:98.37ms +step:1512/1670 train_time:148736ms step_avg:98.37ms +step:1513/1670 train_time:148836ms step_avg:98.37ms +step:1514/1670 train_time:148936ms step_avg:98.37ms +step:1515/1670 train_time:149036ms step_avg:98.37ms +step:1516/1670 train_time:149135ms step_avg:98.37ms +step:1517/1670 train_time:149233ms step_avg:98.37ms +step:1518/1670 train_time:149331ms step_avg:98.37ms +step:1519/1670 train_time:149428ms step_avg:98.37ms +step:1520/1670 train_time:149526ms step_avg:98.37ms +step:1521/1670 train_time:149625ms step_avg:98.37ms +step:1522/1670 train_time:149723ms step_avg:98.37ms +step:1523/1670 train_time:149823ms step_avg:98.37ms +step:1524/1670 train_time:149922ms step_avg:98.37ms +step:1525/1670 train_time:150021ms step_avg:98.37ms +step:1526/1670 train_time:150119ms step_avg:98.37ms +step:1527/1670 train_time:150218ms step_avg:98.37ms +step:1528/1670 train_time:150316ms step_avg:98.37ms +step:1529/1670 train_time:150414ms step_avg:98.37ms +step:1530/1670 train_time:150513ms step_avg:98.37ms +step:1531/1670 train_time:150614ms step_avg:98.38ms +step:1532/1670 train_time:150713ms step_avg:98.38ms +step:1533/1670 train_time:150813ms step_avg:98.38ms +step:1534/1670 train_time:150913ms step_avg:98.38ms +step:1535/1670 train_time:151013ms step_avg:98.38ms +step:1536/1670 train_time:151112ms step_avg:98.38ms +step:1537/1670 train_time:151211ms step_avg:98.38ms +step:1538/1670 train_time:151310ms step_avg:98.38ms +step:1539/1670 train_time:151407ms step_avg:98.38ms +step:1540/1670 train_time:151504ms step_avg:98.38ms +step:1541/1670 train_time:151602ms step_avg:98.38ms +step:1542/1670 train_time:151701ms step_avg:98.38ms +step:1543/1670 train_time:151801ms step_avg:98.38ms +step:1544/1670 train_time:151900ms step_avg:98.38ms +step:1545/1670 train_time:152000ms step_avg:98.38ms +step:1546/1670 train_time:152098ms step_avg:98.38ms +step:1547/1670 train_time:152198ms step_avg:98.38ms +step:1548/1670 train_time:152298ms step_avg:98.38ms +step:1549/1670 train_time:152397ms step_avg:98.38ms +step:1550/1670 train_time:152495ms step_avg:98.38ms +step:1551/1670 train_time:152594ms step_avg:98.38ms +step:1552/1670 train_time:152694ms step_avg:98.39ms +step:1553/1670 train_time:152793ms step_avg:98.39ms +step:1554/1670 train_time:152892ms step_avg:98.39ms +step:1555/1670 train_time:152991ms step_avg:98.39ms +step:1556/1670 train_time:153089ms step_avg:98.39ms +step:1557/1670 train_time:153187ms step_avg:98.39ms +step:1558/1670 train_time:153285ms step_avg:98.39ms +step:1559/1670 train_time:153382ms step_avg:98.38ms +step:1560/1670 train_time:153480ms step_avg:98.38ms +step:1561/1670 train_time:153579ms step_avg:98.38ms +step:1562/1670 train_time:153677ms step_avg:98.38ms +step:1563/1670 train_time:153776ms step_avg:98.39ms +step:1564/1670 train_time:153876ms step_avg:98.39ms +step:1565/1670 train_time:153974ms step_avg:98.39ms +step:1566/1670 train_time:154073ms step_avg:98.39ms +step:1567/1670 train_time:154171ms step_avg:98.39ms +step:1568/1670 train_time:154269ms step_avg:98.39ms +step:1569/1670 train_time:154369ms step_avg:98.39ms +step:1570/1670 train_time:154466ms step_avg:98.39ms +step:1571/1670 train_time:154563ms step_avg:98.39ms +step:1572/1670 train_time:154661ms step_avg:98.38ms +step:1573/1670 train_time:154759ms step_avg:98.38ms +step:1574/1670 train_time:154857ms step_avg:98.38ms +step:1575/1670 train_time:154956ms step_avg:98.38ms +step:1576/1670 train_time:155054ms step_avg:98.38ms +step:1577/1670 train_time:155152ms step_avg:98.38ms +step:1578/1670 train_time:155250ms step_avg:98.38ms +step:1579/1670 train_time:155350ms step_avg:98.39ms +step:1580/1670 train_time:155448ms step_avg:98.39ms +step:1581/1670 train_time:155546ms step_avg:98.38ms +step:1582/1670 train_time:155644ms step_avg:98.38ms +step:1583/1670 train_time:155742ms step_avg:98.38ms +step:1584/1670 train_time:155840ms step_avg:98.38ms +step:1585/1670 train_time:155938ms step_avg:98.38ms +step:1586/1670 train_time:156037ms step_avg:98.38ms +step:1587/1670 train_time:156136ms step_avg:98.38ms +step:1588/1670 train_time:156235ms step_avg:98.38ms +step:1589/1670 train_time:156333ms step_avg:98.38ms +step:1590/1670 train_time:156433ms step_avg:98.39ms +step:1591/1670 train_time:156532ms step_avg:98.39ms +step:1592/1670 train_time:156631ms step_avg:98.39ms +step:1593/1670 train_time:156730ms step_avg:98.39ms +step:1594/1670 train_time:156828ms step_avg:98.39ms +step:1595/1670 train_time:156927ms step_avg:98.39ms +step:1596/1670 train_time:157026ms step_avg:98.39ms +step:1597/1670 train_time:157123ms step_avg:98.39ms +step:1598/1670 train_time:157221ms step_avg:98.39ms +step:1599/1670 train_time:157319ms step_avg:98.39ms +step:1600/1670 train_time:157418ms step_avg:98.39ms +step:1601/1670 train_time:157517ms step_avg:98.39ms +step:1602/1670 train_time:157617ms step_avg:98.39ms +step:1603/1670 train_time:157717ms step_avg:98.39ms +step:1604/1670 train_time:157815ms step_avg:98.39ms +step:1605/1670 train_time:157913ms step_avg:98.39ms +step:1606/1670 train_time:158012ms step_avg:98.39ms +step:1607/1670 train_time:158112ms step_avg:98.39ms +step:1608/1670 train_time:158210ms step_avg:98.39ms +step:1609/1670 train_time:158308ms step_avg:98.39ms +step:1610/1670 train_time:158405ms step_avg:98.39ms +step:1611/1670 train_time:158503ms step_avg:98.39ms +step:1612/1670 train_time:158602ms step_avg:98.39ms +step:1613/1670 train_time:158701ms step_avg:98.39ms +step:1614/1670 train_time:158799ms step_avg:98.39ms +step:1615/1670 train_time:158899ms step_avg:98.39ms +step:1616/1670 train_time:158998ms step_avg:98.39ms +step:1617/1670 train_time:159097ms step_avg:98.39ms +step:1618/1670 train_time:159197ms step_avg:98.39ms +step:1619/1670 train_time:159297ms step_avg:98.39ms +step:1620/1670 train_time:159397ms step_avg:98.39ms +step:1621/1670 train_time:159496ms step_avg:98.39ms +step:1622/1670 train_time:159596ms step_avg:98.39ms +step:1623/1670 train_time:159695ms step_avg:98.40ms +step:1624/1670 train_time:159794ms step_avg:98.40ms +step:1625/1670 train_time:159893ms step_avg:98.40ms +step:1625/1670 val_loss:3.2865 train_time:159990ms step_avg:98.46ms +step:1626/1670 train_time:160012ms step_avg:98.41ms +step:1627/1670 train_time:160095ms step_avg:98.40ms +step:1628/1670 train_time:160195ms step_avg:98.40ms +step:1629/1670 train_time:160293ms step_avg:98.40ms +step:1630/1670 train_time:160389ms step_avg:98.40ms +step:1631/1670 train_time:160487ms step_avg:98.40ms +step:1632/1670 train_time:160585ms step_avg:98.40ms +step:1633/1670 train_time:160682ms step_avg:98.40ms +step:1634/1670 train_time:160781ms step_avg:98.40ms +step:1635/1670 train_time:160879ms step_avg:98.40ms +step:1636/1670 train_time:160979ms step_avg:98.40ms +step:1637/1670 train_time:161080ms step_avg:98.40ms +step:1638/1670 train_time:161181ms step_avg:98.40ms +step:1639/1670 train_time:161281ms step_avg:98.40ms +step:1640/1670 train_time:161380ms step_avg:98.40ms +step:1641/1670 train_time:161478ms step_avg:98.40ms +step:1642/1670 train_time:161575ms step_avg:98.40ms +step:1643/1670 train_time:161673ms step_avg:98.40ms +step:1644/1670 train_time:161770ms step_avg:98.40ms +step:1645/1670 train_time:161868ms step_avg:98.40ms +step:1646/1670 train_time:161967ms step_avg:98.40ms +step:1647/1670 train_time:162069ms step_avg:98.40ms +step:1648/1670 train_time:162170ms step_avg:98.40ms +step:1649/1670 train_time:162270ms step_avg:98.41ms +step:1650/1670 train_time:162369ms step_avg:98.41ms +step:1651/1670 train_time:162468ms step_avg:98.41ms +step:1652/1670 train_time:162567ms step_avg:98.41ms +step:1653/1670 train_time:162665ms step_avg:98.41ms +step:1654/1670 train_time:162763ms step_avg:98.41ms +step:1655/1670 train_time:162861ms step_avg:98.41ms +step:1656/1670 train_time:162959ms step_avg:98.41ms +step:1657/1670 train_time:163059ms step_avg:98.41ms +step:1658/1670 train_time:163160ms step_avg:98.41ms +step:1659/1670 train_time:163261ms step_avg:98.41ms +step:1660/1670 train_time:163361ms step_avg:98.41ms +step:1661/1670 train_time:163460ms step_avg:98.41ms +step:1662/1670 train_time:163559ms step_avg:98.41ms +step:1663/1670 train_time:163657ms step_avg:98.41ms +step:1664/1670 train_time:163754ms step_avg:98.41ms +step:1665/1670 train_time:163851ms step_avg:98.41ms +step:1666/1670 train_time:163949ms step_avg:98.41ms +step:1667/1670 train_time:164048ms step_avg:98.41ms +step:1668/1670 train_time:164146ms step_avg:98.41ms +step:1669/1670 train_time:164245ms step_avg:98.41ms +step:1670/1670 train_time:164344ms step_avg:98.41ms +step:1670/1670 val_loss:3.2784 train_time:164443ms step_avg:98.47ms +peak memory allocated: 34613 MiB reserved: 49116 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt b/records/050925_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt new file mode 100644 index 000000000..5d5d3441a --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:05:54 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 80638 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 80639 C /usr/bin/python3 610MiB | +| 0 N/A N/A 80640 C /usr/bin/python3 610MiB | +| 0 N/A N/A 80641 C /usr/bin/python3 610MiB | +| 0 N/A N/A 80642 C /usr/bin/python3 610MiB | +| 0 N/A N/A 80643 C /usr/bin/python3 610MiB | +| 0 N/A N/A 80644 C /usr/bin/python3 610MiB | +| 0 N/A N/A 80645 C /usr/bin/python3 610MiB | +| 1 N/A N/A 80639 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 80640 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 80641 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 80642 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 80643 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 80644 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 80645 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:428ms step_avg:428.15ms +step:2/1670 train_time:448ms step_avg:224.23ms +step:3/1670 train_time:522ms step_avg:174.06ms +step:4/1670 train_time:616ms step_avg:154.12ms +step:5/1670 train_time:711ms step_avg:142.12ms +step:6/1670 train_time:805ms step_avg:134.20ms +step:7/1670 train_time:901ms step_avg:128.72ms +step:8/1670 train_time:996ms step_avg:124.49ms +step:9/1670 train_time:1090ms step_avg:121.14ms +step:10/1670 train_time:1186ms step_avg:118.56ms +step:11/1670 train_time:1281ms step_avg:116.45ms +step:12/1670 train_time:1380ms step_avg:114.98ms +step:13/1670 train_time:1479ms step_avg:113.78ms +step:14/1670 train_time:1577ms step_avg:112.64ms +step:15/1670 train_time:1673ms step_avg:111.52ms +step:16/1670 train_time:1768ms step_avg:110.48ms +step:17/1670 train_time:1864ms step_avg:109.62ms +step:18/1670 train_time:1959ms step_avg:108.84ms +step:19/1670 train_time:2055ms step_avg:108.17ms +step:20/1670 train_time:2150ms step_avg:107.51ms +step:21/1670 train_time:2245ms step_avg:106.93ms +step:22/1670 train_time:2342ms step_avg:106.46ms +step:23/1670 train_time:2439ms step_avg:106.05ms +step:24/1670 train_time:2536ms step_avg:105.68ms +step:25/1670 train_time:2633ms step_avg:105.31ms +step:26/1670 train_time:2729ms step_avg:104.95ms +step:27/1670 train_time:2825ms step_avg:104.62ms +step:28/1670 train_time:2921ms step_avg:104.31ms +step:29/1670 train_time:3016ms step_avg:104.00ms +step:30/1670 train_time:3111ms step_avg:103.71ms +step:31/1670 train_time:3206ms step_avg:103.43ms +step:32/1670 train_time:3302ms step_avg:103.20ms +step:33/1670 train_time:3399ms step_avg:102.99ms +step:34/1670 train_time:3496ms step_avg:102.82ms +step:35/1670 train_time:3593ms step_avg:102.65ms +step:36/1670 train_time:3689ms step_avg:102.46ms +step:37/1670 train_time:3784ms step_avg:102.27ms +step:38/1670 train_time:3881ms step_avg:102.14ms +step:39/1670 train_time:3976ms step_avg:101.96ms +step:40/1670 train_time:4072ms step_avg:101.81ms +step:41/1670 train_time:4168ms step_avg:101.65ms +step:42/1670 train_time:4263ms step_avg:101.51ms +step:43/1670 train_time:4359ms step_avg:101.37ms +step:44/1670 train_time:4455ms step_avg:101.24ms +step:45/1670 train_time:4550ms step_avg:101.12ms +step:46/1670 train_time:4647ms step_avg:101.02ms +step:47/1670 train_time:4744ms step_avg:100.93ms +step:48/1670 train_time:4840ms step_avg:100.84ms +step:49/1670 train_time:4937ms step_avg:100.76ms +step:50/1670 train_time:5033ms step_avg:100.66ms +step:51/1670 train_time:5128ms step_avg:100.55ms +step:52/1670 train_time:5224ms step_avg:100.47ms +step:53/1670 train_time:5320ms step_avg:100.38ms +step:54/1670 train_time:5417ms step_avg:100.31ms +step:55/1670 train_time:5511ms step_avg:100.21ms +step:56/1670 train_time:5608ms step_avg:100.14ms +step:57/1670 train_time:5704ms step_avg:100.08ms +step:58/1670 train_time:5801ms step_avg:100.01ms +step:59/1670 train_time:5897ms step_avg:99.95ms +step:60/1670 train_time:5993ms step_avg:99.89ms +step:61/1670 train_time:6088ms step_avg:99.81ms +step:62/1670 train_time:6184ms step_avg:99.74ms +step:63/1670 train_time:6280ms step_avg:99.69ms +step:64/1670 train_time:6376ms step_avg:99.63ms +step:65/1670 train_time:6473ms step_avg:99.58ms +step:66/1670 train_time:6568ms step_avg:99.52ms +step:67/1670 train_time:6663ms step_avg:99.45ms +step:68/1670 train_time:6760ms step_avg:99.40ms +step:69/1670 train_time:6856ms step_avg:99.36ms +step:70/1670 train_time:6951ms step_avg:99.30ms +step:71/1670 train_time:7046ms step_avg:99.24ms +step:72/1670 train_time:7142ms step_avg:99.19ms +step:73/1670 train_time:7238ms step_avg:99.14ms +step:74/1670 train_time:7333ms step_avg:99.10ms +step:75/1670 train_time:7430ms step_avg:99.06ms +step:76/1670 train_time:7526ms step_avg:99.02ms +step:77/1670 train_time:7622ms step_avg:98.99ms +step:78/1670 train_time:7718ms step_avg:98.95ms +step:79/1670 train_time:7814ms step_avg:98.91ms +step:80/1670 train_time:7909ms step_avg:98.87ms +step:81/1670 train_time:8005ms step_avg:98.83ms +step:82/1670 train_time:8102ms step_avg:98.80ms +step:83/1670 train_time:8198ms step_avg:98.78ms +step:84/1670 train_time:8294ms step_avg:98.73ms +step:85/1670 train_time:8388ms step_avg:98.69ms +step:86/1670 train_time:8484ms step_avg:98.65ms +step:87/1670 train_time:8580ms step_avg:98.62ms +step:88/1670 train_time:8676ms step_avg:98.59ms +step:89/1670 train_time:8771ms step_avg:98.55ms +step:90/1670 train_time:8867ms step_avg:98.52ms +step:91/1670 train_time:8963ms step_avg:98.49ms +step:92/1670 train_time:9060ms step_avg:98.48ms +step:93/1670 train_time:9156ms step_avg:98.45ms +step:94/1670 train_time:9252ms step_avg:98.42ms +step:95/1670 train_time:9346ms step_avg:98.38ms +step:96/1670 train_time:9442ms step_avg:98.36ms +step:97/1670 train_time:9539ms step_avg:98.34ms +step:98/1670 train_time:9635ms step_avg:98.32ms +step:99/1670 train_time:9731ms step_avg:98.29ms +step:100/1670 train_time:9826ms step_avg:98.26ms +step:101/1670 train_time:9922ms step_avg:98.24ms +step:102/1670 train_time:10017ms step_avg:98.21ms +step:103/1670 train_time:10114ms step_avg:98.19ms +step:104/1670 train_time:10209ms step_avg:98.17ms +step:105/1670 train_time:10305ms step_avg:98.14ms +step:106/1670 train_time:10402ms step_avg:98.13ms +step:107/1670 train_time:10500ms step_avg:98.13ms +step:108/1670 train_time:10595ms step_avg:98.10ms +step:109/1670 train_time:10691ms step_avg:98.08ms +step:110/1670 train_time:10787ms step_avg:98.06ms +step:111/1670 train_time:10883ms step_avg:98.04ms +step:112/1670 train_time:10979ms step_avg:98.03ms +step:113/1670 train_time:11075ms step_avg:98.01ms +step:114/1670 train_time:11171ms step_avg:97.99ms +step:115/1670 train_time:11266ms step_avg:97.97ms +step:116/1670 train_time:11362ms step_avg:97.95ms +step:117/1670 train_time:11458ms step_avg:97.93ms +step:118/1670 train_time:11554ms step_avg:97.91ms +step:119/1670 train_time:11649ms step_avg:97.89ms +step:120/1670 train_time:11744ms step_avg:97.87ms +step:121/1670 train_time:11840ms step_avg:97.85ms +step:122/1670 train_time:11937ms step_avg:97.84ms +step:123/1670 train_time:12033ms step_avg:97.83ms +step:124/1670 train_time:12130ms step_avg:97.82ms +step:125/1670 train_time:12225ms step_avg:97.80ms +step:125/1670 val_loss:4.3028 train_time:12321ms step_avg:98.57ms +step:126/1670 train_time:12342ms step_avg:97.95ms +step:127/1670 train_time:12429ms step_avg:97.86ms +step:128/1670 train_time:12529ms step_avg:97.89ms +step:129/1670 train_time:12625ms step_avg:97.87ms +step:130/1670 train_time:12720ms step_avg:97.85ms +step:131/1670 train_time:12816ms step_avg:97.83ms +step:132/1670 train_time:12911ms step_avg:97.81ms +step:133/1670 train_time:13005ms step_avg:97.79ms +step:134/1670 train_time:13100ms step_avg:97.76ms +step:135/1670 train_time:13196ms step_avg:97.75ms +step:136/1670 train_time:13290ms step_avg:97.72ms +step:137/1670 train_time:13389ms step_avg:97.73ms +step:138/1670 train_time:13486ms step_avg:97.72ms +step:139/1670 train_time:13582ms step_avg:97.71ms +step:140/1670 train_time:13679ms step_avg:97.71ms +step:141/1670 train_time:13775ms step_avg:97.69ms +step:142/1670 train_time:13870ms step_avg:97.68ms +step:143/1670 train_time:13964ms step_avg:97.65ms +step:144/1670 train_time:14059ms step_avg:97.63ms +step:145/1670 train_time:14155ms step_avg:97.62ms +step:146/1670 train_time:14250ms step_avg:97.60ms +step:147/1670 train_time:14346ms step_avg:97.59ms +step:148/1670 train_time:14442ms step_avg:97.58ms +step:149/1670 train_time:14539ms step_avg:97.58ms +step:150/1670 train_time:14637ms step_avg:97.58ms +step:151/1670 train_time:14733ms step_avg:97.57ms +step:152/1670 train_time:14830ms step_avg:97.56ms +step:153/1670 train_time:14925ms step_avg:97.55ms +step:154/1670 train_time:15019ms step_avg:97.53ms +step:155/1670 train_time:15116ms step_avg:97.52ms +step:156/1670 train_time:15211ms step_avg:97.51ms +step:157/1670 train_time:15307ms step_avg:97.49ms +step:158/1670 train_time:15403ms step_avg:97.48ms +step:159/1670 train_time:15498ms step_avg:97.47ms +step:160/1670 train_time:15595ms step_avg:97.47ms +step:161/1670 train_time:15691ms step_avg:97.46ms +step:162/1670 train_time:15788ms step_avg:97.46ms +step:163/1670 train_time:15884ms step_avg:97.45ms +step:164/1670 train_time:15979ms step_avg:97.44ms +step:165/1670 train_time:16074ms step_avg:97.42ms +step:166/1670 train_time:16170ms step_avg:97.41ms +step:167/1670 train_time:16265ms step_avg:97.40ms +step:168/1670 train_time:16360ms step_avg:97.38ms +step:169/1670 train_time:16457ms step_avg:97.38ms +step:170/1670 train_time:16553ms step_avg:97.37ms +step:171/1670 train_time:16650ms step_avg:97.37ms +step:172/1670 train_time:16746ms step_avg:97.36ms +step:173/1670 train_time:16841ms step_avg:97.35ms +step:174/1670 train_time:16937ms step_avg:97.34ms +step:175/1670 train_time:17033ms step_avg:97.33ms +step:176/1670 train_time:17128ms step_avg:97.32ms +step:177/1670 train_time:17223ms step_avg:97.30ms +step:178/1670 train_time:17318ms step_avg:97.29ms +step:179/1670 train_time:17415ms step_avg:97.29ms +step:180/1670 train_time:17511ms step_avg:97.29ms +step:181/1670 train_time:17607ms step_avg:97.28ms +step:182/1670 train_time:17703ms step_avg:97.27ms +step:183/1670 train_time:17798ms step_avg:97.26ms +step:184/1670 train_time:17895ms step_avg:97.25ms +step:185/1670 train_time:17991ms step_avg:97.25ms +step:186/1670 train_time:18087ms step_avg:97.24ms +step:187/1670 train_time:18182ms step_avg:97.23ms +step:188/1670 train_time:18278ms step_avg:97.22ms +step:189/1670 train_time:18373ms step_avg:97.21ms +step:190/1670 train_time:18469ms step_avg:97.21ms +step:191/1670 train_time:18565ms step_avg:97.20ms +step:192/1670 train_time:18661ms step_avg:97.19ms +step:193/1670 train_time:18757ms step_avg:97.18ms +step:194/1670 train_time:18853ms step_avg:97.18ms +step:195/1670 train_time:18949ms step_avg:97.17ms +step:196/1670 train_time:19044ms step_avg:97.16ms +step:197/1670 train_time:19140ms step_avg:97.16ms +step:198/1670 train_time:19236ms step_avg:97.15ms +step:199/1670 train_time:19332ms step_avg:97.14ms +step:200/1670 train_time:19428ms step_avg:97.14ms +step:201/1670 train_time:19523ms step_avg:97.13ms +step:202/1670 train_time:19619ms step_avg:97.12ms +step:203/1670 train_time:19715ms step_avg:97.12ms +step:204/1670 train_time:19810ms step_avg:97.11ms +step:205/1670 train_time:19906ms step_avg:97.10ms +step:206/1670 train_time:20001ms step_avg:97.09ms +step:207/1670 train_time:20097ms step_avg:97.09ms +step:208/1670 train_time:20193ms step_avg:97.08ms +step:209/1670 train_time:20288ms step_avg:97.07ms +step:210/1670 train_time:20383ms step_avg:97.06ms +step:211/1670 train_time:20479ms step_avg:97.06ms +step:212/1670 train_time:20575ms step_avg:97.05ms +step:213/1670 train_time:20938ms step_avg:98.30ms +step:214/1670 train_time:21010ms step_avg:98.18ms +step:215/1670 train_time:21104ms step_avg:98.16ms +step:216/1670 train_time:21199ms step_avg:98.14ms +step:217/1670 train_time:21294ms step_avg:98.13ms +step:218/1670 train_time:21388ms step_avg:98.11ms +step:219/1670 train_time:21482ms step_avg:98.09ms +step:220/1670 train_time:21577ms step_avg:98.08ms +step:221/1670 train_time:21671ms step_avg:98.06ms +step:222/1670 train_time:21766ms step_avg:98.04ms +step:223/1670 train_time:21862ms step_avg:98.04ms +step:224/1670 train_time:21961ms step_avg:98.04ms +step:225/1670 train_time:22060ms step_avg:98.04ms +step:226/1670 train_time:22156ms step_avg:98.03ms +step:227/1670 train_time:22251ms step_avg:98.02ms +step:228/1670 train_time:22347ms step_avg:98.01ms +step:229/1670 train_time:22441ms step_avg:98.00ms +step:230/1670 train_time:22536ms step_avg:97.98ms +step:231/1670 train_time:22631ms step_avg:97.97ms +step:232/1670 train_time:22725ms step_avg:97.95ms +step:233/1670 train_time:22822ms step_avg:97.95ms +step:234/1670 train_time:22919ms step_avg:97.95ms +step:235/1670 train_time:23016ms step_avg:97.94ms +step:236/1670 train_time:23113ms step_avg:97.94ms +step:237/1670 train_time:23208ms step_avg:97.93ms +step:238/1670 train_time:23303ms step_avg:97.91ms +step:239/1670 train_time:23399ms step_avg:97.90ms +step:240/1670 train_time:23494ms step_avg:97.89ms +step:241/1670 train_time:23589ms step_avg:97.88ms +step:242/1670 train_time:23684ms step_avg:97.87ms +step:243/1670 train_time:23779ms step_avg:97.86ms +step:244/1670 train_time:23875ms step_avg:97.85ms +step:245/1670 train_time:23971ms step_avg:97.84ms +step:246/1670 train_time:24066ms step_avg:97.83ms +step:247/1670 train_time:24162ms step_avg:97.82ms +step:248/1670 train_time:24259ms step_avg:97.82ms +step:249/1670 train_time:24354ms step_avg:97.81ms +step:250/1670 train_time:24450ms step_avg:97.80ms +step:250/1670 val_loss:3.9672 train_time:24544ms step_avg:98.17ms +step:251/1670 train_time:24565ms step_avg:97.87ms +step:252/1670 train_time:24646ms step_avg:97.80ms +step:253/1670 train_time:24746ms step_avg:97.81ms +step:254/1670 train_time:24845ms step_avg:97.81ms +step:255/1670 train_time:24942ms step_avg:97.81ms +step:256/1670 train_time:25036ms step_avg:97.80ms +step:257/1670 train_time:25131ms step_avg:97.79ms +step:258/1670 train_time:25226ms step_avg:97.78ms +step:259/1670 train_time:25321ms step_avg:97.76ms +step:260/1670 train_time:25416ms step_avg:97.75ms +step:261/1670 train_time:25511ms step_avg:97.74ms +step:262/1670 train_time:25607ms step_avg:97.74ms +step:263/1670 train_time:25705ms step_avg:97.74ms +step:264/1670 train_time:25803ms step_avg:97.74ms +step:265/1670 train_time:25899ms step_avg:97.73ms +step:266/1670 train_time:25995ms step_avg:97.73ms +step:267/1670 train_time:26089ms step_avg:97.71ms +step:268/1670 train_time:26185ms step_avg:97.71ms +step:269/1670 train_time:26280ms step_avg:97.70ms +step:270/1670 train_time:26376ms step_avg:97.69ms +step:271/1670 train_time:26470ms step_avg:97.68ms +step:272/1670 train_time:26565ms step_avg:97.67ms +step:273/1670 train_time:26662ms step_avg:97.66ms +step:274/1670 train_time:26758ms step_avg:97.66ms +step:275/1670 train_time:26854ms step_avg:97.65ms +step:276/1670 train_time:26950ms step_avg:97.65ms +step:277/1670 train_time:27047ms step_avg:97.64ms +step:278/1670 train_time:27144ms step_avg:97.64ms +step:279/1670 train_time:27238ms step_avg:97.63ms +step:280/1670 train_time:27334ms step_avg:97.62ms +step:281/1670 train_time:27429ms step_avg:97.61ms +step:282/1670 train_time:27524ms step_avg:97.60ms +step:283/1670 train_time:27619ms step_avg:97.59ms +step:284/1670 train_time:27715ms step_avg:97.59ms +step:285/1670 train_time:27810ms step_avg:97.58ms +step:286/1670 train_time:27906ms step_avg:97.57ms +step:287/1670 train_time:28003ms step_avg:97.57ms +step:288/1670 train_time:28099ms step_avg:97.57ms +step:289/1670 train_time:28194ms step_avg:97.56ms +step:290/1670 train_time:28289ms step_avg:97.55ms +step:291/1670 train_time:28384ms step_avg:97.54ms +step:292/1670 train_time:28480ms step_avg:97.53ms +step:293/1670 train_time:28574ms step_avg:97.52ms +step:294/1670 train_time:28669ms step_avg:97.52ms +step:295/1670 train_time:28766ms step_avg:97.51ms +step:296/1670 train_time:28862ms step_avg:97.51ms +step:297/1670 train_time:28958ms step_avg:97.50ms +step:298/1670 train_time:29054ms step_avg:97.50ms +step:299/1670 train_time:29149ms step_avg:97.49ms +step:300/1670 train_time:29245ms step_avg:97.48ms +step:301/1670 train_time:29341ms step_avg:97.48ms +step:302/1670 train_time:29437ms step_avg:97.47ms +step:303/1670 train_time:29532ms step_avg:97.47ms +step:304/1670 train_time:29628ms step_avg:97.46ms +step:305/1670 train_time:29724ms step_avg:97.46ms +step:306/1670 train_time:29820ms step_avg:97.45ms +step:307/1670 train_time:29917ms step_avg:97.45ms +step:308/1670 train_time:30012ms step_avg:97.44ms +step:309/1670 train_time:30108ms step_avg:97.44ms +step:310/1670 train_time:30203ms step_avg:97.43ms +step:311/1670 train_time:30299ms step_avg:97.42ms +step:312/1670 train_time:30393ms step_avg:97.41ms +step:313/1670 train_time:30488ms step_avg:97.41ms +step:314/1670 train_time:30584ms step_avg:97.40ms +step:315/1670 train_time:30679ms step_avg:97.40ms +step:316/1670 train_time:30775ms step_avg:97.39ms +step:317/1670 train_time:30871ms step_avg:97.38ms +step:318/1670 train_time:30967ms step_avg:97.38ms +step:319/1670 train_time:31063ms step_avg:97.38ms +step:320/1670 train_time:31159ms step_avg:97.37ms +step:321/1670 train_time:31254ms step_avg:97.36ms +step:322/1670 train_time:31350ms step_avg:97.36ms +step:323/1670 train_time:31446ms step_avg:97.36ms +step:324/1670 train_time:31541ms step_avg:97.35ms +step:325/1670 train_time:31636ms step_avg:97.34ms +step:326/1670 train_time:31732ms step_avg:97.34ms +step:327/1670 train_time:31827ms step_avg:97.33ms +step:328/1670 train_time:31925ms step_avg:97.33ms +step:329/1670 train_time:32021ms step_avg:97.33ms +step:330/1670 train_time:32117ms step_avg:97.33ms +step:331/1670 train_time:32213ms step_avg:97.32ms +step:332/1670 train_time:32308ms step_avg:97.31ms +step:333/1670 train_time:32404ms step_avg:97.31ms +step:334/1670 train_time:32500ms step_avg:97.31ms +step:335/1670 train_time:32595ms step_avg:97.30ms +step:336/1670 train_time:32691ms step_avg:97.29ms +step:337/1670 train_time:32786ms step_avg:97.29ms +step:338/1670 train_time:32882ms step_avg:97.28ms +step:339/1670 train_time:32978ms step_avg:97.28ms +step:340/1670 train_time:33073ms step_avg:97.27ms +step:341/1670 train_time:33169ms step_avg:97.27ms +step:342/1670 train_time:33265ms step_avg:97.27ms +step:343/1670 train_time:33361ms step_avg:97.26ms +step:344/1670 train_time:33456ms step_avg:97.25ms +step:345/1670 train_time:33552ms step_avg:97.25ms +step:346/1670 train_time:33647ms step_avg:97.24ms +step:347/1670 train_time:33743ms step_avg:97.24ms +step:348/1670 train_time:33839ms step_avg:97.24ms +step:349/1670 train_time:33934ms step_avg:97.23ms +step:350/1670 train_time:34030ms step_avg:97.23ms +step:351/1670 train_time:34125ms step_avg:97.22ms +step:352/1670 train_time:34221ms step_avg:97.22ms +step:353/1670 train_time:34317ms step_avg:97.21ms +step:354/1670 train_time:34413ms step_avg:97.21ms +step:355/1670 train_time:34509ms step_avg:97.21ms +step:356/1670 train_time:34604ms step_avg:97.20ms +step:357/1670 train_time:34700ms step_avg:97.20ms +step:358/1670 train_time:34796ms step_avg:97.20ms +step:359/1670 train_time:34891ms step_avg:97.19ms +step:360/1670 train_time:34987ms step_avg:97.19ms +step:361/1670 train_time:35083ms step_avg:97.18ms +step:362/1670 train_time:35179ms step_avg:97.18ms +step:363/1670 train_time:35275ms step_avg:97.18ms +step:364/1670 train_time:35371ms step_avg:97.17ms +step:365/1670 train_time:35466ms step_avg:97.17ms +step:366/1670 train_time:35562ms step_avg:97.17ms +step:367/1670 train_time:35658ms step_avg:97.16ms +step:368/1670 train_time:35754ms step_avg:97.16ms +step:369/1670 train_time:35850ms step_avg:97.15ms +step:370/1670 train_time:35946ms step_avg:97.15ms +step:371/1670 train_time:36042ms step_avg:97.15ms +step:372/1670 train_time:36138ms step_avg:97.14ms +step:373/1670 train_time:36233ms step_avg:97.14ms +step:374/1670 train_time:36329ms step_avg:97.14ms +step:375/1670 train_time:36425ms step_avg:97.13ms +step:375/1670 val_loss:3.8186 train_time:36520ms step_avg:97.39ms +step:376/1670 train_time:36542ms step_avg:97.19ms +step:377/1670 train_time:36624ms step_avg:97.15ms +step:378/1670 train_time:36725ms step_avg:97.16ms +step:379/1670 train_time:36820ms step_avg:97.15ms +step:380/1670 train_time:36915ms step_avg:97.14ms +step:381/1670 train_time:37009ms step_avg:97.14ms +step:382/1670 train_time:37104ms step_avg:97.13ms +step:383/1670 train_time:37198ms step_avg:97.12ms +step:384/1670 train_time:37293ms step_avg:97.12ms +step:385/1670 train_time:37388ms step_avg:97.11ms +step:386/1670 train_time:37485ms step_avg:97.11ms +step:387/1670 train_time:37582ms step_avg:97.11ms +step:388/1670 train_time:37680ms step_avg:97.11ms +step:389/1670 train_time:37776ms step_avg:97.11ms +step:390/1670 train_time:37872ms step_avg:97.11ms +step:391/1670 train_time:37968ms step_avg:97.11ms +step:392/1670 train_time:38063ms step_avg:97.10ms +step:393/1670 train_time:38158ms step_avg:97.09ms +step:394/1670 train_time:38253ms step_avg:97.09ms +step:395/1670 train_time:38349ms step_avg:97.09ms +step:396/1670 train_time:38444ms step_avg:97.08ms +step:397/1670 train_time:38540ms step_avg:97.08ms +step:398/1670 train_time:38636ms step_avg:97.08ms +step:399/1670 train_time:38733ms step_avg:97.08ms +step:400/1670 train_time:38830ms step_avg:97.07ms +step:401/1670 train_time:38926ms step_avg:97.07ms +step:402/1670 train_time:39021ms step_avg:97.07ms +step:403/1670 train_time:39116ms step_avg:97.06ms +step:404/1670 train_time:39211ms step_avg:97.06ms +step:405/1670 train_time:39307ms step_avg:97.05ms +step:406/1670 train_time:39402ms step_avg:97.05ms +step:407/1670 train_time:39498ms step_avg:97.05ms +step:408/1670 train_time:39594ms step_avg:97.04ms +step:409/1670 train_time:39690ms step_avg:97.04ms +step:410/1670 train_time:39787ms step_avg:97.04ms +step:411/1670 train_time:39882ms step_avg:97.04ms +step:412/1670 train_time:39977ms step_avg:97.03ms +step:413/1670 train_time:40073ms step_avg:97.03ms +step:414/1670 train_time:40168ms step_avg:97.03ms +step:415/1670 train_time:40264ms step_avg:97.02ms +step:416/1670 train_time:40359ms step_avg:97.02ms +step:417/1670 train_time:40455ms step_avg:97.01ms +step:418/1670 train_time:40552ms step_avg:97.02ms +step:419/1670 train_time:40649ms step_avg:97.02ms +step:420/1670 train_time:40746ms step_avg:97.02ms +step:421/1670 train_time:40843ms step_avg:97.01ms +step:422/1670 train_time:40938ms step_avg:97.01ms +step:423/1670 train_time:41033ms step_avg:97.01ms +step:424/1670 train_time:41129ms step_avg:97.00ms +step:425/1670 train_time:41421ms step_avg:97.46ms +step:426/1670 train_time:41530ms step_avg:97.49ms +step:427/1670 train_time:41624ms step_avg:97.48ms +step:428/1670 train_time:41718ms step_avg:97.47ms +step:429/1670 train_time:41813ms step_avg:97.47ms +step:430/1670 train_time:41908ms step_avg:97.46ms +step:431/1670 train_time:42003ms step_avg:97.45ms +step:432/1670 train_time:42097ms step_avg:97.45ms +step:433/1670 train_time:42192ms step_avg:97.44ms +step:434/1670 train_time:42287ms step_avg:97.44ms +step:435/1670 train_time:42384ms step_avg:97.43ms +step:436/1670 train_time:42484ms step_avg:97.44ms +step:437/1670 train_time:42581ms step_avg:97.44ms +step:438/1670 train_time:42676ms step_avg:97.43ms +step:439/1670 train_time:42772ms step_avg:97.43ms +step:440/1670 train_time:42868ms step_avg:97.43ms +step:441/1670 train_time:42963ms step_avg:97.42ms +step:442/1670 train_time:43058ms step_avg:97.42ms +step:443/1670 train_time:43152ms step_avg:97.41ms +step:444/1670 train_time:43249ms step_avg:97.41ms +step:445/1670 train_time:43344ms step_avg:97.40ms +step:446/1670 train_time:43439ms step_avg:97.40ms +step:447/1670 train_time:43536ms step_avg:97.40ms +step:448/1670 train_time:43633ms step_avg:97.40ms +step:449/1670 train_time:43729ms step_avg:97.39ms +step:450/1670 train_time:43825ms step_avg:97.39ms +step:451/1670 train_time:43920ms step_avg:97.38ms +step:452/1670 train_time:44015ms step_avg:97.38ms +step:453/1670 train_time:44110ms step_avg:97.37ms +step:454/1670 train_time:44205ms step_avg:97.37ms +step:455/1670 train_time:44300ms step_avg:97.36ms +step:456/1670 train_time:44396ms step_avg:97.36ms +step:457/1670 train_time:44492ms step_avg:97.36ms +step:458/1670 train_time:44589ms step_avg:97.36ms +step:459/1670 train_time:44685ms step_avg:97.35ms +step:460/1670 train_time:44781ms step_avg:97.35ms +step:461/1670 train_time:44876ms step_avg:97.35ms +step:462/1670 train_time:44972ms step_avg:97.34ms +step:463/1670 train_time:45067ms step_avg:97.34ms +step:464/1670 train_time:45163ms step_avg:97.33ms +step:465/1670 train_time:45257ms step_avg:97.33ms +step:466/1670 train_time:45353ms step_avg:97.32ms +step:467/1670 train_time:45450ms step_avg:97.32ms +step:468/1670 train_time:45547ms step_avg:97.32ms +step:469/1670 train_time:45643ms step_avg:97.32ms +step:470/1670 train_time:45738ms step_avg:97.31ms +step:471/1670 train_time:45834ms step_avg:97.31ms +step:472/1670 train_time:45929ms step_avg:97.31ms +step:473/1670 train_time:46025ms step_avg:97.30ms +step:474/1670 train_time:46120ms step_avg:97.30ms +step:475/1670 train_time:46215ms step_avg:97.30ms +step:476/1670 train_time:46310ms step_avg:97.29ms +step:477/1670 train_time:46406ms step_avg:97.29ms +step:478/1670 train_time:46501ms step_avg:97.28ms +step:479/1670 train_time:46597ms step_avg:97.28ms +step:480/1670 train_time:46693ms step_avg:97.28ms +step:481/1670 train_time:46790ms step_avg:97.28ms +step:482/1670 train_time:46886ms step_avg:97.27ms +step:483/1670 train_time:46981ms step_avg:97.27ms +step:484/1670 train_time:47077ms step_avg:97.27ms +step:485/1670 train_time:47173ms step_avg:97.26ms +step:486/1670 train_time:47268ms step_avg:97.26ms +step:487/1670 train_time:47364ms step_avg:97.26ms +step:488/1670 train_time:47459ms step_avg:97.25ms +step:489/1670 train_time:47555ms step_avg:97.25ms +step:490/1670 train_time:47652ms step_avg:97.25ms +step:491/1670 train_time:47748ms step_avg:97.25ms +step:492/1670 train_time:47844ms step_avg:97.24ms +step:493/1670 train_time:47940ms step_avg:97.24ms +step:494/1670 train_time:48035ms step_avg:97.24ms +step:495/1670 train_time:48131ms step_avg:97.24ms +step:496/1670 train_time:48227ms step_avg:97.23ms +step:497/1670 train_time:48323ms step_avg:97.23ms +step:498/1670 train_time:48419ms step_avg:97.23ms +step:499/1670 train_time:48514ms step_avg:97.22ms +step:500/1670 train_time:48609ms step_avg:97.22ms +step:500/1670 val_loss:3.7198 train_time:48705ms step_avg:97.41ms +step:501/1670 train_time:48726ms step_avg:97.26ms +step:502/1670 train_time:48808ms step_avg:97.23ms +step:503/1670 train_time:48906ms step_avg:97.23ms +step:504/1670 train_time:49002ms step_avg:97.23ms +step:505/1670 train_time:49098ms step_avg:97.22ms +step:506/1670 train_time:49193ms step_avg:97.22ms +step:507/1670 train_time:49288ms step_avg:97.21ms +step:508/1670 train_time:49383ms step_avg:97.21ms +step:509/1670 train_time:49478ms step_avg:97.21ms +step:510/1670 train_time:49573ms step_avg:97.20ms +step:511/1670 train_time:49669ms step_avg:97.20ms +step:512/1670 train_time:49766ms step_avg:97.20ms +step:513/1670 train_time:49863ms step_avg:97.20ms +step:514/1670 train_time:49960ms step_avg:97.20ms +step:515/1670 train_time:50056ms step_avg:97.20ms +step:516/1670 train_time:50152ms step_avg:97.19ms +step:517/1670 train_time:50247ms step_avg:97.19ms +step:518/1670 train_time:50342ms step_avg:97.19ms +step:519/1670 train_time:50438ms step_avg:97.18ms +step:520/1670 train_time:50534ms step_avg:97.18ms +step:521/1670 train_time:50630ms step_avg:97.18ms +step:522/1670 train_time:50726ms step_avg:97.18ms +step:523/1670 train_time:50823ms step_avg:97.18ms +step:524/1670 train_time:50919ms step_avg:97.17ms +step:525/1670 train_time:51015ms step_avg:97.17ms +step:526/1670 train_time:51112ms step_avg:97.17ms +step:527/1670 train_time:51206ms step_avg:97.17ms +step:528/1670 train_time:51301ms step_avg:97.16ms +step:529/1670 train_time:51397ms step_avg:97.16ms +step:530/1670 train_time:51493ms step_avg:97.16ms +step:531/1670 train_time:51589ms step_avg:97.15ms +step:532/1670 train_time:51684ms step_avg:97.15ms +step:533/1670 train_time:51781ms step_avg:97.15ms +step:534/1670 train_time:51877ms step_avg:97.15ms +step:535/1670 train_time:51973ms step_avg:97.15ms +step:536/1670 train_time:52069ms step_avg:97.14ms +step:537/1670 train_time:52165ms step_avg:97.14ms +step:538/1670 train_time:52260ms step_avg:97.14ms +step:539/1670 train_time:52355ms step_avg:97.13ms +step:540/1670 train_time:52450ms step_avg:97.13ms +step:541/1670 train_time:52546ms step_avg:97.13ms +step:542/1670 train_time:52641ms step_avg:97.12ms +step:543/1670 train_time:52737ms step_avg:97.12ms +step:544/1670 train_time:52833ms step_avg:97.12ms +step:545/1670 train_time:52930ms step_avg:97.12ms +step:546/1670 train_time:53025ms step_avg:97.12ms +step:547/1670 train_time:53121ms step_avg:97.11ms +step:548/1670 train_time:53218ms step_avg:97.11ms +step:549/1670 train_time:53313ms step_avg:97.11ms +step:550/1670 train_time:53408ms step_avg:97.11ms +step:551/1670 train_time:53504ms step_avg:97.10ms +step:552/1670 train_time:53600ms step_avg:97.10ms +step:553/1670 train_time:53697ms step_avg:97.10ms +step:554/1670 train_time:53793ms step_avg:97.10ms +step:555/1670 train_time:53888ms step_avg:97.10ms +step:556/1670 train_time:53984ms step_avg:97.09ms +step:557/1670 train_time:54080ms step_avg:97.09ms +step:558/1670 train_time:54176ms step_avg:97.09ms +step:559/1670 train_time:54273ms step_avg:97.09ms +step:560/1670 train_time:54370ms step_avg:97.09ms +step:561/1670 train_time:54466ms step_avg:97.09ms +step:562/1670 train_time:54563ms step_avg:97.09ms +step:563/1670 train_time:54660ms step_avg:97.09ms +step:564/1670 train_time:54757ms step_avg:97.09ms +step:565/1670 train_time:54856ms step_avg:97.09ms +step:566/1670 train_time:54953ms step_avg:97.09ms +step:567/1670 train_time:55050ms step_avg:97.09ms +step:568/1670 train_time:55146ms step_avg:97.09ms +step:569/1670 train_time:55243ms step_avg:97.09ms +step:570/1670 train_time:55340ms step_avg:97.09ms +step:571/1670 train_time:55438ms step_avg:97.09ms +step:572/1670 train_time:55535ms step_avg:97.09ms +step:573/1670 train_time:55632ms step_avg:97.09ms +step:574/1670 train_time:55729ms step_avg:97.09ms +step:575/1670 train_time:55827ms step_avg:97.09ms +step:576/1670 train_time:55926ms step_avg:97.09ms +step:577/1670 train_time:56024ms step_avg:97.10ms +step:578/1670 train_time:56121ms step_avg:97.09ms +step:579/1670 train_time:56217ms step_avg:97.09ms +step:580/1670 train_time:56314ms step_avg:97.09ms +step:581/1670 train_time:56411ms step_avg:97.09ms +step:582/1670 train_time:56507ms step_avg:97.09ms +step:583/1670 train_time:56604ms step_avg:97.09ms +step:584/1670 train_time:56702ms step_avg:97.09ms +step:585/1670 train_time:56801ms step_avg:97.10ms +step:586/1670 train_time:56898ms step_avg:97.10ms +step:587/1670 train_time:56996ms step_avg:97.10ms +step:588/1670 train_time:57092ms step_avg:97.10ms +step:589/1670 train_time:57188ms step_avg:97.09ms +step:590/1670 train_time:57285ms step_avg:97.09ms +step:591/1670 train_time:57383ms step_avg:97.09ms +step:592/1670 train_time:57479ms step_avg:97.09ms +step:593/1670 train_time:57578ms step_avg:97.10ms +step:594/1670 train_time:57676ms step_avg:97.10ms +step:595/1670 train_time:57773ms step_avg:97.10ms +step:596/1670 train_time:57870ms step_avg:97.10ms +step:597/1670 train_time:57966ms step_avg:97.10ms +step:598/1670 train_time:58064ms step_avg:97.10ms +step:599/1670 train_time:58161ms step_avg:97.10ms +step:600/1670 train_time:58258ms step_avg:97.10ms +step:601/1670 train_time:58356ms step_avg:97.10ms +step:602/1670 train_time:58452ms step_avg:97.10ms +step:603/1670 train_time:58549ms step_avg:97.10ms +step:604/1670 train_time:58646ms step_avg:97.10ms +step:605/1670 train_time:58744ms step_avg:97.10ms +step:606/1670 train_time:58841ms step_avg:97.10ms +step:607/1670 train_time:58940ms step_avg:97.10ms +step:608/1670 train_time:59038ms step_avg:97.10ms +step:609/1670 train_time:59135ms step_avg:97.10ms +step:610/1670 train_time:59231ms step_avg:97.10ms +step:611/1670 train_time:59328ms step_avg:97.10ms +step:612/1670 train_time:59426ms step_avg:97.10ms +step:613/1670 train_time:59523ms step_avg:97.10ms +step:614/1670 train_time:59621ms step_avg:97.10ms +step:615/1670 train_time:59718ms step_avg:97.10ms +step:616/1670 train_time:59816ms step_avg:97.10ms +step:617/1670 train_time:59914ms step_avg:97.10ms +step:618/1670 train_time:60010ms step_avg:97.10ms +step:619/1670 train_time:60107ms step_avg:97.10ms +step:620/1670 train_time:60205ms step_avg:97.10ms +step:621/1670 train_time:60302ms step_avg:97.10ms +step:622/1670 train_time:60401ms step_avg:97.11ms +step:623/1670 train_time:60498ms step_avg:97.11ms +step:624/1670 train_time:60595ms step_avg:97.11ms +step:625/1670 train_time:60692ms step_avg:97.11ms +step:625/1670 val_loss:3.6177 train_time:60788ms step_avg:97.26ms +step:626/1670 train_time:60810ms step_avg:97.14ms +step:627/1670 train_time:60894ms step_avg:97.12ms +step:628/1670 train_time:60990ms step_avg:97.12ms +step:629/1670 train_time:61086ms step_avg:97.12ms +step:630/1670 train_time:61182ms step_avg:97.12ms +step:631/1670 train_time:61279ms step_avg:97.11ms +step:632/1670 train_time:61375ms step_avg:97.11ms +step:633/1670 train_time:61471ms step_avg:97.11ms +step:634/1670 train_time:61567ms step_avg:97.11ms +step:635/1670 train_time:61663ms step_avg:97.11ms +step:636/1670 train_time:61765ms step_avg:97.11ms +step:637/1670 train_time:61866ms step_avg:97.12ms +step:638/1670 train_time:61963ms step_avg:97.12ms +step:639/1670 train_time:62351ms step_avg:97.58ms +step:640/1670 train_time:62426ms step_avg:97.54ms +step:641/1670 train_time:62522ms step_avg:97.54ms +step:642/1670 train_time:62618ms step_avg:97.54ms +step:643/1670 train_time:62714ms step_avg:97.53ms +step:644/1670 train_time:62810ms step_avg:97.53ms +step:645/1670 train_time:62906ms step_avg:97.53ms +step:646/1670 train_time:63001ms step_avg:97.53ms +step:647/1670 train_time:63098ms step_avg:97.52ms +step:648/1670 train_time:63195ms step_avg:97.52ms +step:649/1670 train_time:63299ms step_avg:97.53ms +step:650/1670 train_time:63401ms step_avg:97.54ms +step:651/1670 train_time:63500ms step_avg:97.54ms +step:652/1670 train_time:63597ms step_avg:97.54ms +step:653/1670 train_time:63694ms step_avg:97.54ms +step:654/1670 train_time:63790ms step_avg:97.54ms +step:655/1670 train_time:63885ms step_avg:97.53ms +step:656/1670 train_time:63981ms step_avg:97.53ms +step:657/1670 train_time:64078ms step_avg:97.53ms +step:658/1670 train_time:64174ms step_avg:97.53ms +step:659/1670 train_time:64272ms step_avg:97.53ms +step:660/1670 train_time:64372ms step_avg:97.53ms +step:661/1670 train_time:64469ms step_avg:97.53ms +step:662/1670 train_time:64566ms step_avg:97.53ms +step:663/1670 train_time:64664ms step_avg:97.53ms +step:664/1670 train_time:64762ms step_avg:97.53ms +step:665/1670 train_time:64859ms step_avg:97.53ms +step:666/1670 train_time:64955ms step_avg:97.53ms +step:667/1670 train_time:65051ms step_avg:97.53ms +step:668/1670 train_time:65147ms step_avg:97.53ms +step:669/1670 train_time:65244ms step_avg:97.53ms +step:670/1670 train_time:65342ms step_avg:97.53ms +step:671/1670 train_time:65441ms step_avg:97.53ms +step:672/1670 train_time:65540ms step_avg:97.53ms +step:673/1670 train_time:65639ms step_avg:97.53ms +step:674/1670 train_time:65737ms step_avg:97.53ms +step:675/1670 train_time:65835ms step_avg:97.53ms +step:676/1670 train_time:65931ms step_avg:97.53ms +step:677/1670 train_time:66027ms step_avg:97.53ms +step:678/1670 train_time:66124ms step_avg:97.53ms +step:679/1670 train_time:66221ms step_avg:97.53ms +step:680/1670 train_time:66318ms step_avg:97.53ms +step:681/1670 train_time:66416ms step_avg:97.53ms +step:682/1670 train_time:66513ms step_avg:97.53ms +step:683/1670 train_time:66610ms step_avg:97.53ms +step:684/1670 train_time:66707ms step_avg:97.52ms +step:685/1670 train_time:66805ms step_avg:97.53ms +step:686/1670 train_time:66902ms step_avg:97.52ms +step:687/1670 train_time:66999ms step_avg:97.52ms +step:688/1670 train_time:67097ms step_avg:97.52ms +step:689/1670 train_time:67194ms step_avg:97.52ms +step:690/1670 train_time:67290ms step_avg:97.52ms +step:691/1670 train_time:67387ms step_avg:97.52ms +step:692/1670 train_time:67485ms step_avg:97.52ms +step:693/1670 train_time:67583ms step_avg:97.52ms +step:694/1670 train_time:67681ms step_avg:97.52ms +step:695/1670 train_time:67779ms step_avg:97.52ms +step:696/1670 train_time:67877ms step_avg:97.52ms +step:697/1670 train_time:67974ms step_avg:97.52ms +step:698/1670 train_time:68071ms step_avg:97.52ms +step:699/1670 train_time:68167ms step_avg:97.52ms +step:700/1670 train_time:68264ms step_avg:97.52ms +step:701/1670 train_time:68361ms step_avg:97.52ms +step:702/1670 train_time:68459ms step_avg:97.52ms +step:703/1670 train_time:68557ms step_avg:97.52ms +step:704/1670 train_time:68654ms step_avg:97.52ms +step:705/1670 train_time:68750ms step_avg:97.52ms +step:706/1670 train_time:68846ms step_avg:97.52ms +step:707/1670 train_time:68944ms step_avg:97.52ms +step:708/1670 train_time:69042ms step_avg:97.52ms +step:709/1670 train_time:69141ms step_avg:97.52ms +step:710/1670 train_time:69238ms step_avg:97.52ms +step:711/1670 train_time:69336ms step_avg:97.52ms +step:712/1670 train_time:69434ms step_avg:97.52ms +step:713/1670 train_time:69531ms step_avg:97.52ms +step:714/1670 train_time:69628ms step_avg:97.52ms +step:715/1670 train_time:69725ms step_avg:97.52ms +step:716/1670 train_time:69821ms step_avg:97.52ms +step:717/1670 train_time:69918ms step_avg:97.52ms +step:718/1670 train_time:70016ms step_avg:97.51ms +step:719/1670 train_time:70112ms step_avg:97.51ms +step:720/1670 train_time:70208ms step_avg:97.51ms +step:721/1670 train_time:70305ms step_avg:97.51ms +step:722/1670 train_time:70402ms step_avg:97.51ms +step:723/1670 train_time:70500ms step_avg:97.51ms +step:724/1670 train_time:70598ms step_avg:97.51ms +step:725/1670 train_time:70697ms step_avg:97.51ms +step:726/1670 train_time:70794ms step_avg:97.51ms +step:727/1670 train_time:70891ms step_avg:97.51ms +step:728/1670 train_time:70987ms step_avg:97.51ms +step:729/1670 train_time:71085ms step_avg:97.51ms +step:730/1670 train_time:71181ms step_avg:97.51ms +step:731/1670 train_time:71279ms step_avg:97.51ms +step:732/1670 train_time:71376ms step_avg:97.51ms +step:733/1670 train_time:71474ms step_avg:97.51ms +step:734/1670 train_time:71572ms step_avg:97.51ms +step:735/1670 train_time:71669ms step_avg:97.51ms +step:736/1670 train_time:71765ms step_avg:97.51ms +step:737/1670 train_time:71863ms step_avg:97.51ms +step:738/1670 train_time:71960ms step_avg:97.51ms +step:739/1670 train_time:72058ms step_avg:97.51ms +step:740/1670 train_time:72155ms step_avg:97.51ms +step:741/1670 train_time:72251ms step_avg:97.50ms +step:742/1670 train_time:72347ms step_avg:97.50ms +step:743/1670 train_time:72444ms step_avg:97.50ms +step:744/1670 train_time:72542ms step_avg:97.50ms +step:745/1670 train_time:72641ms step_avg:97.50ms +step:746/1670 train_time:72740ms step_avg:97.51ms +step:747/1670 train_time:72838ms step_avg:97.51ms +step:748/1670 train_time:72935ms step_avg:97.51ms +step:749/1670 train_time:73033ms step_avg:97.51ms +step:750/1670 train_time:73129ms step_avg:97.50ms +step:750/1670 val_loss:3.5642 train_time:73224ms step_avg:97.63ms +step:751/1670 train_time:73245ms step_avg:97.53ms +step:752/1670 train_time:73328ms step_avg:97.51ms +step:753/1670 train_time:73427ms step_avg:97.51ms +step:754/1670 train_time:73524ms step_avg:97.51ms +step:755/1670 train_time:73620ms step_avg:97.51ms +step:756/1670 train_time:73717ms step_avg:97.51ms +step:757/1670 train_time:73813ms step_avg:97.51ms +step:758/1670 train_time:73910ms step_avg:97.51ms +step:759/1670 train_time:74005ms step_avg:97.50ms +step:760/1670 train_time:74102ms step_avg:97.50ms +step:761/1670 train_time:74200ms step_avg:97.50ms +step:762/1670 train_time:74301ms step_avg:97.51ms +step:763/1670 train_time:74402ms step_avg:97.51ms +step:764/1670 train_time:74501ms step_avg:97.51ms +step:765/1670 train_time:74598ms step_avg:97.51ms +step:766/1670 train_time:74696ms step_avg:97.51ms +step:767/1670 train_time:74792ms step_avg:97.51ms +step:768/1670 train_time:74889ms step_avg:97.51ms +step:769/1670 train_time:74986ms step_avg:97.51ms +step:770/1670 train_time:75083ms step_avg:97.51ms +step:771/1670 train_time:75180ms step_avg:97.51ms +step:772/1670 train_time:75280ms step_avg:97.51ms +step:773/1670 train_time:75380ms step_avg:97.52ms +step:774/1670 train_time:75479ms step_avg:97.52ms +step:775/1670 train_time:75577ms step_avg:97.52ms +step:776/1670 train_time:75674ms step_avg:97.52ms +step:777/1670 train_time:75770ms step_avg:97.52ms +step:778/1670 train_time:75866ms step_avg:97.51ms +step:779/1670 train_time:75963ms step_avg:97.51ms +step:780/1670 train_time:76060ms step_avg:97.51ms +step:781/1670 train_time:76156ms step_avg:97.51ms +step:782/1670 train_time:76254ms step_avg:97.51ms +step:783/1670 train_time:76353ms step_avg:97.51ms +step:784/1670 train_time:76452ms step_avg:97.51ms +step:785/1670 train_time:76549ms step_avg:97.51ms +step:786/1670 train_time:76645ms step_avg:97.51ms +step:787/1670 train_time:76743ms step_avg:97.51ms +step:788/1670 train_time:76840ms step_avg:97.51ms +step:789/1670 train_time:76937ms step_avg:97.51ms +step:790/1670 train_time:77034ms step_avg:97.51ms +step:791/1670 train_time:77131ms step_avg:97.51ms +step:792/1670 train_time:77227ms step_avg:97.51ms +step:793/1670 train_time:77324ms step_avg:97.51ms +step:794/1670 train_time:77422ms step_avg:97.51ms +step:795/1670 train_time:77521ms step_avg:97.51ms +step:796/1670 train_time:77619ms step_avg:97.51ms +step:797/1670 train_time:77717ms step_avg:97.51ms +step:798/1670 train_time:77816ms step_avg:97.51ms +step:799/1670 train_time:77913ms step_avg:97.51ms +step:800/1670 train_time:78010ms step_avg:97.51ms +step:801/1670 train_time:78106ms step_avg:97.51ms +step:802/1670 train_time:78203ms step_avg:97.51ms +step:803/1670 train_time:78300ms step_avg:97.51ms +step:804/1670 train_time:78398ms step_avg:97.51ms +step:805/1670 train_time:78497ms step_avg:97.51ms +step:806/1670 train_time:78595ms step_avg:97.51ms +step:807/1670 train_time:78693ms step_avg:97.51ms +step:808/1670 train_time:78790ms step_avg:97.51ms +step:809/1670 train_time:78886ms step_avg:97.51ms +step:810/1670 train_time:78983ms step_avg:97.51ms +step:811/1670 train_time:79080ms step_avg:97.51ms +step:812/1670 train_time:79177ms step_avg:97.51ms +step:813/1670 train_time:79274ms step_avg:97.51ms +step:814/1670 train_time:79372ms step_avg:97.51ms +step:815/1670 train_time:79469ms step_avg:97.51ms +step:816/1670 train_time:79566ms step_avg:97.51ms +step:817/1670 train_time:79663ms step_avg:97.51ms +step:818/1670 train_time:79762ms step_avg:97.51ms +step:819/1670 train_time:79860ms step_avg:97.51ms +step:820/1670 train_time:79958ms step_avg:97.51ms +step:821/1670 train_time:80055ms step_avg:97.51ms +step:822/1670 train_time:80151ms step_avg:97.51ms +step:823/1670 train_time:80248ms step_avg:97.51ms +step:824/1670 train_time:80345ms step_avg:97.51ms +step:825/1670 train_time:80443ms step_avg:97.51ms +step:826/1670 train_time:80540ms step_avg:97.51ms +step:827/1670 train_time:80637ms step_avg:97.51ms +step:828/1670 train_time:80735ms step_avg:97.51ms +step:829/1670 train_time:80832ms step_avg:97.51ms +step:830/1670 train_time:80929ms step_avg:97.50ms +step:831/1670 train_time:81025ms step_avg:97.50ms +step:832/1670 train_time:81122ms step_avg:97.50ms +step:833/1670 train_time:81219ms step_avg:97.50ms +step:834/1670 train_time:81318ms step_avg:97.50ms +step:835/1670 train_time:81416ms step_avg:97.50ms +step:836/1670 train_time:81513ms step_avg:97.50ms +step:837/1670 train_time:81612ms step_avg:97.51ms +step:838/1670 train_time:81710ms step_avg:97.51ms +step:839/1670 train_time:81807ms step_avg:97.51ms +step:840/1670 train_time:81903ms step_avg:97.50ms +step:841/1670 train_time:82001ms step_avg:97.50ms +step:842/1670 train_time:82098ms step_avg:97.50ms +step:843/1670 train_time:82196ms step_avg:97.50ms +step:844/1670 train_time:82295ms step_avg:97.51ms +step:845/1670 train_time:82393ms step_avg:97.51ms +step:846/1670 train_time:82490ms step_avg:97.51ms +step:847/1670 train_time:82587ms step_avg:97.51ms +step:848/1670 train_time:82684ms step_avg:97.50ms +step:849/1670 train_time:82781ms step_avg:97.50ms +step:850/1670 train_time:82879ms step_avg:97.50ms +step:851/1670 train_time:83243ms step_avg:97.82ms +step:852/1670 train_time:83317ms step_avg:97.79ms +step:853/1670 train_time:83412ms step_avg:97.79ms +step:854/1670 train_time:83508ms step_avg:97.78ms +step:855/1670 train_time:83603ms step_avg:97.78ms +step:856/1670 train_time:83700ms step_avg:97.78ms +step:857/1670 train_time:83797ms step_avg:97.78ms +step:858/1670 train_time:83893ms step_avg:97.78ms +step:859/1670 train_time:83990ms step_avg:97.78ms +step:860/1670 train_time:84086ms step_avg:97.77ms +step:861/1670 train_time:84186ms step_avg:97.78ms +step:862/1670 train_time:84286ms step_avg:97.78ms +step:863/1670 train_time:84384ms step_avg:97.78ms +step:864/1670 train_time:84481ms step_avg:97.78ms +step:865/1670 train_time:84579ms step_avg:97.78ms +step:866/1670 train_time:84676ms step_avg:97.78ms +step:867/1670 train_time:84773ms step_avg:97.78ms +step:868/1670 train_time:84869ms step_avg:97.78ms +step:869/1670 train_time:84964ms step_avg:97.77ms +step:870/1670 train_time:85061ms step_avg:97.77ms +step:871/1670 train_time:85159ms step_avg:97.77ms +step:872/1670 train_time:85259ms step_avg:97.77ms +step:873/1670 train_time:85358ms step_avg:97.78ms +step:874/1670 train_time:85457ms step_avg:97.78ms +step:875/1670 train_time:85555ms step_avg:97.78ms +step:875/1670 val_loss:3.5231 train_time:85652ms step_avg:97.89ms +step:876/1670 train_time:85674ms step_avg:97.80ms +step:877/1670 train_time:85756ms step_avg:97.78ms +step:878/1670 train_time:85857ms step_avg:97.79ms +step:879/1670 train_time:85955ms step_avg:97.79ms +step:880/1670 train_time:86051ms step_avg:97.79ms +step:881/1670 train_time:86149ms step_avg:97.79ms +step:882/1670 train_time:86244ms step_avg:97.78ms +step:883/1670 train_time:86340ms step_avg:97.78ms +step:884/1670 train_time:86436ms step_avg:97.78ms +step:885/1670 train_time:86533ms step_avg:97.78ms +step:886/1670 train_time:86631ms step_avg:97.78ms +step:887/1670 train_time:86730ms step_avg:97.78ms +step:888/1670 train_time:86829ms step_avg:97.78ms +step:889/1670 train_time:86927ms step_avg:97.78ms +step:890/1670 train_time:87023ms step_avg:97.78ms +step:891/1670 train_time:87120ms step_avg:97.78ms +step:892/1670 train_time:87217ms step_avg:97.78ms +step:893/1670 train_time:87314ms step_avg:97.78ms +step:894/1670 train_time:87410ms step_avg:97.77ms +step:895/1670 train_time:87507ms step_avg:97.77ms +step:896/1670 train_time:87604ms step_avg:97.77ms +step:897/1670 train_time:87701ms step_avg:97.77ms +step:898/1670 train_time:87799ms step_avg:97.77ms +step:899/1670 train_time:87897ms step_avg:97.77ms +step:900/1670 train_time:87997ms step_avg:97.77ms +step:901/1670 train_time:88095ms step_avg:97.77ms +step:902/1670 train_time:88192ms step_avg:97.77ms +step:903/1670 train_time:88289ms step_avg:97.77ms +step:904/1670 train_time:88385ms step_avg:97.77ms +step:905/1670 train_time:88481ms step_avg:97.77ms +step:906/1670 train_time:88578ms step_avg:97.77ms +step:907/1670 train_time:88676ms step_avg:97.77ms +step:908/1670 train_time:88775ms step_avg:97.77ms +step:909/1670 train_time:88873ms step_avg:97.77ms +step:910/1670 train_time:88971ms step_avg:97.77ms +step:911/1670 train_time:89068ms step_avg:97.77ms +step:912/1670 train_time:89164ms step_avg:97.77ms +step:913/1670 train_time:89262ms step_avg:97.77ms +step:914/1670 train_time:89359ms step_avg:97.77ms +step:915/1670 train_time:89456ms step_avg:97.77ms +step:916/1670 train_time:89553ms step_avg:97.77ms +step:917/1670 train_time:89650ms step_avg:97.76ms +step:918/1670 train_time:89747ms step_avg:97.76ms +step:919/1670 train_time:89844ms step_avg:97.76ms +step:920/1670 train_time:89942ms step_avg:97.76ms +step:921/1670 train_time:90040ms step_avg:97.76ms +step:922/1670 train_time:90138ms step_avg:97.76ms +step:923/1670 train_time:90236ms step_avg:97.76ms +step:924/1670 train_time:90334ms step_avg:97.76ms +step:925/1670 train_time:90431ms step_avg:97.76ms +step:926/1670 train_time:90528ms step_avg:97.76ms +step:927/1670 train_time:90624ms step_avg:97.76ms +step:928/1670 train_time:90721ms step_avg:97.76ms +step:929/1670 train_time:90819ms step_avg:97.76ms +step:930/1670 train_time:90917ms step_avg:97.76ms +step:931/1670 train_time:91015ms step_avg:97.76ms +step:932/1670 train_time:91114ms step_avg:97.76ms +step:933/1670 train_time:91212ms step_avg:97.76ms +step:934/1670 train_time:91310ms step_avg:97.76ms +step:935/1670 train_time:91406ms step_avg:97.76ms +step:936/1670 train_time:91502ms step_avg:97.76ms +step:937/1670 train_time:91599ms step_avg:97.76ms +step:938/1670 train_time:91696ms step_avg:97.76ms +step:939/1670 train_time:91794ms step_avg:97.76ms +step:940/1670 train_time:91892ms step_avg:97.76ms +step:941/1670 train_time:91989ms step_avg:97.76ms +step:942/1670 train_time:92086ms step_avg:97.76ms +step:943/1670 train_time:92183ms step_avg:97.76ms +step:944/1670 train_time:92280ms step_avg:97.75ms +step:945/1670 train_time:92379ms step_avg:97.76ms +step:946/1670 train_time:92476ms step_avg:97.76ms +step:947/1670 train_time:92574ms step_avg:97.75ms +step:948/1670 train_time:92671ms step_avg:97.75ms +step:949/1670 train_time:92769ms step_avg:97.75ms +step:950/1670 train_time:92867ms step_avg:97.75ms +step:951/1670 train_time:92963ms step_avg:97.75ms +step:952/1670 train_time:93061ms step_avg:97.75ms +step:953/1670 train_time:93159ms step_avg:97.75ms +step:954/1670 train_time:93257ms step_avg:97.75ms +step:955/1670 train_time:93355ms step_avg:97.75ms +step:956/1670 train_time:93451ms step_avg:97.75ms +step:957/1670 train_time:93548ms step_avg:97.75ms +step:958/1670 train_time:93644ms step_avg:97.75ms +step:959/1670 train_time:93742ms step_avg:97.75ms +step:960/1670 train_time:93840ms step_avg:97.75ms +step:961/1670 train_time:93937ms step_avg:97.75ms +step:962/1670 train_time:94035ms step_avg:97.75ms +step:963/1670 train_time:94133ms step_avg:97.75ms +step:964/1670 train_time:94232ms step_avg:97.75ms +step:965/1670 train_time:94328ms step_avg:97.75ms +step:966/1670 train_time:94425ms step_avg:97.75ms +step:967/1670 train_time:94521ms step_avg:97.75ms +step:968/1670 train_time:94619ms step_avg:97.75ms +step:969/1670 train_time:94716ms step_avg:97.75ms +step:970/1670 train_time:94815ms step_avg:97.75ms +step:971/1670 train_time:94913ms step_avg:97.75ms +step:972/1670 train_time:95011ms step_avg:97.75ms +step:973/1670 train_time:95108ms step_avg:97.75ms +step:974/1670 train_time:95206ms step_avg:97.75ms +step:975/1670 train_time:95303ms step_avg:97.75ms +step:976/1670 train_time:95400ms step_avg:97.75ms +step:977/1670 train_time:95497ms step_avg:97.74ms +step:978/1670 train_time:95594ms step_avg:97.74ms +step:979/1670 train_time:95691ms step_avg:97.74ms +step:980/1670 train_time:95789ms step_avg:97.74ms +step:981/1670 train_time:95886ms step_avg:97.74ms +step:982/1670 train_time:95982ms step_avg:97.74ms +step:983/1670 train_time:96080ms step_avg:97.74ms +step:984/1670 train_time:96178ms step_avg:97.74ms +step:985/1670 train_time:96277ms step_avg:97.74ms +step:986/1670 train_time:96375ms step_avg:97.74ms +step:987/1670 train_time:96472ms step_avg:97.74ms +step:988/1670 train_time:96569ms step_avg:97.74ms +step:989/1670 train_time:96666ms step_avg:97.74ms +step:990/1670 train_time:96763ms step_avg:97.74ms +step:991/1670 train_time:96861ms step_avg:97.74ms +step:992/1670 train_time:96959ms step_avg:97.74ms +step:993/1670 train_time:97056ms step_avg:97.74ms +step:994/1670 train_time:97154ms step_avg:97.74ms +step:995/1670 train_time:97252ms step_avg:97.74ms +step:996/1670 train_time:97349ms step_avg:97.74ms +step:997/1670 train_time:97445ms step_avg:97.74ms +step:998/1670 train_time:97542ms step_avg:97.74ms +step:999/1670 train_time:97639ms step_avg:97.74ms +step:1000/1670 train_time:97737ms step_avg:97.74ms +step:1000/1670 val_loss:3.4783 train_time:97834ms step_avg:97.83ms +step:1001/1670 train_time:97856ms step_avg:97.76ms +step:1002/1670 train_time:97940ms step_avg:97.74ms +step:1003/1670 train_time:98039ms step_avg:97.75ms +step:1004/1670 train_time:98135ms step_avg:97.74ms +step:1005/1670 train_time:98231ms step_avg:97.74ms +step:1006/1670 train_time:98328ms step_avg:97.74ms +step:1007/1670 train_time:98424ms step_avg:97.74ms +step:1008/1670 train_time:98520ms step_avg:97.74ms +step:1009/1670 train_time:98617ms step_avg:97.74ms +step:1010/1670 train_time:98713ms step_avg:97.74ms +step:1011/1670 train_time:98812ms step_avg:97.74ms +step:1012/1670 train_time:98910ms step_avg:97.74ms +step:1013/1670 train_time:99010ms step_avg:97.74ms +step:1014/1670 train_time:99109ms step_avg:97.74ms +step:1015/1670 train_time:99207ms step_avg:97.74ms +step:1016/1670 train_time:99304ms step_avg:97.74ms +step:1017/1670 train_time:99401ms step_avg:97.74ms +step:1018/1670 train_time:99497ms step_avg:97.74ms +step:1019/1670 train_time:99593ms step_avg:97.74ms +step:1020/1670 train_time:99690ms step_avg:97.74ms +step:1021/1670 train_time:99788ms step_avg:97.74ms +step:1022/1670 train_time:99886ms step_avg:97.74ms +step:1023/1670 train_time:99986ms step_avg:97.74ms +step:1024/1670 train_time:100084ms step_avg:97.74ms +step:1025/1670 train_time:100182ms step_avg:97.74ms +step:1026/1670 train_time:100280ms step_avg:97.74ms +step:1027/1670 train_time:100376ms step_avg:97.74ms +step:1028/1670 train_time:100472ms step_avg:97.74ms +step:1029/1670 train_time:100569ms step_avg:97.73ms +step:1030/1670 train_time:100666ms step_avg:97.73ms +step:1031/1670 train_time:100763ms step_avg:97.73ms +step:1032/1670 train_time:100861ms step_avg:97.73ms +step:1033/1670 train_time:100959ms step_avg:97.73ms +step:1034/1670 train_time:101058ms step_avg:97.73ms +step:1035/1670 train_time:101155ms step_avg:97.73ms +step:1036/1670 train_time:101252ms step_avg:97.73ms +step:1037/1670 train_time:101350ms step_avg:97.73ms +step:1038/1670 train_time:101447ms step_avg:97.73ms +step:1039/1670 train_time:101545ms step_avg:97.73ms +step:1040/1670 train_time:101642ms step_avg:97.73ms +step:1041/1670 train_time:101739ms step_avg:97.73ms +step:1042/1670 train_time:101835ms step_avg:97.73ms +step:1043/1670 train_time:101933ms step_avg:97.73ms +step:1044/1670 train_time:102030ms step_avg:97.73ms +step:1045/1670 train_time:102129ms step_avg:97.73ms +step:1046/1670 train_time:102227ms step_avg:97.73ms +step:1047/1670 train_time:102325ms step_avg:97.73ms +step:1048/1670 train_time:102422ms step_avg:97.73ms +step:1049/1670 train_time:102520ms step_avg:97.73ms +step:1050/1670 train_time:102617ms step_avg:97.73ms +step:1051/1670 train_time:102714ms step_avg:97.73ms +step:1052/1670 train_time:102811ms step_avg:97.73ms +step:1053/1670 train_time:102908ms step_avg:97.73ms +step:1054/1670 train_time:103006ms step_avg:97.73ms +step:1055/1670 train_time:103105ms step_avg:97.73ms +step:1056/1670 train_time:103203ms step_avg:97.73ms +step:1057/1670 train_time:103301ms step_avg:97.73ms +step:1058/1670 train_time:103398ms step_avg:97.73ms +step:1059/1670 train_time:103494ms step_avg:97.73ms +step:1060/1670 train_time:103591ms step_avg:97.73ms +step:1061/1670 train_time:103688ms step_avg:97.73ms +step:1062/1670 train_time:103939ms step_avg:97.87ms +step:1063/1670 train_time:104148ms step_avg:97.98ms +step:1064/1670 train_time:104243ms step_avg:97.97ms +step:1065/1670 train_time:104339ms step_avg:97.97ms +step:1066/1670 train_time:104434ms step_avg:97.97ms +step:1067/1670 train_time:104530ms step_avg:97.97ms +step:1068/1670 train_time:104626ms step_avg:97.96ms +step:1069/1670 train_time:104723ms step_avg:97.96ms +step:1070/1670 train_time:104819ms step_avg:97.96ms +step:1071/1670 train_time:104915ms step_avg:97.96ms +step:1072/1670 train_time:105015ms step_avg:97.96ms +step:1073/1670 train_time:105115ms step_avg:97.96ms +step:1074/1670 train_time:105213ms step_avg:97.96ms +step:1075/1670 train_time:105311ms step_avg:97.96ms +step:1076/1670 train_time:105409ms step_avg:97.96ms +step:1077/1670 train_time:105506ms step_avg:97.96ms +step:1078/1670 train_time:105603ms step_avg:97.96ms +step:1079/1670 train_time:105700ms step_avg:97.96ms +step:1080/1670 train_time:105796ms step_avg:97.96ms +step:1081/1670 train_time:105892ms step_avg:97.96ms +step:1082/1670 train_time:105990ms step_avg:97.96ms +step:1083/1670 train_time:106089ms step_avg:97.96ms +step:1084/1670 train_time:106187ms step_avg:97.96ms +step:1085/1670 train_time:106286ms step_avg:97.96ms +step:1086/1670 train_time:106383ms step_avg:97.96ms +step:1087/1670 train_time:106482ms step_avg:97.96ms +step:1088/1670 train_time:106578ms step_avg:97.96ms +step:1089/1670 train_time:106674ms step_avg:97.96ms +step:1090/1670 train_time:106771ms step_avg:97.95ms +step:1091/1670 train_time:106868ms step_avg:97.95ms +step:1092/1670 train_time:106966ms step_avg:97.95ms +step:1093/1670 train_time:107065ms step_avg:97.96ms +step:1094/1670 train_time:107164ms step_avg:97.96ms +step:1095/1670 train_time:107263ms step_avg:97.96ms +step:1096/1670 train_time:107360ms step_avg:97.96ms +step:1097/1670 train_time:107457ms step_avg:97.96ms +step:1098/1670 train_time:107553ms step_avg:97.95ms +step:1099/1670 train_time:107649ms step_avg:97.95ms +step:1100/1670 train_time:107747ms step_avg:97.95ms +step:1101/1670 train_time:107845ms step_avg:97.95ms +step:1102/1670 train_time:107942ms step_avg:97.95ms +step:1103/1670 train_time:108040ms step_avg:97.95ms +step:1104/1670 train_time:108137ms step_avg:97.95ms +step:1105/1670 train_time:108234ms step_avg:97.95ms +step:1106/1670 train_time:108331ms step_avg:97.95ms +step:1107/1670 train_time:108429ms step_avg:97.95ms +step:1108/1670 train_time:108527ms step_avg:97.95ms +step:1109/1670 train_time:108625ms step_avg:97.95ms +step:1110/1670 train_time:108722ms step_avg:97.95ms +step:1111/1670 train_time:108819ms step_avg:97.95ms +step:1112/1670 train_time:108916ms step_avg:97.95ms +step:1113/1670 train_time:109013ms step_avg:97.95ms +step:1114/1670 train_time:109110ms step_avg:97.94ms +step:1115/1670 train_time:109209ms step_avg:97.95ms +step:1116/1670 train_time:109308ms step_avg:97.95ms +step:1117/1670 train_time:109406ms step_avg:97.95ms +step:1118/1670 train_time:109504ms step_avg:97.95ms +step:1119/1670 train_time:109601ms step_avg:97.95ms +step:1120/1670 train_time:109699ms step_avg:97.95ms +step:1121/1670 train_time:109797ms step_avg:97.95ms +step:1122/1670 train_time:109895ms step_avg:97.95ms +step:1123/1670 train_time:109992ms step_avg:97.94ms +step:1124/1670 train_time:110089ms step_avg:97.94ms +step:1125/1670 train_time:110187ms step_avg:97.94ms +step:1125/1670 val_loss:3.4256 train_time:110285ms step_avg:98.03ms +step:1126/1670 train_time:110307ms step_avg:97.96ms +step:1127/1670 train_time:110391ms step_avg:97.95ms +step:1128/1670 train_time:110489ms step_avg:97.95ms +step:1129/1670 train_time:110585ms step_avg:97.95ms +step:1130/1670 train_time:110682ms step_avg:97.95ms +step:1131/1670 train_time:110778ms step_avg:97.95ms +step:1132/1670 train_time:110875ms step_avg:97.95ms +step:1133/1670 train_time:110971ms step_avg:97.94ms +step:1134/1670 train_time:111068ms step_avg:97.94ms +step:1135/1670 train_time:111165ms step_avg:97.94ms +step:1136/1670 train_time:111266ms step_avg:97.95ms +step:1137/1670 train_time:111367ms step_avg:97.95ms +step:1138/1670 train_time:111466ms step_avg:97.95ms +step:1139/1670 train_time:111565ms step_avg:97.95ms +step:1140/1670 train_time:111663ms step_avg:97.95ms +step:1141/1670 train_time:111759ms step_avg:97.95ms +step:1142/1670 train_time:111856ms step_avg:97.95ms +step:1143/1670 train_time:111953ms step_avg:97.95ms +step:1144/1670 train_time:112050ms step_avg:97.95ms +step:1145/1670 train_time:112147ms step_avg:97.95ms +step:1146/1670 train_time:112245ms step_avg:97.95ms +step:1147/1670 train_time:112346ms step_avg:97.95ms +step:1148/1670 train_time:112445ms step_avg:97.95ms +step:1149/1670 train_time:112544ms step_avg:97.95ms +step:1150/1670 train_time:112642ms step_avg:97.95ms +step:1151/1670 train_time:112739ms step_avg:97.95ms +step:1152/1670 train_time:112836ms step_avg:97.95ms +step:1153/1670 train_time:112933ms step_avg:97.95ms +step:1154/1670 train_time:113030ms step_avg:97.95ms +step:1155/1670 train_time:113128ms step_avg:97.95ms +step:1156/1670 train_time:113226ms step_avg:97.95ms +step:1157/1670 train_time:113325ms step_avg:97.95ms +step:1158/1670 train_time:113425ms step_avg:97.95ms +step:1159/1670 train_time:113524ms step_avg:97.95ms +step:1160/1670 train_time:113621ms step_avg:97.95ms +step:1161/1670 train_time:113720ms step_avg:97.95ms +step:1162/1670 train_time:113817ms step_avg:97.95ms +step:1163/1670 train_time:113915ms step_avg:97.95ms +step:1164/1670 train_time:114012ms step_avg:97.95ms +step:1165/1670 train_time:114110ms step_avg:97.95ms +step:1166/1670 train_time:114208ms step_avg:97.95ms +step:1167/1670 train_time:114307ms step_avg:97.95ms +step:1168/1670 train_time:114405ms step_avg:97.95ms +step:1169/1670 train_time:114503ms step_avg:97.95ms +step:1170/1670 train_time:114602ms step_avg:97.95ms +step:1171/1670 train_time:114700ms step_avg:97.95ms +step:1172/1670 train_time:114798ms step_avg:97.95ms +step:1173/1670 train_time:114897ms step_avg:97.95ms +step:1174/1670 train_time:114994ms step_avg:97.95ms +step:1175/1670 train_time:115093ms step_avg:97.95ms +step:1176/1670 train_time:115192ms step_avg:97.95ms +step:1177/1670 train_time:115289ms step_avg:97.95ms +step:1178/1670 train_time:115387ms step_avg:97.95ms +step:1179/1670 train_time:115484ms step_avg:97.95ms +step:1180/1670 train_time:115582ms step_avg:97.95ms +step:1181/1670 train_time:115681ms step_avg:97.95ms +step:1182/1670 train_time:115778ms step_avg:97.95ms +step:1183/1670 train_time:115875ms step_avg:97.95ms +step:1184/1670 train_time:115973ms step_avg:97.95ms +step:1185/1670 train_time:116072ms step_avg:97.95ms +step:1186/1670 train_time:116170ms step_avg:97.95ms +step:1187/1670 train_time:116267ms step_avg:97.95ms +step:1188/1670 train_time:116365ms step_avg:97.95ms +step:1189/1670 train_time:116463ms step_avg:97.95ms +step:1190/1670 train_time:116561ms step_avg:97.95ms +step:1191/1670 train_time:116660ms step_avg:97.95ms +step:1192/1670 train_time:116757ms step_avg:97.95ms +step:1193/1670 train_time:116855ms step_avg:97.95ms +step:1194/1670 train_time:116953ms step_avg:97.95ms +step:1195/1670 train_time:117050ms step_avg:97.95ms +step:1196/1670 train_time:117148ms step_avg:97.95ms +step:1197/1670 train_time:117246ms step_avg:97.95ms +step:1198/1670 train_time:117344ms step_avg:97.95ms +step:1199/1670 train_time:117441ms step_avg:97.95ms +step:1200/1670 train_time:117539ms step_avg:97.95ms +step:1201/1670 train_time:117638ms step_avg:97.95ms +step:1202/1670 train_time:117736ms step_avg:97.95ms +step:1203/1670 train_time:117833ms step_avg:97.95ms +step:1204/1670 train_time:117930ms step_avg:97.95ms +step:1205/1670 train_time:118027ms step_avg:97.95ms +step:1206/1670 train_time:118125ms step_avg:97.95ms +step:1207/1670 train_time:118224ms step_avg:97.95ms +step:1208/1670 train_time:118322ms step_avg:97.95ms +step:1209/1670 train_time:118421ms step_avg:97.95ms +step:1210/1670 train_time:118520ms step_avg:97.95ms +step:1211/1670 train_time:118618ms step_avg:97.95ms +step:1212/1670 train_time:118716ms step_avg:97.95ms +step:1213/1670 train_time:118815ms step_avg:97.95ms +step:1214/1670 train_time:118913ms step_avg:97.95ms +step:1215/1670 train_time:119012ms step_avg:97.95ms +step:1216/1670 train_time:119110ms step_avg:97.95ms +step:1217/1670 train_time:119207ms step_avg:97.95ms +step:1218/1670 train_time:119305ms step_avg:97.95ms +step:1219/1670 train_time:119404ms step_avg:97.95ms +step:1220/1670 train_time:119502ms step_avg:97.95ms +step:1221/1670 train_time:119600ms step_avg:97.95ms +step:1222/1670 train_time:119698ms step_avg:97.95ms +step:1223/1670 train_time:119796ms step_avg:97.95ms +step:1224/1670 train_time:119895ms step_avg:97.95ms +step:1225/1670 train_time:119993ms step_avg:97.95ms +step:1226/1670 train_time:120092ms step_avg:97.95ms +step:1227/1670 train_time:120191ms step_avg:97.95ms +step:1228/1670 train_time:120289ms step_avg:97.95ms +step:1229/1670 train_time:120386ms step_avg:97.95ms +step:1230/1670 train_time:120484ms step_avg:97.95ms +step:1231/1670 train_time:120582ms step_avg:97.95ms +step:1232/1670 train_time:120680ms step_avg:97.95ms +step:1233/1670 train_time:120778ms step_avg:97.95ms +step:1234/1670 train_time:120877ms step_avg:97.96ms +step:1235/1670 train_time:120975ms step_avg:97.96ms +step:1236/1670 train_time:121074ms step_avg:97.96ms +step:1237/1670 train_time:121173ms step_avg:97.96ms +step:1238/1670 train_time:121272ms step_avg:97.96ms +step:1239/1670 train_time:121369ms step_avg:97.96ms +step:1240/1670 train_time:121467ms step_avg:97.96ms +step:1241/1670 train_time:121564ms step_avg:97.96ms +step:1242/1670 train_time:121662ms step_avg:97.96ms +step:1243/1670 train_time:121760ms step_avg:97.96ms +step:1244/1670 train_time:121858ms step_avg:97.96ms +step:1245/1670 train_time:121957ms step_avg:97.96ms +step:1246/1670 train_time:122054ms step_avg:97.96ms +step:1247/1670 train_time:122153ms step_avg:97.96ms +step:1248/1670 train_time:122252ms step_avg:97.96ms +step:1249/1670 train_time:122350ms step_avg:97.96ms +step:1250/1670 train_time:122447ms step_avg:97.96ms +step:1250/1670 val_loss:3.3820 train_time:122544ms step_avg:98.04ms +step:1251/1670 train_time:122565ms step_avg:97.97ms +step:1252/1670 train_time:122654ms step_avg:97.97ms +step:1253/1670 train_time:122754ms step_avg:97.97ms +step:1254/1670 train_time:122852ms step_avg:97.97ms +step:1255/1670 train_time:122949ms step_avg:97.97ms +step:1256/1670 train_time:123047ms step_avg:97.97ms +step:1257/1670 train_time:123143ms step_avg:97.97ms +step:1258/1670 train_time:123240ms step_avg:97.97ms +step:1259/1670 train_time:123337ms step_avg:97.96ms +step:1260/1670 train_time:123434ms step_avg:97.96ms +step:1261/1670 train_time:123535ms step_avg:97.97ms +step:1262/1670 train_time:123636ms step_avg:97.97ms +step:1263/1670 train_time:123736ms step_avg:97.97ms +step:1264/1670 train_time:123834ms step_avg:97.97ms +step:1265/1670 train_time:123931ms step_avg:97.97ms +step:1266/1670 train_time:124030ms step_avg:97.97ms +step:1267/1670 train_time:124127ms step_avg:97.97ms +step:1268/1670 train_time:124223ms step_avg:97.97ms +step:1269/1670 train_time:124320ms step_avg:97.97ms +step:1270/1670 train_time:124417ms step_avg:97.97ms +step:1271/1670 train_time:124515ms step_avg:97.97ms +step:1272/1670 train_time:124615ms step_avg:97.97ms +step:1273/1670 train_time:124714ms step_avg:97.97ms +step:1274/1670 train_time:125098ms step_avg:98.19ms +step:1275/1670 train_time:125175ms step_avg:98.18ms +step:1276/1670 train_time:125272ms step_avg:98.18ms +step:1277/1670 train_time:125368ms step_avg:98.17ms +step:1278/1670 train_time:125464ms step_avg:98.17ms +step:1279/1670 train_time:125561ms step_avg:98.17ms +step:1280/1670 train_time:125657ms step_avg:98.17ms +step:1281/1670 train_time:125754ms step_avg:98.17ms +step:1282/1670 train_time:125850ms step_avg:98.17ms +step:1283/1670 train_time:125948ms step_avg:98.17ms +step:1284/1670 train_time:126053ms step_avg:98.17ms +step:1285/1670 train_time:126156ms step_avg:98.18ms +step:1286/1670 train_time:126254ms step_avg:98.18ms +step:1287/1670 train_time:126352ms step_avg:98.18ms +step:1288/1670 train_time:126450ms step_avg:98.18ms +step:1289/1670 train_time:126548ms step_avg:98.18ms +step:1290/1670 train_time:126646ms step_avg:98.17ms +step:1291/1670 train_time:126743ms step_avg:98.17ms +step:1292/1670 train_time:126840ms step_avg:98.17ms +step:1293/1670 train_time:126938ms step_avg:98.17ms +step:1294/1670 train_time:127036ms step_avg:98.17ms +step:1295/1670 train_time:127136ms step_avg:98.17ms +step:1296/1670 train_time:127235ms step_avg:98.18ms +step:1297/1670 train_time:127333ms step_avg:98.18ms +step:1298/1670 train_time:127431ms step_avg:98.18ms +step:1299/1670 train_time:127529ms step_avg:98.17ms +step:1300/1670 train_time:127627ms step_avg:98.17ms +step:1301/1670 train_time:127725ms step_avg:98.17ms +step:1302/1670 train_time:127822ms step_avg:98.17ms +step:1303/1670 train_time:127919ms step_avg:98.17ms +step:1304/1670 train_time:128018ms step_avg:98.17ms +step:1305/1670 train_time:128118ms step_avg:98.17ms +step:1306/1670 train_time:128216ms step_avg:98.17ms +step:1307/1670 train_time:128313ms step_avg:98.17ms +step:1308/1670 train_time:128411ms step_avg:98.17ms +step:1309/1670 train_time:128509ms step_avg:98.17ms +step:1310/1670 train_time:128608ms step_avg:98.17ms +step:1311/1670 train_time:128708ms step_avg:98.18ms +step:1312/1670 train_time:128805ms step_avg:98.17ms +step:1313/1670 train_time:128903ms step_avg:98.17ms +step:1314/1670 train_time:129002ms step_avg:98.17ms +step:1315/1670 train_time:129102ms step_avg:98.18ms +step:1316/1670 train_time:129201ms step_avg:98.18ms +step:1317/1670 train_time:129298ms step_avg:98.18ms +step:1318/1670 train_time:129395ms step_avg:98.18ms +step:1319/1670 train_time:129492ms step_avg:98.17ms +step:1320/1670 train_time:129589ms step_avg:98.17ms +step:1321/1670 train_time:129688ms step_avg:98.17ms +step:1322/1670 train_time:129786ms step_avg:98.17ms +step:1323/1670 train_time:129884ms step_avg:98.17ms +step:1324/1670 train_time:129983ms step_avg:98.17ms +step:1325/1670 train_time:130083ms step_avg:98.18ms +step:1326/1670 train_time:130183ms step_avg:98.18ms +step:1327/1670 train_time:130281ms step_avg:98.18ms +step:1328/1670 train_time:130378ms step_avg:98.18ms +step:1329/1670 train_time:130476ms step_avg:98.18ms +step:1330/1670 train_time:130572ms step_avg:98.17ms +step:1331/1670 train_time:130672ms step_avg:98.18ms +step:1332/1670 train_time:130772ms step_avg:98.18ms +step:1333/1670 train_time:130871ms step_avg:98.18ms +step:1334/1670 train_time:130971ms step_avg:98.18ms +step:1335/1670 train_time:131071ms step_avg:98.18ms +step:1336/1670 train_time:131171ms step_avg:98.18ms +step:1337/1670 train_time:131271ms step_avg:98.18ms +step:1338/1670 train_time:131371ms step_avg:98.18ms +step:1339/1670 train_time:131469ms step_avg:98.18ms +step:1340/1670 train_time:131567ms step_avg:98.18ms +step:1341/1670 train_time:131664ms step_avg:98.18ms +step:1342/1670 train_time:131762ms step_avg:98.18ms +step:1343/1670 train_time:131859ms step_avg:98.18ms +step:1344/1670 train_time:131956ms step_avg:98.18ms +step:1345/1670 train_time:132054ms step_avg:98.18ms +step:1346/1670 train_time:132153ms step_avg:98.18ms +step:1347/1670 train_time:132251ms step_avg:98.18ms +step:1348/1670 train_time:132351ms step_avg:98.18ms +step:1349/1670 train_time:132449ms step_avg:98.18ms +step:1350/1670 train_time:132548ms step_avg:98.18ms +step:1351/1670 train_time:132647ms step_avg:98.18ms +step:1352/1670 train_time:132746ms step_avg:98.19ms +step:1353/1670 train_time:132844ms step_avg:98.18ms +step:1354/1670 train_time:132942ms step_avg:98.18ms +step:1355/1670 train_time:133039ms step_avg:98.18ms +step:1356/1670 train_time:133137ms step_avg:98.18ms +step:1357/1670 train_time:133235ms step_avg:98.18ms +step:1358/1670 train_time:133333ms step_avg:98.18ms +step:1359/1670 train_time:133431ms step_avg:98.18ms +step:1360/1670 train_time:133529ms step_avg:98.18ms +step:1361/1670 train_time:133627ms step_avg:98.18ms +step:1362/1670 train_time:133726ms step_avg:98.18ms +step:1363/1670 train_time:133824ms step_avg:98.18ms +step:1364/1670 train_time:133922ms step_avg:98.18ms +step:1365/1670 train_time:134020ms step_avg:98.18ms +step:1366/1670 train_time:134118ms step_avg:98.18ms +step:1367/1670 train_time:134215ms step_avg:98.18ms +step:1368/1670 train_time:134314ms step_avg:98.18ms +step:1369/1670 train_time:134411ms step_avg:98.18ms +step:1370/1670 train_time:134510ms step_avg:98.18ms +step:1371/1670 train_time:134608ms step_avg:98.18ms +step:1372/1670 train_time:134706ms step_avg:98.18ms +step:1373/1670 train_time:134805ms step_avg:98.18ms +step:1374/1670 train_time:134902ms step_avg:98.18ms +step:1375/1670 train_time:135001ms step_avg:98.18ms +step:1375/1670 val_loss:3.3446 train_time:135098ms step_avg:98.25ms +step:1376/1670 train_time:135120ms step_avg:98.20ms +step:1377/1670 train_time:135206ms step_avg:98.19ms +step:1378/1670 train_time:135308ms step_avg:98.19ms +step:1379/1670 train_time:135406ms step_avg:98.19ms +step:1380/1670 train_time:135503ms step_avg:98.19ms +step:1381/1670 train_time:135600ms step_avg:98.19ms +step:1382/1670 train_time:135697ms step_avg:98.19ms +step:1383/1670 train_time:135794ms step_avg:98.19ms +step:1384/1670 train_time:135892ms step_avg:98.19ms +step:1385/1670 train_time:135989ms step_avg:98.19ms +step:1386/1670 train_time:136087ms step_avg:98.19ms +step:1387/1670 train_time:136188ms step_avg:98.19ms +step:1388/1670 train_time:136288ms step_avg:98.19ms +step:1389/1670 train_time:136385ms step_avg:98.19ms +step:1390/1670 train_time:136483ms step_avg:98.19ms +step:1391/1670 train_time:136581ms step_avg:98.19ms +step:1392/1670 train_time:136678ms step_avg:98.19ms +step:1393/1670 train_time:136776ms step_avg:98.19ms +step:1394/1670 train_time:136872ms step_avg:98.19ms +step:1395/1670 train_time:136969ms step_avg:98.19ms +step:1396/1670 train_time:137067ms step_avg:98.19ms +step:1397/1670 train_time:137167ms step_avg:98.19ms +step:1398/1670 train_time:137266ms step_avg:98.19ms +step:1399/1670 train_time:137364ms step_avg:98.19ms +step:1400/1670 train_time:137463ms step_avg:98.19ms +step:1401/1670 train_time:137561ms step_avg:98.19ms +step:1402/1670 train_time:137659ms step_avg:98.19ms +step:1403/1670 train_time:137757ms step_avg:98.19ms +step:1404/1670 train_time:137856ms step_avg:98.19ms +step:1405/1670 train_time:137953ms step_avg:98.19ms +step:1406/1670 train_time:138051ms step_avg:98.19ms +step:1407/1670 train_time:138150ms step_avg:98.19ms +step:1408/1670 train_time:138248ms step_avg:98.19ms +step:1409/1670 train_time:138346ms step_avg:98.19ms +step:1410/1670 train_time:138444ms step_avg:98.19ms +step:1411/1670 train_time:138542ms step_avg:98.19ms +step:1412/1670 train_time:138640ms step_avg:98.19ms +step:1413/1670 train_time:138738ms step_avg:98.19ms +step:1414/1670 train_time:138836ms step_avg:98.19ms +step:1415/1670 train_time:138934ms step_avg:98.19ms +step:1416/1670 train_time:139032ms step_avg:98.19ms +step:1417/1670 train_time:139131ms step_avg:98.19ms +step:1418/1670 train_time:139230ms step_avg:98.19ms +step:1419/1670 train_time:139328ms step_avg:98.19ms +step:1420/1670 train_time:139426ms step_avg:98.19ms +step:1421/1670 train_time:139524ms step_avg:98.19ms +step:1422/1670 train_time:139622ms step_avg:98.19ms +step:1423/1670 train_time:139721ms step_avg:98.19ms +step:1424/1670 train_time:139819ms step_avg:98.19ms +step:1425/1670 train_time:139917ms step_avg:98.19ms +step:1426/1670 train_time:140015ms step_avg:98.19ms +step:1427/1670 train_time:140115ms step_avg:98.19ms +step:1428/1670 train_time:140216ms step_avg:98.19ms +step:1429/1670 train_time:140315ms step_avg:98.19ms +step:1430/1670 train_time:140414ms step_avg:98.19ms +step:1431/1670 train_time:140511ms step_avg:98.19ms +step:1432/1670 train_time:140609ms step_avg:98.19ms +step:1433/1670 train_time:140705ms step_avg:98.19ms +step:1434/1670 train_time:140802ms step_avg:98.19ms +step:1435/1670 train_time:140901ms step_avg:98.19ms +step:1436/1670 train_time:140999ms step_avg:98.19ms +step:1437/1670 train_time:141098ms step_avg:98.19ms +step:1438/1670 train_time:141197ms step_avg:98.19ms +step:1439/1670 train_time:141298ms step_avg:98.19ms +step:1440/1670 train_time:141398ms step_avg:98.19ms +step:1441/1670 train_time:141497ms step_avg:98.19ms +step:1442/1670 train_time:141597ms step_avg:98.19ms +step:1443/1670 train_time:141696ms step_avg:98.20ms +step:1444/1670 train_time:141794ms step_avg:98.20ms +step:1445/1670 train_time:141891ms step_avg:98.19ms +step:1446/1670 train_time:141988ms step_avg:98.19ms +step:1447/1670 train_time:142086ms step_avg:98.19ms +step:1448/1670 train_time:142184ms step_avg:98.19ms +step:1449/1670 train_time:142283ms step_avg:98.19ms +step:1450/1670 train_time:142382ms step_avg:98.19ms +step:1451/1670 train_time:142481ms step_avg:98.20ms +step:1452/1670 train_time:142580ms step_avg:98.20ms +step:1453/1670 train_time:142679ms step_avg:98.20ms +step:1454/1670 train_time:142777ms step_avg:98.20ms +step:1455/1670 train_time:142876ms step_avg:98.20ms +step:1456/1670 train_time:142973ms step_avg:98.20ms +step:1457/1670 train_time:143071ms step_avg:98.20ms +step:1458/1670 train_time:143168ms step_avg:98.19ms +step:1459/1670 train_time:143265ms step_avg:98.19ms +step:1460/1670 train_time:143363ms step_avg:98.19ms +step:1461/1670 train_time:143463ms step_avg:98.19ms +step:1462/1670 train_time:143561ms step_avg:98.20ms +step:1463/1670 train_time:143660ms step_avg:98.20ms +step:1464/1670 train_time:143759ms step_avg:98.20ms +step:1465/1670 train_time:143858ms step_avg:98.20ms +step:1466/1670 train_time:143957ms step_avg:98.20ms +step:1467/1670 train_time:144056ms step_avg:98.20ms +step:1468/1670 train_time:144156ms step_avg:98.20ms +step:1469/1670 train_time:144253ms step_avg:98.20ms +step:1470/1670 train_time:144351ms step_avg:98.20ms +step:1471/1670 train_time:144449ms step_avg:98.20ms +step:1472/1670 train_time:144546ms step_avg:98.20ms +step:1473/1670 train_time:144644ms step_avg:98.20ms +step:1474/1670 train_time:144742ms step_avg:98.20ms +step:1475/1670 train_time:144841ms step_avg:98.20ms +step:1476/1670 train_time:144940ms step_avg:98.20ms +step:1477/1670 train_time:145039ms step_avg:98.20ms +step:1478/1670 train_time:145138ms step_avg:98.20ms +step:1479/1670 train_time:145238ms step_avg:98.20ms +step:1480/1670 train_time:145337ms step_avg:98.20ms +step:1481/1670 train_time:145438ms step_avg:98.20ms +step:1482/1670 train_time:145538ms step_avg:98.20ms +step:1483/1670 train_time:145637ms step_avg:98.20ms +step:1484/1670 train_time:145734ms step_avg:98.20ms +step:1485/1670 train_time:146009ms step_avg:98.32ms +step:1486/1670 train_time:146190ms step_avg:98.38ms +step:1487/1670 train_time:146285ms step_avg:98.38ms +step:1488/1670 train_time:146382ms step_avg:98.37ms +step:1489/1670 train_time:146479ms step_avg:98.37ms +step:1490/1670 train_time:146577ms step_avg:98.37ms +step:1491/1670 train_time:146673ms step_avg:98.37ms +step:1492/1670 train_time:146770ms step_avg:98.37ms +step:1493/1670 train_time:146866ms step_avg:98.37ms +step:1494/1670 train_time:146964ms step_avg:98.37ms +step:1495/1670 train_time:147069ms step_avg:98.37ms +step:1496/1670 train_time:147171ms step_avg:98.38ms +step:1497/1670 train_time:147269ms step_avg:98.38ms +step:1498/1670 train_time:147366ms step_avg:98.37ms +step:1499/1670 train_time:147463ms step_avg:98.37ms +step:1500/1670 train_time:147561ms step_avg:98.37ms +step:1500/1670 val_loss:3.3128 train_time:147658ms step_avg:98.44ms +step:1501/1670 train_time:147679ms step_avg:98.39ms +step:1502/1670 train_time:147764ms step_avg:98.38ms +step:1503/1670 train_time:147864ms step_avg:98.38ms +step:1504/1670 train_time:147962ms step_avg:98.38ms +step:1505/1670 train_time:148059ms step_avg:98.38ms +step:1506/1670 train_time:148157ms step_avg:98.38ms +step:1507/1670 train_time:148255ms step_avg:98.38ms +step:1508/1670 train_time:148352ms step_avg:98.38ms +step:1509/1670 train_time:148449ms step_avg:98.38ms +step:1510/1670 train_time:148546ms step_avg:98.37ms +step:1511/1670 train_time:148645ms step_avg:98.38ms +step:1512/1670 train_time:148745ms step_avg:98.38ms +step:1513/1670 train_time:148844ms step_avg:98.38ms +step:1514/1670 train_time:148943ms step_avg:98.38ms +step:1515/1670 train_time:149040ms step_avg:98.38ms +step:1516/1670 train_time:149138ms step_avg:98.38ms +step:1517/1670 train_time:149236ms step_avg:98.38ms +step:1518/1670 train_time:149334ms step_avg:98.38ms +step:1519/1670 train_time:149432ms step_avg:98.38ms +step:1520/1670 train_time:149530ms step_avg:98.37ms +step:1521/1670 train_time:149628ms step_avg:98.37ms +step:1522/1670 train_time:149727ms step_avg:98.38ms +step:1523/1670 train_time:149826ms step_avg:98.38ms +step:1524/1670 train_time:149925ms step_avg:98.38ms +step:1525/1670 train_time:150022ms step_avg:98.38ms +step:1526/1670 train_time:150120ms step_avg:98.37ms +step:1527/1670 train_time:150218ms step_avg:98.37ms +step:1528/1670 train_time:150317ms step_avg:98.37ms +step:1529/1670 train_time:150415ms step_avg:98.37ms +step:1530/1670 train_time:150514ms step_avg:98.37ms +step:1531/1670 train_time:150613ms step_avg:98.38ms +step:1532/1670 train_time:150712ms step_avg:98.38ms +step:1533/1670 train_time:150810ms step_avg:98.38ms +step:1534/1670 train_time:150910ms step_avg:98.38ms +step:1535/1670 train_time:151008ms step_avg:98.38ms +step:1536/1670 train_time:151106ms step_avg:98.38ms +step:1537/1670 train_time:151202ms step_avg:98.38ms +step:1538/1670 train_time:151299ms step_avg:98.37ms +step:1539/1670 train_time:151396ms step_avg:98.37ms +step:1540/1670 train_time:151495ms step_avg:98.37ms +step:1541/1670 train_time:151593ms step_avg:98.37ms +step:1542/1670 train_time:151693ms step_avg:98.37ms +step:1543/1670 train_time:151793ms step_avg:98.38ms +step:1544/1670 train_time:151893ms step_avg:98.38ms +step:1545/1670 train_time:151991ms step_avg:98.38ms +step:1546/1670 train_time:152090ms step_avg:98.38ms +step:1547/1670 train_time:152188ms step_avg:98.38ms +step:1548/1670 train_time:152284ms step_avg:98.37ms +step:1549/1670 train_time:152382ms step_avg:98.37ms +step:1550/1670 train_time:152480ms step_avg:98.37ms +step:1551/1670 train_time:152578ms step_avg:98.37ms +step:1552/1670 train_time:152677ms step_avg:98.37ms +step:1553/1670 train_time:152778ms step_avg:98.38ms +step:1554/1670 train_time:152877ms step_avg:98.38ms +step:1555/1670 train_time:152979ms step_avg:98.38ms +step:1556/1670 train_time:153078ms step_avg:98.38ms +step:1557/1670 train_time:153179ms step_avg:98.38ms +step:1558/1670 train_time:153278ms step_avg:98.38ms +step:1559/1670 train_time:153377ms step_avg:98.38ms +step:1560/1670 train_time:153476ms step_avg:98.38ms +step:1561/1670 train_time:153574ms step_avg:98.38ms +step:1562/1670 train_time:153672ms step_avg:98.38ms +step:1563/1670 train_time:153771ms step_avg:98.38ms +step:1564/1670 train_time:153870ms step_avg:98.38ms +step:1565/1670 train_time:153969ms step_avg:98.38ms +step:1566/1670 train_time:154067ms step_avg:98.38ms +step:1567/1670 train_time:154164ms step_avg:98.38ms +step:1568/1670 train_time:154262ms step_avg:98.38ms +step:1569/1670 train_time:154361ms step_avg:98.38ms +step:1570/1670 train_time:154459ms step_avg:98.38ms +step:1571/1670 train_time:154557ms step_avg:98.38ms +step:1572/1670 train_time:154655ms step_avg:98.38ms +step:1573/1670 train_time:154754ms step_avg:98.38ms +step:1574/1670 train_time:154853ms step_avg:98.38ms +step:1575/1670 train_time:154951ms step_avg:98.38ms +step:1576/1670 train_time:155050ms step_avg:98.38ms +step:1577/1670 train_time:155149ms step_avg:98.38ms +step:1578/1670 train_time:155246ms step_avg:98.38ms +step:1579/1670 train_time:155343ms step_avg:98.38ms +step:1580/1670 train_time:155441ms step_avg:98.38ms +step:1581/1670 train_time:155539ms step_avg:98.38ms +step:1582/1670 train_time:155637ms step_avg:98.38ms +step:1583/1670 train_time:155735ms step_avg:98.38ms +step:1584/1670 train_time:155832ms step_avg:98.38ms +step:1585/1670 train_time:155931ms step_avg:98.38ms +step:1586/1670 train_time:156029ms step_avg:98.38ms +step:1587/1670 train_time:156127ms step_avg:98.38ms +step:1588/1670 train_time:156225ms step_avg:98.38ms +step:1589/1670 train_time:156323ms step_avg:98.38ms +step:1590/1670 train_time:156420ms step_avg:98.38ms +step:1591/1670 train_time:156519ms step_avg:98.38ms +step:1592/1670 train_time:156618ms step_avg:98.38ms +step:1593/1670 train_time:156716ms step_avg:98.38ms +step:1594/1670 train_time:156814ms step_avg:98.38ms +step:1595/1670 train_time:156914ms step_avg:98.38ms +step:1596/1670 train_time:157012ms step_avg:98.38ms +step:1597/1670 train_time:157111ms step_avg:98.38ms +step:1598/1670 train_time:157211ms step_avg:98.38ms +step:1599/1670 train_time:157310ms step_avg:98.38ms +step:1600/1670 train_time:157407ms step_avg:98.38ms +step:1601/1670 train_time:157504ms step_avg:98.38ms +step:1602/1670 train_time:157601ms step_avg:98.38ms +step:1603/1670 train_time:157699ms step_avg:98.38ms +step:1604/1670 train_time:157797ms step_avg:98.38ms +step:1605/1670 train_time:157896ms step_avg:98.38ms +step:1606/1670 train_time:157996ms step_avg:98.38ms +step:1607/1670 train_time:158096ms step_avg:98.38ms +step:1608/1670 train_time:158196ms step_avg:98.38ms +step:1609/1670 train_time:158294ms step_avg:98.38ms +step:1610/1670 train_time:158394ms step_avg:98.38ms +step:1611/1670 train_time:158491ms step_avg:98.38ms +step:1612/1670 train_time:158589ms step_avg:98.38ms +step:1613/1670 train_time:158686ms step_avg:98.38ms +step:1614/1670 train_time:158783ms step_avg:98.38ms +step:1615/1670 train_time:158880ms step_avg:98.38ms +step:1616/1670 train_time:158979ms step_avg:98.38ms +step:1617/1670 train_time:159078ms step_avg:98.38ms +step:1618/1670 train_time:159178ms step_avg:98.38ms +step:1619/1670 train_time:159276ms step_avg:98.38ms +step:1620/1670 train_time:159375ms step_avg:98.38ms +step:1621/1670 train_time:159474ms step_avg:98.38ms +step:1622/1670 train_time:159572ms step_avg:98.38ms +step:1623/1670 train_time:159670ms step_avg:98.38ms +step:1624/1670 train_time:159768ms step_avg:98.38ms +step:1625/1670 train_time:159865ms step_avg:98.38ms +step:1625/1670 val_loss:3.2856 train_time:159961ms step_avg:98.44ms +step:1626/1670 train_time:159983ms step_avg:98.39ms +step:1627/1670 train_time:160067ms step_avg:98.38ms +step:1628/1670 train_time:160167ms step_avg:98.38ms +step:1629/1670 train_time:160265ms step_avg:98.38ms +step:1630/1670 train_time:160363ms step_avg:98.38ms +step:1631/1670 train_time:160462ms step_avg:98.38ms +step:1632/1670 train_time:160559ms step_avg:98.38ms +step:1633/1670 train_time:160657ms step_avg:98.38ms +step:1634/1670 train_time:160756ms step_avg:98.38ms +step:1635/1670 train_time:160854ms step_avg:98.38ms +step:1636/1670 train_time:160954ms step_avg:98.38ms +step:1637/1670 train_time:161054ms step_avg:98.38ms +step:1638/1670 train_time:161154ms step_avg:98.38ms +step:1639/1670 train_time:161253ms step_avg:98.38ms +step:1640/1670 train_time:161350ms step_avg:98.38ms +step:1641/1670 train_time:161447ms step_avg:98.38ms +step:1642/1670 train_time:161545ms step_avg:98.38ms +step:1643/1670 train_time:161642ms step_avg:98.38ms +step:1644/1670 train_time:161739ms step_avg:98.38ms +step:1645/1670 train_time:161837ms step_avg:98.38ms +step:1646/1670 train_time:161936ms step_avg:98.38ms +step:1647/1670 train_time:162036ms step_avg:98.38ms +step:1648/1670 train_time:162138ms step_avg:98.38ms +step:1649/1670 train_time:162237ms step_avg:98.39ms +step:1650/1670 train_time:162335ms step_avg:98.38ms +step:1651/1670 train_time:162434ms step_avg:98.39ms +step:1652/1670 train_time:162532ms step_avg:98.39ms +step:1653/1670 train_time:162629ms step_avg:98.38ms +step:1654/1670 train_time:162726ms step_avg:98.38ms +step:1655/1670 train_time:162823ms step_avg:98.38ms +step:1656/1670 train_time:162921ms step_avg:98.38ms +step:1657/1670 train_time:163021ms step_avg:98.38ms +step:1658/1670 train_time:163121ms step_avg:98.38ms +step:1659/1670 train_time:163222ms step_avg:98.39ms +step:1660/1670 train_time:163321ms step_avg:98.39ms +step:1661/1670 train_time:163422ms step_avg:98.39ms +step:1662/1670 train_time:163521ms step_avg:98.39ms +step:1663/1670 train_time:163619ms step_avg:98.39ms +step:1664/1670 train_time:163717ms step_avg:98.39ms +step:1665/1670 train_time:163815ms step_avg:98.39ms +step:1666/1670 train_time:163913ms step_avg:98.39ms +step:1667/1670 train_time:164010ms step_avg:98.39ms +step:1668/1670 train_time:164108ms step_avg:98.39ms +step:1669/1670 train_time:164206ms step_avg:98.39ms +step:1670/1670 train_time:164304ms step_avg:98.39ms +step:1670/1670 val_loss:3.2778 train_time:164403ms step_avg:98.44ms +peak memory allocated: 34000 MiB reserved: 49576 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt b/records/050925_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt new file mode 100644 index 000000000..3fe1483c6 --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:14:13 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 128W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 44C P0 132W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 84065 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 84066 C /usr/bin/python3 610MiB | +| 0 N/A N/A 84067 C /usr/bin/python3 610MiB | +| 0 N/A N/A 84068 C /usr/bin/python3 610MiB | +| 0 N/A N/A 84069 C /usr/bin/python3 610MiB | +| 0 N/A N/A 84070 C /usr/bin/python3 610MiB | +| 0 N/A N/A 84071 C /usr/bin/python3 610MiB | +| 0 N/A N/A 84072 C /usr/bin/python3 610MiB | +| 1 N/A N/A 84066 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 84067 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 84068 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 84069 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 84070 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 84071 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 84072 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:387ms step_avg:387.37ms +step:2/1670 train_time:407ms step_avg:203.50ms +step:3/1670 train_time:480ms step_avg:160.12ms +step:4/1670 train_time:574ms step_avg:143.57ms +step:5/1670 train_time:669ms step_avg:133.77ms +step:6/1670 train_time:764ms step_avg:127.36ms +step:7/1670 train_time:859ms step_avg:122.71ms +step:8/1670 train_time:954ms step_avg:119.23ms +step:9/1670 train_time:1049ms step_avg:116.52ms +step:10/1670 train_time:1144ms step_avg:114.45ms +step:11/1670 train_time:1240ms step_avg:112.72ms +step:12/1670 train_time:1338ms step_avg:111.54ms +step:13/1670 train_time:1437ms step_avg:110.50ms +step:14/1670 train_time:1533ms step_avg:109.50ms +step:15/1670 train_time:1629ms step_avg:108.58ms +step:16/1670 train_time:1724ms step_avg:107.76ms +step:17/1670 train_time:1821ms step_avg:107.12ms +step:18/1670 train_time:1916ms step_avg:106.45ms +step:19/1670 train_time:2012ms step_avg:105.87ms +step:20/1670 train_time:2107ms step_avg:105.34ms +step:21/1670 train_time:2203ms step_avg:104.88ms +step:22/1670 train_time:2299ms step_avg:104.51ms +step:23/1670 train_time:2396ms step_avg:104.19ms +step:24/1670 train_time:2493ms step_avg:103.86ms +step:25/1670 train_time:2588ms step_avg:103.54ms +step:26/1670 train_time:2684ms step_avg:103.24ms +step:27/1670 train_time:2780ms step_avg:102.96ms +step:28/1670 train_time:2876ms step_avg:102.70ms +step:29/1670 train_time:2971ms step_avg:102.46ms +step:30/1670 train_time:3067ms step_avg:102.22ms +step:31/1670 train_time:3163ms step_avg:102.02ms +step:32/1670 train_time:3259ms step_avg:101.86ms +step:33/1670 train_time:3356ms step_avg:101.71ms +step:34/1670 train_time:3452ms step_avg:101.54ms +step:35/1670 train_time:3548ms step_avg:101.38ms +step:36/1670 train_time:3645ms step_avg:101.25ms +step:37/1670 train_time:3741ms step_avg:101.12ms +step:38/1670 train_time:3838ms step_avg:100.99ms +step:39/1670 train_time:3933ms step_avg:100.84ms +step:40/1670 train_time:4028ms step_avg:100.70ms +step:41/1670 train_time:4124ms step_avg:100.59ms +step:42/1670 train_time:4220ms step_avg:100.48ms +step:43/1670 train_time:4316ms step_avg:100.37ms +step:44/1670 train_time:4411ms step_avg:100.26ms +step:45/1670 train_time:4508ms step_avg:100.17ms +step:46/1670 train_time:4605ms step_avg:100.10ms +step:47/1670 train_time:4701ms step_avg:100.03ms +step:48/1670 train_time:4797ms step_avg:99.94ms +step:49/1670 train_time:4893ms step_avg:99.85ms +step:50/1670 train_time:4988ms step_avg:99.77ms +step:51/1670 train_time:5085ms step_avg:99.70ms +step:52/1670 train_time:5181ms step_avg:99.64ms +step:53/1670 train_time:5277ms step_avg:99.56ms +step:54/1670 train_time:5373ms step_avg:99.49ms +step:55/1670 train_time:5469ms step_avg:99.44ms +step:56/1670 train_time:5566ms step_avg:99.40ms +step:57/1670 train_time:5663ms step_avg:99.35ms +step:58/1670 train_time:5759ms step_avg:99.29ms +step:59/1670 train_time:5855ms step_avg:99.24ms +step:60/1670 train_time:5950ms step_avg:99.17ms +step:61/1670 train_time:6047ms step_avg:99.13ms +step:62/1670 train_time:6143ms step_avg:99.08ms +step:63/1670 train_time:6240ms step_avg:99.04ms +step:64/1670 train_time:6335ms step_avg:98.99ms +step:65/1670 train_time:6430ms step_avg:98.93ms +step:66/1670 train_time:6526ms step_avg:98.88ms +step:67/1670 train_time:6622ms step_avg:98.84ms +step:68/1670 train_time:6718ms step_avg:98.80ms +step:69/1670 train_time:6814ms step_avg:98.75ms +step:70/1670 train_time:6909ms step_avg:98.70ms +step:71/1670 train_time:7005ms step_avg:98.66ms +step:72/1670 train_time:7101ms step_avg:98.63ms +step:73/1670 train_time:7197ms step_avg:98.60ms +step:74/1670 train_time:7294ms step_avg:98.56ms +step:75/1670 train_time:7389ms step_avg:98.52ms +step:76/1670 train_time:7485ms step_avg:98.49ms +step:77/1670 train_time:7581ms step_avg:98.45ms +step:78/1670 train_time:7676ms step_avg:98.42ms +step:79/1670 train_time:7772ms step_avg:98.38ms +step:80/1670 train_time:7868ms step_avg:98.35ms +step:81/1670 train_time:7964ms step_avg:98.32ms +step:82/1670 train_time:8060ms step_avg:98.29ms +step:83/1670 train_time:8156ms step_avg:98.27ms +step:84/1670 train_time:8252ms step_avg:98.24ms +step:85/1670 train_time:8348ms step_avg:98.21ms +step:86/1670 train_time:8444ms step_avg:98.19ms +step:87/1670 train_time:8540ms step_avg:98.17ms +step:88/1670 train_time:8637ms step_avg:98.14ms +step:89/1670 train_time:8732ms step_avg:98.11ms +step:90/1670 train_time:8828ms step_avg:98.09ms +step:91/1670 train_time:8924ms step_avg:98.06ms +step:92/1670 train_time:9020ms step_avg:98.04ms +step:93/1670 train_time:9115ms step_avg:98.01ms +step:94/1670 train_time:9212ms step_avg:98.00ms +step:95/1670 train_time:9308ms step_avg:97.97ms +step:96/1670 train_time:9404ms step_avg:97.96ms +step:97/1670 train_time:9499ms step_avg:97.93ms +step:98/1670 train_time:9595ms step_avg:97.91ms +step:99/1670 train_time:9691ms step_avg:97.89ms +step:100/1670 train_time:9787ms step_avg:97.87ms +step:101/1670 train_time:9883ms step_avg:97.85ms +step:102/1670 train_time:9979ms step_avg:97.83ms +step:103/1670 train_time:10075ms step_avg:97.82ms +step:104/1670 train_time:10170ms step_avg:97.79ms +step:105/1670 train_time:10267ms step_avg:97.78ms +step:106/1670 train_time:10364ms step_avg:97.77ms +step:107/1670 train_time:10460ms step_avg:97.76ms +step:108/1670 train_time:10555ms step_avg:97.73ms +step:109/1670 train_time:10651ms step_avg:97.71ms +step:110/1670 train_time:10746ms step_avg:97.69ms +step:111/1670 train_time:10843ms step_avg:97.68ms +step:112/1670 train_time:10939ms step_avg:97.67ms +step:113/1670 train_time:11035ms step_avg:97.65ms +step:114/1670 train_time:11130ms step_avg:97.63ms +step:115/1670 train_time:11226ms step_avg:97.62ms +step:116/1670 train_time:11323ms step_avg:97.61ms +step:117/1670 train_time:11419ms step_avg:97.60ms +step:118/1670 train_time:11514ms step_avg:97.58ms +step:119/1670 train_time:11609ms step_avg:97.56ms +step:120/1670 train_time:11706ms step_avg:97.55ms +step:121/1670 train_time:11801ms step_avg:97.53ms +step:122/1670 train_time:11897ms step_avg:97.52ms +step:123/1670 train_time:11992ms step_avg:97.50ms +step:124/1670 train_time:12089ms step_avg:97.49ms +step:125/1670 train_time:12184ms step_avg:97.47ms +step:125/1670 val_loss:4.3072 train_time:12279ms step_avg:98.23ms +step:126/1670 train_time:12301ms step_avg:97.63ms +step:127/1670 train_time:12384ms step_avg:97.51ms +step:128/1670 train_time:12490ms step_avg:97.58ms +step:129/1670 train_time:12587ms step_avg:97.57ms +step:130/1670 train_time:12682ms step_avg:97.55ms +step:131/1670 train_time:12777ms step_avg:97.54ms +step:132/1670 train_time:12873ms step_avg:97.52ms +step:133/1670 train_time:12967ms step_avg:97.49ms +step:134/1670 train_time:13061ms step_avg:97.47ms +step:135/1670 train_time:13156ms step_avg:97.45ms +step:136/1670 train_time:13251ms step_avg:97.44ms +step:137/1670 train_time:13348ms step_avg:97.43ms +step:138/1670 train_time:13445ms step_avg:97.43ms +step:139/1670 train_time:13542ms step_avg:97.42ms +step:140/1670 train_time:13638ms step_avg:97.41ms +step:141/1670 train_time:13734ms step_avg:97.40ms +step:142/1670 train_time:13830ms step_avg:97.39ms +step:143/1670 train_time:13924ms step_avg:97.37ms +step:144/1670 train_time:14019ms step_avg:97.35ms +step:145/1670 train_time:14114ms step_avg:97.34ms +step:146/1670 train_time:14209ms step_avg:97.32ms +step:147/1670 train_time:14304ms step_avg:97.31ms +step:148/1670 train_time:14400ms step_avg:97.30ms +step:149/1670 train_time:14498ms step_avg:97.30ms +step:150/1670 train_time:14595ms step_avg:97.30ms +step:151/1670 train_time:14691ms step_avg:97.29ms +step:152/1670 train_time:14786ms step_avg:97.28ms +step:153/1670 train_time:14881ms step_avg:97.26ms +step:154/1670 train_time:14976ms step_avg:97.25ms +step:155/1670 train_time:15071ms step_avg:97.23ms +step:156/1670 train_time:15166ms step_avg:97.22ms +step:157/1670 train_time:15260ms step_avg:97.20ms +step:158/1670 train_time:15356ms step_avg:97.19ms +step:159/1670 train_time:15453ms step_avg:97.19ms +step:160/1670 train_time:15549ms step_avg:97.18ms +step:161/1670 train_time:15646ms step_avg:97.18ms +step:162/1670 train_time:15740ms step_avg:97.16ms +step:163/1670 train_time:15837ms step_avg:97.16ms +step:164/1670 train_time:15932ms step_avg:97.15ms +step:165/1670 train_time:16027ms step_avg:97.13ms +step:166/1670 train_time:16122ms step_avg:97.12ms +step:167/1670 train_time:16217ms step_avg:97.11ms +step:168/1670 train_time:16313ms step_avg:97.10ms +step:169/1670 train_time:16408ms step_avg:97.09ms +step:170/1670 train_time:16503ms step_avg:97.08ms +step:171/1670 train_time:16599ms step_avg:97.07ms +step:172/1670 train_time:16695ms step_avg:97.07ms +step:173/1670 train_time:16792ms step_avg:97.06ms +step:174/1670 train_time:16887ms step_avg:97.05ms +step:175/1670 train_time:16982ms step_avg:97.04ms +step:176/1670 train_time:17077ms step_avg:97.03ms +step:177/1670 train_time:17173ms step_avg:97.02ms +step:178/1670 train_time:17268ms step_avg:97.01ms +step:179/1670 train_time:17363ms step_avg:97.00ms +step:180/1670 train_time:17459ms step_avg:96.99ms +step:181/1670 train_time:17554ms step_avg:96.99ms +step:182/1670 train_time:17651ms step_avg:96.98ms +step:183/1670 train_time:17746ms step_avg:96.97ms +step:184/1670 train_time:17841ms step_avg:96.96ms +step:185/1670 train_time:17937ms step_avg:96.96ms +step:186/1670 train_time:18032ms step_avg:96.95ms +step:187/1670 train_time:18128ms step_avg:96.94ms +step:188/1670 train_time:18222ms step_avg:96.93ms +step:189/1670 train_time:18318ms step_avg:96.92ms +step:190/1670 train_time:18413ms step_avg:96.91ms +step:191/1670 train_time:18509ms step_avg:96.91ms +step:192/1670 train_time:18604ms step_avg:96.90ms +step:193/1670 train_time:18700ms step_avg:96.89ms +step:194/1670 train_time:18796ms step_avg:96.89ms +step:195/1670 train_time:18892ms step_avg:96.88ms +step:196/1670 train_time:18988ms step_avg:96.88ms +step:197/1670 train_time:19083ms step_avg:96.87ms +step:198/1670 train_time:19178ms step_avg:96.86ms +step:199/1670 train_time:19273ms step_avg:96.85ms +step:200/1670 train_time:19368ms step_avg:96.84ms +step:201/1670 train_time:19464ms step_avg:96.84ms +step:202/1670 train_time:19559ms step_avg:96.83ms +step:203/1670 train_time:19655ms step_avg:96.82ms +step:204/1670 train_time:19751ms step_avg:96.82ms +step:205/1670 train_time:19846ms step_avg:96.81ms +step:206/1670 train_time:19941ms step_avg:96.80ms +step:207/1670 train_time:20037ms step_avg:96.80ms +step:208/1670 train_time:20133ms step_avg:96.79ms +step:209/1670 train_time:20229ms step_avg:96.79ms +step:210/1670 train_time:20324ms step_avg:96.78ms +step:211/1670 train_time:20419ms step_avg:96.77ms +step:212/1670 train_time:20515ms step_avg:96.77ms +step:213/1670 train_time:20798ms step_avg:97.64ms +step:214/1670 train_time:20918ms step_avg:97.75ms +step:215/1670 train_time:21012ms step_avg:97.73ms +step:216/1670 train_time:21106ms step_avg:97.71ms +step:217/1670 train_time:21200ms step_avg:97.70ms +step:218/1670 train_time:21295ms step_avg:97.68ms +step:219/1670 train_time:21389ms step_avg:97.67ms +step:220/1670 train_time:21483ms step_avg:97.65ms +step:221/1670 train_time:21578ms step_avg:97.64ms +step:222/1670 train_time:21673ms step_avg:97.63ms +step:223/1670 train_time:21771ms step_avg:97.63ms +step:224/1670 train_time:21870ms step_avg:97.63ms +step:225/1670 train_time:21969ms step_avg:97.64ms +step:226/1670 train_time:22065ms step_avg:97.63ms +step:227/1670 train_time:22159ms step_avg:97.62ms +step:228/1670 train_time:22254ms step_avg:97.60ms +step:229/1670 train_time:22349ms step_avg:97.60ms +step:230/1670 train_time:22443ms step_avg:97.58ms +step:231/1670 train_time:22538ms step_avg:97.57ms +step:232/1670 train_time:22633ms step_avg:97.55ms +step:233/1670 train_time:22728ms step_avg:97.55ms +step:234/1670 train_time:22824ms step_avg:97.54ms +step:235/1670 train_time:22920ms step_avg:97.53ms +step:236/1670 train_time:23017ms step_avg:97.53ms +step:237/1670 train_time:23114ms step_avg:97.53ms +step:238/1670 train_time:23211ms step_avg:97.52ms +step:239/1670 train_time:23306ms step_avg:97.52ms +step:240/1670 train_time:23400ms step_avg:97.50ms +step:241/1670 train_time:23495ms step_avg:97.49ms +step:242/1670 train_time:23590ms step_avg:97.48ms +step:243/1670 train_time:23685ms step_avg:97.47ms +step:244/1670 train_time:23781ms step_avg:97.46ms +step:245/1670 train_time:23877ms step_avg:97.46ms +step:246/1670 train_time:23973ms step_avg:97.45ms +step:247/1670 train_time:24071ms step_avg:97.45ms +step:248/1670 train_time:24167ms step_avg:97.45ms +step:249/1670 train_time:24263ms step_avg:97.44ms +step:250/1670 train_time:24358ms step_avg:97.43ms +step:250/1670 val_loss:3.9738 train_time:24453ms step_avg:97.81ms +step:251/1670 train_time:24474ms step_avg:97.51ms +step:252/1670 train_time:24557ms step_avg:97.45ms +step:253/1670 train_time:24654ms step_avg:97.45ms +step:254/1670 train_time:24751ms step_avg:97.45ms +step:255/1670 train_time:24847ms step_avg:97.44ms +step:256/1670 train_time:24942ms step_avg:97.43ms +step:257/1670 train_time:25037ms step_avg:97.42ms +step:258/1670 train_time:25131ms step_avg:97.41ms +step:259/1670 train_time:25226ms step_avg:97.40ms +step:260/1670 train_time:25321ms step_avg:97.39ms +step:261/1670 train_time:25416ms step_avg:97.38ms +step:262/1670 train_time:25513ms step_avg:97.38ms +step:263/1670 train_time:25610ms step_avg:97.38ms +step:264/1670 train_time:25707ms step_avg:97.38ms +step:265/1670 train_time:25804ms step_avg:97.37ms +step:266/1670 train_time:25899ms step_avg:97.36ms +step:267/1670 train_time:25994ms step_avg:97.35ms +step:268/1670 train_time:26088ms step_avg:97.34ms +step:269/1670 train_time:26183ms step_avg:97.34ms +step:270/1670 train_time:26278ms step_avg:97.33ms +step:271/1670 train_time:26372ms step_avg:97.31ms +step:272/1670 train_time:26469ms step_avg:97.31ms +step:273/1670 train_time:26565ms step_avg:97.31ms +step:274/1670 train_time:26662ms step_avg:97.31ms +step:275/1670 train_time:26758ms step_avg:97.30ms +step:276/1670 train_time:26853ms step_avg:97.29ms +step:277/1670 train_time:26948ms step_avg:97.29ms +step:278/1670 train_time:27044ms step_avg:97.28ms +step:279/1670 train_time:27140ms step_avg:97.27ms +step:280/1670 train_time:27234ms step_avg:97.26ms +step:281/1670 train_time:27329ms step_avg:97.25ms +step:282/1670 train_time:27424ms step_avg:97.25ms +step:283/1670 train_time:27521ms step_avg:97.25ms +step:284/1670 train_time:27616ms step_avg:97.24ms +step:285/1670 train_time:27712ms step_avg:97.24ms +step:286/1670 train_time:27809ms step_avg:97.23ms +step:287/1670 train_time:27904ms step_avg:97.23ms +step:288/1670 train_time:28000ms step_avg:97.22ms +step:289/1670 train_time:28095ms step_avg:97.21ms +step:290/1670 train_time:28190ms step_avg:97.21ms +step:291/1670 train_time:28285ms step_avg:97.20ms +step:292/1670 train_time:28381ms step_avg:97.19ms +step:293/1670 train_time:28476ms step_avg:97.19ms +step:294/1670 train_time:28571ms step_avg:97.18ms +step:295/1670 train_time:28667ms step_avg:97.17ms +step:296/1670 train_time:28763ms step_avg:97.17ms +step:297/1670 train_time:28859ms step_avg:97.17ms +step:298/1670 train_time:28954ms step_avg:97.16ms +step:299/1670 train_time:29050ms step_avg:97.16ms +step:300/1670 train_time:29145ms step_avg:97.15ms +step:301/1670 train_time:29241ms step_avg:97.15ms +step:302/1670 train_time:29336ms step_avg:97.14ms +step:303/1670 train_time:29431ms step_avg:97.13ms +step:304/1670 train_time:29527ms step_avg:97.13ms +step:305/1670 train_time:29623ms step_avg:97.12ms +step:306/1670 train_time:29718ms step_avg:97.12ms +step:307/1670 train_time:29813ms step_avg:97.11ms +step:308/1670 train_time:29909ms step_avg:97.11ms +step:309/1670 train_time:30006ms step_avg:97.11ms +step:310/1670 train_time:30101ms step_avg:97.10ms +step:311/1670 train_time:30196ms step_avg:97.09ms +step:312/1670 train_time:30291ms step_avg:97.09ms +step:313/1670 train_time:30386ms step_avg:97.08ms +step:314/1670 train_time:30482ms step_avg:97.08ms +step:315/1670 train_time:30577ms step_avg:97.07ms +step:316/1670 train_time:30672ms step_avg:97.06ms +step:317/1670 train_time:30768ms step_avg:97.06ms +step:318/1670 train_time:30864ms step_avg:97.06ms +step:319/1670 train_time:30960ms step_avg:97.05ms +step:320/1670 train_time:31057ms step_avg:97.05ms +step:321/1670 train_time:31152ms step_avg:97.05ms +step:322/1670 train_time:31248ms step_avg:97.04ms +step:323/1670 train_time:31344ms step_avg:97.04ms +step:324/1670 train_time:31439ms step_avg:97.03ms +step:325/1670 train_time:31533ms step_avg:97.03ms +step:326/1670 train_time:31629ms step_avg:97.02ms +step:327/1670 train_time:31725ms step_avg:97.02ms +step:328/1670 train_time:31820ms step_avg:97.01ms +step:329/1670 train_time:31916ms step_avg:97.01ms +step:330/1670 train_time:32011ms step_avg:97.00ms +step:331/1670 train_time:32107ms step_avg:97.00ms +step:332/1670 train_time:32204ms step_avg:97.00ms +step:333/1670 train_time:32300ms step_avg:97.00ms +step:334/1670 train_time:32396ms step_avg:96.99ms +step:335/1670 train_time:32491ms step_avg:96.99ms +step:336/1670 train_time:32586ms step_avg:96.98ms +step:337/1670 train_time:32681ms step_avg:96.98ms +step:338/1670 train_time:32777ms step_avg:96.97ms +step:339/1670 train_time:32872ms step_avg:96.97ms +step:340/1670 train_time:32969ms step_avg:96.97ms +step:341/1670 train_time:33064ms step_avg:96.96ms +step:342/1670 train_time:33160ms step_avg:96.96ms +step:343/1670 train_time:33256ms step_avg:96.96ms +step:344/1670 train_time:33351ms step_avg:96.95ms +step:345/1670 train_time:33446ms step_avg:96.95ms +step:346/1670 train_time:33542ms step_avg:96.94ms +step:347/1670 train_time:33637ms step_avg:96.94ms +step:348/1670 train_time:33731ms step_avg:96.93ms +step:349/1670 train_time:33826ms step_avg:96.92ms +step:350/1670 train_time:33923ms step_avg:96.92ms +step:351/1670 train_time:34019ms step_avg:96.92ms +step:352/1670 train_time:34114ms step_avg:96.91ms +step:353/1670 train_time:34210ms step_avg:96.91ms +step:354/1670 train_time:34306ms step_avg:96.91ms +step:355/1670 train_time:34403ms step_avg:96.91ms +step:356/1670 train_time:34499ms step_avg:96.91ms +step:357/1670 train_time:34593ms step_avg:96.90ms +step:358/1670 train_time:34689ms step_avg:96.90ms +step:359/1670 train_time:34784ms step_avg:96.89ms +step:360/1670 train_time:34879ms step_avg:96.89ms +step:361/1670 train_time:34975ms step_avg:96.88ms +step:362/1670 train_time:35070ms step_avg:96.88ms +step:363/1670 train_time:35166ms step_avg:96.88ms +step:364/1670 train_time:35263ms step_avg:96.88ms +step:365/1670 train_time:35359ms step_avg:96.87ms +step:366/1670 train_time:35454ms step_avg:96.87ms +step:367/1670 train_time:35549ms step_avg:96.86ms +step:368/1670 train_time:35645ms step_avg:96.86ms +step:369/1670 train_time:35740ms step_avg:96.86ms +step:370/1670 train_time:35836ms step_avg:96.85ms +step:371/1670 train_time:35931ms step_avg:96.85ms +step:372/1670 train_time:36027ms step_avg:96.85ms +step:373/1670 train_time:36123ms step_avg:96.84ms +step:374/1670 train_time:36219ms step_avg:96.84ms +step:375/1670 train_time:36315ms step_avg:96.84ms +step:375/1670 val_loss:3.8193 train_time:36410ms step_avg:97.09ms +step:376/1670 train_time:36431ms step_avg:96.89ms +step:377/1670 train_time:36513ms step_avg:96.85ms +step:378/1670 train_time:36614ms step_avg:96.86ms +step:379/1670 train_time:36710ms step_avg:96.86ms +step:380/1670 train_time:36805ms step_avg:96.86ms +step:381/1670 train_time:36900ms step_avg:96.85ms +step:382/1670 train_time:36994ms step_avg:96.84ms +step:383/1670 train_time:37089ms step_avg:96.84ms +step:384/1670 train_time:37184ms step_avg:96.83ms +step:385/1670 train_time:37279ms step_avg:96.83ms +step:386/1670 train_time:37375ms step_avg:96.83ms +step:387/1670 train_time:37474ms step_avg:96.83ms +step:388/1670 train_time:37573ms step_avg:96.84ms +step:389/1670 train_time:37671ms step_avg:96.84ms +step:390/1670 train_time:37767ms step_avg:96.84ms +step:391/1670 train_time:37862ms step_avg:96.83ms +step:392/1670 train_time:37957ms step_avg:96.83ms +step:393/1670 train_time:38052ms step_avg:96.82ms +step:394/1670 train_time:38147ms step_avg:96.82ms +step:395/1670 train_time:38242ms step_avg:96.81ms +step:396/1670 train_time:38336ms step_avg:96.81ms +step:397/1670 train_time:38433ms step_avg:96.81ms +step:398/1670 train_time:38530ms step_avg:96.81ms +step:399/1670 train_time:38626ms step_avg:96.81ms +step:400/1670 train_time:38723ms step_avg:96.81ms +step:401/1670 train_time:38818ms step_avg:96.80ms +step:402/1670 train_time:38913ms step_avg:96.80ms +step:403/1670 train_time:39008ms step_avg:96.79ms +step:404/1670 train_time:39103ms step_avg:96.79ms +step:405/1670 train_time:39198ms step_avg:96.79ms +step:406/1670 train_time:39293ms step_avg:96.78ms +step:407/1670 train_time:39389ms step_avg:96.78ms +step:408/1670 train_time:39484ms step_avg:96.78ms +step:409/1670 train_time:39580ms step_avg:96.77ms +step:410/1670 train_time:39678ms step_avg:96.77ms +step:411/1670 train_time:39774ms step_avg:96.77ms +step:412/1670 train_time:39871ms step_avg:96.77ms +step:413/1670 train_time:39967ms step_avg:96.77ms +step:414/1670 train_time:40062ms step_avg:96.77ms +step:415/1670 train_time:40157ms step_avg:96.76ms +step:416/1670 train_time:40252ms step_avg:96.76ms +step:417/1670 train_time:40348ms step_avg:96.76ms +step:418/1670 train_time:40443ms step_avg:96.75ms +step:419/1670 train_time:40539ms step_avg:96.75ms +step:420/1670 train_time:40636ms step_avg:96.75ms +step:421/1670 train_time:40732ms step_avg:96.75ms +step:422/1670 train_time:40827ms step_avg:96.75ms +step:423/1670 train_time:40923ms step_avg:96.75ms +step:424/1670 train_time:41018ms step_avg:96.74ms +step:425/1670 train_time:41298ms step_avg:97.17ms +step:426/1670 train_time:41486ms step_avg:97.39ms +step:427/1670 train_time:41580ms step_avg:97.38ms +step:428/1670 train_time:41675ms step_avg:97.37ms +step:429/1670 train_time:41770ms step_avg:97.36ms +step:430/1670 train_time:41863ms step_avg:97.36ms +step:431/1670 train_time:41958ms step_avg:97.35ms +step:432/1670 train_time:42053ms step_avg:97.35ms +step:433/1670 train_time:42148ms step_avg:97.34ms +step:434/1670 train_time:42243ms step_avg:97.33ms +step:435/1670 train_time:42339ms step_avg:97.33ms +step:436/1670 train_time:42440ms step_avg:97.34ms +step:437/1670 train_time:42538ms step_avg:97.34ms +step:438/1670 train_time:42635ms step_avg:97.34ms +step:439/1670 train_time:42731ms step_avg:97.34ms +step:440/1670 train_time:42827ms step_avg:97.33ms +step:441/1670 train_time:42921ms step_avg:97.33ms +step:442/1670 train_time:43016ms step_avg:97.32ms +step:443/1670 train_time:43111ms step_avg:97.32ms +step:444/1670 train_time:43206ms step_avg:97.31ms +step:445/1670 train_time:43301ms step_avg:97.31ms +step:446/1670 train_time:43397ms step_avg:97.30ms +step:447/1670 train_time:43494ms step_avg:97.30ms +step:448/1670 train_time:43591ms step_avg:97.30ms +step:449/1670 train_time:43688ms step_avg:97.30ms +step:450/1670 train_time:43783ms step_avg:97.30ms +step:451/1670 train_time:43879ms step_avg:97.29ms +step:452/1670 train_time:43975ms step_avg:97.29ms +step:453/1670 train_time:44070ms step_avg:97.29ms +step:454/1670 train_time:44166ms step_avg:97.28ms +step:455/1670 train_time:44260ms step_avg:97.28ms +step:456/1670 train_time:44355ms step_avg:97.27ms +step:457/1670 train_time:44452ms step_avg:97.27ms +step:458/1670 train_time:44549ms step_avg:97.27ms +step:459/1670 train_time:44645ms step_avg:97.27ms +step:460/1670 train_time:44740ms step_avg:97.26ms +step:461/1670 train_time:44836ms step_avg:97.26ms +step:462/1670 train_time:44933ms step_avg:97.26ms +step:463/1670 train_time:45029ms step_avg:97.25ms +step:464/1670 train_time:45124ms step_avg:97.25ms +step:465/1670 train_time:45219ms step_avg:97.24ms +step:466/1670 train_time:45314ms step_avg:97.24ms +step:467/1670 train_time:45410ms step_avg:97.24ms +step:468/1670 train_time:45506ms step_avg:97.24ms +step:469/1670 train_time:45601ms step_avg:97.23ms +step:470/1670 train_time:45697ms step_avg:97.23ms +step:471/1670 train_time:45793ms step_avg:97.23ms +step:472/1670 train_time:45890ms step_avg:97.22ms +step:473/1670 train_time:45985ms step_avg:97.22ms +step:474/1670 train_time:46080ms step_avg:97.22ms +step:475/1670 train_time:46176ms step_avg:97.21ms +step:476/1670 train_time:46272ms step_avg:97.21ms +step:477/1670 train_time:46368ms step_avg:97.21ms +step:478/1670 train_time:46464ms step_avg:97.21ms +step:479/1670 train_time:46559ms step_avg:97.20ms +step:480/1670 train_time:46655ms step_avg:97.20ms +step:481/1670 train_time:46751ms step_avg:97.19ms +step:482/1670 train_time:46847ms step_avg:97.19ms +step:483/1670 train_time:46943ms step_avg:97.19ms +step:484/1670 train_time:47038ms step_avg:97.19ms +step:485/1670 train_time:47134ms step_avg:97.18ms +step:486/1670 train_time:47230ms step_avg:97.18ms +step:487/1670 train_time:47325ms step_avg:97.18ms +step:488/1670 train_time:47420ms step_avg:97.17ms +step:489/1670 train_time:47516ms step_avg:97.17ms +step:490/1670 train_time:47612ms step_avg:97.17ms +step:491/1670 train_time:47709ms step_avg:97.17ms +step:492/1670 train_time:47804ms step_avg:97.16ms +step:493/1670 train_time:47899ms step_avg:97.16ms +step:494/1670 train_time:47995ms step_avg:97.16ms +step:495/1670 train_time:48091ms step_avg:97.15ms +step:496/1670 train_time:48186ms step_avg:97.15ms +step:497/1670 train_time:48281ms step_avg:97.15ms +step:498/1670 train_time:48378ms step_avg:97.15ms +step:499/1670 train_time:48475ms step_avg:97.14ms +step:500/1670 train_time:48570ms step_avg:97.14ms +step:500/1670 val_loss:3.7143 train_time:48665ms step_avg:97.33ms +step:501/1670 train_time:48686ms step_avg:97.18ms +step:502/1670 train_time:48768ms step_avg:97.15ms +step:503/1670 train_time:48871ms step_avg:97.16ms +step:504/1670 train_time:48969ms step_avg:97.16ms +step:505/1670 train_time:49064ms step_avg:97.16ms +step:506/1670 train_time:49159ms step_avg:97.15ms +step:507/1670 train_time:49253ms step_avg:97.15ms +step:508/1670 train_time:49348ms step_avg:97.14ms +step:509/1670 train_time:49443ms step_avg:97.14ms +step:510/1670 train_time:49538ms step_avg:97.13ms +step:511/1670 train_time:49634ms step_avg:97.13ms +step:512/1670 train_time:49731ms step_avg:97.13ms +step:513/1670 train_time:49829ms step_avg:97.13ms +step:514/1670 train_time:49927ms step_avg:97.14ms +step:515/1670 train_time:50024ms step_avg:97.13ms +step:516/1670 train_time:50120ms step_avg:97.13ms +step:517/1670 train_time:50214ms step_avg:97.13ms +step:518/1670 train_time:50309ms step_avg:97.12ms +step:519/1670 train_time:50404ms step_avg:97.12ms +step:520/1670 train_time:50499ms step_avg:97.11ms +step:521/1670 train_time:50595ms step_avg:97.11ms +step:522/1670 train_time:50691ms step_avg:97.11ms +step:523/1670 train_time:50788ms step_avg:97.11ms +step:524/1670 train_time:50885ms step_avg:97.11ms +step:525/1670 train_time:50982ms step_avg:97.11ms +step:526/1670 train_time:51077ms step_avg:97.11ms +step:527/1670 train_time:51173ms step_avg:97.10ms +step:528/1670 train_time:51268ms step_avg:97.10ms +step:529/1670 train_time:51363ms step_avg:97.09ms +step:530/1670 train_time:51459ms step_avg:97.09ms +step:531/1670 train_time:51554ms step_avg:97.09ms +step:532/1670 train_time:51649ms step_avg:97.08ms +step:533/1670 train_time:51745ms step_avg:97.08ms +step:534/1670 train_time:51842ms step_avg:97.08ms +step:535/1670 train_time:51938ms step_avg:97.08ms +step:536/1670 train_time:52034ms step_avg:97.08ms +step:537/1670 train_time:52129ms step_avg:97.07ms +step:538/1670 train_time:52225ms step_avg:97.07ms +step:539/1670 train_time:52320ms step_avg:97.07ms +step:540/1670 train_time:52416ms step_avg:97.07ms +step:541/1670 train_time:52512ms step_avg:97.06ms +step:542/1670 train_time:52608ms step_avg:97.06ms +step:543/1670 train_time:52703ms step_avg:97.06ms +step:544/1670 train_time:52799ms step_avg:97.06ms +step:545/1670 train_time:52895ms step_avg:97.06ms +step:546/1670 train_time:52991ms step_avg:97.05ms +step:547/1670 train_time:53087ms step_avg:97.05ms +step:548/1670 train_time:53183ms step_avg:97.05ms +step:549/1670 train_time:53279ms step_avg:97.05ms +step:550/1670 train_time:53374ms step_avg:97.04ms +step:551/1670 train_time:53470ms step_avg:97.04ms +step:552/1670 train_time:53566ms step_avg:97.04ms +step:553/1670 train_time:53662ms step_avg:97.04ms +step:554/1670 train_time:53759ms step_avg:97.04ms +step:555/1670 train_time:53855ms step_avg:97.04ms +step:556/1670 train_time:53951ms step_avg:97.03ms +step:557/1670 train_time:54047ms step_avg:97.03ms +step:558/1670 train_time:54143ms step_avg:97.03ms +step:559/1670 train_time:54240ms step_avg:97.03ms +step:560/1670 train_time:54337ms step_avg:97.03ms +step:561/1670 train_time:54434ms step_avg:97.03ms +step:562/1670 train_time:54530ms step_avg:97.03ms +step:563/1670 train_time:54626ms step_avg:97.03ms +step:564/1670 train_time:54723ms step_avg:97.03ms +step:565/1670 train_time:54821ms step_avg:97.03ms +step:566/1670 train_time:54918ms step_avg:97.03ms +step:567/1670 train_time:55017ms step_avg:97.03ms +step:568/1670 train_time:55113ms step_avg:97.03ms +step:569/1670 train_time:55209ms step_avg:97.03ms +step:570/1670 train_time:55306ms step_avg:97.03ms +step:571/1670 train_time:55404ms step_avg:97.03ms +step:572/1670 train_time:55501ms step_avg:97.03ms +step:573/1670 train_time:55598ms step_avg:97.03ms +step:574/1670 train_time:55695ms step_avg:97.03ms +step:575/1670 train_time:55792ms step_avg:97.03ms +step:576/1670 train_time:55889ms step_avg:97.03ms +step:577/1670 train_time:55987ms step_avg:97.03ms +step:578/1670 train_time:56085ms step_avg:97.03ms +step:579/1670 train_time:56183ms step_avg:97.04ms +step:580/1670 train_time:56281ms step_avg:97.04ms +step:581/1670 train_time:56379ms step_avg:97.04ms +step:582/1670 train_time:56476ms step_avg:97.04ms +step:583/1670 train_time:56572ms step_avg:97.04ms +step:584/1670 train_time:56669ms step_avg:97.04ms +step:585/1670 train_time:56767ms step_avg:97.04ms +step:586/1670 train_time:56864ms step_avg:97.04ms +step:587/1670 train_time:56962ms step_avg:97.04ms +step:588/1670 train_time:57060ms step_avg:97.04ms +step:589/1670 train_time:57158ms step_avg:97.04ms +step:590/1670 train_time:57255ms step_avg:97.04ms +step:591/1670 train_time:57352ms step_avg:97.04ms +step:592/1670 train_time:57449ms step_avg:97.04ms +step:593/1670 train_time:57546ms step_avg:97.04ms +step:594/1670 train_time:57643ms step_avg:97.04ms +step:595/1670 train_time:57741ms step_avg:97.04ms +step:596/1670 train_time:57838ms step_avg:97.04ms +step:597/1670 train_time:57934ms step_avg:97.04ms +step:598/1670 train_time:58031ms step_avg:97.04ms +step:599/1670 train_time:58128ms step_avg:97.04ms +step:600/1670 train_time:58227ms step_avg:97.05ms +step:601/1670 train_time:58324ms step_avg:97.04ms +step:602/1670 train_time:58421ms step_avg:97.04ms +step:603/1670 train_time:58518ms step_avg:97.05ms +step:604/1670 train_time:58616ms step_avg:97.05ms +step:605/1670 train_time:58712ms step_avg:97.04ms +step:606/1670 train_time:58809ms step_avg:97.04ms +step:607/1670 train_time:58906ms step_avg:97.05ms +step:608/1670 train_time:59004ms step_avg:97.05ms +step:609/1670 train_time:59102ms step_avg:97.05ms +step:610/1670 train_time:59199ms step_avg:97.05ms +step:611/1670 train_time:59296ms step_avg:97.05ms +step:612/1670 train_time:59393ms step_avg:97.05ms +step:613/1670 train_time:59490ms step_avg:97.05ms +step:614/1670 train_time:59587ms step_avg:97.05ms +step:615/1670 train_time:59685ms step_avg:97.05ms +step:616/1670 train_time:59782ms step_avg:97.05ms +step:617/1670 train_time:59880ms step_avg:97.05ms +step:618/1670 train_time:59977ms step_avg:97.05ms +step:619/1670 train_time:60074ms step_avg:97.05ms +step:620/1670 train_time:60170ms step_avg:97.05ms +step:621/1670 train_time:60268ms step_avg:97.05ms +step:622/1670 train_time:60366ms step_avg:97.05ms +step:623/1670 train_time:60463ms step_avg:97.05ms +step:624/1670 train_time:60561ms step_avg:97.05ms +step:625/1670 train_time:60658ms step_avg:97.05ms +step:625/1670 val_loss:3.6151 train_time:60754ms step_avg:97.21ms +step:626/1670 train_time:60776ms step_avg:97.09ms +step:627/1670 train_time:60861ms step_avg:97.07ms +step:628/1670 train_time:60959ms step_avg:97.07ms +step:629/1670 train_time:61057ms step_avg:97.07ms +step:630/1670 train_time:61153ms step_avg:97.07ms +step:631/1670 train_time:61249ms step_avg:97.07ms +step:632/1670 train_time:61344ms step_avg:97.06ms +step:633/1670 train_time:61440ms step_avg:97.06ms +step:634/1670 train_time:61536ms step_avg:97.06ms +step:635/1670 train_time:61632ms step_avg:97.06ms +step:636/1670 train_time:61731ms step_avg:97.06ms +step:637/1670 train_time:61829ms step_avg:97.06ms +step:638/1670 train_time:61927ms step_avg:97.06ms +step:639/1670 train_time:62299ms step_avg:97.49ms +step:640/1670 train_time:62381ms step_avg:97.47ms +step:641/1670 train_time:62477ms step_avg:97.47ms +step:642/1670 train_time:62573ms step_avg:97.47ms +step:643/1670 train_time:62669ms step_avg:97.46ms +step:644/1670 train_time:62765ms step_avg:97.46ms +step:645/1670 train_time:62862ms step_avg:97.46ms +step:646/1670 train_time:62958ms step_avg:97.46ms +step:647/1670 train_time:63055ms step_avg:97.46ms +step:648/1670 train_time:63151ms step_avg:97.45ms +step:649/1670 train_time:63251ms step_avg:97.46ms +step:650/1670 train_time:63350ms step_avg:97.46ms +step:651/1670 train_time:63447ms step_avg:97.46ms +step:652/1670 train_time:63543ms step_avg:97.46ms +step:653/1670 train_time:63640ms step_avg:97.46ms +step:654/1670 train_time:63738ms step_avg:97.46ms +step:655/1670 train_time:63834ms step_avg:97.46ms +step:656/1670 train_time:63930ms step_avg:97.46ms +step:657/1670 train_time:64026ms step_avg:97.45ms +step:658/1670 train_time:64124ms step_avg:97.45ms +step:659/1670 train_time:64223ms step_avg:97.46ms +step:660/1670 train_time:64324ms step_avg:97.46ms +step:661/1670 train_time:64424ms step_avg:97.46ms +step:662/1670 train_time:64521ms step_avg:97.46ms +step:663/1670 train_time:64619ms step_avg:97.46ms +step:664/1670 train_time:64715ms step_avg:97.46ms +step:665/1670 train_time:64812ms step_avg:97.46ms +step:666/1670 train_time:64908ms step_avg:97.46ms +step:667/1670 train_time:65004ms step_avg:97.46ms +step:668/1670 train_time:65102ms step_avg:97.46ms +step:669/1670 train_time:65200ms step_avg:97.46ms +step:670/1670 train_time:65299ms step_avg:97.46ms +step:671/1670 train_time:65398ms step_avg:97.46ms +step:672/1670 train_time:65496ms step_avg:97.46ms +step:673/1670 train_time:65594ms step_avg:97.47ms +step:674/1670 train_time:65691ms step_avg:97.46ms +step:675/1670 train_time:65787ms step_avg:97.46ms +step:676/1670 train_time:65884ms step_avg:97.46ms +step:677/1670 train_time:65980ms step_avg:97.46ms +step:678/1670 train_time:66077ms step_avg:97.46ms +step:679/1670 train_time:66175ms step_avg:97.46ms +step:680/1670 train_time:66272ms step_avg:97.46ms +step:681/1670 train_time:66370ms step_avg:97.46ms +step:682/1670 train_time:66467ms step_avg:97.46ms +step:683/1670 train_time:66565ms step_avg:97.46ms +step:684/1670 train_time:66663ms step_avg:97.46ms +step:685/1670 train_time:66759ms step_avg:97.46ms +step:686/1670 train_time:66856ms step_avg:97.46ms +step:687/1670 train_time:66953ms step_avg:97.46ms +step:688/1670 train_time:67050ms step_avg:97.46ms +step:689/1670 train_time:67145ms step_avg:97.45ms +step:690/1670 train_time:67242ms step_avg:97.45ms +step:691/1670 train_time:67341ms step_avg:97.45ms +step:692/1670 train_time:67438ms step_avg:97.45ms +step:693/1670 train_time:67537ms step_avg:97.46ms +step:694/1670 train_time:67634ms step_avg:97.46ms +step:695/1670 train_time:67731ms step_avg:97.46ms +step:696/1670 train_time:67828ms step_avg:97.45ms +step:697/1670 train_time:67924ms step_avg:97.45ms +step:698/1670 train_time:68021ms step_avg:97.45ms +step:699/1670 train_time:68120ms step_avg:97.45ms +step:700/1670 train_time:68218ms step_avg:97.45ms +step:701/1670 train_time:68315ms step_avg:97.45ms +step:702/1670 train_time:68412ms step_avg:97.45ms +step:703/1670 train_time:68509ms step_avg:97.45ms +step:704/1670 train_time:68605ms step_avg:97.45ms +step:705/1670 train_time:68702ms step_avg:97.45ms +step:706/1670 train_time:68800ms step_avg:97.45ms +step:707/1670 train_time:68898ms step_avg:97.45ms +step:708/1670 train_time:68995ms step_avg:97.45ms +step:709/1670 train_time:69092ms step_avg:97.45ms +step:710/1670 train_time:69188ms step_avg:97.45ms +step:711/1670 train_time:69284ms step_avg:97.45ms +step:712/1670 train_time:69382ms step_avg:97.45ms +step:713/1670 train_time:69480ms step_avg:97.45ms +step:714/1670 train_time:69578ms step_avg:97.45ms +step:715/1670 train_time:69675ms step_avg:97.45ms +step:716/1670 train_time:69773ms step_avg:97.45ms +step:717/1670 train_time:69869ms step_avg:97.45ms +step:718/1670 train_time:69966ms step_avg:97.45ms +step:719/1670 train_time:70062ms step_avg:97.44ms +step:720/1670 train_time:70160ms step_avg:97.44ms +step:721/1670 train_time:70258ms step_avg:97.44ms +step:722/1670 train_time:70355ms step_avg:97.44ms +step:723/1670 train_time:70452ms step_avg:97.44ms +step:724/1670 train_time:70548ms step_avg:97.44ms +step:725/1670 train_time:70645ms step_avg:97.44ms +step:726/1670 train_time:70743ms step_avg:97.44ms +step:727/1670 train_time:70840ms step_avg:97.44ms +step:728/1670 train_time:70939ms step_avg:97.44ms +step:729/1670 train_time:71037ms step_avg:97.44ms +step:730/1670 train_time:71135ms step_avg:97.44ms +step:731/1670 train_time:71231ms step_avg:97.44ms +step:732/1670 train_time:71328ms step_avg:97.44ms +step:733/1670 train_time:71426ms step_avg:97.44ms +step:734/1670 train_time:71522ms step_avg:97.44ms +step:735/1670 train_time:71620ms step_avg:97.44ms +step:736/1670 train_time:71718ms step_avg:97.44ms +step:737/1670 train_time:71816ms step_avg:97.44ms +step:738/1670 train_time:71913ms step_avg:97.44ms +step:739/1670 train_time:72008ms step_avg:97.44ms +step:740/1670 train_time:72106ms step_avg:97.44ms +step:741/1670 train_time:72204ms step_avg:97.44ms +step:742/1670 train_time:72302ms step_avg:97.44ms +step:743/1670 train_time:72400ms step_avg:97.44ms +step:744/1670 train_time:72498ms step_avg:97.44ms +step:745/1670 train_time:72595ms step_avg:97.44ms +step:746/1670 train_time:72692ms step_avg:97.44ms +step:747/1670 train_time:72788ms step_avg:97.44ms +step:748/1670 train_time:72885ms step_avg:97.44ms +step:749/1670 train_time:72982ms step_avg:97.44ms +step:750/1670 train_time:73081ms step_avg:97.44ms +step:750/1670 val_loss:3.5616 train_time:73177ms step_avg:97.57ms +step:751/1670 train_time:73200ms step_avg:97.47ms +step:752/1670 train_time:73283ms step_avg:97.45ms +step:753/1670 train_time:73381ms step_avg:97.45ms +step:754/1670 train_time:73478ms step_avg:97.45ms +step:755/1670 train_time:73575ms step_avg:97.45ms +step:756/1670 train_time:73672ms step_avg:97.45ms +step:757/1670 train_time:73768ms step_avg:97.45ms +step:758/1670 train_time:73865ms step_avg:97.45ms +step:759/1670 train_time:73961ms step_avg:97.45ms +step:760/1670 train_time:74057ms step_avg:97.44ms +step:761/1670 train_time:74157ms step_avg:97.45ms +step:762/1670 train_time:74259ms step_avg:97.45ms +step:763/1670 train_time:74358ms step_avg:97.45ms +step:764/1670 train_time:74456ms step_avg:97.46ms +step:765/1670 train_time:74553ms step_avg:97.46ms +step:766/1670 train_time:74650ms step_avg:97.45ms +step:767/1670 train_time:74747ms step_avg:97.45ms +step:768/1670 train_time:74843ms step_avg:97.45ms +step:769/1670 train_time:74939ms step_avg:97.45ms +step:770/1670 train_time:75035ms step_avg:97.45ms +step:771/1670 train_time:75133ms step_avg:97.45ms +step:772/1670 train_time:75233ms step_avg:97.45ms +step:773/1670 train_time:75333ms step_avg:97.46ms +step:774/1670 train_time:75431ms step_avg:97.46ms +step:775/1670 train_time:75529ms step_avg:97.46ms +step:776/1670 train_time:75626ms step_avg:97.46ms +step:777/1670 train_time:75722ms step_avg:97.45ms +step:778/1670 train_time:75818ms step_avg:97.45ms +step:779/1670 train_time:75915ms step_avg:97.45ms +step:780/1670 train_time:76012ms step_avg:97.45ms +step:781/1670 train_time:76110ms step_avg:97.45ms +step:782/1670 train_time:76207ms step_avg:97.45ms +step:783/1670 train_time:76305ms step_avg:97.45ms +step:784/1670 train_time:76402ms step_avg:97.45ms +step:785/1670 train_time:76499ms step_avg:97.45ms +step:786/1670 train_time:76596ms step_avg:97.45ms +step:787/1670 train_time:76694ms step_avg:97.45ms +step:788/1670 train_time:76792ms step_avg:97.45ms +step:789/1670 train_time:76888ms step_avg:97.45ms +step:790/1670 train_time:76985ms step_avg:97.45ms +step:791/1670 train_time:77081ms step_avg:97.45ms +step:792/1670 train_time:77178ms step_avg:97.45ms +step:793/1670 train_time:77276ms step_avg:97.45ms +step:794/1670 train_time:77374ms step_avg:97.45ms +step:795/1670 train_time:77472ms step_avg:97.45ms +step:796/1670 train_time:77569ms step_avg:97.45ms +step:797/1670 train_time:77666ms step_avg:97.45ms +step:798/1670 train_time:77762ms step_avg:97.45ms +step:799/1670 train_time:77858ms step_avg:97.44ms +step:800/1670 train_time:77955ms step_avg:97.44ms +step:801/1670 train_time:78052ms step_avg:97.44ms +step:802/1670 train_time:78150ms step_avg:97.44ms +step:803/1670 train_time:78248ms step_avg:97.44ms +step:804/1670 train_time:78345ms step_avg:97.44ms +step:805/1670 train_time:78442ms step_avg:97.44ms +step:806/1670 train_time:78539ms step_avg:97.44ms +step:807/1670 train_time:78637ms step_avg:97.44ms +step:808/1670 train_time:78734ms step_avg:97.44ms +step:809/1670 train_time:78832ms step_avg:97.44ms +step:810/1670 train_time:78929ms step_avg:97.44ms +step:811/1670 train_time:79026ms step_avg:97.44ms +step:812/1670 train_time:79123ms step_avg:97.44ms +step:813/1670 train_time:79219ms step_avg:97.44ms +step:814/1670 train_time:79316ms step_avg:97.44ms +step:815/1670 train_time:79414ms step_avg:97.44ms +step:816/1670 train_time:79513ms step_avg:97.44ms +step:817/1670 train_time:79612ms step_avg:97.44ms +step:818/1670 train_time:79708ms step_avg:97.44ms +step:819/1670 train_time:79806ms step_avg:97.44ms +step:820/1670 train_time:79902ms step_avg:97.44ms +step:821/1670 train_time:79999ms step_avg:97.44ms +step:822/1670 train_time:80096ms step_avg:97.44ms +step:823/1670 train_time:80195ms step_avg:97.44ms +step:824/1670 train_time:80292ms step_avg:97.44ms +step:825/1670 train_time:80391ms step_avg:97.44ms +step:826/1670 train_time:80489ms step_avg:97.44ms +step:827/1670 train_time:80586ms step_avg:97.44ms +step:828/1670 train_time:80683ms step_avg:97.44ms +step:829/1670 train_time:80780ms step_avg:97.44ms +step:830/1670 train_time:80877ms step_avg:97.44ms +step:831/1670 train_time:80974ms step_avg:97.44ms +step:832/1670 train_time:81072ms step_avg:97.44ms +step:833/1670 train_time:81170ms step_avg:97.44ms +step:834/1670 train_time:81267ms step_avg:97.44ms +step:835/1670 train_time:81364ms step_avg:97.44ms +step:836/1670 train_time:81461ms step_avg:97.44ms +step:837/1670 train_time:81557ms step_avg:97.44ms +step:838/1670 train_time:81655ms step_avg:97.44ms +step:839/1670 train_time:81753ms step_avg:97.44ms +step:840/1670 train_time:81850ms step_avg:97.44ms +step:841/1670 train_time:81948ms step_avg:97.44ms +step:842/1670 train_time:82045ms step_avg:97.44ms +step:843/1670 train_time:82143ms step_avg:97.44ms +step:844/1670 train_time:82240ms step_avg:97.44ms +step:845/1670 train_time:82338ms step_avg:97.44ms +step:846/1670 train_time:82435ms step_avg:97.44ms +step:847/1670 train_time:82533ms step_avg:97.44ms +step:848/1670 train_time:82631ms step_avg:97.44ms +step:849/1670 train_time:82728ms step_avg:97.44ms +step:850/1670 train_time:82825ms step_avg:97.44ms +step:851/1670 train_time:83093ms step_avg:97.64ms +step:852/1670 train_time:83284ms step_avg:97.75ms +step:853/1670 train_time:83379ms step_avg:97.75ms +step:854/1670 train_time:83476ms step_avg:97.75ms +step:855/1670 train_time:83572ms step_avg:97.75ms +step:856/1670 train_time:83669ms step_avg:97.74ms +step:857/1670 train_time:83765ms step_avg:97.74ms +step:858/1670 train_time:83860ms step_avg:97.74ms +step:859/1670 train_time:83957ms step_avg:97.74ms +step:860/1670 train_time:84053ms step_avg:97.74ms +step:861/1670 train_time:84155ms step_avg:97.74ms +step:862/1670 train_time:84256ms step_avg:97.75ms +step:863/1670 train_time:84356ms step_avg:97.75ms +step:864/1670 train_time:84455ms step_avg:97.75ms +step:865/1670 train_time:84553ms step_avg:97.75ms +step:866/1670 train_time:84649ms step_avg:97.75ms +step:867/1670 train_time:84746ms step_avg:97.75ms +step:868/1670 train_time:84842ms step_avg:97.74ms +step:869/1670 train_time:84938ms step_avg:97.74ms +step:870/1670 train_time:85035ms step_avg:97.74ms +step:871/1670 train_time:85133ms step_avg:97.74ms +step:872/1670 train_time:85233ms step_avg:97.74ms +step:873/1670 train_time:85333ms step_avg:97.75ms +step:874/1670 train_time:85432ms step_avg:97.75ms +step:875/1670 train_time:85531ms step_avg:97.75ms +step:875/1670 val_loss:3.5198 train_time:85627ms step_avg:97.86ms +step:876/1670 train_time:85648ms step_avg:97.77ms +step:877/1670 train_time:85733ms step_avg:97.76ms +step:878/1670 train_time:85833ms step_avg:97.76ms +step:879/1670 train_time:85930ms step_avg:97.76ms +step:880/1670 train_time:86026ms step_avg:97.76ms +step:881/1670 train_time:86122ms step_avg:97.75ms +step:882/1670 train_time:86218ms step_avg:97.75ms +step:883/1670 train_time:86314ms step_avg:97.75ms +step:884/1670 train_time:86410ms step_avg:97.75ms +step:885/1670 train_time:86508ms step_avg:97.75ms +step:886/1670 train_time:86608ms step_avg:97.75ms +step:887/1670 train_time:86709ms step_avg:97.76ms +step:888/1670 train_time:86809ms step_avg:97.76ms +step:889/1670 train_time:86907ms step_avg:97.76ms +step:890/1670 train_time:87004ms step_avg:97.76ms +step:891/1670 train_time:87101ms step_avg:97.76ms +step:892/1670 train_time:87197ms step_avg:97.75ms +step:893/1670 train_time:87293ms step_avg:97.75ms +step:894/1670 train_time:87389ms step_avg:97.75ms +step:895/1670 train_time:87486ms step_avg:97.75ms +step:896/1670 train_time:87584ms step_avg:97.75ms +step:897/1670 train_time:87683ms step_avg:97.75ms +step:898/1670 train_time:87780ms step_avg:97.75ms +step:899/1670 train_time:87877ms step_avg:97.75ms +step:900/1670 train_time:87975ms step_avg:97.75ms +step:901/1670 train_time:88072ms step_avg:97.75ms +step:902/1670 train_time:88170ms step_avg:97.75ms +step:903/1670 train_time:88266ms step_avg:97.75ms +step:904/1670 train_time:88363ms step_avg:97.75ms +step:905/1670 train_time:88460ms step_avg:97.75ms +step:906/1670 train_time:88556ms step_avg:97.74ms +step:907/1670 train_time:88654ms step_avg:97.74ms +step:908/1670 train_time:88753ms step_avg:97.75ms +step:909/1670 train_time:88850ms step_avg:97.74ms +step:910/1670 train_time:88949ms step_avg:97.75ms +step:911/1670 train_time:89045ms step_avg:97.74ms +step:912/1670 train_time:89143ms step_avg:97.74ms +step:913/1670 train_time:89239ms step_avg:97.74ms +step:914/1670 train_time:89335ms step_avg:97.74ms +step:915/1670 train_time:89433ms step_avg:97.74ms +step:916/1670 train_time:89530ms step_avg:97.74ms +step:917/1670 train_time:89629ms step_avg:97.74ms +step:918/1670 train_time:89727ms step_avg:97.74ms +step:919/1670 train_time:89826ms step_avg:97.74ms +step:920/1670 train_time:89925ms step_avg:97.74ms +step:921/1670 train_time:90022ms step_avg:97.74ms +step:922/1670 train_time:90118ms step_avg:97.74ms +step:923/1670 train_time:90215ms step_avg:97.74ms +step:924/1670 train_time:90312ms step_avg:97.74ms +step:925/1670 train_time:90410ms step_avg:97.74ms +step:926/1670 train_time:90507ms step_avg:97.74ms +step:927/1670 train_time:90605ms step_avg:97.74ms +step:928/1670 train_time:90702ms step_avg:97.74ms +step:929/1670 train_time:90800ms step_avg:97.74ms +step:930/1670 train_time:90897ms step_avg:97.74ms +step:931/1670 train_time:90994ms step_avg:97.74ms +step:932/1670 train_time:91091ms step_avg:97.74ms +step:933/1670 train_time:91189ms step_avg:97.74ms +step:934/1670 train_time:91285ms step_avg:97.74ms +step:935/1670 train_time:91382ms step_avg:97.73ms +step:936/1670 train_time:91479ms step_avg:97.73ms +step:937/1670 train_time:91575ms step_avg:97.73ms +step:938/1670 train_time:91673ms step_avg:97.73ms +step:939/1670 train_time:91771ms step_avg:97.73ms +step:940/1670 train_time:91869ms step_avg:97.73ms +step:941/1670 train_time:91969ms step_avg:97.74ms +step:942/1670 train_time:92068ms step_avg:97.74ms +step:943/1670 train_time:92165ms step_avg:97.74ms +step:944/1670 train_time:92262ms step_avg:97.74ms +step:945/1670 train_time:92358ms step_avg:97.73ms +step:946/1670 train_time:92455ms step_avg:97.73ms +step:947/1670 train_time:92552ms step_avg:97.73ms +step:948/1670 train_time:92649ms step_avg:97.73ms +step:949/1670 train_time:92746ms step_avg:97.73ms +step:950/1670 train_time:92844ms step_avg:97.73ms +step:951/1670 train_time:92942ms step_avg:97.73ms +step:952/1670 train_time:93038ms step_avg:97.73ms +step:953/1670 train_time:93135ms step_avg:97.73ms +step:954/1670 train_time:93233ms step_avg:97.73ms +step:955/1670 train_time:93330ms step_avg:97.73ms +step:956/1670 train_time:93429ms step_avg:97.73ms +step:957/1670 train_time:93526ms step_avg:97.73ms +step:958/1670 train_time:93623ms step_avg:97.73ms +step:959/1670 train_time:93720ms step_avg:97.73ms +step:960/1670 train_time:93817ms step_avg:97.73ms +step:961/1670 train_time:93915ms step_avg:97.73ms +step:962/1670 train_time:94014ms step_avg:97.73ms +step:963/1670 train_time:94111ms step_avg:97.73ms +step:964/1670 train_time:94209ms step_avg:97.73ms +step:965/1670 train_time:94307ms step_avg:97.73ms +step:966/1670 train_time:94404ms step_avg:97.73ms +step:967/1670 train_time:94501ms step_avg:97.73ms +step:968/1670 train_time:94598ms step_avg:97.73ms +step:969/1670 train_time:94695ms step_avg:97.72ms +step:970/1670 train_time:94792ms step_avg:97.72ms +step:971/1670 train_time:94891ms step_avg:97.72ms +step:972/1670 train_time:94988ms step_avg:97.72ms +step:973/1670 train_time:95085ms step_avg:97.72ms +step:974/1670 train_time:95183ms step_avg:97.72ms +step:975/1670 train_time:95280ms step_avg:97.72ms +step:976/1670 train_time:95376ms step_avg:97.72ms +step:977/1670 train_time:95473ms step_avg:97.72ms +step:978/1670 train_time:95571ms step_avg:97.72ms +step:979/1670 train_time:95669ms step_avg:97.72ms +step:980/1670 train_time:95767ms step_avg:97.72ms +step:981/1670 train_time:95865ms step_avg:97.72ms +step:982/1670 train_time:95962ms step_avg:97.72ms +step:983/1670 train_time:96059ms step_avg:97.72ms +step:984/1670 train_time:96156ms step_avg:97.72ms +step:985/1670 train_time:96254ms step_avg:97.72ms +step:986/1670 train_time:96351ms step_avg:97.72ms +step:987/1670 train_time:96448ms step_avg:97.72ms +step:988/1670 train_time:96546ms step_avg:97.72ms +step:989/1670 train_time:96643ms step_avg:97.72ms +step:990/1670 train_time:96739ms step_avg:97.72ms +step:991/1670 train_time:96836ms step_avg:97.72ms +step:992/1670 train_time:96934ms step_avg:97.72ms +step:993/1670 train_time:97032ms step_avg:97.72ms +step:994/1670 train_time:97130ms step_avg:97.72ms +step:995/1670 train_time:97227ms step_avg:97.72ms +step:996/1670 train_time:97325ms step_avg:97.72ms +step:997/1670 train_time:97423ms step_avg:97.72ms +step:998/1670 train_time:97520ms step_avg:97.71ms +step:999/1670 train_time:97616ms step_avg:97.71ms +step:1000/1670 train_time:97713ms step_avg:97.71ms +step:1000/1670 val_loss:3.4779 train_time:97810ms step_avg:97.81ms +step:1001/1670 train_time:97831ms step_avg:97.73ms +step:1002/1670 train_time:97916ms step_avg:97.72ms +step:1003/1670 train_time:98017ms step_avg:97.72ms +step:1004/1670 train_time:98113ms step_avg:97.72ms +step:1005/1670 train_time:98209ms step_avg:97.72ms +step:1006/1670 train_time:98306ms step_avg:97.72ms +step:1007/1670 train_time:98401ms step_avg:97.72ms +step:1008/1670 train_time:98497ms step_avg:97.72ms +step:1009/1670 train_time:98594ms step_avg:97.71ms +step:1010/1670 train_time:98691ms step_avg:97.71ms +step:1011/1670 train_time:98788ms step_avg:97.71ms +step:1012/1670 train_time:98889ms step_avg:97.72ms +step:1013/1670 train_time:98991ms step_avg:97.72ms +step:1014/1670 train_time:99089ms step_avg:97.72ms +step:1015/1670 train_time:99187ms step_avg:97.72ms +step:1016/1670 train_time:99283ms step_avg:97.72ms +step:1017/1670 train_time:99380ms step_avg:97.72ms +step:1018/1670 train_time:99477ms step_avg:97.72ms +step:1019/1670 train_time:99574ms step_avg:97.72ms +step:1020/1670 train_time:99670ms step_avg:97.72ms +step:1021/1670 train_time:99767ms step_avg:97.71ms +step:1022/1670 train_time:99865ms step_avg:97.72ms +step:1023/1670 train_time:99966ms step_avg:97.72ms +step:1024/1670 train_time:100065ms step_avg:97.72ms +step:1025/1670 train_time:100163ms step_avg:97.72ms +step:1026/1670 train_time:100260ms step_avg:97.72ms +step:1027/1670 train_time:100356ms step_avg:97.72ms +step:1028/1670 train_time:100452ms step_avg:97.72ms +step:1029/1670 train_time:100549ms step_avg:97.72ms +step:1030/1670 train_time:100646ms step_avg:97.71ms +step:1031/1670 train_time:100743ms step_avg:97.71ms +step:1032/1670 train_time:100841ms step_avg:97.71ms +step:1033/1670 train_time:100940ms step_avg:97.72ms +step:1034/1670 train_time:101038ms step_avg:97.72ms +step:1035/1670 train_time:101136ms step_avg:97.72ms +step:1036/1670 train_time:101232ms step_avg:97.71ms +step:1037/1670 train_time:101329ms step_avg:97.71ms +step:1038/1670 train_time:101426ms step_avg:97.71ms +step:1039/1670 train_time:101523ms step_avg:97.71ms +step:1040/1670 train_time:101619ms step_avg:97.71ms +step:1041/1670 train_time:101716ms step_avg:97.71ms +step:1042/1670 train_time:101813ms step_avg:97.71ms +step:1043/1670 train_time:101911ms step_avg:97.71ms +step:1044/1670 train_time:102009ms step_avg:97.71ms +step:1045/1670 train_time:102107ms step_avg:97.71ms +step:1046/1670 train_time:102205ms step_avg:97.71ms +step:1047/1670 train_time:102304ms step_avg:97.71ms +step:1048/1670 train_time:102400ms step_avg:97.71ms +step:1049/1670 train_time:102496ms step_avg:97.71ms +step:1050/1670 train_time:102593ms step_avg:97.71ms +step:1051/1670 train_time:102689ms step_avg:97.71ms +step:1052/1670 train_time:102787ms step_avg:97.71ms +step:1053/1670 train_time:102885ms step_avg:97.71ms +step:1054/1670 train_time:102983ms step_avg:97.71ms +step:1055/1670 train_time:103082ms step_avg:97.71ms +step:1056/1670 train_time:103180ms step_avg:97.71ms +step:1057/1670 train_time:103278ms step_avg:97.71ms +step:1058/1670 train_time:103375ms step_avg:97.71ms +step:1059/1670 train_time:103472ms step_avg:97.71ms +step:1060/1670 train_time:103569ms step_avg:97.71ms +step:1061/1670 train_time:103666ms step_avg:97.71ms +step:1062/1670 train_time:103930ms step_avg:97.86ms +step:1063/1670 train_time:104020ms step_avg:97.86ms +step:1064/1670 train_time:104116ms step_avg:97.85ms +step:1065/1670 train_time:104212ms step_avg:97.85ms +step:1066/1670 train_time:104309ms step_avg:97.85ms +step:1067/1670 train_time:104405ms step_avg:97.85ms +step:1068/1670 train_time:104501ms step_avg:97.85ms +step:1069/1670 train_time:104597ms step_avg:97.85ms +step:1070/1670 train_time:104693ms step_avg:97.84ms +step:1071/1670 train_time:104789ms step_avg:97.84ms +step:1072/1670 train_time:104891ms step_avg:97.85ms +step:1073/1670 train_time:104990ms step_avg:97.85ms +step:1074/1670 train_time:105088ms step_avg:97.85ms +step:1075/1670 train_time:105187ms step_avg:97.85ms +step:1076/1670 train_time:105284ms step_avg:97.85ms +step:1077/1670 train_time:105383ms step_avg:97.85ms +step:1078/1670 train_time:105480ms step_avg:97.85ms +step:1079/1670 train_time:105576ms step_avg:97.85ms +step:1080/1670 train_time:105673ms step_avg:97.85ms +step:1081/1670 train_time:105770ms step_avg:97.84ms +step:1082/1670 train_time:105868ms step_avg:97.84ms +step:1083/1670 train_time:105966ms step_avg:97.85ms +step:1084/1670 train_time:106066ms step_avg:97.85ms +step:1085/1670 train_time:106164ms step_avg:97.85ms +step:1086/1670 train_time:106261ms step_avg:97.85ms +step:1087/1670 train_time:106358ms step_avg:97.85ms +step:1088/1670 train_time:106454ms step_avg:97.84ms +step:1089/1670 train_time:106550ms step_avg:97.84ms +step:1090/1670 train_time:106648ms step_avg:97.84ms +step:1091/1670 train_time:106745ms step_avg:97.84ms +step:1092/1670 train_time:106843ms step_avg:97.84ms +step:1093/1670 train_time:106940ms step_avg:97.84ms +step:1094/1670 train_time:107038ms step_avg:97.84ms +step:1095/1670 train_time:107135ms step_avg:97.84ms +step:1096/1670 train_time:107232ms step_avg:97.84ms +step:1097/1670 train_time:107329ms step_avg:97.84ms +step:1098/1670 train_time:107427ms step_avg:97.84ms +step:1099/1670 train_time:107524ms step_avg:97.84ms +step:1100/1670 train_time:107621ms step_avg:97.84ms +step:1101/1670 train_time:107718ms step_avg:97.84ms +step:1102/1670 train_time:107815ms step_avg:97.84ms +step:1103/1670 train_time:107911ms step_avg:97.83ms +step:1104/1670 train_time:108008ms step_avg:97.83ms +step:1105/1670 train_time:108106ms step_avg:97.83ms +step:1106/1670 train_time:108204ms step_avg:97.83ms +step:1107/1670 train_time:108302ms step_avg:97.83ms +step:1108/1670 train_time:108399ms step_avg:97.83ms +step:1109/1670 train_time:108496ms step_avg:97.83ms +step:1110/1670 train_time:108592ms step_avg:97.83ms +step:1111/1670 train_time:108689ms step_avg:97.83ms +step:1112/1670 train_time:108787ms step_avg:97.83ms +step:1113/1670 train_time:108885ms step_avg:97.83ms +step:1114/1670 train_time:108984ms step_avg:97.83ms +step:1115/1670 train_time:109081ms step_avg:97.83ms +step:1116/1670 train_time:109178ms step_avg:97.83ms +step:1117/1670 train_time:109276ms step_avg:97.83ms +step:1118/1670 train_time:109373ms step_avg:97.83ms +step:1119/1670 train_time:109473ms step_avg:97.83ms +step:1120/1670 train_time:109571ms step_avg:97.83ms +step:1121/1670 train_time:109668ms step_avg:97.83ms +step:1122/1670 train_time:109767ms step_avg:97.83ms +step:1123/1670 train_time:109865ms step_avg:97.83ms +step:1124/1670 train_time:109963ms step_avg:97.83ms +step:1125/1670 train_time:110062ms step_avg:97.83ms +step:1125/1670 val_loss:3.4247 train_time:110161ms step_avg:97.92ms +step:1126/1670 train_time:110183ms step_avg:97.85ms +step:1127/1670 train_time:110267ms step_avg:97.84ms +step:1128/1670 train_time:110367ms step_avg:97.84ms +step:1129/1670 train_time:110465ms step_avg:97.84ms +step:1130/1670 train_time:110561ms step_avg:97.84ms +step:1131/1670 train_time:110657ms step_avg:97.84ms +step:1132/1670 train_time:110754ms step_avg:97.84ms +step:1133/1670 train_time:110852ms step_avg:97.84ms +step:1134/1670 train_time:110949ms step_avg:97.84ms +step:1135/1670 train_time:111046ms step_avg:97.84ms +step:1136/1670 train_time:111150ms step_avg:97.84ms +step:1137/1670 train_time:111253ms step_avg:97.85ms +step:1138/1670 train_time:111353ms step_avg:97.85ms +step:1139/1670 train_time:111453ms step_avg:97.85ms +step:1140/1670 train_time:111553ms step_avg:97.85ms +step:1141/1670 train_time:111651ms step_avg:97.85ms +step:1142/1670 train_time:111748ms step_avg:97.85ms +step:1143/1670 train_time:111844ms step_avg:97.85ms +step:1144/1670 train_time:111941ms step_avg:97.85ms +step:1145/1670 train_time:112039ms step_avg:97.85ms +step:1146/1670 train_time:112139ms step_avg:97.85ms +step:1147/1670 train_time:112236ms step_avg:97.85ms +step:1148/1670 train_time:112334ms step_avg:97.85ms +step:1149/1670 train_time:112433ms step_avg:97.85ms +step:1150/1670 train_time:112532ms step_avg:97.85ms +step:1151/1670 train_time:112631ms step_avg:97.85ms +step:1152/1670 train_time:112728ms step_avg:97.85ms +step:1153/1670 train_time:112825ms step_avg:97.85ms +step:1154/1670 train_time:112924ms step_avg:97.85ms +step:1155/1670 train_time:113021ms step_avg:97.85ms +step:1156/1670 train_time:113119ms step_avg:97.85ms +step:1157/1670 train_time:113217ms step_avg:97.85ms +step:1158/1670 train_time:113316ms step_avg:97.85ms +step:1159/1670 train_time:113414ms step_avg:97.86ms +step:1160/1670 train_time:113512ms step_avg:97.85ms +step:1161/1670 train_time:113609ms step_avg:97.85ms +step:1162/1670 train_time:113708ms step_avg:97.86ms +step:1163/1670 train_time:113806ms step_avg:97.86ms +step:1164/1670 train_time:113904ms step_avg:97.86ms +step:1165/1670 train_time:114001ms step_avg:97.86ms +step:1166/1670 train_time:114099ms step_avg:97.86ms +step:1167/1670 train_time:114197ms step_avg:97.86ms +step:1168/1670 train_time:114296ms step_avg:97.86ms +step:1169/1670 train_time:114394ms step_avg:97.86ms +step:1170/1670 train_time:114492ms step_avg:97.86ms +step:1171/1670 train_time:114591ms step_avg:97.86ms +step:1172/1670 train_time:114689ms step_avg:97.86ms +step:1173/1670 train_time:114786ms step_avg:97.86ms +step:1174/1670 train_time:114884ms step_avg:97.86ms +step:1175/1670 train_time:114983ms step_avg:97.86ms +step:1176/1670 train_time:115080ms step_avg:97.86ms +step:1177/1670 train_time:115178ms step_avg:97.86ms +step:1178/1670 train_time:115276ms step_avg:97.86ms +step:1179/1670 train_time:115374ms step_avg:97.86ms +step:1180/1670 train_time:115473ms step_avg:97.86ms +step:1181/1670 train_time:115571ms step_avg:97.86ms +step:1182/1670 train_time:115668ms step_avg:97.86ms +step:1183/1670 train_time:115767ms step_avg:97.86ms +step:1184/1670 train_time:115865ms step_avg:97.86ms +step:1185/1670 train_time:115962ms step_avg:97.86ms +step:1186/1670 train_time:116059ms step_avg:97.86ms +step:1187/1670 train_time:116157ms step_avg:97.86ms +step:1188/1670 train_time:116255ms step_avg:97.86ms +step:1189/1670 train_time:116353ms step_avg:97.86ms +step:1190/1670 train_time:116451ms step_avg:97.86ms +step:1191/1670 train_time:116550ms step_avg:97.86ms +step:1192/1670 train_time:116649ms step_avg:97.86ms +step:1193/1670 train_time:116747ms step_avg:97.86ms +step:1194/1670 train_time:116845ms step_avg:97.86ms +step:1195/1670 train_time:116942ms step_avg:97.86ms +step:1196/1670 train_time:117040ms step_avg:97.86ms +step:1197/1670 train_time:117138ms step_avg:97.86ms +step:1198/1670 train_time:117236ms step_avg:97.86ms +step:1199/1670 train_time:117334ms step_avg:97.86ms +step:1200/1670 train_time:117431ms step_avg:97.86ms +step:1201/1670 train_time:117529ms step_avg:97.86ms +step:1202/1670 train_time:117627ms step_avg:97.86ms +step:1203/1670 train_time:117726ms step_avg:97.86ms +step:1204/1670 train_time:117824ms step_avg:97.86ms +step:1205/1670 train_time:117921ms step_avg:97.86ms +step:1206/1670 train_time:118019ms step_avg:97.86ms +step:1207/1670 train_time:118117ms step_avg:97.86ms +step:1208/1670 train_time:118215ms step_avg:97.86ms +step:1209/1670 train_time:118313ms step_avg:97.86ms +step:1210/1670 train_time:118410ms step_avg:97.86ms +step:1211/1670 train_time:118508ms step_avg:97.86ms +step:1212/1670 train_time:118607ms step_avg:97.86ms +step:1213/1670 train_time:118705ms step_avg:97.86ms +step:1214/1670 train_time:118802ms step_avg:97.86ms +step:1215/1670 train_time:118899ms step_avg:97.86ms +step:1216/1670 train_time:118997ms step_avg:97.86ms +step:1217/1670 train_time:119095ms step_avg:97.86ms +step:1218/1670 train_time:119194ms step_avg:97.86ms +step:1219/1670 train_time:119293ms step_avg:97.86ms +step:1220/1670 train_time:119391ms step_avg:97.86ms +step:1221/1670 train_time:119489ms step_avg:97.86ms +step:1222/1670 train_time:119587ms step_avg:97.86ms +step:1223/1670 train_time:119685ms step_avg:97.86ms +step:1224/1670 train_time:119783ms step_avg:97.86ms +step:1225/1670 train_time:119881ms step_avg:97.86ms +step:1226/1670 train_time:119979ms step_avg:97.86ms +step:1227/1670 train_time:120077ms step_avg:97.86ms +step:1228/1670 train_time:120176ms step_avg:97.86ms +step:1229/1670 train_time:120272ms step_avg:97.86ms +step:1230/1670 train_time:120371ms step_avg:97.86ms +step:1231/1670 train_time:120469ms step_avg:97.86ms +step:1232/1670 train_time:120567ms step_avg:97.86ms +step:1233/1670 train_time:120666ms step_avg:97.86ms +step:1234/1670 train_time:120765ms step_avg:97.86ms +step:1235/1670 train_time:120863ms step_avg:97.86ms +step:1236/1670 train_time:120961ms step_avg:97.86ms +step:1237/1670 train_time:121059ms step_avg:97.87ms +step:1238/1670 train_time:121157ms step_avg:97.86ms +step:1239/1670 train_time:121254ms step_avg:97.86ms +step:1240/1670 train_time:121351ms step_avg:97.86ms +step:1241/1670 train_time:121449ms step_avg:97.86ms +step:1242/1670 train_time:121548ms step_avg:97.86ms +step:1243/1670 train_time:121645ms step_avg:97.86ms +step:1244/1670 train_time:121743ms step_avg:97.86ms +step:1245/1670 train_time:121841ms step_avg:97.86ms +step:1246/1670 train_time:121939ms step_avg:97.86ms +step:1247/1670 train_time:122038ms step_avg:97.87ms +step:1248/1670 train_time:122136ms step_avg:97.87ms +step:1249/1670 train_time:122234ms step_avg:97.87ms +step:1250/1670 train_time:122332ms step_avg:97.87ms +step:1250/1670 val_loss:3.3807 train_time:122428ms step_avg:97.94ms +step:1251/1670 train_time:122450ms step_avg:97.88ms +step:1252/1670 train_time:122535ms step_avg:97.87ms +step:1253/1670 train_time:122634ms step_avg:97.87ms +step:1254/1670 train_time:122732ms step_avg:97.87ms +step:1255/1670 train_time:122830ms step_avg:97.87ms +step:1256/1670 train_time:122929ms step_avg:97.87ms +step:1257/1670 train_time:123026ms step_avg:97.87ms +step:1258/1670 train_time:123122ms step_avg:97.87ms +step:1259/1670 train_time:123219ms step_avg:97.87ms +step:1260/1670 train_time:123316ms step_avg:97.87ms +step:1261/1670 train_time:123415ms step_avg:97.87ms +step:1262/1670 train_time:123515ms step_avg:97.87ms +step:1263/1670 train_time:123615ms step_avg:97.87ms +step:1264/1670 train_time:123713ms step_avg:97.87ms +step:1265/1670 train_time:123812ms step_avg:97.87ms +step:1266/1670 train_time:123909ms step_avg:97.87ms +step:1267/1670 train_time:124006ms step_avg:97.87ms +step:1268/1670 train_time:124104ms step_avg:97.87ms +step:1269/1670 train_time:124201ms step_avg:97.87ms +step:1270/1670 train_time:124298ms step_avg:97.87ms +step:1271/1670 train_time:124398ms step_avg:97.87ms +step:1272/1670 train_time:124497ms step_avg:97.88ms +step:1273/1670 train_time:124596ms step_avg:97.88ms +step:1274/1670 train_time:124857ms step_avg:98.00ms +step:1275/1670 train_time:125059ms step_avg:98.09ms +step:1276/1670 train_time:125155ms step_avg:98.08ms +step:1277/1670 train_time:125252ms step_avg:98.08ms +step:1278/1670 train_time:125349ms step_avg:98.08ms +step:1279/1670 train_time:125445ms step_avg:98.08ms +step:1280/1670 train_time:125542ms step_avg:98.08ms +step:1281/1670 train_time:125638ms step_avg:98.08ms +step:1282/1670 train_time:125735ms step_avg:98.08ms +step:1283/1670 train_time:125834ms step_avg:98.08ms +step:1284/1670 train_time:125937ms step_avg:98.08ms +step:1285/1670 train_time:126039ms step_avg:98.09ms +step:1286/1670 train_time:126138ms step_avg:98.09ms +step:1287/1670 train_time:126236ms step_avg:98.09ms +step:1288/1670 train_time:126334ms step_avg:98.09ms +step:1289/1670 train_time:126432ms step_avg:98.09ms +step:1290/1670 train_time:126530ms step_avg:98.09ms +step:1291/1670 train_time:126628ms step_avg:98.09ms +step:1292/1670 train_time:126725ms step_avg:98.08ms +step:1293/1670 train_time:126823ms step_avg:98.08ms +step:1294/1670 train_time:126924ms step_avg:98.09ms +step:1295/1670 train_time:127025ms step_avg:98.09ms +step:1296/1670 train_time:127124ms step_avg:98.09ms +step:1297/1670 train_time:127221ms step_avg:98.09ms +step:1298/1670 train_time:127318ms step_avg:98.09ms +step:1299/1670 train_time:127416ms step_avg:98.09ms +step:1300/1670 train_time:127512ms step_avg:98.09ms +step:1301/1670 train_time:127611ms step_avg:98.09ms +step:1302/1670 train_time:127708ms step_avg:98.09ms +step:1303/1670 train_time:127806ms step_avg:98.09ms +step:1304/1670 train_time:127906ms step_avg:98.09ms +step:1305/1670 train_time:128006ms step_avg:98.09ms +step:1306/1670 train_time:128104ms step_avg:98.09ms +step:1307/1670 train_time:128204ms step_avg:98.09ms +step:1308/1670 train_time:128303ms step_avg:98.09ms +step:1309/1670 train_time:128400ms step_avg:98.09ms +step:1310/1670 train_time:128497ms step_avg:98.09ms +step:1311/1670 train_time:128594ms step_avg:98.09ms +step:1312/1670 train_time:128692ms step_avg:98.09ms +step:1313/1670 train_time:128789ms step_avg:98.09ms +step:1314/1670 train_time:128887ms step_avg:98.09ms +step:1315/1670 train_time:128986ms step_avg:98.09ms +step:1316/1670 train_time:129085ms step_avg:98.09ms +step:1317/1670 train_time:129184ms step_avg:98.09ms +step:1318/1670 train_time:129283ms step_avg:98.09ms +step:1319/1670 train_time:129382ms step_avg:98.09ms +step:1320/1670 train_time:129479ms step_avg:98.09ms +step:1321/1670 train_time:129576ms step_avg:98.09ms +step:1322/1670 train_time:129673ms step_avg:98.09ms +step:1323/1670 train_time:129772ms step_avg:98.09ms +step:1324/1670 train_time:129869ms step_avg:98.09ms +step:1325/1670 train_time:129968ms step_avg:98.09ms +step:1326/1670 train_time:130066ms step_avg:98.09ms +step:1327/1670 train_time:130164ms step_avg:98.09ms +step:1328/1670 train_time:130262ms step_avg:98.09ms +step:1329/1670 train_time:130362ms step_avg:98.09ms +step:1330/1670 train_time:130460ms step_avg:98.09ms +step:1331/1670 train_time:130558ms step_avg:98.09ms +step:1332/1670 train_time:130656ms step_avg:98.09ms +step:1333/1670 train_time:130752ms step_avg:98.09ms +step:1334/1670 train_time:130850ms step_avg:98.09ms +step:1335/1670 train_time:130948ms step_avg:98.09ms +step:1336/1670 train_time:131046ms step_avg:98.09ms +step:1337/1670 train_time:131146ms step_avg:98.09ms +step:1338/1670 train_time:131246ms step_avg:98.09ms +step:1339/1670 train_time:131346ms step_avg:98.09ms +step:1340/1670 train_time:131446ms step_avg:98.09ms +step:1341/1670 train_time:131545ms step_avg:98.09ms +step:1342/1670 train_time:131643ms step_avg:98.09ms +step:1343/1670 train_time:131741ms step_avg:98.09ms +step:1344/1670 train_time:131837ms step_avg:98.09ms +step:1345/1670 train_time:131934ms step_avg:98.09ms +step:1346/1670 train_time:132032ms step_avg:98.09ms +step:1347/1670 train_time:132132ms step_avg:98.09ms +step:1348/1670 train_time:132232ms step_avg:98.10ms +step:1349/1670 train_time:132331ms step_avg:98.10ms +step:1350/1670 train_time:132431ms step_avg:98.10ms +step:1351/1670 train_time:132531ms step_avg:98.10ms +step:1352/1670 train_time:132629ms step_avg:98.10ms +step:1353/1670 train_time:132729ms step_avg:98.10ms +step:1354/1670 train_time:132829ms step_avg:98.10ms +step:1355/1670 train_time:132927ms step_avg:98.10ms +step:1356/1670 train_time:133025ms step_avg:98.10ms +step:1357/1670 train_time:133122ms step_avg:98.10ms +step:1358/1670 train_time:133220ms step_avg:98.10ms +step:1359/1670 train_time:133318ms step_avg:98.10ms +step:1360/1670 train_time:133417ms step_avg:98.10ms +step:1361/1670 train_time:133515ms step_avg:98.10ms +step:1362/1670 train_time:133613ms step_avg:98.10ms +step:1363/1670 train_time:133711ms step_avg:98.10ms +step:1364/1670 train_time:133809ms step_avg:98.10ms +step:1365/1670 train_time:133907ms step_avg:98.10ms +step:1366/1670 train_time:134005ms step_avg:98.10ms +step:1367/1670 train_time:134102ms step_avg:98.10ms +step:1368/1670 train_time:134200ms step_avg:98.10ms +step:1369/1670 train_time:134298ms step_avg:98.10ms +step:1370/1670 train_time:134395ms step_avg:98.10ms +step:1371/1670 train_time:134494ms step_avg:98.10ms +step:1372/1670 train_time:134592ms step_avg:98.10ms +step:1373/1670 train_time:134690ms step_avg:98.10ms +step:1374/1670 train_time:134788ms step_avg:98.10ms +step:1375/1670 train_time:134886ms step_avg:98.10ms +step:1375/1670 val_loss:3.3438 train_time:134983ms step_avg:98.17ms +step:1376/1670 train_time:135005ms step_avg:98.11ms +step:1377/1670 train_time:135088ms step_avg:98.10ms +step:1378/1670 train_time:135191ms step_avg:98.11ms +step:1379/1670 train_time:135288ms step_avg:98.11ms +step:1380/1670 train_time:135386ms step_avg:98.11ms +step:1381/1670 train_time:135483ms step_avg:98.11ms +step:1382/1670 train_time:135580ms step_avg:98.10ms +step:1383/1670 train_time:135677ms step_avg:98.10ms +step:1384/1670 train_time:135774ms step_avg:98.10ms +step:1385/1670 train_time:135871ms step_avg:98.10ms +step:1386/1670 train_time:135971ms step_avg:98.10ms +step:1387/1670 train_time:136071ms step_avg:98.10ms +step:1388/1670 train_time:136172ms step_avg:98.11ms +step:1389/1670 train_time:136272ms step_avg:98.11ms +step:1390/1670 train_time:136373ms step_avg:98.11ms +step:1391/1670 train_time:136472ms step_avg:98.11ms +step:1392/1670 train_time:136570ms step_avg:98.11ms +step:1393/1670 train_time:136668ms step_avg:98.11ms +step:1394/1670 train_time:136765ms step_avg:98.11ms +step:1395/1670 train_time:136862ms step_avg:98.11ms +step:1396/1670 train_time:136959ms step_avg:98.11ms +step:1397/1670 train_time:137058ms step_avg:98.11ms +step:1398/1670 train_time:137156ms step_avg:98.11ms +step:1399/1670 train_time:137255ms step_avg:98.11ms +step:1400/1670 train_time:137354ms step_avg:98.11ms +step:1401/1670 train_time:137452ms step_avg:98.11ms +step:1402/1670 train_time:137550ms step_avg:98.11ms +step:1403/1670 train_time:137648ms step_avg:98.11ms +step:1404/1670 train_time:137746ms step_avg:98.11ms +step:1405/1670 train_time:137844ms step_avg:98.11ms +step:1406/1670 train_time:137942ms step_avg:98.11ms +step:1407/1670 train_time:138039ms step_avg:98.11ms +step:1408/1670 train_time:138137ms step_avg:98.11ms +step:1409/1670 train_time:138237ms step_avg:98.11ms +step:1410/1670 train_time:138337ms step_avg:98.11ms +step:1411/1670 train_time:138436ms step_avg:98.11ms +step:1412/1670 train_time:138535ms step_avg:98.11ms +step:1413/1670 train_time:138633ms step_avg:98.11ms +step:1414/1670 train_time:138731ms step_avg:98.11ms +step:1415/1670 train_time:138829ms step_avg:98.11ms +step:1416/1670 train_time:138928ms step_avg:98.11ms +step:1417/1670 train_time:139027ms step_avg:98.11ms +step:1418/1670 train_time:139125ms step_avg:98.11ms +step:1419/1670 train_time:139223ms step_avg:98.11ms +step:1420/1670 train_time:139320ms step_avg:98.11ms +step:1421/1670 train_time:139418ms step_avg:98.11ms +step:1422/1670 train_time:139516ms step_avg:98.11ms +step:1423/1670 train_time:139614ms step_avg:98.11ms +step:1424/1670 train_time:139711ms step_avg:98.11ms +step:1425/1670 train_time:139810ms step_avg:98.11ms +step:1426/1670 train_time:139908ms step_avg:98.11ms +step:1427/1670 train_time:140006ms step_avg:98.11ms +step:1428/1670 train_time:140103ms step_avg:98.11ms +step:1429/1670 train_time:140201ms step_avg:98.11ms +step:1430/1670 train_time:140299ms step_avg:98.11ms +step:1431/1670 train_time:140397ms step_avg:98.11ms +step:1432/1670 train_time:140495ms step_avg:98.11ms +step:1433/1670 train_time:140594ms step_avg:98.11ms +step:1434/1670 train_time:140693ms step_avg:98.11ms +step:1435/1670 train_time:140793ms step_avg:98.11ms +step:1436/1670 train_time:140892ms step_avg:98.11ms +step:1437/1670 train_time:140992ms step_avg:98.12ms +step:1438/1670 train_time:141090ms step_avg:98.12ms +step:1439/1670 train_time:141190ms step_avg:98.12ms +step:1440/1670 train_time:141291ms step_avg:98.12ms +step:1441/1670 train_time:141389ms step_avg:98.12ms +step:1442/1670 train_time:141489ms step_avg:98.12ms +step:1443/1670 train_time:141586ms step_avg:98.12ms +step:1444/1670 train_time:141684ms step_avg:98.12ms +step:1445/1670 train_time:141781ms step_avg:98.12ms +step:1446/1670 train_time:141878ms step_avg:98.12ms +step:1447/1670 train_time:141977ms step_avg:98.12ms +step:1448/1670 train_time:142076ms step_avg:98.12ms +step:1449/1670 train_time:142176ms step_avg:98.12ms +step:1450/1670 train_time:142275ms step_avg:98.12ms +step:1451/1670 train_time:142373ms step_avg:98.12ms +step:1452/1670 train_time:142472ms step_avg:98.12ms +step:1453/1670 train_time:142571ms step_avg:98.12ms +step:1454/1670 train_time:142670ms step_avg:98.12ms +step:1455/1670 train_time:142768ms step_avg:98.12ms +step:1456/1670 train_time:142866ms step_avg:98.12ms +step:1457/1670 train_time:142964ms step_avg:98.12ms +step:1458/1670 train_time:143061ms step_avg:98.12ms +step:1459/1670 train_time:143159ms step_avg:98.12ms +step:1460/1670 train_time:143258ms step_avg:98.12ms +step:1461/1670 train_time:143358ms step_avg:98.12ms +step:1462/1670 train_time:143456ms step_avg:98.12ms +step:1463/1670 train_time:143554ms step_avg:98.12ms +step:1464/1670 train_time:143653ms step_avg:98.12ms +step:1465/1670 train_time:143753ms step_avg:98.12ms +step:1466/1670 train_time:143852ms step_avg:98.13ms +step:1467/1670 train_time:143951ms step_avg:98.13ms +step:1468/1670 train_time:144051ms step_avg:98.13ms +step:1469/1670 train_time:144150ms step_avg:98.13ms +step:1470/1670 train_time:144249ms step_avg:98.13ms +step:1471/1670 train_time:144347ms step_avg:98.13ms +step:1472/1670 train_time:144445ms step_avg:98.13ms +step:1473/1670 train_time:144541ms step_avg:98.13ms +step:1474/1670 train_time:144639ms step_avg:98.13ms +step:1475/1670 train_time:144737ms step_avg:98.13ms +step:1476/1670 train_time:144837ms step_avg:98.13ms +step:1477/1670 train_time:144935ms step_avg:98.13ms +step:1478/1670 train_time:145034ms step_avg:98.13ms +step:1479/1670 train_time:145133ms step_avg:98.13ms +step:1480/1670 train_time:145233ms step_avg:98.13ms +step:1481/1670 train_time:145333ms step_avg:98.13ms +step:1482/1670 train_time:145432ms step_avg:98.13ms +step:1483/1670 train_time:145530ms step_avg:98.13ms +step:1484/1670 train_time:145628ms step_avg:98.13ms +step:1485/1670 train_time:145886ms step_avg:98.24ms +step:1486/1670 train_time:146095ms step_avg:98.31ms +step:1487/1670 train_time:146192ms step_avg:98.31ms +step:1488/1670 train_time:146289ms step_avg:98.31ms +step:1489/1670 train_time:146385ms step_avg:98.31ms +step:1490/1670 train_time:146482ms step_avg:98.31ms +step:1491/1670 train_time:146579ms step_avg:98.31ms +step:1492/1670 train_time:146676ms step_avg:98.31ms +step:1493/1670 train_time:146773ms step_avg:98.31ms +step:1494/1670 train_time:146873ms step_avg:98.31ms +step:1495/1670 train_time:146979ms step_avg:98.31ms +step:1496/1670 train_time:147080ms step_avg:98.32ms +step:1497/1670 train_time:147178ms step_avg:98.32ms +step:1498/1670 train_time:147276ms step_avg:98.32ms +step:1499/1670 train_time:147375ms step_avg:98.32ms +step:1500/1670 train_time:147473ms step_avg:98.32ms +step:1500/1670 val_loss:3.3113 train_time:147570ms step_avg:98.38ms +step:1501/1670 train_time:147592ms step_avg:98.33ms +step:1502/1670 train_time:147675ms step_avg:98.32ms +step:1503/1670 train_time:147776ms step_avg:98.32ms +step:1504/1670 train_time:147873ms step_avg:98.32ms +step:1505/1670 train_time:147970ms step_avg:98.32ms +step:1506/1670 train_time:148067ms step_avg:98.32ms +step:1507/1670 train_time:148164ms step_avg:98.32ms +step:1508/1670 train_time:148260ms step_avg:98.32ms +step:1509/1670 train_time:148357ms step_avg:98.31ms +step:1510/1670 train_time:148455ms step_avg:98.31ms +step:1511/1670 train_time:148555ms step_avg:98.32ms +step:1512/1670 train_time:148656ms step_avg:98.32ms +step:1513/1670 train_time:148754ms step_avg:98.32ms +step:1514/1670 train_time:148854ms step_avg:98.32ms +step:1515/1670 train_time:148952ms step_avg:98.32ms +step:1516/1670 train_time:149051ms step_avg:98.32ms +step:1517/1670 train_time:149148ms step_avg:98.32ms +step:1518/1670 train_time:149246ms step_avg:98.32ms +step:1519/1670 train_time:149344ms step_avg:98.32ms +step:1520/1670 train_time:149442ms step_avg:98.32ms +step:1521/1670 train_time:149541ms step_avg:98.32ms +step:1522/1670 train_time:149641ms step_avg:98.32ms +step:1523/1670 train_time:149739ms step_avg:98.32ms +step:1524/1670 train_time:149837ms step_avg:98.32ms +step:1525/1670 train_time:149934ms step_avg:98.32ms +step:1526/1670 train_time:150032ms step_avg:98.32ms +step:1527/1670 train_time:150131ms step_avg:98.32ms +step:1528/1670 train_time:150230ms step_avg:98.32ms +step:1529/1670 train_time:150328ms step_avg:98.32ms +step:1530/1670 train_time:150428ms step_avg:98.32ms +step:1531/1670 train_time:150528ms step_avg:98.32ms +step:1532/1670 train_time:150626ms step_avg:98.32ms +step:1533/1670 train_time:150726ms step_avg:98.32ms +step:1534/1670 train_time:150825ms step_avg:98.32ms +step:1535/1670 train_time:150922ms step_avg:98.32ms +step:1536/1670 train_time:151021ms step_avg:98.32ms +step:1537/1670 train_time:151117ms step_avg:98.32ms +step:1538/1670 train_time:151214ms step_avg:98.32ms +step:1539/1670 train_time:151311ms step_avg:98.32ms +step:1540/1670 train_time:151411ms step_avg:98.32ms +step:1541/1670 train_time:151511ms step_avg:98.32ms +step:1542/1670 train_time:151610ms step_avg:98.32ms +step:1543/1670 train_time:151709ms step_avg:98.32ms +step:1544/1670 train_time:151809ms step_avg:98.32ms +step:1545/1670 train_time:151910ms step_avg:98.32ms +step:1546/1670 train_time:152011ms step_avg:98.33ms +step:1547/1670 train_time:152111ms step_avg:98.33ms +step:1548/1670 train_time:152209ms step_avg:98.33ms +step:1549/1670 train_time:152306ms step_avg:98.33ms +step:1550/1670 train_time:152405ms step_avg:98.33ms +step:1551/1670 train_time:152503ms step_avg:98.33ms +step:1552/1670 train_time:152601ms step_avg:98.33ms +step:1553/1670 train_time:152698ms step_avg:98.32ms +step:1554/1670 train_time:152796ms step_avg:98.32ms +step:1555/1670 train_time:152895ms step_avg:98.32ms +step:1556/1670 train_time:152995ms step_avg:98.33ms +step:1557/1670 train_time:153094ms step_avg:98.33ms +step:1558/1670 train_time:153193ms step_avg:98.33ms +step:1559/1670 train_time:153292ms step_avg:98.33ms +step:1560/1670 train_time:153391ms step_avg:98.33ms +step:1561/1670 train_time:153489ms step_avg:98.33ms +step:1562/1670 train_time:153589ms step_avg:98.33ms +step:1563/1670 train_time:153687ms step_avg:98.33ms +step:1564/1670 train_time:153788ms step_avg:98.33ms +step:1565/1670 train_time:153889ms step_avg:98.33ms +step:1566/1670 train_time:153988ms step_avg:98.33ms +step:1567/1670 train_time:154088ms step_avg:98.33ms +step:1568/1670 train_time:154186ms step_avg:98.33ms +step:1569/1670 train_time:154284ms step_avg:98.33ms +step:1570/1670 train_time:154382ms step_avg:98.33ms +step:1571/1670 train_time:154479ms step_avg:98.33ms +step:1572/1670 train_time:154576ms step_avg:98.33ms +step:1573/1670 train_time:154676ms step_avg:98.33ms +step:1574/1670 train_time:154774ms step_avg:98.33ms +step:1575/1670 train_time:154872ms step_avg:98.33ms +step:1576/1670 train_time:154971ms step_avg:98.33ms +step:1577/1670 train_time:155069ms step_avg:98.33ms +step:1578/1670 train_time:155168ms step_avg:98.33ms +step:1579/1670 train_time:155267ms step_avg:98.33ms +step:1580/1670 train_time:155365ms step_avg:98.33ms +step:1581/1670 train_time:155465ms step_avg:98.33ms +step:1582/1670 train_time:155563ms step_avg:98.33ms +step:1583/1670 train_time:155662ms step_avg:98.33ms +step:1584/1670 train_time:155760ms step_avg:98.33ms +step:1585/1670 train_time:155857ms step_avg:98.33ms +step:1586/1670 train_time:155954ms step_avg:98.33ms +step:1587/1670 train_time:156052ms step_avg:98.33ms +step:1588/1670 train_time:156151ms step_avg:98.33ms +step:1589/1670 train_time:156250ms step_avg:98.33ms +step:1590/1670 train_time:156349ms step_avg:98.33ms +step:1591/1670 train_time:156449ms step_avg:98.33ms +step:1592/1670 train_time:156549ms step_avg:98.33ms +step:1593/1670 train_time:156649ms step_avg:98.34ms +step:1594/1670 train_time:156748ms step_avg:98.34ms +step:1595/1670 train_time:156847ms step_avg:98.34ms +step:1596/1670 train_time:156946ms step_avg:98.34ms +step:1597/1670 train_time:157045ms step_avg:98.34ms +step:1598/1670 train_time:157142ms step_avg:98.34ms +step:1599/1670 train_time:157239ms step_avg:98.34ms +step:1600/1670 train_time:157336ms step_avg:98.33ms +step:1601/1670 train_time:157434ms step_avg:98.33ms +step:1602/1670 train_time:157532ms step_avg:98.33ms +step:1603/1670 train_time:157631ms step_avg:98.33ms +step:1604/1670 train_time:157730ms step_avg:98.34ms +step:1605/1670 train_time:157830ms step_avg:98.34ms +step:1606/1670 train_time:157929ms step_avg:98.34ms +step:1607/1670 train_time:158028ms step_avg:98.34ms +step:1608/1670 train_time:158128ms step_avg:98.34ms +step:1609/1670 train_time:158228ms step_avg:98.34ms +step:1610/1670 train_time:158326ms step_avg:98.34ms +step:1611/1670 train_time:158424ms step_avg:98.34ms +step:1612/1670 train_time:158522ms step_avg:98.34ms +step:1613/1670 train_time:158620ms step_avg:98.34ms +step:1614/1670 train_time:158719ms step_avg:98.34ms +step:1615/1670 train_time:158816ms step_avg:98.34ms +step:1616/1670 train_time:158916ms step_avg:98.34ms +step:1617/1670 train_time:159015ms step_avg:98.34ms +step:1618/1670 train_time:159114ms step_avg:98.34ms +step:1619/1670 train_time:159213ms step_avg:98.34ms +step:1620/1670 train_time:159311ms step_avg:98.34ms +step:1621/1670 train_time:159409ms step_avg:98.34ms +step:1622/1670 train_time:159508ms step_avg:98.34ms +step:1623/1670 train_time:159606ms step_avg:98.34ms +step:1624/1670 train_time:159705ms step_avg:98.34ms +step:1625/1670 train_time:159805ms step_avg:98.34ms +step:1625/1670 val_loss:3.2846 train_time:159905ms step_avg:98.40ms +step:1626/1670 train_time:159926ms step_avg:98.36ms +step:1627/1670 train_time:160012ms step_avg:98.35ms +step:1628/1670 train_time:160113ms step_avg:98.35ms +step:1629/1670 train_time:160212ms step_avg:98.35ms +step:1630/1670 train_time:160310ms step_avg:98.35ms +step:1631/1670 train_time:160407ms step_avg:98.35ms +step:1632/1670 train_time:160504ms step_avg:98.35ms +step:1633/1670 train_time:160601ms step_avg:98.35ms +step:1634/1670 train_time:160698ms step_avg:98.35ms +step:1635/1670 train_time:160795ms step_avg:98.35ms +step:1636/1670 train_time:160895ms step_avg:98.35ms +step:1637/1670 train_time:160999ms step_avg:98.35ms +step:1638/1670 train_time:161099ms step_avg:98.35ms +step:1639/1670 train_time:161198ms step_avg:98.35ms +step:1640/1670 train_time:161297ms step_avg:98.35ms +step:1641/1670 train_time:161396ms step_avg:98.35ms +step:1642/1670 train_time:161493ms step_avg:98.35ms +step:1643/1670 train_time:161590ms step_avg:98.35ms +step:1644/1670 train_time:161687ms step_avg:98.35ms +step:1645/1670 train_time:161785ms step_avg:98.35ms +step:1646/1670 train_time:161882ms step_avg:98.35ms +step:1647/1670 train_time:161981ms step_avg:98.35ms +step:1648/1670 train_time:162079ms step_avg:98.35ms +step:1649/1670 train_time:162178ms step_avg:98.35ms +step:1650/1670 train_time:162276ms step_avg:98.35ms +step:1651/1670 train_time:162375ms step_avg:98.35ms +step:1652/1670 train_time:162473ms step_avg:98.35ms +step:1653/1670 train_time:162571ms step_avg:98.35ms +step:1654/1670 train_time:162670ms step_avg:98.35ms +step:1655/1670 train_time:162768ms step_avg:98.35ms +step:1656/1670 train_time:162867ms step_avg:98.35ms +step:1657/1670 train_time:162966ms step_avg:98.35ms +step:1658/1670 train_time:163064ms step_avg:98.35ms +step:1659/1670 train_time:163162ms step_avg:98.35ms +step:1660/1670 train_time:163260ms step_avg:98.35ms +step:1661/1670 train_time:163358ms step_avg:98.35ms +step:1662/1670 train_time:163456ms step_avg:98.35ms +step:1663/1670 train_time:163553ms step_avg:98.35ms +step:1664/1670 train_time:163652ms step_avg:98.35ms +step:1665/1670 train_time:163750ms step_avg:98.35ms +step:1666/1670 train_time:163849ms step_avg:98.35ms +step:1667/1670 train_time:163949ms step_avg:98.35ms +step:1668/1670 train_time:164048ms step_avg:98.35ms +step:1669/1670 train_time:164149ms step_avg:98.35ms +step:1670/1670 train_time:164249ms step_avg:98.35ms +step:1670/1670 val_loss:3.2766 train_time:164348ms step_avg:98.41ms +peak memory allocated: 34613 MiB reserved: 50216 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt b/records/050925_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt new file mode 100644 index 000000000..b02e2101d --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:57:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 130W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 126W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 33C P0 123W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 77482 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 77483 C /usr/bin/python3 610MiB | +| 0 N/A N/A 77484 C /usr/bin/python3 610MiB | +| 0 N/A N/A 77485 C /usr/bin/python3 610MiB | +| 0 N/A N/A 77486 C /usr/bin/python3 610MiB | +| 0 N/A N/A 77487 C /usr/bin/python3 610MiB | +| 0 N/A N/A 77488 C /usr/bin/python3 610MiB | +| 0 N/A N/A 77489 C /usr/bin/python3 610MiB | +| 1 N/A N/A 77483 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 77484 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 77485 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 77486 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 77487 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 77488 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 77489 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:418ms step_avg:417.70ms +step:2/1670 train_time:438ms step_avg:219.23ms +step:3/1670 train_time:511ms step_avg:170.23ms +step:4/1670 train_time:605ms step_avg:151.19ms +step:5/1670 train_time:699ms step_avg:139.84ms +step:6/1670 train_time:794ms step_avg:132.29ms +step:7/1670 train_time:889ms step_avg:126.98ms +step:8/1670 train_time:984ms step_avg:122.99ms +step:9/1670 train_time:1079ms step_avg:119.87ms +step:10/1670 train_time:1174ms step_avg:117.36ms +step:11/1670 train_time:1269ms step_avg:115.34ms +step:12/1670 train_time:1369ms step_avg:114.10ms +step:13/1670 train_time:1468ms step_avg:112.90ms +step:14/1670 train_time:1565ms step_avg:111.77ms +step:15/1670 train_time:1661ms step_avg:110.70ms +step:16/1670 train_time:1756ms step_avg:109.75ms +step:17/1670 train_time:1852ms step_avg:108.96ms +step:18/1670 train_time:1948ms step_avg:108.23ms +step:19/1670 train_time:2044ms step_avg:107.60ms +step:20/1670 train_time:2141ms step_avg:107.04ms +step:21/1670 train_time:2236ms step_avg:106.47ms +step:22/1670 train_time:2332ms step_avg:105.99ms +step:23/1670 train_time:2430ms step_avg:105.64ms +step:24/1670 train_time:2527ms step_avg:105.30ms +step:25/1670 train_time:2625ms step_avg:104.98ms +step:26/1670 train_time:2721ms step_avg:104.64ms +step:27/1670 train_time:2816ms step_avg:104.31ms +step:28/1670 train_time:2912ms step_avg:104.01ms +step:29/1670 train_time:3008ms step_avg:103.74ms +step:30/1670 train_time:3105ms step_avg:103.50ms +step:31/1670 train_time:3202ms step_avg:103.28ms +step:32/1670 train_time:3298ms step_avg:103.07ms +step:33/1670 train_time:3394ms step_avg:102.85ms +step:34/1670 train_time:3490ms step_avg:102.65ms +step:35/1670 train_time:3587ms step_avg:102.48ms +step:36/1670 train_time:3683ms step_avg:102.31ms +step:37/1670 train_time:3779ms step_avg:102.14ms +step:38/1670 train_time:3874ms step_avg:101.96ms +step:39/1670 train_time:3970ms step_avg:101.78ms +step:40/1670 train_time:4066ms step_avg:101.65ms +step:41/1670 train_time:4162ms step_avg:101.51ms +step:42/1670 train_time:4259ms step_avg:101.41ms +step:43/1670 train_time:4355ms step_avg:101.28ms +step:44/1670 train_time:4451ms step_avg:101.16ms +step:45/1670 train_time:4548ms step_avg:101.07ms +step:46/1670 train_time:4643ms step_avg:100.94ms +step:47/1670 train_time:4740ms step_avg:100.84ms +step:48/1670 train_time:4835ms step_avg:100.73ms +step:49/1670 train_time:4931ms step_avg:100.63ms +step:50/1670 train_time:5027ms step_avg:100.53ms +step:51/1670 train_time:5122ms step_avg:100.44ms +step:52/1670 train_time:5218ms step_avg:100.35ms +step:53/1670 train_time:5314ms step_avg:100.27ms +step:54/1670 train_time:5411ms step_avg:100.20ms +step:55/1670 train_time:5508ms step_avg:100.14ms +step:56/1670 train_time:5604ms step_avg:100.07ms +step:57/1670 train_time:5700ms step_avg:100.01ms +step:58/1670 train_time:5796ms step_avg:99.93ms +step:59/1670 train_time:5891ms step_avg:99.85ms +step:60/1670 train_time:5987ms step_avg:99.78ms +step:61/1670 train_time:6082ms step_avg:99.70ms +step:62/1670 train_time:6178ms step_avg:99.65ms +step:63/1670 train_time:6274ms step_avg:99.58ms +step:64/1670 train_time:6370ms step_avg:99.53ms +step:65/1670 train_time:6466ms step_avg:99.48ms +step:66/1670 train_time:6564ms step_avg:99.45ms +step:67/1670 train_time:6661ms step_avg:99.41ms +step:68/1670 train_time:6757ms step_avg:99.37ms +step:69/1670 train_time:6853ms step_avg:99.31ms +step:70/1670 train_time:6949ms step_avg:99.27ms +step:71/1670 train_time:7045ms step_avg:99.22ms +step:72/1670 train_time:7141ms step_avg:99.18ms +step:73/1670 train_time:7237ms step_avg:99.13ms +step:74/1670 train_time:7332ms step_avg:99.08ms +step:75/1670 train_time:7428ms step_avg:99.04ms +step:76/1670 train_time:7524ms step_avg:99.01ms +step:77/1670 train_time:7621ms step_avg:98.98ms +step:78/1670 train_time:7716ms step_avg:98.93ms +step:79/1670 train_time:7811ms step_avg:98.88ms +step:80/1670 train_time:7907ms step_avg:98.84ms +step:81/1670 train_time:8004ms step_avg:98.81ms +step:82/1670 train_time:8100ms step_avg:98.78ms +step:83/1670 train_time:8196ms step_avg:98.74ms +step:84/1670 train_time:8291ms step_avg:98.70ms +step:85/1670 train_time:8387ms step_avg:98.67ms +step:86/1670 train_time:8482ms step_avg:98.63ms +step:87/1670 train_time:8579ms step_avg:98.61ms +step:88/1670 train_time:8674ms step_avg:98.57ms +step:89/1670 train_time:8770ms step_avg:98.53ms +step:90/1670 train_time:8866ms step_avg:98.51ms +step:91/1670 train_time:8963ms step_avg:98.49ms +step:92/1670 train_time:9059ms step_avg:98.47ms +step:93/1670 train_time:9154ms step_avg:98.43ms +step:94/1670 train_time:9250ms step_avg:98.41ms +step:95/1670 train_time:9345ms step_avg:98.37ms +step:96/1670 train_time:9440ms step_avg:98.34ms +step:97/1670 train_time:9536ms step_avg:98.31ms +step:98/1670 train_time:9631ms step_avg:98.28ms +step:99/1670 train_time:9727ms step_avg:98.26ms +step:100/1670 train_time:9823ms step_avg:98.23ms +step:101/1670 train_time:9919ms step_avg:98.21ms +step:102/1670 train_time:10015ms step_avg:98.18ms +step:103/1670 train_time:10111ms step_avg:98.16ms +step:104/1670 train_time:10207ms step_avg:98.15ms +step:105/1670 train_time:10303ms step_avg:98.13ms +step:106/1670 train_time:10400ms step_avg:98.11ms +step:107/1670 train_time:10495ms step_avg:98.09ms +step:108/1670 train_time:10591ms step_avg:98.06ms +step:109/1670 train_time:10686ms step_avg:98.04ms +step:110/1670 train_time:10781ms step_avg:98.01ms +step:111/1670 train_time:10877ms step_avg:97.99ms +step:112/1670 train_time:10972ms step_avg:97.96ms +step:113/1670 train_time:11069ms step_avg:97.95ms +step:114/1670 train_time:11164ms step_avg:97.93ms +step:115/1670 train_time:11260ms step_avg:97.91ms +step:116/1670 train_time:11356ms step_avg:97.90ms +step:117/1670 train_time:11452ms step_avg:97.88ms +step:118/1670 train_time:11548ms step_avg:97.86ms +step:119/1670 train_time:11643ms step_avg:97.84ms +step:120/1670 train_time:11739ms step_avg:97.82ms +step:121/1670 train_time:11834ms step_avg:97.80ms +step:122/1670 train_time:11929ms step_avg:97.78ms +step:123/1670 train_time:12025ms step_avg:97.77ms +step:124/1670 train_time:12122ms step_avg:97.75ms +step:125/1670 train_time:12217ms step_avg:97.74ms +step:125/1670 val_loss:4.2904 train_time:12312ms step_avg:98.49ms +step:126/1670 train_time:12333ms step_avg:97.88ms +step:127/1670 train_time:12418ms step_avg:97.78ms +step:128/1670 train_time:12523ms step_avg:97.84ms +step:129/1670 train_time:12621ms step_avg:97.83ms +step:130/1670 train_time:12717ms step_avg:97.82ms +step:131/1670 train_time:12812ms step_avg:97.80ms +step:132/1670 train_time:12906ms step_avg:97.78ms +step:133/1670 train_time:13001ms step_avg:97.75ms +step:134/1670 train_time:13095ms step_avg:97.73ms +step:135/1670 train_time:13190ms step_avg:97.70ms +step:136/1670 train_time:13284ms step_avg:97.68ms +step:137/1670 train_time:13380ms step_avg:97.67ms +step:138/1670 train_time:13478ms step_avg:97.67ms +step:139/1670 train_time:13575ms step_avg:97.66ms +step:140/1670 train_time:13672ms step_avg:97.66ms +step:141/1670 train_time:13769ms step_avg:97.65ms +step:142/1670 train_time:13865ms step_avg:97.64ms +step:143/1670 train_time:13960ms step_avg:97.62ms +step:144/1670 train_time:14054ms step_avg:97.60ms +step:145/1670 train_time:14149ms step_avg:97.58ms +step:146/1670 train_time:14244ms step_avg:97.56ms +step:147/1670 train_time:14338ms step_avg:97.54ms +step:148/1670 train_time:14434ms step_avg:97.53ms +step:149/1670 train_time:14530ms step_avg:97.52ms +step:150/1670 train_time:14627ms step_avg:97.52ms +step:151/1670 train_time:14724ms step_avg:97.51ms +step:152/1670 train_time:14820ms step_avg:97.50ms +step:153/1670 train_time:14916ms step_avg:97.49ms +step:154/1670 train_time:15011ms step_avg:97.48ms +step:155/1670 train_time:15107ms step_avg:97.46ms +step:156/1670 train_time:15202ms step_avg:97.45ms +step:157/1670 train_time:15298ms step_avg:97.44ms +step:158/1670 train_time:15394ms step_avg:97.43ms +step:159/1670 train_time:15489ms step_avg:97.42ms +step:160/1670 train_time:15585ms step_avg:97.41ms +step:161/1670 train_time:15681ms step_avg:97.39ms +step:162/1670 train_time:15776ms step_avg:97.38ms +step:163/1670 train_time:15872ms step_avg:97.37ms +step:164/1670 train_time:15968ms step_avg:97.36ms +step:165/1670 train_time:16063ms step_avg:97.35ms +step:166/1670 train_time:16158ms step_avg:97.34ms +step:167/1670 train_time:16254ms step_avg:97.33ms +step:168/1670 train_time:16348ms step_avg:97.31ms +step:169/1670 train_time:16445ms step_avg:97.31ms +step:170/1670 train_time:16540ms step_avg:97.30ms +step:171/1670 train_time:16636ms step_avg:97.29ms +step:172/1670 train_time:16732ms step_avg:97.28ms +step:173/1670 train_time:16828ms step_avg:97.27ms +step:174/1670 train_time:16924ms step_avg:97.26ms +step:175/1670 train_time:17019ms step_avg:97.25ms +step:176/1670 train_time:17114ms step_avg:97.24ms +step:177/1670 train_time:17209ms step_avg:97.23ms +step:178/1670 train_time:17304ms step_avg:97.22ms +step:179/1670 train_time:17400ms step_avg:97.21ms +step:180/1670 train_time:17495ms step_avg:97.19ms +step:181/1670 train_time:17591ms step_avg:97.19ms +step:182/1670 train_time:17687ms step_avg:97.18ms +step:183/1670 train_time:17783ms step_avg:97.17ms +step:184/1670 train_time:17879ms step_avg:97.17ms +step:185/1670 train_time:17974ms step_avg:97.16ms +step:186/1670 train_time:18069ms step_avg:97.14ms +step:187/1670 train_time:18164ms step_avg:97.13ms +step:188/1670 train_time:18260ms step_avg:97.13ms +step:189/1670 train_time:18355ms step_avg:97.12ms +step:190/1670 train_time:18451ms step_avg:97.11ms +step:191/1670 train_time:18547ms step_avg:97.11ms +step:192/1670 train_time:18642ms step_avg:97.09ms +step:193/1670 train_time:18738ms step_avg:97.09ms +step:194/1670 train_time:18834ms step_avg:97.08ms +step:195/1670 train_time:18930ms step_avg:97.08ms +step:196/1670 train_time:19026ms step_avg:97.07ms +step:197/1670 train_time:19122ms step_avg:97.06ms +step:198/1670 train_time:19217ms step_avg:97.05ms +step:199/1670 train_time:19313ms step_avg:97.05ms +step:200/1670 train_time:19408ms step_avg:97.04ms +step:201/1670 train_time:19505ms step_avg:97.04ms +step:202/1670 train_time:19600ms step_avg:97.03ms +step:203/1670 train_time:19695ms step_avg:97.02ms +step:204/1670 train_time:19790ms step_avg:97.01ms +step:205/1670 train_time:19886ms step_avg:97.01ms +step:206/1670 train_time:19982ms step_avg:97.00ms +step:207/1670 train_time:20078ms step_avg:96.99ms +step:208/1670 train_time:20173ms step_avg:96.99ms +step:209/1670 train_time:20268ms step_avg:96.98ms +step:210/1670 train_time:20363ms step_avg:96.97ms +step:211/1670 train_time:20459ms step_avg:96.96ms +step:212/1670 train_time:20555ms step_avg:96.96ms +step:213/1670 train_time:20884ms step_avg:98.05ms +step:214/1670 train_time:20972ms step_avg:98.00ms +step:215/1670 train_time:21066ms step_avg:97.98ms +step:216/1670 train_time:21161ms step_avg:97.97ms +step:217/1670 train_time:21256ms step_avg:97.95ms +step:218/1670 train_time:21351ms step_avg:97.94ms +step:219/1670 train_time:21446ms step_avg:97.93ms +step:220/1670 train_time:21540ms step_avg:97.91ms +step:221/1670 train_time:21635ms step_avg:97.89ms +step:222/1670 train_time:21729ms step_avg:97.88ms +step:223/1670 train_time:21828ms step_avg:97.88ms +step:224/1670 train_time:21928ms step_avg:97.89ms +step:225/1670 train_time:22026ms step_avg:97.89ms +step:226/1670 train_time:22122ms step_avg:97.88ms +step:227/1670 train_time:22216ms step_avg:97.87ms +step:228/1670 train_time:22311ms step_avg:97.86ms +step:229/1670 train_time:22406ms step_avg:97.84ms +step:230/1670 train_time:22501ms step_avg:97.83ms +step:231/1670 train_time:22595ms step_avg:97.81ms +step:232/1670 train_time:22690ms step_avg:97.80ms +step:233/1670 train_time:22786ms step_avg:97.79ms +step:234/1670 train_time:22883ms step_avg:97.79ms +step:235/1670 train_time:22979ms step_avg:97.78ms +step:236/1670 train_time:23075ms step_avg:97.77ms +step:237/1670 train_time:23171ms step_avg:97.77ms +step:238/1670 train_time:23268ms step_avg:97.76ms +step:239/1670 train_time:23362ms step_avg:97.75ms +step:240/1670 train_time:23457ms step_avg:97.74ms +step:241/1670 train_time:23551ms step_avg:97.72ms +step:242/1670 train_time:23646ms step_avg:97.71ms +step:243/1670 train_time:23741ms step_avg:97.70ms +step:244/1670 train_time:23837ms step_avg:97.69ms +step:245/1670 train_time:23934ms step_avg:97.69ms +step:246/1670 train_time:24030ms step_avg:97.68ms +step:247/1670 train_time:24126ms step_avg:97.68ms +step:248/1670 train_time:24222ms step_avg:97.67ms +step:249/1670 train_time:24317ms step_avg:97.66ms +step:250/1670 train_time:24412ms step_avg:97.65ms +step:250/1670 val_loss:3.9654 train_time:24506ms step_avg:98.02ms +step:251/1670 train_time:24527ms step_avg:97.72ms +step:252/1670 train_time:24608ms step_avg:97.65ms +step:253/1670 train_time:24708ms step_avg:97.66ms +step:254/1670 train_time:24803ms step_avg:97.65ms +step:255/1670 train_time:24898ms step_avg:97.64ms +step:256/1670 train_time:24993ms step_avg:97.63ms +step:257/1670 train_time:25088ms step_avg:97.62ms +step:258/1670 train_time:25183ms step_avg:97.61ms +step:259/1670 train_time:25277ms step_avg:97.60ms +step:260/1670 train_time:25372ms step_avg:97.59ms +step:261/1670 train_time:25467ms step_avg:97.57ms +step:262/1670 train_time:25564ms step_avg:97.57ms +step:263/1670 train_time:25662ms step_avg:97.57ms +step:264/1670 train_time:25758ms step_avg:97.57ms +step:265/1670 train_time:25854ms step_avg:97.56ms +step:266/1670 train_time:25950ms step_avg:97.56ms +step:267/1670 train_time:26045ms step_avg:97.55ms +step:268/1670 train_time:26140ms step_avg:97.54ms +step:269/1670 train_time:26234ms step_avg:97.52ms +step:270/1670 train_time:26329ms step_avg:97.52ms +step:271/1670 train_time:26424ms step_avg:97.51ms +step:272/1670 train_time:26520ms step_avg:97.50ms +step:273/1670 train_time:26617ms step_avg:97.50ms +step:274/1670 train_time:26713ms step_avg:97.49ms +step:275/1670 train_time:26809ms step_avg:97.49ms +step:276/1670 train_time:26905ms step_avg:97.48ms +step:277/1670 train_time:27000ms step_avg:97.47ms +step:278/1670 train_time:27095ms step_avg:97.47ms +step:279/1670 train_time:27190ms step_avg:97.46ms +step:280/1670 train_time:27286ms step_avg:97.45ms +step:281/1670 train_time:27380ms step_avg:97.44ms +step:282/1670 train_time:27475ms step_avg:97.43ms +step:283/1670 train_time:27571ms step_avg:97.43ms +step:284/1670 train_time:27668ms step_avg:97.42ms +step:285/1670 train_time:27763ms step_avg:97.41ms +step:286/1670 train_time:27859ms step_avg:97.41ms +step:287/1670 train_time:27955ms step_avg:97.40ms +step:288/1670 train_time:28051ms step_avg:97.40ms +step:289/1670 train_time:28146ms step_avg:97.39ms +step:290/1670 train_time:28240ms step_avg:97.38ms +step:291/1670 train_time:28335ms step_avg:97.37ms +step:292/1670 train_time:28430ms step_avg:97.36ms +step:293/1670 train_time:28525ms step_avg:97.36ms +step:294/1670 train_time:28621ms step_avg:97.35ms +step:295/1670 train_time:28717ms step_avg:97.35ms +step:296/1670 train_time:28813ms step_avg:97.34ms +step:297/1670 train_time:28909ms step_avg:97.34ms +step:298/1670 train_time:29005ms step_avg:97.33ms +step:299/1670 train_time:29100ms step_avg:97.32ms +step:300/1670 train_time:29196ms step_avg:97.32ms +step:301/1670 train_time:29292ms step_avg:97.31ms +step:302/1670 train_time:29387ms step_avg:97.31ms +step:303/1670 train_time:29481ms step_avg:97.30ms +step:304/1670 train_time:29577ms step_avg:97.29ms +step:305/1670 train_time:29674ms step_avg:97.29ms +step:306/1670 train_time:29769ms step_avg:97.28ms +step:307/1670 train_time:29864ms step_avg:97.28ms +step:308/1670 train_time:29960ms step_avg:97.27ms +step:309/1670 train_time:30056ms step_avg:97.27ms +step:310/1670 train_time:30151ms step_avg:97.26ms +step:311/1670 train_time:30247ms step_avg:97.26ms +step:312/1670 train_time:30343ms step_avg:97.25ms +step:313/1670 train_time:30438ms step_avg:97.25ms +step:314/1670 train_time:30533ms step_avg:97.24ms +step:315/1670 train_time:30629ms step_avg:97.23ms +step:316/1670 train_time:30724ms step_avg:97.23ms +step:317/1670 train_time:30819ms step_avg:97.22ms +step:318/1670 train_time:30916ms step_avg:97.22ms +step:319/1670 train_time:31011ms step_avg:97.21ms +step:320/1670 train_time:31106ms step_avg:97.21ms +step:321/1670 train_time:31203ms step_avg:97.20ms +step:322/1670 train_time:31298ms step_avg:97.20ms +step:323/1670 train_time:31394ms step_avg:97.20ms +step:324/1670 train_time:31490ms step_avg:97.19ms +step:325/1670 train_time:31585ms step_avg:97.19ms +step:326/1670 train_time:31680ms step_avg:97.18ms +step:327/1670 train_time:31776ms step_avg:97.17ms +step:328/1670 train_time:31872ms step_avg:97.17ms +step:329/1670 train_time:31967ms step_avg:97.17ms +step:330/1670 train_time:32063ms step_avg:97.16ms +step:331/1670 train_time:32158ms step_avg:97.15ms +step:332/1670 train_time:32255ms step_avg:97.15ms +step:333/1670 train_time:32350ms step_avg:97.15ms +step:334/1670 train_time:32446ms step_avg:97.14ms +step:335/1670 train_time:32541ms step_avg:97.14ms +step:336/1670 train_time:32638ms step_avg:97.14ms +step:337/1670 train_time:32733ms step_avg:97.13ms +step:338/1670 train_time:32829ms step_avg:97.13ms +step:339/1670 train_time:32924ms step_avg:97.12ms +step:340/1670 train_time:33019ms step_avg:97.12ms +step:341/1670 train_time:33115ms step_avg:97.11ms +step:342/1670 train_time:33211ms step_avg:97.11ms +step:343/1670 train_time:33307ms step_avg:97.10ms +step:344/1670 train_time:33401ms step_avg:97.10ms +step:345/1670 train_time:33497ms step_avg:97.09ms +step:346/1670 train_time:33594ms step_avg:97.09ms +step:347/1670 train_time:33690ms step_avg:97.09ms +step:348/1670 train_time:33787ms step_avg:97.09ms +step:349/1670 train_time:33882ms step_avg:97.08ms +step:350/1670 train_time:33977ms step_avg:97.08ms +step:351/1670 train_time:34072ms step_avg:97.07ms +step:352/1670 train_time:34168ms step_avg:97.07ms +step:353/1670 train_time:34264ms step_avg:97.06ms +step:354/1670 train_time:34359ms step_avg:97.06ms +step:355/1670 train_time:34455ms step_avg:97.06ms +step:356/1670 train_time:34550ms step_avg:97.05ms +step:357/1670 train_time:34645ms step_avg:97.05ms +step:358/1670 train_time:34741ms step_avg:97.04ms +step:359/1670 train_time:34837ms step_avg:97.04ms +step:360/1670 train_time:34933ms step_avg:97.04ms +step:361/1670 train_time:35029ms step_avg:97.03ms +step:362/1670 train_time:35124ms step_avg:97.03ms +step:363/1670 train_time:35220ms step_avg:97.03ms +step:364/1670 train_time:35315ms step_avg:97.02ms +step:365/1670 train_time:35411ms step_avg:97.02ms +step:366/1670 train_time:35506ms step_avg:97.01ms +step:367/1670 train_time:35601ms step_avg:97.01ms +step:368/1670 train_time:35697ms step_avg:97.00ms +step:369/1670 train_time:35793ms step_avg:97.00ms +step:370/1670 train_time:35889ms step_avg:97.00ms +step:371/1670 train_time:35985ms step_avg:97.00ms +step:372/1670 train_time:36080ms step_avg:96.99ms +step:373/1670 train_time:36176ms step_avg:96.99ms +step:374/1670 train_time:36272ms step_avg:96.98ms +step:375/1670 train_time:36368ms step_avg:96.98ms +step:375/1670 val_loss:3.8173 train_time:36462ms step_avg:97.23ms +step:376/1670 train_time:36485ms step_avg:97.03ms +step:377/1670 train_time:36566ms step_avg:96.99ms +step:378/1670 train_time:36662ms step_avg:96.99ms +step:379/1670 train_time:36759ms step_avg:96.99ms +step:380/1670 train_time:36855ms step_avg:96.99ms +step:381/1670 train_time:36950ms step_avg:96.98ms +step:382/1670 train_time:37045ms step_avg:96.98ms +step:383/1670 train_time:37140ms step_avg:96.97ms +step:384/1670 train_time:37235ms step_avg:96.97ms +step:385/1670 train_time:37329ms step_avg:96.96ms +step:386/1670 train_time:37425ms step_avg:96.96ms +step:387/1670 train_time:37523ms step_avg:96.96ms +step:388/1670 train_time:37621ms step_avg:96.96ms +step:389/1670 train_time:37718ms step_avg:96.96ms +step:390/1670 train_time:37813ms step_avg:96.96ms +step:391/1670 train_time:37908ms step_avg:96.95ms +step:392/1670 train_time:38003ms step_avg:96.95ms +step:393/1670 train_time:38098ms step_avg:96.94ms +step:394/1670 train_time:38193ms step_avg:96.94ms +step:395/1670 train_time:38287ms step_avg:96.93ms +step:396/1670 train_time:38382ms step_avg:96.93ms +step:397/1670 train_time:38479ms step_avg:96.92ms +step:398/1670 train_time:38575ms step_avg:96.92ms +step:399/1670 train_time:38672ms step_avg:96.92ms +step:400/1670 train_time:38768ms step_avg:96.92ms +step:401/1670 train_time:38864ms step_avg:96.92ms +step:402/1670 train_time:38959ms step_avg:96.91ms +step:403/1670 train_time:39055ms step_avg:96.91ms +step:404/1670 train_time:39150ms step_avg:96.91ms +step:405/1670 train_time:39245ms step_avg:96.90ms +step:406/1670 train_time:39340ms step_avg:96.90ms +step:407/1670 train_time:39436ms step_avg:96.89ms +step:408/1670 train_time:39532ms step_avg:96.89ms +step:409/1670 train_time:39629ms step_avg:96.89ms +step:410/1670 train_time:39724ms step_avg:96.89ms +step:411/1670 train_time:39820ms step_avg:96.89ms +step:412/1670 train_time:39916ms step_avg:96.88ms +step:413/1670 train_time:40011ms step_avg:96.88ms +step:414/1670 train_time:40107ms step_avg:96.88ms +step:415/1670 train_time:40202ms step_avg:96.87ms +step:416/1670 train_time:40298ms step_avg:96.87ms +step:417/1670 train_time:40393ms step_avg:96.87ms +step:418/1670 train_time:40488ms step_avg:96.86ms +step:419/1670 train_time:40584ms step_avg:96.86ms +step:420/1670 train_time:40680ms step_avg:96.86ms +step:421/1670 train_time:40776ms step_avg:96.85ms +step:422/1670 train_time:40872ms step_avg:96.85ms +step:423/1670 train_time:40967ms step_avg:96.85ms +step:424/1670 train_time:41063ms step_avg:96.85ms +step:425/1670 train_time:41326ms step_avg:97.24ms +step:426/1670 train_time:41451ms step_avg:97.30ms +step:427/1670 train_time:41544ms step_avg:97.29ms +step:428/1670 train_time:41639ms step_avg:97.29ms +step:429/1670 train_time:41734ms step_avg:97.28ms +step:430/1670 train_time:41829ms step_avg:97.28ms +step:431/1670 train_time:41923ms step_avg:97.27ms +step:432/1670 train_time:42018ms step_avg:97.26ms +step:433/1670 train_time:42113ms step_avg:97.26ms +step:434/1670 train_time:42209ms step_avg:97.25ms +step:435/1670 train_time:42305ms step_avg:97.25ms +step:436/1670 train_time:42404ms step_avg:97.26ms +step:437/1670 train_time:42502ms step_avg:97.26ms +step:438/1670 train_time:42597ms step_avg:97.25ms +step:439/1670 train_time:42692ms step_avg:97.25ms +step:440/1670 train_time:42788ms step_avg:97.24ms +step:441/1670 train_time:42883ms step_avg:97.24ms +step:442/1670 train_time:42978ms step_avg:97.24ms +step:443/1670 train_time:43073ms step_avg:97.23ms +step:444/1670 train_time:43168ms step_avg:97.23ms +step:445/1670 train_time:43263ms step_avg:97.22ms +step:446/1670 train_time:43361ms step_avg:97.22ms +step:447/1670 train_time:43459ms step_avg:97.22ms +step:448/1670 train_time:43556ms step_avg:97.22ms +step:449/1670 train_time:43652ms step_avg:97.22ms +step:450/1670 train_time:43747ms step_avg:97.22ms +step:451/1670 train_time:43843ms step_avg:97.21ms +step:452/1670 train_time:43937ms step_avg:97.21ms +step:453/1670 train_time:44033ms step_avg:97.20ms +step:454/1670 train_time:44127ms step_avg:97.20ms +step:455/1670 train_time:44223ms step_avg:97.19ms +step:456/1670 train_time:44319ms step_avg:97.19ms +step:457/1670 train_time:44415ms step_avg:97.19ms +step:458/1670 train_time:44512ms step_avg:97.19ms +step:459/1670 train_time:44608ms step_avg:97.18ms +step:460/1670 train_time:44703ms step_avg:97.18ms +step:461/1670 train_time:44798ms step_avg:97.18ms +step:462/1670 train_time:44894ms step_avg:97.17ms +step:463/1670 train_time:44989ms step_avg:97.17ms +step:464/1670 train_time:45083ms step_avg:97.16ms +step:465/1670 train_time:45179ms step_avg:97.16ms +step:466/1670 train_time:45274ms step_avg:97.16ms +step:467/1670 train_time:45371ms step_avg:97.15ms +step:468/1670 train_time:45466ms step_avg:97.15ms +step:469/1670 train_time:45562ms step_avg:97.15ms +step:470/1670 train_time:45659ms step_avg:97.15ms +step:471/1670 train_time:45755ms step_avg:97.14ms +step:472/1670 train_time:45850ms step_avg:97.14ms +step:473/1670 train_time:45946ms step_avg:97.14ms +step:474/1670 train_time:46041ms step_avg:97.13ms +step:475/1670 train_time:46136ms step_avg:97.13ms +step:476/1670 train_time:46232ms step_avg:97.13ms +step:477/1670 train_time:46328ms step_avg:97.12ms +step:478/1670 train_time:46423ms step_avg:97.12ms +step:479/1670 train_time:46520ms step_avg:97.12ms +step:480/1670 train_time:46616ms step_avg:97.12ms +step:481/1670 train_time:46712ms step_avg:97.11ms +step:482/1670 train_time:46808ms step_avg:97.11ms +step:483/1670 train_time:46903ms step_avg:97.11ms +step:484/1670 train_time:46998ms step_avg:97.10ms +step:485/1670 train_time:47093ms step_avg:97.10ms +step:486/1670 train_time:47189ms step_avg:97.10ms +step:487/1670 train_time:47283ms step_avg:97.09ms +step:488/1670 train_time:47379ms step_avg:97.09ms +step:489/1670 train_time:47476ms step_avg:97.09ms +step:490/1670 train_time:47572ms step_avg:97.09ms +step:491/1670 train_time:47668ms step_avg:97.08ms +step:492/1670 train_time:47764ms step_avg:97.08ms +step:493/1670 train_time:47860ms step_avg:97.08ms +step:494/1670 train_time:47957ms step_avg:97.08ms +step:495/1670 train_time:48053ms step_avg:97.08ms +step:496/1670 train_time:48147ms step_avg:97.07ms +step:497/1670 train_time:48243ms step_avg:97.07ms +step:498/1670 train_time:48339ms step_avg:97.07ms +step:499/1670 train_time:48435ms step_avg:97.06ms +step:500/1670 train_time:48532ms step_avg:97.06ms +step:500/1670 val_loss:3.7160 train_time:48627ms step_avg:97.25ms +step:501/1670 train_time:48648ms step_avg:97.10ms +step:502/1670 train_time:48732ms step_avg:97.08ms +step:503/1670 train_time:48834ms step_avg:97.09ms +step:504/1670 train_time:48931ms step_avg:97.08ms +step:505/1670 train_time:49026ms step_avg:97.08ms +step:506/1670 train_time:49120ms step_avg:97.08ms +step:507/1670 train_time:49215ms step_avg:97.07ms +step:508/1670 train_time:49309ms step_avg:97.07ms +step:509/1670 train_time:49404ms step_avg:97.06ms +step:510/1670 train_time:49498ms step_avg:97.05ms +step:511/1670 train_time:49594ms step_avg:97.05ms +step:512/1670 train_time:49692ms step_avg:97.05ms +step:513/1670 train_time:49789ms step_avg:97.06ms +step:514/1670 train_time:49886ms step_avg:97.06ms +step:515/1670 train_time:49982ms step_avg:97.05ms +step:516/1670 train_time:50078ms step_avg:97.05ms +step:517/1670 train_time:50173ms step_avg:97.05ms +step:518/1670 train_time:50269ms step_avg:97.04ms +step:519/1670 train_time:50364ms step_avg:97.04ms +step:520/1670 train_time:50459ms step_avg:97.04ms +step:521/1670 train_time:50554ms step_avg:97.03ms +step:522/1670 train_time:50651ms step_avg:97.03ms +step:523/1670 train_time:50747ms step_avg:97.03ms +step:524/1670 train_time:50844ms step_avg:97.03ms +step:525/1670 train_time:50940ms step_avg:97.03ms +step:526/1670 train_time:51035ms step_avg:97.03ms +step:527/1670 train_time:51131ms step_avg:97.02ms +step:528/1670 train_time:51227ms step_avg:97.02ms +step:529/1670 train_time:51322ms step_avg:97.02ms +step:530/1670 train_time:51417ms step_avg:97.01ms +step:531/1670 train_time:51512ms step_avg:97.01ms +step:532/1670 train_time:51608ms step_avg:97.01ms +step:533/1670 train_time:51704ms step_avg:97.01ms +step:534/1670 train_time:51801ms step_avg:97.00ms +step:535/1670 train_time:51897ms step_avg:97.00ms +step:536/1670 train_time:51993ms step_avg:97.00ms +step:537/1670 train_time:52089ms step_avg:97.00ms +step:538/1670 train_time:52184ms step_avg:97.00ms +step:539/1670 train_time:52280ms step_avg:96.99ms +step:540/1670 train_time:52375ms step_avg:96.99ms +step:541/1670 train_time:52470ms step_avg:96.99ms +step:542/1670 train_time:52566ms step_avg:96.98ms +step:543/1670 train_time:52661ms step_avg:96.98ms +step:544/1670 train_time:52756ms step_avg:96.98ms +step:545/1670 train_time:52852ms step_avg:96.98ms +step:546/1670 train_time:52949ms step_avg:96.98ms +step:547/1670 train_time:53045ms step_avg:96.97ms +step:548/1670 train_time:53141ms step_avg:96.97ms +step:549/1670 train_time:53237ms step_avg:96.97ms +step:550/1670 train_time:53333ms step_avg:96.97ms +step:551/1670 train_time:53429ms step_avg:96.97ms +step:552/1670 train_time:53524ms step_avg:96.96ms +step:553/1670 train_time:53619ms step_avg:96.96ms +step:554/1670 train_time:53714ms step_avg:96.96ms +step:555/1670 train_time:53810ms step_avg:96.95ms +step:556/1670 train_time:53906ms step_avg:96.95ms +step:557/1670 train_time:54003ms step_avg:96.95ms +step:558/1670 train_time:54099ms step_avg:96.95ms +step:559/1670 train_time:54196ms step_avg:96.95ms +step:560/1670 train_time:54293ms step_avg:96.95ms +step:561/1670 train_time:54390ms step_avg:96.95ms +step:562/1670 train_time:54487ms step_avg:96.95ms +step:563/1670 train_time:54584ms step_avg:96.95ms +step:564/1670 train_time:54681ms step_avg:96.95ms +step:565/1670 train_time:54778ms step_avg:96.95ms +step:566/1670 train_time:54875ms step_avg:96.95ms +step:567/1670 train_time:54973ms step_avg:96.95ms +step:568/1670 train_time:55071ms step_avg:96.96ms +step:569/1670 train_time:55168ms step_avg:96.96ms +step:570/1670 train_time:55265ms step_avg:96.96ms +step:571/1670 train_time:55362ms step_avg:96.96ms +step:572/1670 train_time:55459ms step_avg:96.96ms +step:573/1670 train_time:55556ms step_avg:96.96ms +step:574/1670 train_time:55652ms step_avg:96.96ms +step:575/1670 train_time:55751ms step_avg:96.96ms +step:576/1670 train_time:55850ms step_avg:96.96ms +step:577/1670 train_time:55948ms step_avg:96.96ms +step:578/1670 train_time:56045ms step_avg:96.96ms +step:579/1670 train_time:56142ms step_avg:96.96ms +step:580/1670 train_time:56240ms step_avg:96.97ms +step:581/1670 train_time:56336ms step_avg:96.96ms +step:582/1670 train_time:56432ms step_avg:96.96ms +step:583/1670 train_time:56529ms step_avg:96.96ms +step:584/1670 train_time:56626ms step_avg:96.96ms +step:585/1670 train_time:56723ms step_avg:96.96ms +step:586/1670 train_time:56820ms step_avg:96.96ms +step:587/1670 train_time:56918ms step_avg:96.96ms +step:588/1670 train_time:57015ms step_avg:96.96ms +step:589/1670 train_time:57113ms step_avg:96.97ms +step:590/1670 train_time:57211ms step_avg:96.97ms +step:591/1670 train_time:57309ms step_avg:96.97ms +step:592/1670 train_time:57406ms step_avg:96.97ms +step:593/1670 train_time:57503ms step_avg:96.97ms +step:594/1670 train_time:57600ms step_avg:96.97ms +step:595/1670 train_time:57697ms step_avg:96.97ms +step:596/1670 train_time:57793ms step_avg:96.97ms +step:597/1670 train_time:57892ms step_avg:96.97ms +step:598/1670 train_time:57989ms step_avg:96.97ms +step:599/1670 train_time:58087ms step_avg:96.97ms +step:600/1670 train_time:58184ms step_avg:96.97ms +step:601/1670 train_time:58281ms step_avg:96.97ms +step:602/1670 train_time:58378ms step_avg:96.97ms +step:603/1670 train_time:58475ms step_avg:96.97ms +step:604/1670 train_time:58572ms step_avg:96.97ms +step:605/1670 train_time:58670ms step_avg:96.98ms +step:606/1670 train_time:58768ms step_avg:96.98ms +step:607/1670 train_time:58865ms step_avg:96.98ms +step:608/1670 train_time:58962ms step_avg:96.98ms +step:609/1670 train_time:59059ms step_avg:96.98ms +step:610/1670 train_time:59156ms step_avg:96.98ms +step:611/1670 train_time:59254ms step_avg:96.98ms +step:612/1670 train_time:59351ms step_avg:96.98ms +step:613/1670 train_time:59448ms step_avg:96.98ms +step:614/1670 train_time:59545ms step_avg:96.98ms +step:615/1670 train_time:59642ms step_avg:96.98ms +step:616/1670 train_time:59739ms step_avg:96.98ms +step:617/1670 train_time:59836ms step_avg:96.98ms +step:618/1670 train_time:59933ms step_avg:96.98ms +step:619/1670 train_time:60030ms step_avg:96.98ms +step:620/1670 train_time:60129ms step_avg:96.98ms +step:621/1670 train_time:60225ms step_avg:96.98ms +step:622/1670 train_time:60322ms step_avg:96.98ms +step:623/1670 train_time:60418ms step_avg:96.98ms +step:624/1670 train_time:60515ms step_avg:96.98ms +step:625/1670 train_time:60612ms step_avg:96.98ms +step:625/1670 val_loss:3.6152 train_time:60710ms step_avg:97.14ms +step:626/1670 train_time:60731ms step_avg:97.01ms +step:627/1670 train_time:60820ms step_avg:97.00ms +step:628/1670 train_time:60918ms step_avg:97.00ms +step:629/1670 train_time:61014ms step_avg:97.00ms +step:630/1670 train_time:61110ms step_avg:97.00ms +step:631/1670 train_time:61206ms step_avg:97.00ms +step:632/1670 train_time:61301ms step_avg:97.00ms +step:633/1670 train_time:61397ms step_avg:96.99ms +step:634/1670 train_time:61493ms step_avg:96.99ms +step:635/1670 train_time:61589ms step_avg:96.99ms +step:636/1670 train_time:61687ms step_avg:96.99ms +step:637/1670 train_time:61786ms step_avg:97.00ms +step:638/1670 train_time:61886ms step_avg:97.00ms +step:639/1670 train_time:62255ms step_avg:97.43ms +step:640/1670 train_time:62349ms step_avg:97.42ms +step:641/1670 train_time:62444ms step_avg:97.42ms +step:642/1670 train_time:62540ms step_avg:97.41ms +step:643/1670 train_time:62636ms step_avg:97.41ms +step:644/1670 train_time:62733ms step_avg:97.41ms +step:645/1670 train_time:62828ms step_avg:97.41ms +step:646/1670 train_time:62924ms step_avg:97.41ms +step:647/1670 train_time:63020ms step_avg:97.40ms +step:648/1670 train_time:63116ms step_avg:97.40ms +step:649/1670 train_time:63220ms step_avg:97.41ms +step:650/1670 train_time:63320ms step_avg:97.42ms +step:651/1670 train_time:63418ms step_avg:97.42ms +step:652/1670 train_time:63515ms step_avg:97.42ms +step:653/1670 train_time:63612ms step_avg:97.42ms +step:654/1670 train_time:63709ms step_avg:97.42ms +step:655/1670 train_time:63805ms step_avg:97.41ms +step:656/1670 train_time:63900ms step_avg:97.41ms +step:657/1670 train_time:63996ms step_avg:97.41ms +step:658/1670 train_time:64093ms step_avg:97.41ms +step:659/1670 train_time:64195ms step_avg:97.41ms +step:660/1670 train_time:64295ms step_avg:97.42ms +step:661/1670 train_time:64394ms step_avg:97.42ms +step:662/1670 train_time:64491ms step_avg:97.42ms +step:663/1670 train_time:64589ms step_avg:97.42ms +step:664/1670 train_time:64686ms step_avg:97.42ms +step:665/1670 train_time:64782ms step_avg:97.42ms +step:666/1670 train_time:64878ms step_avg:97.41ms +step:667/1670 train_time:64974ms step_avg:97.41ms +step:668/1670 train_time:65070ms step_avg:97.41ms +step:669/1670 train_time:65168ms step_avg:97.41ms +step:670/1670 train_time:65267ms step_avg:97.41ms +step:671/1670 train_time:65366ms step_avg:97.42ms +step:672/1670 train_time:65464ms step_avg:97.42ms +step:673/1670 train_time:65562ms step_avg:97.42ms +step:674/1670 train_time:65659ms step_avg:97.42ms +step:675/1670 train_time:65757ms step_avg:97.42ms +step:676/1670 train_time:65854ms step_avg:97.42ms +step:677/1670 train_time:65951ms step_avg:97.42ms +step:678/1670 train_time:66048ms step_avg:97.42ms +step:679/1670 train_time:66145ms step_avg:97.42ms +step:680/1670 train_time:66242ms step_avg:97.41ms +step:681/1670 train_time:66340ms step_avg:97.42ms +step:682/1670 train_time:66439ms step_avg:97.42ms +step:683/1670 train_time:66536ms step_avg:97.42ms +step:684/1670 train_time:66634ms step_avg:97.42ms +step:685/1670 train_time:66731ms step_avg:97.42ms +step:686/1670 train_time:66829ms step_avg:97.42ms +step:687/1670 train_time:66926ms step_avg:97.42ms +step:688/1670 train_time:67022ms step_avg:97.41ms +step:689/1670 train_time:67118ms step_avg:97.41ms +step:690/1670 train_time:67215ms step_avg:97.41ms +step:691/1670 train_time:67314ms step_avg:97.42ms +step:692/1670 train_time:67413ms step_avg:97.42ms +step:693/1670 train_time:67511ms step_avg:97.42ms +step:694/1670 train_time:67608ms step_avg:97.42ms +step:695/1670 train_time:67705ms step_avg:97.42ms +step:696/1670 train_time:67802ms step_avg:97.42ms +step:697/1670 train_time:67899ms step_avg:97.42ms +step:698/1670 train_time:67995ms step_avg:97.41ms +step:699/1670 train_time:68094ms step_avg:97.42ms +step:700/1670 train_time:68192ms step_avg:97.42ms +step:701/1670 train_time:68290ms step_avg:97.42ms +step:702/1670 train_time:68388ms step_avg:97.42ms +step:703/1670 train_time:68486ms step_avg:97.42ms +step:704/1670 train_time:68584ms step_avg:97.42ms +step:705/1670 train_time:68681ms step_avg:97.42ms +step:706/1670 train_time:68778ms step_avg:97.42ms +step:707/1670 train_time:68875ms step_avg:97.42ms +step:708/1670 train_time:68973ms step_avg:97.42ms +step:709/1670 train_time:69070ms step_avg:97.42ms +step:710/1670 train_time:69166ms step_avg:97.42ms +step:711/1670 train_time:69264ms step_avg:97.42ms +step:712/1670 train_time:69361ms step_avg:97.42ms +step:713/1670 train_time:69458ms step_avg:97.42ms +step:714/1670 train_time:69555ms step_avg:97.42ms +step:715/1670 train_time:69654ms step_avg:97.42ms +step:716/1670 train_time:69752ms step_avg:97.42ms +step:717/1670 train_time:69849ms step_avg:97.42ms +step:718/1670 train_time:69946ms step_avg:97.42ms +step:719/1670 train_time:70042ms step_avg:97.42ms +step:720/1670 train_time:70139ms step_avg:97.41ms +step:721/1670 train_time:70236ms step_avg:97.42ms +step:722/1670 train_time:70334ms step_avg:97.41ms +step:723/1670 train_time:70431ms step_avg:97.42ms +step:724/1670 train_time:70528ms step_avg:97.41ms +step:725/1670 train_time:70626ms step_avg:97.42ms +step:726/1670 train_time:70724ms step_avg:97.42ms +step:727/1670 train_time:70821ms step_avg:97.42ms +step:728/1670 train_time:70918ms step_avg:97.42ms +step:729/1670 train_time:71015ms step_avg:97.41ms +step:730/1670 train_time:71113ms step_avg:97.42ms +step:731/1670 train_time:71211ms step_avg:97.42ms +step:732/1670 train_time:71308ms step_avg:97.42ms +step:733/1670 train_time:71406ms step_avg:97.42ms +step:734/1670 train_time:71503ms step_avg:97.41ms +step:735/1670 train_time:71599ms step_avg:97.41ms +step:736/1670 train_time:71697ms step_avg:97.41ms +step:737/1670 train_time:71795ms step_avg:97.42ms +step:738/1670 train_time:71894ms step_avg:97.42ms +step:739/1670 train_time:71991ms step_avg:97.42ms +step:740/1670 train_time:72088ms step_avg:97.42ms +step:741/1670 train_time:72185ms step_avg:97.42ms +step:742/1670 train_time:72282ms step_avg:97.42ms +step:743/1670 train_time:72378ms step_avg:97.41ms +step:744/1670 train_time:72476ms step_avg:97.41ms +step:745/1670 train_time:72574ms step_avg:97.41ms +step:746/1670 train_time:72671ms step_avg:97.41ms +step:747/1670 train_time:72769ms step_avg:97.41ms +step:748/1670 train_time:72866ms step_avg:97.42ms +step:749/1670 train_time:72964ms step_avg:97.42ms +step:750/1670 train_time:73061ms step_avg:97.41ms +step:750/1670 val_loss:3.5627 train_time:73157ms step_avg:97.54ms +step:751/1670 train_time:73179ms step_avg:97.44ms +step:752/1670 train_time:73261ms step_avg:97.42ms +step:753/1670 train_time:73364ms step_avg:97.43ms +step:754/1670 train_time:73461ms step_avg:97.43ms +step:755/1670 train_time:73558ms step_avg:97.43ms +step:756/1670 train_time:73654ms step_avg:97.43ms +step:757/1670 train_time:73750ms step_avg:97.42ms +step:758/1670 train_time:73846ms step_avg:97.42ms +step:759/1670 train_time:73943ms step_avg:97.42ms +step:760/1670 train_time:74039ms step_avg:97.42ms +step:761/1670 train_time:74136ms step_avg:97.42ms +step:762/1670 train_time:74235ms step_avg:97.42ms +step:763/1670 train_time:74335ms step_avg:97.42ms +step:764/1670 train_time:74434ms step_avg:97.43ms +step:765/1670 train_time:74530ms step_avg:97.43ms +step:766/1670 train_time:74627ms step_avg:97.42ms +step:767/1670 train_time:74724ms step_avg:97.42ms +step:768/1670 train_time:74820ms step_avg:97.42ms +step:769/1670 train_time:74916ms step_avg:97.42ms +step:770/1670 train_time:75012ms step_avg:97.42ms +step:771/1670 train_time:75109ms step_avg:97.42ms +step:772/1670 train_time:75208ms step_avg:97.42ms +step:773/1670 train_time:75307ms step_avg:97.42ms +step:774/1670 train_time:75405ms step_avg:97.42ms +step:775/1670 train_time:75504ms step_avg:97.42ms +step:776/1670 train_time:75601ms step_avg:97.42ms +step:777/1670 train_time:75698ms step_avg:97.42ms +step:778/1670 train_time:75794ms step_avg:97.42ms +step:779/1670 train_time:75890ms step_avg:97.42ms +step:780/1670 train_time:75987ms step_avg:97.42ms +step:781/1670 train_time:76084ms step_avg:97.42ms +step:782/1670 train_time:76183ms step_avg:97.42ms +step:783/1670 train_time:76281ms step_avg:97.42ms +step:784/1670 train_time:76380ms step_avg:97.42ms +step:785/1670 train_time:76477ms step_avg:97.42ms +step:786/1670 train_time:76576ms step_avg:97.42ms +step:787/1670 train_time:76673ms step_avg:97.42ms +step:788/1670 train_time:76770ms step_avg:97.42ms +step:789/1670 train_time:76867ms step_avg:97.42ms +step:790/1670 train_time:76962ms step_avg:97.42ms +step:791/1670 train_time:77059ms step_avg:97.42ms +step:792/1670 train_time:77157ms step_avg:97.42ms +step:793/1670 train_time:77255ms step_avg:97.42ms +step:794/1670 train_time:77353ms step_avg:97.42ms +step:795/1670 train_time:77451ms step_avg:97.42ms +step:796/1670 train_time:77548ms step_avg:97.42ms +step:797/1670 train_time:77645ms step_avg:97.42ms +step:798/1670 train_time:77744ms step_avg:97.42ms +step:799/1670 train_time:77841ms step_avg:97.42ms +step:800/1670 train_time:77937ms step_avg:97.42ms +step:801/1670 train_time:78034ms step_avg:97.42ms +step:802/1670 train_time:78131ms step_avg:97.42ms +step:803/1670 train_time:78228ms step_avg:97.42ms +step:804/1670 train_time:78326ms step_avg:97.42ms +step:805/1670 train_time:78424ms step_avg:97.42ms +step:806/1670 train_time:78522ms step_avg:97.42ms +step:807/1670 train_time:78619ms step_avg:97.42ms +step:808/1670 train_time:78716ms step_avg:97.42ms +step:809/1670 train_time:78813ms step_avg:97.42ms +step:810/1670 train_time:78909ms step_avg:97.42ms +step:811/1670 train_time:79006ms step_avg:97.42ms +step:812/1670 train_time:79103ms step_avg:97.42ms +step:813/1670 train_time:79200ms step_avg:97.42ms +step:814/1670 train_time:79298ms step_avg:97.42ms +step:815/1670 train_time:79396ms step_avg:97.42ms +step:816/1670 train_time:79495ms step_avg:97.42ms +step:817/1670 train_time:79592ms step_avg:97.42ms +step:818/1670 train_time:79688ms step_avg:97.42ms +step:819/1670 train_time:79785ms step_avg:97.42ms +step:820/1670 train_time:79883ms step_avg:97.42ms +step:821/1670 train_time:79980ms step_avg:97.42ms +step:822/1670 train_time:80077ms step_avg:97.42ms +step:823/1670 train_time:80173ms step_avg:97.42ms +step:824/1670 train_time:80270ms step_avg:97.41ms +step:825/1670 train_time:80368ms step_avg:97.42ms +step:826/1670 train_time:80466ms step_avg:97.42ms +step:827/1670 train_time:80565ms step_avg:97.42ms +step:828/1670 train_time:80663ms step_avg:97.42ms +step:829/1670 train_time:80759ms step_avg:97.42ms +step:830/1670 train_time:80857ms step_avg:97.42ms +step:831/1670 train_time:80953ms step_avg:97.42ms +step:832/1670 train_time:81050ms step_avg:97.42ms +step:833/1670 train_time:81148ms step_avg:97.42ms +step:834/1670 train_time:81245ms step_avg:97.42ms +step:835/1670 train_time:81343ms step_avg:97.42ms +step:836/1670 train_time:81441ms step_avg:97.42ms +step:837/1670 train_time:81539ms step_avg:97.42ms +step:838/1670 train_time:81635ms step_avg:97.42ms +step:839/1670 train_time:81732ms step_avg:97.42ms +step:840/1670 train_time:81830ms step_avg:97.42ms +step:841/1670 train_time:81926ms step_avg:97.42ms +step:842/1670 train_time:82024ms step_avg:97.42ms +step:843/1670 train_time:82122ms step_avg:97.42ms +step:844/1670 train_time:82219ms step_avg:97.42ms +step:845/1670 train_time:82316ms step_avg:97.42ms +step:846/1670 train_time:82413ms step_avg:97.42ms +step:847/1670 train_time:82511ms step_avg:97.42ms +step:848/1670 train_time:82608ms step_avg:97.41ms +step:849/1670 train_time:82705ms step_avg:97.42ms +step:850/1670 train_time:82804ms step_avg:97.42ms +step:851/1670 train_time:83075ms step_avg:97.62ms +step:852/1670 train_time:83271ms step_avg:97.74ms +step:853/1670 train_time:83366ms step_avg:97.73ms +step:854/1670 train_time:83462ms step_avg:97.73ms +step:855/1670 train_time:83558ms step_avg:97.73ms +step:856/1670 train_time:83654ms step_avg:97.73ms +step:857/1670 train_time:83750ms step_avg:97.72ms +step:858/1670 train_time:83846ms step_avg:97.72ms +step:859/1670 train_time:83943ms step_avg:97.72ms +step:860/1670 train_time:84039ms step_avg:97.72ms +step:861/1670 train_time:84138ms step_avg:97.72ms +step:862/1670 train_time:84241ms step_avg:97.73ms +step:863/1670 train_time:84340ms step_avg:97.73ms +step:864/1670 train_time:84438ms step_avg:97.73ms +step:865/1670 train_time:84535ms step_avg:97.73ms +step:866/1670 train_time:84631ms step_avg:97.73ms +step:867/1670 train_time:84727ms step_avg:97.72ms +step:868/1670 train_time:84823ms step_avg:97.72ms +step:869/1670 train_time:84920ms step_avg:97.72ms +step:870/1670 train_time:85016ms step_avg:97.72ms +step:871/1670 train_time:85113ms step_avg:97.72ms +step:872/1670 train_time:85212ms step_avg:97.72ms +step:873/1670 train_time:85310ms step_avg:97.72ms +step:874/1670 train_time:85409ms step_avg:97.72ms +step:875/1670 train_time:85507ms step_avg:97.72ms +step:875/1670 val_loss:3.5202 train_time:85604ms step_avg:97.83ms +step:876/1670 train_time:85625ms step_avg:97.75ms +step:877/1670 train_time:85710ms step_avg:97.73ms +step:878/1670 train_time:85810ms step_avg:97.73ms +step:879/1670 train_time:85908ms step_avg:97.73ms +step:880/1670 train_time:86005ms step_avg:97.73ms +step:881/1670 train_time:86101ms step_avg:97.73ms +step:882/1670 train_time:86197ms step_avg:97.73ms +step:883/1670 train_time:86294ms step_avg:97.73ms +step:884/1670 train_time:86389ms step_avg:97.73ms +step:885/1670 train_time:86485ms step_avg:97.72ms +step:886/1670 train_time:86583ms step_avg:97.72ms +step:887/1670 train_time:86685ms step_avg:97.73ms +step:888/1670 train_time:86784ms step_avg:97.73ms +step:889/1670 train_time:86882ms step_avg:97.73ms +step:890/1670 train_time:86979ms step_avg:97.73ms +step:891/1670 train_time:87077ms step_avg:97.73ms +step:892/1670 train_time:87174ms step_avg:97.73ms +step:893/1670 train_time:87270ms step_avg:97.73ms +step:894/1670 train_time:87366ms step_avg:97.72ms +step:895/1670 train_time:87463ms step_avg:97.72ms +step:896/1670 train_time:87560ms step_avg:97.72ms +step:897/1670 train_time:87659ms step_avg:97.73ms +step:898/1670 train_time:87759ms step_avg:97.73ms +step:899/1670 train_time:87858ms step_avg:97.73ms +step:900/1670 train_time:87956ms step_avg:97.73ms +step:901/1670 train_time:88054ms step_avg:97.73ms +step:902/1670 train_time:88150ms step_avg:97.73ms +step:903/1670 train_time:88247ms step_avg:97.73ms +step:904/1670 train_time:88343ms step_avg:97.72ms +step:905/1670 train_time:88439ms step_avg:97.72ms +step:906/1670 train_time:88536ms step_avg:97.72ms +step:907/1670 train_time:88634ms step_avg:97.72ms +step:908/1670 train_time:88732ms step_avg:97.72ms +step:909/1670 train_time:88829ms step_avg:97.72ms +step:910/1670 train_time:88926ms step_avg:97.72ms +step:911/1670 train_time:89024ms step_avg:97.72ms +step:912/1670 train_time:89121ms step_avg:97.72ms +step:913/1670 train_time:89219ms step_avg:97.72ms +step:914/1670 train_time:89317ms step_avg:97.72ms +step:915/1670 train_time:89413ms step_avg:97.72ms +step:916/1670 train_time:89510ms step_avg:97.72ms +step:917/1670 train_time:89607ms step_avg:97.72ms +step:918/1670 train_time:89704ms step_avg:97.72ms +step:919/1670 train_time:89803ms step_avg:97.72ms +step:920/1670 train_time:89901ms step_avg:97.72ms +step:921/1670 train_time:89999ms step_avg:97.72ms +step:922/1670 train_time:90096ms step_avg:97.72ms +step:923/1670 train_time:90193ms step_avg:97.72ms +step:924/1670 train_time:90290ms step_avg:97.72ms +step:925/1670 train_time:90387ms step_avg:97.72ms +step:926/1670 train_time:90483ms step_avg:97.71ms +step:927/1670 train_time:90581ms step_avg:97.71ms +step:928/1670 train_time:90679ms step_avg:97.71ms +step:929/1670 train_time:90777ms step_avg:97.72ms +step:930/1670 train_time:90876ms step_avg:97.72ms +step:931/1670 train_time:90973ms step_avg:97.71ms +step:932/1670 train_time:91070ms step_avg:97.71ms +step:933/1670 train_time:91167ms step_avg:97.71ms +step:934/1670 train_time:91265ms step_avg:97.71ms +step:935/1670 train_time:91361ms step_avg:97.71ms +step:936/1670 train_time:91458ms step_avg:97.71ms +step:937/1670 train_time:91556ms step_avg:97.71ms +step:938/1670 train_time:91655ms step_avg:97.71ms +step:939/1670 train_time:91752ms step_avg:97.71ms +step:940/1670 train_time:91849ms step_avg:97.71ms +step:941/1670 train_time:91947ms step_avg:97.71ms +step:942/1670 train_time:92043ms step_avg:97.71ms +step:943/1670 train_time:92141ms step_avg:97.71ms +step:944/1670 train_time:92238ms step_avg:97.71ms +step:945/1670 train_time:92335ms step_avg:97.71ms +step:946/1670 train_time:92432ms step_avg:97.71ms +step:947/1670 train_time:92529ms step_avg:97.71ms +step:948/1670 train_time:92626ms step_avg:97.71ms +step:949/1670 train_time:92724ms step_avg:97.71ms +step:950/1670 train_time:92822ms step_avg:97.71ms +step:951/1670 train_time:92919ms step_avg:97.71ms +step:952/1670 train_time:93017ms step_avg:97.71ms +step:953/1670 train_time:93115ms step_avg:97.71ms +step:954/1670 train_time:93213ms step_avg:97.71ms +step:955/1670 train_time:93311ms step_avg:97.71ms +step:956/1670 train_time:93407ms step_avg:97.71ms +step:957/1670 train_time:93504ms step_avg:97.71ms +step:958/1670 train_time:93601ms step_avg:97.70ms +step:959/1670 train_time:93699ms step_avg:97.70ms +step:960/1670 train_time:93797ms step_avg:97.70ms +step:961/1670 train_time:93894ms step_avg:97.70ms +step:962/1670 train_time:93993ms step_avg:97.71ms +step:963/1670 train_time:94090ms step_avg:97.71ms +step:964/1670 train_time:94187ms step_avg:97.70ms +step:965/1670 train_time:94285ms step_avg:97.70ms +step:966/1670 train_time:94382ms step_avg:97.70ms +step:967/1670 train_time:94479ms step_avg:97.70ms +step:968/1670 train_time:94577ms step_avg:97.70ms +step:969/1670 train_time:94674ms step_avg:97.70ms +step:970/1670 train_time:94772ms step_avg:97.70ms +step:971/1670 train_time:94868ms step_avg:97.70ms +step:972/1670 train_time:94965ms step_avg:97.70ms +step:973/1670 train_time:95062ms step_avg:97.70ms +step:974/1670 train_time:95160ms step_avg:97.70ms +step:975/1670 train_time:95257ms step_avg:97.70ms +step:976/1670 train_time:95355ms step_avg:97.70ms +step:977/1670 train_time:95452ms step_avg:97.70ms +step:978/1670 train_time:95548ms step_avg:97.70ms +step:979/1670 train_time:95646ms step_avg:97.70ms +step:980/1670 train_time:95743ms step_avg:97.70ms +step:981/1670 train_time:95841ms step_avg:97.70ms +step:982/1670 train_time:95939ms step_avg:97.70ms +step:983/1670 train_time:96037ms step_avg:97.70ms +step:984/1670 train_time:96134ms step_avg:97.70ms +step:985/1670 train_time:96231ms step_avg:97.70ms +step:986/1670 train_time:96328ms step_avg:97.70ms +step:987/1670 train_time:96425ms step_avg:97.70ms +step:988/1670 train_time:96523ms step_avg:97.69ms +step:989/1670 train_time:96620ms step_avg:97.69ms +step:990/1670 train_time:96718ms step_avg:97.69ms +step:991/1670 train_time:96816ms step_avg:97.69ms +step:992/1670 train_time:96913ms step_avg:97.69ms +step:993/1670 train_time:97009ms step_avg:97.69ms +step:994/1670 train_time:97106ms step_avg:97.69ms +step:995/1670 train_time:97203ms step_avg:97.69ms +step:996/1670 train_time:97300ms step_avg:97.69ms +step:997/1670 train_time:97398ms step_avg:97.69ms +step:998/1670 train_time:97497ms step_avg:97.69ms +step:999/1670 train_time:97594ms step_avg:97.69ms +step:1000/1670 train_time:97692ms step_avg:97.69ms +step:1000/1670 val_loss:3.4785 train_time:97788ms step_avg:97.79ms +step:1001/1670 train_time:97810ms step_avg:97.71ms +step:1002/1670 train_time:97893ms step_avg:97.70ms +step:1003/1670 train_time:97994ms step_avg:97.70ms +step:1004/1670 train_time:98093ms step_avg:97.70ms +step:1005/1670 train_time:98190ms step_avg:97.70ms +step:1006/1670 train_time:98288ms step_avg:97.70ms +step:1007/1670 train_time:98384ms step_avg:97.70ms +step:1008/1670 train_time:98480ms step_avg:97.70ms +step:1009/1670 train_time:98576ms step_avg:97.70ms +step:1010/1670 train_time:98672ms step_avg:97.70ms +step:1011/1670 train_time:98770ms step_avg:97.70ms +step:1012/1670 train_time:98870ms step_avg:97.70ms +step:1013/1670 train_time:98969ms step_avg:97.70ms +step:1014/1670 train_time:99067ms step_avg:97.70ms +step:1015/1670 train_time:99164ms step_avg:97.70ms +step:1016/1670 train_time:99261ms step_avg:97.70ms +step:1017/1670 train_time:99357ms step_avg:97.70ms +step:1018/1670 train_time:99454ms step_avg:97.70ms +step:1019/1670 train_time:99551ms step_avg:97.69ms +step:1020/1670 train_time:99648ms step_avg:97.69ms +step:1021/1670 train_time:99746ms step_avg:97.69ms +step:1022/1670 train_time:99844ms step_avg:97.69ms +step:1023/1670 train_time:99941ms step_avg:97.69ms +step:1024/1670 train_time:100039ms step_avg:97.69ms +step:1025/1670 train_time:100137ms step_avg:97.69ms +step:1026/1670 train_time:100234ms step_avg:97.69ms +step:1027/1670 train_time:100331ms step_avg:97.69ms +step:1028/1670 train_time:100428ms step_avg:97.69ms +step:1029/1670 train_time:100525ms step_avg:97.69ms +step:1030/1670 train_time:100621ms step_avg:97.69ms +step:1031/1670 train_time:100717ms step_avg:97.69ms +step:1032/1670 train_time:100816ms step_avg:97.69ms +step:1033/1670 train_time:100915ms step_avg:97.69ms +step:1034/1670 train_time:101013ms step_avg:97.69ms +step:1035/1670 train_time:101112ms step_avg:97.69ms +step:1036/1670 train_time:101209ms step_avg:97.69ms +step:1037/1670 train_time:101308ms step_avg:97.69ms +step:1038/1670 train_time:101404ms step_avg:97.69ms +step:1039/1670 train_time:101500ms step_avg:97.69ms +step:1040/1670 train_time:101597ms step_avg:97.69ms +step:1041/1670 train_time:101693ms step_avg:97.69ms +step:1042/1670 train_time:101791ms step_avg:97.69ms +step:1043/1670 train_time:101889ms step_avg:97.69ms +step:1044/1670 train_time:101988ms step_avg:97.69ms +step:1045/1670 train_time:102085ms step_avg:97.69ms +step:1046/1670 train_time:102182ms step_avg:97.69ms +step:1047/1670 train_time:102279ms step_avg:97.69ms +step:1048/1670 train_time:102377ms step_avg:97.69ms +step:1049/1670 train_time:102475ms step_avg:97.69ms +step:1050/1670 train_time:102573ms step_avg:97.69ms +step:1051/1670 train_time:102670ms step_avg:97.69ms +step:1052/1670 train_time:102767ms step_avg:97.69ms +step:1053/1670 train_time:102866ms step_avg:97.69ms +step:1054/1670 train_time:102963ms step_avg:97.69ms +step:1055/1670 train_time:103060ms step_avg:97.69ms +step:1056/1670 train_time:103158ms step_avg:97.69ms +step:1057/1670 train_time:103255ms step_avg:97.69ms +step:1058/1670 train_time:103352ms step_avg:97.69ms +step:1059/1670 train_time:103449ms step_avg:97.69ms +step:1060/1670 train_time:103547ms step_avg:97.69ms +step:1061/1670 train_time:103644ms step_avg:97.69ms +step:1062/1670 train_time:103896ms step_avg:97.83ms +step:1063/1670 train_time:104054ms step_avg:97.89ms +step:1064/1670 train_time:104149ms step_avg:97.88ms +step:1065/1670 train_time:104245ms step_avg:97.88ms +step:1066/1670 train_time:104341ms step_avg:97.88ms +step:1067/1670 train_time:104436ms step_avg:97.88ms +step:1068/1670 train_time:104532ms step_avg:97.88ms +step:1069/1670 train_time:104628ms step_avg:97.87ms +step:1070/1670 train_time:104724ms step_avg:97.87ms +step:1071/1670 train_time:104820ms step_avg:97.87ms +step:1072/1670 train_time:104920ms step_avg:97.87ms +step:1073/1670 train_time:105022ms step_avg:97.88ms +step:1074/1670 train_time:105120ms step_avg:97.88ms +step:1075/1670 train_time:105218ms step_avg:97.88ms +step:1076/1670 train_time:105316ms step_avg:97.88ms +step:1077/1670 train_time:105414ms step_avg:97.88ms +step:1078/1670 train_time:105510ms step_avg:97.88ms +step:1079/1670 train_time:105607ms step_avg:97.87ms +step:1080/1670 train_time:105704ms step_avg:97.87ms +step:1081/1670 train_time:105800ms step_avg:97.87ms +step:1082/1670 train_time:105897ms step_avg:97.87ms +step:1083/1670 train_time:105997ms step_avg:97.87ms +step:1084/1670 train_time:106097ms step_avg:97.88ms +step:1085/1670 train_time:106195ms step_avg:97.88ms +step:1086/1670 train_time:106293ms step_avg:97.88ms +step:1087/1670 train_time:106390ms step_avg:97.88ms +step:1088/1670 train_time:106487ms step_avg:97.87ms +step:1089/1670 train_time:106584ms step_avg:97.87ms +step:1090/1670 train_time:106680ms step_avg:97.87ms +step:1091/1670 train_time:106777ms step_avg:97.87ms +step:1092/1670 train_time:106874ms step_avg:97.87ms +step:1093/1670 train_time:106972ms step_avg:97.87ms +step:1094/1670 train_time:107071ms step_avg:97.87ms +step:1095/1670 train_time:107170ms step_avg:97.87ms +step:1096/1670 train_time:107267ms step_avg:97.87ms +step:1097/1670 train_time:107364ms step_avg:97.87ms +step:1098/1670 train_time:107460ms step_avg:97.87ms +step:1099/1670 train_time:107558ms step_avg:97.87ms +step:1100/1670 train_time:107655ms step_avg:97.87ms +step:1101/1670 train_time:107752ms step_avg:97.87ms +step:1102/1670 train_time:107850ms step_avg:97.87ms +step:1103/1670 train_time:107947ms step_avg:97.87ms +step:1104/1670 train_time:108044ms step_avg:97.87ms +step:1105/1670 train_time:108142ms step_avg:97.87ms +step:1106/1670 train_time:108241ms step_avg:97.87ms +step:1107/1670 train_time:108339ms step_avg:97.87ms +step:1108/1670 train_time:108436ms step_avg:97.87ms +step:1109/1670 train_time:108533ms step_avg:97.87ms +step:1110/1670 train_time:108631ms step_avg:97.87ms +step:1111/1670 train_time:108727ms step_avg:97.86ms +step:1112/1670 train_time:108825ms step_avg:97.86ms +step:1113/1670 train_time:108921ms step_avg:97.86ms +step:1114/1670 train_time:109019ms step_avg:97.86ms +step:1115/1670 train_time:109117ms step_avg:97.86ms +step:1116/1670 train_time:109216ms step_avg:97.86ms +step:1117/1670 train_time:109315ms step_avg:97.87ms +step:1118/1670 train_time:109413ms step_avg:97.87ms +step:1119/1670 train_time:109512ms step_avg:97.87ms +step:1120/1670 train_time:109610ms step_avg:97.87ms +step:1121/1670 train_time:109708ms step_avg:97.87ms +step:1122/1670 train_time:109805ms step_avg:97.87ms +step:1123/1670 train_time:109903ms step_avg:97.87ms +step:1124/1670 train_time:110001ms step_avg:97.87ms +step:1125/1670 train_time:110099ms step_avg:97.87ms +step:1125/1670 val_loss:3.4246 train_time:110197ms step_avg:97.95ms +step:1126/1670 train_time:110219ms step_avg:97.89ms +step:1127/1670 train_time:110310ms step_avg:97.88ms +step:1128/1670 train_time:110408ms step_avg:97.88ms +step:1129/1670 train_time:110506ms step_avg:97.88ms +step:1130/1670 train_time:110602ms step_avg:97.88ms +step:1131/1670 train_time:110698ms step_avg:97.88ms +step:1132/1670 train_time:110795ms step_avg:97.88ms +step:1133/1670 train_time:110892ms step_avg:97.87ms +step:1134/1670 train_time:110989ms step_avg:97.87ms +step:1135/1670 train_time:111087ms step_avg:97.87ms +step:1136/1670 train_time:111187ms step_avg:97.88ms +step:1137/1670 train_time:111288ms step_avg:97.88ms +step:1138/1670 train_time:111388ms step_avg:97.88ms +step:1139/1670 train_time:111487ms step_avg:97.88ms +step:1140/1670 train_time:111584ms step_avg:97.88ms +step:1141/1670 train_time:111681ms step_avg:97.88ms +step:1142/1670 train_time:111776ms step_avg:97.88ms +step:1143/1670 train_time:111873ms step_avg:97.88ms +step:1144/1670 train_time:111971ms step_avg:97.88ms +step:1145/1670 train_time:112070ms step_avg:97.88ms +step:1146/1670 train_time:112171ms step_avg:97.88ms +step:1147/1670 train_time:112271ms step_avg:97.88ms +step:1148/1670 train_time:112371ms step_avg:97.88ms +step:1149/1670 train_time:112471ms step_avg:97.89ms +step:1150/1670 train_time:112571ms step_avg:97.89ms +step:1151/1670 train_time:112670ms step_avg:97.89ms +step:1152/1670 train_time:112770ms step_avg:97.89ms +step:1153/1670 train_time:112867ms step_avg:97.89ms +step:1154/1670 train_time:112964ms step_avg:97.89ms +step:1155/1670 train_time:113061ms step_avg:97.89ms +step:1156/1670 train_time:113158ms step_avg:97.89ms +step:1157/1670 train_time:113256ms step_avg:97.89ms +step:1158/1670 train_time:113356ms step_avg:97.89ms +step:1159/1670 train_time:113456ms step_avg:97.89ms +step:1160/1670 train_time:113556ms step_avg:97.89ms +step:1161/1670 train_time:113656ms step_avg:97.89ms +step:1162/1670 train_time:113756ms step_avg:97.90ms +step:1163/1670 train_time:113854ms step_avg:97.90ms +step:1164/1670 train_time:113952ms step_avg:97.90ms +step:1165/1670 train_time:114050ms step_avg:97.90ms +step:1166/1670 train_time:114148ms step_avg:97.90ms +step:1167/1670 train_time:114246ms step_avg:97.90ms +step:1168/1670 train_time:114343ms step_avg:97.90ms +step:1169/1670 train_time:114440ms step_avg:97.90ms +step:1170/1670 train_time:114538ms step_avg:97.90ms +step:1171/1670 train_time:114637ms step_avg:97.90ms +step:1172/1670 train_time:114735ms step_avg:97.90ms +step:1173/1670 train_time:114834ms step_avg:97.90ms +step:1174/1670 train_time:114931ms step_avg:97.90ms +step:1175/1670 train_time:115030ms step_avg:97.90ms +step:1176/1670 train_time:115128ms step_avg:97.90ms +step:1177/1670 train_time:115227ms step_avg:97.90ms +step:1178/1670 train_time:115325ms step_avg:97.90ms +step:1179/1670 train_time:115423ms step_avg:97.90ms +step:1180/1670 train_time:115520ms step_avg:97.90ms +step:1181/1670 train_time:115618ms step_avg:97.90ms +step:1182/1670 train_time:115715ms step_avg:97.90ms +step:1183/1670 train_time:115813ms step_avg:97.90ms +step:1184/1670 train_time:115911ms step_avg:97.90ms +step:1185/1670 train_time:116009ms step_avg:97.90ms +step:1186/1670 train_time:116106ms step_avg:97.90ms +step:1187/1670 train_time:116204ms step_avg:97.90ms +step:1188/1670 train_time:116301ms step_avg:97.90ms +step:1189/1670 train_time:116399ms step_avg:97.90ms +step:1190/1670 train_time:116497ms step_avg:97.90ms +step:1191/1670 train_time:116596ms step_avg:97.90ms +step:1192/1670 train_time:116694ms step_avg:97.90ms +step:1193/1670 train_time:116792ms step_avg:97.90ms +step:1194/1670 train_time:116890ms step_avg:97.90ms +step:1195/1670 train_time:116987ms step_avg:97.90ms +step:1196/1670 train_time:117084ms step_avg:97.90ms +step:1197/1670 train_time:117182ms step_avg:97.90ms +step:1198/1670 train_time:117280ms step_avg:97.90ms +step:1199/1670 train_time:117378ms step_avg:97.90ms +step:1200/1670 train_time:117476ms step_avg:97.90ms +step:1201/1670 train_time:117575ms step_avg:97.90ms +step:1202/1670 train_time:117673ms step_avg:97.90ms +step:1203/1670 train_time:117772ms step_avg:97.90ms +step:1204/1670 train_time:117869ms step_avg:97.90ms +step:1205/1670 train_time:117966ms step_avg:97.90ms +step:1206/1670 train_time:118064ms step_avg:97.90ms +step:1207/1670 train_time:118162ms step_avg:97.90ms +step:1208/1670 train_time:118260ms step_avg:97.90ms +step:1209/1670 train_time:118357ms step_avg:97.90ms +step:1210/1670 train_time:118455ms step_avg:97.90ms +step:1211/1670 train_time:118554ms step_avg:97.90ms +step:1212/1670 train_time:118652ms step_avg:97.90ms +step:1213/1670 train_time:118750ms step_avg:97.90ms +step:1214/1670 train_time:118847ms step_avg:97.90ms +step:1215/1670 train_time:118944ms step_avg:97.90ms +step:1216/1670 train_time:119041ms step_avg:97.90ms +step:1217/1670 train_time:119139ms step_avg:97.90ms +step:1218/1670 train_time:119236ms step_avg:97.90ms +step:1219/1670 train_time:119335ms step_avg:97.90ms +step:1220/1670 train_time:119434ms step_avg:97.90ms +step:1221/1670 train_time:119533ms step_avg:97.90ms +step:1222/1670 train_time:119631ms step_avg:97.90ms +step:1223/1670 train_time:119730ms step_avg:97.90ms +step:1224/1670 train_time:119828ms step_avg:97.90ms +step:1225/1670 train_time:119927ms step_avg:97.90ms +step:1226/1670 train_time:120025ms step_avg:97.90ms +step:1227/1670 train_time:120123ms step_avg:97.90ms +step:1228/1670 train_time:120220ms step_avg:97.90ms +step:1229/1670 train_time:120317ms step_avg:97.90ms +step:1230/1670 train_time:120414ms step_avg:97.90ms +step:1231/1670 train_time:120513ms step_avg:97.90ms +step:1232/1670 train_time:120613ms step_avg:97.90ms +step:1233/1670 train_time:120711ms step_avg:97.90ms +step:1234/1670 train_time:120809ms step_avg:97.90ms +step:1235/1670 train_time:120908ms step_avg:97.90ms +step:1236/1670 train_time:121006ms step_avg:97.90ms +step:1237/1670 train_time:121104ms step_avg:97.90ms +step:1238/1670 train_time:121202ms step_avg:97.90ms +step:1239/1670 train_time:121300ms step_avg:97.90ms +step:1240/1670 train_time:121397ms step_avg:97.90ms +step:1241/1670 train_time:121494ms step_avg:97.90ms +step:1242/1670 train_time:121592ms step_avg:97.90ms +step:1243/1670 train_time:121690ms step_avg:97.90ms +step:1244/1670 train_time:121787ms step_avg:97.90ms +step:1245/1670 train_time:121885ms step_avg:97.90ms +step:1246/1670 train_time:121983ms step_avg:97.90ms +step:1247/1670 train_time:122081ms step_avg:97.90ms +step:1248/1670 train_time:122178ms step_avg:97.90ms +step:1249/1670 train_time:122276ms step_avg:97.90ms +step:1250/1670 train_time:122374ms step_avg:97.90ms +step:1250/1670 val_loss:3.3811 train_time:122471ms step_avg:97.98ms +step:1251/1670 train_time:122493ms step_avg:97.92ms +step:1252/1670 train_time:122579ms step_avg:97.91ms +step:1253/1670 train_time:122678ms step_avg:97.91ms +step:1254/1670 train_time:122776ms step_avg:97.91ms +step:1255/1670 train_time:122875ms step_avg:97.91ms +step:1256/1670 train_time:122971ms step_avg:97.91ms +step:1257/1670 train_time:123068ms step_avg:97.91ms +step:1258/1670 train_time:123165ms step_avg:97.91ms +step:1259/1670 train_time:123261ms step_avg:97.90ms +step:1260/1670 train_time:123358ms step_avg:97.90ms +step:1261/1670 train_time:123458ms step_avg:97.90ms +step:1262/1670 train_time:123560ms step_avg:97.91ms +step:1263/1670 train_time:123660ms step_avg:97.91ms +step:1264/1670 train_time:123758ms step_avg:97.91ms +step:1265/1670 train_time:123856ms step_avg:97.91ms +step:1266/1670 train_time:123955ms step_avg:97.91ms +step:1267/1670 train_time:124053ms step_avg:97.91ms +step:1268/1670 train_time:124151ms step_avg:97.91ms +step:1269/1670 train_time:124247ms step_avg:97.91ms +step:1270/1670 train_time:124344ms step_avg:97.91ms +step:1271/1670 train_time:124442ms step_avg:97.91ms +step:1272/1670 train_time:124541ms step_avg:97.91ms +step:1273/1670 train_time:124640ms step_avg:97.91ms +step:1274/1670 train_time:125029ms step_avg:98.14ms +step:1275/1670 train_time:125103ms step_avg:98.12ms +step:1276/1670 train_time:125199ms step_avg:98.12ms +step:1277/1670 train_time:125295ms step_avg:98.12ms +step:1278/1670 train_time:125392ms step_avg:98.12ms +step:1279/1670 train_time:125488ms step_avg:98.11ms +step:1280/1670 train_time:125585ms step_avg:98.11ms +step:1281/1670 train_time:125681ms step_avg:98.11ms +step:1282/1670 train_time:125779ms step_avg:98.11ms +step:1283/1670 train_time:125876ms step_avg:98.11ms +step:1284/1670 train_time:125982ms step_avg:98.12ms +step:1285/1670 train_time:126086ms step_avg:98.12ms +step:1286/1670 train_time:126184ms step_avg:98.12ms +step:1287/1670 train_time:126281ms step_avg:98.12ms +step:1288/1670 train_time:126379ms step_avg:98.12ms +step:1289/1670 train_time:126477ms step_avg:98.12ms +step:1290/1670 train_time:126575ms step_avg:98.12ms +step:1291/1670 train_time:126672ms step_avg:98.12ms +step:1292/1670 train_time:126769ms step_avg:98.12ms +step:1293/1670 train_time:126866ms step_avg:98.12ms +step:1294/1670 train_time:126966ms step_avg:98.12ms +step:1295/1670 train_time:127065ms step_avg:98.12ms +step:1296/1670 train_time:127163ms step_avg:98.12ms +step:1297/1670 train_time:127261ms step_avg:98.12ms +step:1298/1670 train_time:127359ms step_avg:98.12ms +step:1299/1670 train_time:127457ms step_avg:98.12ms +step:1300/1670 train_time:127555ms step_avg:98.12ms +step:1301/1670 train_time:127654ms step_avg:98.12ms +step:1302/1670 train_time:127750ms step_avg:98.12ms +step:1303/1670 train_time:127847ms step_avg:98.12ms +step:1304/1670 train_time:127947ms step_avg:98.12ms +step:1305/1670 train_time:128044ms step_avg:98.12ms +step:1306/1670 train_time:128143ms step_avg:98.12ms +step:1307/1670 train_time:128240ms step_avg:98.12ms +step:1308/1670 train_time:128338ms step_avg:98.12ms +step:1309/1670 train_time:128436ms step_avg:98.12ms +step:1310/1670 train_time:128535ms step_avg:98.12ms +step:1311/1670 train_time:128635ms step_avg:98.12ms +step:1312/1670 train_time:128732ms step_avg:98.12ms +step:1313/1670 train_time:128830ms step_avg:98.12ms +step:1314/1670 train_time:128930ms step_avg:98.12ms +step:1315/1670 train_time:129031ms step_avg:98.12ms +step:1316/1670 train_time:129130ms step_avg:98.12ms +step:1317/1670 train_time:129228ms step_avg:98.12ms +step:1318/1670 train_time:129327ms step_avg:98.12ms +step:1319/1670 train_time:129424ms step_avg:98.12ms +step:1320/1670 train_time:129522ms step_avg:98.12ms +step:1321/1670 train_time:129621ms step_avg:98.12ms +step:1322/1670 train_time:129719ms step_avg:98.12ms +step:1323/1670 train_time:129819ms step_avg:98.12ms +step:1324/1670 train_time:129918ms step_avg:98.13ms +step:1325/1670 train_time:130018ms step_avg:98.13ms +step:1326/1670 train_time:130117ms step_avg:98.13ms +step:1327/1670 train_time:130216ms step_avg:98.13ms +step:1328/1670 train_time:130316ms step_avg:98.13ms +step:1329/1670 train_time:130416ms step_avg:98.13ms +step:1330/1670 train_time:130516ms step_avg:98.13ms +step:1331/1670 train_time:130616ms step_avg:98.13ms +step:1332/1670 train_time:130714ms step_avg:98.13ms +step:1333/1670 train_time:130812ms step_avg:98.13ms +step:1334/1670 train_time:130910ms step_avg:98.13ms +step:1335/1670 train_time:131008ms step_avg:98.13ms +step:1336/1670 train_time:131107ms step_avg:98.13ms +step:1337/1670 train_time:131205ms step_avg:98.13ms +step:1338/1670 train_time:131303ms step_avg:98.13ms +step:1339/1670 train_time:131401ms step_avg:98.13ms +step:1340/1670 train_time:131500ms step_avg:98.13ms +step:1341/1670 train_time:131598ms step_avg:98.13ms +step:1342/1670 train_time:131695ms step_avg:98.13ms +step:1343/1670 train_time:131794ms step_avg:98.13ms +step:1344/1670 train_time:131892ms step_avg:98.13ms +step:1345/1670 train_time:131989ms step_avg:98.13ms +step:1346/1670 train_time:132087ms step_avg:98.13ms +step:1347/1670 train_time:132186ms step_avg:98.13ms +step:1348/1670 train_time:132283ms step_avg:98.13ms +step:1349/1670 train_time:132382ms step_avg:98.13ms +step:1350/1670 train_time:132480ms step_avg:98.13ms +step:1351/1670 train_time:132579ms step_avg:98.13ms +step:1352/1670 train_time:132677ms step_avg:98.13ms +step:1353/1670 train_time:132775ms step_avg:98.13ms +step:1354/1670 train_time:132874ms step_avg:98.13ms +step:1355/1670 train_time:132973ms step_avg:98.13ms +step:1356/1670 train_time:133072ms step_avg:98.14ms +step:1357/1670 train_time:133173ms step_avg:98.14ms +step:1358/1670 train_time:133272ms step_avg:98.14ms +step:1359/1670 train_time:133373ms step_avg:98.14ms +step:1360/1670 train_time:133472ms step_avg:98.14ms +step:1361/1670 train_time:133570ms step_avg:98.14ms +step:1362/1670 train_time:133668ms step_avg:98.14ms +step:1363/1670 train_time:133765ms step_avg:98.14ms +step:1364/1670 train_time:133863ms step_avg:98.14ms +step:1365/1670 train_time:133961ms step_avg:98.14ms +step:1366/1670 train_time:134060ms step_avg:98.14ms +step:1367/1670 train_time:134159ms step_avg:98.14ms +step:1368/1670 train_time:134259ms step_avg:98.14ms +step:1369/1670 train_time:134357ms step_avg:98.14ms +step:1370/1670 train_time:134457ms step_avg:98.14ms +step:1371/1670 train_time:134556ms step_avg:98.14ms +step:1372/1670 train_time:134655ms step_avg:98.14ms +step:1373/1670 train_time:134754ms step_avg:98.15ms +step:1374/1670 train_time:134852ms step_avg:98.15ms +step:1375/1670 train_time:134950ms step_avg:98.15ms +step:1375/1670 val_loss:3.3439 train_time:135046ms step_avg:98.22ms +step:1376/1670 train_time:135068ms step_avg:98.16ms +step:1377/1670 train_time:135154ms step_avg:98.15ms +step:1378/1670 train_time:135252ms step_avg:98.15ms +step:1379/1670 train_time:135350ms step_avg:98.15ms +step:1380/1670 train_time:135448ms step_avg:98.15ms +step:1381/1670 train_time:135545ms step_avg:98.15ms +step:1382/1670 train_time:135641ms step_avg:98.15ms +step:1383/1670 train_time:135740ms step_avg:98.15ms +step:1384/1670 train_time:135837ms step_avg:98.15ms +step:1385/1670 train_time:135934ms step_avg:98.15ms +step:1386/1670 train_time:136033ms step_avg:98.15ms +step:1387/1670 train_time:136134ms step_avg:98.15ms +step:1388/1670 train_time:136233ms step_avg:98.15ms +step:1389/1670 train_time:136331ms step_avg:98.15ms +step:1390/1670 train_time:136429ms step_avg:98.15ms +step:1391/1670 train_time:136527ms step_avg:98.15ms +step:1392/1670 train_time:136623ms step_avg:98.15ms +step:1393/1670 train_time:136720ms step_avg:98.15ms +step:1394/1670 train_time:136818ms step_avg:98.15ms +step:1395/1670 train_time:136916ms step_avg:98.15ms +step:1396/1670 train_time:137014ms step_avg:98.15ms +step:1397/1670 train_time:137113ms step_avg:98.15ms +step:1398/1670 train_time:137213ms step_avg:98.15ms +step:1399/1670 train_time:137312ms step_avg:98.15ms +step:1400/1670 train_time:137410ms step_avg:98.15ms +step:1401/1670 train_time:137507ms step_avg:98.15ms +step:1402/1670 train_time:137604ms step_avg:98.15ms +step:1403/1670 train_time:137701ms step_avg:98.15ms +step:1404/1670 train_time:137798ms step_avg:98.15ms +step:1405/1670 train_time:137895ms step_avg:98.15ms +step:1406/1670 train_time:137993ms step_avg:98.15ms +step:1407/1670 train_time:138092ms step_avg:98.15ms +step:1408/1670 train_time:138190ms step_avg:98.15ms +step:1409/1670 train_time:138288ms step_avg:98.15ms +step:1410/1670 train_time:138386ms step_avg:98.15ms +step:1411/1670 train_time:138485ms step_avg:98.15ms +step:1412/1670 train_time:138582ms step_avg:98.15ms +step:1413/1670 train_time:138680ms step_avg:98.15ms +step:1414/1670 train_time:138777ms step_avg:98.15ms +step:1415/1670 train_time:138875ms step_avg:98.15ms +step:1416/1670 train_time:138973ms step_avg:98.14ms +step:1417/1670 train_time:139071ms step_avg:98.14ms +step:1418/1670 train_time:139169ms step_avg:98.14ms +step:1419/1670 train_time:139267ms step_avg:98.14ms +step:1420/1670 train_time:139365ms step_avg:98.14ms +step:1421/1670 train_time:139464ms step_avg:98.14ms +step:1422/1670 train_time:139562ms step_avg:98.15ms +step:1423/1670 train_time:139660ms step_avg:98.15ms +step:1424/1670 train_time:139758ms step_avg:98.14ms +step:1425/1670 train_time:139856ms step_avg:98.14ms +step:1426/1670 train_time:139954ms step_avg:98.14ms +step:1427/1670 train_time:140052ms step_avg:98.14ms +step:1428/1670 train_time:140151ms step_avg:98.14ms +step:1429/1670 train_time:140250ms step_avg:98.15ms +step:1430/1670 train_time:140348ms step_avg:98.15ms +step:1431/1670 train_time:140446ms step_avg:98.15ms +step:1432/1670 train_time:140543ms step_avg:98.14ms +step:1433/1670 train_time:140641ms step_avg:98.14ms +step:1434/1670 train_time:140740ms step_avg:98.14ms +step:1435/1670 train_time:140839ms step_avg:98.15ms +step:1436/1670 train_time:140937ms step_avg:98.15ms +step:1437/1670 train_time:141036ms step_avg:98.15ms +step:1438/1670 train_time:141136ms step_avg:98.15ms +step:1439/1670 train_time:141235ms step_avg:98.15ms +step:1440/1670 train_time:141335ms step_avg:98.15ms +step:1441/1670 train_time:141434ms step_avg:98.15ms +step:1442/1670 train_time:141532ms step_avg:98.15ms +step:1443/1670 train_time:141630ms step_avg:98.15ms +step:1444/1670 train_time:141728ms step_avg:98.15ms +step:1445/1670 train_time:141828ms step_avg:98.15ms +step:1446/1670 train_time:141926ms step_avg:98.15ms +step:1447/1670 train_time:142026ms step_avg:98.15ms +step:1448/1670 train_time:142123ms step_avg:98.15ms +step:1449/1670 train_time:142223ms step_avg:98.15ms +step:1450/1670 train_time:142323ms step_avg:98.15ms +step:1451/1670 train_time:142423ms step_avg:98.15ms +step:1452/1670 train_time:142522ms step_avg:98.16ms +step:1453/1670 train_time:142621ms step_avg:98.16ms +step:1454/1670 train_time:142720ms step_avg:98.16ms +step:1455/1670 train_time:142820ms step_avg:98.16ms +step:1456/1670 train_time:142919ms step_avg:98.16ms +step:1457/1670 train_time:143017ms step_avg:98.16ms +step:1458/1670 train_time:143115ms step_avg:98.16ms +step:1459/1670 train_time:143213ms step_avg:98.16ms +step:1460/1670 train_time:143312ms step_avg:98.16ms +step:1461/1670 train_time:143409ms step_avg:98.16ms +step:1462/1670 train_time:143507ms step_avg:98.16ms +step:1463/1670 train_time:143604ms step_avg:98.16ms +step:1464/1670 train_time:143703ms step_avg:98.16ms +step:1465/1670 train_time:143802ms step_avg:98.16ms +step:1466/1670 train_time:143901ms step_avg:98.16ms +step:1467/1670 train_time:143999ms step_avg:98.16ms +step:1468/1670 train_time:144098ms step_avg:98.16ms +step:1469/1670 train_time:144196ms step_avg:98.16ms +step:1470/1670 train_time:144294ms step_avg:98.16ms +step:1471/1670 train_time:144392ms step_avg:98.16ms +step:1472/1670 train_time:144491ms step_avg:98.16ms +step:1473/1670 train_time:144588ms step_avg:98.16ms +step:1474/1670 train_time:144685ms step_avg:98.16ms +step:1475/1670 train_time:144783ms step_avg:98.16ms +step:1476/1670 train_time:144882ms step_avg:98.16ms +step:1477/1670 train_time:144980ms step_avg:98.16ms +step:1478/1670 train_time:145079ms step_avg:98.16ms +step:1479/1670 train_time:145177ms step_avg:98.16ms +step:1480/1670 train_time:145277ms step_avg:98.16ms +step:1481/1670 train_time:145377ms step_avg:98.16ms +step:1482/1670 train_time:145476ms step_avg:98.16ms +step:1483/1670 train_time:145574ms step_avg:98.16ms +step:1484/1670 train_time:145673ms step_avg:98.16ms +step:1485/1670 train_time:146044ms step_avg:98.35ms +step:1486/1670 train_time:146118ms step_avg:98.33ms +step:1487/1670 train_time:146214ms step_avg:98.33ms +step:1488/1670 train_time:146310ms step_avg:98.33ms +step:1489/1670 train_time:146406ms step_avg:98.33ms +step:1490/1670 train_time:146503ms step_avg:98.32ms +step:1491/1670 train_time:146600ms step_avg:98.32ms +step:1492/1670 train_time:146698ms step_avg:98.32ms +step:1493/1670 train_time:146795ms step_avg:98.32ms +step:1494/1670 train_time:146892ms step_avg:98.32ms +step:1495/1670 train_time:146996ms step_avg:98.32ms +step:1496/1670 train_time:147098ms step_avg:98.33ms +step:1497/1670 train_time:147197ms step_avg:98.33ms +step:1498/1670 train_time:147297ms step_avg:98.33ms +step:1499/1670 train_time:147397ms step_avg:98.33ms +step:1500/1670 train_time:147494ms step_avg:98.33ms +step:1500/1670 val_loss:3.3119 train_time:147591ms step_avg:98.39ms +step:1501/1670 train_time:147612ms step_avg:98.34ms +step:1502/1670 train_time:147699ms step_avg:98.33ms +step:1503/1670 train_time:147800ms step_avg:98.34ms +step:1504/1670 train_time:147898ms step_avg:98.34ms +step:1505/1670 train_time:147996ms step_avg:98.34ms +step:1506/1670 train_time:148092ms step_avg:98.33ms +step:1507/1670 train_time:148189ms step_avg:98.33ms +step:1508/1670 train_time:148286ms step_avg:98.33ms +step:1509/1670 train_time:148384ms step_avg:98.33ms +step:1510/1670 train_time:148481ms step_avg:98.33ms +step:1511/1670 train_time:148579ms step_avg:98.33ms +step:1512/1670 train_time:148679ms step_avg:98.33ms +step:1513/1670 train_time:148779ms step_avg:98.33ms +step:1514/1670 train_time:148877ms step_avg:98.33ms +step:1515/1670 train_time:148975ms step_avg:98.33ms +step:1516/1670 train_time:149072ms step_avg:98.33ms +step:1517/1670 train_time:149170ms step_avg:98.33ms +step:1518/1670 train_time:149267ms step_avg:98.33ms +step:1519/1670 train_time:149365ms step_avg:98.33ms +step:1520/1670 train_time:149463ms step_avg:98.33ms +step:1521/1670 train_time:149561ms step_avg:98.33ms +step:1522/1670 train_time:149660ms step_avg:98.33ms +step:1523/1670 train_time:149759ms step_avg:98.33ms +step:1524/1670 train_time:149859ms step_avg:98.33ms +step:1525/1670 train_time:149957ms step_avg:98.33ms +step:1526/1670 train_time:150055ms step_avg:98.33ms +step:1527/1670 train_time:150153ms step_avg:98.33ms +step:1528/1670 train_time:150250ms step_avg:98.33ms +step:1529/1670 train_time:150348ms step_avg:98.33ms +step:1530/1670 train_time:150445ms step_avg:98.33ms +step:1531/1670 train_time:150544ms step_avg:98.33ms +step:1532/1670 train_time:150643ms step_avg:98.33ms +step:1533/1670 train_time:150742ms step_avg:98.33ms +step:1534/1670 train_time:150840ms step_avg:98.33ms +step:1535/1670 train_time:150939ms step_avg:98.33ms +step:1536/1670 train_time:151037ms step_avg:98.33ms +step:1537/1670 train_time:151134ms step_avg:98.33ms +step:1538/1670 train_time:151232ms step_avg:98.33ms +step:1539/1670 train_time:151331ms step_avg:98.33ms +step:1540/1670 train_time:151428ms step_avg:98.33ms +step:1541/1670 train_time:151526ms step_avg:98.33ms +step:1542/1670 train_time:151625ms step_avg:98.33ms +step:1543/1670 train_time:151725ms step_avg:98.33ms +step:1544/1670 train_time:151825ms step_avg:98.33ms +step:1545/1670 train_time:151925ms step_avg:98.33ms +step:1546/1670 train_time:152024ms step_avg:98.33ms +step:1547/1670 train_time:152123ms step_avg:98.33ms +step:1548/1670 train_time:152222ms step_avg:98.33ms +step:1549/1670 train_time:152321ms step_avg:98.33ms +step:1550/1670 train_time:152417ms step_avg:98.33ms +step:1551/1670 train_time:152515ms step_avg:98.33ms +step:1552/1670 train_time:152614ms step_avg:98.33ms +step:1553/1670 train_time:152712ms step_avg:98.33ms +step:1554/1670 train_time:152813ms step_avg:98.34ms +step:1555/1670 train_time:152913ms step_avg:98.34ms +step:1556/1670 train_time:153013ms step_avg:98.34ms +step:1557/1670 train_time:153113ms step_avg:98.34ms +step:1558/1670 train_time:153212ms step_avg:98.34ms +step:1559/1670 train_time:153313ms step_avg:98.34ms +step:1560/1670 train_time:153412ms step_avg:98.34ms +step:1561/1670 train_time:153510ms step_avg:98.34ms +step:1562/1670 train_time:153608ms step_avg:98.34ms +step:1563/1670 train_time:153707ms step_avg:98.34ms +step:1564/1670 train_time:153807ms step_avg:98.34ms +step:1565/1670 train_time:153906ms step_avg:98.34ms +step:1566/1670 train_time:154006ms step_avg:98.34ms +step:1567/1670 train_time:154106ms step_avg:98.34ms +step:1568/1670 train_time:154203ms step_avg:98.34ms +step:1569/1670 train_time:154301ms step_avg:98.34ms +step:1570/1670 train_time:154399ms step_avg:98.34ms +step:1571/1670 train_time:154496ms step_avg:98.34ms +step:1572/1670 train_time:154594ms step_avg:98.34ms +step:1573/1670 train_time:154692ms step_avg:98.34ms +step:1574/1670 train_time:154791ms step_avg:98.34ms +step:1575/1670 train_time:154890ms step_avg:98.34ms +step:1576/1670 train_time:154989ms step_avg:98.34ms +step:1577/1670 train_time:155088ms step_avg:98.34ms +step:1578/1670 train_time:155187ms step_avg:98.34ms +step:1579/1670 train_time:155286ms step_avg:98.34ms +step:1580/1670 train_time:155385ms step_avg:98.34ms +step:1581/1670 train_time:155483ms step_avg:98.34ms +step:1582/1670 train_time:155581ms step_avg:98.34ms +step:1583/1670 train_time:155678ms step_avg:98.34ms +step:1584/1670 train_time:155776ms step_avg:98.34ms +step:1585/1670 train_time:155874ms step_avg:98.34ms +step:1586/1670 train_time:155973ms step_avg:98.34ms +step:1587/1670 train_time:156072ms step_avg:98.34ms +step:1588/1670 train_time:156173ms step_avg:98.35ms +step:1589/1670 train_time:156273ms step_avg:98.35ms +step:1590/1670 train_time:156372ms step_avg:98.35ms +step:1591/1670 train_time:156471ms step_avg:98.35ms +step:1592/1670 train_time:156570ms step_avg:98.35ms +step:1593/1670 train_time:156669ms step_avg:98.35ms +step:1594/1670 train_time:156767ms step_avg:98.35ms +step:1595/1670 train_time:156868ms step_avg:98.35ms +step:1596/1670 train_time:156967ms step_avg:98.35ms +step:1597/1670 train_time:157067ms step_avg:98.35ms +step:1598/1670 train_time:157166ms step_avg:98.35ms +step:1599/1670 train_time:157266ms step_avg:98.35ms +step:1600/1670 train_time:157364ms step_avg:98.35ms +step:1601/1670 train_time:157463ms step_avg:98.35ms +step:1602/1670 train_time:157562ms step_avg:98.35ms +step:1603/1670 train_time:157659ms step_avg:98.35ms +step:1604/1670 train_time:157757ms step_avg:98.35ms +step:1605/1670 train_time:157854ms step_avg:98.35ms +step:1606/1670 train_time:157952ms step_avg:98.35ms +step:1607/1670 train_time:158051ms step_avg:98.35ms +step:1608/1670 train_time:158150ms step_avg:98.35ms +step:1609/1670 train_time:158250ms step_avg:98.35ms +step:1610/1670 train_time:158349ms step_avg:98.35ms +step:1611/1670 train_time:158451ms step_avg:98.36ms +step:1612/1670 train_time:158550ms step_avg:98.36ms +step:1613/1670 train_time:158648ms step_avg:98.36ms +step:1614/1670 train_time:158747ms step_avg:98.36ms +step:1615/1670 train_time:158846ms step_avg:98.36ms +step:1616/1670 train_time:158945ms step_avg:98.36ms +step:1617/1670 train_time:159045ms step_avg:98.36ms +step:1618/1670 train_time:159143ms step_avg:98.36ms +step:1619/1670 train_time:159241ms step_avg:98.36ms +step:1620/1670 train_time:159338ms step_avg:98.36ms +step:1621/1670 train_time:159436ms step_avg:98.36ms +step:1622/1670 train_time:159535ms step_avg:98.36ms +step:1623/1670 train_time:159634ms step_avg:98.36ms +step:1624/1670 train_time:159733ms step_avg:98.36ms +step:1625/1670 train_time:159832ms step_avg:98.36ms +step:1625/1670 val_loss:3.2852 train_time:159928ms step_avg:98.42ms +step:1626/1670 train_time:159952ms step_avg:98.37ms +step:1627/1670 train_time:160035ms step_avg:98.36ms +step:1628/1670 train_time:160135ms step_avg:98.36ms +step:1629/1670 train_time:160233ms step_avg:98.36ms +step:1630/1670 train_time:160330ms step_avg:98.36ms +step:1631/1670 train_time:160428ms step_avg:98.36ms +step:1632/1670 train_time:160525ms step_avg:98.36ms +step:1633/1670 train_time:160622ms step_avg:98.36ms +step:1634/1670 train_time:160719ms step_avg:98.36ms +step:1635/1670 train_time:160817ms step_avg:98.36ms +step:1636/1670 train_time:160917ms step_avg:98.36ms +step:1637/1670 train_time:161017ms step_avg:98.36ms +step:1638/1670 train_time:161116ms step_avg:98.36ms +step:1639/1670 train_time:161215ms step_avg:98.36ms +step:1640/1670 train_time:161314ms step_avg:98.36ms +step:1641/1670 train_time:161412ms step_avg:98.36ms +step:1642/1670 train_time:161512ms step_avg:98.36ms +step:1643/1670 train_time:161610ms step_avg:98.36ms +step:1644/1670 train_time:161708ms step_avg:98.36ms +step:1645/1670 train_time:161805ms step_avg:98.36ms +step:1646/1670 train_time:161904ms step_avg:98.36ms +step:1647/1670 train_time:162003ms step_avg:98.36ms +step:1648/1670 train_time:162103ms step_avg:98.36ms +step:1649/1670 train_time:162202ms step_avg:98.36ms +step:1650/1670 train_time:162300ms step_avg:98.36ms +step:1651/1670 train_time:162398ms step_avg:98.36ms +step:1652/1670 train_time:162497ms step_avg:98.36ms +step:1653/1670 train_time:162594ms step_avg:98.36ms +step:1654/1670 train_time:162692ms step_avg:98.36ms +step:1655/1670 train_time:162791ms step_avg:98.36ms +step:1656/1670 train_time:162890ms step_avg:98.36ms +step:1657/1670 train_time:162990ms step_avg:98.36ms +step:1658/1670 train_time:163090ms step_avg:98.37ms +step:1659/1670 train_time:163190ms step_avg:98.37ms +step:1660/1670 train_time:163290ms step_avg:98.37ms +step:1661/1670 train_time:163391ms step_avg:98.37ms +step:1662/1670 train_time:163489ms step_avg:98.37ms +step:1663/1670 train_time:163587ms step_avg:98.37ms +step:1664/1670 train_time:163684ms step_avg:98.37ms +step:1665/1670 train_time:163781ms step_avg:98.37ms +step:1666/1670 train_time:163878ms step_avg:98.37ms +step:1667/1670 train_time:163977ms step_avg:98.37ms +step:1668/1670 train_time:164075ms step_avg:98.37ms +step:1669/1670 train_time:164174ms step_avg:98.37ms +step:1670/1670 train_time:164273ms step_avg:98.37ms +step:1670/1670 val_loss:3.2771 train_time:164370ms step_avg:98.43ms +peak memory allocated: 34000 MiB reserved: 49676 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt b/records/050925_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt new file mode 100644 index 000000000..0aa7a41df --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:30:52 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 130W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 44C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 91422 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 91423 C /usr/bin/python3 610MiB | +| 0 N/A N/A 91424 C /usr/bin/python3 610MiB | +| 0 N/A N/A 91425 C /usr/bin/python3 610MiB | +| 0 N/A N/A 91426 C /usr/bin/python3 610MiB | +| 0 N/A N/A 91427 C /usr/bin/python3 610MiB | +| 0 N/A N/A 91428 C /usr/bin/python3 610MiB | +| 0 N/A N/A 91429 C /usr/bin/python3 610MiB | +| 1 N/A N/A 91423 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 91424 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 91425 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 91426 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 91427 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 91428 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 91429 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:382ms step_avg:381.61ms +step:2/1670 train_time:403ms step_avg:201.51ms +step:3/1670 train_time:475ms step_avg:158.40ms +step:4/1670 train_time:569ms step_avg:142.15ms +step:5/1670 train_time:663ms step_avg:132.68ms +step:6/1670 train_time:758ms step_avg:126.35ms +step:7/1670 train_time:853ms step_avg:121.83ms +step:8/1670 train_time:948ms step_avg:118.50ms +step:9/1670 train_time:1043ms step_avg:115.90ms +step:10/1670 train_time:1138ms step_avg:113.82ms +step:11/1670 train_time:1233ms step_avg:112.09ms +step:12/1670 train_time:1332ms step_avg:111.01ms +step:13/1670 train_time:1431ms step_avg:110.08ms +step:14/1670 train_time:1528ms step_avg:109.16ms +step:15/1670 train_time:1625ms step_avg:108.33ms +step:16/1670 train_time:1721ms step_avg:107.58ms +step:17/1670 train_time:1817ms step_avg:106.86ms +step:18/1670 train_time:1912ms step_avg:106.23ms +step:19/1670 train_time:2007ms step_avg:105.63ms +step:20/1670 train_time:2102ms step_avg:105.11ms +step:21/1670 train_time:2198ms step_avg:104.66ms +step:22/1670 train_time:2294ms step_avg:104.27ms +step:23/1670 train_time:2390ms step_avg:103.90ms +step:24/1670 train_time:2487ms step_avg:103.61ms +step:25/1670 train_time:2584ms step_avg:103.35ms +step:26/1670 train_time:2681ms step_avg:103.10ms +step:27/1670 train_time:2777ms step_avg:102.84ms +step:28/1670 train_time:2872ms step_avg:102.56ms +step:29/1670 train_time:2967ms step_avg:102.31ms +step:30/1670 train_time:3063ms step_avg:102.11ms +step:31/1670 train_time:3159ms step_avg:101.89ms +step:32/1670 train_time:3255ms step_avg:101.71ms +step:33/1670 train_time:3350ms step_avg:101.53ms +step:34/1670 train_time:3447ms step_avg:101.37ms +step:35/1670 train_time:3543ms step_avg:101.24ms +step:36/1670 train_time:3640ms step_avg:101.12ms +step:37/1670 train_time:3736ms step_avg:100.97ms +step:38/1670 train_time:3832ms step_avg:100.84ms +step:39/1670 train_time:3927ms step_avg:100.70ms +step:40/1670 train_time:4024ms step_avg:100.60ms +step:41/1670 train_time:4121ms step_avg:100.52ms +step:42/1670 train_time:4217ms step_avg:100.39ms +step:43/1670 train_time:4311ms step_avg:100.26ms +step:44/1670 train_time:4408ms step_avg:100.17ms +step:45/1670 train_time:4505ms step_avg:100.10ms +step:46/1670 train_time:4601ms step_avg:100.02ms +step:47/1670 train_time:4697ms step_avg:99.95ms +step:48/1670 train_time:4793ms step_avg:99.86ms +step:49/1670 train_time:4889ms step_avg:99.77ms +step:50/1670 train_time:4986ms step_avg:99.71ms +step:51/1670 train_time:5081ms step_avg:99.63ms +step:52/1670 train_time:5177ms step_avg:99.55ms +step:53/1670 train_time:5273ms step_avg:99.48ms +step:54/1670 train_time:5368ms step_avg:99.41ms +step:55/1670 train_time:5464ms step_avg:99.35ms +step:56/1670 train_time:5561ms step_avg:99.31ms +step:57/1670 train_time:5657ms step_avg:99.24ms +step:58/1670 train_time:5753ms step_avg:99.18ms +step:59/1670 train_time:5849ms step_avg:99.13ms +step:60/1670 train_time:5945ms step_avg:99.08ms +step:61/1670 train_time:6041ms step_avg:99.03ms +step:62/1670 train_time:6137ms step_avg:98.98ms +step:63/1670 train_time:6233ms step_avg:98.93ms +step:64/1670 train_time:6328ms step_avg:98.88ms +step:65/1670 train_time:6425ms step_avg:98.85ms +step:66/1670 train_time:6521ms step_avg:98.80ms +step:67/1670 train_time:6616ms step_avg:98.75ms +step:68/1670 train_time:6712ms step_avg:98.70ms +step:69/1670 train_time:6808ms step_avg:98.66ms +step:70/1670 train_time:6905ms step_avg:98.64ms +step:71/1670 train_time:7001ms step_avg:98.61ms +step:72/1670 train_time:7097ms step_avg:98.56ms +step:73/1670 train_time:7192ms step_avg:98.53ms +step:74/1670 train_time:7288ms step_avg:98.48ms +step:75/1670 train_time:7384ms step_avg:98.46ms +step:76/1670 train_time:7481ms step_avg:98.43ms +step:77/1670 train_time:7577ms step_avg:98.40ms +step:78/1670 train_time:7672ms step_avg:98.36ms +step:79/1670 train_time:7768ms step_avg:98.32ms +step:80/1670 train_time:7864ms step_avg:98.30ms +step:81/1670 train_time:7961ms step_avg:98.29ms +step:82/1670 train_time:8058ms step_avg:98.26ms +step:83/1670 train_time:8154ms step_avg:98.24ms +step:84/1670 train_time:8249ms step_avg:98.20ms +step:85/1670 train_time:8345ms step_avg:98.18ms +step:86/1670 train_time:8441ms step_avg:98.15ms +step:87/1670 train_time:8537ms step_avg:98.13ms +step:88/1670 train_time:8633ms step_avg:98.10ms +step:89/1670 train_time:8730ms step_avg:98.09ms +step:90/1670 train_time:8825ms step_avg:98.06ms +step:91/1670 train_time:8922ms step_avg:98.05ms +step:92/1670 train_time:9018ms step_avg:98.02ms +step:93/1670 train_time:9114ms step_avg:98.00ms +step:94/1670 train_time:9211ms step_avg:97.98ms +step:95/1670 train_time:9307ms step_avg:97.97ms +step:96/1670 train_time:9404ms step_avg:97.95ms +step:97/1670 train_time:9500ms step_avg:97.93ms +step:98/1670 train_time:9596ms step_avg:97.92ms +step:99/1670 train_time:9691ms step_avg:97.89ms +step:100/1670 train_time:9787ms step_avg:97.87ms +step:101/1670 train_time:9883ms step_avg:97.85ms +step:102/1670 train_time:9979ms step_avg:97.83ms +step:103/1670 train_time:10075ms step_avg:97.81ms +step:104/1670 train_time:10170ms step_avg:97.79ms +step:105/1670 train_time:10267ms step_avg:97.78ms +step:106/1670 train_time:10363ms step_avg:97.77ms +step:107/1670 train_time:10460ms step_avg:97.75ms +step:108/1670 train_time:10555ms step_avg:97.73ms +step:109/1670 train_time:10650ms step_avg:97.71ms +step:110/1670 train_time:10745ms step_avg:97.69ms +step:111/1670 train_time:10841ms step_avg:97.66ms +step:112/1670 train_time:10937ms step_avg:97.65ms +step:113/1670 train_time:11032ms step_avg:97.63ms +step:114/1670 train_time:11129ms step_avg:97.62ms +step:115/1670 train_time:11225ms step_avg:97.61ms +step:116/1670 train_time:11321ms step_avg:97.59ms +step:117/1670 train_time:11417ms step_avg:97.58ms +step:118/1670 train_time:11513ms step_avg:97.57ms +step:119/1670 train_time:11608ms step_avg:97.55ms +step:120/1670 train_time:11704ms step_avg:97.53ms +step:121/1670 train_time:11800ms step_avg:97.52ms +step:122/1670 train_time:11895ms step_avg:97.50ms +step:123/1670 train_time:11991ms step_avg:97.49ms +step:124/1670 train_time:12087ms step_avg:97.48ms +step:125/1670 train_time:12183ms step_avg:97.47ms +step:125/1670 val_loss:4.3009 train_time:12278ms step_avg:98.22ms +step:126/1670 train_time:12303ms step_avg:97.64ms +step:127/1670 train_time:12385ms step_avg:97.52ms +step:128/1670 train_time:12488ms step_avg:97.56ms +step:129/1670 train_time:12585ms step_avg:97.55ms +step:130/1670 train_time:12680ms step_avg:97.53ms +step:131/1670 train_time:12775ms step_avg:97.52ms +step:132/1670 train_time:12870ms step_avg:97.50ms +step:133/1670 train_time:12964ms step_avg:97.48ms +step:134/1670 train_time:13059ms step_avg:97.46ms +step:135/1670 train_time:13154ms step_avg:97.44ms +step:136/1670 train_time:13249ms step_avg:97.42ms +step:137/1670 train_time:13347ms step_avg:97.42ms +step:138/1670 train_time:13444ms step_avg:97.42ms +step:139/1670 train_time:13541ms step_avg:97.42ms +step:140/1670 train_time:13637ms step_avg:97.41ms +step:141/1670 train_time:13733ms step_avg:97.40ms +step:142/1670 train_time:13829ms step_avg:97.38ms +step:143/1670 train_time:13924ms step_avg:97.37ms +step:144/1670 train_time:14019ms step_avg:97.35ms +step:145/1670 train_time:14114ms step_avg:97.34ms +step:146/1670 train_time:14210ms step_avg:97.33ms +step:147/1670 train_time:14306ms step_avg:97.32ms +step:148/1670 train_time:14404ms step_avg:97.32ms +step:149/1670 train_time:14500ms step_avg:97.32ms +step:150/1670 train_time:14596ms step_avg:97.31ms +step:151/1670 train_time:14692ms step_avg:97.30ms +step:152/1670 train_time:14788ms step_avg:97.29ms +step:153/1670 train_time:14884ms step_avg:97.28ms +step:154/1670 train_time:14980ms step_avg:97.27ms +step:155/1670 train_time:15075ms step_avg:97.26ms +step:156/1670 train_time:15170ms step_avg:97.25ms +step:157/1670 train_time:15266ms step_avg:97.24ms +step:158/1670 train_time:15362ms step_avg:97.23ms +step:159/1670 train_time:15458ms step_avg:97.22ms +step:160/1670 train_time:15555ms step_avg:97.22ms +step:161/1670 train_time:15652ms step_avg:97.21ms +step:162/1670 train_time:15748ms step_avg:97.21ms +step:163/1670 train_time:15844ms step_avg:97.20ms +step:164/1670 train_time:15940ms step_avg:97.20ms +step:165/1670 train_time:16035ms step_avg:97.18ms +step:166/1670 train_time:16130ms step_avg:97.17ms +step:167/1670 train_time:16226ms step_avg:97.16ms +step:168/1670 train_time:16321ms step_avg:97.15ms +step:169/1670 train_time:16417ms step_avg:97.14ms +step:170/1670 train_time:16513ms step_avg:97.14ms +step:171/1670 train_time:16609ms step_avg:97.13ms +step:172/1670 train_time:16705ms step_avg:97.12ms +step:173/1670 train_time:16801ms step_avg:97.11ms +step:174/1670 train_time:16897ms step_avg:97.11ms +step:175/1670 train_time:16992ms step_avg:97.10ms +step:176/1670 train_time:17088ms step_avg:97.09ms +step:177/1670 train_time:17184ms step_avg:97.09ms +step:178/1670 train_time:17279ms step_avg:97.08ms +step:179/1670 train_time:17375ms step_avg:97.07ms +step:180/1670 train_time:17471ms step_avg:97.06ms +step:181/1670 train_time:17568ms step_avg:97.06ms +step:182/1670 train_time:17663ms step_avg:97.05ms +step:183/1670 train_time:17758ms step_avg:97.04ms +step:184/1670 train_time:17854ms step_avg:97.03ms +step:185/1670 train_time:17950ms step_avg:97.03ms +step:186/1670 train_time:18045ms step_avg:97.02ms +step:187/1670 train_time:18141ms step_avg:97.01ms +step:188/1670 train_time:18236ms step_avg:97.00ms +step:189/1670 train_time:18332ms step_avg:97.00ms +step:190/1670 train_time:18428ms step_avg:96.99ms +step:191/1670 train_time:18524ms step_avg:96.98ms +step:192/1670 train_time:18619ms step_avg:96.98ms +step:193/1670 train_time:18715ms step_avg:96.97ms +step:194/1670 train_time:18811ms step_avg:96.96ms +step:195/1670 train_time:18906ms step_avg:96.96ms +step:196/1670 train_time:19002ms step_avg:96.95ms +step:197/1670 train_time:19097ms step_avg:96.94ms +step:198/1670 train_time:19192ms step_avg:96.93ms +step:199/1670 train_time:19288ms step_avg:96.92ms +step:200/1670 train_time:19384ms step_avg:96.92ms +step:201/1670 train_time:19479ms step_avg:96.91ms +step:202/1670 train_time:19575ms step_avg:96.91ms +step:203/1670 train_time:19671ms step_avg:96.90ms +step:204/1670 train_time:19767ms step_avg:96.90ms +step:205/1670 train_time:19862ms step_avg:96.89ms +step:206/1670 train_time:19957ms step_avg:96.88ms +step:207/1670 train_time:20053ms step_avg:96.87ms +step:208/1670 train_time:20148ms step_avg:96.87ms +step:209/1670 train_time:20243ms step_avg:96.86ms +step:210/1670 train_time:20338ms step_avg:96.85ms +step:211/1670 train_time:20434ms step_avg:96.84ms +step:212/1670 train_time:20529ms step_avg:96.84ms +step:213/1670 train_time:20838ms step_avg:97.83ms +step:214/1670 train_time:20943ms step_avg:97.86ms +step:215/1670 train_time:21037ms step_avg:97.84ms +step:216/1670 train_time:21132ms step_avg:97.83ms +step:217/1670 train_time:21226ms step_avg:97.82ms +step:218/1670 train_time:21320ms step_avg:97.80ms +step:219/1670 train_time:21415ms step_avg:97.78ms +step:220/1670 train_time:21510ms step_avg:97.77ms +step:221/1670 train_time:21605ms step_avg:97.76ms +step:222/1670 train_time:21699ms step_avg:97.74ms +step:223/1670 train_time:21796ms step_avg:97.74ms +step:224/1670 train_time:21896ms step_avg:97.75ms +step:225/1670 train_time:21993ms step_avg:97.75ms +step:226/1670 train_time:22089ms step_avg:97.74ms +step:227/1670 train_time:22185ms step_avg:97.73ms +step:228/1670 train_time:22280ms step_avg:97.72ms +step:229/1670 train_time:22374ms step_avg:97.70ms +step:230/1670 train_time:22469ms step_avg:97.69ms +step:231/1670 train_time:22563ms step_avg:97.68ms +step:232/1670 train_time:22658ms step_avg:97.66ms +step:233/1670 train_time:22753ms step_avg:97.65ms +step:234/1670 train_time:22850ms step_avg:97.65ms +step:235/1670 train_time:22948ms step_avg:97.65ms +step:236/1670 train_time:23044ms step_avg:97.64ms +step:237/1670 train_time:23139ms step_avg:97.63ms +step:238/1670 train_time:23235ms step_avg:97.63ms +step:239/1670 train_time:23331ms step_avg:97.62ms +step:240/1670 train_time:23426ms step_avg:97.61ms +step:241/1670 train_time:23522ms step_avg:97.60ms +step:242/1670 train_time:23617ms step_avg:97.59ms +step:243/1670 train_time:23712ms step_avg:97.58ms +step:244/1670 train_time:23808ms step_avg:97.57ms +step:245/1670 train_time:23905ms step_avg:97.57ms +step:246/1670 train_time:24000ms step_avg:97.56ms +step:247/1670 train_time:24097ms step_avg:97.56ms +step:248/1670 train_time:24192ms step_avg:97.55ms +step:249/1670 train_time:24288ms step_avg:97.54ms +step:250/1670 train_time:24384ms step_avg:97.54ms +step:250/1670 val_loss:3.9677 train_time:24478ms step_avg:97.91ms +step:251/1670 train_time:24500ms step_avg:97.61ms +step:252/1670 train_time:24581ms step_avg:97.54ms +step:253/1670 train_time:24680ms step_avg:97.55ms +step:254/1670 train_time:24776ms step_avg:97.54ms +step:255/1670 train_time:24871ms step_avg:97.53ms +step:256/1670 train_time:24966ms step_avg:97.52ms +step:257/1670 train_time:25061ms step_avg:97.51ms +step:258/1670 train_time:25155ms step_avg:97.50ms +step:259/1670 train_time:25250ms step_avg:97.49ms +step:260/1670 train_time:25345ms step_avg:97.48ms +step:261/1670 train_time:25441ms step_avg:97.48ms +step:262/1670 train_time:25540ms step_avg:97.48ms +step:263/1670 train_time:25637ms step_avg:97.48ms +step:264/1670 train_time:25733ms step_avg:97.47ms +step:265/1670 train_time:25828ms step_avg:97.46ms +step:266/1670 train_time:25923ms step_avg:97.45ms +step:267/1670 train_time:26019ms step_avg:97.45ms +step:268/1670 train_time:26114ms step_avg:97.44ms +step:269/1670 train_time:26209ms step_avg:97.43ms +step:270/1670 train_time:26304ms step_avg:97.42ms +step:271/1670 train_time:26399ms step_avg:97.41ms +step:272/1670 train_time:26495ms step_avg:97.41ms +step:273/1670 train_time:26591ms step_avg:97.40ms +step:274/1670 train_time:26688ms step_avg:97.40ms +step:275/1670 train_time:26784ms step_avg:97.40ms +step:276/1670 train_time:26881ms step_avg:97.39ms +step:277/1670 train_time:26977ms step_avg:97.39ms +step:278/1670 train_time:27072ms step_avg:97.38ms +step:279/1670 train_time:27168ms step_avg:97.38ms +step:280/1670 train_time:27264ms step_avg:97.37ms +step:281/1670 train_time:27358ms step_avg:97.36ms +step:282/1670 train_time:27454ms step_avg:97.35ms +step:283/1670 train_time:27549ms step_avg:97.35ms +step:284/1670 train_time:27645ms step_avg:97.34ms +step:285/1670 train_time:27742ms step_avg:97.34ms +step:286/1670 train_time:27838ms step_avg:97.34ms +step:287/1670 train_time:27934ms step_avg:97.33ms +step:288/1670 train_time:28030ms step_avg:97.32ms +step:289/1670 train_time:28125ms step_avg:97.32ms +step:290/1670 train_time:28221ms step_avg:97.31ms +step:291/1670 train_time:28317ms step_avg:97.31ms +step:292/1670 train_time:28412ms step_avg:97.30ms +step:293/1670 train_time:28508ms step_avg:97.30ms +step:294/1670 train_time:28603ms step_avg:97.29ms +step:295/1670 train_time:28699ms step_avg:97.29ms +step:296/1670 train_time:28794ms step_avg:97.28ms +step:297/1670 train_time:28890ms step_avg:97.27ms +step:298/1670 train_time:28986ms step_avg:97.27ms +step:299/1670 train_time:29083ms step_avg:97.27ms +step:300/1670 train_time:29179ms step_avg:97.26ms +step:301/1670 train_time:29274ms step_avg:97.26ms +step:302/1670 train_time:29369ms step_avg:97.25ms +step:303/1670 train_time:29465ms step_avg:97.24ms +step:304/1670 train_time:29561ms step_avg:97.24ms +step:305/1670 train_time:29656ms step_avg:97.23ms +step:306/1670 train_time:29752ms step_avg:97.23ms +step:307/1670 train_time:29848ms step_avg:97.22ms +step:308/1670 train_time:29943ms step_avg:97.22ms +step:309/1670 train_time:30039ms step_avg:97.21ms +step:310/1670 train_time:30134ms step_avg:97.21ms +step:311/1670 train_time:30229ms step_avg:97.20ms +step:312/1670 train_time:30325ms step_avg:97.20ms +step:313/1670 train_time:30421ms step_avg:97.19ms +step:314/1670 train_time:30517ms step_avg:97.19ms +step:315/1670 train_time:30613ms step_avg:97.18ms +step:316/1670 train_time:30708ms step_avg:97.18ms +step:317/1670 train_time:30804ms step_avg:97.17ms +step:318/1670 train_time:30900ms step_avg:97.17ms +step:319/1670 train_time:30996ms step_avg:97.17ms +step:320/1670 train_time:31091ms step_avg:97.16ms +step:321/1670 train_time:31186ms step_avg:97.15ms +step:322/1670 train_time:31282ms step_avg:97.15ms +step:323/1670 train_time:31378ms step_avg:97.15ms +step:324/1670 train_time:31474ms step_avg:97.14ms +step:325/1670 train_time:31569ms step_avg:97.14ms +step:326/1670 train_time:31665ms step_avg:97.13ms +step:327/1670 train_time:31762ms step_avg:97.13ms +step:328/1670 train_time:31857ms step_avg:97.13ms +step:329/1670 train_time:31954ms step_avg:97.12ms +step:330/1670 train_time:32048ms step_avg:97.12ms +step:331/1670 train_time:32144ms step_avg:97.11ms +step:332/1670 train_time:32240ms step_avg:97.11ms +step:333/1670 train_time:32337ms step_avg:97.11ms +step:334/1670 train_time:32432ms step_avg:97.10ms +step:335/1670 train_time:32528ms step_avg:97.10ms +step:336/1670 train_time:32623ms step_avg:97.09ms +step:337/1670 train_time:32719ms step_avg:97.09ms +step:338/1670 train_time:32816ms step_avg:97.09ms +step:339/1670 train_time:32912ms step_avg:97.09ms +step:340/1670 train_time:33008ms step_avg:97.08ms +step:341/1670 train_time:33103ms step_avg:97.08ms +step:342/1670 train_time:33199ms step_avg:97.07ms +step:343/1670 train_time:33294ms step_avg:97.07ms +step:344/1670 train_time:33390ms step_avg:97.06ms +step:345/1670 train_time:33486ms step_avg:97.06ms +step:346/1670 train_time:33581ms step_avg:97.06ms +step:347/1670 train_time:33677ms step_avg:97.05ms +step:348/1670 train_time:33772ms step_avg:97.05ms +step:349/1670 train_time:33868ms step_avg:97.04ms +step:350/1670 train_time:33964ms step_avg:97.04ms +step:351/1670 train_time:34061ms step_avg:97.04ms +step:352/1670 train_time:34156ms step_avg:97.04ms +step:353/1670 train_time:34252ms step_avg:97.03ms +step:354/1670 train_time:34347ms step_avg:97.03ms +step:355/1670 train_time:34443ms step_avg:97.02ms +step:356/1670 train_time:34538ms step_avg:97.02ms +step:357/1670 train_time:34633ms step_avg:97.01ms +step:358/1670 train_time:34729ms step_avg:97.01ms +step:359/1670 train_time:34825ms step_avg:97.00ms +step:360/1670 train_time:34920ms step_avg:97.00ms +step:361/1670 train_time:35016ms step_avg:97.00ms +step:362/1670 train_time:35111ms step_avg:96.99ms +step:363/1670 train_time:35207ms step_avg:96.99ms +step:364/1670 train_time:35303ms step_avg:96.99ms +step:365/1670 train_time:35399ms step_avg:96.98ms +step:366/1670 train_time:35495ms step_avg:96.98ms +step:367/1670 train_time:35590ms step_avg:96.97ms +step:368/1670 train_time:35686ms step_avg:96.97ms +step:369/1670 train_time:35781ms step_avg:96.97ms +step:370/1670 train_time:35877ms step_avg:96.96ms +step:371/1670 train_time:35972ms step_avg:96.96ms +step:372/1670 train_time:36068ms step_avg:96.96ms +step:373/1670 train_time:36164ms step_avg:96.96ms +step:374/1670 train_time:36260ms step_avg:96.95ms +step:375/1670 train_time:36356ms step_avg:96.95ms +step:375/1670 val_loss:3.8164 train_time:36451ms step_avg:97.20ms +step:376/1670 train_time:36472ms step_avg:97.00ms +step:377/1670 train_time:36563ms step_avg:96.99ms +step:378/1670 train_time:36663ms step_avg:96.99ms +step:379/1670 train_time:36759ms step_avg:96.99ms +step:380/1670 train_time:36854ms step_avg:96.98ms +step:381/1670 train_time:36948ms step_avg:96.98ms +step:382/1670 train_time:37043ms step_avg:96.97ms +step:383/1670 train_time:37137ms step_avg:96.96ms +step:384/1670 train_time:37233ms step_avg:96.96ms +step:385/1670 train_time:37327ms step_avg:96.95ms +step:386/1670 train_time:37423ms step_avg:96.95ms +step:387/1670 train_time:37520ms step_avg:96.95ms +step:388/1670 train_time:37620ms step_avg:96.96ms +step:389/1670 train_time:37717ms step_avg:96.96ms +step:390/1670 train_time:37813ms step_avg:96.96ms +step:391/1670 train_time:37910ms step_avg:96.96ms +step:392/1670 train_time:38005ms step_avg:96.95ms +step:393/1670 train_time:38101ms step_avg:96.95ms +step:394/1670 train_time:38196ms step_avg:96.94ms +step:395/1670 train_time:38291ms step_avg:96.94ms +step:396/1670 train_time:38386ms step_avg:96.93ms +step:397/1670 train_time:38482ms step_avg:96.93ms +step:398/1670 train_time:38579ms step_avg:96.93ms +step:399/1670 train_time:38676ms step_avg:96.93ms +step:400/1670 train_time:38773ms step_avg:96.93ms +step:401/1670 train_time:38869ms step_avg:96.93ms +step:402/1670 train_time:38965ms step_avg:96.93ms +step:403/1670 train_time:39060ms step_avg:96.92ms +step:404/1670 train_time:39156ms step_avg:96.92ms +step:405/1670 train_time:39251ms step_avg:96.92ms +step:406/1670 train_time:39346ms step_avg:96.91ms +step:407/1670 train_time:39441ms step_avg:96.91ms +step:408/1670 train_time:39537ms step_avg:96.91ms +step:409/1670 train_time:39634ms step_avg:96.91ms +step:410/1670 train_time:39730ms step_avg:96.90ms +step:411/1670 train_time:39826ms step_avg:96.90ms +step:412/1670 train_time:39922ms step_avg:96.90ms +step:413/1670 train_time:40018ms step_avg:96.90ms +step:414/1670 train_time:40114ms step_avg:96.89ms +step:415/1670 train_time:40210ms step_avg:96.89ms +step:416/1670 train_time:40306ms step_avg:96.89ms +step:417/1670 train_time:40401ms step_avg:96.88ms +step:418/1670 train_time:40497ms step_avg:96.88ms +step:419/1670 train_time:40593ms step_avg:96.88ms +step:420/1670 train_time:40689ms step_avg:96.88ms +step:421/1670 train_time:40785ms step_avg:96.88ms +step:422/1670 train_time:40881ms step_avg:96.87ms +step:423/1670 train_time:40976ms step_avg:96.87ms +step:424/1670 train_time:41072ms step_avg:96.87ms +step:425/1670 train_time:41348ms step_avg:97.29ms +step:426/1670 train_time:41559ms step_avg:97.56ms +step:427/1670 train_time:41653ms step_avg:97.55ms +step:428/1670 train_time:41748ms step_avg:97.54ms +step:429/1670 train_time:41842ms step_avg:97.53ms +step:430/1670 train_time:41937ms step_avg:97.53ms +step:431/1670 train_time:42032ms step_avg:97.52ms +step:432/1670 train_time:42126ms step_avg:97.51ms +step:433/1670 train_time:42221ms step_avg:97.51ms +step:434/1670 train_time:42315ms step_avg:97.50ms +step:435/1670 train_time:42413ms step_avg:97.50ms +step:436/1670 train_time:42516ms step_avg:97.51ms +step:437/1670 train_time:42615ms step_avg:97.52ms +step:438/1670 train_time:42712ms step_avg:97.52ms +step:439/1670 train_time:42809ms step_avg:97.51ms +step:440/1670 train_time:42904ms step_avg:97.51ms +step:441/1670 train_time:42999ms step_avg:97.50ms +step:442/1670 train_time:43094ms step_avg:97.50ms +step:443/1670 train_time:43189ms step_avg:97.49ms +step:444/1670 train_time:43284ms step_avg:97.49ms +step:445/1670 train_time:43379ms step_avg:97.48ms +step:446/1670 train_time:43476ms step_avg:97.48ms +step:447/1670 train_time:43573ms step_avg:97.48ms +step:448/1670 train_time:43671ms step_avg:97.48ms +step:449/1670 train_time:43767ms step_avg:97.48ms +step:450/1670 train_time:43862ms step_avg:97.47ms +step:451/1670 train_time:43957ms step_avg:97.47ms +step:452/1670 train_time:44053ms step_avg:97.46ms +step:453/1670 train_time:44148ms step_avg:97.46ms +step:454/1670 train_time:44243ms step_avg:97.45ms +step:455/1670 train_time:44338ms step_avg:97.45ms +step:456/1670 train_time:44434ms step_avg:97.44ms +step:457/1670 train_time:44532ms step_avg:97.44ms +step:458/1670 train_time:44628ms step_avg:97.44ms +step:459/1670 train_time:44724ms step_avg:97.44ms +step:460/1670 train_time:44820ms step_avg:97.43ms +step:461/1670 train_time:44916ms step_avg:97.43ms +step:462/1670 train_time:45012ms step_avg:97.43ms +step:463/1670 train_time:45108ms step_avg:97.42ms +step:464/1670 train_time:45203ms step_avg:97.42ms +step:465/1670 train_time:45298ms step_avg:97.41ms +step:466/1670 train_time:45393ms step_avg:97.41ms +step:467/1670 train_time:45489ms step_avg:97.41ms +step:468/1670 train_time:45586ms step_avg:97.41ms +step:469/1670 train_time:45682ms step_avg:97.40ms +step:470/1670 train_time:45777ms step_avg:97.40ms +step:471/1670 train_time:45873ms step_avg:97.39ms +step:472/1670 train_time:45970ms step_avg:97.39ms +step:473/1670 train_time:46066ms step_avg:97.39ms +step:474/1670 train_time:46161ms step_avg:97.39ms +step:475/1670 train_time:46256ms step_avg:97.38ms +step:476/1670 train_time:46353ms step_avg:97.38ms +step:477/1670 train_time:46449ms step_avg:97.38ms +step:478/1670 train_time:46544ms step_avg:97.37ms +step:479/1670 train_time:46640ms step_avg:97.37ms +step:480/1670 train_time:46736ms step_avg:97.37ms +step:481/1670 train_time:46832ms step_avg:97.36ms +step:482/1670 train_time:46928ms step_avg:97.36ms +step:483/1670 train_time:47025ms step_avg:97.36ms +step:484/1670 train_time:47119ms step_avg:97.35ms +step:485/1670 train_time:47215ms step_avg:97.35ms +step:486/1670 train_time:47311ms step_avg:97.35ms +step:487/1670 train_time:47407ms step_avg:97.34ms +step:488/1670 train_time:47502ms step_avg:97.34ms +step:489/1670 train_time:47597ms step_avg:97.34ms +step:490/1670 train_time:47694ms step_avg:97.33ms +step:491/1670 train_time:47790ms step_avg:97.33ms +step:492/1670 train_time:47886ms step_avg:97.33ms +step:493/1670 train_time:47981ms step_avg:97.33ms +step:494/1670 train_time:48077ms step_avg:97.32ms +step:495/1670 train_time:48173ms step_avg:97.32ms +step:496/1670 train_time:48269ms step_avg:97.32ms +step:497/1670 train_time:48365ms step_avg:97.31ms +step:498/1670 train_time:48460ms step_avg:97.31ms +step:499/1670 train_time:48556ms step_avg:97.31ms +step:500/1670 train_time:48652ms step_avg:97.30ms +step:500/1670 val_loss:3.7161 train_time:48748ms step_avg:97.50ms +step:501/1670 train_time:48770ms step_avg:97.35ms +step:502/1670 train_time:48851ms step_avg:97.31ms +step:503/1670 train_time:48950ms step_avg:97.32ms +step:504/1670 train_time:49047ms step_avg:97.32ms +step:505/1670 train_time:49143ms step_avg:97.31ms +step:506/1670 train_time:49238ms step_avg:97.31ms +step:507/1670 train_time:49333ms step_avg:97.30ms +step:508/1670 train_time:49428ms step_avg:97.30ms +step:509/1670 train_time:49525ms step_avg:97.30ms +step:510/1670 train_time:49620ms step_avg:97.29ms +step:511/1670 train_time:49716ms step_avg:97.29ms +step:512/1670 train_time:49813ms step_avg:97.29ms +step:513/1670 train_time:49910ms step_avg:97.29ms +step:514/1670 train_time:50007ms step_avg:97.29ms +step:515/1670 train_time:50104ms step_avg:97.29ms +step:516/1670 train_time:50200ms step_avg:97.29ms +step:517/1670 train_time:50296ms step_avg:97.28ms +step:518/1670 train_time:50390ms step_avg:97.28ms +step:519/1670 train_time:50485ms step_avg:97.27ms +step:520/1670 train_time:50580ms step_avg:97.27ms +step:521/1670 train_time:50676ms step_avg:97.27ms +step:522/1670 train_time:50772ms step_avg:97.26ms +step:523/1670 train_time:50869ms step_avg:97.26ms +step:524/1670 train_time:50966ms step_avg:97.26ms +step:525/1670 train_time:51062ms step_avg:97.26ms +step:526/1670 train_time:51157ms step_avg:97.26ms +step:527/1670 train_time:51253ms step_avg:97.25ms +step:528/1670 train_time:51348ms step_avg:97.25ms +step:529/1670 train_time:51444ms step_avg:97.25ms +step:530/1670 train_time:51539ms step_avg:97.24ms +step:531/1670 train_time:51635ms step_avg:97.24ms +step:532/1670 train_time:51730ms step_avg:97.24ms +step:533/1670 train_time:51826ms step_avg:97.23ms +step:534/1670 train_time:51923ms step_avg:97.23ms +step:535/1670 train_time:52019ms step_avg:97.23ms +step:536/1670 train_time:52116ms step_avg:97.23ms +step:537/1670 train_time:52211ms step_avg:97.23ms +step:538/1670 train_time:52307ms step_avg:97.22ms +step:539/1670 train_time:52402ms step_avg:97.22ms +step:540/1670 train_time:52498ms step_avg:97.22ms +step:541/1670 train_time:52593ms step_avg:97.21ms +step:542/1670 train_time:52688ms step_avg:97.21ms +step:543/1670 train_time:52785ms step_avg:97.21ms +step:544/1670 train_time:52881ms step_avg:97.21ms +step:545/1670 train_time:52978ms step_avg:97.21ms +step:546/1670 train_time:53075ms step_avg:97.21ms +step:547/1670 train_time:53170ms step_avg:97.20ms +step:548/1670 train_time:53266ms step_avg:97.20ms +step:549/1670 train_time:53361ms step_avg:97.20ms +step:550/1670 train_time:53457ms step_avg:97.19ms +step:551/1670 train_time:53552ms step_avg:97.19ms +step:552/1670 train_time:53647ms step_avg:97.19ms +step:553/1670 train_time:53743ms step_avg:97.19ms +step:554/1670 train_time:53839ms step_avg:97.18ms +step:555/1670 train_time:53934ms step_avg:97.18ms +step:556/1670 train_time:54031ms step_avg:97.18ms +step:557/1670 train_time:54127ms step_avg:97.18ms +step:558/1670 train_time:54224ms step_avg:97.18ms +step:559/1670 train_time:54321ms step_avg:97.18ms +step:560/1670 train_time:54418ms step_avg:97.18ms +step:561/1670 train_time:54515ms step_avg:97.18ms +step:562/1670 train_time:54612ms step_avg:97.17ms +step:563/1670 train_time:54709ms step_avg:97.17ms +step:564/1670 train_time:54807ms step_avg:97.17ms +step:565/1670 train_time:54905ms step_avg:97.18ms +step:566/1670 train_time:55004ms step_avg:97.18ms +step:567/1670 train_time:55102ms step_avg:97.18ms +step:568/1670 train_time:55200ms step_avg:97.18ms +step:569/1670 train_time:55297ms step_avg:97.18ms +step:570/1670 train_time:55393ms step_avg:97.18ms +step:571/1670 train_time:55490ms step_avg:97.18ms +step:572/1670 train_time:55587ms step_avg:97.18ms +step:573/1670 train_time:55683ms step_avg:97.18ms +step:574/1670 train_time:55781ms step_avg:97.18ms +step:575/1670 train_time:55880ms step_avg:97.18ms +step:576/1670 train_time:55977ms step_avg:97.18ms +step:577/1670 train_time:56074ms step_avg:97.18ms +step:578/1670 train_time:56171ms step_avg:97.18ms +step:579/1670 train_time:56268ms step_avg:97.18ms +step:580/1670 train_time:56367ms step_avg:97.18ms +step:581/1670 train_time:56465ms step_avg:97.19ms +step:582/1670 train_time:56563ms step_avg:97.19ms +step:583/1670 train_time:56659ms step_avg:97.19ms +step:584/1670 train_time:56756ms step_avg:97.19ms +step:585/1670 train_time:56852ms step_avg:97.18ms +step:586/1670 train_time:56950ms step_avg:97.18ms +step:587/1670 train_time:57047ms step_avg:97.18ms +step:588/1670 train_time:57146ms step_avg:97.19ms +step:589/1670 train_time:57244ms step_avg:97.19ms +step:590/1670 train_time:57341ms step_avg:97.19ms +step:591/1670 train_time:57438ms step_avg:97.19ms +step:592/1670 train_time:57535ms step_avg:97.19ms +step:593/1670 train_time:57631ms step_avg:97.19ms +step:594/1670 train_time:57730ms step_avg:97.19ms +step:595/1670 train_time:57827ms step_avg:97.19ms +step:596/1670 train_time:57925ms step_avg:97.19ms +step:597/1670 train_time:58023ms step_avg:97.19ms +step:598/1670 train_time:58121ms step_avg:97.19ms +step:599/1670 train_time:58219ms step_avg:97.19ms +step:600/1670 train_time:58316ms step_avg:97.19ms +step:601/1670 train_time:58413ms step_avg:97.19ms +step:602/1670 train_time:58509ms step_avg:97.19ms +step:603/1670 train_time:58606ms step_avg:97.19ms +step:604/1670 train_time:58704ms step_avg:97.19ms +step:605/1670 train_time:58801ms step_avg:97.19ms +step:606/1670 train_time:58898ms step_avg:97.19ms +step:607/1670 train_time:58995ms step_avg:97.19ms +step:608/1670 train_time:59092ms step_avg:97.19ms +step:609/1670 train_time:59191ms step_avg:97.19ms +step:610/1670 train_time:59288ms step_avg:97.19ms +step:611/1670 train_time:59385ms step_avg:97.19ms +step:612/1670 train_time:59483ms step_avg:97.19ms +step:613/1670 train_time:59580ms step_avg:97.19ms +step:614/1670 train_time:59677ms step_avg:97.19ms +step:615/1670 train_time:59774ms step_avg:97.19ms +step:616/1670 train_time:59870ms step_avg:97.19ms +step:617/1670 train_time:59968ms step_avg:97.19ms +step:618/1670 train_time:60065ms step_avg:97.19ms +step:619/1670 train_time:60163ms step_avg:97.19ms +step:620/1670 train_time:60260ms step_avg:97.19ms +step:621/1670 train_time:60357ms step_avg:97.19ms +step:622/1670 train_time:60453ms step_avg:97.19ms +step:623/1670 train_time:60551ms step_avg:97.19ms +step:624/1670 train_time:60649ms step_avg:97.19ms +step:625/1670 train_time:60747ms step_avg:97.19ms +step:625/1670 val_loss:3.6123 train_time:60843ms step_avg:97.35ms +step:626/1670 train_time:60865ms step_avg:97.23ms +step:627/1670 train_time:60947ms step_avg:97.20ms +step:628/1670 train_time:61043ms step_avg:97.20ms +step:629/1670 train_time:61139ms step_avg:97.20ms +step:630/1670 train_time:61236ms step_avg:97.20ms +step:631/1670 train_time:61332ms step_avg:97.20ms +step:632/1670 train_time:61427ms step_avg:97.20ms +step:633/1670 train_time:61523ms step_avg:97.19ms +step:634/1670 train_time:61618ms step_avg:97.19ms +step:635/1670 train_time:61715ms step_avg:97.19ms +step:636/1670 train_time:61816ms step_avg:97.19ms +step:637/1670 train_time:61917ms step_avg:97.20ms +step:638/1670 train_time:62017ms step_avg:97.21ms +step:639/1670 train_time:62397ms step_avg:97.65ms +step:640/1670 train_time:62476ms step_avg:97.62ms +step:641/1670 train_time:62572ms step_avg:97.62ms +step:642/1670 train_time:62669ms step_avg:97.61ms +step:643/1670 train_time:62765ms step_avg:97.61ms +step:644/1670 train_time:62860ms step_avg:97.61ms +step:645/1670 train_time:62957ms step_avg:97.61ms +step:646/1670 train_time:63053ms step_avg:97.60ms +step:647/1670 train_time:63148ms step_avg:97.60ms +step:648/1670 train_time:63245ms step_avg:97.60ms +step:649/1670 train_time:63347ms step_avg:97.61ms +step:650/1670 train_time:63447ms step_avg:97.61ms +step:651/1670 train_time:63544ms step_avg:97.61ms +step:652/1670 train_time:63641ms step_avg:97.61ms +step:653/1670 train_time:63738ms step_avg:97.61ms +step:654/1670 train_time:63835ms step_avg:97.61ms +step:655/1670 train_time:63932ms step_avg:97.61ms +step:656/1670 train_time:64028ms step_avg:97.60ms +step:657/1670 train_time:64124ms step_avg:97.60ms +step:658/1670 train_time:64220ms step_avg:97.60ms +step:659/1670 train_time:64318ms step_avg:97.60ms +step:660/1670 train_time:64418ms step_avg:97.60ms +step:661/1670 train_time:64515ms step_avg:97.60ms +step:662/1670 train_time:64614ms step_avg:97.60ms +step:663/1670 train_time:64711ms step_avg:97.60ms +step:664/1670 train_time:64809ms step_avg:97.60ms +step:665/1670 train_time:64905ms step_avg:97.60ms +step:666/1670 train_time:65001ms step_avg:97.60ms +step:667/1670 train_time:65098ms step_avg:97.60ms +step:668/1670 train_time:65194ms step_avg:97.60ms +step:669/1670 train_time:65294ms step_avg:97.60ms +step:670/1670 train_time:65394ms step_avg:97.60ms +step:671/1670 train_time:65492ms step_avg:97.60ms +step:672/1670 train_time:65591ms step_avg:97.61ms +step:673/1670 train_time:65689ms step_avg:97.61ms +step:674/1670 train_time:65786ms step_avg:97.61ms +step:675/1670 train_time:65883ms step_avg:97.60ms +step:676/1670 train_time:65979ms step_avg:97.60ms +step:677/1670 train_time:66075ms step_avg:97.60ms +step:678/1670 train_time:66172ms step_avg:97.60ms +step:679/1670 train_time:66269ms step_avg:97.60ms +step:680/1670 train_time:66366ms step_avg:97.60ms +step:681/1670 train_time:66464ms step_avg:97.60ms +step:682/1670 train_time:66561ms step_avg:97.60ms +step:683/1670 train_time:66659ms step_avg:97.60ms +step:684/1670 train_time:66757ms step_avg:97.60ms +step:685/1670 train_time:66855ms step_avg:97.60ms +step:686/1670 train_time:66953ms step_avg:97.60ms +step:687/1670 train_time:67050ms step_avg:97.60ms +step:688/1670 train_time:67146ms step_avg:97.60ms +step:689/1670 train_time:67243ms step_avg:97.60ms +step:690/1670 train_time:67340ms step_avg:97.59ms +step:691/1670 train_time:67437ms step_avg:97.59ms +step:692/1670 train_time:67535ms step_avg:97.59ms +step:693/1670 train_time:67634ms step_avg:97.60ms +step:694/1670 train_time:67733ms step_avg:97.60ms +step:695/1670 train_time:67830ms step_avg:97.60ms +step:696/1670 train_time:67927ms step_avg:97.60ms +step:697/1670 train_time:68023ms step_avg:97.59ms +step:698/1670 train_time:68121ms step_avg:97.59ms +step:699/1670 train_time:68217ms step_avg:97.59ms +step:700/1670 train_time:68315ms step_avg:97.59ms +step:701/1670 train_time:68413ms step_avg:97.59ms +step:702/1670 train_time:68511ms step_avg:97.59ms +step:703/1670 train_time:68608ms step_avg:97.59ms +step:704/1670 train_time:68705ms step_avg:97.59ms +step:705/1670 train_time:68802ms step_avg:97.59ms +step:706/1670 train_time:68900ms step_avg:97.59ms +step:707/1670 train_time:68998ms step_avg:97.59ms +step:708/1670 train_time:69095ms step_avg:97.59ms +step:709/1670 train_time:69193ms step_avg:97.59ms +step:710/1670 train_time:69290ms step_avg:97.59ms +step:711/1670 train_time:69387ms step_avg:97.59ms +step:712/1670 train_time:69484ms step_avg:97.59ms +step:713/1670 train_time:69581ms step_avg:97.59ms +step:714/1670 train_time:69678ms step_avg:97.59ms +step:715/1670 train_time:69775ms step_avg:97.59ms +step:716/1670 train_time:69872ms step_avg:97.59ms +step:717/1670 train_time:69970ms step_avg:97.59ms +step:718/1670 train_time:70068ms step_avg:97.59ms +step:719/1670 train_time:70165ms step_avg:97.59ms +step:720/1670 train_time:70261ms step_avg:97.58ms +step:721/1670 train_time:70357ms step_avg:97.58ms +step:722/1670 train_time:70456ms step_avg:97.58ms +step:723/1670 train_time:70553ms step_avg:97.58ms +step:724/1670 train_time:70651ms step_avg:97.58ms +step:725/1670 train_time:70748ms step_avg:97.58ms +step:726/1670 train_time:70845ms step_avg:97.58ms +step:727/1670 train_time:70942ms step_avg:97.58ms +step:728/1670 train_time:71040ms step_avg:97.58ms +step:729/1670 train_time:71139ms step_avg:97.58ms +step:730/1670 train_time:71236ms step_avg:97.58ms +step:731/1670 train_time:71334ms step_avg:97.58ms +step:732/1670 train_time:71431ms step_avg:97.58ms +step:733/1670 train_time:71529ms step_avg:97.58ms +step:734/1670 train_time:71626ms step_avg:97.58ms +step:735/1670 train_time:71722ms step_avg:97.58ms +step:736/1670 train_time:71819ms step_avg:97.58ms +step:737/1670 train_time:71917ms step_avg:97.58ms +step:738/1670 train_time:72015ms step_avg:97.58ms +step:739/1670 train_time:72113ms step_avg:97.58ms +step:740/1670 train_time:72210ms step_avg:97.58ms +step:741/1670 train_time:72307ms step_avg:97.58ms +step:742/1670 train_time:72404ms step_avg:97.58ms +step:743/1670 train_time:72501ms step_avg:97.58ms +step:744/1670 train_time:72599ms step_avg:97.58ms +step:745/1670 train_time:72697ms step_avg:97.58ms +step:746/1670 train_time:72793ms step_avg:97.58ms +step:747/1670 train_time:72891ms step_avg:97.58ms +step:748/1670 train_time:72989ms step_avg:97.58ms +step:749/1670 train_time:73087ms step_avg:97.58ms +step:750/1670 train_time:73184ms step_avg:97.58ms +step:750/1670 val_loss:3.5623 train_time:73280ms step_avg:97.71ms +step:751/1670 train_time:73302ms step_avg:97.61ms +step:752/1670 train_time:73387ms step_avg:97.59ms +step:753/1670 train_time:73485ms step_avg:97.59ms +step:754/1670 train_time:73582ms step_avg:97.59ms +step:755/1670 train_time:73678ms step_avg:97.59ms +step:756/1670 train_time:73774ms step_avg:97.58ms +step:757/1670 train_time:73870ms step_avg:97.58ms +step:758/1670 train_time:73967ms step_avg:97.58ms +step:759/1670 train_time:74064ms step_avg:97.58ms +step:760/1670 train_time:74159ms step_avg:97.58ms +step:761/1670 train_time:74257ms step_avg:97.58ms +step:762/1670 train_time:74358ms step_avg:97.58ms +step:763/1670 train_time:74459ms step_avg:97.59ms +step:764/1670 train_time:74558ms step_avg:97.59ms +step:765/1670 train_time:74655ms step_avg:97.59ms +step:766/1670 train_time:74752ms step_avg:97.59ms +step:767/1670 train_time:74849ms step_avg:97.59ms +step:768/1670 train_time:74944ms step_avg:97.58ms +step:769/1670 train_time:75040ms step_avg:97.58ms +step:770/1670 train_time:75136ms step_avg:97.58ms +step:771/1670 train_time:75234ms step_avg:97.58ms +step:772/1670 train_time:75333ms step_avg:97.58ms +step:773/1670 train_time:75434ms step_avg:97.59ms +step:774/1670 train_time:75533ms step_avg:97.59ms +step:775/1670 train_time:75632ms step_avg:97.59ms +step:776/1670 train_time:75730ms step_avg:97.59ms +step:777/1670 train_time:75826ms step_avg:97.59ms +step:778/1670 train_time:75923ms step_avg:97.59ms +step:779/1670 train_time:76018ms step_avg:97.58ms +step:780/1670 train_time:76115ms step_avg:97.58ms +step:781/1670 train_time:76213ms step_avg:97.58ms +step:782/1670 train_time:76311ms step_avg:97.58ms +step:783/1670 train_time:76411ms step_avg:97.59ms +step:784/1670 train_time:76511ms step_avg:97.59ms +step:785/1670 train_time:76610ms step_avg:97.59ms +step:786/1670 train_time:76709ms step_avg:97.59ms +step:787/1670 train_time:76805ms step_avg:97.59ms +step:788/1670 train_time:76901ms step_avg:97.59ms +step:789/1670 train_time:76997ms step_avg:97.59ms +step:790/1670 train_time:77094ms step_avg:97.59ms +step:791/1670 train_time:77192ms step_avg:97.59ms +step:792/1670 train_time:77291ms step_avg:97.59ms +step:793/1670 train_time:77390ms step_avg:97.59ms +step:794/1670 train_time:77489ms step_avg:97.59ms +step:795/1670 train_time:77587ms step_avg:97.59ms +step:796/1670 train_time:77685ms step_avg:97.59ms +step:797/1670 train_time:77781ms step_avg:97.59ms +step:798/1670 train_time:77878ms step_avg:97.59ms +step:799/1670 train_time:77974ms step_avg:97.59ms +step:800/1670 train_time:78070ms step_avg:97.59ms +step:801/1670 train_time:78167ms step_avg:97.59ms +step:802/1670 train_time:78265ms step_avg:97.59ms +step:803/1670 train_time:78363ms step_avg:97.59ms +step:804/1670 train_time:78459ms step_avg:97.59ms +step:805/1670 train_time:78557ms step_avg:97.59ms +step:806/1670 train_time:78655ms step_avg:97.59ms +step:807/1670 train_time:78752ms step_avg:97.59ms +step:808/1670 train_time:78851ms step_avg:97.59ms +step:809/1670 train_time:78949ms step_avg:97.59ms +step:810/1670 train_time:79045ms step_avg:97.59ms +step:811/1670 train_time:79142ms step_avg:97.59ms +step:812/1670 train_time:79238ms step_avg:97.58ms +step:813/1670 train_time:79335ms step_avg:97.58ms +step:814/1670 train_time:79434ms step_avg:97.58ms +step:815/1670 train_time:79533ms step_avg:97.59ms +step:816/1670 train_time:79632ms step_avg:97.59ms +step:817/1670 train_time:79729ms step_avg:97.59ms +step:818/1670 train_time:79826ms step_avg:97.59ms +step:819/1670 train_time:79924ms step_avg:97.59ms +step:820/1670 train_time:80021ms step_avg:97.59ms +step:821/1670 train_time:80117ms step_avg:97.58ms +step:822/1670 train_time:80214ms step_avg:97.58ms +step:823/1670 train_time:80311ms step_avg:97.58ms +step:824/1670 train_time:80408ms step_avg:97.58ms +step:825/1670 train_time:80507ms step_avg:97.58ms +step:826/1670 train_time:80604ms step_avg:97.58ms +step:827/1670 train_time:80701ms step_avg:97.58ms +step:828/1670 train_time:80798ms step_avg:97.58ms +step:829/1670 train_time:80897ms step_avg:97.58ms +step:830/1670 train_time:80994ms step_avg:97.58ms +step:831/1670 train_time:81092ms step_avg:97.58ms +step:832/1670 train_time:81189ms step_avg:97.58ms +step:833/1670 train_time:81286ms step_avg:97.58ms +step:834/1670 train_time:81382ms step_avg:97.58ms +step:835/1670 train_time:81479ms step_avg:97.58ms +step:836/1670 train_time:81577ms step_avg:97.58ms +step:837/1670 train_time:81674ms step_avg:97.58ms +step:838/1670 train_time:81772ms step_avg:97.58ms +step:839/1670 train_time:81870ms step_avg:97.58ms +step:840/1670 train_time:81969ms step_avg:97.58ms +step:841/1670 train_time:82068ms step_avg:97.58ms +step:842/1670 train_time:82165ms step_avg:97.58ms +step:843/1670 train_time:82262ms step_avg:97.58ms +step:844/1670 train_time:82358ms step_avg:97.58ms +step:845/1670 train_time:82455ms step_avg:97.58ms +step:846/1670 train_time:82553ms step_avg:97.58ms +step:847/1670 train_time:82650ms step_avg:97.58ms +step:848/1670 train_time:82747ms step_avg:97.58ms +step:849/1670 train_time:82844ms step_avg:97.58ms +step:850/1670 train_time:82941ms step_avg:97.58ms +step:851/1670 train_time:83209ms step_avg:97.78ms +step:852/1670 train_time:83379ms step_avg:97.86ms +step:853/1670 train_time:83473ms step_avg:97.86ms +step:854/1670 train_time:83570ms step_avg:97.86ms +step:855/1670 train_time:83666ms step_avg:97.86ms +step:856/1670 train_time:83762ms step_avg:97.85ms +step:857/1670 train_time:83858ms step_avg:97.85ms +step:858/1670 train_time:83955ms step_avg:97.85ms +step:859/1670 train_time:84051ms step_avg:97.85ms +step:860/1670 train_time:84148ms step_avg:97.85ms +step:861/1670 train_time:84245ms step_avg:97.85ms +step:862/1670 train_time:84351ms step_avg:97.85ms +step:863/1670 train_time:84450ms step_avg:97.86ms +step:864/1670 train_time:84548ms step_avg:97.86ms +step:865/1670 train_time:84644ms step_avg:97.85ms +step:866/1670 train_time:84740ms step_avg:97.85ms +step:867/1670 train_time:84836ms step_avg:97.85ms +step:868/1670 train_time:84932ms step_avg:97.85ms +step:869/1670 train_time:85029ms step_avg:97.85ms +step:870/1670 train_time:85125ms step_avg:97.85ms +step:871/1670 train_time:85222ms step_avg:97.84ms +step:872/1670 train_time:85321ms step_avg:97.84ms +step:873/1670 train_time:85420ms step_avg:97.85ms +step:874/1670 train_time:85518ms step_avg:97.85ms +step:875/1670 train_time:85617ms step_avg:97.85ms +step:875/1670 val_loss:3.5191 train_time:85714ms step_avg:97.96ms +step:876/1670 train_time:85735ms step_avg:97.87ms +step:877/1670 train_time:85818ms step_avg:97.85ms +step:878/1670 train_time:85917ms step_avg:97.86ms +step:879/1670 train_time:86014ms step_avg:97.85ms +step:880/1670 train_time:86111ms step_avg:97.85ms +step:881/1670 train_time:86207ms step_avg:97.85ms +step:882/1670 train_time:86303ms step_avg:97.85ms +step:883/1670 train_time:86399ms step_avg:97.85ms +step:884/1670 train_time:86496ms step_avg:97.85ms +step:885/1670 train_time:86592ms step_avg:97.84ms +step:886/1670 train_time:86690ms step_avg:97.84ms +step:887/1670 train_time:86792ms step_avg:97.85ms +step:888/1670 train_time:86893ms step_avg:97.85ms +step:889/1670 train_time:86992ms step_avg:97.85ms +step:890/1670 train_time:87090ms step_avg:97.85ms +step:891/1670 train_time:87186ms step_avg:97.85ms +step:892/1670 train_time:87284ms step_avg:97.85ms +step:893/1670 train_time:87382ms step_avg:97.85ms +step:894/1670 train_time:87479ms step_avg:97.85ms +step:895/1670 train_time:87576ms step_avg:97.85ms +step:896/1670 train_time:87672ms step_avg:97.85ms +step:897/1670 train_time:87770ms step_avg:97.85ms +step:898/1670 train_time:87870ms step_avg:97.85ms +step:899/1670 train_time:87968ms step_avg:97.85ms +step:900/1670 train_time:88066ms step_avg:97.85ms +step:901/1670 train_time:88163ms step_avg:97.85ms +step:902/1670 train_time:88259ms step_avg:97.85ms +step:903/1670 train_time:88356ms step_avg:97.85ms +step:904/1670 train_time:88452ms step_avg:97.85ms +step:905/1670 train_time:88550ms step_avg:97.85ms +step:906/1670 train_time:88647ms step_avg:97.84ms +step:907/1670 train_time:88747ms step_avg:97.85ms +step:908/1670 train_time:88845ms step_avg:97.85ms +step:909/1670 train_time:88944ms step_avg:97.85ms +step:910/1670 train_time:89042ms step_avg:97.85ms +step:911/1670 train_time:89139ms step_avg:97.85ms +step:912/1670 train_time:89236ms step_avg:97.85ms +step:913/1670 train_time:89333ms step_avg:97.85ms +step:914/1670 train_time:89430ms step_avg:97.84ms +step:915/1670 train_time:89527ms step_avg:97.84ms +step:916/1670 train_time:89625ms step_avg:97.84ms +step:917/1670 train_time:89723ms step_avg:97.84ms +step:918/1670 train_time:89820ms step_avg:97.84ms +step:919/1670 train_time:89918ms step_avg:97.84ms +step:920/1670 train_time:90015ms step_avg:97.84ms +step:921/1670 train_time:90112ms step_avg:97.84ms +step:922/1670 train_time:90209ms step_avg:97.84ms +step:923/1670 train_time:90306ms step_avg:97.84ms +step:924/1670 train_time:90403ms step_avg:97.84ms +step:925/1670 train_time:90500ms step_avg:97.84ms +step:926/1670 train_time:90597ms step_avg:97.84ms +step:927/1670 train_time:90694ms step_avg:97.84ms +step:928/1670 train_time:90791ms step_avg:97.84ms +step:929/1670 train_time:90889ms step_avg:97.84ms +step:930/1670 train_time:90987ms step_avg:97.84ms +step:931/1670 train_time:91086ms step_avg:97.84ms +step:932/1670 train_time:91185ms step_avg:97.84ms +step:933/1670 train_time:91283ms step_avg:97.84ms +step:934/1670 train_time:91380ms step_avg:97.84ms +step:935/1670 train_time:91477ms step_avg:97.84ms +step:936/1670 train_time:91574ms step_avg:97.84ms +step:937/1670 train_time:91670ms step_avg:97.83ms +step:938/1670 train_time:91767ms step_avg:97.83ms +step:939/1670 train_time:91866ms step_avg:97.83ms +step:940/1670 train_time:91963ms step_avg:97.83ms +step:941/1670 train_time:92060ms step_avg:97.83ms +step:942/1670 train_time:92158ms step_avg:97.83ms +step:943/1670 train_time:92255ms step_avg:97.83ms +step:944/1670 train_time:92351ms step_avg:97.83ms +step:945/1670 train_time:92449ms step_avg:97.83ms +step:946/1670 train_time:92548ms step_avg:97.83ms +step:947/1670 train_time:92646ms step_avg:97.83ms +step:948/1670 train_time:92743ms step_avg:97.83ms +step:949/1670 train_time:92840ms step_avg:97.83ms +step:950/1670 train_time:92937ms step_avg:97.83ms +step:951/1670 train_time:93033ms step_avg:97.83ms +step:952/1670 train_time:93132ms step_avg:97.83ms +step:953/1670 train_time:93230ms step_avg:97.83ms +step:954/1670 train_time:93327ms step_avg:97.83ms +step:955/1670 train_time:93425ms step_avg:97.83ms +step:956/1670 train_time:93523ms step_avg:97.83ms +step:957/1670 train_time:93620ms step_avg:97.83ms +step:958/1670 train_time:93717ms step_avg:97.83ms +step:959/1670 train_time:93814ms step_avg:97.82ms +step:960/1670 train_time:93911ms step_avg:97.82ms +step:961/1670 train_time:94008ms step_avg:97.82ms +step:962/1670 train_time:94107ms step_avg:97.82ms +step:963/1670 train_time:94205ms step_avg:97.82ms +step:964/1670 train_time:94303ms step_avg:97.82ms +step:965/1670 train_time:94400ms step_avg:97.82ms +step:966/1670 train_time:94497ms step_avg:97.82ms +step:967/1670 train_time:94594ms step_avg:97.82ms +step:968/1670 train_time:94692ms step_avg:97.82ms +step:969/1670 train_time:94789ms step_avg:97.82ms +step:970/1670 train_time:94886ms step_avg:97.82ms +step:971/1670 train_time:94984ms step_avg:97.82ms +step:972/1670 train_time:95082ms step_avg:97.82ms +step:973/1670 train_time:95178ms step_avg:97.82ms +step:974/1670 train_time:95275ms step_avg:97.82ms +step:975/1670 train_time:95372ms step_avg:97.82ms +step:976/1670 train_time:95470ms step_avg:97.82ms +step:977/1670 train_time:95568ms step_avg:97.82ms +step:978/1670 train_time:95666ms step_avg:97.82ms +step:979/1670 train_time:95763ms step_avg:97.82ms +step:980/1670 train_time:95861ms step_avg:97.82ms +step:981/1670 train_time:95958ms step_avg:97.82ms +step:982/1670 train_time:96055ms step_avg:97.82ms +step:983/1670 train_time:96153ms step_avg:97.82ms +step:984/1670 train_time:96251ms step_avg:97.82ms +step:985/1670 train_time:96349ms step_avg:97.82ms +step:986/1670 train_time:96447ms step_avg:97.82ms +step:987/1670 train_time:96545ms step_avg:97.82ms +step:988/1670 train_time:96641ms step_avg:97.82ms +step:989/1670 train_time:96739ms step_avg:97.81ms +step:990/1670 train_time:96835ms step_avg:97.81ms +step:991/1670 train_time:96933ms step_avg:97.81ms +step:992/1670 train_time:97030ms step_avg:97.81ms +step:993/1670 train_time:97128ms step_avg:97.81ms +step:994/1670 train_time:97227ms step_avg:97.81ms +step:995/1670 train_time:97324ms step_avg:97.81ms +step:996/1670 train_time:97422ms step_avg:97.81ms +step:997/1670 train_time:97519ms step_avg:97.81ms +step:998/1670 train_time:97616ms step_avg:97.81ms +step:999/1670 train_time:97713ms step_avg:97.81ms +step:1000/1670 train_time:97811ms step_avg:97.81ms +step:1000/1670 val_loss:3.4757 train_time:97908ms step_avg:97.91ms +step:1001/1670 train_time:97929ms step_avg:97.83ms +step:1002/1670 train_time:98015ms step_avg:97.82ms +step:1003/1670 train_time:98114ms step_avg:97.82ms +step:1004/1670 train_time:98213ms step_avg:97.82ms +step:1005/1670 train_time:98309ms step_avg:97.82ms +step:1006/1670 train_time:98405ms step_avg:97.82ms +step:1007/1670 train_time:98500ms step_avg:97.82ms +step:1008/1670 train_time:98597ms step_avg:97.81ms +step:1009/1670 train_time:98693ms step_avg:97.81ms +step:1010/1670 train_time:98788ms step_avg:97.81ms +step:1011/1670 train_time:98886ms step_avg:97.81ms +step:1012/1670 train_time:98986ms step_avg:97.81ms +step:1013/1670 train_time:99085ms step_avg:97.81ms +step:1014/1670 train_time:99186ms step_avg:97.82ms +step:1015/1670 train_time:99283ms step_avg:97.82ms +step:1016/1670 train_time:99381ms step_avg:97.82ms +step:1017/1670 train_time:99478ms step_avg:97.81ms +step:1018/1670 train_time:99574ms step_avg:97.81ms +step:1019/1670 train_time:99670ms step_avg:97.81ms +step:1020/1670 train_time:99767ms step_avg:97.81ms +step:1021/1670 train_time:99864ms step_avg:97.81ms +step:1022/1670 train_time:99963ms step_avg:97.81ms +step:1023/1670 train_time:100062ms step_avg:97.81ms +step:1024/1670 train_time:100162ms step_avg:97.81ms +step:1025/1670 train_time:100260ms step_avg:97.81ms +step:1026/1670 train_time:100357ms step_avg:97.81ms +step:1027/1670 train_time:100455ms step_avg:97.81ms +step:1028/1670 train_time:100553ms step_avg:97.81ms +step:1029/1670 train_time:100649ms step_avg:97.81ms +step:1030/1670 train_time:100745ms step_avg:97.81ms +step:1031/1670 train_time:100842ms step_avg:97.81ms +step:1032/1670 train_time:100941ms step_avg:97.81ms +step:1033/1670 train_time:101039ms step_avg:97.81ms +step:1034/1670 train_time:101138ms step_avg:97.81ms +step:1035/1670 train_time:101236ms step_avg:97.81ms +step:1036/1670 train_time:101334ms step_avg:97.81ms +step:1037/1670 train_time:101430ms step_avg:97.81ms +step:1038/1670 train_time:101528ms step_avg:97.81ms +step:1039/1670 train_time:101625ms step_avg:97.81ms +step:1040/1670 train_time:101722ms step_avg:97.81ms +step:1041/1670 train_time:101819ms step_avg:97.81ms +step:1042/1670 train_time:101917ms step_avg:97.81ms +step:1043/1670 train_time:102015ms step_avg:97.81ms +step:1044/1670 train_time:102113ms step_avg:97.81ms +step:1045/1670 train_time:102211ms step_avg:97.81ms +step:1046/1670 train_time:102308ms step_avg:97.81ms +step:1047/1670 train_time:102697ms step_avg:98.09ms +step:1048/1670 train_time:102792ms step_avg:98.08ms +step:1049/1670 train_time:102888ms step_avg:98.08ms +step:1050/1670 train_time:102984ms step_avg:98.08ms +step:1051/1670 train_time:103080ms step_avg:98.08ms +step:1052/1670 train_time:103177ms step_avg:98.08ms +step:1053/1670 train_time:103273ms step_avg:98.08ms +step:1054/1670 train_time:103369ms step_avg:98.07ms +step:1055/1670 train_time:103464ms step_avg:98.07ms +step:1056/1670 train_time:103563ms step_avg:98.07ms +step:1057/1670 train_time:103666ms step_avg:98.08ms +step:1058/1670 train_time:103765ms step_avg:98.08ms +step:1059/1670 train_time:103863ms step_avg:98.08ms +step:1060/1670 train_time:103962ms step_avg:98.08ms +step:1061/1670 train_time:104059ms step_avg:98.08ms +step:1062/1670 train_time:104340ms step_avg:98.25ms +step:1063/1670 train_time:104413ms step_avg:98.23ms +step:1064/1670 train_time:104508ms step_avg:98.22ms +step:1065/1670 train_time:104605ms step_avg:98.22ms +step:1066/1670 train_time:104701ms step_avg:98.22ms +step:1067/1670 train_time:104797ms step_avg:98.22ms +step:1068/1670 train_time:104892ms step_avg:98.21ms +step:1069/1670 train_time:104989ms step_avg:98.21ms +step:1070/1670 train_time:105085ms step_avg:98.21ms +step:1071/1670 train_time:105181ms step_avg:98.21ms +step:1072/1670 train_time:105282ms step_avg:98.21ms +step:1073/1670 train_time:105383ms step_avg:98.21ms +step:1074/1670 train_time:105483ms step_avg:98.22ms +step:1075/1670 train_time:105580ms step_avg:98.21ms +step:1076/1670 train_time:105678ms step_avg:98.21ms +step:1077/1670 train_time:105775ms step_avg:98.21ms +step:1078/1670 train_time:105871ms step_avg:98.21ms +step:1079/1670 train_time:105968ms step_avg:98.21ms +step:1080/1670 train_time:106064ms step_avg:98.21ms +step:1081/1670 train_time:106161ms step_avg:98.21ms +step:1082/1670 train_time:106261ms step_avg:98.21ms +step:1083/1670 train_time:106362ms step_avg:98.21ms +step:1084/1670 train_time:106461ms step_avg:98.21ms +step:1085/1670 train_time:106559ms step_avg:98.21ms +step:1086/1670 train_time:106656ms step_avg:98.21ms +step:1087/1670 train_time:106753ms step_avg:98.21ms +step:1088/1670 train_time:106850ms step_avg:98.21ms +step:1089/1670 train_time:106946ms step_avg:98.21ms +step:1090/1670 train_time:107042ms step_avg:98.20ms +step:1091/1670 train_time:107140ms step_avg:98.20ms +step:1092/1670 train_time:107238ms step_avg:98.20ms +step:1093/1670 train_time:107337ms step_avg:98.20ms +step:1094/1670 train_time:107436ms step_avg:98.20ms +step:1095/1670 train_time:107534ms step_avg:98.21ms +step:1096/1670 train_time:107633ms step_avg:98.21ms +step:1097/1670 train_time:107730ms step_avg:98.20ms +step:1098/1670 train_time:107826ms step_avg:98.20ms +step:1099/1670 train_time:107923ms step_avg:98.20ms +step:1100/1670 train_time:108021ms step_avg:98.20ms +step:1101/1670 train_time:108119ms step_avg:98.20ms +step:1102/1670 train_time:108216ms step_avg:98.20ms +step:1103/1670 train_time:108313ms step_avg:98.20ms +step:1104/1670 train_time:108410ms step_avg:98.20ms +step:1105/1670 train_time:108507ms step_avg:98.20ms +step:1106/1670 train_time:108605ms step_avg:98.20ms +step:1107/1670 train_time:108703ms step_avg:98.20ms +step:1108/1670 train_time:108800ms step_avg:98.20ms +step:1109/1670 train_time:108899ms step_avg:98.20ms +step:1110/1670 train_time:108996ms step_avg:98.19ms +step:1111/1670 train_time:109093ms step_avg:98.19ms +step:1112/1670 train_time:109190ms step_avg:98.19ms +step:1113/1670 train_time:109287ms step_avg:98.19ms +step:1114/1670 train_time:109385ms step_avg:98.19ms +step:1115/1670 train_time:109483ms step_avg:98.19ms +step:1116/1670 train_time:109583ms step_avg:98.19ms +step:1117/1670 train_time:109682ms step_avg:98.19ms +step:1118/1670 train_time:109780ms step_avg:98.19ms +step:1119/1670 train_time:109879ms step_avg:98.19ms +step:1120/1670 train_time:109978ms step_avg:98.20ms +step:1121/1670 train_time:110077ms step_avg:98.20ms +step:1122/1670 train_time:110176ms step_avg:98.20ms +step:1123/1670 train_time:110274ms step_avg:98.20ms +step:1124/1670 train_time:110372ms step_avg:98.20ms +step:1125/1670 train_time:110471ms step_avg:98.20ms +step:1125/1670 val_loss:3.4234 train_time:110568ms step_avg:98.28ms +step:1126/1670 train_time:110591ms step_avg:98.22ms +step:1127/1670 train_time:110678ms step_avg:98.21ms +step:1128/1670 train_time:110775ms step_avg:98.20ms +step:1129/1670 train_time:110871ms step_avg:98.20ms +step:1130/1670 train_time:110968ms step_avg:98.20ms +step:1131/1670 train_time:111064ms step_avg:98.20ms +step:1132/1670 train_time:111161ms step_avg:98.20ms +step:1133/1670 train_time:111259ms step_avg:98.20ms +step:1134/1670 train_time:111356ms step_avg:98.20ms +step:1135/1670 train_time:111455ms step_avg:98.20ms +step:1136/1670 train_time:111558ms step_avg:98.20ms +step:1137/1670 train_time:111660ms step_avg:98.21ms +step:1138/1670 train_time:111759ms step_avg:98.21ms +step:1139/1670 train_time:111857ms step_avg:98.21ms +step:1140/1670 train_time:111955ms step_avg:98.21ms +step:1141/1670 train_time:112052ms step_avg:98.21ms +step:1142/1670 train_time:112150ms step_avg:98.20ms +step:1143/1670 train_time:112246ms step_avg:98.20ms +step:1144/1670 train_time:112343ms step_avg:98.20ms +step:1145/1670 train_time:112441ms step_avg:98.20ms +step:1146/1670 train_time:112541ms step_avg:98.20ms +step:1147/1670 train_time:112640ms step_avg:98.20ms +step:1148/1670 train_time:112741ms step_avg:98.21ms +step:1149/1670 train_time:112839ms step_avg:98.21ms +step:1150/1670 train_time:112938ms step_avg:98.21ms +step:1151/1670 train_time:113037ms step_avg:98.21ms +step:1152/1670 train_time:113136ms step_avg:98.21ms +step:1153/1670 train_time:113234ms step_avg:98.21ms +step:1154/1670 train_time:113331ms step_avg:98.21ms +step:1155/1670 train_time:113430ms step_avg:98.21ms +step:1156/1670 train_time:113529ms step_avg:98.21ms +step:1157/1670 train_time:113628ms step_avg:98.21ms +step:1158/1670 train_time:113727ms step_avg:98.21ms +step:1159/1670 train_time:113825ms step_avg:98.21ms +step:1160/1670 train_time:113923ms step_avg:98.21ms +step:1161/1670 train_time:114022ms step_avg:98.21ms +step:1162/1670 train_time:114120ms step_avg:98.21ms +step:1163/1670 train_time:114218ms step_avg:98.21ms +step:1164/1670 train_time:114316ms step_avg:98.21ms +step:1165/1670 train_time:114414ms step_avg:98.21ms +step:1166/1670 train_time:114514ms step_avg:98.21ms +step:1167/1670 train_time:114614ms step_avg:98.21ms +step:1168/1670 train_time:114714ms step_avg:98.21ms +step:1169/1670 train_time:114812ms step_avg:98.21ms +step:1170/1670 train_time:114911ms step_avg:98.21ms +step:1171/1670 train_time:115008ms step_avg:98.21ms +step:1172/1670 train_time:115105ms step_avg:98.21ms +step:1173/1670 train_time:115203ms step_avg:98.21ms +step:1174/1670 train_time:115300ms step_avg:98.21ms +step:1175/1670 train_time:115399ms step_avg:98.21ms +step:1176/1670 train_time:115498ms step_avg:98.21ms +step:1177/1670 train_time:115596ms step_avg:98.21ms +step:1178/1670 train_time:115696ms step_avg:98.21ms +step:1179/1670 train_time:115795ms step_avg:98.21ms +step:1180/1670 train_time:115894ms step_avg:98.22ms +step:1181/1670 train_time:115993ms step_avg:98.22ms +step:1182/1670 train_time:116093ms step_avg:98.22ms +step:1183/1670 train_time:116190ms step_avg:98.22ms +step:1184/1670 train_time:116287ms step_avg:98.22ms +step:1185/1670 train_time:116385ms step_avg:98.21ms +step:1186/1670 train_time:116482ms step_avg:98.21ms +step:1187/1670 train_time:116581ms step_avg:98.21ms +step:1188/1670 train_time:116680ms step_avg:98.22ms +step:1189/1670 train_time:116779ms step_avg:98.22ms +step:1190/1670 train_time:116878ms step_avg:98.22ms +step:1191/1670 train_time:116977ms step_avg:98.22ms +step:1192/1670 train_time:117077ms step_avg:98.22ms +step:1193/1670 train_time:117177ms step_avg:98.22ms +step:1194/1670 train_time:117275ms step_avg:98.22ms +step:1195/1670 train_time:117374ms step_avg:98.22ms +step:1196/1670 train_time:117472ms step_avg:98.22ms +step:1197/1670 train_time:117570ms step_avg:98.22ms +step:1198/1670 train_time:117668ms step_avg:98.22ms +step:1199/1670 train_time:117764ms step_avg:98.22ms +step:1200/1670 train_time:117861ms step_avg:98.22ms +step:1201/1670 train_time:117960ms step_avg:98.22ms +step:1202/1670 train_time:118059ms step_avg:98.22ms +step:1203/1670 train_time:118159ms step_avg:98.22ms +step:1204/1670 train_time:118256ms step_avg:98.22ms +step:1205/1670 train_time:118356ms step_avg:98.22ms +step:1206/1670 train_time:118454ms step_avg:98.22ms +step:1207/1670 train_time:118553ms step_avg:98.22ms +step:1208/1670 train_time:118651ms step_avg:98.22ms +step:1209/1670 train_time:118750ms step_avg:98.22ms +step:1210/1670 train_time:118848ms step_avg:98.22ms +step:1211/1670 train_time:118946ms step_avg:98.22ms +step:1212/1670 train_time:119044ms step_avg:98.22ms +step:1213/1670 train_time:119141ms step_avg:98.22ms +step:1214/1670 train_time:119239ms step_avg:98.22ms +step:1215/1670 train_time:119339ms step_avg:98.22ms +step:1216/1670 train_time:119439ms step_avg:98.22ms +step:1217/1670 train_time:119538ms step_avg:98.22ms +step:1218/1670 train_time:119638ms step_avg:98.23ms +step:1219/1670 train_time:119736ms step_avg:98.23ms +step:1220/1670 train_time:119836ms step_avg:98.23ms +step:1221/1670 train_time:119936ms step_avg:98.23ms +step:1222/1670 train_time:120036ms step_avg:98.23ms +step:1223/1670 train_time:120135ms step_avg:98.23ms +step:1224/1670 train_time:120233ms step_avg:98.23ms +step:1225/1670 train_time:120331ms step_avg:98.23ms +step:1226/1670 train_time:120429ms step_avg:98.23ms +step:1227/1670 train_time:120527ms step_avg:98.23ms +step:1228/1670 train_time:120624ms step_avg:98.23ms +step:1229/1670 train_time:120722ms step_avg:98.23ms +step:1230/1670 train_time:120820ms step_avg:98.23ms +step:1231/1670 train_time:120919ms step_avg:98.23ms +step:1232/1670 train_time:121018ms step_avg:98.23ms +step:1233/1670 train_time:121117ms step_avg:98.23ms +step:1234/1670 train_time:121215ms step_avg:98.23ms +step:1235/1670 train_time:121313ms step_avg:98.23ms +step:1236/1670 train_time:121412ms step_avg:98.23ms +step:1237/1670 train_time:121510ms step_avg:98.23ms +step:1238/1670 train_time:121608ms step_avg:98.23ms +step:1239/1670 train_time:121707ms step_avg:98.23ms +step:1240/1670 train_time:121804ms step_avg:98.23ms +step:1241/1670 train_time:121902ms step_avg:98.23ms +step:1242/1670 train_time:121999ms step_avg:98.23ms +step:1243/1670 train_time:122097ms step_avg:98.23ms +step:1244/1670 train_time:122196ms step_avg:98.23ms +step:1245/1670 train_time:122295ms step_avg:98.23ms +step:1246/1670 train_time:122393ms step_avg:98.23ms +step:1247/1670 train_time:122492ms step_avg:98.23ms +step:1248/1670 train_time:122590ms step_avg:98.23ms +step:1249/1670 train_time:122688ms step_avg:98.23ms +step:1250/1670 train_time:122785ms step_avg:98.23ms +step:1250/1670 val_loss:3.3809 train_time:122882ms step_avg:98.31ms +step:1251/1670 train_time:122905ms step_avg:98.25ms +step:1252/1670 train_time:122989ms step_avg:98.23ms +step:1253/1670 train_time:123088ms step_avg:98.23ms +step:1254/1670 train_time:123185ms step_avg:98.23ms +step:1255/1670 train_time:123283ms step_avg:98.23ms +step:1256/1670 train_time:123380ms step_avg:98.23ms +step:1257/1670 train_time:123477ms step_avg:98.23ms +step:1258/1670 train_time:123574ms step_avg:98.23ms +step:1259/1670 train_time:123671ms step_avg:98.23ms +step:1260/1670 train_time:123767ms step_avg:98.23ms +step:1261/1670 train_time:123866ms step_avg:98.23ms +step:1262/1670 train_time:123967ms step_avg:98.23ms +step:1263/1670 train_time:124066ms step_avg:98.23ms +step:1264/1670 train_time:124165ms step_avg:98.23ms +step:1265/1670 train_time:124263ms step_avg:98.23ms +step:1266/1670 train_time:124361ms step_avg:98.23ms +step:1267/1670 train_time:124459ms step_avg:98.23ms +step:1268/1670 train_time:124556ms step_avg:98.23ms +step:1269/1670 train_time:124653ms step_avg:98.23ms +step:1270/1670 train_time:124750ms step_avg:98.23ms +step:1271/1670 train_time:124848ms step_avg:98.23ms +step:1272/1670 train_time:124946ms step_avg:98.23ms +step:1273/1670 train_time:125045ms step_avg:98.23ms +step:1274/1670 train_time:125323ms step_avg:98.37ms +step:1275/1670 train_time:125511ms step_avg:98.44ms +step:1276/1670 train_time:125607ms step_avg:98.44ms +step:1277/1670 train_time:125704ms step_avg:98.44ms +step:1278/1670 train_time:125801ms step_avg:98.44ms +step:1279/1670 train_time:125898ms step_avg:98.43ms +step:1280/1670 train_time:125995ms step_avg:98.43ms +step:1281/1670 train_time:126091ms step_avg:98.43ms +step:1282/1670 train_time:126188ms step_avg:98.43ms +step:1283/1670 train_time:126284ms step_avg:98.43ms +step:1284/1670 train_time:126384ms step_avg:98.43ms +step:1285/1670 train_time:126488ms step_avg:98.43ms +step:1286/1670 train_time:126587ms step_avg:98.43ms +step:1287/1670 train_time:126685ms step_avg:98.43ms +step:1288/1670 train_time:126782ms step_avg:98.43ms +step:1289/1670 train_time:126880ms step_avg:98.43ms +step:1290/1670 train_time:126977ms step_avg:98.43ms +step:1291/1670 train_time:127075ms step_avg:98.43ms +step:1292/1670 train_time:127173ms step_avg:98.43ms +step:1293/1670 train_time:127269ms step_avg:98.43ms +step:1294/1670 train_time:127367ms step_avg:98.43ms +step:1295/1670 train_time:127466ms step_avg:98.43ms +step:1296/1670 train_time:127566ms step_avg:98.43ms +step:1297/1670 train_time:127664ms step_avg:98.43ms +step:1298/1670 train_time:127764ms step_avg:98.43ms +step:1299/1670 train_time:127862ms step_avg:98.43ms +step:1300/1670 train_time:127959ms step_avg:98.43ms +step:1301/1670 train_time:128058ms step_avg:98.43ms +step:1302/1670 train_time:128155ms step_avg:98.43ms +step:1303/1670 train_time:128252ms step_avg:98.43ms +step:1304/1670 train_time:128351ms step_avg:98.43ms +step:1305/1670 train_time:128450ms step_avg:98.43ms +step:1306/1670 train_time:128549ms step_avg:98.43ms +step:1307/1670 train_time:128647ms step_avg:98.43ms +step:1308/1670 train_time:128745ms step_avg:98.43ms +step:1309/1670 train_time:128843ms step_avg:98.43ms +step:1310/1670 train_time:128940ms step_avg:98.43ms +step:1311/1670 train_time:129039ms step_avg:98.43ms +step:1312/1670 train_time:129137ms step_avg:98.43ms +step:1313/1670 train_time:129234ms step_avg:98.43ms +step:1314/1670 train_time:129332ms step_avg:98.43ms +step:1315/1670 train_time:129430ms step_avg:98.43ms +step:1316/1670 train_time:129528ms step_avg:98.43ms +step:1317/1670 train_time:129626ms step_avg:98.43ms +step:1318/1670 train_time:129725ms step_avg:98.43ms +step:1319/1670 train_time:129822ms step_avg:98.42ms +step:1320/1670 train_time:129921ms step_avg:98.43ms +step:1321/1670 train_time:130019ms step_avg:98.42ms +step:1322/1670 train_time:130116ms step_avg:98.42ms +step:1323/1670 train_time:130214ms step_avg:98.42ms +step:1324/1670 train_time:130313ms step_avg:98.42ms +step:1325/1670 train_time:130411ms step_avg:98.42ms +step:1326/1670 train_time:130509ms step_avg:98.42ms +step:1327/1670 train_time:130607ms step_avg:98.42ms +step:1328/1670 train_time:130705ms step_avg:98.42ms +step:1329/1670 train_time:130803ms step_avg:98.42ms +step:1330/1670 train_time:130901ms step_avg:98.42ms +step:1331/1670 train_time:130999ms step_avg:98.42ms +step:1332/1670 train_time:131098ms step_avg:98.42ms +step:1333/1670 train_time:131196ms step_avg:98.42ms +step:1334/1670 train_time:131294ms step_avg:98.42ms +step:1335/1670 train_time:131393ms step_avg:98.42ms +step:1336/1670 train_time:131494ms step_avg:98.42ms +step:1337/1670 train_time:131593ms step_avg:98.42ms +step:1338/1670 train_time:131692ms step_avg:98.42ms +step:1339/1670 train_time:131792ms step_avg:98.43ms +step:1340/1670 train_time:131891ms step_avg:98.43ms +step:1341/1670 train_time:131990ms step_avg:98.43ms +step:1342/1670 train_time:132089ms step_avg:98.43ms +step:1343/1670 train_time:132186ms step_avg:98.43ms +step:1344/1670 train_time:132284ms step_avg:98.43ms +step:1345/1670 train_time:132382ms step_avg:98.43ms +step:1346/1670 train_time:132482ms step_avg:98.43ms +step:1347/1670 train_time:132581ms step_avg:98.43ms +step:1348/1670 train_time:132679ms step_avg:98.43ms +step:1349/1670 train_time:132779ms step_avg:98.43ms +step:1350/1670 train_time:132880ms step_avg:98.43ms +step:1351/1670 train_time:132981ms step_avg:98.43ms +step:1352/1670 train_time:133080ms step_avg:98.43ms +step:1353/1670 train_time:133179ms step_avg:98.43ms +step:1354/1670 train_time:133277ms step_avg:98.43ms +step:1355/1670 train_time:133375ms step_avg:98.43ms +step:1356/1670 train_time:133473ms step_avg:98.43ms +step:1357/1670 train_time:133572ms step_avg:98.43ms +step:1358/1670 train_time:133669ms step_avg:98.43ms +step:1359/1670 train_time:133766ms step_avg:98.43ms +step:1360/1670 train_time:133864ms step_avg:98.43ms +step:1361/1670 train_time:133962ms step_avg:98.43ms +step:1362/1670 train_time:134060ms step_avg:98.43ms +step:1363/1670 train_time:134159ms step_avg:98.43ms +step:1364/1670 train_time:134257ms step_avg:98.43ms +step:1365/1670 train_time:134354ms step_avg:98.43ms +step:1366/1670 train_time:134452ms step_avg:98.43ms +step:1367/1670 train_time:134550ms step_avg:98.43ms +step:1368/1670 train_time:134648ms step_avg:98.43ms +step:1369/1670 train_time:134747ms step_avg:98.43ms +step:1370/1670 train_time:134845ms step_avg:98.43ms +step:1371/1670 train_time:134943ms step_avg:98.43ms +step:1372/1670 train_time:135042ms step_avg:98.43ms +step:1373/1670 train_time:135140ms step_avg:98.43ms +step:1374/1670 train_time:135238ms step_avg:98.43ms +step:1375/1670 train_time:135337ms step_avg:98.43ms +step:1375/1670 val_loss:3.3427 train_time:135436ms step_avg:98.50ms +step:1376/1670 train_time:135458ms step_avg:98.44ms +step:1377/1670 train_time:135542ms step_avg:98.43ms +step:1378/1670 train_time:135644ms step_avg:98.44ms +step:1379/1670 train_time:135744ms step_avg:98.44ms +step:1380/1670 train_time:135842ms step_avg:98.44ms +step:1381/1670 train_time:135939ms step_avg:98.44ms +step:1382/1670 train_time:136035ms step_avg:98.43ms +step:1383/1670 train_time:136132ms step_avg:98.43ms +step:1384/1670 train_time:136229ms step_avg:98.43ms +step:1385/1670 train_time:136327ms step_avg:98.43ms +step:1386/1670 train_time:136426ms step_avg:98.43ms +step:1387/1670 train_time:136527ms step_avg:98.43ms +step:1388/1670 train_time:136629ms step_avg:98.44ms +step:1389/1670 train_time:136729ms step_avg:98.44ms +step:1390/1670 train_time:136827ms step_avg:98.44ms +step:1391/1670 train_time:136926ms step_avg:98.44ms +step:1392/1670 train_time:137023ms step_avg:98.44ms +step:1393/1670 train_time:137120ms step_avg:98.44ms +step:1394/1670 train_time:137217ms step_avg:98.43ms +step:1395/1670 train_time:137314ms step_avg:98.43ms +step:1396/1670 train_time:137411ms step_avg:98.43ms +step:1397/1670 train_time:137510ms step_avg:98.43ms +step:1398/1670 train_time:137610ms step_avg:98.43ms +step:1399/1670 train_time:137709ms step_avg:98.43ms +step:1400/1670 train_time:137810ms step_avg:98.44ms +step:1401/1670 train_time:137908ms step_avg:98.44ms +step:1402/1670 train_time:138006ms step_avg:98.44ms +step:1403/1670 train_time:138104ms step_avg:98.43ms +step:1404/1670 train_time:138202ms step_avg:98.43ms +step:1405/1670 train_time:138301ms step_avg:98.44ms +step:1406/1670 train_time:138400ms step_avg:98.44ms +step:1407/1670 train_time:138498ms step_avg:98.44ms +step:1408/1670 train_time:138597ms step_avg:98.44ms +step:1409/1670 train_time:138694ms step_avg:98.43ms +step:1410/1670 train_time:138793ms step_avg:98.43ms +step:1411/1670 train_time:138891ms step_avg:98.43ms +step:1412/1670 train_time:138989ms step_avg:98.43ms +step:1413/1670 train_time:139087ms step_avg:98.43ms +step:1414/1670 train_time:139185ms step_avg:98.43ms +step:1415/1670 train_time:139283ms step_avg:98.43ms +step:1416/1670 train_time:139382ms step_avg:98.43ms +step:1417/1670 train_time:139482ms step_avg:98.43ms +step:1418/1670 train_time:139580ms step_avg:98.43ms +step:1419/1670 train_time:139678ms step_avg:98.43ms +step:1420/1670 train_time:139776ms step_avg:98.43ms +step:1421/1670 train_time:139873ms step_avg:98.43ms +step:1422/1670 train_time:139971ms step_avg:98.43ms +step:1423/1670 train_time:140069ms step_avg:98.43ms +step:1424/1670 train_time:140167ms step_avg:98.43ms +step:1425/1670 train_time:140265ms step_avg:98.43ms +step:1426/1670 train_time:140364ms step_avg:98.43ms +step:1427/1670 train_time:140465ms step_avg:98.43ms +step:1428/1670 train_time:140565ms step_avg:98.43ms +step:1429/1670 train_time:140665ms step_avg:98.44ms +step:1430/1670 train_time:140764ms step_avg:98.44ms +step:1431/1670 train_time:140864ms step_avg:98.44ms +step:1432/1670 train_time:140964ms step_avg:98.44ms +step:1433/1670 train_time:141063ms step_avg:98.44ms +step:1434/1670 train_time:141160ms step_avg:98.44ms +step:1435/1670 train_time:141258ms step_avg:98.44ms +step:1436/1670 train_time:141355ms step_avg:98.44ms +step:1437/1670 train_time:141453ms step_avg:98.44ms +step:1438/1670 train_time:141552ms step_avg:98.44ms +step:1439/1670 train_time:141653ms step_avg:98.44ms +step:1440/1670 train_time:141750ms step_avg:98.44ms +step:1441/1670 train_time:141849ms step_avg:98.44ms +step:1442/1670 train_time:141948ms step_avg:98.44ms +step:1443/1670 train_time:142047ms step_avg:98.44ms +step:1444/1670 train_time:142146ms step_avg:98.44ms +step:1445/1670 train_time:142245ms step_avg:98.44ms +step:1446/1670 train_time:142343ms step_avg:98.44ms +step:1447/1670 train_time:142442ms step_avg:98.44ms +step:1448/1670 train_time:142541ms step_avg:98.44ms +step:1449/1670 train_time:142640ms step_avg:98.44ms +step:1450/1670 train_time:142739ms step_avg:98.44ms +step:1451/1670 train_time:142837ms step_avg:98.44ms +step:1452/1670 train_time:142934ms step_avg:98.44ms +step:1453/1670 train_time:143031ms step_avg:98.44ms +step:1454/1670 train_time:143129ms step_avg:98.44ms +step:1455/1670 train_time:143228ms step_avg:98.44ms +step:1456/1670 train_time:143327ms step_avg:98.44ms +step:1457/1670 train_time:143426ms step_avg:98.44ms +step:1458/1670 train_time:143524ms step_avg:98.44ms +step:1459/1670 train_time:143623ms step_avg:98.44ms +step:1460/1670 train_time:143722ms step_avg:98.44ms +step:1461/1670 train_time:143821ms step_avg:98.44ms +step:1462/1670 train_time:143920ms step_avg:98.44ms +step:1463/1670 train_time:144018ms step_avg:98.44ms +step:1464/1670 train_time:144115ms step_avg:98.44ms +step:1465/1670 train_time:144213ms step_avg:98.44ms +step:1466/1670 train_time:144310ms step_avg:98.44ms +step:1467/1670 train_time:144409ms step_avg:98.44ms +step:1468/1670 train_time:144508ms step_avg:98.44ms +step:1469/1670 train_time:144608ms step_avg:98.44ms +step:1470/1670 train_time:144707ms step_avg:98.44ms +step:1471/1670 train_time:144807ms step_avg:98.44ms +step:1472/1670 train_time:144907ms step_avg:98.44ms +step:1473/1670 train_time:145007ms step_avg:98.44ms +step:1474/1670 train_time:145106ms step_avg:98.44ms +step:1475/1670 train_time:145204ms step_avg:98.44ms +step:1476/1670 train_time:145302ms step_avg:98.44ms +step:1477/1670 train_time:145399ms step_avg:98.44ms +step:1478/1670 train_time:145497ms step_avg:98.44ms +step:1479/1670 train_time:145595ms step_avg:98.44ms +step:1480/1670 train_time:145693ms step_avg:98.44ms +step:1481/1670 train_time:145792ms step_avg:98.44ms +step:1482/1670 train_time:145892ms step_avg:98.44ms +step:1483/1670 train_time:145992ms step_avg:98.44ms +step:1484/1670 train_time:146092ms step_avg:98.44ms +step:1485/1670 train_time:146373ms step_avg:98.57ms +step:1486/1670 train_time:146575ms step_avg:98.64ms +step:1487/1670 train_time:146670ms step_avg:98.63ms +step:1488/1670 train_time:146766ms step_avg:98.63ms +step:1489/1670 train_time:146863ms step_avg:98.63ms +step:1490/1670 train_time:146960ms step_avg:98.63ms +step:1491/1670 train_time:147056ms step_avg:98.63ms +step:1492/1670 train_time:147153ms step_avg:98.63ms +step:1493/1670 train_time:147250ms step_avg:98.63ms +step:1494/1670 train_time:147348ms step_avg:98.63ms +step:1495/1670 train_time:147452ms step_avg:98.63ms +step:1496/1670 train_time:147554ms step_avg:98.63ms +step:1497/1670 train_time:147654ms step_avg:98.63ms +step:1498/1670 train_time:147753ms step_avg:98.63ms +step:1499/1670 train_time:147851ms step_avg:98.63ms +step:1500/1670 train_time:147949ms step_avg:98.63ms +step:1500/1670 val_loss:3.3106 train_time:148046ms step_avg:98.70ms +step:1501/1670 train_time:148068ms step_avg:98.65ms +step:1502/1670 train_time:148151ms step_avg:98.64ms +step:1503/1670 train_time:148256ms step_avg:98.64ms +step:1504/1670 train_time:148354ms step_avg:98.64ms +step:1505/1670 train_time:148451ms step_avg:98.64ms +step:1506/1670 train_time:148549ms step_avg:98.64ms +step:1507/1670 train_time:148646ms step_avg:98.64ms +step:1508/1670 train_time:148743ms step_avg:98.64ms +step:1509/1670 train_time:148839ms step_avg:98.63ms +step:1510/1670 train_time:148937ms step_avg:98.63ms +step:1511/1670 train_time:149036ms step_avg:98.63ms +step:1512/1670 train_time:149138ms step_avg:98.64ms +step:1513/1670 train_time:149237ms step_avg:98.64ms +step:1514/1670 train_time:149336ms step_avg:98.64ms +step:1515/1670 train_time:149433ms step_avg:98.64ms +step:1516/1670 train_time:149530ms step_avg:98.63ms +step:1517/1670 train_time:149628ms step_avg:98.63ms +step:1518/1670 train_time:149725ms step_avg:98.63ms +step:1519/1670 train_time:149823ms step_avg:98.63ms +step:1520/1670 train_time:149920ms step_avg:98.63ms +step:1521/1670 train_time:150019ms step_avg:98.63ms +step:1522/1670 train_time:150117ms step_avg:98.63ms +step:1523/1670 train_time:150217ms step_avg:98.63ms +step:1524/1670 train_time:150315ms step_avg:98.63ms +step:1525/1670 train_time:150413ms step_avg:98.63ms +step:1526/1670 train_time:150510ms step_avg:98.63ms +step:1527/1670 train_time:150608ms step_avg:98.63ms +step:1528/1670 train_time:150706ms step_avg:98.63ms +step:1529/1670 train_time:150804ms step_avg:98.63ms +step:1530/1670 train_time:150902ms step_avg:98.63ms +step:1531/1670 train_time:151000ms step_avg:98.63ms +step:1532/1670 train_time:151099ms step_avg:98.63ms +step:1533/1670 train_time:151198ms step_avg:98.63ms +step:1534/1670 train_time:151296ms step_avg:98.63ms +step:1535/1670 train_time:151394ms step_avg:98.63ms +step:1536/1670 train_time:151492ms step_avg:98.63ms +step:1537/1670 train_time:151590ms step_avg:98.63ms +step:1538/1670 train_time:151688ms step_avg:98.63ms +step:1539/1670 train_time:151786ms step_avg:98.63ms +step:1540/1670 train_time:151884ms step_avg:98.63ms +step:1541/1670 train_time:151982ms step_avg:98.63ms +step:1542/1670 train_time:152081ms step_avg:98.63ms +step:1543/1670 train_time:152181ms step_avg:98.63ms +step:1544/1670 train_time:152280ms step_avg:98.63ms +step:1545/1670 train_time:152379ms step_avg:98.63ms +step:1546/1670 train_time:152476ms step_avg:98.63ms +step:1547/1670 train_time:152574ms step_avg:98.63ms +step:1548/1670 train_time:152671ms step_avg:98.62ms +step:1549/1670 train_time:152769ms step_avg:98.62ms +step:1550/1670 train_time:152867ms step_avg:98.62ms +step:1551/1670 train_time:152966ms step_avg:98.62ms +step:1552/1670 train_time:153066ms step_avg:98.63ms +step:1553/1670 train_time:153166ms step_avg:98.63ms +step:1554/1670 train_time:153266ms step_avg:98.63ms +step:1555/1670 train_time:153366ms step_avg:98.63ms +step:1556/1670 train_time:153466ms step_avg:98.63ms +step:1557/1670 train_time:153565ms step_avg:98.63ms +step:1558/1670 train_time:153663ms step_avg:98.63ms +step:1559/1670 train_time:153761ms step_avg:98.63ms +step:1560/1670 train_time:153858ms step_avg:98.63ms +step:1561/1670 train_time:153955ms step_avg:98.63ms +step:1562/1670 train_time:154054ms step_avg:98.63ms +step:1563/1670 train_time:154153ms step_avg:98.63ms +step:1564/1670 train_time:154252ms step_avg:98.63ms +step:1565/1670 train_time:154351ms step_avg:98.63ms +step:1566/1670 train_time:154450ms step_avg:98.63ms +step:1567/1670 train_time:154549ms step_avg:98.63ms +step:1568/1670 train_time:154647ms step_avg:98.63ms +step:1569/1670 train_time:154746ms step_avg:98.63ms +step:1570/1670 train_time:154845ms step_avg:98.63ms +step:1571/1670 train_time:154944ms step_avg:98.63ms +step:1572/1670 train_time:155044ms step_avg:98.63ms +step:1573/1670 train_time:155141ms step_avg:98.63ms +step:1574/1670 train_time:155240ms step_avg:98.63ms +step:1575/1670 train_time:155338ms step_avg:98.63ms +step:1576/1670 train_time:155435ms step_avg:98.63ms +step:1577/1670 train_time:155533ms step_avg:98.63ms +step:1578/1670 train_time:155631ms step_avg:98.63ms +step:1579/1670 train_time:155731ms step_avg:98.63ms +step:1580/1670 train_time:155830ms step_avg:98.63ms +step:1581/1670 train_time:155929ms step_avg:98.63ms +step:1582/1670 train_time:156027ms step_avg:98.63ms +step:1583/1670 train_time:156127ms step_avg:98.63ms +step:1584/1670 train_time:156227ms step_avg:98.63ms +step:1585/1670 train_time:156325ms step_avg:98.63ms +step:1586/1670 train_time:156425ms step_avg:98.63ms +step:1587/1670 train_time:156523ms step_avg:98.63ms +step:1588/1670 train_time:156622ms step_avg:98.63ms +step:1589/1670 train_time:156719ms step_avg:98.63ms +step:1590/1670 train_time:156816ms step_avg:98.63ms +step:1591/1670 train_time:156914ms step_avg:98.63ms +step:1592/1670 train_time:157013ms step_avg:98.63ms +step:1593/1670 train_time:157112ms step_avg:98.63ms +step:1594/1670 train_time:157211ms step_avg:98.63ms +step:1595/1670 train_time:157310ms step_avg:98.63ms +step:1596/1670 train_time:157410ms step_avg:98.63ms +step:1597/1670 train_time:157508ms step_avg:98.63ms +step:1598/1670 train_time:157608ms step_avg:98.63ms +step:1599/1670 train_time:157708ms step_avg:98.63ms +step:1600/1670 train_time:157807ms step_avg:98.63ms +step:1601/1670 train_time:157905ms step_avg:98.63ms +step:1602/1670 train_time:158003ms step_avg:98.63ms +step:1603/1670 train_time:158101ms step_avg:98.63ms +step:1604/1670 train_time:158199ms step_avg:98.63ms +step:1605/1670 train_time:158296ms step_avg:98.63ms +step:1606/1670 train_time:158395ms step_avg:98.63ms +step:1607/1670 train_time:158494ms step_avg:98.63ms +step:1608/1670 train_time:158593ms step_avg:98.63ms +step:1609/1670 train_time:158691ms step_avg:98.63ms +step:1610/1670 train_time:158790ms step_avg:98.63ms +step:1611/1670 train_time:158890ms step_avg:98.63ms +step:1612/1670 train_time:158989ms step_avg:98.63ms +step:1613/1670 train_time:159088ms step_avg:98.63ms +step:1614/1670 train_time:159186ms step_avg:98.63ms +step:1615/1670 train_time:159285ms step_avg:98.63ms +step:1616/1670 train_time:159384ms step_avg:98.63ms +step:1617/1670 train_time:159484ms step_avg:98.63ms +step:1618/1670 train_time:159583ms step_avg:98.63ms +step:1619/1670 train_time:159681ms step_avg:98.63ms +step:1620/1670 train_time:159779ms step_avg:98.63ms +step:1621/1670 train_time:159877ms step_avg:98.63ms +step:1622/1670 train_time:159974ms step_avg:98.63ms +step:1623/1670 train_time:160074ms step_avg:98.63ms +step:1624/1670 train_time:160173ms step_avg:98.63ms +step:1625/1670 train_time:160273ms step_avg:98.63ms +step:1625/1670 val_loss:3.2843 train_time:160371ms step_avg:98.69ms +step:1626/1670 train_time:160393ms step_avg:98.64ms +step:1627/1670 train_time:160478ms step_avg:98.63ms +step:1628/1670 train_time:160582ms step_avg:98.64ms +step:1629/1670 train_time:160679ms step_avg:98.64ms +step:1630/1670 train_time:160778ms step_avg:98.64ms +step:1631/1670 train_time:160875ms step_avg:98.64ms +step:1632/1670 train_time:160972ms step_avg:98.64ms +step:1633/1670 train_time:161069ms step_avg:98.63ms +step:1634/1670 train_time:161166ms step_avg:98.63ms +step:1635/1670 train_time:161263ms step_avg:98.63ms +step:1636/1670 train_time:161361ms step_avg:98.63ms +step:1637/1670 train_time:161461ms step_avg:98.63ms +step:1638/1670 train_time:161562ms step_avg:98.63ms +step:1639/1670 train_time:161661ms step_avg:98.63ms +step:1640/1670 train_time:161760ms step_avg:98.63ms +step:1641/1670 train_time:161858ms step_avg:98.63ms +step:1642/1670 train_time:161956ms step_avg:98.63ms +step:1643/1670 train_time:162054ms step_avg:98.63ms +step:1644/1670 train_time:162152ms step_avg:98.63ms +step:1645/1670 train_time:162250ms step_avg:98.63ms +step:1646/1670 train_time:162348ms step_avg:98.63ms +step:1647/1670 train_time:162447ms step_avg:98.63ms +step:1648/1670 train_time:162548ms step_avg:98.63ms +step:1649/1670 train_time:162648ms step_avg:98.63ms +step:1650/1670 train_time:162747ms step_avg:98.63ms +step:1651/1670 train_time:162846ms step_avg:98.63ms +step:1652/1670 train_time:162944ms step_avg:98.63ms +step:1653/1670 train_time:163040ms step_avg:98.63ms +step:1654/1670 train_time:163137ms step_avg:98.63ms +step:1655/1670 train_time:163235ms step_avg:98.63ms +step:1656/1670 train_time:163333ms step_avg:98.63ms +step:1657/1670 train_time:163433ms step_avg:98.63ms +step:1658/1670 train_time:163533ms step_avg:98.63ms +step:1659/1670 train_time:163634ms step_avg:98.63ms +step:1660/1670 train_time:163734ms step_avg:98.64ms +step:1661/1670 train_time:163835ms step_avg:98.64ms +step:1662/1670 train_time:163933ms step_avg:98.64ms +step:1663/1670 train_time:164032ms step_avg:98.64ms +step:1664/1670 train_time:164131ms step_avg:98.64ms +step:1665/1670 train_time:164228ms step_avg:98.64ms +step:1666/1670 train_time:164325ms step_avg:98.63ms +step:1667/1670 train_time:164423ms step_avg:98.63ms +step:1668/1670 train_time:164520ms step_avg:98.63ms +step:1669/1670 train_time:164619ms step_avg:98.63ms +step:1670/1670 train_time:164718ms step_avg:98.63ms +step:1670/1670 val_loss:3.2764 train_time:164817ms step_avg:98.69ms +peak memory allocated: 34000 MiB reserved: 49056 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt b/records/050925_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt new file mode 100644 index 000000000..784030ba9 --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:22:32 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 128W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 87344 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 87345 C /usr/bin/python3 610MiB | +| 0 N/A N/A 87346 C /usr/bin/python3 610MiB | +| 0 N/A N/A 87347 C /usr/bin/python3 610MiB | +| 0 N/A N/A 87348 C /usr/bin/python3 610MiB | +| 0 N/A N/A 87349 C /usr/bin/python3 610MiB | +| 0 N/A N/A 87350 C /usr/bin/python3 610MiB | +| 0 N/A N/A 87351 C /usr/bin/python3 610MiB | +| 1 N/A N/A 87345 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 87346 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 87347 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 87348 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 87349 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 87350 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 87351 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:354ms step_avg:354.08ms +step:2/1670 train_time:375ms step_avg:187.34ms +step:3/1670 train_time:448ms step_avg:149.22ms +step:4/1670 train_time:541ms step_avg:135.35ms +step:5/1670 train_time:636ms step_avg:127.26ms +step:6/1670 train_time:731ms step_avg:121.81ms +step:7/1670 train_time:826ms step_avg:117.95ms +step:8/1670 train_time:921ms step_avg:115.12ms +step:9/1670 train_time:1016ms step_avg:112.91ms +step:10/1670 train_time:1112ms step_avg:111.19ms +step:11/1670 train_time:1207ms step_avg:109.77ms +step:12/1670 train_time:1305ms step_avg:108.78ms +step:13/1670 train_time:1404ms step_avg:107.98ms +step:14/1670 train_time:1499ms step_avg:107.10ms +step:15/1670 train_time:1595ms step_avg:106.36ms +step:16/1670 train_time:1691ms step_avg:105.69ms +step:17/1670 train_time:1787ms step_avg:105.14ms +step:18/1670 train_time:1883ms step_avg:104.61ms +step:19/1670 train_time:1979ms step_avg:104.15ms +step:20/1670 train_time:2074ms step_avg:103.69ms +step:21/1670 train_time:2169ms step_avg:103.30ms +step:22/1670 train_time:2266ms step_avg:102.99ms +step:23/1670 train_time:2363ms step_avg:102.74ms +step:24/1670 train_time:2459ms step_avg:102.44ms +step:25/1670 train_time:2555ms step_avg:102.19ms +step:26/1670 train_time:2651ms step_avg:101.96ms +step:27/1670 train_time:2747ms step_avg:101.74ms +step:28/1670 train_time:2843ms step_avg:101.55ms +step:29/1670 train_time:2939ms step_avg:101.33ms +step:30/1670 train_time:3034ms step_avg:101.12ms +step:31/1670 train_time:3130ms step_avg:100.97ms +step:32/1670 train_time:3226ms step_avg:100.81ms +step:33/1670 train_time:3322ms step_avg:100.67ms +step:34/1670 train_time:3418ms step_avg:100.54ms +step:35/1670 train_time:3514ms step_avg:100.41ms +step:36/1670 train_time:3611ms step_avg:100.30ms +step:37/1670 train_time:3707ms step_avg:100.20ms +step:38/1670 train_time:3803ms step_avg:100.09ms +step:39/1670 train_time:3899ms step_avg:99.98ms +step:40/1670 train_time:3995ms step_avg:99.86ms +step:41/1670 train_time:4090ms step_avg:99.75ms +step:42/1670 train_time:4185ms step_avg:99.65ms +step:43/1670 train_time:4282ms step_avg:99.58ms +step:44/1670 train_time:4378ms step_avg:99.50ms +step:45/1670 train_time:4474ms step_avg:99.42ms +step:46/1670 train_time:4572ms step_avg:99.39ms +step:47/1670 train_time:4668ms step_avg:99.32ms +step:48/1670 train_time:4765ms step_avg:99.28ms +step:49/1670 train_time:4861ms step_avg:99.21ms +step:50/1670 train_time:4957ms step_avg:99.14ms +step:51/1670 train_time:5053ms step_avg:99.08ms +step:52/1670 train_time:5149ms step_avg:99.02ms +step:53/1670 train_time:5245ms step_avg:98.97ms +step:54/1670 train_time:5341ms step_avg:98.91ms +step:55/1670 train_time:5437ms step_avg:98.86ms +step:56/1670 train_time:5534ms step_avg:98.82ms +step:57/1670 train_time:5630ms step_avg:98.77ms +step:58/1670 train_time:5726ms step_avg:98.73ms +step:59/1670 train_time:5823ms step_avg:98.69ms +step:60/1670 train_time:5918ms step_avg:98.64ms +step:61/1670 train_time:6014ms step_avg:98.60ms +step:62/1670 train_time:6110ms step_avg:98.55ms +step:63/1670 train_time:6207ms step_avg:98.52ms +step:64/1670 train_time:6304ms step_avg:98.50ms +step:65/1670 train_time:6399ms step_avg:98.45ms +step:66/1670 train_time:6495ms step_avg:98.41ms +step:67/1670 train_time:6592ms step_avg:98.38ms +step:68/1670 train_time:6687ms step_avg:98.34ms +step:69/1670 train_time:6784ms step_avg:98.32ms +step:70/1670 train_time:6880ms step_avg:98.29ms +step:71/1670 train_time:6976ms step_avg:98.25ms +step:72/1670 train_time:7072ms step_avg:98.22ms +step:73/1670 train_time:7167ms step_avg:98.18ms +step:74/1670 train_time:7264ms step_avg:98.16ms +step:75/1670 train_time:7360ms step_avg:98.14ms +step:76/1670 train_time:7456ms step_avg:98.11ms +step:77/1670 train_time:7552ms step_avg:98.08ms +step:78/1670 train_time:7648ms step_avg:98.05ms +step:79/1670 train_time:7745ms step_avg:98.03ms +step:80/1670 train_time:7840ms step_avg:98.00ms +step:81/1670 train_time:7935ms step_avg:97.97ms +step:82/1670 train_time:8032ms step_avg:97.95ms +step:83/1670 train_time:8128ms step_avg:97.93ms +step:84/1670 train_time:8224ms step_avg:97.91ms +step:85/1670 train_time:8321ms step_avg:97.90ms +step:86/1670 train_time:8416ms step_avg:97.87ms +step:87/1670 train_time:8513ms step_avg:97.85ms +step:88/1670 train_time:8610ms step_avg:97.84ms +step:89/1670 train_time:8705ms step_avg:97.81ms +step:90/1670 train_time:8801ms step_avg:97.79ms +step:91/1670 train_time:8897ms step_avg:97.77ms +step:92/1670 train_time:8993ms step_avg:97.75ms +step:93/1670 train_time:9088ms step_avg:97.73ms +step:94/1670 train_time:9184ms step_avg:97.71ms +step:95/1670 train_time:9280ms step_avg:97.69ms +step:96/1670 train_time:9376ms step_avg:97.66ms +step:97/1670 train_time:9472ms step_avg:97.65ms +step:98/1670 train_time:9568ms step_avg:97.63ms +step:99/1670 train_time:9664ms step_avg:97.62ms +step:100/1670 train_time:9760ms step_avg:97.60ms +step:101/1670 train_time:9857ms step_avg:97.59ms +step:102/1670 train_time:9952ms step_avg:97.57ms +step:103/1670 train_time:10048ms step_avg:97.55ms +step:104/1670 train_time:10143ms step_avg:97.53ms +step:105/1670 train_time:10238ms step_avg:97.51ms +step:106/1670 train_time:10334ms step_avg:97.49ms +step:107/1670 train_time:10431ms step_avg:97.49ms +step:108/1670 train_time:10527ms step_avg:97.47ms +step:109/1670 train_time:10624ms step_avg:97.47ms +step:110/1670 train_time:10719ms step_avg:97.45ms +step:111/1670 train_time:10815ms step_avg:97.43ms +step:112/1670 train_time:10912ms step_avg:97.42ms +step:113/1670 train_time:11008ms step_avg:97.41ms +step:114/1670 train_time:11105ms step_avg:97.41ms +step:115/1670 train_time:11201ms step_avg:97.40ms +step:116/1670 train_time:11296ms step_avg:97.38ms +step:117/1670 train_time:11393ms step_avg:97.37ms +step:118/1670 train_time:11489ms step_avg:97.37ms +step:119/1670 train_time:11585ms step_avg:97.35ms +step:120/1670 train_time:11681ms step_avg:97.34ms +step:121/1670 train_time:11776ms step_avg:97.32ms +step:122/1670 train_time:11872ms step_avg:97.31ms +step:123/1670 train_time:11968ms step_avg:97.30ms +step:124/1670 train_time:12064ms step_avg:97.29ms +step:125/1670 train_time:12160ms step_avg:97.28ms +step:125/1670 val_loss:4.3131 train_time:12254ms step_avg:98.03ms +step:126/1670 train_time:12277ms step_avg:97.44ms +step:127/1670 train_time:12363ms step_avg:97.34ms +step:128/1670 train_time:12465ms step_avg:97.38ms +step:129/1670 train_time:12561ms step_avg:97.37ms +step:130/1670 train_time:12656ms step_avg:97.36ms +step:131/1670 train_time:12751ms step_avg:97.33ms +step:132/1670 train_time:12845ms step_avg:97.31ms +step:133/1670 train_time:12940ms step_avg:97.29ms +step:134/1670 train_time:13035ms step_avg:97.28ms +step:135/1670 train_time:13130ms step_avg:97.26ms +step:136/1670 train_time:13225ms step_avg:97.24ms +step:137/1670 train_time:13321ms step_avg:97.24ms +step:138/1670 train_time:13419ms step_avg:97.24ms +step:139/1670 train_time:13516ms step_avg:97.24ms +step:140/1670 train_time:13614ms step_avg:97.24ms +step:141/1670 train_time:13710ms step_avg:97.23ms +step:142/1670 train_time:13804ms step_avg:97.21ms +step:143/1670 train_time:13899ms step_avg:97.20ms +step:144/1670 train_time:13994ms step_avg:97.18ms +step:145/1670 train_time:14088ms step_avg:97.16ms +step:146/1670 train_time:14183ms step_avg:97.15ms +step:147/1670 train_time:14280ms step_avg:97.14ms +step:148/1670 train_time:14375ms step_avg:97.13ms +step:149/1670 train_time:14472ms step_avg:97.13ms +step:150/1670 train_time:14569ms step_avg:97.13ms +step:151/1670 train_time:14666ms step_avg:97.12ms +step:152/1670 train_time:14761ms step_avg:97.11ms +step:153/1670 train_time:14856ms step_avg:97.10ms +step:154/1670 train_time:14952ms step_avg:97.09ms +step:155/1670 train_time:15047ms step_avg:97.08ms +step:156/1670 train_time:15142ms step_avg:97.07ms +step:157/1670 train_time:15237ms step_avg:97.05ms +step:158/1670 train_time:15333ms step_avg:97.05ms +step:159/1670 train_time:15429ms step_avg:97.04ms +step:160/1670 train_time:15525ms step_avg:97.03ms +step:161/1670 train_time:15621ms step_avg:97.02ms +step:162/1670 train_time:15716ms step_avg:97.01ms +step:163/1670 train_time:15811ms step_avg:97.00ms +step:164/1670 train_time:15907ms step_avg:96.99ms +step:165/1670 train_time:16002ms step_avg:96.98ms +step:166/1670 train_time:16097ms step_avg:96.97ms +step:167/1670 train_time:16193ms step_avg:96.96ms +step:168/1670 train_time:16288ms step_avg:96.95ms +step:169/1670 train_time:16384ms step_avg:96.95ms +step:170/1670 train_time:16480ms step_avg:96.94ms +step:171/1670 train_time:16576ms step_avg:96.93ms +step:172/1670 train_time:16672ms step_avg:96.93ms +step:173/1670 train_time:16768ms step_avg:96.93ms +step:174/1670 train_time:16864ms step_avg:96.92ms +step:175/1670 train_time:16959ms step_avg:96.91ms +step:176/1670 train_time:17054ms step_avg:96.90ms +step:177/1670 train_time:17149ms step_avg:96.89ms +step:178/1670 train_time:17244ms step_avg:96.88ms +step:179/1670 train_time:17340ms step_avg:96.87ms +step:180/1670 train_time:17435ms step_avg:96.86ms +step:181/1670 train_time:17531ms step_avg:96.86ms +step:182/1670 train_time:17628ms step_avg:96.86ms +step:183/1670 train_time:17724ms step_avg:96.85ms +step:184/1670 train_time:17819ms step_avg:96.84ms +step:185/1670 train_time:17914ms step_avg:96.83ms +step:186/1670 train_time:18010ms step_avg:96.83ms +step:187/1670 train_time:18106ms step_avg:96.82ms +step:188/1670 train_time:18201ms step_avg:96.82ms +step:189/1670 train_time:18298ms step_avg:96.81ms +step:190/1670 train_time:18393ms step_avg:96.81ms +step:191/1670 train_time:18488ms step_avg:96.80ms +step:192/1670 train_time:18584ms step_avg:96.79ms +step:193/1670 train_time:18679ms step_avg:96.78ms +step:194/1670 train_time:18776ms step_avg:96.78ms +step:195/1670 train_time:18872ms step_avg:96.78ms +step:196/1670 train_time:18968ms step_avg:96.78ms +step:197/1670 train_time:19063ms step_avg:96.76ms +step:198/1670 train_time:19158ms step_avg:96.76ms +step:199/1670 train_time:19253ms step_avg:96.75ms +step:200/1670 train_time:19348ms step_avg:96.74ms +step:201/1670 train_time:19444ms step_avg:96.73ms +step:202/1670 train_time:19539ms step_avg:96.73ms +step:203/1670 train_time:19635ms step_avg:96.72ms +step:204/1670 train_time:19731ms step_avg:96.72ms +step:205/1670 train_time:19827ms step_avg:96.72ms +step:206/1670 train_time:19922ms step_avg:96.71ms +step:207/1670 train_time:20018ms step_avg:96.71ms +step:208/1670 train_time:20114ms step_avg:96.70ms +step:209/1670 train_time:20209ms step_avg:96.70ms +step:210/1670 train_time:20305ms step_avg:96.69ms +step:211/1670 train_time:20401ms step_avg:96.69ms +step:212/1670 train_time:20496ms step_avg:96.68ms +step:213/1670 train_time:20771ms step_avg:97.51ms +step:214/1670 train_time:20917ms step_avg:97.74ms +step:215/1670 train_time:21011ms step_avg:97.73ms +step:216/1670 train_time:21106ms step_avg:97.71ms +step:217/1670 train_time:21201ms step_avg:97.70ms +step:218/1670 train_time:21296ms step_avg:97.69ms +step:219/1670 train_time:21390ms step_avg:97.67ms +step:220/1670 train_time:21485ms step_avg:97.66ms +step:221/1670 train_time:21580ms step_avg:97.65ms +step:222/1670 train_time:21675ms step_avg:97.63ms +step:223/1670 train_time:21776ms step_avg:97.65ms +step:224/1670 train_time:21875ms step_avg:97.66ms +step:225/1670 train_time:21972ms step_avg:97.65ms +step:226/1670 train_time:22068ms step_avg:97.65ms +step:227/1670 train_time:22163ms step_avg:97.64ms +step:228/1670 train_time:22259ms step_avg:97.63ms +step:229/1670 train_time:22354ms step_avg:97.61ms +step:230/1670 train_time:22448ms step_avg:97.60ms +step:231/1670 train_time:22544ms step_avg:97.59ms +step:232/1670 train_time:22638ms step_avg:97.58ms +step:233/1670 train_time:22734ms step_avg:97.57ms +step:234/1670 train_time:22831ms step_avg:97.57ms +step:235/1670 train_time:22927ms step_avg:97.56ms +step:236/1670 train_time:23023ms step_avg:97.56ms +step:237/1670 train_time:23119ms step_avg:97.55ms +step:238/1670 train_time:23214ms step_avg:97.54ms +step:239/1670 train_time:23309ms step_avg:97.53ms +step:240/1670 train_time:23405ms step_avg:97.52ms +step:241/1670 train_time:23500ms step_avg:97.51ms +step:242/1670 train_time:23595ms step_avg:97.50ms +step:243/1670 train_time:23690ms step_avg:97.49ms +step:244/1670 train_time:23786ms step_avg:97.48ms +step:245/1670 train_time:23882ms step_avg:97.48ms +step:246/1670 train_time:23978ms step_avg:97.47ms +step:247/1670 train_time:24075ms step_avg:97.47ms +step:248/1670 train_time:24171ms step_avg:97.46ms +step:249/1670 train_time:24266ms step_avg:97.46ms +step:250/1670 train_time:24361ms step_avg:97.44ms +step:250/1670 val_loss:3.9721 train_time:24455ms step_avg:97.82ms +step:251/1670 train_time:24478ms step_avg:97.52ms +step:252/1670 train_time:24558ms step_avg:97.45ms +step:253/1670 train_time:24657ms step_avg:97.46ms +step:254/1670 train_time:24753ms step_avg:97.45ms +step:255/1670 train_time:24849ms step_avg:97.45ms +step:256/1670 train_time:24943ms step_avg:97.43ms +step:257/1670 train_time:25037ms step_avg:97.42ms +step:258/1670 train_time:25132ms step_avg:97.41ms +step:259/1670 train_time:25226ms step_avg:97.40ms +step:260/1670 train_time:25322ms step_avg:97.39ms +step:261/1670 train_time:25417ms step_avg:97.38ms +step:262/1670 train_time:25513ms step_avg:97.38ms +step:263/1670 train_time:25614ms step_avg:97.39ms +step:264/1670 train_time:25711ms step_avg:97.39ms +step:265/1670 train_time:25808ms step_avg:97.39ms +step:266/1670 train_time:25903ms step_avg:97.38ms +step:267/1670 train_time:25998ms step_avg:97.37ms +step:268/1670 train_time:26093ms step_avg:97.36ms +step:269/1670 train_time:26188ms step_avg:97.35ms +step:270/1670 train_time:26283ms step_avg:97.34ms +step:271/1670 train_time:26378ms step_avg:97.34ms +step:272/1670 train_time:26474ms step_avg:97.33ms +step:273/1670 train_time:26570ms step_avg:97.33ms +step:274/1670 train_time:26667ms step_avg:97.32ms +step:275/1670 train_time:26763ms step_avg:97.32ms +step:276/1670 train_time:26859ms step_avg:97.31ms +step:277/1670 train_time:26954ms step_avg:97.31ms +step:278/1670 train_time:27050ms step_avg:97.30ms +step:279/1670 train_time:27146ms step_avg:97.30ms +step:280/1670 train_time:27241ms step_avg:97.29ms +step:281/1670 train_time:27336ms step_avg:97.28ms +step:282/1670 train_time:27431ms step_avg:97.27ms +step:283/1670 train_time:27528ms step_avg:97.27ms +step:284/1670 train_time:27623ms step_avg:97.26ms +step:285/1670 train_time:27719ms step_avg:97.26ms +step:286/1670 train_time:27816ms step_avg:97.26ms +step:287/1670 train_time:27912ms step_avg:97.25ms +step:288/1670 train_time:28007ms step_avg:97.25ms +step:289/1670 train_time:28103ms step_avg:97.24ms +step:290/1670 train_time:28198ms step_avg:97.23ms +step:291/1670 train_time:28293ms step_avg:97.23ms +step:292/1670 train_time:28389ms step_avg:97.22ms +step:293/1670 train_time:28485ms step_avg:97.22ms +step:294/1670 train_time:28581ms step_avg:97.22ms +step:295/1670 train_time:28677ms step_avg:97.21ms +step:296/1670 train_time:28773ms step_avg:97.21ms +step:297/1670 train_time:28870ms step_avg:97.20ms +step:298/1670 train_time:28966ms step_avg:97.20ms +step:299/1670 train_time:29061ms step_avg:97.20ms +step:300/1670 train_time:29157ms step_avg:97.19ms +step:301/1670 train_time:29252ms step_avg:97.18ms +step:302/1670 train_time:29347ms step_avg:97.18ms +step:303/1670 train_time:29443ms step_avg:97.17ms +step:304/1670 train_time:29539ms step_avg:97.17ms +step:305/1670 train_time:29636ms step_avg:97.17ms +step:306/1670 train_time:29731ms step_avg:97.16ms +step:307/1670 train_time:29827ms step_avg:97.16ms +step:308/1670 train_time:29922ms step_avg:97.15ms +step:309/1670 train_time:30017ms step_avg:97.14ms +step:310/1670 train_time:30113ms step_avg:97.14ms +step:311/1670 train_time:30209ms step_avg:97.14ms +step:312/1670 train_time:30304ms step_avg:97.13ms +step:313/1670 train_time:30400ms step_avg:97.12ms +step:314/1670 train_time:30495ms step_avg:97.12ms +step:315/1670 train_time:30592ms step_avg:97.12ms +step:316/1670 train_time:30688ms step_avg:97.11ms +step:317/1670 train_time:30784ms step_avg:97.11ms +step:318/1670 train_time:30880ms step_avg:97.11ms +step:319/1670 train_time:30976ms step_avg:97.10ms +step:320/1670 train_time:31071ms step_avg:97.10ms +step:321/1670 train_time:31167ms step_avg:97.09ms +step:322/1670 train_time:31263ms step_avg:97.09ms +step:323/1670 train_time:31359ms step_avg:97.09ms +step:324/1670 train_time:31454ms step_avg:97.08ms +step:325/1670 train_time:31550ms step_avg:97.08ms +step:326/1670 train_time:31646ms step_avg:97.07ms +step:327/1670 train_time:31742ms step_avg:97.07ms +step:328/1670 train_time:31837ms step_avg:97.07ms +step:329/1670 train_time:31933ms step_avg:97.06ms +step:330/1670 train_time:32028ms step_avg:97.06ms +step:331/1670 train_time:32124ms step_avg:97.05ms +step:332/1670 train_time:32219ms step_avg:97.05ms +step:333/1670 train_time:32316ms step_avg:97.04ms +step:334/1670 train_time:32412ms step_avg:97.04ms +step:335/1670 train_time:32509ms step_avg:97.04ms +step:336/1670 train_time:32604ms step_avg:97.04ms +step:337/1670 train_time:32700ms step_avg:97.03ms +step:338/1670 train_time:32795ms step_avg:97.03ms +step:339/1670 train_time:32892ms step_avg:97.03ms +step:340/1670 train_time:32988ms step_avg:97.02ms +step:341/1670 train_time:33083ms step_avg:97.02ms +step:342/1670 train_time:33178ms step_avg:97.01ms +step:343/1670 train_time:33274ms step_avg:97.01ms +step:344/1670 train_time:33369ms step_avg:97.00ms +step:345/1670 train_time:33465ms step_avg:97.00ms +step:346/1670 train_time:33561ms step_avg:97.00ms +step:347/1670 train_time:33656ms step_avg:96.99ms +step:348/1670 train_time:33752ms step_avg:96.99ms +step:349/1670 train_time:33849ms step_avg:96.99ms +step:350/1670 train_time:33946ms step_avg:96.99ms +step:351/1670 train_time:34041ms step_avg:96.98ms +step:352/1670 train_time:34136ms step_avg:96.98ms +step:353/1670 train_time:34232ms step_avg:96.97ms +step:354/1670 train_time:34327ms step_avg:96.97ms +step:355/1670 train_time:34423ms step_avg:96.96ms +step:356/1670 train_time:34518ms step_avg:96.96ms +step:357/1670 train_time:34614ms step_avg:96.96ms +step:358/1670 train_time:34710ms step_avg:96.95ms +step:359/1670 train_time:34806ms step_avg:96.95ms +step:360/1670 train_time:34902ms step_avg:96.95ms +step:361/1670 train_time:34997ms step_avg:96.95ms +step:362/1670 train_time:35093ms step_avg:96.94ms +step:363/1670 train_time:35189ms step_avg:96.94ms +step:364/1670 train_time:35285ms step_avg:96.94ms +step:365/1670 train_time:35380ms step_avg:96.93ms +step:366/1670 train_time:35475ms step_avg:96.93ms +step:367/1670 train_time:35572ms step_avg:96.93ms +step:368/1670 train_time:35667ms step_avg:96.92ms +step:369/1670 train_time:35763ms step_avg:96.92ms +step:370/1670 train_time:35859ms step_avg:96.92ms +step:371/1670 train_time:35955ms step_avg:96.91ms +step:372/1670 train_time:36051ms step_avg:96.91ms +step:373/1670 train_time:36146ms step_avg:96.91ms +step:374/1670 train_time:36242ms step_avg:96.90ms +step:375/1670 train_time:36338ms step_avg:96.90ms +step:375/1670 val_loss:3.8167 train_time:36432ms step_avg:97.15ms +step:376/1670 train_time:36455ms step_avg:96.95ms +step:377/1670 train_time:36534ms step_avg:96.91ms +step:378/1670 train_time:36633ms step_avg:96.91ms +step:379/1670 train_time:36729ms step_avg:96.91ms +step:380/1670 train_time:36824ms step_avg:96.91ms +step:381/1670 train_time:36919ms step_avg:96.90ms +step:382/1670 train_time:37014ms step_avg:96.90ms +step:383/1670 train_time:37109ms step_avg:96.89ms +step:384/1670 train_time:37204ms step_avg:96.89ms +step:385/1670 train_time:37299ms step_avg:96.88ms +step:386/1670 train_time:37394ms step_avg:96.88ms +step:387/1670 train_time:37491ms step_avg:96.88ms +step:388/1670 train_time:37588ms step_avg:96.88ms +step:389/1670 train_time:37685ms step_avg:96.88ms +step:390/1670 train_time:37781ms step_avg:96.88ms +step:391/1670 train_time:37876ms step_avg:96.87ms +step:392/1670 train_time:37972ms step_avg:96.87ms +step:393/1670 train_time:38067ms step_avg:96.86ms +step:394/1670 train_time:38162ms step_avg:96.86ms +step:395/1670 train_time:38257ms step_avg:96.85ms +step:396/1670 train_time:38352ms step_avg:96.85ms +step:397/1670 train_time:38448ms step_avg:96.85ms +step:398/1670 train_time:38545ms step_avg:96.85ms +step:399/1670 train_time:38642ms step_avg:96.85ms +step:400/1670 train_time:38739ms step_avg:96.85ms +step:401/1670 train_time:38835ms step_avg:96.84ms +step:402/1670 train_time:38930ms step_avg:96.84ms +step:403/1670 train_time:39025ms step_avg:96.84ms +step:404/1670 train_time:39121ms step_avg:96.83ms +step:405/1670 train_time:39216ms step_avg:96.83ms +step:406/1670 train_time:39311ms step_avg:96.83ms +step:407/1670 train_time:39407ms step_avg:96.82ms +step:408/1670 train_time:39503ms step_avg:96.82ms +step:409/1670 train_time:39599ms step_avg:96.82ms +step:410/1670 train_time:39696ms step_avg:96.82ms +step:411/1670 train_time:39790ms step_avg:96.81ms +step:412/1670 train_time:39886ms step_avg:96.81ms +step:413/1670 train_time:39982ms step_avg:96.81ms +step:414/1670 train_time:40079ms step_avg:96.81ms +step:415/1670 train_time:40175ms step_avg:96.81ms +step:416/1670 train_time:40270ms step_avg:96.80ms +step:417/1670 train_time:40365ms step_avg:96.80ms +step:418/1670 train_time:40461ms step_avg:96.80ms +step:419/1670 train_time:40558ms step_avg:96.80ms +step:420/1670 train_time:40654ms step_avg:96.80ms +step:421/1670 train_time:40749ms step_avg:96.79ms +step:422/1670 train_time:40846ms step_avg:96.79ms +step:423/1670 train_time:40941ms step_avg:96.79ms +step:424/1670 train_time:41037ms step_avg:96.78ms +step:425/1670 train_time:41303ms step_avg:97.18ms +step:426/1670 train_time:41427ms step_avg:97.25ms +step:427/1670 train_time:41521ms step_avg:97.24ms +step:428/1670 train_time:41615ms step_avg:97.23ms +step:429/1670 train_time:41709ms step_avg:97.22ms +step:430/1670 train_time:41804ms step_avg:97.22ms +step:431/1670 train_time:41900ms step_avg:97.21ms +step:432/1670 train_time:41994ms step_avg:97.21ms +step:433/1670 train_time:42089ms step_avg:97.20ms +step:434/1670 train_time:42184ms step_avg:97.20ms +step:435/1670 train_time:42279ms step_avg:97.19ms +step:436/1670 train_time:42378ms step_avg:97.20ms +step:437/1670 train_time:42476ms step_avg:97.20ms +step:438/1670 train_time:42573ms step_avg:97.20ms +step:439/1670 train_time:42668ms step_avg:97.19ms +step:440/1670 train_time:42763ms step_avg:97.19ms +step:441/1670 train_time:42859ms step_avg:97.19ms +step:442/1670 train_time:42954ms step_avg:97.18ms +step:443/1670 train_time:43049ms step_avg:97.18ms +step:444/1670 train_time:43144ms step_avg:97.17ms +step:445/1670 train_time:43240ms step_avg:97.17ms +step:446/1670 train_time:43336ms step_avg:97.17ms +step:447/1670 train_time:43433ms step_avg:97.16ms +step:448/1670 train_time:43529ms step_avg:97.16ms +step:449/1670 train_time:43625ms step_avg:97.16ms +step:450/1670 train_time:43721ms step_avg:97.16ms +step:451/1670 train_time:43817ms step_avg:97.16ms +step:452/1670 train_time:43913ms step_avg:97.15ms +step:453/1670 train_time:44008ms step_avg:97.15ms +step:454/1670 train_time:44103ms step_avg:97.14ms +step:455/1670 train_time:44199ms step_avg:97.14ms +step:456/1670 train_time:44295ms step_avg:97.14ms +step:457/1670 train_time:44390ms step_avg:97.13ms +step:458/1670 train_time:44486ms step_avg:97.13ms +step:459/1670 train_time:44583ms step_avg:97.13ms +step:460/1670 train_time:44679ms step_avg:97.13ms +step:461/1670 train_time:44775ms step_avg:97.13ms +step:462/1670 train_time:44871ms step_avg:97.12ms +step:463/1670 train_time:44966ms step_avg:97.12ms +step:464/1670 train_time:45061ms step_avg:97.11ms +step:465/1670 train_time:45156ms step_avg:97.11ms +step:466/1670 train_time:45252ms step_avg:97.11ms +step:467/1670 train_time:45348ms step_avg:97.10ms +step:468/1670 train_time:45444ms step_avg:97.10ms +step:469/1670 train_time:45540ms step_avg:97.10ms +step:470/1670 train_time:45635ms step_avg:97.10ms +step:471/1670 train_time:45730ms step_avg:97.09ms +step:472/1670 train_time:45826ms step_avg:97.09ms +step:473/1670 train_time:45922ms step_avg:97.09ms +step:474/1670 train_time:46018ms step_avg:97.08ms +step:475/1670 train_time:46114ms step_avg:97.08ms +step:476/1670 train_time:46210ms step_avg:97.08ms +step:477/1670 train_time:46305ms step_avg:97.08ms +step:478/1670 train_time:46401ms step_avg:97.07ms +step:479/1670 train_time:46497ms step_avg:97.07ms +step:480/1670 train_time:46592ms step_avg:97.07ms +step:481/1670 train_time:46688ms step_avg:97.06ms +step:482/1670 train_time:46784ms step_avg:97.06ms +step:483/1670 train_time:46881ms step_avg:97.06ms +step:484/1670 train_time:46976ms step_avg:97.06ms +step:485/1670 train_time:47072ms step_avg:97.05ms +step:486/1670 train_time:47167ms step_avg:97.05ms +step:487/1670 train_time:47263ms step_avg:97.05ms +step:488/1670 train_time:47359ms step_avg:97.05ms +step:489/1670 train_time:47455ms step_avg:97.05ms +step:490/1670 train_time:47550ms step_avg:97.04ms +step:491/1670 train_time:47646ms step_avg:97.04ms +step:492/1670 train_time:47743ms step_avg:97.04ms +step:493/1670 train_time:47839ms step_avg:97.04ms +step:494/1670 train_time:47935ms step_avg:97.03ms +step:495/1670 train_time:48030ms step_avg:97.03ms +step:496/1670 train_time:48126ms step_avg:97.03ms +step:497/1670 train_time:48222ms step_avg:97.03ms +step:498/1670 train_time:48318ms step_avg:97.02ms +step:499/1670 train_time:48414ms step_avg:97.02ms +step:500/1670 train_time:48510ms step_avg:97.02ms +step:500/1670 val_loss:3.7126 train_time:48604ms step_avg:97.21ms +step:501/1670 train_time:48627ms step_avg:97.06ms +step:502/1670 train_time:48708ms step_avg:97.03ms +step:503/1670 train_time:48807ms step_avg:97.03ms +step:504/1670 train_time:48903ms step_avg:97.03ms +step:505/1670 train_time:49000ms step_avg:97.03ms +step:506/1670 train_time:49096ms step_avg:97.03ms +step:507/1670 train_time:49191ms step_avg:97.02ms +step:508/1670 train_time:49285ms step_avg:97.02ms +step:509/1670 train_time:49381ms step_avg:97.02ms +step:510/1670 train_time:49476ms step_avg:97.01ms +step:511/1670 train_time:49572ms step_avg:97.01ms +step:512/1670 train_time:49669ms step_avg:97.01ms +step:513/1670 train_time:49766ms step_avg:97.01ms +step:514/1670 train_time:49862ms step_avg:97.01ms +step:515/1670 train_time:49960ms step_avg:97.01ms +step:516/1670 train_time:50056ms step_avg:97.01ms +step:517/1670 train_time:50151ms step_avg:97.00ms +step:518/1670 train_time:50246ms step_avg:97.00ms +step:519/1670 train_time:50341ms step_avg:97.00ms +step:520/1670 train_time:50436ms step_avg:96.99ms +step:521/1670 train_time:50531ms step_avg:96.99ms +step:522/1670 train_time:50627ms step_avg:96.99ms +step:523/1670 train_time:50723ms step_avg:96.99ms +step:524/1670 train_time:50820ms step_avg:96.99ms +step:525/1670 train_time:50917ms step_avg:96.99ms +step:526/1670 train_time:51014ms step_avg:96.98ms +step:527/1670 train_time:51109ms step_avg:96.98ms +step:528/1670 train_time:51205ms step_avg:96.98ms +step:529/1670 train_time:51300ms step_avg:96.98ms +step:530/1670 train_time:51395ms step_avg:96.97ms +step:531/1670 train_time:51490ms step_avg:96.97ms +step:532/1670 train_time:51586ms step_avg:96.97ms +step:533/1670 train_time:51682ms step_avg:96.96ms +step:534/1670 train_time:51778ms step_avg:96.96ms +step:535/1670 train_time:51876ms step_avg:96.96ms +step:536/1670 train_time:51972ms step_avg:96.96ms +step:537/1670 train_time:52068ms step_avg:96.96ms +step:538/1670 train_time:52164ms step_avg:96.96ms +step:539/1670 train_time:52260ms step_avg:96.96ms +step:540/1670 train_time:52356ms step_avg:96.96ms +step:541/1670 train_time:52452ms step_avg:96.95ms +step:542/1670 train_time:52547ms step_avg:96.95ms +step:543/1670 train_time:52642ms step_avg:96.95ms +step:544/1670 train_time:52738ms step_avg:96.94ms +step:545/1670 train_time:52834ms step_avg:96.94ms +step:546/1670 train_time:52930ms step_avg:96.94ms +step:547/1670 train_time:53025ms step_avg:96.94ms +step:548/1670 train_time:53121ms step_avg:96.94ms +step:549/1670 train_time:53217ms step_avg:96.94ms +step:550/1670 train_time:53313ms step_avg:96.93ms +step:551/1670 train_time:53409ms step_avg:96.93ms +step:552/1670 train_time:53505ms step_avg:96.93ms +step:553/1670 train_time:53600ms step_avg:96.93ms +step:554/1670 train_time:53697ms step_avg:96.93ms +step:555/1670 train_time:53793ms step_avg:96.92ms +step:556/1670 train_time:53889ms step_avg:96.92ms +step:557/1670 train_time:53985ms step_avg:96.92ms +step:558/1670 train_time:54081ms step_avg:96.92ms +step:559/1670 train_time:54178ms step_avg:96.92ms +step:560/1670 train_time:54276ms step_avg:96.92ms +step:561/1670 train_time:54373ms step_avg:96.92ms +step:562/1670 train_time:54470ms step_avg:96.92ms +step:563/1670 train_time:54566ms step_avg:96.92ms +step:564/1670 train_time:54663ms step_avg:96.92ms +step:565/1670 train_time:54761ms step_avg:96.92ms +step:566/1670 train_time:54859ms step_avg:96.92ms +step:567/1670 train_time:54958ms step_avg:96.93ms +step:568/1670 train_time:55056ms step_avg:96.93ms +step:569/1670 train_time:55154ms step_avg:96.93ms +step:570/1670 train_time:55251ms step_avg:96.93ms +step:571/1670 train_time:55348ms step_avg:96.93ms +step:572/1670 train_time:55445ms step_avg:96.93ms +step:573/1670 train_time:55542ms step_avg:96.93ms +step:574/1670 train_time:55640ms step_avg:96.93ms +step:575/1670 train_time:55737ms step_avg:96.93ms +step:576/1670 train_time:55836ms step_avg:96.94ms +step:577/1670 train_time:55933ms step_avg:96.94ms +step:578/1670 train_time:56030ms step_avg:96.94ms +step:579/1670 train_time:56127ms step_avg:96.94ms +step:580/1670 train_time:56224ms step_avg:96.94ms +step:581/1670 train_time:56321ms step_avg:96.94ms +step:582/1670 train_time:56419ms step_avg:96.94ms +step:583/1670 train_time:56516ms step_avg:96.94ms +step:584/1670 train_time:56613ms step_avg:96.94ms +step:585/1670 train_time:56710ms step_avg:96.94ms +step:586/1670 train_time:56807ms step_avg:96.94ms +step:587/1670 train_time:56903ms step_avg:96.94ms +step:588/1670 train_time:57002ms step_avg:96.94ms +step:589/1670 train_time:57099ms step_avg:96.94ms +step:590/1670 train_time:57197ms step_avg:96.94ms +step:591/1670 train_time:57294ms step_avg:96.94ms +step:592/1670 train_time:57392ms step_avg:96.95ms +step:593/1670 train_time:57489ms step_avg:96.95ms +step:594/1670 train_time:57585ms step_avg:96.94ms +step:595/1670 train_time:57682ms step_avg:96.94ms +step:596/1670 train_time:57780ms step_avg:96.95ms +step:597/1670 train_time:57879ms step_avg:96.95ms +step:598/1670 train_time:57977ms step_avg:96.95ms +step:599/1670 train_time:58074ms step_avg:96.95ms +step:600/1670 train_time:58171ms step_avg:96.95ms +step:601/1670 train_time:58268ms step_avg:96.95ms +step:602/1670 train_time:58365ms step_avg:96.95ms +step:603/1670 train_time:58463ms step_avg:96.95ms +step:604/1670 train_time:58561ms step_avg:96.96ms +step:605/1670 train_time:58658ms step_avg:96.96ms +step:606/1670 train_time:58755ms step_avg:96.96ms +step:607/1670 train_time:58853ms step_avg:96.96ms +step:608/1670 train_time:58949ms step_avg:96.96ms +step:609/1670 train_time:59046ms step_avg:96.96ms +step:610/1670 train_time:59143ms step_avg:96.96ms +step:611/1670 train_time:59240ms step_avg:96.96ms +step:612/1670 train_time:59338ms step_avg:96.96ms +step:613/1670 train_time:59436ms step_avg:96.96ms +step:614/1670 train_time:59533ms step_avg:96.96ms +step:615/1670 train_time:59630ms step_avg:96.96ms +step:616/1670 train_time:59726ms step_avg:96.96ms +step:617/1670 train_time:59823ms step_avg:96.96ms +step:618/1670 train_time:59921ms step_avg:96.96ms +step:619/1670 train_time:60020ms step_avg:96.96ms +step:620/1670 train_time:60117ms step_avg:96.96ms +step:621/1670 train_time:60215ms step_avg:96.96ms +step:622/1670 train_time:60312ms step_avg:96.97ms +step:623/1670 train_time:60409ms step_avg:96.96ms +step:624/1670 train_time:60506ms step_avg:96.96ms +step:625/1670 train_time:60604ms step_avg:96.97ms +step:625/1670 val_loss:3.6154 train_time:60700ms step_avg:97.12ms +step:626/1670 train_time:60724ms step_avg:97.00ms +step:627/1670 train_time:60813ms step_avg:96.99ms +step:628/1670 train_time:60912ms step_avg:96.99ms +step:629/1670 train_time:61008ms step_avg:96.99ms +step:630/1670 train_time:61104ms step_avg:96.99ms +step:631/1670 train_time:61200ms step_avg:96.99ms +step:632/1670 train_time:61296ms step_avg:96.99ms +step:633/1670 train_time:61391ms step_avg:96.98ms +step:634/1670 train_time:61487ms step_avg:96.98ms +step:635/1670 train_time:61583ms step_avg:96.98ms +step:636/1670 train_time:61683ms step_avg:96.99ms +step:637/1670 train_time:61785ms step_avg:96.99ms +step:638/1670 train_time:61885ms step_avg:97.00ms +step:639/1670 train_time:62253ms step_avg:97.42ms +step:640/1670 train_time:62343ms step_avg:97.41ms +step:641/1670 train_time:62440ms step_avg:97.41ms +step:642/1670 train_time:62536ms step_avg:97.41ms +step:643/1670 train_time:62632ms step_avg:97.41ms +step:644/1670 train_time:62728ms step_avg:97.40ms +step:645/1670 train_time:62824ms step_avg:97.40ms +step:646/1670 train_time:62920ms step_avg:97.40ms +step:647/1670 train_time:63017ms step_avg:97.40ms +step:648/1670 train_time:63114ms step_avg:97.40ms +step:649/1670 train_time:63218ms step_avg:97.41ms +step:650/1670 train_time:63317ms step_avg:97.41ms +step:651/1670 train_time:63416ms step_avg:97.41ms +step:652/1670 train_time:63512ms step_avg:97.41ms +step:653/1670 train_time:63609ms step_avg:97.41ms +step:654/1670 train_time:63706ms step_avg:97.41ms +step:655/1670 train_time:63802ms step_avg:97.41ms +step:656/1670 train_time:63898ms step_avg:97.41ms +step:657/1670 train_time:63994ms step_avg:97.40ms +step:658/1670 train_time:64091ms step_avg:97.40ms +step:659/1670 train_time:64188ms step_avg:97.40ms +step:660/1670 train_time:64288ms step_avg:97.41ms +step:661/1670 train_time:64387ms step_avg:97.41ms +step:662/1670 train_time:64485ms step_avg:97.41ms +step:663/1670 train_time:64583ms step_avg:97.41ms +step:664/1670 train_time:64680ms step_avg:97.41ms +step:665/1670 train_time:64778ms step_avg:97.41ms +step:666/1670 train_time:64874ms step_avg:97.41ms +step:667/1670 train_time:64970ms step_avg:97.41ms +step:668/1670 train_time:65067ms step_avg:97.41ms +step:669/1670 train_time:65164ms step_avg:97.41ms +step:670/1670 train_time:65263ms step_avg:97.41ms +step:671/1670 train_time:65363ms step_avg:97.41ms +step:672/1670 train_time:65461ms step_avg:97.41ms +step:673/1670 train_time:65559ms step_avg:97.41ms +step:674/1670 train_time:65656ms step_avg:97.41ms +step:675/1670 train_time:65753ms step_avg:97.41ms +step:676/1670 train_time:65849ms step_avg:97.41ms +step:677/1670 train_time:65946ms step_avg:97.41ms +step:678/1670 train_time:66044ms step_avg:97.41ms +step:679/1670 train_time:66141ms step_avg:97.41ms +step:680/1670 train_time:66238ms step_avg:97.41ms +step:681/1670 train_time:66336ms step_avg:97.41ms +step:682/1670 train_time:66434ms step_avg:97.41ms +step:683/1670 train_time:66531ms step_avg:97.41ms +step:684/1670 train_time:66628ms step_avg:97.41ms +step:685/1670 train_time:66726ms step_avg:97.41ms +step:686/1670 train_time:66822ms step_avg:97.41ms +step:687/1670 train_time:66919ms step_avg:97.41ms +step:688/1670 train_time:67017ms step_avg:97.41ms +step:689/1670 train_time:67114ms step_avg:97.41ms +step:690/1670 train_time:67211ms step_avg:97.41ms +step:691/1670 train_time:67308ms step_avg:97.41ms +step:692/1670 train_time:67407ms step_avg:97.41ms +step:693/1670 train_time:67505ms step_avg:97.41ms +step:694/1670 train_time:67603ms step_avg:97.41ms +step:695/1670 train_time:67702ms step_avg:97.41ms +step:696/1670 train_time:67799ms step_avg:97.41ms +step:697/1670 train_time:67896ms step_avg:97.41ms +step:698/1670 train_time:67992ms step_avg:97.41ms +step:699/1670 train_time:68088ms step_avg:97.41ms +step:700/1670 train_time:68186ms step_avg:97.41ms +step:701/1670 train_time:68283ms step_avg:97.41ms +step:702/1670 train_time:68382ms step_avg:97.41ms +step:703/1670 train_time:68480ms step_avg:97.41ms +step:704/1670 train_time:68578ms step_avg:97.41ms +step:705/1670 train_time:68676ms step_avg:97.41ms +step:706/1670 train_time:68773ms step_avg:97.41ms +step:707/1670 train_time:68870ms step_avg:97.41ms +step:708/1670 train_time:68967ms step_avg:97.41ms +step:709/1670 train_time:69064ms step_avg:97.41ms +step:710/1670 train_time:69161ms step_avg:97.41ms +step:711/1670 train_time:69259ms step_avg:97.41ms +step:712/1670 train_time:69356ms step_avg:97.41ms +step:713/1670 train_time:69453ms step_avg:97.41ms +step:714/1670 train_time:69550ms step_avg:97.41ms +step:715/1670 train_time:69648ms step_avg:97.41ms +step:716/1670 train_time:69746ms step_avg:97.41ms +step:717/1670 train_time:69844ms step_avg:97.41ms +step:718/1670 train_time:69941ms step_avg:97.41ms +step:719/1670 train_time:70039ms step_avg:97.41ms +step:720/1670 train_time:70136ms step_avg:97.41ms +step:721/1670 train_time:70233ms step_avg:97.41ms +step:722/1670 train_time:70330ms step_avg:97.41ms +step:723/1670 train_time:70428ms step_avg:97.41ms +step:724/1670 train_time:70525ms step_avg:97.41ms +step:725/1670 train_time:70624ms step_avg:97.41ms +step:726/1670 train_time:70721ms step_avg:97.41ms +step:727/1670 train_time:70818ms step_avg:97.41ms +step:728/1670 train_time:70916ms step_avg:97.41ms +step:729/1670 train_time:71013ms step_avg:97.41ms +step:730/1670 train_time:71110ms step_avg:97.41ms +step:731/1670 train_time:71208ms step_avg:97.41ms +step:732/1670 train_time:71306ms step_avg:97.41ms +step:733/1670 train_time:71403ms step_avg:97.41ms +step:734/1670 train_time:71500ms step_avg:97.41ms +step:735/1670 train_time:71599ms step_avg:97.41ms +step:736/1670 train_time:71696ms step_avg:97.41ms +step:737/1670 train_time:71794ms step_avg:97.41ms +step:738/1670 train_time:71890ms step_avg:97.41ms +step:739/1670 train_time:71987ms step_avg:97.41ms +step:740/1670 train_time:72085ms step_avg:97.41ms +step:741/1670 train_time:72183ms step_avg:97.41ms +step:742/1670 train_time:72281ms step_avg:97.41ms +step:743/1670 train_time:72379ms step_avg:97.41ms +step:744/1670 train_time:72476ms step_avg:97.41ms +step:745/1670 train_time:72573ms step_avg:97.41ms +step:746/1670 train_time:72670ms step_avg:97.41ms +step:747/1670 train_time:72768ms step_avg:97.41ms +step:748/1670 train_time:72864ms step_avg:97.41ms +step:749/1670 train_time:72963ms step_avg:97.41ms +step:750/1670 train_time:73061ms step_avg:97.41ms +step:750/1670 val_loss:3.5609 train_time:73158ms step_avg:97.54ms +step:751/1670 train_time:73181ms step_avg:97.44ms +step:752/1670 train_time:73262ms step_avg:97.42ms +step:753/1670 train_time:73361ms step_avg:97.43ms +step:754/1670 train_time:73460ms step_avg:97.43ms +step:755/1670 train_time:73557ms step_avg:97.43ms +step:756/1670 train_time:73653ms step_avg:97.42ms +step:757/1670 train_time:73749ms step_avg:97.42ms +step:758/1670 train_time:73846ms step_avg:97.42ms +step:759/1670 train_time:73942ms step_avg:97.42ms +step:760/1670 train_time:74038ms step_avg:97.42ms +step:761/1670 train_time:74137ms step_avg:97.42ms +step:762/1670 train_time:74237ms step_avg:97.42ms +step:763/1670 train_time:74337ms step_avg:97.43ms +step:764/1670 train_time:74435ms step_avg:97.43ms +step:765/1670 train_time:74533ms step_avg:97.43ms +step:766/1670 train_time:74631ms step_avg:97.43ms +step:767/1670 train_time:74728ms step_avg:97.43ms +step:768/1670 train_time:74825ms step_avg:97.43ms +step:769/1670 train_time:74922ms step_avg:97.43ms +step:770/1670 train_time:75018ms step_avg:97.43ms +step:771/1670 train_time:75116ms step_avg:97.43ms +step:772/1670 train_time:75214ms step_avg:97.43ms +step:773/1670 train_time:75313ms step_avg:97.43ms +step:774/1670 train_time:75412ms step_avg:97.43ms +step:775/1670 train_time:75510ms step_avg:97.43ms +step:776/1670 train_time:75607ms step_avg:97.43ms +step:777/1670 train_time:75704ms step_avg:97.43ms +step:778/1670 train_time:75801ms step_avg:97.43ms +step:779/1670 train_time:75897ms step_avg:97.43ms +step:780/1670 train_time:75995ms step_avg:97.43ms +step:781/1670 train_time:76091ms step_avg:97.43ms +step:782/1670 train_time:76189ms step_avg:97.43ms +step:783/1670 train_time:76288ms step_avg:97.43ms +step:784/1670 train_time:76386ms step_avg:97.43ms +step:785/1670 train_time:76483ms step_avg:97.43ms +step:786/1670 train_time:76581ms step_avg:97.43ms +step:787/1670 train_time:76679ms step_avg:97.43ms +step:788/1670 train_time:76776ms step_avg:97.43ms +step:789/1670 train_time:76874ms step_avg:97.43ms +step:790/1670 train_time:76972ms step_avg:97.43ms +step:791/1670 train_time:77069ms step_avg:97.43ms +step:792/1670 train_time:77167ms step_avg:97.43ms +step:793/1670 train_time:77267ms step_avg:97.44ms +step:794/1670 train_time:77364ms step_avg:97.44ms +step:795/1670 train_time:77462ms step_avg:97.44ms +step:796/1670 train_time:77558ms step_avg:97.44ms +step:797/1670 train_time:77656ms step_avg:97.43ms +step:798/1670 train_time:77752ms step_avg:97.43ms +step:799/1670 train_time:77849ms step_avg:97.43ms +step:800/1670 train_time:77946ms step_avg:97.43ms +step:801/1670 train_time:78043ms step_avg:97.43ms +step:802/1670 train_time:78140ms step_avg:97.43ms +step:803/1670 train_time:78238ms step_avg:97.43ms +step:804/1670 train_time:78336ms step_avg:97.43ms +step:805/1670 train_time:78436ms step_avg:97.44ms +step:806/1670 train_time:78535ms step_avg:97.44ms +step:807/1670 train_time:78633ms step_avg:97.44ms +step:808/1670 train_time:78730ms step_avg:97.44ms +step:809/1670 train_time:78827ms step_avg:97.44ms +step:810/1670 train_time:78923ms step_avg:97.44ms +step:811/1670 train_time:79020ms step_avg:97.44ms +step:812/1670 train_time:79118ms step_avg:97.44ms +step:813/1670 train_time:79216ms step_avg:97.44ms +step:814/1670 train_time:79314ms step_avg:97.44ms +step:815/1670 train_time:79412ms step_avg:97.44ms +step:816/1670 train_time:79510ms step_avg:97.44ms +step:817/1670 train_time:79608ms step_avg:97.44ms +step:818/1670 train_time:79705ms step_avg:97.44ms +step:819/1670 train_time:79802ms step_avg:97.44ms +step:820/1670 train_time:79899ms step_avg:97.44ms +step:821/1670 train_time:79996ms step_avg:97.44ms +step:822/1670 train_time:80093ms step_avg:97.44ms +step:823/1670 train_time:80192ms step_avg:97.44ms +step:824/1670 train_time:80291ms step_avg:97.44ms +step:825/1670 train_time:80388ms step_avg:97.44ms +step:826/1670 train_time:80486ms step_avg:97.44ms +step:827/1670 train_time:80583ms step_avg:97.44ms +step:828/1670 train_time:80680ms step_avg:97.44ms +step:829/1670 train_time:80777ms step_avg:97.44ms +step:830/1670 train_time:80875ms step_avg:97.44ms +step:831/1670 train_time:80973ms step_avg:97.44ms +step:832/1670 train_time:81071ms step_avg:97.44ms +step:833/1670 train_time:81168ms step_avg:97.44ms +step:834/1670 train_time:81265ms step_avg:97.44ms +step:835/1670 train_time:81363ms step_avg:97.44ms +step:836/1670 train_time:81460ms step_avg:97.44ms +step:837/1670 train_time:81558ms step_avg:97.44ms +step:838/1670 train_time:81655ms step_avg:97.44ms +step:839/1670 train_time:81753ms step_avg:97.44ms +step:840/1670 train_time:81851ms step_avg:97.44ms +step:841/1670 train_time:81949ms step_avg:97.44ms +step:842/1670 train_time:82045ms step_avg:97.44ms +step:843/1670 train_time:82142ms step_avg:97.44ms +step:844/1670 train_time:82239ms step_avg:97.44ms +step:845/1670 train_time:82335ms step_avg:97.44ms +step:846/1670 train_time:82434ms step_avg:97.44ms +step:847/1670 train_time:82534ms step_avg:97.44ms +step:848/1670 train_time:82631ms step_avg:97.44ms +step:849/1670 train_time:82729ms step_avg:97.44ms +step:850/1670 train_time:82826ms step_avg:97.44ms +step:851/1670 train_time:83088ms step_avg:97.64ms +step:852/1670 train_time:83261ms step_avg:97.72ms +step:853/1670 train_time:83356ms step_avg:97.72ms +step:854/1670 train_time:83453ms step_avg:97.72ms +step:855/1670 train_time:83550ms step_avg:97.72ms +step:856/1670 train_time:83646ms step_avg:97.72ms +step:857/1670 train_time:83742ms step_avg:97.71ms +step:858/1670 train_time:83838ms step_avg:97.71ms +step:859/1670 train_time:83935ms step_avg:97.71ms +step:860/1670 train_time:84031ms step_avg:97.71ms +step:861/1670 train_time:84135ms step_avg:97.72ms +step:862/1670 train_time:84238ms step_avg:97.72ms +step:863/1670 train_time:84337ms step_avg:97.73ms +step:864/1670 train_time:84435ms step_avg:97.73ms +step:865/1670 train_time:84532ms step_avg:97.73ms +step:866/1670 train_time:84630ms step_avg:97.73ms +step:867/1670 train_time:84727ms step_avg:97.72ms +step:868/1670 train_time:84823ms step_avg:97.72ms +step:869/1670 train_time:84920ms step_avg:97.72ms +step:870/1670 train_time:85016ms step_avg:97.72ms +step:871/1670 train_time:85114ms step_avg:97.72ms +step:872/1670 train_time:85215ms step_avg:97.72ms +step:873/1670 train_time:85314ms step_avg:97.73ms +step:874/1670 train_time:85413ms step_avg:97.73ms +step:875/1670 train_time:85511ms step_avg:97.73ms +step:875/1670 val_loss:3.5199 train_time:85608ms step_avg:97.84ms +step:876/1670 train_time:85630ms step_avg:97.75ms +step:877/1670 train_time:85714ms step_avg:97.74ms +step:878/1670 train_time:85817ms step_avg:97.74ms +step:879/1670 train_time:85915ms step_avg:97.74ms +step:880/1670 train_time:86012ms step_avg:97.74ms +step:881/1670 train_time:86108ms step_avg:97.74ms +step:882/1670 train_time:86205ms step_avg:97.74ms +step:883/1670 train_time:86301ms step_avg:97.74ms +step:884/1670 train_time:86398ms step_avg:97.73ms +step:885/1670 train_time:86495ms step_avg:97.73ms +step:886/1670 train_time:86594ms step_avg:97.74ms +step:887/1670 train_time:86695ms step_avg:97.74ms +step:888/1670 train_time:86796ms step_avg:97.74ms +step:889/1670 train_time:86894ms step_avg:97.74ms +step:890/1670 train_time:86992ms step_avg:97.74ms +step:891/1670 train_time:87088ms step_avg:97.74ms +step:892/1670 train_time:87184ms step_avg:97.74ms +step:893/1670 train_time:87281ms step_avg:97.74ms +step:894/1670 train_time:87377ms step_avg:97.74ms +step:895/1670 train_time:87474ms step_avg:97.74ms +step:896/1670 train_time:87572ms step_avg:97.74ms +step:897/1670 train_time:87671ms step_avg:97.74ms +step:898/1670 train_time:87769ms step_avg:97.74ms +step:899/1670 train_time:87867ms step_avg:97.74ms +step:900/1670 train_time:87964ms step_avg:97.74ms +step:901/1670 train_time:88061ms step_avg:97.74ms +step:902/1670 train_time:88158ms step_avg:97.74ms +step:903/1670 train_time:88255ms step_avg:97.74ms +step:904/1670 train_time:88352ms step_avg:97.73ms +step:905/1670 train_time:88449ms step_avg:97.73ms +step:906/1670 train_time:88546ms step_avg:97.73ms +step:907/1670 train_time:88644ms step_avg:97.73ms +step:908/1670 train_time:88742ms step_avg:97.73ms +step:909/1670 train_time:88841ms step_avg:97.73ms +step:910/1670 train_time:88940ms step_avg:97.74ms +step:911/1670 train_time:89038ms step_avg:97.74ms +step:912/1670 train_time:89136ms step_avg:97.74ms +step:913/1670 train_time:89234ms step_avg:97.74ms +step:914/1670 train_time:89330ms step_avg:97.74ms +step:915/1670 train_time:89427ms step_avg:97.73ms +step:916/1670 train_time:89524ms step_avg:97.73ms +step:917/1670 train_time:89621ms step_avg:97.73ms +step:918/1670 train_time:89719ms step_avg:97.73ms +step:919/1670 train_time:89817ms step_avg:97.73ms +step:920/1670 train_time:89915ms step_avg:97.73ms +step:921/1670 train_time:90014ms step_avg:97.74ms +step:922/1670 train_time:90112ms step_avg:97.73ms +step:923/1670 train_time:90209ms step_avg:97.73ms +step:924/1670 train_time:90306ms step_avg:97.73ms +step:925/1670 train_time:90403ms step_avg:97.73ms +step:926/1670 train_time:90500ms step_avg:97.73ms +step:927/1670 train_time:90597ms step_avg:97.73ms +step:928/1670 train_time:90694ms step_avg:97.73ms +step:929/1670 train_time:90793ms step_avg:97.73ms +step:930/1670 train_time:90892ms step_avg:97.73ms +step:931/1670 train_time:90989ms step_avg:97.73ms +step:932/1670 train_time:91087ms step_avg:97.73ms +step:933/1670 train_time:91185ms step_avg:97.73ms +step:934/1670 train_time:91282ms step_avg:97.73ms +step:935/1670 train_time:91380ms step_avg:97.73ms +step:936/1670 train_time:91476ms step_avg:97.73ms +step:937/1670 train_time:91574ms step_avg:97.73ms +step:938/1670 train_time:91671ms step_avg:97.73ms +step:939/1670 train_time:91770ms step_avg:97.73ms +step:940/1670 train_time:91868ms step_avg:97.73ms +step:941/1670 train_time:91965ms step_avg:97.73ms +step:942/1670 train_time:92062ms step_avg:97.73ms +step:943/1670 train_time:92159ms step_avg:97.73ms +step:944/1670 train_time:92257ms step_avg:97.73ms +step:945/1670 train_time:92355ms step_avg:97.73ms +step:946/1670 train_time:92452ms step_avg:97.73ms +step:947/1670 train_time:92549ms step_avg:97.73ms +step:948/1670 train_time:92646ms step_avg:97.73ms +step:949/1670 train_time:92743ms step_avg:97.73ms +step:950/1670 train_time:92841ms step_avg:97.73ms +step:951/1670 train_time:92939ms step_avg:97.73ms +step:952/1670 train_time:93037ms step_avg:97.73ms +step:953/1670 train_time:93136ms step_avg:97.73ms +step:954/1670 train_time:93233ms step_avg:97.73ms +step:955/1670 train_time:93331ms step_avg:97.73ms +step:956/1670 train_time:93428ms step_avg:97.73ms +step:957/1670 train_time:93524ms step_avg:97.73ms +step:958/1670 train_time:93621ms step_avg:97.73ms +step:959/1670 train_time:93718ms step_avg:97.72ms +step:960/1670 train_time:93816ms step_avg:97.72ms +step:961/1670 train_time:93914ms step_avg:97.73ms +step:962/1670 train_time:94013ms step_avg:97.73ms +step:963/1670 train_time:94112ms step_avg:97.73ms +step:964/1670 train_time:94210ms step_avg:97.73ms +step:965/1670 train_time:94308ms step_avg:97.73ms +step:966/1670 train_time:94406ms step_avg:97.73ms +step:967/1670 train_time:94502ms step_avg:97.73ms +step:968/1670 train_time:94598ms step_avg:97.73ms +step:969/1670 train_time:94696ms step_avg:97.73ms +step:970/1670 train_time:94794ms step_avg:97.73ms +step:971/1670 train_time:94892ms step_avg:97.73ms +step:972/1670 train_time:94989ms step_avg:97.73ms +step:973/1670 train_time:95087ms step_avg:97.73ms +step:974/1670 train_time:95184ms step_avg:97.73ms +step:975/1670 train_time:95282ms step_avg:97.72ms +step:976/1670 train_time:95378ms step_avg:97.72ms +step:977/1670 train_time:95476ms step_avg:97.72ms +step:978/1670 train_time:95574ms step_avg:97.72ms +step:979/1670 train_time:95672ms step_avg:97.72ms +step:980/1670 train_time:95769ms step_avg:97.72ms +step:981/1670 train_time:95867ms step_avg:97.72ms +step:982/1670 train_time:95965ms step_avg:97.72ms +step:983/1670 train_time:96063ms step_avg:97.72ms +step:984/1670 train_time:96160ms step_avg:97.72ms +step:985/1670 train_time:96258ms step_avg:97.72ms +step:986/1670 train_time:96355ms step_avg:97.72ms +step:987/1670 train_time:96452ms step_avg:97.72ms +step:988/1670 train_time:96550ms step_avg:97.72ms +step:989/1670 train_time:96647ms step_avg:97.72ms +step:990/1670 train_time:96744ms step_avg:97.72ms +step:991/1670 train_time:96842ms step_avg:97.72ms +step:992/1670 train_time:96939ms step_avg:97.72ms +step:993/1670 train_time:97036ms step_avg:97.72ms +step:994/1670 train_time:97135ms step_avg:97.72ms +step:995/1670 train_time:97233ms step_avg:97.72ms +step:996/1670 train_time:97331ms step_avg:97.72ms +step:997/1670 train_time:97428ms step_avg:97.72ms +step:998/1670 train_time:97524ms step_avg:97.72ms +step:999/1670 train_time:97622ms step_avg:97.72ms +step:1000/1670 train_time:97719ms step_avg:97.72ms +step:1000/1670 val_loss:3.4777 train_time:97816ms step_avg:97.82ms +step:1001/1670 train_time:97839ms step_avg:97.74ms +step:1002/1670 train_time:97921ms step_avg:97.73ms +step:1003/1670 train_time:98021ms step_avg:97.73ms +step:1004/1670 train_time:98120ms step_avg:97.73ms +step:1005/1670 train_time:98217ms step_avg:97.73ms +step:1006/1670 train_time:98313ms step_avg:97.73ms +step:1007/1670 train_time:98409ms step_avg:97.73ms +step:1008/1670 train_time:98506ms step_avg:97.72ms +step:1009/1670 train_time:98602ms step_avg:97.72ms +step:1010/1670 train_time:98699ms step_avg:97.72ms +step:1011/1670 train_time:98798ms step_avg:97.72ms +step:1012/1670 train_time:98898ms step_avg:97.72ms +step:1013/1670 train_time:98998ms step_avg:97.73ms +step:1014/1670 train_time:99097ms step_avg:97.73ms +step:1015/1670 train_time:99194ms step_avg:97.73ms +step:1016/1670 train_time:99292ms step_avg:97.73ms +step:1017/1670 train_time:99388ms step_avg:97.73ms +step:1018/1670 train_time:99485ms step_avg:97.73ms +step:1019/1670 train_time:99582ms step_avg:97.72ms +step:1020/1670 train_time:99678ms step_avg:97.72ms +step:1021/1670 train_time:99776ms step_avg:97.72ms +step:1022/1670 train_time:99873ms step_avg:97.72ms +step:1023/1670 train_time:99972ms step_avg:97.72ms +step:1024/1670 train_time:100070ms step_avg:97.72ms +step:1025/1670 train_time:100168ms step_avg:97.72ms +step:1026/1670 train_time:100266ms step_avg:97.73ms +step:1027/1670 train_time:100364ms step_avg:97.73ms +step:1028/1670 train_time:100461ms step_avg:97.72ms +step:1029/1670 train_time:100557ms step_avg:97.72ms +step:1030/1670 train_time:100654ms step_avg:97.72ms +step:1031/1670 train_time:100751ms step_avg:97.72ms +step:1032/1670 train_time:100850ms step_avg:97.72ms +step:1033/1670 train_time:100950ms step_avg:97.72ms +step:1034/1670 train_time:101049ms step_avg:97.73ms +step:1035/1670 train_time:101147ms step_avg:97.73ms +step:1036/1670 train_time:101245ms step_avg:97.73ms +step:1037/1670 train_time:101343ms step_avg:97.73ms +step:1038/1670 train_time:101439ms step_avg:97.73ms +step:1039/1670 train_time:101536ms step_avg:97.72ms +step:1040/1670 train_time:101632ms step_avg:97.72ms +step:1041/1670 train_time:101730ms step_avg:97.72ms +step:1042/1670 train_time:101828ms step_avg:97.72ms +step:1043/1670 train_time:101927ms step_avg:97.72ms +step:1044/1670 train_time:102026ms step_avg:97.73ms +step:1045/1670 train_time:102123ms step_avg:97.73ms +step:1046/1670 train_time:102222ms step_avg:97.73ms +step:1047/1670 train_time:102319ms step_avg:97.73ms +step:1048/1670 train_time:102417ms step_avg:97.73ms +step:1049/1670 train_time:102514ms step_avg:97.73ms +step:1050/1670 train_time:102611ms step_avg:97.72ms +step:1051/1670 train_time:102708ms step_avg:97.72ms +step:1052/1670 train_time:102806ms step_avg:97.72ms +step:1053/1670 train_time:103228ms step_avg:98.03ms +step:1054/1670 train_time:103323ms step_avg:98.03ms +step:1055/1670 train_time:103419ms step_avg:98.03ms +step:1056/1670 train_time:103515ms step_avg:98.03ms +step:1057/1670 train_time:103610ms step_avg:98.02ms +step:1058/1670 train_time:103707ms step_avg:98.02ms +step:1059/1670 train_time:103803ms step_avg:98.02ms +step:1060/1670 train_time:103900ms step_avg:98.02ms +step:1061/1670 train_time:103997ms step_avg:98.02ms +step:1062/1670 train_time:104337ms step_avg:98.25ms +step:1063/1670 train_time:104437ms step_avg:98.25ms +step:1064/1670 train_time:104532ms step_avg:98.24ms +step:1065/1670 train_time:104629ms step_avg:98.24ms +step:1066/1670 train_time:104725ms step_avg:98.24ms +step:1067/1670 train_time:104821ms step_avg:98.24ms +step:1068/1670 train_time:104917ms step_avg:98.24ms +step:1069/1670 train_time:105014ms step_avg:98.24ms +step:1070/1670 train_time:105110ms step_avg:98.23ms +step:1071/1670 train_time:105207ms step_avg:98.23ms +step:1072/1670 train_time:105308ms step_avg:98.23ms +step:1073/1670 train_time:105410ms step_avg:98.24ms +step:1074/1670 train_time:105510ms step_avg:98.24ms +step:1075/1670 train_time:105608ms step_avg:98.24ms +step:1076/1670 train_time:105705ms step_avg:98.24ms +step:1077/1670 train_time:105802ms step_avg:98.24ms +step:1078/1670 train_time:105899ms step_avg:98.24ms +step:1079/1670 train_time:105995ms step_avg:98.23ms +step:1080/1670 train_time:106092ms step_avg:98.23ms +step:1081/1670 train_time:106188ms step_avg:98.23ms +step:1082/1670 train_time:106286ms step_avg:98.23ms +step:1083/1670 train_time:106387ms step_avg:98.23ms +step:1084/1670 train_time:106486ms step_avg:98.23ms +step:1085/1670 train_time:106585ms step_avg:98.24ms +step:1086/1670 train_time:106682ms step_avg:98.23ms +step:1087/1670 train_time:106779ms step_avg:98.23ms +step:1088/1670 train_time:106876ms step_avg:98.23ms +step:1089/1670 train_time:106972ms step_avg:98.23ms +step:1090/1670 train_time:107069ms step_avg:98.23ms +step:1091/1670 train_time:107166ms step_avg:98.23ms +step:1092/1670 train_time:107263ms step_avg:98.23ms +step:1093/1670 train_time:107363ms step_avg:98.23ms +step:1094/1670 train_time:107461ms step_avg:98.23ms +step:1095/1670 train_time:107560ms step_avg:98.23ms +step:1096/1670 train_time:107658ms step_avg:98.23ms +step:1097/1670 train_time:107755ms step_avg:98.23ms +step:1098/1670 train_time:107851ms step_avg:98.23ms +step:1099/1670 train_time:107948ms step_avg:98.22ms +step:1100/1670 train_time:108045ms step_avg:98.22ms +step:1101/1670 train_time:108143ms step_avg:98.22ms +step:1102/1670 train_time:108240ms step_avg:98.22ms +step:1103/1670 train_time:108336ms step_avg:98.22ms +step:1104/1670 train_time:108434ms step_avg:98.22ms +step:1105/1670 train_time:108531ms step_avg:98.22ms +step:1106/1670 train_time:108629ms step_avg:98.22ms +step:1107/1670 train_time:108727ms step_avg:98.22ms +step:1108/1670 train_time:108825ms step_avg:98.22ms +step:1109/1670 train_time:108924ms step_avg:98.22ms +step:1110/1670 train_time:109021ms step_avg:98.22ms +step:1111/1670 train_time:109117ms step_avg:98.22ms +step:1112/1670 train_time:109215ms step_avg:98.21ms +step:1113/1670 train_time:109311ms step_avg:98.21ms +step:1114/1670 train_time:109409ms step_avg:98.21ms +step:1115/1670 train_time:109507ms step_avg:98.21ms +step:1116/1670 train_time:109607ms step_avg:98.21ms +step:1117/1670 train_time:109705ms step_avg:98.21ms +step:1118/1670 train_time:109804ms step_avg:98.21ms +step:1119/1670 train_time:109904ms step_avg:98.22ms +step:1120/1670 train_time:110002ms step_avg:98.22ms +step:1121/1670 train_time:110101ms step_avg:98.22ms +step:1122/1670 train_time:110200ms step_avg:98.22ms +step:1123/1670 train_time:110299ms step_avg:98.22ms +step:1124/1670 train_time:110398ms step_avg:98.22ms +step:1125/1670 train_time:110497ms step_avg:98.22ms +step:1125/1670 val_loss:3.4234 train_time:110594ms step_avg:98.31ms +step:1126/1670 train_time:110618ms step_avg:98.24ms +step:1127/1670 train_time:110705ms step_avg:98.23ms +step:1128/1670 train_time:110804ms step_avg:98.23ms +step:1129/1670 train_time:110901ms step_avg:98.23ms +step:1130/1670 train_time:110998ms step_avg:98.23ms +step:1131/1670 train_time:111095ms step_avg:98.23ms +step:1132/1670 train_time:111192ms step_avg:98.23ms +step:1133/1670 train_time:111289ms step_avg:98.22ms +step:1134/1670 train_time:111386ms step_avg:98.22ms +step:1135/1670 train_time:111485ms step_avg:98.22ms +step:1136/1670 train_time:111587ms step_avg:98.23ms +step:1137/1670 train_time:111689ms step_avg:98.23ms +step:1138/1670 train_time:111788ms step_avg:98.23ms +step:1139/1670 train_time:111887ms step_avg:98.23ms +step:1140/1670 train_time:111986ms step_avg:98.23ms +step:1141/1670 train_time:112084ms step_avg:98.23ms +step:1142/1670 train_time:112181ms step_avg:98.23ms +step:1143/1670 train_time:112279ms step_avg:98.23ms +step:1144/1670 train_time:112376ms step_avg:98.23ms +step:1145/1670 train_time:112476ms step_avg:98.23ms +step:1146/1670 train_time:112576ms step_avg:98.23ms +step:1147/1670 train_time:112676ms step_avg:98.24ms +step:1148/1670 train_time:112776ms step_avg:98.24ms +step:1149/1670 train_time:112874ms step_avg:98.24ms +step:1150/1670 train_time:112972ms step_avg:98.24ms +step:1151/1670 train_time:113069ms step_avg:98.24ms +step:1152/1670 train_time:113166ms step_avg:98.23ms +step:1153/1670 train_time:113264ms step_avg:98.23ms +step:1154/1670 train_time:113362ms step_avg:98.23ms +step:1155/1670 train_time:113460ms step_avg:98.23ms +step:1156/1670 train_time:113562ms step_avg:98.24ms +step:1157/1670 train_time:113662ms step_avg:98.24ms +step:1158/1670 train_time:113763ms step_avg:98.24ms +step:1159/1670 train_time:113863ms step_avg:98.24ms +step:1160/1670 train_time:113961ms step_avg:98.24ms +step:1161/1670 train_time:114060ms step_avg:98.24ms +step:1162/1670 train_time:114158ms step_avg:98.24ms +step:1163/1670 train_time:114256ms step_avg:98.24ms +step:1164/1670 train_time:114353ms step_avg:98.24ms +step:1165/1670 train_time:114449ms step_avg:98.24ms +step:1166/1670 train_time:114548ms step_avg:98.24ms +step:1167/1670 train_time:114646ms step_avg:98.24ms +step:1168/1670 train_time:114747ms step_avg:98.24ms +step:1169/1670 train_time:114846ms step_avg:98.24ms +step:1170/1670 train_time:114945ms step_avg:98.24ms +step:1171/1670 train_time:115044ms step_avg:98.24ms +step:1172/1670 train_time:115143ms step_avg:98.25ms +step:1173/1670 train_time:115242ms step_avg:98.25ms +step:1174/1670 train_time:115341ms step_avg:98.25ms +step:1175/1670 train_time:115441ms step_avg:98.25ms +step:1176/1670 train_time:115540ms step_avg:98.25ms +step:1177/1670 train_time:115638ms step_avg:98.25ms +step:1178/1670 train_time:115736ms step_avg:98.25ms +step:1179/1670 train_time:115834ms step_avg:98.25ms +step:1180/1670 train_time:115933ms step_avg:98.25ms +step:1181/1670 train_time:116031ms step_avg:98.25ms +step:1182/1670 train_time:116129ms step_avg:98.25ms +step:1183/1670 train_time:116227ms step_avg:98.25ms +step:1184/1670 train_time:116325ms step_avg:98.25ms +step:1185/1670 train_time:116423ms step_avg:98.25ms +step:1186/1670 train_time:116523ms step_avg:98.25ms +step:1187/1670 train_time:116621ms step_avg:98.25ms +step:1188/1670 train_time:116720ms step_avg:98.25ms +step:1189/1670 train_time:116819ms step_avg:98.25ms +step:1190/1670 train_time:116920ms step_avg:98.25ms +step:1191/1670 train_time:117021ms step_avg:98.25ms +step:1192/1670 train_time:117120ms step_avg:98.26ms +step:1193/1670 train_time:117219ms step_avg:98.26ms +step:1194/1670 train_time:117317ms step_avg:98.26ms +step:1195/1670 train_time:117416ms step_avg:98.26ms +step:1196/1670 train_time:117514ms step_avg:98.26ms +step:1197/1670 train_time:117611ms step_avg:98.25ms +step:1198/1670 train_time:117708ms step_avg:98.25ms +step:1199/1670 train_time:117807ms step_avg:98.25ms +step:1200/1670 train_time:117904ms step_avg:98.25ms +step:1201/1670 train_time:118004ms step_avg:98.25ms +step:1202/1670 train_time:118102ms step_avg:98.25ms +step:1203/1670 train_time:118201ms step_avg:98.26ms +step:1204/1670 train_time:118299ms step_avg:98.25ms +step:1205/1670 train_time:118397ms step_avg:98.25ms +step:1206/1670 train_time:118496ms step_avg:98.26ms +step:1207/1670 train_time:118594ms step_avg:98.26ms +step:1208/1670 train_time:118691ms step_avg:98.25ms +step:1209/1670 train_time:118789ms step_avg:98.25ms +step:1210/1670 train_time:118887ms step_avg:98.25ms +step:1211/1670 train_time:118984ms step_avg:98.25ms +step:1212/1670 train_time:119083ms step_avg:98.25ms +step:1213/1670 train_time:119182ms step_avg:98.25ms +step:1214/1670 train_time:119282ms step_avg:98.26ms +step:1215/1670 train_time:119381ms step_avg:98.26ms +step:1216/1670 train_time:119481ms step_avg:98.26ms +step:1217/1670 train_time:119580ms step_avg:98.26ms +step:1218/1670 train_time:119681ms step_avg:98.26ms +step:1219/1670 train_time:119779ms step_avg:98.26ms +step:1220/1670 train_time:119878ms step_avg:98.26ms +step:1221/1670 train_time:119975ms step_avg:98.26ms +step:1222/1670 train_time:120072ms step_avg:98.26ms +step:1223/1670 train_time:120170ms step_avg:98.26ms +step:1224/1670 train_time:120268ms step_avg:98.26ms +step:1225/1670 train_time:120366ms step_avg:98.26ms +step:1226/1670 train_time:120465ms step_avg:98.26ms +step:1227/1670 train_time:120563ms step_avg:98.26ms +step:1228/1670 train_time:120663ms step_avg:98.26ms +step:1229/1670 train_time:120763ms step_avg:98.26ms +step:1230/1670 train_time:120861ms step_avg:98.26ms +step:1231/1670 train_time:120960ms step_avg:98.26ms +step:1232/1670 train_time:121058ms step_avg:98.26ms +step:1233/1670 train_time:121156ms step_avg:98.26ms +step:1234/1670 train_time:121255ms step_avg:98.26ms +step:1235/1670 train_time:121353ms step_avg:98.26ms +step:1236/1670 train_time:121450ms step_avg:98.26ms +step:1237/1670 train_time:121548ms step_avg:98.26ms +step:1238/1670 train_time:121646ms step_avg:98.26ms +step:1239/1670 train_time:121745ms step_avg:98.26ms +step:1240/1670 train_time:121843ms step_avg:98.26ms +step:1241/1670 train_time:121941ms step_avg:98.26ms +step:1242/1670 train_time:122041ms step_avg:98.26ms +step:1243/1670 train_time:122140ms step_avg:98.26ms +step:1244/1670 train_time:122238ms step_avg:98.26ms +step:1245/1670 train_time:122336ms step_avg:98.26ms +step:1246/1670 train_time:122434ms step_avg:98.26ms +step:1247/1670 train_time:122532ms step_avg:98.26ms +step:1248/1670 train_time:122629ms step_avg:98.26ms +step:1249/1670 train_time:122727ms step_avg:98.26ms +step:1250/1670 train_time:122824ms step_avg:98.26ms +step:1250/1670 val_loss:3.3805 train_time:122921ms step_avg:98.34ms +step:1251/1670 train_time:122944ms step_avg:98.28ms +step:1252/1670 train_time:123027ms step_avg:98.26ms +step:1253/1670 train_time:123127ms step_avg:98.27ms +step:1254/1670 train_time:123226ms step_avg:98.27ms +step:1255/1670 train_time:123322ms step_avg:98.26ms +step:1256/1670 train_time:123419ms step_avg:98.26ms +step:1257/1670 train_time:123516ms step_avg:98.26ms +step:1258/1670 train_time:123613ms step_avg:98.26ms +step:1259/1670 train_time:123710ms step_avg:98.26ms +step:1260/1670 train_time:123807ms step_avg:98.26ms +step:1261/1670 train_time:123908ms step_avg:98.26ms +step:1262/1670 train_time:124010ms step_avg:98.26ms +step:1263/1670 train_time:124111ms step_avg:98.27ms +step:1264/1670 train_time:124209ms step_avg:98.27ms +step:1265/1670 train_time:124308ms step_avg:98.27ms +step:1266/1670 train_time:124407ms step_avg:98.27ms +step:1267/1670 train_time:124505ms step_avg:98.27ms +step:1268/1670 train_time:124602ms step_avg:98.27ms +step:1269/1670 train_time:124698ms step_avg:98.27ms +step:1270/1670 train_time:124796ms step_avg:98.26ms +step:1271/1670 train_time:124894ms step_avg:98.26ms +step:1272/1670 train_time:124994ms step_avg:98.27ms +step:1273/1670 train_time:125094ms step_avg:98.27ms +step:1274/1670 train_time:125352ms step_avg:98.39ms +step:1275/1670 train_time:125536ms step_avg:98.46ms +step:1276/1670 train_time:125634ms step_avg:98.46ms +step:1277/1670 train_time:125731ms step_avg:98.46ms +step:1278/1670 train_time:125828ms step_avg:98.46ms +step:1279/1670 train_time:125925ms step_avg:98.46ms +step:1280/1670 train_time:126022ms step_avg:98.45ms +step:1281/1670 train_time:126119ms step_avg:98.45ms +step:1282/1670 train_time:126216ms step_avg:98.45ms +step:1283/1670 train_time:126314ms step_avg:98.45ms +step:1284/1670 train_time:126419ms step_avg:98.46ms +step:1285/1670 train_time:126520ms step_avg:98.46ms +step:1286/1670 train_time:126620ms step_avg:98.46ms +step:1287/1670 train_time:126719ms step_avg:98.46ms +step:1288/1670 train_time:126817ms step_avg:98.46ms +step:1289/1670 train_time:126914ms step_avg:98.46ms +step:1290/1670 train_time:127011ms step_avg:98.46ms +step:1291/1670 train_time:127109ms step_avg:98.46ms +step:1292/1670 train_time:127206ms step_avg:98.46ms +step:1293/1670 train_time:127304ms step_avg:98.46ms +step:1294/1670 train_time:127404ms step_avg:98.46ms +step:1295/1670 train_time:127504ms step_avg:98.46ms +step:1296/1670 train_time:127604ms step_avg:98.46ms +step:1297/1670 train_time:127703ms step_avg:98.46ms +step:1298/1670 train_time:127801ms step_avg:98.46ms +step:1299/1670 train_time:127899ms step_avg:98.46ms +step:1300/1670 train_time:127996ms step_avg:98.46ms +step:1301/1670 train_time:128094ms step_avg:98.46ms +step:1302/1670 train_time:128191ms step_avg:98.46ms +step:1303/1670 train_time:128289ms step_avg:98.46ms +step:1304/1670 train_time:128388ms step_avg:98.46ms +step:1305/1670 train_time:128488ms step_avg:98.46ms +step:1306/1670 train_time:128589ms step_avg:98.46ms +step:1307/1670 train_time:128690ms step_avg:98.46ms +step:1308/1670 train_time:128787ms step_avg:98.46ms +step:1309/1670 train_time:128886ms step_avg:98.46ms +step:1310/1670 train_time:128986ms step_avg:98.46ms +step:1311/1670 train_time:129086ms step_avg:98.46ms +step:1312/1670 train_time:129183ms step_avg:98.46ms +step:1313/1670 train_time:129280ms step_avg:98.46ms +step:1314/1670 train_time:129378ms step_avg:98.46ms +step:1315/1670 train_time:129475ms step_avg:98.46ms +step:1316/1670 train_time:129575ms step_avg:98.46ms +step:1317/1670 train_time:129674ms step_avg:98.46ms +step:1318/1670 train_time:129774ms step_avg:98.46ms +step:1319/1670 train_time:129874ms step_avg:98.46ms +step:1320/1670 train_time:129973ms step_avg:98.46ms +step:1321/1670 train_time:130072ms step_avg:98.46ms +step:1322/1670 train_time:130170ms step_avg:98.46ms +step:1323/1670 train_time:130268ms step_avg:98.46ms +step:1324/1670 train_time:130366ms step_avg:98.46ms +step:1325/1670 train_time:130465ms step_avg:98.46ms +step:1326/1670 train_time:130563ms step_avg:98.46ms +step:1327/1670 train_time:130661ms step_avg:98.46ms +step:1328/1670 train_time:130759ms step_avg:98.46ms +step:1329/1670 train_time:130857ms step_avg:98.46ms +step:1330/1670 train_time:130957ms step_avg:98.46ms +step:1331/1670 train_time:131055ms step_avg:98.46ms +step:1332/1670 train_time:131154ms step_avg:98.46ms +step:1333/1670 train_time:131253ms step_avg:98.46ms +step:1334/1670 train_time:131351ms step_avg:98.46ms +step:1335/1670 train_time:131449ms step_avg:98.46ms +step:1336/1670 train_time:131547ms step_avg:98.46ms +step:1337/1670 train_time:131646ms step_avg:98.46ms +step:1338/1670 train_time:131746ms step_avg:98.46ms +step:1339/1670 train_time:131847ms step_avg:98.47ms +step:1340/1670 train_time:131947ms step_avg:98.47ms +step:1341/1670 train_time:132047ms step_avg:98.47ms +step:1342/1670 train_time:132145ms step_avg:98.47ms +step:1343/1670 train_time:132244ms step_avg:98.47ms +step:1344/1670 train_time:132341ms step_avg:98.47ms +step:1345/1670 train_time:132439ms step_avg:98.47ms +step:1346/1670 train_time:132536ms step_avg:98.47ms +step:1347/1670 train_time:132635ms step_avg:98.47ms +step:1348/1670 train_time:132733ms step_avg:98.47ms +step:1349/1670 train_time:132833ms step_avg:98.47ms +step:1350/1670 train_time:132932ms step_avg:98.47ms +step:1351/1670 train_time:133031ms step_avg:98.47ms +step:1352/1670 train_time:133129ms step_avg:98.47ms +step:1353/1670 train_time:133229ms step_avg:98.47ms +step:1354/1670 train_time:133327ms step_avg:98.47ms +step:1355/1670 train_time:133426ms step_avg:98.47ms +step:1356/1670 train_time:133524ms step_avg:98.47ms +step:1357/1670 train_time:133623ms step_avg:98.47ms +step:1358/1670 train_time:133721ms step_avg:98.47ms +step:1359/1670 train_time:133820ms step_avg:98.47ms +step:1360/1670 train_time:133918ms step_avg:98.47ms +step:1361/1670 train_time:134016ms step_avg:98.47ms +step:1362/1670 train_time:134115ms step_avg:98.47ms +step:1363/1670 train_time:134213ms step_avg:98.47ms +step:1364/1670 train_time:134311ms step_avg:98.47ms +step:1365/1670 train_time:134409ms step_avg:98.47ms +step:1366/1670 train_time:134508ms step_avg:98.47ms +step:1367/1670 train_time:134606ms step_avg:98.47ms +step:1368/1670 train_time:134705ms step_avg:98.47ms +step:1369/1670 train_time:134803ms step_avg:98.47ms +step:1370/1670 train_time:134903ms step_avg:98.47ms +step:1371/1670 train_time:135001ms step_avg:98.47ms +step:1372/1670 train_time:135100ms step_avg:98.47ms +step:1373/1670 train_time:135199ms step_avg:98.47ms +step:1374/1670 train_time:135297ms step_avg:98.47ms +step:1375/1670 train_time:135396ms step_avg:98.47ms +step:1375/1670 val_loss:3.3434 train_time:135493ms step_avg:98.54ms +step:1376/1670 train_time:135517ms step_avg:98.49ms +step:1377/1670 train_time:135602ms step_avg:98.48ms +step:1378/1670 train_time:135701ms step_avg:98.48ms +step:1379/1670 train_time:135799ms step_avg:98.48ms +step:1380/1670 train_time:135897ms step_avg:98.48ms +step:1381/1670 train_time:135993ms step_avg:98.47ms +step:1382/1670 train_time:136091ms step_avg:98.47ms +step:1383/1670 train_time:136189ms step_avg:98.47ms +step:1384/1670 train_time:136286ms step_avg:98.47ms +step:1385/1670 train_time:136383ms step_avg:98.47ms +step:1386/1670 train_time:136484ms step_avg:98.47ms +step:1387/1670 train_time:136584ms step_avg:98.47ms +step:1388/1670 train_time:136686ms step_avg:98.48ms +step:1389/1670 train_time:136784ms step_avg:98.48ms +step:1390/1670 train_time:136882ms step_avg:98.48ms +step:1391/1670 train_time:136981ms step_avg:98.48ms +step:1392/1670 train_time:137079ms step_avg:98.48ms +step:1393/1670 train_time:137176ms step_avg:98.48ms +step:1394/1670 train_time:137273ms step_avg:98.47ms +step:1395/1670 train_time:137372ms step_avg:98.47ms +step:1396/1670 train_time:137472ms step_avg:98.48ms +step:1397/1670 train_time:137571ms step_avg:98.48ms +step:1398/1670 train_time:137672ms step_avg:98.48ms +step:1399/1670 train_time:137771ms step_avg:98.48ms +step:1400/1670 train_time:137870ms step_avg:98.48ms +step:1401/1670 train_time:137968ms step_avg:98.48ms +step:1402/1670 train_time:138067ms step_avg:98.48ms +step:1403/1670 train_time:138166ms step_avg:98.48ms +step:1404/1670 train_time:138264ms step_avg:98.48ms +step:1405/1670 train_time:138362ms step_avg:98.48ms +step:1406/1670 train_time:138461ms step_avg:98.48ms +step:1407/1670 train_time:138561ms step_avg:98.48ms +step:1408/1670 train_time:138660ms step_avg:98.48ms +step:1409/1670 train_time:138760ms step_avg:98.48ms +step:1410/1670 train_time:138858ms step_avg:98.48ms +step:1411/1670 train_time:138957ms step_avg:98.48ms +step:1412/1670 train_time:139056ms step_avg:98.48ms +step:1413/1670 train_time:139155ms step_avg:98.48ms +step:1414/1670 train_time:139254ms step_avg:98.48ms +step:1415/1670 train_time:139353ms step_avg:98.48ms +step:1416/1670 train_time:139451ms step_avg:98.48ms +step:1417/1670 train_time:139550ms step_avg:98.48ms +step:1418/1670 train_time:139649ms step_avg:98.48ms +step:1419/1670 train_time:139747ms step_avg:98.48ms +step:1420/1670 train_time:139845ms step_avg:98.48ms +step:1421/1670 train_time:139943ms step_avg:98.48ms +step:1422/1670 train_time:140042ms step_avg:98.48ms +step:1423/1670 train_time:140141ms step_avg:98.48ms +step:1424/1670 train_time:140239ms step_avg:98.48ms +step:1425/1670 train_time:140337ms step_avg:98.48ms +step:1426/1670 train_time:140436ms step_avg:98.48ms +step:1427/1670 train_time:140536ms step_avg:98.48ms +step:1428/1670 train_time:140634ms step_avg:98.48ms +step:1429/1670 train_time:140734ms step_avg:98.48ms +step:1430/1670 train_time:140834ms step_avg:98.49ms +step:1431/1670 train_time:140934ms step_avg:98.49ms +step:1432/1670 train_time:141032ms step_avg:98.49ms +step:1433/1670 train_time:141131ms step_avg:98.49ms +step:1434/1670 train_time:141230ms step_avg:98.49ms +step:1435/1670 train_time:141327ms step_avg:98.49ms +step:1436/1670 train_time:141425ms step_avg:98.49ms +step:1437/1670 train_time:141522ms step_avg:98.48ms +step:1438/1670 train_time:141620ms step_avg:98.48ms +step:1439/1670 train_time:141718ms step_avg:98.48ms +step:1440/1670 train_time:141817ms step_avg:98.48ms +step:1441/1670 train_time:141917ms step_avg:98.49ms +step:1442/1670 train_time:142017ms step_avg:98.49ms +step:1443/1670 train_time:142115ms step_avg:98.49ms +step:1444/1670 train_time:142215ms step_avg:98.49ms +step:1445/1670 train_time:142316ms step_avg:98.49ms +step:1446/1670 train_time:142415ms step_avg:98.49ms +step:1447/1670 train_time:142514ms step_avg:98.49ms +step:1448/1670 train_time:142613ms step_avg:98.49ms +step:1449/1670 train_time:142712ms step_avg:98.49ms +step:1450/1670 train_time:142810ms step_avg:98.49ms +step:1451/1670 train_time:142907ms step_avg:98.49ms +step:1452/1670 train_time:143007ms step_avg:98.49ms +step:1453/1670 train_time:143106ms step_avg:98.49ms +step:1454/1670 train_time:143203ms step_avg:98.49ms +step:1455/1670 train_time:143302ms step_avg:98.49ms +step:1456/1670 train_time:143401ms step_avg:98.49ms +step:1457/1670 train_time:143500ms step_avg:98.49ms +step:1458/1670 train_time:143598ms step_avg:98.49ms +step:1459/1670 train_time:143698ms step_avg:98.49ms +step:1460/1670 train_time:143797ms step_avg:98.49ms +step:1461/1670 train_time:143895ms step_avg:98.49ms +step:1462/1670 train_time:143994ms step_avg:98.49ms +step:1463/1670 train_time:144094ms step_avg:98.49ms +step:1464/1670 train_time:144193ms step_avg:98.49ms +step:1465/1670 train_time:144293ms step_avg:98.49ms +step:1466/1670 train_time:144392ms step_avg:98.49ms +step:1467/1670 train_time:144490ms step_avg:98.49ms +step:1468/1670 train_time:144588ms step_avg:98.49ms +step:1469/1670 train_time:144686ms step_avg:98.49ms +step:1470/1670 train_time:144784ms step_avg:98.49ms +step:1471/1670 train_time:144882ms step_avg:98.49ms +step:1472/1670 train_time:144980ms step_avg:98.49ms +step:1473/1670 train_time:145079ms step_avg:98.49ms +step:1474/1670 train_time:145179ms step_avg:98.49ms +step:1475/1670 train_time:145278ms step_avg:98.49ms +step:1476/1670 train_time:145378ms step_avg:98.49ms +step:1477/1670 train_time:145477ms step_avg:98.49ms +step:1478/1670 train_time:145576ms step_avg:98.50ms +step:1479/1670 train_time:145675ms step_avg:98.50ms +step:1480/1670 train_time:145774ms step_avg:98.50ms +step:1481/1670 train_time:145876ms step_avg:98.50ms +step:1482/1670 train_time:145975ms step_avg:98.50ms +step:1483/1670 train_time:146074ms step_avg:98.50ms +step:1484/1670 train_time:146172ms step_avg:98.50ms +step:1485/1670 train_time:146500ms step_avg:98.65ms +step:1486/1670 train_time:146574ms step_avg:98.64ms +step:1487/1670 train_time:146670ms step_avg:98.63ms +step:1488/1670 train_time:146766ms step_avg:98.63ms +step:1489/1670 train_time:146863ms step_avg:98.63ms +step:1490/1670 train_time:146961ms step_avg:98.63ms +step:1491/1670 train_time:147058ms step_avg:98.63ms +step:1492/1670 train_time:147156ms step_avg:98.63ms +step:1493/1670 train_time:147253ms step_avg:98.63ms +step:1494/1670 train_time:147351ms step_avg:98.63ms +step:1495/1670 train_time:147451ms step_avg:98.63ms +step:1496/1670 train_time:147554ms step_avg:98.63ms +step:1497/1670 train_time:147655ms step_avg:98.63ms +step:1498/1670 train_time:147753ms step_avg:98.63ms +step:1499/1670 train_time:147852ms step_avg:98.63ms +step:1500/1670 train_time:147950ms step_avg:98.63ms +step:1500/1670 val_loss:3.3113 train_time:148046ms step_avg:98.70ms +step:1501/1670 train_time:148069ms step_avg:98.65ms +step:1502/1670 train_time:148151ms step_avg:98.64ms +step:1503/1670 train_time:148251ms step_avg:98.64ms +step:1504/1670 train_time:148348ms step_avg:98.64ms +step:1505/1670 train_time:148445ms step_avg:98.63ms +step:1506/1670 train_time:148541ms step_avg:98.63ms +step:1507/1670 train_time:148638ms step_avg:98.63ms +step:1508/1670 train_time:148736ms step_avg:98.63ms +step:1509/1670 train_time:148833ms step_avg:98.63ms +step:1510/1670 train_time:148931ms step_avg:98.63ms +step:1511/1670 train_time:149031ms step_avg:98.63ms +step:1512/1670 train_time:149132ms step_avg:98.63ms +step:1513/1670 train_time:149232ms step_avg:98.63ms +step:1514/1670 train_time:149331ms step_avg:98.63ms +step:1515/1670 train_time:149430ms step_avg:98.63ms +step:1516/1670 train_time:149528ms step_avg:98.63ms +step:1517/1670 train_time:149625ms step_avg:98.63ms +step:1518/1670 train_time:149722ms step_avg:98.63ms +step:1519/1670 train_time:149820ms step_avg:98.63ms +step:1520/1670 train_time:149918ms step_avg:98.63ms +step:1521/1670 train_time:150018ms step_avg:98.63ms +step:1522/1670 train_time:150118ms step_avg:98.63ms +step:1523/1670 train_time:150218ms step_avg:98.63ms +step:1524/1670 train_time:150318ms step_avg:98.63ms +step:1525/1670 train_time:150419ms step_avg:98.64ms +step:1526/1670 train_time:150518ms step_avg:98.64ms +step:1527/1670 train_time:150618ms step_avg:98.64ms +step:1528/1670 train_time:150716ms step_avg:98.64ms +step:1529/1670 train_time:150813ms step_avg:98.64ms +step:1530/1670 train_time:150910ms step_avg:98.63ms +step:1531/1670 train_time:151008ms step_avg:98.63ms +step:1532/1670 train_time:151105ms step_avg:98.63ms +step:1533/1670 train_time:151204ms step_avg:98.63ms +step:1534/1670 train_time:151303ms step_avg:98.63ms +step:1535/1670 train_time:151401ms step_avg:98.63ms +step:1536/1670 train_time:151500ms step_avg:98.63ms +step:1537/1670 train_time:151599ms step_avg:98.63ms +step:1538/1670 train_time:151697ms step_avg:98.63ms +step:1539/1670 train_time:151795ms step_avg:98.63ms +step:1540/1670 train_time:151894ms step_avg:98.63ms +step:1541/1670 train_time:151994ms step_avg:98.63ms +step:1542/1670 train_time:152094ms step_avg:98.63ms +step:1543/1670 train_time:152194ms step_avg:98.64ms +step:1544/1670 train_time:152294ms step_avg:98.64ms +step:1545/1670 train_time:152394ms step_avg:98.64ms +step:1546/1670 train_time:152491ms step_avg:98.64ms +step:1547/1670 train_time:152589ms step_avg:98.64ms +step:1548/1670 train_time:152687ms step_avg:98.64ms +step:1549/1670 train_time:152785ms step_avg:98.63ms +step:1550/1670 train_time:152882ms step_avg:98.63ms +step:1551/1670 train_time:152981ms step_avg:98.63ms +step:1552/1670 train_time:153080ms step_avg:98.63ms +step:1553/1670 train_time:153179ms step_avg:98.63ms +step:1554/1670 train_time:153278ms step_avg:98.63ms +step:1555/1670 train_time:153378ms step_avg:98.64ms +step:1556/1670 train_time:153477ms step_avg:98.64ms +step:1557/1670 train_time:153576ms step_avg:98.64ms +step:1558/1670 train_time:153674ms step_avg:98.64ms +step:1559/1670 train_time:153772ms step_avg:98.64ms +step:1560/1670 train_time:153871ms step_avg:98.64ms +step:1561/1670 train_time:153968ms step_avg:98.63ms +step:1562/1670 train_time:154066ms step_avg:98.63ms +step:1563/1670 train_time:154164ms step_avg:98.63ms +step:1564/1670 train_time:154262ms step_avg:98.63ms +step:1565/1670 train_time:154362ms step_avg:98.63ms +step:1566/1670 train_time:154461ms step_avg:98.63ms +step:1567/1670 train_time:154559ms step_avg:98.63ms +step:1568/1670 train_time:154658ms step_avg:98.63ms +step:1569/1670 train_time:154757ms step_avg:98.63ms +step:1570/1670 train_time:154855ms step_avg:98.63ms +step:1571/1670 train_time:154953ms step_avg:98.63ms +step:1572/1670 train_time:155053ms step_avg:98.63ms +step:1573/1670 train_time:155151ms step_avg:98.63ms +step:1574/1670 train_time:155250ms step_avg:98.63ms +step:1575/1670 train_time:155350ms step_avg:98.63ms +step:1576/1670 train_time:155447ms step_avg:98.63ms +step:1577/1670 train_time:155545ms step_avg:98.63ms +step:1578/1670 train_time:155643ms step_avg:98.63ms +step:1579/1670 train_time:155742ms step_avg:98.63ms +step:1580/1670 train_time:155841ms step_avg:98.63ms +step:1581/1670 train_time:155941ms step_avg:98.63ms +step:1582/1670 train_time:156039ms step_avg:98.63ms +step:1583/1670 train_time:156138ms step_avg:98.63ms +step:1584/1670 train_time:156236ms step_avg:98.63ms +step:1585/1670 train_time:156335ms step_avg:98.63ms +step:1586/1670 train_time:156436ms step_avg:98.64ms +step:1587/1670 train_time:156536ms step_avg:98.64ms +step:1588/1670 train_time:156635ms step_avg:98.64ms +step:1589/1670 train_time:156735ms step_avg:98.64ms +step:1590/1670 train_time:156833ms step_avg:98.64ms +step:1591/1670 train_time:156932ms step_avg:98.64ms +step:1592/1670 train_time:157030ms step_avg:98.64ms +step:1593/1670 train_time:157129ms step_avg:98.64ms +step:1594/1670 train_time:157226ms step_avg:98.64ms +step:1595/1670 train_time:157323ms step_avg:98.64ms +step:1596/1670 train_time:157421ms step_avg:98.63ms +step:1597/1670 train_time:157520ms step_avg:98.63ms +step:1598/1670 train_time:157619ms step_avg:98.64ms +step:1599/1670 train_time:157719ms step_avg:98.64ms +step:1600/1670 train_time:157816ms step_avg:98.64ms +step:1601/1670 train_time:157917ms step_avg:98.64ms +step:1602/1670 train_time:158016ms step_avg:98.64ms +step:1603/1670 train_time:158115ms step_avg:98.64ms +step:1604/1670 train_time:158214ms step_avg:98.64ms +step:1605/1670 train_time:158315ms step_avg:98.64ms +step:1606/1670 train_time:158415ms step_avg:98.64ms +step:1607/1670 train_time:158514ms step_avg:98.64ms +step:1608/1670 train_time:158614ms step_avg:98.64ms +step:1609/1670 train_time:158712ms step_avg:98.64ms +step:1610/1670 train_time:158810ms step_avg:98.64ms +step:1611/1670 train_time:158909ms step_avg:98.64ms +step:1612/1670 train_time:159006ms step_avg:98.64ms +step:1613/1670 train_time:159104ms step_avg:98.64ms +step:1614/1670 train_time:159202ms step_avg:98.64ms +step:1615/1670 train_time:159300ms step_avg:98.64ms +step:1616/1670 train_time:159399ms step_avg:98.64ms +step:1617/1670 train_time:159498ms step_avg:98.64ms +step:1618/1670 train_time:159598ms step_avg:98.64ms +step:1619/1670 train_time:159698ms step_avg:98.64ms +step:1620/1670 train_time:159798ms step_avg:98.64ms +step:1621/1670 train_time:159897ms step_avg:98.64ms +step:1622/1670 train_time:159996ms step_avg:98.64ms +step:1623/1670 train_time:160095ms step_avg:98.64ms +step:1624/1670 train_time:160196ms step_avg:98.64ms +step:1625/1670 train_time:160296ms step_avg:98.64ms +step:1625/1670 val_loss:3.2842 train_time:160393ms step_avg:98.70ms +step:1626/1670 train_time:160416ms step_avg:98.66ms +step:1627/1670 train_time:160497ms step_avg:98.65ms +step:1628/1670 train_time:160596ms step_avg:98.65ms +step:1629/1670 train_time:160694ms step_avg:98.65ms +step:1630/1670 train_time:160792ms step_avg:98.65ms +step:1631/1670 train_time:160890ms step_avg:98.64ms +step:1632/1670 train_time:160987ms step_avg:98.64ms +step:1633/1670 train_time:161084ms step_avg:98.64ms +step:1634/1670 train_time:161182ms step_avg:98.64ms +step:1635/1670 train_time:161279ms step_avg:98.64ms +step:1636/1670 train_time:161378ms step_avg:98.64ms +step:1637/1670 train_time:161478ms step_avg:98.64ms +step:1638/1670 train_time:161577ms step_avg:98.64ms +step:1639/1670 train_time:161676ms step_avg:98.64ms +step:1640/1670 train_time:161774ms step_avg:98.64ms +step:1641/1670 train_time:161872ms step_avg:98.64ms +step:1642/1670 train_time:161970ms step_avg:98.64ms +step:1643/1670 train_time:162068ms step_avg:98.64ms +step:1644/1670 train_time:162166ms step_avg:98.64ms +step:1645/1670 train_time:162265ms step_avg:98.64ms +step:1646/1670 train_time:162364ms step_avg:98.64ms +step:1647/1670 train_time:162464ms step_avg:98.64ms +step:1648/1670 train_time:162565ms step_avg:98.64ms +step:1649/1670 train_time:162666ms step_avg:98.65ms +step:1650/1670 train_time:162767ms step_avg:98.65ms +step:1651/1670 train_time:162865ms step_avg:98.65ms +step:1652/1670 train_time:162964ms step_avg:98.65ms +step:1653/1670 train_time:163063ms step_avg:98.65ms +step:1654/1670 train_time:163160ms step_avg:98.65ms +step:1655/1670 train_time:163257ms step_avg:98.64ms +step:1656/1670 train_time:163354ms step_avg:98.64ms +step:1657/1670 train_time:163453ms step_avg:98.64ms +step:1658/1670 train_time:163553ms step_avg:98.64ms +step:1659/1670 train_time:163653ms step_avg:98.65ms +step:1660/1670 train_time:163752ms step_avg:98.65ms +step:1661/1670 train_time:163850ms step_avg:98.65ms +step:1662/1670 train_time:163949ms step_avg:98.65ms +step:1663/1670 train_time:164047ms step_avg:98.65ms +step:1664/1670 train_time:164146ms step_avg:98.65ms +step:1665/1670 train_time:164244ms step_avg:98.64ms +step:1666/1670 train_time:164341ms step_avg:98.64ms +step:1667/1670 train_time:164440ms step_avg:98.64ms +step:1668/1670 train_time:164539ms step_avg:98.64ms +step:1669/1670 train_time:164639ms step_avg:98.65ms +step:1670/1670 train_time:164737ms step_avg:98.64ms +step:1670/1670 val_loss:3.2768 train_time:164833ms step_avg:98.70ms +peak memory allocated: 34073 MiB reserved: 49756 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt b/records/050925_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt new file mode 100644 index 000000000..58bbe8537 --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:49:16 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 74357 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 74358 C /usr/bin/python3 610MiB | +| 0 N/A N/A 74359 C /usr/bin/python3 610MiB | +| 0 N/A N/A 74360 C /usr/bin/python3 610MiB | +| 0 N/A N/A 74361 C /usr/bin/python3 610MiB | +| 0 N/A N/A 74362 C /usr/bin/python3 610MiB | +| 0 N/A N/A 74363 C /usr/bin/python3 610MiB | +| 0 N/A N/A 74364 C /usr/bin/python3 610MiB | +| 1 N/A N/A 74358 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 74359 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 74360 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 74361 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 74362 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 74363 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 74364 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:380ms step_avg:380.20ms +step:2/1670 train_time:402ms step_avg:201.05ms +step:3/1670 train_time:475ms step_avg:158.24ms +step:4/1670 train_time:568ms step_avg:142.12ms +step:5/1670 train_time:663ms step_avg:132.63ms +step:6/1670 train_time:759ms step_avg:126.42ms +step:7/1670 train_time:854ms step_avg:122.05ms +step:8/1670 train_time:949ms step_avg:118.60ms +step:9/1670 train_time:1044ms step_avg:116.01ms +step:10/1670 train_time:1140ms step_avg:114.00ms +step:11/1670 train_time:1235ms step_avg:112.28ms +step:12/1670 train_time:1332ms step_avg:111.00ms +step:13/1670 train_time:1429ms step_avg:109.94ms +step:14/1670 train_time:1526ms step_avg:108.99ms +step:15/1670 train_time:1622ms step_avg:108.16ms +step:16/1670 train_time:1718ms step_avg:107.39ms +step:17/1670 train_time:1814ms step_avg:106.72ms +step:18/1670 train_time:1909ms step_avg:106.08ms +step:19/1670 train_time:2004ms step_avg:105.49ms +step:20/1670 train_time:2100ms step_avg:105.00ms +step:21/1670 train_time:2196ms step_avg:104.56ms +step:22/1670 train_time:2292ms step_avg:104.17ms +step:23/1670 train_time:2388ms step_avg:103.84ms +step:24/1670 train_time:2485ms step_avg:103.52ms +step:25/1670 train_time:2582ms step_avg:103.27ms +step:26/1670 train_time:2679ms step_avg:103.04ms +step:27/1670 train_time:2776ms step_avg:102.82ms +step:28/1670 train_time:2871ms step_avg:102.55ms +step:29/1670 train_time:2967ms step_avg:102.30ms +step:30/1670 train_time:3062ms step_avg:102.07ms +step:31/1670 train_time:3158ms step_avg:101.86ms +step:32/1670 train_time:3254ms step_avg:101.68ms +step:33/1670 train_time:3350ms step_avg:101.52ms +step:34/1670 train_time:3446ms step_avg:101.36ms +step:35/1670 train_time:3543ms step_avg:101.21ms +step:36/1670 train_time:3640ms step_avg:101.10ms +step:37/1670 train_time:3736ms step_avg:100.97ms +step:38/1670 train_time:3832ms step_avg:100.84ms +step:39/1670 train_time:3927ms step_avg:100.70ms +step:40/1670 train_time:4022ms step_avg:100.56ms +step:41/1670 train_time:4118ms step_avg:100.44ms +step:42/1670 train_time:4214ms step_avg:100.32ms +step:43/1670 train_time:4309ms step_avg:100.21ms +step:44/1670 train_time:4405ms step_avg:100.12ms +step:45/1670 train_time:4501ms step_avg:100.03ms +step:46/1670 train_time:4599ms step_avg:99.98ms +step:47/1670 train_time:4695ms step_avg:99.90ms +step:48/1670 train_time:4792ms step_avg:99.84ms +step:49/1670 train_time:4888ms step_avg:99.75ms +step:50/1670 train_time:4984ms step_avg:99.68ms +step:51/1670 train_time:5079ms step_avg:99.60ms +step:52/1670 train_time:5175ms step_avg:99.53ms +step:53/1670 train_time:5271ms step_avg:99.45ms +step:54/1670 train_time:5366ms step_avg:99.37ms +step:55/1670 train_time:5462ms step_avg:99.32ms +step:56/1670 train_time:5560ms step_avg:99.29ms +step:57/1670 train_time:5656ms step_avg:99.23ms +step:58/1670 train_time:5753ms step_avg:99.20ms +step:59/1670 train_time:5849ms step_avg:99.14ms +step:60/1670 train_time:5945ms step_avg:99.08ms +step:61/1670 train_time:6041ms step_avg:99.03ms +step:62/1670 train_time:6137ms step_avg:98.99ms +step:63/1670 train_time:6234ms step_avg:98.95ms +step:64/1670 train_time:6329ms step_avg:98.89ms +step:65/1670 train_time:6424ms step_avg:98.84ms +step:66/1670 train_time:6522ms step_avg:98.82ms +step:67/1670 train_time:6618ms step_avg:98.78ms +step:68/1670 train_time:6716ms step_avg:98.77ms +step:69/1670 train_time:6813ms step_avg:98.74ms +step:70/1670 train_time:6907ms step_avg:98.68ms +step:71/1670 train_time:7002ms step_avg:98.63ms +step:72/1670 train_time:7098ms step_avg:98.59ms +step:73/1670 train_time:7195ms step_avg:98.56ms +step:74/1670 train_time:7291ms step_avg:98.53ms +step:75/1670 train_time:7386ms step_avg:98.48ms +step:76/1670 train_time:7482ms step_avg:98.45ms +step:77/1670 train_time:7578ms step_avg:98.41ms +step:78/1670 train_time:7674ms step_avg:98.38ms +step:79/1670 train_time:7769ms step_avg:98.34ms +step:80/1670 train_time:7865ms step_avg:98.31ms +step:81/1670 train_time:7962ms step_avg:98.30ms +step:82/1670 train_time:8059ms step_avg:98.28ms +step:83/1670 train_time:8155ms step_avg:98.25ms +step:84/1670 train_time:8250ms step_avg:98.21ms +step:85/1670 train_time:8345ms step_avg:98.18ms +step:86/1670 train_time:8442ms step_avg:98.16ms +step:87/1670 train_time:8537ms step_avg:98.13ms +step:88/1670 train_time:8633ms step_avg:98.11ms +step:89/1670 train_time:8729ms step_avg:98.07ms +step:90/1670 train_time:8824ms step_avg:98.05ms +step:91/1670 train_time:8922ms step_avg:98.04ms +step:92/1670 train_time:9018ms step_avg:98.02ms +step:93/1670 train_time:9115ms step_avg:98.01ms +step:94/1670 train_time:9209ms step_avg:97.97ms +step:95/1670 train_time:9305ms step_avg:97.95ms +step:96/1670 train_time:9401ms step_avg:97.92ms +step:97/1670 train_time:9496ms step_avg:97.90ms +step:98/1670 train_time:9593ms step_avg:97.88ms +step:99/1670 train_time:9688ms step_avg:97.86ms +step:100/1670 train_time:9784ms step_avg:97.84ms +step:101/1670 train_time:9881ms step_avg:97.83ms +step:102/1670 train_time:9977ms step_avg:97.81ms +step:103/1670 train_time:10072ms step_avg:97.79ms +step:104/1670 train_time:10168ms step_avg:97.77ms +step:105/1670 train_time:10263ms step_avg:97.74ms +step:106/1670 train_time:10359ms step_avg:97.73ms +step:107/1670 train_time:10455ms step_avg:97.71ms +step:108/1670 train_time:10551ms step_avg:97.69ms +step:109/1670 train_time:10646ms step_avg:97.67ms +step:110/1670 train_time:10741ms step_avg:97.65ms +step:111/1670 train_time:10837ms step_avg:97.63ms +step:112/1670 train_time:10932ms step_avg:97.61ms +step:113/1670 train_time:11028ms step_avg:97.60ms +step:114/1670 train_time:11124ms step_avg:97.58ms +step:115/1670 train_time:11220ms step_avg:97.57ms +step:116/1670 train_time:11317ms step_avg:97.56ms +step:117/1670 train_time:11412ms step_avg:97.54ms +step:118/1670 train_time:11508ms step_avg:97.52ms +step:119/1670 train_time:11604ms step_avg:97.51ms +step:120/1670 train_time:11700ms step_avg:97.50ms +step:121/1670 train_time:11796ms step_avg:97.49ms +step:122/1670 train_time:11891ms step_avg:97.47ms +step:123/1670 train_time:11987ms step_avg:97.45ms +step:124/1670 train_time:12083ms step_avg:97.44ms +step:125/1670 train_time:12179ms step_avg:97.44ms +step:125/1670 val_loss:4.2999 train_time:12274ms step_avg:98.19ms +step:126/1670 train_time:12298ms step_avg:97.60ms +step:127/1670 train_time:12382ms step_avg:97.50ms +step:128/1670 train_time:12484ms step_avg:97.53ms +step:129/1670 train_time:12582ms step_avg:97.53ms +step:130/1670 train_time:12677ms step_avg:97.52ms +step:131/1670 train_time:12772ms step_avg:97.50ms +step:132/1670 train_time:12868ms step_avg:97.48ms +step:133/1670 train_time:12962ms step_avg:97.46ms +step:134/1670 train_time:13057ms step_avg:97.44ms +step:135/1670 train_time:13152ms step_avg:97.42ms +step:136/1670 train_time:13247ms step_avg:97.40ms +step:137/1670 train_time:13344ms step_avg:97.40ms +step:138/1670 train_time:13442ms step_avg:97.41ms +step:139/1670 train_time:13541ms step_avg:97.42ms +step:140/1670 train_time:13638ms step_avg:97.41ms +step:141/1670 train_time:13734ms step_avg:97.40ms +step:142/1670 train_time:13829ms step_avg:97.39ms +step:143/1670 train_time:13924ms step_avg:97.37ms +step:144/1670 train_time:14019ms step_avg:97.35ms +step:145/1670 train_time:14114ms step_avg:97.34ms +step:146/1670 train_time:14210ms step_avg:97.33ms +step:147/1670 train_time:14305ms step_avg:97.31ms +step:148/1670 train_time:14402ms step_avg:97.31ms +step:149/1670 train_time:14500ms step_avg:97.31ms +step:150/1670 train_time:14597ms step_avg:97.31ms +step:151/1670 train_time:14693ms step_avg:97.30ms +step:152/1670 train_time:14788ms step_avg:97.29ms +step:153/1670 train_time:14884ms step_avg:97.28ms +step:154/1670 train_time:14979ms step_avg:97.27ms +step:155/1670 train_time:15074ms step_avg:97.25ms +step:156/1670 train_time:15170ms step_avg:97.25ms +step:157/1670 train_time:15265ms step_avg:97.23ms +step:158/1670 train_time:15361ms step_avg:97.22ms +step:159/1670 train_time:15458ms step_avg:97.22ms +step:160/1670 train_time:15555ms step_avg:97.22ms +step:161/1670 train_time:15651ms step_avg:97.21ms +step:162/1670 train_time:15746ms step_avg:97.20ms +step:163/1670 train_time:15841ms step_avg:97.19ms +step:164/1670 train_time:15938ms step_avg:97.18ms +step:165/1670 train_time:16032ms step_avg:97.16ms +step:166/1670 train_time:16127ms step_avg:97.15ms +step:167/1670 train_time:16223ms step_avg:97.14ms +step:168/1670 train_time:16319ms step_avg:97.14ms +step:169/1670 train_time:16415ms step_avg:97.13ms +step:170/1670 train_time:16512ms step_avg:97.13ms +step:171/1670 train_time:16608ms step_avg:97.12ms +step:172/1670 train_time:16704ms step_avg:97.12ms +step:173/1670 train_time:16800ms step_avg:97.11ms +step:174/1670 train_time:16897ms step_avg:97.11ms +step:175/1670 train_time:16993ms step_avg:97.10ms +step:176/1670 train_time:17087ms step_avg:97.09ms +step:177/1670 train_time:17183ms step_avg:97.08ms +step:178/1670 train_time:17279ms step_avg:97.07ms +step:179/1670 train_time:17375ms step_avg:97.07ms +step:180/1670 train_time:17471ms step_avg:97.06ms +step:181/1670 train_time:17565ms step_avg:97.05ms +step:182/1670 train_time:17662ms step_avg:97.04ms +step:183/1670 train_time:17758ms step_avg:97.04ms +step:184/1670 train_time:17854ms step_avg:97.03ms +step:185/1670 train_time:17949ms step_avg:97.02ms +step:186/1670 train_time:18044ms step_avg:97.01ms +step:187/1670 train_time:18140ms step_avg:97.00ms +step:188/1670 train_time:18235ms step_avg:97.00ms +step:189/1670 train_time:18330ms step_avg:96.99ms +step:190/1670 train_time:18425ms step_avg:96.98ms +step:191/1670 train_time:18521ms step_avg:96.97ms +step:192/1670 train_time:18618ms step_avg:96.97ms +step:193/1670 train_time:18715ms step_avg:96.97ms +step:194/1670 train_time:18811ms step_avg:96.97ms +step:195/1670 train_time:18907ms step_avg:96.96ms +step:196/1670 train_time:19003ms step_avg:96.95ms +step:197/1670 train_time:19099ms step_avg:96.95ms +step:198/1670 train_time:19195ms step_avg:96.94ms +step:199/1670 train_time:19290ms step_avg:96.94ms +step:200/1670 train_time:19385ms step_avg:96.93ms +step:201/1670 train_time:19481ms step_avg:96.92ms +step:202/1670 train_time:19578ms step_avg:96.92ms +step:203/1670 train_time:19674ms step_avg:96.92ms +step:204/1670 train_time:19769ms step_avg:96.91ms +step:205/1670 train_time:19865ms step_avg:96.90ms +step:206/1670 train_time:19961ms step_avg:96.90ms +step:207/1670 train_time:20057ms step_avg:96.89ms +step:208/1670 train_time:20153ms step_avg:96.89ms +step:209/1670 train_time:20248ms step_avg:96.88ms +step:210/1670 train_time:20343ms step_avg:96.87ms +step:211/1670 train_time:20439ms step_avg:96.87ms +step:212/1670 train_time:20535ms step_avg:96.86ms +step:213/1670 train_time:20810ms step_avg:97.70ms +step:214/1670 train_time:20968ms step_avg:97.98ms +step:215/1670 train_time:21062ms step_avg:97.96ms +step:216/1670 train_time:21157ms step_avg:97.95ms +step:217/1670 train_time:21252ms step_avg:97.93ms +step:218/1670 train_time:21346ms step_avg:97.92ms +step:219/1670 train_time:21441ms step_avg:97.90ms +step:220/1670 train_time:21537ms step_avg:97.89ms +step:221/1670 train_time:21631ms step_avg:97.88ms +step:222/1670 train_time:21725ms step_avg:97.86ms +step:223/1670 train_time:21823ms step_avg:97.86ms +step:224/1670 train_time:21922ms step_avg:97.87ms +step:225/1670 train_time:22021ms step_avg:97.87ms +step:226/1670 train_time:22116ms step_avg:97.86ms +step:227/1670 train_time:22211ms step_avg:97.85ms +step:228/1670 train_time:22307ms step_avg:97.84ms +step:229/1670 train_time:22401ms step_avg:97.82ms +step:230/1670 train_time:22496ms step_avg:97.81ms +step:231/1670 train_time:22592ms step_avg:97.80ms +step:232/1670 train_time:22687ms step_avg:97.79ms +step:233/1670 train_time:22782ms step_avg:97.78ms +step:234/1670 train_time:22880ms step_avg:97.78ms +step:235/1670 train_time:22977ms step_avg:97.78ms +step:236/1670 train_time:23074ms step_avg:97.77ms +step:237/1670 train_time:23170ms step_avg:97.76ms +step:238/1670 train_time:23265ms step_avg:97.75ms +step:239/1670 train_time:23360ms step_avg:97.74ms +step:240/1670 train_time:23455ms step_avg:97.73ms +step:241/1670 train_time:23550ms step_avg:97.72ms +step:242/1670 train_time:23645ms step_avg:97.71ms +step:243/1670 train_time:23740ms step_avg:97.69ms +step:244/1670 train_time:23836ms step_avg:97.69ms +step:245/1670 train_time:23932ms step_avg:97.68ms +step:246/1670 train_time:24028ms step_avg:97.67ms +step:247/1670 train_time:24124ms step_avg:97.67ms +step:248/1670 train_time:24219ms step_avg:97.66ms +step:249/1670 train_time:24316ms step_avg:97.65ms +step:250/1670 train_time:24411ms step_avg:97.64ms +step:250/1670 val_loss:3.9672 train_time:24505ms step_avg:98.02ms +step:251/1670 train_time:24529ms step_avg:97.73ms +step:252/1670 train_time:24607ms step_avg:97.65ms +step:253/1670 train_time:24706ms step_avg:97.65ms +step:254/1670 train_time:24801ms step_avg:97.64ms +step:255/1670 train_time:24896ms step_avg:97.63ms +step:256/1670 train_time:24991ms step_avg:97.62ms +step:257/1670 train_time:25086ms step_avg:97.61ms +step:258/1670 train_time:25180ms step_avg:97.60ms +step:259/1670 train_time:25275ms step_avg:97.59ms +step:260/1670 train_time:25370ms step_avg:97.58ms +step:261/1670 train_time:25465ms step_avg:97.57ms +step:262/1670 train_time:25563ms step_avg:97.57ms +step:263/1670 train_time:25659ms step_avg:97.56ms +step:264/1670 train_time:25757ms step_avg:97.56ms +step:265/1670 train_time:25853ms step_avg:97.56ms +step:266/1670 train_time:25948ms step_avg:97.55ms +step:267/1670 train_time:26043ms step_avg:97.54ms +step:268/1670 train_time:26139ms step_avg:97.53ms +step:269/1670 train_time:26234ms step_avg:97.52ms +step:270/1670 train_time:26329ms step_avg:97.51ms +step:271/1670 train_time:26424ms step_avg:97.51ms +step:272/1670 train_time:26520ms step_avg:97.50ms +step:273/1670 train_time:26616ms step_avg:97.50ms +step:274/1670 train_time:26714ms step_avg:97.50ms +step:275/1670 train_time:26810ms step_avg:97.49ms +step:276/1670 train_time:26905ms step_avg:97.48ms +step:277/1670 train_time:27002ms step_avg:97.48ms +step:278/1670 train_time:27095ms step_avg:97.47ms +step:279/1670 train_time:27191ms step_avg:97.46ms +step:280/1670 train_time:27286ms step_avg:97.45ms +step:281/1670 train_time:27382ms step_avg:97.44ms +step:282/1670 train_time:27477ms step_avg:97.44ms +step:283/1670 train_time:27573ms step_avg:97.43ms +step:284/1670 train_time:27669ms step_avg:97.43ms +step:285/1670 train_time:27765ms step_avg:97.42ms +step:286/1670 train_time:27861ms step_avg:97.41ms +step:287/1670 train_time:27956ms step_avg:97.41ms +step:288/1670 train_time:28053ms step_avg:97.40ms +step:289/1670 train_time:28148ms step_avg:97.40ms +step:290/1670 train_time:28242ms step_avg:97.39ms +step:291/1670 train_time:28338ms step_avg:97.38ms +step:292/1670 train_time:28434ms step_avg:97.38ms +step:293/1670 train_time:28530ms step_avg:97.37ms +step:294/1670 train_time:28625ms step_avg:97.36ms +step:295/1670 train_time:28720ms step_avg:97.36ms +step:296/1670 train_time:28815ms step_avg:97.35ms +step:297/1670 train_time:28912ms step_avg:97.35ms +step:298/1670 train_time:29007ms step_avg:97.34ms +step:299/1670 train_time:29102ms step_avg:97.33ms +step:300/1670 train_time:29197ms step_avg:97.32ms +step:301/1670 train_time:29293ms step_avg:97.32ms +step:302/1670 train_time:29389ms step_avg:97.31ms +step:303/1670 train_time:29484ms step_avg:97.31ms +step:304/1670 train_time:29580ms step_avg:97.30ms +step:305/1670 train_time:29676ms step_avg:97.30ms +step:306/1670 train_time:29773ms step_avg:97.30ms +step:307/1670 train_time:29868ms step_avg:97.29ms +step:308/1670 train_time:29963ms step_avg:97.28ms +step:309/1670 train_time:30059ms step_avg:97.28ms +step:310/1670 train_time:30155ms step_avg:97.27ms +step:311/1670 train_time:30251ms step_avg:97.27ms +step:312/1670 train_time:30346ms step_avg:97.26ms +step:313/1670 train_time:30441ms step_avg:97.26ms +step:314/1670 train_time:30536ms step_avg:97.25ms +step:315/1670 train_time:30633ms step_avg:97.25ms +step:316/1670 train_time:30729ms step_avg:97.24ms +step:317/1670 train_time:30824ms step_avg:97.24ms +step:318/1670 train_time:30920ms step_avg:97.23ms +step:319/1670 train_time:31016ms step_avg:97.23ms +step:320/1670 train_time:31112ms step_avg:97.22ms +step:321/1670 train_time:31208ms step_avg:97.22ms +step:322/1670 train_time:31303ms step_avg:97.22ms +step:323/1670 train_time:31399ms step_avg:97.21ms +step:324/1670 train_time:31494ms step_avg:97.20ms +step:325/1670 train_time:31590ms step_avg:97.20ms +step:326/1670 train_time:31685ms step_avg:97.19ms +step:327/1670 train_time:31780ms step_avg:97.19ms +step:328/1670 train_time:31875ms step_avg:97.18ms +step:329/1670 train_time:31971ms step_avg:97.18ms +step:330/1670 train_time:32067ms step_avg:97.17ms +step:331/1670 train_time:32163ms step_avg:97.17ms +step:332/1670 train_time:32258ms step_avg:97.16ms +step:333/1670 train_time:32354ms step_avg:97.16ms +step:334/1670 train_time:32450ms step_avg:97.16ms +step:335/1670 train_time:32546ms step_avg:97.15ms +step:336/1670 train_time:32642ms step_avg:97.15ms +step:337/1670 train_time:32738ms step_avg:97.15ms +step:338/1670 train_time:32834ms step_avg:97.14ms +step:339/1670 train_time:32929ms step_avg:97.14ms +step:340/1670 train_time:33025ms step_avg:97.13ms +step:341/1670 train_time:33120ms step_avg:97.13ms +step:342/1670 train_time:33216ms step_avg:97.12ms +step:343/1670 train_time:33312ms step_avg:97.12ms +step:344/1670 train_time:33409ms step_avg:97.12ms +step:345/1670 train_time:33504ms step_avg:97.11ms +step:346/1670 train_time:33600ms step_avg:97.11ms +step:347/1670 train_time:33696ms step_avg:97.11ms +step:348/1670 train_time:33793ms step_avg:97.11ms +step:349/1670 train_time:33889ms step_avg:97.10ms +step:350/1670 train_time:33985ms step_avg:97.10ms +step:351/1670 train_time:34080ms step_avg:97.09ms +step:352/1670 train_time:34176ms step_avg:97.09ms +step:353/1670 train_time:34272ms step_avg:97.09ms +step:354/1670 train_time:34368ms step_avg:97.08ms +step:355/1670 train_time:34463ms step_avg:97.08ms +step:356/1670 train_time:34558ms step_avg:97.07ms +step:357/1670 train_time:34654ms step_avg:97.07ms +step:358/1670 train_time:34750ms step_avg:97.07ms +step:359/1670 train_time:34846ms step_avg:97.06ms +step:360/1670 train_time:34941ms step_avg:97.06ms +step:361/1670 train_time:35037ms step_avg:97.05ms +step:362/1670 train_time:35132ms step_avg:97.05ms +step:363/1670 train_time:35228ms step_avg:97.05ms +step:364/1670 train_time:35324ms step_avg:97.04ms +step:365/1670 train_time:35419ms step_avg:97.04ms +step:366/1670 train_time:35515ms step_avg:97.04ms +step:367/1670 train_time:35611ms step_avg:97.03ms +step:368/1670 train_time:35707ms step_avg:97.03ms +step:369/1670 train_time:35803ms step_avg:97.03ms +step:370/1670 train_time:35899ms step_avg:97.02ms +step:371/1670 train_time:35995ms step_avg:97.02ms +step:372/1670 train_time:36092ms step_avg:97.02ms +step:373/1670 train_time:36188ms step_avg:97.02ms +step:374/1670 train_time:36284ms step_avg:97.02ms +step:375/1670 train_time:36379ms step_avg:97.01ms +step:375/1670 val_loss:3.8161 train_time:36473ms step_avg:97.26ms +step:376/1670 train_time:36499ms step_avg:97.07ms +step:377/1670 train_time:36577ms step_avg:97.02ms +step:378/1670 train_time:36675ms step_avg:97.02ms +step:379/1670 train_time:36771ms step_avg:97.02ms +step:380/1670 train_time:36866ms step_avg:97.01ms +step:381/1670 train_time:36961ms step_avg:97.01ms +step:382/1670 train_time:37055ms step_avg:97.00ms +step:383/1670 train_time:37149ms step_avg:97.00ms +step:384/1670 train_time:37245ms step_avg:96.99ms +step:385/1670 train_time:37340ms step_avg:96.99ms +step:386/1670 train_time:37437ms step_avg:96.99ms +step:387/1670 train_time:37535ms step_avg:96.99ms +step:388/1670 train_time:37632ms step_avg:96.99ms +step:389/1670 train_time:37728ms step_avg:96.99ms +step:390/1670 train_time:37824ms step_avg:96.98ms +step:391/1670 train_time:37919ms step_avg:96.98ms +step:392/1670 train_time:38015ms step_avg:96.98ms +step:393/1670 train_time:38109ms step_avg:96.97ms +step:394/1670 train_time:38205ms step_avg:96.97ms +step:395/1670 train_time:38300ms step_avg:96.96ms +step:396/1670 train_time:38394ms step_avg:96.96ms +step:397/1670 train_time:38490ms step_avg:96.95ms +step:398/1670 train_time:38587ms step_avg:96.95ms +step:399/1670 train_time:38684ms step_avg:96.95ms +step:400/1670 train_time:38780ms step_avg:96.95ms +step:401/1670 train_time:38877ms step_avg:96.95ms +step:402/1670 train_time:38972ms step_avg:96.94ms +step:403/1670 train_time:39067ms step_avg:96.94ms +step:404/1670 train_time:39162ms step_avg:96.94ms +step:405/1670 train_time:39257ms step_avg:96.93ms +step:406/1670 train_time:39352ms step_avg:96.93ms +step:407/1670 train_time:39447ms step_avg:96.92ms +step:408/1670 train_time:39544ms step_avg:96.92ms +step:409/1670 train_time:39640ms step_avg:96.92ms +step:410/1670 train_time:39737ms step_avg:96.92ms +step:411/1670 train_time:39832ms step_avg:96.92ms +step:412/1670 train_time:39928ms step_avg:96.91ms +step:413/1670 train_time:40024ms step_avg:96.91ms +step:414/1670 train_time:40120ms step_avg:96.91ms +step:415/1670 train_time:40215ms step_avg:96.90ms +step:416/1670 train_time:40310ms step_avg:96.90ms +step:417/1670 train_time:40405ms step_avg:96.89ms +step:418/1670 train_time:40501ms step_avg:96.89ms +step:419/1670 train_time:40596ms step_avg:96.89ms +step:420/1670 train_time:40691ms step_avg:96.88ms +step:421/1670 train_time:40787ms step_avg:96.88ms +step:422/1670 train_time:40884ms step_avg:96.88ms +step:423/1670 train_time:40980ms step_avg:96.88ms +step:424/1670 train_time:41076ms step_avg:96.88ms +step:425/1670 train_time:41338ms step_avg:97.27ms +step:426/1670 train_time:41461ms step_avg:97.33ms +step:427/1670 train_time:41556ms step_avg:97.32ms +step:428/1670 train_time:41650ms step_avg:97.31ms +step:429/1670 train_time:41745ms step_avg:97.31ms +step:430/1670 train_time:41839ms step_avg:97.30ms +step:431/1670 train_time:41934ms step_avg:97.29ms +step:432/1670 train_time:42029ms step_avg:97.29ms +step:433/1670 train_time:42124ms step_avg:97.28ms +step:434/1670 train_time:42219ms step_avg:97.28ms +step:435/1670 train_time:42319ms step_avg:97.29ms +step:436/1670 train_time:42417ms step_avg:97.29ms +step:437/1670 train_time:42514ms step_avg:97.29ms +step:438/1670 train_time:42610ms step_avg:97.28ms +step:439/1670 train_time:42706ms step_avg:97.28ms +step:440/1670 train_time:42802ms step_avg:97.28ms +step:441/1670 train_time:42897ms step_avg:97.27ms +step:442/1670 train_time:42992ms step_avg:97.27ms +step:443/1670 train_time:43087ms step_avg:97.26ms +step:444/1670 train_time:43182ms step_avg:97.26ms +step:445/1670 train_time:43278ms step_avg:97.25ms +step:446/1670 train_time:43374ms step_avg:97.25ms +step:447/1670 train_time:43471ms step_avg:97.25ms +step:448/1670 train_time:43567ms step_avg:97.25ms +step:449/1670 train_time:43664ms step_avg:97.25ms +step:450/1670 train_time:43759ms step_avg:97.24ms +step:451/1670 train_time:43855ms step_avg:97.24ms +step:452/1670 train_time:43950ms step_avg:97.23ms +step:453/1670 train_time:44044ms step_avg:97.23ms +step:454/1670 train_time:44139ms step_avg:97.22ms +step:455/1670 train_time:44234ms step_avg:97.22ms +step:456/1670 train_time:44330ms step_avg:97.22ms +step:457/1670 train_time:44427ms step_avg:97.21ms +step:458/1670 train_time:44524ms step_avg:97.21ms +step:459/1670 train_time:44621ms step_avg:97.21ms +step:460/1670 train_time:44716ms step_avg:97.21ms +step:461/1670 train_time:44811ms step_avg:97.20ms +step:462/1670 train_time:44906ms step_avg:97.20ms +step:463/1670 train_time:45002ms step_avg:97.20ms +step:464/1670 train_time:45097ms step_avg:97.19ms +step:465/1670 train_time:45192ms step_avg:97.19ms +step:466/1670 train_time:45288ms step_avg:97.18ms +step:467/1670 train_time:45385ms step_avg:97.19ms +step:468/1670 train_time:45483ms step_avg:97.19ms +step:469/1670 train_time:45579ms step_avg:97.18ms +step:470/1670 train_time:45676ms step_avg:97.18ms +step:471/1670 train_time:45771ms step_avg:97.18ms +step:472/1670 train_time:45867ms step_avg:97.18ms +step:473/1670 train_time:45963ms step_avg:97.17ms +step:474/1670 train_time:46058ms step_avg:97.17ms +step:475/1670 train_time:46153ms step_avg:97.16ms +step:476/1670 train_time:46249ms step_avg:97.16ms +step:477/1670 train_time:46345ms step_avg:97.16ms +step:478/1670 train_time:46441ms step_avg:97.16ms +step:479/1670 train_time:46536ms step_avg:97.15ms +step:480/1670 train_time:46632ms step_avg:97.15ms +step:481/1670 train_time:46728ms step_avg:97.15ms +step:482/1670 train_time:46825ms step_avg:97.15ms +step:483/1670 train_time:46921ms step_avg:97.15ms +step:484/1670 train_time:47017ms step_avg:97.14ms +step:485/1670 train_time:47112ms step_avg:97.14ms +step:486/1670 train_time:47207ms step_avg:97.13ms +step:487/1670 train_time:47303ms step_avg:97.13ms +step:488/1670 train_time:47399ms step_avg:97.13ms +step:489/1670 train_time:47494ms step_avg:97.12ms +step:490/1670 train_time:47590ms step_avg:97.12ms +step:491/1670 train_time:47687ms step_avg:97.12ms +step:492/1670 train_time:47783ms step_avg:97.12ms +step:493/1670 train_time:47879ms step_avg:97.12ms +step:494/1670 train_time:47975ms step_avg:97.11ms +step:495/1670 train_time:48070ms step_avg:97.11ms +step:496/1670 train_time:48166ms step_avg:97.11ms +step:497/1670 train_time:48261ms step_avg:97.10ms +step:498/1670 train_time:48357ms step_avg:97.10ms +step:499/1670 train_time:48452ms step_avg:97.10ms +step:500/1670 train_time:48549ms step_avg:97.10ms +step:500/1670 val_loss:3.7116 train_time:48644ms step_avg:97.29ms +step:501/1670 train_time:48668ms step_avg:97.14ms +step:502/1670 train_time:48747ms step_avg:97.10ms +step:503/1670 train_time:48846ms step_avg:97.11ms +step:504/1670 train_time:48942ms step_avg:97.11ms +step:505/1670 train_time:49037ms step_avg:97.10ms +step:506/1670 train_time:49131ms step_avg:97.10ms +step:507/1670 train_time:49226ms step_avg:97.09ms +step:508/1670 train_time:49320ms step_avg:97.09ms +step:509/1670 train_time:49415ms step_avg:97.08ms +step:510/1670 train_time:49510ms step_avg:97.08ms +step:511/1670 train_time:49606ms step_avg:97.08ms +step:512/1670 train_time:49703ms step_avg:97.08ms +step:513/1670 train_time:49801ms step_avg:97.08ms +step:514/1670 train_time:49897ms step_avg:97.08ms +step:515/1670 train_time:49994ms step_avg:97.08ms +step:516/1670 train_time:50089ms step_avg:97.07ms +step:517/1670 train_time:50184ms step_avg:97.07ms +step:518/1670 train_time:50279ms step_avg:97.06ms +step:519/1670 train_time:50374ms step_avg:97.06ms +step:520/1670 train_time:50469ms step_avg:97.06ms +step:521/1670 train_time:50565ms step_avg:97.05ms +step:522/1670 train_time:50660ms step_avg:97.05ms +step:523/1670 train_time:50757ms step_avg:97.05ms +step:524/1670 train_time:50855ms step_avg:97.05ms +step:525/1670 train_time:50951ms step_avg:97.05ms +step:526/1670 train_time:51046ms step_avg:97.05ms +step:527/1670 train_time:51142ms step_avg:97.04ms +step:528/1670 train_time:51237ms step_avg:97.04ms +step:529/1670 train_time:51332ms step_avg:97.04ms +step:530/1670 train_time:51428ms step_avg:97.03ms +step:531/1670 train_time:51523ms step_avg:97.03ms +step:532/1670 train_time:51619ms step_avg:97.03ms +step:533/1670 train_time:51716ms step_avg:97.03ms +step:534/1670 train_time:51812ms step_avg:97.03ms +step:535/1670 train_time:51908ms step_avg:97.03ms +step:536/1670 train_time:52004ms step_avg:97.02ms +step:537/1670 train_time:52100ms step_avg:97.02ms +step:538/1670 train_time:52196ms step_avg:97.02ms +step:539/1670 train_time:52291ms step_avg:97.01ms +step:540/1670 train_time:52386ms step_avg:97.01ms +step:541/1670 train_time:52482ms step_avg:97.01ms +step:542/1670 train_time:52578ms step_avg:97.01ms +step:543/1670 train_time:52674ms step_avg:97.01ms +step:544/1670 train_time:52771ms step_avg:97.00ms +step:545/1670 train_time:52866ms step_avg:97.00ms +step:546/1670 train_time:52963ms step_avg:97.00ms +step:547/1670 train_time:53059ms step_avg:97.00ms +step:548/1670 train_time:53155ms step_avg:97.00ms +step:549/1670 train_time:53251ms step_avg:97.00ms +step:550/1670 train_time:53345ms step_avg:96.99ms +step:551/1670 train_time:53441ms step_avg:96.99ms +step:552/1670 train_time:53537ms step_avg:96.99ms +step:553/1670 train_time:53633ms step_avg:96.99ms +step:554/1670 train_time:53730ms step_avg:96.99ms +step:555/1670 train_time:53826ms step_avg:96.98ms +step:556/1670 train_time:53921ms step_avg:96.98ms +step:557/1670 train_time:54016ms step_avg:96.98ms +step:558/1670 train_time:54113ms step_avg:96.98ms +step:559/1670 train_time:54210ms step_avg:96.98ms +step:560/1670 train_time:54306ms step_avg:96.98ms +step:561/1670 train_time:54402ms step_avg:96.97ms +step:562/1670 train_time:54499ms step_avg:96.97ms +step:563/1670 train_time:54597ms step_avg:96.98ms +step:564/1670 train_time:54695ms step_avg:96.98ms +step:565/1670 train_time:54794ms step_avg:96.98ms +step:566/1670 train_time:54891ms step_avg:96.98ms +step:567/1670 train_time:54988ms step_avg:96.98ms +step:568/1670 train_time:55085ms step_avg:96.98ms +step:569/1670 train_time:55181ms step_avg:96.98ms +step:570/1670 train_time:55279ms step_avg:96.98ms +step:571/1670 train_time:55377ms step_avg:96.98ms +step:572/1670 train_time:55473ms step_avg:96.98ms +step:573/1670 train_time:55570ms step_avg:96.98ms +step:574/1670 train_time:55667ms step_avg:96.98ms +step:575/1670 train_time:55764ms step_avg:96.98ms +step:576/1670 train_time:55861ms step_avg:96.98ms +step:577/1670 train_time:55959ms step_avg:96.98ms +step:578/1670 train_time:56057ms step_avg:96.99ms +step:579/1670 train_time:56155ms step_avg:96.99ms +step:580/1670 train_time:56253ms step_avg:96.99ms +step:581/1670 train_time:56350ms step_avg:96.99ms +step:582/1670 train_time:56446ms step_avg:96.99ms +step:583/1670 train_time:56543ms step_avg:96.99ms +step:584/1670 train_time:56640ms step_avg:96.99ms +step:585/1670 train_time:56738ms step_avg:96.99ms +step:586/1670 train_time:56835ms step_avg:96.99ms +step:587/1670 train_time:56934ms step_avg:96.99ms +step:588/1670 train_time:57032ms step_avg:96.99ms +step:589/1670 train_time:57129ms step_avg:96.99ms +step:590/1670 train_time:57225ms step_avg:96.99ms +step:591/1670 train_time:57322ms step_avg:96.99ms +step:592/1670 train_time:57419ms step_avg:96.99ms +step:593/1670 train_time:57518ms step_avg:96.99ms +step:594/1670 train_time:57617ms step_avg:97.00ms +step:595/1670 train_time:57714ms step_avg:97.00ms +step:596/1670 train_time:57812ms step_avg:97.00ms +step:597/1670 train_time:57909ms step_avg:97.00ms +step:598/1670 train_time:58006ms step_avg:97.00ms +step:599/1670 train_time:58103ms step_avg:97.00ms +step:600/1670 train_time:58200ms step_avg:97.00ms +step:601/1670 train_time:58298ms step_avg:97.00ms +step:602/1670 train_time:58395ms step_avg:97.00ms +step:603/1670 train_time:58493ms step_avg:97.00ms +step:604/1670 train_time:58590ms step_avg:97.00ms +step:605/1670 train_time:58687ms step_avg:97.00ms +step:606/1670 train_time:58783ms step_avg:97.00ms +step:607/1670 train_time:58881ms step_avg:97.00ms +step:608/1670 train_time:58979ms step_avg:97.00ms +step:609/1670 train_time:59077ms step_avg:97.01ms +step:610/1670 train_time:59175ms step_avg:97.01ms +step:611/1670 train_time:59272ms step_avg:97.01ms +step:612/1670 train_time:59368ms step_avg:97.01ms +step:613/1670 train_time:59465ms step_avg:97.01ms +step:614/1670 train_time:59563ms step_avg:97.01ms +step:615/1670 train_time:59660ms step_avg:97.01ms +step:616/1670 train_time:59757ms step_avg:97.01ms +step:617/1670 train_time:59856ms step_avg:97.01ms +step:618/1670 train_time:59953ms step_avg:97.01ms +step:619/1670 train_time:60051ms step_avg:97.01ms +step:620/1670 train_time:60149ms step_avg:97.01ms +step:621/1670 train_time:60246ms step_avg:97.02ms +step:622/1670 train_time:60342ms step_avg:97.01ms +step:623/1670 train_time:60439ms step_avg:97.01ms +step:624/1670 train_time:60538ms step_avg:97.02ms +step:625/1670 train_time:60635ms step_avg:97.02ms +step:625/1670 val_loss:3.6117 train_time:60731ms step_avg:97.17ms +step:626/1670 train_time:60755ms step_avg:97.05ms +step:627/1670 train_time:60839ms step_avg:97.03ms +step:628/1670 train_time:60939ms step_avg:97.04ms +step:629/1670 train_time:61035ms step_avg:97.04ms +step:630/1670 train_time:61131ms step_avg:97.03ms +step:631/1670 train_time:61227ms step_avg:97.03ms +step:632/1670 train_time:61323ms step_avg:97.03ms +step:633/1670 train_time:61419ms step_avg:97.03ms +step:634/1670 train_time:61515ms step_avg:97.03ms +step:635/1670 train_time:61611ms step_avg:97.02ms +step:636/1670 train_time:61710ms step_avg:97.03ms +step:637/1670 train_time:61808ms step_avg:97.03ms +step:638/1670 train_time:61906ms step_avg:97.03ms +step:639/1670 train_time:62281ms step_avg:97.47ms +step:640/1670 train_time:62353ms step_avg:97.43ms +step:641/1670 train_time:62449ms step_avg:97.42ms +step:642/1670 train_time:62544ms step_avg:97.42ms +step:643/1670 train_time:62641ms step_avg:97.42ms +step:644/1670 train_time:62737ms step_avg:97.42ms +step:645/1670 train_time:62833ms step_avg:97.42ms +step:646/1670 train_time:62929ms step_avg:97.41ms +step:647/1670 train_time:63024ms step_avg:97.41ms +step:648/1670 train_time:63121ms step_avg:97.41ms +step:649/1670 train_time:63222ms step_avg:97.41ms +step:650/1670 train_time:63323ms step_avg:97.42ms +step:651/1670 train_time:63422ms step_avg:97.42ms +step:652/1670 train_time:63520ms step_avg:97.42ms +step:653/1670 train_time:63617ms step_avg:97.42ms +step:654/1670 train_time:63713ms step_avg:97.42ms +step:655/1670 train_time:63808ms step_avg:97.42ms +step:656/1670 train_time:63904ms step_avg:97.41ms +step:657/1670 train_time:64000ms step_avg:97.41ms +step:658/1670 train_time:64098ms step_avg:97.41ms +step:659/1670 train_time:64196ms step_avg:97.41ms +step:660/1670 train_time:64296ms step_avg:97.42ms +step:661/1670 train_time:64395ms step_avg:97.42ms +step:662/1670 train_time:64493ms step_avg:97.42ms +step:663/1670 train_time:64590ms step_avg:97.42ms +step:664/1670 train_time:64687ms step_avg:97.42ms +step:665/1670 train_time:64783ms step_avg:97.42ms +step:666/1670 train_time:64879ms step_avg:97.42ms +step:667/1670 train_time:64976ms step_avg:97.42ms +step:668/1670 train_time:65074ms step_avg:97.42ms +step:669/1670 train_time:65171ms step_avg:97.41ms +step:670/1670 train_time:65268ms step_avg:97.41ms +step:671/1670 train_time:65365ms step_avg:97.41ms +step:672/1670 train_time:65463ms step_avg:97.41ms +step:673/1670 train_time:65560ms step_avg:97.41ms +step:674/1670 train_time:65658ms step_avg:97.42ms +step:675/1670 train_time:65756ms step_avg:97.42ms +step:676/1670 train_time:65853ms step_avg:97.42ms +step:677/1670 train_time:65949ms step_avg:97.41ms +step:678/1670 train_time:66045ms step_avg:97.41ms +step:679/1670 train_time:66142ms step_avg:97.41ms +step:680/1670 train_time:66240ms step_avg:97.41ms +step:681/1670 train_time:66338ms step_avg:97.41ms +step:682/1670 train_time:66436ms step_avg:97.41ms +step:683/1670 train_time:66534ms step_avg:97.41ms +step:684/1670 train_time:66632ms step_avg:97.41ms +step:685/1670 train_time:66729ms step_avg:97.41ms +step:686/1670 train_time:66826ms step_avg:97.41ms +step:687/1670 train_time:66923ms step_avg:97.41ms +step:688/1670 train_time:67019ms step_avg:97.41ms +step:689/1670 train_time:67117ms step_avg:97.41ms +step:690/1670 train_time:67214ms step_avg:97.41ms +step:691/1670 train_time:67311ms step_avg:97.41ms +step:692/1670 train_time:67409ms step_avg:97.41ms +step:693/1670 train_time:67505ms step_avg:97.41ms +step:694/1670 train_time:67602ms step_avg:97.41ms +step:695/1670 train_time:67699ms step_avg:97.41ms +step:696/1670 train_time:67797ms step_avg:97.41ms +step:697/1670 train_time:67895ms step_avg:97.41ms +step:698/1670 train_time:67993ms step_avg:97.41ms +step:699/1670 train_time:68089ms step_avg:97.41ms +step:700/1670 train_time:68186ms step_avg:97.41ms +step:701/1670 train_time:68282ms step_avg:97.41ms +step:702/1670 train_time:68383ms step_avg:97.41ms +step:703/1670 train_time:68479ms step_avg:97.41ms +step:704/1670 train_time:68578ms step_avg:97.41ms +step:705/1670 train_time:68675ms step_avg:97.41ms +step:706/1670 train_time:68772ms step_avg:97.41ms +step:707/1670 train_time:68869ms step_avg:97.41ms +step:708/1670 train_time:68965ms step_avg:97.41ms +step:709/1670 train_time:69062ms step_avg:97.41ms +step:710/1670 train_time:69159ms step_avg:97.41ms +step:711/1670 train_time:69257ms step_avg:97.41ms +step:712/1670 train_time:69354ms step_avg:97.41ms +step:713/1670 train_time:69451ms step_avg:97.41ms +step:714/1670 train_time:69547ms step_avg:97.41ms +step:715/1670 train_time:69644ms step_avg:97.40ms +step:716/1670 train_time:69741ms step_avg:97.40ms +step:717/1670 train_time:69839ms step_avg:97.40ms +step:718/1670 train_time:69937ms step_avg:97.40ms +step:719/1670 train_time:70034ms step_avg:97.40ms +step:720/1670 train_time:70131ms step_avg:97.40ms +step:721/1670 train_time:70228ms step_avg:97.40ms +step:722/1670 train_time:70324ms step_avg:97.40ms +step:723/1670 train_time:70421ms step_avg:97.40ms +step:724/1670 train_time:70518ms step_avg:97.40ms +step:725/1670 train_time:70616ms step_avg:97.40ms +step:726/1670 train_time:70714ms step_avg:97.40ms +step:727/1670 train_time:70811ms step_avg:97.40ms +step:728/1670 train_time:70908ms step_avg:97.40ms +step:729/1670 train_time:71005ms step_avg:97.40ms +step:730/1670 train_time:71102ms step_avg:97.40ms +step:731/1670 train_time:71200ms step_avg:97.40ms +step:732/1670 train_time:71297ms step_avg:97.40ms +step:733/1670 train_time:71395ms step_avg:97.40ms +step:734/1670 train_time:71492ms step_avg:97.40ms +step:735/1670 train_time:71589ms step_avg:97.40ms +step:736/1670 train_time:71686ms step_avg:97.40ms +step:737/1670 train_time:71783ms step_avg:97.40ms +step:738/1670 train_time:71880ms step_avg:97.40ms +step:739/1670 train_time:71977ms step_avg:97.40ms +step:740/1670 train_time:72076ms step_avg:97.40ms +step:741/1670 train_time:72172ms step_avg:97.40ms +step:742/1670 train_time:72269ms step_avg:97.40ms +step:743/1670 train_time:72366ms step_avg:97.40ms +step:744/1670 train_time:72463ms step_avg:97.40ms +step:745/1670 train_time:72560ms step_avg:97.40ms +step:746/1670 train_time:72658ms step_avg:97.40ms +step:747/1670 train_time:72757ms step_avg:97.40ms +step:748/1670 train_time:72853ms step_avg:97.40ms +step:749/1670 train_time:72950ms step_avg:97.40ms +step:750/1670 train_time:73046ms step_avg:97.39ms +step:750/1670 val_loss:3.5598 train_time:73142ms step_avg:97.52ms +step:751/1670 train_time:73166ms step_avg:97.42ms +step:752/1670 train_time:73249ms step_avg:97.41ms +step:753/1670 train_time:73351ms step_avg:97.41ms +step:754/1670 train_time:73448ms step_avg:97.41ms +step:755/1670 train_time:73544ms step_avg:97.41ms +step:756/1670 train_time:73640ms step_avg:97.41ms +step:757/1670 train_time:73737ms step_avg:97.41ms +step:758/1670 train_time:73833ms step_avg:97.41ms +step:759/1670 train_time:73930ms step_avg:97.40ms +step:760/1670 train_time:74025ms step_avg:97.40ms +step:761/1670 train_time:74122ms step_avg:97.40ms +step:762/1670 train_time:74221ms step_avg:97.40ms +step:763/1670 train_time:74320ms step_avg:97.41ms +step:764/1670 train_time:74420ms step_avg:97.41ms +step:765/1670 train_time:74517ms step_avg:97.41ms +step:766/1670 train_time:74614ms step_avg:97.41ms +step:767/1670 train_time:74711ms step_avg:97.41ms +step:768/1670 train_time:74807ms step_avg:97.41ms +step:769/1670 train_time:74903ms step_avg:97.40ms +step:770/1670 train_time:75000ms step_avg:97.40ms +step:771/1670 train_time:75096ms step_avg:97.40ms +step:772/1670 train_time:75195ms step_avg:97.40ms +step:773/1670 train_time:75295ms step_avg:97.41ms +step:774/1670 train_time:75395ms step_avg:97.41ms +step:775/1670 train_time:75495ms step_avg:97.41ms +step:776/1670 train_time:75593ms step_avg:97.41ms +step:777/1670 train_time:75691ms step_avg:97.41ms +step:778/1670 train_time:75788ms step_avg:97.41ms +step:779/1670 train_time:75884ms step_avg:97.41ms +step:780/1670 train_time:75980ms step_avg:97.41ms +step:781/1670 train_time:76077ms step_avg:97.41ms +step:782/1670 train_time:76175ms step_avg:97.41ms +step:783/1670 train_time:76273ms step_avg:97.41ms +step:784/1670 train_time:76371ms step_avg:97.41ms +step:785/1670 train_time:76469ms step_avg:97.41ms +step:786/1670 train_time:76566ms step_avg:97.41ms +step:787/1670 train_time:76663ms step_avg:97.41ms +step:788/1670 train_time:76760ms step_avg:97.41ms +step:789/1670 train_time:76857ms step_avg:97.41ms +step:790/1670 train_time:76955ms step_avg:97.41ms +step:791/1670 train_time:77053ms step_avg:97.41ms +step:792/1670 train_time:77151ms step_avg:97.41ms +step:793/1670 train_time:77248ms step_avg:97.41ms +step:794/1670 train_time:77344ms step_avg:97.41ms +step:795/1670 train_time:77442ms step_avg:97.41ms +step:796/1670 train_time:77539ms step_avg:97.41ms +step:797/1670 train_time:77637ms step_avg:97.41ms +step:798/1670 train_time:77735ms step_avg:97.41ms +step:799/1670 train_time:77832ms step_avg:97.41ms +step:800/1670 train_time:77929ms step_avg:97.41ms +step:801/1670 train_time:78025ms step_avg:97.41ms +step:802/1670 train_time:78122ms step_avg:97.41ms +step:803/1670 train_time:78219ms step_avg:97.41ms +step:804/1670 train_time:78317ms step_avg:97.41ms +step:805/1670 train_time:78416ms step_avg:97.41ms +step:806/1670 train_time:78514ms step_avg:97.41ms +step:807/1670 train_time:78612ms step_avg:97.41ms +step:808/1670 train_time:78709ms step_avg:97.41ms +step:809/1670 train_time:78805ms step_avg:97.41ms +step:810/1670 train_time:78902ms step_avg:97.41ms +step:811/1670 train_time:78999ms step_avg:97.41ms +step:812/1670 train_time:79097ms step_avg:97.41ms +step:813/1670 train_time:79194ms step_avg:97.41ms +step:814/1670 train_time:79292ms step_avg:97.41ms +step:815/1670 train_time:79389ms step_avg:97.41ms +step:816/1670 train_time:79487ms step_avg:97.41ms +step:817/1670 train_time:79583ms step_avg:97.41ms +step:818/1670 train_time:79680ms step_avg:97.41ms +step:819/1670 train_time:79778ms step_avg:97.41ms +step:820/1670 train_time:79875ms step_avg:97.41ms +step:821/1670 train_time:79973ms step_avg:97.41ms +step:822/1670 train_time:80070ms step_avg:97.41ms +step:823/1670 train_time:80167ms step_avg:97.41ms +step:824/1670 train_time:80265ms step_avg:97.41ms +step:825/1670 train_time:80362ms step_avg:97.41ms +step:826/1670 train_time:80459ms step_avg:97.41ms +step:827/1670 train_time:80556ms step_avg:97.41ms +step:828/1670 train_time:80655ms step_avg:97.41ms +step:829/1670 train_time:80754ms step_avg:97.41ms +step:830/1670 train_time:80852ms step_avg:97.41ms +step:831/1670 train_time:80949ms step_avg:97.41ms +step:832/1670 train_time:81046ms step_avg:97.41ms +step:833/1670 train_time:81142ms step_avg:97.41ms +step:834/1670 train_time:81239ms step_avg:97.41ms +step:835/1670 train_time:81337ms step_avg:97.41ms +step:836/1670 train_time:81434ms step_avg:97.41ms +step:837/1670 train_time:81532ms step_avg:97.41ms +step:838/1670 train_time:81629ms step_avg:97.41ms +step:839/1670 train_time:81726ms step_avg:97.41ms +step:840/1670 train_time:81823ms step_avg:97.41ms +step:841/1670 train_time:81920ms step_avg:97.41ms +step:842/1670 train_time:82019ms step_avg:97.41ms +step:843/1670 train_time:82116ms step_avg:97.41ms +step:844/1670 train_time:82214ms step_avg:97.41ms +step:845/1670 train_time:82310ms step_avg:97.41ms +step:846/1670 train_time:82407ms step_avg:97.41ms +step:847/1670 train_time:82505ms step_avg:97.41ms +step:848/1670 train_time:82601ms step_avg:97.41ms +step:849/1670 train_time:82699ms step_avg:97.41ms +step:850/1670 train_time:82796ms step_avg:97.41ms +step:851/1670 train_time:83066ms step_avg:97.61ms +step:852/1670 train_time:83206ms step_avg:97.66ms +step:853/1670 train_time:83302ms step_avg:97.66ms +step:854/1670 train_time:83398ms step_avg:97.66ms +step:855/1670 train_time:83495ms step_avg:97.65ms +step:856/1670 train_time:83591ms step_avg:97.65ms +step:857/1670 train_time:83688ms step_avg:97.65ms +step:858/1670 train_time:83783ms step_avg:97.65ms +step:859/1670 train_time:83880ms step_avg:97.65ms +step:860/1670 train_time:83976ms step_avg:97.65ms +step:861/1670 train_time:84078ms step_avg:97.65ms +step:862/1670 train_time:84183ms step_avg:97.66ms +step:863/1670 train_time:84281ms step_avg:97.66ms +step:864/1670 train_time:84378ms step_avg:97.66ms +step:865/1670 train_time:84475ms step_avg:97.66ms +step:866/1670 train_time:84572ms step_avg:97.66ms +step:867/1670 train_time:84668ms step_avg:97.66ms +step:868/1670 train_time:84764ms step_avg:97.65ms +step:869/1670 train_time:84859ms step_avg:97.65ms +step:870/1670 train_time:84956ms step_avg:97.65ms +step:871/1670 train_time:85055ms step_avg:97.65ms +step:872/1670 train_time:85158ms step_avg:97.66ms +step:873/1670 train_time:85258ms step_avg:97.66ms +step:874/1670 train_time:85357ms step_avg:97.66ms +step:875/1670 train_time:85455ms step_avg:97.66ms +step:875/1670 val_loss:3.5197 train_time:85551ms step_avg:97.77ms +step:876/1670 train_time:85574ms step_avg:97.69ms +step:877/1670 train_time:85657ms step_avg:97.67ms +step:878/1670 train_time:85754ms step_avg:97.67ms +step:879/1670 train_time:85853ms step_avg:97.67ms +step:880/1670 train_time:85950ms step_avg:97.67ms +step:881/1670 train_time:86046ms step_avg:97.67ms +step:882/1670 train_time:86142ms step_avg:97.67ms +step:883/1670 train_time:86238ms step_avg:97.66ms +step:884/1670 train_time:86335ms step_avg:97.66ms +step:885/1670 train_time:86431ms step_avg:97.66ms +step:886/1670 train_time:86532ms step_avg:97.67ms +step:887/1670 train_time:86632ms step_avg:97.67ms +step:888/1670 train_time:86730ms step_avg:97.67ms +step:889/1670 train_time:86828ms step_avg:97.67ms +step:890/1670 train_time:86926ms step_avg:97.67ms +step:891/1670 train_time:87023ms step_avg:97.67ms +step:892/1670 train_time:87119ms step_avg:97.67ms +step:893/1670 train_time:87215ms step_avg:97.67ms +step:894/1670 train_time:87312ms step_avg:97.66ms +step:895/1670 train_time:87409ms step_avg:97.66ms +step:896/1670 train_time:87508ms step_avg:97.66ms +step:897/1670 train_time:87608ms step_avg:97.67ms +step:898/1670 train_time:87707ms step_avg:97.67ms +step:899/1670 train_time:87804ms step_avg:97.67ms +step:900/1670 train_time:87902ms step_avg:97.67ms +step:901/1670 train_time:87999ms step_avg:97.67ms +step:902/1670 train_time:88095ms step_avg:97.67ms +step:903/1670 train_time:88192ms step_avg:97.67ms +step:904/1670 train_time:88289ms step_avg:97.66ms +step:905/1670 train_time:88386ms step_avg:97.66ms +step:906/1670 train_time:88482ms step_avg:97.66ms +step:907/1670 train_time:88580ms step_avg:97.66ms +step:908/1670 train_time:88677ms step_avg:97.66ms +step:909/1670 train_time:88774ms step_avg:97.66ms +step:910/1670 train_time:88872ms step_avg:97.66ms +step:911/1670 train_time:88970ms step_avg:97.66ms +step:912/1670 train_time:89068ms step_avg:97.66ms +step:913/1670 train_time:89166ms step_avg:97.66ms +step:914/1670 train_time:89262ms step_avg:97.66ms +step:915/1670 train_time:89359ms step_avg:97.66ms +step:916/1670 train_time:89455ms step_avg:97.66ms +step:917/1670 train_time:89552ms step_avg:97.66ms +step:918/1670 train_time:89650ms step_avg:97.66ms +step:919/1670 train_time:89748ms step_avg:97.66ms +step:920/1670 train_time:89846ms step_avg:97.66ms +step:921/1670 train_time:89944ms step_avg:97.66ms +step:922/1670 train_time:90041ms step_avg:97.66ms +step:923/1670 train_time:90137ms step_avg:97.66ms +step:924/1670 train_time:90234ms step_avg:97.66ms +step:925/1670 train_time:90331ms step_avg:97.65ms +step:926/1670 train_time:90428ms step_avg:97.65ms +step:927/1670 train_time:90526ms step_avg:97.65ms +step:928/1670 train_time:90624ms step_avg:97.66ms +step:929/1670 train_time:90721ms step_avg:97.65ms +step:930/1670 train_time:90819ms step_avg:97.65ms +step:931/1670 train_time:90915ms step_avg:97.65ms +step:932/1670 train_time:91013ms step_avg:97.65ms +step:933/1670 train_time:91110ms step_avg:97.65ms +step:934/1670 train_time:91208ms step_avg:97.65ms +step:935/1670 train_time:91306ms step_avg:97.65ms +step:936/1670 train_time:91404ms step_avg:97.65ms +step:937/1670 train_time:91502ms step_avg:97.65ms +step:938/1670 train_time:91598ms step_avg:97.65ms +step:939/1670 train_time:91695ms step_avg:97.65ms +step:940/1670 train_time:91793ms step_avg:97.65ms +step:941/1670 train_time:91891ms step_avg:97.65ms +step:942/1670 train_time:91988ms step_avg:97.65ms +step:943/1670 train_time:92085ms step_avg:97.65ms +step:944/1670 train_time:92182ms step_avg:97.65ms +step:945/1670 train_time:92279ms step_avg:97.65ms +step:946/1670 train_time:92375ms step_avg:97.65ms +step:947/1670 train_time:92473ms step_avg:97.65ms +step:948/1670 train_time:92570ms step_avg:97.65ms +step:949/1670 train_time:92669ms step_avg:97.65ms +step:950/1670 train_time:92766ms step_avg:97.65ms +step:951/1670 train_time:92864ms step_avg:97.65ms +step:952/1670 train_time:92962ms step_avg:97.65ms +step:953/1670 train_time:93058ms step_avg:97.65ms +step:954/1670 train_time:93155ms step_avg:97.65ms +step:955/1670 train_time:93253ms step_avg:97.65ms +step:956/1670 train_time:93350ms step_avg:97.65ms +step:957/1670 train_time:93448ms step_avg:97.65ms +step:958/1670 train_time:93545ms step_avg:97.65ms +step:959/1670 train_time:93643ms step_avg:97.65ms +step:960/1670 train_time:93740ms step_avg:97.65ms +step:961/1670 train_time:93836ms step_avg:97.64ms +step:962/1670 train_time:93933ms step_avg:97.64ms +step:963/1670 train_time:94030ms step_avg:97.64ms +step:964/1670 train_time:94128ms step_avg:97.64ms +step:965/1670 train_time:94226ms step_avg:97.64ms +step:966/1670 train_time:94324ms step_avg:97.64ms +step:967/1670 train_time:94421ms step_avg:97.64ms +step:968/1670 train_time:94517ms step_avg:97.64ms +step:969/1670 train_time:94614ms step_avg:97.64ms +step:970/1670 train_time:94713ms step_avg:97.64ms +step:971/1670 train_time:94810ms step_avg:97.64ms +step:972/1670 train_time:94907ms step_avg:97.64ms +step:973/1670 train_time:95004ms step_avg:97.64ms +step:974/1670 train_time:95102ms step_avg:97.64ms +step:975/1670 train_time:95199ms step_avg:97.64ms +step:976/1670 train_time:95296ms step_avg:97.64ms +step:977/1670 train_time:95393ms step_avg:97.64ms +step:978/1670 train_time:95490ms step_avg:97.64ms +step:979/1670 train_time:95588ms step_avg:97.64ms +step:980/1670 train_time:95687ms step_avg:97.64ms +step:981/1670 train_time:95785ms step_avg:97.64ms +step:982/1670 train_time:95882ms step_avg:97.64ms +step:983/1670 train_time:95979ms step_avg:97.64ms +step:984/1670 train_time:96076ms step_avg:97.64ms +step:985/1670 train_time:96173ms step_avg:97.64ms +step:986/1670 train_time:96270ms step_avg:97.64ms +step:987/1670 train_time:96368ms step_avg:97.64ms +step:988/1670 train_time:96465ms step_avg:97.64ms +step:989/1670 train_time:96564ms step_avg:97.64ms +step:990/1670 train_time:96661ms step_avg:97.64ms +step:991/1670 train_time:96758ms step_avg:97.64ms +step:992/1670 train_time:96854ms step_avg:97.64ms +step:993/1670 train_time:96953ms step_avg:97.64ms +step:994/1670 train_time:97051ms step_avg:97.64ms +step:995/1670 train_time:97148ms step_avg:97.64ms +step:996/1670 train_time:97246ms step_avg:97.64ms +step:997/1670 train_time:97343ms step_avg:97.64ms +step:998/1670 train_time:97439ms step_avg:97.63ms +step:999/1670 train_time:97536ms step_avg:97.63ms +step:1000/1670 train_time:97633ms step_avg:97.63ms +step:1000/1670 val_loss:3.4772 train_time:97729ms step_avg:97.73ms +step:1001/1670 train_time:97752ms step_avg:97.65ms +step:1002/1670 train_time:97833ms step_avg:97.64ms +step:1003/1670 train_time:97935ms step_avg:97.64ms +step:1004/1670 train_time:98032ms step_avg:97.64ms +step:1005/1670 train_time:98128ms step_avg:97.64ms +step:1006/1670 train_time:98224ms step_avg:97.64ms +step:1007/1670 train_time:98321ms step_avg:97.64ms +step:1008/1670 train_time:98417ms step_avg:97.64ms +step:1009/1670 train_time:98514ms step_avg:97.63ms +step:1010/1670 train_time:98609ms step_avg:97.63ms +step:1011/1670 train_time:98707ms step_avg:97.63ms +step:1012/1670 train_time:98806ms step_avg:97.63ms +step:1013/1670 train_time:98905ms step_avg:97.64ms +step:1014/1670 train_time:99003ms step_avg:97.64ms +step:1015/1670 train_time:99102ms step_avg:97.64ms +step:1016/1670 train_time:99199ms step_avg:97.64ms +step:1017/1670 train_time:99295ms step_avg:97.64ms +step:1018/1670 train_time:99391ms step_avg:97.63ms +step:1019/1670 train_time:99488ms step_avg:97.63ms +step:1020/1670 train_time:99584ms step_avg:97.63ms +step:1021/1670 train_time:99683ms step_avg:97.63ms +step:1022/1670 train_time:99781ms step_avg:97.63ms +step:1023/1670 train_time:99882ms step_avg:97.64ms +step:1024/1670 train_time:99981ms step_avg:97.64ms +step:1025/1670 train_time:100079ms step_avg:97.64ms +step:1026/1670 train_time:100176ms step_avg:97.64ms +step:1027/1670 train_time:100274ms step_avg:97.64ms +step:1028/1670 train_time:100371ms step_avg:97.64ms +step:1029/1670 train_time:100467ms step_avg:97.64ms +step:1030/1670 train_time:100564ms step_avg:97.63ms +step:1031/1670 train_time:100661ms step_avg:97.63ms +step:1032/1670 train_time:100760ms step_avg:97.64ms +step:1033/1670 train_time:100859ms step_avg:97.64ms +step:1034/1670 train_time:100957ms step_avg:97.64ms +step:1035/1670 train_time:101055ms step_avg:97.64ms +step:1036/1670 train_time:101153ms step_avg:97.64ms +step:1037/1670 train_time:101249ms step_avg:97.64ms +step:1038/1670 train_time:101346ms step_avg:97.64ms +step:1039/1670 train_time:101443ms step_avg:97.63ms +step:1040/1670 train_time:101540ms step_avg:97.63ms +step:1041/1670 train_time:101639ms step_avg:97.64ms +step:1042/1670 train_time:101736ms step_avg:97.64ms +step:1043/1670 train_time:101835ms step_avg:97.64ms +step:1044/1670 train_time:101933ms step_avg:97.64ms +step:1045/1670 train_time:102030ms step_avg:97.64ms +step:1046/1670 train_time:102128ms step_avg:97.64ms +step:1047/1670 train_time:102225ms step_avg:97.64ms +step:1048/1670 train_time:102323ms step_avg:97.64ms +step:1049/1670 train_time:102421ms step_avg:97.64ms +step:1050/1670 train_time:102518ms step_avg:97.64ms +step:1051/1670 train_time:102616ms step_avg:97.64ms +step:1052/1670 train_time:102713ms step_avg:97.64ms +step:1053/1670 train_time:102810ms step_avg:97.63ms +step:1054/1670 train_time:102907ms step_avg:97.63ms +step:1055/1670 train_time:103005ms step_avg:97.63ms +step:1056/1670 train_time:103103ms step_avg:97.64ms +step:1057/1670 train_time:103202ms step_avg:97.64ms +step:1058/1670 train_time:103299ms step_avg:97.64ms +step:1059/1670 train_time:103396ms step_avg:97.64ms +step:1060/1670 train_time:103494ms step_avg:97.64ms +step:1061/1670 train_time:103592ms step_avg:97.64ms +step:1062/1670 train_time:103856ms step_avg:97.79ms +step:1063/1670 train_time:104047ms step_avg:97.88ms +step:1064/1670 train_time:104142ms step_avg:97.88ms +step:1065/1670 train_time:104238ms step_avg:97.88ms +step:1066/1670 train_time:104334ms step_avg:97.87ms +step:1067/1670 train_time:104431ms step_avg:97.87ms +step:1068/1670 train_time:104526ms step_avg:97.87ms +step:1069/1670 train_time:104623ms step_avg:97.87ms +step:1070/1670 train_time:104720ms step_avg:97.87ms +step:1071/1670 train_time:104817ms step_avg:97.87ms +step:1072/1670 train_time:104917ms step_avg:97.87ms +step:1073/1670 train_time:105021ms step_avg:97.88ms +step:1074/1670 train_time:105122ms step_avg:97.88ms +step:1075/1670 train_time:105219ms step_avg:97.88ms +step:1076/1670 train_time:105316ms step_avg:97.88ms +step:1077/1670 train_time:105414ms step_avg:97.88ms +step:1078/1670 train_time:105512ms step_avg:97.88ms +step:1079/1670 train_time:105609ms step_avg:97.88ms +step:1080/1670 train_time:105705ms step_avg:97.87ms +step:1081/1670 train_time:105801ms step_avg:97.87ms +step:1082/1670 train_time:105898ms step_avg:97.87ms +step:1083/1670 train_time:105998ms step_avg:97.87ms +step:1084/1670 train_time:106095ms step_avg:97.87ms +step:1085/1670 train_time:106194ms step_avg:97.87ms +step:1086/1670 train_time:106290ms step_avg:97.87ms +step:1087/1670 train_time:106387ms step_avg:97.87ms +step:1088/1670 train_time:106484ms step_avg:97.87ms +step:1089/1670 train_time:106581ms step_avg:97.87ms +step:1090/1670 train_time:106678ms step_avg:97.87ms +step:1091/1670 train_time:106776ms step_avg:97.87ms +step:1092/1670 train_time:106874ms step_avg:97.87ms +step:1093/1670 train_time:106971ms step_avg:97.87ms +step:1094/1670 train_time:107068ms step_avg:97.87ms +step:1095/1670 train_time:107166ms step_avg:97.87ms +step:1096/1670 train_time:107264ms step_avg:97.87ms +step:1097/1670 train_time:107363ms step_avg:97.87ms +step:1098/1670 train_time:107460ms step_avg:97.87ms +step:1099/1670 train_time:107558ms step_avg:97.87ms +step:1100/1670 train_time:107655ms step_avg:97.87ms +step:1101/1670 train_time:107752ms step_avg:97.87ms +step:1102/1670 train_time:107848ms step_avg:97.87ms +step:1103/1670 train_time:107945ms step_avg:97.87ms +step:1104/1670 train_time:108043ms step_avg:97.86ms +step:1105/1670 train_time:108142ms step_avg:97.87ms +step:1106/1670 train_time:108240ms step_avg:97.87ms +step:1107/1670 train_time:108338ms step_avg:97.87ms +step:1108/1670 train_time:108435ms step_avg:97.87ms +step:1109/1670 train_time:108532ms step_avg:97.86ms +step:1110/1670 train_time:108629ms step_avg:97.86ms +step:1111/1670 train_time:108726ms step_avg:97.86ms +step:1112/1670 train_time:108823ms step_avg:97.86ms +step:1113/1670 train_time:108921ms step_avg:97.86ms +step:1114/1670 train_time:109019ms step_avg:97.86ms +step:1115/1670 train_time:109118ms step_avg:97.86ms +step:1116/1670 train_time:109216ms step_avg:97.86ms +step:1117/1670 train_time:109314ms step_avg:97.86ms +step:1118/1670 train_time:109412ms step_avg:97.86ms +step:1119/1670 train_time:109509ms step_avg:97.86ms +step:1120/1670 train_time:109607ms step_avg:97.86ms +step:1121/1670 train_time:109705ms step_avg:97.86ms +step:1122/1670 train_time:109802ms step_avg:97.86ms +step:1123/1670 train_time:109901ms step_avg:97.86ms +step:1124/1670 train_time:110000ms step_avg:97.87ms +step:1125/1670 train_time:110100ms step_avg:97.87ms +step:1125/1670 val_loss:3.4233 train_time:110198ms step_avg:97.95ms +step:1126/1670 train_time:110221ms step_avg:97.89ms +step:1127/1670 train_time:110303ms step_avg:97.87ms +step:1128/1670 train_time:110402ms step_avg:97.87ms +step:1129/1670 train_time:110499ms step_avg:97.87ms +step:1130/1670 train_time:110595ms step_avg:97.87ms +step:1131/1670 train_time:110691ms step_avg:97.87ms +step:1132/1670 train_time:110788ms step_avg:97.87ms +step:1133/1670 train_time:110885ms step_avg:97.87ms +step:1134/1670 train_time:110983ms step_avg:97.87ms +step:1135/1670 train_time:111081ms step_avg:97.87ms +step:1136/1670 train_time:111184ms step_avg:97.87ms +step:1137/1670 train_time:111288ms step_avg:97.88ms +step:1138/1670 train_time:111387ms step_avg:97.88ms +step:1139/1670 train_time:111486ms step_avg:97.88ms +step:1140/1670 train_time:111585ms step_avg:97.88ms +step:1141/1670 train_time:111683ms step_avg:97.88ms +step:1142/1670 train_time:111780ms step_avg:97.88ms +step:1143/1670 train_time:111876ms step_avg:97.88ms +step:1144/1670 train_time:111974ms step_avg:97.88ms +step:1145/1670 train_time:112071ms step_avg:97.88ms +step:1146/1670 train_time:112170ms step_avg:97.88ms +step:1147/1670 train_time:112268ms step_avg:97.88ms +step:1148/1670 train_time:112368ms step_avg:97.88ms +step:1149/1670 train_time:112466ms step_avg:97.88ms +step:1150/1670 train_time:112564ms step_avg:97.88ms +step:1151/1670 train_time:112664ms step_avg:97.88ms +step:1152/1670 train_time:112762ms step_avg:97.88ms +step:1153/1670 train_time:112860ms step_avg:97.88ms +step:1154/1670 train_time:112956ms step_avg:97.88ms +step:1155/1670 train_time:113054ms step_avg:97.88ms +step:1156/1670 train_time:113153ms step_avg:97.88ms +step:1157/1670 train_time:113251ms step_avg:97.88ms +step:1158/1670 train_time:113349ms step_avg:97.88ms +step:1159/1670 train_time:113448ms step_avg:97.88ms +step:1160/1670 train_time:113546ms step_avg:97.88ms +step:1161/1670 train_time:113644ms step_avg:97.88ms +step:1162/1670 train_time:113741ms step_avg:97.88ms +step:1163/1670 train_time:113840ms step_avg:97.88ms +step:1164/1670 train_time:113938ms step_avg:97.88ms +step:1165/1670 train_time:114037ms step_avg:97.89ms +step:1166/1670 train_time:114136ms step_avg:97.89ms +step:1167/1670 train_time:114236ms step_avg:97.89ms +step:1168/1670 train_time:114335ms step_avg:97.89ms +step:1169/1670 train_time:114434ms step_avg:97.89ms +step:1170/1670 train_time:114531ms step_avg:97.89ms +step:1171/1670 train_time:114629ms step_avg:97.89ms +step:1172/1670 train_time:114727ms step_avg:97.89ms +step:1173/1670 train_time:114825ms step_avg:97.89ms +step:1174/1670 train_time:114923ms step_avg:97.89ms +step:1175/1670 train_time:115022ms step_avg:97.89ms +step:1176/1670 train_time:115123ms step_avg:97.89ms +step:1177/1670 train_time:115222ms step_avg:97.89ms +step:1178/1670 train_time:115319ms step_avg:97.89ms +step:1179/1670 train_time:115418ms step_avg:97.90ms +step:1180/1670 train_time:115517ms step_avg:97.90ms +step:1181/1670 train_time:115614ms step_avg:97.89ms +step:1182/1670 train_time:115710ms step_avg:97.89ms +step:1183/1670 train_time:115808ms step_avg:97.89ms +step:1184/1670 train_time:115906ms step_avg:97.89ms +step:1185/1670 train_time:116005ms step_avg:97.89ms +step:1186/1670 train_time:116104ms step_avg:97.90ms +step:1187/1670 train_time:116202ms step_avg:97.90ms +step:1188/1670 train_time:116301ms step_avg:97.90ms +step:1189/1670 train_time:116400ms step_avg:97.90ms +step:1190/1670 train_time:116498ms step_avg:97.90ms +step:1191/1670 train_time:116597ms step_avg:97.90ms +step:1192/1670 train_time:116695ms step_avg:97.90ms +step:1193/1670 train_time:116792ms step_avg:97.90ms +step:1194/1670 train_time:116888ms step_avg:97.90ms +step:1195/1670 train_time:116986ms step_avg:97.90ms +step:1196/1670 train_time:117085ms step_avg:97.90ms +step:1197/1670 train_time:117183ms step_avg:97.90ms +step:1198/1670 train_time:117283ms step_avg:97.90ms +step:1199/1670 train_time:117382ms step_avg:97.90ms +step:1200/1670 train_time:117482ms step_avg:97.90ms +step:1201/1670 train_time:117581ms step_avg:97.90ms +step:1202/1670 train_time:117680ms step_avg:97.90ms +step:1203/1670 train_time:117779ms step_avg:97.90ms +step:1204/1670 train_time:117877ms step_avg:97.90ms +step:1205/1670 train_time:117974ms step_avg:97.90ms +step:1206/1670 train_time:118071ms step_avg:97.90ms +step:1207/1670 train_time:118168ms step_avg:97.90ms +step:1208/1670 train_time:118267ms step_avg:97.90ms +step:1209/1670 train_time:118366ms step_avg:97.90ms +step:1210/1670 train_time:118465ms step_avg:97.90ms +step:1211/1670 train_time:118565ms step_avg:97.91ms +step:1212/1670 train_time:118665ms step_avg:97.91ms +step:1213/1670 train_time:118765ms step_avg:97.91ms +step:1214/1670 train_time:118865ms step_avg:97.91ms +step:1215/1670 train_time:118965ms step_avg:97.91ms +step:1216/1670 train_time:119064ms step_avg:97.91ms +step:1217/1670 train_time:119162ms step_avg:97.91ms +step:1218/1670 train_time:119260ms step_avg:97.91ms +step:1219/1670 train_time:119358ms step_avg:97.91ms +step:1220/1670 train_time:119456ms step_avg:97.92ms +step:1221/1670 train_time:119553ms step_avg:97.91ms +step:1222/1670 train_time:119651ms step_avg:97.91ms +step:1223/1670 train_time:119749ms step_avg:97.91ms +step:1224/1670 train_time:119848ms step_avg:97.91ms +step:1225/1670 train_time:119946ms step_avg:97.91ms +step:1226/1670 train_time:120044ms step_avg:97.92ms +step:1227/1670 train_time:120142ms step_avg:97.92ms +step:1228/1670 train_time:120240ms step_avg:97.92ms +step:1229/1670 train_time:120338ms step_avg:97.92ms +step:1230/1670 train_time:120437ms step_avg:97.92ms +step:1231/1670 train_time:120535ms step_avg:97.92ms +step:1232/1670 train_time:120633ms step_avg:97.92ms +step:1233/1670 train_time:120730ms step_avg:97.92ms +step:1234/1670 train_time:120827ms step_avg:97.92ms +step:1235/1670 train_time:120925ms step_avg:97.92ms +step:1236/1670 train_time:121023ms step_avg:97.92ms +step:1237/1670 train_time:121122ms step_avg:97.92ms +step:1238/1670 train_time:121219ms step_avg:97.92ms +step:1239/1670 train_time:121317ms step_avg:97.92ms +step:1240/1670 train_time:121416ms step_avg:97.92ms +step:1241/1670 train_time:121514ms step_avg:97.92ms +step:1242/1670 train_time:121611ms step_avg:97.92ms +step:1243/1670 train_time:121708ms step_avg:97.91ms +step:1244/1670 train_time:121805ms step_avg:97.91ms +step:1245/1670 train_time:121904ms step_avg:97.91ms +step:1246/1670 train_time:122002ms step_avg:97.92ms +step:1247/1670 train_time:122100ms step_avg:97.92ms +step:1248/1670 train_time:122198ms step_avg:97.92ms +step:1249/1670 train_time:122297ms step_avg:97.92ms +step:1250/1670 train_time:122395ms step_avg:97.92ms +step:1250/1670 val_loss:3.3801 train_time:122492ms step_avg:97.99ms +step:1251/1670 train_time:122515ms step_avg:97.93ms +step:1252/1670 train_time:122598ms step_avg:97.92ms +step:1253/1670 train_time:122696ms step_avg:97.92ms +step:1254/1670 train_time:122794ms step_avg:97.92ms +step:1255/1670 train_time:122892ms step_avg:97.92ms +step:1256/1670 train_time:122989ms step_avg:97.92ms +step:1257/1670 train_time:123087ms step_avg:97.92ms +step:1258/1670 train_time:123184ms step_avg:97.92ms +step:1259/1670 train_time:123282ms step_avg:97.92ms +step:1260/1670 train_time:123378ms step_avg:97.92ms +step:1261/1670 train_time:123478ms step_avg:97.92ms +step:1262/1670 train_time:123578ms step_avg:97.92ms +step:1263/1670 train_time:123676ms step_avg:97.92ms +step:1264/1670 train_time:123775ms step_avg:97.92ms +step:1265/1670 train_time:123874ms step_avg:97.92ms +step:1266/1670 train_time:123971ms step_avg:97.92ms +step:1267/1670 train_time:124068ms step_avg:97.92ms +step:1268/1670 train_time:124166ms step_avg:97.92ms +step:1269/1670 train_time:124263ms step_avg:97.92ms +step:1270/1670 train_time:124360ms step_avg:97.92ms +step:1271/1670 train_time:124458ms step_avg:97.92ms +step:1272/1670 train_time:124557ms step_avg:97.92ms +step:1273/1670 train_time:124655ms step_avg:97.92ms +step:1274/1670 train_time:125017ms step_avg:98.13ms +step:1275/1670 train_time:125117ms step_avg:98.13ms +step:1276/1670 train_time:125214ms step_avg:98.13ms +step:1277/1670 train_time:125310ms step_avg:98.13ms +step:1278/1670 train_time:125408ms step_avg:98.13ms +step:1279/1670 train_time:125505ms step_avg:98.13ms +step:1280/1670 train_time:125602ms step_avg:98.13ms +step:1281/1670 train_time:125698ms step_avg:98.12ms +step:1282/1670 train_time:125795ms step_avg:98.12ms +step:1283/1670 train_time:125892ms step_avg:98.12ms +step:1284/1670 train_time:125998ms step_avg:98.13ms +step:1285/1670 train_time:126098ms step_avg:98.13ms +step:1286/1670 train_time:126196ms step_avg:98.13ms +step:1287/1670 train_time:126293ms step_avg:98.13ms +step:1288/1670 train_time:126390ms step_avg:98.13ms +step:1289/1670 train_time:126487ms step_avg:98.13ms +step:1290/1670 train_time:126585ms step_avg:98.13ms +step:1291/1670 train_time:126683ms step_avg:98.13ms +step:1292/1670 train_time:126780ms step_avg:98.13ms +step:1293/1670 train_time:126878ms step_avg:98.13ms +step:1294/1670 train_time:126977ms step_avg:98.13ms +step:1295/1670 train_time:127077ms step_avg:98.13ms +step:1296/1670 train_time:127175ms step_avg:98.13ms +step:1297/1670 train_time:127273ms step_avg:98.13ms +step:1298/1670 train_time:127371ms step_avg:98.13ms +step:1299/1670 train_time:127469ms step_avg:98.13ms +step:1300/1670 train_time:127566ms step_avg:98.13ms +step:1301/1670 train_time:127664ms step_avg:98.13ms +step:1302/1670 train_time:127761ms step_avg:98.13ms +step:1303/1670 train_time:127859ms step_avg:98.13ms +step:1304/1670 train_time:127956ms step_avg:98.13ms +step:1305/1670 train_time:128055ms step_avg:98.13ms +step:1306/1670 train_time:128154ms step_avg:98.13ms +step:1307/1670 train_time:128253ms step_avg:98.13ms +step:1308/1670 train_time:128350ms step_avg:98.13ms +step:1309/1670 train_time:128448ms step_avg:98.13ms +step:1310/1670 train_time:128547ms step_avg:98.13ms +step:1311/1670 train_time:128644ms step_avg:98.13ms +step:1312/1670 train_time:128742ms step_avg:98.13ms +step:1313/1670 train_time:128839ms step_avg:98.13ms +step:1314/1670 train_time:128937ms step_avg:98.13ms +step:1315/1670 train_time:129035ms step_avg:98.13ms +step:1316/1670 train_time:129133ms step_avg:98.13ms +step:1317/1670 train_time:129230ms step_avg:98.12ms +step:1318/1670 train_time:129329ms step_avg:98.13ms +step:1319/1670 train_time:129427ms step_avg:98.13ms +step:1320/1670 train_time:129525ms step_avg:98.13ms +step:1321/1670 train_time:129623ms step_avg:98.12ms +step:1322/1670 train_time:129721ms step_avg:98.12ms +step:1323/1670 train_time:129818ms step_avg:98.12ms +step:1324/1670 train_time:129916ms step_avg:98.12ms +step:1325/1670 train_time:130015ms step_avg:98.12ms +step:1326/1670 train_time:130114ms step_avg:98.13ms +step:1327/1670 train_time:130212ms step_avg:98.12ms +step:1328/1670 train_time:130310ms step_avg:98.12ms +step:1329/1670 train_time:130407ms step_avg:98.12ms +step:1330/1670 train_time:130506ms step_avg:98.12ms +step:1331/1670 train_time:130604ms step_avg:98.12ms +step:1332/1670 train_time:130701ms step_avg:98.12ms +step:1333/1670 train_time:130799ms step_avg:98.12ms +step:1334/1670 train_time:130899ms step_avg:98.12ms +step:1335/1670 train_time:130997ms step_avg:98.13ms +step:1336/1670 train_time:131095ms step_avg:98.12ms +step:1337/1670 train_time:131192ms step_avg:98.12ms +step:1338/1670 train_time:131290ms step_avg:98.12ms +step:1339/1670 train_time:131388ms step_avg:98.12ms +step:1340/1670 train_time:131486ms step_avg:98.12ms +step:1341/1670 train_time:131583ms step_avg:98.12ms +step:1342/1670 train_time:131681ms step_avg:98.12ms +step:1343/1670 train_time:131780ms step_avg:98.12ms +step:1344/1670 train_time:131878ms step_avg:98.12ms +step:1345/1670 train_time:131975ms step_avg:98.12ms +step:1346/1670 train_time:132073ms step_avg:98.12ms +step:1347/1670 train_time:132171ms step_avg:98.12ms +step:1348/1670 train_time:132269ms step_avg:98.12ms +step:1349/1670 train_time:132367ms step_avg:98.12ms +step:1350/1670 train_time:132465ms step_avg:98.12ms +step:1351/1670 train_time:132564ms step_avg:98.12ms +step:1352/1670 train_time:132661ms step_avg:98.12ms +step:1353/1670 train_time:132760ms step_avg:98.12ms +step:1354/1670 train_time:132858ms step_avg:98.12ms +step:1355/1670 train_time:132956ms step_avg:98.12ms +step:1356/1670 train_time:133054ms step_avg:98.12ms +step:1357/1670 train_time:133152ms step_avg:98.12ms +step:1358/1670 train_time:133249ms step_avg:98.12ms +step:1359/1670 train_time:133347ms step_avg:98.12ms +step:1360/1670 train_time:133444ms step_avg:98.12ms +step:1361/1670 train_time:133542ms step_avg:98.12ms +step:1362/1670 train_time:133639ms step_avg:98.12ms +step:1363/1670 train_time:133737ms step_avg:98.12ms +step:1364/1670 train_time:133835ms step_avg:98.12ms +step:1365/1670 train_time:133934ms step_avg:98.12ms +step:1366/1670 train_time:134032ms step_avg:98.12ms +step:1367/1670 train_time:134130ms step_avg:98.12ms +step:1368/1670 train_time:134229ms step_avg:98.12ms +step:1369/1670 train_time:134328ms step_avg:98.12ms +step:1370/1670 train_time:134427ms step_avg:98.12ms +step:1371/1670 train_time:134525ms step_avg:98.12ms +step:1372/1670 train_time:134622ms step_avg:98.12ms +step:1373/1670 train_time:134721ms step_avg:98.12ms +step:1374/1670 train_time:134820ms step_avg:98.12ms +step:1375/1670 train_time:134917ms step_avg:98.12ms +step:1375/1670 val_loss:3.3425 train_time:135014ms step_avg:98.19ms +step:1376/1670 train_time:135036ms step_avg:98.14ms +step:1377/1670 train_time:135122ms step_avg:98.13ms +step:1378/1670 train_time:135224ms step_avg:98.13ms +step:1379/1670 train_time:135323ms step_avg:98.13ms +step:1380/1670 train_time:135421ms step_avg:98.13ms +step:1381/1670 train_time:135518ms step_avg:98.13ms +step:1382/1670 train_time:135615ms step_avg:98.13ms +step:1383/1670 train_time:135712ms step_avg:98.13ms +step:1384/1670 train_time:135809ms step_avg:98.13ms +step:1385/1670 train_time:135906ms step_avg:98.13ms +step:1386/1670 train_time:136006ms step_avg:98.13ms +step:1387/1670 train_time:136108ms step_avg:98.13ms +step:1388/1670 train_time:136208ms step_avg:98.13ms +step:1389/1670 train_time:136307ms step_avg:98.13ms +step:1390/1670 train_time:136406ms step_avg:98.13ms +step:1391/1670 train_time:136503ms step_avg:98.13ms +step:1392/1670 train_time:136602ms step_avg:98.13ms +step:1393/1670 train_time:136699ms step_avg:98.13ms +step:1394/1670 train_time:136796ms step_avg:98.13ms +step:1395/1670 train_time:136894ms step_avg:98.13ms +step:1396/1670 train_time:136992ms step_avg:98.13ms +step:1397/1670 train_time:137090ms step_avg:98.13ms +step:1398/1670 train_time:137189ms step_avg:98.13ms +step:1399/1670 train_time:137287ms step_avg:98.13ms +step:1400/1670 train_time:137386ms step_avg:98.13ms +step:1401/1670 train_time:137484ms step_avg:98.13ms +step:1402/1670 train_time:137581ms step_avg:98.13ms +step:1403/1670 train_time:137679ms step_avg:98.13ms +step:1404/1670 train_time:137777ms step_avg:98.13ms +step:1405/1670 train_time:137874ms step_avg:98.13ms +step:1406/1670 train_time:137973ms step_avg:98.13ms +step:1407/1670 train_time:138071ms step_avg:98.13ms +step:1408/1670 train_time:138169ms step_avg:98.13ms +step:1409/1670 train_time:138267ms step_avg:98.13ms +step:1410/1670 train_time:138367ms step_avg:98.13ms +step:1411/1670 train_time:138465ms step_avg:98.13ms +step:1412/1670 train_time:138563ms step_avg:98.13ms +step:1413/1670 train_time:138661ms step_avg:98.13ms +step:1414/1670 train_time:138759ms step_avg:98.13ms +step:1415/1670 train_time:138858ms step_avg:98.13ms +step:1416/1670 train_time:138956ms step_avg:98.13ms +step:1417/1670 train_time:139055ms step_avg:98.13ms +step:1418/1670 train_time:139152ms step_avg:98.13ms +step:1419/1670 train_time:139251ms step_avg:98.13ms +step:1420/1670 train_time:139349ms step_avg:98.13ms +step:1421/1670 train_time:139446ms step_avg:98.13ms +step:1422/1670 train_time:139544ms step_avg:98.13ms +step:1423/1670 train_time:139642ms step_avg:98.13ms +step:1424/1670 train_time:139740ms step_avg:98.13ms +step:1425/1670 train_time:139838ms step_avg:98.13ms +step:1426/1670 train_time:139936ms step_avg:98.13ms +step:1427/1670 train_time:140035ms step_avg:98.13ms +step:1428/1670 train_time:140133ms step_avg:98.13ms +step:1429/1670 train_time:140231ms step_avg:98.13ms +step:1430/1670 train_time:140330ms step_avg:98.13ms +step:1431/1670 train_time:140427ms step_avg:98.13ms +step:1432/1670 train_time:140525ms step_avg:98.13ms +step:1433/1670 train_time:140623ms step_avg:98.13ms +step:1434/1670 train_time:140720ms step_avg:98.13ms +step:1435/1670 train_time:140819ms step_avg:98.13ms +step:1436/1670 train_time:140917ms step_avg:98.13ms +step:1437/1670 train_time:141015ms step_avg:98.13ms +step:1438/1670 train_time:141114ms step_avg:98.13ms +step:1439/1670 train_time:141213ms step_avg:98.13ms +step:1440/1670 train_time:141312ms step_avg:98.13ms +step:1441/1670 train_time:141409ms step_avg:98.13ms +step:1442/1670 train_time:141506ms step_avg:98.13ms +step:1443/1670 train_time:141604ms step_avg:98.13ms +step:1444/1670 train_time:141702ms step_avg:98.13ms +step:1445/1670 train_time:141800ms step_avg:98.13ms +step:1446/1670 train_time:141899ms step_avg:98.13ms +step:1447/1670 train_time:141998ms step_avg:98.13ms +step:1448/1670 train_time:142097ms step_avg:98.13ms +step:1449/1670 train_time:142197ms step_avg:98.13ms +step:1450/1670 train_time:142296ms step_avg:98.14ms +step:1451/1670 train_time:142394ms step_avg:98.14ms +step:1452/1670 train_time:142493ms step_avg:98.14ms +step:1453/1670 train_time:142590ms step_avg:98.13ms +step:1454/1670 train_time:142688ms step_avg:98.13ms +step:1455/1670 train_time:142786ms step_avg:98.13ms +step:1456/1670 train_time:142885ms step_avg:98.14ms +step:1457/1670 train_time:142984ms step_avg:98.14ms +step:1458/1670 train_time:143085ms step_avg:98.14ms +step:1459/1670 train_time:143184ms step_avg:98.14ms +step:1460/1670 train_time:143284ms step_avg:98.14ms +step:1461/1670 train_time:143384ms step_avg:98.14ms +step:1462/1670 train_time:143483ms step_avg:98.14ms +step:1463/1670 train_time:143581ms step_avg:98.14ms +step:1464/1670 train_time:143680ms step_avg:98.14ms +step:1465/1670 train_time:143778ms step_avg:98.14ms +step:1466/1670 train_time:143875ms step_avg:98.14ms +step:1467/1670 train_time:143972ms step_avg:98.14ms +step:1468/1670 train_time:144070ms step_avg:98.14ms +step:1469/1670 train_time:144168ms step_avg:98.14ms +step:1470/1670 train_time:144266ms step_avg:98.14ms +step:1471/1670 train_time:144365ms step_avg:98.14ms +step:1472/1670 train_time:144465ms step_avg:98.14ms +step:1473/1670 train_time:144565ms step_avg:98.14ms +step:1474/1670 train_time:144662ms step_avg:98.14ms +step:1475/1670 train_time:144760ms step_avg:98.14ms +step:1476/1670 train_time:144859ms step_avg:98.14ms +step:1477/1670 train_time:144957ms step_avg:98.14ms +step:1478/1670 train_time:145055ms step_avg:98.14ms +step:1479/1670 train_time:145153ms step_avg:98.14ms +step:1480/1670 train_time:145251ms step_avg:98.14ms +step:1481/1670 train_time:145349ms step_avg:98.14ms +step:1482/1670 train_time:145447ms step_avg:98.14ms +step:1483/1670 train_time:145545ms step_avg:98.14ms +step:1484/1670 train_time:145642ms step_avg:98.14ms +step:1485/1670 train_time:145908ms step_avg:98.25ms +step:1486/1670 train_time:146091ms step_avg:98.31ms +step:1487/1670 train_time:146187ms step_avg:98.31ms +step:1488/1670 train_time:146284ms step_avg:98.31ms +step:1489/1670 train_time:146381ms step_avg:98.31ms +step:1490/1670 train_time:146479ms step_avg:98.31ms +step:1491/1670 train_time:146576ms step_avg:98.31ms +step:1492/1670 train_time:146673ms step_avg:98.31ms +step:1493/1670 train_time:146769ms step_avg:98.30ms +step:1494/1670 train_time:146867ms step_avg:98.30ms +step:1495/1670 train_time:146973ms step_avg:98.31ms +step:1496/1670 train_time:147073ms step_avg:98.31ms +step:1497/1670 train_time:147170ms step_avg:98.31ms +step:1498/1670 train_time:147267ms step_avg:98.31ms +step:1499/1670 train_time:147365ms step_avg:98.31ms +step:1500/1670 train_time:147462ms step_avg:98.31ms +step:1500/1670 val_loss:3.3106 train_time:147559ms step_avg:98.37ms +step:1501/1670 train_time:147582ms step_avg:98.32ms +step:1502/1670 train_time:147665ms step_avg:98.31ms +step:1503/1670 train_time:147766ms step_avg:98.31ms +step:1504/1670 train_time:147865ms step_avg:98.31ms +step:1505/1670 train_time:147962ms step_avg:98.31ms +step:1506/1670 train_time:148059ms step_avg:98.31ms +step:1507/1670 train_time:148155ms step_avg:98.31ms +step:1508/1670 train_time:148252ms step_avg:98.31ms +step:1509/1670 train_time:148349ms step_avg:98.31ms +step:1510/1670 train_time:148447ms step_avg:98.31ms +step:1511/1670 train_time:148547ms step_avg:98.31ms +step:1512/1670 train_time:148648ms step_avg:98.31ms +step:1513/1670 train_time:148748ms step_avg:98.31ms +step:1514/1670 train_time:148847ms step_avg:98.31ms +step:1515/1670 train_time:148945ms step_avg:98.31ms +step:1516/1670 train_time:149044ms step_avg:98.31ms +step:1517/1670 train_time:149141ms step_avg:98.31ms +step:1518/1670 train_time:149238ms step_avg:98.31ms +step:1519/1670 train_time:149335ms step_avg:98.31ms +step:1520/1670 train_time:149432ms step_avg:98.31ms +step:1521/1670 train_time:149530ms step_avg:98.31ms +step:1522/1670 train_time:149629ms step_avg:98.31ms +step:1523/1670 train_time:149729ms step_avg:98.31ms +step:1524/1670 train_time:149828ms step_avg:98.31ms +step:1525/1670 train_time:149927ms step_avg:98.31ms +step:1526/1670 train_time:150025ms step_avg:98.31ms +step:1527/1670 train_time:150124ms step_avg:98.31ms +step:1528/1670 train_time:150221ms step_avg:98.31ms +step:1529/1670 train_time:150319ms step_avg:98.31ms +step:1530/1670 train_time:150416ms step_avg:98.31ms +step:1531/1670 train_time:150513ms step_avg:98.31ms +step:1532/1670 train_time:150611ms step_avg:98.31ms +step:1533/1670 train_time:150710ms step_avg:98.31ms +step:1534/1670 train_time:150810ms step_avg:98.31ms +step:1535/1670 train_time:150908ms step_avg:98.31ms +step:1536/1670 train_time:151006ms step_avg:98.31ms +step:1537/1670 train_time:151104ms step_avg:98.31ms +step:1538/1670 train_time:151204ms step_avg:98.31ms +step:1539/1670 train_time:151302ms step_avg:98.31ms +step:1540/1670 train_time:151400ms step_avg:98.31ms +step:1541/1670 train_time:151499ms step_avg:98.31ms +step:1542/1670 train_time:151597ms step_avg:98.31ms +step:1543/1670 train_time:151695ms step_avg:98.31ms +step:1544/1670 train_time:151792ms step_avg:98.31ms +step:1545/1670 train_time:151889ms step_avg:98.31ms +step:1546/1670 train_time:151987ms step_avg:98.31ms +step:1547/1670 train_time:152085ms step_avg:98.31ms +step:1548/1670 train_time:152184ms step_avg:98.31ms +step:1549/1670 train_time:152282ms step_avg:98.31ms +step:1550/1670 train_time:152380ms step_avg:98.31ms +step:1551/1670 train_time:152478ms step_avg:98.31ms +step:1552/1670 train_time:152577ms step_avg:98.31ms +step:1553/1670 train_time:152674ms step_avg:98.31ms +step:1554/1670 train_time:152772ms step_avg:98.31ms +step:1555/1670 train_time:152870ms step_avg:98.31ms +step:1556/1670 train_time:152968ms step_avg:98.31ms +step:1557/1670 train_time:153066ms step_avg:98.31ms +step:1558/1670 train_time:153165ms step_avg:98.31ms +step:1559/1670 train_time:153264ms step_avg:98.31ms +step:1560/1670 train_time:153362ms step_avg:98.31ms +step:1561/1670 train_time:153461ms step_avg:98.31ms +step:1562/1670 train_time:153560ms step_avg:98.31ms +step:1563/1670 train_time:153660ms step_avg:98.31ms +step:1564/1670 train_time:153760ms step_avg:98.31ms +step:1565/1670 train_time:153858ms step_avg:98.31ms +step:1566/1670 train_time:153956ms step_avg:98.31ms +step:1567/1670 train_time:154052ms step_avg:98.31ms +step:1568/1670 train_time:154149ms step_avg:98.31ms +step:1569/1670 train_time:154247ms step_avg:98.31ms +step:1570/1670 train_time:154345ms step_avg:98.31ms +step:1571/1670 train_time:154443ms step_avg:98.31ms +step:1572/1670 train_time:154542ms step_avg:98.31ms +step:1573/1670 train_time:154642ms step_avg:98.31ms +step:1574/1670 train_time:154742ms step_avg:98.31ms +step:1575/1670 train_time:154840ms step_avg:98.31ms +step:1576/1670 train_time:154939ms step_avg:98.31ms +step:1577/1670 train_time:155037ms step_avg:98.31ms +step:1578/1670 train_time:155134ms step_avg:98.31ms +step:1579/1670 train_time:155231ms step_avg:98.31ms +step:1580/1670 train_time:155329ms step_avg:98.31ms +step:1581/1670 train_time:155428ms step_avg:98.31ms +step:1582/1670 train_time:155527ms step_avg:98.31ms +step:1583/1670 train_time:155627ms step_avg:98.31ms +step:1584/1670 train_time:155727ms step_avg:98.31ms +step:1585/1670 train_time:155827ms step_avg:98.31ms +step:1586/1670 train_time:155926ms step_avg:98.31ms +step:1587/1670 train_time:156025ms step_avg:98.31ms +step:1588/1670 train_time:156126ms step_avg:98.32ms +step:1589/1670 train_time:156225ms step_avg:98.32ms +step:1590/1670 train_time:156322ms step_avg:98.32ms +step:1591/1670 train_time:156421ms step_avg:98.32ms +step:1592/1670 train_time:156519ms step_avg:98.32ms +step:1593/1670 train_time:156617ms step_avg:98.32ms +step:1594/1670 train_time:156715ms step_avg:98.32ms +step:1595/1670 train_time:156813ms step_avg:98.32ms +step:1596/1670 train_time:156910ms step_avg:98.31ms +step:1597/1670 train_time:157009ms step_avg:98.32ms +step:1598/1670 train_time:157108ms step_avg:98.32ms +step:1599/1670 train_time:157207ms step_avg:98.32ms +step:1600/1670 train_time:157305ms step_avg:98.32ms +step:1601/1670 train_time:157404ms step_avg:98.32ms +step:1602/1670 train_time:157503ms step_avg:98.32ms +step:1603/1670 train_time:157601ms step_avg:98.32ms +step:1604/1670 train_time:157699ms step_avg:98.32ms +step:1605/1670 train_time:157798ms step_avg:98.32ms +step:1606/1670 train_time:157895ms step_avg:98.32ms +step:1607/1670 train_time:157993ms step_avg:98.32ms +step:1608/1670 train_time:158090ms step_avg:98.31ms +step:1609/1670 train_time:158188ms step_avg:98.31ms +step:1610/1670 train_time:158287ms step_avg:98.32ms +step:1611/1670 train_time:158386ms step_avg:98.32ms +step:1612/1670 train_time:158485ms step_avg:98.32ms +step:1613/1670 train_time:158583ms step_avg:98.32ms +step:1614/1670 train_time:158683ms step_avg:98.32ms +step:1615/1670 train_time:158783ms step_avg:98.32ms +step:1616/1670 train_time:158883ms step_avg:98.32ms +step:1617/1670 train_time:158983ms step_avg:98.32ms +step:1618/1670 train_time:159082ms step_avg:98.32ms +step:1619/1670 train_time:159181ms step_avg:98.32ms +step:1620/1670 train_time:159279ms step_avg:98.32ms +step:1621/1670 train_time:159376ms step_avg:98.32ms +step:1622/1670 train_time:159474ms step_avg:98.32ms +step:1623/1670 train_time:159571ms step_avg:98.32ms +step:1624/1670 train_time:159670ms step_avg:98.32ms +step:1625/1670 train_time:159769ms step_avg:98.32ms +step:1625/1670 val_loss:3.2839 train_time:159868ms step_avg:98.38ms +step:1626/1670 train_time:159893ms step_avg:98.33ms +step:1627/1670 train_time:159977ms step_avg:98.33ms +step:1628/1670 train_time:160077ms step_avg:98.33ms +step:1629/1670 train_time:160177ms step_avg:98.33ms +step:1630/1670 train_time:160275ms step_avg:98.33ms +step:1631/1670 train_time:160372ms step_avg:98.33ms +step:1632/1670 train_time:160468ms step_avg:98.33ms +step:1633/1670 train_time:160564ms step_avg:98.32ms +step:1634/1670 train_time:160661ms step_avg:98.32ms +step:1635/1670 train_time:160759ms step_avg:98.32ms +step:1636/1670 train_time:160858ms step_avg:98.32ms +step:1637/1670 train_time:160960ms step_avg:98.33ms +step:1638/1670 train_time:161060ms step_avg:98.33ms +step:1639/1670 train_time:161159ms step_avg:98.33ms +step:1640/1670 train_time:161257ms step_avg:98.33ms +step:1641/1670 train_time:161354ms step_avg:98.33ms +step:1642/1670 train_time:161451ms step_avg:98.33ms +step:1643/1670 train_time:161549ms step_avg:98.33ms +step:1644/1670 train_time:161647ms step_avg:98.33ms +step:1645/1670 train_time:161744ms step_avg:98.32ms +step:1646/1670 train_time:161841ms step_avg:98.32ms +step:1647/1670 train_time:161940ms step_avg:98.32ms +step:1648/1670 train_time:162041ms step_avg:98.33ms +step:1649/1670 train_time:162139ms step_avg:98.33ms +step:1650/1670 train_time:162237ms step_avg:98.33ms +step:1651/1670 train_time:162336ms step_avg:98.33ms +step:1652/1670 train_time:162433ms step_avg:98.32ms +step:1653/1670 train_time:162531ms step_avg:98.32ms +step:1654/1670 train_time:162628ms step_avg:98.32ms +step:1655/1670 train_time:162726ms step_avg:98.32ms +step:1656/1670 train_time:162824ms step_avg:98.32ms +step:1657/1670 train_time:162923ms step_avg:98.32ms +step:1658/1670 train_time:163021ms step_avg:98.32ms +step:1659/1670 train_time:163119ms step_avg:98.32ms +step:1660/1670 train_time:163217ms step_avg:98.32ms +step:1661/1670 train_time:163315ms step_avg:98.32ms +step:1662/1670 train_time:163414ms step_avg:98.32ms +step:1663/1670 train_time:163513ms step_avg:98.32ms +step:1664/1670 train_time:163612ms step_avg:98.32ms +step:1665/1670 train_time:163711ms step_avg:98.32ms +step:1666/1670 train_time:163809ms step_avg:98.32ms +step:1667/1670 train_time:163908ms step_avg:98.33ms +step:1668/1670 train_time:164007ms step_avg:98.33ms +step:1669/1670 train_time:164106ms step_avg:98.33ms +step:1670/1670 train_time:164204ms step_avg:98.33ms +step:1670/1670 val_loss:3.2760 train_time:164301ms step_avg:98.38ms +peak memory allocated: 34000 MiB reserved: 49796 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt b/records/050925_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt new file mode 100644 index 000000000..c6951e93c --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:23:20 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 43C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 43C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 33C P0 120W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 41C P0 122W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 42C P0 129W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 33C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 56619 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 56620 C /usr/bin/python3 610MiB | +| 0 N/A N/A 56621 C /usr/bin/python3 610MiB | +| 0 N/A N/A 56622 C /usr/bin/python3 610MiB | +| 0 N/A N/A 56623 C /usr/bin/python3 610MiB | +| 0 N/A N/A 56624 C /usr/bin/python3 610MiB | +| 0 N/A N/A 56625 C /usr/bin/python3 610MiB | +| 0 N/A N/A 56626 C /usr/bin/python3 610MiB | +| 1 N/A N/A 56620 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 56621 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 56622 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 56623 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 56624 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 56625 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 56626 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:377ms step_avg:376.56ms +step:2/1670 train_time:397ms step_avg:198.70ms +step:3/1670 train_time:470ms step_avg:156.65ms +step:4/1670 train_time:563ms step_avg:140.83ms +step:5/1670 train_time:658ms step_avg:131.53ms +step:6/1670 train_time:752ms step_avg:125.40ms +step:7/1670 train_time:847ms step_avg:121.00ms +step:8/1670 train_time:942ms step_avg:117.72ms +step:9/1670 train_time:1194ms step_avg:132.64ms +step:10/1670 train_time:1267ms step_avg:126.71ms +step:11/1670 train_time:1361ms step_avg:123.70ms +step:12/1670 train_time:1455ms step_avg:121.28ms +step:13/1670 train_time:1550ms step_avg:119.25ms +step:14/1670 train_time:1645ms step_avg:117.49ms +step:15/1670 train_time:1740ms step_avg:115.98ms +step:16/1670 train_time:1835ms step_avg:114.69ms +step:17/1670 train_time:1930ms step_avg:113.55ms +step:18/1670 train_time:2025ms step_avg:112.49ms +step:19/1670 train_time:2121ms step_avg:111.62ms +step:20/1670 train_time:2222ms step_avg:111.12ms +step:21/1670 train_time:2321ms step_avg:110.54ms +step:22/1670 train_time:2418ms step_avg:109.92ms +step:23/1670 train_time:2515ms step_avg:109.35ms +step:24/1670 train_time:2611ms step_avg:108.77ms +step:25/1670 train_time:2705ms step_avg:108.20ms +step:26/1670 train_time:2800ms step_avg:107.71ms +step:27/1670 train_time:2896ms step_avg:107.27ms +step:28/1670 train_time:2991ms step_avg:106.81ms +step:29/1670 train_time:3086ms step_avg:106.41ms +step:30/1670 train_time:3182ms step_avg:106.07ms +step:31/1670 train_time:3279ms step_avg:105.77ms +step:32/1670 train_time:3376ms step_avg:105.49ms +step:33/1670 train_time:3473ms step_avg:105.23ms +step:34/1670 train_time:3568ms step_avg:104.95ms +step:35/1670 train_time:3663ms step_avg:104.67ms +step:36/1670 train_time:3759ms step_avg:104.42ms +step:37/1670 train_time:3855ms step_avg:104.18ms +step:38/1670 train_time:3949ms step_avg:103.93ms +step:39/1670 train_time:4045ms step_avg:103.71ms +step:40/1670 train_time:4140ms step_avg:103.51ms +step:41/1670 train_time:4238ms step_avg:103.36ms +step:42/1670 train_time:4332ms step_avg:103.15ms +step:43/1670 train_time:4428ms step_avg:102.98ms +step:44/1670 train_time:4524ms step_avg:102.82ms +step:45/1670 train_time:4620ms step_avg:102.67ms +step:46/1670 train_time:4717ms step_avg:102.54ms +step:47/1670 train_time:4813ms step_avg:102.41ms +step:48/1670 train_time:4909ms step_avg:102.27ms +step:49/1670 train_time:5004ms step_avg:102.12ms +step:50/1670 train_time:5100ms step_avg:102.00ms +step:51/1670 train_time:5196ms step_avg:101.88ms +step:52/1670 train_time:5291ms step_avg:101.75ms +step:53/1670 train_time:5387ms step_avg:101.65ms +step:54/1670 train_time:5483ms step_avg:101.55ms +step:55/1670 train_time:5579ms step_avg:101.44ms +step:56/1670 train_time:5675ms step_avg:101.34ms +step:57/1670 train_time:5771ms step_avg:101.24ms +step:58/1670 train_time:5866ms step_avg:101.14ms +step:59/1670 train_time:5961ms step_avg:101.04ms +step:60/1670 train_time:6057ms step_avg:100.96ms +step:61/1670 train_time:6153ms step_avg:100.87ms +step:62/1670 train_time:6249ms step_avg:100.78ms +step:63/1670 train_time:6344ms step_avg:100.70ms +step:64/1670 train_time:6439ms step_avg:100.61ms +step:65/1670 train_time:6535ms step_avg:100.54ms +step:66/1670 train_time:6631ms step_avg:100.47ms +step:67/1670 train_time:6727ms step_avg:100.40ms +step:68/1670 train_time:6823ms step_avg:100.34ms +step:69/1670 train_time:6918ms step_avg:100.26ms +step:70/1670 train_time:7015ms step_avg:100.21ms +step:71/1670 train_time:7111ms step_avg:100.15ms +step:72/1670 train_time:7206ms step_avg:100.08ms +step:73/1670 train_time:7301ms step_avg:100.02ms +step:74/1670 train_time:7398ms step_avg:99.97ms +step:75/1670 train_time:7494ms step_avg:99.92ms +step:76/1670 train_time:7590ms step_avg:99.87ms +step:77/1670 train_time:7685ms step_avg:99.81ms +step:78/1670 train_time:7781ms step_avg:99.76ms +step:79/1670 train_time:7877ms step_avg:99.71ms +step:80/1670 train_time:7973ms step_avg:99.66ms +step:81/1670 train_time:8068ms step_avg:99.61ms +step:82/1670 train_time:8163ms step_avg:99.55ms +step:83/1670 train_time:8259ms step_avg:99.51ms +step:84/1670 train_time:8355ms step_avg:99.47ms +step:85/1670 train_time:8451ms step_avg:99.43ms +step:86/1670 train_time:8546ms step_avg:99.37ms +step:87/1670 train_time:8642ms step_avg:99.33ms +step:88/1670 train_time:8737ms step_avg:99.29ms +step:89/1670 train_time:8833ms step_avg:99.25ms +step:90/1670 train_time:8928ms step_avg:99.21ms +step:91/1670 train_time:9024ms step_avg:99.16ms +step:92/1670 train_time:9119ms step_avg:99.12ms +step:93/1670 train_time:9215ms step_avg:99.09ms +step:94/1670 train_time:9311ms step_avg:99.05ms +step:95/1670 train_time:9406ms step_avg:99.01ms +step:96/1670 train_time:9502ms step_avg:98.97ms +step:97/1670 train_time:9598ms step_avg:98.95ms +step:98/1670 train_time:9695ms step_avg:98.93ms +step:99/1670 train_time:9791ms step_avg:98.90ms +step:100/1670 train_time:9887ms step_avg:98.87ms +step:101/1670 train_time:9982ms step_avg:98.83ms +step:102/1670 train_time:10078ms step_avg:98.80ms +step:103/1670 train_time:10174ms step_avg:98.77ms +step:104/1670 train_time:10269ms step_avg:98.74ms +step:105/1670 train_time:10364ms step_avg:98.71ms +step:106/1670 train_time:10460ms step_avg:98.68ms +step:107/1670 train_time:10556ms step_avg:98.66ms +step:108/1670 train_time:10652ms step_avg:98.63ms +step:109/1670 train_time:10748ms step_avg:98.61ms +step:110/1670 train_time:10843ms step_avg:98.57ms +step:111/1670 train_time:10939ms step_avg:98.55ms +step:112/1670 train_time:11035ms step_avg:98.52ms +step:113/1670 train_time:11131ms step_avg:98.50ms +step:114/1670 train_time:11226ms step_avg:98.48ms +step:115/1670 train_time:11322ms step_avg:98.45ms +step:116/1670 train_time:11418ms step_avg:98.43ms +step:117/1670 train_time:11515ms step_avg:98.41ms +step:118/1670 train_time:11610ms step_avg:98.39ms +step:119/1670 train_time:11705ms step_avg:98.37ms +step:120/1670 train_time:11801ms step_avg:98.34ms +step:121/1670 train_time:11897ms step_avg:98.32ms +step:122/1670 train_time:11993ms step_avg:98.30ms +step:123/1670 train_time:12089ms step_avg:98.28ms +step:124/1670 train_time:12184ms step_avg:98.26ms +step:125/1670 train_time:12279ms step_avg:98.23ms +step:125/1670 val_loss:4.3009 train_time:12375ms step_avg:99.00ms +step:126/1670 train_time:12397ms step_avg:98.39ms +step:127/1670 train_time:12481ms step_avg:98.28ms +step:128/1670 train_time:12586ms step_avg:98.33ms +step:129/1670 train_time:12683ms step_avg:98.32ms +step:130/1670 train_time:12778ms step_avg:98.29ms +step:131/1670 train_time:12873ms step_avg:98.27ms +step:132/1670 train_time:12967ms step_avg:98.24ms +step:133/1670 train_time:13062ms step_avg:98.21ms +step:134/1670 train_time:13156ms step_avg:98.18ms +step:135/1670 train_time:13251ms step_avg:98.15ms +step:136/1670 train_time:13345ms step_avg:98.13ms +step:137/1670 train_time:13441ms step_avg:98.11ms +step:138/1670 train_time:13540ms step_avg:98.11ms +step:139/1670 train_time:13637ms step_avg:98.11ms +step:140/1670 train_time:13733ms step_avg:98.10ms +step:141/1670 train_time:13829ms step_avg:98.08ms +step:142/1670 train_time:13923ms step_avg:98.05ms +step:143/1670 train_time:14018ms step_avg:98.03ms +step:144/1670 train_time:14113ms step_avg:98.01ms +step:145/1670 train_time:14208ms step_avg:97.99ms +step:146/1670 train_time:14302ms step_avg:97.96ms +step:147/1670 train_time:14398ms step_avg:97.94ms +step:148/1670 train_time:14495ms step_avg:97.94ms +step:149/1670 train_time:14592ms step_avg:97.94ms +step:150/1670 train_time:14689ms step_avg:97.93ms +step:151/1670 train_time:14784ms step_avg:97.91ms +step:152/1670 train_time:14879ms step_avg:97.89ms +step:153/1670 train_time:14976ms step_avg:97.88ms +step:154/1670 train_time:15070ms step_avg:97.86ms +step:155/1670 train_time:15165ms step_avg:97.84ms +step:156/1670 train_time:15808ms step_avg:101.33ms +step:157/1670 train_time:15881ms step_avg:101.15ms +step:158/1670 train_time:15974ms step_avg:101.10ms +step:159/1670 train_time:16069ms step_avg:101.06ms +step:160/1670 train_time:16164ms step_avg:101.02ms +step:161/1670 train_time:16258ms step_avg:100.98ms +step:162/1670 train_time:16353ms step_avg:100.94ms +step:163/1670 train_time:16447ms step_avg:100.90ms +step:164/1670 train_time:16542ms step_avg:100.86ms +step:165/1670 train_time:16637ms step_avg:100.83ms +step:166/1670 train_time:16733ms step_avg:100.80ms +step:167/1670 train_time:16833ms step_avg:100.80ms +step:168/1670 train_time:16929ms step_avg:100.77ms +step:169/1670 train_time:17025ms step_avg:100.74ms +step:170/1670 train_time:17120ms step_avg:100.71ms +step:171/1670 train_time:17215ms step_avg:100.67ms +step:172/1670 train_time:17311ms step_avg:100.64ms +step:173/1670 train_time:17405ms step_avg:100.61ms +step:174/1670 train_time:17500ms step_avg:100.57ms +step:175/1670 train_time:17595ms step_avg:100.54ms +step:176/1670 train_time:17690ms step_avg:100.51ms +step:177/1670 train_time:17785ms step_avg:100.48ms +step:178/1670 train_time:17883ms step_avg:100.47ms +step:179/1670 train_time:17979ms step_avg:100.44ms +step:180/1670 train_time:18075ms step_avg:100.42ms +step:181/1670 train_time:18170ms step_avg:100.39ms +step:182/1670 train_time:18266ms step_avg:100.36ms +step:183/1670 train_time:18361ms step_avg:100.33ms +step:184/1670 train_time:18456ms step_avg:100.31ms +step:185/1670 train_time:18551ms step_avg:100.27ms +step:186/1670 train_time:18646ms step_avg:100.25ms +step:187/1670 train_time:18742ms step_avg:100.22ms +step:188/1670 train_time:18838ms step_avg:100.20ms +step:189/1670 train_time:18934ms step_avg:100.18ms +step:190/1670 train_time:19029ms step_avg:100.15ms +step:191/1670 train_time:19124ms step_avg:100.12ms +step:192/1670 train_time:19219ms step_avg:100.10ms +step:193/1670 train_time:19315ms step_avg:100.08ms +step:194/1670 train_time:19410ms step_avg:100.05ms +step:195/1670 train_time:19505ms step_avg:100.03ms +step:196/1670 train_time:19601ms step_avg:100.01ms +step:197/1670 train_time:19697ms step_avg:99.98ms +step:198/1670 train_time:19792ms step_avg:99.96ms +step:199/1670 train_time:19889ms step_avg:99.94ms +step:200/1670 train_time:19984ms step_avg:99.92ms +step:201/1670 train_time:20079ms step_avg:99.90ms +step:202/1670 train_time:20176ms step_avg:99.88ms +step:203/1670 train_time:20271ms step_avg:99.86ms +step:204/1670 train_time:20366ms step_avg:99.83ms +step:205/1670 train_time:20462ms step_avg:99.81ms +step:206/1670 train_time:20557ms step_avg:99.79ms +step:207/1670 train_time:20652ms step_avg:99.77ms +step:208/1670 train_time:20748ms step_avg:99.75ms +step:209/1670 train_time:20843ms step_avg:99.73ms +step:210/1670 train_time:20938ms step_avg:99.71ms +step:211/1670 train_time:21034ms step_avg:99.69ms +step:212/1670 train_time:21129ms step_avg:99.67ms +step:213/1670 train_time:21417ms step_avg:100.55ms +step:214/1670 train_time:21512ms step_avg:100.52ms +step:215/1670 train_time:21606ms step_avg:100.49ms +step:216/1670 train_time:21700ms step_avg:100.46ms +step:217/1670 train_time:21795ms step_avg:100.44ms +step:218/1670 train_time:21890ms step_avg:100.41ms +step:219/1670 train_time:21984ms step_avg:100.38ms +step:220/1670 train_time:22079ms step_avg:100.36ms +step:221/1670 train_time:22174ms step_avg:100.33ms +step:222/1670 train_time:22269ms step_avg:100.31ms +step:223/1670 train_time:22366ms step_avg:100.30ms +step:224/1670 train_time:22464ms step_avg:100.29ms +step:225/1670 train_time:22562ms step_avg:100.28ms +step:226/1670 train_time:22658ms step_avg:100.26ms +step:227/1670 train_time:22753ms step_avg:100.23ms +step:228/1670 train_time:22848ms step_avg:100.21ms +step:229/1670 train_time:22943ms step_avg:100.19ms +step:230/1670 train_time:23038ms step_avg:100.16ms +step:231/1670 train_time:23132ms step_avg:100.14ms +step:232/1670 train_time:23227ms step_avg:100.12ms +step:233/1670 train_time:23322ms step_avg:100.10ms +step:234/1670 train_time:23419ms step_avg:100.08ms +step:235/1670 train_time:23517ms step_avg:100.07ms +step:236/1670 train_time:23613ms step_avg:100.06ms +step:237/1670 train_time:23709ms step_avg:100.04ms +step:238/1670 train_time:23803ms step_avg:100.01ms +step:239/1670 train_time:23898ms step_avg:99.99ms +step:240/1670 train_time:23993ms step_avg:99.97ms +step:241/1670 train_time:24088ms step_avg:99.95ms +step:242/1670 train_time:24182ms step_avg:99.93ms +step:243/1670 train_time:24278ms step_avg:99.91ms +step:244/1670 train_time:24375ms step_avg:99.90ms +step:245/1670 train_time:24472ms step_avg:99.88ms +step:246/1670 train_time:24567ms step_avg:99.86ms +step:247/1670 train_time:24662ms step_avg:99.85ms +step:248/1670 train_time:24758ms step_avg:99.83ms +step:249/1670 train_time:24854ms step_avg:99.81ms +step:250/1670 train_time:24948ms step_avg:99.79ms +step:250/1670 val_loss:3.9641 train_time:25042ms step_avg:100.17ms +step:251/1670 train_time:25065ms step_avg:99.86ms +step:252/1670 train_time:25145ms step_avg:99.78ms +step:253/1670 train_time:25244ms step_avg:99.78ms +step:254/1670 train_time:25340ms step_avg:99.76ms +step:255/1670 train_time:25435ms step_avg:99.75ms +step:256/1670 train_time:25530ms step_avg:99.73ms +step:257/1670 train_time:25624ms step_avg:99.71ms +step:258/1670 train_time:25719ms step_avg:99.69ms +step:259/1670 train_time:25813ms step_avg:99.67ms +step:260/1670 train_time:25909ms step_avg:99.65ms +step:261/1670 train_time:26004ms step_avg:99.63ms +step:262/1670 train_time:26101ms step_avg:99.62ms +step:263/1670 train_time:26197ms step_avg:99.61ms +step:264/1670 train_time:26294ms step_avg:99.60ms +step:265/1670 train_time:26391ms step_avg:99.59ms +step:266/1670 train_time:26486ms step_avg:99.57ms +step:267/1670 train_time:26581ms step_avg:99.55ms +step:268/1670 train_time:26675ms step_avg:99.54ms +step:269/1670 train_time:26770ms step_avg:99.52ms +step:270/1670 train_time:26866ms step_avg:99.50ms +step:271/1670 train_time:26960ms step_avg:99.48ms +step:272/1670 train_time:27056ms step_avg:99.47ms +step:273/1670 train_time:27152ms step_avg:99.46ms +step:274/1670 train_time:27250ms step_avg:99.45ms +step:275/1670 train_time:27347ms step_avg:99.44ms +step:276/1670 train_time:27443ms step_avg:99.43ms +step:277/1670 train_time:27538ms step_avg:99.41ms +step:278/1670 train_time:27632ms step_avg:99.40ms +step:279/1670 train_time:27727ms step_avg:99.38ms +step:280/1670 train_time:27822ms step_avg:99.37ms +step:281/1670 train_time:27917ms step_avg:99.35ms +step:282/1670 train_time:28011ms step_avg:99.33ms +step:283/1670 train_time:28108ms step_avg:99.32ms +step:284/1670 train_time:28205ms step_avg:99.31ms +step:285/1670 train_time:28301ms step_avg:99.30ms +step:286/1670 train_time:28396ms step_avg:99.29ms +step:287/1670 train_time:28492ms step_avg:99.28ms +step:288/1670 train_time:28588ms step_avg:99.26ms +step:289/1670 train_time:28683ms step_avg:99.25ms +step:290/1670 train_time:28778ms step_avg:99.23ms +step:291/1670 train_time:28873ms step_avg:99.22ms +step:292/1670 train_time:28967ms step_avg:99.20ms +step:293/1670 train_time:29062ms step_avg:99.19ms +step:294/1670 train_time:29157ms step_avg:99.17ms +step:295/1670 train_time:29253ms step_avg:99.16ms +step:296/1670 train_time:29350ms step_avg:99.15ms +step:297/1670 train_time:29446ms step_avg:99.15ms +step:298/1670 train_time:29542ms step_avg:99.13ms +step:299/1670 train_time:29637ms step_avg:99.12ms +step:300/1670 train_time:29732ms step_avg:99.11ms +step:301/1670 train_time:29828ms step_avg:99.10ms +step:302/1670 train_time:29924ms step_avg:99.08ms +step:303/1670 train_time:30019ms step_avg:99.07ms +step:304/1670 train_time:30114ms step_avg:99.06ms +step:305/1670 train_time:30210ms step_avg:99.05ms +step:306/1670 train_time:30305ms step_avg:99.04ms +step:307/1670 train_time:30401ms step_avg:99.03ms +step:308/1670 train_time:30496ms step_avg:99.01ms +step:309/1670 train_time:30592ms step_avg:99.00ms +step:310/1670 train_time:30687ms step_avg:98.99ms +step:311/1670 train_time:30782ms step_avg:98.98ms +step:312/1670 train_time:30877ms step_avg:98.96ms +step:313/1670 train_time:30972ms step_avg:98.95ms +step:314/1670 train_time:31068ms step_avg:98.94ms +step:315/1670 train_time:31163ms step_avg:98.93ms +step:316/1670 train_time:31259ms step_avg:98.92ms +step:317/1670 train_time:31354ms step_avg:98.91ms +step:318/1670 train_time:31450ms step_avg:98.90ms +step:319/1670 train_time:31547ms step_avg:98.89ms +step:320/1670 train_time:31642ms step_avg:98.88ms +step:321/1670 train_time:31738ms step_avg:98.87ms +step:322/1670 train_time:31834ms step_avg:98.86ms +step:323/1670 train_time:31929ms step_avg:98.85ms +step:324/1670 train_time:32024ms step_avg:98.84ms +step:325/1670 train_time:32120ms step_avg:98.83ms +step:326/1670 train_time:32215ms step_avg:98.82ms +step:327/1670 train_time:32310ms step_avg:98.81ms +step:328/1670 train_time:32406ms step_avg:98.80ms +step:329/1670 train_time:32500ms step_avg:98.79ms +step:330/1670 train_time:32595ms step_avg:98.77ms +step:331/1670 train_time:32691ms step_avg:98.76ms +step:332/1670 train_time:32788ms step_avg:98.76ms +step:333/1670 train_time:32883ms step_avg:98.75ms +step:334/1670 train_time:32979ms step_avg:98.74ms +step:335/1670 train_time:33074ms step_avg:98.73ms +step:336/1670 train_time:33169ms step_avg:98.72ms +step:337/1670 train_time:33265ms step_avg:98.71ms +step:338/1670 train_time:33360ms step_avg:98.70ms +step:339/1670 train_time:33455ms step_avg:98.69ms +step:340/1670 train_time:33550ms step_avg:98.68ms +step:341/1670 train_time:33646ms step_avg:98.67ms +step:342/1670 train_time:33742ms step_avg:98.66ms +step:343/1670 train_time:33838ms step_avg:98.65ms +step:344/1670 train_time:33934ms step_avg:98.64ms +step:345/1670 train_time:34029ms step_avg:98.64ms +step:346/1670 train_time:34125ms step_avg:98.63ms +step:347/1670 train_time:34220ms step_avg:98.62ms +step:348/1670 train_time:34315ms step_avg:98.61ms +step:349/1670 train_time:34411ms step_avg:98.60ms +step:350/1670 train_time:34507ms step_avg:98.59ms +step:351/1670 train_time:34602ms step_avg:98.58ms +step:352/1670 train_time:34697ms step_avg:98.57ms +step:353/1670 train_time:34793ms step_avg:98.56ms +step:354/1670 train_time:34890ms step_avg:98.56ms +step:355/1670 train_time:34986ms step_avg:98.55ms +step:356/1670 train_time:35082ms step_avg:98.54ms +step:357/1670 train_time:35177ms step_avg:98.54ms +step:358/1670 train_time:35273ms step_avg:98.53ms +step:359/1670 train_time:35368ms step_avg:98.52ms +step:360/1670 train_time:35463ms step_avg:98.51ms +step:361/1670 train_time:35558ms step_avg:98.50ms +step:362/1670 train_time:35654ms step_avg:98.49ms +step:363/1670 train_time:35749ms step_avg:98.48ms +step:364/1670 train_time:35845ms step_avg:98.48ms +step:365/1670 train_time:35941ms step_avg:98.47ms +step:366/1670 train_time:36037ms step_avg:98.46ms +step:367/1670 train_time:36133ms step_avg:98.45ms +step:368/1670 train_time:36229ms step_avg:98.45ms +step:369/1670 train_time:36325ms step_avg:98.44ms +step:370/1670 train_time:36420ms step_avg:98.43ms +step:371/1670 train_time:36516ms step_avg:98.42ms +step:372/1670 train_time:36611ms step_avg:98.42ms +step:373/1670 train_time:36707ms step_avg:98.41ms +step:374/1670 train_time:36803ms step_avg:98.40ms +step:375/1670 train_time:36898ms step_avg:98.39ms +step:375/1670 val_loss:3.8163 train_time:36993ms step_avg:98.65ms +step:376/1670 train_time:37015ms step_avg:98.44ms +step:377/1670 train_time:37090ms step_avg:98.38ms +step:378/1670 train_time:37186ms step_avg:98.38ms +step:379/1670 train_time:37288ms step_avg:98.38ms +step:380/1670 train_time:37382ms step_avg:98.37ms +step:381/1670 train_time:37477ms step_avg:98.36ms +step:382/1670 train_time:37571ms step_avg:98.35ms +step:383/1670 train_time:37666ms step_avg:98.34ms +step:384/1670 train_time:37761ms step_avg:98.33ms +step:385/1670 train_time:37854ms step_avg:98.32ms +step:386/1670 train_time:37953ms step_avg:98.32ms +step:387/1670 train_time:38051ms step_avg:98.32ms +step:388/1670 train_time:38146ms step_avg:98.31ms +step:389/1670 train_time:38241ms step_avg:98.31ms +step:390/1670 train_time:38337ms step_avg:98.30ms +step:391/1670 train_time:38432ms step_avg:98.29ms +step:392/1670 train_time:38527ms step_avg:98.28ms +step:393/1670 train_time:38622ms step_avg:98.27ms +step:394/1670 train_time:38717ms step_avg:98.27ms +step:395/1670 train_time:38812ms step_avg:98.26ms +step:396/1670 train_time:38908ms step_avg:98.25ms +step:397/1670 train_time:39004ms step_avg:98.25ms +step:398/1670 train_time:39100ms step_avg:98.24ms +step:399/1670 train_time:39196ms step_avg:98.24ms +step:400/1670 train_time:39293ms step_avg:98.23ms +step:401/1670 train_time:39388ms step_avg:98.23ms +step:402/1670 train_time:39484ms step_avg:98.22ms +step:403/1670 train_time:39578ms step_avg:98.21ms +step:404/1670 train_time:39673ms step_avg:98.20ms +step:405/1670 train_time:39768ms step_avg:98.19ms +step:406/1670 train_time:39863ms step_avg:98.18ms +step:407/1670 train_time:39958ms step_avg:98.18ms +step:408/1670 train_time:40054ms step_avg:98.17ms +step:409/1670 train_time:40150ms step_avg:98.17ms +step:410/1670 train_time:40246ms step_avg:98.16ms +step:411/1670 train_time:40341ms step_avg:98.15ms +step:412/1670 train_time:40437ms step_avg:98.15ms +step:413/1670 train_time:40533ms step_avg:98.14ms +step:414/1670 train_time:40629ms step_avg:98.14ms +step:415/1670 train_time:40724ms step_avg:98.13ms +step:416/1670 train_time:40819ms step_avg:98.12ms +step:417/1670 train_time:40915ms step_avg:98.12ms +step:418/1670 train_time:41011ms step_avg:98.11ms +step:419/1670 train_time:41107ms step_avg:98.11ms +step:420/1670 train_time:41203ms step_avg:98.10ms +step:421/1670 train_time:41298ms step_avg:98.09ms +step:422/1670 train_time:41394ms step_avg:98.09ms +step:423/1670 train_time:41489ms step_avg:98.08ms +step:424/1670 train_time:41584ms step_avg:98.08ms +step:425/1670 train_time:41875ms step_avg:98.53ms +step:426/1670 train_time:41994ms step_avg:98.58ms +step:427/1670 train_time:42088ms step_avg:98.57ms +step:428/1670 train_time:42182ms step_avg:98.56ms +step:429/1670 train_time:42277ms step_avg:98.55ms +step:430/1670 train_time:42372ms step_avg:98.54ms +step:431/1670 train_time:42467ms step_avg:98.53ms +step:432/1670 train_time:42561ms step_avg:98.52ms +step:433/1670 train_time:42655ms step_avg:98.51ms +step:434/1670 train_time:42750ms step_avg:98.50ms +step:435/1670 train_time:42849ms step_avg:98.50ms +step:436/1670 train_time:42948ms step_avg:98.51ms +step:437/1670 train_time:43045ms step_avg:98.50ms +step:438/1670 train_time:43140ms step_avg:98.49ms +step:439/1670 train_time:43235ms step_avg:98.49ms +step:440/1670 train_time:43330ms step_avg:98.48ms +step:441/1670 train_time:43425ms step_avg:98.47ms +step:442/1670 train_time:43519ms step_avg:98.46ms +step:443/1670 train_time:43614ms step_avg:98.45ms +step:444/1670 train_time:43709ms step_avg:98.44ms +step:445/1670 train_time:43806ms step_avg:98.44ms +step:446/1670 train_time:43901ms step_avg:98.43ms +step:447/1670 train_time:43999ms step_avg:98.43ms +step:448/1670 train_time:44095ms step_avg:98.43ms +step:449/1670 train_time:44191ms step_avg:98.42ms +step:450/1670 train_time:44287ms step_avg:98.42ms +step:451/1670 train_time:44383ms step_avg:98.41ms +step:452/1670 train_time:44478ms step_avg:98.40ms +step:453/1670 train_time:44573ms step_avg:98.39ms +step:454/1670 train_time:44668ms step_avg:98.39ms +step:455/1670 train_time:44762ms step_avg:98.38ms +step:456/1670 train_time:44858ms step_avg:98.37ms +step:457/1670 train_time:44955ms step_avg:98.37ms +step:458/1670 train_time:45052ms step_avg:98.37ms +step:459/1670 train_time:45147ms step_avg:98.36ms +step:460/1670 train_time:45242ms step_avg:98.35ms +step:461/1670 train_time:45339ms step_avg:98.35ms +step:462/1670 train_time:45434ms step_avg:98.34ms +step:463/1670 train_time:45530ms step_avg:98.34ms +step:464/1670 train_time:45625ms step_avg:98.33ms +step:465/1670 train_time:45719ms step_avg:98.32ms +step:466/1670 train_time:45815ms step_avg:98.31ms +step:467/1670 train_time:45911ms step_avg:98.31ms +step:468/1670 train_time:46007ms step_avg:98.31ms +step:469/1670 train_time:46102ms step_avg:98.30ms +step:470/1670 train_time:46198ms step_avg:98.29ms +step:471/1670 train_time:46294ms step_avg:98.29ms +step:472/1670 train_time:46389ms step_avg:98.28ms +step:473/1670 train_time:46484ms step_avg:98.28ms +step:474/1670 train_time:46579ms step_avg:98.27ms +step:475/1670 train_time:46674ms step_avg:98.26ms +step:476/1670 train_time:46769ms step_avg:98.25ms +step:477/1670 train_time:46864ms step_avg:98.25ms +step:478/1670 train_time:46961ms step_avg:98.24ms +step:479/1670 train_time:47057ms step_avg:98.24ms +step:480/1670 train_time:47153ms step_avg:98.24ms +step:481/1670 train_time:47249ms step_avg:98.23ms +step:482/1670 train_time:47345ms step_avg:98.23ms +step:483/1670 train_time:47440ms step_avg:98.22ms +step:484/1670 train_time:47536ms step_avg:98.21ms +step:485/1670 train_time:47631ms step_avg:98.21ms +step:486/1670 train_time:47727ms step_avg:98.20ms +step:487/1670 train_time:47822ms step_avg:98.20ms +step:488/1670 train_time:47918ms step_avg:98.19ms +step:489/1670 train_time:48013ms step_avg:98.19ms +step:490/1670 train_time:48109ms step_avg:98.18ms +step:491/1670 train_time:48204ms step_avg:98.18ms +step:492/1670 train_time:48300ms step_avg:98.17ms +step:493/1670 train_time:48395ms step_avg:98.16ms +step:494/1670 train_time:48491ms step_avg:98.16ms +step:495/1670 train_time:48587ms step_avg:98.16ms +step:496/1670 train_time:48683ms step_avg:98.15ms +step:497/1670 train_time:48779ms step_avg:98.15ms +step:498/1670 train_time:48874ms step_avg:98.14ms +step:499/1670 train_time:48970ms step_avg:98.14ms +step:500/1670 train_time:49065ms step_avg:98.13ms +step:500/1670 val_loss:3.7142 train_time:49160ms step_avg:98.32ms +step:501/1670 train_time:49181ms step_avg:98.17ms +step:502/1670 train_time:49263ms step_avg:98.13ms +step:503/1670 train_time:49366ms step_avg:98.14ms +step:504/1670 train_time:49461ms step_avg:98.14ms +step:505/1670 train_time:49557ms step_avg:98.13ms +step:506/1670 train_time:49652ms step_avg:98.13ms +step:507/1670 train_time:49746ms step_avg:98.12ms +step:508/1670 train_time:49841ms step_avg:98.11ms +step:509/1670 train_time:49936ms step_avg:98.11ms +step:510/1670 train_time:50030ms step_avg:98.10ms +step:511/1670 train_time:50125ms step_avg:98.09ms +step:512/1670 train_time:50223ms step_avg:98.09ms +step:513/1670 train_time:50321ms step_avg:98.09ms +step:514/1670 train_time:50417ms step_avg:98.09ms +step:515/1670 train_time:50512ms step_avg:98.08ms +step:516/1670 train_time:50608ms step_avg:98.08ms +step:517/1670 train_time:50704ms step_avg:98.07ms +step:518/1670 train_time:50798ms step_avg:98.07ms +step:519/1670 train_time:50893ms step_avg:98.06ms +step:520/1670 train_time:50988ms step_avg:98.05ms +step:521/1670 train_time:51083ms step_avg:98.05ms +step:522/1670 train_time:51179ms step_avg:98.04ms +step:523/1670 train_time:51276ms step_avg:98.04ms +step:524/1670 train_time:51372ms step_avg:98.04ms +step:525/1670 train_time:51469ms step_avg:98.04ms +step:526/1670 train_time:51565ms step_avg:98.03ms +step:527/1670 train_time:51660ms step_avg:98.03ms +step:528/1670 train_time:51755ms step_avg:98.02ms +step:529/1670 train_time:51850ms step_avg:98.02ms +step:530/1670 train_time:51945ms step_avg:98.01ms +step:531/1670 train_time:52041ms step_avg:98.00ms +step:532/1670 train_time:52135ms step_avg:98.00ms +step:533/1670 train_time:52231ms step_avg:97.99ms +step:534/1670 train_time:52327ms step_avg:97.99ms +step:535/1670 train_time:52424ms step_avg:97.99ms +step:536/1670 train_time:52520ms step_avg:97.99ms +step:537/1670 train_time:52617ms step_avg:97.98ms +step:538/1670 train_time:52712ms step_avg:97.98ms +step:539/1670 train_time:52807ms step_avg:97.97ms +step:540/1670 train_time:52902ms step_avg:97.97ms +step:541/1670 train_time:52997ms step_avg:97.96ms +step:542/1670 train_time:53092ms step_avg:97.96ms +step:543/1670 train_time:53187ms step_avg:97.95ms +step:544/1670 train_time:53283ms step_avg:97.95ms +step:545/1670 train_time:53380ms step_avg:97.94ms +step:546/1670 train_time:53476ms step_avg:97.94ms +step:547/1670 train_time:53572ms step_avg:97.94ms +step:548/1670 train_time:53668ms step_avg:97.93ms +step:549/1670 train_time:53765ms step_avg:97.93ms +step:550/1670 train_time:53859ms step_avg:97.93ms +step:551/1670 train_time:53955ms step_avg:97.92ms +step:552/1670 train_time:54050ms step_avg:97.92ms +step:553/1670 train_time:54145ms step_avg:97.91ms +step:554/1670 train_time:54241ms step_avg:97.91ms +step:555/1670 train_time:54336ms step_avg:97.90ms +step:556/1670 train_time:54432ms step_avg:97.90ms +step:557/1670 train_time:54529ms step_avg:97.90ms +step:558/1670 train_time:54625ms step_avg:97.89ms +step:559/1670 train_time:54721ms step_avg:97.89ms +step:560/1670 train_time:54818ms step_avg:97.89ms +step:561/1670 train_time:54915ms step_avg:97.89ms +step:562/1670 train_time:55011ms step_avg:97.88ms +step:563/1670 train_time:55107ms step_avg:97.88ms +step:564/1670 train_time:55204ms step_avg:97.88ms +step:565/1670 train_time:55301ms step_avg:97.88ms +step:566/1670 train_time:55399ms step_avg:97.88ms +step:567/1670 train_time:55496ms step_avg:97.88ms +step:568/1670 train_time:55593ms step_avg:97.87ms +step:569/1670 train_time:55690ms step_avg:97.87ms +step:570/1670 train_time:55787ms step_avg:97.87ms +step:571/1670 train_time:55886ms step_avg:97.87ms +step:572/1670 train_time:55983ms step_avg:97.87ms +step:573/1670 train_time:56079ms step_avg:97.87ms +step:574/1670 train_time:56176ms step_avg:97.87ms +step:575/1670 train_time:56273ms step_avg:97.87ms +step:576/1670 train_time:56370ms step_avg:97.87ms +step:577/1670 train_time:56468ms step_avg:97.87ms +step:578/1670 train_time:56566ms step_avg:97.87ms +step:579/1670 train_time:56664ms step_avg:97.86ms +step:580/1670 train_time:56761ms step_avg:97.86ms +step:581/1670 train_time:56859ms step_avg:97.86ms +step:582/1670 train_time:56957ms step_avg:97.86ms +step:583/1670 train_time:57053ms step_avg:97.86ms +step:584/1670 train_time:57150ms step_avg:97.86ms +step:585/1670 train_time:57247ms step_avg:97.86ms +step:586/1670 train_time:57344ms step_avg:97.86ms +step:587/1670 train_time:57441ms step_avg:97.85ms +step:588/1670 train_time:57538ms step_avg:97.85ms +step:589/1670 train_time:57634ms step_avg:97.85ms +step:590/1670 train_time:57731ms step_avg:97.85ms +step:591/1670 train_time:57829ms step_avg:97.85ms +step:592/1670 train_time:57927ms step_avg:97.85ms +step:593/1670 train_time:58025ms step_avg:97.85ms +step:594/1670 train_time:58123ms step_avg:97.85ms +step:595/1670 train_time:58221ms step_avg:97.85ms +step:596/1670 train_time:58318ms step_avg:97.85ms +step:597/1670 train_time:58415ms step_avg:97.85ms +step:598/1670 train_time:58512ms step_avg:97.85ms +step:599/1670 train_time:58609ms step_avg:97.84ms +step:600/1670 train_time:58706ms step_avg:97.84ms +step:601/1670 train_time:58804ms step_avg:97.84ms +step:602/1670 train_time:58903ms step_avg:97.84ms +step:603/1670 train_time:58999ms step_avg:97.84ms +step:604/1670 train_time:59095ms step_avg:97.84ms +step:605/1670 train_time:59192ms step_avg:97.84ms +step:606/1670 train_time:59289ms step_avg:97.84ms +step:607/1670 train_time:59387ms step_avg:97.84ms +step:608/1670 train_time:59485ms step_avg:97.84ms +step:609/1670 train_time:59583ms step_avg:97.84ms +step:610/1670 train_time:59680ms step_avg:97.84ms +step:611/1670 train_time:59776ms step_avg:97.83ms +step:612/1670 train_time:59872ms step_avg:97.83ms +step:613/1670 train_time:59970ms step_avg:97.83ms +step:614/1670 train_time:60067ms step_avg:97.83ms +step:615/1670 train_time:60165ms step_avg:97.83ms +step:616/1670 train_time:60263ms step_avg:97.83ms +step:617/1670 train_time:60361ms step_avg:97.83ms +step:618/1670 train_time:60457ms step_avg:97.83ms +step:619/1670 train_time:60554ms step_avg:97.83ms +step:620/1670 train_time:60651ms step_avg:97.82ms +step:621/1670 train_time:60748ms step_avg:97.82ms +step:622/1670 train_time:60845ms step_avg:97.82ms +step:623/1670 train_time:60942ms step_avg:97.82ms +step:624/1670 train_time:61040ms step_avg:97.82ms +step:625/1670 train_time:61136ms step_avg:97.82ms +step:625/1670 val_loss:3.6151 train_time:61232ms step_avg:97.97ms +step:626/1670 train_time:61256ms step_avg:97.85ms +step:627/1670 train_time:61334ms step_avg:97.82ms +step:628/1670 train_time:61429ms step_avg:97.82ms +step:629/1670 train_time:61528ms step_avg:97.82ms +step:630/1670 train_time:61624ms step_avg:97.82ms +step:631/1670 train_time:61720ms step_avg:97.81ms +step:632/1670 train_time:61815ms step_avg:97.81ms +step:633/1670 train_time:61910ms step_avg:97.80ms +step:634/1670 train_time:62007ms step_avg:97.80ms +step:635/1670 train_time:62103ms step_avg:97.80ms +step:636/1670 train_time:62208ms step_avg:97.81ms +step:637/1670 train_time:62306ms step_avg:97.81ms +step:638/1670 train_time:62405ms step_avg:97.81ms +step:639/1670 train_time:62688ms step_avg:98.10ms +step:640/1670 train_time:62833ms step_avg:98.18ms +step:641/1670 train_time:62927ms step_avg:98.17ms +step:642/1670 train_time:63023ms step_avg:98.17ms +step:643/1670 train_time:63120ms step_avg:98.16ms +step:644/1670 train_time:63216ms step_avg:98.16ms +step:645/1670 train_time:63311ms step_avg:98.16ms +step:646/1670 train_time:63406ms step_avg:98.15ms +step:647/1670 train_time:63502ms step_avg:98.15ms +step:648/1670 train_time:63597ms step_avg:98.14ms +step:649/1670 train_time:63696ms step_avg:98.14ms +step:650/1670 train_time:63796ms step_avg:98.15ms +step:651/1670 train_time:63894ms step_avg:98.15ms +step:652/1670 train_time:63990ms step_avg:98.14ms +step:653/1670 train_time:64086ms step_avg:98.14ms +step:654/1670 train_time:64182ms step_avg:98.14ms +step:655/1670 train_time:64279ms step_avg:98.14ms +step:656/1670 train_time:64375ms step_avg:98.13ms +step:657/1670 train_time:64470ms step_avg:98.13ms +step:658/1670 train_time:64567ms step_avg:98.13ms +step:659/1670 train_time:64665ms step_avg:98.13ms +step:660/1670 train_time:64764ms step_avg:98.13ms +step:661/1670 train_time:64863ms step_avg:98.13ms +step:662/1670 train_time:64962ms step_avg:98.13ms +step:663/1670 train_time:65059ms step_avg:98.13ms +step:664/1670 train_time:65156ms step_avg:98.13ms +step:665/1670 train_time:65253ms step_avg:98.12ms +step:666/1670 train_time:65349ms step_avg:98.12ms +step:667/1670 train_time:65444ms step_avg:98.12ms +step:668/1670 train_time:65541ms step_avg:98.11ms +step:669/1670 train_time:65637ms step_avg:98.11ms +step:670/1670 train_time:65734ms step_avg:98.11ms +step:671/1670 train_time:65830ms step_avg:98.11ms +step:672/1670 train_time:65928ms step_avg:98.11ms +step:673/1670 train_time:66026ms step_avg:98.11ms +step:674/1670 train_time:66124ms step_avg:98.11ms +step:675/1670 train_time:66221ms step_avg:98.11ms +step:676/1670 train_time:66319ms step_avg:98.10ms +step:677/1670 train_time:66414ms step_avg:98.10ms +step:678/1670 train_time:66510ms step_avg:98.10ms +step:679/1670 train_time:66606ms step_avg:98.09ms +step:680/1670 train_time:66704ms step_avg:98.09ms +step:681/1670 train_time:66801ms step_avg:98.09ms +step:682/1670 train_time:66900ms step_avg:98.09ms +step:683/1670 train_time:66998ms step_avg:98.09ms +step:684/1670 train_time:67094ms step_avg:98.09ms +step:685/1670 train_time:67190ms step_avg:98.09ms +step:686/1670 train_time:67287ms step_avg:98.09ms +step:687/1670 train_time:67384ms step_avg:98.08ms +step:688/1670 train_time:67481ms step_avg:98.08ms +step:689/1670 train_time:67578ms step_avg:98.08ms +step:690/1670 train_time:67674ms step_avg:98.08ms +step:691/1670 train_time:67771ms step_avg:98.08ms +step:692/1670 train_time:67868ms step_avg:98.07ms +step:693/1670 train_time:67965ms step_avg:98.07ms +step:694/1670 train_time:68063ms step_avg:98.07ms +step:695/1670 train_time:68161ms step_avg:98.07ms +step:696/1670 train_time:68259ms step_avg:98.07ms +step:697/1670 train_time:68356ms step_avg:98.07ms +step:698/1670 train_time:68453ms step_avg:98.07ms +step:699/1670 train_time:68549ms step_avg:98.07ms +step:700/1670 train_time:68646ms step_avg:98.07ms +step:701/1670 train_time:68743ms step_avg:98.06ms +step:702/1670 train_time:68840ms step_avg:98.06ms +step:703/1670 train_time:68937ms step_avg:98.06ms +step:704/1670 train_time:69033ms step_avg:98.06ms +step:705/1670 train_time:69129ms step_avg:98.05ms +step:706/1670 train_time:69226ms step_avg:98.05ms +step:707/1670 train_time:69324ms step_avg:98.05ms +step:708/1670 train_time:69422ms step_avg:98.05ms +step:709/1670 train_time:69519ms step_avg:98.05ms +step:710/1670 train_time:69616ms step_avg:98.05ms +step:711/1670 train_time:69712ms step_avg:98.05ms +step:712/1670 train_time:69809ms step_avg:98.05ms +step:713/1670 train_time:69906ms step_avg:98.04ms +step:714/1670 train_time:70003ms step_avg:98.04ms +step:715/1670 train_time:70100ms step_avg:98.04ms +step:716/1670 train_time:70197ms step_avg:98.04ms +step:717/1670 train_time:70293ms step_avg:98.04ms +step:718/1670 train_time:70389ms step_avg:98.03ms +step:719/1670 train_time:70485ms step_avg:98.03ms +step:720/1670 train_time:70582ms step_avg:98.03ms +step:721/1670 train_time:70680ms step_avg:98.03ms +step:722/1670 train_time:70776ms step_avg:98.03ms +step:723/1670 train_time:70873ms step_avg:98.03ms +step:724/1670 train_time:70969ms step_avg:98.02ms +step:725/1670 train_time:71067ms step_avg:98.02ms +step:726/1670 train_time:71164ms step_avg:98.02ms +step:727/1670 train_time:71261ms step_avg:98.02ms +step:728/1670 train_time:71359ms step_avg:98.02ms +step:729/1670 train_time:71457ms step_avg:98.02ms +step:730/1670 train_time:71554ms step_avg:98.02ms +step:731/1670 train_time:71650ms step_avg:98.02ms +step:732/1670 train_time:71748ms step_avg:98.02ms +step:733/1670 train_time:71845ms step_avg:98.01ms +step:734/1670 train_time:71941ms step_avg:98.01ms +step:735/1670 train_time:72039ms step_avg:98.01ms +step:736/1670 train_time:72135ms step_avg:98.01ms +step:737/1670 train_time:72232ms step_avg:98.01ms +step:738/1670 train_time:72329ms step_avg:98.01ms +step:739/1670 train_time:72426ms step_avg:98.01ms +step:740/1670 train_time:72524ms step_avg:98.01ms +step:741/1670 train_time:72621ms step_avg:98.00ms +step:742/1670 train_time:72718ms step_avg:98.00ms +step:743/1670 train_time:72814ms step_avg:98.00ms +step:744/1670 train_time:72910ms step_avg:98.00ms +step:745/1670 train_time:73008ms step_avg:98.00ms +step:746/1670 train_time:73106ms step_avg:98.00ms +step:747/1670 train_time:73203ms step_avg:98.00ms +step:748/1670 train_time:73301ms step_avg:98.00ms +step:749/1670 train_time:73398ms step_avg:98.00ms +step:750/1670 train_time:73496ms step_avg:97.99ms +step:750/1670 val_loss:3.5640 train_time:73591ms step_avg:98.12ms +step:751/1670 train_time:73613ms step_avg:98.02ms +step:752/1670 train_time:73695ms step_avg:98.00ms +step:753/1670 train_time:73795ms step_avg:98.00ms +step:754/1670 train_time:73892ms step_avg:98.00ms +step:755/1670 train_time:73989ms step_avg:98.00ms +step:756/1670 train_time:74085ms step_avg:98.00ms +step:757/1670 train_time:74181ms step_avg:97.99ms +step:758/1670 train_time:74277ms step_avg:97.99ms +step:759/1670 train_time:74373ms step_avg:97.99ms +step:760/1670 train_time:74469ms step_avg:97.98ms +step:761/1670 train_time:74567ms step_avg:97.99ms +step:762/1670 train_time:74665ms step_avg:97.99ms +step:763/1670 train_time:74764ms step_avg:97.99ms +step:764/1670 train_time:74863ms step_avg:97.99ms +step:765/1670 train_time:74961ms step_avg:97.99ms +step:766/1670 train_time:75059ms step_avg:97.99ms +step:767/1670 train_time:75156ms step_avg:97.99ms +step:768/1670 train_time:75252ms step_avg:97.98ms +step:769/1670 train_time:75348ms step_avg:97.98ms +step:770/1670 train_time:75444ms step_avg:97.98ms +step:771/1670 train_time:75541ms step_avg:97.98ms +step:772/1670 train_time:75640ms step_avg:97.98ms +step:773/1670 train_time:75739ms step_avg:97.98ms +step:774/1670 train_time:75837ms step_avg:97.98ms +step:775/1670 train_time:75933ms step_avg:97.98ms +step:776/1670 train_time:76030ms step_avg:97.98ms +step:777/1670 train_time:76126ms step_avg:97.97ms +step:778/1670 train_time:76223ms step_avg:97.97ms +step:779/1670 train_time:76319ms step_avg:97.97ms +step:780/1670 train_time:76416ms step_avg:97.97ms +step:781/1670 train_time:76513ms step_avg:97.97ms +step:782/1670 train_time:76610ms step_avg:97.97ms +step:783/1670 train_time:76706ms step_avg:97.96ms +step:784/1670 train_time:76803ms step_avg:97.96ms +step:785/1670 train_time:76901ms step_avg:97.96ms +step:786/1670 train_time:76999ms step_avg:97.96ms +step:787/1670 train_time:77097ms step_avg:97.96ms +step:788/1670 train_time:77194ms step_avg:97.96ms +step:789/1670 train_time:77290ms step_avg:97.96ms +step:790/1670 train_time:77386ms step_avg:97.96ms +step:791/1670 train_time:77482ms step_avg:97.95ms +step:792/1670 train_time:77580ms step_avg:97.95ms +step:793/1670 train_time:77678ms step_avg:97.95ms +step:794/1670 train_time:77776ms step_avg:97.95ms +step:795/1670 train_time:77873ms step_avg:97.95ms +step:796/1670 train_time:77970ms step_avg:97.95ms +step:797/1670 train_time:78067ms step_avg:97.95ms +step:798/1670 train_time:78163ms step_avg:97.95ms +step:799/1670 train_time:78260ms step_avg:97.95ms +step:800/1670 train_time:78357ms step_avg:97.95ms +step:801/1670 train_time:78453ms step_avg:97.94ms +step:802/1670 train_time:78550ms step_avg:97.94ms +step:803/1670 train_time:78646ms step_avg:97.94ms +step:804/1670 train_time:78743ms step_avg:97.94ms +step:805/1670 train_time:78842ms step_avg:97.94ms +step:806/1670 train_time:78939ms step_avg:97.94ms +step:807/1670 train_time:79037ms step_avg:97.94ms +step:808/1670 train_time:79134ms step_avg:97.94ms +step:809/1670 train_time:79231ms step_avg:97.94ms +step:810/1670 train_time:79327ms step_avg:97.93ms +step:811/1670 train_time:79425ms step_avg:97.93ms +step:812/1670 train_time:79521ms step_avg:97.93ms +step:813/1670 train_time:79618ms step_avg:97.93ms +step:814/1670 train_time:79716ms step_avg:97.93ms +step:815/1670 train_time:79814ms step_avg:97.93ms +step:816/1670 train_time:79911ms step_avg:97.93ms +step:817/1670 train_time:80008ms step_avg:97.93ms +step:818/1670 train_time:80105ms step_avg:97.93ms +step:819/1670 train_time:80203ms step_avg:97.93ms +step:820/1670 train_time:80301ms step_avg:97.93ms +step:821/1670 train_time:80398ms step_avg:97.93ms +step:822/1670 train_time:80494ms step_avg:97.93ms +step:823/1670 train_time:80591ms step_avg:97.92ms +step:824/1670 train_time:80688ms step_avg:97.92ms +step:825/1670 train_time:80785ms step_avg:97.92ms +step:826/1670 train_time:80882ms step_avg:97.92ms +step:827/1670 train_time:80981ms step_avg:97.92ms +step:828/1670 train_time:81079ms step_avg:97.92ms +step:829/1670 train_time:81176ms step_avg:97.92ms +step:830/1670 train_time:81272ms step_avg:97.92ms +step:831/1670 train_time:81369ms step_avg:97.92ms +step:832/1670 train_time:81465ms step_avg:97.91ms +step:833/1670 train_time:81562ms step_avg:97.91ms +step:834/1670 train_time:81660ms step_avg:97.91ms +step:835/1670 train_time:81757ms step_avg:97.91ms +step:836/1670 train_time:81854ms step_avg:97.91ms +step:837/1670 train_time:81951ms step_avg:97.91ms +step:838/1670 train_time:82047ms step_avg:97.91ms +step:839/1670 train_time:82144ms step_avg:97.91ms +step:840/1670 train_time:82242ms step_avg:97.91ms +step:841/1670 train_time:82340ms step_avg:97.91ms +step:842/1670 train_time:82437ms step_avg:97.91ms +step:843/1670 train_time:82534ms step_avg:97.90ms +step:844/1670 train_time:82630ms step_avg:97.90ms +step:845/1670 train_time:82726ms step_avg:97.90ms +step:846/1670 train_time:82823ms step_avg:97.90ms +step:847/1670 train_time:82921ms step_avg:97.90ms +step:848/1670 train_time:83018ms step_avg:97.90ms +step:849/1670 train_time:83117ms step_avg:97.90ms +step:850/1670 train_time:83214ms step_avg:97.90ms +step:851/1670 train_time:83564ms step_avg:98.19ms +step:852/1670 train_time:83637ms step_avg:98.17ms +step:853/1670 train_time:83732ms step_avg:98.16ms +step:854/1670 train_time:83827ms step_avg:98.16ms +step:855/1670 train_time:83923ms step_avg:98.16ms +step:856/1670 train_time:84020ms step_avg:98.15ms +step:857/1670 train_time:84116ms step_avg:98.15ms +step:858/1670 train_time:84212ms step_avg:98.15ms +step:859/1670 train_time:84307ms step_avg:98.15ms +step:860/1670 train_time:84404ms step_avg:98.14ms +step:861/1670 train_time:84503ms step_avg:98.15ms +step:862/1670 train_time:84605ms step_avg:98.15ms +step:863/1670 train_time:84704ms step_avg:98.15ms +step:864/1670 train_time:84802ms step_avg:98.15ms +step:865/1670 train_time:84898ms step_avg:98.15ms +step:866/1670 train_time:84996ms step_avg:98.15ms +step:867/1670 train_time:85092ms step_avg:98.15ms +step:868/1670 train_time:85187ms step_avg:98.14ms +step:869/1670 train_time:85284ms step_avg:98.14ms +step:870/1670 train_time:85381ms step_avg:98.14ms +step:871/1670 train_time:85478ms step_avg:98.14ms +step:872/1670 train_time:85578ms step_avg:98.14ms +step:873/1670 train_time:85678ms step_avg:98.14ms +step:874/1670 train_time:85777ms step_avg:98.14ms +step:875/1670 train_time:85873ms step_avg:98.14ms +step:875/1670 val_loss:3.5224 train_time:85970ms step_avg:98.25ms +step:876/1670 train_time:85992ms step_avg:98.16ms +step:877/1670 train_time:86075ms step_avg:98.15ms +step:878/1670 train_time:86174ms step_avg:98.15ms +step:879/1670 train_time:86271ms step_avg:98.15ms +step:880/1670 train_time:86366ms step_avg:98.14ms +step:881/1670 train_time:86462ms step_avg:98.14ms +step:882/1670 train_time:86558ms step_avg:98.14ms +step:883/1670 train_time:86653ms step_avg:98.14ms +step:884/1670 train_time:86749ms step_avg:98.13ms +step:885/1670 train_time:86845ms step_avg:98.13ms +step:886/1670 train_time:86944ms step_avg:98.13ms +step:887/1670 train_time:87043ms step_avg:98.13ms +step:888/1670 train_time:87142ms step_avg:98.13ms +step:889/1670 train_time:87241ms step_avg:98.13ms +step:890/1670 train_time:87339ms step_avg:98.13ms +step:891/1670 train_time:87435ms step_avg:98.13ms +step:892/1670 train_time:87532ms step_avg:98.13ms +step:893/1670 train_time:87628ms step_avg:98.13ms +step:894/1670 train_time:87724ms step_avg:98.12ms +step:895/1670 train_time:87820ms step_avg:98.12ms +step:896/1670 train_time:87917ms step_avg:98.12ms +step:897/1670 train_time:88015ms step_avg:98.12ms +step:898/1670 train_time:88113ms step_avg:98.12ms +step:899/1670 train_time:88210ms step_avg:98.12ms +step:900/1670 train_time:88307ms step_avg:98.12ms +step:901/1670 train_time:88404ms step_avg:98.12ms +step:902/1670 train_time:88502ms step_avg:98.12ms +step:903/1670 train_time:88599ms step_avg:98.12ms +step:904/1670 train_time:88696ms step_avg:98.12ms +step:905/1670 train_time:88793ms step_avg:98.11ms +step:906/1670 train_time:88888ms step_avg:98.11ms +step:907/1670 train_time:88986ms step_avg:98.11ms +step:908/1670 train_time:89085ms step_avg:98.11ms +step:909/1670 train_time:89183ms step_avg:98.11ms +step:910/1670 train_time:89280ms step_avg:98.11ms +step:911/1670 train_time:89377ms step_avg:98.11ms +step:912/1670 train_time:89474ms step_avg:98.11ms +step:913/1670 train_time:89570ms step_avg:98.11ms +step:914/1670 train_time:89667ms step_avg:98.10ms +step:915/1670 train_time:89764ms step_avg:98.10ms +step:916/1670 train_time:89861ms step_avg:98.10ms +step:917/1670 train_time:89960ms step_avg:98.10ms +step:918/1670 train_time:90058ms step_avg:98.10ms +step:919/1670 train_time:90155ms step_avg:98.10ms +step:920/1670 train_time:90253ms step_avg:98.10ms +step:921/1670 train_time:90349ms step_avg:98.10ms +step:922/1670 train_time:90446ms step_avg:98.10ms +step:923/1670 train_time:90543ms step_avg:98.10ms +step:924/1670 train_time:90640ms step_avg:98.10ms +step:925/1670 train_time:90738ms step_avg:98.09ms +step:926/1670 train_time:90834ms step_avg:98.09ms +step:927/1670 train_time:90930ms step_avg:98.09ms +step:928/1670 train_time:91027ms step_avg:98.09ms +step:929/1670 train_time:91125ms step_avg:98.09ms +step:930/1670 train_time:91225ms step_avg:98.09ms +step:931/1670 train_time:91323ms step_avg:98.09ms +step:932/1670 train_time:91421ms step_avg:98.09ms +step:933/1670 train_time:91517ms step_avg:98.09ms +step:934/1670 train_time:91614ms step_avg:98.09ms +step:935/1670 train_time:91710ms step_avg:98.09ms +step:936/1670 train_time:91807ms step_avg:98.08ms +step:937/1670 train_time:91904ms step_avg:98.08ms +step:938/1670 train_time:92001ms step_avg:98.08ms +step:939/1670 train_time:92100ms step_avg:98.08ms +step:940/1670 train_time:92198ms step_avg:98.08ms +step:941/1670 train_time:92297ms step_avg:98.08ms +step:942/1670 train_time:92395ms step_avg:98.08ms +step:943/1670 train_time:92491ms step_avg:98.08ms +step:944/1670 train_time:92587ms step_avg:98.08ms +step:945/1670 train_time:92685ms step_avg:98.08ms +step:946/1670 train_time:92783ms step_avg:98.08ms +step:947/1670 train_time:92880ms step_avg:98.08ms +step:948/1670 train_time:92977ms step_avg:98.08ms +step:949/1670 train_time:93074ms step_avg:98.08ms +step:950/1670 train_time:93170ms step_avg:98.07ms +step:951/1670 train_time:93267ms step_avg:98.07ms +step:952/1670 train_time:93366ms step_avg:98.07ms +step:953/1670 train_time:93463ms step_avg:98.07ms +step:954/1670 train_time:93561ms step_avg:98.07ms +step:955/1670 train_time:93658ms step_avg:98.07ms +step:956/1670 train_time:93756ms step_avg:98.07ms +step:957/1670 train_time:93852ms step_avg:98.07ms +step:958/1670 train_time:93948ms step_avg:98.07ms +step:959/1670 train_time:94045ms step_avg:98.07ms +step:960/1670 train_time:94144ms step_avg:98.07ms +step:961/1670 train_time:94241ms step_avg:98.07ms +step:962/1670 train_time:94340ms step_avg:98.07ms +step:963/1670 train_time:94437ms step_avg:98.07ms +step:964/1670 train_time:94535ms step_avg:98.07ms +step:965/1670 train_time:94632ms step_avg:98.06ms +step:966/1670 train_time:94728ms step_avg:98.06ms +step:967/1670 train_time:94825ms step_avg:98.06ms +step:968/1670 train_time:94922ms step_avg:98.06ms +step:969/1670 train_time:95018ms step_avg:98.06ms +step:970/1670 train_time:95116ms step_avg:98.06ms +step:971/1670 train_time:95213ms step_avg:98.06ms +step:972/1670 train_time:95310ms step_avg:98.06ms +step:973/1670 train_time:95407ms step_avg:98.05ms +step:974/1670 train_time:95503ms step_avg:98.05ms +step:975/1670 train_time:95600ms step_avg:98.05ms +step:976/1670 train_time:95698ms step_avg:98.05ms +step:977/1670 train_time:95796ms step_avg:98.05ms +step:978/1670 train_time:95894ms step_avg:98.05ms +step:979/1670 train_time:95990ms step_avg:98.05ms +step:980/1670 train_time:96087ms step_avg:98.05ms +step:981/1670 train_time:96185ms step_avg:98.05ms +step:982/1670 train_time:96284ms step_avg:98.05ms +step:983/1670 train_time:96381ms step_avg:98.05ms +step:984/1670 train_time:96478ms step_avg:98.05ms +step:985/1670 train_time:96574ms step_avg:98.05ms +step:986/1670 train_time:96671ms step_avg:98.04ms +step:987/1670 train_time:96767ms step_avg:98.04ms +step:988/1670 train_time:96865ms step_avg:98.04ms +step:989/1670 train_time:96962ms step_avg:98.04ms +step:990/1670 train_time:97060ms step_avg:98.04ms +step:991/1670 train_time:97158ms step_avg:98.04ms +step:992/1670 train_time:97255ms step_avg:98.04ms +step:993/1670 train_time:97353ms step_avg:98.04ms +step:994/1670 train_time:97449ms step_avg:98.04ms +step:995/1670 train_time:97546ms step_avg:98.04ms +step:996/1670 train_time:97643ms step_avg:98.04ms +step:997/1670 train_time:97740ms step_avg:98.03ms +step:998/1670 train_time:97838ms step_avg:98.03ms +step:999/1670 train_time:97936ms step_avg:98.03ms +step:1000/1670 train_time:98034ms step_avg:98.03ms +step:1000/1670 val_loss:3.4803 train_time:98129ms step_avg:98.13ms +step:1001/1670 train_time:98151ms step_avg:98.05ms +step:1002/1670 train_time:98235ms step_avg:98.04ms +step:1003/1670 train_time:98338ms step_avg:98.04ms +step:1004/1670 train_time:98437ms step_avg:98.05ms +step:1005/1670 train_time:98534ms step_avg:98.04ms +step:1006/1670 train_time:98631ms step_avg:98.04ms +step:1007/1670 train_time:98726ms step_avg:98.04ms +step:1008/1670 train_time:98822ms step_avg:98.04ms +step:1009/1670 train_time:98918ms step_avg:98.04ms +step:1010/1670 train_time:99014ms step_avg:98.03ms +step:1011/1670 train_time:99113ms step_avg:98.03ms +step:1012/1670 train_time:99212ms step_avg:98.04ms +step:1013/1670 train_time:99312ms step_avg:98.04ms +step:1014/1670 train_time:99411ms step_avg:98.04ms +step:1015/1670 train_time:99507ms step_avg:98.04ms +step:1016/1670 train_time:99603ms step_avg:98.03ms +step:1017/1670 train_time:99700ms step_avg:98.03ms +step:1018/1670 train_time:99797ms step_avg:98.03ms +step:1019/1670 train_time:99893ms step_avg:98.03ms +step:1020/1670 train_time:99989ms step_avg:98.03ms +step:1021/1670 train_time:100085ms step_avg:98.03ms +step:1022/1670 train_time:100183ms step_avg:98.03ms +step:1023/1670 train_time:100284ms step_avg:98.03ms +step:1024/1670 train_time:100383ms step_avg:98.03ms +step:1025/1670 train_time:100480ms step_avg:98.03ms +step:1026/1670 train_time:100578ms step_avg:98.03ms +step:1027/1670 train_time:100675ms step_avg:98.03ms +step:1028/1670 train_time:100772ms step_avg:98.03ms +step:1029/1670 train_time:100867ms step_avg:98.02ms +step:1030/1670 train_time:100963ms step_avg:98.02ms +step:1031/1670 train_time:101060ms step_avg:98.02ms +step:1032/1670 train_time:101159ms step_avg:98.02ms +step:1033/1670 train_time:101258ms step_avg:98.02ms +step:1034/1670 train_time:101358ms step_avg:98.03ms +step:1035/1670 train_time:101456ms step_avg:98.03ms +step:1036/1670 train_time:101555ms step_avg:98.03ms +step:1037/1670 train_time:101652ms step_avg:98.03ms +step:1038/1670 train_time:101750ms step_avg:98.02ms +step:1039/1670 train_time:101846ms step_avg:98.02ms +step:1040/1670 train_time:101943ms step_avg:98.02ms +step:1041/1670 train_time:102040ms step_avg:98.02ms +step:1042/1670 train_time:102137ms step_avg:98.02ms +step:1043/1670 train_time:102234ms step_avg:98.02ms +step:1044/1670 train_time:102332ms step_avg:98.02ms +step:1045/1670 train_time:102430ms step_avg:98.02ms +step:1046/1670 train_time:102526ms step_avg:98.02ms +step:1047/1670 train_time:102623ms step_avg:98.02ms +step:1048/1670 train_time:102722ms step_avg:98.02ms +step:1049/1670 train_time:102819ms step_avg:98.02ms +step:1050/1670 train_time:102916ms step_avg:98.02ms +step:1051/1670 train_time:103013ms step_avg:98.01ms +step:1052/1670 train_time:103109ms step_avg:98.01ms +step:1053/1670 train_time:103206ms step_avg:98.01ms +step:1054/1670 train_time:103303ms step_avg:98.01ms +step:1055/1670 train_time:103401ms step_avg:98.01ms +step:1056/1670 train_time:103500ms step_avg:98.01ms +step:1057/1670 train_time:103597ms step_avg:98.01ms +step:1058/1670 train_time:103694ms step_avg:98.01ms +step:1059/1670 train_time:103791ms step_avg:98.01ms +step:1060/1670 train_time:103889ms step_avg:98.01ms +step:1061/1670 train_time:103985ms step_avg:98.01ms +step:1062/1670 train_time:104305ms step_avg:98.22ms +step:1063/1670 train_time:104444ms step_avg:98.25ms +step:1064/1670 train_time:104540ms step_avg:98.25ms +step:1065/1670 train_time:104636ms step_avg:98.25ms +step:1066/1670 train_time:104732ms step_avg:98.25ms +step:1067/1670 train_time:104827ms step_avg:98.24ms +step:1068/1670 train_time:104923ms step_avg:98.24ms +step:1069/1670 train_time:105019ms step_avg:98.24ms +step:1070/1670 train_time:105115ms step_avg:98.24ms +step:1071/1670 train_time:105211ms step_avg:98.24ms +step:1072/1670 train_time:105310ms step_avg:98.24ms +step:1073/1670 train_time:105413ms step_avg:98.24ms +step:1074/1670 train_time:105512ms step_avg:98.24ms +step:1075/1670 train_time:105610ms step_avg:98.24ms +step:1076/1670 train_time:105706ms step_avg:98.24ms +step:1077/1670 train_time:105802ms step_avg:98.24ms +step:1078/1670 train_time:105899ms step_avg:98.24ms +step:1079/1670 train_time:105996ms step_avg:98.23ms +step:1080/1670 train_time:106092ms step_avg:98.23ms +step:1081/1670 train_time:106187ms step_avg:98.23ms +step:1082/1670 train_time:106285ms step_avg:98.23ms +step:1083/1670 train_time:106384ms step_avg:98.23ms +step:1084/1670 train_time:106484ms step_avg:98.23ms +step:1085/1670 train_time:106583ms step_avg:98.23ms +step:1086/1670 train_time:106681ms step_avg:98.23ms +step:1087/1670 train_time:106778ms step_avg:98.23ms +step:1088/1670 train_time:106875ms step_avg:98.23ms +step:1089/1670 train_time:106972ms step_avg:98.23ms +step:1090/1670 train_time:107067ms step_avg:98.23ms +step:1091/1670 train_time:107163ms step_avg:98.22ms +step:1092/1670 train_time:107262ms step_avg:98.23ms +step:1093/1670 train_time:107360ms step_avg:98.23ms +step:1094/1670 train_time:107459ms step_avg:98.23ms +step:1095/1670 train_time:107558ms step_avg:98.23ms +step:1096/1670 train_time:107656ms step_avg:98.23ms +step:1097/1670 train_time:107753ms step_avg:98.22ms +step:1098/1670 train_time:107849ms step_avg:98.22ms +step:1099/1670 train_time:107945ms step_avg:98.22ms +step:1100/1670 train_time:108042ms step_avg:98.22ms +step:1101/1670 train_time:108139ms step_avg:98.22ms +step:1102/1670 train_time:108236ms step_avg:98.22ms +step:1103/1670 train_time:108333ms step_avg:98.22ms +step:1104/1670 train_time:108431ms step_avg:98.22ms +step:1105/1670 train_time:108530ms step_avg:98.22ms +step:1106/1670 train_time:108628ms step_avg:98.22ms +step:1107/1670 train_time:108725ms step_avg:98.22ms +step:1108/1670 train_time:108822ms step_avg:98.21ms +step:1109/1670 train_time:108920ms step_avg:98.21ms +step:1110/1670 train_time:109017ms step_avg:98.21ms +step:1111/1670 train_time:109113ms step_avg:98.21ms +step:1112/1670 train_time:109211ms step_avg:98.21ms +step:1113/1670 train_time:109308ms step_avg:98.21ms +step:1114/1670 train_time:109404ms step_avg:98.21ms +step:1115/1670 train_time:109502ms step_avg:98.21ms +step:1116/1670 train_time:109601ms step_avg:98.21ms +step:1117/1670 train_time:109699ms step_avg:98.21ms +step:1118/1670 train_time:109798ms step_avg:98.21ms +step:1119/1670 train_time:109896ms step_avg:98.21ms +step:1120/1670 train_time:109992ms step_avg:98.21ms +step:1121/1670 train_time:110090ms step_avg:98.21ms +step:1122/1670 train_time:110187ms step_avg:98.21ms +step:1123/1670 train_time:110284ms step_avg:98.21ms +step:1124/1670 train_time:110382ms step_avg:98.20ms +step:1125/1670 train_time:110481ms step_avg:98.20ms +step:1125/1670 val_loss:3.4260 train_time:110580ms step_avg:98.29ms +step:1126/1670 train_time:110603ms step_avg:98.23ms +step:1127/1670 train_time:110693ms step_avg:98.22ms +step:1128/1670 train_time:110791ms step_avg:98.22ms +step:1129/1670 train_time:110887ms step_avg:98.22ms +step:1130/1670 train_time:110984ms step_avg:98.22ms +step:1131/1670 train_time:111080ms step_avg:98.21ms +step:1132/1670 train_time:111176ms step_avg:98.21ms +step:1133/1670 train_time:111273ms step_avg:98.21ms +step:1134/1670 train_time:111369ms step_avg:98.21ms +step:1135/1670 train_time:111466ms step_avg:98.21ms +step:1136/1670 train_time:111569ms step_avg:98.21ms +step:1137/1670 train_time:111670ms step_avg:98.21ms +step:1138/1670 train_time:111769ms step_avg:98.22ms +step:1139/1670 train_time:111866ms step_avg:98.21ms +step:1140/1670 train_time:111962ms step_avg:98.21ms +step:1141/1670 train_time:112058ms step_avg:98.21ms +step:1142/1670 train_time:112155ms step_avg:98.21ms +step:1143/1670 train_time:112252ms step_avg:98.21ms +step:1144/1670 train_time:112348ms step_avg:98.21ms +step:1145/1670 train_time:112445ms step_avg:98.21ms +step:1146/1670 train_time:112544ms step_avg:98.21ms +step:1147/1670 train_time:112644ms step_avg:98.21ms +step:1148/1670 train_time:112745ms step_avg:98.21ms +step:1149/1670 train_time:112843ms step_avg:98.21ms +step:1150/1670 train_time:112940ms step_avg:98.21ms +step:1151/1670 train_time:113037ms step_avg:98.21ms +step:1152/1670 train_time:113133ms step_avg:98.21ms +step:1153/1670 train_time:113230ms step_avg:98.20ms +step:1154/1670 train_time:113326ms step_avg:98.20ms +step:1155/1670 train_time:113424ms step_avg:98.20ms +step:1156/1670 train_time:113522ms step_avg:98.20ms +step:1157/1670 train_time:113621ms step_avg:98.20ms +step:1158/1670 train_time:113721ms step_avg:98.20ms +step:1159/1670 train_time:113821ms step_avg:98.21ms +step:1160/1670 train_time:113919ms step_avg:98.21ms +step:1161/1670 train_time:114016ms step_avg:98.20ms +step:1162/1670 train_time:114112ms step_avg:98.20ms +step:1163/1670 train_time:114209ms step_avg:98.20ms +step:1164/1670 train_time:114306ms step_avg:98.20ms +step:1165/1670 train_time:114404ms step_avg:98.20ms +step:1166/1670 train_time:114501ms step_avg:98.20ms +step:1167/1670 train_time:114599ms step_avg:98.20ms +step:1168/1670 train_time:114698ms step_avg:98.20ms +step:1169/1670 train_time:114796ms step_avg:98.20ms +step:1170/1670 train_time:114894ms step_avg:98.20ms +step:1171/1670 train_time:114991ms step_avg:98.20ms +step:1172/1670 train_time:115088ms step_avg:98.20ms +step:1173/1670 train_time:115186ms step_avg:98.20ms +step:1174/1670 train_time:115285ms step_avg:98.20ms +step:1175/1670 train_time:115384ms step_avg:98.20ms +step:1176/1670 train_time:115483ms step_avg:98.20ms +step:1177/1670 train_time:115581ms step_avg:98.20ms +step:1178/1670 train_time:115678ms step_avg:98.20ms +step:1179/1670 train_time:115778ms step_avg:98.20ms +step:1180/1670 train_time:115876ms step_avg:98.20ms +step:1181/1670 train_time:115973ms step_avg:98.20ms +step:1182/1670 train_time:116071ms step_avg:98.20ms +step:1183/1670 train_time:116168ms step_avg:98.20ms +step:1184/1670 train_time:116265ms step_avg:98.20ms +step:1185/1670 train_time:116364ms step_avg:98.20ms +step:1186/1670 train_time:116462ms step_avg:98.20ms +step:1187/1670 train_time:116560ms step_avg:98.20ms +step:1188/1670 train_time:116657ms step_avg:98.20ms +step:1189/1670 train_time:116756ms step_avg:98.20ms +step:1190/1670 train_time:116855ms step_avg:98.20ms +step:1191/1670 train_time:116950ms step_avg:98.20ms +step:1192/1670 train_time:117048ms step_avg:98.19ms +step:1193/1670 train_time:117146ms step_avg:98.19ms +step:1194/1670 train_time:117243ms step_avg:98.19ms +step:1195/1670 train_time:117340ms step_avg:98.19ms +step:1196/1670 train_time:117438ms step_avg:98.19ms +step:1197/1670 train_time:117536ms step_avg:98.19ms +step:1198/1670 train_time:117634ms step_avg:98.19ms +step:1199/1670 train_time:117731ms step_avg:98.19ms +step:1200/1670 train_time:117828ms step_avg:98.19ms +step:1201/1670 train_time:117927ms step_avg:98.19ms +step:1202/1670 train_time:118025ms step_avg:98.19ms +step:1203/1670 train_time:118124ms step_avg:98.19ms +step:1204/1670 train_time:118222ms step_avg:98.19ms +step:1205/1670 train_time:118319ms step_avg:98.19ms +step:1206/1670 train_time:118417ms step_avg:98.19ms +step:1207/1670 train_time:118514ms step_avg:98.19ms +step:1208/1670 train_time:118611ms step_avg:98.19ms +step:1209/1670 train_time:118708ms step_avg:98.19ms +step:1210/1670 train_time:118807ms step_avg:98.19ms +step:1211/1670 train_time:118905ms step_avg:98.19ms +step:1212/1670 train_time:119004ms step_avg:98.19ms +step:1213/1670 train_time:119103ms step_avg:98.19ms +step:1214/1670 train_time:119202ms step_avg:98.19ms +step:1215/1670 train_time:119299ms step_avg:98.19ms +step:1216/1670 train_time:119396ms step_avg:98.19ms +step:1217/1670 train_time:119494ms step_avg:98.19ms +step:1218/1670 train_time:119591ms step_avg:98.19ms +step:1219/1670 train_time:119689ms step_avg:98.19ms +step:1220/1670 train_time:119788ms step_avg:98.19ms +step:1221/1670 train_time:119886ms step_avg:98.19ms +step:1222/1670 train_time:119984ms step_avg:98.19ms +step:1223/1670 train_time:120081ms step_avg:98.19ms +step:1224/1670 train_time:120180ms step_avg:98.19ms +step:1225/1670 train_time:120279ms step_avg:98.19ms +step:1226/1670 train_time:120376ms step_avg:98.19ms +step:1227/1670 train_time:120473ms step_avg:98.19ms +step:1228/1670 train_time:120570ms step_avg:98.18ms +step:1229/1670 train_time:120667ms step_avg:98.18ms +step:1230/1670 train_time:120766ms step_avg:98.18ms +step:1231/1670 train_time:120864ms step_avg:98.18ms +step:1232/1670 train_time:120962ms step_avg:98.18ms +step:1233/1670 train_time:121059ms step_avg:98.18ms +step:1234/1670 train_time:121156ms step_avg:98.18ms +step:1235/1670 train_time:121254ms step_avg:98.18ms +step:1236/1670 train_time:121351ms step_avg:98.18ms +step:1237/1670 train_time:121449ms step_avg:98.18ms +step:1238/1670 train_time:121547ms step_avg:98.18ms +step:1239/1670 train_time:121645ms step_avg:98.18ms +step:1240/1670 train_time:121742ms step_avg:98.18ms +step:1241/1670 train_time:121840ms step_avg:98.18ms +step:1242/1670 train_time:121938ms step_avg:98.18ms +step:1243/1670 train_time:122035ms step_avg:98.18ms +step:1244/1670 train_time:122132ms step_avg:98.18ms +step:1245/1670 train_time:122230ms step_avg:98.18ms +step:1246/1670 train_time:122328ms step_avg:98.18ms +step:1247/1670 train_time:122425ms step_avg:98.18ms +step:1248/1670 train_time:122523ms step_avg:98.18ms +step:1249/1670 train_time:122623ms step_avg:98.18ms +step:1250/1670 train_time:122721ms step_avg:98.18ms +step:1250/1670 val_loss:3.3835 train_time:122818ms step_avg:98.25ms +step:1251/1670 train_time:122841ms step_avg:98.19ms +step:1252/1670 train_time:122921ms step_avg:98.18ms +step:1253/1670 train_time:123021ms step_avg:98.18ms +step:1254/1670 train_time:123119ms step_avg:98.18ms +step:1255/1670 train_time:123215ms step_avg:98.18ms +step:1256/1670 train_time:123312ms step_avg:98.18ms +step:1257/1670 train_time:123409ms step_avg:98.18ms +step:1258/1670 train_time:123505ms step_avg:98.18ms +step:1259/1670 train_time:123602ms step_avg:98.17ms +step:1260/1670 train_time:123698ms step_avg:98.17ms +step:1261/1670 train_time:123796ms step_avg:98.17ms +step:1262/1670 train_time:123895ms step_avg:98.17ms +step:1263/1670 train_time:123996ms step_avg:98.18ms +step:1264/1670 train_time:124095ms step_avg:98.18ms +step:1265/1670 train_time:124192ms step_avg:98.18ms +step:1266/1670 train_time:124290ms step_avg:98.18ms +step:1267/1670 train_time:124388ms step_avg:98.18ms +step:1268/1670 train_time:124485ms step_avg:98.17ms +step:1269/1670 train_time:124582ms step_avg:98.17ms +step:1270/1670 train_time:124678ms step_avg:98.17ms +step:1271/1670 train_time:124776ms step_avg:98.17ms +step:1272/1670 train_time:124875ms step_avg:98.17ms +step:1273/1670 train_time:124974ms step_avg:98.17ms +step:1274/1670 train_time:125233ms step_avg:98.30ms +step:1275/1670 train_time:125395ms step_avg:98.35ms +step:1276/1670 train_time:125490ms step_avg:98.35ms +step:1277/1670 train_time:125587ms step_avg:98.35ms +step:1278/1670 train_time:125683ms step_avg:98.34ms +step:1279/1670 train_time:125780ms step_avg:98.34ms +step:1280/1670 train_time:125876ms step_avg:98.34ms +step:1281/1670 train_time:125973ms step_avg:98.34ms +step:1282/1670 train_time:126070ms step_avg:98.34ms +step:1283/1670 train_time:126167ms step_avg:98.34ms +step:1284/1670 train_time:126269ms step_avg:98.34ms +step:1285/1670 train_time:126372ms step_avg:98.34ms +step:1286/1670 train_time:126472ms step_avg:98.35ms +step:1287/1670 train_time:126571ms step_avg:98.35ms +step:1288/1670 train_time:126670ms step_avg:98.35ms +step:1289/1670 train_time:126767ms step_avg:98.35ms +step:1290/1670 train_time:126865ms step_avg:98.34ms +step:1291/1670 train_time:126962ms step_avg:98.34ms +step:1292/1670 train_time:127059ms step_avg:98.34ms +step:1293/1670 train_time:127155ms step_avg:98.34ms +step:1294/1670 train_time:127253ms step_avg:98.34ms +step:1295/1670 train_time:127351ms step_avg:98.34ms +step:1296/1670 train_time:127450ms step_avg:98.34ms +step:1297/1670 train_time:127549ms step_avg:98.34ms +step:1298/1670 train_time:127646ms step_avg:98.34ms +step:1299/1670 train_time:127744ms step_avg:98.34ms +step:1300/1670 train_time:127843ms step_avg:98.34ms +step:1301/1670 train_time:127942ms step_avg:98.34ms +step:1302/1670 train_time:128038ms step_avg:98.34ms +step:1303/1670 train_time:128135ms step_avg:98.34ms +step:1304/1670 train_time:128232ms step_avg:98.34ms +step:1305/1670 train_time:128330ms step_avg:98.34ms +step:1306/1670 train_time:128429ms step_avg:98.34ms +step:1307/1670 train_time:128527ms step_avg:98.34ms +step:1308/1670 train_time:128625ms step_avg:98.34ms +step:1309/1670 train_time:128723ms step_avg:98.34ms +step:1310/1670 train_time:128820ms step_avg:98.34ms +step:1311/1670 train_time:128917ms step_avg:98.34ms +step:1312/1670 train_time:129015ms step_avg:98.33ms +step:1313/1670 train_time:129112ms step_avg:98.33ms +step:1314/1670 train_time:129210ms step_avg:98.33ms +step:1315/1670 train_time:129308ms step_avg:98.33ms +step:1316/1670 train_time:129407ms step_avg:98.33ms +step:1317/1670 train_time:129505ms step_avg:98.33ms +step:1318/1670 train_time:129603ms step_avg:98.33ms +step:1319/1670 train_time:129700ms step_avg:98.33ms +step:1320/1670 train_time:129799ms step_avg:98.33ms +step:1321/1670 train_time:129896ms step_avg:98.33ms +step:1322/1670 train_time:129994ms step_avg:98.33ms +step:1323/1670 train_time:130091ms step_avg:98.33ms +step:1324/1670 train_time:130189ms step_avg:98.33ms +step:1325/1670 train_time:130286ms step_avg:98.33ms +step:1326/1670 train_time:130384ms step_avg:98.33ms +step:1327/1670 train_time:130482ms step_avg:98.33ms +step:1328/1670 train_time:130579ms step_avg:98.33ms +step:1329/1670 train_time:130677ms step_avg:98.33ms +step:1330/1670 train_time:130775ms step_avg:98.33ms +step:1331/1670 train_time:130874ms step_avg:98.33ms +step:1332/1670 train_time:130972ms step_avg:98.33ms +step:1333/1670 train_time:131071ms step_avg:98.33ms +step:1334/1670 train_time:131170ms step_avg:98.33ms +step:1335/1670 train_time:131267ms step_avg:98.33ms +step:1336/1670 train_time:131365ms step_avg:98.33ms +step:1337/1670 train_time:131463ms step_avg:98.33ms +step:1338/1670 train_time:131560ms step_avg:98.33ms +step:1339/1670 train_time:131658ms step_avg:98.33ms +step:1340/1670 train_time:131756ms step_avg:98.33ms +step:1341/1670 train_time:131854ms step_avg:98.32ms +step:1342/1670 train_time:131951ms step_avg:98.32ms +step:1343/1670 train_time:132050ms step_avg:98.32ms +step:1344/1670 train_time:132148ms step_avg:98.32ms +step:1345/1670 train_time:132246ms step_avg:98.32ms +step:1346/1670 train_time:132344ms step_avg:98.32ms +step:1347/1670 train_time:132442ms step_avg:98.32ms +step:1348/1670 train_time:132539ms step_avg:98.32ms +step:1349/1670 train_time:132637ms step_avg:98.32ms +step:1350/1670 train_time:132734ms step_avg:98.32ms +step:1351/1670 train_time:132833ms step_avg:98.32ms +step:1352/1670 train_time:132931ms step_avg:98.32ms +step:1353/1670 train_time:133029ms step_avg:98.32ms +step:1354/1670 train_time:133127ms step_avg:98.32ms +step:1355/1670 train_time:133225ms step_avg:98.32ms +step:1356/1670 train_time:133323ms step_avg:98.32ms +step:1357/1670 train_time:133421ms step_avg:98.32ms +step:1358/1670 train_time:133518ms step_avg:98.32ms +step:1359/1670 train_time:133616ms step_avg:98.32ms +step:1360/1670 train_time:133713ms step_avg:98.32ms +step:1361/1670 train_time:133811ms step_avg:98.32ms +step:1362/1670 train_time:133910ms step_avg:98.32ms +step:1363/1670 train_time:134008ms step_avg:98.32ms +step:1364/1670 train_time:134105ms step_avg:98.32ms +step:1365/1670 train_time:134203ms step_avg:98.32ms +step:1366/1670 train_time:134299ms step_avg:98.32ms +step:1367/1670 train_time:134397ms step_avg:98.32ms +step:1368/1670 train_time:134496ms step_avg:98.32ms +step:1369/1670 train_time:134594ms step_avg:98.32ms +step:1370/1670 train_time:134691ms step_avg:98.31ms +step:1371/1670 train_time:134790ms step_avg:98.32ms +step:1372/1670 train_time:134890ms step_avg:98.32ms +step:1373/1670 train_time:134988ms step_avg:98.32ms +step:1374/1670 train_time:135085ms step_avg:98.32ms +step:1375/1670 train_time:135183ms step_avg:98.32ms +step:1375/1670 val_loss:3.3460 train_time:135280ms step_avg:98.39ms +step:1376/1670 train_time:135302ms step_avg:98.33ms +step:1377/1670 train_time:135387ms step_avg:98.32ms +step:1378/1670 train_time:135485ms step_avg:98.32ms +step:1379/1670 train_time:135582ms step_avg:98.32ms +step:1380/1670 train_time:135679ms step_avg:98.32ms +step:1381/1670 train_time:135775ms step_avg:98.32ms +step:1382/1670 train_time:135872ms step_avg:98.32ms +step:1383/1670 train_time:135969ms step_avg:98.31ms +step:1384/1670 train_time:136065ms step_avg:98.31ms +step:1385/1670 train_time:136162ms step_avg:98.31ms +step:1386/1670 train_time:136263ms step_avg:98.31ms +step:1387/1670 train_time:136365ms step_avg:98.32ms +step:1388/1670 train_time:136464ms step_avg:98.32ms +step:1389/1670 train_time:136563ms step_avg:98.32ms +step:1390/1670 train_time:136661ms step_avg:98.32ms +step:1391/1670 train_time:136758ms step_avg:98.32ms +step:1392/1670 train_time:136855ms step_avg:98.32ms +step:1393/1670 train_time:136952ms step_avg:98.31ms +step:1394/1670 train_time:137049ms step_avg:98.31ms +step:1395/1670 train_time:137146ms step_avg:98.31ms +step:1396/1670 train_time:137244ms step_avg:98.31ms +step:1397/1670 train_time:137343ms step_avg:98.31ms +step:1398/1670 train_time:137443ms step_avg:98.31ms +step:1399/1670 train_time:137541ms step_avg:98.31ms +step:1400/1670 train_time:137639ms step_avg:98.31ms +step:1401/1670 train_time:137737ms step_avg:98.31ms +step:1402/1670 train_time:137834ms step_avg:98.31ms +step:1403/1670 train_time:137932ms step_avg:98.31ms +step:1404/1670 train_time:138029ms step_avg:98.31ms +step:1405/1670 train_time:138126ms step_avg:98.31ms +step:1406/1670 train_time:138224ms step_avg:98.31ms +step:1407/1670 train_time:138322ms step_avg:98.31ms +step:1408/1670 train_time:138422ms step_avg:98.31ms +step:1409/1670 train_time:138521ms step_avg:98.31ms +step:1410/1670 train_time:138620ms step_avg:98.31ms +step:1411/1670 train_time:138718ms step_avg:98.31ms +step:1412/1670 train_time:138815ms step_avg:98.31ms +step:1413/1670 train_time:138912ms step_avg:98.31ms +step:1414/1670 train_time:139009ms step_avg:98.31ms +step:1415/1670 train_time:139106ms step_avg:98.31ms +step:1416/1670 train_time:139203ms step_avg:98.31ms +step:1417/1670 train_time:139301ms step_avg:98.31ms +step:1418/1670 train_time:139400ms step_avg:98.31ms +step:1419/1670 train_time:139499ms step_avg:98.31ms +step:1420/1670 train_time:139598ms step_avg:98.31ms +step:1421/1670 train_time:139697ms step_avg:98.31ms +step:1422/1670 train_time:139794ms step_avg:98.31ms +step:1423/1670 train_time:139892ms step_avg:98.31ms +step:1424/1670 train_time:139989ms step_avg:98.31ms +step:1425/1670 train_time:140087ms step_avg:98.31ms +step:1426/1670 train_time:140185ms step_avg:98.31ms +step:1427/1670 train_time:140282ms step_avg:98.31ms +step:1428/1670 train_time:140381ms step_avg:98.31ms +step:1429/1670 train_time:140480ms step_avg:98.31ms +step:1430/1670 train_time:140578ms step_avg:98.31ms +step:1431/1670 train_time:140676ms step_avg:98.31ms +step:1432/1670 train_time:140774ms step_avg:98.31ms +step:1433/1670 train_time:140872ms step_avg:98.31ms +step:1434/1670 train_time:140969ms step_avg:98.30ms +step:1435/1670 train_time:141067ms step_avg:98.30ms +step:1436/1670 train_time:141165ms step_avg:98.30ms +step:1437/1670 train_time:141263ms step_avg:98.30ms +step:1438/1670 train_time:141360ms step_avg:98.30ms +step:1439/1670 train_time:141457ms step_avg:98.30ms +step:1440/1670 train_time:141555ms step_avg:98.30ms +step:1441/1670 train_time:141653ms step_avg:98.30ms +step:1442/1670 train_time:141751ms step_avg:98.30ms +step:1443/1670 train_time:141849ms step_avg:98.30ms +step:1444/1670 train_time:141946ms step_avg:98.30ms +step:1445/1670 train_time:142044ms step_avg:98.30ms +step:1446/1670 train_time:142142ms step_avg:98.30ms +step:1447/1670 train_time:142241ms step_avg:98.30ms +step:1448/1670 train_time:142338ms step_avg:98.30ms +step:1449/1670 train_time:142436ms step_avg:98.30ms +step:1450/1670 train_time:142535ms step_avg:98.30ms +step:1451/1670 train_time:142634ms step_avg:98.30ms +step:1452/1670 train_time:142731ms step_avg:98.30ms +step:1453/1670 train_time:142829ms step_avg:98.30ms +step:1454/1670 train_time:142926ms step_avg:98.30ms +step:1455/1670 train_time:143024ms step_avg:98.30ms +step:1456/1670 train_time:143122ms step_avg:98.30ms +step:1457/1670 train_time:143219ms step_avg:98.30ms +step:1458/1670 train_time:143318ms step_avg:98.30ms +step:1459/1670 train_time:143415ms step_avg:98.30ms +step:1460/1670 train_time:143514ms step_avg:98.30ms +step:1461/1670 train_time:143613ms step_avg:98.30ms +step:1462/1670 train_time:143711ms step_avg:98.30ms +step:1463/1670 train_time:143808ms step_avg:98.30ms +step:1464/1670 train_time:143906ms step_avg:98.30ms +step:1465/1670 train_time:144004ms step_avg:98.30ms +step:1466/1670 train_time:144102ms step_avg:98.30ms +step:1467/1670 train_time:144200ms step_avg:98.30ms +step:1468/1670 train_time:144299ms step_avg:98.30ms +step:1469/1670 train_time:144397ms step_avg:98.30ms +step:1470/1670 train_time:144496ms step_avg:98.30ms +step:1471/1670 train_time:144593ms step_avg:98.30ms +step:1472/1670 train_time:144691ms step_avg:98.30ms +step:1473/1670 train_time:144789ms step_avg:98.30ms +step:1474/1670 train_time:144887ms step_avg:98.29ms +step:1475/1670 train_time:144984ms step_avg:98.29ms +step:1476/1670 train_time:145081ms step_avg:98.29ms +step:1477/1670 train_time:145179ms step_avg:98.29ms +step:1478/1670 train_time:145277ms step_avg:98.29ms +step:1479/1670 train_time:145374ms step_avg:98.29ms +step:1480/1670 train_time:145472ms step_avg:98.29ms +step:1481/1670 train_time:145570ms step_avg:98.29ms +step:1482/1670 train_time:145668ms step_avg:98.29ms +step:1483/1670 train_time:145766ms step_avg:98.29ms +step:1484/1670 train_time:145863ms step_avg:98.29ms +step:1485/1670 train_time:146203ms step_avg:98.45ms +step:1486/1670 train_time:146277ms step_avg:98.44ms +step:1487/1670 train_time:146373ms step_avg:98.44ms +step:1488/1670 train_time:146470ms step_avg:98.43ms +step:1489/1670 train_time:146566ms step_avg:98.43ms +step:1490/1670 train_time:146663ms step_avg:98.43ms +step:1491/1670 train_time:146760ms step_avg:98.43ms +step:1492/1670 train_time:146857ms step_avg:98.43ms +step:1493/1670 train_time:146954ms step_avg:98.43ms +step:1494/1670 train_time:147051ms step_avg:98.43ms +step:1495/1670 train_time:147155ms step_avg:98.43ms +step:1496/1670 train_time:147259ms step_avg:98.44ms +step:1497/1670 train_time:147358ms step_avg:98.44ms +step:1498/1670 train_time:147457ms step_avg:98.44ms +step:1499/1670 train_time:147555ms step_avg:98.44ms +step:1500/1670 train_time:147653ms step_avg:98.44ms +step:1500/1670 val_loss:3.3137 train_time:147749ms step_avg:98.50ms +step:1501/1670 train_time:147772ms step_avg:98.45ms +step:1502/1670 train_time:147855ms step_avg:98.44ms +step:1503/1670 train_time:147954ms step_avg:98.44ms +step:1504/1670 train_time:148052ms step_avg:98.44ms +step:1505/1670 train_time:148150ms step_avg:98.44ms +step:1506/1670 train_time:148247ms step_avg:98.44ms +step:1507/1670 train_time:148344ms step_avg:98.44ms +step:1508/1670 train_time:148441ms step_avg:98.44ms +step:1509/1670 train_time:148538ms step_avg:98.43ms +step:1510/1670 train_time:148635ms step_avg:98.43ms +step:1511/1670 train_time:148735ms step_avg:98.43ms +step:1512/1670 train_time:148834ms step_avg:98.44ms +step:1513/1670 train_time:148932ms step_avg:98.44ms +step:1514/1670 train_time:149031ms step_avg:98.44ms +step:1515/1670 train_time:149129ms step_avg:98.43ms +step:1516/1670 train_time:149226ms step_avg:98.43ms +step:1517/1670 train_time:149323ms step_avg:98.43ms +step:1518/1670 train_time:149420ms step_avg:98.43ms +step:1519/1670 train_time:149517ms step_avg:98.43ms +step:1520/1670 train_time:149614ms step_avg:98.43ms +step:1521/1670 train_time:149712ms step_avg:98.43ms +step:1522/1670 train_time:149811ms step_avg:98.43ms +step:1523/1670 train_time:149910ms step_avg:98.43ms +step:1524/1670 train_time:150009ms step_avg:98.43ms +step:1525/1670 train_time:150107ms step_avg:98.43ms +step:1526/1670 train_time:150205ms step_avg:98.43ms +step:1527/1670 train_time:150303ms step_avg:98.43ms +step:1528/1670 train_time:150400ms step_avg:98.43ms +step:1529/1670 train_time:150498ms step_avg:98.43ms +step:1530/1670 train_time:150595ms step_avg:98.43ms +step:1531/1670 train_time:150692ms step_avg:98.43ms +step:1532/1670 train_time:150791ms step_avg:98.43ms +step:1533/1670 train_time:150889ms step_avg:98.43ms +step:1534/1670 train_time:150988ms step_avg:98.43ms +step:1535/1670 train_time:151087ms step_avg:98.43ms +step:1536/1670 train_time:151185ms step_avg:98.43ms +step:1537/1670 train_time:151282ms step_avg:98.43ms +step:1538/1670 train_time:151380ms step_avg:98.43ms +step:1539/1670 train_time:151477ms step_avg:98.43ms +step:1540/1670 train_time:151575ms step_avg:98.43ms +step:1541/1670 train_time:151673ms step_avg:98.42ms +step:1542/1670 train_time:151770ms step_avg:98.42ms +step:1543/1670 train_time:151869ms step_avg:98.42ms +step:1544/1670 train_time:151968ms step_avg:98.42ms +step:1545/1670 train_time:152067ms step_avg:98.43ms +step:1546/1670 train_time:152166ms step_avg:98.43ms +step:1547/1670 train_time:152265ms step_avg:98.43ms +step:1548/1670 train_time:152363ms step_avg:98.43ms +step:1549/1670 train_time:152461ms step_avg:98.43ms +step:1550/1670 train_time:152559ms step_avg:98.43ms +step:1551/1670 train_time:152658ms step_avg:98.43ms +step:1552/1670 train_time:152756ms step_avg:98.43ms +step:1553/1670 train_time:152855ms step_avg:98.43ms +step:1554/1670 train_time:152953ms step_avg:98.43ms +step:1555/1670 train_time:153052ms step_avg:98.43ms +step:1556/1670 train_time:153150ms step_avg:98.43ms +step:1557/1670 train_time:153247ms step_avg:98.42ms +step:1558/1670 train_time:153345ms step_avg:98.42ms +step:1559/1670 train_time:153444ms step_avg:98.42ms +step:1560/1670 train_time:153544ms step_avg:98.43ms +step:1561/1670 train_time:153643ms step_avg:98.43ms +step:1562/1670 train_time:153742ms step_avg:98.43ms +step:1563/1670 train_time:153841ms step_avg:98.43ms +step:1564/1670 train_time:153940ms step_avg:98.43ms +step:1565/1670 train_time:154039ms step_avg:98.43ms +step:1566/1670 train_time:154138ms step_avg:98.43ms +step:1567/1670 train_time:154236ms step_avg:98.43ms +step:1568/1670 train_time:154334ms step_avg:98.43ms +step:1569/1670 train_time:154431ms step_avg:98.43ms +step:1570/1670 train_time:154528ms step_avg:98.43ms +step:1571/1670 train_time:154626ms step_avg:98.43ms +step:1572/1670 train_time:154725ms step_avg:98.43ms +step:1573/1670 train_time:154824ms step_avg:98.43ms +step:1574/1670 train_time:154923ms step_avg:98.43ms +step:1575/1670 train_time:155022ms step_avg:98.43ms +step:1576/1670 train_time:155121ms step_avg:98.43ms +step:1577/1670 train_time:155220ms step_avg:98.43ms +step:1578/1670 train_time:155317ms step_avg:98.43ms +step:1579/1670 train_time:155415ms step_avg:98.43ms +step:1580/1670 train_time:155512ms step_avg:98.43ms +step:1581/1670 train_time:155610ms step_avg:98.43ms +step:1582/1670 train_time:155708ms step_avg:98.43ms +step:1583/1670 train_time:155807ms step_avg:98.43ms +step:1584/1670 train_time:155906ms step_avg:98.43ms +step:1585/1670 train_time:156005ms step_avg:98.43ms +step:1586/1670 train_time:156103ms step_avg:98.43ms +step:1587/1670 train_time:156202ms step_avg:98.43ms +step:1588/1670 train_time:156300ms step_avg:98.43ms +step:1589/1670 train_time:156398ms step_avg:98.43ms +step:1590/1670 train_time:156496ms step_avg:98.43ms +step:1591/1670 train_time:156594ms step_avg:98.42ms +step:1592/1670 train_time:156692ms step_avg:98.42ms +step:1593/1670 train_time:156790ms step_avg:98.42ms +step:1594/1670 train_time:156889ms step_avg:98.42ms +step:1595/1670 train_time:156987ms step_avg:98.42ms +step:1596/1670 train_time:157085ms step_avg:98.42ms +step:1597/1670 train_time:157184ms step_avg:98.42ms +step:1598/1670 train_time:157283ms step_avg:98.42ms +step:1599/1670 train_time:157381ms step_avg:98.42ms +step:1600/1670 train_time:157478ms step_avg:98.42ms +step:1601/1670 train_time:157577ms step_avg:98.42ms +step:1602/1670 train_time:157675ms step_avg:98.42ms +step:1603/1670 train_time:157772ms step_avg:98.42ms +step:1604/1670 train_time:157870ms step_avg:98.42ms +step:1605/1670 train_time:157968ms step_avg:98.42ms +step:1606/1670 train_time:158067ms step_avg:98.42ms +step:1607/1670 train_time:158165ms step_avg:98.42ms +step:1608/1670 train_time:158263ms step_avg:98.42ms +step:1609/1670 train_time:158361ms step_avg:98.42ms +step:1610/1670 train_time:158459ms step_avg:98.42ms +step:1611/1670 train_time:158558ms step_avg:98.42ms +step:1612/1670 train_time:158656ms step_avg:98.42ms +step:1613/1670 train_time:158753ms step_avg:98.42ms +step:1614/1670 train_time:158851ms step_avg:98.42ms +step:1615/1670 train_time:158948ms step_avg:98.42ms +step:1616/1670 train_time:159046ms step_avg:98.42ms +step:1617/1670 train_time:159144ms step_avg:98.42ms +step:1618/1670 train_time:159243ms step_avg:98.42ms +step:1619/1670 train_time:159342ms step_avg:98.42ms +step:1620/1670 train_time:159441ms step_avg:98.42ms +step:1621/1670 train_time:159538ms step_avg:98.42ms +step:1622/1670 train_time:159636ms step_avg:98.42ms +step:1623/1670 train_time:159734ms step_avg:98.42ms +step:1624/1670 train_time:159832ms step_avg:98.42ms +step:1625/1670 train_time:159930ms step_avg:98.42ms +step:1625/1670 val_loss:3.2871 train_time:160026ms step_avg:98.48ms +step:1626/1670 train_time:160048ms step_avg:98.43ms +step:1627/1670 train_time:160133ms step_avg:98.42ms +step:1628/1670 train_time:160233ms step_avg:98.42ms +step:1629/1670 train_time:160331ms step_avg:98.42ms +step:1630/1670 train_time:160428ms step_avg:98.42ms +step:1631/1670 train_time:160526ms step_avg:98.42ms +step:1632/1670 train_time:160623ms step_avg:98.42ms +step:1633/1670 train_time:160720ms step_avg:98.42ms +step:1634/1670 train_time:160816ms step_avg:98.42ms +step:1635/1670 train_time:160913ms step_avg:98.42ms +step:1636/1670 train_time:161013ms step_avg:98.42ms +step:1637/1670 train_time:161113ms step_avg:98.42ms +step:1638/1670 train_time:161214ms step_avg:98.42ms +step:1639/1670 train_time:161312ms step_avg:98.42ms +step:1640/1670 train_time:161410ms step_avg:98.42ms +step:1641/1670 train_time:161507ms step_avg:98.42ms +step:1642/1670 train_time:161605ms step_avg:98.42ms +step:1643/1670 train_time:161702ms step_avg:98.42ms +step:1644/1670 train_time:161799ms step_avg:98.42ms +step:1645/1670 train_time:161896ms step_avg:98.42ms +step:1646/1670 train_time:161995ms step_avg:98.42ms +step:1647/1670 train_time:162093ms step_avg:98.42ms +step:1648/1670 train_time:162193ms step_avg:98.42ms +step:1649/1670 train_time:162291ms step_avg:98.42ms +step:1650/1670 train_time:162390ms step_avg:98.42ms +step:1651/1670 train_time:162487ms step_avg:98.42ms +step:1652/1670 train_time:162585ms step_avg:98.42ms +step:1653/1670 train_time:162681ms step_avg:98.42ms +step:1654/1670 train_time:162779ms step_avg:98.42ms +step:1655/1670 train_time:162876ms step_avg:98.41ms +step:1656/1670 train_time:162974ms step_avg:98.41ms +step:1657/1670 train_time:163073ms step_avg:98.41ms +step:1658/1670 train_time:163171ms step_avg:98.41ms +step:1659/1670 train_time:163271ms step_avg:98.42ms +step:1660/1670 train_time:163371ms step_avg:98.42ms +step:1661/1670 train_time:163468ms step_avg:98.42ms +step:1662/1670 train_time:163566ms step_avg:98.42ms +step:1663/1670 train_time:163664ms step_avg:98.42ms +step:1664/1670 train_time:163763ms step_avg:98.42ms +step:1665/1670 train_time:163863ms step_avg:98.42ms +step:1666/1670 train_time:163962ms step_avg:98.42ms +step:1667/1670 train_time:164061ms step_avg:98.42ms +step:1668/1670 train_time:164161ms step_avg:98.42ms +step:1669/1670 train_time:164262ms step_avg:98.42ms +step:1670/1670 train_time:164361ms step_avg:98.42ms +step:1670/1670 val_loss:3.2790 train_time:164459ms step_avg:98.48ms +peak memory allocated: 34000 MiB reserved: 49496 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt b/records/050925_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt new file mode 100644 index 000000000..09040c23f --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:40:57 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 44C P0 130W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 42C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 43C P0 130W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 71199 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 71200 C /usr/bin/python3 610MiB | +| 0 N/A N/A 71201 C /usr/bin/python3 610MiB | +| 0 N/A N/A 71202 C /usr/bin/python3 610MiB | +| 0 N/A N/A 71203 C /usr/bin/python3 610MiB | +| 0 N/A N/A 71204 C /usr/bin/python3 610MiB | +| 0 N/A N/A 71205 C /usr/bin/python3 610MiB | +| 0 N/A N/A 71206 C /usr/bin/python3 610MiB | +| 1 N/A N/A 71200 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 71201 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 71202 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 71203 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 71204 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 71205 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 71206 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:419ms step_avg:418.68ms +step:2/1670 train_time:439ms step_avg:219.57ms +step:3/1670 train_time:512ms step_avg:170.52ms +step:4/1670 train_time:606ms step_avg:151.42ms +step:5/1670 train_time:700ms step_avg:140.09ms +step:6/1670 train_time:796ms step_avg:132.63ms +step:7/1670 train_time:891ms step_avg:127.22ms +step:8/1670 train_time:985ms step_avg:123.18ms +step:9/1670 train_time:1081ms step_avg:120.10ms +step:10/1670 train_time:1176ms step_avg:117.60ms +step:11/1670 train_time:1271ms step_avg:115.54ms +step:12/1670 train_time:1370ms step_avg:114.14ms +step:13/1670 train_time:1469ms step_avg:112.98ms +step:14/1670 train_time:1565ms step_avg:111.79ms +step:15/1670 train_time:1661ms step_avg:110.75ms +step:16/1670 train_time:1758ms step_avg:109.86ms +step:17/1670 train_time:1853ms step_avg:108.98ms +step:18/1670 train_time:1948ms step_avg:108.23ms +step:19/1670 train_time:2044ms step_avg:107.57ms +step:20/1670 train_time:2139ms step_avg:106.97ms +step:21/1670 train_time:2236ms step_avg:106.46ms +step:22/1670 train_time:2332ms step_avg:106.00ms +step:23/1670 train_time:2429ms step_avg:105.59ms +step:24/1670 train_time:2525ms step_avg:105.23ms +step:25/1670 train_time:2622ms step_avg:104.88ms +step:26/1670 train_time:2717ms step_avg:104.51ms +step:27/1670 train_time:2813ms step_avg:104.19ms +step:28/1670 train_time:2908ms step_avg:103.86ms +step:29/1670 train_time:3004ms step_avg:103.58ms +step:30/1670 train_time:3100ms step_avg:103.34ms +step:31/1670 train_time:3196ms step_avg:103.10ms +step:32/1670 train_time:3291ms step_avg:102.86ms +step:33/1670 train_time:3387ms step_avg:102.65ms +step:34/1670 train_time:3484ms step_avg:102.48ms +step:35/1670 train_time:3581ms step_avg:102.31ms +step:36/1670 train_time:3678ms step_avg:102.16ms +step:37/1670 train_time:3774ms step_avg:102.00ms +step:38/1670 train_time:3869ms step_avg:101.81ms +step:39/1670 train_time:3965ms step_avg:101.66ms +step:40/1670 train_time:4060ms step_avg:101.51ms +step:41/1670 train_time:4156ms step_avg:101.37ms +step:42/1670 train_time:4251ms step_avg:101.22ms +step:43/1670 train_time:4347ms step_avg:101.10ms +step:44/1670 train_time:4443ms step_avg:100.99ms +step:45/1670 train_time:4540ms step_avg:100.89ms +step:46/1670 train_time:4636ms step_avg:100.79ms +step:47/1670 train_time:4732ms step_avg:100.69ms +step:48/1670 train_time:4829ms step_avg:100.60ms +step:49/1670 train_time:4924ms step_avg:100.49ms +step:50/1670 train_time:5020ms step_avg:100.40ms +step:51/1670 train_time:5117ms step_avg:100.33ms +step:52/1670 train_time:5213ms step_avg:100.25ms +step:53/1670 train_time:5308ms step_avg:100.15ms +step:54/1670 train_time:5404ms step_avg:100.07ms +step:55/1670 train_time:5500ms step_avg:100.00ms +step:56/1670 train_time:5596ms step_avg:99.93ms +step:57/1670 train_time:5692ms step_avg:99.86ms +step:58/1670 train_time:5788ms step_avg:99.79ms +step:59/1670 train_time:5884ms step_avg:99.74ms +step:60/1670 train_time:5981ms step_avg:99.68ms +step:61/1670 train_time:6077ms step_avg:99.62ms +step:62/1670 train_time:6173ms step_avg:99.56ms +step:63/1670 train_time:6269ms step_avg:99.51ms +step:64/1670 train_time:6365ms step_avg:99.45ms +step:65/1670 train_time:6460ms step_avg:99.39ms +step:66/1670 train_time:6556ms step_avg:99.33ms +step:67/1670 train_time:6651ms step_avg:99.27ms +step:68/1670 train_time:6747ms step_avg:99.22ms +step:69/1670 train_time:6844ms step_avg:99.19ms +step:70/1670 train_time:6941ms step_avg:99.15ms +step:71/1670 train_time:7037ms step_avg:99.11ms +step:72/1670 train_time:7132ms step_avg:99.06ms +step:73/1670 train_time:7228ms step_avg:99.01ms +step:74/1670 train_time:7324ms step_avg:98.97ms +step:75/1670 train_time:7420ms step_avg:98.93ms +step:76/1670 train_time:7515ms step_avg:98.89ms +step:77/1670 train_time:7610ms step_avg:98.84ms +step:78/1670 train_time:7706ms step_avg:98.79ms +step:79/1670 train_time:7802ms step_avg:98.76ms +step:80/1670 train_time:7898ms step_avg:98.72ms +step:81/1670 train_time:7993ms step_avg:98.68ms +step:82/1670 train_time:8089ms step_avg:98.65ms +step:83/1670 train_time:8185ms step_avg:98.61ms +step:84/1670 train_time:8281ms step_avg:98.58ms +step:85/1670 train_time:8377ms step_avg:98.55ms +step:86/1670 train_time:8472ms step_avg:98.52ms +step:87/1670 train_time:8568ms step_avg:98.48ms +step:88/1670 train_time:8663ms step_avg:98.45ms +step:89/1670 train_time:8759ms step_avg:98.41ms +step:90/1670 train_time:8853ms step_avg:98.37ms +step:91/1670 train_time:8950ms step_avg:98.35ms +step:92/1670 train_time:9045ms step_avg:98.32ms +step:93/1670 train_time:9142ms step_avg:98.30ms +step:94/1670 train_time:9239ms step_avg:98.29ms +step:95/1670 train_time:9335ms step_avg:98.26ms +step:96/1670 train_time:9431ms step_avg:98.23ms +step:97/1670 train_time:9526ms step_avg:98.21ms +step:98/1670 train_time:9623ms step_avg:98.19ms +step:99/1670 train_time:9719ms step_avg:98.17ms +step:100/1670 train_time:9814ms step_avg:98.14ms +step:101/1670 train_time:9909ms step_avg:98.11ms +step:102/1670 train_time:10005ms step_avg:98.09ms +step:103/1670 train_time:10101ms step_avg:98.07ms +step:104/1670 train_time:10197ms step_avg:98.05ms +step:105/1670 train_time:10293ms step_avg:98.03ms +step:106/1670 train_time:10389ms step_avg:98.01ms +step:107/1670 train_time:10485ms step_avg:97.99ms +step:108/1670 train_time:10581ms step_avg:97.97ms +step:109/1670 train_time:10676ms step_avg:97.95ms +step:110/1670 train_time:10772ms step_avg:97.93ms +step:111/1670 train_time:10867ms step_avg:97.90ms +step:112/1670 train_time:10962ms step_avg:97.88ms +step:113/1670 train_time:11058ms step_avg:97.86ms +step:114/1670 train_time:11153ms step_avg:97.84ms +step:115/1670 train_time:11249ms step_avg:97.82ms +step:116/1670 train_time:11346ms step_avg:97.81ms +step:117/1670 train_time:11441ms step_avg:97.79ms +step:118/1670 train_time:11536ms step_avg:97.77ms +step:119/1670 train_time:11632ms step_avg:97.75ms +step:120/1670 train_time:11727ms step_avg:97.73ms +step:121/1670 train_time:11824ms step_avg:97.72ms +step:122/1670 train_time:11919ms step_avg:97.70ms +step:123/1670 train_time:12016ms step_avg:97.69ms +step:124/1670 train_time:12110ms step_avg:97.66ms +step:125/1670 train_time:12206ms step_avg:97.65ms +step:125/1670 val_loss:4.3041 train_time:12301ms step_avg:98.41ms +step:126/1670 train_time:12323ms step_avg:97.80ms +step:127/1670 train_time:12404ms step_avg:97.67ms +step:128/1670 train_time:12505ms step_avg:97.70ms +step:129/1670 train_time:12603ms step_avg:97.70ms +step:130/1670 train_time:12699ms step_avg:97.69ms +step:131/1670 train_time:12794ms step_avg:97.67ms +step:132/1670 train_time:12888ms step_avg:97.64ms +step:133/1670 train_time:12983ms step_avg:97.62ms +step:134/1670 train_time:13078ms step_avg:97.60ms +step:135/1670 train_time:13172ms step_avg:97.57ms +step:136/1670 train_time:13267ms step_avg:97.55ms +step:137/1670 train_time:13364ms step_avg:97.55ms +step:138/1670 train_time:13462ms step_avg:97.55ms +step:139/1670 train_time:13560ms step_avg:97.55ms +step:140/1670 train_time:13656ms step_avg:97.54ms +step:141/1670 train_time:13752ms step_avg:97.53ms +step:142/1670 train_time:13846ms step_avg:97.51ms +step:143/1670 train_time:13941ms step_avg:97.49ms +step:144/1670 train_time:14036ms step_avg:97.47ms +step:145/1670 train_time:14130ms step_avg:97.45ms +step:146/1670 train_time:14225ms step_avg:97.43ms +step:147/1670 train_time:14322ms step_avg:97.43ms +step:148/1670 train_time:14418ms step_avg:97.42ms +step:149/1670 train_time:14515ms step_avg:97.42ms +step:150/1670 train_time:14611ms step_avg:97.41ms +step:151/1670 train_time:14706ms step_avg:97.39ms +step:152/1670 train_time:14802ms step_avg:97.38ms +step:153/1670 train_time:14897ms step_avg:97.37ms +step:154/1670 train_time:14993ms step_avg:97.35ms +step:155/1670 train_time:15087ms step_avg:97.34ms +step:156/1670 train_time:15182ms step_avg:97.32ms +step:157/1670 train_time:15277ms step_avg:97.31ms +step:158/1670 train_time:15373ms step_avg:97.30ms +step:159/1670 train_time:15469ms step_avg:97.29ms +step:160/1670 train_time:15565ms step_avg:97.28ms +step:161/1670 train_time:15662ms step_avg:97.28ms +step:162/1670 train_time:15759ms step_avg:97.28ms +step:163/1670 train_time:15854ms step_avg:97.26ms +step:164/1670 train_time:15948ms step_avg:97.25ms +step:165/1670 train_time:16044ms step_avg:97.23ms +step:166/1670 train_time:16140ms step_avg:97.23ms +step:167/1670 train_time:16234ms step_avg:97.21ms +step:168/1670 train_time:16330ms step_avg:97.20ms +step:169/1670 train_time:16425ms step_avg:97.19ms +step:170/1670 train_time:16522ms step_avg:97.19ms +step:171/1670 train_time:16618ms step_avg:97.18ms +step:172/1670 train_time:16713ms step_avg:97.17ms +step:173/1670 train_time:16809ms step_avg:97.16ms +step:174/1670 train_time:16904ms step_avg:97.15ms +step:175/1670 train_time:17000ms step_avg:97.14ms +step:176/1670 train_time:17095ms step_avg:97.13ms +step:177/1670 train_time:17190ms step_avg:97.12ms +step:178/1670 train_time:17285ms step_avg:97.11ms +step:179/1670 train_time:17381ms step_avg:97.10ms +step:180/1670 train_time:17478ms step_avg:97.10ms +step:181/1670 train_time:17573ms step_avg:97.09ms +step:182/1670 train_time:17669ms step_avg:97.08ms +step:183/1670 train_time:17765ms step_avg:97.08ms +step:184/1670 train_time:17862ms step_avg:97.07ms +step:185/1670 train_time:17957ms step_avg:97.06ms +step:186/1670 train_time:18052ms step_avg:97.05ms +step:187/1670 train_time:18147ms step_avg:97.04ms +step:188/1670 train_time:18242ms step_avg:97.03ms +step:189/1670 train_time:18338ms step_avg:97.03ms +step:190/1670 train_time:18434ms step_avg:97.02ms +step:191/1670 train_time:18529ms step_avg:97.01ms +step:192/1670 train_time:18625ms step_avg:97.00ms +step:193/1670 train_time:18721ms step_avg:97.00ms +step:194/1670 train_time:18818ms step_avg:97.00ms +step:195/1670 train_time:18914ms step_avg:96.99ms +step:196/1670 train_time:19009ms step_avg:96.99ms +step:197/1670 train_time:19104ms step_avg:96.98ms +step:198/1670 train_time:19200ms step_avg:96.97ms +step:199/1670 train_time:19295ms step_avg:96.96ms +step:200/1670 train_time:19391ms step_avg:96.95ms +step:201/1670 train_time:19486ms step_avg:96.95ms +step:202/1670 train_time:19582ms step_avg:96.94ms +step:203/1670 train_time:19677ms step_avg:96.93ms +step:204/1670 train_time:19774ms step_avg:96.93ms +step:205/1670 train_time:19869ms step_avg:96.92ms +step:206/1670 train_time:19965ms step_avg:96.92ms +step:207/1670 train_time:20061ms step_avg:96.91ms +step:208/1670 train_time:20157ms step_avg:96.91ms +step:209/1670 train_time:20252ms step_avg:96.90ms +step:210/1670 train_time:20347ms step_avg:96.89ms +step:211/1670 train_time:20442ms step_avg:96.88ms +step:212/1670 train_time:20539ms step_avg:96.88ms +step:213/1670 train_time:20842ms step_avg:97.85ms +step:214/1670 train_time:20947ms step_avg:97.88ms +step:215/1670 train_time:21042ms step_avg:97.87ms +step:216/1670 train_time:21137ms step_avg:97.85ms +step:217/1670 train_time:21231ms step_avg:97.84ms +step:218/1670 train_time:21326ms step_avg:97.82ms +step:219/1670 train_time:21420ms step_avg:97.81ms +step:220/1670 train_time:21515ms step_avg:97.80ms +step:221/1670 train_time:21610ms step_avg:97.78ms +step:222/1670 train_time:21704ms step_avg:97.77ms +step:223/1670 train_time:21802ms step_avg:97.77ms +step:224/1670 train_time:21901ms step_avg:97.77ms +step:225/1670 train_time:21999ms step_avg:97.77ms +step:226/1670 train_time:22095ms step_avg:97.77ms +step:227/1670 train_time:22190ms step_avg:97.75ms +step:228/1670 train_time:22285ms step_avg:97.74ms +step:229/1670 train_time:22380ms step_avg:97.73ms +step:230/1670 train_time:22475ms step_avg:97.72ms +step:231/1670 train_time:22569ms step_avg:97.70ms +step:232/1670 train_time:22664ms step_avg:97.69ms +step:233/1670 train_time:22760ms step_avg:97.68ms +step:234/1670 train_time:22857ms step_avg:97.68ms +step:235/1670 train_time:22954ms step_avg:97.67ms +step:236/1670 train_time:23050ms step_avg:97.67ms +step:237/1670 train_time:23145ms step_avg:97.66ms +step:238/1670 train_time:23241ms step_avg:97.65ms +step:239/1670 train_time:23337ms step_avg:97.64ms +step:240/1670 train_time:23432ms step_avg:97.63ms +step:241/1670 train_time:23526ms step_avg:97.62ms +step:242/1670 train_time:23622ms step_avg:97.61ms +step:243/1670 train_time:23717ms step_avg:97.60ms +step:244/1670 train_time:23813ms step_avg:97.59ms +step:245/1670 train_time:23909ms step_avg:97.59ms +step:246/1670 train_time:24005ms step_avg:97.58ms +step:247/1670 train_time:24101ms step_avg:97.57ms +step:248/1670 train_time:24197ms step_avg:97.57ms +step:249/1670 train_time:24293ms step_avg:97.56ms +step:250/1670 train_time:24387ms step_avg:97.55ms +step:250/1670 val_loss:3.9656 train_time:24482ms step_avg:97.93ms +step:251/1670 train_time:24503ms step_avg:97.62ms +step:252/1670 train_time:24579ms step_avg:97.54ms +step:253/1670 train_time:24680ms step_avg:97.55ms +step:254/1670 train_time:24783ms step_avg:97.57ms +step:255/1670 train_time:24881ms step_avg:97.57ms +step:256/1670 train_time:24976ms step_avg:97.56ms +step:257/1670 train_time:25071ms step_avg:97.55ms +step:258/1670 train_time:25165ms step_avg:97.54ms +step:259/1670 train_time:25260ms step_avg:97.53ms +step:260/1670 train_time:25355ms step_avg:97.52ms +step:261/1670 train_time:25452ms step_avg:97.52ms +step:262/1670 train_time:25548ms step_avg:97.51ms +step:263/1670 train_time:25643ms step_avg:97.50ms +step:264/1670 train_time:25740ms step_avg:97.50ms +step:265/1670 train_time:25838ms step_avg:97.50ms +step:266/1670 train_time:25934ms step_avg:97.49ms +step:267/1670 train_time:26029ms step_avg:97.49ms +step:268/1670 train_time:26124ms step_avg:97.48ms +step:269/1670 train_time:26219ms step_avg:97.47ms +step:270/1670 train_time:26314ms step_avg:97.46ms +step:271/1670 train_time:26410ms step_avg:97.45ms +step:272/1670 train_time:26505ms step_avg:97.45ms +step:273/1670 train_time:26601ms step_avg:97.44ms +step:274/1670 train_time:26697ms step_avg:97.43ms +step:275/1670 train_time:26793ms step_avg:97.43ms +step:276/1670 train_time:26889ms step_avg:97.43ms +step:277/1670 train_time:26985ms step_avg:97.42ms +step:278/1670 train_time:27080ms step_avg:97.41ms +step:279/1670 train_time:27175ms step_avg:97.40ms +step:280/1670 train_time:27270ms step_avg:97.39ms +step:281/1670 train_time:27365ms step_avg:97.38ms +step:282/1670 train_time:27461ms step_avg:97.38ms +step:283/1670 train_time:27557ms step_avg:97.37ms +step:284/1670 train_time:27653ms step_avg:97.37ms +step:285/1670 train_time:27749ms step_avg:97.36ms +step:286/1670 train_time:27845ms step_avg:97.36ms +step:287/1670 train_time:27940ms step_avg:97.35ms +step:288/1670 train_time:28036ms step_avg:97.35ms +step:289/1670 train_time:28132ms step_avg:97.34ms +step:290/1670 train_time:28228ms step_avg:97.34ms +step:291/1670 train_time:28322ms step_avg:97.33ms +step:292/1670 train_time:28418ms step_avg:97.32ms +step:293/1670 train_time:28513ms step_avg:97.32ms +step:294/1670 train_time:28609ms step_avg:97.31ms +step:295/1670 train_time:28705ms step_avg:97.31ms +step:296/1670 train_time:28800ms step_avg:97.30ms +step:297/1670 train_time:28896ms step_avg:97.29ms +step:298/1670 train_time:28992ms step_avg:97.29ms +step:299/1670 train_time:29087ms step_avg:97.28ms +step:300/1670 train_time:29182ms step_avg:97.27ms +step:301/1670 train_time:29277ms step_avg:97.27ms +step:302/1670 train_time:29373ms step_avg:97.26ms +step:303/1670 train_time:29468ms step_avg:97.26ms +step:304/1670 train_time:29563ms step_avg:97.25ms +step:305/1670 train_time:29659ms step_avg:97.24ms +step:306/1670 train_time:29755ms step_avg:97.24ms +step:307/1670 train_time:29851ms step_avg:97.23ms +step:308/1670 train_time:29946ms step_avg:97.23ms +step:309/1670 train_time:30041ms step_avg:97.22ms +step:310/1670 train_time:30137ms step_avg:97.22ms +step:311/1670 train_time:30233ms step_avg:97.21ms +step:312/1670 train_time:30329ms step_avg:97.21ms +step:313/1670 train_time:30423ms step_avg:97.20ms +step:314/1670 train_time:30519ms step_avg:97.19ms +step:315/1670 train_time:30615ms step_avg:97.19ms +step:316/1670 train_time:30711ms step_avg:97.19ms +step:317/1670 train_time:30806ms step_avg:97.18ms +step:318/1670 train_time:30901ms step_avg:97.17ms +step:319/1670 train_time:30997ms step_avg:97.17ms +step:320/1670 train_time:31093ms step_avg:97.17ms +step:321/1670 train_time:31188ms step_avg:97.16ms +step:322/1670 train_time:31284ms step_avg:97.15ms +step:323/1670 train_time:31379ms step_avg:97.15ms +step:324/1670 train_time:31475ms step_avg:97.14ms +step:325/1670 train_time:31570ms step_avg:97.14ms +step:326/1670 train_time:31665ms step_avg:97.13ms +step:327/1670 train_time:31761ms step_avg:97.13ms +step:328/1670 train_time:31857ms step_avg:97.12ms +step:329/1670 train_time:31953ms step_avg:97.12ms +step:330/1670 train_time:32048ms step_avg:97.12ms +step:331/1670 train_time:32144ms step_avg:97.11ms +step:332/1670 train_time:32240ms step_avg:97.11ms +step:333/1670 train_time:32336ms step_avg:97.11ms +step:334/1670 train_time:32432ms step_avg:97.10ms +step:335/1670 train_time:32527ms step_avg:97.09ms +step:336/1670 train_time:32622ms step_avg:97.09ms +step:337/1670 train_time:32718ms step_avg:97.09ms +step:338/1670 train_time:32813ms step_avg:97.08ms +step:339/1670 train_time:32909ms step_avg:97.08ms +step:340/1670 train_time:33004ms step_avg:97.07ms +step:341/1670 train_time:33099ms step_avg:97.07ms +step:342/1670 train_time:33195ms step_avg:97.06ms +step:343/1670 train_time:33291ms step_avg:97.06ms +step:344/1670 train_time:33388ms step_avg:97.06ms +step:345/1670 train_time:33483ms step_avg:97.05ms +step:346/1670 train_time:33579ms step_avg:97.05ms +step:347/1670 train_time:33674ms step_avg:97.04ms +step:348/1670 train_time:33770ms step_avg:97.04ms +step:349/1670 train_time:33865ms step_avg:97.03ms +step:350/1670 train_time:33960ms step_avg:97.03ms +step:351/1670 train_time:34057ms step_avg:97.03ms +step:352/1670 train_time:34153ms step_avg:97.02ms +step:353/1670 train_time:34248ms step_avg:97.02ms +step:354/1670 train_time:34343ms step_avg:97.01ms +step:355/1670 train_time:34439ms step_avg:97.01ms +step:356/1670 train_time:34535ms step_avg:97.01ms +step:357/1670 train_time:34631ms step_avg:97.00ms +step:358/1670 train_time:34726ms step_avg:97.00ms +step:359/1670 train_time:34822ms step_avg:97.00ms +step:360/1670 train_time:34918ms step_avg:96.99ms +step:361/1670 train_time:35014ms step_avg:96.99ms +step:362/1670 train_time:35109ms step_avg:96.99ms +step:363/1670 train_time:35204ms step_avg:96.98ms +step:364/1670 train_time:35300ms step_avg:96.98ms +step:365/1670 train_time:35396ms step_avg:96.97ms +step:366/1670 train_time:35492ms step_avg:96.97ms +step:367/1670 train_time:35588ms step_avg:96.97ms +step:368/1670 train_time:35683ms step_avg:96.96ms +step:369/1670 train_time:35778ms step_avg:96.96ms +step:370/1670 train_time:35874ms step_avg:96.96ms +step:371/1670 train_time:35971ms step_avg:96.96ms +step:372/1670 train_time:36066ms step_avg:96.95ms +step:373/1670 train_time:36161ms step_avg:96.95ms +step:374/1670 train_time:36256ms step_avg:96.94ms +step:375/1670 train_time:36352ms step_avg:96.94ms +step:375/1670 val_loss:3.8117 train_time:36447ms step_avg:97.19ms +step:376/1670 train_time:36468ms step_avg:96.99ms +step:377/1670 train_time:36549ms step_avg:96.95ms +step:378/1670 train_time:36651ms step_avg:96.96ms +step:379/1670 train_time:36747ms step_avg:96.96ms +step:380/1670 train_time:36842ms step_avg:96.95ms +step:381/1670 train_time:36937ms step_avg:96.95ms +step:382/1670 train_time:37032ms step_avg:96.94ms +step:383/1670 train_time:37127ms step_avg:96.94ms +step:384/1670 train_time:37223ms step_avg:96.93ms +step:385/1670 train_time:37317ms step_avg:96.93ms +step:386/1670 train_time:37412ms step_avg:96.92ms +step:387/1670 train_time:37510ms step_avg:96.93ms +step:388/1670 train_time:37609ms step_avg:96.93ms +step:389/1670 train_time:37706ms step_avg:96.93ms +step:390/1670 train_time:37802ms step_avg:96.93ms +step:391/1670 train_time:37898ms step_avg:96.92ms +step:392/1670 train_time:37992ms step_avg:96.92ms +step:393/1670 train_time:38088ms step_avg:96.92ms +step:394/1670 train_time:38183ms step_avg:96.91ms +step:395/1670 train_time:38277ms step_avg:96.90ms +step:396/1670 train_time:38373ms step_avg:96.90ms +step:397/1670 train_time:38470ms step_avg:96.90ms +step:398/1670 train_time:38566ms step_avg:96.90ms +step:399/1670 train_time:38663ms step_avg:96.90ms +step:400/1670 train_time:38759ms step_avg:96.90ms +step:401/1670 train_time:38854ms step_avg:96.89ms +step:402/1670 train_time:38949ms step_avg:96.89ms +step:403/1670 train_time:39045ms step_avg:96.89ms +step:404/1670 train_time:39140ms step_avg:96.88ms +step:405/1670 train_time:39235ms step_avg:96.88ms +step:406/1670 train_time:39330ms step_avg:96.87ms +step:407/1670 train_time:39426ms step_avg:96.87ms +step:408/1670 train_time:39522ms step_avg:96.87ms +step:409/1670 train_time:39618ms step_avg:96.86ms +step:410/1670 train_time:39713ms step_avg:96.86ms +step:411/1670 train_time:39810ms step_avg:96.86ms +step:412/1670 train_time:39905ms step_avg:96.86ms +step:413/1670 train_time:40001ms step_avg:96.85ms +step:414/1670 train_time:40096ms step_avg:96.85ms +step:415/1670 train_time:40191ms step_avg:96.85ms +step:416/1670 train_time:40286ms step_avg:96.84ms +step:417/1670 train_time:40382ms step_avg:96.84ms +step:418/1670 train_time:40478ms step_avg:96.84ms +step:419/1670 train_time:40573ms step_avg:96.83ms +step:420/1670 train_time:40669ms step_avg:96.83ms +step:421/1670 train_time:40765ms step_avg:96.83ms +step:422/1670 train_time:40862ms step_avg:96.83ms +step:423/1670 train_time:40957ms step_avg:96.83ms +step:424/1670 train_time:41052ms step_avg:96.82ms +step:425/1670 train_time:41347ms step_avg:97.29ms +step:426/1670 train_time:41488ms step_avg:97.39ms +step:427/1670 train_time:41582ms step_avg:97.38ms +step:428/1670 train_time:41677ms step_avg:97.38ms +step:429/1670 train_time:41771ms step_avg:97.37ms +step:430/1670 train_time:41865ms step_avg:97.36ms +step:431/1670 train_time:41960ms step_avg:97.36ms +step:432/1670 train_time:42055ms step_avg:97.35ms +step:433/1670 train_time:42149ms step_avg:97.34ms +step:434/1670 train_time:42244ms step_avg:97.34ms +step:435/1670 train_time:42344ms step_avg:97.34ms +step:436/1670 train_time:42444ms step_avg:97.35ms +step:437/1670 train_time:42542ms step_avg:97.35ms +step:438/1670 train_time:42638ms step_avg:97.35ms +step:439/1670 train_time:42733ms step_avg:97.34ms +step:440/1670 train_time:42828ms step_avg:97.34ms +step:441/1670 train_time:42923ms step_avg:97.33ms +step:442/1670 train_time:43018ms step_avg:97.33ms +step:443/1670 train_time:43112ms step_avg:97.32ms +step:444/1670 train_time:43207ms step_avg:97.31ms +step:445/1670 train_time:43304ms step_avg:97.31ms +step:446/1670 train_time:43400ms step_avg:97.31ms +step:447/1670 train_time:43496ms step_avg:97.31ms +step:448/1670 train_time:43592ms step_avg:97.30ms +step:449/1670 train_time:43688ms step_avg:97.30ms +step:450/1670 train_time:43784ms step_avg:97.30ms +step:451/1670 train_time:43879ms step_avg:97.29ms +step:452/1670 train_time:43975ms step_avg:97.29ms +step:453/1670 train_time:44069ms step_avg:97.28ms +step:454/1670 train_time:44163ms step_avg:97.28ms +step:455/1670 train_time:44258ms step_avg:97.27ms +step:456/1670 train_time:44354ms step_avg:97.27ms +step:457/1670 train_time:44451ms step_avg:97.27ms +step:458/1670 train_time:44548ms step_avg:97.27ms +step:459/1670 train_time:44644ms step_avg:97.26ms +step:460/1670 train_time:44742ms step_avg:97.26ms +step:461/1670 train_time:44837ms step_avg:97.26ms +step:462/1670 train_time:44932ms step_avg:97.26ms +step:463/1670 train_time:45028ms step_avg:97.25ms +step:464/1670 train_time:45123ms step_avg:97.25ms +step:465/1670 train_time:45218ms step_avg:97.24ms +step:466/1670 train_time:45314ms step_avg:97.24ms +step:467/1670 train_time:45409ms step_avg:97.24ms +step:468/1670 train_time:45506ms step_avg:97.23ms +step:469/1670 train_time:45602ms step_avg:97.23ms +step:470/1670 train_time:45698ms step_avg:97.23ms +step:471/1670 train_time:45793ms step_avg:97.23ms +step:472/1670 train_time:45889ms step_avg:97.22ms +step:473/1670 train_time:45985ms step_avg:97.22ms +step:474/1670 train_time:46080ms step_avg:97.22ms +step:475/1670 train_time:46176ms step_avg:97.21ms +step:476/1670 train_time:46271ms step_avg:97.21ms +step:477/1670 train_time:46367ms step_avg:97.20ms +step:478/1670 train_time:46463ms step_avg:97.20ms +step:479/1670 train_time:46559ms step_avg:97.20ms +step:480/1670 train_time:46654ms step_avg:97.20ms +step:481/1670 train_time:46750ms step_avg:97.19ms +step:482/1670 train_time:46846ms step_avg:97.19ms +step:483/1670 train_time:46943ms step_avg:97.19ms +step:484/1670 train_time:47039ms step_avg:97.19ms +step:485/1670 train_time:47134ms step_avg:97.18ms +step:486/1670 train_time:47229ms step_avg:97.18ms +step:487/1670 train_time:47325ms step_avg:97.18ms +step:488/1670 train_time:47421ms step_avg:97.17ms +step:489/1670 train_time:47517ms step_avg:97.17ms +step:490/1670 train_time:47613ms step_avg:97.17ms +step:491/1670 train_time:47708ms step_avg:97.17ms +step:492/1670 train_time:47804ms step_avg:97.16ms +step:493/1670 train_time:47900ms step_avg:97.16ms +step:494/1670 train_time:47995ms step_avg:97.16ms +step:495/1670 train_time:48091ms step_avg:97.15ms +step:496/1670 train_time:48187ms step_avg:97.15ms +step:497/1670 train_time:48282ms step_avg:97.15ms +step:498/1670 train_time:48378ms step_avg:97.14ms +step:499/1670 train_time:48474ms step_avg:97.14ms +step:500/1670 train_time:48569ms step_avg:97.14ms +step:500/1670 val_loss:3.7096 train_time:48665ms step_avg:97.33ms +step:501/1670 train_time:48687ms step_avg:97.18ms +step:502/1670 train_time:48767ms step_avg:97.15ms +step:503/1670 train_time:48867ms step_avg:97.15ms +step:504/1670 train_time:48963ms step_avg:97.15ms +step:505/1670 train_time:49059ms step_avg:97.15ms +step:506/1670 train_time:49154ms step_avg:97.14ms +step:507/1670 train_time:49249ms step_avg:97.14ms +step:508/1670 train_time:49344ms step_avg:97.13ms +step:509/1670 train_time:49439ms step_avg:97.13ms +step:510/1670 train_time:49534ms step_avg:97.13ms +step:511/1670 train_time:49629ms step_avg:97.12ms +step:512/1670 train_time:49725ms step_avg:97.12ms +step:513/1670 train_time:49823ms step_avg:97.12ms +step:514/1670 train_time:49920ms step_avg:97.12ms +step:515/1670 train_time:50017ms step_avg:97.12ms +step:516/1670 train_time:50113ms step_avg:97.12ms +step:517/1670 train_time:50207ms step_avg:97.11ms +step:518/1670 train_time:50303ms step_avg:97.11ms +step:519/1670 train_time:50398ms step_avg:97.11ms +step:520/1670 train_time:50494ms step_avg:97.10ms +step:521/1670 train_time:50588ms step_avg:97.10ms +step:522/1670 train_time:50684ms step_avg:97.10ms +step:523/1670 train_time:50781ms step_avg:97.10ms +step:524/1670 train_time:50878ms step_avg:97.10ms +step:525/1670 train_time:50974ms step_avg:97.09ms +step:526/1670 train_time:51070ms step_avg:97.09ms +step:527/1670 train_time:51165ms step_avg:97.09ms +step:528/1670 train_time:51261ms step_avg:97.08ms +step:529/1670 train_time:51356ms step_avg:97.08ms +step:530/1670 train_time:51452ms step_avg:97.08ms +step:531/1670 train_time:51548ms step_avg:97.08ms +step:532/1670 train_time:51643ms step_avg:97.07ms +step:533/1670 train_time:51740ms step_avg:97.07ms +step:534/1670 train_time:51837ms step_avg:97.07ms +step:535/1670 train_time:51933ms step_avg:97.07ms +step:536/1670 train_time:52029ms step_avg:97.07ms +step:537/1670 train_time:52124ms step_avg:97.06ms +step:538/1670 train_time:52219ms step_avg:97.06ms +step:539/1670 train_time:52315ms step_avg:97.06ms +step:540/1670 train_time:52410ms step_avg:97.06ms +step:541/1670 train_time:52505ms step_avg:97.05ms +step:542/1670 train_time:52601ms step_avg:97.05ms +step:543/1670 train_time:52698ms step_avg:97.05ms +step:544/1670 train_time:52794ms step_avg:97.05ms +step:545/1670 train_time:52890ms step_avg:97.05ms +step:546/1670 train_time:52985ms step_avg:97.04ms +step:547/1670 train_time:53081ms step_avg:97.04ms +step:548/1670 train_time:53178ms step_avg:97.04ms +step:549/1670 train_time:53273ms step_avg:97.04ms +step:550/1670 train_time:53369ms step_avg:97.03ms +step:551/1670 train_time:53464ms step_avg:97.03ms +step:552/1670 train_time:53560ms step_avg:97.03ms +step:553/1670 train_time:53656ms step_avg:97.03ms +step:554/1670 train_time:53752ms step_avg:97.02ms +step:555/1670 train_time:53848ms step_avg:97.02ms +step:556/1670 train_time:53944ms step_avg:97.02ms +step:557/1670 train_time:54040ms step_avg:97.02ms +step:558/1670 train_time:54135ms step_avg:97.02ms +step:559/1670 train_time:54233ms step_avg:97.02ms +step:560/1670 train_time:54329ms step_avg:97.02ms +step:561/1670 train_time:54426ms step_avg:97.02ms +step:562/1670 train_time:54522ms step_avg:97.02ms +step:563/1670 train_time:54621ms step_avg:97.02ms +step:564/1670 train_time:54718ms step_avg:97.02ms +step:565/1670 train_time:54817ms step_avg:97.02ms +step:566/1670 train_time:54914ms step_avg:97.02ms +step:567/1670 train_time:55011ms step_avg:97.02ms +step:568/1670 train_time:55107ms step_avg:97.02ms +step:569/1670 train_time:55205ms step_avg:97.02ms +step:570/1670 train_time:55302ms step_avg:97.02ms +step:571/1670 train_time:55399ms step_avg:97.02ms +step:572/1670 train_time:55497ms step_avg:97.02ms +step:573/1670 train_time:55593ms step_avg:97.02ms +step:574/1670 train_time:55690ms step_avg:97.02ms +step:575/1670 train_time:55786ms step_avg:97.02ms +step:576/1670 train_time:55884ms step_avg:97.02ms +step:577/1670 train_time:55982ms step_avg:97.02ms +step:578/1670 train_time:56080ms step_avg:97.02ms +step:579/1670 train_time:56178ms step_avg:97.03ms +step:580/1670 train_time:56276ms step_avg:97.03ms +step:581/1670 train_time:56373ms step_avg:97.03ms +step:582/1670 train_time:56471ms step_avg:97.03ms +step:583/1670 train_time:56567ms step_avg:97.03ms +step:584/1670 train_time:56665ms step_avg:97.03ms +step:585/1670 train_time:56762ms step_avg:97.03ms +step:586/1670 train_time:56860ms step_avg:97.03ms +step:587/1670 train_time:56958ms step_avg:97.03ms +step:588/1670 train_time:57056ms step_avg:97.03ms +step:589/1670 train_time:57153ms step_avg:97.03ms +step:590/1670 train_time:57251ms step_avg:97.04ms +step:591/1670 train_time:57347ms step_avg:97.03ms +step:592/1670 train_time:57444ms step_avg:97.03ms +step:593/1670 train_time:57541ms step_avg:97.03ms +step:594/1670 train_time:57638ms step_avg:97.03ms +step:595/1670 train_time:57735ms step_avg:97.03ms +step:596/1670 train_time:57832ms step_avg:97.03ms +step:597/1670 train_time:57928ms step_avg:97.03ms +step:598/1670 train_time:58026ms step_avg:97.03ms +step:599/1670 train_time:58124ms step_avg:97.03ms +step:600/1670 train_time:58222ms step_avg:97.04ms +step:601/1670 train_time:58319ms step_avg:97.04ms +step:602/1670 train_time:58418ms step_avg:97.04ms +step:603/1670 train_time:58516ms step_avg:97.04ms +step:604/1670 train_time:58612ms step_avg:97.04ms +step:605/1670 train_time:58709ms step_avg:97.04ms +step:606/1670 train_time:58805ms step_avg:97.04ms +step:607/1670 train_time:58902ms step_avg:97.04ms +step:608/1670 train_time:59000ms step_avg:97.04ms +step:609/1670 train_time:59098ms step_avg:97.04ms +step:610/1670 train_time:59195ms step_avg:97.04ms +step:611/1670 train_time:59292ms step_avg:97.04ms +step:612/1670 train_time:59388ms step_avg:97.04ms +step:613/1670 train_time:59485ms step_avg:97.04ms +step:614/1670 train_time:59583ms step_avg:97.04ms +step:615/1670 train_time:59682ms step_avg:97.04ms +step:616/1670 train_time:59780ms step_avg:97.05ms +step:617/1670 train_time:59879ms step_avg:97.05ms +step:618/1670 train_time:59976ms step_avg:97.05ms +step:619/1670 train_time:60073ms step_avg:97.05ms +step:620/1670 train_time:60170ms step_avg:97.05ms +step:621/1670 train_time:60267ms step_avg:97.05ms +step:622/1670 train_time:60364ms step_avg:97.05ms +step:623/1670 train_time:60461ms step_avg:97.05ms +step:624/1670 train_time:60559ms step_avg:97.05ms +step:625/1670 train_time:60656ms step_avg:97.05ms +step:625/1670 val_loss:3.6099 train_time:60752ms step_avg:97.20ms +step:626/1670 train_time:60774ms step_avg:97.08ms +step:627/1670 train_time:60860ms step_avg:97.07ms +step:628/1670 train_time:60956ms step_avg:97.06ms +step:629/1670 train_time:61052ms step_avg:97.06ms +step:630/1670 train_time:61148ms step_avg:97.06ms +step:631/1670 train_time:61244ms step_avg:97.06ms +step:632/1670 train_time:61340ms step_avg:97.06ms +step:633/1670 train_time:61436ms step_avg:97.06ms +step:634/1670 train_time:61532ms step_avg:97.05ms +step:635/1670 train_time:61628ms step_avg:97.05ms +step:636/1670 train_time:61727ms step_avg:97.06ms +step:637/1670 train_time:61829ms step_avg:97.06ms +step:638/1670 train_time:61929ms step_avg:97.07ms +step:639/1670 train_time:62303ms step_avg:97.50ms +step:640/1670 train_time:62392ms step_avg:97.49ms +step:641/1670 train_time:62488ms step_avg:97.48ms +step:642/1670 train_time:62584ms step_avg:97.48ms +step:643/1670 train_time:62680ms step_avg:97.48ms +step:644/1670 train_time:62776ms step_avg:97.48ms +step:645/1670 train_time:62872ms step_avg:97.48ms +step:646/1670 train_time:62968ms step_avg:97.47ms +step:647/1670 train_time:63064ms step_avg:97.47ms +step:648/1670 train_time:63161ms step_avg:97.47ms +step:649/1670 train_time:63263ms step_avg:97.48ms +step:650/1670 train_time:63361ms step_avg:97.48ms +step:651/1670 train_time:63459ms step_avg:97.48ms +step:652/1670 train_time:63556ms step_avg:97.48ms +step:653/1670 train_time:63653ms step_avg:97.48ms +step:654/1670 train_time:63750ms step_avg:97.48ms +step:655/1670 train_time:63847ms step_avg:97.48ms +step:656/1670 train_time:63944ms step_avg:97.47ms +step:657/1670 train_time:64040ms step_avg:97.47ms +step:658/1670 train_time:64137ms step_avg:97.47ms +step:659/1670 train_time:64234ms step_avg:97.47ms +step:660/1670 train_time:64333ms step_avg:97.47ms +step:661/1670 train_time:64432ms step_avg:97.48ms +step:662/1670 train_time:64529ms step_avg:97.48ms +step:663/1670 train_time:64627ms step_avg:97.48ms +step:664/1670 train_time:64725ms step_avg:97.48ms +step:665/1670 train_time:64821ms step_avg:97.47ms +step:666/1670 train_time:64916ms step_avg:97.47ms +step:667/1670 train_time:65013ms step_avg:97.47ms +step:668/1670 train_time:65110ms step_avg:97.47ms +step:669/1670 train_time:65209ms step_avg:97.47ms +step:670/1670 train_time:65306ms step_avg:97.47ms +step:671/1670 train_time:65405ms step_avg:97.47ms +step:672/1670 train_time:65503ms step_avg:97.47ms +step:673/1670 train_time:65600ms step_avg:97.47ms +step:674/1670 train_time:65697ms step_avg:97.47ms +step:675/1670 train_time:65793ms step_avg:97.47ms +step:676/1670 train_time:65890ms step_avg:97.47ms +step:677/1670 train_time:65987ms step_avg:97.47ms +step:678/1670 train_time:66086ms step_avg:97.47ms +step:679/1670 train_time:66183ms step_avg:97.47ms +step:680/1670 train_time:66280ms step_avg:97.47ms +step:681/1670 train_time:66377ms step_avg:97.47ms +step:682/1670 train_time:66475ms step_avg:97.47ms +step:683/1670 train_time:66572ms step_avg:97.47ms +step:684/1670 train_time:66670ms step_avg:97.47ms +step:685/1670 train_time:66767ms step_avg:97.47ms +step:686/1670 train_time:66864ms step_avg:97.47ms +step:687/1670 train_time:66962ms step_avg:97.47ms +step:688/1670 train_time:67060ms step_avg:97.47ms +step:689/1670 train_time:67156ms step_avg:97.47ms +step:690/1670 train_time:67253ms step_avg:97.47ms +step:691/1670 train_time:67351ms step_avg:97.47ms +step:692/1670 train_time:67449ms step_avg:97.47ms +step:693/1670 train_time:67547ms step_avg:97.47ms +step:694/1670 train_time:67644ms step_avg:97.47ms +step:695/1670 train_time:67742ms step_avg:97.47ms +step:696/1670 train_time:67838ms step_avg:97.47ms +step:697/1670 train_time:67934ms step_avg:97.47ms +step:698/1670 train_time:68031ms step_avg:97.47ms +step:699/1670 train_time:68129ms step_avg:97.47ms +step:700/1670 train_time:68227ms step_avg:97.47ms +step:701/1670 train_time:68325ms step_avg:97.47ms +step:702/1670 train_time:68422ms step_avg:97.47ms +step:703/1670 train_time:68519ms step_avg:97.47ms +step:704/1670 train_time:68615ms step_avg:97.46ms +step:705/1670 train_time:68713ms step_avg:97.46ms +step:706/1670 train_time:68811ms step_avg:97.47ms +step:707/1670 train_time:68909ms step_avg:97.47ms +step:708/1670 train_time:69006ms step_avg:97.47ms +step:709/1670 train_time:69103ms step_avg:97.47ms +step:710/1670 train_time:69199ms step_avg:97.46ms +step:711/1670 train_time:69296ms step_avg:97.46ms +step:712/1670 train_time:69393ms step_avg:97.46ms +step:713/1670 train_time:69490ms step_avg:97.46ms +step:714/1670 train_time:69588ms step_avg:97.46ms +step:715/1670 train_time:69686ms step_avg:97.46ms +step:716/1670 train_time:69784ms step_avg:97.46ms +step:717/1670 train_time:69881ms step_avg:97.46ms +step:718/1670 train_time:69977ms step_avg:97.46ms +step:719/1670 train_time:70073ms step_avg:97.46ms +step:720/1670 train_time:70171ms step_avg:97.46ms +step:721/1670 train_time:70268ms step_avg:97.46ms +step:722/1670 train_time:70366ms step_avg:97.46ms +step:723/1670 train_time:70464ms step_avg:97.46ms +step:724/1670 train_time:70562ms step_avg:97.46ms +step:725/1670 train_time:70659ms step_avg:97.46ms +step:726/1670 train_time:70754ms step_avg:97.46ms +step:727/1670 train_time:70852ms step_avg:97.46ms +step:728/1670 train_time:70949ms step_avg:97.46ms +step:729/1670 train_time:71048ms step_avg:97.46ms +step:730/1670 train_time:71146ms step_avg:97.46ms +step:731/1670 train_time:71243ms step_avg:97.46ms +step:732/1670 train_time:71339ms step_avg:97.46ms +step:733/1670 train_time:71436ms step_avg:97.46ms +step:734/1670 train_time:71533ms step_avg:97.46ms +step:735/1670 train_time:71630ms step_avg:97.46ms +step:736/1670 train_time:71728ms step_avg:97.46ms +step:737/1670 train_time:71826ms step_avg:97.46ms +step:738/1670 train_time:71924ms step_avg:97.46ms +step:739/1670 train_time:72022ms step_avg:97.46ms +step:740/1670 train_time:72118ms step_avg:97.46ms +step:741/1670 train_time:72215ms step_avg:97.46ms +step:742/1670 train_time:72313ms step_avg:97.46ms +step:743/1670 train_time:72410ms step_avg:97.46ms +step:744/1670 train_time:72507ms step_avg:97.46ms +step:745/1670 train_time:72605ms step_avg:97.46ms +step:746/1670 train_time:72703ms step_avg:97.46ms +step:747/1670 train_time:72801ms step_avg:97.46ms +step:748/1670 train_time:72897ms step_avg:97.46ms +step:749/1670 train_time:72993ms step_avg:97.45ms +step:750/1670 train_time:73091ms step_avg:97.45ms +step:750/1670 val_loss:3.5575 train_time:73188ms step_avg:97.58ms +step:751/1670 train_time:73211ms step_avg:97.48ms +step:752/1670 train_time:73293ms step_avg:97.46ms +step:753/1670 train_time:73392ms step_avg:97.47ms +step:754/1670 train_time:73489ms step_avg:97.47ms +step:755/1670 train_time:73586ms step_avg:97.47ms +step:756/1670 train_time:73683ms step_avg:97.46ms +step:757/1670 train_time:73779ms step_avg:97.46ms +step:758/1670 train_time:73875ms step_avg:97.46ms +step:759/1670 train_time:73971ms step_avg:97.46ms +step:760/1670 train_time:74067ms step_avg:97.46ms +step:761/1670 train_time:74165ms step_avg:97.46ms +step:762/1670 train_time:74268ms step_avg:97.46ms +step:763/1670 train_time:74367ms step_avg:97.47ms +step:764/1670 train_time:74466ms step_avg:97.47ms +step:765/1670 train_time:74563ms step_avg:97.47ms +step:766/1670 train_time:74660ms step_avg:97.47ms +step:767/1670 train_time:74756ms step_avg:97.47ms +step:768/1670 train_time:74852ms step_avg:97.46ms +step:769/1670 train_time:74948ms step_avg:97.46ms +step:770/1670 train_time:75046ms step_avg:97.46ms +step:771/1670 train_time:75144ms step_avg:97.46ms +step:772/1670 train_time:75243ms step_avg:97.47ms +step:773/1670 train_time:75343ms step_avg:97.47ms +step:774/1670 train_time:75441ms step_avg:97.47ms +step:775/1670 train_time:75538ms step_avg:97.47ms +step:776/1670 train_time:75635ms step_avg:97.47ms +step:777/1670 train_time:75732ms step_avg:97.47ms +step:778/1670 train_time:75829ms step_avg:97.47ms +step:779/1670 train_time:75925ms step_avg:97.47ms +step:780/1670 train_time:76022ms step_avg:97.46ms +step:781/1670 train_time:76118ms step_avg:97.46ms +step:782/1670 train_time:76215ms step_avg:97.46ms +step:783/1670 train_time:76313ms step_avg:97.46ms +step:784/1670 train_time:76410ms step_avg:97.46ms +step:785/1670 train_time:76508ms step_avg:97.46ms +step:786/1670 train_time:76606ms step_avg:97.46ms +step:787/1670 train_time:76705ms step_avg:97.47ms +step:788/1670 train_time:76804ms step_avg:97.47ms +step:789/1670 train_time:76901ms step_avg:97.47ms +step:790/1670 train_time:76998ms step_avg:97.47ms +step:791/1670 train_time:77093ms step_avg:97.46ms +step:792/1670 train_time:77190ms step_avg:97.46ms +step:793/1670 train_time:77290ms step_avg:97.47ms +step:794/1670 train_time:77388ms step_avg:97.47ms +step:795/1670 train_time:77486ms step_avg:97.47ms +step:796/1670 train_time:77584ms step_avg:97.47ms +step:797/1670 train_time:77681ms step_avg:97.47ms +step:798/1670 train_time:77779ms step_avg:97.47ms +step:799/1670 train_time:77875ms step_avg:97.47ms +step:800/1670 train_time:77971ms step_avg:97.46ms +step:801/1670 train_time:78068ms step_avg:97.46ms +step:802/1670 train_time:78165ms step_avg:97.46ms +step:803/1670 train_time:78264ms step_avg:97.46ms +step:804/1670 train_time:78362ms step_avg:97.46ms +step:805/1670 train_time:78459ms step_avg:97.46ms +step:806/1670 train_time:78555ms step_avg:97.46ms +step:807/1670 train_time:78653ms step_avg:97.46ms +step:808/1670 train_time:78751ms step_avg:97.46ms +step:809/1670 train_time:78848ms step_avg:97.46ms +step:810/1670 train_time:78945ms step_avg:97.46ms +step:811/1670 train_time:79042ms step_avg:97.46ms +step:812/1670 train_time:79140ms step_avg:97.46ms +step:813/1670 train_time:79237ms step_avg:97.46ms +step:814/1670 train_time:79334ms step_avg:97.46ms +step:815/1670 train_time:79431ms step_avg:97.46ms +step:816/1670 train_time:79529ms step_avg:97.46ms +step:817/1670 train_time:79627ms step_avg:97.46ms +step:818/1670 train_time:79725ms step_avg:97.46ms +step:819/1670 train_time:79823ms step_avg:97.46ms +step:820/1670 train_time:79920ms step_avg:97.46ms +step:821/1670 train_time:80017ms step_avg:97.46ms +step:822/1670 train_time:80113ms step_avg:97.46ms +step:823/1670 train_time:80210ms step_avg:97.46ms +step:824/1670 train_time:80308ms step_avg:97.46ms +step:825/1670 train_time:80407ms step_avg:97.46ms +step:826/1670 train_time:80505ms step_avg:97.46ms +step:827/1670 train_time:80602ms step_avg:97.46ms +step:828/1670 train_time:80698ms step_avg:97.46ms +step:829/1670 train_time:80795ms step_avg:97.46ms +step:830/1670 train_time:80891ms step_avg:97.46ms +step:831/1670 train_time:80989ms step_avg:97.46ms +step:832/1670 train_time:81086ms step_avg:97.46ms +step:833/1670 train_time:81184ms step_avg:97.46ms +step:834/1670 train_time:81281ms step_avg:97.46ms +step:835/1670 train_time:81378ms step_avg:97.46ms +step:836/1670 train_time:81475ms step_avg:97.46ms +step:837/1670 train_time:81572ms step_avg:97.46ms +step:838/1670 train_time:81670ms step_avg:97.46ms +step:839/1670 train_time:81767ms step_avg:97.46ms +step:840/1670 train_time:81865ms step_avg:97.46ms +step:841/1670 train_time:81963ms step_avg:97.46ms +step:842/1670 train_time:82060ms step_avg:97.46ms +step:843/1670 train_time:82157ms step_avg:97.46ms +step:844/1670 train_time:82254ms step_avg:97.46ms +step:845/1670 train_time:82351ms step_avg:97.46ms +step:846/1670 train_time:82449ms step_avg:97.46ms +step:847/1670 train_time:82546ms step_avg:97.46ms +step:848/1670 train_time:82644ms step_avg:97.46ms +step:849/1670 train_time:82742ms step_avg:97.46ms +step:850/1670 train_time:82839ms step_avg:97.46ms +step:851/1670 train_time:83087ms step_avg:97.63ms +step:852/1670 train_time:83289ms step_avg:97.76ms +step:853/1670 train_time:83384ms step_avg:97.75ms +step:854/1670 train_time:83480ms step_avg:97.75ms +step:855/1670 train_time:83576ms step_avg:97.75ms +step:856/1670 train_time:83672ms step_avg:97.75ms +step:857/1670 train_time:83768ms step_avg:97.75ms +step:858/1670 train_time:83865ms step_avg:97.74ms +step:859/1670 train_time:83961ms step_avg:97.74ms +step:860/1670 train_time:84060ms step_avg:97.74ms +step:861/1670 train_time:84162ms step_avg:97.75ms +step:862/1670 train_time:84262ms step_avg:97.75ms +step:863/1670 train_time:84360ms step_avg:97.75ms +step:864/1670 train_time:84458ms step_avg:97.75ms +step:865/1670 train_time:84554ms step_avg:97.75ms +step:866/1670 train_time:84650ms step_avg:97.75ms +step:867/1670 train_time:84746ms step_avg:97.75ms +step:868/1670 train_time:84843ms step_avg:97.75ms +step:869/1670 train_time:84939ms step_avg:97.74ms +step:870/1670 train_time:85035ms step_avg:97.74ms +step:871/1670 train_time:85132ms step_avg:97.74ms +step:872/1670 train_time:85231ms step_avg:97.74ms +step:873/1670 train_time:85330ms step_avg:97.74ms +step:874/1670 train_time:85428ms step_avg:97.74ms +step:875/1670 train_time:85527ms step_avg:97.74ms +step:875/1670 val_loss:3.5179 train_time:85623ms step_avg:97.86ms +step:876/1670 train_time:85645ms step_avg:97.77ms +step:877/1670 train_time:85729ms step_avg:97.75ms +step:878/1670 train_time:85827ms step_avg:97.75ms +step:879/1670 train_time:85926ms step_avg:97.75ms +step:880/1670 train_time:86022ms step_avg:97.75ms +step:881/1670 train_time:86118ms step_avg:97.75ms +step:882/1670 train_time:86214ms step_avg:97.75ms +step:883/1670 train_time:86310ms step_avg:97.75ms +step:884/1670 train_time:86406ms step_avg:97.74ms +step:885/1670 train_time:86504ms step_avg:97.74ms +step:886/1670 train_time:86604ms step_avg:97.75ms +step:887/1670 train_time:86704ms step_avg:97.75ms +step:888/1670 train_time:86803ms step_avg:97.75ms +step:889/1670 train_time:86901ms step_avg:97.75ms +step:890/1670 train_time:86998ms step_avg:97.75ms +step:891/1670 train_time:87095ms step_avg:97.75ms +step:892/1670 train_time:87191ms step_avg:97.75ms +step:893/1670 train_time:87288ms step_avg:97.75ms +step:894/1670 train_time:87384ms step_avg:97.74ms +step:895/1670 train_time:87481ms step_avg:97.74ms +step:896/1670 train_time:87578ms step_avg:97.74ms +step:897/1670 train_time:87676ms step_avg:97.74ms +step:898/1670 train_time:87774ms step_avg:97.74ms +step:899/1670 train_time:87871ms step_avg:97.74ms +step:900/1670 train_time:87969ms step_avg:97.74ms +step:901/1670 train_time:88066ms step_avg:97.74ms +step:902/1670 train_time:88164ms step_avg:97.74ms +step:903/1670 train_time:88261ms step_avg:97.74ms +step:904/1670 train_time:88358ms step_avg:97.74ms +step:905/1670 train_time:88455ms step_avg:97.74ms +step:906/1670 train_time:88552ms step_avg:97.74ms +step:907/1670 train_time:88649ms step_avg:97.74ms +step:908/1670 train_time:88747ms step_avg:97.74ms +step:909/1670 train_time:88845ms step_avg:97.74ms +step:910/1670 train_time:88944ms step_avg:97.74ms +step:911/1670 train_time:89042ms step_avg:97.74ms +step:912/1670 train_time:89139ms step_avg:97.74ms +step:913/1670 train_time:89236ms step_avg:97.74ms +step:914/1670 train_time:89332ms step_avg:97.74ms +step:915/1670 train_time:89428ms step_avg:97.74ms +step:916/1670 train_time:89525ms step_avg:97.73ms +step:917/1670 train_time:89623ms step_avg:97.73ms +step:918/1670 train_time:89721ms step_avg:97.74ms +step:919/1670 train_time:89819ms step_avg:97.74ms +step:920/1670 train_time:89918ms step_avg:97.74ms +step:921/1670 train_time:90015ms step_avg:97.74ms +step:922/1670 train_time:90111ms step_avg:97.73ms +step:923/1670 train_time:90209ms step_avg:97.73ms +step:924/1670 train_time:90306ms step_avg:97.73ms +step:925/1670 train_time:90404ms step_avg:97.73ms +step:926/1670 train_time:90500ms step_avg:97.73ms +step:927/1670 train_time:90597ms step_avg:97.73ms +step:928/1670 train_time:90695ms step_avg:97.73ms +step:929/1670 train_time:90793ms step_avg:97.73ms +step:930/1670 train_time:90891ms step_avg:97.73ms +step:931/1670 train_time:90988ms step_avg:97.73ms +step:932/1670 train_time:91085ms step_avg:97.73ms +step:933/1670 train_time:91183ms step_avg:97.73ms +step:934/1670 train_time:91280ms step_avg:97.73ms +step:935/1670 train_time:91377ms step_avg:97.73ms +step:936/1670 train_time:91474ms step_avg:97.73ms +step:937/1670 train_time:91571ms step_avg:97.73ms +step:938/1670 train_time:91667ms step_avg:97.73ms +step:939/1670 train_time:91764ms step_avg:97.73ms +step:940/1670 train_time:91862ms step_avg:97.73ms +step:941/1670 train_time:91960ms step_avg:97.73ms +step:942/1670 train_time:92057ms step_avg:97.73ms +step:943/1670 train_time:92155ms step_avg:97.73ms +step:944/1670 train_time:92251ms step_avg:97.72ms +step:945/1670 train_time:92349ms step_avg:97.72ms +step:946/1670 train_time:92446ms step_avg:97.72ms +step:947/1670 train_time:92543ms step_avg:97.72ms +step:948/1670 train_time:92642ms step_avg:97.72ms +step:949/1670 train_time:92739ms step_avg:97.72ms +step:950/1670 train_time:92836ms step_avg:97.72ms +step:951/1670 train_time:92933ms step_avg:97.72ms +step:952/1670 train_time:93030ms step_avg:97.72ms +step:953/1670 train_time:93127ms step_avg:97.72ms +step:954/1670 train_time:93225ms step_avg:97.72ms +step:955/1670 train_time:93322ms step_avg:97.72ms +step:956/1670 train_time:93419ms step_avg:97.72ms +step:957/1670 train_time:93517ms step_avg:97.72ms +step:958/1670 train_time:93614ms step_avg:97.72ms +step:959/1670 train_time:93710ms step_avg:97.72ms +step:960/1670 train_time:93807ms step_avg:97.72ms +step:961/1670 train_time:93906ms step_avg:97.72ms +step:962/1670 train_time:94004ms step_avg:97.72ms +step:963/1670 train_time:94102ms step_avg:97.72ms +step:964/1670 train_time:94201ms step_avg:97.72ms +step:965/1670 train_time:94298ms step_avg:97.72ms +step:966/1670 train_time:94395ms step_avg:97.72ms +step:967/1670 train_time:94492ms step_avg:97.72ms +step:968/1670 train_time:94588ms step_avg:97.72ms +step:969/1670 train_time:94685ms step_avg:97.71ms +step:970/1670 train_time:94783ms step_avg:97.71ms +step:971/1670 train_time:94881ms step_avg:97.71ms +step:972/1670 train_time:94979ms step_avg:97.71ms +step:973/1670 train_time:95076ms step_avg:97.71ms +step:974/1670 train_time:95173ms step_avg:97.71ms +step:975/1670 train_time:95269ms step_avg:97.71ms +step:976/1670 train_time:95366ms step_avg:97.71ms +step:977/1670 train_time:95465ms step_avg:97.71ms +step:978/1670 train_time:95563ms step_avg:97.71ms +step:979/1670 train_time:95660ms step_avg:97.71ms +step:980/1670 train_time:95758ms step_avg:97.71ms +step:981/1670 train_time:95855ms step_avg:97.71ms +step:982/1670 train_time:95952ms step_avg:97.71ms +step:983/1670 train_time:96049ms step_avg:97.71ms +step:984/1670 train_time:96148ms step_avg:97.71ms +step:985/1670 train_time:96245ms step_avg:97.71ms +step:986/1670 train_time:96342ms step_avg:97.71ms +step:987/1670 train_time:96440ms step_avg:97.71ms +step:988/1670 train_time:96537ms step_avg:97.71ms +step:989/1670 train_time:96634ms step_avg:97.71ms +step:990/1670 train_time:96731ms step_avg:97.71ms +step:991/1670 train_time:96827ms step_avg:97.71ms +step:992/1670 train_time:96925ms step_avg:97.71ms +step:993/1670 train_time:97022ms step_avg:97.71ms +step:994/1670 train_time:97121ms step_avg:97.71ms +step:995/1670 train_time:97219ms step_avg:97.71ms +step:996/1670 train_time:97316ms step_avg:97.71ms +step:997/1670 train_time:97413ms step_avg:97.71ms +step:998/1670 train_time:97510ms step_avg:97.71ms +step:999/1670 train_time:97607ms step_avg:97.70ms +step:1000/1670 train_time:97704ms step_avg:97.70ms +step:1000/1670 val_loss:3.4737 train_time:97801ms step_avg:97.80ms +step:1001/1670 train_time:97823ms step_avg:97.73ms +step:1002/1670 train_time:97906ms step_avg:97.71ms +step:1003/1670 train_time:98008ms step_avg:97.71ms +step:1004/1670 train_time:98104ms step_avg:97.71ms +step:1005/1670 train_time:98201ms step_avg:97.71ms +step:1006/1670 train_time:98297ms step_avg:97.71ms +step:1007/1670 train_time:98393ms step_avg:97.71ms +step:1008/1670 train_time:98491ms step_avg:97.71ms +step:1009/1670 train_time:98585ms step_avg:97.71ms +step:1010/1670 train_time:98682ms step_avg:97.70ms +step:1011/1670 train_time:98780ms step_avg:97.71ms +step:1012/1670 train_time:98879ms step_avg:97.71ms +step:1013/1670 train_time:98980ms step_avg:97.71ms +step:1014/1670 train_time:99078ms step_avg:97.71ms +step:1015/1670 train_time:99176ms step_avg:97.71ms +step:1016/1670 train_time:99272ms step_avg:97.71ms +step:1017/1670 train_time:99369ms step_avg:97.71ms +step:1018/1670 train_time:99465ms step_avg:97.71ms +step:1019/1670 train_time:99561ms step_avg:97.70ms +step:1020/1670 train_time:99658ms step_avg:97.70ms +step:1021/1670 train_time:99755ms step_avg:97.70ms +step:1022/1670 train_time:99854ms step_avg:97.70ms +step:1023/1670 train_time:99953ms step_avg:97.71ms +step:1024/1670 train_time:100051ms step_avg:97.71ms +step:1025/1670 train_time:100148ms step_avg:97.71ms +step:1026/1670 train_time:100245ms step_avg:97.70ms +step:1027/1670 train_time:100341ms step_avg:97.70ms +step:1028/1670 train_time:100438ms step_avg:97.70ms +step:1029/1670 train_time:100535ms step_avg:97.70ms +step:1030/1670 train_time:100632ms step_avg:97.70ms +step:1031/1670 train_time:100728ms step_avg:97.70ms +step:1032/1670 train_time:100825ms step_avg:97.70ms +step:1033/1670 train_time:100923ms step_avg:97.70ms +step:1034/1670 train_time:101022ms step_avg:97.70ms +step:1035/1670 train_time:101121ms step_avg:97.70ms +step:1036/1670 train_time:101218ms step_avg:97.70ms +step:1037/1670 train_time:101316ms step_avg:97.70ms +step:1038/1670 train_time:101413ms step_avg:97.70ms +step:1039/1670 train_time:101510ms step_avg:97.70ms +step:1040/1670 train_time:101606ms step_avg:97.70ms +step:1041/1670 train_time:101703ms step_avg:97.70ms +step:1042/1670 train_time:101800ms step_avg:97.70ms +step:1043/1670 train_time:101899ms step_avg:97.70ms +step:1044/1670 train_time:101998ms step_avg:97.70ms +step:1045/1670 train_time:102096ms step_avg:97.70ms +step:1046/1670 train_time:102193ms step_avg:97.70ms +step:1047/1670 train_time:102290ms step_avg:97.70ms +step:1048/1670 train_time:102387ms step_avg:97.70ms +step:1049/1670 train_time:102483ms step_avg:97.70ms +step:1050/1670 train_time:102580ms step_avg:97.70ms +step:1051/1670 train_time:102677ms step_avg:97.69ms +step:1052/1670 train_time:102775ms step_avg:97.69ms +step:1053/1670 train_time:102873ms step_avg:97.70ms +step:1054/1670 train_time:102971ms step_avg:97.70ms +step:1055/1670 train_time:103068ms step_avg:97.70ms +step:1056/1670 train_time:103165ms step_avg:97.69ms +step:1057/1670 train_time:103263ms step_avg:97.69ms +step:1058/1670 train_time:103361ms step_avg:97.69ms +step:1059/1670 train_time:103458ms step_avg:97.69ms +step:1060/1670 train_time:103556ms step_avg:97.69ms +step:1061/1670 train_time:103653ms step_avg:97.69ms +step:1062/1670 train_time:103921ms step_avg:97.85ms +step:1063/1670 train_time:104001ms step_avg:97.84ms +step:1064/1670 train_time:104098ms step_avg:97.84ms +step:1065/1670 train_time:104195ms step_avg:97.84ms +step:1066/1670 train_time:104290ms step_avg:97.83ms +step:1067/1670 train_time:104386ms step_avg:97.83ms +step:1068/1670 train_time:104482ms step_avg:97.83ms +step:1069/1670 train_time:104578ms step_avg:97.83ms +step:1070/1670 train_time:104674ms step_avg:97.83ms +step:1071/1670 train_time:104771ms step_avg:97.83ms +step:1072/1670 train_time:104874ms step_avg:97.83ms +step:1073/1670 train_time:104974ms step_avg:97.83ms +step:1074/1670 train_time:105072ms step_avg:97.83ms +step:1075/1670 train_time:105169ms step_avg:97.83ms +step:1076/1670 train_time:105265ms step_avg:97.83ms +step:1077/1670 train_time:105362ms step_avg:97.83ms +step:1078/1670 train_time:105457ms step_avg:97.83ms +step:1079/1670 train_time:105554ms step_avg:97.83ms +step:1080/1670 train_time:105650ms step_avg:97.82ms +step:1081/1670 train_time:105746ms step_avg:97.82ms +step:1082/1670 train_time:105845ms step_avg:97.82ms +step:1083/1670 train_time:105944ms step_avg:97.82ms +step:1084/1670 train_time:106043ms step_avg:97.83ms +step:1085/1670 train_time:106141ms step_avg:97.83ms +step:1086/1670 train_time:106239ms step_avg:97.83ms +step:1087/1670 train_time:106337ms step_avg:97.83ms +step:1088/1670 train_time:106434ms step_avg:97.83ms +step:1089/1670 train_time:106531ms step_avg:97.82ms +step:1090/1670 train_time:106627ms step_avg:97.82ms +step:1091/1670 train_time:106723ms step_avg:97.82ms +step:1092/1670 train_time:106821ms step_avg:97.82ms +step:1093/1670 train_time:106919ms step_avg:97.82ms +step:1094/1670 train_time:107019ms step_avg:97.82ms +step:1095/1670 train_time:107118ms step_avg:97.82ms +step:1096/1670 train_time:107216ms step_avg:97.82ms +step:1097/1670 train_time:107314ms step_avg:97.82ms +step:1098/1670 train_time:107410ms step_avg:97.82ms +step:1099/1670 train_time:107506ms step_avg:97.82ms +step:1100/1670 train_time:107602ms step_avg:97.82ms +step:1101/1670 train_time:107700ms step_avg:97.82ms +step:1102/1670 train_time:107797ms step_avg:97.82ms +step:1103/1670 train_time:107895ms step_avg:97.82ms +step:1104/1670 train_time:107993ms step_avg:97.82ms +step:1105/1670 train_time:108090ms step_avg:97.82ms +step:1106/1670 train_time:108187ms step_avg:97.82ms +step:1107/1670 train_time:108285ms step_avg:97.82ms +step:1108/1670 train_time:108382ms step_avg:97.82ms +step:1109/1670 train_time:108480ms step_avg:97.82ms +step:1110/1670 train_time:108578ms step_avg:97.82ms +step:1111/1670 train_time:108675ms step_avg:97.82ms +step:1112/1670 train_time:108772ms step_avg:97.82ms +step:1113/1670 train_time:108868ms step_avg:97.82ms +step:1114/1670 train_time:108965ms step_avg:97.81ms +step:1115/1670 train_time:109063ms step_avg:97.81ms +step:1116/1670 train_time:109163ms step_avg:97.82ms +step:1117/1670 train_time:109262ms step_avg:97.82ms +step:1118/1670 train_time:109359ms step_avg:97.82ms +step:1119/1670 train_time:109457ms step_avg:97.82ms +step:1120/1670 train_time:109555ms step_avg:97.82ms +step:1121/1670 train_time:109653ms step_avg:97.82ms +step:1122/1670 train_time:109751ms step_avg:97.82ms +step:1123/1670 train_time:109848ms step_avg:97.82ms +step:1124/1670 train_time:109945ms step_avg:97.82ms +step:1125/1670 train_time:110043ms step_avg:97.82ms +step:1125/1670 val_loss:3.4212 train_time:110141ms step_avg:97.90ms +step:1126/1670 train_time:110163ms step_avg:97.84ms +step:1127/1670 train_time:110245ms step_avg:97.82ms +step:1128/1670 train_time:110342ms step_avg:97.82ms +step:1129/1670 train_time:110440ms step_avg:97.82ms +step:1130/1670 train_time:110537ms step_avg:97.82ms +step:1131/1670 train_time:110634ms step_avg:97.82ms +step:1132/1670 train_time:110730ms step_avg:97.82ms +step:1133/1670 train_time:110827ms step_avg:97.82ms +step:1134/1670 train_time:110924ms step_avg:97.82ms +step:1135/1670 train_time:111022ms step_avg:97.82ms +step:1136/1670 train_time:111125ms step_avg:97.82ms +step:1137/1670 train_time:111225ms step_avg:97.82ms +step:1138/1670 train_time:111323ms step_avg:97.82ms +step:1139/1670 train_time:111422ms step_avg:97.82ms +step:1140/1670 train_time:111519ms step_avg:97.82ms +step:1141/1670 train_time:111617ms step_avg:97.82ms +step:1142/1670 train_time:111713ms step_avg:97.82ms +step:1143/1670 train_time:111810ms step_avg:97.82ms +step:1144/1670 train_time:111907ms step_avg:97.82ms +step:1145/1670 train_time:112004ms step_avg:97.82ms +step:1146/1670 train_time:112102ms step_avg:97.82ms +step:1147/1670 train_time:112202ms step_avg:97.82ms +step:1148/1670 train_time:112301ms step_avg:97.82ms +step:1149/1670 train_time:112401ms step_avg:97.83ms +step:1150/1670 train_time:112499ms step_avg:97.82ms +step:1151/1670 train_time:112597ms step_avg:97.82ms +step:1152/1670 train_time:112694ms step_avg:97.82ms +step:1153/1670 train_time:112792ms step_avg:97.82ms +step:1154/1670 train_time:112889ms step_avg:97.82ms +step:1155/1670 train_time:112986ms step_avg:97.82ms +step:1156/1670 train_time:113085ms step_avg:97.82ms +step:1157/1670 train_time:113183ms step_avg:97.82ms +step:1158/1670 train_time:113282ms step_avg:97.83ms +step:1159/1670 train_time:113381ms step_avg:97.83ms +step:1160/1670 train_time:113479ms step_avg:97.83ms +step:1161/1670 train_time:113577ms step_avg:97.83ms +step:1162/1670 train_time:113675ms step_avg:97.83ms +step:1163/1670 train_time:113773ms step_avg:97.83ms +step:1164/1670 train_time:113871ms step_avg:97.83ms +step:1165/1670 train_time:113968ms step_avg:97.83ms +step:1166/1670 train_time:114066ms step_avg:97.83ms +step:1167/1670 train_time:114163ms step_avg:97.83ms +step:1168/1670 train_time:114261ms step_avg:97.83ms +step:1169/1670 train_time:114360ms step_avg:97.83ms +step:1170/1670 train_time:114459ms step_avg:97.83ms +step:1171/1670 train_time:114557ms step_avg:97.83ms +step:1172/1670 train_time:114655ms step_avg:97.83ms +step:1173/1670 train_time:114753ms step_avg:97.83ms +step:1174/1670 train_time:114850ms step_avg:97.83ms +step:1175/1670 train_time:114947ms step_avg:97.83ms +step:1176/1670 train_time:115045ms step_avg:97.83ms +step:1177/1670 train_time:115143ms step_avg:97.83ms +step:1178/1670 train_time:115242ms step_avg:97.83ms +step:1179/1670 train_time:115341ms step_avg:97.83ms +step:1180/1670 train_time:115439ms step_avg:97.83ms +step:1181/1670 train_time:115537ms step_avg:97.83ms +step:1182/1670 train_time:115635ms step_avg:97.83ms +step:1183/1670 train_time:115734ms step_avg:97.83ms +step:1184/1670 train_time:115832ms step_avg:97.83ms +step:1185/1670 train_time:115929ms step_avg:97.83ms +step:1186/1670 train_time:116027ms step_avg:97.83ms +step:1187/1670 train_time:116125ms step_avg:97.83ms +step:1188/1670 train_time:116223ms step_avg:97.83ms +step:1189/1670 train_time:116321ms step_avg:97.83ms +step:1190/1670 train_time:116419ms step_avg:97.83ms +step:1191/1670 train_time:116517ms step_avg:97.83ms +step:1192/1670 train_time:116615ms step_avg:97.83ms +step:1193/1670 train_time:116713ms step_avg:97.83ms +step:1194/1670 train_time:116811ms step_avg:97.83ms +step:1195/1670 train_time:116909ms step_avg:97.83ms +step:1196/1670 train_time:117006ms step_avg:97.83ms +step:1197/1670 train_time:117104ms step_avg:97.83ms +step:1198/1670 train_time:117202ms step_avg:97.83ms +step:1199/1670 train_time:117299ms step_avg:97.83ms +step:1200/1670 train_time:117397ms step_avg:97.83ms +step:1201/1670 train_time:117494ms step_avg:97.83ms +step:1202/1670 train_time:117593ms step_avg:97.83ms +step:1203/1670 train_time:117691ms step_avg:97.83ms +step:1204/1670 train_time:117789ms step_avg:97.83ms +step:1205/1670 train_time:117887ms step_avg:97.83ms +step:1206/1670 train_time:117984ms step_avg:97.83ms +step:1207/1670 train_time:118082ms step_avg:97.83ms +step:1208/1670 train_time:118180ms step_avg:97.83ms +step:1209/1670 train_time:118279ms step_avg:97.83ms +step:1210/1670 train_time:118377ms step_avg:97.83ms +step:1211/1670 train_time:118474ms step_avg:97.83ms +step:1212/1670 train_time:118571ms step_avg:97.83ms +step:1213/1670 train_time:118670ms step_avg:97.83ms +step:1214/1670 train_time:118767ms step_avg:97.83ms +step:1215/1670 train_time:118865ms step_avg:97.83ms +step:1216/1670 train_time:118963ms step_avg:97.83ms +step:1217/1670 train_time:119061ms step_avg:97.83ms +step:1218/1670 train_time:119160ms step_avg:97.83ms +step:1219/1670 train_time:119259ms step_avg:97.83ms +step:1220/1670 train_time:119357ms step_avg:97.83ms +step:1221/1670 train_time:119454ms step_avg:97.83ms +step:1222/1670 train_time:119551ms step_avg:97.83ms +step:1223/1670 train_time:119649ms step_avg:97.83ms +step:1224/1670 train_time:119747ms step_avg:97.83ms +step:1225/1670 train_time:119846ms step_avg:97.83ms +step:1226/1670 train_time:119944ms step_avg:97.83ms +step:1227/1670 train_time:120042ms step_avg:97.83ms +step:1228/1670 train_time:120140ms step_avg:97.83ms +step:1229/1670 train_time:120238ms step_avg:97.83ms +step:1230/1670 train_time:120337ms step_avg:97.84ms +step:1231/1670 train_time:120435ms step_avg:97.84ms +step:1232/1670 train_time:120534ms step_avg:97.84ms +step:1233/1670 train_time:120632ms step_avg:97.84ms +step:1234/1670 train_time:120730ms step_avg:97.84ms +step:1235/1670 train_time:120830ms step_avg:97.84ms +step:1236/1670 train_time:120927ms step_avg:97.84ms +step:1237/1670 train_time:121025ms step_avg:97.84ms +step:1238/1670 train_time:121122ms step_avg:97.84ms +step:1239/1670 train_time:121221ms step_avg:97.84ms +step:1240/1670 train_time:121318ms step_avg:97.84ms +step:1241/1670 train_time:121416ms step_avg:97.84ms +step:1242/1670 train_time:121514ms step_avg:97.84ms +step:1243/1670 train_time:121612ms step_avg:97.84ms +step:1244/1670 train_time:121711ms step_avg:97.84ms +step:1245/1670 train_time:121809ms step_avg:97.84ms +step:1246/1670 train_time:121907ms step_avg:97.84ms +step:1247/1670 train_time:122004ms step_avg:97.84ms +step:1248/1670 train_time:122101ms step_avg:97.84ms +step:1249/1670 train_time:122199ms step_avg:97.84ms +step:1250/1670 train_time:122297ms step_avg:97.84ms +step:1250/1670 val_loss:3.3783 train_time:122395ms step_avg:97.92ms +step:1251/1670 train_time:122416ms step_avg:97.85ms +step:1252/1670 train_time:122502ms step_avg:97.85ms +step:1253/1670 train_time:122604ms step_avg:97.85ms +step:1254/1670 train_time:122703ms step_avg:97.85ms +step:1255/1670 train_time:122801ms step_avg:97.85ms +step:1256/1670 train_time:122898ms step_avg:97.85ms +step:1257/1670 train_time:122996ms step_avg:97.85ms +step:1258/1670 train_time:123092ms step_avg:97.85ms +step:1259/1670 train_time:123189ms step_avg:97.85ms +step:1260/1670 train_time:123285ms step_avg:97.85ms +step:1261/1670 train_time:123384ms step_avg:97.85ms +step:1262/1670 train_time:123483ms step_avg:97.85ms +step:1263/1670 train_time:123587ms step_avg:97.85ms +step:1264/1670 train_time:123686ms step_avg:97.85ms +step:1265/1670 train_time:123784ms step_avg:97.85ms +step:1266/1670 train_time:123881ms step_avg:97.85ms +step:1267/1670 train_time:123979ms step_avg:97.85ms +step:1268/1670 train_time:124077ms step_avg:97.85ms +step:1269/1670 train_time:124174ms step_avg:97.85ms +step:1270/1670 train_time:124271ms step_avg:97.85ms +step:1271/1670 train_time:124369ms step_avg:97.85ms +step:1272/1670 train_time:124468ms step_avg:97.85ms +step:1273/1670 train_time:124567ms step_avg:97.85ms +step:1274/1670 train_time:124951ms step_avg:98.08ms +step:1275/1670 train_time:125025ms step_avg:98.06ms +step:1276/1670 train_time:125122ms step_avg:98.06ms +step:1277/1670 train_time:125218ms step_avg:98.06ms +step:1278/1670 train_time:125315ms step_avg:98.06ms +step:1279/1670 train_time:125412ms step_avg:98.05ms +step:1280/1670 train_time:125509ms step_avg:98.05ms +step:1281/1670 train_time:125605ms step_avg:98.05ms +step:1282/1670 train_time:125702ms step_avg:98.05ms +step:1283/1670 train_time:125799ms step_avg:98.05ms +step:1284/1670 train_time:125900ms step_avg:98.05ms +step:1285/1670 train_time:126005ms step_avg:98.06ms +step:1286/1670 train_time:126104ms step_avg:98.06ms +step:1287/1670 train_time:126201ms step_avg:98.06ms +step:1288/1670 train_time:126300ms step_avg:98.06ms +step:1289/1670 train_time:126398ms step_avg:98.06ms +step:1290/1670 train_time:126495ms step_avg:98.06ms +step:1291/1670 train_time:126593ms step_avg:98.06ms +step:1292/1670 train_time:126690ms step_avg:98.06ms +step:1293/1670 train_time:126787ms step_avg:98.06ms +step:1294/1670 train_time:126886ms step_avg:98.06ms +step:1295/1670 train_time:126986ms step_avg:98.06ms +step:1296/1670 train_time:127086ms step_avg:98.06ms +step:1297/1670 train_time:127183ms step_avg:98.06ms +step:1298/1670 train_time:127281ms step_avg:98.06ms +step:1299/1670 train_time:127379ms step_avg:98.06ms +step:1300/1670 train_time:127477ms step_avg:98.06ms +step:1301/1670 train_time:127575ms step_avg:98.06ms +step:1302/1670 train_time:127672ms step_avg:98.06ms +step:1303/1670 train_time:127769ms step_avg:98.06ms +step:1304/1670 train_time:127867ms step_avg:98.06ms +step:1305/1670 train_time:127966ms step_avg:98.06ms +step:1306/1670 train_time:128065ms step_avg:98.06ms +step:1307/1670 train_time:128162ms step_avg:98.06ms +step:1308/1670 train_time:128261ms step_avg:98.06ms +step:1309/1670 train_time:128359ms step_avg:98.06ms +step:1310/1670 train_time:128458ms step_avg:98.06ms +step:1311/1670 train_time:128557ms step_avg:98.06ms +step:1312/1670 train_time:128654ms step_avg:98.06ms +step:1313/1670 train_time:128754ms step_avg:98.06ms +step:1314/1670 train_time:128854ms step_avg:98.06ms +step:1315/1670 train_time:128956ms step_avg:98.07ms +step:1316/1670 train_time:129055ms step_avg:98.07ms +step:1317/1670 train_time:129154ms step_avg:98.07ms +step:1318/1670 train_time:129253ms step_avg:98.07ms +step:1319/1670 train_time:129351ms step_avg:98.07ms +step:1320/1670 train_time:129447ms step_avg:98.07ms +step:1321/1670 train_time:129544ms step_avg:98.06ms +step:1322/1670 train_time:129641ms step_avg:98.06ms +step:1323/1670 train_time:129740ms step_avg:98.06ms +step:1324/1670 train_time:129839ms step_avg:98.07ms +step:1325/1670 train_time:129939ms step_avg:98.07ms +step:1326/1670 train_time:130039ms step_avg:98.07ms +step:1327/1670 train_time:130138ms step_avg:98.07ms +step:1328/1670 train_time:130237ms step_avg:98.07ms +step:1329/1670 train_time:130336ms step_avg:98.07ms +step:1330/1670 train_time:130435ms step_avg:98.07ms +step:1331/1670 train_time:130533ms step_avg:98.07ms +step:1332/1670 train_time:130632ms step_avg:98.07ms +step:1333/1670 train_time:130728ms step_avg:98.07ms +step:1334/1670 train_time:130825ms step_avg:98.07ms +step:1335/1670 train_time:130924ms step_avg:98.07ms +step:1336/1670 train_time:131022ms step_avg:98.07ms +step:1337/1670 train_time:131121ms step_avg:98.07ms +step:1338/1670 train_time:131220ms step_avg:98.07ms +step:1339/1670 train_time:131319ms step_avg:98.07ms +step:1340/1670 train_time:131416ms step_avg:98.07ms +step:1341/1670 train_time:131515ms step_avg:98.07ms +step:1342/1670 train_time:131612ms step_avg:98.07ms +step:1343/1670 train_time:131711ms step_avg:98.07ms +step:1344/1670 train_time:131808ms step_avg:98.07ms +step:1345/1670 train_time:131905ms step_avg:98.07ms +step:1346/1670 train_time:132003ms step_avg:98.07ms +step:1347/1670 train_time:132102ms step_avg:98.07ms +step:1348/1670 train_time:132201ms step_avg:98.07ms +step:1349/1670 train_time:132300ms step_avg:98.07ms +step:1350/1670 train_time:132399ms step_avg:98.07ms +step:1351/1670 train_time:132497ms step_avg:98.07ms +step:1352/1670 train_time:132595ms step_avg:98.07ms +step:1353/1670 train_time:132694ms step_avg:98.07ms +step:1354/1670 train_time:132793ms step_avg:98.07ms +step:1355/1670 train_time:132892ms step_avg:98.08ms +step:1356/1670 train_time:132991ms step_avg:98.08ms +step:1357/1670 train_time:133089ms step_avg:98.08ms +step:1358/1670 train_time:133187ms step_avg:98.08ms +step:1359/1670 train_time:133285ms step_avg:98.08ms +step:1360/1670 train_time:133383ms step_avg:98.08ms +step:1361/1670 train_time:133481ms step_avg:98.08ms +step:1362/1670 train_time:133579ms step_avg:98.08ms +step:1363/1670 train_time:133678ms step_avg:98.08ms +step:1364/1670 train_time:133777ms step_avg:98.08ms +step:1365/1670 train_time:133876ms step_avg:98.08ms +step:1366/1670 train_time:133974ms step_avg:98.08ms +step:1367/1670 train_time:134073ms step_avg:98.08ms +step:1368/1670 train_time:134174ms step_avg:98.08ms +step:1369/1670 train_time:134273ms step_avg:98.08ms +step:1370/1670 train_time:134371ms step_avg:98.08ms +step:1371/1670 train_time:134469ms step_avg:98.08ms +step:1372/1670 train_time:134566ms step_avg:98.08ms +step:1373/1670 train_time:134664ms step_avg:98.08ms +step:1374/1670 train_time:134762ms step_avg:98.08ms +step:1375/1670 train_time:134860ms step_avg:98.08ms +step:1375/1670 val_loss:3.3416 train_time:134958ms step_avg:98.15ms +step:1376/1670 train_time:134979ms step_avg:98.10ms +step:1377/1670 train_time:135063ms step_avg:98.09ms +step:1378/1670 train_time:135166ms step_avg:98.09ms +step:1379/1670 train_time:135267ms step_avg:98.09ms +step:1380/1670 train_time:135365ms step_avg:98.09ms +step:1381/1670 train_time:135462ms step_avg:98.09ms +step:1382/1670 train_time:135559ms step_avg:98.09ms +step:1383/1670 train_time:135657ms step_avg:98.09ms +step:1384/1670 train_time:135755ms step_avg:98.09ms +step:1385/1670 train_time:135851ms step_avg:98.09ms +step:1386/1670 train_time:135949ms step_avg:98.09ms +step:1387/1670 train_time:136050ms step_avg:98.09ms +step:1388/1670 train_time:136149ms step_avg:98.09ms +step:1389/1670 train_time:136248ms step_avg:98.09ms +step:1390/1670 train_time:136347ms step_avg:98.09ms +step:1391/1670 train_time:136445ms step_avg:98.09ms +step:1392/1670 train_time:136543ms step_avg:98.09ms +step:1393/1670 train_time:136641ms step_avg:98.09ms +step:1394/1670 train_time:136739ms step_avg:98.09ms +step:1395/1670 train_time:136837ms step_avg:98.09ms +step:1396/1670 train_time:136935ms step_avg:98.09ms +step:1397/1670 train_time:137034ms step_avg:98.09ms +step:1398/1670 train_time:137133ms step_avg:98.09ms +step:1399/1670 train_time:137231ms step_avg:98.09ms +step:1400/1670 train_time:137329ms step_avg:98.09ms +step:1401/1670 train_time:137426ms step_avg:98.09ms +step:1402/1670 train_time:137523ms step_avg:98.09ms +step:1403/1670 train_time:137621ms step_avg:98.09ms +step:1404/1670 train_time:137719ms step_avg:98.09ms +step:1405/1670 train_time:137817ms step_avg:98.09ms +step:1406/1670 train_time:137915ms step_avg:98.09ms +step:1407/1670 train_time:138013ms step_avg:98.09ms +step:1408/1670 train_time:138111ms step_avg:98.09ms +step:1409/1670 train_time:138209ms step_avg:98.09ms +step:1410/1670 train_time:138307ms step_avg:98.09ms +step:1411/1670 train_time:138405ms step_avg:98.09ms +step:1412/1670 train_time:138502ms step_avg:98.09ms +step:1413/1670 train_time:138600ms step_avg:98.09ms +step:1414/1670 train_time:138697ms step_avg:98.09ms +step:1415/1670 train_time:138795ms step_avg:98.09ms +step:1416/1670 train_time:138893ms step_avg:98.09ms +step:1417/1670 train_time:138991ms step_avg:98.09ms +step:1418/1670 train_time:139089ms step_avg:98.09ms +step:1419/1670 train_time:139188ms step_avg:98.09ms +step:1420/1670 train_time:139286ms step_avg:98.09ms +step:1421/1670 train_time:139383ms step_avg:98.09ms +step:1422/1670 train_time:139481ms step_avg:98.09ms +step:1423/1670 train_time:139579ms step_avg:98.09ms +step:1424/1670 train_time:139676ms step_avg:98.09ms +step:1425/1670 train_time:139774ms step_avg:98.09ms +step:1426/1670 train_time:139872ms step_avg:98.09ms +step:1427/1670 train_time:139969ms step_avg:98.09ms +step:1428/1670 train_time:140067ms step_avg:98.09ms +step:1429/1670 train_time:140167ms step_avg:98.09ms +step:1430/1670 train_time:140266ms step_avg:98.09ms +step:1431/1670 train_time:140364ms step_avg:98.09ms +step:1432/1670 train_time:140462ms step_avg:98.09ms +step:1433/1670 train_time:140561ms step_avg:98.09ms +step:1434/1670 train_time:140658ms step_avg:98.09ms +step:1435/1670 train_time:140756ms step_avg:98.09ms +step:1436/1670 train_time:140855ms step_avg:98.09ms +step:1437/1670 train_time:140954ms step_avg:98.09ms +step:1438/1670 train_time:141051ms step_avg:98.09ms +step:1439/1670 train_time:141148ms step_avg:98.09ms +step:1440/1670 train_time:141247ms step_avg:98.09ms +step:1441/1670 train_time:141345ms step_avg:98.09ms +step:1442/1670 train_time:141444ms step_avg:98.09ms +step:1443/1670 train_time:141542ms step_avg:98.09ms +step:1444/1670 train_time:141640ms step_avg:98.09ms +step:1445/1670 train_time:141739ms step_avg:98.09ms +step:1446/1670 train_time:141838ms step_avg:98.09ms +step:1447/1670 train_time:141937ms step_avg:98.09ms +step:1448/1670 train_time:142036ms step_avg:98.09ms +step:1449/1670 train_time:142134ms step_avg:98.09ms +step:1450/1670 train_time:142231ms step_avg:98.09ms +step:1451/1670 train_time:142328ms step_avg:98.09ms +step:1452/1670 train_time:142425ms step_avg:98.09ms +step:1453/1670 train_time:142523ms step_avg:98.09ms +step:1454/1670 train_time:142621ms step_avg:98.09ms +step:1455/1670 train_time:142718ms step_avg:98.09ms +step:1456/1670 train_time:142817ms step_avg:98.09ms +step:1457/1670 train_time:142915ms step_avg:98.09ms +step:1458/1670 train_time:143012ms step_avg:98.09ms +step:1459/1670 train_time:143110ms step_avg:98.09ms +step:1460/1670 train_time:143208ms step_avg:98.09ms +step:1461/1670 train_time:143306ms step_avg:98.09ms +step:1462/1670 train_time:143404ms step_avg:98.09ms +step:1463/1670 train_time:143502ms step_avg:98.09ms +step:1464/1670 train_time:143600ms step_avg:98.09ms +step:1465/1670 train_time:143697ms step_avg:98.09ms +step:1466/1670 train_time:143795ms step_avg:98.09ms +step:1467/1670 train_time:143893ms step_avg:98.09ms +step:1468/1670 train_time:143991ms step_avg:98.09ms +step:1469/1670 train_time:144090ms step_avg:98.09ms +step:1470/1670 train_time:144187ms step_avg:98.09ms +step:1471/1670 train_time:144285ms step_avg:98.09ms +step:1472/1670 train_time:144384ms step_avg:98.09ms +step:1473/1670 train_time:144483ms step_avg:98.09ms +step:1474/1670 train_time:144581ms step_avg:98.09ms +step:1475/1670 train_time:144679ms step_avg:98.09ms +step:1476/1670 train_time:144777ms step_avg:98.09ms +step:1477/1670 train_time:144876ms step_avg:98.09ms +step:1478/1670 train_time:144974ms step_avg:98.09ms +step:1479/1670 train_time:145072ms step_avg:98.09ms +step:1480/1670 train_time:145169ms step_avg:98.09ms +step:1481/1670 train_time:145268ms step_avg:98.09ms +step:1482/1670 train_time:145365ms step_avg:98.09ms +step:1483/1670 train_time:145464ms step_avg:98.09ms +step:1484/1670 train_time:145561ms step_avg:98.09ms +step:1485/1670 train_time:145838ms step_avg:98.21ms +step:1486/1670 train_time:146020ms step_avg:98.26ms +step:1487/1670 train_time:146116ms step_avg:98.26ms +step:1488/1670 train_time:146213ms step_avg:98.26ms +step:1489/1670 train_time:146309ms step_avg:98.26ms +step:1490/1670 train_time:146407ms step_avg:98.26ms +step:1491/1670 train_time:146504ms step_avg:98.26ms +step:1492/1670 train_time:146601ms step_avg:98.26ms +step:1493/1670 train_time:146698ms step_avg:98.26ms +step:1494/1670 train_time:146796ms step_avg:98.26ms +step:1495/1670 train_time:146894ms step_avg:98.26ms +step:1496/1670 train_time:146997ms step_avg:98.26ms +step:1497/1670 train_time:147098ms step_avg:98.26ms +step:1498/1670 train_time:147197ms step_avg:98.26ms +step:1499/1670 train_time:147294ms step_avg:98.26ms +step:1500/1670 train_time:147391ms step_avg:98.26ms +step:1500/1670 val_loss:3.3089 train_time:147487ms step_avg:98.32ms +step:1501/1670 train_time:147508ms step_avg:98.27ms +step:1502/1670 train_time:147591ms step_avg:98.26ms +step:1503/1670 train_time:147690ms step_avg:98.26ms +step:1504/1670 train_time:147787ms step_avg:98.26ms +step:1505/1670 train_time:147884ms step_avg:98.26ms +step:1506/1670 train_time:147981ms step_avg:98.26ms +step:1507/1670 train_time:148078ms step_avg:98.26ms +step:1508/1670 train_time:148176ms step_avg:98.26ms +step:1509/1670 train_time:148273ms step_avg:98.26ms +step:1510/1670 train_time:148371ms step_avg:98.26ms +step:1511/1670 train_time:148469ms step_avg:98.26ms +step:1512/1670 train_time:148572ms step_avg:98.26ms +step:1513/1670 train_time:148671ms step_avg:98.26ms +step:1514/1670 train_time:148769ms step_avg:98.26ms +step:1515/1670 train_time:148867ms step_avg:98.26ms +step:1516/1670 train_time:148964ms step_avg:98.26ms +step:1517/1670 train_time:149061ms step_avg:98.26ms +step:1518/1670 train_time:149159ms step_avg:98.26ms +step:1519/1670 train_time:149257ms step_avg:98.26ms +step:1520/1670 train_time:149355ms step_avg:98.26ms +step:1521/1670 train_time:149455ms step_avg:98.26ms +step:1522/1670 train_time:149555ms step_avg:98.26ms +step:1523/1670 train_time:149654ms step_avg:98.26ms +step:1524/1670 train_time:149752ms step_avg:98.26ms +step:1525/1670 train_time:149849ms step_avg:98.26ms +step:1526/1670 train_time:149946ms step_avg:98.26ms +step:1527/1670 train_time:150044ms step_avg:98.26ms +step:1528/1670 train_time:150142ms step_avg:98.26ms +step:1529/1670 train_time:150239ms step_avg:98.26ms +step:1530/1670 train_time:150336ms step_avg:98.26ms +step:1531/1670 train_time:150435ms step_avg:98.26ms +step:1532/1670 train_time:150534ms step_avg:98.26ms +step:1533/1670 train_time:150633ms step_avg:98.26ms +step:1534/1670 train_time:150731ms step_avg:98.26ms +step:1535/1670 train_time:150829ms step_avg:98.26ms +step:1536/1670 train_time:150926ms step_avg:98.26ms +step:1537/1670 train_time:151024ms step_avg:98.26ms +step:1538/1670 train_time:151121ms step_avg:98.26ms +step:1539/1670 train_time:151220ms step_avg:98.26ms +step:1540/1670 train_time:151318ms step_avg:98.26ms +step:1541/1670 train_time:151416ms step_avg:98.26ms +step:1542/1670 train_time:151515ms step_avg:98.26ms +step:1543/1670 train_time:151615ms step_avg:98.26ms +step:1544/1670 train_time:151713ms step_avg:98.26ms +step:1545/1670 train_time:151811ms step_avg:98.26ms +step:1546/1670 train_time:151908ms step_avg:98.26ms +step:1547/1670 train_time:152004ms step_avg:98.26ms +step:1548/1670 train_time:152103ms step_avg:98.26ms +step:1549/1670 train_time:152202ms step_avg:98.26ms +step:1550/1670 train_time:152301ms step_avg:98.26ms +step:1551/1670 train_time:152399ms step_avg:98.26ms +step:1552/1670 train_time:152499ms step_avg:98.26ms +step:1553/1670 train_time:152599ms step_avg:98.26ms +step:1554/1670 train_time:152698ms step_avg:98.26ms +step:1555/1670 train_time:152798ms step_avg:98.26ms +step:1556/1670 train_time:152896ms step_avg:98.26ms +step:1557/1670 train_time:152994ms step_avg:98.26ms +step:1558/1670 train_time:153092ms step_avg:98.26ms +step:1559/1670 train_time:153189ms step_avg:98.26ms +step:1560/1670 train_time:153287ms step_avg:98.26ms +step:1561/1670 train_time:153384ms step_avg:98.26ms +step:1562/1670 train_time:153483ms step_avg:98.26ms +step:1563/1670 train_time:153583ms step_avg:98.26ms +step:1564/1670 train_time:153685ms step_avg:98.26ms +step:1565/1670 train_time:153784ms step_avg:98.26ms +step:1566/1670 train_time:153884ms step_avg:98.27ms +step:1567/1670 train_time:153982ms step_avg:98.27ms +step:1568/1670 train_time:154081ms step_avg:98.27ms +step:1569/1670 train_time:154180ms step_avg:98.27ms +step:1570/1670 train_time:154278ms step_avg:98.27ms +step:1571/1670 train_time:154375ms step_avg:98.27ms +step:1572/1670 train_time:154473ms step_avg:98.27ms +step:1573/1670 train_time:154570ms step_avg:98.26ms +step:1574/1670 train_time:154668ms step_avg:98.26ms +step:1575/1670 train_time:154766ms step_avg:98.26ms +step:1576/1670 train_time:154865ms step_avg:98.26ms +step:1577/1670 train_time:154964ms step_avg:98.27ms +step:1578/1670 train_time:155065ms step_avg:98.27ms +step:1579/1670 train_time:155164ms step_avg:98.27ms +step:1580/1670 train_time:155263ms step_avg:98.27ms +step:1581/1670 train_time:155361ms step_avg:98.27ms +step:1582/1670 train_time:155460ms step_avg:98.27ms +step:1583/1670 train_time:155560ms step_avg:98.27ms +step:1584/1670 train_time:155658ms step_avg:98.27ms +step:1585/1670 train_time:155755ms step_avg:98.27ms +step:1586/1670 train_time:155853ms step_avg:98.27ms +step:1587/1670 train_time:155951ms step_avg:98.27ms +step:1588/1670 train_time:156048ms step_avg:98.27ms +step:1589/1670 train_time:156146ms step_avg:98.27ms +step:1590/1670 train_time:156245ms step_avg:98.27ms +step:1591/1670 train_time:156343ms step_avg:98.27ms +step:1592/1670 train_time:156442ms step_avg:98.27ms +step:1593/1670 train_time:156542ms step_avg:98.27ms +step:1594/1670 train_time:156643ms step_avg:98.27ms +step:1595/1670 train_time:156743ms step_avg:98.27ms +step:1596/1670 train_time:156843ms step_avg:98.27ms +step:1597/1670 train_time:156941ms step_avg:98.27ms +step:1598/1670 train_time:157040ms step_avg:98.27ms +step:1599/1670 train_time:157138ms step_avg:98.27ms +step:1600/1670 train_time:157237ms step_avg:98.27ms +step:1601/1670 train_time:157334ms step_avg:98.27ms +step:1602/1670 train_time:157431ms step_avg:98.27ms +step:1603/1670 train_time:157529ms step_avg:98.27ms +step:1604/1670 train_time:157628ms step_avg:98.27ms +step:1605/1670 train_time:157726ms step_avg:98.27ms +step:1606/1670 train_time:157825ms step_avg:98.27ms +step:1607/1670 train_time:157924ms step_avg:98.27ms +step:1608/1670 train_time:158022ms step_avg:98.27ms +step:1609/1670 train_time:158121ms step_avg:98.27ms +step:1610/1670 train_time:158220ms step_avg:98.27ms +step:1611/1670 train_time:158318ms step_avg:98.27ms +step:1612/1670 train_time:158417ms step_avg:98.27ms +step:1613/1670 train_time:158514ms step_avg:98.27ms +step:1614/1670 train_time:158612ms step_avg:98.27ms +step:1615/1670 train_time:158710ms step_avg:98.27ms +step:1616/1670 train_time:158808ms step_avg:98.27ms +step:1617/1670 train_time:158906ms step_avg:98.27ms +step:1618/1670 train_time:159005ms step_avg:98.27ms +step:1619/1670 train_time:159104ms step_avg:98.27ms +step:1620/1670 train_time:159202ms step_avg:98.27ms +step:1621/1670 train_time:159301ms step_avg:98.27ms +step:1622/1670 train_time:159401ms step_avg:98.27ms +step:1623/1670 train_time:159500ms step_avg:98.28ms +step:1624/1670 train_time:159601ms step_avg:98.28ms +step:1625/1670 train_time:159700ms step_avg:98.28ms +step:1625/1670 val_loss:3.2825 train_time:159800ms step_avg:98.34ms +step:1626/1670 train_time:159821ms step_avg:98.29ms +step:1627/1670 train_time:159906ms step_avg:98.28ms +step:1628/1670 train_time:160009ms step_avg:98.29ms +step:1629/1670 train_time:160108ms step_avg:98.29ms +step:1630/1670 train_time:160205ms step_avg:98.29ms +step:1631/1670 train_time:160302ms step_avg:98.28ms +step:1632/1670 train_time:160399ms step_avg:98.28ms +step:1633/1670 train_time:160495ms step_avg:98.28ms +step:1634/1670 train_time:160592ms step_avg:98.28ms +step:1635/1670 train_time:160690ms step_avg:98.28ms +step:1636/1670 train_time:160791ms step_avg:98.28ms +step:1637/1670 train_time:160893ms step_avg:98.29ms +step:1638/1670 train_time:160994ms step_avg:98.29ms +step:1639/1670 train_time:161093ms step_avg:98.29ms +step:1640/1670 train_time:161192ms step_avg:98.29ms +step:1641/1670 train_time:161290ms step_avg:98.29ms +step:1642/1670 train_time:161389ms step_avg:98.29ms +step:1643/1670 train_time:161487ms step_avg:98.29ms +step:1644/1670 train_time:161585ms step_avg:98.29ms +step:1645/1670 train_time:161682ms step_avg:98.29ms +step:1646/1670 train_time:161781ms step_avg:98.29ms +step:1647/1670 train_time:161880ms step_avg:98.29ms +step:1648/1670 train_time:161980ms step_avg:98.29ms +step:1649/1670 train_time:162081ms step_avg:98.29ms +step:1650/1670 train_time:162181ms step_avg:98.29ms +step:1651/1670 train_time:162278ms step_avg:98.29ms +step:1652/1670 train_time:162376ms step_avg:98.29ms +step:1653/1670 train_time:162474ms step_avg:98.29ms +step:1654/1670 train_time:162571ms step_avg:98.29ms +step:1655/1670 train_time:162670ms step_avg:98.29ms +step:1656/1670 train_time:162767ms step_avg:98.29ms +step:1657/1670 train_time:162866ms step_avg:98.29ms +step:1658/1670 train_time:162966ms step_avg:98.29ms +step:1659/1670 train_time:163067ms step_avg:98.29ms +step:1660/1670 train_time:163166ms step_avg:98.29ms +step:1661/1670 train_time:163264ms step_avg:98.29ms +step:1662/1670 train_time:163363ms step_avg:98.29ms +step:1663/1670 train_time:163460ms step_avg:98.29ms +step:1664/1670 train_time:163558ms step_avg:98.29ms +step:1665/1670 train_time:163656ms step_avg:98.29ms +step:1666/1670 train_time:163752ms step_avg:98.29ms +step:1667/1670 train_time:163850ms step_avg:98.29ms +step:1668/1670 train_time:163948ms step_avg:98.29ms +step:1669/1670 train_time:164047ms step_avg:98.29ms +step:1670/1670 train_time:164147ms step_avg:98.29ms +step:1670/1670 val_loss:3.2747 train_time:164244ms step_avg:98.35ms +peak memory allocated: 34000 MiB reserved: 49296 MiB diff --git a/records/050925_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt b/records/050925_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt new file mode 100644 index 000000000..64e30d506 --- /dev/null +++ b/records/050925_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt @@ -0,0 +1,2815 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 15:32:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 44C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 67814 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 67815 C /usr/bin/python3 610MiB | +| 0 N/A N/A 67816 C /usr/bin/python3 610MiB | +| 0 N/A N/A 67817 C /usr/bin/python3 610MiB | +| 0 N/A N/A 67818 C /usr/bin/python3 610MiB | +| 0 N/A N/A 67819 C /usr/bin/python3 610MiB | +| 0 N/A N/A 67820 C /usr/bin/python3 610MiB | +| 0 N/A N/A 67821 C /usr/bin/python3 610MiB | +| 1 N/A N/A 67815 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 67816 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 67817 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 67818 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 67819 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 67820 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 67821 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:357ms step_avg:357.26ms +step:2/1670 train_time:378ms step_avg:189.13ms +step:3/1670 train_time:451ms step_avg:150.28ms +step:4/1670 train_time:545ms step_avg:136.24ms +step:5/1670 train_time:641ms step_avg:128.13ms +step:6/1670 train_time:735ms step_avg:122.58ms +step:7/1670 train_time:830ms step_avg:118.59ms +step:8/1670 train_time:925ms step_avg:115.65ms +step:9/1670 train_time:1021ms step_avg:113.42ms +step:10/1670 train_time:1116ms step_avg:111.61ms +step:11/1670 train_time:1213ms step_avg:110.25ms +step:12/1670 train_time:1309ms step_avg:109.12ms +step:13/1670 train_time:1407ms step_avg:108.21ms +step:14/1670 train_time:1504ms step_avg:107.40ms +step:15/1670 train_time:1601ms step_avg:106.73ms +step:16/1670 train_time:1697ms step_avg:106.07ms +step:17/1670 train_time:1793ms step_avg:105.45ms +step:18/1670 train_time:1888ms step_avg:104.87ms +step:19/1670 train_time:1983ms step_avg:104.36ms +step:20/1670 train_time:2078ms step_avg:103.91ms +step:21/1670 train_time:2175ms step_avg:103.56ms +step:22/1670 train_time:2270ms step_avg:103.19ms +step:23/1670 train_time:2367ms step_avg:102.91ms +step:24/1670 train_time:2463ms step_avg:102.64ms +step:25/1670 train_time:2560ms step_avg:102.41ms +step:26/1670 train_time:2657ms step_avg:102.18ms +step:27/1670 train_time:2754ms step_avg:101.99ms +step:28/1670 train_time:2850ms step_avg:101.79ms +step:29/1670 train_time:2946ms step_avg:101.57ms +step:30/1670 train_time:3042ms step_avg:101.39ms +step:31/1670 train_time:3137ms step_avg:101.20ms +step:32/1670 train_time:3234ms step_avg:101.06ms +step:33/1670 train_time:3330ms step_avg:100.90ms +step:34/1670 train_time:3426ms step_avg:100.76ms +step:35/1670 train_time:3522ms step_avg:100.62ms +step:36/1670 train_time:3618ms step_avg:100.51ms +step:37/1670 train_time:3716ms step_avg:100.43ms +step:38/1670 train_time:3812ms step_avg:100.33ms +step:39/1670 train_time:3908ms step_avg:100.21ms +step:40/1670 train_time:4003ms step_avg:100.08ms +step:41/1670 train_time:4099ms step_avg:99.98ms +step:42/1670 train_time:4196ms step_avg:99.90ms +step:43/1670 train_time:4292ms step_avg:99.80ms +step:44/1670 train_time:4387ms step_avg:99.70ms +step:45/1670 train_time:4484ms step_avg:99.63ms +step:46/1670 train_time:4580ms step_avg:99.56ms +step:47/1670 train_time:4676ms step_avg:99.49ms +step:48/1670 train_time:4773ms step_avg:99.43ms +step:49/1670 train_time:4869ms step_avg:99.37ms +step:50/1670 train_time:4965ms step_avg:99.30ms +step:51/1670 train_time:5061ms step_avg:99.24ms +step:52/1670 train_time:5157ms step_avg:99.17ms +step:53/1670 train_time:5253ms step_avg:99.11ms +step:54/1670 train_time:5348ms step_avg:99.04ms +step:55/1670 train_time:5444ms step_avg:98.98ms +step:56/1670 train_time:5541ms step_avg:98.94ms +step:57/1670 train_time:5638ms step_avg:98.91ms +step:58/1670 train_time:5735ms step_avg:98.88ms +step:59/1670 train_time:5832ms step_avg:98.84ms +step:60/1670 train_time:5927ms step_avg:98.79ms +step:61/1670 train_time:6022ms step_avg:98.73ms +step:62/1670 train_time:6119ms step_avg:98.70ms +step:63/1670 train_time:6216ms step_avg:98.66ms +step:64/1670 train_time:6312ms step_avg:98.62ms +step:65/1670 train_time:6407ms step_avg:98.58ms +step:66/1670 train_time:6503ms step_avg:98.53ms +step:67/1670 train_time:6600ms step_avg:98.51ms +step:68/1670 train_time:6696ms step_avg:98.48ms +step:69/1670 train_time:6793ms step_avg:98.45ms +step:70/1670 train_time:6888ms step_avg:98.41ms +step:71/1670 train_time:6984ms step_avg:98.36ms +step:72/1670 train_time:7080ms step_avg:98.33ms +step:73/1670 train_time:7176ms step_avg:98.30ms +step:74/1670 train_time:7272ms step_avg:98.27ms +step:75/1670 train_time:7368ms step_avg:98.24ms +step:76/1670 train_time:7465ms step_avg:98.22ms +step:77/1670 train_time:7561ms step_avg:98.19ms +step:78/1670 train_time:7657ms step_avg:98.17ms +step:79/1670 train_time:7753ms step_avg:98.14ms +step:80/1670 train_time:7850ms step_avg:98.12ms +step:81/1670 train_time:7946ms step_avg:98.10ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8137ms step_avg:98.04ms +step:84/1670 train_time:8233ms step_avg:98.01ms +step:85/1670 train_time:8328ms step_avg:97.98ms +step:86/1670 train_time:8424ms step_avg:97.96ms +step:87/1670 train_time:8520ms step_avg:97.93ms +step:88/1670 train_time:8616ms step_avg:97.91ms +step:89/1670 train_time:8712ms step_avg:97.89ms +step:90/1670 train_time:8808ms step_avg:97.87ms +step:91/1670 train_time:8904ms step_avg:97.85ms +step:92/1670 train_time:9001ms step_avg:97.83ms +step:93/1670 train_time:9097ms step_avg:97.82ms +step:94/1670 train_time:9194ms step_avg:97.81ms +step:95/1670 train_time:9289ms step_avg:97.78ms +step:96/1670 train_time:9385ms step_avg:97.76ms +step:97/1670 train_time:9480ms step_avg:97.74ms +step:98/1670 train_time:9577ms step_avg:97.72ms +step:99/1670 train_time:9672ms step_avg:97.70ms +step:100/1670 train_time:9768ms step_avg:97.68ms +step:101/1670 train_time:9864ms step_avg:97.66ms +step:102/1670 train_time:9960ms step_avg:97.65ms +step:103/1670 train_time:10056ms step_avg:97.63ms +step:104/1670 train_time:10152ms step_avg:97.62ms +step:105/1670 train_time:10248ms step_avg:97.60ms +step:106/1670 train_time:10344ms step_avg:97.58ms +step:107/1670 train_time:10439ms step_avg:97.56ms +step:108/1670 train_time:10535ms step_avg:97.55ms +step:109/1670 train_time:10630ms step_avg:97.53ms +step:110/1670 train_time:10726ms step_avg:97.51ms +step:111/1670 train_time:10821ms step_avg:97.49ms +step:112/1670 train_time:10918ms step_avg:97.48ms +step:113/1670 train_time:11014ms step_avg:97.47ms +step:114/1670 train_time:11110ms step_avg:97.46ms +step:115/1670 train_time:11206ms step_avg:97.44ms +step:116/1670 train_time:11302ms step_avg:97.43ms +step:117/1670 train_time:11397ms step_avg:97.41ms +step:118/1670 train_time:11493ms step_avg:97.40ms +step:119/1670 train_time:11588ms step_avg:97.38ms +step:120/1670 train_time:11684ms step_avg:97.36ms +step:121/1670 train_time:11780ms step_avg:97.35ms +step:122/1670 train_time:11875ms step_avg:97.34ms +step:123/1670 train_time:11971ms step_avg:97.33ms +step:124/1670 train_time:12067ms step_avg:97.31ms +step:125/1670 train_time:12163ms step_avg:97.30ms +step:125/1670 val_loss:4.2943 train_time:12258ms step_avg:98.07ms +step:126/1670 train_time:12281ms step_avg:97.47ms +step:127/1670 train_time:12362ms step_avg:97.34ms +step:128/1670 train_time:12464ms step_avg:97.37ms +step:129/1670 train_time:12560ms step_avg:97.36ms +step:130/1670 train_time:12655ms step_avg:97.35ms +step:131/1670 train_time:12750ms step_avg:97.33ms +step:132/1670 train_time:12845ms step_avg:97.31ms +step:133/1670 train_time:12939ms step_avg:97.29ms +step:134/1670 train_time:13033ms step_avg:97.26ms +step:135/1670 train_time:13128ms step_avg:97.24ms +step:136/1670 train_time:13223ms step_avg:97.23ms +step:137/1670 train_time:13319ms step_avg:97.22ms +step:138/1670 train_time:13417ms step_avg:97.23ms +step:139/1670 train_time:13515ms step_avg:97.23ms +step:140/1670 train_time:13611ms step_avg:97.22ms +step:141/1670 train_time:13707ms step_avg:97.21ms +step:142/1670 train_time:13802ms step_avg:97.20ms +step:143/1670 train_time:13898ms step_avg:97.19ms +step:144/1670 train_time:13993ms step_avg:97.17ms +step:145/1670 train_time:14088ms step_avg:97.16ms +step:146/1670 train_time:14183ms step_avg:97.14ms +step:147/1670 train_time:14278ms step_avg:97.13ms +step:148/1670 train_time:14374ms step_avg:97.12ms +step:149/1670 train_time:14471ms step_avg:97.12ms +step:150/1670 train_time:14568ms step_avg:97.12ms +step:151/1670 train_time:14664ms step_avg:97.11ms +step:152/1670 train_time:14760ms step_avg:97.11ms +step:153/1670 train_time:14855ms step_avg:97.09ms +step:154/1670 train_time:14950ms step_avg:97.08ms +step:155/1670 train_time:15045ms step_avg:97.07ms +step:156/1670 train_time:15140ms step_avg:97.05ms +step:157/1670 train_time:15235ms step_avg:97.04ms +step:158/1670 train_time:15331ms step_avg:97.03ms +step:159/1670 train_time:15427ms step_avg:97.03ms +step:160/1670 train_time:15524ms step_avg:97.03ms +step:161/1670 train_time:15621ms step_avg:97.02ms +step:162/1670 train_time:15716ms step_avg:97.01ms +step:163/1670 train_time:15812ms step_avg:97.00ms +step:164/1670 train_time:15908ms step_avg:97.00ms +step:165/1670 train_time:16004ms step_avg:96.99ms +step:166/1670 train_time:16099ms step_avg:96.98ms +step:167/1670 train_time:16194ms step_avg:96.97ms +step:168/1670 train_time:16289ms step_avg:96.96ms +step:169/1670 train_time:16385ms step_avg:96.95ms +step:170/1670 train_time:16481ms step_avg:96.94ms +step:171/1670 train_time:16577ms step_avg:96.94ms +step:172/1670 train_time:16673ms step_avg:96.93ms +step:173/1670 train_time:16768ms step_avg:96.93ms +step:174/1670 train_time:16864ms step_avg:96.92ms +step:175/1670 train_time:16958ms step_avg:96.90ms +step:176/1670 train_time:17053ms step_avg:96.89ms +step:177/1670 train_time:17149ms step_avg:96.89ms +step:178/1670 train_time:17244ms step_avg:96.88ms +step:179/1670 train_time:17340ms step_avg:96.87ms +step:180/1670 train_time:17435ms step_avg:96.86ms +step:181/1670 train_time:17531ms step_avg:96.85ms +step:182/1670 train_time:17628ms step_avg:96.86ms +step:183/1670 train_time:17724ms step_avg:96.85ms +step:184/1670 train_time:17819ms step_avg:96.84ms +step:185/1670 train_time:17914ms step_avg:96.83ms +step:186/1670 train_time:18010ms step_avg:96.83ms +step:187/1670 train_time:18107ms step_avg:96.83ms +step:188/1670 train_time:18202ms step_avg:96.82ms +step:189/1670 train_time:18297ms step_avg:96.81ms +step:190/1670 train_time:18393ms step_avg:96.80ms +step:191/1670 train_time:18487ms step_avg:96.79ms +step:192/1670 train_time:18583ms step_avg:96.79ms +step:193/1670 train_time:18680ms step_avg:96.79ms +step:194/1670 train_time:18776ms step_avg:96.78ms +step:195/1670 train_time:18871ms step_avg:96.78ms +step:196/1670 train_time:18967ms step_avg:96.77ms +step:197/1670 train_time:19063ms step_avg:96.77ms +step:198/1670 train_time:19158ms step_avg:96.76ms +step:199/1670 train_time:19253ms step_avg:96.75ms +step:200/1670 train_time:19348ms step_avg:96.74ms +step:201/1670 train_time:19445ms step_avg:96.74ms +step:202/1670 train_time:19539ms step_avg:96.73ms +step:203/1670 train_time:19635ms step_avg:96.72ms +step:204/1670 train_time:19731ms step_avg:96.72ms +step:205/1670 train_time:19828ms step_avg:96.72ms +step:206/1670 train_time:19923ms step_avg:96.71ms +step:207/1670 train_time:20018ms step_avg:96.71ms +step:208/1670 train_time:20114ms step_avg:96.70ms +step:209/1670 train_time:20210ms step_avg:96.70ms +step:210/1670 train_time:20305ms step_avg:96.69ms +step:211/1670 train_time:20400ms step_avg:96.68ms +step:212/1670 train_time:20496ms step_avg:96.68ms +step:213/1670 train_time:20752ms step_avg:97.43ms +step:214/1670 train_time:20931ms step_avg:97.81ms +step:215/1670 train_time:21025ms step_avg:97.79ms +step:216/1670 train_time:21120ms step_avg:97.78ms +step:217/1670 train_time:21214ms step_avg:97.76ms +step:218/1670 train_time:21309ms step_avg:97.75ms +step:219/1670 train_time:21404ms step_avg:97.73ms +step:220/1670 train_time:21498ms step_avg:97.72ms +step:221/1670 train_time:21593ms step_avg:97.70ms +step:222/1670 train_time:21687ms step_avg:97.69ms +step:223/1670 train_time:21787ms step_avg:97.70ms +step:224/1670 train_time:21887ms step_avg:97.71ms +step:225/1670 train_time:21984ms step_avg:97.71ms +step:226/1670 train_time:22080ms step_avg:97.70ms +step:227/1670 train_time:22174ms step_avg:97.68ms +step:228/1670 train_time:22270ms step_avg:97.67ms +step:229/1670 train_time:22365ms step_avg:97.66ms +step:230/1670 train_time:22460ms step_avg:97.65ms +step:231/1670 train_time:22554ms step_avg:97.64ms +step:232/1670 train_time:22649ms step_avg:97.63ms +step:233/1670 train_time:22745ms step_avg:97.62ms +step:234/1670 train_time:22842ms step_avg:97.61ms +step:235/1670 train_time:22939ms step_avg:97.61ms +step:236/1670 train_time:23035ms step_avg:97.61ms +step:237/1670 train_time:23132ms step_avg:97.60ms +step:238/1670 train_time:23227ms step_avg:97.59ms +step:239/1670 train_time:23322ms step_avg:97.58ms +step:240/1670 train_time:23417ms step_avg:97.57ms +step:241/1670 train_time:23512ms step_avg:97.56ms +step:242/1670 train_time:23606ms step_avg:97.55ms +step:243/1670 train_time:23701ms step_avg:97.53ms +step:244/1670 train_time:23797ms step_avg:97.53ms +step:245/1670 train_time:23893ms step_avg:97.52ms +step:246/1670 train_time:23989ms step_avg:97.52ms +step:247/1670 train_time:24086ms step_avg:97.51ms +step:248/1670 train_time:24181ms step_avg:97.51ms +step:249/1670 train_time:24277ms step_avg:97.50ms +step:250/1670 train_time:24372ms step_avg:97.49ms +step:250/1670 val_loss:3.9672 train_time:24467ms step_avg:97.87ms +step:251/1670 train_time:24488ms step_avg:97.56ms +step:252/1670 train_time:24572ms step_avg:97.51ms +step:253/1670 train_time:24673ms step_avg:97.52ms +step:254/1670 train_time:24769ms step_avg:97.51ms +step:255/1670 train_time:24864ms step_avg:97.50ms +step:256/1670 train_time:24959ms step_avg:97.50ms +step:257/1670 train_time:25053ms step_avg:97.48ms +step:258/1670 train_time:25148ms step_avg:97.47ms +step:259/1670 train_time:25242ms step_avg:97.46ms +step:260/1670 train_time:25337ms step_avg:97.45ms +step:261/1670 train_time:25432ms step_avg:97.44ms +step:262/1670 train_time:25528ms step_avg:97.44ms +step:263/1670 train_time:25627ms step_avg:97.44ms +step:264/1670 train_time:25725ms step_avg:97.44ms +step:265/1670 train_time:25820ms step_avg:97.44ms +step:266/1670 train_time:25916ms step_avg:97.43ms +step:267/1670 train_time:26010ms step_avg:97.42ms +step:268/1670 train_time:26106ms step_avg:97.41ms +step:269/1670 train_time:26201ms step_avg:97.40ms +step:270/1670 train_time:26295ms step_avg:97.39ms +step:271/1670 train_time:26390ms step_avg:97.38ms +step:272/1670 train_time:26486ms step_avg:97.37ms +step:273/1670 train_time:26582ms step_avg:97.37ms +step:274/1670 train_time:26679ms step_avg:97.37ms +step:275/1670 train_time:26775ms step_avg:97.36ms +step:276/1670 train_time:26871ms step_avg:97.36ms +step:277/1670 train_time:26967ms step_avg:97.35ms +step:278/1670 train_time:27062ms step_avg:97.35ms +step:279/1670 train_time:27158ms step_avg:97.34ms +step:280/1670 train_time:27252ms step_avg:97.33ms +step:281/1670 train_time:27347ms step_avg:97.32ms +step:282/1670 train_time:27443ms step_avg:97.31ms +step:283/1670 train_time:27538ms step_avg:97.31ms +step:284/1670 train_time:27635ms step_avg:97.31ms +step:285/1670 train_time:27732ms step_avg:97.30ms +step:286/1670 train_time:27827ms step_avg:97.30ms +step:287/1670 train_time:27924ms step_avg:97.29ms +step:288/1670 train_time:28019ms step_avg:97.29ms +step:289/1670 train_time:28114ms step_avg:97.28ms +step:290/1670 train_time:28209ms step_avg:97.27ms +step:291/1670 train_time:28305ms step_avg:97.27ms +step:292/1670 train_time:28401ms step_avg:97.26ms +step:293/1670 train_time:28497ms step_avg:97.26ms +step:294/1670 train_time:28592ms step_avg:97.25ms +step:295/1670 train_time:28687ms step_avg:97.24ms +step:296/1670 train_time:28783ms step_avg:97.24ms +step:297/1670 train_time:28879ms step_avg:97.24ms +step:298/1670 train_time:28974ms step_avg:97.23ms +step:299/1670 train_time:29070ms step_avg:97.22ms +step:300/1670 train_time:29165ms step_avg:97.22ms +step:301/1670 train_time:29260ms step_avg:97.21ms +step:302/1670 train_time:29356ms step_avg:97.20ms +step:303/1670 train_time:29451ms step_avg:97.20ms +step:304/1670 train_time:29546ms step_avg:97.19ms +step:305/1670 train_time:29643ms step_avg:97.19ms +step:306/1670 train_time:29738ms step_avg:97.18ms +step:307/1670 train_time:29835ms step_avg:97.18ms +step:308/1670 train_time:29931ms step_avg:97.18ms +step:309/1670 train_time:30027ms step_avg:97.17ms +step:310/1670 train_time:30122ms step_avg:97.17ms +step:311/1670 train_time:30217ms step_avg:97.16ms +step:312/1670 train_time:30312ms step_avg:97.15ms +step:313/1670 train_time:30407ms step_avg:97.15ms +step:314/1670 train_time:30503ms step_avg:97.14ms +step:315/1670 train_time:30599ms step_avg:97.14ms +step:316/1670 train_time:30695ms step_avg:97.14ms +step:317/1670 train_time:30790ms step_avg:97.13ms +step:318/1670 train_time:30886ms step_avg:97.13ms +step:319/1670 train_time:30983ms step_avg:97.12ms +step:320/1670 train_time:31079ms step_avg:97.12ms +step:321/1670 train_time:31174ms step_avg:97.11ms +step:322/1670 train_time:31269ms step_avg:97.11ms +step:323/1670 train_time:31364ms step_avg:97.10ms +step:324/1670 train_time:31460ms step_avg:97.10ms +step:325/1670 train_time:31556ms step_avg:97.09ms +step:326/1670 train_time:31651ms step_avg:97.09ms +step:327/1670 train_time:31746ms step_avg:97.08ms +step:328/1670 train_time:31842ms step_avg:97.08ms +step:329/1670 train_time:31938ms step_avg:97.08ms +step:330/1670 train_time:32034ms step_avg:97.07ms +step:331/1670 train_time:32129ms step_avg:97.07ms +step:332/1670 train_time:32225ms step_avg:97.06ms +step:333/1670 train_time:32321ms step_avg:97.06ms +step:334/1670 train_time:32417ms step_avg:97.06ms +step:335/1670 train_time:32513ms step_avg:97.05ms +step:336/1670 train_time:32608ms step_avg:97.05ms +step:337/1670 train_time:32704ms step_avg:97.05ms +step:338/1670 train_time:32800ms step_avg:97.04ms +step:339/1670 train_time:32896ms step_avg:97.04ms +step:340/1670 train_time:32991ms step_avg:97.03ms +step:341/1670 train_time:33087ms step_avg:97.03ms +step:342/1670 train_time:33182ms step_avg:97.02ms +step:343/1670 train_time:33278ms step_avg:97.02ms +step:344/1670 train_time:33374ms step_avg:97.02ms +step:345/1670 train_time:33470ms step_avg:97.01ms +step:346/1670 train_time:33566ms step_avg:97.01ms +step:347/1670 train_time:33661ms step_avg:97.01ms +step:348/1670 train_time:33757ms step_avg:97.00ms +step:349/1670 train_time:33853ms step_avg:97.00ms +step:350/1670 train_time:33948ms step_avg:96.99ms +step:351/1670 train_time:34044ms step_avg:96.99ms +step:352/1670 train_time:34139ms step_avg:96.99ms +step:353/1670 train_time:34236ms step_avg:96.98ms +step:354/1670 train_time:34331ms step_avg:96.98ms +step:355/1670 train_time:34427ms step_avg:96.98ms +step:356/1670 train_time:34523ms step_avg:96.98ms +step:357/1670 train_time:34619ms step_avg:96.97ms +step:358/1670 train_time:34715ms step_avg:96.97ms +step:359/1670 train_time:34810ms step_avg:96.96ms +step:360/1670 train_time:34905ms step_avg:96.96ms +step:361/1670 train_time:35001ms step_avg:96.96ms +step:362/1670 train_time:35097ms step_avg:96.95ms +step:363/1670 train_time:35193ms step_avg:96.95ms +step:364/1670 train_time:35288ms step_avg:96.95ms +step:365/1670 train_time:35384ms step_avg:96.94ms +step:366/1670 train_time:35481ms step_avg:96.94ms +step:367/1670 train_time:35577ms step_avg:96.94ms +step:368/1670 train_time:35672ms step_avg:96.93ms +step:369/1670 train_time:35767ms step_avg:96.93ms +step:370/1670 train_time:35862ms step_avg:96.92ms +step:371/1670 train_time:35958ms step_avg:96.92ms +step:372/1670 train_time:36053ms step_avg:96.92ms +step:373/1670 train_time:36149ms step_avg:96.91ms +step:374/1670 train_time:36244ms step_avg:96.91ms +step:375/1670 train_time:36340ms step_avg:96.91ms +step:375/1670 val_loss:3.8109 train_time:36436ms step_avg:97.16ms +step:376/1670 train_time:36459ms step_avg:96.96ms +step:377/1670 train_time:36544ms step_avg:96.93ms +step:378/1670 train_time:36643ms step_avg:96.94ms +step:379/1670 train_time:36739ms step_avg:96.94ms +step:380/1670 train_time:36833ms step_avg:96.93ms +step:381/1670 train_time:36928ms step_avg:96.93ms +step:382/1670 train_time:37024ms step_avg:96.92ms +step:383/1670 train_time:37119ms step_avg:96.92ms +step:384/1670 train_time:37213ms step_avg:96.91ms +step:385/1670 train_time:37308ms step_avg:96.90ms +step:386/1670 train_time:37404ms step_avg:96.90ms +step:387/1670 train_time:37501ms step_avg:96.90ms +step:388/1670 train_time:37599ms step_avg:96.90ms +step:389/1670 train_time:37695ms step_avg:96.90ms +step:390/1670 train_time:37791ms step_avg:96.90ms +step:391/1670 train_time:37886ms step_avg:96.90ms +step:392/1670 train_time:37981ms step_avg:96.89ms +step:393/1670 train_time:38077ms step_avg:96.89ms +step:394/1670 train_time:38172ms step_avg:96.88ms +step:395/1670 train_time:38267ms step_avg:96.88ms +step:396/1670 train_time:38362ms step_avg:96.87ms +step:397/1670 train_time:38459ms step_avg:96.87ms +step:398/1670 train_time:38555ms step_avg:96.87ms +step:399/1670 train_time:38651ms step_avg:96.87ms +step:400/1670 train_time:38747ms step_avg:96.87ms +step:401/1670 train_time:38843ms step_avg:96.87ms +step:402/1670 train_time:38939ms step_avg:96.86ms +step:403/1670 train_time:39033ms step_avg:96.86ms +step:404/1670 train_time:39129ms step_avg:96.85ms +step:405/1670 train_time:39225ms step_avg:96.85ms +step:406/1670 train_time:39319ms step_avg:96.85ms +step:407/1670 train_time:39415ms step_avg:96.84ms +step:408/1670 train_time:39511ms step_avg:96.84ms +step:409/1670 train_time:39608ms step_avg:96.84ms +step:410/1670 train_time:39704ms step_avg:96.84ms +step:411/1670 train_time:39801ms step_avg:96.84ms +step:412/1670 train_time:39896ms step_avg:96.84ms +step:413/1670 train_time:39992ms step_avg:96.83ms +step:414/1670 train_time:40087ms step_avg:96.83ms +step:415/1670 train_time:40182ms step_avg:96.83ms +step:416/1670 train_time:40277ms step_avg:96.82ms +step:417/1670 train_time:40372ms step_avg:96.82ms +step:418/1670 train_time:40469ms step_avg:96.81ms +step:419/1670 train_time:40564ms step_avg:96.81ms +step:420/1670 train_time:40661ms step_avg:96.81ms +step:421/1670 train_time:40757ms step_avg:96.81ms +step:422/1670 train_time:40853ms step_avg:96.81ms +step:423/1670 train_time:40949ms step_avg:96.81ms +step:424/1670 train_time:41045ms step_avg:96.80ms +step:425/1670 train_time:41302ms step_avg:97.18ms +step:426/1670 train_time:41508ms step_avg:97.44ms +step:427/1670 train_time:41601ms step_avg:97.43ms +step:428/1670 train_time:41696ms step_avg:97.42ms +step:429/1670 train_time:41790ms step_avg:97.41ms +step:430/1670 train_time:41886ms step_avg:97.41ms +step:431/1670 train_time:41981ms step_avg:97.40ms +step:432/1670 train_time:42075ms step_avg:97.40ms +step:433/1670 train_time:42170ms step_avg:97.39ms +step:434/1670 train_time:42265ms step_avg:97.39ms +step:435/1670 train_time:42365ms step_avg:97.39ms +step:436/1670 train_time:42463ms step_avg:97.39ms +step:437/1670 train_time:42562ms step_avg:97.40ms +step:438/1670 train_time:42657ms step_avg:97.39ms +step:439/1670 train_time:42753ms step_avg:97.39ms +step:440/1670 train_time:42849ms step_avg:97.38ms +step:441/1670 train_time:42944ms step_avg:97.38ms +step:442/1670 train_time:43039ms step_avg:97.37ms +step:443/1670 train_time:43134ms step_avg:97.37ms +step:444/1670 train_time:43229ms step_avg:97.36ms +step:445/1670 train_time:43325ms step_avg:97.36ms +step:446/1670 train_time:43421ms step_avg:97.36ms +step:447/1670 train_time:43518ms step_avg:97.36ms +step:448/1670 train_time:43614ms step_avg:97.35ms +step:449/1670 train_time:43710ms step_avg:97.35ms +step:450/1670 train_time:43805ms step_avg:97.35ms +step:451/1670 train_time:43901ms step_avg:97.34ms +step:452/1670 train_time:43995ms step_avg:97.34ms +step:453/1670 train_time:44091ms step_avg:97.33ms +step:454/1670 train_time:44187ms step_avg:97.33ms +step:455/1670 train_time:44282ms step_avg:97.32ms +step:456/1670 train_time:44378ms step_avg:97.32ms +step:457/1670 train_time:44474ms step_avg:97.32ms +step:458/1670 train_time:44571ms step_avg:97.32ms +step:459/1670 train_time:44667ms step_avg:97.31ms +step:460/1670 train_time:44763ms step_avg:97.31ms +step:461/1670 train_time:44859ms step_avg:97.31ms +step:462/1670 train_time:44955ms step_avg:97.30ms +step:463/1670 train_time:45050ms step_avg:97.30ms +step:464/1670 train_time:45146ms step_avg:97.30ms +step:465/1670 train_time:45241ms step_avg:97.29ms +step:466/1670 train_time:45336ms step_avg:97.29ms +step:467/1670 train_time:45432ms step_avg:97.29ms +step:468/1670 train_time:45529ms step_avg:97.28ms +step:469/1670 train_time:45625ms step_avg:97.28ms +step:470/1670 train_time:45722ms step_avg:97.28ms +step:471/1670 train_time:45817ms step_avg:97.28ms +step:472/1670 train_time:45913ms step_avg:97.27ms +step:473/1670 train_time:46009ms step_avg:97.27ms +step:474/1670 train_time:46105ms step_avg:97.27ms +step:475/1670 train_time:46200ms step_avg:97.26ms +step:476/1670 train_time:46296ms step_avg:97.26ms +step:477/1670 train_time:46391ms step_avg:97.26ms +step:478/1670 train_time:46487ms step_avg:97.25ms +step:479/1670 train_time:46584ms step_avg:97.25ms +step:480/1670 train_time:46681ms step_avg:97.25ms +step:481/1670 train_time:46776ms step_avg:97.25ms +step:482/1670 train_time:46872ms step_avg:97.24ms +step:483/1670 train_time:46968ms step_avg:97.24ms +step:484/1670 train_time:47063ms step_avg:97.24ms +step:485/1670 train_time:47159ms step_avg:97.23ms +step:486/1670 train_time:47254ms step_avg:97.23ms +step:487/1670 train_time:47350ms step_avg:97.23ms +step:488/1670 train_time:47446ms step_avg:97.23ms +step:489/1670 train_time:47542ms step_avg:97.22ms +step:490/1670 train_time:47638ms step_avg:97.22ms +step:491/1670 train_time:47733ms step_avg:97.22ms +step:492/1670 train_time:47829ms step_avg:97.21ms +step:493/1670 train_time:47925ms step_avg:97.21ms +step:494/1670 train_time:48021ms step_avg:97.21ms +step:495/1670 train_time:48117ms step_avg:97.21ms +step:496/1670 train_time:48212ms step_avg:97.20ms +step:497/1670 train_time:48309ms step_avg:97.20ms +step:498/1670 train_time:48405ms step_avg:97.20ms +step:499/1670 train_time:48501ms step_avg:97.20ms +step:500/1670 train_time:48597ms step_avg:97.19ms +step:500/1670 val_loss:3.7118 train_time:48692ms step_avg:97.38ms +step:501/1670 train_time:48715ms step_avg:97.24ms +step:502/1670 train_time:48798ms step_avg:97.21ms +step:503/1670 train_time:48897ms step_avg:97.21ms +step:504/1670 train_time:48994ms step_avg:97.21ms +step:505/1670 train_time:49090ms step_avg:97.21ms +step:506/1670 train_time:49185ms step_avg:97.20ms +step:507/1670 train_time:49280ms step_avg:97.20ms +step:508/1670 train_time:49375ms step_avg:97.19ms +step:509/1670 train_time:49470ms step_avg:97.19ms +step:510/1670 train_time:49565ms step_avg:97.19ms +step:511/1670 train_time:49661ms step_avg:97.18ms +step:512/1670 train_time:49759ms step_avg:97.18ms +step:513/1670 train_time:49857ms step_avg:97.19ms +step:514/1670 train_time:49954ms step_avg:97.19ms +step:515/1670 train_time:50050ms step_avg:97.18ms +step:516/1670 train_time:50145ms step_avg:97.18ms +step:517/1670 train_time:50241ms step_avg:97.18ms +step:518/1670 train_time:50336ms step_avg:97.17ms +step:519/1670 train_time:50431ms step_avg:97.17ms +step:520/1670 train_time:50527ms step_avg:97.17ms +step:521/1670 train_time:50621ms step_avg:97.16ms +step:522/1670 train_time:50717ms step_avg:97.16ms +step:523/1670 train_time:50814ms step_avg:97.16ms +step:524/1670 train_time:50910ms step_avg:97.16ms +step:525/1670 train_time:51006ms step_avg:97.15ms +step:526/1670 train_time:51102ms step_avg:97.15ms +step:527/1670 train_time:51198ms step_avg:97.15ms +step:528/1670 train_time:51293ms step_avg:97.15ms +step:529/1670 train_time:51388ms step_avg:97.14ms +step:530/1670 train_time:51483ms step_avg:97.14ms +step:531/1670 train_time:51579ms step_avg:97.13ms +step:532/1670 train_time:51674ms step_avg:97.13ms +step:533/1670 train_time:51772ms step_avg:97.13ms +step:534/1670 train_time:51868ms step_avg:97.13ms +step:535/1670 train_time:51964ms step_avg:97.13ms +step:536/1670 train_time:52061ms step_avg:97.13ms +step:537/1670 train_time:52157ms step_avg:97.13ms +step:538/1670 train_time:52252ms step_avg:97.12ms +step:539/1670 train_time:52348ms step_avg:97.12ms +step:540/1670 train_time:52444ms step_avg:97.12ms +step:541/1670 train_time:52539ms step_avg:97.11ms +step:542/1670 train_time:52634ms step_avg:97.11ms +step:543/1670 train_time:52730ms step_avg:97.11ms +step:544/1670 train_time:52825ms step_avg:97.11ms +step:545/1670 train_time:52922ms step_avg:97.10ms +step:546/1670 train_time:53019ms step_avg:97.10ms +step:547/1670 train_time:53115ms step_avg:97.10ms +step:548/1670 train_time:53211ms step_avg:97.10ms +step:549/1670 train_time:53306ms step_avg:97.10ms +step:550/1670 train_time:53402ms step_avg:97.09ms +step:551/1670 train_time:53497ms step_avg:97.09ms +step:552/1670 train_time:53592ms step_avg:97.09ms +step:553/1670 train_time:53687ms step_avg:97.08ms +step:554/1670 train_time:53783ms step_avg:97.08ms +step:555/1670 train_time:53879ms step_avg:97.08ms +step:556/1670 train_time:53976ms step_avg:97.08ms +step:557/1670 train_time:54072ms step_avg:97.08ms +step:558/1670 train_time:54168ms step_avg:97.08ms +step:559/1670 train_time:54265ms step_avg:97.07ms +step:560/1670 train_time:54363ms step_avg:97.08ms +step:561/1670 train_time:54460ms step_avg:97.08ms +step:562/1670 train_time:54557ms step_avg:97.08ms +step:563/1670 train_time:54655ms step_avg:97.08ms +step:564/1670 train_time:54752ms step_avg:97.08ms +step:565/1670 train_time:54849ms step_avg:97.08ms +step:566/1670 train_time:54947ms step_avg:97.08ms +step:567/1670 train_time:55044ms step_avg:97.08ms +step:568/1670 train_time:55142ms step_avg:97.08ms +step:569/1670 train_time:55240ms step_avg:97.08ms +step:570/1670 train_time:55338ms step_avg:97.08ms +step:571/1670 train_time:55434ms step_avg:97.08ms +step:572/1670 train_time:55531ms step_avg:97.08ms +step:573/1670 train_time:55628ms step_avg:97.08ms +step:574/1670 train_time:55724ms step_avg:97.08ms +step:575/1670 train_time:55821ms step_avg:97.08ms +step:576/1670 train_time:55919ms step_avg:97.08ms +step:577/1670 train_time:56016ms step_avg:97.08ms +step:578/1670 train_time:56114ms step_avg:97.08ms +step:579/1670 train_time:56212ms step_avg:97.08ms +step:580/1670 train_time:56309ms step_avg:97.09ms +step:581/1670 train_time:56406ms step_avg:97.08ms +step:582/1670 train_time:56504ms step_avg:97.09ms +step:583/1670 train_time:56601ms step_avg:97.09ms +step:584/1670 train_time:56699ms step_avg:97.09ms +step:585/1670 train_time:56796ms step_avg:97.09ms +step:586/1670 train_time:56893ms step_avg:97.09ms +step:587/1670 train_time:56990ms step_avg:97.09ms +step:588/1670 train_time:57088ms step_avg:97.09ms +step:589/1670 train_time:57185ms step_avg:97.09ms +step:590/1670 train_time:57283ms step_avg:97.09ms +step:591/1670 train_time:57381ms step_avg:97.09ms +step:592/1670 train_time:57478ms step_avg:97.09ms +step:593/1670 train_time:57577ms step_avg:97.09ms +step:594/1670 train_time:57673ms step_avg:97.09ms +step:595/1670 train_time:57770ms step_avg:97.09ms +step:596/1670 train_time:57867ms step_avg:97.09ms +step:597/1670 train_time:57964ms step_avg:97.09ms +step:598/1670 train_time:58062ms step_avg:97.09ms +step:599/1670 train_time:58161ms step_avg:97.10ms +step:600/1670 train_time:58259ms step_avg:97.10ms +step:601/1670 train_time:58357ms step_avg:97.10ms +step:602/1670 train_time:58455ms step_avg:97.10ms +step:603/1670 train_time:58553ms step_avg:97.10ms +step:604/1670 train_time:58649ms step_avg:97.10ms +step:605/1670 train_time:58745ms step_avg:97.10ms +step:606/1670 train_time:58842ms step_avg:97.10ms +step:607/1670 train_time:58940ms step_avg:97.10ms +step:608/1670 train_time:59038ms step_avg:97.10ms +step:609/1670 train_time:59135ms step_avg:97.10ms +step:610/1670 train_time:59232ms step_avg:97.10ms +step:611/1670 train_time:59329ms step_avg:97.10ms +step:612/1670 train_time:59428ms step_avg:97.10ms +step:613/1670 train_time:59525ms step_avg:97.10ms +step:614/1670 train_time:59622ms step_avg:97.10ms +step:615/1670 train_time:59719ms step_avg:97.10ms +step:616/1670 train_time:59816ms step_avg:97.10ms +step:617/1670 train_time:59914ms step_avg:97.11ms +step:618/1670 train_time:60012ms step_avg:97.11ms +step:619/1670 train_time:60109ms step_avg:97.11ms +step:620/1670 train_time:60206ms step_avg:97.11ms +step:621/1670 train_time:60304ms step_avg:97.11ms +step:622/1670 train_time:60401ms step_avg:97.11ms +step:623/1670 train_time:60499ms step_avg:97.11ms +step:624/1670 train_time:60597ms step_avg:97.11ms +step:625/1670 train_time:60694ms step_avg:97.11ms +step:625/1670 val_loss:3.6134 train_time:60791ms step_avg:97.27ms +step:626/1670 train_time:60813ms step_avg:97.14ms +step:627/1670 train_time:60900ms step_avg:97.13ms +step:628/1670 train_time:60996ms step_avg:97.13ms +step:629/1670 train_time:61092ms step_avg:97.13ms +step:630/1670 train_time:61188ms step_avg:97.12ms +step:631/1670 train_time:61283ms step_avg:97.12ms +step:632/1670 train_time:61380ms step_avg:97.12ms +step:633/1670 train_time:61475ms step_avg:97.12ms +step:634/1670 train_time:61571ms step_avg:97.12ms +step:635/1670 train_time:61667ms step_avg:97.11ms +step:636/1670 train_time:61766ms step_avg:97.12ms +step:637/1670 train_time:61866ms step_avg:97.12ms +step:638/1670 train_time:61964ms step_avg:97.12ms +step:639/1670 train_time:62342ms step_avg:97.56ms +step:640/1670 train_time:62429ms step_avg:97.54ms +step:641/1670 train_time:62523ms step_avg:97.54ms +step:642/1670 train_time:62620ms step_avg:97.54ms +step:643/1670 train_time:62716ms step_avg:97.54ms +step:644/1670 train_time:62812ms step_avg:97.53ms +step:645/1670 train_time:62908ms step_avg:97.53ms +step:646/1670 train_time:63005ms step_avg:97.53ms +step:647/1670 train_time:63101ms step_avg:97.53ms +step:648/1670 train_time:63197ms step_avg:97.53ms +step:649/1670 train_time:63297ms step_avg:97.53ms +step:650/1670 train_time:63399ms step_avg:97.54ms +step:651/1670 train_time:63498ms step_avg:97.54ms +step:652/1670 train_time:63597ms step_avg:97.54ms +step:653/1670 train_time:63692ms step_avg:97.54ms +step:654/1670 train_time:63788ms step_avg:97.54ms +step:655/1670 train_time:63884ms step_avg:97.53ms +step:656/1670 train_time:63980ms step_avg:97.53ms +step:657/1670 train_time:64077ms step_avg:97.53ms +step:658/1670 train_time:64173ms step_avg:97.53ms +step:659/1670 train_time:64271ms step_avg:97.53ms +step:660/1670 train_time:64370ms step_avg:97.53ms +step:661/1670 train_time:64469ms step_avg:97.53ms +step:662/1670 train_time:64567ms step_avg:97.53ms +step:663/1670 train_time:64665ms step_avg:97.53ms +step:664/1670 train_time:64763ms step_avg:97.54ms +step:665/1670 train_time:64859ms step_avg:97.53ms +step:666/1670 train_time:64955ms step_avg:97.53ms +step:667/1670 train_time:65052ms step_avg:97.53ms +step:668/1670 train_time:65148ms step_avg:97.53ms +step:669/1670 train_time:65245ms step_avg:97.53ms +step:670/1670 train_time:65344ms step_avg:97.53ms +step:671/1670 train_time:65443ms step_avg:97.53ms +step:672/1670 train_time:65541ms step_avg:97.53ms +step:673/1670 train_time:65640ms step_avg:97.53ms +step:674/1670 train_time:65738ms step_avg:97.53ms +step:675/1670 train_time:65834ms step_avg:97.53ms +step:676/1670 train_time:65931ms step_avg:97.53ms +step:677/1670 train_time:66027ms step_avg:97.53ms +step:678/1670 train_time:66124ms step_avg:97.53ms +step:679/1670 train_time:66221ms step_avg:97.53ms +step:680/1670 train_time:66319ms step_avg:97.53ms +step:681/1670 train_time:66418ms step_avg:97.53ms +step:682/1670 train_time:66515ms step_avg:97.53ms +step:683/1670 train_time:66612ms step_avg:97.53ms +step:684/1670 train_time:66710ms step_avg:97.53ms +step:685/1670 train_time:66807ms step_avg:97.53ms +step:686/1670 train_time:66903ms step_avg:97.53ms +step:687/1670 train_time:67000ms step_avg:97.53ms +step:688/1670 train_time:67097ms step_avg:97.52ms +step:689/1670 train_time:67194ms step_avg:97.52ms +step:690/1670 train_time:67291ms step_avg:97.52ms +step:691/1670 train_time:67388ms step_avg:97.52ms +step:692/1670 train_time:67486ms step_avg:97.52ms +step:693/1670 train_time:67584ms step_avg:97.52ms +step:694/1670 train_time:67682ms step_avg:97.52ms +step:695/1670 train_time:67780ms step_avg:97.53ms +step:696/1670 train_time:67878ms step_avg:97.53ms +step:697/1670 train_time:67975ms step_avg:97.53ms +step:698/1670 train_time:68072ms step_avg:97.52ms +step:699/1670 train_time:68169ms step_avg:97.52ms +step:700/1670 train_time:68266ms step_avg:97.52ms +step:701/1670 train_time:68363ms step_avg:97.52ms +step:702/1670 train_time:68462ms step_avg:97.52ms +step:703/1670 train_time:68560ms step_avg:97.52ms +step:704/1670 train_time:68658ms step_avg:97.53ms +step:705/1670 train_time:68756ms step_avg:97.53ms +step:706/1670 train_time:68854ms step_avg:97.53ms +step:707/1670 train_time:68950ms step_avg:97.52ms +step:708/1670 train_time:69047ms step_avg:97.52ms +step:709/1670 train_time:69143ms step_avg:97.52ms +step:710/1670 train_time:69240ms step_avg:97.52ms +step:711/1670 train_time:69340ms step_avg:97.52ms +step:712/1670 train_time:69438ms step_avg:97.52ms +step:713/1670 train_time:69535ms step_avg:97.52ms +step:714/1670 train_time:69632ms step_avg:97.52ms +step:715/1670 train_time:69729ms step_avg:97.52ms +step:716/1670 train_time:69827ms step_avg:97.52ms +step:717/1670 train_time:69924ms step_avg:97.52ms +step:718/1670 train_time:70021ms step_avg:97.52ms +step:719/1670 train_time:70119ms step_avg:97.52ms +step:720/1670 train_time:70215ms step_avg:97.52ms +step:721/1670 train_time:70313ms step_avg:97.52ms +step:722/1670 train_time:70411ms step_avg:97.52ms +step:723/1670 train_time:70508ms step_avg:97.52ms +step:724/1670 train_time:70605ms step_avg:97.52ms +step:725/1670 train_time:70702ms step_avg:97.52ms +step:726/1670 train_time:70800ms step_avg:97.52ms +step:727/1670 train_time:70898ms step_avg:97.52ms +step:728/1670 train_time:70995ms step_avg:97.52ms +step:729/1670 train_time:71094ms step_avg:97.52ms +step:730/1670 train_time:71191ms step_avg:97.52ms +step:731/1670 train_time:71288ms step_avg:97.52ms +step:732/1670 train_time:71385ms step_avg:97.52ms +step:733/1670 train_time:71482ms step_avg:97.52ms +step:734/1670 train_time:71580ms step_avg:97.52ms +step:735/1670 train_time:71677ms step_avg:97.52ms +step:736/1670 train_time:71774ms step_avg:97.52ms +step:737/1670 train_time:71871ms step_avg:97.52ms +step:738/1670 train_time:71968ms step_avg:97.52ms +step:739/1670 train_time:72065ms step_avg:97.52ms +step:740/1670 train_time:72163ms step_avg:97.52ms +step:741/1670 train_time:72261ms step_avg:97.52ms +step:742/1670 train_time:72358ms step_avg:97.52ms +step:743/1670 train_time:72456ms step_avg:97.52ms +step:744/1670 train_time:72553ms step_avg:97.52ms +step:745/1670 train_time:72650ms step_avg:97.52ms +step:746/1670 train_time:72748ms step_avg:97.52ms +step:747/1670 train_time:72845ms step_avg:97.52ms +step:748/1670 train_time:72943ms step_avg:97.52ms +step:749/1670 train_time:73040ms step_avg:97.52ms +step:750/1670 train_time:73138ms step_avg:97.52ms +step:750/1670 val_loss:3.5618 train_time:73234ms step_avg:97.65ms +step:751/1670 train_time:73257ms step_avg:97.55ms +step:752/1670 train_time:73340ms step_avg:97.53ms +step:753/1670 train_time:73440ms step_avg:97.53ms +step:754/1670 train_time:73537ms step_avg:97.53ms +step:755/1670 train_time:73634ms step_avg:97.53ms +step:756/1670 train_time:73731ms step_avg:97.53ms +step:757/1670 train_time:73827ms step_avg:97.53ms +step:758/1670 train_time:73924ms step_avg:97.53ms +step:759/1670 train_time:74021ms step_avg:97.52ms +step:760/1670 train_time:74116ms step_avg:97.52ms +step:761/1670 train_time:74216ms step_avg:97.52ms +step:762/1670 train_time:74317ms step_avg:97.53ms +step:763/1670 train_time:74416ms step_avg:97.53ms +step:764/1670 train_time:74514ms step_avg:97.53ms +step:765/1670 train_time:74611ms step_avg:97.53ms +step:766/1670 train_time:74708ms step_avg:97.53ms +step:767/1670 train_time:74804ms step_avg:97.53ms +step:768/1670 train_time:74900ms step_avg:97.53ms +step:769/1670 train_time:74996ms step_avg:97.52ms +step:770/1670 train_time:75093ms step_avg:97.52ms +step:771/1670 train_time:75191ms step_avg:97.52ms +step:772/1670 train_time:75290ms step_avg:97.53ms +step:773/1670 train_time:75390ms step_avg:97.53ms +step:774/1670 train_time:75487ms step_avg:97.53ms +step:775/1670 train_time:75586ms step_avg:97.53ms +step:776/1670 train_time:75683ms step_avg:97.53ms +step:777/1670 train_time:75780ms step_avg:97.53ms +step:778/1670 train_time:75877ms step_avg:97.53ms +step:779/1670 train_time:75973ms step_avg:97.53ms +step:780/1670 train_time:76071ms step_avg:97.53ms +step:781/1670 train_time:76169ms step_avg:97.53ms +step:782/1670 train_time:76267ms step_avg:97.53ms +step:783/1670 train_time:76365ms step_avg:97.53ms +step:784/1670 train_time:76463ms step_avg:97.53ms +step:785/1670 train_time:76560ms step_avg:97.53ms +step:786/1670 train_time:76658ms step_avg:97.53ms +step:787/1670 train_time:76755ms step_avg:97.53ms +step:788/1670 train_time:76852ms step_avg:97.53ms +step:789/1670 train_time:76949ms step_avg:97.53ms +step:790/1670 train_time:77044ms step_avg:97.52ms +step:791/1670 train_time:77142ms step_avg:97.52ms +step:792/1670 train_time:77241ms step_avg:97.53ms +step:793/1670 train_time:77338ms step_avg:97.53ms +step:794/1670 train_time:77436ms step_avg:97.53ms +step:795/1670 train_time:77534ms step_avg:97.53ms +step:796/1670 train_time:77633ms step_avg:97.53ms +step:797/1670 train_time:77731ms step_avg:97.53ms +step:798/1670 train_time:77827ms step_avg:97.53ms +step:799/1670 train_time:77924ms step_avg:97.53ms +step:800/1670 train_time:78020ms step_avg:97.53ms +step:801/1670 train_time:78117ms step_avg:97.52ms +step:802/1670 train_time:78215ms step_avg:97.52ms +step:803/1670 train_time:78312ms step_avg:97.52ms +step:804/1670 train_time:78411ms step_avg:97.53ms +step:805/1670 train_time:78510ms step_avg:97.53ms +step:806/1670 train_time:78607ms step_avg:97.53ms +step:807/1670 train_time:78705ms step_avg:97.53ms +step:808/1670 train_time:78803ms step_avg:97.53ms +step:809/1670 train_time:78899ms step_avg:97.53ms +step:810/1670 train_time:78996ms step_avg:97.53ms +step:811/1670 train_time:79093ms step_avg:97.53ms +step:812/1670 train_time:79190ms step_avg:97.53ms +step:813/1670 train_time:79288ms step_avg:97.52ms +step:814/1670 train_time:79385ms step_avg:97.52ms +step:815/1670 train_time:79483ms step_avg:97.53ms +step:816/1670 train_time:79581ms step_avg:97.53ms +step:817/1670 train_time:79678ms step_avg:97.53ms +step:818/1670 train_time:79777ms step_avg:97.53ms +step:819/1670 train_time:79875ms step_avg:97.53ms +step:820/1670 train_time:79972ms step_avg:97.53ms +step:821/1670 train_time:80069ms step_avg:97.53ms +step:822/1670 train_time:80167ms step_avg:97.53ms +step:823/1670 train_time:80263ms step_avg:97.53ms +step:824/1670 train_time:80360ms step_avg:97.52ms +step:825/1670 train_time:80458ms step_avg:97.52ms +step:826/1670 train_time:80556ms step_avg:97.53ms +step:827/1670 train_time:80653ms step_avg:97.53ms +step:828/1670 train_time:80752ms step_avg:97.53ms +step:829/1670 train_time:80849ms step_avg:97.53ms +step:830/1670 train_time:80947ms step_avg:97.53ms +step:831/1670 train_time:81044ms step_avg:97.53ms +step:832/1670 train_time:81141ms step_avg:97.52ms +step:833/1670 train_time:81237ms step_avg:97.52ms +step:834/1670 train_time:81335ms step_avg:97.52ms +step:835/1670 train_time:81432ms step_avg:97.52ms +step:836/1670 train_time:81531ms step_avg:97.53ms +step:837/1670 train_time:81628ms step_avg:97.52ms +step:838/1670 train_time:81726ms step_avg:97.52ms +step:839/1670 train_time:81823ms step_avg:97.52ms +step:840/1670 train_time:81920ms step_avg:97.52ms +step:841/1670 train_time:82018ms step_avg:97.52ms +step:842/1670 train_time:82115ms step_avg:97.52ms +step:843/1670 train_time:82212ms step_avg:97.52ms +step:844/1670 train_time:82310ms step_avg:97.52ms +step:845/1670 train_time:82407ms step_avg:97.52ms +step:846/1670 train_time:82505ms step_avg:97.52ms +step:847/1670 train_time:82602ms step_avg:97.52ms +step:848/1670 train_time:82699ms step_avg:97.52ms +step:849/1670 train_time:82796ms step_avg:97.52ms +step:850/1670 train_time:82894ms step_avg:97.52ms +step:851/1670 train_time:83168ms step_avg:97.73ms +step:852/1670 train_time:83242ms step_avg:97.70ms +step:853/1670 train_time:83337ms step_avg:97.70ms +step:854/1670 train_time:83434ms step_avg:97.70ms +step:855/1670 train_time:83530ms step_avg:97.70ms +step:856/1670 train_time:83627ms step_avg:97.69ms +step:857/1670 train_time:83723ms step_avg:97.69ms +step:858/1670 train_time:83820ms step_avg:97.69ms +step:859/1670 train_time:83916ms step_avg:97.69ms +step:860/1670 train_time:84013ms step_avg:97.69ms +step:861/1670 train_time:84114ms step_avg:97.69ms +step:862/1670 train_time:84217ms step_avg:97.70ms +step:863/1670 train_time:84316ms step_avg:97.70ms +step:864/1670 train_time:84413ms step_avg:97.70ms +step:865/1670 train_time:84510ms step_avg:97.70ms +step:866/1670 train_time:84606ms step_avg:97.70ms +step:867/1670 train_time:84703ms step_avg:97.70ms +step:868/1670 train_time:84799ms step_avg:97.70ms +step:869/1670 train_time:84895ms step_avg:97.69ms +step:870/1670 train_time:84992ms step_avg:97.69ms +step:871/1670 train_time:85091ms step_avg:97.69ms +step:872/1670 train_time:85193ms step_avg:97.70ms +step:873/1670 train_time:85292ms step_avg:97.70ms +step:874/1670 train_time:85391ms step_avg:97.70ms +step:875/1670 train_time:85489ms step_avg:97.70ms +step:875/1670 val_loss:3.5195 train_time:85584ms step_avg:97.81ms +step:876/1670 train_time:85607ms step_avg:97.72ms +step:877/1670 train_time:85689ms step_avg:97.71ms +step:878/1670 train_time:85788ms step_avg:97.71ms +step:879/1670 train_time:85886ms step_avg:97.71ms +step:880/1670 train_time:85983ms step_avg:97.71ms +step:881/1670 train_time:86080ms step_avg:97.71ms +step:882/1670 train_time:86176ms step_avg:97.70ms +step:883/1670 train_time:86271ms step_avg:97.70ms +step:884/1670 train_time:86368ms step_avg:97.70ms +step:885/1670 train_time:86464ms step_avg:97.70ms +step:886/1670 train_time:86564ms step_avg:97.70ms +step:887/1670 train_time:86665ms step_avg:97.71ms +step:888/1670 train_time:86764ms step_avg:97.71ms +step:889/1670 train_time:86862ms step_avg:97.71ms +step:890/1670 train_time:86960ms step_avg:97.71ms +step:891/1670 train_time:87057ms step_avg:97.71ms +step:892/1670 train_time:87153ms step_avg:97.70ms +step:893/1670 train_time:87249ms step_avg:97.70ms +step:894/1670 train_time:87346ms step_avg:97.70ms +step:895/1670 train_time:87442ms step_avg:97.70ms +step:896/1670 train_time:87540ms step_avg:97.70ms +step:897/1670 train_time:87638ms step_avg:97.70ms +step:898/1670 train_time:87736ms step_avg:97.70ms +step:899/1670 train_time:87834ms step_avg:97.70ms +step:900/1670 train_time:87931ms step_avg:97.70ms +step:901/1670 train_time:88028ms step_avg:97.70ms +step:902/1670 train_time:88126ms step_avg:97.70ms +step:903/1670 train_time:88223ms step_avg:97.70ms +step:904/1670 train_time:88319ms step_avg:97.70ms +step:905/1670 train_time:88416ms step_avg:97.70ms +step:906/1670 train_time:88513ms step_avg:97.70ms +step:907/1670 train_time:88611ms step_avg:97.70ms +step:908/1670 train_time:88709ms step_avg:97.70ms +step:909/1670 train_time:88808ms step_avg:97.70ms +step:910/1670 train_time:88908ms step_avg:97.70ms +step:911/1670 train_time:89006ms step_avg:97.70ms +step:912/1670 train_time:89105ms step_avg:97.70ms +step:913/1670 train_time:89202ms step_avg:97.70ms +step:914/1670 train_time:89298ms step_avg:97.70ms +step:915/1670 train_time:89395ms step_avg:97.70ms +step:916/1670 train_time:89491ms step_avg:97.70ms +step:917/1670 train_time:89589ms step_avg:97.70ms +step:918/1670 train_time:89687ms step_avg:97.70ms +step:919/1670 train_time:89785ms step_avg:97.70ms +step:920/1670 train_time:89883ms step_avg:97.70ms +step:921/1670 train_time:89980ms step_avg:97.70ms +step:922/1670 train_time:90077ms step_avg:97.70ms +step:923/1670 train_time:90175ms step_avg:97.70ms +step:924/1670 train_time:90272ms step_avg:97.70ms +step:925/1670 train_time:90369ms step_avg:97.70ms +step:926/1670 train_time:90466ms step_avg:97.70ms +step:927/1670 train_time:90564ms step_avg:97.70ms +step:928/1670 train_time:90662ms step_avg:97.70ms +step:929/1670 train_time:90760ms step_avg:97.70ms +step:930/1670 train_time:90858ms step_avg:97.70ms +step:931/1670 train_time:90955ms step_avg:97.70ms +step:932/1670 train_time:91053ms step_avg:97.70ms +step:933/1670 train_time:91150ms step_avg:97.70ms +step:934/1670 train_time:91247ms step_avg:97.70ms +step:935/1670 train_time:91345ms step_avg:97.70ms +step:936/1670 train_time:91442ms step_avg:97.69ms +step:937/1670 train_time:91539ms step_avg:97.69ms +step:938/1670 train_time:91636ms step_avg:97.69ms +step:939/1670 train_time:91733ms step_avg:97.69ms +step:940/1670 train_time:91831ms step_avg:97.69ms +step:941/1670 train_time:91928ms step_avg:97.69ms +step:942/1670 train_time:92027ms step_avg:97.69ms +step:943/1670 train_time:92124ms step_avg:97.69ms +step:944/1670 train_time:92221ms step_avg:97.69ms +step:945/1670 train_time:92319ms step_avg:97.69ms +step:946/1670 train_time:92416ms step_avg:97.69ms +step:947/1670 train_time:92514ms step_avg:97.69ms +step:948/1670 train_time:92610ms step_avg:97.69ms +step:949/1670 train_time:92708ms step_avg:97.69ms +step:950/1670 train_time:92805ms step_avg:97.69ms +step:951/1670 train_time:92903ms step_avg:97.69ms +step:952/1670 train_time:93000ms step_avg:97.69ms +step:953/1670 train_time:93098ms step_avg:97.69ms +step:954/1670 train_time:93195ms step_avg:97.69ms +step:955/1670 train_time:93292ms step_avg:97.69ms +step:956/1670 train_time:93390ms step_avg:97.69ms +step:957/1670 train_time:93488ms step_avg:97.69ms +step:958/1670 train_time:93586ms step_avg:97.69ms +step:959/1670 train_time:93683ms step_avg:97.69ms +step:960/1670 train_time:93781ms step_avg:97.69ms +step:961/1670 train_time:93878ms step_avg:97.69ms +step:962/1670 train_time:93975ms step_avg:97.69ms +step:963/1670 train_time:94072ms step_avg:97.69ms +step:964/1670 train_time:94169ms step_avg:97.69ms +step:965/1670 train_time:94267ms step_avg:97.69ms +step:966/1670 train_time:94366ms step_avg:97.69ms +step:967/1670 train_time:94463ms step_avg:97.69ms +step:968/1670 train_time:94560ms step_avg:97.69ms +step:969/1670 train_time:94657ms step_avg:97.69ms +step:970/1670 train_time:94754ms step_avg:97.68ms +step:971/1670 train_time:94851ms step_avg:97.68ms +step:972/1670 train_time:94949ms step_avg:97.68ms +step:973/1670 train_time:95048ms step_avg:97.69ms +step:974/1670 train_time:95145ms step_avg:97.69ms +step:975/1670 train_time:95243ms step_avg:97.68ms +step:976/1670 train_time:95340ms step_avg:97.68ms +step:977/1670 train_time:95437ms step_avg:97.68ms +step:978/1670 train_time:95534ms step_avg:97.68ms +step:979/1670 train_time:95632ms step_avg:97.68ms +step:980/1670 train_time:95729ms step_avg:97.68ms +step:981/1670 train_time:95827ms step_avg:97.68ms +step:982/1670 train_time:95924ms step_avg:97.68ms +step:983/1670 train_time:96022ms step_avg:97.68ms +step:984/1670 train_time:96119ms step_avg:97.68ms +step:985/1670 train_time:96215ms step_avg:97.68ms +step:986/1670 train_time:96312ms step_avg:97.68ms +step:987/1670 train_time:96410ms step_avg:97.68ms +step:988/1670 train_time:96507ms step_avg:97.68ms +step:989/1670 train_time:96606ms step_avg:97.68ms +step:990/1670 train_time:96704ms step_avg:97.68ms +step:991/1670 train_time:96802ms step_avg:97.68ms +step:992/1670 train_time:96899ms step_avg:97.68ms +step:993/1670 train_time:96997ms step_avg:97.68ms +step:994/1670 train_time:97093ms step_avg:97.68ms +step:995/1670 train_time:97190ms step_avg:97.68ms +step:996/1670 train_time:97287ms step_avg:97.68ms +step:997/1670 train_time:97385ms step_avg:97.68ms +step:998/1670 train_time:97483ms step_avg:97.68ms +step:999/1670 train_time:97580ms step_avg:97.68ms +step:1000/1670 train_time:97676ms step_avg:97.68ms +step:1000/1670 val_loss:3.4765 train_time:97773ms step_avg:97.77ms +step:1001/1670 train_time:97795ms step_avg:97.70ms +step:1002/1670 train_time:97877ms step_avg:97.68ms +step:1003/1670 train_time:97977ms step_avg:97.68ms +step:1004/1670 train_time:98075ms step_avg:97.68ms +step:1005/1670 train_time:98172ms step_avg:97.68ms +step:1006/1670 train_time:98268ms step_avg:97.68ms +step:1007/1670 train_time:98364ms step_avg:97.68ms +step:1008/1670 train_time:98460ms step_avg:97.68ms +step:1009/1670 train_time:98557ms step_avg:97.68ms +step:1010/1670 train_time:98654ms step_avg:97.68ms +step:1011/1670 train_time:98752ms step_avg:97.68ms +step:1012/1670 train_time:98851ms step_avg:97.68ms +step:1013/1670 train_time:98950ms step_avg:97.68ms +step:1014/1670 train_time:99048ms step_avg:97.68ms +step:1015/1670 train_time:99145ms step_avg:97.68ms +step:1016/1670 train_time:99242ms step_avg:97.68ms +step:1017/1670 train_time:99340ms step_avg:97.68ms +step:1018/1670 train_time:99437ms step_avg:97.68ms +step:1019/1670 train_time:99533ms step_avg:97.68ms +step:1020/1670 train_time:99629ms step_avg:97.68ms +step:1021/1670 train_time:99726ms step_avg:97.68ms +step:1022/1670 train_time:99824ms step_avg:97.68ms +step:1023/1670 train_time:99924ms step_avg:97.68ms +step:1024/1670 train_time:100024ms step_avg:97.68ms +step:1025/1670 train_time:100121ms step_avg:97.68ms +step:1026/1670 train_time:100220ms step_avg:97.68ms +step:1027/1670 train_time:100317ms step_avg:97.68ms +step:1028/1670 train_time:100414ms step_avg:97.68ms +step:1029/1670 train_time:100511ms step_avg:97.68ms +step:1030/1670 train_time:100607ms step_avg:97.68ms +step:1031/1670 train_time:100704ms step_avg:97.68ms +step:1032/1670 train_time:100802ms step_avg:97.68ms +step:1033/1670 train_time:100900ms step_avg:97.68ms +step:1034/1670 train_time:100999ms step_avg:97.68ms +step:1035/1670 train_time:101096ms step_avg:97.68ms +step:1036/1670 train_time:101194ms step_avg:97.68ms +step:1037/1670 train_time:101291ms step_avg:97.68ms +step:1038/1670 train_time:101388ms step_avg:97.68ms +step:1039/1670 train_time:101484ms step_avg:97.67ms +step:1040/1670 train_time:101581ms step_avg:97.67ms +step:1041/1670 train_time:101678ms step_avg:97.67ms +step:1042/1670 train_time:101778ms step_avg:97.68ms +step:1043/1670 train_time:101877ms step_avg:97.68ms +step:1044/1670 train_time:101976ms step_avg:97.68ms +step:1045/1670 train_time:102073ms step_avg:97.68ms +step:1046/1670 train_time:102171ms step_avg:97.68ms +step:1047/1670 train_time:102267ms step_avg:97.68ms +step:1048/1670 train_time:102364ms step_avg:97.68ms +step:1049/1670 train_time:102462ms step_avg:97.68ms +step:1050/1670 train_time:102559ms step_avg:97.68ms +step:1051/1670 train_time:102657ms step_avg:97.68ms +step:1052/1670 train_time:102755ms step_avg:97.68ms +step:1053/1670 train_time:102852ms step_avg:97.68ms +step:1054/1670 train_time:102949ms step_avg:97.68ms +step:1055/1670 train_time:103047ms step_avg:97.68ms +step:1056/1670 train_time:103145ms step_avg:97.68ms +step:1057/1670 train_time:103243ms step_avg:97.68ms +step:1058/1670 train_time:103340ms step_avg:97.67ms +step:1059/1670 train_time:103438ms step_avg:97.68ms +step:1060/1670 train_time:103535ms step_avg:97.67ms +step:1061/1670 train_time:103632ms step_avg:97.67ms +step:1062/1670 train_time:103903ms step_avg:97.84ms +step:1063/1670 train_time:104062ms step_avg:97.89ms +step:1064/1670 train_time:104157ms step_avg:97.89ms +step:1065/1670 train_time:104253ms step_avg:97.89ms +step:1066/1670 train_time:104349ms step_avg:97.89ms +step:1067/1670 train_time:104444ms step_avg:97.89ms +step:1068/1670 train_time:104540ms step_avg:97.88ms +step:1069/1670 train_time:104637ms step_avg:97.88ms +step:1070/1670 train_time:104733ms step_avg:97.88ms +step:1071/1670 train_time:104828ms step_avg:97.88ms +step:1072/1670 train_time:104930ms step_avg:97.88ms +step:1073/1670 train_time:105033ms step_avg:97.89ms +step:1074/1670 train_time:105131ms step_avg:97.89ms +step:1075/1670 train_time:105228ms step_avg:97.89ms +step:1076/1670 train_time:105324ms step_avg:97.88ms +step:1077/1670 train_time:105421ms step_avg:97.88ms +step:1078/1670 train_time:105518ms step_avg:97.88ms +step:1079/1670 train_time:105615ms step_avg:97.88ms +step:1080/1670 train_time:105711ms step_avg:97.88ms +step:1081/1670 train_time:105807ms step_avg:97.88ms +step:1082/1670 train_time:105905ms step_avg:97.88ms +step:1083/1670 train_time:106005ms step_avg:97.88ms +step:1084/1670 train_time:106104ms step_avg:97.88ms +step:1085/1670 train_time:106203ms step_avg:97.88ms +step:1086/1670 train_time:106300ms step_avg:97.88ms +step:1087/1670 train_time:106397ms step_avg:97.88ms +step:1088/1670 train_time:106493ms step_avg:97.88ms +step:1089/1670 train_time:106589ms step_avg:97.88ms +step:1090/1670 train_time:106686ms step_avg:97.88ms +step:1091/1670 train_time:106783ms step_avg:97.88ms +step:1092/1670 train_time:106882ms step_avg:97.88ms +step:1093/1670 train_time:106981ms step_avg:97.88ms +step:1094/1670 train_time:107080ms step_avg:97.88ms +step:1095/1670 train_time:107179ms step_avg:97.88ms +step:1096/1670 train_time:107276ms step_avg:97.88ms +step:1097/1670 train_time:107373ms step_avg:97.88ms +step:1098/1670 train_time:107470ms step_avg:97.88ms +step:1099/1670 train_time:107566ms step_avg:97.88ms +step:1100/1670 train_time:107663ms step_avg:97.88ms +step:1101/1670 train_time:107761ms step_avg:97.88ms +step:1102/1670 train_time:107859ms step_avg:97.88ms +step:1103/1670 train_time:107957ms step_avg:97.88ms +step:1104/1670 train_time:108054ms step_avg:97.88ms +step:1105/1670 train_time:108153ms step_avg:97.88ms +step:1106/1670 train_time:108251ms step_avg:97.88ms +step:1107/1670 train_time:108348ms step_avg:97.88ms +step:1108/1670 train_time:108445ms step_avg:97.87ms +step:1109/1670 train_time:108542ms step_avg:97.87ms +step:1110/1670 train_time:108640ms step_avg:97.87ms +step:1111/1670 train_time:108737ms step_avg:97.87ms +step:1112/1670 train_time:108834ms step_avg:97.87ms +step:1113/1670 train_time:108930ms step_avg:97.87ms +step:1114/1670 train_time:109028ms step_avg:97.87ms +step:1115/1670 train_time:109126ms step_avg:97.87ms +step:1116/1670 train_time:109224ms step_avg:97.87ms +step:1117/1670 train_time:109325ms step_avg:97.87ms +step:1118/1670 train_time:109422ms step_avg:97.87ms +step:1119/1670 train_time:109520ms step_avg:97.87ms +step:1120/1670 train_time:109619ms step_avg:97.87ms +step:1121/1670 train_time:109717ms step_avg:97.87ms +step:1122/1670 train_time:109815ms step_avg:97.87ms +step:1123/1670 train_time:109913ms step_avg:97.87ms +step:1124/1670 train_time:110011ms step_avg:97.87ms +step:1125/1670 train_time:110108ms step_avg:97.87ms +step:1125/1670 val_loss:3.4250 train_time:110206ms step_avg:97.96ms +step:1126/1670 train_time:110228ms step_avg:97.89ms +step:1127/1670 train_time:110312ms step_avg:97.88ms +step:1128/1670 train_time:110409ms step_avg:97.88ms +step:1129/1670 train_time:110506ms step_avg:97.88ms +step:1130/1670 train_time:110602ms step_avg:97.88ms +step:1131/1670 train_time:110698ms step_avg:97.88ms +step:1132/1670 train_time:110795ms step_avg:97.88ms +step:1133/1670 train_time:110892ms step_avg:97.88ms +step:1134/1670 train_time:110990ms step_avg:97.87ms +step:1135/1670 train_time:111087ms step_avg:97.87ms +step:1136/1670 train_time:111192ms step_avg:97.88ms +step:1137/1670 train_time:111294ms step_avg:97.88ms +step:1138/1670 train_time:111393ms step_avg:97.88ms +step:1139/1670 train_time:111491ms step_avg:97.88ms +step:1140/1670 train_time:111588ms step_avg:97.88ms +step:1141/1670 train_time:111686ms step_avg:97.88ms +step:1142/1670 train_time:111784ms step_avg:97.88ms +step:1143/1670 train_time:111880ms step_avg:97.88ms +step:1144/1670 train_time:111977ms step_avg:97.88ms +step:1145/1670 train_time:112075ms step_avg:97.88ms +step:1146/1670 train_time:112174ms step_avg:97.88ms +step:1147/1670 train_time:112274ms step_avg:97.89ms +step:1148/1670 train_time:112375ms step_avg:97.89ms +step:1149/1670 train_time:112474ms step_avg:97.89ms +step:1150/1670 train_time:112573ms step_avg:97.89ms +step:1151/1670 train_time:112671ms step_avg:97.89ms +step:1152/1670 train_time:112769ms step_avg:97.89ms +step:1153/1670 train_time:112866ms step_avg:97.89ms +step:1154/1670 train_time:112964ms step_avg:97.89ms +step:1155/1670 train_time:113061ms step_avg:97.89ms +step:1156/1670 train_time:113158ms step_avg:97.89ms +step:1157/1670 train_time:113255ms step_avg:97.89ms +step:1158/1670 train_time:113354ms step_avg:97.89ms +step:1159/1670 train_time:113453ms step_avg:97.89ms +step:1160/1670 train_time:113552ms step_avg:97.89ms +step:1161/1670 train_time:113651ms step_avg:97.89ms +step:1162/1670 train_time:113750ms step_avg:97.89ms +step:1163/1670 train_time:113847ms step_avg:97.89ms +step:1164/1670 train_time:113945ms step_avg:97.89ms +step:1165/1670 train_time:114042ms step_avg:97.89ms +step:1166/1670 train_time:114140ms step_avg:97.89ms +step:1167/1670 train_time:114237ms step_avg:97.89ms +step:1168/1670 train_time:114335ms step_avg:97.89ms +step:1169/1670 train_time:114434ms step_avg:97.89ms +step:1170/1670 train_time:114533ms step_avg:97.89ms +step:1171/1670 train_time:114632ms step_avg:97.89ms +step:1172/1670 train_time:114732ms step_avg:97.89ms +step:1173/1670 train_time:114832ms step_avg:97.90ms +step:1174/1670 train_time:114932ms step_avg:97.90ms +step:1175/1670 train_time:115033ms step_avg:97.90ms +step:1176/1670 train_time:115133ms step_avg:97.90ms +step:1177/1670 train_time:115233ms step_avg:97.90ms +step:1178/1670 train_time:115331ms step_avg:97.90ms +step:1179/1670 train_time:115429ms step_avg:97.90ms +step:1180/1670 train_time:115527ms step_avg:97.90ms +step:1181/1670 train_time:115624ms step_avg:97.90ms +step:1182/1670 train_time:115721ms step_avg:97.90ms +step:1183/1670 train_time:115819ms step_avg:97.90ms +step:1184/1670 train_time:115916ms step_avg:97.90ms +step:1185/1670 train_time:116015ms step_avg:97.90ms +step:1186/1670 train_time:116113ms step_avg:97.90ms +step:1187/1670 train_time:116211ms step_avg:97.90ms +step:1188/1670 train_time:116310ms step_avg:97.90ms +step:1189/1670 train_time:116408ms step_avg:97.90ms +step:1190/1670 train_time:116505ms step_avg:97.90ms +step:1191/1670 train_time:116603ms step_avg:97.90ms +step:1192/1670 train_time:116700ms step_avg:97.90ms +step:1193/1670 train_time:116798ms step_avg:97.90ms +step:1194/1670 train_time:116896ms step_avg:97.90ms +step:1195/1670 train_time:116994ms step_avg:97.90ms +step:1196/1670 train_time:117094ms step_avg:97.90ms +step:1197/1670 train_time:117191ms step_avg:97.90ms +step:1198/1670 train_time:117289ms step_avg:97.90ms +step:1199/1670 train_time:117387ms step_avg:97.90ms +step:1200/1670 train_time:117485ms step_avg:97.90ms +step:1201/1670 train_time:117583ms step_avg:97.90ms +step:1202/1670 train_time:117681ms step_avg:97.90ms +step:1203/1670 train_time:117779ms step_avg:97.90ms +step:1204/1670 train_time:117876ms step_avg:97.90ms +step:1205/1670 train_time:117975ms step_avg:97.90ms +step:1206/1670 train_time:118073ms step_avg:97.90ms +step:1207/1670 train_time:118171ms step_avg:97.90ms +step:1208/1670 train_time:118270ms step_avg:97.91ms +step:1209/1670 train_time:118368ms step_avg:97.91ms +step:1210/1670 train_time:118466ms step_avg:97.91ms +step:1211/1670 train_time:118563ms step_avg:97.91ms +step:1212/1670 train_time:118661ms step_avg:97.91ms +step:1213/1670 train_time:118759ms step_avg:97.91ms +step:1214/1670 train_time:118856ms step_avg:97.90ms +step:1215/1670 train_time:118953ms step_avg:97.90ms +step:1216/1670 train_time:119051ms step_avg:97.90ms +step:1217/1670 train_time:119149ms step_avg:97.90ms +step:1218/1670 train_time:119248ms step_avg:97.90ms +step:1219/1670 train_time:119346ms step_avg:97.90ms +step:1220/1670 train_time:119443ms step_avg:97.90ms +step:1221/1670 train_time:119541ms step_avg:97.90ms +step:1222/1670 train_time:119639ms step_avg:97.90ms +step:1223/1670 train_time:119736ms step_avg:97.90ms +step:1224/1670 train_time:119835ms step_avg:97.90ms +step:1225/1670 train_time:119934ms step_avg:97.91ms +step:1226/1670 train_time:120031ms step_avg:97.90ms +step:1227/1670 train_time:120129ms step_avg:97.90ms +step:1228/1670 train_time:120227ms step_avg:97.90ms +step:1229/1670 train_time:120325ms step_avg:97.90ms +step:1230/1670 train_time:120423ms step_avg:97.90ms +step:1231/1670 train_time:120521ms step_avg:97.91ms +step:1232/1670 train_time:120619ms step_avg:97.90ms +step:1233/1670 train_time:120717ms step_avg:97.91ms +step:1234/1670 train_time:120815ms step_avg:97.91ms +step:1235/1670 train_time:120913ms step_avg:97.91ms +step:1236/1670 train_time:121011ms step_avg:97.91ms +step:1237/1670 train_time:121108ms step_avg:97.90ms +step:1238/1670 train_time:121206ms step_avg:97.90ms +step:1239/1670 train_time:121303ms step_avg:97.90ms +step:1240/1670 train_time:121402ms step_avg:97.90ms +step:1241/1670 train_time:121499ms step_avg:97.90ms +step:1242/1670 train_time:121597ms step_avg:97.90ms +step:1243/1670 train_time:121695ms step_avg:97.90ms +step:1244/1670 train_time:121793ms step_avg:97.90ms +step:1245/1670 train_time:121891ms step_avg:97.90ms +step:1246/1670 train_time:121989ms step_avg:97.90ms +step:1247/1670 train_time:122086ms step_avg:97.90ms +step:1248/1670 train_time:122185ms step_avg:97.90ms +step:1249/1670 train_time:122282ms step_avg:97.90ms +step:1250/1670 train_time:122380ms step_avg:97.90ms +step:1250/1670 val_loss:3.3821 train_time:122477ms step_avg:97.98ms +step:1251/1670 train_time:122499ms step_avg:97.92ms +step:1252/1670 train_time:122586ms step_avg:97.91ms +step:1253/1670 train_time:122689ms step_avg:97.92ms +step:1254/1670 train_time:122786ms step_avg:97.92ms +step:1255/1670 train_time:122883ms step_avg:97.92ms +step:1256/1670 train_time:122980ms step_avg:97.91ms +step:1257/1670 train_time:123077ms step_avg:97.91ms +step:1258/1670 train_time:123175ms step_avg:97.91ms +step:1259/1670 train_time:123272ms step_avg:97.91ms +step:1260/1670 train_time:123369ms step_avg:97.91ms +step:1261/1670 train_time:123469ms step_avg:97.91ms +step:1262/1670 train_time:123570ms step_avg:97.92ms +step:1263/1670 train_time:123669ms step_avg:97.92ms +step:1264/1670 train_time:123767ms step_avg:97.92ms +step:1265/1670 train_time:123865ms step_avg:97.92ms +step:1266/1670 train_time:123961ms step_avg:97.92ms +step:1267/1670 train_time:124059ms step_avg:97.92ms +step:1268/1670 train_time:124157ms step_avg:97.92ms +step:1269/1670 train_time:124254ms step_avg:97.91ms +step:1270/1670 train_time:124352ms step_avg:97.91ms +step:1271/1670 train_time:124451ms step_avg:97.92ms +step:1272/1670 train_time:124551ms step_avg:97.92ms +step:1273/1670 train_time:124651ms step_avg:97.92ms +step:1274/1670 train_time:124995ms step_avg:98.11ms +step:1275/1670 train_time:125113ms step_avg:98.13ms +step:1276/1670 train_time:125208ms step_avg:98.13ms +step:1277/1670 train_time:125305ms step_avg:98.12ms +step:1278/1670 train_time:125402ms step_avg:98.12ms +step:1279/1670 train_time:125499ms step_avg:98.12ms +step:1280/1670 train_time:125596ms step_avg:98.12ms +step:1281/1670 train_time:125692ms step_avg:98.12ms +step:1282/1670 train_time:125790ms step_avg:98.12ms +step:1283/1670 train_time:125887ms step_avg:98.12ms +step:1284/1670 train_time:125990ms step_avg:98.12ms +step:1285/1670 train_time:126090ms step_avg:98.12ms +step:1286/1670 train_time:126191ms step_avg:98.13ms +step:1287/1670 train_time:126291ms step_avg:98.13ms +step:1288/1670 train_time:126389ms step_avg:98.13ms +step:1289/1670 train_time:126486ms step_avg:98.13ms +step:1290/1670 train_time:126584ms step_avg:98.13ms +step:1291/1670 train_time:126681ms step_avg:98.13ms +step:1292/1670 train_time:126779ms step_avg:98.13ms +step:1293/1670 train_time:126876ms step_avg:98.12ms +step:1294/1670 train_time:126974ms step_avg:98.13ms +step:1295/1670 train_time:127075ms step_avg:98.13ms +step:1296/1670 train_time:127176ms step_avg:98.13ms +step:1297/1670 train_time:127276ms step_avg:98.13ms +step:1298/1670 train_time:127376ms step_avg:98.13ms +step:1299/1670 train_time:127476ms step_avg:98.13ms +step:1300/1670 train_time:127575ms step_avg:98.13ms +step:1301/1670 train_time:127675ms step_avg:98.14ms +step:1302/1670 train_time:127773ms step_avg:98.14ms +step:1303/1670 train_time:127870ms step_avg:98.13ms +step:1304/1670 train_time:127967ms step_avg:98.13ms +step:1305/1670 train_time:128065ms step_avg:98.13ms +step:1306/1670 train_time:128163ms step_avg:98.13ms +step:1307/1670 train_time:128260ms step_avg:98.13ms +step:1308/1670 train_time:128359ms step_avg:98.13ms +step:1309/1670 train_time:128458ms step_avg:98.13ms +step:1310/1670 train_time:128557ms step_avg:98.14ms +step:1311/1670 train_time:128656ms step_avg:98.14ms +step:1312/1670 train_time:128756ms step_avg:98.14ms +step:1313/1670 train_time:128854ms step_avg:98.14ms +step:1314/1670 train_time:128952ms step_avg:98.14ms +step:1315/1670 train_time:129050ms step_avg:98.14ms +step:1316/1670 train_time:129150ms step_avg:98.14ms +step:1317/1670 train_time:129249ms step_avg:98.14ms +step:1318/1670 train_time:129347ms step_avg:98.14ms +step:1319/1670 train_time:129445ms step_avg:98.14ms +step:1320/1670 train_time:129543ms step_avg:98.14ms +step:1321/1670 train_time:129642ms step_avg:98.14ms +step:1322/1670 train_time:129741ms step_avg:98.14ms +step:1323/1670 train_time:129840ms step_avg:98.14ms +step:1324/1670 train_time:129938ms step_avg:98.14ms +step:1325/1670 train_time:130037ms step_avg:98.14ms +step:1326/1670 train_time:130136ms step_avg:98.14ms +step:1327/1670 train_time:130235ms step_avg:98.14ms +step:1328/1670 train_time:130334ms step_avg:98.14ms +step:1329/1670 train_time:130434ms step_avg:98.14ms +step:1330/1670 train_time:130534ms step_avg:98.15ms +step:1331/1670 train_time:130632ms step_avg:98.15ms +step:1332/1670 train_time:130731ms step_avg:98.15ms +step:1333/1670 train_time:130830ms step_avg:98.15ms +step:1334/1670 train_time:130929ms step_avg:98.15ms +step:1335/1670 train_time:131026ms step_avg:98.15ms +step:1336/1670 train_time:131124ms step_avg:98.15ms +step:1337/1670 train_time:131222ms step_avg:98.15ms +step:1338/1670 train_time:131320ms step_avg:98.15ms +step:1339/1670 train_time:131419ms step_avg:98.15ms +step:1340/1670 train_time:131518ms step_avg:98.15ms +step:1341/1670 train_time:131617ms step_avg:98.15ms +step:1342/1670 train_time:131716ms step_avg:98.15ms +step:1343/1670 train_time:131815ms step_avg:98.15ms +step:1344/1670 train_time:131915ms step_avg:98.15ms +step:1345/1670 train_time:132012ms step_avg:98.15ms +step:1346/1670 train_time:132112ms step_avg:98.15ms +step:1347/1670 train_time:132210ms step_avg:98.15ms +step:1348/1670 train_time:132307ms step_avg:98.15ms +step:1349/1670 train_time:132405ms step_avg:98.15ms +step:1350/1670 train_time:132503ms step_avg:98.15ms +step:1351/1670 train_time:132602ms step_avg:98.15ms +step:1352/1670 train_time:132700ms step_avg:98.15ms +step:1353/1670 train_time:132799ms step_avg:98.15ms +step:1354/1670 train_time:132900ms step_avg:98.15ms +step:1355/1670 train_time:132999ms step_avg:98.15ms +step:1356/1670 train_time:133098ms step_avg:98.15ms +step:1357/1670 train_time:133196ms step_avg:98.15ms +step:1358/1670 train_time:133294ms step_avg:98.15ms +step:1359/1670 train_time:133392ms step_avg:98.15ms +step:1360/1670 train_time:133490ms step_avg:98.15ms +step:1361/1670 train_time:133588ms step_avg:98.15ms +step:1362/1670 train_time:133686ms step_avg:98.15ms +step:1363/1670 train_time:133784ms step_avg:98.15ms +step:1364/1670 train_time:133882ms step_avg:98.15ms +step:1365/1670 train_time:133981ms step_avg:98.15ms +step:1366/1670 train_time:134079ms step_avg:98.15ms +step:1367/1670 train_time:134177ms step_avg:98.15ms +step:1368/1670 train_time:134276ms step_avg:98.15ms +step:1369/1670 train_time:134376ms step_avg:98.16ms +step:1370/1670 train_time:134476ms step_avg:98.16ms +step:1371/1670 train_time:134574ms step_avg:98.16ms +step:1372/1670 train_time:134672ms step_avg:98.16ms +step:1373/1670 train_time:134770ms step_avg:98.16ms +step:1374/1670 train_time:134869ms step_avg:98.16ms +step:1375/1670 train_time:134966ms step_avg:98.16ms +step:1375/1670 val_loss:3.3440 train_time:135063ms step_avg:98.23ms +step:1376/1670 train_time:135085ms step_avg:98.17ms +step:1377/1670 train_time:135170ms step_avg:98.16ms +step:1378/1670 train_time:135271ms step_avg:98.16ms +step:1379/1670 train_time:135371ms step_avg:98.17ms +step:1380/1670 train_time:135468ms step_avg:98.17ms +step:1381/1670 train_time:135565ms step_avg:98.16ms +step:1382/1670 train_time:135662ms step_avg:98.16ms +step:1383/1670 train_time:135760ms step_avg:98.16ms +step:1384/1670 train_time:135858ms step_avg:98.16ms +step:1385/1670 train_time:135955ms step_avg:98.16ms +step:1386/1670 train_time:136054ms step_avg:98.16ms +step:1387/1670 train_time:136155ms step_avg:98.17ms +step:1388/1670 train_time:136254ms step_avg:98.17ms +step:1389/1670 train_time:136353ms step_avg:98.17ms +step:1390/1670 train_time:136450ms step_avg:98.17ms +step:1391/1670 train_time:136548ms step_avg:98.16ms +step:1392/1670 train_time:136645ms step_avg:98.16ms +step:1393/1670 train_time:136743ms step_avg:98.16ms +step:1394/1670 train_time:136841ms step_avg:98.16ms +step:1395/1670 train_time:136939ms step_avg:98.16ms +step:1396/1670 train_time:137039ms step_avg:98.17ms +step:1397/1670 train_time:137139ms step_avg:98.17ms +step:1398/1670 train_time:137239ms step_avg:98.17ms +step:1399/1670 train_time:137339ms step_avg:98.17ms +step:1400/1670 train_time:137436ms step_avg:98.17ms +step:1401/1670 train_time:137535ms step_avg:98.17ms +step:1402/1670 train_time:137633ms step_avg:98.17ms +step:1403/1670 train_time:137731ms step_avg:98.17ms +step:1404/1670 train_time:137829ms step_avg:98.17ms +step:1405/1670 train_time:137928ms step_avg:98.17ms +step:1406/1670 train_time:138028ms step_avg:98.17ms +step:1407/1670 train_time:138127ms step_avg:98.17ms +step:1408/1670 train_time:138226ms step_avg:98.17ms +step:1409/1670 train_time:138325ms step_avg:98.17ms +step:1410/1670 train_time:138426ms step_avg:98.17ms +step:1411/1670 train_time:138526ms step_avg:98.18ms +step:1412/1670 train_time:138625ms step_avg:98.18ms +step:1413/1670 train_time:138724ms step_avg:98.18ms +step:1414/1670 train_time:138823ms step_avg:98.18ms +step:1415/1670 train_time:138920ms step_avg:98.18ms +step:1416/1670 train_time:139018ms step_avg:98.18ms +step:1417/1670 train_time:139116ms step_avg:98.18ms +step:1418/1670 train_time:139214ms step_avg:98.18ms +step:1419/1670 train_time:139313ms step_avg:98.18ms +step:1420/1670 train_time:139410ms step_avg:98.18ms +step:1421/1670 train_time:139509ms step_avg:98.18ms +step:1422/1670 train_time:139608ms step_avg:98.18ms +step:1423/1670 train_time:139707ms step_avg:98.18ms +step:1424/1670 train_time:139804ms step_avg:98.18ms +step:1425/1670 train_time:139903ms step_avg:98.18ms +step:1426/1670 train_time:140001ms step_avg:98.18ms +step:1427/1670 train_time:140099ms step_avg:98.18ms +step:1428/1670 train_time:140197ms step_avg:98.18ms +step:1429/1670 train_time:140294ms step_avg:98.18ms +step:1430/1670 train_time:140393ms step_avg:98.18ms +step:1431/1670 train_time:140491ms step_avg:98.18ms +step:1432/1670 train_time:140590ms step_avg:98.18ms +step:1433/1670 train_time:140688ms step_avg:98.18ms +step:1434/1670 train_time:140787ms step_avg:98.18ms +step:1435/1670 train_time:140885ms step_avg:98.18ms +step:1436/1670 train_time:140983ms step_avg:98.18ms +step:1437/1670 train_time:141082ms step_avg:98.18ms +step:1438/1670 train_time:141180ms step_avg:98.18ms +step:1439/1670 train_time:141278ms step_avg:98.18ms +step:1440/1670 train_time:141378ms step_avg:98.18ms +step:1441/1670 train_time:141476ms step_avg:98.18ms +step:1442/1670 train_time:141574ms step_avg:98.18ms +step:1443/1670 train_time:141671ms step_avg:98.18ms +step:1444/1670 train_time:141769ms step_avg:98.18ms +step:1445/1670 train_time:141867ms step_avg:98.18ms +step:1446/1670 train_time:141967ms step_avg:98.18ms +step:1447/1670 train_time:142066ms step_avg:98.18ms +step:1448/1670 train_time:142166ms step_avg:98.18ms +step:1449/1670 train_time:142266ms step_avg:98.18ms +step:1450/1670 train_time:142367ms step_avg:98.18ms +step:1451/1670 train_time:142466ms step_avg:98.18ms +step:1452/1670 train_time:142564ms step_avg:98.18ms +step:1453/1670 train_time:142664ms step_avg:98.19ms +step:1454/1670 train_time:142762ms step_avg:98.19ms +step:1455/1670 train_time:142861ms step_avg:98.19ms +step:1456/1670 train_time:142958ms step_avg:98.19ms +step:1457/1670 train_time:143055ms step_avg:98.18ms +step:1458/1670 train_time:143154ms step_avg:98.19ms +step:1459/1670 train_time:143253ms step_avg:98.19ms +step:1460/1670 train_time:143351ms step_avg:98.19ms +step:1461/1670 train_time:143451ms step_avg:98.19ms +step:1462/1670 train_time:143549ms step_avg:98.19ms +step:1463/1670 train_time:143648ms step_avg:98.19ms +step:1464/1670 train_time:143746ms step_avg:98.19ms +step:1465/1670 train_time:143845ms step_avg:98.19ms +step:1466/1670 train_time:143943ms step_avg:98.19ms +step:1467/1670 train_time:144040ms step_avg:98.19ms +step:1468/1670 train_time:144138ms step_avg:98.19ms +step:1469/1670 train_time:144237ms step_avg:98.19ms +step:1470/1670 train_time:144335ms step_avg:98.19ms +step:1471/1670 train_time:144433ms step_avg:98.19ms +step:1472/1670 train_time:144531ms step_avg:98.19ms +step:1473/1670 train_time:144629ms step_avg:98.19ms +step:1474/1670 train_time:144728ms step_avg:98.19ms +step:1475/1670 train_time:144825ms step_avg:98.19ms +step:1476/1670 train_time:144924ms step_avg:98.19ms +step:1477/1670 train_time:145021ms step_avg:98.19ms +step:1478/1670 train_time:145120ms step_avg:98.19ms +step:1479/1670 train_time:145218ms step_avg:98.19ms +step:1480/1670 train_time:145317ms step_avg:98.19ms +step:1481/1670 train_time:145415ms step_avg:98.19ms +step:1482/1670 train_time:145513ms step_avg:98.19ms +step:1483/1670 train_time:145611ms step_avg:98.19ms +step:1484/1670 train_time:145710ms step_avg:98.19ms +step:1485/1670 train_time:145969ms step_avg:98.30ms +step:1486/1670 train_time:146064ms step_avg:98.29ms +step:1487/1670 train_time:146160ms step_avg:98.29ms +step:1488/1670 train_time:146256ms step_avg:98.29ms +step:1489/1670 train_time:146353ms step_avg:98.29ms +step:1490/1670 train_time:146450ms step_avg:98.29ms +step:1491/1670 train_time:146548ms step_avg:98.29ms +step:1492/1670 train_time:146646ms step_avg:98.29ms +step:1493/1670 train_time:146742ms step_avg:98.29ms +step:1494/1670 train_time:146841ms step_avg:98.29ms +step:1495/1670 train_time:146943ms step_avg:98.29ms +step:1496/1670 train_time:147045ms step_avg:98.29ms +step:1497/1670 train_time:147144ms step_avg:98.29ms +step:1498/1670 train_time:147244ms step_avg:98.29ms +step:1499/1670 train_time:147342ms step_avg:98.29ms +step:1500/1670 train_time:147440ms step_avg:98.29ms +step:1500/1670 val_loss:3.3119 train_time:147536ms step_avg:98.36ms +step:1501/1670 train_time:147558ms step_avg:98.31ms +step:1502/1670 train_time:147641ms step_avg:98.30ms +step:1503/1670 train_time:147742ms step_avg:98.30ms +step:1504/1670 train_time:147839ms step_avg:98.30ms +step:1505/1670 train_time:147937ms step_avg:98.30ms +step:1506/1670 train_time:148034ms step_avg:98.30ms +step:1507/1670 train_time:148131ms step_avg:98.30ms +step:1508/1670 train_time:148228ms step_avg:98.29ms +step:1509/1670 train_time:148325ms step_avg:98.29ms +step:1510/1670 train_time:148423ms step_avg:98.29ms +step:1511/1670 train_time:148522ms step_avg:98.29ms +step:1512/1670 train_time:148622ms step_avg:98.30ms +step:1513/1670 train_time:148720ms step_avg:98.29ms +step:1514/1670 train_time:148819ms step_avg:98.29ms +step:1515/1670 train_time:148917ms step_avg:98.29ms +step:1516/1670 train_time:149014ms step_avg:98.29ms +step:1517/1670 train_time:149111ms step_avg:98.29ms +step:1518/1670 train_time:149208ms step_avg:98.29ms +step:1519/1670 train_time:149305ms step_avg:98.29ms +step:1520/1670 train_time:149404ms step_avg:98.29ms +step:1521/1670 train_time:149502ms step_avg:98.29ms +step:1522/1670 train_time:149602ms step_avg:98.29ms +step:1523/1670 train_time:149700ms step_avg:98.29ms +step:1524/1670 train_time:149799ms step_avg:98.29ms +step:1525/1670 train_time:149897ms step_avg:98.29ms +step:1526/1670 train_time:149995ms step_avg:98.29ms +step:1527/1670 train_time:150092ms step_avg:98.29ms +step:1528/1670 train_time:150190ms step_avg:98.29ms +step:1529/1670 train_time:150287ms step_avg:98.29ms +step:1530/1670 train_time:150385ms step_avg:98.29ms +step:1531/1670 train_time:150483ms step_avg:98.29ms +step:1532/1670 train_time:150581ms step_avg:98.29ms +step:1533/1670 train_time:150679ms step_avg:98.29ms +step:1534/1670 train_time:150777ms step_avg:98.29ms +step:1535/1670 train_time:150876ms step_avg:98.29ms +step:1536/1670 train_time:150974ms step_avg:98.29ms +step:1537/1670 train_time:151071ms step_avg:98.29ms +step:1538/1670 train_time:151169ms step_avg:98.29ms +step:1539/1670 train_time:151267ms step_avg:98.29ms +step:1540/1670 train_time:151365ms step_avg:98.29ms +step:1541/1670 train_time:151463ms step_avg:98.29ms +step:1542/1670 train_time:151561ms step_avg:98.29ms +step:1543/1670 train_time:151659ms step_avg:98.29ms +step:1544/1670 train_time:151757ms step_avg:98.29ms +step:1545/1670 train_time:151856ms step_avg:98.29ms +step:1546/1670 train_time:151953ms step_avg:98.29ms +step:1547/1670 train_time:152051ms step_avg:98.29ms +step:1548/1670 train_time:152149ms step_avg:98.29ms +step:1549/1670 train_time:152247ms step_avg:98.29ms +step:1550/1670 train_time:152345ms step_avg:98.29ms +step:1551/1670 train_time:152444ms step_avg:98.29ms +step:1552/1670 train_time:152543ms step_avg:98.29ms +step:1553/1670 train_time:152643ms step_avg:98.29ms +step:1554/1670 train_time:152742ms step_avg:98.29ms +step:1555/1670 train_time:152841ms step_avg:98.29ms +step:1556/1670 train_time:152938ms step_avg:98.29ms +step:1557/1670 train_time:153036ms step_avg:98.29ms +step:1558/1670 train_time:153134ms step_avg:98.29ms +step:1559/1670 train_time:153233ms step_avg:98.29ms +step:1560/1670 train_time:153331ms step_avg:98.29ms +step:1561/1670 train_time:153431ms step_avg:98.29ms +step:1562/1670 train_time:153531ms step_avg:98.29ms +step:1563/1670 train_time:153631ms step_avg:98.29ms +step:1564/1670 train_time:153731ms step_avg:98.29ms +step:1565/1670 train_time:153832ms step_avg:98.29ms +step:1566/1670 train_time:153932ms step_avg:98.30ms +step:1567/1670 train_time:154031ms step_avg:98.30ms +step:1568/1670 train_time:154130ms step_avg:98.30ms +step:1569/1670 train_time:154228ms step_avg:98.30ms +step:1570/1670 train_time:154326ms step_avg:98.30ms +step:1571/1670 train_time:154424ms step_avg:98.30ms +step:1572/1670 train_time:154522ms step_avg:98.30ms +step:1573/1670 train_time:154620ms step_avg:98.30ms +step:1574/1670 train_time:154719ms step_avg:98.30ms +step:1575/1670 train_time:154817ms step_avg:98.30ms +step:1576/1670 train_time:154915ms step_avg:98.30ms +step:1577/1670 train_time:155013ms step_avg:98.30ms +step:1578/1670 train_time:155112ms step_avg:98.30ms +step:1579/1670 train_time:155210ms step_avg:98.30ms +step:1580/1670 train_time:155308ms step_avg:98.30ms +step:1581/1670 train_time:155406ms step_avg:98.30ms +step:1582/1670 train_time:155504ms step_avg:98.30ms +step:1583/1670 train_time:155603ms step_avg:98.30ms +step:1584/1670 train_time:155701ms step_avg:98.30ms +step:1585/1670 train_time:155799ms step_avg:98.30ms +step:1586/1670 train_time:155897ms step_avg:98.30ms +step:1587/1670 train_time:155996ms step_avg:98.30ms +step:1588/1670 train_time:156093ms step_avg:98.30ms +step:1589/1670 train_time:156191ms step_avg:98.30ms +step:1590/1670 train_time:156290ms step_avg:98.30ms +step:1591/1670 train_time:156389ms step_avg:98.30ms +step:1592/1670 train_time:156489ms step_avg:98.30ms +step:1593/1670 train_time:156587ms step_avg:98.30ms +step:1594/1670 train_time:156688ms step_avg:98.30ms +step:1595/1670 train_time:156789ms step_avg:98.30ms +step:1596/1670 train_time:156890ms step_avg:98.30ms +step:1597/1670 train_time:156989ms step_avg:98.30ms +step:1598/1670 train_time:157088ms step_avg:98.30ms +step:1599/1670 train_time:157185ms step_avg:98.30ms +step:1600/1670 train_time:157282ms step_avg:98.30ms +step:1601/1670 train_time:157380ms step_avg:98.30ms +step:1602/1670 train_time:157477ms step_avg:98.30ms +step:1603/1670 train_time:157574ms step_avg:98.30ms +step:1604/1670 train_time:157674ms step_avg:98.30ms +step:1605/1670 train_time:157774ms step_avg:98.30ms +step:1606/1670 train_time:157874ms step_avg:98.30ms +step:1607/1670 train_time:157972ms step_avg:98.30ms +step:1608/1670 train_time:158071ms step_avg:98.30ms +step:1609/1670 train_time:158170ms step_avg:98.30ms +step:1610/1670 train_time:158268ms step_avg:98.30ms +step:1611/1670 train_time:158367ms step_avg:98.30ms +step:1612/1670 train_time:158464ms step_avg:98.30ms +step:1613/1670 train_time:158563ms step_avg:98.30ms +step:1614/1670 train_time:158662ms step_avg:98.30ms +step:1615/1670 train_time:158761ms step_avg:98.30ms +step:1616/1670 train_time:158859ms step_avg:98.30ms +step:1617/1670 train_time:158957ms step_avg:98.30ms +step:1618/1670 train_time:159056ms step_avg:98.30ms +step:1619/1670 train_time:159153ms step_avg:98.30ms +step:1620/1670 train_time:159252ms step_avg:98.30ms +step:1621/1670 train_time:159351ms step_avg:98.30ms +step:1622/1670 train_time:159449ms step_avg:98.30ms +step:1623/1670 train_time:159549ms step_avg:98.30ms +step:1624/1670 train_time:159648ms step_avg:98.31ms +step:1625/1670 train_time:159748ms step_avg:98.31ms +step:1625/1670 val_loss:3.2852 train_time:159848ms step_avg:98.37ms +step:1626/1670 train_time:159872ms step_avg:98.32ms +step:1627/1670 train_time:159956ms step_avg:98.31ms +step:1628/1670 train_time:160056ms step_avg:98.31ms +step:1629/1670 train_time:160154ms step_avg:98.31ms +step:1630/1670 train_time:160252ms step_avg:98.31ms +step:1631/1670 train_time:160349ms step_avg:98.31ms +step:1632/1670 train_time:160446ms step_avg:98.31ms +step:1633/1670 train_time:160542ms step_avg:98.31ms +step:1634/1670 train_time:160640ms step_avg:98.31ms +step:1635/1670 train_time:160738ms step_avg:98.31ms +step:1636/1670 train_time:160840ms step_avg:98.31ms +step:1637/1670 train_time:160942ms step_avg:98.31ms +step:1638/1670 train_time:161042ms step_avg:98.32ms +step:1639/1670 train_time:161142ms step_avg:98.32ms +step:1640/1670 train_time:161240ms step_avg:98.32ms +step:1641/1670 train_time:161340ms step_avg:98.32ms +step:1642/1670 train_time:161438ms step_avg:98.32ms +step:1643/1670 train_time:161536ms step_avg:98.32ms +step:1644/1670 train_time:161634ms step_avg:98.32ms +step:1645/1670 train_time:161732ms step_avg:98.32ms +step:1646/1670 train_time:161830ms step_avg:98.32ms +step:1647/1670 train_time:161931ms step_avg:98.32ms +step:1648/1670 train_time:162031ms step_avg:98.32ms +step:1649/1670 train_time:162132ms step_avg:98.32ms +step:1650/1670 train_time:162232ms step_avg:98.32ms +step:1651/1670 train_time:162330ms step_avg:98.32ms +step:1652/1670 train_time:162428ms step_avg:98.32ms +step:1653/1670 train_time:162526ms step_avg:98.32ms +step:1654/1670 train_time:162624ms step_avg:98.32ms +step:1655/1670 train_time:162723ms step_avg:98.32ms +step:1656/1670 train_time:162821ms step_avg:98.32ms +step:1657/1670 train_time:162922ms step_avg:98.32ms +step:1658/1670 train_time:163021ms step_avg:98.32ms +step:1659/1670 train_time:163121ms step_avg:98.33ms +step:1660/1670 train_time:163221ms step_avg:98.33ms +step:1661/1670 train_time:163320ms step_avg:98.33ms +step:1662/1670 train_time:163418ms step_avg:98.33ms +step:1663/1670 train_time:163516ms step_avg:98.33ms +step:1664/1670 train_time:163616ms step_avg:98.33ms +step:1665/1670 train_time:163714ms step_avg:98.33ms +step:1666/1670 train_time:163814ms step_avg:98.33ms +step:1667/1670 train_time:163912ms step_avg:98.33ms +step:1668/1670 train_time:164012ms step_avg:98.33ms +step:1669/1670 train_time:164110ms step_avg:98.33ms +step:1670/1670 train_time:164209ms step_avg:98.33ms +step:1670/1670 val_loss:3.2772 train_time:164306ms step_avg:98.39ms +peak memory allocated: 34001 MiB reserved: 49136 MiB diff --git a/records/050925_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt b/records/050925_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt new file mode 100644 index 000000000..1d6a73fd4 --- /dev/null +++ b/records/050925_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt @@ -0,0 +1,2853 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1705 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Sep 5 16:18:22 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 45C P0 129W / 700W | 5826MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 45C P0 127W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 35C P0 117W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 44C P0 131W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1516MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 85777 C /usr/bin/python3 1506MiB | +| 0 N/A N/A 85778 C /usr/bin/python3 610MiB | +| 0 N/A N/A 85779 C /usr/bin/python3 610MiB | +| 0 N/A N/A 85780 C /usr/bin/python3 610MiB | +| 0 N/A N/A 85781 C /usr/bin/python3 610MiB | +| 0 N/A N/A 85782 C /usr/bin/python3 610MiB | +| 0 N/A N/A 85783 C /usr/bin/python3 610MiB | +| 0 N/A N/A 85784 C /usr/bin/python3 610MiB | +| 1 N/A N/A 85778 C /usr/bin/python3 1506MiB | +| 2 N/A N/A 85779 C /usr/bin/python3 1506MiB | +| 3 N/A N/A 85780 C /usr/bin/python3 1506MiB | +| 4 N/A N/A 85781 C /usr/bin/python3 1506MiB | +| 5 N/A N/A 85782 C /usr/bin/python3 1506MiB | +| 6 N/A N/A 85783 C /usr/bin/python3 1506MiB | +| 7 N/A N/A 85784 C /usr/bin/python3 1506MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1705 val_loss:10.8258 train_time:0ms step_avg:0.04ms +step:1/1705 train_time:431ms step_avg:431.48ms +step:2/1705 train_time:450ms step_avg:225.20ms +step:3/1705 train_time:520ms step_avg:173.49ms +step:4/1705 train_time:612ms step_avg:152.98ms +step:5/1705 train_time:704ms step_avg:140.78ms +step:6/1705 train_time:796ms step_avg:132.63ms +step:7/1705 train_time:888ms step_avg:126.83ms +step:8/1705 train_time:980ms step_avg:122.52ms +step:9/1705 train_time:1073ms step_avg:119.21ms +step:10/1705 train_time:1165ms step_avg:116.51ms +step:11/1705 train_time:1258ms step_avg:114.32ms +step:12/1705 train_time:1351ms step_avg:112.62ms +step:13/1705 train_time:1446ms step_avg:111.26ms +step:14/1705 train_time:1540ms step_avg:110.03ms +step:15/1705 train_time:1634ms step_avg:108.94ms +step:16/1705 train_time:1727ms step_avg:107.91ms +step:17/1705 train_time:1820ms step_avg:107.06ms +step:18/1705 train_time:1913ms step_avg:106.27ms +step:19/1705 train_time:2006ms step_avg:105.57ms +step:20/1705 train_time:2099ms step_avg:104.95ms +step:21/1705 train_time:2192ms step_avg:104.36ms +step:22/1705 train_time:2284ms step_avg:103.84ms +step:23/1705 train_time:2379ms step_avg:103.44ms +step:24/1705 train_time:2475ms step_avg:103.14ms +step:25/1705 train_time:2567ms step_avg:102.68ms +step:26/1705 train_time:2661ms step_avg:102.34ms +step:27/1705 train_time:2754ms step_avg:102.01ms +step:28/1705 train_time:2847ms step_avg:101.67ms +step:29/1705 train_time:2941ms step_avg:101.40ms +step:30/1705 train_time:3034ms step_avg:101.13ms +step:31/1705 train_time:3126ms step_avg:100.84ms +step:32/1705 train_time:3220ms step_avg:100.63ms +step:33/1705 train_time:3313ms step_avg:100.38ms +step:34/1705 train_time:3406ms step_avg:100.17ms +step:35/1705 train_time:3500ms step_avg:100.00ms +step:36/1705 train_time:3594ms step_avg:99.84ms +step:37/1705 train_time:3687ms step_avg:99.65ms +step:38/1705 train_time:3780ms step_avg:99.48ms +step:39/1705 train_time:3873ms step_avg:99.30ms +step:40/1705 train_time:3966ms step_avg:99.14ms +step:41/1705 train_time:4059ms step_avg:99.00ms +step:42/1705 train_time:4151ms step_avg:98.83ms +step:43/1705 train_time:4244ms step_avg:98.69ms +step:44/1705 train_time:4338ms step_avg:98.59ms +step:45/1705 train_time:4431ms step_avg:98.46ms +step:46/1705 train_time:4525ms step_avg:98.37ms +step:47/1705 train_time:4618ms step_avg:98.26ms +step:48/1705 train_time:4712ms step_avg:98.16ms +step:49/1705 train_time:4805ms step_avg:98.05ms +step:50/1705 train_time:4899ms step_avg:97.97ms +step:51/1705 train_time:4991ms step_avg:97.87ms +step:52/1705 train_time:5084ms step_avg:97.77ms +step:53/1705 train_time:5177ms step_avg:97.68ms +step:54/1705 train_time:5270ms step_avg:97.58ms +step:55/1705 train_time:5363ms step_avg:97.51ms +step:56/1705 train_time:5457ms step_avg:97.45ms +step:57/1705 train_time:5549ms step_avg:97.36ms +step:58/1705 train_time:5643ms step_avg:97.29ms +step:59/1705 train_time:5737ms step_avg:97.24ms +step:60/1705 train_time:5831ms step_avg:97.18ms +step:61/1705 train_time:5923ms step_avg:97.11ms +step:62/1705 train_time:6016ms step_avg:97.04ms +step:63/1705 train_time:6109ms step_avg:96.97ms +step:64/1705 train_time:6202ms step_avg:96.91ms +step:65/1705 train_time:6295ms step_avg:96.85ms +step:66/1705 train_time:6388ms step_avg:96.78ms +step:67/1705 train_time:6481ms step_avg:96.74ms +step:68/1705 train_time:6575ms step_avg:96.69ms +step:69/1705 train_time:6668ms step_avg:96.63ms +step:70/1705 train_time:6761ms step_avg:96.59ms +step:71/1705 train_time:6855ms step_avg:96.55ms +step:72/1705 train_time:6947ms step_avg:96.49ms +step:73/1705 train_time:7041ms step_avg:96.46ms +step:74/1705 train_time:7135ms step_avg:96.42ms +step:75/1705 train_time:7227ms step_avg:96.36ms +step:76/1705 train_time:7320ms step_avg:96.31ms +step:77/1705 train_time:7413ms step_avg:96.27ms +step:78/1705 train_time:7505ms step_avg:96.22ms +step:79/1705 train_time:7599ms step_avg:96.18ms +step:80/1705 train_time:7692ms step_avg:96.14ms +step:81/1705 train_time:7785ms step_avg:96.11ms +step:82/1705 train_time:7878ms step_avg:96.08ms +step:83/1705 train_time:7972ms step_avg:96.04ms +step:84/1705 train_time:8065ms step_avg:96.02ms +step:85/1705 train_time:8159ms step_avg:95.99ms +step:86/1705 train_time:8253ms step_avg:95.96ms +step:87/1705 train_time:8345ms step_avg:95.92ms +step:88/1705 train_time:8440ms step_avg:95.91ms +step:89/1705 train_time:8532ms step_avg:95.87ms +step:90/1705 train_time:8625ms step_avg:95.84ms +step:91/1705 train_time:8719ms step_avg:95.81ms +step:92/1705 train_time:8813ms step_avg:95.79ms +step:93/1705 train_time:8906ms step_avg:95.76ms +step:94/1705 train_time:8999ms step_avg:95.74ms +step:95/1705 train_time:9092ms step_avg:95.70ms +step:96/1705 train_time:9184ms step_avg:95.67ms +step:97/1705 train_time:9278ms step_avg:95.65ms +step:98/1705 train_time:9371ms step_avg:95.62ms +step:99/1705 train_time:9465ms step_avg:95.60ms +step:100/1705 train_time:9558ms step_avg:95.58ms +step:101/1705 train_time:9650ms step_avg:95.55ms +step:102/1705 train_time:9744ms step_avg:95.53ms +step:103/1705 train_time:9836ms step_avg:95.50ms +step:104/1705 train_time:9930ms step_avg:95.48ms +step:105/1705 train_time:10024ms step_avg:95.46ms +step:106/1705 train_time:10117ms step_avg:95.44ms +step:107/1705 train_time:10209ms step_avg:95.41ms +step:108/1705 train_time:10302ms step_avg:95.39ms +step:109/1705 train_time:10395ms step_avg:95.37ms +step:110/1705 train_time:10488ms step_avg:95.34ms +step:111/1705 train_time:10581ms step_avg:95.32ms +step:112/1705 train_time:10674ms step_avg:95.30ms +step:113/1705 train_time:10766ms step_avg:95.28ms +step:114/1705 train_time:10860ms step_avg:95.27ms +step:115/1705 train_time:10953ms step_avg:95.25ms +step:116/1705 train_time:11046ms step_avg:95.23ms +step:117/1705 train_time:11139ms step_avg:95.21ms +step:118/1705 train_time:11233ms step_avg:95.19ms +step:119/1705 train_time:11326ms step_avg:95.17ms +step:120/1705 train_time:11419ms step_avg:95.16ms +step:121/1705 train_time:11512ms step_avg:95.14ms +step:122/1705 train_time:11604ms step_avg:95.12ms +step:123/1705 train_time:11697ms step_avg:95.10ms +step:124/1705 train_time:11790ms step_avg:95.08ms +step:125/1705 train_time:11884ms step_avg:95.07ms +step:125/1705 val_loss:4.3069 train_time:11977ms step_avg:95.82ms +step:126/1705 train_time:11999ms step_avg:95.23ms +step:127/1705 train_time:12076ms step_avg:95.09ms +step:128/1705 train_time:12179ms step_avg:95.15ms +step:129/1705 train_time:12276ms step_avg:95.16ms +step:130/1705 train_time:12369ms step_avg:95.15ms +step:131/1705 train_time:12461ms step_avg:95.12ms +step:132/1705 train_time:12553ms step_avg:95.10ms +step:133/1705 train_time:12645ms step_avg:95.08ms +step:134/1705 train_time:12737ms step_avg:95.05ms +step:135/1705 train_time:12829ms step_avg:95.03ms +step:136/1705 train_time:12921ms step_avg:95.01ms +step:137/1705 train_time:13014ms step_avg:94.99ms +step:138/1705 train_time:13110ms step_avg:95.00ms +step:139/1705 train_time:13203ms step_avg:94.99ms +step:140/1705 train_time:13297ms step_avg:94.98ms +step:141/1705 train_time:13391ms step_avg:94.97ms +step:142/1705 train_time:13484ms step_avg:94.96ms +step:143/1705 train_time:13577ms step_avg:94.94ms +step:144/1705 train_time:13669ms step_avg:94.92ms +step:145/1705 train_time:13761ms step_avg:94.90ms +step:146/1705 train_time:13853ms step_avg:94.88ms +step:147/1705 train_time:13945ms step_avg:94.87ms +step:148/1705 train_time:14038ms step_avg:94.85ms +step:149/1705 train_time:14132ms step_avg:94.84ms +step:150/1705 train_time:14226ms step_avg:94.84ms +step:151/1705 train_time:14320ms step_avg:94.83ms +step:152/1705 train_time:14413ms step_avg:94.82ms +step:153/1705 train_time:14506ms step_avg:94.81ms +step:154/1705 train_time:14598ms step_avg:94.79ms +step:155/1705 train_time:14691ms step_avg:94.78ms +step:156/1705 train_time:14784ms step_avg:94.77ms +step:157/1705 train_time:14877ms step_avg:94.76ms +step:158/1705 train_time:14971ms step_avg:94.75ms +step:159/1705 train_time:15063ms step_avg:94.74ms +step:160/1705 train_time:15157ms step_avg:94.73ms +step:161/1705 train_time:15250ms step_avg:94.72ms +step:162/1705 train_time:15343ms step_avg:94.71ms +step:163/1705 train_time:15437ms step_avg:94.70ms +step:164/1705 train_time:15530ms step_avg:94.70ms +step:165/1705 train_time:15622ms step_avg:94.68ms +step:166/1705 train_time:15715ms step_avg:94.67ms +step:167/1705 train_time:15808ms step_avg:94.66ms +step:168/1705 train_time:15900ms step_avg:94.64ms +step:169/1705 train_time:15994ms step_avg:94.64ms +step:170/1705 train_time:16087ms step_avg:94.63ms +step:171/1705 train_time:16180ms step_avg:94.62ms +step:172/1705 train_time:16274ms step_avg:94.62ms +step:173/1705 train_time:16368ms step_avg:94.61ms +step:174/1705 train_time:16460ms step_avg:94.60ms +step:175/1705 train_time:16553ms step_avg:94.59ms +step:176/1705 train_time:16646ms step_avg:94.58ms +step:177/1705 train_time:16739ms step_avg:94.57ms +step:178/1705 train_time:16832ms step_avg:94.56ms +step:179/1705 train_time:16925ms step_avg:94.55ms +step:180/1705 train_time:17017ms step_avg:94.54ms +step:181/1705 train_time:17111ms step_avg:94.53ms +step:182/1705 train_time:17202ms step_avg:94.52ms +step:183/1705 train_time:17296ms step_avg:94.51ms +step:184/1705 train_time:17390ms step_avg:94.51ms +step:185/1705 train_time:17483ms step_avg:94.50ms +step:186/1705 train_time:17576ms step_avg:94.49ms +step:187/1705 train_time:17669ms step_avg:94.49ms +step:188/1705 train_time:17762ms step_avg:94.48ms +step:189/1705 train_time:17854ms step_avg:94.47ms +step:190/1705 train_time:17947ms step_avg:94.46ms +step:191/1705 train_time:18040ms step_avg:94.45ms +step:192/1705 train_time:18133ms step_avg:94.44ms +step:193/1705 train_time:18227ms step_avg:94.44ms +step:194/1705 train_time:18318ms step_avg:94.42ms +step:195/1705 train_time:18412ms step_avg:94.42ms +step:196/1705 train_time:18505ms step_avg:94.41ms +step:197/1705 train_time:18598ms step_avg:94.41ms +step:198/1705 train_time:18691ms step_avg:94.40ms +step:199/1705 train_time:18785ms step_avg:94.40ms +step:200/1705 train_time:18878ms step_avg:94.39ms +step:201/1705 train_time:18970ms step_avg:94.38ms +step:202/1705 train_time:19062ms step_avg:94.37ms +step:203/1705 train_time:19156ms step_avg:94.36ms +step:204/1705 train_time:19249ms step_avg:94.36ms +step:205/1705 train_time:19342ms step_avg:94.35ms +step:206/1705 train_time:19434ms step_avg:94.34ms +step:207/1705 train_time:19527ms step_avg:94.33ms +step:208/1705 train_time:19620ms step_avg:94.33ms +step:209/1705 train_time:19714ms step_avg:94.33ms +step:210/1705 train_time:19807ms step_avg:94.32ms +step:211/1705 train_time:19900ms step_avg:94.31ms +step:212/1705 train_time:19992ms step_avg:94.30ms +step:213/1705 train_time:20316ms step_avg:95.38ms +step:214/1705 train_time:20421ms step_avg:95.43ms +step:215/1705 train_time:20513ms step_avg:95.41ms +step:216/1705 train_time:20605ms step_avg:95.39ms +step:217/1705 train_time:20697ms step_avg:95.38ms +step:218/1705 train_time:20789ms step_avg:95.36ms +step:219/1705 train_time:20881ms step_avg:95.35ms +step:220/1705 train_time:20974ms step_avg:95.33ms +step:221/1705 train_time:21065ms step_avg:95.32ms +step:222/1705 train_time:21157ms step_avg:95.30ms +step:223/1705 train_time:21252ms step_avg:95.30ms +step:224/1705 train_time:21348ms step_avg:95.30ms +step:225/1705 train_time:21443ms step_avg:95.30ms +step:226/1705 train_time:21535ms step_avg:95.29ms +step:227/1705 train_time:21629ms step_avg:95.28ms +step:228/1705 train_time:21721ms step_avg:95.27ms +step:229/1705 train_time:21813ms step_avg:95.25ms +step:230/1705 train_time:21906ms step_avg:95.24ms +step:231/1705 train_time:21997ms step_avg:95.23ms +step:232/1705 train_time:22089ms step_avg:95.21ms +step:233/1705 train_time:22182ms step_avg:95.20ms +step:234/1705 train_time:22277ms step_avg:95.20ms +step:235/1705 train_time:22372ms step_avg:95.20ms +step:236/1705 train_time:22466ms step_avg:95.19ms +step:237/1705 train_time:22558ms step_avg:95.18ms +step:238/1705 train_time:22651ms step_avg:95.17ms +step:239/1705 train_time:22744ms step_avg:95.16ms +step:240/1705 train_time:22837ms step_avg:95.15ms +step:241/1705 train_time:22929ms step_avg:95.14ms +step:242/1705 train_time:23022ms step_avg:95.13ms +step:243/1705 train_time:23115ms step_avg:95.12ms +step:244/1705 train_time:23208ms step_avg:95.11ms +step:245/1705 train_time:23300ms step_avg:95.10ms +step:246/1705 train_time:23396ms step_avg:95.10ms +step:247/1705 train_time:23490ms step_avg:95.10ms +step:248/1705 train_time:23583ms step_avg:95.09ms +step:249/1705 train_time:23675ms step_avg:95.08ms +step:250/1705 train_time:23768ms step_avg:95.07ms +step:250/1705 val_loss:3.9798 train_time:23860ms step_avg:95.44ms +step:251/1705 train_time:23882ms step_avg:95.15ms +step:252/1705 train_time:23957ms step_avg:95.07ms +step:253/1705 train_time:24055ms step_avg:95.08ms +step:254/1705 train_time:24149ms step_avg:95.07ms +step:255/1705 train_time:24241ms step_avg:95.06ms +step:256/1705 train_time:24334ms step_avg:95.05ms +step:257/1705 train_time:24426ms step_avg:95.04ms +step:258/1705 train_time:24518ms step_avg:95.03ms +step:259/1705 train_time:24610ms step_avg:95.02ms +step:260/1705 train_time:24702ms step_avg:95.01ms +step:261/1705 train_time:24794ms step_avg:95.00ms +step:262/1705 train_time:24888ms step_avg:94.99ms +step:263/1705 train_time:24982ms step_avg:94.99ms +step:264/1705 train_time:25077ms step_avg:94.99ms +step:265/1705 train_time:25171ms step_avg:94.98ms +step:266/1705 train_time:25264ms step_avg:94.98ms +step:267/1705 train_time:25356ms step_avg:94.97ms +step:268/1705 train_time:25449ms step_avg:94.96ms +step:269/1705 train_time:25541ms step_avg:94.95ms +step:270/1705 train_time:25633ms step_avg:94.94ms +step:271/1705 train_time:25725ms step_avg:94.93ms +step:272/1705 train_time:25818ms step_avg:94.92ms +step:273/1705 train_time:25912ms step_avg:94.92ms +step:274/1705 train_time:26006ms step_avg:94.91ms +step:275/1705 train_time:26100ms step_avg:94.91ms +step:276/1705 train_time:26193ms step_avg:94.90ms +step:277/1705 train_time:26286ms step_avg:94.89ms +step:278/1705 train_time:26378ms step_avg:94.88ms +step:279/1705 train_time:26471ms step_avg:94.88ms +step:280/1705 train_time:26563ms step_avg:94.87ms +step:281/1705 train_time:26656ms step_avg:94.86ms +step:282/1705 train_time:26749ms step_avg:94.85ms +step:283/1705 train_time:26841ms step_avg:94.85ms +step:284/1705 train_time:26934ms step_avg:94.84ms +step:285/1705 train_time:27028ms step_avg:94.84ms +step:286/1705 train_time:27122ms step_avg:94.83ms +step:287/1705 train_time:27216ms step_avg:94.83ms +step:288/1705 train_time:27309ms step_avg:94.82ms +step:289/1705 train_time:27401ms step_avg:94.81ms +step:290/1705 train_time:27494ms step_avg:94.81ms +step:291/1705 train_time:27586ms step_avg:94.80ms +step:292/1705 train_time:27679ms step_avg:94.79ms +step:293/1705 train_time:27772ms step_avg:94.79ms +step:294/1705 train_time:27865ms step_avg:94.78ms +step:295/1705 train_time:27958ms step_avg:94.77ms +step:296/1705 train_time:28051ms step_avg:94.77ms +step:297/1705 train_time:28145ms step_avg:94.77ms +step:298/1705 train_time:28239ms step_avg:94.76ms +step:299/1705 train_time:28332ms step_avg:94.76ms +step:300/1705 train_time:28425ms step_avg:94.75ms +step:301/1705 train_time:28518ms step_avg:94.74ms +step:302/1705 train_time:28610ms step_avg:94.73ms +step:303/1705 train_time:28702ms step_avg:94.73ms +step:304/1705 train_time:28795ms step_avg:94.72ms +step:305/1705 train_time:28888ms step_avg:94.71ms +step:306/1705 train_time:28980ms step_avg:94.71ms +step:307/1705 train_time:29074ms step_avg:94.70ms +step:308/1705 train_time:29169ms step_avg:94.70ms +step:309/1705 train_time:29262ms step_avg:94.70ms +step:310/1705 train_time:29355ms step_avg:94.69ms +step:311/1705 train_time:29449ms step_avg:94.69ms +step:312/1705 train_time:29543ms step_avg:94.69ms +step:313/1705 train_time:29635ms step_avg:94.68ms +step:314/1705 train_time:29728ms step_avg:94.68ms +step:315/1705 train_time:29821ms step_avg:94.67ms +step:316/1705 train_time:29913ms step_avg:94.66ms +step:317/1705 train_time:30006ms step_avg:94.66ms +step:318/1705 train_time:30099ms step_avg:94.65ms +step:319/1705 train_time:30192ms step_avg:94.65ms +step:320/1705 train_time:30285ms step_avg:94.64ms +step:321/1705 train_time:30378ms step_avg:94.64ms +step:322/1705 train_time:30471ms step_avg:94.63ms +step:323/1705 train_time:30564ms step_avg:94.63ms +step:324/1705 train_time:30656ms step_avg:94.62ms +step:325/1705 train_time:30749ms step_avg:94.61ms +step:326/1705 train_time:30843ms step_avg:94.61ms +step:327/1705 train_time:30935ms step_avg:94.60ms +step:328/1705 train_time:31029ms step_avg:94.60ms +step:329/1705 train_time:31120ms step_avg:94.59ms +step:330/1705 train_time:31213ms step_avg:94.58ms +step:331/1705 train_time:31307ms step_avg:94.58ms +step:332/1705 train_time:31399ms step_avg:94.57ms +step:333/1705 train_time:31492ms step_avg:94.57ms +step:334/1705 train_time:31585ms step_avg:94.57ms +step:335/1705 train_time:31677ms step_avg:94.56ms +step:336/1705 train_time:31770ms step_avg:94.55ms +step:337/1705 train_time:31863ms step_avg:94.55ms +step:338/1705 train_time:31956ms step_avg:94.54ms +step:339/1705 train_time:32050ms step_avg:94.54ms +step:340/1705 train_time:32143ms step_avg:94.54ms +step:341/1705 train_time:32236ms step_avg:94.53ms +step:342/1705 train_time:32329ms step_avg:94.53ms +step:343/1705 train_time:32423ms step_avg:94.53ms +step:344/1705 train_time:32515ms step_avg:94.52ms +step:345/1705 train_time:32609ms step_avg:94.52ms +step:346/1705 train_time:32702ms step_avg:94.51ms +step:347/1705 train_time:32794ms step_avg:94.51ms +step:348/1705 train_time:32887ms step_avg:94.50ms +step:349/1705 train_time:32980ms step_avg:94.50ms +step:350/1705 train_time:33073ms step_avg:94.50ms +step:351/1705 train_time:33167ms step_avg:94.49ms +step:352/1705 train_time:33259ms step_avg:94.48ms +step:353/1705 train_time:33352ms step_avg:94.48ms +step:354/1705 train_time:33445ms step_avg:94.48ms +step:355/1705 train_time:33538ms step_avg:94.47ms +step:356/1705 train_time:33630ms step_avg:94.47ms +step:357/1705 train_time:33723ms step_avg:94.46ms +step:358/1705 train_time:33816ms step_avg:94.46ms +step:359/1705 train_time:33909ms step_avg:94.45ms +step:360/1705 train_time:34001ms step_avg:94.45ms +step:361/1705 train_time:34094ms step_avg:94.44ms +step:362/1705 train_time:34187ms step_avg:94.44ms +step:363/1705 train_time:34280ms step_avg:94.43ms +step:364/1705 train_time:34373ms step_avg:94.43ms +step:365/1705 train_time:34466ms step_avg:94.43ms +step:366/1705 train_time:34558ms step_avg:94.42ms +step:367/1705 train_time:34652ms step_avg:94.42ms +step:368/1705 train_time:34745ms step_avg:94.42ms +step:369/1705 train_time:34837ms step_avg:94.41ms +step:370/1705 train_time:34930ms step_avg:94.41ms +step:371/1705 train_time:35023ms step_avg:94.40ms +step:372/1705 train_time:35116ms step_avg:94.40ms +step:373/1705 train_time:35209ms step_avg:94.39ms +step:374/1705 train_time:35302ms step_avg:94.39ms +step:375/1705 train_time:35395ms step_avg:94.39ms +step:375/1705 val_loss:3.8258 train_time:35489ms step_avg:94.64ms +step:376/1705 train_time:35510ms step_avg:94.44ms +step:377/1705 train_time:35587ms step_avg:94.39ms +step:378/1705 train_time:35685ms step_avg:94.41ms +step:379/1705 train_time:35779ms step_avg:94.40ms +step:380/1705 train_time:35871ms step_avg:94.40ms +step:381/1705 train_time:35963ms step_avg:94.39ms +step:382/1705 train_time:36055ms step_avg:94.39ms +step:383/1705 train_time:36147ms step_avg:94.38ms +step:384/1705 train_time:36239ms step_avg:94.37ms +step:385/1705 train_time:36331ms step_avg:94.37ms +step:386/1705 train_time:36424ms step_avg:94.36ms +step:387/1705 train_time:36519ms step_avg:94.36ms +step:388/1705 train_time:36614ms step_avg:94.37ms +step:389/1705 train_time:36709ms step_avg:94.37ms +step:390/1705 train_time:36802ms step_avg:94.36ms +step:391/1705 train_time:36895ms step_avg:94.36ms +step:392/1705 train_time:36988ms step_avg:94.36ms +step:393/1705 train_time:37081ms step_avg:94.35ms +step:394/1705 train_time:37173ms step_avg:94.35ms +step:395/1705 train_time:37265ms step_avg:94.34ms +step:396/1705 train_time:37358ms step_avg:94.34ms +step:397/1705 train_time:37450ms step_avg:94.33ms +step:398/1705 train_time:37544ms step_avg:94.33ms +step:399/1705 train_time:37639ms step_avg:94.33ms +step:400/1705 train_time:37732ms step_avg:94.33ms +step:401/1705 train_time:37825ms step_avg:94.33ms +step:402/1705 train_time:37919ms step_avg:94.33ms +step:403/1705 train_time:38011ms step_avg:94.32ms +step:404/1705 train_time:38104ms step_avg:94.32ms +step:405/1705 train_time:38196ms step_avg:94.31ms +step:406/1705 train_time:38290ms step_avg:94.31ms +step:407/1705 train_time:38383ms step_avg:94.31ms +step:408/1705 train_time:38476ms step_avg:94.30ms +step:409/1705 train_time:38569ms step_avg:94.30ms +step:410/1705 train_time:38662ms step_avg:94.30ms +step:411/1705 train_time:38755ms step_avg:94.29ms +step:412/1705 train_time:38848ms step_avg:94.29ms +step:413/1705 train_time:38941ms step_avg:94.29ms +step:414/1705 train_time:39034ms step_avg:94.29ms +step:415/1705 train_time:39126ms step_avg:94.28ms +step:416/1705 train_time:39219ms step_avg:94.28ms +step:417/1705 train_time:39312ms step_avg:94.27ms +step:418/1705 train_time:39405ms step_avg:94.27ms +step:419/1705 train_time:39499ms step_avg:94.27ms +step:420/1705 train_time:39593ms step_avg:94.27ms +step:421/1705 train_time:39686ms step_avg:94.27ms +step:422/1705 train_time:39780ms step_avg:94.27ms +step:423/1705 train_time:39874ms step_avg:94.26ms +step:424/1705 train_time:39967ms step_avg:94.26ms +step:425/1705 train_time:40241ms step_avg:94.68ms +step:426/1705 train_time:40348ms step_avg:94.71ms +step:427/1705 train_time:40439ms step_avg:94.70ms +step:428/1705 train_time:40530ms step_avg:94.70ms +step:429/1705 train_time:40622ms step_avg:94.69ms +step:430/1705 train_time:40715ms step_avg:94.69ms +step:431/1705 train_time:40807ms step_avg:94.68ms +step:432/1705 train_time:40899ms step_avg:94.67ms +step:433/1705 train_time:40991ms step_avg:94.67ms +step:434/1705 train_time:41083ms step_avg:94.66ms +step:435/1705 train_time:41177ms step_avg:94.66ms +step:436/1705 train_time:41272ms step_avg:94.66ms +step:437/1705 train_time:41367ms step_avg:94.66ms +step:438/1705 train_time:41461ms step_avg:94.66ms +step:439/1705 train_time:41554ms step_avg:94.66ms +step:440/1705 train_time:41646ms step_avg:94.65ms +step:441/1705 train_time:41738ms step_avg:94.64ms +step:442/1705 train_time:41830ms step_avg:94.64ms +step:443/1705 train_time:41922ms step_avg:94.63ms +step:444/1705 train_time:42015ms step_avg:94.63ms +step:445/1705 train_time:42107ms step_avg:94.62ms +step:446/1705 train_time:42201ms step_avg:94.62ms +step:447/1705 train_time:42296ms step_avg:94.62ms +step:448/1705 train_time:42389ms step_avg:94.62ms +step:449/1705 train_time:42484ms step_avg:94.62ms +step:450/1705 train_time:42577ms step_avg:94.62ms +step:451/1705 train_time:42669ms step_avg:94.61ms +step:452/1705 train_time:42762ms step_avg:94.61ms +step:453/1705 train_time:42854ms step_avg:94.60ms +step:454/1705 train_time:42946ms step_avg:94.59ms +step:455/1705 train_time:43038ms step_avg:94.59ms +step:456/1705 train_time:43130ms step_avg:94.58ms +step:457/1705 train_time:43224ms step_avg:94.58ms +step:458/1705 train_time:43319ms step_avg:94.58ms +step:459/1705 train_time:43412ms step_avg:94.58ms +step:460/1705 train_time:43505ms step_avg:94.58ms +step:461/1705 train_time:43599ms step_avg:94.58ms +step:462/1705 train_time:43692ms step_avg:94.57ms +step:463/1705 train_time:43785ms step_avg:94.57ms +step:464/1705 train_time:43877ms step_avg:94.56ms +step:465/1705 train_time:43969ms step_avg:94.56ms +step:466/1705 train_time:44061ms step_avg:94.55ms +step:467/1705 train_time:44155ms step_avg:94.55ms +step:468/1705 train_time:44248ms step_avg:94.55ms +step:469/1705 train_time:44341ms step_avg:94.54ms +step:470/1705 train_time:44435ms step_avg:94.54ms +step:471/1705 train_time:44528ms step_avg:94.54ms +step:472/1705 train_time:44621ms step_avg:94.54ms +step:473/1705 train_time:44714ms step_avg:94.53ms +step:474/1705 train_time:44806ms step_avg:94.53ms +step:475/1705 train_time:44900ms step_avg:94.53ms +step:476/1705 train_time:44993ms step_avg:94.52ms +step:477/1705 train_time:45085ms step_avg:94.52ms +step:478/1705 train_time:45179ms step_avg:94.52ms +step:479/1705 train_time:45272ms step_avg:94.51ms +step:480/1705 train_time:45365ms step_avg:94.51ms +step:481/1705 train_time:45459ms step_avg:94.51ms +step:482/1705 train_time:45552ms step_avg:94.51ms +step:483/1705 train_time:45644ms step_avg:94.50ms +step:484/1705 train_time:45737ms step_avg:94.50ms +step:485/1705 train_time:45831ms step_avg:94.50ms +step:486/1705 train_time:45923ms step_avg:94.49ms +step:487/1705 train_time:46017ms step_avg:94.49ms +step:488/1705 train_time:46109ms step_avg:94.49ms +step:489/1705 train_time:46202ms step_avg:94.48ms +step:490/1705 train_time:46296ms step_avg:94.48ms +step:491/1705 train_time:46389ms step_avg:94.48ms +step:492/1705 train_time:46483ms step_avg:94.48ms +step:493/1705 train_time:46576ms step_avg:94.47ms +step:494/1705 train_time:46669ms step_avg:94.47ms +step:495/1705 train_time:46762ms step_avg:94.47ms +step:496/1705 train_time:46855ms step_avg:94.47ms +step:497/1705 train_time:46948ms step_avg:94.46ms +step:498/1705 train_time:47040ms step_avg:94.46ms +step:499/1705 train_time:47133ms step_avg:94.46ms +step:500/1705 train_time:47227ms step_avg:94.45ms +step:500/1705 val_loss:3.7225 train_time:47320ms step_avg:94.64ms +step:501/1705 train_time:47342ms step_avg:94.49ms +step:502/1705 train_time:47419ms step_avg:94.46ms +step:503/1705 train_time:47516ms step_avg:94.47ms +step:504/1705 train_time:47610ms step_avg:94.46ms +step:505/1705 train_time:47702ms step_avg:94.46ms +step:506/1705 train_time:47795ms step_avg:94.46ms +step:507/1705 train_time:47887ms step_avg:94.45ms +step:508/1705 train_time:47979ms step_avg:94.45ms +step:509/1705 train_time:48071ms step_avg:94.44ms +step:510/1705 train_time:48163ms step_avg:94.44ms +step:511/1705 train_time:48255ms step_avg:94.43ms +step:512/1705 train_time:48351ms step_avg:94.44ms +step:513/1705 train_time:48447ms step_avg:94.44ms +step:514/1705 train_time:48541ms step_avg:94.44ms +step:515/1705 train_time:48635ms step_avg:94.44ms +step:516/1705 train_time:48727ms step_avg:94.43ms +step:517/1705 train_time:48820ms step_avg:94.43ms +step:518/1705 train_time:48912ms step_avg:94.42ms +step:519/1705 train_time:49005ms step_avg:94.42ms +step:520/1705 train_time:49097ms step_avg:94.42ms +step:521/1705 train_time:49189ms step_avg:94.41ms +step:522/1705 train_time:49282ms step_avg:94.41ms +step:523/1705 train_time:49377ms step_avg:94.41ms +step:524/1705 train_time:49471ms step_avg:94.41ms +step:525/1705 train_time:49565ms step_avg:94.41ms +step:526/1705 train_time:49658ms step_avg:94.41ms +step:527/1705 train_time:49751ms step_avg:94.40ms +step:528/1705 train_time:49844ms step_avg:94.40ms +step:529/1705 train_time:49936ms step_avg:94.40ms +step:530/1705 train_time:50029ms step_avg:94.39ms +step:531/1705 train_time:50121ms step_avg:94.39ms +step:532/1705 train_time:50214ms step_avg:94.39ms +step:533/1705 train_time:50306ms step_avg:94.38ms +step:534/1705 train_time:50400ms step_avg:94.38ms +step:535/1705 train_time:50493ms step_avg:94.38ms +step:536/1705 train_time:50587ms step_avg:94.38ms +step:537/1705 train_time:50680ms step_avg:94.38ms +step:538/1705 train_time:50774ms step_avg:94.37ms +step:539/1705 train_time:50867ms step_avg:94.37ms +step:540/1705 train_time:50959ms step_avg:94.37ms +step:541/1705 train_time:51052ms step_avg:94.37ms +step:542/1705 train_time:51145ms step_avg:94.36ms +step:543/1705 train_time:51237ms step_avg:94.36ms +step:544/1705 train_time:51330ms step_avg:94.36ms +step:545/1705 train_time:51423ms step_avg:94.35ms +step:546/1705 train_time:51516ms step_avg:94.35ms +step:547/1705 train_time:51609ms step_avg:94.35ms +step:548/1705 train_time:51702ms step_avg:94.35ms +step:549/1705 train_time:51795ms step_avg:94.34ms +step:550/1705 train_time:51888ms step_avg:94.34ms +step:551/1705 train_time:51980ms step_avg:94.34ms +step:552/1705 train_time:52073ms step_avg:94.34ms +step:553/1705 train_time:52167ms step_avg:94.33ms +step:554/1705 train_time:52259ms step_avg:94.33ms +step:555/1705 train_time:52353ms step_avg:94.33ms +step:556/1705 train_time:52446ms step_avg:94.33ms +step:557/1705 train_time:52539ms step_avg:94.32ms +step:558/1705 train_time:52632ms step_avg:94.32ms +step:559/1705 train_time:52725ms step_avg:94.32ms +step:560/1705 train_time:52817ms step_avg:94.32ms +step:561/1705 train_time:52911ms step_avg:94.32ms +step:562/1705 train_time:53004ms step_avg:94.31ms +step:563/1705 train_time:53096ms step_avg:94.31ms +step:564/1705 train_time:53189ms step_avg:94.31ms +step:565/1705 train_time:53282ms step_avg:94.30ms +step:566/1705 train_time:53374ms step_avg:94.30ms +step:567/1705 train_time:53468ms step_avg:94.30ms +step:568/1705 train_time:53561ms step_avg:94.30ms +step:569/1705 train_time:53654ms step_avg:94.30ms +step:570/1705 train_time:53748ms step_avg:94.29ms +step:571/1705 train_time:53842ms step_avg:94.29ms +step:572/1705 train_time:53937ms step_avg:94.29ms +step:573/1705 train_time:54031ms step_avg:94.29ms +step:574/1705 train_time:54126ms step_avg:94.30ms +step:575/1705 train_time:54220ms step_avg:94.30ms +step:576/1705 train_time:54314ms step_avg:94.30ms +step:577/1705 train_time:54409ms step_avg:94.30ms +step:578/1705 train_time:54504ms step_avg:94.30ms +step:579/1705 train_time:54597ms step_avg:94.30ms +step:580/1705 train_time:54692ms step_avg:94.30ms +step:581/1705 train_time:54788ms step_avg:94.30ms +step:582/1705 train_time:54882ms step_avg:94.30ms +step:583/1705 train_time:54977ms step_avg:94.30ms +step:584/1705 train_time:55072ms step_avg:94.30ms +step:585/1705 train_time:55166ms step_avg:94.30ms +step:586/1705 train_time:55259ms step_avg:94.30ms +step:587/1705 train_time:55353ms step_avg:94.30ms +step:588/1705 train_time:55448ms step_avg:94.30ms +step:589/1705 train_time:55542ms step_avg:94.30ms +step:590/1705 train_time:55636ms step_avg:94.30ms +step:591/1705 train_time:55732ms step_avg:94.30ms +step:592/1705 train_time:55827ms step_avg:94.30ms +step:593/1705 train_time:55920ms step_avg:94.30ms +step:594/1705 train_time:56015ms step_avg:94.30ms +step:595/1705 train_time:56111ms step_avg:94.30ms +step:596/1705 train_time:56205ms step_avg:94.30ms +step:597/1705 train_time:56299ms step_avg:94.30ms +step:598/1705 train_time:56394ms step_avg:94.30ms +step:599/1705 train_time:56488ms step_avg:94.30ms +step:600/1705 train_time:56582ms step_avg:94.30ms +step:601/1705 train_time:56676ms step_avg:94.30ms +step:602/1705 train_time:56772ms step_avg:94.31ms +step:603/1705 train_time:56866ms step_avg:94.31ms +step:604/1705 train_time:56960ms step_avg:94.30ms +step:605/1705 train_time:57054ms step_avg:94.30ms +step:606/1705 train_time:57150ms step_avg:94.31ms +step:607/1705 train_time:57245ms step_avg:94.31ms +step:608/1705 train_time:57339ms step_avg:94.31ms +step:609/1705 train_time:57434ms step_avg:94.31ms +step:610/1705 train_time:57528ms step_avg:94.31ms +step:611/1705 train_time:57622ms step_avg:94.31ms +step:612/1705 train_time:57716ms step_avg:94.31ms +step:613/1705 train_time:57811ms step_avg:94.31ms +step:614/1705 train_time:57906ms step_avg:94.31ms +step:615/1705 train_time:57999ms step_avg:94.31ms +step:616/1705 train_time:58093ms step_avg:94.31ms +step:617/1705 train_time:58187ms step_avg:94.31ms +step:618/1705 train_time:58281ms step_avg:94.31ms +step:619/1705 train_time:58375ms step_avg:94.31ms +step:620/1705 train_time:58470ms step_avg:94.31ms +step:621/1705 train_time:58564ms step_avg:94.31ms +step:622/1705 train_time:58659ms step_avg:94.31ms +step:623/1705 train_time:58753ms step_avg:94.31ms +step:624/1705 train_time:58848ms step_avg:94.31ms +step:625/1705 train_time:58942ms step_avg:94.31ms +step:625/1705 val_loss:3.6215 train_time:59037ms step_avg:94.46ms +step:626/1705 train_time:59060ms step_avg:94.35ms +step:627/1705 train_time:59132ms step_avg:94.31ms +step:628/1705 train_time:59227ms step_avg:94.31ms +step:629/1705 train_time:59331ms step_avg:94.33ms +step:630/1705 train_time:59427ms step_avg:94.33ms +step:631/1705 train_time:59521ms step_avg:94.33ms +step:632/1705 train_time:59614ms step_avg:94.33ms +step:633/1705 train_time:59708ms step_avg:94.32ms +step:634/1705 train_time:59802ms step_avg:94.32ms +step:635/1705 train_time:59895ms step_avg:94.32ms +step:636/1705 train_time:59990ms step_avg:94.32ms +step:637/1705 train_time:60085ms step_avg:94.33ms +step:638/1705 train_time:60180ms step_avg:94.33ms +step:639/1705 train_time:60537ms step_avg:94.74ms +step:640/1705 train_time:60635ms step_avg:94.74ms +step:641/1705 train_time:60728ms step_avg:94.74ms +step:642/1705 train_time:60821ms step_avg:94.74ms +step:643/1705 train_time:60915ms step_avg:94.74ms +step:644/1705 train_time:61008ms step_avg:94.73ms +step:645/1705 train_time:61102ms step_avg:94.73ms +step:646/1705 train_time:61194ms step_avg:94.73ms +step:647/1705 train_time:61287ms step_avg:94.73ms +step:648/1705 train_time:61381ms step_avg:94.72ms +step:649/1705 train_time:61478ms step_avg:94.73ms +step:650/1705 train_time:61575ms step_avg:94.73ms +step:651/1705 train_time:61669ms step_avg:94.73ms +step:652/1705 train_time:61764ms step_avg:94.73ms +step:653/1705 train_time:61858ms step_avg:94.73ms +step:654/1705 train_time:61952ms step_avg:94.73ms +step:655/1705 train_time:62046ms step_avg:94.73ms +step:656/1705 train_time:62140ms step_avg:94.73ms +step:657/1705 train_time:62232ms step_avg:94.72ms +step:658/1705 train_time:62326ms step_avg:94.72ms +step:659/1705 train_time:62422ms step_avg:94.72ms +step:660/1705 train_time:62518ms step_avg:94.72ms +step:661/1705 train_time:62614ms step_avg:94.73ms +step:662/1705 train_time:62708ms step_avg:94.73ms +step:663/1705 train_time:62803ms step_avg:94.73ms +step:664/1705 train_time:62898ms step_avg:94.73ms +step:665/1705 train_time:62992ms step_avg:94.72ms +step:666/1705 train_time:63085ms step_avg:94.72ms +step:667/1705 train_time:63180ms step_avg:94.72ms +step:668/1705 train_time:63274ms step_avg:94.72ms +step:669/1705 train_time:63367ms step_avg:94.72ms +step:670/1705 train_time:63462ms step_avg:94.72ms +step:671/1705 train_time:63557ms step_avg:94.72ms +step:672/1705 train_time:63652ms step_avg:94.72ms +step:673/1705 train_time:63747ms step_avg:94.72ms +step:674/1705 train_time:63841ms step_avg:94.72ms +step:675/1705 train_time:63937ms step_avg:94.72ms +step:676/1705 train_time:64031ms step_avg:94.72ms +step:677/1705 train_time:64124ms step_avg:94.72ms +step:678/1705 train_time:64219ms step_avg:94.72ms +step:679/1705 train_time:64311ms step_avg:94.71ms +step:680/1705 train_time:64406ms step_avg:94.71ms +step:681/1705 train_time:64501ms step_avg:94.71ms +step:682/1705 train_time:64595ms step_avg:94.71ms +step:683/1705 train_time:64690ms step_avg:94.71ms +step:684/1705 train_time:64784ms step_avg:94.71ms +step:685/1705 train_time:64880ms step_avg:94.71ms +step:686/1705 train_time:64975ms step_avg:94.72ms +step:687/1705 train_time:65068ms step_avg:94.71ms +step:688/1705 train_time:65162ms step_avg:94.71ms +step:689/1705 train_time:65256ms step_avg:94.71ms +step:690/1705 train_time:65350ms step_avg:94.71ms +step:691/1705 train_time:65444ms step_avg:94.71ms +step:692/1705 train_time:65539ms step_avg:94.71ms +step:693/1705 train_time:65634ms step_avg:94.71ms +step:694/1705 train_time:65728ms step_avg:94.71ms +step:695/1705 train_time:65825ms step_avg:94.71ms +step:696/1705 train_time:65919ms step_avg:94.71ms +step:697/1705 train_time:66015ms step_avg:94.71ms +step:698/1705 train_time:66108ms step_avg:94.71ms +step:699/1705 train_time:66202ms step_avg:94.71ms +step:700/1705 train_time:66296ms step_avg:94.71ms +step:701/1705 train_time:66390ms step_avg:94.71ms +step:702/1705 train_time:66485ms step_avg:94.71ms +step:703/1705 train_time:66579ms step_avg:94.71ms +step:704/1705 train_time:66674ms step_avg:94.71ms +step:705/1705 train_time:66768ms step_avg:94.71ms +step:706/1705 train_time:66863ms step_avg:94.71ms +step:707/1705 train_time:66958ms step_avg:94.71ms +step:708/1705 train_time:67052ms step_avg:94.71ms +step:709/1705 train_time:67146ms step_avg:94.71ms +step:710/1705 train_time:67242ms step_avg:94.71ms +step:711/1705 train_time:67337ms step_avg:94.71ms +step:712/1705 train_time:67431ms step_avg:94.71ms +step:713/1705 train_time:67525ms step_avg:94.71ms +step:714/1705 train_time:67620ms step_avg:94.71ms +step:715/1705 train_time:67715ms step_avg:94.71ms +step:716/1705 train_time:67808ms step_avg:94.70ms +step:717/1705 train_time:67904ms step_avg:94.71ms +step:718/1705 train_time:67999ms step_avg:94.71ms +step:719/1705 train_time:68093ms step_avg:94.71ms +step:720/1705 train_time:68187ms step_avg:94.70ms +step:721/1705 train_time:68281ms step_avg:94.70ms +step:722/1705 train_time:68376ms step_avg:94.70ms +step:723/1705 train_time:68470ms step_avg:94.70ms +step:724/1705 train_time:68564ms step_avg:94.70ms +step:725/1705 train_time:68659ms step_avg:94.70ms +step:726/1705 train_time:68754ms step_avg:94.70ms +step:727/1705 train_time:68849ms step_avg:94.70ms +step:728/1705 train_time:68943ms step_avg:94.70ms +step:729/1705 train_time:69038ms step_avg:94.70ms +step:730/1705 train_time:69132ms step_avg:94.70ms +step:731/1705 train_time:69226ms step_avg:94.70ms +step:732/1705 train_time:69320ms step_avg:94.70ms +step:733/1705 train_time:69416ms step_avg:94.70ms +step:734/1705 train_time:69509ms step_avg:94.70ms +step:735/1705 train_time:69604ms step_avg:94.70ms +step:736/1705 train_time:69699ms step_avg:94.70ms +step:737/1705 train_time:69794ms step_avg:94.70ms +step:738/1705 train_time:69888ms step_avg:94.70ms +step:739/1705 train_time:69982ms step_avg:94.70ms +step:740/1705 train_time:70077ms step_avg:94.70ms +step:741/1705 train_time:70171ms step_avg:94.70ms +step:742/1705 train_time:70265ms step_avg:94.70ms +step:743/1705 train_time:70360ms step_avg:94.70ms +step:744/1705 train_time:70455ms step_avg:94.70ms +step:745/1705 train_time:70549ms step_avg:94.70ms +step:746/1705 train_time:70644ms step_avg:94.70ms +step:747/1705 train_time:70739ms step_avg:94.70ms +step:748/1705 train_time:70834ms step_avg:94.70ms +step:749/1705 train_time:70928ms step_avg:94.70ms +step:750/1705 train_time:71023ms step_avg:94.70ms +step:750/1705 val_loss:3.5671 train_time:71119ms step_avg:94.82ms +step:751/1705 train_time:71139ms step_avg:94.73ms +step:752/1705 train_time:71217ms step_avg:94.70ms +step:753/1705 train_time:71315ms step_avg:94.71ms +step:754/1705 train_time:71412ms step_avg:94.71ms +step:755/1705 train_time:71506ms step_avg:94.71ms +step:756/1705 train_time:71599ms step_avg:94.71ms +step:757/1705 train_time:71692ms step_avg:94.71ms +step:758/1705 train_time:71786ms step_avg:94.70ms +step:759/1705 train_time:71879ms step_avg:94.70ms +step:760/1705 train_time:71973ms step_avg:94.70ms +step:761/1705 train_time:72067ms step_avg:94.70ms +step:762/1705 train_time:72162ms step_avg:94.70ms +step:763/1705 train_time:72259ms step_avg:94.70ms +step:764/1705 train_time:72355ms step_avg:94.71ms +step:765/1705 train_time:72450ms step_avg:94.71ms +step:766/1705 train_time:72544ms step_avg:94.71ms +step:767/1705 train_time:72638ms step_avg:94.70ms +step:768/1705 train_time:72732ms step_avg:94.70ms +step:769/1705 train_time:72826ms step_avg:94.70ms +step:770/1705 train_time:72919ms step_avg:94.70ms +step:771/1705 train_time:73012ms step_avg:94.70ms +step:772/1705 train_time:73107ms step_avg:94.70ms +step:773/1705 train_time:73202ms step_avg:94.70ms +step:774/1705 train_time:73299ms step_avg:94.70ms +step:775/1705 train_time:73394ms step_avg:94.70ms +step:776/1705 train_time:73489ms step_avg:94.70ms +step:777/1705 train_time:73584ms step_avg:94.70ms +step:778/1705 train_time:73677ms step_avg:94.70ms +step:779/1705 train_time:73771ms step_avg:94.70ms +step:780/1705 train_time:73865ms step_avg:94.70ms +step:781/1705 train_time:73960ms step_avg:94.70ms +step:782/1705 train_time:74053ms step_avg:94.70ms +step:783/1705 train_time:74149ms step_avg:94.70ms +step:784/1705 train_time:74246ms step_avg:94.70ms +step:785/1705 train_time:74342ms step_avg:94.70ms +step:786/1705 train_time:74437ms step_avg:94.70ms +step:787/1705 train_time:74532ms step_avg:94.70ms +step:788/1705 train_time:74626ms step_avg:94.70ms +step:789/1705 train_time:74721ms step_avg:94.70ms +step:790/1705 train_time:74815ms step_avg:94.70ms +step:791/1705 train_time:74909ms step_avg:94.70ms +step:792/1705 train_time:75004ms step_avg:94.70ms +step:793/1705 train_time:75097ms step_avg:94.70ms +step:794/1705 train_time:75191ms step_avg:94.70ms +step:795/1705 train_time:75287ms step_avg:94.70ms +step:796/1705 train_time:75382ms step_avg:94.70ms +step:797/1705 train_time:75477ms step_avg:94.70ms +step:798/1705 train_time:75571ms step_avg:94.70ms +step:799/1705 train_time:75665ms step_avg:94.70ms +step:800/1705 train_time:75761ms step_avg:94.70ms +step:801/1705 train_time:75855ms step_avg:94.70ms +step:802/1705 train_time:75950ms step_avg:94.70ms +step:803/1705 train_time:76045ms step_avg:94.70ms +step:804/1705 train_time:76139ms step_avg:94.70ms +step:805/1705 train_time:76234ms step_avg:94.70ms +step:806/1705 train_time:76328ms step_avg:94.70ms +step:807/1705 train_time:76425ms step_avg:94.70ms +step:808/1705 train_time:76519ms step_avg:94.70ms +step:809/1705 train_time:76613ms step_avg:94.70ms +step:810/1705 train_time:76708ms step_avg:94.70ms +step:811/1705 train_time:76802ms step_avg:94.70ms +step:812/1705 train_time:76896ms step_avg:94.70ms +step:813/1705 train_time:76991ms step_avg:94.70ms +step:814/1705 train_time:77086ms step_avg:94.70ms +step:815/1705 train_time:77181ms step_avg:94.70ms +step:816/1705 train_time:77275ms step_avg:94.70ms +step:817/1705 train_time:77369ms step_avg:94.70ms +step:818/1705 train_time:77464ms step_avg:94.70ms +step:819/1705 train_time:77559ms step_avg:94.70ms +step:820/1705 train_time:77653ms step_avg:94.70ms +step:821/1705 train_time:77747ms step_avg:94.70ms +step:822/1705 train_time:77842ms step_avg:94.70ms +step:823/1705 train_time:77936ms step_avg:94.70ms +step:824/1705 train_time:78030ms step_avg:94.70ms +step:825/1705 train_time:78125ms step_avg:94.70ms +step:826/1705 train_time:78220ms step_avg:94.70ms +step:827/1705 train_time:78314ms step_avg:94.70ms +step:828/1705 train_time:78408ms step_avg:94.70ms +step:829/1705 train_time:78503ms step_avg:94.70ms +step:830/1705 train_time:78598ms step_avg:94.70ms +step:831/1705 train_time:78693ms step_avg:94.70ms +step:832/1705 train_time:78788ms step_avg:94.70ms +step:833/1705 train_time:78883ms step_avg:94.70ms +step:834/1705 train_time:78978ms step_avg:94.70ms +step:835/1705 train_time:79071ms step_avg:94.70ms +step:836/1705 train_time:79168ms step_avg:94.70ms +step:837/1705 train_time:79262ms step_avg:94.70ms +step:838/1705 train_time:79356ms step_avg:94.70ms +step:839/1705 train_time:79450ms step_avg:94.70ms +step:840/1705 train_time:79547ms step_avg:94.70ms +step:841/1705 train_time:79641ms step_avg:94.70ms +step:842/1705 train_time:79735ms step_avg:94.70ms +step:843/1705 train_time:79829ms step_avg:94.70ms +step:844/1705 train_time:79925ms step_avg:94.70ms +step:845/1705 train_time:80020ms step_avg:94.70ms +step:846/1705 train_time:80113ms step_avg:94.70ms +step:847/1705 train_time:80207ms step_avg:94.70ms +step:848/1705 train_time:80302ms step_avg:94.70ms +step:849/1705 train_time:80396ms step_avg:94.69ms +step:850/1705 train_time:80490ms step_avg:94.69ms +step:851/1705 train_time:80770ms step_avg:94.91ms +step:852/1705 train_time:80872ms step_avg:94.92ms +step:853/1705 train_time:80966ms step_avg:94.92ms +step:854/1705 train_time:81059ms step_avg:94.92ms +step:855/1705 train_time:81152ms step_avg:94.91ms +step:856/1705 train_time:81246ms step_avg:94.91ms +step:857/1705 train_time:81339ms step_avg:94.91ms +step:858/1705 train_time:81433ms step_avg:94.91ms +step:859/1705 train_time:81526ms step_avg:94.91ms +step:860/1705 train_time:81620ms step_avg:94.91ms +step:861/1705 train_time:81716ms step_avg:94.91ms +step:862/1705 train_time:81814ms step_avg:94.91ms +step:863/1705 train_time:81910ms step_avg:94.91ms +step:864/1705 train_time:82005ms step_avg:94.91ms +step:865/1705 train_time:82099ms step_avg:94.91ms +step:866/1705 train_time:82193ms step_avg:94.91ms +step:867/1705 train_time:82286ms step_avg:94.91ms +step:868/1705 train_time:82380ms step_avg:94.91ms +step:869/1705 train_time:82473ms step_avg:94.91ms +step:870/1705 train_time:82567ms step_avg:94.90ms +step:871/1705 train_time:82661ms step_avg:94.90ms +step:872/1705 train_time:82756ms step_avg:94.90ms +step:873/1705 train_time:82851ms step_avg:94.90ms +step:874/1705 train_time:82947ms step_avg:94.90ms +step:875/1705 train_time:83043ms step_avg:94.91ms +step:875/1705 val_loss:3.5249 train_time:83139ms step_avg:95.02ms +step:876/1705 train_time:83160ms step_avg:94.93ms +step:877/1705 train_time:83238ms step_avg:94.91ms +step:878/1705 train_time:83337ms step_avg:94.92ms +step:879/1705 train_time:83432ms step_avg:94.92ms +step:880/1705 train_time:83528ms step_avg:94.92ms +step:881/1705 train_time:83621ms step_avg:94.92ms +step:882/1705 train_time:83714ms step_avg:94.91ms +step:883/1705 train_time:83808ms step_avg:94.91ms +step:884/1705 train_time:83901ms step_avg:94.91ms +step:885/1705 train_time:83995ms step_avg:94.91ms +step:886/1705 train_time:84090ms step_avg:94.91ms +step:887/1705 train_time:84187ms step_avg:94.91ms +step:888/1705 train_time:84284ms step_avg:94.91ms +step:889/1705 train_time:84380ms step_avg:94.92ms +step:890/1705 train_time:84474ms step_avg:94.91ms +step:891/1705 train_time:84568ms step_avg:94.91ms +step:892/1705 train_time:84662ms step_avg:94.91ms +step:893/1705 train_time:84756ms step_avg:94.91ms +step:894/1705 train_time:84849ms step_avg:94.91ms +step:895/1705 train_time:84943ms step_avg:94.91ms +step:896/1705 train_time:85037ms step_avg:94.91ms +step:897/1705 train_time:85131ms step_avg:94.91ms +step:898/1705 train_time:85227ms step_avg:94.91ms +step:899/1705 train_time:85323ms step_avg:94.91ms +step:900/1705 train_time:85417ms step_avg:94.91ms +step:901/1705 train_time:85512ms step_avg:94.91ms +step:902/1705 train_time:85607ms step_avg:94.91ms +step:903/1705 train_time:85701ms step_avg:94.91ms +step:904/1705 train_time:85795ms step_avg:94.91ms +step:905/1705 train_time:85889ms step_avg:94.91ms +step:906/1705 train_time:85984ms step_avg:94.91ms +step:907/1705 train_time:86077ms step_avg:94.90ms +step:908/1705 train_time:86172ms step_avg:94.90ms +step:909/1705 train_time:86267ms step_avg:94.90ms +step:910/1705 train_time:86363ms step_avg:94.90ms +step:911/1705 train_time:86457ms step_avg:94.90ms +step:912/1705 train_time:86552ms step_avg:94.90ms +step:913/1705 train_time:86647ms step_avg:94.90ms +step:914/1705 train_time:86742ms step_avg:94.90ms +step:915/1705 train_time:86836ms step_avg:94.90ms +step:916/1705 train_time:86930ms step_avg:94.90ms +step:917/1705 train_time:87024ms step_avg:94.90ms +step:918/1705 train_time:87118ms step_avg:94.90ms +step:919/1705 train_time:87212ms step_avg:94.90ms +step:920/1705 train_time:87308ms step_avg:94.90ms +step:921/1705 train_time:87405ms step_avg:94.90ms +step:922/1705 train_time:87499ms step_avg:94.90ms +step:923/1705 train_time:87594ms step_avg:94.90ms +step:924/1705 train_time:87689ms step_avg:94.90ms +step:925/1705 train_time:87783ms step_avg:94.90ms +step:926/1705 train_time:87876ms step_avg:94.90ms +step:927/1705 train_time:87970ms step_avg:94.90ms +step:928/1705 train_time:88064ms step_avg:94.90ms +step:929/1705 train_time:88158ms step_avg:94.90ms +step:930/1705 train_time:88253ms step_avg:94.90ms +step:931/1705 train_time:88348ms step_avg:94.90ms +step:932/1705 train_time:88443ms step_avg:94.90ms +step:933/1705 train_time:88538ms step_avg:94.90ms +step:934/1705 train_time:88633ms step_avg:94.90ms +step:935/1705 train_time:88728ms step_avg:94.90ms +step:936/1705 train_time:88823ms step_avg:94.90ms +step:937/1705 train_time:88916ms step_avg:94.89ms +step:938/1705 train_time:89010ms step_avg:94.89ms +step:939/1705 train_time:89105ms step_avg:94.89ms +step:940/1705 train_time:89200ms step_avg:94.89ms +step:941/1705 train_time:89294ms step_avg:94.89ms +step:942/1705 train_time:89388ms step_avg:94.89ms +step:943/1705 train_time:89483ms step_avg:94.89ms +step:944/1705 train_time:89578ms step_avg:94.89ms +step:945/1705 train_time:89672ms step_avg:94.89ms +step:946/1705 train_time:89768ms step_avg:94.89ms +step:947/1705 train_time:89862ms step_avg:94.89ms +step:948/1705 train_time:89956ms step_avg:94.89ms +step:949/1705 train_time:90050ms step_avg:94.89ms +step:950/1705 train_time:90145ms step_avg:94.89ms +step:951/1705 train_time:90239ms step_avg:94.89ms +step:952/1705 train_time:90334ms step_avg:94.89ms +step:953/1705 train_time:90429ms step_avg:94.89ms +step:954/1705 train_time:90524ms step_avg:94.89ms +step:955/1705 train_time:90618ms step_avg:94.89ms +step:956/1705 train_time:90712ms step_avg:94.89ms +step:957/1705 train_time:90808ms step_avg:94.89ms +step:958/1705 train_time:90903ms step_avg:94.89ms +step:959/1705 train_time:90997ms step_avg:94.89ms +step:960/1705 train_time:91091ms step_avg:94.89ms +step:961/1705 train_time:91186ms step_avg:94.89ms +step:962/1705 train_time:91281ms step_avg:94.89ms +step:963/1705 train_time:91375ms step_avg:94.89ms +step:964/1705 train_time:91469ms step_avg:94.89ms +step:965/1705 train_time:91564ms step_avg:94.88ms +step:966/1705 train_time:91658ms step_avg:94.88ms +step:967/1705 train_time:91753ms step_avg:94.88ms +step:968/1705 train_time:91848ms step_avg:94.88ms +step:969/1705 train_time:91943ms step_avg:94.88ms +step:970/1705 train_time:92036ms step_avg:94.88ms +step:971/1705 train_time:92131ms step_avg:94.88ms +step:972/1705 train_time:92225ms step_avg:94.88ms +step:973/1705 train_time:92320ms step_avg:94.88ms +step:974/1705 train_time:92414ms step_avg:94.88ms +step:975/1705 train_time:92509ms step_avg:94.88ms +step:976/1705 train_time:92603ms step_avg:94.88ms +step:977/1705 train_time:92697ms step_avg:94.88ms +step:978/1705 train_time:92791ms step_avg:94.88ms +step:979/1705 train_time:92886ms step_avg:94.88ms +step:980/1705 train_time:92981ms step_avg:94.88ms +step:981/1705 train_time:93075ms step_avg:94.88ms +step:982/1705 train_time:93170ms step_avg:94.88ms +step:983/1705 train_time:93265ms step_avg:94.88ms +step:984/1705 train_time:93358ms step_avg:94.88ms +step:985/1705 train_time:93453ms step_avg:94.88ms +step:986/1705 train_time:93548ms step_avg:94.88ms +step:987/1705 train_time:93643ms step_avg:94.88ms +step:988/1705 train_time:93738ms step_avg:94.88ms +step:989/1705 train_time:93833ms step_avg:94.88ms +step:990/1705 train_time:93928ms step_avg:94.88ms +step:991/1705 train_time:94023ms step_avg:94.88ms +step:992/1705 train_time:94117ms step_avg:94.88ms +step:993/1705 train_time:94212ms step_avg:94.88ms +step:994/1705 train_time:94306ms step_avg:94.88ms +step:995/1705 train_time:94400ms step_avg:94.87ms +step:996/1705 train_time:94494ms step_avg:94.87ms +step:997/1705 train_time:94590ms step_avg:94.87ms +step:998/1705 train_time:94685ms step_avg:94.87ms +step:999/1705 train_time:94779ms step_avg:94.87ms +step:1000/1705 train_time:94873ms step_avg:94.87ms +step:1000/1705 val_loss:3.4856 train_time:94969ms step_avg:94.97ms +step:1001/1705 train_time:94990ms step_avg:94.90ms +step:1002/1705 train_time:95068ms step_avg:94.88ms +step:1003/1705 train_time:95169ms step_avg:94.88ms +step:1004/1705 train_time:95263ms step_avg:94.88ms +step:1005/1705 train_time:95357ms step_avg:94.88ms +step:1006/1705 train_time:95450ms step_avg:94.88ms +step:1007/1705 train_time:95544ms step_avg:94.88ms +step:1008/1705 train_time:95636ms step_avg:94.88ms +step:1009/1705 train_time:95730ms step_avg:94.88ms +step:1010/1705 train_time:95824ms step_avg:94.88ms +step:1011/1705 train_time:95919ms step_avg:94.87ms +step:1012/1705 train_time:96015ms step_avg:94.88ms +step:1013/1705 train_time:96112ms step_avg:94.88ms +step:1014/1705 train_time:96208ms step_avg:94.88ms +step:1015/1705 train_time:96304ms step_avg:94.88ms +step:1016/1705 train_time:96397ms step_avg:94.88ms +step:1017/1705 train_time:96492ms step_avg:94.88ms +step:1018/1705 train_time:96585ms step_avg:94.88ms +step:1019/1705 train_time:96679ms step_avg:94.88ms +step:1020/1705 train_time:96772ms step_avg:94.87ms +step:1021/1705 train_time:96866ms step_avg:94.87ms +step:1022/1705 train_time:96962ms step_avg:94.87ms +step:1023/1705 train_time:97056ms step_avg:94.87ms +step:1024/1705 train_time:97153ms step_avg:94.88ms +step:1025/1705 train_time:97249ms step_avg:94.88ms +step:1026/1705 train_time:97345ms step_avg:94.88ms +step:1027/1705 train_time:97439ms step_avg:94.88ms +step:1028/1705 train_time:97533ms step_avg:94.88ms +step:1029/1705 train_time:97627ms step_avg:94.88ms +step:1030/1705 train_time:97721ms step_avg:94.87ms +step:1031/1705 train_time:97815ms step_avg:94.87ms +step:1032/1705 train_time:97909ms step_avg:94.87ms +step:1033/1705 train_time:98004ms step_avg:94.87ms +step:1034/1705 train_time:98100ms step_avg:94.87ms +step:1035/1705 train_time:98195ms step_avg:94.87ms +step:1036/1705 train_time:98290ms step_avg:94.87ms +step:1037/1705 train_time:98385ms step_avg:94.87ms +step:1038/1705 train_time:98479ms step_avg:94.87ms +step:1039/1705 train_time:98573ms step_avg:94.87ms +step:1040/1705 train_time:98668ms step_avg:94.87ms +step:1041/1705 train_time:98762ms step_avg:94.87ms +step:1042/1705 train_time:98856ms step_avg:94.87ms +step:1043/1705 train_time:98950ms step_avg:94.87ms +step:1044/1705 train_time:99045ms step_avg:94.87ms +step:1045/1705 train_time:99141ms step_avg:94.87ms +step:1046/1705 train_time:99236ms step_avg:94.87ms +step:1047/1705 train_time:99331ms step_avg:94.87ms +step:1048/1705 train_time:99426ms step_avg:94.87ms +step:1049/1705 train_time:99521ms step_avg:94.87ms +step:1050/1705 train_time:99615ms step_avg:94.87ms +step:1051/1705 train_time:99710ms step_avg:94.87ms +step:1052/1705 train_time:99805ms step_avg:94.87ms +step:1053/1705 train_time:99899ms step_avg:94.87ms +step:1054/1705 train_time:99993ms step_avg:94.87ms +step:1055/1705 train_time:100089ms step_avg:94.87ms +step:1056/1705 train_time:100184ms step_avg:94.87ms +step:1057/1705 train_time:100278ms step_avg:94.87ms +step:1058/1705 train_time:100372ms step_avg:94.87ms +step:1059/1705 train_time:100467ms step_avg:94.87ms +step:1060/1705 train_time:100562ms step_avg:94.87ms +step:1061/1705 train_time:100655ms step_avg:94.87ms +step:1062/1705 train_time:100895ms step_avg:95.00ms +step:1063/1705 train_time:101100ms step_avg:95.11ms +step:1064/1705 train_time:101192ms step_avg:95.11ms +step:1065/1705 train_time:101286ms step_avg:95.10ms +step:1066/1705 train_time:101379ms step_avg:95.10ms +step:1067/1705 train_time:101472ms step_avg:95.10ms +step:1068/1705 train_time:101565ms step_avg:95.10ms +step:1069/1705 train_time:101659ms step_avg:95.10ms +step:1070/1705 train_time:101752ms step_avg:95.10ms +step:1071/1705 train_time:101846ms step_avg:95.09ms +step:1072/1705 train_time:101949ms step_avg:95.10ms +step:1073/1705 train_time:102047ms step_avg:95.10ms +step:1074/1705 train_time:102145ms step_avg:95.11ms +step:1075/1705 train_time:102240ms step_avg:95.11ms +step:1076/1705 train_time:102334ms step_avg:95.11ms +step:1077/1705 train_time:102428ms step_avg:95.10ms +step:1078/1705 train_time:102521ms step_avg:95.10ms +step:1079/1705 train_time:102615ms step_avg:95.10ms +step:1080/1705 train_time:102708ms step_avg:95.10ms +step:1081/1705 train_time:102802ms step_avg:95.10ms +step:1082/1705 train_time:102896ms step_avg:95.10ms +step:1083/1705 train_time:102993ms step_avg:95.10ms +step:1084/1705 train_time:103090ms step_avg:95.10ms +step:1085/1705 train_time:103187ms step_avg:95.10ms +step:1086/1705 train_time:103281ms step_avg:95.10ms +step:1087/1705 train_time:103375ms step_avg:95.10ms +step:1088/1705 train_time:103469ms step_avg:95.10ms +step:1089/1705 train_time:103563ms step_avg:95.10ms +step:1090/1705 train_time:103656ms step_avg:95.10ms +step:1091/1705 train_time:103750ms step_avg:95.10ms +step:1092/1705 train_time:103845ms step_avg:95.10ms +step:1093/1705 train_time:103940ms step_avg:95.10ms +step:1094/1705 train_time:104035ms step_avg:95.10ms +step:1095/1705 train_time:104131ms step_avg:95.10ms +step:1096/1705 train_time:104227ms step_avg:95.10ms +step:1097/1705 train_time:104322ms step_avg:95.10ms +step:1098/1705 train_time:104416ms step_avg:95.10ms +step:1099/1705 train_time:104510ms step_avg:95.10ms +step:1100/1705 train_time:104604ms step_avg:95.09ms +step:1101/1705 train_time:104698ms step_avg:95.09ms +step:1102/1705 train_time:104792ms step_avg:95.09ms +step:1103/1705 train_time:104888ms step_avg:95.09ms +step:1104/1705 train_time:104982ms step_avg:95.09ms +step:1105/1705 train_time:105076ms step_avg:95.09ms +step:1106/1705 train_time:105172ms step_avg:95.09ms +step:1107/1705 train_time:105268ms step_avg:95.09ms +step:1108/1705 train_time:105363ms step_avg:95.09ms +step:1109/1705 train_time:105456ms step_avg:95.09ms +step:1110/1705 train_time:105551ms step_avg:95.09ms +step:1111/1705 train_time:105645ms step_avg:95.09ms +step:1112/1705 train_time:105739ms step_avg:95.09ms +step:1113/1705 train_time:105834ms step_avg:95.09ms +step:1114/1705 train_time:105930ms step_avg:95.09ms +step:1115/1705 train_time:106025ms step_avg:95.09ms +step:1116/1705 train_time:106119ms step_avg:95.09ms +step:1117/1705 train_time:106214ms step_avg:95.09ms +step:1118/1705 train_time:106309ms step_avg:95.09ms +step:1119/1705 train_time:106404ms step_avg:95.09ms +step:1120/1705 train_time:106498ms step_avg:95.09ms +step:1121/1705 train_time:106592ms step_avg:95.09ms +step:1122/1705 train_time:106686ms step_avg:95.09ms +step:1123/1705 train_time:106781ms step_avg:95.09ms +step:1124/1705 train_time:106875ms step_avg:95.08ms +step:1125/1705 train_time:106970ms step_avg:95.08ms +step:1125/1705 val_loss:3.4374 train_time:107065ms step_avg:95.17ms +step:1126/1705 train_time:107086ms step_avg:95.10ms +step:1127/1705 train_time:107165ms step_avg:95.09ms +step:1128/1705 train_time:107264ms step_avg:95.09ms +step:1129/1705 train_time:107359ms step_avg:95.09ms +step:1130/1705 train_time:107453ms step_avg:95.09ms +step:1131/1705 train_time:107547ms step_avg:95.09ms +step:1132/1705 train_time:107641ms step_avg:95.09ms +step:1133/1705 train_time:107734ms step_avg:95.09ms +step:1134/1705 train_time:107827ms step_avg:95.09ms +step:1135/1705 train_time:107920ms step_avg:95.08ms +step:1136/1705 train_time:108015ms step_avg:95.08ms +step:1137/1705 train_time:108112ms step_avg:95.09ms +step:1138/1705 train_time:108210ms step_avg:95.09ms +step:1139/1705 train_time:108306ms step_avg:95.09ms +step:1140/1705 train_time:108402ms step_avg:95.09ms +step:1141/1705 train_time:108496ms step_avg:95.09ms +step:1142/1705 train_time:108591ms step_avg:95.09ms +step:1143/1705 train_time:108686ms step_avg:95.09ms +step:1144/1705 train_time:108780ms step_avg:95.09ms +step:1145/1705 train_time:108875ms step_avg:95.09ms +step:1146/1705 train_time:108970ms step_avg:95.09ms +step:1147/1705 train_time:109065ms step_avg:95.09ms +step:1148/1705 train_time:109161ms step_avg:95.09ms +step:1149/1705 train_time:109257ms step_avg:95.09ms +step:1150/1705 train_time:109353ms step_avg:95.09ms +step:1151/1705 train_time:109449ms step_avg:95.09ms +step:1152/1705 train_time:109545ms step_avg:95.09ms +step:1153/1705 train_time:109639ms step_avg:95.09ms +step:1154/1705 train_time:109733ms step_avg:95.09ms +step:1155/1705 train_time:109828ms step_avg:95.09ms +step:1156/1705 train_time:109923ms step_avg:95.09ms +step:1157/1705 train_time:110017ms step_avg:95.09ms +step:1158/1705 train_time:110114ms step_avg:95.09ms +step:1159/1705 train_time:110210ms step_avg:95.09ms +step:1160/1705 train_time:110308ms step_avg:95.09ms +step:1161/1705 train_time:110404ms step_avg:95.09ms +step:1162/1705 train_time:110498ms step_avg:95.09ms +step:1163/1705 train_time:110594ms step_avg:95.09ms +step:1164/1705 train_time:110689ms step_avg:95.09ms +step:1165/1705 train_time:110784ms step_avg:95.09ms +step:1166/1705 train_time:110879ms step_avg:95.09ms +step:1167/1705 train_time:110974ms step_avg:95.09ms +step:1168/1705 train_time:111070ms step_avg:95.09ms +step:1169/1705 train_time:111165ms step_avg:95.09ms +step:1170/1705 train_time:111260ms step_avg:95.09ms +step:1171/1705 train_time:111356ms step_avg:95.09ms +step:1172/1705 train_time:111452ms step_avg:95.10ms +step:1173/1705 train_time:111548ms step_avg:95.10ms +step:1174/1705 train_time:111643ms step_avg:95.10ms +step:1175/1705 train_time:111737ms step_avg:95.10ms +step:1176/1705 train_time:111832ms step_avg:95.10ms +step:1177/1705 train_time:111927ms step_avg:95.10ms +step:1178/1705 train_time:112022ms step_avg:95.10ms +step:1179/1705 train_time:112117ms step_avg:95.09ms +step:1180/1705 train_time:112213ms step_avg:95.10ms +step:1181/1705 train_time:112310ms step_avg:95.10ms +step:1182/1705 train_time:112406ms step_avg:95.10ms +step:1183/1705 train_time:112501ms step_avg:95.10ms +step:1184/1705 train_time:112596ms step_avg:95.10ms +step:1185/1705 train_time:112691ms step_avg:95.10ms +step:1186/1705 train_time:112786ms step_avg:95.10ms +step:1187/1705 train_time:112881ms step_avg:95.10ms +step:1188/1705 train_time:112976ms step_avg:95.10ms +step:1189/1705 train_time:113072ms step_avg:95.10ms +step:1190/1705 train_time:113168ms step_avg:95.10ms +step:1191/1705 train_time:113265ms step_avg:95.10ms +step:1192/1705 train_time:113362ms step_avg:95.10ms +step:1193/1705 train_time:113458ms step_avg:95.10ms +step:1194/1705 train_time:113553ms step_avg:95.10ms +step:1195/1705 train_time:113649ms step_avg:95.10ms +step:1196/1705 train_time:113743ms step_avg:95.10ms +step:1197/1705 train_time:113838ms step_avg:95.10ms +step:1198/1705 train_time:113933ms step_avg:95.10ms +step:1199/1705 train_time:114028ms step_avg:95.10ms +step:1200/1705 train_time:114123ms step_avg:95.10ms +step:1201/1705 train_time:114218ms step_avg:95.10ms +step:1202/1705 train_time:114315ms step_avg:95.10ms +step:1203/1705 train_time:114410ms step_avg:95.10ms +step:1204/1705 train_time:114506ms step_avg:95.10ms +step:1205/1705 train_time:114601ms step_avg:95.10ms +step:1206/1705 train_time:114696ms step_avg:95.10ms +step:1207/1705 train_time:114791ms step_avg:95.10ms +step:1208/1705 train_time:114887ms step_avg:95.11ms +step:1209/1705 train_time:114982ms step_avg:95.11ms +step:1210/1705 train_time:115076ms step_avg:95.10ms +step:1211/1705 train_time:115171ms step_avg:95.10ms +step:1212/1705 train_time:115268ms step_avg:95.11ms +step:1213/1705 train_time:115364ms step_avg:95.11ms +step:1214/1705 train_time:115459ms step_avg:95.11ms +step:1215/1705 train_time:115554ms step_avg:95.11ms +step:1216/1705 train_time:115650ms step_avg:95.11ms +step:1217/1705 train_time:115745ms step_avg:95.11ms +step:1218/1705 train_time:115840ms step_avg:95.11ms +step:1219/1705 train_time:115935ms step_avg:95.11ms +step:1220/1705 train_time:116031ms step_avg:95.11ms +step:1221/1705 train_time:116125ms step_avg:95.11ms +step:1222/1705 train_time:116221ms step_avg:95.11ms +step:1223/1705 train_time:116315ms step_avg:95.11ms +step:1224/1705 train_time:116412ms step_avg:95.11ms +step:1225/1705 train_time:116510ms step_avg:95.11ms +step:1226/1705 train_time:116605ms step_avg:95.11ms +step:1227/1705 train_time:116701ms step_avg:95.11ms +step:1228/1705 train_time:116796ms step_avg:95.11ms +step:1229/1705 train_time:116891ms step_avg:95.11ms +step:1230/1705 train_time:116986ms step_avg:95.11ms +step:1231/1705 train_time:117081ms step_avg:95.11ms +step:1232/1705 train_time:117176ms step_avg:95.11ms +step:1233/1705 train_time:117272ms step_avg:95.11ms +step:1234/1705 train_time:117367ms step_avg:95.11ms +step:1235/1705 train_time:117462ms step_avg:95.11ms +step:1236/1705 train_time:117558ms step_avg:95.11ms +step:1237/1705 train_time:117654ms step_avg:95.11ms +step:1238/1705 train_time:117751ms step_avg:95.11ms +step:1239/1705 train_time:117848ms step_avg:95.12ms +step:1240/1705 train_time:117943ms step_avg:95.12ms +step:1241/1705 train_time:118039ms step_avg:95.12ms +step:1242/1705 train_time:118133ms step_avg:95.12ms +step:1243/1705 train_time:118228ms step_avg:95.11ms +step:1244/1705 train_time:118323ms step_avg:95.11ms +step:1245/1705 train_time:118418ms step_avg:95.11ms +step:1246/1705 train_time:118513ms step_avg:95.11ms +step:1247/1705 train_time:118609ms step_avg:95.12ms +step:1248/1705 train_time:118705ms step_avg:95.12ms +step:1249/1705 train_time:118799ms step_avg:95.12ms +step:1250/1705 train_time:118894ms step_avg:95.11ms +step:1250/1705 val_loss:3.3898 train_time:118991ms step_avg:95.19ms +step:1251/1705 train_time:119012ms step_avg:95.13ms +step:1252/1705 train_time:119098ms step_avg:95.13ms +step:1253/1705 train_time:119194ms step_avg:95.13ms +step:1254/1705 train_time:119288ms step_avg:95.13ms +step:1255/1705 train_time:119382ms step_avg:95.13ms +step:1256/1705 train_time:119477ms step_avg:95.12ms +step:1257/1705 train_time:119571ms step_avg:95.12ms +step:1258/1705 train_time:119665ms step_avg:95.12ms +step:1259/1705 train_time:119759ms step_avg:95.12ms +step:1260/1705 train_time:119853ms step_avg:95.12ms +step:1261/1705 train_time:119951ms step_avg:95.12ms +step:1262/1705 train_time:120051ms step_avg:95.13ms +step:1263/1705 train_time:120149ms step_avg:95.13ms +step:1264/1705 train_time:120244ms step_avg:95.13ms +step:1265/1705 train_time:120339ms step_avg:95.13ms +step:1266/1705 train_time:120433ms step_avg:95.13ms +step:1267/1705 train_time:120528ms step_avg:95.13ms +step:1268/1705 train_time:120622ms step_avg:95.13ms +step:1269/1705 train_time:120716ms step_avg:95.13ms +step:1270/1705 train_time:120810ms step_avg:95.13ms +step:1271/1705 train_time:120906ms step_avg:95.13ms +step:1272/1705 train_time:121003ms step_avg:95.13ms +step:1273/1705 train_time:121100ms step_avg:95.13ms +step:1274/1705 train_time:121476ms step_avg:95.35ms +step:1275/1705 train_time:121552ms step_avg:95.33ms +step:1276/1705 train_time:121647ms step_avg:95.33ms +step:1277/1705 train_time:121742ms step_avg:95.33ms +step:1278/1705 train_time:121836ms step_avg:95.33ms +step:1279/1705 train_time:121930ms step_avg:95.33ms +step:1280/1705 train_time:122024ms step_avg:95.33ms +step:1281/1705 train_time:122118ms step_avg:95.33ms +step:1282/1705 train_time:122212ms step_avg:95.33ms +step:1283/1705 train_time:122306ms step_avg:95.33ms +step:1284/1705 train_time:122408ms step_avg:95.33ms +step:1285/1705 train_time:122507ms step_avg:95.34ms +step:1286/1705 train_time:122603ms step_avg:95.34ms +step:1287/1705 train_time:122699ms step_avg:95.34ms +step:1288/1705 train_time:122794ms step_avg:95.34ms +step:1289/1705 train_time:122889ms step_avg:95.34ms +step:1290/1705 train_time:122983ms step_avg:95.34ms +step:1291/1705 train_time:123077ms step_avg:95.33ms +step:1292/1705 train_time:123171ms step_avg:95.33ms +step:1293/1705 train_time:123266ms step_avg:95.33ms +step:1294/1705 train_time:123364ms step_avg:95.34ms +step:1295/1705 train_time:123462ms step_avg:95.34ms +step:1296/1705 train_time:123559ms step_avg:95.34ms +step:1297/1705 train_time:123654ms step_avg:95.34ms +step:1298/1705 train_time:123749ms step_avg:95.34ms +step:1299/1705 train_time:123845ms step_avg:95.34ms +step:1300/1705 train_time:123940ms step_avg:95.34ms +step:1301/1705 train_time:124035ms step_avg:95.34ms +step:1302/1705 train_time:124130ms step_avg:95.34ms +step:1303/1705 train_time:124224ms step_avg:95.34ms +step:1304/1705 train_time:124320ms step_avg:95.34ms +step:1305/1705 train_time:124417ms step_avg:95.34ms +step:1306/1705 train_time:124512ms step_avg:95.34ms +step:1307/1705 train_time:124607ms step_avg:95.34ms +step:1308/1705 train_time:124703ms step_avg:95.34ms +step:1309/1705 train_time:124799ms step_avg:95.34ms +step:1310/1705 train_time:124894ms step_avg:95.34ms +step:1311/1705 train_time:124988ms step_avg:95.34ms +step:1312/1705 train_time:125083ms step_avg:95.34ms +step:1313/1705 train_time:125179ms step_avg:95.34ms +step:1314/1705 train_time:125274ms step_avg:95.34ms +step:1315/1705 train_time:125370ms step_avg:95.34ms +step:1316/1705 train_time:125466ms step_avg:95.34ms +step:1317/1705 train_time:125562ms step_avg:95.34ms +step:1318/1705 train_time:125658ms step_avg:95.34ms +step:1319/1705 train_time:125755ms step_avg:95.34ms +step:1320/1705 train_time:125849ms step_avg:95.34ms +step:1321/1705 train_time:125944ms step_avg:95.34ms +step:1322/1705 train_time:126039ms step_avg:95.34ms +step:1323/1705 train_time:126134ms step_avg:95.34ms +step:1324/1705 train_time:126229ms step_avg:95.34ms +step:1325/1705 train_time:126324ms step_avg:95.34ms +step:1326/1705 train_time:126420ms step_avg:95.34ms +step:1327/1705 train_time:126516ms step_avg:95.34ms +step:1328/1705 train_time:126610ms step_avg:95.34ms +step:1329/1705 train_time:126706ms step_avg:95.34ms +step:1330/1705 train_time:126801ms step_avg:95.34ms +step:1331/1705 train_time:126896ms step_avg:95.34ms +step:1332/1705 train_time:126991ms step_avg:95.34ms +step:1333/1705 train_time:127087ms step_avg:95.34ms +step:1334/1705 train_time:127182ms step_avg:95.34ms +step:1335/1705 train_time:127277ms step_avg:95.34ms +step:1336/1705 train_time:127372ms step_avg:95.34ms +step:1337/1705 train_time:127468ms step_avg:95.34ms +step:1338/1705 train_time:127565ms step_avg:95.34ms +step:1339/1705 train_time:127661ms step_avg:95.34ms +step:1340/1705 train_time:127756ms step_avg:95.34ms +step:1341/1705 train_time:127851ms step_avg:95.34ms +step:1342/1705 train_time:127946ms step_avg:95.34ms +step:1343/1705 train_time:128041ms step_avg:95.34ms +step:1344/1705 train_time:128136ms step_avg:95.34ms +step:1345/1705 train_time:128230ms step_avg:95.34ms +step:1346/1705 train_time:128326ms step_avg:95.34ms +step:1347/1705 train_time:128422ms step_avg:95.34ms +step:1348/1705 train_time:128519ms step_avg:95.34ms +step:1349/1705 train_time:128613ms step_avg:95.34ms +step:1350/1705 train_time:128708ms step_avg:95.34ms +step:1351/1705 train_time:128804ms step_avg:95.34ms +step:1352/1705 train_time:128899ms step_avg:95.34ms +step:1353/1705 train_time:128994ms step_avg:95.34ms +step:1354/1705 train_time:129089ms step_avg:95.34ms +step:1355/1705 train_time:129184ms step_avg:95.34ms +step:1356/1705 train_time:129279ms step_avg:95.34ms +step:1357/1705 train_time:129374ms step_avg:95.34ms +step:1358/1705 train_time:129470ms step_avg:95.34ms +step:1359/1705 train_time:129567ms step_avg:95.34ms +step:1360/1705 train_time:129662ms step_avg:95.34ms +step:1361/1705 train_time:129758ms step_avg:95.34ms +step:1362/1705 train_time:129852ms step_avg:95.34ms +step:1363/1705 train_time:129949ms step_avg:95.34ms +step:1364/1705 train_time:130044ms step_avg:95.34ms +step:1365/1705 train_time:130139ms step_avg:95.34ms +step:1366/1705 train_time:130233ms step_avg:95.34ms +step:1367/1705 train_time:130329ms step_avg:95.34ms +step:1368/1705 train_time:130425ms step_avg:95.34ms +step:1369/1705 train_time:130520ms step_avg:95.34ms +step:1370/1705 train_time:130616ms step_avg:95.34ms +step:1371/1705 train_time:130712ms step_avg:95.34ms +step:1372/1705 train_time:130807ms step_avg:95.34ms +step:1373/1705 train_time:130903ms step_avg:95.34ms +step:1374/1705 train_time:130998ms step_avg:95.34ms +step:1375/1705 train_time:131093ms step_avg:95.34ms +step:1375/1705 val_loss:3.3522 train_time:131189ms step_avg:95.41ms +step:1376/1705 train_time:131210ms step_avg:95.36ms +step:1377/1705 train_time:131288ms step_avg:95.34ms +step:1378/1705 train_time:131389ms step_avg:95.35ms +step:1379/1705 train_time:131483ms step_avg:95.35ms +step:1380/1705 train_time:131578ms step_avg:95.35ms +step:1381/1705 train_time:131672ms step_avg:95.35ms +step:1382/1705 train_time:131767ms step_avg:95.34ms +step:1383/1705 train_time:131861ms step_avg:95.34ms +step:1384/1705 train_time:131955ms step_avg:95.34ms +step:1385/1705 train_time:132049ms step_avg:95.34ms +step:1386/1705 train_time:132146ms step_avg:95.34ms +step:1387/1705 train_time:132242ms step_avg:95.34ms +step:1388/1705 train_time:132338ms step_avg:95.34ms +step:1389/1705 train_time:132434ms step_avg:95.35ms +step:1390/1705 train_time:132530ms step_avg:95.35ms +step:1391/1705 train_time:132625ms step_avg:95.35ms +step:1392/1705 train_time:132720ms step_avg:95.34ms +step:1393/1705 train_time:132815ms step_avg:95.34ms +step:1394/1705 train_time:132909ms step_avg:95.34ms +step:1395/1705 train_time:133004ms step_avg:95.34ms +step:1396/1705 train_time:133098ms step_avg:95.34ms +step:1397/1705 train_time:133194ms step_avg:95.34ms +step:1398/1705 train_time:133290ms step_avg:95.34ms +step:1399/1705 train_time:133386ms step_avg:95.34ms +step:1400/1705 train_time:133481ms step_avg:95.34ms +step:1401/1705 train_time:133576ms step_avg:95.34ms +step:1402/1705 train_time:133672ms step_avg:95.34ms +step:1403/1705 train_time:133768ms step_avg:95.34ms +step:1404/1705 train_time:133863ms step_avg:95.34ms +step:1405/1705 train_time:133957ms step_avg:95.34ms +step:1406/1705 train_time:134053ms step_avg:95.34ms +step:1407/1705 train_time:134149ms step_avg:95.34ms +step:1408/1705 train_time:134246ms step_avg:95.34ms +step:1409/1705 train_time:134341ms step_avg:95.35ms +step:1410/1705 train_time:134436ms step_avg:95.34ms +step:1411/1705 train_time:134532ms step_avg:95.34ms +step:1412/1705 train_time:134627ms step_avg:95.35ms +step:1413/1705 train_time:134723ms step_avg:95.35ms +step:1414/1705 train_time:134818ms step_avg:95.35ms +step:1415/1705 train_time:134913ms step_avg:95.34ms +step:1416/1705 train_time:135007ms step_avg:95.34ms +step:1417/1705 train_time:135102ms step_avg:95.34ms +step:1418/1705 train_time:135197ms step_avg:95.34ms +step:1419/1705 train_time:135293ms step_avg:95.34ms +step:1420/1705 train_time:135390ms step_avg:95.35ms +step:1421/1705 train_time:135485ms step_avg:95.35ms +step:1422/1705 train_time:135581ms step_avg:95.35ms +step:1423/1705 train_time:135677ms step_avg:95.35ms +step:1424/1705 train_time:135772ms step_avg:95.35ms +step:1425/1705 train_time:135868ms step_avg:95.35ms +step:1426/1705 train_time:135963ms step_avg:95.35ms +step:1427/1705 train_time:136058ms step_avg:95.35ms +step:1428/1705 train_time:136154ms step_avg:95.35ms +step:1429/1705 train_time:136249ms step_avg:95.35ms +step:1430/1705 train_time:136346ms step_avg:95.35ms +step:1431/1705 train_time:136441ms step_avg:95.35ms +step:1432/1705 train_time:136536ms step_avg:95.35ms +step:1433/1705 train_time:136632ms step_avg:95.35ms +step:1434/1705 train_time:136728ms step_avg:95.35ms +step:1435/1705 train_time:136823ms step_avg:95.35ms +step:1436/1705 train_time:136918ms step_avg:95.35ms +step:1437/1705 train_time:137013ms step_avg:95.35ms +step:1438/1705 train_time:137108ms step_avg:95.35ms +step:1439/1705 train_time:137204ms step_avg:95.35ms +step:1440/1705 train_time:137299ms step_avg:95.35ms +step:1441/1705 train_time:137396ms step_avg:95.35ms +step:1442/1705 train_time:137491ms step_avg:95.35ms +step:1443/1705 train_time:137587ms step_avg:95.35ms +step:1444/1705 train_time:137681ms step_avg:95.35ms +step:1445/1705 train_time:137776ms step_avg:95.35ms +step:1446/1705 train_time:137872ms step_avg:95.35ms +step:1447/1705 train_time:137968ms step_avg:95.35ms +step:1448/1705 train_time:138064ms step_avg:95.35ms +step:1449/1705 train_time:138158ms step_avg:95.35ms +step:1450/1705 train_time:138255ms step_avg:95.35ms +step:1451/1705 train_time:138351ms step_avg:95.35ms +step:1452/1705 train_time:138447ms step_avg:95.35ms +step:1453/1705 train_time:138543ms step_avg:95.35ms +step:1454/1705 train_time:138638ms step_avg:95.35ms +step:1455/1705 train_time:138732ms step_avg:95.35ms +step:1456/1705 train_time:138829ms step_avg:95.35ms +step:1457/1705 train_time:138924ms step_avg:95.35ms +step:1458/1705 train_time:139020ms step_avg:95.35ms +step:1459/1705 train_time:139115ms step_avg:95.35ms +step:1460/1705 train_time:139211ms step_avg:95.35ms +step:1461/1705 train_time:139306ms step_avg:95.35ms +step:1462/1705 train_time:139401ms step_avg:95.35ms +step:1463/1705 train_time:139496ms step_avg:95.35ms +step:1464/1705 train_time:139592ms step_avg:95.35ms +step:1465/1705 train_time:139687ms step_avg:95.35ms +step:1466/1705 train_time:139782ms step_avg:95.35ms +step:1467/1705 train_time:139877ms step_avg:95.35ms +step:1468/1705 train_time:139973ms step_avg:95.35ms +step:1469/1705 train_time:140068ms step_avg:95.35ms +step:1470/1705 train_time:140164ms step_avg:95.35ms +step:1471/1705 train_time:140259ms step_avg:95.35ms +step:1472/1705 train_time:140355ms step_avg:95.35ms +step:1473/1705 train_time:140451ms step_avg:95.35ms +step:1474/1705 train_time:140547ms step_avg:95.35ms +step:1475/1705 train_time:140642ms step_avg:95.35ms +step:1476/1705 train_time:140736ms step_avg:95.35ms +step:1477/1705 train_time:140833ms step_avg:95.35ms +step:1478/1705 train_time:140928ms step_avg:95.35ms +step:1479/1705 train_time:141024ms step_avg:95.35ms +step:1480/1705 train_time:141120ms step_avg:95.35ms +step:1481/1705 train_time:141215ms step_avg:95.35ms +step:1482/1705 train_time:141311ms step_avg:95.35ms +step:1483/1705 train_time:141408ms step_avg:95.35ms +step:1484/1705 train_time:141503ms step_avg:95.35ms +step:1485/1705 train_time:141768ms step_avg:95.47ms +step:1486/1705 train_time:141957ms step_avg:95.53ms +step:1487/1705 train_time:142051ms step_avg:95.53ms +step:1488/1705 train_time:142145ms step_avg:95.53ms +step:1489/1705 train_time:142239ms step_avg:95.53ms +step:1490/1705 train_time:142334ms step_avg:95.53ms +step:1491/1705 train_time:142428ms step_avg:95.53ms +step:1492/1705 train_time:142523ms step_avg:95.52ms +step:1493/1705 train_time:142617ms step_avg:95.52ms +step:1494/1705 train_time:142711ms step_avg:95.52ms +step:1495/1705 train_time:142811ms step_avg:95.53ms +step:1496/1705 train_time:142911ms step_avg:95.53ms +step:1497/1705 train_time:143008ms step_avg:95.53ms +step:1498/1705 train_time:143103ms step_avg:95.53ms +step:1499/1705 train_time:143197ms step_avg:95.53ms +step:1500/1705 train_time:143293ms step_avg:95.53ms +step:1500/1705 val_loss:3.3199 train_time:143387ms step_avg:95.59ms +step:1501/1705 train_time:143409ms step_avg:95.54ms +step:1502/1705 train_time:143489ms step_avg:95.53ms +step:1503/1705 train_time:143588ms step_avg:95.53ms +step:1504/1705 train_time:143683ms step_avg:95.53ms +step:1505/1705 train_time:143778ms step_avg:95.53ms +step:1506/1705 train_time:143871ms step_avg:95.53ms +step:1507/1705 train_time:143965ms step_avg:95.53ms +step:1508/1705 train_time:144061ms step_avg:95.53ms +step:1509/1705 train_time:144154ms step_avg:95.53ms +step:1510/1705 train_time:144249ms step_avg:95.53ms +step:1511/1705 train_time:144345ms step_avg:95.53ms +step:1512/1705 train_time:144443ms step_avg:95.53ms +step:1513/1705 train_time:144539ms step_avg:95.53ms +step:1514/1705 train_time:144636ms step_avg:95.53ms +step:1515/1705 train_time:144731ms step_avg:95.53ms +step:1516/1705 train_time:144825ms step_avg:95.53ms +step:1517/1705 train_time:144920ms step_avg:95.53ms +step:1518/1705 train_time:145014ms step_avg:95.53ms +step:1519/1705 train_time:145108ms step_avg:95.53ms +step:1520/1705 train_time:145204ms step_avg:95.53ms +step:1521/1705 train_time:145300ms step_avg:95.53ms +step:1522/1705 train_time:145397ms step_avg:95.53ms +step:1523/1705 train_time:145494ms step_avg:95.53ms +step:1524/1705 train_time:145589ms step_avg:95.53ms +step:1525/1705 train_time:145687ms step_avg:95.53ms +step:1526/1705 train_time:145783ms step_avg:95.53ms +step:1527/1705 train_time:145877ms step_avg:95.53ms +step:1528/1705 train_time:145972ms step_avg:95.53ms +step:1529/1705 train_time:146066ms step_avg:95.53ms +step:1530/1705 train_time:146160ms step_avg:95.53ms +step:1531/1705 train_time:146255ms step_avg:95.53ms +step:1532/1705 train_time:146351ms step_avg:95.53ms +step:1533/1705 train_time:146449ms step_avg:95.53ms +step:1534/1705 train_time:146545ms step_avg:95.53ms +step:1535/1705 train_time:146641ms step_avg:95.53ms +step:1536/1705 train_time:146737ms step_avg:95.53ms +step:1537/1705 train_time:146831ms step_avg:95.53ms +step:1538/1705 train_time:146927ms step_avg:95.53ms +step:1539/1705 train_time:147022ms step_avg:95.53ms +step:1540/1705 train_time:147116ms step_avg:95.53ms +step:1541/1705 train_time:147210ms step_avg:95.53ms +step:1542/1705 train_time:147306ms step_avg:95.53ms +step:1543/1705 train_time:147403ms step_avg:95.53ms +step:1544/1705 train_time:147498ms step_avg:95.53ms +step:1545/1705 train_time:147593ms step_avg:95.53ms +step:1546/1705 train_time:147688ms step_avg:95.53ms +step:1547/1705 train_time:147785ms step_avg:95.53ms +step:1548/1705 train_time:147882ms step_avg:95.53ms +step:1549/1705 train_time:147979ms step_avg:95.53ms +step:1550/1705 train_time:148074ms step_avg:95.53ms +step:1551/1705 train_time:148168ms step_avg:95.53ms +step:1552/1705 train_time:148264ms step_avg:95.53ms +step:1553/1705 train_time:148360ms step_avg:95.53ms +step:1554/1705 train_time:148455ms step_avg:95.53ms +step:1555/1705 train_time:148550ms step_avg:95.53ms +step:1556/1705 train_time:148647ms step_avg:95.53ms +step:1557/1705 train_time:148742ms step_avg:95.53ms +step:1558/1705 train_time:148837ms step_avg:95.53ms +step:1559/1705 train_time:148932ms step_avg:95.53ms +step:1560/1705 train_time:149028ms step_avg:95.53ms +step:1561/1705 train_time:149124ms step_avg:95.53ms +step:1562/1705 train_time:149219ms step_avg:95.53ms +step:1563/1705 train_time:149313ms step_avg:95.53ms +step:1564/1705 train_time:149408ms step_avg:95.53ms +step:1565/1705 train_time:149505ms step_avg:95.53ms +step:1566/1705 train_time:149601ms step_avg:95.53ms +step:1567/1705 train_time:149697ms step_avg:95.53ms +step:1568/1705 train_time:149790ms step_avg:95.53ms +step:1569/1705 train_time:149887ms step_avg:95.53ms +step:1570/1705 train_time:149984ms step_avg:95.53ms +step:1571/1705 train_time:150079ms step_avg:95.53ms +step:1572/1705 train_time:150175ms step_avg:95.53ms +step:1573/1705 train_time:150270ms step_avg:95.53ms +step:1574/1705 train_time:150366ms step_avg:95.53ms +step:1575/1705 train_time:150462ms step_avg:95.53ms +step:1576/1705 train_time:150557ms step_avg:95.53ms +step:1577/1705 train_time:150652ms step_avg:95.53ms +step:1578/1705 train_time:150747ms step_avg:95.53ms +step:1579/1705 train_time:150842ms step_avg:95.53ms +step:1580/1705 train_time:150938ms step_avg:95.53ms +step:1581/1705 train_time:151035ms step_avg:95.53ms +step:1582/1705 train_time:151128ms step_avg:95.53ms +step:1583/1705 train_time:151224ms step_avg:95.53ms +step:1584/1705 train_time:151320ms step_avg:95.53ms +step:1585/1705 train_time:151415ms step_avg:95.53ms +step:1586/1705 train_time:151510ms step_avg:95.53ms +step:1587/1705 train_time:151606ms step_avg:95.53ms +step:1588/1705 train_time:151702ms step_avg:95.53ms +step:1589/1705 train_time:151798ms step_avg:95.53ms +step:1590/1705 train_time:151892ms step_avg:95.53ms +step:1591/1705 train_time:151988ms step_avg:95.53ms +step:1592/1705 train_time:152084ms step_avg:95.53ms +step:1593/1705 train_time:152179ms step_avg:95.53ms +step:1594/1705 train_time:152274ms step_avg:95.53ms +step:1595/1705 train_time:152369ms step_avg:95.53ms +step:1596/1705 train_time:152464ms step_avg:95.53ms +step:1597/1705 train_time:152560ms step_avg:95.53ms +step:1598/1705 train_time:152655ms step_avg:95.53ms +step:1599/1705 train_time:152750ms step_avg:95.53ms +step:1600/1705 train_time:152845ms step_avg:95.53ms +step:1601/1705 train_time:152941ms step_avg:95.53ms +step:1602/1705 train_time:153037ms step_avg:95.53ms +step:1603/1705 train_time:153131ms step_avg:95.53ms +step:1604/1705 train_time:153227ms step_avg:95.53ms +step:1605/1705 train_time:153323ms step_avg:95.53ms +step:1606/1705 train_time:153418ms step_avg:95.53ms +step:1607/1705 train_time:153513ms step_avg:95.53ms +step:1608/1705 train_time:153608ms step_avg:95.53ms +step:1609/1705 train_time:153704ms step_avg:95.53ms +step:1610/1705 train_time:153800ms step_avg:95.53ms +step:1611/1705 train_time:153895ms step_avg:95.53ms +step:1612/1705 train_time:153990ms step_avg:95.53ms +step:1613/1705 train_time:154086ms step_avg:95.53ms +step:1614/1705 train_time:154183ms step_avg:95.53ms +step:1615/1705 train_time:154278ms step_avg:95.53ms +step:1616/1705 train_time:154373ms step_avg:95.53ms +step:1617/1705 train_time:154467ms step_avg:95.53ms +step:1618/1705 train_time:154564ms step_avg:95.53ms +step:1619/1705 train_time:154660ms step_avg:95.53ms +step:1620/1705 train_time:154756ms step_avg:95.53ms +step:1621/1705 train_time:154851ms step_avg:95.53ms +step:1622/1705 train_time:154946ms step_avg:95.53ms +step:1623/1705 train_time:155041ms step_avg:95.53ms +step:1624/1705 train_time:155137ms step_avg:95.53ms +step:1625/1705 train_time:155232ms step_avg:95.53ms +step:1625/1705 val_loss:3.2922 train_time:155329ms step_avg:95.59ms +step:1626/1705 train_time:155351ms step_avg:95.54ms +step:1627/1705 train_time:155430ms step_avg:95.53ms +step:1628/1705 train_time:155530ms step_avg:95.53ms +step:1629/1705 train_time:155625ms step_avg:95.53ms +step:1630/1705 train_time:155720ms step_avg:95.53ms +step:1631/1705 train_time:155814ms step_avg:95.53ms +step:1632/1705 train_time:155909ms step_avg:95.53ms +step:1633/1705 train_time:156002ms step_avg:95.53ms +step:1634/1705 train_time:156097ms step_avg:95.53ms +step:1635/1705 train_time:156191ms step_avg:95.53ms +step:1636/1705 train_time:156287ms step_avg:95.53ms +step:1637/1705 train_time:156383ms step_avg:95.53ms +step:1638/1705 train_time:156483ms step_avg:95.53ms +step:1639/1705 train_time:156580ms step_avg:95.53ms +step:1640/1705 train_time:156675ms step_avg:95.53ms +step:1641/1705 train_time:156770ms step_avg:95.53ms +step:1642/1705 train_time:156865ms step_avg:95.53ms +step:1643/1705 train_time:156958ms step_avg:95.53ms +step:1644/1705 train_time:157054ms step_avg:95.53ms +step:1645/1705 train_time:157148ms step_avg:95.53ms +step:1646/1705 train_time:157243ms step_avg:95.53ms +step:1647/1705 train_time:157339ms step_avg:95.53ms +step:1648/1705 train_time:157437ms step_avg:95.53ms +step:1649/1705 train_time:157535ms step_avg:95.53ms +step:1650/1705 train_time:157632ms step_avg:95.53ms +step:1651/1705 train_time:157728ms step_avg:95.54ms +step:1652/1705 train_time:157823ms step_avg:95.53ms +step:1653/1705 train_time:157918ms step_avg:95.53ms +step:1654/1705 train_time:158012ms step_avg:95.53ms +step:1655/1705 train_time:158107ms step_avg:95.53ms +step:1656/1705 train_time:158201ms step_avg:95.53ms +step:1657/1705 train_time:158296ms step_avg:95.53ms +step:1658/1705 train_time:158392ms step_avg:95.53ms +step:1659/1705 train_time:158488ms step_avg:95.53ms +step:1660/1705 train_time:158585ms step_avg:95.53ms +step:1661/1705 train_time:158680ms step_avg:95.53ms +step:1662/1705 train_time:158776ms step_avg:95.53ms +step:1663/1705 train_time:158872ms step_avg:95.53ms +step:1664/1705 train_time:158967ms step_avg:95.53ms +step:1665/1705 train_time:159061ms step_avg:95.53ms +step:1666/1705 train_time:159156ms step_avg:95.53ms +step:1667/1705 train_time:159251ms step_avg:95.53ms +step:1668/1705 train_time:159345ms step_avg:95.53ms +step:1669/1705 train_time:159441ms step_avg:95.53ms +step:1670/1705 train_time:159537ms step_avg:95.53ms +step:1671/1705 train_time:159634ms step_avg:95.53ms +step:1672/1705 train_time:159730ms step_avg:95.53ms +step:1673/1705 train_time:159826ms step_avg:95.53ms +step:1674/1705 train_time:159920ms step_avg:95.53ms +step:1675/1705 train_time:160015ms step_avg:95.53ms +step:1676/1705 train_time:160110ms step_avg:95.53ms +step:1677/1705 train_time:160205ms step_avg:95.53ms +step:1678/1705 train_time:160300ms step_avg:95.53ms +step:1679/1705 train_time:160395ms step_avg:95.53ms +step:1680/1705 train_time:160492ms step_avg:95.53ms +step:1681/1705 train_time:160589ms step_avg:95.53ms +step:1682/1705 train_time:160685ms step_avg:95.53ms +step:1683/1705 train_time:160779ms step_avg:95.53ms +step:1684/1705 train_time:160874ms step_avg:95.53ms +step:1685/1705 train_time:160970ms step_avg:95.53ms +step:1686/1705 train_time:161065ms step_avg:95.53ms +step:1687/1705 train_time:161160ms step_avg:95.53ms +step:1688/1705 train_time:161255ms step_avg:95.53ms +step:1689/1705 train_time:161350ms step_avg:95.53ms +step:1690/1705 train_time:161445ms step_avg:95.53ms +step:1691/1705 train_time:161541ms step_avg:95.53ms +step:1692/1705 train_time:161637ms step_avg:95.53ms +step:1693/1705 train_time:161733ms step_avg:95.53ms +step:1694/1705 train_time:161830ms step_avg:95.53ms +step:1695/1705 train_time:161924ms step_avg:95.53ms +step:1696/1705 train_time:162020ms step_avg:95.53ms +step:1697/1705 train_time:162115ms step_avg:95.53ms +step:1698/1705 train_time:162370ms step_avg:95.62ms +step:1699/1705 train_time:162565ms step_avg:95.68ms +step:1700/1705 train_time:162658ms step_avg:95.68ms +step:1701/1705 train_time:162753ms step_avg:95.68ms +step:1702/1705 train_time:162847ms step_avg:95.68ms +step:1703/1705 train_time:162941ms step_avg:95.68ms +step:1704/1705 train_time:163035ms step_avg:95.68ms +step:1705/1705 train_time:163129ms step_avg:95.68ms +step:1705/1705 val_loss:3.2779 train_time:163224ms step_avg:95.73ms +peak memory allocated: 33750 MiB reserved: 49456 MiB diff --git a/train_gpt.py b/train_gpt.py index 8c5ea5788..ebd4b092a 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -644,19 +644,22 @@ def forward(self, x: Tensor): x = F.linear(x, self.c_proj.type_as(x)) return x + class Block(nn.Module): def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) + SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, seqlens: Tensor, bm_size: int): x = lambdas[0] * x + lambdas[1] * x0 if self.attn is not None: x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) - x = x + self.mlp(norm(x)) + if self.mlp is not None: + x = x + self.mlp(norm(x)) return x # ----------------------------------------------------------------------------- @@ -865,7 +868,7 @@ class Hyperparameters: train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1670 # number of iterations to run + num_iterations: int = 1705 # number of iterations to run cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate # evaluation and logging run_id: str = str(uuid.uuid4()) @@ -1057,4 +1060,4 @@ def get_ws(step: int): print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() \ No newline at end of file +dist.destroy_process_group() From 091cc978cd774c524906f5af1593c1547405390e Mon Sep 17 00:00:00 2001 From: EmelyanenkoK Date: Mon, 8 Sep 2025 21:10:37 +0300 Subject: [PATCH 06/14] Fix senseless removing of 12 (non-existent) block mlp layer --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index ebd4b092a..257dfe5ef 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -650,7 +650,7 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - SKIPPED_MLP_BLOCKS = [0, 12] # skip MLP blocks for first and last layers by @EmelyanenkoK + SKIPPED_MLP_BLOCKS = [0] # skip MLP blocks for first MLP layer by @EmelyanenkoK self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, From 34ae835a2aa8c3597a7c99378e53b108cfbae8b8 Mon Sep 17 00:00:00 2001 From: larry dial Date: Wed, 10 Sep 2025 23:28:30 -0700 Subject: [PATCH 07/14] yarn --- .../07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt | 0 .../1858912a-2697-4461-9edb-e5ee4246ee3d.txt | 0 .../3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt | 0 .../56955462-7201-4627-91d9-b2426a1424e2.txt | 0 .../5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt | 0 .../70af20aa-f602-4cc1-85e9-430a1664f62e.txt | 0 .../8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt | 0 .../cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt | 0 .../cf8c8a10-ea32-46a0-8276-241330023e83.txt | 0 ...n_0f6c8eac-db39-49ce-bef8-08a34044625f.txt | 0 ...n_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt | 0 ...n_3f42c181-6303-4ade-9f64-556d44d54065.txt | 0 ...n_50e5b966-21a9-4545-8c88-91308e140958.txt | 0 ...n_803c2d15-4adb-42d2-958b-0b712cd9d062.txt | 0 ...n_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt | 0 ...n_adcc39f4-c919-420a-bd94-9d0035f0038c.txt | 0 ...n_c753588f-47c7-4107-9087-3c5da90cc0f4.txt | 0 ...n_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt | 0 ...n_e501e1e9-39fa-473b-bded-39427f349f37.txt | 0 .../f01447c9-da70-405a-8ed0-858caadd1194.txt | 0 .../0ecdb695-510b-4c3b-b030-09861a162ce8.txt | 2863 +++++++++++++++++ .../132fe599-bc5a-4237-ad14-ee33cbbd5fc0.txt | 2863 +++++++++++++++++ .../61b04c65-2c0f-4d24-83e2-6035dfea1582.txt | 2863 +++++++++++++++++ .../6297777d-03bd-4955-9c3a-c854246b928a.txt | 2863 +++++++++++++++++ .../783d22ec-c441-4d93-9fd7-cd00d2c473e8.txt | 2863 +++++++++++++++++ .../9121a353-d3ce-4f54-98de-0b466773fe0b.txt | 2863 +++++++++++++++++ records/091025_Yarn/ReadMe.md | 94 + .../ef66c943-e262-400f-822f-068d397a1dc9.txt | 2863 +++++++++++++++++ train_gpt.py | 137 +- 29 files changed, 20226 insertions(+), 46 deletions(-) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/1858912a-2697-4461-9edb-e5ee4246ee3d.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/56955462-7201-4627-91d9-b2426a1424e2.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/70af20aa-f602-4cc1-85e9-430a1664f62e.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/cf8c8a10-ea32-46a0-8276-241330023e83.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt (100%) rename records/{050925_SkipMLPBlocks => 090525_SkipMLPBlocks}/f01447c9-da70-405a-8ed0-858caadd1194.txt (100%) create mode 100644 records/091025_Yarn/0ecdb695-510b-4c3b-b030-09861a162ce8.txt create mode 100644 records/091025_Yarn/132fe599-bc5a-4237-ad14-ee33cbbd5fc0.txt create mode 100644 records/091025_Yarn/61b04c65-2c0f-4d24-83e2-6035dfea1582.txt create mode 100644 records/091025_Yarn/6297777d-03bd-4955-9c3a-c854246b928a.txt create mode 100644 records/091025_Yarn/783d22ec-c441-4d93-9fd7-cd00d2c473e8.txt create mode 100644 records/091025_Yarn/9121a353-d3ce-4f54-98de-0b466773fe0b.txt create mode 100644 records/091025_Yarn/ReadMe.md create mode 100644 records/091025_Yarn/ef66c943-e262-400f-822f-068d397a1dc9.txt diff --git a/records/050925_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt b/records/090525_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt similarity index 100% rename from records/050925_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt rename to records/090525_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt diff --git a/records/050925_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt b/records/090525_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt similarity index 100% rename from records/050925_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt rename to records/090525_SkipMLPBlocks/1858912a-2697-4461-9edb-e5ee4246ee3d.txt diff --git a/records/050925_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt b/records/090525_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt similarity index 100% rename from records/050925_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt rename to records/090525_SkipMLPBlocks/3a3f4c61-475d-4fcb-a606-65aa3784d7af.txt diff --git a/records/050925_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt b/records/090525_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt similarity index 100% rename from records/050925_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt rename to records/090525_SkipMLPBlocks/56955462-7201-4627-91d9-b2426a1424e2.txt diff --git a/records/050925_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt b/records/090525_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt similarity index 100% rename from records/050925_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt rename to records/090525_SkipMLPBlocks/5ab34e6e-f1db-4ceb-a639-9186a26a48f5.txt diff --git a/records/050925_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt b/records/090525_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt similarity index 100% rename from records/050925_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt rename to records/090525_SkipMLPBlocks/70af20aa-f602-4cc1-85e9-430a1664f62e.txt diff --git a/records/050925_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt b/records/090525_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt similarity index 100% rename from records/050925_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt rename to records/090525_SkipMLPBlocks/8ac310eb-aa6a-4f5b-b298-8a0cbcb01398.txt diff --git a/records/050925_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt b/records/090525_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt similarity index 100% rename from records/050925_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt rename to records/090525_SkipMLPBlocks/cf25c17a-ae33-4c45-8478-3e4f177a9f26.txt diff --git a/records/050925_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt b/records/090525_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt similarity index 100% rename from records/050925_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt rename to records/090525_SkipMLPBlocks/cf8c8a10-ea32-46a0-8276-241330023e83.txt diff --git a/records/050925_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt b/records/090525_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt rename to records/090525_SkipMLPBlocks/comparison_0f6c8eac-db39-49ce-bef8-08a34044625f.txt diff --git a/records/050925_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt b/records/090525_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt rename to records/090525_SkipMLPBlocks/comparison_1b9374fc-2a63-47a1-b144-2fc8ad635792.txt diff --git a/records/050925_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt b/records/090525_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt rename to records/090525_SkipMLPBlocks/comparison_3f42c181-6303-4ade-9f64-556d44d54065.txt diff --git a/records/050925_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt b/records/090525_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt rename to records/090525_SkipMLPBlocks/comparison_50e5b966-21a9-4545-8c88-91308e140958.txt diff --git a/records/050925_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt b/records/090525_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt rename to records/090525_SkipMLPBlocks/comparison_803c2d15-4adb-42d2-958b-0b712cd9d062.txt diff --git a/records/050925_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt b/records/090525_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt rename to records/090525_SkipMLPBlocks/comparison_9a9ac5ac-514a-43e0-ab92-1319bf013a3b.txt diff --git a/records/050925_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt b/records/090525_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt rename to records/090525_SkipMLPBlocks/comparison_adcc39f4-c919-420a-bd94-9d0035f0038c.txt diff --git a/records/050925_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt b/records/090525_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt rename to records/090525_SkipMLPBlocks/comparison_c753588f-47c7-4107-9087-3c5da90cc0f4.txt diff --git a/records/050925_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt b/records/090525_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt rename to records/090525_SkipMLPBlocks/comparison_d3bc9a09-09e9-450c-a8d7-f53a4f5aed01.txt diff --git a/records/050925_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt b/records/090525_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt similarity index 100% rename from records/050925_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt rename to records/090525_SkipMLPBlocks/comparison_e501e1e9-39fa-473b-bded-39427f349f37.txt diff --git a/records/050925_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt b/records/090525_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt similarity index 100% rename from records/050925_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt rename to records/090525_SkipMLPBlocks/f01447c9-da70-405a-8ed0-858caadd1194.txt diff --git a/records/091025_Yarn/0ecdb695-510b-4c3b-b030-09861a162ce8.txt b/records/091025_Yarn/0ecdb695-510b-4c3b-b030-09861a162ce8.txt new file mode 100644 index 000000000..cbf11baca --- /dev/null +++ b/records/091025_Yarn/0ecdb695-510b-4c3b-b030-09861a162ce8.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 05:40:13 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 31C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 30C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 33C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 32C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 97165 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 97166 C /usr/bin/python3 614MiB | +| 0 N/A N/A 97167 C /usr/bin/python3 614MiB | +| 0 N/A N/A 97168 C /usr/bin/python3 614MiB | +| 0 N/A N/A 97169 C /usr/bin/python3 614MiB | +| 0 N/A N/A 97170 C /usr/bin/python3 614MiB | +| 0 N/A N/A 97171 C /usr/bin/python3 614MiB | +| 0 N/A N/A 97172 C /usr/bin/python3 614MiB | +| 1 N/A N/A 97166 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 97167 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 97168 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 97169 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 97170 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 97171 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 97172 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:482ms step_avg:481.84ms +step:2/1670 train_time:506ms step_avg:252.88ms +step:3/1670 train_time:574ms step_avg:191.30ms +step:4/1670 train_time:664ms step_avg:166.06ms +step:5/1670 train_time:756ms step_avg:151.16ms +step:6/1670 train_time:848ms step_avg:141.26ms +step:7/1670 train_time:940ms step_avg:134.29ms +step:8/1670 train_time:1032ms step_avg:128.96ms +step:9/1670 train_time:1123ms step_avg:124.79ms +step:10/1670 train_time:1215ms step_avg:121.54ms +step:11/1670 train_time:1306ms step_avg:118.73ms +step:12/1670 train_time:1398ms step_avg:116.54ms +step:13/1670 train_time:1498ms step_avg:115.20ms +step:14/1670 train_time:1593ms step_avg:113.80ms +step:15/1670 train_time:1687ms step_avg:112.44ms +step:16/1670 train_time:1778ms step_avg:111.12ms +step:17/1670 train_time:1870ms step_avg:110.00ms +step:18/1670 train_time:1963ms step_avg:109.06ms +step:19/1670 train_time:2055ms step_avg:108.18ms +step:20/1670 train_time:2147ms step_avg:107.37ms +step:21/1670 train_time:2239ms step_avg:106.64ms +step:22/1670 train_time:2331ms step_avg:105.98ms +step:23/1670 train_time:2424ms step_avg:105.39ms +step:24/1670 train_time:2517ms step_avg:104.87ms +step:25/1670 train_time:2609ms step_avg:104.38ms +step:26/1670 train_time:2702ms step_avg:103.91ms +step:27/1670 train_time:2795ms step_avg:103.51ms +step:28/1670 train_time:2887ms step_avg:103.12ms +step:29/1670 train_time:2980ms step_avg:102.76ms +step:30/1670 train_time:3072ms step_avg:102.42ms +step:31/1670 train_time:3165ms step_avg:102.09ms +step:32/1670 train_time:3257ms step_avg:101.77ms +step:33/1670 train_time:3350ms step_avg:101.51ms +step:34/1670 train_time:3442ms step_avg:101.25ms +step:35/1670 train_time:3535ms step_avg:100.99ms +step:36/1670 train_time:3627ms step_avg:100.76ms +step:37/1670 train_time:3720ms step_avg:100.53ms +step:38/1670 train_time:3813ms step_avg:100.35ms +step:39/1670 train_time:3905ms step_avg:100.13ms +step:40/1670 train_time:3997ms step_avg:99.92ms +step:41/1670 train_time:4090ms step_avg:99.75ms +step:42/1670 train_time:4182ms step_avg:99.58ms +step:43/1670 train_time:4274ms step_avg:99.41ms +step:44/1670 train_time:4367ms step_avg:99.26ms +step:45/1670 train_time:4460ms step_avg:99.12ms +step:46/1670 train_time:4553ms step_avg:98.97ms +step:47/1670 train_time:4646ms step_avg:98.84ms +step:48/1670 train_time:4738ms step_avg:98.71ms +step:49/1670 train_time:4831ms step_avg:98.59ms +step:50/1670 train_time:4924ms step_avg:98.47ms +step:51/1670 train_time:5016ms step_avg:98.35ms +step:52/1670 train_time:5109ms step_avg:98.25ms +step:53/1670 train_time:5201ms step_avg:98.14ms +step:54/1670 train_time:5295ms step_avg:98.05ms +step:55/1670 train_time:5388ms step_avg:97.96ms +step:56/1670 train_time:5480ms step_avg:97.86ms +step:57/1670 train_time:5572ms step_avg:97.76ms +step:58/1670 train_time:5665ms step_avg:97.67ms +step:59/1670 train_time:5757ms step_avg:97.58ms +step:60/1670 train_time:5850ms step_avg:97.50ms +step:61/1670 train_time:5942ms step_avg:97.41ms +step:62/1670 train_time:6034ms step_avg:97.33ms +step:63/1670 train_time:6127ms step_avg:97.26ms +step:64/1670 train_time:6220ms step_avg:97.19ms +step:65/1670 train_time:6312ms step_avg:97.11ms +step:66/1670 train_time:6405ms step_avg:97.04ms +step:67/1670 train_time:6497ms step_avg:96.98ms +step:68/1670 train_time:6590ms step_avg:96.91ms +step:69/1670 train_time:6682ms step_avg:96.84ms +step:70/1670 train_time:6774ms step_avg:96.78ms +step:71/1670 train_time:6868ms step_avg:96.73ms +step:72/1670 train_time:6960ms step_avg:96.67ms +step:73/1670 train_time:7053ms step_avg:96.62ms +step:74/1670 train_time:7145ms step_avg:96.55ms +step:75/1670 train_time:7237ms step_avg:96.49ms +step:76/1670 train_time:7329ms step_avg:96.43ms +step:77/1670 train_time:7421ms step_avg:96.38ms +step:78/1670 train_time:7514ms step_avg:96.33ms +step:79/1670 train_time:7606ms step_avg:96.28ms +step:80/1670 train_time:7698ms step_avg:96.22ms +step:81/1670 train_time:7791ms step_avg:96.18ms +step:82/1670 train_time:7883ms step_avg:96.13ms +step:83/1670 train_time:7975ms step_avg:96.08ms +step:84/1670 train_time:8068ms step_avg:96.05ms +step:85/1670 train_time:8160ms step_avg:96.00ms +step:86/1670 train_time:8254ms step_avg:95.97ms +step:87/1670 train_time:8346ms step_avg:95.93ms +step:88/1670 train_time:8438ms step_avg:95.88ms +step:89/1670 train_time:8530ms step_avg:95.84ms +step:90/1670 train_time:8622ms step_avg:95.80ms +step:91/1670 train_time:8714ms step_avg:95.76ms +step:92/1670 train_time:8807ms step_avg:95.73ms +step:93/1670 train_time:8899ms step_avg:95.68ms +step:94/1670 train_time:8991ms step_avg:95.65ms +step:95/1670 train_time:9084ms step_avg:95.62ms +step:96/1670 train_time:9176ms step_avg:95.58ms +step:97/1670 train_time:9268ms step_avg:95.55ms +step:98/1670 train_time:9360ms step_avg:95.51ms +step:99/1670 train_time:9453ms step_avg:95.48ms +step:100/1670 train_time:9545ms step_avg:95.45ms +step:101/1670 train_time:9637ms step_avg:95.41ms +step:102/1670 train_time:9730ms step_avg:95.39ms +step:103/1670 train_time:9822ms step_avg:95.36ms +step:104/1670 train_time:9914ms step_avg:95.33ms +step:105/1670 train_time:10006ms step_avg:95.30ms +step:106/1670 train_time:10099ms step_avg:95.27ms +step:107/1670 train_time:10192ms step_avg:95.25ms +step:108/1670 train_time:10284ms step_avg:95.23ms +step:109/1670 train_time:10377ms step_avg:95.20ms +step:110/1670 train_time:10469ms step_avg:95.17ms +step:111/1670 train_time:10561ms step_avg:95.15ms +step:112/1670 train_time:10654ms step_avg:95.12ms +step:113/1670 train_time:10747ms step_avg:95.11ms +step:114/1670 train_time:10838ms step_avg:95.07ms +step:115/1670 train_time:10932ms step_avg:95.06ms +step:116/1670 train_time:11025ms step_avg:95.04ms +step:117/1670 train_time:11116ms step_avg:95.01ms +step:118/1670 train_time:11209ms step_avg:94.99ms +step:119/1670 train_time:11301ms step_avg:94.96ms +step:120/1670 train_time:11394ms step_avg:94.95ms +step:121/1670 train_time:11487ms step_avg:94.93ms +step:122/1670 train_time:11578ms step_avg:94.90ms +step:123/1670 train_time:11671ms step_avg:94.89ms +step:124/1670 train_time:11763ms step_avg:94.86ms +step:125/1670 train_time:11855ms step_avg:94.84ms +step:125/1670 val_loss:4.2857 train_time:11946ms step_avg:95.56ms +step:126/1670 train_time:11971ms step_avg:95.01ms +step:127/1670 train_time:12046ms step_avg:94.85ms +step:128/1670 train_time:12147ms step_avg:94.90ms +step:129/1670 train_time:12242ms step_avg:94.90ms +step:130/1670 train_time:12334ms step_avg:94.88ms +step:131/1670 train_time:12427ms step_avg:94.86ms +step:132/1670 train_time:12518ms step_avg:94.84ms +step:133/1670 train_time:12610ms step_avg:94.81ms +step:134/1670 train_time:12702ms step_avg:94.79ms +step:135/1670 train_time:12793ms step_avg:94.77ms +step:136/1670 train_time:12885ms step_avg:94.74ms +step:137/1670 train_time:12977ms step_avg:94.72ms +step:138/1670 train_time:13071ms step_avg:94.72ms +step:139/1670 train_time:13165ms step_avg:94.71ms +step:140/1670 train_time:13259ms step_avg:94.71ms +step:141/1670 train_time:13351ms step_avg:94.69ms +step:142/1670 train_time:13443ms step_avg:94.67ms +step:143/1670 train_time:13535ms step_avg:94.65ms +step:144/1670 train_time:13628ms step_avg:94.64ms +step:145/1670 train_time:13719ms step_avg:94.61ms +step:146/1670 train_time:13810ms step_avg:94.59ms +step:147/1670 train_time:13902ms step_avg:94.57ms +step:148/1670 train_time:13994ms step_avg:94.55ms +step:149/1670 train_time:14087ms step_avg:94.54ms +step:150/1670 train_time:14180ms step_avg:94.53ms +step:151/1670 train_time:14272ms step_avg:94.52ms +step:152/1670 train_time:14365ms step_avg:94.51ms +step:153/1670 train_time:14457ms step_avg:94.49ms +step:154/1670 train_time:14549ms step_avg:94.47ms +step:155/1670 train_time:14641ms step_avg:94.46ms +step:156/1670 train_time:14732ms step_avg:94.44ms +step:157/1670 train_time:14824ms step_avg:94.42ms +step:158/1670 train_time:14916ms step_avg:94.40ms +step:159/1670 train_time:15008ms step_avg:94.39ms +step:160/1670 train_time:15101ms step_avg:94.38ms +step:161/1670 train_time:15193ms step_avg:94.36ms +step:162/1670 train_time:15286ms step_avg:94.36ms +step:163/1670 train_time:15378ms step_avg:94.34ms +step:164/1670 train_time:15470ms step_avg:94.33ms +step:165/1670 train_time:15563ms step_avg:94.32ms +step:166/1670 train_time:15655ms step_avg:94.30ms +step:167/1670 train_time:15747ms step_avg:94.29ms +step:168/1670 train_time:15839ms step_avg:94.28ms +step:169/1670 train_time:15932ms step_avg:94.27ms +step:170/1670 train_time:16025ms step_avg:94.26ms +step:171/1670 train_time:16117ms step_avg:94.25ms +step:172/1670 train_time:16209ms step_avg:94.24ms +step:173/1670 train_time:16304ms step_avg:94.24ms +step:174/1670 train_time:16396ms step_avg:94.23ms +step:175/1670 train_time:16488ms step_avg:94.22ms +step:176/1670 train_time:16580ms step_avg:94.21ms +step:177/1670 train_time:16672ms step_avg:94.19ms +step:178/1670 train_time:16764ms step_avg:94.18ms +step:179/1670 train_time:16857ms step_avg:94.17ms +step:180/1670 train_time:16949ms step_avg:94.16ms +step:181/1670 train_time:17042ms step_avg:94.15ms +step:182/1670 train_time:17134ms step_avg:94.14ms +step:183/1670 train_time:17227ms step_avg:94.14ms +step:184/1670 train_time:17319ms step_avg:94.13ms +step:185/1670 train_time:17411ms step_avg:94.11ms +step:186/1670 train_time:17504ms step_avg:94.11ms +step:187/1670 train_time:17596ms step_avg:94.10ms +step:188/1670 train_time:17688ms step_avg:94.09ms +step:189/1670 train_time:17780ms step_avg:94.07ms +step:190/1670 train_time:17872ms step_avg:94.06ms +step:191/1670 train_time:17964ms step_avg:94.05ms +step:192/1670 train_time:18056ms step_avg:94.04ms +step:193/1670 train_time:18149ms step_avg:94.04ms +step:194/1670 train_time:18242ms step_avg:94.03ms +step:195/1670 train_time:18334ms step_avg:94.02ms +step:196/1670 train_time:18427ms step_avg:94.01ms +step:197/1670 train_time:18520ms step_avg:94.01ms +step:198/1670 train_time:18611ms step_avg:94.00ms +step:199/1670 train_time:18703ms step_avg:93.99ms +step:200/1670 train_time:18796ms step_avg:93.98ms +step:201/1670 train_time:18888ms step_avg:93.97ms +step:202/1670 train_time:18981ms step_avg:93.96ms +step:203/1670 train_time:19073ms step_avg:93.96ms +step:204/1670 train_time:19166ms step_avg:93.95ms +step:205/1670 train_time:19258ms step_avg:93.94ms +step:206/1670 train_time:19350ms step_avg:93.93ms +step:207/1670 train_time:19443ms step_avg:93.93ms +step:208/1670 train_time:19535ms step_avg:93.92ms +step:209/1670 train_time:19628ms step_avg:93.92ms +step:210/1670 train_time:19720ms step_avg:93.91ms +step:211/1670 train_time:19813ms step_avg:93.90ms +step:212/1670 train_time:19906ms step_avg:93.89ms +step:213/1670 train_time:20226ms step_avg:94.96ms +step:214/1670 train_time:20368ms step_avg:95.18ms +step:215/1670 train_time:20459ms step_avg:95.16ms +step:216/1670 train_time:20551ms step_avg:95.14ms +step:217/1670 train_time:20642ms step_avg:95.12ms +step:218/1670 train_time:20733ms step_avg:95.11ms +step:219/1670 train_time:20825ms step_avg:95.09ms +step:220/1670 train_time:20917ms step_avg:95.08ms +step:221/1670 train_time:21009ms step_avg:95.06ms +step:222/1670 train_time:21100ms step_avg:95.04ms +step:223/1670 train_time:21192ms step_avg:95.03ms +step:224/1670 train_time:21287ms step_avg:95.03ms +step:225/1670 train_time:21382ms step_avg:95.03ms +step:226/1670 train_time:21475ms step_avg:95.02ms +step:227/1670 train_time:21567ms step_avg:95.01ms +step:228/1670 train_time:21658ms step_avg:94.99ms +step:229/1670 train_time:21750ms step_avg:94.98ms +step:230/1670 train_time:21841ms step_avg:94.96ms +step:231/1670 train_time:21933ms step_avg:94.95ms +step:232/1670 train_time:22025ms step_avg:94.93ms +step:233/1670 train_time:22116ms step_avg:94.92ms +step:234/1670 train_time:22209ms step_avg:94.91ms +step:235/1670 train_time:22303ms step_avg:94.90ms +step:236/1670 train_time:22395ms step_avg:94.90ms +step:237/1670 train_time:22488ms step_avg:94.89ms +step:238/1670 train_time:22580ms step_avg:94.88ms +step:239/1670 train_time:22672ms step_avg:94.86ms +step:240/1670 train_time:22765ms step_avg:94.85ms +step:241/1670 train_time:22857ms step_avg:94.84ms +step:242/1670 train_time:22949ms step_avg:94.83ms +step:243/1670 train_time:23041ms step_avg:94.82ms +step:244/1670 train_time:23133ms step_avg:94.81ms +step:245/1670 train_time:23225ms step_avg:94.80ms +step:246/1670 train_time:23317ms step_avg:94.78ms +step:247/1670 train_time:23409ms step_avg:94.77ms +step:248/1670 train_time:23502ms step_avg:94.77ms +step:249/1670 train_time:23594ms step_avg:94.76ms +step:250/1670 train_time:23687ms step_avg:94.75ms +step:250/1670 val_loss:3.9662 train_time:23777ms step_avg:95.11ms +step:251/1670 train_time:23802ms step_avg:94.83ms +step:252/1670 train_time:23877ms step_avg:94.75ms +step:253/1670 train_time:23975ms step_avg:94.76ms +step:254/1670 train_time:24070ms step_avg:94.76ms +step:255/1670 train_time:24162ms step_avg:94.75ms +step:256/1670 train_time:24254ms step_avg:94.74ms +step:257/1670 train_time:24346ms step_avg:94.73ms +step:258/1670 train_time:24437ms step_avg:94.72ms +step:259/1670 train_time:24528ms step_avg:94.70ms +step:260/1670 train_time:24620ms step_avg:94.69ms +step:261/1670 train_time:24711ms step_avg:94.68ms +step:262/1670 train_time:24803ms step_avg:94.67ms +step:263/1670 train_time:24897ms step_avg:94.66ms +step:264/1670 train_time:24991ms step_avg:94.66ms +step:265/1670 train_time:25084ms step_avg:94.66ms +step:266/1670 train_time:25178ms step_avg:94.65ms +step:267/1670 train_time:25270ms step_avg:94.64ms +step:268/1670 train_time:25362ms step_avg:94.63ms +step:269/1670 train_time:25454ms step_avg:94.62ms +step:270/1670 train_time:25546ms step_avg:94.61ms +step:271/1670 train_time:25637ms step_avg:94.60ms +step:272/1670 train_time:25728ms step_avg:94.59ms +step:273/1670 train_time:25820ms step_avg:94.58ms +step:274/1670 train_time:25913ms step_avg:94.57ms +step:275/1670 train_time:26005ms step_avg:94.56ms +step:276/1670 train_time:26098ms step_avg:94.56ms +step:277/1670 train_time:26191ms step_avg:94.55ms +step:278/1670 train_time:26284ms step_avg:94.55ms +step:279/1670 train_time:26376ms step_avg:94.54ms +step:280/1670 train_time:26468ms step_avg:94.53ms +step:281/1670 train_time:26560ms step_avg:94.52ms +step:282/1670 train_time:26652ms step_avg:94.51ms +step:283/1670 train_time:26744ms step_avg:94.50ms +step:284/1670 train_time:26836ms step_avg:94.49ms +step:285/1670 train_time:26927ms step_avg:94.48ms +step:286/1670 train_time:27020ms step_avg:94.47ms +step:287/1670 train_time:27112ms step_avg:94.47ms +step:288/1670 train_time:27204ms step_avg:94.46ms +step:289/1670 train_time:27296ms step_avg:94.45ms +step:290/1670 train_time:27388ms step_avg:94.44ms +step:291/1670 train_time:27481ms step_avg:94.43ms +step:292/1670 train_time:27573ms step_avg:94.43ms +step:293/1670 train_time:27664ms step_avg:94.42ms +step:294/1670 train_time:27757ms step_avg:94.41ms +step:295/1670 train_time:27849ms step_avg:94.40ms +step:296/1670 train_time:27942ms step_avg:94.40ms +step:297/1670 train_time:28034ms step_avg:94.39ms +step:298/1670 train_time:28126ms step_avg:94.38ms +step:299/1670 train_time:28219ms step_avg:94.38ms +step:300/1670 train_time:28311ms step_avg:94.37ms +step:301/1670 train_time:28403ms step_avg:94.36ms +step:302/1670 train_time:28496ms step_avg:94.36ms +step:303/1670 train_time:28589ms step_avg:94.35ms +step:304/1670 train_time:28681ms step_avg:94.34ms +step:305/1670 train_time:28773ms step_avg:94.34ms +step:306/1670 train_time:28865ms step_avg:94.33ms +step:307/1670 train_time:28957ms step_avg:94.32ms +step:308/1670 train_time:29050ms step_avg:94.32ms +step:309/1670 train_time:29142ms step_avg:94.31ms +step:310/1670 train_time:29235ms step_avg:94.31ms +step:311/1670 train_time:29327ms step_avg:94.30ms +step:312/1670 train_time:29420ms step_avg:94.29ms +step:313/1670 train_time:29511ms step_avg:94.29ms +step:314/1670 train_time:29604ms step_avg:94.28ms +step:315/1670 train_time:29696ms step_avg:94.27ms +step:316/1670 train_time:29788ms step_avg:94.27ms +step:317/1670 train_time:29881ms step_avg:94.26ms +step:318/1670 train_time:29973ms step_avg:94.26ms +step:319/1670 train_time:30065ms step_avg:94.25ms +step:320/1670 train_time:30158ms step_avg:94.25ms +step:321/1670 train_time:30251ms step_avg:94.24ms +step:322/1670 train_time:30343ms step_avg:94.23ms +step:323/1670 train_time:30435ms step_avg:94.23ms +step:324/1670 train_time:30528ms step_avg:94.22ms +step:325/1670 train_time:30620ms step_avg:94.22ms +step:326/1670 train_time:30712ms step_avg:94.21ms +step:327/1670 train_time:30804ms step_avg:94.20ms +step:328/1670 train_time:30896ms step_avg:94.20ms +step:329/1670 train_time:30988ms step_avg:94.19ms +step:330/1670 train_time:31081ms step_avg:94.18ms +step:331/1670 train_time:31173ms step_avg:94.18ms +step:332/1670 train_time:31265ms step_avg:94.17ms +step:333/1670 train_time:31358ms step_avg:94.17ms +step:334/1670 train_time:31450ms step_avg:94.16ms +step:335/1670 train_time:31542ms step_avg:94.16ms +step:336/1670 train_time:31634ms step_avg:94.15ms +step:337/1670 train_time:31726ms step_avg:94.14ms +step:338/1670 train_time:31818ms step_avg:94.14ms +step:339/1670 train_time:31910ms step_avg:94.13ms +step:340/1670 train_time:32002ms step_avg:94.12ms +step:341/1670 train_time:32095ms step_avg:94.12ms +step:342/1670 train_time:32187ms step_avg:94.11ms +step:343/1670 train_time:32280ms step_avg:94.11ms +step:344/1670 train_time:32372ms step_avg:94.11ms +step:345/1670 train_time:32465ms step_avg:94.10ms +step:346/1670 train_time:32556ms step_avg:94.09ms +step:347/1670 train_time:32648ms step_avg:94.09ms +step:348/1670 train_time:32741ms step_avg:94.08ms +step:349/1670 train_time:32833ms step_avg:94.08ms +step:350/1670 train_time:32925ms step_avg:94.07ms +step:351/1670 train_time:33017ms step_avg:94.07ms +step:352/1670 train_time:33109ms step_avg:94.06ms +step:353/1670 train_time:33202ms step_avg:94.06ms +step:354/1670 train_time:33294ms step_avg:94.05ms +step:355/1670 train_time:33386ms step_avg:94.04ms +step:356/1670 train_time:33479ms step_avg:94.04ms +step:357/1670 train_time:33571ms step_avg:94.04ms +step:358/1670 train_time:33663ms step_avg:94.03ms +step:359/1670 train_time:33756ms step_avg:94.03ms +step:360/1670 train_time:33848ms step_avg:94.02ms +step:361/1670 train_time:33941ms step_avg:94.02ms +step:362/1670 train_time:34033ms step_avg:94.01ms +step:363/1670 train_time:34125ms step_avg:94.01ms +step:364/1670 train_time:34218ms step_avg:94.00ms +step:365/1670 train_time:34310ms step_avg:94.00ms +step:366/1670 train_time:34402ms step_avg:94.00ms +step:367/1670 train_time:34495ms step_avg:93.99ms +step:368/1670 train_time:34586ms step_avg:93.98ms +step:369/1670 train_time:34679ms step_avg:93.98ms +step:370/1670 train_time:34771ms step_avg:93.98ms +step:371/1670 train_time:34863ms step_avg:93.97ms +step:372/1670 train_time:34956ms step_avg:93.97ms +step:373/1670 train_time:35048ms step_avg:93.96ms +step:374/1670 train_time:35140ms step_avg:93.96ms +step:375/1670 train_time:35232ms step_avg:93.95ms +step:375/1670 val_loss:3.8152 train_time:35322ms step_avg:94.19ms +step:376/1670 train_time:35348ms step_avg:94.01ms +step:377/1670 train_time:35422ms step_avg:93.96ms +step:378/1670 train_time:35521ms step_avg:93.97ms +step:379/1670 train_time:35614ms step_avg:93.97ms +step:380/1670 train_time:35707ms step_avg:93.97ms +step:381/1670 train_time:35798ms step_avg:93.96ms +step:382/1670 train_time:35890ms step_avg:93.95ms +step:383/1670 train_time:35980ms step_avg:93.94ms +step:384/1670 train_time:36072ms step_avg:93.94ms +step:385/1670 train_time:36164ms step_avg:93.93ms +step:386/1670 train_time:36255ms step_avg:93.93ms +step:387/1670 train_time:36347ms step_avg:93.92ms +step:388/1670 train_time:36441ms step_avg:93.92ms +step:389/1670 train_time:36536ms step_avg:93.92ms +step:390/1670 train_time:36629ms step_avg:93.92ms +step:391/1670 train_time:36721ms step_avg:93.92ms +step:392/1670 train_time:36813ms step_avg:93.91ms +step:393/1670 train_time:36905ms step_avg:93.91ms +step:394/1670 train_time:36997ms step_avg:93.90ms +step:395/1670 train_time:37089ms step_avg:93.90ms +step:396/1670 train_time:37181ms step_avg:93.89ms +step:397/1670 train_time:37273ms step_avg:93.89ms +step:398/1670 train_time:37366ms step_avg:93.88ms +step:399/1670 train_time:37459ms step_avg:93.88ms +step:400/1670 train_time:37552ms step_avg:93.88ms +step:401/1670 train_time:37645ms step_avg:93.88ms +step:402/1670 train_time:37736ms step_avg:93.87ms +step:403/1670 train_time:37829ms step_avg:93.87ms +step:404/1670 train_time:37921ms step_avg:93.86ms +step:405/1670 train_time:38013ms step_avg:93.86ms +step:406/1670 train_time:38106ms step_avg:93.86ms +step:407/1670 train_time:38197ms step_avg:93.85ms +step:408/1670 train_time:38290ms step_avg:93.85ms +step:409/1670 train_time:38382ms step_avg:93.84ms +step:410/1670 train_time:38474ms step_avg:93.84ms +step:411/1670 train_time:38567ms step_avg:93.84ms +step:412/1670 train_time:38659ms step_avg:93.83ms +step:413/1670 train_time:38752ms step_avg:93.83ms +step:414/1670 train_time:38844ms step_avg:93.83ms +step:415/1670 train_time:38936ms step_avg:93.82ms +step:416/1670 train_time:39029ms step_avg:93.82ms +step:417/1670 train_time:39121ms step_avg:93.81ms +step:418/1670 train_time:39213ms step_avg:93.81ms +step:419/1670 train_time:39306ms step_avg:93.81ms +step:420/1670 train_time:39398ms step_avg:93.80ms +step:421/1670 train_time:39490ms step_avg:93.80ms +step:422/1670 train_time:39584ms step_avg:93.80ms +step:423/1670 train_time:39676ms step_avg:93.80ms +step:424/1670 train_time:39769ms step_avg:93.79ms +step:425/1670 train_time:40096ms step_avg:94.34ms +step:426/1670 train_time:40293ms step_avg:94.58ms +step:427/1670 train_time:40384ms step_avg:94.57ms +step:428/1670 train_time:40475ms step_avg:94.57ms +step:429/1670 train_time:40566ms step_avg:94.56ms +step:430/1670 train_time:40658ms step_avg:94.55ms +step:431/1670 train_time:40749ms step_avg:94.54ms +step:432/1670 train_time:40840ms step_avg:94.54ms +step:433/1670 train_time:40931ms step_avg:94.53ms +step:434/1670 train_time:41023ms step_avg:94.52ms +step:435/1670 train_time:41115ms step_avg:94.52ms +step:436/1670 train_time:41210ms step_avg:94.52ms +step:437/1670 train_time:41306ms step_avg:94.52ms +step:438/1670 train_time:41399ms step_avg:94.52ms +step:439/1670 train_time:41492ms step_avg:94.51ms +step:440/1670 train_time:41584ms step_avg:94.51ms +step:441/1670 train_time:41676ms step_avg:94.50ms +step:442/1670 train_time:41768ms step_avg:94.50ms +step:443/1670 train_time:41860ms step_avg:94.49ms +step:444/1670 train_time:41951ms step_avg:94.48ms +step:445/1670 train_time:42043ms step_avg:94.48ms +step:446/1670 train_time:42135ms step_avg:94.47ms +step:447/1670 train_time:42228ms step_avg:94.47ms +step:448/1670 train_time:42321ms step_avg:94.47ms +step:449/1670 train_time:42414ms step_avg:94.46ms +step:450/1670 train_time:42508ms step_avg:94.46ms +step:451/1670 train_time:42600ms step_avg:94.46ms +step:452/1670 train_time:42692ms step_avg:94.45ms +step:453/1670 train_time:42784ms step_avg:94.45ms +step:454/1670 train_time:42876ms step_avg:94.44ms +step:455/1670 train_time:42968ms step_avg:94.44ms +step:456/1670 train_time:43060ms step_avg:94.43ms +step:457/1670 train_time:43153ms step_avg:94.43ms +step:458/1670 train_time:43245ms step_avg:94.42ms +step:459/1670 train_time:43339ms step_avg:94.42ms +step:460/1670 train_time:43433ms step_avg:94.42ms +step:461/1670 train_time:43526ms step_avg:94.42ms +step:462/1670 train_time:43619ms step_avg:94.41ms +step:463/1670 train_time:43711ms step_avg:94.41ms +step:464/1670 train_time:43803ms step_avg:94.40ms +step:465/1670 train_time:43895ms step_avg:94.40ms +step:466/1670 train_time:43988ms step_avg:94.39ms +step:467/1670 train_time:44080ms step_avg:94.39ms +step:468/1670 train_time:44172ms step_avg:94.39ms +step:469/1670 train_time:44266ms step_avg:94.38ms +step:470/1670 train_time:44359ms step_avg:94.38ms +step:471/1670 train_time:44453ms step_avg:94.38ms +step:472/1670 train_time:44546ms step_avg:94.38ms +step:473/1670 train_time:44638ms step_avg:94.37ms +step:474/1670 train_time:44730ms step_avg:94.37ms +step:475/1670 train_time:44824ms step_avg:94.37ms +step:476/1670 train_time:44915ms step_avg:94.36ms +step:477/1670 train_time:45008ms step_avg:94.36ms +step:478/1670 train_time:45100ms step_avg:94.35ms +step:479/1670 train_time:45192ms step_avg:94.35ms +step:480/1670 train_time:45284ms step_avg:94.34ms +step:481/1670 train_time:45377ms step_avg:94.34ms +step:482/1670 train_time:45470ms step_avg:94.34ms +step:483/1670 train_time:45562ms step_avg:94.33ms +step:484/1670 train_time:45654ms step_avg:94.33ms +step:485/1670 train_time:45746ms step_avg:94.32ms +step:486/1670 train_time:45839ms step_avg:94.32ms +step:487/1670 train_time:45931ms step_avg:94.31ms +step:488/1670 train_time:46024ms step_avg:94.31ms +step:489/1670 train_time:46115ms step_avg:94.30ms +step:490/1670 train_time:46208ms step_avg:94.30ms +step:491/1670 train_time:46300ms step_avg:94.30ms +step:492/1670 train_time:46393ms step_avg:94.29ms +step:493/1670 train_time:46485ms step_avg:94.29ms +step:494/1670 train_time:46577ms step_avg:94.29ms +step:495/1670 train_time:46670ms step_avg:94.28ms +step:496/1670 train_time:46762ms step_avg:94.28ms +step:497/1670 train_time:46854ms step_avg:94.27ms +step:498/1670 train_time:46946ms step_avg:94.27ms +step:499/1670 train_time:47039ms step_avg:94.27ms +step:500/1670 train_time:47131ms step_avg:94.26ms +step:500/1670 val_loss:3.7135 train_time:47222ms step_avg:94.44ms +step:501/1670 train_time:47247ms step_avg:94.31ms +step:502/1670 train_time:47321ms step_avg:94.26ms +step:503/1670 train_time:47421ms step_avg:94.28ms +step:504/1670 train_time:47517ms step_avg:94.28ms +step:505/1670 train_time:47608ms step_avg:94.27ms +step:506/1670 train_time:47699ms step_avg:94.27ms +step:507/1670 train_time:47791ms step_avg:94.26ms +step:508/1670 train_time:47882ms step_avg:94.26ms +step:509/1670 train_time:47974ms step_avg:94.25ms +step:510/1670 train_time:48065ms step_avg:94.25ms +step:511/1670 train_time:48156ms step_avg:94.24ms +step:512/1670 train_time:48248ms step_avg:94.23ms +step:513/1670 train_time:48343ms step_avg:94.24ms +step:514/1670 train_time:48438ms step_avg:94.24ms +step:515/1670 train_time:48532ms step_avg:94.24ms +step:516/1670 train_time:48625ms step_avg:94.23ms +step:517/1670 train_time:48717ms step_avg:94.23ms +step:518/1670 train_time:48809ms step_avg:94.23ms +step:519/1670 train_time:48900ms step_avg:94.22ms +step:520/1670 train_time:48991ms step_avg:94.21ms +step:521/1670 train_time:49084ms step_avg:94.21ms +step:522/1670 train_time:49176ms step_avg:94.21ms +step:523/1670 train_time:49268ms step_avg:94.20ms +step:524/1670 train_time:49361ms step_avg:94.20ms +step:525/1670 train_time:49456ms step_avg:94.20ms +step:526/1670 train_time:49549ms step_avg:94.20ms +step:527/1670 train_time:49642ms step_avg:94.20ms +step:528/1670 train_time:49734ms step_avg:94.19ms +step:529/1670 train_time:49826ms step_avg:94.19ms +step:530/1670 train_time:49918ms step_avg:94.19ms +step:531/1670 train_time:50010ms step_avg:94.18ms +step:532/1670 train_time:50102ms step_avg:94.18ms +step:533/1670 train_time:50193ms step_avg:94.17ms +step:534/1670 train_time:50286ms step_avg:94.17ms +step:535/1670 train_time:50379ms step_avg:94.17ms +step:536/1670 train_time:50473ms step_avg:94.17ms +step:537/1670 train_time:50565ms step_avg:94.16ms +step:538/1670 train_time:50659ms step_avg:94.16ms +step:539/1670 train_time:50752ms step_avg:94.16ms +step:540/1670 train_time:50844ms step_avg:94.16ms +step:541/1670 train_time:50937ms step_avg:94.15ms +step:542/1670 train_time:51028ms step_avg:94.15ms +step:543/1670 train_time:51120ms step_avg:94.14ms +step:544/1670 train_time:51212ms step_avg:94.14ms +step:545/1670 train_time:51304ms step_avg:94.14ms +step:546/1670 train_time:51397ms step_avg:94.13ms +step:547/1670 train_time:51489ms step_avg:94.13ms +step:548/1670 train_time:51582ms step_avg:94.13ms +step:549/1670 train_time:51674ms step_avg:94.12ms +step:550/1670 train_time:51766ms step_avg:94.12ms +step:551/1670 train_time:51859ms step_avg:94.12ms +step:552/1670 train_time:51952ms step_avg:94.12ms +step:553/1670 train_time:52044ms step_avg:94.11ms +step:554/1670 train_time:52136ms step_avg:94.11ms +step:555/1670 train_time:52228ms step_avg:94.10ms +step:556/1670 train_time:52320ms step_avg:94.10ms +step:557/1670 train_time:52412ms step_avg:94.10ms +step:558/1670 train_time:52615ms step_avg:94.29ms +step:559/1670 train_time:52683ms step_avg:94.25ms +step:560/1670 train_time:52775ms step_avg:94.24ms +step:561/1670 train_time:52868ms step_avg:94.24ms +step:562/1670 train_time:52961ms step_avg:94.24ms +step:563/1670 train_time:53054ms step_avg:94.23ms +step:564/1670 train_time:53146ms step_avg:94.23ms +step:565/1670 train_time:53239ms step_avg:94.23ms +step:566/1670 train_time:53332ms step_avg:94.23ms +step:567/1670 train_time:53424ms step_avg:94.22ms +step:568/1670 train_time:53520ms step_avg:94.23ms +step:569/1670 train_time:53619ms step_avg:94.23ms +step:570/1670 train_time:53713ms step_avg:94.23ms +step:571/1670 train_time:53806ms step_avg:94.23ms +step:572/1670 train_time:53899ms step_avg:94.23ms +step:573/1670 train_time:53993ms step_avg:94.23ms +step:574/1670 train_time:54085ms step_avg:94.23ms +step:575/1670 train_time:54178ms step_avg:94.22ms +step:576/1670 train_time:54270ms step_avg:94.22ms +step:577/1670 train_time:54363ms step_avg:94.22ms +step:578/1670 train_time:54457ms step_avg:94.22ms +step:579/1670 train_time:54553ms step_avg:94.22ms +step:580/1670 train_time:54647ms step_avg:94.22ms +step:581/1670 train_time:54742ms step_avg:94.22ms +step:582/1670 train_time:54836ms step_avg:94.22ms +step:583/1670 train_time:54929ms step_avg:94.22ms +step:584/1670 train_time:55023ms step_avg:94.22ms +step:585/1670 train_time:55116ms step_avg:94.22ms +step:586/1670 train_time:55209ms step_avg:94.21ms +step:587/1670 train_time:55302ms step_avg:94.21ms +step:588/1670 train_time:55395ms step_avg:94.21ms +step:589/1670 train_time:55490ms step_avg:94.21ms +step:590/1670 train_time:55585ms step_avg:94.21ms +step:591/1670 train_time:55679ms step_avg:94.21ms +step:592/1670 train_time:55773ms step_avg:94.21ms +step:593/1670 train_time:55866ms step_avg:94.21ms +step:594/1670 train_time:55961ms step_avg:94.21ms +step:595/1670 train_time:56055ms step_avg:94.21ms +step:596/1670 train_time:56148ms step_avg:94.21ms +step:597/1670 train_time:56242ms step_avg:94.21ms +step:598/1670 train_time:56334ms step_avg:94.20ms +step:599/1670 train_time:56427ms step_avg:94.20ms +step:600/1670 train_time:56521ms step_avg:94.20ms +step:601/1670 train_time:56615ms step_avg:94.20ms +step:602/1670 train_time:56709ms step_avg:94.20ms +step:603/1670 train_time:56803ms step_avg:94.20ms +step:604/1670 train_time:56897ms step_avg:94.20ms +step:605/1670 train_time:56991ms step_avg:94.20ms +step:606/1670 train_time:57085ms step_avg:94.20ms +step:607/1670 train_time:57178ms step_avg:94.20ms +step:608/1670 train_time:57271ms step_avg:94.20ms +step:609/1670 train_time:57365ms step_avg:94.19ms +step:610/1670 train_time:57459ms step_avg:94.19ms +step:611/1670 train_time:57553ms step_avg:94.20ms +step:612/1670 train_time:57647ms step_avg:94.19ms +step:613/1670 train_time:57741ms step_avg:94.19ms +step:614/1670 train_time:57834ms step_avg:94.19ms +step:615/1670 train_time:57928ms step_avg:94.19ms +step:616/1670 train_time:58022ms step_avg:94.19ms +step:617/1670 train_time:58117ms step_avg:94.19ms +step:618/1670 train_time:58210ms step_avg:94.19ms +step:619/1670 train_time:58303ms step_avg:94.19ms +step:620/1670 train_time:58397ms step_avg:94.19ms +step:621/1670 train_time:58490ms step_avg:94.19ms +step:622/1670 train_time:58585ms step_avg:94.19ms +step:623/1670 train_time:58679ms step_avg:94.19ms +step:624/1670 train_time:58772ms step_avg:94.19ms +step:625/1670 train_time:58866ms step_avg:94.19ms +step:625/1670 val_loss:3.6129 train_time:58958ms step_avg:94.33ms +step:626/1670 train_time:58984ms step_avg:94.22ms +step:627/1670 train_time:59066ms step_avg:94.20ms +step:628/1670 train_time:59167ms step_avg:94.21ms +step:629/1670 train_time:59261ms step_avg:94.21ms +step:630/1670 train_time:59353ms step_avg:94.21ms +step:631/1670 train_time:59446ms step_avg:94.21ms +step:632/1670 train_time:59539ms step_avg:94.21ms +step:633/1670 train_time:59631ms step_avg:94.20ms +step:634/1670 train_time:59724ms step_avg:94.20ms +step:635/1670 train_time:59816ms step_avg:94.20ms +step:636/1670 train_time:59909ms step_avg:94.20ms +step:637/1670 train_time:60004ms step_avg:94.20ms +step:638/1670 train_time:60100ms step_avg:94.20ms +step:639/1670 train_time:60550ms step_avg:94.76ms +step:640/1670 train_time:60621ms step_avg:94.72ms +step:641/1670 train_time:60714ms step_avg:94.72ms +step:642/1670 train_time:60806ms step_avg:94.71ms +step:643/1670 train_time:60898ms step_avg:94.71ms +step:644/1670 train_time:60991ms step_avg:94.71ms +step:645/1670 train_time:61083ms step_avg:94.70ms +step:646/1670 train_time:61176ms step_avg:94.70ms +step:647/1670 train_time:61269ms step_avg:94.70ms +step:648/1670 train_time:61361ms step_avg:94.69ms +step:649/1670 train_time:61457ms step_avg:94.70ms +step:650/1670 train_time:61555ms step_avg:94.70ms +step:651/1670 train_time:61650ms step_avg:94.70ms +step:652/1670 train_time:61745ms step_avg:94.70ms +step:653/1670 train_time:61838ms step_avg:94.70ms +step:654/1670 train_time:61931ms step_avg:94.70ms +step:655/1670 train_time:62024ms step_avg:94.69ms +step:656/1670 train_time:62116ms step_avg:94.69ms +step:657/1670 train_time:62209ms step_avg:94.69ms +step:658/1670 train_time:62302ms step_avg:94.68ms +step:659/1670 train_time:62396ms step_avg:94.68ms +step:660/1670 train_time:62490ms step_avg:94.68ms +step:661/1670 train_time:62585ms step_avg:94.68ms +step:662/1670 train_time:62680ms step_avg:94.68ms +step:663/1670 train_time:62774ms step_avg:94.68ms +step:664/1670 train_time:62868ms step_avg:94.68ms +step:665/1670 train_time:62961ms step_avg:94.68ms +step:666/1670 train_time:63054ms step_avg:94.68ms +step:667/1670 train_time:63148ms step_avg:94.67ms +step:668/1670 train_time:63241ms step_avg:94.67ms +step:669/1670 train_time:63334ms step_avg:94.67ms +step:670/1670 train_time:63427ms step_avg:94.67ms +step:671/1670 train_time:63520ms step_avg:94.67ms +step:672/1670 train_time:63615ms step_avg:94.67ms +step:673/1670 train_time:63710ms step_avg:94.67ms +step:674/1670 train_time:63804ms step_avg:94.66ms +step:675/1670 train_time:63897ms step_avg:94.66ms +step:676/1670 train_time:63990ms step_avg:94.66ms +step:677/1670 train_time:64084ms step_avg:94.66ms +step:678/1670 train_time:64177ms step_avg:94.66ms +step:679/1670 train_time:64270ms step_avg:94.65ms +step:680/1670 train_time:64363ms step_avg:94.65ms +step:681/1670 train_time:64457ms step_avg:94.65ms +step:682/1670 train_time:64552ms step_avg:94.65ms +step:683/1670 train_time:64646ms step_avg:94.65ms +step:684/1670 train_time:64740ms step_avg:94.65ms +step:685/1670 train_time:64833ms step_avg:94.65ms +step:686/1670 train_time:64927ms step_avg:94.65ms +step:687/1670 train_time:65020ms step_avg:94.64ms +step:688/1670 train_time:65114ms step_avg:94.64ms +step:689/1670 train_time:65207ms step_avg:94.64ms +step:690/1670 train_time:65301ms step_avg:94.64ms +step:691/1670 train_time:65394ms step_avg:94.64ms +step:692/1670 train_time:65488ms step_avg:94.64ms +step:693/1670 train_time:65582ms step_avg:94.63ms +step:694/1670 train_time:65675ms step_avg:94.63ms +step:695/1670 train_time:65769ms step_avg:94.63ms +step:696/1670 train_time:65863ms step_avg:94.63ms +step:697/1670 train_time:65957ms step_avg:94.63ms +step:698/1670 train_time:66051ms step_avg:94.63ms +step:699/1670 train_time:66144ms step_avg:94.63ms +step:700/1670 train_time:66237ms step_avg:94.62ms +step:701/1670 train_time:66330ms step_avg:94.62ms +step:702/1670 train_time:66424ms step_avg:94.62ms +step:703/1670 train_time:66518ms step_avg:94.62ms +step:704/1670 train_time:66613ms step_avg:94.62ms +step:705/1670 train_time:66707ms step_avg:94.62ms +step:706/1670 train_time:66799ms step_avg:94.62ms +step:707/1670 train_time:66893ms step_avg:94.62ms +step:708/1670 train_time:66987ms step_avg:94.61ms +step:709/1670 train_time:67080ms step_avg:94.61ms +step:710/1670 train_time:67174ms step_avg:94.61ms +step:711/1670 train_time:67267ms step_avg:94.61ms +step:712/1670 train_time:67360ms step_avg:94.61ms +step:713/1670 train_time:67453ms step_avg:94.61ms +step:714/1670 train_time:67547ms step_avg:94.60ms +step:715/1670 train_time:67641ms step_avg:94.60ms +step:716/1670 train_time:67734ms step_avg:94.60ms +step:717/1670 train_time:67829ms step_avg:94.60ms +step:718/1670 train_time:67923ms step_avg:94.60ms +step:719/1670 train_time:68017ms step_avg:94.60ms +step:720/1670 train_time:68110ms step_avg:94.60ms +step:721/1670 train_time:68203ms step_avg:94.60ms +step:722/1670 train_time:68297ms step_avg:94.59ms +step:723/1670 train_time:68390ms step_avg:94.59ms +step:724/1670 train_time:68485ms step_avg:94.59ms +step:725/1670 train_time:68578ms step_avg:94.59ms +step:726/1670 train_time:68672ms step_avg:94.59ms +step:727/1670 train_time:68766ms step_avg:94.59ms +step:728/1670 train_time:68859ms step_avg:94.59ms +step:729/1670 train_time:68954ms step_avg:94.59ms +step:730/1670 train_time:69048ms step_avg:94.59ms +step:731/1670 train_time:69141ms step_avg:94.58ms +step:732/1670 train_time:69235ms step_avg:94.58ms +step:733/1670 train_time:69328ms step_avg:94.58ms +step:734/1670 train_time:69422ms step_avg:94.58ms +step:735/1670 train_time:69515ms step_avg:94.58ms +step:736/1670 train_time:69609ms step_avg:94.58ms +step:737/1670 train_time:69702ms step_avg:94.58ms +step:738/1670 train_time:69797ms step_avg:94.58ms +step:739/1670 train_time:69890ms step_avg:94.57ms +step:740/1670 train_time:69983ms step_avg:94.57ms +step:741/1670 train_time:70077ms step_avg:94.57ms +step:742/1670 train_time:70171ms step_avg:94.57ms +step:743/1670 train_time:70264ms step_avg:94.57ms +step:744/1670 train_time:70357ms step_avg:94.57ms +step:745/1670 train_time:70452ms step_avg:94.57ms +step:746/1670 train_time:70546ms step_avg:94.57ms +step:747/1670 train_time:70639ms step_avg:94.56ms +step:748/1670 train_time:70733ms step_avg:94.56ms +step:749/1670 train_time:70827ms step_avg:94.56ms +step:750/1670 train_time:70920ms step_avg:94.56ms +step:750/1670 val_loss:3.5634 train_time:71012ms step_avg:94.68ms +step:751/1670 train_time:71037ms step_avg:94.59ms +step:752/1670 train_time:71114ms step_avg:94.57ms +step:753/1670 train_time:71216ms step_avg:94.58ms +step:754/1670 train_time:71311ms step_avg:94.58ms +step:755/1670 train_time:71404ms step_avg:94.57ms +step:756/1670 train_time:71496ms step_avg:94.57ms +step:757/1670 train_time:71589ms step_avg:94.57ms +step:758/1670 train_time:71681ms step_avg:94.57ms +step:759/1670 train_time:71774ms step_avg:94.56ms +step:760/1670 train_time:71866ms step_avg:94.56ms +step:761/1670 train_time:71959ms step_avg:94.56ms +step:762/1670 train_time:72054ms step_avg:94.56ms +step:763/1670 train_time:72150ms step_avg:94.56ms +step:764/1670 train_time:72246ms step_avg:94.56ms +step:765/1670 train_time:72340ms step_avg:94.56ms +step:766/1670 train_time:72433ms step_avg:94.56ms +step:767/1670 train_time:72526ms step_avg:94.56ms +step:768/1670 train_time:72618ms step_avg:94.55ms +step:769/1670 train_time:72711ms step_avg:94.55ms +step:770/1670 train_time:72804ms step_avg:94.55ms +step:771/1670 train_time:72897ms step_avg:94.55ms +step:772/1670 train_time:72990ms step_avg:94.55ms +step:773/1670 train_time:73085ms step_avg:94.55ms +step:774/1670 train_time:73180ms step_avg:94.55ms +step:775/1670 train_time:73275ms step_avg:94.55ms +step:776/1670 train_time:73369ms step_avg:94.55ms +step:777/1670 train_time:73462ms step_avg:94.55ms +step:778/1670 train_time:73556ms step_avg:94.55ms +step:779/1670 train_time:73649ms step_avg:94.54ms +step:780/1670 train_time:73742ms step_avg:94.54ms +step:781/1670 train_time:73835ms step_avg:94.54ms +step:782/1670 train_time:73928ms step_avg:94.54ms +step:783/1670 train_time:74022ms step_avg:94.54ms +step:784/1670 train_time:74117ms step_avg:94.54ms +step:785/1670 train_time:74213ms step_avg:94.54ms +step:786/1670 train_time:74306ms step_avg:94.54ms +step:787/1670 train_time:74400ms step_avg:94.54ms +step:788/1670 train_time:74494ms step_avg:94.53ms +step:789/1670 train_time:74588ms step_avg:94.53ms +step:790/1670 train_time:74681ms step_avg:94.53ms +step:791/1670 train_time:74774ms step_avg:94.53ms +step:792/1670 train_time:74868ms step_avg:94.53ms +step:793/1670 train_time:74962ms step_avg:94.53ms +step:794/1670 train_time:75056ms step_avg:94.53ms +step:795/1670 train_time:75150ms step_avg:94.53ms +step:796/1670 train_time:75244ms step_avg:94.53ms +step:797/1670 train_time:75338ms step_avg:94.53ms +step:798/1670 train_time:75432ms step_avg:94.53ms +step:799/1670 train_time:75526ms step_avg:94.53ms +step:800/1670 train_time:75619ms step_avg:94.52ms +step:801/1670 train_time:75713ms step_avg:94.52ms +step:802/1670 train_time:75806ms step_avg:94.52ms +step:803/1670 train_time:75899ms step_avg:94.52ms +step:804/1670 train_time:75993ms step_avg:94.52ms +step:805/1670 train_time:76087ms step_avg:94.52ms +step:806/1670 train_time:76180ms step_avg:94.52ms +step:807/1670 train_time:76274ms step_avg:94.52ms +step:808/1670 train_time:76368ms step_avg:94.51ms +step:809/1670 train_time:76462ms step_avg:94.51ms +step:810/1670 train_time:76555ms step_avg:94.51ms +step:811/1670 train_time:76648ms step_avg:94.51ms +step:812/1670 train_time:76742ms step_avg:94.51ms +step:813/1670 train_time:76835ms step_avg:94.51ms +step:814/1670 train_time:76929ms step_avg:94.51ms +step:815/1670 train_time:77022ms step_avg:94.51ms +step:816/1670 train_time:77116ms step_avg:94.50ms +step:817/1670 train_time:77210ms step_avg:94.50ms +step:818/1670 train_time:77303ms step_avg:94.50ms +step:819/1670 train_time:77397ms step_avg:94.50ms +step:820/1670 train_time:77492ms step_avg:94.50ms +step:821/1670 train_time:77585ms step_avg:94.50ms +step:822/1670 train_time:77679ms step_avg:94.50ms +step:823/1670 train_time:77772ms step_avg:94.50ms +step:824/1670 train_time:77866ms step_avg:94.50ms +step:825/1670 train_time:77959ms step_avg:94.50ms +step:826/1670 train_time:78053ms step_avg:94.49ms +step:827/1670 train_time:78146ms step_avg:94.49ms +step:828/1670 train_time:78239ms step_avg:94.49ms +step:829/1670 train_time:78333ms step_avg:94.49ms +step:830/1670 train_time:78426ms step_avg:94.49ms +step:831/1670 train_time:78519ms step_avg:94.49ms +step:832/1670 train_time:78614ms step_avg:94.49ms +step:833/1670 train_time:78707ms step_avg:94.49ms +step:834/1670 train_time:78801ms step_avg:94.49ms +step:835/1670 train_time:78894ms step_avg:94.48ms +step:836/1670 train_time:78988ms step_avg:94.48ms +step:837/1670 train_time:79082ms step_avg:94.48ms +step:838/1670 train_time:79175ms step_avg:94.48ms +step:839/1670 train_time:79269ms step_avg:94.48ms +step:840/1670 train_time:79364ms step_avg:94.48ms +step:841/1670 train_time:79457ms step_avg:94.48ms +step:842/1670 train_time:79550ms step_avg:94.48ms +step:843/1670 train_time:79644ms step_avg:94.48ms +step:844/1670 train_time:79737ms step_avg:94.48ms +step:845/1670 train_time:79832ms step_avg:94.48ms +step:846/1670 train_time:79926ms step_avg:94.48ms +step:847/1670 train_time:80020ms step_avg:94.47ms +step:848/1670 train_time:80113ms step_avg:94.47ms +step:849/1670 train_time:80207ms step_avg:94.47ms +step:850/1670 train_time:80300ms step_avg:94.47ms +step:851/1670 train_time:80657ms step_avg:94.78ms +step:852/1670 train_time:80823ms step_avg:94.86ms +step:853/1670 train_time:80915ms step_avg:94.86ms +step:854/1670 train_time:81007ms step_avg:94.86ms +step:855/1670 train_time:81099ms step_avg:94.85ms +step:856/1670 train_time:81192ms step_avg:94.85ms +step:857/1670 train_time:81285ms step_avg:94.85ms +step:858/1670 train_time:81377ms step_avg:94.85ms +step:859/1670 train_time:81470ms step_avg:94.84ms +step:860/1670 train_time:81563ms step_avg:94.84ms +step:861/1670 train_time:81656ms step_avg:94.84ms +step:862/1670 train_time:81753ms step_avg:94.84ms +step:863/1670 train_time:81851ms step_avg:94.84ms +step:864/1670 train_time:81946ms step_avg:94.84ms +step:865/1670 train_time:82039ms step_avg:94.84ms +step:866/1670 train_time:82132ms step_avg:94.84ms +step:867/1670 train_time:82225ms step_avg:94.84ms +step:868/1670 train_time:82317ms step_avg:94.84ms +step:869/1670 train_time:82411ms step_avg:94.83ms +step:870/1670 train_time:82503ms step_avg:94.83ms +step:871/1670 train_time:82596ms step_avg:94.83ms +step:872/1670 train_time:82691ms step_avg:94.83ms +step:873/1670 train_time:82787ms step_avg:94.83ms +step:874/1670 train_time:82882ms step_avg:94.83ms +step:875/1670 train_time:82976ms step_avg:94.83ms +step:875/1670 val_loss:3.5179 train_time:83069ms step_avg:94.94ms +step:876/1670 train_time:83094ms step_avg:94.86ms +step:877/1670 train_time:83171ms step_avg:94.84ms +step:878/1670 train_time:83269ms step_avg:94.84ms +step:879/1670 train_time:83364ms step_avg:94.84ms +step:880/1670 train_time:83457ms step_avg:94.84ms +step:881/1670 train_time:83550ms step_avg:94.83ms +step:882/1670 train_time:83642ms step_avg:94.83ms +step:883/1670 train_time:83735ms step_avg:94.83ms +step:884/1670 train_time:83828ms step_avg:94.83ms +step:885/1670 train_time:83920ms step_avg:94.83ms +step:886/1670 train_time:84013ms step_avg:94.82ms +step:887/1670 train_time:84110ms step_avg:94.83ms +step:888/1670 train_time:84207ms step_avg:94.83ms +step:889/1670 train_time:84303ms step_avg:94.83ms +step:890/1670 train_time:84396ms step_avg:94.83ms +step:891/1670 train_time:84490ms step_avg:94.83ms +step:892/1670 train_time:84583ms step_avg:94.82ms +step:893/1670 train_time:84676ms step_avg:94.82ms +step:894/1670 train_time:84769ms step_avg:94.82ms +step:895/1670 train_time:84862ms step_avg:94.82ms +step:896/1670 train_time:84956ms step_avg:94.82ms +step:897/1670 train_time:85050ms step_avg:94.82ms +step:898/1670 train_time:85146ms step_avg:94.82ms +step:899/1670 train_time:85240ms step_avg:94.82ms +step:900/1670 train_time:85334ms step_avg:94.82ms +step:901/1670 train_time:85428ms step_avg:94.81ms +step:902/1670 train_time:85521ms step_avg:94.81ms +step:903/1670 train_time:85614ms step_avg:94.81ms +step:904/1670 train_time:85707ms step_avg:94.81ms +step:905/1670 train_time:85800ms step_avg:94.81ms +step:906/1670 train_time:85892ms step_avg:94.80ms +step:907/1670 train_time:85985ms step_avg:94.80ms +step:908/1670 train_time:86079ms step_avg:94.80ms +step:909/1670 train_time:86173ms step_avg:94.80ms +step:910/1670 train_time:86268ms step_avg:94.80ms +step:911/1670 train_time:86363ms step_avg:94.80ms +step:912/1670 train_time:86457ms step_avg:94.80ms +step:913/1670 train_time:86551ms step_avg:94.80ms +step:914/1670 train_time:86644ms step_avg:94.80ms +step:915/1670 train_time:86737ms step_avg:94.80ms +step:916/1670 train_time:86831ms step_avg:94.79ms +step:917/1670 train_time:86925ms step_avg:94.79ms +step:918/1670 train_time:87019ms step_avg:94.79ms +step:919/1670 train_time:87113ms step_avg:94.79ms +step:920/1670 train_time:87206ms step_avg:94.79ms +step:921/1670 train_time:87300ms step_avg:94.79ms +step:922/1670 train_time:87395ms step_avg:94.79ms +step:923/1670 train_time:87489ms step_avg:94.79ms +step:924/1670 train_time:87582ms step_avg:94.79ms +step:925/1670 train_time:87676ms step_avg:94.78ms +step:926/1670 train_time:87769ms step_avg:94.78ms +step:927/1670 train_time:87863ms step_avg:94.78ms +step:928/1670 train_time:87956ms step_avg:94.78ms +step:929/1670 train_time:88051ms step_avg:94.78ms +step:930/1670 train_time:88145ms step_avg:94.78ms +step:931/1670 train_time:88238ms step_avg:94.78ms +step:932/1670 train_time:88332ms step_avg:94.78ms +step:933/1670 train_time:88427ms step_avg:94.78ms +step:934/1670 train_time:88520ms step_avg:94.78ms +step:935/1670 train_time:88613ms step_avg:94.77ms +step:936/1670 train_time:88707ms step_avg:94.77ms +step:937/1670 train_time:88800ms step_avg:94.77ms +step:938/1670 train_time:88893ms step_avg:94.77ms +step:939/1670 train_time:88987ms step_avg:94.77ms +step:940/1670 train_time:89081ms step_avg:94.77ms +step:941/1670 train_time:89175ms step_avg:94.77ms +step:942/1670 train_time:89268ms step_avg:94.76ms +step:943/1670 train_time:89362ms step_avg:94.76ms +step:944/1670 train_time:89456ms step_avg:94.76ms +step:945/1670 train_time:89549ms step_avg:94.76ms +step:946/1670 train_time:89643ms step_avg:94.76ms +step:947/1670 train_time:89736ms step_avg:94.76ms +step:948/1670 train_time:89830ms step_avg:94.76ms +step:949/1670 train_time:89925ms step_avg:94.76ms +step:950/1670 train_time:90019ms step_avg:94.76ms +step:951/1670 train_time:90112ms step_avg:94.75ms +step:952/1670 train_time:90205ms step_avg:94.75ms +step:953/1670 train_time:90300ms step_avg:94.75ms +step:954/1670 train_time:90393ms step_avg:94.75ms +step:955/1670 train_time:90487ms step_avg:94.75ms +step:956/1670 train_time:90580ms step_avg:94.75ms +step:957/1670 train_time:90673ms step_avg:94.75ms +step:958/1670 train_time:90767ms step_avg:94.75ms +step:959/1670 train_time:90861ms step_avg:94.75ms +step:960/1670 train_time:90955ms step_avg:94.74ms +step:961/1670 train_time:91049ms step_avg:94.74ms +step:962/1670 train_time:91143ms step_avg:94.74ms +step:963/1670 train_time:91236ms step_avg:94.74ms +step:964/1670 train_time:91331ms step_avg:94.74ms +step:965/1670 train_time:91424ms step_avg:94.74ms +step:966/1670 train_time:91518ms step_avg:94.74ms +step:967/1670 train_time:91612ms step_avg:94.74ms +step:968/1670 train_time:91705ms step_avg:94.74ms +step:969/1670 train_time:91798ms step_avg:94.73ms +step:970/1670 train_time:91892ms step_avg:94.73ms +step:971/1670 train_time:91986ms step_avg:94.73ms +step:972/1670 train_time:92080ms step_avg:94.73ms +step:973/1670 train_time:92173ms step_avg:94.73ms +step:974/1670 train_time:92266ms step_avg:94.73ms +step:975/1670 train_time:92359ms step_avg:94.73ms +step:976/1670 train_time:92454ms step_avg:94.73ms +step:977/1670 train_time:92548ms step_avg:94.73ms +step:978/1670 train_time:92641ms step_avg:94.73ms +step:979/1670 train_time:92735ms step_avg:94.72ms +step:980/1670 train_time:92829ms step_avg:94.72ms +step:981/1670 train_time:92923ms step_avg:94.72ms +step:982/1670 train_time:93017ms step_avg:94.72ms +step:983/1670 train_time:93111ms step_avg:94.72ms +step:984/1670 train_time:93204ms step_avg:94.72ms +step:985/1670 train_time:93297ms step_avg:94.72ms +step:986/1670 train_time:93390ms step_avg:94.72ms +step:987/1670 train_time:93484ms step_avg:94.72ms +step:988/1670 train_time:93577ms step_avg:94.71ms +step:989/1670 train_time:93671ms step_avg:94.71ms +step:990/1670 train_time:93764ms step_avg:94.71ms +step:991/1670 train_time:93859ms step_avg:94.71ms +step:992/1670 train_time:93953ms step_avg:94.71ms +step:993/1670 train_time:94048ms step_avg:94.71ms +step:994/1670 train_time:94142ms step_avg:94.71ms +step:995/1670 train_time:94235ms step_avg:94.71ms +step:996/1670 train_time:94329ms step_avg:94.71ms +step:997/1670 train_time:94423ms step_avg:94.71ms +step:998/1670 train_time:94516ms step_avg:94.71ms +step:999/1670 train_time:94609ms step_avg:94.70ms +step:1000/1670 train_time:94703ms step_avg:94.70ms +step:1000/1670 val_loss:3.4686 train_time:94795ms step_avg:94.79ms +step:1001/1670 train_time:94820ms step_avg:94.73ms +step:1002/1670 train_time:94896ms step_avg:94.71ms +step:1003/1670 train_time:94996ms step_avg:94.71ms +step:1004/1670 train_time:95091ms step_avg:94.71ms +step:1005/1670 train_time:95183ms step_avg:94.71ms +step:1006/1670 train_time:95276ms step_avg:94.71ms +step:1007/1670 train_time:95368ms step_avg:94.71ms +step:1008/1670 train_time:95461ms step_avg:94.70ms +step:1009/1670 train_time:95553ms step_avg:94.70ms +step:1010/1670 train_time:95646ms step_avg:94.70ms +step:1011/1670 train_time:95740ms step_avg:94.70ms +step:1012/1670 train_time:95834ms step_avg:94.70ms +step:1013/1670 train_time:95929ms step_avg:94.70ms +step:1014/1670 train_time:96025ms step_avg:94.70ms +step:1015/1670 train_time:96119ms step_avg:94.70ms +step:1016/1670 train_time:96212ms step_avg:94.70ms +step:1017/1670 train_time:96305ms step_avg:94.70ms +step:1018/1670 train_time:96398ms step_avg:94.69ms +step:1019/1670 train_time:96490ms step_avg:94.69ms +step:1020/1670 train_time:96584ms step_avg:94.69ms +step:1021/1670 train_time:96678ms step_avg:94.69ms +step:1022/1670 train_time:96772ms step_avg:94.69ms +step:1023/1670 train_time:96866ms step_avg:94.69ms +step:1024/1670 train_time:96963ms step_avg:94.69ms +step:1025/1670 train_time:97057ms step_avg:94.69ms +step:1026/1670 train_time:97151ms step_avg:94.69ms +step:1027/1670 train_time:97244ms step_avg:94.69ms +step:1028/1670 train_time:97338ms step_avg:94.69ms +step:1029/1670 train_time:97431ms step_avg:94.69ms +step:1030/1670 train_time:97524ms step_avg:94.68ms +step:1031/1670 train_time:97617ms step_avg:94.68ms +step:1032/1670 train_time:97710ms step_avg:94.68ms +step:1033/1670 train_time:97805ms step_avg:94.68ms +step:1034/1670 train_time:97900ms step_avg:94.68ms +step:1035/1670 train_time:97994ms step_avg:94.68ms +step:1036/1670 train_time:98088ms step_avg:94.68ms +step:1037/1670 train_time:98182ms step_avg:94.68ms +step:1038/1670 train_time:98276ms step_avg:94.68ms +step:1039/1670 train_time:98368ms step_avg:94.68ms +step:1040/1670 train_time:98462ms step_avg:94.67ms +step:1041/1670 train_time:98555ms step_avg:94.67ms +step:1042/1670 train_time:98648ms step_avg:94.67ms +step:1043/1670 train_time:98742ms step_avg:94.67ms +step:1044/1670 train_time:98835ms step_avg:94.67ms +step:1045/1670 train_time:98929ms step_avg:94.67ms +step:1046/1670 train_time:99024ms step_avg:94.67ms +step:1047/1670 train_time:99117ms step_avg:94.67ms +step:1048/1670 train_time:99210ms step_avg:94.67ms +step:1049/1670 train_time:99304ms step_avg:94.67ms +step:1050/1670 train_time:99398ms step_avg:94.66ms +step:1051/1670 train_time:99491ms step_avg:94.66ms +step:1052/1670 train_time:99585ms step_avg:94.66ms +step:1053/1670 train_time:99678ms step_avg:94.66ms +step:1054/1670 train_time:99771ms step_avg:94.66ms +step:1055/1670 train_time:99864ms step_avg:94.66ms +step:1056/1670 train_time:99959ms step_avg:94.66ms +step:1057/1670 train_time:100054ms step_avg:94.66ms +step:1058/1670 train_time:100147ms step_avg:94.66ms +step:1059/1670 train_time:100241ms step_avg:94.66ms +step:1060/1670 train_time:100335ms step_avg:94.66ms +step:1061/1670 train_time:100428ms step_avg:94.65ms +step:1062/1670 train_time:100757ms step_avg:94.87ms +step:1063/1670 train_time:100947ms step_avg:94.96ms +step:1064/1670 train_time:101038ms step_avg:94.96ms +step:1065/1670 train_time:101131ms step_avg:94.96ms +step:1066/1670 train_time:101223ms step_avg:94.96ms +step:1067/1670 train_time:101315ms step_avg:94.95ms +step:1068/1670 train_time:101408ms step_avg:94.95ms +step:1069/1670 train_time:101500ms step_avg:94.95ms +step:1070/1670 train_time:101592ms step_avg:94.95ms +step:1071/1670 train_time:101685ms step_avg:94.94ms +step:1072/1670 train_time:101782ms step_avg:94.95ms +step:1073/1670 train_time:101880ms step_avg:94.95ms +step:1074/1670 train_time:101975ms step_avg:94.95ms +step:1075/1670 train_time:102068ms step_avg:94.95ms +step:1076/1670 train_time:102162ms step_avg:94.95ms +step:1077/1670 train_time:102255ms step_avg:94.94ms +step:1078/1670 train_time:102348ms step_avg:94.94ms +step:1079/1670 train_time:102441ms step_avg:94.94ms +step:1080/1670 train_time:102534ms step_avg:94.94ms +step:1081/1670 train_time:102627ms step_avg:94.94ms +step:1082/1670 train_time:102720ms step_avg:94.94ms +step:1083/1670 train_time:102814ms step_avg:94.93ms +step:1084/1670 train_time:102909ms step_avg:94.93ms +step:1085/1670 train_time:103003ms step_avg:94.93ms +step:1086/1670 train_time:103097ms step_avg:94.93ms +step:1087/1670 train_time:103191ms step_avg:94.93ms +step:1088/1670 train_time:103284ms step_avg:94.93ms +step:1089/1670 train_time:103378ms step_avg:94.93ms +step:1090/1670 train_time:103470ms step_avg:94.93ms +step:1091/1670 train_time:103564ms step_avg:94.93ms +step:1092/1670 train_time:103657ms step_avg:94.92ms +step:1093/1670 train_time:103750ms step_avg:94.92ms +step:1094/1670 train_time:103845ms step_avg:94.92ms +step:1095/1670 train_time:103939ms step_avg:94.92ms +step:1096/1670 train_time:104032ms step_avg:94.92ms +step:1097/1670 train_time:104126ms step_avg:94.92ms +step:1098/1670 train_time:104221ms step_avg:94.92ms +step:1099/1670 train_time:104314ms step_avg:94.92ms +step:1100/1670 train_time:104407ms step_avg:94.92ms +step:1101/1670 train_time:104500ms step_avg:94.91ms +step:1102/1670 train_time:104593ms step_avg:94.91ms +step:1103/1670 train_time:104686ms step_avg:94.91ms +step:1104/1670 train_time:104780ms step_avg:94.91ms +step:1105/1670 train_time:104874ms step_avg:94.91ms +step:1106/1670 train_time:104967ms step_avg:94.91ms +step:1107/1670 train_time:105062ms step_avg:94.91ms +step:1108/1670 train_time:105156ms step_avg:94.91ms +step:1109/1670 train_time:105250ms step_avg:94.90ms +step:1110/1670 train_time:105344ms step_avg:94.90ms +step:1111/1670 train_time:105437ms step_avg:94.90ms +step:1112/1670 train_time:105531ms step_avg:94.90ms +step:1113/1670 train_time:105624ms step_avg:94.90ms +step:1114/1670 train_time:105718ms step_avg:94.90ms +step:1115/1670 train_time:105922ms step_avg:95.00ms +step:1116/1670 train_time:105990ms step_avg:94.97ms +step:1117/1670 train_time:106084ms step_avg:94.97ms +step:1118/1670 train_time:106177ms step_avg:94.97ms +step:1119/1670 train_time:106269ms step_avg:94.97ms +step:1120/1670 train_time:106363ms step_avg:94.97ms +step:1121/1670 train_time:106456ms step_avg:94.96ms +step:1122/1670 train_time:106549ms step_avg:94.96ms +step:1123/1670 train_time:106642ms step_avg:94.96ms +step:1124/1670 train_time:106735ms step_avg:94.96ms +step:1125/1670 train_time:106833ms step_avg:94.96ms +step:1125/1670 val_loss:3.4160 train_time:106930ms step_avg:95.05ms +step:1126/1670 train_time:106955ms step_avg:94.99ms +step:1127/1670 train_time:107041ms step_avg:94.98ms +step:1128/1670 train_time:107142ms step_avg:94.98ms +step:1129/1670 train_time:107236ms step_avg:94.98ms +step:1130/1670 train_time:107328ms step_avg:94.98ms +step:1131/1670 train_time:107422ms step_avg:94.98ms +step:1132/1670 train_time:107515ms step_avg:94.98ms +step:1133/1670 train_time:107608ms step_avg:94.98ms +step:1134/1670 train_time:107702ms step_avg:94.98ms +step:1135/1670 train_time:107795ms step_avg:94.97ms +step:1136/1670 train_time:107888ms step_avg:94.97ms +step:1137/1670 train_time:107986ms step_avg:94.97ms +step:1138/1670 train_time:108084ms step_avg:94.98ms +step:1139/1670 train_time:108181ms step_avg:94.98ms +step:1140/1670 train_time:108276ms step_avg:94.98ms +step:1141/1670 train_time:108369ms step_avg:94.98ms +step:1142/1670 train_time:108462ms step_avg:94.98ms +step:1143/1670 train_time:108556ms step_avg:94.97ms +step:1144/1670 train_time:108650ms step_avg:94.97ms +step:1145/1670 train_time:108743ms step_avg:94.97ms +step:1146/1670 train_time:108836ms step_avg:94.97ms +step:1147/1670 train_time:108930ms step_avg:94.97ms +step:1148/1670 train_time:109026ms step_avg:94.97ms +step:1149/1670 train_time:109122ms step_avg:94.97ms +step:1150/1670 train_time:109219ms step_avg:94.97ms +step:1151/1670 train_time:109313ms step_avg:94.97ms +step:1152/1670 train_time:109407ms step_avg:94.97ms +step:1153/1670 train_time:109501ms step_avg:94.97ms +step:1154/1670 train_time:109595ms step_avg:94.97ms +step:1155/1670 train_time:109689ms step_avg:94.97ms +step:1156/1670 train_time:109783ms step_avg:94.97ms +step:1157/1670 train_time:109876ms step_avg:94.97ms +step:1158/1670 train_time:109971ms step_avg:94.97ms +step:1159/1670 train_time:110066ms step_avg:94.97ms +step:1160/1670 train_time:110162ms step_avg:94.97ms +step:1161/1670 train_time:110258ms step_avg:94.97ms +step:1162/1670 train_time:110353ms step_avg:94.97ms +step:1163/1670 train_time:110447ms step_avg:94.97ms +step:1164/1670 train_time:110541ms step_avg:94.97ms +step:1165/1670 train_time:110635ms step_avg:94.97ms +step:1166/1670 train_time:110728ms step_avg:94.96ms +step:1167/1670 train_time:110822ms step_avg:94.96ms +step:1168/1670 train_time:110916ms step_avg:94.96ms +step:1169/1670 train_time:111009ms step_avg:94.96ms +step:1170/1670 train_time:111105ms step_avg:94.96ms +step:1171/1670 train_time:111200ms step_avg:94.96ms +step:1172/1670 train_time:111296ms step_avg:94.96ms +step:1173/1670 train_time:111391ms step_avg:94.96ms +step:1174/1670 train_time:111485ms step_avg:94.96ms +step:1175/1670 train_time:111579ms step_avg:94.96ms +step:1176/1670 train_time:111673ms step_avg:94.96ms +step:1177/1670 train_time:111767ms step_avg:94.96ms +step:1178/1670 train_time:111861ms step_avg:94.96ms +step:1179/1670 train_time:111955ms step_avg:94.96ms +step:1180/1670 train_time:112050ms step_avg:94.96ms +step:1181/1670 train_time:112145ms step_avg:94.96ms +step:1182/1670 train_time:112240ms step_avg:94.96ms +step:1183/1670 train_time:112336ms step_avg:94.96ms +step:1184/1670 train_time:112430ms step_avg:94.96ms +step:1185/1670 train_time:112524ms step_avg:94.96ms +step:1186/1670 train_time:112619ms step_avg:94.96ms +step:1187/1670 train_time:112713ms step_avg:94.96ms +step:1188/1670 train_time:112808ms step_avg:94.96ms +step:1189/1670 train_time:112902ms step_avg:94.96ms +step:1190/1670 train_time:112996ms step_avg:94.95ms +step:1191/1670 train_time:113091ms step_avg:94.95ms +step:1192/1670 train_time:113186ms step_avg:94.95ms +step:1193/1670 train_time:113282ms step_avg:94.96ms +step:1194/1670 train_time:113377ms step_avg:94.96ms +step:1195/1670 train_time:113471ms step_avg:94.95ms +step:1196/1670 train_time:113565ms step_avg:94.95ms +step:1197/1670 train_time:113660ms step_avg:94.95ms +step:1198/1670 train_time:113755ms step_avg:94.95ms +step:1199/1670 train_time:113848ms step_avg:94.95ms +step:1200/1670 train_time:113943ms step_avg:94.95ms +step:1201/1670 train_time:114037ms step_avg:94.95ms +step:1202/1670 train_time:114131ms step_avg:94.95ms +step:1203/1670 train_time:114226ms step_avg:94.95ms +step:1204/1670 train_time:114321ms step_avg:94.95ms +step:1205/1670 train_time:114416ms step_avg:94.95ms +step:1206/1670 train_time:114509ms step_avg:94.95ms +step:1207/1670 train_time:114604ms step_avg:94.95ms +step:1208/1670 train_time:114698ms step_avg:94.95ms +step:1209/1670 train_time:114793ms step_avg:94.95ms +step:1210/1670 train_time:114887ms step_avg:94.95ms +step:1211/1670 train_time:114981ms step_avg:94.95ms +step:1212/1670 train_time:115075ms step_avg:94.95ms +step:1213/1670 train_time:115169ms step_avg:94.95ms +step:1214/1670 train_time:115264ms step_avg:94.95ms +step:1215/1670 train_time:115359ms step_avg:94.95ms +step:1216/1670 train_time:115453ms step_avg:94.95ms +step:1217/1670 train_time:115548ms step_avg:94.94ms +step:1218/1670 train_time:115642ms step_avg:94.94ms +step:1219/1670 train_time:115736ms step_avg:94.94ms +step:1220/1670 train_time:115830ms step_avg:94.94ms +step:1221/1670 train_time:115924ms step_avg:94.94ms +step:1222/1670 train_time:116018ms step_avg:94.94ms +step:1223/1670 train_time:116113ms step_avg:94.94ms +step:1224/1670 train_time:116207ms step_avg:94.94ms +step:1225/1670 train_time:116302ms step_avg:94.94ms +step:1226/1670 train_time:116397ms step_avg:94.94ms +step:1227/1670 train_time:116490ms step_avg:94.94ms +step:1228/1670 train_time:116585ms step_avg:94.94ms +step:1229/1670 train_time:116679ms step_avg:94.94ms +step:1230/1670 train_time:116774ms step_avg:94.94ms +step:1231/1670 train_time:116867ms step_avg:94.94ms +step:1232/1670 train_time:116962ms step_avg:94.94ms +step:1233/1670 train_time:117057ms step_avg:94.94ms +step:1234/1670 train_time:117151ms step_avg:94.94ms +step:1235/1670 train_time:117246ms step_avg:94.94ms +step:1236/1670 train_time:117341ms step_avg:94.94ms +step:1237/1670 train_time:117436ms step_avg:94.94ms +step:1238/1670 train_time:117530ms step_avg:94.94ms +step:1239/1670 train_time:117625ms step_avg:94.94ms +step:1240/1670 train_time:117719ms step_avg:94.93ms +step:1241/1670 train_time:117813ms step_avg:94.93ms +step:1242/1670 train_time:117907ms step_avg:94.93ms +step:1243/1670 train_time:118002ms step_avg:94.93ms +step:1244/1670 train_time:118096ms step_avg:94.93ms +step:1245/1670 train_time:118190ms step_avg:94.93ms +step:1246/1670 train_time:118285ms step_avg:94.93ms +step:1247/1670 train_time:118379ms step_avg:94.93ms +step:1248/1670 train_time:118473ms step_avg:94.93ms +step:1249/1670 train_time:118567ms step_avg:94.93ms +step:1250/1670 train_time:118662ms step_avg:94.93ms +step:1250/1670 val_loss:3.3767 train_time:118755ms step_avg:95.00ms +step:1251/1670 train_time:118781ms step_avg:94.95ms +step:1252/1670 train_time:118858ms step_avg:94.93ms +step:1253/1670 train_time:118958ms step_avg:94.94ms +step:1254/1670 train_time:119053ms step_avg:94.94ms +step:1255/1670 train_time:119147ms step_avg:94.94ms +step:1256/1670 train_time:119240ms step_avg:94.94ms +step:1257/1670 train_time:119334ms step_avg:94.94ms +step:1258/1670 train_time:119427ms step_avg:94.93ms +step:1259/1670 train_time:119520ms step_avg:94.93ms +step:1260/1670 train_time:119613ms step_avg:94.93ms +step:1261/1670 train_time:119707ms step_avg:94.93ms +step:1262/1670 train_time:119804ms step_avg:94.93ms +step:1263/1670 train_time:119900ms step_avg:94.93ms +step:1264/1670 train_time:119995ms step_avg:94.93ms +step:1265/1670 train_time:120089ms step_avg:94.93ms +step:1266/1670 train_time:120185ms step_avg:94.93ms +step:1267/1670 train_time:120279ms step_avg:94.93ms +step:1268/1670 train_time:120373ms step_avg:94.93ms +step:1269/1670 train_time:120467ms step_avg:94.93ms +step:1270/1670 train_time:120561ms step_avg:94.93ms +step:1271/1670 train_time:120654ms step_avg:94.93ms +step:1272/1670 train_time:120749ms step_avg:94.93ms +step:1273/1670 train_time:120846ms step_avg:94.93ms +step:1274/1670 train_time:121303ms step_avg:95.21ms +step:1275/1670 train_time:121372ms step_avg:95.19ms +step:1276/1670 train_time:121465ms step_avg:95.19ms +step:1277/1670 train_time:121559ms step_avg:95.19ms +step:1278/1670 train_time:121652ms step_avg:95.19ms +step:1279/1670 train_time:121745ms step_avg:95.19ms +step:1280/1670 train_time:121839ms step_avg:95.19ms +step:1281/1670 train_time:121932ms step_avg:95.19ms +step:1282/1670 train_time:122025ms step_avg:95.18ms +step:1283/1670 train_time:122118ms step_avg:95.18ms +step:1284/1670 train_time:122215ms step_avg:95.18ms +step:1285/1670 train_time:122312ms step_avg:95.18ms +step:1286/1670 train_time:122408ms step_avg:95.18ms +step:1287/1670 train_time:122502ms step_avg:95.18ms +step:1288/1670 train_time:122596ms step_avg:95.18ms +step:1289/1670 train_time:122691ms step_avg:95.18ms +step:1290/1670 train_time:122786ms step_avg:95.18ms +step:1291/1670 train_time:122879ms step_avg:95.18ms +step:1292/1670 train_time:122973ms step_avg:95.18ms +step:1293/1670 train_time:123067ms step_avg:95.18ms +step:1294/1670 train_time:123162ms step_avg:95.18ms +step:1295/1670 train_time:123259ms step_avg:95.18ms +step:1296/1670 train_time:123353ms step_avg:95.18ms +step:1297/1670 train_time:123449ms step_avg:95.18ms +step:1298/1670 train_time:123544ms step_avg:95.18ms +step:1299/1670 train_time:123639ms step_avg:95.18ms +step:1300/1670 train_time:123734ms step_avg:95.18ms +step:1301/1670 train_time:123828ms step_avg:95.18ms +step:1302/1670 train_time:123921ms step_avg:95.18ms +step:1303/1670 train_time:124015ms step_avg:95.18ms +step:1304/1670 train_time:124109ms step_avg:95.18ms +step:1305/1670 train_time:124205ms step_avg:95.18ms +step:1306/1670 train_time:124299ms step_avg:95.18ms +step:1307/1670 train_time:124393ms step_avg:95.17ms +step:1308/1670 train_time:124488ms step_avg:95.17ms +step:1309/1670 train_time:124583ms step_avg:95.17ms +step:1310/1670 train_time:124679ms step_avg:95.17ms +step:1311/1670 train_time:124772ms step_avg:95.17ms +step:1312/1670 train_time:124867ms step_avg:95.17ms +step:1313/1670 train_time:124960ms step_avg:95.17ms +step:1314/1670 train_time:125054ms step_avg:95.17ms +step:1315/1670 train_time:125149ms step_avg:95.17ms +step:1316/1670 train_time:125245ms step_avg:95.17ms +step:1317/1670 train_time:125339ms step_avg:95.17ms +step:1318/1670 train_time:125434ms step_avg:95.17ms +step:1319/1670 train_time:125529ms step_avg:95.17ms +step:1320/1670 train_time:125624ms step_avg:95.17ms +step:1321/1670 train_time:125718ms step_avg:95.17ms +step:1322/1670 train_time:125812ms step_avg:95.17ms +step:1323/1670 train_time:125907ms step_avg:95.17ms +step:1324/1670 train_time:126001ms step_avg:95.17ms +step:1325/1670 train_time:126095ms step_avg:95.17ms +step:1326/1670 train_time:126190ms step_avg:95.17ms +step:1327/1670 train_time:126285ms step_avg:95.17ms +step:1328/1670 train_time:126380ms step_avg:95.17ms +step:1329/1670 train_time:126475ms step_avg:95.17ms +step:1330/1670 train_time:126570ms step_avg:95.17ms +step:1331/1670 train_time:126665ms step_avg:95.17ms +step:1332/1670 train_time:126760ms step_avg:95.17ms +step:1333/1670 train_time:126854ms step_avg:95.16ms +step:1334/1670 train_time:126949ms step_avg:95.16ms +step:1335/1670 train_time:127043ms step_avg:95.16ms +step:1336/1670 train_time:127138ms step_avg:95.16ms +step:1337/1670 train_time:127232ms step_avg:95.16ms +step:1338/1670 train_time:127326ms step_avg:95.16ms +step:1339/1670 train_time:127421ms step_avg:95.16ms +step:1340/1670 train_time:127515ms step_avg:95.16ms +step:1341/1670 train_time:127611ms step_avg:95.16ms +step:1342/1670 train_time:127705ms step_avg:95.16ms +step:1343/1670 train_time:127800ms step_avg:95.16ms +step:1344/1670 train_time:127894ms step_avg:95.16ms +step:1345/1670 train_time:127988ms step_avg:95.16ms +step:1346/1670 train_time:128084ms step_avg:95.16ms +step:1347/1670 train_time:128178ms step_avg:95.16ms +step:1348/1670 train_time:128272ms step_avg:95.16ms +step:1349/1670 train_time:128366ms step_avg:95.16ms +step:1350/1670 train_time:128461ms step_avg:95.16ms +step:1351/1670 train_time:128555ms step_avg:95.16ms +step:1352/1670 train_time:128650ms step_avg:95.16ms +step:1353/1670 train_time:128746ms step_avg:95.16ms +step:1354/1670 train_time:128840ms step_avg:95.16ms +step:1355/1670 train_time:128933ms step_avg:95.15ms +step:1356/1670 train_time:129028ms step_avg:95.15ms +step:1357/1670 train_time:129123ms step_avg:95.15ms +step:1358/1670 train_time:129217ms step_avg:95.15ms +step:1359/1670 train_time:129311ms step_avg:95.15ms +step:1360/1670 train_time:129405ms step_avg:95.15ms +step:1361/1670 train_time:129499ms step_avg:95.15ms +step:1362/1670 train_time:129593ms step_avg:95.15ms +step:1363/1670 train_time:129688ms step_avg:95.15ms +step:1364/1670 train_time:129783ms step_avg:95.15ms +step:1365/1670 train_time:129877ms step_avg:95.15ms +step:1366/1670 train_time:129971ms step_avg:95.15ms +step:1367/1670 train_time:130066ms step_avg:95.15ms +step:1368/1670 train_time:130161ms step_avg:95.15ms +step:1369/1670 train_time:130255ms step_avg:95.15ms +step:1370/1670 train_time:130350ms step_avg:95.15ms +step:1371/1670 train_time:130445ms step_avg:95.15ms +step:1372/1670 train_time:130539ms step_avg:95.14ms +step:1373/1670 train_time:130633ms step_avg:95.14ms +step:1374/1670 train_time:130728ms step_avg:95.14ms +step:1375/1670 train_time:130823ms step_avg:95.14ms +step:1375/1670 val_loss:3.3425 train_time:130915ms step_avg:95.21ms +step:1376/1670 train_time:130940ms step_avg:95.16ms +step:1377/1670 train_time:131019ms step_avg:95.15ms +step:1378/1670 train_time:131117ms step_avg:95.15ms +step:1379/1670 train_time:131210ms step_avg:95.15ms +step:1380/1670 train_time:131303ms step_avg:95.15ms +step:1381/1670 train_time:131397ms step_avg:95.15ms +step:1382/1670 train_time:131491ms step_avg:95.15ms +step:1383/1670 train_time:131586ms step_avg:95.15ms +step:1384/1670 train_time:131680ms step_avg:95.14ms +step:1385/1670 train_time:131773ms step_avg:95.14ms +step:1386/1670 train_time:131868ms step_avg:95.14ms +step:1387/1670 train_time:131964ms step_avg:95.14ms +step:1388/1670 train_time:132061ms step_avg:95.14ms +step:1389/1670 train_time:132157ms step_avg:95.15ms +step:1390/1670 train_time:132252ms step_avg:95.15ms +step:1391/1670 train_time:132345ms step_avg:95.14ms +step:1392/1670 train_time:132439ms step_avg:95.14ms +step:1393/1670 train_time:132533ms step_avg:95.14ms +step:1394/1670 train_time:132626ms step_avg:95.14ms +step:1395/1670 train_time:132720ms step_avg:95.14ms +step:1396/1670 train_time:132814ms step_avg:95.14ms +step:1397/1670 train_time:132910ms step_avg:95.14ms +step:1398/1670 train_time:133005ms step_avg:95.14ms +step:1399/1670 train_time:133101ms step_avg:95.14ms +step:1400/1670 train_time:133197ms step_avg:95.14ms +step:1401/1670 train_time:133291ms step_avg:95.14ms +step:1402/1670 train_time:133384ms step_avg:95.14ms +step:1403/1670 train_time:133480ms step_avg:95.14ms +step:1404/1670 train_time:133574ms step_avg:95.14ms +step:1405/1670 train_time:133667ms step_avg:95.14ms +step:1406/1670 train_time:133762ms step_avg:95.14ms +step:1407/1670 train_time:133856ms step_avg:95.14ms +step:1408/1670 train_time:133951ms step_avg:95.14ms +step:1409/1670 train_time:134046ms step_avg:95.14ms +step:1410/1670 train_time:134142ms step_avg:95.14ms +step:1411/1670 train_time:134237ms step_avg:95.14ms +step:1412/1670 train_time:134331ms step_avg:95.13ms +step:1413/1670 train_time:134425ms step_avg:95.13ms +step:1414/1670 train_time:134519ms step_avg:95.13ms +step:1415/1670 train_time:134614ms step_avg:95.13ms +step:1416/1670 train_time:134708ms step_avg:95.13ms +step:1417/1670 train_time:134802ms step_avg:95.13ms +step:1418/1670 train_time:134897ms step_avg:95.13ms +step:1419/1670 train_time:134993ms step_avg:95.13ms +step:1420/1670 train_time:135088ms step_avg:95.13ms +step:1421/1670 train_time:135183ms step_avg:95.13ms +step:1422/1670 train_time:135277ms step_avg:95.13ms +step:1423/1670 train_time:135371ms step_avg:95.13ms +step:1424/1670 train_time:135465ms step_avg:95.13ms +step:1425/1670 train_time:135560ms step_avg:95.13ms +step:1426/1670 train_time:135655ms step_avg:95.13ms +step:1427/1670 train_time:135749ms step_avg:95.13ms +step:1428/1670 train_time:135843ms step_avg:95.13ms +step:1429/1670 train_time:135939ms step_avg:95.13ms +step:1430/1670 train_time:136035ms step_avg:95.13ms +step:1431/1670 train_time:136130ms step_avg:95.13ms +step:1432/1670 train_time:136224ms step_avg:95.13ms +step:1433/1670 train_time:136319ms step_avg:95.13ms +step:1434/1670 train_time:136413ms step_avg:95.13ms +step:1435/1670 train_time:136508ms step_avg:95.13ms +step:1436/1670 train_time:136602ms step_avg:95.13ms +step:1437/1670 train_time:136697ms step_avg:95.13ms +step:1438/1670 train_time:136791ms step_avg:95.13ms +step:1439/1670 train_time:136886ms step_avg:95.13ms +step:1440/1670 train_time:136981ms step_avg:95.13ms +step:1441/1670 train_time:137077ms step_avg:95.13ms +step:1442/1670 train_time:137172ms step_avg:95.13ms +step:1443/1670 train_time:137265ms step_avg:95.13ms +step:1444/1670 train_time:137360ms step_avg:95.12ms +step:1445/1670 train_time:137454ms step_avg:95.12ms +step:1446/1670 train_time:137548ms step_avg:95.12ms +step:1447/1670 train_time:137643ms step_avg:95.12ms +step:1448/1670 train_time:137738ms step_avg:95.12ms +step:1449/1670 train_time:137832ms step_avg:95.12ms +step:1450/1670 train_time:137927ms step_avg:95.12ms +step:1451/1670 train_time:138022ms step_avg:95.12ms +step:1452/1670 train_time:138118ms step_avg:95.12ms +step:1453/1670 train_time:138213ms step_avg:95.12ms +step:1454/1670 train_time:138307ms step_avg:95.12ms +step:1455/1670 train_time:138401ms step_avg:95.12ms +step:1456/1670 train_time:138496ms step_avg:95.12ms +step:1457/1670 train_time:138591ms step_avg:95.12ms +step:1458/1670 train_time:138684ms step_avg:95.12ms +step:1459/1670 train_time:138779ms step_avg:95.12ms +step:1460/1670 train_time:138874ms step_avg:95.12ms +step:1461/1670 train_time:138968ms step_avg:95.12ms +step:1462/1670 train_time:139063ms step_avg:95.12ms +step:1463/1670 train_time:139158ms step_avg:95.12ms +step:1464/1670 train_time:139253ms step_avg:95.12ms +step:1465/1670 train_time:139347ms step_avg:95.12ms +step:1466/1670 train_time:139442ms step_avg:95.12ms +step:1467/1670 train_time:139537ms step_avg:95.12ms +step:1468/1670 train_time:139632ms step_avg:95.12ms +step:1469/1670 train_time:139725ms step_avg:95.12ms +step:1470/1670 train_time:139820ms step_avg:95.12ms +step:1471/1670 train_time:139914ms step_avg:95.12ms +step:1472/1670 train_time:140009ms step_avg:95.11ms +step:1473/1670 train_time:140104ms step_avg:95.11ms +step:1474/1670 train_time:140199ms step_avg:95.11ms +step:1475/1670 train_time:140294ms step_avg:95.11ms +step:1476/1670 train_time:140388ms step_avg:95.11ms +step:1477/1670 train_time:140482ms step_avg:95.11ms +step:1478/1670 train_time:140577ms step_avg:95.11ms +step:1479/1670 train_time:140671ms step_avg:95.11ms +step:1480/1670 train_time:140765ms step_avg:95.11ms +step:1481/1670 train_time:140860ms step_avg:95.11ms +step:1482/1670 train_time:140955ms step_avg:95.11ms +step:1483/1670 train_time:141049ms step_avg:95.11ms +step:1484/1670 train_time:141144ms step_avg:95.11ms +step:1485/1670 train_time:141581ms step_avg:95.34ms +step:1486/1670 train_time:141650ms step_avg:95.32ms +step:1487/1670 train_time:141742ms step_avg:95.32ms +step:1488/1670 train_time:141835ms step_avg:95.32ms +step:1489/1670 train_time:141929ms step_avg:95.32ms +step:1490/1670 train_time:142022ms step_avg:95.32ms +step:1491/1670 train_time:142116ms step_avg:95.32ms +step:1492/1670 train_time:142209ms step_avg:95.31ms +step:1493/1670 train_time:142303ms step_avg:95.31ms +step:1494/1670 train_time:142397ms step_avg:95.31ms +step:1495/1670 train_time:142494ms step_avg:95.31ms +step:1496/1670 train_time:142590ms step_avg:95.31ms +step:1497/1670 train_time:142687ms step_avg:95.32ms +step:1498/1670 train_time:142783ms step_avg:95.32ms +step:1499/1670 train_time:142877ms step_avg:95.31ms +step:1500/1670 train_time:142970ms step_avg:95.31ms +step:1500/1670 val_loss:3.3124 train_time:143061ms step_avg:95.37ms +step:1501/1670 train_time:143087ms step_avg:95.33ms +step:1502/1670 train_time:143166ms step_avg:95.32ms +step:1503/1670 train_time:143269ms step_avg:95.32ms +step:1504/1670 train_time:143364ms step_avg:95.32ms +step:1505/1670 train_time:143458ms step_avg:95.32ms +step:1506/1670 train_time:143551ms step_avg:95.32ms +step:1507/1670 train_time:143644ms step_avg:95.32ms +step:1508/1670 train_time:143737ms step_avg:95.32ms +step:1509/1670 train_time:143830ms step_avg:95.31ms +step:1510/1670 train_time:143924ms step_avg:95.31ms +step:1511/1670 train_time:144018ms step_avg:95.31ms +step:1512/1670 train_time:144113ms step_avg:95.31ms +step:1513/1670 train_time:144211ms step_avg:95.31ms +step:1514/1670 train_time:144307ms step_avg:95.32ms +step:1515/1670 train_time:144405ms step_avg:95.32ms +step:1516/1670 train_time:144499ms step_avg:95.32ms +step:1517/1670 train_time:144592ms step_avg:95.31ms +step:1518/1670 train_time:144685ms step_avg:95.31ms +step:1519/1670 train_time:144778ms step_avg:95.31ms +step:1520/1670 train_time:144872ms step_avg:95.31ms +step:1521/1670 train_time:144966ms step_avg:95.31ms +step:1522/1670 train_time:145060ms step_avg:95.31ms +step:1523/1670 train_time:145154ms step_avg:95.31ms +step:1524/1670 train_time:145250ms step_avg:95.31ms +step:1525/1670 train_time:145347ms step_avg:95.31ms +step:1526/1670 train_time:145444ms step_avg:95.31ms +step:1527/1670 train_time:145538ms step_avg:95.31ms +step:1528/1670 train_time:145631ms step_avg:95.31ms +step:1529/1670 train_time:145725ms step_avg:95.31ms +step:1530/1670 train_time:145819ms step_avg:95.31ms +step:1531/1670 train_time:145912ms step_avg:95.31ms +step:1532/1670 train_time:146006ms step_avg:95.30ms +step:1533/1670 train_time:146101ms step_avg:95.30ms +step:1534/1670 train_time:146196ms step_avg:95.30ms +step:1535/1670 train_time:146291ms step_avg:95.30ms +step:1536/1670 train_time:146387ms step_avg:95.30ms +step:1537/1670 train_time:146482ms step_avg:95.30ms +step:1538/1670 train_time:146576ms step_avg:95.30ms +step:1539/1670 train_time:146670ms step_avg:95.30ms +step:1540/1670 train_time:146765ms step_avg:95.30ms +step:1541/1670 train_time:146859ms step_avg:95.30ms +step:1542/1670 train_time:146953ms step_avg:95.30ms +step:1543/1670 train_time:147047ms step_avg:95.30ms +step:1544/1670 train_time:147142ms step_avg:95.30ms +step:1545/1670 train_time:147237ms step_avg:95.30ms +step:1546/1670 train_time:147332ms step_avg:95.30ms +step:1547/1670 train_time:147427ms step_avg:95.30ms +step:1548/1670 train_time:147523ms step_avg:95.30ms +step:1549/1670 train_time:147617ms step_avg:95.30ms +step:1550/1670 train_time:147710ms step_avg:95.30ms +step:1551/1670 train_time:147805ms step_avg:95.30ms +step:1552/1670 train_time:147900ms step_avg:95.30ms +step:1553/1670 train_time:147993ms step_avg:95.30ms +step:1554/1670 train_time:148088ms step_avg:95.29ms +step:1555/1670 train_time:148183ms step_avg:95.29ms +step:1556/1670 train_time:148278ms step_avg:95.29ms +step:1557/1670 train_time:148373ms step_avg:95.29ms +step:1558/1670 train_time:148469ms step_avg:95.29ms +step:1559/1670 train_time:148564ms step_avg:95.29ms +step:1560/1670 train_time:148658ms step_avg:95.29ms +step:1561/1670 train_time:148752ms step_avg:95.29ms +step:1562/1670 train_time:148847ms step_avg:95.29ms +step:1563/1670 train_time:148940ms step_avg:95.29ms +step:1564/1670 train_time:149034ms step_avg:95.29ms +step:1565/1670 train_time:149129ms step_avg:95.29ms +step:1566/1670 train_time:149224ms step_avg:95.29ms +step:1567/1670 train_time:149319ms step_avg:95.29ms +step:1568/1670 train_time:149413ms step_avg:95.29ms +step:1569/1670 train_time:149508ms step_avg:95.29ms +step:1570/1670 train_time:149603ms step_avg:95.29ms +step:1571/1670 train_time:149697ms step_avg:95.29ms +step:1572/1670 train_time:149791ms step_avg:95.29ms +step:1573/1670 train_time:149885ms step_avg:95.29ms +step:1574/1670 train_time:149980ms step_avg:95.29ms +step:1575/1670 train_time:150074ms step_avg:95.29ms +step:1576/1670 train_time:150170ms step_avg:95.29ms +step:1577/1670 train_time:150265ms step_avg:95.29ms +step:1578/1670 train_time:150360ms step_avg:95.28ms +step:1579/1670 train_time:150454ms step_avg:95.28ms +step:1580/1670 train_time:150549ms step_avg:95.28ms +step:1581/1670 train_time:150644ms step_avg:95.28ms +step:1582/1670 train_time:150739ms step_avg:95.28ms +step:1583/1670 train_time:150832ms step_avg:95.28ms +step:1584/1670 train_time:150927ms step_avg:95.28ms +step:1585/1670 train_time:151021ms step_avg:95.28ms +step:1586/1670 train_time:151115ms step_avg:95.28ms +step:1587/1670 train_time:151210ms step_avg:95.28ms +step:1588/1670 train_time:151305ms step_avg:95.28ms +step:1589/1670 train_time:151399ms step_avg:95.28ms +step:1590/1670 train_time:151493ms step_avg:95.28ms +step:1591/1670 train_time:151589ms step_avg:95.28ms +step:1592/1670 train_time:151683ms step_avg:95.28ms +step:1593/1670 train_time:151778ms step_avg:95.28ms +step:1594/1670 train_time:151872ms step_avg:95.28ms +step:1595/1670 train_time:151967ms step_avg:95.28ms +step:1596/1670 train_time:152061ms step_avg:95.28ms +step:1597/1670 train_time:152155ms step_avg:95.28ms +step:1598/1670 train_time:152249ms step_avg:95.28ms +step:1599/1670 train_time:152344ms step_avg:95.27ms +step:1600/1670 train_time:152439ms step_avg:95.27ms +step:1601/1670 train_time:152533ms step_avg:95.27ms +step:1602/1670 train_time:152627ms step_avg:95.27ms +step:1603/1670 train_time:152722ms step_avg:95.27ms +step:1604/1670 train_time:152817ms step_avg:95.27ms +step:1605/1670 train_time:152911ms step_avg:95.27ms +step:1606/1670 train_time:153006ms step_avg:95.27ms +step:1607/1670 train_time:153101ms step_avg:95.27ms +step:1608/1670 train_time:153195ms step_avg:95.27ms +step:1609/1670 train_time:153290ms step_avg:95.27ms +step:1610/1670 train_time:153384ms step_avg:95.27ms +step:1611/1670 train_time:153478ms step_avg:95.27ms +step:1612/1670 train_time:153572ms step_avg:95.27ms +step:1613/1670 train_time:153667ms step_avg:95.27ms +step:1614/1670 train_time:153762ms step_avg:95.27ms +step:1615/1670 train_time:153857ms step_avg:95.27ms +step:1616/1670 train_time:153951ms step_avg:95.27ms +step:1617/1670 train_time:154046ms step_avg:95.27ms +step:1618/1670 train_time:154140ms step_avg:95.27ms +step:1619/1670 train_time:154234ms step_avg:95.27ms +step:1620/1670 train_time:154329ms step_avg:95.26ms +step:1621/1670 train_time:154424ms step_avg:95.26ms +step:1622/1670 train_time:154519ms step_avg:95.26ms +step:1623/1670 train_time:154613ms step_avg:95.26ms +step:1624/1670 train_time:154708ms step_avg:95.26ms +step:1625/1670 train_time:154803ms step_avg:95.26ms +step:1625/1670 val_loss:3.2875 train_time:154895ms step_avg:95.32ms +step:1626/1670 train_time:154920ms step_avg:95.28ms +step:1627/1670 train_time:154998ms step_avg:95.27ms +step:1628/1670 train_time:155103ms step_avg:95.27ms +step:1629/1670 train_time:155199ms step_avg:95.27ms +step:1630/1670 train_time:155294ms step_avg:95.27ms +step:1631/1670 train_time:155389ms step_avg:95.27ms +step:1632/1670 train_time:155482ms step_avg:95.27ms +step:1633/1670 train_time:155576ms step_avg:95.27ms +step:1634/1670 train_time:155669ms step_avg:95.27ms +step:1635/1670 train_time:155762ms step_avg:95.27ms +step:1636/1670 train_time:155856ms step_avg:95.27ms +step:1637/1670 train_time:155950ms step_avg:95.27ms +step:1638/1670 train_time:156047ms step_avg:95.27ms +step:1639/1670 train_time:156145ms step_avg:95.27ms +step:1640/1670 train_time:156240ms step_avg:95.27ms +step:1641/1670 train_time:156335ms step_avg:95.27ms +step:1642/1670 train_time:156430ms step_avg:95.27ms +step:1643/1670 train_time:156524ms step_avg:95.27ms +step:1644/1670 train_time:156618ms step_avg:95.27ms +step:1645/1670 train_time:156712ms step_avg:95.27ms +step:1646/1670 train_time:156805ms step_avg:95.26ms +step:1647/1670 train_time:156899ms step_avg:95.26ms +step:1648/1670 train_time:156996ms step_avg:95.26ms +step:1649/1670 train_time:157094ms step_avg:95.27ms +step:1650/1670 train_time:157190ms step_avg:95.27ms +step:1651/1670 train_time:157285ms step_avg:95.27ms +step:1652/1670 train_time:157380ms step_avg:95.27ms +step:1653/1670 train_time:157474ms step_avg:95.27ms +step:1654/1670 train_time:157568ms step_avg:95.26ms +step:1655/1670 train_time:157662ms step_avg:95.26ms +step:1656/1670 train_time:157756ms step_avg:95.26ms +step:1657/1670 train_time:157849ms step_avg:95.26ms +step:1658/1670 train_time:157944ms step_avg:95.26ms +step:1659/1670 train_time:158040ms step_avg:95.26ms +step:1660/1670 train_time:158135ms step_avg:95.26ms +step:1661/1670 train_time:158230ms step_avg:95.26ms +step:1662/1670 train_time:158324ms step_avg:95.26ms +step:1663/1670 train_time:158419ms step_avg:95.26ms +step:1664/1670 train_time:158513ms step_avg:95.26ms +step:1665/1670 train_time:158607ms step_avg:95.26ms +step:1666/1670 train_time:158702ms step_avg:95.26ms +step:1667/1670 train_time:158796ms step_avg:95.26ms +step:1668/1670 train_time:158891ms step_avg:95.26ms +step:1669/1670 train_time:158985ms step_avg:95.26ms +step:1670/1670 train_time:159082ms step_avg:95.26ms +step:1670/1670 val_loss:3.2785 train_time:159259ms step_avg:95.36ms +peak memory allocated: 32470 MiB reserved: 47536 MiB diff --git a/records/091025_Yarn/132fe599-bc5a-4237-ad14-ee33cbbd5fc0.txt b/records/091025_Yarn/132fe599-bc5a-4237-ad14-ee33cbbd5fc0.txt new file mode 100644 index 000000000..41ab4a052 --- /dev/null +++ b/records/091025_Yarn/132fe599-bc5a-4237-ad14-ee33cbbd5fc0.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 04:13:30 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 42C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 44C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 36C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 43C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 41C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 67852 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 67853 C /usr/bin/python3 614MiB | +| 0 N/A N/A 67854 C /usr/bin/python3 614MiB | +| 0 N/A N/A 67855 C /usr/bin/python3 614MiB | +| 0 N/A N/A 67856 C /usr/bin/python3 614MiB | +| 0 N/A N/A 67857 C /usr/bin/python3 614MiB | +| 0 N/A N/A 67858 C /usr/bin/python3 614MiB | +| 0 N/A N/A 67859 C /usr/bin/python3 614MiB | +| 1 N/A N/A 67853 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 67854 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 67855 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 67856 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 67857 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 67858 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 67859 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:464ms step_avg:464.33ms +step:2/1670 train_time:488ms step_avg:244.05ms +step:3/1670 train_time:557ms step_avg:185.53ms +step:4/1670 train_time:647ms step_avg:161.72ms +step:5/1670 train_time:738ms step_avg:147.57ms +step:6/1670 train_time:830ms step_avg:138.30ms +step:7/1670 train_time:921ms step_avg:131.60ms +step:8/1670 train_time:1013ms step_avg:126.58ms +step:9/1670 train_time:1104ms step_avg:122.67ms +step:10/1670 train_time:1196ms step_avg:119.56ms +step:11/1670 train_time:1288ms step_avg:117.05ms +step:12/1670 train_time:1381ms step_avg:115.12ms +step:13/1670 train_time:1477ms step_avg:113.60ms +step:14/1670 train_time:1571ms step_avg:112.20ms +step:15/1670 train_time:1664ms step_avg:110.91ms +step:16/1670 train_time:1756ms step_avg:109.77ms +step:17/1670 train_time:1849ms step_avg:108.74ms +step:18/1670 train_time:1941ms step_avg:107.86ms +step:19/1670 train_time:2034ms step_avg:107.03ms +step:20/1670 train_time:2126ms step_avg:106.28ms +step:21/1670 train_time:2218ms step_avg:105.61ms +step:22/1670 train_time:2310ms step_avg:105.02ms +step:23/1670 train_time:2403ms step_avg:104.49ms +step:24/1670 train_time:2497ms step_avg:104.05ms +step:25/1670 train_time:2591ms step_avg:103.62ms +step:26/1670 train_time:2684ms step_avg:103.21ms +step:27/1670 train_time:2776ms step_avg:102.81ms +step:28/1670 train_time:2869ms step_avg:102.46ms +step:29/1670 train_time:2962ms step_avg:102.13ms +step:30/1670 train_time:3054ms step_avg:101.81ms +step:31/1670 train_time:3146ms step_avg:101.49ms +step:32/1670 train_time:3238ms step_avg:101.20ms +step:33/1670 train_time:3331ms step_avg:100.93ms +step:34/1670 train_time:3424ms step_avg:100.70ms +step:35/1670 train_time:3517ms step_avg:100.48ms +step:36/1670 train_time:3610ms step_avg:100.28ms +step:37/1670 train_time:3703ms step_avg:100.07ms +step:38/1670 train_time:3796ms step_avg:99.88ms +step:39/1670 train_time:3888ms step_avg:99.70ms +step:40/1670 train_time:3980ms step_avg:99.50ms +step:41/1670 train_time:4073ms step_avg:99.33ms +step:42/1670 train_time:4165ms step_avg:99.17ms +step:43/1670 train_time:4258ms step_avg:99.03ms +step:44/1670 train_time:4351ms step_avg:98.88ms +step:45/1670 train_time:4444ms step_avg:98.75ms +step:46/1670 train_time:4537ms step_avg:98.62ms +step:47/1670 train_time:4629ms step_avg:98.49ms +step:48/1670 train_time:4722ms step_avg:98.38ms +step:49/1670 train_time:4815ms step_avg:98.27ms +step:50/1670 train_time:4908ms step_avg:98.16ms +step:51/1670 train_time:5000ms step_avg:98.05ms +step:52/1670 train_time:5093ms step_avg:97.93ms +step:53/1670 train_time:5185ms step_avg:97.82ms +step:54/1670 train_time:5277ms step_avg:97.73ms +step:55/1670 train_time:5370ms step_avg:97.63ms +step:56/1670 train_time:5462ms step_avg:97.54ms +step:57/1670 train_time:5555ms step_avg:97.45ms +step:58/1670 train_time:5648ms step_avg:97.38ms +step:59/1670 train_time:5742ms step_avg:97.32ms +step:60/1670 train_time:5833ms step_avg:97.22ms +step:61/1670 train_time:5926ms step_avg:97.15ms +step:62/1670 train_time:6019ms step_avg:97.08ms +step:63/1670 train_time:6112ms step_avg:97.01ms +step:64/1670 train_time:6204ms step_avg:96.94ms +step:65/1670 train_time:6297ms step_avg:96.87ms +step:66/1670 train_time:6389ms step_avg:96.80ms +step:67/1670 train_time:6482ms step_avg:96.74ms +step:68/1670 train_time:6575ms step_avg:96.70ms +step:69/1670 train_time:6668ms step_avg:96.63ms +step:70/1670 train_time:6760ms step_avg:96.57ms +step:71/1670 train_time:6853ms step_avg:96.52ms +step:72/1670 train_time:6946ms step_avg:96.47ms +step:73/1670 train_time:7039ms step_avg:96.43ms +step:74/1670 train_time:7132ms step_avg:96.37ms +step:75/1670 train_time:7223ms step_avg:96.31ms +step:76/1670 train_time:7316ms step_avg:96.26ms +step:77/1670 train_time:7408ms step_avg:96.21ms +step:78/1670 train_time:7501ms step_avg:96.17ms +step:79/1670 train_time:7594ms step_avg:96.13ms +step:80/1670 train_time:7686ms step_avg:96.08ms +step:81/1670 train_time:7778ms step_avg:96.03ms +step:82/1670 train_time:7871ms step_avg:95.99ms +step:83/1670 train_time:7963ms step_avg:95.94ms +step:84/1670 train_time:8056ms step_avg:95.91ms +step:85/1670 train_time:8149ms step_avg:95.87ms +step:86/1670 train_time:8242ms step_avg:95.84ms +step:87/1670 train_time:8334ms step_avg:95.79ms +step:88/1670 train_time:8427ms step_avg:95.76ms +step:89/1670 train_time:8520ms step_avg:95.73ms +step:90/1670 train_time:8613ms step_avg:95.69ms +step:91/1670 train_time:8705ms step_avg:95.66ms +step:92/1670 train_time:8797ms step_avg:95.62ms +step:93/1670 train_time:8890ms step_avg:95.59ms +step:94/1670 train_time:8982ms step_avg:95.55ms +step:95/1670 train_time:9075ms step_avg:95.52ms +step:96/1670 train_time:9167ms step_avg:95.49ms +step:97/1670 train_time:9259ms step_avg:95.46ms +step:98/1670 train_time:9352ms step_avg:95.43ms +step:99/1670 train_time:9445ms step_avg:95.40ms +step:100/1670 train_time:9537ms step_avg:95.37ms +step:101/1670 train_time:9630ms step_avg:95.35ms +step:102/1670 train_time:9723ms step_avg:95.32ms +step:103/1670 train_time:9816ms step_avg:95.30ms +step:104/1670 train_time:9908ms step_avg:95.27ms +step:105/1670 train_time:10000ms step_avg:95.24ms +step:106/1670 train_time:10094ms step_avg:95.22ms +step:107/1670 train_time:10186ms step_avg:95.20ms +step:108/1670 train_time:10279ms step_avg:95.17ms +step:109/1670 train_time:10371ms step_avg:95.15ms +step:110/1670 train_time:10463ms step_avg:95.12ms +step:111/1670 train_time:10555ms step_avg:95.09ms +step:112/1670 train_time:10648ms step_avg:95.07ms +step:113/1670 train_time:10741ms step_avg:95.05ms +step:114/1670 train_time:10834ms step_avg:95.03ms +step:115/1670 train_time:10926ms step_avg:95.01ms +step:116/1670 train_time:11018ms step_avg:94.98ms +step:117/1670 train_time:11111ms step_avg:94.96ms +step:118/1670 train_time:11204ms step_avg:94.95ms +step:119/1670 train_time:11296ms step_avg:94.93ms +step:120/1670 train_time:11388ms step_avg:94.90ms +step:121/1670 train_time:11481ms step_avg:94.88ms +step:122/1670 train_time:11573ms step_avg:94.86ms +step:123/1670 train_time:11665ms step_avg:94.84ms +step:124/1670 train_time:11758ms step_avg:94.83ms +step:125/1670 train_time:11852ms step_avg:94.81ms +step:125/1670 val_loss:4.2892 train_time:11942ms step_avg:95.53ms +step:126/1670 train_time:11968ms step_avg:94.99ms +step:127/1670 train_time:12040ms step_avg:94.81ms +step:128/1670 train_time:12145ms step_avg:94.89ms +step:129/1670 train_time:12240ms step_avg:94.88ms +step:130/1670 train_time:12332ms step_avg:94.86ms +step:131/1670 train_time:12424ms step_avg:94.84ms +step:132/1670 train_time:12516ms step_avg:94.82ms +step:133/1670 train_time:12608ms step_avg:94.79ms +step:134/1670 train_time:12699ms step_avg:94.77ms +step:135/1670 train_time:12791ms step_avg:94.74ms +step:136/1670 train_time:12882ms step_avg:94.72ms +step:137/1670 train_time:12974ms step_avg:94.70ms +step:138/1670 train_time:13067ms step_avg:94.69ms +step:139/1670 train_time:13161ms step_avg:94.69ms +step:140/1670 train_time:13256ms step_avg:94.69ms +step:141/1670 train_time:13349ms step_avg:94.67ms +step:142/1670 train_time:13442ms step_avg:94.66ms +step:143/1670 train_time:13534ms step_avg:94.64ms +step:144/1670 train_time:13626ms step_avg:94.63ms +step:145/1670 train_time:13719ms step_avg:94.61ms +step:146/1670 train_time:13811ms step_avg:94.59ms +step:147/1670 train_time:13903ms step_avg:94.58ms +step:148/1670 train_time:13995ms step_avg:94.56ms +step:149/1670 train_time:14089ms step_avg:94.56ms +step:150/1670 train_time:14183ms step_avg:94.55ms +step:151/1670 train_time:14276ms step_avg:94.55ms +step:152/1670 train_time:14369ms step_avg:94.53ms +step:153/1670 train_time:14462ms step_avg:94.52ms +step:154/1670 train_time:14555ms step_avg:94.51ms +step:155/1670 train_time:14648ms step_avg:94.50ms +step:156/1670 train_time:14740ms step_avg:94.49ms +step:157/1670 train_time:14832ms step_avg:94.47ms +step:158/1670 train_time:14924ms step_avg:94.46ms +step:159/1670 train_time:15016ms step_avg:94.44ms +step:160/1670 train_time:15110ms step_avg:94.44ms +step:161/1670 train_time:15203ms step_avg:94.43ms +step:162/1670 train_time:15296ms step_avg:94.42ms +step:163/1670 train_time:15389ms step_avg:94.41ms +step:164/1670 train_time:15481ms step_avg:94.40ms +step:165/1670 train_time:15574ms step_avg:94.39ms +step:166/1670 train_time:15666ms step_avg:94.37ms +step:167/1670 train_time:15758ms step_avg:94.36ms +step:168/1670 train_time:15850ms step_avg:94.35ms +step:169/1670 train_time:15942ms step_avg:94.33ms +step:170/1670 train_time:16035ms step_avg:94.32ms +step:171/1670 train_time:16128ms step_avg:94.31ms +step:172/1670 train_time:16220ms step_avg:94.30ms +step:173/1670 train_time:16314ms step_avg:94.30ms +step:174/1670 train_time:16407ms step_avg:94.29ms +step:175/1670 train_time:16500ms step_avg:94.28ms +step:176/1670 train_time:16593ms step_avg:94.28ms +step:177/1670 train_time:16686ms step_avg:94.27ms +step:178/1670 train_time:16779ms step_avg:94.26ms +step:179/1670 train_time:16871ms step_avg:94.25ms +step:180/1670 train_time:16963ms step_avg:94.24ms +step:181/1670 train_time:17056ms step_avg:94.23ms +step:182/1670 train_time:17148ms step_avg:94.22ms +step:183/1670 train_time:17240ms step_avg:94.21ms +step:184/1670 train_time:17333ms step_avg:94.20ms +step:185/1670 train_time:17426ms step_avg:94.19ms +step:186/1670 train_time:17519ms step_avg:94.19ms +step:187/1670 train_time:17612ms step_avg:94.18ms +step:188/1670 train_time:17705ms step_avg:94.17ms +step:189/1670 train_time:17798ms step_avg:94.17ms +step:190/1670 train_time:17890ms step_avg:94.16ms +step:191/1670 train_time:17982ms step_avg:94.15ms +step:192/1670 train_time:18074ms step_avg:94.14ms +step:193/1670 train_time:18166ms step_avg:94.13ms +step:194/1670 train_time:18259ms step_avg:94.12ms +step:195/1670 train_time:18351ms step_avg:94.11ms +step:196/1670 train_time:18444ms step_avg:94.10ms +step:197/1670 train_time:18537ms step_avg:94.10ms +step:198/1670 train_time:18630ms step_avg:94.09ms +step:199/1670 train_time:18723ms step_avg:94.08ms +step:200/1670 train_time:18815ms step_avg:94.07ms +step:201/1670 train_time:18908ms step_avg:94.07ms +step:202/1670 train_time:19000ms step_avg:94.06ms +step:203/1670 train_time:19093ms step_avg:94.05ms +step:204/1670 train_time:19185ms step_avg:94.05ms +step:205/1670 train_time:19278ms step_avg:94.04ms +step:206/1670 train_time:19371ms step_avg:94.03ms +step:207/1670 train_time:19463ms step_avg:94.03ms +step:208/1670 train_time:19556ms step_avg:94.02ms +step:209/1670 train_time:19649ms step_avg:94.01ms +step:210/1670 train_time:19742ms step_avg:94.01ms +step:211/1670 train_time:19835ms step_avg:94.00ms +step:212/1670 train_time:19928ms step_avg:94.00ms +step:213/1670 train_time:20289ms step_avg:95.25ms +step:214/1670 train_time:20401ms step_avg:95.33ms +step:215/1670 train_time:20492ms step_avg:95.31ms +step:216/1670 train_time:20583ms step_avg:95.29ms +step:217/1670 train_time:20674ms step_avg:95.27ms +step:218/1670 train_time:20766ms step_avg:95.26ms +step:219/1670 train_time:20858ms step_avg:95.24ms +step:220/1670 train_time:20949ms step_avg:95.22ms +step:221/1670 train_time:21041ms step_avg:95.21ms +step:222/1670 train_time:21133ms step_avg:95.19ms +step:223/1670 train_time:21224ms step_avg:95.18ms +step:224/1670 train_time:21320ms step_avg:95.18ms +step:225/1670 train_time:21417ms step_avg:95.18ms +step:226/1670 train_time:21510ms step_avg:95.18ms +step:227/1670 train_time:21603ms step_avg:95.17ms +step:228/1670 train_time:21694ms step_avg:95.15ms +step:229/1670 train_time:21787ms step_avg:95.14ms +step:230/1670 train_time:21878ms step_avg:95.12ms +step:231/1670 train_time:21970ms step_avg:95.11ms +step:232/1670 train_time:22061ms step_avg:95.09ms +step:233/1670 train_time:22153ms step_avg:95.08ms +step:234/1670 train_time:22246ms step_avg:95.07ms +step:235/1670 train_time:22340ms step_avg:95.06ms +step:236/1670 train_time:22433ms step_avg:95.06ms +step:237/1670 train_time:22527ms step_avg:95.05ms +step:238/1670 train_time:22620ms step_avg:95.04ms +step:239/1670 train_time:22712ms step_avg:95.03ms +step:240/1670 train_time:22804ms step_avg:95.02ms +step:241/1670 train_time:22896ms step_avg:95.01ms +step:242/1670 train_time:22988ms step_avg:94.99ms +step:243/1670 train_time:23081ms step_avg:94.98ms +step:244/1670 train_time:23173ms step_avg:94.97ms +step:245/1670 train_time:23265ms step_avg:94.96ms +step:246/1670 train_time:23358ms step_avg:94.95ms +step:247/1670 train_time:23452ms step_avg:94.95ms +step:248/1670 train_time:23545ms step_avg:94.94ms +step:249/1670 train_time:23638ms step_avg:94.93ms +step:250/1670 train_time:23730ms step_avg:94.92ms +step:250/1670 val_loss:3.9642 train_time:23821ms step_avg:95.28ms +step:251/1670 train_time:23845ms step_avg:95.00ms +step:252/1670 train_time:23918ms step_avg:94.91ms +step:253/1670 train_time:24018ms step_avg:94.93ms +step:254/1670 train_time:24115ms step_avg:94.94ms +step:255/1670 train_time:24207ms step_avg:94.93ms +step:256/1670 train_time:24299ms step_avg:94.92ms +step:257/1670 train_time:24391ms step_avg:94.91ms +step:258/1670 train_time:24482ms step_avg:94.89ms +step:259/1670 train_time:24573ms step_avg:94.88ms +step:260/1670 train_time:24666ms step_avg:94.87ms +step:261/1670 train_time:24758ms step_avg:94.86ms +step:262/1670 train_time:24850ms step_avg:94.85ms +step:263/1670 train_time:24944ms step_avg:94.84ms +step:264/1670 train_time:25039ms step_avg:94.85ms +step:265/1670 train_time:25133ms step_avg:94.84ms +step:266/1670 train_time:25226ms step_avg:94.83ms +step:267/1670 train_time:25318ms step_avg:94.82ms +step:268/1670 train_time:25410ms step_avg:94.81ms +step:269/1670 train_time:25502ms step_avg:94.80ms +step:270/1670 train_time:25593ms step_avg:94.79ms +step:271/1670 train_time:25685ms step_avg:94.78ms +step:272/1670 train_time:25777ms step_avg:94.77ms +step:273/1670 train_time:25869ms step_avg:94.76ms +step:274/1670 train_time:25963ms step_avg:94.75ms +step:275/1670 train_time:26056ms step_avg:94.75ms +step:276/1670 train_time:26149ms step_avg:94.74ms +step:277/1670 train_time:26242ms step_avg:94.73ms +step:278/1670 train_time:26335ms step_avg:94.73ms +step:279/1670 train_time:26427ms step_avg:94.72ms +step:280/1670 train_time:26519ms step_avg:94.71ms +step:281/1670 train_time:26611ms step_avg:94.70ms +step:282/1670 train_time:26704ms step_avg:94.69ms +step:283/1670 train_time:26796ms step_avg:94.68ms +step:284/1670 train_time:26888ms step_avg:94.68ms +step:285/1670 train_time:26981ms step_avg:94.67ms +step:286/1670 train_time:27074ms step_avg:94.66ms +step:287/1670 train_time:27167ms step_avg:94.66ms +step:288/1670 train_time:27261ms step_avg:94.66ms +step:289/1670 train_time:27353ms step_avg:94.65ms +step:290/1670 train_time:27445ms step_avg:94.64ms +step:291/1670 train_time:27538ms step_avg:94.63ms +step:292/1670 train_time:27630ms step_avg:94.62ms +step:293/1670 train_time:27722ms step_avg:94.61ms +step:294/1670 train_time:27814ms step_avg:94.61ms +step:295/1670 train_time:27907ms step_avg:94.60ms +step:296/1670 train_time:28000ms step_avg:94.59ms +step:297/1670 train_time:28092ms step_avg:94.59ms +step:298/1670 train_time:28185ms step_avg:94.58ms +step:299/1670 train_time:28278ms step_avg:94.58ms +step:300/1670 train_time:28371ms step_avg:94.57ms +step:301/1670 train_time:28463ms step_avg:94.56ms +step:302/1670 train_time:28555ms step_avg:94.55ms +step:303/1670 train_time:28648ms step_avg:94.55ms +step:304/1670 train_time:28740ms step_avg:94.54ms +step:305/1670 train_time:28833ms step_avg:94.53ms +step:306/1670 train_time:28926ms step_avg:94.53ms +step:307/1670 train_time:29018ms step_avg:94.52ms +step:308/1670 train_time:29111ms step_avg:94.52ms +step:309/1670 train_time:29205ms step_avg:94.51ms +step:310/1670 train_time:29298ms step_avg:94.51ms +step:311/1670 train_time:29390ms step_avg:94.50ms +step:312/1670 train_time:29483ms step_avg:94.50ms +step:313/1670 train_time:29575ms step_avg:94.49ms +step:314/1670 train_time:29668ms step_avg:94.48ms +step:315/1670 train_time:29760ms step_avg:94.48ms +step:316/1670 train_time:29853ms step_avg:94.47ms +step:317/1670 train_time:29946ms step_avg:94.47ms +step:318/1670 train_time:30039ms step_avg:94.46ms +step:319/1670 train_time:30131ms step_avg:94.46ms +step:320/1670 train_time:30223ms step_avg:94.45ms +step:321/1670 train_time:30316ms step_avg:94.44ms +step:322/1670 train_time:30409ms step_avg:94.44ms +step:323/1670 train_time:30502ms step_avg:94.43ms +step:324/1670 train_time:30594ms step_avg:94.43ms +step:325/1670 train_time:30687ms step_avg:94.42ms +step:326/1670 train_time:30779ms step_avg:94.41ms +step:327/1670 train_time:30871ms step_avg:94.41ms +step:328/1670 train_time:30965ms step_avg:94.41ms +step:329/1670 train_time:31057ms step_avg:94.40ms +step:330/1670 train_time:31149ms step_avg:94.39ms +step:331/1670 train_time:31242ms step_avg:94.39ms +step:332/1670 train_time:31334ms step_avg:94.38ms +step:333/1670 train_time:31427ms step_avg:94.38ms +step:334/1670 train_time:31520ms step_avg:94.37ms +step:335/1670 train_time:31612ms step_avg:94.36ms +step:336/1670 train_time:31705ms step_avg:94.36ms +step:337/1670 train_time:31797ms step_avg:94.35ms +step:338/1670 train_time:31890ms step_avg:94.35ms +step:339/1670 train_time:31982ms step_avg:94.34ms +step:340/1670 train_time:32074ms step_avg:94.34ms +step:341/1670 train_time:32168ms step_avg:94.33ms +step:342/1670 train_time:32260ms step_avg:94.33ms +step:343/1670 train_time:32353ms step_avg:94.32ms +step:344/1670 train_time:32446ms step_avg:94.32ms +step:345/1670 train_time:32538ms step_avg:94.31ms +step:346/1670 train_time:32631ms step_avg:94.31ms +step:347/1670 train_time:32724ms step_avg:94.30ms +step:348/1670 train_time:32816ms step_avg:94.30ms +step:349/1670 train_time:32909ms step_avg:94.30ms +step:350/1670 train_time:33002ms step_avg:94.29ms +step:351/1670 train_time:33094ms step_avg:94.29ms +step:352/1670 train_time:33187ms step_avg:94.28ms +step:353/1670 train_time:33280ms step_avg:94.28ms +step:354/1670 train_time:33372ms step_avg:94.27ms +step:355/1670 train_time:33465ms step_avg:94.27ms +step:356/1670 train_time:33557ms step_avg:94.26ms +step:357/1670 train_time:33649ms step_avg:94.25ms +step:358/1670 train_time:33742ms step_avg:94.25ms +step:359/1670 train_time:33834ms step_avg:94.25ms +step:360/1670 train_time:33927ms step_avg:94.24ms +step:361/1670 train_time:34019ms step_avg:94.24ms +step:362/1670 train_time:34112ms step_avg:94.23ms +step:363/1670 train_time:34205ms step_avg:94.23ms +step:364/1670 train_time:34298ms step_avg:94.22ms +step:365/1670 train_time:34390ms step_avg:94.22ms +step:366/1670 train_time:34483ms step_avg:94.22ms +step:367/1670 train_time:34576ms step_avg:94.21ms +step:368/1670 train_time:34668ms step_avg:94.21ms +step:369/1670 train_time:34760ms step_avg:94.20ms +step:370/1670 train_time:34852ms step_avg:94.20ms +step:371/1670 train_time:34945ms step_avg:94.19ms +step:372/1670 train_time:35038ms step_avg:94.19ms +step:373/1670 train_time:35130ms step_avg:94.18ms +step:374/1670 train_time:35223ms step_avg:94.18ms +step:375/1670 train_time:35315ms step_avg:94.17ms +step:375/1670 val_loss:3.8157 train_time:35407ms step_avg:94.42ms +step:376/1670 train_time:35432ms step_avg:94.23ms +step:377/1670 train_time:35508ms step_avg:94.18ms +step:378/1670 train_time:35608ms step_avg:94.20ms +step:379/1670 train_time:35706ms step_avg:94.21ms +step:380/1670 train_time:35798ms step_avg:94.21ms +step:381/1670 train_time:35890ms step_avg:94.20ms +step:382/1670 train_time:35981ms step_avg:94.19ms +step:383/1670 train_time:36073ms step_avg:94.19ms +step:384/1670 train_time:36164ms step_avg:94.18ms +step:385/1670 train_time:36256ms step_avg:94.17ms +step:386/1670 train_time:36348ms step_avg:94.16ms +step:387/1670 train_time:36440ms step_avg:94.16ms +step:388/1670 train_time:36535ms step_avg:94.16ms +step:389/1670 train_time:36630ms step_avg:94.16ms +step:390/1670 train_time:36724ms step_avg:94.16ms +step:391/1670 train_time:36816ms step_avg:94.16ms +step:392/1670 train_time:36908ms step_avg:94.15ms +step:393/1670 train_time:37000ms step_avg:94.15ms +step:394/1670 train_time:37092ms step_avg:94.14ms +step:395/1670 train_time:37183ms step_avg:94.13ms +step:396/1670 train_time:37275ms step_avg:94.13ms +step:397/1670 train_time:37367ms step_avg:94.12ms +step:398/1670 train_time:37459ms step_avg:94.12ms +step:399/1670 train_time:37553ms step_avg:94.12ms +step:400/1670 train_time:37648ms step_avg:94.12ms +step:401/1670 train_time:37740ms step_avg:94.12ms +step:402/1670 train_time:37834ms step_avg:94.11ms +step:403/1670 train_time:37926ms step_avg:94.11ms +step:404/1670 train_time:38018ms step_avg:94.10ms +step:405/1670 train_time:38110ms step_avg:94.10ms +step:406/1670 train_time:38201ms step_avg:94.09ms +step:407/1670 train_time:38294ms step_avg:94.09ms +step:408/1670 train_time:38386ms step_avg:94.08ms +step:409/1670 train_time:38478ms step_avg:94.08ms +step:410/1670 train_time:38571ms step_avg:94.08ms +step:411/1670 train_time:38664ms step_avg:94.07ms +step:412/1670 train_time:38757ms step_avg:94.07ms +step:413/1670 train_time:38851ms step_avg:94.07ms +step:414/1670 train_time:38944ms step_avg:94.07ms +step:415/1670 train_time:39036ms step_avg:94.06ms +step:416/1670 train_time:39129ms step_avg:94.06ms +step:417/1670 train_time:39221ms step_avg:94.05ms +step:418/1670 train_time:39313ms step_avg:94.05ms +step:419/1670 train_time:39405ms step_avg:94.04ms +step:420/1670 train_time:39497ms step_avg:94.04ms +step:421/1670 train_time:39589ms step_avg:94.04ms +step:422/1670 train_time:39683ms step_avg:94.04ms +step:423/1670 train_time:39777ms step_avg:94.04ms +step:424/1670 train_time:39869ms step_avg:94.03ms +step:425/1670 train_time:40193ms step_avg:94.57ms +step:426/1670 train_time:40386ms step_avg:94.80ms +step:427/1670 train_time:40477ms step_avg:94.79ms +step:428/1670 train_time:40567ms step_avg:94.78ms +step:429/1670 train_time:40659ms step_avg:94.78ms +step:430/1670 train_time:40750ms step_avg:94.77ms +step:431/1670 train_time:40842ms step_avg:94.76ms +step:432/1670 train_time:40934ms step_avg:94.75ms +step:433/1670 train_time:41025ms step_avg:94.75ms +step:434/1670 train_time:41117ms step_avg:94.74ms +step:435/1670 train_time:41209ms step_avg:94.73ms +step:436/1670 train_time:41303ms step_avg:94.73ms +step:437/1670 train_time:41399ms step_avg:94.73ms +step:438/1670 train_time:41494ms step_avg:94.74ms +step:439/1670 train_time:41587ms step_avg:94.73ms +step:440/1670 train_time:41679ms step_avg:94.73ms +step:441/1670 train_time:41771ms step_avg:94.72ms +step:442/1670 train_time:41863ms step_avg:94.71ms +step:443/1670 train_time:41955ms step_avg:94.71ms +step:444/1670 train_time:42047ms step_avg:94.70ms +step:445/1670 train_time:42139ms step_avg:94.69ms +step:446/1670 train_time:42232ms step_avg:94.69ms +step:447/1670 train_time:42326ms step_avg:94.69ms +step:448/1670 train_time:42420ms step_avg:94.69ms +step:449/1670 train_time:42514ms step_avg:94.68ms +step:450/1670 train_time:42607ms step_avg:94.68ms +step:451/1670 train_time:42699ms step_avg:94.68ms +step:452/1670 train_time:42791ms step_avg:94.67ms +step:453/1670 train_time:42884ms step_avg:94.67ms +step:454/1670 train_time:42976ms step_avg:94.66ms +step:455/1670 train_time:43067ms step_avg:94.65ms +step:456/1670 train_time:43159ms step_avg:94.65ms +step:457/1670 train_time:43252ms step_avg:94.64ms +step:458/1670 train_time:43346ms step_avg:94.64ms +step:459/1670 train_time:43438ms step_avg:94.64ms +step:460/1670 train_time:43532ms step_avg:94.63ms +step:461/1670 train_time:43624ms step_avg:94.63ms +step:462/1670 train_time:43716ms step_avg:94.62ms +step:463/1670 train_time:43808ms step_avg:94.62ms +step:464/1670 train_time:43901ms step_avg:94.61ms +step:465/1670 train_time:43993ms step_avg:94.61ms +step:466/1670 train_time:44085ms step_avg:94.60ms +step:467/1670 train_time:44177ms step_avg:94.60ms +step:468/1670 train_time:44270ms step_avg:94.59ms +step:469/1670 train_time:44363ms step_avg:94.59ms +step:470/1670 train_time:44456ms step_avg:94.59ms +step:471/1670 train_time:44549ms step_avg:94.58ms +step:472/1670 train_time:44642ms step_avg:94.58ms +step:473/1670 train_time:44735ms step_avg:94.58ms +step:474/1670 train_time:44827ms step_avg:94.57ms +step:475/1670 train_time:44918ms step_avg:94.57ms +step:476/1670 train_time:45011ms step_avg:94.56ms +step:477/1670 train_time:45103ms step_avg:94.56ms +step:478/1670 train_time:45196ms step_avg:94.55ms +step:479/1670 train_time:45288ms step_avg:94.55ms +step:480/1670 train_time:45381ms step_avg:94.54ms +step:481/1670 train_time:45474ms step_avg:94.54ms +step:482/1670 train_time:45567ms step_avg:94.54ms +step:483/1670 train_time:45660ms step_avg:94.53ms +step:484/1670 train_time:45753ms step_avg:94.53ms +step:485/1670 train_time:45845ms step_avg:94.53ms +step:486/1670 train_time:45938ms step_avg:94.52ms +step:487/1670 train_time:46030ms step_avg:94.52ms +step:488/1670 train_time:46123ms step_avg:94.51ms +step:489/1670 train_time:46215ms step_avg:94.51ms +step:490/1670 train_time:46307ms step_avg:94.50ms +step:491/1670 train_time:46399ms step_avg:94.50ms +step:492/1670 train_time:46492ms step_avg:94.50ms +step:493/1670 train_time:46584ms step_avg:94.49ms +step:494/1670 train_time:46677ms step_avg:94.49ms +step:495/1670 train_time:46770ms step_avg:94.48ms +step:496/1670 train_time:46862ms step_avg:94.48ms +step:497/1670 train_time:46955ms step_avg:94.48ms +step:498/1670 train_time:47047ms step_avg:94.47ms +step:499/1670 train_time:47140ms step_avg:94.47ms +step:500/1670 train_time:47232ms step_avg:94.46ms +step:500/1670 val_loss:3.7170 train_time:47323ms step_avg:94.65ms +step:501/1670 train_time:47348ms step_avg:94.51ms +step:502/1670 train_time:47420ms step_avg:94.46ms +step:503/1670 train_time:47520ms step_avg:94.47ms +step:504/1670 train_time:47615ms step_avg:94.47ms +step:505/1670 train_time:47707ms step_avg:94.47ms +step:506/1670 train_time:47799ms step_avg:94.46ms +step:507/1670 train_time:47891ms step_avg:94.46ms +step:508/1670 train_time:47982ms step_avg:94.45ms +step:509/1670 train_time:48074ms step_avg:94.45ms +step:510/1670 train_time:48165ms step_avg:94.44ms +step:511/1670 train_time:48256ms step_avg:94.43ms +step:512/1670 train_time:48348ms step_avg:94.43ms +step:513/1670 train_time:48443ms step_avg:94.43ms +step:514/1670 train_time:48537ms step_avg:94.43ms +step:515/1670 train_time:48631ms step_avg:94.43ms +step:516/1670 train_time:48724ms step_avg:94.43ms +step:517/1670 train_time:48816ms step_avg:94.42ms +step:518/1670 train_time:48908ms step_avg:94.42ms +step:519/1670 train_time:49000ms step_avg:94.41ms +step:520/1670 train_time:49093ms step_avg:94.41ms +step:521/1670 train_time:49184ms step_avg:94.40ms +step:522/1670 train_time:49276ms step_avg:94.40ms +step:523/1670 train_time:49368ms step_avg:94.39ms +step:524/1670 train_time:49462ms step_avg:94.39ms +step:525/1670 train_time:49555ms step_avg:94.39ms +step:526/1670 train_time:49648ms step_avg:94.39ms +step:527/1670 train_time:49741ms step_avg:94.38ms +step:528/1670 train_time:49833ms step_avg:94.38ms +step:529/1670 train_time:49925ms step_avg:94.38ms +step:530/1670 train_time:50018ms step_avg:94.37ms +step:531/1670 train_time:50110ms step_avg:94.37ms +step:532/1670 train_time:50203ms step_avg:94.37ms +step:533/1670 train_time:50296ms step_avg:94.36ms +step:534/1670 train_time:50388ms step_avg:94.36ms +step:535/1670 train_time:50481ms step_avg:94.36ms +step:536/1670 train_time:50574ms step_avg:94.35ms +step:537/1670 train_time:50666ms step_avg:94.35ms +step:538/1670 train_time:50760ms step_avg:94.35ms +step:539/1670 train_time:50853ms step_avg:94.35ms +step:540/1670 train_time:50945ms step_avg:94.34ms +step:541/1670 train_time:51037ms step_avg:94.34ms +step:542/1670 train_time:51129ms step_avg:94.33ms +step:543/1670 train_time:51221ms step_avg:94.33ms +step:544/1670 train_time:51313ms step_avg:94.33ms +step:545/1670 train_time:51406ms step_avg:94.32ms +step:546/1670 train_time:51499ms step_avg:94.32ms +step:547/1670 train_time:51595ms step_avg:94.32ms +step:548/1670 train_time:51687ms step_avg:94.32ms +step:549/1670 train_time:51778ms step_avg:94.31ms +step:550/1670 train_time:51871ms step_avg:94.31ms +step:551/1670 train_time:51964ms step_avg:94.31ms +step:552/1670 train_time:52056ms step_avg:94.30ms +step:553/1670 train_time:52148ms step_avg:94.30ms +step:554/1670 train_time:52242ms step_avg:94.30ms +step:555/1670 train_time:52334ms step_avg:94.30ms +step:556/1670 train_time:52427ms step_avg:94.29ms +step:557/1670 train_time:52520ms step_avg:94.29ms +step:558/1670 train_time:52713ms step_avg:94.47ms +step:559/1670 train_time:52790ms step_avg:94.44ms +step:560/1670 train_time:52882ms step_avg:94.43ms +step:561/1670 train_time:52975ms step_avg:94.43ms +step:562/1670 train_time:53068ms step_avg:94.43ms +step:563/1670 train_time:53161ms step_avg:94.42ms +step:564/1670 train_time:53253ms step_avg:94.42ms +step:565/1670 train_time:53346ms step_avg:94.42ms +step:566/1670 train_time:53438ms step_avg:94.41ms +step:567/1670 train_time:53531ms step_avg:94.41ms +step:568/1670 train_time:53629ms step_avg:94.42ms +step:569/1670 train_time:53727ms step_avg:94.42ms +step:570/1670 train_time:53822ms step_avg:94.43ms +step:571/1670 train_time:53916ms step_avg:94.42ms +step:572/1670 train_time:54009ms step_avg:94.42ms +step:573/1670 train_time:54102ms step_avg:94.42ms +step:574/1670 train_time:54196ms step_avg:94.42ms +step:575/1670 train_time:54289ms step_avg:94.42ms +step:576/1670 train_time:54382ms step_avg:94.41ms +step:577/1670 train_time:54475ms step_avg:94.41ms +step:578/1670 train_time:54568ms step_avg:94.41ms +step:579/1670 train_time:54663ms step_avg:94.41ms +step:580/1670 train_time:54758ms step_avg:94.41ms +step:581/1670 train_time:54852ms step_avg:94.41ms +step:582/1670 train_time:54946ms step_avg:94.41ms +step:583/1670 train_time:55039ms step_avg:94.41ms +step:584/1670 train_time:55133ms step_avg:94.41ms +step:585/1670 train_time:55228ms step_avg:94.41ms +step:586/1670 train_time:55321ms step_avg:94.40ms +step:587/1670 train_time:55413ms step_avg:94.40ms +step:588/1670 train_time:55507ms step_avg:94.40ms +step:589/1670 train_time:55602ms step_avg:94.40ms +step:590/1670 train_time:55697ms step_avg:94.40ms +step:591/1670 train_time:55791ms step_avg:94.40ms +step:592/1670 train_time:55885ms step_avg:94.40ms +step:593/1670 train_time:55979ms step_avg:94.40ms +step:594/1670 train_time:56073ms step_avg:94.40ms +step:595/1670 train_time:56167ms step_avg:94.40ms +step:596/1670 train_time:56260ms step_avg:94.40ms +step:597/1670 train_time:56353ms step_avg:94.39ms +step:598/1670 train_time:56446ms step_avg:94.39ms +step:599/1670 train_time:56539ms step_avg:94.39ms +step:600/1670 train_time:56634ms step_avg:94.39ms +step:601/1670 train_time:56728ms step_avg:94.39ms +step:602/1670 train_time:56822ms step_avg:94.39ms +step:603/1670 train_time:56916ms step_avg:94.39ms +step:604/1670 train_time:57010ms step_avg:94.39ms +step:605/1670 train_time:57105ms step_avg:94.39ms +step:606/1670 train_time:57198ms step_avg:94.39ms +step:607/1670 train_time:57292ms step_avg:94.39ms +step:608/1670 train_time:57385ms step_avg:94.38ms +step:609/1670 train_time:57479ms step_avg:94.38ms +step:610/1670 train_time:57572ms step_avg:94.38ms +step:611/1670 train_time:57666ms step_avg:94.38ms +step:612/1670 train_time:57761ms step_avg:94.38ms +step:613/1670 train_time:57855ms step_avg:94.38ms +step:614/1670 train_time:57949ms step_avg:94.38ms +step:615/1670 train_time:58043ms step_avg:94.38ms +step:616/1670 train_time:58137ms step_avg:94.38ms +step:617/1670 train_time:58232ms step_avg:94.38ms +step:618/1670 train_time:58325ms step_avg:94.38ms +step:619/1670 train_time:58418ms step_avg:94.37ms +step:620/1670 train_time:58511ms step_avg:94.37ms +step:621/1670 train_time:58605ms step_avg:94.37ms +step:622/1670 train_time:58700ms step_avg:94.37ms +step:623/1670 train_time:58793ms step_avg:94.37ms +step:624/1670 train_time:58888ms step_avg:94.37ms +step:625/1670 train_time:58982ms step_avg:94.37ms +step:625/1670 val_loss:3.6148 train_time:59074ms step_avg:94.52ms +step:626/1670 train_time:59099ms step_avg:94.41ms +step:627/1670 train_time:59181ms step_avg:94.39ms +step:628/1670 train_time:59281ms step_avg:94.40ms +step:629/1670 train_time:59376ms step_avg:94.40ms +step:630/1670 train_time:59468ms step_avg:94.39ms +step:631/1670 train_time:59561ms step_avg:94.39ms +step:632/1670 train_time:59654ms step_avg:94.39ms +step:633/1670 train_time:59747ms step_avg:94.39ms +step:634/1670 train_time:59840ms step_avg:94.38ms +step:635/1670 train_time:59932ms step_avg:94.38ms +step:636/1670 train_time:60024ms step_avg:94.38ms +step:637/1670 train_time:60120ms step_avg:94.38ms +step:638/1670 train_time:60216ms step_avg:94.38ms +step:639/1670 train_time:60666ms step_avg:94.94ms +step:640/1670 train_time:60738ms step_avg:94.90ms +step:641/1670 train_time:60830ms step_avg:94.90ms +step:642/1670 train_time:60923ms step_avg:94.90ms +step:643/1670 train_time:61016ms step_avg:94.89ms +step:644/1670 train_time:61108ms step_avg:94.89ms +step:645/1670 train_time:61201ms step_avg:94.88ms +step:646/1670 train_time:61293ms step_avg:94.88ms +step:647/1670 train_time:61386ms step_avg:94.88ms +step:648/1670 train_time:61478ms step_avg:94.87ms +step:649/1670 train_time:61574ms step_avg:94.88ms +step:650/1670 train_time:61676ms step_avg:94.89ms +step:651/1670 train_time:61773ms step_avg:94.89ms +step:652/1670 train_time:61867ms step_avg:94.89ms +step:653/1670 train_time:61960ms step_avg:94.88ms +step:654/1670 train_time:62053ms step_avg:94.88ms +step:655/1670 train_time:62145ms step_avg:94.88ms +step:656/1670 train_time:62238ms step_avg:94.87ms +step:657/1670 train_time:62331ms step_avg:94.87ms +step:658/1670 train_time:62423ms step_avg:94.87ms +step:659/1670 train_time:62517ms step_avg:94.87ms +step:660/1670 train_time:62612ms step_avg:94.87ms +step:661/1670 train_time:62706ms step_avg:94.87ms +step:662/1670 train_time:62801ms step_avg:94.87ms +step:663/1670 train_time:62896ms step_avg:94.87ms +step:664/1670 train_time:62990ms step_avg:94.86ms +step:665/1670 train_time:63082ms step_avg:94.86ms +step:666/1670 train_time:63176ms step_avg:94.86ms +step:667/1670 train_time:63268ms step_avg:94.85ms +step:668/1670 train_time:63361ms step_avg:94.85ms +step:669/1670 train_time:63454ms step_avg:94.85ms +step:670/1670 train_time:63548ms step_avg:94.85ms +step:671/1670 train_time:63643ms step_avg:94.85ms +step:672/1670 train_time:63737ms step_avg:94.85ms +step:673/1670 train_time:63832ms step_avg:94.85ms +step:674/1670 train_time:63926ms step_avg:94.85ms +step:675/1670 train_time:64020ms step_avg:94.85ms +step:676/1670 train_time:64114ms step_avg:94.84ms +step:677/1670 train_time:64207ms step_avg:94.84ms +step:678/1670 train_time:64301ms step_avg:94.84ms +step:679/1670 train_time:64393ms step_avg:94.84ms +step:680/1670 train_time:64486ms step_avg:94.83ms +step:681/1670 train_time:64580ms step_avg:94.83ms +step:682/1670 train_time:64674ms step_avg:94.83ms +step:683/1670 train_time:64768ms step_avg:94.83ms +step:684/1670 train_time:64863ms step_avg:94.83ms +step:685/1670 train_time:64957ms step_avg:94.83ms +step:686/1670 train_time:65051ms step_avg:94.83ms +step:687/1670 train_time:65144ms step_avg:94.82ms +step:688/1670 train_time:65237ms step_avg:94.82ms +step:689/1670 train_time:65331ms step_avg:94.82ms +step:690/1670 train_time:65424ms step_avg:94.82ms +step:691/1670 train_time:65517ms step_avg:94.82ms +step:692/1670 train_time:65611ms step_avg:94.81ms +step:693/1670 train_time:65705ms step_avg:94.81ms +step:694/1670 train_time:65799ms step_avg:94.81ms +step:695/1670 train_time:65893ms step_avg:94.81ms +step:696/1670 train_time:65987ms step_avg:94.81ms +step:697/1670 train_time:66081ms step_avg:94.81ms +step:698/1670 train_time:66175ms step_avg:94.81ms +step:699/1670 train_time:66269ms step_avg:94.81ms +step:700/1670 train_time:66362ms step_avg:94.80ms +step:701/1670 train_time:66455ms step_avg:94.80ms +step:702/1670 train_time:66549ms step_avg:94.80ms +step:703/1670 train_time:66642ms step_avg:94.80ms +step:704/1670 train_time:66736ms step_avg:94.80ms +step:705/1670 train_time:66831ms step_avg:94.80ms +step:706/1670 train_time:66925ms step_avg:94.79ms +step:707/1670 train_time:67019ms step_avg:94.79ms +step:708/1670 train_time:67113ms step_avg:94.79ms +step:709/1670 train_time:67206ms step_avg:94.79ms +step:710/1670 train_time:67300ms step_avg:94.79ms +step:711/1670 train_time:67393ms step_avg:94.79ms +step:712/1670 train_time:67487ms step_avg:94.78ms +step:713/1670 train_time:67581ms step_avg:94.78ms +step:714/1670 train_time:67675ms step_avg:94.78ms +step:715/1670 train_time:67768ms step_avg:94.78ms +step:716/1670 train_time:67862ms step_avg:94.78ms +step:717/1670 train_time:67956ms step_avg:94.78ms +step:718/1670 train_time:68049ms step_avg:94.78ms +step:719/1670 train_time:68143ms step_avg:94.77ms +step:720/1670 train_time:68237ms step_avg:94.77ms +step:721/1670 train_time:68331ms step_avg:94.77ms +step:722/1670 train_time:68424ms step_avg:94.77ms +step:723/1670 train_time:68518ms step_avg:94.77ms +step:724/1670 train_time:68612ms step_avg:94.77ms +step:725/1670 train_time:68706ms step_avg:94.77ms +step:726/1670 train_time:68799ms step_avg:94.77ms +step:727/1670 train_time:68893ms step_avg:94.76ms +step:728/1670 train_time:68987ms step_avg:94.76ms +step:729/1670 train_time:69081ms step_avg:94.76ms +step:730/1670 train_time:69175ms step_avg:94.76ms +step:731/1670 train_time:69268ms step_avg:94.76ms +step:732/1670 train_time:69362ms step_avg:94.76ms +step:733/1670 train_time:69456ms step_avg:94.76ms +step:734/1670 train_time:69550ms step_avg:94.75ms +step:735/1670 train_time:69643ms step_avg:94.75ms +step:736/1670 train_time:69737ms step_avg:94.75ms +step:737/1670 train_time:69831ms step_avg:94.75ms +step:738/1670 train_time:69926ms step_avg:94.75ms +step:739/1670 train_time:70019ms step_avg:94.75ms +step:740/1670 train_time:70112ms step_avg:94.75ms +step:741/1670 train_time:70206ms step_avg:94.74ms +step:742/1670 train_time:70299ms step_avg:94.74ms +step:743/1670 train_time:70393ms step_avg:94.74ms +step:744/1670 train_time:70487ms step_avg:94.74ms +step:745/1670 train_time:70580ms step_avg:94.74ms +step:746/1670 train_time:70673ms step_avg:94.74ms +step:747/1670 train_time:70767ms step_avg:94.73ms +step:748/1670 train_time:70861ms step_avg:94.73ms +step:749/1670 train_time:70956ms step_avg:94.73ms +step:750/1670 train_time:71049ms step_avg:94.73ms +step:750/1670 val_loss:3.5637 train_time:71141ms step_avg:94.85ms +step:751/1670 train_time:71167ms step_avg:94.76ms +step:752/1670 train_time:71242ms step_avg:94.74ms +step:753/1670 train_time:71342ms step_avg:94.74ms +step:754/1670 train_time:71439ms step_avg:94.75ms +step:755/1670 train_time:71532ms step_avg:94.74ms +step:756/1670 train_time:71625ms step_avg:94.74ms +step:757/1670 train_time:71717ms step_avg:94.74ms +step:758/1670 train_time:71810ms step_avg:94.74ms +step:759/1670 train_time:71902ms step_avg:94.73ms +step:760/1670 train_time:71995ms step_avg:94.73ms +step:761/1670 train_time:72088ms step_avg:94.73ms +step:762/1670 train_time:72182ms step_avg:94.73ms +step:763/1670 train_time:72277ms step_avg:94.73ms +step:764/1670 train_time:72374ms step_avg:94.73ms +step:765/1670 train_time:72468ms step_avg:94.73ms +step:766/1670 train_time:72563ms step_avg:94.73ms +step:767/1670 train_time:72656ms step_avg:94.73ms +step:768/1670 train_time:72749ms step_avg:94.72ms +step:769/1670 train_time:72841ms step_avg:94.72ms +step:770/1670 train_time:72934ms step_avg:94.72ms +step:771/1670 train_time:73027ms step_avg:94.72ms +step:772/1670 train_time:73120ms step_avg:94.72ms +step:773/1670 train_time:73214ms step_avg:94.71ms +step:774/1670 train_time:73309ms step_avg:94.71ms +step:775/1670 train_time:73404ms step_avg:94.71ms +step:776/1670 train_time:73499ms step_avg:94.72ms +step:777/1670 train_time:73593ms step_avg:94.71ms +step:778/1670 train_time:73686ms step_avg:94.71ms +step:779/1670 train_time:73779ms step_avg:94.71ms +step:780/1670 train_time:73872ms step_avg:94.71ms +step:781/1670 train_time:73966ms step_avg:94.71ms +step:782/1670 train_time:74059ms step_avg:94.70ms +step:783/1670 train_time:74153ms step_avg:94.70ms +step:784/1670 train_time:74246ms step_avg:94.70ms +step:785/1670 train_time:74342ms step_avg:94.70ms +step:786/1670 train_time:74437ms step_avg:94.70ms +step:787/1670 train_time:74531ms step_avg:94.70ms +step:788/1670 train_time:74624ms step_avg:94.70ms +step:789/1670 train_time:74719ms step_avg:94.70ms +step:790/1670 train_time:74811ms step_avg:94.70ms +step:791/1670 train_time:74905ms step_avg:94.70ms +step:792/1670 train_time:74999ms step_avg:94.70ms +step:793/1670 train_time:75092ms step_avg:94.69ms +step:794/1670 train_time:75186ms step_avg:94.69ms +step:795/1670 train_time:75279ms step_avg:94.69ms +step:796/1670 train_time:75374ms step_avg:94.69ms +step:797/1670 train_time:75469ms step_avg:94.69ms +step:798/1670 train_time:75562ms step_avg:94.69ms +step:799/1670 train_time:75656ms step_avg:94.69ms +step:800/1670 train_time:75750ms step_avg:94.69ms +step:801/1670 train_time:75843ms step_avg:94.69ms +step:802/1670 train_time:75936ms step_avg:94.68ms +step:803/1670 train_time:76030ms step_avg:94.68ms +step:804/1670 train_time:76124ms step_avg:94.68ms +step:805/1670 train_time:76217ms step_avg:94.68ms +step:806/1670 train_time:76311ms step_avg:94.68ms +step:807/1670 train_time:76405ms step_avg:94.68ms +step:808/1670 train_time:76500ms step_avg:94.68ms +step:809/1670 train_time:76594ms step_avg:94.68ms +step:810/1670 train_time:76688ms step_avg:94.68ms +step:811/1670 train_time:76781ms step_avg:94.67ms +step:812/1670 train_time:76875ms step_avg:94.67ms +step:813/1670 train_time:76969ms step_avg:94.67ms +step:814/1670 train_time:77061ms step_avg:94.67ms +step:815/1670 train_time:77155ms step_avg:94.67ms +step:816/1670 train_time:77249ms step_avg:94.67ms +step:817/1670 train_time:77343ms step_avg:94.67ms +step:818/1670 train_time:77437ms step_avg:94.67ms +step:819/1670 train_time:77532ms step_avg:94.67ms +step:820/1670 train_time:77625ms step_avg:94.67ms +step:821/1670 train_time:77719ms step_avg:94.66ms +step:822/1670 train_time:77813ms step_avg:94.66ms +step:823/1670 train_time:77906ms step_avg:94.66ms +step:824/1670 train_time:78000ms step_avg:94.66ms +step:825/1670 train_time:78093ms step_avg:94.66ms +step:826/1670 train_time:78186ms step_avg:94.66ms +step:827/1670 train_time:78280ms step_avg:94.65ms +step:828/1670 train_time:78374ms step_avg:94.65ms +step:829/1670 train_time:78467ms step_avg:94.65ms +step:830/1670 train_time:78561ms step_avg:94.65ms +step:831/1670 train_time:78654ms step_avg:94.65ms +step:832/1670 train_time:78748ms step_avg:94.65ms +step:833/1670 train_time:78842ms step_avg:94.65ms +step:834/1670 train_time:78936ms step_avg:94.65ms +step:835/1670 train_time:79030ms step_avg:94.65ms +step:836/1670 train_time:79124ms step_avg:94.65ms +step:837/1670 train_time:79217ms step_avg:94.64ms +step:838/1670 train_time:79312ms step_avg:94.64ms +step:839/1670 train_time:79406ms step_avg:94.64ms +step:840/1670 train_time:79501ms step_avg:94.64ms +step:841/1670 train_time:79594ms step_avg:94.64ms +step:842/1670 train_time:79688ms step_avg:94.64ms +step:843/1670 train_time:79782ms step_avg:94.64ms +step:844/1670 train_time:79876ms step_avg:94.64ms +step:845/1670 train_time:79970ms step_avg:94.64ms +step:846/1670 train_time:80064ms step_avg:94.64ms +step:847/1670 train_time:80157ms step_avg:94.64ms +step:848/1670 train_time:80251ms step_avg:94.64ms +step:849/1670 train_time:80344ms step_avg:94.63ms +step:850/1670 train_time:80439ms step_avg:94.63ms +step:851/1670 train_time:80791ms step_avg:94.94ms +step:852/1670 train_time:80960ms step_avg:95.02ms +step:853/1670 train_time:81052ms step_avg:95.02ms +step:854/1670 train_time:81144ms step_avg:95.02ms +step:855/1670 train_time:81237ms step_avg:95.01ms +step:856/1670 train_time:81330ms step_avg:95.01ms +step:857/1670 train_time:81423ms step_avg:95.01ms +step:858/1670 train_time:81516ms step_avg:95.01ms +step:859/1670 train_time:81608ms step_avg:95.00ms +step:860/1670 train_time:81701ms step_avg:95.00ms +step:861/1670 train_time:81799ms step_avg:95.00ms +step:862/1670 train_time:81896ms step_avg:95.01ms +step:863/1670 train_time:81992ms step_avg:95.01ms +step:864/1670 train_time:82086ms step_avg:95.01ms +step:865/1670 train_time:82179ms step_avg:95.00ms +step:866/1670 train_time:82271ms step_avg:95.00ms +step:867/1670 train_time:82365ms step_avg:95.00ms +step:868/1670 train_time:82458ms step_avg:95.00ms +step:869/1670 train_time:82550ms step_avg:94.99ms +step:870/1670 train_time:82643ms step_avg:94.99ms +step:871/1670 train_time:82736ms step_avg:94.99ms +step:872/1670 train_time:82832ms step_avg:94.99ms +step:873/1670 train_time:82927ms step_avg:94.99ms +step:874/1670 train_time:83021ms step_avg:94.99ms +step:875/1670 train_time:83114ms step_avg:94.99ms +step:875/1670 val_loss:3.5185 train_time:83206ms step_avg:95.09ms +step:876/1670 train_time:83232ms step_avg:95.01ms +step:877/1670 train_time:83306ms step_avg:94.99ms +step:878/1670 train_time:83409ms step_avg:95.00ms +step:879/1670 train_time:83503ms step_avg:95.00ms +step:880/1670 train_time:83596ms step_avg:95.00ms +step:881/1670 train_time:83689ms step_avg:94.99ms +step:882/1670 train_time:83781ms step_avg:94.99ms +step:883/1670 train_time:83873ms step_avg:94.99ms +step:884/1670 train_time:83966ms step_avg:94.98ms +step:885/1670 train_time:84059ms step_avg:94.98ms +step:886/1670 train_time:84152ms step_avg:94.98ms +step:887/1670 train_time:84247ms step_avg:94.98ms +step:888/1670 train_time:84344ms step_avg:94.98ms +step:889/1670 train_time:84439ms step_avg:94.98ms +step:890/1670 train_time:84534ms step_avg:94.98ms +step:891/1670 train_time:84628ms step_avg:94.98ms +step:892/1670 train_time:84721ms step_avg:94.98ms +step:893/1670 train_time:84813ms step_avg:94.98ms +step:894/1670 train_time:84906ms step_avg:94.97ms +step:895/1670 train_time:84999ms step_avg:94.97ms +step:896/1670 train_time:85093ms step_avg:94.97ms +step:897/1670 train_time:85187ms step_avg:94.97ms +step:898/1670 train_time:85281ms step_avg:94.97ms +step:899/1670 train_time:85377ms step_avg:94.97ms +step:900/1670 train_time:85473ms step_avg:94.97ms +step:901/1670 train_time:85568ms step_avg:94.97ms +step:902/1670 train_time:85661ms step_avg:94.97ms +step:903/1670 train_time:85754ms step_avg:94.97ms +step:904/1670 train_time:85847ms step_avg:94.96ms +step:905/1670 train_time:85940ms step_avg:94.96ms +step:906/1670 train_time:86033ms step_avg:94.96ms +step:907/1670 train_time:86126ms step_avg:94.96ms +step:908/1670 train_time:86219ms step_avg:94.96ms +step:909/1670 train_time:86314ms step_avg:94.96ms +step:910/1670 train_time:86409ms step_avg:94.96ms +step:911/1670 train_time:86504ms step_avg:94.96ms +step:912/1670 train_time:86598ms step_avg:94.95ms +step:913/1670 train_time:86691ms step_avg:94.95ms +step:914/1670 train_time:86785ms step_avg:94.95ms +step:915/1670 train_time:86878ms step_avg:94.95ms +step:916/1670 train_time:86971ms step_avg:94.95ms +step:917/1670 train_time:87065ms step_avg:94.95ms +step:918/1670 train_time:87157ms step_avg:94.94ms +step:919/1670 train_time:87251ms step_avg:94.94ms +step:920/1670 train_time:87346ms step_avg:94.94ms +step:921/1670 train_time:87440ms step_avg:94.94ms +step:922/1670 train_time:87535ms step_avg:94.94ms +step:923/1670 train_time:87629ms step_avg:94.94ms +step:924/1670 train_time:87722ms step_avg:94.94ms +step:925/1670 train_time:87816ms step_avg:94.94ms +step:926/1670 train_time:87909ms step_avg:94.93ms +step:927/1670 train_time:88002ms step_avg:94.93ms +step:928/1670 train_time:88096ms step_avg:94.93ms +step:929/1670 train_time:88189ms step_avg:94.93ms +step:930/1670 train_time:88283ms step_avg:94.93ms +step:931/1670 train_time:88377ms step_avg:94.93ms +step:932/1670 train_time:88471ms step_avg:94.93ms +step:933/1670 train_time:88566ms step_avg:94.93ms +step:934/1670 train_time:88660ms step_avg:94.93ms +step:935/1670 train_time:88753ms step_avg:94.92ms +step:936/1670 train_time:88847ms step_avg:94.92ms +step:937/1670 train_time:88940ms step_avg:94.92ms +step:938/1670 train_time:89033ms step_avg:94.92ms +step:939/1670 train_time:89128ms step_avg:94.92ms +step:940/1670 train_time:89221ms step_avg:94.92ms +step:941/1670 train_time:89315ms step_avg:94.91ms +step:942/1670 train_time:89409ms step_avg:94.91ms +step:943/1670 train_time:89504ms step_avg:94.91ms +step:944/1670 train_time:89597ms step_avg:94.91ms +step:945/1670 train_time:89691ms step_avg:94.91ms +step:946/1670 train_time:89784ms step_avg:94.91ms +step:947/1670 train_time:89879ms step_avg:94.91ms +step:948/1670 train_time:89972ms step_avg:94.91ms +step:949/1670 train_time:90066ms step_avg:94.91ms +step:950/1670 train_time:90159ms step_avg:94.90ms +step:951/1670 train_time:90253ms step_avg:94.90ms +step:952/1670 train_time:90347ms step_avg:94.90ms +step:953/1670 train_time:90441ms step_avg:94.90ms +step:954/1670 train_time:90535ms step_avg:94.90ms +step:955/1670 train_time:90629ms step_avg:94.90ms +step:956/1670 train_time:90722ms step_avg:94.90ms +step:957/1670 train_time:90815ms step_avg:94.90ms +step:958/1670 train_time:90909ms step_avg:94.89ms +step:959/1670 train_time:91002ms step_avg:94.89ms +step:960/1670 train_time:91096ms step_avg:94.89ms +step:961/1670 train_time:91189ms step_avg:94.89ms +step:962/1670 train_time:91283ms step_avg:94.89ms +step:963/1670 train_time:91377ms step_avg:94.89ms +step:964/1670 train_time:91472ms step_avg:94.89ms +step:965/1670 train_time:91565ms step_avg:94.89ms +step:966/1670 train_time:91659ms step_avg:94.88ms +step:967/1670 train_time:91752ms step_avg:94.88ms +step:968/1670 train_time:91846ms step_avg:94.88ms +step:969/1670 train_time:91940ms step_avg:94.88ms +step:970/1670 train_time:92034ms step_avg:94.88ms +step:971/1670 train_time:92127ms step_avg:94.88ms +step:972/1670 train_time:92220ms step_avg:94.88ms +step:973/1670 train_time:92315ms step_avg:94.88ms +step:974/1670 train_time:92409ms step_avg:94.88ms +step:975/1670 train_time:92503ms step_avg:94.87ms +step:976/1670 train_time:92597ms step_avg:94.87ms +step:977/1670 train_time:92690ms step_avg:94.87ms +step:978/1670 train_time:92784ms step_avg:94.87ms +step:979/1670 train_time:92878ms step_avg:94.87ms +step:980/1670 train_time:92971ms step_avg:94.87ms +step:981/1670 train_time:93065ms step_avg:94.87ms +step:982/1670 train_time:93158ms step_avg:94.87ms +step:983/1670 train_time:93252ms step_avg:94.86ms +step:984/1670 train_time:93346ms step_avg:94.86ms +step:985/1670 train_time:93439ms step_avg:94.86ms +step:986/1670 train_time:93533ms step_avg:94.86ms +step:987/1670 train_time:93626ms step_avg:94.86ms +step:988/1670 train_time:93720ms step_avg:94.86ms +step:989/1670 train_time:93814ms step_avg:94.86ms +step:990/1670 train_time:93908ms step_avg:94.86ms +step:991/1670 train_time:94002ms step_avg:94.86ms +step:992/1670 train_time:94095ms step_avg:94.85ms +step:993/1670 train_time:94189ms step_avg:94.85ms +step:994/1670 train_time:94283ms step_avg:94.85ms +step:995/1670 train_time:94377ms step_avg:94.85ms +step:996/1670 train_time:94472ms step_avg:94.85ms +step:997/1670 train_time:94566ms step_avg:94.85ms +step:998/1670 train_time:94659ms step_avg:94.85ms +step:999/1670 train_time:94753ms step_avg:94.85ms +step:1000/1670 train_time:94846ms step_avg:94.85ms +step:1000/1670 val_loss:3.4706 train_time:94939ms step_avg:94.94ms +step:1001/1670 train_time:94965ms step_avg:94.87ms +step:1002/1670 train_time:95040ms step_avg:94.85ms +step:1003/1670 train_time:95141ms step_avg:94.86ms +step:1004/1670 train_time:95236ms step_avg:94.86ms +step:1005/1670 train_time:95329ms step_avg:94.85ms +step:1006/1670 train_time:95422ms step_avg:94.85ms +step:1007/1670 train_time:95515ms step_avg:94.85ms +step:1008/1670 train_time:95607ms step_avg:94.85ms +step:1009/1670 train_time:95700ms step_avg:94.85ms +step:1010/1670 train_time:95793ms step_avg:94.84ms +step:1011/1670 train_time:95885ms step_avg:94.84ms +step:1012/1670 train_time:95980ms step_avg:94.84ms +step:1013/1670 train_time:96077ms step_avg:94.84ms +step:1014/1670 train_time:96172ms step_avg:94.84ms +step:1015/1670 train_time:96267ms step_avg:94.84ms +step:1016/1670 train_time:96360ms step_avg:94.84ms +step:1017/1670 train_time:96454ms step_avg:94.84ms +step:1018/1670 train_time:96548ms step_avg:94.84ms +step:1019/1670 train_time:96641ms step_avg:94.84ms +step:1020/1670 train_time:96734ms step_avg:94.84ms +step:1021/1670 train_time:96827ms step_avg:94.84ms +step:1022/1670 train_time:96919ms step_avg:94.83ms +step:1023/1670 train_time:97014ms step_avg:94.83ms +step:1024/1670 train_time:97109ms step_avg:94.83ms +step:1025/1670 train_time:97204ms step_avg:94.83ms +step:1026/1670 train_time:97299ms step_avg:94.83ms +step:1027/1670 train_time:97392ms step_avg:94.83ms +step:1028/1670 train_time:97486ms step_avg:94.83ms +step:1029/1670 train_time:97579ms step_avg:94.83ms +step:1030/1670 train_time:97672ms step_avg:94.83ms +step:1031/1670 train_time:97764ms step_avg:94.82ms +step:1032/1670 train_time:97857ms step_avg:94.82ms +step:1033/1670 train_time:97951ms step_avg:94.82ms +step:1034/1670 train_time:98046ms step_avg:94.82ms +step:1035/1670 train_time:98140ms step_avg:94.82ms +step:1036/1670 train_time:98235ms step_avg:94.82ms +step:1037/1670 train_time:98329ms step_avg:94.82ms +step:1038/1670 train_time:98422ms step_avg:94.82ms +step:1039/1670 train_time:98517ms step_avg:94.82ms +step:1040/1670 train_time:98610ms step_avg:94.82ms +step:1041/1670 train_time:98703ms step_avg:94.82ms +step:1042/1670 train_time:98796ms step_avg:94.81ms +step:1043/1670 train_time:98890ms step_avg:94.81ms +step:1044/1670 train_time:98984ms step_avg:94.81ms +step:1045/1670 train_time:99078ms step_avg:94.81ms +step:1046/1670 train_time:99172ms step_avg:94.81ms +step:1047/1670 train_time:99266ms step_avg:94.81ms +step:1048/1670 train_time:99360ms step_avg:94.81ms +step:1049/1670 train_time:99455ms step_avg:94.81ms +step:1050/1670 train_time:99549ms step_avg:94.81ms +step:1051/1670 train_time:99642ms step_avg:94.81ms +step:1052/1670 train_time:99735ms step_avg:94.81ms +step:1053/1670 train_time:99829ms step_avg:94.80ms +step:1054/1670 train_time:99921ms step_avg:94.80ms +step:1055/1670 train_time:100015ms step_avg:94.80ms +step:1056/1670 train_time:100109ms step_avg:94.80ms +step:1057/1670 train_time:100203ms step_avg:94.80ms +step:1058/1670 train_time:100297ms step_avg:94.80ms +step:1059/1670 train_time:100392ms step_avg:94.80ms +step:1060/1670 train_time:100486ms step_avg:94.80ms +step:1061/1670 train_time:100580ms step_avg:94.80ms +step:1062/1670 train_time:100905ms step_avg:95.01ms +step:1063/1670 train_time:101101ms step_avg:95.11ms +step:1064/1670 train_time:101193ms step_avg:95.11ms +step:1065/1670 train_time:101286ms step_avg:95.10ms +step:1066/1670 train_time:101379ms step_avg:95.10ms +step:1067/1670 train_time:101471ms step_avg:95.10ms +step:1068/1670 train_time:101564ms step_avg:95.10ms +step:1069/1670 train_time:101656ms step_avg:95.09ms +step:1070/1670 train_time:101749ms step_avg:95.09ms +step:1071/1670 train_time:101841ms step_avg:95.09ms +step:1072/1670 train_time:101938ms step_avg:95.09ms +step:1073/1670 train_time:102037ms step_avg:95.09ms +step:1074/1670 train_time:102132ms step_avg:95.09ms +step:1075/1670 train_time:102226ms step_avg:95.09ms +step:1076/1670 train_time:102319ms step_avg:95.09ms +step:1077/1670 train_time:102412ms step_avg:95.09ms +step:1078/1670 train_time:102505ms step_avg:95.09ms +step:1079/1670 train_time:102598ms step_avg:95.09ms +step:1080/1670 train_time:102691ms step_avg:95.08ms +step:1081/1670 train_time:102783ms step_avg:95.08ms +step:1082/1670 train_time:102876ms step_avg:95.08ms +step:1083/1670 train_time:102971ms step_avg:95.08ms +step:1084/1670 train_time:103067ms step_avg:95.08ms +step:1085/1670 train_time:103162ms step_avg:95.08ms +step:1086/1670 train_time:103256ms step_avg:95.08ms +step:1087/1670 train_time:103350ms step_avg:95.08ms +step:1088/1670 train_time:103443ms step_avg:95.08ms +step:1089/1670 train_time:103536ms step_avg:95.07ms +step:1090/1670 train_time:103629ms step_avg:95.07ms +step:1091/1670 train_time:103723ms step_avg:95.07ms +step:1092/1670 train_time:103816ms step_avg:95.07ms +step:1093/1670 train_time:103910ms step_avg:95.07ms +step:1094/1670 train_time:104005ms step_avg:95.07ms +step:1095/1670 train_time:104099ms step_avg:95.07ms +step:1096/1670 train_time:104193ms step_avg:95.07ms +step:1097/1670 train_time:104287ms step_avg:95.07ms +step:1098/1670 train_time:104381ms step_avg:95.06ms +step:1099/1670 train_time:104475ms step_avg:95.06ms +step:1100/1670 train_time:104568ms step_avg:95.06ms +step:1101/1670 train_time:104661ms step_avg:95.06ms +step:1102/1670 train_time:104755ms step_avg:95.06ms +step:1103/1670 train_time:104849ms step_avg:95.06ms +step:1104/1670 train_time:104943ms step_avg:95.06ms +step:1105/1670 train_time:105037ms step_avg:95.06ms +step:1106/1670 train_time:105130ms step_avg:95.05ms +step:1107/1670 train_time:105224ms step_avg:95.05ms +step:1108/1670 train_time:105318ms step_avg:95.05ms +step:1109/1670 train_time:105412ms step_avg:95.05ms +step:1110/1670 train_time:105506ms step_avg:95.05ms +step:1111/1670 train_time:105599ms step_avg:95.05ms +step:1112/1670 train_time:105693ms step_avg:95.05ms +step:1113/1670 train_time:105786ms step_avg:95.05ms +step:1114/1670 train_time:105879ms step_avg:95.04ms +step:1115/1670 train_time:106074ms step_avg:95.13ms +step:1116/1670 train_time:106152ms step_avg:95.12ms +step:1117/1670 train_time:106245ms step_avg:95.12ms +step:1118/1670 train_time:106338ms step_avg:95.11ms +step:1119/1670 train_time:106431ms step_avg:95.11ms +step:1120/1670 train_time:106524ms step_avg:95.11ms +step:1121/1670 train_time:106618ms step_avg:95.11ms +step:1122/1670 train_time:106711ms step_avg:95.11ms +step:1123/1670 train_time:106804ms step_avg:95.11ms +step:1124/1670 train_time:106897ms step_avg:95.10ms +step:1125/1670 train_time:106995ms step_avg:95.11ms +step:1125/1670 val_loss:3.4180 train_time:107091ms step_avg:95.19ms +step:1126/1670 train_time:107115ms step_avg:95.13ms +step:1127/1670 train_time:107195ms step_avg:95.12ms +step:1128/1670 train_time:107296ms step_avg:95.12ms +step:1129/1670 train_time:107391ms step_avg:95.12ms +step:1130/1670 train_time:107484ms step_avg:95.12ms +step:1131/1670 train_time:107578ms step_avg:95.12ms +step:1132/1670 train_time:107671ms step_avg:95.12ms +step:1133/1670 train_time:107764ms step_avg:95.11ms +step:1134/1670 train_time:107858ms step_avg:95.11ms +step:1135/1670 train_time:107951ms step_avg:95.11ms +step:1136/1670 train_time:108045ms step_avg:95.11ms +step:1137/1670 train_time:108140ms step_avg:95.11ms +step:1138/1670 train_time:108238ms step_avg:95.11ms +step:1139/1670 train_time:108335ms step_avg:95.11ms +step:1140/1670 train_time:108430ms step_avg:95.11ms +step:1141/1670 train_time:108524ms step_avg:95.11ms +step:1142/1670 train_time:108618ms step_avg:95.11ms +step:1143/1670 train_time:108712ms step_avg:95.11ms +step:1144/1670 train_time:108805ms step_avg:95.11ms +step:1145/1670 train_time:108898ms step_avg:95.11ms +step:1146/1670 train_time:108993ms step_avg:95.11ms +step:1147/1670 train_time:109087ms step_avg:95.11ms +step:1148/1670 train_time:109183ms step_avg:95.11ms +step:1149/1670 train_time:109278ms step_avg:95.11ms +step:1150/1670 train_time:109374ms step_avg:95.11ms +step:1151/1670 train_time:109469ms step_avg:95.11ms +step:1152/1670 train_time:109563ms step_avg:95.11ms +step:1153/1670 train_time:109657ms step_avg:95.11ms +step:1154/1670 train_time:109750ms step_avg:95.10ms +step:1155/1670 train_time:109844ms step_avg:95.10ms +step:1156/1670 train_time:109938ms step_avg:95.10ms +step:1157/1670 train_time:110032ms step_avg:95.10ms +step:1158/1670 train_time:110128ms step_avg:95.10ms +step:1159/1670 train_time:110223ms step_avg:95.10ms +step:1160/1670 train_time:110318ms step_avg:95.10ms +step:1161/1670 train_time:110414ms step_avg:95.10ms +step:1162/1670 train_time:110509ms step_avg:95.10ms +step:1163/1670 train_time:110603ms step_avg:95.10ms +step:1164/1670 train_time:110698ms step_avg:95.10ms +step:1165/1670 train_time:110791ms step_avg:95.10ms +step:1166/1670 train_time:110886ms step_avg:95.10ms +step:1167/1670 train_time:110979ms step_avg:95.10ms +step:1168/1670 train_time:111073ms step_avg:95.10ms +step:1169/1670 train_time:111168ms step_avg:95.10ms +step:1170/1670 train_time:111263ms step_avg:95.10ms +step:1171/1670 train_time:111358ms step_avg:95.10ms +step:1172/1670 train_time:111453ms step_avg:95.10ms +step:1173/1670 train_time:111548ms step_avg:95.10ms +step:1174/1670 train_time:111642ms step_avg:95.10ms +step:1175/1670 train_time:111736ms step_avg:95.09ms +step:1176/1670 train_time:111830ms step_avg:95.09ms +step:1177/1670 train_time:111924ms step_avg:95.09ms +step:1178/1670 train_time:112018ms step_avg:95.09ms +step:1179/1670 train_time:112113ms step_avg:95.09ms +step:1180/1670 train_time:112208ms step_avg:95.09ms +step:1181/1670 train_time:112303ms step_avg:95.09ms +step:1182/1670 train_time:112398ms step_avg:95.09ms +step:1183/1670 train_time:112493ms step_avg:95.09ms +step:1184/1670 train_time:112587ms step_avg:95.09ms +step:1185/1670 train_time:112682ms step_avg:95.09ms +step:1186/1670 train_time:112777ms step_avg:95.09ms +step:1187/1670 train_time:112870ms step_avg:95.09ms +step:1188/1670 train_time:112964ms step_avg:95.09ms +step:1189/1670 train_time:113059ms step_avg:95.09ms +step:1190/1670 train_time:113153ms step_avg:95.09ms +step:1191/1670 train_time:113248ms step_avg:95.09ms +step:1192/1670 train_time:113344ms step_avg:95.09ms +step:1193/1670 train_time:113439ms step_avg:95.09ms +step:1194/1670 train_time:113534ms step_avg:95.09ms +step:1195/1670 train_time:113629ms step_avg:95.09ms +step:1196/1670 train_time:113724ms step_avg:95.09ms +step:1197/1670 train_time:113819ms step_avg:95.09ms +step:1198/1670 train_time:113913ms step_avg:95.09ms +step:1199/1670 train_time:114007ms step_avg:95.09ms +step:1200/1670 train_time:114102ms step_avg:95.09ms +step:1201/1670 train_time:114196ms step_avg:95.08ms +step:1202/1670 train_time:114292ms step_avg:95.08ms +step:1203/1670 train_time:114387ms step_avg:95.08ms +step:1204/1670 train_time:114481ms step_avg:95.08ms +step:1205/1670 train_time:114576ms step_avg:95.08ms +step:1206/1670 train_time:114670ms step_avg:95.08ms +step:1207/1670 train_time:114765ms step_avg:95.08ms +step:1208/1670 train_time:114859ms step_avg:95.08ms +step:1209/1670 train_time:114954ms step_avg:95.08ms +step:1210/1670 train_time:115048ms step_avg:95.08ms +step:1211/1670 train_time:115142ms step_avg:95.08ms +step:1212/1670 train_time:115237ms step_avg:95.08ms +step:1213/1670 train_time:115332ms step_avg:95.08ms +step:1214/1670 train_time:115427ms step_avg:95.08ms +step:1215/1670 train_time:115520ms step_avg:95.08ms +step:1216/1670 train_time:115614ms step_avg:95.08ms +step:1217/1670 train_time:115708ms step_avg:95.08ms +step:1218/1670 train_time:115804ms step_avg:95.08ms +step:1219/1670 train_time:115899ms step_avg:95.08ms +step:1220/1670 train_time:115993ms step_avg:95.08ms +step:1221/1670 train_time:116088ms step_avg:95.08ms +step:1222/1670 train_time:116183ms step_avg:95.08ms +step:1223/1670 train_time:116277ms step_avg:95.07ms +step:1224/1670 train_time:116372ms step_avg:95.08ms +step:1225/1670 train_time:116467ms step_avg:95.08ms +step:1226/1670 train_time:116561ms step_avg:95.07ms +step:1227/1670 train_time:116655ms step_avg:95.07ms +step:1228/1670 train_time:116750ms step_avg:95.07ms +step:1229/1670 train_time:116845ms step_avg:95.07ms +step:1230/1670 train_time:116939ms step_avg:95.07ms +step:1231/1670 train_time:117034ms step_avg:95.07ms +step:1232/1670 train_time:117129ms step_avg:95.07ms +step:1233/1670 train_time:117224ms step_avg:95.07ms +step:1234/1670 train_time:117319ms step_avg:95.07ms +step:1235/1670 train_time:117414ms step_avg:95.07ms +step:1236/1670 train_time:117508ms step_avg:95.07ms +step:1237/1670 train_time:117602ms step_avg:95.07ms +step:1238/1670 train_time:117697ms step_avg:95.07ms +step:1239/1670 train_time:117792ms step_avg:95.07ms +step:1240/1670 train_time:117887ms step_avg:95.07ms +step:1241/1670 train_time:117982ms step_avg:95.07ms +step:1242/1670 train_time:118077ms step_avg:95.07ms +step:1243/1670 train_time:118172ms step_avg:95.07ms +step:1244/1670 train_time:118266ms step_avg:95.07ms +step:1245/1670 train_time:118362ms step_avg:95.07ms +step:1246/1670 train_time:118456ms step_avg:95.07ms +step:1247/1670 train_time:118550ms step_avg:95.07ms +step:1248/1670 train_time:118646ms step_avg:95.07ms +step:1249/1670 train_time:118740ms step_avg:95.07ms +step:1250/1670 train_time:118834ms step_avg:95.07ms +step:1250/1670 val_loss:3.3783 train_time:118926ms step_avg:95.14ms +step:1251/1670 train_time:118951ms step_avg:95.08ms +step:1252/1670 train_time:119028ms step_avg:95.07ms +step:1253/1670 train_time:119129ms step_avg:95.07ms +step:1254/1670 train_time:119225ms step_avg:95.08ms +step:1255/1670 train_time:119319ms step_avg:95.07ms +step:1256/1670 train_time:119413ms step_avg:95.07ms +step:1257/1670 train_time:119505ms step_avg:95.07ms +step:1258/1670 train_time:119599ms step_avg:95.07ms +step:1259/1670 train_time:119693ms step_avg:95.07ms +step:1260/1670 train_time:119786ms step_avg:95.07ms +step:1261/1670 train_time:119879ms step_avg:95.07ms +step:1262/1670 train_time:119974ms step_avg:95.07ms +step:1263/1670 train_time:120071ms step_avg:95.07ms +step:1264/1670 train_time:120167ms step_avg:95.07ms +step:1265/1670 train_time:120262ms step_avg:95.07ms +step:1266/1670 train_time:120356ms step_avg:95.07ms +step:1267/1670 train_time:120451ms step_avg:95.07ms +step:1268/1670 train_time:120545ms step_avg:95.07ms +step:1269/1670 train_time:120638ms step_avg:95.07ms +step:1270/1670 train_time:120732ms step_avg:95.06ms +step:1271/1670 train_time:120825ms step_avg:95.06ms +step:1272/1670 train_time:120919ms step_avg:95.06ms +step:1273/1670 train_time:121013ms step_avg:95.06ms +step:1274/1670 train_time:121465ms step_avg:95.34ms +step:1275/1670 train_time:121536ms step_avg:95.32ms +step:1276/1670 train_time:121629ms step_avg:95.32ms +step:1277/1670 train_time:121722ms step_avg:95.32ms +step:1278/1670 train_time:121815ms step_avg:95.32ms +step:1279/1670 train_time:121909ms step_avg:95.32ms +step:1280/1670 train_time:122002ms step_avg:95.31ms +step:1281/1670 train_time:122096ms step_avg:95.31ms +step:1282/1670 train_time:122190ms step_avg:95.31ms +step:1283/1670 train_time:122283ms step_avg:95.31ms +step:1284/1670 train_time:122380ms step_avg:95.31ms +step:1285/1670 train_time:122478ms step_avg:95.31ms +step:1286/1670 train_time:122576ms step_avg:95.32ms +step:1287/1670 train_time:122670ms step_avg:95.31ms +step:1288/1670 train_time:122765ms step_avg:95.31ms +step:1289/1670 train_time:122858ms step_avg:95.31ms +step:1290/1670 train_time:122952ms step_avg:95.31ms +step:1291/1670 train_time:123047ms step_avg:95.31ms +step:1292/1670 train_time:123140ms step_avg:95.31ms +step:1293/1670 train_time:123233ms step_avg:95.31ms +step:1294/1670 train_time:123329ms step_avg:95.31ms +step:1295/1670 train_time:123425ms step_avg:95.31ms +step:1296/1670 train_time:123522ms step_avg:95.31ms +step:1297/1670 train_time:123617ms step_avg:95.31ms +step:1298/1670 train_time:123712ms step_avg:95.31ms +step:1299/1670 train_time:123807ms step_avg:95.31ms +step:1300/1670 train_time:123901ms step_avg:95.31ms +step:1301/1670 train_time:123995ms step_avg:95.31ms +step:1302/1670 train_time:124089ms step_avg:95.31ms +step:1303/1670 train_time:124182ms step_avg:95.31ms +step:1304/1670 train_time:124277ms step_avg:95.30ms +step:1305/1670 train_time:124373ms step_avg:95.30ms +step:1306/1670 train_time:124468ms step_avg:95.31ms +step:1307/1670 train_time:124564ms step_avg:95.30ms +step:1308/1670 train_time:124659ms step_avg:95.30ms +step:1309/1670 train_time:124754ms step_avg:95.30ms +step:1310/1670 train_time:124849ms step_avg:95.30ms +step:1311/1670 train_time:124942ms step_avg:95.30ms +step:1312/1670 train_time:125037ms step_avg:95.30ms +step:1313/1670 train_time:125131ms step_avg:95.30ms +step:1314/1670 train_time:125225ms step_avg:95.30ms +step:1315/1670 train_time:125320ms step_avg:95.30ms +step:1316/1670 train_time:125415ms step_avg:95.30ms +step:1317/1670 train_time:125510ms step_avg:95.30ms +step:1318/1670 train_time:125605ms step_avg:95.30ms +step:1319/1670 train_time:125700ms step_avg:95.30ms +step:1320/1670 train_time:125795ms step_avg:95.30ms +step:1321/1670 train_time:125890ms step_avg:95.30ms +step:1322/1670 train_time:125984ms step_avg:95.30ms +step:1323/1670 train_time:126078ms step_avg:95.30ms +step:1324/1670 train_time:126171ms step_avg:95.30ms +step:1325/1670 train_time:126266ms step_avg:95.29ms +step:1326/1670 train_time:126361ms step_avg:95.29ms +step:1327/1670 train_time:126455ms step_avg:95.29ms +step:1328/1670 train_time:126551ms step_avg:95.29ms +step:1329/1670 train_time:126645ms step_avg:95.29ms +step:1330/1670 train_time:126741ms step_avg:95.29ms +step:1331/1670 train_time:126835ms step_avg:95.29ms +step:1332/1670 train_time:126930ms step_avg:95.29ms +step:1333/1670 train_time:127024ms step_avg:95.29ms +step:1334/1670 train_time:127118ms step_avg:95.29ms +step:1335/1670 train_time:127212ms step_avg:95.29ms +step:1336/1670 train_time:127307ms step_avg:95.29ms +step:1337/1670 train_time:127403ms step_avg:95.29ms +step:1338/1670 train_time:127497ms step_avg:95.29ms +step:1339/1670 train_time:127593ms step_avg:95.29ms +step:1340/1670 train_time:127688ms step_avg:95.29ms +step:1341/1670 train_time:127782ms step_avg:95.29ms +step:1342/1670 train_time:127877ms step_avg:95.29ms +step:1343/1670 train_time:127972ms step_avg:95.29ms +step:1344/1670 train_time:128066ms step_avg:95.29ms +step:1345/1670 train_time:128161ms step_avg:95.29ms +step:1346/1670 train_time:128255ms step_avg:95.29ms +step:1347/1670 train_time:128349ms step_avg:95.29ms +step:1348/1670 train_time:128444ms step_avg:95.28ms +step:1349/1670 train_time:128538ms step_avg:95.28ms +step:1350/1670 train_time:128633ms step_avg:95.28ms +step:1351/1670 train_time:128728ms step_avg:95.28ms +step:1352/1670 train_time:128822ms step_avg:95.28ms +step:1353/1670 train_time:128917ms step_avg:95.28ms +step:1354/1670 train_time:129012ms step_avg:95.28ms +step:1355/1670 train_time:129105ms step_avg:95.28ms +step:1356/1670 train_time:129200ms step_avg:95.28ms +step:1357/1670 train_time:129295ms step_avg:95.28ms +step:1358/1670 train_time:129388ms step_avg:95.28ms +step:1359/1670 train_time:129482ms step_avg:95.28ms +step:1360/1670 train_time:129577ms step_avg:95.28ms +step:1361/1670 train_time:129672ms step_avg:95.28ms +step:1362/1670 train_time:129766ms step_avg:95.28ms +step:1363/1670 train_time:129861ms step_avg:95.28ms +step:1364/1670 train_time:129955ms step_avg:95.28ms +step:1365/1670 train_time:130049ms step_avg:95.27ms +step:1366/1670 train_time:130144ms step_avg:95.27ms +step:1367/1670 train_time:130238ms step_avg:95.27ms +step:1368/1670 train_time:130333ms step_avg:95.27ms +step:1369/1670 train_time:130427ms step_avg:95.27ms +step:1370/1670 train_time:130523ms step_avg:95.27ms +step:1371/1670 train_time:130618ms step_avg:95.27ms +step:1372/1670 train_time:130713ms step_avg:95.27ms +step:1373/1670 train_time:130808ms step_avg:95.27ms +step:1374/1670 train_time:130902ms step_avg:95.27ms +step:1375/1670 train_time:130997ms step_avg:95.27ms +step:1375/1670 val_loss:3.3447 train_time:131091ms step_avg:95.34ms +step:1376/1670 train_time:131116ms step_avg:95.29ms +step:1377/1670 train_time:131193ms step_avg:95.27ms +step:1378/1670 train_time:131295ms step_avg:95.28ms +step:1379/1670 train_time:131392ms step_avg:95.28ms +step:1380/1670 train_time:131486ms step_avg:95.28ms +step:1381/1670 train_time:131580ms step_avg:95.28ms +step:1382/1670 train_time:131672ms step_avg:95.28ms +step:1383/1670 train_time:131767ms step_avg:95.28ms +step:1384/1670 train_time:131862ms step_avg:95.28ms +step:1385/1670 train_time:131955ms step_avg:95.27ms +step:1386/1670 train_time:132050ms step_avg:95.27ms +step:1387/1670 train_time:132149ms step_avg:95.28ms +step:1388/1670 train_time:132249ms step_avg:95.28ms +step:1389/1670 train_time:132345ms step_avg:95.28ms +step:1390/1670 train_time:132439ms step_avg:95.28ms +step:1391/1670 train_time:132532ms step_avg:95.28ms +step:1392/1670 train_time:132626ms step_avg:95.28ms +step:1393/1670 train_time:132719ms step_avg:95.28ms +step:1394/1670 train_time:132813ms step_avg:95.27ms +step:1395/1670 train_time:132906ms step_avg:95.27ms +step:1396/1670 train_time:133001ms step_avg:95.27ms +step:1397/1670 train_time:133096ms step_avg:95.27ms +step:1398/1670 train_time:133192ms step_avg:95.27ms +step:1399/1670 train_time:133288ms step_avg:95.27ms +step:1400/1670 train_time:133382ms step_avg:95.27ms +step:1401/1670 train_time:133477ms step_avg:95.27ms +step:1402/1670 train_time:133571ms step_avg:95.27ms +step:1403/1670 train_time:133666ms step_avg:95.27ms +step:1404/1670 train_time:133760ms step_avg:95.27ms +step:1405/1670 train_time:133854ms step_avg:95.27ms +step:1406/1670 train_time:133948ms step_avg:95.27ms +step:1407/1670 train_time:134043ms step_avg:95.27ms +step:1408/1670 train_time:134138ms step_avg:95.27ms +step:1409/1670 train_time:134235ms step_avg:95.27ms +step:1410/1670 train_time:134330ms step_avg:95.27ms +step:1411/1670 train_time:134425ms step_avg:95.27ms +step:1412/1670 train_time:134519ms step_avg:95.27ms +step:1413/1670 train_time:134614ms step_avg:95.27ms +step:1414/1670 train_time:134708ms step_avg:95.27ms +step:1415/1670 train_time:134802ms step_avg:95.27ms +step:1416/1670 train_time:134897ms step_avg:95.27ms +step:1417/1670 train_time:134991ms step_avg:95.27ms +step:1418/1670 train_time:135086ms step_avg:95.26ms +step:1419/1670 train_time:135181ms step_avg:95.26ms +step:1420/1670 train_time:135276ms step_avg:95.26ms +step:1421/1670 train_time:135372ms step_avg:95.27ms +step:1422/1670 train_time:135466ms step_avg:95.26ms +step:1423/1670 train_time:135561ms step_avg:95.26ms +step:1424/1670 train_time:135656ms step_avg:95.26ms +step:1425/1670 train_time:135750ms step_avg:95.26ms +step:1426/1670 train_time:135844ms step_avg:95.26ms +step:1427/1670 train_time:135939ms step_avg:95.26ms +step:1428/1670 train_time:136033ms step_avg:95.26ms +step:1429/1670 train_time:136128ms step_avg:95.26ms +step:1430/1670 train_time:136223ms step_avg:95.26ms +step:1431/1670 train_time:136318ms step_avg:95.26ms +step:1432/1670 train_time:136413ms step_avg:95.26ms +step:1433/1670 train_time:136507ms step_avg:95.26ms +step:1434/1670 train_time:136602ms step_avg:95.26ms +step:1435/1670 train_time:136697ms step_avg:95.26ms +step:1436/1670 train_time:136792ms step_avg:95.26ms +step:1437/1670 train_time:136886ms step_avg:95.26ms +step:1438/1670 train_time:136980ms step_avg:95.26ms +step:1439/1670 train_time:137075ms step_avg:95.26ms +step:1440/1670 train_time:137169ms step_avg:95.26ms +step:1441/1670 train_time:137264ms step_avg:95.26ms +step:1442/1670 train_time:137358ms step_avg:95.26ms +step:1443/1670 train_time:137452ms step_avg:95.25ms +step:1444/1670 train_time:137547ms step_avg:95.25ms +step:1445/1670 train_time:137642ms step_avg:95.25ms +step:1446/1670 train_time:137736ms step_avg:95.25ms +step:1447/1670 train_time:137831ms step_avg:95.25ms +step:1448/1670 train_time:137927ms step_avg:95.25ms +step:1449/1670 train_time:138021ms step_avg:95.25ms +step:1450/1670 train_time:138115ms step_avg:95.25ms +step:1451/1670 train_time:138209ms step_avg:95.25ms +step:1452/1670 train_time:138304ms step_avg:95.25ms +step:1453/1670 train_time:138400ms step_avg:95.25ms +step:1454/1670 train_time:138495ms step_avg:95.25ms +step:1455/1670 train_time:138589ms step_avg:95.25ms +step:1456/1670 train_time:138684ms step_avg:95.25ms +step:1457/1670 train_time:138779ms step_avg:95.25ms +step:1458/1670 train_time:138873ms step_avg:95.25ms +step:1459/1670 train_time:138968ms step_avg:95.25ms +step:1460/1670 train_time:139063ms step_avg:95.25ms +step:1461/1670 train_time:139158ms step_avg:95.25ms +step:1462/1670 train_time:139252ms step_avg:95.25ms +step:1463/1670 train_time:139347ms step_avg:95.25ms +step:1464/1670 train_time:139442ms step_avg:95.25ms +step:1465/1670 train_time:139537ms step_avg:95.25ms +step:1466/1670 train_time:139632ms step_avg:95.25ms +step:1467/1670 train_time:139727ms step_avg:95.25ms +step:1468/1670 train_time:139822ms step_avg:95.25ms +step:1469/1670 train_time:139916ms step_avg:95.25ms +step:1470/1670 train_time:140011ms step_avg:95.25ms +step:1471/1670 train_time:140105ms step_avg:95.24ms +step:1472/1670 train_time:140200ms step_avg:95.24ms +step:1473/1670 train_time:140295ms step_avg:95.24ms +step:1474/1670 train_time:140390ms step_avg:95.24ms +step:1475/1670 train_time:140485ms step_avg:95.24ms +step:1476/1670 train_time:140579ms step_avg:95.24ms +step:1477/1670 train_time:140673ms step_avg:95.24ms +step:1478/1670 train_time:140768ms step_avg:95.24ms +step:1479/1670 train_time:140863ms step_avg:95.24ms +step:1480/1670 train_time:140958ms step_avg:95.24ms +step:1481/1670 train_time:141053ms step_avg:95.24ms +step:1482/1670 train_time:141148ms step_avg:95.24ms +step:1483/1670 train_time:141242ms step_avg:95.24ms +step:1484/1670 train_time:141336ms step_avg:95.24ms +step:1485/1670 train_time:141661ms step_avg:95.39ms +step:1486/1670 train_time:141856ms step_avg:95.46ms +step:1487/1670 train_time:141948ms step_avg:95.46ms +step:1488/1670 train_time:142042ms step_avg:95.46ms +step:1489/1670 train_time:142135ms step_avg:95.46ms +step:1490/1670 train_time:142229ms step_avg:95.46ms +step:1491/1670 train_time:142323ms step_avg:95.45ms +step:1492/1670 train_time:142416ms step_avg:95.45ms +step:1493/1670 train_time:142510ms step_avg:95.45ms +step:1494/1670 train_time:142604ms step_avg:95.45ms +step:1495/1670 train_time:142701ms step_avg:95.45ms +step:1496/1670 train_time:142798ms step_avg:95.45ms +step:1497/1670 train_time:142897ms step_avg:95.46ms +step:1498/1670 train_time:142992ms step_avg:95.45ms +step:1499/1670 train_time:143085ms step_avg:95.45ms +step:1500/1670 train_time:143179ms step_avg:95.45ms +step:1500/1670 val_loss:3.3142 train_time:143271ms step_avg:95.51ms +step:1501/1670 train_time:143295ms step_avg:95.47ms +step:1502/1670 train_time:143372ms step_avg:95.45ms +step:1503/1670 train_time:143470ms step_avg:95.46ms +step:1504/1670 train_time:143565ms step_avg:95.46ms +step:1505/1670 train_time:143660ms step_avg:95.46ms +step:1506/1670 train_time:143753ms step_avg:95.45ms +step:1507/1670 train_time:143846ms step_avg:95.45ms +step:1508/1670 train_time:143940ms step_avg:95.45ms +step:1509/1670 train_time:144034ms step_avg:95.45ms +step:1510/1670 train_time:144128ms step_avg:95.45ms +step:1511/1670 train_time:144223ms step_avg:95.45ms +step:1512/1670 train_time:144321ms step_avg:95.45ms +step:1513/1670 train_time:144419ms step_avg:95.45ms +step:1514/1670 train_time:144516ms step_avg:95.45ms +step:1515/1670 train_time:144611ms step_avg:95.45ms +step:1516/1670 train_time:144705ms step_avg:95.45ms +step:1517/1670 train_time:144798ms step_avg:95.45ms +step:1518/1670 train_time:144892ms step_avg:95.45ms +step:1519/1670 train_time:144985ms step_avg:95.45ms +step:1520/1670 train_time:145078ms step_avg:95.45ms +step:1521/1670 train_time:145173ms step_avg:95.45ms +step:1522/1670 train_time:145268ms step_avg:95.45ms +step:1523/1670 train_time:145364ms step_avg:95.45ms +step:1524/1670 train_time:145461ms step_avg:95.45ms +step:1525/1670 train_time:145556ms step_avg:95.45ms +step:1526/1670 train_time:145651ms step_avg:95.45ms +step:1527/1670 train_time:145745ms step_avg:95.45ms +step:1528/1670 train_time:145839ms step_avg:95.44ms +step:1529/1670 train_time:145933ms step_avg:95.44ms +step:1530/1670 train_time:146026ms step_avg:95.44ms +step:1531/1670 train_time:146120ms step_avg:95.44ms +step:1532/1670 train_time:146215ms step_avg:95.44ms +step:1533/1670 train_time:146312ms step_avg:95.44ms +step:1534/1670 train_time:146407ms step_avg:95.44ms +step:1535/1670 train_time:146504ms step_avg:95.44ms +step:1536/1670 train_time:146599ms step_avg:95.44ms +step:1537/1670 train_time:146694ms step_avg:95.44ms +step:1538/1670 train_time:146789ms step_avg:95.44ms +step:1539/1670 train_time:146883ms step_avg:95.44ms +step:1540/1670 train_time:146976ms step_avg:95.44ms +step:1541/1670 train_time:147070ms step_avg:95.44ms +step:1542/1670 train_time:147165ms step_avg:95.44ms +step:1543/1670 train_time:147260ms step_avg:95.44ms +step:1544/1670 train_time:147355ms step_avg:95.44ms +step:1545/1670 train_time:147451ms step_avg:95.44ms +step:1546/1670 train_time:147547ms step_avg:95.44ms +step:1547/1670 train_time:147642ms step_avg:95.44ms +step:1548/1670 train_time:147737ms step_avg:95.44ms +step:1549/1670 train_time:147832ms step_avg:95.44ms +step:1550/1670 train_time:147925ms step_avg:95.44ms +step:1551/1670 train_time:148019ms step_avg:95.43ms +step:1552/1670 train_time:148113ms step_avg:95.43ms +step:1553/1670 train_time:148208ms step_avg:95.43ms +step:1554/1670 train_time:148303ms step_avg:95.43ms +step:1555/1670 train_time:148398ms step_avg:95.43ms +step:1556/1670 train_time:148493ms step_avg:95.43ms +step:1557/1670 train_time:148588ms step_avg:95.43ms +step:1558/1670 train_time:148683ms step_avg:95.43ms +step:1559/1670 train_time:148777ms step_avg:95.43ms +step:1560/1670 train_time:148872ms step_avg:95.43ms +step:1561/1670 train_time:148966ms step_avg:95.43ms +step:1562/1670 train_time:149061ms step_avg:95.43ms +step:1563/1670 train_time:149155ms step_avg:95.43ms +step:1564/1670 train_time:149251ms step_avg:95.43ms +step:1565/1670 train_time:149346ms step_avg:95.43ms +step:1566/1670 train_time:149441ms step_avg:95.43ms +step:1567/1670 train_time:149536ms step_avg:95.43ms +step:1568/1670 train_time:149630ms step_avg:95.43ms +step:1569/1670 train_time:149725ms step_avg:95.43ms +step:1570/1670 train_time:149820ms step_avg:95.43ms +step:1571/1670 train_time:149914ms step_avg:95.43ms +step:1572/1670 train_time:150008ms step_avg:95.42ms +step:1573/1670 train_time:150103ms step_avg:95.42ms +step:1574/1670 train_time:150197ms step_avg:95.42ms +step:1575/1670 train_time:150291ms step_avg:95.42ms +step:1576/1670 train_time:150386ms step_avg:95.42ms +step:1577/1670 train_time:150481ms step_avg:95.42ms +step:1578/1670 train_time:150576ms step_avg:95.42ms +step:1579/1670 train_time:150671ms step_avg:95.42ms +step:1580/1670 train_time:150766ms step_avg:95.42ms +step:1581/1670 train_time:150861ms step_avg:95.42ms +step:1582/1670 train_time:150955ms step_avg:95.42ms +step:1583/1670 train_time:151049ms step_avg:95.42ms +step:1584/1670 train_time:151143ms step_avg:95.42ms +step:1585/1670 train_time:151238ms step_avg:95.42ms +step:1586/1670 train_time:151332ms step_avg:95.42ms +step:1587/1670 train_time:151427ms step_avg:95.42ms +step:1588/1670 train_time:151522ms step_avg:95.42ms +step:1589/1670 train_time:151617ms step_avg:95.42ms +step:1590/1670 train_time:151712ms step_avg:95.42ms +step:1591/1670 train_time:151806ms step_avg:95.42ms +step:1592/1670 train_time:151900ms step_avg:95.41ms +step:1593/1670 train_time:151994ms step_avg:95.41ms +step:1594/1670 train_time:152089ms step_avg:95.41ms +step:1595/1670 train_time:152183ms step_avg:95.41ms +step:1596/1670 train_time:152279ms step_avg:95.41ms +step:1597/1670 train_time:152373ms step_avg:95.41ms +step:1598/1670 train_time:152468ms step_avg:95.41ms +step:1599/1670 train_time:152565ms step_avg:95.41ms +step:1600/1670 train_time:152659ms step_avg:95.41ms +step:1601/1670 train_time:152753ms step_avg:95.41ms +step:1602/1670 train_time:152847ms step_avg:95.41ms +step:1603/1670 train_time:152942ms step_avg:95.41ms +step:1604/1670 train_time:153036ms step_avg:95.41ms +step:1605/1670 train_time:153131ms step_avg:95.41ms +step:1606/1670 train_time:153225ms step_avg:95.41ms +step:1607/1670 train_time:153320ms step_avg:95.41ms +step:1608/1670 train_time:153415ms step_avg:95.41ms +step:1609/1670 train_time:153511ms step_avg:95.41ms +step:1610/1670 train_time:153605ms step_avg:95.41ms +step:1611/1670 train_time:153701ms step_avg:95.41ms +step:1612/1670 train_time:153796ms step_avg:95.41ms +step:1613/1670 train_time:153891ms step_avg:95.41ms +step:1614/1670 train_time:153986ms step_avg:95.41ms +step:1615/1670 train_time:154081ms step_avg:95.41ms +step:1616/1670 train_time:154175ms step_avg:95.41ms +step:1617/1670 train_time:154269ms step_avg:95.40ms +step:1618/1670 train_time:154363ms step_avg:95.40ms +step:1619/1670 train_time:154458ms step_avg:95.40ms +step:1620/1670 train_time:154554ms step_avg:95.40ms +step:1621/1670 train_time:154650ms step_avg:95.40ms +step:1622/1670 train_time:154745ms step_avg:95.40ms +step:1623/1670 train_time:154839ms step_avg:95.40ms +step:1624/1670 train_time:154933ms step_avg:95.40ms +step:1625/1670 train_time:155027ms step_avg:95.40ms +step:1625/1670 val_loss:3.2894 train_time:155120ms step_avg:95.46ms +step:1626/1670 train_time:155144ms step_avg:95.41ms +step:1627/1670 train_time:155222ms step_avg:95.40ms +step:1628/1670 train_time:155323ms step_avg:95.41ms +step:1629/1670 train_time:155421ms step_avg:95.41ms +step:1630/1670 train_time:155517ms step_avg:95.41ms +step:1631/1670 train_time:155610ms step_avg:95.41ms +step:1632/1670 train_time:155703ms step_avg:95.41ms +step:1633/1670 train_time:155797ms step_avg:95.41ms +step:1634/1670 train_time:155890ms step_avg:95.40ms +step:1635/1670 train_time:155984ms step_avg:95.40ms +step:1636/1670 train_time:156078ms step_avg:95.40ms +step:1637/1670 train_time:156173ms step_avg:95.40ms +step:1638/1670 train_time:156269ms step_avg:95.40ms +step:1639/1670 train_time:156366ms step_avg:95.40ms +step:1640/1670 train_time:156462ms step_avg:95.40ms +step:1641/1670 train_time:156558ms step_avg:95.40ms +step:1642/1670 train_time:156653ms step_avg:95.40ms +step:1643/1670 train_time:156746ms step_avg:95.40ms +step:1644/1670 train_time:156840ms step_avg:95.40ms +step:1645/1670 train_time:156933ms step_avg:95.40ms +step:1646/1670 train_time:157027ms step_avg:95.40ms +step:1647/1670 train_time:157121ms step_avg:95.40ms +step:1648/1670 train_time:157218ms step_avg:95.40ms +step:1649/1670 train_time:157314ms step_avg:95.40ms +step:1650/1670 train_time:157410ms step_avg:95.40ms +step:1651/1670 train_time:157505ms step_avg:95.40ms +step:1652/1670 train_time:157600ms step_avg:95.40ms +step:1653/1670 train_time:157695ms step_avg:95.40ms +step:1654/1670 train_time:157790ms step_avg:95.40ms +step:1655/1670 train_time:157884ms step_avg:95.40ms +step:1656/1670 train_time:157979ms step_avg:95.40ms +step:1657/1670 train_time:158072ms step_avg:95.40ms +step:1658/1670 train_time:158167ms step_avg:95.40ms +step:1659/1670 train_time:158261ms step_avg:95.40ms +step:1660/1670 train_time:158356ms step_avg:95.40ms +step:1661/1670 train_time:158451ms step_avg:95.40ms +step:1662/1670 train_time:158546ms step_avg:95.39ms +step:1663/1670 train_time:158640ms step_avg:95.39ms +step:1664/1670 train_time:158734ms step_avg:95.39ms +step:1665/1670 train_time:158828ms step_avg:95.39ms +step:1666/1670 train_time:158922ms step_avg:95.39ms +step:1667/1670 train_time:159017ms step_avg:95.39ms +step:1668/1670 train_time:159111ms step_avg:95.39ms +step:1669/1670 train_time:159205ms step_avg:95.39ms +step:1670/1670 train_time:159300ms step_avg:95.39ms +step:1670/1670 val_loss:3.2806 train_time:159468ms step_avg:95.49ms +peak memory allocated: 32712 MiB reserved: 47656 MiB diff --git a/records/091025_Yarn/61b04c65-2c0f-4d24-83e2-6035dfea1582.txt b/records/091025_Yarn/61b04c65-2c0f-4d24-83e2-6035dfea1582.txt new file mode 100644 index 000000000..4c448dd7c --- /dev/null +++ b/records/091025_Yarn/61b04c65-2c0f-4d24-83e2-6035dfea1582.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 03:57:09 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 32C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 33C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 34C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 31C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 31C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 33C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 61015 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 61016 C /usr/bin/python3 614MiB | +| 0 N/A N/A 61017 C /usr/bin/python3 614MiB | +| 0 N/A N/A 61018 C /usr/bin/python3 614MiB | +| 0 N/A N/A 61019 C /usr/bin/python3 614MiB | +| 0 N/A N/A 61020 C /usr/bin/python3 614MiB | +| 0 N/A N/A 61021 C /usr/bin/python3 614MiB | +| 0 N/A N/A 61022 C /usr/bin/python3 614MiB | +| 1 N/A N/A 61016 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 61017 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 61018 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 61019 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 61020 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 61021 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 61022 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:453ms step_avg:453.25ms +step:2/1670 train_time:477ms step_avg:238.73ms +step:3/1670 train_time:545ms step_avg:181.74ms +step:4/1670 train_time:636ms step_avg:158.90ms +step:5/1670 train_time:727ms step_avg:145.42ms +step:6/1670 train_time:818ms step_avg:136.41ms +step:7/1670 train_time:910ms step_avg:130.05ms +step:8/1670 train_time:1002ms step_avg:125.23ms +step:9/1670 train_time:1094ms step_avg:121.53ms +step:10/1670 train_time:1185ms step_avg:118.54ms +step:11/1670 train_time:1277ms step_avg:116.13ms +step:12/1670 train_time:1371ms step_avg:114.23ms +step:13/1670 train_time:1468ms step_avg:112.93ms +step:14/1670 train_time:1561ms step_avg:111.48ms +step:15/1670 train_time:1654ms step_avg:110.26ms +step:16/1670 train_time:1746ms step_avg:109.13ms +step:17/1670 train_time:1838ms step_avg:108.10ms +step:18/1670 train_time:1930ms step_avg:107.23ms +step:19/1670 train_time:2022ms step_avg:106.45ms +step:20/1670 train_time:2115ms step_avg:105.77ms +step:21/1670 train_time:2208ms step_avg:105.12ms +step:22/1670 train_time:2300ms step_avg:104.57ms +step:23/1670 train_time:2394ms step_avg:104.07ms +step:24/1670 train_time:2487ms step_avg:103.62ms +step:25/1670 train_time:2580ms step_avg:103.20ms +step:26/1670 train_time:2674ms step_avg:102.84ms +step:27/1670 train_time:2767ms step_avg:102.50ms +step:28/1670 train_time:2859ms step_avg:102.12ms +step:29/1670 train_time:2952ms step_avg:101.79ms +step:30/1670 train_time:3045ms step_avg:101.49ms +step:31/1670 train_time:3137ms step_avg:101.19ms +step:32/1670 train_time:3230ms step_avg:100.95ms +step:33/1670 train_time:3324ms step_avg:100.72ms +step:34/1670 train_time:3417ms step_avg:100.50ms +step:35/1670 train_time:3510ms step_avg:100.29ms +step:36/1670 train_time:3604ms step_avg:100.11ms +step:37/1670 train_time:3696ms step_avg:99.90ms +step:38/1670 train_time:3789ms step_avg:99.72ms +step:39/1670 train_time:3881ms step_avg:99.53ms +step:40/1670 train_time:3974ms step_avg:99.36ms +step:41/1670 train_time:4067ms step_avg:99.20ms +step:42/1670 train_time:4160ms step_avg:99.05ms +step:43/1670 train_time:4253ms step_avg:98.90ms +step:44/1670 train_time:4345ms step_avg:98.75ms +step:45/1670 train_time:4437ms step_avg:98.60ms +step:46/1670 train_time:4531ms step_avg:98.49ms +step:47/1670 train_time:4624ms step_avg:98.39ms +step:48/1670 train_time:4716ms step_avg:98.25ms +step:49/1670 train_time:4808ms step_avg:98.13ms +step:50/1670 train_time:4901ms step_avg:98.02ms +step:51/1670 train_time:4993ms step_avg:97.91ms +step:52/1670 train_time:5086ms step_avg:97.81ms +step:53/1670 train_time:5179ms step_avg:97.71ms +step:54/1670 train_time:5272ms step_avg:97.62ms +step:55/1670 train_time:5364ms step_avg:97.52ms +step:56/1670 train_time:5456ms step_avg:97.43ms +step:57/1670 train_time:5550ms step_avg:97.37ms +step:58/1670 train_time:5642ms step_avg:97.28ms +step:59/1670 train_time:5735ms step_avg:97.20ms +step:60/1670 train_time:5828ms step_avg:97.13ms +step:61/1670 train_time:5920ms step_avg:97.06ms +step:62/1670 train_time:6012ms step_avg:96.97ms +step:63/1670 train_time:6105ms step_avg:96.91ms +step:64/1670 train_time:6198ms step_avg:96.84ms +step:65/1670 train_time:6291ms step_avg:96.78ms +step:66/1670 train_time:6383ms step_avg:96.71ms +step:67/1670 train_time:6475ms step_avg:96.65ms +step:68/1670 train_time:6568ms step_avg:96.59ms +step:69/1670 train_time:6661ms step_avg:96.53ms +step:70/1670 train_time:6753ms step_avg:96.48ms +step:71/1670 train_time:6846ms step_avg:96.42ms +step:72/1670 train_time:6939ms step_avg:96.37ms +step:73/1670 train_time:7031ms step_avg:96.32ms +step:74/1670 train_time:7125ms step_avg:96.28ms +step:75/1670 train_time:7217ms step_avg:96.22ms +step:76/1670 train_time:7309ms step_avg:96.17ms +step:77/1670 train_time:7402ms step_avg:96.12ms +step:78/1670 train_time:7493ms step_avg:96.07ms +step:79/1670 train_time:7586ms step_avg:96.02ms +step:80/1670 train_time:7679ms step_avg:95.99ms +step:81/1670 train_time:7772ms step_avg:95.95ms +step:82/1670 train_time:7864ms step_avg:95.91ms +step:83/1670 train_time:7956ms step_avg:95.86ms +step:84/1670 train_time:8050ms step_avg:95.83ms +step:85/1670 train_time:8143ms step_avg:95.80ms +step:86/1670 train_time:8235ms step_avg:95.76ms +step:87/1670 train_time:8328ms step_avg:95.73ms +step:88/1670 train_time:8422ms step_avg:95.71ms +step:89/1670 train_time:8514ms step_avg:95.66ms +step:90/1670 train_time:8607ms step_avg:95.63ms +step:91/1670 train_time:8699ms step_avg:95.60ms +step:92/1670 train_time:8792ms step_avg:95.57ms +step:93/1670 train_time:8885ms step_avg:95.54ms +step:94/1670 train_time:8978ms step_avg:95.51ms +step:95/1670 train_time:9070ms step_avg:95.47ms +step:96/1670 train_time:9162ms step_avg:95.44ms +step:97/1670 train_time:9255ms step_avg:95.41ms +step:98/1670 train_time:9348ms step_avg:95.39ms +step:99/1670 train_time:9442ms step_avg:95.37ms +step:100/1670 train_time:9533ms step_avg:95.33ms +step:101/1670 train_time:9625ms step_avg:95.30ms +step:102/1670 train_time:9717ms step_avg:95.27ms +step:103/1670 train_time:9809ms step_avg:95.24ms +step:104/1670 train_time:9903ms step_avg:95.22ms +step:105/1670 train_time:9995ms step_avg:95.19ms +step:106/1670 train_time:10087ms step_avg:95.16ms +step:107/1670 train_time:10180ms step_avg:95.14ms +step:108/1670 train_time:10272ms step_avg:95.12ms +step:109/1670 train_time:10365ms step_avg:95.09ms +step:110/1670 train_time:10458ms step_avg:95.07ms +step:111/1670 train_time:10550ms step_avg:95.05ms +step:112/1670 train_time:10643ms step_avg:95.02ms +step:113/1670 train_time:10735ms step_avg:95.00ms +step:114/1670 train_time:10827ms step_avg:94.97ms +step:115/1670 train_time:10921ms step_avg:94.97ms +step:116/1670 train_time:11015ms step_avg:94.95ms +step:117/1670 train_time:11107ms step_avg:94.93ms +step:118/1670 train_time:11198ms step_avg:94.90ms +step:119/1670 train_time:11291ms step_avg:94.88ms +step:120/1670 train_time:11383ms step_avg:94.86ms +step:121/1670 train_time:11476ms step_avg:94.84ms +step:122/1670 train_time:11568ms step_avg:94.82ms +step:123/1670 train_time:11660ms step_avg:94.80ms +step:124/1670 train_time:11753ms step_avg:94.78ms +step:125/1670 train_time:11845ms step_avg:94.76ms +step:125/1670 val_loss:4.2943 train_time:11936ms step_avg:95.49ms +step:126/1670 train_time:11961ms step_avg:94.93ms +step:127/1670 train_time:12034ms step_avg:94.75ms +step:128/1670 train_time:12137ms step_avg:94.82ms +step:129/1670 train_time:12235ms step_avg:94.84ms +step:130/1670 train_time:12327ms step_avg:94.82ms +step:131/1670 train_time:12419ms step_avg:94.80ms +step:132/1670 train_time:12510ms step_avg:94.78ms +step:133/1670 train_time:12602ms step_avg:94.75ms +step:134/1670 train_time:12693ms step_avg:94.73ms +step:135/1670 train_time:12785ms step_avg:94.71ms +step:136/1670 train_time:12877ms step_avg:94.68ms +step:137/1670 train_time:12968ms step_avg:94.66ms +step:138/1670 train_time:13062ms step_avg:94.65ms +step:139/1670 train_time:13156ms step_avg:94.65ms +step:140/1670 train_time:13249ms step_avg:94.64ms +step:141/1670 train_time:13343ms step_avg:94.63ms +step:142/1670 train_time:13435ms step_avg:94.61ms +step:143/1670 train_time:13528ms step_avg:94.60ms +step:144/1670 train_time:13620ms step_avg:94.58ms +step:145/1670 train_time:13711ms step_avg:94.56ms +step:146/1670 train_time:13803ms step_avg:94.54ms +step:147/1670 train_time:13895ms step_avg:94.52ms +step:148/1670 train_time:13987ms step_avg:94.50ms +step:149/1670 train_time:14079ms step_avg:94.49ms +step:150/1670 train_time:14172ms step_avg:94.48ms +step:151/1670 train_time:14266ms step_avg:94.47ms +step:152/1670 train_time:14359ms step_avg:94.47ms +step:153/1670 train_time:14451ms step_avg:94.45ms +step:154/1670 train_time:14543ms step_avg:94.44ms +step:155/1670 train_time:14635ms step_avg:94.42ms +step:156/1670 train_time:14728ms step_avg:94.41ms +step:157/1670 train_time:14820ms step_avg:94.40ms +step:158/1670 train_time:14912ms step_avg:94.38ms +step:159/1670 train_time:15006ms step_avg:94.38ms +step:160/1670 train_time:15099ms step_avg:94.37ms +step:161/1670 train_time:15192ms step_avg:94.36ms +step:162/1670 train_time:15284ms step_avg:94.35ms +step:163/1670 train_time:15377ms step_avg:94.34ms +step:164/1670 train_time:15470ms step_avg:94.33ms +step:165/1670 train_time:15562ms step_avg:94.32ms +step:166/1670 train_time:15656ms step_avg:94.31ms +step:167/1670 train_time:15748ms step_avg:94.30ms +step:168/1670 train_time:15840ms step_avg:94.29ms +step:169/1670 train_time:15932ms step_avg:94.27ms +step:170/1670 train_time:16025ms step_avg:94.27ms +step:171/1670 train_time:16117ms step_avg:94.25ms +step:172/1670 train_time:16209ms step_avg:94.24ms +step:173/1670 train_time:16302ms step_avg:94.23ms +step:174/1670 train_time:16395ms step_avg:94.22ms +step:175/1670 train_time:16488ms step_avg:94.22ms +step:176/1670 train_time:16581ms step_avg:94.21ms +step:177/1670 train_time:16673ms step_avg:94.20ms +step:178/1670 train_time:16765ms step_avg:94.18ms +step:179/1670 train_time:16857ms step_avg:94.18ms +step:180/1670 train_time:16949ms step_avg:94.16ms +step:181/1670 train_time:17042ms step_avg:94.15ms +step:182/1670 train_time:17135ms step_avg:94.15ms +step:183/1670 train_time:17226ms step_avg:94.13ms +step:184/1670 train_time:17319ms step_avg:94.12ms +step:185/1670 train_time:17411ms step_avg:94.12ms +step:186/1670 train_time:17504ms step_avg:94.11ms +step:187/1670 train_time:17597ms step_avg:94.10ms +step:188/1670 train_time:17689ms step_avg:94.09ms +step:189/1670 train_time:17782ms step_avg:94.08ms +step:190/1670 train_time:17874ms step_avg:94.07ms +step:191/1670 train_time:17967ms step_avg:94.07ms +step:192/1670 train_time:18059ms step_avg:94.06ms +step:193/1670 train_time:18152ms step_avg:94.05ms +step:194/1670 train_time:18245ms step_avg:94.05ms +step:195/1670 train_time:18338ms step_avg:94.04ms +step:196/1670 train_time:18430ms step_avg:94.03ms +step:197/1670 train_time:18523ms step_avg:94.02ms +step:198/1670 train_time:18615ms step_avg:94.02ms +step:199/1670 train_time:18708ms step_avg:94.01ms +step:200/1670 train_time:18800ms step_avg:94.00ms +step:201/1670 train_time:18892ms step_avg:93.99ms +step:202/1670 train_time:18985ms step_avg:93.98ms +step:203/1670 train_time:19078ms step_avg:93.98ms +step:204/1670 train_time:19170ms step_avg:93.97ms +step:205/1670 train_time:19262ms step_avg:93.96ms +step:206/1670 train_time:19354ms step_avg:93.95ms +step:207/1670 train_time:19446ms step_avg:93.94ms +step:208/1670 train_time:19540ms step_avg:93.94ms +step:209/1670 train_time:19632ms step_avg:93.93ms +step:210/1670 train_time:19724ms step_avg:93.93ms +step:211/1670 train_time:19817ms step_avg:93.92ms +step:212/1670 train_time:19909ms step_avg:93.91ms +step:213/1670 train_time:20263ms step_avg:95.13ms +step:214/1670 train_time:20379ms step_avg:95.23ms +step:215/1670 train_time:20469ms step_avg:95.21ms +step:216/1670 train_time:20561ms step_avg:95.19ms +step:217/1670 train_time:20652ms step_avg:95.17ms +step:218/1670 train_time:20744ms step_avg:95.16ms +step:219/1670 train_time:20835ms step_avg:95.14ms +step:220/1670 train_time:20926ms step_avg:95.12ms +step:221/1670 train_time:21018ms step_avg:95.10ms +step:222/1670 train_time:21110ms step_avg:95.09ms +step:223/1670 train_time:21203ms step_avg:95.08ms +step:224/1670 train_time:21298ms step_avg:95.08ms +step:225/1670 train_time:21391ms step_avg:95.07ms +step:226/1670 train_time:21484ms step_avg:95.06ms +step:227/1670 train_time:21577ms step_avg:95.05ms +step:228/1670 train_time:21669ms step_avg:95.04ms +step:229/1670 train_time:21760ms step_avg:95.02ms +step:230/1670 train_time:21852ms step_avg:95.01ms +step:231/1670 train_time:21944ms step_avg:95.00ms +step:232/1670 train_time:22035ms step_avg:94.98ms +step:233/1670 train_time:22127ms step_avg:94.97ms +step:234/1670 train_time:22220ms step_avg:94.96ms +step:235/1670 train_time:22313ms step_avg:94.95ms +step:236/1670 train_time:22406ms step_avg:94.94ms +step:237/1670 train_time:22500ms step_avg:94.94ms +step:238/1670 train_time:22593ms step_avg:94.93ms +step:239/1670 train_time:22685ms step_avg:94.92ms +step:240/1670 train_time:22778ms step_avg:94.91ms +step:241/1670 train_time:22870ms step_avg:94.90ms +step:242/1670 train_time:22962ms step_avg:94.88ms +step:243/1670 train_time:23053ms step_avg:94.87ms +step:244/1670 train_time:23146ms step_avg:94.86ms +step:245/1670 train_time:23239ms step_avg:94.85ms +step:246/1670 train_time:23331ms step_avg:94.84ms +step:247/1670 train_time:23423ms step_avg:94.83ms +step:248/1670 train_time:23517ms step_avg:94.83ms +step:249/1670 train_time:23609ms step_avg:94.82ms +step:250/1670 train_time:23702ms step_avg:94.81ms +step:250/1670 val_loss:3.9692 train_time:23793ms step_avg:95.17ms +step:251/1670 train_time:23819ms step_avg:94.90ms +step:252/1670 train_time:23894ms step_avg:94.82ms +step:253/1670 train_time:23995ms step_avg:94.84ms +step:254/1670 train_time:24089ms step_avg:94.84ms +step:255/1670 train_time:24182ms step_avg:94.83ms +step:256/1670 train_time:24273ms step_avg:94.82ms +step:257/1670 train_time:24364ms step_avg:94.80ms +step:258/1670 train_time:24456ms step_avg:94.79ms +step:259/1670 train_time:24547ms step_avg:94.78ms +step:260/1670 train_time:24639ms step_avg:94.76ms +step:261/1670 train_time:24730ms step_avg:94.75ms +step:262/1670 train_time:24823ms step_avg:94.74ms +step:263/1670 train_time:24917ms step_avg:94.74ms +step:264/1670 train_time:25011ms step_avg:94.74ms +step:265/1670 train_time:25104ms step_avg:94.73ms +step:266/1670 train_time:25197ms step_avg:94.73ms +step:267/1670 train_time:25290ms step_avg:94.72ms +step:268/1670 train_time:25382ms step_avg:94.71ms +step:269/1670 train_time:25474ms step_avg:94.70ms +step:270/1670 train_time:25565ms step_avg:94.69ms +step:271/1670 train_time:25658ms step_avg:94.68ms +step:272/1670 train_time:25749ms step_avg:94.67ms +step:273/1670 train_time:25842ms step_avg:94.66ms +step:274/1670 train_time:25935ms step_avg:94.65ms +step:275/1670 train_time:26028ms step_avg:94.65ms +step:276/1670 train_time:26121ms step_avg:94.64ms +step:277/1670 train_time:26214ms step_avg:94.64ms +step:278/1670 train_time:26308ms step_avg:94.63ms +step:279/1670 train_time:26400ms step_avg:94.62ms +step:280/1670 train_time:26492ms step_avg:94.61ms +step:281/1670 train_time:26584ms step_avg:94.61ms +step:282/1670 train_time:26676ms step_avg:94.59ms +step:283/1670 train_time:26768ms step_avg:94.59ms +step:284/1670 train_time:26861ms step_avg:94.58ms +step:285/1670 train_time:26954ms step_avg:94.58ms +step:286/1670 train_time:27047ms step_avg:94.57ms +step:287/1670 train_time:27139ms step_avg:94.56ms +step:288/1670 train_time:27232ms step_avg:94.55ms +step:289/1670 train_time:27324ms step_avg:94.55ms +step:290/1670 train_time:27416ms step_avg:94.54ms +step:291/1670 train_time:27509ms step_avg:94.53ms +step:292/1670 train_time:27601ms step_avg:94.52ms +step:293/1670 train_time:27693ms step_avg:94.52ms +step:294/1670 train_time:27786ms step_avg:94.51ms +step:295/1670 train_time:27879ms step_avg:94.50ms +step:296/1670 train_time:27972ms step_avg:94.50ms +step:297/1670 train_time:28065ms step_avg:94.49ms +step:298/1670 train_time:28157ms step_avg:94.49ms +step:299/1670 train_time:28250ms step_avg:94.48ms +step:300/1670 train_time:28343ms step_avg:94.48ms +step:301/1670 train_time:28435ms step_avg:94.47ms +step:302/1670 train_time:28527ms step_avg:94.46ms +step:303/1670 train_time:28619ms step_avg:94.45ms +step:304/1670 train_time:28712ms step_avg:94.45ms +step:305/1670 train_time:28804ms step_avg:94.44ms +step:306/1670 train_time:28896ms step_avg:94.43ms +step:307/1670 train_time:28989ms step_avg:94.43ms +step:308/1670 train_time:29081ms step_avg:94.42ms +step:309/1670 train_time:29174ms step_avg:94.41ms +step:310/1670 train_time:29267ms step_avg:94.41ms +step:311/1670 train_time:29359ms step_avg:94.40ms +step:312/1670 train_time:29451ms step_avg:94.40ms +step:313/1670 train_time:29544ms step_avg:94.39ms +step:314/1670 train_time:29636ms step_avg:94.38ms +step:315/1670 train_time:29730ms step_avg:94.38ms +step:316/1670 train_time:29823ms step_avg:94.38ms +step:317/1670 train_time:29915ms step_avg:94.37ms +step:318/1670 train_time:30006ms step_avg:94.36ms +step:319/1670 train_time:30099ms step_avg:94.35ms +step:320/1670 train_time:30191ms step_avg:94.35ms +step:321/1670 train_time:30284ms step_avg:94.34ms +step:322/1670 train_time:30376ms step_avg:94.34ms +step:323/1670 train_time:30468ms step_avg:94.33ms +step:324/1670 train_time:30560ms step_avg:94.32ms +step:325/1670 train_time:30653ms step_avg:94.32ms +step:326/1670 train_time:30746ms step_avg:94.31ms +step:327/1670 train_time:30839ms step_avg:94.31ms +step:328/1670 train_time:30931ms step_avg:94.30ms +step:329/1670 train_time:31024ms step_avg:94.30ms +step:330/1670 train_time:31116ms step_avg:94.29ms +step:331/1670 train_time:31209ms step_avg:94.29ms +step:332/1670 train_time:31301ms step_avg:94.28ms +step:333/1670 train_time:31394ms step_avg:94.28ms +step:334/1670 train_time:31487ms step_avg:94.27ms +step:335/1670 train_time:31579ms step_avg:94.26ms +step:336/1670 train_time:31671ms step_avg:94.26ms +step:337/1670 train_time:31763ms step_avg:94.25ms +step:338/1670 train_time:31856ms step_avg:94.25ms +step:339/1670 train_time:31948ms step_avg:94.24ms +step:340/1670 train_time:32041ms step_avg:94.24ms +step:341/1670 train_time:32133ms step_avg:94.23ms +step:342/1670 train_time:32226ms step_avg:94.23ms +step:343/1670 train_time:32318ms step_avg:94.22ms +step:344/1670 train_time:32411ms step_avg:94.22ms +step:345/1670 train_time:32502ms step_avg:94.21ms +step:346/1670 train_time:32595ms step_avg:94.20ms +step:347/1670 train_time:32687ms step_avg:94.20ms +step:348/1670 train_time:32779ms step_avg:94.19ms +step:349/1670 train_time:32872ms step_avg:94.19ms +step:350/1670 train_time:32964ms step_avg:94.18ms +step:351/1670 train_time:33056ms step_avg:94.18ms +step:352/1670 train_time:33148ms step_avg:94.17ms +step:353/1670 train_time:33241ms step_avg:94.17ms +step:354/1670 train_time:33334ms step_avg:94.16ms +step:355/1670 train_time:33426ms step_avg:94.16ms +step:356/1670 train_time:33518ms step_avg:94.15ms +step:357/1670 train_time:33611ms step_avg:94.15ms +step:358/1670 train_time:33702ms step_avg:94.14ms +step:359/1670 train_time:33795ms step_avg:94.14ms +step:360/1670 train_time:33887ms step_avg:94.13ms +step:361/1670 train_time:33980ms step_avg:94.13ms +step:362/1670 train_time:34073ms step_avg:94.12ms +step:363/1670 train_time:34164ms step_avg:94.12ms +step:364/1670 train_time:34258ms step_avg:94.11ms +step:365/1670 train_time:34350ms step_avg:94.11ms +step:366/1670 train_time:34443ms step_avg:94.11ms +step:367/1670 train_time:34535ms step_avg:94.10ms +step:368/1670 train_time:34628ms step_avg:94.10ms +step:369/1670 train_time:34720ms step_avg:94.09ms +step:370/1670 train_time:34812ms step_avg:94.09ms +step:371/1670 train_time:34905ms step_avg:94.08ms +step:372/1670 train_time:34997ms step_avg:94.08ms +step:373/1670 train_time:35090ms step_avg:94.07ms +step:374/1670 train_time:35182ms step_avg:94.07ms +step:375/1670 train_time:35274ms step_avg:94.06ms +step:375/1670 val_loss:3.8146 train_time:35365ms step_avg:94.31ms +step:376/1670 train_time:35391ms step_avg:94.12ms +step:377/1670 train_time:35464ms step_avg:94.07ms +step:378/1670 train_time:35562ms step_avg:94.08ms +step:379/1670 train_time:35655ms step_avg:94.08ms +step:380/1670 train_time:35747ms step_avg:94.07ms +step:381/1670 train_time:35838ms step_avg:94.06ms +step:382/1670 train_time:35929ms step_avg:94.06ms +step:383/1670 train_time:36021ms step_avg:94.05ms +step:384/1670 train_time:36113ms step_avg:94.04ms +step:385/1670 train_time:36204ms step_avg:94.04ms +step:386/1670 train_time:36297ms step_avg:94.03ms +step:387/1670 train_time:36389ms step_avg:94.03ms +step:388/1670 train_time:36484ms step_avg:94.03ms +step:389/1670 train_time:36578ms step_avg:94.03ms +step:390/1670 train_time:36670ms step_avg:94.03ms +step:391/1670 train_time:36763ms step_avg:94.02ms +step:392/1670 train_time:36855ms step_avg:94.02ms +step:393/1670 train_time:36947ms step_avg:94.01ms +step:394/1670 train_time:37040ms step_avg:94.01ms +step:395/1670 train_time:37131ms step_avg:94.00ms +step:396/1670 train_time:37223ms step_avg:94.00ms +step:397/1670 train_time:37315ms step_avg:93.99ms +step:398/1670 train_time:37409ms step_avg:93.99ms +step:399/1670 train_time:37503ms step_avg:93.99ms +step:400/1670 train_time:37596ms step_avg:93.99ms +step:401/1670 train_time:37689ms step_avg:93.99ms +step:402/1670 train_time:37783ms step_avg:93.99ms +step:403/1670 train_time:37875ms step_avg:93.98ms +step:404/1670 train_time:37967ms step_avg:93.98ms +step:405/1670 train_time:38059ms step_avg:93.97ms +step:406/1670 train_time:38151ms step_avg:93.97ms +step:407/1670 train_time:38243ms step_avg:93.96ms +step:408/1670 train_time:38337ms step_avg:93.96ms +step:409/1670 train_time:38429ms step_avg:93.96ms +step:410/1670 train_time:38522ms step_avg:93.96ms +step:411/1670 train_time:38615ms step_avg:93.95ms +step:412/1670 train_time:38709ms step_avg:93.95ms +step:413/1670 train_time:38802ms step_avg:93.95ms +step:414/1670 train_time:38893ms step_avg:93.95ms +step:415/1670 train_time:38986ms step_avg:93.94ms +step:416/1670 train_time:39078ms step_avg:93.94ms +step:417/1670 train_time:39170ms step_avg:93.93ms +step:418/1670 train_time:39262ms step_avg:93.93ms +step:419/1670 train_time:39354ms step_avg:93.92ms +step:420/1670 train_time:39447ms step_avg:93.92ms +step:421/1670 train_time:39539ms step_avg:93.92ms +step:422/1670 train_time:39632ms step_avg:93.92ms +step:423/1670 train_time:39725ms step_avg:93.91ms +step:424/1670 train_time:39818ms step_avg:93.91ms +step:425/1670 train_time:40145ms step_avg:94.46ms +step:426/1670 train_time:40336ms step_avg:94.68ms +step:427/1670 train_time:40426ms step_avg:94.67ms +step:428/1670 train_time:40517ms step_avg:94.67ms +step:429/1670 train_time:40609ms step_avg:94.66ms +step:430/1670 train_time:40700ms step_avg:94.65ms +step:431/1670 train_time:40792ms step_avg:94.64ms +step:432/1670 train_time:40884ms step_avg:94.64ms +step:433/1670 train_time:40975ms step_avg:94.63ms +step:434/1670 train_time:41067ms step_avg:94.62ms +step:435/1670 train_time:41158ms step_avg:94.62ms +step:436/1670 train_time:41253ms step_avg:94.62ms +step:437/1670 train_time:41349ms step_avg:94.62ms +step:438/1670 train_time:41442ms step_avg:94.62ms +step:439/1670 train_time:41535ms step_avg:94.61ms +step:440/1670 train_time:41627ms step_avg:94.61ms +step:441/1670 train_time:41719ms step_avg:94.60ms +step:442/1670 train_time:41812ms step_avg:94.60ms +step:443/1670 train_time:41903ms step_avg:94.59ms +step:444/1670 train_time:41995ms step_avg:94.58ms +step:445/1670 train_time:42088ms step_avg:94.58ms +step:446/1670 train_time:42180ms step_avg:94.57ms +step:447/1670 train_time:42273ms step_avg:94.57ms +step:448/1670 train_time:42367ms step_avg:94.57ms +step:449/1670 train_time:42460ms step_avg:94.57ms +step:450/1670 train_time:42553ms step_avg:94.56ms +step:451/1670 train_time:42645ms step_avg:94.56ms +step:452/1670 train_time:42738ms step_avg:94.55ms +step:453/1670 train_time:42831ms step_avg:94.55ms +step:454/1670 train_time:42923ms step_avg:94.54ms +step:455/1670 train_time:43014ms step_avg:94.54ms +step:456/1670 train_time:43106ms step_avg:94.53ms +step:457/1670 train_time:43199ms step_avg:94.53ms +step:458/1670 train_time:43292ms step_avg:94.52ms +step:459/1670 train_time:43384ms step_avg:94.52ms +step:460/1670 train_time:43477ms step_avg:94.52ms +step:461/1670 train_time:43569ms step_avg:94.51ms +step:462/1670 train_time:43662ms step_avg:94.51ms +step:463/1670 train_time:43755ms step_avg:94.50ms +step:464/1670 train_time:43847ms step_avg:94.50ms +step:465/1670 train_time:43940ms step_avg:94.49ms +step:466/1670 train_time:44032ms step_avg:94.49ms +step:467/1670 train_time:44125ms step_avg:94.49ms +step:468/1670 train_time:44217ms step_avg:94.48ms +step:469/1670 train_time:44310ms step_avg:94.48ms +step:470/1670 train_time:44402ms step_avg:94.47ms +step:471/1670 train_time:44495ms step_avg:94.47ms +step:472/1670 train_time:44587ms step_avg:94.46ms +step:473/1670 train_time:44680ms step_avg:94.46ms +step:474/1670 train_time:44773ms step_avg:94.46ms +step:475/1670 train_time:44865ms step_avg:94.45ms +step:476/1670 train_time:44957ms step_avg:94.45ms +step:477/1670 train_time:45049ms step_avg:94.44ms +step:478/1670 train_time:45141ms step_avg:94.44ms +step:479/1670 train_time:45233ms step_avg:94.43ms +step:480/1670 train_time:45326ms step_avg:94.43ms +step:481/1670 train_time:45418ms step_avg:94.42ms +step:482/1670 train_time:45511ms step_avg:94.42ms +step:483/1670 train_time:45603ms step_avg:94.42ms +step:484/1670 train_time:45696ms step_avg:94.41ms +step:485/1670 train_time:45788ms step_avg:94.41ms +step:486/1670 train_time:45881ms step_avg:94.40ms +step:487/1670 train_time:45973ms step_avg:94.40ms +step:488/1670 train_time:46065ms step_avg:94.40ms +step:489/1670 train_time:46158ms step_avg:94.39ms +step:490/1670 train_time:46251ms step_avg:94.39ms +step:491/1670 train_time:46343ms step_avg:94.38ms +step:492/1670 train_time:46435ms step_avg:94.38ms +step:493/1670 train_time:46528ms step_avg:94.38ms +step:494/1670 train_time:46620ms step_avg:94.37ms +step:495/1670 train_time:46712ms step_avg:94.37ms +step:496/1670 train_time:46805ms step_avg:94.37ms +step:497/1670 train_time:46898ms step_avg:94.36ms +step:498/1670 train_time:46990ms step_avg:94.36ms +step:499/1670 train_time:47082ms step_avg:94.35ms +step:500/1670 train_time:47175ms step_avg:94.35ms +step:500/1670 val_loss:3.7137 train_time:47265ms step_avg:94.53ms +step:501/1670 train_time:47290ms step_avg:94.39ms +step:502/1670 train_time:47365ms step_avg:94.35ms +step:503/1670 train_time:47462ms step_avg:94.36ms +step:504/1670 train_time:47556ms step_avg:94.36ms +step:505/1670 train_time:47648ms step_avg:94.35ms +step:506/1670 train_time:47739ms step_avg:94.35ms +step:507/1670 train_time:47831ms step_avg:94.34ms +step:508/1670 train_time:47923ms step_avg:94.34ms +step:509/1670 train_time:48015ms step_avg:94.33ms +step:510/1670 train_time:48107ms step_avg:94.33ms +step:511/1670 train_time:48198ms step_avg:94.32ms +step:512/1670 train_time:48292ms step_avg:94.32ms +step:513/1670 train_time:48387ms step_avg:94.32ms +step:514/1670 train_time:48483ms step_avg:94.33ms +step:515/1670 train_time:48576ms step_avg:94.32ms +step:516/1670 train_time:48668ms step_avg:94.32ms +step:517/1670 train_time:48759ms step_avg:94.31ms +step:518/1670 train_time:48852ms step_avg:94.31ms +step:519/1670 train_time:48943ms step_avg:94.30ms +step:520/1670 train_time:49035ms step_avg:94.30ms +step:521/1670 train_time:49127ms step_avg:94.29ms +step:522/1670 train_time:49219ms step_avg:94.29ms +step:523/1670 train_time:49313ms step_avg:94.29ms +step:524/1670 train_time:49407ms step_avg:94.29ms +step:525/1670 train_time:49499ms step_avg:94.28ms +step:526/1670 train_time:49592ms step_avg:94.28ms +step:527/1670 train_time:49685ms step_avg:94.28ms +step:528/1670 train_time:49778ms step_avg:94.28ms +step:529/1670 train_time:49869ms step_avg:94.27ms +step:530/1670 train_time:49961ms step_avg:94.27ms +step:531/1670 train_time:50053ms step_avg:94.26ms +step:532/1670 train_time:50144ms step_avg:94.26ms +step:533/1670 train_time:50236ms step_avg:94.25ms +step:534/1670 train_time:50329ms step_avg:94.25ms +step:535/1670 train_time:50422ms step_avg:94.25ms +step:536/1670 train_time:50515ms step_avg:94.24ms +step:537/1670 train_time:50608ms step_avg:94.24ms +step:538/1670 train_time:50699ms step_avg:94.24ms +step:539/1670 train_time:50792ms step_avg:94.23ms +step:540/1670 train_time:50884ms step_avg:94.23ms +step:541/1670 train_time:50976ms step_avg:94.23ms +step:542/1670 train_time:51068ms step_avg:94.22ms +step:543/1670 train_time:51160ms step_avg:94.22ms +step:544/1670 train_time:51253ms step_avg:94.22ms +step:545/1670 train_time:51346ms step_avg:94.21ms +step:546/1670 train_time:51439ms step_avg:94.21ms +step:547/1670 train_time:51531ms step_avg:94.21ms +step:548/1670 train_time:51624ms step_avg:94.20ms +step:549/1670 train_time:51716ms step_avg:94.20ms +step:550/1670 train_time:51809ms step_avg:94.20ms +step:551/1670 train_time:51902ms step_avg:94.20ms +step:552/1670 train_time:51994ms step_avg:94.19ms +step:553/1670 train_time:52086ms step_avg:94.19ms +step:554/1670 train_time:52178ms step_avg:94.18ms +step:555/1670 train_time:52271ms step_avg:94.18ms +step:556/1670 train_time:52363ms step_avg:94.18ms +step:557/1670 train_time:52456ms step_avg:94.18ms +step:558/1670 train_time:52659ms step_avg:94.37ms +step:559/1670 train_time:52726ms step_avg:94.32ms +step:560/1670 train_time:52819ms step_avg:94.32ms +step:561/1670 train_time:52912ms step_avg:94.32ms +step:562/1670 train_time:53005ms step_avg:94.31ms +step:563/1670 train_time:53097ms step_avg:94.31ms +step:564/1670 train_time:53190ms step_avg:94.31ms +step:565/1670 train_time:53283ms step_avg:94.31ms +step:566/1670 train_time:53376ms step_avg:94.30ms +step:567/1670 train_time:53469ms step_avg:94.30ms +step:568/1670 train_time:53568ms step_avg:94.31ms +step:569/1670 train_time:53665ms step_avg:94.31ms +step:570/1670 train_time:53759ms step_avg:94.31ms +step:571/1670 train_time:53853ms step_avg:94.31ms +step:572/1670 train_time:53945ms step_avg:94.31ms +step:573/1670 train_time:54038ms step_avg:94.31ms +step:574/1670 train_time:54131ms step_avg:94.31ms +step:575/1670 train_time:54224ms step_avg:94.30ms +step:576/1670 train_time:54317ms step_avg:94.30ms +step:577/1670 train_time:54410ms step_avg:94.30ms +step:578/1670 train_time:54505ms step_avg:94.30ms +step:579/1670 train_time:54600ms step_avg:94.30ms +step:580/1670 train_time:54695ms step_avg:94.30ms +step:581/1670 train_time:54790ms step_avg:94.30ms +step:582/1670 train_time:54884ms step_avg:94.30ms +step:583/1670 train_time:54977ms step_avg:94.30ms +step:584/1670 train_time:55070ms step_avg:94.30ms +step:585/1670 train_time:55164ms step_avg:94.30ms +step:586/1670 train_time:55256ms step_avg:94.29ms +step:587/1670 train_time:55349ms step_avg:94.29ms +step:588/1670 train_time:55443ms step_avg:94.29ms +step:589/1670 train_time:55537ms step_avg:94.29ms +step:590/1670 train_time:55631ms step_avg:94.29ms +step:591/1670 train_time:55726ms step_avg:94.29ms +step:592/1670 train_time:55820ms step_avg:94.29ms +step:593/1670 train_time:55914ms step_avg:94.29ms +step:594/1670 train_time:56008ms step_avg:94.29ms +step:595/1670 train_time:56101ms step_avg:94.29ms +step:596/1670 train_time:56194ms step_avg:94.29ms +step:597/1670 train_time:56287ms step_avg:94.28ms +step:598/1670 train_time:56380ms step_avg:94.28ms +step:599/1670 train_time:56474ms step_avg:94.28ms +step:600/1670 train_time:56569ms step_avg:94.28ms +step:601/1670 train_time:56663ms step_avg:94.28ms +step:602/1670 train_time:56758ms step_avg:94.28ms +step:603/1670 train_time:56853ms step_avg:94.28ms +step:604/1670 train_time:56947ms step_avg:94.28ms +step:605/1670 train_time:57040ms step_avg:94.28ms +step:606/1670 train_time:57133ms step_avg:94.28ms +step:607/1670 train_time:57226ms step_avg:94.28ms +step:608/1670 train_time:57319ms step_avg:94.27ms +step:609/1670 train_time:57412ms step_avg:94.27ms +step:610/1670 train_time:57506ms step_avg:94.27ms +step:611/1670 train_time:57600ms step_avg:94.27ms +step:612/1670 train_time:57695ms step_avg:94.27ms +step:613/1670 train_time:57790ms step_avg:94.27ms +step:614/1670 train_time:57885ms step_avg:94.28ms +step:615/1670 train_time:57980ms step_avg:94.28ms +step:616/1670 train_time:58073ms step_avg:94.27ms +step:617/1670 train_time:58167ms step_avg:94.27ms +step:618/1670 train_time:58259ms step_avg:94.27ms +step:619/1670 train_time:58354ms step_avg:94.27ms +step:620/1670 train_time:58448ms step_avg:94.27ms +step:621/1670 train_time:58541ms step_avg:94.27ms +step:622/1670 train_time:58635ms step_avg:94.27ms +step:623/1670 train_time:58729ms step_avg:94.27ms +step:624/1670 train_time:58824ms step_avg:94.27ms +step:625/1670 train_time:58918ms step_avg:94.27ms +step:625/1670 val_loss:3.6129 train_time:59010ms step_avg:94.42ms +step:626/1670 train_time:59036ms step_avg:94.31ms +step:627/1670 train_time:59113ms step_avg:94.28ms +step:628/1670 train_time:59213ms step_avg:94.29ms +step:629/1670 train_time:59307ms step_avg:94.29ms +step:630/1670 train_time:59400ms step_avg:94.29ms +step:631/1670 train_time:59492ms step_avg:94.28ms +step:632/1670 train_time:59584ms step_avg:94.28ms +step:633/1670 train_time:59676ms step_avg:94.28ms +step:634/1670 train_time:59769ms step_avg:94.27ms +step:635/1670 train_time:59862ms step_avg:94.27ms +step:636/1670 train_time:59957ms step_avg:94.27ms +step:637/1670 train_time:60051ms step_avg:94.27ms +step:638/1670 train_time:60146ms step_avg:94.27ms +step:639/1670 train_time:60541ms step_avg:94.74ms +step:640/1670 train_time:60608ms step_avg:94.70ms +step:641/1670 train_time:60700ms step_avg:94.70ms +step:642/1670 train_time:60793ms step_avg:94.69ms +step:643/1670 train_time:60885ms step_avg:94.69ms +step:644/1670 train_time:60978ms step_avg:94.69ms +step:645/1670 train_time:61070ms step_avg:94.68ms +step:646/1670 train_time:61163ms step_avg:94.68ms +step:647/1670 train_time:61255ms step_avg:94.68ms +step:648/1670 train_time:61348ms step_avg:94.67ms +step:649/1670 train_time:61441ms step_avg:94.67ms +step:650/1670 train_time:61539ms step_avg:94.67ms +step:651/1670 train_time:61634ms step_avg:94.68ms +step:652/1670 train_time:61727ms step_avg:94.67ms +step:653/1670 train_time:61820ms step_avg:94.67ms +step:654/1670 train_time:61914ms step_avg:94.67ms +step:655/1670 train_time:62006ms step_avg:94.67ms +step:656/1670 train_time:62099ms step_avg:94.66ms +step:657/1670 train_time:62192ms step_avg:94.66ms +step:658/1670 train_time:62284ms step_avg:94.66ms +step:659/1670 train_time:62377ms step_avg:94.65ms +step:660/1670 train_time:62471ms step_avg:94.65ms +step:661/1670 train_time:62565ms step_avg:94.65ms +step:662/1670 train_time:62659ms step_avg:94.65ms +step:663/1670 train_time:62753ms step_avg:94.65ms +step:664/1670 train_time:62847ms step_avg:94.65ms +step:665/1670 train_time:62941ms step_avg:94.65ms +step:666/1670 train_time:63035ms step_avg:94.65ms +step:667/1670 train_time:63127ms step_avg:94.64ms +step:668/1670 train_time:63221ms step_avg:94.64ms +step:669/1670 train_time:63314ms step_avg:94.64ms +step:670/1670 train_time:63409ms step_avg:94.64ms +step:671/1670 train_time:63500ms step_avg:94.63ms +step:672/1670 train_time:63594ms step_avg:94.63ms +step:673/1670 train_time:63688ms step_avg:94.63ms +step:674/1670 train_time:63782ms step_avg:94.63ms +step:675/1670 train_time:63876ms step_avg:94.63ms +step:676/1670 train_time:63969ms step_avg:94.63ms +step:677/1670 train_time:64062ms step_avg:94.63ms +step:678/1670 train_time:64155ms step_avg:94.62ms +step:679/1670 train_time:64248ms step_avg:94.62ms +step:680/1670 train_time:64343ms step_avg:94.62ms +step:681/1670 train_time:64437ms step_avg:94.62ms +step:682/1670 train_time:64530ms step_avg:94.62ms +step:683/1670 train_time:64623ms step_avg:94.62ms +step:684/1670 train_time:64716ms step_avg:94.61ms +step:685/1670 train_time:64810ms step_avg:94.61ms +step:686/1670 train_time:64904ms step_avg:94.61ms +step:687/1670 train_time:64998ms step_avg:94.61ms +step:688/1670 train_time:65093ms step_avg:94.61ms +step:689/1670 train_time:65185ms step_avg:94.61ms +step:690/1670 train_time:65279ms step_avg:94.61ms +step:691/1670 train_time:65372ms step_avg:94.60ms +step:692/1670 train_time:65465ms step_avg:94.60ms +step:693/1670 train_time:65559ms step_avg:94.60ms +step:694/1670 train_time:65653ms step_avg:94.60ms +step:695/1670 train_time:65746ms step_avg:94.60ms +step:696/1670 train_time:65840ms step_avg:94.60ms +step:697/1670 train_time:65933ms step_avg:94.60ms +step:698/1670 train_time:66027ms step_avg:94.59ms +step:699/1670 train_time:66120ms step_avg:94.59ms +step:700/1670 train_time:66214ms step_avg:94.59ms +step:701/1670 train_time:66307ms step_avg:94.59ms +step:702/1670 train_time:66401ms step_avg:94.59ms +step:703/1670 train_time:66494ms step_avg:94.59ms +step:704/1670 train_time:66587ms step_avg:94.58ms +step:705/1670 train_time:66681ms step_avg:94.58ms +step:706/1670 train_time:66775ms step_avg:94.58ms +step:707/1670 train_time:66868ms step_avg:94.58ms +step:708/1670 train_time:66961ms step_avg:94.58ms +step:709/1670 train_time:67055ms step_avg:94.58ms +step:710/1670 train_time:67148ms step_avg:94.57ms +step:711/1670 train_time:67243ms step_avg:94.58ms +step:712/1670 train_time:67337ms step_avg:94.57ms +step:713/1670 train_time:67430ms step_avg:94.57ms +step:714/1670 train_time:67523ms step_avg:94.57ms +step:715/1670 train_time:67616ms step_avg:94.57ms +step:716/1670 train_time:67709ms step_avg:94.57ms +step:717/1670 train_time:67803ms step_avg:94.56ms +step:718/1670 train_time:67896ms step_avg:94.56ms +step:719/1670 train_time:67990ms step_avg:94.56ms +step:720/1670 train_time:68084ms step_avg:94.56ms +step:721/1670 train_time:68177ms step_avg:94.56ms +step:722/1670 train_time:68271ms step_avg:94.56ms +step:723/1670 train_time:68365ms step_avg:94.56ms +step:724/1670 train_time:68459ms step_avg:94.56ms +step:725/1670 train_time:68553ms step_avg:94.56ms +step:726/1670 train_time:68646ms step_avg:94.55ms +step:727/1670 train_time:68739ms step_avg:94.55ms +step:728/1670 train_time:68832ms step_avg:94.55ms +step:729/1670 train_time:68926ms step_avg:94.55ms +step:730/1670 train_time:69019ms step_avg:94.55ms +step:731/1670 train_time:69113ms step_avg:94.55ms +step:732/1670 train_time:69206ms step_avg:94.54ms +step:733/1670 train_time:69300ms step_avg:94.54ms +step:734/1670 train_time:69394ms step_avg:94.54ms +step:735/1670 train_time:69487ms step_avg:94.54ms +step:736/1670 train_time:69581ms step_avg:94.54ms +step:737/1670 train_time:69674ms step_avg:94.54ms +step:738/1670 train_time:69768ms step_avg:94.54ms +step:739/1670 train_time:69862ms step_avg:94.54ms +step:740/1670 train_time:69955ms step_avg:94.53ms +step:741/1670 train_time:70049ms step_avg:94.53ms +step:742/1670 train_time:70142ms step_avg:94.53ms +step:743/1670 train_time:70237ms step_avg:94.53ms +step:744/1670 train_time:70330ms step_avg:94.53ms +step:745/1670 train_time:70424ms step_avg:94.53ms +step:746/1670 train_time:70518ms step_avg:94.53ms +step:747/1670 train_time:70611ms step_avg:94.53ms +step:748/1670 train_time:70704ms step_avg:94.52ms +step:749/1670 train_time:70798ms step_avg:94.52ms +step:750/1670 train_time:70892ms step_avg:94.52ms +step:750/1670 val_loss:3.5617 train_time:70983ms step_avg:94.64ms +step:751/1670 train_time:71009ms step_avg:94.55ms +step:752/1670 train_time:71088ms step_avg:94.53ms +step:753/1670 train_time:71187ms step_avg:94.54ms +step:754/1670 train_time:71280ms step_avg:94.54ms +step:755/1670 train_time:71373ms step_avg:94.53ms +step:756/1670 train_time:71466ms step_avg:94.53ms +step:757/1670 train_time:71558ms step_avg:94.53ms +step:758/1670 train_time:71651ms step_avg:94.53ms +step:759/1670 train_time:71743ms step_avg:94.52ms +step:760/1670 train_time:71836ms step_avg:94.52ms +step:761/1670 train_time:71929ms step_avg:94.52ms +step:762/1670 train_time:72024ms step_avg:94.52ms +step:763/1670 train_time:72119ms step_avg:94.52ms +step:764/1670 train_time:72215ms step_avg:94.52ms +step:765/1670 train_time:72308ms step_avg:94.52ms +step:766/1670 train_time:72402ms step_avg:94.52ms +step:767/1670 train_time:72495ms step_avg:94.52ms +step:768/1670 train_time:72589ms step_avg:94.52ms +step:769/1670 train_time:72681ms step_avg:94.51ms +step:770/1670 train_time:72774ms step_avg:94.51ms +step:771/1670 train_time:72867ms step_avg:94.51ms +step:772/1670 train_time:72960ms step_avg:94.51ms +step:773/1670 train_time:73055ms step_avg:94.51ms +step:774/1670 train_time:73149ms step_avg:94.51ms +step:775/1670 train_time:73244ms step_avg:94.51ms +step:776/1670 train_time:73338ms step_avg:94.51ms +step:777/1670 train_time:73431ms step_avg:94.51ms +step:778/1670 train_time:73524ms step_avg:94.50ms +step:779/1670 train_time:73617ms step_avg:94.50ms +step:780/1670 train_time:73710ms step_avg:94.50ms +step:781/1670 train_time:73804ms step_avg:94.50ms +step:782/1670 train_time:73897ms step_avg:94.50ms +step:783/1670 train_time:73991ms step_avg:94.50ms +step:784/1670 train_time:74085ms step_avg:94.50ms +step:785/1670 train_time:74179ms step_avg:94.50ms +step:786/1670 train_time:74274ms step_avg:94.50ms +step:787/1670 train_time:74368ms step_avg:94.50ms +step:788/1670 train_time:74462ms step_avg:94.49ms +step:789/1670 train_time:74555ms step_avg:94.49ms +step:790/1670 train_time:74648ms step_avg:94.49ms +step:791/1670 train_time:74740ms step_avg:94.49ms +step:792/1670 train_time:74834ms step_avg:94.49ms +step:793/1670 train_time:74927ms step_avg:94.49ms +step:794/1670 train_time:75022ms step_avg:94.49ms +step:795/1670 train_time:75116ms step_avg:94.49ms +step:796/1670 train_time:75210ms step_avg:94.49ms +step:797/1670 train_time:75304ms step_avg:94.48ms +step:798/1670 train_time:75398ms step_avg:94.48ms +step:799/1670 train_time:75491ms step_avg:94.48ms +step:800/1670 train_time:75585ms step_avg:94.48ms +step:801/1670 train_time:75679ms step_avg:94.48ms +step:802/1670 train_time:75772ms step_avg:94.48ms +step:803/1670 train_time:75865ms step_avg:94.48ms +step:804/1670 train_time:75958ms step_avg:94.47ms +step:805/1670 train_time:76051ms step_avg:94.47ms +step:806/1670 train_time:76144ms step_avg:94.47ms +step:807/1670 train_time:76239ms step_avg:94.47ms +step:808/1670 train_time:76333ms step_avg:94.47ms +step:809/1670 train_time:76427ms step_avg:94.47ms +step:810/1670 train_time:76521ms step_avg:94.47ms +step:811/1670 train_time:76614ms step_avg:94.47ms +step:812/1670 train_time:76707ms step_avg:94.47ms +step:813/1670 train_time:76801ms step_avg:94.47ms +step:814/1670 train_time:76894ms step_avg:94.46ms +step:815/1670 train_time:76987ms step_avg:94.46ms +step:816/1670 train_time:77081ms step_avg:94.46ms +step:817/1670 train_time:77176ms step_avg:94.46ms +step:818/1670 train_time:77270ms step_avg:94.46ms +step:819/1670 train_time:77364ms step_avg:94.46ms +step:820/1670 train_time:77457ms step_avg:94.46ms +step:821/1670 train_time:77550ms step_avg:94.46ms +step:822/1670 train_time:77645ms step_avg:94.46ms +step:823/1670 train_time:77739ms step_avg:94.46ms +step:824/1670 train_time:77833ms step_avg:94.46ms +step:825/1670 train_time:77926ms step_avg:94.46ms +step:826/1670 train_time:78020ms step_avg:94.45ms +step:827/1670 train_time:78113ms step_avg:94.45ms +step:828/1670 train_time:78207ms step_avg:94.45ms +step:829/1670 train_time:78302ms step_avg:94.45ms +step:830/1670 train_time:78396ms step_avg:94.45ms +step:831/1670 train_time:78489ms step_avg:94.45ms +step:832/1670 train_time:78583ms step_avg:94.45ms +step:833/1670 train_time:78677ms step_avg:94.45ms +step:834/1670 train_time:78771ms step_avg:94.45ms +step:835/1670 train_time:78864ms step_avg:94.45ms +step:836/1670 train_time:78958ms step_avg:94.45ms +step:837/1670 train_time:79051ms step_avg:94.45ms +step:838/1670 train_time:79145ms step_avg:94.44ms +step:839/1670 train_time:79239ms step_avg:94.45ms +step:840/1670 train_time:79334ms step_avg:94.44ms +step:841/1670 train_time:79427ms step_avg:94.44ms +step:842/1670 train_time:79521ms step_avg:94.44ms +step:843/1670 train_time:79615ms step_avg:94.44ms +step:844/1670 train_time:79709ms step_avg:94.44ms +step:845/1670 train_time:79803ms step_avg:94.44ms +step:846/1670 train_time:79896ms step_avg:94.44ms +step:847/1670 train_time:79990ms step_avg:94.44ms +step:848/1670 train_time:80083ms step_avg:94.44ms +step:849/1670 train_time:80176ms step_avg:94.44ms +step:850/1670 train_time:80270ms step_avg:94.44ms +step:851/1670 train_time:80629ms step_avg:94.75ms +step:852/1670 train_time:80744ms step_avg:94.77ms +step:853/1670 train_time:80836ms step_avg:94.77ms +step:854/1670 train_time:80928ms step_avg:94.76ms +step:855/1670 train_time:81021ms step_avg:94.76ms +step:856/1670 train_time:81114ms step_avg:94.76ms +step:857/1670 train_time:81207ms step_avg:94.76ms +step:858/1670 train_time:81299ms step_avg:94.75ms +step:859/1670 train_time:81391ms step_avg:94.75ms +step:860/1670 train_time:81484ms step_avg:94.75ms +step:861/1670 train_time:81579ms step_avg:94.75ms +step:862/1670 train_time:81678ms step_avg:94.75ms +step:863/1670 train_time:81775ms step_avg:94.76ms +step:864/1670 train_time:81869ms step_avg:94.76ms +step:865/1670 train_time:81962ms step_avg:94.75ms +step:866/1670 train_time:82055ms step_avg:94.75ms +step:867/1670 train_time:82147ms step_avg:94.75ms +step:868/1670 train_time:82240ms step_avg:94.75ms +step:869/1670 train_time:82332ms step_avg:94.74ms +step:870/1670 train_time:82425ms step_avg:94.74ms +step:871/1670 train_time:82518ms step_avg:94.74ms +step:872/1670 train_time:82613ms step_avg:94.74ms +step:873/1670 train_time:82710ms step_avg:94.74ms +step:874/1670 train_time:82804ms step_avg:94.74ms +step:875/1670 train_time:82898ms step_avg:94.74ms +step:875/1670 val_loss:3.5164 train_time:82990ms step_avg:94.85ms +step:876/1670 train_time:83015ms step_avg:94.77ms +step:877/1670 train_time:83092ms step_avg:94.75ms +step:878/1670 train_time:83192ms step_avg:94.75ms +step:879/1670 train_time:83285ms step_avg:94.75ms +step:880/1670 train_time:83379ms step_avg:94.75ms +step:881/1670 train_time:83472ms step_avg:94.75ms +step:882/1670 train_time:83564ms step_avg:94.74ms +step:883/1670 train_time:83657ms step_avg:94.74ms +step:884/1670 train_time:83750ms step_avg:94.74ms +step:885/1670 train_time:83842ms step_avg:94.74ms +step:886/1670 train_time:83936ms step_avg:94.74ms +step:887/1670 train_time:84031ms step_avg:94.74ms +step:888/1670 train_time:84127ms step_avg:94.74ms +step:889/1670 train_time:84222ms step_avg:94.74ms +step:890/1670 train_time:84316ms step_avg:94.74ms +step:891/1670 train_time:84409ms step_avg:94.74ms +step:892/1670 train_time:84502ms step_avg:94.73ms +step:893/1670 train_time:84596ms step_avg:94.73ms +step:894/1670 train_time:84688ms step_avg:94.73ms +step:895/1670 train_time:84781ms step_avg:94.73ms +step:896/1670 train_time:84875ms step_avg:94.73ms +step:897/1670 train_time:84969ms step_avg:94.73ms +step:898/1670 train_time:85063ms step_avg:94.72ms +step:899/1670 train_time:85158ms step_avg:94.72ms +step:900/1670 train_time:85252ms step_avg:94.72ms +step:901/1670 train_time:85346ms step_avg:94.72ms +step:902/1670 train_time:85440ms step_avg:94.72ms +step:903/1670 train_time:85533ms step_avg:94.72ms +step:904/1670 train_time:85626ms step_avg:94.72ms +step:905/1670 train_time:85720ms step_avg:94.72ms +step:906/1670 train_time:85813ms step_avg:94.72ms +step:907/1670 train_time:85906ms step_avg:94.71ms +step:908/1670 train_time:86000ms step_avg:94.71ms +step:909/1670 train_time:86094ms step_avg:94.71ms +step:910/1670 train_time:86189ms step_avg:94.71ms +step:911/1670 train_time:86282ms step_avg:94.71ms +step:912/1670 train_time:86377ms step_avg:94.71ms +step:913/1670 train_time:86471ms step_avg:94.71ms +step:914/1670 train_time:86563ms step_avg:94.71ms +step:915/1670 train_time:86656ms step_avg:94.71ms +step:916/1670 train_time:86750ms step_avg:94.71ms +step:917/1670 train_time:86843ms step_avg:94.70ms +step:918/1670 train_time:86937ms step_avg:94.70ms +step:919/1670 train_time:87031ms step_avg:94.70ms +step:920/1670 train_time:87125ms step_avg:94.70ms +step:921/1670 train_time:87219ms step_avg:94.70ms +step:922/1670 train_time:87313ms step_avg:94.70ms +step:923/1670 train_time:87407ms step_avg:94.70ms +step:924/1670 train_time:87500ms step_avg:94.70ms +step:925/1670 train_time:87594ms step_avg:94.70ms +step:926/1670 train_time:87687ms step_avg:94.69ms +step:927/1670 train_time:87780ms step_avg:94.69ms +step:928/1670 train_time:87874ms step_avg:94.69ms +step:929/1670 train_time:87967ms step_avg:94.69ms +step:930/1670 train_time:88060ms step_avg:94.69ms +step:931/1670 train_time:88155ms step_avg:94.69ms +step:932/1670 train_time:88248ms step_avg:94.69ms +step:933/1670 train_time:88343ms step_avg:94.69ms +step:934/1670 train_time:88437ms step_avg:94.69ms +step:935/1670 train_time:88530ms step_avg:94.68ms +step:936/1670 train_time:88623ms step_avg:94.68ms +step:937/1670 train_time:88716ms step_avg:94.68ms +step:938/1670 train_time:88810ms step_avg:94.68ms +step:939/1670 train_time:88904ms step_avg:94.68ms +step:940/1670 train_time:88998ms step_avg:94.68ms +step:941/1670 train_time:89091ms step_avg:94.68ms +step:942/1670 train_time:89185ms step_avg:94.68ms +step:943/1670 train_time:89278ms step_avg:94.67ms +step:944/1670 train_time:89373ms step_avg:94.67ms +step:945/1670 train_time:89467ms step_avg:94.67ms +step:946/1670 train_time:89561ms step_avg:94.67ms +step:947/1670 train_time:89654ms step_avg:94.67ms +step:948/1670 train_time:89748ms step_avg:94.67ms +step:949/1670 train_time:89841ms step_avg:94.67ms +step:950/1670 train_time:89935ms step_avg:94.67ms +step:951/1670 train_time:90028ms step_avg:94.67ms +step:952/1670 train_time:90121ms step_avg:94.67ms +step:953/1670 train_time:90215ms step_avg:94.66ms +step:954/1670 train_time:90309ms step_avg:94.66ms +step:955/1670 train_time:90403ms step_avg:94.66ms +step:956/1670 train_time:90497ms step_avg:94.66ms +step:957/1670 train_time:90591ms step_avg:94.66ms +step:958/1670 train_time:90684ms step_avg:94.66ms +step:959/1670 train_time:90777ms step_avg:94.66ms +step:960/1670 train_time:90871ms step_avg:94.66ms +step:961/1670 train_time:90964ms step_avg:94.66ms +step:962/1670 train_time:91057ms step_avg:94.65ms +step:963/1670 train_time:91152ms step_avg:94.65ms +step:964/1670 train_time:91246ms step_avg:94.65ms +step:965/1670 train_time:91340ms step_avg:94.65ms +step:966/1670 train_time:91435ms step_avg:94.65ms +step:967/1670 train_time:91528ms step_avg:94.65ms +step:968/1670 train_time:91621ms step_avg:94.65ms +step:969/1670 train_time:91714ms step_avg:94.65ms +step:970/1670 train_time:91809ms step_avg:94.65ms +step:971/1670 train_time:91902ms step_avg:94.65ms +step:972/1670 train_time:91996ms step_avg:94.65ms +step:973/1670 train_time:92089ms step_avg:94.64ms +step:974/1670 train_time:92182ms step_avg:94.64ms +step:975/1670 train_time:92278ms step_avg:94.64ms +step:976/1670 train_time:92373ms step_avg:94.64ms +step:977/1670 train_time:92467ms step_avg:94.64ms +step:978/1670 train_time:92560ms step_avg:94.64ms +step:979/1670 train_time:92654ms step_avg:94.64ms +step:980/1670 train_time:92748ms step_avg:94.64ms +step:981/1670 train_time:92841ms step_avg:94.64ms +step:982/1670 train_time:92935ms step_avg:94.64ms +step:983/1670 train_time:93029ms step_avg:94.64ms +step:984/1670 train_time:93122ms step_avg:94.64ms +step:985/1670 train_time:93215ms step_avg:94.63ms +step:986/1670 train_time:93310ms step_avg:94.63ms +step:987/1670 train_time:93403ms step_avg:94.63ms +step:988/1670 train_time:93497ms step_avg:94.63ms +step:989/1670 train_time:93591ms step_avg:94.63ms +step:990/1670 train_time:93684ms step_avg:94.63ms +step:991/1670 train_time:93778ms step_avg:94.63ms +step:992/1670 train_time:93871ms step_avg:94.63ms +step:993/1670 train_time:93966ms step_avg:94.63ms +step:994/1670 train_time:94058ms step_avg:94.63ms +step:995/1670 train_time:94151ms step_avg:94.62ms +step:996/1670 train_time:94245ms step_avg:94.62ms +step:997/1670 train_time:94339ms step_avg:94.62ms +step:998/1670 train_time:94432ms step_avg:94.62ms +step:999/1670 train_time:94527ms step_avg:94.62ms +step:1000/1670 train_time:94620ms step_avg:94.62ms +step:1000/1670 val_loss:3.4679 train_time:94712ms step_avg:94.71ms +step:1001/1670 train_time:94738ms step_avg:94.64ms +step:1002/1670 train_time:94812ms step_avg:94.62ms +step:1003/1670 train_time:94914ms step_avg:94.63ms +step:1004/1670 train_time:95010ms step_avg:94.63ms +step:1005/1670 train_time:95103ms step_avg:94.63ms +step:1006/1670 train_time:95196ms step_avg:94.63ms +step:1007/1670 train_time:95288ms step_avg:94.63ms +step:1008/1670 train_time:95381ms step_avg:94.62ms +step:1009/1670 train_time:95474ms step_avg:94.62ms +step:1010/1670 train_time:95566ms step_avg:94.62ms +step:1011/1670 train_time:95659ms step_avg:94.62ms +step:1012/1670 train_time:95754ms step_avg:94.62ms +step:1013/1670 train_time:95850ms step_avg:94.62ms +step:1014/1670 train_time:95946ms step_avg:94.62ms +step:1015/1670 train_time:96041ms step_avg:94.62ms +step:1016/1670 train_time:96134ms step_avg:94.62ms +step:1017/1670 train_time:96227ms step_avg:94.62ms +step:1018/1670 train_time:96320ms step_avg:94.62ms +step:1019/1670 train_time:96413ms step_avg:94.62ms +step:1020/1670 train_time:96507ms step_avg:94.61ms +step:1021/1670 train_time:96599ms step_avg:94.61ms +step:1022/1670 train_time:96692ms step_avg:94.61ms +step:1023/1670 train_time:96786ms step_avg:94.61ms +step:1024/1670 train_time:96882ms step_avg:94.61ms +step:1025/1670 train_time:96976ms step_avg:94.61ms +step:1026/1670 train_time:97071ms step_avg:94.61ms +step:1027/1670 train_time:97165ms step_avg:94.61ms +step:1028/1670 train_time:97258ms step_avg:94.61ms +step:1029/1670 train_time:97351ms step_avg:94.61ms +step:1030/1670 train_time:97445ms step_avg:94.61ms +step:1031/1670 train_time:97539ms step_avg:94.61ms +step:1032/1670 train_time:97632ms step_avg:94.60ms +step:1033/1670 train_time:97726ms step_avg:94.60ms +step:1034/1670 train_time:97819ms step_avg:94.60ms +step:1035/1670 train_time:97915ms step_avg:94.60ms +step:1036/1670 train_time:98010ms step_avg:94.60ms +step:1037/1670 train_time:98103ms step_avg:94.60ms +step:1038/1670 train_time:98196ms step_avg:94.60ms +step:1039/1670 train_time:98290ms step_avg:94.60ms +step:1040/1670 train_time:98384ms step_avg:94.60ms +step:1041/1670 train_time:98477ms step_avg:94.60ms +step:1042/1670 train_time:98570ms step_avg:94.60ms +step:1043/1670 train_time:98663ms step_avg:94.60ms +step:1044/1670 train_time:98757ms step_avg:94.59ms +step:1045/1670 train_time:98851ms step_avg:94.59ms +step:1046/1670 train_time:98947ms step_avg:94.60ms +step:1047/1670 train_time:99040ms step_avg:94.59ms +step:1048/1670 train_time:99135ms step_avg:94.59ms +step:1049/1670 train_time:99229ms step_avg:94.59ms +step:1050/1670 train_time:99322ms step_avg:94.59ms +step:1051/1670 train_time:99416ms step_avg:94.59ms +step:1052/1670 train_time:99509ms step_avg:94.59ms +step:1053/1670 train_time:99604ms step_avg:94.59ms +step:1054/1670 train_time:99696ms step_avg:94.59ms +step:1055/1670 train_time:99790ms step_avg:94.59ms +step:1056/1670 train_time:99885ms step_avg:94.59ms +step:1057/1670 train_time:99980ms step_avg:94.59ms +step:1058/1670 train_time:100074ms step_avg:94.59ms +step:1059/1670 train_time:100167ms step_avg:94.59ms +step:1060/1670 train_time:100261ms step_avg:94.59ms +step:1061/1670 train_time:100355ms step_avg:94.59ms +step:1062/1670 train_time:100784ms step_avg:94.90ms +step:1063/1670 train_time:100852ms step_avg:94.88ms +step:1064/1670 train_time:100944ms step_avg:94.87ms +step:1065/1670 train_time:101036ms step_avg:94.87ms +step:1066/1670 train_time:101129ms step_avg:94.87ms +step:1067/1670 train_time:101221ms step_avg:94.87ms +step:1068/1670 train_time:101314ms step_avg:94.86ms +step:1069/1670 train_time:101407ms step_avg:94.86ms +step:1070/1670 train_time:101499ms step_avg:94.86ms +step:1071/1670 train_time:101592ms step_avg:94.86ms +step:1072/1670 train_time:101687ms step_avg:94.86ms +step:1073/1670 train_time:101782ms step_avg:94.86ms +step:1074/1670 train_time:101880ms step_avg:94.86ms +step:1075/1670 train_time:101974ms step_avg:94.86ms +step:1076/1670 train_time:102067ms step_avg:94.86ms +step:1077/1670 train_time:102160ms step_avg:94.86ms +step:1078/1670 train_time:102252ms step_avg:94.85ms +step:1079/1670 train_time:102346ms step_avg:94.85ms +step:1080/1670 train_time:102439ms step_avg:94.85ms +step:1081/1670 train_time:102532ms step_avg:94.85ms +step:1082/1670 train_time:102625ms step_avg:94.85ms +step:1083/1670 train_time:102720ms step_avg:94.85ms +step:1084/1670 train_time:102816ms step_avg:94.85ms +step:1085/1670 train_time:102912ms step_avg:94.85ms +step:1086/1670 train_time:103007ms step_avg:94.85ms +step:1087/1670 train_time:103099ms step_avg:94.85ms +step:1088/1670 train_time:103193ms step_avg:94.85ms +step:1089/1670 train_time:103286ms step_avg:94.84ms +step:1090/1670 train_time:103379ms step_avg:94.84ms +step:1091/1670 train_time:103472ms step_avg:94.84ms +step:1092/1670 train_time:103566ms step_avg:94.84ms +step:1093/1670 train_time:103659ms step_avg:94.84ms +step:1094/1670 train_time:103753ms step_avg:94.84ms +step:1095/1670 train_time:103848ms step_avg:94.84ms +step:1096/1670 train_time:103942ms step_avg:94.84ms +step:1097/1670 train_time:104036ms step_avg:94.84ms +step:1098/1670 train_time:104130ms step_avg:94.84ms +step:1099/1670 train_time:104224ms step_avg:94.84ms +step:1100/1670 train_time:104317ms step_avg:94.83ms +step:1101/1670 train_time:104411ms step_avg:94.83ms +step:1102/1670 train_time:104504ms step_avg:94.83ms +step:1103/1670 train_time:104597ms step_avg:94.83ms +step:1104/1670 train_time:104690ms step_avg:94.83ms +step:1105/1670 train_time:104784ms step_avg:94.83ms +step:1106/1670 train_time:104878ms step_avg:94.83ms +step:1107/1670 train_time:104973ms step_avg:94.83ms +step:1108/1670 train_time:105066ms step_avg:94.83ms +step:1109/1670 train_time:105160ms step_avg:94.82ms +step:1110/1670 train_time:105253ms step_avg:94.82ms +step:1111/1670 train_time:105347ms step_avg:94.82ms +step:1112/1670 train_time:105440ms step_avg:94.82ms +step:1113/1670 train_time:105534ms step_avg:94.82ms +step:1114/1670 train_time:105628ms step_avg:94.82ms +step:1115/1670 train_time:105830ms step_avg:94.91ms +step:1116/1670 train_time:105899ms step_avg:94.89ms +step:1117/1670 train_time:105992ms step_avg:94.89ms +step:1118/1670 train_time:106085ms step_avg:94.89ms +step:1119/1670 train_time:106178ms step_avg:94.89ms +step:1120/1670 train_time:106272ms step_avg:94.89ms +step:1121/1670 train_time:106366ms step_avg:94.88ms +step:1122/1670 train_time:106459ms step_avg:94.88ms +step:1123/1670 train_time:106553ms step_avg:94.88ms +step:1124/1670 train_time:106646ms step_avg:94.88ms +step:1125/1670 train_time:106745ms step_avg:94.88ms +step:1125/1670 val_loss:3.4148 train_time:106842ms step_avg:94.97ms +step:1126/1670 train_time:106868ms step_avg:94.91ms +step:1127/1670 train_time:106942ms step_avg:94.89ms +step:1128/1670 train_time:107042ms step_avg:94.90ms +step:1129/1670 train_time:107137ms step_avg:94.90ms +step:1130/1670 train_time:107230ms step_avg:94.89ms +step:1131/1670 train_time:107323ms step_avg:94.89ms +step:1132/1670 train_time:107417ms step_avg:94.89ms +step:1133/1670 train_time:107510ms step_avg:94.89ms +step:1134/1670 train_time:107604ms step_avg:94.89ms +step:1135/1670 train_time:107697ms step_avg:94.89ms +step:1136/1670 train_time:107792ms step_avg:94.89ms +step:1137/1670 train_time:107888ms step_avg:94.89ms +step:1138/1670 train_time:107983ms step_avg:94.89ms +step:1139/1670 train_time:108077ms step_avg:94.89ms +step:1140/1670 train_time:108171ms step_avg:94.89ms +step:1141/1670 train_time:108265ms step_avg:94.89ms +step:1142/1670 train_time:108359ms step_avg:94.89ms +step:1143/1670 train_time:108453ms step_avg:94.88ms +step:1144/1670 train_time:108547ms step_avg:94.88ms +step:1145/1670 train_time:108640ms step_avg:94.88ms +step:1146/1670 train_time:108734ms step_avg:94.88ms +step:1147/1670 train_time:108829ms step_avg:94.88ms +step:1148/1670 train_time:108924ms step_avg:94.88ms +step:1149/1670 train_time:109019ms step_avg:94.88ms +step:1150/1670 train_time:109113ms step_avg:94.88ms +step:1151/1670 train_time:109208ms step_avg:94.88ms +step:1152/1670 train_time:109301ms step_avg:94.88ms +step:1153/1670 train_time:109395ms step_avg:94.88ms +step:1154/1670 train_time:109489ms step_avg:94.88ms +step:1155/1670 train_time:109583ms step_avg:94.88ms +step:1156/1670 train_time:109677ms step_avg:94.88ms +step:1157/1670 train_time:109771ms step_avg:94.88ms +step:1158/1670 train_time:109865ms step_avg:94.88ms +step:1159/1670 train_time:109960ms step_avg:94.87ms +step:1160/1670 train_time:110054ms step_avg:94.87ms +step:1161/1670 train_time:110148ms step_avg:94.87ms +step:1162/1670 train_time:110242ms step_avg:94.87ms +step:1163/1670 train_time:110336ms step_avg:94.87ms +step:1164/1670 train_time:110430ms step_avg:94.87ms +step:1165/1670 train_time:110524ms step_avg:94.87ms +step:1166/1670 train_time:110617ms step_avg:94.87ms +step:1167/1670 train_time:110712ms step_avg:94.87ms +step:1168/1670 train_time:110806ms step_avg:94.87ms +step:1169/1670 train_time:110900ms step_avg:94.87ms +step:1170/1670 train_time:110994ms step_avg:94.87ms +step:1171/1670 train_time:111088ms step_avg:94.87ms +step:1172/1670 train_time:111184ms step_avg:94.87ms +step:1173/1670 train_time:111277ms step_avg:94.87ms +step:1174/1670 train_time:111371ms step_avg:94.86ms +step:1175/1670 train_time:111466ms step_avg:94.86ms +step:1176/1670 train_time:111561ms step_avg:94.86ms +step:1177/1670 train_time:111654ms step_avg:94.86ms +step:1178/1670 train_time:111747ms step_avg:94.86ms +step:1179/1670 train_time:111842ms step_avg:94.86ms +step:1180/1670 train_time:111936ms step_avg:94.86ms +step:1181/1670 train_time:112030ms step_avg:94.86ms +step:1182/1670 train_time:112124ms step_avg:94.86ms +step:1183/1670 train_time:112218ms step_avg:94.86ms +step:1184/1670 train_time:112312ms step_avg:94.86ms +step:1185/1670 train_time:112406ms step_avg:94.86ms +step:1186/1670 train_time:112500ms step_avg:94.86ms +step:1187/1670 train_time:112594ms step_avg:94.86ms +step:1188/1670 train_time:112688ms step_avg:94.86ms +step:1189/1670 train_time:112783ms step_avg:94.86ms +step:1190/1670 train_time:112876ms step_avg:94.85ms +step:1191/1670 train_time:112970ms step_avg:94.85ms +step:1192/1670 train_time:113065ms step_avg:94.85ms +step:1193/1670 train_time:113159ms step_avg:94.85ms +step:1194/1670 train_time:113254ms step_avg:94.85ms +step:1195/1670 train_time:113349ms step_avg:94.85ms +step:1196/1670 train_time:113443ms step_avg:94.85ms +step:1197/1670 train_time:113538ms step_avg:94.85ms +step:1198/1670 train_time:113633ms step_avg:94.85ms +step:1199/1670 train_time:113727ms step_avg:94.85ms +step:1200/1670 train_time:113821ms step_avg:94.85ms +step:1201/1670 train_time:113915ms step_avg:94.85ms +step:1202/1670 train_time:114009ms step_avg:94.85ms +step:1203/1670 train_time:114103ms step_avg:94.85ms +step:1204/1670 train_time:114197ms step_avg:94.85ms +step:1205/1670 train_time:114292ms step_avg:94.85ms +step:1206/1670 train_time:114387ms step_avg:94.85ms +step:1207/1670 train_time:114481ms step_avg:94.85ms +step:1208/1670 train_time:114575ms step_avg:94.85ms +step:1209/1670 train_time:114670ms step_avg:94.85ms +step:1210/1670 train_time:114765ms step_avg:94.85ms +step:1211/1670 train_time:114859ms step_avg:94.85ms +step:1212/1670 train_time:114954ms step_avg:94.85ms +step:1213/1670 train_time:115048ms step_avg:94.85ms +step:1214/1670 train_time:115143ms step_avg:94.85ms +step:1215/1670 train_time:115237ms step_avg:94.85ms +step:1216/1670 train_time:115332ms step_avg:94.85ms +step:1217/1670 train_time:115427ms step_avg:94.85ms +step:1218/1670 train_time:115521ms step_avg:94.85ms +step:1219/1670 train_time:115615ms step_avg:94.84ms +step:1220/1670 train_time:115711ms step_avg:94.85ms +step:1221/1670 train_time:115806ms step_avg:94.85ms +step:1222/1670 train_time:115900ms step_avg:94.84ms +step:1223/1670 train_time:115994ms step_avg:94.84ms +step:1224/1670 train_time:116089ms step_avg:94.84ms +step:1225/1670 train_time:116182ms step_avg:94.84ms +step:1226/1670 train_time:116277ms step_avg:94.84ms +step:1227/1670 train_time:116373ms step_avg:94.84ms +step:1228/1670 train_time:116467ms step_avg:94.84ms +step:1229/1670 train_time:116561ms step_avg:94.84ms +step:1230/1670 train_time:116656ms step_avg:94.84ms +step:1231/1670 train_time:116751ms step_avg:94.84ms +step:1232/1670 train_time:116847ms step_avg:94.84ms +step:1233/1670 train_time:116941ms step_avg:94.84ms +step:1234/1670 train_time:117035ms step_avg:94.84ms +step:1235/1670 train_time:117129ms step_avg:94.84ms +step:1236/1670 train_time:117223ms step_avg:94.84ms +step:1237/1670 train_time:117316ms step_avg:94.84ms +step:1238/1670 train_time:117411ms step_avg:94.84ms +step:1239/1670 train_time:117505ms step_avg:94.84ms +step:1240/1670 train_time:117601ms step_avg:94.84ms +step:1241/1670 train_time:117696ms step_avg:94.84ms +step:1242/1670 train_time:117790ms step_avg:94.84ms +step:1243/1670 train_time:117885ms step_avg:94.84ms +step:1244/1670 train_time:117979ms step_avg:94.84ms +step:1245/1670 train_time:118073ms step_avg:94.84ms +step:1246/1670 train_time:118169ms step_avg:94.84ms +step:1247/1670 train_time:118262ms step_avg:94.84ms +step:1248/1670 train_time:118357ms step_avg:94.84ms +step:1249/1670 train_time:118451ms step_avg:94.84ms +step:1250/1670 train_time:118546ms step_avg:94.84ms +step:1250/1670 val_loss:3.3758 train_time:118638ms step_avg:94.91ms +step:1251/1670 train_time:118665ms step_avg:94.86ms +step:1252/1670 train_time:118745ms step_avg:94.84ms +step:1253/1670 train_time:118845ms step_avg:94.85ms +step:1254/1670 train_time:118941ms step_avg:94.85ms +step:1255/1670 train_time:119035ms step_avg:94.85ms +step:1256/1670 train_time:119128ms step_avg:94.85ms +step:1257/1670 train_time:119221ms step_avg:94.85ms +step:1258/1670 train_time:119314ms step_avg:94.84ms +step:1259/1670 train_time:119407ms step_avg:94.84ms +step:1260/1670 train_time:119501ms step_avg:94.84ms +step:1261/1670 train_time:119594ms step_avg:94.84ms +step:1262/1670 train_time:119691ms step_avg:94.84ms +step:1263/1670 train_time:119788ms step_avg:94.84ms +step:1264/1670 train_time:119885ms step_avg:94.85ms +step:1265/1670 train_time:119980ms step_avg:94.85ms +step:1266/1670 train_time:120074ms step_avg:94.85ms +step:1267/1670 train_time:120169ms step_avg:94.85ms +step:1268/1670 train_time:120263ms step_avg:94.84ms +step:1269/1670 train_time:120356ms step_avg:94.84ms +step:1270/1670 train_time:120450ms step_avg:94.84ms +step:1271/1670 train_time:120544ms step_avg:94.84ms +step:1272/1670 train_time:120639ms step_avg:94.84ms +step:1273/1670 train_time:120733ms step_avg:94.84ms +step:1274/1670 train_time:121100ms step_avg:95.06ms +step:1275/1670 train_time:121188ms step_avg:95.05ms +step:1276/1670 train_time:121281ms step_avg:95.05ms +step:1277/1670 train_time:121374ms step_avg:95.05ms +step:1278/1670 train_time:121467ms step_avg:95.04ms +step:1279/1670 train_time:121560ms step_avg:95.04ms +step:1280/1670 train_time:121653ms step_avg:95.04ms +step:1281/1670 train_time:121746ms step_avg:95.04ms +step:1282/1670 train_time:121840ms step_avg:95.04ms +step:1283/1670 train_time:121932ms step_avg:95.04ms +step:1284/1670 train_time:122027ms step_avg:95.04ms +step:1285/1670 train_time:122126ms step_avg:95.04ms +step:1286/1670 train_time:122223ms step_avg:95.04ms +step:1287/1670 train_time:122317ms step_avg:95.04ms +step:1288/1670 train_time:122411ms step_avg:95.04ms +step:1289/1670 train_time:122505ms step_avg:95.04ms +step:1290/1670 train_time:122598ms step_avg:95.04ms +step:1291/1670 train_time:122692ms step_avg:95.04ms +step:1292/1670 train_time:122786ms step_avg:95.04ms +step:1293/1670 train_time:122880ms step_avg:95.03ms +step:1294/1670 train_time:122973ms step_avg:95.03ms +step:1295/1670 train_time:123069ms step_avg:95.03ms +step:1296/1670 train_time:123165ms step_avg:95.03ms +step:1297/1670 train_time:123261ms step_avg:95.04ms +step:1298/1670 train_time:123356ms step_avg:95.04ms +step:1299/1670 train_time:123451ms step_avg:95.04ms +step:1300/1670 train_time:123545ms step_avg:95.03ms +step:1301/1670 train_time:123639ms step_avg:95.03ms +step:1302/1670 train_time:123733ms step_avg:95.03ms +step:1303/1670 train_time:123826ms step_avg:95.03ms +step:1304/1670 train_time:123920ms step_avg:95.03ms +step:1305/1670 train_time:124013ms step_avg:95.03ms +step:1306/1670 train_time:124109ms step_avg:95.03ms +step:1307/1670 train_time:124206ms step_avg:95.03ms +step:1308/1670 train_time:124301ms step_avg:95.03ms +step:1309/1670 train_time:124396ms step_avg:95.03ms +step:1310/1670 train_time:124491ms step_avg:95.03ms +step:1311/1670 train_time:124586ms step_avg:95.03ms +step:1312/1670 train_time:124681ms step_avg:95.03ms +step:1313/1670 train_time:124775ms step_avg:95.03ms +step:1314/1670 train_time:124870ms step_avg:95.03ms +step:1315/1670 train_time:124964ms step_avg:95.03ms +step:1316/1670 train_time:125059ms step_avg:95.03ms +step:1317/1670 train_time:125153ms step_avg:95.03ms +step:1318/1670 train_time:125248ms step_avg:95.03ms +step:1319/1670 train_time:125344ms step_avg:95.03ms +step:1320/1670 train_time:125439ms step_avg:95.03ms +step:1321/1670 train_time:125533ms step_avg:95.03ms +step:1322/1670 train_time:125627ms step_avg:95.03ms +step:1323/1670 train_time:125722ms step_avg:95.03ms +step:1324/1670 train_time:125816ms step_avg:95.03ms +step:1325/1670 train_time:125910ms step_avg:95.03ms +step:1326/1670 train_time:126004ms step_avg:95.03ms +step:1327/1670 train_time:126098ms step_avg:95.02ms +step:1328/1670 train_time:126192ms step_avg:95.02ms +step:1329/1670 train_time:126288ms step_avg:95.03ms +step:1330/1670 train_time:126384ms step_avg:95.03ms +step:1331/1670 train_time:126479ms step_avg:95.03ms +step:1332/1670 train_time:126573ms step_avg:95.02ms +step:1333/1670 train_time:126667ms step_avg:95.02ms +step:1334/1670 train_time:126762ms step_avg:95.02ms +step:1335/1670 train_time:126857ms step_avg:95.02ms +step:1336/1670 train_time:126952ms step_avg:95.02ms +step:1337/1670 train_time:127045ms step_avg:95.02ms +step:1338/1670 train_time:127140ms step_avg:95.02ms +step:1339/1670 train_time:127235ms step_avg:95.02ms +step:1340/1670 train_time:127329ms step_avg:95.02ms +step:1341/1670 train_time:127423ms step_avg:95.02ms +step:1342/1670 train_time:127517ms step_avg:95.02ms +step:1343/1670 train_time:127612ms step_avg:95.02ms +step:1344/1670 train_time:127705ms step_avg:95.02ms +step:1345/1670 train_time:127799ms step_avg:95.02ms +step:1346/1670 train_time:127893ms step_avg:95.02ms +step:1347/1670 train_time:127988ms step_avg:95.02ms +step:1348/1670 train_time:128082ms step_avg:95.02ms +step:1349/1670 train_time:128177ms step_avg:95.02ms +step:1350/1670 train_time:128271ms step_avg:95.02ms +step:1351/1670 train_time:128367ms step_avg:95.02ms +step:1352/1670 train_time:128462ms step_avg:95.02ms +step:1353/1670 train_time:128556ms step_avg:95.02ms +step:1354/1670 train_time:128651ms step_avg:95.02ms +step:1355/1670 train_time:128745ms step_avg:95.01ms +step:1356/1670 train_time:128840ms step_avg:95.01ms +step:1357/1670 train_time:128933ms step_avg:95.01ms +step:1358/1670 train_time:129027ms step_avg:95.01ms +step:1359/1670 train_time:129122ms step_avg:95.01ms +step:1360/1670 train_time:129216ms step_avg:95.01ms +step:1361/1670 train_time:129311ms step_avg:95.01ms +step:1362/1670 train_time:129406ms step_avg:95.01ms +step:1363/1670 train_time:129500ms step_avg:95.01ms +step:1364/1670 train_time:129595ms step_avg:95.01ms +step:1365/1670 train_time:129689ms step_avg:95.01ms +step:1366/1670 train_time:129784ms step_avg:95.01ms +step:1367/1670 train_time:129878ms step_avg:95.01ms +step:1368/1670 train_time:129973ms step_avg:95.01ms +step:1369/1670 train_time:130068ms step_avg:95.01ms +step:1370/1670 train_time:130163ms step_avg:95.01ms +step:1371/1670 train_time:130258ms step_avg:95.01ms +step:1372/1670 train_time:130351ms step_avg:95.01ms +step:1373/1670 train_time:130446ms step_avg:95.01ms +step:1374/1670 train_time:130541ms step_avg:95.01ms +step:1375/1670 train_time:130635ms step_avg:95.01ms +step:1375/1670 val_loss:3.3413 train_time:130728ms step_avg:95.07ms +step:1376/1670 train_time:130754ms step_avg:95.02ms +step:1377/1670 train_time:130832ms step_avg:95.01ms +step:1378/1670 train_time:130935ms step_avg:95.02ms +step:1379/1670 train_time:131030ms step_avg:95.02ms +step:1380/1670 train_time:131123ms step_avg:95.02ms +step:1381/1670 train_time:131218ms step_avg:95.02ms +step:1382/1670 train_time:131311ms step_avg:95.02ms +step:1383/1670 train_time:131404ms step_avg:95.01ms +step:1384/1670 train_time:131498ms step_avg:95.01ms +step:1385/1670 train_time:131592ms step_avg:95.01ms +step:1386/1670 train_time:131685ms step_avg:95.01ms +step:1387/1670 train_time:131783ms step_avg:95.01ms +step:1388/1670 train_time:131881ms step_avg:95.01ms +step:1389/1670 train_time:131978ms step_avg:95.02ms +step:1390/1670 train_time:132074ms step_avg:95.02ms +step:1391/1670 train_time:132167ms step_avg:95.02ms +step:1392/1670 train_time:132261ms step_avg:95.01ms +step:1393/1670 train_time:132354ms step_avg:95.01ms +step:1394/1670 train_time:132448ms step_avg:95.01ms +step:1395/1670 train_time:132542ms step_avg:95.01ms +step:1396/1670 train_time:132635ms step_avg:95.01ms +step:1397/1670 train_time:132730ms step_avg:95.01ms +step:1398/1670 train_time:132826ms step_avg:95.01ms +step:1399/1670 train_time:132922ms step_avg:95.01ms +step:1400/1670 train_time:133019ms step_avg:95.01ms +step:1401/1670 train_time:133114ms step_avg:95.01ms +step:1402/1670 train_time:133208ms step_avg:95.01ms +step:1403/1670 train_time:133302ms step_avg:95.01ms +step:1404/1670 train_time:133395ms step_avg:95.01ms +step:1405/1670 train_time:133489ms step_avg:95.01ms +step:1406/1670 train_time:133583ms step_avg:95.01ms +step:1407/1670 train_time:133677ms step_avg:95.01ms +step:1408/1670 train_time:133773ms step_avg:95.01ms +step:1409/1670 train_time:133869ms step_avg:95.01ms +step:1410/1670 train_time:133963ms step_avg:95.01ms +step:1411/1670 train_time:134059ms step_avg:95.01ms +step:1412/1670 train_time:134154ms step_avg:95.01ms +step:1413/1670 train_time:134248ms step_avg:95.01ms +step:1414/1670 train_time:134342ms step_avg:95.01ms +step:1415/1670 train_time:134436ms step_avg:95.01ms +step:1416/1670 train_time:134531ms step_avg:95.01ms +step:1417/1670 train_time:134625ms step_avg:95.01ms +step:1418/1670 train_time:134719ms step_avg:95.01ms +step:1419/1670 train_time:134814ms step_avg:95.01ms +step:1420/1670 train_time:134909ms step_avg:95.01ms +step:1421/1670 train_time:135004ms step_avg:95.01ms +step:1422/1670 train_time:135100ms step_avg:95.01ms +step:1423/1670 train_time:135194ms step_avg:95.01ms +step:1424/1670 train_time:135288ms step_avg:95.01ms +step:1425/1670 train_time:135382ms step_avg:95.00ms +step:1426/1670 train_time:135476ms step_avg:95.00ms +step:1427/1670 train_time:135570ms step_avg:95.00ms +step:1428/1670 train_time:135664ms step_avg:95.00ms +step:1429/1670 train_time:135759ms step_avg:95.00ms +step:1430/1670 train_time:135854ms step_avg:95.00ms +step:1431/1670 train_time:135948ms step_avg:95.00ms +step:1432/1670 train_time:136043ms step_avg:95.00ms +step:1433/1670 train_time:136138ms step_avg:95.00ms +step:1434/1670 train_time:136233ms step_avg:95.00ms +step:1435/1670 train_time:136327ms step_avg:95.00ms +step:1436/1670 train_time:136422ms step_avg:95.00ms +step:1437/1670 train_time:136516ms step_avg:95.00ms +step:1438/1670 train_time:136610ms step_avg:95.00ms +step:1439/1670 train_time:136704ms step_avg:95.00ms +step:1440/1670 train_time:136800ms step_avg:95.00ms +step:1441/1670 train_time:136894ms step_avg:95.00ms +step:1442/1670 train_time:136991ms step_avg:95.00ms +step:1443/1670 train_time:137085ms step_avg:95.00ms +step:1444/1670 train_time:137180ms step_avg:95.00ms +step:1445/1670 train_time:137275ms step_avg:95.00ms +step:1446/1670 train_time:137369ms step_avg:95.00ms +step:1447/1670 train_time:137464ms step_avg:95.00ms +step:1448/1670 train_time:137559ms step_avg:95.00ms +step:1449/1670 train_time:137653ms step_avg:95.00ms +step:1450/1670 train_time:137748ms step_avg:95.00ms +step:1451/1670 train_time:137842ms step_avg:95.00ms +step:1452/1670 train_time:137937ms step_avg:95.00ms +step:1453/1670 train_time:138032ms step_avg:95.00ms +step:1454/1670 train_time:138127ms step_avg:95.00ms +step:1455/1670 train_time:138222ms step_avg:95.00ms +step:1456/1670 train_time:138317ms step_avg:95.00ms +step:1457/1670 train_time:138412ms step_avg:95.00ms +step:1458/1670 train_time:138506ms step_avg:95.00ms +step:1459/1670 train_time:138601ms step_avg:95.00ms +step:1460/1670 train_time:138696ms step_avg:95.00ms +step:1461/1670 train_time:138790ms step_avg:95.00ms +step:1462/1670 train_time:138884ms step_avg:95.00ms +step:1463/1670 train_time:138978ms step_avg:95.00ms +step:1464/1670 train_time:139074ms step_avg:95.00ms +step:1465/1670 train_time:139169ms step_avg:95.00ms +step:1466/1670 train_time:139263ms step_avg:95.00ms +step:1467/1670 train_time:139358ms step_avg:95.00ms +step:1468/1670 train_time:139453ms step_avg:95.00ms +step:1469/1670 train_time:139547ms step_avg:94.99ms +step:1470/1670 train_time:139642ms step_avg:94.99ms +step:1471/1670 train_time:139736ms step_avg:94.99ms +step:1472/1670 train_time:139831ms step_avg:94.99ms +step:1473/1670 train_time:139925ms step_avg:94.99ms +step:1474/1670 train_time:140020ms step_avg:94.99ms +step:1475/1670 train_time:140115ms step_avg:94.99ms +step:1476/1670 train_time:140211ms step_avg:94.99ms +step:1477/1670 train_time:140306ms step_avg:94.99ms +step:1478/1670 train_time:140400ms step_avg:94.99ms +step:1479/1670 train_time:140496ms step_avg:94.99ms +step:1480/1670 train_time:140590ms step_avg:94.99ms +step:1481/1670 train_time:140685ms step_avg:94.99ms +step:1482/1670 train_time:140780ms step_avg:94.99ms +step:1483/1670 train_time:140874ms step_avg:94.99ms +step:1484/1670 train_time:140968ms step_avg:94.99ms +step:1485/1670 train_time:141304ms step_avg:95.15ms +step:1486/1670 train_time:141379ms step_avg:95.14ms +step:1487/1670 train_time:141472ms step_avg:95.14ms +step:1488/1670 train_time:141565ms step_avg:95.14ms +step:1489/1670 train_time:141659ms step_avg:95.14ms +step:1490/1670 train_time:141753ms step_avg:95.14ms +step:1491/1670 train_time:141845ms step_avg:95.13ms +step:1492/1670 train_time:141939ms step_avg:95.13ms +step:1493/1670 train_time:142032ms step_avg:95.13ms +step:1494/1670 train_time:142126ms step_avg:95.13ms +step:1495/1670 train_time:142226ms step_avg:95.13ms +step:1496/1670 train_time:142326ms step_avg:95.14ms +step:1497/1670 train_time:142422ms step_avg:95.14ms +step:1498/1670 train_time:142516ms step_avg:95.14ms +step:1499/1670 train_time:142610ms step_avg:95.14ms +step:1500/1670 train_time:142703ms step_avg:95.14ms +step:1500/1670 val_loss:3.3112 train_time:142795ms step_avg:95.20ms +step:1501/1670 train_time:142821ms step_avg:95.15ms +step:1502/1670 train_time:142901ms step_avg:95.14ms +step:1503/1670 train_time:143000ms step_avg:95.14ms +step:1504/1670 train_time:143095ms step_avg:95.14ms +step:1505/1670 train_time:143188ms step_avg:95.14ms +step:1506/1670 train_time:143281ms step_avg:95.14ms +step:1507/1670 train_time:143374ms step_avg:95.14ms +step:1508/1670 train_time:143467ms step_avg:95.14ms +step:1509/1670 train_time:143560ms step_avg:95.14ms +step:1510/1670 train_time:143654ms step_avg:95.13ms +step:1511/1670 train_time:143748ms step_avg:95.13ms +step:1512/1670 train_time:143845ms step_avg:95.14ms +step:1513/1670 train_time:143941ms step_avg:95.14ms +step:1514/1670 train_time:144037ms step_avg:95.14ms +step:1515/1670 train_time:144132ms step_avg:95.14ms +step:1516/1670 train_time:144227ms step_avg:95.14ms +step:1517/1670 train_time:144322ms step_avg:95.14ms +step:1518/1670 train_time:144415ms step_avg:95.14ms +step:1519/1670 train_time:144509ms step_avg:95.13ms +step:1520/1670 train_time:144603ms step_avg:95.13ms +step:1521/1670 train_time:144697ms step_avg:95.13ms +step:1522/1670 train_time:144792ms step_avg:95.13ms +step:1523/1670 train_time:144887ms step_avg:95.13ms +step:1524/1670 train_time:144982ms step_avg:95.13ms +step:1525/1670 train_time:145078ms step_avg:95.13ms +step:1526/1670 train_time:145173ms step_avg:95.13ms +step:1527/1670 train_time:145267ms step_avg:95.13ms +step:1528/1670 train_time:145361ms step_avg:95.13ms +step:1529/1670 train_time:145455ms step_avg:95.13ms +step:1530/1670 train_time:145550ms step_avg:95.13ms +step:1531/1670 train_time:145643ms step_avg:95.13ms +step:1532/1670 train_time:145737ms step_avg:95.13ms +step:1533/1670 train_time:145832ms step_avg:95.13ms +step:1534/1670 train_time:145928ms step_avg:95.13ms +step:1535/1670 train_time:146023ms step_avg:95.13ms +step:1536/1670 train_time:146118ms step_avg:95.13ms +step:1537/1670 train_time:146213ms step_avg:95.13ms +step:1538/1670 train_time:146308ms step_avg:95.13ms +step:1539/1670 train_time:146403ms step_avg:95.13ms +step:1540/1670 train_time:146497ms step_avg:95.13ms +step:1541/1670 train_time:146591ms step_avg:95.13ms +step:1542/1670 train_time:146686ms step_avg:95.13ms +step:1543/1670 train_time:146780ms step_avg:95.13ms +step:1544/1670 train_time:146874ms step_avg:95.13ms +step:1545/1670 train_time:146971ms step_avg:95.13ms +step:1546/1670 train_time:147067ms step_avg:95.13ms +step:1547/1670 train_time:147161ms step_avg:95.13ms +step:1548/1670 train_time:147255ms step_avg:95.13ms +step:1549/1670 train_time:147350ms step_avg:95.13ms +step:1550/1670 train_time:147444ms step_avg:95.13ms +step:1551/1670 train_time:147539ms step_avg:95.12ms +step:1552/1670 train_time:147633ms step_avg:95.12ms +step:1553/1670 train_time:147728ms step_avg:95.12ms +step:1554/1670 train_time:147822ms step_avg:95.12ms +step:1555/1670 train_time:147917ms step_avg:95.12ms +step:1556/1670 train_time:148012ms step_avg:95.12ms +step:1557/1670 train_time:148108ms step_avg:95.12ms +step:1558/1670 train_time:148203ms step_avg:95.12ms +step:1559/1670 train_time:148298ms step_avg:95.12ms +step:1560/1670 train_time:148392ms step_avg:95.12ms +step:1561/1670 train_time:148487ms step_avg:95.12ms +step:1562/1670 train_time:148582ms step_avg:95.12ms +step:1563/1670 train_time:148677ms step_avg:95.12ms +step:1564/1670 train_time:148772ms step_avg:95.12ms +step:1565/1670 train_time:148867ms step_avg:95.12ms +step:1566/1670 train_time:148963ms step_avg:95.12ms +step:1567/1670 train_time:149058ms step_avg:95.12ms +step:1568/1670 train_time:149153ms step_avg:95.12ms +step:1569/1670 train_time:149248ms step_avg:95.12ms +step:1570/1670 train_time:149343ms step_avg:95.12ms +step:1571/1670 train_time:149437ms step_avg:95.12ms +step:1572/1670 train_time:149533ms step_avg:95.12ms +step:1573/1670 train_time:149628ms step_avg:95.12ms +step:1574/1670 train_time:149722ms step_avg:95.12ms +step:1575/1670 train_time:149817ms step_avg:95.12ms +step:1576/1670 train_time:149911ms step_avg:95.12ms +step:1577/1670 train_time:150006ms step_avg:95.12ms +step:1578/1670 train_time:150102ms step_avg:95.12ms +step:1579/1670 train_time:150196ms step_avg:95.12ms +step:1580/1670 train_time:150292ms step_avg:95.12ms +step:1581/1670 train_time:150386ms step_avg:95.12ms +step:1582/1670 train_time:150481ms step_avg:95.12ms +step:1583/1670 train_time:150576ms step_avg:95.12ms +step:1584/1670 train_time:150670ms step_avg:95.12ms +step:1585/1670 train_time:150764ms step_avg:95.12ms +step:1586/1670 train_time:150859ms step_avg:95.12ms +step:1587/1670 train_time:150954ms step_avg:95.12ms +step:1588/1670 train_time:151049ms step_avg:95.12ms +step:1589/1670 train_time:151143ms step_avg:95.12ms +step:1590/1670 train_time:151238ms step_avg:95.12ms +step:1591/1670 train_time:151333ms step_avg:95.12ms +step:1592/1670 train_time:151427ms step_avg:95.12ms +step:1593/1670 train_time:151522ms step_avg:95.12ms +step:1594/1670 train_time:151616ms step_avg:95.12ms +step:1595/1670 train_time:151712ms step_avg:95.12ms +step:1596/1670 train_time:151806ms step_avg:95.12ms +step:1597/1670 train_time:151901ms step_avg:95.12ms +step:1598/1670 train_time:151995ms step_avg:95.12ms +step:1599/1670 train_time:152089ms step_avg:95.12ms +step:1600/1670 train_time:152184ms step_avg:95.12ms +step:1601/1670 train_time:152279ms step_avg:95.12ms +step:1602/1670 train_time:152374ms step_avg:95.11ms +step:1603/1670 train_time:152468ms step_avg:95.11ms +step:1604/1670 train_time:152564ms step_avg:95.11ms +step:1605/1670 train_time:152659ms step_avg:95.11ms +step:1606/1670 train_time:152752ms step_avg:95.11ms +step:1607/1670 train_time:152847ms step_avg:95.11ms +step:1608/1670 train_time:152942ms step_avg:95.11ms +step:1609/1670 train_time:153037ms step_avg:95.11ms +step:1610/1670 train_time:153132ms step_avg:95.11ms +step:1611/1670 train_time:153227ms step_avg:95.11ms +step:1612/1670 train_time:153322ms step_avg:95.11ms +step:1613/1670 train_time:153416ms step_avg:95.11ms +step:1614/1670 train_time:153511ms step_avg:95.11ms +step:1615/1670 train_time:153607ms step_avg:95.11ms +step:1616/1670 train_time:153701ms step_avg:95.11ms +step:1617/1670 train_time:153795ms step_avg:95.11ms +step:1618/1670 train_time:153890ms step_avg:95.11ms +step:1619/1670 train_time:153985ms step_avg:95.11ms +step:1620/1670 train_time:154079ms step_avg:95.11ms +step:1621/1670 train_time:154173ms step_avg:95.11ms +step:1622/1670 train_time:154268ms step_avg:95.11ms +step:1623/1670 train_time:154363ms step_avg:95.11ms +step:1624/1670 train_time:154458ms step_avg:95.11ms +step:1625/1670 train_time:154553ms step_avg:95.11ms +step:1625/1670 val_loss:3.2865 train_time:154646ms step_avg:95.17ms +step:1626/1670 train_time:154672ms step_avg:95.12ms +step:1627/1670 train_time:154749ms step_avg:95.11ms +step:1628/1670 train_time:154849ms step_avg:95.12ms +step:1629/1670 train_time:154944ms step_avg:95.12ms +step:1630/1670 train_time:155038ms step_avg:95.12ms +step:1631/1670 train_time:155132ms step_avg:95.11ms +step:1632/1670 train_time:155226ms step_avg:95.11ms +step:1633/1670 train_time:155319ms step_avg:95.11ms +step:1634/1670 train_time:155413ms step_avg:95.11ms +step:1635/1670 train_time:155506ms step_avg:95.11ms +step:1636/1670 train_time:155600ms step_avg:95.11ms +step:1637/1670 train_time:155696ms step_avg:95.11ms +step:1638/1670 train_time:155793ms step_avg:95.11ms +step:1639/1670 train_time:155890ms step_avg:95.11ms +step:1640/1670 train_time:155985ms step_avg:95.11ms +step:1641/1670 train_time:156079ms step_avg:95.11ms +step:1642/1670 train_time:156173ms step_avg:95.11ms +step:1643/1670 train_time:156268ms step_avg:95.11ms +step:1644/1670 train_time:156361ms step_avg:95.11ms +step:1645/1670 train_time:156455ms step_avg:95.11ms +step:1646/1670 train_time:156549ms step_avg:95.11ms +step:1647/1670 train_time:156643ms step_avg:95.11ms +step:1648/1670 train_time:156739ms step_avg:95.11ms +step:1649/1670 train_time:156836ms step_avg:95.11ms +step:1650/1670 train_time:156932ms step_avg:95.11ms +step:1651/1670 train_time:157027ms step_avg:95.11ms +step:1652/1670 train_time:157121ms step_avg:95.11ms +step:1653/1670 train_time:157215ms step_avg:95.11ms +step:1654/1670 train_time:157310ms step_avg:95.11ms +step:1655/1670 train_time:157405ms step_avg:95.11ms +step:1656/1670 train_time:157499ms step_avg:95.11ms +step:1657/1670 train_time:157592ms step_avg:95.11ms +step:1658/1670 train_time:157686ms step_avg:95.11ms +step:1659/1670 train_time:157781ms step_avg:95.11ms +step:1660/1670 train_time:157876ms step_avg:95.11ms +step:1661/1670 train_time:157972ms step_avg:95.11ms +step:1662/1670 train_time:158066ms step_avg:95.11ms +step:1663/1670 train_time:158161ms step_avg:95.11ms +step:1664/1670 train_time:158255ms step_avg:95.11ms +step:1665/1670 train_time:158350ms step_avg:95.10ms +step:1666/1670 train_time:158444ms step_avg:95.10ms +step:1667/1670 train_time:158538ms step_avg:95.10ms +step:1668/1670 train_time:158632ms step_avg:95.10ms +step:1669/1670 train_time:158727ms step_avg:95.10ms +step:1670/1670 train_time:158822ms step_avg:95.10ms +step:1670/1670 val_loss:3.2779 train_time:158998ms step_avg:95.21ms +peak memory allocated: 32304 MiB reserved: 47696 MiB diff --git a/records/091025_Yarn/6297777d-03bd-4955-9c3a-c854246b928a.txt b/records/091025_Yarn/6297777d-03bd-4955-9c3a-c854246b928a.txt new file mode 100644 index 000000000..f506379be --- /dev/null +++ b/records/091025_Yarn/6297777d-03bd-4955-9c3a-c854246b928a.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 04:01:30 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 36C P0 120W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 40C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 41C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 35C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 35C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 41C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 39C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 36C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 63307 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 63308 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63309 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63310 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63311 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63312 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63313 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63314 C /usr/bin/python3 614MiB | +| 1 N/A N/A 63308 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 63309 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 63310 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 63311 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 63312 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 63313 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 63314 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:455ms step_avg:454.74ms +step:2/1670 train_time:481ms step_avg:240.42ms +step:3/1670 train_time:548ms step_avg:182.66ms +step:4/1670 train_time:638ms step_avg:159.62ms +step:5/1670 train_time:730ms step_avg:146.05ms +step:6/1670 train_time:822ms step_avg:136.98ms +step:7/1670 train_time:914ms step_avg:130.63ms +step:8/1670 train_time:1006ms step_avg:125.75ms +step:9/1670 train_time:1098ms step_avg:121.97ms +step:10/1670 train_time:1189ms step_avg:118.94ms +step:11/1670 train_time:1281ms step_avg:116.46ms +step:12/1670 train_time:1373ms step_avg:114.43ms +step:13/1670 train_time:1469ms step_avg:113.03ms +step:14/1670 train_time:1565ms step_avg:111.75ms +step:15/1670 train_time:1659ms step_avg:110.57ms +step:16/1670 train_time:1751ms step_avg:109.46ms +step:17/1670 train_time:1844ms step_avg:108.47ms +step:18/1670 train_time:1936ms step_avg:107.55ms +step:19/1670 train_time:2029ms step_avg:106.78ms +step:20/1670 train_time:2120ms step_avg:106.02ms +step:21/1670 train_time:2212ms step_avg:105.35ms +step:22/1670 train_time:2305ms step_avg:104.79ms +step:23/1670 train_time:2398ms step_avg:104.25ms +step:24/1670 train_time:2492ms step_avg:103.81ms +step:25/1670 train_time:2585ms step_avg:103.41ms +step:26/1670 train_time:2679ms step_avg:103.03ms +step:27/1670 train_time:2771ms step_avg:102.65ms +step:28/1670 train_time:2864ms step_avg:102.30ms +step:29/1670 train_time:2957ms step_avg:101.97ms +step:30/1670 train_time:3049ms step_avg:101.64ms +step:31/1670 train_time:3142ms step_avg:101.36ms +step:32/1670 train_time:3235ms step_avg:101.09ms +step:33/1670 train_time:3328ms step_avg:100.84ms +step:34/1670 train_time:3421ms step_avg:100.63ms +step:35/1670 train_time:3515ms step_avg:100.43ms +step:36/1670 train_time:3608ms step_avg:100.23ms +step:37/1670 train_time:3701ms step_avg:100.03ms +step:38/1670 train_time:3794ms step_avg:99.85ms +step:39/1670 train_time:3888ms step_avg:99.69ms +step:40/1670 train_time:3980ms step_avg:99.50ms +step:41/1670 train_time:4073ms step_avg:99.34ms +step:42/1670 train_time:4167ms step_avg:99.20ms +step:43/1670 train_time:4259ms step_avg:99.05ms +step:44/1670 train_time:4352ms step_avg:98.90ms +step:45/1670 train_time:4445ms step_avg:98.79ms +step:46/1670 train_time:4539ms step_avg:98.68ms +step:47/1670 train_time:4633ms step_avg:98.56ms +step:48/1670 train_time:4725ms step_avg:98.43ms +step:49/1670 train_time:4818ms step_avg:98.33ms +step:50/1670 train_time:4911ms step_avg:98.22ms +step:51/1670 train_time:5004ms step_avg:98.12ms +step:52/1670 train_time:5097ms step_avg:98.02ms +step:53/1670 train_time:5190ms step_avg:97.92ms +step:54/1670 train_time:5283ms step_avg:97.83ms +step:55/1670 train_time:5376ms step_avg:97.74ms +step:56/1670 train_time:5469ms step_avg:97.66ms +step:57/1670 train_time:5562ms step_avg:97.58ms +step:58/1670 train_time:5655ms step_avg:97.50ms +step:59/1670 train_time:5748ms step_avg:97.42ms +step:60/1670 train_time:5841ms step_avg:97.34ms +step:61/1670 train_time:5934ms step_avg:97.27ms +step:62/1670 train_time:6027ms step_avg:97.20ms +step:63/1670 train_time:6120ms step_avg:97.14ms +step:64/1670 train_time:6211ms step_avg:97.05ms +step:65/1670 train_time:6304ms step_avg:96.99ms +step:66/1670 train_time:6398ms step_avg:96.93ms +step:67/1670 train_time:6490ms step_avg:96.87ms +step:68/1670 train_time:6584ms step_avg:96.83ms +step:69/1670 train_time:6677ms step_avg:96.77ms +step:70/1670 train_time:6770ms step_avg:96.71ms +step:71/1670 train_time:6864ms step_avg:96.68ms +step:72/1670 train_time:6957ms step_avg:96.63ms +step:73/1670 train_time:7050ms step_avg:96.57ms +step:74/1670 train_time:7143ms step_avg:96.53ms +step:75/1670 train_time:7236ms step_avg:96.48ms +step:76/1670 train_time:7328ms step_avg:96.42ms +step:77/1670 train_time:7420ms step_avg:96.37ms +step:78/1670 train_time:7512ms step_avg:96.31ms +step:79/1670 train_time:7605ms step_avg:96.26ms +step:80/1670 train_time:7697ms step_avg:96.22ms +step:81/1670 train_time:7790ms step_avg:96.17ms +step:82/1670 train_time:7883ms step_avg:96.13ms +step:83/1670 train_time:7976ms step_avg:96.10ms +step:84/1670 train_time:8069ms step_avg:96.06ms +step:85/1670 train_time:8161ms step_avg:96.02ms +step:86/1670 train_time:8253ms step_avg:95.97ms +step:87/1670 train_time:8346ms step_avg:95.93ms +step:88/1670 train_time:8439ms step_avg:95.89ms +step:89/1670 train_time:8530ms step_avg:95.85ms +step:90/1670 train_time:8623ms step_avg:95.81ms +step:91/1670 train_time:8716ms step_avg:95.77ms +step:92/1670 train_time:8808ms step_avg:95.74ms +step:93/1670 train_time:8900ms step_avg:95.70ms +step:94/1670 train_time:8992ms step_avg:95.66ms +step:95/1670 train_time:9086ms step_avg:95.65ms +step:96/1670 train_time:9179ms step_avg:95.62ms +step:97/1670 train_time:9272ms step_avg:95.59ms +step:98/1670 train_time:9365ms step_avg:95.56ms +step:99/1670 train_time:9458ms step_avg:95.53ms +step:100/1670 train_time:9549ms step_avg:95.49ms +step:101/1670 train_time:9643ms step_avg:95.47ms +step:102/1670 train_time:9735ms step_avg:95.44ms +step:103/1670 train_time:9827ms step_avg:95.41ms +step:104/1670 train_time:9920ms step_avg:95.38ms +step:105/1670 train_time:10012ms step_avg:95.36ms +step:106/1670 train_time:10105ms step_avg:95.33ms +step:107/1670 train_time:10198ms step_avg:95.31ms +step:108/1670 train_time:10291ms step_avg:95.28ms +step:109/1670 train_time:10384ms step_avg:95.26ms +step:110/1670 train_time:10476ms step_avg:95.23ms +step:111/1670 train_time:10568ms step_avg:95.20ms +step:112/1670 train_time:10660ms step_avg:95.18ms +step:113/1670 train_time:10752ms step_avg:95.15ms +step:114/1670 train_time:10845ms step_avg:95.13ms +step:115/1670 train_time:10937ms step_avg:95.10ms +step:116/1670 train_time:11030ms step_avg:95.08ms +step:117/1670 train_time:11122ms step_avg:95.06ms +step:118/1670 train_time:11214ms step_avg:95.04ms +step:119/1670 train_time:11306ms step_avg:95.01ms +step:120/1670 train_time:11399ms step_avg:94.99ms +step:121/1670 train_time:11491ms step_avg:94.97ms +step:122/1670 train_time:11584ms step_avg:94.95ms +step:123/1670 train_time:11676ms step_avg:94.93ms +step:124/1670 train_time:11768ms step_avg:94.91ms +step:125/1670 train_time:11861ms step_avg:94.89ms +step:125/1670 val_loss:4.3018 train_time:11951ms step_avg:95.61ms +step:126/1670 train_time:11978ms step_avg:95.07ms +step:127/1670 train_time:12049ms step_avg:94.88ms +step:128/1670 train_time:12154ms step_avg:94.95ms +step:129/1670 train_time:12252ms step_avg:94.98ms +step:130/1670 train_time:12346ms step_avg:94.97ms +step:131/1670 train_time:12437ms step_avg:94.94ms +step:132/1670 train_time:12529ms step_avg:94.91ms +step:133/1670 train_time:12620ms step_avg:94.89ms +step:134/1670 train_time:12712ms step_avg:94.86ms +step:135/1670 train_time:12803ms step_avg:94.84ms +step:136/1670 train_time:12894ms step_avg:94.81ms +step:137/1670 train_time:12986ms step_avg:94.79ms +step:138/1670 train_time:13079ms step_avg:94.77ms +step:139/1670 train_time:13173ms step_avg:94.77ms +step:140/1670 train_time:13268ms step_avg:94.77ms +step:141/1670 train_time:13360ms step_avg:94.75ms +step:142/1670 train_time:13453ms step_avg:94.74ms +step:143/1670 train_time:13545ms step_avg:94.72ms +step:144/1670 train_time:13637ms step_avg:94.70ms +step:145/1670 train_time:13728ms step_avg:94.68ms +step:146/1670 train_time:13820ms step_avg:94.66ms +step:147/1670 train_time:13912ms step_avg:94.64ms +step:148/1670 train_time:14005ms step_avg:94.63ms +step:149/1670 train_time:14097ms step_avg:94.61ms +step:150/1670 train_time:14190ms step_avg:94.60ms +step:151/1670 train_time:14283ms step_avg:94.59ms +step:152/1670 train_time:14376ms step_avg:94.58ms +step:153/1670 train_time:14469ms step_avg:94.57ms +step:154/1670 train_time:14561ms step_avg:94.55ms +step:155/1670 train_time:14654ms step_avg:94.54ms +step:156/1670 train_time:14746ms step_avg:94.53ms +step:157/1670 train_time:14838ms step_avg:94.51ms +step:158/1670 train_time:14931ms step_avg:94.50ms +step:159/1670 train_time:15023ms step_avg:94.48ms +step:160/1670 train_time:15115ms step_avg:94.47ms +step:161/1670 train_time:15208ms step_avg:94.46ms +step:162/1670 train_time:15301ms step_avg:94.45ms +step:163/1670 train_time:15394ms step_avg:94.44ms +step:164/1670 train_time:15486ms step_avg:94.43ms +step:165/1670 train_time:15578ms step_avg:94.41ms +step:166/1670 train_time:15671ms step_avg:94.40ms +step:167/1670 train_time:15764ms step_avg:94.40ms +step:168/1670 train_time:15856ms step_avg:94.38ms +step:169/1670 train_time:15947ms step_avg:94.36ms +step:170/1670 train_time:16039ms step_avg:94.35ms +step:171/1670 train_time:16132ms step_avg:94.34ms +step:172/1670 train_time:16225ms step_avg:94.33ms +step:173/1670 train_time:16317ms step_avg:94.32ms +step:174/1670 train_time:16411ms step_avg:94.31ms +step:175/1670 train_time:16503ms step_avg:94.30ms +step:176/1670 train_time:16595ms step_avg:94.29ms +step:177/1670 train_time:16688ms step_avg:94.28ms +step:178/1670 train_time:16780ms step_avg:94.27ms +step:179/1670 train_time:16872ms step_avg:94.26ms +step:180/1670 train_time:16965ms step_avg:94.25ms +step:181/1670 train_time:17058ms step_avg:94.24ms +step:182/1670 train_time:17150ms step_avg:94.23ms +step:183/1670 train_time:17243ms step_avg:94.22ms +step:184/1670 train_time:17335ms step_avg:94.21ms +step:185/1670 train_time:17428ms step_avg:94.21ms +step:186/1670 train_time:17521ms step_avg:94.20ms +step:187/1670 train_time:17615ms step_avg:94.20ms +step:188/1670 train_time:17707ms step_avg:94.19ms +step:189/1670 train_time:17799ms step_avg:94.18ms +step:190/1670 train_time:17892ms step_avg:94.17ms +step:191/1670 train_time:17985ms step_avg:94.16ms +step:192/1670 train_time:18078ms step_avg:94.15ms +step:193/1670 train_time:18169ms step_avg:94.14ms +step:194/1670 train_time:18262ms step_avg:94.14ms +step:195/1670 train_time:18355ms step_avg:94.13ms +step:196/1670 train_time:18447ms step_avg:94.12ms +step:197/1670 train_time:18539ms step_avg:94.11ms +step:198/1670 train_time:18633ms step_avg:94.10ms +step:199/1670 train_time:18725ms step_avg:94.09ms +step:200/1670 train_time:18817ms step_avg:94.08ms +step:201/1670 train_time:18909ms step_avg:94.08ms +step:202/1670 train_time:19002ms step_avg:94.07ms +step:203/1670 train_time:19095ms step_avg:94.06ms +step:204/1670 train_time:19188ms step_avg:94.06ms +step:205/1670 train_time:19280ms step_avg:94.05ms +step:206/1670 train_time:19373ms step_avg:94.04ms +step:207/1670 train_time:19466ms step_avg:94.04ms +step:208/1670 train_time:19558ms step_avg:94.03ms +step:209/1670 train_time:19651ms step_avg:94.02ms +step:210/1670 train_time:19743ms step_avg:94.02ms +step:211/1670 train_time:19836ms step_avg:94.01ms +step:212/1670 train_time:19928ms step_avg:94.00ms +step:213/1670 train_time:20267ms step_avg:95.15ms +step:214/1670 train_time:20407ms step_avg:95.36ms +step:215/1670 train_time:20498ms step_avg:95.34ms +step:216/1670 train_time:20590ms step_avg:95.32ms +step:217/1670 train_time:20681ms step_avg:95.30ms +step:218/1670 train_time:20772ms step_avg:95.29ms +step:219/1670 train_time:20864ms step_avg:95.27ms +step:220/1670 train_time:20956ms step_avg:95.25ms +step:221/1670 train_time:21047ms step_avg:95.23ms +step:222/1670 train_time:21138ms step_avg:95.22ms +step:223/1670 train_time:21230ms step_avg:95.20ms +step:224/1670 train_time:21324ms step_avg:95.20ms +step:225/1670 train_time:21420ms step_avg:95.20ms +step:226/1670 train_time:21513ms step_avg:95.19ms +step:227/1670 train_time:21606ms step_avg:95.18ms +step:228/1670 train_time:21698ms step_avg:95.17ms +step:229/1670 train_time:21790ms step_avg:95.15ms +step:230/1670 train_time:21882ms step_avg:95.14ms +step:231/1670 train_time:21974ms step_avg:95.12ms +step:232/1670 train_time:22066ms step_avg:95.11ms +step:233/1670 train_time:22158ms step_avg:95.10ms +step:234/1670 train_time:22250ms step_avg:95.09ms +step:235/1670 train_time:22343ms step_avg:95.08ms +step:236/1670 train_time:22436ms step_avg:95.07ms +step:237/1670 train_time:22530ms step_avg:95.06ms +step:238/1670 train_time:22623ms step_avg:95.05ms +step:239/1670 train_time:22715ms step_avg:95.04ms +step:240/1670 train_time:22807ms step_avg:95.03ms +step:241/1670 train_time:22899ms step_avg:95.01ms +step:242/1670 train_time:22991ms step_avg:95.00ms +step:243/1670 train_time:23083ms step_avg:94.99ms +step:244/1670 train_time:23175ms step_avg:94.98ms +step:245/1670 train_time:23267ms step_avg:94.97ms +step:246/1670 train_time:23361ms step_avg:94.96ms +step:247/1670 train_time:23455ms step_avg:94.96ms +step:248/1670 train_time:23547ms step_avg:94.95ms +step:249/1670 train_time:23640ms step_avg:94.94ms +step:250/1670 train_time:23732ms step_avg:94.93ms +step:250/1670 val_loss:3.9694 train_time:23822ms step_avg:95.29ms +step:251/1670 train_time:23851ms step_avg:95.02ms +step:252/1670 train_time:23922ms step_avg:94.93ms +step:253/1670 train_time:24021ms step_avg:94.94ms +step:254/1670 train_time:24115ms step_avg:94.94ms +step:255/1670 train_time:24207ms step_avg:94.93ms +step:256/1670 train_time:24299ms step_avg:94.92ms +step:257/1670 train_time:24390ms step_avg:94.90ms +step:258/1670 train_time:24482ms step_avg:94.89ms +step:259/1670 train_time:24573ms step_avg:94.88ms +step:260/1670 train_time:24664ms step_avg:94.86ms +step:261/1670 train_time:24756ms step_avg:94.85ms +step:262/1670 train_time:24849ms step_avg:94.84ms +step:263/1670 train_time:24944ms step_avg:94.84ms +step:264/1670 train_time:25039ms step_avg:94.84ms +step:265/1670 train_time:25132ms step_avg:94.84ms +step:266/1670 train_time:25224ms step_avg:94.83ms +step:267/1670 train_time:25316ms step_avg:94.82ms +step:268/1670 train_time:25409ms step_avg:94.81ms +step:269/1670 train_time:25500ms step_avg:94.80ms +step:270/1670 train_time:25593ms step_avg:94.79ms +step:271/1670 train_time:25684ms step_avg:94.78ms +step:272/1670 train_time:25777ms step_avg:94.77ms +step:273/1670 train_time:25870ms step_avg:94.76ms +step:274/1670 train_time:25962ms step_avg:94.75ms +step:275/1670 train_time:26056ms step_avg:94.75ms +step:276/1670 train_time:26150ms step_avg:94.75ms +step:277/1670 train_time:26243ms step_avg:94.74ms +step:278/1670 train_time:26336ms step_avg:94.73ms +step:279/1670 train_time:26428ms step_avg:94.72ms +step:280/1670 train_time:26520ms step_avg:94.72ms +step:281/1670 train_time:26613ms step_avg:94.71ms +step:282/1670 train_time:26704ms step_avg:94.70ms +step:283/1670 train_time:26797ms step_avg:94.69ms +step:284/1670 train_time:26889ms step_avg:94.68ms +step:285/1670 train_time:26982ms step_avg:94.67ms +step:286/1670 train_time:27076ms step_avg:94.67ms +step:287/1670 train_time:27168ms step_avg:94.66ms +step:288/1670 train_time:27261ms step_avg:94.66ms +step:289/1670 train_time:27354ms step_avg:94.65ms +step:290/1670 train_time:27446ms step_avg:94.64ms +step:291/1670 train_time:27539ms step_avg:94.63ms +step:292/1670 train_time:27631ms step_avg:94.63ms +step:293/1670 train_time:27723ms step_avg:94.62ms +step:294/1670 train_time:27815ms step_avg:94.61ms +step:295/1670 train_time:27907ms step_avg:94.60ms +step:296/1670 train_time:28000ms step_avg:94.59ms +step:297/1670 train_time:28093ms step_avg:94.59ms +step:298/1670 train_time:28186ms step_avg:94.58ms +step:299/1670 train_time:28279ms step_avg:94.58ms +step:300/1670 train_time:28371ms step_avg:94.57ms +step:301/1670 train_time:28463ms step_avg:94.56ms +step:302/1670 train_time:28556ms step_avg:94.56ms +step:303/1670 train_time:28648ms step_avg:94.55ms +step:304/1670 train_time:28740ms step_avg:94.54ms +step:305/1670 train_time:28833ms step_avg:94.53ms +step:306/1670 train_time:28925ms step_avg:94.53ms +step:307/1670 train_time:29018ms step_avg:94.52ms +step:308/1670 train_time:29111ms step_avg:94.52ms +step:309/1670 train_time:29203ms step_avg:94.51ms +step:310/1670 train_time:29296ms step_avg:94.50ms +step:311/1670 train_time:29388ms step_avg:94.50ms +step:312/1670 train_time:29482ms step_avg:94.49ms +step:313/1670 train_time:29574ms step_avg:94.49ms +step:314/1670 train_time:29667ms step_avg:94.48ms +step:315/1670 train_time:29759ms step_avg:94.47ms +step:316/1670 train_time:29852ms step_avg:94.47ms +step:317/1670 train_time:29944ms step_avg:94.46ms +step:318/1670 train_time:30037ms step_avg:94.46ms +step:319/1670 train_time:30131ms step_avg:94.45ms +step:320/1670 train_time:30223ms step_avg:94.45ms +step:321/1670 train_time:30315ms step_avg:94.44ms +step:322/1670 train_time:30407ms step_avg:94.43ms +step:323/1670 train_time:30501ms step_avg:94.43ms +step:324/1670 train_time:30594ms step_avg:94.43ms +step:325/1670 train_time:30686ms step_avg:94.42ms +step:326/1670 train_time:30778ms step_avg:94.41ms +step:327/1670 train_time:30871ms step_avg:94.41ms +step:328/1670 train_time:30963ms step_avg:94.40ms +step:329/1670 train_time:31056ms step_avg:94.39ms +step:330/1670 train_time:31148ms step_avg:94.39ms +step:331/1670 train_time:31240ms step_avg:94.38ms +step:332/1670 train_time:31333ms step_avg:94.38ms +step:333/1670 train_time:31426ms step_avg:94.37ms +step:334/1670 train_time:31518ms step_avg:94.37ms +step:335/1670 train_time:31611ms step_avg:94.36ms +step:336/1670 train_time:31703ms step_avg:94.35ms +step:337/1670 train_time:31795ms step_avg:94.35ms +step:338/1670 train_time:31887ms step_avg:94.34ms +step:339/1670 train_time:31979ms step_avg:94.33ms +step:340/1670 train_time:32073ms step_avg:94.33ms +step:341/1670 train_time:32165ms step_avg:94.33ms +step:342/1670 train_time:32258ms step_avg:94.32ms +step:343/1670 train_time:32350ms step_avg:94.32ms +step:344/1670 train_time:32444ms step_avg:94.31ms +step:345/1670 train_time:32537ms step_avg:94.31ms +step:346/1670 train_time:32629ms step_avg:94.30ms +step:347/1670 train_time:32721ms step_avg:94.30ms +step:348/1670 train_time:32815ms step_avg:94.30ms +step:349/1670 train_time:32907ms step_avg:94.29ms +step:350/1670 train_time:32999ms step_avg:94.28ms +step:351/1670 train_time:33092ms step_avg:94.28ms +step:352/1670 train_time:33184ms step_avg:94.27ms +step:353/1670 train_time:33277ms step_avg:94.27ms +step:354/1670 train_time:33369ms step_avg:94.26ms +step:355/1670 train_time:33462ms step_avg:94.26ms +step:356/1670 train_time:33555ms step_avg:94.25ms +step:357/1670 train_time:33647ms step_avg:94.25ms +step:358/1670 train_time:33740ms step_avg:94.25ms +step:359/1670 train_time:33833ms step_avg:94.24ms +step:360/1670 train_time:33925ms step_avg:94.23ms +step:361/1670 train_time:34018ms step_avg:94.23ms +step:362/1670 train_time:34111ms step_avg:94.23ms +step:363/1670 train_time:34203ms step_avg:94.22ms +step:364/1670 train_time:34295ms step_avg:94.22ms +step:365/1670 train_time:34387ms step_avg:94.21ms +step:366/1670 train_time:34480ms step_avg:94.21ms +step:367/1670 train_time:34573ms step_avg:94.20ms +step:368/1670 train_time:34666ms step_avg:94.20ms +step:369/1670 train_time:34759ms step_avg:94.20ms +step:370/1670 train_time:34851ms step_avg:94.19ms +step:371/1670 train_time:34943ms step_avg:94.19ms +step:372/1670 train_time:35036ms step_avg:94.18ms +step:373/1670 train_time:35128ms step_avg:94.18ms +step:374/1670 train_time:35221ms step_avg:94.17ms +step:375/1670 train_time:35313ms step_avg:94.17ms +step:375/1670 val_loss:3.8201 train_time:35403ms step_avg:94.41ms +step:376/1670 train_time:35431ms step_avg:94.23ms +step:377/1670 train_time:35506ms step_avg:94.18ms +step:378/1670 train_time:35607ms step_avg:94.20ms +step:379/1670 train_time:35704ms step_avg:94.20ms +step:380/1670 train_time:35796ms step_avg:94.20ms +step:381/1670 train_time:35887ms step_avg:94.19ms +step:382/1670 train_time:35978ms step_avg:94.18ms +step:383/1670 train_time:36070ms step_avg:94.18ms +step:384/1670 train_time:36162ms step_avg:94.17ms +step:385/1670 train_time:36253ms step_avg:94.16ms +step:386/1670 train_time:36345ms step_avg:94.16ms +step:387/1670 train_time:36437ms step_avg:94.15ms +step:388/1670 train_time:36533ms step_avg:94.16ms +step:389/1670 train_time:36628ms step_avg:94.16ms +step:390/1670 train_time:36722ms step_avg:94.16ms +step:391/1670 train_time:36814ms step_avg:94.15ms +step:392/1670 train_time:36907ms step_avg:94.15ms +step:393/1670 train_time:36999ms step_avg:94.14ms +step:394/1670 train_time:37090ms step_avg:94.14ms +step:395/1670 train_time:37182ms step_avg:94.13ms +step:396/1670 train_time:37273ms step_avg:94.12ms +step:397/1670 train_time:37365ms step_avg:94.12ms +step:398/1670 train_time:37458ms step_avg:94.12ms +step:399/1670 train_time:37552ms step_avg:94.11ms +step:400/1670 train_time:37645ms step_avg:94.11ms +step:401/1670 train_time:37738ms step_avg:94.11ms +step:402/1670 train_time:37831ms step_avg:94.11ms +step:403/1670 train_time:37924ms step_avg:94.10ms +step:404/1670 train_time:38016ms step_avg:94.10ms +step:405/1670 train_time:38109ms step_avg:94.10ms +step:406/1670 train_time:38200ms step_avg:94.09ms +step:407/1670 train_time:38292ms step_avg:94.08ms +step:408/1670 train_time:38384ms step_avg:94.08ms +step:409/1670 train_time:38478ms step_avg:94.08ms +step:410/1670 train_time:38571ms step_avg:94.08ms +step:411/1670 train_time:38664ms step_avg:94.07ms +step:412/1670 train_time:38758ms step_avg:94.07ms +step:413/1670 train_time:38850ms step_avg:94.07ms +step:414/1670 train_time:38942ms step_avg:94.06ms +step:415/1670 train_time:39035ms step_avg:94.06ms +step:416/1670 train_time:39128ms step_avg:94.06ms +step:417/1670 train_time:39220ms step_avg:94.05ms +step:418/1670 train_time:39312ms step_avg:94.05ms +step:419/1670 train_time:39404ms step_avg:94.04ms +step:420/1670 train_time:39497ms step_avg:94.04ms +step:421/1670 train_time:39591ms step_avg:94.04ms +step:422/1670 train_time:39683ms step_avg:94.04ms +step:423/1670 train_time:39777ms step_avg:94.03ms +step:424/1670 train_time:39870ms step_avg:94.03ms +step:425/1670 train_time:40207ms step_avg:94.60ms +step:426/1670 train_time:40381ms step_avg:94.79ms +step:427/1670 train_time:40471ms step_avg:94.78ms +step:428/1670 train_time:40563ms step_avg:94.77ms +step:429/1670 train_time:40655ms step_avg:94.77ms +step:430/1670 train_time:40745ms step_avg:94.76ms +step:431/1670 train_time:40836ms step_avg:94.75ms +step:432/1670 train_time:40928ms step_avg:94.74ms +step:433/1670 train_time:41020ms step_avg:94.73ms +step:434/1670 train_time:41111ms step_avg:94.73ms +step:435/1670 train_time:41205ms step_avg:94.72ms +step:436/1670 train_time:41301ms step_avg:94.73ms +step:437/1670 train_time:41397ms step_avg:94.73ms +step:438/1670 train_time:41491ms step_avg:94.73ms +step:439/1670 train_time:41584ms step_avg:94.72ms +step:440/1670 train_time:41676ms step_avg:94.72ms +step:441/1670 train_time:41768ms step_avg:94.71ms +step:442/1670 train_time:41860ms step_avg:94.71ms +step:443/1670 train_time:41952ms step_avg:94.70ms +step:444/1670 train_time:42043ms step_avg:94.69ms +step:445/1670 train_time:42134ms step_avg:94.68ms +step:446/1670 train_time:42227ms step_avg:94.68ms +step:447/1670 train_time:42320ms step_avg:94.68ms +step:448/1670 train_time:42413ms step_avg:94.67ms +step:449/1670 train_time:42507ms step_avg:94.67ms +step:450/1670 train_time:42600ms step_avg:94.67ms +step:451/1670 train_time:42692ms step_avg:94.66ms +step:452/1670 train_time:42784ms step_avg:94.66ms +step:453/1670 train_time:42877ms step_avg:94.65ms +step:454/1670 train_time:42969ms step_avg:94.65ms +step:455/1670 train_time:43061ms step_avg:94.64ms +step:456/1670 train_time:43153ms step_avg:94.63ms +step:457/1670 train_time:43245ms step_avg:94.63ms +step:458/1670 train_time:43339ms step_avg:94.63ms +step:459/1670 train_time:43433ms step_avg:94.62ms +step:460/1670 train_time:43527ms step_avg:94.62ms +step:461/1670 train_time:43619ms step_avg:94.62ms +step:462/1670 train_time:43712ms step_avg:94.62ms +step:463/1670 train_time:43805ms step_avg:94.61ms +step:464/1670 train_time:43896ms step_avg:94.60ms +step:465/1670 train_time:43989ms step_avg:94.60ms +step:466/1670 train_time:44081ms step_avg:94.59ms +step:467/1670 train_time:44174ms step_avg:94.59ms +step:468/1670 train_time:44266ms step_avg:94.59ms +step:469/1670 train_time:44358ms step_avg:94.58ms +step:470/1670 train_time:44451ms step_avg:94.58ms +step:471/1670 train_time:44543ms step_avg:94.57ms +step:472/1670 train_time:44635ms step_avg:94.57ms +step:473/1670 train_time:44729ms step_avg:94.56ms +step:474/1670 train_time:44821ms step_avg:94.56ms +step:475/1670 train_time:44913ms step_avg:94.55ms +step:476/1670 train_time:45006ms step_avg:94.55ms +step:477/1670 train_time:45098ms step_avg:94.55ms +step:478/1670 train_time:45190ms step_avg:94.54ms +step:479/1670 train_time:45283ms step_avg:94.54ms +step:480/1670 train_time:45376ms step_avg:94.53ms +step:481/1670 train_time:45468ms step_avg:94.53ms +step:482/1670 train_time:45561ms step_avg:94.52ms +step:483/1670 train_time:45654ms step_avg:94.52ms +step:484/1670 train_time:45746ms step_avg:94.52ms +step:485/1670 train_time:45838ms step_avg:94.51ms +step:486/1670 train_time:45931ms step_avg:94.51ms +step:487/1670 train_time:46024ms step_avg:94.51ms +step:488/1670 train_time:46117ms step_avg:94.50ms +step:489/1670 train_time:46209ms step_avg:94.50ms +step:490/1670 train_time:46301ms step_avg:94.49ms +step:491/1670 train_time:46394ms step_avg:94.49ms +step:492/1670 train_time:46487ms step_avg:94.49ms +step:493/1670 train_time:46580ms step_avg:94.48ms +step:494/1670 train_time:46672ms step_avg:94.48ms +step:495/1670 train_time:46765ms step_avg:94.47ms +step:496/1670 train_time:46858ms step_avg:94.47ms +step:497/1670 train_time:46951ms step_avg:94.47ms +step:498/1670 train_time:47043ms step_avg:94.46ms +step:499/1670 train_time:47135ms step_avg:94.46ms +step:500/1670 train_time:47228ms step_avg:94.46ms +step:500/1670 val_loss:3.7176 train_time:47317ms step_avg:94.63ms +step:501/1670 train_time:47344ms step_avg:94.50ms +step:502/1670 train_time:47417ms step_avg:94.46ms +step:503/1670 train_time:47517ms step_avg:94.47ms +step:504/1670 train_time:47612ms step_avg:94.47ms +step:505/1670 train_time:47703ms step_avg:94.46ms +step:506/1670 train_time:47795ms step_avg:94.46ms +step:507/1670 train_time:47886ms step_avg:94.45ms +step:508/1670 train_time:47977ms step_avg:94.44ms +step:509/1670 train_time:48069ms step_avg:94.44ms +step:510/1670 train_time:48161ms step_avg:94.43ms +step:511/1670 train_time:48253ms step_avg:94.43ms +step:512/1670 train_time:48345ms step_avg:94.42ms +step:513/1670 train_time:48440ms step_avg:94.42ms +step:514/1670 train_time:48535ms step_avg:94.43ms +step:515/1670 train_time:48629ms step_avg:94.42ms +step:516/1670 train_time:48721ms step_avg:94.42ms +step:517/1670 train_time:48813ms step_avg:94.41ms +step:518/1670 train_time:48906ms step_avg:94.41ms +step:519/1670 train_time:48997ms step_avg:94.41ms +step:520/1670 train_time:49088ms step_avg:94.40ms +step:521/1670 train_time:49180ms step_avg:94.40ms +step:522/1670 train_time:49272ms step_avg:94.39ms +step:523/1670 train_time:49365ms step_avg:94.39ms +step:524/1670 train_time:49458ms step_avg:94.39ms +step:525/1670 train_time:49552ms step_avg:94.39ms +step:526/1670 train_time:49645ms step_avg:94.38ms +step:527/1670 train_time:49738ms step_avg:94.38ms +step:528/1670 train_time:49830ms step_avg:94.38ms +step:529/1670 train_time:49923ms step_avg:94.37ms +step:530/1670 train_time:50014ms step_avg:94.37ms +step:531/1670 train_time:50107ms step_avg:94.36ms +step:532/1670 train_time:50199ms step_avg:94.36ms +step:533/1670 train_time:50291ms step_avg:94.35ms +step:534/1670 train_time:50384ms step_avg:94.35ms +step:535/1670 train_time:50478ms step_avg:94.35ms +step:536/1670 train_time:50571ms step_avg:94.35ms +step:537/1670 train_time:50664ms step_avg:94.35ms +step:538/1670 train_time:50757ms step_avg:94.34ms +step:539/1670 train_time:50850ms step_avg:94.34ms +step:540/1670 train_time:50942ms step_avg:94.34ms +step:541/1670 train_time:51035ms step_avg:94.33ms +step:542/1670 train_time:51128ms step_avg:94.33ms +step:543/1670 train_time:51219ms step_avg:94.33ms +step:544/1670 train_time:51312ms step_avg:94.32ms +step:545/1670 train_time:51405ms step_avg:94.32ms +step:546/1670 train_time:51498ms step_avg:94.32ms +step:547/1670 train_time:51591ms step_avg:94.32ms +step:548/1670 train_time:51684ms step_avg:94.31ms +step:549/1670 train_time:51777ms step_avg:94.31ms +step:550/1670 train_time:51870ms step_avg:94.31ms +step:551/1670 train_time:51962ms step_avg:94.30ms +step:552/1670 train_time:52054ms step_avg:94.30ms +step:553/1670 train_time:52147ms step_avg:94.30ms +step:554/1670 train_time:52239ms step_avg:94.29ms +step:555/1670 train_time:52332ms step_avg:94.29ms +step:556/1670 train_time:52425ms step_avg:94.29ms +step:557/1670 train_time:52518ms step_avg:94.29ms +step:558/1670 train_time:52713ms step_avg:94.47ms +step:559/1670 train_time:52789ms step_avg:94.43ms +step:560/1670 train_time:52881ms step_avg:94.43ms +step:561/1670 train_time:52974ms step_avg:94.43ms +step:562/1670 train_time:53066ms step_avg:94.42ms +step:563/1670 train_time:53159ms step_avg:94.42ms +step:564/1670 train_time:53252ms step_avg:94.42ms +step:565/1670 train_time:53344ms step_avg:94.41ms +step:566/1670 train_time:53437ms step_avg:94.41ms +step:567/1670 train_time:53530ms step_avg:94.41ms +step:568/1670 train_time:53627ms step_avg:94.41ms +step:569/1670 train_time:53725ms step_avg:94.42ms +step:570/1670 train_time:53819ms step_avg:94.42ms +step:571/1670 train_time:53913ms step_avg:94.42ms +step:572/1670 train_time:54006ms step_avg:94.42ms +step:573/1670 train_time:54099ms step_avg:94.41ms +step:574/1670 train_time:54192ms step_avg:94.41ms +step:575/1670 train_time:54284ms step_avg:94.41ms +step:576/1670 train_time:54377ms step_avg:94.41ms +step:577/1670 train_time:54471ms step_avg:94.40ms +step:578/1670 train_time:54564ms step_avg:94.40ms +step:579/1670 train_time:54660ms step_avg:94.40ms +step:580/1670 train_time:54755ms step_avg:94.41ms +step:581/1670 train_time:54850ms step_avg:94.41ms +step:582/1670 train_time:54945ms step_avg:94.41ms +step:583/1670 train_time:55038ms step_avg:94.41ms +step:584/1670 train_time:55133ms step_avg:94.41ms +step:585/1670 train_time:55225ms step_avg:94.40ms +step:586/1670 train_time:55318ms step_avg:94.40ms +step:587/1670 train_time:55411ms step_avg:94.40ms +step:588/1670 train_time:55504ms step_avg:94.39ms +step:589/1670 train_time:55598ms step_avg:94.39ms +step:590/1670 train_time:55692ms step_avg:94.39ms +step:591/1670 train_time:55786ms step_avg:94.39ms +step:592/1670 train_time:55881ms step_avg:94.39ms +step:593/1670 train_time:55976ms step_avg:94.39ms +step:594/1670 train_time:56070ms step_avg:94.39ms +step:595/1670 train_time:56163ms step_avg:94.39ms +step:596/1670 train_time:56256ms step_avg:94.39ms +step:597/1670 train_time:56350ms step_avg:94.39ms +step:598/1670 train_time:56443ms step_avg:94.39ms +step:599/1670 train_time:56537ms step_avg:94.39ms +step:600/1670 train_time:56631ms step_avg:94.39ms +step:601/1670 train_time:56725ms step_avg:94.38ms +step:602/1670 train_time:56820ms step_avg:94.38ms +step:603/1670 train_time:56914ms step_avg:94.39ms +step:604/1670 train_time:57008ms step_avg:94.38ms +step:605/1670 train_time:57102ms step_avg:94.38ms +step:606/1670 train_time:57195ms step_avg:94.38ms +step:607/1670 train_time:57289ms step_avg:94.38ms +step:608/1670 train_time:57382ms step_avg:94.38ms +step:609/1670 train_time:57476ms step_avg:94.38ms +step:610/1670 train_time:57569ms step_avg:94.38ms +step:611/1670 train_time:57664ms step_avg:94.38ms +step:612/1670 train_time:57758ms step_avg:94.38ms +step:613/1670 train_time:57853ms step_avg:94.38ms +step:614/1670 train_time:57948ms step_avg:94.38ms +step:615/1670 train_time:58042ms step_avg:94.38ms +step:616/1670 train_time:58135ms step_avg:94.38ms +step:617/1670 train_time:58229ms step_avg:94.37ms +step:618/1670 train_time:58321ms step_avg:94.37ms +step:619/1670 train_time:58415ms step_avg:94.37ms +step:620/1670 train_time:58509ms step_avg:94.37ms +step:621/1670 train_time:58602ms step_avg:94.37ms +step:622/1670 train_time:58696ms step_avg:94.37ms +step:623/1670 train_time:58790ms step_avg:94.37ms +step:624/1670 train_time:58885ms step_avg:94.37ms +step:625/1670 train_time:58979ms step_avg:94.37ms +step:625/1670 val_loss:3.6157 train_time:59071ms step_avg:94.51ms +step:626/1670 train_time:59101ms step_avg:94.41ms +step:627/1670 train_time:59176ms step_avg:94.38ms +step:628/1670 train_time:59276ms step_avg:94.39ms +step:629/1670 train_time:59371ms step_avg:94.39ms +step:630/1670 train_time:59463ms step_avg:94.39ms +step:631/1670 train_time:59556ms step_avg:94.38ms +step:632/1670 train_time:59649ms step_avg:94.38ms +step:633/1670 train_time:59742ms step_avg:94.38ms +step:634/1670 train_time:59834ms step_avg:94.38ms +step:635/1670 train_time:59927ms step_avg:94.37ms +step:636/1670 train_time:60022ms step_avg:94.37ms +step:637/1670 train_time:60118ms step_avg:94.38ms +step:638/1670 train_time:60213ms step_avg:94.38ms +step:639/1670 train_time:60659ms step_avg:94.93ms +step:640/1670 train_time:60727ms step_avg:94.89ms +step:641/1670 train_time:60820ms step_avg:94.88ms +step:642/1670 train_time:60912ms step_avg:94.88ms +step:643/1670 train_time:61005ms step_avg:94.88ms +step:644/1670 train_time:61098ms step_avg:94.87ms +step:645/1670 train_time:61191ms step_avg:94.87ms +step:646/1670 train_time:61283ms step_avg:94.87ms +step:647/1670 train_time:61376ms step_avg:94.86ms +step:648/1670 train_time:61469ms step_avg:94.86ms +step:649/1670 train_time:61565ms step_avg:94.86ms +step:650/1670 train_time:61661ms step_avg:94.86ms +step:651/1670 train_time:61756ms step_avg:94.86ms +step:652/1670 train_time:61850ms step_avg:94.86ms +step:653/1670 train_time:61943ms step_avg:94.86ms +step:654/1670 train_time:62035ms step_avg:94.86ms +step:655/1670 train_time:62128ms step_avg:94.85ms +step:656/1670 train_time:62222ms step_avg:94.85ms +step:657/1670 train_time:62315ms step_avg:94.85ms +step:658/1670 train_time:62408ms step_avg:94.84ms +step:659/1670 train_time:62502ms step_avg:94.84ms +step:660/1670 train_time:62597ms step_avg:94.84ms +step:661/1670 train_time:62692ms step_avg:94.84ms +step:662/1670 train_time:62786ms step_avg:94.84ms +step:663/1670 train_time:62880ms step_avg:94.84ms +step:664/1670 train_time:62973ms step_avg:94.84ms +step:665/1670 train_time:63066ms step_avg:94.84ms +step:666/1670 train_time:63160ms step_avg:94.83ms +step:667/1670 train_time:63254ms step_avg:94.83ms +step:668/1670 train_time:63347ms step_avg:94.83ms +step:669/1670 train_time:63441ms step_avg:94.83ms +step:670/1670 train_time:63535ms step_avg:94.83ms +step:671/1670 train_time:63629ms step_avg:94.83ms +step:672/1670 train_time:63724ms step_avg:94.83ms +step:673/1670 train_time:63818ms step_avg:94.83ms +step:674/1670 train_time:63911ms step_avg:94.82ms +step:675/1670 train_time:64004ms step_avg:94.82ms +step:676/1670 train_time:64098ms step_avg:94.82ms +step:677/1670 train_time:64191ms step_avg:94.82ms +step:678/1670 train_time:64285ms step_avg:94.82ms +step:679/1670 train_time:64379ms step_avg:94.82ms +step:680/1670 train_time:64473ms step_avg:94.81ms +step:681/1670 train_time:64567ms step_avg:94.81ms +step:682/1670 train_time:64661ms step_avg:94.81ms +step:683/1670 train_time:64755ms step_avg:94.81ms +step:684/1670 train_time:64849ms step_avg:94.81ms +step:685/1670 train_time:64943ms step_avg:94.81ms +step:686/1670 train_time:65036ms step_avg:94.80ms +step:687/1670 train_time:65129ms step_avg:94.80ms +step:688/1670 train_time:65223ms step_avg:94.80ms +step:689/1670 train_time:65317ms step_avg:94.80ms +step:690/1670 train_time:65410ms step_avg:94.80ms +step:691/1670 train_time:65504ms step_avg:94.80ms +step:692/1670 train_time:65598ms step_avg:94.79ms +step:693/1670 train_time:65691ms step_avg:94.79ms +step:694/1670 train_time:65785ms step_avg:94.79ms +step:695/1670 train_time:65879ms step_avg:94.79ms +step:696/1670 train_time:65974ms step_avg:94.79ms +step:697/1670 train_time:66067ms step_avg:94.79ms +step:698/1670 train_time:66160ms step_avg:94.79ms +step:699/1670 train_time:66255ms step_avg:94.79ms +step:700/1670 train_time:66349ms step_avg:94.78ms +step:701/1670 train_time:66443ms step_avg:94.78ms +step:702/1670 train_time:66535ms step_avg:94.78ms +step:703/1670 train_time:66628ms step_avg:94.78ms +step:704/1670 train_time:66723ms step_avg:94.78ms +step:705/1670 train_time:66817ms step_avg:94.78ms +step:706/1670 train_time:66910ms step_avg:94.77ms +step:707/1670 train_time:67005ms step_avg:94.77ms +step:708/1670 train_time:67098ms step_avg:94.77ms +step:709/1670 train_time:67191ms step_avg:94.77ms +step:710/1670 train_time:67285ms step_avg:94.77ms +step:711/1670 train_time:67379ms step_avg:94.77ms +step:712/1670 train_time:67472ms step_avg:94.76ms +step:713/1670 train_time:67566ms step_avg:94.76ms +step:714/1670 train_time:67660ms step_avg:94.76ms +step:715/1670 train_time:67754ms step_avg:94.76ms +step:716/1670 train_time:67848ms step_avg:94.76ms +step:717/1670 train_time:67942ms step_avg:94.76ms +step:718/1670 train_time:68037ms step_avg:94.76ms +step:719/1670 train_time:68130ms step_avg:94.76ms +step:720/1670 train_time:68223ms step_avg:94.75ms +step:721/1670 train_time:68317ms step_avg:94.75ms +step:722/1670 train_time:68411ms step_avg:94.75ms +step:723/1670 train_time:68505ms step_avg:94.75ms +step:724/1670 train_time:68599ms step_avg:94.75ms +step:725/1670 train_time:68694ms step_avg:94.75ms +step:726/1670 train_time:68788ms step_avg:94.75ms +step:727/1670 train_time:68882ms step_avg:94.75ms +step:728/1670 train_time:68976ms step_avg:94.75ms +step:729/1670 train_time:69069ms step_avg:94.75ms +step:730/1670 train_time:69163ms step_avg:94.74ms +step:731/1670 train_time:69257ms step_avg:94.74ms +step:732/1670 train_time:69350ms step_avg:94.74ms +step:733/1670 train_time:69444ms step_avg:94.74ms +step:734/1670 train_time:69538ms step_avg:94.74ms +step:735/1670 train_time:69631ms step_avg:94.74ms +step:736/1670 train_time:69725ms step_avg:94.73ms +step:737/1670 train_time:69819ms step_avg:94.73ms +step:738/1670 train_time:69914ms step_avg:94.73ms +step:739/1670 train_time:70008ms step_avg:94.73ms +step:740/1670 train_time:70101ms step_avg:94.73ms +step:741/1670 train_time:70195ms step_avg:94.73ms +step:742/1670 train_time:70288ms step_avg:94.73ms +step:743/1670 train_time:70383ms step_avg:94.73ms +step:744/1670 train_time:70477ms step_avg:94.73ms +step:745/1670 train_time:70570ms step_avg:94.73ms +step:746/1670 train_time:70663ms step_avg:94.72ms +step:747/1670 train_time:70757ms step_avg:94.72ms +step:748/1670 train_time:70851ms step_avg:94.72ms +step:749/1670 train_time:70945ms step_avg:94.72ms +step:750/1670 train_time:71039ms step_avg:94.72ms +step:750/1670 val_loss:3.5631 train_time:71130ms step_avg:94.84ms +step:751/1670 train_time:71157ms step_avg:94.75ms +step:752/1670 train_time:71233ms step_avg:94.72ms +step:753/1670 train_time:71333ms step_avg:94.73ms +step:754/1670 train_time:71428ms step_avg:94.73ms +step:755/1670 train_time:71521ms step_avg:94.73ms +step:756/1670 train_time:71614ms step_avg:94.73ms +step:757/1670 train_time:71707ms step_avg:94.73ms +step:758/1670 train_time:71800ms step_avg:94.72ms +step:759/1670 train_time:71892ms step_avg:94.72ms +step:760/1670 train_time:71985ms step_avg:94.72ms +step:761/1670 train_time:72078ms step_avg:94.72ms +step:762/1670 train_time:72173ms step_avg:94.72ms +step:763/1670 train_time:72271ms step_avg:94.72ms +step:764/1670 train_time:72366ms step_avg:94.72ms +step:765/1670 train_time:72461ms step_avg:94.72ms +step:766/1670 train_time:72555ms step_avg:94.72ms +step:767/1670 train_time:72648ms step_avg:94.72ms +step:768/1670 train_time:72741ms step_avg:94.72ms +step:769/1670 train_time:72834ms step_avg:94.71ms +step:770/1670 train_time:72927ms step_avg:94.71ms +step:771/1670 train_time:73020ms step_avg:94.71ms +step:772/1670 train_time:73113ms step_avg:94.71ms +step:773/1670 train_time:73207ms step_avg:94.71ms +step:774/1670 train_time:73303ms step_avg:94.71ms +step:775/1670 train_time:73398ms step_avg:94.71ms +step:776/1670 train_time:73492ms step_avg:94.71ms +step:777/1670 train_time:73586ms step_avg:94.70ms +step:778/1670 train_time:73679ms step_avg:94.70ms +step:779/1670 train_time:73772ms step_avg:94.70ms +step:780/1670 train_time:73865ms step_avg:94.70ms +step:781/1670 train_time:73958ms step_avg:94.70ms +step:782/1670 train_time:74052ms step_avg:94.70ms +step:783/1670 train_time:74145ms step_avg:94.69ms +step:784/1670 train_time:74240ms step_avg:94.69ms +step:785/1670 train_time:74334ms step_avg:94.69ms +step:786/1670 train_time:74428ms step_avg:94.69ms +step:787/1670 train_time:74522ms step_avg:94.69ms +step:788/1670 train_time:74616ms step_avg:94.69ms +step:789/1670 train_time:74709ms step_avg:94.69ms +step:790/1670 train_time:74802ms step_avg:94.69ms +step:791/1670 train_time:74896ms step_avg:94.68ms +step:792/1670 train_time:74989ms step_avg:94.68ms +step:793/1670 train_time:75083ms step_avg:94.68ms +step:794/1670 train_time:75177ms step_avg:94.68ms +step:795/1670 train_time:75271ms step_avg:94.68ms +step:796/1670 train_time:75365ms step_avg:94.68ms +step:797/1670 train_time:75461ms step_avg:94.68ms +step:798/1670 train_time:75554ms step_avg:94.68ms +step:799/1670 train_time:75648ms step_avg:94.68ms +step:800/1670 train_time:75741ms step_avg:94.68ms +step:801/1670 train_time:75835ms step_avg:94.67ms +step:802/1670 train_time:75928ms step_avg:94.67ms +step:803/1670 train_time:76021ms step_avg:94.67ms +step:804/1670 train_time:76114ms step_avg:94.67ms +step:805/1670 train_time:76208ms step_avg:94.67ms +step:806/1670 train_time:76301ms step_avg:94.67ms +step:807/1670 train_time:76395ms step_avg:94.67ms +step:808/1670 train_time:76490ms step_avg:94.67ms +step:809/1670 train_time:76584ms step_avg:94.67ms +step:810/1670 train_time:76678ms step_avg:94.66ms +step:811/1670 train_time:76771ms step_avg:94.66ms +step:812/1670 train_time:76865ms step_avg:94.66ms +step:813/1670 train_time:76958ms step_avg:94.66ms +step:814/1670 train_time:77052ms step_avg:94.66ms +step:815/1670 train_time:77145ms step_avg:94.66ms +step:816/1670 train_time:77239ms step_avg:94.66ms +step:817/1670 train_time:77333ms step_avg:94.66ms +step:818/1670 train_time:77427ms step_avg:94.65ms +step:819/1670 train_time:77522ms step_avg:94.65ms +step:820/1670 train_time:77616ms step_avg:94.65ms +step:821/1670 train_time:77709ms step_avg:94.65ms +step:822/1670 train_time:77802ms step_avg:94.65ms +step:823/1670 train_time:77896ms step_avg:94.65ms +step:824/1670 train_time:77990ms step_avg:94.65ms +step:825/1670 train_time:78085ms step_avg:94.65ms +step:826/1670 train_time:78179ms step_avg:94.65ms +step:827/1670 train_time:78272ms step_avg:94.65ms +step:828/1670 train_time:78367ms step_avg:94.65ms +step:829/1670 train_time:78461ms step_avg:94.65ms +step:830/1670 train_time:78555ms step_avg:94.64ms +step:831/1670 train_time:78648ms step_avg:94.64ms +step:832/1670 train_time:78742ms step_avg:94.64ms +step:833/1670 train_time:78836ms step_avg:94.64ms +step:834/1670 train_time:78929ms step_avg:94.64ms +step:835/1670 train_time:79022ms step_avg:94.64ms +step:836/1670 train_time:79116ms step_avg:94.64ms +step:837/1670 train_time:79210ms step_avg:94.64ms +step:838/1670 train_time:79303ms step_avg:94.63ms +step:839/1670 train_time:79397ms step_avg:94.63ms +step:840/1670 train_time:79492ms step_avg:94.63ms +step:841/1670 train_time:79586ms step_avg:94.63ms +step:842/1670 train_time:79679ms step_avg:94.63ms +step:843/1670 train_time:79773ms step_avg:94.63ms +step:844/1670 train_time:79866ms step_avg:94.63ms +step:845/1670 train_time:79961ms step_avg:94.63ms +step:846/1670 train_time:80055ms step_avg:94.63ms +step:847/1670 train_time:80148ms step_avg:94.63ms +step:848/1670 train_time:80241ms step_avg:94.62ms +step:849/1670 train_time:80335ms step_avg:94.62ms +step:850/1670 train_time:80429ms step_avg:94.62ms +step:851/1670 train_time:80850ms step_avg:95.01ms +step:852/1670 train_time:80953ms step_avg:95.02ms +step:853/1670 train_time:81046ms step_avg:95.01ms +step:854/1670 train_time:81139ms step_avg:95.01ms +step:855/1670 train_time:81231ms step_avg:95.01ms +step:856/1670 train_time:81325ms step_avg:95.01ms +step:857/1670 train_time:81417ms step_avg:95.00ms +step:858/1670 train_time:81509ms step_avg:95.00ms +step:859/1670 train_time:81602ms step_avg:95.00ms +step:860/1670 train_time:81695ms step_avg:94.99ms +step:861/1670 train_time:81792ms step_avg:95.00ms +step:862/1670 train_time:81891ms step_avg:95.00ms +step:863/1670 train_time:81987ms step_avg:95.00ms +step:864/1670 train_time:82081ms step_avg:95.00ms +step:865/1670 train_time:82174ms step_avg:95.00ms +step:866/1670 train_time:82266ms step_avg:95.00ms +step:867/1670 train_time:82360ms step_avg:94.99ms +step:868/1670 train_time:82453ms step_avg:94.99ms +step:869/1670 train_time:82546ms step_avg:94.99ms +step:870/1670 train_time:82639ms step_avg:94.99ms +step:871/1670 train_time:82733ms step_avg:94.99ms +step:872/1670 train_time:82828ms step_avg:94.99ms +step:873/1670 train_time:82923ms step_avg:94.99ms +step:874/1670 train_time:83019ms step_avg:94.99ms +step:875/1670 train_time:83113ms step_avg:94.99ms +step:875/1670 val_loss:3.5188 train_time:83204ms step_avg:95.09ms +step:876/1670 train_time:83231ms step_avg:95.01ms +step:877/1670 train_time:83307ms step_avg:94.99ms +step:878/1670 train_time:83407ms step_avg:95.00ms +step:879/1670 train_time:83501ms step_avg:95.00ms +step:880/1670 train_time:83595ms step_avg:94.99ms +step:881/1670 train_time:83687ms step_avg:94.99ms +step:882/1670 train_time:83780ms step_avg:94.99ms +step:883/1670 train_time:83873ms step_avg:94.99ms +step:884/1670 train_time:83965ms step_avg:94.98ms +step:885/1670 train_time:84058ms step_avg:94.98ms +step:886/1670 train_time:84151ms step_avg:94.98ms +step:887/1670 train_time:84245ms step_avg:94.98ms +step:888/1670 train_time:84342ms step_avg:94.98ms +step:889/1670 train_time:84438ms step_avg:94.98ms +step:890/1670 train_time:84532ms step_avg:94.98ms +step:891/1670 train_time:84626ms step_avg:94.98ms +step:892/1670 train_time:84719ms step_avg:94.98ms +step:893/1670 train_time:84811ms step_avg:94.97ms +step:894/1670 train_time:84904ms step_avg:94.97ms +step:895/1670 train_time:84998ms step_avg:94.97ms +step:896/1670 train_time:85092ms step_avg:94.97ms +step:897/1670 train_time:85185ms step_avg:94.97ms +step:898/1670 train_time:85279ms step_avg:94.97ms +step:899/1670 train_time:85373ms step_avg:94.96ms +step:900/1670 train_time:85468ms step_avg:94.96ms +step:901/1670 train_time:85562ms step_avg:94.96ms +step:902/1670 train_time:85655ms step_avg:94.96ms +step:903/1670 train_time:85748ms step_avg:94.96ms +step:904/1670 train_time:85841ms step_avg:94.96ms +step:905/1670 train_time:85935ms step_avg:94.96ms +step:906/1670 train_time:86028ms step_avg:94.95ms +step:907/1670 train_time:86122ms step_avg:94.95ms +step:908/1670 train_time:86215ms step_avg:94.95ms +step:909/1670 train_time:86309ms step_avg:94.95ms +step:910/1670 train_time:86404ms step_avg:94.95ms +step:911/1670 train_time:86499ms step_avg:94.95ms +step:912/1670 train_time:86592ms step_avg:94.95ms +step:913/1670 train_time:86686ms step_avg:94.95ms +step:914/1670 train_time:86778ms step_avg:94.94ms +step:915/1670 train_time:86871ms step_avg:94.94ms +step:916/1670 train_time:86964ms step_avg:94.94ms +step:917/1670 train_time:87058ms step_avg:94.94ms +step:918/1670 train_time:87151ms step_avg:94.94ms +step:919/1670 train_time:87245ms step_avg:94.93ms +step:920/1670 train_time:87339ms step_avg:94.93ms +step:921/1670 train_time:87434ms step_avg:94.93ms +step:922/1670 train_time:87528ms step_avg:94.93ms +step:923/1670 train_time:87622ms step_avg:94.93ms +step:924/1670 train_time:87716ms step_avg:94.93ms +step:925/1670 train_time:87809ms step_avg:94.93ms +step:926/1670 train_time:87902ms step_avg:94.93ms +step:927/1670 train_time:87996ms step_avg:94.93ms +step:928/1670 train_time:88090ms step_avg:94.92ms +step:929/1670 train_time:88184ms step_avg:94.92ms +step:930/1670 train_time:88277ms step_avg:94.92ms +step:931/1670 train_time:88372ms step_avg:94.92ms +step:932/1670 train_time:88466ms step_avg:94.92ms +step:933/1670 train_time:88560ms step_avg:94.92ms +step:934/1670 train_time:88654ms step_avg:94.92ms +step:935/1670 train_time:88747ms step_avg:94.92ms +step:936/1670 train_time:88840ms step_avg:94.92ms +step:937/1670 train_time:88934ms step_avg:94.91ms +step:938/1670 train_time:89028ms step_avg:94.91ms +step:939/1670 train_time:89121ms step_avg:94.91ms +step:940/1670 train_time:89215ms step_avg:94.91ms +step:941/1670 train_time:89308ms step_avg:94.91ms +step:942/1670 train_time:89403ms step_avg:94.91ms +step:943/1670 train_time:89497ms step_avg:94.91ms +step:944/1670 train_time:89592ms step_avg:94.91ms +step:945/1670 train_time:89685ms step_avg:94.90ms +step:946/1670 train_time:89779ms step_avg:94.90ms +step:947/1670 train_time:89873ms step_avg:94.90ms +step:948/1670 train_time:89967ms step_avg:94.90ms +step:949/1670 train_time:90061ms step_avg:94.90ms +step:950/1670 train_time:90154ms step_avg:94.90ms +step:951/1670 train_time:90247ms step_avg:94.90ms +step:952/1670 train_time:90341ms step_avg:94.90ms +step:953/1670 train_time:90435ms step_avg:94.90ms +step:954/1670 train_time:90529ms step_avg:94.89ms +step:955/1670 train_time:90623ms step_avg:94.89ms +step:956/1670 train_time:90718ms step_avg:94.89ms +step:957/1670 train_time:90811ms step_avg:94.89ms +step:958/1670 train_time:90905ms step_avg:94.89ms +step:959/1670 train_time:90998ms step_avg:94.89ms +step:960/1670 train_time:91092ms step_avg:94.89ms +step:961/1670 train_time:91185ms step_avg:94.89ms +step:962/1670 train_time:91278ms step_avg:94.88ms +step:963/1670 train_time:91372ms step_avg:94.88ms +step:964/1670 train_time:91466ms step_avg:94.88ms +step:965/1670 train_time:91560ms step_avg:94.88ms +step:966/1670 train_time:91653ms step_avg:94.88ms +step:967/1670 train_time:91747ms step_avg:94.88ms +step:968/1670 train_time:91840ms step_avg:94.88ms +step:969/1670 train_time:91935ms step_avg:94.88ms +step:970/1670 train_time:92028ms step_avg:94.87ms +step:971/1670 train_time:92122ms step_avg:94.87ms +step:972/1670 train_time:92214ms step_avg:94.87ms +step:973/1670 train_time:92308ms step_avg:94.87ms +step:974/1670 train_time:92401ms step_avg:94.87ms +step:975/1670 train_time:92495ms step_avg:94.87ms +step:976/1670 train_time:92590ms step_avg:94.87ms +step:977/1670 train_time:92683ms step_avg:94.87ms +step:978/1670 train_time:92777ms step_avg:94.86ms +step:979/1670 train_time:92872ms step_avg:94.86ms +step:980/1670 train_time:92965ms step_avg:94.86ms +step:981/1670 train_time:93058ms step_avg:94.86ms +step:982/1670 train_time:93152ms step_avg:94.86ms +step:983/1670 train_time:93247ms step_avg:94.86ms +step:984/1670 train_time:93340ms step_avg:94.86ms +step:985/1670 train_time:93434ms step_avg:94.86ms +step:986/1670 train_time:93528ms step_avg:94.86ms +step:987/1670 train_time:93621ms step_avg:94.85ms +step:988/1670 train_time:93715ms step_avg:94.85ms +step:989/1670 train_time:93808ms step_avg:94.85ms +step:990/1670 train_time:93902ms step_avg:94.85ms +step:991/1670 train_time:93995ms step_avg:94.85ms +step:992/1670 train_time:94089ms step_avg:94.85ms +step:993/1670 train_time:94182ms step_avg:94.85ms +step:994/1670 train_time:94276ms step_avg:94.84ms +step:995/1670 train_time:94370ms step_avg:94.84ms +step:996/1670 train_time:94463ms step_avg:94.84ms +step:997/1670 train_time:94557ms step_avg:94.84ms +step:998/1670 train_time:94651ms step_avg:94.84ms +step:999/1670 train_time:94745ms step_avg:94.84ms +step:1000/1670 train_time:94839ms step_avg:94.84ms +step:1000/1670 val_loss:3.4686 train_time:94930ms step_avg:94.93ms +step:1001/1670 train_time:94957ms step_avg:94.86ms +step:1002/1670 train_time:95032ms step_avg:94.84ms +step:1003/1670 train_time:95130ms step_avg:94.85ms +step:1004/1670 train_time:95224ms step_avg:94.84ms +step:1005/1670 train_time:95316ms step_avg:94.84ms +step:1006/1670 train_time:95409ms step_avg:94.84ms +step:1007/1670 train_time:95502ms step_avg:94.84ms +step:1008/1670 train_time:95594ms step_avg:94.84ms +step:1009/1670 train_time:95687ms step_avg:94.83ms +step:1010/1670 train_time:95780ms step_avg:94.83ms +step:1011/1670 train_time:95873ms step_avg:94.83ms +step:1012/1670 train_time:95970ms step_avg:94.83ms +step:1013/1670 train_time:96066ms step_avg:94.83ms +step:1014/1670 train_time:96162ms step_avg:94.83ms +step:1015/1670 train_time:96256ms step_avg:94.83ms +step:1016/1670 train_time:96349ms step_avg:94.83ms +step:1017/1670 train_time:96442ms step_avg:94.83ms +step:1018/1670 train_time:96535ms step_avg:94.83ms +step:1019/1670 train_time:96629ms step_avg:94.83ms +step:1020/1670 train_time:96722ms step_avg:94.83ms +step:1021/1670 train_time:96814ms step_avg:94.82ms +step:1022/1670 train_time:96908ms step_avg:94.82ms +step:1023/1670 train_time:97003ms step_avg:94.82ms +step:1024/1670 train_time:97099ms step_avg:94.82ms +step:1025/1670 train_time:97193ms step_avg:94.82ms +step:1026/1670 train_time:97288ms step_avg:94.82ms +step:1027/1670 train_time:97382ms step_avg:94.82ms +step:1028/1670 train_time:97476ms step_avg:94.82ms +step:1029/1670 train_time:97568ms step_avg:94.82ms +step:1030/1670 train_time:97662ms step_avg:94.82ms +step:1031/1670 train_time:97755ms step_avg:94.82ms +step:1032/1670 train_time:97848ms step_avg:94.81ms +step:1033/1670 train_time:97942ms step_avg:94.81ms +step:1034/1670 train_time:98036ms step_avg:94.81ms +step:1035/1670 train_time:98131ms step_avg:94.81ms +step:1036/1670 train_time:98225ms step_avg:94.81ms +step:1037/1670 train_time:98319ms step_avg:94.81ms +step:1038/1670 train_time:98413ms step_avg:94.81ms +step:1039/1670 train_time:98507ms step_avg:94.81ms +step:1040/1670 train_time:98600ms step_avg:94.81ms +step:1041/1670 train_time:98693ms step_avg:94.81ms +step:1042/1670 train_time:98787ms step_avg:94.81ms +step:1043/1670 train_time:98881ms step_avg:94.80ms +step:1044/1670 train_time:98975ms step_avg:94.80ms +step:1045/1670 train_time:99069ms step_avg:94.80ms +step:1046/1670 train_time:99163ms step_avg:94.80ms +step:1047/1670 train_time:99257ms step_avg:94.80ms +step:1048/1670 train_time:99351ms step_avg:94.80ms +step:1049/1670 train_time:99445ms step_avg:94.80ms +step:1050/1670 train_time:99539ms step_avg:94.80ms +step:1051/1670 train_time:99632ms step_avg:94.80ms +step:1052/1670 train_time:99726ms step_avg:94.80ms +step:1053/1670 train_time:99820ms step_avg:94.80ms +step:1054/1670 train_time:99914ms step_avg:94.80ms +step:1055/1670 train_time:100008ms step_avg:94.79ms +step:1056/1670 train_time:100102ms step_avg:94.79ms +step:1057/1670 train_time:100196ms step_avg:94.79ms +step:1058/1670 train_time:100290ms step_avg:94.79ms +step:1059/1670 train_time:100384ms step_avg:94.79ms +step:1060/1670 train_time:100478ms step_avg:94.79ms +step:1061/1670 train_time:100571ms step_avg:94.79ms +step:1062/1670 train_time:101016ms step_avg:95.12ms +step:1063/1670 train_time:101091ms step_avg:95.10ms +step:1064/1670 train_time:101183ms step_avg:95.10ms +step:1065/1670 train_time:101276ms step_avg:95.09ms +step:1066/1670 train_time:101368ms step_avg:95.09ms +step:1067/1670 train_time:101461ms step_avg:95.09ms +step:1068/1670 train_time:101553ms step_avg:95.09ms +step:1069/1670 train_time:101645ms step_avg:95.08ms +step:1070/1670 train_time:101738ms step_avg:95.08ms +step:1071/1670 train_time:101831ms step_avg:95.08ms +step:1072/1670 train_time:101925ms step_avg:95.08ms +step:1073/1670 train_time:102024ms step_avg:95.08ms +step:1074/1670 train_time:102121ms step_avg:95.08ms +step:1075/1670 train_time:102215ms step_avg:95.08ms +step:1076/1670 train_time:102310ms step_avg:95.08ms +step:1077/1670 train_time:102403ms step_avg:95.08ms +step:1078/1670 train_time:102496ms step_avg:95.08ms +step:1079/1670 train_time:102589ms step_avg:95.08ms +step:1080/1670 train_time:102682ms step_avg:95.08ms +step:1081/1670 train_time:102775ms step_avg:95.07ms +step:1082/1670 train_time:102868ms step_avg:95.07ms +step:1083/1670 train_time:102963ms step_avg:95.07ms +step:1084/1670 train_time:103059ms step_avg:95.07ms +step:1085/1670 train_time:103154ms step_avg:95.07ms +step:1086/1670 train_time:103249ms step_avg:95.07ms +step:1087/1670 train_time:103343ms step_avg:95.07ms +step:1088/1670 train_time:103436ms step_avg:95.07ms +step:1089/1670 train_time:103530ms step_avg:95.07ms +step:1090/1670 train_time:103623ms step_avg:95.07ms +step:1091/1670 train_time:103716ms step_avg:95.06ms +step:1092/1670 train_time:103809ms step_avg:95.06ms +step:1093/1670 train_time:103903ms step_avg:95.06ms +step:1094/1670 train_time:103997ms step_avg:95.06ms +step:1095/1670 train_time:104091ms step_avg:95.06ms +step:1096/1670 train_time:104186ms step_avg:95.06ms +step:1097/1670 train_time:104280ms step_avg:95.06ms +step:1098/1670 train_time:104373ms step_avg:95.06ms +step:1099/1670 train_time:104468ms step_avg:95.06ms +step:1100/1670 train_time:104562ms step_avg:95.06ms +step:1101/1670 train_time:104655ms step_avg:95.05ms +step:1102/1670 train_time:104749ms step_avg:95.05ms +step:1103/1670 train_time:104842ms step_avg:95.05ms +step:1104/1670 train_time:104935ms step_avg:95.05ms +step:1105/1670 train_time:105030ms step_avg:95.05ms +step:1106/1670 train_time:105124ms step_avg:95.05ms +step:1107/1670 train_time:105218ms step_avg:95.05ms +step:1108/1670 train_time:105311ms step_avg:95.05ms +step:1109/1670 train_time:105405ms step_avg:95.05ms +step:1110/1670 train_time:105499ms step_avg:95.04ms +step:1111/1670 train_time:105593ms step_avg:95.04ms +step:1112/1670 train_time:105686ms step_avg:95.04ms +step:1113/1670 train_time:105779ms step_avg:95.04ms +step:1114/1670 train_time:105872ms step_avg:95.04ms +step:1115/1670 train_time:106070ms step_avg:95.13ms +step:1116/1670 train_time:106146ms step_avg:95.11ms +step:1117/1670 train_time:106239ms step_avg:95.11ms +step:1118/1670 train_time:106333ms step_avg:95.11ms +step:1119/1670 train_time:106426ms step_avg:95.11ms +step:1120/1670 train_time:106520ms step_avg:95.11ms +step:1121/1670 train_time:106613ms step_avg:95.11ms +step:1122/1670 train_time:106708ms step_avg:95.10ms +step:1123/1670 train_time:106801ms step_avg:95.10ms +step:1124/1670 train_time:106895ms step_avg:95.10ms +step:1125/1670 train_time:106996ms step_avg:95.11ms +step:1125/1670 val_loss:3.4159 train_time:107091ms step_avg:95.19ms +step:1126/1670 train_time:107119ms step_avg:95.13ms +step:1127/1670 train_time:107195ms step_avg:95.11ms +step:1128/1670 train_time:107294ms step_avg:95.12ms +step:1129/1670 train_time:107387ms step_avg:95.12ms +step:1130/1670 train_time:107480ms step_avg:95.12ms +step:1131/1670 train_time:107574ms step_avg:95.11ms +step:1132/1670 train_time:107667ms step_avg:95.11ms +step:1133/1670 train_time:107761ms step_avg:95.11ms +step:1134/1670 train_time:107855ms step_avg:95.11ms +step:1135/1670 train_time:107948ms step_avg:95.11ms +step:1136/1670 train_time:108045ms step_avg:95.11ms +step:1137/1670 train_time:108144ms step_avg:95.11ms +step:1138/1670 train_time:108241ms step_avg:95.11ms +step:1139/1670 train_time:108337ms step_avg:95.12ms +step:1140/1670 train_time:108430ms step_avg:95.11ms +step:1141/1670 train_time:108524ms step_avg:95.11ms +step:1142/1670 train_time:108618ms step_avg:95.11ms +step:1143/1670 train_time:108711ms step_avg:95.11ms +step:1144/1670 train_time:108805ms step_avg:95.11ms +step:1145/1670 train_time:108899ms step_avg:95.11ms +step:1146/1670 train_time:108993ms step_avg:95.11ms +step:1147/1670 train_time:109090ms step_avg:95.11ms +step:1148/1670 train_time:109186ms step_avg:95.11ms +step:1149/1670 train_time:109281ms step_avg:95.11ms +step:1150/1670 train_time:109376ms step_avg:95.11ms +step:1151/1670 train_time:109470ms step_avg:95.11ms +step:1152/1670 train_time:109564ms step_avg:95.11ms +step:1153/1670 train_time:109659ms step_avg:95.11ms +step:1154/1670 train_time:109752ms step_avg:95.11ms +step:1155/1670 train_time:109846ms step_avg:95.10ms +step:1156/1670 train_time:109941ms step_avg:95.10ms +step:1157/1670 train_time:110034ms step_avg:95.10ms +step:1158/1670 train_time:110129ms step_avg:95.10ms +step:1159/1670 train_time:110225ms step_avg:95.10ms +step:1160/1670 train_time:110321ms step_avg:95.10ms +step:1161/1670 train_time:110415ms step_avg:95.10ms +step:1162/1670 train_time:110508ms step_avg:95.10ms +step:1163/1670 train_time:110603ms step_avg:95.10ms +step:1164/1670 train_time:110699ms step_avg:95.10ms +step:1165/1670 train_time:110793ms step_avg:95.10ms +step:1166/1670 train_time:110887ms step_avg:95.10ms +step:1167/1670 train_time:110981ms step_avg:95.10ms +step:1168/1670 train_time:111077ms step_avg:95.10ms +step:1169/1670 train_time:111171ms step_avg:95.10ms +step:1170/1670 train_time:111267ms step_avg:95.10ms +step:1171/1670 train_time:111362ms step_avg:95.10ms +step:1172/1670 train_time:111457ms step_avg:95.10ms +step:1173/1670 train_time:111552ms step_avg:95.10ms +step:1174/1670 train_time:111647ms step_avg:95.10ms +step:1175/1670 train_time:111741ms step_avg:95.10ms +step:1176/1670 train_time:111835ms step_avg:95.10ms +step:1177/1670 train_time:111929ms step_avg:95.10ms +step:1178/1670 train_time:112023ms step_avg:95.10ms +step:1179/1670 train_time:112119ms step_avg:95.10ms +step:1180/1670 train_time:112214ms step_avg:95.10ms +step:1181/1670 train_time:112309ms step_avg:95.10ms +step:1182/1670 train_time:112403ms step_avg:95.10ms +step:1183/1670 train_time:112498ms step_avg:95.10ms +step:1184/1670 train_time:112592ms step_avg:95.09ms +step:1185/1670 train_time:112688ms step_avg:95.10ms +step:1186/1670 train_time:112782ms step_avg:95.09ms +step:1187/1670 train_time:112875ms step_avg:95.09ms +step:1188/1670 train_time:112970ms step_avg:95.09ms +step:1189/1670 train_time:113065ms step_avg:95.09ms +step:1190/1670 train_time:113160ms step_avg:95.09ms +step:1191/1670 train_time:113255ms step_avg:95.09ms +step:1192/1670 train_time:113350ms step_avg:95.09ms +step:1193/1670 train_time:113446ms step_avg:95.09ms +step:1194/1670 train_time:113540ms step_avg:95.09ms +step:1195/1670 train_time:113634ms step_avg:95.09ms +step:1196/1670 train_time:113729ms step_avg:95.09ms +step:1197/1670 train_time:113824ms step_avg:95.09ms +step:1198/1670 train_time:113918ms step_avg:95.09ms +step:1199/1670 train_time:114012ms step_avg:95.09ms +step:1200/1670 train_time:114106ms step_avg:95.09ms +step:1201/1670 train_time:114201ms step_avg:95.09ms +step:1202/1670 train_time:114296ms step_avg:95.09ms +step:1203/1670 train_time:114391ms step_avg:95.09ms +step:1204/1670 train_time:114486ms step_avg:95.09ms +step:1205/1670 train_time:114581ms step_avg:95.09ms +step:1206/1670 train_time:114675ms step_avg:95.09ms +step:1207/1670 train_time:114770ms step_avg:95.09ms +step:1208/1670 train_time:114864ms step_avg:95.09ms +step:1209/1670 train_time:114959ms step_avg:95.09ms +step:1210/1670 train_time:115053ms step_avg:95.09ms +step:1211/1670 train_time:115148ms step_avg:95.09ms +step:1212/1670 train_time:115243ms step_avg:95.08ms +step:1213/1670 train_time:115338ms step_avg:95.09ms +step:1214/1670 train_time:115433ms step_avg:95.08ms +step:1215/1670 train_time:115528ms step_avg:95.08ms +step:1216/1670 train_time:115622ms step_avg:95.08ms +step:1217/1670 train_time:115718ms step_avg:95.08ms +step:1218/1670 train_time:115812ms step_avg:95.08ms +step:1219/1670 train_time:115906ms step_avg:95.08ms +step:1220/1670 train_time:116001ms step_avg:95.08ms +step:1221/1670 train_time:116095ms step_avg:95.08ms +step:1222/1670 train_time:116189ms step_avg:95.08ms +step:1223/1670 train_time:116285ms step_avg:95.08ms +step:1224/1670 train_time:116380ms step_avg:95.08ms +step:1225/1670 train_time:116474ms step_avg:95.08ms +step:1226/1670 train_time:116569ms step_avg:95.08ms +step:1227/1670 train_time:116664ms step_avg:95.08ms +step:1228/1670 train_time:116759ms step_avg:95.08ms +step:1229/1670 train_time:116853ms step_avg:95.08ms +step:1230/1670 train_time:116947ms step_avg:95.08ms +step:1231/1670 train_time:117042ms step_avg:95.08ms +step:1232/1670 train_time:117136ms step_avg:95.08ms +step:1233/1670 train_time:117230ms step_avg:95.08ms +step:1234/1670 train_time:117325ms step_avg:95.08ms +step:1235/1670 train_time:117420ms step_avg:95.08ms +step:1236/1670 train_time:117514ms step_avg:95.08ms +step:1237/1670 train_time:117609ms step_avg:95.08ms +step:1238/1670 train_time:117704ms step_avg:95.08ms +step:1239/1670 train_time:117798ms step_avg:95.08ms +step:1240/1670 train_time:117892ms step_avg:95.07ms +step:1241/1670 train_time:117987ms step_avg:95.07ms +step:1242/1670 train_time:118081ms step_avg:95.07ms +step:1243/1670 train_time:118175ms step_avg:95.07ms +step:1244/1670 train_time:118269ms step_avg:95.07ms +step:1245/1670 train_time:118364ms step_avg:95.07ms +step:1246/1670 train_time:118460ms step_avg:95.07ms +step:1247/1670 train_time:118554ms step_avg:95.07ms +step:1248/1670 train_time:118648ms step_avg:95.07ms +step:1249/1670 train_time:118744ms step_avg:95.07ms +step:1250/1670 train_time:118838ms step_avg:95.07ms +step:1250/1670 val_loss:3.3775 train_time:118930ms step_avg:95.14ms +step:1251/1670 train_time:118957ms step_avg:95.09ms +step:1252/1670 train_time:119032ms step_avg:95.07ms +step:1253/1670 train_time:119135ms step_avg:95.08ms +step:1254/1670 train_time:119231ms step_avg:95.08ms +step:1255/1670 train_time:119325ms step_avg:95.08ms +step:1256/1670 train_time:119418ms step_avg:95.08ms +step:1257/1670 train_time:119511ms step_avg:95.08ms +step:1258/1670 train_time:119604ms step_avg:95.07ms +step:1259/1670 train_time:119698ms step_avg:95.07ms +step:1260/1670 train_time:119791ms step_avg:95.07ms +step:1261/1670 train_time:119885ms step_avg:95.07ms +step:1262/1670 train_time:119980ms step_avg:95.07ms +step:1263/1670 train_time:120079ms step_avg:95.07ms +step:1264/1670 train_time:120176ms step_avg:95.08ms +step:1265/1670 train_time:120273ms step_avg:95.08ms +step:1266/1670 train_time:120368ms step_avg:95.08ms +step:1267/1670 train_time:120462ms step_avg:95.08ms +step:1268/1670 train_time:120557ms step_avg:95.08ms +step:1269/1670 train_time:120650ms step_avg:95.08ms +step:1270/1670 train_time:120744ms step_avg:95.07ms +step:1271/1670 train_time:120837ms step_avg:95.07ms +step:1272/1670 train_time:120931ms step_avg:95.07ms +step:1273/1670 train_time:121028ms step_avg:95.07ms +step:1274/1670 train_time:121477ms step_avg:95.35ms +step:1275/1670 train_time:121548ms step_avg:95.33ms +step:1276/1670 train_time:121641ms step_avg:95.33ms +step:1277/1670 train_time:121734ms step_avg:95.33ms +step:1278/1670 train_time:121827ms step_avg:95.33ms +step:1279/1670 train_time:121920ms step_avg:95.32ms +step:1280/1670 train_time:122013ms step_avg:95.32ms +step:1281/1670 train_time:122107ms step_avg:95.32ms +step:1282/1670 train_time:122200ms step_avg:95.32ms +step:1283/1670 train_time:122294ms step_avg:95.32ms +step:1284/1670 train_time:122394ms step_avg:95.32ms +step:1285/1670 train_time:122492ms step_avg:95.32ms +step:1286/1670 train_time:122588ms step_avg:95.33ms +step:1287/1670 train_time:122682ms step_avg:95.32ms +step:1288/1670 train_time:122776ms step_avg:95.32ms +step:1289/1670 train_time:122870ms step_avg:95.32ms +step:1290/1670 train_time:122964ms step_avg:95.32ms +step:1291/1670 train_time:123057ms step_avg:95.32ms +step:1292/1670 train_time:123151ms step_avg:95.32ms +step:1293/1670 train_time:123245ms step_avg:95.32ms +step:1294/1670 train_time:123340ms step_avg:95.32ms +step:1295/1670 train_time:123436ms step_avg:95.32ms +step:1296/1670 train_time:123533ms step_avg:95.32ms +step:1297/1670 train_time:123629ms step_avg:95.32ms +step:1298/1670 train_time:123723ms step_avg:95.32ms +step:1299/1670 train_time:123817ms step_avg:95.32ms +step:1300/1670 train_time:123911ms step_avg:95.32ms +step:1301/1670 train_time:124005ms step_avg:95.32ms +step:1302/1670 train_time:124099ms step_avg:95.31ms +step:1303/1670 train_time:124192ms step_avg:95.31ms +step:1304/1670 train_time:124287ms step_avg:95.31ms +step:1305/1670 train_time:124382ms step_avg:95.31ms +step:1306/1670 train_time:124477ms step_avg:95.31ms +step:1307/1670 train_time:124572ms step_avg:95.31ms +step:1308/1670 train_time:124667ms step_avg:95.31ms +step:1309/1670 train_time:124761ms step_avg:95.31ms +step:1310/1670 train_time:124856ms step_avg:95.31ms +step:1311/1670 train_time:124951ms step_avg:95.31ms +step:1312/1670 train_time:125045ms step_avg:95.31ms +step:1313/1670 train_time:125139ms step_avg:95.31ms +step:1314/1670 train_time:125233ms step_avg:95.31ms +step:1315/1670 train_time:125328ms step_avg:95.31ms +step:1316/1670 train_time:125423ms step_avg:95.31ms +step:1317/1670 train_time:125517ms step_avg:95.31ms +step:1318/1670 train_time:125613ms step_avg:95.31ms +step:1319/1670 train_time:125708ms step_avg:95.31ms +step:1320/1670 train_time:125803ms step_avg:95.31ms +step:1321/1670 train_time:125899ms step_avg:95.31ms +step:1322/1670 train_time:125992ms step_avg:95.30ms +step:1323/1670 train_time:126087ms step_avg:95.30ms +step:1324/1670 train_time:126180ms step_avg:95.30ms +step:1325/1670 train_time:126275ms step_avg:95.30ms +step:1326/1670 train_time:126371ms step_avg:95.30ms +step:1327/1670 train_time:126465ms step_avg:95.30ms +step:1328/1670 train_time:126561ms step_avg:95.30ms +step:1329/1670 train_time:126656ms step_avg:95.30ms +step:1330/1670 train_time:126751ms step_avg:95.30ms +step:1331/1670 train_time:126846ms step_avg:95.30ms +step:1332/1670 train_time:126940ms step_avg:95.30ms +step:1333/1670 train_time:127035ms step_avg:95.30ms +step:1334/1670 train_time:127129ms step_avg:95.30ms +step:1335/1670 train_time:127223ms step_avg:95.30ms +step:1336/1670 train_time:127317ms step_avg:95.30ms +step:1337/1670 train_time:127412ms step_avg:95.30ms +step:1338/1670 train_time:127508ms step_avg:95.30ms +step:1339/1670 train_time:127603ms step_avg:95.30ms +step:1340/1670 train_time:127696ms step_avg:95.30ms +step:1341/1670 train_time:127794ms step_avg:95.30ms +step:1342/1670 train_time:127889ms step_avg:95.30ms +step:1343/1670 train_time:127983ms step_avg:95.30ms +step:1344/1670 train_time:128077ms step_avg:95.30ms +step:1345/1670 train_time:128171ms step_avg:95.29ms +step:1346/1670 train_time:128266ms step_avg:95.29ms +step:1347/1670 train_time:128360ms step_avg:95.29ms +step:1348/1670 train_time:128454ms step_avg:95.29ms +step:1349/1670 train_time:128548ms step_avg:95.29ms +step:1350/1670 train_time:128643ms step_avg:95.29ms +step:1351/1670 train_time:128739ms step_avg:95.29ms +step:1352/1670 train_time:128833ms step_avg:95.29ms +step:1353/1670 train_time:128928ms step_avg:95.29ms +step:1354/1670 train_time:129023ms step_avg:95.29ms +step:1355/1670 train_time:129117ms step_avg:95.29ms +step:1356/1670 train_time:129211ms step_avg:95.29ms +step:1357/1670 train_time:129306ms step_avg:95.29ms +step:1358/1670 train_time:129399ms step_avg:95.29ms +step:1359/1670 train_time:129495ms step_avg:95.29ms +step:1360/1670 train_time:129589ms step_avg:95.29ms +step:1361/1670 train_time:129684ms step_avg:95.29ms +step:1362/1670 train_time:129778ms step_avg:95.29ms +step:1363/1670 train_time:129874ms step_avg:95.29ms +step:1364/1670 train_time:129969ms step_avg:95.29ms +step:1365/1670 train_time:130064ms step_avg:95.28ms +step:1366/1670 train_time:130157ms step_avg:95.28ms +step:1367/1670 train_time:130252ms step_avg:95.28ms +step:1368/1670 train_time:130348ms step_avg:95.28ms +step:1369/1670 train_time:130441ms step_avg:95.28ms +step:1370/1670 train_time:130536ms step_avg:95.28ms +step:1371/1670 train_time:130630ms step_avg:95.28ms +step:1372/1670 train_time:130725ms step_avg:95.28ms +step:1373/1670 train_time:130820ms step_avg:95.28ms +step:1374/1670 train_time:130916ms step_avg:95.28ms +step:1375/1670 train_time:131011ms step_avg:95.28ms +step:1375/1670 val_loss:3.3430 train_time:131102ms step_avg:95.35ms +step:1376/1670 train_time:131129ms step_avg:95.30ms +step:1377/1670 train_time:131209ms step_avg:95.29ms +step:1378/1670 train_time:131313ms step_avg:95.29ms +step:1379/1670 train_time:131407ms step_avg:95.29ms +step:1380/1670 train_time:131502ms step_avg:95.29ms +step:1381/1670 train_time:131595ms step_avg:95.29ms +step:1382/1670 train_time:131688ms step_avg:95.29ms +step:1383/1670 train_time:131782ms step_avg:95.29ms +step:1384/1670 train_time:131876ms step_avg:95.29ms +step:1385/1670 train_time:131969ms step_avg:95.28ms +step:1386/1670 train_time:132063ms step_avg:95.28ms +step:1387/1670 train_time:132160ms step_avg:95.28ms +step:1388/1670 train_time:132260ms step_avg:95.29ms +step:1389/1670 train_time:132356ms step_avg:95.29ms +step:1390/1670 train_time:132451ms step_avg:95.29ms +step:1391/1670 train_time:132546ms step_avg:95.29ms +step:1392/1670 train_time:132639ms step_avg:95.29ms +step:1393/1670 train_time:132733ms step_avg:95.29ms +step:1394/1670 train_time:132827ms step_avg:95.28ms +step:1395/1670 train_time:132921ms step_avg:95.28ms +step:1396/1670 train_time:133015ms step_avg:95.28ms +step:1397/1670 train_time:133109ms step_avg:95.28ms +step:1398/1670 train_time:133206ms step_avg:95.28ms +step:1399/1670 train_time:133302ms step_avg:95.28ms +step:1400/1670 train_time:133397ms step_avg:95.28ms +step:1401/1670 train_time:133493ms step_avg:95.28ms +step:1402/1670 train_time:133587ms step_avg:95.28ms +step:1403/1670 train_time:133682ms step_avg:95.28ms +step:1404/1670 train_time:133776ms step_avg:95.28ms +step:1405/1670 train_time:133870ms step_avg:95.28ms +step:1406/1670 train_time:133964ms step_avg:95.28ms +step:1407/1670 train_time:134058ms step_avg:95.28ms +step:1408/1670 train_time:134154ms step_avg:95.28ms +step:1409/1670 train_time:134249ms step_avg:95.28ms +step:1410/1670 train_time:134346ms step_avg:95.28ms +step:1411/1670 train_time:134441ms step_avg:95.28ms +step:1412/1670 train_time:134535ms step_avg:95.28ms +step:1413/1670 train_time:134629ms step_avg:95.28ms +step:1414/1670 train_time:134724ms step_avg:95.28ms +step:1415/1670 train_time:134818ms step_avg:95.28ms +step:1416/1670 train_time:134912ms step_avg:95.28ms +step:1417/1670 train_time:135006ms step_avg:95.28ms +step:1418/1670 train_time:135101ms step_avg:95.28ms +step:1419/1670 train_time:135197ms step_avg:95.28ms +step:1420/1670 train_time:135293ms step_avg:95.28ms +step:1421/1670 train_time:135388ms step_avg:95.28ms +step:1422/1670 train_time:135483ms step_avg:95.28ms +step:1423/1670 train_time:135578ms step_avg:95.28ms +step:1424/1670 train_time:135673ms step_avg:95.28ms +step:1425/1670 train_time:135766ms step_avg:95.27ms +step:1426/1670 train_time:135861ms step_avg:95.27ms +step:1427/1670 train_time:135956ms step_avg:95.27ms +step:1428/1670 train_time:136050ms step_avg:95.27ms +step:1429/1670 train_time:136145ms step_avg:95.27ms +step:1430/1670 train_time:136240ms step_avg:95.27ms +step:1431/1670 train_time:136335ms step_avg:95.27ms +step:1432/1670 train_time:136430ms step_avg:95.27ms +step:1433/1670 train_time:136525ms step_avg:95.27ms +step:1434/1670 train_time:136619ms step_avg:95.27ms +step:1435/1670 train_time:136713ms step_avg:95.27ms +step:1436/1670 train_time:136807ms step_avg:95.27ms +step:1437/1670 train_time:136902ms step_avg:95.27ms +step:1438/1670 train_time:136996ms step_avg:95.27ms +step:1439/1670 train_time:137091ms step_avg:95.27ms +step:1440/1670 train_time:137186ms step_avg:95.27ms +step:1441/1670 train_time:137282ms step_avg:95.27ms +step:1442/1670 train_time:137377ms step_avg:95.27ms +step:1443/1670 train_time:137472ms step_avg:95.27ms +step:1444/1670 train_time:137567ms step_avg:95.27ms +step:1445/1670 train_time:137661ms step_avg:95.27ms +step:1446/1670 train_time:137757ms step_avg:95.27ms +step:1447/1670 train_time:137851ms step_avg:95.27ms +step:1448/1670 train_time:137946ms step_avg:95.27ms +step:1449/1670 train_time:138040ms step_avg:95.27ms +step:1450/1670 train_time:138135ms step_avg:95.27ms +step:1451/1670 train_time:138231ms step_avg:95.27ms +step:1452/1670 train_time:138326ms step_avg:95.27ms +step:1453/1670 train_time:138421ms step_avg:95.27ms +step:1454/1670 train_time:138516ms step_avg:95.27ms +step:1455/1670 train_time:138611ms step_avg:95.27ms +step:1456/1670 train_time:138706ms step_avg:95.27ms +step:1457/1670 train_time:138801ms step_avg:95.26ms +step:1458/1670 train_time:138895ms step_avg:95.26ms +step:1459/1670 train_time:138989ms step_avg:95.26ms +step:1460/1670 train_time:139085ms step_avg:95.26ms +step:1461/1670 train_time:139179ms step_avg:95.26ms +step:1462/1670 train_time:139274ms step_avg:95.26ms +step:1463/1670 train_time:139368ms step_avg:95.26ms +step:1464/1670 train_time:139463ms step_avg:95.26ms +step:1465/1670 train_time:139558ms step_avg:95.26ms +step:1466/1670 train_time:139652ms step_avg:95.26ms +step:1467/1670 train_time:139747ms step_avg:95.26ms +step:1468/1670 train_time:139842ms step_avg:95.26ms +step:1469/1670 train_time:139937ms step_avg:95.26ms +step:1470/1670 train_time:140031ms step_avg:95.26ms +step:1471/1670 train_time:140125ms step_avg:95.26ms +step:1472/1670 train_time:140220ms step_avg:95.26ms +step:1473/1670 train_time:140316ms step_avg:95.26ms +step:1474/1670 train_time:140410ms step_avg:95.26ms +step:1475/1670 train_time:140506ms step_avg:95.26ms +step:1476/1670 train_time:140600ms step_avg:95.26ms +step:1477/1670 train_time:140695ms step_avg:95.26ms +step:1478/1670 train_time:140790ms step_avg:95.26ms +step:1479/1670 train_time:140885ms step_avg:95.26ms +step:1480/1670 train_time:140979ms step_avg:95.26ms +step:1481/1670 train_time:141074ms step_avg:95.26ms +step:1482/1670 train_time:141168ms step_avg:95.25ms +step:1483/1670 train_time:141264ms step_avg:95.26ms +step:1484/1670 train_time:141358ms step_avg:95.25ms +step:1485/1670 train_time:141703ms step_avg:95.42ms +step:1486/1670 train_time:141875ms step_avg:95.47ms +step:1487/1670 train_time:141968ms step_avg:95.47ms +step:1488/1670 train_time:142062ms step_avg:95.47ms +step:1489/1670 train_time:142155ms step_avg:95.47ms +step:1490/1670 train_time:142248ms step_avg:95.47ms +step:1491/1670 train_time:142342ms step_avg:95.47ms +step:1492/1670 train_time:142435ms step_avg:95.47ms +step:1493/1670 train_time:142528ms step_avg:95.46ms +step:1494/1670 train_time:142622ms step_avg:95.46ms +step:1495/1670 train_time:142718ms step_avg:95.46ms +step:1496/1670 train_time:142818ms step_avg:95.47ms +step:1497/1670 train_time:142916ms step_avg:95.47ms +step:1498/1670 train_time:143011ms step_avg:95.47ms +step:1499/1670 train_time:143105ms step_avg:95.47ms +step:1500/1670 train_time:143199ms step_avg:95.47ms +step:1500/1670 val_loss:3.3130 train_time:143290ms step_avg:95.53ms +step:1501/1670 train_time:143318ms step_avg:95.48ms +step:1502/1670 train_time:143397ms step_avg:95.47ms +step:1503/1670 train_time:143497ms step_avg:95.47ms +step:1504/1670 train_time:143591ms step_avg:95.47ms +step:1505/1670 train_time:143684ms step_avg:95.47ms +step:1506/1670 train_time:143778ms step_avg:95.47ms +step:1507/1670 train_time:143870ms step_avg:95.47ms +step:1508/1670 train_time:143964ms step_avg:95.47ms +step:1509/1670 train_time:144057ms step_avg:95.47ms +step:1510/1670 train_time:144151ms step_avg:95.46ms +step:1511/1670 train_time:144246ms step_avg:95.46ms +step:1512/1670 train_time:144343ms step_avg:95.46ms +step:1513/1670 train_time:144439ms step_avg:95.47ms +step:1514/1670 train_time:144535ms step_avg:95.47ms +step:1515/1670 train_time:144632ms step_avg:95.47ms +step:1516/1670 train_time:144727ms step_avg:95.47ms +step:1517/1670 train_time:144820ms step_avg:95.46ms +step:1518/1670 train_time:144914ms step_avg:95.46ms +step:1519/1670 train_time:145007ms step_avg:95.46ms +step:1520/1670 train_time:145102ms step_avg:95.46ms +step:1521/1670 train_time:145197ms step_avg:95.46ms +step:1522/1670 train_time:145290ms step_avg:95.46ms +step:1523/1670 train_time:145387ms step_avg:95.46ms +step:1524/1670 train_time:145482ms step_avg:95.46ms +step:1525/1670 train_time:145579ms step_avg:95.46ms +step:1526/1670 train_time:145673ms step_avg:95.46ms +step:1527/1670 train_time:145767ms step_avg:95.46ms +step:1528/1670 train_time:145861ms step_avg:95.46ms +step:1529/1670 train_time:145954ms step_avg:95.46ms +step:1530/1670 train_time:146049ms step_avg:95.46ms +step:1531/1670 train_time:146142ms step_avg:95.46ms +step:1532/1670 train_time:146237ms step_avg:95.46ms +step:1533/1670 train_time:146332ms step_avg:95.45ms +step:1534/1670 train_time:146427ms step_avg:95.45ms +step:1535/1670 train_time:146523ms step_avg:95.45ms +step:1536/1670 train_time:146617ms step_avg:95.45ms +step:1537/1670 train_time:146713ms step_avg:95.45ms +step:1538/1670 train_time:146808ms step_avg:95.45ms +step:1539/1670 train_time:146902ms step_avg:95.45ms +step:1540/1670 train_time:146996ms step_avg:95.45ms +step:1541/1670 train_time:147090ms step_avg:95.45ms +step:1542/1670 train_time:147185ms step_avg:95.45ms +step:1543/1670 train_time:147279ms step_avg:95.45ms +step:1544/1670 train_time:147374ms step_avg:95.45ms +step:1545/1670 train_time:147470ms step_avg:95.45ms +step:1546/1670 train_time:147565ms step_avg:95.45ms +step:1547/1670 train_time:147661ms step_avg:95.45ms +step:1548/1670 train_time:147755ms step_avg:95.45ms +step:1549/1670 train_time:147851ms step_avg:95.45ms +step:1550/1670 train_time:147945ms step_avg:95.45ms +step:1551/1670 train_time:148039ms step_avg:95.45ms +step:1552/1670 train_time:148134ms step_avg:95.45ms +step:1553/1670 train_time:148229ms step_avg:95.45ms +step:1554/1670 train_time:148323ms step_avg:95.45ms +step:1555/1670 train_time:148418ms step_avg:95.45ms +step:1556/1670 train_time:148513ms step_avg:95.45ms +step:1557/1670 train_time:148609ms step_avg:95.45ms +step:1558/1670 train_time:148705ms step_avg:95.45ms +step:1559/1670 train_time:148799ms step_avg:95.45ms +step:1560/1670 train_time:148894ms step_avg:95.44ms +step:1561/1670 train_time:148988ms step_avg:95.44ms +step:1562/1670 train_time:149083ms step_avg:95.44ms +step:1563/1670 train_time:149177ms step_avg:95.44ms +step:1564/1670 train_time:149272ms step_avg:95.44ms +step:1565/1670 train_time:149367ms step_avg:95.44ms +step:1566/1670 train_time:149462ms step_avg:95.44ms +step:1567/1670 train_time:149557ms step_avg:95.44ms +step:1568/1670 train_time:149652ms step_avg:95.44ms +step:1569/1670 train_time:149747ms step_avg:95.44ms +step:1570/1670 train_time:149842ms step_avg:95.44ms +step:1571/1670 train_time:149936ms step_avg:95.44ms +step:1572/1670 train_time:150031ms step_avg:95.44ms +step:1573/1670 train_time:150126ms step_avg:95.44ms +step:1574/1670 train_time:150220ms step_avg:95.44ms +step:1575/1670 train_time:150314ms step_avg:95.44ms +step:1576/1670 train_time:150409ms step_avg:95.44ms +step:1577/1670 train_time:150504ms step_avg:95.44ms +step:1578/1670 train_time:150599ms step_avg:95.44ms +step:1579/1670 train_time:150694ms step_avg:95.44ms +step:1580/1670 train_time:150788ms step_avg:95.44ms +step:1581/1670 train_time:150884ms step_avg:95.44ms +step:1582/1670 train_time:150978ms step_avg:95.44ms +step:1583/1670 train_time:151073ms step_avg:95.43ms +step:1584/1670 train_time:151168ms step_avg:95.43ms +step:1585/1670 train_time:151262ms step_avg:95.43ms +step:1586/1670 train_time:151356ms step_avg:95.43ms +step:1587/1670 train_time:151452ms step_avg:95.43ms +step:1588/1670 train_time:151546ms step_avg:95.43ms +step:1589/1670 train_time:151640ms step_avg:95.43ms +step:1590/1670 train_time:151735ms step_avg:95.43ms +step:1591/1670 train_time:151830ms step_avg:95.43ms +step:1592/1670 train_time:151926ms step_avg:95.43ms +step:1593/1670 train_time:152022ms step_avg:95.43ms +step:1594/1670 train_time:152117ms step_avg:95.43ms +step:1595/1670 train_time:152212ms step_avg:95.43ms +step:1596/1670 train_time:152306ms step_avg:95.43ms +step:1597/1670 train_time:152401ms step_avg:95.43ms +step:1598/1670 train_time:152495ms step_avg:95.43ms +step:1599/1670 train_time:152589ms step_avg:95.43ms +step:1600/1670 train_time:152684ms step_avg:95.43ms +step:1601/1670 train_time:152779ms step_avg:95.43ms +step:1602/1670 train_time:152874ms step_avg:95.43ms +step:1603/1670 train_time:152970ms step_avg:95.43ms +step:1604/1670 train_time:153065ms step_avg:95.43ms +step:1605/1670 train_time:153160ms step_avg:95.43ms +step:1606/1670 train_time:153254ms step_avg:95.43ms +step:1607/1670 train_time:153348ms step_avg:95.43ms +step:1608/1670 train_time:153443ms step_avg:95.42ms +step:1609/1670 train_time:153538ms step_avg:95.42ms +step:1610/1670 train_time:153632ms step_avg:95.42ms +step:1611/1670 train_time:153727ms step_avg:95.42ms +step:1612/1670 train_time:153822ms step_avg:95.42ms +step:1613/1670 train_time:153917ms step_avg:95.42ms +step:1614/1670 train_time:154012ms step_avg:95.42ms +step:1615/1670 train_time:154107ms step_avg:95.42ms +step:1616/1670 train_time:154201ms step_avg:95.42ms +step:1617/1670 train_time:154296ms step_avg:95.42ms +step:1618/1670 train_time:154391ms step_avg:95.42ms +step:1619/1670 train_time:154486ms step_avg:95.42ms +step:1620/1670 train_time:154582ms step_avg:95.42ms +step:1621/1670 train_time:154676ms step_avg:95.42ms +step:1622/1670 train_time:154771ms step_avg:95.42ms +step:1623/1670 train_time:154866ms step_avg:95.42ms +step:1624/1670 train_time:154961ms step_avg:95.42ms +step:1625/1670 train_time:155056ms step_avg:95.42ms +step:1625/1670 val_loss:3.2879 train_time:155148ms step_avg:95.48ms +step:1626/1670 train_time:155176ms step_avg:95.43ms +step:1627/1670 train_time:155251ms step_avg:95.42ms +step:1628/1670 train_time:155351ms step_avg:95.42ms +step:1629/1670 train_time:155448ms step_avg:95.43ms +step:1630/1670 train_time:155542ms step_avg:95.42ms +step:1631/1670 train_time:155636ms step_avg:95.42ms +step:1632/1670 train_time:155729ms step_avg:95.42ms +step:1633/1670 train_time:155822ms step_avg:95.42ms +step:1634/1670 train_time:155916ms step_avg:95.42ms +step:1635/1670 train_time:156010ms step_avg:95.42ms +step:1636/1670 train_time:156105ms step_avg:95.42ms +step:1637/1670 train_time:156203ms step_avg:95.42ms +step:1638/1670 train_time:156302ms step_avg:95.42ms +step:1639/1670 train_time:156398ms step_avg:95.42ms +step:1640/1670 train_time:156494ms step_avg:95.42ms +step:1641/1670 train_time:156589ms step_avg:95.42ms +step:1642/1670 train_time:156682ms step_avg:95.42ms +step:1643/1670 train_time:156776ms step_avg:95.42ms +step:1644/1670 train_time:156870ms step_avg:95.42ms +step:1645/1670 train_time:156964ms step_avg:95.42ms +step:1646/1670 train_time:157058ms step_avg:95.42ms +step:1647/1670 train_time:157153ms step_avg:95.42ms +step:1648/1670 train_time:157250ms step_avg:95.42ms +step:1649/1670 train_time:157345ms step_avg:95.42ms +step:1650/1670 train_time:157441ms step_avg:95.42ms +step:1651/1670 train_time:157536ms step_avg:95.42ms +step:1652/1670 train_time:157630ms step_avg:95.42ms +step:1653/1670 train_time:157725ms step_avg:95.42ms +step:1654/1670 train_time:157819ms step_avg:95.42ms +step:1655/1670 train_time:157913ms step_avg:95.42ms +step:1656/1670 train_time:158008ms step_avg:95.42ms +step:1657/1670 train_time:158102ms step_avg:95.41ms +step:1658/1670 train_time:158197ms step_avg:95.41ms +step:1659/1670 train_time:158293ms step_avg:95.41ms +step:1660/1670 train_time:158388ms step_avg:95.41ms +step:1661/1670 train_time:158482ms step_avg:95.41ms +step:1662/1670 train_time:158577ms step_avg:95.41ms +step:1663/1670 train_time:158671ms step_avg:95.41ms +step:1664/1670 train_time:158766ms step_avg:95.41ms +step:1665/1670 train_time:158861ms step_avg:95.41ms +step:1666/1670 train_time:158954ms step_avg:95.41ms +step:1667/1670 train_time:159048ms step_avg:95.41ms +step:1668/1670 train_time:159144ms step_avg:95.41ms +step:1669/1670 train_time:159239ms step_avg:95.41ms +step:1670/1670 train_time:159334ms step_avg:95.41ms +step:1670/1670 val_loss:3.2789 train_time:159503ms step_avg:95.51ms +peak memory allocated: 32460 MiB reserved: 47576 MiB diff --git a/records/091025_Yarn/783d22ec-c441-4d93-9fd7-cd00d2c473e8.txt b/records/091025_Yarn/783d22ec-c441-4d93-9fd7-cd00d2c473e8.txt new file mode 100644 index 000000000..44ea15333 --- /dev/null +++ b/records/091025_Yarn/783d22ec-c441-4d93-9fd7-cd00d2c473e8.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 05:44:12 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 36C P0 120W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 41C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 42C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 35C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 35C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 42C P0 127W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 40C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 98682 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 98683 C /usr/bin/python3 614MiB | +| 0 N/A N/A 98684 C /usr/bin/python3 614MiB | +| 0 N/A N/A 98685 C /usr/bin/python3 614MiB | +| 0 N/A N/A 98686 C /usr/bin/python3 614MiB | +| 0 N/A N/A 98687 C /usr/bin/python3 614MiB | +| 0 N/A N/A 98688 C /usr/bin/python3 614MiB | +| 0 N/A N/A 98689 C /usr/bin/python3 614MiB | +| 1 N/A N/A 98683 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 98684 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 98685 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 98686 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 98687 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 98688 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 98689 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:458ms step_avg:458.31ms +step:2/1670 train_time:484ms step_avg:241.85ms +step:3/1670 train_time:551ms step_avg:183.55ms +step:4/1670 train_time:641ms step_avg:160.36ms +step:5/1670 train_time:733ms step_avg:146.66ms +step:6/1670 train_time:825ms step_avg:137.57ms +step:7/1670 train_time:917ms step_avg:131.01ms +step:8/1670 train_time:1009ms step_avg:126.17ms +step:9/1670 train_time:1101ms step_avg:122.37ms +step:10/1670 train_time:1193ms step_avg:119.28ms +step:11/1670 train_time:1285ms step_avg:116.78ms +step:12/1670 train_time:1376ms step_avg:114.69ms +step:13/1670 train_time:1471ms step_avg:113.18ms +step:14/1670 train_time:1566ms step_avg:111.86ms +step:15/1670 train_time:1659ms step_avg:110.61ms +step:16/1670 train_time:1752ms step_avg:109.52ms +step:17/1670 train_time:1844ms step_avg:108.49ms +step:18/1670 train_time:1937ms step_avg:107.60ms +step:19/1670 train_time:2029ms step_avg:106.80ms +step:20/1670 train_time:2121ms step_avg:106.05ms +step:21/1670 train_time:2213ms step_avg:105.40ms +step:22/1670 train_time:2306ms step_avg:104.83ms +step:23/1670 train_time:2400ms step_avg:104.33ms +step:24/1670 train_time:2493ms step_avg:103.88ms +step:25/1670 train_time:2587ms step_avg:103.46ms +step:26/1670 train_time:2679ms step_avg:103.05ms +step:27/1670 train_time:2773ms step_avg:102.69ms +step:28/1670 train_time:2866ms step_avg:102.34ms +step:29/1670 train_time:2958ms step_avg:102.00ms +step:30/1670 train_time:3050ms step_avg:101.68ms +step:31/1670 train_time:3143ms step_avg:101.38ms +step:32/1670 train_time:3235ms step_avg:101.09ms +step:33/1670 train_time:3328ms step_avg:100.86ms +step:34/1670 train_time:3421ms step_avg:100.63ms +step:35/1670 train_time:3514ms step_avg:100.41ms +step:36/1670 train_time:3607ms step_avg:100.20ms +step:37/1670 train_time:3700ms step_avg:100.00ms +step:38/1670 train_time:3792ms step_avg:99.80ms +step:39/1670 train_time:3886ms step_avg:99.63ms +step:40/1670 train_time:3977ms step_avg:99.43ms +step:41/1670 train_time:4070ms step_avg:99.27ms +step:42/1670 train_time:4163ms step_avg:99.12ms +step:43/1670 train_time:4255ms step_avg:98.96ms +step:44/1670 train_time:4349ms step_avg:98.84ms +step:45/1670 train_time:4442ms step_avg:98.71ms +step:46/1670 train_time:4536ms step_avg:98.60ms +step:47/1670 train_time:4630ms step_avg:98.50ms +step:48/1670 train_time:4722ms step_avg:98.38ms +step:49/1670 train_time:4815ms step_avg:98.26ms +step:50/1670 train_time:4907ms step_avg:98.15ms +step:51/1670 train_time:5000ms step_avg:98.04ms +step:52/1670 train_time:5093ms step_avg:97.93ms +step:53/1670 train_time:5185ms step_avg:97.82ms +step:54/1670 train_time:5277ms step_avg:97.72ms +step:55/1670 train_time:5370ms step_avg:97.63ms +step:56/1670 train_time:5462ms step_avg:97.54ms +step:57/1670 train_time:5555ms step_avg:97.46ms +step:58/1670 train_time:5649ms step_avg:97.40ms +step:59/1670 train_time:5742ms step_avg:97.32ms +step:60/1670 train_time:5835ms step_avg:97.24ms +step:61/1670 train_time:5927ms step_avg:97.16ms +step:62/1670 train_time:6019ms step_avg:97.08ms +step:63/1670 train_time:6111ms step_avg:97.01ms +step:64/1670 train_time:6204ms step_avg:96.94ms +step:65/1670 train_time:6297ms step_avg:96.87ms +step:66/1670 train_time:6390ms step_avg:96.81ms +step:67/1670 train_time:6482ms step_avg:96.74ms +step:68/1670 train_time:6575ms step_avg:96.69ms +step:69/1670 train_time:6668ms step_avg:96.63ms +step:70/1670 train_time:6761ms step_avg:96.58ms +step:71/1670 train_time:6853ms step_avg:96.52ms +step:72/1670 train_time:6946ms step_avg:96.47ms +step:73/1670 train_time:7038ms step_avg:96.41ms +step:74/1670 train_time:7131ms step_avg:96.37ms +step:75/1670 train_time:7224ms step_avg:96.32ms +step:76/1670 train_time:7316ms step_avg:96.27ms +step:77/1670 train_time:7409ms step_avg:96.23ms +step:78/1670 train_time:7502ms step_avg:96.17ms +step:79/1670 train_time:7595ms step_avg:96.14ms +step:80/1670 train_time:7688ms step_avg:96.10ms +step:81/1670 train_time:7780ms step_avg:96.05ms +step:82/1670 train_time:7873ms step_avg:96.01ms +step:83/1670 train_time:7966ms step_avg:95.98ms +step:84/1670 train_time:8058ms step_avg:95.93ms +step:85/1670 train_time:8151ms step_avg:95.90ms +step:86/1670 train_time:8244ms step_avg:95.86ms +step:87/1670 train_time:8337ms step_avg:95.82ms +step:88/1670 train_time:8430ms step_avg:95.79ms +step:89/1670 train_time:8522ms step_avg:95.76ms +step:90/1670 train_time:8614ms step_avg:95.71ms +step:91/1670 train_time:8707ms step_avg:95.68ms +step:92/1670 train_time:8800ms step_avg:95.65ms +step:93/1670 train_time:8892ms step_avg:95.62ms +step:94/1670 train_time:8985ms step_avg:95.58ms +step:95/1670 train_time:9077ms step_avg:95.54ms +step:96/1670 train_time:9169ms step_avg:95.51ms +step:97/1670 train_time:9262ms step_avg:95.49ms +step:98/1670 train_time:9354ms step_avg:95.45ms +step:99/1670 train_time:9447ms step_avg:95.42ms +step:100/1670 train_time:9540ms step_avg:95.40ms +step:101/1670 train_time:9633ms step_avg:95.38ms +step:102/1670 train_time:9727ms step_avg:95.36ms +step:103/1670 train_time:9818ms step_avg:95.32ms +step:104/1670 train_time:9911ms step_avg:95.30ms +step:105/1670 train_time:10003ms step_avg:95.27ms +step:106/1670 train_time:10096ms step_avg:95.24ms +step:107/1670 train_time:10188ms step_avg:95.21ms +step:108/1670 train_time:10280ms step_avg:95.18ms +step:109/1670 train_time:10373ms step_avg:95.16ms +step:110/1670 train_time:10465ms step_avg:95.14ms +step:111/1670 train_time:10557ms step_avg:95.11ms +step:112/1670 train_time:10650ms step_avg:95.09ms +step:113/1670 train_time:10742ms step_avg:95.07ms +step:114/1670 train_time:10835ms step_avg:95.04ms +step:115/1670 train_time:10928ms step_avg:95.03ms +step:116/1670 train_time:11020ms step_avg:95.00ms +step:117/1670 train_time:11112ms step_avg:94.98ms +step:118/1670 train_time:11205ms step_avg:94.96ms +step:119/1670 train_time:11298ms step_avg:94.94ms +step:120/1670 train_time:11391ms step_avg:94.92ms +step:121/1670 train_time:11482ms step_avg:94.89ms +step:122/1670 train_time:11574ms step_avg:94.87ms +step:123/1670 train_time:11667ms step_avg:94.86ms +step:124/1670 train_time:11760ms step_avg:94.84ms +step:125/1670 train_time:11853ms step_avg:94.82ms +step:125/1670 val_loss:4.2986 train_time:11943ms step_avg:95.54ms +step:126/1670 train_time:11970ms step_avg:95.00ms +step:127/1670 train_time:12041ms step_avg:94.81ms +step:128/1670 train_time:12144ms step_avg:94.88ms +step:129/1670 train_time:12242ms step_avg:94.90ms +step:130/1670 train_time:12336ms step_avg:94.89ms +step:131/1670 train_time:12427ms step_avg:94.86ms +step:132/1670 train_time:12518ms step_avg:94.84ms +step:133/1670 train_time:12609ms step_avg:94.81ms +step:134/1670 train_time:12701ms step_avg:94.78ms +step:135/1670 train_time:12793ms step_avg:94.76ms +step:136/1670 train_time:12884ms step_avg:94.74ms +step:137/1670 train_time:12976ms step_avg:94.71ms +step:138/1670 train_time:13070ms step_avg:94.71ms +step:139/1670 train_time:13164ms step_avg:94.70ms +step:140/1670 train_time:13258ms step_avg:94.70ms +step:141/1670 train_time:13351ms step_avg:94.69ms +step:142/1670 train_time:13443ms step_avg:94.67ms +step:143/1670 train_time:13534ms step_avg:94.65ms +step:144/1670 train_time:13626ms step_avg:94.63ms +step:145/1670 train_time:13718ms step_avg:94.60ms +step:146/1670 train_time:13810ms step_avg:94.59ms +step:147/1670 train_time:13902ms step_avg:94.57ms +step:148/1670 train_time:13993ms step_avg:94.55ms +step:149/1670 train_time:14087ms step_avg:94.54ms +step:150/1670 train_time:14180ms step_avg:94.54ms +step:151/1670 train_time:14273ms step_avg:94.52ms +step:152/1670 train_time:14367ms step_avg:94.52ms +step:153/1670 train_time:14459ms step_avg:94.50ms +step:154/1670 train_time:14551ms step_avg:94.49ms +step:155/1670 train_time:14643ms step_avg:94.47ms +step:156/1670 train_time:14735ms step_avg:94.46ms +step:157/1670 train_time:14828ms step_avg:94.44ms +step:158/1670 train_time:14919ms step_avg:94.43ms +step:159/1670 train_time:15011ms step_avg:94.41ms +step:160/1670 train_time:15104ms step_avg:94.40ms +step:161/1670 train_time:15196ms step_avg:94.38ms +step:162/1670 train_time:15289ms step_avg:94.38ms +step:163/1670 train_time:15382ms step_avg:94.37ms +step:164/1670 train_time:15476ms step_avg:94.36ms +step:165/1670 train_time:15569ms step_avg:94.36ms +step:166/1670 train_time:15661ms step_avg:94.34ms +step:167/1670 train_time:15753ms step_avg:94.33ms +step:168/1670 train_time:15845ms step_avg:94.32ms +step:169/1670 train_time:15937ms step_avg:94.30ms +step:170/1670 train_time:16029ms step_avg:94.29ms +step:171/1670 train_time:16121ms step_avg:94.27ms +step:172/1670 train_time:16213ms step_avg:94.26ms +step:173/1670 train_time:16306ms step_avg:94.26ms +step:174/1670 train_time:16399ms step_avg:94.25ms +step:175/1670 train_time:16491ms step_avg:94.23ms +step:176/1670 train_time:16584ms step_avg:94.23ms +step:177/1670 train_time:16677ms step_avg:94.22ms +step:178/1670 train_time:16769ms step_avg:94.21ms +step:179/1670 train_time:16862ms step_avg:94.20ms +step:180/1670 train_time:16953ms step_avg:94.19ms +step:181/1670 train_time:17046ms step_avg:94.18ms +step:182/1670 train_time:17139ms step_avg:94.17ms +step:183/1670 train_time:17230ms step_avg:94.15ms +step:184/1670 train_time:17324ms step_avg:94.15ms +step:185/1670 train_time:17416ms step_avg:94.14ms +step:186/1670 train_time:17509ms step_avg:94.13ms +step:187/1670 train_time:17602ms step_avg:94.13ms +step:188/1670 train_time:17694ms step_avg:94.12ms +step:189/1670 train_time:17786ms step_avg:94.11ms +step:190/1670 train_time:17879ms step_avg:94.10ms +step:191/1670 train_time:17972ms step_avg:94.09ms +step:192/1670 train_time:18064ms step_avg:94.08ms +step:193/1670 train_time:18156ms step_avg:94.07ms +step:194/1670 train_time:18248ms step_avg:94.06ms +step:195/1670 train_time:18341ms step_avg:94.05ms +step:196/1670 train_time:18433ms step_avg:94.04ms +step:197/1670 train_time:18526ms step_avg:94.04ms +step:198/1670 train_time:18618ms step_avg:94.03ms +step:199/1670 train_time:18710ms step_avg:94.02ms +step:200/1670 train_time:18804ms step_avg:94.02ms +step:201/1670 train_time:18898ms step_avg:94.02ms +step:202/1670 train_time:18989ms step_avg:94.01ms +step:203/1670 train_time:19082ms step_avg:94.00ms +step:204/1670 train_time:19175ms step_avg:94.00ms +step:205/1670 train_time:19268ms step_avg:93.99ms +step:206/1670 train_time:19360ms step_avg:93.98ms +step:207/1670 train_time:19453ms step_avg:93.97ms +step:208/1670 train_time:19545ms step_avg:93.97ms +step:209/1670 train_time:19637ms step_avg:93.95ms +step:210/1670 train_time:19730ms step_avg:93.95ms +step:211/1670 train_time:19823ms step_avg:93.95ms +step:212/1670 train_time:19915ms step_avg:93.94ms +step:213/1670 train_time:20329ms step_avg:95.44ms +step:214/1670 train_time:20395ms step_avg:95.31ms +step:215/1670 train_time:20486ms step_avg:95.28ms +step:216/1670 train_time:20577ms step_avg:95.26ms +step:217/1670 train_time:20668ms step_avg:95.25ms +step:218/1670 train_time:20759ms step_avg:95.23ms +step:219/1670 train_time:20851ms step_avg:95.21ms +step:220/1670 train_time:20942ms step_avg:95.19ms +step:221/1670 train_time:21034ms step_avg:95.17ms +step:222/1670 train_time:21125ms step_avg:95.16ms +step:223/1670 train_time:21218ms step_avg:95.15ms +step:224/1670 train_time:21314ms step_avg:95.15ms +step:225/1670 train_time:21411ms step_avg:95.16ms +step:226/1670 train_time:21504ms step_avg:95.15ms +step:227/1670 train_time:21596ms step_avg:95.14ms +step:228/1670 train_time:21688ms step_avg:95.12ms +step:229/1670 train_time:21780ms step_avg:95.11ms +step:230/1670 train_time:21871ms step_avg:95.09ms +step:231/1670 train_time:21963ms step_avg:95.08ms +step:232/1670 train_time:22055ms step_avg:95.07ms +step:233/1670 train_time:22147ms step_avg:95.05ms +step:234/1670 train_time:22240ms step_avg:95.04ms +step:235/1670 train_time:22333ms step_avg:95.03ms +step:236/1670 train_time:22427ms step_avg:95.03ms +step:237/1670 train_time:22519ms step_avg:95.02ms +step:238/1670 train_time:22611ms step_avg:95.01ms +step:239/1670 train_time:22704ms step_avg:95.00ms +step:240/1670 train_time:22796ms step_avg:94.98ms +step:241/1670 train_time:22889ms step_avg:94.97ms +step:242/1670 train_time:22980ms step_avg:94.96ms +step:243/1670 train_time:23073ms step_avg:94.95ms +step:244/1670 train_time:23165ms step_avg:94.94ms +step:245/1670 train_time:23257ms step_avg:94.92ms +step:246/1670 train_time:23349ms step_avg:94.91ms +step:247/1670 train_time:23443ms step_avg:94.91ms +step:248/1670 train_time:23536ms step_avg:94.90ms +step:249/1670 train_time:23628ms step_avg:94.89ms +step:250/1670 train_time:23721ms step_avg:94.88ms +step:250/1670 val_loss:3.9703 train_time:23810ms step_avg:95.24ms +step:251/1670 train_time:23838ms step_avg:94.97ms +step:252/1670 train_time:23909ms step_avg:94.88ms +step:253/1670 train_time:24007ms step_avg:94.89ms +step:254/1670 train_time:24104ms step_avg:94.90ms +step:255/1670 train_time:24195ms step_avg:94.88ms +step:256/1670 train_time:24287ms step_avg:94.87ms +step:257/1670 train_time:24378ms step_avg:94.86ms +step:258/1670 train_time:24470ms step_avg:94.84ms +step:259/1670 train_time:24561ms step_avg:94.83ms +step:260/1670 train_time:24653ms step_avg:94.82ms +step:261/1670 train_time:24745ms step_avg:94.81ms +step:262/1670 train_time:24838ms step_avg:94.80ms +step:263/1670 train_time:24932ms step_avg:94.80ms +step:264/1670 train_time:25027ms step_avg:94.80ms +step:265/1670 train_time:25119ms step_avg:94.79ms +step:266/1670 train_time:25212ms step_avg:94.78ms +step:267/1670 train_time:25304ms step_avg:94.77ms +step:268/1670 train_time:25396ms step_avg:94.76ms +step:269/1670 train_time:25488ms step_avg:94.75ms +step:270/1670 train_time:25580ms step_avg:94.74ms +step:271/1670 train_time:25671ms step_avg:94.73ms +step:272/1670 train_time:25763ms step_avg:94.72ms +step:273/1670 train_time:25855ms step_avg:94.71ms +step:274/1670 train_time:25948ms step_avg:94.70ms +step:275/1670 train_time:26041ms step_avg:94.70ms +step:276/1670 train_time:26134ms step_avg:94.69ms +step:277/1670 train_time:26228ms step_avg:94.69ms +step:278/1670 train_time:26320ms step_avg:94.68ms +step:279/1670 train_time:26412ms step_avg:94.67ms +step:280/1670 train_time:26504ms step_avg:94.66ms +step:281/1670 train_time:26595ms step_avg:94.65ms +step:282/1670 train_time:26687ms step_avg:94.63ms +step:283/1670 train_time:26779ms step_avg:94.62ms +step:284/1670 train_time:26871ms step_avg:94.62ms +step:285/1670 train_time:26963ms step_avg:94.61ms +step:286/1670 train_time:27056ms step_avg:94.60ms +step:287/1670 train_time:27149ms step_avg:94.60ms +step:288/1670 train_time:27243ms step_avg:94.59ms +step:289/1670 train_time:27335ms step_avg:94.58ms +step:290/1670 train_time:27427ms step_avg:94.58ms +step:291/1670 train_time:27519ms step_avg:94.57ms +step:292/1670 train_time:27611ms step_avg:94.56ms +step:293/1670 train_time:27703ms step_avg:94.55ms +step:294/1670 train_time:27796ms step_avg:94.54ms +step:295/1670 train_time:27889ms step_avg:94.54ms +step:296/1670 train_time:27980ms step_avg:94.53ms +step:297/1670 train_time:28073ms step_avg:94.52ms +step:298/1670 train_time:28165ms step_avg:94.51ms +step:299/1670 train_time:28259ms step_avg:94.51ms +step:300/1670 train_time:28352ms step_avg:94.51ms +step:301/1670 train_time:28444ms step_avg:94.50ms +step:302/1670 train_time:28536ms step_avg:94.49ms +step:303/1670 train_time:28628ms step_avg:94.48ms +step:304/1670 train_time:28720ms step_avg:94.47ms +step:305/1670 train_time:28813ms step_avg:94.47ms +step:306/1670 train_time:28906ms step_avg:94.46ms +step:307/1670 train_time:28998ms step_avg:94.46ms +step:308/1670 train_time:29090ms step_avg:94.45ms +step:309/1670 train_time:29183ms step_avg:94.44ms +step:310/1670 train_time:29276ms step_avg:94.44ms +step:311/1670 train_time:29368ms step_avg:94.43ms +step:312/1670 train_time:29461ms step_avg:94.42ms +step:313/1670 train_time:29553ms step_avg:94.42ms +step:314/1670 train_time:29645ms step_avg:94.41ms +step:315/1670 train_time:29738ms step_avg:94.41ms +step:316/1670 train_time:29831ms step_avg:94.40ms +step:317/1670 train_time:29923ms step_avg:94.40ms +step:318/1670 train_time:30016ms step_avg:94.39ms +step:319/1670 train_time:30108ms step_avg:94.38ms +step:320/1670 train_time:30200ms step_avg:94.38ms +step:321/1670 train_time:30293ms step_avg:94.37ms +step:322/1670 train_time:30385ms step_avg:94.36ms +step:323/1670 train_time:30478ms step_avg:94.36ms +step:324/1670 train_time:30570ms step_avg:94.35ms +step:325/1670 train_time:30662ms step_avg:94.34ms +step:326/1670 train_time:30755ms step_avg:94.34ms +step:327/1670 train_time:30848ms step_avg:94.34ms +step:328/1670 train_time:30940ms step_avg:94.33ms +step:329/1670 train_time:31032ms step_avg:94.32ms +step:330/1670 train_time:31124ms step_avg:94.32ms +step:331/1670 train_time:31217ms step_avg:94.31ms +step:332/1670 train_time:31309ms step_avg:94.30ms +step:333/1670 train_time:31401ms step_avg:94.30ms +step:334/1670 train_time:31495ms step_avg:94.29ms +step:335/1670 train_time:31586ms step_avg:94.29ms +step:336/1670 train_time:31678ms step_avg:94.28ms +step:337/1670 train_time:31771ms step_avg:94.27ms +step:338/1670 train_time:31863ms step_avg:94.27ms +step:339/1670 train_time:31956ms step_avg:94.27ms +step:340/1670 train_time:32049ms step_avg:94.26ms +step:341/1670 train_time:32141ms step_avg:94.25ms +step:342/1670 train_time:32234ms step_avg:94.25ms +step:343/1670 train_time:32326ms step_avg:94.24ms +step:344/1670 train_time:32418ms step_avg:94.24ms +step:345/1670 train_time:32511ms step_avg:94.23ms +step:346/1670 train_time:32603ms step_avg:94.23ms +step:347/1670 train_time:32695ms step_avg:94.22ms +step:348/1670 train_time:32788ms step_avg:94.22ms +step:349/1670 train_time:32880ms step_avg:94.21ms +step:350/1670 train_time:32973ms step_avg:94.21ms +step:351/1670 train_time:33065ms step_avg:94.20ms +step:352/1670 train_time:33159ms step_avg:94.20ms +step:353/1670 train_time:33252ms step_avg:94.20ms +step:354/1670 train_time:33343ms step_avg:94.19ms +step:355/1670 train_time:33436ms step_avg:94.19ms +step:356/1670 train_time:33528ms step_avg:94.18ms +step:357/1670 train_time:33621ms step_avg:94.18ms +step:358/1670 train_time:33714ms step_avg:94.17ms +step:359/1670 train_time:33806ms step_avg:94.17ms +step:360/1670 train_time:33898ms step_avg:94.16ms +step:361/1670 train_time:33991ms step_avg:94.16ms +step:362/1670 train_time:34083ms step_avg:94.15ms +step:363/1670 train_time:34176ms step_avg:94.15ms +step:364/1670 train_time:34269ms step_avg:94.14ms +step:365/1670 train_time:34361ms step_avg:94.14ms +step:366/1670 train_time:34454ms step_avg:94.14ms +step:367/1670 train_time:34545ms step_avg:94.13ms +step:368/1670 train_time:34638ms step_avg:94.13ms +step:369/1670 train_time:34732ms step_avg:94.12ms +step:370/1670 train_time:34824ms step_avg:94.12ms +step:371/1670 train_time:34917ms step_avg:94.12ms +step:372/1670 train_time:35009ms step_avg:94.11ms +step:373/1670 train_time:35102ms step_avg:94.11ms +step:374/1670 train_time:35195ms step_avg:94.10ms +step:375/1670 train_time:35288ms step_avg:94.10ms +step:375/1670 val_loss:3.8130 train_time:35377ms step_avg:94.34ms +step:376/1670 train_time:35404ms step_avg:94.16ms +step:377/1670 train_time:35478ms step_avg:94.11ms +step:378/1670 train_time:35578ms step_avg:94.12ms +step:379/1670 train_time:35672ms step_avg:94.12ms +step:380/1670 train_time:35764ms step_avg:94.12ms +step:381/1670 train_time:35856ms step_avg:94.11ms +step:382/1670 train_time:35947ms step_avg:94.10ms +step:383/1670 train_time:36039ms step_avg:94.10ms +step:384/1670 train_time:36130ms step_avg:94.09ms +step:385/1670 train_time:36222ms step_avg:94.08ms +step:386/1670 train_time:36313ms step_avg:94.08ms +step:387/1670 train_time:36408ms step_avg:94.08ms +step:388/1670 train_time:36502ms step_avg:94.08ms +step:389/1670 train_time:36596ms step_avg:94.08ms +step:390/1670 train_time:36688ms step_avg:94.07ms +step:391/1670 train_time:36782ms step_avg:94.07ms +step:392/1670 train_time:36874ms step_avg:94.07ms +step:393/1670 train_time:36966ms step_avg:94.06ms +step:394/1670 train_time:37057ms step_avg:94.05ms +step:395/1670 train_time:37149ms step_avg:94.05ms +step:396/1670 train_time:37240ms step_avg:94.04ms +step:397/1670 train_time:37332ms step_avg:94.04ms +step:398/1670 train_time:37425ms step_avg:94.03ms +step:399/1670 train_time:37518ms step_avg:94.03ms +step:400/1670 train_time:37611ms step_avg:94.03ms +step:401/1670 train_time:37704ms step_avg:94.03ms +step:402/1670 train_time:37797ms step_avg:94.02ms +step:403/1670 train_time:37889ms step_avg:94.02ms +step:404/1670 train_time:37981ms step_avg:94.01ms +step:405/1670 train_time:38073ms step_avg:94.01ms +step:406/1670 train_time:38165ms step_avg:94.00ms +step:407/1670 train_time:38257ms step_avg:94.00ms +step:408/1670 train_time:38350ms step_avg:93.99ms +step:409/1670 train_time:38443ms step_avg:93.99ms +step:410/1670 train_time:38536ms step_avg:93.99ms +step:411/1670 train_time:38629ms step_avg:93.99ms +step:412/1670 train_time:38721ms step_avg:93.98ms +step:413/1670 train_time:38813ms step_avg:93.98ms +step:414/1670 train_time:38906ms step_avg:93.98ms +step:415/1670 train_time:38999ms step_avg:93.97ms +step:416/1670 train_time:39090ms step_avg:93.97ms +step:417/1670 train_time:39183ms step_avg:93.96ms +step:418/1670 train_time:39275ms step_avg:93.96ms +step:419/1670 train_time:39367ms step_avg:93.96ms +step:420/1670 train_time:39460ms step_avg:93.95ms +step:421/1670 train_time:39552ms step_avg:93.95ms +step:422/1670 train_time:39645ms step_avg:93.94ms +step:423/1670 train_time:39737ms step_avg:93.94ms +step:424/1670 train_time:39829ms step_avg:93.94ms +step:425/1670 train_time:40145ms step_avg:94.46ms +step:426/1670 train_time:40235ms step_avg:94.45ms +step:427/1670 train_time:40325ms step_avg:94.44ms +step:428/1670 train_time:40416ms step_avg:94.43ms +step:429/1670 train_time:40507ms step_avg:94.42ms +step:430/1670 train_time:40599ms step_avg:94.42ms +step:431/1670 train_time:40690ms step_avg:94.41ms +step:432/1670 train_time:40782ms step_avg:94.40ms +step:433/1670 train_time:40873ms step_avg:94.40ms +step:434/1670 train_time:40964ms step_avg:94.39ms +step:435/1670 train_time:41058ms step_avg:94.39ms +step:436/1670 train_time:41155ms step_avg:94.39ms +step:437/1670 train_time:41249ms step_avg:94.39ms +step:438/1670 train_time:41343ms step_avg:94.39ms +step:439/1670 train_time:41435ms step_avg:94.39ms +step:440/1670 train_time:41527ms step_avg:94.38ms +step:441/1670 train_time:41619ms step_avg:94.37ms +step:442/1670 train_time:41710ms step_avg:94.37ms +step:443/1670 train_time:41801ms step_avg:94.36ms +step:444/1670 train_time:41892ms step_avg:94.35ms +step:445/1670 train_time:41985ms step_avg:94.35ms +step:446/1670 train_time:42078ms step_avg:94.35ms +step:447/1670 train_time:42172ms step_avg:94.34ms +step:448/1670 train_time:42265ms step_avg:94.34ms +step:449/1670 train_time:42358ms step_avg:94.34ms +step:450/1670 train_time:42451ms step_avg:94.34ms +step:451/1670 train_time:42543ms step_avg:94.33ms +step:452/1670 train_time:42635ms step_avg:94.32ms +step:453/1670 train_time:42727ms step_avg:94.32ms +step:454/1670 train_time:42818ms step_avg:94.31ms +step:455/1670 train_time:42909ms step_avg:94.31ms +step:456/1670 train_time:43002ms step_avg:94.30ms +step:457/1670 train_time:43095ms step_avg:94.30ms +step:458/1670 train_time:43187ms step_avg:94.30ms +step:459/1670 train_time:43281ms step_avg:94.29ms +step:460/1670 train_time:43374ms step_avg:94.29ms +step:461/1670 train_time:43468ms step_avg:94.29ms +step:462/1670 train_time:43560ms step_avg:94.29ms +step:463/1670 train_time:43652ms step_avg:94.28ms +step:464/1670 train_time:43744ms step_avg:94.28ms +step:465/1670 train_time:43836ms step_avg:94.27ms +step:466/1670 train_time:43927ms step_avg:94.26ms +step:467/1670 train_time:44020ms step_avg:94.26ms +step:468/1670 train_time:44111ms step_avg:94.25ms +step:469/1670 train_time:44204ms step_avg:94.25ms +step:470/1670 train_time:44297ms step_avg:94.25ms +step:471/1670 train_time:44389ms step_avg:94.24ms +step:472/1670 train_time:44484ms step_avg:94.25ms +step:473/1670 train_time:44578ms step_avg:94.24ms +step:474/1670 train_time:44670ms step_avg:94.24ms +step:475/1670 train_time:44762ms step_avg:94.24ms +step:476/1670 train_time:44855ms step_avg:94.23ms +step:477/1670 train_time:44947ms step_avg:94.23ms +step:478/1670 train_time:45038ms step_avg:94.22ms +step:479/1670 train_time:45130ms step_avg:94.22ms +step:480/1670 train_time:45224ms step_avg:94.22ms +step:481/1670 train_time:45316ms step_avg:94.21ms +step:482/1670 train_time:45409ms step_avg:94.21ms +step:483/1670 train_time:45503ms step_avg:94.21ms +step:484/1670 train_time:45596ms step_avg:94.21ms +step:485/1670 train_time:45688ms step_avg:94.20ms +step:486/1670 train_time:45780ms step_avg:94.20ms +step:487/1670 train_time:45873ms step_avg:94.19ms +step:488/1670 train_time:45965ms step_avg:94.19ms +step:489/1670 train_time:46057ms step_avg:94.19ms +step:490/1670 train_time:46149ms step_avg:94.18ms +step:491/1670 train_time:46242ms step_avg:94.18ms +step:492/1670 train_time:46334ms step_avg:94.18ms +step:493/1670 train_time:46427ms step_avg:94.17ms +step:494/1670 train_time:46520ms step_avg:94.17ms +step:495/1670 train_time:46612ms step_avg:94.17ms +step:496/1670 train_time:46704ms step_avg:94.16ms +step:497/1670 train_time:46796ms step_avg:94.16ms +step:498/1670 train_time:46888ms step_avg:94.15ms +step:499/1670 train_time:46981ms step_avg:94.15ms +step:500/1670 train_time:47073ms step_avg:94.15ms +step:500/1670 val_loss:3.7124 train_time:47163ms step_avg:94.33ms +step:501/1670 train_time:47190ms step_avg:94.19ms +step:502/1670 train_time:47261ms step_avg:94.15ms +step:503/1670 train_time:47361ms step_avg:94.16ms +step:504/1670 train_time:47456ms step_avg:94.16ms +step:505/1670 train_time:47549ms step_avg:94.16ms +step:506/1670 train_time:47640ms step_avg:94.15ms +step:507/1670 train_time:47732ms step_avg:94.15ms +step:508/1670 train_time:47823ms step_avg:94.14ms +step:509/1670 train_time:47915ms step_avg:94.13ms +step:510/1670 train_time:48006ms step_avg:94.13ms +step:511/1670 train_time:48098ms step_avg:94.12ms +step:512/1670 train_time:48190ms step_avg:94.12ms +step:513/1670 train_time:48283ms step_avg:94.12ms +step:514/1670 train_time:48379ms step_avg:94.12ms +step:515/1670 train_time:48473ms step_avg:94.12ms +step:516/1670 train_time:48565ms step_avg:94.12ms +step:517/1670 train_time:48658ms step_avg:94.12ms +step:518/1670 train_time:48750ms step_avg:94.11ms +step:519/1670 train_time:48841ms step_avg:94.11ms +step:520/1670 train_time:48933ms step_avg:94.10ms +step:521/1670 train_time:49025ms step_avg:94.10ms +step:522/1670 train_time:49117ms step_avg:94.09ms +step:523/1670 train_time:49209ms step_avg:94.09ms +step:524/1670 train_time:49302ms step_avg:94.09ms +step:525/1670 train_time:49396ms step_avg:94.09ms +step:526/1670 train_time:49489ms step_avg:94.09ms +step:527/1670 train_time:49582ms step_avg:94.08ms +step:528/1670 train_time:49674ms step_avg:94.08ms +step:529/1670 train_time:49766ms step_avg:94.07ms +step:530/1670 train_time:49858ms step_avg:94.07ms +step:531/1670 train_time:49950ms step_avg:94.07ms +step:532/1670 train_time:50042ms step_avg:94.06ms +step:533/1670 train_time:50135ms step_avg:94.06ms +step:534/1670 train_time:50227ms step_avg:94.06ms +step:535/1670 train_time:50320ms step_avg:94.06ms +step:536/1670 train_time:50413ms step_avg:94.05ms +step:537/1670 train_time:50506ms step_avg:94.05ms +step:538/1670 train_time:50598ms step_avg:94.05ms +step:539/1670 train_time:50691ms step_avg:94.05ms +step:540/1670 train_time:50784ms step_avg:94.04ms +step:541/1670 train_time:50876ms step_avg:94.04ms +step:542/1670 train_time:50968ms step_avg:94.04ms +step:543/1670 train_time:51059ms step_avg:94.03ms +step:544/1670 train_time:51152ms step_avg:94.03ms +step:545/1670 train_time:51244ms step_avg:94.03ms +step:546/1670 train_time:51338ms step_avg:94.03ms +step:547/1670 train_time:51432ms step_avg:94.03ms +step:548/1670 train_time:51524ms step_avg:94.02ms +step:549/1670 train_time:51617ms step_avg:94.02ms +step:550/1670 train_time:51709ms step_avg:94.02ms +step:551/1670 train_time:51801ms step_avg:94.01ms +step:552/1670 train_time:51893ms step_avg:94.01ms +step:553/1670 train_time:51985ms step_avg:94.01ms +step:554/1670 train_time:52077ms step_avg:94.00ms +step:555/1670 train_time:52169ms step_avg:94.00ms +step:556/1670 train_time:52261ms step_avg:93.99ms +step:557/1670 train_time:52354ms step_avg:93.99ms +step:558/1670 train_time:52545ms step_avg:94.17ms +step:559/1670 train_time:52624ms step_avg:94.14ms +step:560/1670 train_time:52716ms step_avg:94.14ms +step:561/1670 train_time:52809ms step_avg:94.13ms +step:562/1670 train_time:52901ms step_avg:94.13ms +step:563/1670 train_time:52994ms step_avg:94.13ms +step:564/1670 train_time:53087ms step_avg:94.13ms +step:565/1670 train_time:53179ms step_avg:94.12ms +step:566/1670 train_time:53272ms step_avg:94.12ms +step:567/1670 train_time:53365ms step_avg:94.12ms +step:568/1670 train_time:53464ms step_avg:94.13ms +step:569/1670 train_time:53561ms step_avg:94.13ms +step:570/1670 train_time:53656ms step_avg:94.13ms +step:571/1670 train_time:53749ms step_avg:94.13ms +step:572/1670 train_time:53841ms step_avg:94.13ms +step:573/1670 train_time:53934ms step_avg:94.13ms +step:574/1670 train_time:54028ms step_avg:94.13ms +step:575/1670 train_time:54120ms step_avg:94.12ms +step:576/1670 train_time:54213ms step_avg:94.12ms +step:577/1670 train_time:54306ms step_avg:94.12ms +step:578/1670 train_time:54400ms step_avg:94.12ms +step:579/1670 train_time:54496ms step_avg:94.12ms +step:580/1670 train_time:54590ms step_avg:94.12ms +step:581/1670 train_time:54683ms step_avg:94.12ms +step:582/1670 train_time:54778ms step_avg:94.12ms +step:583/1670 train_time:54871ms step_avg:94.12ms +step:584/1670 train_time:54963ms step_avg:94.11ms +step:585/1670 train_time:55057ms step_avg:94.11ms +step:586/1670 train_time:55150ms step_avg:94.11ms +step:587/1670 train_time:55243ms step_avg:94.11ms +step:588/1670 train_time:55337ms step_avg:94.11ms +step:589/1670 train_time:55432ms step_avg:94.11ms +step:590/1670 train_time:55527ms step_avg:94.11ms +step:591/1670 train_time:55621ms step_avg:94.11ms +step:592/1670 train_time:55715ms step_avg:94.11ms +step:593/1670 train_time:55809ms step_avg:94.11ms +step:594/1670 train_time:55902ms step_avg:94.11ms +step:595/1670 train_time:55996ms step_avg:94.11ms +step:596/1670 train_time:56090ms step_avg:94.11ms +step:597/1670 train_time:56182ms step_avg:94.11ms +step:598/1670 train_time:56276ms step_avg:94.11ms +step:599/1670 train_time:56370ms step_avg:94.11ms +step:600/1670 train_time:56464ms step_avg:94.11ms +step:601/1670 train_time:56560ms step_avg:94.11ms +step:602/1670 train_time:56655ms step_avg:94.11ms +step:603/1670 train_time:56748ms step_avg:94.11ms +step:604/1670 train_time:56842ms step_avg:94.11ms +step:605/1670 train_time:56936ms step_avg:94.11ms +step:606/1670 train_time:57029ms step_avg:94.11ms +step:607/1670 train_time:57122ms step_avg:94.10ms +step:608/1670 train_time:57215ms step_avg:94.10ms +step:609/1670 train_time:57308ms step_avg:94.10ms +step:610/1670 train_time:57402ms step_avg:94.10ms +step:611/1670 train_time:57497ms step_avg:94.10ms +step:612/1670 train_time:57592ms step_avg:94.10ms +step:613/1670 train_time:57685ms step_avg:94.10ms +step:614/1670 train_time:57778ms step_avg:94.10ms +step:615/1670 train_time:57872ms step_avg:94.10ms +step:616/1670 train_time:57966ms step_avg:94.10ms +step:617/1670 train_time:58059ms step_avg:94.10ms +step:618/1670 train_time:58152ms step_avg:94.10ms +step:619/1670 train_time:58245ms step_avg:94.09ms +step:620/1670 train_time:58338ms step_avg:94.09ms +step:621/1670 train_time:58432ms step_avg:94.09ms +step:622/1670 train_time:58526ms step_avg:94.09ms +step:623/1670 train_time:58620ms step_avg:94.09ms +step:624/1670 train_time:58715ms step_avg:94.09ms +step:625/1670 train_time:58808ms step_avg:94.09ms +step:625/1670 val_loss:3.6122 train_time:58899ms step_avg:94.24ms +step:626/1670 train_time:58928ms step_avg:94.13ms +step:627/1670 train_time:59003ms step_avg:94.10ms +step:628/1670 train_time:59101ms step_avg:94.11ms +step:629/1670 train_time:59194ms step_avg:94.11ms +step:630/1670 train_time:59286ms step_avg:94.10ms +step:631/1670 train_time:59378ms step_avg:94.10ms +step:632/1670 train_time:59471ms step_avg:94.10ms +step:633/1670 train_time:59564ms step_avg:94.10ms +step:634/1670 train_time:59656ms step_avg:94.09ms +step:635/1670 train_time:59749ms step_avg:94.09ms +step:636/1670 train_time:59846ms step_avg:94.10ms +step:637/1670 train_time:59944ms step_avg:94.10ms +step:638/1670 train_time:60039ms step_avg:94.11ms +step:639/1670 train_time:60365ms step_avg:94.47ms +step:640/1670 train_time:60558ms step_avg:94.62ms +step:641/1670 train_time:60650ms step_avg:94.62ms +step:642/1670 train_time:60743ms step_avg:94.62ms +step:643/1670 train_time:60836ms step_avg:94.61ms +step:644/1670 train_time:60928ms step_avg:94.61ms +step:645/1670 train_time:61020ms step_avg:94.61ms +step:646/1670 train_time:61113ms step_avg:94.60ms +step:647/1670 train_time:61205ms step_avg:94.60ms +step:648/1670 train_time:61298ms step_avg:94.60ms +step:649/1670 train_time:61394ms step_avg:94.60ms +step:650/1670 train_time:61492ms step_avg:94.60ms +step:651/1670 train_time:61588ms step_avg:94.61ms +step:652/1670 train_time:61682ms step_avg:94.60ms +step:653/1670 train_time:61775ms step_avg:94.60ms +step:654/1670 train_time:61868ms step_avg:94.60ms +step:655/1670 train_time:61960ms step_avg:94.60ms +step:656/1670 train_time:62053ms step_avg:94.59ms +step:657/1670 train_time:62146ms step_avg:94.59ms +step:658/1670 train_time:62239ms step_avg:94.59ms +step:659/1670 train_time:62332ms step_avg:94.59ms +step:660/1670 train_time:62427ms step_avg:94.59ms +step:661/1670 train_time:62522ms step_avg:94.59ms +step:662/1670 train_time:62617ms step_avg:94.59ms +step:663/1670 train_time:62710ms step_avg:94.59ms +step:664/1670 train_time:62805ms step_avg:94.59ms +step:665/1670 train_time:62899ms step_avg:94.58ms +step:666/1670 train_time:62991ms step_avg:94.58ms +step:667/1670 train_time:63084ms step_avg:94.58ms +step:668/1670 train_time:63177ms step_avg:94.58ms +step:669/1670 train_time:63270ms step_avg:94.57ms +step:670/1670 train_time:63364ms step_avg:94.57ms +step:671/1670 train_time:63457ms step_avg:94.57ms +step:672/1670 train_time:63552ms step_avg:94.57ms +step:673/1670 train_time:63646ms step_avg:94.57ms +step:674/1670 train_time:63740ms step_avg:94.57ms +step:675/1670 train_time:63833ms step_avg:94.57ms +step:676/1670 train_time:63927ms step_avg:94.57ms +step:677/1670 train_time:64021ms step_avg:94.57ms +step:678/1670 train_time:64114ms step_avg:94.56ms +step:679/1670 train_time:64208ms step_avg:94.56ms +step:680/1670 train_time:64301ms step_avg:94.56ms +step:681/1670 train_time:64395ms step_avg:94.56ms +step:682/1670 train_time:64488ms step_avg:94.56ms +step:683/1670 train_time:64583ms step_avg:94.56ms +step:684/1670 train_time:64677ms step_avg:94.56ms +step:685/1670 train_time:64771ms step_avg:94.56ms +step:686/1670 train_time:64864ms step_avg:94.55ms +step:687/1670 train_time:64957ms step_avg:94.55ms +step:688/1670 train_time:65050ms step_avg:94.55ms +step:689/1670 train_time:65144ms step_avg:94.55ms +step:690/1670 train_time:65237ms step_avg:94.55ms +step:691/1670 train_time:65331ms step_avg:94.55ms +step:692/1670 train_time:65424ms step_avg:94.54ms +step:693/1670 train_time:65517ms step_avg:94.54ms +step:694/1670 train_time:65611ms step_avg:94.54ms +step:695/1670 train_time:65706ms step_avg:94.54ms +step:696/1670 train_time:65800ms step_avg:94.54ms +step:697/1670 train_time:65892ms step_avg:94.54ms +step:698/1670 train_time:65986ms step_avg:94.54ms +step:699/1670 train_time:66080ms step_avg:94.54ms +step:700/1670 train_time:66173ms step_avg:94.53ms +step:701/1670 train_time:66267ms step_avg:94.53ms +step:702/1670 train_time:66361ms step_avg:94.53ms +step:703/1670 train_time:66454ms step_avg:94.53ms +step:704/1670 train_time:66547ms step_avg:94.53ms +step:705/1670 train_time:66642ms step_avg:94.53ms +step:706/1670 train_time:66735ms step_avg:94.53ms +step:707/1670 train_time:66829ms step_avg:94.52ms +step:708/1670 train_time:66922ms step_avg:94.52ms +step:709/1670 train_time:67015ms step_avg:94.52ms +step:710/1670 train_time:67109ms step_avg:94.52ms +step:711/1670 train_time:67202ms step_avg:94.52ms +step:712/1670 train_time:67296ms step_avg:94.52ms +step:713/1670 train_time:67389ms step_avg:94.51ms +step:714/1670 train_time:67483ms step_avg:94.51ms +step:715/1670 train_time:67576ms step_avg:94.51ms +step:716/1670 train_time:67670ms step_avg:94.51ms +step:717/1670 train_time:67765ms step_avg:94.51ms +step:718/1670 train_time:67858ms step_avg:94.51ms +step:719/1670 train_time:67951ms step_avg:94.51ms +step:720/1670 train_time:68045ms step_avg:94.51ms +step:721/1670 train_time:68139ms step_avg:94.51ms +step:722/1670 train_time:68232ms step_avg:94.50ms +step:723/1670 train_time:68326ms step_avg:94.50ms +step:724/1670 train_time:68419ms step_avg:94.50ms +step:725/1670 train_time:68513ms step_avg:94.50ms +step:726/1670 train_time:68608ms step_avg:94.50ms +step:727/1670 train_time:68701ms step_avg:94.50ms +step:728/1670 train_time:68795ms step_avg:94.50ms +step:729/1670 train_time:68888ms step_avg:94.50ms +step:730/1670 train_time:68983ms step_avg:94.50ms +step:731/1670 train_time:69076ms step_avg:94.49ms +step:732/1670 train_time:69169ms step_avg:94.49ms +step:733/1670 train_time:69263ms step_avg:94.49ms +step:734/1670 train_time:69356ms step_avg:94.49ms +step:735/1670 train_time:69450ms step_avg:94.49ms +step:736/1670 train_time:69544ms step_avg:94.49ms +step:737/1670 train_time:69637ms step_avg:94.49ms +step:738/1670 train_time:69732ms step_avg:94.49ms +step:739/1670 train_time:69825ms step_avg:94.49ms +step:740/1670 train_time:69919ms step_avg:94.48ms +step:741/1670 train_time:70012ms step_avg:94.48ms +step:742/1670 train_time:70106ms step_avg:94.48ms +step:743/1670 train_time:70200ms step_avg:94.48ms +step:744/1670 train_time:70293ms step_avg:94.48ms +step:745/1670 train_time:70386ms step_avg:94.48ms +step:746/1670 train_time:70480ms step_avg:94.48ms +step:747/1670 train_time:70574ms step_avg:94.48ms +step:748/1670 train_time:70667ms step_avg:94.47ms +step:749/1670 train_time:70761ms step_avg:94.47ms +step:750/1670 train_time:70854ms step_avg:94.47ms +step:750/1670 val_loss:3.5615 train_time:70945ms step_avg:94.59ms +step:751/1670 train_time:70972ms step_avg:94.50ms +step:752/1670 train_time:71049ms step_avg:94.48ms +step:753/1670 train_time:71149ms step_avg:94.49ms +step:754/1670 train_time:71244ms step_avg:94.49ms +step:755/1670 train_time:71338ms step_avg:94.49ms +step:756/1670 train_time:71430ms step_avg:94.48ms +step:757/1670 train_time:71523ms step_avg:94.48ms +step:758/1670 train_time:71616ms step_avg:94.48ms +step:759/1670 train_time:71709ms step_avg:94.48ms +step:760/1670 train_time:71802ms step_avg:94.48ms +step:761/1670 train_time:71894ms step_avg:94.47ms +step:762/1670 train_time:71988ms step_avg:94.47ms +step:763/1670 train_time:72084ms step_avg:94.47ms +step:764/1670 train_time:72180ms step_avg:94.48ms +step:765/1670 train_time:72274ms step_avg:94.48ms +step:766/1670 train_time:72367ms step_avg:94.47ms +step:767/1670 train_time:72460ms step_avg:94.47ms +step:768/1670 train_time:72554ms step_avg:94.47ms +step:769/1670 train_time:72646ms step_avg:94.47ms +step:770/1670 train_time:72740ms step_avg:94.47ms +step:771/1670 train_time:72832ms step_avg:94.46ms +step:772/1670 train_time:72925ms step_avg:94.46ms +step:773/1670 train_time:73021ms step_avg:94.46ms +step:774/1670 train_time:73118ms step_avg:94.47ms +step:775/1670 train_time:73213ms step_avg:94.47ms +step:776/1670 train_time:73306ms step_avg:94.47ms +step:777/1670 train_time:73399ms step_avg:94.46ms +step:778/1670 train_time:73492ms step_avg:94.46ms +step:779/1670 train_time:73585ms step_avg:94.46ms +step:780/1670 train_time:73679ms step_avg:94.46ms +step:781/1670 train_time:73772ms step_avg:94.46ms +step:782/1670 train_time:73865ms step_avg:94.46ms +step:783/1670 train_time:73958ms step_avg:94.45ms +step:784/1670 train_time:74053ms step_avg:94.45ms +step:785/1670 train_time:74147ms step_avg:94.45ms +step:786/1670 train_time:74241ms step_avg:94.45ms +step:787/1670 train_time:74335ms step_avg:94.45ms +step:788/1670 train_time:74428ms step_avg:94.45ms +step:789/1670 train_time:74522ms step_avg:94.45ms +step:790/1670 train_time:74616ms step_avg:94.45ms +step:791/1670 train_time:74709ms step_avg:94.45ms +step:792/1670 train_time:74802ms step_avg:94.45ms +step:793/1670 train_time:74895ms step_avg:94.45ms +step:794/1670 train_time:74988ms step_avg:94.44ms +step:795/1670 train_time:75082ms step_avg:94.44ms +step:796/1670 train_time:75178ms step_avg:94.44ms +step:797/1670 train_time:75272ms step_avg:94.44ms +step:798/1670 train_time:75365ms step_avg:94.44ms +step:799/1670 train_time:75459ms step_avg:94.44ms +step:800/1670 train_time:75552ms step_avg:94.44ms +step:801/1670 train_time:75646ms step_avg:94.44ms +step:802/1670 train_time:75740ms step_avg:94.44ms +step:803/1670 train_time:75834ms step_avg:94.44ms +step:804/1670 train_time:75926ms step_avg:94.44ms +step:805/1670 train_time:76019ms step_avg:94.43ms +step:806/1670 train_time:76114ms step_avg:94.43ms +step:807/1670 train_time:76208ms step_avg:94.43ms +step:808/1670 train_time:76302ms step_avg:94.43ms +step:809/1670 train_time:76396ms step_avg:94.43ms +step:810/1670 train_time:76489ms step_avg:94.43ms +step:811/1670 train_time:76583ms step_avg:94.43ms +step:812/1670 train_time:76676ms step_avg:94.43ms +step:813/1670 train_time:76769ms step_avg:94.43ms +step:814/1670 train_time:76863ms step_avg:94.43ms +step:815/1670 train_time:76957ms step_avg:94.43ms +step:816/1670 train_time:77052ms step_avg:94.43ms +step:817/1670 train_time:77146ms step_avg:94.43ms +step:818/1670 train_time:77240ms step_avg:94.43ms +step:819/1670 train_time:77334ms step_avg:94.42ms +step:820/1670 train_time:77426ms step_avg:94.42ms +step:821/1670 train_time:77521ms step_avg:94.42ms +step:822/1670 train_time:77614ms step_avg:94.42ms +step:823/1670 train_time:77707ms step_avg:94.42ms +step:824/1670 train_time:77801ms step_avg:94.42ms +step:825/1670 train_time:77895ms step_avg:94.42ms +step:826/1670 train_time:77988ms step_avg:94.42ms +step:827/1670 train_time:78082ms step_avg:94.42ms +step:828/1670 train_time:78176ms step_avg:94.42ms +step:829/1670 train_time:78270ms step_avg:94.41ms +step:830/1670 train_time:78364ms step_avg:94.41ms +step:831/1670 train_time:78458ms step_avg:94.41ms +step:832/1670 train_time:78551ms step_avg:94.41ms +step:833/1670 train_time:78645ms step_avg:94.41ms +step:834/1670 train_time:78739ms step_avg:94.41ms +step:835/1670 train_time:78832ms step_avg:94.41ms +step:836/1670 train_time:78925ms step_avg:94.41ms +step:837/1670 train_time:79020ms step_avg:94.41ms +step:838/1670 train_time:79114ms step_avg:94.41ms +step:839/1670 train_time:79207ms step_avg:94.41ms +step:840/1670 train_time:79301ms step_avg:94.41ms +step:841/1670 train_time:79395ms step_avg:94.41ms +step:842/1670 train_time:79488ms step_avg:94.40ms +step:843/1670 train_time:79582ms step_avg:94.40ms +step:844/1670 train_time:79676ms step_avg:94.40ms +step:845/1670 train_time:79769ms step_avg:94.40ms +step:846/1670 train_time:79862ms step_avg:94.40ms +step:847/1670 train_time:79956ms step_avg:94.40ms +step:848/1670 train_time:80049ms step_avg:94.40ms +step:849/1670 train_time:80143ms step_avg:94.40ms +step:850/1670 train_time:80237ms step_avg:94.40ms +step:851/1670 train_time:80571ms step_avg:94.68ms +step:852/1670 train_time:80758ms step_avg:94.79ms +step:853/1670 train_time:80849ms step_avg:94.78ms +step:854/1670 train_time:80942ms step_avg:94.78ms +step:855/1670 train_time:81035ms step_avg:94.78ms +step:856/1670 train_time:81127ms step_avg:94.78ms +step:857/1670 train_time:81221ms step_avg:94.77ms +step:858/1670 train_time:81313ms step_avg:94.77ms +step:859/1670 train_time:81406ms step_avg:94.77ms +step:860/1670 train_time:81498ms step_avg:94.77ms +step:861/1670 train_time:81594ms step_avg:94.77ms +step:862/1670 train_time:81691ms step_avg:94.77ms +step:863/1670 train_time:81786ms step_avg:94.77ms +step:864/1670 train_time:81880ms step_avg:94.77ms +step:865/1670 train_time:81973ms step_avg:94.77ms +step:866/1670 train_time:82067ms step_avg:94.77ms +step:867/1670 train_time:82160ms step_avg:94.76ms +step:868/1670 train_time:82252ms step_avg:94.76ms +step:869/1670 train_time:82345ms step_avg:94.76ms +step:870/1670 train_time:82438ms step_avg:94.76ms +step:871/1670 train_time:82532ms step_avg:94.76ms +step:872/1670 train_time:82627ms step_avg:94.76ms +step:873/1670 train_time:82722ms step_avg:94.76ms +step:874/1670 train_time:82817ms step_avg:94.76ms +step:875/1670 train_time:82910ms step_avg:94.75ms +step:875/1670 val_loss:3.5171 train_time:83001ms step_avg:94.86ms +step:876/1670 train_time:83031ms step_avg:94.78ms +step:877/1670 train_time:83103ms step_avg:94.76ms +step:878/1670 train_time:83204ms step_avg:94.77ms +step:879/1670 train_time:83300ms step_avg:94.77ms +step:880/1670 train_time:83393ms step_avg:94.76ms +step:881/1670 train_time:83486ms step_avg:94.76ms +step:882/1670 train_time:83578ms step_avg:94.76ms +step:883/1670 train_time:83671ms step_avg:94.76ms +step:884/1670 train_time:83764ms step_avg:94.76ms +step:885/1670 train_time:83856ms step_avg:94.75ms +step:886/1670 train_time:83950ms step_avg:94.75ms +step:887/1670 train_time:84044ms step_avg:94.75ms +step:888/1670 train_time:84141ms step_avg:94.75ms +step:889/1670 train_time:84238ms step_avg:94.76ms +step:890/1670 train_time:84332ms step_avg:94.76ms +step:891/1670 train_time:84426ms step_avg:94.75ms +step:892/1670 train_time:84519ms step_avg:94.75ms +step:893/1670 train_time:84612ms step_avg:94.75ms +step:894/1670 train_time:84704ms step_avg:94.75ms +step:895/1670 train_time:84797ms step_avg:94.75ms +step:896/1670 train_time:84890ms step_avg:94.74ms +step:897/1670 train_time:84984ms step_avg:94.74ms +step:898/1670 train_time:85078ms step_avg:94.74ms +step:899/1670 train_time:85173ms step_avg:94.74ms +step:900/1670 train_time:85267ms step_avg:94.74ms +step:901/1670 train_time:85361ms step_avg:94.74ms +step:902/1670 train_time:85455ms step_avg:94.74ms +step:903/1670 train_time:85549ms step_avg:94.74ms +step:904/1670 train_time:85642ms step_avg:94.74ms +step:905/1670 train_time:85735ms step_avg:94.73ms +step:906/1670 train_time:85827ms step_avg:94.73ms +step:907/1670 train_time:85921ms step_avg:94.73ms +step:908/1670 train_time:86014ms step_avg:94.73ms +step:909/1670 train_time:86108ms step_avg:94.73ms +step:910/1670 train_time:86203ms step_avg:94.73ms +step:911/1670 train_time:86298ms step_avg:94.73ms +step:912/1670 train_time:86392ms step_avg:94.73ms +step:913/1670 train_time:86485ms step_avg:94.73ms +step:914/1670 train_time:86579ms step_avg:94.73ms +step:915/1670 train_time:86673ms step_avg:94.72ms +step:916/1670 train_time:86765ms step_avg:94.72ms +step:917/1670 train_time:86858ms step_avg:94.72ms +step:918/1670 train_time:86951ms step_avg:94.72ms +step:919/1670 train_time:87045ms step_avg:94.72ms +step:920/1670 train_time:87140ms step_avg:94.72ms +step:921/1670 train_time:87234ms step_avg:94.72ms +step:922/1670 train_time:87328ms step_avg:94.72ms +step:923/1670 train_time:87421ms step_avg:94.71ms +step:924/1670 train_time:87516ms step_avg:94.71ms +step:925/1670 train_time:87609ms step_avg:94.71ms +step:926/1670 train_time:87702ms step_avg:94.71ms +step:927/1670 train_time:87795ms step_avg:94.71ms +step:928/1670 train_time:87889ms step_avg:94.71ms +step:929/1670 train_time:87982ms step_avg:94.71ms +step:930/1670 train_time:88076ms step_avg:94.71ms +step:931/1670 train_time:88170ms step_avg:94.70ms +step:932/1670 train_time:88263ms step_avg:94.70ms +step:933/1670 train_time:88358ms step_avg:94.70ms +step:934/1670 train_time:88452ms step_avg:94.70ms +step:935/1670 train_time:88546ms step_avg:94.70ms +step:936/1670 train_time:88639ms step_avg:94.70ms +step:937/1670 train_time:88733ms step_avg:94.70ms +step:938/1670 train_time:88826ms step_avg:94.70ms +step:939/1670 train_time:88919ms step_avg:94.70ms +step:940/1670 train_time:89013ms step_avg:94.70ms +step:941/1670 train_time:89107ms step_avg:94.69ms +step:942/1670 train_time:89201ms step_avg:94.69ms +step:943/1670 train_time:89294ms step_avg:94.69ms +step:944/1670 train_time:89387ms step_avg:94.69ms +step:945/1670 train_time:89482ms step_avg:94.69ms +step:946/1670 train_time:89576ms step_avg:94.69ms +step:947/1670 train_time:89671ms step_avg:94.69ms +step:948/1670 train_time:89763ms step_avg:94.69ms +step:949/1670 train_time:89857ms step_avg:94.69ms +step:950/1670 train_time:89950ms step_avg:94.68ms +step:951/1670 train_time:90043ms step_avg:94.68ms +step:952/1670 train_time:90137ms step_avg:94.68ms +step:953/1670 train_time:90232ms step_avg:94.68ms +step:954/1670 train_time:90325ms step_avg:94.68ms +step:955/1670 train_time:90419ms step_avg:94.68ms +step:956/1670 train_time:90514ms step_avg:94.68ms +step:957/1670 train_time:90608ms step_avg:94.68ms +step:958/1670 train_time:90701ms step_avg:94.68ms +step:959/1670 train_time:90794ms step_avg:94.68ms +step:960/1670 train_time:90888ms step_avg:94.68ms +step:961/1670 train_time:90982ms step_avg:94.67ms +step:962/1670 train_time:91075ms step_avg:94.67ms +step:963/1670 train_time:91168ms step_avg:94.67ms +step:964/1670 train_time:91262ms step_avg:94.67ms +step:965/1670 train_time:91356ms step_avg:94.67ms +step:966/1670 train_time:91450ms step_avg:94.67ms +step:967/1670 train_time:91544ms step_avg:94.67ms +step:968/1670 train_time:91638ms step_avg:94.67ms +step:969/1670 train_time:91731ms step_avg:94.67ms +step:970/1670 train_time:91824ms step_avg:94.66ms +step:971/1670 train_time:91918ms step_avg:94.66ms +step:972/1670 train_time:92011ms step_avg:94.66ms +step:973/1670 train_time:92104ms step_avg:94.66ms +step:974/1670 train_time:92197ms step_avg:94.66ms +step:975/1670 train_time:92291ms step_avg:94.66ms +step:976/1670 train_time:92385ms step_avg:94.66ms +step:977/1670 train_time:92479ms step_avg:94.66ms +step:978/1670 train_time:92572ms step_avg:94.65ms +step:979/1670 train_time:92666ms step_avg:94.65ms +step:980/1670 train_time:92759ms step_avg:94.65ms +step:981/1670 train_time:92853ms step_avg:94.65ms +step:982/1670 train_time:92946ms step_avg:94.65ms +step:983/1670 train_time:93041ms step_avg:94.65ms +step:984/1670 train_time:93134ms step_avg:94.65ms +step:985/1670 train_time:93228ms step_avg:94.65ms +step:986/1670 train_time:93322ms step_avg:94.65ms +step:987/1670 train_time:93416ms step_avg:94.65ms +step:988/1670 train_time:93510ms step_avg:94.65ms +step:989/1670 train_time:93603ms step_avg:94.64ms +step:990/1670 train_time:93697ms step_avg:94.64ms +step:991/1670 train_time:93791ms step_avg:94.64ms +step:992/1670 train_time:93885ms step_avg:94.64ms +step:993/1670 train_time:93979ms step_avg:94.64ms +step:994/1670 train_time:94073ms step_avg:94.64ms +step:995/1670 train_time:94167ms step_avg:94.64ms +step:996/1670 train_time:94261ms step_avg:94.64ms +step:997/1670 train_time:94354ms step_avg:94.64ms +step:998/1670 train_time:94448ms step_avg:94.64ms +step:999/1670 train_time:94542ms step_avg:94.64ms +step:1000/1670 train_time:94635ms step_avg:94.63ms +step:1000/1670 val_loss:3.4675 train_time:94726ms step_avg:94.73ms +step:1001/1670 train_time:94754ms step_avg:94.66ms +step:1002/1670 train_time:94828ms step_avg:94.64ms +step:1003/1670 train_time:94928ms step_avg:94.64ms +step:1004/1670 train_time:95023ms step_avg:94.64ms +step:1005/1670 train_time:95117ms step_avg:94.64ms +step:1006/1670 train_time:95209ms step_avg:94.64ms +step:1007/1670 train_time:95302ms step_avg:94.64ms +step:1008/1670 train_time:95394ms step_avg:94.64ms +step:1009/1670 train_time:95487ms step_avg:94.64ms +step:1010/1670 train_time:95580ms step_avg:94.63ms +step:1011/1670 train_time:95674ms step_avg:94.63ms +step:1012/1670 train_time:95767ms step_avg:94.63ms +step:1013/1670 train_time:95864ms step_avg:94.63ms +step:1014/1670 train_time:95958ms step_avg:94.63ms +step:1015/1670 train_time:96053ms step_avg:94.63ms +step:1016/1670 train_time:96147ms step_avg:94.63ms +step:1017/1670 train_time:96240ms step_avg:94.63ms +step:1018/1670 train_time:96334ms step_avg:94.63ms +step:1019/1670 train_time:96427ms step_avg:94.63ms +step:1020/1670 train_time:96521ms step_avg:94.63ms +step:1021/1670 train_time:96614ms step_avg:94.63ms +step:1022/1670 train_time:96707ms step_avg:94.63ms +step:1023/1670 train_time:96801ms step_avg:94.62ms +step:1024/1670 train_time:96896ms step_avg:94.62ms +step:1025/1670 train_time:96990ms step_avg:94.62ms +step:1026/1670 train_time:97085ms step_avg:94.62ms +step:1027/1670 train_time:97178ms step_avg:94.62ms +step:1028/1670 train_time:97272ms step_avg:94.62ms +step:1029/1670 train_time:97365ms step_avg:94.62ms +step:1030/1670 train_time:97458ms step_avg:94.62ms +step:1031/1670 train_time:97552ms step_avg:94.62ms +step:1032/1670 train_time:97645ms step_avg:94.62ms +step:1033/1670 train_time:97739ms step_avg:94.62ms +step:1034/1670 train_time:97833ms step_avg:94.62ms +step:1035/1670 train_time:97926ms step_avg:94.61ms +step:1036/1670 train_time:98022ms step_avg:94.62ms +step:1037/1670 train_time:98116ms step_avg:94.62ms +step:1038/1670 train_time:98209ms step_avg:94.61ms +step:1039/1670 train_time:98304ms step_avg:94.61ms +step:1040/1670 train_time:98397ms step_avg:94.61ms +step:1041/1670 train_time:98490ms step_avg:94.61ms +step:1042/1670 train_time:98583ms step_avg:94.61ms +step:1043/1670 train_time:98677ms step_avg:94.61ms +step:1044/1670 train_time:98771ms step_avg:94.61ms +step:1045/1670 train_time:98865ms step_avg:94.61ms +step:1046/1670 train_time:98960ms step_avg:94.61ms +step:1047/1670 train_time:99054ms step_avg:94.61ms +step:1048/1670 train_time:99147ms step_avg:94.61ms +step:1049/1670 train_time:99241ms step_avg:94.61ms +step:1050/1670 train_time:99335ms step_avg:94.60ms +step:1051/1670 train_time:99429ms step_avg:94.60ms +step:1052/1670 train_time:99522ms step_avg:94.60ms +step:1053/1670 train_time:99615ms step_avg:94.60ms +step:1054/1670 train_time:99708ms step_avg:94.60ms +step:1055/1670 train_time:99802ms step_avg:94.60ms +step:1056/1670 train_time:99896ms step_avg:94.60ms +step:1057/1670 train_time:99990ms step_avg:94.60ms +step:1058/1670 train_time:100084ms step_avg:94.60ms +step:1059/1670 train_time:100179ms step_avg:94.60ms +step:1060/1670 train_time:100273ms step_avg:94.60ms +step:1061/1670 train_time:100367ms step_avg:94.60ms +step:1062/1670 train_time:100712ms step_avg:94.83ms +step:1063/1670 train_time:100883ms step_avg:94.90ms +step:1064/1670 train_time:100976ms step_avg:94.90ms +step:1065/1670 train_time:101068ms step_avg:94.90ms +step:1066/1670 train_time:101160ms step_avg:94.90ms +step:1067/1670 train_time:101253ms step_avg:94.89ms +step:1068/1670 train_time:101345ms step_avg:94.89ms +step:1069/1670 train_time:101438ms step_avg:94.89ms +step:1070/1670 train_time:101530ms step_avg:94.89ms +step:1071/1670 train_time:101622ms step_avg:94.88ms +step:1072/1670 train_time:101718ms step_avg:94.89ms +step:1073/1670 train_time:101814ms step_avg:94.89ms +step:1074/1670 train_time:101909ms step_avg:94.89ms +step:1075/1670 train_time:102004ms step_avg:94.89ms +step:1076/1670 train_time:102097ms step_avg:94.89ms +step:1077/1670 train_time:102191ms step_avg:94.89ms +step:1078/1670 train_time:102284ms step_avg:94.88ms +step:1079/1670 train_time:102378ms step_avg:94.88ms +step:1080/1670 train_time:102471ms step_avg:94.88ms +step:1081/1670 train_time:102563ms step_avg:94.88ms +step:1082/1670 train_time:102658ms step_avg:94.88ms +step:1083/1670 train_time:102752ms step_avg:94.88ms +step:1084/1670 train_time:102846ms step_avg:94.88ms +step:1085/1670 train_time:102940ms step_avg:94.88ms +step:1086/1670 train_time:103034ms step_avg:94.87ms +step:1087/1670 train_time:103128ms step_avg:94.87ms +step:1088/1670 train_time:103221ms step_avg:94.87ms +step:1089/1670 train_time:103314ms step_avg:94.87ms +step:1090/1670 train_time:103407ms step_avg:94.87ms +step:1091/1670 train_time:103500ms step_avg:94.87ms +step:1092/1670 train_time:103594ms step_avg:94.87ms +step:1093/1670 train_time:103689ms step_avg:94.87ms +step:1094/1670 train_time:103783ms step_avg:94.87ms +step:1095/1670 train_time:103877ms step_avg:94.87ms +step:1096/1670 train_time:103971ms step_avg:94.86ms +step:1097/1670 train_time:104065ms step_avg:94.86ms +step:1098/1670 train_time:104158ms step_avg:94.86ms +step:1099/1670 train_time:104252ms step_avg:94.86ms +step:1100/1670 train_time:104345ms step_avg:94.86ms +step:1101/1670 train_time:104439ms step_avg:94.86ms +step:1102/1670 train_time:104533ms step_avg:94.86ms +step:1103/1670 train_time:104626ms step_avg:94.86ms +step:1104/1670 train_time:104720ms step_avg:94.86ms +step:1105/1670 train_time:104815ms step_avg:94.86ms +step:1106/1670 train_time:104908ms step_avg:94.85ms +step:1107/1670 train_time:105002ms step_avg:94.85ms +step:1108/1670 train_time:105096ms step_avg:94.85ms +step:1109/1670 train_time:105189ms step_avg:94.85ms +step:1110/1670 train_time:105283ms step_avg:94.85ms +step:1111/1670 train_time:105376ms step_avg:94.85ms +step:1112/1670 train_time:105469ms step_avg:94.85ms +step:1113/1670 train_time:105562ms step_avg:94.84ms +step:1114/1670 train_time:105657ms step_avg:94.84ms +step:1115/1670 train_time:105858ms step_avg:94.94ms +step:1116/1670 train_time:105929ms step_avg:94.92ms +step:1117/1670 train_time:106022ms step_avg:94.92ms +step:1118/1670 train_time:106115ms step_avg:94.92ms +step:1119/1670 train_time:106208ms step_avg:94.91ms +step:1120/1670 train_time:106301ms step_avg:94.91ms +step:1121/1670 train_time:106395ms step_avg:94.91ms +step:1122/1670 train_time:106488ms step_avg:94.91ms +step:1123/1670 train_time:106581ms step_avg:94.91ms +step:1124/1670 train_time:106674ms step_avg:94.91ms +step:1125/1670 train_time:106774ms step_avg:94.91ms +step:1125/1670 val_loss:3.4153 train_time:106871ms step_avg:95.00ms +step:1126/1670 train_time:106898ms step_avg:94.94ms +step:1127/1670 train_time:106978ms step_avg:94.92ms +step:1128/1670 train_time:107080ms step_avg:94.93ms +step:1129/1670 train_time:107175ms step_avg:94.93ms +step:1130/1670 train_time:107268ms step_avg:94.93ms +step:1131/1670 train_time:107361ms step_avg:94.93ms +step:1132/1670 train_time:107455ms step_avg:94.92ms +step:1133/1670 train_time:107548ms step_avg:94.92ms +step:1134/1670 train_time:107641ms step_avg:94.92ms +step:1135/1670 train_time:107735ms step_avg:94.92ms +step:1136/1670 train_time:107832ms step_avg:94.92ms +step:1137/1670 train_time:107929ms step_avg:94.92ms +step:1138/1670 train_time:108025ms step_avg:94.93ms +step:1139/1670 train_time:108120ms step_avg:94.93ms +step:1140/1670 train_time:108215ms step_avg:94.93ms +step:1141/1670 train_time:108309ms step_avg:94.92ms +step:1142/1670 train_time:108402ms step_avg:94.92ms +step:1143/1670 train_time:108496ms step_avg:94.92ms +step:1144/1670 train_time:108590ms step_avg:94.92ms +step:1145/1670 train_time:108683ms step_avg:94.92ms +step:1146/1670 train_time:108778ms step_avg:94.92ms +step:1147/1670 train_time:108875ms step_avg:94.92ms +step:1148/1670 train_time:108971ms step_avg:94.92ms +step:1149/1670 train_time:109066ms step_avg:94.92ms +step:1150/1670 train_time:109160ms step_avg:94.92ms +step:1151/1670 train_time:109256ms step_avg:94.92ms +step:1152/1670 train_time:109350ms step_avg:94.92ms +step:1153/1670 train_time:109443ms step_avg:94.92ms +step:1154/1670 train_time:109536ms step_avg:94.92ms +step:1155/1670 train_time:109631ms step_avg:94.92ms +step:1156/1670 train_time:109725ms step_avg:94.92ms +step:1157/1670 train_time:109820ms step_avg:94.92ms +step:1158/1670 train_time:109915ms step_avg:94.92ms +step:1159/1670 train_time:110011ms step_avg:94.92ms +step:1160/1670 train_time:110106ms step_avg:94.92ms +step:1161/1670 train_time:110200ms step_avg:94.92ms +step:1162/1670 train_time:110295ms step_avg:94.92ms +step:1163/1670 train_time:110390ms step_avg:94.92ms +step:1164/1670 train_time:110485ms step_avg:94.92ms +step:1165/1670 train_time:110578ms step_avg:94.92ms +step:1166/1670 train_time:110672ms step_avg:94.92ms +step:1167/1670 train_time:110766ms step_avg:94.92ms +step:1168/1670 train_time:110860ms step_avg:94.91ms +step:1169/1670 train_time:110955ms step_avg:94.91ms +step:1170/1670 train_time:111050ms step_avg:94.91ms +step:1171/1670 train_time:111145ms step_avg:94.91ms +step:1172/1670 train_time:111240ms step_avg:94.91ms +step:1173/1670 train_time:111334ms step_avg:94.91ms +step:1174/1670 train_time:111429ms step_avg:94.91ms +step:1175/1670 train_time:111524ms step_avg:94.91ms +step:1176/1670 train_time:111618ms step_avg:94.91ms +step:1177/1670 train_time:111712ms step_avg:94.91ms +step:1178/1670 train_time:111806ms step_avg:94.91ms +step:1179/1670 train_time:111900ms step_avg:94.91ms +step:1180/1670 train_time:111994ms step_avg:94.91ms +step:1181/1670 train_time:112090ms step_avg:94.91ms +step:1182/1670 train_time:112184ms step_avg:94.91ms +step:1183/1670 train_time:112278ms step_avg:94.91ms +step:1184/1670 train_time:112373ms step_avg:94.91ms +step:1185/1670 train_time:112468ms step_avg:94.91ms +step:1186/1670 train_time:112561ms step_avg:94.91ms +step:1187/1670 train_time:112656ms step_avg:94.91ms +step:1188/1670 train_time:112750ms step_avg:94.91ms +step:1189/1670 train_time:112845ms step_avg:94.91ms +step:1190/1670 train_time:112938ms step_avg:94.91ms +step:1191/1670 train_time:113033ms step_avg:94.91ms +step:1192/1670 train_time:113129ms step_avg:94.91ms +step:1193/1670 train_time:113223ms step_avg:94.91ms +step:1194/1670 train_time:113317ms step_avg:94.91ms +step:1195/1670 train_time:113412ms step_avg:94.91ms +step:1196/1670 train_time:113507ms step_avg:94.91ms +step:1197/1670 train_time:113602ms step_avg:94.91ms +step:1198/1670 train_time:113696ms step_avg:94.91ms +step:1199/1670 train_time:113791ms step_avg:94.91ms +step:1200/1670 train_time:113886ms step_avg:94.90ms +step:1201/1670 train_time:113979ms step_avg:94.90ms +step:1202/1670 train_time:114073ms step_avg:94.90ms +step:1203/1670 train_time:114168ms step_avg:94.90ms +step:1204/1670 train_time:114262ms step_avg:94.90ms +step:1205/1670 train_time:114358ms step_avg:94.90ms +step:1206/1670 train_time:114452ms step_avg:94.90ms +step:1207/1670 train_time:114546ms step_avg:94.90ms +step:1208/1670 train_time:114641ms step_avg:94.90ms +step:1209/1670 train_time:114735ms step_avg:94.90ms +step:1210/1670 train_time:114830ms step_avg:94.90ms +step:1211/1670 train_time:114925ms step_avg:94.90ms +step:1212/1670 train_time:115018ms step_avg:94.90ms +step:1213/1670 train_time:115113ms step_avg:94.90ms +step:1214/1670 train_time:115208ms step_avg:94.90ms +step:1215/1670 train_time:115302ms step_avg:94.90ms +step:1216/1670 train_time:115396ms step_avg:94.90ms +step:1217/1670 train_time:115493ms step_avg:94.90ms +step:1218/1670 train_time:115587ms step_avg:94.90ms +step:1219/1670 train_time:115681ms step_avg:94.90ms +step:1220/1670 train_time:115776ms step_avg:94.90ms +step:1221/1670 train_time:115870ms step_avg:94.90ms +step:1222/1670 train_time:115964ms step_avg:94.90ms +step:1223/1670 train_time:116058ms step_avg:94.90ms +step:1224/1670 train_time:116153ms step_avg:94.90ms +step:1225/1670 train_time:116248ms step_avg:94.90ms +step:1226/1670 train_time:116342ms step_avg:94.90ms +step:1227/1670 train_time:116436ms step_avg:94.89ms +step:1228/1670 train_time:116530ms step_avg:94.89ms +step:1229/1670 train_time:116625ms step_avg:94.89ms +step:1230/1670 train_time:116719ms step_avg:94.89ms +step:1231/1670 train_time:116813ms step_avg:94.89ms +step:1232/1670 train_time:116908ms step_avg:94.89ms +step:1233/1670 train_time:117002ms step_avg:94.89ms +step:1234/1670 train_time:117097ms step_avg:94.89ms +step:1235/1670 train_time:117192ms step_avg:94.89ms +step:1236/1670 train_time:117286ms step_avg:94.89ms +step:1237/1670 train_time:117380ms step_avg:94.89ms +step:1238/1670 train_time:117475ms step_avg:94.89ms +step:1239/1670 train_time:117569ms step_avg:94.89ms +step:1240/1670 train_time:117664ms step_avg:94.89ms +step:1241/1670 train_time:117759ms step_avg:94.89ms +step:1242/1670 train_time:117854ms step_avg:94.89ms +step:1243/1670 train_time:117949ms step_avg:94.89ms +step:1244/1670 train_time:118042ms step_avg:94.89ms +step:1245/1670 train_time:118137ms step_avg:94.89ms +step:1246/1670 train_time:118232ms step_avg:94.89ms +step:1247/1670 train_time:118326ms step_avg:94.89ms +step:1248/1670 train_time:118420ms step_avg:94.89ms +step:1249/1670 train_time:118514ms step_avg:94.89ms +step:1250/1670 train_time:118609ms step_avg:94.89ms +step:1250/1670 val_loss:3.3756 train_time:118702ms step_avg:94.96ms +step:1251/1670 train_time:118729ms step_avg:94.91ms +step:1252/1670 train_time:118803ms step_avg:94.89ms +step:1253/1670 train_time:118902ms step_avg:94.89ms +step:1254/1670 train_time:118996ms step_avg:94.89ms +step:1255/1670 train_time:119089ms step_avg:94.89ms +step:1256/1670 train_time:119183ms step_avg:94.89ms +step:1257/1670 train_time:119276ms step_avg:94.89ms +step:1258/1670 train_time:119370ms step_avg:94.89ms +step:1259/1670 train_time:119463ms step_avg:94.89ms +step:1260/1670 train_time:119556ms step_avg:94.89ms +step:1261/1670 train_time:119651ms step_avg:94.89ms +step:1262/1670 train_time:119747ms step_avg:94.89ms +step:1263/1670 train_time:119844ms step_avg:94.89ms +step:1264/1670 train_time:119941ms step_avg:94.89ms +step:1265/1670 train_time:120035ms step_avg:94.89ms +step:1266/1670 train_time:120129ms step_avg:94.89ms +step:1267/1670 train_time:120223ms step_avg:94.89ms +step:1268/1670 train_time:120316ms step_avg:94.89ms +step:1269/1670 train_time:120411ms step_avg:94.89ms +step:1270/1670 train_time:120504ms step_avg:94.88ms +step:1271/1670 train_time:120597ms step_avg:94.88ms +step:1272/1670 train_time:120693ms step_avg:94.88ms +step:1273/1670 train_time:120790ms step_avg:94.89ms +step:1274/1670 train_time:121237ms step_avg:95.16ms +step:1275/1670 train_time:121307ms step_avg:95.14ms +step:1276/1670 train_time:121399ms step_avg:95.14ms +step:1277/1670 train_time:121493ms step_avg:95.14ms +step:1278/1670 train_time:121586ms step_avg:95.14ms +step:1279/1670 train_time:121679ms step_avg:95.14ms +step:1280/1670 train_time:121773ms step_avg:95.14ms +step:1281/1670 train_time:121866ms step_avg:95.13ms +step:1282/1670 train_time:121959ms step_avg:95.13ms +step:1283/1670 train_time:122053ms step_avg:95.13ms +step:1284/1670 train_time:122151ms step_avg:95.13ms +step:1285/1670 train_time:122250ms step_avg:95.14ms +step:1286/1670 train_time:122347ms step_avg:95.14ms +step:1287/1670 train_time:122441ms step_avg:95.14ms +step:1288/1670 train_time:122535ms step_avg:95.14ms +step:1289/1670 train_time:122629ms step_avg:95.14ms +step:1290/1670 train_time:122723ms step_avg:95.13ms +step:1291/1670 train_time:122816ms step_avg:95.13ms +step:1292/1670 train_time:122911ms step_avg:95.13ms +step:1293/1670 train_time:123005ms step_avg:95.13ms +step:1294/1670 train_time:123101ms step_avg:95.13ms +step:1295/1670 train_time:123197ms step_avg:95.13ms +step:1296/1670 train_time:123292ms step_avg:95.13ms +step:1297/1670 train_time:123388ms step_avg:95.13ms +step:1298/1670 train_time:123482ms step_avg:95.13ms +step:1299/1670 train_time:123577ms step_avg:95.13ms +step:1300/1670 train_time:123670ms step_avg:95.13ms +step:1301/1670 train_time:123766ms step_avg:95.13ms +step:1302/1670 train_time:123859ms step_avg:95.13ms +step:1303/1670 train_time:123953ms step_avg:95.13ms +step:1304/1670 train_time:124047ms step_avg:95.13ms +step:1305/1670 train_time:124142ms step_avg:95.13ms +step:1306/1670 train_time:124237ms step_avg:95.13ms +step:1307/1670 train_time:124333ms step_avg:95.13ms +step:1308/1670 train_time:124428ms step_avg:95.13ms +step:1309/1670 train_time:124522ms step_avg:95.13ms +step:1310/1670 train_time:124616ms step_avg:95.13ms +step:1311/1670 train_time:124711ms step_avg:95.13ms +step:1312/1670 train_time:124805ms step_avg:95.13ms +step:1313/1670 train_time:124899ms step_avg:95.12ms +step:1314/1670 train_time:124994ms step_avg:95.12ms +step:1315/1670 train_time:125088ms step_avg:95.12ms +step:1316/1670 train_time:125183ms step_avg:95.12ms +step:1317/1670 train_time:125277ms step_avg:95.12ms +step:1318/1670 train_time:125371ms step_avg:95.12ms +step:1319/1670 train_time:125467ms step_avg:95.12ms +step:1320/1670 train_time:125562ms step_avg:95.12ms +step:1321/1670 train_time:125656ms step_avg:95.12ms +step:1322/1670 train_time:125750ms step_avg:95.12ms +step:1323/1670 train_time:125845ms step_avg:95.12ms +step:1324/1670 train_time:125937ms step_avg:95.12ms +step:1325/1670 train_time:126031ms step_avg:95.12ms +step:1326/1670 train_time:126127ms step_avg:95.12ms +step:1327/1670 train_time:126221ms step_avg:95.12ms +step:1328/1670 train_time:126316ms step_avg:95.12ms +step:1329/1670 train_time:126411ms step_avg:95.12ms +step:1330/1670 train_time:126506ms step_avg:95.12ms +step:1331/1670 train_time:126601ms step_avg:95.12ms +step:1332/1670 train_time:126694ms step_avg:95.12ms +step:1333/1670 train_time:126789ms step_avg:95.12ms +step:1334/1670 train_time:126882ms step_avg:95.11ms +step:1335/1670 train_time:126976ms step_avg:95.11ms +step:1336/1670 train_time:127070ms step_avg:95.11ms +step:1337/1670 train_time:127166ms step_avg:95.11ms +step:1338/1670 train_time:127261ms step_avg:95.11ms +step:1339/1670 train_time:127354ms step_avg:95.11ms +step:1340/1670 train_time:127450ms step_avg:95.11ms +step:1341/1670 train_time:127545ms step_avg:95.11ms +step:1342/1670 train_time:127640ms step_avg:95.11ms +step:1343/1670 train_time:127734ms step_avg:95.11ms +step:1344/1670 train_time:127829ms step_avg:95.11ms +step:1345/1670 train_time:127923ms step_avg:95.11ms +step:1346/1670 train_time:128017ms step_avg:95.11ms +step:1347/1670 train_time:128111ms step_avg:95.11ms +step:1348/1670 train_time:128205ms step_avg:95.11ms +step:1349/1670 train_time:128299ms step_avg:95.11ms +step:1350/1670 train_time:128394ms step_avg:95.11ms +step:1351/1670 train_time:128489ms step_avg:95.11ms +step:1352/1670 train_time:128584ms step_avg:95.11ms +step:1353/1670 train_time:128678ms step_avg:95.11ms +step:1354/1670 train_time:128772ms step_avg:95.11ms +step:1355/1670 train_time:128866ms step_avg:95.10ms +step:1356/1670 train_time:128961ms step_avg:95.10ms +step:1357/1670 train_time:129056ms step_avg:95.10ms +step:1358/1670 train_time:129150ms step_avg:95.10ms +step:1359/1670 train_time:129246ms step_avg:95.10ms +step:1360/1670 train_time:129341ms step_avg:95.10ms +step:1361/1670 train_time:129434ms step_avg:95.10ms +step:1362/1670 train_time:129530ms step_avg:95.10ms +step:1363/1670 train_time:129624ms step_avg:95.10ms +step:1364/1670 train_time:129717ms step_avg:95.10ms +step:1365/1670 train_time:129812ms step_avg:95.10ms +step:1366/1670 train_time:129906ms step_avg:95.10ms +step:1367/1670 train_time:130001ms step_avg:95.10ms +step:1368/1670 train_time:130096ms step_avg:95.10ms +step:1369/1670 train_time:130190ms step_avg:95.10ms +step:1370/1670 train_time:130286ms step_avg:95.10ms +step:1371/1670 train_time:130381ms step_avg:95.10ms +step:1372/1670 train_time:130474ms step_avg:95.10ms +step:1373/1670 train_time:130570ms step_avg:95.10ms +step:1374/1670 train_time:130664ms step_avg:95.10ms +step:1375/1670 train_time:130759ms step_avg:95.10ms +step:1375/1670 val_loss:3.3417 train_time:130851ms step_avg:95.16ms +step:1376/1670 train_time:130878ms step_avg:95.11ms +step:1377/1670 train_time:130953ms step_avg:95.10ms +step:1378/1670 train_time:131054ms step_avg:95.10ms +step:1379/1670 train_time:131149ms step_avg:95.10ms +step:1380/1670 train_time:131243ms step_avg:95.10ms +step:1381/1670 train_time:131337ms step_avg:95.10ms +step:1382/1670 train_time:131430ms step_avg:95.10ms +step:1383/1670 train_time:131523ms step_avg:95.10ms +step:1384/1670 train_time:131617ms step_avg:95.10ms +step:1385/1670 train_time:131710ms step_avg:95.10ms +step:1386/1670 train_time:131804ms step_avg:95.10ms +step:1387/1670 train_time:131901ms step_avg:95.10ms +step:1388/1670 train_time:131999ms step_avg:95.10ms +step:1389/1670 train_time:132097ms step_avg:95.10ms +step:1390/1670 train_time:132192ms step_avg:95.10ms +step:1391/1670 train_time:132286ms step_avg:95.10ms +step:1392/1670 train_time:132380ms step_avg:95.10ms +step:1393/1670 train_time:132473ms step_avg:95.10ms +step:1394/1670 train_time:132566ms step_avg:95.10ms +step:1395/1670 train_time:132660ms step_avg:95.10ms +step:1396/1670 train_time:132753ms step_avg:95.10ms +step:1397/1670 train_time:132847ms step_avg:95.09ms +step:1398/1670 train_time:132944ms step_avg:95.10ms +step:1399/1670 train_time:133040ms step_avg:95.10ms +step:1400/1670 train_time:133137ms step_avg:95.10ms +step:1401/1670 train_time:133231ms step_avg:95.10ms +step:1402/1670 train_time:133325ms step_avg:95.10ms +step:1403/1670 train_time:133420ms step_avg:95.10ms +step:1404/1670 train_time:133514ms step_avg:95.10ms +step:1405/1670 train_time:133607ms step_avg:95.09ms +step:1406/1670 train_time:133701ms step_avg:95.09ms +step:1407/1670 train_time:133796ms step_avg:95.09ms +step:1408/1670 train_time:133890ms step_avg:95.09ms +step:1409/1670 train_time:133985ms step_avg:95.09ms +step:1410/1670 train_time:134080ms step_avg:95.09ms +step:1411/1670 train_time:134176ms step_avg:95.09ms +step:1412/1670 train_time:134271ms step_avg:95.09ms +step:1413/1670 train_time:134365ms step_avg:95.09ms +step:1414/1670 train_time:134460ms step_avg:95.09ms +step:1415/1670 train_time:134554ms step_avg:95.09ms +step:1416/1670 train_time:134648ms step_avg:95.09ms +step:1417/1670 train_time:134741ms step_avg:95.09ms +step:1418/1670 train_time:134836ms step_avg:95.09ms +step:1419/1670 train_time:134930ms step_avg:95.09ms +step:1420/1670 train_time:135024ms step_avg:95.09ms +step:1421/1670 train_time:135120ms step_avg:95.09ms +step:1422/1670 train_time:135215ms step_avg:95.09ms +step:1423/1670 train_time:135309ms step_avg:95.09ms +step:1424/1670 train_time:135403ms step_avg:95.09ms +step:1425/1670 train_time:135499ms step_avg:95.09ms +step:1426/1670 train_time:135594ms step_avg:95.09ms +step:1427/1670 train_time:135688ms step_avg:95.09ms +step:1428/1670 train_time:135781ms step_avg:95.08ms +step:1429/1670 train_time:135876ms step_avg:95.08ms +step:1430/1670 train_time:135970ms step_avg:95.08ms +step:1431/1670 train_time:136064ms step_avg:95.08ms +step:1432/1670 train_time:136159ms step_avg:95.08ms +step:1433/1670 train_time:136254ms step_avg:95.08ms +step:1434/1670 train_time:136348ms step_avg:95.08ms +step:1435/1670 train_time:136443ms step_avg:95.08ms +step:1436/1670 train_time:136538ms step_avg:95.08ms +step:1437/1670 train_time:136632ms step_avg:95.08ms +step:1438/1670 train_time:136726ms step_avg:95.08ms +step:1439/1670 train_time:136821ms step_avg:95.08ms +step:1440/1670 train_time:136916ms step_avg:95.08ms +step:1441/1670 train_time:137011ms step_avg:95.08ms +step:1442/1670 train_time:137105ms step_avg:95.08ms +step:1443/1670 train_time:137200ms step_avg:95.08ms +step:1444/1670 train_time:137295ms step_avg:95.08ms +step:1445/1670 train_time:137390ms step_avg:95.08ms +step:1446/1670 train_time:137484ms step_avg:95.08ms +step:1447/1670 train_time:137579ms step_avg:95.08ms +step:1448/1670 train_time:137673ms step_avg:95.08ms +step:1449/1670 train_time:137767ms step_avg:95.08ms +step:1450/1670 train_time:137862ms step_avg:95.08ms +step:1451/1670 train_time:137957ms step_avg:95.08ms +step:1452/1670 train_time:138052ms step_avg:95.08ms +step:1453/1670 train_time:138146ms step_avg:95.08ms +step:1454/1670 train_time:138241ms step_avg:95.08ms +step:1455/1670 train_time:138336ms step_avg:95.08ms +step:1456/1670 train_time:138431ms step_avg:95.08ms +step:1457/1670 train_time:138525ms step_avg:95.08ms +step:1458/1670 train_time:138620ms step_avg:95.08ms +step:1459/1670 train_time:138715ms step_avg:95.08ms +step:1460/1670 train_time:138810ms step_avg:95.08ms +step:1461/1670 train_time:138904ms step_avg:95.07ms +step:1462/1670 train_time:138998ms step_avg:95.07ms +step:1463/1670 train_time:139093ms step_avg:95.07ms +step:1464/1670 train_time:139188ms step_avg:95.07ms +step:1465/1670 train_time:139282ms step_avg:95.07ms +step:1466/1670 train_time:139378ms step_avg:95.07ms +step:1467/1670 train_time:139473ms step_avg:95.07ms +step:1468/1670 train_time:139567ms step_avg:95.07ms +step:1469/1670 train_time:139661ms step_avg:95.07ms +step:1470/1670 train_time:139755ms step_avg:95.07ms +step:1471/1670 train_time:139849ms step_avg:95.07ms +step:1472/1670 train_time:139943ms step_avg:95.07ms +step:1473/1670 train_time:140038ms step_avg:95.07ms +step:1474/1670 train_time:140133ms step_avg:95.07ms +step:1475/1670 train_time:140227ms step_avg:95.07ms +step:1476/1670 train_time:140322ms step_avg:95.07ms +step:1477/1670 train_time:140418ms step_avg:95.07ms +step:1478/1670 train_time:140513ms step_avg:95.07ms +step:1479/1670 train_time:140606ms step_avg:95.07ms +step:1480/1670 train_time:140701ms step_avg:95.07ms +step:1481/1670 train_time:140796ms step_avg:95.07ms +step:1482/1670 train_time:140890ms step_avg:95.07ms +step:1483/1670 train_time:140985ms step_avg:95.07ms +step:1484/1670 train_time:141079ms step_avg:95.07ms +step:1485/1670 train_time:141431ms step_avg:95.24ms +step:1486/1670 train_time:141593ms step_avg:95.28ms +step:1487/1670 train_time:141685ms step_avg:95.28ms +step:1488/1670 train_time:141778ms step_avg:95.28ms +step:1489/1670 train_time:141871ms step_avg:95.28ms +step:1490/1670 train_time:141964ms step_avg:95.28ms +step:1491/1670 train_time:142058ms step_avg:95.28ms +step:1492/1670 train_time:142151ms step_avg:95.28ms +step:1493/1670 train_time:142245ms step_avg:95.27ms +step:1494/1670 train_time:142338ms step_avg:95.27ms +step:1495/1670 train_time:142433ms step_avg:95.27ms +step:1496/1670 train_time:142532ms step_avg:95.28ms +step:1497/1670 train_time:142629ms step_avg:95.28ms +step:1498/1670 train_time:142725ms step_avg:95.28ms +step:1499/1670 train_time:142819ms step_avg:95.28ms +step:1500/1670 train_time:142912ms step_avg:95.27ms +step:1500/1670 val_loss:3.3122 train_time:143004ms step_avg:95.34ms +step:1501/1670 train_time:143032ms step_avg:95.29ms +step:1502/1670 train_time:143109ms step_avg:95.28ms +step:1503/1670 train_time:143209ms step_avg:95.28ms +step:1504/1670 train_time:143306ms step_avg:95.28ms +step:1505/1670 train_time:143400ms step_avg:95.28ms +step:1506/1670 train_time:143493ms step_avg:95.28ms +step:1507/1670 train_time:143586ms step_avg:95.28ms +step:1508/1670 train_time:143680ms step_avg:95.28ms +step:1509/1670 train_time:143773ms step_avg:95.28ms +step:1510/1670 train_time:143866ms step_avg:95.28ms +step:1511/1670 train_time:143961ms step_avg:95.28ms +step:1512/1670 train_time:144057ms step_avg:95.28ms +step:1513/1670 train_time:144153ms step_avg:95.28ms +step:1514/1670 train_time:144249ms step_avg:95.28ms +step:1515/1670 train_time:144345ms step_avg:95.28ms +step:1516/1670 train_time:144439ms step_avg:95.28ms +step:1517/1670 train_time:144533ms step_avg:95.28ms +step:1518/1670 train_time:144627ms step_avg:95.27ms +step:1519/1670 train_time:144720ms step_avg:95.27ms +step:1520/1670 train_time:144813ms step_avg:95.27ms +step:1521/1670 train_time:144908ms step_avg:95.27ms +step:1522/1670 train_time:145004ms step_avg:95.27ms +step:1523/1670 train_time:145100ms step_avg:95.27ms +step:1524/1670 train_time:145197ms step_avg:95.27ms +step:1525/1670 train_time:145292ms step_avg:95.27ms +step:1526/1670 train_time:145386ms step_avg:95.27ms +step:1527/1670 train_time:145480ms step_avg:95.27ms +step:1528/1670 train_time:145574ms step_avg:95.27ms +step:1529/1670 train_time:145669ms step_avg:95.27ms +step:1530/1670 train_time:145763ms step_avg:95.27ms +step:1531/1670 train_time:145856ms step_avg:95.27ms +step:1532/1670 train_time:145950ms step_avg:95.27ms +step:1533/1670 train_time:146045ms step_avg:95.27ms +step:1534/1670 train_time:146140ms step_avg:95.27ms +step:1535/1670 train_time:146235ms step_avg:95.27ms +step:1536/1670 train_time:146329ms step_avg:95.27ms +step:1537/1670 train_time:146425ms step_avg:95.27ms +step:1538/1670 train_time:146519ms step_avg:95.27ms +step:1539/1670 train_time:146613ms step_avg:95.26ms +step:1540/1670 train_time:146707ms step_avg:95.26ms +step:1541/1670 train_time:146801ms step_avg:95.26ms +step:1542/1670 train_time:146895ms step_avg:95.26ms +step:1543/1670 train_time:146990ms step_avg:95.26ms +step:1544/1670 train_time:147085ms step_avg:95.26ms +step:1545/1670 train_time:147181ms step_avg:95.26ms +step:1546/1670 train_time:147275ms step_avg:95.26ms +step:1547/1670 train_time:147370ms step_avg:95.26ms +step:1548/1670 train_time:147465ms step_avg:95.26ms +step:1549/1670 train_time:147559ms step_avg:95.26ms +step:1550/1670 train_time:147652ms step_avg:95.26ms +step:1551/1670 train_time:147747ms step_avg:95.26ms +step:1552/1670 train_time:147842ms step_avg:95.26ms +step:1553/1670 train_time:147936ms step_avg:95.26ms +step:1554/1670 train_time:148030ms step_avg:95.26ms +step:1555/1670 train_time:148125ms step_avg:95.26ms +step:1556/1670 train_time:148220ms step_avg:95.26ms +step:1557/1670 train_time:148316ms step_avg:95.26ms +step:1558/1670 train_time:148411ms step_avg:95.26ms +step:1559/1670 train_time:148506ms step_avg:95.26ms +step:1560/1670 train_time:148601ms step_avg:95.26ms +step:1561/1670 train_time:148695ms step_avg:95.26ms +step:1562/1670 train_time:148789ms step_avg:95.26ms +step:1563/1670 train_time:148883ms step_avg:95.25ms +step:1564/1670 train_time:148977ms step_avg:95.25ms +step:1565/1670 train_time:149071ms step_avg:95.25ms +step:1566/1670 train_time:149167ms step_avg:95.25ms +step:1567/1670 train_time:149262ms step_avg:95.25ms +step:1568/1670 train_time:149357ms step_avg:95.25ms +step:1569/1670 train_time:149452ms step_avg:95.25ms +step:1570/1670 train_time:149547ms step_avg:95.25ms +step:1571/1670 train_time:149641ms step_avg:95.25ms +step:1572/1670 train_time:149735ms step_avg:95.25ms +step:1573/1670 train_time:149829ms step_avg:95.25ms +step:1574/1670 train_time:149924ms step_avg:95.25ms +step:1575/1670 train_time:150019ms step_avg:95.25ms +step:1576/1670 train_time:150112ms step_avg:95.25ms +step:1577/1670 train_time:150207ms step_avg:95.25ms +step:1578/1670 train_time:150303ms step_avg:95.25ms +step:1579/1670 train_time:150397ms step_avg:95.25ms +step:1580/1670 train_time:150491ms step_avg:95.25ms +step:1581/1670 train_time:150586ms step_avg:95.25ms +step:1582/1670 train_time:150681ms step_avg:95.25ms +step:1583/1670 train_time:150775ms step_avg:95.25ms +step:1584/1670 train_time:150869ms step_avg:95.25ms +step:1585/1670 train_time:150964ms step_avg:95.25ms +step:1586/1670 train_time:151058ms step_avg:95.24ms +step:1587/1670 train_time:151152ms step_avg:95.24ms +step:1588/1670 train_time:151247ms step_avg:95.24ms +step:1589/1670 train_time:151343ms step_avg:95.24ms +step:1590/1670 train_time:151438ms step_avg:95.24ms +step:1591/1670 train_time:151532ms step_avg:95.24ms +step:1592/1670 train_time:151626ms step_avg:95.24ms +step:1593/1670 train_time:151721ms step_avg:95.24ms +step:1594/1670 train_time:151816ms step_avg:95.24ms +step:1595/1670 train_time:151909ms step_avg:95.24ms +step:1596/1670 train_time:152004ms step_avg:95.24ms +step:1597/1670 train_time:152098ms step_avg:95.24ms +step:1598/1670 train_time:152192ms step_avg:95.24ms +step:1599/1670 train_time:152287ms step_avg:95.24ms +step:1600/1670 train_time:152382ms step_avg:95.24ms +step:1601/1670 train_time:152477ms step_avg:95.24ms +step:1602/1670 train_time:152571ms step_avg:95.24ms +step:1603/1670 train_time:152666ms step_avg:95.24ms +step:1604/1670 train_time:152761ms step_avg:95.24ms +step:1605/1670 train_time:152855ms step_avg:95.24ms +step:1606/1670 train_time:152949ms step_avg:95.24ms +step:1607/1670 train_time:153044ms step_avg:95.24ms +step:1608/1670 train_time:153138ms step_avg:95.24ms +step:1609/1670 train_time:153233ms step_avg:95.23ms +step:1610/1670 train_time:153327ms step_avg:95.23ms +step:1611/1670 train_time:153422ms step_avg:95.23ms +step:1612/1670 train_time:153517ms step_avg:95.23ms +step:1613/1670 train_time:153612ms step_avg:95.23ms +step:1614/1670 train_time:153707ms step_avg:95.23ms +step:1615/1670 train_time:153801ms step_avg:95.23ms +step:1616/1670 train_time:153896ms step_avg:95.23ms +step:1617/1670 train_time:153990ms step_avg:95.23ms +step:1618/1670 train_time:154085ms step_avg:95.23ms +step:1619/1670 train_time:154179ms step_avg:95.23ms +step:1620/1670 train_time:154273ms step_avg:95.23ms +step:1621/1670 train_time:154370ms step_avg:95.23ms +step:1622/1670 train_time:154465ms step_avg:95.23ms +step:1623/1670 train_time:154560ms step_avg:95.23ms +step:1624/1670 train_time:154654ms step_avg:95.23ms +step:1625/1670 train_time:154750ms step_avg:95.23ms +step:1625/1670 val_loss:3.2867 train_time:154842ms step_avg:95.29ms +step:1626/1670 train_time:154869ms step_avg:95.25ms +step:1627/1670 train_time:154946ms step_avg:95.23ms +step:1628/1670 train_time:155047ms step_avg:95.24ms +step:1629/1670 train_time:155143ms step_avg:95.24ms +step:1630/1670 train_time:155237ms step_avg:95.24ms +step:1631/1670 train_time:155330ms step_avg:95.24ms +step:1632/1670 train_time:155423ms step_avg:95.23ms +step:1633/1670 train_time:155517ms step_avg:95.23ms +step:1634/1670 train_time:155610ms step_avg:95.23ms +step:1635/1670 train_time:155704ms step_avg:95.23ms +step:1636/1670 train_time:155798ms step_avg:95.23ms +step:1637/1670 train_time:155893ms step_avg:95.23ms +step:1638/1670 train_time:155990ms step_avg:95.23ms +step:1639/1670 train_time:156087ms step_avg:95.23ms +step:1640/1670 train_time:156184ms step_avg:95.23ms +step:1641/1670 train_time:156277ms step_avg:95.23ms +step:1642/1670 train_time:156371ms step_avg:95.23ms +step:1643/1670 train_time:156466ms step_avg:95.23ms +step:1644/1670 train_time:156559ms step_avg:95.23ms +step:1645/1670 train_time:156653ms step_avg:95.23ms +step:1646/1670 train_time:156747ms step_avg:95.23ms +step:1647/1670 train_time:156841ms step_avg:95.23ms +step:1648/1670 train_time:156938ms step_avg:95.23ms +step:1649/1670 train_time:157036ms step_avg:95.23ms +step:1650/1670 train_time:157132ms step_avg:95.23ms +step:1651/1670 train_time:157227ms step_avg:95.23ms +step:1652/1670 train_time:157321ms step_avg:95.23ms +step:1653/1670 train_time:157415ms step_avg:95.23ms +step:1654/1670 train_time:157509ms step_avg:95.23ms +step:1655/1670 train_time:157603ms step_avg:95.23ms +step:1656/1670 train_time:157697ms step_avg:95.23ms +step:1657/1670 train_time:157790ms step_avg:95.23ms +step:1658/1670 train_time:157884ms step_avg:95.23ms +step:1659/1670 train_time:157981ms step_avg:95.23ms +step:1660/1670 train_time:158077ms step_avg:95.23ms +step:1661/1670 train_time:158173ms step_avg:95.23ms +step:1662/1670 train_time:158266ms step_avg:95.23ms +step:1663/1670 train_time:158361ms step_avg:95.23ms +step:1664/1670 train_time:158456ms step_avg:95.23ms +step:1665/1670 train_time:158550ms step_avg:95.23ms +step:1666/1670 train_time:158644ms step_avg:95.22ms +step:1667/1670 train_time:158737ms step_avg:95.22ms +step:1668/1670 train_time:158832ms step_avg:95.22ms +step:1669/1670 train_time:158926ms step_avg:95.22ms +step:1670/1670 train_time:159021ms step_avg:95.22ms +step:1670/1670 val_loss:3.2778 train_time:159191ms step_avg:95.32ms +peak memory allocated: 32470 MiB reserved: 47756 MiB diff --git a/records/091025_Yarn/9121a353-d3ce-4f54-98de-0b466773fe0b.txt b/records/091025_Yarn/9121a353-d3ce-4f54-98de-0b466773fe0b.txt new file mode 100644 index 000000000..6639de4c0 --- /dev/null +++ b/records/091025_Yarn/9121a353-d3ce-4f54-98de-0b466773fe0b.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 04:09:29 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 37C P0 119W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 41C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 43C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 35C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 43C P0 127W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 41C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 37C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 66339 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 66340 C /usr/bin/python3 614MiB | +| 0 N/A N/A 66341 C /usr/bin/python3 614MiB | +| 0 N/A N/A 66342 C /usr/bin/python3 614MiB | +| 0 N/A N/A 66343 C /usr/bin/python3 614MiB | +| 0 N/A N/A 66344 C /usr/bin/python3 614MiB | +| 0 N/A N/A 66345 C /usr/bin/python3 614MiB | +| 0 N/A N/A 66346 C /usr/bin/python3 614MiB | +| 1 N/A N/A 66340 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 66341 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 66342 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 66343 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 66344 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 66345 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 66346 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1670 train_time:456ms step_avg:455.70ms +step:2/1670 train_time:481ms step_avg:240.37ms +step:3/1670 train_time:548ms step_avg:182.61ms +step:4/1670 train_time:638ms step_avg:159.58ms +step:5/1670 train_time:730ms step_avg:145.93ms +step:6/1670 train_time:821ms step_avg:136.92ms +step:7/1670 train_time:914ms step_avg:130.50ms +step:8/1670 train_time:1005ms step_avg:125.68ms +step:9/1670 train_time:1098ms step_avg:121.98ms +step:10/1670 train_time:1189ms step_avg:118.92ms +step:11/1670 train_time:1281ms step_avg:116.46ms +step:12/1670 train_time:1373ms step_avg:114.45ms +step:13/1670 train_time:1470ms step_avg:113.05ms +step:14/1670 train_time:1565ms step_avg:111.79ms +step:15/1670 train_time:1658ms step_avg:110.53ms +step:16/1670 train_time:1750ms step_avg:109.37ms +step:17/1670 train_time:1842ms step_avg:108.37ms +step:18/1670 train_time:1935ms step_avg:107.50ms +step:19/1670 train_time:2028ms step_avg:106.72ms +step:20/1670 train_time:2119ms step_avg:105.97ms +step:21/1670 train_time:2211ms step_avg:105.30ms +step:22/1670 train_time:2303ms step_avg:104.68ms +step:23/1670 train_time:2396ms step_avg:104.18ms +step:24/1670 train_time:2489ms step_avg:103.70ms +step:25/1670 train_time:2583ms step_avg:103.30ms +step:26/1670 train_time:2676ms step_avg:102.91ms +step:27/1670 train_time:2768ms step_avg:102.53ms +step:28/1670 train_time:2862ms step_avg:102.20ms +step:29/1670 train_time:2955ms step_avg:101.89ms +step:30/1670 train_time:3047ms step_avg:101.57ms +step:31/1670 train_time:3141ms step_avg:101.31ms +step:32/1670 train_time:3233ms step_avg:101.03ms +step:33/1670 train_time:3326ms step_avg:100.78ms +step:34/1670 train_time:3419ms step_avg:100.56ms +step:35/1670 train_time:3512ms step_avg:100.35ms +step:36/1670 train_time:3605ms step_avg:100.14ms +step:37/1670 train_time:3698ms step_avg:99.94ms +step:38/1670 train_time:3790ms step_avg:99.74ms +step:39/1670 train_time:3883ms step_avg:99.57ms +step:40/1670 train_time:3976ms step_avg:99.39ms +step:41/1670 train_time:4068ms step_avg:99.22ms +step:42/1670 train_time:4161ms step_avg:99.06ms +step:43/1670 train_time:4253ms step_avg:98.90ms +step:44/1670 train_time:4345ms step_avg:98.74ms +step:45/1670 train_time:4438ms step_avg:98.62ms +step:46/1670 train_time:4530ms step_avg:98.49ms +step:47/1670 train_time:4623ms step_avg:98.36ms +step:48/1670 train_time:4716ms step_avg:98.25ms +step:49/1670 train_time:4808ms step_avg:98.13ms +step:50/1670 train_time:4902ms step_avg:98.05ms +step:51/1670 train_time:4995ms step_avg:97.94ms +step:52/1670 train_time:5087ms step_avg:97.84ms +step:53/1670 train_time:5181ms step_avg:97.75ms +step:54/1670 train_time:5273ms step_avg:97.65ms +step:55/1670 train_time:5366ms step_avg:97.56ms +step:56/1670 train_time:5459ms step_avg:97.48ms +step:57/1670 train_time:5552ms step_avg:97.40ms +step:58/1670 train_time:5644ms step_avg:97.31ms +step:59/1670 train_time:5737ms step_avg:97.24ms +step:60/1670 train_time:5830ms step_avg:97.17ms +step:61/1670 train_time:5922ms step_avg:97.09ms +step:62/1670 train_time:6015ms step_avg:97.02ms +step:63/1670 train_time:6108ms step_avg:96.95ms +step:64/1670 train_time:6200ms step_avg:96.87ms +step:65/1670 train_time:6292ms step_avg:96.80ms +step:66/1670 train_time:6384ms step_avg:96.73ms +step:67/1670 train_time:6477ms step_avg:96.67ms +step:68/1670 train_time:6570ms step_avg:96.61ms +step:69/1670 train_time:6662ms step_avg:96.55ms +step:70/1670 train_time:6755ms step_avg:96.50ms +step:71/1670 train_time:6847ms step_avg:96.44ms +step:72/1670 train_time:6940ms step_avg:96.39ms +step:73/1670 train_time:7034ms step_avg:96.36ms +step:74/1670 train_time:7127ms step_avg:96.32ms +step:75/1670 train_time:7219ms step_avg:96.26ms +step:76/1670 train_time:7312ms step_avg:96.21ms +step:77/1670 train_time:7404ms step_avg:96.16ms +step:78/1670 train_time:7497ms step_avg:96.11ms +step:79/1670 train_time:7589ms step_avg:96.06ms +step:80/1670 train_time:7682ms step_avg:96.02ms +step:81/1670 train_time:7774ms step_avg:95.98ms +step:82/1670 train_time:7867ms step_avg:95.94ms +step:83/1670 train_time:7960ms step_avg:95.90ms +step:84/1670 train_time:8052ms step_avg:95.86ms +step:85/1670 train_time:8144ms step_avg:95.82ms +step:86/1670 train_time:8237ms step_avg:95.78ms +step:87/1670 train_time:8330ms step_avg:95.74ms +step:88/1670 train_time:8422ms step_avg:95.70ms +step:89/1670 train_time:8515ms step_avg:95.67ms +step:90/1670 train_time:8608ms step_avg:95.65ms +step:91/1670 train_time:8700ms step_avg:95.60ms +step:92/1670 train_time:8793ms step_avg:95.57ms +step:93/1670 train_time:8885ms step_avg:95.53ms +step:94/1670 train_time:8977ms step_avg:95.50ms +step:95/1670 train_time:9070ms step_avg:95.47ms +step:96/1670 train_time:9163ms step_avg:95.45ms +step:97/1670 train_time:9256ms step_avg:95.42ms +step:98/1670 train_time:9348ms step_avg:95.38ms +step:99/1670 train_time:9440ms step_avg:95.36ms +step:100/1670 train_time:9533ms step_avg:95.33ms +step:101/1670 train_time:9625ms step_avg:95.30ms +step:102/1670 train_time:9717ms step_avg:95.26ms +step:103/1670 train_time:9811ms step_avg:95.25ms +step:104/1670 train_time:9903ms step_avg:95.22ms +step:105/1670 train_time:9995ms step_avg:95.19ms +step:106/1670 train_time:10088ms step_avg:95.17ms +step:107/1670 train_time:10180ms step_avg:95.14ms +step:108/1670 train_time:10272ms step_avg:95.11ms +step:109/1670 train_time:10366ms step_avg:95.10ms +step:110/1670 train_time:10458ms step_avg:95.07ms +step:111/1670 train_time:10550ms step_avg:95.05ms +step:112/1670 train_time:10643ms step_avg:95.02ms +step:113/1670 train_time:10735ms step_avg:95.00ms +step:114/1670 train_time:10828ms step_avg:94.98ms +step:115/1670 train_time:10919ms step_avg:94.95ms +step:116/1670 train_time:11013ms step_avg:94.94ms +step:117/1670 train_time:11104ms step_avg:94.91ms +step:118/1670 train_time:11196ms step_avg:94.88ms +step:119/1670 train_time:11289ms step_avg:94.87ms +step:120/1670 train_time:11382ms step_avg:94.85ms +step:121/1670 train_time:11474ms step_avg:94.82ms +step:122/1670 train_time:11567ms step_avg:94.81ms +step:123/1670 train_time:11659ms step_avg:94.79ms +step:124/1670 train_time:11751ms step_avg:94.77ms +step:125/1670 train_time:11844ms step_avg:94.75ms +step:125/1670 val_loss:4.3045 train_time:11935ms step_avg:95.48ms +step:126/1670 train_time:11963ms step_avg:94.94ms +step:127/1670 train_time:12034ms step_avg:94.76ms +step:128/1670 train_time:12139ms step_avg:94.84ms +step:129/1670 train_time:12235ms step_avg:94.85ms +step:130/1670 train_time:12328ms step_avg:94.83ms +step:131/1670 train_time:12420ms step_avg:94.81ms +step:132/1670 train_time:12512ms step_avg:94.79ms +step:133/1670 train_time:12603ms step_avg:94.76ms +step:134/1670 train_time:12694ms step_avg:94.73ms +step:135/1670 train_time:12786ms step_avg:94.71ms +step:136/1670 train_time:12877ms step_avg:94.69ms +step:137/1670 train_time:12969ms step_avg:94.66ms +step:138/1670 train_time:13063ms step_avg:94.66ms +step:139/1670 train_time:13158ms step_avg:94.66ms +step:140/1670 train_time:13252ms step_avg:94.65ms +step:141/1670 train_time:13345ms step_avg:94.65ms +step:142/1670 train_time:13438ms step_avg:94.63ms +step:143/1670 train_time:13530ms step_avg:94.62ms +step:144/1670 train_time:13621ms step_avg:94.59ms +step:145/1670 train_time:13713ms step_avg:94.57ms +step:146/1670 train_time:13804ms step_avg:94.55ms +step:147/1670 train_time:13896ms step_avg:94.53ms +step:148/1670 train_time:13987ms step_avg:94.51ms +step:149/1670 train_time:14080ms step_avg:94.50ms +step:150/1670 train_time:14173ms step_avg:94.49ms +step:151/1670 train_time:14266ms step_avg:94.48ms +step:152/1670 train_time:14360ms step_avg:94.47ms +step:153/1670 train_time:14452ms step_avg:94.46ms +step:154/1670 train_time:14545ms step_avg:94.45ms +step:155/1670 train_time:14637ms step_avg:94.43ms +step:156/1670 train_time:14728ms step_avg:94.41ms +step:157/1670 train_time:14821ms step_avg:94.40ms +step:158/1670 train_time:14913ms step_avg:94.39ms +step:159/1670 train_time:15007ms step_avg:94.38ms +step:160/1670 train_time:15099ms step_avg:94.37ms +step:161/1670 train_time:15192ms step_avg:94.36ms +step:162/1670 train_time:15285ms step_avg:94.35ms +step:163/1670 train_time:15379ms step_avg:94.35ms +step:164/1670 train_time:15472ms step_avg:94.34ms +step:165/1670 train_time:15565ms step_avg:94.33ms +step:166/1670 train_time:15657ms step_avg:94.32ms +step:167/1670 train_time:15750ms step_avg:94.31ms +step:168/1670 train_time:15841ms step_avg:94.29ms +step:169/1670 train_time:15933ms step_avg:94.28ms +step:170/1670 train_time:16026ms step_avg:94.27ms +step:171/1670 train_time:16118ms step_avg:94.26ms +step:172/1670 train_time:16211ms step_avg:94.25ms +step:173/1670 train_time:16304ms step_avg:94.24ms +step:174/1670 train_time:16397ms step_avg:94.24ms +step:175/1670 train_time:16489ms step_avg:94.22ms +step:176/1670 train_time:16581ms step_avg:94.21ms +step:177/1670 train_time:16674ms step_avg:94.20ms +step:178/1670 train_time:16766ms step_avg:94.19ms +step:179/1670 train_time:16859ms step_avg:94.19ms +step:180/1670 train_time:16951ms step_avg:94.17ms +step:181/1670 train_time:17043ms step_avg:94.16ms +step:182/1670 train_time:17136ms step_avg:94.15ms +step:183/1670 train_time:17228ms step_avg:94.14ms +step:184/1670 train_time:17321ms step_avg:94.13ms +step:185/1670 train_time:17413ms step_avg:94.12ms +step:186/1670 train_time:17506ms step_avg:94.12ms +step:187/1670 train_time:17598ms step_avg:94.10ms +step:188/1670 train_time:17690ms step_avg:94.09ms +step:189/1670 train_time:17782ms step_avg:94.09ms +step:190/1670 train_time:17875ms step_avg:94.08ms +step:191/1670 train_time:17967ms step_avg:94.07ms +step:192/1670 train_time:18060ms step_avg:94.06ms +step:193/1670 train_time:18153ms step_avg:94.06ms +step:194/1670 train_time:18245ms step_avg:94.05ms +step:195/1670 train_time:18337ms step_avg:94.04ms +step:196/1670 train_time:18430ms step_avg:94.03ms +step:197/1670 train_time:18523ms step_avg:94.02ms +step:198/1670 train_time:18615ms step_avg:94.02ms +step:199/1670 train_time:18708ms step_avg:94.01ms +step:200/1670 train_time:18800ms step_avg:94.00ms +step:201/1670 train_time:18892ms step_avg:93.99ms +step:202/1670 train_time:18985ms step_avg:93.98ms +step:203/1670 train_time:19077ms step_avg:93.98ms +step:204/1670 train_time:19169ms step_avg:93.97ms +step:205/1670 train_time:19262ms step_avg:93.96ms +step:206/1670 train_time:19354ms step_avg:93.95ms +step:207/1670 train_time:19447ms step_avg:93.95ms +step:208/1670 train_time:19539ms step_avg:93.94ms +step:209/1670 train_time:19632ms step_avg:93.93ms +step:210/1670 train_time:19725ms step_avg:93.93ms +step:211/1670 train_time:19817ms step_avg:93.92ms +step:212/1670 train_time:19909ms step_avg:93.91ms +step:213/1670 train_time:20253ms step_avg:95.09ms +step:214/1670 train_time:20388ms step_avg:95.27ms +step:215/1670 train_time:20479ms step_avg:95.25ms +step:216/1670 train_time:20570ms step_avg:95.23ms +step:217/1670 train_time:20661ms step_avg:95.21ms +step:218/1670 train_time:20753ms step_avg:95.20ms +step:219/1670 train_time:20844ms step_avg:95.18ms +step:220/1670 train_time:20936ms step_avg:95.16ms +step:221/1670 train_time:21027ms step_avg:95.15ms +step:222/1670 train_time:21119ms step_avg:95.13ms +step:223/1670 train_time:21212ms step_avg:95.12ms +step:224/1670 train_time:21307ms step_avg:95.12ms +step:225/1670 train_time:21403ms step_avg:95.12ms +step:226/1670 train_time:21496ms step_avg:95.12ms +step:227/1670 train_time:21587ms step_avg:95.10ms +step:228/1670 train_time:21679ms step_avg:95.08ms +step:229/1670 train_time:21772ms step_avg:95.07ms +step:230/1670 train_time:21863ms step_avg:95.06ms +step:231/1670 train_time:21955ms step_avg:95.04ms +step:232/1670 train_time:22047ms step_avg:95.03ms +step:233/1670 train_time:22139ms step_avg:95.02ms +step:234/1670 train_time:22232ms step_avg:95.01ms +step:235/1670 train_time:22326ms step_avg:95.00ms +step:236/1670 train_time:22419ms step_avg:94.99ms +step:237/1670 train_time:22512ms step_avg:94.99ms +step:238/1670 train_time:22605ms step_avg:94.98ms +step:239/1670 train_time:22698ms step_avg:94.97ms +step:240/1670 train_time:22794ms step_avg:94.97ms +step:241/1670 train_time:22882ms step_avg:94.95ms +step:242/1670 train_time:22975ms step_avg:94.94ms +step:243/1670 train_time:23067ms step_avg:94.93ms +step:244/1670 train_time:23159ms step_avg:94.92ms +step:245/1670 train_time:23252ms step_avg:94.91ms +step:246/1670 train_time:23345ms step_avg:94.90ms +step:247/1670 train_time:23438ms step_avg:94.89ms +step:248/1670 train_time:23531ms step_avg:94.88ms +step:249/1670 train_time:23624ms step_avg:94.87ms +step:250/1670 train_time:23716ms step_avg:94.86ms +step:250/1670 val_loss:3.9649 train_time:23806ms step_avg:95.22ms +step:251/1670 train_time:23833ms step_avg:94.95ms +step:252/1670 train_time:23908ms step_avg:94.87ms +step:253/1670 train_time:24008ms step_avg:94.89ms +step:254/1670 train_time:24102ms step_avg:94.89ms +step:255/1670 train_time:24195ms step_avg:94.88ms +step:256/1670 train_time:24287ms step_avg:94.87ms +step:257/1670 train_time:24378ms step_avg:94.86ms +step:258/1670 train_time:24470ms step_avg:94.84ms +step:259/1670 train_time:24561ms step_avg:94.83ms +step:260/1670 train_time:24652ms step_avg:94.82ms +step:261/1670 train_time:24744ms step_avg:94.80ms +step:262/1670 train_time:24838ms step_avg:94.80ms +step:263/1670 train_time:24932ms step_avg:94.80ms +step:264/1670 train_time:25029ms step_avg:94.81ms +step:265/1670 train_time:25122ms step_avg:94.80ms +step:266/1670 train_time:25214ms step_avg:94.79ms +step:267/1670 train_time:25307ms step_avg:94.78ms +step:268/1670 train_time:25399ms step_avg:94.77ms +step:269/1670 train_time:25491ms step_avg:94.76ms +step:270/1670 train_time:25583ms step_avg:94.75ms +step:271/1670 train_time:25675ms step_avg:94.74ms +step:272/1670 train_time:25766ms step_avg:94.73ms +step:273/1670 train_time:25859ms step_avg:94.72ms +step:274/1670 train_time:25954ms step_avg:94.72ms +step:275/1670 train_time:26047ms step_avg:94.72ms +step:276/1670 train_time:26140ms step_avg:94.71ms +step:277/1670 train_time:26234ms step_avg:94.71ms +step:278/1670 train_time:26325ms step_avg:94.69ms +step:279/1670 train_time:26417ms step_avg:94.69ms +step:280/1670 train_time:26509ms step_avg:94.68ms +step:281/1670 train_time:26601ms step_avg:94.67ms +step:282/1670 train_time:26693ms step_avg:94.66ms +step:283/1670 train_time:26786ms step_avg:94.65ms +step:284/1670 train_time:26878ms step_avg:94.64ms +step:285/1670 train_time:26971ms step_avg:94.63ms +step:286/1670 train_time:27064ms step_avg:94.63ms +step:287/1670 train_time:27157ms step_avg:94.62ms +step:288/1670 train_time:27250ms step_avg:94.62ms +step:289/1670 train_time:27343ms step_avg:94.61ms +step:290/1670 train_time:27436ms step_avg:94.61ms +step:291/1670 train_time:27528ms step_avg:94.60ms +step:292/1670 train_time:27620ms step_avg:94.59ms +step:293/1670 train_time:27712ms step_avg:94.58ms +step:294/1670 train_time:27805ms step_avg:94.57ms +step:295/1670 train_time:27897ms step_avg:94.57ms +step:296/1670 train_time:27990ms step_avg:94.56ms +step:297/1670 train_time:28082ms step_avg:94.55ms +step:298/1670 train_time:28175ms step_avg:94.55ms +step:299/1670 train_time:28267ms step_avg:94.54ms +step:300/1670 train_time:28360ms step_avg:94.53ms +step:301/1670 train_time:28452ms step_avg:94.53ms +step:302/1670 train_time:28544ms step_avg:94.52ms +step:303/1670 train_time:28637ms step_avg:94.51ms +step:304/1670 train_time:28728ms step_avg:94.50ms +step:305/1670 train_time:28820ms step_avg:94.49ms +step:306/1670 train_time:28912ms step_avg:94.48ms +step:307/1670 train_time:29004ms step_avg:94.48ms +step:308/1670 train_time:29098ms step_avg:94.47ms +step:309/1670 train_time:29190ms step_avg:94.46ms +step:310/1670 train_time:29282ms step_avg:94.46ms +step:311/1670 train_time:29375ms step_avg:94.45ms +step:312/1670 train_time:29467ms step_avg:94.45ms +step:313/1670 train_time:29560ms step_avg:94.44ms +step:314/1670 train_time:29652ms step_avg:94.43ms +step:315/1670 train_time:29744ms step_avg:94.43ms +step:316/1670 train_time:29837ms step_avg:94.42ms +step:317/1670 train_time:29929ms step_avg:94.41ms +step:318/1670 train_time:30021ms step_avg:94.41ms +step:319/1670 train_time:30114ms step_avg:94.40ms +step:320/1670 train_time:30207ms step_avg:94.40ms +step:321/1670 train_time:30300ms step_avg:94.39ms +step:322/1670 train_time:30392ms step_avg:94.39ms +step:323/1670 train_time:30484ms step_avg:94.38ms +step:324/1670 train_time:30577ms step_avg:94.37ms +step:325/1670 train_time:30670ms step_avg:94.37ms +step:326/1670 train_time:30762ms step_avg:94.36ms +step:327/1670 train_time:30855ms step_avg:94.36ms +step:328/1670 train_time:30947ms step_avg:94.35ms +step:329/1670 train_time:31040ms step_avg:94.35ms +step:330/1670 train_time:31131ms step_avg:94.34ms +step:331/1670 train_time:31225ms step_avg:94.33ms +step:332/1670 train_time:31318ms step_avg:94.33ms +step:333/1670 train_time:31410ms step_avg:94.32ms +step:334/1670 train_time:31503ms step_avg:94.32ms +step:335/1670 train_time:31595ms step_avg:94.31ms +step:336/1670 train_time:31687ms step_avg:94.31ms +step:337/1670 train_time:31779ms step_avg:94.30ms +step:338/1670 train_time:31872ms step_avg:94.30ms +step:339/1670 train_time:31964ms step_avg:94.29ms +step:340/1670 train_time:32057ms step_avg:94.29ms +step:341/1670 train_time:32149ms step_avg:94.28ms +step:342/1670 train_time:32242ms step_avg:94.28ms +step:343/1670 train_time:32334ms step_avg:94.27ms +step:344/1670 train_time:32427ms step_avg:94.26ms +step:345/1670 train_time:32519ms step_avg:94.26ms +step:346/1670 train_time:32611ms step_avg:94.25ms +step:347/1670 train_time:32705ms step_avg:94.25ms +step:348/1670 train_time:32799ms step_avg:94.25ms +step:349/1670 train_time:32890ms step_avg:94.24ms +step:350/1670 train_time:32983ms step_avg:94.24ms +step:351/1670 train_time:33075ms step_avg:94.23ms +step:352/1670 train_time:33167ms step_avg:94.22ms +step:353/1670 train_time:33260ms step_avg:94.22ms +step:354/1670 train_time:33352ms step_avg:94.22ms +step:355/1670 train_time:33445ms step_avg:94.21ms +step:356/1670 train_time:33538ms step_avg:94.21ms +step:357/1670 train_time:33631ms step_avg:94.20ms +step:358/1670 train_time:33724ms step_avg:94.20ms +step:359/1670 train_time:33816ms step_avg:94.19ms +step:360/1670 train_time:33908ms step_avg:94.19ms +step:361/1670 train_time:34002ms step_avg:94.19ms +step:362/1670 train_time:34094ms step_avg:94.18ms +step:363/1670 train_time:34187ms step_avg:94.18ms +step:364/1670 train_time:34279ms step_avg:94.17ms +step:365/1670 train_time:34372ms step_avg:94.17ms +step:366/1670 train_time:34464ms step_avg:94.16ms +step:367/1670 train_time:34558ms step_avg:94.16ms +step:368/1670 train_time:34651ms step_avg:94.16ms +step:369/1670 train_time:34743ms step_avg:94.15ms +step:370/1670 train_time:34835ms step_avg:94.15ms +step:371/1670 train_time:34928ms step_avg:94.14ms +step:372/1670 train_time:35021ms step_avg:94.14ms +step:373/1670 train_time:35113ms step_avg:94.14ms +step:374/1670 train_time:35206ms step_avg:94.13ms +step:375/1670 train_time:35298ms step_avg:94.13ms +step:375/1670 val_loss:3.8193 train_time:35389ms step_avg:94.37ms +step:376/1670 train_time:35415ms step_avg:94.19ms +step:377/1670 train_time:35489ms step_avg:94.13ms +step:378/1670 train_time:35586ms step_avg:94.14ms +step:379/1670 train_time:35682ms step_avg:94.15ms +step:380/1670 train_time:35774ms step_avg:94.14ms +step:381/1670 train_time:35865ms step_avg:94.13ms +step:382/1670 train_time:35957ms step_avg:94.13ms +step:383/1670 train_time:36048ms step_avg:94.12ms +step:384/1670 train_time:36140ms step_avg:94.11ms +step:385/1670 train_time:36231ms step_avg:94.11ms +step:386/1670 train_time:36323ms step_avg:94.10ms +step:387/1670 train_time:36416ms step_avg:94.10ms +step:388/1670 train_time:36511ms step_avg:94.10ms +step:389/1670 train_time:36606ms step_avg:94.10ms +step:390/1670 train_time:36699ms step_avg:94.10ms +step:391/1670 train_time:36792ms step_avg:94.10ms +step:392/1670 train_time:36885ms step_avg:94.09ms +step:393/1670 train_time:36977ms step_avg:94.09ms +step:394/1670 train_time:37069ms step_avg:94.08ms +step:395/1670 train_time:37160ms step_avg:94.08ms +step:396/1670 train_time:37252ms step_avg:94.07ms +step:397/1670 train_time:37345ms step_avg:94.07ms +step:398/1670 train_time:37437ms step_avg:94.06ms +step:399/1670 train_time:37531ms step_avg:94.06ms +step:400/1670 train_time:37623ms step_avg:94.06ms +step:401/1670 train_time:37716ms step_avg:94.06ms +step:402/1670 train_time:37810ms step_avg:94.05ms +step:403/1670 train_time:37902ms step_avg:94.05ms +step:404/1670 train_time:37994ms step_avg:94.04ms +step:405/1670 train_time:38087ms step_avg:94.04ms +step:406/1670 train_time:38179ms step_avg:94.04ms +step:407/1670 train_time:38271ms step_avg:94.03ms +step:408/1670 train_time:38364ms step_avg:94.03ms +step:409/1670 train_time:38456ms step_avg:94.03ms +step:410/1670 train_time:38549ms step_avg:94.02ms +step:411/1670 train_time:38642ms step_avg:94.02ms +step:412/1670 train_time:38735ms step_avg:94.02ms +step:413/1670 train_time:38828ms step_avg:94.01ms +step:414/1670 train_time:38920ms step_avg:94.01ms +step:415/1670 train_time:39012ms step_avg:94.01ms +step:416/1670 train_time:39105ms step_avg:94.00ms +step:417/1670 train_time:39197ms step_avg:94.00ms +step:418/1670 train_time:39289ms step_avg:93.99ms +step:419/1670 train_time:39383ms step_avg:93.99ms +step:420/1670 train_time:39475ms step_avg:93.99ms +step:421/1670 train_time:39568ms step_avg:93.99ms +step:422/1670 train_time:39660ms step_avg:93.98ms +step:423/1670 train_time:39754ms step_avg:93.98ms +step:424/1670 train_time:39847ms step_avg:93.98ms +step:425/1670 train_time:40160ms step_avg:94.49ms +step:426/1670 train_time:40349ms step_avg:94.72ms +step:427/1670 train_time:40439ms step_avg:94.71ms +step:428/1670 train_time:40531ms step_avg:94.70ms +step:429/1670 train_time:40622ms step_avg:94.69ms +step:430/1670 train_time:40713ms step_avg:94.68ms +step:431/1670 train_time:40805ms step_avg:94.68ms +step:432/1670 train_time:40896ms step_avg:94.67ms +step:433/1670 train_time:40988ms step_avg:94.66ms +step:434/1670 train_time:41079ms step_avg:94.65ms +step:435/1670 train_time:41171ms step_avg:94.65ms +step:436/1670 train_time:41266ms step_avg:94.65ms +step:437/1670 train_time:41363ms step_avg:94.65ms +step:438/1670 train_time:41456ms step_avg:94.65ms +step:439/1670 train_time:41550ms step_avg:94.65ms +step:440/1670 train_time:41642ms step_avg:94.64ms +step:441/1670 train_time:41733ms step_avg:94.63ms +step:442/1670 train_time:41825ms step_avg:94.63ms +step:443/1670 train_time:41918ms step_avg:94.62ms +step:444/1670 train_time:42010ms step_avg:94.62ms +step:445/1670 train_time:42101ms step_avg:94.61ms +step:446/1670 train_time:42194ms step_avg:94.61ms +step:447/1670 train_time:42288ms step_avg:94.61ms +step:448/1670 train_time:42383ms step_avg:94.60ms +step:449/1670 train_time:42476ms step_avg:94.60ms +step:450/1670 train_time:42569ms step_avg:94.60ms +step:451/1670 train_time:42661ms step_avg:94.59ms +step:452/1670 train_time:42753ms step_avg:94.59ms +step:453/1670 train_time:42845ms step_avg:94.58ms +step:454/1670 train_time:42937ms step_avg:94.57ms +step:455/1670 train_time:43029ms step_avg:94.57ms +step:456/1670 train_time:43121ms step_avg:94.56ms +step:457/1670 train_time:43214ms step_avg:94.56ms +step:458/1670 train_time:43308ms step_avg:94.56ms +step:459/1670 train_time:43401ms step_avg:94.55ms +step:460/1670 train_time:43494ms step_avg:94.55ms +step:461/1670 train_time:43586ms step_avg:94.55ms +step:462/1670 train_time:43679ms step_avg:94.54ms +step:463/1670 train_time:43771ms step_avg:94.54ms +step:464/1670 train_time:43863ms step_avg:94.53ms +step:465/1670 train_time:43956ms step_avg:94.53ms +step:466/1670 train_time:44047ms step_avg:94.52ms +step:467/1670 train_time:44140ms step_avg:94.52ms +step:468/1670 train_time:44233ms step_avg:94.51ms +step:469/1670 train_time:44325ms step_avg:94.51ms +step:470/1670 train_time:44418ms step_avg:94.51ms +step:471/1670 train_time:44511ms step_avg:94.50ms +step:472/1670 train_time:44604ms step_avg:94.50ms +step:473/1670 train_time:44697ms step_avg:94.50ms +step:474/1670 train_time:44789ms step_avg:94.49ms +step:475/1670 train_time:44881ms step_avg:94.49ms +step:476/1670 train_time:44973ms step_avg:94.48ms +step:477/1670 train_time:45066ms step_avg:94.48ms +step:478/1670 train_time:45158ms step_avg:94.47ms +step:479/1670 train_time:45251ms step_avg:94.47ms +step:480/1670 train_time:45343ms step_avg:94.46ms +step:481/1670 train_time:45436ms step_avg:94.46ms +step:482/1670 train_time:45529ms step_avg:94.46ms +step:483/1670 train_time:45623ms step_avg:94.46ms +step:484/1670 train_time:45716ms step_avg:94.45ms +step:485/1670 train_time:45809ms step_avg:94.45ms +step:486/1670 train_time:45900ms step_avg:94.45ms +step:487/1670 train_time:45994ms step_avg:94.44ms +step:488/1670 train_time:46087ms step_avg:94.44ms +step:489/1670 train_time:46180ms step_avg:94.44ms +step:490/1670 train_time:46273ms step_avg:94.43ms +step:491/1670 train_time:46365ms step_avg:94.43ms +step:492/1670 train_time:46457ms step_avg:94.43ms +step:493/1670 train_time:46550ms step_avg:94.42ms +step:494/1670 train_time:46643ms step_avg:94.42ms +step:495/1670 train_time:46736ms step_avg:94.42ms +step:496/1670 train_time:46829ms step_avg:94.41ms +step:497/1670 train_time:46921ms step_avg:94.41ms +step:498/1670 train_time:47013ms step_avg:94.40ms +step:499/1670 train_time:47106ms step_avg:94.40ms +step:500/1670 train_time:47198ms step_avg:94.40ms +step:500/1670 val_loss:3.7146 train_time:47288ms step_avg:94.58ms +step:501/1670 train_time:47314ms step_avg:94.44ms +step:502/1670 train_time:47387ms step_avg:94.40ms +step:503/1670 train_time:47491ms step_avg:94.42ms +step:504/1670 train_time:47585ms step_avg:94.41ms +step:505/1670 train_time:47677ms step_avg:94.41ms +step:506/1670 train_time:47769ms step_avg:94.40ms +step:507/1670 train_time:47860ms step_avg:94.40ms +step:508/1670 train_time:47951ms step_avg:94.39ms +step:509/1670 train_time:48043ms step_avg:94.39ms +step:510/1670 train_time:48135ms step_avg:94.38ms +step:511/1670 train_time:48226ms step_avg:94.38ms +step:512/1670 train_time:48319ms step_avg:94.37ms +step:513/1670 train_time:48412ms step_avg:94.37ms +step:514/1670 train_time:48506ms step_avg:94.37ms +step:515/1670 train_time:48600ms step_avg:94.37ms +step:516/1670 train_time:48693ms step_avg:94.37ms +step:517/1670 train_time:48785ms step_avg:94.36ms +step:518/1670 train_time:48877ms step_avg:94.36ms +step:519/1670 train_time:48969ms step_avg:94.35ms +step:520/1670 train_time:49061ms step_avg:94.35ms +step:521/1670 train_time:49153ms step_avg:94.34ms +step:522/1670 train_time:49245ms step_avg:94.34ms +step:523/1670 train_time:49338ms step_avg:94.34ms +step:524/1670 train_time:49431ms step_avg:94.33ms +step:525/1670 train_time:49525ms step_avg:94.33ms +step:526/1670 train_time:49618ms step_avg:94.33ms +step:527/1670 train_time:49711ms step_avg:94.33ms +step:528/1670 train_time:49803ms step_avg:94.32ms +step:529/1670 train_time:49896ms step_avg:94.32ms +step:530/1670 train_time:49988ms step_avg:94.32ms +step:531/1670 train_time:50081ms step_avg:94.31ms +step:532/1670 train_time:50173ms step_avg:94.31ms +step:533/1670 train_time:50266ms step_avg:94.31ms +step:534/1670 train_time:50359ms step_avg:94.31ms +step:535/1670 train_time:50452ms step_avg:94.30ms +step:536/1670 train_time:50546ms step_avg:94.30ms +step:537/1670 train_time:50639ms step_avg:94.30ms +step:538/1670 train_time:50732ms step_avg:94.30ms +step:539/1670 train_time:50825ms step_avg:94.29ms +step:540/1670 train_time:50917ms step_avg:94.29ms +step:541/1670 train_time:51010ms step_avg:94.29ms +step:542/1670 train_time:51101ms step_avg:94.28ms +step:543/1670 train_time:51194ms step_avg:94.28ms +step:544/1670 train_time:51287ms step_avg:94.28ms +step:545/1670 train_time:51380ms step_avg:94.28ms +step:546/1670 train_time:51472ms step_avg:94.27ms +step:547/1670 train_time:51565ms step_avg:94.27ms +step:548/1670 train_time:51659ms step_avg:94.27ms +step:549/1670 train_time:51752ms step_avg:94.27ms +step:550/1670 train_time:51844ms step_avg:94.26ms +step:551/1670 train_time:51936ms step_avg:94.26ms +step:552/1670 train_time:52029ms step_avg:94.25ms +step:553/1670 train_time:52121ms step_avg:94.25ms +step:554/1670 train_time:52214ms step_avg:94.25ms +step:555/1670 train_time:52306ms step_avg:94.25ms +step:556/1670 train_time:52399ms step_avg:94.24ms +step:557/1670 train_time:52493ms step_avg:94.24ms +step:558/1670 train_time:52698ms step_avg:94.44ms +step:559/1670 train_time:52765ms step_avg:94.39ms +step:560/1670 train_time:52856ms step_avg:94.39ms +step:561/1670 train_time:52949ms step_avg:94.38ms +step:562/1670 train_time:53042ms step_avg:94.38ms +step:563/1670 train_time:53135ms step_avg:94.38ms +step:564/1670 train_time:53227ms step_avg:94.37ms +step:565/1670 train_time:53320ms step_avg:94.37ms +step:566/1670 train_time:53412ms step_avg:94.37ms +step:567/1670 train_time:53505ms step_avg:94.36ms +step:568/1670 train_time:53602ms step_avg:94.37ms +step:569/1670 train_time:53700ms step_avg:94.38ms +step:570/1670 train_time:53795ms step_avg:94.38ms +step:571/1670 train_time:53889ms step_avg:94.38ms +step:572/1670 train_time:53981ms step_avg:94.37ms +step:573/1670 train_time:54074ms step_avg:94.37ms +step:574/1670 train_time:54167ms step_avg:94.37ms +step:575/1670 train_time:54260ms step_avg:94.36ms +step:576/1670 train_time:54352ms step_avg:94.36ms +step:577/1670 train_time:54445ms step_avg:94.36ms +step:578/1670 train_time:54539ms step_avg:94.36ms +step:579/1670 train_time:54635ms step_avg:94.36ms +step:580/1670 train_time:54730ms step_avg:94.36ms +step:581/1670 train_time:54825ms step_avg:94.36ms +step:582/1670 train_time:54919ms step_avg:94.36ms +step:583/1670 train_time:55013ms step_avg:94.36ms +step:584/1670 train_time:55106ms step_avg:94.36ms +step:585/1670 train_time:55199ms step_avg:94.36ms +step:586/1670 train_time:55292ms step_avg:94.35ms +step:587/1670 train_time:55384ms step_avg:94.35ms +step:588/1670 train_time:55478ms step_avg:94.35ms +step:589/1670 train_time:55572ms step_avg:94.35ms +step:590/1670 train_time:55667ms step_avg:94.35ms +step:591/1670 train_time:55762ms step_avg:94.35ms +step:592/1670 train_time:55857ms step_avg:94.35ms +step:593/1670 train_time:55950ms step_avg:94.35ms +step:594/1670 train_time:56044ms step_avg:94.35ms +step:595/1670 train_time:56137ms step_avg:94.35ms +step:596/1670 train_time:56231ms step_avg:94.35ms +step:597/1670 train_time:56324ms step_avg:94.34ms +step:598/1670 train_time:56417ms step_avg:94.34ms +step:599/1670 train_time:56511ms step_avg:94.34ms +step:600/1670 train_time:56605ms step_avg:94.34ms +step:601/1670 train_time:56700ms step_avg:94.34ms +step:602/1670 train_time:56794ms step_avg:94.34ms +step:603/1670 train_time:56889ms step_avg:94.34ms +step:604/1670 train_time:56982ms step_avg:94.34ms +step:605/1670 train_time:57076ms step_avg:94.34ms +step:606/1670 train_time:57170ms step_avg:94.34ms +step:607/1670 train_time:57263ms step_avg:94.34ms +step:608/1670 train_time:57356ms step_avg:94.34ms +step:609/1670 train_time:57449ms step_avg:94.33ms +step:610/1670 train_time:57543ms step_avg:94.33ms +step:611/1670 train_time:57638ms step_avg:94.33ms +step:612/1670 train_time:57732ms step_avg:94.33ms +step:613/1670 train_time:57826ms step_avg:94.33ms +step:614/1670 train_time:57920ms step_avg:94.33ms +step:615/1670 train_time:58014ms step_avg:94.33ms +step:616/1670 train_time:58108ms step_avg:94.33ms +step:617/1670 train_time:58202ms step_avg:94.33ms +step:618/1670 train_time:58296ms step_avg:94.33ms +step:619/1670 train_time:58389ms step_avg:94.33ms +step:620/1670 train_time:58483ms step_avg:94.33ms +step:621/1670 train_time:58576ms step_avg:94.32ms +step:622/1670 train_time:58670ms step_avg:94.32ms +step:623/1670 train_time:58763ms step_avg:94.32ms +step:624/1670 train_time:58859ms step_avg:94.33ms +step:625/1670 train_time:58952ms step_avg:94.32ms +step:625/1670 val_loss:3.6140 train_time:59044ms step_avg:94.47ms +step:626/1670 train_time:59070ms step_avg:94.36ms +step:627/1670 train_time:59150ms step_avg:94.34ms +step:628/1670 train_time:59247ms step_avg:94.34ms +step:629/1670 train_time:59341ms step_avg:94.34ms +step:630/1670 train_time:59433ms step_avg:94.34ms +step:631/1670 train_time:59526ms step_avg:94.34ms +step:632/1670 train_time:59619ms step_avg:94.33ms +step:633/1670 train_time:59712ms step_avg:94.33ms +step:634/1670 train_time:59804ms step_avg:94.33ms +step:635/1670 train_time:59897ms step_avg:94.33ms +step:636/1670 train_time:59991ms step_avg:94.33ms +step:637/1670 train_time:60088ms step_avg:94.33ms +step:638/1670 train_time:60184ms step_avg:94.33ms +step:639/1670 train_time:60625ms step_avg:94.88ms +step:640/1670 train_time:60703ms step_avg:94.85ms +step:641/1670 train_time:60796ms step_avg:94.85ms +step:642/1670 train_time:60888ms step_avg:94.84ms +step:643/1670 train_time:60981ms step_avg:94.84ms +step:644/1670 train_time:61074ms step_avg:94.83ms +step:645/1670 train_time:61166ms step_avg:94.83ms +step:646/1670 train_time:61258ms step_avg:94.83ms +step:647/1670 train_time:61351ms step_avg:94.82ms +step:648/1670 train_time:61444ms step_avg:94.82ms +step:649/1670 train_time:61539ms step_avg:94.82ms +step:650/1670 train_time:61636ms step_avg:94.82ms +step:651/1670 train_time:61731ms step_avg:94.82ms +step:652/1670 train_time:61824ms step_avg:94.82ms +step:653/1670 train_time:61917ms step_avg:94.82ms +step:654/1670 train_time:62011ms step_avg:94.82ms +step:655/1670 train_time:62103ms step_avg:94.81ms +step:656/1670 train_time:62196ms step_avg:94.81ms +step:657/1670 train_time:62289ms step_avg:94.81ms +step:658/1670 train_time:62382ms step_avg:94.81ms +step:659/1670 train_time:62476ms step_avg:94.80ms +step:660/1670 train_time:62571ms step_avg:94.80ms +step:661/1670 train_time:62667ms step_avg:94.81ms +step:662/1670 train_time:62761ms step_avg:94.81ms +step:663/1670 train_time:62855ms step_avg:94.80ms +step:664/1670 train_time:62949ms step_avg:94.80ms +step:665/1670 train_time:63043ms step_avg:94.80ms +step:666/1670 train_time:63136ms step_avg:94.80ms +step:667/1670 train_time:63229ms step_avg:94.80ms +step:668/1670 train_time:63322ms step_avg:94.79ms +step:669/1670 train_time:63416ms step_avg:94.79ms +step:670/1670 train_time:63510ms step_avg:94.79ms +step:671/1670 train_time:63604ms step_avg:94.79ms +step:672/1670 train_time:63699ms step_avg:94.79ms +step:673/1670 train_time:63793ms step_avg:94.79ms +step:674/1670 train_time:63887ms step_avg:94.79ms +step:675/1670 train_time:63980ms step_avg:94.79ms +step:676/1670 train_time:64074ms step_avg:94.78ms +step:677/1670 train_time:64167ms step_avg:94.78ms +step:678/1670 train_time:64260ms step_avg:94.78ms +step:679/1670 train_time:64353ms step_avg:94.78ms +step:680/1670 train_time:64447ms step_avg:94.78ms +step:681/1670 train_time:64541ms step_avg:94.77ms +step:682/1670 train_time:64635ms step_avg:94.77ms +step:683/1670 train_time:64729ms step_avg:94.77ms +step:684/1670 train_time:64823ms step_avg:94.77ms +step:685/1670 train_time:64917ms step_avg:94.77ms +step:686/1670 train_time:65011ms step_avg:94.77ms +step:687/1670 train_time:65104ms step_avg:94.77ms +step:688/1670 train_time:65198ms step_avg:94.76ms +step:689/1670 train_time:65292ms step_avg:94.76ms +step:690/1670 train_time:65385ms step_avg:94.76ms +step:691/1670 train_time:65480ms step_avg:94.76ms +step:692/1670 train_time:65573ms step_avg:94.76ms +step:693/1670 train_time:65667ms step_avg:94.76ms +step:694/1670 train_time:65761ms step_avg:94.76ms +step:695/1670 train_time:65855ms step_avg:94.76ms +step:696/1670 train_time:65950ms step_avg:94.76ms +step:697/1670 train_time:66044ms step_avg:94.75ms +step:698/1670 train_time:66138ms step_avg:94.75ms +step:699/1670 train_time:66231ms step_avg:94.75ms +step:700/1670 train_time:66325ms step_avg:94.75ms +step:701/1670 train_time:66419ms step_avg:94.75ms +step:702/1670 train_time:66512ms step_avg:94.75ms +step:703/1670 train_time:66606ms step_avg:94.75ms +step:704/1670 train_time:66700ms step_avg:94.74ms +step:705/1670 train_time:66794ms step_avg:94.74ms +step:706/1670 train_time:66887ms step_avg:94.74ms +step:707/1670 train_time:66982ms step_avg:94.74ms +step:708/1670 train_time:67076ms step_avg:94.74ms +step:709/1670 train_time:67169ms step_avg:94.74ms +step:710/1670 train_time:67262ms step_avg:94.74ms +step:711/1670 train_time:67356ms step_avg:94.73ms +step:712/1670 train_time:67450ms step_avg:94.73ms +step:713/1670 train_time:67543ms step_avg:94.73ms +step:714/1670 train_time:67638ms step_avg:94.73ms +step:715/1670 train_time:67731ms step_avg:94.73ms +step:716/1670 train_time:67825ms step_avg:94.73ms +step:717/1670 train_time:67919ms step_avg:94.73ms +step:718/1670 train_time:68013ms step_avg:94.73ms +step:719/1670 train_time:68107ms step_avg:94.72ms +step:720/1670 train_time:68201ms step_avg:94.72ms +step:721/1670 train_time:68295ms step_avg:94.72ms +step:722/1670 train_time:68387ms step_avg:94.72ms +step:723/1670 train_time:68481ms step_avg:94.72ms +step:724/1670 train_time:68574ms step_avg:94.72ms +step:725/1670 train_time:68668ms step_avg:94.71ms +step:726/1670 train_time:68761ms step_avg:94.71ms +step:727/1670 train_time:68855ms step_avg:94.71ms +step:728/1670 train_time:68950ms step_avg:94.71ms +step:729/1670 train_time:69043ms step_avg:94.71ms +step:730/1670 train_time:69138ms step_avg:94.71ms +step:731/1670 train_time:69231ms step_avg:94.71ms +step:732/1670 train_time:69324ms step_avg:94.71ms +step:733/1670 train_time:69419ms step_avg:94.70ms +step:734/1670 train_time:69512ms step_avg:94.70ms +step:735/1670 train_time:69606ms step_avg:94.70ms +step:736/1670 train_time:69700ms step_avg:94.70ms +step:737/1670 train_time:69794ms step_avg:94.70ms +step:738/1670 train_time:69888ms step_avg:94.70ms +step:739/1670 train_time:69981ms step_avg:94.70ms +step:740/1670 train_time:70076ms step_avg:94.70ms +step:741/1670 train_time:70169ms step_avg:94.69ms +step:742/1670 train_time:70263ms step_avg:94.69ms +step:743/1670 train_time:70357ms step_avg:94.69ms +step:744/1670 train_time:70451ms step_avg:94.69ms +step:745/1670 train_time:70545ms step_avg:94.69ms +step:746/1670 train_time:70639ms step_avg:94.69ms +step:747/1670 train_time:70733ms step_avg:94.69ms +step:748/1670 train_time:70826ms step_avg:94.69ms +step:749/1670 train_time:70920ms step_avg:94.69ms +step:750/1670 train_time:71014ms step_avg:94.69ms +step:750/1670 val_loss:3.5629 train_time:71106ms step_avg:94.81ms +step:751/1670 train_time:71133ms step_avg:94.72ms +step:752/1670 train_time:71210ms step_avg:94.69ms +step:753/1670 train_time:71311ms step_avg:94.70ms +step:754/1670 train_time:71408ms step_avg:94.71ms +step:755/1670 train_time:71501ms step_avg:94.70ms +step:756/1670 train_time:71594ms step_avg:94.70ms +step:757/1670 train_time:71686ms step_avg:94.70ms +step:758/1670 train_time:71779ms step_avg:94.69ms +step:759/1670 train_time:71871ms step_avg:94.69ms +step:760/1670 train_time:71964ms step_avg:94.69ms +step:761/1670 train_time:72057ms step_avg:94.69ms +step:762/1670 train_time:72151ms step_avg:94.69ms +step:763/1670 train_time:72247ms step_avg:94.69ms +step:764/1670 train_time:72342ms step_avg:94.69ms +step:765/1670 train_time:72438ms step_avg:94.69ms +step:766/1670 train_time:72532ms step_avg:94.69ms +step:767/1670 train_time:72626ms step_avg:94.69ms +step:768/1670 train_time:72719ms step_avg:94.69ms +step:769/1670 train_time:72812ms step_avg:94.68ms +step:770/1670 train_time:72905ms step_avg:94.68ms +step:771/1670 train_time:72997ms step_avg:94.68ms +step:772/1670 train_time:73091ms step_avg:94.68ms +step:773/1670 train_time:73186ms step_avg:94.68ms +step:774/1670 train_time:73281ms step_avg:94.68ms +step:775/1670 train_time:73376ms step_avg:94.68ms +step:776/1670 train_time:73469ms step_avg:94.68ms +step:777/1670 train_time:73563ms step_avg:94.68ms +step:778/1670 train_time:73657ms step_avg:94.67ms +step:779/1670 train_time:73750ms step_avg:94.67ms +step:780/1670 train_time:73843ms step_avg:94.67ms +step:781/1670 train_time:73936ms step_avg:94.67ms +step:782/1670 train_time:74029ms step_avg:94.67ms +step:783/1670 train_time:74123ms step_avg:94.67ms +step:784/1670 train_time:74217ms step_avg:94.66ms +step:785/1670 train_time:74311ms step_avg:94.66ms +step:786/1670 train_time:74406ms step_avg:94.66ms +step:787/1670 train_time:74500ms step_avg:94.66ms +step:788/1670 train_time:74594ms step_avg:94.66ms +step:789/1670 train_time:74687ms step_avg:94.66ms +step:790/1670 train_time:74781ms step_avg:94.66ms +step:791/1670 train_time:74874ms step_avg:94.66ms +step:792/1670 train_time:74967ms step_avg:94.66ms +step:793/1670 train_time:75061ms step_avg:94.65ms +step:794/1670 train_time:75154ms step_avg:94.65ms +step:795/1670 train_time:75248ms step_avg:94.65ms +step:796/1670 train_time:75343ms step_avg:94.65ms +step:797/1670 train_time:75437ms step_avg:94.65ms +step:798/1670 train_time:75530ms step_avg:94.65ms +step:799/1670 train_time:75624ms step_avg:94.65ms +step:800/1670 train_time:75718ms step_avg:94.65ms +step:801/1670 train_time:75812ms step_avg:94.65ms +step:802/1670 train_time:75905ms step_avg:94.64ms +step:803/1670 train_time:75998ms step_avg:94.64ms +step:804/1670 train_time:76091ms step_avg:94.64ms +step:805/1670 train_time:76185ms step_avg:94.64ms +step:806/1670 train_time:76279ms step_avg:94.64ms +step:807/1670 train_time:76373ms step_avg:94.64ms +step:808/1670 train_time:76467ms step_avg:94.64ms +step:809/1670 train_time:76560ms step_avg:94.64ms +step:810/1670 train_time:76654ms step_avg:94.64ms +step:811/1670 train_time:76748ms step_avg:94.63ms +step:812/1670 train_time:76842ms step_avg:94.63ms +step:813/1670 train_time:76935ms step_avg:94.63ms +step:814/1670 train_time:77028ms step_avg:94.63ms +step:815/1670 train_time:77121ms step_avg:94.63ms +step:816/1670 train_time:77216ms step_avg:94.63ms +step:817/1670 train_time:77309ms step_avg:94.63ms +step:818/1670 train_time:77403ms step_avg:94.62ms +step:819/1670 train_time:77497ms step_avg:94.62ms +step:820/1670 train_time:77591ms step_avg:94.62ms +step:821/1670 train_time:77684ms step_avg:94.62ms +step:822/1670 train_time:77778ms step_avg:94.62ms +step:823/1670 train_time:77872ms step_avg:94.62ms +step:824/1670 train_time:77967ms step_avg:94.62ms +step:825/1670 train_time:78060ms step_avg:94.62ms +step:826/1670 train_time:78155ms step_avg:94.62ms +step:827/1670 train_time:78249ms step_avg:94.62ms +step:828/1670 train_time:78342ms step_avg:94.62ms +step:829/1670 train_time:78435ms step_avg:94.61ms +step:830/1670 train_time:78528ms step_avg:94.61ms +step:831/1670 train_time:78622ms step_avg:94.61ms +step:832/1670 train_time:78716ms step_avg:94.61ms +step:833/1670 train_time:78810ms step_avg:94.61ms +step:834/1670 train_time:78905ms step_avg:94.61ms +step:835/1670 train_time:78998ms step_avg:94.61ms +step:836/1670 train_time:79093ms step_avg:94.61ms +step:837/1670 train_time:79187ms step_avg:94.61ms +step:838/1670 train_time:79281ms step_avg:94.61ms +step:839/1670 train_time:79375ms step_avg:94.61ms +step:840/1670 train_time:79467ms step_avg:94.60ms +step:841/1670 train_time:79561ms step_avg:94.60ms +step:842/1670 train_time:79655ms step_avg:94.60ms +step:843/1670 train_time:79749ms step_avg:94.60ms +step:844/1670 train_time:79843ms step_avg:94.60ms +step:845/1670 train_time:79937ms step_avg:94.60ms +step:846/1670 train_time:80031ms step_avg:94.60ms +step:847/1670 train_time:80126ms step_avg:94.60ms +step:848/1670 train_time:80220ms step_avg:94.60ms +step:849/1670 train_time:80313ms step_avg:94.60ms +step:850/1670 train_time:80407ms step_avg:94.60ms +step:851/1670 train_time:80847ms step_avg:95.00ms +step:852/1670 train_time:80915ms step_avg:94.97ms +step:853/1670 train_time:81007ms step_avg:94.97ms +step:854/1670 train_time:81099ms step_avg:94.96ms +step:855/1670 train_time:81192ms step_avg:94.96ms +step:856/1670 train_time:81285ms step_avg:94.96ms +step:857/1670 train_time:81378ms step_avg:94.96ms +step:858/1670 train_time:81470ms step_avg:94.95ms +step:859/1670 train_time:81563ms step_avg:94.95ms +step:860/1670 train_time:81656ms step_avg:94.95ms +step:861/1670 train_time:81753ms step_avg:94.95ms +step:862/1670 train_time:81851ms step_avg:94.95ms +step:863/1670 train_time:81947ms step_avg:94.96ms +step:864/1670 train_time:82041ms step_avg:94.96ms +step:865/1670 train_time:82134ms step_avg:94.95ms +step:866/1670 train_time:82227ms step_avg:94.95ms +step:867/1670 train_time:82321ms step_avg:94.95ms +step:868/1670 train_time:82413ms step_avg:94.95ms +step:869/1670 train_time:82506ms step_avg:94.94ms +step:870/1670 train_time:82598ms step_avg:94.94ms +step:871/1670 train_time:82692ms step_avg:94.94ms +step:872/1670 train_time:82788ms step_avg:94.94ms +step:873/1670 train_time:82885ms step_avg:94.94ms +step:874/1670 train_time:82979ms step_avg:94.94ms +step:875/1670 train_time:83073ms step_avg:94.94ms +step:875/1670 val_loss:3.5191 train_time:83165ms step_avg:95.05ms +step:876/1670 train_time:83191ms step_avg:94.97ms +step:877/1670 train_time:83268ms step_avg:94.95ms +step:878/1670 train_time:83365ms step_avg:94.95ms +step:879/1670 train_time:83460ms step_avg:94.95ms +step:880/1670 train_time:83553ms step_avg:94.95ms +step:881/1670 train_time:83646ms step_avg:94.94ms +step:882/1670 train_time:83739ms step_avg:94.94ms +step:883/1670 train_time:83832ms step_avg:94.94ms +step:884/1670 train_time:83924ms step_avg:94.94ms +step:885/1670 train_time:84017ms step_avg:94.93ms +step:886/1670 train_time:84110ms step_avg:94.93ms +step:887/1670 train_time:84205ms step_avg:94.93ms +step:888/1670 train_time:84302ms step_avg:94.93ms +step:889/1670 train_time:84397ms step_avg:94.93ms +step:890/1670 train_time:84492ms step_avg:94.94ms +step:891/1670 train_time:84586ms step_avg:94.93ms +step:892/1670 train_time:84679ms step_avg:94.93ms +step:893/1670 train_time:84772ms step_avg:94.93ms +step:894/1670 train_time:84865ms step_avg:94.93ms +step:895/1670 train_time:84957ms step_avg:94.92ms +step:896/1670 train_time:85050ms step_avg:94.92ms +step:897/1670 train_time:85144ms step_avg:94.92ms +step:898/1670 train_time:85239ms step_avg:94.92ms +step:899/1670 train_time:85334ms step_avg:94.92ms +step:900/1670 train_time:85428ms step_avg:94.92ms +step:901/1670 train_time:85523ms step_avg:94.92ms +step:902/1670 train_time:85617ms step_avg:94.92ms +step:903/1670 train_time:85711ms step_avg:94.92ms +step:904/1670 train_time:85803ms step_avg:94.92ms +step:905/1670 train_time:85897ms step_avg:94.91ms +step:906/1670 train_time:85991ms step_avg:94.91ms +step:907/1670 train_time:86084ms step_avg:94.91ms +step:908/1670 train_time:86177ms step_avg:94.91ms +step:909/1670 train_time:86271ms step_avg:94.91ms +step:910/1670 train_time:86366ms step_avg:94.91ms +step:911/1670 train_time:86461ms step_avg:94.91ms +step:912/1670 train_time:86555ms step_avg:94.91ms +step:913/1670 train_time:86649ms step_avg:94.91ms +step:914/1670 train_time:86743ms step_avg:94.90ms +step:915/1670 train_time:86837ms step_avg:94.90ms +step:916/1670 train_time:86930ms step_avg:94.90ms +step:917/1670 train_time:87024ms step_avg:94.90ms +step:918/1670 train_time:87117ms step_avg:94.90ms +step:919/1670 train_time:87210ms step_avg:94.90ms +step:920/1670 train_time:87304ms step_avg:94.90ms +step:921/1670 train_time:87398ms step_avg:94.89ms +step:922/1670 train_time:87492ms step_avg:94.89ms +step:923/1670 train_time:87586ms step_avg:94.89ms +step:924/1670 train_time:87680ms step_avg:94.89ms +step:925/1670 train_time:87774ms step_avg:94.89ms +step:926/1670 train_time:87866ms step_avg:94.89ms +step:927/1670 train_time:87960ms step_avg:94.89ms +step:928/1670 train_time:88054ms step_avg:94.89ms +step:929/1670 train_time:88147ms step_avg:94.88ms +step:930/1670 train_time:88241ms step_avg:94.88ms +step:931/1670 train_time:88334ms step_avg:94.88ms +step:932/1670 train_time:88428ms step_avg:94.88ms +step:933/1670 train_time:88524ms step_avg:94.88ms +step:934/1670 train_time:88617ms step_avg:94.88ms +step:935/1670 train_time:88711ms step_avg:94.88ms +step:936/1670 train_time:88805ms step_avg:94.88ms +step:937/1670 train_time:88898ms step_avg:94.88ms +step:938/1670 train_time:88992ms step_avg:94.87ms +step:939/1670 train_time:89086ms step_avg:94.87ms +step:940/1670 train_time:89180ms step_avg:94.87ms +step:941/1670 train_time:89273ms step_avg:94.87ms +step:942/1670 train_time:89367ms step_avg:94.87ms +step:943/1670 train_time:89461ms step_avg:94.87ms +step:944/1670 train_time:89555ms step_avg:94.87ms +step:945/1670 train_time:89649ms step_avg:94.87ms +step:946/1670 train_time:89743ms step_avg:94.87ms +step:947/1670 train_time:89836ms step_avg:94.86ms +step:948/1670 train_time:89930ms step_avg:94.86ms +step:949/1670 train_time:90024ms step_avg:94.86ms +step:950/1670 train_time:90117ms step_avg:94.86ms +step:951/1670 train_time:90211ms step_avg:94.86ms +step:952/1670 train_time:90305ms step_avg:94.86ms +step:953/1670 train_time:90399ms step_avg:94.86ms +step:954/1670 train_time:90492ms step_avg:94.86ms +step:955/1670 train_time:90586ms step_avg:94.85ms +step:956/1670 train_time:90681ms step_avg:94.85ms +step:957/1670 train_time:90775ms step_avg:94.85ms +step:958/1670 train_time:90868ms step_avg:94.85ms +step:959/1670 train_time:90961ms step_avg:94.85ms +step:960/1670 train_time:91055ms step_avg:94.85ms +step:961/1670 train_time:91149ms step_avg:94.85ms +step:962/1670 train_time:91242ms step_avg:94.85ms +step:963/1670 train_time:91335ms step_avg:94.84ms +step:964/1670 train_time:91430ms step_avg:94.84ms +step:965/1670 train_time:91524ms step_avg:94.84ms +step:966/1670 train_time:91618ms step_avg:94.84ms +step:967/1670 train_time:91712ms step_avg:94.84ms +step:968/1670 train_time:91805ms step_avg:94.84ms +step:969/1670 train_time:91899ms step_avg:94.84ms +step:970/1670 train_time:91992ms step_avg:94.84ms +step:971/1670 train_time:92086ms step_avg:94.84ms +step:972/1670 train_time:92180ms step_avg:94.84ms +step:973/1670 train_time:92273ms step_avg:94.83ms +step:974/1670 train_time:92367ms step_avg:94.83ms +step:975/1670 train_time:92461ms step_avg:94.83ms +step:976/1670 train_time:92555ms step_avg:94.83ms +step:977/1670 train_time:92650ms step_avg:94.83ms +step:978/1670 train_time:92744ms step_avg:94.83ms +step:979/1670 train_time:92837ms step_avg:94.83ms +step:980/1670 train_time:92931ms step_avg:94.83ms +step:981/1670 train_time:93024ms step_avg:94.83ms +step:982/1670 train_time:93118ms step_avg:94.83ms +step:983/1670 train_time:93213ms step_avg:94.82ms +step:984/1670 train_time:93305ms step_avg:94.82ms +step:985/1670 train_time:93399ms step_avg:94.82ms +step:986/1670 train_time:93493ms step_avg:94.82ms +step:987/1670 train_time:93586ms step_avg:94.82ms +step:988/1670 train_time:93681ms step_avg:94.82ms +step:989/1670 train_time:93775ms step_avg:94.82ms +step:990/1670 train_time:93869ms step_avg:94.82ms +step:991/1670 train_time:93963ms step_avg:94.82ms +step:992/1670 train_time:94057ms step_avg:94.82ms +step:993/1670 train_time:94150ms step_avg:94.81ms +step:994/1670 train_time:94244ms step_avg:94.81ms +step:995/1670 train_time:94337ms step_avg:94.81ms +step:996/1670 train_time:94431ms step_avg:94.81ms +step:997/1670 train_time:94525ms step_avg:94.81ms +step:998/1670 train_time:94619ms step_avg:94.81ms +step:999/1670 train_time:94714ms step_avg:94.81ms +step:1000/1670 train_time:94807ms step_avg:94.81ms +step:1000/1670 val_loss:3.4699 train_time:94899ms step_avg:94.90ms +step:1001/1670 train_time:94926ms step_avg:94.83ms +step:1002/1670 train_time:95001ms step_avg:94.81ms +step:1003/1670 train_time:95100ms step_avg:94.82ms +step:1004/1670 train_time:95193ms step_avg:94.81ms +step:1005/1670 train_time:95286ms step_avg:94.81ms +step:1006/1670 train_time:95379ms step_avg:94.81ms +step:1007/1670 train_time:95472ms step_avg:94.81ms +step:1008/1670 train_time:95564ms step_avg:94.81ms +step:1009/1670 train_time:95657ms step_avg:94.80ms +step:1010/1670 train_time:95751ms step_avg:94.80ms +step:1011/1670 train_time:95845ms step_avg:94.80ms +step:1012/1670 train_time:95940ms step_avg:94.80ms +step:1013/1670 train_time:96036ms step_avg:94.80ms +step:1014/1670 train_time:96131ms step_avg:94.80ms +step:1015/1670 train_time:96225ms step_avg:94.80ms +step:1016/1670 train_time:96319ms step_avg:94.80ms +step:1017/1670 train_time:96412ms step_avg:94.80ms +step:1018/1670 train_time:96505ms step_avg:94.80ms +step:1019/1670 train_time:96598ms step_avg:94.80ms +step:1020/1670 train_time:96691ms step_avg:94.79ms +step:1021/1670 train_time:96784ms step_avg:94.79ms +step:1022/1670 train_time:96878ms step_avg:94.79ms +step:1023/1670 train_time:96973ms step_avg:94.79ms +step:1024/1670 train_time:97068ms step_avg:94.79ms +step:1025/1670 train_time:97162ms step_avg:94.79ms +step:1026/1670 train_time:97257ms step_avg:94.79ms +step:1027/1670 train_time:97351ms step_avg:94.79ms +step:1028/1670 train_time:97445ms step_avg:94.79ms +step:1029/1670 train_time:97538ms step_avg:94.79ms +step:1030/1670 train_time:97631ms step_avg:94.79ms +step:1031/1670 train_time:97724ms step_avg:94.79ms +step:1032/1670 train_time:97818ms step_avg:94.79ms +step:1033/1670 train_time:97912ms step_avg:94.78ms +step:1034/1670 train_time:98006ms step_avg:94.78ms +step:1035/1670 train_time:98100ms step_avg:94.78ms +step:1036/1670 train_time:98195ms step_avg:94.78ms +step:1037/1670 train_time:98289ms step_avg:94.78ms +step:1038/1670 train_time:98382ms step_avg:94.78ms +step:1039/1670 train_time:98476ms step_avg:94.78ms +step:1040/1670 train_time:98569ms step_avg:94.78ms +step:1041/1670 train_time:98662ms step_avg:94.78ms +step:1042/1670 train_time:98756ms step_avg:94.78ms +step:1043/1670 train_time:98849ms step_avg:94.77ms +step:1044/1670 train_time:98943ms step_avg:94.77ms +step:1045/1670 train_time:99037ms step_avg:94.77ms +step:1046/1670 train_time:99131ms step_avg:94.77ms +step:1047/1670 train_time:99225ms step_avg:94.77ms +step:1048/1670 train_time:99319ms step_avg:94.77ms +step:1049/1670 train_time:99413ms step_avg:94.77ms +step:1050/1670 train_time:99506ms step_avg:94.77ms +step:1051/1670 train_time:99600ms step_avg:94.77ms +step:1052/1670 train_time:99693ms step_avg:94.77ms +step:1053/1670 train_time:99787ms step_avg:94.76ms +step:1054/1670 train_time:99881ms step_avg:94.76ms +step:1055/1670 train_time:99975ms step_avg:94.76ms +step:1056/1670 train_time:100069ms step_avg:94.76ms +step:1057/1670 train_time:100164ms step_avg:94.76ms +step:1058/1670 train_time:100259ms step_avg:94.76ms +step:1059/1670 train_time:100353ms step_avg:94.76ms +step:1060/1670 train_time:100447ms step_avg:94.76ms +step:1061/1670 train_time:100540ms step_avg:94.76ms +step:1062/1670 train_time:100986ms step_avg:95.09ms +step:1063/1670 train_time:101056ms step_avg:95.07ms +step:1064/1670 train_time:101148ms step_avg:95.06ms +step:1065/1670 train_time:101240ms step_avg:95.06ms +step:1066/1670 train_time:101333ms step_avg:95.06ms +step:1067/1670 train_time:101426ms step_avg:95.06ms +step:1068/1670 train_time:101518ms step_avg:95.05ms +step:1069/1670 train_time:101611ms step_avg:95.05ms +step:1070/1670 train_time:101703ms step_avg:95.05ms +step:1071/1670 train_time:101796ms step_avg:95.05ms +step:1072/1670 train_time:101892ms step_avg:95.05ms +step:1073/1670 train_time:101988ms step_avg:95.05ms +step:1074/1670 train_time:102085ms step_avg:95.05ms +step:1075/1670 train_time:102179ms step_avg:95.05ms +step:1076/1670 train_time:102272ms step_avg:95.05ms +step:1077/1670 train_time:102365ms step_avg:95.05ms +step:1078/1670 train_time:102459ms step_avg:95.05ms +step:1079/1670 train_time:102551ms step_avg:95.04ms +step:1080/1670 train_time:102644ms step_avg:95.04ms +step:1081/1670 train_time:102736ms step_avg:95.04ms +step:1082/1670 train_time:102830ms step_avg:95.04ms +step:1083/1670 train_time:102925ms step_avg:95.04ms +step:1084/1670 train_time:103021ms step_avg:95.04ms +step:1085/1670 train_time:103117ms step_avg:95.04ms +step:1086/1670 train_time:103210ms step_avg:95.04ms +step:1087/1670 train_time:103304ms step_avg:95.04ms +step:1088/1670 train_time:103397ms step_avg:95.03ms +step:1089/1670 train_time:103490ms step_avg:95.03ms +step:1090/1670 train_time:103583ms step_avg:95.03ms +step:1091/1670 train_time:103677ms step_avg:95.03ms +step:1092/1670 train_time:103771ms step_avg:95.03ms +step:1093/1670 train_time:103866ms step_avg:95.03ms +step:1094/1670 train_time:103960ms step_avg:95.03ms +step:1095/1670 train_time:104055ms step_avg:95.03ms +step:1096/1670 train_time:104150ms step_avg:95.03ms +step:1097/1670 train_time:104244ms step_avg:95.03ms +step:1098/1670 train_time:104338ms step_avg:95.03ms +step:1099/1670 train_time:104431ms step_avg:95.02ms +step:1100/1670 train_time:104524ms step_avg:95.02ms +step:1101/1670 train_time:104618ms step_avg:95.02ms +step:1102/1670 train_time:104711ms step_avg:95.02ms +step:1103/1670 train_time:104804ms step_avg:95.02ms +step:1104/1670 train_time:104898ms step_avg:95.02ms +step:1105/1670 train_time:104993ms step_avg:95.02ms +step:1106/1670 train_time:105087ms step_avg:95.02ms +step:1107/1670 train_time:105182ms step_avg:95.02ms +step:1108/1670 train_time:105276ms step_avg:95.01ms +step:1109/1670 train_time:105370ms step_avg:95.01ms +step:1110/1670 train_time:105464ms step_avg:95.01ms +step:1111/1670 train_time:105557ms step_avg:95.01ms +step:1112/1670 train_time:105650ms step_avg:95.01ms +step:1113/1670 train_time:105744ms step_avg:95.01ms +step:1114/1670 train_time:105837ms step_avg:95.01ms +step:1115/1670 train_time:106042ms step_avg:95.10ms +step:1116/1670 train_time:106111ms step_avg:95.08ms +step:1117/1670 train_time:106204ms step_avg:95.08ms +step:1118/1670 train_time:106298ms step_avg:95.08ms +step:1119/1670 train_time:106391ms step_avg:95.08ms +step:1120/1670 train_time:106485ms step_avg:95.08ms +step:1121/1670 train_time:106578ms step_avg:95.07ms +step:1122/1670 train_time:106671ms step_avg:95.07ms +step:1123/1670 train_time:106765ms step_avg:95.07ms +step:1124/1670 train_time:106858ms step_avg:95.07ms +step:1125/1670 train_time:106956ms step_avg:95.07ms +step:1125/1670 val_loss:3.4166 train_time:107052ms step_avg:95.16ms +step:1126/1670 train_time:107078ms step_avg:95.10ms +step:1127/1670 train_time:107160ms step_avg:95.08ms +step:1128/1670 train_time:107261ms step_avg:95.09ms +step:1129/1670 train_time:107356ms step_avg:95.09ms +step:1130/1670 train_time:107449ms step_avg:95.09ms +step:1131/1670 train_time:107544ms step_avg:95.09ms +step:1132/1670 train_time:107637ms step_avg:95.09ms +step:1133/1670 train_time:107731ms step_avg:95.08ms +step:1134/1670 train_time:107824ms step_avg:95.08ms +step:1135/1670 train_time:107918ms step_avg:95.08ms +step:1136/1670 train_time:108012ms step_avg:95.08ms +step:1137/1670 train_time:108109ms step_avg:95.08ms +step:1138/1670 train_time:108208ms step_avg:95.09ms +step:1139/1670 train_time:108304ms step_avg:95.09ms +step:1140/1670 train_time:108399ms step_avg:95.09ms +step:1141/1670 train_time:108493ms step_avg:95.09ms +step:1142/1670 train_time:108587ms step_avg:95.08ms +step:1143/1670 train_time:108681ms step_avg:95.08ms +step:1144/1670 train_time:108774ms step_avg:95.08ms +step:1145/1670 train_time:108868ms step_avg:95.08ms +step:1146/1670 train_time:108962ms step_avg:95.08ms +step:1147/1670 train_time:109055ms step_avg:95.08ms +step:1148/1670 train_time:109151ms step_avg:95.08ms +step:1149/1670 train_time:109246ms step_avg:95.08ms +step:1150/1670 train_time:109341ms step_avg:95.08ms +step:1151/1670 train_time:109436ms step_avg:95.08ms +step:1152/1670 train_time:109530ms step_avg:95.08ms +step:1153/1670 train_time:109624ms step_avg:95.08ms +step:1154/1670 train_time:109718ms step_avg:95.08ms +step:1155/1670 train_time:109812ms step_avg:95.08ms +step:1156/1670 train_time:109905ms step_avg:95.07ms +step:1157/1670 train_time:110000ms step_avg:95.07ms +step:1158/1670 train_time:110096ms step_avg:95.07ms +step:1159/1670 train_time:110191ms step_avg:95.07ms +step:1160/1670 train_time:110286ms step_avg:95.07ms +step:1161/1670 train_time:110381ms step_avg:95.07ms +step:1162/1670 train_time:110475ms step_avg:95.07ms +step:1163/1670 train_time:110571ms step_avg:95.07ms +step:1164/1670 train_time:110665ms step_avg:95.07ms +step:1165/1670 train_time:110759ms step_avg:95.07ms +step:1166/1670 train_time:110852ms step_avg:95.07ms +step:1167/1670 train_time:110947ms step_avg:95.07ms +step:1168/1670 train_time:111042ms step_avg:95.07ms +step:1169/1670 train_time:111137ms step_avg:95.07ms +step:1170/1670 train_time:111233ms step_avg:95.07ms +step:1171/1670 train_time:111329ms step_avg:95.07ms +step:1172/1670 train_time:111424ms step_avg:95.07ms +step:1173/1670 train_time:111519ms step_avg:95.07ms +step:1174/1670 train_time:111613ms step_avg:95.07ms +step:1175/1670 train_time:111708ms step_avg:95.07ms +step:1176/1670 train_time:111803ms step_avg:95.07ms +step:1177/1670 train_time:111897ms step_avg:95.07ms +step:1178/1670 train_time:111991ms step_avg:95.07ms +step:1179/1670 train_time:112085ms step_avg:95.07ms +step:1180/1670 train_time:112180ms step_avg:95.07ms +step:1181/1670 train_time:112276ms step_avg:95.07ms +step:1182/1670 train_time:112371ms step_avg:95.07ms +step:1183/1670 train_time:112467ms step_avg:95.07ms +step:1184/1670 train_time:112561ms step_avg:95.07ms +step:1185/1670 train_time:112655ms step_avg:95.07ms +step:1186/1670 train_time:112750ms step_avg:95.07ms +step:1187/1670 train_time:112844ms step_avg:95.07ms +step:1188/1670 train_time:112939ms step_avg:95.07ms +step:1189/1670 train_time:113033ms step_avg:95.07ms +step:1190/1670 train_time:113128ms step_avg:95.07ms +step:1191/1670 train_time:113223ms step_avg:95.07ms +step:1192/1670 train_time:113318ms step_avg:95.07ms +step:1193/1670 train_time:113413ms step_avg:95.07ms +step:1194/1670 train_time:113509ms step_avg:95.07ms +step:1195/1670 train_time:113604ms step_avg:95.07ms +step:1196/1670 train_time:113697ms step_avg:95.06ms +step:1197/1670 train_time:113792ms step_avg:95.06ms +step:1198/1670 train_time:113887ms step_avg:95.06ms +step:1199/1670 train_time:113981ms step_avg:95.06ms +step:1200/1670 train_time:114075ms step_avg:95.06ms +step:1201/1670 train_time:114170ms step_avg:95.06ms +step:1202/1670 train_time:114266ms step_avg:95.06ms +step:1203/1670 train_time:114360ms step_avg:95.06ms +step:1204/1670 train_time:114456ms step_avg:95.06ms +step:1205/1670 train_time:114551ms step_avg:95.06ms +step:1206/1670 train_time:114645ms step_avg:95.06ms +step:1207/1670 train_time:114740ms step_avg:95.06ms +step:1208/1670 train_time:114835ms step_avg:95.06ms +step:1209/1670 train_time:114930ms step_avg:95.06ms +step:1210/1670 train_time:115024ms step_avg:95.06ms +step:1211/1670 train_time:115119ms step_avg:95.06ms +step:1212/1670 train_time:115214ms step_avg:95.06ms +step:1213/1670 train_time:115308ms step_avg:95.06ms +step:1214/1670 train_time:115404ms step_avg:95.06ms +step:1215/1670 train_time:115499ms step_avg:95.06ms +step:1216/1670 train_time:115593ms step_avg:95.06ms +step:1217/1670 train_time:115688ms step_avg:95.06ms +step:1218/1670 train_time:115783ms step_avg:95.06ms +step:1219/1670 train_time:115878ms step_avg:95.06ms +step:1220/1670 train_time:115973ms step_avg:95.06ms +step:1221/1670 train_time:116067ms step_avg:95.06ms +step:1222/1670 train_time:116161ms step_avg:95.06ms +step:1223/1670 train_time:116254ms step_avg:95.06ms +step:1224/1670 train_time:116349ms step_avg:95.06ms +step:1225/1670 train_time:116445ms step_avg:95.06ms +step:1226/1670 train_time:116539ms step_avg:95.06ms +step:1227/1670 train_time:116634ms step_avg:95.06ms +step:1228/1670 train_time:116729ms step_avg:95.06ms +step:1229/1670 train_time:116823ms step_avg:95.06ms +step:1230/1670 train_time:116917ms step_avg:95.05ms +step:1231/1670 train_time:117012ms step_avg:95.05ms +step:1232/1670 train_time:117107ms step_avg:95.05ms +step:1233/1670 train_time:117201ms step_avg:95.05ms +step:1234/1670 train_time:117295ms step_avg:95.05ms +step:1235/1670 train_time:117390ms step_avg:95.05ms +step:1236/1670 train_time:117485ms step_avg:95.05ms +step:1237/1670 train_time:117580ms step_avg:95.05ms +step:1238/1670 train_time:117676ms step_avg:95.05ms +step:1239/1670 train_time:117771ms step_avg:95.05ms +step:1240/1670 train_time:117866ms step_avg:95.05ms +step:1241/1670 train_time:117960ms step_avg:95.05ms +step:1242/1670 train_time:118055ms step_avg:95.05ms +step:1243/1670 train_time:118149ms step_avg:95.05ms +step:1244/1670 train_time:118244ms step_avg:95.05ms +step:1245/1670 train_time:118339ms step_avg:95.05ms +step:1246/1670 train_time:118434ms step_avg:95.05ms +step:1247/1670 train_time:118528ms step_avg:95.05ms +step:1248/1670 train_time:118623ms step_avg:95.05ms +step:1249/1670 train_time:118717ms step_avg:95.05ms +step:1250/1670 train_time:118811ms step_avg:95.05ms +step:1250/1670 val_loss:3.3777 train_time:118904ms step_avg:95.12ms +step:1251/1670 train_time:118930ms step_avg:95.07ms +step:1252/1670 train_time:119007ms step_avg:95.05ms +step:1253/1670 train_time:119109ms step_avg:95.06ms +step:1254/1670 train_time:119205ms step_avg:95.06ms +step:1255/1670 train_time:119299ms step_avg:95.06ms +step:1256/1670 train_time:119392ms step_avg:95.06ms +step:1257/1670 train_time:119486ms step_avg:95.06ms +step:1258/1670 train_time:119579ms step_avg:95.05ms +step:1259/1670 train_time:119673ms step_avg:95.05ms +step:1260/1670 train_time:119766ms step_avg:95.05ms +step:1261/1670 train_time:119860ms step_avg:95.05ms +step:1262/1670 train_time:119957ms step_avg:95.05ms +step:1263/1670 train_time:120056ms step_avg:95.06ms +step:1264/1670 train_time:120153ms step_avg:95.06ms +step:1265/1670 train_time:120249ms step_avg:95.06ms +step:1266/1670 train_time:120346ms step_avg:95.06ms +step:1267/1670 train_time:120440ms step_avg:95.06ms +step:1268/1670 train_time:120533ms step_avg:95.06ms +step:1269/1670 train_time:120627ms step_avg:95.06ms +step:1270/1670 train_time:120720ms step_avg:95.06ms +step:1271/1670 train_time:120814ms step_avg:95.05ms +step:1272/1670 train_time:120909ms step_avg:95.05ms +step:1273/1670 train_time:121006ms step_avg:95.06ms +step:1274/1670 train_time:121451ms step_avg:95.33ms +step:1275/1670 train_time:121519ms step_avg:95.31ms +step:1276/1670 train_time:121612ms step_avg:95.31ms +step:1277/1670 train_time:121705ms step_avg:95.31ms +step:1278/1670 train_time:121798ms step_avg:95.30ms +step:1279/1670 train_time:121892ms step_avg:95.30ms +step:1280/1670 train_time:121986ms step_avg:95.30ms +step:1281/1670 train_time:122079ms step_avg:95.30ms +step:1282/1670 train_time:122172ms step_avg:95.30ms +step:1283/1670 train_time:122265ms step_avg:95.30ms +step:1284/1670 train_time:122361ms step_avg:95.30ms +step:1285/1670 train_time:122460ms step_avg:95.30ms +step:1286/1670 train_time:122555ms step_avg:95.30ms +step:1287/1670 train_time:122650ms step_avg:95.30ms +step:1288/1670 train_time:122744ms step_avg:95.30ms +step:1289/1670 train_time:122838ms step_avg:95.30ms +step:1290/1670 train_time:122933ms step_avg:95.30ms +step:1291/1670 train_time:123026ms step_avg:95.30ms +step:1292/1670 train_time:123120ms step_avg:95.29ms +step:1293/1670 train_time:123215ms step_avg:95.29ms +step:1294/1670 train_time:123308ms step_avg:95.29ms +step:1295/1670 train_time:123406ms step_avg:95.29ms +step:1296/1670 train_time:123503ms step_avg:95.30ms +step:1297/1670 train_time:123598ms step_avg:95.30ms +step:1298/1670 train_time:123692ms step_avg:95.29ms +step:1299/1670 train_time:123787ms step_avg:95.29ms +step:1300/1670 train_time:123881ms step_avg:95.29ms +step:1301/1670 train_time:123976ms step_avg:95.29ms +step:1302/1670 train_time:124069ms step_avg:95.29ms +step:1303/1670 train_time:124163ms step_avg:95.29ms +step:1304/1670 train_time:124259ms step_avg:95.29ms +step:1305/1670 train_time:124354ms step_avg:95.29ms +step:1306/1670 train_time:124449ms step_avg:95.29ms +step:1307/1670 train_time:124544ms step_avg:95.29ms +step:1308/1670 train_time:124640ms step_avg:95.29ms +step:1309/1670 train_time:124734ms step_avg:95.29ms +step:1310/1670 train_time:124829ms step_avg:95.29ms +step:1311/1670 train_time:124923ms step_avg:95.29ms +step:1312/1670 train_time:125018ms step_avg:95.29ms +step:1313/1670 train_time:125112ms step_avg:95.29ms +step:1314/1670 train_time:125206ms step_avg:95.29ms +step:1315/1670 train_time:125301ms step_avg:95.29ms +step:1316/1670 train_time:125396ms step_avg:95.29ms +step:1317/1670 train_time:125491ms step_avg:95.29ms +step:1318/1670 train_time:125586ms step_avg:95.29ms +step:1319/1670 train_time:125681ms step_avg:95.28ms +step:1320/1670 train_time:125776ms step_avg:95.28ms +step:1321/1670 train_time:125870ms step_avg:95.28ms +step:1322/1670 train_time:125963ms step_avg:95.28ms +step:1323/1670 train_time:126059ms step_avg:95.28ms +step:1324/1670 train_time:126154ms step_avg:95.28ms +step:1325/1670 train_time:126248ms step_avg:95.28ms +step:1326/1670 train_time:126343ms step_avg:95.28ms +step:1327/1670 train_time:126437ms step_avg:95.28ms +step:1328/1670 train_time:126532ms step_avg:95.28ms +step:1329/1670 train_time:126627ms step_avg:95.28ms +step:1330/1670 train_time:126722ms step_avg:95.28ms +step:1331/1670 train_time:126816ms step_avg:95.28ms +step:1332/1670 train_time:126910ms step_avg:95.28ms +step:1333/1670 train_time:127005ms step_avg:95.28ms +step:1334/1670 train_time:127100ms step_avg:95.28ms +step:1335/1670 train_time:127195ms step_avg:95.28ms +step:1336/1670 train_time:127290ms step_avg:95.28ms +step:1337/1670 train_time:127384ms step_avg:95.28ms +step:1338/1670 train_time:127479ms step_avg:95.28ms +step:1339/1670 train_time:127574ms step_avg:95.28ms +step:1340/1670 train_time:127669ms step_avg:95.28ms +step:1341/1670 train_time:127764ms step_avg:95.28ms +step:1342/1670 train_time:127858ms step_avg:95.27ms +step:1343/1670 train_time:127953ms step_avg:95.27ms +step:1344/1670 train_time:128047ms step_avg:95.27ms +step:1345/1670 train_time:128141ms step_avg:95.27ms +step:1346/1670 train_time:128237ms step_avg:95.27ms +step:1347/1670 train_time:128331ms step_avg:95.27ms +step:1348/1670 train_time:128425ms step_avg:95.27ms +step:1349/1670 train_time:128520ms step_avg:95.27ms +step:1350/1670 train_time:128615ms step_avg:95.27ms +step:1351/1670 train_time:128709ms step_avg:95.27ms +step:1352/1670 train_time:128805ms step_avg:95.27ms +step:1353/1670 train_time:128899ms step_avg:95.27ms +step:1354/1670 train_time:128993ms step_avg:95.27ms +step:1355/1670 train_time:129088ms step_avg:95.27ms +step:1356/1670 train_time:129182ms step_avg:95.27ms +step:1357/1670 train_time:129277ms step_avg:95.27ms +step:1358/1670 train_time:129371ms step_avg:95.27ms +step:1359/1670 train_time:129465ms step_avg:95.27ms +step:1360/1670 train_time:129561ms step_avg:95.27ms +step:1361/1670 train_time:129656ms step_avg:95.26ms +step:1362/1670 train_time:129751ms step_avg:95.26ms +step:1363/1670 train_time:129845ms step_avg:95.26ms +step:1364/1670 train_time:129939ms step_avg:95.26ms +step:1365/1670 train_time:130033ms step_avg:95.26ms +step:1366/1670 train_time:130128ms step_avg:95.26ms +step:1367/1670 train_time:130223ms step_avg:95.26ms +step:1368/1670 train_time:130317ms step_avg:95.26ms +step:1369/1670 train_time:130412ms step_avg:95.26ms +step:1370/1670 train_time:130506ms step_avg:95.26ms +step:1371/1670 train_time:130601ms step_avg:95.26ms +step:1372/1670 train_time:130695ms step_avg:95.26ms +step:1373/1670 train_time:130791ms step_avg:95.26ms +step:1374/1670 train_time:130886ms step_avg:95.26ms +step:1375/1670 train_time:130980ms step_avg:95.26ms +step:1375/1670 val_loss:3.3429 train_time:131074ms step_avg:95.33ms +step:1376/1670 train_time:131100ms step_avg:95.28ms +step:1377/1670 train_time:131179ms step_avg:95.26ms +step:1378/1670 train_time:131280ms step_avg:95.27ms +step:1379/1670 train_time:131375ms step_avg:95.27ms +step:1380/1670 train_time:131469ms step_avg:95.27ms +step:1381/1670 train_time:131563ms step_avg:95.27ms +step:1382/1670 train_time:131656ms step_avg:95.27ms +step:1383/1670 train_time:131750ms step_avg:95.26ms +step:1384/1670 train_time:131844ms step_avg:95.26ms +step:1385/1670 train_time:131938ms step_avg:95.26ms +step:1386/1670 train_time:132032ms step_avg:95.26ms +step:1387/1670 train_time:132127ms step_avg:95.26ms +step:1388/1670 train_time:132224ms step_avg:95.26ms +step:1389/1670 train_time:132319ms step_avg:95.26ms +step:1390/1670 train_time:132415ms step_avg:95.26ms +step:1391/1670 train_time:132509ms step_avg:95.26ms +step:1392/1670 train_time:132603ms step_avg:95.26ms +step:1393/1670 train_time:132697ms step_avg:95.26ms +step:1394/1670 train_time:132791ms step_avg:95.26ms +step:1395/1670 train_time:132884ms step_avg:95.26ms +step:1396/1670 train_time:132979ms step_avg:95.26ms +step:1397/1670 train_time:133073ms step_avg:95.26ms +step:1398/1670 train_time:133169ms step_avg:95.26ms +step:1399/1670 train_time:133266ms step_avg:95.26ms +step:1400/1670 train_time:133360ms step_avg:95.26ms +step:1401/1670 train_time:133455ms step_avg:95.26ms +step:1402/1670 train_time:133549ms step_avg:95.26ms +step:1403/1670 train_time:133644ms step_avg:95.26ms +step:1404/1670 train_time:133738ms step_avg:95.25ms +step:1405/1670 train_time:133833ms step_avg:95.26ms +step:1406/1670 train_time:133927ms step_avg:95.25ms +step:1407/1670 train_time:134021ms step_avg:95.25ms +step:1408/1670 train_time:134115ms step_avg:95.25ms +step:1409/1670 train_time:134211ms step_avg:95.25ms +step:1410/1670 train_time:134306ms step_avg:95.25ms +step:1411/1670 train_time:134401ms step_avg:95.25ms +step:1412/1670 train_time:134496ms step_avg:95.25ms +step:1413/1670 train_time:134591ms step_avg:95.25ms +step:1414/1670 train_time:134684ms step_avg:95.25ms +step:1415/1670 train_time:134779ms step_avg:95.25ms +step:1416/1670 train_time:134874ms step_avg:95.25ms +step:1417/1670 train_time:134969ms step_avg:95.25ms +step:1418/1670 train_time:135063ms step_avg:95.25ms +step:1419/1670 train_time:135158ms step_avg:95.25ms +step:1420/1670 train_time:135254ms step_avg:95.25ms +step:1421/1670 train_time:135349ms step_avg:95.25ms +step:1422/1670 train_time:135444ms step_avg:95.25ms +step:1423/1670 train_time:135538ms step_avg:95.25ms +step:1424/1670 train_time:135633ms step_avg:95.25ms +step:1425/1670 train_time:135727ms step_avg:95.25ms +step:1426/1670 train_time:135821ms step_avg:95.25ms +step:1427/1670 train_time:135915ms step_avg:95.25ms +step:1428/1670 train_time:136010ms step_avg:95.24ms +step:1429/1670 train_time:136105ms step_avg:95.25ms +step:1430/1670 train_time:136200ms step_avg:95.24ms +step:1431/1670 train_time:136295ms step_avg:95.24ms +step:1432/1670 train_time:136389ms step_avg:95.24ms +step:1433/1670 train_time:136484ms step_avg:95.24ms +step:1434/1670 train_time:136579ms step_avg:95.24ms +step:1435/1670 train_time:136673ms step_avg:95.24ms +step:1436/1670 train_time:136768ms step_avg:95.24ms +step:1437/1670 train_time:136863ms step_avg:95.24ms +step:1438/1670 train_time:136956ms step_avg:95.24ms +step:1439/1670 train_time:137052ms step_avg:95.24ms +step:1440/1670 train_time:137147ms step_avg:95.24ms +step:1441/1670 train_time:137243ms step_avg:95.24ms +step:1442/1670 train_time:137338ms step_avg:95.24ms +step:1443/1670 train_time:137432ms step_avg:95.24ms +step:1444/1670 train_time:137527ms step_avg:95.24ms +step:1445/1670 train_time:137622ms step_avg:95.24ms +step:1446/1670 train_time:137717ms step_avg:95.24ms +step:1447/1670 train_time:137812ms step_avg:95.24ms +step:1448/1670 train_time:137907ms step_avg:95.24ms +step:1449/1670 train_time:138001ms step_avg:95.24ms +step:1450/1670 train_time:138096ms step_avg:95.24ms +step:1451/1670 train_time:138191ms step_avg:95.24ms +step:1452/1670 train_time:138286ms step_avg:95.24ms +step:1453/1670 train_time:138381ms step_avg:95.24ms +step:1454/1670 train_time:138476ms step_avg:95.24ms +step:1455/1670 train_time:138572ms step_avg:95.24ms +step:1456/1670 train_time:138666ms step_avg:95.24ms +step:1457/1670 train_time:138760ms step_avg:95.24ms +step:1458/1670 train_time:138856ms step_avg:95.24ms +step:1459/1670 train_time:138950ms step_avg:95.24ms +step:1460/1670 train_time:139044ms step_avg:95.24ms +step:1461/1670 train_time:139139ms step_avg:95.24ms +step:1462/1670 train_time:139233ms step_avg:95.23ms +step:1463/1670 train_time:139329ms step_avg:95.24ms +step:1464/1670 train_time:139424ms step_avg:95.24ms +step:1465/1670 train_time:139519ms step_avg:95.23ms +step:1466/1670 train_time:139613ms step_avg:95.23ms +step:1467/1670 train_time:139708ms step_avg:95.23ms +step:1468/1670 train_time:139803ms step_avg:95.23ms +step:1469/1670 train_time:139898ms step_avg:95.23ms +step:1470/1670 train_time:139992ms step_avg:95.23ms +step:1471/1670 train_time:140086ms step_avg:95.23ms +step:1472/1670 train_time:140180ms step_avg:95.23ms +step:1473/1670 train_time:140276ms step_avg:95.23ms +step:1474/1670 train_time:140370ms step_avg:95.23ms +step:1475/1670 train_time:140465ms step_avg:95.23ms +step:1476/1670 train_time:140561ms step_avg:95.23ms +step:1477/1670 train_time:140655ms step_avg:95.23ms +step:1478/1670 train_time:140750ms step_avg:95.23ms +step:1479/1670 train_time:140846ms step_avg:95.23ms +step:1480/1670 train_time:140940ms step_avg:95.23ms +step:1481/1670 train_time:141035ms step_avg:95.23ms +step:1482/1670 train_time:141130ms step_avg:95.23ms +step:1483/1670 train_time:141224ms step_avg:95.23ms +step:1484/1670 train_time:141319ms step_avg:95.23ms +step:1485/1670 train_time:141763ms step_avg:95.46ms +step:1486/1670 train_time:141831ms step_avg:95.44ms +step:1487/1670 train_time:141923ms step_avg:95.44ms +step:1488/1670 train_time:142017ms step_avg:95.44ms +step:1489/1670 train_time:142110ms step_avg:95.44ms +step:1490/1670 train_time:142204ms step_avg:95.44ms +step:1491/1670 train_time:142297ms step_avg:95.44ms +step:1492/1670 train_time:142391ms step_avg:95.44ms +step:1493/1670 train_time:142484ms step_avg:95.43ms +step:1494/1670 train_time:142578ms step_avg:95.43ms +step:1495/1670 train_time:142672ms step_avg:95.43ms +step:1496/1670 train_time:142774ms step_avg:95.44ms +step:1497/1670 train_time:142872ms step_avg:95.44ms +step:1498/1670 train_time:142966ms step_avg:95.44ms +step:1499/1670 train_time:143060ms step_avg:95.44ms +step:1500/1670 train_time:143154ms step_avg:95.44ms +step:1500/1670 val_loss:3.3130 train_time:143247ms step_avg:95.50ms +step:1501/1670 train_time:143273ms step_avg:95.45ms +step:1502/1670 train_time:143349ms step_avg:95.44ms +step:1503/1670 train_time:143453ms step_avg:95.44ms +step:1504/1670 train_time:143549ms step_avg:95.44ms +step:1505/1670 train_time:143642ms step_avg:95.44ms +step:1506/1670 train_time:143736ms step_avg:95.44ms +step:1507/1670 train_time:143829ms step_avg:95.44ms +step:1508/1670 train_time:143923ms step_avg:95.44ms +step:1509/1670 train_time:144016ms step_avg:95.44ms +step:1510/1670 train_time:144109ms step_avg:95.44ms +step:1511/1670 train_time:144203ms step_avg:95.44ms +step:1512/1670 train_time:144300ms step_avg:95.44ms +step:1513/1670 train_time:144398ms step_avg:95.44ms +step:1514/1670 train_time:144496ms step_avg:95.44ms +step:1515/1670 train_time:144592ms step_avg:95.44ms +step:1516/1670 train_time:144686ms step_avg:95.44ms +step:1517/1670 train_time:144779ms step_avg:95.44ms +step:1518/1670 train_time:144873ms step_avg:95.44ms +step:1519/1670 train_time:144966ms step_avg:95.43ms +step:1520/1670 train_time:145059ms step_avg:95.43ms +step:1521/1670 train_time:145153ms step_avg:95.43ms +step:1522/1670 train_time:145248ms step_avg:95.43ms +step:1523/1670 train_time:145345ms step_avg:95.43ms +step:1524/1670 train_time:145441ms step_avg:95.43ms +step:1525/1670 train_time:145536ms step_avg:95.43ms +step:1526/1670 train_time:145631ms step_avg:95.43ms +step:1527/1670 train_time:145726ms step_avg:95.43ms +step:1528/1670 train_time:145820ms step_avg:95.43ms +step:1529/1670 train_time:145914ms step_avg:95.43ms +step:1530/1670 train_time:146008ms step_avg:95.43ms +step:1531/1670 train_time:146101ms step_avg:95.43ms +step:1532/1670 train_time:146195ms step_avg:95.43ms +step:1533/1670 train_time:146292ms step_avg:95.43ms +step:1534/1670 train_time:146387ms step_avg:95.43ms +step:1535/1670 train_time:146483ms step_avg:95.43ms +step:1536/1670 train_time:146577ms step_avg:95.43ms +step:1537/1670 train_time:146673ms step_avg:95.43ms +step:1538/1670 train_time:146767ms step_avg:95.43ms +step:1539/1670 train_time:146861ms step_avg:95.43ms +step:1540/1670 train_time:146955ms step_avg:95.43ms +step:1541/1670 train_time:147050ms step_avg:95.42ms +step:1542/1670 train_time:147144ms step_avg:95.42ms +step:1543/1670 train_time:147238ms step_avg:95.42ms +step:1544/1670 train_time:147334ms step_avg:95.42ms +step:1545/1670 train_time:147428ms step_avg:95.42ms +step:1546/1670 train_time:147523ms step_avg:95.42ms +step:1547/1670 train_time:147618ms step_avg:95.42ms +step:1548/1670 train_time:147714ms step_avg:95.42ms +step:1549/1670 train_time:147808ms step_avg:95.42ms +step:1550/1670 train_time:147902ms step_avg:95.42ms +step:1551/1670 train_time:147996ms step_avg:95.42ms +step:1552/1670 train_time:148090ms step_avg:95.42ms +step:1553/1670 train_time:148184ms step_avg:95.42ms +step:1554/1670 train_time:148279ms step_avg:95.42ms +step:1555/1670 train_time:148375ms step_avg:95.42ms +step:1556/1670 train_time:148470ms step_avg:95.42ms +step:1557/1670 train_time:148564ms step_avg:95.42ms +step:1558/1670 train_time:148659ms step_avg:95.42ms +step:1559/1670 train_time:148754ms step_avg:95.42ms +step:1560/1670 train_time:148849ms step_avg:95.42ms +step:1561/1670 train_time:148944ms step_avg:95.42ms +step:1562/1670 train_time:149038ms step_avg:95.41ms +step:1563/1670 train_time:149132ms step_avg:95.41ms +step:1564/1670 train_time:149228ms step_avg:95.41ms +step:1565/1670 train_time:149322ms step_avg:95.41ms +step:1566/1670 train_time:149418ms step_avg:95.41ms +step:1567/1670 train_time:149513ms step_avg:95.41ms +step:1568/1670 train_time:149608ms step_avg:95.41ms +step:1569/1670 train_time:149703ms step_avg:95.41ms +step:1570/1670 train_time:149798ms step_avg:95.41ms +step:1571/1670 train_time:149892ms step_avg:95.41ms +step:1572/1670 train_time:149986ms step_avg:95.41ms +step:1573/1670 train_time:150081ms step_avg:95.41ms +step:1574/1670 train_time:150176ms step_avg:95.41ms +step:1575/1670 train_time:150269ms step_avg:95.41ms +step:1576/1670 train_time:150364ms step_avg:95.41ms +step:1577/1670 train_time:150458ms step_avg:95.41ms +step:1578/1670 train_time:150553ms step_avg:95.41ms +step:1579/1670 train_time:150648ms step_avg:95.41ms +step:1580/1670 train_time:150743ms step_avg:95.41ms +step:1581/1670 train_time:150838ms step_avg:95.41ms +step:1582/1670 train_time:150932ms step_avg:95.41ms +step:1583/1670 train_time:151027ms step_avg:95.41ms +step:1584/1670 train_time:151121ms step_avg:95.40ms +step:1585/1670 train_time:151216ms step_avg:95.40ms +step:1586/1670 train_time:151310ms step_avg:95.40ms +step:1587/1670 train_time:151405ms step_avg:95.40ms +step:1588/1670 train_time:151500ms step_avg:95.40ms +step:1589/1670 train_time:151595ms step_avg:95.40ms +step:1590/1670 train_time:151690ms step_avg:95.40ms +step:1591/1670 train_time:151784ms step_avg:95.40ms +step:1592/1670 train_time:151878ms step_avg:95.40ms +step:1593/1670 train_time:151973ms step_avg:95.40ms +step:1594/1670 train_time:152067ms step_avg:95.40ms +step:1595/1670 train_time:152163ms step_avg:95.40ms +step:1596/1670 train_time:152257ms step_avg:95.40ms +step:1597/1670 train_time:152351ms step_avg:95.40ms +step:1598/1670 train_time:152446ms step_avg:95.40ms +step:1599/1670 train_time:152540ms step_avg:95.40ms +step:1600/1670 train_time:152637ms step_avg:95.40ms +step:1601/1670 train_time:152731ms step_avg:95.40ms +step:1602/1670 train_time:152826ms step_avg:95.40ms +step:1603/1670 train_time:152921ms step_avg:95.40ms +step:1604/1670 train_time:153015ms step_avg:95.40ms +step:1605/1670 train_time:153109ms step_avg:95.40ms +step:1606/1670 train_time:153204ms step_avg:95.40ms +step:1607/1670 train_time:153299ms step_avg:95.39ms +step:1608/1670 train_time:153394ms step_avg:95.39ms +step:1609/1670 train_time:153489ms step_avg:95.39ms +step:1610/1670 train_time:153584ms step_avg:95.39ms +step:1611/1670 train_time:153678ms step_avg:95.39ms +step:1612/1670 train_time:153771ms step_avg:95.39ms +step:1613/1670 train_time:153866ms step_avg:95.39ms +step:1614/1670 train_time:153961ms step_avg:95.39ms +step:1615/1670 train_time:154056ms step_avg:95.39ms +step:1616/1670 train_time:154151ms step_avg:95.39ms +step:1617/1670 train_time:154246ms step_avg:95.39ms +step:1618/1670 train_time:154341ms step_avg:95.39ms +step:1619/1670 train_time:154437ms step_avg:95.39ms +step:1620/1670 train_time:154531ms step_avg:95.39ms +step:1621/1670 train_time:154626ms step_avg:95.39ms +step:1622/1670 train_time:154720ms step_avg:95.39ms +step:1623/1670 train_time:154815ms step_avg:95.39ms +step:1624/1670 train_time:154910ms step_avg:95.39ms +step:1625/1670 train_time:155005ms step_avg:95.39ms +step:1625/1670 val_loss:3.2878 train_time:155098ms step_avg:95.44ms +step:1626/1670 train_time:155125ms step_avg:95.40ms +step:1627/1670 train_time:155203ms step_avg:95.39ms +step:1628/1670 train_time:155304ms step_avg:95.40ms +step:1629/1670 train_time:155400ms step_avg:95.40ms +step:1630/1670 train_time:155494ms step_avg:95.40ms +step:1631/1670 train_time:155588ms step_avg:95.39ms +step:1632/1670 train_time:155681ms step_avg:95.39ms +step:1633/1670 train_time:155775ms step_avg:95.39ms +step:1634/1670 train_time:155869ms step_avg:95.39ms +step:1635/1670 train_time:155962ms step_avg:95.39ms +step:1636/1670 train_time:156057ms step_avg:95.39ms +step:1637/1670 train_time:156154ms step_avg:95.39ms +step:1638/1670 train_time:156252ms step_avg:95.39ms +step:1639/1670 train_time:156348ms step_avg:95.39ms +step:1640/1670 train_time:156444ms step_avg:95.39ms +step:1641/1670 train_time:156539ms step_avg:95.39ms +step:1642/1670 train_time:156633ms step_avg:95.39ms +step:1643/1670 train_time:156727ms step_avg:95.39ms +step:1644/1670 train_time:156820ms step_avg:95.39ms +step:1645/1670 train_time:156914ms step_avg:95.39ms +step:1646/1670 train_time:157009ms step_avg:95.39ms +step:1647/1670 train_time:157104ms step_avg:95.39ms +step:1648/1670 train_time:157201ms step_avg:95.39ms +step:1649/1670 train_time:157298ms step_avg:95.39ms +step:1650/1670 train_time:157393ms step_avg:95.39ms +step:1651/1670 train_time:157487ms step_avg:95.39ms +step:1652/1670 train_time:157583ms step_avg:95.39ms +step:1653/1670 train_time:157677ms step_avg:95.39ms +step:1654/1670 train_time:157771ms step_avg:95.39ms +step:1655/1670 train_time:157865ms step_avg:95.39ms +step:1656/1670 train_time:157959ms step_avg:95.39ms +step:1657/1670 train_time:158054ms step_avg:95.39ms +step:1658/1670 train_time:158148ms step_avg:95.38ms +step:1659/1670 train_time:158243ms step_avg:95.38ms +step:1660/1670 train_time:158338ms step_avg:95.38ms +step:1661/1670 train_time:158434ms step_avg:95.38ms +step:1662/1670 train_time:158529ms step_avg:95.38ms +step:1663/1670 train_time:158625ms step_avg:95.39ms +step:1664/1670 train_time:158721ms step_avg:95.39ms +step:1665/1670 train_time:158816ms step_avg:95.38ms +step:1666/1670 train_time:158909ms step_avg:95.38ms +step:1667/1670 train_time:159004ms step_avg:95.38ms +step:1668/1670 train_time:159099ms step_avg:95.38ms +step:1669/1670 train_time:159194ms step_avg:95.38ms +step:1670/1670 train_time:159289ms step_avg:95.38ms +step:1670/1670 val_loss:3.2789 train_time:159467ms step_avg:95.49ms +peak memory allocated: 32712 MiB reserved: 46816 MiB diff --git a/records/091025_Yarn/ReadMe.md b/records/091025_Yarn/ReadMe.md new file mode 100644 index 000000000..60db04692 --- /dev/null +++ b/records/091025_Yarn/ReadMe.md @@ -0,0 +1,94 @@ +This PR of 159.3s incorporates YaRN into the training window schedule and final validation. https://arxiv.org/pdf/2309.00071 +This submission includes all recent WR improvements, including dropping initial MLP layer by @EmelyanenkoK in [PR 120](https://github.com/KellerJordan/modded-nanogpt/pull/120). + +Longer attention windows take longer to train, but produce models with lower loss. Two phenomena occur in RoPE when the attention window is increased during or after training: +1. Dimensions with low frequency rotations experience unfamiliar rotation angles. For instance, a dimension that rotates 0.1 degrees per position will have experienced 0.1*384=38.4 degrees of rotation during training on ws 384. When the sliding window is expanded to 896, it experiences up to 89.6 degrees of rotation. This out of distribution data causes a temporary loss spike. +2. In particular when K and Q vectored are normed, perplexity of the attn mechanism increases as the number of keys increases. Applying a scaling factor d to softmax(d*QK) enables the perplexity of the data to be controlled as the number of keys in the attention window increases. + +A single copy of rotary embeddings is stored in the model root to reduce update time, reduce memory size, and potentially improve cache performance. +``` +# store single copy of rotary tensors +angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) +# half-truncate RoPE by @YouJiacheng (w/ base freq tuning) +angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) +t = torch.arange(self.max_seq_len, dtype=torch.float32) +theta = torch.outer(t, angular_freq) +self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) +self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) +``` + +Based on empirical testing, the 0.1 constant in 0.1*log(curr/prev)+1 formula from YaRN is updated to 0.2. +The constant attn_scale of 0.12 is updated to a starting value of 0.1, such that the distribution over training has a similar mean, ranging between 0.1 and 0.14. + + +``` +# scale attention factor f in attn=softmax(f*qk) logarithmically with window size +windows = list(dict.fromkeys(args.ws_schedule + [args.ws_validate])) +scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] +# start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 +attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) +self.attn_scales = dict(zip(windows, attn_scales)) +``` + +YaRN has a straighforward implementation, shown below. alpha and beta are left at the default constants of 1 and 32, based on the original YaRN paper which was tuned for Llama. +``` +def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) +``` +The frequency update incurred by YaRN is most notable from ws 3->7 and dimensions 5 to 10. + + +Arg ws_validate enables the model to be validated at a longer attention window than training. This arg is set to 13, which differs from the final training window size of 11. + +``` +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] +``` + +Attention args are batched to improve readablility. cooldown_frac is increased from 0.45 to 0.5 to compliment the reduction from 1705 to 1670 steps, following the heuristic of a fixed number of cooldown steps. Dropping below 1695 steps has a secondary benefit of eliminating the 9th file read, saving roughly 200ms. + +Without YaRN, there is a substantial spike in validation loss when the attention window is abrubtly increased from 3 to 7. + +Extending the final validation window out shows roughly a 0.0015 improvement in loss for 11->13. Interestingly, odd increments perform substantially better. @varunneal has noted that "One thing to note is that floor division (ws_short = ws_long // 2) has different behavior for odd vs short window sizes. I generally found odd window sizes performed surprisingly better." The attention schedule follows (long/short) (3/1) -> (7/3) -> (11/5). It may be that the short attention window performs better when it is under 50% of the long window, or it may be that the model learns to fit the long/short ratio, and performs poorly when this ratio is substantially altered, or there may be a completely different explanation. + +Ablations were ran to measure the impact of each change: +* new_record +* no_attn_scale. Keep constant attn scale of 0.12. +* no_freq_scale. Keep constant rotary freq based on 1024^(0..1). +* prior_record. Updated steps from 1705 to 1670. + + +Future Considerations: +* Right now model training is like a racecar with no brakes. There may be a way to effectively dampen the optimizer state momentum terms when the model updates its attention window size and 'changes direction'. Preliminary testing here on only the Muon params gave negative results. +* There may be a way to distribute the load of finding bos token indicies for all 8 files. If each GPU is given 1 file instead of 8 to locate the bos_tokens, this could save up to roughly 200ms*7 = 1.4 seconds assuming zero overhead. +* Starting RoPE at a max angular frequency of 1 radian per position, or 57 degrees, seems arbitrary. However, increasing this to 180 degrees did not show an improvement in performance. +* Plotting validation loss every 125 iterations masks critical issues like loss spikes on attn window updates. In general, more granular monitoring seems useful. + +Validation: +``` +import scipy.stats +import torch +accs = [3.2779, 3.2779, 3.2789, 3.2778, 3.2789, 3.2785, 3.2806] +times = [159.447, 158.998, 159.467, 159.191, 159.503, 159.259, 159.468] + +print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) +# p=0.0053 + +print('acc:',torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0010), tensor(3.2786)) + +print('time:',torch.std_mean(torch.tensor(times))) +# time: (tensor(0.1897), tensor(159.3333)) +``` \ No newline at end of file diff --git a/records/091025_Yarn/ef66c943-e262-400f-822f-068d397a1dc9.txt b/records/091025_Yarn/ef66c943-e262-400f-822f-068d397a1dc9.txt new file mode 100644 index 000000000..ff868602b --- /dev/null +++ b/records/091025_Yarn/ef66c943-e262-400f-822f-068d397a1dc9.txt @@ -0,0 +1,2863 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +import math + +from dataclasses import dataclass +from functools import lru_cache +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"yarn/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = (step == train_steps) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 04:05:30 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | Off | +| N/A 41C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | Off | +| N/A 43C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | Off | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | Off | +| N/A 36C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | Off | +| N/A 43C P0 128W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | Off | +| N/A 41C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | Off | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 64827 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 64828 C /usr/bin/python3 614MiB | +| 0 N/A N/A 64829 C /usr/bin/python3 614MiB | +| 0 N/A N/A 64830 C /usr/bin/python3 614MiB | +| 0 N/A N/A 64831 C /usr/bin/python3 614MiB | +| 0 N/A N/A 64832 C /usr/bin/python3 614MiB | +| 0 N/A N/A 64833 C /usr/bin/python3 614MiB | +| 0 N/A N/A 64834 C /usr/bin/python3 614MiB | +| 1 N/A N/A 64828 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 64829 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 64830 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 64831 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 64832 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 64833 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 64834 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:435ms step_avg:434.72ms +step:2/1670 train_time:459ms step_avg:229.31ms +step:3/1670 train_time:527ms step_avg:175.56ms +step:4/1670 train_time:617ms step_avg:154.28ms +step:5/1670 train_time:709ms step_avg:141.72ms +step:6/1670 train_time:800ms step_avg:133.38ms +step:7/1670 train_time:892ms step_avg:127.40ms +step:8/1670 train_time:984ms step_avg:123.03ms +step:9/1670 train_time:1076ms step_avg:119.60ms +step:10/1670 train_time:1168ms step_avg:116.80ms +step:11/1670 train_time:1260ms step_avg:114.50ms +step:12/1670 train_time:1353ms step_avg:112.79ms +step:13/1670 train_time:1451ms step_avg:111.61ms +step:14/1670 train_time:1544ms step_avg:110.28ms +step:15/1670 train_time:1637ms step_avg:109.10ms +step:16/1670 train_time:1729ms step_avg:108.08ms +step:17/1670 train_time:1821ms step_avg:107.13ms +step:18/1670 train_time:1914ms step_avg:106.31ms +step:19/1670 train_time:2006ms step_avg:105.58ms +step:20/1670 train_time:2098ms step_avg:104.89ms +step:21/1670 train_time:2190ms step_avg:104.28ms +step:22/1670 train_time:2282ms step_avg:103.74ms +step:23/1670 train_time:2376ms step_avg:103.31ms +step:24/1670 train_time:2471ms step_avg:102.95ms +step:25/1670 train_time:2564ms step_avg:102.57ms +step:26/1670 train_time:2657ms step_avg:102.21ms +step:27/1670 train_time:2750ms step_avg:101.86ms +step:28/1670 train_time:2843ms step_avg:101.54ms +step:29/1670 train_time:2937ms step_avg:101.27ms +step:30/1670 train_time:3029ms step_avg:100.97ms +step:31/1670 train_time:3121ms step_avg:100.69ms +step:32/1670 train_time:3213ms step_avg:100.42ms +step:33/1670 train_time:3306ms step_avg:100.18ms +step:34/1670 train_time:3400ms step_avg:100.01ms +step:35/1670 train_time:3494ms step_avg:99.82ms +step:36/1670 train_time:3586ms step_avg:99.62ms +step:37/1670 train_time:3681ms step_avg:99.48ms +step:38/1670 train_time:3773ms step_avg:99.30ms +step:39/1670 train_time:3865ms step_avg:99.11ms +step:40/1670 train_time:3959ms step_avg:98.97ms +step:41/1670 train_time:4051ms step_avg:98.81ms +step:42/1670 train_time:4144ms step_avg:98.67ms +step:43/1670 train_time:4237ms step_avg:98.53ms +step:44/1670 train_time:4329ms step_avg:98.39ms +step:45/1670 train_time:4422ms step_avg:98.26ms +step:46/1670 train_time:4515ms step_avg:98.15ms +step:47/1670 train_time:4608ms step_avg:98.05ms +step:48/1670 train_time:4702ms step_avg:97.95ms +step:49/1670 train_time:4794ms step_avg:97.84ms +step:50/1670 train_time:4887ms step_avg:97.73ms +step:51/1670 train_time:4980ms step_avg:97.65ms +step:52/1670 train_time:5073ms step_avg:97.56ms +step:53/1670 train_time:5165ms step_avg:97.46ms +step:54/1670 train_time:5258ms step_avg:97.38ms +step:55/1670 train_time:5350ms step_avg:97.27ms +step:56/1670 train_time:5443ms step_avg:97.19ms +step:57/1670 train_time:5535ms step_avg:97.10ms +step:58/1670 train_time:5628ms step_avg:97.03ms +step:59/1670 train_time:5721ms step_avg:96.96ms +step:60/1670 train_time:5814ms step_avg:96.89ms +step:61/1670 train_time:5907ms step_avg:96.84ms +step:62/1670 train_time:6001ms step_avg:96.79ms +step:63/1670 train_time:6094ms step_avg:96.73ms +step:64/1670 train_time:6186ms step_avg:96.65ms +step:65/1670 train_time:6278ms step_avg:96.58ms +step:66/1670 train_time:6371ms step_avg:96.53ms +step:67/1670 train_time:6463ms step_avg:96.47ms +step:68/1670 train_time:6556ms step_avg:96.41ms +step:69/1670 train_time:6649ms step_avg:96.36ms +step:70/1670 train_time:6741ms step_avg:96.30ms +step:71/1670 train_time:6834ms step_avg:96.25ms +step:72/1670 train_time:6928ms step_avg:96.22ms +step:73/1670 train_time:7021ms step_avg:96.18ms +step:74/1670 train_time:7114ms step_avg:96.13ms +step:75/1670 train_time:7207ms step_avg:96.09ms +step:76/1670 train_time:7300ms step_avg:96.05ms +step:77/1670 train_time:7392ms step_avg:96.00ms +step:78/1670 train_time:7485ms step_avg:95.97ms +step:79/1670 train_time:7578ms step_avg:95.93ms +step:80/1670 train_time:7671ms step_avg:95.89ms +step:81/1670 train_time:7763ms step_avg:95.84ms +step:82/1670 train_time:7856ms step_avg:95.80ms +step:83/1670 train_time:7948ms step_avg:95.76ms +step:84/1670 train_time:8041ms step_avg:95.72ms +step:85/1670 train_time:8134ms step_avg:95.70ms +step:86/1670 train_time:8226ms step_avg:95.65ms +step:87/1670 train_time:8319ms step_avg:95.62ms +step:88/1670 train_time:8411ms step_avg:95.59ms +step:89/1670 train_time:8503ms step_avg:95.54ms +step:90/1670 train_time:8596ms step_avg:95.51ms +step:91/1670 train_time:8688ms step_avg:95.47ms +step:92/1670 train_time:8781ms step_avg:95.45ms +step:93/1670 train_time:8874ms step_avg:95.42ms +step:94/1670 train_time:8967ms step_avg:95.39ms +step:95/1670 train_time:9059ms step_avg:95.36ms +step:96/1670 train_time:9151ms step_avg:95.33ms +step:97/1670 train_time:9244ms step_avg:95.30ms +step:98/1670 train_time:9336ms step_avg:95.27ms +step:99/1670 train_time:9429ms step_avg:95.24ms +step:100/1670 train_time:9521ms step_avg:95.21ms +step:101/1670 train_time:9613ms step_avg:95.18ms +step:102/1670 train_time:9706ms step_avg:95.15ms +step:103/1670 train_time:9798ms step_avg:95.13ms +step:104/1670 train_time:9891ms step_avg:95.11ms +step:105/1670 train_time:9983ms step_avg:95.07ms +step:106/1670 train_time:10076ms step_avg:95.05ms +step:107/1670 train_time:10169ms step_avg:95.03ms +step:108/1670 train_time:10260ms step_avg:95.00ms +step:109/1670 train_time:10353ms step_avg:94.98ms +step:110/1670 train_time:10446ms step_avg:94.96ms +step:111/1670 train_time:10538ms step_avg:94.93ms +step:112/1670 train_time:10630ms step_avg:94.91ms +step:113/1670 train_time:10722ms step_avg:94.89ms +step:114/1670 train_time:10815ms step_avg:94.86ms +step:115/1670 train_time:10908ms step_avg:94.85ms +step:116/1670 train_time:11000ms step_avg:94.83ms +step:117/1670 train_time:11092ms step_avg:94.80ms +step:118/1670 train_time:11185ms step_avg:94.79ms +step:119/1670 train_time:11277ms step_avg:94.76ms +step:120/1670 train_time:11370ms step_avg:94.75ms +step:121/1670 train_time:11462ms step_avg:94.72ms +step:122/1670 train_time:11554ms step_avg:94.71ms +step:123/1670 train_time:11647ms step_avg:94.69ms +step:124/1670 train_time:11740ms step_avg:94.68ms +step:125/1670 train_time:11833ms step_avg:94.66ms +step:125/1670 val_loss:4.3038 train_time:11924ms step_avg:95.39ms +step:126/1670 train_time:11949ms step_avg:94.83ms +step:127/1670 train_time:12021ms step_avg:94.66ms +step:128/1670 train_time:12124ms step_avg:94.72ms +step:129/1670 train_time:12222ms step_avg:94.74ms +step:130/1670 train_time:12314ms step_avg:94.73ms +step:131/1670 train_time:12407ms step_avg:94.71ms +step:132/1670 train_time:12499ms step_avg:94.69ms +step:133/1670 train_time:12590ms step_avg:94.67ms +step:134/1670 train_time:12682ms step_avg:94.64ms +step:135/1670 train_time:12773ms step_avg:94.62ms +step:136/1670 train_time:12865ms step_avg:94.59ms +step:137/1670 train_time:12957ms step_avg:94.58ms +step:138/1670 train_time:13050ms step_avg:94.57ms +step:139/1670 train_time:13144ms step_avg:94.56ms +step:140/1670 train_time:13237ms step_avg:94.55ms +step:141/1670 train_time:13330ms step_avg:94.54ms +step:142/1670 train_time:13423ms step_avg:94.53ms +step:143/1670 train_time:13515ms step_avg:94.51ms +step:144/1670 train_time:13606ms step_avg:94.49ms +step:145/1670 train_time:13698ms step_avg:94.47ms +step:146/1670 train_time:13790ms step_avg:94.45ms +step:147/1670 train_time:13882ms step_avg:94.43ms +step:148/1670 train_time:13974ms step_avg:94.42ms +step:149/1670 train_time:14066ms step_avg:94.40ms +step:150/1670 train_time:14160ms step_avg:94.40ms +step:151/1670 train_time:14254ms step_avg:94.39ms +step:152/1670 train_time:14346ms step_avg:94.38ms +step:153/1670 train_time:14438ms step_avg:94.37ms +step:154/1670 train_time:14530ms step_avg:94.35ms +step:155/1670 train_time:14623ms step_avg:94.34ms +step:156/1670 train_time:14715ms step_avg:94.33ms +step:157/1670 train_time:14807ms step_avg:94.31ms +step:158/1670 train_time:14899ms step_avg:94.30ms +step:159/1670 train_time:14991ms step_avg:94.29ms +step:160/1670 train_time:15084ms step_avg:94.28ms +step:161/1670 train_time:15177ms step_avg:94.27ms +step:162/1670 train_time:15270ms step_avg:94.26ms +step:163/1670 train_time:15363ms step_avg:94.25ms +step:164/1670 train_time:15457ms step_avg:94.25ms +step:165/1670 train_time:15549ms step_avg:94.23ms +step:166/1670 train_time:15641ms step_avg:94.22ms +step:167/1670 train_time:15734ms step_avg:94.21ms +step:168/1670 train_time:15826ms step_avg:94.21ms +step:169/1670 train_time:15918ms step_avg:94.19ms +step:170/1670 train_time:16010ms step_avg:94.18ms +step:171/1670 train_time:16103ms step_avg:94.17ms +step:172/1670 train_time:16195ms step_avg:94.16ms +step:173/1670 train_time:16288ms step_avg:94.15ms +step:174/1670 train_time:16381ms step_avg:94.15ms +step:175/1670 train_time:16474ms step_avg:94.14ms +step:176/1670 train_time:16566ms step_avg:94.13ms +step:177/1670 train_time:16659ms step_avg:94.12ms +step:178/1670 train_time:16750ms step_avg:94.10ms +step:179/1670 train_time:16844ms step_avg:94.10ms +step:180/1670 train_time:16936ms step_avg:94.09ms +step:181/1670 train_time:17029ms step_avg:94.08ms +step:182/1670 train_time:17121ms step_avg:94.07ms +step:183/1670 train_time:17214ms step_avg:94.07ms +step:184/1670 train_time:17306ms step_avg:94.06ms +step:185/1670 train_time:17399ms step_avg:94.05ms +step:186/1670 train_time:17492ms step_avg:94.04ms +step:187/1670 train_time:17585ms step_avg:94.04ms +step:188/1670 train_time:17677ms step_avg:94.03ms +step:189/1670 train_time:17769ms step_avg:94.02ms +step:190/1670 train_time:17861ms step_avg:94.01ms +step:191/1670 train_time:17954ms step_avg:94.00ms +step:192/1670 train_time:18046ms step_avg:93.99ms +step:193/1670 train_time:18138ms step_avg:93.98ms +step:194/1670 train_time:18230ms step_avg:93.97ms +step:195/1670 train_time:18323ms step_avg:93.97ms +step:196/1670 train_time:18416ms step_avg:93.96ms +step:197/1670 train_time:18509ms step_avg:93.95ms +step:198/1670 train_time:18602ms step_avg:93.95ms +step:199/1670 train_time:18694ms step_avg:93.94ms +step:200/1670 train_time:18786ms step_avg:93.93ms +step:201/1670 train_time:18878ms step_avg:93.92ms +step:202/1670 train_time:18971ms step_avg:93.92ms +step:203/1670 train_time:19065ms step_avg:93.91ms +step:204/1670 train_time:19157ms step_avg:93.91ms +step:205/1670 train_time:19249ms step_avg:93.90ms +step:206/1670 train_time:19343ms step_avg:93.90ms +step:207/1670 train_time:19436ms step_avg:93.89ms +step:208/1670 train_time:19528ms step_avg:93.88ms +step:209/1670 train_time:19622ms step_avg:93.88ms +step:210/1670 train_time:19714ms step_avg:93.88ms +step:211/1670 train_time:19806ms step_avg:93.87ms +step:212/1670 train_time:19899ms step_avg:93.86ms +step:213/1670 train_time:20255ms step_avg:95.10ms +step:214/1670 train_time:20379ms step_avg:95.23ms +step:215/1670 train_time:20471ms step_avg:95.22ms +step:216/1670 train_time:20563ms step_avg:95.20ms +step:217/1670 train_time:20654ms step_avg:95.18ms +step:218/1670 train_time:20746ms step_avg:95.16ms +step:219/1670 train_time:20837ms step_avg:95.15ms +step:220/1670 train_time:20928ms step_avg:95.13ms +step:221/1670 train_time:21019ms step_avg:95.11ms +step:222/1670 train_time:21111ms step_avg:95.09ms +step:223/1670 train_time:21203ms step_avg:95.08ms +step:224/1670 train_time:21298ms step_avg:95.08ms +step:225/1670 train_time:21394ms step_avg:95.09ms +step:226/1670 train_time:21488ms step_avg:95.08ms +step:227/1670 train_time:21580ms step_avg:95.07ms +step:228/1670 train_time:21672ms step_avg:95.05ms +step:229/1670 train_time:21764ms step_avg:95.04ms +step:230/1670 train_time:21856ms step_avg:95.03ms +step:231/1670 train_time:21948ms step_avg:95.01ms +step:232/1670 train_time:22040ms step_avg:95.00ms +step:233/1670 train_time:22131ms step_avg:94.98ms +step:234/1670 train_time:22225ms step_avg:94.98ms +step:235/1670 train_time:22318ms step_avg:94.97ms +step:236/1670 train_time:22411ms step_avg:94.96ms +step:237/1670 train_time:22504ms step_avg:94.95ms +step:238/1670 train_time:22596ms step_avg:94.94ms +step:239/1670 train_time:22688ms step_avg:94.93ms +step:240/1670 train_time:22781ms step_avg:94.92ms +step:241/1670 train_time:22873ms step_avg:94.91ms +step:242/1670 train_time:22965ms step_avg:94.90ms +step:243/1670 train_time:23057ms step_avg:94.88ms +step:244/1670 train_time:23149ms step_avg:94.87ms +step:245/1670 train_time:23243ms step_avg:94.87ms +step:246/1670 train_time:23337ms step_avg:94.86ms +step:247/1670 train_time:23429ms step_avg:94.86ms +step:248/1670 train_time:23523ms step_avg:94.85ms +step:249/1670 train_time:23616ms step_avg:94.84ms +step:250/1670 train_time:23708ms step_avg:94.83ms +step:250/1670 val_loss:3.9639 train_time:23799ms step_avg:95.20ms +step:251/1670 train_time:23825ms step_avg:94.92ms +step:252/1670 train_time:23901ms step_avg:94.84ms +step:253/1670 train_time:24001ms step_avg:94.86ms +step:254/1670 train_time:24096ms step_avg:94.87ms +step:255/1670 train_time:24187ms step_avg:94.85ms +step:256/1670 train_time:24279ms step_avg:94.84ms +step:257/1670 train_time:24370ms step_avg:94.83ms +step:258/1670 train_time:24462ms step_avg:94.81ms +step:259/1670 train_time:24553ms step_avg:94.80ms +step:260/1670 train_time:24645ms step_avg:94.79ms +step:261/1670 train_time:24737ms step_avg:94.78ms +step:262/1670 train_time:24829ms step_avg:94.77ms +step:263/1670 train_time:24924ms step_avg:94.77ms +step:264/1670 train_time:25018ms step_avg:94.77ms +step:265/1670 train_time:25112ms step_avg:94.76ms +step:266/1670 train_time:25205ms step_avg:94.75ms +step:267/1670 train_time:25297ms step_avg:94.75ms +step:268/1670 train_time:25388ms step_avg:94.73ms +step:269/1670 train_time:25480ms step_avg:94.72ms +step:270/1670 train_time:25572ms step_avg:94.71ms +step:271/1670 train_time:25663ms step_avg:94.70ms +step:272/1670 train_time:25756ms step_avg:94.69ms +step:273/1670 train_time:25849ms step_avg:94.68ms +step:274/1670 train_time:25942ms step_avg:94.68ms +step:275/1670 train_time:26035ms step_avg:94.67ms +step:276/1670 train_time:26129ms step_avg:94.67ms +step:277/1670 train_time:26222ms step_avg:94.66ms +step:278/1670 train_time:26314ms step_avg:94.66ms +step:279/1670 train_time:26406ms step_avg:94.65ms +step:280/1670 train_time:26498ms step_avg:94.64ms +step:281/1670 train_time:26589ms step_avg:94.62ms +step:282/1670 train_time:26681ms step_avg:94.61ms +step:283/1670 train_time:26773ms step_avg:94.61ms +step:284/1670 train_time:26867ms step_avg:94.60ms +step:285/1670 train_time:26960ms step_avg:94.60ms +step:286/1670 train_time:27053ms step_avg:94.59ms +step:287/1670 train_time:27147ms step_avg:94.59ms +step:288/1670 train_time:27239ms step_avg:94.58ms +step:289/1670 train_time:27332ms step_avg:94.57ms +step:290/1670 train_time:27424ms step_avg:94.57ms +step:291/1670 train_time:27516ms step_avg:94.56ms +step:292/1670 train_time:27608ms step_avg:94.55ms +step:293/1670 train_time:27701ms step_avg:94.54ms +step:294/1670 train_time:27793ms step_avg:94.53ms +step:295/1670 train_time:27885ms step_avg:94.53ms +step:296/1670 train_time:27978ms step_avg:94.52ms +step:297/1670 train_time:28071ms step_avg:94.51ms +step:298/1670 train_time:28163ms step_avg:94.51ms +step:299/1670 train_time:28255ms step_avg:94.50ms +step:300/1670 train_time:28348ms step_avg:94.49ms +step:301/1670 train_time:28440ms step_avg:94.48ms +step:302/1670 train_time:28532ms step_avg:94.48ms +step:303/1670 train_time:28625ms step_avg:94.47ms +step:304/1670 train_time:28717ms step_avg:94.46ms +step:305/1670 train_time:28809ms step_avg:94.46ms +step:306/1670 train_time:28903ms step_avg:94.45ms +step:307/1670 train_time:28995ms step_avg:94.45ms +step:308/1670 train_time:29088ms step_avg:94.44ms +step:309/1670 train_time:29181ms step_avg:94.44ms +step:310/1670 train_time:29275ms step_avg:94.44ms +step:311/1670 train_time:29367ms step_avg:94.43ms +step:312/1670 train_time:29460ms step_avg:94.42ms +step:313/1670 train_time:29553ms step_avg:94.42ms +step:314/1670 train_time:29645ms step_avg:94.41ms +step:315/1670 train_time:29737ms step_avg:94.40ms +step:316/1670 train_time:29830ms step_avg:94.40ms +step:317/1670 train_time:29922ms step_avg:94.39ms +step:318/1670 train_time:30015ms step_avg:94.39ms +step:319/1670 train_time:30108ms step_avg:94.38ms +step:320/1670 train_time:30201ms step_avg:94.38ms +step:321/1670 train_time:30294ms step_avg:94.37ms +step:322/1670 train_time:30386ms step_avg:94.37ms +step:323/1670 train_time:30479ms step_avg:94.36ms +step:324/1670 train_time:30571ms step_avg:94.35ms +step:325/1670 train_time:30663ms step_avg:94.35ms +step:326/1670 train_time:30755ms step_avg:94.34ms +step:327/1670 train_time:30848ms step_avg:94.34ms +step:328/1670 train_time:30940ms step_avg:94.33ms +step:329/1670 train_time:31033ms step_avg:94.32ms +step:330/1670 train_time:31125ms step_avg:94.32ms +step:331/1670 train_time:31218ms step_avg:94.31ms +step:332/1670 train_time:31311ms step_avg:94.31ms +step:333/1670 train_time:31403ms step_avg:94.30ms +step:334/1670 train_time:31495ms step_avg:94.30ms +step:335/1670 train_time:31587ms step_avg:94.29ms +step:336/1670 train_time:31679ms step_avg:94.28ms +step:337/1670 train_time:31772ms step_avg:94.28ms +step:338/1670 train_time:31864ms step_avg:94.27ms +step:339/1670 train_time:31957ms step_avg:94.27ms +step:340/1670 train_time:32050ms step_avg:94.26ms +step:341/1670 train_time:32142ms step_avg:94.26ms +step:342/1670 train_time:32234ms step_avg:94.25ms +step:343/1670 train_time:32327ms step_avg:94.25ms +step:344/1670 train_time:32419ms step_avg:94.24ms +step:345/1670 train_time:32512ms step_avg:94.24ms +step:346/1670 train_time:32604ms step_avg:94.23ms +step:347/1670 train_time:32696ms step_avg:94.23ms +step:348/1670 train_time:32789ms step_avg:94.22ms +step:349/1670 train_time:32881ms step_avg:94.22ms +step:350/1670 train_time:32974ms step_avg:94.21ms +step:351/1670 train_time:33067ms step_avg:94.21ms +step:352/1670 train_time:33160ms step_avg:94.20ms +step:353/1670 train_time:33253ms step_avg:94.20ms +step:354/1670 train_time:33345ms step_avg:94.20ms +step:355/1670 train_time:33437ms step_avg:94.19ms +step:356/1670 train_time:33530ms step_avg:94.19ms +step:357/1670 train_time:33622ms step_avg:94.18ms +step:358/1670 train_time:33715ms step_avg:94.18ms +step:359/1670 train_time:33807ms step_avg:94.17ms +step:360/1670 train_time:33899ms step_avg:94.16ms +step:361/1670 train_time:33992ms step_avg:94.16ms +step:362/1670 train_time:34085ms step_avg:94.16ms +step:363/1670 train_time:34178ms step_avg:94.15ms +step:364/1670 train_time:34271ms step_avg:94.15ms +step:365/1670 train_time:34362ms step_avg:94.14ms +step:366/1670 train_time:34455ms step_avg:94.14ms +step:367/1670 train_time:34548ms step_avg:94.14ms +step:368/1670 train_time:34640ms step_avg:94.13ms +step:369/1670 train_time:34733ms step_avg:94.13ms +step:370/1670 train_time:34825ms step_avg:94.12ms +step:371/1670 train_time:34918ms step_avg:94.12ms +step:372/1670 train_time:35010ms step_avg:94.11ms +step:373/1670 train_time:35104ms step_avg:94.11ms +step:374/1670 train_time:35195ms step_avg:94.11ms +step:375/1670 train_time:35288ms step_avg:94.10ms +step:375/1670 val_loss:3.8113 train_time:35378ms step_avg:94.34ms +step:376/1670 train_time:35403ms step_avg:94.16ms +step:377/1670 train_time:35478ms step_avg:94.11ms +step:378/1670 train_time:35577ms step_avg:94.12ms +step:379/1670 train_time:35673ms step_avg:94.12ms +step:380/1670 train_time:35765ms step_avg:94.12ms +step:381/1670 train_time:35857ms step_avg:94.11ms +step:382/1670 train_time:35949ms step_avg:94.11ms +step:383/1670 train_time:36040ms step_avg:94.10ms +step:384/1670 train_time:36132ms step_avg:94.09ms +step:385/1670 train_time:36224ms step_avg:94.09ms +step:386/1670 train_time:36315ms step_avg:94.08ms +step:387/1670 train_time:36407ms step_avg:94.08ms +step:388/1670 train_time:36502ms step_avg:94.08ms +step:389/1670 train_time:36597ms step_avg:94.08ms +step:390/1670 train_time:36690ms step_avg:94.08ms +step:391/1670 train_time:36782ms step_avg:94.07ms +step:392/1670 train_time:36874ms step_avg:94.07ms +step:393/1670 train_time:36967ms step_avg:94.06ms +step:394/1670 train_time:37059ms step_avg:94.06ms +step:395/1670 train_time:37151ms step_avg:94.05ms +step:396/1670 train_time:37242ms step_avg:94.05ms +step:397/1670 train_time:37333ms step_avg:94.04ms +step:398/1670 train_time:37426ms step_avg:94.03ms +step:399/1670 train_time:37521ms step_avg:94.04ms +step:400/1670 train_time:37615ms step_avg:94.04ms +step:401/1670 train_time:37708ms step_avg:94.03ms +step:402/1670 train_time:37802ms step_avg:94.03ms +step:403/1670 train_time:37894ms step_avg:94.03ms +step:404/1670 train_time:37986ms step_avg:94.02ms +step:405/1670 train_time:38078ms step_avg:94.02ms +step:406/1670 train_time:38170ms step_avg:94.01ms +step:407/1670 train_time:38261ms step_avg:94.01ms +step:408/1670 train_time:38353ms step_avg:94.00ms +step:409/1670 train_time:38446ms step_avg:94.00ms +step:410/1670 train_time:38539ms step_avg:94.00ms +step:411/1670 train_time:38632ms step_avg:94.00ms +step:412/1670 train_time:38725ms step_avg:93.99ms +step:413/1670 train_time:38819ms step_avg:93.99ms +step:414/1670 train_time:38912ms step_avg:93.99ms +step:415/1670 train_time:39004ms step_avg:93.99ms +step:416/1670 train_time:39095ms step_avg:93.98ms +step:417/1670 train_time:39187ms step_avg:93.97ms +step:418/1670 train_time:39280ms step_avg:93.97ms +step:419/1670 train_time:39373ms step_avg:93.97ms +step:420/1670 train_time:39464ms step_avg:93.96ms +step:421/1670 train_time:39557ms step_avg:93.96ms +step:422/1670 train_time:39650ms step_avg:93.96ms +step:423/1670 train_time:39743ms step_avg:93.95ms +step:424/1670 train_time:39835ms step_avg:93.95ms +step:425/1670 train_time:40165ms step_avg:94.51ms +step:426/1670 train_time:40358ms step_avg:94.74ms +step:427/1670 train_time:40448ms step_avg:94.73ms +step:428/1670 train_time:40539ms step_avg:94.72ms +step:429/1670 train_time:40630ms step_avg:94.71ms +step:430/1670 train_time:40722ms step_avg:94.70ms +step:431/1670 train_time:40814ms step_avg:94.70ms +step:432/1670 train_time:40905ms step_avg:94.69ms +step:433/1670 train_time:40997ms step_avg:94.68ms +step:434/1670 train_time:41088ms step_avg:94.67ms +step:435/1670 train_time:41182ms step_avg:94.67ms +step:436/1670 train_time:41276ms step_avg:94.67ms +step:437/1670 train_time:41374ms step_avg:94.68ms +step:438/1670 train_time:41467ms step_avg:94.67ms +step:439/1670 train_time:41560ms step_avg:94.67ms +step:440/1670 train_time:41652ms step_avg:94.66ms +step:441/1670 train_time:41744ms step_avg:94.66ms +step:442/1670 train_time:41835ms step_avg:94.65ms +step:443/1670 train_time:41931ms step_avg:94.65ms +step:444/1670 train_time:42024ms step_avg:94.65ms +step:445/1670 train_time:42116ms step_avg:94.64ms +step:446/1670 train_time:42208ms step_avg:94.64ms +step:447/1670 train_time:42297ms step_avg:94.63ms +step:448/1670 train_time:42391ms step_avg:94.62ms +step:449/1670 train_time:42484ms step_avg:94.62ms +step:450/1670 train_time:42577ms step_avg:94.62ms +step:451/1670 train_time:42670ms step_avg:94.61ms +step:452/1670 train_time:42762ms step_avg:94.61ms +step:453/1670 train_time:42854ms step_avg:94.60ms +step:454/1670 train_time:42945ms step_avg:94.59ms +step:455/1670 train_time:43037ms step_avg:94.59ms +step:456/1670 train_time:43129ms step_avg:94.58ms +step:457/1670 train_time:43223ms step_avg:94.58ms +step:458/1670 train_time:43316ms step_avg:94.58ms +step:459/1670 train_time:43408ms step_avg:94.57ms +step:460/1670 train_time:43501ms step_avg:94.57ms +step:461/1670 train_time:43594ms step_avg:94.56ms +step:462/1670 train_time:43686ms step_avg:94.56ms +step:463/1670 train_time:43779ms step_avg:94.56ms +step:464/1670 train_time:43872ms step_avg:94.55ms +step:465/1670 train_time:43964ms step_avg:94.55ms +step:466/1670 train_time:44055ms step_avg:94.54ms +step:467/1670 train_time:44148ms step_avg:94.54ms +step:468/1670 train_time:44241ms step_avg:94.53ms +step:469/1670 train_time:44334ms step_avg:94.53ms +step:470/1670 train_time:44427ms step_avg:94.52ms +step:471/1670 train_time:44520ms step_avg:94.52ms +step:472/1670 train_time:44613ms step_avg:94.52ms +step:473/1670 train_time:44705ms step_avg:94.51ms +step:474/1670 train_time:44798ms step_avg:94.51ms +step:475/1670 train_time:44891ms step_avg:94.51ms +step:476/1670 train_time:44983ms step_avg:94.50ms +step:477/1670 train_time:45075ms step_avg:94.50ms +step:478/1670 train_time:45168ms step_avg:94.49ms +step:479/1670 train_time:45259ms step_avg:94.49ms +step:480/1670 train_time:45352ms step_avg:94.48ms +step:481/1670 train_time:45444ms step_avg:94.48ms +step:482/1670 train_time:45537ms step_avg:94.47ms +step:483/1670 train_time:45629ms step_avg:94.47ms +step:484/1670 train_time:45723ms step_avg:94.47ms +step:485/1670 train_time:45816ms step_avg:94.47ms +step:486/1670 train_time:45908ms step_avg:94.46ms +step:487/1670 train_time:46000ms step_avg:94.46ms +step:488/1670 train_time:46092ms step_avg:94.45ms +step:489/1670 train_time:46185ms step_avg:94.45ms +step:490/1670 train_time:46277ms step_avg:94.44ms +step:491/1670 train_time:46369ms step_avg:94.44ms +step:492/1670 train_time:46462ms step_avg:94.44ms +step:493/1670 train_time:46554ms step_avg:94.43ms +step:494/1670 train_time:46647ms step_avg:94.43ms +step:495/1670 train_time:46740ms step_avg:94.43ms +step:496/1670 train_time:46834ms step_avg:94.42ms +step:497/1670 train_time:46926ms step_avg:94.42ms +step:498/1670 train_time:47019ms step_avg:94.42ms +step:499/1670 train_time:47113ms step_avg:94.41ms +step:500/1670 train_time:47204ms step_avg:94.41ms +step:500/1670 val_loss:3.7121 train_time:47295ms step_avg:94.59ms +step:501/1670 train_time:47320ms step_avg:94.45ms +step:502/1670 train_time:47397ms step_avg:94.42ms +step:503/1670 train_time:47494ms step_avg:94.42ms +step:504/1670 train_time:47588ms step_avg:94.42ms +step:505/1670 train_time:47680ms step_avg:94.42ms +step:506/1670 train_time:47772ms step_avg:94.41ms +step:507/1670 train_time:47863ms step_avg:94.40ms +step:508/1670 train_time:47954ms step_avg:94.40ms +step:509/1670 train_time:48045ms step_avg:94.39ms +step:510/1670 train_time:48137ms step_avg:94.39ms +step:511/1670 train_time:48229ms step_avg:94.38ms +step:512/1670 train_time:48321ms step_avg:94.38ms +step:513/1670 train_time:48415ms step_avg:94.38ms +step:514/1670 train_time:48508ms step_avg:94.37ms +step:515/1670 train_time:48602ms step_avg:94.37ms +step:516/1670 train_time:48694ms step_avg:94.37ms +step:517/1670 train_time:48786ms step_avg:94.36ms +step:518/1670 train_time:48879ms step_avg:94.36ms +step:519/1670 train_time:48971ms step_avg:94.36ms +step:520/1670 train_time:49062ms step_avg:94.35ms +step:521/1670 train_time:49154ms step_avg:94.34ms +step:522/1670 train_time:49245ms step_avg:94.34ms +step:523/1670 train_time:49339ms step_avg:94.34ms +step:524/1670 train_time:49432ms step_avg:94.34ms +step:525/1670 train_time:49526ms step_avg:94.34ms +step:526/1670 train_time:49619ms step_avg:94.33ms +step:527/1670 train_time:49711ms step_avg:94.33ms +step:528/1670 train_time:49804ms step_avg:94.33ms +step:529/1670 train_time:49896ms step_avg:94.32ms +step:530/1670 train_time:49989ms step_avg:94.32ms +step:531/1670 train_time:50081ms step_avg:94.31ms +step:532/1670 train_time:50173ms step_avg:94.31ms +step:533/1670 train_time:50265ms step_avg:94.31ms +step:534/1670 train_time:50357ms step_avg:94.30ms +step:535/1670 train_time:50451ms step_avg:94.30ms +step:536/1670 train_time:50544ms step_avg:94.30ms +step:537/1670 train_time:50636ms step_avg:94.29ms +step:538/1670 train_time:50729ms step_avg:94.29ms +step:539/1670 train_time:50822ms step_avg:94.29ms +step:540/1670 train_time:50914ms step_avg:94.29ms +step:541/1670 train_time:51006ms step_avg:94.28ms +step:542/1670 train_time:51098ms step_avg:94.28ms +step:543/1670 train_time:51191ms step_avg:94.27ms +step:544/1670 train_time:51283ms step_avg:94.27ms +step:545/1670 train_time:51375ms step_avg:94.27ms +step:546/1670 train_time:51469ms step_avg:94.27ms +step:547/1670 train_time:51562ms step_avg:94.26ms +step:548/1670 train_time:51655ms step_avg:94.26ms +step:549/1670 train_time:51748ms step_avg:94.26ms +step:550/1670 train_time:51841ms step_avg:94.26ms +step:551/1670 train_time:51934ms step_avg:94.25ms +step:552/1670 train_time:52026ms step_avg:94.25ms +step:553/1670 train_time:52118ms step_avg:94.25ms +step:554/1670 train_time:52210ms step_avg:94.24ms +step:555/1670 train_time:52304ms step_avg:94.24ms +step:556/1670 train_time:52396ms step_avg:94.24ms +step:557/1670 train_time:52488ms step_avg:94.23ms +step:558/1670 train_time:52690ms step_avg:94.43ms +step:559/1670 train_time:52758ms step_avg:94.38ms +step:560/1670 train_time:52851ms step_avg:94.38ms +step:561/1670 train_time:52943ms step_avg:94.37ms +step:562/1670 train_time:53036ms step_avg:94.37ms +step:563/1670 train_time:53129ms step_avg:94.37ms +step:564/1670 train_time:53221ms step_avg:94.36ms +step:565/1670 train_time:53314ms step_avg:94.36ms +step:566/1670 train_time:53407ms step_avg:94.36ms +step:567/1670 train_time:53500ms step_avg:94.36ms +step:568/1670 train_time:53598ms step_avg:94.36ms +step:569/1670 train_time:53695ms step_avg:94.37ms +step:570/1670 train_time:53789ms step_avg:94.37ms +step:571/1670 train_time:53883ms step_avg:94.37ms +step:572/1670 train_time:53975ms step_avg:94.36ms +step:573/1670 train_time:54068ms step_avg:94.36ms +step:574/1670 train_time:54161ms step_avg:94.36ms +step:575/1670 train_time:54254ms step_avg:94.36ms +step:576/1670 train_time:54347ms step_avg:94.35ms +step:577/1670 train_time:54440ms step_avg:94.35ms +step:578/1670 train_time:54535ms step_avg:94.35ms +step:579/1670 train_time:54631ms step_avg:94.35ms +step:580/1670 train_time:54726ms step_avg:94.36ms +step:581/1670 train_time:54821ms step_avg:94.36ms +step:582/1670 train_time:54914ms step_avg:94.35ms +step:583/1670 train_time:55008ms step_avg:94.35ms +step:584/1670 train_time:55101ms step_avg:94.35ms +step:585/1670 train_time:55195ms step_avg:94.35ms +step:586/1670 train_time:55287ms step_avg:94.35ms +step:587/1670 train_time:55380ms step_avg:94.34ms +step:588/1670 train_time:55474ms step_avg:94.34ms +step:589/1670 train_time:55568ms step_avg:94.34ms +step:590/1670 train_time:55663ms step_avg:94.34ms +step:591/1670 train_time:55757ms step_avg:94.34ms +step:592/1670 train_time:55851ms step_avg:94.34ms +step:593/1670 train_time:55944ms step_avg:94.34ms +step:594/1670 train_time:56038ms step_avg:94.34ms +step:595/1670 train_time:56132ms step_avg:94.34ms +step:596/1670 train_time:56224ms step_avg:94.34ms +step:597/1670 train_time:56318ms step_avg:94.33ms +step:598/1670 train_time:56411ms step_avg:94.33ms +step:599/1670 train_time:56505ms step_avg:94.33ms +step:600/1670 train_time:56599ms step_avg:94.33ms +step:601/1670 train_time:56692ms step_avg:94.33ms +step:602/1670 train_time:56786ms step_avg:94.33ms +step:603/1670 train_time:56880ms step_avg:94.33ms +step:604/1670 train_time:56974ms step_avg:94.33ms +step:605/1670 train_time:57067ms step_avg:94.33ms +step:606/1670 train_time:57161ms step_avg:94.33ms +step:607/1670 train_time:57255ms step_avg:94.32ms +step:608/1670 train_time:57348ms step_avg:94.32ms +step:609/1670 train_time:57442ms step_avg:94.32ms +step:610/1670 train_time:57535ms step_avg:94.32ms +step:611/1670 train_time:57629ms step_avg:94.32ms +step:612/1670 train_time:57723ms step_avg:94.32ms +step:613/1670 train_time:57816ms step_avg:94.32ms +step:614/1670 train_time:57910ms step_avg:94.32ms +step:615/1670 train_time:58004ms step_avg:94.32ms +step:616/1670 train_time:58098ms step_avg:94.31ms +step:617/1670 train_time:58191ms step_avg:94.31ms +step:618/1670 train_time:58284ms step_avg:94.31ms +step:619/1670 train_time:58378ms step_avg:94.31ms +step:620/1670 train_time:58472ms step_avg:94.31ms +step:621/1670 train_time:58566ms step_avg:94.31ms +step:622/1670 train_time:58661ms step_avg:94.31ms +step:623/1670 train_time:58754ms step_avg:94.31ms +step:624/1670 train_time:58847ms step_avg:94.31ms +step:625/1670 train_time:58941ms step_avg:94.31ms +step:625/1670 val_loss:3.6111 train_time:59034ms step_avg:94.45ms +step:626/1670 train_time:59059ms step_avg:94.34ms +step:627/1670 train_time:59141ms step_avg:94.32ms +step:628/1670 train_time:59238ms step_avg:94.33ms +step:629/1670 train_time:59333ms step_avg:94.33ms +step:630/1670 train_time:59426ms step_avg:94.33ms +step:631/1670 train_time:59519ms step_avg:94.32ms +step:632/1670 train_time:59612ms step_avg:94.32ms +step:633/1670 train_time:59704ms step_avg:94.32ms +step:634/1670 train_time:59797ms step_avg:94.32ms +step:635/1670 train_time:59889ms step_avg:94.31ms +step:636/1670 train_time:59982ms step_avg:94.31ms +step:637/1670 train_time:60078ms step_avg:94.31ms +step:638/1670 train_time:60175ms step_avg:94.32ms +step:639/1670 train_time:60618ms step_avg:94.86ms +step:640/1670 train_time:60698ms step_avg:94.84ms +step:641/1670 train_time:60790ms step_avg:94.84ms +step:642/1670 train_time:60883ms step_avg:94.83ms +step:643/1670 train_time:60976ms step_avg:94.83ms +step:644/1670 train_time:61069ms step_avg:94.83ms +step:645/1670 train_time:61161ms step_avg:94.82ms +step:646/1670 train_time:61253ms step_avg:94.82ms +step:647/1670 train_time:61346ms step_avg:94.82ms +step:648/1670 train_time:61439ms step_avg:94.81ms +step:649/1670 train_time:61536ms step_avg:94.82ms +step:650/1670 train_time:61635ms step_avg:94.82ms +step:651/1670 train_time:61730ms step_avg:94.82ms +step:652/1670 train_time:61823ms step_avg:94.82ms +step:653/1670 train_time:61916ms step_avg:94.82ms +step:654/1670 train_time:62009ms step_avg:94.81ms +step:655/1670 train_time:62101ms step_avg:94.81ms +step:656/1670 train_time:62194ms step_avg:94.81ms +step:657/1670 train_time:62286ms step_avg:94.80ms +step:658/1670 train_time:62379ms step_avg:94.80ms +step:659/1670 train_time:62474ms step_avg:94.80ms +step:660/1670 train_time:62569ms step_avg:94.80ms +step:661/1670 train_time:62665ms step_avg:94.80ms +step:662/1670 train_time:62758ms step_avg:94.80ms +step:663/1670 train_time:62852ms step_avg:94.80ms +step:664/1670 train_time:62946ms step_avg:94.80ms +step:665/1670 train_time:63039ms step_avg:94.80ms +step:666/1670 train_time:63132ms step_avg:94.79ms +step:667/1670 train_time:63225ms step_avg:94.79ms +step:668/1670 train_time:63318ms step_avg:94.79ms +step:669/1670 train_time:63412ms step_avg:94.79ms +step:670/1670 train_time:63506ms step_avg:94.79ms +step:671/1670 train_time:63601ms step_avg:94.79ms +step:672/1670 train_time:63694ms step_avg:94.78ms +step:673/1670 train_time:63789ms step_avg:94.78ms +step:674/1670 train_time:63882ms step_avg:94.78ms +step:675/1670 train_time:63976ms step_avg:94.78ms +step:676/1670 train_time:64069ms step_avg:94.78ms +step:677/1670 train_time:64163ms step_avg:94.78ms +step:678/1670 train_time:64256ms step_avg:94.77ms +step:679/1670 train_time:64350ms step_avg:94.77ms +step:680/1670 train_time:64443ms step_avg:94.77ms +step:681/1670 train_time:64537ms step_avg:94.77ms +step:682/1670 train_time:64631ms step_avg:94.77ms +step:683/1670 train_time:64725ms step_avg:94.77ms +step:684/1670 train_time:64819ms step_avg:94.76ms +step:685/1670 train_time:64912ms step_avg:94.76ms +step:686/1670 train_time:65006ms step_avg:94.76ms +step:687/1670 train_time:65099ms step_avg:94.76ms +step:688/1670 train_time:65193ms step_avg:94.76ms +step:689/1670 train_time:65285ms step_avg:94.75ms +step:690/1670 train_time:65379ms step_avg:94.75ms +step:691/1670 train_time:65474ms step_avg:94.75ms +step:692/1670 train_time:65567ms step_avg:94.75ms +step:693/1670 train_time:65661ms step_avg:94.75ms +step:694/1670 train_time:65755ms step_avg:94.75ms +step:695/1670 train_time:65849ms step_avg:94.75ms +step:696/1670 train_time:65943ms step_avg:94.75ms +step:697/1670 train_time:66036ms step_avg:94.74ms +step:698/1670 train_time:66130ms step_avg:94.74ms +step:699/1670 train_time:66223ms step_avg:94.74ms +step:700/1670 train_time:66316ms step_avg:94.74ms +step:701/1670 train_time:66410ms step_avg:94.74ms +step:702/1670 train_time:66504ms step_avg:94.73ms +step:703/1670 train_time:66598ms step_avg:94.73ms +step:704/1670 train_time:66692ms step_avg:94.73ms +step:705/1670 train_time:66785ms step_avg:94.73ms +step:706/1670 train_time:66879ms step_avg:94.73ms +step:707/1670 train_time:66973ms step_avg:94.73ms +step:708/1670 train_time:67067ms step_avg:94.73ms +step:709/1670 train_time:67161ms step_avg:94.73ms +step:710/1670 train_time:67254ms step_avg:94.72ms +step:711/1670 train_time:67348ms step_avg:94.72ms +step:712/1670 train_time:67442ms step_avg:94.72ms +step:713/1670 train_time:67535ms step_avg:94.72ms +step:714/1670 train_time:67629ms step_avg:94.72ms +step:715/1670 train_time:67723ms step_avg:94.72ms +step:716/1670 train_time:67816ms step_avg:94.72ms +step:717/1670 train_time:67910ms step_avg:94.71ms +step:718/1670 train_time:68004ms step_avg:94.71ms +step:719/1670 train_time:68098ms step_avg:94.71ms +step:720/1670 train_time:68192ms step_avg:94.71ms +step:721/1670 train_time:68286ms step_avg:94.71ms +step:722/1670 train_time:68381ms step_avg:94.71ms +step:723/1670 train_time:68474ms step_avg:94.71ms +step:724/1670 train_time:68568ms step_avg:94.71ms +step:725/1670 train_time:68661ms step_avg:94.70ms +step:726/1670 train_time:68754ms step_avg:94.70ms +step:727/1670 train_time:68848ms step_avg:94.70ms +step:728/1670 train_time:68943ms step_avg:94.70ms +step:729/1670 train_time:69038ms step_avg:94.70ms +step:730/1670 train_time:69132ms step_avg:94.70ms +step:731/1670 train_time:69226ms step_avg:94.70ms +step:732/1670 train_time:69320ms step_avg:94.70ms +step:733/1670 train_time:69414ms step_avg:94.70ms +step:734/1670 train_time:69507ms step_avg:94.70ms +step:735/1670 train_time:69602ms step_avg:94.70ms +step:736/1670 train_time:69695ms step_avg:94.69ms +step:737/1670 train_time:69789ms step_avg:94.69ms +step:738/1670 train_time:69884ms step_avg:94.69ms +step:739/1670 train_time:69977ms step_avg:94.69ms +step:740/1670 train_time:70071ms step_avg:94.69ms +step:741/1670 train_time:70166ms step_avg:94.69ms +step:742/1670 train_time:70259ms step_avg:94.69ms +step:743/1670 train_time:70353ms step_avg:94.69ms +step:744/1670 train_time:70446ms step_avg:94.69ms +step:745/1670 train_time:70539ms step_avg:94.68ms +step:746/1670 train_time:70633ms step_avg:94.68ms +step:747/1670 train_time:70726ms step_avg:94.68ms +step:748/1670 train_time:70819ms step_avg:94.68ms +step:749/1670 train_time:70912ms step_avg:94.68ms +step:750/1670 train_time:71007ms step_avg:94.68ms +step:750/1670 val_loss:3.5617 train_time:71099ms step_avg:94.80ms +step:751/1670 train_time:71124ms step_avg:94.71ms +step:752/1670 train_time:71201ms step_avg:94.68ms +step:753/1670 train_time:71302ms step_avg:94.69ms +step:754/1670 train_time:71397ms step_avg:94.69ms +step:755/1670 train_time:71490ms step_avg:94.69ms +step:756/1670 train_time:71583ms step_avg:94.69ms +step:757/1670 train_time:71675ms step_avg:94.68ms +step:758/1670 train_time:71768ms step_avg:94.68ms +step:759/1670 train_time:71861ms step_avg:94.68ms +step:760/1670 train_time:71953ms step_avg:94.68ms +step:761/1670 train_time:72047ms step_avg:94.67ms +step:762/1670 train_time:72142ms step_avg:94.67ms +step:763/1670 train_time:72239ms step_avg:94.68ms +step:764/1670 train_time:72336ms step_avg:94.68ms +step:765/1670 train_time:72429ms step_avg:94.68ms +step:766/1670 train_time:72522ms step_avg:94.68ms +step:767/1670 train_time:72616ms step_avg:94.68ms +step:768/1670 train_time:72709ms step_avg:94.67ms +step:769/1670 train_time:72802ms step_avg:94.67ms +step:770/1670 train_time:72894ms step_avg:94.67ms +step:771/1670 train_time:72987ms step_avg:94.67ms +step:772/1670 train_time:73081ms step_avg:94.66ms +step:773/1670 train_time:73176ms step_avg:94.66ms +step:774/1670 train_time:73271ms step_avg:94.67ms +step:775/1670 train_time:73366ms step_avg:94.67ms +step:776/1670 train_time:73461ms step_avg:94.67ms +step:777/1670 train_time:73554ms step_avg:94.66ms +step:778/1670 train_time:73647ms step_avg:94.66ms +step:779/1670 train_time:73742ms step_avg:94.66ms +step:780/1670 train_time:73835ms step_avg:94.66ms +step:781/1670 train_time:73929ms step_avg:94.66ms +step:782/1670 train_time:74021ms step_avg:94.66ms +step:783/1670 train_time:74115ms step_avg:94.65ms +step:784/1670 train_time:74209ms step_avg:94.65ms +step:785/1670 train_time:74302ms step_avg:94.65ms +step:786/1670 train_time:74397ms step_avg:94.65ms +step:787/1670 train_time:74491ms step_avg:94.65ms +step:788/1670 train_time:74584ms step_avg:94.65ms +step:789/1670 train_time:74678ms step_avg:94.65ms +step:790/1670 train_time:74771ms step_avg:94.65ms +step:791/1670 train_time:74865ms step_avg:94.65ms +step:792/1670 train_time:74959ms step_avg:94.64ms +step:793/1670 train_time:75053ms step_avg:94.64ms +step:794/1670 train_time:75146ms step_avg:94.64ms +step:795/1670 train_time:75240ms step_avg:94.64ms +step:796/1670 train_time:75334ms step_avg:94.64ms +step:797/1670 train_time:75428ms step_avg:94.64ms +step:798/1670 train_time:75522ms step_avg:94.64ms +step:799/1670 train_time:75615ms step_avg:94.64ms +step:800/1670 train_time:75709ms step_avg:94.64ms +step:801/1670 train_time:75803ms step_avg:94.64ms +step:802/1670 train_time:75896ms step_avg:94.63ms +step:803/1670 train_time:75990ms step_avg:94.63ms +step:804/1670 train_time:76083ms step_avg:94.63ms +step:805/1670 train_time:76178ms step_avg:94.63ms +step:806/1670 train_time:76271ms step_avg:94.63ms +step:807/1670 train_time:76364ms step_avg:94.63ms +step:808/1670 train_time:76458ms step_avg:94.63ms +step:809/1670 train_time:76551ms step_avg:94.62ms +step:810/1670 train_time:76646ms step_avg:94.62ms +step:811/1670 train_time:76739ms step_avg:94.62ms +step:812/1670 train_time:76832ms step_avg:94.62ms +step:813/1670 train_time:76925ms step_avg:94.62ms +step:814/1670 train_time:77019ms step_avg:94.62ms +step:815/1670 train_time:77112ms step_avg:94.62ms +step:816/1670 train_time:77207ms step_avg:94.62ms +step:817/1670 train_time:77301ms step_avg:94.62ms +step:818/1670 train_time:77394ms step_avg:94.61ms +step:819/1670 train_time:77488ms step_avg:94.61ms +step:820/1670 train_time:77582ms step_avg:94.61ms +step:821/1670 train_time:77676ms step_avg:94.61ms +step:822/1670 train_time:77770ms step_avg:94.61ms +step:823/1670 train_time:77864ms step_avg:94.61ms +step:824/1670 train_time:77958ms step_avg:94.61ms +step:825/1670 train_time:78051ms step_avg:94.61ms +step:826/1670 train_time:78145ms step_avg:94.61ms +step:827/1670 train_time:78240ms step_avg:94.61ms +step:828/1670 train_time:78334ms step_avg:94.61ms +step:829/1670 train_time:78427ms step_avg:94.60ms +step:830/1670 train_time:78521ms step_avg:94.60ms +step:831/1670 train_time:78614ms step_avg:94.60ms +step:832/1670 train_time:78709ms step_avg:94.60ms +step:833/1670 train_time:78802ms step_avg:94.60ms +step:834/1670 train_time:78896ms step_avg:94.60ms +step:835/1670 train_time:78989ms step_avg:94.60ms +step:836/1670 train_time:79083ms step_avg:94.60ms +step:837/1670 train_time:79178ms step_avg:94.60ms +step:838/1670 train_time:79271ms step_avg:94.60ms +step:839/1670 train_time:79365ms step_avg:94.59ms +step:840/1670 train_time:79459ms step_avg:94.59ms +step:841/1670 train_time:79553ms step_avg:94.59ms +step:842/1670 train_time:79646ms step_avg:94.59ms +step:843/1670 train_time:79740ms step_avg:94.59ms +step:844/1670 train_time:79833ms step_avg:94.59ms +step:845/1670 train_time:79927ms step_avg:94.59ms +step:846/1670 train_time:80021ms step_avg:94.59ms +step:847/1670 train_time:80115ms step_avg:94.59ms +step:848/1670 train_time:80209ms step_avg:94.59ms +step:849/1670 train_time:80303ms step_avg:94.59ms +step:850/1670 train_time:80397ms step_avg:94.58ms +step:851/1670 train_time:80815ms step_avg:94.97ms +step:852/1670 train_time:80917ms step_avg:94.97ms +step:853/1670 train_time:81009ms step_avg:94.97ms +step:854/1670 train_time:81102ms step_avg:94.97ms +step:855/1670 train_time:81195ms step_avg:94.96ms +step:856/1670 train_time:81287ms step_avg:94.96ms +step:857/1670 train_time:81381ms step_avg:94.96ms +step:858/1670 train_time:81473ms step_avg:94.96ms +step:859/1670 train_time:81566ms step_avg:94.95ms +step:860/1670 train_time:81659ms step_avg:94.95ms +step:861/1670 train_time:81755ms step_avg:94.95ms +step:862/1670 train_time:81853ms step_avg:94.96ms +step:863/1670 train_time:81950ms step_avg:94.96ms +step:864/1670 train_time:82046ms step_avg:94.96ms +step:865/1670 train_time:82138ms step_avg:94.96ms +step:866/1670 train_time:82232ms step_avg:94.96ms +step:867/1670 train_time:82325ms step_avg:94.95ms +step:868/1670 train_time:82418ms step_avg:94.95ms +step:869/1670 train_time:82511ms step_avg:94.95ms +step:870/1670 train_time:82604ms step_avg:94.95ms +step:871/1670 train_time:82698ms step_avg:94.95ms +step:872/1670 train_time:82794ms step_avg:94.95ms +step:873/1670 train_time:82888ms step_avg:94.95ms +step:874/1670 train_time:82983ms step_avg:94.95ms +step:875/1670 train_time:83077ms step_avg:94.94ms +step:875/1670 val_loss:3.5158 train_time:83168ms step_avg:95.05ms +step:876/1670 train_time:83193ms step_avg:94.97ms +step:877/1670 train_time:83269ms step_avg:94.95ms +step:878/1670 train_time:83371ms step_avg:94.96ms +step:879/1670 train_time:83468ms step_avg:94.96ms +step:880/1670 train_time:83561ms step_avg:94.96ms +step:881/1670 train_time:83654ms step_avg:94.95ms +step:882/1670 train_time:83746ms step_avg:94.95ms +step:883/1670 train_time:83839ms step_avg:94.95ms +step:884/1670 train_time:83932ms step_avg:94.95ms +step:885/1670 train_time:84024ms step_avg:94.94ms +step:886/1670 train_time:84117ms step_avg:94.94ms +step:887/1670 train_time:84211ms step_avg:94.94ms +step:888/1670 train_time:84308ms step_avg:94.94ms +step:889/1670 train_time:84405ms step_avg:94.94ms +step:890/1670 train_time:84500ms step_avg:94.94ms +step:891/1670 train_time:84593ms step_avg:94.94ms +step:892/1670 train_time:84686ms step_avg:94.94ms +step:893/1670 train_time:84778ms step_avg:94.94ms +step:894/1670 train_time:84872ms step_avg:94.94ms +step:895/1670 train_time:84965ms step_avg:94.93ms +step:896/1670 train_time:85058ms step_avg:94.93ms +step:897/1670 train_time:85152ms step_avg:94.93ms +step:898/1670 train_time:85247ms step_avg:94.93ms +step:899/1670 train_time:85342ms step_avg:94.93ms +step:900/1670 train_time:85437ms step_avg:94.93ms +step:901/1670 train_time:85530ms step_avg:94.93ms +step:902/1670 train_time:85623ms step_avg:94.93ms +step:903/1670 train_time:85717ms step_avg:94.93ms +step:904/1670 train_time:85811ms step_avg:94.92ms +step:905/1670 train_time:85904ms step_avg:94.92ms +step:906/1670 train_time:85998ms step_avg:94.92ms +step:907/1670 train_time:86092ms step_avg:94.92ms +step:908/1670 train_time:86184ms step_avg:94.92ms +step:909/1670 train_time:86278ms step_avg:94.92ms +step:910/1670 train_time:86373ms step_avg:94.92ms +step:911/1670 train_time:86467ms step_avg:94.91ms +step:912/1670 train_time:86562ms step_avg:94.91ms +step:913/1670 train_time:86655ms step_avg:94.91ms +step:914/1670 train_time:86748ms step_avg:94.91ms +step:915/1670 train_time:86842ms step_avg:94.91ms +step:916/1670 train_time:86936ms step_avg:94.91ms +step:917/1670 train_time:87029ms step_avg:94.91ms +step:918/1670 train_time:87122ms step_avg:94.90ms +step:919/1670 train_time:87216ms step_avg:94.90ms +step:920/1670 train_time:87311ms step_avg:94.90ms +step:921/1670 train_time:87405ms step_avg:94.90ms +step:922/1670 train_time:87498ms step_avg:94.90ms +step:923/1670 train_time:87593ms step_avg:94.90ms +step:924/1670 train_time:87686ms step_avg:94.90ms +step:925/1670 train_time:87780ms step_avg:94.90ms +step:926/1670 train_time:87873ms step_avg:94.90ms +step:927/1670 train_time:87966ms step_avg:94.89ms +step:928/1670 train_time:88060ms step_avg:94.89ms +step:929/1670 train_time:88153ms step_avg:94.89ms +step:930/1670 train_time:88246ms step_avg:94.89ms +step:931/1670 train_time:88340ms step_avg:94.89ms +step:932/1670 train_time:88435ms step_avg:94.89ms +step:933/1670 train_time:88528ms step_avg:94.89ms +step:934/1670 train_time:88622ms step_avg:94.88ms +step:935/1670 train_time:88716ms step_avg:94.88ms +step:936/1670 train_time:88810ms step_avg:94.88ms +step:937/1670 train_time:88903ms step_avg:94.88ms +step:938/1670 train_time:88998ms step_avg:94.88ms +step:939/1670 train_time:89091ms step_avg:94.88ms +step:940/1670 train_time:89184ms step_avg:94.88ms +step:941/1670 train_time:89278ms step_avg:94.88ms +step:942/1670 train_time:89372ms step_avg:94.87ms +step:943/1670 train_time:89465ms step_avg:94.87ms +step:944/1670 train_time:89560ms step_avg:94.87ms +step:945/1670 train_time:89653ms step_avg:94.87ms +step:946/1670 train_time:89747ms step_avg:94.87ms +step:947/1670 train_time:89842ms step_avg:94.87ms +step:948/1670 train_time:89936ms step_avg:94.87ms +step:949/1670 train_time:90030ms step_avg:94.87ms +step:950/1670 train_time:90123ms step_avg:94.87ms +step:951/1670 train_time:90217ms step_avg:94.87ms +step:952/1670 train_time:90310ms step_avg:94.86ms +step:953/1670 train_time:90405ms step_avg:94.86ms +step:954/1670 train_time:90499ms step_avg:94.86ms +step:955/1670 train_time:90592ms step_avg:94.86ms +step:956/1670 train_time:90685ms step_avg:94.86ms +step:957/1670 train_time:90779ms step_avg:94.86ms +step:958/1670 train_time:90873ms step_avg:94.86ms +step:959/1670 train_time:90966ms step_avg:94.86ms +step:960/1670 train_time:91059ms step_avg:94.85ms +step:961/1670 train_time:91153ms step_avg:94.85ms +step:962/1670 train_time:91247ms step_avg:94.85ms +step:963/1670 train_time:91341ms step_avg:94.85ms +step:964/1670 train_time:91435ms step_avg:94.85ms +step:965/1670 train_time:91529ms step_avg:94.85ms +step:966/1670 train_time:91623ms step_avg:94.85ms +step:967/1670 train_time:91717ms step_avg:94.85ms +step:968/1670 train_time:91811ms step_avg:94.85ms +step:969/1670 train_time:91905ms step_avg:94.84ms +step:970/1670 train_time:91999ms step_avg:94.84ms +step:971/1670 train_time:92091ms step_avg:94.84ms +step:972/1670 train_time:92185ms step_avg:94.84ms +step:973/1670 train_time:92279ms step_avg:94.84ms +step:974/1670 train_time:92373ms step_avg:94.84ms +step:975/1670 train_time:92466ms step_avg:94.84ms +step:976/1670 train_time:92559ms step_avg:94.84ms +step:977/1670 train_time:92652ms step_avg:94.83ms +step:978/1670 train_time:92747ms step_avg:94.83ms +step:979/1670 train_time:92842ms step_avg:94.83ms +step:980/1670 train_time:92936ms step_avg:94.83ms +step:981/1670 train_time:93029ms step_avg:94.83ms +step:982/1670 train_time:93123ms step_avg:94.83ms +step:983/1670 train_time:93217ms step_avg:94.83ms +step:984/1670 train_time:93310ms step_avg:94.83ms +step:985/1670 train_time:93405ms step_avg:94.83ms +step:986/1670 train_time:93497ms step_avg:94.82ms +step:987/1670 train_time:93591ms step_avg:94.82ms +step:988/1670 train_time:93685ms step_avg:94.82ms +step:989/1670 train_time:93779ms step_avg:94.82ms +step:990/1670 train_time:93873ms step_avg:94.82ms +step:991/1670 train_time:93966ms step_avg:94.82ms +step:992/1670 train_time:94060ms step_avg:94.82ms +step:993/1670 train_time:94154ms step_avg:94.82ms +step:994/1670 train_time:94247ms step_avg:94.82ms +step:995/1670 train_time:94341ms step_avg:94.81ms +step:996/1670 train_time:94435ms step_avg:94.81ms +step:997/1670 train_time:94528ms step_avg:94.81ms +step:998/1670 train_time:94622ms step_avg:94.81ms +step:999/1670 train_time:94716ms step_avg:94.81ms +step:1000/1670 train_time:94810ms step_avg:94.81ms +step:1000/1670 val_loss:3.4681 train_time:94902ms step_avg:94.90ms +step:1001/1670 train_time:94927ms step_avg:94.83ms +step:1002/1670 train_time:95005ms step_avg:94.81ms +step:1003/1670 train_time:95106ms step_avg:94.82ms +step:1004/1670 train_time:95201ms step_avg:94.82ms +step:1005/1670 train_time:95294ms step_avg:94.82ms +step:1006/1670 train_time:95387ms step_avg:94.82ms +step:1007/1670 train_time:95479ms step_avg:94.82ms +step:1008/1670 train_time:95572ms step_avg:94.81ms +step:1009/1670 train_time:95664ms step_avg:94.81ms +step:1010/1670 train_time:95757ms step_avg:94.81ms +step:1011/1670 train_time:95850ms step_avg:94.81ms +step:1012/1670 train_time:95944ms step_avg:94.81ms +step:1013/1670 train_time:96040ms step_avg:94.81ms +step:1014/1670 train_time:96137ms step_avg:94.81ms +step:1015/1670 train_time:96232ms step_avg:94.81ms +step:1016/1670 train_time:96326ms step_avg:94.81ms +step:1017/1670 train_time:96419ms step_avg:94.81ms +step:1018/1670 train_time:96513ms step_avg:94.81ms +step:1019/1670 train_time:96606ms step_avg:94.80ms +step:1020/1670 train_time:96698ms step_avg:94.80ms +step:1021/1670 train_time:96791ms step_avg:94.80ms +step:1022/1670 train_time:96884ms step_avg:94.80ms +step:1023/1670 train_time:96980ms step_avg:94.80ms +step:1024/1670 train_time:97075ms step_avg:94.80ms +step:1025/1670 train_time:97169ms step_avg:94.80ms +step:1026/1670 train_time:97264ms step_avg:94.80ms +step:1027/1670 train_time:97358ms step_avg:94.80ms +step:1028/1670 train_time:97452ms step_avg:94.80ms +step:1029/1670 train_time:97545ms step_avg:94.80ms +step:1030/1670 train_time:97637ms step_avg:94.79ms +step:1031/1670 train_time:97731ms step_avg:94.79ms +step:1032/1670 train_time:97823ms step_avg:94.79ms +step:1033/1670 train_time:97917ms step_avg:94.79ms +step:1034/1670 train_time:98012ms step_avg:94.79ms +step:1035/1670 train_time:98106ms step_avg:94.79ms +step:1036/1670 train_time:98200ms step_avg:94.79ms +step:1037/1670 train_time:98295ms step_avg:94.79ms +step:1038/1670 train_time:98389ms step_avg:94.79ms +step:1039/1670 train_time:98483ms step_avg:94.79ms +step:1040/1670 train_time:98577ms step_avg:94.79ms +step:1041/1670 train_time:98670ms step_avg:94.78ms +step:1042/1670 train_time:98763ms step_avg:94.78ms +step:1043/1670 train_time:98857ms step_avg:94.78ms +step:1044/1670 train_time:98951ms step_avg:94.78ms +step:1045/1670 train_time:99045ms step_avg:94.78ms +step:1046/1670 train_time:99139ms step_avg:94.78ms +step:1047/1670 train_time:99233ms step_avg:94.78ms +step:1048/1670 train_time:99327ms step_avg:94.78ms +step:1049/1670 train_time:99420ms step_avg:94.78ms +step:1050/1670 train_time:99515ms step_avg:94.78ms +step:1051/1670 train_time:99609ms step_avg:94.78ms +step:1052/1670 train_time:99702ms step_avg:94.77ms +step:1053/1670 train_time:99795ms step_avg:94.77ms +step:1054/1670 train_time:99889ms step_avg:94.77ms +step:1055/1670 train_time:99983ms step_avg:94.77ms +step:1056/1670 train_time:100077ms step_avg:94.77ms +step:1057/1670 train_time:100171ms step_avg:94.77ms +step:1058/1670 train_time:100265ms step_avg:94.77ms +step:1059/1670 train_time:100360ms step_avg:94.77ms +step:1060/1670 train_time:100453ms step_avg:94.77ms +step:1061/1670 train_time:100547ms step_avg:94.77ms +step:1062/1670 train_time:100874ms step_avg:94.99ms +step:1063/1670 train_time:101070ms step_avg:95.08ms +step:1064/1670 train_time:101162ms step_avg:95.08ms +step:1065/1670 train_time:101255ms step_avg:95.07ms +step:1066/1670 train_time:101347ms step_avg:95.07ms +step:1067/1670 train_time:101440ms step_avg:95.07ms +step:1068/1670 train_time:101532ms step_avg:95.07ms +step:1069/1670 train_time:101624ms step_avg:95.06ms +step:1070/1670 train_time:101717ms step_avg:95.06ms +step:1071/1670 train_time:101809ms step_avg:95.06ms +step:1072/1670 train_time:101905ms step_avg:95.06ms +step:1073/1670 train_time:102001ms step_avg:95.06ms +step:1074/1670 train_time:102097ms step_avg:95.06ms +step:1075/1670 train_time:102191ms step_avg:95.06ms +step:1076/1670 train_time:102285ms step_avg:95.06ms +step:1077/1670 train_time:102378ms step_avg:95.06ms +step:1078/1670 train_time:102471ms step_avg:95.06ms +step:1079/1670 train_time:102564ms step_avg:95.06ms +step:1080/1670 train_time:102657ms step_avg:95.05ms +step:1081/1670 train_time:102750ms step_avg:95.05ms +step:1082/1670 train_time:102843ms step_avg:95.05ms +step:1083/1670 train_time:102938ms step_avg:95.05ms +step:1084/1670 train_time:103033ms step_avg:95.05ms +step:1085/1670 train_time:103127ms step_avg:95.05ms +step:1086/1670 train_time:103221ms step_avg:95.05ms +step:1087/1670 train_time:103315ms step_avg:95.05ms +step:1088/1670 train_time:103409ms step_avg:95.04ms +step:1089/1670 train_time:103502ms step_avg:95.04ms +step:1090/1670 train_time:103595ms step_avg:95.04ms +step:1091/1670 train_time:103688ms step_avg:95.04ms +step:1092/1670 train_time:103783ms step_avg:95.04ms +step:1093/1670 train_time:103877ms step_avg:95.04ms +step:1094/1670 train_time:103971ms step_avg:95.04ms +step:1095/1670 train_time:104067ms step_avg:95.04ms +step:1096/1670 train_time:104160ms step_avg:95.04ms +step:1097/1670 train_time:104253ms step_avg:95.03ms +step:1098/1670 train_time:104347ms step_avg:95.03ms +step:1099/1670 train_time:104441ms step_avg:95.03ms +step:1100/1670 train_time:104534ms step_avg:95.03ms +step:1101/1670 train_time:104627ms step_avg:95.03ms +step:1102/1670 train_time:104721ms step_avg:95.03ms +step:1103/1670 train_time:104814ms step_avg:95.03ms +step:1104/1670 train_time:104908ms step_avg:95.03ms +step:1105/1670 train_time:105001ms step_avg:95.02ms +step:1106/1670 train_time:105095ms step_avg:95.02ms +step:1107/1670 train_time:105189ms step_avg:95.02ms +step:1108/1670 train_time:105284ms step_avg:95.02ms +step:1109/1670 train_time:105378ms step_avg:95.02ms +step:1110/1670 train_time:105472ms step_avg:95.02ms +step:1111/1670 train_time:105565ms step_avg:95.02ms +step:1112/1670 train_time:105658ms step_avg:95.02ms +step:1113/1670 train_time:105751ms step_avg:95.01ms +step:1114/1670 train_time:105845ms step_avg:95.01ms +step:1115/1670 train_time:106047ms step_avg:95.11ms +step:1116/1670 train_time:106116ms step_avg:95.09ms +step:1117/1670 train_time:106210ms step_avg:95.09ms +step:1118/1670 train_time:106303ms step_avg:95.08ms +step:1119/1670 train_time:106396ms step_avg:95.08ms +step:1120/1670 train_time:106490ms step_avg:95.08ms +step:1121/1670 train_time:106584ms step_avg:95.08ms +step:1122/1670 train_time:106677ms step_avg:95.08ms +step:1123/1670 train_time:106770ms step_avg:95.08ms +step:1124/1670 train_time:106863ms step_avg:95.07ms +step:1125/1670 train_time:106964ms step_avg:95.08ms +step:1125/1670 val_loss:3.4149 train_time:107061ms step_avg:95.17ms +step:1126/1670 train_time:107086ms step_avg:95.10ms +step:1127/1670 train_time:107167ms step_avg:95.09ms +step:1128/1670 train_time:107268ms step_avg:95.10ms +step:1129/1670 train_time:107363ms step_avg:95.10ms +step:1130/1670 train_time:107457ms step_avg:95.09ms +step:1131/1670 train_time:107551ms step_avg:95.09ms +step:1132/1670 train_time:107644ms step_avg:95.09ms +step:1133/1670 train_time:107737ms step_avg:95.09ms +step:1134/1670 train_time:107831ms step_avg:95.09ms +step:1135/1670 train_time:107924ms step_avg:95.09ms +step:1136/1670 train_time:108018ms step_avg:95.09ms +step:1137/1670 train_time:108116ms step_avg:95.09ms +step:1138/1670 train_time:108213ms step_avg:95.09ms +step:1139/1670 train_time:108309ms step_avg:95.09ms +step:1140/1670 train_time:108404ms step_avg:95.09ms +step:1141/1670 train_time:108498ms step_avg:95.09ms +step:1142/1670 train_time:108593ms step_avg:95.09ms +step:1143/1670 train_time:108686ms step_avg:95.09ms +step:1144/1670 train_time:108780ms step_avg:95.09ms +step:1145/1670 train_time:108875ms step_avg:95.09ms +step:1146/1670 train_time:108967ms step_avg:95.08ms +step:1147/1670 train_time:109062ms step_avg:95.08ms +step:1148/1670 train_time:109158ms step_avg:95.09ms +step:1149/1670 train_time:109255ms step_avg:95.09ms +step:1150/1670 train_time:109350ms step_avg:95.09ms +step:1151/1670 train_time:109445ms step_avg:95.09ms +step:1152/1670 train_time:109539ms step_avg:95.09ms +step:1153/1670 train_time:109633ms step_avg:95.08ms +step:1154/1670 train_time:109727ms step_avg:95.08ms +step:1155/1670 train_time:109821ms step_avg:95.08ms +step:1156/1670 train_time:109915ms step_avg:95.08ms +step:1157/1670 train_time:110009ms step_avg:95.08ms +step:1158/1670 train_time:110104ms step_avg:95.08ms +step:1159/1670 train_time:110200ms step_avg:95.08ms +step:1160/1670 train_time:110294ms step_avg:95.08ms +step:1161/1670 train_time:110389ms step_avg:95.08ms +step:1162/1670 train_time:110483ms step_avg:95.08ms +step:1163/1670 train_time:110579ms step_avg:95.08ms +step:1164/1670 train_time:110673ms step_avg:95.08ms +step:1165/1670 train_time:110766ms step_avg:95.08ms +step:1166/1670 train_time:110861ms step_avg:95.08ms +step:1167/1670 train_time:110955ms step_avg:95.08ms +step:1168/1670 train_time:111050ms step_avg:95.08ms +step:1169/1670 train_time:111145ms step_avg:95.08ms +step:1170/1670 train_time:111239ms step_avg:95.08ms +step:1171/1670 train_time:111334ms step_avg:95.08ms +step:1172/1670 train_time:111429ms step_avg:95.08ms +step:1173/1670 train_time:111523ms step_avg:95.08ms +step:1174/1670 train_time:111618ms step_avg:95.08ms +step:1175/1670 train_time:111712ms step_avg:95.07ms +step:1176/1670 train_time:111807ms step_avg:95.07ms +step:1177/1670 train_time:111902ms step_avg:95.07ms +step:1178/1670 train_time:111996ms step_avg:95.07ms +step:1179/1670 train_time:112092ms step_avg:95.07ms +step:1180/1670 train_time:112186ms step_avg:95.07ms +step:1181/1670 train_time:112281ms step_avg:95.07ms +step:1182/1670 train_time:112376ms step_avg:95.07ms +step:1183/1670 train_time:112471ms step_avg:95.07ms +step:1184/1670 train_time:112564ms step_avg:95.07ms +step:1185/1670 train_time:112659ms step_avg:95.07ms +step:1186/1670 train_time:112753ms step_avg:95.07ms +step:1187/1670 train_time:112848ms step_avg:95.07ms +step:1188/1670 train_time:112942ms step_avg:95.07ms +step:1189/1670 train_time:113036ms step_avg:95.07ms +step:1190/1670 train_time:113131ms step_avg:95.07ms +step:1191/1670 train_time:113226ms step_avg:95.07ms +step:1192/1670 train_time:113321ms step_avg:95.07ms +step:1193/1670 train_time:113416ms step_avg:95.07ms +step:1194/1670 train_time:113511ms step_avg:95.07ms +step:1195/1670 train_time:113605ms step_avg:95.07ms +step:1196/1670 train_time:113700ms step_avg:95.07ms +step:1197/1670 train_time:113794ms step_avg:95.07ms +step:1198/1670 train_time:113889ms step_avg:95.07ms +step:1199/1670 train_time:113985ms step_avg:95.07ms +step:1200/1670 train_time:114078ms step_avg:95.07ms +step:1201/1670 train_time:114172ms step_avg:95.06ms +step:1202/1670 train_time:114267ms step_avg:95.06ms +step:1203/1670 train_time:114363ms step_avg:95.06ms +step:1204/1670 train_time:114457ms step_avg:95.06ms +step:1205/1670 train_time:114551ms step_avg:95.06ms +step:1206/1670 train_time:114646ms step_avg:95.06ms +step:1207/1670 train_time:114740ms step_avg:95.06ms +step:1208/1670 train_time:114835ms step_avg:95.06ms +step:1209/1670 train_time:114929ms step_avg:95.06ms +step:1210/1670 train_time:115024ms step_avg:95.06ms +step:1211/1670 train_time:115119ms step_avg:95.06ms +step:1212/1670 train_time:115214ms step_avg:95.06ms +step:1213/1670 train_time:115309ms step_avg:95.06ms +step:1214/1670 train_time:115404ms step_avg:95.06ms +step:1215/1670 train_time:115499ms step_avg:95.06ms +step:1216/1670 train_time:115594ms step_avg:95.06ms +step:1217/1670 train_time:115689ms step_avg:95.06ms +step:1218/1670 train_time:115784ms step_avg:95.06ms +step:1219/1670 train_time:115878ms step_avg:95.06ms +step:1220/1670 train_time:115973ms step_avg:95.06ms +step:1221/1670 train_time:116067ms step_avg:95.06ms +step:1222/1670 train_time:116161ms step_avg:95.06ms +step:1223/1670 train_time:116256ms step_avg:95.06ms +step:1224/1670 train_time:116352ms step_avg:95.06ms +step:1225/1670 train_time:116447ms step_avg:95.06ms +step:1226/1670 train_time:116541ms step_avg:95.06ms +step:1227/1670 train_time:116635ms step_avg:95.06ms +step:1228/1670 train_time:116730ms step_avg:95.06ms +step:1229/1670 train_time:116824ms step_avg:95.06ms +step:1230/1670 train_time:116918ms step_avg:95.06ms +step:1231/1670 train_time:117012ms step_avg:95.05ms +step:1232/1670 train_time:117107ms step_avg:95.05ms +step:1233/1670 train_time:117202ms step_avg:95.05ms +step:1234/1670 train_time:117297ms step_avg:95.05ms +step:1235/1670 train_time:117392ms step_avg:95.05ms +step:1236/1670 train_time:117486ms step_avg:95.05ms +step:1237/1670 train_time:117581ms step_avg:95.05ms +step:1238/1670 train_time:117675ms step_avg:95.05ms +step:1239/1670 train_time:117770ms step_avg:95.05ms +step:1240/1670 train_time:117864ms step_avg:95.05ms +step:1241/1670 train_time:117959ms step_avg:95.05ms +step:1242/1670 train_time:118053ms step_avg:95.05ms +step:1243/1670 train_time:118152ms step_avg:95.05ms +step:1244/1670 train_time:118243ms step_avg:95.05ms +step:1245/1670 train_time:118336ms step_avg:95.05ms +step:1246/1670 train_time:118431ms step_avg:95.05ms +step:1247/1670 train_time:118525ms step_avg:95.05ms +step:1248/1670 train_time:118620ms step_avg:95.05ms +step:1249/1670 train_time:118714ms step_avg:95.05ms +step:1250/1670 train_time:118808ms step_avg:95.05ms +step:1250/1670 val_loss:3.3758 train_time:118903ms step_avg:95.12ms +step:1251/1670 train_time:118928ms step_avg:95.07ms +step:1252/1670 train_time:119009ms step_avg:95.06ms +step:1253/1670 train_time:119111ms step_avg:95.06ms +step:1254/1670 train_time:119205ms step_avg:95.06ms +step:1255/1670 train_time:119299ms step_avg:95.06ms +step:1256/1670 train_time:119392ms step_avg:95.06ms +step:1257/1670 train_time:119486ms step_avg:95.06ms +step:1258/1670 train_time:119579ms step_avg:95.05ms +step:1259/1670 train_time:119672ms step_avg:95.05ms +step:1260/1670 train_time:119766ms step_avg:95.05ms +step:1261/1670 train_time:119859ms step_avg:95.05ms +step:1262/1670 train_time:119956ms step_avg:95.05ms +step:1263/1670 train_time:120053ms step_avg:95.05ms +step:1264/1670 train_time:120149ms step_avg:95.05ms +step:1265/1670 train_time:120244ms step_avg:95.05ms +step:1266/1670 train_time:120338ms step_avg:95.05ms +step:1267/1670 train_time:120432ms step_avg:95.05ms +step:1268/1670 train_time:120525ms step_avg:95.05ms +step:1269/1670 train_time:120619ms step_avg:95.05ms +step:1270/1670 train_time:120713ms step_avg:95.05ms +step:1271/1670 train_time:120807ms step_avg:95.05ms +step:1272/1670 train_time:120902ms step_avg:95.05ms +step:1273/1670 train_time:120999ms step_avg:95.05ms +step:1274/1670 train_time:121447ms step_avg:95.33ms +step:1275/1670 train_time:121521ms step_avg:95.31ms +step:1276/1670 train_time:121614ms step_avg:95.31ms +step:1277/1670 train_time:121707ms step_avg:95.31ms +step:1278/1670 train_time:121800ms step_avg:95.31ms +step:1279/1670 train_time:121894ms step_avg:95.30ms +step:1280/1670 train_time:121987ms step_avg:95.30ms +step:1281/1670 train_time:122080ms step_avg:95.30ms +step:1282/1670 train_time:122174ms step_avg:95.30ms +step:1283/1670 train_time:122268ms step_avg:95.30ms +step:1284/1670 train_time:122366ms step_avg:95.30ms +step:1285/1670 train_time:122464ms step_avg:95.30ms +step:1286/1670 train_time:122559ms step_avg:95.30ms +step:1287/1670 train_time:122653ms step_avg:95.30ms +step:1288/1670 train_time:122747ms step_avg:95.30ms +step:1289/1670 train_time:122841ms step_avg:95.30ms +step:1290/1670 train_time:122935ms step_avg:95.30ms +step:1291/1670 train_time:123028ms step_avg:95.30ms +step:1292/1670 train_time:123122ms step_avg:95.30ms +step:1293/1670 train_time:123216ms step_avg:95.29ms +step:1294/1670 train_time:123311ms step_avg:95.29ms +step:1295/1670 train_time:123407ms step_avg:95.30ms +step:1296/1670 train_time:123503ms step_avg:95.30ms +step:1297/1670 train_time:123598ms step_avg:95.30ms +step:1298/1670 train_time:123693ms step_avg:95.29ms +step:1299/1670 train_time:123788ms step_avg:95.29ms +step:1300/1670 train_time:123882ms step_avg:95.29ms +step:1301/1670 train_time:123977ms step_avg:95.29ms +step:1302/1670 train_time:124071ms step_avg:95.29ms +step:1303/1670 train_time:124165ms step_avg:95.29ms +step:1304/1670 train_time:124259ms step_avg:95.29ms +step:1305/1670 train_time:124354ms step_avg:95.29ms +step:1306/1670 train_time:124449ms step_avg:95.29ms +step:1307/1670 train_time:124545ms step_avg:95.29ms +step:1308/1670 train_time:124640ms step_avg:95.29ms +step:1309/1670 train_time:124734ms step_avg:95.29ms +step:1310/1670 train_time:124829ms step_avg:95.29ms +step:1311/1670 train_time:124923ms step_avg:95.29ms +step:1312/1670 train_time:125017ms step_avg:95.29ms +step:1313/1670 train_time:125111ms step_avg:95.29ms +step:1314/1670 train_time:125206ms step_avg:95.29ms +step:1315/1670 train_time:125300ms step_avg:95.29ms +step:1316/1670 train_time:125395ms step_avg:95.28ms +step:1317/1670 train_time:125491ms step_avg:95.29ms +step:1318/1670 train_time:125586ms step_avg:95.29ms +step:1319/1670 train_time:125681ms step_avg:95.29ms +step:1320/1670 train_time:125775ms step_avg:95.28ms +step:1321/1670 train_time:125869ms step_avg:95.28ms +step:1322/1670 train_time:125963ms step_avg:95.28ms +step:1323/1670 train_time:126057ms step_avg:95.28ms +step:1324/1670 train_time:126151ms step_avg:95.28ms +step:1325/1670 train_time:126246ms step_avg:95.28ms +step:1326/1670 train_time:126341ms step_avg:95.28ms +step:1327/1670 train_time:126435ms step_avg:95.28ms +step:1328/1670 train_time:126531ms step_avg:95.28ms +step:1329/1670 train_time:126626ms step_avg:95.28ms +step:1330/1670 train_time:126721ms step_avg:95.28ms +step:1331/1670 train_time:126816ms step_avg:95.28ms +step:1332/1670 train_time:126910ms step_avg:95.28ms +step:1333/1670 train_time:127004ms step_avg:95.28ms +step:1334/1670 train_time:127098ms step_avg:95.28ms +step:1335/1670 train_time:127193ms step_avg:95.28ms +step:1336/1670 train_time:127287ms step_avg:95.28ms +step:1337/1670 train_time:127382ms step_avg:95.27ms +step:1338/1670 train_time:127477ms step_avg:95.27ms +step:1339/1670 train_time:127571ms step_avg:95.27ms +step:1340/1670 train_time:127666ms step_avg:95.27ms +step:1341/1670 train_time:127761ms step_avg:95.27ms +step:1342/1670 train_time:127856ms step_avg:95.27ms +step:1343/1670 train_time:127951ms step_avg:95.27ms +step:1344/1670 train_time:128046ms step_avg:95.27ms +step:1345/1670 train_time:128140ms step_avg:95.27ms +step:1346/1670 train_time:128235ms step_avg:95.27ms +step:1347/1670 train_time:128330ms step_avg:95.27ms +step:1348/1670 train_time:128423ms step_avg:95.27ms +step:1349/1670 train_time:128518ms step_avg:95.27ms +step:1350/1670 train_time:128613ms step_avg:95.27ms +step:1351/1670 train_time:128709ms step_avg:95.27ms +step:1352/1670 train_time:128803ms step_avg:95.27ms +step:1353/1670 train_time:128898ms step_avg:95.27ms +step:1354/1670 train_time:128993ms step_avg:95.27ms +step:1355/1670 train_time:129088ms step_avg:95.27ms +step:1356/1670 train_time:129183ms step_avg:95.27ms +step:1357/1670 train_time:129277ms step_avg:95.27ms +step:1358/1670 train_time:129371ms step_avg:95.27ms +step:1359/1670 train_time:129465ms step_avg:95.26ms +step:1360/1670 train_time:129559ms step_avg:95.26ms +step:1361/1670 train_time:129654ms step_avg:95.26ms +step:1362/1670 train_time:129749ms step_avg:95.26ms +step:1363/1670 train_time:129843ms step_avg:95.26ms +step:1364/1670 train_time:129937ms step_avg:95.26ms +step:1365/1670 train_time:130031ms step_avg:95.26ms +step:1366/1670 train_time:130126ms step_avg:95.26ms +step:1367/1670 train_time:130220ms step_avg:95.26ms +step:1368/1670 train_time:130315ms step_avg:95.26ms +step:1369/1670 train_time:130410ms step_avg:95.26ms +step:1370/1670 train_time:130504ms step_avg:95.26ms +step:1371/1670 train_time:130600ms step_avg:95.26ms +step:1372/1670 train_time:130695ms step_avg:95.26ms +step:1373/1670 train_time:130790ms step_avg:95.26ms +step:1374/1670 train_time:130885ms step_avg:95.26ms +step:1375/1670 train_time:130979ms step_avg:95.26ms +step:1375/1670 val_loss:3.3415 train_time:131073ms step_avg:95.33ms +step:1376/1670 train_time:131098ms step_avg:95.27ms +step:1377/1670 train_time:131176ms step_avg:95.26ms +step:1378/1670 train_time:131278ms step_avg:95.27ms +step:1379/1670 train_time:131373ms step_avg:95.27ms +step:1380/1670 train_time:131467ms step_avg:95.27ms +step:1381/1670 train_time:131560ms step_avg:95.26ms +step:1382/1670 train_time:131653ms step_avg:95.26ms +step:1383/1670 train_time:131747ms step_avg:95.26ms +step:1384/1670 train_time:131841ms step_avg:95.26ms +step:1385/1670 train_time:131935ms step_avg:95.26ms +step:1386/1670 train_time:132029ms step_avg:95.26ms +step:1387/1670 train_time:132125ms step_avg:95.26ms +step:1388/1670 train_time:132224ms step_avg:95.26ms +step:1389/1670 train_time:132322ms step_avg:95.26ms +step:1390/1670 train_time:132417ms step_avg:95.26ms +step:1391/1670 train_time:132511ms step_avg:95.26ms +step:1392/1670 train_time:132604ms step_avg:95.26ms +step:1393/1670 train_time:132698ms step_avg:95.26ms +step:1394/1670 train_time:132793ms step_avg:95.26ms +step:1395/1670 train_time:132886ms step_avg:95.26ms +step:1396/1670 train_time:132979ms step_avg:95.26ms +step:1397/1670 train_time:133076ms step_avg:95.26ms +step:1398/1670 train_time:133172ms step_avg:95.26ms +step:1399/1670 train_time:133268ms step_avg:95.26ms +step:1400/1670 train_time:133364ms step_avg:95.26ms +step:1401/1670 train_time:133458ms step_avg:95.26ms +step:1402/1670 train_time:133553ms step_avg:95.26ms +step:1403/1670 train_time:133647ms step_avg:95.26ms +step:1404/1670 train_time:133742ms step_avg:95.26ms +step:1405/1670 train_time:133837ms step_avg:95.26ms +step:1406/1670 train_time:133930ms step_avg:95.26ms +step:1407/1670 train_time:134023ms step_avg:95.25ms +step:1408/1670 train_time:134118ms step_avg:95.25ms +step:1409/1670 train_time:134214ms step_avg:95.26ms +step:1410/1670 train_time:134310ms step_avg:95.26ms +step:1411/1670 train_time:134405ms step_avg:95.25ms +step:1412/1670 train_time:134499ms step_avg:95.25ms +step:1413/1670 train_time:134594ms step_avg:95.25ms +step:1414/1670 train_time:134689ms step_avg:95.25ms +step:1415/1670 train_time:134783ms step_avg:95.25ms +step:1416/1670 train_time:134878ms step_avg:95.25ms +step:1417/1670 train_time:134972ms step_avg:95.25ms +step:1418/1670 train_time:135066ms step_avg:95.25ms +step:1419/1670 train_time:135161ms step_avg:95.25ms +step:1420/1670 train_time:135257ms step_avg:95.25ms +step:1421/1670 train_time:135353ms step_avg:95.25ms +step:1422/1670 train_time:135447ms step_avg:95.25ms +step:1423/1670 train_time:135542ms step_avg:95.25ms +step:1424/1670 train_time:135636ms step_avg:95.25ms +step:1425/1670 train_time:135730ms step_avg:95.25ms +step:1426/1670 train_time:135824ms step_avg:95.25ms +step:1427/1670 train_time:135918ms step_avg:95.25ms +step:1428/1670 train_time:136013ms step_avg:95.25ms +step:1429/1670 train_time:136107ms step_avg:95.25ms +step:1430/1670 train_time:136202ms step_avg:95.25ms +step:1431/1670 train_time:136297ms step_avg:95.25ms +step:1432/1670 train_time:136393ms step_avg:95.25ms +step:1433/1670 train_time:136486ms step_avg:95.25ms +step:1434/1670 train_time:136580ms step_avg:95.24ms +step:1435/1670 train_time:136676ms step_avg:95.24ms +step:1436/1670 train_time:136770ms step_avg:95.24ms +step:1437/1670 train_time:136865ms step_avg:95.24ms +step:1438/1670 train_time:136958ms step_avg:95.24ms +step:1439/1670 train_time:137053ms step_avg:95.24ms +step:1440/1670 train_time:137148ms step_avg:95.24ms +step:1441/1670 train_time:137243ms step_avg:95.24ms +step:1442/1670 train_time:137338ms step_avg:95.24ms +step:1443/1670 train_time:137432ms step_avg:95.24ms +step:1444/1670 train_time:137526ms step_avg:95.24ms +step:1445/1670 train_time:137622ms step_avg:95.24ms +step:1446/1670 train_time:137717ms step_avg:95.24ms +step:1447/1670 train_time:137811ms step_avg:95.24ms +step:1448/1670 train_time:137905ms step_avg:95.24ms +step:1449/1670 train_time:138000ms step_avg:95.24ms +step:1450/1670 train_time:138096ms step_avg:95.24ms +step:1451/1670 train_time:138189ms step_avg:95.24ms +step:1452/1670 train_time:138285ms step_avg:95.24ms +step:1453/1670 train_time:138380ms step_avg:95.24ms +step:1454/1670 train_time:138474ms step_avg:95.24ms +step:1455/1670 train_time:138569ms step_avg:95.24ms +step:1456/1670 train_time:138664ms step_avg:95.24ms +step:1457/1670 train_time:138758ms step_avg:95.24ms +step:1458/1670 train_time:138853ms step_avg:95.24ms +step:1459/1670 train_time:138947ms step_avg:95.23ms +step:1460/1670 train_time:139042ms step_avg:95.23ms +step:1461/1670 train_time:139137ms step_avg:95.23ms +step:1462/1670 train_time:139232ms step_avg:95.23ms +step:1463/1670 train_time:139327ms step_avg:95.23ms +step:1464/1670 train_time:139422ms step_avg:95.23ms +step:1465/1670 train_time:139517ms step_avg:95.23ms +step:1466/1670 train_time:139612ms step_avg:95.23ms +step:1467/1670 train_time:139706ms step_avg:95.23ms +step:1468/1670 train_time:139801ms step_avg:95.23ms +step:1469/1670 train_time:139896ms step_avg:95.23ms +step:1470/1670 train_time:139990ms step_avg:95.23ms +step:1471/1670 train_time:140084ms step_avg:95.23ms +step:1472/1670 train_time:140178ms step_avg:95.23ms +step:1473/1670 train_time:140274ms step_avg:95.23ms +step:1474/1670 train_time:140368ms step_avg:95.23ms +step:1475/1670 train_time:140464ms step_avg:95.23ms +step:1476/1670 train_time:140559ms step_avg:95.23ms +step:1477/1670 train_time:140653ms step_avg:95.23ms +step:1478/1670 train_time:140747ms step_avg:95.23ms +step:1479/1670 train_time:140842ms step_avg:95.23ms +step:1480/1670 train_time:140936ms step_avg:95.23ms +step:1481/1670 train_time:141031ms step_avg:95.23ms +step:1482/1670 train_time:141126ms step_avg:95.23ms +step:1483/1670 train_time:141221ms step_avg:95.23ms +step:1484/1670 train_time:141315ms step_avg:95.23ms +step:1485/1670 train_time:141758ms step_avg:95.46ms +step:1486/1670 train_time:141835ms step_avg:95.45ms +step:1487/1670 train_time:141927ms step_avg:95.45ms +step:1488/1670 train_time:142020ms step_avg:95.44ms +step:1489/1670 train_time:142113ms step_avg:95.44ms +step:1490/1670 train_time:142207ms step_avg:95.44ms +step:1491/1670 train_time:142300ms step_avg:95.44ms +step:1492/1670 train_time:142393ms step_avg:95.44ms +step:1493/1670 train_time:142487ms step_avg:95.44ms +step:1494/1670 train_time:142580ms step_avg:95.44ms +step:1495/1670 train_time:142677ms step_avg:95.44ms +step:1496/1670 train_time:142775ms step_avg:95.44ms +step:1497/1670 train_time:142873ms step_avg:95.44ms +step:1498/1670 train_time:142968ms step_avg:95.44ms +step:1499/1670 train_time:143062ms step_avg:95.44ms +step:1500/1670 train_time:143156ms step_avg:95.44ms +step:1500/1670 val_loss:3.3117 train_time:143248ms step_avg:95.50ms +step:1501/1670 train_time:143273ms step_avg:95.45ms +step:1502/1670 train_time:143353ms step_avg:95.44ms +step:1503/1670 train_time:143453ms step_avg:95.44ms +step:1504/1670 train_time:143549ms step_avg:95.44ms +step:1505/1670 train_time:143642ms step_avg:95.44ms +step:1506/1670 train_time:143736ms step_avg:95.44ms +step:1507/1670 train_time:143829ms step_avg:95.44ms +step:1508/1670 train_time:143922ms step_avg:95.44ms +step:1509/1670 train_time:144016ms step_avg:95.44ms +step:1510/1670 train_time:144109ms step_avg:95.44ms +step:1511/1670 train_time:144203ms step_avg:95.44ms +step:1512/1670 train_time:144302ms step_avg:95.44ms +step:1513/1670 train_time:144399ms step_avg:95.44ms +step:1514/1670 train_time:144494ms step_avg:95.44ms +step:1515/1670 train_time:144589ms step_avg:95.44ms +step:1516/1670 train_time:144682ms step_avg:95.44ms +step:1517/1670 train_time:144776ms step_avg:95.44ms +step:1518/1670 train_time:144869ms step_avg:95.43ms +step:1519/1670 train_time:144963ms step_avg:95.43ms +step:1520/1670 train_time:145056ms step_avg:95.43ms +step:1521/1670 train_time:145150ms step_avg:95.43ms +step:1522/1670 train_time:145245ms step_avg:95.43ms +step:1523/1670 train_time:145342ms step_avg:95.43ms +step:1524/1670 train_time:145437ms step_avg:95.43ms +step:1525/1670 train_time:145533ms step_avg:95.43ms +step:1526/1670 train_time:145628ms step_avg:95.43ms +step:1527/1670 train_time:145721ms step_avg:95.43ms +step:1528/1670 train_time:145815ms step_avg:95.43ms +step:1529/1670 train_time:145909ms step_avg:95.43ms +step:1530/1670 train_time:146003ms step_avg:95.43ms +step:1531/1670 train_time:146096ms step_avg:95.43ms +step:1532/1670 train_time:146191ms step_avg:95.42ms +step:1533/1670 train_time:146286ms step_avg:95.42ms +step:1534/1670 train_time:146382ms step_avg:95.42ms +step:1535/1670 train_time:146478ms step_avg:95.43ms +step:1536/1670 train_time:146573ms step_avg:95.42ms +step:1537/1670 train_time:146667ms step_avg:95.42ms +step:1538/1670 train_time:146762ms step_avg:95.42ms +step:1539/1670 train_time:146855ms step_avg:95.42ms +step:1540/1670 train_time:146950ms step_avg:95.42ms +step:1541/1670 train_time:147045ms step_avg:95.42ms +step:1542/1670 train_time:147139ms step_avg:95.42ms +step:1543/1670 train_time:147234ms step_avg:95.42ms +step:1544/1670 train_time:147329ms step_avg:95.42ms +step:1545/1670 train_time:147424ms step_avg:95.42ms +step:1546/1670 train_time:147519ms step_avg:95.42ms +step:1547/1670 train_time:147615ms step_avg:95.42ms +step:1548/1670 train_time:147709ms step_avg:95.42ms +step:1549/1670 train_time:147803ms step_avg:95.42ms +step:1550/1670 train_time:147899ms step_avg:95.42ms +step:1551/1670 train_time:147994ms step_avg:95.42ms +step:1552/1670 train_time:148088ms step_avg:95.42ms +step:1553/1670 train_time:148181ms step_avg:95.42ms +step:1554/1670 train_time:148276ms step_avg:95.42ms +step:1555/1670 train_time:148370ms step_avg:95.41ms +step:1556/1670 train_time:148465ms step_avg:95.41ms +step:1557/1670 train_time:148560ms step_avg:95.41ms +step:1558/1670 train_time:148655ms step_avg:95.41ms +step:1559/1670 train_time:148749ms step_avg:95.41ms +step:1560/1670 train_time:148845ms step_avg:95.41ms +step:1561/1670 train_time:148940ms step_avg:95.41ms +step:1562/1670 train_time:149034ms step_avg:95.41ms +step:1563/1670 train_time:149128ms step_avg:95.41ms +step:1564/1670 train_time:149222ms step_avg:95.41ms +step:1565/1670 train_time:149318ms step_avg:95.41ms +step:1566/1670 train_time:149413ms step_avg:95.41ms +step:1567/1670 train_time:149507ms step_avg:95.41ms +step:1568/1670 train_time:149601ms step_avg:95.41ms +step:1569/1670 train_time:149698ms step_avg:95.41ms +step:1570/1670 train_time:149793ms step_avg:95.41ms +step:1571/1670 train_time:149887ms step_avg:95.41ms +step:1572/1670 train_time:149982ms step_avg:95.41ms +step:1573/1670 train_time:150076ms step_avg:95.41ms +step:1574/1670 train_time:150170ms step_avg:95.41ms +step:1575/1670 train_time:150264ms step_avg:95.41ms +step:1576/1670 train_time:150359ms step_avg:95.41ms +step:1577/1670 train_time:150454ms step_avg:95.41ms +step:1578/1670 train_time:150549ms step_avg:95.40ms +step:1579/1670 train_time:150644ms step_avg:95.40ms +step:1580/1670 train_time:150738ms step_avg:95.40ms +step:1581/1670 train_time:150833ms step_avg:95.40ms +step:1582/1670 train_time:150928ms step_avg:95.40ms +step:1583/1670 train_time:151021ms step_avg:95.40ms +step:1584/1670 train_time:151116ms step_avg:95.40ms +step:1585/1670 train_time:151210ms step_avg:95.40ms +step:1586/1670 train_time:151304ms step_avg:95.40ms +step:1587/1670 train_time:151399ms step_avg:95.40ms +step:1588/1670 train_time:151494ms step_avg:95.40ms +step:1589/1670 train_time:151589ms step_avg:95.40ms +step:1590/1670 train_time:151684ms step_avg:95.40ms +step:1591/1670 train_time:151779ms step_avg:95.40ms +step:1592/1670 train_time:151873ms step_avg:95.40ms +step:1593/1670 train_time:151969ms step_avg:95.40ms +step:1594/1670 train_time:152063ms step_avg:95.40ms +step:1595/1670 train_time:152157ms step_avg:95.40ms +step:1596/1670 train_time:152251ms step_avg:95.40ms +step:1597/1670 train_time:152347ms step_avg:95.40ms +step:1598/1670 train_time:152478ms step_avg:95.42ms +step:1599/1670 train_time:152537ms step_avg:95.40ms +step:1600/1670 train_time:152631ms step_avg:95.39ms +step:1601/1670 train_time:152727ms step_avg:95.39ms +step:1602/1670 train_time:152821ms step_avg:95.39ms +step:1603/1670 train_time:152916ms step_avg:95.39ms +step:1604/1670 train_time:153011ms step_avg:95.39ms +step:1605/1670 train_time:153106ms step_avg:95.39ms +step:1606/1670 train_time:153200ms step_avg:95.39ms +step:1607/1670 train_time:153295ms step_avg:95.39ms +step:1608/1670 train_time:153390ms step_avg:95.39ms +step:1609/1670 train_time:153485ms step_avg:95.39ms +step:1610/1670 train_time:153580ms step_avg:95.39ms +step:1611/1670 train_time:153675ms step_avg:95.39ms +step:1612/1670 train_time:153769ms step_avg:95.39ms +step:1613/1670 train_time:153863ms step_avg:95.39ms +step:1614/1670 train_time:153959ms step_avg:95.39ms +step:1615/1670 train_time:154053ms step_avg:95.39ms +step:1616/1670 train_time:154148ms step_avg:95.39ms +step:1617/1670 train_time:154243ms step_avg:95.39ms +step:1618/1670 train_time:154338ms step_avg:95.39ms +step:1619/1670 train_time:154433ms step_avg:95.39ms +step:1620/1670 train_time:154528ms step_avg:95.39ms +step:1621/1670 train_time:154622ms step_avg:95.39ms +step:1622/1670 train_time:154717ms step_avg:95.39ms +step:1623/1670 train_time:154812ms step_avg:95.39ms +step:1624/1670 train_time:154907ms step_avg:95.39ms +step:1625/1670 train_time:155001ms step_avg:95.39ms +step:1625/1670 val_loss:3.2869 train_time:155094ms step_avg:95.44ms +step:1626/1670 train_time:155119ms step_avg:95.40ms +step:1627/1670 train_time:155195ms step_avg:95.39ms +step:1628/1670 train_time:155298ms step_avg:95.39ms +step:1629/1670 train_time:155393ms step_avg:95.39ms +step:1630/1670 train_time:155487ms step_avg:95.39ms +step:1631/1670 train_time:155581ms step_avg:95.39ms +step:1632/1670 train_time:155674ms step_avg:95.39ms +step:1633/1670 train_time:155767ms step_avg:95.39ms +step:1634/1670 train_time:155861ms step_avg:95.39ms +step:1635/1670 train_time:155955ms step_avg:95.39ms +step:1636/1670 train_time:156049ms step_avg:95.38ms +step:1637/1670 train_time:156144ms step_avg:95.38ms +step:1638/1670 train_time:156241ms step_avg:95.38ms +step:1639/1670 train_time:156337ms step_avg:95.39ms +step:1640/1670 train_time:156432ms step_avg:95.39ms +step:1641/1670 train_time:156527ms step_avg:95.39ms +step:1642/1670 train_time:156620ms step_avg:95.38ms +step:1643/1670 train_time:156714ms step_avg:95.38ms +step:1644/1670 train_time:156808ms step_avg:95.38ms +step:1645/1670 train_time:156901ms step_avg:95.38ms +step:1646/1670 train_time:156995ms step_avg:95.38ms +step:1647/1670 train_time:157090ms step_avg:95.38ms +step:1648/1670 train_time:157186ms step_avg:95.38ms +step:1649/1670 train_time:157282ms step_avg:95.38ms +step:1650/1670 train_time:157378ms step_avg:95.38ms +step:1651/1670 train_time:157473ms step_avg:95.38ms +step:1652/1670 train_time:157568ms step_avg:95.38ms +step:1653/1670 train_time:157661ms step_avg:95.38ms +step:1654/1670 train_time:157755ms step_avg:95.38ms +step:1655/1670 train_time:157850ms step_avg:95.38ms +step:1656/1670 train_time:157945ms step_avg:95.38ms +step:1657/1670 train_time:158039ms step_avg:95.38ms +step:1658/1670 train_time:158133ms step_avg:95.38ms +step:1659/1670 train_time:158229ms step_avg:95.38ms +step:1660/1670 train_time:158325ms step_avg:95.38ms +step:1661/1670 train_time:158421ms step_avg:95.38ms +step:1662/1670 train_time:158515ms step_avg:95.38ms +step:1663/1670 train_time:158610ms step_avg:95.38ms +step:1664/1670 train_time:158704ms step_avg:95.37ms +step:1665/1670 train_time:158798ms step_avg:95.37ms +step:1666/1670 train_time:158892ms step_avg:95.37ms +step:1667/1670 train_time:158987ms step_avg:95.37ms +step:1668/1670 train_time:159081ms step_avg:95.37ms +step:1669/1670 train_time:159175ms step_avg:95.37ms +step:1670/1670 train_time:159269ms step_avg:95.37ms +step:1670/1670 val_loss:3.2779 train_time:159447ms step_avg:95.48ms +peak memory allocated: 32712 MiB reserved: 48456 MiB diff --git a/train_gpt.py b/train_gpt.py index 257dfe5ef..6c53f0815 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -6,8 +6,11 @@ import time import copy import glob +import math + from dataclasses import dataclass from functools import lru_cache +from itertools import accumulate from pathlib import Path os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -556,27 +559,26 @@ def forward(self, x: Tensor): else: return F.linear(x, self.weight.type_as(x)) -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + def __init__(self, dim: int, head_dim: int, num_heads: int): super().__init__() self.num_heads = num_heads self.head_dim = head_dim @@ -590,36 +592,35 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): with torch.no_grad(): self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate = CastedLinear(12, num_heads) self.attn_gate.weight.detach().zero_() - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + def forward(self, x: Tensor, attn_args: AttnArgs): B, T = x.size(0), x.size(1) # batch size, sequence length assert B == 1, "varlen sequences requires B == 1" assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) + q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v + v = sa_lambdas[0] * v max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, - causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side y = F.linear(y, self.qkvo_w[3].type_as(y)) return y @@ -644,20 +645,18 @@ def forward(self, x: Tensor): x = F.linear(x, self.c_proj.type_as(x)) return x - class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - SKIPPED_MLP_BLOCKS = [0] # skip MLP blocks for first MLP layer by @EmelyanenkoK - self.mlp = None if layer_idx in SKIPPED_MLP_BLOCKS else MLP(dim) + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, - seqlens: Tensor, bm_size: int): + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): x = lambdas[0] * x + lambdas[1] * x0 if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.attn(norm(x), attn_args) if self.mlp is not None: x = x + self.mlp(norm(x)) return x @@ -669,14 +668,14 @@ def next_multiple_of_n(v: float | int, *, n: int): return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): super().__init__() vocab_size = next_multiple_of_n(vocab_size, n=128) self.embed = nn.Embedding(vocab_size, model_dim) # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. use_fp8 = not os.environ.get("DISABLE_FP8", False) @@ -691,6 +690,8 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas torch.ones(pad), ])) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) # set learning rates for param in self.embed.parameters(): param.lr_mul = 75. @@ -699,6 +700,33 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.lm_head.weight.lr_mul = 1.0 self.scalars.lr_mul = 5.0 + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) + self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) + scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): assert input_seq.ndim == 1 @@ -723,9 +751,18 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in n = len(self.blocks) // 2 for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws] + ) if i >= n: x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + x = self.blocks[i](x, x0, lambdas[i], attn_args) if i < n: skip_connections.append(x) @@ -868,15 +905,16 @@ class Hyperparameters: train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1705 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate # evaluation and logging - run_id: str = str(uuid.uuid4()) + run_id: str = f"yarn/{uuid.uuid4()}" val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end save_checkpoint: bool = False # attention masking block_size: int = 128 ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd args = Hyperparameters() @@ -928,6 +966,7 @@ def nvidia_smi(): vocab_size=50257, num_layers=12, num_heads=6, + head_dim=128, model_dim=768, max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) ).cuda() @@ -964,6 +1003,8 @@ def get_lr(step: int): return lr def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate x = step / (1 + args.num_iterations) assert 0 <= x < 1 ws_idx = int(len(args.ws_schedule) * x) @@ -1003,9 +1044,13 @@ def get_ws(step: int): t0 = time.perf_counter() # begin training train_steps = args.num_iterations +ws = get_ws(0) for step in range(train_steps + 1): last_step = (step == train_steps) - ws = get_ws(step) + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws=new_ws # --------------- VALIDATION SECTION ----------------- if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): From 4bfa3a585e1bd4f1a344ed33f0c67bb3d99eacf3 Mon Sep 17 00:00:00 2001 From: Larry Dial <42926649+ClassicLarry@users.noreply.github.com> Date: Wed, 10 Sep 2025 23:59:00 -0700 Subject: [PATCH 08/14] Update ReadMe.md --- records/091025_Yarn/ReadMe.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/records/091025_Yarn/ReadMe.md b/records/091025_Yarn/ReadMe.md index 60db04692..745ee1742 100644 --- a/records/091025_Yarn/ReadMe.md +++ b/records/091025_Yarn/ReadMe.md @@ -19,7 +19,7 @@ self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) Based on empirical testing, the 0.1 constant in 0.1*log(curr/prev)+1 formula from YaRN is updated to 0.2. The constant attn_scale of 0.12 is updated to a starting value of 0.1, such that the distribution over training has a similar mean, ranging between 0.1 and 0.14. - +image ``` # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @@ -30,7 +30,8 @@ attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * f self.attn_scales = dict(zip(windows, attn_scales)) ``` -YaRN has a straighforward implementation, shown below. alpha and beta are left at the default constants of 1 and 32, based on the original YaRN paper which was tuned for Llama. +YaRN has a straighforward implementation, shown below. alpha and beta are left at the default constants of 1 and 32, based on the original YaRN paper which was tuned for Llama. The frequency update incurred by YaRN is most notable from ws 3->7 and dimensions 5 to 10. +image ``` def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) @@ -42,10 +43,11 @@ def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=3 self.rotary_cos.copy_(theta.cos()) self.rotary_sin.copy_(theta.sin()) ``` -The frequency update incurred by YaRN is most notable from ws 3->7 and dimensions 5 to 10. + Arg ws_validate enables the model to be validated at a longer attention window than training. This arg is set to 13, which differs from the final training window size of 11. +image ``` def get_ws(step: int): @@ -60,6 +62,7 @@ def get_ws(step: int): Attention args are batched to improve readablility. cooldown_frac is increased from 0.45 to 0.5 to compliment the reduction from 1705 to 1670 steps, following the heuristic of a fixed number of cooldown steps. Dropping below 1695 steps has a secondary benefit of eliminating the 9th file read, saving roughly 200ms. Without YaRN, there is a substantial spike in validation loss when the attention window is abrubtly increased from 3 to 7. +image Extending the final validation window out shows roughly a 0.0015 improvement in loss for 11->13. Interestingly, odd increments perform substantially better. @varunneal has noted that "One thing to note is that floor division (ws_short = ws_long // 2) has different behavior for odd vs short window sizes. I generally found odd window sizes performed surprisingly better." The attention schedule follows (long/short) (3/1) -> (7/3) -> (11/5). It may be that the short attention window performs better when it is under 50% of the long window, or it may be that the model learns to fit the long/short ratio, and performs poorly when this ratio is substantially altered, or there may be a completely different explanation. @@ -67,7 +70,8 @@ Ablations were ran to measure the impact of each change: * new_record * no_attn_scale. Keep constant attn scale of 0.12. * no_freq_scale. Keep constant rotary freq based on 1024^(0..1). -* prior_record. Updated steps from 1705 to 1670. +* prior_record. Prior record with updated steps from 1705 to 1670 and cooldown frac to 0.5. +image Future Considerations: @@ -91,4 +95,4 @@ print('acc:',torch.std_mean(torch.tensor(accs))) print('time:',torch.std_mean(torch.tensor(times))) # time: (tensor(0.1897), tensor(159.3333)) -``` \ No newline at end of file +``` From eae03e69dda1866eb7dfb3285172945c592b723f Mon Sep 17 00:00:00 2001 From: larry dial Date: Thu, 11 Sep 2025 00:29:44 -0700 Subject: [PATCH 09/14] . --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index 6c53f0815..98f8df08c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -908,7 +908,7 @@ class Hyperparameters: num_iterations: int = 1670 # number of iterations to run cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate # evaluation and logging - run_id: str = f"yarn/{uuid.uuid4()}" + run_id: str = f"{uuid.uuid4()}" val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end save_checkpoint: bool = False # attention masking From a96fbbd4309fcf6427f9191bff54ed2d460c9565 Mon Sep 17 00:00:00 2001 From: Bernardino Date: Thu, 11 Sep 2025 20:12:49 +0200 Subject: [PATCH 10/14] New WR 1.25% better than PR #122: Optimize distributed training, improve skip connection gating, and enhance bfloat16 usage --- .../0d0d9882-c34f-4d82-b961-a17d5659c988.txt | 3382 +++++++++++++++++ .../0d451b7e-6500-41ae-ac9d-02352e611b88.txt | 3382 +++++++++++++++++ .../19171a35-3730-4239-b15e-3728d8de73db.txt | 3382 +++++++++++++++++ .../3b564c86-3d85-490d-a96a-83bd60d60f11.txt | 3382 +++++++++++++++++ .../648dffae-9eb3-4d2a-a28a-dae8d1152aa7.txt | 3382 +++++++++++++++++ .../6ecf5ea5-e999-4da1-a501-4fbc7160aec5.txt | 3382 +++++++++++++++++ .../7129a36e-505d-456b-aed5-ea8e455a0bac.txt | 3382 +++++++++++++++++ records/091125_VectSigmoidBFloat16/README.md | 53 + .../a077c741-ce5d-4639-955b-d7a2660b5cf8.txt | 3382 +++++++++++++++++ .../a72c99e9-0019-4baa-a858-f8738392933f.txt | 3382 +++++++++++++++++ .../ab5b991b-3767-4092-851a-5c266ae5c1e2.txt | 3382 +++++++++++++++++ .../b4bb35d4-92c1-42f4-91dd-dfbf665e66b4.txt | 3382 +++++++++++++++++ .../deb22a2c-6cf2-46f9-a350-aec1c97e9909.txt | 3382 +++++++++++++++++ train_gpt.py | 321 +- 14 files changed, 40891 insertions(+), 67 deletions(-) create mode 100644 records/091125_VectSigmoidBFloat16/0d0d9882-c34f-4d82-b961-a17d5659c988.txt create mode 100644 records/091125_VectSigmoidBFloat16/0d451b7e-6500-41ae-ac9d-02352e611b88.txt create mode 100644 records/091125_VectSigmoidBFloat16/19171a35-3730-4239-b15e-3728d8de73db.txt create mode 100644 records/091125_VectSigmoidBFloat16/3b564c86-3d85-490d-a96a-83bd60d60f11.txt create mode 100644 records/091125_VectSigmoidBFloat16/648dffae-9eb3-4d2a-a28a-dae8d1152aa7.txt create mode 100644 records/091125_VectSigmoidBFloat16/6ecf5ea5-e999-4da1-a501-4fbc7160aec5.txt create mode 100644 records/091125_VectSigmoidBFloat16/7129a36e-505d-456b-aed5-ea8e455a0bac.txt create mode 100644 records/091125_VectSigmoidBFloat16/README.md create mode 100644 records/091125_VectSigmoidBFloat16/a077c741-ce5d-4639-955b-d7a2660b5cf8.txt create mode 100644 records/091125_VectSigmoidBFloat16/a72c99e9-0019-4baa-a858-f8738392933f.txt create mode 100644 records/091125_VectSigmoidBFloat16/ab5b991b-3767-4092-851a-5c266ae5c1e2.txt create mode 100644 records/091125_VectSigmoidBFloat16/b4bb35d4-92c1-42f4-91dd-dfbf665e66b4.txt create mode 100644 records/091125_VectSigmoidBFloat16/deb22a2c-6cf2-46f9-a350-aec1c97e9909.txt diff --git a/records/091125_VectSigmoidBFloat16/0d0d9882-c34f-4d82-b961-a17d5659c988.txt b/records/091125_VectSigmoidBFloat16/0d0d9882-c34f-4d82-b961-a17d5659c988.txt new file mode 100644 index 000000000..8771aae9e --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/0d0d9882-c34f-4d82-b961-a17d5659c988.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:55:28 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 37C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 44C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 40C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.10ms +step:1/1670 train_time:296ms step_avg:296.06ms +step:2/1670 train_time:316ms step_avg:157.94ms +step:3/1670 train_time:382ms step_avg:127.44ms +step:4/1670 train_time:472ms step_avg:117.91ms +step:5/1670 train_time:562ms step_avg:112.41ms +step:6/1670 train_time:653ms step_avg:108.77ms +step:7/1670 train_time:744ms step_avg:106.34ms +step:8/1670 train_time:835ms step_avg:104.33ms +step:9/1670 train_time:925ms step_avg:102.81ms +step:10/1670 train_time:1016ms step_avg:101.63ms +step:11/1670 train_time:1108ms step_avg:100.73ms +step:12/1670 train_time:1202ms step_avg:100.14ms +step:13/1670 train_time:1295ms step_avg:99.61ms +step:14/1670 train_time:1388ms step_avg:99.12ms +step:15/1670 train_time:1480ms step_avg:98.66ms +step:16/1670 train_time:1571ms step_avg:98.19ms +step:17/1670 train_time:1664ms step_avg:97.89ms +step:18/1670 train_time:1755ms step_avg:97.49ms +step:19/1670 train_time:1846ms step_avg:97.16ms +step:20/1670 train_time:1936ms step_avg:96.79ms +step:21/1670 train_time:2027ms step_avg:96.51ms +step:22/1670 train_time:2122ms step_avg:96.44ms +step:23/1670 train_time:2214ms step_avg:96.28ms +step:24/1670 train_time:2307ms step_avg:96.13ms +step:25/1670 train_time:2399ms step_avg:95.95ms +step:26/1670 train_time:2491ms step_avg:95.83ms +step:27/1670 train_time:2584ms step_avg:95.70ms +step:28/1670 train_time:2675ms step_avg:95.55ms +step:29/1670 train_time:2768ms step_avg:95.46ms +step:30/1670 train_time:2859ms step_avg:95.31ms +step:31/1670 train_time:2951ms step_avg:95.19ms +step:32/1670 train_time:3044ms step_avg:95.12ms +step:33/1670 train_time:3136ms step_avg:95.02ms +step:34/1670 train_time:3229ms step_avg:94.97ms +step:35/1670 train_time:3321ms step_avg:94.89ms +step:36/1670 train_time:3412ms step_avg:94.79ms +step:37/1670 train_time:3504ms step_avg:94.69ms +step:38/1670 train_time:3595ms step_avg:94.61ms +step:39/1670 train_time:3688ms step_avg:94.56ms +step:40/1670 train_time:3778ms step_avg:94.45ms +step:41/1670 train_time:3871ms step_avg:94.41ms +step:42/1670 train_time:3963ms step_avg:94.35ms +step:43/1670 train_time:4054ms step_avg:94.28ms +step:44/1670 train_time:4145ms step_avg:94.20ms +step:45/1670 train_time:4236ms step_avg:94.14ms +step:46/1670 train_time:4328ms step_avg:94.09ms +step:47/1670 train_time:4420ms step_avg:94.03ms +step:48/1670 train_time:4511ms step_avg:93.98ms +step:49/1670 train_time:4602ms step_avg:93.92ms +step:50/1670 train_time:4693ms step_avg:93.87ms +step:51/1670 train_time:4785ms step_avg:93.82ms +step:52/1670 train_time:4875ms step_avg:93.76ms +step:53/1670 train_time:4968ms step_avg:93.73ms +step:54/1670 train_time:5060ms step_avg:93.70ms +step:55/1670 train_time:5152ms step_avg:93.67ms +step:56/1670 train_time:5243ms step_avg:93.63ms +step:57/1670 train_time:5335ms step_avg:93.59ms +step:58/1670 train_time:5427ms step_avg:93.57ms +step:59/1670 train_time:5519ms step_avg:93.54ms +step:60/1670 train_time:5611ms step_avg:93.51ms +step:61/1670 train_time:5702ms step_avg:93.48ms +step:62/1670 train_time:5793ms step_avg:93.44ms +step:63/1670 train_time:5886ms step_avg:93.43ms +step:64/1670 train_time:5980ms step_avg:93.43ms +step:65/1670 train_time:6071ms step_avg:93.40ms +step:66/1670 train_time:6164ms step_avg:93.39ms +step:67/1670 train_time:6256ms step_avg:93.37ms +step:68/1670 train_time:6348ms step_avg:93.35ms +step:69/1670 train_time:6439ms step_avg:93.32ms +step:70/1670 train_time:6530ms step_avg:93.28ms +step:71/1670 train_time:6622ms step_avg:93.27ms +step:72/1670 train_time:6713ms step_avg:93.24ms +step:73/1670 train_time:6804ms step_avg:93.21ms +step:74/1670 train_time:6895ms step_avg:93.18ms +step:75/1670 train_time:6988ms step_avg:93.17ms +step:76/1670 train_time:7079ms step_avg:93.15ms +step:77/1670 train_time:7171ms step_avg:93.13ms +step:78/1670 train_time:7263ms step_avg:93.12ms +step:79/1670 train_time:7354ms step_avg:93.09ms +step:80/1670 train_time:7446ms step_avg:93.08ms +step:81/1670 train_time:7538ms step_avg:93.06ms +step:82/1670 train_time:7629ms step_avg:93.04ms +step:83/1670 train_time:7722ms step_avg:93.03ms +step:84/1670 train_time:7813ms step_avg:93.01ms +step:85/1670 train_time:7905ms step_avg:93.00ms +step:86/1670 train_time:7995ms step_avg:92.97ms +step:87/1670 train_time:8088ms step_avg:92.97ms +step:88/1670 train_time:8180ms step_avg:92.95ms +step:89/1670 train_time:8273ms step_avg:92.95ms +step:90/1670 train_time:8365ms step_avg:92.94ms +step:91/1670 train_time:8456ms step_avg:92.92ms +step:92/1670 train_time:8548ms step_avg:92.92ms +step:93/1670 train_time:8639ms step_avg:92.89ms +step:94/1670 train_time:8731ms step_avg:92.88ms +step:95/1670 train_time:8823ms step_avg:92.87ms +step:96/1670 train_time:8914ms step_avg:92.86ms +step:97/1670 train_time:9006ms step_avg:92.84ms +step:98/1670 train_time:9097ms step_avg:92.83ms +step:99/1670 train_time:9192ms step_avg:92.84ms +step:100/1670 train_time:9283ms step_avg:92.83ms +step:101/1670 train_time:9374ms step_avg:92.81ms +step:102/1670 train_time:9466ms step_avg:92.80ms +step:103/1670 train_time:9557ms step_avg:92.78ms +step:104/1670 train_time:9649ms step_avg:92.78ms +step:105/1670 train_time:9740ms step_avg:92.76ms +step:106/1670 train_time:9832ms step_avg:92.75ms +step:107/1670 train_time:9923ms step_avg:92.74ms +step:108/1670 train_time:10013ms step_avg:92.72ms +step:109/1670 train_time:10105ms step_avg:92.70ms +step:110/1670 train_time:10196ms step_avg:92.69ms +step:111/1670 train_time:10287ms step_avg:92.68ms +step:112/1670 train_time:10379ms step_avg:92.67ms +step:113/1670 train_time:10471ms step_avg:92.67ms +step:114/1670 train_time:10564ms step_avg:92.66ms +step:115/1670 train_time:10655ms step_avg:92.65ms +step:116/1670 train_time:10748ms step_avg:92.65ms +step:117/1670 train_time:10838ms step_avg:92.64ms +step:118/1670 train_time:10930ms step_avg:92.62ms +step:119/1670 train_time:11020ms step_avg:92.61ms +step:120/1670 train_time:11112ms step_avg:92.60ms +step:121/1670 train_time:11202ms step_avg:92.58ms +step:122/1670 train_time:11294ms step_avg:92.57ms +step:123/1670 train_time:11386ms step_avg:92.57ms +step:124/1670 train_time:11477ms step_avg:92.56ms +step:125/1670 train_time:11569ms step_avg:92.55ms +step:125/1670 val_loss:4.2929 train_time:11659ms step_avg:93.27ms +step:126/1670 train_time:11682ms step_avg:92.71ms +step:127/1670 train_time:11756ms step_avg:92.56ms +step:128/1670 train_time:11858ms step_avg:92.64ms +step:129/1670 train_time:11951ms step_avg:92.64ms +step:130/1670 train_time:12041ms step_avg:92.62ms +step:131/1670 train_time:12132ms step_avg:92.61ms +step:132/1670 train_time:12222ms step_avg:92.59ms +step:133/1670 train_time:12312ms step_avg:92.57ms +step:134/1670 train_time:12402ms step_avg:92.55ms +step:135/1670 train_time:12492ms step_avg:92.53ms +step:136/1670 train_time:12582ms step_avg:92.52ms +step:137/1670 train_time:12674ms step_avg:92.51ms +step:138/1670 train_time:12767ms step_avg:92.51ms +step:139/1670 train_time:12861ms step_avg:92.53ms +step:140/1670 train_time:12954ms step_avg:92.53ms +step:141/1670 train_time:13046ms step_avg:92.53ms +step:142/1670 train_time:13137ms step_avg:92.51ms +step:143/1670 train_time:13228ms step_avg:92.50ms +step:144/1670 train_time:13319ms step_avg:92.49ms +step:145/1670 train_time:13409ms step_avg:92.47ms +step:146/1670 train_time:13499ms step_avg:92.46ms +step:147/1670 train_time:13591ms step_avg:92.45ms +step:148/1670 train_time:13681ms step_avg:92.44ms +step:149/1670 train_time:13775ms step_avg:92.45ms +step:150/1670 train_time:13867ms step_avg:92.45ms +step:151/1670 train_time:13959ms step_avg:92.45ms +step:152/1670 train_time:14052ms step_avg:92.44ms +step:153/1670 train_time:14142ms step_avg:92.43ms +step:154/1670 train_time:14233ms step_avg:92.42ms +step:155/1670 train_time:14323ms step_avg:92.41ms +step:156/1670 train_time:14413ms step_avg:92.39ms +step:157/1670 train_time:14504ms step_avg:92.38ms +step:158/1670 train_time:14594ms step_avg:92.37ms +step:159/1670 train_time:14686ms step_avg:92.36ms +step:160/1670 train_time:14778ms step_avg:92.36ms +step:161/1670 train_time:14871ms step_avg:92.36ms +step:162/1670 train_time:14963ms step_avg:92.36ms +step:163/1670 train_time:15056ms step_avg:92.37ms +step:164/1670 train_time:15147ms step_avg:92.36ms +step:165/1670 train_time:15238ms step_avg:92.35ms +step:166/1670 train_time:15329ms step_avg:92.34ms +step:167/1670 train_time:15419ms step_avg:92.33ms +step:168/1670 train_time:15508ms step_avg:92.31ms +step:169/1670 train_time:15599ms step_avg:92.30ms +step:170/1670 train_time:15690ms step_avg:92.29ms +step:171/1670 train_time:15781ms step_avg:92.28ms +step:172/1670 train_time:15872ms step_avg:92.28ms +step:173/1670 train_time:15965ms step_avg:92.28ms +step:174/1670 train_time:16057ms step_avg:92.28ms +step:175/1670 train_time:16149ms step_avg:92.28ms +step:176/1670 train_time:16241ms step_avg:92.28ms +step:177/1670 train_time:16332ms step_avg:92.27ms +step:178/1670 train_time:16423ms step_avg:92.26ms +step:179/1670 train_time:16514ms step_avg:92.26ms +step:180/1670 train_time:16604ms step_avg:92.25ms +step:181/1670 train_time:16695ms step_avg:92.24ms +step:182/1670 train_time:16786ms step_avg:92.23ms +step:183/1670 train_time:16878ms step_avg:92.23ms +step:184/1670 train_time:16971ms step_avg:92.23ms +step:185/1670 train_time:17062ms step_avg:92.23ms +step:186/1670 train_time:17153ms step_avg:92.22ms +step:187/1670 train_time:17245ms step_avg:92.22ms +step:188/1670 train_time:17337ms step_avg:92.22ms +step:189/1670 train_time:17428ms step_avg:92.21ms +step:190/1670 train_time:17518ms step_avg:92.20ms +step:191/1670 train_time:17608ms step_avg:92.19ms +step:192/1670 train_time:17698ms step_avg:92.18ms +step:193/1670 train_time:17789ms step_avg:92.17ms +step:194/1670 train_time:17880ms step_avg:92.17ms +step:195/1670 train_time:17972ms step_avg:92.17ms +step:196/1670 train_time:18063ms step_avg:92.16ms +step:197/1670 train_time:18155ms step_avg:92.16ms +step:198/1670 train_time:18247ms step_avg:92.16ms +step:199/1670 train_time:18339ms step_avg:92.16ms +step:200/1670 train_time:18431ms step_avg:92.15ms +step:201/1670 train_time:18521ms step_avg:92.15ms +step:202/1670 train_time:18611ms step_avg:92.13ms +step:203/1670 train_time:18702ms step_avg:92.13ms +step:204/1670 train_time:18792ms step_avg:92.12ms +step:205/1670 train_time:18886ms step_avg:92.13ms +step:206/1670 train_time:18975ms step_avg:92.11ms +step:207/1670 train_time:19065ms step_avg:92.10ms +step:208/1670 train_time:19157ms step_avg:92.10ms +step:209/1670 train_time:19249ms step_avg:92.10ms +step:210/1670 train_time:19341ms step_avg:92.10ms +step:211/1670 train_time:19432ms step_avg:92.10ms +step:212/1670 train_time:19523ms step_avg:92.09ms +step:213/1670 train_time:19772ms step_avg:92.83ms +step:214/1670 train_time:19842ms step_avg:92.72ms +step:215/1670 train_time:19931ms step_avg:92.70ms +step:216/1670 train_time:20021ms step_avg:92.69ms +step:217/1670 train_time:20111ms step_avg:92.68ms +step:218/1670 train_time:20202ms step_avg:92.67ms +step:219/1670 train_time:20291ms step_avg:92.65ms +step:220/1670 train_time:20382ms step_avg:92.64ms +step:221/1670 train_time:20472ms step_avg:92.63ms +step:222/1670 train_time:20561ms step_avg:92.62ms +step:223/1670 train_time:20657ms step_avg:92.63ms +step:224/1670 train_time:20753ms step_avg:92.65ms +step:225/1670 train_time:20845ms step_avg:92.64ms +step:226/1670 train_time:20937ms step_avg:92.64ms +step:227/1670 train_time:21028ms step_avg:92.64ms +step:228/1670 train_time:21119ms step_avg:92.63ms +step:229/1670 train_time:21210ms step_avg:92.62ms +step:230/1670 train_time:21299ms step_avg:92.61ms +step:231/1670 train_time:21390ms step_avg:92.60ms +step:232/1670 train_time:21479ms step_avg:92.58ms +step:233/1670 train_time:21570ms step_avg:92.57ms +step:234/1670 train_time:21662ms step_avg:92.57ms +step:235/1670 train_time:21756ms step_avg:92.58ms +step:236/1670 train_time:21847ms step_avg:92.57ms +step:237/1670 train_time:21939ms step_avg:92.57ms +step:238/1670 train_time:22030ms step_avg:92.56ms +step:239/1670 train_time:22120ms step_avg:92.55ms +step:240/1670 train_time:22211ms step_avg:92.54ms +step:241/1670 train_time:22301ms step_avg:92.54ms +step:242/1670 train_time:22392ms step_avg:92.53ms +step:243/1670 train_time:22482ms step_avg:92.52ms +step:244/1670 train_time:22574ms step_avg:92.52ms +step:245/1670 train_time:22667ms step_avg:92.52ms +step:246/1670 train_time:22759ms step_avg:92.52ms +step:247/1670 train_time:22851ms step_avg:92.51ms +step:248/1670 train_time:22942ms step_avg:92.51ms +step:249/1670 train_time:23034ms step_avg:92.51ms +step:250/1670 train_time:23125ms step_avg:92.50ms +step:250/1670 val_loss:3.9750 train_time:23218ms step_avg:92.87ms +step:251/1670 train_time:23240ms step_avg:92.59ms +step:252/1670 train_time:23310ms step_avg:92.50ms +step:253/1670 train_time:23402ms step_avg:92.50ms +step:254/1670 train_time:23491ms step_avg:92.49ms +step:255/1670 train_time:23581ms step_avg:92.48ms +step:256/1670 train_time:23672ms step_avg:92.47ms +step:257/1670 train_time:23762ms step_avg:92.46ms +step:258/1670 train_time:23854ms step_avg:92.46ms +step:259/1670 train_time:23946ms step_avg:92.46ms +step:260/1670 train_time:24038ms step_avg:92.45ms +step:261/1670 train_time:24129ms step_avg:92.45ms +step:262/1670 train_time:24223ms step_avg:92.45ms +step:263/1670 train_time:24315ms step_avg:92.45ms +step:264/1670 train_time:24408ms step_avg:92.45ms +step:265/1670 train_time:24498ms step_avg:92.45ms +step:266/1670 train_time:24589ms step_avg:92.44ms +step:267/1670 train_time:24679ms step_avg:92.43ms +step:268/1670 train_time:24769ms step_avg:92.42ms +step:269/1670 train_time:24860ms step_avg:92.42ms +step:270/1670 train_time:24951ms step_avg:92.41ms +step:271/1670 train_time:25041ms step_avg:92.40ms +step:272/1670 train_time:25133ms step_avg:92.40ms +step:273/1670 train_time:25226ms step_avg:92.40ms +step:274/1670 train_time:25318ms step_avg:92.40ms +step:275/1670 train_time:25411ms step_avg:92.40ms +step:276/1670 train_time:25501ms step_avg:92.40ms +step:277/1670 train_time:25592ms step_avg:92.39ms +step:278/1670 train_time:25682ms step_avg:92.38ms +step:279/1670 train_time:25774ms step_avg:92.38ms +step:280/1670 train_time:25865ms step_avg:92.38ms +step:281/1670 train_time:25955ms step_avg:92.37ms +step:282/1670 train_time:26047ms step_avg:92.36ms +step:283/1670 train_time:26137ms step_avg:92.36ms +step:284/1670 train_time:26229ms step_avg:92.36ms +step:285/1670 train_time:26322ms step_avg:92.36ms +step:286/1670 train_time:26414ms step_avg:92.36ms +step:287/1670 train_time:26505ms step_avg:92.35ms +step:288/1670 train_time:26595ms step_avg:92.35ms +step:289/1670 train_time:26686ms step_avg:92.34ms +step:290/1670 train_time:26778ms step_avg:92.34ms +step:291/1670 train_time:26869ms step_avg:92.33ms +step:292/1670 train_time:26960ms step_avg:92.33ms +step:293/1670 train_time:27051ms step_avg:92.33ms +step:294/1670 train_time:27143ms step_avg:92.32ms +step:295/1670 train_time:27234ms step_avg:92.32ms +step:296/1670 train_time:27327ms step_avg:92.32ms +step:297/1670 train_time:27418ms step_avg:92.32ms +step:298/1670 train_time:27510ms step_avg:92.32ms +step:299/1670 train_time:27602ms step_avg:92.32ms +step:300/1670 train_time:27693ms step_avg:92.31ms +step:301/1670 train_time:27784ms step_avg:92.30ms +step:302/1670 train_time:27875ms step_avg:92.30ms +step:303/1670 train_time:27965ms step_avg:92.29ms +step:304/1670 train_time:28055ms step_avg:92.29ms +step:305/1670 train_time:28146ms step_avg:92.28ms +step:306/1670 train_time:28237ms step_avg:92.28ms +step:307/1670 train_time:28329ms step_avg:92.28ms +step:308/1670 train_time:28421ms step_avg:92.28ms +step:309/1670 train_time:28513ms step_avg:92.28ms +step:310/1670 train_time:28605ms step_avg:92.27ms +step:311/1670 train_time:28695ms step_avg:92.27ms +step:312/1670 train_time:28787ms step_avg:92.27ms +step:313/1670 train_time:28878ms step_avg:92.26ms +step:314/1670 train_time:28969ms step_avg:92.26ms +step:315/1670 train_time:29060ms step_avg:92.25ms +step:316/1670 train_time:29151ms step_avg:92.25ms +step:317/1670 train_time:29242ms step_avg:92.25ms +step:318/1670 train_time:29333ms step_avg:92.24ms +step:319/1670 train_time:29425ms step_avg:92.24ms +step:320/1670 train_time:29516ms step_avg:92.24ms +step:321/1670 train_time:29608ms step_avg:92.24ms +step:322/1670 train_time:29699ms step_avg:92.23ms +step:323/1670 train_time:29789ms step_avg:92.23ms +step:324/1670 train_time:29881ms step_avg:92.22ms +step:325/1670 train_time:29972ms step_avg:92.22ms +step:326/1670 train_time:30064ms step_avg:92.22ms +step:327/1670 train_time:30155ms step_avg:92.22ms +step:328/1670 train_time:30247ms step_avg:92.22ms +step:329/1670 train_time:30338ms step_avg:92.21ms +step:330/1670 train_time:30430ms step_avg:92.21ms +step:331/1670 train_time:30523ms step_avg:92.21ms +step:332/1670 train_time:30614ms step_avg:92.21ms +step:333/1670 train_time:30705ms step_avg:92.21ms +step:334/1670 train_time:30796ms step_avg:92.20ms +step:335/1670 train_time:30888ms step_avg:92.20ms +step:336/1670 train_time:30978ms step_avg:92.20ms +step:337/1670 train_time:31069ms step_avg:92.19ms +step:338/1670 train_time:31160ms step_avg:92.19ms +step:339/1670 train_time:31252ms step_avg:92.19ms +step:340/1670 train_time:31344ms step_avg:92.19ms +step:341/1670 train_time:31435ms step_avg:92.18ms +step:342/1670 train_time:31527ms step_avg:92.18ms +step:343/1670 train_time:31618ms step_avg:92.18ms +step:344/1670 train_time:31710ms step_avg:92.18ms +step:345/1670 train_time:31800ms step_avg:92.17ms +step:346/1670 train_time:31892ms step_avg:92.17ms +step:347/1670 train_time:31982ms step_avg:92.17ms +step:348/1670 train_time:32072ms step_avg:92.16ms +step:349/1670 train_time:32164ms step_avg:92.16ms +step:350/1670 train_time:32255ms step_avg:92.16ms +step:351/1670 train_time:32348ms step_avg:92.16ms +step:352/1670 train_time:32439ms step_avg:92.16ms +step:353/1670 train_time:32532ms step_avg:92.16ms +step:354/1670 train_time:32623ms step_avg:92.15ms +step:355/1670 train_time:32714ms step_avg:92.15ms +step:356/1670 train_time:32805ms step_avg:92.15ms +step:357/1670 train_time:32895ms step_avg:92.14ms +step:358/1670 train_time:32987ms step_avg:92.14ms +step:359/1670 train_time:33077ms step_avg:92.14ms +step:360/1670 train_time:33169ms step_avg:92.14ms +step:361/1670 train_time:33260ms step_avg:92.13ms +step:362/1670 train_time:33352ms step_avg:92.13ms +step:363/1670 train_time:33443ms step_avg:92.13ms +step:364/1670 train_time:33535ms step_avg:92.13ms +step:365/1670 train_time:33627ms step_avg:92.13ms +step:366/1670 train_time:33718ms step_avg:92.13ms +step:367/1670 train_time:33810ms step_avg:92.13ms +step:368/1670 train_time:33901ms step_avg:92.12ms +step:369/1670 train_time:33991ms step_avg:92.12ms +step:370/1670 train_time:34082ms step_avg:92.11ms +step:371/1670 train_time:34172ms step_avg:92.11ms +step:372/1670 train_time:34263ms step_avg:92.10ms +step:373/1670 train_time:34354ms step_avg:92.10ms +step:374/1670 train_time:34446ms step_avg:92.10ms +step:375/1670 train_time:34537ms step_avg:92.10ms +step:375/1670 val_loss:3.8164 train_time:34628ms step_avg:92.34ms +step:376/1670 train_time:34647ms step_avg:92.15ms +step:377/1670 train_time:34718ms step_avg:92.09ms +step:378/1670 train_time:34811ms step_avg:92.09ms +step:379/1670 train_time:34902ms step_avg:92.09ms +step:380/1670 train_time:34993ms step_avg:92.09ms +step:381/1670 train_time:35083ms step_avg:92.08ms +step:382/1670 train_time:35174ms step_avg:92.08ms +step:383/1670 train_time:35265ms step_avg:92.07ms +step:384/1670 train_time:35356ms step_avg:92.07ms +step:385/1670 train_time:35449ms step_avg:92.07ms +step:386/1670 train_time:35540ms step_avg:92.07ms +step:387/1670 train_time:35632ms step_avg:92.07ms +step:388/1670 train_time:35725ms step_avg:92.07ms +step:389/1670 train_time:35817ms step_avg:92.07ms +step:390/1670 train_time:35909ms step_avg:92.07ms +step:391/1670 train_time:36000ms step_avg:92.07ms +step:392/1670 train_time:36090ms step_avg:92.07ms +step:393/1670 train_time:36180ms step_avg:92.06ms +step:394/1670 train_time:36272ms step_avg:92.06ms +step:395/1670 train_time:36364ms step_avg:92.06ms +step:396/1670 train_time:36456ms step_avg:92.06ms +step:397/1670 train_time:36547ms step_avg:92.06ms +step:398/1670 train_time:36640ms step_avg:92.06ms +step:399/1670 train_time:36732ms step_avg:92.06ms +step:400/1670 train_time:36823ms step_avg:92.06ms +step:401/1670 train_time:36914ms step_avg:92.06ms +step:402/1670 train_time:37005ms step_avg:92.05ms +step:403/1670 train_time:37095ms step_avg:92.05ms +step:404/1670 train_time:37186ms step_avg:92.04ms +step:405/1670 train_time:37277ms step_avg:92.04ms +step:406/1670 train_time:37368ms step_avg:92.04ms +step:407/1670 train_time:37459ms step_avg:92.04ms +step:408/1670 train_time:37552ms step_avg:92.04ms +step:409/1670 train_time:37643ms step_avg:92.04ms +step:410/1670 train_time:37734ms step_avg:92.03ms +step:411/1670 train_time:37825ms step_avg:92.03ms +step:412/1670 train_time:37918ms step_avg:92.03ms +step:413/1670 train_time:38008ms step_avg:92.03ms +step:414/1670 train_time:38100ms step_avg:92.03ms +step:415/1670 train_time:38191ms step_avg:92.03ms +step:416/1670 train_time:38282ms step_avg:92.02ms +step:417/1670 train_time:38373ms step_avg:92.02ms +step:418/1670 train_time:38463ms step_avg:92.02ms +step:419/1670 train_time:38556ms step_avg:92.02ms +step:420/1670 train_time:38647ms step_avg:92.02ms +step:421/1670 train_time:38738ms step_avg:92.01ms +step:422/1670 train_time:38829ms step_avg:92.01ms +step:423/1670 train_time:38921ms step_avg:92.01ms +step:424/1670 train_time:39012ms step_avg:92.01ms +step:425/1670 train_time:39260ms step_avg:92.38ms +step:426/1670 train_time:39332ms step_avg:92.33ms +step:427/1670 train_time:39421ms step_avg:92.32ms +step:428/1670 train_time:39512ms step_avg:92.32ms +step:429/1670 train_time:39602ms step_avg:92.31ms +step:430/1670 train_time:39692ms step_avg:92.31ms +step:431/1670 train_time:39782ms step_avg:92.30ms +step:432/1670 train_time:39872ms step_avg:92.30ms +step:433/1670 train_time:39962ms step_avg:92.29ms +step:434/1670 train_time:40052ms step_avg:92.29ms +step:435/1670 train_time:40147ms step_avg:92.29ms +step:436/1670 train_time:40244ms step_avg:92.30ms +step:437/1670 train_time:40337ms step_avg:92.30ms +step:438/1670 train_time:40427ms step_avg:92.30ms +step:439/1670 train_time:40519ms step_avg:92.30ms +step:440/1670 train_time:40610ms step_avg:92.29ms +step:441/1670 train_time:40701ms step_avg:92.29ms +step:442/1670 train_time:40792ms step_avg:92.29ms +step:443/1670 train_time:40882ms step_avg:92.28ms +step:444/1670 train_time:40972ms step_avg:92.28ms +step:445/1670 train_time:41063ms step_avg:92.28ms +step:446/1670 train_time:41157ms step_avg:92.28ms +step:447/1670 train_time:41251ms step_avg:92.28ms +step:448/1670 train_time:41343ms step_avg:92.28ms +step:449/1670 train_time:41434ms step_avg:92.28ms +step:450/1670 train_time:41525ms step_avg:92.28ms +step:451/1670 train_time:41616ms step_avg:92.27ms +step:452/1670 train_time:41706ms step_avg:92.27ms +step:453/1670 train_time:41797ms step_avg:92.27ms +step:454/1670 train_time:41887ms step_avg:92.26ms +step:455/1670 train_time:41977ms step_avg:92.26ms +step:456/1670 train_time:42068ms step_avg:92.26ms +step:457/1670 train_time:42160ms step_avg:92.25ms +step:458/1670 train_time:42252ms step_avg:92.25ms +step:459/1670 train_time:42343ms step_avg:92.25ms +step:460/1670 train_time:42435ms step_avg:92.25ms +step:461/1670 train_time:42525ms step_avg:92.25ms +step:462/1670 train_time:42618ms step_avg:92.25ms +step:463/1670 train_time:42708ms step_avg:92.24ms +step:464/1670 train_time:42799ms step_avg:92.24ms +step:465/1670 train_time:42890ms step_avg:92.24ms +step:466/1670 train_time:42981ms step_avg:92.23ms +step:467/1670 train_time:43072ms step_avg:92.23ms +step:468/1670 train_time:43164ms step_avg:92.23ms +step:469/1670 train_time:43255ms step_avg:92.23ms +step:470/1670 train_time:43346ms step_avg:92.23ms +step:471/1670 train_time:43438ms step_avg:92.23ms +step:472/1670 train_time:43529ms step_avg:92.22ms +step:473/1670 train_time:43622ms step_avg:92.23ms +step:474/1670 train_time:43715ms step_avg:92.22ms +step:475/1670 train_time:43805ms step_avg:92.22ms +step:476/1670 train_time:43896ms step_avg:92.22ms +step:477/1670 train_time:43986ms step_avg:92.21ms +step:478/1670 train_time:44078ms step_avg:92.21ms +step:479/1670 train_time:44170ms step_avg:92.21ms +step:480/1670 train_time:44261ms step_avg:92.21ms +step:481/1670 train_time:44352ms step_avg:92.21ms +step:482/1670 train_time:44444ms step_avg:92.21ms +step:483/1670 train_time:44535ms step_avg:92.20ms +step:484/1670 train_time:44626ms step_avg:92.20ms +step:485/1670 train_time:44719ms step_avg:92.20ms +step:486/1670 train_time:44811ms step_avg:92.20ms +step:487/1670 train_time:44902ms step_avg:92.20ms +step:488/1670 train_time:44993ms step_avg:92.20ms +step:489/1670 train_time:45084ms step_avg:92.20ms +step:490/1670 train_time:45175ms step_avg:92.19ms +step:491/1670 train_time:45267ms step_avg:92.19ms +step:492/1670 train_time:45358ms step_avg:92.19ms +step:493/1670 train_time:45449ms step_avg:92.19ms +step:494/1670 train_time:45540ms step_avg:92.19ms +step:495/1670 train_time:45630ms step_avg:92.18ms +step:496/1670 train_time:45722ms step_avg:92.18ms +step:497/1670 train_time:45815ms step_avg:92.18ms +step:498/1670 train_time:45906ms step_avg:92.18ms +step:499/1670 train_time:45997ms step_avg:92.18ms +step:500/1670 train_time:46088ms step_avg:92.18ms +step:500/1670 val_loss:3.7158 train_time:46179ms step_avg:92.36ms +step:501/1670 train_time:46199ms step_avg:92.21ms +step:502/1670 train_time:46271ms step_avg:92.17ms +step:503/1670 train_time:46363ms step_avg:92.17ms +step:504/1670 train_time:46454ms step_avg:92.17ms +step:505/1670 train_time:46545ms step_avg:92.17ms +step:506/1670 train_time:46636ms step_avg:92.17ms +step:507/1670 train_time:46726ms step_avg:92.16ms +step:508/1670 train_time:46818ms step_avg:92.16ms +step:509/1670 train_time:46909ms step_avg:92.16ms +step:510/1670 train_time:47001ms step_avg:92.16ms +step:511/1670 train_time:47092ms step_avg:92.16ms +step:512/1670 train_time:47185ms step_avg:92.16ms +step:513/1670 train_time:47276ms step_avg:92.16ms +step:514/1670 train_time:47369ms step_avg:92.16ms +step:515/1670 train_time:47461ms step_avg:92.16ms +step:516/1670 train_time:47552ms step_avg:92.16ms +step:517/1670 train_time:47643ms step_avg:92.15ms +step:518/1670 train_time:47733ms step_avg:92.15ms +step:519/1670 train_time:47824ms step_avg:92.15ms +step:520/1670 train_time:47914ms step_avg:92.14ms +step:521/1670 train_time:48005ms step_avg:92.14ms +step:522/1670 train_time:48096ms step_avg:92.14ms +step:523/1670 train_time:48188ms step_avg:92.14ms +step:524/1670 train_time:48279ms step_avg:92.14ms +step:525/1670 train_time:48371ms step_avg:92.13ms +step:526/1670 train_time:48463ms step_avg:92.14ms +step:527/1670 train_time:48554ms step_avg:92.13ms +step:528/1670 train_time:48645ms step_avg:92.13ms +step:529/1670 train_time:48735ms step_avg:92.13ms +step:530/1670 train_time:48826ms step_avg:92.12ms +step:531/1670 train_time:48916ms step_avg:92.12ms +step:532/1670 train_time:49007ms step_avg:92.12ms +step:533/1670 train_time:49098ms step_avg:92.12ms +step:534/1670 train_time:49190ms step_avg:92.12ms +step:535/1670 train_time:49282ms step_avg:92.12ms +step:536/1670 train_time:49374ms step_avg:92.12ms +step:537/1670 train_time:49466ms step_avg:92.12ms +step:538/1670 train_time:49557ms step_avg:92.11ms +step:539/1670 train_time:49648ms step_avg:92.11ms +step:540/1670 train_time:49739ms step_avg:92.11ms +step:541/1670 train_time:49830ms step_avg:92.11ms +step:542/1670 train_time:49920ms step_avg:92.10ms +step:543/1670 train_time:50011ms step_avg:92.10ms +step:544/1670 train_time:50102ms step_avg:92.10ms +step:545/1670 train_time:50193ms step_avg:92.10ms +step:546/1670 train_time:50285ms step_avg:92.10ms +step:547/1670 train_time:50376ms step_avg:92.09ms +step:548/1670 train_time:50468ms step_avg:92.10ms +step:549/1670 train_time:50560ms step_avg:92.09ms +step:550/1670 train_time:50651ms step_avg:92.09ms +step:551/1670 train_time:50742ms step_avg:92.09ms +step:552/1670 train_time:50832ms step_avg:92.09ms +step:553/1670 train_time:50923ms step_avg:92.09ms +step:554/1670 train_time:51014ms step_avg:92.08ms +step:555/1670 train_time:51105ms step_avg:92.08ms +step:556/1670 train_time:51195ms step_avg:92.08ms +step:557/1670 train_time:51287ms step_avg:92.08ms +step:558/1670 train_time:51558ms step_avg:92.40ms +step:559/1670 train_time:51645ms step_avg:92.39ms +step:560/1670 train_time:51735ms step_avg:92.38ms +step:561/1670 train_time:51827ms step_avg:92.38ms +step:562/1670 train_time:51918ms step_avg:92.38ms +step:563/1670 train_time:52009ms step_avg:92.38ms +step:564/1670 train_time:52101ms step_avg:92.38ms +step:565/1670 train_time:52192ms step_avg:92.37ms +step:566/1670 train_time:52283ms step_avg:92.37ms +step:567/1670 train_time:52375ms step_avg:92.37ms +step:568/1670 train_time:52474ms step_avg:92.38ms +step:569/1670 train_time:52573ms step_avg:92.39ms +step:570/1670 train_time:52665ms step_avg:92.40ms +step:571/1670 train_time:52757ms step_avg:92.39ms +step:572/1670 train_time:52849ms step_avg:92.39ms +step:573/1670 train_time:52941ms step_avg:92.39ms +step:574/1670 train_time:53032ms step_avg:92.39ms +step:575/1670 train_time:53124ms step_avg:92.39ms +step:576/1670 train_time:53215ms step_avg:92.39ms +step:577/1670 train_time:53307ms step_avg:92.39ms +step:578/1670 train_time:53399ms step_avg:92.39ms +step:579/1670 train_time:53496ms step_avg:92.39ms +step:580/1670 train_time:53594ms step_avg:92.40ms +step:581/1670 train_time:53688ms step_avg:92.41ms +step:582/1670 train_time:53779ms step_avg:92.40ms +step:583/1670 train_time:53871ms step_avg:92.40ms +step:584/1670 train_time:53963ms step_avg:92.40ms +step:585/1670 train_time:54055ms step_avg:92.40ms +step:586/1670 train_time:54146ms step_avg:92.40ms +step:587/1670 train_time:54238ms step_avg:92.40ms +step:588/1670 train_time:54330ms step_avg:92.40ms +step:589/1670 train_time:54423ms step_avg:92.40ms +step:590/1670 train_time:54517ms step_avg:92.40ms +step:591/1670 train_time:54611ms step_avg:92.40ms +step:592/1670 train_time:54705ms step_avg:92.41ms +step:593/1670 train_time:54797ms step_avg:92.41ms +step:594/1670 train_time:54890ms step_avg:92.41ms +step:595/1670 train_time:54983ms step_avg:92.41ms +step:596/1670 train_time:55074ms step_avg:92.41ms +step:597/1670 train_time:55166ms step_avg:92.41ms +step:598/1670 train_time:55258ms step_avg:92.40ms +step:599/1670 train_time:55351ms step_avg:92.41ms +step:600/1670 train_time:55444ms step_avg:92.41ms +step:601/1670 train_time:55537ms step_avg:92.41ms +step:602/1670 train_time:55631ms step_avg:92.41ms +step:603/1670 train_time:55724ms step_avg:92.41ms +step:604/1670 train_time:55817ms step_avg:92.41ms +step:605/1670 train_time:55911ms step_avg:92.42ms +step:606/1670 train_time:56003ms step_avg:92.41ms +step:607/1670 train_time:56095ms step_avg:92.41ms +step:608/1670 train_time:56188ms step_avg:92.41ms +step:609/1670 train_time:56280ms step_avg:92.41ms +step:610/1670 train_time:56374ms step_avg:92.42ms +step:611/1670 train_time:56466ms step_avg:92.42ms +step:612/1670 train_time:56560ms step_avg:92.42ms +step:613/1670 train_time:56653ms step_avg:92.42ms +step:614/1670 train_time:56745ms step_avg:92.42ms +step:615/1670 train_time:56838ms step_avg:92.42ms +step:616/1670 train_time:56931ms step_avg:92.42ms +step:617/1670 train_time:57024ms step_avg:92.42ms +step:618/1670 train_time:57116ms step_avg:92.42ms +step:619/1670 train_time:57208ms step_avg:92.42ms +step:620/1670 train_time:57300ms step_avg:92.42ms +step:621/1670 train_time:57393ms step_avg:92.42ms +step:622/1670 train_time:57486ms step_avg:92.42ms +step:623/1670 train_time:57578ms step_avg:92.42ms +step:624/1670 train_time:57672ms step_avg:92.42ms +step:625/1670 train_time:57764ms step_avg:92.42ms +step:625/1670 val_loss:3.6126 train_time:57856ms step_avg:92.57ms +step:626/1670 train_time:57876ms step_avg:92.45ms +step:627/1670 train_time:57955ms step_avg:92.43ms +step:628/1670 train_time:58057ms step_avg:92.45ms +step:629/1670 train_time:58150ms step_avg:92.45ms +step:630/1670 train_time:58242ms step_avg:92.45ms +step:631/1670 train_time:58333ms step_avg:92.44ms +step:632/1670 train_time:58424ms step_avg:92.44ms +step:633/1670 train_time:58515ms step_avg:92.44ms +step:634/1670 train_time:58606ms step_avg:92.44ms +step:635/1670 train_time:58698ms step_avg:92.44ms +step:636/1670 train_time:58789ms step_avg:92.44ms +step:637/1670 train_time:58883ms step_avg:92.44ms +step:638/1670 train_time:58980ms step_avg:92.44ms +step:639/1670 train_time:59214ms step_avg:92.67ms +step:640/1670 train_time:59288ms step_avg:92.64ms +step:641/1670 train_time:59379ms step_avg:92.63ms +step:642/1670 train_time:59470ms step_avg:92.63ms +step:643/1670 train_time:59561ms step_avg:92.63ms +step:644/1670 train_time:59653ms step_avg:92.63ms +step:645/1670 train_time:59744ms step_avg:92.63ms +step:646/1670 train_time:59835ms step_avg:92.62ms +step:647/1670 train_time:59926ms step_avg:92.62ms +step:648/1670 train_time:60018ms step_avg:92.62ms +step:649/1670 train_time:60116ms step_avg:92.63ms +step:650/1670 train_time:60212ms step_avg:92.63ms +step:651/1670 train_time:60305ms step_avg:92.63ms +step:652/1670 train_time:60399ms step_avg:92.64ms +step:653/1670 train_time:60490ms step_avg:92.63ms +step:654/1670 train_time:60582ms step_avg:92.63ms +step:655/1670 train_time:60674ms step_avg:92.63ms +step:656/1670 train_time:60767ms step_avg:92.63ms +step:657/1670 train_time:60858ms step_avg:92.63ms +step:658/1670 train_time:60949ms step_avg:92.63ms +step:659/1670 train_time:61042ms step_avg:92.63ms +step:660/1670 train_time:61137ms step_avg:92.63ms +step:661/1670 train_time:61230ms step_avg:92.63ms +step:662/1670 train_time:61324ms step_avg:92.63ms +step:663/1670 train_time:61417ms step_avg:92.63ms +step:664/1670 train_time:61509ms step_avg:92.63ms +step:665/1670 train_time:61602ms step_avg:92.63ms +step:666/1670 train_time:61694ms step_avg:92.63ms +step:667/1670 train_time:61785ms step_avg:92.63ms +step:668/1670 train_time:61877ms step_avg:92.63ms +step:669/1670 train_time:61968ms step_avg:92.63ms +step:670/1670 train_time:62063ms step_avg:92.63ms +step:671/1670 train_time:62158ms step_avg:92.63ms +step:672/1670 train_time:62251ms step_avg:92.63ms +step:673/1670 train_time:62345ms step_avg:92.64ms +step:674/1670 train_time:62437ms step_avg:92.64ms +step:675/1670 train_time:62529ms step_avg:92.64ms +step:676/1670 train_time:62622ms step_avg:92.64ms +step:677/1670 train_time:62715ms step_avg:92.64ms +step:678/1670 train_time:62806ms step_avg:92.63ms +step:679/1670 train_time:62898ms step_avg:92.63ms +step:680/1670 train_time:62990ms step_avg:92.63ms +step:681/1670 train_time:63083ms step_avg:92.63ms +step:682/1670 train_time:63176ms step_avg:92.63ms +step:683/1670 train_time:63268ms step_avg:92.63ms +step:684/1670 train_time:63363ms step_avg:92.64ms +step:685/1670 train_time:63456ms step_avg:92.64ms +step:686/1670 train_time:63548ms step_avg:92.64ms +step:687/1670 train_time:63641ms step_avg:92.64ms +step:688/1670 train_time:63733ms step_avg:92.64ms +step:689/1670 train_time:63825ms step_avg:92.63ms +step:690/1670 train_time:63917ms step_avg:92.63ms +step:691/1670 train_time:64009ms step_avg:92.63ms +step:692/1670 train_time:64102ms step_avg:92.63ms +step:693/1670 train_time:64195ms step_avg:92.63ms +step:694/1670 train_time:64287ms step_avg:92.63ms +step:695/1670 train_time:64380ms step_avg:92.63ms +step:696/1670 train_time:64473ms step_avg:92.63ms +step:697/1670 train_time:64565ms step_avg:92.63ms +step:698/1670 train_time:64658ms step_avg:92.63ms +step:699/1670 train_time:64750ms step_avg:92.63ms +step:700/1670 train_time:64842ms step_avg:92.63ms +step:701/1670 train_time:64935ms step_avg:92.63ms +step:702/1670 train_time:65026ms step_avg:92.63ms +step:703/1670 train_time:65120ms step_avg:92.63ms +step:704/1670 train_time:65211ms step_avg:92.63ms +step:705/1670 train_time:65305ms step_avg:92.63ms +step:706/1670 train_time:65398ms step_avg:92.63ms +step:707/1670 train_time:65492ms step_avg:92.63ms +step:708/1670 train_time:65584ms step_avg:92.63ms +step:709/1670 train_time:65676ms step_avg:92.63ms +step:710/1670 train_time:65769ms step_avg:92.63ms +step:711/1670 train_time:65862ms step_avg:92.63ms +step:712/1670 train_time:65953ms step_avg:92.63ms +step:713/1670 train_time:66045ms step_avg:92.63ms +step:714/1670 train_time:66137ms step_avg:92.63ms +step:715/1670 train_time:66229ms step_avg:92.63ms +step:716/1670 train_time:66323ms step_avg:92.63ms +step:717/1670 train_time:66416ms step_avg:92.63ms +step:718/1670 train_time:66508ms step_avg:92.63ms +step:719/1670 train_time:66601ms step_avg:92.63ms +step:720/1670 train_time:66694ms step_avg:92.63ms +step:721/1670 train_time:66786ms step_avg:92.63ms +step:722/1670 train_time:66880ms step_avg:92.63ms +step:723/1670 train_time:66972ms step_avg:92.63ms +step:724/1670 train_time:67064ms step_avg:92.63ms +step:725/1670 train_time:67157ms step_avg:92.63ms +step:726/1670 train_time:67249ms step_avg:92.63ms +step:727/1670 train_time:67343ms step_avg:92.63ms +step:728/1670 train_time:67436ms step_avg:92.63ms +step:729/1670 train_time:67528ms step_avg:92.63ms +step:730/1670 train_time:67622ms step_avg:92.63ms +step:731/1670 train_time:67714ms step_avg:92.63ms +step:732/1670 train_time:67806ms step_avg:92.63ms +step:733/1670 train_time:67898ms step_avg:92.63ms +step:734/1670 train_time:67990ms step_avg:92.63ms +step:735/1670 train_time:68084ms step_avg:92.63ms +step:736/1670 train_time:68176ms step_avg:92.63ms +step:737/1670 train_time:68269ms step_avg:92.63ms +step:738/1670 train_time:68362ms step_avg:92.63ms +step:739/1670 train_time:68454ms step_avg:92.63ms +step:740/1670 train_time:68547ms step_avg:92.63ms +step:741/1670 train_time:68639ms step_avg:92.63ms +step:742/1670 train_time:68731ms step_avg:92.63ms +step:743/1670 train_time:68825ms step_avg:92.63ms +step:744/1670 train_time:68918ms step_avg:92.63ms +step:745/1670 train_time:69011ms step_avg:92.63ms +step:746/1670 train_time:69103ms step_avg:92.63ms +step:747/1670 train_time:69196ms step_avg:92.63ms +step:748/1670 train_time:69288ms step_avg:92.63ms +step:749/1670 train_time:69382ms step_avg:92.63ms +step:750/1670 train_time:69474ms step_avg:92.63ms +step:750/1670 val_loss:3.5602 train_time:69566ms step_avg:92.76ms +step:751/1670 train_time:69586ms step_avg:92.66ms +step:752/1670 train_time:69660ms step_avg:92.63ms +step:753/1670 train_time:69753ms step_avg:92.63ms +step:754/1670 train_time:69846ms step_avg:92.63ms +step:755/1670 train_time:69938ms step_avg:92.63ms +step:756/1670 train_time:70029ms step_avg:92.63ms +step:757/1670 train_time:70121ms step_avg:92.63ms +step:758/1670 train_time:70213ms step_avg:92.63ms +step:759/1670 train_time:70306ms step_avg:92.63ms +step:760/1670 train_time:70398ms step_avg:92.63ms +step:761/1670 train_time:70491ms step_avg:92.63ms +step:762/1670 train_time:70586ms step_avg:92.63ms +step:763/1670 train_time:70680ms step_avg:92.63ms +step:764/1670 train_time:70774ms step_avg:92.64ms +step:765/1670 train_time:70865ms step_avg:92.63ms +step:766/1670 train_time:70958ms step_avg:92.63ms +step:767/1670 train_time:71050ms step_avg:92.63ms +step:768/1670 train_time:71142ms step_avg:92.63ms +step:769/1670 train_time:71234ms step_avg:92.63ms +step:770/1670 train_time:71326ms step_avg:92.63ms +step:771/1670 train_time:71419ms step_avg:92.63ms +step:772/1670 train_time:71512ms step_avg:92.63ms +step:773/1670 train_time:71605ms step_avg:92.63ms +step:774/1670 train_time:71699ms step_avg:92.63ms +step:775/1670 train_time:71792ms step_avg:92.63ms +step:776/1670 train_time:71884ms step_avg:92.63ms +step:777/1670 train_time:71978ms step_avg:92.64ms +step:778/1670 train_time:72071ms step_avg:92.64ms +step:779/1670 train_time:72163ms step_avg:92.63ms +step:780/1670 train_time:72254ms step_avg:92.63ms +step:781/1670 train_time:72346ms step_avg:92.63ms +step:782/1670 train_time:72439ms step_avg:92.63ms +step:783/1670 train_time:72532ms step_avg:92.63ms +step:784/1670 train_time:72625ms step_avg:92.63ms +step:785/1670 train_time:72719ms step_avg:92.64ms +step:786/1670 train_time:72811ms step_avg:92.64ms +step:787/1670 train_time:72903ms step_avg:92.63ms +step:788/1670 train_time:72997ms step_avg:92.64ms +step:789/1670 train_time:73089ms step_avg:92.63ms +step:790/1670 train_time:73182ms step_avg:92.64ms +step:791/1670 train_time:73274ms step_avg:92.63ms +step:792/1670 train_time:73366ms step_avg:92.63ms +step:793/1670 train_time:73459ms step_avg:92.63ms +step:794/1670 train_time:73552ms step_avg:92.63ms +step:795/1670 train_time:73645ms step_avg:92.64ms +step:796/1670 train_time:73738ms step_avg:92.64ms +step:797/1670 train_time:73831ms step_avg:92.64ms +step:798/1670 train_time:73923ms step_avg:92.64ms +step:799/1670 train_time:74016ms step_avg:92.64ms +step:800/1670 train_time:74109ms step_avg:92.64ms +step:801/1670 train_time:74202ms step_avg:92.64ms +step:802/1670 train_time:74294ms step_avg:92.64ms +step:803/1670 train_time:74386ms step_avg:92.63ms +step:804/1670 train_time:74479ms step_avg:92.64ms +step:805/1670 train_time:74572ms step_avg:92.64ms +step:806/1670 train_time:74665ms step_avg:92.64ms +step:807/1670 train_time:74759ms step_avg:92.64ms +step:808/1670 train_time:74852ms step_avg:92.64ms +step:809/1670 train_time:74945ms step_avg:92.64ms +step:810/1670 train_time:75038ms step_avg:92.64ms +step:811/1670 train_time:75131ms step_avg:92.64ms +step:812/1670 train_time:75223ms step_avg:92.64ms +step:813/1670 train_time:75316ms step_avg:92.64ms +step:814/1670 train_time:75409ms step_avg:92.64ms +step:815/1670 train_time:75502ms step_avg:92.64ms +step:816/1670 train_time:75594ms step_avg:92.64ms +step:817/1670 train_time:75686ms step_avg:92.64ms +step:818/1670 train_time:75780ms step_avg:92.64ms +step:819/1670 train_time:75873ms step_avg:92.64ms +step:820/1670 train_time:75965ms step_avg:92.64ms +step:821/1670 train_time:76058ms step_avg:92.64ms +step:822/1670 train_time:76151ms step_avg:92.64ms +step:823/1670 train_time:76243ms step_avg:92.64ms +step:824/1670 train_time:76336ms step_avg:92.64ms +step:825/1670 train_time:76428ms step_avg:92.64ms +step:826/1670 train_time:76521ms step_avg:92.64ms +step:827/1670 train_time:76614ms step_avg:92.64ms +step:828/1670 train_time:76706ms step_avg:92.64ms +step:829/1670 train_time:76800ms step_avg:92.64ms +step:830/1670 train_time:76892ms step_avg:92.64ms +step:831/1670 train_time:76985ms step_avg:92.64ms +step:832/1670 train_time:77078ms step_avg:92.64ms +step:833/1670 train_time:77171ms step_avg:92.64ms +step:834/1670 train_time:77264ms step_avg:92.64ms +step:835/1670 train_time:77357ms step_avg:92.64ms +step:836/1670 train_time:77450ms step_avg:92.64ms +step:837/1670 train_time:77542ms step_avg:92.64ms +step:838/1670 train_time:77634ms step_avg:92.64ms +step:839/1670 train_time:77726ms step_avg:92.64ms +step:840/1670 train_time:77819ms step_avg:92.64ms +step:841/1670 train_time:77911ms step_avg:92.64ms +step:842/1670 train_time:78004ms step_avg:92.64ms +step:843/1670 train_time:78097ms step_avg:92.64ms +step:844/1670 train_time:78190ms step_avg:92.64ms +step:845/1670 train_time:78284ms step_avg:92.64ms +step:846/1670 train_time:78377ms step_avg:92.64ms +step:847/1670 train_time:78470ms step_avg:92.64ms +step:848/1670 train_time:78562ms step_avg:92.64ms +step:849/1670 train_time:78655ms step_avg:92.64ms +step:850/1670 train_time:78747ms step_avg:92.64ms +step:851/1670 train_time:78998ms step_avg:92.83ms +step:852/1670 train_time:79068ms step_avg:92.80ms +step:853/1670 train_time:79160ms step_avg:92.80ms +step:854/1670 train_time:79252ms step_avg:92.80ms +step:855/1670 train_time:79343ms step_avg:92.80ms +step:856/1670 train_time:79435ms step_avg:92.80ms +step:857/1670 train_time:79526ms step_avg:92.80ms +step:858/1670 train_time:79617ms step_avg:92.79ms +step:859/1670 train_time:79708ms step_avg:92.79ms +step:860/1670 train_time:79800ms step_avg:92.79ms +step:861/1670 train_time:79896ms step_avg:92.79ms +step:862/1670 train_time:79993ms step_avg:92.80ms +step:863/1670 train_time:80087ms step_avg:92.80ms +step:864/1670 train_time:80180ms step_avg:92.80ms +step:865/1670 train_time:80272ms step_avg:92.80ms +step:866/1670 train_time:80363ms step_avg:92.80ms +step:867/1670 train_time:80455ms step_avg:92.80ms +step:868/1670 train_time:80546ms step_avg:92.79ms +step:869/1670 train_time:80638ms step_avg:92.79ms +step:870/1670 train_time:80729ms step_avg:92.79ms +step:871/1670 train_time:80823ms step_avg:92.79ms +step:872/1670 train_time:80919ms step_avg:92.80ms +step:873/1670 train_time:81014ms step_avg:92.80ms +step:874/1670 train_time:81107ms step_avg:92.80ms +step:875/1670 train_time:81200ms step_avg:92.80ms +step:875/1670 val_loss:3.5158 train_time:81293ms step_avg:92.91ms +step:876/1670 train_time:81313ms step_avg:92.82ms +step:877/1670 train_time:81387ms step_avg:92.80ms +step:878/1670 train_time:81482ms step_avg:92.80ms +step:879/1670 train_time:81574ms step_avg:92.80ms +step:880/1670 train_time:81666ms step_avg:92.80ms +step:881/1670 train_time:81757ms step_avg:92.80ms +step:882/1670 train_time:81849ms step_avg:92.80ms +step:883/1670 train_time:81942ms step_avg:92.80ms +step:884/1670 train_time:82034ms step_avg:92.80ms +step:885/1670 train_time:82126ms step_avg:92.80ms +step:886/1670 train_time:82220ms step_avg:92.80ms +step:887/1670 train_time:82314ms step_avg:92.80ms +step:888/1670 train_time:82409ms step_avg:92.80ms +step:889/1670 train_time:82504ms step_avg:92.81ms +step:890/1670 train_time:82596ms step_avg:92.80ms +step:891/1670 train_time:82687ms step_avg:92.80ms +step:892/1670 train_time:82779ms step_avg:92.80ms +step:893/1670 train_time:82871ms step_avg:92.80ms +step:894/1670 train_time:82966ms step_avg:92.80ms +step:895/1670 train_time:83058ms step_avg:92.80ms +step:896/1670 train_time:83150ms step_avg:92.80ms +step:897/1670 train_time:83246ms step_avg:92.80ms +step:898/1670 train_time:83339ms step_avg:92.81ms +step:899/1670 train_time:83431ms step_avg:92.80ms +step:900/1670 train_time:83526ms step_avg:92.81ms +step:901/1670 train_time:83618ms step_avg:92.81ms +step:902/1670 train_time:83710ms step_avg:92.81ms +step:903/1670 train_time:83802ms step_avg:92.80ms +step:904/1670 train_time:83893ms step_avg:92.80ms +step:905/1670 train_time:83986ms step_avg:92.80ms +step:906/1670 train_time:84078ms step_avg:92.80ms +step:907/1670 train_time:84170ms step_avg:92.80ms +step:908/1670 train_time:84263ms step_avg:92.80ms +step:909/1670 train_time:84357ms step_avg:92.80ms +step:910/1670 train_time:84449ms step_avg:92.80ms +step:911/1670 train_time:84544ms step_avg:92.80ms +step:912/1670 train_time:84637ms step_avg:92.80ms +step:913/1670 train_time:84729ms step_avg:92.80ms +step:914/1670 train_time:84822ms step_avg:92.80ms +step:915/1670 train_time:84914ms step_avg:92.80ms +step:916/1670 train_time:85006ms step_avg:92.80ms +step:917/1670 train_time:85099ms step_avg:92.80ms +step:918/1670 train_time:85191ms step_avg:92.80ms +step:919/1670 train_time:85284ms step_avg:92.80ms +step:920/1670 train_time:85376ms step_avg:92.80ms +step:921/1670 train_time:85468ms step_avg:92.80ms +step:922/1670 train_time:85562ms step_avg:92.80ms +step:923/1670 train_time:85656ms step_avg:92.80ms +step:924/1670 train_time:85747ms step_avg:92.80ms +step:925/1670 train_time:85840ms step_avg:92.80ms +step:926/1670 train_time:85933ms step_avg:92.80ms +step:927/1670 train_time:86026ms step_avg:92.80ms +step:928/1670 train_time:86118ms step_avg:92.80ms +step:929/1670 train_time:86211ms step_avg:92.80ms +step:930/1670 train_time:86304ms step_avg:92.80ms +step:931/1670 train_time:86397ms step_avg:92.80ms +step:932/1670 train_time:86489ms step_avg:92.80ms +step:933/1670 train_time:86582ms step_avg:92.80ms +step:934/1670 train_time:86674ms step_avg:92.80ms +step:935/1670 train_time:86767ms step_avg:92.80ms +step:936/1670 train_time:86860ms step_avg:92.80ms +step:937/1670 train_time:86952ms step_avg:92.80ms +step:938/1670 train_time:87046ms step_avg:92.80ms +step:939/1670 train_time:87139ms step_avg:92.80ms +step:940/1670 train_time:87230ms step_avg:92.80ms +step:941/1670 train_time:87324ms step_avg:92.80ms +step:942/1670 train_time:87417ms step_avg:92.80ms +step:943/1670 train_time:87510ms step_avg:92.80ms +step:944/1670 train_time:87604ms step_avg:92.80ms +step:945/1670 train_time:87697ms step_avg:92.80ms +step:946/1670 train_time:87789ms step_avg:92.80ms +step:947/1670 train_time:87882ms step_avg:92.80ms +step:948/1670 train_time:87975ms step_avg:92.80ms +step:949/1670 train_time:88068ms step_avg:92.80ms +step:950/1670 train_time:88161ms step_avg:92.80ms +step:951/1670 train_time:88254ms step_avg:92.80ms +step:952/1670 train_time:88346ms step_avg:92.80ms +step:953/1670 train_time:88439ms step_avg:92.80ms +step:954/1670 train_time:88531ms step_avg:92.80ms +step:955/1670 train_time:88623ms step_avg:92.80ms +step:956/1670 train_time:88716ms step_avg:92.80ms +step:957/1670 train_time:88809ms step_avg:92.80ms +step:958/1670 train_time:88902ms step_avg:92.80ms +step:959/1670 train_time:88993ms step_avg:92.80ms +step:960/1670 train_time:89086ms step_avg:92.80ms +step:961/1670 train_time:89178ms step_avg:92.80ms +step:962/1670 train_time:89270ms step_avg:92.80ms +step:963/1670 train_time:89365ms step_avg:92.80ms +step:964/1670 train_time:89457ms step_avg:92.80ms +step:965/1670 train_time:89549ms step_avg:92.80ms +step:966/1670 train_time:89643ms step_avg:92.80ms +step:967/1670 train_time:89735ms step_avg:92.80ms +step:968/1670 train_time:89827ms step_avg:92.80ms +step:969/1670 train_time:89920ms step_avg:92.80ms +step:970/1670 train_time:90013ms step_avg:92.80ms +step:971/1670 train_time:90104ms step_avg:92.80ms +step:972/1670 train_time:90196ms step_avg:92.79ms +step:973/1670 train_time:90288ms step_avg:92.79ms +step:974/1670 train_time:90382ms step_avg:92.79ms +step:975/1670 train_time:90474ms step_avg:92.79ms +step:976/1670 train_time:90566ms step_avg:92.79ms +step:977/1670 train_time:90659ms step_avg:92.79ms +step:978/1670 train_time:90751ms step_avg:92.79ms +step:979/1670 train_time:90844ms step_avg:92.79ms +step:980/1670 train_time:90937ms step_avg:92.79ms +step:981/1670 train_time:91029ms step_avg:92.79ms +step:982/1670 train_time:91123ms step_avg:92.79ms +step:983/1670 train_time:91216ms step_avg:92.79ms +step:984/1670 train_time:91308ms step_avg:92.79ms +step:985/1670 train_time:91400ms step_avg:92.79ms +step:986/1670 train_time:91492ms step_avg:92.79ms +step:987/1670 train_time:91584ms step_avg:92.79ms +step:988/1670 train_time:91677ms step_avg:92.79ms +step:989/1670 train_time:91770ms step_avg:92.79ms +step:990/1670 train_time:91863ms step_avg:92.79ms +step:991/1670 train_time:91955ms step_avg:92.79ms +step:992/1670 train_time:92048ms step_avg:92.79ms +step:993/1670 train_time:92141ms step_avg:92.79ms +step:994/1670 train_time:92234ms step_avg:92.79ms +step:995/1670 train_time:92326ms step_avg:92.79ms +step:996/1670 train_time:92418ms step_avg:92.79ms +step:997/1670 train_time:92512ms step_avg:92.79ms +step:998/1670 train_time:92604ms step_avg:92.79ms +step:999/1670 train_time:92697ms step_avg:92.79ms +step:1000/1670 train_time:92789ms step_avg:92.79ms +step:1000/1670 val_loss:3.4659 train_time:92882ms step_avg:92.88ms +step:1001/1670 train_time:92902ms step_avg:92.81ms +step:1002/1670 train_time:92979ms step_avg:92.79ms +step:1003/1670 train_time:93071ms step_avg:92.79ms +step:1004/1670 train_time:93162ms step_avg:92.79ms +step:1005/1670 train_time:93254ms step_avg:92.79ms +step:1006/1670 train_time:93347ms step_avg:92.79ms +step:1007/1670 train_time:93440ms step_avg:92.79ms +step:1008/1670 train_time:93534ms step_avg:92.79ms +step:1009/1670 train_time:93625ms step_avg:92.79ms +step:1010/1670 train_time:93717ms step_avg:92.79ms +step:1011/1670 train_time:93810ms step_avg:92.79ms +step:1012/1670 train_time:93905ms step_avg:92.79ms +step:1013/1670 train_time:93998ms step_avg:92.79ms +step:1014/1670 train_time:94092ms step_avg:92.79ms +step:1015/1670 train_time:94183ms step_avg:92.79ms +step:1016/1670 train_time:94276ms step_avg:92.79ms +step:1017/1670 train_time:94368ms step_avg:92.79ms +step:1018/1670 train_time:94461ms step_avg:92.79ms +step:1019/1670 train_time:94553ms step_avg:92.79ms +step:1020/1670 train_time:94644ms step_avg:92.79ms +step:1021/1670 train_time:94736ms step_avg:92.79ms +step:1022/1670 train_time:94829ms step_avg:92.79ms +step:1023/1670 train_time:94922ms step_avg:92.79ms +step:1024/1670 train_time:95016ms step_avg:92.79ms +step:1025/1670 train_time:95108ms step_avg:92.79ms +step:1026/1670 train_time:95200ms step_avg:92.79ms +step:1027/1670 train_time:95293ms step_avg:92.79ms +step:1028/1670 train_time:95385ms step_avg:92.79ms +step:1029/1670 train_time:95477ms step_avg:92.79ms +step:1030/1670 train_time:95570ms step_avg:92.79ms +step:1031/1670 train_time:95663ms step_avg:92.79ms +step:1032/1670 train_time:95755ms step_avg:92.79ms +step:1033/1670 train_time:95848ms step_avg:92.79ms +step:1034/1670 train_time:95941ms step_avg:92.79ms +step:1035/1670 train_time:96034ms step_avg:92.79ms +step:1036/1670 train_time:96126ms step_avg:92.79ms +step:1037/1670 train_time:96219ms step_avg:92.79ms +step:1038/1670 train_time:96311ms step_avg:92.79ms +step:1039/1670 train_time:96403ms step_avg:92.78ms +step:1040/1670 train_time:96496ms step_avg:92.78ms +step:1041/1670 train_time:96588ms step_avg:92.78ms +step:1042/1670 train_time:96681ms step_avg:92.78ms +step:1043/1670 train_time:96775ms step_avg:92.79ms +step:1044/1670 train_time:96868ms step_avg:92.79ms +step:1045/1670 train_time:96960ms step_avg:92.78ms +step:1046/1670 train_time:97053ms step_avg:92.78ms +step:1047/1670 train_time:97145ms step_avg:92.78ms +step:1048/1670 train_time:97238ms step_avg:92.78ms +step:1049/1670 train_time:97331ms step_avg:92.78ms +step:1050/1670 train_time:97423ms step_avg:92.78ms +step:1051/1670 train_time:97516ms step_avg:92.78ms +step:1052/1670 train_time:97608ms step_avg:92.78ms +step:1053/1670 train_time:97701ms step_avg:92.78ms +step:1054/1670 train_time:97794ms step_avg:92.78ms +step:1055/1670 train_time:97886ms step_avg:92.78ms +step:1056/1670 train_time:97979ms step_avg:92.78ms +step:1057/1670 train_time:98072ms step_avg:92.78ms +step:1058/1670 train_time:98165ms step_avg:92.78ms +step:1059/1670 train_time:98260ms step_avg:92.79ms +step:1060/1670 train_time:98352ms step_avg:92.78ms +step:1061/1670 train_time:98444ms step_avg:92.78ms +step:1062/1670 train_time:98698ms step_avg:92.94ms +step:1063/1670 train_time:98766ms step_avg:92.91ms +step:1064/1670 train_time:98858ms step_avg:92.91ms +step:1065/1670 train_time:98949ms step_avg:92.91ms +step:1066/1670 train_time:99040ms step_avg:92.91ms +step:1067/1670 train_time:99131ms step_avg:92.91ms +step:1068/1670 train_time:99222ms step_avg:92.90ms +step:1069/1670 train_time:99313ms step_avg:92.90ms +step:1070/1670 train_time:99404ms step_avg:92.90ms +step:1071/1670 train_time:99496ms step_avg:92.90ms +step:1072/1670 train_time:99592ms step_avg:92.90ms +step:1073/1670 train_time:99690ms step_avg:92.91ms +step:1074/1670 train_time:99784ms step_avg:92.91ms +step:1075/1670 train_time:99876ms step_avg:92.91ms +step:1076/1670 train_time:99968ms step_avg:92.91ms +step:1077/1670 train_time:100060ms step_avg:92.91ms +step:1078/1670 train_time:100151ms step_avg:92.90ms +step:1079/1670 train_time:100242ms step_avg:92.90ms +step:1080/1670 train_time:100334ms step_avg:92.90ms +step:1081/1670 train_time:100425ms step_avg:92.90ms +step:1082/1670 train_time:100519ms step_avg:92.90ms +step:1083/1670 train_time:100615ms step_avg:92.90ms +step:1084/1670 train_time:100708ms step_avg:92.90ms +step:1085/1670 train_time:100803ms step_avg:92.91ms +step:1086/1670 train_time:100898ms step_avg:92.91ms +step:1087/1670 train_time:100991ms step_avg:92.91ms +step:1088/1670 train_time:101082ms step_avg:92.91ms +step:1089/1670 train_time:101174ms step_avg:92.91ms +step:1090/1670 train_time:101265ms step_avg:92.90ms +step:1091/1670 train_time:101357ms step_avg:92.90ms +step:1092/1670 train_time:101448ms step_avg:92.90ms +step:1093/1670 train_time:101541ms step_avg:92.90ms +step:1094/1670 train_time:101634ms step_avg:92.90ms +step:1095/1670 train_time:101728ms step_avg:92.90ms +step:1096/1670 train_time:101820ms step_avg:92.90ms +step:1097/1670 train_time:101913ms step_avg:92.90ms +step:1098/1670 train_time:102006ms step_avg:92.90ms +step:1099/1670 train_time:102099ms step_avg:92.90ms +step:1100/1670 train_time:102191ms step_avg:92.90ms +step:1101/1670 train_time:102282ms step_avg:92.90ms +step:1102/1670 train_time:102374ms step_avg:92.90ms +step:1103/1670 train_time:102466ms step_avg:92.90ms +step:1104/1670 train_time:102560ms step_avg:92.90ms +step:1105/1670 train_time:102653ms step_avg:92.90ms +step:1106/1670 train_time:102745ms step_avg:92.90ms +step:1107/1670 train_time:102838ms step_avg:92.90ms +step:1108/1670 train_time:102931ms step_avg:92.90ms +step:1109/1670 train_time:103023ms step_avg:92.90ms +step:1110/1670 train_time:103116ms step_avg:92.90ms +step:1111/1670 train_time:103208ms step_avg:92.90ms +step:1112/1670 train_time:103302ms step_avg:92.90ms +step:1113/1670 train_time:103394ms step_avg:92.90ms +step:1114/1670 train_time:103485ms step_avg:92.90ms +step:1115/1670 train_time:103773ms step_avg:93.07ms +step:1116/1670 train_time:103842ms step_avg:93.05ms +step:1117/1670 train_time:103933ms step_avg:93.05ms +step:1118/1670 train_time:104025ms step_avg:93.05ms +step:1119/1670 train_time:104117ms step_avg:93.04ms +step:1120/1670 train_time:104208ms step_avg:93.04ms +step:1121/1670 train_time:104300ms step_avg:93.04ms +step:1122/1670 train_time:104392ms step_avg:93.04ms +step:1123/1670 train_time:104484ms step_avg:93.04ms +step:1124/1670 train_time:104576ms step_avg:93.04ms +step:1125/1670 train_time:104673ms step_avg:93.04ms +step:1125/1670 val_loss:3.4132 train_time:104773ms step_avg:93.13ms +step:1126/1670 train_time:104793ms step_avg:93.07ms +step:1127/1670 train_time:104875ms step_avg:93.06ms +step:1128/1670 train_time:104977ms step_avg:93.06ms +step:1129/1670 train_time:105069ms step_avg:93.06ms +step:1130/1670 train_time:105161ms step_avg:93.06ms +step:1131/1670 train_time:105253ms step_avg:93.06ms +step:1132/1670 train_time:105344ms step_avg:93.06ms +step:1133/1670 train_time:105436ms step_avg:93.06ms +step:1134/1670 train_time:105528ms step_avg:93.06ms +step:1135/1670 train_time:105622ms step_avg:93.06ms +step:1136/1670 train_time:105717ms step_avg:93.06ms +step:1137/1670 train_time:105812ms step_avg:93.06ms +step:1138/1670 train_time:105909ms step_avg:93.07ms +step:1139/1670 train_time:106004ms step_avg:93.07ms +step:1140/1670 train_time:106097ms step_avg:93.07ms +step:1141/1670 train_time:106189ms step_avg:93.07ms +step:1142/1670 train_time:106281ms step_avg:93.07ms +step:1143/1670 train_time:106373ms step_avg:93.06ms +step:1144/1670 train_time:106466ms step_avg:93.06ms +step:1145/1670 train_time:106558ms step_avg:93.06ms +step:1146/1670 train_time:106649ms step_avg:93.06ms +step:1147/1670 train_time:106744ms step_avg:93.06ms +step:1148/1670 train_time:106838ms step_avg:93.06ms +step:1149/1670 train_time:106932ms step_avg:93.07ms +step:1150/1670 train_time:107026ms step_avg:93.07ms +step:1151/1670 train_time:107120ms step_avg:93.07ms +step:1152/1670 train_time:107213ms step_avg:93.07ms +step:1153/1670 train_time:107305ms step_avg:93.07ms +step:1154/1670 train_time:107398ms step_avg:93.07ms +step:1155/1670 train_time:107491ms step_avg:93.07ms +step:1156/1670 train_time:107583ms step_avg:93.07ms +step:1157/1670 train_time:107676ms step_avg:93.07ms +step:1158/1670 train_time:107770ms step_avg:93.07ms +step:1159/1670 train_time:107866ms step_avg:93.07ms +step:1160/1670 train_time:107961ms step_avg:93.07ms +step:1161/1670 train_time:108055ms step_avg:93.07ms +step:1162/1670 train_time:108147ms step_avg:93.07ms +step:1163/1670 train_time:108241ms step_avg:93.07ms +step:1164/1670 train_time:108334ms step_avg:93.07ms +step:1165/1670 train_time:108427ms step_avg:93.07ms +step:1166/1670 train_time:108520ms step_avg:93.07ms +step:1167/1670 train_time:108612ms step_avg:93.07ms +step:1168/1670 train_time:108706ms step_avg:93.07ms +step:1169/1670 train_time:108801ms step_avg:93.07ms +step:1170/1670 train_time:108895ms step_avg:93.07ms +step:1171/1670 train_time:108989ms step_avg:93.07ms +step:1172/1670 train_time:109083ms step_avg:93.07ms +step:1173/1670 train_time:109176ms step_avg:93.07ms +step:1174/1670 train_time:109268ms step_avg:93.07ms +step:1175/1670 train_time:109361ms step_avg:93.07ms +step:1176/1670 train_time:109454ms step_avg:93.07ms +step:1177/1670 train_time:109547ms step_avg:93.07ms +step:1178/1670 train_time:109640ms step_avg:93.07ms +step:1179/1670 train_time:109733ms step_avg:93.07ms +step:1180/1670 train_time:109827ms step_avg:93.07ms +step:1181/1670 train_time:109921ms step_avg:93.07ms +step:1182/1670 train_time:110014ms step_avg:93.07ms +step:1183/1670 train_time:110107ms step_avg:93.07ms +step:1184/1670 train_time:110201ms step_avg:93.08ms +step:1185/1670 train_time:110294ms step_avg:93.07ms +step:1186/1670 train_time:110386ms step_avg:93.07ms +step:1187/1670 train_time:110478ms step_avg:93.07ms +step:1188/1670 train_time:110570ms step_avg:93.07ms +step:1189/1670 train_time:110664ms step_avg:93.07ms +step:1190/1670 train_time:110757ms step_avg:93.07ms +step:1191/1670 train_time:110851ms step_avg:93.07ms +step:1192/1670 train_time:110946ms step_avg:93.08ms +step:1193/1670 train_time:111039ms step_avg:93.08ms +step:1194/1670 train_time:111132ms step_avg:93.08ms +step:1195/1670 train_time:111225ms step_avg:93.08ms +step:1196/1670 train_time:111318ms step_avg:93.08ms +step:1197/1670 train_time:111411ms step_avg:93.08ms +step:1198/1670 train_time:111504ms step_avg:93.08ms +step:1199/1670 train_time:111597ms step_avg:93.07ms +step:1200/1670 train_time:111690ms step_avg:93.07ms +step:1201/1670 train_time:111783ms step_avg:93.08ms +step:1202/1670 train_time:111877ms step_avg:93.08ms +step:1203/1670 train_time:111971ms step_avg:93.08ms +step:1204/1670 train_time:112064ms step_avg:93.08ms +step:1205/1670 train_time:112158ms step_avg:93.08ms +step:1206/1670 train_time:112251ms step_avg:93.08ms +step:1207/1670 train_time:112345ms step_avg:93.08ms +step:1208/1670 train_time:112438ms step_avg:93.08ms +step:1209/1670 train_time:112531ms step_avg:93.08ms +step:1210/1670 train_time:112624ms step_avg:93.08ms +step:1211/1670 train_time:112718ms step_avg:93.08ms +step:1212/1670 train_time:112811ms step_avg:93.08ms +step:1213/1670 train_time:112904ms step_avg:93.08ms +step:1214/1670 train_time:112997ms step_avg:93.08ms +step:1215/1670 train_time:113090ms step_avg:93.08ms +step:1216/1670 train_time:113184ms step_avg:93.08ms +step:1217/1670 train_time:113277ms step_avg:93.08ms +step:1218/1670 train_time:113370ms step_avg:93.08ms +step:1219/1670 train_time:113464ms step_avg:93.08ms +step:1220/1670 train_time:113557ms step_avg:93.08ms +step:1221/1670 train_time:113649ms step_avg:93.08ms +step:1222/1670 train_time:113742ms step_avg:93.08ms +step:1223/1670 train_time:113835ms step_avg:93.08ms +step:1224/1670 train_time:113929ms step_avg:93.08ms +step:1225/1670 train_time:114024ms step_avg:93.08ms +step:1226/1670 train_time:114117ms step_avg:93.08ms +step:1227/1670 train_time:114210ms step_avg:93.08ms +step:1228/1670 train_time:114302ms step_avg:93.08ms +step:1229/1670 train_time:114395ms step_avg:93.08ms +step:1230/1670 train_time:114487ms step_avg:93.08ms +step:1231/1670 train_time:114581ms step_avg:93.08ms +step:1232/1670 train_time:114674ms step_avg:93.08ms +step:1233/1670 train_time:114766ms step_avg:93.08ms +step:1234/1670 train_time:114860ms step_avg:93.08ms +step:1235/1670 train_time:114953ms step_avg:93.08ms +step:1236/1670 train_time:115047ms step_avg:93.08ms +step:1237/1670 train_time:115141ms step_avg:93.08ms +step:1238/1670 train_time:115234ms step_avg:93.08ms +step:1239/1670 train_time:115327ms step_avg:93.08ms +step:1240/1670 train_time:115420ms step_avg:93.08ms +step:1241/1670 train_time:115512ms step_avg:93.08ms +step:1242/1670 train_time:115606ms step_avg:93.08ms +step:1243/1670 train_time:115699ms step_avg:93.08ms +step:1244/1670 train_time:115792ms step_avg:93.08ms +step:1245/1670 train_time:115885ms step_avg:93.08ms +step:1246/1670 train_time:115978ms step_avg:93.08ms +step:1247/1670 train_time:116070ms step_avg:93.08ms +step:1248/1670 train_time:116165ms step_avg:93.08ms +step:1249/1670 train_time:116258ms step_avg:93.08ms +step:1250/1670 train_time:116350ms step_avg:93.08ms +step:1250/1670 val_loss:3.3746 train_time:116443ms step_avg:93.15ms +step:1251/1670 train_time:116464ms step_avg:93.10ms +step:1252/1670 train_time:116540ms step_avg:93.08ms +step:1253/1670 train_time:116632ms step_avg:93.08ms +step:1254/1670 train_time:116724ms step_avg:93.08ms +step:1255/1670 train_time:116817ms step_avg:93.08ms +step:1256/1670 train_time:116911ms step_avg:93.08ms +step:1257/1670 train_time:117004ms step_avg:93.08ms +step:1258/1670 train_time:117096ms step_avg:93.08ms +step:1259/1670 train_time:117190ms step_avg:93.08ms +step:1260/1670 train_time:117283ms step_avg:93.08ms +step:1261/1670 train_time:117378ms step_avg:93.08ms +step:1262/1670 train_time:117475ms step_avg:93.09ms +step:1263/1670 train_time:117569ms step_avg:93.09ms +step:1264/1670 train_time:117662ms step_avg:93.09ms +step:1265/1670 train_time:117754ms step_avg:93.09ms +step:1266/1670 train_time:117846ms step_avg:93.09ms +step:1267/1670 train_time:117939ms step_avg:93.09ms +step:1268/1670 train_time:118031ms step_avg:93.08ms +step:1269/1670 train_time:118124ms step_avg:93.08ms +step:1270/1670 train_time:118216ms step_avg:93.08ms +step:1271/1670 train_time:118310ms step_avg:93.08ms +step:1272/1670 train_time:118405ms step_avg:93.09ms +step:1273/1670 train_time:118499ms step_avg:93.09ms +step:1274/1670 train_time:118736ms step_avg:93.20ms +step:1275/1670 train_time:118823ms step_avg:93.19ms +step:1276/1670 train_time:118914ms step_avg:93.19ms +step:1277/1670 train_time:119005ms step_avg:93.19ms +step:1278/1670 train_time:119097ms step_avg:93.19ms +step:1279/1670 train_time:119189ms step_avg:93.19ms +step:1280/1670 train_time:119281ms step_avg:93.19ms +step:1281/1670 train_time:119373ms step_avg:93.19ms +step:1282/1670 train_time:119465ms step_avg:93.19ms +step:1283/1670 train_time:119557ms step_avg:93.19ms +step:1284/1670 train_time:119655ms step_avg:93.19ms +step:1285/1670 train_time:119755ms step_avg:93.19ms +step:1286/1670 train_time:119850ms step_avg:93.20ms +step:1287/1670 train_time:119943ms step_avg:93.20ms +step:1288/1670 train_time:120035ms step_avg:93.20ms +step:1289/1670 train_time:120128ms step_avg:93.19ms +step:1290/1670 train_time:120220ms step_avg:93.19ms +step:1291/1670 train_time:120312ms step_avg:93.19ms +step:1292/1670 train_time:120405ms step_avg:93.19ms +step:1293/1670 train_time:120497ms step_avg:93.19ms +step:1294/1670 train_time:120593ms step_avg:93.19ms +step:1295/1670 train_time:120688ms step_avg:93.20ms +step:1296/1670 train_time:120785ms step_avg:93.20ms +step:1297/1670 train_time:120878ms step_avg:93.20ms +step:1298/1670 train_time:120972ms step_avg:93.20ms +step:1299/1670 train_time:121064ms step_avg:93.20ms +step:1300/1670 train_time:121157ms step_avg:93.20ms +step:1301/1670 train_time:121249ms step_avg:93.20ms +step:1302/1670 train_time:121341ms step_avg:93.20ms +step:1303/1670 train_time:121433ms step_avg:93.19ms +step:1304/1670 train_time:121525ms step_avg:93.19ms +step:1305/1670 train_time:121618ms step_avg:93.19ms +step:1306/1670 train_time:121714ms step_avg:93.20ms +step:1307/1670 train_time:121808ms step_avg:93.20ms +step:1308/1670 train_time:121901ms step_avg:93.20ms +step:1309/1670 train_time:121995ms step_avg:93.20ms +step:1310/1670 train_time:122088ms step_avg:93.20ms +step:1311/1670 train_time:122180ms step_avg:93.20ms +step:1312/1670 train_time:122273ms step_avg:93.20ms +step:1313/1670 train_time:122365ms step_avg:93.20ms +step:1314/1670 train_time:122459ms step_avg:93.20ms +step:1315/1670 train_time:122551ms step_avg:93.19ms +step:1316/1670 train_time:122645ms step_avg:93.20ms +step:1317/1670 train_time:122739ms step_avg:93.20ms +step:1318/1670 train_time:122836ms step_avg:93.20ms +step:1319/1670 train_time:122929ms step_avg:93.20ms +step:1320/1670 train_time:123022ms step_avg:93.20ms +step:1321/1670 train_time:123115ms step_avg:93.20ms +step:1322/1670 train_time:123207ms step_avg:93.20ms +step:1323/1670 train_time:123301ms step_avg:93.20ms +step:1324/1670 train_time:123393ms step_avg:93.20ms +step:1325/1670 train_time:123486ms step_avg:93.20ms +step:1326/1670 train_time:123579ms step_avg:93.20ms +step:1327/1670 train_time:123673ms step_avg:93.20ms +step:1328/1670 train_time:123766ms step_avg:93.20ms +step:1329/1670 train_time:123859ms step_avg:93.20ms +step:1330/1670 train_time:123955ms step_avg:93.20ms +step:1331/1670 train_time:124048ms step_avg:93.20ms +step:1332/1670 train_time:124140ms step_avg:93.20ms +step:1333/1670 train_time:124233ms step_avg:93.20ms +step:1334/1670 train_time:124325ms step_avg:93.20ms +step:1335/1670 train_time:124418ms step_avg:93.20ms +step:1336/1670 train_time:124512ms step_avg:93.20ms +step:1337/1670 train_time:124605ms step_avg:93.20ms +step:1338/1670 train_time:124698ms step_avg:93.20ms +step:1339/1670 train_time:124793ms step_avg:93.20ms +step:1340/1670 train_time:124887ms step_avg:93.20ms +step:1341/1670 train_time:124980ms step_avg:93.20ms +step:1342/1670 train_time:125074ms step_avg:93.20ms +step:1343/1670 train_time:125166ms step_avg:93.20ms +step:1344/1670 train_time:125258ms step_avg:93.20ms +step:1345/1670 train_time:125352ms step_avg:93.20ms +step:1346/1670 train_time:125445ms step_avg:93.20ms +step:1347/1670 train_time:125538ms step_avg:93.20ms +step:1348/1670 train_time:125632ms step_avg:93.20ms +step:1349/1670 train_time:125725ms step_avg:93.20ms +step:1350/1670 train_time:125818ms step_avg:93.20ms +step:1351/1670 train_time:125912ms step_avg:93.20ms +step:1352/1670 train_time:126006ms step_avg:93.20ms +step:1353/1670 train_time:126099ms step_avg:93.20ms +step:1354/1670 train_time:126193ms step_avg:93.20ms +step:1355/1670 train_time:126285ms step_avg:93.20ms +step:1356/1670 train_time:126378ms step_avg:93.20ms +step:1357/1670 train_time:126472ms step_avg:93.20ms +step:1358/1670 train_time:126565ms step_avg:93.20ms +step:1359/1670 train_time:126657ms step_avg:93.20ms +step:1360/1670 train_time:126750ms step_avg:93.20ms +step:1361/1670 train_time:126842ms step_avg:93.20ms +step:1362/1670 train_time:126936ms step_avg:93.20ms +step:1363/1670 train_time:127031ms step_avg:93.20ms +step:1364/1670 train_time:127125ms step_avg:93.20ms +step:1365/1670 train_time:127218ms step_avg:93.20ms +step:1366/1670 train_time:127312ms step_avg:93.20ms +step:1367/1670 train_time:127406ms step_avg:93.20ms +step:1368/1670 train_time:127499ms step_avg:93.20ms +step:1369/1670 train_time:127593ms step_avg:93.20ms +step:1370/1670 train_time:127686ms step_avg:93.20ms +step:1371/1670 train_time:127780ms step_avg:93.20ms +step:1372/1670 train_time:127873ms step_avg:93.20ms +step:1373/1670 train_time:127967ms step_avg:93.20ms +step:1374/1670 train_time:128060ms step_avg:93.20ms +step:1375/1670 train_time:128154ms step_avg:93.20ms +step:1375/1670 val_loss:3.3400 train_time:128245ms step_avg:93.27ms +step:1376/1670 train_time:128265ms step_avg:93.22ms +step:1377/1670 train_time:128340ms step_avg:93.20ms +step:1378/1670 train_time:128433ms step_avg:93.20ms +step:1379/1670 train_time:128525ms step_avg:93.20ms +step:1380/1670 train_time:128621ms step_avg:93.20ms +step:1381/1670 train_time:128714ms step_avg:93.20ms +step:1382/1670 train_time:128806ms step_avg:93.20ms +step:1383/1670 train_time:128899ms step_avg:93.20ms +step:1384/1670 train_time:128993ms step_avg:93.20ms +step:1385/1670 train_time:129086ms step_avg:93.20ms +step:1386/1670 train_time:129179ms step_avg:93.20ms +step:1387/1670 train_time:129274ms step_avg:93.20ms +step:1388/1670 train_time:129367ms step_avg:93.20ms +step:1389/1670 train_time:129461ms step_avg:93.20ms +step:1390/1670 train_time:129553ms step_avg:93.20ms +step:1391/1670 train_time:129646ms step_avg:93.20ms +step:1392/1670 train_time:129740ms step_avg:93.20ms +step:1393/1670 train_time:129834ms step_avg:93.20ms +step:1394/1670 train_time:129926ms step_avg:93.20ms +step:1395/1670 train_time:130019ms step_avg:93.20ms +step:1396/1670 train_time:130112ms step_avg:93.20ms +step:1397/1670 train_time:130206ms step_avg:93.20ms +step:1398/1670 train_time:130300ms step_avg:93.20ms +step:1399/1670 train_time:130393ms step_avg:93.20ms +step:1400/1670 train_time:130486ms step_avg:93.20ms +step:1401/1670 train_time:130579ms step_avg:93.20ms +step:1402/1670 train_time:130672ms step_avg:93.20ms +step:1403/1670 train_time:130765ms step_avg:93.20ms +step:1404/1670 train_time:130860ms step_avg:93.20ms +step:1405/1670 train_time:130954ms step_avg:93.21ms +step:1406/1670 train_time:131047ms step_avg:93.21ms +step:1407/1670 train_time:131140ms step_avg:93.21ms +step:1408/1670 train_time:131233ms step_avg:93.21ms +step:1409/1670 train_time:131327ms step_avg:93.21ms +step:1410/1670 train_time:131420ms step_avg:93.21ms +step:1411/1670 train_time:131514ms step_avg:93.21ms +step:1412/1670 train_time:131607ms step_avg:93.21ms +step:1413/1670 train_time:131700ms step_avg:93.21ms +step:1414/1670 train_time:131793ms step_avg:93.21ms +step:1415/1670 train_time:131888ms step_avg:93.21ms +step:1416/1670 train_time:131982ms step_avg:93.21ms +step:1417/1670 train_time:132075ms step_avg:93.21ms +step:1418/1670 train_time:132168ms step_avg:93.21ms +step:1419/1670 train_time:132262ms step_avg:93.21ms +step:1420/1670 train_time:132355ms step_avg:93.21ms +step:1421/1670 train_time:132449ms step_avg:93.21ms +step:1422/1670 train_time:132542ms step_avg:93.21ms +step:1423/1670 train_time:132635ms step_avg:93.21ms +step:1424/1670 train_time:132729ms step_avg:93.21ms +step:1425/1670 train_time:132822ms step_avg:93.21ms +step:1426/1670 train_time:132916ms step_avg:93.21ms +step:1427/1670 train_time:133009ms step_avg:93.21ms +step:1428/1670 train_time:133103ms step_avg:93.21ms +step:1429/1670 train_time:133196ms step_avg:93.21ms +step:1430/1670 train_time:133290ms step_avg:93.21ms +step:1431/1670 train_time:133384ms step_avg:93.21ms +step:1432/1670 train_time:133476ms step_avg:93.21ms +step:1433/1670 train_time:133569ms step_avg:93.21ms +step:1434/1670 train_time:133663ms step_avg:93.21ms +step:1435/1670 train_time:133756ms step_avg:93.21ms +step:1436/1670 train_time:133849ms step_avg:93.21ms +step:1437/1670 train_time:133942ms step_avg:93.21ms +step:1438/1670 train_time:134035ms step_avg:93.21ms +step:1439/1670 train_time:134129ms step_avg:93.21ms +step:1440/1670 train_time:134222ms step_avg:93.21ms +step:1441/1670 train_time:134317ms step_avg:93.21ms +step:1442/1670 train_time:134410ms step_avg:93.21ms +step:1443/1670 train_time:134503ms step_avg:93.21ms +step:1444/1670 train_time:134597ms step_avg:93.21ms +step:1445/1670 train_time:134690ms step_avg:93.21ms +step:1446/1670 train_time:134784ms step_avg:93.21ms +step:1447/1670 train_time:134877ms step_avg:93.21ms +step:1448/1670 train_time:134969ms step_avg:93.21ms +step:1449/1670 train_time:135063ms step_avg:93.21ms +step:1450/1670 train_time:135157ms step_avg:93.21ms +step:1451/1670 train_time:135251ms step_avg:93.21ms +step:1452/1670 train_time:135344ms step_avg:93.21ms +step:1453/1670 train_time:135438ms step_avg:93.21ms +step:1454/1670 train_time:135531ms step_avg:93.21ms +step:1455/1670 train_time:135624ms step_avg:93.21ms +step:1456/1670 train_time:135719ms step_avg:93.21ms +step:1457/1670 train_time:135811ms step_avg:93.21ms +step:1458/1670 train_time:135904ms step_avg:93.21ms +step:1459/1670 train_time:135998ms step_avg:93.21ms +step:1460/1670 train_time:136091ms step_avg:93.21ms +step:1461/1670 train_time:136184ms step_avg:93.21ms +step:1462/1670 train_time:136278ms step_avg:93.21ms +step:1463/1670 train_time:136370ms step_avg:93.21ms +step:1464/1670 train_time:136464ms step_avg:93.21ms +step:1465/1670 train_time:136558ms step_avg:93.21ms +step:1466/1670 train_time:136651ms step_avg:93.21ms +step:1467/1670 train_time:136744ms step_avg:93.21ms +step:1468/1670 train_time:136837ms step_avg:93.21ms +step:1469/1670 train_time:136930ms step_avg:93.21ms +step:1470/1670 train_time:137024ms step_avg:93.21ms +step:1471/1670 train_time:137117ms step_avg:93.21ms +step:1472/1670 train_time:137210ms step_avg:93.21ms +step:1473/1670 train_time:137303ms step_avg:93.21ms +step:1474/1670 train_time:137396ms step_avg:93.21ms +step:1475/1670 train_time:137490ms step_avg:93.21ms +step:1476/1670 train_time:137583ms step_avg:93.21ms +step:1477/1670 train_time:137677ms step_avg:93.21ms +step:1478/1670 train_time:137769ms step_avg:93.21ms +step:1479/1670 train_time:137863ms step_avg:93.21ms +step:1480/1670 train_time:137957ms step_avg:93.21ms +step:1481/1670 train_time:138051ms step_avg:93.21ms +step:1482/1670 train_time:138143ms step_avg:93.21ms +step:1483/1670 train_time:138237ms step_avg:93.21ms +step:1484/1670 train_time:138330ms step_avg:93.21ms +step:1485/1670 train_time:138580ms step_avg:93.32ms +step:1486/1670 train_time:138652ms step_avg:93.31ms +step:1487/1670 train_time:138744ms step_avg:93.30ms +step:1488/1670 train_time:138836ms step_avg:93.30ms +step:1489/1670 train_time:138928ms step_avg:93.30ms +step:1490/1670 train_time:139020ms step_avg:93.30ms +step:1491/1670 train_time:139112ms step_avg:93.30ms +step:1492/1670 train_time:139204ms step_avg:93.30ms +step:1493/1670 train_time:139296ms step_avg:93.30ms +step:1494/1670 train_time:139388ms step_avg:93.30ms +step:1495/1670 train_time:139486ms step_avg:93.30ms +step:1496/1670 train_time:139586ms step_avg:93.31ms +step:1497/1670 train_time:139681ms step_avg:93.31ms +step:1498/1670 train_time:139774ms step_avg:93.31ms +step:1499/1670 train_time:139866ms step_avg:93.31ms +step:1500/1670 train_time:139958ms step_avg:93.31ms +step:1500/1670 val_loss:3.3104 train_time:140052ms step_avg:93.37ms +step:1501/1670 train_time:140072ms step_avg:93.32ms +step:1502/1670 train_time:140146ms step_avg:93.31ms +step:1503/1670 train_time:140239ms step_avg:93.31ms +step:1504/1670 train_time:140331ms step_avg:93.31ms +step:1505/1670 train_time:140425ms step_avg:93.31ms +step:1506/1670 train_time:140517ms step_avg:93.30ms +step:1507/1670 train_time:140609ms step_avg:93.30ms +step:1508/1670 train_time:140705ms step_avg:93.31ms +step:1509/1670 train_time:140799ms step_avg:93.31ms +step:1510/1670 train_time:140893ms step_avg:93.31ms +step:1511/1670 train_time:140989ms step_avg:93.31ms +step:1512/1670 train_time:141084ms step_avg:93.31ms +step:1513/1670 train_time:141176ms step_avg:93.31ms +step:1514/1670 train_time:141270ms step_avg:93.31ms +step:1515/1670 train_time:141363ms step_avg:93.31ms +step:1516/1670 train_time:141458ms step_avg:93.31ms +step:1517/1670 train_time:141550ms step_avg:93.31ms +step:1518/1670 train_time:141643ms step_avg:93.31ms +step:1519/1670 train_time:141736ms step_avg:93.31ms +step:1520/1670 train_time:141829ms step_avg:93.31ms +step:1521/1670 train_time:141923ms step_avg:93.31ms +step:1522/1670 train_time:142016ms step_avg:93.31ms +step:1523/1670 train_time:142110ms step_avg:93.31ms +step:1524/1670 train_time:142204ms step_avg:93.31ms +step:1525/1670 train_time:142298ms step_avg:93.31ms +step:1526/1670 train_time:142390ms step_avg:93.31ms +step:1527/1670 train_time:142483ms step_avg:93.31ms +step:1528/1670 train_time:142577ms step_avg:93.31ms +step:1529/1670 train_time:142671ms step_avg:93.31ms +step:1530/1670 train_time:142764ms step_avg:93.31ms +step:1531/1670 train_time:142858ms step_avg:93.31ms +step:1532/1670 train_time:142950ms step_avg:93.31ms +step:1533/1670 train_time:143044ms step_avg:93.31ms +step:1534/1670 train_time:143138ms step_avg:93.31ms +step:1535/1670 train_time:143232ms step_avg:93.31ms +step:1536/1670 train_time:143325ms step_avg:93.31ms +step:1537/1670 train_time:143418ms step_avg:93.31ms +step:1538/1670 train_time:143510ms step_avg:93.31ms +step:1539/1670 train_time:143605ms step_avg:93.31ms +step:1540/1670 train_time:143699ms step_avg:93.31ms +step:1541/1670 train_time:143792ms step_avg:93.31ms +step:1542/1670 train_time:143886ms step_avg:93.31ms +step:1543/1670 train_time:143980ms step_avg:93.31ms +step:1544/1670 train_time:144073ms step_avg:93.31ms +step:1545/1670 train_time:144167ms step_avg:93.31ms +step:1546/1670 train_time:144260ms step_avg:93.31ms +step:1547/1670 train_time:144353ms step_avg:93.31ms +step:1548/1670 train_time:144446ms step_avg:93.31ms +step:1549/1670 train_time:144539ms step_avg:93.31ms +step:1550/1670 train_time:144632ms step_avg:93.31ms +step:1551/1670 train_time:144727ms step_avg:93.31ms +step:1552/1670 train_time:144820ms step_avg:93.31ms +step:1553/1670 train_time:144913ms step_avg:93.31ms +step:1554/1670 train_time:145009ms step_avg:93.31ms +step:1555/1670 train_time:145103ms step_avg:93.31ms +step:1556/1670 train_time:145195ms step_avg:93.31ms +step:1557/1670 train_time:145287ms step_avg:93.31ms +step:1558/1670 train_time:145380ms step_avg:93.31ms +step:1559/1670 train_time:145473ms step_avg:93.31ms +step:1560/1670 train_time:145566ms step_avg:93.31ms +step:1561/1670 train_time:145660ms step_avg:93.31ms +step:1562/1670 train_time:145753ms step_avg:93.31ms +step:1563/1670 train_time:145846ms step_avg:93.31ms +step:1564/1670 train_time:145940ms step_avg:93.31ms +step:1565/1670 train_time:146033ms step_avg:93.31ms +step:1566/1670 train_time:146127ms step_avg:93.31ms +step:1567/1670 train_time:146220ms step_avg:93.31ms +step:1568/1670 train_time:146312ms step_avg:93.31ms +step:1569/1670 train_time:146406ms step_avg:93.31ms +step:1570/1670 train_time:146500ms step_avg:93.31ms +step:1571/1670 train_time:146593ms step_avg:93.31ms +step:1572/1670 train_time:146686ms step_avg:93.31ms +step:1573/1670 train_time:146780ms step_avg:93.31ms +step:1574/1670 train_time:146873ms step_avg:93.31ms +step:1575/1670 train_time:146967ms step_avg:93.31ms +step:1576/1670 train_time:147061ms step_avg:93.31ms +step:1577/1670 train_time:147154ms step_avg:93.31ms +step:1578/1670 train_time:147247ms step_avg:93.31ms +step:1579/1670 train_time:147341ms step_avg:93.31ms +step:1580/1670 train_time:147434ms step_avg:93.31ms +step:1581/1670 train_time:147527ms step_avg:93.31ms +step:1582/1670 train_time:147621ms step_avg:93.31ms +step:1583/1670 train_time:147713ms step_avg:93.31ms +step:1584/1670 train_time:147807ms step_avg:93.31ms +step:1585/1670 train_time:147901ms step_avg:93.31ms +step:1586/1670 train_time:147994ms step_avg:93.31ms +step:1587/1670 train_time:148088ms step_avg:93.31ms +step:1588/1670 train_time:148182ms step_avg:93.31ms +step:1589/1670 train_time:148276ms step_avg:93.31ms +step:1590/1670 train_time:148368ms step_avg:93.31ms +step:1591/1670 train_time:148461ms step_avg:93.31ms +step:1592/1670 train_time:148555ms step_avg:93.31ms +step:1593/1670 train_time:148648ms step_avg:93.31ms +step:1594/1670 train_time:148742ms step_avg:93.31ms +step:1595/1670 train_time:148835ms step_avg:93.31ms +step:1596/1670 train_time:148928ms step_avg:93.31ms +step:1597/1670 train_time:149022ms step_avg:93.31ms +step:1598/1670 train_time:149115ms step_avg:93.31ms +step:1599/1670 train_time:149208ms step_avg:93.31ms +step:1600/1670 train_time:149302ms step_avg:93.31ms +step:1601/1670 train_time:149396ms step_avg:93.31ms +step:1602/1670 train_time:149488ms step_avg:93.31ms +step:1603/1670 train_time:149581ms step_avg:93.31ms +step:1604/1670 train_time:149674ms step_avg:93.31ms +step:1605/1670 train_time:149768ms step_avg:93.31ms +step:1606/1670 train_time:149861ms step_avg:93.31ms +step:1607/1670 train_time:149955ms step_avg:93.31ms +step:1608/1670 train_time:150048ms step_avg:93.31ms +step:1609/1670 train_time:150143ms step_avg:93.31ms +step:1610/1670 train_time:150236ms step_avg:93.31ms +step:1611/1670 train_time:150328ms step_avg:93.31ms +step:1612/1670 train_time:150421ms step_avg:93.31ms +step:1613/1670 train_time:150514ms step_avg:93.31ms +step:1614/1670 train_time:150608ms step_avg:93.31ms +step:1615/1670 train_time:150701ms step_avg:93.31ms +step:1616/1670 train_time:150795ms step_avg:93.31ms +step:1617/1670 train_time:150888ms step_avg:93.31ms +step:1618/1670 train_time:150980ms step_avg:93.31ms +step:1619/1670 train_time:151075ms step_avg:93.31ms +step:1620/1670 train_time:151168ms step_avg:93.31ms +step:1621/1670 train_time:151261ms step_avg:93.31ms +step:1622/1670 train_time:151354ms step_avg:93.31ms +step:1623/1670 train_time:151447ms step_avg:93.31ms +step:1624/1670 train_time:151541ms step_avg:93.31ms +step:1625/1670 train_time:151635ms step_avg:93.31ms +step:1625/1670 val_loss:3.2851 train_time:151727ms step_avg:93.37ms +step:1626/1670 train_time:151747ms step_avg:93.33ms +step:1627/1670 train_time:151822ms step_avg:93.31ms +step:1628/1670 train_time:151916ms step_avg:93.31ms +step:1629/1670 train_time:152008ms step_avg:93.31ms +step:1630/1670 train_time:152101ms step_avg:93.31ms +step:1631/1670 train_time:152194ms step_avg:93.31ms +step:1632/1670 train_time:152286ms step_avg:93.31ms +step:1633/1670 train_time:152379ms step_avg:93.31ms +step:1634/1670 train_time:152472ms step_avg:93.31ms +step:1635/1670 train_time:152565ms step_avg:93.31ms +step:1636/1670 train_time:152660ms step_avg:93.31ms +step:1637/1670 train_time:152754ms step_avg:93.31ms +step:1638/1670 train_time:152848ms step_avg:93.31ms +step:1639/1670 train_time:152942ms step_avg:93.31ms +step:1640/1670 train_time:153035ms step_avg:93.31ms +step:1641/1670 train_time:153128ms step_avg:93.31ms +step:1642/1670 train_time:153220ms step_avg:93.31ms +step:1643/1670 train_time:153314ms step_avg:93.31ms +step:1644/1670 train_time:153406ms step_avg:93.31ms +step:1645/1670 train_time:153499ms step_avg:93.31ms +step:1646/1670 train_time:153592ms step_avg:93.31ms +step:1647/1670 train_time:153686ms step_avg:93.31ms +step:1648/1670 train_time:153780ms step_avg:93.31ms +step:1649/1670 train_time:153874ms step_avg:93.31ms +step:1650/1670 train_time:153968ms step_avg:93.31ms +step:1651/1670 train_time:154061ms step_avg:93.31ms +step:1652/1670 train_time:154155ms step_avg:93.31ms +step:1653/1670 train_time:154247ms step_avg:93.31ms +step:1654/1670 train_time:154340ms step_avg:93.31ms +step:1655/1670 train_time:154433ms step_avg:93.31ms +step:1656/1670 train_time:154526ms step_avg:93.31ms +step:1657/1670 train_time:154621ms step_avg:93.31ms +step:1658/1670 train_time:154715ms step_avg:93.31ms +step:1659/1670 train_time:154808ms step_avg:93.31ms +step:1660/1670 train_time:154900ms step_avg:93.31ms +step:1661/1670 train_time:154993ms step_avg:93.31ms +step:1662/1670 train_time:155087ms step_avg:93.31ms +step:1663/1670 train_time:155180ms step_avg:93.31ms +step:1664/1670 train_time:155275ms step_avg:93.31ms +step:1665/1670 train_time:155367ms step_avg:93.31ms +step:1666/1670 train_time:155461ms step_avg:93.31ms +step:1667/1670 train_time:155553ms step_avg:93.31ms +step:1668/1670 train_time:155646ms step_avg:93.31ms +step:1669/1670 train_time:155741ms step_avg:93.31ms +step:1670/1670 train_time:155835ms step_avg:93.31ms +step:1670/1670 val_loss:3.2767 train_time:156100ms step_avg:93.47ms +peak memory allocated: 32002 MiB reserved: 46934 MiB diff --git a/records/091125_VectSigmoidBFloat16/0d451b7e-6500-41ae-ac9d-02352e611b88.txt b/records/091125_VectSigmoidBFloat16/0d451b7e-6500-41ae-ac9d-02352e611b88.txt new file mode 100644 index 000000000..9690e3139 --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/0d451b7e-6500-41ae-ac9d-02352e611b88.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:47:47 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 135W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 46C P0 126W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 46C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:296ms step_avg:295.61ms +step:2/1670 train_time:314ms step_avg:157.20ms +step:3/1670 train_time:382ms step_avg:127.43ms +step:4/1670 train_time:471ms step_avg:117.84ms +step:5/1670 train_time:562ms step_avg:112.38ms +step:6/1670 train_time:652ms step_avg:108.64ms +step:7/1670 train_time:743ms step_avg:106.07ms +step:8/1670 train_time:833ms step_avg:104.16ms +step:9/1670 train_time:924ms step_avg:102.68ms +step:10/1670 train_time:1015ms step_avg:101.46ms +step:11/1670 train_time:1106ms step_avg:100.51ms +step:12/1670 train_time:1199ms step_avg:99.91ms +step:13/1670 train_time:1294ms step_avg:99.50ms +step:14/1670 train_time:1386ms step_avg:99.02ms +step:15/1670 train_time:1478ms step_avg:98.50ms +step:16/1670 train_time:1570ms step_avg:98.10ms +step:17/1670 train_time:1662ms step_avg:97.75ms +step:18/1670 train_time:1753ms step_avg:97.36ms +step:19/1670 train_time:1843ms step_avg:97.02ms +step:20/1670 train_time:1934ms step_avg:96.69ms +step:21/1670 train_time:2026ms step_avg:96.48ms +step:22/1670 train_time:2118ms step_avg:96.26ms +step:23/1670 train_time:2211ms step_avg:96.12ms +step:24/1670 train_time:2304ms step_avg:96.01ms +step:25/1670 train_time:2396ms step_avg:95.86ms +step:26/1670 train_time:2488ms step_avg:95.69ms +step:27/1670 train_time:2579ms step_avg:95.53ms +step:28/1670 train_time:2671ms step_avg:95.39ms +step:29/1670 train_time:2763ms step_avg:95.29ms +step:30/1670 train_time:2854ms step_avg:95.13ms +step:31/1670 train_time:2945ms step_avg:95.01ms +step:32/1670 train_time:3036ms step_avg:94.87ms +step:33/1670 train_time:3128ms step_avg:94.78ms +step:34/1670 train_time:3220ms step_avg:94.69ms +step:35/1670 train_time:3312ms step_avg:94.62ms +step:36/1670 train_time:3405ms step_avg:94.57ms +step:37/1670 train_time:3496ms step_avg:94.48ms +step:38/1670 train_time:3587ms step_avg:94.40ms +step:39/1670 train_time:3679ms step_avg:94.34ms +step:40/1670 train_time:3772ms step_avg:94.30ms +step:41/1670 train_time:3864ms step_avg:94.25ms +step:42/1670 train_time:3955ms step_avg:94.18ms +step:43/1670 train_time:4048ms step_avg:94.13ms +step:44/1670 train_time:4140ms step_avg:94.09ms +step:45/1670 train_time:4231ms step_avg:94.02ms +step:46/1670 train_time:4322ms step_avg:93.97ms +step:47/1670 train_time:4414ms step_avg:93.91ms +step:48/1670 train_time:4505ms step_avg:93.86ms +step:49/1670 train_time:4597ms step_avg:93.82ms +step:50/1670 train_time:4689ms step_avg:93.78ms +step:51/1670 train_time:4780ms step_avg:93.73ms +step:52/1670 train_time:4872ms step_avg:93.69ms +step:53/1670 train_time:4964ms step_avg:93.66ms +step:54/1670 train_time:5055ms step_avg:93.62ms +step:55/1670 train_time:5149ms step_avg:93.61ms +step:56/1670 train_time:5241ms step_avg:93.59ms +step:57/1670 train_time:5333ms step_avg:93.56ms +step:58/1670 train_time:5425ms step_avg:93.54ms +step:59/1670 train_time:5516ms step_avg:93.50ms +step:60/1670 train_time:5608ms step_avg:93.46ms +step:61/1670 train_time:5699ms step_avg:93.42ms +step:62/1670 train_time:5790ms step_avg:93.39ms +step:63/1670 train_time:5881ms step_avg:93.35ms +step:64/1670 train_time:5972ms step_avg:93.32ms +step:65/1670 train_time:6064ms step_avg:93.29ms +step:66/1670 train_time:6155ms step_avg:93.26ms +step:67/1670 train_time:6248ms step_avg:93.25ms +step:68/1670 train_time:6339ms step_avg:93.23ms +step:69/1670 train_time:6432ms step_avg:93.21ms +step:70/1670 train_time:6524ms step_avg:93.20ms +step:71/1670 train_time:6615ms step_avg:93.18ms +step:72/1670 train_time:6706ms step_avg:93.14ms +step:73/1670 train_time:6798ms step_avg:93.12ms +step:74/1670 train_time:6890ms step_avg:93.11ms +step:75/1670 train_time:6982ms step_avg:93.09ms +step:76/1670 train_time:7073ms step_avg:93.07ms +step:77/1670 train_time:7166ms step_avg:93.07ms +step:78/1670 train_time:7258ms step_avg:93.05ms +step:79/1670 train_time:7350ms step_avg:93.04ms +step:80/1670 train_time:7443ms step_avg:93.03ms +step:81/1670 train_time:7534ms step_avg:93.01ms +step:82/1670 train_time:7626ms step_avg:93.00ms +step:83/1670 train_time:7717ms step_avg:92.98ms +step:84/1670 train_time:7809ms step_avg:92.97ms +step:85/1670 train_time:7900ms step_avg:92.94ms +step:86/1670 train_time:7991ms step_avg:92.92ms +step:87/1670 train_time:8083ms step_avg:92.91ms +step:88/1670 train_time:8174ms step_avg:92.88ms +step:89/1670 train_time:8267ms step_avg:92.89ms +step:90/1670 train_time:8360ms step_avg:92.89ms +step:91/1670 train_time:8452ms step_avg:92.88ms +step:92/1670 train_time:8544ms step_avg:92.87ms +step:93/1670 train_time:8635ms step_avg:92.85ms +step:94/1670 train_time:8727ms step_avg:92.84ms +step:95/1670 train_time:8820ms step_avg:92.84ms +step:96/1670 train_time:8911ms step_avg:92.82ms +step:97/1670 train_time:9002ms step_avg:92.81ms +step:98/1670 train_time:9093ms step_avg:92.79ms +step:99/1670 train_time:9185ms step_avg:92.78ms +step:100/1670 train_time:9277ms step_avg:92.77ms +step:101/1670 train_time:9370ms step_avg:92.77ms +step:102/1670 train_time:9462ms step_avg:92.76ms +step:103/1670 train_time:9552ms step_avg:92.74ms +step:104/1670 train_time:9644ms step_avg:92.73ms +step:105/1670 train_time:9734ms step_avg:92.71ms +step:106/1670 train_time:9826ms step_avg:92.70ms +step:107/1670 train_time:9917ms step_avg:92.68ms +step:108/1670 train_time:10009ms step_avg:92.67ms +step:109/1670 train_time:10100ms step_avg:92.66ms +step:110/1670 train_time:10192ms step_avg:92.65ms +step:111/1670 train_time:10283ms step_avg:92.64ms +step:112/1670 train_time:10373ms step_avg:92.62ms +step:113/1670 train_time:10466ms step_avg:92.62ms +step:114/1670 train_time:10558ms step_avg:92.61ms +step:115/1670 train_time:10650ms step_avg:92.61ms +step:116/1670 train_time:10741ms step_avg:92.60ms +step:117/1670 train_time:10832ms step_avg:92.58ms +step:118/1670 train_time:10924ms step_avg:92.57ms +step:119/1670 train_time:11014ms step_avg:92.56ms +step:120/1670 train_time:11106ms step_avg:92.55ms +step:121/1670 train_time:11198ms step_avg:92.54ms +step:122/1670 train_time:11289ms step_avg:92.53ms +step:123/1670 train_time:11381ms step_avg:92.52ms +step:124/1670 train_time:11472ms step_avg:92.51ms +step:125/1670 train_time:11563ms step_avg:92.50ms +step:125/1670 val_loss:4.3109 train_time:11654ms step_avg:93.23ms +step:126/1670 train_time:11674ms step_avg:92.65ms +step:127/1670 train_time:11750ms step_avg:92.52ms +step:128/1670 train_time:11851ms step_avg:92.59ms +step:129/1670 train_time:11945ms step_avg:92.59ms +step:130/1670 train_time:12036ms step_avg:92.58ms +step:131/1670 train_time:12127ms step_avg:92.57ms +step:132/1670 train_time:12216ms step_avg:92.55ms +step:133/1670 train_time:12307ms step_avg:92.54ms +step:134/1670 train_time:12397ms step_avg:92.51ms +step:135/1670 train_time:12487ms step_avg:92.50ms +step:136/1670 train_time:12578ms step_avg:92.49ms +step:137/1670 train_time:12670ms step_avg:92.48ms +step:138/1670 train_time:12763ms step_avg:92.48ms +step:139/1670 train_time:12857ms step_avg:92.50ms +step:140/1670 train_time:12950ms step_avg:92.50ms +step:141/1670 train_time:13041ms step_avg:92.49ms +step:142/1670 train_time:13132ms step_avg:92.48ms +step:143/1670 train_time:13224ms step_avg:92.47ms +step:144/1670 train_time:13314ms step_avg:92.46ms +step:145/1670 train_time:13404ms step_avg:92.44ms +step:146/1670 train_time:13496ms step_avg:92.44ms +step:147/1670 train_time:13586ms step_avg:92.42ms +step:148/1670 train_time:13677ms step_avg:92.41ms +step:149/1670 train_time:13769ms step_avg:92.41ms +step:150/1670 train_time:13861ms step_avg:92.41ms +step:151/1670 train_time:13955ms step_avg:92.42ms +step:152/1670 train_time:14048ms step_avg:92.42ms +step:153/1670 train_time:14138ms step_avg:92.41ms +step:154/1670 train_time:14229ms step_avg:92.40ms +step:155/1670 train_time:14320ms step_avg:92.39ms +step:156/1670 train_time:14411ms step_avg:92.38ms +step:157/1670 train_time:14501ms step_avg:92.36ms +step:158/1670 train_time:14591ms step_avg:92.35ms +step:159/1670 train_time:14682ms step_avg:92.34ms +step:160/1670 train_time:14774ms step_avg:92.34ms +step:161/1670 train_time:14866ms step_avg:92.33ms +step:162/1670 train_time:14957ms step_avg:92.33ms +step:163/1670 train_time:15050ms step_avg:92.33ms +step:164/1670 train_time:15142ms step_avg:92.33ms +step:165/1670 train_time:15234ms step_avg:92.32ms +step:166/1670 train_time:15324ms step_avg:92.31ms +step:167/1670 train_time:15414ms step_avg:92.30ms +step:168/1670 train_time:15506ms step_avg:92.30ms +step:169/1670 train_time:15598ms step_avg:92.29ms +step:170/1670 train_time:15688ms step_avg:92.28ms +step:171/1670 train_time:15779ms step_avg:92.28ms +step:172/1670 train_time:15872ms step_avg:92.28ms +step:173/1670 train_time:15963ms step_avg:92.27ms +step:174/1670 train_time:16055ms step_avg:92.27ms +step:175/1670 train_time:16148ms step_avg:92.27ms +step:176/1670 train_time:16239ms step_avg:92.26ms +step:177/1670 train_time:16329ms step_avg:92.26ms +step:178/1670 train_time:16420ms step_avg:92.25ms +step:179/1670 train_time:16512ms step_avg:92.25ms +step:180/1670 train_time:16603ms step_avg:92.24ms +step:181/1670 train_time:16693ms step_avg:92.23ms +step:182/1670 train_time:16784ms step_avg:92.22ms +step:183/1670 train_time:16875ms step_avg:92.22ms +step:184/1670 train_time:16967ms step_avg:92.21ms +step:185/1670 train_time:17058ms step_avg:92.21ms +step:186/1670 train_time:17152ms step_avg:92.21ms +step:187/1670 train_time:17243ms step_avg:92.21ms +step:188/1670 train_time:17334ms step_avg:92.20ms +step:189/1670 train_time:17425ms step_avg:92.20ms +step:190/1670 train_time:17516ms step_avg:92.19ms +step:191/1670 train_time:17608ms step_avg:92.19ms +step:192/1670 train_time:17698ms step_avg:92.18ms +step:193/1670 train_time:17789ms step_avg:92.17ms +step:194/1670 train_time:17881ms step_avg:92.17ms +step:195/1670 train_time:17972ms step_avg:92.16ms +step:196/1670 train_time:18063ms step_avg:92.16ms +step:197/1670 train_time:18155ms step_avg:92.16ms +step:198/1670 train_time:18246ms step_avg:92.15ms +step:199/1670 train_time:18337ms step_avg:92.15ms +step:200/1670 train_time:18429ms step_avg:92.14ms +step:201/1670 train_time:18519ms step_avg:92.13ms +step:202/1670 train_time:18611ms step_avg:92.13ms +step:203/1670 train_time:18701ms step_avg:92.12ms +step:204/1670 train_time:18792ms step_avg:92.12ms +step:205/1670 train_time:18883ms step_avg:92.11ms +step:206/1670 train_time:18974ms step_avg:92.11ms +step:207/1670 train_time:19065ms step_avg:92.10ms +step:208/1670 train_time:19156ms step_avg:92.10ms +step:209/1670 train_time:19249ms step_avg:92.10ms +step:210/1670 train_time:19339ms step_avg:92.09ms +step:211/1670 train_time:19431ms step_avg:92.09ms +step:212/1670 train_time:19522ms step_avg:92.09ms +step:213/1670 train_time:19773ms step_avg:92.83ms +step:214/1670 train_time:19842ms step_avg:92.72ms +step:215/1670 train_time:19932ms step_avg:92.71ms +step:216/1670 train_time:20022ms step_avg:92.69ms +step:217/1670 train_time:20112ms step_avg:92.68ms +step:218/1670 train_time:20202ms step_avg:92.67ms +step:219/1670 train_time:20292ms step_avg:92.66ms +step:220/1670 train_time:20382ms step_avg:92.64ms +step:221/1670 train_time:20471ms step_avg:92.63ms +step:222/1670 train_time:20561ms step_avg:92.62ms +step:223/1670 train_time:20654ms step_avg:92.62ms +step:224/1670 train_time:20751ms step_avg:92.64ms +step:225/1670 train_time:20846ms step_avg:92.65ms +step:226/1670 train_time:20938ms step_avg:92.65ms +step:227/1670 train_time:21029ms step_avg:92.64ms +step:228/1670 train_time:21119ms step_avg:92.63ms +step:229/1670 train_time:21209ms step_avg:92.62ms +step:230/1670 train_time:21299ms step_avg:92.61ms +step:231/1670 train_time:21389ms step_avg:92.59ms +step:232/1670 train_time:21480ms step_avg:92.58ms +step:233/1670 train_time:21570ms step_avg:92.58ms +step:234/1670 train_time:21663ms step_avg:92.58ms +step:235/1670 train_time:21757ms step_avg:92.58ms +step:236/1670 train_time:21850ms step_avg:92.59ms +step:237/1670 train_time:21941ms step_avg:92.58ms +step:238/1670 train_time:22034ms step_avg:92.58ms +step:239/1670 train_time:22124ms step_avg:92.57ms +step:240/1670 train_time:22215ms step_avg:92.56ms +step:241/1670 train_time:22305ms step_avg:92.55ms +step:242/1670 train_time:22395ms step_avg:92.54ms +step:243/1670 train_time:22485ms step_avg:92.53ms +step:244/1670 train_time:22576ms step_avg:92.52ms +step:245/1670 train_time:22667ms step_avg:92.52ms +step:246/1670 train_time:22759ms step_avg:92.52ms +step:247/1670 train_time:22853ms step_avg:92.52ms +step:248/1670 train_time:22945ms step_avg:92.52ms +step:249/1670 train_time:23037ms step_avg:92.52ms +step:250/1670 train_time:23129ms step_avg:92.51ms +step:250/1670 val_loss:3.9614 train_time:23218ms step_avg:92.87ms +step:251/1670 train_time:23238ms step_avg:92.58ms +step:252/1670 train_time:23312ms step_avg:92.51ms +step:253/1670 train_time:23404ms step_avg:92.50ms +step:254/1670 train_time:23494ms step_avg:92.50ms +step:255/1670 train_time:23585ms step_avg:92.49ms +step:256/1670 train_time:23676ms step_avg:92.48ms +step:257/1670 train_time:23766ms step_avg:92.47ms +step:258/1670 train_time:23856ms step_avg:92.47ms +step:259/1670 train_time:23946ms step_avg:92.46ms +step:260/1670 train_time:24037ms step_avg:92.45ms +step:261/1670 train_time:24130ms step_avg:92.45ms +step:262/1670 train_time:24223ms step_avg:92.46ms +step:263/1670 train_time:24316ms step_avg:92.46ms +step:264/1670 train_time:24408ms step_avg:92.46ms +step:265/1670 train_time:24500ms step_avg:92.45ms +step:266/1670 train_time:24590ms step_avg:92.44ms +step:267/1670 train_time:24681ms step_avg:92.44ms +step:268/1670 train_time:24771ms step_avg:92.43ms +step:269/1670 train_time:24862ms step_avg:92.42ms +step:270/1670 train_time:24952ms step_avg:92.42ms +step:271/1670 train_time:25043ms step_avg:92.41ms +step:272/1670 train_time:25134ms step_avg:92.41ms +step:273/1670 train_time:25227ms step_avg:92.41ms +step:274/1670 train_time:25320ms step_avg:92.41ms +step:275/1670 train_time:25411ms step_avg:92.40ms +step:276/1670 train_time:25502ms step_avg:92.40ms +step:277/1670 train_time:25593ms step_avg:92.39ms +step:278/1670 train_time:25684ms step_avg:92.39ms +step:279/1670 train_time:25775ms step_avg:92.38ms +step:280/1670 train_time:25866ms step_avg:92.38ms +step:281/1670 train_time:25956ms step_avg:92.37ms +step:282/1670 train_time:26047ms step_avg:92.37ms +step:283/1670 train_time:26139ms step_avg:92.37ms +step:284/1670 train_time:26231ms step_avg:92.36ms +step:285/1670 train_time:26323ms step_avg:92.36ms +step:286/1670 train_time:26414ms step_avg:92.36ms +step:287/1670 train_time:26506ms step_avg:92.35ms +step:288/1670 train_time:26597ms step_avg:92.35ms +step:289/1670 train_time:26688ms step_avg:92.35ms +step:290/1670 train_time:26778ms step_avg:92.34ms +step:291/1670 train_time:26869ms step_avg:92.33ms +step:292/1670 train_time:26960ms step_avg:92.33ms +step:293/1670 train_time:27050ms step_avg:92.32ms +step:294/1670 train_time:27142ms step_avg:92.32ms +step:295/1670 train_time:27233ms step_avg:92.32ms +step:296/1670 train_time:27325ms step_avg:92.32ms +step:297/1670 train_time:27416ms step_avg:92.31ms +step:298/1670 train_time:27510ms step_avg:92.32ms +step:299/1670 train_time:27601ms step_avg:92.31ms +step:300/1670 train_time:27692ms step_avg:92.31ms +step:301/1670 train_time:27784ms step_avg:92.30ms +step:302/1670 train_time:27874ms step_avg:92.30ms +step:303/1670 train_time:27965ms step_avg:92.29ms +step:304/1670 train_time:28056ms step_avg:92.29ms +step:305/1670 train_time:28147ms step_avg:92.28ms +step:306/1670 train_time:28238ms step_avg:92.28ms +step:307/1670 train_time:28329ms step_avg:92.28ms +step:308/1670 train_time:28422ms step_avg:92.28ms +step:309/1670 train_time:28513ms step_avg:92.27ms +step:310/1670 train_time:28605ms step_avg:92.27ms +step:311/1670 train_time:28696ms step_avg:92.27ms +step:312/1670 train_time:28788ms step_avg:92.27ms +step:313/1670 train_time:28879ms step_avg:92.27ms +step:314/1670 train_time:28970ms step_avg:92.26ms +step:315/1670 train_time:29061ms step_avg:92.26ms +step:316/1670 train_time:29152ms step_avg:92.25ms +step:317/1670 train_time:29242ms step_avg:92.25ms +step:318/1670 train_time:29334ms step_avg:92.25ms +step:319/1670 train_time:29426ms step_avg:92.24ms +step:320/1670 train_time:29517ms step_avg:92.24ms +step:321/1670 train_time:29609ms step_avg:92.24ms +step:322/1670 train_time:29701ms step_avg:92.24ms +step:323/1670 train_time:29792ms step_avg:92.23ms +step:324/1670 train_time:29883ms step_avg:92.23ms +step:325/1670 train_time:29974ms step_avg:92.23ms +step:326/1670 train_time:30066ms step_avg:92.23ms +step:327/1670 train_time:30156ms step_avg:92.22ms +step:328/1670 train_time:30248ms step_avg:92.22ms +step:329/1670 train_time:30339ms step_avg:92.22ms +step:330/1670 train_time:30430ms step_avg:92.21ms +step:331/1670 train_time:30522ms step_avg:92.21ms +step:332/1670 train_time:30614ms step_avg:92.21ms +step:333/1670 train_time:30706ms step_avg:92.21ms +step:334/1670 train_time:30797ms step_avg:92.21ms +step:335/1670 train_time:30889ms step_avg:92.21ms +step:336/1670 train_time:30981ms step_avg:92.20ms +step:337/1670 train_time:31071ms step_avg:92.20ms +step:338/1670 train_time:31163ms step_avg:92.20ms +step:339/1670 train_time:31254ms step_avg:92.19ms +step:340/1670 train_time:31345ms step_avg:92.19ms +step:341/1670 train_time:31436ms step_avg:92.19ms +step:342/1670 train_time:31527ms step_avg:92.18ms +step:343/1670 train_time:31618ms step_avg:92.18ms +step:344/1670 train_time:31710ms step_avg:92.18ms +step:345/1670 train_time:31802ms step_avg:92.18ms +step:346/1670 train_time:31892ms step_avg:92.17ms +step:347/1670 train_time:31984ms step_avg:92.17ms +step:348/1670 train_time:32075ms step_avg:92.17ms +step:349/1670 train_time:32166ms step_avg:92.17ms +step:350/1670 train_time:32257ms step_avg:92.16ms +step:351/1670 train_time:32350ms step_avg:92.17ms +step:352/1670 train_time:32441ms step_avg:92.16ms +step:353/1670 train_time:32532ms step_avg:92.16ms +step:354/1670 train_time:32624ms step_avg:92.16ms +step:355/1670 train_time:32715ms step_avg:92.16ms +step:356/1670 train_time:32808ms step_avg:92.16ms +step:357/1670 train_time:32898ms step_avg:92.15ms +step:358/1670 train_time:32989ms step_avg:92.15ms +step:359/1670 train_time:33081ms step_avg:92.15ms +step:360/1670 train_time:33172ms step_avg:92.14ms +step:361/1670 train_time:33263ms step_avg:92.14ms +step:362/1670 train_time:33354ms step_avg:92.14ms +step:363/1670 train_time:33445ms step_avg:92.13ms +step:364/1670 train_time:33535ms step_avg:92.13ms +step:365/1670 train_time:33628ms step_avg:92.13ms +step:366/1670 train_time:33721ms step_avg:92.13ms +step:367/1670 train_time:33812ms step_avg:92.13ms +step:368/1670 train_time:33903ms step_avg:92.13ms +step:369/1670 train_time:33994ms step_avg:92.12ms +step:370/1670 train_time:34086ms step_avg:92.12ms +step:371/1670 train_time:34177ms step_avg:92.12ms +step:372/1670 train_time:34268ms step_avg:92.12ms +step:373/1670 train_time:34358ms step_avg:92.11ms +step:374/1670 train_time:34449ms step_avg:92.11ms +step:375/1670 train_time:34540ms step_avg:92.11ms +step:375/1670 val_loss:3.8131 train_time:34630ms step_avg:92.35ms +step:376/1670 train_time:34650ms step_avg:92.15ms +step:377/1670 train_time:34723ms step_avg:92.10ms +step:378/1670 train_time:34815ms step_avg:92.10ms +step:379/1670 train_time:34906ms step_avg:92.10ms +step:380/1670 train_time:34997ms step_avg:92.10ms +step:381/1670 train_time:35089ms step_avg:92.10ms +step:382/1670 train_time:35179ms step_avg:92.09ms +step:383/1670 train_time:35269ms step_avg:92.09ms +step:384/1670 train_time:35360ms step_avg:92.08ms +step:385/1670 train_time:35452ms step_avg:92.08ms +step:386/1670 train_time:35543ms step_avg:92.08ms +step:387/1670 train_time:35635ms step_avg:92.08ms +step:388/1670 train_time:35728ms step_avg:92.08ms +step:389/1670 train_time:35819ms step_avg:92.08ms +step:390/1670 train_time:35910ms step_avg:92.08ms +step:391/1670 train_time:36000ms step_avg:92.07ms +step:392/1670 train_time:36092ms step_avg:92.07ms +step:393/1670 train_time:36182ms step_avg:92.07ms +step:394/1670 train_time:36273ms step_avg:92.06ms +step:395/1670 train_time:36364ms step_avg:92.06ms +step:396/1670 train_time:36455ms step_avg:92.06ms +step:397/1670 train_time:36547ms step_avg:92.06ms +step:398/1670 train_time:36640ms step_avg:92.06ms +step:399/1670 train_time:36732ms step_avg:92.06ms +step:400/1670 train_time:36823ms step_avg:92.06ms +step:401/1670 train_time:36914ms step_avg:92.06ms +step:402/1670 train_time:37005ms step_avg:92.05ms +step:403/1670 train_time:37097ms step_avg:92.05ms +step:404/1670 train_time:37188ms step_avg:92.05ms +step:405/1670 train_time:37279ms step_avg:92.05ms +step:406/1670 train_time:37371ms step_avg:92.05ms +step:407/1670 train_time:37461ms step_avg:92.04ms +step:408/1670 train_time:37553ms step_avg:92.04ms +step:409/1670 train_time:37644ms step_avg:92.04ms +step:410/1670 train_time:37736ms step_avg:92.04ms +step:411/1670 train_time:37827ms step_avg:92.04ms +step:412/1670 train_time:37918ms step_avg:92.03ms +step:413/1670 train_time:38008ms step_avg:92.03ms +step:414/1670 train_time:38100ms step_avg:92.03ms +step:415/1670 train_time:38192ms step_avg:92.03ms +step:416/1670 train_time:38282ms step_avg:92.02ms +step:417/1670 train_time:38375ms step_avg:92.03ms +step:418/1670 train_time:38466ms step_avg:92.02ms +step:419/1670 train_time:38558ms step_avg:92.02ms +step:420/1670 train_time:38650ms step_avg:92.02ms +step:421/1670 train_time:38741ms step_avg:92.02ms +step:422/1670 train_time:38832ms step_avg:92.02ms +step:423/1670 train_time:38922ms step_avg:92.01ms +step:424/1670 train_time:39013ms step_avg:92.01ms +step:425/1670 train_time:39270ms step_avg:92.40ms +step:426/1670 train_time:39338ms step_avg:92.34ms +step:427/1670 train_time:39428ms step_avg:92.34ms +step:428/1670 train_time:39518ms step_avg:92.33ms +step:429/1670 train_time:39608ms step_avg:92.33ms +step:430/1670 train_time:39698ms step_avg:92.32ms +step:431/1670 train_time:39788ms step_avg:92.32ms +step:432/1670 train_time:39877ms step_avg:92.31ms +step:433/1670 train_time:39968ms step_avg:92.31ms +step:434/1670 train_time:40059ms step_avg:92.30ms +step:435/1670 train_time:40153ms step_avg:92.31ms +step:436/1670 train_time:40248ms step_avg:92.31ms +step:437/1670 train_time:40340ms step_avg:92.31ms +step:438/1670 train_time:40431ms step_avg:92.31ms +step:439/1670 train_time:40522ms step_avg:92.31ms +step:440/1670 train_time:40613ms step_avg:92.30ms +step:441/1670 train_time:40704ms step_avg:92.30ms +step:442/1670 train_time:40795ms step_avg:92.30ms +step:443/1670 train_time:40885ms step_avg:92.29ms +step:444/1670 train_time:40976ms step_avg:92.29ms +step:445/1670 train_time:41067ms step_avg:92.28ms +step:446/1670 train_time:41159ms step_avg:92.29ms +step:447/1670 train_time:41253ms step_avg:92.29ms +step:448/1670 train_time:41344ms step_avg:92.29ms +step:449/1670 train_time:41436ms step_avg:92.29ms +step:450/1670 train_time:41528ms step_avg:92.28ms +step:451/1670 train_time:41618ms step_avg:92.28ms +step:452/1670 train_time:41709ms step_avg:92.28ms +step:453/1670 train_time:41799ms step_avg:92.27ms +step:454/1670 train_time:41889ms step_avg:92.27ms +step:455/1670 train_time:41980ms step_avg:92.26ms +step:456/1670 train_time:42072ms step_avg:92.26ms +step:457/1670 train_time:42164ms step_avg:92.26ms +step:458/1670 train_time:42258ms step_avg:92.27ms +step:459/1670 train_time:42350ms step_avg:92.27ms +step:460/1670 train_time:42441ms step_avg:92.26ms +step:461/1670 train_time:42533ms step_avg:92.26ms +step:462/1670 train_time:42624ms step_avg:92.26ms +step:463/1670 train_time:42715ms step_avg:92.26ms +step:464/1670 train_time:42806ms step_avg:92.25ms +step:465/1670 train_time:42896ms step_avg:92.25ms +step:466/1670 train_time:42987ms step_avg:92.25ms +step:467/1670 train_time:43080ms step_avg:92.25ms +step:468/1670 train_time:43172ms step_avg:92.25ms +step:469/1670 train_time:43262ms step_avg:92.24ms +step:470/1670 train_time:43355ms step_avg:92.24ms +step:471/1670 train_time:43446ms step_avg:92.24ms +step:472/1670 train_time:43538ms step_avg:92.24ms +step:473/1670 train_time:43630ms step_avg:92.24ms +step:474/1670 train_time:43720ms step_avg:92.24ms +step:475/1670 train_time:43811ms step_avg:92.23ms +step:476/1670 train_time:43902ms step_avg:92.23ms +step:477/1670 train_time:43994ms step_avg:92.23ms +step:478/1670 train_time:44084ms step_avg:92.23ms +step:479/1670 train_time:44176ms step_avg:92.23ms +step:480/1670 train_time:44267ms step_avg:92.22ms +step:481/1670 train_time:44359ms step_avg:92.22ms +step:482/1670 train_time:44451ms step_avg:92.22ms +step:483/1670 train_time:44542ms step_avg:92.22ms +step:484/1670 train_time:44634ms step_avg:92.22ms +step:485/1670 train_time:44725ms step_avg:92.22ms +step:486/1670 train_time:44817ms step_avg:92.22ms +step:487/1670 train_time:44908ms step_avg:92.21ms +step:488/1670 train_time:45000ms step_avg:92.21ms +step:489/1670 train_time:45091ms step_avg:92.21ms +step:490/1670 train_time:45182ms step_avg:92.21ms +step:491/1670 train_time:45274ms step_avg:92.21ms +step:492/1670 train_time:45365ms step_avg:92.20ms +step:493/1670 train_time:45457ms step_avg:92.20ms +step:494/1670 train_time:45548ms step_avg:92.20ms +step:495/1670 train_time:45639ms step_avg:92.20ms +step:496/1670 train_time:45731ms step_avg:92.20ms +step:497/1670 train_time:45821ms step_avg:92.20ms +step:498/1670 train_time:45912ms step_avg:92.19ms +step:499/1670 train_time:46003ms step_avg:92.19ms +step:500/1670 train_time:46096ms step_avg:92.19ms +step:500/1670 val_loss:3.7135 train_time:46186ms step_avg:92.37ms +step:501/1670 train_time:46206ms step_avg:92.23ms +step:502/1670 train_time:46279ms step_avg:92.19ms +step:503/1670 train_time:46371ms step_avg:92.19ms +step:504/1670 train_time:46462ms step_avg:92.19ms +step:505/1670 train_time:46552ms step_avg:92.18ms +step:506/1670 train_time:46642ms step_avg:92.18ms +step:507/1670 train_time:46732ms step_avg:92.17ms +step:508/1670 train_time:46824ms step_avg:92.17ms +step:509/1670 train_time:46914ms step_avg:92.17ms +step:510/1670 train_time:47006ms step_avg:92.17ms +step:511/1670 train_time:47099ms step_avg:92.17ms +step:512/1670 train_time:47192ms step_avg:92.17ms +step:513/1670 train_time:47284ms step_avg:92.17ms +step:514/1670 train_time:47376ms step_avg:92.17ms +step:515/1670 train_time:47468ms step_avg:92.17ms +step:516/1670 train_time:47559ms step_avg:92.17ms +step:517/1670 train_time:47649ms step_avg:92.16ms +step:518/1670 train_time:47740ms step_avg:92.16ms +step:519/1670 train_time:47831ms step_avg:92.16ms +step:520/1670 train_time:47922ms step_avg:92.16ms +step:521/1670 train_time:48012ms step_avg:92.15ms +step:522/1670 train_time:48104ms step_avg:92.15ms +step:523/1670 train_time:48196ms step_avg:92.15ms +step:524/1670 train_time:48287ms step_avg:92.15ms +step:525/1670 train_time:48378ms step_avg:92.15ms +step:526/1670 train_time:48470ms step_avg:92.15ms +step:527/1670 train_time:48562ms step_avg:92.15ms +step:528/1670 train_time:48652ms step_avg:92.14ms +step:529/1670 train_time:48743ms step_avg:92.14ms +step:530/1670 train_time:48834ms step_avg:92.14ms +step:531/1670 train_time:48926ms step_avg:92.14ms +step:532/1670 train_time:49017ms step_avg:92.14ms +step:533/1670 train_time:49108ms step_avg:92.14ms +step:534/1670 train_time:49200ms step_avg:92.13ms +step:535/1670 train_time:49291ms step_avg:92.13ms +step:536/1670 train_time:49381ms step_avg:92.13ms +step:537/1670 train_time:49472ms step_avg:92.13ms +step:538/1670 train_time:49564ms step_avg:92.13ms +step:539/1670 train_time:49655ms step_avg:92.12ms +step:540/1670 train_time:49746ms step_avg:92.12ms +step:541/1670 train_time:49838ms step_avg:92.12ms +step:542/1670 train_time:49930ms step_avg:92.12ms +step:543/1670 train_time:50021ms step_avg:92.12ms +step:544/1670 train_time:50112ms step_avg:92.12ms +step:545/1670 train_time:50204ms step_avg:92.12ms +step:546/1670 train_time:50295ms step_avg:92.11ms +step:547/1670 train_time:50386ms step_avg:92.11ms +step:548/1670 train_time:50478ms step_avg:92.11ms +step:549/1670 train_time:50570ms step_avg:92.11ms +step:550/1670 train_time:50661ms step_avg:92.11ms +step:551/1670 train_time:50753ms step_avg:92.11ms +step:552/1670 train_time:50845ms step_avg:92.11ms +step:553/1670 train_time:50936ms step_avg:92.11ms +step:554/1670 train_time:51029ms step_avg:92.11ms +step:555/1670 train_time:51121ms step_avg:92.11ms +step:556/1670 train_time:51212ms step_avg:92.11ms +step:557/1670 train_time:51303ms step_avg:92.11ms +step:558/1670 train_time:51577ms step_avg:92.43ms +step:559/1670 train_time:51658ms step_avg:92.41ms +step:560/1670 train_time:51748ms step_avg:92.41ms +step:561/1670 train_time:51840ms step_avg:92.41ms +step:562/1670 train_time:51932ms step_avg:92.40ms +step:563/1670 train_time:52023ms step_avg:92.40ms +step:564/1670 train_time:52114ms step_avg:92.40ms +step:565/1670 train_time:52205ms step_avg:92.40ms +step:566/1670 train_time:52296ms step_avg:92.40ms +step:567/1670 train_time:52389ms step_avg:92.40ms +step:568/1670 train_time:52483ms step_avg:92.40ms +step:569/1670 train_time:52580ms step_avg:92.41ms +step:570/1670 train_time:52673ms step_avg:92.41ms +step:571/1670 train_time:52767ms step_avg:92.41ms +step:572/1670 train_time:52859ms step_avg:92.41ms +step:573/1670 train_time:52951ms step_avg:92.41ms +step:574/1670 train_time:53042ms step_avg:92.41ms +step:575/1670 train_time:53134ms step_avg:92.41ms +step:576/1670 train_time:53226ms step_avg:92.41ms +step:577/1670 train_time:53317ms step_avg:92.40ms +step:578/1670 train_time:53410ms step_avg:92.40ms +step:579/1670 train_time:53506ms step_avg:92.41ms +step:580/1670 train_time:53600ms step_avg:92.41ms +step:581/1670 train_time:53693ms step_avg:92.42ms +step:582/1670 train_time:53789ms step_avg:92.42ms +step:583/1670 train_time:53881ms step_avg:92.42ms +step:584/1670 train_time:53973ms step_avg:92.42ms +step:585/1670 train_time:54066ms step_avg:92.42ms +step:586/1670 train_time:54157ms step_avg:92.42ms +step:587/1670 train_time:54249ms step_avg:92.42ms +step:588/1670 train_time:54341ms step_avg:92.42ms +step:589/1670 train_time:54433ms step_avg:92.42ms +step:590/1670 train_time:54527ms step_avg:92.42ms +step:591/1670 train_time:54620ms step_avg:92.42ms +step:592/1670 train_time:54713ms step_avg:92.42ms +step:593/1670 train_time:54806ms step_avg:92.42ms +step:594/1670 train_time:54899ms step_avg:92.42ms +step:595/1670 train_time:54992ms step_avg:92.42ms +step:596/1670 train_time:55085ms step_avg:92.42ms +step:597/1670 train_time:55176ms step_avg:92.42ms +step:598/1670 train_time:55268ms step_avg:92.42ms +step:599/1670 train_time:55360ms step_avg:92.42ms +step:600/1670 train_time:55454ms step_avg:92.42ms +step:601/1670 train_time:55546ms step_avg:92.42ms +step:602/1670 train_time:55639ms step_avg:92.42ms +step:603/1670 train_time:55733ms step_avg:92.43ms +step:604/1670 train_time:55826ms step_avg:92.43ms +step:605/1670 train_time:55918ms step_avg:92.43ms +step:606/1670 train_time:56009ms step_avg:92.42ms +step:607/1670 train_time:56101ms step_avg:92.42ms +step:608/1670 train_time:56194ms step_avg:92.42ms +step:609/1670 train_time:56285ms step_avg:92.42ms +step:610/1670 train_time:56378ms step_avg:92.42ms +step:611/1670 train_time:56471ms step_avg:92.42ms +step:612/1670 train_time:56564ms step_avg:92.42ms +step:613/1670 train_time:56656ms step_avg:92.42ms +step:614/1670 train_time:56750ms step_avg:92.43ms +step:615/1670 train_time:56842ms step_avg:92.43ms +step:616/1670 train_time:56935ms step_avg:92.43ms +step:617/1670 train_time:57027ms step_avg:92.43ms +step:618/1670 train_time:57119ms step_avg:92.43ms +step:619/1670 train_time:57212ms step_avg:92.43ms +step:620/1670 train_time:57305ms step_avg:92.43ms +step:621/1670 train_time:57397ms step_avg:92.43ms +step:622/1670 train_time:57490ms step_avg:92.43ms +step:623/1670 train_time:57582ms step_avg:92.43ms +step:624/1670 train_time:57675ms step_avg:92.43ms +step:625/1670 train_time:57769ms step_avg:92.43ms +step:625/1670 val_loss:3.6141 train_time:57861ms step_avg:92.58ms +step:626/1670 train_time:57881ms step_avg:92.46ms +step:627/1670 train_time:57959ms step_avg:92.44ms +step:628/1670 train_time:58064ms step_avg:92.46ms +step:629/1670 train_time:58159ms step_avg:92.46ms +step:630/1670 train_time:58251ms step_avg:92.46ms +step:631/1670 train_time:58341ms step_avg:92.46ms +step:632/1670 train_time:58432ms step_avg:92.46ms +step:633/1670 train_time:58524ms step_avg:92.45ms +step:634/1670 train_time:58615ms step_avg:92.45ms +step:635/1670 train_time:58707ms step_avg:92.45ms +step:636/1670 train_time:58798ms step_avg:92.45ms +step:637/1670 train_time:58889ms step_avg:92.45ms +step:638/1670 train_time:58984ms step_avg:92.45ms +step:639/1670 train_time:59222ms step_avg:92.68ms +step:640/1670 train_time:59292ms step_avg:92.64ms +step:641/1670 train_time:59383ms step_avg:92.64ms +step:642/1670 train_time:59474ms step_avg:92.64ms +step:643/1670 train_time:59565ms step_avg:92.64ms +step:644/1670 train_time:59656ms step_avg:92.63ms +step:645/1670 train_time:59748ms step_avg:92.63ms +step:646/1670 train_time:59838ms step_avg:92.63ms +step:647/1670 train_time:59929ms step_avg:92.63ms +step:648/1670 train_time:60021ms step_avg:92.62ms +step:649/1670 train_time:60117ms step_avg:92.63ms +step:650/1670 train_time:60215ms step_avg:92.64ms +step:651/1670 train_time:60308ms step_avg:92.64ms +step:652/1670 train_time:60400ms step_avg:92.64ms +step:653/1670 train_time:60493ms step_avg:92.64ms +step:654/1670 train_time:60585ms step_avg:92.64ms +step:655/1670 train_time:60677ms step_avg:92.64ms +step:656/1670 train_time:60770ms step_avg:92.64ms +step:657/1670 train_time:60861ms step_avg:92.63ms +step:658/1670 train_time:60952ms step_avg:92.63ms +step:659/1670 train_time:61044ms step_avg:92.63ms +step:660/1670 train_time:61138ms step_avg:92.63ms +step:661/1670 train_time:61232ms step_avg:92.64ms +step:662/1670 train_time:61325ms step_avg:92.64ms +step:663/1670 train_time:61418ms step_avg:92.64ms +step:664/1670 train_time:61511ms step_avg:92.64ms +step:665/1670 train_time:61603ms step_avg:92.64ms +step:666/1670 train_time:61695ms step_avg:92.64ms +step:667/1670 train_time:61788ms step_avg:92.64ms +step:668/1670 train_time:61879ms step_avg:92.63ms +step:669/1670 train_time:61971ms step_avg:92.63ms +step:670/1670 train_time:62064ms step_avg:92.63ms +step:671/1670 train_time:62158ms step_avg:92.64ms +step:672/1670 train_time:62253ms step_avg:92.64ms +step:673/1670 train_time:62346ms step_avg:92.64ms +step:674/1670 train_time:62439ms step_avg:92.64ms +step:675/1670 train_time:62531ms step_avg:92.64ms +step:676/1670 train_time:62624ms step_avg:92.64ms +step:677/1670 train_time:62717ms step_avg:92.64ms +step:678/1670 train_time:62809ms step_avg:92.64ms +step:679/1670 train_time:62901ms step_avg:92.64ms +step:680/1670 train_time:62993ms step_avg:92.64ms +step:681/1670 train_time:63085ms step_avg:92.64ms +step:682/1670 train_time:63179ms step_avg:92.64ms +step:683/1670 train_time:63273ms step_avg:92.64ms +step:684/1670 train_time:63365ms step_avg:92.64ms +step:685/1670 train_time:63458ms step_avg:92.64ms +step:686/1670 train_time:63550ms step_avg:92.64ms +step:687/1670 train_time:63642ms step_avg:92.64ms +step:688/1670 train_time:63735ms step_avg:92.64ms +step:689/1670 train_time:63826ms step_avg:92.64ms +step:690/1670 train_time:63919ms step_avg:92.64ms +step:691/1670 train_time:64011ms step_avg:92.64ms +step:692/1670 train_time:64104ms step_avg:92.64ms +step:693/1670 train_time:64198ms step_avg:92.64ms +step:694/1670 train_time:64291ms step_avg:92.64ms +step:695/1670 train_time:64383ms step_avg:92.64ms +step:696/1670 train_time:64477ms step_avg:92.64ms +step:697/1670 train_time:64569ms step_avg:92.64ms +step:698/1670 train_time:64661ms step_avg:92.64ms +step:699/1670 train_time:64753ms step_avg:92.64ms +step:700/1670 train_time:64846ms step_avg:92.64ms +step:701/1670 train_time:64939ms step_avg:92.64ms +step:702/1670 train_time:65032ms step_avg:92.64ms +step:703/1670 train_time:65125ms step_avg:92.64ms +step:704/1670 train_time:65218ms step_avg:92.64ms +step:705/1670 train_time:65310ms step_avg:92.64ms +step:706/1670 train_time:65402ms step_avg:92.64ms +step:707/1670 train_time:65496ms step_avg:92.64ms +step:708/1670 train_time:65588ms step_avg:92.64ms +step:709/1670 train_time:65681ms step_avg:92.64ms +step:710/1670 train_time:65774ms step_avg:92.64ms +step:711/1670 train_time:65866ms step_avg:92.64ms +step:712/1670 train_time:65959ms step_avg:92.64ms +step:713/1670 train_time:66052ms step_avg:92.64ms +step:714/1670 train_time:66145ms step_avg:92.64ms +step:715/1670 train_time:66238ms step_avg:92.64ms +step:716/1670 train_time:66330ms step_avg:92.64ms +step:717/1670 train_time:66423ms step_avg:92.64ms +step:718/1670 train_time:66515ms step_avg:92.64ms +step:719/1670 train_time:66607ms step_avg:92.64ms +step:720/1670 train_time:66701ms step_avg:92.64ms +step:721/1670 train_time:66794ms step_avg:92.64ms +step:722/1670 train_time:66886ms step_avg:92.64ms +step:723/1670 train_time:66979ms step_avg:92.64ms +step:724/1670 train_time:67072ms step_avg:92.64ms +step:725/1670 train_time:67164ms step_avg:92.64ms +step:726/1670 train_time:67257ms step_avg:92.64ms +step:727/1670 train_time:67349ms step_avg:92.64ms +step:728/1670 train_time:67441ms step_avg:92.64ms +step:729/1670 train_time:67533ms step_avg:92.64ms +step:730/1670 train_time:67625ms step_avg:92.64ms +step:731/1670 train_time:67718ms step_avg:92.64ms +step:732/1670 train_time:67811ms step_avg:92.64ms +step:733/1670 train_time:67902ms step_avg:92.64ms +step:734/1670 train_time:67996ms step_avg:92.64ms +step:735/1670 train_time:68090ms step_avg:92.64ms +step:736/1670 train_time:68182ms step_avg:92.64ms +step:737/1670 train_time:68275ms step_avg:92.64ms +step:738/1670 train_time:68368ms step_avg:92.64ms +step:739/1670 train_time:68461ms step_avg:92.64ms +step:740/1670 train_time:68552ms step_avg:92.64ms +step:741/1670 train_time:68645ms step_avg:92.64ms +step:742/1670 train_time:68738ms step_avg:92.64ms +step:743/1670 train_time:68831ms step_avg:92.64ms +step:744/1670 train_time:68924ms step_avg:92.64ms +step:745/1670 train_time:69017ms step_avg:92.64ms +step:746/1670 train_time:69109ms step_avg:92.64ms +step:747/1670 train_time:69201ms step_avg:92.64ms +step:748/1670 train_time:69294ms step_avg:92.64ms +step:749/1670 train_time:69386ms step_avg:92.64ms +step:750/1670 train_time:69479ms step_avg:92.64ms +step:750/1670 val_loss:3.5619 train_time:69571ms step_avg:92.76ms +step:751/1670 train_time:69591ms step_avg:92.66ms +step:752/1670 train_time:69665ms step_avg:92.64ms +step:753/1670 train_time:69757ms step_avg:92.64ms +step:754/1670 train_time:69849ms step_avg:92.64ms +step:755/1670 train_time:69940ms step_avg:92.64ms +step:756/1670 train_time:70033ms step_avg:92.64ms +step:757/1670 train_time:70126ms step_avg:92.64ms +step:758/1670 train_time:70218ms step_avg:92.64ms +step:759/1670 train_time:70310ms step_avg:92.64ms +step:760/1670 train_time:70403ms step_avg:92.64ms +step:761/1670 train_time:70497ms step_avg:92.64ms +step:762/1670 train_time:70590ms step_avg:92.64ms +step:763/1670 train_time:70683ms step_avg:92.64ms +step:764/1670 train_time:70776ms step_avg:92.64ms +step:765/1670 train_time:70867ms step_avg:92.64ms +step:766/1670 train_time:70959ms step_avg:92.64ms +step:767/1670 train_time:71052ms step_avg:92.64ms +step:768/1670 train_time:71144ms step_avg:92.64ms +step:769/1670 train_time:71237ms step_avg:92.64ms +step:770/1670 train_time:71330ms step_avg:92.64ms +step:771/1670 train_time:71423ms step_avg:92.64ms +step:772/1670 train_time:71516ms step_avg:92.64ms +step:773/1670 train_time:71610ms step_avg:92.64ms +step:774/1670 train_time:71703ms step_avg:92.64ms +step:775/1670 train_time:71795ms step_avg:92.64ms +step:776/1670 train_time:71887ms step_avg:92.64ms +step:777/1670 train_time:71979ms step_avg:92.64ms +step:778/1670 train_time:72071ms step_avg:92.64ms +step:779/1670 train_time:72163ms step_avg:92.64ms +step:780/1670 train_time:72256ms step_avg:92.64ms +step:781/1670 train_time:72349ms step_avg:92.64ms +step:782/1670 train_time:72441ms step_avg:92.64ms +step:783/1670 train_time:72536ms step_avg:92.64ms +step:784/1670 train_time:72629ms step_avg:92.64ms +step:785/1670 train_time:72721ms step_avg:92.64ms +step:786/1670 train_time:72814ms step_avg:92.64ms +step:787/1670 train_time:72906ms step_avg:92.64ms +step:788/1670 train_time:72998ms step_avg:92.64ms +step:789/1670 train_time:73091ms step_avg:92.64ms +step:790/1670 train_time:73183ms step_avg:92.64ms +step:791/1670 train_time:73276ms step_avg:92.64ms +step:792/1670 train_time:73368ms step_avg:92.64ms +step:793/1670 train_time:73460ms step_avg:92.64ms +step:794/1670 train_time:73554ms step_avg:92.64ms +step:795/1670 train_time:73648ms step_avg:92.64ms +step:796/1670 train_time:73740ms step_avg:92.64ms +step:797/1670 train_time:73834ms step_avg:92.64ms +step:798/1670 train_time:73926ms step_avg:92.64ms +step:799/1670 train_time:74018ms step_avg:92.64ms +step:800/1670 train_time:74111ms step_avg:92.64ms +step:801/1670 train_time:74203ms step_avg:92.64ms +step:802/1670 train_time:74295ms step_avg:92.64ms +step:803/1670 train_time:74388ms step_avg:92.64ms +step:804/1670 train_time:74481ms step_avg:92.64ms +step:805/1670 train_time:74573ms step_avg:92.64ms +step:806/1670 train_time:74666ms step_avg:92.64ms +step:807/1670 train_time:74759ms step_avg:92.64ms +step:808/1670 train_time:74853ms step_avg:92.64ms +step:809/1670 train_time:74945ms step_avg:92.64ms +step:810/1670 train_time:75037ms step_avg:92.64ms +step:811/1670 train_time:75130ms step_avg:92.64ms +step:812/1670 train_time:75223ms step_avg:92.64ms +step:813/1670 train_time:75317ms step_avg:92.64ms +step:814/1670 train_time:75409ms step_avg:92.64ms +step:815/1670 train_time:75502ms step_avg:92.64ms +step:816/1670 train_time:75595ms step_avg:92.64ms +step:817/1670 train_time:75687ms step_avg:92.64ms +step:818/1670 train_time:75781ms step_avg:92.64ms +step:819/1670 train_time:75874ms step_avg:92.64ms +step:820/1670 train_time:75965ms step_avg:92.64ms +step:821/1670 train_time:76059ms step_avg:92.64ms +step:822/1670 train_time:76153ms step_avg:92.64ms +step:823/1670 train_time:76246ms step_avg:92.64ms +step:824/1670 train_time:76339ms step_avg:92.64ms +step:825/1670 train_time:76430ms step_avg:92.64ms +step:826/1670 train_time:76523ms step_avg:92.64ms +step:827/1670 train_time:76616ms step_avg:92.64ms +step:828/1670 train_time:76708ms step_avg:92.64ms +step:829/1670 train_time:76800ms step_avg:92.64ms +step:830/1670 train_time:76893ms step_avg:92.64ms +step:831/1670 train_time:76986ms step_avg:92.64ms +step:832/1670 train_time:77078ms step_avg:92.64ms +step:833/1670 train_time:77171ms step_avg:92.64ms +step:834/1670 train_time:77262ms step_avg:92.64ms +step:835/1670 train_time:77356ms step_avg:92.64ms +step:836/1670 train_time:77449ms step_avg:92.64ms +step:837/1670 train_time:77541ms step_avg:92.64ms +step:838/1670 train_time:77634ms step_avg:92.64ms +step:839/1670 train_time:77727ms step_avg:92.64ms +step:840/1670 train_time:77818ms step_avg:92.64ms +step:841/1670 train_time:77911ms step_avg:92.64ms +step:842/1670 train_time:78004ms step_avg:92.64ms +step:843/1670 train_time:78096ms step_avg:92.64ms +step:844/1670 train_time:78189ms step_avg:92.64ms +step:845/1670 train_time:78281ms step_avg:92.64ms +step:846/1670 train_time:78374ms step_avg:92.64ms +step:847/1670 train_time:78466ms step_avg:92.64ms +step:848/1670 train_time:78558ms step_avg:92.64ms +step:849/1670 train_time:78651ms step_avg:92.64ms +step:850/1670 train_time:78743ms step_avg:92.64ms +step:851/1670 train_time:78995ms step_avg:92.83ms +step:852/1670 train_time:79066ms step_avg:92.80ms +step:853/1670 train_time:79157ms step_avg:92.80ms +step:854/1670 train_time:79248ms step_avg:92.80ms +step:855/1670 train_time:79338ms step_avg:92.79ms +step:856/1670 train_time:79430ms step_avg:92.79ms +step:857/1670 train_time:79521ms step_avg:92.79ms +step:858/1670 train_time:79613ms step_avg:92.79ms +step:859/1670 train_time:79705ms step_avg:92.79ms +step:860/1670 train_time:79796ms step_avg:92.79ms +step:861/1670 train_time:79892ms step_avg:92.79ms +step:862/1670 train_time:79988ms step_avg:92.79ms +step:863/1670 train_time:80082ms step_avg:92.79ms +step:864/1670 train_time:80175ms step_avg:92.79ms +step:865/1670 train_time:80267ms step_avg:92.79ms +step:866/1670 train_time:80358ms step_avg:92.79ms +step:867/1670 train_time:80450ms step_avg:92.79ms +step:868/1670 train_time:80541ms step_avg:92.79ms +step:869/1670 train_time:80633ms step_avg:92.79ms +step:870/1670 train_time:80724ms step_avg:92.79ms +step:871/1670 train_time:80817ms step_avg:92.79ms +step:872/1670 train_time:80913ms step_avg:92.79ms +step:873/1670 train_time:81007ms step_avg:92.79ms +step:874/1670 train_time:81100ms step_avg:92.79ms +step:875/1670 train_time:81193ms step_avg:92.79ms +step:875/1670 val_loss:3.5187 train_time:81286ms step_avg:92.90ms +step:876/1670 train_time:81306ms step_avg:92.81ms +step:877/1670 train_time:81383ms step_avg:92.80ms +step:878/1670 train_time:81479ms step_avg:92.80ms +step:879/1670 train_time:81571ms step_avg:92.80ms +step:880/1670 train_time:81661ms step_avg:92.80ms +step:881/1670 train_time:81753ms step_avg:92.80ms +step:882/1670 train_time:81844ms step_avg:92.79ms +step:883/1670 train_time:81936ms step_avg:92.79ms +step:884/1670 train_time:82028ms step_avg:92.79ms +step:885/1670 train_time:82121ms step_avg:92.79ms +step:886/1670 train_time:82214ms step_avg:92.79ms +step:887/1670 train_time:82308ms step_avg:92.79ms +step:888/1670 train_time:82402ms step_avg:92.80ms +step:889/1670 train_time:82496ms step_avg:92.80ms +step:890/1670 train_time:82588ms step_avg:92.79ms +step:891/1670 train_time:82680ms step_avg:92.79ms +step:892/1670 train_time:82772ms step_avg:92.79ms +step:893/1670 train_time:82864ms step_avg:92.79ms +step:894/1670 train_time:82956ms step_avg:92.79ms +step:895/1670 train_time:83048ms step_avg:92.79ms +step:896/1670 train_time:83141ms step_avg:92.79ms +step:897/1670 train_time:83234ms step_avg:92.79ms +step:898/1670 train_time:83327ms step_avg:92.79ms +step:899/1670 train_time:83420ms step_avg:92.79ms +step:900/1670 train_time:83514ms step_avg:92.79ms +step:901/1670 train_time:83606ms step_avg:92.79ms +step:902/1670 train_time:83698ms step_avg:92.79ms +step:903/1670 train_time:83790ms step_avg:92.79ms +step:904/1670 train_time:83882ms step_avg:92.79ms +step:905/1670 train_time:83975ms step_avg:92.79ms +step:906/1670 train_time:84067ms step_avg:92.79ms +step:907/1670 train_time:84159ms step_avg:92.79ms +step:908/1670 train_time:84252ms step_avg:92.79ms +step:909/1670 train_time:84345ms step_avg:92.79ms +step:910/1670 train_time:84439ms step_avg:92.79ms +step:911/1670 train_time:84531ms step_avg:92.79ms +step:912/1670 train_time:84623ms step_avg:92.79ms +step:913/1670 train_time:84716ms step_avg:92.79ms +step:914/1670 train_time:84808ms step_avg:92.79ms +step:915/1670 train_time:84900ms step_avg:92.79ms +step:916/1670 train_time:84992ms step_avg:92.79ms +step:917/1670 train_time:85085ms step_avg:92.79ms +step:918/1670 train_time:85177ms step_avg:92.79ms +step:919/1670 train_time:85270ms step_avg:92.79ms +step:920/1670 train_time:85363ms step_avg:92.79ms +step:921/1670 train_time:85457ms step_avg:92.79ms +step:922/1670 train_time:85550ms step_avg:92.79ms +step:923/1670 train_time:85643ms step_avg:92.79ms +step:924/1670 train_time:85735ms step_avg:92.79ms +step:925/1670 train_time:85827ms step_avg:92.79ms +step:926/1670 train_time:85920ms step_avg:92.79ms +step:927/1670 train_time:86012ms step_avg:92.79ms +step:928/1670 train_time:86104ms step_avg:92.79ms +step:929/1670 train_time:86197ms step_avg:92.79ms +step:930/1670 train_time:86291ms step_avg:92.79ms +step:931/1670 train_time:86384ms step_avg:92.79ms +step:932/1670 train_time:86477ms step_avg:92.79ms +step:933/1670 train_time:86568ms step_avg:92.78ms +step:934/1670 train_time:86662ms step_avg:92.79ms +step:935/1670 train_time:86754ms step_avg:92.79ms +step:936/1670 train_time:86846ms step_avg:92.78ms +step:937/1670 train_time:86940ms step_avg:92.78ms +step:938/1670 train_time:87032ms step_avg:92.78ms +step:939/1670 train_time:87123ms step_avg:92.78ms +step:940/1670 train_time:87217ms step_avg:92.78ms +step:941/1670 train_time:87310ms step_avg:92.78ms +step:942/1670 train_time:87403ms step_avg:92.78ms +step:943/1670 train_time:87495ms step_avg:92.78ms +step:944/1670 train_time:87587ms step_avg:92.78ms +step:945/1670 train_time:87681ms step_avg:92.78ms +step:946/1670 train_time:87773ms step_avg:92.78ms +step:947/1670 train_time:87866ms step_avg:92.78ms +step:948/1670 train_time:87958ms step_avg:92.78ms +step:949/1670 train_time:88050ms step_avg:92.78ms +step:950/1670 train_time:88143ms step_avg:92.78ms +step:951/1670 train_time:88236ms step_avg:92.78ms +step:952/1670 train_time:88327ms step_avg:92.78ms +step:953/1670 train_time:88421ms step_avg:92.78ms +step:954/1670 train_time:88513ms step_avg:92.78ms +step:955/1670 train_time:88605ms step_avg:92.78ms +step:956/1670 train_time:88698ms step_avg:92.78ms +step:957/1670 train_time:88791ms step_avg:92.78ms +step:958/1670 train_time:88883ms step_avg:92.78ms +step:959/1670 train_time:88976ms step_avg:92.78ms +step:960/1670 train_time:89068ms step_avg:92.78ms +step:961/1670 train_time:89160ms step_avg:92.78ms +step:962/1670 train_time:89253ms step_avg:92.78ms +step:963/1670 train_time:89345ms step_avg:92.78ms +step:964/1670 train_time:89438ms step_avg:92.78ms +step:965/1670 train_time:89530ms step_avg:92.78ms +step:966/1670 train_time:89623ms step_avg:92.78ms +step:967/1670 train_time:89717ms step_avg:92.78ms +step:968/1670 train_time:89809ms step_avg:92.78ms +step:969/1670 train_time:89903ms step_avg:92.78ms +step:970/1670 train_time:89996ms step_avg:92.78ms +step:971/1670 train_time:90088ms step_avg:92.78ms +step:972/1670 train_time:90181ms step_avg:92.78ms +step:973/1670 train_time:90273ms step_avg:92.78ms +step:974/1670 train_time:90365ms step_avg:92.78ms +step:975/1670 train_time:90457ms step_avg:92.78ms +step:976/1670 train_time:90549ms step_avg:92.78ms +step:977/1670 train_time:90642ms step_avg:92.78ms +step:978/1670 train_time:90734ms step_avg:92.78ms +step:979/1670 train_time:90826ms step_avg:92.77ms +step:980/1670 train_time:90920ms step_avg:92.78ms +step:981/1670 train_time:91012ms step_avg:92.77ms +step:982/1670 train_time:91104ms step_avg:92.77ms +step:983/1670 train_time:91197ms step_avg:92.77ms +step:984/1670 train_time:91288ms step_avg:92.77ms +step:985/1670 train_time:91382ms step_avg:92.77ms +step:986/1670 train_time:91474ms step_avg:92.77ms +step:987/1670 train_time:91566ms step_avg:92.77ms +step:988/1670 train_time:91659ms step_avg:92.77ms +step:989/1670 train_time:91751ms step_avg:92.77ms +step:990/1670 train_time:91844ms step_avg:92.77ms +step:991/1670 train_time:91938ms step_avg:92.77ms +step:992/1670 train_time:92030ms step_avg:92.77ms +step:993/1670 train_time:92123ms step_avg:92.77ms +step:994/1670 train_time:92216ms step_avg:92.77ms +step:995/1670 train_time:92308ms step_avg:92.77ms +step:996/1670 train_time:92400ms step_avg:92.77ms +step:997/1670 train_time:92493ms step_avg:92.77ms +step:998/1670 train_time:92585ms step_avg:92.77ms +step:999/1670 train_time:92677ms step_avg:92.77ms +step:1000/1670 train_time:92770ms step_avg:92.77ms +step:1000/1670 val_loss:3.4700 train_time:92862ms step_avg:92.86ms +step:1001/1670 train_time:92882ms step_avg:92.79ms +step:1002/1670 train_time:92957ms step_avg:92.77ms +step:1003/1670 train_time:93050ms step_avg:92.77ms +step:1004/1670 train_time:93141ms step_avg:92.77ms +step:1005/1670 train_time:93233ms step_avg:92.77ms +step:1006/1670 train_time:93324ms step_avg:92.77ms +step:1007/1670 train_time:93416ms step_avg:92.77ms +step:1008/1670 train_time:93509ms step_avg:92.77ms +step:1009/1670 train_time:93601ms step_avg:92.77ms +step:1010/1670 train_time:93693ms step_avg:92.77ms +step:1011/1670 train_time:93786ms step_avg:92.77ms +step:1012/1670 train_time:93880ms step_avg:92.77ms +step:1013/1670 train_time:93974ms step_avg:92.77ms +step:1014/1670 train_time:94068ms step_avg:92.77ms +step:1015/1670 train_time:94160ms step_avg:92.77ms +step:1016/1670 train_time:94252ms step_avg:92.77ms +step:1017/1670 train_time:94343ms step_avg:92.77ms +step:1018/1670 train_time:94437ms step_avg:92.77ms +step:1019/1670 train_time:94529ms step_avg:92.77ms +step:1020/1670 train_time:94621ms step_avg:92.77ms +step:1021/1670 train_time:94713ms step_avg:92.76ms +step:1022/1670 train_time:94806ms step_avg:92.77ms +step:1023/1670 train_time:94899ms step_avg:92.77ms +step:1024/1670 train_time:94993ms step_avg:92.77ms +step:1025/1670 train_time:95085ms step_avg:92.77ms +step:1026/1670 train_time:95178ms step_avg:92.77ms +step:1027/1670 train_time:95271ms step_avg:92.77ms +step:1028/1670 train_time:95363ms step_avg:92.77ms +step:1029/1670 train_time:95455ms step_avg:92.77ms +step:1030/1670 train_time:95548ms step_avg:92.77ms +step:1031/1670 train_time:95641ms step_avg:92.77ms +step:1032/1670 train_time:95734ms step_avg:92.77ms +step:1033/1670 train_time:95827ms step_avg:92.77ms +step:1034/1670 train_time:95919ms step_avg:92.77ms +step:1035/1670 train_time:96013ms step_avg:92.77ms +step:1036/1670 train_time:96105ms step_avg:92.77ms +step:1037/1670 train_time:96198ms step_avg:92.77ms +step:1038/1670 train_time:96290ms step_avg:92.76ms +step:1039/1670 train_time:96382ms step_avg:92.76ms +step:1040/1670 train_time:96475ms step_avg:92.76ms +step:1041/1670 train_time:96567ms step_avg:92.76ms +step:1042/1670 train_time:96659ms step_avg:92.76ms +step:1043/1670 train_time:96751ms step_avg:92.76ms +step:1044/1670 train_time:96844ms step_avg:92.76ms +step:1045/1670 train_time:96937ms step_avg:92.76ms +step:1046/1670 train_time:97030ms step_avg:92.76ms +step:1047/1670 train_time:97122ms step_avg:92.76ms +step:1048/1670 train_time:97214ms step_avg:92.76ms +step:1049/1670 train_time:97307ms step_avg:92.76ms +step:1050/1670 train_time:97399ms step_avg:92.76ms +step:1051/1670 train_time:97491ms step_avg:92.76ms +step:1052/1670 train_time:97582ms step_avg:92.76ms +step:1053/1670 train_time:97676ms step_avg:92.76ms +step:1054/1670 train_time:97769ms step_avg:92.76ms +step:1055/1670 train_time:97861ms step_avg:92.76ms +step:1056/1670 train_time:97954ms step_avg:92.76ms +step:1057/1670 train_time:98047ms step_avg:92.76ms +step:1058/1670 train_time:98139ms step_avg:92.76ms +step:1059/1670 train_time:98232ms step_avg:92.76ms +step:1060/1670 train_time:98325ms step_avg:92.76ms +step:1061/1670 train_time:98417ms step_avg:92.76ms +step:1062/1670 train_time:98653ms step_avg:92.89ms +step:1063/1670 train_time:98739ms step_avg:92.89ms +step:1064/1670 train_time:98830ms step_avg:92.88ms +step:1065/1670 train_time:98921ms step_avg:92.88ms +step:1066/1670 train_time:99012ms step_avg:92.88ms +step:1067/1670 train_time:99103ms step_avg:92.88ms +step:1068/1670 train_time:99195ms step_avg:92.88ms +step:1069/1670 train_time:99286ms step_avg:92.88ms +step:1070/1670 train_time:99377ms step_avg:92.88ms +step:1071/1670 train_time:99469ms step_avg:92.87ms +step:1072/1670 train_time:99568ms step_avg:92.88ms +step:1073/1670 train_time:99665ms step_avg:92.88ms +step:1074/1670 train_time:99758ms step_avg:92.88ms +step:1075/1670 train_time:99851ms step_avg:92.89ms +step:1076/1670 train_time:99943ms step_avg:92.88ms +step:1077/1670 train_time:100034ms step_avg:92.88ms +step:1078/1670 train_time:100126ms step_avg:92.88ms +step:1079/1670 train_time:100217ms step_avg:92.88ms +step:1080/1670 train_time:100309ms step_avg:92.88ms +step:1081/1670 train_time:100400ms step_avg:92.88ms +step:1082/1670 train_time:100494ms step_avg:92.88ms +step:1083/1670 train_time:100587ms step_avg:92.88ms +step:1084/1670 train_time:100681ms step_avg:92.88ms +step:1085/1670 train_time:100776ms step_avg:92.88ms +step:1086/1670 train_time:100870ms step_avg:92.88ms +step:1087/1670 train_time:100961ms step_avg:92.88ms +step:1088/1670 train_time:101054ms step_avg:92.88ms +step:1089/1670 train_time:101146ms step_avg:92.88ms +step:1090/1670 train_time:101239ms step_avg:92.88ms +step:1091/1670 train_time:101330ms step_avg:92.88ms +step:1092/1670 train_time:101423ms step_avg:92.88ms +step:1093/1670 train_time:101516ms step_avg:92.88ms +step:1094/1670 train_time:101611ms step_avg:92.88ms +step:1095/1670 train_time:101703ms step_avg:92.88ms +step:1096/1670 train_time:101797ms step_avg:92.88ms +step:1097/1670 train_time:101889ms step_avg:92.88ms +step:1098/1670 train_time:101981ms step_avg:92.88ms +step:1099/1670 train_time:102073ms step_avg:92.88ms +step:1100/1670 train_time:102167ms step_avg:92.88ms +step:1101/1670 train_time:102260ms step_avg:92.88ms +step:1102/1670 train_time:102352ms step_avg:92.88ms +step:1103/1670 train_time:102445ms step_avg:92.88ms +step:1104/1670 train_time:102538ms step_avg:92.88ms +step:1105/1670 train_time:102632ms step_avg:92.88ms +step:1106/1670 train_time:102724ms step_avg:92.88ms +step:1107/1670 train_time:102818ms step_avg:92.88ms +step:1108/1670 train_time:102912ms step_avg:92.88ms +step:1109/1670 train_time:103003ms step_avg:92.88ms +step:1110/1670 train_time:103096ms step_avg:92.88ms +step:1111/1670 train_time:103189ms step_avg:92.88ms +step:1112/1670 train_time:103283ms step_avg:92.88ms +step:1113/1670 train_time:103372ms step_avg:92.88ms +step:1114/1670 train_time:103465ms step_avg:92.88ms +step:1115/1670 train_time:103755ms step_avg:93.05ms +step:1116/1670 train_time:103827ms step_avg:93.04ms +step:1117/1670 train_time:103919ms step_avg:93.03ms +step:1118/1670 train_time:104011ms step_avg:93.03ms +step:1119/1670 train_time:104102ms step_avg:93.03ms +step:1120/1670 train_time:104194ms step_avg:93.03ms +step:1121/1670 train_time:104286ms step_avg:93.03ms +step:1122/1670 train_time:104378ms step_avg:93.03ms +step:1123/1670 train_time:104470ms step_avg:93.03ms +step:1124/1670 train_time:104562ms step_avg:93.03ms +step:1125/1670 train_time:104661ms step_avg:93.03ms +step:1125/1670 val_loss:3.4172 train_time:104761ms step_avg:93.12ms +step:1126/1670 train_time:104781ms step_avg:93.06ms +step:1127/1670 train_time:104857ms step_avg:93.04ms +step:1128/1670 train_time:104957ms step_avg:93.05ms +step:1129/1670 train_time:105053ms step_avg:93.05ms +step:1130/1670 train_time:105147ms step_avg:93.05ms +step:1131/1670 train_time:105238ms step_avg:93.05ms +step:1132/1670 train_time:105331ms step_avg:93.05ms +step:1133/1670 train_time:105423ms step_avg:93.05ms +step:1134/1670 train_time:105515ms step_avg:93.05ms +step:1135/1670 train_time:105607ms step_avg:93.05ms +step:1136/1670 train_time:105701ms step_avg:93.05ms +step:1137/1670 train_time:105795ms step_avg:93.05ms +step:1138/1670 train_time:105892ms step_avg:93.05ms +step:1139/1670 train_time:105989ms step_avg:93.05ms +step:1140/1670 train_time:106082ms step_avg:93.05ms +step:1141/1670 train_time:106175ms step_avg:93.05ms +step:1142/1670 train_time:106269ms step_avg:93.05ms +step:1143/1670 train_time:106361ms step_avg:93.05ms +step:1144/1670 train_time:106453ms step_avg:93.05ms +step:1145/1670 train_time:106545ms step_avg:93.05ms +step:1146/1670 train_time:106636ms step_avg:93.05ms +step:1147/1670 train_time:106729ms step_avg:93.05ms +step:1148/1670 train_time:106823ms step_avg:93.05ms +step:1149/1670 train_time:106918ms step_avg:93.05ms +step:1150/1670 train_time:107012ms step_avg:93.05ms +step:1151/1670 train_time:107105ms step_avg:93.05ms +step:1152/1670 train_time:107198ms step_avg:93.05ms +step:1153/1670 train_time:107293ms step_avg:93.06ms +step:1154/1670 train_time:107386ms step_avg:93.06ms +step:1155/1670 train_time:107478ms step_avg:93.05ms +step:1156/1670 train_time:107570ms step_avg:93.05ms +step:1157/1670 train_time:107663ms step_avg:93.05ms +step:1158/1670 train_time:107756ms step_avg:93.05ms +step:1159/1670 train_time:107851ms step_avg:93.05ms +step:1160/1670 train_time:107945ms step_avg:93.06ms +step:1161/1670 train_time:108041ms step_avg:93.06ms +step:1162/1670 train_time:108134ms step_avg:93.06ms +step:1163/1670 train_time:108229ms step_avg:93.06ms +step:1164/1670 train_time:108322ms step_avg:93.06ms +step:1165/1670 train_time:108415ms step_avg:93.06ms +step:1166/1670 train_time:108508ms step_avg:93.06ms +step:1167/1670 train_time:108600ms step_avg:93.06ms +step:1168/1670 train_time:108694ms step_avg:93.06ms +step:1169/1670 train_time:108788ms step_avg:93.06ms +step:1170/1670 train_time:108881ms step_avg:93.06ms +step:1171/1670 train_time:108974ms step_avg:93.06ms +step:1172/1670 train_time:109068ms step_avg:93.06ms +step:1173/1670 train_time:109161ms step_avg:93.06ms +step:1174/1670 train_time:109255ms step_avg:93.06ms +step:1175/1670 train_time:109348ms step_avg:93.06ms +step:1176/1670 train_time:109441ms step_avg:93.06ms +step:1177/1670 train_time:109535ms step_avg:93.06ms +step:1178/1670 train_time:109628ms step_avg:93.06ms +step:1179/1670 train_time:109720ms step_avg:93.06ms +step:1180/1670 train_time:109815ms step_avg:93.06ms +step:1181/1670 train_time:109909ms step_avg:93.06ms +step:1182/1670 train_time:110002ms step_avg:93.06ms +step:1183/1670 train_time:110096ms step_avg:93.06ms +step:1184/1670 train_time:110190ms step_avg:93.07ms +step:1185/1670 train_time:110283ms step_avg:93.07ms +step:1186/1670 train_time:110376ms step_avg:93.07ms +step:1187/1670 train_time:110468ms step_avg:93.07ms +step:1188/1670 train_time:110561ms step_avg:93.06ms +step:1189/1670 train_time:110654ms step_avg:93.06ms +step:1190/1670 train_time:110747ms step_avg:93.06ms +step:1191/1670 train_time:110839ms step_avg:93.06ms +step:1192/1670 train_time:110934ms step_avg:93.07ms +step:1193/1670 train_time:111027ms step_avg:93.07ms +step:1194/1670 train_time:111120ms step_avg:93.07ms +step:1195/1670 train_time:111213ms step_avg:93.07ms +step:1196/1670 train_time:111307ms step_avg:93.07ms +step:1197/1670 train_time:111400ms step_avg:93.07ms +step:1198/1670 train_time:111493ms step_avg:93.07ms +step:1199/1670 train_time:111586ms step_avg:93.07ms +step:1200/1670 train_time:111679ms step_avg:93.07ms +step:1201/1670 train_time:111771ms step_avg:93.07ms +step:1202/1670 train_time:111865ms step_avg:93.07ms +step:1203/1670 train_time:111958ms step_avg:93.07ms +step:1204/1670 train_time:112052ms step_avg:93.07ms +step:1205/1670 train_time:112145ms step_avg:93.07ms +step:1206/1670 train_time:112238ms step_avg:93.07ms +step:1207/1670 train_time:112334ms step_avg:93.07ms +step:1208/1670 train_time:112428ms step_avg:93.07ms +step:1209/1670 train_time:112521ms step_avg:93.07ms +step:1210/1670 train_time:112615ms step_avg:93.07ms +step:1211/1670 train_time:112708ms step_avg:93.07ms +step:1212/1670 train_time:112801ms step_avg:93.07ms +step:1213/1670 train_time:112894ms step_avg:93.07ms +step:1214/1670 train_time:112987ms step_avg:93.07ms +step:1215/1670 train_time:113080ms step_avg:93.07ms +step:1216/1670 train_time:113173ms step_avg:93.07ms +step:1217/1670 train_time:113268ms step_avg:93.07ms +step:1218/1670 train_time:113362ms step_avg:93.07ms +step:1219/1670 train_time:113456ms step_avg:93.07ms +step:1220/1670 train_time:113548ms step_avg:93.07ms +step:1221/1670 train_time:113642ms step_avg:93.07ms +step:1222/1670 train_time:113736ms step_avg:93.07ms +step:1223/1670 train_time:113829ms step_avg:93.07ms +step:1224/1670 train_time:113922ms step_avg:93.07ms +step:1225/1670 train_time:114015ms step_avg:93.07ms +step:1226/1670 train_time:114108ms step_avg:93.07ms +step:1227/1670 train_time:114202ms step_avg:93.07ms +step:1228/1670 train_time:114295ms step_avg:93.07ms +step:1229/1670 train_time:114388ms step_avg:93.07ms +step:1230/1670 train_time:114482ms step_avg:93.07ms +step:1231/1670 train_time:114575ms step_avg:93.07ms +step:1232/1670 train_time:114668ms step_avg:93.07ms +step:1233/1670 train_time:114762ms step_avg:93.08ms +step:1234/1670 train_time:114855ms step_avg:93.08ms +step:1235/1670 train_time:114948ms step_avg:93.08ms +step:1236/1670 train_time:115040ms step_avg:93.07ms +step:1237/1670 train_time:115134ms step_avg:93.08ms +step:1238/1670 train_time:115228ms step_avg:93.08ms +step:1239/1670 train_time:115321ms step_avg:93.08ms +step:1240/1670 train_time:115415ms step_avg:93.08ms +step:1241/1670 train_time:115508ms step_avg:93.08ms +step:1242/1670 train_time:115601ms step_avg:93.08ms +step:1243/1670 train_time:115694ms step_avg:93.08ms +step:1244/1670 train_time:115787ms step_avg:93.08ms +step:1245/1670 train_time:115880ms step_avg:93.08ms +step:1246/1670 train_time:115973ms step_avg:93.08ms +step:1247/1670 train_time:116068ms step_avg:93.08ms +step:1248/1670 train_time:116161ms step_avg:93.08ms +step:1249/1670 train_time:116254ms step_avg:93.08ms +step:1250/1670 train_time:116346ms step_avg:93.08ms +step:1250/1670 val_loss:3.3784 train_time:116438ms step_avg:93.15ms +step:1251/1670 train_time:116458ms step_avg:93.09ms +step:1252/1670 train_time:116533ms step_avg:93.08ms +step:1253/1670 train_time:116626ms step_avg:93.08ms +step:1254/1670 train_time:116719ms step_avg:93.08ms +step:1255/1670 train_time:116812ms step_avg:93.08ms +step:1256/1670 train_time:116904ms step_avg:93.08ms +step:1257/1670 train_time:116996ms step_avg:93.08ms +step:1258/1670 train_time:117089ms step_avg:93.08ms +step:1259/1670 train_time:117182ms step_avg:93.08ms +step:1260/1670 train_time:117276ms step_avg:93.08ms +step:1261/1670 train_time:117370ms step_avg:93.08ms +step:1262/1670 train_time:117465ms step_avg:93.08ms +step:1263/1670 train_time:117559ms step_avg:93.08ms +step:1264/1670 train_time:117651ms step_avg:93.08ms +step:1265/1670 train_time:117743ms step_avg:93.08ms +step:1266/1670 train_time:117836ms step_avg:93.08ms +step:1267/1670 train_time:117930ms step_avg:93.08ms +step:1268/1670 train_time:118025ms step_avg:93.08ms +step:1269/1670 train_time:118117ms step_avg:93.08ms +step:1270/1670 train_time:118210ms step_avg:93.08ms +step:1271/1670 train_time:118303ms step_avg:93.08ms +step:1272/1670 train_time:118398ms step_avg:93.08ms +step:1273/1670 train_time:118492ms step_avg:93.08ms +step:1274/1670 train_time:118735ms step_avg:93.20ms +step:1275/1670 train_time:118809ms step_avg:93.18ms +step:1276/1670 train_time:118901ms step_avg:93.18ms +step:1277/1670 train_time:118993ms step_avg:93.18ms +step:1278/1670 train_time:119084ms step_avg:93.18ms +step:1279/1670 train_time:119176ms step_avg:93.18ms +step:1280/1670 train_time:119269ms step_avg:93.18ms +step:1281/1670 train_time:119360ms step_avg:93.18ms +step:1282/1670 train_time:119453ms step_avg:93.18ms +step:1283/1670 train_time:119545ms step_avg:93.18ms +step:1284/1670 train_time:119646ms step_avg:93.18ms +step:1285/1670 train_time:119743ms step_avg:93.19ms +step:1286/1670 train_time:119836ms step_avg:93.19ms +step:1287/1670 train_time:119930ms step_avg:93.19ms +step:1288/1670 train_time:120023ms step_avg:93.19ms +step:1289/1670 train_time:120115ms step_avg:93.18ms +step:1290/1670 train_time:120207ms step_avg:93.18ms +step:1291/1670 train_time:120299ms step_avg:93.18ms +step:1292/1670 train_time:120391ms step_avg:93.18ms +step:1293/1670 train_time:120483ms step_avg:93.18ms +step:1294/1670 train_time:120579ms step_avg:93.18ms +step:1295/1670 train_time:120679ms step_avg:93.19ms +step:1296/1670 train_time:120774ms step_avg:93.19ms +step:1297/1670 train_time:120867ms step_avg:93.19ms +step:1298/1670 train_time:120960ms step_avg:93.19ms +step:1299/1670 train_time:121052ms step_avg:93.19ms +step:1300/1670 train_time:121144ms step_avg:93.19ms +step:1301/1670 train_time:121237ms step_avg:93.19ms +step:1302/1670 train_time:121330ms step_avg:93.19ms +step:1303/1670 train_time:121421ms step_avg:93.19ms +step:1304/1670 train_time:121515ms step_avg:93.19ms +step:1305/1670 train_time:121610ms step_avg:93.19ms +step:1306/1670 train_time:121705ms step_avg:93.19ms +step:1307/1670 train_time:121798ms step_avg:93.19ms +step:1308/1670 train_time:121891ms step_avg:93.19ms +step:1309/1670 train_time:121984ms step_avg:93.19ms +step:1310/1670 train_time:122077ms step_avg:93.19ms +step:1311/1670 train_time:122170ms step_avg:93.19ms +step:1312/1670 train_time:122263ms step_avg:93.19ms +step:1313/1670 train_time:122356ms step_avg:93.19ms +step:1314/1670 train_time:122448ms step_avg:93.19ms +step:1315/1670 train_time:122541ms step_avg:93.19ms +step:1316/1670 train_time:122637ms step_avg:93.19ms +step:1317/1670 train_time:122732ms step_avg:93.19ms +step:1318/1670 train_time:122825ms step_avg:93.19ms +step:1319/1670 train_time:122918ms step_avg:93.19ms +step:1320/1670 train_time:123011ms step_avg:93.19ms +step:1321/1670 train_time:123103ms step_avg:93.19ms +step:1322/1670 train_time:123196ms step_avg:93.19ms +step:1323/1670 train_time:123288ms step_avg:93.19ms +step:1324/1670 train_time:123381ms step_avg:93.19ms +step:1325/1670 train_time:123474ms step_avg:93.19ms +step:1326/1670 train_time:123568ms step_avg:93.19ms +step:1327/1670 train_time:123661ms step_avg:93.19ms +step:1328/1670 train_time:123756ms step_avg:93.19ms +step:1329/1670 train_time:123850ms step_avg:93.19ms +step:1330/1670 train_time:123943ms step_avg:93.19ms +step:1331/1670 train_time:124036ms step_avg:93.19ms +step:1332/1670 train_time:124129ms step_avg:93.19ms +step:1333/1670 train_time:124222ms step_avg:93.19ms +step:1334/1670 train_time:124314ms step_avg:93.19ms +step:1335/1670 train_time:124407ms step_avg:93.19ms +step:1336/1670 train_time:124500ms step_avg:93.19ms +step:1337/1670 train_time:124595ms step_avg:93.19ms +step:1338/1670 train_time:124688ms step_avg:93.19ms +step:1339/1670 train_time:124782ms step_avg:93.19ms +step:1340/1670 train_time:124877ms step_avg:93.19ms +step:1341/1670 train_time:124970ms step_avg:93.19ms +step:1342/1670 train_time:125064ms step_avg:93.19ms +step:1343/1670 train_time:125157ms step_avg:93.19ms +step:1344/1670 train_time:125250ms step_avg:93.19ms +step:1345/1670 train_time:125343ms step_avg:93.19ms +step:1346/1670 train_time:125435ms step_avg:93.19ms +step:1347/1670 train_time:125529ms step_avg:93.19ms +step:1348/1670 train_time:125621ms step_avg:93.19ms +step:1349/1670 train_time:125715ms step_avg:93.19ms +step:1350/1670 train_time:125808ms step_avg:93.19ms +step:1351/1670 train_time:125903ms step_avg:93.19ms +step:1352/1670 train_time:125997ms step_avg:93.19ms +step:1353/1670 train_time:126090ms step_avg:93.19ms +step:1354/1670 train_time:126183ms step_avg:93.19ms +step:1355/1670 train_time:126277ms step_avg:93.19ms +step:1356/1670 train_time:126371ms step_avg:93.19ms +step:1357/1670 train_time:126464ms step_avg:93.19ms +step:1358/1670 train_time:126557ms step_avg:93.19ms +step:1359/1670 train_time:126650ms step_avg:93.19ms +step:1360/1670 train_time:126742ms step_avg:93.19ms +step:1361/1670 train_time:126837ms step_avg:93.19ms +step:1362/1670 train_time:126931ms step_avg:93.19ms +step:1363/1670 train_time:127024ms step_avg:93.19ms +step:1364/1670 train_time:127117ms step_avg:93.19ms +step:1365/1670 train_time:127210ms step_avg:93.19ms +step:1366/1670 train_time:127303ms step_avg:93.19ms +step:1367/1670 train_time:127398ms step_avg:93.20ms +step:1368/1670 train_time:127491ms step_avg:93.19ms +step:1369/1670 train_time:127583ms step_avg:93.19ms +step:1370/1670 train_time:127678ms step_avg:93.20ms +step:1371/1670 train_time:127772ms step_avg:93.20ms +step:1372/1670 train_time:127865ms step_avg:93.20ms +step:1373/1670 train_time:127959ms step_avg:93.20ms +step:1374/1670 train_time:128053ms step_avg:93.20ms +step:1375/1670 train_time:128145ms step_avg:93.20ms +step:1375/1670 val_loss:3.3439 train_time:128238ms step_avg:93.26ms +step:1376/1670 train_time:128258ms step_avg:93.21ms +step:1377/1670 train_time:128331ms step_avg:93.20ms +step:1378/1670 train_time:128423ms step_avg:93.20ms +step:1379/1670 train_time:128516ms step_avg:93.20ms +step:1380/1670 train_time:128609ms step_avg:93.20ms +step:1381/1670 train_time:128702ms step_avg:93.19ms +step:1382/1670 train_time:128795ms step_avg:93.19ms +step:1383/1670 train_time:128889ms step_avg:93.20ms +step:1384/1670 train_time:128982ms step_avg:93.20ms +step:1385/1670 train_time:129077ms step_avg:93.20ms +step:1386/1670 train_time:129172ms step_avg:93.20ms +step:1387/1670 train_time:129267ms step_avg:93.20ms +step:1388/1670 train_time:129360ms step_avg:93.20ms +step:1389/1670 train_time:129453ms step_avg:93.20ms +step:1390/1670 train_time:129546ms step_avg:93.20ms +step:1391/1670 train_time:129639ms step_avg:93.20ms +step:1392/1670 train_time:129730ms step_avg:93.20ms +step:1393/1670 train_time:129823ms step_avg:93.20ms +step:1394/1670 train_time:129917ms step_avg:93.20ms +step:1395/1670 train_time:130012ms step_avg:93.20ms +step:1396/1670 train_time:130106ms step_avg:93.20ms +step:1397/1670 train_time:130200ms step_avg:93.20ms +step:1398/1670 train_time:130294ms step_avg:93.20ms +step:1399/1670 train_time:130386ms step_avg:93.20ms +step:1400/1670 train_time:130480ms step_avg:93.20ms +step:1401/1670 train_time:130572ms step_avg:93.20ms +step:1402/1670 train_time:130665ms step_avg:93.20ms +step:1403/1670 train_time:130757ms step_avg:93.20ms +step:1404/1670 train_time:130851ms step_avg:93.20ms +step:1405/1670 train_time:130943ms step_avg:93.20ms +step:1406/1670 train_time:131038ms step_avg:93.20ms +step:1407/1670 train_time:131131ms step_avg:93.20ms +step:1408/1670 train_time:131224ms step_avg:93.20ms +step:1409/1670 train_time:131319ms step_avg:93.20ms +step:1410/1670 train_time:131413ms step_avg:93.20ms +step:1411/1670 train_time:131505ms step_avg:93.20ms +step:1412/1670 train_time:131600ms step_avg:93.20ms +step:1413/1670 train_time:131692ms step_avg:93.20ms +step:1414/1670 train_time:131785ms step_avg:93.20ms +step:1415/1670 train_time:131880ms step_avg:93.20ms +step:1416/1670 train_time:131973ms step_avg:93.20ms +step:1417/1670 train_time:132066ms step_avg:93.20ms +step:1418/1670 train_time:132160ms step_avg:93.20ms +step:1419/1670 train_time:132254ms step_avg:93.20ms +step:1420/1670 train_time:132348ms step_avg:93.20ms +step:1421/1670 train_time:132441ms step_avg:93.20ms +step:1422/1670 train_time:132534ms step_avg:93.20ms +step:1423/1670 train_time:132627ms step_avg:93.20ms +step:1424/1670 train_time:132721ms step_avg:93.20ms +step:1425/1670 train_time:132814ms step_avg:93.20ms +step:1426/1670 train_time:132906ms step_avg:93.20ms +step:1427/1670 train_time:133000ms step_avg:93.20ms +step:1428/1670 train_time:133093ms step_avg:93.20ms +step:1429/1670 train_time:133187ms step_avg:93.20ms +step:1430/1670 train_time:133280ms step_avg:93.20ms +step:1431/1670 train_time:133373ms step_avg:93.20ms +step:1432/1670 train_time:133466ms step_avg:93.20ms +step:1433/1670 train_time:133560ms step_avg:93.20ms +step:1434/1670 train_time:133653ms step_avg:93.20ms +step:1435/1670 train_time:133746ms step_avg:93.20ms +step:1436/1670 train_time:133840ms step_avg:93.20ms +step:1437/1670 train_time:133933ms step_avg:93.20ms +step:1438/1670 train_time:134026ms step_avg:93.20ms +step:1439/1670 train_time:134121ms step_avg:93.20ms +step:1440/1670 train_time:134214ms step_avg:93.20ms +step:1441/1670 train_time:134307ms step_avg:93.20ms +step:1442/1670 train_time:134400ms step_avg:93.20ms +step:1443/1670 train_time:134493ms step_avg:93.20ms +step:1444/1670 train_time:134587ms step_avg:93.20ms +step:1445/1670 train_time:134680ms step_avg:93.20ms +step:1446/1670 train_time:134774ms step_avg:93.20ms +step:1447/1670 train_time:134867ms step_avg:93.20ms +step:1448/1670 train_time:134960ms step_avg:93.20ms +step:1449/1670 train_time:135054ms step_avg:93.20ms +step:1450/1670 train_time:135148ms step_avg:93.21ms +step:1451/1670 train_time:135241ms step_avg:93.21ms +step:1452/1670 train_time:135335ms step_avg:93.21ms +step:1453/1670 train_time:135427ms step_avg:93.21ms +step:1454/1670 train_time:135521ms step_avg:93.21ms +step:1455/1670 train_time:135615ms step_avg:93.21ms +step:1456/1670 train_time:135707ms step_avg:93.21ms +step:1457/1670 train_time:135800ms step_avg:93.21ms +step:1458/1670 train_time:135894ms step_avg:93.21ms +step:1459/1670 train_time:135987ms step_avg:93.21ms +step:1460/1670 train_time:136081ms step_avg:93.21ms +step:1461/1670 train_time:136175ms step_avg:93.21ms +step:1462/1670 train_time:136268ms step_avg:93.21ms +step:1463/1670 train_time:136361ms step_avg:93.21ms +step:1464/1670 train_time:136454ms step_avg:93.21ms +step:1465/1670 train_time:136547ms step_avg:93.21ms +step:1466/1670 train_time:136640ms step_avg:93.21ms +step:1467/1670 train_time:136734ms step_avg:93.21ms +step:1468/1670 train_time:136827ms step_avg:93.21ms +step:1469/1670 train_time:136920ms step_avg:93.21ms +step:1470/1670 train_time:137013ms step_avg:93.21ms +step:1471/1670 train_time:137106ms step_avg:93.21ms +step:1472/1670 train_time:137200ms step_avg:93.21ms +step:1473/1670 train_time:137293ms step_avg:93.21ms +step:1474/1670 train_time:137386ms step_avg:93.21ms +step:1475/1670 train_time:137480ms step_avg:93.21ms +step:1476/1670 train_time:137572ms step_avg:93.21ms +step:1477/1670 train_time:137666ms step_avg:93.21ms +step:1478/1670 train_time:137759ms step_avg:93.21ms +step:1479/1670 train_time:137852ms step_avg:93.21ms +step:1480/1670 train_time:137946ms step_avg:93.21ms +step:1481/1670 train_time:138040ms step_avg:93.21ms +step:1482/1670 train_time:138131ms step_avg:93.21ms +step:1483/1670 train_time:138226ms step_avg:93.21ms +step:1484/1670 train_time:138319ms step_avg:93.21ms +step:1485/1670 train_time:138567ms step_avg:93.31ms +step:1486/1670 train_time:138641ms step_avg:93.30ms +step:1487/1670 train_time:138732ms step_avg:93.30ms +step:1488/1670 train_time:138824ms step_avg:93.30ms +step:1489/1670 train_time:138916ms step_avg:93.29ms +step:1490/1670 train_time:139008ms step_avg:93.29ms +step:1491/1670 train_time:139099ms step_avg:93.29ms +step:1492/1670 train_time:139191ms step_avg:93.29ms +step:1493/1670 train_time:139283ms step_avg:93.29ms +step:1494/1670 train_time:139376ms step_avg:93.29ms +step:1495/1670 train_time:139473ms step_avg:93.29ms +step:1496/1670 train_time:139571ms step_avg:93.30ms +step:1497/1670 train_time:139666ms step_avg:93.30ms +step:1498/1670 train_time:139759ms step_avg:93.30ms +step:1499/1670 train_time:139853ms step_avg:93.30ms +step:1500/1670 train_time:139945ms step_avg:93.30ms +step:1500/1670 val_loss:3.3137 train_time:140039ms step_avg:93.36ms +step:1501/1670 train_time:140059ms step_avg:93.31ms +step:1502/1670 train_time:140133ms step_avg:93.30ms +step:1503/1670 train_time:140225ms step_avg:93.30ms +step:1504/1670 train_time:140319ms step_avg:93.30ms +step:1505/1670 train_time:140412ms step_avg:93.30ms +step:1506/1670 train_time:140504ms step_avg:93.30ms +step:1507/1670 train_time:140598ms step_avg:93.30ms +step:1508/1670 train_time:140693ms step_avg:93.30ms +step:1509/1670 train_time:140786ms step_avg:93.30ms +step:1510/1670 train_time:140879ms step_avg:93.30ms +step:1511/1670 train_time:140973ms step_avg:93.30ms +step:1512/1670 train_time:141068ms step_avg:93.30ms +step:1513/1670 train_time:141161ms step_avg:93.30ms +step:1514/1670 train_time:141254ms step_avg:93.30ms +step:1515/1670 train_time:141346ms step_avg:93.30ms +step:1516/1670 train_time:141439ms step_avg:93.30ms +step:1517/1670 train_time:141532ms step_avg:93.30ms +step:1518/1670 train_time:141625ms step_avg:93.30ms +step:1519/1670 train_time:141719ms step_avg:93.30ms +step:1520/1670 train_time:141812ms step_avg:93.30ms +step:1521/1670 train_time:141905ms step_avg:93.30ms +step:1522/1670 train_time:142000ms step_avg:93.30ms +step:1523/1670 train_time:142093ms step_avg:93.30ms +step:1524/1670 train_time:142186ms step_avg:93.30ms +step:1525/1670 train_time:142282ms step_avg:93.30ms +step:1526/1670 train_time:142375ms step_avg:93.30ms +step:1527/1670 train_time:142467ms step_avg:93.30ms +step:1528/1670 train_time:142561ms step_avg:93.30ms +step:1529/1670 train_time:142655ms step_avg:93.30ms +step:1530/1670 train_time:142747ms step_avg:93.30ms +step:1531/1670 train_time:142841ms step_avg:93.30ms +step:1532/1670 train_time:142934ms step_avg:93.30ms +step:1533/1670 train_time:143028ms step_avg:93.30ms +step:1534/1670 train_time:143121ms step_avg:93.30ms +step:1535/1670 train_time:143214ms step_avg:93.30ms +step:1536/1670 train_time:143308ms step_avg:93.30ms +step:1537/1670 train_time:143401ms step_avg:93.30ms +step:1538/1670 train_time:143493ms step_avg:93.30ms +step:1539/1670 train_time:143586ms step_avg:93.30ms +step:1540/1670 train_time:143680ms step_avg:93.30ms +step:1541/1670 train_time:143774ms step_avg:93.30ms +step:1542/1670 train_time:143867ms step_avg:93.30ms +step:1543/1670 train_time:143962ms step_avg:93.30ms +step:1544/1670 train_time:144055ms step_avg:93.30ms +step:1545/1670 train_time:144148ms step_avg:93.30ms +step:1546/1670 train_time:144241ms step_avg:93.30ms +step:1547/1670 train_time:144335ms step_avg:93.30ms +step:1548/1670 train_time:144428ms step_avg:93.30ms +step:1549/1670 train_time:144522ms step_avg:93.30ms +step:1550/1670 train_time:144615ms step_avg:93.30ms +step:1551/1670 train_time:144708ms step_avg:93.30ms +step:1552/1670 train_time:144801ms step_avg:93.30ms +step:1553/1670 train_time:144894ms step_avg:93.30ms +step:1554/1670 train_time:144988ms step_avg:93.30ms +step:1555/1670 train_time:145080ms step_avg:93.30ms +step:1556/1670 train_time:145173ms step_avg:93.30ms +step:1557/1670 train_time:145266ms step_avg:93.30ms +step:1558/1670 train_time:145360ms step_avg:93.30ms +step:1559/1670 train_time:145454ms step_avg:93.30ms +step:1560/1670 train_time:145547ms step_avg:93.30ms +step:1561/1670 train_time:145641ms step_avg:93.30ms +step:1562/1670 train_time:145733ms step_avg:93.30ms +step:1563/1670 train_time:145827ms step_avg:93.30ms +step:1564/1670 train_time:145921ms step_avg:93.30ms +step:1565/1670 train_time:146015ms step_avg:93.30ms +step:1566/1670 train_time:146108ms step_avg:93.30ms +step:1567/1670 train_time:146201ms step_avg:93.30ms +step:1568/1670 train_time:146294ms step_avg:93.30ms +step:1569/1670 train_time:146387ms step_avg:93.30ms +step:1570/1670 train_time:146481ms step_avg:93.30ms +step:1571/1670 train_time:146574ms step_avg:93.30ms +step:1572/1670 train_time:146666ms step_avg:93.30ms +step:1573/1670 train_time:146761ms step_avg:93.30ms +step:1574/1670 train_time:146855ms step_avg:93.30ms +step:1575/1670 train_time:146948ms step_avg:93.30ms +step:1576/1670 train_time:147042ms step_avg:93.30ms +step:1577/1670 train_time:147135ms step_avg:93.30ms +step:1578/1670 train_time:147227ms step_avg:93.30ms +step:1579/1670 train_time:147322ms step_avg:93.30ms +step:1580/1670 train_time:147415ms step_avg:93.30ms +step:1581/1670 train_time:147508ms step_avg:93.30ms +step:1582/1670 train_time:147601ms step_avg:93.30ms +step:1583/1670 train_time:147695ms step_avg:93.30ms +step:1584/1670 train_time:147787ms step_avg:93.30ms +step:1585/1670 train_time:147880ms step_avg:93.30ms +step:1586/1670 train_time:147973ms step_avg:93.30ms +step:1587/1670 train_time:148066ms step_avg:93.30ms +step:1588/1670 train_time:148161ms step_avg:93.30ms +step:1589/1670 train_time:148255ms step_avg:93.30ms +step:1590/1670 train_time:148347ms step_avg:93.30ms +step:1591/1670 train_time:148441ms step_avg:93.30ms +step:1592/1670 train_time:148535ms step_avg:93.30ms +step:1593/1670 train_time:148628ms step_avg:93.30ms +step:1594/1670 train_time:148721ms step_avg:93.30ms +step:1595/1670 train_time:148814ms step_avg:93.30ms +step:1596/1670 train_time:148907ms step_avg:93.30ms +step:1597/1670 train_time:149001ms step_avg:93.30ms +step:1598/1670 train_time:149094ms step_avg:93.30ms +step:1599/1670 train_time:149188ms step_avg:93.30ms +step:1600/1670 train_time:149283ms step_avg:93.30ms +step:1601/1670 train_time:149376ms step_avg:93.30ms +step:1602/1670 train_time:149468ms step_avg:93.30ms +step:1603/1670 train_time:149561ms step_avg:93.30ms +step:1604/1670 train_time:149655ms step_avg:93.30ms +step:1605/1670 train_time:149748ms step_avg:93.30ms +step:1606/1670 train_time:149841ms step_avg:93.30ms +step:1607/1670 train_time:149935ms step_avg:93.30ms +step:1608/1670 train_time:150028ms step_avg:93.30ms +step:1609/1670 train_time:150122ms step_avg:93.30ms +step:1610/1670 train_time:150216ms step_avg:93.30ms +step:1611/1670 train_time:150309ms step_avg:93.30ms +step:1612/1670 train_time:150402ms step_avg:93.30ms +step:1613/1670 train_time:150496ms step_avg:93.30ms +step:1614/1670 train_time:150591ms step_avg:93.30ms +step:1615/1670 train_time:150684ms step_avg:93.30ms +step:1616/1670 train_time:150777ms step_avg:93.30ms +step:1617/1670 train_time:150870ms step_avg:93.30ms +step:1618/1670 train_time:150963ms step_avg:93.30ms +step:1619/1670 train_time:151057ms step_avg:93.30ms +step:1620/1670 train_time:151150ms step_avg:93.30ms +step:1621/1670 train_time:151243ms step_avg:93.30ms +step:1622/1670 train_time:151336ms step_avg:93.30ms +step:1623/1670 train_time:151430ms step_avg:93.30ms +step:1624/1670 train_time:151523ms step_avg:93.30ms +step:1625/1670 train_time:151617ms step_avg:93.30ms +step:1625/1670 val_loss:3.2887 train_time:151709ms step_avg:93.36ms +step:1626/1670 train_time:151730ms step_avg:93.32ms +step:1627/1670 train_time:151804ms step_avg:93.30ms +step:1628/1670 train_time:151898ms step_avg:93.30ms +step:1629/1670 train_time:151990ms step_avg:93.30ms +step:1630/1670 train_time:152082ms step_avg:93.30ms +step:1631/1670 train_time:152175ms step_avg:93.30ms +step:1632/1670 train_time:152268ms step_avg:93.30ms +step:1633/1670 train_time:152362ms step_avg:93.30ms +step:1634/1670 train_time:152455ms step_avg:93.30ms +step:1635/1670 train_time:152547ms step_avg:93.30ms +step:1636/1670 train_time:152642ms step_avg:93.30ms +step:1637/1670 train_time:152737ms step_avg:93.30ms +step:1638/1670 train_time:152831ms step_avg:93.30ms +step:1639/1670 train_time:152926ms step_avg:93.30ms +step:1640/1670 train_time:153019ms step_avg:93.30ms +step:1641/1670 train_time:153111ms step_avg:93.30ms +step:1642/1670 train_time:153204ms step_avg:93.30ms +step:1643/1670 train_time:153297ms step_avg:93.30ms +step:1644/1670 train_time:153390ms step_avg:93.30ms +step:1645/1670 train_time:153483ms step_avg:93.30ms +step:1646/1670 train_time:153577ms step_avg:93.30ms +step:1647/1670 train_time:153671ms step_avg:93.30ms +step:1648/1670 train_time:153766ms step_avg:93.30ms +step:1649/1670 train_time:153860ms step_avg:93.31ms +step:1650/1670 train_time:153953ms step_avg:93.31ms +step:1651/1670 train_time:154046ms step_avg:93.30ms +step:1652/1670 train_time:154140ms step_avg:93.30ms +step:1653/1670 train_time:154231ms step_avg:93.30ms +step:1654/1670 train_time:154324ms step_avg:93.30ms +step:1655/1670 train_time:154417ms step_avg:93.30ms +step:1656/1670 train_time:154510ms step_avg:93.30ms +step:1657/1670 train_time:154603ms step_avg:93.30ms +step:1658/1670 train_time:154697ms step_avg:93.30ms +step:1659/1670 train_time:154790ms step_avg:93.30ms +step:1660/1670 train_time:154884ms step_avg:93.30ms +step:1661/1670 train_time:154977ms step_avg:93.30ms +step:1662/1670 train_time:155070ms step_avg:93.30ms +step:1663/1670 train_time:155164ms step_avg:93.30ms +step:1664/1670 train_time:155257ms step_avg:93.30ms +step:1665/1670 train_time:155350ms step_avg:93.30ms +step:1666/1670 train_time:155443ms step_avg:93.30ms +step:1667/1670 train_time:155536ms step_avg:93.30ms +step:1668/1670 train_time:155629ms step_avg:93.30ms +step:1669/1670 train_time:155723ms step_avg:93.30ms +step:1670/1670 train_time:155816ms step_avg:93.30ms +step:1670/1670 val_loss:3.2802 train_time:156077ms step_avg:93.46ms +peak memory allocated: 32002 MiB reserved: 46834 MiB diff --git a/records/091125_VectSigmoidBFloat16/19171a35-3730-4239-b15e-3728d8de73db.txt b/records/091125_VectSigmoidBFloat16/19171a35-3730-4239-b15e-3728d8de73db.txt new file mode 100644 index 000000000..03bfad48d --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/19171a35-3730-4239-b15e-3728d8de73db.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:12:55 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 134W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 45C P0 126W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 37C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 40C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:293ms step_avg:292.89ms +step:2/1670 train_time:312ms step_avg:156.04ms +step:3/1670 train_time:382ms step_avg:127.20ms +step:4/1670 train_time:471ms step_avg:117.85ms +step:5/1670 train_time:562ms step_avg:112.31ms +step:6/1670 train_time:652ms step_avg:108.62ms +step:7/1670 train_time:742ms step_avg:105.98ms +step:8/1670 train_time:832ms step_avg:104.01ms +step:9/1670 train_time:922ms step_avg:102.44ms +step:10/1670 train_time:1013ms step_avg:101.29ms +step:11/1670 train_time:1104ms step_avg:100.39ms +step:12/1670 train_time:1199ms step_avg:99.88ms +step:13/1670 train_time:1293ms step_avg:99.45ms +step:14/1670 train_time:1386ms step_avg:98.99ms +step:15/1670 train_time:1478ms step_avg:98.50ms +step:16/1670 train_time:1569ms step_avg:98.04ms +step:17/1670 train_time:1660ms step_avg:97.63ms +step:18/1670 train_time:1750ms step_avg:97.24ms +step:19/1670 train_time:1842ms step_avg:96.97ms +step:20/1670 train_time:1934ms step_avg:96.70ms +step:21/1670 train_time:2024ms step_avg:96.40ms +step:22/1670 train_time:2117ms step_avg:96.21ms +step:23/1670 train_time:2210ms step_avg:96.07ms +step:24/1670 train_time:2302ms step_avg:95.94ms +step:25/1670 train_time:2396ms step_avg:95.83ms +step:26/1670 train_time:2488ms step_avg:95.69ms +step:27/1670 train_time:2580ms step_avg:95.55ms +step:28/1670 train_time:2672ms step_avg:95.43ms +step:29/1670 train_time:2763ms step_avg:95.28ms +step:30/1670 train_time:2854ms step_avg:95.13ms +step:31/1670 train_time:2944ms step_avg:94.98ms +step:32/1670 train_time:3036ms step_avg:94.87ms +step:33/1670 train_time:3126ms step_avg:94.73ms +step:34/1670 train_time:3220ms step_avg:94.71ms +step:35/1670 train_time:3312ms step_avg:94.64ms +step:36/1670 train_time:3404ms step_avg:94.55ms +step:37/1670 train_time:3498ms step_avg:94.53ms +step:38/1670 train_time:3590ms step_avg:94.47ms +step:39/1670 train_time:3682ms step_avg:94.41ms +step:40/1670 train_time:3775ms step_avg:94.37ms +step:41/1670 train_time:3866ms step_avg:94.28ms +step:42/1670 train_time:3957ms step_avg:94.23ms +step:43/1670 train_time:4048ms step_avg:94.15ms +step:44/1670 train_time:4139ms step_avg:94.08ms +step:45/1670 train_time:4231ms step_avg:94.03ms +step:46/1670 train_time:4324ms step_avg:93.99ms +step:47/1670 train_time:4417ms step_avg:93.97ms +step:48/1670 train_time:4509ms step_avg:93.94ms +step:49/1670 train_time:4601ms step_avg:93.90ms +step:50/1670 train_time:4695ms step_avg:93.89ms +step:51/1670 train_time:4786ms step_avg:93.85ms +step:52/1670 train_time:4878ms step_avg:93.82ms +step:53/1670 train_time:4970ms step_avg:93.77ms +step:54/1670 train_time:5061ms step_avg:93.72ms +step:55/1670 train_time:5152ms step_avg:93.68ms +step:56/1670 train_time:5243ms step_avg:93.62ms +step:57/1670 train_time:5334ms step_avg:93.58ms +step:58/1670 train_time:5426ms step_avg:93.55ms +step:59/1670 train_time:5520ms step_avg:93.55ms +step:60/1670 train_time:5613ms step_avg:93.54ms +step:61/1670 train_time:5704ms step_avg:93.51ms +step:62/1670 train_time:5798ms step_avg:93.51ms +step:63/1670 train_time:5890ms step_avg:93.49ms +step:64/1670 train_time:5981ms step_avg:93.46ms +step:65/1670 train_time:6073ms step_avg:93.43ms +step:66/1670 train_time:6164ms step_avg:93.39ms +step:67/1670 train_time:6256ms step_avg:93.37ms +step:68/1670 train_time:6346ms step_avg:93.33ms +step:69/1670 train_time:6439ms step_avg:93.32ms +step:70/1670 train_time:6531ms step_avg:93.30ms +step:71/1670 train_time:6623ms step_avg:93.28ms +step:72/1670 train_time:6716ms step_avg:93.28ms +step:73/1670 train_time:6808ms step_avg:93.27ms +step:74/1670 train_time:6900ms step_avg:93.25ms +step:75/1670 train_time:6991ms step_avg:93.21ms +step:76/1670 train_time:7082ms step_avg:93.18ms +step:77/1670 train_time:7174ms step_avg:93.17ms +step:78/1670 train_time:7263ms step_avg:93.12ms +step:79/1670 train_time:7355ms step_avg:93.10ms +step:80/1670 train_time:7446ms step_avg:93.07ms +step:81/1670 train_time:7537ms step_avg:93.05ms +step:82/1670 train_time:7628ms step_avg:93.03ms +step:83/1670 train_time:7720ms step_avg:93.01ms +step:84/1670 train_time:7812ms step_avg:93.01ms +step:85/1670 train_time:7904ms step_avg:92.99ms +step:86/1670 train_time:7996ms step_avg:92.98ms +step:87/1670 train_time:8089ms step_avg:92.97ms +step:88/1670 train_time:8180ms step_avg:92.95ms +step:89/1670 train_time:8271ms step_avg:92.93ms +step:90/1670 train_time:8362ms step_avg:92.91ms +step:91/1670 train_time:8454ms step_avg:92.90ms +step:92/1670 train_time:8545ms step_avg:92.88ms +step:93/1670 train_time:8637ms step_avg:92.87ms +step:94/1670 train_time:8728ms step_avg:92.85ms +step:95/1670 train_time:8820ms step_avg:92.84ms +step:96/1670 train_time:8911ms step_avg:92.83ms +step:97/1670 train_time:9003ms step_avg:92.81ms +step:98/1670 train_time:9093ms step_avg:92.79ms +step:99/1670 train_time:9184ms step_avg:92.77ms +step:100/1670 train_time:9276ms step_avg:92.76ms +step:101/1670 train_time:9371ms step_avg:92.78ms +step:102/1670 train_time:9462ms step_avg:92.76ms +step:103/1670 train_time:9554ms step_avg:92.75ms +step:104/1670 train_time:9644ms step_avg:92.73ms +step:105/1670 train_time:9735ms step_avg:92.71ms +step:106/1670 train_time:9826ms step_avg:92.70ms +step:107/1670 train_time:9917ms step_avg:92.68ms +step:108/1670 train_time:10008ms step_avg:92.67ms +step:109/1670 train_time:10100ms step_avg:92.66ms +step:110/1670 train_time:10191ms step_avg:92.64ms +step:111/1670 train_time:10282ms step_avg:92.63ms +step:112/1670 train_time:10374ms step_avg:92.62ms +step:113/1670 train_time:10464ms step_avg:92.60ms +step:114/1670 train_time:10555ms step_avg:92.59ms +step:115/1670 train_time:10646ms step_avg:92.57ms +step:116/1670 train_time:10738ms step_avg:92.57ms +step:117/1670 train_time:10829ms step_avg:92.56ms +step:118/1670 train_time:10920ms step_avg:92.55ms +step:119/1670 train_time:11011ms step_avg:92.53ms +step:120/1670 train_time:11102ms step_avg:92.52ms +step:121/1670 train_time:11193ms step_avg:92.50ms +step:122/1670 train_time:11284ms step_avg:92.49ms +step:123/1670 train_time:11375ms step_avg:92.48ms +step:124/1670 train_time:11467ms step_avg:92.47ms +step:125/1670 train_time:11558ms step_avg:92.46ms +step:125/1670 val_loss:4.3115 train_time:11648ms step_avg:93.19ms +step:126/1670 train_time:11668ms step_avg:92.61ms +step:127/1670 train_time:11743ms step_avg:92.47ms +step:128/1670 train_time:11845ms step_avg:92.54ms +step:129/1670 train_time:11938ms step_avg:92.54ms +step:130/1670 train_time:12028ms step_avg:92.52ms +step:131/1670 train_time:12118ms step_avg:92.50ms +step:132/1670 train_time:12208ms step_avg:92.48ms +step:133/1670 train_time:12299ms step_avg:92.47ms +step:134/1670 train_time:12388ms step_avg:92.45ms +step:135/1670 train_time:12478ms step_avg:92.43ms +step:136/1670 train_time:12568ms step_avg:92.41ms +step:137/1670 train_time:12659ms step_avg:92.40ms +step:138/1670 train_time:12753ms step_avg:92.41ms +step:139/1670 train_time:12846ms step_avg:92.42ms +step:140/1670 train_time:12939ms step_avg:92.42ms +step:141/1670 train_time:13030ms step_avg:92.41ms +step:142/1670 train_time:13121ms step_avg:92.40ms +step:143/1670 train_time:13210ms step_avg:92.38ms +step:144/1670 train_time:13300ms step_avg:92.36ms +step:145/1670 train_time:13391ms step_avg:92.35ms +step:146/1670 train_time:13482ms step_avg:92.34ms +step:147/1670 train_time:13572ms step_avg:92.33ms +step:148/1670 train_time:13666ms step_avg:92.34ms +step:149/1670 train_time:13758ms step_avg:92.33ms +step:150/1670 train_time:13850ms step_avg:92.33ms +step:151/1670 train_time:13945ms step_avg:92.35ms +step:152/1670 train_time:14036ms step_avg:92.34ms +step:153/1670 train_time:14126ms step_avg:92.33ms +step:154/1670 train_time:14217ms step_avg:92.32ms +step:155/1670 train_time:14307ms step_avg:92.30ms +step:156/1670 train_time:14398ms step_avg:92.29ms +step:157/1670 train_time:14488ms step_avg:92.28ms +step:158/1670 train_time:14579ms step_avg:92.27ms +step:159/1670 train_time:14669ms step_avg:92.26ms +step:160/1670 train_time:14763ms step_avg:92.27ms +step:161/1670 train_time:14856ms step_avg:92.27ms +step:162/1670 train_time:14948ms step_avg:92.27ms +step:163/1670 train_time:15040ms step_avg:92.27ms +step:164/1670 train_time:15131ms step_avg:92.26ms +step:165/1670 train_time:15222ms step_avg:92.25ms +step:166/1670 train_time:15311ms step_avg:92.24ms +step:167/1670 train_time:15404ms step_avg:92.24ms +step:168/1670 train_time:15494ms step_avg:92.23ms +step:169/1670 train_time:15583ms step_avg:92.21ms +step:170/1670 train_time:15675ms step_avg:92.21ms +step:171/1670 train_time:15767ms step_avg:92.21ms +step:172/1670 train_time:15859ms step_avg:92.20ms +step:173/1670 train_time:15951ms step_avg:92.20ms +step:174/1670 train_time:16043ms step_avg:92.20ms +step:175/1670 train_time:16133ms step_avg:92.19ms +step:176/1670 train_time:16224ms step_avg:92.18ms +step:177/1670 train_time:16315ms step_avg:92.17ms +step:178/1670 train_time:16406ms step_avg:92.17ms +step:179/1670 train_time:16498ms step_avg:92.17ms +step:180/1670 train_time:16588ms step_avg:92.16ms +step:181/1670 train_time:16680ms step_avg:92.15ms +step:182/1670 train_time:16770ms step_avg:92.15ms +step:183/1670 train_time:16862ms step_avg:92.14ms +step:184/1670 train_time:16953ms step_avg:92.13ms +step:185/1670 train_time:17044ms step_avg:92.13ms +step:186/1670 train_time:17135ms step_avg:92.12ms +step:187/1670 train_time:17226ms step_avg:92.12ms +step:188/1670 train_time:17316ms step_avg:92.11ms +step:189/1670 train_time:17407ms step_avg:92.10ms +step:190/1670 train_time:17498ms step_avg:92.10ms +step:191/1670 train_time:17589ms step_avg:92.09ms +step:192/1670 train_time:17680ms step_avg:92.09ms +step:193/1670 train_time:17771ms step_avg:92.08ms +step:194/1670 train_time:17864ms step_avg:92.08ms +step:195/1670 train_time:17955ms step_avg:92.08ms +step:196/1670 train_time:18046ms step_avg:92.07ms +step:197/1670 train_time:18138ms step_avg:92.07ms +step:198/1670 train_time:18229ms step_avg:92.06ms +step:199/1670 train_time:18319ms step_avg:92.06ms +step:200/1670 train_time:18409ms step_avg:92.05ms +step:201/1670 train_time:18500ms step_avg:92.04ms +step:202/1670 train_time:18593ms step_avg:92.04ms +step:203/1670 train_time:18685ms step_avg:92.04ms +step:204/1670 train_time:18776ms step_avg:92.04ms +step:205/1670 train_time:18868ms step_avg:92.04ms +step:206/1670 train_time:18959ms step_avg:92.03ms +step:207/1670 train_time:19050ms step_avg:92.03ms +step:208/1670 train_time:19141ms step_avg:92.03ms +step:209/1670 train_time:19232ms step_avg:92.02ms +step:210/1670 train_time:19323ms step_avg:92.01ms +step:211/1670 train_time:19414ms step_avg:92.01ms +step:212/1670 train_time:19505ms step_avg:92.00ms +step:213/1670 train_time:19752ms step_avg:92.73ms +step:214/1670 train_time:19831ms step_avg:92.67ms +step:215/1670 train_time:19920ms step_avg:92.65ms +step:216/1670 train_time:20010ms step_avg:92.64ms +step:217/1670 train_time:20101ms step_avg:92.63ms +step:218/1670 train_time:20191ms step_avg:92.62ms +step:219/1670 train_time:20281ms step_avg:92.61ms +step:220/1670 train_time:20370ms step_avg:92.59ms +step:221/1670 train_time:20461ms step_avg:92.58ms +step:222/1670 train_time:20550ms step_avg:92.57ms +step:223/1670 train_time:20644ms step_avg:92.58ms +step:224/1670 train_time:20742ms step_avg:92.60ms +step:225/1670 train_time:20835ms step_avg:92.60ms +step:226/1670 train_time:20926ms step_avg:92.59ms +step:227/1670 train_time:21017ms step_avg:92.58ms +step:228/1670 train_time:21107ms step_avg:92.57ms +step:229/1670 train_time:21199ms step_avg:92.57ms +step:230/1670 train_time:21289ms step_avg:92.56ms +step:231/1670 train_time:21379ms step_avg:92.55ms +step:232/1670 train_time:21470ms step_avg:92.54ms +step:233/1670 train_time:21560ms step_avg:92.53ms +step:234/1670 train_time:21652ms step_avg:92.53ms +step:235/1670 train_time:21745ms step_avg:92.53ms +step:236/1670 train_time:21837ms step_avg:92.53ms +step:237/1670 train_time:21928ms step_avg:92.53ms +step:238/1670 train_time:22019ms step_avg:92.52ms +step:239/1670 train_time:22110ms step_avg:92.51ms +step:240/1670 train_time:22201ms step_avg:92.50ms +step:241/1670 train_time:22290ms step_avg:92.49ms +step:242/1670 train_time:22381ms step_avg:92.48ms +step:243/1670 train_time:22471ms step_avg:92.47ms +step:244/1670 train_time:22561ms step_avg:92.46ms +step:245/1670 train_time:22653ms step_avg:92.46ms +step:246/1670 train_time:22745ms step_avg:92.46ms +step:247/1670 train_time:22837ms step_avg:92.46ms +step:248/1670 train_time:22928ms step_avg:92.45ms +step:249/1670 train_time:23020ms step_avg:92.45ms +step:250/1670 train_time:23111ms step_avg:92.44ms +step:250/1670 val_loss:3.9716 train_time:23200ms step_avg:92.80ms +step:251/1670 train_time:23220ms step_avg:92.51ms +step:252/1670 train_time:23292ms step_avg:92.43ms +step:253/1670 train_time:23383ms step_avg:92.42ms +step:254/1670 train_time:23475ms step_avg:92.42ms +step:255/1670 train_time:23565ms step_avg:92.41ms +step:256/1670 train_time:23657ms step_avg:92.41ms +step:257/1670 train_time:23747ms step_avg:92.40ms +step:258/1670 train_time:23839ms step_avg:92.40ms +step:259/1670 train_time:23929ms step_avg:92.39ms +step:260/1670 train_time:24020ms step_avg:92.39ms +step:261/1670 train_time:24112ms step_avg:92.38ms +step:262/1670 train_time:24206ms step_avg:92.39ms +step:263/1670 train_time:24297ms step_avg:92.38ms +step:264/1670 train_time:24388ms step_avg:92.38ms +step:265/1670 train_time:24479ms step_avg:92.37ms +step:266/1670 train_time:24571ms step_avg:92.37ms +step:267/1670 train_time:24661ms step_avg:92.36ms +step:268/1670 train_time:24752ms step_avg:92.36ms +step:269/1670 train_time:24843ms step_avg:92.35ms +step:270/1670 train_time:24934ms step_avg:92.35ms +step:271/1670 train_time:25025ms step_avg:92.34ms +step:272/1670 train_time:25117ms step_avg:92.34ms +step:273/1670 train_time:25209ms step_avg:92.34ms +step:274/1670 train_time:25301ms step_avg:92.34ms +step:275/1670 train_time:25393ms step_avg:92.34ms +step:276/1670 train_time:25483ms step_avg:92.33ms +step:277/1670 train_time:25577ms step_avg:92.34ms +step:278/1670 train_time:25667ms step_avg:92.33ms +step:279/1670 train_time:25759ms step_avg:92.33ms +step:280/1670 train_time:25849ms step_avg:92.32ms +step:281/1670 train_time:25941ms step_avg:92.32ms +step:282/1670 train_time:26031ms step_avg:92.31ms +step:283/1670 train_time:26123ms step_avg:92.31ms +step:284/1670 train_time:26214ms step_avg:92.30ms +step:285/1670 train_time:26305ms step_avg:92.30ms +step:286/1670 train_time:26396ms step_avg:92.29ms +step:287/1670 train_time:26487ms step_avg:92.29ms +step:288/1670 train_time:26579ms step_avg:92.29ms +step:289/1670 train_time:26671ms step_avg:92.29ms +step:290/1670 train_time:26762ms step_avg:92.28ms +step:291/1670 train_time:26853ms step_avg:92.28ms +step:292/1670 train_time:26943ms step_avg:92.27ms +step:293/1670 train_time:27035ms step_avg:92.27ms +step:294/1670 train_time:27125ms step_avg:92.26ms +step:295/1670 train_time:27217ms step_avg:92.26ms +step:296/1670 train_time:27308ms step_avg:92.26ms +step:297/1670 train_time:27400ms step_avg:92.25ms +step:298/1670 train_time:27491ms step_avg:92.25ms +step:299/1670 train_time:27582ms step_avg:92.25ms +step:300/1670 train_time:27673ms step_avg:92.24ms +step:301/1670 train_time:27764ms step_avg:92.24ms +step:302/1670 train_time:27855ms step_avg:92.24ms +step:303/1670 train_time:27946ms step_avg:92.23ms +step:304/1670 train_time:28038ms step_avg:92.23ms +step:305/1670 train_time:28129ms step_avg:92.22ms +step:306/1670 train_time:28220ms step_avg:92.22ms +step:307/1670 train_time:28311ms step_avg:92.22ms +step:308/1670 train_time:28402ms step_avg:92.21ms +step:309/1670 train_time:28494ms step_avg:92.21ms +step:310/1670 train_time:28584ms step_avg:92.21ms +step:311/1670 train_time:28675ms step_avg:92.20ms +step:312/1670 train_time:28765ms step_avg:92.20ms +step:313/1670 train_time:28857ms step_avg:92.19ms +step:314/1670 train_time:28948ms step_avg:92.19ms +step:315/1670 train_time:29040ms step_avg:92.19ms +step:316/1670 train_time:29131ms step_avg:92.19ms +step:317/1670 train_time:29222ms step_avg:92.18ms +step:318/1670 train_time:29313ms step_avg:92.18ms +step:319/1670 train_time:29403ms step_avg:92.17ms +step:320/1670 train_time:29494ms step_avg:92.17ms +step:321/1670 train_time:29584ms step_avg:92.16ms +step:322/1670 train_time:29675ms step_avg:92.16ms +step:323/1670 train_time:29766ms step_avg:92.16ms +step:324/1670 train_time:29859ms step_avg:92.16ms +step:325/1670 train_time:29951ms step_avg:92.16ms +step:326/1670 train_time:30043ms step_avg:92.16ms +step:327/1670 train_time:30134ms step_avg:92.15ms +step:328/1670 train_time:30225ms step_avg:92.15ms +step:329/1670 train_time:30316ms step_avg:92.15ms +step:330/1670 train_time:30406ms step_avg:92.14ms +step:331/1670 train_time:30497ms step_avg:92.14ms +step:332/1670 train_time:30588ms step_avg:92.13ms +step:333/1670 train_time:30680ms step_avg:92.13ms +step:334/1670 train_time:30772ms step_avg:92.13ms +step:335/1670 train_time:30863ms step_avg:92.13ms +step:336/1670 train_time:30954ms step_avg:92.13ms +step:337/1670 train_time:31045ms step_avg:92.12ms +step:338/1670 train_time:31137ms step_avg:92.12ms +step:339/1670 train_time:31228ms step_avg:92.12ms +step:340/1670 train_time:31319ms step_avg:92.11ms +step:341/1670 train_time:31410ms step_avg:92.11ms +step:342/1670 train_time:31500ms step_avg:92.11ms +step:343/1670 train_time:31590ms step_avg:92.10ms +step:344/1670 train_time:31681ms step_avg:92.10ms +step:345/1670 train_time:31772ms step_avg:92.09ms +step:346/1670 train_time:31863ms step_avg:92.09ms +step:347/1670 train_time:31955ms step_avg:92.09ms +step:348/1670 train_time:32046ms step_avg:92.09ms +step:349/1670 train_time:32138ms step_avg:92.09ms +step:350/1670 train_time:32229ms step_avg:92.08ms +step:351/1670 train_time:32320ms step_avg:92.08ms +step:352/1670 train_time:32412ms step_avg:92.08ms +step:353/1670 train_time:32502ms step_avg:92.08ms +step:354/1670 train_time:32593ms step_avg:92.07ms +step:355/1670 train_time:32683ms step_avg:92.07ms +step:356/1670 train_time:32776ms step_avg:92.07ms +step:357/1670 train_time:32866ms step_avg:92.06ms +step:358/1670 train_time:32958ms step_avg:92.06ms +step:359/1670 train_time:33050ms step_avg:92.06ms +step:360/1670 train_time:33141ms step_avg:92.06ms +step:361/1670 train_time:33232ms step_avg:92.06ms +step:362/1670 train_time:33323ms step_avg:92.05ms +step:363/1670 train_time:33414ms step_avg:92.05ms +step:364/1670 train_time:33505ms step_avg:92.05ms +step:365/1670 train_time:33596ms step_avg:92.04ms +step:366/1670 train_time:33687ms step_avg:92.04ms +step:367/1670 train_time:33779ms step_avg:92.04ms +step:368/1670 train_time:33871ms step_avg:92.04ms +step:369/1670 train_time:33962ms step_avg:92.04ms +step:370/1670 train_time:34054ms step_avg:92.04ms +step:371/1670 train_time:34144ms step_avg:92.03ms +step:372/1670 train_time:34235ms step_avg:92.03ms +step:373/1670 train_time:34327ms step_avg:92.03ms +step:374/1670 train_time:34418ms step_avg:92.03ms +step:375/1670 train_time:34510ms step_avg:92.03ms +step:375/1670 val_loss:3.8163 train_time:34600ms step_avg:92.27ms +step:376/1670 train_time:34620ms step_avg:92.07ms +step:377/1670 train_time:34692ms step_avg:92.02ms +step:378/1670 train_time:34784ms step_avg:92.02ms +step:379/1670 train_time:34874ms step_avg:92.02ms +step:380/1670 train_time:34964ms step_avg:92.01ms +step:381/1670 train_time:35054ms step_avg:92.01ms +step:382/1670 train_time:35145ms step_avg:92.00ms +step:383/1670 train_time:35235ms step_avg:92.00ms +step:384/1670 train_time:35326ms step_avg:92.00ms +step:385/1670 train_time:35419ms step_avg:92.00ms +step:386/1670 train_time:35510ms step_avg:92.00ms +step:387/1670 train_time:35603ms step_avg:92.00ms +step:388/1670 train_time:35696ms step_avg:92.00ms +step:389/1670 train_time:35788ms step_avg:92.00ms +step:390/1670 train_time:35878ms step_avg:92.00ms +step:391/1670 train_time:35969ms step_avg:91.99ms +step:392/1670 train_time:36060ms step_avg:91.99ms +step:393/1670 train_time:36150ms step_avg:91.99ms +step:394/1670 train_time:36241ms step_avg:91.98ms +step:395/1670 train_time:36332ms step_avg:91.98ms +step:396/1670 train_time:36423ms step_avg:91.98ms +step:397/1670 train_time:36514ms step_avg:91.98ms +step:398/1670 train_time:36608ms step_avg:91.98ms +step:399/1670 train_time:36699ms step_avg:91.98ms +step:400/1670 train_time:36790ms step_avg:91.98ms +step:401/1670 train_time:36881ms step_avg:91.97ms +step:402/1670 train_time:36972ms step_avg:91.97ms +step:403/1670 train_time:37063ms step_avg:91.97ms +step:404/1670 train_time:37153ms step_avg:91.96ms +step:405/1670 train_time:37245ms step_avg:91.96ms +step:406/1670 train_time:37335ms step_avg:91.96ms +step:407/1670 train_time:37427ms step_avg:91.96ms +step:408/1670 train_time:37518ms step_avg:91.95ms +step:409/1670 train_time:37609ms step_avg:91.95ms +step:410/1670 train_time:37701ms step_avg:91.95ms +step:411/1670 train_time:37792ms step_avg:91.95ms +step:412/1670 train_time:37884ms step_avg:91.95ms +step:413/1670 train_time:37974ms step_avg:91.95ms +step:414/1670 train_time:38064ms step_avg:91.94ms +step:415/1670 train_time:38155ms step_avg:91.94ms +step:416/1670 train_time:38248ms step_avg:91.94ms +step:417/1670 train_time:38338ms step_avg:91.94ms +step:418/1670 train_time:38430ms step_avg:91.94ms +step:419/1670 train_time:38522ms step_avg:91.94ms +step:420/1670 train_time:38613ms step_avg:91.94ms +step:421/1670 train_time:38706ms step_avg:91.94ms +step:422/1670 train_time:38797ms step_avg:91.94ms +step:423/1670 train_time:38888ms step_avg:91.93ms +step:424/1670 train_time:38978ms step_avg:91.93ms +step:425/1670 train_time:39227ms step_avg:92.30ms +step:426/1670 train_time:39298ms step_avg:92.25ms +step:427/1670 train_time:39387ms step_avg:92.24ms +step:428/1670 train_time:39477ms step_avg:92.24ms +step:429/1670 train_time:39567ms step_avg:92.23ms +step:430/1670 train_time:39657ms step_avg:92.23ms +step:431/1670 train_time:39747ms step_avg:92.22ms +step:432/1670 train_time:39837ms step_avg:92.22ms +step:433/1670 train_time:39928ms step_avg:92.21ms +step:434/1670 train_time:40019ms step_avg:92.21ms +step:435/1670 train_time:40115ms step_avg:92.22ms +step:436/1670 train_time:40214ms step_avg:92.23ms +step:437/1670 train_time:40307ms step_avg:92.24ms +step:438/1670 train_time:40398ms step_avg:92.23ms +step:439/1670 train_time:40488ms step_avg:92.23ms +step:440/1670 train_time:40579ms step_avg:92.22ms +step:441/1670 train_time:40669ms step_avg:92.22ms +step:442/1670 train_time:40759ms step_avg:92.21ms +step:443/1670 train_time:40849ms step_avg:92.21ms +step:444/1670 train_time:40939ms step_avg:92.21ms +step:445/1670 train_time:41031ms step_avg:92.20ms +step:446/1670 train_time:41123ms step_avg:92.20ms +step:447/1670 train_time:41215ms step_avg:92.20ms +step:448/1670 train_time:41309ms step_avg:92.21ms +step:449/1670 train_time:41400ms step_avg:92.21ms +step:450/1670 train_time:41491ms step_avg:92.20ms +step:451/1670 train_time:41582ms step_avg:92.20ms +step:452/1670 train_time:41672ms step_avg:92.20ms +step:453/1670 train_time:41763ms step_avg:92.19ms +step:454/1670 train_time:41853ms step_avg:92.19ms +step:455/1670 train_time:41944ms step_avg:92.18ms +step:456/1670 train_time:42035ms step_avg:92.18ms +step:457/1670 train_time:42128ms step_avg:92.18ms +step:458/1670 train_time:42222ms step_avg:92.19ms +step:459/1670 train_time:42314ms step_avg:92.19ms +step:460/1670 train_time:42406ms step_avg:92.19ms +step:461/1670 train_time:42497ms step_avg:92.19ms +step:462/1670 train_time:42588ms step_avg:92.18ms +step:463/1670 train_time:42679ms step_avg:92.18ms +step:464/1670 train_time:42769ms step_avg:92.17ms +step:465/1670 train_time:42859ms step_avg:92.17ms +step:466/1670 train_time:42950ms step_avg:92.17ms +step:467/1670 train_time:43041ms step_avg:92.17ms +step:468/1670 train_time:43132ms step_avg:92.16ms +step:469/1670 train_time:43225ms step_avg:92.16ms +step:470/1670 train_time:43316ms step_avg:92.16ms +step:471/1670 train_time:43408ms step_avg:92.16ms +step:472/1670 train_time:43499ms step_avg:92.16ms +step:473/1670 train_time:43589ms step_avg:92.16ms +step:474/1670 train_time:43680ms step_avg:92.15ms +step:475/1670 train_time:43770ms step_avg:92.15ms +step:476/1670 train_time:43860ms step_avg:92.14ms +step:477/1670 train_time:43951ms step_avg:92.14ms +step:478/1670 train_time:44042ms step_avg:92.14ms +step:479/1670 train_time:44133ms step_avg:92.14ms +step:480/1670 train_time:44225ms step_avg:92.13ms +step:481/1670 train_time:44316ms step_avg:92.13ms +step:482/1670 train_time:44409ms step_avg:92.13ms +step:483/1670 train_time:44500ms step_avg:92.13ms +step:484/1670 train_time:44591ms step_avg:92.13ms +step:485/1670 train_time:44682ms step_avg:92.13ms +step:486/1670 train_time:44773ms step_avg:92.12ms +step:487/1670 train_time:44864ms step_avg:92.12ms +step:488/1670 train_time:44954ms step_avg:92.12ms +step:489/1670 train_time:45047ms step_avg:92.12ms +step:490/1670 train_time:45138ms step_avg:92.12ms +step:491/1670 train_time:45230ms step_avg:92.12ms +step:492/1670 train_time:45322ms step_avg:92.12ms +step:493/1670 train_time:45412ms step_avg:92.11ms +step:494/1670 train_time:45504ms step_avg:92.11ms +step:495/1670 train_time:45595ms step_avg:92.11ms +step:496/1670 train_time:45686ms step_avg:92.11ms +step:497/1670 train_time:45777ms step_avg:92.11ms +step:498/1670 train_time:45868ms step_avg:92.11ms +step:499/1670 train_time:45960ms step_avg:92.10ms +step:500/1670 train_time:46052ms step_avg:92.10ms +step:500/1670 val_loss:3.7143 train_time:46144ms step_avg:92.29ms +step:501/1670 train_time:46164ms step_avg:92.14ms +step:502/1670 train_time:46236ms step_avg:92.10ms +step:503/1670 train_time:46329ms step_avg:92.10ms +step:504/1670 train_time:46420ms step_avg:92.10ms +step:505/1670 train_time:46509ms step_avg:92.10ms +step:506/1670 train_time:46600ms step_avg:92.09ms +step:507/1670 train_time:46690ms step_avg:92.09ms +step:508/1670 train_time:46783ms step_avg:92.09ms +step:509/1670 train_time:46874ms step_avg:92.09ms +step:510/1670 train_time:46964ms step_avg:92.09ms +step:511/1670 train_time:47055ms step_avg:92.08ms +step:512/1670 train_time:47147ms step_avg:92.08ms +step:513/1670 train_time:47239ms step_avg:92.08ms +step:514/1670 train_time:47331ms step_avg:92.08ms +step:515/1670 train_time:47423ms step_avg:92.08ms +step:516/1670 train_time:47514ms step_avg:92.08ms +step:517/1670 train_time:47605ms step_avg:92.08ms +step:518/1670 train_time:47698ms step_avg:92.08ms +step:519/1670 train_time:47789ms step_avg:92.08ms +step:520/1670 train_time:47880ms step_avg:92.08ms +step:521/1670 train_time:47970ms step_avg:92.07ms +step:522/1670 train_time:48061ms step_avg:92.07ms +step:523/1670 train_time:48152ms step_avg:92.07ms +step:524/1670 train_time:48244ms step_avg:92.07ms +step:525/1670 train_time:48335ms step_avg:92.07ms +step:526/1670 train_time:48427ms step_avg:92.07ms +step:527/1670 train_time:48520ms step_avg:92.07ms +step:528/1670 train_time:48610ms step_avg:92.06ms +step:529/1670 train_time:48701ms step_avg:92.06ms +step:530/1670 train_time:48791ms step_avg:92.06ms +step:531/1670 train_time:48882ms step_avg:92.06ms +step:532/1670 train_time:48974ms step_avg:92.06ms +step:533/1670 train_time:49065ms step_avg:92.05ms +step:534/1670 train_time:49157ms step_avg:92.05ms +step:535/1670 train_time:49248ms step_avg:92.05ms +step:536/1670 train_time:49340ms step_avg:92.05ms +step:537/1670 train_time:49430ms step_avg:92.05ms +step:538/1670 train_time:49523ms step_avg:92.05ms +step:539/1670 train_time:49614ms step_avg:92.05ms +step:540/1670 train_time:49707ms step_avg:92.05ms +step:541/1670 train_time:49797ms step_avg:92.05ms +step:542/1670 train_time:49888ms step_avg:92.04ms +step:543/1670 train_time:49979ms step_avg:92.04ms +step:544/1670 train_time:50070ms step_avg:92.04ms +step:545/1670 train_time:50161ms step_avg:92.04ms +step:546/1670 train_time:50252ms step_avg:92.04ms +step:547/1670 train_time:50344ms step_avg:92.04ms +step:548/1670 train_time:50434ms step_avg:92.03ms +step:549/1670 train_time:50527ms step_avg:92.04ms +step:550/1670 train_time:50619ms step_avg:92.04ms +step:551/1670 train_time:50710ms step_avg:92.03ms +step:552/1670 train_time:50801ms step_avg:92.03ms +step:553/1670 train_time:50892ms step_avg:92.03ms +step:554/1670 train_time:50983ms step_avg:92.03ms +step:555/1670 train_time:51074ms step_avg:92.02ms +step:556/1670 train_time:51165ms step_avg:92.02ms +step:557/1670 train_time:51257ms step_avg:92.02ms +step:558/1670 train_time:51539ms step_avg:92.36ms +step:559/1670 train_time:51615ms step_avg:92.33ms +step:560/1670 train_time:51706ms step_avg:92.33ms +step:561/1670 train_time:51797ms step_avg:92.33ms +step:562/1670 train_time:51888ms step_avg:92.33ms +step:563/1670 train_time:51980ms step_avg:92.33ms +step:564/1670 train_time:52071ms step_avg:92.32ms +step:565/1670 train_time:52162ms step_avg:92.32ms +step:566/1670 train_time:52253ms step_avg:92.32ms +step:567/1670 train_time:52345ms step_avg:92.32ms +step:568/1670 train_time:52440ms step_avg:92.32ms +step:569/1670 train_time:52538ms step_avg:92.33ms +step:570/1670 train_time:52630ms step_avg:92.33ms +step:571/1670 train_time:52723ms step_avg:92.33ms +step:572/1670 train_time:52815ms step_avg:92.33ms +step:573/1670 train_time:52907ms step_avg:92.33ms +step:574/1670 train_time:52998ms step_avg:92.33ms +step:575/1670 train_time:53089ms step_avg:92.33ms +step:576/1670 train_time:53180ms step_avg:92.33ms +step:577/1670 train_time:53272ms step_avg:92.33ms +step:578/1670 train_time:53364ms step_avg:92.32ms +step:579/1670 train_time:53458ms step_avg:92.33ms +step:580/1670 train_time:53552ms step_avg:92.33ms +step:581/1670 train_time:53645ms step_avg:92.33ms +step:582/1670 train_time:53739ms step_avg:92.33ms +step:583/1670 train_time:53830ms step_avg:92.33ms +step:584/1670 train_time:53923ms step_avg:92.33ms +step:585/1670 train_time:54015ms step_avg:92.33ms +step:586/1670 train_time:54107ms step_avg:92.33ms +step:587/1670 train_time:54199ms step_avg:92.33ms +step:588/1670 train_time:54290ms step_avg:92.33ms +step:589/1670 train_time:54383ms step_avg:92.33ms +step:590/1670 train_time:54476ms step_avg:92.33ms +step:591/1670 train_time:54569ms step_avg:92.33ms +step:592/1670 train_time:54662ms step_avg:92.33ms +step:593/1670 train_time:54755ms step_avg:92.34ms +step:594/1670 train_time:54847ms step_avg:92.34ms +step:595/1670 train_time:54940ms step_avg:92.34ms +step:596/1670 train_time:55031ms step_avg:92.33ms +step:597/1670 train_time:55125ms step_avg:92.34ms +step:598/1670 train_time:55217ms step_avg:92.34ms +step:599/1670 train_time:55308ms step_avg:92.33ms +step:600/1670 train_time:55401ms step_avg:92.34ms +step:601/1670 train_time:55494ms step_avg:92.34ms +step:602/1670 train_time:55587ms step_avg:92.34ms +step:603/1670 train_time:55680ms step_avg:92.34ms +step:604/1670 train_time:55772ms step_avg:92.34ms +step:605/1670 train_time:55864ms step_avg:92.34ms +step:606/1670 train_time:55957ms step_avg:92.34ms +step:607/1670 train_time:56049ms step_avg:92.34ms +step:608/1670 train_time:56141ms step_avg:92.34ms +step:609/1670 train_time:56233ms step_avg:92.34ms +step:610/1670 train_time:56325ms step_avg:92.34ms +step:611/1670 train_time:56417ms step_avg:92.34ms +step:612/1670 train_time:56509ms step_avg:92.33ms +step:613/1670 train_time:56602ms step_avg:92.34ms +step:614/1670 train_time:56694ms step_avg:92.34ms +step:615/1670 train_time:56787ms step_avg:92.34ms +step:616/1670 train_time:56879ms step_avg:92.34ms +step:617/1670 train_time:56971ms step_avg:92.34ms +step:618/1670 train_time:57063ms step_avg:92.33ms +step:619/1670 train_time:57156ms step_avg:92.34ms +step:620/1670 train_time:57248ms step_avg:92.34ms +step:621/1670 train_time:57340ms step_avg:92.34ms +step:622/1670 train_time:57432ms step_avg:92.33ms +step:623/1670 train_time:57525ms step_avg:92.34ms +step:624/1670 train_time:57618ms step_avg:92.34ms +step:625/1670 train_time:57710ms step_avg:92.34ms +step:625/1670 val_loss:3.6125 train_time:57804ms step_avg:92.49ms +step:626/1670 train_time:57824ms step_avg:92.37ms +step:627/1670 train_time:57901ms step_avg:92.35ms +step:628/1670 train_time:58004ms step_avg:92.36ms +step:629/1670 train_time:58098ms step_avg:92.37ms +step:630/1670 train_time:58190ms step_avg:92.36ms +step:631/1670 train_time:58281ms step_avg:92.36ms +step:632/1670 train_time:58372ms step_avg:92.36ms +step:633/1670 train_time:58463ms step_avg:92.36ms +step:634/1670 train_time:58554ms step_avg:92.36ms +step:635/1670 train_time:58646ms step_avg:92.36ms +step:636/1670 train_time:58738ms step_avg:92.35ms +step:637/1670 train_time:58829ms step_avg:92.35ms +step:638/1670 train_time:58925ms step_avg:92.36ms +step:639/1670 train_time:59161ms step_avg:92.58ms +step:640/1670 train_time:59233ms step_avg:92.55ms +step:641/1670 train_time:59324ms step_avg:92.55ms +step:642/1670 train_time:59415ms step_avg:92.55ms +step:643/1670 train_time:59506ms step_avg:92.54ms +step:644/1670 train_time:59597ms step_avg:92.54ms +step:645/1670 train_time:59688ms step_avg:92.54ms +step:646/1670 train_time:59780ms step_avg:92.54ms +step:647/1670 train_time:59872ms step_avg:92.54ms +step:648/1670 train_time:59963ms step_avg:92.54ms +step:649/1670 train_time:60059ms step_avg:92.54ms +step:650/1670 train_time:60156ms step_avg:92.55ms +step:651/1670 train_time:60249ms step_avg:92.55ms +step:652/1670 train_time:60342ms step_avg:92.55ms +step:653/1670 train_time:60434ms step_avg:92.55ms +step:654/1670 train_time:60525ms step_avg:92.55ms +step:655/1670 train_time:60617ms step_avg:92.55ms +step:656/1670 train_time:60708ms step_avg:92.54ms +step:657/1670 train_time:60800ms step_avg:92.54ms +step:658/1670 train_time:60891ms step_avg:92.54ms +step:659/1670 train_time:60985ms step_avg:92.54ms +step:660/1670 train_time:61081ms step_avg:92.55ms +step:661/1670 train_time:61175ms step_avg:92.55ms +step:662/1670 train_time:61268ms step_avg:92.55ms +step:663/1670 train_time:61362ms step_avg:92.55ms +step:664/1670 train_time:61454ms step_avg:92.55ms +step:665/1670 train_time:61545ms step_avg:92.55ms +step:666/1670 train_time:61637ms step_avg:92.55ms +step:667/1670 train_time:61726ms step_avg:92.54ms +step:668/1670 train_time:61817ms step_avg:92.54ms +step:669/1670 train_time:61909ms step_avg:92.54ms +step:670/1670 train_time:62002ms step_avg:92.54ms +step:671/1670 train_time:62096ms step_avg:92.54ms +step:672/1670 train_time:62188ms step_avg:92.54ms +step:673/1670 train_time:62282ms step_avg:92.54ms +step:674/1670 train_time:62375ms step_avg:92.55ms +step:675/1670 train_time:62467ms step_avg:92.54ms +step:676/1670 train_time:62559ms step_avg:92.54ms +step:677/1670 train_time:62652ms step_avg:92.54ms +step:678/1670 train_time:62744ms step_avg:92.54ms +step:679/1670 train_time:62835ms step_avg:92.54ms +step:680/1670 train_time:62928ms step_avg:92.54ms +step:681/1670 train_time:63021ms step_avg:92.54ms +step:682/1670 train_time:63114ms step_avg:92.54ms +step:683/1670 train_time:63206ms step_avg:92.54ms +step:684/1670 train_time:63302ms step_avg:92.55ms +step:685/1670 train_time:63394ms step_avg:92.55ms +step:686/1670 train_time:63486ms step_avg:92.55ms +step:687/1670 train_time:63579ms step_avg:92.55ms +step:688/1670 train_time:63671ms step_avg:92.55ms +step:689/1670 train_time:63764ms step_avg:92.55ms +step:690/1670 train_time:63856ms step_avg:92.55ms +step:691/1670 train_time:63948ms step_avg:92.54ms +step:692/1670 train_time:64041ms step_avg:92.55ms +step:693/1670 train_time:64134ms step_avg:92.55ms +step:694/1670 train_time:64227ms step_avg:92.55ms +step:695/1670 train_time:64321ms step_avg:92.55ms +step:696/1670 train_time:64414ms step_avg:92.55ms +step:697/1670 train_time:64506ms step_avg:92.55ms +step:698/1670 train_time:64598ms step_avg:92.55ms +step:699/1670 train_time:64689ms step_avg:92.55ms +step:700/1670 train_time:64782ms step_avg:92.55ms +step:701/1670 train_time:64875ms step_avg:92.55ms +step:702/1670 train_time:64966ms step_avg:92.54ms +step:703/1670 train_time:65060ms step_avg:92.55ms +step:704/1670 train_time:65152ms step_avg:92.55ms +step:705/1670 train_time:65245ms step_avg:92.55ms +step:706/1670 train_time:65337ms step_avg:92.55ms +step:707/1670 train_time:65430ms step_avg:92.55ms +step:708/1670 train_time:65523ms step_avg:92.55ms +step:709/1670 train_time:65617ms step_avg:92.55ms +step:710/1670 train_time:65708ms step_avg:92.55ms +step:711/1670 train_time:65802ms step_avg:92.55ms +step:712/1670 train_time:65895ms step_avg:92.55ms +step:713/1670 train_time:65987ms step_avg:92.55ms +step:714/1670 train_time:66080ms step_avg:92.55ms +step:715/1670 train_time:66172ms step_avg:92.55ms +step:716/1670 train_time:66265ms step_avg:92.55ms +step:717/1670 train_time:66357ms step_avg:92.55ms +step:718/1670 train_time:66450ms step_avg:92.55ms +step:719/1670 train_time:66543ms step_avg:92.55ms +step:720/1670 train_time:66635ms step_avg:92.55ms +step:721/1670 train_time:66727ms step_avg:92.55ms +step:722/1670 train_time:66821ms step_avg:92.55ms +step:723/1670 train_time:66913ms step_avg:92.55ms +step:724/1670 train_time:67005ms step_avg:92.55ms +step:725/1670 train_time:67097ms step_avg:92.55ms +step:726/1670 train_time:67189ms step_avg:92.55ms +step:727/1670 train_time:67283ms step_avg:92.55ms +step:728/1670 train_time:67375ms step_avg:92.55ms +step:729/1670 train_time:67467ms step_avg:92.55ms +step:730/1670 train_time:67561ms step_avg:92.55ms +step:731/1670 train_time:67653ms step_avg:92.55ms +step:732/1670 train_time:67746ms step_avg:92.55ms +step:733/1670 train_time:67838ms step_avg:92.55ms +step:734/1670 train_time:67931ms step_avg:92.55ms +step:735/1670 train_time:68024ms step_avg:92.55ms +step:736/1670 train_time:68117ms step_avg:92.55ms +step:737/1670 train_time:68209ms step_avg:92.55ms +step:738/1670 train_time:68302ms step_avg:92.55ms +step:739/1670 train_time:68394ms step_avg:92.55ms +step:740/1670 train_time:68486ms step_avg:92.55ms +step:741/1670 train_time:68580ms step_avg:92.55ms +step:742/1670 train_time:68672ms step_avg:92.55ms +step:743/1670 train_time:68764ms step_avg:92.55ms +step:744/1670 train_time:68857ms step_avg:92.55ms +step:745/1670 train_time:68949ms step_avg:92.55ms +step:746/1670 train_time:69042ms step_avg:92.55ms +step:747/1670 train_time:69135ms step_avg:92.55ms +step:748/1670 train_time:69227ms step_avg:92.55ms +step:749/1670 train_time:69320ms step_avg:92.55ms +step:750/1670 train_time:69412ms step_avg:92.55ms +step:750/1670 val_loss:3.5629 train_time:69505ms step_avg:92.67ms +step:751/1670 train_time:69525ms step_avg:92.58ms +step:752/1670 train_time:69600ms step_avg:92.55ms +step:753/1670 train_time:69693ms step_avg:92.55ms +step:754/1670 train_time:69784ms step_avg:92.55ms +step:755/1670 train_time:69875ms step_avg:92.55ms +step:756/1670 train_time:69966ms step_avg:92.55ms +step:757/1670 train_time:70058ms step_avg:92.55ms +step:758/1670 train_time:70151ms step_avg:92.55ms +step:759/1670 train_time:70243ms step_avg:92.55ms +step:760/1670 train_time:70336ms step_avg:92.55ms +step:761/1670 train_time:70428ms step_avg:92.55ms +step:762/1670 train_time:70524ms step_avg:92.55ms +step:763/1670 train_time:70617ms step_avg:92.55ms +step:764/1670 train_time:70710ms step_avg:92.55ms +step:765/1670 train_time:70803ms step_avg:92.55ms +step:766/1670 train_time:70896ms step_avg:92.55ms +step:767/1670 train_time:70987ms step_avg:92.55ms +step:768/1670 train_time:71079ms step_avg:92.55ms +step:769/1670 train_time:71171ms step_avg:92.55ms +step:770/1670 train_time:71263ms step_avg:92.55ms +step:771/1670 train_time:71356ms step_avg:92.55ms +step:772/1670 train_time:71448ms step_avg:92.55ms +step:773/1670 train_time:71541ms step_avg:92.55ms +step:774/1670 train_time:71634ms step_avg:92.55ms +step:775/1670 train_time:71726ms step_avg:92.55ms +step:776/1670 train_time:71821ms step_avg:92.55ms +step:777/1670 train_time:71913ms step_avg:92.55ms +step:778/1670 train_time:72005ms step_avg:92.55ms +step:779/1670 train_time:72098ms step_avg:92.55ms +step:780/1670 train_time:72189ms step_avg:92.55ms +step:781/1670 train_time:72282ms step_avg:92.55ms +step:782/1670 train_time:72375ms step_avg:92.55ms +step:783/1670 train_time:72468ms step_avg:92.55ms +step:784/1670 train_time:72561ms step_avg:92.55ms +step:785/1670 train_time:72654ms step_avg:92.55ms +step:786/1670 train_time:72746ms step_avg:92.55ms +step:787/1670 train_time:72841ms step_avg:92.56ms +step:788/1670 train_time:72934ms step_avg:92.56ms +step:789/1670 train_time:73025ms step_avg:92.55ms +step:790/1670 train_time:73118ms step_avg:92.55ms +step:791/1670 train_time:73211ms step_avg:92.56ms +step:792/1670 train_time:73304ms step_avg:92.56ms +step:793/1670 train_time:73396ms step_avg:92.55ms +step:794/1670 train_time:73489ms step_avg:92.55ms +step:795/1670 train_time:73581ms step_avg:92.55ms +step:796/1670 train_time:73674ms step_avg:92.55ms +step:797/1670 train_time:73766ms step_avg:92.55ms +step:798/1670 train_time:73858ms step_avg:92.55ms +step:799/1670 train_time:73952ms step_avg:92.56ms +step:800/1670 train_time:74044ms step_avg:92.55ms +step:801/1670 train_time:74136ms step_avg:92.55ms +step:802/1670 train_time:74228ms step_avg:92.55ms +step:803/1670 train_time:74322ms step_avg:92.56ms +step:804/1670 train_time:74415ms step_avg:92.56ms +step:805/1670 train_time:74507ms step_avg:92.56ms +step:806/1670 train_time:74600ms step_avg:92.56ms +step:807/1670 train_time:74693ms step_avg:92.56ms +step:808/1670 train_time:74786ms step_avg:92.56ms +step:809/1670 train_time:74878ms step_avg:92.56ms +step:810/1670 train_time:74971ms step_avg:92.56ms +step:811/1670 train_time:75063ms step_avg:92.56ms +step:812/1670 train_time:75155ms step_avg:92.56ms +step:813/1670 train_time:75247ms step_avg:92.55ms +step:814/1670 train_time:75340ms step_avg:92.56ms +step:815/1670 train_time:75432ms step_avg:92.55ms +step:816/1670 train_time:75524ms step_avg:92.55ms +step:817/1670 train_time:75617ms step_avg:92.55ms +step:818/1670 train_time:75710ms step_avg:92.55ms +step:819/1670 train_time:75803ms step_avg:92.56ms +step:820/1670 train_time:75896ms step_avg:92.56ms +step:821/1670 train_time:75988ms step_avg:92.55ms +step:822/1670 train_time:76080ms step_avg:92.56ms +step:823/1670 train_time:76173ms step_avg:92.56ms +step:824/1670 train_time:76265ms step_avg:92.55ms +step:825/1670 train_time:76357ms step_avg:92.55ms +step:826/1670 train_time:76449ms step_avg:92.55ms +step:827/1670 train_time:76542ms step_avg:92.55ms +step:828/1670 train_time:76633ms step_avg:92.55ms +step:829/1670 train_time:76726ms step_avg:92.55ms +step:830/1670 train_time:76819ms step_avg:92.55ms +step:831/1670 train_time:76911ms step_avg:92.55ms +step:832/1670 train_time:77003ms step_avg:92.55ms +step:833/1670 train_time:77096ms step_avg:92.55ms +step:834/1670 train_time:77189ms step_avg:92.55ms +step:835/1670 train_time:77280ms step_avg:92.55ms +step:836/1670 train_time:77373ms step_avg:92.55ms +step:837/1670 train_time:77465ms step_avg:92.55ms +step:838/1670 train_time:77559ms step_avg:92.55ms +step:839/1670 train_time:77652ms step_avg:92.55ms +step:840/1670 train_time:77744ms step_avg:92.55ms +step:841/1670 train_time:77837ms step_avg:92.55ms +step:842/1670 train_time:77929ms step_avg:92.55ms +step:843/1670 train_time:78023ms step_avg:92.55ms +step:844/1670 train_time:78116ms step_avg:92.55ms +step:845/1670 train_time:78207ms step_avg:92.55ms +step:846/1670 train_time:78300ms step_avg:92.55ms +step:847/1670 train_time:78393ms step_avg:92.55ms +step:848/1670 train_time:78486ms step_avg:92.55ms +step:849/1670 train_time:78579ms step_avg:92.56ms +step:850/1670 train_time:78672ms step_avg:92.56ms +step:851/1670 train_time:78921ms step_avg:92.74ms +step:852/1670 train_time:78992ms step_avg:92.71ms +step:853/1670 train_time:79082ms step_avg:92.71ms +step:854/1670 train_time:79173ms step_avg:92.71ms +step:855/1670 train_time:79263ms step_avg:92.71ms +step:856/1670 train_time:79354ms step_avg:92.70ms +step:857/1670 train_time:79445ms step_avg:92.70ms +step:858/1670 train_time:79537ms step_avg:92.70ms +step:859/1670 train_time:79628ms step_avg:92.70ms +step:860/1670 train_time:79720ms step_avg:92.70ms +step:861/1670 train_time:79815ms step_avg:92.70ms +step:862/1670 train_time:79912ms step_avg:92.71ms +step:863/1670 train_time:80007ms step_avg:92.71ms +step:864/1670 train_time:80100ms step_avg:92.71ms +step:865/1670 train_time:80192ms step_avg:92.71ms +step:866/1670 train_time:80284ms step_avg:92.71ms +step:867/1670 train_time:80375ms step_avg:92.71ms +step:868/1670 train_time:80466ms step_avg:92.70ms +step:869/1670 train_time:80558ms step_avg:92.70ms +step:870/1670 train_time:80649ms step_avg:92.70ms +step:871/1670 train_time:80742ms step_avg:92.70ms +step:872/1670 train_time:80837ms step_avg:92.70ms +step:873/1670 train_time:80932ms step_avg:92.71ms +step:874/1670 train_time:81026ms step_avg:92.71ms +step:875/1670 train_time:81119ms step_avg:92.71ms +step:875/1670 val_loss:3.5185 train_time:81211ms step_avg:92.81ms +step:876/1670 train_time:81231ms step_avg:92.73ms +step:877/1670 train_time:81306ms step_avg:92.71ms +step:878/1670 train_time:81399ms step_avg:92.71ms +step:879/1670 train_time:81490ms step_avg:92.71ms +step:880/1670 train_time:81582ms step_avg:92.71ms +step:881/1670 train_time:81673ms step_avg:92.71ms +step:882/1670 train_time:81764ms step_avg:92.70ms +step:883/1670 train_time:81857ms step_avg:92.70ms +step:884/1670 train_time:81949ms step_avg:92.70ms +step:885/1670 train_time:82042ms step_avg:92.70ms +step:886/1670 train_time:82135ms step_avg:92.70ms +step:887/1670 train_time:82230ms step_avg:92.71ms +step:888/1670 train_time:82323ms step_avg:92.71ms +step:889/1670 train_time:82416ms step_avg:92.71ms +step:890/1670 train_time:82508ms step_avg:92.71ms +step:891/1670 train_time:82601ms step_avg:92.71ms +step:892/1670 train_time:82693ms step_avg:92.71ms +step:893/1670 train_time:82784ms step_avg:92.70ms +step:894/1670 train_time:82877ms step_avg:92.70ms +step:895/1670 train_time:82969ms step_avg:92.70ms +step:896/1670 train_time:83063ms step_avg:92.70ms +step:897/1670 train_time:83157ms step_avg:92.71ms +step:898/1670 train_time:83250ms step_avg:92.71ms +step:899/1670 train_time:83342ms step_avg:92.71ms +step:900/1670 train_time:83435ms step_avg:92.71ms +step:901/1670 train_time:83528ms step_avg:92.71ms +step:902/1670 train_time:83620ms step_avg:92.70ms +step:903/1670 train_time:83712ms step_avg:92.70ms +step:904/1670 train_time:83805ms step_avg:92.70ms +step:905/1670 train_time:83897ms step_avg:92.70ms +step:906/1670 train_time:83989ms step_avg:92.70ms +step:907/1670 train_time:84084ms step_avg:92.71ms +step:908/1670 train_time:84177ms step_avg:92.71ms +step:909/1670 train_time:84269ms step_avg:92.71ms +step:910/1670 train_time:84363ms step_avg:92.71ms +step:911/1670 train_time:84456ms step_avg:92.71ms +step:912/1670 train_time:84548ms step_avg:92.71ms +step:913/1670 train_time:84640ms step_avg:92.71ms +step:914/1670 train_time:84733ms step_avg:92.71ms +step:915/1670 train_time:84825ms step_avg:92.70ms +step:916/1670 train_time:84916ms step_avg:92.70ms +step:917/1670 train_time:85009ms step_avg:92.70ms +step:918/1670 train_time:85102ms step_avg:92.70ms +step:919/1670 train_time:85194ms step_avg:92.70ms +step:920/1670 train_time:85287ms step_avg:92.70ms +step:921/1670 train_time:85379ms step_avg:92.70ms +step:922/1670 train_time:85471ms step_avg:92.70ms +step:923/1670 train_time:85563ms step_avg:92.70ms +step:924/1670 train_time:85656ms step_avg:92.70ms +step:925/1670 train_time:85748ms step_avg:92.70ms +step:926/1670 train_time:85842ms step_avg:92.70ms +step:927/1670 train_time:85934ms step_avg:92.70ms +step:928/1670 train_time:86026ms step_avg:92.70ms +step:929/1670 train_time:86119ms step_avg:92.70ms +step:930/1670 train_time:86212ms step_avg:92.70ms +step:931/1670 train_time:86304ms step_avg:92.70ms +step:932/1670 train_time:86397ms step_avg:92.70ms +step:933/1670 train_time:86489ms step_avg:92.70ms +step:934/1670 train_time:86582ms step_avg:92.70ms +step:935/1670 train_time:86675ms step_avg:92.70ms +step:936/1670 train_time:86767ms step_avg:92.70ms +step:937/1670 train_time:86860ms step_avg:92.70ms +step:938/1670 train_time:86952ms step_avg:92.70ms +step:939/1670 train_time:87045ms step_avg:92.70ms +step:940/1670 train_time:87138ms step_avg:92.70ms +step:941/1670 train_time:87230ms step_avg:92.70ms +step:942/1670 train_time:87323ms step_avg:92.70ms +step:943/1670 train_time:87416ms step_avg:92.70ms +step:944/1670 train_time:87507ms step_avg:92.70ms +step:945/1670 train_time:87600ms step_avg:92.70ms +step:946/1670 train_time:87692ms step_avg:92.70ms +step:947/1670 train_time:87784ms step_avg:92.70ms +step:948/1670 train_time:87878ms step_avg:92.70ms +step:949/1670 train_time:87970ms step_avg:92.70ms +step:950/1670 train_time:88063ms step_avg:92.70ms +step:951/1670 train_time:88156ms step_avg:92.70ms +step:952/1670 train_time:88248ms step_avg:92.70ms +step:953/1670 train_time:88340ms step_avg:92.70ms +step:954/1670 train_time:88433ms step_avg:92.70ms +step:955/1670 train_time:88525ms step_avg:92.70ms +step:956/1670 train_time:88617ms step_avg:92.70ms +step:957/1670 train_time:88709ms step_avg:92.69ms +step:958/1670 train_time:88802ms step_avg:92.69ms +step:959/1670 train_time:88894ms step_avg:92.69ms +step:960/1670 train_time:88986ms step_avg:92.69ms +step:961/1670 train_time:89080ms step_avg:92.69ms +step:962/1670 train_time:89172ms step_avg:92.69ms +step:963/1670 train_time:89265ms step_avg:92.69ms +step:964/1670 train_time:89358ms step_avg:92.70ms +step:965/1670 train_time:89450ms step_avg:92.69ms +step:966/1670 train_time:89543ms step_avg:92.69ms +step:967/1670 train_time:89635ms step_avg:92.69ms +step:968/1670 train_time:89727ms step_avg:92.69ms +step:969/1670 train_time:89819ms step_avg:92.69ms +step:970/1670 train_time:89912ms step_avg:92.69ms +step:971/1670 train_time:90004ms step_avg:92.69ms +step:972/1670 train_time:90096ms step_avg:92.69ms +step:973/1670 train_time:90188ms step_avg:92.69ms +step:974/1670 train_time:90282ms step_avg:92.69ms +step:975/1670 train_time:90375ms step_avg:92.69ms +step:976/1670 train_time:90467ms step_avg:92.69ms +step:977/1670 train_time:90559ms step_avg:92.69ms +step:978/1670 train_time:90651ms step_avg:92.69ms +step:979/1670 train_time:90743ms step_avg:92.69ms +step:980/1670 train_time:90836ms step_avg:92.69ms +step:981/1670 train_time:90928ms step_avg:92.69ms +step:982/1670 train_time:91020ms step_avg:92.69ms +step:983/1670 train_time:91114ms step_avg:92.69ms +step:984/1670 train_time:91206ms step_avg:92.69ms +step:985/1670 train_time:91298ms step_avg:92.69ms +step:986/1670 train_time:91389ms step_avg:92.69ms +step:987/1670 train_time:91483ms step_avg:92.69ms +step:988/1670 train_time:91576ms step_avg:92.69ms +step:989/1670 train_time:91667ms step_avg:92.69ms +step:990/1670 train_time:91759ms step_avg:92.69ms +step:991/1670 train_time:91852ms step_avg:92.69ms +step:992/1670 train_time:91944ms step_avg:92.69ms +step:993/1670 train_time:92037ms step_avg:92.69ms +step:994/1670 train_time:92129ms step_avg:92.69ms +step:995/1670 train_time:92222ms step_avg:92.69ms +step:996/1670 train_time:92314ms step_avg:92.68ms +step:997/1670 train_time:92406ms step_avg:92.68ms +step:998/1670 train_time:92499ms step_avg:92.68ms +step:999/1670 train_time:92591ms step_avg:92.68ms +step:1000/1670 train_time:92684ms step_avg:92.68ms +step:1000/1670 val_loss:3.4675 train_time:92777ms step_avg:92.78ms +step:1001/1670 train_time:92797ms step_avg:92.70ms +step:1002/1670 train_time:92871ms step_avg:92.69ms +step:1003/1670 train_time:92963ms step_avg:92.69ms +step:1004/1670 train_time:93055ms step_avg:92.68ms +step:1005/1670 train_time:93146ms step_avg:92.68ms +step:1006/1670 train_time:93238ms step_avg:92.68ms +step:1007/1670 train_time:93330ms step_avg:92.68ms +step:1008/1670 train_time:93422ms step_avg:92.68ms +step:1009/1670 train_time:93514ms step_avg:92.68ms +step:1010/1670 train_time:93607ms step_avg:92.68ms +step:1011/1670 train_time:93701ms step_avg:92.68ms +step:1012/1670 train_time:93795ms step_avg:92.68ms +step:1013/1670 train_time:93889ms step_avg:92.68ms +step:1014/1670 train_time:93981ms step_avg:92.68ms +step:1015/1670 train_time:94075ms step_avg:92.68ms +step:1016/1670 train_time:94166ms step_avg:92.68ms +step:1017/1670 train_time:94258ms step_avg:92.68ms +step:1018/1670 train_time:94350ms step_avg:92.68ms +step:1019/1670 train_time:94442ms step_avg:92.68ms +step:1020/1670 train_time:94534ms step_avg:92.68ms +step:1021/1670 train_time:94625ms step_avg:92.68ms +step:1022/1670 train_time:94720ms step_avg:92.68ms +step:1023/1670 train_time:94814ms step_avg:92.68ms +step:1024/1670 train_time:94907ms step_avg:92.68ms +step:1025/1670 train_time:95001ms step_avg:92.68ms +step:1026/1670 train_time:95093ms step_avg:92.68ms +step:1027/1670 train_time:95185ms step_avg:92.68ms +step:1028/1670 train_time:95278ms step_avg:92.68ms +step:1029/1670 train_time:95370ms step_avg:92.68ms +step:1030/1670 train_time:95462ms step_avg:92.68ms +step:1031/1670 train_time:95554ms step_avg:92.68ms +step:1032/1670 train_time:95646ms step_avg:92.68ms +step:1033/1670 train_time:95739ms step_avg:92.68ms +step:1034/1670 train_time:95832ms step_avg:92.68ms +step:1035/1670 train_time:95925ms step_avg:92.68ms +step:1036/1670 train_time:96019ms step_avg:92.68ms +step:1037/1670 train_time:96112ms step_avg:92.68ms +step:1038/1670 train_time:96205ms step_avg:92.68ms +step:1039/1670 train_time:96297ms step_avg:92.68ms +step:1040/1670 train_time:96389ms step_avg:92.68ms +step:1041/1670 train_time:96481ms step_avg:92.68ms +step:1042/1670 train_time:96573ms step_avg:92.68ms +step:1043/1670 train_time:96665ms step_avg:92.68ms +step:1044/1670 train_time:96758ms step_avg:92.68ms +step:1045/1670 train_time:96850ms step_avg:92.68ms +step:1046/1670 train_time:96944ms step_avg:92.68ms +step:1047/1670 train_time:97037ms step_avg:92.68ms +step:1048/1670 train_time:97130ms step_avg:92.68ms +step:1049/1670 train_time:97223ms step_avg:92.68ms +step:1050/1670 train_time:97315ms step_avg:92.68ms +step:1051/1670 train_time:97407ms step_avg:92.68ms +step:1052/1670 train_time:97499ms step_avg:92.68ms +step:1053/1670 train_time:97591ms step_avg:92.68ms +step:1054/1670 train_time:97683ms step_avg:92.68ms +step:1055/1670 train_time:97776ms step_avg:92.68ms +step:1056/1670 train_time:97869ms step_avg:92.68ms +step:1057/1670 train_time:97962ms step_avg:92.68ms +step:1058/1670 train_time:98055ms step_avg:92.68ms +step:1059/1670 train_time:98148ms step_avg:92.68ms +step:1060/1670 train_time:98242ms step_avg:92.68ms +step:1061/1670 train_time:98336ms step_avg:92.68ms +step:1062/1670 train_time:98584ms step_avg:92.83ms +step:1063/1670 train_time:98653ms step_avg:92.81ms +step:1064/1670 train_time:98743ms step_avg:92.80ms +step:1065/1670 train_time:98835ms step_avg:92.80ms +step:1066/1670 train_time:98925ms step_avg:92.80ms +step:1067/1670 train_time:99017ms step_avg:92.80ms +step:1068/1670 train_time:99108ms step_avg:92.80ms +step:1069/1670 train_time:99199ms step_avg:92.80ms +step:1070/1670 train_time:99291ms step_avg:92.80ms +step:1071/1670 train_time:99382ms step_avg:92.79ms +step:1072/1670 train_time:99479ms step_avg:92.80ms +step:1073/1670 train_time:99575ms step_avg:92.80ms +step:1074/1670 train_time:99669ms step_avg:92.80ms +step:1075/1670 train_time:99762ms step_avg:92.80ms +step:1076/1670 train_time:99854ms step_avg:92.80ms +step:1077/1670 train_time:99944ms step_avg:92.80ms +step:1078/1670 train_time:100036ms step_avg:92.80ms +step:1079/1670 train_time:100127ms step_avg:92.80ms +step:1080/1670 train_time:100220ms step_avg:92.80ms +step:1081/1670 train_time:100311ms step_avg:92.79ms +step:1082/1670 train_time:100403ms step_avg:92.79ms +step:1083/1670 train_time:100500ms step_avg:92.80ms +step:1084/1670 train_time:100594ms step_avg:92.80ms +step:1085/1670 train_time:100687ms step_avg:92.80ms +step:1086/1670 train_time:100780ms step_avg:92.80ms +step:1087/1670 train_time:100873ms step_avg:92.80ms +step:1088/1670 train_time:100965ms step_avg:92.80ms +step:1089/1670 train_time:101056ms step_avg:92.80ms +step:1090/1670 train_time:101147ms step_avg:92.80ms +step:1091/1670 train_time:101239ms step_avg:92.79ms +step:1092/1670 train_time:101330ms step_avg:92.79ms +step:1093/1670 train_time:101424ms step_avg:92.79ms +step:1094/1670 train_time:101518ms step_avg:92.80ms +step:1095/1670 train_time:101611ms step_avg:92.80ms +step:1096/1670 train_time:101705ms step_avg:92.80ms +step:1097/1670 train_time:101798ms step_avg:92.80ms +step:1098/1670 train_time:101890ms step_avg:92.80ms +step:1099/1670 train_time:101982ms step_avg:92.80ms +step:1100/1670 train_time:102073ms step_avg:92.79ms +step:1101/1670 train_time:102165ms step_avg:92.79ms +step:1102/1670 train_time:102256ms step_avg:92.79ms +step:1103/1670 train_time:102348ms step_avg:92.79ms +step:1104/1670 train_time:102443ms step_avg:92.79ms +step:1105/1670 train_time:102537ms step_avg:92.79ms +step:1106/1670 train_time:102628ms step_avg:92.79ms +step:1107/1670 train_time:102723ms step_avg:92.79ms +step:1108/1670 train_time:102816ms step_avg:92.79ms +step:1109/1670 train_time:102909ms step_avg:92.79ms +step:1110/1670 train_time:103002ms step_avg:92.79ms +step:1111/1670 train_time:103095ms step_avg:92.79ms +step:1112/1670 train_time:103186ms step_avg:92.79ms +step:1113/1670 train_time:103277ms step_avg:92.79ms +step:1114/1670 train_time:103370ms step_avg:92.79ms +step:1115/1670 train_time:103652ms step_avg:92.96ms +step:1116/1670 train_time:103725ms step_avg:92.94ms +step:1117/1670 train_time:103816ms step_avg:92.94ms +step:1118/1670 train_time:103908ms step_avg:92.94ms +step:1119/1670 train_time:104000ms step_avg:92.94ms +step:1120/1670 train_time:104092ms step_avg:92.94ms +step:1121/1670 train_time:104183ms step_avg:92.94ms +step:1122/1670 train_time:104275ms step_avg:92.94ms +step:1123/1670 train_time:104367ms step_avg:92.94ms +step:1124/1670 train_time:104459ms step_avg:92.93ms +step:1125/1670 train_time:104555ms step_avg:92.94ms +step:1125/1670 val_loss:3.4152 train_time:104654ms step_avg:93.03ms +step:1126/1670 train_time:104673ms step_avg:92.96ms +step:1127/1670 train_time:104751ms step_avg:92.95ms +step:1128/1670 train_time:104852ms step_avg:92.95ms +step:1129/1670 train_time:104948ms step_avg:92.96ms +step:1130/1670 train_time:105041ms step_avg:92.96ms +step:1131/1670 train_time:105132ms step_avg:92.95ms +step:1132/1670 train_time:105224ms step_avg:92.95ms +step:1133/1670 train_time:105316ms step_avg:92.95ms +step:1134/1670 train_time:105408ms step_avg:92.95ms +step:1135/1670 train_time:105500ms step_avg:92.95ms +step:1136/1670 train_time:105592ms step_avg:92.95ms +step:1137/1670 train_time:105688ms step_avg:92.95ms +step:1138/1670 train_time:105784ms step_avg:92.96ms +step:1139/1670 train_time:105880ms step_avg:92.96ms +step:1140/1670 train_time:105973ms step_avg:92.96ms +step:1141/1670 train_time:106066ms step_avg:92.96ms +step:1142/1670 train_time:106159ms step_avg:92.96ms +step:1143/1670 train_time:106251ms step_avg:92.96ms +step:1144/1670 train_time:106343ms step_avg:92.96ms +step:1145/1670 train_time:106435ms step_avg:92.96ms +step:1146/1670 train_time:106528ms step_avg:92.96ms +step:1147/1670 train_time:106621ms step_avg:92.96ms +step:1148/1670 train_time:106715ms step_avg:92.96ms +step:1149/1670 train_time:106811ms step_avg:92.96ms +step:1150/1670 train_time:106906ms step_avg:92.96ms +step:1151/1670 train_time:107000ms step_avg:92.96ms +step:1152/1670 train_time:107092ms step_avg:92.96ms +step:1153/1670 train_time:107185ms step_avg:92.96ms +step:1154/1670 train_time:107277ms step_avg:92.96ms +step:1155/1670 train_time:107370ms step_avg:92.96ms +step:1156/1670 train_time:107462ms step_avg:92.96ms +step:1157/1670 train_time:107554ms step_avg:92.96ms +step:1158/1670 train_time:107647ms step_avg:92.96ms +step:1159/1670 train_time:107741ms step_avg:92.96ms +step:1160/1670 train_time:107836ms step_avg:92.96ms +step:1161/1670 train_time:107930ms step_avg:92.96ms +step:1162/1670 train_time:108025ms step_avg:92.96ms +step:1163/1670 train_time:108119ms step_avg:92.97ms +step:1164/1670 train_time:108211ms step_avg:92.96ms +step:1165/1670 train_time:108303ms step_avg:92.96ms +step:1166/1670 train_time:108395ms step_avg:92.96ms +step:1167/1670 train_time:108488ms step_avg:92.96ms +step:1168/1670 train_time:108582ms step_avg:92.96ms +step:1169/1670 train_time:108674ms step_avg:92.96ms +step:1170/1670 train_time:108768ms step_avg:92.96ms +step:1171/1670 train_time:108861ms step_avg:92.96ms +step:1172/1670 train_time:108954ms step_avg:92.96ms +step:1173/1670 train_time:109048ms step_avg:92.96ms +step:1174/1670 train_time:109142ms step_avg:92.97ms +step:1175/1670 train_time:109235ms step_avg:92.97ms +step:1176/1670 train_time:109329ms step_avg:92.97ms +step:1177/1670 train_time:109421ms step_avg:92.97ms +step:1178/1670 train_time:109513ms step_avg:92.97ms +step:1179/1670 train_time:109607ms step_avg:92.97ms +step:1180/1670 train_time:109701ms step_avg:92.97ms +step:1181/1670 train_time:109794ms step_avg:92.97ms +step:1182/1670 train_time:109888ms step_avg:92.97ms +step:1183/1670 train_time:109982ms step_avg:92.97ms +step:1184/1670 train_time:110074ms step_avg:92.97ms +step:1185/1670 train_time:110168ms step_avg:92.97ms +step:1186/1670 train_time:110261ms step_avg:92.97ms +step:1187/1670 train_time:110353ms step_avg:92.97ms +step:1188/1670 train_time:110446ms step_avg:92.97ms +step:1189/1670 train_time:110539ms step_avg:92.97ms +step:1190/1670 train_time:110632ms step_avg:92.97ms +step:1191/1670 train_time:110727ms step_avg:92.97ms +step:1192/1670 train_time:110819ms step_avg:92.97ms +step:1193/1670 train_time:110913ms step_avg:92.97ms +step:1194/1670 train_time:111008ms step_avg:92.97ms +step:1195/1670 train_time:111102ms step_avg:92.97ms +step:1196/1670 train_time:111194ms step_avg:92.97ms +step:1197/1670 train_time:111288ms step_avg:92.97ms +step:1198/1670 train_time:111381ms step_avg:92.97ms +step:1199/1670 train_time:111473ms step_avg:92.97ms +step:1200/1670 train_time:111567ms step_avg:92.97ms +step:1201/1670 train_time:111659ms step_avg:92.97ms +step:1202/1670 train_time:111752ms step_avg:92.97ms +step:1203/1670 train_time:111846ms step_avg:92.97ms +step:1204/1670 train_time:111939ms step_avg:92.97ms +step:1205/1670 train_time:112032ms step_avg:92.97ms +step:1206/1670 train_time:112126ms step_avg:92.97ms +step:1207/1670 train_time:112219ms step_avg:92.97ms +step:1208/1670 train_time:112312ms step_avg:92.97ms +step:1209/1670 train_time:112406ms step_avg:92.97ms +step:1210/1670 train_time:112500ms step_avg:92.98ms +step:1211/1670 train_time:112593ms step_avg:92.97ms +step:1212/1670 train_time:112686ms step_avg:92.98ms +step:1213/1670 train_time:112778ms step_avg:92.97ms +step:1214/1670 train_time:112871ms step_avg:92.97ms +step:1215/1670 train_time:112964ms step_avg:92.97ms +step:1216/1670 train_time:113057ms step_avg:92.97ms +step:1217/1670 train_time:113150ms step_avg:92.97ms +step:1218/1670 train_time:113243ms step_avg:92.97ms +step:1219/1670 train_time:113336ms step_avg:92.97ms +step:1220/1670 train_time:113429ms step_avg:92.97ms +step:1221/1670 train_time:113522ms step_avg:92.97ms +step:1222/1670 train_time:113614ms step_avg:92.97ms +step:1223/1670 train_time:113708ms step_avg:92.97ms +step:1224/1670 train_time:113801ms step_avg:92.97ms +step:1225/1670 train_time:113893ms step_avg:92.97ms +step:1226/1670 train_time:113988ms step_avg:92.98ms +step:1227/1670 train_time:114081ms step_avg:92.98ms +step:1228/1670 train_time:114174ms step_avg:92.98ms +step:1229/1670 train_time:114269ms step_avg:92.98ms +step:1230/1670 train_time:114361ms step_avg:92.98ms +step:1231/1670 train_time:114454ms step_avg:92.98ms +step:1232/1670 train_time:114548ms step_avg:92.98ms +step:1233/1670 train_time:114640ms step_avg:92.98ms +step:1234/1670 train_time:114733ms step_avg:92.98ms +step:1235/1670 train_time:114826ms step_avg:92.98ms +step:1236/1670 train_time:114919ms step_avg:92.98ms +step:1237/1670 train_time:115012ms step_avg:92.98ms +step:1238/1670 train_time:115106ms step_avg:92.98ms +step:1239/1670 train_time:115199ms step_avg:92.98ms +step:1240/1670 train_time:115292ms step_avg:92.98ms +step:1241/1670 train_time:115386ms step_avg:92.98ms +step:1242/1670 train_time:115479ms step_avg:92.98ms +step:1243/1670 train_time:115573ms step_avg:92.98ms +step:1244/1670 train_time:115668ms step_avg:92.98ms +step:1245/1670 train_time:115761ms step_avg:92.98ms +step:1246/1670 train_time:115853ms step_avg:92.98ms +step:1247/1670 train_time:115947ms step_avg:92.98ms +step:1248/1670 train_time:116040ms step_avg:92.98ms +step:1249/1670 train_time:116132ms step_avg:92.98ms +step:1250/1670 train_time:116225ms step_avg:92.98ms +step:1250/1670 val_loss:3.3763 train_time:116317ms step_avg:93.05ms +step:1251/1670 train_time:116337ms step_avg:93.00ms +step:1252/1670 train_time:116411ms step_avg:92.98ms +step:1253/1670 train_time:116503ms step_avg:92.98ms +step:1254/1670 train_time:116596ms step_avg:92.98ms +step:1255/1670 train_time:116690ms step_avg:92.98ms +step:1256/1670 train_time:116783ms step_avg:92.98ms +step:1257/1670 train_time:116875ms step_avg:92.98ms +step:1258/1670 train_time:116968ms step_avg:92.98ms +step:1259/1670 train_time:117060ms step_avg:92.98ms +step:1260/1670 train_time:117153ms step_avg:92.98ms +step:1261/1670 train_time:117247ms step_avg:92.98ms +step:1262/1670 train_time:117341ms step_avg:92.98ms +step:1263/1670 train_time:117435ms step_avg:92.98ms +step:1264/1670 train_time:117529ms step_avg:92.98ms +step:1265/1670 train_time:117621ms step_avg:92.98ms +step:1266/1670 train_time:117714ms step_avg:92.98ms +step:1267/1670 train_time:117809ms step_avg:92.98ms +step:1268/1670 train_time:117901ms step_avg:92.98ms +step:1269/1670 train_time:117993ms step_avg:92.98ms +step:1270/1670 train_time:118086ms step_avg:92.98ms +step:1271/1670 train_time:118179ms step_avg:92.98ms +step:1272/1670 train_time:118274ms step_avg:92.98ms +step:1273/1670 train_time:118367ms step_avg:92.98ms +step:1274/1670 train_time:118617ms step_avg:93.11ms +step:1275/1670 train_time:118686ms step_avg:93.09ms +step:1276/1670 train_time:118777ms step_avg:93.09ms +step:1277/1670 train_time:118870ms step_avg:93.09ms +step:1278/1670 train_time:118961ms step_avg:93.08ms +step:1279/1670 train_time:119052ms step_avg:93.08ms +step:1280/1670 train_time:119145ms step_avg:93.08ms +step:1281/1670 train_time:119237ms step_avg:93.08ms +step:1282/1670 train_time:119329ms step_avg:93.08ms +step:1283/1670 train_time:119420ms step_avg:93.08ms +step:1284/1670 train_time:119520ms step_avg:93.08ms +step:1285/1670 train_time:119619ms step_avg:93.09ms +step:1286/1670 train_time:119715ms step_avg:93.09ms +step:1287/1670 train_time:119807ms step_avg:93.09ms +step:1288/1670 train_time:119899ms step_avg:93.09ms +step:1289/1670 train_time:119991ms step_avg:93.09ms +step:1290/1670 train_time:120084ms step_avg:93.09ms +step:1291/1670 train_time:120176ms step_avg:93.09ms +step:1292/1670 train_time:120267ms step_avg:93.09ms +step:1293/1670 train_time:120360ms step_avg:93.09ms +step:1294/1670 train_time:120455ms step_avg:93.09ms +step:1295/1670 train_time:120551ms step_avg:93.09ms +step:1296/1670 train_time:120646ms step_avg:93.09ms +step:1297/1670 train_time:120740ms step_avg:93.09ms +step:1298/1670 train_time:120834ms step_avg:93.09ms +step:1299/1670 train_time:120927ms step_avg:93.09ms +step:1300/1670 train_time:121019ms step_avg:93.09ms +step:1301/1670 train_time:121111ms step_avg:93.09ms +step:1302/1670 train_time:121203ms step_avg:93.09ms +step:1303/1670 train_time:121295ms step_avg:93.09ms +step:1304/1670 train_time:121388ms step_avg:93.09ms +step:1305/1670 train_time:121481ms step_avg:93.09ms +step:1306/1670 train_time:121576ms step_avg:93.09ms +step:1307/1670 train_time:121670ms step_avg:93.09ms +step:1308/1670 train_time:121763ms step_avg:93.09ms +step:1309/1670 train_time:121857ms step_avg:93.09ms +step:1310/1670 train_time:121951ms step_avg:93.09ms +step:1311/1670 train_time:122043ms step_avg:93.09ms +step:1312/1670 train_time:122136ms step_avg:93.09ms +step:1313/1670 train_time:122228ms step_avg:93.09ms +step:1314/1670 train_time:122321ms step_avg:93.09ms +step:1315/1670 train_time:122415ms step_avg:93.09ms +step:1316/1670 train_time:122508ms step_avg:93.09ms +step:1317/1670 train_time:122602ms step_avg:93.09ms +step:1318/1670 train_time:122698ms step_avg:93.09ms +step:1319/1670 train_time:122792ms step_avg:93.09ms +step:1320/1670 train_time:122885ms step_avg:93.09ms +step:1321/1670 train_time:122978ms step_avg:93.09ms +step:1322/1670 train_time:123071ms step_avg:93.09ms +step:1323/1670 train_time:123164ms step_avg:93.09ms +step:1324/1670 train_time:123256ms step_avg:93.09ms +step:1325/1670 train_time:123349ms step_avg:93.09ms +step:1326/1670 train_time:123443ms step_avg:93.09ms +step:1327/1670 train_time:123537ms step_avg:93.09ms +step:1328/1670 train_time:123630ms step_avg:93.10ms +step:1329/1670 train_time:123724ms step_avg:93.10ms +step:1330/1670 train_time:123818ms step_avg:93.10ms +step:1331/1670 train_time:123913ms step_avg:93.10ms +step:1332/1670 train_time:124005ms step_avg:93.10ms +step:1333/1670 train_time:124097ms step_avg:93.10ms +step:1334/1670 train_time:124190ms step_avg:93.10ms +step:1335/1670 train_time:124283ms step_avg:93.10ms +step:1336/1670 train_time:124376ms step_avg:93.10ms +step:1337/1670 train_time:124469ms step_avg:93.10ms +step:1338/1670 train_time:124562ms step_avg:93.10ms +step:1339/1670 train_time:124657ms step_avg:93.10ms +step:1340/1670 train_time:124751ms step_avg:93.10ms +step:1341/1670 train_time:124843ms step_avg:93.10ms +step:1342/1670 train_time:124937ms step_avg:93.10ms +step:1343/1670 train_time:125031ms step_avg:93.10ms +step:1344/1670 train_time:125123ms step_avg:93.10ms +step:1345/1670 train_time:125217ms step_avg:93.10ms +step:1346/1670 train_time:125309ms step_avg:93.10ms +step:1347/1670 train_time:125402ms step_avg:93.10ms +step:1348/1670 train_time:125496ms step_avg:93.10ms +step:1349/1670 train_time:125589ms step_avg:93.10ms +step:1350/1670 train_time:125682ms step_avg:93.10ms +step:1351/1670 train_time:125777ms step_avg:93.10ms +step:1352/1670 train_time:125870ms step_avg:93.10ms +step:1353/1670 train_time:125963ms step_avg:93.10ms +step:1354/1670 train_time:126057ms step_avg:93.10ms +step:1355/1670 train_time:126151ms step_avg:93.10ms +step:1356/1670 train_time:126243ms step_avg:93.10ms +step:1357/1670 train_time:126337ms step_avg:93.10ms +step:1358/1670 train_time:126429ms step_avg:93.10ms +step:1359/1670 train_time:126522ms step_avg:93.10ms +step:1360/1670 train_time:126617ms step_avg:93.10ms +step:1361/1670 train_time:126709ms step_avg:93.10ms +step:1362/1670 train_time:126803ms step_avg:93.10ms +step:1363/1670 train_time:126896ms step_avg:93.10ms +step:1364/1670 train_time:126990ms step_avg:93.10ms +step:1365/1670 train_time:127082ms step_avg:93.10ms +step:1366/1670 train_time:127176ms step_avg:93.10ms +step:1367/1670 train_time:127269ms step_avg:93.10ms +step:1368/1670 train_time:127361ms step_avg:93.10ms +step:1369/1670 train_time:127455ms step_avg:93.10ms +step:1370/1670 train_time:127549ms step_avg:93.10ms +step:1371/1670 train_time:127642ms step_avg:93.10ms +step:1372/1670 train_time:127737ms step_avg:93.10ms +step:1373/1670 train_time:127830ms step_avg:93.10ms +step:1374/1670 train_time:127923ms step_avg:93.10ms +step:1375/1670 train_time:128017ms step_avg:93.10ms +step:1375/1670 val_loss:3.3410 train_time:128110ms step_avg:93.17ms +step:1376/1670 train_time:128130ms step_avg:93.12ms +step:1377/1670 train_time:128207ms step_avg:93.11ms +step:1378/1670 train_time:128300ms step_avg:93.11ms +step:1379/1670 train_time:128393ms step_avg:93.11ms +step:1380/1670 train_time:128485ms step_avg:93.10ms +step:1381/1670 train_time:128578ms step_avg:93.10ms +step:1382/1670 train_time:128670ms step_avg:93.10ms +step:1383/1670 train_time:128764ms step_avg:93.11ms +step:1384/1670 train_time:128858ms step_avg:93.11ms +step:1385/1670 train_time:128950ms step_avg:93.11ms +step:1386/1670 train_time:129045ms step_avg:93.11ms +step:1387/1670 train_time:129139ms step_avg:93.11ms +step:1388/1670 train_time:129234ms step_avg:93.11ms +step:1389/1670 train_time:129327ms step_avg:93.11ms +step:1390/1670 train_time:129419ms step_avg:93.11ms +step:1391/1670 train_time:129511ms step_avg:93.11ms +step:1392/1670 train_time:129605ms step_avg:93.11ms +step:1393/1670 train_time:129699ms step_avg:93.11ms +step:1394/1670 train_time:129791ms step_avg:93.11ms +step:1395/1670 train_time:129886ms step_avg:93.11ms +step:1396/1670 train_time:129979ms step_avg:93.11ms +step:1397/1670 train_time:130072ms step_avg:93.11ms +step:1398/1670 train_time:130166ms step_avg:93.11ms +step:1399/1670 train_time:130260ms step_avg:93.11ms +step:1400/1670 train_time:130352ms step_avg:93.11ms +step:1401/1670 train_time:130445ms step_avg:93.11ms +step:1402/1670 train_time:130538ms step_avg:93.11ms +step:1403/1670 train_time:130631ms step_avg:93.11ms +step:1404/1670 train_time:130723ms step_avg:93.11ms +step:1405/1670 train_time:130816ms step_avg:93.11ms +step:1406/1670 train_time:130910ms step_avg:93.11ms +step:1407/1670 train_time:131004ms step_avg:93.11ms +step:1408/1670 train_time:131097ms step_avg:93.11ms +step:1409/1670 train_time:131190ms step_avg:93.11ms +step:1410/1670 train_time:131285ms step_avg:93.11ms +step:1411/1670 train_time:131378ms step_avg:93.11ms +step:1412/1670 train_time:131471ms step_avg:93.11ms +step:1413/1670 train_time:131564ms step_avg:93.11ms +step:1414/1670 train_time:131658ms step_avg:93.11ms +step:1415/1670 train_time:131750ms step_avg:93.11ms +step:1416/1670 train_time:131843ms step_avg:93.11ms +step:1417/1670 train_time:131936ms step_avg:93.11ms +step:1418/1670 train_time:132030ms step_avg:93.11ms +step:1419/1670 train_time:132123ms step_avg:93.11ms +step:1420/1670 train_time:132216ms step_avg:93.11ms +step:1421/1670 train_time:132309ms step_avg:93.11ms +step:1422/1670 train_time:132403ms step_avg:93.11ms +step:1423/1670 train_time:132497ms step_avg:93.11ms +step:1424/1670 train_time:132589ms step_avg:93.11ms +step:1425/1670 train_time:132683ms step_avg:93.11ms +step:1426/1670 train_time:132776ms step_avg:93.11ms +step:1427/1670 train_time:132870ms step_avg:93.11ms +step:1428/1670 train_time:132964ms step_avg:93.11ms +step:1429/1670 train_time:133058ms step_avg:93.11ms +step:1430/1670 train_time:133151ms step_avg:93.11ms +step:1431/1670 train_time:133245ms step_avg:93.11ms +step:1432/1670 train_time:133337ms step_avg:93.11ms +step:1433/1670 train_time:133431ms step_avg:93.11ms +step:1434/1670 train_time:133524ms step_avg:93.11ms +step:1435/1670 train_time:133616ms step_avg:93.11ms +step:1436/1670 train_time:133710ms step_avg:93.11ms +step:1437/1670 train_time:133802ms step_avg:93.11ms +step:1438/1670 train_time:133895ms step_avg:93.11ms +step:1439/1670 train_time:133988ms step_avg:93.11ms +step:1440/1670 train_time:134082ms step_avg:93.11ms +step:1441/1670 train_time:134175ms step_avg:93.11ms +step:1442/1670 train_time:134269ms step_avg:93.11ms +step:1443/1670 train_time:134363ms step_avg:93.11ms +step:1444/1670 train_time:134457ms step_avg:93.11ms +step:1445/1670 train_time:134550ms step_avg:93.11ms +step:1446/1670 train_time:134643ms step_avg:93.11ms +step:1447/1670 train_time:134736ms step_avg:93.11ms +step:1448/1670 train_time:134829ms step_avg:93.11ms +step:1449/1670 train_time:134922ms step_avg:93.11ms +step:1450/1670 train_time:135016ms step_avg:93.11ms +step:1451/1670 train_time:135109ms step_avg:93.11ms +step:1452/1670 train_time:135203ms step_avg:93.12ms +step:1453/1670 train_time:135297ms step_avg:93.12ms +step:1454/1670 train_time:135390ms step_avg:93.12ms +step:1455/1670 train_time:135484ms step_avg:93.12ms +step:1456/1670 train_time:135577ms step_avg:93.12ms +step:1457/1670 train_time:135670ms step_avg:93.12ms +step:1458/1670 train_time:135765ms step_avg:93.12ms +step:1459/1670 train_time:135858ms step_avg:93.12ms +step:1460/1670 train_time:135951ms step_avg:93.12ms +step:1461/1670 train_time:136044ms step_avg:93.12ms +step:1462/1670 train_time:136138ms step_avg:93.12ms +step:1463/1670 train_time:136230ms step_avg:93.12ms +step:1464/1670 train_time:136324ms step_avg:93.12ms +step:1465/1670 train_time:136417ms step_avg:93.12ms +step:1466/1670 train_time:136511ms step_avg:93.12ms +step:1467/1670 train_time:136605ms step_avg:93.12ms +step:1468/1670 train_time:136699ms step_avg:93.12ms +step:1469/1670 train_time:136791ms step_avg:93.12ms +step:1470/1670 train_time:136884ms step_avg:93.12ms +step:1471/1670 train_time:136977ms step_avg:93.12ms +step:1472/1670 train_time:137071ms step_avg:93.12ms +step:1473/1670 train_time:137164ms step_avg:93.12ms +step:1474/1670 train_time:137257ms step_avg:93.12ms +step:1475/1670 train_time:137350ms step_avg:93.12ms +step:1476/1670 train_time:137444ms step_avg:93.12ms +step:1477/1670 train_time:137537ms step_avg:93.12ms +step:1478/1670 train_time:137630ms step_avg:93.12ms +step:1479/1670 train_time:137725ms step_avg:93.12ms +step:1480/1670 train_time:137818ms step_avg:93.12ms +step:1481/1670 train_time:137911ms step_avg:93.12ms +step:1482/1670 train_time:138005ms step_avg:93.12ms +step:1483/1670 train_time:138098ms step_avg:93.12ms +step:1484/1670 train_time:138190ms step_avg:93.12ms +step:1485/1670 train_time:138439ms step_avg:93.23ms +step:1486/1670 train_time:138513ms step_avg:93.21ms +step:1487/1670 train_time:138606ms step_avg:93.21ms +step:1488/1670 train_time:138698ms step_avg:93.21ms +step:1489/1670 train_time:138789ms step_avg:93.21ms +step:1490/1670 train_time:138882ms step_avg:93.21ms +step:1491/1670 train_time:138973ms step_avg:93.21ms +step:1492/1670 train_time:139065ms step_avg:93.21ms +step:1493/1670 train_time:139157ms step_avg:93.21ms +step:1494/1670 train_time:139249ms step_avg:93.21ms +step:1495/1670 train_time:139347ms step_avg:93.21ms +step:1496/1670 train_time:139447ms step_avg:93.21ms +step:1497/1670 train_time:139542ms step_avg:93.21ms +step:1498/1670 train_time:139634ms step_avg:93.21ms +step:1499/1670 train_time:139726ms step_avg:93.21ms +step:1500/1670 train_time:139818ms step_avg:93.21ms +step:1500/1670 val_loss:3.3110 train_time:139910ms step_avg:93.27ms +step:1501/1670 train_time:139930ms step_avg:93.22ms +step:1502/1670 train_time:140004ms step_avg:93.21ms +step:1503/1670 train_time:140098ms step_avg:93.21ms +step:1504/1670 train_time:140190ms step_avg:93.21ms +step:1505/1670 train_time:140282ms step_avg:93.21ms +step:1506/1670 train_time:140375ms step_avg:93.21ms +step:1507/1670 train_time:140468ms step_avg:93.21ms +step:1508/1670 train_time:140561ms step_avg:93.21ms +step:1509/1670 train_time:140655ms step_avg:93.21ms +step:1510/1670 train_time:140749ms step_avg:93.21ms +step:1511/1670 train_time:140842ms step_avg:93.21ms +step:1512/1670 train_time:140937ms step_avg:93.21ms +step:1513/1670 train_time:141031ms step_avg:93.21ms +step:1514/1670 train_time:141124ms step_avg:93.21ms +step:1515/1670 train_time:141217ms step_avg:93.21ms +step:1516/1670 train_time:141310ms step_avg:93.21ms +step:1517/1670 train_time:141402ms step_avg:93.21ms +step:1518/1670 train_time:141495ms step_avg:93.21ms +step:1519/1670 train_time:141588ms step_avg:93.21ms +step:1520/1670 train_time:141680ms step_avg:93.21ms +step:1521/1670 train_time:141775ms step_avg:93.21ms +step:1522/1670 train_time:141869ms step_avg:93.21ms +step:1523/1670 train_time:141961ms step_avg:93.21ms +step:1524/1670 train_time:142054ms step_avg:93.21ms +step:1525/1670 train_time:142147ms step_avg:93.21ms +step:1526/1670 train_time:142239ms step_avg:93.21ms +step:1527/1670 train_time:142332ms step_avg:93.21ms +step:1528/1670 train_time:142425ms step_avg:93.21ms +step:1529/1670 train_time:142519ms step_avg:93.21ms +step:1530/1670 train_time:142612ms step_avg:93.21ms +step:1531/1670 train_time:142707ms step_avg:93.21ms +step:1532/1670 train_time:142800ms step_avg:93.21ms +step:1533/1670 train_time:142895ms step_avg:93.21ms +step:1534/1670 train_time:142989ms step_avg:93.21ms +step:1535/1670 train_time:143082ms step_avg:93.21ms +step:1536/1670 train_time:143176ms step_avg:93.21ms +step:1537/1670 train_time:143268ms step_avg:93.21ms +step:1538/1670 train_time:143361ms step_avg:93.21ms +step:1539/1670 train_time:143455ms step_avg:93.21ms +step:1540/1670 train_time:143548ms step_avg:93.21ms +step:1541/1670 train_time:143641ms step_avg:93.21ms +step:1542/1670 train_time:143734ms step_avg:93.21ms +step:1543/1670 train_time:143827ms step_avg:93.21ms +step:1544/1670 train_time:143922ms step_avg:93.21ms +step:1545/1670 train_time:144016ms step_avg:93.21ms +step:1546/1670 train_time:144108ms step_avg:93.21ms +step:1547/1670 train_time:144202ms step_avg:93.21ms +step:1548/1670 train_time:144295ms step_avg:93.21ms +step:1549/1670 train_time:144388ms step_avg:93.21ms +step:1550/1670 train_time:144481ms step_avg:93.21ms +step:1551/1670 train_time:144574ms step_avg:93.21ms +step:1552/1670 train_time:144666ms step_avg:93.21ms +step:1553/1670 train_time:144763ms step_avg:93.21ms +step:1554/1670 train_time:144854ms step_avg:93.21ms +step:1555/1670 train_time:144947ms step_avg:93.21ms +step:1556/1670 train_time:145040ms step_avg:93.21ms +step:1557/1670 train_time:145134ms step_avg:93.21ms +step:1558/1670 train_time:145227ms step_avg:93.21ms +step:1559/1670 train_time:145320ms step_avg:93.21ms +step:1560/1670 train_time:145413ms step_avg:93.21ms +step:1561/1670 train_time:145506ms step_avg:93.21ms +step:1562/1670 train_time:145599ms step_avg:93.21ms +step:1563/1670 train_time:145693ms step_avg:93.21ms +step:1564/1670 train_time:145785ms step_avg:93.21ms +step:1565/1670 train_time:145879ms step_avg:93.21ms +step:1566/1670 train_time:145974ms step_avg:93.21ms +step:1567/1670 train_time:146066ms step_avg:93.21ms +step:1568/1670 train_time:146159ms step_avg:93.21ms +step:1569/1670 train_time:146253ms step_avg:93.21ms +step:1570/1670 train_time:146346ms step_avg:93.21ms +step:1571/1670 train_time:146440ms step_avg:93.21ms +step:1572/1670 train_time:146532ms step_avg:93.21ms +step:1573/1670 train_time:146626ms step_avg:93.21ms +step:1574/1670 train_time:146720ms step_avg:93.21ms +step:1575/1670 train_time:146814ms step_avg:93.22ms +step:1576/1670 train_time:146907ms step_avg:93.21ms +step:1577/1670 train_time:147000ms step_avg:93.21ms +step:1578/1670 train_time:147093ms step_avg:93.22ms +step:1579/1670 train_time:147186ms step_avg:93.21ms +step:1580/1670 train_time:147280ms step_avg:93.21ms +step:1581/1670 train_time:147373ms step_avg:93.22ms +step:1582/1670 train_time:147466ms step_avg:93.21ms +step:1583/1670 train_time:147560ms step_avg:93.22ms +step:1584/1670 train_time:147655ms step_avg:93.22ms +step:1585/1670 train_time:147747ms step_avg:93.22ms +step:1586/1670 train_time:147840ms step_avg:93.22ms +step:1587/1670 train_time:147933ms step_avg:93.22ms +step:1588/1670 train_time:148027ms step_avg:93.22ms +step:1589/1670 train_time:148121ms step_avg:93.22ms +step:1590/1670 train_time:148215ms step_avg:93.22ms +step:1591/1670 train_time:148308ms step_avg:93.22ms +step:1592/1670 train_time:148401ms step_avg:93.22ms +step:1593/1670 train_time:148494ms step_avg:93.22ms +step:1594/1670 train_time:148588ms step_avg:93.22ms +step:1595/1670 train_time:148681ms step_avg:93.22ms +step:1596/1670 train_time:148774ms step_avg:93.22ms +step:1597/1670 train_time:148866ms step_avg:93.22ms +step:1598/1670 train_time:148960ms step_avg:93.22ms +step:1599/1670 train_time:149053ms step_avg:93.22ms +step:1600/1670 train_time:149146ms step_avg:93.22ms +step:1601/1670 train_time:149241ms step_avg:93.22ms +step:1602/1670 train_time:149334ms step_avg:93.22ms +step:1603/1670 train_time:149427ms step_avg:93.22ms +step:1604/1670 train_time:149522ms step_avg:93.22ms +step:1605/1670 train_time:149615ms step_avg:93.22ms +step:1606/1670 train_time:149708ms step_avg:93.22ms +step:1607/1670 train_time:149801ms step_avg:93.22ms +step:1608/1670 train_time:149896ms step_avg:93.22ms +step:1609/1670 train_time:149990ms step_avg:93.22ms +step:1610/1670 train_time:150084ms step_avg:93.22ms +step:1611/1670 train_time:150177ms step_avg:93.22ms +step:1612/1670 train_time:150269ms step_avg:93.22ms +step:1613/1670 train_time:150362ms step_avg:93.22ms +step:1614/1670 train_time:150456ms step_avg:93.22ms +step:1615/1670 train_time:150550ms step_avg:93.22ms +step:1616/1670 train_time:150642ms step_avg:93.22ms +step:1617/1670 train_time:150735ms step_avg:93.22ms +step:1618/1670 train_time:150829ms step_avg:93.22ms +step:1619/1670 train_time:150922ms step_avg:93.22ms +step:1620/1670 train_time:151017ms step_avg:93.22ms +step:1621/1670 train_time:151111ms step_avg:93.22ms +step:1622/1670 train_time:151204ms step_avg:93.22ms +step:1623/1670 train_time:151297ms step_avg:93.22ms +step:1624/1670 train_time:151390ms step_avg:93.22ms +step:1625/1670 train_time:151484ms step_avg:93.22ms +step:1625/1670 val_loss:3.2859 train_time:151578ms step_avg:93.28ms +step:1626/1670 train_time:151597ms step_avg:93.23ms +step:1627/1670 train_time:151674ms step_avg:93.22ms +step:1628/1670 train_time:151767ms step_avg:93.22ms +step:1629/1670 train_time:151860ms step_avg:93.22ms +step:1630/1670 train_time:151952ms step_avg:93.22ms +step:1631/1670 train_time:152044ms step_avg:93.22ms +step:1632/1670 train_time:152137ms step_avg:93.22ms +step:1633/1670 train_time:152230ms step_avg:93.22ms +step:1634/1670 train_time:152323ms step_avg:93.22ms +step:1635/1670 train_time:152416ms step_avg:93.22ms +step:1636/1670 train_time:152510ms step_avg:93.22ms +step:1637/1670 train_time:152608ms step_avg:93.22ms +step:1638/1670 train_time:152703ms step_avg:93.23ms +step:1639/1670 train_time:152797ms step_avg:93.23ms +step:1640/1670 train_time:152889ms step_avg:93.22ms +step:1641/1670 train_time:152982ms step_avg:93.22ms +step:1642/1670 train_time:153075ms step_avg:93.22ms +step:1643/1670 train_time:153170ms step_avg:93.23ms +step:1644/1670 train_time:153264ms step_avg:93.23ms +step:1645/1670 train_time:153358ms step_avg:93.23ms +step:1646/1670 train_time:153450ms step_avg:93.23ms +step:1647/1670 train_time:153545ms step_avg:93.23ms +step:1648/1670 train_time:153640ms step_avg:93.23ms +step:1649/1670 train_time:153733ms step_avg:93.23ms +step:1650/1670 train_time:153827ms step_avg:93.23ms +step:1651/1670 train_time:153920ms step_avg:93.23ms +step:1652/1670 train_time:154012ms step_avg:93.23ms +step:1653/1670 train_time:154106ms step_avg:93.23ms +step:1654/1670 train_time:154200ms step_avg:93.23ms +step:1655/1670 train_time:154295ms step_avg:93.23ms +step:1656/1670 train_time:154387ms step_avg:93.23ms +step:1657/1670 train_time:154480ms step_avg:93.23ms +step:1658/1670 train_time:154573ms step_avg:93.23ms +step:1659/1670 train_time:154667ms step_avg:93.23ms +step:1660/1670 train_time:154760ms step_avg:93.23ms +step:1661/1670 train_time:154853ms step_avg:93.23ms +step:1662/1670 train_time:154948ms step_avg:93.23ms +step:1663/1670 train_time:155040ms step_avg:93.23ms +step:1664/1670 train_time:155133ms step_avg:93.23ms +step:1665/1670 train_time:155226ms step_avg:93.23ms +step:1666/1670 train_time:155318ms step_avg:93.23ms +step:1667/1670 train_time:155411ms step_avg:93.23ms +step:1668/1670 train_time:155506ms step_avg:93.23ms +step:1669/1670 train_time:155599ms step_avg:93.23ms +step:1670/1670 train_time:155693ms step_avg:93.23ms +step:1670/1670 val_loss:3.2772 train_time:155956ms step_avg:93.39ms +peak memory allocated: 32002 MiB reserved: 47034 MiB diff --git a/records/091125_VectSigmoidBFloat16/3b564c86-3d85-490d-a96a-83bd60d60f11.txt b/records/091125_VectSigmoidBFloat16/3b564c86-3d85-490d-a96a-83bd60d60f11.txt new file mode 100644 index 000000000..008aff19d --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/3b564c86-3d85-490d-a96a-83bd60d60f11.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:16:45 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 129W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.06ms +step:1/1670 train_time:296ms step_avg:296.32ms +step:2/1670 train_time:313ms step_avg:156.66ms +step:3/1670 train_time:383ms step_avg:127.83ms +step:4/1670 train_time:472ms step_avg:118.07ms +step:5/1670 train_time:563ms step_avg:112.54ms +step:6/1670 train_time:653ms step_avg:108.86ms +step:7/1670 train_time:744ms step_avg:106.32ms +step:8/1670 train_time:835ms step_avg:104.32ms +step:9/1670 train_time:925ms step_avg:102.77ms +step:10/1670 train_time:1016ms step_avg:101.58ms +step:11/1670 train_time:1108ms step_avg:100.69ms +step:12/1670 train_time:1203ms step_avg:100.28ms +step:13/1670 train_time:1299ms step_avg:99.90ms +step:14/1670 train_time:1392ms step_avg:99.46ms +step:15/1670 train_time:1484ms step_avg:98.91ms +step:16/1670 train_time:1574ms step_avg:98.39ms +step:17/1670 train_time:1666ms step_avg:97.98ms +step:18/1670 train_time:1756ms step_avg:97.55ms +step:19/1670 train_time:1846ms step_avg:97.18ms +step:20/1670 train_time:1938ms step_avg:96.88ms +step:21/1670 train_time:2029ms step_avg:96.61ms +step:22/1670 train_time:2120ms step_avg:96.37ms +step:23/1670 train_time:2213ms step_avg:96.20ms +step:24/1670 train_time:2307ms step_avg:96.13ms +step:25/1670 train_time:2400ms step_avg:96.01ms +step:26/1670 train_time:2491ms step_avg:95.81ms +step:27/1670 train_time:2583ms step_avg:95.66ms +step:28/1670 train_time:2674ms step_avg:95.50ms +step:29/1670 train_time:2766ms step_avg:95.38ms +step:30/1670 train_time:2857ms step_avg:95.24ms +step:31/1670 train_time:2949ms step_avg:95.12ms +step:32/1670 train_time:3040ms step_avg:95.00ms +step:33/1670 train_time:3131ms step_avg:94.88ms +step:34/1670 train_time:3223ms step_avg:94.80ms +step:35/1670 train_time:3315ms step_avg:94.72ms +step:36/1670 train_time:3408ms step_avg:94.66ms +step:37/1670 train_time:3501ms step_avg:94.61ms +step:38/1670 train_time:3591ms step_avg:94.50ms +step:39/1670 train_time:3683ms step_avg:94.43ms +step:40/1670 train_time:3773ms step_avg:94.34ms +step:41/1670 train_time:3866ms step_avg:94.28ms +step:42/1670 train_time:3957ms step_avg:94.22ms +step:43/1670 train_time:4049ms step_avg:94.16ms +step:44/1670 train_time:4141ms step_avg:94.12ms +step:45/1670 train_time:4233ms step_avg:94.06ms +step:46/1670 train_time:4326ms step_avg:94.05ms +step:47/1670 train_time:4418ms step_avg:93.99ms +step:48/1670 train_time:4509ms step_avg:93.93ms +step:49/1670 train_time:4601ms step_avg:93.90ms +step:50/1670 train_time:4692ms step_avg:93.85ms +step:51/1670 train_time:4783ms step_avg:93.79ms +step:52/1670 train_time:4875ms step_avg:93.75ms +step:53/1670 train_time:4967ms step_avg:93.72ms +step:54/1670 train_time:5059ms step_avg:93.68ms +step:55/1670 train_time:5150ms step_avg:93.63ms +step:56/1670 train_time:5244ms step_avg:93.64ms +step:57/1670 train_time:5336ms step_avg:93.61ms +step:58/1670 train_time:5428ms step_avg:93.58ms +step:59/1670 train_time:5520ms step_avg:93.55ms +step:60/1670 train_time:5611ms step_avg:93.51ms +step:61/1670 train_time:5702ms step_avg:93.48ms +step:62/1670 train_time:5794ms step_avg:93.45ms +step:63/1670 train_time:5885ms step_avg:93.42ms +step:64/1670 train_time:5977ms step_avg:93.39ms +step:65/1670 train_time:6068ms step_avg:93.35ms +step:66/1670 train_time:6160ms step_avg:93.34ms +step:67/1670 train_time:6252ms step_avg:93.31ms +step:68/1670 train_time:6345ms step_avg:93.31ms +step:69/1670 train_time:6438ms step_avg:93.31ms +step:70/1670 train_time:6531ms step_avg:93.29ms +step:71/1670 train_time:6623ms step_avg:93.28ms +step:72/1670 train_time:6715ms step_avg:93.26ms +step:73/1670 train_time:6806ms step_avg:93.24ms +step:74/1670 train_time:6899ms step_avg:93.23ms +step:75/1670 train_time:6989ms step_avg:93.19ms +step:76/1670 train_time:7081ms step_avg:93.17ms +step:77/1670 train_time:7172ms step_avg:93.14ms +step:78/1670 train_time:7264ms step_avg:93.12ms +step:79/1670 train_time:7355ms step_avg:93.10ms +step:80/1670 train_time:7448ms step_avg:93.10ms +step:81/1670 train_time:7540ms step_avg:93.09ms +step:82/1670 train_time:7632ms step_avg:93.07ms +step:83/1670 train_time:7724ms step_avg:93.06ms +step:84/1670 train_time:7815ms step_avg:93.04ms +step:85/1670 train_time:7907ms step_avg:93.02ms +step:86/1670 train_time:7998ms step_avg:93.01ms +step:87/1670 train_time:8089ms step_avg:92.98ms +step:88/1670 train_time:8180ms step_avg:92.95ms +step:89/1670 train_time:8270ms step_avg:92.92ms +step:90/1670 train_time:8361ms step_avg:92.90ms +step:91/1670 train_time:8453ms step_avg:92.89ms +step:92/1670 train_time:8546ms step_avg:92.89ms +step:93/1670 train_time:8637ms step_avg:92.87ms +step:94/1670 train_time:8729ms step_avg:92.86ms +step:95/1670 train_time:8822ms step_avg:92.86ms +step:96/1670 train_time:8912ms step_avg:92.83ms +step:97/1670 train_time:9004ms step_avg:92.82ms +step:98/1670 train_time:9095ms step_avg:92.80ms +step:99/1670 train_time:9186ms step_avg:92.79ms +step:100/1670 train_time:9279ms step_avg:92.79ms +step:101/1670 train_time:9369ms step_avg:92.76ms +step:102/1670 train_time:9460ms step_avg:92.75ms +step:103/1670 train_time:9551ms step_avg:92.73ms +step:104/1670 train_time:9642ms step_avg:92.72ms +step:105/1670 train_time:9733ms step_avg:92.70ms +step:106/1670 train_time:9826ms step_avg:92.69ms +step:107/1670 train_time:9917ms step_avg:92.68ms +step:108/1670 train_time:10009ms step_avg:92.68ms +step:109/1670 train_time:10101ms step_avg:92.67ms +step:110/1670 train_time:10191ms step_avg:92.65ms +step:111/1670 train_time:10283ms step_avg:92.64ms +step:112/1670 train_time:10374ms step_avg:92.62ms +step:113/1670 train_time:10465ms step_avg:92.61ms +step:114/1670 train_time:10555ms step_avg:92.59ms +step:115/1670 train_time:10648ms step_avg:92.59ms +step:116/1670 train_time:10740ms step_avg:92.59ms +step:117/1670 train_time:10831ms step_avg:92.58ms +step:118/1670 train_time:10923ms step_avg:92.57ms +step:119/1670 train_time:11015ms step_avg:92.56ms +step:120/1670 train_time:11107ms step_avg:92.56ms +step:121/1670 train_time:11199ms step_avg:92.55ms +step:122/1670 train_time:11290ms step_avg:92.54ms +step:123/1670 train_time:11381ms step_avg:92.53ms +step:124/1670 train_time:11471ms step_avg:92.51ms +step:125/1670 train_time:11562ms step_avg:92.49ms +step:125/1670 val_loss:4.3072 train_time:11653ms step_avg:93.22ms +step:126/1670 train_time:11671ms step_avg:92.63ms +step:127/1670 train_time:11748ms step_avg:92.51ms +step:128/1670 train_time:11848ms step_avg:92.57ms +step:129/1670 train_time:11940ms step_avg:92.56ms +step:130/1670 train_time:12031ms step_avg:92.55ms +step:131/1670 train_time:12122ms step_avg:92.54ms +step:132/1670 train_time:12212ms step_avg:92.52ms +step:133/1670 train_time:12303ms step_avg:92.50ms +step:134/1670 train_time:12393ms step_avg:92.48ms +step:135/1670 train_time:12483ms step_avg:92.47ms +step:136/1670 train_time:12573ms step_avg:92.45ms +step:137/1670 train_time:12663ms step_avg:92.43ms +step:138/1670 train_time:12756ms step_avg:92.43ms +step:139/1670 train_time:12849ms step_avg:92.44ms +step:140/1670 train_time:12941ms step_avg:92.44ms +step:141/1670 train_time:13033ms step_avg:92.43ms +step:142/1670 train_time:13124ms step_avg:92.42ms +step:143/1670 train_time:13214ms step_avg:92.41ms +step:144/1670 train_time:13305ms step_avg:92.39ms +step:145/1670 train_time:13395ms step_avg:92.38ms +step:146/1670 train_time:13485ms step_avg:92.36ms +step:147/1670 train_time:13575ms step_avg:92.35ms +step:148/1670 train_time:13667ms step_avg:92.35ms +step:149/1670 train_time:13759ms step_avg:92.34ms +step:150/1670 train_time:13851ms step_avg:92.34ms +step:151/1670 train_time:13943ms step_avg:92.33ms +step:152/1670 train_time:14035ms step_avg:92.33ms +step:153/1670 train_time:14126ms step_avg:92.33ms +step:154/1670 train_time:14216ms step_avg:92.31ms +step:155/1670 train_time:14306ms step_avg:92.30ms +step:156/1670 train_time:14397ms step_avg:92.29ms +step:157/1670 train_time:14488ms step_avg:92.28ms +step:158/1670 train_time:14578ms step_avg:92.26ms +step:159/1670 train_time:14669ms step_avg:92.26ms +step:160/1670 train_time:14761ms step_avg:92.26ms +step:161/1670 train_time:14852ms step_avg:92.25ms +step:162/1670 train_time:14944ms step_avg:92.25ms +step:163/1670 train_time:15036ms step_avg:92.24ms +step:164/1670 train_time:15129ms step_avg:92.25ms +step:165/1670 train_time:15219ms step_avg:92.24ms +step:166/1670 train_time:15310ms step_avg:92.23ms +step:167/1670 train_time:15401ms step_avg:92.22ms +step:168/1670 train_time:15492ms step_avg:92.21ms +step:169/1670 train_time:15583ms step_avg:92.21ms +step:170/1670 train_time:15674ms step_avg:92.20ms +step:171/1670 train_time:15765ms step_avg:92.19ms +step:172/1670 train_time:15856ms step_avg:92.19ms +step:173/1670 train_time:15947ms step_avg:92.18ms +step:174/1670 train_time:16039ms step_avg:92.18ms +step:175/1670 train_time:16132ms step_avg:92.18ms +step:176/1670 train_time:16224ms step_avg:92.18ms +step:177/1670 train_time:16315ms step_avg:92.18ms +step:178/1670 train_time:16406ms step_avg:92.17ms +step:179/1670 train_time:16496ms step_avg:92.16ms +step:180/1670 train_time:16588ms step_avg:92.16ms +step:181/1670 train_time:16680ms step_avg:92.15ms +step:182/1670 train_time:16771ms step_avg:92.15ms +step:183/1670 train_time:16861ms step_avg:92.14ms +step:184/1670 train_time:16952ms step_avg:92.13ms +step:185/1670 train_time:17043ms step_avg:92.13ms +step:186/1670 train_time:17135ms step_avg:92.12ms +step:187/1670 train_time:17226ms step_avg:92.12ms +step:188/1670 train_time:17316ms step_avg:92.11ms +step:189/1670 train_time:17408ms step_avg:92.11ms +step:190/1670 train_time:17499ms step_avg:92.10ms +step:191/1670 train_time:17590ms step_avg:92.10ms +step:192/1670 train_time:17681ms step_avg:92.09ms +step:193/1670 train_time:17772ms step_avg:92.08ms +step:194/1670 train_time:17863ms step_avg:92.08ms +step:195/1670 train_time:17956ms step_avg:92.08ms +step:196/1670 train_time:18046ms step_avg:92.07ms +step:197/1670 train_time:18136ms step_avg:92.06ms +step:198/1670 train_time:18227ms step_avg:92.06ms +step:199/1670 train_time:18319ms step_avg:92.06ms +step:200/1670 train_time:18411ms step_avg:92.06ms +step:201/1670 train_time:18503ms step_avg:92.05ms +step:202/1670 train_time:18594ms step_avg:92.05ms +step:203/1670 train_time:18686ms step_avg:92.05ms +step:204/1670 train_time:18777ms step_avg:92.04ms +step:205/1670 train_time:18869ms step_avg:92.04ms +step:206/1670 train_time:18959ms step_avg:92.03ms +step:207/1670 train_time:19051ms step_avg:92.03ms +step:208/1670 train_time:19142ms step_avg:92.03ms +step:209/1670 train_time:19233ms step_avg:92.03ms +step:210/1670 train_time:19324ms step_avg:92.02ms +step:211/1670 train_time:19415ms step_avg:92.02ms +step:212/1670 train_time:19507ms step_avg:92.01ms +step:213/1670 train_time:19759ms step_avg:92.76ms +step:214/1670 train_time:19829ms step_avg:92.66ms +step:215/1670 train_time:19919ms step_avg:92.64ms +step:216/1670 train_time:20009ms step_avg:92.63ms +step:217/1670 train_time:20100ms step_avg:92.63ms +step:218/1670 train_time:20190ms step_avg:92.62ms +step:219/1670 train_time:20280ms step_avg:92.60ms +step:220/1670 train_time:20371ms step_avg:92.59ms +step:221/1670 train_time:20461ms step_avg:92.58ms +step:222/1670 train_time:20550ms step_avg:92.57ms +step:223/1670 train_time:20642ms step_avg:92.57ms +step:224/1670 train_time:20738ms step_avg:92.58ms +step:225/1670 train_time:20832ms step_avg:92.58ms +step:226/1670 train_time:20925ms step_avg:92.59ms +step:227/1670 train_time:21015ms step_avg:92.58ms +step:228/1670 train_time:21105ms step_avg:92.57ms +step:229/1670 train_time:21196ms step_avg:92.56ms +step:230/1670 train_time:21287ms step_avg:92.55ms +step:231/1670 train_time:21377ms step_avg:92.54ms +step:232/1670 train_time:21467ms step_avg:92.53ms +step:233/1670 train_time:21558ms step_avg:92.53ms +step:234/1670 train_time:21652ms step_avg:92.53ms +step:235/1670 train_time:21744ms step_avg:92.53ms +step:236/1670 train_time:21836ms step_avg:92.52ms +step:237/1670 train_time:21927ms step_avg:92.52ms +step:238/1670 train_time:22018ms step_avg:92.51ms +step:239/1670 train_time:22109ms step_avg:92.51ms +step:240/1670 train_time:22199ms step_avg:92.50ms +step:241/1670 train_time:22291ms step_avg:92.49ms +step:242/1670 train_time:22381ms step_avg:92.48ms +step:243/1670 train_time:22472ms step_avg:92.48ms +step:244/1670 train_time:22563ms step_avg:92.47ms +step:245/1670 train_time:22654ms step_avg:92.47ms +step:246/1670 train_time:22745ms step_avg:92.46ms +step:247/1670 train_time:22837ms step_avg:92.46ms +step:248/1670 train_time:22929ms step_avg:92.46ms +step:249/1670 train_time:23021ms step_avg:92.45ms +step:250/1670 train_time:23113ms step_avg:92.45ms +step:250/1670 val_loss:3.9688 train_time:23204ms step_avg:92.82ms +step:251/1670 train_time:23221ms step_avg:92.51ms +step:252/1670 train_time:23296ms step_avg:92.44ms +step:253/1670 train_time:23387ms step_avg:92.44ms +step:254/1670 train_time:23479ms step_avg:92.44ms +step:255/1670 train_time:23569ms step_avg:92.43ms +step:256/1670 train_time:23659ms step_avg:92.42ms +step:257/1670 train_time:23749ms step_avg:92.41ms +step:258/1670 train_time:23840ms step_avg:92.40ms +step:259/1670 train_time:23932ms step_avg:92.40ms +step:260/1670 train_time:24023ms step_avg:92.40ms +step:261/1670 train_time:24114ms step_avg:92.39ms +step:262/1670 train_time:24208ms step_avg:92.40ms +step:263/1670 train_time:24301ms step_avg:92.40ms +step:264/1670 train_time:24392ms step_avg:92.39ms +step:265/1670 train_time:24484ms step_avg:92.39ms +step:266/1670 train_time:24575ms step_avg:92.39ms +step:267/1670 train_time:24666ms step_avg:92.38ms +step:268/1670 train_time:24756ms step_avg:92.37ms +step:269/1670 train_time:24848ms step_avg:92.37ms +step:270/1670 train_time:24939ms step_avg:92.37ms +step:271/1670 train_time:25030ms step_avg:92.36ms +step:272/1670 train_time:25120ms step_avg:92.35ms +step:273/1670 train_time:25212ms step_avg:92.35ms +step:274/1670 train_time:25305ms step_avg:92.35ms +step:275/1670 train_time:25397ms step_avg:92.35ms +step:276/1670 train_time:25489ms step_avg:92.35ms +step:277/1670 train_time:25580ms step_avg:92.35ms +step:278/1670 train_time:25672ms step_avg:92.34ms +step:279/1670 train_time:25763ms step_avg:92.34ms +step:280/1670 train_time:25853ms step_avg:92.33ms +step:281/1670 train_time:25944ms step_avg:92.33ms +step:282/1670 train_time:26034ms step_avg:92.32ms +step:283/1670 train_time:26126ms step_avg:92.32ms +step:284/1670 train_time:26217ms step_avg:92.31ms +step:285/1670 train_time:26309ms step_avg:92.31ms +step:286/1670 train_time:26401ms step_avg:92.31ms +step:287/1670 train_time:26491ms step_avg:92.30ms +step:288/1670 train_time:26584ms step_avg:92.31ms +step:289/1670 train_time:26676ms step_avg:92.30ms +step:290/1670 train_time:26768ms step_avg:92.30ms +step:291/1670 train_time:26860ms step_avg:92.30ms +step:292/1670 train_time:26950ms step_avg:92.30ms +step:293/1670 train_time:27040ms step_avg:92.29ms +step:294/1670 train_time:27131ms step_avg:92.28ms +step:295/1670 train_time:27222ms step_avg:92.28ms +step:296/1670 train_time:27313ms step_avg:92.27ms +step:297/1670 train_time:27406ms step_avg:92.28ms +step:298/1670 train_time:27498ms step_avg:92.27ms +step:299/1670 train_time:27589ms step_avg:92.27ms +step:300/1670 train_time:27680ms step_avg:92.27ms +step:301/1670 train_time:27771ms step_avg:92.26ms +step:302/1670 train_time:27862ms step_avg:92.26ms +step:303/1670 train_time:27953ms step_avg:92.25ms +step:304/1670 train_time:28044ms step_avg:92.25ms +step:305/1670 train_time:28134ms step_avg:92.24ms +step:306/1670 train_time:28226ms step_avg:92.24ms +step:307/1670 train_time:28317ms step_avg:92.24ms +step:308/1670 train_time:28408ms step_avg:92.23ms +step:309/1670 train_time:28499ms step_avg:92.23ms +step:310/1670 train_time:28591ms step_avg:92.23ms +step:311/1670 train_time:28684ms step_avg:92.23ms +step:312/1670 train_time:28775ms step_avg:92.23ms +step:313/1670 train_time:28867ms step_avg:92.23ms +step:314/1670 train_time:28957ms step_avg:92.22ms +step:315/1670 train_time:29048ms step_avg:92.21ms +step:316/1670 train_time:29137ms step_avg:92.21ms +step:317/1670 train_time:29229ms step_avg:92.20ms +step:318/1670 train_time:29319ms step_avg:92.20ms +step:319/1670 train_time:29410ms step_avg:92.20ms +step:320/1670 train_time:29502ms step_avg:92.19ms +step:321/1670 train_time:29594ms step_avg:92.19ms +step:322/1670 train_time:29686ms step_avg:92.19ms +step:323/1670 train_time:29778ms step_avg:92.19ms +step:324/1670 train_time:29870ms step_avg:92.19ms +step:325/1670 train_time:29961ms step_avg:92.19ms +step:326/1670 train_time:30051ms step_avg:92.18ms +step:327/1670 train_time:30141ms step_avg:92.18ms +step:328/1670 train_time:30232ms step_avg:92.17ms +step:329/1670 train_time:30323ms step_avg:92.17ms +step:330/1670 train_time:30414ms step_avg:92.16ms +step:331/1670 train_time:30506ms step_avg:92.16ms +step:332/1670 train_time:30599ms step_avg:92.17ms +step:333/1670 train_time:30690ms step_avg:92.16ms +step:334/1670 train_time:30782ms step_avg:92.16ms +step:335/1670 train_time:30873ms step_avg:92.16ms +step:336/1670 train_time:30966ms step_avg:92.16ms +step:337/1670 train_time:31057ms step_avg:92.16ms +step:338/1670 train_time:31148ms step_avg:92.15ms +step:339/1670 train_time:31239ms step_avg:92.15ms +step:340/1670 train_time:31329ms step_avg:92.14ms +step:341/1670 train_time:31420ms step_avg:92.14ms +step:342/1670 train_time:31510ms step_avg:92.13ms +step:343/1670 train_time:31603ms step_avg:92.14ms +step:344/1670 train_time:31694ms step_avg:92.13ms +step:345/1670 train_time:31787ms step_avg:92.14ms +step:346/1670 train_time:31880ms step_avg:92.14ms +step:347/1670 train_time:31971ms step_avg:92.13ms +step:348/1670 train_time:32063ms step_avg:92.14ms +step:349/1670 train_time:32154ms step_avg:92.13ms +step:350/1670 train_time:32246ms step_avg:92.13ms +step:351/1670 train_time:32337ms step_avg:92.13ms +step:352/1670 train_time:32427ms step_avg:92.12ms +step:353/1670 train_time:32518ms step_avg:92.12ms +step:354/1670 train_time:32609ms step_avg:92.11ms +step:355/1670 train_time:32701ms step_avg:92.12ms +step:356/1670 train_time:32792ms step_avg:92.11ms +step:357/1670 train_time:32884ms step_avg:92.11ms +step:358/1670 train_time:32975ms step_avg:92.11ms +step:359/1670 train_time:33067ms step_avg:92.11ms +step:360/1670 train_time:33158ms step_avg:92.11ms +step:361/1670 train_time:33248ms step_avg:92.10ms +step:362/1670 train_time:33339ms step_avg:92.10ms +step:363/1670 train_time:33429ms step_avg:92.09ms +step:364/1670 train_time:33519ms step_avg:92.08ms +step:365/1670 train_time:33610ms step_avg:92.08ms +step:366/1670 train_time:33702ms step_avg:92.08ms +step:367/1670 train_time:33793ms step_avg:92.08ms +step:368/1670 train_time:33886ms step_avg:92.08ms +step:369/1670 train_time:33978ms step_avg:92.08ms +step:370/1670 train_time:34070ms step_avg:92.08ms +step:371/1670 train_time:34162ms step_avg:92.08ms +step:372/1670 train_time:34252ms step_avg:92.08ms +step:373/1670 train_time:34344ms step_avg:92.07ms +step:374/1670 train_time:34435ms step_avg:92.07ms +step:375/1670 train_time:34526ms step_avg:92.07ms +step:375/1670 val_loss:3.8157 train_time:34617ms step_avg:92.31ms +step:376/1670 train_time:34633ms step_avg:92.11ms +step:377/1670 train_time:34709ms step_avg:92.07ms +step:378/1670 train_time:34800ms step_avg:92.06ms +step:379/1670 train_time:34891ms step_avg:92.06ms +step:380/1670 train_time:34982ms step_avg:92.06ms +step:381/1670 train_time:35072ms step_avg:92.05ms +step:382/1670 train_time:35162ms step_avg:92.05ms +step:383/1670 train_time:35255ms step_avg:92.05ms +step:384/1670 train_time:35346ms step_avg:92.05ms +step:385/1670 train_time:35437ms step_avg:92.04ms +step:386/1670 train_time:35529ms step_avg:92.04ms +step:387/1670 train_time:35621ms step_avg:92.04ms +step:388/1670 train_time:35713ms step_avg:92.04ms +step:389/1670 train_time:35804ms step_avg:92.04ms +step:390/1670 train_time:35896ms step_avg:92.04ms +step:391/1670 train_time:35987ms step_avg:92.04ms +step:392/1670 train_time:36078ms step_avg:92.03ms +step:393/1670 train_time:36170ms step_avg:92.03ms +step:394/1670 train_time:36260ms step_avg:92.03ms +step:395/1670 train_time:36351ms step_avg:92.03ms +step:396/1670 train_time:36442ms step_avg:92.02ms +step:397/1670 train_time:36534ms step_avg:92.02ms +step:398/1670 train_time:36625ms step_avg:92.02ms +step:399/1670 train_time:36718ms step_avg:92.02ms +step:400/1670 train_time:36809ms step_avg:92.02ms +step:401/1670 train_time:36900ms step_avg:92.02ms +step:402/1670 train_time:36991ms step_avg:92.02ms +step:403/1670 train_time:37082ms step_avg:92.01ms +step:404/1670 train_time:37173ms step_avg:92.01ms +step:405/1670 train_time:37264ms step_avg:92.01ms +step:406/1670 train_time:37354ms step_avg:92.01ms +step:407/1670 train_time:37445ms step_avg:92.00ms +step:408/1670 train_time:37537ms step_avg:92.00ms +step:409/1670 train_time:37628ms step_avg:92.00ms +step:410/1670 train_time:37719ms step_avg:92.00ms +step:411/1670 train_time:37811ms step_avg:92.00ms +step:412/1670 train_time:37902ms step_avg:91.99ms +step:413/1670 train_time:37993ms step_avg:91.99ms +step:414/1670 train_time:38084ms step_avg:91.99ms +step:415/1670 train_time:38175ms step_avg:91.99ms +step:416/1670 train_time:38266ms step_avg:91.98ms +step:417/1670 train_time:38357ms step_avg:91.98ms +step:418/1670 train_time:38448ms step_avg:91.98ms +step:419/1670 train_time:38541ms step_avg:91.98ms +step:420/1670 train_time:38631ms step_avg:91.98ms +step:421/1670 train_time:38722ms step_avg:91.98ms +step:422/1670 train_time:38814ms step_avg:91.98ms +step:423/1670 train_time:38904ms step_avg:91.97ms +step:424/1670 train_time:38995ms step_avg:91.97ms +step:425/1670 train_time:39248ms step_avg:92.35ms +step:426/1670 train_time:39319ms step_avg:92.30ms +step:427/1670 train_time:39409ms step_avg:92.29ms +step:428/1670 train_time:39499ms step_avg:92.29ms +step:429/1670 train_time:39589ms step_avg:92.28ms +step:430/1670 train_time:39679ms step_avg:92.28ms +step:431/1670 train_time:39769ms step_avg:92.27ms +step:432/1670 train_time:39859ms step_avg:92.27ms +step:433/1670 train_time:39950ms step_avg:92.26ms +step:434/1670 train_time:40040ms step_avg:92.26ms +step:435/1670 train_time:40136ms step_avg:92.27ms +step:436/1670 train_time:40232ms step_avg:92.28ms +step:437/1670 train_time:40325ms step_avg:92.28ms +step:438/1670 train_time:40417ms step_avg:92.28ms +step:439/1670 train_time:40507ms step_avg:92.27ms +step:440/1670 train_time:40598ms step_avg:92.27ms +step:441/1670 train_time:40688ms step_avg:92.26ms +step:442/1670 train_time:40779ms step_avg:92.26ms +step:443/1670 train_time:40869ms step_avg:92.26ms +step:444/1670 train_time:40959ms step_avg:92.25ms +step:445/1670 train_time:41051ms step_avg:92.25ms +step:446/1670 train_time:41144ms step_avg:92.25ms +step:447/1670 train_time:41238ms step_avg:92.25ms +step:448/1670 train_time:41330ms step_avg:92.25ms +step:449/1670 train_time:41421ms step_avg:92.25ms +step:450/1670 train_time:41512ms step_avg:92.25ms +step:451/1670 train_time:41602ms step_avg:92.24ms +step:452/1670 train_time:41693ms step_avg:92.24ms +step:453/1670 train_time:41783ms step_avg:92.24ms +step:454/1670 train_time:41874ms step_avg:92.23ms +step:455/1670 train_time:41964ms step_avg:92.23ms +step:456/1670 train_time:42055ms step_avg:92.23ms +step:457/1670 train_time:42146ms step_avg:92.22ms +step:458/1670 train_time:42240ms step_avg:92.23ms +step:459/1670 train_time:42333ms step_avg:92.23ms +step:460/1670 train_time:42424ms step_avg:92.23ms +step:461/1670 train_time:42515ms step_avg:92.22ms +step:462/1670 train_time:42606ms step_avg:92.22ms +step:463/1670 train_time:42698ms step_avg:92.22ms +step:464/1670 train_time:42786ms step_avg:92.21ms +step:465/1670 train_time:42878ms step_avg:92.21ms +step:466/1670 train_time:42967ms step_avg:92.20ms +step:467/1670 train_time:43059ms step_avg:92.20ms +step:468/1670 train_time:43151ms step_avg:92.20ms +step:469/1670 train_time:43243ms step_avg:92.20ms +step:470/1670 train_time:43335ms step_avg:92.20ms +step:471/1670 train_time:43426ms step_avg:92.20ms +step:472/1670 train_time:43517ms step_avg:92.20ms +step:473/1670 train_time:43608ms step_avg:92.19ms +step:474/1670 train_time:43699ms step_avg:92.19ms +step:475/1670 train_time:43791ms step_avg:92.19ms +step:476/1670 train_time:43881ms step_avg:92.19ms +step:477/1670 train_time:43971ms step_avg:92.18ms +step:478/1670 train_time:44062ms step_avg:92.18ms +step:479/1670 train_time:44154ms step_avg:92.18ms +step:480/1670 train_time:44245ms step_avg:92.18ms +step:481/1670 train_time:44338ms step_avg:92.18ms +step:482/1670 train_time:44430ms step_avg:92.18ms +step:483/1670 train_time:44522ms step_avg:92.18ms +step:484/1670 train_time:44612ms step_avg:92.17ms +step:485/1670 train_time:44703ms step_avg:92.17ms +step:486/1670 train_time:44794ms step_avg:92.17ms +step:487/1670 train_time:44885ms step_avg:92.17ms +step:488/1670 train_time:44976ms step_avg:92.16ms +step:489/1670 train_time:45066ms step_avg:92.16ms +step:490/1670 train_time:45158ms step_avg:92.16ms +step:491/1670 train_time:45250ms step_avg:92.16ms +step:492/1670 train_time:45341ms step_avg:92.16ms +step:493/1670 train_time:45433ms step_avg:92.16ms +step:494/1670 train_time:45523ms step_avg:92.15ms +step:495/1670 train_time:45615ms step_avg:92.15ms +step:496/1670 train_time:45706ms step_avg:92.15ms +step:497/1670 train_time:45796ms step_avg:92.15ms +step:498/1670 train_time:45887ms step_avg:92.14ms +step:499/1670 train_time:45978ms step_avg:92.14ms +step:500/1670 train_time:46069ms step_avg:92.14ms +step:500/1670 val_loss:3.7139 train_time:46161ms step_avg:92.32ms +step:501/1670 train_time:46178ms step_avg:92.17ms +step:502/1670 train_time:46253ms step_avg:92.14ms +step:503/1670 train_time:46345ms step_avg:92.14ms +step:504/1670 train_time:46435ms step_avg:92.13ms +step:505/1670 train_time:46526ms step_avg:92.13ms +step:506/1670 train_time:46616ms step_avg:92.13ms +step:507/1670 train_time:46707ms step_avg:92.12ms +step:508/1670 train_time:46798ms step_avg:92.12ms +step:509/1670 train_time:46888ms step_avg:92.12ms +step:510/1670 train_time:46979ms step_avg:92.12ms +step:511/1670 train_time:47070ms step_avg:92.11ms +step:512/1670 train_time:47163ms step_avg:92.11ms +step:513/1670 train_time:47254ms step_avg:92.11ms +step:514/1670 train_time:47347ms step_avg:92.11ms +step:515/1670 train_time:47438ms step_avg:92.11ms +step:516/1670 train_time:47529ms step_avg:92.11ms +step:517/1670 train_time:47620ms step_avg:92.11ms +step:518/1670 train_time:47710ms step_avg:92.11ms +step:519/1670 train_time:47801ms step_avg:92.10ms +step:520/1670 train_time:47891ms step_avg:92.10ms +step:521/1670 train_time:47983ms step_avg:92.10ms +step:522/1670 train_time:48074ms step_avg:92.09ms +step:523/1670 train_time:48165ms step_avg:92.09ms +step:524/1670 train_time:48256ms step_avg:92.09ms +step:525/1670 train_time:48350ms step_avg:92.09ms +step:526/1670 train_time:48441ms step_avg:92.09ms +step:527/1670 train_time:48533ms step_avg:92.09ms +step:528/1670 train_time:48624ms step_avg:92.09ms +step:529/1670 train_time:48715ms step_avg:92.09ms +step:530/1670 train_time:48806ms step_avg:92.09ms +step:531/1670 train_time:48896ms step_avg:92.08ms +step:532/1670 train_time:48988ms step_avg:92.08ms +step:533/1670 train_time:49079ms step_avg:92.08ms +step:534/1670 train_time:49170ms step_avg:92.08ms +step:535/1670 train_time:49262ms step_avg:92.08ms +step:536/1670 train_time:49354ms step_avg:92.08ms +step:537/1670 train_time:49446ms step_avg:92.08ms +step:538/1670 train_time:49537ms step_avg:92.08ms +step:539/1670 train_time:49627ms step_avg:92.07ms +step:540/1670 train_time:49718ms step_avg:92.07ms +step:541/1670 train_time:49809ms step_avg:92.07ms +step:542/1670 train_time:49899ms step_avg:92.06ms +step:543/1670 train_time:49990ms step_avg:92.06ms +step:544/1670 train_time:50081ms step_avg:92.06ms +step:545/1670 train_time:50172ms step_avg:92.06ms +step:546/1670 train_time:50264ms step_avg:92.06ms +step:547/1670 train_time:50355ms step_avg:92.06ms +step:548/1670 train_time:50447ms step_avg:92.06ms +step:549/1670 train_time:50538ms step_avg:92.05ms +step:550/1670 train_time:50628ms step_avg:92.05ms +step:551/1670 train_time:50719ms step_avg:92.05ms +step:552/1670 train_time:50812ms step_avg:92.05ms +step:553/1670 train_time:50903ms step_avg:92.05ms +step:554/1670 train_time:50994ms step_avg:92.05ms +step:555/1670 train_time:51085ms step_avg:92.04ms +step:556/1670 train_time:51175ms step_avg:92.04ms +step:557/1670 train_time:51266ms step_avg:92.04ms +step:558/1670 train_time:51556ms step_avg:92.39ms +step:559/1670 train_time:51626ms step_avg:92.35ms +step:560/1670 train_time:51716ms step_avg:92.35ms +step:561/1670 train_time:51807ms step_avg:92.35ms +step:562/1670 train_time:51898ms step_avg:92.35ms +step:563/1670 train_time:51989ms step_avg:92.34ms +step:564/1670 train_time:52081ms step_avg:92.34ms +step:565/1670 train_time:52172ms step_avg:92.34ms +step:566/1670 train_time:52263ms step_avg:92.34ms +step:567/1670 train_time:52354ms step_avg:92.34ms +step:568/1670 train_time:52456ms step_avg:92.35ms +step:569/1670 train_time:52555ms step_avg:92.36ms +step:570/1670 train_time:52648ms step_avg:92.37ms +step:571/1670 train_time:52740ms step_avg:92.36ms +step:572/1670 train_time:52832ms step_avg:92.36ms +step:573/1670 train_time:52923ms step_avg:92.36ms +step:574/1670 train_time:53014ms step_avg:92.36ms +step:575/1670 train_time:53105ms step_avg:92.36ms +step:576/1670 train_time:53197ms step_avg:92.36ms +step:577/1670 train_time:53290ms step_avg:92.36ms +step:578/1670 train_time:53382ms step_avg:92.36ms +step:579/1670 train_time:53477ms step_avg:92.36ms +step:580/1670 train_time:53572ms step_avg:92.37ms +step:581/1670 train_time:53665ms step_avg:92.37ms +step:582/1670 train_time:53757ms step_avg:92.37ms +step:583/1670 train_time:53851ms step_avg:92.37ms +step:584/1670 train_time:53943ms step_avg:92.37ms +step:585/1670 train_time:54034ms step_avg:92.37ms +step:586/1670 train_time:54126ms step_avg:92.36ms +step:587/1670 train_time:54217ms step_avg:92.36ms +step:588/1670 train_time:54310ms step_avg:92.36ms +step:589/1670 train_time:54403ms step_avg:92.37ms +step:590/1670 train_time:54496ms step_avg:92.37ms +step:591/1670 train_time:54591ms step_avg:92.37ms +step:592/1670 train_time:54683ms step_avg:92.37ms +step:593/1670 train_time:54775ms step_avg:92.37ms +step:594/1670 train_time:54868ms step_avg:92.37ms +step:595/1670 train_time:54959ms step_avg:92.37ms +step:596/1670 train_time:55052ms step_avg:92.37ms +step:597/1670 train_time:55144ms step_avg:92.37ms +step:598/1670 train_time:55235ms step_avg:92.37ms +step:599/1670 train_time:55329ms step_avg:92.37ms +step:600/1670 train_time:55421ms step_avg:92.37ms +step:601/1670 train_time:55514ms step_avg:92.37ms +step:602/1670 train_time:55608ms step_avg:92.37ms +step:603/1670 train_time:55701ms step_avg:92.37ms +step:604/1670 train_time:55793ms step_avg:92.37ms +step:605/1670 train_time:55885ms step_avg:92.37ms +step:606/1670 train_time:55977ms step_avg:92.37ms +step:607/1670 train_time:56069ms step_avg:92.37ms +step:608/1670 train_time:56160ms step_avg:92.37ms +step:609/1670 train_time:56253ms step_avg:92.37ms +step:610/1670 train_time:56346ms step_avg:92.37ms +step:611/1670 train_time:56438ms step_avg:92.37ms +step:612/1670 train_time:56531ms step_avg:92.37ms +step:613/1670 train_time:56624ms step_avg:92.37ms +step:614/1670 train_time:56716ms step_avg:92.37ms +step:615/1670 train_time:56809ms step_avg:92.37ms +step:616/1670 train_time:56901ms step_avg:92.37ms +step:617/1670 train_time:56992ms step_avg:92.37ms +step:618/1670 train_time:57086ms step_avg:92.37ms +step:619/1670 train_time:57178ms step_avg:92.37ms +step:620/1670 train_time:57271ms step_avg:92.37ms +step:621/1670 train_time:57364ms step_avg:92.37ms +step:622/1670 train_time:57456ms step_avg:92.37ms +step:623/1670 train_time:57550ms step_avg:92.38ms +step:624/1670 train_time:57642ms step_avg:92.38ms +step:625/1670 train_time:57734ms step_avg:92.37ms +step:625/1670 val_loss:3.6124 train_time:57827ms step_avg:92.52ms +step:626/1670 train_time:57844ms step_avg:92.40ms +step:627/1670 train_time:57923ms step_avg:92.38ms +step:628/1670 train_time:58025ms step_avg:92.40ms +step:629/1670 train_time:58118ms step_avg:92.40ms +step:630/1670 train_time:58211ms step_avg:92.40ms +step:631/1670 train_time:58302ms step_avg:92.40ms +step:632/1670 train_time:58394ms step_avg:92.40ms +step:633/1670 train_time:58484ms step_avg:92.39ms +step:634/1670 train_time:58575ms step_avg:92.39ms +step:635/1670 train_time:58666ms step_avg:92.39ms +step:636/1670 train_time:58758ms step_avg:92.39ms +step:637/1670 train_time:58852ms step_avg:92.39ms +step:638/1670 train_time:58949ms step_avg:92.40ms +step:639/1670 train_time:59184ms step_avg:92.62ms +step:640/1670 train_time:59256ms step_avg:92.59ms +step:641/1670 train_time:59347ms step_avg:92.58ms +step:642/1670 train_time:59439ms step_avg:92.58ms +step:643/1670 train_time:59530ms step_avg:92.58ms +step:644/1670 train_time:59621ms step_avg:92.58ms +step:645/1670 train_time:59712ms step_avg:92.58ms +step:646/1670 train_time:59804ms step_avg:92.58ms +step:647/1670 train_time:59895ms step_avg:92.57ms +step:648/1670 train_time:59986ms step_avg:92.57ms +step:649/1670 train_time:60086ms step_avg:92.58ms +step:650/1670 train_time:60184ms step_avg:92.59ms +step:651/1670 train_time:60277ms step_avg:92.59ms +step:652/1670 train_time:60369ms step_avg:92.59ms +step:653/1670 train_time:60461ms step_avg:92.59ms +step:654/1670 train_time:60554ms step_avg:92.59ms +step:655/1670 train_time:60646ms step_avg:92.59ms +step:656/1670 train_time:60737ms step_avg:92.59ms +step:657/1670 train_time:60827ms step_avg:92.58ms +step:658/1670 train_time:60919ms step_avg:92.58ms +step:659/1670 train_time:61011ms step_avg:92.58ms +step:660/1670 train_time:61105ms step_avg:92.58ms +step:661/1670 train_time:61199ms step_avg:92.59ms +step:662/1670 train_time:61292ms step_avg:92.59ms +step:663/1670 train_time:61384ms step_avg:92.59ms +step:664/1670 train_time:61477ms step_avg:92.59ms +step:665/1670 train_time:61568ms step_avg:92.58ms +step:666/1670 train_time:61660ms step_avg:92.58ms +step:667/1670 train_time:61752ms step_avg:92.58ms +step:668/1670 train_time:61843ms step_avg:92.58ms +step:669/1670 train_time:61935ms step_avg:92.58ms +step:670/1670 train_time:62027ms step_avg:92.58ms +step:671/1670 train_time:62121ms step_avg:92.58ms +step:672/1670 train_time:62214ms step_avg:92.58ms +step:673/1670 train_time:62307ms step_avg:92.58ms +step:674/1670 train_time:62400ms step_avg:92.58ms +step:675/1670 train_time:62492ms step_avg:92.58ms +step:676/1670 train_time:62584ms step_avg:92.58ms +step:677/1670 train_time:62677ms step_avg:92.58ms +step:678/1670 train_time:62768ms step_avg:92.58ms +step:679/1670 train_time:62860ms step_avg:92.58ms +step:680/1670 train_time:62952ms step_avg:92.58ms +step:681/1670 train_time:63044ms step_avg:92.58ms +step:682/1670 train_time:63137ms step_avg:92.58ms +step:683/1670 train_time:63230ms step_avg:92.58ms +step:684/1670 train_time:63323ms step_avg:92.58ms +step:685/1670 train_time:63416ms step_avg:92.58ms +step:686/1670 train_time:63509ms step_avg:92.58ms +step:687/1670 train_time:63602ms step_avg:92.58ms +step:688/1670 train_time:63694ms step_avg:92.58ms +step:689/1670 train_time:63785ms step_avg:92.58ms +step:690/1670 train_time:63877ms step_avg:92.58ms +step:691/1670 train_time:63970ms step_avg:92.58ms +step:692/1670 train_time:64062ms step_avg:92.58ms +step:693/1670 train_time:64156ms step_avg:92.58ms +step:694/1670 train_time:64247ms step_avg:92.58ms +step:695/1670 train_time:64341ms step_avg:92.58ms +step:696/1670 train_time:64435ms step_avg:92.58ms +step:697/1670 train_time:64527ms step_avg:92.58ms +step:698/1670 train_time:64621ms step_avg:92.58ms +step:699/1670 train_time:64713ms step_avg:92.58ms +step:700/1670 train_time:64805ms step_avg:92.58ms +step:701/1670 train_time:64897ms step_avg:92.58ms +step:702/1670 train_time:64990ms step_avg:92.58ms +step:703/1670 train_time:65083ms step_avg:92.58ms +step:704/1670 train_time:65175ms step_avg:92.58ms +step:705/1670 train_time:65267ms step_avg:92.58ms +step:706/1670 train_time:65360ms step_avg:92.58ms +step:707/1670 train_time:65453ms step_avg:92.58ms +step:708/1670 train_time:65545ms step_avg:92.58ms +step:709/1670 train_time:65638ms step_avg:92.58ms +step:710/1670 train_time:65729ms step_avg:92.58ms +step:711/1670 train_time:65822ms step_avg:92.58ms +step:712/1670 train_time:65915ms step_avg:92.58ms +step:713/1670 train_time:66007ms step_avg:92.58ms +step:714/1670 train_time:66100ms step_avg:92.58ms +step:715/1670 train_time:66192ms step_avg:92.58ms +step:716/1670 train_time:66285ms step_avg:92.58ms +step:717/1670 train_time:66377ms step_avg:92.58ms +step:718/1670 train_time:66470ms step_avg:92.58ms +step:719/1670 train_time:66563ms step_avg:92.58ms +step:720/1670 train_time:66656ms step_avg:92.58ms +step:721/1670 train_time:66748ms step_avg:92.58ms +step:722/1670 train_time:66840ms step_avg:92.58ms +step:723/1670 train_time:66933ms step_avg:92.58ms +step:724/1670 train_time:67025ms step_avg:92.58ms +step:725/1670 train_time:67118ms step_avg:92.58ms +step:726/1670 train_time:67210ms step_avg:92.58ms +step:727/1670 train_time:67303ms step_avg:92.58ms +step:728/1670 train_time:67396ms step_avg:92.58ms +step:729/1670 train_time:67488ms step_avg:92.58ms +step:730/1670 train_time:67581ms step_avg:92.58ms +step:731/1670 train_time:67674ms step_avg:92.58ms +step:732/1670 train_time:67766ms step_avg:92.58ms +step:733/1670 train_time:67859ms step_avg:92.58ms +step:734/1670 train_time:67951ms step_avg:92.58ms +step:735/1670 train_time:68044ms step_avg:92.58ms +step:736/1670 train_time:68136ms step_avg:92.58ms +step:737/1670 train_time:68228ms step_avg:92.58ms +step:738/1670 train_time:68323ms step_avg:92.58ms +step:739/1670 train_time:68416ms step_avg:92.58ms +step:740/1670 train_time:68507ms step_avg:92.58ms +step:741/1670 train_time:68601ms step_avg:92.58ms +step:742/1670 train_time:68692ms step_avg:92.58ms +step:743/1670 train_time:68785ms step_avg:92.58ms +step:744/1670 train_time:68878ms step_avg:92.58ms +step:745/1670 train_time:68969ms step_avg:92.58ms +step:746/1670 train_time:69063ms step_avg:92.58ms +step:747/1670 train_time:69156ms step_avg:92.58ms +step:748/1670 train_time:69248ms step_avg:92.58ms +step:749/1670 train_time:69341ms step_avg:92.58ms +step:750/1670 train_time:69433ms step_avg:92.58ms +step:750/1670 val_loss:3.5594 train_time:69525ms step_avg:92.70ms +step:751/1670 train_time:69542ms step_avg:92.60ms +step:752/1670 train_time:69618ms step_avg:92.58ms +step:753/1670 train_time:69711ms step_avg:92.58ms +step:754/1670 train_time:69803ms step_avg:92.58ms +step:755/1670 train_time:69895ms step_avg:92.58ms +step:756/1670 train_time:69987ms step_avg:92.58ms +step:757/1670 train_time:70078ms step_avg:92.57ms +step:758/1670 train_time:70171ms step_avg:92.57ms +step:759/1670 train_time:70263ms step_avg:92.57ms +step:760/1670 train_time:70355ms step_avg:92.57ms +step:761/1670 train_time:70449ms step_avg:92.57ms +step:762/1670 train_time:70543ms step_avg:92.58ms +step:763/1670 train_time:70636ms step_avg:92.58ms +step:764/1670 train_time:70728ms step_avg:92.58ms +step:765/1670 train_time:70820ms step_avg:92.58ms +step:766/1670 train_time:70913ms step_avg:92.58ms +step:767/1670 train_time:71005ms step_avg:92.58ms +step:768/1670 train_time:71097ms step_avg:92.57ms +step:769/1670 train_time:71190ms step_avg:92.57ms +step:770/1670 train_time:71281ms step_avg:92.57ms +step:771/1670 train_time:71375ms step_avg:92.57ms +step:772/1670 train_time:71468ms step_avg:92.58ms +step:773/1670 train_time:71563ms step_avg:92.58ms +step:774/1670 train_time:71656ms step_avg:92.58ms +step:775/1670 train_time:71748ms step_avg:92.58ms +step:776/1670 train_time:71841ms step_avg:92.58ms +step:777/1670 train_time:71933ms step_avg:92.58ms +step:778/1670 train_time:72025ms step_avg:92.58ms +step:779/1670 train_time:72118ms step_avg:92.58ms +step:780/1670 train_time:72210ms step_avg:92.58ms +step:781/1670 train_time:72302ms step_avg:92.58ms +step:782/1670 train_time:72396ms step_avg:92.58ms +step:783/1670 train_time:72490ms step_avg:92.58ms +step:784/1670 train_time:72582ms step_avg:92.58ms +step:785/1670 train_time:72675ms step_avg:92.58ms +step:786/1670 train_time:72767ms step_avg:92.58ms +step:787/1670 train_time:72859ms step_avg:92.58ms +step:788/1670 train_time:72951ms step_avg:92.58ms +step:789/1670 train_time:73043ms step_avg:92.58ms +step:790/1670 train_time:73136ms step_avg:92.58ms +step:791/1670 train_time:73229ms step_avg:92.58ms +step:792/1670 train_time:73320ms step_avg:92.58ms +step:793/1670 train_time:73414ms step_avg:92.58ms +step:794/1670 train_time:73508ms step_avg:92.58ms +step:795/1670 train_time:73600ms step_avg:92.58ms +step:796/1670 train_time:73693ms step_avg:92.58ms +step:797/1670 train_time:73786ms step_avg:92.58ms +step:798/1670 train_time:73878ms step_avg:92.58ms +step:799/1670 train_time:73971ms step_avg:92.58ms +step:800/1670 train_time:74063ms step_avg:92.58ms +step:801/1670 train_time:74156ms step_avg:92.58ms +step:802/1670 train_time:74248ms step_avg:92.58ms +step:803/1670 train_time:74340ms step_avg:92.58ms +step:804/1670 train_time:74433ms step_avg:92.58ms +step:805/1670 train_time:74526ms step_avg:92.58ms +step:806/1670 train_time:74618ms step_avg:92.58ms +step:807/1670 train_time:74711ms step_avg:92.58ms +step:808/1670 train_time:74803ms step_avg:92.58ms +step:809/1670 train_time:74896ms step_avg:92.58ms +step:810/1670 train_time:74988ms step_avg:92.58ms +step:811/1670 train_time:75080ms step_avg:92.58ms +step:812/1670 train_time:75174ms step_avg:92.58ms +step:813/1670 train_time:75267ms step_avg:92.58ms +step:814/1670 train_time:75359ms step_avg:92.58ms +step:815/1670 train_time:75452ms step_avg:92.58ms +step:816/1670 train_time:75544ms step_avg:92.58ms +step:817/1670 train_time:75638ms step_avg:92.58ms +step:818/1670 train_time:75731ms step_avg:92.58ms +step:819/1670 train_time:75823ms step_avg:92.58ms +step:820/1670 train_time:75917ms step_avg:92.58ms +step:821/1670 train_time:76009ms step_avg:92.58ms +step:822/1670 train_time:76101ms step_avg:92.58ms +step:823/1670 train_time:76193ms step_avg:92.58ms +step:824/1670 train_time:76286ms step_avg:92.58ms +step:825/1670 train_time:76378ms step_avg:92.58ms +step:826/1670 train_time:76470ms step_avg:92.58ms +step:827/1670 train_time:76563ms step_avg:92.58ms +step:828/1670 train_time:76654ms step_avg:92.58ms +step:829/1670 train_time:76747ms step_avg:92.58ms +step:830/1670 train_time:76839ms step_avg:92.58ms +step:831/1670 train_time:76933ms step_avg:92.58ms +step:832/1670 train_time:77026ms step_avg:92.58ms +step:833/1670 train_time:77119ms step_avg:92.58ms +step:834/1670 train_time:77212ms step_avg:92.58ms +step:835/1670 train_time:77305ms step_avg:92.58ms +step:836/1670 train_time:77397ms step_avg:92.58ms +step:837/1670 train_time:77490ms step_avg:92.58ms +step:838/1670 train_time:77581ms step_avg:92.58ms +step:839/1670 train_time:77675ms step_avg:92.58ms +step:840/1670 train_time:77768ms step_avg:92.58ms +step:841/1670 train_time:77860ms step_avg:92.58ms +step:842/1670 train_time:77952ms step_avg:92.58ms +step:843/1670 train_time:78045ms step_avg:92.58ms +step:844/1670 train_time:78138ms step_avg:92.58ms +step:845/1670 train_time:78232ms step_avg:92.58ms +step:846/1670 train_time:78323ms step_avg:92.58ms +step:847/1670 train_time:78418ms step_avg:92.58ms +step:848/1670 train_time:78509ms step_avg:92.58ms +step:849/1670 train_time:78602ms step_avg:92.58ms +step:850/1670 train_time:78695ms step_avg:92.58ms +step:851/1670 train_time:78943ms step_avg:92.76ms +step:852/1670 train_time:79017ms step_avg:92.74ms +step:853/1670 train_time:79109ms step_avg:92.74ms +step:854/1670 train_time:79200ms step_avg:92.74ms +step:855/1670 train_time:79291ms step_avg:92.74ms +step:856/1670 train_time:79383ms step_avg:92.74ms +step:857/1670 train_time:79475ms step_avg:92.74ms +step:858/1670 train_time:79566ms step_avg:92.73ms +step:859/1670 train_time:79657ms step_avg:92.73ms +step:860/1670 train_time:79748ms step_avg:92.73ms +step:861/1670 train_time:79846ms step_avg:92.74ms +step:862/1670 train_time:79945ms step_avg:92.74ms +step:863/1670 train_time:80039ms step_avg:92.74ms +step:864/1670 train_time:80131ms step_avg:92.74ms +step:865/1670 train_time:80223ms step_avg:92.74ms +step:866/1670 train_time:80316ms step_avg:92.74ms +step:867/1670 train_time:80408ms step_avg:92.74ms +step:868/1670 train_time:80499ms step_avg:92.74ms +step:869/1670 train_time:80590ms step_avg:92.74ms +step:870/1670 train_time:80681ms step_avg:92.74ms +step:871/1670 train_time:80775ms step_avg:92.74ms +step:872/1670 train_time:80870ms step_avg:92.74ms +step:873/1670 train_time:80964ms step_avg:92.74ms +step:874/1670 train_time:81058ms step_avg:92.74ms +step:875/1670 train_time:81150ms step_avg:92.74ms +step:875/1670 val_loss:3.5154 train_time:81241ms step_avg:92.85ms +step:876/1670 train_time:81258ms step_avg:92.76ms +step:877/1670 train_time:81334ms step_avg:92.74ms +step:878/1670 train_time:81428ms step_avg:92.74ms +step:879/1670 train_time:81519ms step_avg:92.74ms +step:880/1670 train_time:81610ms step_avg:92.74ms +step:881/1670 train_time:81701ms step_avg:92.74ms +step:882/1670 train_time:81793ms step_avg:92.74ms +step:883/1670 train_time:81884ms step_avg:92.73ms +step:884/1670 train_time:81976ms step_avg:92.73ms +step:885/1670 train_time:82070ms step_avg:92.73ms +step:886/1670 train_time:82164ms step_avg:92.74ms +step:887/1670 train_time:82259ms step_avg:92.74ms +step:888/1670 train_time:82353ms step_avg:92.74ms +step:889/1670 train_time:82446ms step_avg:92.74ms +step:890/1670 train_time:82538ms step_avg:92.74ms +step:891/1670 train_time:82630ms step_avg:92.74ms +step:892/1670 train_time:82721ms step_avg:92.74ms +step:893/1670 train_time:82813ms step_avg:92.74ms +step:894/1670 train_time:82905ms step_avg:92.73ms +step:895/1670 train_time:82999ms step_avg:92.74ms +step:896/1670 train_time:83094ms step_avg:92.74ms +step:897/1670 train_time:83187ms step_avg:92.74ms +step:898/1670 train_time:83280ms step_avg:92.74ms +step:899/1670 train_time:83373ms step_avg:92.74ms +step:900/1670 train_time:83466ms step_avg:92.74ms +step:901/1670 train_time:83558ms step_avg:92.74ms +step:902/1670 train_time:83652ms step_avg:92.74ms +step:903/1670 train_time:83743ms step_avg:92.74ms +step:904/1670 train_time:83835ms step_avg:92.74ms +step:905/1670 train_time:83927ms step_avg:92.74ms +step:906/1670 train_time:84020ms step_avg:92.74ms +step:907/1670 train_time:84113ms step_avg:92.74ms +step:908/1670 train_time:84205ms step_avg:92.74ms +step:909/1670 train_time:84300ms step_avg:92.74ms +step:910/1670 train_time:84393ms step_avg:92.74ms +step:911/1670 train_time:84485ms step_avg:92.74ms +step:912/1670 train_time:84577ms step_avg:92.74ms +step:913/1670 train_time:84670ms step_avg:92.74ms +step:914/1670 train_time:84762ms step_avg:92.74ms +step:915/1670 train_time:84854ms step_avg:92.74ms +step:916/1670 train_time:84946ms step_avg:92.74ms +step:917/1670 train_time:85038ms step_avg:92.73ms +step:918/1670 train_time:85131ms step_avg:92.74ms +step:919/1670 train_time:85223ms step_avg:92.74ms +step:920/1670 train_time:85317ms step_avg:92.74ms +step:921/1670 train_time:85410ms step_avg:92.74ms +step:922/1670 train_time:85502ms step_avg:92.73ms +step:923/1670 train_time:85595ms step_avg:92.74ms +step:924/1670 train_time:85687ms step_avg:92.74ms +step:925/1670 train_time:85780ms step_avg:92.73ms +step:926/1670 train_time:85872ms step_avg:92.73ms +step:927/1670 train_time:85963ms step_avg:92.73ms +step:928/1670 train_time:86057ms step_avg:92.73ms +step:929/1670 train_time:86149ms step_avg:92.73ms +step:930/1670 train_time:86241ms step_avg:92.73ms +step:931/1670 train_time:86334ms step_avg:92.73ms +step:932/1670 train_time:86426ms step_avg:92.73ms +step:933/1670 train_time:86519ms step_avg:92.73ms +step:934/1670 train_time:86612ms step_avg:92.73ms +step:935/1670 train_time:86704ms step_avg:92.73ms +step:936/1670 train_time:86797ms step_avg:92.73ms +step:937/1670 train_time:86889ms step_avg:92.73ms +step:938/1670 train_time:86981ms step_avg:92.73ms +step:939/1670 train_time:87074ms step_avg:92.73ms +step:940/1670 train_time:87167ms step_avg:92.73ms +step:941/1670 train_time:87260ms step_avg:92.73ms +step:942/1670 train_time:87352ms step_avg:92.73ms +step:943/1670 train_time:87445ms step_avg:92.73ms +step:944/1670 train_time:87537ms step_avg:92.73ms +step:945/1670 train_time:87630ms step_avg:92.73ms +step:946/1670 train_time:87723ms step_avg:92.73ms +step:947/1670 train_time:87817ms step_avg:92.73ms +step:948/1670 train_time:87909ms step_avg:92.73ms +step:949/1670 train_time:88001ms step_avg:92.73ms +step:950/1670 train_time:88095ms step_avg:92.73ms +step:951/1670 train_time:88188ms step_avg:92.73ms +step:952/1670 train_time:88280ms step_avg:92.73ms +step:953/1670 train_time:88373ms step_avg:92.73ms +step:954/1670 train_time:88466ms step_avg:92.73ms +step:955/1670 train_time:88558ms step_avg:92.73ms +step:956/1670 train_time:88651ms step_avg:92.73ms +step:957/1670 train_time:88743ms step_avg:92.73ms +step:958/1670 train_time:88835ms step_avg:92.73ms +step:959/1670 train_time:88928ms step_avg:92.73ms +step:960/1670 train_time:89020ms step_avg:92.73ms +step:961/1670 train_time:89113ms step_avg:92.73ms +step:962/1670 train_time:89205ms step_avg:92.73ms +step:963/1670 train_time:89299ms step_avg:92.73ms +step:964/1670 train_time:89392ms step_avg:92.73ms +step:965/1670 train_time:89484ms step_avg:92.73ms +step:966/1670 train_time:89577ms step_avg:92.73ms +step:967/1670 train_time:89670ms step_avg:92.73ms +step:968/1670 train_time:89762ms step_avg:92.73ms +step:969/1670 train_time:89855ms step_avg:92.73ms +step:970/1670 train_time:89947ms step_avg:92.73ms +step:971/1670 train_time:90039ms step_avg:92.73ms +step:972/1670 train_time:90132ms step_avg:92.73ms +step:973/1670 train_time:90224ms step_avg:92.73ms +step:974/1670 train_time:90317ms step_avg:92.73ms +step:975/1670 train_time:90410ms step_avg:92.73ms +step:976/1670 train_time:90502ms step_avg:92.73ms +step:977/1670 train_time:90595ms step_avg:92.73ms +step:978/1670 train_time:90687ms step_avg:92.73ms +step:979/1670 train_time:90779ms step_avg:92.73ms +step:980/1670 train_time:90872ms step_avg:92.73ms +step:981/1670 train_time:90963ms step_avg:92.72ms +step:982/1670 train_time:91057ms step_avg:92.73ms +step:983/1670 train_time:91149ms step_avg:92.73ms +step:984/1670 train_time:91242ms step_avg:92.73ms +step:985/1670 train_time:91334ms step_avg:92.73ms +step:986/1670 train_time:91426ms step_avg:92.72ms +step:987/1670 train_time:91519ms step_avg:92.72ms +step:988/1670 train_time:91611ms step_avg:92.72ms +step:989/1670 train_time:91702ms step_avg:92.72ms +step:990/1670 train_time:91796ms step_avg:92.72ms +step:991/1670 train_time:91890ms step_avg:92.72ms +step:992/1670 train_time:91982ms step_avg:92.72ms +step:993/1670 train_time:92075ms step_avg:92.72ms +step:994/1670 train_time:92167ms step_avg:92.72ms +step:995/1670 train_time:92260ms step_avg:92.72ms +step:996/1670 train_time:92353ms step_avg:92.72ms +step:997/1670 train_time:92445ms step_avg:92.72ms +step:998/1670 train_time:92537ms step_avg:92.72ms +step:999/1670 train_time:92630ms step_avg:92.72ms +step:1000/1670 train_time:92723ms step_avg:92.72ms +step:1000/1670 val_loss:3.4673 train_time:92816ms step_avg:92.82ms +step:1001/1670 train_time:92834ms step_avg:92.74ms +step:1002/1670 train_time:92911ms step_avg:92.73ms +step:1003/1670 train_time:93002ms step_avg:92.72ms +step:1004/1670 train_time:93094ms step_avg:92.72ms +step:1005/1670 train_time:93186ms step_avg:92.72ms +step:1006/1670 train_time:93278ms step_avg:92.72ms +step:1007/1670 train_time:93370ms step_avg:92.72ms +step:1008/1670 train_time:93462ms step_avg:92.72ms +step:1009/1670 train_time:93555ms step_avg:92.72ms +step:1010/1670 train_time:93647ms step_avg:92.72ms +step:1011/1670 train_time:93740ms step_avg:92.72ms +step:1012/1670 train_time:93835ms step_avg:92.72ms +step:1013/1670 train_time:93930ms step_avg:92.72ms +step:1014/1670 train_time:94022ms step_avg:92.72ms +step:1015/1670 train_time:94114ms step_avg:92.72ms +step:1016/1670 train_time:94207ms step_avg:92.72ms +step:1017/1670 train_time:94298ms step_avg:92.72ms +step:1018/1670 train_time:94391ms step_avg:92.72ms +step:1019/1670 train_time:94483ms step_avg:92.72ms +step:1020/1670 train_time:94576ms step_avg:92.72ms +step:1021/1670 train_time:94668ms step_avg:92.72ms +step:1022/1670 train_time:94761ms step_avg:92.72ms +step:1023/1670 train_time:94855ms step_avg:92.72ms +step:1024/1670 train_time:94948ms step_avg:92.72ms +step:1025/1670 train_time:95040ms step_avg:92.72ms +step:1026/1670 train_time:95134ms step_avg:92.72ms +step:1027/1670 train_time:95227ms step_avg:92.72ms +step:1028/1670 train_time:95319ms step_avg:92.72ms +step:1029/1670 train_time:95411ms step_avg:92.72ms +step:1030/1670 train_time:95503ms step_avg:92.72ms +step:1031/1670 train_time:95596ms step_avg:92.72ms +step:1032/1670 train_time:95688ms step_avg:92.72ms +step:1033/1670 train_time:95780ms step_avg:92.72ms +step:1034/1670 train_time:95874ms step_avg:92.72ms +step:1035/1670 train_time:95968ms step_avg:92.72ms +step:1036/1670 train_time:96060ms step_avg:92.72ms +step:1037/1670 train_time:96154ms step_avg:92.72ms +step:1038/1670 train_time:96247ms step_avg:92.72ms +step:1039/1670 train_time:96339ms step_avg:92.72ms +step:1040/1670 train_time:96430ms step_avg:92.72ms +step:1041/1670 train_time:96523ms step_avg:92.72ms +step:1042/1670 train_time:96616ms step_avg:92.72ms +step:1043/1670 train_time:96708ms step_avg:92.72ms +step:1044/1670 train_time:96800ms step_avg:92.72ms +step:1045/1670 train_time:96893ms step_avg:92.72ms +step:1046/1670 train_time:96985ms step_avg:92.72ms +step:1047/1670 train_time:97078ms step_avg:92.72ms +step:1048/1670 train_time:97170ms step_avg:92.72ms +step:1049/1670 train_time:97262ms step_avg:92.72ms +step:1050/1670 train_time:97356ms step_avg:92.72ms +step:1051/1670 train_time:97448ms step_avg:92.72ms +step:1052/1670 train_time:97541ms step_avg:92.72ms +step:1053/1670 train_time:97635ms step_avg:92.72ms +step:1054/1670 train_time:97728ms step_avg:92.72ms +step:1055/1670 train_time:97820ms step_avg:92.72ms +step:1056/1670 train_time:97912ms step_avg:92.72ms +step:1057/1670 train_time:98005ms step_avg:92.72ms +step:1058/1670 train_time:98097ms step_avg:92.72ms +step:1059/1670 train_time:98190ms step_avg:92.72ms +step:1060/1670 train_time:98282ms step_avg:92.72ms +step:1061/1670 train_time:98375ms step_avg:92.72ms +step:1062/1670 train_time:98624ms step_avg:92.87ms +step:1063/1670 train_time:98697ms step_avg:92.85ms +step:1064/1670 train_time:98788ms step_avg:92.85ms +step:1065/1670 train_time:98879ms step_avg:92.84ms +step:1066/1670 train_time:98970ms step_avg:92.84ms +step:1067/1670 train_time:99061ms step_avg:92.84ms +step:1068/1670 train_time:99153ms step_avg:92.84ms +step:1069/1670 train_time:99244ms step_avg:92.84ms +step:1070/1670 train_time:99335ms step_avg:92.84ms +step:1071/1670 train_time:99426ms step_avg:92.83ms +step:1072/1670 train_time:99523ms step_avg:92.84ms +step:1073/1670 train_time:99619ms step_avg:92.84ms +step:1074/1670 train_time:99714ms step_avg:92.84ms +step:1075/1670 train_time:99807ms step_avg:92.84ms +step:1076/1670 train_time:99898ms step_avg:92.84ms +step:1077/1670 train_time:99991ms step_avg:92.84ms +step:1078/1670 train_time:100082ms step_avg:92.84ms +step:1079/1670 train_time:100174ms step_avg:92.84ms +step:1080/1670 train_time:100265ms step_avg:92.84ms +step:1081/1670 train_time:100356ms step_avg:92.84ms +step:1082/1670 train_time:100452ms step_avg:92.84ms +step:1083/1670 train_time:100546ms step_avg:92.84ms +step:1084/1670 train_time:100640ms step_avg:92.84ms +step:1085/1670 train_time:100734ms step_avg:92.84ms +step:1086/1670 train_time:100827ms step_avg:92.84ms +step:1087/1670 train_time:100919ms step_avg:92.84ms +step:1088/1670 train_time:101012ms step_avg:92.84ms +step:1089/1670 train_time:101103ms step_avg:92.84ms +step:1090/1670 train_time:101194ms step_avg:92.84ms +step:1091/1670 train_time:101286ms step_avg:92.84ms +step:1092/1670 train_time:101378ms step_avg:92.84ms +step:1093/1670 train_time:101470ms step_avg:92.84ms +step:1094/1670 train_time:101563ms step_avg:92.84ms +step:1095/1670 train_time:101657ms step_avg:92.84ms +step:1096/1670 train_time:101751ms step_avg:92.84ms +step:1097/1670 train_time:101843ms step_avg:92.84ms +step:1098/1670 train_time:101936ms step_avg:92.84ms +step:1099/1670 train_time:102029ms step_avg:92.84ms +step:1100/1670 train_time:102120ms step_avg:92.84ms +step:1101/1670 train_time:102212ms step_avg:92.84ms +step:1102/1670 train_time:102304ms step_avg:92.83ms +step:1103/1670 train_time:102396ms step_avg:92.83ms +step:1104/1670 train_time:102489ms step_avg:92.83ms +step:1105/1670 train_time:102581ms step_avg:92.83ms +step:1106/1670 train_time:102675ms step_avg:92.83ms +step:1107/1670 train_time:102768ms step_avg:92.83ms +step:1108/1670 train_time:102860ms step_avg:92.83ms +step:1109/1670 train_time:102955ms step_avg:92.84ms +step:1110/1670 train_time:103047ms step_avg:92.83ms +step:1111/1670 train_time:103139ms step_avg:92.83ms +step:1112/1670 train_time:103231ms step_avg:92.83ms +step:1113/1670 train_time:103323ms step_avg:92.83ms +step:1114/1670 train_time:103415ms step_avg:92.83ms +step:1115/1670 train_time:103701ms step_avg:93.01ms +step:1116/1670 train_time:103778ms step_avg:92.99ms +step:1117/1670 train_time:103871ms step_avg:92.99ms +step:1118/1670 train_time:103962ms step_avg:92.99ms +step:1119/1670 train_time:104054ms step_avg:92.99ms +step:1120/1670 train_time:104146ms step_avg:92.99ms +step:1121/1670 train_time:104237ms step_avg:92.99ms +step:1122/1670 train_time:104329ms step_avg:92.99ms +step:1123/1670 train_time:104421ms step_avg:92.98ms +step:1124/1670 train_time:104514ms step_avg:92.98ms +step:1125/1670 train_time:104614ms step_avg:92.99ms +step:1125/1670 val_loss:3.4139 train_time:104716ms step_avg:93.08ms +step:1126/1670 train_time:104734ms step_avg:93.01ms +step:1127/1670 train_time:104816ms step_avg:93.00ms +step:1128/1670 train_time:104919ms step_avg:93.01ms +step:1129/1670 train_time:105014ms step_avg:93.01ms +step:1130/1670 train_time:105105ms step_avg:93.01ms +step:1131/1670 train_time:105197ms step_avg:93.01ms +step:1132/1670 train_time:105289ms step_avg:93.01ms +step:1133/1670 train_time:105381ms step_avg:93.01ms +step:1134/1670 train_time:105473ms step_avg:93.01ms +step:1135/1670 train_time:105564ms step_avg:93.01ms +step:1136/1670 train_time:105657ms step_avg:93.01ms +step:1137/1670 train_time:105751ms step_avg:93.01ms +step:1138/1670 train_time:105847ms step_avg:93.01ms +step:1139/1670 train_time:105945ms step_avg:93.02ms +step:1140/1670 train_time:106038ms step_avg:93.02ms +step:1141/1670 train_time:106131ms step_avg:93.02ms +step:1142/1670 train_time:106224ms step_avg:93.02ms +step:1143/1670 train_time:106316ms step_avg:93.01ms +step:1144/1670 train_time:106408ms step_avg:93.01ms +step:1145/1670 train_time:106501ms step_avg:93.01ms +step:1146/1670 train_time:106593ms step_avg:93.01ms +step:1147/1670 train_time:106685ms step_avg:93.01ms +step:1148/1670 train_time:106781ms step_avg:93.01ms +step:1149/1670 train_time:106878ms step_avg:93.02ms +step:1150/1670 train_time:106972ms step_avg:93.02ms +step:1151/1670 train_time:107065ms step_avg:93.02ms +step:1152/1670 train_time:107159ms step_avg:93.02ms +step:1153/1670 train_time:107251ms step_avg:93.02ms +step:1154/1670 train_time:107344ms step_avg:93.02ms +step:1155/1670 train_time:107436ms step_avg:93.02ms +step:1156/1670 train_time:107528ms step_avg:93.02ms +step:1157/1670 train_time:107621ms step_avg:93.02ms +step:1158/1670 train_time:107715ms step_avg:93.02ms +step:1159/1670 train_time:107809ms step_avg:93.02ms +step:1160/1670 train_time:107904ms step_avg:93.02ms +step:1161/1670 train_time:107998ms step_avg:93.02ms +step:1162/1670 train_time:108090ms step_avg:93.02ms +step:1163/1670 train_time:108184ms step_avg:93.02ms +step:1164/1670 train_time:108277ms step_avg:93.02ms +step:1165/1670 train_time:108369ms step_avg:93.02ms +step:1166/1670 train_time:108462ms step_avg:93.02ms +step:1167/1670 train_time:108555ms step_avg:93.02ms +step:1168/1670 train_time:108648ms step_avg:93.02ms +step:1169/1670 train_time:108743ms step_avg:93.02ms +step:1170/1670 train_time:108838ms step_avg:93.02ms +step:1171/1670 train_time:108932ms step_avg:93.02ms +step:1172/1670 train_time:109026ms step_avg:93.03ms +step:1173/1670 train_time:109118ms step_avg:93.02ms +step:1174/1670 train_time:109211ms step_avg:93.03ms +step:1175/1670 train_time:109305ms step_avg:93.03ms +step:1176/1670 train_time:109398ms step_avg:93.03ms +step:1177/1670 train_time:109490ms step_avg:93.02ms +step:1178/1670 train_time:109582ms step_avg:93.02ms +step:1179/1670 train_time:109675ms step_avg:93.02ms +step:1180/1670 train_time:109768ms step_avg:93.02ms +step:1181/1670 train_time:109863ms step_avg:93.03ms +step:1182/1670 train_time:109956ms step_avg:93.03ms +step:1183/1670 train_time:110049ms step_avg:93.03ms +step:1184/1670 train_time:110144ms step_avg:93.03ms +step:1185/1670 train_time:110237ms step_avg:93.03ms +step:1186/1670 train_time:110330ms step_avg:93.03ms +step:1187/1670 train_time:110423ms step_avg:93.03ms +step:1188/1670 train_time:110515ms step_avg:93.03ms +step:1189/1670 train_time:110608ms step_avg:93.03ms +step:1190/1670 train_time:110702ms step_avg:93.03ms +step:1191/1670 train_time:110796ms step_avg:93.03ms +step:1192/1670 train_time:110889ms step_avg:93.03ms +step:1193/1670 train_time:110982ms step_avg:93.03ms +step:1194/1670 train_time:111076ms step_avg:93.03ms +step:1195/1670 train_time:111169ms step_avg:93.03ms +step:1196/1670 train_time:111262ms step_avg:93.03ms +step:1197/1670 train_time:111355ms step_avg:93.03ms +step:1198/1670 train_time:111448ms step_avg:93.03ms +step:1199/1670 train_time:111541ms step_avg:93.03ms +step:1200/1670 train_time:111634ms step_avg:93.03ms +step:1201/1670 train_time:111726ms step_avg:93.03ms +step:1202/1670 train_time:111820ms step_avg:93.03ms +step:1203/1670 train_time:111913ms step_avg:93.03ms +step:1204/1670 train_time:112006ms step_avg:93.03ms +step:1205/1670 train_time:112100ms step_avg:93.03ms +step:1206/1670 train_time:112194ms step_avg:93.03ms +step:1207/1670 train_time:112286ms step_avg:93.03ms +step:1208/1670 train_time:112381ms step_avg:93.03ms +step:1209/1670 train_time:112473ms step_avg:93.03ms +step:1210/1670 train_time:112567ms step_avg:93.03ms +step:1211/1670 train_time:112661ms step_avg:93.03ms +step:1212/1670 train_time:112754ms step_avg:93.03ms +step:1213/1670 train_time:112847ms step_avg:93.03ms +step:1214/1670 train_time:112939ms step_avg:93.03ms +step:1215/1670 train_time:113033ms step_avg:93.03ms +step:1216/1670 train_time:113127ms step_avg:93.03ms +step:1217/1670 train_time:113221ms step_avg:93.03ms +step:1218/1670 train_time:113315ms step_avg:93.03ms +step:1219/1670 train_time:113408ms step_avg:93.03ms +step:1220/1670 train_time:113502ms step_avg:93.03ms +step:1221/1670 train_time:113595ms step_avg:93.03ms +step:1222/1670 train_time:113687ms step_avg:93.03ms +step:1223/1670 train_time:113781ms step_avg:93.03ms +step:1224/1670 train_time:113874ms step_avg:93.03ms +step:1225/1670 train_time:113966ms step_avg:93.03ms +step:1226/1670 train_time:114060ms step_avg:93.03ms +step:1227/1670 train_time:114152ms step_avg:93.03ms +step:1228/1670 train_time:114246ms step_avg:93.03ms +step:1229/1670 train_time:114341ms step_avg:93.04ms +step:1230/1670 train_time:114435ms step_avg:93.04ms +step:1231/1670 train_time:114528ms step_avg:93.04ms +step:1232/1670 train_time:114623ms step_avg:93.04ms +step:1233/1670 train_time:114715ms step_avg:93.04ms +step:1234/1670 train_time:114807ms step_avg:93.04ms +step:1235/1670 train_time:114901ms step_avg:93.04ms +step:1236/1670 train_time:114994ms step_avg:93.04ms +step:1237/1670 train_time:115086ms step_avg:93.04ms +step:1238/1670 train_time:115179ms step_avg:93.04ms +step:1239/1670 train_time:115273ms step_avg:93.04ms +step:1240/1670 train_time:115366ms step_avg:93.04ms +step:1241/1670 train_time:115460ms step_avg:93.04ms +step:1242/1670 train_time:115553ms step_avg:93.04ms +step:1243/1670 train_time:115646ms step_avg:93.04ms +step:1244/1670 train_time:115741ms step_avg:93.04ms +step:1245/1670 train_time:115835ms step_avg:93.04ms +step:1246/1670 train_time:115927ms step_avg:93.04ms +step:1247/1670 train_time:116020ms step_avg:93.04ms +step:1248/1670 train_time:116114ms step_avg:93.04ms +step:1249/1670 train_time:116207ms step_avg:93.04ms +step:1250/1670 train_time:116300ms step_avg:93.04ms +step:1250/1670 val_loss:3.3762 train_time:116392ms step_avg:93.11ms +step:1251/1670 train_time:116411ms step_avg:93.05ms +step:1252/1670 train_time:116487ms step_avg:93.04ms +step:1253/1670 train_time:116579ms step_avg:93.04ms +step:1254/1670 train_time:116671ms step_avg:93.04ms +step:1255/1670 train_time:116764ms step_avg:93.04ms +step:1256/1670 train_time:116857ms step_avg:93.04ms +step:1257/1670 train_time:116949ms step_avg:93.04ms +step:1258/1670 train_time:117044ms step_avg:93.04ms +step:1259/1670 train_time:117137ms step_avg:93.04ms +step:1260/1670 train_time:117229ms step_avg:93.04ms +step:1261/1670 train_time:117324ms step_avg:93.04ms +step:1262/1670 train_time:117419ms step_avg:93.04ms +step:1263/1670 train_time:117512ms step_avg:93.04ms +step:1264/1670 train_time:117605ms step_avg:93.04ms +step:1265/1670 train_time:117698ms step_avg:93.04ms +step:1266/1670 train_time:117790ms step_avg:93.04ms +step:1267/1670 train_time:117885ms step_avg:93.04ms +step:1268/1670 train_time:117979ms step_avg:93.04ms +step:1269/1670 train_time:118071ms step_avg:93.04ms +step:1270/1670 train_time:118165ms step_avg:93.04ms +step:1271/1670 train_time:118258ms step_avg:93.04ms +step:1272/1670 train_time:118352ms step_avg:93.04ms +step:1273/1670 train_time:118446ms step_avg:93.04ms +step:1274/1670 train_time:118685ms step_avg:93.16ms +step:1275/1670 train_time:118773ms step_avg:93.16ms +step:1276/1670 train_time:118864ms step_avg:93.15ms +step:1277/1670 train_time:118956ms step_avg:93.15ms +step:1278/1670 train_time:119048ms step_avg:93.15ms +step:1279/1670 train_time:119140ms step_avg:93.15ms +step:1280/1670 train_time:119231ms step_avg:93.15ms +step:1281/1670 train_time:119323ms step_avg:93.15ms +step:1282/1670 train_time:119416ms step_avg:93.15ms +step:1283/1670 train_time:119507ms step_avg:93.15ms +step:1284/1670 train_time:119605ms step_avg:93.15ms +step:1285/1670 train_time:119703ms step_avg:93.15ms +step:1286/1670 train_time:119797ms step_avg:93.15ms +step:1287/1670 train_time:119890ms step_avg:93.15ms +step:1288/1670 train_time:119984ms step_avg:93.16ms +step:1289/1670 train_time:120076ms step_avg:93.15ms +step:1290/1670 train_time:120168ms step_avg:93.15ms +step:1291/1670 train_time:120260ms step_avg:93.15ms +step:1292/1670 train_time:120352ms step_avg:93.15ms +step:1293/1670 train_time:120445ms step_avg:93.15ms +step:1294/1670 train_time:120538ms step_avg:93.15ms +step:1295/1670 train_time:120633ms step_avg:93.15ms +step:1296/1670 train_time:120728ms step_avg:93.15ms +step:1297/1670 train_time:120821ms step_avg:93.15ms +step:1298/1670 train_time:120914ms step_avg:93.15ms +step:1299/1670 train_time:121008ms step_avg:93.15ms +step:1300/1670 train_time:121102ms step_avg:93.16ms +step:1301/1670 train_time:121195ms step_avg:93.16ms +step:1302/1670 train_time:121288ms step_avg:93.15ms +step:1303/1670 train_time:121380ms step_avg:93.15ms +step:1304/1670 train_time:121472ms step_avg:93.15ms +step:1305/1670 train_time:121565ms step_avg:93.15ms +step:1306/1670 train_time:121659ms step_avg:93.15ms +step:1307/1670 train_time:121752ms step_avg:93.15ms +step:1308/1670 train_time:121846ms step_avg:93.15ms +step:1309/1670 train_time:121940ms step_avg:93.16ms +step:1310/1670 train_time:122033ms step_avg:93.16ms +step:1311/1670 train_time:122126ms step_avg:93.16ms +step:1312/1670 train_time:122219ms step_avg:93.15ms +step:1313/1670 train_time:122311ms step_avg:93.15ms +step:1314/1670 train_time:122406ms step_avg:93.15ms +step:1315/1670 train_time:122498ms step_avg:93.15ms +step:1316/1670 train_time:122590ms step_avg:93.15ms +step:1317/1670 train_time:122687ms step_avg:93.16ms +step:1318/1670 train_time:122782ms step_avg:93.16ms +step:1319/1670 train_time:122875ms step_avg:93.16ms +step:1320/1670 train_time:122968ms step_avg:93.16ms +step:1321/1670 train_time:123061ms step_avg:93.16ms +step:1322/1670 train_time:123154ms step_avg:93.16ms +step:1323/1670 train_time:123247ms step_avg:93.16ms +step:1324/1670 train_time:123340ms step_avg:93.16ms +step:1325/1670 train_time:123433ms step_avg:93.16ms +step:1326/1670 train_time:123527ms step_avg:93.16ms +step:1327/1670 train_time:123620ms step_avg:93.16ms +step:1328/1670 train_time:123714ms step_avg:93.16ms +step:1329/1670 train_time:123808ms step_avg:93.16ms +step:1330/1670 train_time:123902ms step_avg:93.16ms +step:1331/1670 train_time:123995ms step_avg:93.16ms +step:1332/1670 train_time:124090ms step_avg:93.16ms +step:1333/1670 train_time:124183ms step_avg:93.16ms +step:1334/1670 train_time:124274ms step_avg:93.16ms +step:1335/1670 train_time:124368ms step_avg:93.16ms +step:1336/1670 train_time:124461ms step_avg:93.16ms +step:1337/1670 train_time:124553ms step_avg:93.16ms +step:1338/1670 train_time:124647ms step_avg:93.16ms +step:1339/1670 train_time:124742ms step_avg:93.16ms +step:1340/1670 train_time:124835ms step_avg:93.16ms +step:1341/1670 train_time:124928ms step_avg:93.16ms +step:1342/1670 train_time:125021ms step_avg:93.16ms +step:1343/1670 train_time:125114ms step_avg:93.16ms +step:1344/1670 train_time:125207ms step_avg:93.16ms +step:1345/1670 train_time:125301ms step_avg:93.16ms +step:1346/1670 train_time:125392ms step_avg:93.16ms +step:1347/1670 train_time:125488ms step_avg:93.16ms +step:1348/1670 train_time:125581ms step_avg:93.16ms +step:1349/1670 train_time:125673ms step_avg:93.16ms +step:1350/1670 train_time:125767ms step_avg:93.16ms +step:1351/1670 train_time:125860ms step_avg:93.16ms +step:1352/1670 train_time:125953ms step_avg:93.16ms +step:1353/1670 train_time:126047ms step_avg:93.16ms +step:1354/1670 train_time:126140ms step_avg:93.16ms +step:1355/1670 train_time:126232ms step_avg:93.16ms +step:1356/1670 train_time:126326ms step_avg:93.16ms +step:1357/1670 train_time:126418ms step_avg:93.16ms +step:1358/1670 train_time:126512ms step_avg:93.16ms +step:1359/1670 train_time:126606ms step_avg:93.16ms +step:1360/1670 train_time:126700ms step_avg:93.16ms +step:1361/1670 train_time:126793ms step_avg:93.16ms +step:1362/1670 train_time:126886ms step_avg:93.16ms +step:1363/1670 train_time:126977ms step_avg:93.16ms +step:1364/1670 train_time:127072ms step_avg:93.16ms +step:1365/1670 train_time:127165ms step_avg:93.16ms +step:1366/1670 train_time:127258ms step_avg:93.16ms +step:1367/1670 train_time:127351ms step_avg:93.16ms +step:1368/1670 train_time:127444ms step_avg:93.16ms +step:1369/1670 train_time:127538ms step_avg:93.16ms +step:1370/1670 train_time:127631ms step_avg:93.16ms +step:1371/1670 train_time:127724ms step_avg:93.16ms +step:1372/1670 train_time:127818ms step_avg:93.16ms +step:1373/1670 train_time:127911ms step_avg:93.16ms +step:1374/1670 train_time:128004ms step_avg:93.16ms +step:1375/1670 train_time:128097ms step_avg:93.16ms +step:1375/1670 val_loss:3.3416 train_time:128190ms step_avg:93.23ms +step:1376/1670 train_time:128209ms step_avg:93.17ms +step:1377/1670 train_time:128284ms step_avg:93.16ms +step:1378/1670 train_time:128377ms step_avg:93.16ms +step:1379/1670 train_time:128470ms step_avg:93.16ms +step:1380/1670 train_time:128562ms step_avg:93.16ms +step:1381/1670 train_time:128654ms step_avg:93.16ms +step:1382/1670 train_time:128746ms step_avg:93.16ms +step:1383/1670 train_time:128839ms step_avg:93.16ms +step:1384/1670 train_time:128934ms step_avg:93.16ms +step:1385/1670 train_time:129026ms step_avg:93.16ms +step:1386/1670 train_time:129120ms step_avg:93.16ms +step:1387/1670 train_time:129216ms step_avg:93.16ms +step:1388/1670 train_time:129310ms step_avg:93.16ms +step:1389/1670 train_time:129403ms step_avg:93.16ms +step:1390/1670 train_time:129497ms step_avg:93.16ms +step:1391/1670 train_time:129589ms step_avg:93.16ms +step:1392/1670 train_time:129681ms step_avg:93.16ms +step:1393/1670 train_time:129775ms step_avg:93.16ms +step:1394/1670 train_time:129868ms step_avg:93.16ms +step:1395/1670 train_time:129961ms step_avg:93.16ms +step:1396/1670 train_time:130055ms step_avg:93.16ms +step:1397/1670 train_time:130148ms step_avg:93.16ms +step:1398/1670 train_time:130241ms step_avg:93.16ms +step:1399/1670 train_time:130336ms step_avg:93.16ms +step:1400/1670 train_time:130430ms step_avg:93.16ms +step:1401/1670 train_time:130523ms step_avg:93.16ms +step:1402/1670 train_time:130617ms step_avg:93.16ms +step:1403/1670 train_time:130710ms step_avg:93.16ms +step:1404/1670 train_time:130803ms step_avg:93.16ms +step:1405/1670 train_time:130896ms step_avg:93.16ms +step:1406/1670 train_time:130991ms step_avg:93.17ms +step:1407/1670 train_time:131083ms step_avg:93.17ms +step:1408/1670 train_time:131177ms step_avg:93.17ms +step:1409/1670 train_time:131271ms step_avg:93.17ms +step:1410/1670 train_time:131364ms step_avg:93.17ms +step:1411/1670 train_time:131458ms step_avg:93.17ms +step:1412/1670 train_time:131551ms step_avg:93.17ms +step:1413/1670 train_time:131644ms step_avg:93.17ms +step:1414/1670 train_time:131736ms step_avg:93.17ms +step:1415/1670 train_time:131829ms step_avg:93.17ms +step:1416/1670 train_time:131922ms step_avg:93.17ms +step:1417/1670 train_time:132015ms step_avg:93.17ms +step:1418/1670 train_time:132109ms step_avg:93.17ms +step:1419/1670 train_time:132203ms step_avg:93.17ms +step:1420/1670 train_time:132297ms step_avg:93.17ms +step:1421/1670 train_time:132391ms step_avg:93.17ms +step:1422/1670 train_time:132485ms step_avg:93.17ms +step:1423/1670 train_time:132578ms step_avg:93.17ms +step:1424/1670 train_time:132672ms step_avg:93.17ms +step:1425/1670 train_time:132766ms step_avg:93.17ms +step:1426/1670 train_time:132858ms step_avg:93.17ms +step:1427/1670 train_time:132952ms step_avg:93.17ms +step:1428/1670 train_time:133045ms step_avg:93.17ms +step:1429/1670 train_time:133139ms step_avg:93.17ms +step:1430/1670 train_time:133234ms step_avg:93.17ms +step:1431/1670 train_time:133328ms step_avg:93.17ms +step:1432/1670 train_time:133421ms step_avg:93.17ms +step:1433/1670 train_time:133514ms step_avg:93.17ms +step:1434/1670 train_time:133607ms step_avg:93.17ms +step:1435/1670 train_time:133700ms step_avg:93.17ms +step:1436/1670 train_time:133795ms step_avg:93.17ms +step:1437/1670 train_time:133887ms step_avg:93.17ms +step:1438/1670 train_time:133979ms step_avg:93.17ms +step:1439/1670 train_time:134073ms step_avg:93.17ms +step:1440/1670 train_time:134167ms step_avg:93.17ms +step:1441/1670 train_time:134260ms step_avg:93.17ms +step:1442/1670 train_time:134355ms step_avg:93.17ms +step:1443/1670 train_time:134448ms step_avg:93.17ms +step:1444/1670 train_time:134540ms step_avg:93.17ms +step:1445/1670 train_time:134634ms step_avg:93.17ms +step:1446/1670 train_time:134727ms step_avg:93.17ms +step:1447/1670 train_time:134820ms step_avg:93.17ms +step:1448/1670 train_time:134914ms step_avg:93.17ms +step:1449/1670 train_time:135006ms step_avg:93.17ms +step:1450/1670 train_time:135099ms step_avg:93.17ms +step:1451/1670 train_time:135195ms step_avg:93.17ms +step:1452/1670 train_time:135289ms step_avg:93.17ms +step:1453/1670 train_time:135382ms step_avg:93.17ms +step:1454/1670 train_time:135476ms step_avg:93.17ms +step:1455/1670 train_time:135569ms step_avg:93.17ms +step:1456/1670 train_time:135663ms step_avg:93.17ms +step:1457/1670 train_time:135756ms step_avg:93.17ms +step:1458/1670 train_time:135849ms step_avg:93.17ms +step:1459/1670 train_time:135941ms step_avg:93.17ms +step:1460/1670 train_time:136036ms step_avg:93.18ms +step:1461/1670 train_time:136130ms step_avg:93.18ms +step:1462/1670 train_time:136223ms step_avg:93.18ms +step:1463/1670 train_time:136316ms step_avg:93.18ms +step:1464/1670 train_time:136410ms step_avg:93.18ms +step:1465/1670 train_time:136503ms step_avg:93.18ms +step:1466/1670 train_time:136597ms step_avg:93.18ms +step:1467/1670 train_time:136691ms step_avg:93.18ms +step:1468/1670 train_time:136785ms step_avg:93.18ms +step:1469/1670 train_time:136878ms step_avg:93.18ms +step:1470/1670 train_time:136970ms step_avg:93.18ms +step:1471/1670 train_time:137063ms step_avg:93.18ms +step:1472/1670 train_time:137157ms step_avg:93.18ms +step:1473/1670 train_time:137250ms step_avg:93.18ms +step:1474/1670 train_time:137343ms step_avg:93.18ms +step:1475/1670 train_time:137438ms step_avg:93.18ms +step:1476/1670 train_time:137531ms step_avg:93.18ms +step:1477/1670 train_time:137625ms step_avg:93.18ms +step:1478/1670 train_time:137718ms step_avg:93.18ms +step:1479/1670 train_time:137811ms step_avg:93.18ms +step:1480/1670 train_time:137905ms step_avg:93.18ms +step:1481/1670 train_time:137999ms step_avg:93.18ms +step:1482/1670 train_time:138091ms step_avg:93.18ms +step:1483/1670 train_time:138185ms step_avg:93.18ms +step:1484/1670 train_time:138277ms step_avg:93.18ms +step:1485/1670 train_time:138529ms step_avg:93.29ms +step:1486/1670 train_time:138601ms step_avg:93.27ms +step:1487/1670 train_time:138693ms step_avg:93.27ms +step:1488/1670 train_time:138785ms step_avg:93.27ms +step:1489/1670 train_time:138876ms step_avg:93.27ms +step:1490/1670 train_time:138968ms step_avg:93.27ms +step:1491/1670 train_time:139060ms step_avg:93.27ms +step:1492/1670 train_time:139152ms step_avg:93.27ms +step:1493/1670 train_time:139244ms step_avg:93.26ms +step:1494/1670 train_time:139336ms step_avg:93.26ms +step:1495/1670 train_time:139437ms step_avg:93.27ms +step:1496/1670 train_time:139535ms step_avg:93.27ms +step:1497/1670 train_time:139629ms step_avg:93.27ms +step:1498/1670 train_time:139721ms step_avg:93.27ms +step:1499/1670 train_time:139814ms step_avg:93.27ms +step:1500/1670 train_time:139906ms step_avg:93.27ms +step:1500/1670 val_loss:3.3117 train_time:139999ms step_avg:93.33ms +step:1501/1670 train_time:140017ms step_avg:93.28ms +step:1502/1670 train_time:140095ms step_avg:93.27ms +step:1503/1670 train_time:140188ms step_avg:93.27ms +step:1504/1670 train_time:140280ms step_avg:93.27ms +step:1505/1670 train_time:140372ms step_avg:93.27ms +step:1506/1670 train_time:140464ms step_avg:93.27ms +step:1507/1670 train_time:140558ms step_avg:93.27ms +step:1508/1670 train_time:140652ms step_avg:93.27ms +step:1509/1670 train_time:140745ms step_avg:93.27ms +step:1510/1670 train_time:140839ms step_avg:93.27ms +step:1511/1670 train_time:140933ms step_avg:93.27ms +step:1512/1670 train_time:141027ms step_avg:93.27ms +step:1513/1670 train_time:141122ms step_avg:93.27ms +step:1514/1670 train_time:141215ms step_avg:93.27ms +step:1515/1670 train_time:141307ms step_avg:93.27ms +step:1516/1670 train_time:141400ms step_avg:93.27ms +step:1517/1670 train_time:141492ms step_avg:93.27ms +step:1518/1670 train_time:141585ms step_avg:93.27ms +step:1519/1670 train_time:141679ms step_avg:93.27ms +step:1520/1670 train_time:141771ms step_avg:93.27ms +step:1521/1670 train_time:141865ms step_avg:93.27ms +step:1522/1670 train_time:141959ms step_avg:93.27ms +step:1523/1670 train_time:142053ms step_avg:93.27ms +step:1524/1670 train_time:142146ms step_avg:93.27ms +step:1525/1670 train_time:142241ms step_avg:93.27ms +step:1526/1670 train_time:142333ms step_avg:93.27ms +step:1527/1670 train_time:142425ms step_avg:93.27ms +step:1528/1670 train_time:142518ms step_avg:93.27ms +step:1529/1670 train_time:142611ms step_avg:93.27ms +step:1530/1670 train_time:142704ms step_avg:93.27ms +step:1531/1670 train_time:142798ms step_avg:93.27ms +step:1532/1670 train_time:142891ms step_avg:93.27ms +step:1533/1670 train_time:142985ms step_avg:93.27ms +step:1534/1670 train_time:143081ms step_avg:93.27ms +step:1535/1670 train_time:143175ms step_avg:93.27ms +step:1536/1670 train_time:143269ms step_avg:93.27ms +step:1537/1670 train_time:143362ms step_avg:93.27ms +step:1538/1670 train_time:143456ms step_avg:93.27ms +step:1539/1670 train_time:143549ms step_avg:93.27ms +step:1540/1670 train_time:143642ms step_avg:93.27ms +step:1541/1670 train_time:143736ms step_avg:93.27ms +step:1542/1670 train_time:143830ms step_avg:93.27ms +step:1543/1670 train_time:143923ms step_avg:93.27ms +step:1544/1670 train_time:144015ms step_avg:93.27ms +step:1545/1670 train_time:144109ms step_avg:93.27ms +step:1546/1670 train_time:144202ms step_avg:93.27ms +step:1547/1670 train_time:144297ms step_avg:93.28ms +step:1548/1670 train_time:144389ms step_avg:93.27ms +step:1549/1670 train_time:144483ms step_avg:93.28ms +step:1550/1670 train_time:144577ms step_avg:93.28ms +step:1551/1670 train_time:144669ms step_avg:93.27ms +step:1552/1670 train_time:144764ms step_avg:93.28ms +step:1553/1670 train_time:144857ms step_avg:93.28ms +step:1554/1670 train_time:144949ms step_avg:93.28ms +step:1555/1670 train_time:145044ms step_avg:93.28ms +step:1556/1670 train_time:145136ms step_avg:93.28ms +step:1557/1670 train_time:145229ms step_avg:93.27ms +step:1558/1670 train_time:145322ms step_avg:93.27ms +step:1559/1670 train_time:145415ms step_avg:93.27ms +step:1560/1670 train_time:145508ms step_avg:93.27ms +step:1561/1670 train_time:145602ms step_avg:93.27ms +step:1562/1670 train_time:145695ms step_avg:93.27ms +step:1563/1670 train_time:145788ms step_avg:93.27ms +step:1564/1670 train_time:145881ms step_avg:93.27ms +step:1565/1670 train_time:145975ms step_avg:93.27ms +step:1566/1670 train_time:146069ms step_avg:93.28ms +step:1567/1670 train_time:146162ms step_avg:93.28ms +step:1568/1670 train_time:146256ms step_avg:93.28ms +step:1569/1670 train_time:146350ms step_avg:93.28ms +step:1570/1670 train_time:146444ms step_avg:93.28ms +step:1571/1670 train_time:146537ms step_avg:93.28ms +step:1572/1670 train_time:146629ms step_avg:93.28ms +step:1573/1670 train_time:146723ms step_avg:93.28ms +step:1574/1670 train_time:146816ms step_avg:93.28ms +step:1575/1670 train_time:146909ms step_avg:93.28ms +step:1576/1670 train_time:147003ms step_avg:93.28ms +step:1577/1670 train_time:147096ms step_avg:93.28ms +step:1578/1670 train_time:147189ms step_avg:93.28ms +step:1579/1670 train_time:147283ms step_avg:93.28ms +step:1580/1670 train_time:147376ms step_avg:93.28ms +step:1581/1670 train_time:147468ms step_avg:93.28ms +step:1582/1670 train_time:147562ms step_avg:93.28ms +step:1583/1670 train_time:147654ms step_avg:93.27ms +step:1584/1670 train_time:147747ms step_avg:93.27ms +step:1585/1670 train_time:147839ms step_avg:93.27ms +step:1586/1670 train_time:147932ms step_avg:93.27ms +step:1587/1670 train_time:148025ms step_avg:93.27ms +step:1588/1670 train_time:148118ms step_avg:93.27ms +step:1589/1670 train_time:148212ms step_avg:93.27ms +step:1590/1670 train_time:148305ms step_avg:93.27ms +step:1591/1670 train_time:148398ms step_avg:93.27ms +step:1592/1670 train_time:148491ms step_avg:93.27ms +step:1593/1670 train_time:148584ms step_avg:93.27ms +step:1594/1670 train_time:148678ms step_avg:93.27ms +step:1595/1670 train_time:148772ms step_avg:93.27ms +step:1596/1670 train_time:148866ms step_avg:93.27ms +step:1597/1670 train_time:148959ms step_avg:93.27ms +step:1598/1670 train_time:149051ms step_avg:93.27ms +step:1599/1670 train_time:149145ms step_avg:93.27ms +step:1600/1670 train_time:149239ms step_avg:93.27ms +step:1601/1670 train_time:149333ms step_avg:93.27ms +step:1602/1670 train_time:149425ms step_avg:93.27ms +step:1603/1670 train_time:149518ms step_avg:93.27ms +step:1604/1670 train_time:149611ms step_avg:93.27ms +step:1605/1670 train_time:149704ms step_avg:93.27ms +step:1606/1670 train_time:149798ms step_avg:93.27ms +step:1607/1670 train_time:149890ms step_avg:93.27ms +step:1608/1670 train_time:149984ms step_avg:93.27ms +step:1609/1670 train_time:150078ms step_avg:93.27ms +step:1610/1670 train_time:150172ms step_avg:93.27ms +step:1611/1670 train_time:150265ms step_avg:93.27ms +step:1612/1670 train_time:150359ms step_avg:93.27ms +step:1613/1670 train_time:150451ms step_avg:93.27ms +step:1614/1670 train_time:150545ms step_avg:93.27ms +step:1615/1670 train_time:150640ms step_avg:93.28ms +step:1616/1670 train_time:150732ms step_avg:93.27ms +step:1617/1670 train_time:150825ms step_avg:93.27ms +step:1618/1670 train_time:150918ms step_avg:93.27ms +step:1619/1670 train_time:151011ms step_avg:93.27ms +step:1620/1670 train_time:151105ms step_avg:93.27ms +step:1621/1670 train_time:151199ms step_avg:93.27ms +step:1622/1670 train_time:151292ms step_avg:93.27ms +step:1623/1670 train_time:151385ms step_avg:93.27ms +step:1624/1670 train_time:151479ms step_avg:93.28ms +step:1625/1670 train_time:151572ms step_avg:93.28ms +step:1625/1670 val_loss:3.2864 train_time:151665ms step_avg:93.33ms +step:1626/1670 train_time:151683ms step_avg:93.29ms +step:1627/1670 train_time:151759ms step_avg:93.28ms +step:1628/1670 train_time:151853ms step_avg:93.28ms +step:1629/1670 train_time:151946ms step_avg:93.28ms +step:1630/1670 train_time:152038ms step_avg:93.27ms +step:1631/1670 train_time:152131ms step_avg:93.27ms +step:1632/1670 train_time:152224ms step_avg:93.27ms +step:1633/1670 train_time:152317ms step_avg:93.27ms +step:1634/1670 train_time:152411ms step_avg:93.27ms +step:1635/1670 train_time:152505ms step_avg:93.28ms +step:1636/1670 train_time:152599ms step_avg:93.28ms +step:1637/1670 train_time:152695ms step_avg:93.28ms +step:1638/1670 train_time:152788ms step_avg:93.28ms +step:1639/1670 train_time:152881ms step_avg:93.28ms +step:1640/1670 train_time:152973ms step_avg:93.28ms +step:1641/1670 train_time:153066ms step_avg:93.28ms +step:1642/1670 train_time:153160ms step_avg:93.28ms +step:1643/1670 train_time:153253ms step_avg:93.28ms +step:1644/1670 train_time:153347ms step_avg:93.28ms +step:1645/1670 train_time:153440ms step_avg:93.28ms +step:1646/1670 train_time:153534ms step_avg:93.28ms +step:1647/1670 train_time:153629ms step_avg:93.28ms +step:1648/1670 train_time:153722ms step_avg:93.28ms +step:1649/1670 train_time:153815ms step_avg:93.28ms +step:1650/1670 train_time:153908ms step_avg:93.28ms +step:1651/1670 train_time:154001ms step_avg:93.28ms +step:1652/1670 train_time:154094ms step_avg:93.28ms +step:1653/1670 train_time:154187ms step_avg:93.28ms +step:1654/1670 train_time:154279ms step_avg:93.28ms +step:1655/1670 train_time:154372ms step_avg:93.28ms +step:1656/1670 train_time:154467ms step_avg:93.28ms +step:1657/1670 train_time:154560ms step_avg:93.28ms +step:1658/1670 train_time:154653ms step_avg:93.28ms +step:1659/1670 train_time:154747ms step_avg:93.28ms +step:1660/1670 train_time:154841ms step_avg:93.28ms +step:1661/1670 train_time:154934ms step_avg:93.28ms +step:1662/1670 train_time:155028ms step_avg:93.28ms +step:1663/1670 train_time:155121ms step_avg:93.28ms +step:1664/1670 train_time:155214ms step_avg:93.28ms +step:1665/1670 train_time:155306ms step_avg:93.28ms +step:1666/1670 train_time:155400ms step_avg:93.28ms +step:1667/1670 train_time:155493ms step_avg:93.28ms +step:1668/1670 train_time:155587ms step_avg:93.28ms +step:1669/1670 train_time:155680ms step_avg:93.28ms +step:1670/1670 train_time:155773ms step_avg:93.28ms +step:1670/1670 val_loss:3.2778 train_time:156043ms step_avg:93.44ms +peak memory allocated: 32002 MiB reserved: 46354 MiB diff --git a/records/091125_VectSigmoidBFloat16/648dffae-9eb3-4d2a-a28a-dae8d1152aa7.txt b/records/091125_VectSigmoidBFloat16/648dffae-9eb3-4d2a-a28a-dae8d1152aa7.txt new file mode 100644 index 000000000..ff840da94 --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/648dffae-9eb3-4d2a-a28a-dae8d1152aa7.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 10:08:35 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 33C P0 120W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 34C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 34C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 31C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 33C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 34C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 33C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:291ms step_avg:290.94ms +step:2/1670 train_time:309ms step_avg:154.27ms +step:3/1670 train_time:378ms step_avg:125.96ms +step:4/1670 train_time:467ms step_avg:116.77ms +step:5/1670 train_time:557ms step_avg:111.42ms +step:6/1670 train_time:647ms step_avg:107.86ms +step:7/1670 train_time:737ms step_avg:105.34ms +step:8/1670 train_time:828ms step_avg:103.52ms +step:9/1670 train_time:918ms step_avg:102.01ms +step:10/1670 train_time:1008ms step_avg:100.84ms +step:11/1670 train_time:1099ms step_avg:99.87ms +step:12/1670 train_time:1191ms step_avg:99.28ms +step:13/1670 train_time:1285ms step_avg:98.87ms +step:14/1670 train_time:1377ms step_avg:98.37ms +step:15/1670 train_time:1471ms step_avg:98.07ms +step:16/1670 train_time:1562ms step_avg:97.64ms +step:17/1670 train_time:1653ms step_avg:97.21ms +step:18/1670 train_time:1743ms step_avg:96.81ms +step:19/1670 train_time:1833ms step_avg:96.49ms +step:20/1670 train_time:1924ms step_avg:96.18ms +step:21/1670 train_time:2016ms step_avg:95.99ms +step:22/1670 train_time:2108ms step_avg:95.81ms +step:23/1670 train_time:2200ms step_avg:95.65ms +step:24/1670 train_time:2294ms step_avg:95.58ms +step:25/1670 train_time:2386ms step_avg:95.44ms +step:26/1670 train_time:2478ms step_avg:95.29ms +step:27/1670 train_time:2569ms step_avg:95.13ms +step:28/1670 train_time:2660ms step_avg:94.99ms +step:29/1670 train_time:2751ms step_avg:94.85ms +step:30/1670 train_time:2841ms step_avg:94.69ms +step:31/1670 train_time:2932ms step_avg:94.58ms +step:32/1670 train_time:3023ms step_avg:94.47ms +step:33/1670 train_time:3115ms step_avg:94.38ms +step:34/1670 train_time:3206ms step_avg:94.30ms +step:35/1670 train_time:3298ms step_avg:94.24ms +step:36/1670 train_time:3390ms step_avg:94.17ms +step:37/1670 train_time:3481ms step_avg:94.09ms +step:38/1670 train_time:3574ms step_avg:94.04ms +step:39/1670 train_time:3665ms step_avg:93.98ms +step:40/1670 train_time:3757ms step_avg:93.91ms +step:41/1670 train_time:3848ms step_avg:93.84ms +step:42/1670 train_time:3938ms step_avg:93.77ms +step:43/1670 train_time:4030ms step_avg:93.71ms +step:44/1670 train_time:4121ms step_avg:93.65ms +step:45/1670 train_time:4213ms step_avg:93.63ms +step:46/1670 train_time:4305ms step_avg:93.58ms +step:47/1670 train_time:4397ms step_avg:93.56ms +step:48/1670 train_time:4490ms step_avg:93.54ms +step:49/1670 train_time:4581ms step_avg:93.48ms +step:50/1670 train_time:4671ms step_avg:93.42ms +step:51/1670 train_time:4761ms step_avg:93.36ms +step:52/1670 train_time:4852ms step_avg:93.31ms +step:53/1670 train_time:4942ms step_avg:93.25ms +step:54/1670 train_time:5033ms step_avg:93.20ms +step:55/1670 train_time:5124ms step_avg:93.15ms +step:56/1670 train_time:5215ms step_avg:93.12ms +step:57/1670 train_time:5306ms step_avg:93.10ms +step:58/1670 train_time:5398ms step_avg:93.07ms +step:59/1670 train_time:5489ms step_avg:93.04ms +step:60/1670 train_time:5580ms step_avg:93.00ms +step:61/1670 train_time:5672ms step_avg:92.99ms +step:62/1670 train_time:5763ms step_avg:92.95ms +step:63/1670 train_time:5855ms step_avg:92.93ms +step:64/1670 train_time:5946ms step_avg:92.90ms +step:65/1670 train_time:6037ms step_avg:92.88ms +step:66/1670 train_time:6130ms step_avg:92.87ms +step:67/1670 train_time:6220ms step_avg:92.84ms +step:68/1670 train_time:6312ms step_avg:92.82ms +step:69/1670 train_time:6403ms step_avg:92.80ms +step:70/1670 train_time:6495ms step_avg:92.78ms +step:71/1670 train_time:6586ms step_avg:92.76ms +step:72/1670 train_time:6677ms step_avg:92.73ms +step:73/1670 train_time:6769ms step_avg:92.72ms +step:74/1670 train_time:6860ms step_avg:92.70ms +step:75/1670 train_time:6951ms step_avg:92.68ms +step:76/1670 train_time:7042ms step_avg:92.65ms +step:77/1670 train_time:7133ms step_avg:92.64ms +step:78/1670 train_time:7224ms step_avg:92.62ms +step:79/1670 train_time:7317ms step_avg:92.62ms +step:80/1670 train_time:7408ms step_avg:92.60ms +step:81/1670 train_time:7500ms step_avg:92.59ms +step:82/1670 train_time:7591ms step_avg:92.57ms +step:83/1670 train_time:7682ms step_avg:92.55ms +step:84/1670 train_time:7773ms step_avg:92.54ms +step:85/1670 train_time:7864ms step_avg:92.52ms +step:86/1670 train_time:7956ms step_avg:92.51ms +step:87/1670 train_time:8047ms step_avg:92.49ms +step:88/1670 train_time:8138ms step_avg:92.47ms +step:89/1670 train_time:8229ms step_avg:92.46ms +step:90/1670 train_time:8320ms step_avg:92.45ms +step:91/1670 train_time:8413ms step_avg:92.45ms +step:92/1670 train_time:8504ms step_avg:92.43ms +step:93/1670 train_time:8596ms step_avg:92.43ms +step:94/1670 train_time:8688ms step_avg:92.42ms +step:95/1670 train_time:8778ms step_avg:92.40ms +step:96/1670 train_time:8869ms step_avg:92.38ms +step:97/1670 train_time:8960ms step_avg:92.37ms +step:98/1670 train_time:9051ms step_avg:92.35ms +step:99/1670 train_time:9141ms step_avg:92.34ms +step:100/1670 train_time:9234ms step_avg:92.34ms +step:101/1670 train_time:9327ms step_avg:92.34ms +step:102/1670 train_time:9418ms step_avg:92.33ms +step:103/1670 train_time:9509ms step_avg:92.32ms +step:104/1670 train_time:9600ms step_avg:92.31ms +step:105/1670 train_time:9692ms step_avg:92.30ms +step:106/1670 train_time:9782ms step_avg:92.28ms +step:107/1670 train_time:9873ms step_avg:92.27ms +step:108/1670 train_time:9963ms step_avg:92.25ms +step:109/1670 train_time:10055ms step_avg:92.25ms +step:110/1670 train_time:10146ms step_avg:92.23ms +step:111/1670 train_time:10237ms step_avg:92.23ms +step:112/1670 train_time:10328ms step_avg:92.22ms +step:113/1670 train_time:10419ms step_avg:92.21ms +step:114/1670 train_time:10512ms step_avg:92.21ms +step:115/1670 train_time:10603ms step_avg:92.20ms +step:116/1670 train_time:10694ms step_avg:92.19ms +step:117/1670 train_time:10785ms step_avg:92.18ms +step:118/1670 train_time:10876ms step_avg:92.17ms +step:119/1670 train_time:10966ms step_avg:92.15ms +step:120/1670 train_time:11057ms step_avg:92.14ms +step:121/1670 train_time:11147ms step_avg:92.13ms +step:122/1670 train_time:11239ms step_avg:92.12ms +step:123/1670 train_time:11332ms step_avg:92.13ms +step:124/1670 train_time:11422ms step_avg:92.11ms +step:125/1670 train_time:11514ms step_avg:92.11ms +step:125/1670 val_loss:4.3094 train_time:11605ms step_avg:92.84ms +step:126/1670 train_time:11623ms step_avg:92.25ms +step:127/1670 train_time:11698ms step_avg:92.11ms +step:128/1670 train_time:11798ms step_avg:92.17ms +step:129/1670 train_time:11892ms step_avg:92.19ms +step:130/1670 train_time:11983ms step_avg:92.18ms +step:131/1670 train_time:12073ms step_avg:92.16ms +step:132/1670 train_time:12163ms step_avg:92.14ms +step:133/1670 train_time:12253ms step_avg:92.13ms +step:134/1670 train_time:12343ms step_avg:92.11ms +step:135/1670 train_time:12433ms step_avg:92.10ms +step:136/1670 train_time:12525ms step_avg:92.10ms +step:137/1670 train_time:12618ms step_avg:92.10ms +step:138/1670 train_time:12710ms step_avg:92.10ms +step:139/1670 train_time:12804ms step_avg:92.12ms +step:140/1670 train_time:12897ms step_avg:92.12ms +step:141/1670 train_time:12990ms step_avg:92.13ms +step:142/1670 train_time:13081ms step_avg:92.12ms +step:143/1670 train_time:13172ms step_avg:92.11ms +step:144/1670 train_time:13262ms step_avg:92.10ms +step:145/1670 train_time:13352ms step_avg:92.08ms +step:146/1670 train_time:13442ms step_avg:92.07ms +step:147/1670 train_time:13533ms step_avg:92.06ms +step:148/1670 train_time:13624ms step_avg:92.05ms +step:149/1670 train_time:13716ms step_avg:92.05ms +step:150/1670 train_time:13807ms step_avg:92.05ms +step:151/1670 train_time:13902ms step_avg:92.06ms +step:152/1670 train_time:13993ms step_avg:92.06ms +step:153/1670 train_time:14085ms step_avg:92.06ms +step:154/1670 train_time:14177ms step_avg:92.06ms +step:155/1670 train_time:14267ms step_avg:92.04ms +step:156/1670 train_time:14357ms step_avg:92.03ms +step:157/1670 train_time:14447ms step_avg:92.02ms +step:158/1670 train_time:14537ms step_avg:92.00ms +step:159/1670 train_time:14627ms step_avg:92.00ms +step:160/1670 train_time:14718ms step_avg:91.99ms +step:161/1670 train_time:14810ms step_avg:91.99ms +step:162/1670 train_time:14903ms step_avg:91.99ms +step:163/1670 train_time:14995ms step_avg:91.99ms +step:164/1670 train_time:15086ms step_avg:91.99ms +step:165/1670 train_time:15178ms step_avg:91.99ms +step:166/1670 train_time:15268ms step_avg:91.98ms +step:167/1670 train_time:15359ms step_avg:91.97ms +step:168/1670 train_time:15449ms step_avg:91.96ms +step:169/1670 train_time:15540ms step_avg:91.95ms +step:170/1670 train_time:15630ms step_avg:91.94ms +step:171/1670 train_time:15721ms step_avg:91.94ms +step:172/1670 train_time:15813ms step_avg:91.94ms +step:173/1670 train_time:15905ms step_avg:91.94ms +step:174/1670 train_time:15997ms step_avg:91.94ms +step:175/1670 train_time:16089ms step_avg:91.94ms +step:176/1670 train_time:16180ms step_avg:91.93ms +step:177/1670 train_time:16270ms step_avg:91.92ms +step:178/1670 train_time:16360ms step_avg:91.91ms +step:179/1670 train_time:16451ms step_avg:91.90ms +step:180/1670 train_time:16542ms step_avg:91.90ms +step:181/1670 train_time:16631ms step_avg:91.89ms +step:182/1670 train_time:16723ms step_avg:91.88ms +step:183/1670 train_time:16815ms step_avg:91.88ms +step:184/1670 train_time:16906ms step_avg:91.88ms +step:185/1670 train_time:16998ms step_avg:91.88ms +step:186/1670 train_time:17090ms step_avg:91.88ms +step:187/1670 train_time:17180ms step_avg:91.87ms +step:188/1670 train_time:17271ms step_avg:91.86ms +step:189/1670 train_time:17362ms step_avg:91.86ms +step:190/1670 train_time:17453ms step_avg:91.86ms +step:191/1670 train_time:17543ms step_avg:91.85ms +step:192/1670 train_time:17633ms step_avg:91.84ms +step:193/1670 train_time:17724ms step_avg:91.84ms +step:194/1670 train_time:17816ms step_avg:91.83ms +step:195/1670 train_time:17906ms step_avg:91.83ms +step:196/1670 train_time:17999ms step_avg:91.83ms +step:197/1670 train_time:18090ms step_avg:91.83ms +step:198/1670 train_time:18181ms step_avg:91.82ms +step:199/1670 train_time:18271ms step_avg:91.81ms +step:200/1670 train_time:18362ms step_avg:91.81ms +step:201/1670 train_time:18452ms step_avg:91.80ms +step:202/1670 train_time:18543ms step_avg:91.80ms +step:203/1670 train_time:18634ms step_avg:91.79ms +step:204/1670 train_time:18725ms step_avg:91.79ms +step:205/1670 train_time:18816ms step_avg:91.78ms +step:206/1670 train_time:18906ms step_avg:91.78ms +step:207/1670 train_time:18998ms step_avg:91.78ms +step:208/1670 train_time:19089ms step_avg:91.77ms +step:209/1670 train_time:19180ms step_avg:91.77ms +step:210/1670 train_time:19270ms step_avg:91.76ms +step:211/1670 train_time:19362ms step_avg:91.76ms +step:212/1670 train_time:19453ms step_avg:91.76ms +step:213/1670 train_time:19695ms step_avg:92.46ms +step:214/1670 train_time:19766ms step_avg:92.36ms +step:215/1670 train_time:19856ms step_avg:92.35ms +step:216/1670 train_time:19945ms step_avg:92.34ms +step:217/1670 train_time:20035ms step_avg:92.33ms +step:218/1670 train_time:20125ms step_avg:92.32ms +step:219/1670 train_time:20215ms step_avg:92.31ms +step:220/1670 train_time:20304ms step_avg:92.29ms +step:221/1670 train_time:20394ms step_avg:92.28ms +step:222/1670 train_time:20484ms step_avg:92.27ms +step:223/1670 train_time:20578ms step_avg:92.28ms +step:224/1670 train_time:20675ms step_avg:92.30ms +step:225/1670 train_time:20767ms step_avg:92.30ms +step:226/1670 train_time:20859ms step_avg:92.30ms +step:227/1670 train_time:20950ms step_avg:92.29ms +step:228/1670 train_time:21040ms step_avg:92.28ms +step:229/1670 train_time:21130ms step_avg:92.27ms +step:230/1670 train_time:21221ms step_avg:92.26ms +step:231/1670 train_time:21312ms step_avg:92.26ms +step:232/1670 train_time:21402ms step_avg:92.25ms +step:233/1670 train_time:21493ms step_avg:92.24ms +step:234/1670 train_time:21585ms step_avg:92.24ms +step:235/1670 train_time:21680ms step_avg:92.26ms +step:236/1670 train_time:21772ms step_avg:92.25ms +step:237/1670 train_time:21863ms step_avg:92.25ms +step:238/1670 train_time:21954ms step_avg:92.24ms +step:239/1670 train_time:22045ms step_avg:92.24ms +step:240/1670 train_time:22135ms step_avg:92.23ms +step:241/1670 train_time:22225ms step_avg:92.22ms +step:242/1670 train_time:22317ms step_avg:92.22ms +step:243/1670 train_time:22408ms step_avg:92.21ms +step:244/1670 train_time:22499ms step_avg:92.21ms +step:245/1670 train_time:22591ms step_avg:92.21ms +step:246/1670 train_time:22683ms step_avg:92.21ms +step:247/1670 train_time:22776ms step_avg:92.21ms +step:248/1670 train_time:22866ms step_avg:92.20ms +step:249/1670 train_time:22958ms step_avg:92.20ms +step:250/1670 train_time:23049ms step_avg:92.19ms +step:250/1670 val_loss:3.9645 train_time:23141ms step_avg:92.56ms +step:251/1670 train_time:23159ms step_avg:92.27ms +step:252/1670 train_time:23233ms step_avg:92.19ms +step:253/1670 train_time:23326ms step_avg:92.20ms +step:254/1670 train_time:23417ms step_avg:92.19ms +step:255/1670 train_time:23506ms step_avg:92.18ms +step:256/1670 train_time:23597ms step_avg:92.17ms +step:257/1670 train_time:23687ms step_avg:92.17ms +step:258/1670 train_time:23778ms step_avg:92.16ms +step:259/1670 train_time:23868ms step_avg:92.16ms +step:260/1670 train_time:23959ms step_avg:92.15ms +step:261/1670 train_time:24050ms step_avg:92.15ms +step:262/1670 train_time:24142ms step_avg:92.15ms +step:263/1670 train_time:24235ms step_avg:92.15ms +step:264/1670 train_time:24326ms step_avg:92.14ms +step:265/1670 train_time:24417ms step_avg:92.14ms +step:266/1670 train_time:24507ms step_avg:92.13ms +step:267/1670 train_time:24598ms step_avg:92.13ms +step:268/1670 train_time:24689ms step_avg:92.12ms +step:269/1670 train_time:24779ms step_avg:92.12ms +step:270/1670 train_time:24870ms step_avg:92.11ms +step:271/1670 train_time:24960ms step_avg:92.10ms +step:272/1670 train_time:25052ms step_avg:92.10ms +step:273/1670 train_time:25143ms step_avg:92.10ms +step:274/1670 train_time:25235ms step_avg:92.10ms +step:275/1670 train_time:25327ms step_avg:92.10ms +step:276/1670 train_time:25419ms step_avg:92.10ms +step:277/1670 train_time:25509ms step_avg:92.09ms +step:278/1670 train_time:25599ms step_avg:92.08ms +step:279/1670 train_time:25691ms step_avg:92.08ms +step:280/1670 train_time:25781ms step_avg:92.07ms +step:281/1670 train_time:25872ms step_avg:92.07ms +step:282/1670 train_time:25963ms step_avg:92.07ms +step:283/1670 train_time:26054ms step_avg:92.06ms +step:284/1670 train_time:26145ms step_avg:92.06ms +step:285/1670 train_time:26238ms step_avg:92.06ms +step:286/1670 train_time:26329ms step_avg:92.06ms +step:287/1670 train_time:26421ms step_avg:92.06ms +step:288/1670 train_time:26511ms step_avg:92.05ms +step:289/1670 train_time:26602ms step_avg:92.05ms +step:290/1670 train_time:26693ms step_avg:92.05ms +step:291/1670 train_time:26784ms step_avg:92.04ms +step:292/1670 train_time:26875ms step_avg:92.04ms +step:293/1670 train_time:26966ms step_avg:92.03ms +step:294/1670 train_time:27057ms step_avg:92.03ms +step:295/1670 train_time:27147ms step_avg:92.02ms +step:296/1670 train_time:27239ms step_avg:92.02ms +step:297/1670 train_time:27330ms step_avg:92.02ms +step:298/1670 train_time:27422ms step_avg:92.02ms +step:299/1670 train_time:27513ms step_avg:92.02ms +step:300/1670 train_time:27603ms step_avg:92.01ms +step:301/1670 train_time:27696ms step_avg:92.01ms +step:302/1670 train_time:27787ms step_avg:92.01ms +step:303/1670 train_time:27878ms step_avg:92.01ms +step:304/1670 train_time:27969ms step_avg:92.00ms +step:305/1670 train_time:28060ms step_avg:92.00ms +step:306/1670 train_time:28151ms step_avg:92.00ms +step:307/1670 train_time:28241ms step_avg:91.99ms +step:308/1670 train_time:28332ms step_avg:91.99ms +step:309/1670 train_time:28423ms step_avg:91.98ms +step:310/1670 train_time:28514ms step_avg:91.98ms +step:311/1670 train_time:28604ms step_avg:91.98ms +step:312/1670 train_time:28696ms step_avg:91.98ms +step:313/1670 train_time:28788ms step_avg:91.97ms +step:314/1670 train_time:28879ms step_avg:91.97ms +step:315/1670 train_time:28972ms step_avg:91.97ms +step:316/1670 train_time:29062ms step_avg:91.97ms +step:317/1670 train_time:29154ms step_avg:91.97ms +step:318/1670 train_time:29244ms step_avg:91.96ms +step:319/1670 train_time:29335ms step_avg:91.96ms +step:320/1670 train_time:29426ms step_avg:91.96ms +step:321/1670 train_time:29517ms step_avg:91.95ms +step:322/1670 train_time:29608ms step_avg:91.95ms +step:323/1670 train_time:29699ms step_avg:91.95ms +step:324/1670 train_time:29789ms step_avg:91.94ms +step:325/1670 train_time:29880ms step_avg:91.94ms +step:326/1670 train_time:29971ms step_avg:91.94ms +step:327/1670 train_time:30062ms step_avg:91.93ms +step:328/1670 train_time:30153ms step_avg:91.93ms +step:329/1670 train_time:30244ms step_avg:91.93ms +step:330/1670 train_time:30335ms step_avg:91.92ms +step:331/1670 train_time:30426ms step_avg:91.92ms +step:332/1670 train_time:30518ms step_avg:91.92ms +step:333/1670 train_time:30608ms step_avg:91.92ms +step:334/1670 train_time:30699ms step_avg:91.91ms +step:335/1670 train_time:30790ms step_avg:91.91ms +step:336/1670 train_time:30881ms step_avg:91.91ms +step:337/1670 train_time:30972ms step_avg:91.90ms +step:338/1670 train_time:31062ms step_avg:91.90ms +step:339/1670 train_time:31154ms step_avg:91.90ms +step:340/1670 train_time:31244ms step_avg:91.89ms +step:341/1670 train_time:31336ms step_avg:91.89ms +step:342/1670 train_time:31426ms step_avg:91.89ms +step:343/1670 train_time:31518ms step_avg:91.89ms +step:344/1670 train_time:31609ms step_avg:91.89ms +step:345/1670 train_time:31700ms step_avg:91.88ms +step:346/1670 train_time:31790ms step_avg:91.88ms +step:347/1670 train_time:31881ms step_avg:91.88ms +step:348/1670 train_time:31973ms step_avg:91.88ms +step:349/1670 train_time:32064ms step_avg:91.87ms +step:350/1670 train_time:32156ms step_avg:91.87ms +step:351/1670 train_time:32246ms step_avg:91.87ms +step:352/1670 train_time:32338ms step_avg:91.87ms +step:353/1670 train_time:32428ms step_avg:91.86ms +step:354/1670 train_time:32519ms step_avg:91.86ms +step:355/1670 train_time:32610ms step_avg:91.86ms +step:356/1670 train_time:32701ms step_avg:91.86ms +step:357/1670 train_time:32793ms step_avg:91.86ms +step:358/1670 train_time:32883ms step_avg:91.85ms +step:359/1670 train_time:32975ms step_avg:91.85ms +step:360/1670 train_time:33067ms step_avg:91.85ms +step:361/1670 train_time:33159ms step_avg:91.85ms +step:362/1670 train_time:33250ms step_avg:91.85ms +step:363/1670 train_time:33341ms step_avg:91.85ms +step:364/1670 train_time:33432ms step_avg:91.85ms +step:365/1670 train_time:33523ms step_avg:91.84ms +step:366/1670 train_time:33614ms step_avg:91.84ms +step:367/1670 train_time:33705ms step_avg:91.84ms +step:368/1670 train_time:33796ms step_avg:91.84ms +step:369/1670 train_time:33887ms step_avg:91.83ms +step:370/1670 train_time:33978ms step_avg:91.83ms +step:371/1670 train_time:34070ms step_avg:91.83ms +step:372/1670 train_time:34162ms step_avg:91.83ms +step:373/1670 train_time:34254ms step_avg:91.83ms +step:374/1670 train_time:34344ms step_avg:91.83ms +step:375/1670 train_time:34435ms step_avg:91.83ms +step:375/1670 val_loss:3.8109 train_time:34525ms step_avg:92.07ms +step:376/1670 train_time:34543ms step_avg:91.87ms +step:377/1670 train_time:34617ms step_avg:91.82ms +step:378/1670 train_time:34709ms step_avg:91.82ms +step:379/1670 train_time:34799ms step_avg:91.82ms +step:380/1670 train_time:34889ms step_avg:91.81ms +step:381/1670 train_time:34979ms step_avg:91.81ms +step:382/1670 train_time:35070ms step_avg:91.81ms +step:383/1670 train_time:35161ms step_avg:91.80ms +step:384/1670 train_time:35251ms step_avg:91.80ms +step:385/1670 train_time:35343ms step_avg:91.80ms +step:386/1670 train_time:35434ms step_avg:91.80ms +step:387/1670 train_time:35526ms step_avg:91.80ms +step:388/1670 train_time:35618ms step_avg:91.80ms +step:389/1670 train_time:35710ms step_avg:91.80ms +step:390/1670 train_time:35801ms step_avg:91.80ms +step:391/1670 train_time:35892ms step_avg:91.79ms +step:392/1670 train_time:35982ms step_avg:91.79ms +step:393/1670 train_time:36072ms step_avg:91.79ms +step:394/1670 train_time:36162ms step_avg:91.78ms +step:395/1670 train_time:36253ms step_avg:91.78ms +step:396/1670 train_time:36343ms step_avg:91.77ms +step:397/1670 train_time:36434ms step_avg:91.77ms +step:398/1670 train_time:36526ms step_avg:91.77ms +step:399/1670 train_time:36617ms step_avg:91.77ms +step:400/1670 train_time:36709ms step_avg:91.77ms +step:401/1670 train_time:36800ms step_avg:91.77ms +step:402/1670 train_time:36891ms step_avg:91.77ms +step:403/1670 train_time:36982ms step_avg:91.77ms +step:404/1670 train_time:37072ms step_avg:91.76ms +step:405/1670 train_time:37163ms step_avg:91.76ms +step:406/1670 train_time:37253ms step_avg:91.76ms +step:407/1670 train_time:37346ms step_avg:91.76ms +step:408/1670 train_time:37438ms step_avg:91.76ms +step:409/1670 train_time:37530ms step_avg:91.76ms +step:410/1670 train_time:37620ms step_avg:91.76ms +step:411/1670 train_time:37712ms step_avg:91.76ms +step:412/1670 train_time:37802ms step_avg:91.75ms +step:413/1670 train_time:37893ms step_avg:91.75ms +step:414/1670 train_time:37983ms step_avg:91.75ms +step:415/1670 train_time:38074ms step_avg:91.75ms +step:416/1670 train_time:38166ms step_avg:91.74ms +step:417/1670 train_time:38256ms step_avg:91.74ms +step:418/1670 train_time:38347ms step_avg:91.74ms +step:419/1670 train_time:38438ms step_avg:91.74ms +step:420/1670 train_time:38530ms step_avg:91.74ms +step:421/1670 train_time:38621ms step_avg:91.74ms +step:422/1670 train_time:38712ms step_avg:91.73ms +step:423/1670 train_time:38803ms step_avg:91.73ms +step:424/1670 train_time:38893ms step_avg:91.73ms +step:425/1670 train_time:39142ms step_avg:92.10ms +step:426/1670 train_time:39213ms step_avg:92.05ms +step:427/1670 train_time:39303ms step_avg:92.05ms +step:428/1670 train_time:39393ms step_avg:92.04ms +step:429/1670 train_time:39482ms step_avg:92.03ms +step:430/1670 train_time:39572ms step_avg:92.03ms +step:431/1670 train_time:39662ms step_avg:92.02ms +step:432/1670 train_time:39752ms step_avg:92.02ms +step:433/1670 train_time:39842ms step_avg:92.01ms +step:434/1670 train_time:39932ms step_avg:92.01ms +step:435/1670 train_time:40026ms step_avg:92.01ms +step:436/1670 train_time:40121ms step_avg:92.02ms +step:437/1670 train_time:40214ms step_avg:92.02ms +step:438/1670 train_time:40305ms step_avg:92.02ms +step:439/1670 train_time:40396ms step_avg:92.02ms +step:440/1670 train_time:40486ms step_avg:92.01ms +step:441/1670 train_time:40576ms step_avg:92.01ms +step:442/1670 train_time:40667ms step_avg:92.01ms +step:443/1670 train_time:40757ms step_avg:92.00ms +step:444/1670 train_time:40847ms step_avg:92.00ms +step:445/1670 train_time:40937ms step_avg:91.99ms +step:446/1670 train_time:41030ms step_avg:91.99ms +step:447/1670 train_time:41121ms step_avg:91.99ms +step:448/1670 train_time:41213ms step_avg:91.99ms +step:449/1670 train_time:41306ms step_avg:92.00ms +step:450/1670 train_time:41396ms step_avg:91.99ms +step:451/1670 train_time:41487ms step_avg:91.99ms +step:452/1670 train_time:41577ms step_avg:91.98ms +step:453/1670 train_time:41667ms step_avg:91.98ms +step:454/1670 train_time:41758ms step_avg:91.98ms +step:455/1670 train_time:41848ms step_avg:91.97ms +step:456/1670 train_time:41939ms step_avg:91.97ms +step:457/1670 train_time:42031ms step_avg:91.97ms +step:458/1670 train_time:42123ms step_avg:91.97ms +step:459/1670 train_time:42215ms step_avg:91.97ms +step:460/1670 train_time:42307ms step_avg:91.97ms +step:461/1670 train_time:42398ms step_avg:91.97ms +step:462/1670 train_time:42489ms step_avg:91.97ms +step:463/1670 train_time:42579ms step_avg:91.96ms +step:464/1670 train_time:42670ms step_avg:91.96ms +step:465/1670 train_time:42762ms step_avg:91.96ms +step:466/1670 train_time:42853ms step_avg:91.96ms +step:467/1670 train_time:42944ms step_avg:91.96ms +step:468/1670 train_time:43036ms step_avg:91.96ms +step:469/1670 train_time:43127ms step_avg:91.96ms +step:470/1670 train_time:43219ms step_avg:91.95ms +step:471/1670 train_time:43311ms step_avg:91.95ms +step:472/1670 train_time:43401ms step_avg:91.95ms +step:473/1670 train_time:43493ms step_avg:91.95ms +step:474/1670 train_time:43583ms step_avg:91.95ms +step:475/1670 train_time:43673ms step_avg:91.94ms +step:476/1670 train_time:43765ms step_avg:91.94ms +step:477/1670 train_time:43855ms step_avg:91.94ms +step:478/1670 train_time:43945ms step_avg:91.94ms +step:479/1670 train_time:44036ms step_avg:91.93ms +step:480/1670 train_time:44129ms step_avg:91.93ms +step:481/1670 train_time:44220ms step_avg:91.93ms +step:482/1670 train_time:44312ms step_avg:91.93ms +step:483/1670 train_time:44404ms step_avg:91.93ms +step:484/1670 train_time:44495ms step_avg:91.93ms +step:485/1670 train_time:44586ms step_avg:91.93ms +step:486/1670 train_time:44676ms step_avg:91.93ms +step:487/1670 train_time:44768ms step_avg:91.93ms +step:488/1670 train_time:44858ms step_avg:91.92ms +step:489/1670 train_time:44949ms step_avg:91.92ms +step:490/1670 train_time:45040ms step_avg:91.92ms +step:491/1670 train_time:45132ms step_avg:91.92ms +step:492/1670 train_time:45222ms step_avg:91.92ms +step:493/1670 train_time:45313ms step_avg:91.91ms +step:494/1670 train_time:45404ms step_avg:91.91ms +step:495/1670 train_time:45495ms step_avg:91.91ms +step:496/1670 train_time:45585ms step_avg:91.91ms +step:497/1670 train_time:45676ms step_avg:91.90ms +step:498/1670 train_time:45767ms step_avg:91.90ms +step:499/1670 train_time:45857ms step_avg:91.90ms +step:500/1670 train_time:45949ms step_avg:91.90ms +step:500/1670 val_loss:3.7103 train_time:46039ms step_avg:92.08ms +step:501/1670 train_time:46057ms step_avg:91.93ms +step:502/1670 train_time:46131ms step_avg:91.89ms +step:503/1670 train_time:46224ms step_avg:91.90ms +step:504/1670 train_time:46315ms step_avg:91.89ms +step:505/1670 train_time:46406ms step_avg:91.89ms +step:506/1670 train_time:46497ms step_avg:91.89ms +step:507/1670 train_time:46587ms step_avg:91.89ms +step:508/1670 train_time:46679ms step_avg:91.89ms +step:509/1670 train_time:46769ms step_avg:91.88ms +step:510/1670 train_time:46861ms step_avg:91.88ms +step:511/1670 train_time:46951ms step_avg:91.88ms +step:512/1670 train_time:47043ms step_avg:91.88ms +step:513/1670 train_time:47134ms step_avg:91.88ms +step:514/1670 train_time:47226ms step_avg:91.88ms +step:515/1670 train_time:47318ms step_avg:91.88ms +step:516/1670 train_time:47408ms step_avg:91.88ms +step:517/1670 train_time:47499ms step_avg:91.87ms +step:518/1670 train_time:47590ms step_avg:91.87ms +step:519/1670 train_time:47680ms step_avg:91.87ms +step:520/1670 train_time:47770ms step_avg:91.87ms +step:521/1670 train_time:47861ms step_avg:91.86ms +step:522/1670 train_time:47951ms step_avg:91.86ms +step:523/1670 train_time:48042ms step_avg:91.86ms +step:524/1670 train_time:48134ms step_avg:91.86ms +step:525/1670 train_time:48226ms step_avg:91.86ms +step:526/1670 train_time:48318ms step_avg:91.86ms +step:527/1670 train_time:48410ms step_avg:91.86ms +step:528/1670 train_time:48500ms step_avg:91.86ms +step:529/1670 train_time:48591ms step_avg:91.85ms +step:530/1670 train_time:48682ms step_avg:91.85ms +step:531/1670 train_time:48772ms step_avg:91.85ms +step:532/1670 train_time:48865ms step_avg:91.85ms +step:533/1670 train_time:48955ms step_avg:91.85ms +step:534/1670 train_time:49046ms step_avg:91.85ms +step:535/1670 train_time:49138ms step_avg:91.85ms +step:536/1670 train_time:49230ms step_avg:91.85ms +step:537/1670 train_time:49323ms step_avg:91.85ms +step:538/1670 train_time:49414ms step_avg:91.85ms +step:539/1670 train_time:49506ms step_avg:91.85ms +step:540/1670 train_time:49596ms step_avg:91.84ms +step:541/1670 train_time:49687ms step_avg:91.84ms +step:542/1670 train_time:49778ms step_avg:91.84ms +step:543/1670 train_time:49869ms step_avg:91.84ms +step:544/1670 train_time:49960ms step_avg:91.84ms +step:545/1670 train_time:50051ms step_avg:91.84ms +step:546/1670 train_time:50143ms step_avg:91.84ms +step:547/1670 train_time:50234ms step_avg:91.84ms +step:548/1670 train_time:50325ms step_avg:91.83ms +step:549/1670 train_time:50416ms step_avg:91.83ms +step:550/1670 train_time:50507ms step_avg:91.83ms +step:551/1670 train_time:50599ms step_avg:91.83ms +step:552/1670 train_time:50690ms step_avg:91.83ms +step:553/1670 train_time:50781ms step_avg:91.83ms +step:554/1670 train_time:50873ms step_avg:91.83ms +step:555/1670 train_time:50964ms step_avg:91.83ms +step:556/1670 train_time:51054ms step_avg:91.82ms +step:557/1670 train_time:51146ms step_avg:91.82ms +step:558/1670 train_time:51423ms step_avg:92.16ms +step:559/1670 train_time:51501ms step_avg:92.13ms +step:560/1670 train_time:51591ms step_avg:92.13ms +step:561/1670 train_time:51682ms step_avg:92.12ms +step:562/1670 train_time:51773ms step_avg:92.12ms +step:563/1670 train_time:51864ms step_avg:92.12ms +step:564/1670 train_time:51954ms step_avg:92.12ms +step:565/1670 train_time:52046ms step_avg:92.12ms +step:566/1670 train_time:52136ms step_avg:92.11ms +step:567/1670 train_time:52228ms step_avg:92.11ms +step:568/1670 train_time:52326ms step_avg:92.12ms +step:569/1670 train_time:52424ms step_avg:92.13ms +step:570/1670 train_time:52516ms step_avg:92.13ms +step:571/1670 train_time:52609ms step_avg:92.13ms +step:572/1670 train_time:52701ms step_avg:92.14ms +step:573/1670 train_time:52792ms step_avg:92.13ms +step:574/1670 train_time:52884ms step_avg:92.13ms +step:575/1670 train_time:52975ms step_avg:92.13ms +step:576/1670 train_time:53067ms step_avg:92.13ms +step:577/1670 train_time:53158ms step_avg:92.13ms +step:578/1670 train_time:53254ms step_avg:92.13ms +step:579/1670 train_time:53350ms step_avg:92.14ms +step:580/1670 train_time:53445ms step_avg:92.15ms +step:581/1670 train_time:53538ms step_avg:92.15ms +step:582/1670 train_time:53631ms step_avg:92.15ms +step:583/1670 train_time:53723ms step_avg:92.15ms +step:584/1670 train_time:53815ms step_avg:92.15ms +step:585/1670 train_time:53907ms step_avg:92.15ms +step:586/1670 train_time:53999ms step_avg:92.15ms +step:587/1670 train_time:54090ms step_avg:92.15ms +step:588/1670 train_time:54183ms step_avg:92.15ms +step:589/1670 train_time:54275ms step_avg:92.15ms +step:590/1670 train_time:54371ms step_avg:92.15ms +step:591/1670 train_time:54464ms step_avg:92.16ms +step:592/1670 train_time:54557ms step_avg:92.16ms +step:593/1670 train_time:54650ms step_avg:92.16ms +step:594/1670 train_time:54743ms step_avg:92.16ms +step:595/1670 train_time:54835ms step_avg:92.16ms +step:596/1670 train_time:54927ms step_avg:92.16ms +step:597/1670 train_time:55018ms step_avg:92.16ms +step:598/1670 train_time:55110ms step_avg:92.16ms +step:599/1670 train_time:55203ms step_avg:92.16ms +step:600/1670 train_time:55295ms step_avg:92.16ms +step:601/1670 train_time:55389ms step_avg:92.16ms +step:602/1670 train_time:55481ms step_avg:92.16ms +step:603/1670 train_time:55574ms step_avg:92.16ms +step:604/1670 train_time:55666ms step_avg:92.16ms +step:605/1670 train_time:55758ms step_avg:92.16ms +step:606/1670 train_time:55850ms step_avg:92.16ms +step:607/1670 train_time:55942ms step_avg:92.16ms +step:608/1670 train_time:56034ms step_avg:92.16ms +step:609/1670 train_time:56126ms step_avg:92.16ms +step:610/1670 train_time:56219ms step_avg:92.16ms +step:611/1670 train_time:56311ms step_avg:92.16ms +step:612/1670 train_time:56404ms step_avg:92.16ms +step:613/1670 train_time:56496ms step_avg:92.16ms +step:614/1670 train_time:56589ms step_avg:92.16ms +step:615/1670 train_time:56681ms step_avg:92.16ms +step:616/1670 train_time:56774ms step_avg:92.17ms +step:617/1670 train_time:56866ms step_avg:92.16ms +step:618/1670 train_time:56957ms step_avg:92.16ms +step:619/1670 train_time:57050ms step_avg:92.16ms +step:620/1670 train_time:57142ms step_avg:92.16ms +step:621/1670 train_time:57234ms step_avg:92.16ms +step:622/1670 train_time:57326ms step_avg:92.16ms +step:623/1670 train_time:57418ms step_avg:92.16ms +step:624/1670 train_time:57511ms step_avg:92.17ms +step:625/1670 train_time:57604ms step_avg:92.17ms +step:625/1670 val_loss:3.6125 train_time:57696ms step_avg:92.31ms +step:626/1670 train_time:57714ms step_avg:92.19ms +step:627/1670 train_time:57791ms step_avg:92.17ms +step:628/1670 train_time:57892ms step_avg:92.18ms +step:629/1670 train_time:57986ms step_avg:92.19ms +step:630/1670 train_time:58078ms step_avg:92.19ms +step:631/1670 train_time:58168ms step_avg:92.18ms +step:632/1670 train_time:58260ms step_avg:92.18ms +step:633/1670 train_time:58351ms step_avg:92.18ms +step:634/1670 train_time:58442ms step_avg:92.18ms +step:635/1670 train_time:58533ms step_avg:92.18ms +step:636/1670 train_time:58624ms step_avg:92.18ms +step:637/1670 train_time:58717ms step_avg:92.18ms +step:638/1670 train_time:58812ms step_avg:92.18ms +step:639/1670 train_time:59043ms step_avg:92.40ms +step:640/1670 train_time:59121ms step_avg:92.38ms +step:641/1670 train_time:59213ms step_avg:92.38ms +step:642/1670 train_time:59303ms step_avg:92.37ms +step:643/1670 train_time:59394ms step_avg:92.37ms +step:644/1670 train_time:59485ms step_avg:92.37ms +step:645/1670 train_time:59576ms step_avg:92.37ms +step:646/1670 train_time:59667ms step_avg:92.36ms +step:647/1670 train_time:59758ms step_avg:92.36ms +step:648/1670 train_time:59850ms step_avg:92.36ms +step:649/1670 train_time:59948ms step_avg:92.37ms +step:650/1670 train_time:60044ms step_avg:92.37ms +step:651/1670 train_time:60137ms step_avg:92.38ms +step:652/1670 train_time:60229ms step_avg:92.38ms +step:653/1670 train_time:60321ms step_avg:92.37ms +step:654/1670 train_time:60412ms step_avg:92.37ms +step:655/1670 train_time:60504ms step_avg:92.37ms +step:656/1670 train_time:60595ms step_avg:92.37ms +step:657/1670 train_time:60686ms step_avg:92.37ms +step:658/1670 train_time:60778ms step_avg:92.37ms +step:659/1670 train_time:60872ms step_avg:92.37ms +step:660/1670 train_time:60966ms step_avg:92.37ms +step:661/1670 train_time:61060ms step_avg:92.37ms +step:662/1670 train_time:61153ms step_avg:92.38ms +step:663/1670 train_time:61245ms step_avg:92.38ms +step:664/1670 train_time:61338ms step_avg:92.38ms +step:665/1670 train_time:61429ms step_avg:92.37ms +step:666/1670 train_time:61520ms step_avg:92.37ms +step:667/1670 train_time:61611ms step_avg:92.37ms +step:668/1670 train_time:61703ms step_avg:92.37ms +step:669/1670 train_time:61795ms step_avg:92.37ms +step:670/1670 train_time:61887ms step_avg:92.37ms +step:671/1670 train_time:61981ms step_avg:92.37ms +step:672/1670 train_time:62076ms step_avg:92.37ms +step:673/1670 train_time:62170ms step_avg:92.38ms +step:674/1670 train_time:62262ms step_avg:92.38ms +step:675/1670 train_time:62354ms step_avg:92.38ms +step:676/1670 train_time:62447ms step_avg:92.38ms +step:677/1670 train_time:62539ms step_avg:92.38ms +step:678/1670 train_time:62631ms step_avg:92.38ms +step:679/1670 train_time:62723ms step_avg:92.38ms +step:680/1670 train_time:62814ms step_avg:92.37ms +step:681/1670 train_time:62906ms step_avg:92.37ms +step:682/1670 train_time:63000ms step_avg:92.38ms +step:683/1670 train_time:63094ms step_avg:92.38ms +step:684/1670 train_time:63186ms step_avg:92.38ms +step:685/1670 train_time:63280ms step_avg:92.38ms +step:686/1670 train_time:63373ms step_avg:92.38ms +step:687/1670 train_time:63465ms step_avg:92.38ms +step:688/1670 train_time:63557ms step_avg:92.38ms +step:689/1670 train_time:63649ms step_avg:92.38ms +step:690/1670 train_time:63741ms step_avg:92.38ms +step:691/1670 train_time:63833ms step_avg:92.38ms +step:692/1670 train_time:63925ms step_avg:92.38ms +step:693/1670 train_time:64018ms step_avg:92.38ms +step:694/1670 train_time:64111ms step_avg:92.38ms +step:695/1670 train_time:64203ms step_avg:92.38ms +step:696/1670 train_time:64296ms step_avg:92.38ms +step:697/1670 train_time:64388ms step_avg:92.38ms +step:698/1670 train_time:64480ms step_avg:92.38ms +step:699/1670 train_time:64572ms step_avg:92.38ms +step:700/1670 train_time:64663ms step_avg:92.38ms +step:701/1670 train_time:64756ms step_avg:92.38ms +step:702/1670 train_time:64847ms step_avg:92.38ms +step:703/1670 train_time:64939ms step_avg:92.37ms +step:704/1670 train_time:65033ms step_avg:92.38ms +step:705/1670 train_time:65125ms step_avg:92.38ms +step:706/1670 train_time:65218ms step_avg:92.38ms +step:707/1670 train_time:65311ms step_avg:92.38ms +step:708/1670 train_time:65403ms step_avg:92.38ms +step:709/1670 train_time:65495ms step_avg:92.38ms +step:710/1670 train_time:65587ms step_avg:92.38ms +step:711/1670 train_time:65681ms step_avg:92.38ms +step:712/1670 train_time:65774ms step_avg:92.38ms +step:713/1670 train_time:65865ms step_avg:92.38ms +step:714/1670 train_time:65959ms step_avg:92.38ms +step:715/1670 train_time:66051ms step_avg:92.38ms +step:716/1670 train_time:66143ms step_avg:92.38ms +step:717/1670 train_time:66236ms step_avg:92.38ms +step:718/1670 train_time:66328ms step_avg:92.38ms +step:719/1670 train_time:66420ms step_avg:92.38ms +step:720/1670 train_time:66513ms step_avg:92.38ms +step:721/1670 train_time:66604ms step_avg:92.38ms +step:722/1670 train_time:66697ms step_avg:92.38ms +step:723/1670 train_time:66790ms step_avg:92.38ms +step:724/1670 train_time:66882ms step_avg:92.38ms +step:725/1670 train_time:66974ms step_avg:92.38ms +step:726/1670 train_time:67066ms step_avg:92.38ms +step:727/1670 train_time:67160ms step_avg:92.38ms +step:728/1670 train_time:67253ms step_avg:92.38ms +step:729/1670 train_time:67345ms step_avg:92.38ms +step:730/1670 train_time:67438ms step_avg:92.38ms +step:731/1670 train_time:67529ms step_avg:92.38ms +step:732/1670 train_time:67622ms step_avg:92.38ms +step:733/1670 train_time:67714ms step_avg:92.38ms +step:734/1670 train_time:67806ms step_avg:92.38ms +step:735/1670 train_time:67899ms step_avg:92.38ms +step:736/1670 train_time:67991ms step_avg:92.38ms +step:737/1670 train_time:68083ms step_avg:92.38ms +step:738/1670 train_time:68176ms step_avg:92.38ms +step:739/1670 train_time:68268ms step_avg:92.38ms +step:740/1670 train_time:68360ms step_avg:92.38ms +step:741/1670 train_time:68453ms step_avg:92.38ms +step:742/1670 train_time:68546ms step_avg:92.38ms +step:743/1670 train_time:68638ms step_avg:92.38ms +step:744/1670 train_time:68730ms step_avg:92.38ms +step:745/1670 train_time:68823ms step_avg:92.38ms +step:746/1670 train_time:68915ms step_avg:92.38ms +step:747/1670 train_time:69007ms step_avg:92.38ms +step:748/1670 train_time:69100ms step_avg:92.38ms +step:749/1670 train_time:69193ms step_avg:92.38ms +step:750/1670 train_time:69285ms step_avg:92.38ms +step:750/1670 val_loss:3.5602 train_time:69380ms step_avg:92.51ms +step:751/1670 train_time:69398ms step_avg:92.41ms +step:752/1670 train_time:69473ms step_avg:92.38ms +step:753/1670 train_time:69566ms step_avg:92.38ms +step:754/1670 train_time:69659ms step_avg:92.39ms +step:755/1670 train_time:69750ms step_avg:92.38ms +step:756/1670 train_time:69843ms step_avg:92.39ms +step:757/1670 train_time:69935ms step_avg:92.38ms +step:758/1670 train_time:70026ms step_avg:92.38ms +step:759/1670 train_time:70118ms step_avg:92.38ms +step:760/1670 train_time:70210ms step_avg:92.38ms +step:761/1670 train_time:70303ms step_avg:92.38ms +step:762/1670 train_time:70396ms step_avg:92.38ms +step:763/1670 train_time:70489ms step_avg:92.38ms +step:764/1670 train_time:70583ms step_avg:92.39ms +step:765/1670 train_time:70675ms step_avg:92.39ms +step:766/1670 train_time:70768ms step_avg:92.39ms +step:767/1670 train_time:70860ms step_avg:92.39ms +step:768/1670 train_time:70952ms step_avg:92.39ms +step:769/1670 train_time:71045ms step_avg:92.39ms +step:770/1670 train_time:71137ms step_avg:92.39ms +step:771/1670 train_time:71228ms step_avg:92.38ms +step:772/1670 train_time:71321ms step_avg:92.39ms +step:773/1670 train_time:71415ms step_avg:92.39ms +step:774/1670 train_time:71508ms step_avg:92.39ms +step:775/1670 train_time:71602ms step_avg:92.39ms +step:776/1670 train_time:71695ms step_avg:92.39ms +step:777/1670 train_time:71788ms step_avg:92.39ms +step:778/1670 train_time:71880ms step_avg:92.39ms +step:779/1670 train_time:71972ms step_avg:92.39ms +step:780/1670 train_time:72064ms step_avg:92.39ms +step:781/1670 train_time:72155ms step_avg:92.39ms +step:782/1670 train_time:72247ms step_avg:92.39ms +step:783/1670 train_time:72340ms step_avg:92.39ms +step:784/1670 train_time:72432ms step_avg:92.39ms +step:785/1670 train_time:72526ms step_avg:92.39ms +step:786/1670 train_time:72618ms step_avg:92.39ms +step:787/1670 train_time:72711ms step_avg:92.39ms +step:788/1670 train_time:72804ms step_avg:92.39ms +step:789/1670 train_time:72897ms step_avg:92.39ms +step:790/1670 train_time:72989ms step_avg:92.39ms +step:791/1670 train_time:73081ms step_avg:92.39ms +step:792/1670 train_time:73173ms step_avg:92.39ms +step:793/1670 train_time:73265ms step_avg:92.39ms +step:794/1670 train_time:73358ms step_avg:92.39ms +step:795/1670 train_time:73450ms step_avg:92.39ms +step:796/1670 train_time:73543ms step_avg:92.39ms +step:797/1670 train_time:73637ms step_avg:92.39ms +step:798/1670 train_time:73730ms step_avg:92.39ms +step:799/1670 train_time:73822ms step_avg:92.39ms +step:800/1670 train_time:73915ms step_avg:92.39ms +step:801/1670 train_time:74007ms step_avg:92.39ms +step:802/1670 train_time:74100ms step_avg:92.39ms +step:803/1670 train_time:74191ms step_avg:92.39ms +step:804/1670 train_time:74283ms step_avg:92.39ms +step:805/1670 train_time:74375ms step_avg:92.39ms +step:806/1670 train_time:74467ms step_avg:92.39ms +step:807/1670 train_time:74560ms step_avg:92.39ms +step:808/1670 train_time:74652ms step_avg:92.39ms +step:809/1670 train_time:74745ms step_avg:92.39ms +step:810/1670 train_time:74837ms step_avg:92.39ms +step:811/1670 train_time:74929ms step_avg:92.39ms +step:812/1670 train_time:75022ms step_avg:92.39ms +step:813/1670 train_time:75115ms step_avg:92.39ms +step:814/1670 train_time:75207ms step_avg:92.39ms +step:815/1670 train_time:75299ms step_avg:92.39ms +step:816/1670 train_time:75391ms step_avg:92.39ms +step:817/1670 train_time:75483ms step_avg:92.39ms +step:818/1670 train_time:75577ms step_avg:92.39ms +step:819/1670 train_time:75669ms step_avg:92.39ms +step:820/1670 train_time:75762ms step_avg:92.39ms +step:821/1670 train_time:75854ms step_avg:92.39ms +step:822/1670 train_time:75946ms step_avg:92.39ms +step:823/1670 train_time:76038ms step_avg:92.39ms +step:824/1670 train_time:76130ms step_avg:92.39ms +step:825/1670 train_time:76223ms step_avg:92.39ms +step:826/1670 train_time:76316ms step_avg:92.39ms +step:827/1670 train_time:76408ms step_avg:92.39ms +step:828/1670 train_time:76500ms step_avg:92.39ms +step:829/1670 train_time:76592ms step_avg:92.39ms +step:830/1670 train_time:76684ms step_avg:92.39ms +step:831/1670 train_time:76777ms step_avg:92.39ms +step:832/1670 train_time:76869ms step_avg:92.39ms +step:833/1670 train_time:76962ms step_avg:92.39ms +step:834/1670 train_time:77054ms step_avg:92.39ms +step:835/1670 train_time:77146ms step_avg:92.39ms +step:836/1670 train_time:77239ms step_avg:92.39ms +step:837/1670 train_time:77330ms step_avg:92.39ms +step:838/1670 train_time:77423ms step_avg:92.39ms +step:839/1670 train_time:77515ms step_avg:92.39ms +step:840/1670 train_time:77608ms step_avg:92.39ms +step:841/1670 train_time:77701ms step_avg:92.39ms +step:842/1670 train_time:77793ms step_avg:92.39ms +step:843/1670 train_time:77886ms step_avg:92.39ms +step:844/1670 train_time:77978ms step_avg:92.39ms +step:845/1670 train_time:78071ms step_avg:92.39ms +step:846/1670 train_time:78163ms step_avg:92.39ms +step:847/1670 train_time:78255ms step_avg:92.39ms +step:848/1670 train_time:78347ms step_avg:92.39ms +step:849/1670 train_time:78440ms step_avg:92.39ms +step:850/1670 train_time:78534ms step_avg:92.39ms +step:851/1670 train_time:78784ms step_avg:92.58ms +step:852/1670 train_time:78855ms step_avg:92.55ms +step:853/1670 train_time:78946ms step_avg:92.55ms +step:854/1670 train_time:79037ms step_avg:92.55ms +step:855/1670 train_time:79128ms step_avg:92.55ms +step:856/1670 train_time:79219ms step_avg:92.55ms +step:857/1670 train_time:79310ms step_avg:92.54ms +step:858/1670 train_time:79402ms step_avg:92.54ms +step:859/1670 train_time:79493ms step_avg:92.54ms +step:860/1670 train_time:79585ms step_avg:92.54ms +step:861/1670 train_time:79683ms step_avg:92.55ms +step:862/1670 train_time:79783ms step_avg:92.56ms +step:863/1670 train_time:79877ms step_avg:92.56ms +step:864/1670 train_time:79968ms step_avg:92.56ms +step:865/1670 train_time:80060ms step_avg:92.55ms +step:866/1670 train_time:80151ms step_avg:92.55ms +step:867/1670 train_time:80242ms step_avg:92.55ms +step:868/1670 train_time:80334ms step_avg:92.55ms +step:869/1670 train_time:80424ms step_avg:92.55ms +step:870/1670 train_time:80516ms step_avg:92.55ms +step:871/1670 train_time:80608ms step_avg:92.55ms +step:872/1670 train_time:80702ms step_avg:92.55ms +step:873/1670 train_time:80797ms step_avg:92.55ms +step:874/1670 train_time:80890ms step_avg:92.55ms +step:875/1670 train_time:80983ms step_avg:92.55ms +step:875/1670 val_loss:3.5154 train_time:81077ms step_avg:92.66ms +step:876/1670 train_time:81095ms step_avg:92.57ms +step:877/1670 train_time:81169ms step_avg:92.55ms +step:878/1670 train_time:81263ms step_avg:92.55ms +step:879/1670 train_time:81355ms step_avg:92.55ms +step:880/1670 train_time:81446ms step_avg:92.55ms +step:881/1670 train_time:81537ms step_avg:92.55ms +step:882/1670 train_time:81628ms step_avg:92.55ms +step:883/1670 train_time:81720ms step_avg:92.55ms +step:884/1670 train_time:81812ms step_avg:92.55ms +step:885/1670 train_time:81904ms step_avg:92.55ms +step:886/1670 train_time:81998ms step_avg:92.55ms +step:887/1670 train_time:82092ms step_avg:92.55ms +step:888/1670 train_time:82185ms step_avg:92.55ms +step:889/1670 train_time:82278ms step_avg:92.55ms +step:890/1670 train_time:82370ms step_avg:92.55ms +step:891/1670 train_time:82462ms step_avg:92.55ms +step:892/1670 train_time:82555ms step_avg:92.55ms +step:893/1670 train_time:82646ms step_avg:92.55ms +step:894/1670 train_time:82738ms step_avg:92.55ms +step:895/1670 train_time:82829ms step_avg:92.55ms +step:896/1670 train_time:82922ms step_avg:92.55ms +step:897/1670 train_time:83015ms step_avg:92.55ms +step:898/1670 train_time:83107ms step_avg:92.55ms +step:899/1670 train_time:83201ms step_avg:92.55ms +step:900/1670 train_time:83295ms step_avg:92.55ms +step:901/1670 train_time:83387ms step_avg:92.55ms +step:902/1670 train_time:83479ms step_avg:92.55ms +step:903/1670 train_time:83572ms step_avg:92.55ms +step:904/1670 train_time:83664ms step_avg:92.55ms +step:905/1670 train_time:83756ms step_avg:92.55ms +step:906/1670 train_time:83847ms step_avg:92.55ms +step:907/1670 train_time:83940ms step_avg:92.55ms +step:908/1670 train_time:84033ms step_avg:92.55ms +step:909/1670 train_time:84127ms step_avg:92.55ms +step:910/1670 train_time:84220ms step_avg:92.55ms +step:911/1670 train_time:84313ms step_avg:92.55ms +step:912/1670 train_time:84405ms step_avg:92.55ms +step:913/1670 train_time:84497ms step_avg:92.55ms +step:914/1670 train_time:84589ms step_avg:92.55ms +step:915/1670 train_time:84681ms step_avg:92.55ms +step:916/1670 train_time:84774ms step_avg:92.55ms +step:917/1670 train_time:84867ms step_avg:92.55ms +step:918/1670 train_time:84959ms step_avg:92.55ms +step:919/1670 train_time:85051ms step_avg:92.55ms +step:920/1670 train_time:85144ms step_avg:92.55ms +step:921/1670 train_time:85237ms step_avg:92.55ms +step:922/1670 train_time:85329ms step_avg:92.55ms +step:923/1670 train_time:85423ms step_avg:92.55ms +step:924/1670 train_time:85515ms step_avg:92.55ms +step:925/1670 train_time:85606ms step_avg:92.55ms +step:926/1670 train_time:85699ms step_avg:92.55ms +step:927/1670 train_time:85791ms step_avg:92.55ms +step:928/1670 train_time:85883ms step_avg:92.55ms +step:929/1670 train_time:85976ms step_avg:92.55ms +step:930/1670 train_time:86068ms step_avg:92.55ms +step:931/1670 train_time:86161ms step_avg:92.55ms +step:932/1670 train_time:86254ms step_avg:92.55ms +step:933/1670 train_time:86346ms step_avg:92.55ms +step:934/1670 train_time:86439ms step_avg:92.55ms +step:935/1670 train_time:86532ms step_avg:92.55ms +step:936/1670 train_time:86624ms step_avg:92.55ms +step:937/1670 train_time:86716ms step_avg:92.55ms +step:938/1670 train_time:86807ms step_avg:92.54ms +step:939/1670 train_time:86900ms step_avg:92.55ms +step:940/1670 train_time:86992ms step_avg:92.54ms +step:941/1670 train_time:87084ms step_avg:92.54ms +step:942/1670 train_time:87177ms step_avg:92.54ms +step:943/1670 train_time:87270ms step_avg:92.54ms +step:944/1670 train_time:87363ms step_avg:92.55ms +step:945/1670 train_time:87455ms step_avg:92.54ms +step:946/1670 train_time:87547ms step_avg:92.54ms +step:947/1670 train_time:87639ms step_avg:92.54ms +step:948/1670 train_time:87730ms step_avg:92.54ms +step:949/1670 train_time:87822ms step_avg:92.54ms +step:950/1670 train_time:87915ms step_avg:92.54ms +step:951/1670 train_time:88006ms step_avg:92.54ms +step:952/1670 train_time:88099ms step_avg:92.54ms +step:953/1670 train_time:88192ms step_avg:92.54ms +step:954/1670 train_time:88284ms step_avg:92.54ms +step:955/1670 train_time:88377ms step_avg:92.54ms +step:956/1670 train_time:88470ms step_avg:92.54ms +step:957/1670 train_time:88563ms step_avg:92.54ms +step:958/1670 train_time:88656ms step_avg:92.54ms +step:959/1670 train_time:88747ms step_avg:92.54ms +step:960/1670 train_time:88839ms step_avg:92.54ms +step:961/1670 train_time:88930ms step_avg:92.54ms +step:962/1670 train_time:89023ms step_avg:92.54ms +step:963/1670 train_time:89116ms step_avg:92.54ms +step:964/1670 train_time:89208ms step_avg:92.54ms +step:965/1670 train_time:89301ms step_avg:92.54ms +step:966/1670 train_time:89394ms step_avg:92.54ms +step:967/1670 train_time:89485ms step_avg:92.54ms +step:968/1670 train_time:89578ms step_avg:92.54ms +step:969/1670 train_time:89670ms step_avg:92.54ms +step:970/1670 train_time:89763ms step_avg:92.54ms +step:971/1670 train_time:89856ms step_avg:92.54ms +step:972/1670 train_time:89947ms step_avg:92.54ms +step:973/1670 train_time:90040ms step_avg:92.54ms +step:974/1670 train_time:90132ms step_avg:92.54ms +step:975/1670 train_time:90225ms step_avg:92.54ms +step:976/1670 train_time:90319ms step_avg:92.54ms +step:977/1670 train_time:90410ms step_avg:92.54ms +step:978/1670 train_time:90503ms step_avg:92.54ms +step:979/1670 train_time:90596ms step_avg:92.54ms +step:980/1670 train_time:90688ms step_avg:92.54ms +step:981/1670 train_time:90782ms step_avg:92.54ms +step:982/1670 train_time:90874ms step_avg:92.54ms +step:983/1670 train_time:90966ms step_avg:92.54ms +step:984/1670 train_time:91058ms step_avg:92.54ms +step:985/1670 train_time:91150ms step_avg:92.54ms +step:986/1670 train_time:91242ms step_avg:92.54ms +step:987/1670 train_time:91336ms step_avg:92.54ms +step:988/1670 train_time:91428ms step_avg:92.54ms +step:989/1670 train_time:91520ms step_avg:92.54ms +step:990/1670 train_time:91612ms step_avg:92.54ms +step:991/1670 train_time:91705ms step_avg:92.54ms +step:992/1670 train_time:91799ms step_avg:92.54ms +step:993/1670 train_time:91891ms step_avg:92.54ms +step:994/1670 train_time:91984ms step_avg:92.54ms +step:995/1670 train_time:92075ms step_avg:92.54ms +step:996/1670 train_time:92168ms step_avg:92.54ms +step:997/1670 train_time:92261ms step_avg:92.54ms +step:998/1670 train_time:92354ms step_avg:92.54ms +step:999/1670 train_time:92447ms step_avg:92.54ms +step:1000/1670 train_time:92539ms step_avg:92.54ms +step:1000/1670 val_loss:3.4668 train_time:92631ms step_avg:92.63ms +step:1001/1670 train_time:92649ms step_avg:92.56ms +step:1002/1670 train_time:92725ms step_avg:92.54ms +step:1003/1670 train_time:92818ms step_avg:92.54ms +step:1004/1670 train_time:92910ms step_avg:92.54ms +step:1005/1670 train_time:93002ms step_avg:92.54ms +step:1006/1670 train_time:93093ms step_avg:92.54ms +step:1007/1670 train_time:93185ms step_avg:92.54ms +step:1008/1670 train_time:93277ms step_avg:92.54ms +step:1009/1670 train_time:93368ms step_avg:92.54ms +step:1010/1670 train_time:93461ms step_avg:92.54ms +step:1011/1670 train_time:93553ms step_avg:92.54ms +step:1012/1670 train_time:93646ms step_avg:92.54ms +step:1013/1670 train_time:93740ms step_avg:92.54ms +step:1014/1670 train_time:93833ms step_avg:92.54ms +step:1015/1670 train_time:93926ms step_avg:92.54ms +step:1016/1670 train_time:94019ms step_avg:92.54ms +step:1017/1670 train_time:94110ms step_avg:92.54ms +step:1018/1670 train_time:94203ms step_avg:92.54ms +step:1019/1670 train_time:94295ms step_avg:92.54ms +step:1020/1670 train_time:94387ms step_avg:92.54ms +step:1021/1670 train_time:94479ms step_avg:92.54ms +step:1022/1670 train_time:94571ms step_avg:92.54ms +step:1023/1670 train_time:94664ms step_avg:92.54ms +step:1024/1670 train_time:94757ms step_avg:92.54ms +step:1025/1670 train_time:94849ms step_avg:92.54ms +step:1026/1670 train_time:94942ms step_avg:92.54ms +step:1027/1670 train_time:95034ms step_avg:92.54ms +step:1028/1670 train_time:95126ms step_avg:92.54ms +step:1029/1670 train_time:95219ms step_avg:92.54ms +step:1030/1670 train_time:95311ms step_avg:92.53ms +step:1031/1670 train_time:95403ms step_avg:92.53ms +step:1032/1670 train_time:95495ms step_avg:92.53ms +step:1033/1670 train_time:95587ms step_avg:92.53ms +step:1034/1670 train_time:95680ms step_avg:92.53ms +step:1035/1670 train_time:95774ms step_avg:92.53ms +step:1036/1670 train_time:95867ms step_avg:92.54ms +step:1037/1670 train_time:95960ms step_avg:92.54ms +step:1038/1670 train_time:96052ms step_avg:92.54ms +step:1039/1670 train_time:96145ms step_avg:92.54ms +step:1040/1670 train_time:96237ms step_avg:92.54ms +step:1041/1670 train_time:96329ms step_avg:92.54ms +step:1042/1670 train_time:96423ms step_avg:92.54ms +step:1043/1670 train_time:96514ms step_avg:92.54ms +step:1044/1670 train_time:96606ms step_avg:92.53ms +step:1045/1670 train_time:96699ms step_avg:92.54ms +step:1046/1670 train_time:96791ms step_avg:92.53ms +step:1047/1670 train_time:96885ms step_avg:92.54ms +step:1048/1670 train_time:96977ms step_avg:92.54ms +step:1049/1670 train_time:97070ms step_avg:92.54ms +step:1050/1670 train_time:97162ms step_avg:92.54ms +step:1051/1670 train_time:97253ms step_avg:92.53ms +step:1052/1670 train_time:97347ms step_avg:92.53ms +step:1053/1670 train_time:97439ms step_avg:92.53ms +step:1054/1670 train_time:97531ms step_avg:92.53ms +step:1055/1670 train_time:97624ms step_avg:92.53ms +step:1056/1670 train_time:97716ms step_avg:92.53ms +step:1057/1670 train_time:97809ms step_avg:92.53ms +step:1058/1670 train_time:97902ms step_avg:92.53ms +step:1059/1670 train_time:97994ms step_avg:92.53ms +step:1060/1670 train_time:98087ms step_avg:92.53ms +step:1061/1670 train_time:98178ms step_avg:92.53ms +step:1062/1670 train_time:98423ms step_avg:92.68ms +step:1063/1670 train_time:98498ms step_avg:92.66ms +step:1064/1670 train_time:98589ms step_avg:92.66ms +step:1065/1670 train_time:98680ms step_avg:92.66ms +step:1066/1670 train_time:98771ms step_avg:92.66ms +step:1067/1670 train_time:98863ms step_avg:92.65ms +step:1068/1670 train_time:98954ms step_avg:92.65ms +step:1069/1670 train_time:99045ms step_avg:92.65ms +step:1070/1670 train_time:99136ms step_avg:92.65ms +step:1071/1670 train_time:99227ms step_avg:92.65ms +step:1072/1670 train_time:99323ms step_avg:92.65ms +step:1073/1670 train_time:99420ms step_avg:92.66ms +step:1074/1670 train_time:99513ms step_avg:92.66ms +step:1075/1670 train_time:99605ms step_avg:92.66ms +step:1076/1670 train_time:99697ms step_avg:92.65ms +step:1077/1670 train_time:99789ms step_avg:92.65ms +step:1078/1670 train_time:99881ms step_avg:92.65ms +step:1079/1670 train_time:99973ms step_avg:92.65ms +step:1080/1670 train_time:100066ms step_avg:92.65ms +step:1081/1670 train_time:100158ms step_avg:92.65ms +step:1082/1670 train_time:100251ms step_avg:92.65ms +step:1083/1670 train_time:100347ms step_avg:92.66ms +step:1084/1670 train_time:100441ms step_avg:92.66ms +step:1085/1670 train_time:100534ms step_avg:92.66ms +step:1086/1670 train_time:100627ms step_avg:92.66ms +step:1087/1670 train_time:100719ms step_avg:92.66ms +step:1088/1670 train_time:100811ms step_avg:92.66ms +step:1089/1670 train_time:100903ms step_avg:92.66ms +step:1090/1670 train_time:100994ms step_avg:92.65ms +step:1091/1670 train_time:101086ms step_avg:92.65ms +step:1092/1670 train_time:101178ms step_avg:92.65ms +step:1093/1670 train_time:101270ms step_avg:92.65ms +step:1094/1670 train_time:101365ms step_avg:92.66ms +step:1095/1670 train_time:101460ms step_avg:92.66ms +step:1096/1670 train_time:101553ms step_avg:92.66ms +step:1097/1670 train_time:101646ms step_avg:92.66ms +step:1098/1670 train_time:101738ms step_avg:92.66ms +step:1099/1670 train_time:101830ms step_avg:92.66ms +step:1100/1670 train_time:101921ms step_avg:92.66ms +step:1101/1670 train_time:102013ms step_avg:92.65ms +step:1102/1670 train_time:102104ms step_avg:92.65ms +step:1103/1670 train_time:102196ms step_avg:92.65ms +step:1104/1670 train_time:102289ms step_avg:92.65ms +step:1105/1670 train_time:102383ms step_avg:92.65ms +step:1106/1670 train_time:102474ms step_avg:92.65ms +step:1107/1670 train_time:102570ms step_avg:92.66ms +step:1108/1670 train_time:102664ms step_avg:92.66ms +step:1109/1670 train_time:102755ms step_avg:92.66ms +step:1110/1670 train_time:102848ms step_avg:92.66ms +step:1111/1670 train_time:102940ms step_avg:92.66ms +step:1112/1670 train_time:103032ms step_avg:92.65ms +step:1113/1670 train_time:103124ms step_avg:92.65ms +step:1114/1670 train_time:103217ms step_avg:92.65ms +step:1115/1670 train_time:103495ms step_avg:92.82ms +step:1116/1670 train_time:103577ms step_avg:92.81ms +step:1117/1670 train_time:103669ms step_avg:92.81ms +step:1118/1670 train_time:103761ms step_avg:92.81ms +step:1119/1670 train_time:103852ms step_avg:92.81ms +step:1120/1670 train_time:103944ms step_avg:92.81ms +step:1121/1670 train_time:104036ms step_avg:92.81ms +step:1122/1670 train_time:104128ms step_avg:92.81ms +step:1123/1670 train_time:104220ms step_avg:92.80ms +step:1124/1670 train_time:104311ms step_avg:92.80ms +step:1125/1670 train_time:104411ms step_avg:92.81ms +step:1125/1670 val_loss:3.4147 train_time:104509ms step_avg:92.90ms +step:1126/1670 train_time:104528ms step_avg:92.83ms +step:1127/1670 train_time:104611ms step_avg:92.82ms +step:1128/1670 train_time:104712ms step_avg:92.83ms +step:1129/1670 train_time:104808ms step_avg:92.83ms +step:1130/1670 train_time:104900ms step_avg:92.83ms +step:1131/1670 train_time:104991ms step_avg:92.83ms +step:1132/1670 train_time:105083ms step_avg:92.83ms +step:1133/1670 train_time:105175ms step_avg:92.83ms +step:1134/1670 train_time:105267ms step_avg:92.83ms +step:1135/1670 train_time:105359ms step_avg:92.83ms +step:1136/1670 train_time:105452ms step_avg:92.83ms +step:1137/1670 train_time:105546ms step_avg:92.83ms +step:1138/1670 train_time:105643ms step_avg:92.83ms +step:1139/1670 train_time:105740ms step_avg:92.84ms +step:1140/1670 train_time:105835ms step_avg:92.84ms +step:1141/1670 train_time:105928ms step_avg:92.84ms +step:1142/1670 train_time:106020ms step_avg:92.84ms +step:1143/1670 train_time:106113ms step_avg:92.84ms +step:1144/1670 train_time:106204ms step_avg:92.84ms +step:1145/1670 train_time:106297ms step_avg:92.84ms +step:1146/1670 train_time:106388ms step_avg:92.83ms +step:1147/1670 train_time:106484ms step_avg:92.84ms +step:1148/1670 train_time:106578ms step_avg:92.84ms +step:1149/1670 train_time:106673ms step_avg:92.84ms +step:1150/1670 train_time:106767ms step_avg:92.84ms +step:1151/1670 train_time:106861ms step_avg:92.84ms +step:1152/1670 train_time:106954ms step_avg:92.84ms +step:1153/1670 train_time:107046ms step_avg:92.84ms +step:1154/1670 train_time:107140ms step_avg:92.84ms +step:1155/1670 train_time:107232ms step_avg:92.84ms +step:1156/1670 train_time:107324ms step_avg:92.84ms +step:1157/1670 train_time:107417ms step_avg:92.84ms +step:1158/1670 train_time:107510ms step_avg:92.84ms +step:1159/1670 train_time:107604ms step_avg:92.84ms +step:1160/1670 train_time:107699ms step_avg:92.84ms +step:1161/1670 train_time:107793ms step_avg:92.84ms +step:1162/1670 train_time:107886ms step_avg:92.84ms +step:1163/1670 train_time:107979ms step_avg:92.85ms +step:1164/1670 train_time:108072ms step_avg:92.84ms +step:1165/1670 train_time:108164ms step_avg:92.84ms +step:1166/1670 train_time:108256ms step_avg:92.84ms +step:1167/1670 train_time:108348ms step_avg:92.84ms +step:1168/1670 train_time:108441ms step_avg:92.84ms +step:1169/1670 train_time:108535ms step_avg:92.84ms +step:1170/1670 train_time:108629ms step_avg:92.85ms +step:1171/1670 train_time:108723ms step_avg:92.85ms +step:1172/1670 train_time:108818ms step_avg:92.85ms +step:1173/1670 train_time:108911ms step_avg:92.85ms +step:1174/1670 train_time:109004ms step_avg:92.85ms +step:1175/1670 train_time:109098ms step_avg:92.85ms +step:1176/1670 train_time:109192ms step_avg:92.85ms +step:1177/1670 train_time:109283ms step_avg:92.85ms +step:1178/1670 train_time:109377ms step_avg:92.85ms +step:1179/1670 train_time:109470ms step_avg:92.85ms +step:1180/1670 train_time:109562ms step_avg:92.85ms +step:1181/1670 train_time:109656ms step_avg:92.85ms +step:1182/1670 train_time:109749ms step_avg:92.85ms +step:1183/1670 train_time:109842ms step_avg:92.85ms +step:1184/1670 train_time:109936ms step_avg:92.85ms +step:1185/1670 train_time:110030ms step_avg:92.85ms +step:1186/1670 train_time:110123ms step_avg:92.85ms +step:1187/1670 train_time:110216ms step_avg:92.85ms +step:1188/1670 train_time:110309ms step_avg:92.85ms +step:1189/1670 train_time:110402ms step_avg:92.85ms +step:1190/1670 train_time:110496ms step_avg:92.85ms +step:1191/1670 train_time:110588ms step_avg:92.85ms +step:1192/1670 train_time:110681ms step_avg:92.85ms +step:1193/1670 train_time:110774ms step_avg:92.85ms +step:1194/1670 train_time:110867ms step_avg:92.85ms +step:1195/1670 train_time:110960ms step_avg:92.85ms +step:1196/1670 train_time:111054ms step_avg:92.85ms +step:1197/1670 train_time:111147ms step_avg:92.85ms +step:1198/1670 train_time:111241ms step_avg:92.86ms +step:1199/1670 train_time:111333ms step_avg:92.86ms +step:1200/1670 train_time:111426ms step_avg:92.86ms +step:1201/1670 train_time:111519ms step_avg:92.85ms +step:1202/1670 train_time:111613ms step_avg:92.86ms +step:1203/1670 train_time:111706ms step_avg:92.86ms +step:1204/1670 train_time:111800ms step_avg:92.86ms +step:1205/1670 train_time:111894ms step_avg:92.86ms +step:1206/1670 train_time:111987ms step_avg:92.86ms +step:1207/1670 train_time:112082ms step_avg:92.86ms +step:1208/1670 train_time:112175ms step_avg:92.86ms +step:1209/1670 train_time:112268ms step_avg:92.86ms +step:1210/1670 train_time:112361ms step_avg:92.86ms +step:1211/1670 train_time:112455ms step_avg:92.86ms +step:1212/1670 train_time:112548ms step_avg:92.86ms +step:1213/1670 train_time:112641ms step_avg:92.86ms +step:1214/1670 train_time:112733ms step_avg:92.86ms +step:1215/1670 train_time:112826ms step_avg:92.86ms +step:1216/1670 train_time:112920ms step_avg:92.86ms +step:1217/1670 train_time:113015ms step_avg:92.86ms +step:1218/1670 train_time:113108ms step_avg:92.86ms +step:1219/1670 train_time:113201ms step_avg:92.86ms +step:1220/1670 train_time:113295ms step_avg:92.86ms +step:1221/1670 train_time:113388ms step_avg:92.86ms +step:1222/1670 train_time:113481ms step_avg:92.86ms +step:1223/1670 train_time:113573ms step_avg:92.86ms +step:1224/1670 train_time:113666ms step_avg:92.86ms +step:1225/1670 train_time:113758ms step_avg:92.86ms +step:1226/1670 train_time:113851ms step_avg:92.86ms +step:1227/1670 train_time:113944ms step_avg:92.86ms +step:1228/1670 train_time:114038ms step_avg:92.86ms +step:1229/1670 train_time:114132ms step_avg:92.87ms +step:1230/1670 train_time:114225ms step_avg:92.87ms +step:1231/1670 train_time:114319ms step_avg:92.87ms +step:1232/1670 train_time:114412ms step_avg:92.87ms +step:1233/1670 train_time:114504ms step_avg:92.87ms +step:1234/1670 train_time:114598ms step_avg:92.87ms +step:1235/1670 train_time:114692ms step_avg:92.87ms +step:1236/1670 train_time:114784ms step_avg:92.87ms +step:1237/1670 train_time:114877ms step_avg:92.87ms +step:1238/1670 train_time:114970ms step_avg:92.87ms +step:1239/1670 train_time:115064ms step_avg:92.87ms +step:1240/1670 train_time:115158ms step_avg:92.87ms +step:1241/1670 train_time:115252ms step_avg:92.87ms +step:1242/1670 train_time:115344ms step_avg:92.87ms +step:1243/1670 train_time:115438ms step_avg:92.87ms +step:1244/1670 train_time:115532ms step_avg:92.87ms +step:1245/1670 train_time:115625ms step_avg:92.87ms +step:1246/1670 train_time:115720ms step_avg:92.87ms +step:1247/1670 train_time:115812ms step_avg:92.87ms +step:1248/1670 train_time:115905ms step_avg:92.87ms +step:1249/1670 train_time:115998ms step_avg:92.87ms +step:1250/1670 train_time:116090ms step_avg:92.87ms +step:1250/1670 val_loss:3.3765 train_time:116183ms step_avg:92.95ms +step:1251/1670 train_time:116201ms step_avg:92.89ms +step:1252/1670 train_time:116278ms step_avg:92.87ms +step:1253/1670 train_time:116371ms step_avg:92.87ms +step:1254/1670 train_time:116463ms step_avg:92.87ms +step:1255/1670 train_time:116556ms step_avg:92.87ms +step:1256/1670 train_time:116649ms step_avg:92.87ms +step:1257/1670 train_time:116741ms step_avg:92.87ms +step:1258/1670 train_time:116834ms step_avg:92.87ms +step:1259/1670 train_time:116927ms step_avg:92.87ms +step:1260/1670 train_time:117021ms step_avg:92.87ms +step:1261/1670 train_time:117116ms step_avg:92.88ms +step:1262/1670 train_time:117211ms step_avg:92.88ms +step:1263/1670 train_time:117306ms step_avg:92.88ms +step:1264/1670 train_time:117399ms step_avg:92.88ms +step:1265/1670 train_time:117491ms step_avg:92.88ms +step:1266/1670 train_time:117585ms step_avg:92.88ms +step:1267/1670 train_time:117679ms step_avg:92.88ms +step:1268/1670 train_time:117771ms step_avg:92.88ms +step:1269/1670 train_time:117864ms step_avg:92.88ms +step:1270/1670 train_time:117957ms step_avg:92.88ms +step:1271/1670 train_time:118050ms step_avg:92.88ms +step:1272/1670 train_time:118144ms step_avg:92.88ms +step:1273/1670 train_time:118238ms step_avg:92.88ms +step:1274/1670 train_time:118471ms step_avg:92.99ms +step:1275/1670 train_time:118559ms step_avg:92.99ms +step:1276/1670 train_time:118650ms step_avg:92.99ms +step:1277/1670 train_time:118742ms step_avg:92.98ms +step:1278/1670 train_time:118833ms step_avg:92.98ms +step:1279/1670 train_time:118925ms step_avg:92.98ms +step:1280/1670 train_time:119017ms step_avg:92.98ms +step:1281/1670 train_time:119109ms step_avg:92.98ms +step:1282/1670 train_time:119201ms step_avg:92.98ms +step:1283/1670 train_time:119293ms step_avg:92.98ms +step:1284/1670 train_time:119394ms step_avg:92.99ms +step:1285/1670 train_time:119492ms step_avg:92.99ms +step:1286/1670 train_time:119587ms step_avg:92.99ms +step:1287/1670 train_time:119680ms step_avg:92.99ms +step:1288/1670 train_time:119772ms step_avg:92.99ms +step:1289/1670 train_time:119864ms step_avg:92.99ms +step:1290/1670 train_time:119956ms step_avg:92.99ms +step:1291/1670 train_time:120048ms step_avg:92.99ms +step:1292/1670 train_time:120140ms step_avg:92.99ms +step:1293/1670 train_time:120232ms step_avg:92.99ms +step:1294/1670 train_time:120327ms step_avg:92.99ms +step:1295/1670 train_time:120422ms step_avg:92.99ms +step:1296/1670 train_time:120517ms step_avg:92.99ms +step:1297/1670 train_time:120611ms step_avg:92.99ms +step:1298/1670 train_time:120704ms step_avg:92.99ms +step:1299/1670 train_time:120797ms step_avg:92.99ms +step:1300/1670 train_time:120891ms step_avg:92.99ms +step:1301/1670 train_time:120983ms step_avg:92.99ms +step:1302/1670 train_time:121075ms step_avg:92.99ms +step:1303/1670 train_time:121168ms step_avg:92.99ms +step:1304/1670 train_time:121260ms step_avg:92.99ms +step:1305/1670 train_time:121354ms step_avg:92.99ms +step:1306/1670 train_time:121450ms step_avg:92.99ms +step:1307/1670 train_time:121545ms step_avg:93.00ms +step:1308/1670 train_time:121638ms step_avg:93.00ms +step:1309/1670 train_time:121731ms step_avg:93.00ms +step:1310/1670 train_time:121824ms step_avg:93.00ms +step:1311/1670 train_time:121917ms step_avg:93.00ms +step:1312/1670 train_time:122010ms step_avg:93.00ms +step:1313/1670 train_time:122102ms step_avg:92.99ms +step:1314/1670 train_time:122195ms step_avg:92.99ms +step:1315/1670 train_time:122288ms step_avg:92.99ms +step:1316/1670 train_time:122382ms step_avg:93.00ms +step:1317/1670 train_time:122475ms step_avg:93.00ms +step:1318/1670 train_time:122571ms step_avg:93.00ms +step:1319/1670 train_time:122665ms step_avg:93.00ms +step:1320/1670 train_time:122757ms step_avg:93.00ms +step:1321/1670 train_time:122850ms step_avg:93.00ms +step:1322/1670 train_time:122943ms step_avg:93.00ms +step:1323/1670 train_time:123035ms step_avg:93.00ms +step:1324/1670 train_time:123127ms step_avg:93.00ms +step:1325/1670 train_time:123221ms step_avg:93.00ms +step:1326/1670 train_time:123314ms step_avg:93.00ms +step:1327/1670 train_time:123407ms step_avg:93.00ms +step:1328/1670 train_time:123501ms step_avg:93.00ms +step:1329/1670 train_time:123594ms step_avg:93.00ms +step:1330/1670 train_time:123689ms step_avg:93.00ms +step:1331/1670 train_time:123782ms step_avg:93.00ms +step:1332/1670 train_time:123874ms step_avg:93.00ms +step:1333/1670 train_time:123968ms step_avg:93.00ms +step:1334/1670 train_time:124060ms step_avg:93.00ms +step:1335/1670 train_time:124152ms step_avg:93.00ms +step:1336/1670 train_time:124246ms step_avg:93.00ms +step:1337/1670 train_time:124339ms step_avg:93.00ms +step:1338/1670 train_time:124432ms step_avg:93.00ms +step:1339/1670 train_time:124527ms step_avg:93.00ms +step:1340/1670 train_time:124621ms step_avg:93.00ms +step:1341/1670 train_time:124714ms step_avg:93.00ms +step:1342/1670 train_time:124807ms step_avg:93.00ms +step:1343/1670 train_time:124900ms step_avg:93.00ms +step:1344/1670 train_time:124992ms step_avg:93.00ms +step:1345/1670 train_time:125085ms step_avg:93.00ms +step:1346/1670 train_time:125178ms step_avg:93.00ms +step:1347/1670 train_time:125271ms step_avg:93.00ms +step:1348/1670 train_time:125364ms step_avg:93.00ms +step:1349/1670 train_time:125457ms step_avg:93.00ms +step:1350/1670 train_time:125550ms step_avg:93.00ms +step:1351/1670 train_time:125645ms step_avg:93.00ms +step:1352/1670 train_time:125738ms step_avg:93.00ms +step:1353/1670 train_time:125830ms step_avg:93.00ms +step:1354/1670 train_time:125924ms step_avg:93.00ms +step:1355/1670 train_time:126017ms step_avg:93.00ms +step:1356/1670 train_time:126110ms step_avg:93.00ms +step:1357/1670 train_time:126203ms step_avg:93.00ms +step:1358/1670 train_time:126295ms step_avg:93.00ms +step:1359/1670 train_time:126389ms step_avg:93.00ms +step:1360/1670 train_time:126481ms step_avg:93.00ms +step:1361/1670 train_time:126574ms step_avg:93.00ms +step:1362/1670 train_time:126668ms step_avg:93.00ms +step:1363/1670 train_time:126762ms step_avg:93.00ms +step:1364/1670 train_time:126854ms step_avg:93.00ms +step:1365/1670 train_time:126948ms step_avg:93.00ms +step:1366/1670 train_time:127040ms step_avg:93.00ms +step:1367/1670 train_time:127133ms step_avg:93.00ms +step:1368/1670 train_time:127226ms step_avg:93.00ms +step:1369/1670 train_time:127320ms step_avg:93.00ms +step:1370/1670 train_time:127413ms step_avg:93.00ms +step:1371/1670 train_time:127506ms step_avg:93.00ms +step:1372/1670 train_time:127598ms step_avg:93.00ms +step:1373/1670 train_time:127691ms step_avg:93.00ms +step:1374/1670 train_time:127785ms step_avg:93.00ms +step:1375/1670 train_time:127879ms step_avg:93.00ms +step:1375/1670 val_loss:3.3418 train_time:127971ms step_avg:93.07ms +step:1376/1670 train_time:127989ms step_avg:93.02ms +step:1377/1670 train_time:128066ms step_avg:93.00ms +step:1378/1670 train_time:128159ms step_avg:93.00ms +step:1379/1670 train_time:128253ms step_avg:93.00ms +step:1380/1670 train_time:128345ms step_avg:93.00ms +step:1381/1670 train_time:128438ms step_avg:93.00ms +step:1382/1670 train_time:128530ms step_avg:93.00ms +step:1383/1670 train_time:128622ms step_avg:93.00ms +step:1384/1670 train_time:128714ms step_avg:93.00ms +step:1385/1670 train_time:128808ms step_avg:93.00ms +step:1386/1670 train_time:128904ms step_avg:93.00ms +step:1387/1670 train_time:129002ms step_avg:93.01ms +step:1388/1670 train_time:129096ms step_avg:93.01ms +step:1389/1670 train_time:129189ms step_avg:93.01ms +step:1390/1670 train_time:129281ms step_avg:93.01ms +step:1391/1670 train_time:129374ms step_avg:93.01ms +step:1392/1670 train_time:129466ms step_avg:93.01ms +step:1393/1670 train_time:129559ms step_avg:93.01ms +step:1394/1670 train_time:129652ms step_avg:93.01ms +step:1395/1670 train_time:129744ms step_avg:93.01ms +step:1396/1670 train_time:129839ms step_avg:93.01ms +step:1397/1670 train_time:129933ms step_avg:93.01ms +step:1398/1670 train_time:130026ms step_avg:93.01ms +step:1399/1670 train_time:130120ms step_avg:93.01ms +step:1400/1670 train_time:130213ms step_avg:93.01ms +step:1401/1670 train_time:130306ms step_avg:93.01ms +step:1402/1670 train_time:130399ms step_avg:93.01ms +step:1403/1670 train_time:130492ms step_avg:93.01ms +step:1404/1670 train_time:130584ms step_avg:93.01ms +step:1405/1670 train_time:130677ms step_avg:93.01ms +step:1406/1670 train_time:130770ms step_avg:93.01ms +step:1407/1670 train_time:130863ms step_avg:93.01ms +step:1408/1670 train_time:130957ms step_avg:93.01ms +step:1409/1670 train_time:131051ms step_avg:93.01ms +step:1410/1670 train_time:131144ms step_avg:93.01ms +step:1411/1670 train_time:131238ms step_avg:93.01ms +step:1412/1670 train_time:131332ms step_avg:93.01ms +step:1413/1670 train_time:131425ms step_avg:93.01ms +step:1414/1670 train_time:131519ms step_avg:93.01ms +step:1415/1670 train_time:131611ms step_avg:93.01ms +step:1416/1670 train_time:131704ms step_avg:93.01ms +step:1417/1670 train_time:131797ms step_avg:93.01ms +step:1418/1670 train_time:131892ms step_avg:93.01ms +step:1419/1670 train_time:131985ms step_avg:93.01ms +step:1420/1670 train_time:132078ms step_avg:93.01ms +step:1421/1670 train_time:132172ms step_avg:93.01ms +step:1422/1670 train_time:132265ms step_avg:93.01ms +step:1423/1670 train_time:132359ms step_avg:93.01ms +step:1424/1670 train_time:132453ms step_avg:93.01ms +step:1425/1670 train_time:132545ms step_avg:93.01ms +step:1426/1670 train_time:132639ms step_avg:93.01ms +step:1427/1670 train_time:132732ms step_avg:93.01ms +step:1428/1670 train_time:132825ms step_avg:93.01ms +step:1429/1670 train_time:132920ms step_avg:93.02ms +step:1430/1670 train_time:133013ms step_avg:93.02ms +step:1431/1670 train_time:133105ms step_avg:93.02ms +step:1432/1670 train_time:133200ms step_avg:93.02ms +step:1433/1670 train_time:133293ms step_avg:93.02ms +step:1434/1670 train_time:133385ms step_avg:93.02ms +step:1435/1670 train_time:133478ms step_avg:93.02ms +step:1436/1670 train_time:133571ms step_avg:93.02ms +step:1437/1670 train_time:133664ms step_avg:93.02ms +step:1438/1670 train_time:133758ms step_avg:93.02ms +step:1439/1670 train_time:133852ms step_avg:93.02ms +step:1440/1670 train_time:133945ms step_avg:93.02ms +step:1441/1670 train_time:134038ms step_avg:93.02ms +step:1442/1670 train_time:134132ms step_avg:93.02ms +step:1443/1670 train_time:134226ms step_avg:93.02ms +step:1444/1670 train_time:134320ms step_avg:93.02ms +step:1445/1670 train_time:134413ms step_avg:93.02ms +step:1446/1670 train_time:134506ms step_avg:93.02ms +step:1447/1670 train_time:134600ms step_avg:93.02ms +step:1448/1670 train_time:134693ms step_avg:93.02ms +step:1449/1670 train_time:134786ms step_avg:93.02ms +step:1450/1670 train_time:134880ms step_avg:93.02ms +step:1451/1670 train_time:134973ms step_avg:93.02ms +step:1452/1670 train_time:135067ms step_avg:93.02ms +step:1453/1670 train_time:135160ms step_avg:93.02ms +step:1454/1670 train_time:135254ms step_avg:93.02ms +step:1455/1670 train_time:135347ms step_avg:93.02ms +step:1456/1670 train_time:135441ms step_avg:93.02ms +step:1457/1670 train_time:135533ms step_avg:93.02ms +step:1458/1670 train_time:135626ms step_avg:93.02ms +step:1459/1670 train_time:135720ms step_avg:93.02ms +step:1460/1670 train_time:135813ms step_avg:93.02ms +step:1461/1670 train_time:135906ms step_avg:93.02ms +step:1462/1670 train_time:136000ms step_avg:93.02ms +step:1463/1670 train_time:136093ms step_avg:93.02ms +step:1464/1670 train_time:136186ms step_avg:93.02ms +step:1465/1670 train_time:136279ms step_avg:93.02ms +step:1466/1670 train_time:136373ms step_avg:93.02ms +step:1467/1670 train_time:136466ms step_avg:93.02ms +step:1468/1670 train_time:136560ms step_avg:93.02ms +step:1469/1670 train_time:136653ms step_avg:93.02ms +step:1470/1670 train_time:136744ms step_avg:93.02ms +step:1471/1670 train_time:136838ms step_avg:93.02ms +step:1472/1670 train_time:136932ms step_avg:93.02ms +step:1473/1670 train_time:137025ms step_avg:93.02ms +step:1474/1670 train_time:137118ms step_avg:93.02ms +step:1475/1670 train_time:137209ms step_avg:93.02ms +step:1476/1670 train_time:137303ms step_avg:93.02ms +step:1477/1670 train_time:137398ms step_avg:93.02ms +step:1478/1670 train_time:137491ms step_avg:93.02ms +step:1479/1670 train_time:137583ms step_avg:93.02ms +step:1480/1670 train_time:137677ms step_avg:93.02ms +step:1481/1670 train_time:137771ms step_avg:93.03ms +step:1482/1670 train_time:137864ms step_avg:93.03ms +step:1483/1670 train_time:137958ms step_avg:93.03ms +step:1484/1670 train_time:138050ms step_avg:93.03ms +step:1485/1670 train_time:138282ms step_avg:93.12ms +step:1486/1670 train_time:138374ms step_avg:93.12ms +step:1487/1670 train_time:138466ms step_avg:93.12ms +step:1488/1670 train_time:138557ms step_avg:93.12ms +step:1489/1670 train_time:138648ms step_avg:93.12ms +step:1490/1670 train_time:138740ms step_avg:93.11ms +step:1491/1670 train_time:138832ms step_avg:93.11ms +step:1492/1670 train_time:138924ms step_avg:93.11ms +step:1493/1670 train_time:139016ms step_avg:93.11ms +step:1494/1670 train_time:139108ms step_avg:93.11ms +step:1495/1670 train_time:139206ms step_avg:93.11ms +step:1496/1670 train_time:139303ms step_avg:93.12ms +step:1497/1670 train_time:139397ms step_avg:93.12ms +step:1498/1670 train_time:139491ms step_avg:93.12ms +step:1499/1670 train_time:139583ms step_avg:93.12ms +step:1500/1670 train_time:139675ms step_avg:93.12ms +step:1500/1670 val_loss:3.3119 train_time:139768ms step_avg:93.18ms +step:1501/1670 train_time:139787ms step_avg:93.13ms +step:1502/1670 train_time:139863ms step_avg:93.12ms +step:1503/1670 train_time:139956ms step_avg:93.12ms +step:1504/1670 train_time:140048ms step_avg:93.12ms +step:1505/1670 train_time:140141ms step_avg:93.12ms +step:1506/1670 train_time:140234ms step_avg:93.12ms +step:1507/1670 train_time:140327ms step_avg:93.12ms +step:1508/1670 train_time:140421ms step_avg:93.12ms +step:1509/1670 train_time:140513ms step_avg:93.12ms +step:1510/1670 train_time:140607ms step_avg:93.12ms +step:1511/1670 train_time:140701ms step_avg:93.12ms +step:1512/1670 train_time:140795ms step_avg:93.12ms +step:1513/1670 train_time:140888ms step_avg:93.12ms +step:1514/1670 train_time:140982ms step_avg:93.12ms +step:1515/1670 train_time:141076ms step_avg:93.12ms +step:1516/1670 train_time:141168ms step_avg:93.12ms +step:1517/1670 train_time:141260ms step_avg:93.12ms +step:1518/1670 train_time:141353ms step_avg:93.12ms +step:1519/1670 train_time:141446ms step_avg:93.12ms +step:1520/1670 train_time:141538ms step_avg:93.12ms +step:1521/1670 train_time:141632ms step_avg:93.12ms +step:1522/1670 train_time:141727ms step_avg:93.12ms +step:1523/1670 train_time:141820ms step_avg:93.12ms +step:1524/1670 train_time:141912ms step_avg:93.12ms +step:1525/1670 train_time:142006ms step_avg:93.12ms +step:1526/1670 train_time:142100ms step_avg:93.12ms +step:1527/1670 train_time:142192ms step_avg:93.12ms +step:1528/1670 train_time:142284ms step_avg:93.12ms +step:1529/1670 train_time:142378ms step_avg:93.12ms +step:1530/1670 train_time:142471ms step_avg:93.12ms +step:1531/1670 train_time:142565ms step_avg:93.12ms +step:1532/1670 train_time:142659ms step_avg:93.12ms +step:1533/1670 train_time:142754ms step_avg:93.12ms +step:1534/1670 train_time:142847ms step_avg:93.12ms +step:1535/1670 train_time:142941ms step_avg:93.12ms +step:1536/1670 train_time:143034ms step_avg:93.12ms +step:1537/1670 train_time:143126ms step_avg:93.12ms +step:1538/1670 train_time:143219ms step_avg:93.12ms +step:1539/1670 train_time:143313ms step_avg:93.12ms +step:1540/1670 train_time:143405ms step_avg:93.12ms +step:1541/1670 train_time:143498ms step_avg:93.12ms +step:1542/1670 train_time:143591ms step_avg:93.12ms +step:1543/1670 train_time:143685ms step_avg:93.12ms +step:1544/1670 train_time:143779ms step_avg:93.12ms +step:1545/1670 train_time:143873ms step_avg:93.12ms +step:1546/1670 train_time:143968ms step_avg:93.12ms +step:1547/1670 train_time:144061ms step_avg:93.12ms +step:1548/1670 train_time:144154ms step_avg:93.12ms +step:1549/1670 train_time:144248ms step_avg:93.12ms +step:1550/1670 train_time:144341ms step_avg:93.12ms +step:1551/1670 train_time:144434ms step_avg:93.12ms +step:1552/1670 train_time:144527ms step_avg:93.12ms +step:1553/1670 train_time:144621ms step_avg:93.12ms +step:1554/1670 train_time:144715ms step_avg:93.12ms +step:1555/1670 train_time:144808ms step_avg:93.12ms +step:1556/1670 train_time:144902ms step_avg:93.12ms +step:1557/1670 train_time:144995ms step_avg:93.12ms +step:1558/1670 train_time:145088ms step_avg:93.12ms +step:1559/1670 train_time:145181ms step_avg:93.12ms +step:1560/1670 train_time:145274ms step_avg:93.12ms +step:1561/1670 train_time:145367ms step_avg:93.12ms +step:1562/1670 train_time:145461ms step_avg:93.12ms +step:1563/1670 train_time:145555ms step_avg:93.13ms +step:1564/1670 train_time:145648ms step_avg:93.13ms +step:1565/1670 train_time:145742ms step_avg:93.13ms +step:1566/1670 train_time:145835ms step_avg:93.13ms +step:1567/1670 train_time:145928ms step_avg:93.13ms +step:1568/1670 train_time:146021ms step_avg:93.13ms +step:1569/1670 train_time:146115ms step_avg:93.13ms +step:1570/1670 train_time:146208ms step_avg:93.13ms +step:1571/1670 train_time:146300ms step_avg:93.13ms +step:1572/1670 train_time:146393ms step_avg:93.13ms +step:1573/1670 train_time:146488ms step_avg:93.13ms +step:1574/1670 train_time:146581ms step_avg:93.13ms +step:1575/1670 train_time:146675ms step_avg:93.13ms +step:1576/1670 train_time:146768ms step_avg:93.13ms +step:1577/1670 train_time:146862ms step_avg:93.13ms +step:1578/1670 train_time:146956ms step_avg:93.13ms +step:1579/1670 train_time:147048ms step_avg:93.13ms +step:1580/1670 train_time:147141ms step_avg:93.13ms +step:1581/1670 train_time:147234ms step_avg:93.13ms +step:1582/1670 train_time:147327ms step_avg:93.13ms +step:1583/1670 train_time:147420ms step_avg:93.13ms +step:1584/1670 train_time:147513ms step_avg:93.13ms +step:1585/1670 train_time:147606ms step_avg:93.13ms +step:1586/1670 train_time:147699ms step_avg:93.13ms +step:1587/1670 train_time:147793ms step_avg:93.13ms +step:1588/1670 train_time:147886ms step_avg:93.13ms +step:1589/1670 train_time:147981ms step_avg:93.13ms +step:1590/1670 train_time:148075ms step_avg:93.13ms +step:1591/1670 train_time:148168ms step_avg:93.13ms +step:1592/1670 train_time:148261ms step_avg:93.13ms +step:1593/1670 train_time:148354ms step_avg:93.13ms +step:1594/1670 train_time:148447ms step_avg:93.13ms +step:1595/1670 train_time:148540ms step_avg:93.13ms +step:1596/1670 train_time:148632ms step_avg:93.13ms +step:1597/1670 train_time:148726ms step_avg:93.13ms +step:1598/1670 train_time:148819ms step_avg:93.13ms +step:1599/1670 train_time:148912ms step_avg:93.13ms +step:1600/1670 train_time:149006ms step_avg:93.13ms +step:1601/1670 train_time:149100ms step_avg:93.13ms +step:1602/1670 train_time:149193ms step_avg:93.13ms +step:1603/1670 train_time:149286ms step_avg:93.13ms +step:1604/1670 train_time:149379ms step_avg:93.13ms +step:1605/1670 train_time:149472ms step_avg:93.13ms +step:1606/1670 train_time:149565ms step_avg:93.13ms +step:1607/1670 train_time:149658ms step_avg:93.13ms +step:1608/1670 train_time:149751ms step_avg:93.13ms +step:1609/1670 train_time:149844ms step_avg:93.13ms +step:1610/1670 train_time:149938ms step_avg:93.13ms +step:1611/1670 train_time:150031ms step_avg:93.13ms +step:1612/1670 train_time:150124ms step_avg:93.13ms +step:1613/1670 train_time:150218ms step_avg:93.13ms +step:1614/1670 train_time:150311ms step_avg:93.13ms +step:1615/1670 train_time:150405ms step_avg:93.13ms +step:1616/1670 train_time:150497ms step_avg:93.13ms +step:1617/1670 train_time:150590ms step_avg:93.13ms +step:1618/1670 train_time:150684ms step_avg:93.13ms +step:1619/1670 train_time:150779ms step_avg:93.13ms +step:1620/1670 train_time:150872ms step_avg:93.13ms +step:1621/1670 train_time:150965ms step_avg:93.13ms +step:1622/1670 train_time:151061ms step_avg:93.13ms +step:1623/1670 train_time:151151ms step_avg:93.13ms +step:1624/1670 train_time:151244ms step_avg:93.13ms +step:1625/1670 train_time:151337ms step_avg:93.13ms +step:1625/1670 val_loss:3.2870 train_time:151429ms step_avg:93.19ms +step:1626/1670 train_time:151448ms step_avg:93.14ms +step:1627/1670 train_time:151526ms step_avg:93.13ms +step:1628/1670 train_time:151618ms step_avg:93.13ms +step:1629/1670 train_time:151710ms step_avg:93.13ms +step:1630/1670 train_time:151803ms step_avg:93.13ms +step:1631/1670 train_time:151895ms step_avg:93.13ms +step:1632/1670 train_time:151988ms step_avg:93.13ms +step:1633/1670 train_time:152081ms step_avg:93.13ms +step:1634/1670 train_time:152173ms step_avg:93.13ms +step:1635/1670 train_time:152267ms step_avg:93.13ms +step:1636/1670 train_time:152361ms step_avg:93.13ms +step:1637/1670 train_time:152456ms step_avg:93.13ms +step:1638/1670 train_time:152551ms step_avg:93.13ms +step:1639/1670 train_time:152644ms step_avg:93.13ms +step:1640/1670 train_time:152736ms step_avg:93.13ms +step:1641/1670 train_time:152829ms step_avg:93.13ms +step:1642/1670 train_time:152922ms step_avg:93.13ms +step:1643/1670 train_time:153015ms step_avg:93.13ms +step:1644/1670 train_time:153109ms step_avg:93.13ms +step:1645/1670 train_time:153201ms step_avg:93.13ms +step:1646/1670 train_time:153294ms step_avg:93.13ms +step:1647/1670 train_time:153389ms step_avg:93.13ms +step:1648/1670 train_time:153483ms step_avg:93.13ms +step:1649/1670 train_time:153577ms step_avg:93.13ms +step:1650/1670 train_time:153672ms step_avg:93.13ms +step:1651/1670 train_time:153765ms step_avg:93.13ms +step:1652/1670 train_time:153858ms step_avg:93.13ms +step:1653/1670 train_time:153951ms step_avg:93.13ms +step:1654/1670 train_time:154043ms step_avg:93.13ms +step:1655/1670 train_time:154137ms step_avg:93.13ms +step:1656/1670 train_time:154231ms step_avg:93.13ms +step:1657/1670 train_time:154324ms step_avg:93.13ms +step:1658/1670 train_time:154417ms step_avg:93.13ms +step:1659/1670 train_time:154511ms step_avg:93.14ms +step:1660/1670 train_time:154604ms step_avg:93.13ms +step:1661/1670 train_time:154697ms step_avg:93.13ms +step:1662/1670 train_time:154791ms step_avg:93.14ms +step:1663/1670 train_time:154884ms step_avg:93.14ms +step:1664/1670 train_time:154976ms step_avg:93.13ms +step:1665/1670 train_time:155070ms step_avg:93.14ms +step:1666/1670 train_time:155163ms step_avg:93.14ms +step:1667/1670 train_time:155256ms step_avg:93.13ms +step:1668/1670 train_time:155351ms step_avg:93.14ms +step:1669/1670 train_time:155444ms step_avg:93.14ms +step:1670/1670 train_time:155536ms step_avg:93.14ms +step:1670/1670 val_loss:3.2783 train_time:155799ms step_avg:93.29ms +peak memory allocated: 32002 MiB reserved: 47856 MiB diff --git a/records/091125_VectSigmoidBFloat16/6ecf5ea5-e999-4da1-a501-4fbc7160aec5.txt b/records/091125_VectSigmoidBFloat16/6ecf5ea5-e999-4da1-a501-4fbc7160aec5.txt new file mode 100644 index 000000000..4cc43442f --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/6ecf5ea5-e999-4da1-a501-4fbc7160aec5.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:24:28 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.09ms +step:1/1670 train_time:297ms step_avg:297.03ms +step:2/1670 train_time:315ms step_avg:157.52ms +step:3/1670 train_time:384ms step_avg:127.86ms +step:4/1670 train_time:473ms step_avg:118.16ms +step:5/1670 train_time:563ms step_avg:112.68ms +step:6/1670 train_time:654ms step_avg:108.95ms +step:7/1670 train_time:744ms step_avg:106.26ms +step:8/1670 train_time:834ms step_avg:104.28ms +step:9/1670 train_time:925ms step_avg:102.75ms +step:10/1670 train_time:1016ms step_avg:101.56ms +step:11/1670 train_time:1109ms step_avg:100.78ms +step:12/1670 train_time:1203ms step_avg:100.21ms +step:13/1670 train_time:1298ms step_avg:99.81ms +step:14/1670 train_time:1390ms step_avg:99.32ms +step:15/1670 train_time:1481ms step_avg:98.74ms +step:16/1670 train_time:1573ms step_avg:98.28ms +step:17/1670 train_time:1664ms step_avg:97.91ms +step:18/1670 train_time:1755ms step_avg:97.49ms +step:19/1670 train_time:1845ms step_avg:97.12ms +step:20/1670 train_time:1936ms step_avg:96.79ms +step:21/1670 train_time:2029ms step_avg:96.62ms +step:22/1670 train_time:2121ms step_avg:96.40ms +step:23/1670 train_time:2215ms step_avg:96.31ms +step:24/1670 train_time:2308ms step_avg:96.18ms +step:25/1670 train_time:2401ms step_avg:96.03ms +step:26/1670 train_time:2492ms step_avg:95.86ms +step:27/1670 train_time:2584ms step_avg:95.71ms +step:28/1670 train_time:2675ms step_avg:95.54ms +step:29/1670 train_time:2767ms step_avg:95.42ms +step:30/1670 train_time:2858ms step_avg:95.25ms +step:31/1670 train_time:2949ms step_avg:95.13ms +step:32/1670 train_time:3040ms step_avg:94.99ms +step:33/1670 train_time:3132ms step_avg:94.90ms +step:34/1670 train_time:3224ms step_avg:94.82ms +step:35/1670 train_time:3316ms step_avg:94.75ms +step:36/1670 train_time:3408ms step_avg:94.68ms +step:37/1670 train_time:3501ms step_avg:94.62ms +step:38/1670 train_time:3593ms step_avg:94.54ms +step:39/1670 train_time:3684ms step_avg:94.46ms +step:40/1670 train_time:3775ms step_avg:94.37ms +step:41/1670 train_time:3866ms step_avg:94.28ms +step:42/1670 train_time:3956ms step_avg:94.20ms +step:43/1670 train_time:4048ms step_avg:94.14ms +step:44/1670 train_time:4139ms step_avg:94.08ms +step:45/1670 train_time:4234ms step_avg:94.08ms +step:46/1670 train_time:4327ms step_avg:94.07ms +step:47/1670 train_time:4419ms step_avg:94.03ms +step:48/1670 train_time:4510ms step_avg:93.96ms +step:49/1670 train_time:4602ms step_avg:93.91ms +step:50/1670 train_time:4693ms step_avg:93.85ms +step:51/1670 train_time:4785ms step_avg:93.82ms +step:52/1670 train_time:4876ms step_avg:93.76ms +step:53/1670 train_time:4967ms step_avg:93.72ms +step:54/1670 train_time:5058ms step_avg:93.67ms +step:55/1670 train_time:5150ms step_avg:93.65ms +step:56/1670 train_time:5243ms step_avg:93.63ms +step:57/1670 train_time:5336ms step_avg:93.61ms +step:58/1670 train_time:5428ms step_avg:93.58ms +step:59/1670 train_time:5519ms step_avg:93.55ms +step:60/1670 train_time:5611ms step_avg:93.52ms +step:61/1670 train_time:5703ms step_avg:93.50ms +step:62/1670 train_time:5793ms step_avg:93.44ms +step:63/1670 train_time:5885ms step_avg:93.41ms +step:64/1670 train_time:5975ms step_avg:93.36ms +step:65/1670 train_time:6067ms step_avg:93.34ms +step:66/1670 train_time:6158ms step_avg:93.31ms +step:67/1670 train_time:6253ms step_avg:93.32ms +step:68/1670 train_time:6345ms step_avg:93.30ms +step:69/1670 train_time:6437ms step_avg:93.28ms +step:70/1670 train_time:6528ms step_avg:93.25ms +step:71/1670 train_time:6619ms step_avg:93.23ms +step:72/1670 train_time:6710ms step_avg:93.19ms +step:73/1670 train_time:6800ms step_avg:93.15ms +step:74/1670 train_time:6892ms step_avg:93.13ms +step:75/1670 train_time:6983ms step_avg:93.11ms +step:76/1670 train_time:7076ms step_avg:93.11ms +step:77/1670 train_time:7168ms step_avg:93.09ms +step:78/1670 train_time:7259ms step_avg:93.06ms +step:79/1670 train_time:7353ms step_avg:93.08ms +step:80/1670 train_time:7447ms step_avg:93.09ms +step:81/1670 train_time:7539ms step_avg:93.08ms +step:82/1670 train_time:7632ms step_avg:93.07ms +step:83/1670 train_time:7723ms step_avg:93.05ms +step:84/1670 train_time:7814ms step_avg:93.03ms +step:85/1670 train_time:7905ms step_avg:93.00ms +step:86/1670 train_time:7996ms step_avg:92.98ms +step:87/1670 train_time:8087ms step_avg:92.96ms +step:88/1670 train_time:8178ms step_avg:92.93ms +step:89/1670 train_time:8270ms step_avg:92.92ms +step:90/1670 train_time:8361ms step_avg:92.90ms +step:91/1670 train_time:8454ms step_avg:92.90ms +step:92/1670 train_time:8546ms step_avg:92.90ms +step:93/1670 train_time:8638ms step_avg:92.88ms +step:94/1670 train_time:8729ms step_avg:92.86ms +step:95/1670 train_time:8821ms step_avg:92.85ms +step:96/1670 train_time:8912ms step_avg:92.83ms +step:97/1670 train_time:9003ms step_avg:92.82ms +step:98/1670 train_time:9095ms step_avg:92.80ms +step:99/1670 train_time:9186ms step_avg:92.79ms +step:100/1670 train_time:9277ms step_avg:92.77ms +step:101/1670 train_time:9369ms step_avg:92.76ms +step:102/1670 train_time:9460ms step_avg:92.74ms +step:103/1670 train_time:9552ms step_avg:92.74ms +step:104/1670 train_time:9645ms step_avg:92.74ms +step:105/1670 train_time:9736ms step_avg:92.72ms +step:106/1670 train_time:9827ms step_avg:92.71ms +step:107/1670 train_time:9918ms step_avg:92.69ms +step:108/1670 train_time:10009ms step_avg:92.68ms +step:109/1670 train_time:10102ms step_avg:92.68ms +step:110/1670 train_time:10192ms step_avg:92.66ms +step:111/1670 train_time:10285ms step_avg:92.65ms +step:112/1670 train_time:10376ms step_avg:92.65ms +step:113/1670 train_time:10467ms step_avg:92.63ms +step:114/1670 train_time:10558ms step_avg:92.62ms +step:115/1670 train_time:10651ms step_avg:92.61ms +step:116/1670 train_time:10741ms step_avg:92.60ms +step:117/1670 train_time:10834ms step_avg:92.60ms +step:118/1670 train_time:10926ms step_avg:92.60ms +step:119/1670 train_time:11018ms step_avg:92.59ms +step:120/1670 train_time:11109ms step_avg:92.58ms +step:121/1670 train_time:11200ms step_avg:92.56ms +step:122/1670 train_time:11292ms step_avg:92.55ms +step:123/1670 train_time:11382ms step_avg:92.54ms +step:124/1670 train_time:11474ms step_avg:92.53ms +step:125/1670 train_time:11566ms step_avg:92.53ms +step:125/1670 val_loss:4.2978 train_time:11657ms step_avg:93.26ms +step:126/1670 train_time:11676ms step_avg:92.67ms +step:127/1670 train_time:11750ms step_avg:92.52ms +step:128/1670 train_time:11849ms step_avg:92.57ms +step:129/1670 train_time:11944ms step_avg:92.59ms +step:130/1670 train_time:12037ms step_avg:92.59ms +step:131/1670 train_time:12128ms step_avg:92.58ms +step:132/1670 train_time:12218ms step_avg:92.56ms +step:133/1670 train_time:12308ms step_avg:92.54ms +step:134/1670 train_time:12397ms step_avg:92.52ms +step:135/1670 train_time:12487ms step_avg:92.50ms +step:136/1670 train_time:12578ms step_avg:92.49ms +step:137/1670 train_time:12670ms step_avg:92.48ms +step:138/1670 train_time:12763ms step_avg:92.49ms +step:139/1670 train_time:12857ms step_avg:92.50ms +step:140/1670 train_time:12950ms step_avg:92.50ms +step:141/1670 train_time:13041ms step_avg:92.49ms +step:142/1670 train_time:13132ms step_avg:92.48ms +step:143/1670 train_time:13222ms step_avg:92.46ms +step:144/1670 train_time:13313ms step_avg:92.45ms +step:145/1670 train_time:13403ms step_avg:92.44ms +step:146/1670 train_time:13494ms step_avg:92.43ms +step:147/1670 train_time:13585ms step_avg:92.42ms +step:148/1670 train_time:13677ms step_avg:92.41ms +step:149/1670 train_time:13769ms step_avg:92.41ms +step:150/1670 train_time:13861ms step_avg:92.41ms +step:151/1670 train_time:13954ms step_avg:92.41ms +step:152/1670 train_time:14046ms step_avg:92.41ms +step:153/1670 train_time:14137ms step_avg:92.40ms +step:154/1670 train_time:14229ms step_avg:92.40ms +step:155/1670 train_time:14320ms step_avg:92.39ms +step:156/1670 train_time:14410ms step_avg:92.37ms +step:157/1670 train_time:14501ms step_avg:92.36ms +step:158/1670 train_time:14591ms step_avg:92.35ms +step:159/1670 train_time:14681ms step_avg:92.34ms +step:160/1670 train_time:14772ms step_avg:92.33ms +step:161/1670 train_time:14864ms step_avg:92.32ms +step:162/1670 train_time:14956ms step_avg:92.32ms +step:163/1670 train_time:15047ms step_avg:92.31ms +step:164/1670 train_time:15139ms step_avg:92.31ms +step:165/1670 train_time:15231ms step_avg:92.31ms +step:166/1670 train_time:15322ms step_avg:92.30ms +step:167/1670 train_time:15413ms step_avg:92.30ms +step:168/1670 train_time:15504ms step_avg:92.29ms +step:169/1670 train_time:15595ms step_avg:92.28ms +step:170/1670 train_time:15686ms step_avg:92.27ms +step:171/1670 train_time:15776ms step_avg:92.26ms +step:172/1670 train_time:15868ms step_avg:92.25ms +step:173/1670 train_time:15959ms step_avg:92.25ms +step:174/1670 train_time:16051ms step_avg:92.25ms +step:175/1670 train_time:16141ms step_avg:92.24ms +step:176/1670 train_time:16233ms step_avg:92.24ms +step:177/1670 train_time:16325ms step_avg:92.23ms +step:178/1670 train_time:16417ms step_avg:92.23ms +step:179/1670 train_time:16509ms step_avg:92.23ms +step:180/1670 train_time:16600ms step_avg:92.22ms +step:181/1670 train_time:16692ms step_avg:92.22ms +step:182/1670 train_time:16783ms step_avg:92.21ms +step:183/1670 train_time:16875ms step_avg:92.21ms +step:184/1670 train_time:16965ms step_avg:92.20ms +step:185/1670 train_time:17056ms step_avg:92.20ms +step:186/1670 train_time:17148ms step_avg:92.19ms +step:187/1670 train_time:17239ms step_avg:92.19ms +step:188/1670 train_time:17330ms step_avg:92.18ms +step:189/1670 train_time:17421ms step_avg:92.17ms +step:190/1670 train_time:17512ms step_avg:92.17ms +step:191/1670 train_time:17603ms step_avg:92.16ms +step:192/1670 train_time:17695ms step_avg:92.16ms +step:193/1670 train_time:17787ms step_avg:92.16ms +step:194/1670 train_time:17878ms step_avg:92.16ms +step:195/1670 train_time:17969ms step_avg:92.15ms +step:196/1670 train_time:18060ms step_avg:92.14ms +step:197/1670 train_time:18151ms step_avg:92.14ms +step:198/1670 train_time:18241ms step_avg:92.13ms +step:199/1670 train_time:18332ms step_avg:92.12ms +step:200/1670 train_time:18423ms step_avg:92.11ms +step:201/1670 train_time:18515ms step_avg:92.11ms +step:202/1670 train_time:18605ms step_avg:92.11ms +step:203/1670 train_time:18697ms step_avg:92.10ms +step:204/1670 train_time:18788ms step_avg:92.10ms +step:205/1670 train_time:18879ms step_avg:92.09ms +step:206/1670 train_time:18970ms step_avg:92.09ms +step:207/1670 train_time:19060ms step_avg:92.08ms +step:208/1670 train_time:19151ms step_avg:92.07ms +step:209/1670 train_time:19241ms step_avg:92.06ms +step:210/1670 train_time:19332ms step_avg:92.06ms +step:211/1670 train_time:19423ms step_avg:92.05ms +step:212/1670 train_time:19515ms step_avg:92.05ms +step:213/1670 train_time:19762ms step_avg:92.78ms +step:214/1670 train_time:19834ms step_avg:92.68ms +step:215/1670 train_time:19923ms step_avg:92.67ms +step:216/1670 train_time:20013ms step_avg:92.65ms +step:217/1670 train_time:20103ms step_avg:92.64ms +step:218/1670 train_time:20193ms step_avg:92.63ms +step:219/1670 train_time:20284ms step_avg:92.62ms +step:220/1670 train_time:20374ms step_avg:92.61ms +step:221/1670 train_time:20464ms step_avg:92.60ms +step:222/1670 train_time:20554ms step_avg:92.59ms +step:223/1670 train_time:20648ms step_avg:92.59ms +step:224/1670 train_time:20743ms step_avg:92.60ms +step:225/1670 train_time:20837ms step_avg:92.61ms +step:226/1670 train_time:20929ms step_avg:92.61ms +step:227/1670 train_time:21019ms step_avg:92.59ms +step:228/1670 train_time:21110ms step_avg:92.59ms +step:229/1670 train_time:21200ms step_avg:92.58ms +step:230/1670 train_time:21290ms step_avg:92.57ms +step:231/1670 train_time:21380ms step_avg:92.55ms +step:232/1670 train_time:21470ms step_avg:92.54ms +step:233/1670 train_time:21561ms step_avg:92.53ms +step:234/1670 train_time:21655ms step_avg:92.54ms +step:235/1670 train_time:21748ms step_avg:92.54ms +step:236/1670 train_time:21841ms step_avg:92.55ms +step:237/1670 train_time:21934ms step_avg:92.55ms +step:238/1670 train_time:22024ms step_avg:92.54ms +step:239/1670 train_time:22115ms step_avg:92.53ms +step:240/1670 train_time:22206ms step_avg:92.53ms +step:241/1670 train_time:22297ms step_avg:92.52ms +step:242/1670 train_time:22388ms step_avg:92.51ms +step:243/1670 train_time:22478ms step_avg:92.50ms +step:244/1670 train_time:22570ms step_avg:92.50ms +step:245/1670 train_time:22662ms step_avg:92.50ms +step:246/1670 train_time:22756ms step_avg:92.50ms +step:247/1670 train_time:22847ms step_avg:92.50ms +step:248/1670 train_time:22939ms step_avg:92.50ms +step:249/1670 train_time:23030ms step_avg:92.49ms +step:250/1670 train_time:23121ms step_avg:92.49ms +step:250/1670 val_loss:3.9621 train_time:23212ms step_avg:92.85ms +step:251/1670 train_time:23231ms step_avg:92.55ms +step:252/1670 train_time:23304ms step_avg:92.48ms +step:253/1670 train_time:23396ms step_avg:92.48ms +step:254/1670 train_time:23488ms step_avg:92.47ms +step:255/1670 train_time:23578ms step_avg:92.46ms +step:256/1670 train_time:23669ms step_avg:92.46ms +step:257/1670 train_time:23759ms step_avg:92.45ms +step:258/1670 train_time:23849ms step_avg:92.44ms +step:259/1670 train_time:23940ms step_avg:92.43ms +step:260/1670 train_time:24031ms step_avg:92.43ms +step:261/1670 train_time:24123ms step_avg:92.42ms +step:262/1670 train_time:24216ms step_avg:92.43ms +step:263/1670 train_time:24309ms step_avg:92.43ms +step:264/1670 train_time:24400ms step_avg:92.42ms +step:265/1670 train_time:24492ms step_avg:92.42ms +step:266/1670 train_time:24584ms step_avg:92.42ms +step:267/1670 train_time:24675ms step_avg:92.41ms +step:268/1670 train_time:24765ms step_avg:92.41ms +step:269/1670 train_time:24854ms step_avg:92.40ms +step:270/1670 train_time:24945ms step_avg:92.39ms +step:271/1670 train_time:25036ms step_avg:92.38ms +step:272/1670 train_time:25129ms step_avg:92.39ms +step:273/1670 train_time:25221ms step_avg:92.39ms +step:274/1670 train_time:25316ms step_avg:92.39ms +step:275/1670 train_time:25408ms step_avg:92.39ms +step:276/1670 train_time:25498ms step_avg:92.38ms +step:277/1670 train_time:25590ms step_avg:92.38ms +step:278/1670 train_time:25681ms step_avg:92.38ms +step:279/1670 train_time:25772ms step_avg:92.37ms +step:280/1670 train_time:25863ms step_avg:92.37ms +step:281/1670 train_time:25954ms step_avg:92.36ms +step:282/1670 train_time:26045ms step_avg:92.36ms +step:283/1670 train_time:26136ms step_avg:92.35ms +step:284/1670 train_time:26228ms step_avg:92.35ms +step:285/1670 train_time:26319ms step_avg:92.35ms +step:286/1670 train_time:26412ms step_avg:92.35ms +step:287/1670 train_time:26503ms step_avg:92.35ms +step:288/1670 train_time:26595ms step_avg:92.34ms +step:289/1670 train_time:26685ms step_avg:92.34ms +step:290/1670 train_time:26776ms step_avg:92.33ms +step:291/1670 train_time:26867ms step_avg:92.33ms +step:292/1670 train_time:26957ms step_avg:92.32ms +step:293/1670 train_time:27048ms step_avg:92.31ms +step:294/1670 train_time:27139ms step_avg:92.31ms +step:295/1670 train_time:27230ms step_avg:92.31ms +step:296/1670 train_time:27322ms step_avg:92.30ms +step:297/1670 train_time:27414ms step_avg:92.30ms +step:298/1670 train_time:27506ms step_avg:92.30ms +step:299/1670 train_time:27597ms step_avg:92.30ms +step:300/1670 train_time:27689ms step_avg:92.30ms +step:301/1670 train_time:27780ms step_avg:92.29ms +step:302/1670 train_time:27872ms step_avg:92.29ms +step:303/1670 train_time:27963ms step_avg:92.29ms +step:304/1670 train_time:28053ms step_avg:92.28ms +step:305/1670 train_time:28145ms step_avg:92.28ms +step:306/1670 train_time:28236ms step_avg:92.27ms +step:307/1670 train_time:28328ms step_avg:92.28ms +step:308/1670 train_time:28419ms step_avg:92.27ms +step:309/1670 train_time:28511ms step_avg:92.27ms +step:310/1670 train_time:28602ms step_avg:92.26ms +step:311/1670 train_time:28695ms step_avg:92.27ms +step:312/1670 train_time:28786ms step_avg:92.26ms +step:313/1670 train_time:28877ms step_avg:92.26ms +step:314/1670 train_time:28968ms step_avg:92.25ms +step:315/1670 train_time:29058ms step_avg:92.25ms +step:316/1670 train_time:29149ms step_avg:92.24ms +step:317/1670 train_time:29240ms step_avg:92.24ms +step:318/1670 train_time:29331ms step_avg:92.24ms +step:319/1670 train_time:29423ms step_avg:92.23ms +step:320/1670 train_time:29514ms step_avg:92.23ms +step:321/1670 train_time:29605ms step_avg:92.23ms +step:322/1670 train_time:29696ms step_avg:92.22ms +step:323/1670 train_time:29789ms step_avg:92.23ms +step:324/1670 train_time:29880ms step_avg:92.22ms +step:325/1670 train_time:29970ms step_avg:92.22ms +step:326/1670 train_time:30060ms step_avg:92.21ms +step:327/1670 train_time:30152ms step_avg:92.21ms +step:328/1670 train_time:30243ms step_avg:92.20ms +step:329/1670 train_time:30334ms step_avg:92.20ms +step:330/1670 train_time:30426ms step_avg:92.20ms +step:331/1670 train_time:30517ms step_avg:92.20ms +step:332/1670 train_time:30610ms step_avg:92.20ms +step:333/1670 train_time:30700ms step_avg:92.19ms +step:334/1670 train_time:30792ms step_avg:92.19ms +step:335/1670 train_time:30883ms step_avg:92.19ms +step:336/1670 train_time:30974ms step_avg:92.19ms +step:337/1670 train_time:31065ms step_avg:92.18ms +step:338/1670 train_time:31156ms step_avg:92.18ms +step:339/1670 train_time:31247ms step_avg:92.17ms +step:340/1670 train_time:31338ms step_avg:92.17ms +step:341/1670 train_time:31430ms step_avg:92.17ms +step:342/1670 train_time:31520ms step_avg:92.16ms +step:343/1670 train_time:31613ms step_avg:92.17ms +step:344/1670 train_time:31704ms step_avg:92.16ms +step:345/1670 train_time:31796ms step_avg:92.16ms +step:346/1670 train_time:31887ms step_avg:92.16ms +step:347/1670 train_time:31977ms step_avg:92.15ms +step:348/1670 train_time:32069ms step_avg:92.15ms +step:349/1670 train_time:32159ms step_avg:92.15ms +step:350/1670 train_time:32250ms step_avg:92.14ms +step:351/1670 train_time:32341ms step_avg:92.14ms +step:352/1670 train_time:32432ms step_avg:92.14ms +step:353/1670 train_time:32523ms step_avg:92.13ms +step:354/1670 train_time:32615ms step_avg:92.13ms +step:355/1670 train_time:32707ms step_avg:92.13ms +step:356/1670 train_time:32797ms step_avg:92.13ms +step:357/1670 train_time:32888ms step_avg:92.12ms +step:358/1670 train_time:32978ms step_avg:92.12ms +step:359/1670 train_time:33069ms step_avg:92.11ms +step:360/1670 train_time:33161ms step_avg:92.11ms +step:361/1670 train_time:33252ms step_avg:92.11ms +step:362/1670 train_time:33344ms step_avg:92.11ms +step:363/1670 train_time:33434ms step_avg:92.11ms +step:364/1670 train_time:33525ms step_avg:92.10ms +step:365/1670 train_time:33617ms step_avg:92.10ms +step:366/1670 train_time:33709ms step_avg:92.10ms +step:367/1670 train_time:33800ms step_avg:92.10ms +step:368/1670 train_time:33891ms step_avg:92.10ms +step:369/1670 train_time:33982ms step_avg:92.09ms +step:370/1670 train_time:34074ms step_avg:92.09ms +step:371/1670 train_time:34166ms step_avg:92.09ms +step:372/1670 train_time:34257ms step_avg:92.09ms +step:373/1670 train_time:34348ms step_avg:92.09ms +step:374/1670 train_time:34440ms step_avg:92.08ms +step:375/1670 train_time:34531ms step_avg:92.08ms +step:375/1670 val_loss:3.8126 train_time:34621ms step_avg:92.32ms +step:376/1670 train_time:34642ms step_avg:92.13ms +step:377/1670 train_time:34716ms step_avg:92.08ms +step:378/1670 train_time:34807ms step_avg:92.08ms +step:379/1670 train_time:34898ms step_avg:92.08ms +step:380/1670 train_time:34989ms step_avg:92.08ms +step:381/1670 train_time:35081ms step_avg:92.08ms +step:382/1670 train_time:35172ms step_avg:92.07ms +step:383/1670 train_time:35263ms step_avg:92.07ms +step:384/1670 train_time:35355ms step_avg:92.07ms +step:385/1670 train_time:35446ms step_avg:92.07ms +step:386/1670 train_time:35537ms step_avg:92.07ms +step:387/1670 train_time:35630ms step_avg:92.07ms +step:388/1670 train_time:35723ms step_avg:92.07ms +step:389/1670 train_time:35815ms step_avg:92.07ms +step:390/1670 train_time:35904ms step_avg:92.06ms +step:391/1670 train_time:35995ms step_avg:92.06ms +step:392/1670 train_time:36086ms step_avg:92.06ms +step:393/1670 train_time:36177ms step_avg:92.05ms +step:394/1670 train_time:36267ms step_avg:92.05ms +step:395/1670 train_time:36360ms step_avg:92.05ms +step:396/1670 train_time:36450ms step_avg:92.05ms +step:397/1670 train_time:36543ms step_avg:92.05ms +step:398/1670 train_time:36635ms step_avg:92.05ms +step:399/1670 train_time:36726ms step_avg:92.04ms +step:400/1670 train_time:36817ms step_avg:92.04ms +step:401/1670 train_time:36908ms step_avg:92.04ms +step:402/1670 train_time:36999ms step_avg:92.04ms +step:403/1670 train_time:37089ms step_avg:92.03ms +step:404/1670 train_time:37183ms step_avg:92.04ms +step:405/1670 train_time:37273ms step_avg:92.03ms +step:406/1670 train_time:37364ms step_avg:92.03ms +step:407/1670 train_time:37456ms step_avg:92.03ms +step:408/1670 train_time:37548ms step_avg:92.03ms +step:409/1670 train_time:37639ms step_avg:92.03ms +step:410/1670 train_time:37730ms step_avg:92.03ms +step:411/1670 train_time:37821ms step_avg:92.02ms +step:412/1670 train_time:37912ms step_avg:92.02ms +step:413/1670 train_time:38003ms step_avg:92.02ms +step:414/1670 train_time:38095ms step_avg:92.02ms +step:415/1670 train_time:38187ms step_avg:92.02ms +step:416/1670 train_time:38279ms step_avg:92.02ms +step:417/1670 train_time:38369ms step_avg:92.01ms +step:418/1670 train_time:38461ms step_avg:92.01ms +step:419/1670 train_time:38552ms step_avg:92.01ms +step:420/1670 train_time:38645ms step_avg:92.01ms +step:421/1670 train_time:38737ms step_avg:92.01ms +step:422/1670 train_time:38827ms step_avg:92.01ms +step:423/1670 train_time:38918ms step_avg:92.00ms +step:424/1670 train_time:39008ms step_avg:92.00ms +step:425/1670 train_time:39260ms step_avg:92.38ms +step:426/1670 train_time:39329ms step_avg:92.32ms +step:427/1670 train_time:39419ms step_avg:92.32ms +step:428/1670 train_time:39509ms step_avg:92.31ms +step:429/1670 train_time:39599ms step_avg:92.30ms +step:430/1670 train_time:39688ms step_avg:92.30ms +step:431/1670 train_time:39778ms step_avg:92.29ms +step:432/1670 train_time:39868ms step_avg:92.29ms +step:433/1670 train_time:39958ms step_avg:92.28ms +step:434/1670 train_time:40048ms step_avg:92.28ms +step:435/1670 train_time:40142ms step_avg:92.28ms +step:436/1670 train_time:40238ms step_avg:92.29ms +step:437/1670 train_time:40331ms step_avg:92.29ms +step:438/1670 train_time:40422ms step_avg:92.29ms +step:439/1670 train_time:40513ms step_avg:92.28ms +step:440/1670 train_time:40603ms step_avg:92.28ms +step:441/1670 train_time:40693ms step_avg:92.27ms +step:442/1670 train_time:40783ms step_avg:92.27ms +step:443/1670 train_time:40874ms step_avg:92.27ms +step:444/1670 train_time:40964ms step_avg:92.26ms +step:445/1670 train_time:41055ms step_avg:92.26ms +step:446/1670 train_time:41148ms step_avg:92.26ms +step:447/1670 train_time:41242ms step_avg:92.26ms +step:448/1670 train_time:41335ms step_avg:92.26ms +step:449/1670 train_time:41426ms step_avg:92.26ms +step:450/1670 train_time:41517ms step_avg:92.26ms +step:451/1670 train_time:41607ms step_avg:92.26ms +step:452/1670 train_time:41698ms step_avg:92.25ms +step:453/1670 train_time:41788ms step_avg:92.25ms +step:454/1670 train_time:41880ms step_avg:92.25ms +step:455/1670 train_time:41970ms step_avg:92.24ms +step:456/1670 train_time:42061ms step_avg:92.24ms +step:457/1670 train_time:42153ms step_avg:92.24ms +step:458/1670 train_time:42247ms step_avg:92.24ms +step:459/1670 train_time:42340ms step_avg:92.24ms +step:460/1670 train_time:42431ms step_avg:92.24ms +step:461/1670 train_time:42522ms step_avg:92.24ms +step:462/1670 train_time:42613ms step_avg:92.24ms +step:463/1670 train_time:42704ms step_avg:92.23ms +step:464/1670 train_time:42795ms step_avg:92.23ms +step:465/1670 train_time:42886ms step_avg:92.23ms +step:466/1670 train_time:42977ms step_avg:92.23ms +step:467/1670 train_time:43069ms step_avg:92.22ms +step:468/1670 train_time:43162ms step_avg:92.23ms +step:469/1670 train_time:43254ms step_avg:92.23ms +step:470/1670 train_time:43347ms step_avg:92.23ms +step:471/1670 train_time:43439ms step_avg:92.23ms +step:472/1670 train_time:43529ms step_avg:92.22ms +step:473/1670 train_time:43621ms step_avg:92.22ms +step:474/1670 train_time:43712ms step_avg:92.22ms +step:475/1670 train_time:43802ms step_avg:92.22ms +step:476/1670 train_time:43892ms step_avg:92.21ms +step:477/1670 train_time:43984ms step_avg:92.21ms +step:478/1670 train_time:44076ms step_avg:92.21ms +step:479/1670 train_time:44167ms step_avg:92.21ms +step:480/1670 train_time:44259ms step_avg:92.21ms +step:481/1670 train_time:44351ms step_avg:92.20ms +step:482/1670 train_time:44442ms step_avg:92.20ms +step:483/1670 train_time:44533ms step_avg:92.20ms +step:484/1670 train_time:44624ms step_avg:92.20ms +step:485/1670 train_time:44715ms step_avg:92.20ms +step:486/1670 train_time:44806ms step_avg:92.19ms +step:487/1670 train_time:44896ms step_avg:92.19ms +step:488/1670 train_time:44987ms step_avg:92.19ms +step:489/1670 train_time:45078ms step_avg:92.18ms +step:490/1670 train_time:45169ms step_avg:92.18ms +step:491/1670 train_time:45260ms step_avg:92.18ms +step:492/1670 train_time:45351ms step_avg:92.18ms +step:493/1670 train_time:45443ms step_avg:92.18ms +step:494/1670 train_time:45534ms step_avg:92.17ms +step:495/1670 train_time:45625ms step_avg:92.17ms +step:496/1670 train_time:45716ms step_avg:92.17ms +step:497/1670 train_time:45806ms step_avg:92.17ms +step:498/1670 train_time:45897ms step_avg:92.16ms +step:499/1670 train_time:45988ms step_avg:92.16ms +step:500/1670 train_time:46079ms step_avg:92.16ms +step:500/1670 val_loss:3.7158 train_time:46170ms step_avg:92.34ms +step:501/1670 train_time:46190ms step_avg:92.20ms +step:502/1670 train_time:46263ms step_avg:92.16ms +step:503/1670 train_time:46356ms step_avg:92.16ms +step:504/1670 train_time:46448ms step_avg:92.16ms +step:505/1670 train_time:46539ms step_avg:92.16ms +step:506/1670 train_time:46630ms step_avg:92.15ms +step:507/1670 train_time:46721ms step_avg:92.15ms +step:508/1670 train_time:46813ms step_avg:92.15ms +step:509/1670 train_time:46904ms step_avg:92.15ms +step:510/1670 train_time:46995ms step_avg:92.15ms +step:511/1670 train_time:47086ms step_avg:92.14ms +step:512/1670 train_time:47178ms step_avg:92.14ms +step:513/1670 train_time:47269ms step_avg:92.14ms +step:514/1670 train_time:47362ms step_avg:92.14ms +step:515/1670 train_time:47454ms step_avg:92.14ms +step:516/1670 train_time:47545ms step_avg:92.14ms +step:517/1670 train_time:47636ms step_avg:92.14ms +step:518/1670 train_time:47727ms step_avg:92.14ms +step:519/1670 train_time:47819ms step_avg:92.14ms +step:520/1670 train_time:47909ms step_avg:92.13ms +step:521/1670 train_time:48000ms step_avg:92.13ms +step:522/1670 train_time:48092ms step_avg:92.13ms +step:523/1670 train_time:48184ms step_avg:92.13ms +step:524/1670 train_time:48275ms step_avg:92.13ms +step:525/1670 train_time:48367ms step_avg:92.13ms +step:526/1670 train_time:48460ms step_avg:92.13ms +step:527/1670 train_time:48550ms step_avg:92.13ms +step:528/1670 train_time:48641ms step_avg:92.12ms +step:529/1670 train_time:48732ms step_avg:92.12ms +step:530/1670 train_time:48823ms step_avg:92.12ms +step:531/1670 train_time:48914ms step_avg:92.12ms +step:532/1670 train_time:49004ms step_avg:92.11ms +step:533/1670 train_time:49095ms step_avg:92.11ms +step:534/1670 train_time:49185ms step_avg:92.11ms +step:535/1670 train_time:49277ms step_avg:92.11ms +step:536/1670 train_time:49368ms step_avg:92.10ms +step:537/1670 train_time:49460ms step_avg:92.10ms +step:538/1670 train_time:49551ms step_avg:92.10ms +step:539/1670 train_time:49642ms step_avg:92.10ms +step:540/1670 train_time:49732ms step_avg:92.10ms +step:541/1670 train_time:49824ms step_avg:92.10ms +step:542/1670 train_time:49916ms step_avg:92.10ms +step:543/1670 train_time:50007ms step_avg:92.09ms +step:544/1670 train_time:50098ms step_avg:92.09ms +step:545/1670 train_time:50188ms step_avg:92.09ms +step:546/1670 train_time:50279ms step_avg:92.09ms +step:547/1670 train_time:50371ms step_avg:92.09ms +step:548/1670 train_time:50462ms step_avg:92.08ms +step:549/1670 train_time:50553ms step_avg:92.08ms +step:550/1670 train_time:50645ms step_avg:92.08ms +step:551/1670 train_time:50737ms step_avg:92.08ms +step:552/1670 train_time:50828ms step_avg:92.08ms +step:553/1670 train_time:50918ms step_avg:92.08ms +step:554/1670 train_time:51009ms step_avg:92.07ms +step:555/1670 train_time:51099ms step_avg:92.07ms +step:556/1670 train_time:51190ms step_avg:92.07ms +step:557/1670 train_time:51281ms step_avg:92.07ms +step:558/1670 train_time:51569ms step_avg:92.42ms +step:559/1670 train_time:51642ms step_avg:92.38ms +step:560/1670 train_time:51734ms step_avg:92.38ms +step:561/1670 train_time:51825ms step_avg:92.38ms +step:562/1670 train_time:51917ms step_avg:92.38ms +step:563/1670 train_time:52007ms step_avg:92.38ms +step:564/1670 train_time:52099ms step_avg:92.37ms +step:565/1670 train_time:52190ms step_avg:92.37ms +step:566/1670 train_time:52281ms step_avg:92.37ms +step:567/1670 train_time:52373ms step_avg:92.37ms +step:568/1670 train_time:52473ms step_avg:92.38ms +step:569/1670 train_time:52569ms step_avg:92.39ms +step:570/1670 train_time:52663ms step_avg:92.39ms +step:571/1670 train_time:52755ms step_avg:92.39ms +step:572/1670 train_time:52847ms step_avg:92.39ms +step:573/1670 train_time:52939ms step_avg:92.39ms +step:574/1670 train_time:53030ms step_avg:92.39ms +step:575/1670 train_time:53122ms step_avg:92.39ms +step:576/1670 train_time:53213ms step_avg:92.38ms +step:577/1670 train_time:53304ms step_avg:92.38ms +step:578/1670 train_time:53399ms step_avg:92.39ms +step:579/1670 train_time:53493ms step_avg:92.39ms +step:580/1670 train_time:53587ms step_avg:92.39ms +step:581/1670 train_time:53680ms step_avg:92.39ms +step:582/1670 train_time:53772ms step_avg:92.39ms +step:583/1670 train_time:53866ms step_avg:92.40ms +step:584/1670 train_time:53959ms step_avg:92.40ms +step:585/1670 train_time:54050ms step_avg:92.39ms +step:586/1670 train_time:54141ms step_avg:92.39ms +step:587/1670 train_time:54232ms step_avg:92.39ms +step:588/1670 train_time:54325ms step_avg:92.39ms +step:589/1670 train_time:54417ms step_avg:92.39ms +step:590/1670 train_time:54510ms step_avg:92.39ms +step:591/1670 train_time:54604ms step_avg:92.39ms +step:592/1670 train_time:54698ms step_avg:92.40ms +step:593/1670 train_time:54791ms step_avg:92.40ms +step:594/1670 train_time:54883ms step_avg:92.40ms +step:595/1670 train_time:54976ms step_avg:92.40ms +step:596/1670 train_time:55068ms step_avg:92.40ms +step:597/1670 train_time:55160ms step_avg:92.40ms +step:598/1670 train_time:55252ms step_avg:92.39ms +step:599/1670 train_time:55345ms step_avg:92.40ms +step:600/1670 train_time:55439ms step_avg:92.40ms +step:601/1670 train_time:55531ms step_avg:92.40ms +step:602/1670 train_time:55624ms step_avg:92.40ms +step:603/1670 train_time:55717ms step_avg:92.40ms +step:604/1670 train_time:55810ms step_avg:92.40ms +step:605/1670 train_time:55902ms step_avg:92.40ms +step:606/1670 train_time:55995ms step_avg:92.40ms +step:607/1670 train_time:56087ms step_avg:92.40ms +step:608/1670 train_time:56178ms step_avg:92.40ms +step:609/1670 train_time:56270ms step_avg:92.40ms +step:610/1670 train_time:56364ms step_avg:92.40ms +step:611/1670 train_time:56457ms step_avg:92.40ms +step:612/1670 train_time:56549ms step_avg:92.40ms +step:613/1670 train_time:56642ms step_avg:92.40ms +step:614/1670 train_time:56734ms step_avg:92.40ms +step:615/1670 train_time:56827ms step_avg:92.40ms +step:616/1670 train_time:56920ms step_avg:92.40ms +step:617/1670 train_time:57011ms step_avg:92.40ms +step:618/1670 train_time:57103ms step_avg:92.40ms +step:619/1670 train_time:57196ms step_avg:92.40ms +step:620/1670 train_time:57288ms step_avg:92.40ms +step:621/1670 train_time:57381ms step_avg:92.40ms +step:622/1670 train_time:57474ms step_avg:92.40ms +step:623/1670 train_time:57567ms step_avg:92.40ms +step:624/1670 train_time:57660ms step_avg:92.40ms +step:625/1670 train_time:57751ms step_avg:92.40ms +step:625/1670 val_loss:3.6125 train_time:57844ms step_avg:92.55ms +step:626/1670 train_time:57865ms step_avg:92.44ms +step:627/1670 train_time:57943ms step_avg:92.41ms +step:628/1670 train_time:58044ms step_avg:92.43ms +step:629/1670 train_time:58138ms step_avg:92.43ms +step:630/1670 train_time:58230ms step_avg:92.43ms +step:631/1670 train_time:58321ms step_avg:92.43ms +step:632/1670 train_time:58412ms step_avg:92.42ms +step:633/1670 train_time:58503ms step_avg:92.42ms +step:634/1670 train_time:58595ms step_avg:92.42ms +step:635/1670 train_time:58686ms step_avg:92.42ms +step:636/1670 train_time:58777ms step_avg:92.42ms +step:637/1670 train_time:58869ms step_avg:92.42ms +step:638/1670 train_time:58966ms step_avg:92.42ms +step:639/1670 train_time:59204ms step_avg:92.65ms +step:640/1670 train_time:59275ms step_avg:92.62ms +step:641/1670 train_time:59366ms step_avg:92.61ms +step:642/1670 train_time:59457ms step_avg:92.61ms +step:643/1670 train_time:59548ms step_avg:92.61ms +step:644/1670 train_time:59640ms step_avg:92.61ms +step:645/1670 train_time:59731ms step_avg:92.61ms +step:646/1670 train_time:59823ms step_avg:92.60ms +step:647/1670 train_time:59914ms step_avg:92.60ms +step:648/1670 train_time:60005ms step_avg:92.60ms +step:649/1670 train_time:60103ms step_avg:92.61ms +step:650/1670 train_time:60200ms step_avg:92.62ms +step:651/1670 train_time:60294ms step_avg:92.62ms +step:652/1670 train_time:60386ms step_avg:92.62ms +step:653/1670 train_time:60478ms step_avg:92.61ms +step:654/1670 train_time:60569ms step_avg:92.61ms +step:655/1670 train_time:60663ms step_avg:92.62ms +step:656/1670 train_time:60754ms step_avg:92.61ms +step:657/1670 train_time:60845ms step_avg:92.61ms +step:658/1670 train_time:60936ms step_avg:92.61ms +step:659/1670 train_time:61029ms step_avg:92.61ms +step:660/1670 train_time:61124ms step_avg:92.61ms +step:661/1670 train_time:61218ms step_avg:92.61ms +step:662/1670 train_time:61311ms step_avg:92.61ms +step:663/1670 train_time:61405ms step_avg:92.62ms +step:664/1670 train_time:61497ms step_avg:92.62ms +step:665/1670 train_time:61588ms step_avg:92.61ms +step:666/1670 train_time:61681ms step_avg:92.61ms +step:667/1670 train_time:61772ms step_avg:92.61ms +step:668/1670 train_time:61864ms step_avg:92.61ms +step:669/1670 train_time:61956ms step_avg:92.61ms +step:670/1670 train_time:62048ms step_avg:92.61ms +step:671/1670 train_time:62142ms step_avg:92.61ms +step:672/1670 train_time:62236ms step_avg:92.61ms +step:673/1670 train_time:62329ms step_avg:92.61ms +step:674/1670 train_time:62422ms step_avg:92.61ms +step:675/1670 train_time:62514ms step_avg:92.61ms +step:676/1670 train_time:62606ms step_avg:92.61ms +step:677/1670 train_time:62699ms step_avg:92.61ms +step:678/1670 train_time:62790ms step_avg:92.61ms +step:679/1670 train_time:62882ms step_avg:92.61ms +step:680/1670 train_time:62974ms step_avg:92.61ms +step:681/1670 train_time:63066ms step_avg:92.61ms +step:682/1670 train_time:63159ms step_avg:92.61ms +step:683/1670 train_time:63252ms step_avg:92.61ms +step:684/1670 train_time:63345ms step_avg:92.61ms +step:685/1670 train_time:63438ms step_avg:92.61ms +step:686/1670 train_time:63529ms step_avg:92.61ms +step:687/1670 train_time:63622ms step_avg:92.61ms +step:688/1670 train_time:63714ms step_avg:92.61ms +step:689/1670 train_time:63806ms step_avg:92.61ms +step:690/1670 train_time:63898ms step_avg:92.61ms +step:691/1670 train_time:63991ms step_avg:92.61ms +step:692/1670 train_time:64084ms step_avg:92.61ms +step:693/1670 train_time:64176ms step_avg:92.61ms +step:694/1670 train_time:64269ms step_avg:92.61ms +step:695/1670 train_time:64363ms step_avg:92.61ms +step:696/1670 train_time:64455ms step_avg:92.61ms +step:697/1670 train_time:64546ms step_avg:92.61ms +step:698/1670 train_time:64639ms step_avg:92.61ms +step:699/1670 train_time:64731ms step_avg:92.60ms +step:700/1670 train_time:64823ms step_avg:92.60ms +step:701/1670 train_time:64916ms step_avg:92.60ms +step:702/1670 train_time:65008ms step_avg:92.60ms +step:703/1670 train_time:65101ms step_avg:92.60ms +step:704/1670 train_time:65194ms step_avg:92.60ms +step:705/1670 train_time:65287ms step_avg:92.61ms +step:706/1670 train_time:65380ms step_avg:92.61ms +step:707/1670 train_time:65474ms step_avg:92.61ms +step:708/1670 train_time:65567ms step_avg:92.61ms +step:709/1670 train_time:65660ms step_avg:92.61ms +step:710/1670 train_time:65752ms step_avg:92.61ms +step:711/1670 train_time:65845ms step_avg:92.61ms +step:712/1670 train_time:65938ms step_avg:92.61ms +step:713/1670 train_time:66030ms step_avg:92.61ms +step:714/1670 train_time:66123ms step_avg:92.61ms +step:715/1670 train_time:66215ms step_avg:92.61ms +step:716/1670 train_time:66308ms step_avg:92.61ms +step:717/1670 train_time:66400ms step_avg:92.61ms +step:718/1670 train_time:66492ms step_avg:92.61ms +step:719/1670 train_time:66586ms step_avg:92.61ms +step:720/1670 train_time:66679ms step_avg:92.61ms +step:721/1670 train_time:66771ms step_avg:92.61ms +step:722/1670 train_time:66863ms step_avg:92.61ms +step:723/1670 train_time:66956ms step_avg:92.61ms +step:724/1670 train_time:67047ms step_avg:92.61ms +step:725/1670 train_time:67140ms step_avg:92.61ms +step:726/1670 train_time:67232ms step_avg:92.61ms +step:727/1670 train_time:67326ms step_avg:92.61ms +step:728/1670 train_time:67418ms step_avg:92.61ms +step:729/1670 train_time:67510ms step_avg:92.61ms +step:730/1670 train_time:67603ms step_avg:92.61ms +step:731/1670 train_time:67694ms step_avg:92.61ms +step:732/1670 train_time:67787ms step_avg:92.61ms +step:733/1670 train_time:67881ms step_avg:92.61ms +step:734/1670 train_time:67973ms step_avg:92.61ms +step:735/1670 train_time:68066ms step_avg:92.61ms +step:736/1670 train_time:68158ms step_avg:92.61ms +step:737/1670 train_time:68250ms step_avg:92.61ms +step:738/1670 train_time:68343ms step_avg:92.61ms +step:739/1670 train_time:68436ms step_avg:92.61ms +step:740/1670 train_time:68528ms step_avg:92.61ms +step:741/1670 train_time:68620ms step_avg:92.61ms +step:742/1670 train_time:68712ms step_avg:92.60ms +step:743/1670 train_time:68805ms step_avg:92.60ms +step:744/1670 train_time:68898ms step_avg:92.60ms +step:745/1670 train_time:68990ms step_avg:92.60ms +step:746/1670 train_time:69083ms step_avg:92.60ms +step:747/1670 train_time:69175ms step_avg:92.60ms +step:748/1670 train_time:69267ms step_avg:92.60ms +step:749/1670 train_time:69360ms step_avg:92.60ms +step:750/1670 train_time:69452ms step_avg:92.60ms +step:750/1670 val_loss:3.5618 train_time:69544ms step_avg:92.73ms +step:751/1670 train_time:69563ms step_avg:92.63ms +step:752/1670 train_time:69640ms step_avg:92.61ms +step:753/1670 train_time:69732ms step_avg:92.61ms +step:754/1670 train_time:69824ms step_avg:92.60ms +step:755/1670 train_time:69916ms step_avg:92.60ms +step:756/1670 train_time:70007ms step_avg:92.60ms +step:757/1670 train_time:70099ms step_avg:92.60ms +step:758/1670 train_time:70193ms step_avg:92.60ms +step:759/1670 train_time:70284ms step_avg:92.60ms +step:760/1670 train_time:70376ms step_avg:92.60ms +step:761/1670 train_time:70469ms step_avg:92.60ms +step:762/1670 train_time:70563ms step_avg:92.60ms +step:763/1670 train_time:70659ms step_avg:92.61ms +step:764/1670 train_time:70752ms step_avg:92.61ms +step:765/1670 train_time:70844ms step_avg:92.61ms +step:766/1670 train_time:70936ms step_avg:92.61ms +step:767/1670 train_time:71028ms step_avg:92.61ms +step:768/1670 train_time:71120ms step_avg:92.60ms +step:769/1670 train_time:71212ms step_avg:92.60ms +step:770/1670 train_time:71303ms step_avg:92.60ms +step:771/1670 train_time:71395ms step_avg:92.60ms +step:772/1670 train_time:71488ms step_avg:92.60ms +step:773/1670 train_time:71582ms step_avg:92.60ms +step:774/1670 train_time:71677ms step_avg:92.61ms +step:775/1670 train_time:71770ms step_avg:92.61ms +step:776/1670 train_time:71863ms step_avg:92.61ms +step:777/1670 train_time:71956ms step_avg:92.61ms +step:778/1670 train_time:72048ms step_avg:92.61ms +step:779/1670 train_time:72141ms step_avg:92.61ms +step:780/1670 train_time:72233ms step_avg:92.61ms +step:781/1670 train_time:72325ms step_avg:92.61ms +step:782/1670 train_time:72416ms step_avg:92.60ms +step:783/1670 train_time:72510ms step_avg:92.61ms +step:784/1670 train_time:72603ms step_avg:92.61ms +step:785/1670 train_time:72696ms step_avg:92.61ms +step:786/1670 train_time:72788ms step_avg:92.61ms +step:787/1670 train_time:72882ms step_avg:92.61ms +step:788/1670 train_time:72974ms step_avg:92.61ms +step:789/1670 train_time:73066ms step_avg:92.61ms +step:790/1670 train_time:73158ms step_avg:92.61ms +step:791/1670 train_time:73251ms step_avg:92.61ms +step:792/1670 train_time:73343ms step_avg:92.60ms +step:793/1670 train_time:73436ms step_avg:92.61ms +step:794/1670 train_time:73529ms step_avg:92.61ms +step:795/1670 train_time:73621ms step_avg:92.61ms +step:796/1670 train_time:73714ms step_avg:92.61ms +step:797/1670 train_time:73806ms step_avg:92.60ms +step:798/1670 train_time:73899ms step_avg:92.60ms +step:799/1670 train_time:73991ms step_avg:92.60ms +step:800/1670 train_time:74083ms step_avg:92.60ms +step:801/1670 train_time:74175ms step_avg:92.60ms +step:802/1670 train_time:74267ms step_avg:92.60ms +step:803/1670 train_time:74359ms step_avg:92.60ms +step:804/1670 train_time:74452ms step_avg:92.60ms +step:805/1670 train_time:74545ms step_avg:92.60ms +step:806/1670 train_time:74638ms step_avg:92.60ms +step:807/1670 train_time:74730ms step_avg:92.60ms +step:808/1670 train_time:74822ms step_avg:92.60ms +step:809/1670 train_time:74914ms step_avg:92.60ms +step:810/1670 train_time:75006ms step_avg:92.60ms +step:811/1670 train_time:75099ms step_avg:92.60ms +step:812/1670 train_time:75191ms step_avg:92.60ms +step:813/1670 train_time:75282ms step_avg:92.60ms +step:814/1670 train_time:75374ms step_avg:92.60ms +step:815/1670 train_time:75466ms step_avg:92.60ms +step:816/1670 train_time:75561ms step_avg:92.60ms +step:817/1670 train_time:75654ms step_avg:92.60ms +step:818/1670 train_time:75745ms step_avg:92.60ms +step:819/1670 train_time:75838ms step_avg:92.60ms +step:820/1670 train_time:75930ms step_avg:92.60ms +step:821/1670 train_time:76022ms step_avg:92.60ms +step:822/1670 train_time:76116ms step_avg:92.60ms +step:823/1670 train_time:76208ms step_avg:92.60ms +step:824/1670 train_time:76301ms step_avg:92.60ms +step:825/1670 train_time:76394ms step_avg:92.60ms +step:826/1670 train_time:76486ms step_avg:92.60ms +step:827/1670 train_time:76580ms step_avg:92.60ms +step:828/1670 train_time:76673ms step_avg:92.60ms +step:829/1670 train_time:76764ms step_avg:92.60ms +step:830/1670 train_time:76857ms step_avg:92.60ms +step:831/1670 train_time:76950ms step_avg:92.60ms +step:832/1670 train_time:77043ms step_avg:92.60ms +step:833/1670 train_time:77135ms step_avg:92.60ms +step:834/1670 train_time:77227ms step_avg:92.60ms +step:835/1670 train_time:77320ms step_avg:92.60ms +step:836/1670 train_time:77413ms step_avg:92.60ms +step:837/1670 train_time:77505ms step_avg:92.60ms +step:838/1670 train_time:77599ms step_avg:92.60ms +step:839/1670 train_time:77692ms step_avg:92.60ms +step:840/1670 train_time:77784ms step_avg:92.60ms +step:841/1670 train_time:77877ms step_avg:92.60ms +step:842/1670 train_time:77969ms step_avg:92.60ms +step:843/1670 train_time:78063ms step_avg:92.60ms +step:844/1670 train_time:78156ms step_avg:92.60ms +step:845/1670 train_time:78247ms step_avg:92.60ms +step:846/1670 train_time:78341ms step_avg:92.60ms +step:847/1670 train_time:78434ms step_avg:92.60ms +step:848/1670 train_time:78525ms step_avg:92.60ms +step:849/1670 train_time:78620ms step_avg:92.60ms +step:850/1670 train_time:78713ms step_avg:92.60ms +step:851/1670 train_time:78964ms step_avg:92.79ms +step:852/1670 train_time:79035ms step_avg:92.76ms +step:853/1670 train_time:79125ms step_avg:92.76ms +step:854/1670 train_time:79217ms step_avg:92.76ms +step:855/1670 train_time:79308ms step_avg:92.76ms +step:856/1670 train_time:79399ms step_avg:92.76ms +step:857/1670 train_time:79491ms step_avg:92.75ms +step:858/1670 train_time:79582ms step_avg:92.75ms +step:859/1670 train_time:79673ms step_avg:92.75ms +step:860/1670 train_time:79764ms step_avg:92.75ms +step:861/1670 train_time:79867ms step_avg:92.76ms +step:862/1670 train_time:79964ms step_avg:92.77ms +step:863/1670 train_time:80057ms step_avg:92.77ms +step:864/1670 train_time:80149ms step_avg:92.76ms +step:865/1670 train_time:80240ms step_avg:92.76ms +step:866/1670 train_time:80331ms step_avg:92.76ms +step:867/1670 train_time:80422ms step_avg:92.76ms +step:868/1670 train_time:80513ms step_avg:92.76ms +step:869/1670 train_time:80604ms step_avg:92.75ms +step:870/1670 train_time:80695ms step_avg:92.75ms +step:871/1670 train_time:80790ms step_avg:92.76ms +step:872/1670 train_time:80885ms step_avg:92.76ms +step:873/1670 train_time:80981ms step_avg:92.76ms +step:874/1670 train_time:81075ms step_avg:92.76ms +step:875/1670 train_time:81167ms step_avg:92.76ms +step:875/1670 val_loss:3.5199 train_time:81260ms step_avg:92.87ms +step:876/1670 train_time:81278ms step_avg:92.78ms +step:877/1670 train_time:81353ms step_avg:92.76ms +step:878/1670 train_time:81446ms step_avg:92.76ms +step:879/1670 train_time:81538ms step_avg:92.76ms +step:880/1670 train_time:81629ms step_avg:92.76ms +step:881/1670 train_time:81721ms step_avg:92.76ms +step:882/1670 train_time:81812ms step_avg:92.76ms +step:883/1670 train_time:81904ms step_avg:92.76ms +step:884/1670 train_time:81997ms step_avg:92.76ms +step:885/1670 train_time:82091ms step_avg:92.76ms +step:886/1670 train_time:82184ms step_avg:92.76ms +step:887/1670 train_time:82278ms step_avg:92.76ms +step:888/1670 train_time:82372ms step_avg:92.76ms +step:889/1670 train_time:82464ms step_avg:92.76ms +step:890/1670 train_time:82556ms step_avg:92.76ms +step:891/1670 train_time:82647ms step_avg:92.76ms +step:892/1670 train_time:82741ms step_avg:92.76ms +step:893/1670 train_time:82833ms step_avg:92.76ms +step:894/1670 train_time:82925ms step_avg:92.76ms +step:895/1670 train_time:83018ms step_avg:92.76ms +step:896/1670 train_time:83110ms step_avg:92.76ms +step:897/1670 train_time:83204ms step_avg:92.76ms +step:898/1670 train_time:83297ms step_avg:92.76ms +step:899/1670 train_time:83390ms step_avg:92.76ms +step:900/1670 train_time:83483ms step_avg:92.76ms +step:901/1670 train_time:83576ms step_avg:92.76ms +step:902/1670 train_time:83668ms step_avg:92.76ms +step:903/1670 train_time:83760ms step_avg:92.76ms +step:904/1670 train_time:83851ms step_avg:92.76ms +step:905/1670 train_time:83944ms step_avg:92.76ms +step:906/1670 train_time:84036ms step_avg:92.75ms +step:907/1670 train_time:84127ms step_avg:92.75ms +step:908/1670 train_time:84222ms step_avg:92.76ms +step:909/1670 train_time:84315ms step_avg:92.76ms +step:910/1670 train_time:84406ms step_avg:92.75ms +step:911/1670 train_time:84500ms step_avg:92.75ms +step:912/1670 train_time:84592ms step_avg:92.75ms +step:913/1670 train_time:84684ms step_avg:92.75ms +step:914/1670 train_time:84776ms step_avg:92.75ms +step:915/1670 train_time:84868ms step_avg:92.75ms +step:916/1670 train_time:84960ms step_avg:92.75ms +step:917/1670 train_time:85053ms step_avg:92.75ms +step:918/1670 train_time:85145ms step_avg:92.75ms +step:919/1670 train_time:85237ms step_avg:92.75ms +step:920/1670 train_time:85329ms step_avg:92.75ms +step:921/1670 train_time:85422ms step_avg:92.75ms +step:922/1670 train_time:85515ms step_avg:92.75ms +step:923/1670 train_time:85608ms step_avg:92.75ms +step:924/1670 train_time:85701ms step_avg:92.75ms +step:925/1670 train_time:85793ms step_avg:92.75ms +step:926/1670 train_time:85885ms step_avg:92.75ms +step:927/1670 train_time:85977ms step_avg:92.75ms +step:928/1670 train_time:86069ms step_avg:92.75ms +step:929/1670 train_time:86162ms step_avg:92.75ms +step:930/1670 train_time:86255ms step_avg:92.75ms +step:931/1670 train_time:86347ms step_avg:92.75ms +step:932/1670 train_time:86440ms step_avg:92.75ms +step:933/1670 train_time:86532ms step_avg:92.75ms +step:934/1670 train_time:86626ms step_avg:92.75ms +step:935/1670 train_time:86719ms step_avg:92.75ms +step:936/1670 train_time:86810ms step_avg:92.75ms +step:937/1670 train_time:86903ms step_avg:92.75ms +step:938/1670 train_time:86995ms step_avg:92.75ms +step:939/1670 train_time:87088ms step_avg:92.75ms +step:940/1670 train_time:87181ms step_avg:92.75ms +step:941/1670 train_time:87274ms step_avg:92.75ms +step:942/1670 train_time:87366ms step_avg:92.75ms +step:943/1670 train_time:87458ms step_avg:92.74ms +step:944/1670 train_time:87551ms step_avg:92.75ms +step:945/1670 train_time:87644ms step_avg:92.75ms +step:946/1670 train_time:87737ms step_avg:92.74ms +step:947/1670 train_time:87828ms step_avg:92.74ms +step:948/1670 train_time:87921ms step_avg:92.74ms +step:949/1670 train_time:88014ms step_avg:92.74ms +step:950/1670 train_time:88106ms step_avg:92.74ms +step:951/1670 train_time:88199ms step_avg:92.74ms +step:952/1670 train_time:88291ms step_avg:92.74ms +step:953/1670 train_time:88383ms step_avg:92.74ms +step:954/1670 train_time:88476ms step_avg:92.74ms +step:955/1670 train_time:88568ms step_avg:92.74ms +step:956/1670 train_time:88661ms step_avg:92.74ms +step:957/1670 train_time:88755ms step_avg:92.74ms +step:958/1670 train_time:88847ms step_avg:92.74ms +step:959/1670 train_time:88940ms step_avg:92.74ms +step:960/1670 train_time:89032ms step_avg:92.74ms +step:961/1670 train_time:89125ms step_avg:92.74ms +step:962/1670 train_time:89218ms step_avg:92.74ms +step:963/1670 train_time:89310ms step_avg:92.74ms +step:964/1670 train_time:89403ms step_avg:92.74ms +step:965/1670 train_time:89496ms step_avg:92.74ms +step:966/1670 train_time:89588ms step_avg:92.74ms +step:967/1670 train_time:89681ms step_avg:92.74ms +step:968/1670 train_time:89773ms step_avg:92.74ms +step:969/1670 train_time:89865ms step_avg:92.74ms +step:970/1670 train_time:89958ms step_avg:92.74ms +step:971/1670 train_time:90050ms step_avg:92.74ms +step:972/1670 train_time:90144ms step_avg:92.74ms +step:973/1670 train_time:90236ms step_avg:92.74ms +step:974/1670 train_time:90328ms step_avg:92.74ms +step:975/1670 train_time:90421ms step_avg:92.74ms +step:976/1670 train_time:90513ms step_avg:92.74ms +step:977/1670 train_time:90605ms step_avg:92.74ms +step:978/1670 train_time:90698ms step_avg:92.74ms +step:979/1670 train_time:90791ms step_avg:92.74ms +step:980/1670 train_time:90883ms step_avg:92.74ms +step:981/1670 train_time:90976ms step_avg:92.74ms +step:982/1670 train_time:91068ms step_avg:92.74ms +step:983/1670 train_time:91162ms step_avg:92.74ms +step:984/1670 train_time:91254ms step_avg:92.74ms +step:985/1670 train_time:91346ms step_avg:92.74ms +step:986/1670 train_time:91438ms step_avg:92.74ms +step:987/1670 train_time:91529ms step_avg:92.74ms +step:988/1670 train_time:91623ms step_avg:92.74ms +step:989/1670 train_time:91717ms step_avg:92.74ms +step:990/1670 train_time:91809ms step_avg:92.74ms +step:991/1670 train_time:91902ms step_avg:92.74ms +step:992/1670 train_time:91995ms step_avg:92.74ms +step:993/1670 train_time:92088ms step_avg:92.74ms +step:994/1670 train_time:92181ms step_avg:92.74ms +step:995/1670 train_time:92273ms step_avg:92.74ms +step:996/1670 train_time:92365ms step_avg:92.74ms +step:997/1670 train_time:92457ms step_avg:92.74ms +step:998/1670 train_time:92550ms step_avg:92.74ms +step:999/1670 train_time:92643ms step_avg:92.74ms +step:1000/1670 train_time:92736ms step_avg:92.74ms +step:1000/1670 val_loss:3.4695 train_time:92828ms step_avg:92.83ms +step:1001/1670 train_time:92846ms step_avg:92.75ms +step:1002/1670 train_time:92922ms step_avg:92.74ms +step:1003/1670 train_time:93014ms step_avg:92.74ms +step:1004/1670 train_time:93105ms step_avg:92.73ms +step:1005/1670 train_time:93196ms step_avg:92.73ms +step:1006/1670 train_time:93288ms step_avg:92.73ms +step:1007/1670 train_time:93380ms step_avg:92.73ms +step:1008/1670 train_time:93472ms step_avg:92.73ms +step:1009/1670 train_time:93564ms step_avg:92.73ms +step:1010/1670 train_time:93657ms step_avg:92.73ms +step:1011/1670 train_time:93750ms step_avg:92.73ms +step:1012/1670 train_time:93844ms step_avg:92.73ms +step:1013/1670 train_time:93939ms step_avg:92.73ms +step:1014/1670 train_time:94032ms step_avg:92.73ms +step:1015/1670 train_time:94124ms step_avg:92.73ms +step:1016/1670 train_time:94216ms step_avg:92.73ms +step:1017/1670 train_time:94307ms step_avg:92.73ms +step:1018/1670 train_time:94401ms step_avg:92.73ms +step:1019/1670 train_time:94492ms step_avg:92.73ms +step:1020/1670 train_time:94584ms step_avg:92.73ms +step:1021/1670 train_time:94676ms step_avg:92.73ms +step:1022/1670 train_time:94769ms step_avg:92.73ms +step:1023/1670 train_time:94864ms step_avg:92.73ms +step:1024/1670 train_time:94958ms step_avg:92.73ms +step:1025/1670 train_time:95050ms step_avg:92.73ms +step:1026/1670 train_time:95143ms step_avg:92.73ms +step:1027/1670 train_time:95235ms step_avg:92.73ms +step:1028/1670 train_time:95326ms step_avg:92.73ms +step:1029/1670 train_time:95419ms step_avg:92.73ms +step:1030/1670 train_time:95511ms step_avg:92.73ms +step:1031/1670 train_time:95603ms step_avg:92.73ms +step:1032/1670 train_time:95696ms step_avg:92.73ms +step:1033/1670 train_time:95788ms step_avg:92.73ms +step:1034/1670 train_time:95881ms step_avg:92.73ms +step:1035/1670 train_time:95975ms step_avg:92.73ms +step:1036/1670 train_time:96067ms step_avg:92.73ms +step:1037/1670 train_time:96160ms step_avg:92.73ms +step:1038/1670 train_time:96252ms step_avg:92.73ms +step:1039/1670 train_time:96345ms step_avg:92.73ms +step:1040/1670 train_time:96438ms step_avg:92.73ms +step:1041/1670 train_time:96529ms step_avg:92.73ms +step:1042/1670 train_time:96621ms step_avg:92.73ms +step:1043/1670 train_time:96714ms step_avg:92.73ms +step:1044/1670 train_time:96807ms step_avg:92.73ms +step:1045/1670 train_time:96901ms step_avg:92.73ms +step:1046/1670 train_time:96994ms step_avg:92.73ms +step:1047/1670 train_time:97085ms step_avg:92.73ms +step:1048/1670 train_time:97178ms step_avg:92.73ms +step:1049/1670 train_time:97270ms step_avg:92.73ms +step:1050/1670 train_time:97363ms step_avg:92.73ms +step:1051/1670 train_time:97457ms step_avg:92.73ms +step:1052/1670 train_time:97548ms step_avg:92.73ms +step:1053/1670 train_time:97641ms step_avg:92.73ms +step:1054/1670 train_time:97733ms step_avg:92.73ms +step:1055/1670 train_time:97826ms step_avg:92.73ms +step:1056/1670 train_time:97919ms step_avg:92.73ms +step:1057/1670 train_time:98011ms step_avg:92.73ms +step:1058/1670 train_time:98104ms step_avg:92.73ms +step:1059/1670 train_time:98196ms step_avg:92.73ms +step:1060/1670 train_time:98288ms step_avg:92.72ms +step:1061/1670 train_time:98381ms step_avg:92.72ms +step:1062/1670 train_time:98630ms step_avg:92.87ms +step:1063/1670 train_time:98701ms step_avg:92.85ms +step:1064/1670 train_time:98791ms step_avg:92.85ms +step:1065/1670 train_time:98883ms step_avg:92.85ms +step:1066/1670 train_time:98974ms step_avg:92.85ms +step:1067/1670 train_time:99065ms step_avg:92.84ms +step:1068/1670 train_time:99156ms step_avg:92.84ms +step:1069/1670 train_time:99248ms step_avg:92.84ms +step:1070/1670 train_time:99339ms step_avg:92.84ms +step:1071/1670 train_time:99430ms step_avg:92.84ms +step:1072/1670 train_time:99526ms step_avg:92.84ms +step:1073/1670 train_time:99623ms step_avg:92.85ms +step:1074/1670 train_time:99717ms step_avg:92.85ms +step:1075/1670 train_time:99809ms step_avg:92.85ms +step:1076/1670 train_time:99901ms step_avg:92.85ms +step:1077/1670 train_time:99993ms step_avg:92.84ms +step:1078/1670 train_time:100084ms step_avg:92.84ms +step:1079/1670 train_time:100176ms step_avg:92.84ms +step:1080/1670 train_time:100268ms step_avg:92.84ms +step:1081/1670 train_time:100360ms step_avg:92.84ms +step:1082/1670 train_time:100452ms step_avg:92.84ms +step:1083/1670 train_time:100546ms step_avg:92.84ms +step:1084/1670 train_time:100642ms step_avg:92.84ms +step:1085/1670 train_time:100735ms step_avg:92.84ms +step:1086/1670 train_time:100827ms step_avg:92.84ms +step:1087/1670 train_time:100918ms step_avg:92.84ms +step:1088/1670 train_time:101010ms step_avg:92.84ms +step:1089/1670 train_time:101102ms step_avg:92.84ms +step:1090/1670 train_time:101194ms step_avg:92.84ms +step:1091/1670 train_time:101285ms step_avg:92.84ms +step:1092/1670 train_time:101378ms step_avg:92.84ms +step:1093/1670 train_time:101470ms step_avg:92.84ms +step:1094/1670 train_time:101565ms step_avg:92.84ms +step:1095/1670 train_time:101659ms step_avg:92.84ms +step:1096/1670 train_time:101752ms step_avg:92.84ms +step:1097/1670 train_time:101843ms step_avg:92.84ms +step:1098/1670 train_time:101936ms step_avg:92.84ms +step:1099/1670 train_time:102027ms step_avg:92.84ms +step:1100/1670 train_time:102119ms step_avg:92.84ms +step:1101/1670 train_time:102211ms step_avg:92.84ms +step:1102/1670 train_time:102304ms step_avg:92.83ms +step:1103/1670 train_time:102396ms step_avg:92.83ms +step:1104/1670 train_time:102488ms step_avg:92.83ms +step:1105/1670 train_time:102581ms step_avg:92.83ms +step:1106/1670 train_time:102675ms step_avg:92.83ms +step:1107/1670 train_time:102768ms step_avg:92.83ms +step:1108/1670 train_time:102862ms step_avg:92.84ms +step:1109/1670 train_time:102954ms step_avg:92.84ms +step:1110/1670 train_time:103045ms step_avg:92.83ms +step:1111/1670 train_time:103137ms step_avg:92.83ms +step:1112/1670 train_time:103229ms step_avg:92.83ms +step:1113/1670 train_time:103321ms step_avg:92.83ms +step:1114/1670 train_time:103413ms step_avg:92.83ms +step:1115/1670 train_time:103680ms step_avg:92.99ms +step:1116/1670 train_time:103772ms step_avg:92.99ms +step:1117/1670 train_time:103863ms step_avg:92.98ms +step:1118/1670 train_time:103955ms step_avg:92.98ms +step:1119/1670 train_time:104047ms step_avg:92.98ms +step:1120/1670 train_time:104139ms step_avg:92.98ms +step:1121/1670 train_time:104231ms step_avg:92.98ms +step:1122/1670 train_time:104322ms step_avg:92.98ms +step:1123/1670 train_time:104414ms step_avg:92.98ms +step:1124/1670 train_time:104506ms step_avg:92.98ms +step:1125/1670 train_time:104603ms step_avg:92.98ms +step:1125/1670 val_loss:3.4163 train_time:104704ms step_avg:93.07ms +step:1126/1670 train_time:104724ms step_avg:93.01ms +step:1127/1670 train_time:104804ms step_avg:92.99ms +step:1128/1670 train_time:104902ms step_avg:93.00ms +step:1129/1670 train_time:104996ms step_avg:93.00ms +step:1130/1670 train_time:105088ms step_avg:93.00ms +step:1131/1670 train_time:105180ms step_avg:93.00ms +step:1132/1670 train_time:105272ms step_avg:93.00ms +step:1133/1670 train_time:105364ms step_avg:93.00ms +step:1134/1670 train_time:105457ms step_avg:93.00ms +step:1135/1670 train_time:105550ms step_avg:93.00ms +step:1136/1670 train_time:105644ms step_avg:93.00ms +step:1137/1670 train_time:105741ms step_avg:93.00ms +step:1138/1670 train_time:105838ms step_avg:93.00ms +step:1139/1670 train_time:105931ms step_avg:93.00ms +step:1140/1670 train_time:106025ms step_avg:93.00ms +step:1141/1670 train_time:106117ms step_avg:93.00ms +step:1142/1670 train_time:106210ms step_avg:93.00ms +step:1143/1670 train_time:106301ms step_avg:93.00ms +step:1144/1670 train_time:106393ms step_avg:93.00ms +step:1145/1670 train_time:106485ms step_avg:93.00ms +step:1146/1670 train_time:106580ms step_avg:93.00ms +step:1147/1670 train_time:106674ms step_avg:93.00ms +step:1148/1670 train_time:106768ms step_avg:93.00ms +step:1149/1670 train_time:106864ms step_avg:93.01ms +step:1150/1670 train_time:106960ms step_avg:93.01ms +step:1151/1670 train_time:107053ms step_avg:93.01ms +step:1152/1670 train_time:107145ms step_avg:93.01ms +step:1153/1670 train_time:107237ms step_avg:93.01ms +step:1154/1670 train_time:107330ms step_avg:93.01ms +step:1155/1670 train_time:107423ms step_avg:93.01ms +step:1156/1670 train_time:107517ms step_avg:93.01ms +step:1157/1670 train_time:107610ms step_avg:93.01ms +step:1158/1670 train_time:107703ms step_avg:93.01ms +step:1159/1670 train_time:107797ms step_avg:93.01ms +step:1160/1670 train_time:107892ms step_avg:93.01ms +step:1161/1670 train_time:107985ms step_avg:93.01ms +step:1162/1670 train_time:108079ms step_avg:93.01ms +step:1163/1670 train_time:108172ms step_avg:93.01ms +step:1164/1670 train_time:108264ms step_avg:93.01ms +step:1165/1670 train_time:108358ms step_avg:93.01ms +step:1166/1670 train_time:108451ms step_avg:93.01ms +step:1167/1670 train_time:108543ms step_avg:93.01ms +step:1168/1670 train_time:108638ms step_avg:93.01ms +step:1169/1670 train_time:108732ms step_avg:93.01ms +step:1170/1670 train_time:108825ms step_avg:93.01ms +step:1171/1670 train_time:108920ms step_avg:93.01ms +step:1172/1670 train_time:109013ms step_avg:93.01ms +step:1173/1670 train_time:109106ms step_avg:93.01ms +step:1174/1670 train_time:109198ms step_avg:93.01ms +step:1175/1670 train_time:109292ms step_avg:93.01ms +step:1176/1670 train_time:109385ms step_avg:93.01ms +step:1177/1670 train_time:109479ms step_avg:93.01ms +step:1178/1670 train_time:109571ms step_avg:93.01ms +step:1179/1670 train_time:109665ms step_avg:93.01ms +step:1180/1670 train_time:109759ms step_avg:93.02ms +step:1181/1670 train_time:109853ms step_avg:93.02ms +step:1182/1670 train_time:109946ms step_avg:93.02ms +step:1183/1670 train_time:110040ms step_avg:93.02ms +step:1184/1670 train_time:110133ms step_avg:93.02ms +step:1185/1670 train_time:110225ms step_avg:93.02ms +step:1186/1670 train_time:110318ms step_avg:93.02ms +step:1187/1670 train_time:110410ms step_avg:93.02ms +step:1188/1670 train_time:110503ms step_avg:93.02ms +step:1189/1670 train_time:110599ms step_avg:93.02ms +step:1190/1670 train_time:110690ms step_avg:93.02ms +step:1191/1670 train_time:110784ms step_avg:93.02ms +step:1192/1670 train_time:110877ms step_avg:93.02ms +step:1193/1670 train_time:110970ms step_avg:93.02ms +step:1194/1670 train_time:111063ms step_avg:93.02ms +step:1195/1670 train_time:111156ms step_avg:93.02ms +step:1196/1670 train_time:111248ms step_avg:93.02ms +step:1197/1670 train_time:111341ms step_avg:93.02ms +step:1198/1670 train_time:111434ms step_avg:93.02ms +step:1199/1670 train_time:111526ms step_avg:93.02ms +step:1200/1670 train_time:111619ms step_avg:93.02ms +step:1201/1670 train_time:111713ms step_avg:93.02ms +step:1202/1670 train_time:111807ms step_avg:93.02ms +step:1203/1670 train_time:111901ms step_avg:93.02ms +step:1204/1670 train_time:111995ms step_avg:93.02ms +step:1205/1670 train_time:112088ms step_avg:93.02ms +step:1206/1670 train_time:112181ms step_avg:93.02ms +step:1207/1670 train_time:112275ms step_avg:93.02ms +step:1208/1670 train_time:112368ms step_avg:93.02ms +step:1209/1670 train_time:112462ms step_avg:93.02ms +step:1210/1670 train_time:112556ms step_avg:93.02ms +step:1211/1670 train_time:112649ms step_avg:93.02ms +step:1212/1670 train_time:112742ms step_avg:93.02ms +step:1213/1670 train_time:112835ms step_avg:93.02ms +step:1214/1670 train_time:112928ms step_avg:93.02ms +step:1215/1670 train_time:113022ms step_avg:93.02ms +step:1216/1670 train_time:113115ms step_avg:93.02ms +step:1217/1670 train_time:113208ms step_avg:93.02ms +step:1218/1670 train_time:113301ms step_avg:93.02ms +step:1219/1670 train_time:113394ms step_avg:93.02ms +step:1220/1670 train_time:113487ms step_avg:93.02ms +step:1221/1670 train_time:113581ms step_avg:93.02ms +step:1222/1670 train_time:113675ms step_avg:93.02ms +step:1223/1670 train_time:113767ms step_avg:93.02ms +step:1224/1670 train_time:113861ms step_avg:93.02ms +step:1225/1670 train_time:113954ms step_avg:93.02ms +step:1226/1670 train_time:114047ms step_avg:93.02ms +step:1227/1670 train_time:114141ms step_avg:93.02ms +step:1228/1670 train_time:114234ms step_avg:93.02ms +step:1229/1670 train_time:114327ms step_avg:93.02ms +step:1230/1670 train_time:114422ms step_avg:93.03ms +step:1231/1670 train_time:114513ms step_avg:93.02ms +step:1232/1670 train_time:114606ms step_avg:93.02ms +step:1233/1670 train_time:114700ms step_avg:93.02ms +step:1234/1670 train_time:114793ms step_avg:93.03ms +step:1235/1670 train_time:114886ms step_avg:93.02ms +step:1236/1670 train_time:114979ms step_avg:93.03ms +step:1237/1670 train_time:115072ms step_avg:93.02ms +step:1238/1670 train_time:115165ms step_avg:93.02ms +step:1239/1670 train_time:115258ms step_avg:93.03ms +step:1240/1670 train_time:115353ms step_avg:93.03ms +step:1241/1670 train_time:115446ms step_avg:93.03ms +step:1242/1670 train_time:115538ms step_avg:93.03ms +step:1243/1670 train_time:115632ms step_avg:93.03ms +step:1244/1670 train_time:115725ms step_avg:93.03ms +step:1245/1670 train_time:115820ms step_avg:93.03ms +step:1246/1670 train_time:115913ms step_avg:93.03ms +step:1247/1670 train_time:116005ms step_avg:93.03ms +step:1248/1670 train_time:116098ms step_avg:93.03ms +step:1249/1670 train_time:116191ms step_avg:93.03ms +step:1250/1670 train_time:116284ms step_avg:93.03ms +step:1250/1670 val_loss:3.3780 train_time:116376ms step_avg:93.10ms +step:1251/1670 train_time:116396ms step_avg:93.04ms +step:1252/1670 train_time:116471ms step_avg:93.03ms +step:1253/1670 train_time:116564ms step_avg:93.03ms +step:1254/1670 train_time:116656ms step_avg:93.03ms +step:1255/1670 train_time:116748ms step_avg:93.03ms +step:1256/1670 train_time:116841ms step_avg:93.03ms +step:1257/1670 train_time:116933ms step_avg:93.03ms +step:1258/1670 train_time:117027ms step_avg:93.03ms +step:1259/1670 train_time:117120ms step_avg:93.03ms +step:1260/1670 train_time:117213ms step_avg:93.03ms +step:1261/1670 train_time:117308ms step_avg:93.03ms +step:1262/1670 train_time:117402ms step_avg:93.03ms +step:1263/1670 train_time:117496ms step_avg:93.03ms +step:1264/1670 train_time:117588ms step_avg:93.03ms +step:1265/1670 train_time:117681ms step_avg:93.03ms +step:1266/1670 train_time:117773ms step_avg:93.03ms +step:1267/1670 train_time:117868ms step_avg:93.03ms +step:1268/1670 train_time:117961ms step_avg:93.03ms +step:1269/1670 train_time:118054ms step_avg:93.03ms +step:1270/1670 train_time:118146ms step_avg:93.03ms +step:1271/1670 train_time:118240ms step_avg:93.03ms +step:1272/1670 train_time:118333ms step_avg:93.03ms +step:1273/1670 train_time:118429ms step_avg:93.03ms +step:1274/1670 train_time:118667ms step_avg:93.15ms +step:1275/1670 train_time:118748ms step_avg:93.14ms +step:1276/1670 train_time:118840ms step_avg:93.13ms +step:1277/1670 train_time:118932ms step_avg:93.13ms +step:1278/1670 train_time:119024ms step_avg:93.13ms +step:1279/1670 train_time:119115ms step_avg:93.13ms +step:1280/1670 train_time:119208ms step_avg:93.13ms +step:1281/1670 train_time:119299ms step_avg:93.13ms +step:1282/1670 train_time:119391ms step_avg:93.13ms +step:1283/1670 train_time:119483ms step_avg:93.13ms +step:1284/1670 train_time:119581ms step_avg:93.13ms +step:1285/1670 train_time:119680ms step_avg:93.14ms +step:1286/1670 train_time:119774ms step_avg:93.14ms +step:1287/1670 train_time:119867ms step_avg:93.14ms +step:1288/1670 train_time:119959ms step_avg:93.14ms +step:1289/1670 train_time:120051ms step_avg:93.13ms +step:1290/1670 train_time:120143ms step_avg:93.13ms +step:1291/1670 train_time:120235ms step_avg:93.13ms +step:1292/1670 train_time:120327ms step_avg:93.13ms +step:1293/1670 train_time:120419ms step_avg:93.13ms +step:1294/1670 train_time:120512ms step_avg:93.13ms +step:1295/1670 train_time:120609ms step_avg:93.13ms +step:1296/1670 train_time:120704ms step_avg:93.14ms +step:1297/1670 train_time:120798ms step_avg:93.14ms +step:1298/1670 train_time:120891ms step_avg:93.14ms +step:1299/1670 train_time:120983ms step_avg:93.14ms +step:1300/1670 train_time:121076ms step_avg:93.14ms +step:1301/1670 train_time:121169ms step_avg:93.14ms +step:1302/1670 train_time:121261ms step_avg:93.13ms +step:1303/1670 train_time:121353ms step_avg:93.13ms +step:1304/1670 train_time:121446ms step_avg:93.13ms +step:1305/1670 train_time:121540ms step_avg:93.13ms +step:1306/1670 train_time:121634ms step_avg:93.13ms +step:1307/1670 train_time:121729ms step_avg:93.14ms +step:1308/1670 train_time:121823ms step_avg:93.14ms +step:1309/1670 train_time:121916ms step_avg:93.14ms +step:1310/1670 train_time:122009ms step_avg:93.14ms +step:1311/1670 train_time:122102ms step_avg:93.14ms +step:1312/1670 train_time:122194ms step_avg:93.14ms +step:1313/1670 train_time:122287ms step_avg:93.14ms +step:1314/1670 train_time:122380ms step_avg:93.14ms +step:1315/1670 train_time:122473ms step_avg:93.14ms +step:1316/1670 train_time:122568ms step_avg:93.14ms +step:1317/1670 train_time:122662ms step_avg:93.14ms +step:1318/1670 train_time:122756ms step_avg:93.14ms +step:1319/1670 train_time:122849ms step_avg:93.14ms +step:1320/1670 train_time:122942ms step_avg:93.14ms +step:1321/1670 train_time:123035ms step_avg:93.14ms +step:1322/1670 train_time:123129ms step_avg:93.14ms +step:1323/1670 train_time:123221ms step_avg:93.14ms +step:1324/1670 train_time:123313ms step_avg:93.14ms +step:1325/1670 train_time:123408ms step_avg:93.14ms +step:1326/1670 train_time:123502ms step_avg:93.14ms +step:1327/1670 train_time:123594ms step_avg:93.14ms +step:1328/1670 train_time:123688ms step_avg:93.14ms +step:1329/1670 train_time:123781ms step_avg:93.14ms +step:1330/1670 train_time:123874ms step_avg:93.14ms +step:1331/1670 train_time:123967ms step_avg:93.14ms +step:1332/1670 train_time:124060ms step_avg:93.14ms +step:1333/1670 train_time:124154ms step_avg:93.14ms +step:1334/1670 train_time:124246ms step_avg:93.14ms +step:1335/1670 train_time:124340ms step_avg:93.14ms +step:1336/1670 train_time:124434ms step_avg:93.14ms +step:1337/1670 train_time:124528ms step_avg:93.14ms +step:1338/1670 train_time:124621ms step_avg:93.14ms +step:1339/1670 train_time:124714ms step_avg:93.14ms +step:1340/1670 train_time:124808ms step_avg:93.14ms +step:1341/1670 train_time:124901ms step_avg:93.14ms +step:1342/1670 train_time:124993ms step_avg:93.14ms +step:1343/1670 train_time:125087ms step_avg:93.14ms +step:1344/1670 train_time:125180ms step_avg:93.14ms +step:1345/1670 train_time:125273ms step_avg:93.14ms +step:1346/1670 train_time:125366ms step_avg:93.14ms +step:1347/1670 train_time:125459ms step_avg:93.14ms +step:1348/1670 train_time:125553ms step_avg:93.14ms +step:1349/1670 train_time:125647ms step_avg:93.14ms +step:1350/1670 train_time:125740ms step_avg:93.14ms +step:1351/1670 train_time:125834ms step_avg:93.14ms +step:1352/1670 train_time:125928ms step_avg:93.14ms +step:1353/1670 train_time:126021ms step_avg:93.14ms +step:1354/1670 train_time:126113ms step_avg:93.14ms +step:1355/1670 train_time:126207ms step_avg:93.14ms +step:1356/1670 train_time:126300ms step_avg:93.14ms +step:1357/1670 train_time:126393ms step_avg:93.14ms +step:1358/1670 train_time:126486ms step_avg:93.14ms +step:1359/1670 train_time:126578ms step_avg:93.14ms +step:1360/1670 train_time:126672ms step_avg:93.14ms +step:1361/1670 train_time:126765ms step_avg:93.14ms +step:1362/1670 train_time:126858ms step_avg:93.14ms +step:1363/1670 train_time:126951ms step_avg:93.14ms +step:1364/1670 train_time:127048ms step_avg:93.14ms +step:1365/1670 train_time:127139ms step_avg:93.14ms +step:1366/1670 train_time:127232ms step_avg:93.14ms +step:1367/1670 train_time:127326ms step_avg:93.14ms +step:1368/1670 train_time:127418ms step_avg:93.14ms +step:1369/1670 train_time:127511ms step_avg:93.14ms +step:1370/1670 train_time:127604ms step_avg:93.14ms +step:1371/1670 train_time:127697ms step_avg:93.14ms +step:1372/1670 train_time:127791ms step_avg:93.14ms +step:1373/1670 train_time:127886ms step_avg:93.14ms +step:1374/1670 train_time:127979ms step_avg:93.14ms +step:1375/1670 train_time:128073ms step_avg:93.14ms +step:1375/1670 val_loss:3.3443 train_time:128166ms step_avg:93.21ms +step:1376/1670 train_time:128185ms step_avg:93.16ms +step:1377/1670 train_time:128261ms step_avg:93.15ms +step:1378/1670 train_time:128354ms step_avg:93.15ms +step:1379/1670 train_time:128447ms step_avg:93.15ms +step:1380/1670 train_time:128540ms step_avg:93.14ms +step:1381/1670 train_time:128632ms step_avg:93.14ms +step:1382/1670 train_time:128724ms step_avg:93.14ms +step:1383/1670 train_time:128819ms step_avg:93.14ms +step:1384/1670 train_time:128913ms step_avg:93.15ms +step:1385/1670 train_time:129005ms step_avg:93.14ms +step:1386/1670 train_time:129099ms step_avg:93.15ms +step:1387/1670 train_time:129194ms step_avg:93.15ms +step:1388/1670 train_time:129287ms step_avg:93.15ms +step:1389/1670 train_time:129380ms step_avg:93.15ms +step:1390/1670 train_time:129473ms step_avg:93.15ms +step:1391/1670 train_time:129565ms step_avg:93.15ms +step:1392/1670 train_time:129658ms step_avg:93.15ms +step:1393/1670 train_time:129752ms step_avg:93.15ms +step:1394/1670 train_time:129844ms step_avg:93.15ms +step:1395/1670 train_time:129938ms step_avg:93.15ms +step:1396/1670 train_time:130032ms step_avg:93.15ms +step:1397/1670 train_time:130125ms step_avg:93.15ms +step:1398/1670 train_time:130220ms step_avg:93.15ms +step:1399/1670 train_time:130316ms step_avg:93.15ms +step:1400/1670 train_time:130410ms step_avg:93.15ms +step:1401/1670 train_time:130503ms step_avg:93.15ms +step:1402/1670 train_time:130596ms step_avg:93.15ms +step:1403/1670 train_time:130690ms step_avg:93.15ms +step:1404/1670 train_time:130783ms step_avg:93.15ms +step:1405/1670 train_time:130876ms step_avg:93.15ms +step:1406/1670 train_time:130970ms step_avg:93.15ms +step:1407/1670 train_time:131063ms step_avg:93.15ms +step:1408/1670 train_time:131157ms step_avg:93.15ms +step:1409/1670 train_time:131250ms step_avg:93.15ms +step:1410/1670 train_time:131344ms step_avg:93.15ms +step:1411/1670 train_time:131437ms step_avg:93.15ms +step:1412/1670 train_time:131529ms step_avg:93.15ms +step:1413/1670 train_time:131623ms step_avg:93.15ms +step:1414/1670 train_time:131717ms step_avg:93.15ms +step:1415/1670 train_time:131810ms step_avg:93.15ms +step:1416/1670 train_time:131903ms step_avg:93.15ms +step:1417/1670 train_time:131996ms step_avg:93.15ms +step:1418/1670 train_time:132089ms step_avg:93.15ms +step:1419/1670 train_time:132183ms step_avg:93.15ms +step:1420/1670 train_time:132276ms step_avg:93.15ms +step:1421/1670 train_time:132370ms step_avg:93.15ms +step:1422/1670 train_time:132463ms step_avg:93.15ms +step:1423/1670 train_time:132557ms step_avg:93.15ms +step:1424/1670 train_time:132650ms step_avg:93.15ms +step:1425/1670 train_time:132743ms step_avg:93.15ms +step:1426/1670 train_time:132836ms step_avg:93.15ms +step:1427/1670 train_time:132929ms step_avg:93.15ms +step:1428/1670 train_time:133023ms step_avg:93.15ms +step:1429/1670 train_time:133117ms step_avg:93.15ms +step:1430/1670 train_time:133210ms step_avg:93.15ms +step:1431/1670 train_time:133302ms step_avg:93.15ms +step:1432/1670 train_time:133395ms step_avg:93.15ms +step:1433/1670 train_time:133489ms step_avg:93.15ms +step:1434/1670 train_time:133582ms step_avg:93.15ms +step:1435/1670 train_time:133675ms step_avg:93.15ms +step:1436/1670 train_time:133770ms step_avg:93.15ms +step:1437/1670 train_time:133861ms step_avg:93.15ms +step:1438/1670 train_time:133955ms step_avg:93.15ms +step:1439/1670 train_time:134048ms step_avg:93.15ms +step:1440/1670 train_time:134141ms step_avg:93.15ms +step:1441/1670 train_time:134234ms step_avg:93.15ms +step:1442/1670 train_time:134326ms step_avg:93.15ms +step:1443/1670 train_time:134420ms step_avg:93.15ms +step:1444/1670 train_time:134513ms step_avg:93.15ms +step:1445/1670 train_time:134606ms step_avg:93.15ms +step:1446/1670 train_time:134699ms step_avg:93.15ms +step:1447/1670 train_time:134793ms step_avg:93.15ms +step:1448/1670 train_time:134885ms step_avg:93.15ms +step:1449/1670 train_time:134980ms step_avg:93.15ms +step:1450/1670 train_time:135074ms step_avg:93.15ms +step:1451/1670 train_time:135167ms step_avg:93.15ms +step:1452/1670 train_time:135261ms step_avg:93.15ms +step:1453/1670 train_time:135354ms step_avg:93.16ms +step:1454/1670 train_time:135448ms step_avg:93.16ms +step:1455/1670 train_time:135542ms step_avg:93.16ms +step:1456/1670 train_time:135634ms step_avg:93.16ms +step:1457/1670 train_time:135727ms step_avg:93.16ms +step:1458/1670 train_time:135821ms step_avg:93.16ms +step:1459/1670 train_time:135914ms step_avg:93.16ms +step:1460/1670 train_time:136007ms step_avg:93.16ms +step:1461/1670 train_time:136101ms step_avg:93.16ms +step:1462/1670 train_time:136195ms step_avg:93.16ms +step:1463/1670 train_time:136288ms step_avg:93.16ms +step:1464/1670 train_time:136382ms step_avg:93.16ms +step:1465/1670 train_time:136475ms step_avg:93.16ms +step:1466/1670 train_time:136568ms step_avg:93.16ms +step:1467/1670 train_time:136663ms step_avg:93.16ms +step:1468/1670 train_time:136756ms step_avg:93.16ms +step:1469/1670 train_time:136849ms step_avg:93.16ms +step:1470/1670 train_time:136942ms step_avg:93.16ms +step:1471/1670 train_time:137034ms step_avg:93.16ms +step:1472/1670 train_time:137128ms step_avg:93.16ms +step:1473/1670 train_time:137221ms step_avg:93.16ms +step:1474/1670 train_time:137314ms step_avg:93.16ms +step:1475/1670 train_time:137407ms step_avg:93.16ms +step:1476/1670 train_time:137501ms step_avg:93.16ms +step:1477/1670 train_time:137595ms step_avg:93.16ms +step:1478/1670 train_time:137688ms step_avg:93.16ms +step:1479/1670 train_time:137782ms step_avg:93.16ms +step:1480/1670 train_time:137876ms step_avg:93.16ms +step:1481/1670 train_time:137968ms step_avg:93.16ms +step:1482/1670 train_time:138063ms step_avg:93.16ms +step:1483/1670 train_time:138157ms step_avg:93.16ms +step:1484/1670 train_time:138250ms step_avg:93.16ms +step:1485/1670 train_time:138500ms step_avg:93.27ms +step:1486/1670 train_time:138571ms step_avg:93.25ms +step:1487/1670 train_time:138663ms step_avg:93.25ms +step:1488/1670 train_time:138754ms step_avg:93.25ms +step:1489/1670 train_time:138847ms step_avg:93.25ms +step:1490/1670 train_time:138939ms step_avg:93.25ms +step:1491/1670 train_time:139030ms step_avg:93.25ms +step:1492/1670 train_time:139122ms step_avg:93.25ms +step:1493/1670 train_time:139215ms step_avg:93.25ms +step:1494/1670 train_time:139307ms step_avg:93.24ms +step:1495/1670 train_time:139405ms step_avg:93.25ms +step:1496/1670 train_time:139503ms step_avg:93.25ms +step:1497/1670 train_time:139597ms step_avg:93.25ms +step:1498/1670 train_time:139692ms step_avg:93.25ms +step:1499/1670 train_time:139784ms step_avg:93.25ms +step:1500/1670 train_time:139876ms step_avg:93.25ms +step:1500/1670 val_loss:3.3143 train_time:139969ms step_avg:93.31ms +step:1501/1670 train_time:139988ms step_avg:93.26ms +step:1502/1670 train_time:140063ms step_avg:93.25ms +step:1503/1670 train_time:140156ms step_avg:93.25ms +step:1504/1670 train_time:140248ms step_avg:93.25ms +step:1505/1670 train_time:140341ms step_avg:93.25ms +step:1506/1670 train_time:140433ms step_avg:93.25ms +step:1507/1670 train_time:140526ms step_avg:93.25ms +step:1508/1670 train_time:140618ms step_avg:93.25ms +step:1509/1670 train_time:140711ms step_avg:93.25ms +step:1510/1670 train_time:140805ms step_avg:93.25ms +step:1511/1670 train_time:140899ms step_avg:93.25ms +step:1512/1670 train_time:140993ms step_avg:93.25ms +step:1513/1670 train_time:141087ms step_avg:93.25ms +step:1514/1670 train_time:141181ms step_avg:93.25ms +step:1515/1670 train_time:141273ms step_avg:93.25ms +step:1516/1670 train_time:141367ms step_avg:93.25ms +step:1517/1670 train_time:141461ms step_avg:93.25ms +step:1518/1670 train_time:141553ms step_avg:93.25ms +step:1519/1670 train_time:141646ms step_avg:93.25ms +step:1520/1670 train_time:141739ms step_avg:93.25ms +step:1521/1670 train_time:141833ms step_avg:93.25ms +step:1522/1670 train_time:141927ms step_avg:93.25ms +step:1523/1670 train_time:142021ms step_avg:93.25ms +step:1524/1670 train_time:142113ms step_avg:93.25ms +step:1525/1670 train_time:142208ms step_avg:93.25ms +step:1526/1670 train_time:142302ms step_avg:93.25ms +step:1527/1670 train_time:142395ms step_avg:93.25ms +step:1528/1670 train_time:142488ms step_avg:93.25ms +step:1529/1670 train_time:142580ms step_avg:93.25ms +step:1530/1670 train_time:142673ms step_avg:93.25ms +step:1531/1670 train_time:142768ms step_avg:93.25ms +step:1532/1670 train_time:142861ms step_avg:93.25ms +step:1533/1670 train_time:142955ms step_avg:93.25ms +step:1534/1670 train_time:143048ms step_avg:93.25ms +step:1535/1670 train_time:143141ms step_avg:93.25ms +step:1536/1670 train_time:143234ms step_avg:93.25ms +step:1537/1670 train_time:143327ms step_avg:93.25ms +step:1538/1670 train_time:143420ms step_avg:93.25ms +step:1539/1670 train_time:143513ms step_avg:93.25ms +step:1540/1670 train_time:143607ms step_avg:93.25ms +step:1541/1670 train_time:143700ms step_avg:93.25ms +step:1542/1670 train_time:143794ms step_avg:93.25ms +step:1543/1670 train_time:143887ms step_avg:93.25ms +step:1544/1670 train_time:143980ms step_avg:93.25ms +step:1545/1670 train_time:144074ms step_avg:93.25ms +step:1546/1670 train_time:144167ms step_avg:93.25ms +step:1547/1670 train_time:144260ms step_avg:93.25ms +step:1548/1670 train_time:144354ms step_avg:93.25ms +step:1549/1670 train_time:144446ms step_avg:93.25ms +step:1550/1670 train_time:144539ms step_avg:93.25ms +step:1551/1670 train_time:144632ms step_avg:93.25ms +step:1552/1670 train_time:144728ms step_avg:93.25ms +step:1553/1670 train_time:144822ms step_avg:93.25ms +step:1554/1670 train_time:144914ms step_avg:93.25ms +step:1555/1670 train_time:145009ms step_avg:93.25ms +step:1556/1670 train_time:145102ms step_avg:93.25ms +step:1557/1670 train_time:145196ms step_avg:93.25ms +step:1558/1670 train_time:145290ms step_avg:93.25ms +step:1559/1670 train_time:145382ms step_avg:93.25ms +step:1560/1670 train_time:145475ms step_avg:93.25ms +step:1561/1670 train_time:145569ms step_avg:93.25ms +step:1562/1670 train_time:145663ms step_avg:93.25ms +step:1563/1670 train_time:145756ms step_avg:93.25ms +step:1564/1670 train_time:145849ms step_avg:93.25ms +step:1565/1670 train_time:145942ms step_avg:93.25ms +step:1566/1670 train_time:146034ms step_avg:93.25ms +step:1567/1670 train_time:146128ms step_avg:93.25ms +step:1568/1670 train_time:146221ms step_avg:93.25ms +step:1569/1670 train_time:146314ms step_avg:93.25ms +step:1570/1670 train_time:146407ms step_avg:93.25ms +step:1571/1670 train_time:146500ms step_avg:93.25ms +step:1572/1670 train_time:146593ms step_avg:93.25ms +step:1573/1670 train_time:146687ms step_avg:93.25ms +step:1574/1670 train_time:146780ms step_avg:93.25ms +step:1575/1670 train_time:146873ms step_avg:93.25ms +step:1576/1670 train_time:146966ms step_avg:93.25ms +step:1577/1670 train_time:147059ms step_avg:93.25ms +step:1578/1670 train_time:147152ms step_avg:93.25ms +step:1579/1670 train_time:147245ms step_avg:93.25ms +step:1580/1670 train_time:147338ms step_avg:93.25ms +step:1581/1670 train_time:147431ms step_avg:93.25ms +step:1582/1670 train_time:147525ms step_avg:93.25ms +step:1583/1670 train_time:147619ms step_avg:93.25ms +step:1584/1670 train_time:147712ms step_avg:93.25ms +step:1585/1670 train_time:147806ms step_avg:93.25ms +step:1586/1670 train_time:147900ms step_avg:93.25ms +step:1587/1670 train_time:147992ms step_avg:93.25ms +step:1588/1670 train_time:148085ms step_avg:93.25ms +step:1589/1670 train_time:148178ms step_avg:93.25ms +step:1590/1670 train_time:148271ms step_avg:93.25ms +step:1591/1670 train_time:148364ms step_avg:93.25ms +step:1592/1670 train_time:148457ms step_avg:93.25ms +step:1593/1670 train_time:148550ms step_avg:93.25ms +step:1594/1670 train_time:148643ms step_avg:93.25ms +step:1595/1670 train_time:148736ms step_avg:93.25ms +step:1596/1670 train_time:148830ms step_avg:93.25ms +step:1597/1670 train_time:148923ms step_avg:93.25ms +step:1598/1670 train_time:149016ms step_avg:93.25ms +step:1599/1670 train_time:149110ms step_avg:93.25ms +step:1600/1670 train_time:149204ms step_avg:93.25ms +step:1601/1670 train_time:149297ms step_avg:93.25ms +step:1602/1670 train_time:149389ms step_avg:93.25ms +step:1603/1670 train_time:149482ms step_avg:93.25ms +step:1604/1670 train_time:149575ms step_avg:93.25ms +step:1605/1670 train_time:149669ms step_avg:93.25ms +step:1606/1670 train_time:149763ms step_avg:93.25ms +step:1607/1670 train_time:149856ms step_avg:93.25ms +step:1608/1670 train_time:149949ms step_avg:93.25ms +step:1609/1670 train_time:150042ms step_avg:93.25ms +step:1610/1670 train_time:150136ms step_avg:93.25ms +step:1611/1670 train_time:150229ms step_avg:93.25ms +step:1612/1670 train_time:150322ms step_avg:93.25ms +step:1613/1670 train_time:150414ms step_avg:93.25ms +step:1614/1670 train_time:150509ms step_avg:93.25ms +step:1615/1670 train_time:150603ms step_avg:93.25ms +step:1616/1670 train_time:150696ms step_avg:93.25ms +step:1617/1670 train_time:150789ms step_avg:93.25ms +step:1618/1670 train_time:150882ms step_avg:93.25ms +step:1619/1670 train_time:150975ms step_avg:93.25ms +step:1620/1670 train_time:151070ms step_avg:93.25ms +step:1621/1670 train_time:151163ms step_avg:93.25ms +step:1622/1670 train_time:151255ms step_avg:93.25ms +step:1623/1670 train_time:151348ms step_avg:93.25ms +step:1624/1670 train_time:151441ms step_avg:93.25ms +step:1625/1670 train_time:151534ms step_avg:93.25ms +step:1625/1670 val_loss:3.2889 train_time:151627ms step_avg:93.31ms +step:1626/1670 train_time:151646ms step_avg:93.26ms +step:1627/1670 train_time:151721ms step_avg:93.25ms +step:1628/1670 train_time:151814ms step_avg:93.25ms +step:1629/1670 train_time:151907ms step_avg:93.25ms +step:1630/1670 train_time:151999ms step_avg:93.25ms +step:1631/1670 train_time:152092ms step_avg:93.25ms +step:1632/1670 train_time:152185ms step_avg:93.25ms +step:1633/1670 train_time:152278ms step_avg:93.25ms +step:1634/1670 train_time:152372ms step_avg:93.25ms +step:1635/1670 train_time:152464ms step_avg:93.25ms +step:1636/1670 train_time:152559ms step_avg:93.25ms +step:1637/1670 train_time:152655ms step_avg:93.25ms +step:1638/1670 train_time:152748ms step_avg:93.25ms +step:1639/1670 train_time:152841ms step_avg:93.25ms +step:1640/1670 train_time:152934ms step_avg:93.25ms +step:1641/1670 train_time:153026ms step_avg:93.25ms +step:1642/1670 train_time:153119ms step_avg:93.25ms +step:1643/1670 train_time:153213ms step_avg:93.25ms +step:1644/1670 train_time:153306ms step_avg:93.25ms +step:1645/1670 train_time:153398ms step_avg:93.25ms +step:1646/1670 train_time:153492ms step_avg:93.25ms +step:1647/1670 train_time:153586ms step_avg:93.25ms +step:1648/1670 train_time:153679ms step_avg:93.25ms +step:1649/1670 train_time:153773ms step_avg:93.25ms +step:1650/1670 train_time:153865ms step_avg:93.25ms +step:1651/1670 train_time:153958ms step_avg:93.25ms +step:1652/1670 train_time:154053ms step_avg:93.25ms +step:1653/1670 train_time:154145ms step_avg:93.25ms +step:1654/1670 train_time:154238ms step_avg:93.25ms +step:1655/1670 train_time:154332ms step_avg:93.25ms +step:1656/1670 train_time:154425ms step_avg:93.25ms +step:1657/1670 train_time:154519ms step_avg:93.25ms +step:1658/1670 train_time:154613ms step_avg:93.25ms +step:1659/1670 train_time:154706ms step_avg:93.25ms +step:1660/1670 train_time:154799ms step_avg:93.25ms +step:1661/1670 train_time:154891ms step_avg:93.25ms +step:1662/1670 train_time:154985ms step_avg:93.25ms +step:1663/1670 train_time:155078ms step_avg:93.25ms +step:1664/1670 train_time:155171ms step_avg:93.25ms +step:1665/1670 train_time:155264ms step_avg:93.25ms +step:1666/1670 train_time:155358ms step_avg:93.25ms +step:1667/1670 train_time:155451ms step_avg:93.25ms +step:1668/1670 train_time:155544ms step_avg:93.25ms +step:1669/1670 train_time:155639ms step_avg:93.25ms +step:1670/1670 train_time:155732ms step_avg:93.25ms +step:1670/1670 val_loss:3.2805 train_time:155980ms step_avg:93.40ms +peak memory allocated: 32002 MiB reserved: 46616 MiB diff --git a/records/091125_VectSigmoidBFloat16/7129a36e-505d-456b-aed5-ea8e455a0bac.txt b/records/091125_VectSigmoidBFloat16/7129a36e-505d-456b-aed5-ea8e455a0bac.txt new file mode 100644 index 000000000..1060915cc --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/7129a36e-505d-456b-aed5-ea8e455a0bac.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:09:03 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 125W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 46C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 38C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 130W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:299ms step_avg:298.66ms +step:2/1670 train_time:318ms step_avg:158.82ms +step:3/1670 train_time:385ms step_avg:128.48ms +step:4/1670 train_time:475ms step_avg:118.73ms +step:5/1670 train_time:565ms step_avg:113.08ms +step:6/1670 train_time:656ms step_avg:109.35ms +step:7/1670 train_time:746ms step_avg:106.63ms +step:8/1670 train_time:837ms step_avg:104.67ms +step:9/1670 train_time:928ms step_avg:103.08ms +step:10/1670 train_time:1019ms step_avg:101.86ms +step:11/1670 train_time:1109ms step_avg:100.85ms +step:12/1670 train_time:1202ms step_avg:100.19ms +step:13/1670 train_time:1298ms step_avg:99.84ms +step:14/1670 train_time:1390ms step_avg:99.27ms +step:15/1670 train_time:1481ms step_avg:98.76ms +step:16/1670 train_time:1575ms step_avg:98.41ms +step:17/1670 train_time:1666ms step_avg:98.03ms +step:18/1670 train_time:1758ms step_avg:97.64ms +step:19/1670 train_time:1848ms step_avg:97.26ms +step:20/1670 train_time:1939ms step_avg:96.95ms +step:21/1670 train_time:2030ms step_avg:96.65ms +step:22/1670 train_time:2121ms step_avg:96.42ms +step:23/1670 train_time:2214ms step_avg:96.25ms +step:24/1670 train_time:2307ms step_avg:96.13ms +step:25/1670 train_time:2399ms step_avg:95.97ms +step:26/1670 train_time:2490ms step_avg:95.79ms +step:27/1670 train_time:2583ms step_avg:95.68ms +step:28/1670 train_time:2675ms step_avg:95.52ms +step:29/1670 train_time:2766ms step_avg:95.37ms +step:30/1670 train_time:2856ms step_avg:95.21ms +step:31/1670 train_time:2947ms step_avg:95.05ms +step:32/1670 train_time:3039ms step_avg:94.98ms +step:33/1670 train_time:3132ms step_avg:94.90ms +step:34/1670 train_time:3224ms step_avg:94.84ms +step:35/1670 train_time:3318ms step_avg:94.80ms +step:36/1670 train_time:3410ms step_avg:94.71ms +step:37/1670 train_time:3503ms step_avg:94.68ms +step:38/1670 train_time:3595ms step_avg:94.61ms +step:39/1670 train_time:3688ms step_avg:94.56ms +step:40/1670 train_time:3779ms step_avg:94.49ms +step:41/1670 train_time:3870ms step_avg:94.40ms +step:42/1670 train_time:3962ms step_avg:94.35ms +step:43/1670 train_time:4055ms step_avg:94.30ms +step:44/1670 train_time:4145ms step_avg:94.21ms +step:45/1670 train_time:4237ms step_avg:94.15ms +step:46/1670 train_time:4328ms step_avg:94.09ms +step:47/1670 train_time:4421ms step_avg:94.06ms +step:48/1670 train_time:4512ms step_avg:94.00ms +step:49/1670 train_time:4605ms step_avg:93.98ms +step:50/1670 train_time:4697ms step_avg:93.94ms +step:51/1670 train_time:4789ms step_avg:93.89ms +step:52/1670 train_time:4880ms step_avg:93.85ms +step:53/1670 train_time:4971ms step_avg:93.80ms +step:54/1670 train_time:5063ms step_avg:93.76ms +step:55/1670 train_time:5154ms step_avg:93.71ms +step:56/1670 train_time:5245ms step_avg:93.67ms +step:57/1670 train_time:5337ms step_avg:93.63ms +step:58/1670 train_time:5429ms step_avg:93.61ms +step:59/1670 train_time:5522ms step_avg:93.60ms +step:60/1670 train_time:5613ms step_avg:93.56ms +step:61/1670 train_time:5707ms step_avg:93.56ms +step:62/1670 train_time:5799ms step_avg:93.53ms +step:63/1670 train_time:5890ms step_avg:93.49ms +step:64/1670 train_time:5982ms step_avg:93.47ms +step:65/1670 train_time:6074ms step_avg:93.45ms +step:66/1670 train_time:6166ms step_avg:93.42ms +step:67/1670 train_time:6257ms step_avg:93.39ms +step:68/1670 train_time:6348ms step_avg:93.35ms +step:69/1670 train_time:6441ms step_avg:93.35ms +step:70/1670 train_time:6533ms step_avg:93.32ms +step:71/1670 train_time:6624ms step_avg:93.30ms +step:72/1670 train_time:6715ms step_avg:93.27ms +step:73/1670 train_time:6808ms step_avg:93.26ms +step:74/1670 train_time:6900ms step_avg:93.25ms +step:75/1670 train_time:6992ms step_avg:93.23ms +step:76/1670 train_time:7085ms step_avg:93.22ms +step:77/1670 train_time:7176ms step_avg:93.20ms +step:78/1670 train_time:7267ms step_avg:93.17ms +step:79/1670 train_time:7358ms step_avg:93.14ms +step:80/1670 train_time:7449ms step_avg:93.11ms +step:81/1670 train_time:7540ms step_avg:93.09ms +step:82/1670 train_time:7631ms step_avg:93.07ms +step:83/1670 train_time:7724ms step_avg:93.06ms +step:84/1670 train_time:7816ms step_avg:93.04ms +step:85/1670 train_time:7908ms step_avg:93.04ms +step:86/1670 train_time:8000ms step_avg:93.02ms +step:87/1670 train_time:8091ms step_avg:93.00ms +step:88/1670 train_time:8184ms step_avg:93.00ms +step:89/1670 train_time:8275ms step_avg:92.98ms +step:90/1670 train_time:8366ms step_avg:92.96ms +step:91/1670 train_time:8457ms step_avg:92.94ms +step:92/1670 train_time:8548ms step_avg:92.91ms +step:93/1670 train_time:8640ms step_avg:92.91ms +step:94/1670 train_time:8731ms step_avg:92.88ms +step:95/1670 train_time:8822ms step_avg:92.86ms +step:96/1670 train_time:8912ms step_avg:92.84ms +step:97/1670 train_time:9004ms step_avg:92.83ms +step:98/1670 train_time:9095ms step_avg:92.81ms +step:99/1670 train_time:9187ms step_avg:92.80ms +step:100/1670 train_time:9279ms step_avg:92.79ms +step:101/1670 train_time:9371ms step_avg:92.78ms +step:102/1670 train_time:9463ms step_avg:92.78ms +step:103/1670 train_time:9554ms step_avg:92.75ms +step:104/1670 train_time:9645ms step_avg:92.75ms +step:105/1670 train_time:9736ms step_avg:92.73ms +step:106/1670 train_time:9827ms step_avg:92.71ms +step:107/1670 train_time:9918ms step_avg:92.69ms +step:108/1670 train_time:10009ms step_avg:92.68ms +step:109/1670 train_time:10101ms step_avg:92.67ms +step:110/1670 train_time:10192ms step_avg:92.66ms +step:111/1670 train_time:10285ms step_avg:92.66ms +step:112/1670 train_time:10376ms step_avg:92.65ms +step:113/1670 train_time:10468ms step_avg:92.64ms +step:114/1670 train_time:10559ms step_avg:92.62ms +step:115/1670 train_time:10650ms step_avg:92.61ms +step:116/1670 train_time:10740ms step_avg:92.59ms +step:117/1670 train_time:10831ms step_avg:92.57ms +step:118/1670 train_time:10922ms step_avg:92.56ms +step:119/1670 train_time:11014ms step_avg:92.55ms +step:120/1670 train_time:11106ms step_avg:92.55ms +step:121/1670 train_time:11199ms step_avg:92.55ms +step:122/1670 train_time:11290ms step_avg:92.54ms +step:123/1670 train_time:11383ms step_avg:92.54ms +step:124/1670 train_time:11474ms step_avg:92.53ms +step:125/1670 train_time:11566ms step_avg:92.53ms +step:125/1670 val_loss:4.3003 train_time:11657ms step_avg:93.26ms +step:126/1670 train_time:11677ms step_avg:92.68ms +step:127/1670 train_time:11754ms step_avg:92.55ms +step:128/1670 train_time:11855ms step_avg:92.62ms +step:129/1670 train_time:11948ms step_avg:92.62ms +step:130/1670 train_time:12039ms step_avg:92.60ms +step:131/1670 train_time:12128ms step_avg:92.58ms +step:132/1670 train_time:12218ms step_avg:92.56ms +step:133/1670 train_time:12308ms step_avg:92.54ms +step:134/1670 train_time:12398ms step_avg:92.52ms +step:135/1670 train_time:12488ms step_avg:92.50ms +step:136/1670 train_time:12578ms step_avg:92.49ms +step:137/1670 train_time:12669ms step_avg:92.47ms +step:138/1670 train_time:12764ms step_avg:92.49ms +step:139/1670 train_time:12858ms step_avg:92.51ms +step:140/1670 train_time:12950ms step_avg:92.50ms +step:141/1670 train_time:13042ms step_avg:92.49ms +step:142/1670 train_time:13132ms step_avg:92.48ms +step:143/1670 train_time:13223ms step_avg:92.47ms +step:144/1670 train_time:13314ms step_avg:92.46ms +step:145/1670 train_time:13404ms step_avg:92.44ms +step:146/1670 train_time:13494ms step_avg:92.42ms +step:147/1670 train_time:13585ms step_avg:92.41ms +step:148/1670 train_time:13677ms step_avg:92.42ms +step:149/1670 train_time:13770ms step_avg:92.41ms +step:150/1670 train_time:13864ms step_avg:92.43ms +step:151/1670 train_time:13956ms step_avg:92.43ms +step:152/1670 train_time:14047ms step_avg:92.41ms +step:153/1670 train_time:14138ms step_avg:92.41ms +step:154/1670 train_time:14229ms step_avg:92.39ms +step:155/1670 train_time:14319ms step_avg:92.38ms +step:156/1670 train_time:14410ms step_avg:92.37ms +step:157/1670 train_time:14500ms step_avg:92.36ms +step:158/1670 train_time:14591ms step_avg:92.35ms +step:159/1670 train_time:14682ms step_avg:92.34ms +step:160/1670 train_time:14775ms step_avg:92.35ms +step:161/1670 train_time:14868ms step_avg:92.35ms +step:162/1670 train_time:14959ms step_avg:92.34ms +step:163/1670 train_time:15049ms step_avg:92.33ms +step:164/1670 train_time:15141ms step_avg:92.32ms +step:165/1670 train_time:15232ms step_avg:92.31ms +step:166/1670 train_time:15323ms step_avg:92.31ms +step:167/1670 train_time:15415ms step_avg:92.30ms +step:168/1670 train_time:15506ms step_avg:92.30ms +step:169/1670 train_time:15596ms step_avg:92.29ms +step:170/1670 train_time:15687ms step_avg:92.28ms +step:171/1670 train_time:15781ms step_avg:92.29ms +step:172/1670 train_time:15873ms step_avg:92.28ms +step:173/1670 train_time:15965ms step_avg:92.28ms +step:174/1670 train_time:16056ms step_avg:92.28ms +step:175/1670 train_time:16148ms step_avg:92.27ms +step:176/1670 train_time:16239ms step_avg:92.27ms +step:177/1670 train_time:16329ms step_avg:92.26ms +step:178/1670 train_time:16420ms step_avg:92.25ms +step:179/1670 train_time:16511ms step_avg:92.24ms +step:180/1670 train_time:16601ms step_avg:92.23ms +step:181/1670 train_time:16691ms step_avg:92.22ms +step:182/1670 train_time:16783ms step_avg:92.21ms +step:183/1670 train_time:16873ms step_avg:92.20ms +step:184/1670 train_time:16965ms step_avg:92.20ms +step:185/1670 train_time:17057ms step_avg:92.20ms +step:186/1670 train_time:17149ms step_avg:92.20ms +step:187/1670 train_time:17241ms step_avg:92.20ms +step:188/1670 train_time:17331ms step_avg:92.19ms +step:189/1670 train_time:17421ms step_avg:92.18ms +step:190/1670 train_time:17512ms step_avg:92.17ms +step:191/1670 train_time:17603ms step_avg:92.16ms +step:192/1670 train_time:17694ms step_avg:92.16ms +step:193/1670 train_time:17787ms step_avg:92.16ms +step:194/1670 train_time:17879ms step_avg:92.16ms +step:195/1670 train_time:17970ms step_avg:92.15ms +step:196/1670 train_time:18061ms step_avg:92.15ms +step:197/1670 train_time:18152ms step_avg:92.14ms +step:198/1670 train_time:18244ms step_avg:92.14ms +step:199/1670 train_time:18336ms step_avg:92.14ms +step:200/1670 train_time:18427ms step_avg:92.14ms +step:201/1670 train_time:18518ms step_avg:92.13ms +step:202/1670 train_time:18608ms step_avg:92.12ms +step:203/1670 train_time:18699ms step_avg:92.11ms +step:204/1670 train_time:18790ms step_avg:92.11ms +step:205/1670 train_time:18882ms step_avg:92.11ms +step:206/1670 train_time:18972ms step_avg:92.10ms +step:207/1670 train_time:19064ms step_avg:92.10ms +step:208/1670 train_time:19155ms step_avg:92.09ms +step:209/1670 train_time:19247ms step_avg:92.09ms +step:210/1670 train_time:19339ms step_avg:92.09ms +step:211/1670 train_time:19430ms step_avg:92.08ms +step:212/1670 train_time:19520ms step_avg:92.08ms +step:213/1670 train_time:19771ms step_avg:92.82ms +step:214/1670 train_time:19839ms step_avg:92.71ms +step:215/1670 train_time:19929ms step_avg:92.69ms +step:216/1670 train_time:20018ms step_avg:92.68ms +step:217/1670 train_time:20108ms step_avg:92.66ms +step:218/1670 train_time:20198ms step_avg:92.65ms +step:219/1670 train_time:20287ms step_avg:92.64ms +step:220/1670 train_time:20377ms step_avg:92.62ms +step:221/1670 train_time:20467ms step_avg:92.61ms +step:222/1670 train_time:20557ms step_avg:92.60ms +step:223/1670 train_time:20652ms step_avg:92.61ms +step:224/1670 train_time:20748ms step_avg:92.62ms +step:225/1670 train_time:20841ms step_avg:92.62ms +step:226/1670 train_time:20931ms step_avg:92.61ms +step:227/1670 train_time:21021ms step_avg:92.60ms +step:228/1670 train_time:21111ms step_avg:92.59ms +step:229/1670 train_time:21202ms step_avg:92.59ms +step:230/1670 train_time:21292ms step_avg:92.58ms +step:231/1670 train_time:21383ms step_avg:92.57ms +step:232/1670 train_time:21473ms step_avg:92.56ms +step:233/1670 train_time:21565ms step_avg:92.55ms +step:234/1670 train_time:21658ms step_avg:92.56ms +step:235/1670 train_time:21750ms step_avg:92.55ms +step:236/1670 train_time:21843ms step_avg:92.55ms +step:237/1670 train_time:21935ms step_avg:92.55ms +step:238/1670 train_time:22026ms step_avg:92.55ms +step:239/1670 train_time:22117ms step_avg:92.54ms +step:240/1670 train_time:22207ms step_avg:92.53ms +step:241/1670 train_time:22300ms step_avg:92.53ms +step:242/1670 train_time:22389ms step_avg:92.52ms +step:243/1670 train_time:22480ms step_avg:92.51ms +step:244/1670 train_time:22571ms step_avg:92.50ms +step:245/1670 train_time:22664ms step_avg:92.50ms +step:246/1670 train_time:22756ms step_avg:92.50ms +step:247/1670 train_time:22850ms step_avg:92.51ms +step:248/1670 train_time:22942ms step_avg:92.51ms +step:249/1670 train_time:23033ms step_avg:92.50ms +step:250/1670 train_time:23123ms step_avg:92.49ms +step:250/1670 val_loss:3.9688 train_time:23213ms step_avg:92.85ms +step:251/1670 train_time:23234ms step_avg:92.56ms +step:252/1670 train_time:23307ms step_avg:92.49ms +step:253/1670 train_time:23402ms step_avg:92.50ms +step:254/1670 train_time:23493ms step_avg:92.49ms +step:255/1670 train_time:23582ms step_avg:92.48ms +step:256/1670 train_time:23672ms step_avg:92.47ms +step:257/1670 train_time:23762ms step_avg:92.46ms +step:258/1670 train_time:23853ms step_avg:92.45ms +step:259/1670 train_time:23943ms step_avg:92.44ms +step:260/1670 train_time:24034ms step_avg:92.44ms +step:261/1670 train_time:24126ms step_avg:92.44ms +step:262/1670 train_time:24218ms step_avg:92.44ms +step:263/1670 train_time:24311ms step_avg:92.44ms +step:264/1670 train_time:24403ms step_avg:92.44ms +step:265/1670 train_time:24494ms step_avg:92.43ms +step:266/1670 train_time:24585ms step_avg:92.43ms +step:267/1670 train_time:24675ms step_avg:92.42ms +step:268/1670 train_time:24765ms step_avg:92.41ms +step:269/1670 train_time:24855ms step_avg:92.40ms +step:270/1670 train_time:24945ms step_avg:92.39ms +step:271/1670 train_time:25037ms step_avg:92.39ms +step:272/1670 train_time:25128ms step_avg:92.38ms +step:273/1670 train_time:25221ms step_avg:92.39ms +step:274/1670 train_time:25313ms step_avg:92.38ms +step:275/1670 train_time:25405ms step_avg:92.38ms +step:276/1670 train_time:25497ms step_avg:92.38ms +step:277/1670 train_time:25588ms step_avg:92.38ms +step:278/1670 train_time:25680ms step_avg:92.37ms +step:279/1670 train_time:25772ms step_avg:92.37ms +step:280/1670 train_time:25863ms step_avg:92.37ms +step:281/1670 train_time:25954ms step_avg:92.36ms +step:282/1670 train_time:26045ms step_avg:92.36ms +step:283/1670 train_time:26135ms step_avg:92.35ms +step:284/1670 train_time:26227ms step_avg:92.35ms +step:285/1670 train_time:26320ms step_avg:92.35ms +step:286/1670 train_time:26412ms step_avg:92.35ms +step:287/1670 train_time:26503ms step_avg:92.35ms +step:288/1670 train_time:26594ms step_avg:92.34ms +step:289/1670 train_time:26684ms step_avg:92.33ms +step:290/1670 train_time:26774ms step_avg:92.33ms +step:291/1670 train_time:26865ms step_avg:92.32ms +step:292/1670 train_time:26956ms step_avg:92.32ms +step:293/1670 train_time:27047ms step_avg:92.31ms +step:294/1670 train_time:27138ms step_avg:92.31ms +step:295/1670 train_time:27229ms step_avg:92.30ms +step:296/1670 train_time:27321ms step_avg:92.30ms +step:297/1670 train_time:27412ms step_avg:92.30ms +step:298/1670 train_time:27504ms step_avg:92.30ms +step:299/1670 train_time:27596ms step_avg:92.29ms +step:300/1670 train_time:27687ms step_avg:92.29ms +step:301/1670 train_time:27777ms step_avg:92.28ms +step:302/1670 train_time:27868ms step_avg:92.28ms +step:303/1670 train_time:27960ms step_avg:92.28ms +step:304/1670 train_time:28051ms step_avg:92.27ms +step:305/1670 train_time:28144ms step_avg:92.28ms +step:306/1670 train_time:28236ms step_avg:92.27ms +step:307/1670 train_time:28328ms step_avg:92.27ms +step:308/1670 train_time:28420ms step_avg:92.27ms +step:309/1670 train_time:28511ms step_avg:92.27ms +step:310/1670 train_time:28602ms step_avg:92.26ms +step:311/1670 train_time:28693ms step_avg:92.26ms +step:312/1670 train_time:28783ms step_avg:92.25ms +step:313/1670 train_time:28874ms step_avg:92.25ms +step:314/1670 train_time:28965ms step_avg:92.24ms +step:315/1670 train_time:29056ms step_avg:92.24ms +step:316/1670 train_time:29147ms step_avg:92.24ms +step:317/1670 train_time:29239ms step_avg:92.24ms +step:318/1670 train_time:29331ms step_avg:92.24ms +step:319/1670 train_time:29423ms step_avg:92.23ms +step:320/1670 train_time:29514ms step_avg:92.23ms +step:321/1670 train_time:29606ms step_avg:92.23ms +step:322/1670 train_time:29697ms step_avg:92.23ms +step:323/1670 train_time:29787ms step_avg:92.22ms +step:324/1670 train_time:29878ms step_avg:92.22ms +step:325/1670 train_time:29968ms step_avg:92.21ms +step:326/1670 train_time:30060ms step_avg:92.21ms +step:327/1670 train_time:30152ms step_avg:92.21ms +step:328/1670 train_time:30244ms step_avg:92.21ms +step:329/1670 train_time:30335ms step_avg:92.20ms +step:330/1670 train_time:30426ms step_avg:92.20ms +step:331/1670 train_time:30517ms step_avg:92.20ms +step:332/1670 train_time:30609ms step_avg:92.19ms +step:333/1670 train_time:30700ms step_avg:92.19ms +step:334/1670 train_time:30791ms step_avg:92.19ms +step:335/1670 train_time:30882ms step_avg:92.19ms +step:336/1670 train_time:30973ms step_avg:92.18ms +step:337/1670 train_time:31064ms step_avg:92.18ms +step:338/1670 train_time:31156ms step_avg:92.18ms +step:339/1670 train_time:31247ms step_avg:92.17ms +step:340/1670 train_time:31338ms step_avg:92.17ms +step:341/1670 train_time:31430ms step_avg:92.17ms +step:342/1670 train_time:31522ms step_avg:92.17ms +step:343/1670 train_time:31613ms step_avg:92.16ms +step:344/1670 train_time:31704ms step_avg:92.16ms +step:345/1670 train_time:31795ms step_avg:92.16ms +step:346/1670 train_time:31885ms step_avg:92.15ms +step:347/1670 train_time:31976ms step_avg:92.15ms +step:348/1670 train_time:32067ms step_avg:92.15ms +step:349/1670 train_time:32160ms step_avg:92.15ms +step:350/1670 train_time:32251ms step_avg:92.14ms +step:351/1670 train_time:32342ms step_avg:92.14ms +step:352/1670 train_time:32433ms step_avg:92.14ms +step:353/1670 train_time:32524ms step_avg:92.14ms +step:354/1670 train_time:32615ms step_avg:92.13ms +step:355/1670 train_time:32706ms step_avg:92.13ms +step:356/1670 train_time:32797ms step_avg:92.13ms +step:357/1670 train_time:32887ms step_avg:92.12ms +step:358/1670 train_time:32980ms step_avg:92.12ms +step:359/1670 train_time:33071ms step_avg:92.12ms +step:360/1670 train_time:33163ms step_avg:92.12ms +step:361/1670 train_time:33255ms step_avg:92.12ms +step:362/1670 train_time:33346ms step_avg:92.12ms +step:363/1670 train_time:33437ms step_avg:92.11ms +step:364/1670 train_time:33528ms step_avg:92.11ms +step:365/1670 train_time:33619ms step_avg:92.11ms +step:366/1670 train_time:33709ms step_avg:92.10ms +step:367/1670 train_time:33801ms step_avg:92.10ms +step:368/1670 train_time:33892ms step_avg:92.10ms +step:369/1670 train_time:33983ms step_avg:92.09ms +step:370/1670 train_time:34074ms step_avg:92.09ms +step:371/1670 train_time:34165ms step_avg:92.09ms +step:372/1670 train_time:34257ms step_avg:92.09ms +step:373/1670 train_time:34348ms step_avg:92.08ms +step:374/1670 train_time:34440ms step_avg:92.08ms +step:375/1670 train_time:34530ms step_avg:92.08ms +step:375/1670 val_loss:3.8112 train_time:34622ms step_avg:92.32ms +step:376/1670 train_time:34641ms step_avg:92.13ms +step:377/1670 train_time:34715ms step_avg:92.08ms +step:378/1670 train_time:34807ms step_avg:92.08ms +step:379/1670 train_time:34898ms step_avg:92.08ms +step:380/1670 train_time:34989ms step_avg:92.08ms +step:381/1670 train_time:35078ms step_avg:92.07ms +step:382/1670 train_time:35169ms step_avg:92.06ms +step:383/1670 train_time:35261ms step_avg:92.06ms +step:384/1670 train_time:35352ms step_avg:92.06ms +step:385/1670 train_time:35443ms step_avg:92.06ms +step:386/1670 train_time:35535ms step_avg:92.06ms +step:387/1670 train_time:35627ms step_avg:92.06ms +step:388/1670 train_time:35718ms step_avg:92.06ms +step:389/1670 train_time:35809ms step_avg:92.05ms +step:390/1670 train_time:35901ms step_avg:92.05ms +step:391/1670 train_time:35993ms step_avg:92.05ms +step:392/1670 train_time:36083ms step_avg:92.05ms +step:393/1670 train_time:36173ms step_avg:92.04ms +step:394/1670 train_time:36264ms step_avg:92.04ms +step:395/1670 train_time:36354ms step_avg:92.03ms +step:396/1670 train_time:36445ms step_avg:92.03ms +step:397/1670 train_time:36537ms step_avg:92.03ms +step:398/1670 train_time:36630ms step_avg:92.04ms +step:399/1670 train_time:36721ms step_avg:92.03ms +step:400/1670 train_time:36812ms step_avg:92.03ms +step:401/1670 train_time:36903ms step_avg:92.03ms +step:402/1670 train_time:36994ms step_avg:92.03ms +step:403/1670 train_time:37087ms step_avg:92.03ms +step:404/1670 train_time:37177ms step_avg:92.02ms +step:405/1670 train_time:37268ms step_avg:92.02ms +step:406/1670 train_time:37358ms step_avg:92.01ms +step:407/1670 train_time:37449ms step_avg:92.01ms +step:408/1670 train_time:37539ms step_avg:92.01ms +step:409/1670 train_time:37632ms step_avg:92.01ms +step:410/1670 train_time:37724ms step_avg:92.01ms +step:411/1670 train_time:37815ms step_avg:92.01ms +step:412/1670 train_time:37906ms step_avg:92.00ms +step:413/1670 train_time:37997ms step_avg:92.00ms +step:414/1670 train_time:38088ms step_avg:92.00ms +step:415/1670 train_time:38178ms step_avg:92.00ms +step:416/1670 train_time:38270ms step_avg:92.00ms +step:417/1670 train_time:38361ms step_avg:91.99ms +step:418/1670 train_time:38452ms step_avg:91.99ms +step:419/1670 train_time:38542ms step_avg:91.99ms +step:420/1670 train_time:38634ms step_avg:91.99ms +step:421/1670 train_time:38727ms step_avg:91.99ms +step:422/1670 train_time:38817ms step_avg:91.98ms +step:423/1670 train_time:38908ms step_avg:91.98ms +step:424/1670 train_time:38999ms step_avg:91.98ms +step:425/1670 train_time:39249ms step_avg:92.35ms +step:426/1670 train_time:39322ms step_avg:92.31ms +step:427/1670 train_time:39412ms step_avg:92.30ms +step:428/1670 train_time:39502ms step_avg:92.29ms +step:429/1670 train_time:39592ms step_avg:92.29ms +step:430/1670 train_time:39682ms step_avg:92.28ms +step:431/1670 train_time:39772ms step_avg:92.28ms +step:432/1670 train_time:39861ms step_avg:92.27ms +step:433/1670 train_time:39951ms step_avg:92.27ms +step:434/1670 train_time:40041ms step_avg:92.26ms +step:435/1670 train_time:40136ms step_avg:92.27ms +step:436/1670 train_time:40234ms step_avg:92.28ms +step:437/1670 train_time:40327ms step_avg:92.28ms +step:438/1670 train_time:40419ms step_avg:92.28ms +step:439/1670 train_time:40511ms step_avg:92.28ms +step:440/1670 train_time:40601ms step_avg:92.27ms +step:441/1670 train_time:40692ms step_avg:92.27ms +step:442/1670 train_time:40782ms step_avg:92.27ms +step:443/1670 train_time:40872ms step_avg:92.26ms +step:444/1670 train_time:40962ms step_avg:92.26ms +step:445/1670 train_time:41054ms step_avg:92.26ms +step:446/1670 train_time:41147ms step_avg:92.26ms +step:447/1670 train_time:41239ms step_avg:92.26ms +step:448/1670 train_time:41332ms step_avg:92.26ms +step:449/1670 train_time:41423ms step_avg:92.26ms +step:450/1670 train_time:41514ms step_avg:92.25ms +step:451/1670 train_time:41605ms step_avg:92.25ms +step:452/1670 train_time:41695ms step_avg:92.25ms +step:453/1670 train_time:41786ms step_avg:92.24ms +step:454/1670 train_time:41876ms step_avg:92.24ms +step:455/1670 train_time:41966ms step_avg:92.23ms +step:456/1670 train_time:42058ms step_avg:92.23ms +step:457/1670 train_time:42151ms step_avg:92.23ms +step:458/1670 train_time:42242ms step_avg:92.23ms +step:459/1670 train_time:42335ms step_avg:92.23ms +step:460/1670 train_time:42426ms step_avg:92.23ms +step:461/1670 train_time:42517ms step_avg:92.23ms +step:462/1670 train_time:42608ms step_avg:92.22ms +step:463/1670 train_time:42698ms step_avg:92.22ms +step:464/1670 train_time:42789ms step_avg:92.22ms +step:465/1670 train_time:42880ms step_avg:92.22ms +step:466/1670 train_time:42969ms step_avg:92.21ms +step:467/1670 train_time:43060ms step_avg:92.21ms +step:468/1670 train_time:43152ms step_avg:92.21ms +step:469/1670 train_time:43243ms step_avg:92.20ms +step:470/1670 train_time:43336ms step_avg:92.20ms +step:471/1670 train_time:43428ms step_avg:92.20ms +step:472/1670 train_time:43518ms step_avg:92.20ms +step:473/1670 train_time:43610ms step_avg:92.20ms +step:474/1670 train_time:43700ms step_avg:92.19ms +step:475/1670 train_time:43792ms step_avg:92.19ms +step:476/1670 train_time:43882ms step_avg:92.19ms +step:477/1670 train_time:43973ms step_avg:92.19ms +step:478/1670 train_time:44065ms step_avg:92.19ms +step:479/1670 train_time:44155ms step_avg:92.18ms +step:480/1670 train_time:44246ms step_avg:92.18ms +step:481/1670 train_time:44338ms step_avg:92.18ms +step:482/1670 train_time:44430ms step_avg:92.18ms +step:483/1670 train_time:44520ms step_avg:92.17ms +step:484/1670 train_time:44612ms step_avg:92.17ms +step:485/1670 train_time:44702ms step_avg:92.17ms +step:486/1670 train_time:44794ms step_avg:92.17ms +step:487/1670 train_time:44885ms step_avg:92.17ms +step:488/1670 train_time:44976ms step_avg:92.16ms +step:489/1670 train_time:45067ms step_avg:92.16ms +step:490/1670 train_time:45157ms step_avg:92.16ms +step:491/1670 train_time:45249ms step_avg:92.16ms +step:492/1670 train_time:45340ms step_avg:92.15ms +step:493/1670 train_time:45432ms step_avg:92.15ms +step:494/1670 train_time:45523ms step_avg:92.15ms +step:495/1670 train_time:45614ms step_avg:92.15ms +step:496/1670 train_time:45706ms step_avg:92.15ms +step:497/1670 train_time:45796ms step_avg:92.15ms +step:498/1670 train_time:45887ms step_avg:92.14ms +step:499/1670 train_time:45978ms step_avg:92.14ms +step:500/1670 train_time:46069ms step_avg:92.14ms +step:500/1670 val_loss:3.7119 train_time:46159ms step_avg:92.32ms +step:501/1670 train_time:46179ms step_avg:92.17ms +step:502/1670 train_time:46252ms step_avg:92.14ms +step:503/1670 train_time:46343ms step_avg:92.13ms +step:504/1670 train_time:46434ms step_avg:92.13ms +step:505/1670 train_time:46524ms step_avg:92.13ms +step:506/1670 train_time:46615ms step_avg:92.12ms +step:507/1670 train_time:46706ms step_avg:92.12ms +step:508/1670 train_time:46797ms step_avg:92.12ms +step:509/1670 train_time:46888ms step_avg:92.12ms +step:510/1670 train_time:46979ms step_avg:92.11ms +step:511/1670 train_time:47070ms step_avg:92.11ms +step:512/1670 train_time:47163ms step_avg:92.12ms +step:513/1670 train_time:47256ms step_avg:92.12ms +step:514/1670 train_time:47346ms step_avg:92.11ms +step:515/1670 train_time:47437ms step_avg:92.11ms +step:516/1670 train_time:47527ms step_avg:92.11ms +step:517/1670 train_time:47617ms step_avg:92.10ms +step:518/1670 train_time:47709ms step_avg:92.10ms +step:519/1670 train_time:47802ms step_avg:92.10ms +step:520/1670 train_time:47893ms step_avg:92.10ms +step:521/1670 train_time:47984ms step_avg:92.10ms +step:522/1670 train_time:48077ms step_avg:92.10ms +step:523/1670 train_time:48168ms step_avg:92.10ms +step:524/1670 train_time:48259ms step_avg:92.10ms +step:525/1670 train_time:48350ms step_avg:92.10ms +step:526/1670 train_time:48441ms step_avg:92.09ms +step:527/1670 train_time:48531ms step_avg:92.09ms +step:528/1670 train_time:48622ms step_avg:92.09ms +step:529/1670 train_time:48713ms step_avg:92.09ms +step:530/1670 train_time:48804ms step_avg:92.08ms +step:531/1670 train_time:48895ms step_avg:92.08ms +step:532/1670 train_time:48987ms step_avg:92.08ms +step:533/1670 train_time:49079ms step_avg:92.08ms +step:534/1670 train_time:49171ms step_avg:92.08ms +step:535/1670 train_time:49263ms step_avg:92.08ms +step:536/1670 train_time:49354ms step_avg:92.08ms +step:537/1670 train_time:49445ms step_avg:92.08ms +step:538/1670 train_time:49536ms step_avg:92.07ms +step:539/1670 train_time:49626ms step_avg:92.07ms +step:540/1670 train_time:49717ms step_avg:92.07ms +step:541/1670 train_time:49807ms step_avg:92.06ms +step:542/1670 train_time:49898ms step_avg:92.06ms +step:543/1670 train_time:49989ms step_avg:92.06ms +step:544/1670 train_time:50081ms step_avg:92.06ms +step:545/1670 train_time:50173ms step_avg:92.06ms +step:546/1670 train_time:50266ms step_avg:92.06ms +step:547/1670 train_time:50357ms step_avg:92.06ms +step:548/1670 train_time:50448ms step_avg:92.06ms +step:549/1670 train_time:50539ms step_avg:92.06ms +step:550/1670 train_time:50630ms step_avg:92.05ms +step:551/1670 train_time:50722ms step_avg:92.05ms +step:552/1670 train_time:50812ms step_avg:92.05ms +step:553/1670 train_time:50903ms step_avg:92.05ms +step:554/1670 train_time:50995ms step_avg:92.05ms +step:555/1670 train_time:51086ms step_avg:92.05ms +step:556/1670 train_time:51179ms step_avg:92.05ms +step:557/1670 train_time:51270ms step_avg:92.05ms +step:558/1670 train_time:51555ms step_avg:92.39ms +step:559/1670 train_time:51627ms step_avg:92.36ms +step:560/1670 train_time:51717ms step_avg:92.35ms +step:561/1670 train_time:51808ms step_avg:92.35ms +step:562/1670 train_time:51900ms step_avg:92.35ms +step:563/1670 train_time:51991ms step_avg:92.35ms +step:564/1670 train_time:52082ms step_avg:92.34ms +step:565/1670 train_time:52173ms step_avg:92.34ms +step:566/1670 train_time:52265ms step_avg:92.34ms +step:567/1670 train_time:52357ms step_avg:92.34ms +step:568/1670 train_time:52453ms step_avg:92.35ms +step:569/1670 train_time:52549ms step_avg:92.35ms +step:570/1670 train_time:52643ms step_avg:92.36ms +step:571/1670 train_time:52736ms step_avg:92.36ms +step:572/1670 train_time:52827ms step_avg:92.35ms +step:573/1670 train_time:52919ms step_avg:92.35ms +step:574/1670 train_time:53010ms step_avg:92.35ms +step:575/1670 train_time:53102ms step_avg:92.35ms +step:576/1670 train_time:53193ms step_avg:92.35ms +step:577/1670 train_time:53285ms step_avg:92.35ms +step:578/1670 train_time:53378ms step_avg:92.35ms +step:579/1670 train_time:53473ms step_avg:92.35ms +step:580/1670 train_time:53566ms step_avg:92.36ms +step:581/1670 train_time:53660ms step_avg:92.36ms +step:582/1670 train_time:53753ms step_avg:92.36ms +step:583/1670 train_time:53845ms step_avg:92.36ms +step:584/1670 train_time:53937ms step_avg:92.36ms +step:585/1670 train_time:54028ms step_avg:92.36ms +step:586/1670 train_time:54120ms step_avg:92.35ms +step:587/1670 train_time:54211ms step_avg:92.35ms +step:588/1670 train_time:54303ms step_avg:92.35ms +step:589/1670 train_time:54396ms step_avg:92.35ms +step:590/1670 train_time:54489ms step_avg:92.35ms +step:591/1670 train_time:54583ms step_avg:92.36ms +step:592/1670 train_time:54675ms step_avg:92.36ms +step:593/1670 train_time:54768ms step_avg:92.36ms +step:594/1670 train_time:54862ms step_avg:92.36ms +step:595/1670 train_time:54955ms step_avg:92.36ms +step:596/1670 train_time:55047ms step_avg:92.36ms +step:597/1670 train_time:55139ms step_avg:92.36ms +step:598/1670 train_time:55231ms step_avg:92.36ms +step:599/1670 train_time:55323ms step_avg:92.36ms +step:600/1670 train_time:55416ms step_avg:92.36ms +step:601/1670 train_time:55508ms step_avg:92.36ms +step:602/1670 train_time:55602ms step_avg:92.36ms +step:603/1670 train_time:55696ms step_avg:92.36ms +step:604/1670 train_time:55788ms step_avg:92.36ms +step:605/1670 train_time:55882ms step_avg:92.37ms +step:606/1670 train_time:55974ms step_avg:92.37ms +step:607/1670 train_time:56066ms step_avg:92.37ms +step:608/1670 train_time:56158ms step_avg:92.37ms +step:609/1670 train_time:56249ms step_avg:92.36ms +step:610/1670 train_time:56341ms step_avg:92.36ms +step:611/1670 train_time:56434ms step_avg:92.36ms +step:612/1670 train_time:56526ms step_avg:92.36ms +step:613/1670 train_time:56619ms step_avg:92.36ms +step:614/1670 train_time:56711ms step_avg:92.36ms +step:615/1670 train_time:56805ms step_avg:92.37ms +step:616/1670 train_time:56897ms step_avg:92.37ms +step:617/1670 train_time:56989ms step_avg:92.36ms +step:618/1670 train_time:57081ms step_avg:92.36ms +step:619/1670 train_time:57174ms step_avg:92.36ms +step:620/1670 train_time:57266ms step_avg:92.36ms +step:621/1670 train_time:57359ms step_avg:92.36ms +step:622/1670 train_time:57451ms step_avg:92.36ms +step:623/1670 train_time:57544ms step_avg:92.37ms +step:624/1670 train_time:57637ms step_avg:92.37ms +step:625/1670 train_time:57729ms step_avg:92.37ms +step:625/1670 val_loss:3.6095 train_time:57822ms step_avg:92.51ms +step:626/1670 train_time:57842ms step_avg:92.40ms +step:627/1670 train_time:57921ms step_avg:92.38ms +step:628/1670 train_time:58020ms step_avg:92.39ms +step:629/1670 train_time:58113ms step_avg:92.39ms +step:630/1670 train_time:58204ms step_avg:92.39ms +step:631/1670 train_time:58295ms step_avg:92.38ms +step:632/1670 train_time:58386ms step_avg:92.38ms +step:633/1670 train_time:58477ms step_avg:92.38ms +step:634/1670 train_time:58568ms step_avg:92.38ms +step:635/1670 train_time:58659ms step_avg:92.38ms +step:636/1670 train_time:58750ms step_avg:92.37ms +step:637/1670 train_time:58844ms step_avg:92.38ms +step:638/1670 train_time:58938ms step_avg:92.38ms +step:639/1670 train_time:59176ms step_avg:92.61ms +step:640/1670 train_time:59246ms step_avg:92.57ms +step:641/1670 train_time:59336ms step_avg:92.57ms +step:642/1670 train_time:59428ms step_avg:92.57ms +step:643/1670 train_time:59519ms step_avg:92.56ms +step:644/1670 train_time:59610ms step_avg:92.56ms +step:645/1670 train_time:59701ms step_avg:92.56ms +step:646/1670 train_time:59793ms step_avg:92.56ms +step:647/1670 train_time:59885ms step_avg:92.56ms +step:648/1670 train_time:59975ms step_avg:92.55ms +step:649/1670 train_time:60073ms step_avg:92.56ms +step:650/1670 train_time:60172ms step_avg:92.57ms +step:651/1670 train_time:60265ms step_avg:92.57ms +step:652/1670 train_time:60357ms step_avg:92.57ms +step:653/1670 train_time:60450ms step_avg:92.57ms +step:654/1670 train_time:60541ms step_avg:92.57ms +step:655/1670 train_time:60633ms step_avg:92.57ms +step:656/1670 train_time:60724ms step_avg:92.57ms +step:657/1670 train_time:60815ms step_avg:92.56ms +step:658/1670 train_time:60907ms step_avg:92.56ms +step:659/1670 train_time:60999ms step_avg:92.56ms +step:660/1670 train_time:61094ms step_avg:92.57ms +step:661/1670 train_time:61189ms step_avg:92.57ms +step:662/1670 train_time:61282ms step_avg:92.57ms +step:663/1670 train_time:61375ms step_avg:92.57ms +step:664/1670 train_time:61468ms step_avg:92.57ms +step:665/1670 train_time:61559ms step_avg:92.57ms +step:666/1670 train_time:61652ms step_avg:92.57ms +step:667/1670 train_time:61743ms step_avg:92.57ms +step:668/1670 train_time:61834ms step_avg:92.57ms +step:669/1670 train_time:61926ms step_avg:92.57ms +step:670/1670 train_time:62019ms step_avg:92.57ms +step:671/1670 train_time:62113ms step_avg:92.57ms +step:672/1670 train_time:62206ms step_avg:92.57ms +step:673/1670 train_time:62298ms step_avg:92.57ms +step:674/1670 train_time:62392ms step_avg:92.57ms +step:675/1670 train_time:62484ms step_avg:92.57ms +step:676/1670 train_time:62576ms step_avg:92.57ms +step:677/1670 train_time:62667ms step_avg:92.57ms +step:678/1670 train_time:62758ms step_avg:92.56ms +step:679/1670 train_time:62850ms step_avg:92.56ms +step:680/1670 train_time:62942ms step_avg:92.56ms +step:681/1670 train_time:63035ms step_avg:92.56ms +step:682/1670 train_time:63128ms step_avg:92.56ms +step:683/1670 train_time:63221ms step_avg:92.56ms +step:684/1670 train_time:63315ms step_avg:92.57ms +step:685/1670 train_time:63408ms step_avg:92.57ms +step:686/1670 train_time:63499ms step_avg:92.56ms +step:687/1670 train_time:63592ms step_avg:92.57ms +step:688/1670 train_time:63684ms step_avg:92.56ms +step:689/1670 train_time:63775ms step_avg:92.56ms +step:690/1670 train_time:63867ms step_avg:92.56ms +step:691/1670 train_time:63958ms step_avg:92.56ms +step:692/1670 train_time:64050ms step_avg:92.56ms +step:693/1670 train_time:64143ms step_avg:92.56ms +step:694/1670 train_time:64236ms step_avg:92.56ms +step:695/1670 train_time:64330ms step_avg:92.56ms +step:696/1670 train_time:64422ms step_avg:92.56ms +step:697/1670 train_time:64514ms step_avg:92.56ms +step:698/1670 train_time:64606ms step_avg:92.56ms +step:699/1670 train_time:64698ms step_avg:92.56ms +step:700/1670 train_time:64790ms step_avg:92.56ms +step:701/1670 train_time:64883ms step_avg:92.56ms +step:702/1670 train_time:64976ms step_avg:92.56ms +step:703/1670 train_time:65068ms step_avg:92.56ms +step:704/1670 train_time:65160ms step_avg:92.56ms +step:705/1670 train_time:65253ms step_avg:92.56ms +step:706/1670 train_time:65346ms step_avg:92.56ms +step:707/1670 train_time:65438ms step_avg:92.56ms +step:708/1670 train_time:65531ms step_avg:92.56ms +step:709/1670 train_time:65624ms step_avg:92.56ms +step:710/1670 train_time:65716ms step_avg:92.56ms +step:711/1670 train_time:65809ms step_avg:92.56ms +step:712/1670 train_time:65901ms step_avg:92.56ms +step:713/1670 train_time:65993ms step_avg:92.56ms +step:714/1670 train_time:66086ms step_avg:92.56ms +step:715/1670 train_time:66177ms step_avg:92.56ms +step:716/1670 train_time:66271ms step_avg:92.56ms +step:717/1670 train_time:66363ms step_avg:92.56ms +step:718/1670 train_time:66455ms step_avg:92.56ms +step:719/1670 train_time:66548ms step_avg:92.56ms +step:720/1670 train_time:66640ms step_avg:92.56ms +step:721/1670 train_time:66732ms step_avg:92.56ms +step:722/1670 train_time:66825ms step_avg:92.56ms +step:723/1670 train_time:66918ms step_avg:92.56ms +step:724/1670 train_time:67011ms step_avg:92.56ms +step:725/1670 train_time:67103ms step_avg:92.56ms +step:726/1670 train_time:67195ms step_avg:92.56ms +step:727/1670 train_time:67288ms step_avg:92.56ms +step:728/1670 train_time:67380ms step_avg:92.55ms +step:729/1670 train_time:67472ms step_avg:92.55ms +step:730/1670 train_time:67564ms step_avg:92.55ms +step:731/1670 train_time:67656ms step_avg:92.55ms +step:732/1670 train_time:67748ms step_avg:92.55ms +step:733/1670 train_time:67840ms step_avg:92.55ms +step:734/1670 train_time:67933ms step_avg:92.55ms +step:735/1670 train_time:68027ms step_avg:92.55ms +step:736/1670 train_time:68119ms step_avg:92.55ms +step:737/1670 train_time:68212ms step_avg:92.55ms +step:738/1670 train_time:68304ms step_avg:92.55ms +step:739/1670 train_time:68396ms step_avg:92.55ms +step:740/1670 train_time:68489ms step_avg:92.55ms +step:741/1670 train_time:68582ms step_avg:92.55ms +step:742/1670 train_time:68674ms step_avg:92.55ms +step:743/1670 train_time:68766ms step_avg:92.55ms +step:744/1670 train_time:68857ms step_avg:92.55ms +step:745/1670 train_time:68952ms step_avg:92.55ms +step:746/1670 train_time:69045ms step_avg:92.55ms +step:747/1670 train_time:69136ms step_avg:92.55ms +step:748/1670 train_time:69229ms step_avg:92.55ms +step:749/1670 train_time:69321ms step_avg:92.55ms +step:750/1670 train_time:69413ms step_avg:92.55ms +step:750/1670 val_loss:3.5590 train_time:69505ms step_avg:92.67ms +step:751/1670 train_time:69525ms step_avg:92.58ms +step:752/1670 train_time:69598ms step_avg:92.55ms +step:753/1670 train_time:69691ms step_avg:92.55ms +step:754/1670 train_time:69785ms step_avg:92.55ms +step:755/1670 train_time:69877ms step_avg:92.55ms +step:756/1670 train_time:69969ms step_avg:92.55ms +step:757/1670 train_time:70061ms step_avg:92.55ms +step:758/1670 train_time:70152ms step_avg:92.55ms +step:759/1670 train_time:70243ms step_avg:92.55ms +step:760/1670 train_time:70335ms step_avg:92.55ms +step:761/1670 train_time:70428ms step_avg:92.55ms +step:762/1670 train_time:70522ms step_avg:92.55ms +step:763/1670 train_time:70614ms step_avg:92.55ms +step:764/1670 train_time:70707ms step_avg:92.55ms +step:765/1670 train_time:70800ms step_avg:92.55ms +step:766/1670 train_time:70892ms step_avg:92.55ms +step:767/1670 train_time:70984ms step_avg:92.55ms +step:768/1670 train_time:71076ms step_avg:92.55ms +step:769/1670 train_time:71169ms step_avg:92.55ms +step:770/1670 train_time:71262ms step_avg:92.55ms +step:771/1670 train_time:71353ms step_avg:92.55ms +step:772/1670 train_time:71446ms step_avg:92.55ms +step:773/1670 train_time:71538ms step_avg:92.55ms +step:774/1670 train_time:71632ms step_avg:92.55ms +step:775/1670 train_time:71723ms step_avg:92.55ms +step:776/1670 train_time:71816ms step_avg:92.55ms +step:777/1670 train_time:71908ms step_avg:92.55ms +step:778/1670 train_time:71999ms step_avg:92.54ms +step:779/1670 train_time:72092ms step_avg:92.54ms +step:780/1670 train_time:72184ms step_avg:92.54ms +step:781/1670 train_time:72276ms step_avg:92.54ms +step:782/1670 train_time:72370ms step_avg:92.54ms +step:783/1670 train_time:72463ms step_avg:92.55ms +step:784/1670 train_time:72555ms step_avg:92.55ms +step:785/1670 train_time:72649ms step_avg:92.55ms +step:786/1670 train_time:72742ms step_avg:92.55ms +step:787/1670 train_time:72833ms step_avg:92.54ms +step:788/1670 train_time:72924ms step_avg:92.54ms +step:789/1670 train_time:73016ms step_avg:92.54ms +step:790/1670 train_time:73108ms step_avg:92.54ms +step:791/1670 train_time:73200ms step_avg:92.54ms +step:792/1670 train_time:73293ms step_avg:92.54ms +step:793/1670 train_time:73387ms step_avg:92.54ms +step:794/1670 train_time:73479ms step_avg:92.54ms +step:795/1670 train_time:73572ms step_avg:92.54ms +step:796/1670 train_time:73665ms step_avg:92.54ms +step:797/1670 train_time:73758ms step_avg:92.54ms +step:798/1670 train_time:73850ms step_avg:92.54ms +step:799/1670 train_time:73943ms step_avg:92.54ms +step:800/1670 train_time:74034ms step_avg:92.54ms +step:801/1670 train_time:74126ms step_avg:92.54ms +step:802/1670 train_time:74219ms step_avg:92.54ms +step:803/1670 train_time:74311ms step_avg:92.54ms +step:804/1670 train_time:74403ms step_avg:92.54ms +step:805/1670 train_time:74496ms step_avg:92.54ms +step:806/1670 train_time:74588ms step_avg:92.54ms +step:807/1670 train_time:74682ms step_avg:92.54ms +step:808/1670 train_time:74774ms step_avg:92.54ms +step:809/1670 train_time:74867ms step_avg:92.54ms +step:810/1670 train_time:74959ms step_avg:92.54ms +step:811/1670 train_time:75051ms step_avg:92.54ms +step:812/1670 train_time:75144ms step_avg:92.54ms +step:813/1670 train_time:75236ms step_avg:92.54ms +step:814/1670 train_time:75329ms step_avg:92.54ms +step:815/1670 train_time:75421ms step_avg:92.54ms +step:816/1670 train_time:75514ms step_avg:92.54ms +step:817/1670 train_time:75607ms step_avg:92.54ms +step:818/1670 train_time:75699ms step_avg:92.54ms +step:819/1670 train_time:75792ms step_avg:92.54ms +step:820/1670 train_time:75886ms step_avg:92.54ms +step:821/1670 train_time:75978ms step_avg:92.54ms +step:822/1670 train_time:76071ms step_avg:92.54ms +step:823/1670 train_time:76163ms step_avg:92.54ms +step:824/1670 train_time:76255ms step_avg:92.54ms +step:825/1670 train_time:76348ms step_avg:92.54ms +step:826/1670 train_time:76441ms step_avg:92.54ms +step:827/1670 train_time:76533ms step_avg:92.54ms +step:828/1670 train_time:76626ms step_avg:92.54ms +step:829/1670 train_time:76718ms step_avg:92.54ms +step:830/1670 train_time:76811ms step_avg:92.54ms +step:831/1670 train_time:76903ms step_avg:92.54ms +step:832/1670 train_time:76994ms step_avg:92.54ms +step:833/1670 train_time:77088ms step_avg:92.54ms +step:834/1670 train_time:77181ms step_avg:92.54ms +step:835/1670 train_time:77273ms step_avg:92.54ms +step:836/1670 train_time:77366ms step_avg:92.54ms +step:837/1670 train_time:77458ms step_avg:92.54ms +step:838/1670 train_time:77550ms step_avg:92.54ms +step:839/1670 train_time:77643ms step_avg:92.54ms +step:840/1670 train_time:77734ms step_avg:92.54ms +step:841/1670 train_time:77827ms step_avg:92.54ms +step:842/1670 train_time:77920ms step_avg:92.54ms +step:843/1670 train_time:78012ms step_avg:92.54ms +step:844/1670 train_time:78106ms step_avg:92.54ms +step:845/1670 train_time:78198ms step_avg:92.54ms +step:846/1670 train_time:78291ms step_avg:92.54ms +step:847/1670 train_time:78384ms step_avg:92.54ms +step:848/1670 train_time:78476ms step_avg:92.54ms +step:849/1670 train_time:78570ms step_avg:92.54ms +step:850/1670 train_time:78662ms step_avg:92.54ms +step:851/1670 train_time:78910ms step_avg:92.73ms +step:852/1670 train_time:78989ms step_avg:92.71ms +step:853/1670 train_time:79079ms step_avg:92.71ms +step:854/1670 train_time:79170ms step_avg:92.71ms +step:855/1670 train_time:79261ms step_avg:92.70ms +step:856/1670 train_time:79352ms step_avg:92.70ms +step:857/1670 train_time:79444ms step_avg:92.70ms +step:858/1670 train_time:79535ms step_avg:92.70ms +step:859/1670 train_time:79626ms step_avg:92.70ms +step:860/1670 train_time:79717ms step_avg:92.69ms +step:861/1670 train_time:79812ms step_avg:92.70ms +step:862/1670 train_time:79913ms step_avg:92.71ms +step:863/1670 train_time:80007ms step_avg:92.71ms +step:864/1670 train_time:80099ms step_avg:92.71ms +step:865/1670 train_time:80190ms step_avg:92.71ms +step:866/1670 train_time:80283ms step_avg:92.71ms +step:867/1670 train_time:80373ms step_avg:92.70ms +step:868/1670 train_time:80465ms step_avg:92.70ms +step:869/1670 train_time:80557ms step_avg:92.70ms +step:870/1670 train_time:80648ms step_avg:92.70ms +step:871/1670 train_time:80741ms step_avg:92.70ms +step:872/1670 train_time:80834ms step_avg:92.70ms +step:873/1670 train_time:80929ms step_avg:92.70ms +step:874/1670 train_time:81024ms step_avg:92.70ms +step:875/1670 train_time:81117ms step_avg:92.70ms +step:875/1670 val_loss:3.5142 train_time:81209ms step_avg:92.81ms +step:876/1670 train_time:81229ms step_avg:92.73ms +step:877/1670 train_time:81303ms step_avg:92.71ms +step:878/1670 train_time:81395ms step_avg:92.71ms +step:879/1670 train_time:81488ms step_avg:92.71ms +step:880/1670 train_time:81579ms step_avg:92.70ms +step:881/1670 train_time:81671ms step_avg:92.70ms +step:882/1670 train_time:81762ms step_avg:92.70ms +step:883/1670 train_time:81853ms step_avg:92.70ms +step:884/1670 train_time:81944ms step_avg:92.70ms +step:885/1670 train_time:82035ms step_avg:92.69ms +step:886/1670 train_time:82129ms step_avg:92.70ms +step:887/1670 train_time:82224ms step_avg:92.70ms +step:888/1670 train_time:82317ms step_avg:92.70ms +step:889/1670 train_time:82410ms step_avg:92.70ms +step:890/1670 train_time:82503ms step_avg:92.70ms +step:891/1670 train_time:82596ms step_avg:92.70ms +step:892/1670 train_time:82688ms step_avg:92.70ms +step:893/1670 train_time:82779ms step_avg:92.70ms +step:894/1670 train_time:82871ms step_avg:92.70ms +step:895/1670 train_time:82963ms step_avg:92.70ms +step:896/1670 train_time:83055ms step_avg:92.70ms +step:897/1670 train_time:83149ms step_avg:92.70ms +step:898/1670 train_time:83243ms step_avg:92.70ms +step:899/1670 train_time:83335ms step_avg:92.70ms +step:900/1670 train_time:83428ms step_avg:92.70ms +step:901/1670 train_time:83521ms step_avg:92.70ms +step:902/1670 train_time:83613ms step_avg:92.70ms +step:903/1670 train_time:83704ms step_avg:92.70ms +step:904/1670 train_time:83796ms step_avg:92.69ms +step:905/1670 train_time:83888ms step_avg:92.69ms +step:906/1670 train_time:83980ms step_avg:92.69ms +step:907/1670 train_time:84073ms step_avg:92.69ms +step:908/1670 train_time:84166ms step_avg:92.69ms +step:909/1670 train_time:84258ms step_avg:92.69ms +step:910/1670 train_time:84352ms step_avg:92.69ms +step:911/1670 train_time:84446ms step_avg:92.70ms +step:912/1670 train_time:84538ms step_avg:92.70ms +step:913/1670 train_time:84631ms step_avg:92.70ms +step:914/1670 train_time:84723ms step_avg:92.69ms +step:915/1670 train_time:84814ms step_avg:92.69ms +step:916/1670 train_time:84906ms step_avg:92.69ms +step:917/1670 train_time:84998ms step_avg:92.69ms +step:918/1670 train_time:85089ms step_avg:92.69ms +step:919/1670 train_time:85182ms step_avg:92.69ms +step:920/1670 train_time:85275ms step_avg:92.69ms +step:921/1670 train_time:85369ms step_avg:92.69ms +step:922/1670 train_time:85461ms step_avg:92.69ms +step:923/1670 train_time:85554ms step_avg:92.69ms +step:924/1670 train_time:85646ms step_avg:92.69ms +step:925/1670 train_time:85738ms step_avg:92.69ms +step:926/1670 train_time:85830ms step_avg:92.69ms +step:927/1670 train_time:85923ms step_avg:92.69ms +step:928/1670 train_time:86015ms step_avg:92.69ms +step:929/1670 train_time:86108ms step_avg:92.69ms +step:930/1670 train_time:86201ms step_avg:92.69ms +step:931/1670 train_time:86293ms step_avg:92.69ms +step:932/1670 train_time:86385ms step_avg:92.69ms +step:933/1670 train_time:86478ms step_avg:92.69ms +step:934/1670 train_time:86571ms step_avg:92.69ms +step:935/1670 train_time:86664ms step_avg:92.69ms +step:936/1670 train_time:86755ms step_avg:92.69ms +step:937/1670 train_time:86849ms step_avg:92.69ms +step:938/1670 train_time:86941ms step_avg:92.69ms +step:939/1670 train_time:87033ms step_avg:92.69ms +step:940/1670 train_time:87126ms step_avg:92.69ms +step:941/1670 train_time:87218ms step_avg:92.69ms +step:942/1670 train_time:87311ms step_avg:92.69ms +step:943/1670 train_time:87403ms step_avg:92.69ms +step:944/1670 train_time:87495ms step_avg:92.69ms +step:945/1670 train_time:87588ms step_avg:92.69ms +step:946/1670 train_time:87681ms step_avg:92.69ms +step:947/1670 train_time:87774ms step_avg:92.69ms +step:948/1670 train_time:87866ms step_avg:92.69ms +step:949/1670 train_time:87957ms step_avg:92.68ms +step:950/1670 train_time:88051ms step_avg:92.68ms +step:951/1670 train_time:88143ms step_avg:92.68ms +step:952/1670 train_time:88235ms step_avg:92.68ms +step:953/1670 train_time:88329ms step_avg:92.68ms +step:954/1670 train_time:88421ms step_avg:92.68ms +step:955/1670 train_time:88513ms step_avg:92.68ms +step:956/1670 train_time:88606ms step_avg:92.68ms +step:957/1670 train_time:88698ms step_avg:92.68ms +step:958/1670 train_time:88790ms step_avg:92.68ms +step:959/1670 train_time:88883ms step_avg:92.68ms +step:960/1670 train_time:88975ms step_avg:92.68ms +step:961/1670 train_time:89067ms step_avg:92.68ms +step:962/1670 train_time:89158ms step_avg:92.68ms +step:963/1670 train_time:89252ms step_avg:92.68ms +step:964/1670 train_time:89346ms step_avg:92.68ms +step:965/1670 train_time:89438ms step_avg:92.68ms +step:966/1670 train_time:89530ms step_avg:92.68ms +step:967/1670 train_time:89623ms step_avg:92.68ms +step:968/1670 train_time:89715ms step_avg:92.68ms +step:969/1670 train_time:89807ms step_avg:92.68ms +step:970/1670 train_time:89900ms step_avg:92.68ms +step:971/1670 train_time:89991ms step_avg:92.68ms +step:972/1670 train_time:90084ms step_avg:92.68ms +step:973/1670 train_time:90175ms step_avg:92.68ms +step:974/1670 train_time:90269ms step_avg:92.68ms +step:975/1670 train_time:90361ms step_avg:92.68ms +step:976/1670 train_time:90454ms step_avg:92.68ms +step:977/1670 train_time:90547ms step_avg:92.68ms +step:978/1670 train_time:90638ms step_avg:92.68ms +step:979/1670 train_time:90731ms step_avg:92.68ms +step:980/1670 train_time:90824ms step_avg:92.68ms +step:981/1670 train_time:90917ms step_avg:92.68ms +step:982/1670 train_time:91011ms step_avg:92.68ms +step:983/1670 train_time:91104ms step_avg:92.68ms +step:984/1670 train_time:91196ms step_avg:92.68ms +step:985/1670 train_time:91289ms step_avg:92.68ms +step:986/1670 train_time:91381ms step_avg:92.68ms +step:987/1670 train_time:91474ms step_avg:92.68ms +step:988/1670 train_time:91566ms step_avg:92.68ms +step:989/1670 train_time:91658ms step_avg:92.68ms +step:990/1670 train_time:91751ms step_avg:92.68ms +step:991/1670 train_time:91843ms step_avg:92.68ms +step:992/1670 train_time:91935ms step_avg:92.68ms +step:993/1670 train_time:92028ms step_avg:92.68ms +step:994/1670 train_time:92120ms step_avg:92.68ms +step:995/1670 train_time:92213ms step_avg:92.68ms +step:996/1670 train_time:92305ms step_avg:92.68ms +step:997/1670 train_time:92397ms step_avg:92.67ms +step:998/1670 train_time:92490ms step_avg:92.67ms +step:999/1670 train_time:92583ms step_avg:92.68ms +step:1000/1670 train_time:92675ms step_avg:92.67ms +step:1000/1670 val_loss:3.4660 train_time:92766ms step_avg:92.77ms +step:1001/1670 train_time:92786ms step_avg:92.69ms +step:1002/1670 train_time:92859ms step_avg:92.67ms +step:1003/1670 train_time:92950ms step_avg:92.67ms +step:1004/1670 train_time:93042ms step_avg:92.67ms +step:1005/1670 train_time:93133ms step_avg:92.67ms +step:1006/1670 train_time:93225ms step_avg:92.67ms +step:1007/1670 train_time:93317ms step_avg:92.67ms +step:1008/1670 train_time:93409ms step_avg:92.67ms +step:1009/1670 train_time:93501ms step_avg:92.67ms +step:1010/1670 train_time:93592ms step_avg:92.67ms +step:1011/1670 train_time:93686ms step_avg:92.67ms +step:1012/1670 train_time:93781ms step_avg:92.67ms +step:1013/1670 train_time:93875ms step_avg:92.67ms +step:1014/1670 train_time:93967ms step_avg:92.67ms +step:1015/1670 train_time:94059ms step_avg:92.67ms +step:1016/1670 train_time:94150ms step_avg:92.67ms +step:1017/1670 train_time:94242ms step_avg:92.67ms +step:1018/1670 train_time:94335ms step_avg:92.67ms +step:1019/1670 train_time:94427ms step_avg:92.67ms +step:1020/1670 train_time:94519ms step_avg:92.67ms +step:1021/1670 train_time:94611ms step_avg:92.67ms +step:1022/1670 train_time:94704ms step_avg:92.67ms +step:1023/1670 train_time:94801ms step_avg:92.67ms +step:1024/1670 train_time:94893ms step_avg:92.67ms +step:1025/1670 train_time:94986ms step_avg:92.67ms +step:1026/1670 train_time:95078ms step_avg:92.67ms +step:1027/1670 train_time:95171ms step_avg:92.67ms +step:1028/1670 train_time:95263ms step_avg:92.67ms +step:1029/1670 train_time:95354ms step_avg:92.67ms +step:1030/1670 train_time:95446ms step_avg:92.67ms +step:1031/1670 train_time:95538ms step_avg:92.67ms +step:1032/1670 train_time:95630ms step_avg:92.66ms +step:1033/1670 train_time:95722ms step_avg:92.66ms +step:1034/1670 train_time:95815ms step_avg:92.66ms +step:1035/1670 train_time:95909ms step_avg:92.67ms +step:1036/1670 train_time:96003ms step_avg:92.67ms +step:1037/1670 train_time:96095ms step_avg:92.67ms +step:1038/1670 train_time:96188ms step_avg:92.67ms +step:1039/1670 train_time:96279ms step_avg:92.67ms +step:1040/1670 train_time:96371ms step_avg:92.66ms +step:1041/1670 train_time:96464ms step_avg:92.66ms +step:1042/1670 train_time:96555ms step_avg:92.66ms +step:1043/1670 train_time:96649ms step_avg:92.66ms +step:1044/1670 train_time:96742ms step_avg:92.66ms +step:1045/1670 train_time:96834ms step_avg:92.66ms +step:1046/1670 train_time:96928ms step_avg:92.67ms +step:1047/1670 train_time:97020ms step_avg:92.66ms +step:1048/1670 train_time:97112ms step_avg:92.66ms +step:1049/1670 train_time:97205ms step_avg:92.66ms +step:1050/1670 train_time:97297ms step_avg:92.66ms +step:1051/1670 train_time:97389ms step_avg:92.66ms +step:1052/1670 train_time:97481ms step_avg:92.66ms +step:1053/1670 train_time:97573ms step_avg:92.66ms +step:1054/1670 train_time:97666ms step_avg:92.66ms +step:1055/1670 train_time:97759ms step_avg:92.66ms +step:1056/1670 train_time:97852ms step_avg:92.66ms +step:1057/1670 train_time:97946ms step_avg:92.66ms +step:1058/1670 train_time:98038ms step_avg:92.66ms +step:1059/1670 train_time:98130ms step_avg:92.66ms +step:1060/1670 train_time:98222ms step_avg:92.66ms +step:1061/1670 train_time:98314ms step_avg:92.66ms +step:1062/1670 train_time:98566ms step_avg:92.81ms +step:1063/1670 train_time:98635ms step_avg:92.79ms +step:1064/1670 train_time:98726ms step_avg:92.79ms +step:1065/1670 train_time:98817ms step_avg:92.79ms +step:1066/1670 train_time:98908ms step_avg:92.78ms +step:1067/1670 train_time:98999ms step_avg:92.78ms +step:1068/1670 train_time:99091ms step_avg:92.78ms +step:1069/1670 train_time:99182ms step_avg:92.78ms +step:1070/1670 train_time:99273ms step_avg:92.78ms +step:1071/1670 train_time:99364ms step_avg:92.78ms +step:1072/1670 train_time:99462ms step_avg:92.78ms +step:1073/1670 train_time:99559ms step_avg:92.79ms +step:1074/1670 train_time:99651ms step_avg:92.79ms +step:1075/1670 train_time:99744ms step_avg:92.78ms +step:1076/1670 train_time:99836ms step_avg:92.78ms +step:1077/1670 train_time:99927ms step_avg:92.78ms +step:1078/1670 train_time:100018ms step_avg:92.78ms +step:1079/1670 train_time:100110ms step_avg:92.78ms +step:1080/1670 train_time:100200ms step_avg:92.78ms +step:1081/1670 train_time:100291ms step_avg:92.78ms +step:1082/1670 train_time:100384ms step_avg:92.78ms +step:1083/1670 train_time:100479ms step_avg:92.78ms +step:1084/1670 train_time:100574ms step_avg:92.78ms +step:1085/1670 train_time:100668ms step_avg:92.78ms +step:1086/1670 train_time:100762ms step_avg:92.78ms +step:1087/1670 train_time:100853ms step_avg:92.78ms +step:1088/1670 train_time:100944ms step_avg:92.78ms +step:1089/1670 train_time:101036ms step_avg:92.78ms +step:1090/1670 train_time:101127ms step_avg:92.78ms +step:1091/1670 train_time:101219ms step_avg:92.78ms +step:1092/1670 train_time:101311ms step_avg:92.78ms +step:1093/1670 train_time:101404ms step_avg:92.78ms +step:1094/1670 train_time:101499ms step_avg:92.78ms +step:1095/1670 train_time:101592ms step_avg:92.78ms +step:1096/1670 train_time:101686ms step_avg:92.78ms +step:1097/1670 train_time:101779ms step_avg:92.78ms +step:1098/1670 train_time:101871ms step_avg:92.78ms +step:1099/1670 train_time:101964ms step_avg:92.78ms +step:1100/1670 train_time:102055ms step_avg:92.78ms +step:1101/1670 train_time:102147ms step_avg:92.78ms +step:1102/1670 train_time:102239ms step_avg:92.78ms +step:1103/1670 train_time:102331ms step_avg:92.78ms +step:1104/1670 train_time:102428ms step_avg:92.78ms +step:1105/1670 train_time:102520ms step_avg:92.78ms +step:1106/1670 train_time:102612ms step_avg:92.78ms +step:1107/1670 train_time:102706ms step_avg:92.78ms +step:1108/1670 train_time:102798ms step_avg:92.78ms +step:1109/1670 train_time:102890ms step_avg:92.78ms +step:1110/1670 train_time:102982ms step_avg:92.78ms +step:1111/1670 train_time:103074ms step_avg:92.78ms +step:1112/1670 train_time:103166ms step_avg:92.78ms +step:1113/1670 train_time:103258ms step_avg:92.77ms +step:1114/1670 train_time:103351ms step_avg:92.77ms +step:1115/1670 train_time:103640ms step_avg:92.95ms +step:1116/1670 train_time:103709ms step_avg:92.93ms +step:1117/1670 train_time:103800ms step_avg:92.93ms +step:1118/1670 train_time:103892ms step_avg:92.93ms +step:1119/1670 train_time:103984ms step_avg:92.93ms +step:1120/1670 train_time:104075ms step_avg:92.92ms +step:1121/1670 train_time:104167ms step_avg:92.92ms +step:1122/1670 train_time:104259ms step_avg:92.92ms +step:1123/1670 train_time:104351ms step_avg:92.92ms +step:1124/1670 train_time:104443ms step_avg:92.92ms +step:1125/1670 train_time:104542ms step_avg:92.93ms +step:1125/1670 val_loss:3.4139 train_time:104641ms step_avg:93.01ms +step:1126/1670 train_time:104662ms step_avg:92.95ms +step:1127/1670 train_time:104737ms step_avg:92.93ms +step:1128/1670 train_time:104836ms step_avg:92.94ms +step:1129/1670 train_time:104929ms step_avg:92.94ms +step:1130/1670 train_time:105022ms step_avg:92.94ms +step:1131/1670 train_time:105113ms step_avg:92.94ms +step:1132/1670 train_time:105205ms step_avg:92.94ms +step:1133/1670 train_time:105297ms step_avg:92.94ms +step:1134/1670 train_time:105388ms step_avg:92.93ms +step:1135/1670 train_time:105481ms step_avg:92.93ms +step:1136/1670 train_time:105576ms step_avg:92.94ms +step:1137/1670 train_time:105671ms step_avg:92.94ms +step:1138/1670 train_time:105768ms step_avg:92.94ms +step:1139/1670 train_time:105864ms step_avg:92.94ms +step:1140/1670 train_time:105956ms step_avg:92.94ms +step:1141/1670 train_time:106049ms step_avg:92.94ms +step:1142/1670 train_time:106142ms step_avg:92.94ms +step:1143/1670 train_time:106235ms step_avg:92.94ms +step:1144/1670 train_time:106326ms step_avg:92.94ms +step:1145/1670 train_time:106419ms step_avg:92.94ms +step:1146/1670 train_time:106512ms step_avg:92.94ms +step:1147/1670 train_time:106605ms step_avg:92.94ms +step:1148/1670 train_time:106699ms step_avg:92.94ms +step:1149/1670 train_time:106793ms step_avg:92.94ms +step:1150/1670 train_time:106887ms step_avg:92.95ms +step:1151/1670 train_time:106981ms step_avg:92.95ms +step:1152/1670 train_time:107074ms step_avg:92.95ms +step:1153/1670 train_time:107167ms step_avg:92.95ms +step:1154/1670 train_time:107259ms step_avg:92.95ms +step:1155/1670 train_time:107351ms step_avg:92.94ms +step:1156/1670 train_time:107444ms step_avg:92.94ms +step:1157/1670 train_time:107538ms step_avg:92.95ms +step:1158/1670 train_time:107630ms step_avg:92.94ms +step:1159/1670 train_time:107724ms step_avg:92.95ms +step:1160/1670 train_time:107818ms step_avg:92.95ms +step:1161/1670 train_time:107911ms step_avg:92.95ms +step:1162/1670 train_time:108005ms step_avg:92.95ms +step:1163/1670 train_time:108098ms step_avg:92.95ms +step:1164/1670 train_time:108191ms step_avg:92.95ms +step:1165/1670 train_time:108283ms step_avg:92.95ms +step:1166/1670 train_time:108376ms step_avg:92.95ms +step:1167/1670 train_time:108469ms step_avg:92.95ms +step:1168/1670 train_time:108562ms step_avg:92.95ms +step:1169/1670 train_time:108655ms step_avg:92.95ms +step:1170/1670 train_time:108748ms step_avg:92.95ms +step:1171/1670 train_time:108843ms step_avg:92.95ms +step:1172/1670 train_time:108937ms step_avg:92.95ms +step:1173/1670 train_time:109029ms step_avg:92.95ms +step:1174/1670 train_time:109122ms step_avg:92.95ms +step:1175/1670 train_time:109215ms step_avg:92.95ms +step:1176/1670 train_time:109308ms step_avg:92.95ms +step:1177/1670 train_time:109400ms step_avg:92.95ms +step:1178/1670 train_time:109493ms step_avg:92.95ms +step:1179/1670 train_time:109587ms step_avg:92.95ms +step:1180/1670 train_time:109681ms step_avg:92.95ms +step:1181/1670 train_time:109774ms step_avg:92.95ms +step:1182/1670 train_time:109868ms step_avg:92.95ms +step:1183/1670 train_time:109962ms step_avg:92.95ms +step:1184/1670 train_time:110054ms step_avg:92.95ms +step:1185/1670 train_time:110146ms step_avg:92.95ms +step:1186/1670 train_time:110240ms step_avg:92.95ms +step:1187/1670 train_time:110332ms step_avg:92.95ms +step:1188/1670 train_time:110425ms step_avg:92.95ms +step:1189/1670 train_time:110518ms step_avg:92.95ms +step:1190/1670 train_time:110611ms step_avg:92.95ms +step:1191/1670 train_time:110705ms step_avg:92.95ms +step:1192/1670 train_time:110799ms step_avg:92.95ms +step:1193/1670 train_time:110891ms step_avg:92.95ms +step:1194/1670 train_time:110986ms step_avg:92.95ms +step:1195/1670 train_time:111081ms step_avg:92.95ms +step:1196/1670 train_time:111174ms step_avg:92.95ms +step:1197/1670 train_time:111269ms step_avg:92.96ms +step:1198/1670 train_time:111361ms step_avg:92.96ms +step:1199/1670 train_time:111453ms step_avg:92.95ms +step:1200/1670 train_time:111547ms step_avg:92.96ms +step:1201/1670 train_time:111640ms step_avg:92.96ms +step:1202/1670 train_time:111733ms step_avg:92.96ms +step:1203/1670 train_time:111827ms step_avg:92.96ms +step:1204/1670 train_time:111920ms step_avg:92.96ms +step:1205/1670 train_time:112013ms step_avg:92.96ms +step:1206/1670 train_time:112107ms step_avg:92.96ms +step:1207/1670 train_time:112200ms step_avg:92.96ms +step:1208/1670 train_time:112293ms step_avg:92.96ms +step:1209/1670 train_time:112387ms step_avg:92.96ms +step:1210/1670 train_time:112479ms step_avg:92.96ms +step:1211/1670 train_time:112572ms step_avg:92.96ms +step:1212/1670 train_time:112667ms step_avg:92.96ms +step:1213/1670 train_time:112760ms step_avg:92.96ms +step:1214/1670 train_time:112853ms step_avg:92.96ms +step:1215/1670 train_time:112947ms step_avg:92.96ms +step:1216/1670 train_time:113040ms step_avg:92.96ms +step:1217/1670 train_time:113134ms step_avg:92.96ms +step:1218/1670 train_time:113227ms step_avg:92.96ms +step:1219/1670 train_time:113320ms step_avg:92.96ms +step:1220/1670 train_time:113413ms step_avg:92.96ms +step:1221/1670 train_time:113506ms step_avg:92.96ms +step:1222/1670 train_time:113599ms step_avg:92.96ms +step:1223/1670 train_time:113691ms step_avg:92.96ms +step:1224/1670 train_time:113784ms step_avg:92.96ms +step:1225/1670 train_time:113878ms step_avg:92.96ms +step:1226/1670 train_time:113971ms step_avg:92.96ms +step:1227/1670 train_time:114065ms step_avg:92.96ms +step:1228/1670 train_time:114158ms step_avg:92.96ms +step:1229/1670 train_time:114250ms step_avg:92.96ms +step:1230/1670 train_time:114344ms step_avg:92.96ms +step:1231/1670 train_time:114436ms step_avg:92.96ms +step:1232/1670 train_time:114529ms step_avg:92.96ms +step:1233/1670 train_time:114622ms step_avg:92.96ms +step:1234/1670 train_time:114715ms step_avg:92.96ms +step:1235/1670 train_time:114809ms step_avg:92.96ms +step:1236/1670 train_time:114903ms step_avg:92.96ms +step:1237/1670 train_time:114995ms step_avg:92.96ms +step:1238/1670 train_time:115088ms step_avg:92.96ms +step:1239/1670 train_time:115180ms step_avg:92.96ms +step:1240/1670 train_time:115274ms step_avg:92.96ms +step:1241/1670 train_time:115368ms step_avg:92.96ms +step:1242/1670 train_time:115461ms step_avg:92.96ms +step:1243/1670 train_time:115554ms step_avg:92.96ms +step:1244/1670 train_time:115648ms step_avg:92.96ms +step:1245/1670 train_time:115741ms step_avg:92.96ms +step:1246/1670 train_time:115833ms step_avg:92.96ms +step:1247/1670 train_time:115928ms step_avg:92.97ms +step:1248/1670 train_time:116021ms step_avg:92.97ms +step:1249/1670 train_time:116114ms step_avg:92.97ms +step:1250/1670 train_time:116207ms step_avg:92.97ms +step:1250/1670 val_loss:3.3747 train_time:116299ms step_avg:93.04ms +step:1251/1670 train_time:116319ms step_avg:92.98ms +step:1252/1670 train_time:116393ms step_avg:92.97ms +step:1253/1670 train_time:116488ms step_avg:92.97ms +step:1254/1670 train_time:116580ms step_avg:92.97ms +step:1255/1670 train_time:116672ms step_avg:92.97ms +step:1256/1670 train_time:116764ms step_avg:92.96ms +step:1257/1670 train_time:116856ms step_avg:92.96ms +step:1258/1670 train_time:116949ms step_avg:92.96ms +step:1259/1670 train_time:117042ms step_avg:92.96ms +step:1260/1670 train_time:117135ms step_avg:92.96ms +step:1261/1670 train_time:117230ms step_avg:92.97ms +step:1262/1670 train_time:117325ms step_avg:92.97ms +step:1263/1670 train_time:117419ms step_avg:92.97ms +step:1264/1670 train_time:117513ms step_avg:92.97ms +step:1265/1670 train_time:117605ms step_avg:92.97ms +step:1266/1670 train_time:117697ms step_avg:92.97ms +step:1267/1670 train_time:117791ms step_avg:92.97ms +step:1268/1670 train_time:117884ms step_avg:92.97ms +step:1269/1670 train_time:117976ms step_avg:92.97ms +step:1270/1670 train_time:118070ms step_avg:92.97ms +step:1271/1670 train_time:118163ms step_avg:92.97ms +step:1272/1670 train_time:118257ms step_avg:92.97ms +step:1273/1670 train_time:118353ms step_avg:92.97ms +step:1274/1670 train_time:118601ms step_avg:93.09ms +step:1275/1670 train_time:118672ms step_avg:93.08ms +step:1276/1670 train_time:118763ms step_avg:93.07ms +step:1277/1670 train_time:118855ms step_avg:93.07ms +step:1278/1670 train_time:118947ms step_avg:93.07ms +step:1279/1670 train_time:119039ms step_avg:93.07ms +step:1280/1670 train_time:119131ms step_avg:93.07ms +step:1281/1670 train_time:119222ms step_avg:93.07ms +step:1282/1670 train_time:119315ms step_avg:93.07ms +step:1283/1670 train_time:119407ms step_avg:93.07ms +step:1284/1670 train_time:119503ms step_avg:93.07ms +step:1285/1670 train_time:119599ms step_avg:93.07ms +step:1286/1670 train_time:119694ms step_avg:93.07ms +step:1287/1670 train_time:119787ms step_avg:93.07ms +step:1288/1670 train_time:119880ms step_avg:93.07ms +step:1289/1670 train_time:119972ms step_avg:93.07ms +step:1290/1670 train_time:120064ms step_avg:93.07ms +step:1291/1670 train_time:120156ms step_avg:93.07ms +step:1292/1670 train_time:120249ms step_avg:93.07ms +step:1293/1670 train_time:120340ms step_avg:93.07ms +step:1294/1670 train_time:120434ms step_avg:93.07ms +step:1295/1670 train_time:120529ms step_avg:93.07ms +step:1296/1670 train_time:120625ms step_avg:93.07ms +step:1297/1670 train_time:120718ms step_avg:93.07ms +step:1298/1670 train_time:120811ms step_avg:93.07ms +step:1299/1670 train_time:120904ms step_avg:93.07ms +step:1300/1670 train_time:120996ms step_avg:93.07ms +step:1301/1670 train_time:121090ms step_avg:93.07ms +step:1302/1670 train_time:121182ms step_avg:93.07ms +step:1303/1670 train_time:121274ms step_avg:93.07ms +step:1304/1670 train_time:121366ms step_avg:93.07ms +step:1305/1670 train_time:121459ms step_avg:93.07ms +step:1306/1670 train_time:121553ms step_avg:93.07ms +step:1307/1670 train_time:121648ms step_avg:93.07ms +step:1308/1670 train_time:121741ms step_avg:93.07ms +step:1309/1670 train_time:121834ms step_avg:93.07ms +step:1310/1670 train_time:121928ms step_avg:93.07ms +step:1311/1670 train_time:122020ms step_avg:93.07ms +step:1312/1670 train_time:122113ms step_avg:93.07ms +step:1313/1670 train_time:122206ms step_avg:93.07ms +step:1314/1670 train_time:122298ms step_avg:93.07ms +step:1315/1670 train_time:122393ms step_avg:93.07ms +step:1316/1670 train_time:122487ms step_avg:93.07ms +step:1317/1670 train_time:122581ms step_avg:93.08ms +step:1318/1670 train_time:122675ms step_avg:93.08ms +step:1319/1670 train_time:122768ms step_avg:93.08ms +step:1320/1670 train_time:122861ms step_avg:93.08ms +step:1321/1670 train_time:122954ms step_avg:93.08ms +step:1322/1670 train_time:123047ms step_avg:93.08ms +step:1323/1670 train_time:123139ms step_avg:93.08ms +step:1324/1670 train_time:123231ms step_avg:93.07ms +step:1325/1670 train_time:123325ms step_avg:93.08ms +step:1326/1670 train_time:123418ms step_avg:93.08ms +step:1327/1670 train_time:123512ms step_avg:93.08ms +step:1328/1670 train_time:123606ms step_avg:93.08ms +step:1329/1670 train_time:123699ms step_avg:93.08ms +step:1330/1670 train_time:123794ms step_avg:93.08ms +step:1331/1670 train_time:123887ms step_avg:93.08ms +step:1332/1670 train_time:123980ms step_avg:93.08ms +step:1333/1670 train_time:124074ms step_avg:93.08ms +step:1334/1670 train_time:124167ms step_avg:93.08ms +step:1335/1670 train_time:124259ms step_avg:93.08ms +step:1336/1670 train_time:124353ms step_avg:93.08ms +step:1337/1670 train_time:124446ms step_avg:93.08ms +step:1338/1670 train_time:124539ms step_avg:93.08ms +step:1339/1670 train_time:124634ms step_avg:93.08ms +step:1340/1670 train_time:124728ms step_avg:93.08ms +step:1341/1670 train_time:124821ms step_avg:93.08ms +step:1342/1670 train_time:124914ms step_avg:93.08ms +step:1343/1670 train_time:125008ms step_avg:93.08ms +step:1344/1670 train_time:125101ms step_avg:93.08ms +step:1345/1670 train_time:125195ms step_avg:93.08ms +step:1346/1670 train_time:125288ms step_avg:93.08ms +step:1347/1670 train_time:125380ms step_avg:93.08ms +step:1348/1670 train_time:125474ms step_avg:93.08ms +step:1349/1670 train_time:125567ms step_avg:93.08ms +step:1350/1670 train_time:125661ms step_avg:93.08ms +step:1351/1670 train_time:125754ms step_avg:93.08ms +step:1352/1670 train_time:125847ms step_avg:93.08ms +step:1353/1670 train_time:125939ms step_avg:93.08ms +step:1354/1670 train_time:126033ms step_avg:93.08ms +step:1355/1670 train_time:126126ms step_avg:93.08ms +step:1356/1670 train_time:126218ms step_avg:93.08ms +step:1357/1670 train_time:126311ms step_avg:93.08ms +step:1358/1670 train_time:126405ms step_avg:93.08ms +step:1359/1670 train_time:126497ms step_avg:93.08ms +step:1360/1670 train_time:126591ms step_avg:93.08ms +step:1361/1670 train_time:126685ms step_avg:93.08ms +step:1362/1670 train_time:126777ms step_avg:93.08ms +step:1363/1670 train_time:126871ms step_avg:93.08ms +step:1364/1670 train_time:126964ms step_avg:93.08ms +step:1365/1670 train_time:127057ms step_avg:93.08ms +step:1366/1670 train_time:127151ms step_avg:93.08ms +step:1367/1670 train_time:127244ms step_avg:93.08ms +step:1368/1670 train_time:127336ms step_avg:93.08ms +step:1369/1670 train_time:127430ms step_avg:93.08ms +step:1370/1670 train_time:127523ms step_avg:93.08ms +step:1371/1670 train_time:127616ms step_avg:93.08ms +step:1372/1670 train_time:127710ms step_avg:93.08ms +step:1373/1670 train_time:127803ms step_avg:93.08ms +step:1374/1670 train_time:127897ms step_avg:93.08ms +step:1375/1670 train_time:127990ms step_avg:93.08ms +step:1375/1670 val_loss:3.3404 train_time:128082ms step_avg:93.15ms +step:1376/1670 train_time:128102ms step_avg:93.10ms +step:1377/1670 train_time:128177ms step_avg:93.08ms +step:1378/1670 train_time:128271ms step_avg:93.08ms +step:1379/1670 train_time:128365ms step_avg:93.09ms +step:1380/1670 train_time:128458ms step_avg:93.09ms +step:1381/1670 train_time:128549ms step_avg:93.08ms +step:1382/1670 train_time:128643ms step_avg:93.08ms +step:1383/1670 train_time:128737ms step_avg:93.09ms +step:1384/1670 train_time:128829ms step_avg:93.08ms +step:1385/1670 train_time:128922ms step_avg:93.08ms +step:1386/1670 train_time:129017ms step_avg:93.09ms +step:1387/1670 train_time:129112ms step_avg:93.09ms +step:1388/1670 train_time:129207ms step_avg:93.09ms +step:1389/1670 train_time:129300ms step_avg:93.09ms +step:1390/1670 train_time:129392ms step_avg:93.09ms +step:1391/1670 train_time:129485ms step_avg:93.09ms +step:1392/1670 train_time:129578ms step_avg:93.09ms +step:1393/1670 train_time:129670ms step_avg:93.09ms +step:1394/1670 train_time:129763ms step_avg:93.09ms +step:1395/1670 train_time:129855ms step_avg:93.09ms +step:1396/1670 train_time:129949ms step_avg:93.09ms +step:1397/1670 train_time:130043ms step_avg:93.09ms +step:1398/1670 train_time:130137ms step_avg:93.09ms +step:1399/1670 train_time:130230ms step_avg:93.09ms +step:1400/1670 train_time:130323ms step_avg:93.09ms +step:1401/1670 train_time:130416ms step_avg:93.09ms +step:1402/1670 train_time:130510ms step_avg:93.09ms +step:1403/1670 train_time:130604ms step_avg:93.09ms +step:1404/1670 train_time:130696ms step_avg:93.09ms +step:1405/1670 train_time:130789ms step_avg:93.09ms +step:1406/1670 train_time:130882ms step_avg:93.09ms +step:1407/1670 train_time:130975ms step_avg:93.09ms +step:1408/1670 train_time:131068ms step_avg:93.09ms +step:1409/1670 train_time:131162ms step_avg:93.09ms +step:1410/1670 train_time:131255ms step_avg:93.09ms +step:1411/1670 train_time:131347ms step_avg:93.09ms +step:1412/1670 train_time:131441ms step_avg:93.09ms +step:1413/1670 train_time:131536ms step_avg:93.09ms +step:1414/1670 train_time:131628ms step_avg:93.09ms +step:1415/1670 train_time:131721ms step_avg:93.09ms +step:1416/1670 train_time:131814ms step_avg:93.09ms +step:1417/1670 train_time:131907ms step_avg:93.09ms +step:1418/1670 train_time:132000ms step_avg:93.09ms +step:1419/1670 train_time:132093ms step_avg:93.09ms +step:1420/1670 train_time:132188ms step_avg:93.09ms +step:1421/1670 train_time:132281ms step_avg:93.09ms +step:1422/1670 train_time:132374ms step_avg:93.09ms +step:1423/1670 train_time:132469ms step_avg:93.09ms +step:1424/1670 train_time:132562ms step_avg:93.09ms +step:1425/1670 train_time:132654ms step_avg:93.09ms +step:1426/1670 train_time:132747ms step_avg:93.09ms +step:1427/1670 train_time:132840ms step_avg:93.09ms +step:1428/1670 train_time:132933ms step_avg:93.09ms +step:1429/1670 train_time:133027ms step_avg:93.09ms +step:1430/1670 train_time:133121ms step_avg:93.09ms +step:1431/1670 train_time:133213ms step_avg:93.09ms +step:1432/1670 train_time:133307ms step_avg:93.09ms +step:1433/1670 train_time:133400ms step_avg:93.09ms +step:1434/1670 train_time:133492ms step_avg:93.09ms +step:1435/1670 train_time:133585ms step_avg:93.09ms +step:1436/1670 train_time:133679ms step_avg:93.09ms +step:1437/1670 train_time:133771ms step_avg:93.09ms +step:1438/1670 train_time:133864ms step_avg:93.09ms +step:1439/1670 train_time:133957ms step_avg:93.09ms +step:1440/1670 train_time:134050ms step_avg:93.09ms +step:1441/1670 train_time:134145ms step_avg:93.09ms +step:1442/1670 train_time:134239ms step_avg:93.09ms +step:1443/1670 train_time:134332ms step_avg:93.09ms +step:1444/1670 train_time:134426ms step_avg:93.09ms +step:1445/1670 train_time:134519ms step_avg:93.09ms +step:1446/1670 train_time:134612ms step_avg:93.09ms +step:1447/1670 train_time:134706ms step_avg:93.09ms +step:1448/1670 train_time:134799ms step_avg:93.09ms +step:1449/1670 train_time:134892ms step_avg:93.09ms +step:1450/1670 train_time:134985ms step_avg:93.09ms +step:1451/1670 train_time:135078ms step_avg:93.09ms +step:1452/1670 train_time:135171ms step_avg:93.09ms +step:1453/1670 train_time:135265ms step_avg:93.09ms +step:1454/1670 train_time:135358ms step_avg:93.09ms +step:1455/1670 train_time:135452ms step_avg:93.09ms +step:1456/1670 train_time:135546ms step_avg:93.09ms +step:1457/1670 train_time:135639ms step_avg:93.09ms +step:1458/1670 train_time:135733ms step_avg:93.10ms +step:1459/1670 train_time:135826ms step_avg:93.10ms +step:1460/1670 train_time:135919ms step_avg:93.10ms +step:1461/1670 train_time:136012ms step_avg:93.10ms +step:1462/1670 train_time:136106ms step_avg:93.10ms +step:1463/1670 train_time:136199ms step_avg:93.10ms +step:1464/1670 train_time:136294ms step_avg:93.10ms +step:1465/1670 train_time:136387ms step_avg:93.10ms +step:1466/1670 train_time:136480ms step_avg:93.10ms +step:1467/1670 train_time:136573ms step_avg:93.10ms +step:1468/1670 train_time:136665ms step_avg:93.10ms +step:1469/1670 train_time:136758ms step_avg:93.10ms +step:1470/1670 train_time:136851ms step_avg:93.10ms +step:1471/1670 train_time:136945ms step_avg:93.10ms +step:1472/1670 train_time:137038ms step_avg:93.10ms +step:1473/1670 train_time:137130ms step_avg:93.10ms +step:1474/1670 train_time:137224ms step_avg:93.10ms +step:1475/1670 train_time:137316ms step_avg:93.10ms +step:1476/1670 train_time:137410ms step_avg:93.10ms +step:1477/1670 train_time:137503ms step_avg:93.10ms +step:1478/1670 train_time:137596ms step_avg:93.10ms +step:1479/1670 train_time:137689ms step_avg:93.10ms +step:1480/1670 train_time:137783ms step_avg:93.10ms +step:1481/1670 train_time:137876ms step_avg:93.10ms +step:1482/1670 train_time:137969ms step_avg:93.10ms +step:1483/1670 train_time:138063ms step_avg:93.10ms +step:1484/1670 train_time:138155ms step_avg:93.10ms +step:1485/1670 train_time:138407ms step_avg:93.20ms +step:1486/1670 train_time:138476ms step_avg:93.19ms +step:1487/1670 train_time:138568ms step_avg:93.19ms +step:1488/1670 train_time:138660ms step_avg:93.19ms +step:1489/1670 train_time:138752ms step_avg:93.18ms +step:1490/1670 train_time:138844ms step_avg:93.18ms +step:1491/1670 train_time:138936ms step_avg:93.18ms +step:1492/1670 train_time:139028ms step_avg:93.18ms +step:1493/1670 train_time:139119ms step_avg:93.18ms +step:1494/1670 train_time:139212ms step_avg:93.18ms +step:1495/1670 train_time:139311ms step_avg:93.18ms +step:1496/1670 train_time:139409ms step_avg:93.19ms +step:1497/1670 train_time:139503ms step_avg:93.19ms +step:1498/1670 train_time:139596ms step_avg:93.19ms +step:1499/1670 train_time:139687ms step_avg:93.19ms +step:1500/1670 train_time:139779ms step_avg:93.19ms +step:1500/1670 val_loss:3.3104 train_time:139871ms step_avg:93.25ms +step:1501/1670 train_time:139891ms step_avg:93.20ms +step:1502/1670 train_time:139966ms step_avg:93.19ms +step:1503/1670 train_time:140059ms step_avg:93.19ms +step:1504/1670 train_time:140150ms step_avg:93.19ms +step:1505/1670 train_time:140243ms step_avg:93.18ms +step:1506/1670 train_time:140334ms step_avg:93.18ms +step:1507/1670 train_time:140428ms step_avg:93.18ms +step:1508/1670 train_time:140523ms step_avg:93.18ms +step:1509/1670 train_time:140615ms step_avg:93.18ms +step:1510/1670 train_time:140709ms step_avg:93.18ms +step:1511/1670 train_time:140803ms step_avg:93.19ms +step:1512/1670 train_time:140897ms step_avg:93.19ms +step:1513/1670 train_time:140992ms step_avg:93.19ms +step:1514/1670 train_time:141085ms step_avg:93.19ms +step:1515/1670 train_time:141178ms step_avg:93.19ms +step:1516/1670 train_time:141271ms step_avg:93.19ms +step:1517/1670 train_time:141363ms step_avg:93.19ms +step:1518/1670 train_time:141458ms step_avg:93.19ms +step:1519/1670 train_time:141552ms step_avg:93.19ms +step:1520/1670 train_time:141644ms step_avg:93.19ms +step:1521/1670 train_time:141738ms step_avg:93.19ms +step:1522/1670 train_time:141832ms step_avg:93.19ms +step:1523/1670 train_time:141926ms step_avg:93.19ms +step:1524/1670 train_time:142019ms step_avg:93.19ms +step:1525/1670 train_time:142112ms step_avg:93.19ms +step:1526/1670 train_time:142204ms step_avg:93.19ms +step:1527/1670 train_time:142296ms step_avg:93.19ms +step:1528/1670 train_time:142390ms step_avg:93.19ms +step:1529/1670 train_time:142484ms step_avg:93.19ms +step:1530/1670 train_time:142576ms step_avg:93.19ms +step:1531/1670 train_time:142669ms step_avg:93.19ms +step:1532/1670 train_time:142763ms step_avg:93.19ms +step:1533/1670 train_time:142857ms step_avg:93.19ms +step:1534/1670 train_time:142951ms step_avg:93.19ms +step:1535/1670 train_time:143045ms step_avg:93.19ms +step:1536/1670 train_time:143137ms step_avg:93.19ms +step:1537/1670 train_time:143231ms step_avg:93.19ms +step:1538/1670 train_time:143324ms step_avg:93.19ms +step:1539/1670 train_time:143416ms step_avg:93.19ms +step:1540/1670 train_time:143510ms step_avg:93.19ms +step:1541/1670 train_time:143603ms step_avg:93.19ms +step:1542/1670 train_time:143696ms step_avg:93.19ms +step:1543/1670 train_time:143790ms step_avg:93.19ms +step:1544/1670 train_time:143883ms step_avg:93.19ms +step:1545/1670 train_time:143976ms step_avg:93.19ms +step:1546/1670 train_time:144071ms step_avg:93.19ms +step:1547/1670 train_time:144164ms step_avg:93.19ms +step:1548/1670 train_time:144257ms step_avg:93.19ms +step:1549/1670 train_time:144351ms step_avg:93.19ms +step:1550/1670 train_time:144443ms step_avg:93.19ms +step:1551/1670 train_time:144535ms step_avg:93.19ms +step:1552/1670 train_time:144629ms step_avg:93.19ms +step:1553/1670 train_time:144721ms step_avg:93.19ms +step:1554/1670 train_time:144814ms step_avg:93.19ms +step:1555/1670 train_time:144905ms step_avg:93.19ms +step:1556/1670 train_time:144998ms step_avg:93.19ms +step:1557/1670 train_time:145092ms step_avg:93.19ms +step:1558/1670 train_time:145185ms step_avg:93.19ms +step:1559/1670 train_time:145279ms step_avg:93.19ms +step:1560/1670 train_time:145372ms step_avg:93.19ms +step:1561/1670 train_time:145466ms step_avg:93.19ms +step:1562/1670 train_time:145559ms step_avg:93.19ms +step:1563/1670 train_time:145653ms step_avg:93.19ms +step:1564/1670 train_time:145745ms step_avg:93.19ms +step:1565/1670 train_time:145839ms step_avg:93.19ms +step:1566/1670 train_time:145931ms step_avg:93.19ms +step:1567/1670 train_time:146025ms step_avg:93.19ms +step:1568/1670 train_time:146118ms step_avg:93.19ms +step:1569/1670 train_time:146212ms step_avg:93.19ms +step:1570/1670 train_time:146305ms step_avg:93.19ms +step:1571/1670 train_time:146398ms step_avg:93.19ms +step:1572/1670 train_time:146491ms step_avg:93.19ms +step:1573/1670 train_time:146586ms step_avg:93.19ms +step:1574/1670 train_time:146679ms step_avg:93.19ms +step:1575/1670 train_time:146773ms step_avg:93.19ms +step:1576/1670 train_time:146867ms step_avg:93.19ms +step:1577/1670 train_time:146961ms step_avg:93.19ms +step:1578/1670 train_time:147054ms step_avg:93.19ms +step:1579/1670 train_time:147148ms step_avg:93.19ms +step:1580/1670 train_time:147241ms step_avg:93.19ms +step:1581/1670 train_time:147334ms step_avg:93.19ms +step:1582/1670 train_time:147427ms step_avg:93.19ms +step:1583/1670 train_time:147520ms step_avg:93.19ms +step:1584/1670 train_time:147613ms step_avg:93.19ms +step:1585/1670 train_time:147706ms step_avg:93.19ms +step:1586/1670 train_time:147798ms step_avg:93.19ms +step:1587/1670 train_time:147893ms step_avg:93.19ms +step:1588/1670 train_time:147986ms step_avg:93.19ms +step:1589/1670 train_time:148078ms step_avg:93.19ms +step:1590/1670 train_time:148172ms step_avg:93.19ms +step:1591/1670 train_time:148266ms step_avg:93.19ms +step:1592/1670 train_time:148358ms step_avg:93.19ms +step:1593/1670 train_time:148452ms step_avg:93.19ms +step:1594/1670 train_time:148546ms step_avg:93.19ms +step:1595/1670 train_time:148638ms step_avg:93.19ms +step:1596/1670 train_time:148731ms step_avg:93.19ms +step:1597/1670 train_time:148825ms step_avg:93.19ms +step:1598/1670 train_time:148918ms step_avg:93.19ms +step:1599/1670 train_time:149010ms step_avg:93.19ms +step:1600/1670 train_time:149103ms step_avg:93.19ms +step:1601/1670 train_time:149196ms step_avg:93.19ms +step:1602/1670 train_time:149291ms step_avg:93.19ms +step:1603/1670 train_time:149384ms step_avg:93.19ms +step:1604/1670 train_time:149477ms step_avg:93.19ms +step:1605/1670 train_time:149571ms step_avg:93.19ms +step:1606/1670 train_time:149664ms step_avg:93.19ms +step:1607/1670 train_time:149757ms step_avg:93.19ms +step:1608/1670 train_time:149852ms step_avg:93.19ms +step:1609/1670 train_time:149946ms step_avg:93.19ms +step:1610/1670 train_time:150038ms step_avg:93.19ms +step:1611/1670 train_time:150130ms step_avg:93.19ms +step:1612/1670 train_time:150224ms step_avg:93.19ms +step:1613/1670 train_time:150316ms step_avg:93.19ms +step:1614/1670 train_time:150410ms step_avg:93.19ms +step:1615/1670 train_time:150503ms step_avg:93.19ms +step:1616/1670 train_time:150596ms step_avg:93.19ms +step:1617/1670 train_time:150689ms step_avg:93.19ms +step:1618/1670 train_time:150784ms step_avg:93.19ms +step:1619/1670 train_time:150876ms step_avg:93.19ms +step:1620/1670 train_time:150969ms step_avg:93.19ms +step:1621/1670 train_time:151063ms step_avg:93.19ms +step:1622/1670 train_time:151155ms step_avg:93.19ms +step:1623/1670 train_time:151250ms step_avg:93.19ms +step:1624/1670 train_time:151343ms step_avg:93.19ms +step:1625/1670 train_time:151436ms step_avg:93.19ms +step:1625/1670 val_loss:3.2856 train_time:151530ms step_avg:93.25ms +step:1626/1670 train_time:151550ms step_avg:93.20ms +step:1627/1670 train_time:151628ms step_avg:93.19ms +step:1628/1670 train_time:151720ms step_avg:93.19ms +step:1629/1670 train_time:151813ms step_avg:93.19ms +step:1630/1670 train_time:151905ms step_avg:93.19ms +step:1631/1670 train_time:151999ms step_avg:93.19ms +step:1632/1670 train_time:152091ms step_avg:93.19ms +step:1633/1670 train_time:152183ms step_avg:93.19ms +step:1634/1670 train_time:152276ms step_avg:93.19ms +step:1635/1670 train_time:152370ms step_avg:93.19ms +step:1636/1670 train_time:152464ms step_avg:93.19ms +step:1637/1670 train_time:152560ms step_avg:93.19ms +step:1638/1670 train_time:152654ms step_avg:93.20ms +step:1639/1670 train_time:152748ms step_avg:93.20ms +step:1640/1670 train_time:152841ms step_avg:93.20ms +step:1641/1670 train_time:152935ms step_avg:93.20ms +step:1642/1670 train_time:153028ms step_avg:93.20ms +step:1643/1670 train_time:153121ms step_avg:93.20ms +step:1644/1670 train_time:153214ms step_avg:93.20ms +step:1645/1670 train_time:153307ms step_avg:93.20ms +step:1646/1670 train_time:153401ms step_avg:93.20ms +step:1647/1670 train_time:153496ms step_avg:93.20ms +step:1648/1670 train_time:153589ms step_avg:93.20ms +step:1649/1670 train_time:153682ms step_avg:93.20ms +step:1650/1670 train_time:153777ms step_avg:93.20ms +step:1651/1670 train_time:153870ms step_avg:93.20ms +step:1652/1670 train_time:153963ms step_avg:93.20ms +step:1653/1670 train_time:154056ms step_avg:93.20ms +step:1654/1670 train_time:154149ms step_avg:93.20ms +step:1655/1670 train_time:154242ms step_avg:93.20ms +step:1656/1670 train_time:154336ms step_avg:93.20ms +step:1657/1670 train_time:154429ms step_avg:93.20ms +step:1658/1670 train_time:154522ms step_avg:93.20ms +step:1659/1670 train_time:154616ms step_avg:93.20ms +step:1660/1670 train_time:154709ms step_avg:93.20ms +step:1661/1670 train_time:154802ms step_avg:93.20ms +step:1662/1670 train_time:154897ms step_avg:93.20ms +step:1663/1670 train_time:154989ms step_avg:93.20ms +step:1664/1670 train_time:155081ms step_avg:93.20ms +step:1665/1670 train_time:155174ms step_avg:93.20ms +step:1666/1670 train_time:155267ms step_avg:93.20ms +step:1667/1670 train_time:155360ms step_avg:93.20ms +step:1668/1670 train_time:155454ms step_avg:93.20ms +step:1669/1670 train_time:155547ms step_avg:93.20ms +step:1670/1670 train_time:155640ms step_avg:93.20ms +step:1670/1670 val_loss:3.2770 train_time:155902ms step_avg:93.35ms +peak memory allocated: 32002 MiB reserved: 47174 MiB diff --git a/records/091125_VectSigmoidBFloat16/README.md b/records/091125_VectSigmoidBFloat16/README.md new file mode 100644 index 000000000..efee9cf58 --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/README.md @@ -0,0 +1,53 @@ +## New WR 1.25% better than PR #122: Optimize distributed training, improve skip connection gating, and enhance bfloat16 usage + +This PR takes all recent improvements including PR #122 from today, and adds on top of that the following three ideas: + +- Replacing in the Muon optimizer Python for-loops with vectorized tensor operations using PyTorch. This is done for improved gradient sharding, padding, and parameter synchronization. + +- Cast more tensors and buffers (embeddings, linear layers, optimizer state, positional encodings) to torch.bfloat16. This allows us to get faster experiments with minimal changes in model accuracy. + +- Apply sigmoid gating to U-Net skip connections; initialize skip weights to -1.5 for better learnability. Instead of directly multiplying skip connections by a raw trainable parameter (which could be unbounded and unstable), the code now passes the skip weight through a sigmoid function. This constrains the gate value to the range (0, 1), making the effect of each skip connection smoothly adjustable and numerically stable. + +This improves the runtime by 2 seconds, i.e. 1.25%, see below. + +### Validation +I’ve used a 8 × H100 SXM NVLink 80GB node on RunPod. The results I’ve been getting when benchmarking PR #122 are a bit better than the ones reported there. So here I present the statistics of both PR #122 and this PR when using that node: + + +Validation for PR #122 +``` +import scipy.stats +import torch + +accs = [3.2798, 3.2798, 3.2829, 3.2785, 3.2783, 3.2787, 3.2787, 3.2784, 3.2821, 3.2794, 3.2786, 3.2765, 3.2794, 3.2776, 3.2778, 3.2774, 3.2777] + +times = [157.977, 157.889, 158.014, 158.103, 158.093, 158.001, 158.089, 157.981, 158.019, 157.963, 158.043, 157.957, 157.880, 157.687, 158.002, 157.947, 158.097] + +print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue) +# p=0.0069 + +print("acc:", torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0016), tensor(3.2789)) + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (tensor(0.1021), tensor(157.9848)) +``` + +Validation for PR #123 +``` +import scipy.stats +import torch + +accs = [3.277, 3.2772, 3.2778, 3.2767, 3.2805, 3.2781, 3.2797, 3.2802, 3.2774, 3.2767, 3.2769, 3.2783] + +times = [155.902, 155.956, 156.043, 155.987, 155.980, 155.717, 156.019, 156.077, 156.064, 156.100, 156.129, 155.799] + +print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue) +# p=0.0002 + +print("acc:", torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0014), tensor(3.2780)) + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (tensor(0.1233), tensor(155.9811)) +``` \ No newline at end of file diff --git a/records/091125_VectSigmoidBFloat16/a077c741-ce5d-4639-955b-d7a2660b5cf8.txt b/records/091125_VectSigmoidBFloat16/a077c741-ce5d-4639-955b-d7a2660b5cf8.txt new file mode 100644 index 000000000..34b5d6a43 --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/a077c741-ce5d-4639-955b-d7a2660b5cf8.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:51:37 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 125W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 134W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 46C P0 126W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 37C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.06ms +step:1/1670 train_time:294ms step_avg:294.48ms +step:2/1670 train_time:313ms step_avg:156.32ms +step:3/1670 train_time:382ms step_avg:127.30ms +step:4/1670 train_time:471ms step_avg:117.80ms +step:5/1670 train_time:561ms step_avg:112.24ms +step:6/1670 train_time:652ms step_avg:108.63ms +step:7/1670 train_time:742ms step_avg:106.00ms +step:8/1670 train_time:833ms step_avg:104.12ms +step:9/1670 train_time:923ms step_avg:102.59ms +step:10/1670 train_time:1015ms step_avg:101.48ms +step:11/1670 train_time:1106ms step_avg:100.52ms +step:12/1670 train_time:1200ms step_avg:99.98ms +step:13/1670 train_time:1295ms step_avg:99.59ms +step:14/1670 train_time:1387ms step_avg:99.06ms +step:15/1670 train_time:1479ms step_avg:98.57ms +step:16/1670 train_time:1571ms step_avg:98.17ms +step:17/1670 train_time:1661ms step_avg:97.73ms +step:18/1670 train_time:1754ms step_avg:97.44ms +step:19/1670 train_time:1844ms step_avg:97.07ms +step:20/1670 train_time:1936ms step_avg:96.82ms +step:21/1670 train_time:2028ms step_avg:96.59ms +step:22/1670 train_time:2121ms step_avg:96.40ms +step:23/1670 train_time:2215ms step_avg:96.30ms +step:24/1670 train_time:2308ms step_avg:96.16ms +step:25/1670 train_time:2400ms step_avg:96.00ms +step:26/1670 train_time:2491ms step_avg:95.82ms +step:27/1670 train_time:2582ms step_avg:95.64ms +step:28/1670 train_time:2676ms step_avg:95.58ms +step:29/1670 train_time:2767ms step_avg:95.43ms +step:30/1670 train_time:2859ms step_avg:95.31ms +step:31/1670 train_time:2950ms step_avg:95.16ms +step:32/1670 train_time:3041ms step_avg:95.04ms +step:33/1670 train_time:3133ms step_avg:94.94ms +step:34/1670 train_time:3225ms step_avg:94.85ms +step:35/1670 train_time:3318ms step_avg:94.81ms +step:36/1670 train_time:3410ms step_avg:94.72ms +step:37/1670 train_time:3501ms step_avg:94.63ms +step:38/1670 train_time:3592ms step_avg:94.54ms +step:39/1670 train_time:3684ms step_avg:94.46ms +step:40/1670 train_time:3775ms step_avg:94.37ms +step:41/1670 train_time:3866ms step_avg:94.30ms +step:42/1670 train_time:3958ms step_avg:94.25ms +step:43/1670 train_time:4051ms step_avg:94.20ms +step:44/1670 train_time:4143ms step_avg:94.15ms +step:45/1670 train_time:4235ms step_avg:94.12ms +step:46/1670 train_time:4327ms step_avg:94.06ms +step:47/1670 train_time:4419ms step_avg:94.02ms +step:48/1670 train_time:4512ms step_avg:93.99ms +step:49/1670 train_time:4603ms step_avg:93.93ms +step:50/1670 train_time:4694ms step_avg:93.88ms +step:51/1670 train_time:4786ms step_avg:93.84ms +step:52/1670 train_time:4877ms step_avg:93.80ms +step:53/1670 train_time:4968ms step_avg:93.74ms +step:54/1670 train_time:5060ms step_avg:93.71ms +step:55/1670 train_time:5152ms step_avg:93.68ms +step:56/1670 train_time:5243ms step_avg:93.63ms +step:57/1670 train_time:5335ms step_avg:93.60ms +step:58/1670 train_time:5426ms step_avg:93.56ms +step:59/1670 train_time:5519ms step_avg:93.54ms +step:60/1670 train_time:5611ms step_avg:93.52ms +step:61/1670 train_time:5702ms step_avg:93.48ms +step:62/1670 train_time:5794ms step_avg:93.45ms +step:63/1670 train_time:5885ms step_avg:93.42ms +step:64/1670 train_time:5977ms step_avg:93.39ms +step:65/1670 train_time:6068ms step_avg:93.35ms +step:66/1670 train_time:6159ms step_avg:93.32ms +step:67/1670 train_time:6251ms step_avg:93.30ms +step:68/1670 train_time:6342ms step_avg:93.26ms +step:69/1670 train_time:6435ms step_avg:93.25ms +step:70/1670 train_time:6527ms step_avg:93.24ms +step:71/1670 train_time:6618ms step_avg:93.22ms +step:72/1670 train_time:6710ms step_avg:93.19ms +step:73/1670 train_time:6801ms step_avg:93.17ms +step:74/1670 train_time:6894ms step_avg:93.16ms +step:75/1670 train_time:6985ms step_avg:93.14ms +step:76/1670 train_time:7077ms step_avg:93.12ms +step:77/1670 train_time:7169ms step_avg:93.10ms +step:78/1670 train_time:7261ms step_avg:93.09ms +step:79/1670 train_time:7352ms step_avg:93.06ms +step:80/1670 train_time:7444ms step_avg:93.05ms +step:81/1670 train_time:7538ms step_avg:93.06ms +step:82/1670 train_time:7631ms step_avg:93.06ms +step:83/1670 train_time:7722ms step_avg:93.04ms +step:84/1670 train_time:7814ms step_avg:93.03ms +step:85/1670 train_time:7905ms step_avg:93.00ms +step:86/1670 train_time:7996ms step_avg:92.97ms +step:87/1670 train_time:8087ms step_avg:92.95ms +step:88/1670 train_time:8178ms step_avg:92.93ms +step:89/1670 train_time:8269ms step_avg:92.91ms +step:90/1670 train_time:8361ms step_avg:92.90ms +step:91/1670 train_time:8454ms step_avg:92.90ms +step:92/1670 train_time:8545ms step_avg:92.88ms +step:93/1670 train_time:8638ms step_avg:92.88ms +step:94/1670 train_time:8730ms step_avg:92.87ms +step:95/1670 train_time:8822ms step_avg:92.86ms +step:96/1670 train_time:8914ms step_avg:92.85ms +step:97/1670 train_time:9005ms step_avg:92.84ms +step:98/1670 train_time:9097ms step_avg:92.82ms +step:99/1670 train_time:9188ms step_avg:92.81ms +step:100/1670 train_time:9279ms step_avg:92.79ms +step:101/1670 train_time:9370ms step_avg:92.77ms +step:102/1670 train_time:9462ms step_avg:92.76ms +step:103/1670 train_time:9555ms step_avg:92.76ms +step:104/1670 train_time:9647ms step_avg:92.76ms +step:105/1670 train_time:9739ms step_avg:92.76ms +step:106/1670 train_time:9831ms step_avg:92.75ms +step:107/1670 train_time:9922ms step_avg:92.73ms +step:108/1670 train_time:10014ms step_avg:92.72ms +step:109/1670 train_time:10105ms step_avg:92.71ms +step:110/1670 train_time:10196ms step_avg:92.69ms +step:111/1670 train_time:10287ms step_avg:92.68ms +step:112/1670 train_time:10379ms step_avg:92.67ms +step:113/1670 train_time:10470ms step_avg:92.65ms +step:114/1670 train_time:10561ms step_avg:92.64ms +step:115/1670 train_time:10653ms step_avg:92.64ms +step:116/1670 train_time:10745ms step_avg:92.63ms +step:117/1670 train_time:10837ms step_avg:92.63ms +step:118/1670 train_time:10929ms step_avg:92.62ms +step:119/1670 train_time:11020ms step_avg:92.60ms +step:120/1670 train_time:11111ms step_avg:92.59ms +step:121/1670 train_time:11202ms step_avg:92.58ms +step:122/1670 train_time:11292ms step_avg:92.56ms +step:123/1670 train_time:11384ms step_avg:92.55ms +step:124/1670 train_time:11474ms step_avg:92.53ms +step:125/1670 train_time:11565ms step_avg:92.52ms +step:125/1670 val_loss:4.3026 train_time:11657ms step_avg:93.25ms +step:126/1670 train_time:11676ms step_avg:92.66ms +step:127/1670 train_time:11752ms step_avg:92.53ms +step:128/1670 train_time:11852ms step_avg:92.59ms +step:129/1670 train_time:11945ms step_avg:92.60ms +step:130/1670 train_time:12038ms step_avg:92.60ms +step:131/1670 train_time:12128ms step_avg:92.58ms +step:132/1670 train_time:12218ms step_avg:92.56ms +step:133/1670 train_time:12308ms step_avg:92.54ms +step:134/1670 train_time:12399ms step_avg:92.53ms +step:135/1670 train_time:12489ms step_avg:92.51ms +step:136/1670 train_time:12580ms step_avg:92.50ms +step:137/1670 train_time:12671ms step_avg:92.49ms +step:138/1670 train_time:12763ms step_avg:92.49ms +step:139/1670 train_time:12857ms step_avg:92.50ms +step:140/1670 train_time:12949ms step_avg:92.49ms +step:141/1670 train_time:13040ms step_avg:92.49ms +step:142/1670 train_time:13132ms step_avg:92.48ms +step:143/1670 train_time:13222ms step_avg:92.46ms +step:144/1670 train_time:13313ms step_avg:92.45ms +step:145/1670 train_time:13403ms step_avg:92.43ms +step:146/1670 train_time:13496ms step_avg:92.43ms +step:147/1670 train_time:13586ms step_avg:92.42ms +step:148/1670 train_time:13677ms step_avg:92.41ms +step:149/1670 train_time:13769ms step_avg:92.41ms +step:150/1670 train_time:13861ms step_avg:92.40ms +step:151/1670 train_time:13953ms step_avg:92.40ms +step:152/1670 train_time:14044ms step_avg:92.39ms +step:153/1670 train_time:14136ms step_avg:92.39ms +step:154/1670 train_time:14227ms step_avg:92.38ms +step:155/1670 train_time:14318ms step_avg:92.38ms +step:156/1670 train_time:14409ms step_avg:92.37ms +step:157/1670 train_time:14500ms step_avg:92.36ms +step:158/1670 train_time:14591ms step_avg:92.35ms +step:159/1670 train_time:14681ms step_avg:92.34ms +step:160/1670 train_time:14772ms step_avg:92.33ms +step:161/1670 train_time:14864ms step_avg:92.32ms +step:162/1670 train_time:14955ms step_avg:92.32ms +step:163/1670 train_time:15046ms step_avg:92.31ms +step:164/1670 train_time:15139ms step_avg:92.31ms +step:165/1670 train_time:15231ms step_avg:92.31ms +step:166/1670 train_time:15322ms step_avg:92.30ms +step:167/1670 train_time:15413ms step_avg:92.30ms +step:168/1670 train_time:15503ms step_avg:92.28ms +step:169/1670 train_time:15594ms step_avg:92.27ms +step:170/1670 train_time:15686ms step_avg:92.27ms +step:171/1670 train_time:15776ms step_avg:92.26ms +step:172/1670 train_time:15867ms step_avg:92.25ms +step:173/1670 train_time:15958ms step_avg:92.25ms +step:174/1670 train_time:16050ms step_avg:92.24ms +step:175/1670 train_time:16141ms step_avg:92.24ms +step:176/1670 train_time:16232ms step_avg:92.23ms +step:177/1670 train_time:16323ms step_avg:92.22ms +step:178/1670 train_time:16415ms step_avg:92.22ms +step:179/1670 train_time:16505ms step_avg:92.21ms +step:180/1670 train_time:16597ms step_avg:92.21ms +step:181/1670 train_time:16688ms step_avg:92.20ms +step:182/1670 train_time:16779ms step_avg:92.19ms +step:183/1670 train_time:16870ms step_avg:92.18ms +step:184/1670 train_time:16961ms step_avg:92.18ms +step:185/1670 train_time:17052ms step_avg:92.18ms +step:186/1670 train_time:17143ms step_avg:92.17ms +step:187/1670 train_time:17235ms step_avg:92.17ms +step:188/1670 train_time:17326ms step_avg:92.16ms +step:189/1670 train_time:17419ms step_avg:92.16ms +step:190/1670 train_time:17511ms step_avg:92.16ms +step:191/1670 train_time:17602ms step_avg:92.16ms +step:192/1670 train_time:17695ms step_avg:92.16ms +step:193/1670 train_time:17785ms step_avg:92.15ms +step:194/1670 train_time:17876ms step_avg:92.14ms +step:195/1670 train_time:17967ms step_avg:92.14ms +step:196/1670 train_time:18058ms step_avg:92.13ms +step:197/1670 train_time:18149ms step_avg:92.13ms +step:198/1670 train_time:18240ms step_avg:92.12ms +step:199/1670 train_time:18332ms step_avg:92.12ms +step:200/1670 train_time:18422ms step_avg:92.11ms +step:201/1670 train_time:18514ms step_avg:92.11ms +step:202/1670 train_time:18604ms step_avg:92.10ms +step:203/1670 train_time:18696ms step_avg:92.10ms +step:204/1670 train_time:18787ms step_avg:92.09ms +step:205/1670 train_time:18878ms step_avg:92.09ms +step:206/1670 train_time:18969ms step_avg:92.08ms +step:207/1670 train_time:19061ms step_avg:92.08ms +step:208/1670 train_time:19151ms step_avg:92.07ms +step:209/1670 train_time:19242ms step_avg:92.07ms +step:210/1670 train_time:19333ms step_avg:92.06ms +step:211/1670 train_time:19424ms step_avg:92.06ms +step:212/1670 train_time:19515ms step_avg:92.05ms +step:213/1670 train_time:19765ms step_avg:92.79ms +step:214/1670 train_time:19833ms step_avg:92.68ms +step:215/1670 train_time:19922ms step_avg:92.66ms +step:216/1670 train_time:20013ms step_avg:92.65ms +step:217/1670 train_time:20102ms step_avg:92.64ms +step:218/1670 train_time:20193ms step_avg:92.63ms +step:219/1670 train_time:20283ms step_avg:92.62ms +step:220/1670 train_time:20374ms step_avg:92.61ms +step:221/1670 train_time:20464ms step_avg:92.60ms +step:222/1670 train_time:20554ms step_avg:92.59ms +step:223/1670 train_time:20647ms step_avg:92.59ms +step:224/1670 train_time:20742ms step_avg:92.60ms +step:225/1670 train_time:20838ms step_avg:92.62ms +step:226/1670 train_time:20930ms step_avg:92.61ms +step:227/1670 train_time:21021ms step_avg:92.60ms +step:228/1670 train_time:21112ms step_avg:92.59ms +step:229/1670 train_time:21202ms step_avg:92.58ms +step:230/1670 train_time:21292ms step_avg:92.57ms +step:231/1670 train_time:21382ms step_avg:92.56ms +step:232/1670 train_time:21472ms step_avg:92.55ms +step:233/1670 train_time:21563ms step_avg:92.55ms +step:234/1670 train_time:21656ms step_avg:92.55ms +step:235/1670 train_time:21749ms step_avg:92.55ms +step:236/1670 train_time:21842ms step_avg:92.55ms +step:237/1670 train_time:21935ms step_avg:92.55ms +step:238/1670 train_time:22025ms step_avg:92.54ms +step:239/1670 train_time:22116ms step_avg:92.54ms +step:240/1670 train_time:22207ms step_avg:92.53ms +step:241/1670 train_time:22298ms step_avg:92.52ms +step:242/1670 train_time:22388ms step_avg:92.51ms +step:243/1670 train_time:22479ms step_avg:92.50ms +step:244/1670 train_time:22569ms step_avg:92.50ms +step:245/1670 train_time:22661ms step_avg:92.49ms +step:246/1670 train_time:22752ms step_avg:92.49ms +step:247/1670 train_time:22844ms step_avg:92.49ms +step:248/1670 train_time:22936ms step_avg:92.48ms +step:249/1670 train_time:23027ms step_avg:92.48ms +step:250/1670 train_time:23118ms step_avg:92.47ms +step:250/1670 val_loss:3.9615 train_time:23208ms step_avg:92.83ms +step:251/1670 train_time:23228ms step_avg:92.54ms +step:252/1670 train_time:23301ms step_avg:92.46ms +step:253/1670 train_time:23393ms step_avg:92.46ms +step:254/1670 train_time:23484ms step_avg:92.46ms +step:255/1670 train_time:23574ms step_avg:92.45ms +step:256/1670 train_time:23664ms step_avg:92.44ms +step:257/1670 train_time:23754ms step_avg:92.43ms +step:258/1670 train_time:23845ms step_avg:92.42ms +step:259/1670 train_time:23936ms step_avg:92.42ms +step:260/1670 train_time:24026ms step_avg:92.41ms +step:261/1670 train_time:24119ms step_avg:92.41ms +step:262/1670 train_time:24213ms step_avg:92.42ms +step:263/1670 train_time:24305ms step_avg:92.42ms +step:264/1670 train_time:24398ms step_avg:92.42ms +step:265/1670 train_time:24491ms step_avg:92.42ms +step:266/1670 train_time:24581ms step_avg:92.41ms +step:267/1670 train_time:24671ms step_avg:92.40ms +step:268/1670 train_time:24761ms step_avg:92.39ms +step:269/1670 train_time:24852ms step_avg:92.39ms +step:270/1670 train_time:24942ms step_avg:92.38ms +step:271/1670 train_time:25033ms step_avg:92.37ms +step:272/1670 train_time:25125ms step_avg:92.37ms +step:273/1670 train_time:25217ms step_avg:92.37ms +step:274/1670 train_time:25309ms step_avg:92.37ms +step:275/1670 train_time:25402ms step_avg:92.37ms +step:276/1670 train_time:25493ms step_avg:92.37ms +step:277/1670 train_time:25584ms step_avg:92.36ms +step:278/1670 train_time:25676ms step_avg:92.36ms +step:279/1670 train_time:25765ms step_avg:92.35ms +step:280/1670 train_time:25856ms step_avg:92.34ms +step:281/1670 train_time:25947ms step_avg:92.34ms +step:282/1670 train_time:26039ms step_avg:92.34ms +step:283/1670 train_time:26131ms step_avg:92.33ms +step:284/1670 train_time:26223ms step_avg:92.33ms +step:285/1670 train_time:26315ms step_avg:92.33ms +step:286/1670 train_time:26405ms step_avg:92.33ms +step:287/1670 train_time:26498ms step_avg:92.33ms +step:288/1670 train_time:26590ms step_avg:92.33ms +step:289/1670 train_time:26682ms step_avg:92.33ms +step:290/1670 train_time:26774ms step_avg:92.32ms +step:291/1670 train_time:26865ms step_avg:92.32ms +step:292/1670 train_time:26956ms step_avg:92.31ms +step:293/1670 train_time:27046ms step_avg:92.31ms +step:294/1670 train_time:27138ms step_avg:92.31ms +step:295/1670 train_time:27229ms step_avg:92.30ms +step:296/1670 train_time:27320ms step_avg:92.30ms +step:297/1670 train_time:27412ms step_avg:92.30ms +step:298/1670 train_time:27504ms step_avg:92.30ms +step:299/1670 train_time:27594ms step_avg:92.29ms +step:300/1670 train_time:27687ms step_avg:92.29ms +step:301/1670 train_time:27777ms step_avg:92.28ms +step:302/1670 train_time:27869ms step_avg:92.28ms +step:303/1670 train_time:27959ms step_avg:92.27ms +step:304/1670 train_time:28049ms step_avg:92.27ms +step:305/1670 train_time:28141ms step_avg:92.27ms +step:306/1670 train_time:28231ms step_avg:92.26ms +step:307/1670 train_time:28323ms step_avg:92.26ms +step:308/1670 train_time:28414ms step_avg:92.25ms +step:309/1670 train_time:28506ms step_avg:92.25ms +step:310/1670 train_time:28597ms step_avg:92.25ms +step:311/1670 train_time:28689ms step_avg:92.25ms +step:312/1670 train_time:28779ms step_avg:92.24ms +step:313/1670 train_time:28872ms step_avg:92.24ms +step:314/1670 train_time:28962ms step_avg:92.23ms +step:315/1670 train_time:29053ms step_avg:92.23ms +step:316/1670 train_time:29144ms step_avg:92.23ms +step:317/1670 train_time:29234ms step_avg:92.22ms +step:318/1670 train_time:29326ms step_avg:92.22ms +step:319/1670 train_time:29417ms step_avg:92.22ms +step:320/1670 train_time:29508ms step_avg:92.21ms +step:321/1670 train_time:29600ms step_avg:92.21ms +step:322/1670 train_time:29693ms step_avg:92.21ms +step:323/1670 train_time:29785ms step_avg:92.21ms +step:324/1670 train_time:29876ms step_avg:92.21ms +step:325/1670 train_time:29967ms step_avg:92.20ms +step:326/1670 train_time:30057ms step_avg:92.20ms +step:327/1670 train_time:30149ms step_avg:92.20ms +step:328/1670 train_time:30240ms step_avg:92.20ms +step:329/1670 train_time:30332ms step_avg:92.19ms +step:330/1670 train_time:30423ms step_avg:92.19ms +step:331/1670 train_time:30514ms step_avg:92.19ms +step:332/1670 train_time:30606ms step_avg:92.19ms +step:333/1670 train_time:30697ms step_avg:92.18ms +step:334/1670 train_time:30789ms step_avg:92.18ms +step:335/1670 train_time:30880ms step_avg:92.18ms +step:336/1670 train_time:30972ms step_avg:92.18ms +step:337/1670 train_time:31063ms step_avg:92.17ms +step:338/1670 train_time:31153ms step_avg:92.17ms +step:339/1670 train_time:31244ms step_avg:92.17ms +step:340/1670 train_time:31336ms step_avg:92.16ms +step:341/1670 train_time:31425ms step_avg:92.16ms +step:342/1670 train_time:31516ms step_avg:92.15ms +step:343/1670 train_time:31607ms step_avg:92.15ms +step:344/1670 train_time:31698ms step_avg:92.14ms +step:345/1670 train_time:31790ms step_avg:92.15ms +step:346/1670 train_time:31883ms step_avg:92.15ms +step:347/1670 train_time:31974ms step_avg:92.14ms +step:348/1670 train_time:32065ms step_avg:92.14ms +step:349/1670 train_time:32155ms step_avg:92.14ms +step:350/1670 train_time:32246ms step_avg:92.13ms +step:351/1670 train_time:32338ms step_avg:92.13ms +step:352/1670 train_time:32429ms step_avg:92.13ms +step:353/1670 train_time:32520ms step_avg:92.12ms +step:354/1670 train_time:32611ms step_avg:92.12ms +step:355/1670 train_time:32702ms step_avg:92.12ms +step:356/1670 train_time:32795ms step_avg:92.12ms +step:357/1670 train_time:32887ms step_avg:92.12ms +step:358/1670 train_time:32978ms step_avg:92.12ms +step:359/1670 train_time:33070ms step_avg:92.12ms +step:360/1670 train_time:33161ms step_avg:92.11ms +step:361/1670 train_time:33252ms step_avg:92.11ms +step:362/1670 train_time:33343ms step_avg:92.11ms +step:363/1670 train_time:33434ms step_avg:92.10ms +step:364/1670 train_time:33525ms step_avg:92.10ms +step:365/1670 train_time:33615ms step_avg:92.10ms +step:366/1670 train_time:33706ms step_avg:92.09ms +step:367/1670 train_time:33798ms step_avg:92.09ms +step:368/1670 train_time:33889ms step_avg:92.09ms +step:369/1670 train_time:33980ms step_avg:92.09ms +step:370/1670 train_time:34073ms step_avg:92.09ms +step:371/1670 train_time:34164ms step_avg:92.09ms +step:372/1670 train_time:34255ms step_avg:92.08ms +step:373/1670 train_time:34346ms step_avg:92.08ms +step:374/1670 train_time:34436ms step_avg:92.08ms +step:375/1670 train_time:34528ms step_avg:92.07ms +step:375/1670 val_loss:3.8087 train_time:34618ms step_avg:92.32ms +step:376/1670 train_time:34638ms step_avg:92.12ms +step:377/1670 train_time:34712ms step_avg:92.07ms +step:378/1670 train_time:34803ms step_avg:92.07ms +step:379/1670 train_time:34894ms step_avg:92.07ms +step:380/1670 train_time:34984ms step_avg:92.06ms +step:381/1670 train_time:35075ms step_avg:92.06ms +step:382/1670 train_time:35164ms step_avg:92.05ms +step:383/1670 train_time:35256ms step_avg:92.05ms +step:384/1670 train_time:35347ms step_avg:92.05ms +step:385/1670 train_time:35440ms step_avg:92.05ms +step:386/1670 train_time:35531ms step_avg:92.05ms +step:387/1670 train_time:35623ms step_avg:92.05ms +step:388/1670 train_time:35716ms step_avg:92.05ms +step:389/1670 train_time:35809ms step_avg:92.05ms +step:390/1670 train_time:35903ms step_avg:92.06ms +step:391/1670 train_time:35992ms step_avg:92.05ms +step:392/1670 train_time:36083ms step_avg:92.05ms +step:393/1670 train_time:36174ms step_avg:92.05ms +step:394/1670 train_time:36264ms step_avg:92.04ms +step:395/1670 train_time:36355ms step_avg:92.04ms +step:396/1670 train_time:36446ms step_avg:92.04ms +step:397/1670 train_time:36538ms step_avg:92.03ms +step:398/1670 train_time:36630ms step_avg:92.04ms +step:399/1670 train_time:36722ms step_avg:92.03ms +step:400/1670 train_time:36813ms step_avg:92.03ms +step:401/1670 train_time:36905ms step_avg:92.03ms +step:402/1670 train_time:36994ms step_avg:92.03ms +step:403/1670 train_time:37085ms step_avg:92.02ms +step:404/1670 train_time:37176ms step_avg:92.02ms +step:405/1670 train_time:37266ms step_avg:92.02ms +step:406/1670 train_time:37358ms step_avg:92.01ms +step:407/1670 train_time:37449ms step_avg:92.01ms +step:408/1670 train_time:37541ms step_avg:92.01ms +step:409/1670 train_time:37634ms step_avg:92.01ms +step:410/1670 train_time:37724ms step_avg:92.01ms +step:411/1670 train_time:37816ms step_avg:92.01ms +step:412/1670 train_time:37907ms step_avg:92.01ms +step:413/1670 train_time:37999ms step_avg:92.01ms +step:414/1670 train_time:38090ms step_avg:92.00ms +step:415/1670 train_time:38182ms step_avg:92.00ms +step:416/1670 train_time:38273ms step_avg:92.00ms +step:417/1670 train_time:38363ms step_avg:92.00ms +step:418/1670 train_time:38454ms step_avg:91.99ms +step:419/1670 train_time:38545ms step_avg:91.99ms +step:420/1670 train_time:38637ms step_avg:91.99ms +step:421/1670 train_time:38729ms step_avg:91.99ms +step:422/1670 train_time:38820ms step_avg:91.99ms +step:423/1670 train_time:38912ms step_avg:91.99ms +step:424/1670 train_time:39002ms step_avg:91.99ms +step:425/1670 train_time:39252ms step_avg:92.36ms +step:426/1670 train_time:39328ms step_avg:92.32ms +step:427/1670 train_time:39418ms step_avg:92.31ms +step:428/1670 train_time:39508ms step_avg:92.31ms +step:429/1670 train_time:39598ms step_avg:92.30ms +step:430/1670 train_time:39688ms step_avg:92.30ms +step:431/1670 train_time:39778ms step_avg:92.29ms +step:432/1670 train_time:39868ms step_avg:92.29ms +step:433/1670 train_time:39958ms step_avg:92.28ms +step:434/1670 train_time:40049ms step_avg:92.28ms +step:435/1670 train_time:40144ms step_avg:92.28ms +step:436/1670 train_time:40243ms step_avg:92.30ms +step:437/1670 train_time:40336ms step_avg:92.30ms +step:438/1670 train_time:40427ms step_avg:92.30ms +step:439/1670 train_time:40518ms step_avg:92.30ms +step:440/1670 train_time:40608ms step_avg:92.29ms +step:441/1670 train_time:40699ms step_avg:92.29ms +step:442/1670 train_time:40789ms step_avg:92.28ms +step:443/1670 train_time:40879ms step_avg:92.28ms +step:444/1670 train_time:40969ms step_avg:92.27ms +step:445/1670 train_time:41060ms step_avg:92.27ms +step:446/1670 train_time:41153ms step_avg:92.27ms +step:447/1670 train_time:41246ms step_avg:92.27ms +step:448/1670 train_time:41340ms step_avg:92.28ms +step:449/1670 train_time:41431ms step_avg:92.27ms +step:450/1670 train_time:41521ms step_avg:92.27ms +step:451/1670 train_time:41613ms step_avg:92.27ms +step:452/1670 train_time:41703ms step_avg:92.26ms +step:453/1670 train_time:41793ms step_avg:92.26ms +step:454/1670 train_time:41883ms step_avg:92.25ms +step:455/1670 train_time:41974ms step_avg:92.25ms +step:456/1670 train_time:42066ms step_avg:92.25ms +step:457/1670 train_time:42158ms step_avg:92.25ms +step:458/1670 train_time:42250ms step_avg:92.25ms +step:459/1670 train_time:42343ms step_avg:92.25ms +step:460/1670 train_time:42434ms step_avg:92.25ms +step:461/1670 train_time:42525ms step_avg:92.25ms +step:462/1670 train_time:42617ms step_avg:92.24ms +step:463/1670 train_time:42708ms step_avg:92.24ms +step:464/1670 train_time:42799ms step_avg:92.24ms +step:465/1670 train_time:42889ms step_avg:92.24ms +step:466/1670 train_time:42981ms step_avg:92.23ms +step:467/1670 train_time:43072ms step_avg:92.23ms +step:468/1670 train_time:43164ms step_avg:92.23ms +step:469/1670 train_time:43255ms step_avg:92.23ms +step:470/1670 train_time:43347ms step_avg:92.23ms +step:471/1670 train_time:43439ms step_avg:92.23ms +step:472/1670 train_time:43530ms step_avg:92.23ms +step:473/1670 train_time:43621ms step_avg:92.22ms +step:474/1670 train_time:43711ms step_avg:92.22ms +step:475/1670 train_time:43802ms step_avg:92.21ms +step:476/1670 train_time:43892ms step_avg:92.21ms +step:477/1670 train_time:43983ms step_avg:92.21ms +step:478/1670 train_time:44075ms step_avg:92.21ms +step:479/1670 train_time:44165ms step_avg:92.20ms +step:480/1670 train_time:44257ms step_avg:92.20ms +step:481/1670 train_time:44348ms step_avg:92.20ms +step:482/1670 train_time:44441ms step_avg:92.20ms +step:483/1670 train_time:44533ms step_avg:92.20ms +step:484/1670 train_time:44624ms step_avg:92.20ms +step:485/1670 train_time:44715ms step_avg:92.20ms +step:486/1670 train_time:44806ms step_avg:92.19ms +step:487/1670 train_time:44897ms step_avg:92.19ms +step:488/1670 train_time:44987ms step_avg:92.19ms +step:489/1670 train_time:45078ms step_avg:92.18ms +step:490/1670 train_time:45170ms step_avg:92.18ms +step:491/1670 train_time:45262ms step_avg:92.18ms +step:492/1670 train_time:45354ms step_avg:92.18ms +step:493/1670 train_time:45444ms step_avg:92.18ms +step:494/1670 train_time:45536ms step_avg:92.18ms +step:495/1670 train_time:45628ms step_avg:92.18ms +step:496/1670 train_time:45719ms step_avg:92.18ms +step:497/1670 train_time:45810ms step_avg:92.17ms +step:498/1670 train_time:45901ms step_avg:92.17ms +step:499/1670 train_time:45993ms step_avg:92.17ms +step:500/1670 train_time:46083ms step_avg:92.17ms +step:500/1670 val_loss:3.7114 train_time:46173ms step_avg:92.35ms +step:501/1670 train_time:46193ms step_avg:92.20ms +step:502/1670 train_time:46266ms step_avg:92.16ms +step:503/1670 train_time:46359ms step_avg:92.16ms +step:504/1670 train_time:46452ms step_avg:92.17ms +step:505/1670 train_time:46542ms step_avg:92.16ms +step:506/1670 train_time:46632ms step_avg:92.16ms +step:507/1670 train_time:46722ms step_avg:92.15ms +step:508/1670 train_time:46814ms step_avg:92.15ms +step:509/1670 train_time:46904ms step_avg:92.15ms +step:510/1670 train_time:46995ms step_avg:92.15ms +step:511/1670 train_time:47086ms step_avg:92.14ms +step:512/1670 train_time:47178ms step_avg:92.14ms +step:513/1670 train_time:47270ms step_avg:92.14ms +step:514/1670 train_time:47362ms step_avg:92.14ms +step:515/1670 train_time:47455ms step_avg:92.15ms +step:516/1670 train_time:47546ms step_avg:92.14ms +step:517/1670 train_time:47637ms step_avg:92.14ms +step:518/1670 train_time:47728ms step_avg:92.14ms +step:519/1670 train_time:47819ms step_avg:92.14ms +step:520/1670 train_time:47909ms step_avg:92.13ms +step:521/1670 train_time:48001ms step_avg:92.13ms +step:522/1670 train_time:48092ms step_avg:92.13ms +step:523/1670 train_time:48184ms step_avg:92.13ms +step:524/1670 train_time:48275ms step_avg:92.13ms +step:525/1670 train_time:48366ms step_avg:92.13ms +step:526/1670 train_time:48457ms step_avg:92.12ms +step:527/1670 train_time:48549ms step_avg:92.12ms +step:528/1670 train_time:48639ms step_avg:92.12ms +step:529/1670 train_time:48730ms step_avg:92.12ms +step:530/1670 train_time:48821ms step_avg:92.11ms +step:531/1670 train_time:48914ms step_avg:92.12ms +step:532/1670 train_time:49005ms step_avg:92.11ms +step:533/1670 train_time:49096ms step_avg:92.11ms +step:534/1670 train_time:49187ms step_avg:92.11ms +step:535/1670 train_time:49279ms step_avg:92.11ms +step:536/1670 train_time:49370ms step_avg:92.11ms +step:537/1670 train_time:49462ms step_avg:92.11ms +step:538/1670 train_time:49554ms step_avg:92.11ms +step:539/1670 train_time:49645ms step_avg:92.11ms +step:540/1670 train_time:49737ms step_avg:92.11ms +step:541/1670 train_time:49829ms step_avg:92.11ms +step:542/1670 train_time:49920ms step_avg:92.10ms +step:543/1670 train_time:50012ms step_avg:92.10ms +step:544/1670 train_time:50103ms step_avg:92.10ms +step:545/1670 train_time:50194ms step_avg:92.10ms +step:546/1670 train_time:50286ms step_avg:92.10ms +step:547/1670 train_time:50378ms step_avg:92.10ms +step:548/1670 train_time:50469ms step_avg:92.10ms +step:549/1670 train_time:50560ms step_avg:92.09ms +step:550/1670 train_time:50651ms step_avg:92.09ms +step:551/1670 train_time:50742ms step_avg:92.09ms +step:552/1670 train_time:50835ms step_avg:92.09ms +step:553/1670 train_time:50926ms step_avg:92.09ms +step:554/1670 train_time:51017ms step_avg:92.09ms +step:555/1670 train_time:51108ms step_avg:92.09ms +step:556/1670 train_time:51199ms step_avg:92.08ms +step:557/1670 train_time:51289ms step_avg:92.08ms +step:558/1670 train_time:51574ms step_avg:92.43ms +step:559/1670 train_time:51647ms step_avg:92.39ms +step:560/1670 train_time:51738ms step_avg:92.39ms +step:561/1670 train_time:51829ms step_avg:92.39ms +step:562/1670 train_time:51920ms step_avg:92.38ms +step:563/1670 train_time:52011ms step_avg:92.38ms +step:564/1670 train_time:52103ms step_avg:92.38ms +step:565/1670 train_time:52194ms step_avg:92.38ms +step:566/1670 train_time:52285ms step_avg:92.38ms +step:567/1670 train_time:52376ms step_avg:92.37ms +step:568/1670 train_time:52472ms step_avg:92.38ms +step:569/1670 train_time:52568ms step_avg:92.39ms +step:570/1670 train_time:52661ms step_avg:92.39ms +step:571/1670 train_time:52755ms step_avg:92.39ms +step:572/1670 train_time:52847ms step_avg:92.39ms +step:573/1670 train_time:52940ms step_avg:92.39ms +step:574/1670 train_time:53032ms step_avg:92.39ms +step:575/1670 train_time:53123ms step_avg:92.39ms +step:576/1670 train_time:53215ms step_avg:92.39ms +step:577/1670 train_time:53307ms step_avg:92.39ms +step:578/1670 train_time:53399ms step_avg:92.39ms +step:579/1670 train_time:53493ms step_avg:92.39ms +step:580/1670 train_time:53586ms step_avg:92.39ms +step:581/1670 train_time:53680ms step_avg:92.39ms +step:582/1670 train_time:53774ms step_avg:92.40ms +step:583/1670 train_time:53866ms step_avg:92.39ms +step:584/1670 train_time:53958ms step_avg:92.39ms +step:585/1670 train_time:54051ms step_avg:92.39ms +step:586/1670 train_time:54142ms step_avg:92.39ms +step:587/1670 train_time:54233ms step_avg:92.39ms +step:588/1670 train_time:54325ms step_avg:92.39ms +step:589/1670 train_time:54418ms step_avg:92.39ms +step:590/1670 train_time:54511ms step_avg:92.39ms +step:591/1670 train_time:54604ms step_avg:92.39ms +step:592/1670 train_time:54697ms step_avg:92.39ms +step:593/1670 train_time:54791ms step_avg:92.40ms +step:594/1670 train_time:54884ms step_avg:92.40ms +step:595/1670 train_time:54977ms step_avg:92.40ms +step:596/1670 train_time:55069ms step_avg:92.40ms +step:597/1670 train_time:55161ms step_avg:92.40ms +step:598/1670 train_time:55253ms step_avg:92.40ms +step:599/1670 train_time:55345ms step_avg:92.40ms +step:600/1670 train_time:55438ms step_avg:92.40ms +step:601/1670 train_time:55531ms step_avg:92.40ms +step:602/1670 train_time:55624ms step_avg:92.40ms +step:603/1670 train_time:55718ms step_avg:92.40ms +step:604/1670 train_time:55811ms step_avg:92.40ms +step:605/1670 train_time:55903ms step_avg:92.40ms +step:606/1670 train_time:55995ms step_avg:92.40ms +step:607/1670 train_time:56087ms step_avg:92.40ms +step:608/1670 train_time:56179ms step_avg:92.40ms +step:609/1670 train_time:56271ms step_avg:92.40ms +step:610/1670 train_time:56362ms step_avg:92.40ms +step:611/1670 train_time:56456ms step_avg:92.40ms +step:612/1670 train_time:56549ms step_avg:92.40ms +step:613/1670 train_time:56641ms step_avg:92.40ms +step:614/1670 train_time:56735ms step_avg:92.40ms +step:615/1670 train_time:56827ms step_avg:92.40ms +step:616/1670 train_time:56920ms step_avg:92.40ms +step:617/1670 train_time:57012ms step_avg:92.40ms +step:618/1670 train_time:57104ms step_avg:92.40ms +step:619/1670 train_time:57196ms step_avg:92.40ms +step:620/1670 train_time:57288ms step_avg:92.40ms +step:621/1670 train_time:57380ms step_avg:92.40ms +step:622/1670 train_time:57472ms step_avg:92.40ms +step:623/1670 train_time:57564ms step_avg:92.40ms +step:624/1670 train_time:57659ms step_avg:92.40ms +step:625/1670 train_time:57752ms step_avg:92.40ms +step:625/1670 val_loss:3.6108 train_time:57844ms step_avg:92.55ms +step:626/1670 train_time:57864ms step_avg:92.44ms +step:627/1670 train_time:57943ms step_avg:92.41ms +step:628/1670 train_time:58042ms step_avg:92.42ms +step:629/1670 train_time:58137ms step_avg:92.43ms +step:630/1670 train_time:58229ms step_avg:92.43ms +step:631/1670 train_time:58320ms step_avg:92.42ms +step:632/1670 train_time:58411ms step_avg:92.42ms +step:633/1670 train_time:58502ms step_avg:92.42ms +step:634/1670 train_time:58593ms step_avg:92.42ms +step:635/1670 train_time:58684ms step_avg:92.42ms +step:636/1670 train_time:58776ms step_avg:92.41ms +step:637/1670 train_time:58868ms step_avg:92.42ms +step:638/1670 train_time:58966ms step_avg:92.42ms +step:639/1670 train_time:59188ms step_avg:92.63ms +step:640/1670 train_time:59276ms step_avg:92.62ms +step:641/1670 train_time:59367ms step_avg:92.62ms +step:642/1670 train_time:59458ms step_avg:92.61ms +step:643/1670 train_time:59549ms step_avg:92.61ms +step:644/1670 train_time:59641ms step_avg:92.61ms +step:645/1670 train_time:59732ms step_avg:92.61ms +step:646/1670 train_time:59824ms step_avg:92.61ms +step:647/1670 train_time:59915ms step_avg:92.60ms +step:648/1670 train_time:60006ms step_avg:92.60ms +step:649/1670 train_time:60102ms step_avg:92.61ms +step:650/1670 train_time:60198ms step_avg:92.61ms +step:651/1670 train_time:60292ms step_avg:92.61ms +step:652/1670 train_time:60385ms step_avg:92.61ms +step:653/1670 train_time:60477ms step_avg:92.61ms +step:654/1670 train_time:60568ms step_avg:92.61ms +step:655/1670 train_time:60660ms step_avg:92.61ms +step:656/1670 train_time:60752ms step_avg:92.61ms +step:657/1670 train_time:60844ms step_avg:92.61ms +step:658/1670 train_time:60936ms step_avg:92.61ms +step:659/1670 train_time:61029ms step_avg:92.61ms +step:660/1670 train_time:61125ms step_avg:92.61ms +step:661/1670 train_time:61220ms step_avg:92.62ms +step:662/1670 train_time:61313ms step_avg:92.62ms +step:663/1670 train_time:61406ms step_avg:92.62ms +step:664/1670 train_time:61499ms step_avg:92.62ms +step:665/1670 train_time:61591ms step_avg:92.62ms +step:666/1670 train_time:61683ms step_avg:92.62ms +step:667/1670 train_time:61774ms step_avg:92.61ms +step:668/1670 train_time:61865ms step_avg:92.61ms +step:669/1670 train_time:61958ms step_avg:92.61ms +step:670/1670 train_time:62051ms step_avg:92.61ms +step:671/1670 train_time:62146ms step_avg:92.62ms +step:672/1670 train_time:62240ms step_avg:92.62ms +step:673/1670 train_time:62332ms step_avg:92.62ms +step:674/1670 train_time:62425ms step_avg:92.62ms +step:675/1670 train_time:62518ms step_avg:92.62ms +step:676/1670 train_time:62610ms step_avg:92.62ms +step:677/1670 train_time:62703ms step_avg:92.62ms +step:678/1670 train_time:62794ms step_avg:92.62ms +step:679/1670 train_time:62886ms step_avg:92.62ms +step:680/1670 train_time:62978ms step_avg:92.61ms +step:681/1670 train_time:63070ms step_avg:92.61ms +step:682/1670 train_time:63164ms step_avg:92.62ms +step:683/1670 train_time:63257ms step_avg:92.62ms +step:684/1670 train_time:63349ms step_avg:92.62ms +step:685/1670 train_time:63443ms step_avg:92.62ms +step:686/1670 train_time:63535ms step_avg:92.62ms +step:687/1670 train_time:63627ms step_avg:92.62ms +step:688/1670 train_time:63720ms step_avg:92.62ms +step:689/1670 train_time:63812ms step_avg:92.62ms +step:690/1670 train_time:63904ms step_avg:92.61ms +step:691/1670 train_time:63995ms step_avg:92.61ms +step:692/1670 train_time:64087ms step_avg:92.61ms +step:693/1670 train_time:64181ms step_avg:92.61ms +step:694/1670 train_time:64274ms step_avg:92.61ms +step:695/1670 train_time:64367ms step_avg:92.61ms +step:696/1670 train_time:64459ms step_avg:92.61ms +step:697/1670 train_time:64552ms step_avg:92.61ms +step:698/1670 train_time:64645ms step_avg:92.61ms +step:699/1670 train_time:64737ms step_avg:92.61ms +step:700/1670 train_time:64829ms step_avg:92.61ms +step:701/1670 train_time:64921ms step_avg:92.61ms +step:702/1670 train_time:65014ms step_avg:92.61ms +step:703/1670 train_time:65105ms step_avg:92.61ms +step:704/1670 train_time:65198ms step_avg:92.61ms +step:705/1670 train_time:65291ms step_avg:92.61ms +step:706/1670 train_time:65385ms step_avg:92.61ms +step:707/1670 train_time:65478ms step_avg:92.61ms +step:708/1670 train_time:65570ms step_avg:92.61ms +step:709/1670 train_time:65663ms step_avg:92.61ms +step:710/1670 train_time:65756ms step_avg:92.61ms +step:711/1670 train_time:65848ms step_avg:92.61ms +step:712/1670 train_time:65940ms step_avg:92.61ms +step:713/1670 train_time:66032ms step_avg:92.61ms +step:714/1670 train_time:66125ms step_avg:92.61ms +step:715/1670 train_time:66218ms step_avg:92.61ms +step:716/1670 train_time:66310ms step_avg:92.61ms +step:717/1670 train_time:66403ms step_avg:92.61ms +step:718/1670 train_time:66495ms step_avg:92.61ms +step:719/1670 train_time:66587ms step_avg:92.61ms +step:720/1670 train_time:66680ms step_avg:92.61ms +step:721/1670 train_time:66772ms step_avg:92.61ms +step:722/1670 train_time:66865ms step_avg:92.61ms +step:723/1670 train_time:66958ms step_avg:92.61ms +step:724/1670 train_time:67050ms step_avg:92.61ms +step:725/1670 train_time:67143ms step_avg:92.61ms +step:726/1670 train_time:67235ms step_avg:92.61ms +step:727/1670 train_time:67328ms step_avg:92.61ms +step:728/1670 train_time:67420ms step_avg:92.61ms +step:729/1670 train_time:67512ms step_avg:92.61ms +step:730/1670 train_time:67605ms step_avg:92.61ms +step:731/1670 train_time:67697ms step_avg:92.61ms +step:732/1670 train_time:67789ms step_avg:92.61ms +step:733/1670 train_time:67881ms step_avg:92.61ms +step:734/1670 train_time:67974ms step_avg:92.61ms +step:735/1670 train_time:68066ms step_avg:92.61ms +step:736/1670 train_time:68159ms step_avg:92.61ms +step:737/1670 train_time:68252ms step_avg:92.61ms +step:738/1670 train_time:68345ms step_avg:92.61ms +step:739/1670 train_time:68438ms step_avg:92.61ms +step:740/1670 train_time:68529ms step_avg:92.61ms +step:741/1670 train_time:68623ms step_avg:92.61ms +step:742/1670 train_time:68714ms step_avg:92.61ms +step:743/1670 train_time:68806ms step_avg:92.61ms +step:744/1670 train_time:68899ms step_avg:92.61ms +step:745/1670 train_time:68991ms step_avg:92.61ms +step:746/1670 train_time:69084ms step_avg:92.61ms +step:747/1670 train_time:69177ms step_avg:92.61ms +step:748/1670 train_time:69269ms step_avg:92.61ms +step:749/1670 train_time:69362ms step_avg:92.61ms +step:750/1670 train_time:69454ms step_avg:92.61ms +step:750/1670 val_loss:3.5593 train_time:69547ms step_avg:92.73ms +step:751/1670 train_time:69566ms step_avg:92.63ms +step:752/1670 train_time:69642ms step_avg:92.61ms +step:753/1670 train_time:69735ms step_avg:92.61ms +step:754/1670 train_time:69827ms step_avg:92.61ms +step:755/1670 train_time:69919ms step_avg:92.61ms +step:756/1670 train_time:70010ms step_avg:92.61ms +step:757/1670 train_time:70102ms step_avg:92.61ms +step:758/1670 train_time:70195ms step_avg:92.61ms +step:759/1670 train_time:70287ms step_avg:92.61ms +step:760/1670 train_time:70381ms step_avg:92.61ms +step:761/1670 train_time:70474ms step_avg:92.61ms +step:762/1670 train_time:70567ms step_avg:92.61ms +step:763/1670 train_time:70661ms step_avg:92.61ms +step:764/1670 train_time:70754ms step_avg:92.61ms +step:765/1670 train_time:70846ms step_avg:92.61ms +step:766/1670 train_time:70938ms step_avg:92.61ms +step:767/1670 train_time:71029ms step_avg:92.61ms +step:768/1670 train_time:71122ms step_avg:92.61ms +step:769/1670 train_time:71214ms step_avg:92.61ms +step:770/1670 train_time:71307ms step_avg:92.61ms +step:771/1670 train_time:71400ms step_avg:92.61ms +step:772/1670 train_time:71493ms step_avg:92.61ms +step:773/1670 train_time:71586ms step_avg:92.61ms +step:774/1670 train_time:71679ms step_avg:92.61ms +step:775/1670 train_time:71771ms step_avg:92.61ms +step:776/1670 train_time:71864ms step_avg:92.61ms +step:777/1670 train_time:71955ms step_avg:92.61ms +step:778/1670 train_time:72047ms step_avg:92.61ms +step:779/1670 train_time:72142ms step_avg:92.61ms +step:780/1670 train_time:72236ms step_avg:92.61ms +step:781/1670 train_time:72328ms step_avg:92.61ms +step:782/1670 train_time:72421ms step_avg:92.61ms +step:783/1670 train_time:72514ms step_avg:92.61ms +step:784/1670 train_time:72606ms step_avg:92.61ms +step:785/1670 train_time:72699ms step_avg:92.61ms +step:786/1670 train_time:72791ms step_avg:92.61ms +step:787/1670 train_time:72883ms step_avg:92.61ms +step:788/1670 train_time:72976ms step_avg:92.61ms +step:789/1670 train_time:73068ms step_avg:92.61ms +step:790/1670 train_time:73160ms step_avg:92.61ms +step:791/1670 train_time:73252ms step_avg:92.61ms +step:792/1670 train_time:73346ms step_avg:92.61ms +step:793/1670 train_time:73438ms step_avg:92.61ms +step:794/1670 train_time:73530ms step_avg:92.61ms +step:795/1670 train_time:73623ms step_avg:92.61ms +step:796/1670 train_time:73716ms step_avg:92.61ms +step:797/1670 train_time:73808ms step_avg:92.61ms +step:798/1670 train_time:73902ms step_avg:92.61ms +step:799/1670 train_time:73994ms step_avg:92.61ms +step:800/1670 train_time:74086ms step_avg:92.61ms +step:801/1670 train_time:74179ms step_avg:92.61ms +step:802/1670 train_time:74272ms step_avg:92.61ms +step:803/1670 train_time:74365ms step_avg:92.61ms +step:804/1670 train_time:74457ms step_avg:92.61ms +step:805/1670 train_time:74548ms step_avg:92.61ms +step:806/1670 train_time:74642ms step_avg:92.61ms +step:807/1670 train_time:74735ms step_avg:92.61ms +step:808/1670 train_time:74827ms step_avg:92.61ms +step:809/1670 train_time:74920ms step_avg:92.61ms +step:810/1670 train_time:75012ms step_avg:92.61ms +step:811/1670 train_time:75105ms step_avg:92.61ms +step:812/1670 train_time:75197ms step_avg:92.61ms +step:813/1670 train_time:75290ms step_avg:92.61ms +step:814/1670 train_time:75383ms step_avg:92.61ms +step:815/1670 train_time:75476ms step_avg:92.61ms +step:816/1670 train_time:75568ms step_avg:92.61ms +step:817/1670 train_time:75661ms step_avg:92.61ms +step:818/1670 train_time:75753ms step_avg:92.61ms +step:819/1670 train_time:75845ms step_avg:92.61ms +step:820/1670 train_time:75938ms step_avg:92.61ms +step:821/1670 train_time:76030ms step_avg:92.61ms +step:822/1670 train_time:76123ms step_avg:92.61ms +step:823/1670 train_time:76216ms step_avg:92.61ms +step:824/1670 train_time:76308ms step_avg:92.61ms +step:825/1670 train_time:76401ms step_avg:92.61ms +step:826/1670 train_time:76493ms step_avg:92.61ms +step:827/1670 train_time:76585ms step_avg:92.61ms +step:828/1670 train_time:76678ms step_avg:92.61ms +step:829/1670 train_time:76770ms step_avg:92.61ms +step:830/1670 train_time:76862ms step_avg:92.61ms +step:831/1670 train_time:76955ms step_avg:92.61ms +step:832/1670 train_time:77047ms step_avg:92.60ms +step:833/1670 train_time:77140ms step_avg:92.60ms +step:834/1670 train_time:77232ms step_avg:92.60ms +step:835/1670 train_time:77325ms step_avg:92.61ms +step:836/1670 train_time:77418ms step_avg:92.61ms +step:837/1670 train_time:77510ms step_avg:92.60ms +step:838/1670 train_time:77603ms step_avg:92.61ms +step:839/1670 train_time:77696ms step_avg:92.61ms +step:840/1670 train_time:77788ms step_avg:92.60ms +step:841/1670 train_time:77881ms step_avg:92.61ms +step:842/1670 train_time:77973ms step_avg:92.60ms +step:843/1670 train_time:78066ms step_avg:92.60ms +step:844/1670 train_time:78158ms step_avg:92.60ms +step:845/1670 train_time:78251ms step_avg:92.61ms +step:846/1670 train_time:78344ms step_avg:92.61ms +step:847/1670 train_time:78437ms step_avg:92.61ms +step:848/1670 train_time:78530ms step_avg:92.61ms +step:849/1670 train_time:78623ms step_avg:92.61ms +step:850/1670 train_time:78716ms step_avg:92.61ms +step:851/1670 train_time:78965ms step_avg:92.79ms +step:852/1670 train_time:79037ms step_avg:92.77ms +step:853/1670 train_time:79128ms step_avg:92.76ms +step:854/1670 train_time:79220ms step_avg:92.76ms +step:855/1670 train_time:79311ms step_avg:92.76ms +step:856/1670 train_time:79402ms step_avg:92.76ms +step:857/1670 train_time:79493ms step_avg:92.76ms +step:858/1670 train_time:79585ms step_avg:92.76ms +step:859/1670 train_time:79676ms step_avg:92.75ms +step:860/1670 train_time:79767ms step_avg:92.75ms +step:861/1670 train_time:79867ms step_avg:92.76ms +step:862/1670 train_time:79965ms step_avg:92.77ms +step:863/1670 train_time:80058ms step_avg:92.77ms +step:864/1670 train_time:80150ms step_avg:92.77ms +step:865/1670 train_time:80242ms step_avg:92.77ms +step:866/1670 train_time:80334ms step_avg:92.76ms +step:867/1670 train_time:80426ms step_avg:92.76ms +step:868/1670 train_time:80517ms step_avg:92.76ms +step:869/1670 train_time:80608ms step_avg:92.76ms +step:870/1670 train_time:80699ms step_avg:92.76ms +step:871/1670 train_time:80793ms step_avg:92.76ms +step:872/1670 train_time:80888ms step_avg:92.76ms +step:873/1670 train_time:80983ms step_avg:92.76ms +step:874/1670 train_time:81076ms step_avg:92.76ms +step:875/1670 train_time:81169ms step_avg:92.76ms +step:875/1670 val_loss:3.5164 train_time:81263ms step_avg:92.87ms +step:876/1670 train_time:81283ms step_avg:92.79ms +step:877/1670 train_time:81357ms step_avg:92.77ms +step:878/1670 train_time:81450ms step_avg:92.77ms +step:879/1670 train_time:81542ms step_avg:92.77ms +step:880/1670 train_time:81633ms step_avg:92.77ms +step:881/1670 train_time:81725ms step_avg:92.76ms +step:882/1670 train_time:81816ms step_avg:92.76ms +step:883/1670 train_time:81909ms step_avg:92.76ms +step:884/1670 train_time:82001ms step_avg:92.76ms +step:885/1670 train_time:82094ms step_avg:92.76ms +step:886/1670 train_time:82188ms step_avg:92.76ms +step:887/1670 train_time:82281ms step_avg:92.76ms +step:888/1670 train_time:82376ms step_avg:92.77ms +step:889/1670 train_time:82468ms step_avg:92.76ms +step:890/1670 train_time:82560ms step_avg:92.76ms +step:891/1670 train_time:82651ms step_avg:92.76ms +step:892/1670 train_time:82743ms step_avg:92.76ms +step:893/1670 train_time:82834ms step_avg:92.76ms +step:894/1670 train_time:82927ms step_avg:92.76ms +step:895/1670 train_time:83020ms step_avg:92.76ms +step:896/1670 train_time:83112ms step_avg:92.76ms +step:897/1670 train_time:83206ms step_avg:92.76ms +step:898/1670 train_time:83301ms step_avg:92.76ms +step:899/1670 train_time:83395ms step_avg:92.76ms +step:900/1670 train_time:83487ms step_avg:92.76ms +step:901/1670 train_time:83580ms step_avg:92.76ms +step:902/1670 train_time:83672ms step_avg:92.76ms +step:903/1670 train_time:83764ms step_avg:92.76ms +step:904/1670 train_time:83856ms step_avg:92.76ms +step:905/1670 train_time:83948ms step_avg:92.76ms +step:906/1670 train_time:84040ms step_avg:92.76ms +step:907/1670 train_time:84133ms step_avg:92.76ms +step:908/1670 train_time:84226ms step_avg:92.76ms +step:909/1670 train_time:84321ms step_avg:92.76ms +step:910/1670 train_time:84414ms step_avg:92.76ms +step:911/1670 train_time:84506ms step_avg:92.76ms +step:912/1670 train_time:84599ms step_avg:92.76ms +step:913/1670 train_time:84691ms step_avg:92.76ms +step:914/1670 train_time:84783ms step_avg:92.76ms +step:915/1670 train_time:84875ms step_avg:92.76ms +step:916/1670 train_time:84967ms step_avg:92.76ms +step:917/1670 train_time:85060ms step_avg:92.76ms +step:918/1670 train_time:85152ms step_avg:92.76ms +step:919/1670 train_time:85245ms step_avg:92.76ms +step:920/1670 train_time:85338ms step_avg:92.76ms +step:921/1670 train_time:85430ms step_avg:92.76ms +step:922/1670 train_time:85523ms step_avg:92.76ms +step:923/1670 train_time:85617ms step_avg:92.76ms +step:924/1670 train_time:85709ms step_avg:92.76ms +step:925/1670 train_time:85803ms step_avg:92.76ms +step:926/1670 train_time:85895ms step_avg:92.76ms +step:927/1670 train_time:85987ms step_avg:92.76ms +step:928/1670 train_time:86081ms step_avg:92.76ms +step:929/1670 train_time:86174ms step_avg:92.76ms +step:930/1670 train_time:86267ms step_avg:92.76ms +step:931/1670 train_time:86361ms step_avg:92.76ms +step:932/1670 train_time:86453ms step_avg:92.76ms +step:933/1670 train_time:86546ms step_avg:92.76ms +step:934/1670 train_time:86639ms step_avg:92.76ms +step:935/1670 train_time:86731ms step_avg:92.76ms +step:936/1670 train_time:86822ms step_avg:92.76ms +step:937/1670 train_time:86914ms step_avg:92.76ms +step:938/1670 train_time:87006ms step_avg:92.76ms +step:939/1670 train_time:87099ms step_avg:92.76ms +step:940/1670 train_time:87192ms step_avg:92.76ms +step:941/1670 train_time:87284ms step_avg:92.76ms +step:942/1670 train_time:87378ms step_avg:92.76ms +step:943/1670 train_time:87470ms step_avg:92.76ms +step:944/1670 train_time:87563ms step_avg:92.76ms +step:945/1670 train_time:87655ms step_avg:92.76ms +step:946/1670 train_time:87748ms step_avg:92.76ms +step:947/1670 train_time:87840ms step_avg:92.76ms +step:948/1670 train_time:87933ms step_avg:92.76ms +step:949/1670 train_time:88026ms step_avg:92.76ms +step:950/1670 train_time:88118ms step_avg:92.76ms +step:951/1670 train_time:88210ms step_avg:92.76ms +step:952/1670 train_time:88303ms step_avg:92.75ms +step:953/1670 train_time:88394ms step_avg:92.75ms +step:954/1670 train_time:88487ms step_avg:92.75ms +step:955/1670 train_time:88580ms step_avg:92.75ms +step:956/1670 train_time:88673ms step_avg:92.75ms +step:957/1670 train_time:88766ms step_avg:92.75ms +step:958/1670 train_time:88859ms step_avg:92.75ms +step:959/1670 train_time:88952ms step_avg:92.76ms +step:960/1670 train_time:89045ms step_avg:92.75ms +step:961/1670 train_time:89137ms step_avg:92.75ms +step:962/1670 train_time:89229ms step_avg:92.75ms +step:963/1670 train_time:89322ms step_avg:92.75ms +step:964/1670 train_time:89414ms step_avg:92.75ms +step:965/1670 train_time:89507ms step_avg:92.75ms +step:966/1670 train_time:89601ms step_avg:92.75ms +step:967/1670 train_time:89694ms step_avg:92.75ms +step:968/1670 train_time:89787ms step_avg:92.75ms +step:969/1670 train_time:89881ms step_avg:92.76ms +step:970/1670 train_time:89973ms step_avg:92.76ms +step:971/1670 train_time:90065ms step_avg:92.75ms +step:972/1670 train_time:90157ms step_avg:92.75ms +step:973/1670 train_time:90248ms step_avg:92.75ms +step:974/1670 train_time:90341ms step_avg:92.75ms +step:975/1670 train_time:90433ms step_avg:92.75ms +step:976/1670 train_time:90526ms step_avg:92.75ms +step:977/1670 train_time:90618ms step_avg:92.75ms +step:978/1670 train_time:90710ms step_avg:92.75ms +step:979/1670 train_time:90803ms step_avg:92.75ms +step:980/1670 train_time:90897ms step_avg:92.75ms +step:981/1670 train_time:90989ms step_avg:92.75ms +step:982/1670 train_time:91081ms step_avg:92.75ms +step:983/1670 train_time:91174ms step_avg:92.75ms +step:984/1670 train_time:91267ms step_avg:92.75ms +step:985/1670 train_time:91360ms step_avg:92.75ms +step:986/1670 train_time:91452ms step_avg:92.75ms +step:987/1670 train_time:91545ms step_avg:92.75ms +step:988/1670 train_time:91638ms step_avg:92.75ms +step:989/1670 train_time:91731ms step_avg:92.75ms +step:990/1670 train_time:91824ms step_avg:92.75ms +step:991/1670 train_time:91916ms step_avg:92.75ms +step:992/1670 train_time:92009ms step_avg:92.75ms +step:993/1670 train_time:92102ms step_avg:92.75ms +step:994/1670 train_time:92194ms step_avg:92.75ms +step:995/1670 train_time:92287ms step_avg:92.75ms +step:996/1670 train_time:92380ms step_avg:92.75ms +step:997/1670 train_time:92473ms step_avg:92.75ms +step:998/1670 train_time:92565ms step_avg:92.75ms +step:999/1670 train_time:92658ms step_avg:92.75ms +step:1000/1670 train_time:92751ms step_avg:92.75ms +step:1000/1670 val_loss:3.4669 train_time:92843ms step_avg:92.84ms +step:1001/1670 train_time:92863ms step_avg:92.77ms +step:1002/1670 train_time:92941ms step_avg:92.76ms +step:1003/1670 train_time:93033ms step_avg:92.75ms +step:1004/1670 train_time:93124ms step_avg:92.75ms +step:1005/1670 train_time:93216ms step_avg:92.75ms +step:1006/1670 train_time:93307ms step_avg:92.75ms +step:1007/1670 train_time:93399ms step_avg:92.75ms +step:1008/1670 train_time:93493ms step_avg:92.75ms +step:1009/1670 train_time:93585ms step_avg:92.75ms +step:1010/1670 train_time:93678ms step_avg:92.75ms +step:1011/1670 train_time:93771ms step_avg:92.75ms +step:1012/1670 train_time:93865ms step_avg:92.75ms +step:1013/1670 train_time:93959ms step_avg:92.75ms +step:1014/1670 train_time:94052ms step_avg:92.75ms +step:1015/1670 train_time:94144ms step_avg:92.75ms +step:1016/1670 train_time:94236ms step_avg:92.75ms +step:1017/1670 train_time:94328ms step_avg:92.75ms +step:1018/1670 train_time:94421ms step_avg:92.75ms +step:1019/1670 train_time:94513ms step_avg:92.75ms +step:1020/1670 train_time:94604ms step_avg:92.75ms +step:1021/1670 train_time:94697ms step_avg:92.75ms +step:1022/1670 train_time:94791ms step_avg:92.75ms +step:1023/1670 train_time:94884ms step_avg:92.75ms +step:1024/1670 train_time:94978ms step_avg:92.75ms +step:1025/1670 train_time:95070ms step_avg:92.75ms +step:1026/1670 train_time:95162ms step_avg:92.75ms +step:1027/1670 train_time:95255ms step_avg:92.75ms +step:1028/1670 train_time:95347ms step_avg:92.75ms +step:1029/1670 train_time:95439ms step_avg:92.75ms +step:1030/1670 train_time:95532ms step_avg:92.75ms +step:1031/1670 train_time:95625ms step_avg:92.75ms +step:1032/1670 train_time:95718ms step_avg:92.75ms +step:1033/1670 train_time:95811ms step_avg:92.75ms +step:1034/1670 train_time:95903ms step_avg:92.75ms +step:1035/1670 train_time:95996ms step_avg:92.75ms +step:1036/1670 train_time:96088ms step_avg:92.75ms +step:1037/1670 train_time:96181ms step_avg:92.75ms +step:1038/1670 train_time:96273ms step_avg:92.75ms +step:1039/1670 train_time:96365ms step_avg:92.75ms +step:1040/1670 train_time:96458ms step_avg:92.75ms +step:1041/1670 train_time:96552ms step_avg:92.75ms +step:1042/1670 train_time:96645ms step_avg:92.75ms +step:1043/1670 train_time:96738ms step_avg:92.75ms +step:1044/1670 train_time:96832ms step_avg:92.75ms +step:1045/1670 train_time:96924ms step_avg:92.75ms +step:1046/1670 train_time:97016ms step_avg:92.75ms +step:1047/1670 train_time:97109ms step_avg:92.75ms +step:1048/1670 train_time:97200ms step_avg:92.75ms +step:1049/1670 train_time:97292ms step_avg:92.75ms +step:1050/1670 train_time:97385ms step_avg:92.75ms +step:1051/1670 train_time:97476ms step_avg:92.75ms +step:1052/1670 train_time:97568ms step_avg:92.75ms +step:1053/1670 train_time:97662ms step_avg:92.75ms +step:1054/1670 train_time:97757ms step_avg:92.75ms +step:1055/1670 train_time:97849ms step_avg:92.75ms +step:1056/1670 train_time:97942ms step_avg:92.75ms +step:1057/1670 train_time:98036ms step_avg:92.75ms +step:1058/1670 train_time:98128ms step_avg:92.75ms +step:1059/1670 train_time:98221ms step_avg:92.75ms +step:1060/1670 train_time:98313ms step_avg:92.75ms +step:1061/1670 train_time:98405ms step_avg:92.75ms +step:1062/1670 train_time:98656ms step_avg:92.90ms +step:1063/1670 train_time:98727ms step_avg:92.88ms +step:1064/1670 train_time:98818ms step_avg:92.87ms +step:1065/1670 train_time:98909ms step_avg:92.87ms +step:1066/1670 train_time:99000ms step_avg:92.87ms +step:1067/1670 train_time:99091ms step_avg:92.87ms +step:1068/1670 train_time:99183ms step_avg:92.87ms +step:1069/1670 train_time:99274ms step_avg:92.87ms +step:1070/1670 train_time:99365ms step_avg:92.86ms +step:1071/1670 train_time:99457ms step_avg:92.86ms +step:1072/1670 train_time:99555ms step_avg:92.87ms +step:1073/1670 train_time:99654ms step_avg:92.87ms +step:1074/1670 train_time:99747ms step_avg:92.87ms +step:1075/1670 train_time:99839ms step_avg:92.87ms +step:1076/1670 train_time:99931ms step_avg:92.87ms +step:1077/1670 train_time:100022ms step_avg:92.87ms +step:1078/1670 train_time:100113ms step_avg:92.87ms +step:1079/1670 train_time:100205ms step_avg:92.87ms +step:1080/1670 train_time:100296ms step_avg:92.87ms +step:1081/1670 train_time:100387ms step_avg:92.87ms +step:1082/1670 train_time:100482ms step_avg:92.87ms +step:1083/1670 train_time:100577ms step_avg:92.87ms +step:1084/1670 train_time:100670ms step_avg:92.87ms +step:1085/1670 train_time:100764ms step_avg:92.87ms +step:1086/1670 train_time:100858ms step_avg:92.87ms +step:1087/1670 train_time:100949ms step_avg:92.87ms +step:1088/1670 train_time:101044ms step_avg:92.87ms +step:1089/1670 train_time:101135ms step_avg:92.87ms +step:1090/1670 train_time:101226ms step_avg:92.87ms +step:1091/1670 train_time:101318ms step_avg:92.87ms +step:1092/1670 train_time:101410ms step_avg:92.87ms +step:1093/1670 train_time:101502ms step_avg:92.87ms +step:1094/1670 train_time:101596ms step_avg:92.87ms +step:1095/1670 train_time:101689ms step_avg:92.87ms +step:1096/1670 train_time:101783ms step_avg:92.87ms +step:1097/1670 train_time:101876ms step_avg:92.87ms +step:1098/1670 train_time:101968ms step_avg:92.87ms +step:1099/1670 train_time:102060ms step_avg:92.87ms +step:1100/1670 train_time:102153ms step_avg:92.87ms +step:1101/1670 train_time:102245ms step_avg:92.87ms +step:1102/1670 train_time:102336ms step_avg:92.86ms +step:1103/1670 train_time:102429ms step_avg:92.86ms +step:1104/1670 train_time:102522ms step_avg:92.86ms +step:1105/1670 train_time:102615ms step_avg:92.86ms +step:1106/1670 train_time:102708ms step_avg:92.86ms +step:1107/1670 train_time:102801ms step_avg:92.86ms +step:1108/1670 train_time:102894ms step_avg:92.87ms +step:1109/1670 train_time:102987ms step_avg:92.86ms +step:1110/1670 train_time:103079ms step_avg:92.86ms +step:1111/1670 train_time:103171ms step_avg:92.86ms +step:1112/1670 train_time:103263ms step_avg:92.86ms +step:1113/1670 train_time:103356ms step_avg:92.86ms +step:1114/1670 train_time:103448ms step_avg:92.86ms +step:1115/1670 train_time:103729ms step_avg:93.03ms +step:1116/1670 train_time:103806ms step_avg:93.02ms +step:1117/1670 train_time:103897ms step_avg:93.01ms +step:1118/1670 train_time:103989ms step_avg:93.01ms +step:1119/1670 train_time:104080ms step_avg:93.01ms +step:1120/1670 train_time:104172ms step_avg:93.01ms +step:1121/1670 train_time:104264ms step_avg:93.01ms +step:1122/1670 train_time:104356ms step_avg:93.01ms +step:1123/1670 train_time:104447ms step_avg:93.01ms +step:1124/1670 train_time:104540ms step_avg:93.01ms +step:1125/1670 train_time:104638ms step_avg:93.01ms +step:1125/1670 val_loss:3.4137 train_time:104738ms step_avg:93.10ms +step:1126/1670 train_time:104760ms step_avg:93.04ms +step:1127/1670 train_time:104836ms step_avg:93.02ms +step:1128/1670 train_time:104938ms step_avg:93.03ms +step:1129/1670 train_time:105033ms step_avg:93.03ms +step:1130/1670 train_time:105126ms step_avg:93.03ms +step:1131/1670 train_time:105218ms step_avg:93.03ms +step:1132/1670 train_time:105310ms step_avg:93.03ms +step:1133/1670 train_time:105403ms step_avg:93.03ms +step:1134/1670 train_time:105495ms step_avg:93.03ms +step:1135/1670 train_time:105587ms step_avg:93.03ms +step:1136/1670 train_time:105680ms step_avg:93.03ms +step:1137/1670 train_time:105773ms step_avg:93.03ms +step:1138/1670 train_time:105869ms step_avg:93.03ms +step:1139/1670 train_time:105966ms step_avg:93.03ms +step:1140/1670 train_time:106060ms step_avg:93.04ms +step:1141/1670 train_time:106153ms step_avg:93.03ms +step:1142/1670 train_time:106246ms step_avg:93.03ms +step:1143/1670 train_time:106338ms step_avg:93.03ms +step:1144/1670 train_time:106430ms step_avg:93.03ms +step:1145/1670 train_time:106523ms step_avg:93.03ms +step:1146/1670 train_time:106617ms step_avg:93.03ms +step:1147/1670 train_time:106709ms step_avg:93.03ms +step:1148/1670 train_time:106802ms step_avg:93.03ms +step:1149/1670 train_time:106897ms step_avg:93.03ms +step:1150/1670 train_time:106992ms step_avg:93.04ms +step:1151/1670 train_time:107086ms step_avg:93.04ms +step:1152/1670 train_time:107179ms step_avg:93.04ms +step:1153/1670 train_time:107271ms step_avg:93.04ms +step:1154/1670 train_time:107364ms step_avg:93.04ms +step:1155/1670 train_time:107458ms step_avg:93.04ms +step:1156/1670 train_time:107550ms step_avg:93.04ms +step:1157/1670 train_time:107642ms step_avg:93.04ms +step:1158/1670 train_time:107735ms step_avg:93.04ms +step:1159/1670 train_time:107830ms step_avg:93.04ms +step:1160/1670 train_time:107926ms step_avg:93.04ms +step:1161/1670 train_time:108020ms step_avg:93.04ms +step:1162/1670 train_time:108114ms step_avg:93.04ms +step:1163/1670 train_time:108208ms step_avg:93.04ms +step:1164/1670 train_time:108301ms step_avg:93.04ms +step:1165/1670 train_time:108393ms step_avg:93.04ms +step:1166/1670 train_time:108487ms step_avg:93.04ms +step:1167/1670 train_time:108580ms step_avg:93.04ms +step:1168/1670 train_time:108672ms step_avg:93.04ms +step:1169/1670 train_time:108765ms step_avg:93.04ms +step:1170/1670 train_time:108859ms step_avg:93.04ms +step:1171/1670 train_time:108952ms step_avg:93.04ms +step:1172/1670 train_time:109047ms step_avg:93.04ms +step:1173/1670 train_time:109140ms step_avg:93.04ms +step:1174/1670 train_time:109233ms step_avg:93.04ms +step:1175/1670 train_time:109328ms step_avg:93.04ms +step:1176/1670 train_time:109421ms step_avg:93.05ms +step:1177/1670 train_time:109514ms step_avg:93.04ms +step:1178/1670 train_time:109607ms step_avg:93.04ms +step:1179/1670 train_time:109700ms step_avg:93.05ms +step:1180/1670 train_time:109794ms step_avg:93.05ms +step:1181/1670 train_time:109888ms step_avg:93.05ms +step:1182/1670 train_time:109981ms step_avg:93.05ms +step:1183/1670 train_time:110075ms step_avg:93.05ms +step:1184/1670 train_time:110168ms step_avg:93.05ms +step:1185/1670 train_time:110261ms step_avg:93.05ms +step:1186/1670 train_time:110354ms step_avg:93.05ms +step:1187/1670 train_time:110447ms step_avg:93.05ms +step:1188/1670 train_time:110540ms step_avg:93.05ms +step:1189/1670 train_time:110632ms step_avg:93.05ms +step:1190/1670 train_time:110726ms step_avg:93.05ms +step:1191/1670 train_time:110819ms step_avg:93.05ms +step:1192/1670 train_time:110912ms step_avg:93.05ms +step:1193/1670 train_time:111006ms step_avg:93.05ms +step:1194/1670 train_time:111099ms step_avg:93.05ms +step:1195/1670 train_time:111193ms step_avg:93.05ms +step:1196/1670 train_time:111286ms step_avg:93.05ms +step:1197/1670 train_time:111379ms step_avg:93.05ms +step:1198/1670 train_time:111472ms step_avg:93.05ms +step:1199/1670 train_time:111565ms step_avg:93.05ms +step:1200/1670 train_time:111658ms step_avg:93.05ms +step:1201/1670 train_time:111751ms step_avg:93.05ms +step:1202/1670 train_time:111846ms step_avg:93.05ms +step:1203/1670 train_time:111939ms step_avg:93.05ms +step:1204/1670 train_time:112032ms step_avg:93.05ms +step:1205/1670 train_time:112127ms step_avg:93.05ms +step:1206/1670 train_time:112221ms step_avg:93.05ms +step:1207/1670 train_time:112315ms step_avg:93.05ms +step:1208/1670 train_time:112408ms step_avg:93.05ms +step:1209/1670 train_time:112501ms step_avg:93.05ms +step:1210/1670 train_time:112593ms step_avg:93.05ms +step:1211/1670 train_time:112688ms step_avg:93.05ms +step:1212/1670 train_time:112782ms step_avg:93.05ms +step:1213/1670 train_time:112874ms step_avg:93.05ms +step:1214/1670 train_time:112966ms step_avg:93.05ms +step:1215/1670 train_time:113060ms step_avg:93.05ms +step:1216/1670 train_time:113153ms step_avg:93.05ms +step:1217/1670 train_time:113248ms step_avg:93.05ms +step:1218/1670 train_time:113341ms step_avg:93.06ms +step:1219/1670 train_time:113434ms step_avg:93.05ms +step:1220/1670 train_time:113528ms step_avg:93.06ms +step:1221/1670 train_time:113622ms step_avg:93.06ms +step:1222/1670 train_time:113714ms step_avg:93.06ms +step:1223/1670 train_time:113807ms step_avg:93.06ms +step:1224/1670 train_time:113900ms step_avg:93.06ms +step:1225/1670 train_time:113993ms step_avg:93.06ms +step:1226/1670 train_time:114090ms step_avg:93.06ms +step:1227/1670 train_time:114181ms step_avg:93.06ms +step:1228/1670 train_time:114274ms step_avg:93.06ms +step:1229/1670 train_time:114367ms step_avg:93.06ms +step:1230/1670 train_time:114460ms step_avg:93.06ms +step:1231/1670 train_time:114553ms step_avg:93.06ms +step:1232/1670 train_time:114647ms step_avg:93.06ms +step:1233/1670 train_time:114739ms step_avg:93.06ms +step:1234/1670 train_time:114831ms step_avg:93.06ms +step:1235/1670 train_time:114925ms step_avg:93.06ms +step:1236/1670 train_time:115019ms step_avg:93.06ms +step:1237/1670 train_time:115112ms step_avg:93.06ms +step:1238/1670 train_time:115207ms step_avg:93.06ms +step:1239/1670 train_time:115299ms step_avg:93.06ms +step:1240/1670 train_time:115391ms step_avg:93.06ms +step:1241/1670 train_time:115485ms step_avg:93.06ms +step:1242/1670 train_time:115578ms step_avg:93.06ms +step:1243/1670 train_time:115670ms step_avg:93.06ms +step:1244/1670 train_time:115764ms step_avg:93.06ms +step:1245/1670 train_time:115856ms step_avg:93.06ms +step:1246/1670 train_time:115950ms step_avg:93.06ms +step:1247/1670 train_time:116044ms step_avg:93.06ms +step:1248/1670 train_time:116137ms step_avg:93.06ms +step:1249/1670 train_time:116230ms step_avg:93.06ms +step:1250/1670 train_time:116325ms step_avg:93.06ms +step:1250/1670 val_loss:3.3755 train_time:116419ms step_avg:93.13ms +step:1251/1670 train_time:116438ms step_avg:93.08ms +step:1252/1670 train_time:116514ms step_avg:93.06ms +step:1253/1670 train_time:116608ms step_avg:93.06ms +step:1254/1670 train_time:116700ms step_avg:93.06ms +step:1255/1670 train_time:116794ms step_avg:93.06ms +step:1256/1670 train_time:116887ms step_avg:93.06ms +step:1257/1670 train_time:116979ms step_avg:93.06ms +step:1258/1670 train_time:117072ms step_avg:93.06ms +step:1259/1670 train_time:117165ms step_avg:93.06ms +step:1260/1670 train_time:117258ms step_avg:93.06ms +step:1261/1670 train_time:117353ms step_avg:93.06ms +step:1262/1670 train_time:117449ms step_avg:93.07ms +step:1263/1670 train_time:117542ms step_avg:93.07ms +step:1264/1670 train_time:117636ms step_avg:93.07ms +step:1265/1670 train_time:117728ms step_avg:93.07ms +step:1266/1670 train_time:117821ms step_avg:93.07ms +step:1267/1670 train_time:117914ms step_avg:93.07ms +step:1268/1670 train_time:118007ms step_avg:93.07ms +step:1269/1670 train_time:118109ms step_avg:93.07ms +step:1270/1670 train_time:118196ms step_avg:93.07ms +step:1271/1670 train_time:118289ms step_avg:93.07ms +step:1272/1670 train_time:118383ms step_avg:93.07ms +step:1273/1670 train_time:118478ms step_avg:93.07ms +step:1274/1670 train_time:118727ms step_avg:93.19ms +step:1275/1670 train_time:118797ms step_avg:93.17ms +step:1276/1670 train_time:118888ms step_avg:93.17ms +step:1277/1670 train_time:118980ms step_avg:93.17ms +step:1278/1670 train_time:119072ms step_avg:93.17ms +step:1279/1670 train_time:119164ms step_avg:93.17ms +step:1280/1670 train_time:119256ms step_avg:93.17ms +step:1281/1670 train_time:119348ms step_avg:93.17ms +step:1282/1670 train_time:119440ms step_avg:93.17ms +step:1283/1670 train_time:119533ms step_avg:93.17ms +step:1284/1670 train_time:119631ms step_avg:93.17ms +step:1285/1670 train_time:119727ms step_avg:93.17ms +step:1286/1670 train_time:119823ms step_avg:93.18ms +step:1287/1670 train_time:119916ms step_avg:93.17ms +step:1288/1670 train_time:120008ms step_avg:93.17ms +step:1289/1670 train_time:120100ms step_avg:93.17ms +step:1290/1670 train_time:120193ms step_avg:93.17ms +step:1291/1670 train_time:120285ms step_avg:93.17ms +step:1292/1670 train_time:120378ms step_avg:93.17ms +step:1293/1670 train_time:120470ms step_avg:93.17ms +step:1294/1670 train_time:120564ms step_avg:93.17ms +step:1295/1670 train_time:120660ms step_avg:93.17ms +step:1296/1670 train_time:120756ms step_avg:93.18ms +step:1297/1670 train_time:120850ms step_avg:93.18ms +step:1298/1670 train_time:120943ms step_avg:93.18ms +step:1299/1670 train_time:121036ms step_avg:93.18ms +step:1300/1670 train_time:121128ms step_avg:93.18ms +step:1301/1670 train_time:121221ms step_avg:93.17ms +step:1302/1670 train_time:121313ms step_avg:93.17ms +step:1303/1670 train_time:121404ms step_avg:93.17ms +step:1304/1670 train_time:121497ms step_avg:93.17ms +step:1305/1670 train_time:121590ms step_avg:93.17ms +step:1306/1670 train_time:121684ms step_avg:93.17ms +step:1307/1670 train_time:121778ms step_avg:93.17ms +step:1308/1670 train_time:121872ms step_avg:93.17ms +step:1309/1670 train_time:121966ms step_avg:93.17ms +step:1310/1670 train_time:122058ms step_avg:93.17ms +step:1311/1670 train_time:122151ms step_avg:93.17ms +step:1312/1670 train_time:122244ms step_avg:93.17ms +step:1313/1670 train_time:122337ms step_avg:93.17ms +step:1314/1670 train_time:122430ms step_avg:93.17ms +step:1315/1670 train_time:122522ms step_avg:93.17ms +step:1316/1670 train_time:122616ms step_avg:93.17ms +step:1317/1670 train_time:122710ms step_avg:93.17ms +step:1318/1670 train_time:122805ms step_avg:93.18ms +step:1319/1670 train_time:122898ms step_avg:93.18ms +step:1320/1670 train_time:122991ms step_avg:93.18ms +step:1321/1670 train_time:123084ms step_avg:93.18ms +step:1322/1670 train_time:123177ms step_avg:93.17ms +step:1323/1670 train_time:123270ms step_avg:93.17ms +step:1324/1670 train_time:123362ms step_avg:93.17ms +step:1325/1670 train_time:123455ms step_avg:93.17ms +step:1326/1670 train_time:123548ms step_avg:93.17ms +step:1327/1670 train_time:123642ms step_avg:93.17ms +step:1328/1670 train_time:123736ms step_avg:93.17ms +step:1329/1670 train_time:123830ms step_avg:93.18ms +step:1330/1670 train_time:123923ms step_avg:93.18ms +step:1331/1670 train_time:124017ms step_avg:93.18ms +step:1332/1670 train_time:124111ms step_avg:93.18ms +step:1333/1670 train_time:124203ms step_avg:93.18ms +step:1334/1670 train_time:124296ms step_avg:93.18ms +step:1335/1670 train_time:124389ms step_avg:93.18ms +step:1336/1670 train_time:124482ms step_avg:93.18ms +step:1337/1670 train_time:124576ms step_avg:93.18ms +step:1338/1670 train_time:124669ms step_avg:93.18ms +step:1339/1670 train_time:124762ms step_avg:93.18ms +step:1340/1670 train_time:124856ms step_avg:93.18ms +step:1341/1670 train_time:124950ms step_avg:93.18ms +step:1342/1670 train_time:125043ms step_avg:93.18ms +step:1343/1670 train_time:125137ms step_avg:93.18ms +step:1344/1670 train_time:125230ms step_avg:93.18ms +step:1345/1670 train_time:125322ms step_avg:93.18ms +step:1346/1670 train_time:125415ms step_avg:93.18ms +step:1347/1670 train_time:125508ms step_avg:93.18ms +step:1348/1670 train_time:125600ms step_avg:93.18ms +step:1349/1670 train_time:125693ms step_avg:93.18ms +step:1350/1670 train_time:125786ms step_avg:93.17ms +step:1351/1670 train_time:125879ms step_avg:93.17ms +step:1352/1670 train_time:125973ms step_avg:93.17ms +step:1353/1670 train_time:126065ms step_avg:93.17ms +step:1354/1670 train_time:126158ms step_avg:93.17ms +step:1355/1670 train_time:126253ms step_avg:93.18ms +step:1356/1670 train_time:126346ms step_avg:93.18ms +step:1357/1670 train_time:126439ms step_avg:93.18ms +step:1358/1670 train_time:126532ms step_avg:93.17ms +step:1359/1670 train_time:126624ms step_avg:93.17ms +step:1360/1670 train_time:126717ms step_avg:93.17ms +step:1361/1670 train_time:126811ms step_avg:93.17ms +step:1362/1670 train_time:126904ms step_avg:93.17ms +step:1363/1670 train_time:126997ms step_avg:93.17ms +step:1364/1670 train_time:127090ms step_avg:93.17ms +step:1365/1670 train_time:127183ms step_avg:93.17ms +step:1366/1670 train_time:127276ms step_avg:93.17ms +step:1367/1670 train_time:127371ms step_avg:93.18ms +step:1368/1670 train_time:127465ms step_avg:93.18ms +step:1369/1670 train_time:127557ms step_avg:93.18ms +step:1370/1670 train_time:127651ms step_avg:93.18ms +step:1371/1670 train_time:127744ms step_avg:93.18ms +step:1372/1670 train_time:127838ms step_avg:93.18ms +step:1373/1670 train_time:127931ms step_avg:93.18ms +step:1374/1670 train_time:128024ms step_avg:93.18ms +step:1375/1670 train_time:128117ms step_avg:93.18ms +step:1375/1670 val_loss:3.3410 train_time:128210ms step_avg:93.24ms +step:1376/1670 train_time:128230ms step_avg:93.19ms +step:1377/1670 train_time:128304ms step_avg:93.18ms +step:1378/1670 train_time:128398ms step_avg:93.18ms +step:1379/1670 train_time:128490ms step_avg:93.18ms +step:1380/1670 train_time:128584ms step_avg:93.18ms +step:1381/1670 train_time:128678ms step_avg:93.18ms +step:1382/1670 train_time:128771ms step_avg:93.18ms +step:1383/1670 train_time:128864ms step_avg:93.18ms +step:1384/1670 train_time:128957ms step_avg:93.18ms +step:1385/1670 train_time:129050ms step_avg:93.18ms +step:1386/1670 train_time:129144ms step_avg:93.18ms +step:1387/1670 train_time:129239ms step_avg:93.18ms +step:1388/1670 train_time:129333ms step_avg:93.18ms +step:1389/1670 train_time:129426ms step_avg:93.18ms +step:1390/1670 train_time:129519ms step_avg:93.18ms +step:1391/1670 train_time:129611ms step_avg:93.18ms +step:1392/1670 train_time:129705ms step_avg:93.18ms +step:1393/1670 train_time:129800ms step_avg:93.18ms +step:1394/1670 train_time:129893ms step_avg:93.18ms +step:1395/1670 train_time:129985ms step_avg:93.18ms +step:1396/1670 train_time:130079ms step_avg:93.18ms +step:1397/1670 train_time:130173ms step_avg:93.18ms +step:1398/1670 train_time:130267ms step_avg:93.18ms +step:1399/1670 train_time:130361ms step_avg:93.18ms +step:1400/1670 train_time:130453ms step_avg:93.18ms +step:1401/1670 train_time:130545ms step_avg:93.18ms +step:1402/1670 train_time:130639ms step_avg:93.18ms +step:1403/1670 train_time:130731ms step_avg:93.18ms +step:1404/1670 train_time:130824ms step_avg:93.18ms +step:1405/1670 train_time:130917ms step_avg:93.18ms +step:1406/1670 train_time:131010ms step_avg:93.18ms +step:1407/1670 train_time:131104ms step_avg:93.18ms +step:1408/1670 train_time:131199ms step_avg:93.18ms +step:1409/1670 train_time:131293ms step_avg:93.18ms +step:1410/1670 train_time:131387ms step_avg:93.18ms +step:1411/1670 train_time:131481ms step_avg:93.18ms +step:1412/1670 train_time:131575ms step_avg:93.18ms +step:1413/1670 train_time:131667ms step_avg:93.18ms +step:1414/1670 train_time:131760ms step_avg:93.18ms +step:1415/1670 train_time:131853ms step_avg:93.18ms +step:1416/1670 train_time:131945ms step_avg:93.18ms +step:1417/1670 train_time:132039ms step_avg:93.18ms +step:1418/1670 train_time:132132ms step_avg:93.18ms +step:1419/1670 train_time:132226ms step_avg:93.18ms +step:1420/1670 train_time:132320ms step_avg:93.18ms +step:1421/1670 train_time:132414ms step_avg:93.18ms +step:1422/1670 train_time:132508ms step_avg:93.18ms +step:1423/1670 train_time:132602ms step_avg:93.18ms +step:1424/1670 train_time:132696ms step_avg:93.19ms +step:1425/1670 train_time:132789ms step_avg:93.19ms +step:1426/1670 train_time:132882ms step_avg:93.19ms +step:1427/1670 train_time:132975ms step_avg:93.18ms +step:1428/1670 train_time:133068ms step_avg:93.18ms +step:1429/1670 train_time:133160ms step_avg:93.18ms +step:1430/1670 train_time:133254ms step_avg:93.18ms +step:1431/1670 train_time:133347ms step_avg:93.18ms +step:1432/1670 train_time:133440ms step_avg:93.18ms +step:1433/1670 train_time:133533ms step_avg:93.18ms +step:1434/1670 train_time:133627ms step_avg:93.18ms +step:1435/1670 train_time:133721ms step_avg:93.19ms +step:1436/1670 train_time:133815ms step_avg:93.19ms +step:1437/1670 train_time:133907ms step_avg:93.19ms +step:1438/1670 train_time:134000ms step_avg:93.19ms +step:1439/1670 train_time:134093ms step_avg:93.19ms +step:1440/1670 train_time:134187ms step_avg:93.19ms +step:1441/1670 train_time:134281ms step_avg:93.19ms +step:1442/1670 train_time:134373ms step_avg:93.19ms +step:1443/1670 train_time:134466ms step_avg:93.19ms +step:1444/1670 train_time:134560ms step_avg:93.19ms +step:1445/1670 train_time:134653ms step_avg:93.19ms +step:1446/1670 train_time:134747ms step_avg:93.19ms +step:1447/1670 train_time:134841ms step_avg:93.19ms +step:1448/1670 train_time:134934ms step_avg:93.19ms +step:1449/1670 train_time:135027ms step_avg:93.19ms +step:1450/1670 train_time:135122ms step_avg:93.19ms +step:1451/1670 train_time:135215ms step_avg:93.19ms +step:1452/1670 train_time:135308ms step_avg:93.19ms +step:1453/1670 train_time:135401ms step_avg:93.19ms +step:1454/1670 train_time:135494ms step_avg:93.19ms +step:1455/1670 train_time:135588ms step_avg:93.19ms +step:1456/1670 train_time:135681ms step_avg:93.19ms +step:1457/1670 train_time:135774ms step_avg:93.19ms +step:1458/1670 train_time:135867ms step_avg:93.19ms +step:1459/1670 train_time:135961ms step_avg:93.19ms +step:1460/1670 train_time:136054ms step_avg:93.19ms +step:1461/1670 train_time:136148ms step_avg:93.19ms +step:1462/1670 train_time:136242ms step_avg:93.19ms +step:1463/1670 train_time:136335ms step_avg:93.19ms +step:1464/1670 train_time:136428ms step_avg:93.19ms +step:1465/1670 train_time:136523ms step_avg:93.19ms +step:1466/1670 train_time:136617ms step_avg:93.19ms +step:1467/1670 train_time:136710ms step_avg:93.19ms +step:1468/1670 train_time:136803ms step_avg:93.19ms +step:1469/1670 train_time:136896ms step_avg:93.19ms +step:1470/1670 train_time:136988ms step_avg:93.19ms +step:1471/1670 train_time:137082ms step_avg:93.19ms +step:1472/1670 train_time:137175ms step_avg:93.19ms +step:1473/1670 train_time:137267ms step_avg:93.19ms +step:1474/1670 train_time:137361ms step_avg:93.19ms +step:1475/1670 train_time:137456ms step_avg:93.19ms +step:1476/1670 train_time:137549ms step_avg:93.19ms +step:1477/1670 train_time:137644ms step_avg:93.19ms +step:1478/1670 train_time:137738ms step_avg:93.19ms +step:1479/1670 train_time:137831ms step_avg:93.19ms +step:1480/1670 train_time:137924ms step_avg:93.19ms +step:1481/1670 train_time:138017ms step_avg:93.19ms +step:1482/1670 train_time:138110ms step_avg:93.19ms +step:1483/1670 train_time:138204ms step_avg:93.19ms +step:1484/1670 train_time:138297ms step_avg:93.19ms +step:1485/1670 train_time:138548ms step_avg:93.30ms +step:1486/1670 train_time:138619ms step_avg:93.28ms +step:1487/1670 train_time:138711ms step_avg:93.28ms +step:1488/1670 train_time:138802ms step_avg:93.28ms +step:1489/1670 train_time:138894ms step_avg:93.28ms +step:1490/1670 train_time:138986ms step_avg:93.28ms +step:1491/1670 train_time:139078ms step_avg:93.28ms +step:1492/1670 train_time:139170ms step_avg:93.28ms +step:1493/1670 train_time:139261ms step_avg:93.28ms +step:1494/1670 train_time:139354ms step_avg:93.28ms +step:1495/1670 train_time:139453ms step_avg:93.28ms +step:1496/1670 train_time:139551ms step_avg:93.28ms +step:1497/1670 train_time:139644ms step_avg:93.28ms +step:1498/1670 train_time:139738ms step_avg:93.28ms +step:1499/1670 train_time:139830ms step_avg:93.28ms +step:1500/1670 train_time:139922ms step_avg:93.28ms +step:1500/1670 val_loss:3.3107 train_time:140016ms step_avg:93.34ms +step:1501/1670 train_time:140036ms step_avg:93.30ms +step:1502/1670 train_time:140112ms step_avg:93.28ms +step:1503/1670 train_time:140207ms step_avg:93.28ms +step:1504/1670 train_time:140299ms step_avg:93.28ms +step:1505/1670 train_time:140392ms step_avg:93.28ms +step:1506/1670 train_time:140484ms step_avg:93.28ms +step:1507/1670 train_time:140577ms step_avg:93.28ms +step:1508/1670 train_time:140673ms step_avg:93.28ms +step:1509/1670 train_time:140766ms step_avg:93.28ms +step:1510/1670 train_time:140859ms step_avg:93.28ms +step:1511/1670 train_time:140953ms step_avg:93.28ms +step:1512/1670 train_time:141047ms step_avg:93.29ms +step:1513/1670 train_time:141142ms step_avg:93.29ms +step:1514/1670 train_time:141236ms step_avg:93.29ms +step:1515/1670 train_time:141330ms step_avg:93.29ms +step:1516/1670 train_time:141422ms step_avg:93.29ms +step:1517/1670 train_time:141515ms step_avg:93.29ms +step:1518/1670 train_time:141610ms step_avg:93.29ms +step:1519/1670 train_time:141703ms step_avg:93.29ms +step:1520/1670 train_time:141796ms step_avg:93.29ms +step:1521/1670 train_time:141889ms step_avg:93.29ms +step:1522/1670 train_time:141983ms step_avg:93.29ms +step:1523/1670 train_time:142077ms step_avg:93.29ms +step:1524/1670 train_time:142170ms step_avg:93.29ms +step:1525/1670 train_time:142263ms step_avg:93.29ms +step:1526/1670 train_time:142356ms step_avg:93.29ms +step:1527/1670 train_time:142448ms step_avg:93.29ms +step:1528/1670 train_time:142540ms step_avg:93.29ms +step:1529/1670 train_time:142634ms step_avg:93.29ms +step:1530/1670 train_time:142728ms step_avg:93.29ms +step:1531/1670 train_time:142820ms step_avg:93.29ms +step:1532/1670 train_time:142914ms step_avg:93.29ms +step:1533/1670 train_time:143007ms step_avg:93.29ms +step:1534/1670 train_time:143101ms step_avg:93.29ms +step:1535/1670 train_time:143195ms step_avg:93.29ms +step:1536/1670 train_time:143288ms step_avg:93.29ms +step:1537/1670 train_time:143381ms step_avg:93.29ms +step:1538/1670 train_time:143475ms step_avg:93.29ms +step:1539/1670 train_time:143568ms step_avg:93.29ms +step:1540/1670 train_time:143661ms step_avg:93.29ms +step:1541/1670 train_time:143754ms step_avg:93.29ms +step:1542/1670 train_time:143847ms step_avg:93.29ms +step:1543/1670 train_time:143941ms step_avg:93.29ms +step:1544/1670 train_time:144036ms step_avg:93.29ms +step:1545/1670 train_time:144129ms step_avg:93.29ms +step:1546/1670 train_time:144221ms step_avg:93.29ms +step:1547/1670 train_time:144316ms step_avg:93.29ms +step:1548/1670 train_time:144409ms step_avg:93.29ms +step:1549/1670 train_time:144501ms step_avg:93.29ms +step:1550/1670 train_time:144595ms step_avg:93.29ms +step:1551/1670 train_time:144688ms step_avg:93.29ms +step:1552/1670 train_time:144780ms step_avg:93.29ms +step:1553/1670 train_time:144875ms step_avg:93.29ms +step:1554/1670 train_time:144969ms step_avg:93.29ms +step:1555/1670 train_time:145062ms step_avg:93.29ms +step:1556/1670 train_time:145155ms step_avg:93.29ms +step:1557/1670 train_time:145248ms step_avg:93.29ms +step:1558/1670 train_time:145341ms step_avg:93.29ms +step:1559/1670 train_time:145435ms step_avg:93.29ms +step:1560/1670 train_time:145527ms step_avg:93.29ms +step:1561/1670 train_time:145620ms step_avg:93.29ms +step:1562/1670 train_time:145713ms step_avg:93.29ms +step:1563/1670 train_time:145807ms step_avg:93.29ms +step:1564/1670 train_time:145900ms step_avg:93.29ms +step:1565/1670 train_time:145993ms step_avg:93.29ms +step:1566/1670 train_time:146087ms step_avg:93.29ms +step:1567/1670 train_time:146180ms step_avg:93.29ms +step:1568/1670 train_time:146275ms step_avg:93.29ms +step:1569/1670 train_time:146369ms step_avg:93.29ms +step:1570/1670 train_time:146462ms step_avg:93.29ms +step:1571/1670 train_time:146554ms step_avg:93.29ms +step:1572/1670 train_time:146647ms step_avg:93.29ms +step:1573/1670 train_time:146741ms step_avg:93.29ms +step:1574/1670 train_time:146836ms step_avg:93.29ms +step:1575/1670 train_time:146930ms step_avg:93.29ms +step:1576/1670 train_time:147023ms step_avg:93.29ms +step:1577/1670 train_time:147116ms step_avg:93.29ms +step:1578/1670 train_time:147209ms step_avg:93.29ms +step:1579/1670 train_time:147303ms step_avg:93.29ms +step:1580/1670 train_time:147398ms step_avg:93.29ms +step:1581/1670 train_time:147491ms step_avg:93.29ms +step:1582/1670 train_time:147584ms step_avg:93.29ms +step:1583/1670 train_time:147678ms step_avg:93.29ms +step:1584/1670 train_time:147771ms step_avg:93.29ms +step:1585/1670 train_time:147864ms step_avg:93.29ms +step:1586/1670 train_time:147957ms step_avg:93.29ms +step:1587/1670 train_time:148050ms step_avg:93.29ms +step:1588/1670 train_time:148143ms step_avg:93.29ms +step:1589/1670 train_time:148237ms step_avg:93.29ms +step:1590/1670 train_time:148330ms step_avg:93.29ms +step:1591/1670 train_time:148423ms step_avg:93.29ms +step:1592/1670 train_time:148516ms step_avg:93.29ms +step:1593/1670 train_time:148610ms step_avg:93.29ms +step:1594/1670 train_time:148703ms step_avg:93.29ms +step:1595/1670 train_time:148797ms step_avg:93.29ms +step:1596/1670 train_time:148891ms step_avg:93.29ms +step:1597/1670 train_time:148983ms step_avg:93.29ms +step:1598/1670 train_time:149077ms step_avg:93.29ms +step:1599/1670 train_time:149170ms step_avg:93.29ms +step:1600/1670 train_time:149264ms step_avg:93.29ms +step:1601/1670 train_time:149358ms step_avg:93.29ms +step:1602/1670 train_time:149450ms step_avg:93.29ms +step:1603/1670 train_time:149542ms step_avg:93.29ms +step:1604/1670 train_time:149637ms step_avg:93.29ms +step:1605/1670 train_time:149731ms step_avg:93.29ms +step:1606/1670 train_time:149824ms step_avg:93.29ms +step:1607/1670 train_time:149916ms step_avg:93.29ms +step:1608/1670 train_time:150010ms step_avg:93.29ms +step:1609/1670 train_time:150103ms step_avg:93.29ms +step:1610/1670 train_time:150198ms step_avg:93.29ms +step:1611/1670 train_time:150291ms step_avg:93.29ms +step:1612/1670 train_time:150384ms step_avg:93.29ms +step:1613/1670 train_time:150477ms step_avg:93.29ms +step:1614/1670 train_time:150571ms step_avg:93.29ms +step:1615/1670 train_time:150664ms step_avg:93.29ms +step:1616/1670 train_time:150757ms step_avg:93.29ms +step:1617/1670 train_time:150850ms step_avg:93.29ms +step:1618/1670 train_time:150943ms step_avg:93.29ms +step:1619/1670 train_time:151037ms step_avg:93.29ms +step:1620/1670 train_time:151132ms step_avg:93.29ms +step:1621/1670 train_time:151225ms step_avg:93.29ms +step:1622/1670 train_time:151318ms step_avg:93.29ms +step:1623/1670 train_time:151413ms step_avg:93.29ms +step:1624/1670 train_time:151506ms step_avg:93.29ms +step:1625/1670 train_time:151599ms step_avg:93.29ms +step:1625/1670 val_loss:3.2859 train_time:151692ms step_avg:93.35ms +step:1626/1670 train_time:151712ms step_avg:93.30ms +step:1627/1670 train_time:151786ms step_avg:93.29ms +step:1628/1670 train_time:151879ms step_avg:93.29ms +step:1629/1670 train_time:151972ms step_avg:93.29ms +step:1630/1670 train_time:152064ms step_avg:93.29ms +step:1631/1670 train_time:152159ms step_avg:93.29ms +step:1632/1670 train_time:152252ms step_avg:93.29ms +step:1633/1670 train_time:152345ms step_avg:93.29ms +step:1634/1670 train_time:152439ms step_avg:93.29ms +step:1635/1670 train_time:152531ms step_avg:93.29ms +step:1636/1670 train_time:152626ms step_avg:93.29ms +step:1637/1670 train_time:152724ms step_avg:93.29ms +step:1638/1670 train_time:152818ms step_avg:93.30ms +step:1639/1670 train_time:152910ms step_avg:93.29ms +step:1640/1670 train_time:153003ms step_avg:93.29ms +step:1641/1670 train_time:153096ms step_avg:93.29ms +step:1642/1670 train_time:153189ms step_avg:93.29ms +step:1643/1670 train_time:153283ms step_avg:93.29ms +step:1644/1670 train_time:153377ms step_avg:93.29ms +step:1645/1670 train_time:153469ms step_avg:93.29ms +step:1646/1670 train_time:153563ms step_avg:93.29ms +step:1647/1670 train_time:153658ms step_avg:93.30ms +step:1648/1670 train_time:153752ms step_avg:93.30ms +step:1649/1670 train_time:153845ms step_avg:93.30ms +step:1650/1670 train_time:153939ms step_avg:93.30ms +step:1651/1670 train_time:154031ms step_avg:93.30ms +step:1652/1670 train_time:154125ms step_avg:93.30ms +step:1653/1670 train_time:154217ms step_avg:93.30ms +step:1654/1670 train_time:154310ms step_avg:93.29ms +step:1655/1670 train_time:154404ms step_avg:93.30ms +step:1656/1670 train_time:154497ms step_avg:93.30ms +step:1657/1670 train_time:154591ms step_avg:93.30ms +step:1658/1670 train_time:154685ms step_avg:93.30ms +step:1659/1670 train_time:154779ms step_avg:93.30ms +step:1660/1670 train_time:154872ms step_avg:93.30ms +step:1661/1670 train_time:154964ms step_avg:93.30ms +step:1662/1670 train_time:155057ms step_avg:93.30ms +step:1663/1670 train_time:155150ms step_avg:93.30ms +step:1664/1670 train_time:155243ms step_avg:93.29ms +step:1665/1670 train_time:155335ms step_avg:93.29ms +step:1666/1670 train_time:155428ms step_avg:93.29ms +step:1667/1670 train_time:155521ms step_avg:93.29ms +step:1668/1670 train_time:155615ms step_avg:93.29ms +step:1669/1670 train_time:155708ms step_avg:93.29ms +step:1670/1670 train_time:155801ms step_avg:93.29ms +step:1670/1670 val_loss:3.2774 train_time:156064ms step_avg:93.45ms +peak memory allocated: 32002 MiB reserved: 47054 MiB diff --git a/records/091125_VectSigmoidBFloat16/a72c99e9-0019-4baa-a858-f8738392933f.txt b/records/091125_VectSigmoidBFloat16/a72c99e9-0019-4baa-a858-f8738392933f.txt new file mode 100644 index 000000000..23f52c71e --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/a72c99e9-0019-4baa-a858-f8738392933f.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:40:05 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 33C P0 120W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 34C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 33C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 31C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 33C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 33C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 33C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.08ms +step:1/1670 train_time:297ms step_avg:296.71ms +step:2/1670 train_time:316ms step_avg:158.07ms +step:3/1670 train_time:383ms step_avg:127.80ms +step:4/1670 train_time:472ms step_avg:118.02ms +step:5/1670 train_time:562ms step_avg:112.35ms +step:6/1670 train_time:652ms step_avg:108.62ms +step:7/1670 train_time:742ms step_avg:105.94ms +step:8/1670 train_time:833ms step_avg:104.07ms +step:9/1670 train_time:922ms step_avg:102.48ms +step:10/1670 train_time:1013ms step_avg:101.26ms +step:11/1670 train_time:1103ms step_avg:100.24ms +step:12/1670 train_time:1198ms step_avg:99.79ms +step:13/1670 train_time:1292ms step_avg:99.37ms +step:14/1670 train_time:1384ms step_avg:98.87ms +step:15/1670 train_time:1475ms step_avg:98.36ms +step:16/1670 train_time:1566ms step_avg:97.89ms +step:17/1670 train_time:1658ms step_avg:97.51ms +step:18/1670 train_time:1748ms step_avg:97.12ms +step:19/1670 train_time:1838ms step_avg:96.73ms +step:20/1670 train_time:1928ms step_avg:96.40ms +step:21/1670 train_time:2018ms step_avg:96.12ms +step:22/1670 train_time:2110ms step_avg:95.90ms +step:23/1670 train_time:2202ms step_avg:95.73ms +step:24/1670 train_time:2296ms step_avg:95.65ms +step:25/1670 train_time:2388ms step_avg:95.51ms +step:26/1670 train_time:2481ms step_avg:95.43ms +step:27/1670 train_time:2571ms step_avg:95.21ms +step:28/1670 train_time:2662ms step_avg:95.07ms +step:29/1670 train_time:2752ms step_avg:94.91ms +step:30/1670 train_time:2843ms step_avg:94.76ms +step:31/1670 train_time:2934ms step_avg:94.63ms +step:32/1670 train_time:3024ms step_avg:94.51ms +step:33/1670 train_time:3115ms step_avg:94.39ms +step:34/1670 train_time:3208ms step_avg:94.36ms +step:35/1670 train_time:3300ms step_avg:94.29ms +step:36/1670 train_time:3392ms step_avg:94.24ms +step:37/1670 train_time:3484ms step_avg:94.17ms +step:38/1670 train_time:3577ms step_avg:94.13ms +step:39/1670 train_time:3668ms step_avg:94.06ms +step:40/1670 train_time:3759ms step_avg:93.98ms +step:41/1670 train_time:3850ms step_avg:93.90ms +step:42/1670 train_time:3941ms step_avg:93.84ms +step:43/1670 train_time:4032ms step_avg:93.76ms +step:44/1670 train_time:4123ms step_avg:93.71ms +step:45/1670 train_time:4214ms step_avg:93.64ms +step:46/1670 train_time:4306ms step_avg:93.61ms +step:47/1670 train_time:4398ms step_avg:93.57ms +step:48/1670 train_time:4489ms step_avg:93.53ms +step:49/1670 train_time:4580ms step_avg:93.48ms +step:50/1670 train_time:4672ms step_avg:93.44ms +step:51/1670 train_time:4768ms step_avg:93.50ms +step:52/1670 train_time:4855ms step_avg:93.36ms +step:53/1670 train_time:4945ms step_avg:93.31ms +step:54/1670 train_time:5037ms step_avg:93.28ms +step:55/1670 train_time:5128ms step_avg:93.23ms +step:56/1670 train_time:5218ms step_avg:93.18ms +step:57/1670 train_time:5310ms step_avg:93.16ms +step:58/1670 train_time:5402ms step_avg:93.15ms +step:59/1670 train_time:5494ms step_avg:93.11ms +step:60/1670 train_time:5584ms step_avg:93.07ms +step:61/1670 train_time:5676ms step_avg:93.04ms +step:62/1670 train_time:5766ms step_avg:93.00ms +step:63/1670 train_time:5858ms step_avg:92.99ms +step:64/1670 train_time:5950ms step_avg:92.96ms +step:65/1670 train_time:6041ms step_avg:92.93ms +step:66/1670 train_time:6132ms step_avg:92.90ms +step:67/1670 train_time:6222ms step_avg:92.86ms +step:68/1670 train_time:6313ms step_avg:92.84ms +step:69/1670 train_time:6404ms step_avg:92.82ms +step:70/1670 train_time:6496ms step_avg:92.80ms +step:71/1670 train_time:6586ms step_avg:92.77ms +step:72/1670 train_time:6679ms step_avg:92.76ms +step:73/1670 train_time:6771ms step_avg:92.75ms +step:74/1670 train_time:6863ms step_avg:92.74ms +step:75/1670 train_time:6954ms step_avg:92.72ms +step:76/1670 train_time:7046ms step_avg:92.71ms +step:77/1670 train_time:7138ms step_avg:92.70ms +step:78/1670 train_time:7228ms step_avg:92.67ms +step:79/1670 train_time:7319ms step_avg:92.65ms +step:80/1670 train_time:7410ms step_avg:92.62ms +step:81/1670 train_time:7502ms step_avg:92.61ms +step:82/1670 train_time:7592ms step_avg:92.58ms +step:83/1670 train_time:7684ms step_avg:92.58ms +step:84/1670 train_time:7776ms step_avg:92.58ms +step:85/1670 train_time:7867ms step_avg:92.56ms +step:86/1670 train_time:7958ms step_avg:92.54ms +step:87/1670 train_time:8049ms step_avg:92.52ms +step:88/1670 train_time:8140ms step_avg:92.50ms +step:89/1670 train_time:8231ms step_avg:92.48ms +step:90/1670 train_time:8322ms step_avg:92.47ms +step:91/1670 train_time:8414ms step_avg:92.46ms +step:92/1670 train_time:8504ms step_avg:92.44ms +step:93/1670 train_time:8595ms step_avg:92.42ms +step:94/1670 train_time:8686ms step_avg:92.40ms +step:95/1670 train_time:8776ms step_avg:92.38ms +step:96/1670 train_time:8868ms step_avg:92.38ms +step:97/1670 train_time:8961ms step_avg:92.39ms +step:98/1670 train_time:9053ms step_avg:92.37ms +step:99/1670 train_time:9143ms step_avg:92.36ms +step:100/1670 train_time:9234ms step_avg:92.34ms +step:101/1670 train_time:9325ms step_avg:92.32ms +step:102/1670 train_time:9416ms step_avg:92.31ms +step:103/1670 train_time:9507ms step_avg:92.30ms +step:104/1670 train_time:9597ms step_avg:92.28ms +step:105/1670 train_time:9688ms step_avg:92.27ms +step:106/1670 train_time:9780ms step_avg:92.26ms +step:107/1670 train_time:9873ms step_avg:92.27ms +step:108/1670 train_time:9964ms step_avg:92.26ms +step:109/1670 train_time:10055ms step_avg:92.25ms +step:110/1670 train_time:10146ms step_avg:92.23ms +step:111/1670 train_time:10237ms step_avg:92.23ms +step:112/1670 train_time:10328ms step_avg:92.21ms +step:113/1670 train_time:10420ms step_avg:92.21ms +step:114/1670 train_time:10510ms step_avg:92.20ms +step:115/1670 train_time:10601ms step_avg:92.18ms +step:116/1670 train_time:10692ms step_avg:92.18ms +step:117/1670 train_time:10784ms step_avg:92.17ms +step:118/1670 train_time:10877ms step_avg:92.17ms +step:119/1670 train_time:10967ms step_avg:92.16ms +step:120/1670 train_time:11059ms step_avg:92.16ms +step:121/1670 train_time:11150ms step_avg:92.15ms +step:122/1670 train_time:11242ms step_avg:92.15ms +step:123/1670 train_time:11334ms step_avg:92.14ms +step:124/1670 train_time:11425ms step_avg:92.13ms +step:125/1670 train_time:11516ms step_avg:92.13ms +step:125/1670 val_loss:4.3116 train_time:11606ms step_avg:92.85ms +step:126/1670 train_time:11629ms step_avg:92.30ms +step:127/1670 train_time:11698ms step_avg:92.11ms +step:128/1670 train_time:11800ms step_avg:92.18ms +step:129/1670 train_time:11893ms step_avg:92.19ms +step:130/1670 train_time:11983ms step_avg:92.18ms +step:131/1670 train_time:12073ms step_avg:92.16ms +step:132/1670 train_time:12164ms step_avg:92.15ms +step:133/1670 train_time:12253ms step_avg:92.13ms +step:134/1670 train_time:12343ms step_avg:92.11ms +step:135/1670 train_time:12434ms step_avg:92.10ms +step:136/1670 train_time:12525ms step_avg:92.09ms +step:137/1670 train_time:12616ms step_avg:92.09ms +step:138/1670 train_time:12708ms step_avg:92.09ms +step:139/1670 train_time:12801ms step_avg:92.10ms +step:140/1670 train_time:12893ms step_avg:92.10ms +step:141/1670 train_time:12985ms step_avg:92.09ms +step:142/1670 train_time:13075ms step_avg:92.08ms +step:143/1670 train_time:13166ms step_avg:92.07ms +step:144/1670 train_time:13255ms step_avg:92.05ms +step:145/1670 train_time:13345ms step_avg:92.04ms +step:146/1670 train_time:13437ms step_avg:92.03ms +step:147/1670 train_time:13528ms step_avg:92.03ms +step:148/1670 train_time:13619ms step_avg:92.02ms +step:149/1670 train_time:13712ms step_avg:92.03ms +step:150/1670 train_time:13805ms step_avg:92.03ms +step:151/1670 train_time:13897ms step_avg:92.03ms +step:152/1670 train_time:13989ms step_avg:92.03ms +step:153/1670 train_time:14079ms step_avg:92.02ms +step:154/1670 train_time:14169ms step_avg:92.01ms +step:155/1670 train_time:14259ms step_avg:91.99ms +step:156/1670 train_time:14349ms step_avg:91.98ms +step:157/1670 train_time:14440ms step_avg:91.97ms +step:158/1670 train_time:14530ms step_avg:91.96ms +step:159/1670 train_time:14622ms step_avg:91.96ms +step:160/1670 train_time:14713ms step_avg:91.96ms +step:161/1670 train_time:14805ms step_avg:91.96ms +step:162/1670 train_time:14897ms step_avg:91.96ms +step:163/1670 train_time:14988ms step_avg:91.95ms +step:164/1670 train_time:15079ms step_avg:91.94ms +step:165/1670 train_time:15169ms step_avg:91.93ms +step:166/1670 train_time:15259ms step_avg:91.92ms +step:167/1670 train_time:15349ms step_avg:91.91ms +step:168/1670 train_time:15439ms step_avg:91.90ms +step:169/1670 train_time:15529ms step_avg:91.89ms +step:170/1670 train_time:15621ms step_avg:91.89ms +step:171/1670 train_time:15712ms step_avg:91.88ms +step:172/1670 train_time:15803ms step_avg:91.88ms +step:173/1670 train_time:15895ms step_avg:91.88ms +step:174/1670 train_time:15986ms step_avg:91.88ms +step:175/1670 train_time:16077ms step_avg:91.87ms +step:176/1670 train_time:16168ms step_avg:91.86ms +step:177/1670 train_time:16258ms step_avg:91.85ms +step:178/1670 train_time:16349ms step_avg:91.85ms +step:179/1670 train_time:16439ms step_avg:91.84ms +step:180/1670 train_time:16530ms step_avg:91.83ms +step:181/1670 train_time:16620ms step_avg:91.82ms +step:182/1670 train_time:16711ms step_avg:91.82ms +step:183/1670 train_time:16801ms step_avg:91.81ms +step:184/1670 train_time:16892ms step_avg:91.81ms +step:185/1670 train_time:16983ms step_avg:91.80ms +step:186/1670 train_time:17075ms step_avg:91.80ms +step:187/1670 train_time:17167ms step_avg:91.80ms +step:188/1670 train_time:17258ms step_avg:91.80ms +step:189/1670 train_time:17347ms step_avg:91.78ms +step:190/1670 train_time:17438ms step_avg:91.78ms +step:191/1670 train_time:17528ms step_avg:91.77ms +step:192/1670 train_time:17620ms step_avg:91.77ms +step:193/1670 train_time:17711ms step_avg:91.77ms +step:194/1670 train_time:17802ms step_avg:91.76ms +step:195/1670 train_time:17893ms step_avg:91.76ms +step:196/1670 train_time:17984ms step_avg:91.75ms +step:197/1670 train_time:18075ms step_avg:91.75ms +step:198/1670 train_time:18167ms step_avg:91.75ms +step:199/1670 train_time:18258ms step_avg:91.75ms +step:200/1670 train_time:18348ms step_avg:91.74ms +step:201/1670 train_time:18439ms step_avg:91.73ms +step:202/1670 train_time:18529ms step_avg:91.73ms +step:203/1670 train_time:18620ms step_avg:91.73ms +step:204/1670 train_time:18710ms step_avg:91.72ms +step:205/1670 train_time:18801ms step_avg:91.71ms +step:206/1670 train_time:18892ms step_avg:91.71ms +step:207/1670 train_time:18984ms step_avg:91.71ms +step:208/1670 train_time:19074ms step_avg:91.70ms +step:209/1670 train_time:19166ms step_avg:91.70ms +step:210/1670 train_time:19259ms step_avg:91.71ms +step:211/1670 train_time:19350ms step_avg:91.70ms +step:212/1670 train_time:19441ms step_avg:91.70ms +step:213/1670 train_time:19690ms step_avg:92.44ms +step:214/1670 train_time:19760ms step_avg:92.34ms +step:215/1670 train_time:19850ms step_avg:92.32ms +step:216/1670 train_time:19941ms step_avg:92.32ms +step:217/1670 train_time:20030ms step_avg:92.31ms +step:218/1670 train_time:20121ms step_avg:92.30ms +step:219/1670 train_time:20210ms step_avg:92.28ms +step:220/1670 train_time:20301ms step_avg:92.28ms +step:221/1670 train_time:20391ms step_avg:92.27ms +step:222/1670 train_time:20481ms step_avg:92.26ms +step:223/1670 train_time:20574ms step_avg:92.26ms +step:224/1670 train_time:20670ms step_avg:92.28ms +step:225/1670 train_time:20764ms step_avg:92.29ms +step:226/1670 train_time:20856ms step_avg:92.28ms +step:227/1670 train_time:20946ms step_avg:92.27ms +step:228/1670 train_time:21037ms step_avg:92.27ms +step:229/1670 train_time:21127ms step_avg:92.26ms +step:230/1670 train_time:21217ms step_avg:92.25ms +step:231/1670 train_time:21306ms step_avg:92.24ms +step:232/1670 train_time:21399ms step_avg:92.24ms +step:233/1670 train_time:21487ms step_avg:92.22ms +step:234/1670 train_time:21579ms step_avg:92.22ms +step:235/1670 train_time:21674ms step_avg:92.23ms +step:236/1670 train_time:21767ms step_avg:92.23ms +step:237/1670 train_time:21859ms step_avg:92.23ms +step:238/1670 train_time:21950ms step_avg:92.23ms +step:239/1670 train_time:22042ms step_avg:92.22ms +step:240/1670 train_time:22132ms step_avg:92.22ms +step:241/1670 train_time:22223ms step_avg:92.21ms +step:242/1670 train_time:22313ms step_avg:92.20ms +step:243/1670 train_time:22403ms step_avg:92.19ms +step:244/1670 train_time:22493ms step_avg:92.19ms +step:245/1670 train_time:22585ms step_avg:92.18ms +step:246/1670 train_time:22676ms step_avg:92.18ms +step:247/1670 train_time:22769ms step_avg:92.18ms +step:248/1670 train_time:22861ms step_avg:92.18ms +step:249/1670 train_time:22952ms step_avg:92.18ms +step:250/1670 train_time:23044ms step_avg:92.18ms +step:250/1670 val_loss:3.9690 train_time:23134ms step_avg:92.54ms +step:251/1670 train_time:23154ms step_avg:92.25ms +step:252/1670 train_time:23228ms step_avg:92.17ms +step:253/1670 train_time:23319ms step_avg:92.17ms +step:254/1670 train_time:23409ms step_avg:92.16ms +step:255/1670 train_time:23500ms step_avg:92.16ms +step:256/1670 train_time:23590ms step_avg:92.15ms +step:257/1670 train_time:23680ms step_avg:92.14ms +step:258/1670 train_time:23773ms step_avg:92.14ms +step:259/1670 train_time:23864ms step_avg:92.14ms +step:260/1670 train_time:23955ms step_avg:92.14ms +step:261/1670 train_time:24045ms step_avg:92.13ms +step:262/1670 train_time:24138ms step_avg:92.13ms +step:263/1670 train_time:24229ms step_avg:92.13ms +step:264/1670 train_time:24321ms step_avg:92.12ms +step:265/1670 train_time:24412ms step_avg:92.12ms +step:266/1670 train_time:24503ms step_avg:92.12ms +step:267/1670 train_time:24594ms step_avg:92.11ms +step:268/1670 train_time:24685ms step_avg:92.11ms +step:269/1670 train_time:24776ms step_avg:92.10ms +step:270/1670 train_time:24866ms step_avg:92.10ms +step:271/1670 train_time:24957ms step_avg:92.09ms +step:272/1670 train_time:25048ms step_avg:92.09ms +step:273/1670 train_time:25140ms step_avg:92.09ms +step:274/1670 train_time:25233ms step_avg:92.09ms +step:275/1670 train_time:25324ms step_avg:92.09ms +step:276/1670 train_time:25415ms step_avg:92.08ms +step:277/1670 train_time:25506ms step_avg:92.08ms +step:278/1670 train_time:25596ms step_avg:92.07ms +step:279/1670 train_time:25687ms step_avg:92.07ms +step:280/1670 train_time:25778ms step_avg:92.06ms +step:281/1670 train_time:25869ms step_avg:92.06ms +step:282/1670 train_time:25961ms step_avg:92.06ms +step:283/1670 train_time:26052ms step_avg:92.05ms +step:284/1670 train_time:26143ms step_avg:92.05ms +step:285/1670 train_time:26235ms step_avg:92.05ms +step:286/1670 train_time:26326ms step_avg:92.05ms +step:287/1670 train_time:26417ms step_avg:92.04ms +step:288/1670 train_time:26507ms step_avg:92.04ms +step:289/1670 train_time:26599ms step_avg:92.04ms +step:290/1670 train_time:26690ms step_avg:92.03ms +step:291/1670 train_time:26780ms step_avg:92.03ms +step:292/1670 train_time:26871ms step_avg:92.02ms +step:293/1670 train_time:26963ms step_avg:92.02ms +step:294/1670 train_time:27052ms step_avg:92.01ms +step:295/1670 train_time:27143ms step_avg:92.01ms +step:296/1670 train_time:27235ms step_avg:92.01ms +step:297/1670 train_time:27326ms step_avg:92.01ms +step:298/1670 train_time:27417ms step_avg:92.00ms +step:299/1670 train_time:27508ms step_avg:92.00ms +step:300/1670 train_time:27599ms step_avg:92.00ms +step:301/1670 train_time:27689ms step_avg:91.99ms +step:302/1670 train_time:27780ms step_avg:91.99ms +step:303/1670 train_time:27871ms step_avg:91.98ms +step:304/1670 train_time:27962ms step_avg:91.98ms +step:305/1670 train_time:28053ms step_avg:91.98ms +step:306/1670 train_time:28144ms step_avg:91.97ms +step:307/1670 train_time:28236ms step_avg:91.97ms +step:308/1670 train_time:28327ms step_avg:91.97ms +step:309/1670 train_time:28418ms step_avg:91.97ms +step:310/1670 train_time:28509ms step_avg:91.96ms +step:311/1670 train_time:28600ms step_avg:91.96ms +step:312/1670 train_time:28691ms step_avg:91.96ms +step:313/1670 train_time:28781ms step_avg:91.95ms +step:314/1670 train_time:28872ms step_avg:91.95ms +step:315/1670 train_time:28963ms step_avg:91.95ms +step:316/1670 train_time:29054ms step_avg:91.94ms +step:317/1670 train_time:29145ms step_avg:91.94ms +step:318/1670 train_time:29236ms step_avg:91.94ms +step:319/1670 train_time:29327ms step_avg:91.93ms +step:320/1670 train_time:29419ms step_avg:91.93ms +step:321/1670 train_time:29509ms step_avg:91.93ms +step:322/1670 train_time:29601ms step_avg:91.93ms +step:323/1670 train_time:29692ms step_avg:91.93ms +step:324/1670 train_time:29782ms step_avg:91.92ms +step:325/1670 train_time:29874ms step_avg:91.92ms +step:326/1670 train_time:29964ms step_avg:91.91ms +step:327/1670 train_time:30054ms step_avg:91.91ms +step:328/1670 train_time:30145ms step_avg:91.91ms +step:329/1670 train_time:30237ms step_avg:91.90ms +step:330/1670 train_time:30327ms step_avg:91.90ms +step:331/1670 train_time:30419ms step_avg:91.90ms +step:332/1670 train_time:30510ms step_avg:91.90ms +step:333/1670 train_time:30601ms step_avg:91.89ms +step:334/1670 train_time:30691ms step_avg:91.89ms +step:335/1670 train_time:30781ms step_avg:91.88ms +step:336/1670 train_time:30873ms step_avg:91.88ms +step:337/1670 train_time:30964ms step_avg:91.88ms +step:338/1670 train_time:31055ms step_avg:91.88ms +step:339/1670 train_time:31145ms step_avg:91.87ms +step:340/1670 train_time:31236ms step_avg:91.87ms +step:341/1670 train_time:31327ms step_avg:91.87ms +step:342/1670 train_time:31418ms step_avg:91.87ms +step:343/1670 train_time:31509ms step_avg:91.86ms +step:344/1670 train_time:31600ms step_avg:91.86ms +step:345/1670 train_time:31690ms step_avg:91.86ms +step:346/1670 train_time:31781ms step_avg:91.85ms +step:347/1670 train_time:31873ms step_avg:91.85ms +step:348/1670 train_time:31963ms step_avg:91.85ms +step:349/1670 train_time:32055ms step_avg:91.85ms +step:350/1670 train_time:32145ms step_avg:91.84ms +step:351/1670 train_time:32237ms step_avg:91.84ms +step:352/1670 train_time:32328ms step_avg:91.84ms +step:353/1670 train_time:32420ms step_avg:91.84ms +step:354/1670 train_time:32511ms step_avg:91.84ms +step:355/1670 train_time:32604ms step_avg:91.84ms +step:356/1670 train_time:32694ms step_avg:91.84ms +step:357/1670 train_time:32784ms step_avg:91.83ms +step:358/1670 train_time:32875ms step_avg:91.83ms +step:359/1670 train_time:32966ms step_avg:91.83ms +step:360/1670 train_time:33057ms step_avg:91.82ms +step:361/1670 train_time:33147ms step_avg:91.82ms +step:362/1670 train_time:33239ms step_avg:91.82ms +step:363/1670 train_time:33330ms step_avg:91.82ms +step:364/1670 train_time:33422ms step_avg:91.82ms +step:365/1670 train_time:33513ms step_avg:91.82ms +step:366/1670 train_time:33604ms step_avg:91.81ms +step:367/1670 train_time:33695ms step_avg:91.81ms +step:368/1670 train_time:33786ms step_avg:91.81ms +step:369/1670 train_time:33876ms step_avg:91.81ms +step:370/1670 train_time:33967ms step_avg:91.80ms +step:371/1670 train_time:34058ms step_avg:91.80ms +step:372/1670 train_time:34149ms step_avg:91.80ms +step:373/1670 train_time:34240ms step_avg:91.80ms +step:374/1670 train_time:34331ms step_avg:91.79ms +step:375/1670 train_time:34423ms step_avg:91.80ms +step:375/1670 val_loss:3.8129 train_time:34514ms step_avg:92.04ms +step:376/1670 train_time:34533ms step_avg:91.84ms +step:377/1670 train_time:34608ms step_avg:91.80ms +step:378/1670 train_time:34701ms step_avg:91.80ms +step:379/1670 train_time:34792ms step_avg:91.80ms +step:380/1670 train_time:34882ms step_avg:91.79ms +step:381/1670 train_time:34972ms step_avg:91.79ms +step:382/1670 train_time:35062ms step_avg:91.79ms +step:383/1670 train_time:35153ms step_avg:91.78ms +step:384/1670 train_time:35243ms step_avg:91.78ms +step:385/1670 train_time:35335ms step_avg:91.78ms +step:386/1670 train_time:35426ms step_avg:91.78ms +step:387/1670 train_time:35518ms step_avg:91.78ms +step:388/1670 train_time:35613ms step_avg:91.79ms +step:389/1670 train_time:35706ms step_avg:91.79ms +step:390/1670 train_time:35796ms step_avg:91.78ms +step:391/1670 train_time:35886ms step_avg:91.78ms +step:392/1670 train_time:35976ms step_avg:91.78ms +step:393/1670 train_time:36067ms step_avg:91.77ms +step:394/1670 train_time:36157ms step_avg:91.77ms +step:395/1670 train_time:36248ms step_avg:91.77ms +step:396/1670 train_time:36338ms step_avg:91.76ms +step:397/1670 train_time:36431ms step_avg:91.77ms +step:398/1670 train_time:36522ms step_avg:91.76ms +step:399/1670 train_time:36616ms step_avg:91.77ms +step:400/1670 train_time:36708ms step_avg:91.77ms +step:401/1670 train_time:36799ms step_avg:91.77ms +step:402/1670 train_time:36889ms step_avg:91.76ms +step:403/1670 train_time:36979ms step_avg:91.76ms +step:404/1670 train_time:37070ms step_avg:91.76ms +step:405/1670 train_time:37160ms step_avg:91.75ms +step:406/1670 train_time:37251ms step_avg:91.75ms +step:407/1670 train_time:37341ms step_avg:91.75ms +step:408/1670 train_time:37433ms step_avg:91.75ms +step:409/1670 train_time:37524ms step_avg:91.74ms +step:410/1670 train_time:37617ms step_avg:91.75ms +step:411/1670 train_time:37708ms step_avg:91.75ms +step:412/1670 train_time:37800ms step_avg:91.75ms +step:413/1670 train_time:37890ms step_avg:91.74ms +step:414/1670 train_time:37981ms step_avg:91.74ms +step:415/1670 train_time:38071ms step_avg:91.74ms +step:416/1670 train_time:38162ms step_avg:91.74ms +step:417/1670 train_time:38252ms step_avg:91.73ms +step:418/1670 train_time:38343ms step_avg:91.73ms +step:419/1670 train_time:38434ms step_avg:91.73ms +step:420/1670 train_time:38526ms step_avg:91.73ms +step:421/1670 train_time:38618ms step_avg:91.73ms +step:422/1670 train_time:38710ms step_avg:91.73ms +step:423/1670 train_time:38802ms step_avg:91.73ms +step:424/1670 train_time:38893ms step_avg:91.73ms +step:425/1670 train_time:39144ms step_avg:92.10ms +step:426/1670 train_time:39212ms step_avg:92.05ms +step:427/1670 train_time:39302ms step_avg:92.04ms +step:428/1670 train_time:39392ms step_avg:92.04ms +step:429/1670 train_time:39482ms step_avg:92.03ms +step:430/1670 train_time:39572ms step_avg:92.03ms +step:431/1670 train_time:39662ms step_avg:92.02ms +step:432/1670 train_time:39752ms step_avg:92.02ms +step:433/1670 train_time:39843ms step_avg:92.02ms +step:434/1670 train_time:39934ms step_avg:92.01ms +step:435/1670 train_time:40027ms step_avg:92.02ms +step:436/1670 train_time:40122ms step_avg:92.02ms +step:437/1670 train_time:40215ms step_avg:92.02ms +step:438/1670 train_time:40306ms step_avg:92.02ms +step:439/1670 train_time:40397ms step_avg:92.02ms +step:440/1670 train_time:40489ms step_avg:92.02ms +step:441/1670 train_time:40579ms step_avg:92.02ms +step:442/1670 train_time:40669ms step_avg:92.01ms +step:443/1670 train_time:40759ms step_avg:92.01ms +step:444/1670 train_time:40850ms step_avg:92.00ms +step:445/1670 train_time:40940ms step_avg:92.00ms +step:446/1670 train_time:41032ms step_avg:92.00ms +step:447/1670 train_time:41124ms step_avg:92.00ms +step:448/1670 train_time:41218ms step_avg:92.00ms +step:449/1670 train_time:41310ms step_avg:92.00ms +step:450/1670 train_time:41401ms step_avg:92.00ms +step:451/1670 train_time:41492ms step_avg:92.00ms +step:452/1670 train_time:41582ms step_avg:92.00ms +step:453/1670 train_time:41673ms step_avg:91.99ms +step:454/1670 train_time:41763ms step_avg:91.99ms +step:455/1670 train_time:41853ms step_avg:91.98ms +step:456/1670 train_time:41944ms step_avg:91.98ms +step:457/1670 train_time:42036ms step_avg:91.98ms +step:458/1670 train_time:42128ms step_avg:91.98ms +step:459/1670 train_time:42220ms step_avg:91.98ms +step:460/1670 train_time:42312ms step_avg:91.98ms +step:461/1670 train_time:42402ms step_avg:91.98ms +step:462/1670 train_time:42493ms step_avg:91.98ms +step:463/1670 train_time:42584ms step_avg:91.97ms +step:464/1670 train_time:42675ms step_avg:91.97ms +step:465/1670 train_time:42765ms step_avg:91.97ms +step:466/1670 train_time:42856ms step_avg:91.97ms +step:467/1670 train_time:42947ms step_avg:91.96ms +step:468/1670 train_time:43039ms step_avg:91.96ms +step:469/1670 train_time:43130ms step_avg:91.96ms +step:470/1670 train_time:43221ms step_avg:91.96ms +step:471/1670 train_time:43313ms step_avg:91.96ms +step:472/1670 train_time:43404ms step_avg:91.96ms +step:473/1670 train_time:43496ms step_avg:91.96ms +step:474/1670 train_time:43586ms step_avg:91.95ms +step:475/1670 train_time:43676ms step_avg:91.95ms +step:476/1670 train_time:43767ms step_avg:91.95ms +step:477/1670 train_time:43857ms step_avg:91.94ms +step:478/1670 train_time:43947ms step_avg:91.94ms +step:479/1670 train_time:44038ms step_avg:91.94ms +step:480/1670 train_time:44129ms step_avg:91.94ms +step:481/1670 train_time:44220ms step_avg:91.93ms +step:482/1670 train_time:44311ms step_avg:91.93ms +step:483/1670 train_time:44402ms step_avg:91.93ms +step:484/1670 train_time:44494ms step_avg:91.93ms +step:485/1670 train_time:44585ms step_avg:91.93ms +step:486/1670 train_time:44676ms step_avg:91.93ms +step:487/1670 train_time:44767ms step_avg:91.92ms +step:488/1670 train_time:44857ms step_avg:91.92ms +step:489/1670 train_time:44948ms step_avg:91.92ms +step:490/1670 train_time:45038ms step_avg:91.91ms +step:491/1670 train_time:45129ms step_avg:91.91ms +step:492/1670 train_time:45220ms step_avg:91.91ms +step:493/1670 train_time:45312ms step_avg:91.91ms +step:494/1670 train_time:45403ms step_avg:91.91ms +step:495/1670 train_time:45494ms step_avg:91.91ms +step:496/1670 train_time:45585ms step_avg:91.91ms +step:497/1670 train_time:45677ms step_avg:91.90ms +step:498/1670 train_time:45768ms step_avg:91.90ms +step:499/1670 train_time:45858ms step_avg:91.90ms +step:500/1670 train_time:45948ms step_avg:91.90ms +step:500/1670 val_loss:3.7130 train_time:46038ms step_avg:92.08ms +step:501/1670 train_time:46058ms step_avg:91.93ms +step:502/1670 train_time:46132ms step_avg:91.90ms +step:503/1670 train_time:46224ms step_avg:91.90ms +step:504/1670 train_time:46314ms step_avg:91.89ms +step:505/1670 train_time:46404ms step_avg:91.89ms +step:506/1670 train_time:46494ms step_avg:91.88ms +step:507/1670 train_time:46584ms step_avg:91.88ms +step:508/1670 train_time:46673ms step_avg:91.88ms +step:509/1670 train_time:46764ms step_avg:91.87ms +step:510/1670 train_time:46856ms step_avg:91.87ms +step:511/1670 train_time:46947ms step_avg:91.87ms +step:512/1670 train_time:47041ms step_avg:91.88ms +step:513/1670 train_time:47132ms step_avg:91.88ms +step:514/1670 train_time:47224ms step_avg:91.88ms +step:515/1670 train_time:47318ms step_avg:91.88ms +step:516/1670 train_time:47408ms step_avg:91.88ms +step:517/1670 train_time:47498ms step_avg:91.87ms +step:518/1670 train_time:47589ms step_avg:91.87ms +step:519/1670 train_time:47680ms step_avg:91.87ms +step:520/1670 train_time:47770ms step_avg:91.87ms +step:521/1670 train_time:47861ms step_avg:91.86ms +step:522/1670 train_time:47952ms step_avg:91.86ms +step:523/1670 train_time:48044ms step_avg:91.86ms +step:524/1670 train_time:48136ms step_avg:91.86ms +step:525/1670 train_time:48227ms step_avg:91.86ms +step:526/1670 train_time:48318ms step_avg:91.86ms +step:527/1670 train_time:48409ms step_avg:91.86ms +step:528/1670 train_time:48500ms step_avg:91.86ms +step:529/1670 train_time:48590ms step_avg:91.85ms +step:530/1670 train_time:48682ms step_avg:91.85ms +step:531/1670 train_time:48772ms step_avg:91.85ms +step:532/1670 train_time:48864ms step_avg:91.85ms +step:533/1670 train_time:48955ms step_avg:91.85ms +step:534/1670 train_time:49047ms step_avg:91.85ms +step:535/1670 train_time:49139ms step_avg:91.85ms +step:536/1670 train_time:49231ms step_avg:91.85ms +step:537/1670 train_time:49322ms step_avg:91.85ms +step:538/1670 train_time:49413ms step_avg:91.85ms +step:539/1670 train_time:49504ms step_avg:91.84ms +step:540/1670 train_time:49594ms step_avg:91.84ms +step:541/1670 train_time:49684ms step_avg:91.84ms +step:542/1670 train_time:49776ms step_avg:91.84ms +step:543/1670 train_time:49868ms step_avg:91.84ms +step:544/1670 train_time:49960ms step_avg:91.84ms +step:545/1670 train_time:50050ms step_avg:91.84ms +step:546/1670 train_time:50142ms step_avg:91.83ms +step:547/1670 train_time:50233ms step_avg:91.83ms +step:548/1670 train_time:50324ms step_avg:91.83ms +step:549/1670 train_time:50416ms step_avg:91.83ms +step:550/1670 train_time:50507ms step_avg:91.83ms +step:551/1670 train_time:50598ms step_avg:91.83ms +step:552/1670 train_time:50688ms step_avg:91.83ms +step:553/1670 train_time:50778ms step_avg:91.82ms +step:554/1670 train_time:50869ms step_avg:91.82ms +step:555/1670 train_time:50961ms step_avg:91.82ms +step:556/1670 train_time:51052ms step_avg:91.82ms +step:557/1670 train_time:51144ms step_avg:91.82ms +step:558/1670 train_time:51427ms step_avg:92.16ms +step:559/1670 train_time:51500ms step_avg:92.13ms +step:560/1670 train_time:51591ms step_avg:92.13ms +step:561/1670 train_time:51683ms step_avg:92.13ms +step:562/1670 train_time:51774ms step_avg:92.12ms +step:563/1670 train_time:51865ms step_avg:92.12ms +step:564/1670 train_time:51956ms step_avg:92.12ms +step:565/1670 train_time:52047ms step_avg:92.12ms +step:566/1670 train_time:52138ms step_avg:92.12ms +step:567/1670 train_time:52229ms step_avg:92.11ms +step:568/1670 train_time:52324ms step_avg:92.12ms +step:569/1670 train_time:52420ms step_avg:92.13ms +step:570/1670 train_time:52514ms step_avg:92.13ms +step:571/1670 train_time:52606ms step_avg:92.13ms +step:572/1670 train_time:52698ms step_avg:92.13ms +step:573/1670 train_time:52789ms step_avg:92.13ms +step:574/1670 train_time:52881ms step_avg:92.13ms +step:575/1670 train_time:52972ms step_avg:92.12ms +step:576/1670 train_time:53063ms step_avg:92.12ms +step:577/1670 train_time:53156ms step_avg:92.12ms +step:578/1670 train_time:53248ms step_avg:92.13ms +step:579/1670 train_time:53342ms step_avg:92.13ms +step:580/1670 train_time:53435ms step_avg:92.13ms +step:581/1670 train_time:53528ms step_avg:92.13ms +step:582/1670 train_time:53620ms step_avg:92.13ms +step:583/1670 train_time:53712ms step_avg:92.13ms +step:584/1670 train_time:53804ms step_avg:92.13ms +step:585/1670 train_time:53896ms step_avg:92.13ms +step:586/1670 train_time:53988ms step_avg:92.13ms +step:587/1670 train_time:54080ms step_avg:92.13ms +step:588/1670 train_time:54172ms step_avg:92.13ms +step:589/1670 train_time:54265ms step_avg:92.13ms +step:590/1670 train_time:54357ms step_avg:92.13ms +step:591/1670 train_time:54450ms step_avg:92.13ms +step:592/1670 train_time:54544ms step_avg:92.13ms +step:593/1670 train_time:54636ms step_avg:92.13ms +step:594/1670 train_time:54728ms step_avg:92.13ms +step:595/1670 train_time:54819ms step_avg:92.13ms +step:596/1670 train_time:54910ms step_avg:92.13ms +step:597/1670 train_time:55002ms step_avg:92.13ms +step:598/1670 train_time:55093ms step_avg:92.13ms +step:599/1670 train_time:55185ms step_avg:92.13ms +step:600/1670 train_time:55277ms step_avg:92.13ms +step:601/1670 train_time:55369ms step_avg:92.13ms +step:602/1670 train_time:55463ms step_avg:92.13ms +step:603/1670 train_time:55556ms step_avg:92.13ms +step:604/1670 train_time:55648ms step_avg:92.13ms +step:605/1670 train_time:55741ms step_avg:92.13ms +step:606/1670 train_time:55832ms step_avg:92.13ms +step:607/1670 train_time:55927ms step_avg:92.14ms +step:608/1670 train_time:56019ms step_avg:92.14ms +step:609/1670 train_time:56110ms step_avg:92.13ms +step:610/1670 train_time:56202ms step_avg:92.14ms +step:611/1670 train_time:56294ms step_avg:92.13ms +step:612/1670 train_time:56387ms step_avg:92.14ms +step:613/1670 train_time:56481ms step_avg:92.14ms +step:614/1670 train_time:56573ms step_avg:92.14ms +step:615/1670 train_time:56667ms step_avg:92.14ms +step:616/1670 train_time:56759ms step_avg:92.14ms +step:617/1670 train_time:56851ms step_avg:92.14ms +step:618/1670 train_time:56943ms step_avg:92.14ms +step:619/1670 train_time:57034ms step_avg:92.14ms +step:620/1670 train_time:57127ms step_avg:92.14ms +step:621/1670 train_time:57219ms step_avg:92.14ms +step:622/1670 train_time:57311ms step_avg:92.14ms +step:623/1670 train_time:57404ms step_avg:92.14ms +step:624/1670 train_time:57496ms step_avg:92.14ms +step:625/1670 train_time:57589ms step_avg:92.14ms +step:625/1670 val_loss:3.6120 train_time:57683ms step_avg:92.29ms +step:626/1670 train_time:57703ms step_avg:92.18ms +step:627/1670 train_time:57783ms step_avg:92.16ms +step:628/1670 train_time:57883ms step_avg:92.17ms +step:629/1670 train_time:57976ms step_avg:92.17ms +step:630/1670 train_time:58068ms step_avg:92.17ms +step:631/1670 train_time:58159ms step_avg:92.17ms +step:632/1670 train_time:58250ms step_avg:92.17ms +step:633/1670 train_time:58341ms step_avg:92.17ms +step:634/1670 train_time:58432ms step_avg:92.16ms +step:635/1670 train_time:58523ms step_avg:92.16ms +step:636/1670 train_time:58615ms step_avg:92.16ms +step:637/1670 train_time:58707ms step_avg:92.16ms +step:638/1670 train_time:58803ms step_avg:92.17ms +step:639/1670 train_time:59038ms step_avg:92.39ms +step:640/1670 train_time:59110ms step_avg:92.36ms +step:641/1670 train_time:59201ms step_avg:92.36ms +step:642/1670 train_time:59292ms step_avg:92.36ms +step:643/1670 train_time:59383ms step_avg:92.35ms +step:644/1670 train_time:59475ms step_avg:92.35ms +step:645/1670 train_time:59566ms step_avg:92.35ms +step:646/1670 train_time:59657ms step_avg:92.35ms +step:647/1670 train_time:59747ms step_avg:92.35ms +step:648/1670 train_time:59838ms step_avg:92.34ms +step:649/1670 train_time:59936ms step_avg:92.35ms +step:650/1670 train_time:60033ms step_avg:92.36ms +step:651/1670 train_time:60126ms step_avg:92.36ms +step:652/1670 train_time:60218ms step_avg:92.36ms +step:653/1670 train_time:60310ms step_avg:92.36ms +step:654/1670 train_time:60401ms step_avg:92.36ms +step:655/1670 train_time:60493ms step_avg:92.36ms +step:656/1670 train_time:60584ms step_avg:92.35ms +step:657/1670 train_time:60675ms step_avg:92.35ms +step:658/1670 train_time:60766ms step_avg:92.35ms +step:659/1670 train_time:60859ms step_avg:92.35ms +step:660/1670 train_time:60953ms step_avg:92.35ms +step:661/1670 train_time:61048ms step_avg:92.36ms +step:662/1670 train_time:61140ms step_avg:92.36ms +step:663/1670 train_time:61233ms step_avg:92.36ms +step:664/1670 train_time:61326ms step_avg:92.36ms +step:665/1670 train_time:61417ms step_avg:92.36ms +step:666/1670 train_time:61509ms step_avg:92.36ms +step:667/1670 train_time:61602ms step_avg:92.36ms +step:668/1670 train_time:61694ms step_avg:92.36ms +step:669/1670 train_time:61785ms step_avg:92.35ms +step:670/1670 train_time:61877ms step_avg:92.35ms +step:671/1670 train_time:61971ms step_avg:92.36ms +step:672/1670 train_time:62065ms step_avg:92.36ms +step:673/1670 train_time:62158ms step_avg:92.36ms +step:674/1670 train_time:62251ms step_avg:92.36ms +step:675/1670 train_time:62344ms step_avg:92.36ms +step:676/1670 train_time:62436ms step_avg:92.36ms +step:677/1670 train_time:62528ms step_avg:92.36ms +step:678/1670 train_time:62618ms step_avg:92.36ms +step:679/1670 train_time:62710ms step_avg:92.36ms +step:680/1670 train_time:62802ms step_avg:92.36ms +step:681/1670 train_time:62894ms step_avg:92.36ms +step:682/1670 train_time:62987ms step_avg:92.36ms +step:683/1670 train_time:63080ms step_avg:92.36ms +step:684/1670 train_time:63172ms step_avg:92.36ms +step:685/1670 train_time:63265ms step_avg:92.36ms +step:686/1670 train_time:63356ms step_avg:92.36ms +step:687/1670 train_time:63449ms step_avg:92.36ms +step:688/1670 train_time:63541ms step_avg:92.36ms +step:689/1670 train_time:63632ms step_avg:92.35ms +step:690/1670 train_time:63723ms step_avg:92.35ms +step:691/1670 train_time:63816ms step_avg:92.35ms +step:692/1670 train_time:63909ms step_avg:92.35ms +step:693/1670 train_time:64003ms step_avg:92.36ms +step:694/1670 train_time:64096ms step_avg:92.36ms +step:695/1670 train_time:64188ms step_avg:92.36ms +step:696/1670 train_time:64281ms step_avg:92.36ms +step:697/1670 train_time:64375ms step_avg:92.36ms +step:698/1670 train_time:64464ms step_avg:92.36ms +step:699/1670 train_time:64555ms step_avg:92.35ms +step:700/1670 train_time:64648ms step_avg:92.35ms +step:701/1670 train_time:64739ms step_avg:92.35ms +step:702/1670 train_time:64832ms step_avg:92.35ms +step:703/1670 train_time:64924ms step_avg:92.35ms +step:704/1670 train_time:65016ms step_avg:92.35ms +step:705/1670 train_time:65109ms step_avg:92.35ms +step:706/1670 train_time:65201ms step_avg:92.35ms +step:707/1670 train_time:65294ms step_avg:92.35ms +step:708/1670 train_time:65387ms step_avg:92.35ms +step:709/1670 train_time:65479ms step_avg:92.35ms +step:710/1670 train_time:65573ms step_avg:92.36ms +step:711/1670 train_time:65666ms step_avg:92.36ms +step:712/1670 train_time:65757ms step_avg:92.36ms +step:713/1670 train_time:65849ms step_avg:92.36ms +step:714/1670 train_time:65941ms step_avg:92.35ms +step:715/1670 train_time:66033ms step_avg:92.35ms +step:716/1670 train_time:66125ms step_avg:92.35ms +step:717/1670 train_time:66218ms step_avg:92.35ms +step:718/1670 train_time:66312ms step_avg:92.36ms +step:719/1670 train_time:66404ms step_avg:92.36ms +step:720/1670 train_time:66495ms step_avg:92.35ms +step:721/1670 train_time:66588ms step_avg:92.35ms +step:722/1670 train_time:66680ms step_avg:92.35ms +step:723/1670 train_time:66772ms step_avg:92.35ms +step:724/1670 train_time:66864ms step_avg:92.35ms +step:725/1670 train_time:66956ms step_avg:92.35ms +step:726/1670 train_time:67048ms step_avg:92.35ms +step:727/1670 train_time:67140ms step_avg:92.35ms +step:728/1670 train_time:67233ms step_avg:92.35ms +step:729/1670 train_time:67325ms step_avg:92.35ms +step:730/1670 train_time:67417ms step_avg:92.35ms +step:731/1670 train_time:67509ms step_avg:92.35ms +step:732/1670 train_time:67601ms step_avg:92.35ms +step:733/1670 train_time:67693ms step_avg:92.35ms +step:734/1670 train_time:67786ms step_avg:92.35ms +step:735/1670 train_time:67878ms step_avg:92.35ms +step:736/1670 train_time:67971ms step_avg:92.35ms +step:737/1670 train_time:68063ms step_avg:92.35ms +step:738/1670 train_time:68155ms step_avg:92.35ms +step:739/1670 train_time:68247ms step_avg:92.35ms +step:740/1670 train_time:68339ms step_avg:92.35ms +step:741/1670 train_time:68432ms step_avg:92.35ms +step:742/1670 train_time:68525ms step_avg:92.35ms +step:743/1670 train_time:68617ms step_avg:92.35ms +step:744/1670 train_time:68709ms step_avg:92.35ms +step:745/1670 train_time:68801ms step_avg:92.35ms +step:746/1670 train_time:68895ms step_avg:92.35ms +step:747/1670 train_time:68988ms step_avg:92.35ms +step:748/1670 train_time:69080ms step_avg:92.35ms +step:749/1670 train_time:69172ms step_avg:92.35ms +step:750/1670 train_time:69264ms step_avg:92.35ms +step:750/1670 val_loss:3.5608 train_time:69355ms step_avg:92.47ms +step:751/1670 train_time:69375ms step_avg:92.38ms +step:752/1670 train_time:69449ms step_avg:92.35ms +step:753/1670 train_time:69542ms step_avg:92.35ms +step:754/1670 train_time:69634ms step_avg:92.35ms +step:755/1670 train_time:69726ms step_avg:92.35ms +step:756/1670 train_time:69818ms step_avg:92.35ms +step:757/1670 train_time:69909ms step_avg:92.35ms +step:758/1670 train_time:70001ms step_avg:92.35ms +step:759/1670 train_time:70094ms step_avg:92.35ms +step:760/1670 train_time:70189ms step_avg:92.35ms +step:761/1670 train_time:70281ms step_avg:92.35ms +step:762/1670 train_time:70374ms step_avg:92.35ms +step:763/1670 train_time:70467ms step_avg:92.36ms +step:764/1670 train_time:70560ms step_avg:92.36ms +step:765/1670 train_time:70652ms step_avg:92.36ms +step:766/1670 train_time:70745ms step_avg:92.36ms +step:767/1670 train_time:70836ms step_avg:92.35ms +step:768/1670 train_time:70929ms step_avg:92.36ms +step:769/1670 train_time:71021ms step_avg:92.36ms +step:770/1670 train_time:71113ms step_avg:92.35ms +step:771/1670 train_time:71205ms step_avg:92.35ms +step:772/1670 train_time:71298ms step_avg:92.36ms +step:773/1670 train_time:71391ms step_avg:92.36ms +step:774/1670 train_time:71486ms step_avg:92.36ms +step:775/1670 train_time:71577ms step_avg:92.36ms +step:776/1670 train_time:71671ms step_avg:92.36ms +step:777/1670 train_time:71763ms step_avg:92.36ms +step:778/1670 train_time:71855ms step_avg:92.36ms +step:779/1670 train_time:71949ms step_avg:92.36ms +step:780/1670 train_time:72041ms step_avg:92.36ms +step:781/1670 train_time:72133ms step_avg:92.36ms +step:782/1670 train_time:72226ms step_avg:92.36ms +step:783/1670 train_time:72318ms step_avg:92.36ms +step:784/1670 train_time:72411ms step_avg:92.36ms +step:785/1670 train_time:72505ms step_avg:92.36ms +step:786/1670 train_time:72598ms step_avg:92.36ms +step:787/1670 train_time:72691ms step_avg:92.36ms +step:788/1670 train_time:72783ms step_avg:92.36ms +step:789/1670 train_time:72875ms step_avg:92.36ms +step:790/1670 train_time:72967ms step_avg:92.36ms +step:791/1670 train_time:73060ms step_avg:92.36ms +step:792/1670 train_time:73152ms step_avg:92.36ms +step:793/1670 train_time:73245ms step_avg:92.36ms +step:794/1670 train_time:73337ms step_avg:92.36ms +step:795/1670 train_time:73430ms step_avg:92.36ms +step:796/1670 train_time:73523ms step_avg:92.37ms +step:797/1670 train_time:73615ms step_avg:92.37ms +step:798/1670 train_time:73708ms step_avg:92.37ms +step:799/1670 train_time:73799ms step_avg:92.36ms +step:800/1670 train_time:73892ms step_avg:92.36ms +step:801/1670 train_time:73984ms step_avg:92.36ms +step:802/1670 train_time:74075ms step_avg:92.36ms +step:803/1670 train_time:74168ms step_avg:92.36ms +step:804/1670 train_time:74260ms step_avg:92.36ms +step:805/1670 train_time:74352ms step_avg:92.36ms +step:806/1670 train_time:74445ms step_avg:92.36ms +step:807/1670 train_time:74537ms step_avg:92.36ms +step:808/1670 train_time:74631ms step_avg:92.37ms +step:809/1670 train_time:74724ms step_avg:92.37ms +step:810/1670 train_time:74815ms step_avg:92.36ms +step:811/1670 train_time:74907ms step_avg:92.36ms +step:812/1670 train_time:74999ms step_avg:92.36ms +step:813/1670 train_time:75092ms step_avg:92.36ms +step:814/1670 train_time:75184ms step_avg:92.36ms +step:815/1670 train_time:75276ms step_avg:92.36ms +step:816/1670 train_time:75370ms step_avg:92.37ms +step:817/1670 train_time:75463ms step_avg:92.37ms +step:818/1670 train_time:75554ms step_avg:92.36ms +step:819/1670 train_time:75648ms step_avg:92.37ms +step:820/1670 train_time:75740ms step_avg:92.37ms +step:821/1670 train_time:75832ms step_avg:92.37ms +step:822/1670 train_time:75924ms step_avg:92.37ms +step:823/1670 train_time:76016ms step_avg:92.36ms +step:824/1670 train_time:76109ms step_avg:92.37ms +step:825/1670 train_time:76202ms step_avg:92.37ms +step:826/1670 train_time:76293ms step_avg:92.36ms +step:827/1670 train_time:76385ms step_avg:92.36ms +step:828/1670 train_time:76478ms step_avg:92.36ms +step:829/1670 train_time:76571ms step_avg:92.37ms +step:830/1670 train_time:76665ms step_avg:92.37ms +step:831/1670 train_time:76757ms step_avg:92.37ms +step:832/1670 train_time:76849ms step_avg:92.37ms +step:833/1670 train_time:76941ms step_avg:92.37ms +step:834/1670 train_time:77034ms step_avg:92.37ms +step:835/1670 train_time:77125ms step_avg:92.37ms +step:836/1670 train_time:77217ms step_avg:92.37ms +step:837/1670 train_time:77311ms step_avg:92.37ms +step:838/1670 train_time:77403ms step_avg:92.37ms +step:839/1670 train_time:77495ms step_avg:92.37ms +step:840/1670 train_time:77588ms step_avg:92.37ms +step:841/1670 train_time:77681ms step_avg:92.37ms +step:842/1670 train_time:77773ms step_avg:92.37ms +step:843/1670 train_time:77865ms step_avg:92.37ms +step:844/1670 train_time:77957ms step_avg:92.37ms +step:845/1670 train_time:78049ms step_avg:92.37ms +step:846/1670 train_time:78142ms step_avg:92.37ms +step:847/1670 train_time:78234ms step_avg:92.37ms +step:848/1670 train_time:78327ms step_avg:92.37ms +step:849/1670 train_time:78420ms step_avg:92.37ms +step:850/1670 train_time:78513ms step_avg:92.37ms +step:851/1670 train_time:78767ms step_avg:92.56ms +step:852/1670 train_time:78837ms step_avg:92.53ms +step:853/1670 train_time:78928ms step_avg:92.53ms +step:854/1670 train_time:79019ms step_avg:92.53ms +step:855/1670 train_time:79110ms step_avg:92.53ms +step:856/1670 train_time:79201ms step_avg:92.52ms +step:857/1670 train_time:79292ms step_avg:92.52ms +step:858/1670 train_time:79384ms step_avg:92.52ms +step:859/1670 train_time:79474ms step_avg:92.52ms +step:860/1670 train_time:79565ms step_avg:92.52ms +step:861/1670 train_time:79661ms step_avg:92.52ms +step:862/1670 train_time:79760ms step_avg:92.53ms +step:863/1670 train_time:79853ms step_avg:92.53ms +step:864/1670 train_time:79946ms step_avg:92.53ms +step:865/1670 train_time:80037ms step_avg:92.53ms +step:866/1670 train_time:80129ms step_avg:92.53ms +step:867/1670 train_time:80221ms step_avg:92.53ms +step:868/1670 train_time:80312ms step_avg:92.53ms +step:869/1670 train_time:80403ms step_avg:92.52ms +step:870/1670 train_time:80494ms step_avg:92.52ms +step:871/1670 train_time:80587ms step_avg:92.52ms +step:872/1670 train_time:80680ms step_avg:92.52ms +step:873/1670 train_time:80775ms step_avg:92.53ms +step:874/1670 train_time:80869ms step_avg:92.53ms +step:875/1670 train_time:80961ms step_avg:92.53ms +step:875/1670 val_loss:3.5160 train_time:81052ms step_avg:92.63ms +step:876/1670 train_time:81072ms step_avg:92.55ms +step:877/1670 train_time:81145ms step_avg:92.53ms +step:878/1670 train_time:81237ms step_avg:92.53ms +step:879/1670 train_time:81329ms step_avg:92.52ms +step:880/1670 train_time:81422ms step_avg:92.52ms +step:881/1670 train_time:81512ms step_avg:92.52ms +step:882/1670 train_time:81603ms step_avg:92.52ms +step:883/1670 train_time:81695ms step_avg:92.52ms +step:884/1670 train_time:81787ms step_avg:92.52ms +step:885/1670 train_time:81880ms step_avg:92.52ms +step:886/1670 train_time:81974ms step_avg:92.52ms +step:887/1670 train_time:82068ms step_avg:92.52ms +step:888/1670 train_time:82162ms step_avg:92.52ms +step:889/1670 train_time:82254ms step_avg:92.52ms +step:890/1670 train_time:82346ms step_avg:92.52ms +step:891/1670 train_time:82437ms step_avg:92.52ms +step:892/1670 train_time:82529ms step_avg:92.52ms +step:893/1670 train_time:82621ms step_avg:92.52ms +step:894/1670 train_time:82713ms step_avg:92.52ms +step:895/1670 train_time:82805ms step_avg:92.52ms +step:896/1670 train_time:82897ms step_avg:92.52ms +step:897/1670 train_time:82990ms step_avg:92.52ms +step:898/1670 train_time:83086ms step_avg:92.52ms +step:899/1670 train_time:83178ms step_avg:92.52ms +step:900/1670 train_time:83270ms step_avg:92.52ms +step:901/1670 train_time:83363ms step_avg:92.52ms +step:902/1670 train_time:83455ms step_avg:92.52ms +step:903/1670 train_time:83546ms step_avg:92.52ms +step:904/1670 train_time:83637ms step_avg:92.52ms +step:905/1670 train_time:83730ms step_avg:92.52ms +step:906/1670 train_time:83823ms step_avg:92.52ms +step:907/1670 train_time:83914ms step_avg:92.52ms +step:908/1670 train_time:84007ms step_avg:92.52ms +step:909/1670 train_time:84100ms step_avg:92.52ms +step:910/1670 train_time:84194ms step_avg:92.52ms +step:911/1670 train_time:84286ms step_avg:92.52ms +step:912/1670 train_time:84378ms step_avg:92.52ms +step:913/1670 train_time:84470ms step_avg:92.52ms +step:914/1670 train_time:84562ms step_avg:92.52ms +step:915/1670 train_time:84654ms step_avg:92.52ms +step:916/1670 train_time:84746ms step_avg:92.52ms +step:917/1670 train_time:84838ms step_avg:92.52ms +step:918/1670 train_time:84931ms step_avg:92.52ms +step:919/1670 train_time:85024ms step_avg:92.52ms +step:920/1670 train_time:85116ms step_avg:92.52ms +step:921/1670 train_time:85209ms step_avg:92.52ms +step:922/1670 train_time:85302ms step_avg:92.52ms +step:923/1670 train_time:85395ms step_avg:92.52ms +step:924/1670 train_time:85487ms step_avg:92.52ms +step:925/1670 train_time:85579ms step_avg:92.52ms +step:926/1670 train_time:85671ms step_avg:92.52ms +step:927/1670 train_time:85763ms step_avg:92.52ms +step:928/1670 train_time:85855ms step_avg:92.52ms +step:929/1670 train_time:85948ms step_avg:92.52ms +step:930/1670 train_time:86040ms step_avg:92.52ms +step:931/1670 train_time:86132ms step_avg:92.52ms +step:932/1670 train_time:86225ms step_avg:92.52ms +step:933/1670 train_time:86317ms step_avg:92.52ms +step:934/1670 train_time:86410ms step_avg:92.52ms +step:935/1670 train_time:86503ms step_avg:92.52ms +step:936/1670 train_time:86595ms step_avg:92.52ms +step:937/1670 train_time:86687ms step_avg:92.52ms +step:938/1670 train_time:86778ms step_avg:92.51ms +step:939/1670 train_time:86872ms step_avg:92.52ms +step:940/1670 train_time:86964ms step_avg:92.52ms +step:941/1670 train_time:87056ms step_avg:92.51ms +step:942/1670 train_time:87149ms step_avg:92.51ms +step:943/1670 train_time:87241ms step_avg:92.51ms +step:944/1670 train_time:87334ms step_avg:92.51ms +step:945/1670 train_time:87427ms step_avg:92.52ms +step:946/1670 train_time:87519ms step_avg:92.52ms +step:947/1670 train_time:87612ms step_avg:92.52ms +step:948/1670 train_time:87704ms step_avg:92.51ms +step:949/1670 train_time:87797ms step_avg:92.51ms +step:950/1670 train_time:87889ms step_avg:92.51ms +step:951/1670 train_time:87981ms step_avg:92.51ms +step:952/1670 train_time:88074ms step_avg:92.51ms +step:953/1670 train_time:88166ms step_avg:92.51ms +step:954/1670 train_time:88258ms step_avg:92.51ms +step:955/1670 train_time:88351ms step_avg:92.51ms +step:956/1670 train_time:88443ms step_avg:92.51ms +step:957/1670 train_time:88535ms step_avg:92.51ms +step:958/1670 train_time:88628ms step_avg:92.51ms +step:959/1670 train_time:88720ms step_avg:92.51ms +step:960/1670 train_time:88813ms step_avg:92.51ms +step:961/1670 train_time:88905ms step_avg:92.51ms +step:962/1670 train_time:88996ms step_avg:92.51ms +step:963/1670 train_time:89089ms step_avg:92.51ms +step:964/1670 train_time:89182ms step_avg:92.51ms +step:965/1670 train_time:89274ms step_avg:92.51ms +step:966/1670 train_time:89367ms step_avg:92.51ms +step:967/1670 train_time:89459ms step_avg:92.51ms +step:968/1670 train_time:89551ms step_avg:92.51ms +step:969/1670 train_time:89644ms step_avg:92.51ms +step:970/1670 train_time:89736ms step_avg:92.51ms +step:971/1670 train_time:89828ms step_avg:92.51ms +step:972/1670 train_time:89920ms step_avg:92.51ms +step:973/1670 train_time:90013ms step_avg:92.51ms +step:974/1670 train_time:90106ms step_avg:92.51ms +step:975/1670 train_time:90198ms step_avg:92.51ms +step:976/1670 train_time:90291ms step_avg:92.51ms +step:977/1670 train_time:90383ms step_avg:92.51ms +step:978/1670 train_time:90475ms step_avg:92.51ms +step:979/1670 train_time:90568ms step_avg:92.51ms +step:980/1670 train_time:90659ms step_avg:92.51ms +step:981/1670 train_time:90752ms step_avg:92.51ms +step:982/1670 train_time:90845ms step_avg:92.51ms +step:983/1670 train_time:90937ms step_avg:92.51ms +step:984/1670 train_time:91029ms step_avg:92.51ms +step:985/1670 train_time:91121ms step_avg:92.51ms +step:986/1670 train_time:91214ms step_avg:92.51ms +step:987/1670 train_time:91307ms step_avg:92.51ms +step:988/1670 train_time:91399ms step_avg:92.51ms +step:989/1670 train_time:91493ms step_avg:92.51ms +step:990/1670 train_time:91585ms step_avg:92.51ms +step:991/1670 train_time:91676ms step_avg:92.51ms +step:992/1670 train_time:91769ms step_avg:92.51ms +step:993/1670 train_time:91861ms step_avg:92.51ms +step:994/1670 train_time:91954ms step_avg:92.51ms +step:995/1670 train_time:92045ms step_avg:92.51ms +step:996/1670 train_time:92136ms step_avg:92.51ms +step:997/1670 train_time:92230ms step_avg:92.51ms +step:998/1670 train_time:92322ms step_avg:92.51ms +step:999/1670 train_time:92415ms step_avg:92.51ms +step:1000/1670 train_time:92507ms step_avg:92.51ms +step:1000/1670 val_loss:3.4677 train_time:92599ms step_avg:92.60ms +step:1001/1670 train_time:92618ms step_avg:92.53ms +step:1002/1670 train_time:92692ms step_avg:92.51ms +step:1003/1670 train_time:92785ms step_avg:92.51ms +step:1004/1670 train_time:92877ms step_avg:92.51ms +step:1005/1670 train_time:92968ms step_avg:92.51ms +step:1006/1670 train_time:93059ms step_avg:92.50ms +step:1007/1670 train_time:93151ms step_avg:92.50ms +step:1008/1670 train_time:93242ms step_avg:92.50ms +step:1009/1670 train_time:93334ms step_avg:92.50ms +step:1010/1670 train_time:93431ms step_avg:92.51ms +step:1011/1670 train_time:93521ms step_avg:92.50ms +step:1012/1670 train_time:93614ms step_avg:92.50ms +step:1013/1670 train_time:93708ms step_avg:92.51ms +step:1014/1670 train_time:93801ms step_avg:92.51ms +step:1015/1670 train_time:93893ms step_avg:92.51ms +step:1016/1670 train_time:93986ms step_avg:92.51ms +step:1017/1670 train_time:94078ms step_avg:92.50ms +step:1018/1670 train_time:94172ms step_avg:92.51ms +step:1019/1670 train_time:94264ms step_avg:92.51ms +step:1020/1670 train_time:94357ms step_avg:92.51ms +step:1021/1670 train_time:94449ms step_avg:92.51ms +step:1022/1670 train_time:94542ms step_avg:92.51ms +step:1023/1670 train_time:94635ms step_avg:92.51ms +step:1024/1670 train_time:94728ms step_avg:92.51ms +step:1025/1670 train_time:94820ms step_avg:92.51ms +step:1026/1670 train_time:94913ms step_avg:92.51ms +step:1027/1670 train_time:95006ms step_avg:92.51ms +step:1028/1670 train_time:95099ms step_avg:92.51ms +step:1029/1670 train_time:95192ms step_avg:92.51ms +step:1030/1670 train_time:95283ms step_avg:92.51ms +step:1031/1670 train_time:95376ms step_avg:92.51ms +step:1032/1670 train_time:95469ms step_avg:92.51ms +step:1033/1670 train_time:95561ms step_avg:92.51ms +step:1034/1670 train_time:95655ms step_avg:92.51ms +step:1035/1670 train_time:95747ms step_avg:92.51ms +step:1036/1670 train_time:95839ms step_avg:92.51ms +step:1037/1670 train_time:95931ms step_avg:92.51ms +step:1038/1670 train_time:96023ms step_avg:92.51ms +step:1039/1670 train_time:96115ms step_avg:92.51ms +step:1040/1670 train_time:96208ms step_avg:92.51ms +step:1041/1670 train_time:96299ms step_avg:92.51ms +step:1042/1670 train_time:96392ms step_avg:92.51ms +step:1043/1670 train_time:96483ms step_avg:92.51ms +step:1044/1670 train_time:96577ms step_avg:92.51ms +step:1045/1670 train_time:96671ms step_avg:92.51ms +step:1046/1670 train_time:96763ms step_avg:92.51ms +step:1047/1670 train_time:96855ms step_avg:92.51ms +step:1048/1670 train_time:96947ms step_avg:92.51ms +step:1049/1670 train_time:97039ms step_avg:92.51ms +step:1050/1670 train_time:97131ms step_avg:92.51ms +step:1051/1670 train_time:97224ms step_avg:92.51ms +step:1052/1670 train_time:97316ms step_avg:92.51ms +step:1053/1670 train_time:97408ms step_avg:92.50ms +step:1054/1670 train_time:97500ms step_avg:92.50ms +step:1055/1670 train_time:97593ms step_avg:92.51ms +step:1056/1670 train_time:97686ms step_avg:92.51ms +step:1057/1670 train_time:97779ms step_avg:92.51ms +step:1058/1670 train_time:97872ms step_avg:92.51ms +step:1059/1670 train_time:97964ms step_avg:92.51ms +step:1060/1670 train_time:98057ms step_avg:92.51ms +step:1061/1670 train_time:98149ms step_avg:92.51ms +step:1062/1670 train_time:98402ms step_avg:92.66ms +step:1063/1670 train_time:98471ms step_avg:92.64ms +step:1064/1670 train_time:98561ms step_avg:92.63ms +step:1065/1670 train_time:98653ms step_avg:92.63ms +step:1066/1670 train_time:98743ms step_avg:92.63ms +step:1067/1670 train_time:98834ms step_avg:92.63ms +step:1068/1670 train_time:98926ms step_avg:92.63ms +step:1069/1670 train_time:99017ms step_avg:92.63ms +step:1070/1670 train_time:99108ms step_avg:92.62ms +step:1071/1670 train_time:99200ms step_avg:92.62ms +step:1072/1670 train_time:99296ms step_avg:92.63ms +step:1073/1670 train_time:99395ms step_avg:92.63ms +step:1074/1670 train_time:99489ms step_avg:92.63ms +step:1075/1670 train_time:99580ms step_avg:92.63ms +step:1076/1670 train_time:99673ms step_avg:92.63ms +step:1077/1670 train_time:99765ms step_avg:92.63ms +step:1078/1670 train_time:99856ms step_avg:92.63ms +step:1079/1670 train_time:99947ms step_avg:92.63ms +step:1080/1670 train_time:100038ms step_avg:92.63ms +step:1081/1670 train_time:100130ms step_avg:92.63ms +step:1082/1670 train_time:100222ms step_avg:92.63ms +step:1083/1670 train_time:100317ms step_avg:92.63ms +step:1084/1670 train_time:100414ms step_avg:92.63ms +step:1085/1670 train_time:100508ms step_avg:92.63ms +step:1086/1670 train_time:100600ms step_avg:92.63ms +step:1087/1670 train_time:100693ms step_avg:92.63ms +step:1088/1670 train_time:100785ms step_avg:92.63ms +step:1089/1670 train_time:100877ms step_avg:92.63ms +step:1090/1670 train_time:100968ms step_avg:92.63ms +step:1091/1670 train_time:101060ms step_avg:92.63ms +step:1092/1670 train_time:101152ms step_avg:92.63ms +step:1093/1670 train_time:101244ms step_avg:92.63ms +step:1094/1670 train_time:101337ms step_avg:92.63ms +step:1095/1670 train_time:101431ms step_avg:92.63ms +step:1096/1670 train_time:101523ms step_avg:92.63ms +step:1097/1670 train_time:101616ms step_avg:92.63ms +step:1098/1670 train_time:101708ms step_avg:92.63ms +step:1099/1670 train_time:101799ms step_avg:92.63ms +step:1100/1670 train_time:101891ms step_avg:92.63ms +step:1101/1670 train_time:101982ms step_avg:92.63ms +step:1102/1670 train_time:102075ms step_avg:92.63ms +step:1103/1670 train_time:102166ms step_avg:92.63ms +step:1104/1670 train_time:102259ms step_avg:92.63ms +step:1105/1670 train_time:102351ms step_avg:92.63ms +step:1106/1670 train_time:102444ms step_avg:92.63ms +step:1107/1670 train_time:102537ms step_avg:92.63ms +step:1108/1670 train_time:102630ms step_avg:92.63ms +step:1109/1670 train_time:102723ms step_avg:92.63ms +step:1110/1670 train_time:102816ms step_avg:92.63ms +step:1111/1670 train_time:102908ms step_avg:92.63ms +step:1112/1670 train_time:102999ms step_avg:92.63ms +step:1113/1670 train_time:103091ms step_avg:92.62ms +step:1114/1670 train_time:103182ms step_avg:92.62ms +step:1115/1670 train_time:103467ms step_avg:92.80ms +step:1116/1670 train_time:103544ms step_avg:92.78ms +step:1117/1670 train_time:103635ms step_avg:92.78ms +step:1118/1670 train_time:103727ms step_avg:92.78ms +step:1119/1670 train_time:103819ms step_avg:92.78ms +step:1120/1670 train_time:103911ms step_avg:92.78ms +step:1121/1670 train_time:104003ms step_avg:92.78ms +step:1122/1670 train_time:104095ms step_avg:92.78ms +step:1123/1670 train_time:104187ms step_avg:92.78ms +step:1124/1670 train_time:104278ms step_avg:92.77ms +step:1125/1670 train_time:104376ms step_avg:92.78ms +step:1125/1670 val_loss:3.4149 train_time:104475ms step_avg:92.87ms +step:1126/1670 train_time:104497ms step_avg:92.80ms +step:1127/1670 train_time:104574ms step_avg:92.79ms +step:1128/1670 train_time:104672ms step_avg:92.79ms +step:1129/1670 train_time:104765ms step_avg:92.79ms +step:1130/1670 train_time:104857ms step_avg:92.79ms +step:1131/1670 train_time:104949ms step_avg:92.79ms +step:1132/1670 train_time:105042ms step_avg:92.79ms +step:1133/1670 train_time:105134ms step_avg:92.79ms +step:1134/1670 train_time:105225ms step_avg:92.79ms +step:1135/1670 train_time:105317ms step_avg:92.79ms +step:1136/1670 train_time:105411ms step_avg:92.79ms +step:1137/1670 train_time:105508ms step_avg:92.79ms +step:1138/1670 train_time:105606ms step_avg:92.80ms +step:1139/1670 train_time:105702ms step_avg:92.80ms +step:1140/1670 train_time:105794ms step_avg:92.80ms +step:1141/1670 train_time:105887ms step_avg:92.80ms +step:1142/1670 train_time:105979ms step_avg:92.80ms +step:1143/1670 train_time:106072ms step_avg:92.80ms +step:1144/1670 train_time:106165ms step_avg:92.80ms +step:1145/1670 train_time:106256ms step_avg:92.80ms +step:1146/1670 train_time:106348ms step_avg:92.80ms +step:1147/1670 train_time:106441ms step_avg:92.80ms +step:1148/1670 train_time:106536ms step_avg:92.80ms +step:1149/1670 train_time:106630ms step_avg:92.80ms +step:1150/1670 train_time:106725ms step_avg:92.80ms +step:1151/1670 train_time:106819ms step_avg:92.81ms +step:1152/1670 train_time:106911ms step_avg:92.80ms +step:1153/1670 train_time:107004ms step_avg:92.80ms +step:1154/1670 train_time:107097ms step_avg:92.80ms +step:1155/1670 train_time:107189ms step_avg:92.80ms +step:1156/1670 train_time:107281ms step_avg:92.80ms +step:1157/1670 train_time:107373ms step_avg:92.80ms +step:1158/1670 train_time:107468ms step_avg:92.81ms +step:1159/1670 train_time:107562ms step_avg:92.81ms +step:1160/1670 train_time:107656ms step_avg:92.81ms +step:1161/1670 train_time:107749ms step_avg:92.81ms +step:1162/1670 train_time:107842ms step_avg:92.81ms +step:1163/1670 train_time:107935ms step_avg:92.81ms +step:1164/1670 train_time:108028ms step_avg:92.81ms +step:1165/1670 train_time:108121ms step_avg:92.81ms +step:1166/1670 train_time:108210ms step_avg:92.80ms +step:1167/1670 train_time:108304ms step_avg:92.81ms +step:1168/1670 train_time:108398ms step_avg:92.81ms +step:1169/1670 train_time:108491ms step_avg:92.81ms +step:1170/1670 train_time:108584ms step_avg:92.81ms +step:1171/1670 train_time:108679ms step_avg:92.81ms +step:1172/1670 train_time:108771ms step_avg:92.81ms +step:1173/1670 train_time:108864ms step_avg:92.81ms +step:1174/1670 train_time:108957ms step_avg:92.81ms +step:1175/1670 train_time:109050ms step_avg:92.81ms +step:1176/1670 train_time:109144ms step_avg:92.81ms +step:1177/1670 train_time:109237ms step_avg:92.81ms +step:1178/1670 train_time:109329ms step_avg:92.81ms +step:1179/1670 train_time:109423ms step_avg:92.81ms +step:1180/1670 train_time:109516ms step_avg:92.81ms +step:1181/1670 train_time:109610ms step_avg:92.81ms +step:1182/1670 train_time:109704ms step_avg:92.81ms +step:1183/1670 train_time:109797ms step_avg:92.81ms +step:1184/1670 train_time:109889ms step_avg:92.81ms +step:1185/1670 train_time:109982ms step_avg:92.81ms +step:1186/1670 train_time:110075ms step_avg:92.81ms +step:1187/1670 train_time:110167ms step_avg:92.81ms +step:1188/1670 train_time:110260ms step_avg:92.81ms +step:1189/1670 train_time:110353ms step_avg:92.81ms +step:1190/1670 train_time:110446ms step_avg:92.81ms +step:1191/1670 train_time:110540ms step_avg:92.81ms +step:1192/1670 train_time:110633ms step_avg:92.81ms +step:1193/1670 train_time:110727ms step_avg:92.81ms +step:1194/1670 train_time:110821ms step_avg:92.81ms +step:1195/1670 train_time:110914ms step_avg:92.81ms +step:1196/1670 train_time:111006ms step_avg:92.81ms +step:1197/1670 train_time:111099ms step_avg:92.81ms +step:1198/1670 train_time:111191ms step_avg:92.81ms +step:1199/1670 train_time:111283ms step_avg:92.81ms +step:1200/1670 train_time:111376ms step_avg:92.81ms +step:1201/1670 train_time:111469ms step_avg:92.81ms +step:1202/1670 train_time:111563ms step_avg:92.81ms +step:1203/1670 train_time:111657ms step_avg:92.82ms +step:1204/1670 train_time:111750ms step_avg:92.82ms +step:1205/1670 train_time:111843ms step_avg:92.82ms +step:1206/1670 train_time:111937ms step_avg:92.82ms +step:1207/1670 train_time:112029ms step_avg:92.82ms +step:1208/1670 train_time:112124ms step_avg:92.82ms +step:1209/1670 train_time:112217ms step_avg:92.82ms +step:1210/1670 train_time:112310ms step_avg:92.82ms +step:1211/1670 train_time:112404ms step_avg:92.82ms +step:1212/1670 train_time:112497ms step_avg:92.82ms +step:1213/1670 train_time:112589ms step_avg:92.82ms +step:1214/1670 train_time:112682ms step_avg:92.82ms +step:1215/1670 train_time:112775ms step_avg:92.82ms +step:1216/1670 train_time:112868ms step_avg:92.82ms +step:1217/1670 train_time:112962ms step_avg:92.82ms +step:1218/1670 train_time:113054ms step_avg:92.82ms +step:1219/1670 train_time:113147ms step_avg:92.82ms +step:1220/1670 train_time:113240ms step_avg:92.82ms +step:1221/1670 train_time:113333ms step_avg:92.82ms +step:1222/1670 train_time:113427ms step_avg:92.82ms +step:1223/1670 train_time:113519ms step_avg:92.82ms +step:1224/1670 train_time:113612ms step_avg:92.82ms +step:1225/1670 train_time:113705ms step_avg:92.82ms +step:1226/1670 train_time:113799ms step_avg:92.82ms +step:1227/1670 train_time:113892ms step_avg:92.82ms +step:1228/1670 train_time:113984ms step_avg:92.82ms +step:1229/1670 train_time:114078ms step_avg:92.82ms +step:1230/1670 train_time:114171ms step_avg:92.82ms +step:1231/1670 train_time:114265ms step_avg:92.82ms +step:1232/1670 train_time:114358ms step_avg:92.82ms +step:1233/1670 train_time:114451ms step_avg:92.82ms +step:1234/1670 train_time:114545ms step_avg:92.82ms +step:1235/1670 train_time:114639ms step_avg:92.82ms +step:1236/1670 train_time:114731ms step_avg:92.82ms +step:1237/1670 train_time:114825ms step_avg:92.83ms +step:1238/1670 train_time:114918ms step_avg:92.83ms +step:1239/1670 train_time:115010ms step_avg:92.82ms +step:1240/1670 train_time:115103ms step_avg:92.83ms +step:1241/1670 train_time:115197ms step_avg:92.83ms +step:1242/1670 train_time:115291ms step_avg:92.83ms +step:1243/1670 train_time:115384ms step_avg:92.83ms +step:1244/1670 train_time:115477ms step_avg:92.83ms +step:1245/1670 train_time:115570ms step_avg:92.83ms +step:1246/1670 train_time:115664ms step_avg:92.83ms +step:1247/1670 train_time:115756ms step_avg:92.83ms +step:1248/1670 train_time:115849ms step_avg:92.83ms +step:1249/1670 train_time:115941ms step_avg:92.83ms +step:1250/1670 train_time:116035ms step_avg:92.83ms +step:1250/1670 val_loss:3.3759 train_time:116127ms step_avg:92.90ms +step:1251/1670 train_time:116148ms step_avg:92.84ms +step:1252/1670 train_time:116222ms step_avg:92.83ms +step:1253/1670 train_time:116316ms step_avg:92.83ms +step:1254/1670 train_time:116408ms step_avg:92.83ms +step:1255/1670 train_time:116500ms step_avg:92.83ms +step:1256/1670 train_time:116592ms step_avg:92.83ms +step:1257/1670 train_time:116684ms step_avg:92.83ms +step:1258/1670 train_time:116777ms step_avg:92.83ms +step:1259/1670 train_time:116871ms step_avg:92.83ms +step:1260/1670 train_time:116964ms step_avg:92.83ms +step:1261/1670 train_time:117059ms step_avg:92.83ms +step:1262/1670 train_time:117153ms step_avg:92.83ms +step:1263/1670 train_time:117247ms step_avg:92.83ms +step:1264/1670 train_time:117340ms step_avg:92.83ms +step:1265/1670 train_time:117433ms step_avg:92.83ms +step:1266/1670 train_time:117525ms step_avg:92.83ms +step:1267/1670 train_time:117618ms step_avg:92.83ms +step:1268/1670 train_time:117711ms step_avg:92.83ms +step:1269/1670 train_time:117803ms step_avg:92.83ms +step:1270/1670 train_time:117897ms step_avg:92.83ms +step:1271/1670 train_time:117990ms step_avg:92.83ms +step:1272/1670 train_time:118084ms step_avg:92.83ms +step:1273/1670 train_time:118178ms step_avg:92.83ms +step:1274/1670 train_time:118419ms step_avg:92.95ms +step:1275/1670 train_time:118499ms step_avg:92.94ms +step:1276/1670 train_time:118591ms step_avg:92.94ms +step:1277/1670 train_time:118683ms step_avg:92.94ms +step:1278/1670 train_time:118775ms step_avg:92.94ms +step:1279/1670 train_time:118866ms step_avg:92.94ms +step:1280/1670 train_time:118958ms step_avg:92.94ms +step:1281/1670 train_time:119049ms step_avg:92.93ms +step:1282/1670 train_time:119141ms step_avg:92.93ms +step:1283/1670 train_time:119233ms step_avg:92.93ms +step:1284/1670 train_time:119331ms step_avg:92.94ms +step:1285/1670 train_time:119428ms step_avg:92.94ms +step:1286/1670 train_time:119522ms step_avg:92.94ms +step:1287/1670 train_time:119615ms step_avg:92.94ms +step:1288/1670 train_time:119708ms step_avg:92.94ms +step:1289/1670 train_time:119800ms step_avg:92.94ms +step:1290/1670 train_time:119893ms step_avg:92.94ms +step:1291/1670 train_time:119986ms step_avg:92.94ms +step:1292/1670 train_time:120078ms step_avg:92.94ms +step:1293/1670 train_time:120170ms step_avg:92.94ms +step:1294/1670 train_time:120264ms step_avg:92.94ms +step:1295/1670 train_time:120358ms step_avg:92.94ms +step:1296/1670 train_time:120453ms step_avg:92.94ms +step:1297/1670 train_time:120548ms step_avg:92.94ms +step:1298/1670 train_time:120640ms step_avg:92.94ms +step:1299/1670 train_time:120733ms step_avg:92.94ms +step:1300/1670 train_time:120827ms step_avg:92.94ms +step:1301/1670 train_time:120920ms step_avg:92.94ms +step:1302/1670 train_time:121012ms step_avg:92.94ms +step:1303/1670 train_time:121104ms step_avg:92.94ms +step:1304/1670 train_time:121197ms step_avg:92.94ms +step:1305/1670 train_time:121290ms step_avg:92.94ms +step:1306/1670 train_time:121383ms step_avg:92.94ms +step:1307/1670 train_time:121477ms step_avg:92.94ms +step:1308/1670 train_time:121571ms step_avg:92.94ms +step:1309/1670 train_time:121664ms step_avg:92.94ms +step:1310/1670 train_time:121758ms step_avg:92.94ms +step:1311/1670 train_time:121851ms step_avg:92.95ms +step:1312/1670 train_time:121944ms step_avg:92.95ms +step:1313/1670 train_time:122036ms step_avg:92.94ms +step:1314/1670 train_time:122129ms step_avg:92.94ms +step:1315/1670 train_time:122221ms step_avg:92.94ms +step:1316/1670 train_time:122315ms step_avg:92.94ms +step:1317/1670 train_time:122409ms step_avg:92.94ms +step:1318/1670 train_time:122502ms step_avg:92.95ms +step:1319/1670 train_time:122596ms step_avg:92.95ms +step:1320/1670 train_time:122689ms step_avg:92.95ms +step:1321/1670 train_time:122782ms step_avg:92.95ms +step:1322/1670 train_time:122877ms step_avg:92.95ms +step:1323/1670 train_time:122971ms step_avg:92.95ms +step:1324/1670 train_time:123063ms step_avg:92.95ms +step:1325/1670 train_time:123157ms step_avg:92.95ms +step:1326/1670 train_time:123251ms step_avg:92.95ms +step:1327/1670 train_time:123343ms step_avg:92.95ms +step:1328/1670 train_time:123436ms step_avg:92.95ms +step:1329/1670 train_time:123530ms step_avg:92.95ms +step:1330/1670 train_time:123623ms step_avg:92.95ms +step:1331/1670 train_time:123716ms step_avg:92.95ms +step:1332/1670 train_time:123809ms step_avg:92.95ms +step:1333/1670 train_time:123901ms step_avg:92.95ms +step:1334/1670 train_time:123994ms step_avg:92.95ms +step:1335/1670 train_time:124088ms step_avg:92.95ms +step:1336/1670 train_time:124181ms step_avg:92.95ms +step:1337/1670 train_time:124274ms step_avg:92.95ms +step:1338/1670 train_time:124367ms step_avg:92.95ms +step:1339/1670 train_time:124461ms step_avg:92.95ms +step:1340/1670 train_time:124555ms step_avg:92.95ms +step:1341/1670 train_time:124648ms step_avg:92.95ms +step:1342/1670 train_time:124741ms step_avg:92.95ms +step:1343/1670 train_time:124834ms step_avg:92.95ms +step:1344/1670 train_time:124926ms step_avg:92.95ms +step:1345/1670 train_time:125020ms step_avg:92.95ms +step:1346/1670 train_time:125113ms step_avg:92.95ms +step:1347/1670 train_time:125206ms step_avg:92.95ms +step:1348/1670 train_time:125299ms step_avg:92.95ms +step:1349/1670 train_time:125393ms step_avg:92.95ms +step:1350/1670 train_time:125486ms step_avg:92.95ms +step:1351/1670 train_time:125578ms step_avg:92.95ms +step:1352/1670 train_time:125671ms step_avg:92.95ms +step:1353/1670 train_time:125765ms step_avg:92.95ms +step:1354/1670 train_time:125858ms step_avg:92.95ms +step:1355/1670 train_time:125950ms step_avg:92.95ms +step:1356/1670 train_time:126043ms step_avg:92.95ms +step:1357/1670 train_time:126136ms step_avg:92.95ms +step:1358/1670 train_time:126229ms step_avg:92.95ms +step:1359/1670 train_time:126322ms step_avg:92.95ms +step:1360/1670 train_time:126415ms step_avg:92.95ms +step:1361/1670 train_time:126508ms step_avg:92.95ms +step:1362/1670 train_time:126600ms step_avg:92.95ms +step:1363/1670 train_time:126693ms step_avg:92.95ms +step:1364/1670 train_time:126787ms step_avg:92.95ms +step:1365/1670 train_time:126880ms step_avg:92.95ms +step:1366/1670 train_time:126974ms step_avg:92.95ms +step:1367/1670 train_time:127068ms step_avg:92.95ms +step:1368/1670 train_time:127161ms step_avg:92.95ms +step:1369/1670 train_time:127254ms step_avg:92.95ms +step:1370/1670 train_time:127347ms step_avg:92.95ms +step:1371/1670 train_time:127440ms step_avg:92.95ms +step:1372/1670 train_time:127533ms step_avg:92.95ms +step:1373/1670 train_time:127626ms step_avg:92.95ms +step:1374/1670 train_time:127719ms step_avg:92.95ms +step:1375/1670 train_time:127812ms step_avg:92.95ms +step:1375/1670 val_loss:3.3418 train_time:127904ms step_avg:93.02ms +step:1376/1670 train_time:127924ms step_avg:92.97ms +step:1377/1670 train_time:127998ms step_avg:92.95ms +step:1378/1670 train_time:128090ms step_avg:92.95ms +step:1379/1670 train_time:128183ms step_avg:92.95ms +step:1380/1670 train_time:128276ms step_avg:92.95ms +step:1381/1670 train_time:128368ms step_avg:92.95ms +step:1382/1670 train_time:128461ms step_avg:92.95ms +step:1383/1670 train_time:128553ms step_avg:92.95ms +step:1384/1670 train_time:128646ms step_avg:92.95ms +step:1385/1670 train_time:128740ms step_avg:92.95ms +step:1386/1670 train_time:128835ms step_avg:92.95ms +step:1387/1670 train_time:128930ms step_avg:92.96ms +step:1388/1670 train_time:129025ms step_avg:92.96ms +step:1389/1670 train_time:129118ms step_avg:92.96ms +step:1390/1670 train_time:129212ms step_avg:92.96ms +step:1391/1670 train_time:129306ms step_avg:92.96ms +step:1392/1670 train_time:129399ms step_avg:92.96ms +step:1393/1670 train_time:129491ms step_avg:92.96ms +step:1394/1670 train_time:129583ms step_avg:92.96ms +step:1395/1670 train_time:129675ms step_avg:92.96ms +step:1396/1670 train_time:129768ms step_avg:92.96ms +step:1397/1670 train_time:129863ms step_avg:92.96ms +step:1398/1670 train_time:129957ms step_avg:92.96ms +step:1399/1670 train_time:130050ms step_avg:92.96ms +step:1400/1670 train_time:130143ms step_avg:92.96ms +step:1401/1670 train_time:130239ms step_avg:92.96ms +step:1402/1670 train_time:130331ms step_avg:92.96ms +step:1403/1670 train_time:130425ms step_avg:92.96ms +step:1404/1670 train_time:130518ms step_avg:92.96ms +step:1405/1670 train_time:130610ms step_avg:92.96ms +step:1406/1670 train_time:130704ms step_avg:92.96ms +step:1407/1670 train_time:130797ms step_avg:92.96ms +step:1408/1670 train_time:130890ms step_avg:92.96ms +step:1409/1670 train_time:130984ms step_avg:92.96ms +step:1410/1670 train_time:131077ms step_avg:92.96ms +step:1411/1670 train_time:131170ms step_avg:92.96ms +step:1412/1670 train_time:131265ms step_avg:92.96ms +step:1413/1670 train_time:131358ms step_avg:92.96ms +step:1414/1670 train_time:131451ms step_avg:92.96ms +step:1415/1670 train_time:131544ms step_avg:92.96ms +step:1416/1670 train_time:131637ms step_avg:92.96ms +step:1417/1670 train_time:131729ms step_avg:92.96ms +step:1418/1670 train_time:131822ms step_avg:92.96ms +step:1419/1670 train_time:131915ms step_avg:92.96ms +step:1420/1670 train_time:132007ms step_avg:92.96ms +step:1421/1670 train_time:132101ms step_avg:92.96ms +step:1422/1670 train_time:132194ms step_avg:92.96ms +step:1423/1670 train_time:132287ms step_avg:92.96ms +step:1424/1670 train_time:132380ms step_avg:92.96ms +step:1425/1670 train_time:132474ms step_avg:92.96ms +step:1426/1670 train_time:132567ms step_avg:92.96ms +step:1427/1670 train_time:132660ms step_avg:92.96ms +step:1428/1670 train_time:132752ms step_avg:92.96ms +step:1429/1670 train_time:132847ms step_avg:92.96ms +step:1430/1670 train_time:132940ms step_avg:92.97ms +step:1431/1670 train_time:133033ms step_avg:92.96ms +step:1432/1670 train_time:133126ms step_avg:92.96ms +step:1433/1670 train_time:133218ms step_avg:92.96ms +step:1434/1670 train_time:133311ms step_avg:92.96ms +step:1435/1670 train_time:133403ms step_avg:92.96ms +step:1436/1670 train_time:133497ms step_avg:92.96ms +step:1437/1670 train_time:133589ms step_avg:92.96ms +step:1438/1670 train_time:133682ms step_avg:92.96ms +step:1439/1670 train_time:133775ms step_avg:92.96ms +step:1440/1670 train_time:133868ms step_avg:92.96ms +step:1441/1670 train_time:133961ms step_avg:92.96ms +step:1442/1670 train_time:134054ms step_avg:92.96ms +step:1443/1670 train_time:134148ms step_avg:92.96ms +step:1444/1670 train_time:134241ms step_avg:92.96ms +step:1445/1670 train_time:134335ms step_avg:92.97ms +step:1446/1670 train_time:134428ms step_avg:92.97ms +step:1447/1670 train_time:134522ms step_avg:92.97ms +step:1448/1670 train_time:134615ms step_avg:92.97ms +step:1449/1670 train_time:134709ms step_avg:92.97ms +step:1450/1670 train_time:134803ms step_avg:92.97ms +step:1451/1670 train_time:134897ms step_avg:92.97ms +step:1452/1670 train_time:134990ms step_avg:92.97ms +step:1453/1670 train_time:135082ms step_avg:92.97ms +step:1454/1670 train_time:135175ms step_avg:92.97ms +step:1455/1670 train_time:135269ms step_avg:92.97ms +step:1456/1670 train_time:135362ms step_avg:92.97ms +step:1457/1670 train_time:135454ms step_avg:92.97ms +step:1458/1670 train_time:135548ms step_avg:92.97ms +step:1459/1670 train_time:135641ms step_avg:92.97ms +step:1460/1670 train_time:135735ms step_avg:92.97ms +step:1461/1670 train_time:135828ms step_avg:92.97ms +step:1462/1670 train_time:135921ms step_avg:92.97ms +step:1463/1670 train_time:136014ms step_avg:92.97ms +step:1464/1670 train_time:136108ms step_avg:92.97ms +step:1465/1670 train_time:136202ms step_avg:92.97ms +step:1466/1670 train_time:136294ms step_avg:92.97ms +step:1467/1670 train_time:136387ms step_avg:92.97ms +step:1468/1670 train_time:136480ms step_avg:92.97ms +step:1469/1670 train_time:136572ms step_avg:92.97ms +step:1470/1670 train_time:136666ms step_avg:92.97ms +step:1471/1670 train_time:136759ms step_avg:92.97ms +step:1472/1670 train_time:136852ms step_avg:92.97ms +step:1473/1670 train_time:136946ms step_avg:92.97ms +step:1474/1670 train_time:137040ms step_avg:92.97ms +step:1475/1670 train_time:137133ms step_avg:92.97ms +step:1476/1670 train_time:137227ms step_avg:92.97ms +step:1477/1670 train_time:137321ms step_avg:92.97ms +step:1478/1670 train_time:137413ms step_avg:92.97ms +step:1479/1670 train_time:137506ms step_avg:92.97ms +step:1480/1670 train_time:137598ms step_avg:92.97ms +step:1481/1670 train_time:137690ms step_avg:92.97ms +step:1482/1670 train_time:137783ms step_avg:92.97ms +step:1483/1670 train_time:137877ms step_avg:92.97ms +step:1484/1670 train_time:137970ms step_avg:92.97ms +step:1485/1670 train_time:138221ms step_avg:93.08ms +step:1486/1670 train_time:138293ms step_avg:93.06ms +step:1487/1670 train_time:138385ms step_avg:93.06ms +step:1488/1670 train_time:138477ms step_avg:93.06ms +step:1489/1670 train_time:138568ms step_avg:93.06ms +step:1490/1670 train_time:138660ms step_avg:93.06ms +step:1491/1670 train_time:138752ms step_avg:93.06ms +step:1492/1670 train_time:138844ms step_avg:93.06ms +step:1493/1670 train_time:138936ms step_avg:93.06ms +step:1494/1670 train_time:139028ms step_avg:93.06ms +step:1495/1670 train_time:139125ms step_avg:93.06ms +step:1496/1670 train_time:139224ms step_avg:93.06ms +step:1497/1670 train_time:139319ms step_avg:93.07ms +step:1498/1670 train_time:139412ms step_avg:93.07ms +step:1499/1670 train_time:139505ms step_avg:93.07ms +step:1500/1670 train_time:139597ms step_avg:93.06ms +step:1500/1670 val_loss:3.3115 train_time:139690ms step_avg:93.13ms +step:1501/1670 train_time:139710ms step_avg:93.08ms +step:1502/1670 train_time:139784ms step_avg:93.06ms +step:1503/1670 train_time:139876ms step_avg:93.06ms +step:1504/1670 train_time:139968ms step_avg:93.06ms +step:1505/1670 train_time:140060ms step_avg:93.06ms +step:1506/1670 train_time:140152ms step_avg:93.06ms +step:1507/1670 train_time:140246ms step_avg:93.06ms +step:1508/1670 train_time:140342ms step_avg:93.07ms +step:1509/1670 train_time:140436ms step_avg:93.07ms +step:1510/1670 train_time:140528ms step_avg:93.06ms +step:1511/1670 train_time:140622ms step_avg:93.07ms +step:1512/1670 train_time:140716ms step_avg:93.07ms +step:1513/1670 train_time:140809ms step_avg:93.07ms +step:1514/1670 train_time:140901ms step_avg:93.07ms +step:1515/1670 train_time:140993ms step_avg:93.06ms +step:1516/1670 train_time:141087ms step_avg:93.07ms +step:1517/1670 train_time:141180ms step_avg:93.06ms +step:1518/1670 train_time:141273ms step_avg:93.07ms +step:1519/1670 train_time:141368ms step_avg:93.07ms +step:1520/1670 train_time:141462ms step_avg:93.07ms +step:1521/1670 train_time:141555ms step_avg:93.07ms +step:1522/1670 train_time:141650ms step_avg:93.07ms +step:1523/1670 train_time:141744ms step_avg:93.07ms +step:1524/1670 train_time:141837ms step_avg:93.07ms +step:1525/1670 train_time:141930ms step_avg:93.07ms +step:1526/1670 train_time:142022ms step_avg:93.07ms +step:1527/1670 train_time:142114ms step_avg:93.07ms +step:1528/1670 train_time:142208ms step_avg:93.07ms +step:1529/1670 train_time:142301ms step_avg:93.07ms +step:1530/1670 train_time:142394ms step_avg:93.07ms +step:1531/1670 train_time:142489ms step_avg:93.07ms +step:1532/1670 train_time:142583ms step_avg:93.07ms +step:1533/1670 train_time:142677ms step_avg:93.07ms +step:1534/1670 train_time:142770ms step_avg:93.07ms +step:1535/1670 train_time:142863ms step_avg:93.07ms +step:1536/1670 train_time:142955ms step_avg:93.07ms +step:1537/1670 train_time:143047ms step_avg:93.07ms +step:1538/1670 train_time:143140ms step_avg:93.07ms +step:1539/1670 train_time:143233ms step_avg:93.07ms +step:1540/1670 train_time:143326ms step_avg:93.07ms +step:1541/1670 train_time:143421ms step_avg:93.07ms +step:1542/1670 train_time:143514ms step_avg:93.07ms +step:1543/1670 train_time:143607ms step_avg:93.07ms +step:1544/1670 train_time:143701ms step_avg:93.07ms +step:1545/1670 train_time:143793ms step_avg:93.07ms +step:1546/1670 train_time:143887ms step_avg:93.07ms +step:1547/1670 train_time:143979ms step_avg:93.07ms +step:1548/1670 train_time:144072ms step_avg:93.07ms +step:1549/1670 train_time:144164ms step_avg:93.07ms +step:1550/1670 train_time:144257ms step_avg:93.07ms +step:1551/1670 train_time:144351ms step_avg:93.07ms +step:1552/1670 train_time:144445ms step_avg:93.07ms +step:1553/1670 train_time:144538ms step_avg:93.07ms +step:1554/1670 train_time:144631ms step_avg:93.07ms +step:1555/1670 train_time:144723ms step_avg:93.07ms +step:1556/1670 train_time:144817ms step_avg:93.07ms +step:1557/1670 train_time:144911ms step_avg:93.07ms +step:1558/1670 train_time:145003ms step_avg:93.07ms +step:1559/1670 train_time:145095ms step_avg:93.07ms +step:1560/1670 train_time:145190ms step_avg:93.07ms +step:1561/1670 train_time:145283ms step_avg:93.07ms +step:1562/1670 train_time:145376ms step_avg:93.07ms +step:1563/1670 train_time:145471ms step_avg:93.07ms +step:1564/1670 train_time:145563ms step_avg:93.07ms +step:1565/1670 train_time:145656ms step_avg:93.07ms +step:1566/1670 train_time:145749ms step_avg:93.07ms +step:1567/1670 train_time:145842ms step_avg:93.07ms +step:1568/1670 train_time:145936ms step_avg:93.07ms +step:1569/1670 train_time:146029ms step_avg:93.07ms +step:1570/1670 train_time:146121ms step_avg:93.07ms +step:1571/1670 train_time:146215ms step_avg:93.07ms +step:1572/1670 train_time:146308ms step_avg:93.07ms +step:1573/1670 train_time:146402ms step_avg:93.07ms +step:1574/1670 train_time:146495ms step_avg:93.07ms +step:1575/1670 train_time:146589ms step_avg:93.07ms +step:1576/1670 train_time:146683ms step_avg:93.07ms +step:1577/1670 train_time:146776ms step_avg:93.07ms +step:1578/1670 train_time:146870ms step_avg:93.07ms +step:1579/1670 train_time:146965ms step_avg:93.07ms +step:1580/1670 train_time:147058ms step_avg:93.07ms +step:1581/1670 train_time:147151ms step_avg:93.07ms +step:1582/1670 train_time:147244ms step_avg:93.07ms +step:1583/1670 train_time:147338ms step_avg:93.08ms +step:1584/1670 train_time:147431ms step_avg:93.08ms +step:1585/1670 train_time:147524ms step_avg:93.07ms +step:1586/1670 train_time:147617ms step_avg:93.08ms +step:1587/1670 train_time:147710ms step_avg:93.08ms +step:1588/1670 train_time:147802ms step_avg:93.07ms +step:1589/1670 train_time:147895ms step_avg:93.07ms +step:1590/1670 train_time:147989ms step_avg:93.07ms +step:1591/1670 train_time:148083ms step_avg:93.08ms +step:1592/1670 train_time:148175ms step_avg:93.07ms +step:1593/1670 train_time:148269ms step_avg:93.08ms +step:1594/1670 train_time:148362ms step_avg:93.08ms +step:1595/1670 train_time:148455ms step_avg:93.08ms +step:1596/1670 train_time:148549ms step_avg:93.08ms +step:1597/1670 train_time:148643ms step_avg:93.08ms +step:1598/1670 train_time:148734ms step_avg:93.08ms +step:1599/1670 train_time:148827ms step_avg:93.07ms +step:1600/1670 train_time:148920ms step_avg:93.07ms +step:1601/1670 train_time:149013ms step_avg:93.07ms +step:1602/1670 train_time:149106ms step_avg:93.08ms +step:1603/1670 train_time:149199ms step_avg:93.07ms +step:1604/1670 train_time:149292ms step_avg:93.07ms +step:1605/1670 train_time:149386ms step_avg:93.08ms +step:1606/1670 train_time:149479ms step_avg:93.08ms +step:1607/1670 train_time:149572ms step_avg:93.08ms +step:1608/1670 train_time:149666ms step_avg:93.08ms +step:1609/1670 train_time:149759ms step_avg:93.08ms +step:1610/1670 train_time:149853ms step_avg:93.08ms +step:1611/1670 train_time:149946ms step_avg:93.08ms +step:1612/1670 train_time:150039ms step_avg:93.08ms +step:1613/1670 train_time:150132ms step_avg:93.08ms +step:1614/1670 train_time:150225ms step_avg:93.08ms +step:1615/1670 train_time:150319ms step_avg:93.08ms +step:1616/1670 train_time:150412ms step_avg:93.08ms +step:1617/1670 train_time:150505ms step_avg:93.08ms +step:1618/1670 train_time:150597ms step_avg:93.08ms +step:1619/1670 train_time:150692ms step_avg:93.08ms +step:1620/1670 train_time:150785ms step_avg:93.08ms +step:1621/1670 train_time:150878ms step_avg:93.08ms +step:1622/1670 train_time:150971ms step_avg:93.08ms +step:1623/1670 train_time:151064ms step_avg:93.08ms +step:1624/1670 train_time:151156ms step_avg:93.08ms +step:1625/1670 train_time:151251ms step_avg:93.08ms +step:1625/1670 val_loss:3.2866 train_time:151345ms step_avg:93.14ms +step:1626/1670 train_time:151370ms step_avg:93.09ms +step:1627/1670 train_time:151440ms step_avg:93.08ms +step:1628/1670 train_time:151535ms step_avg:93.08ms +step:1629/1670 train_time:151628ms step_avg:93.08ms +step:1630/1670 train_time:151720ms step_avg:93.08ms +step:1631/1670 train_time:151814ms step_avg:93.08ms +step:1632/1670 train_time:151908ms step_avg:93.08ms +step:1633/1670 train_time:152000ms step_avg:93.08ms +step:1634/1670 train_time:152093ms step_avg:93.08ms +step:1635/1670 train_time:152186ms step_avg:93.08ms +step:1636/1670 train_time:152281ms step_avg:93.08ms +step:1637/1670 train_time:152374ms step_avg:93.08ms +step:1638/1670 train_time:152468ms step_avg:93.08ms +step:1639/1670 train_time:152561ms step_avg:93.08ms +step:1640/1670 train_time:152656ms step_avg:93.08ms +step:1641/1670 train_time:152749ms step_avg:93.08ms +step:1642/1670 train_time:152842ms step_avg:93.08ms +step:1643/1670 train_time:152939ms step_avg:93.09ms +step:1644/1670 train_time:153031ms step_avg:93.08ms +step:1645/1670 train_time:153124ms step_avg:93.08ms +step:1646/1670 train_time:153217ms step_avg:93.08ms +step:1647/1670 train_time:153311ms step_avg:93.08ms +step:1648/1670 train_time:153404ms step_avg:93.09ms +step:1649/1670 train_time:153498ms step_avg:93.09ms +step:1650/1670 train_time:153592ms step_avg:93.09ms +step:1651/1670 train_time:153685ms step_avg:93.09ms +step:1652/1670 train_time:153778ms step_avg:93.09ms +step:1653/1670 train_time:153870ms step_avg:93.09ms +step:1654/1670 train_time:153963ms step_avg:93.09ms +step:1655/1670 train_time:154058ms step_avg:93.09ms +step:1656/1670 train_time:154151ms step_avg:93.09ms +step:1657/1670 train_time:154244ms step_avg:93.09ms +step:1658/1670 train_time:154336ms step_avg:93.09ms +step:1659/1670 train_time:154429ms step_avg:93.09ms +step:1660/1670 train_time:154522ms step_avg:93.09ms +step:1661/1670 train_time:154615ms step_avg:93.09ms +step:1662/1670 train_time:154708ms step_avg:93.09ms +step:1663/1670 train_time:154801ms step_avg:93.09ms +step:1664/1670 train_time:154896ms step_avg:93.09ms +step:1665/1670 train_time:154988ms step_avg:93.09ms +step:1666/1670 train_time:155082ms step_avg:93.09ms +step:1667/1670 train_time:155175ms step_avg:93.09ms +step:1668/1670 train_time:155268ms step_avg:93.09ms +step:1669/1670 train_time:155361ms step_avg:93.09ms +step:1670/1670 train_time:155455ms step_avg:93.09ms +step:1670/1670 val_loss:3.2781 train_time:155717ms step_avg:93.24ms +peak memory allocated: 32002 MiB reserved: 46556 MiB diff --git a/records/091125_VectSigmoidBFloat16/ab5b991b-3767-4092-851a-5c266ae5c1e2.txt b/records/091125_VectSigmoidBFloat16/ab5b991b-3767-4092-851a-5c266ae5c1e2.txt new file mode 100644 index 000000000..cef62018d --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/ab5b991b-3767-4092-851a-5c266ae5c1e2.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:59:20 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 125W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 132W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 46C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.06ms +step:1/1670 train_time:293ms step_avg:293.10ms +step:2/1670 train_time:311ms step_avg:155.46ms +step:3/1670 train_time:380ms step_avg:126.68ms +step:4/1670 train_time:469ms step_avg:117.30ms +step:5/1670 train_time:560ms step_avg:111.94ms +step:6/1670 train_time:651ms step_avg:108.52ms +step:7/1670 train_time:741ms step_avg:105.91ms +step:8/1670 train_time:832ms step_avg:103.98ms +step:9/1670 train_time:922ms step_avg:102.50ms +step:10/1670 train_time:1013ms step_avg:101.30ms +step:11/1670 train_time:1103ms step_avg:100.29ms +step:12/1670 train_time:1198ms step_avg:99.86ms +step:13/1670 train_time:1292ms step_avg:99.42ms +step:14/1670 train_time:1386ms step_avg:98.97ms +step:15/1670 train_time:1478ms step_avg:98.51ms +step:16/1670 train_time:1569ms step_avg:98.09ms +step:17/1670 train_time:1661ms step_avg:97.68ms +step:18/1670 train_time:1752ms step_avg:97.36ms +step:19/1670 train_time:1844ms step_avg:97.03ms +step:20/1670 train_time:1934ms step_avg:96.72ms +step:21/1670 train_time:2026ms step_avg:96.46ms +step:22/1670 train_time:2118ms step_avg:96.28ms +step:23/1670 train_time:2212ms step_avg:96.15ms +step:24/1670 train_time:2303ms step_avg:95.97ms +step:25/1670 train_time:2398ms step_avg:95.90ms +step:26/1670 train_time:2492ms step_avg:95.84ms +step:27/1670 train_time:2584ms step_avg:95.69ms +step:28/1670 train_time:2675ms step_avg:95.54ms +step:29/1670 train_time:2766ms step_avg:95.37ms +step:30/1670 train_time:2857ms step_avg:95.22ms +step:31/1670 train_time:2948ms step_avg:95.11ms +step:32/1670 train_time:3039ms step_avg:94.97ms +step:33/1670 train_time:3132ms step_avg:94.90ms +step:34/1670 train_time:3223ms step_avg:94.80ms +step:35/1670 train_time:3316ms step_avg:94.73ms +step:36/1670 train_time:3407ms step_avg:94.65ms +step:37/1670 train_time:3501ms step_avg:94.62ms +step:38/1670 train_time:3594ms step_avg:94.57ms +step:39/1670 train_time:3686ms step_avg:94.52ms +step:40/1670 train_time:3779ms step_avg:94.48ms +step:41/1670 train_time:3872ms step_avg:94.43ms +step:42/1670 train_time:3963ms step_avg:94.35ms +step:43/1670 train_time:4055ms step_avg:94.30ms +step:44/1670 train_time:4145ms step_avg:94.21ms +step:45/1670 train_time:4237ms step_avg:94.15ms +step:46/1670 train_time:4329ms step_avg:94.11ms +step:47/1670 train_time:4422ms step_avg:94.08ms +step:48/1670 train_time:4514ms step_avg:94.05ms +step:49/1670 train_time:4605ms step_avg:93.98ms +step:50/1670 train_time:4698ms step_avg:93.97ms +step:51/1670 train_time:4791ms step_avg:93.95ms +step:52/1670 train_time:4883ms step_avg:93.89ms +step:53/1670 train_time:4975ms step_avg:93.87ms +step:54/1670 train_time:5066ms step_avg:93.81ms +step:55/1670 train_time:5157ms step_avg:93.76ms +step:56/1670 train_time:5249ms step_avg:93.73ms +step:57/1670 train_time:5341ms step_avg:93.69ms +step:58/1670 train_time:5433ms step_avg:93.67ms +step:59/1670 train_time:5524ms step_avg:93.62ms +step:60/1670 train_time:5615ms step_avg:93.59ms +step:61/1670 train_time:5707ms step_avg:93.56ms +step:62/1670 train_time:5799ms step_avg:93.54ms +step:63/1670 train_time:5892ms step_avg:93.52ms +step:64/1670 train_time:5983ms step_avg:93.48ms +step:65/1670 train_time:6075ms step_avg:93.46ms +step:66/1670 train_time:6166ms step_avg:93.42ms +step:67/1670 train_time:6258ms step_avg:93.40ms +step:68/1670 train_time:6349ms step_avg:93.36ms +step:69/1670 train_time:6441ms step_avg:93.34ms +step:70/1670 train_time:6532ms step_avg:93.32ms +step:71/1670 train_time:6623ms step_avg:93.28ms +step:72/1670 train_time:6714ms step_avg:93.25ms +step:73/1670 train_time:6805ms step_avg:93.22ms +step:74/1670 train_time:6898ms step_avg:93.21ms +step:75/1670 train_time:6990ms step_avg:93.20ms +step:76/1670 train_time:7082ms step_avg:93.18ms +step:77/1670 train_time:7173ms step_avg:93.16ms +step:78/1670 train_time:7265ms step_avg:93.14ms +step:79/1670 train_time:7357ms step_avg:93.13ms +step:80/1670 train_time:7450ms step_avg:93.12ms +step:81/1670 train_time:7539ms step_avg:93.08ms +step:82/1670 train_time:7630ms step_avg:93.05ms +step:83/1670 train_time:7721ms step_avg:93.03ms +step:84/1670 train_time:7813ms step_avg:93.01ms +step:85/1670 train_time:7904ms step_avg:92.98ms +step:86/1670 train_time:7996ms step_avg:92.98ms +step:87/1670 train_time:8088ms step_avg:92.96ms +step:88/1670 train_time:8179ms step_avg:92.94ms +step:89/1670 train_time:8270ms step_avg:92.92ms +step:90/1670 train_time:8361ms step_avg:92.90ms +step:91/1670 train_time:8456ms step_avg:92.92ms +step:92/1670 train_time:8548ms step_avg:92.92ms +step:93/1670 train_time:8640ms step_avg:92.91ms +step:94/1670 train_time:8732ms step_avg:92.89ms +step:95/1670 train_time:8822ms step_avg:92.87ms +step:96/1670 train_time:8915ms step_avg:92.86ms +step:97/1670 train_time:9006ms step_avg:92.85ms +step:98/1670 train_time:9098ms step_avg:92.84ms +step:99/1670 train_time:9191ms step_avg:92.84ms +step:100/1670 train_time:9282ms step_avg:92.82ms +step:101/1670 train_time:9373ms step_avg:92.81ms +step:102/1670 train_time:9465ms step_avg:92.79ms +step:103/1670 train_time:9556ms step_avg:92.78ms +step:104/1670 train_time:9647ms step_avg:92.76ms +step:105/1670 train_time:9739ms step_avg:92.75ms +step:106/1670 train_time:9830ms step_avg:92.73ms +step:107/1670 train_time:9921ms step_avg:92.72ms +step:108/1670 train_time:10013ms step_avg:92.71ms +step:109/1670 train_time:10104ms step_avg:92.70ms +step:110/1670 train_time:10197ms step_avg:92.70ms +step:111/1670 train_time:10289ms step_avg:92.69ms +step:112/1670 train_time:10380ms step_avg:92.68ms +step:113/1670 train_time:10472ms step_avg:92.68ms +step:114/1670 train_time:10563ms step_avg:92.66ms +step:115/1670 train_time:10655ms step_avg:92.65ms +step:116/1670 train_time:10746ms step_avg:92.64ms +step:117/1670 train_time:10837ms step_avg:92.62ms +step:118/1670 train_time:10929ms step_avg:92.62ms +step:119/1670 train_time:11020ms step_avg:92.60ms +step:120/1670 train_time:11112ms step_avg:92.60ms +step:121/1670 train_time:11202ms step_avg:92.58ms +step:122/1670 train_time:11294ms step_avg:92.57ms +step:123/1670 train_time:11385ms step_avg:92.56ms +step:124/1670 train_time:11478ms step_avg:92.57ms +step:125/1670 train_time:11570ms step_avg:92.56ms +step:125/1670 val_loss:4.2895 train_time:11661ms step_avg:93.29ms +step:126/1670 train_time:11680ms step_avg:92.70ms +step:127/1670 train_time:11755ms step_avg:92.56ms +step:128/1670 train_time:11855ms step_avg:92.62ms +step:129/1670 train_time:11950ms step_avg:92.64ms +step:130/1670 train_time:12041ms step_avg:92.63ms +step:131/1670 train_time:12132ms step_avg:92.61ms +step:132/1670 train_time:12222ms step_avg:92.59ms +step:133/1670 train_time:12312ms step_avg:92.57ms +step:134/1670 train_time:12402ms step_avg:92.55ms +step:135/1670 train_time:12492ms step_avg:92.54ms +step:136/1670 train_time:12583ms step_avg:92.52ms +step:137/1670 train_time:12673ms step_avg:92.50ms +step:138/1670 train_time:12766ms step_avg:92.51ms +step:139/1670 train_time:12860ms step_avg:92.52ms +step:140/1670 train_time:12955ms step_avg:92.53ms +step:141/1670 train_time:13046ms step_avg:92.53ms +step:142/1670 train_time:13139ms step_avg:92.53ms +step:143/1670 train_time:13230ms step_avg:92.52ms +step:144/1670 train_time:13320ms step_avg:92.50ms +step:145/1670 train_time:13411ms step_avg:92.49ms +step:146/1670 train_time:13501ms step_avg:92.47ms +step:147/1670 train_time:13590ms step_avg:92.45ms +step:148/1670 train_time:13681ms step_avg:92.44ms +step:149/1670 train_time:13773ms step_avg:92.44ms +step:150/1670 train_time:13865ms step_avg:92.43ms +step:151/1670 train_time:13958ms step_avg:92.44ms +step:152/1670 train_time:14049ms step_avg:92.43ms +step:153/1670 train_time:14141ms step_avg:92.42ms +step:154/1670 train_time:14233ms step_avg:92.42ms +step:155/1670 train_time:14323ms step_avg:92.40ms +step:156/1670 train_time:14413ms step_avg:92.39ms +step:157/1670 train_time:14503ms step_avg:92.38ms +step:158/1670 train_time:14594ms step_avg:92.36ms +step:159/1670 train_time:14684ms step_avg:92.36ms +step:160/1670 train_time:14778ms step_avg:92.36ms +step:161/1670 train_time:14869ms step_avg:92.36ms +step:162/1670 train_time:14962ms step_avg:92.36ms +step:163/1670 train_time:15054ms step_avg:92.36ms +step:164/1670 train_time:15146ms step_avg:92.35ms +step:165/1670 train_time:15237ms step_avg:92.34ms +step:166/1670 train_time:15327ms step_avg:92.33ms +step:167/1670 train_time:15418ms step_avg:92.32ms +step:168/1670 train_time:15508ms step_avg:92.31ms +step:169/1670 train_time:15599ms step_avg:92.30ms +step:170/1670 train_time:15690ms step_avg:92.30ms +step:171/1670 train_time:15783ms step_avg:92.30ms +step:172/1670 train_time:15875ms step_avg:92.30ms +step:173/1670 train_time:15966ms step_avg:92.29ms +step:174/1670 train_time:16058ms step_avg:92.29ms +step:175/1670 train_time:16150ms step_avg:92.29ms +step:176/1670 train_time:16242ms step_avg:92.28ms +step:177/1670 train_time:16333ms step_avg:92.28ms +step:178/1670 train_time:16423ms step_avg:92.26ms +step:179/1670 train_time:16514ms step_avg:92.25ms +step:180/1670 train_time:16604ms step_avg:92.24ms +step:181/1670 train_time:16696ms step_avg:92.24ms +step:182/1670 train_time:16786ms step_avg:92.23ms +step:183/1670 train_time:16879ms step_avg:92.24ms +step:184/1670 train_time:16971ms step_avg:92.24ms +step:185/1670 train_time:17063ms step_avg:92.23ms +step:186/1670 train_time:17155ms step_avg:92.23ms +step:187/1670 train_time:17247ms step_avg:92.23ms +step:188/1670 train_time:17339ms step_avg:92.23ms +step:189/1670 train_time:17430ms step_avg:92.22ms +step:190/1670 train_time:17520ms step_avg:92.21ms +step:191/1670 train_time:17611ms step_avg:92.20ms +step:192/1670 train_time:17703ms step_avg:92.20ms +step:193/1670 train_time:17795ms step_avg:92.20ms +step:194/1670 train_time:17885ms step_avg:92.19ms +step:195/1670 train_time:17977ms step_avg:92.19ms +step:196/1670 train_time:18069ms step_avg:92.19ms +step:197/1670 train_time:18161ms step_avg:92.19ms +step:198/1670 train_time:18253ms step_avg:92.19ms +step:199/1670 train_time:18345ms step_avg:92.18ms +step:200/1670 train_time:18438ms step_avg:92.19ms +step:201/1670 train_time:18527ms step_avg:92.17ms +step:202/1670 train_time:18619ms step_avg:92.17ms +step:203/1670 train_time:18709ms step_avg:92.16ms +step:204/1670 train_time:18801ms step_avg:92.16ms +step:205/1670 train_time:18891ms step_avg:92.15ms +step:206/1670 train_time:18982ms step_avg:92.15ms +step:207/1670 train_time:19073ms step_avg:92.14ms +step:208/1670 train_time:19164ms step_avg:92.14ms +step:209/1670 train_time:19257ms step_avg:92.14ms +step:210/1670 train_time:19347ms step_avg:92.13ms +step:211/1670 train_time:19438ms step_avg:92.12ms +step:212/1670 train_time:19529ms step_avg:92.12ms +step:213/1670 train_time:19782ms step_avg:92.87ms +step:214/1670 train_time:19851ms step_avg:92.76ms +step:215/1670 train_time:19941ms step_avg:92.75ms +step:216/1670 train_time:20031ms step_avg:92.74ms +step:217/1670 train_time:20121ms step_avg:92.73ms +step:218/1670 train_time:20211ms step_avg:92.71ms +step:219/1670 train_time:20301ms step_avg:92.70ms +step:220/1670 train_time:20391ms step_avg:92.69ms +step:221/1670 train_time:20481ms step_avg:92.67ms +step:222/1670 train_time:20572ms step_avg:92.67ms +step:223/1670 train_time:20666ms step_avg:92.67ms +step:224/1670 train_time:20761ms step_avg:92.68ms +step:225/1670 train_time:20854ms step_avg:92.69ms +step:226/1670 train_time:20946ms step_avg:92.68ms +step:227/1670 train_time:21037ms step_avg:92.67ms +step:228/1670 train_time:21127ms step_avg:92.66ms +step:229/1670 train_time:21217ms step_avg:92.65ms +step:230/1670 train_time:21308ms step_avg:92.64ms +step:231/1670 train_time:21400ms step_avg:92.64ms +step:232/1670 train_time:21491ms step_avg:92.63ms +step:233/1670 train_time:21583ms step_avg:92.63ms +step:234/1670 train_time:21676ms step_avg:92.63ms +step:235/1670 train_time:21769ms step_avg:92.63ms +step:236/1670 train_time:21862ms step_avg:92.64ms +step:237/1670 train_time:21955ms step_avg:92.64ms +step:238/1670 train_time:22045ms step_avg:92.63ms +step:239/1670 train_time:22136ms step_avg:92.62ms +step:240/1670 train_time:22226ms step_avg:92.61ms +step:241/1670 train_time:22316ms step_avg:92.60ms +step:242/1670 train_time:22406ms step_avg:92.59ms +step:243/1670 train_time:22498ms step_avg:92.58ms +step:244/1670 train_time:22589ms step_avg:92.58ms +step:245/1670 train_time:22683ms step_avg:92.58ms +step:246/1670 train_time:22775ms step_avg:92.58ms +step:247/1670 train_time:22867ms step_avg:92.58ms +step:248/1670 train_time:22961ms step_avg:92.58ms +step:249/1670 train_time:23052ms step_avg:92.58ms +step:250/1670 train_time:23144ms step_avg:92.57ms +step:250/1670 val_loss:3.9615 train_time:23234ms step_avg:92.94ms +step:251/1670 train_time:23252ms step_avg:92.64ms +step:252/1670 train_time:23327ms step_avg:92.57ms +step:253/1670 train_time:23419ms step_avg:92.57ms +step:254/1670 train_time:23516ms step_avg:92.58ms +step:255/1670 train_time:23608ms step_avg:92.58ms +step:256/1670 train_time:23699ms step_avg:92.57ms +step:257/1670 train_time:23789ms step_avg:92.56ms +step:258/1670 train_time:23880ms step_avg:92.56ms +step:259/1670 train_time:23969ms step_avg:92.55ms +step:260/1670 train_time:24060ms step_avg:92.54ms +step:261/1670 train_time:24153ms step_avg:92.54ms +step:262/1670 train_time:24248ms step_avg:92.55ms +step:263/1670 train_time:24340ms step_avg:92.55ms +step:264/1670 train_time:24432ms step_avg:92.55ms +step:265/1670 train_time:24525ms step_avg:92.55ms +step:266/1670 train_time:24618ms step_avg:92.55ms +step:267/1670 train_time:24709ms step_avg:92.54ms +step:268/1670 train_time:24800ms step_avg:92.54ms +step:269/1670 train_time:24890ms step_avg:92.53ms +step:270/1670 train_time:24981ms step_avg:92.52ms +step:271/1670 train_time:25071ms step_avg:92.51ms +step:272/1670 train_time:25162ms step_avg:92.51ms +step:273/1670 train_time:25253ms step_avg:92.50ms +step:274/1670 train_time:25347ms step_avg:92.51ms +step:275/1670 train_time:25439ms step_avg:92.50ms +step:276/1670 train_time:25530ms step_avg:92.50ms +step:277/1670 train_time:25624ms step_avg:92.50ms +step:278/1670 train_time:25717ms step_avg:92.51ms +step:279/1670 train_time:25807ms step_avg:92.50ms +step:280/1670 train_time:25898ms step_avg:92.49ms +step:281/1670 train_time:25989ms step_avg:92.49ms +step:282/1670 train_time:26080ms step_avg:92.48ms +step:283/1670 train_time:26170ms step_avg:92.47ms +step:284/1670 train_time:26261ms step_avg:92.47ms +step:285/1670 train_time:26353ms step_avg:92.47ms +step:286/1670 train_time:26444ms step_avg:92.46ms +step:287/1670 train_time:26536ms step_avg:92.46ms +step:288/1670 train_time:26628ms step_avg:92.46ms +step:289/1670 train_time:26720ms step_avg:92.46ms +step:290/1670 train_time:26811ms step_avg:92.45ms +step:291/1670 train_time:26902ms step_avg:92.45ms +step:292/1670 train_time:26993ms step_avg:92.44ms +step:293/1670 train_time:27085ms step_avg:92.44ms +step:294/1670 train_time:27175ms step_avg:92.43ms +step:295/1670 train_time:27266ms step_avg:92.43ms +step:296/1670 train_time:27358ms step_avg:92.42ms +step:297/1670 train_time:27449ms step_avg:92.42ms +step:298/1670 train_time:27541ms step_avg:92.42ms +step:299/1670 train_time:27634ms step_avg:92.42ms +step:300/1670 train_time:27725ms step_avg:92.42ms +step:301/1670 train_time:27817ms step_avg:92.41ms +step:302/1670 train_time:27908ms step_avg:92.41ms +step:303/1670 train_time:27999ms step_avg:92.41ms +step:304/1670 train_time:28090ms step_avg:92.40ms +step:305/1670 train_time:28181ms step_avg:92.40ms +step:306/1670 train_time:28273ms step_avg:92.40ms +step:307/1670 train_time:28365ms step_avg:92.39ms +step:308/1670 train_time:28456ms step_avg:92.39ms +step:309/1670 train_time:28547ms step_avg:92.38ms +step:310/1670 train_time:28639ms step_avg:92.38ms +step:311/1670 train_time:28729ms step_avg:92.38ms +step:312/1670 train_time:28822ms step_avg:92.38ms +step:313/1670 train_time:28914ms step_avg:92.38ms +step:314/1670 train_time:29006ms step_avg:92.37ms +step:315/1670 train_time:29096ms step_avg:92.37ms +step:316/1670 train_time:29187ms step_avg:92.36ms +step:317/1670 train_time:29278ms step_avg:92.36ms +step:318/1670 train_time:29369ms step_avg:92.36ms +step:319/1670 train_time:29460ms step_avg:92.35ms +step:320/1670 train_time:29551ms step_avg:92.35ms +step:321/1670 train_time:29642ms step_avg:92.34ms +step:322/1670 train_time:29733ms step_avg:92.34ms +step:323/1670 train_time:29824ms step_avg:92.33ms +step:324/1670 train_time:29916ms step_avg:92.33ms +step:325/1670 train_time:30007ms step_avg:92.33ms +step:326/1670 train_time:30098ms step_avg:92.33ms +step:327/1670 train_time:30189ms step_avg:92.32ms +step:328/1670 train_time:30281ms step_avg:92.32ms +step:329/1670 train_time:30372ms step_avg:92.32ms +step:330/1670 train_time:30464ms step_avg:92.31ms +step:331/1670 train_time:30554ms step_avg:92.31ms +step:332/1670 train_time:30645ms step_avg:92.30ms +step:333/1670 train_time:30735ms step_avg:92.30ms +step:334/1670 train_time:30826ms step_avg:92.29ms +step:335/1670 train_time:30918ms step_avg:92.29ms +step:336/1670 train_time:31008ms step_avg:92.29ms +step:337/1670 train_time:31101ms step_avg:92.29ms +step:338/1670 train_time:31193ms step_avg:92.29ms +step:339/1670 train_time:31284ms step_avg:92.28ms +step:340/1670 train_time:31375ms step_avg:92.28ms +step:341/1670 train_time:31466ms step_avg:92.28ms +step:342/1670 train_time:31558ms step_avg:92.28ms +step:343/1670 train_time:31649ms step_avg:92.27ms +step:344/1670 train_time:31740ms step_avg:92.27ms +step:345/1670 train_time:31831ms step_avg:92.26ms +step:346/1670 train_time:31923ms step_avg:92.26ms +step:347/1670 train_time:32014ms step_avg:92.26ms +step:348/1670 train_time:32106ms step_avg:92.26ms +step:349/1670 train_time:32198ms step_avg:92.26ms +step:350/1670 train_time:32289ms step_avg:92.25ms +step:351/1670 train_time:32381ms step_avg:92.25ms +step:352/1670 train_time:32472ms step_avg:92.25ms +step:353/1670 train_time:32564ms step_avg:92.25ms +step:354/1670 train_time:32656ms step_avg:92.25ms +step:355/1670 train_time:32747ms step_avg:92.25ms +step:356/1670 train_time:32838ms step_avg:92.24ms +step:357/1670 train_time:32929ms step_avg:92.24ms +step:358/1670 train_time:33021ms step_avg:92.24ms +step:359/1670 train_time:33112ms step_avg:92.23ms +step:360/1670 train_time:33205ms step_avg:92.24ms +step:361/1670 train_time:33296ms step_avg:92.23ms +step:362/1670 train_time:33387ms step_avg:92.23ms +step:363/1670 train_time:33479ms step_avg:92.23ms +step:364/1670 train_time:33569ms step_avg:92.22ms +step:365/1670 train_time:33661ms step_avg:92.22ms +step:366/1670 train_time:33752ms step_avg:92.22ms +step:367/1670 train_time:33844ms step_avg:92.22ms +step:368/1670 train_time:33936ms step_avg:92.22ms +step:369/1670 train_time:34026ms step_avg:92.21ms +step:370/1670 train_time:34117ms step_avg:92.21ms +step:371/1670 train_time:34208ms step_avg:92.20ms +step:372/1670 train_time:34300ms step_avg:92.20ms +step:373/1670 train_time:34391ms step_avg:92.20ms +step:374/1670 train_time:34482ms step_avg:92.20ms +step:375/1670 train_time:34573ms step_avg:92.20ms +step:375/1670 val_loss:3.8098 train_time:34664ms step_avg:92.44ms +step:376/1670 train_time:34681ms step_avg:92.24ms +step:377/1670 train_time:34759ms step_avg:92.20ms +step:378/1670 train_time:34852ms step_avg:92.20ms +step:379/1670 train_time:34942ms step_avg:92.20ms +step:380/1670 train_time:35033ms step_avg:92.19ms +step:381/1670 train_time:35124ms step_avg:92.19ms +step:382/1670 train_time:35215ms step_avg:92.19ms +step:383/1670 train_time:35306ms step_avg:92.18ms +step:384/1670 train_time:35397ms step_avg:92.18ms +step:385/1670 train_time:35488ms step_avg:92.18ms +step:386/1670 train_time:35579ms step_avg:92.17ms +step:387/1670 train_time:35671ms step_avg:92.17ms +step:388/1670 train_time:35764ms step_avg:92.18ms +step:389/1670 train_time:35857ms step_avg:92.18ms +step:390/1670 train_time:35948ms step_avg:92.17ms +step:391/1670 train_time:36039ms step_avg:92.17ms +step:392/1670 train_time:36130ms step_avg:92.17ms +step:393/1670 train_time:36220ms step_avg:92.16ms +step:394/1670 train_time:36313ms step_avg:92.16ms +step:395/1670 train_time:36403ms step_avg:92.16ms +step:396/1670 train_time:36495ms step_avg:92.16ms +step:397/1670 train_time:36586ms step_avg:92.16ms +step:398/1670 train_time:36679ms step_avg:92.16ms +step:399/1670 train_time:36771ms step_avg:92.16ms +step:400/1670 train_time:36862ms step_avg:92.16ms +step:401/1670 train_time:36954ms step_avg:92.15ms +step:402/1670 train_time:37045ms step_avg:92.15ms +step:403/1670 train_time:37136ms step_avg:92.15ms +step:404/1670 train_time:37227ms step_avg:92.15ms +step:405/1670 train_time:37317ms step_avg:92.14ms +step:406/1670 train_time:37407ms step_avg:92.14ms +step:407/1670 train_time:37497ms step_avg:92.13ms +step:408/1670 train_time:37588ms step_avg:92.13ms +step:409/1670 train_time:37679ms step_avg:92.12ms +step:410/1670 train_time:37771ms step_avg:92.12ms +step:411/1670 train_time:37862ms step_avg:92.12ms +step:412/1670 train_time:37955ms step_avg:92.12ms +step:413/1670 train_time:38046ms step_avg:92.12ms +step:414/1670 train_time:38137ms step_avg:92.12ms +step:415/1670 train_time:38228ms step_avg:92.11ms +step:416/1670 train_time:38318ms step_avg:92.11ms +step:417/1670 train_time:38408ms step_avg:92.11ms +step:418/1670 train_time:38499ms step_avg:92.10ms +step:419/1670 train_time:38590ms step_avg:92.10ms +step:420/1670 train_time:38681ms step_avg:92.10ms +step:421/1670 train_time:38773ms step_avg:92.10ms +step:422/1670 train_time:38864ms step_avg:92.10ms +step:423/1670 train_time:38956ms step_avg:92.10ms +step:424/1670 train_time:39048ms step_avg:92.09ms +step:425/1670 train_time:39295ms step_avg:92.46ms +step:426/1670 train_time:39366ms step_avg:92.41ms +step:427/1670 train_time:39456ms step_avg:92.40ms +step:428/1670 train_time:39546ms step_avg:92.40ms +step:429/1670 train_time:39636ms step_avg:92.39ms +step:430/1670 train_time:39727ms step_avg:92.39ms +step:431/1670 train_time:39817ms step_avg:92.38ms +step:432/1670 train_time:39907ms step_avg:92.38ms +step:433/1670 train_time:39997ms step_avg:92.37ms +step:434/1670 train_time:40087ms step_avg:92.37ms +step:435/1670 train_time:40183ms step_avg:92.37ms +step:436/1670 train_time:40281ms step_avg:92.39ms +step:437/1670 train_time:40375ms step_avg:92.39ms +step:438/1670 train_time:40466ms step_avg:92.39ms +step:439/1670 train_time:40556ms step_avg:92.38ms +step:440/1670 train_time:40646ms step_avg:92.38ms +step:441/1670 train_time:40736ms step_avg:92.37ms +step:442/1670 train_time:40827ms step_avg:92.37ms +step:443/1670 train_time:40917ms step_avg:92.36ms +step:444/1670 train_time:41007ms step_avg:92.36ms +step:445/1670 train_time:41098ms step_avg:92.36ms +step:446/1670 train_time:41193ms step_avg:92.36ms +step:447/1670 train_time:41285ms step_avg:92.36ms +step:448/1670 train_time:41377ms step_avg:92.36ms +step:449/1670 train_time:41469ms step_avg:92.36ms +step:450/1670 train_time:41560ms step_avg:92.35ms +step:451/1670 train_time:41650ms step_avg:92.35ms +step:452/1670 train_time:41741ms step_avg:92.35ms +step:453/1670 train_time:41831ms step_avg:92.34ms +step:454/1670 train_time:41921ms step_avg:92.34ms +step:455/1670 train_time:42012ms step_avg:92.33ms +step:456/1670 train_time:42104ms step_avg:92.33ms +step:457/1670 train_time:42197ms step_avg:92.34ms +step:458/1670 train_time:42291ms step_avg:92.34ms +step:459/1670 train_time:42381ms step_avg:92.33ms +step:460/1670 train_time:42474ms step_avg:92.33ms +step:461/1670 train_time:42564ms step_avg:92.33ms +step:462/1670 train_time:42656ms step_avg:92.33ms +step:463/1670 train_time:42747ms step_avg:92.33ms +step:464/1670 train_time:42837ms step_avg:92.32ms +step:465/1670 train_time:42928ms step_avg:92.32ms +step:466/1670 train_time:43019ms step_avg:92.32ms +step:467/1670 train_time:43110ms step_avg:92.31ms +step:468/1670 train_time:43202ms step_avg:92.31ms +step:469/1670 train_time:43295ms step_avg:92.31ms +step:470/1670 train_time:43388ms step_avg:92.32ms +step:471/1670 train_time:43480ms step_avg:92.31ms +step:472/1670 train_time:43571ms step_avg:92.31ms +step:473/1670 train_time:43662ms step_avg:92.31ms +step:474/1670 train_time:43754ms step_avg:92.31ms +step:475/1670 train_time:43845ms step_avg:92.31ms +step:476/1670 train_time:43935ms step_avg:92.30ms +step:477/1670 train_time:44026ms step_avg:92.30ms +step:478/1670 train_time:44117ms step_avg:92.29ms +step:479/1670 train_time:44208ms step_avg:92.29ms +step:480/1670 train_time:44300ms step_avg:92.29ms +step:481/1670 train_time:44392ms step_avg:92.29ms +step:482/1670 train_time:44483ms step_avg:92.29ms +step:483/1670 train_time:44574ms step_avg:92.29ms +step:484/1670 train_time:44665ms step_avg:92.28ms +step:485/1670 train_time:44757ms step_avg:92.28ms +step:486/1670 train_time:44848ms step_avg:92.28ms +step:487/1670 train_time:44939ms step_avg:92.28ms +step:488/1670 train_time:45030ms step_avg:92.28ms +step:489/1670 train_time:45121ms step_avg:92.27ms +step:490/1670 train_time:45213ms step_avg:92.27ms +step:491/1670 train_time:45304ms step_avg:92.27ms +step:492/1670 train_time:45397ms step_avg:92.27ms +step:493/1670 train_time:45488ms step_avg:92.27ms +step:494/1670 train_time:45580ms step_avg:92.27ms +step:495/1670 train_time:45670ms step_avg:92.26ms +step:496/1670 train_time:45760ms step_avg:92.26ms +step:497/1670 train_time:45851ms step_avg:92.26ms +step:498/1670 train_time:45943ms step_avg:92.26ms +step:499/1670 train_time:46034ms step_avg:92.25ms +step:500/1670 train_time:46126ms step_avg:92.25ms +step:500/1670 val_loss:3.7097 train_time:46217ms step_avg:92.43ms +step:501/1670 train_time:46234ms step_avg:92.28ms +step:502/1670 train_time:46311ms step_avg:92.25ms +step:503/1670 train_time:46403ms step_avg:92.25ms +step:504/1670 train_time:46495ms step_avg:92.25ms +step:505/1670 train_time:46585ms step_avg:92.25ms +step:506/1670 train_time:46677ms step_avg:92.25ms +step:507/1670 train_time:46768ms step_avg:92.24ms +step:508/1670 train_time:46859ms step_avg:92.24ms +step:509/1670 train_time:46949ms step_avg:92.24ms +step:510/1670 train_time:47041ms step_avg:92.24ms +step:511/1670 train_time:47132ms step_avg:92.24ms +step:512/1670 train_time:47227ms step_avg:92.24ms +step:513/1670 train_time:47318ms step_avg:92.24ms +step:514/1670 train_time:47410ms step_avg:92.24ms +step:515/1670 train_time:47501ms step_avg:92.23ms +step:516/1670 train_time:47591ms step_avg:92.23ms +step:517/1670 train_time:47682ms step_avg:92.23ms +step:518/1670 train_time:47774ms step_avg:92.23ms +step:519/1670 train_time:47865ms step_avg:92.23ms +step:520/1670 train_time:47956ms step_avg:92.22ms +step:521/1670 train_time:48047ms step_avg:92.22ms +step:522/1670 train_time:48138ms step_avg:92.22ms +step:523/1670 train_time:48230ms step_avg:92.22ms +step:524/1670 train_time:48322ms step_avg:92.22ms +step:525/1670 train_time:48413ms step_avg:92.22ms +step:526/1670 train_time:48506ms step_avg:92.22ms +step:527/1670 train_time:48597ms step_avg:92.21ms +step:528/1670 train_time:48688ms step_avg:92.21ms +step:529/1670 train_time:48779ms step_avg:92.21ms +step:530/1670 train_time:48869ms step_avg:92.21ms +step:531/1670 train_time:48960ms step_avg:92.20ms +step:532/1670 train_time:49050ms step_avg:92.20ms +step:533/1670 train_time:49142ms step_avg:92.20ms +step:534/1670 train_time:49234ms step_avg:92.20ms +step:535/1670 train_time:49327ms step_avg:92.20ms +step:536/1670 train_time:49418ms step_avg:92.20ms +step:537/1670 train_time:49509ms step_avg:92.20ms +step:538/1670 train_time:49600ms step_avg:92.19ms +step:539/1670 train_time:49691ms step_avg:92.19ms +step:540/1670 train_time:49782ms step_avg:92.19ms +step:541/1670 train_time:49872ms step_avg:92.18ms +step:542/1670 train_time:49963ms step_avg:92.18ms +step:543/1670 train_time:50055ms step_avg:92.18ms +step:544/1670 train_time:50148ms step_avg:92.18ms +step:545/1670 train_time:50239ms step_avg:92.18ms +step:546/1670 train_time:50330ms step_avg:92.18ms +step:547/1670 train_time:50423ms step_avg:92.18ms +step:548/1670 train_time:50513ms step_avg:92.18ms +step:549/1670 train_time:50605ms step_avg:92.18ms +step:550/1670 train_time:50695ms step_avg:92.17ms +step:551/1670 train_time:50786ms step_avg:92.17ms +step:552/1670 train_time:50877ms step_avg:92.17ms +step:553/1670 train_time:50969ms step_avg:92.17ms +step:554/1670 train_time:51060ms step_avg:92.17ms +step:555/1670 train_time:51151ms step_avg:92.16ms +step:556/1670 train_time:51243ms step_avg:92.16ms +step:557/1670 train_time:51334ms step_avg:92.16ms +step:558/1670 train_time:51619ms step_avg:92.51ms +step:559/1670 train_time:51697ms step_avg:92.48ms +step:560/1670 train_time:51787ms step_avg:92.48ms +step:561/1670 train_time:51878ms step_avg:92.47ms +step:562/1670 train_time:51969ms step_avg:92.47ms +step:563/1670 train_time:52061ms step_avg:92.47ms +step:564/1670 train_time:52152ms step_avg:92.47ms +step:565/1670 train_time:52243ms step_avg:92.47ms +step:566/1670 train_time:52334ms step_avg:92.46ms +step:567/1670 train_time:52426ms step_avg:92.46ms +step:568/1670 train_time:52523ms step_avg:92.47ms +step:569/1670 train_time:52620ms step_avg:92.48ms +step:570/1670 train_time:52714ms step_avg:92.48ms +step:571/1670 train_time:52807ms step_avg:92.48ms +step:572/1670 train_time:52899ms step_avg:92.48ms +step:573/1670 train_time:52990ms step_avg:92.48ms +step:574/1670 train_time:53082ms step_avg:92.48ms +step:575/1670 train_time:53173ms step_avg:92.47ms +step:576/1670 train_time:53265ms step_avg:92.47ms +step:577/1670 train_time:53357ms step_avg:92.47ms +step:578/1670 train_time:53450ms step_avg:92.47ms +step:579/1670 train_time:53545ms step_avg:92.48ms +step:580/1670 train_time:53640ms step_avg:92.48ms +step:581/1670 train_time:53732ms step_avg:92.48ms +step:582/1670 train_time:53826ms step_avg:92.48ms +step:583/1670 train_time:53918ms step_avg:92.48ms +step:584/1670 train_time:54010ms step_avg:92.48ms +step:585/1670 train_time:54103ms step_avg:92.48ms +step:586/1670 train_time:54195ms step_avg:92.48ms +step:587/1670 train_time:54287ms step_avg:92.48ms +step:588/1670 train_time:54379ms step_avg:92.48ms +step:589/1670 train_time:54471ms step_avg:92.48ms +step:590/1670 train_time:54565ms step_avg:92.48ms +step:591/1670 train_time:54660ms step_avg:92.49ms +step:592/1670 train_time:54753ms step_avg:92.49ms +step:593/1670 train_time:54847ms step_avg:92.49ms +step:594/1670 train_time:54939ms step_avg:92.49ms +step:595/1670 train_time:55031ms step_avg:92.49ms +step:596/1670 train_time:55124ms step_avg:92.49ms +step:597/1670 train_time:55215ms step_avg:92.49ms +step:598/1670 train_time:55307ms step_avg:92.49ms +step:599/1670 train_time:55399ms step_avg:92.49ms +step:600/1670 train_time:55492ms step_avg:92.49ms +step:601/1670 train_time:55586ms step_avg:92.49ms +step:602/1670 train_time:55679ms step_avg:92.49ms +step:603/1670 train_time:55771ms step_avg:92.49ms +step:604/1670 train_time:55865ms step_avg:92.49ms +step:605/1670 train_time:55956ms step_avg:92.49ms +step:606/1670 train_time:56049ms step_avg:92.49ms +step:607/1670 train_time:56141ms step_avg:92.49ms +step:608/1670 train_time:56233ms step_avg:92.49ms +step:609/1670 train_time:56325ms step_avg:92.49ms +step:610/1670 train_time:56418ms step_avg:92.49ms +step:611/1670 train_time:56510ms step_avg:92.49ms +step:612/1670 train_time:56603ms step_avg:92.49ms +step:613/1670 train_time:56695ms step_avg:92.49ms +step:614/1670 train_time:56788ms step_avg:92.49ms +step:615/1670 train_time:56881ms step_avg:92.49ms +step:616/1670 train_time:56973ms step_avg:92.49ms +step:617/1670 train_time:57066ms step_avg:92.49ms +step:618/1670 train_time:57158ms step_avg:92.49ms +step:619/1670 train_time:57250ms step_avg:92.49ms +step:620/1670 train_time:57343ms step_avg:92.49ms +step:621/1670 train_time:57435ms step_avg:92.49ms +step:622/1670 train_time:57528ms step_avg:92.49ms +step:623/1670 train_time:57621ms step_avg:92.49ms +step:624/1670 train_time:57714ms step_avg:92.49ms +step:625/1670 train_time:57807ms step_avg:92.49ms +step:625/1670 val_loss:3.6092 train_time:57899ms step_avg:92.64ms +step:626/1670 train_time:57917ms step_avg:92.52ms +step:627/1670 train_time:58000ms step_avg:92.50ms +step:628/1670 train_time:58099ms step_avg:92.51ms +step:629/1670 train_time:58193ms step_avg:92.52ms +step:630/1670 train_time:58286ms step_avg:92.52ms +step:631/1670 train_time:58376ms step_avg:92.51ms +step:632/1670 train_time:58467ms step_avg:92.51ms +step:633/1670 train_time:58558ms step_avg:92.51ms +step:634/1670 train_time:58650ms step_avg:92.51ms +step:635/1670 train_time:58741ms step_avg:92.51ms +step:636/1670 train_time:58832ms step_avg:92.50ms +step:637/1670 train_time:58926ms step_avg:92.51ms +step:638/1670 train_time:59021ms step_avg:92.51ms +step:639/1670 train_time:59256ms step_avg:92.73ms +step:640/1670 train_time:59331ms step_avg:92.70ms +step:641/1670 train_time:59422ms step_avg:92.70ms +step:642/1670 train_time:59513ms step_avg:92.70ms +step:643/1670 train_time:59604ms step_avg:92.70ms +step:644/1670 train_time:59695ms step_avg:92.69ms +step:645/1670 train_time:59787ms step_avg:92.69ms +step:646/1670 train_time:59878ms step_avg:92.69ms +step:647/1670 train_time:59969ms step_avg:92.69ms +step:648/1670 train_time:60060ms step_avg:92.69ms +step:649/1670 train_time:60161ms step_avg:92.70ms +step:650/1670 train_time:60258ms step_avg:92.70ms +step:651/1670 train_time:60353ms step_avg:92.71ms +step:652/1670 train_time:60446ms step_avg:92.71ms +step:653/1670 train_time:60537ms step_avg:92.71ms +step:654/1670 train_time:60628ms step_avg:92.70ms +step:655/1670 train_time:60720ms step_avg:92.70ms +step:656/1670 train_time:60811ms step_avg:92.70ms +step:657/1670 train_time:60902ms step_avg:92.70ms +step:658/1670 train_time:60994ms step_avg:92.70ms +step:659/1670 train_time:61089ms step_avg:92.70ms +step:660/1670 train_time:61183ms step_avg:92.70ms +step:661/1670 train_time:61277ms step_avg:92.70ms +step:662/1670 train_time:61371ms step_avg:92.71ms +step:663/1670 train_time:61464ms step_avg:92.71ms +step:664/1670 train_time:61557ms step_avg:92.71ms +step:665/1670 train_time:61648ms step_avg:92.70ms +step:666/1670 train_time:61740ms step_avg:92.70ms +step:667/1670 train_time:61831ms step_avg:92.70ms +step:668/1670 train_time:61923ms step_avg:92.70ms +step:669/1670 train_time:62014ms step_avg:92.70ms +step:670/1670 train_time:62108ms step_avg:92.70ms +step:671/1670 train_time:62201ms step_avg:92.70ms +step:672/1670 train_time:62296ms step_avg:92.70ms +step:673/1670 train_time:62390ms step_avg:92.70ms +step:674/1670 train_time:62482ms step_avg:92.70ms +step:675/1670 train_time:62576ms step_avg:92.70ms +step:676/1670 train_time:62668ms step_avg:92.70ms +step:677/1670 train_time:62760ms step_avg:92.70ms +step:678/1670 train_time:62851ms step_avg:92.70ms +step:679/1670 train_time:62943ms step_avg:92.70ms +step:680/1670 train_time:63034ms step_avg:92.70ms +step:681/1670 train_time:63127ms step_avg:92.70ms +step:682/1670 train_time:63220ms step_avg:92.70ms +step:683/1670 train_time:63313ms step_avg:92.70ms +step:684/1670 train_time:63407ms step_avg:92.70ms +step:685/1670 train_time:63499ms step_avg:92.70ms +step:686/1670 train_time:63593ms step_avg:92.70ms +step:687/1670 train_time:63686ms step_avg:92.70ms +step:688/1670 train_time:63777ms step_avg:92.70ms +step:689/1670 train_time:63870ms step_avg:92.70ms +step:690/1670 train_time:63961ms step_avg:92.70ms +step:691/1670 train_time:64053ms step_avg:92.70ms +step:692/1670 train_time:64146ms step_avg:92.70ms +step:693/1670 train_time:64239ms step_avg:92.70ms +step:694/1670 train_time:64332ms step_avg:92.70ms +step:695/1670 train_time:64425ms step_avg:92.70ms +step:696/1670 train_time:64518ms step_avg:92.70ms +step:697/1670 train_time:64612ms step_avg:92.70ms +step:698/1670 train_time:64704ms step_avg:92.70ms +step:699/1670 train_time:64797ms step_avg:92.70ms +step:700/1670 train_time:64889ms step_avg:92.70ms +step:701/1670 train_time:64981ms step_avg:92.70ms +step:702/1670 train_time:65074ms step_avg:92.70ms +step:703/1670 train_time:65167ms step_avg:92.70ms +step:704/1670 train_time:65260ms step_avg:92.70ms +step:705/1670 train_time:65352ms step_avg:92.70ms +step:706/1670 train_time:65445ms step_avg:92.70ms +step:707/1670 train_time:65537ms step_avg:92.70ms +step:708/1670 train_time:65629ms step_avg:92.70ms +step:709/1670 train_time:65723ms step_avg:92.70ms +step:710/1670 train_time:65815ms step_avg:92.70ms +step:711/1670 train_time:65908ms step_avg:92.70ms +step:712/1670 train_time:65999ms step_avg:92.70ms +step:713/1670 train_time:66092ms step_avg:92.70ms +step:714/1670 train_time:66185ms step_avg:92.70ms +step:715/1670 train_time:66278ms step_avg:92.70ms +step:716/1670 train_time:66371ms step_avg:92.70ms +step:717/1670 train_time:66463ms step_avg:92.70ms +step:718/1670 train_time:66557ms step_avg:92.70ms +step:719/1670 train_time:66649ms step_avg:92.70ms +step:720/1670 train_time:66742ms step_avg:92.70ms +step:721/1670 train_time:66834ms step_avg:92.70ms +step:722/1670 train_time:66926ms step_avg:92.70ms +step:723/1670 train_time:67018ms step_avg:92.69ms +step:724/1670 train_time:67111ms step_avg:92.69ms +step:725/1670 train_time:67202ms step_avg:92.69ms +step:726/1670 train_time:67297ms step_avg:92.70ms +step:727/1670 train_time:67391ms step_avg:92.70ms +step:728/1670 train_time:67483ms step_avg:92.70ms +step:729/1670 train_time:67575ms step_avg:92.70ms +step:730/1670 train_time:67668ms step_avg:92.70ms +step:731/1670 train_time:67760ms step_avg:92.69ms +step:732/1670 train_time:67853ms step_avg:92.69ms +step:733/1670 train_time:67945ms step_avg:92.69ms +step:734/1670 train_time:68037ms step_avg:92.69ms +step:735/1670 train_time:68129ms step_avg:92.69ms +step:736/1670 train_time:68221ms step_avg:92.69ms +step:737/1670 train_time:68316ms step_avg:92.69ms +step:738/1670 train_time:68409ms step_avg:92.70ms +step:739/1670 train_time:68501ms step_avg:92.69ms +step:740/1670 train_time:68595ms step_avg:92.70ms +step:741/1670 train_time:68687ms step_avg:92.70ms +step:742/1670 train_time:68779ms step_avg:92.69ms +step:743/1670 train_time:68872ms step_avg:92.69ms +step:744/1670 train_time:68963ms step_avg:92.69ms +step:745/1670 train_time:69056ms step_avg:92.69ms +step:746/1670 train_time:69149ms step_avg:92.69ms +step:747/1670 train_time:69242ms step_avg:92.69ms +step:748/1670 train_time:69335ms step_avg:92.69ms +step:749/1670 train_time:69428ms step_avg:92.69ms +step:750/1670 train_time:69520ms step_avg:92.69ms +step:750/1670 val_loss:3.5584 train_time:69612ms step_avg:92.82ms +step:751/1670 train_time:69629ms step_avg:92.72ms +step:752/1670 train_time:69707ms step_avg:92.70ms +step:753/1670 train_time:69800ms step_avg:92.70ms +step:754/1670 train_time:69892ms step_avg:92.69ms +step:755/1670 train_time:69983ms step_avg:92.69ms +step:756/1670 train_time:70075ms step_avg:92.69ms +step:757/1670 train_time:70167ms step_avg:92.69ms +step:758/1670 train_time:70259ms step_avg:92.69ms +step:759/1670 train_time:70352ms step_avg:92.69ms +step:760/1670 train_time:70445ms step_avg:92.69ms +step:761/1670 train_time:70538ms step_avg:92.69ms +step:762/1670 train_time:70633ms step_avg:92.69ms +step:763/1670 train_time:70726ms step_avg:92.69ms +step:764/1670 train_time:70819ms step_avg:92.70ms +step:765/1670 train_time:70912ms step_avg:92.70ms +step:766/1670 train_time:71004ms step_avg:92.69ms +step:767/1670 train_time:71096ms step_avg:92.69ms +step:768/1670 train_time:71187ms step_avg:92.69ms +step:769/1670 train_time:71279ms step_avg:92.69ms +step:770/1670 train_time:71372ms step_avg:92.69ms +step:771/1670 train_time:71465ms step_avg:92.69ms +step:772/1670 train_time:71559ms step_avg:92.69ms +step:773/1670 train_time:71653ms step_avg:92.69ms +step:774/1670 train_time:71745ms step_avg:92.69ms +step:775/1670 train_time:71837ms step_avg:92.69ms +step:776/1670 train_time:71931ms step_avg:92.69ms +step:777/1670 train_time:72023ms step_avg:92.69ms +step:778/1670 train_time:72115ms step_avg:92.69ms +step:779/1670 train_time:72207ms step_avg:92.69ms +step:780/1670 train_time:72299ms step_avg:92.69ms +step:781/1670 train_time:72391ms step_avg:92.69ms +step:782/1670 train_time:72484ms step_avg:92.69ms +step:783/1670 train_time:72577ms step_avg:92.69ms +step:784/1670 train_time:72670ms step_avg:92.69ms +step:785/1670 train_time:72763ms step_avg:92.69ms +step:786/1670 train_time:72856ms step_avg:92.69ms +step:787/1670 train_time:72949ms step_avg:92.69ms +step:788/1670 train_time:73042ms step_avg:92.69ms +step:789/1670 train_time:73135ms step_avg:92.69ms +step:790/1670 train_time:73226ms step_avg:92.69ms +step:791/1670 train_time:73318ms step_avg:92.69ms +step:792/1670 train_time:73411ms step_avg:92.69ms +step:793/1670 train_time:73504ms step_avg:92.69ms +step:794/1670 train_time:73597ms step_avg:92.69ms +step:795/1670 train_time:73691ms step_avg:92.69ms +step:796/1670 train_time:73784ms step_avg:92.69ms +step:797/1670 train_time:73877ms step_avg:92.69ms +step:798/1670 train_time:73972ms step_avg:92.70ms +step:799/1670 train_time:74065ms step_avg:92.70ms +step:800/1670 train_time:74156ms step_avg:92.69ms +step:801/1670 train_time:74250ms step_avg:92.70ms +step:802/1670 train_time:74343ms step_avg:92.70ms +step:803/1670 train_time:74435ms step_avg:92.70ms +step:804/1670 train_time:74527ms step_avg:92.69ms +step:805/1670 train_time:74619ms step_avg:92.69ms +step:806/1670 train_time:74711ms step_avg:92.69ms +step:807/1670 train_time:74804ms step_avg:92.69ms +step:808/1670 train_time:74896ms step_avg:92.69ms +step:809/1670 train_time:74989ms step_avg:92.69ms +step:810/1670 train_time:75082ms step_avg:92.69ms +step:811/1670 train_time:75174ms step_avg:92.69ms +step:812/1670 train_time:75267ms step_avg:92.69ms +step:813/1670 train_time:75359ms step_avg:92.69ms +step:814/1670 train_time:75452ms step_avg:92.69ms +step:815/1670 train_time:75544ms step_avg:92.69ms +step:816/1670 train_time:75637ms step_avg:92.69ms +step:817/1670 train_time:75730ms step_avg:92.69ms +step:818/1670 train_time:75822ms step_avg:92.69ms +step:819/1670 train_time:75915ms step_avg:92.69ms +step:820/1670 train_time:76009ms step_avg:92.69ms +step:821/1670 train_time:76102ms step_avg:92.69ms +step:822/1670 train_time:76193ms step_avg:92.69ms +step:823/1670 train_time:76286ms step_avg:92.69ms +step:824/1670 train_time:76378ms step_avg:92.69ms +step:825/1670 train_time:76472ms step_avg:92.69ms +step:826/1670 train_time:76564ms step_avg:92.69ms +step:827/1670 train_time:76656ms step_avg:92.69ms +step:828/1670 train_time:76749ms step_avg:92.69ms +step:829/1670 train_time:76840ms step_avg:92.69ms +step:830/1670 train_time:76933ms step_avg:92.69ms +step:831/1670 train_time:77025ms step_avg:92.69ms +step:832/1670 train_time:77117ms step_avg:92.69ms +step:833/1670 train_time:77211ms step_avg:92.69ms +step:834/1670 train_time:77304ms step_avg:92.69ms +step:835/1670 train_time:77396ms step_avg:92.69ms +step:836/1670 train_time:77489ms step_avg:92.69ms +step:837/1670 train_time:77581ms step_avg:92.69ms +step:838/1670 train_time:77674ms step_avg:92.69ms +step:839/1670 train_time:77766ms step_avg:92.69ms +step:840/1670 train_time:77858ms step_avg:92.69ms +step:841/1670 train_time:77952ms step_avg:92.69ms +step:842/1670 train_time:78044ms step_avg:92.69ms +step:843/1670 train_time:78136ms step_avg:92.69ms +step:844/1670 train_time:78230ms step_avg:92.69ms +step:845/1670 train_time:78321ms step_avg:92.69ms +step:846/1670 train_time:78414ms step_avg:92.69ms +step:847/1670 train_time:78508ms step_avg:92.69ms +step:848/1670 train_time:78600ms step_avg:92.69ms +step:849/1670 train_time:78693ms step_avg:92.69ms +step:850/1670 train_time:78785ms step_avg:92.69ms +step:851/1670 train_time:79032ms step_avg:92.87ms +step:852/1670 train_time:79108ms step_avg:92.85ms +step:853/1670 train_time:79199ms step_avg:92.85ms +step:854/1670 train_time:79290ms step_avg:92.85ms +step:855/1670 train_time:79381ms step_avg:92.84ms +step:856/1670 train_time:79472ms step_avg:92.84ms +step:857/1670 train_time:79564ms step_avg:92.84ms +step:858/1670 train_time:79655ms step_avg:92.84ms +step:859/1670 train_time:79746ms step_avg:92.84ms +step:860/1670 train_time:79837ms step_avg:92.83ms +step:861/1670 train_time:79935ms step_avg:92.84ms +step:862/1670 train_time:80034ms step_avg:92.85ms +step:863/1670 train_time:80128ms step_avg:92.85ms +step:864/1670 train_time:80220ms step_avg:92.85ms +step:865/1670 train_time:80312ms step_avg:92.85ms +step:866/1670 train_time:80404ms step_avg:92.84ms +step:867/1670 train_time:80495ms step_avg:92.84ms +step:868/1670 train_time:80588ms step_avg:92.84ms +step:869/1670 train_time:80680ms step_avg:92.84ms +step:870/1670 train_time:80771ms step_avg:92.84ms +step:871/1670 train_time:80864ms step_avg:92.84ms +step:872/1670 train_time:80958ms step_avg:92.84ms +step:873/1670 train_time:81054ms step_avg:92.85ms +step:874/1670 train_time:81149ms step_avg:92.85ms +step:875/1670 train_time:81241ms step_avg:92.85ms +step:875/1670 val_loss:3.5154 train_time:81333ms step_avg:92.95ms +step:876/1670 train_time:81350ms step_avg:92.87ms +step:877/1670 train_time:81428ms step_avg:92.85ms +step:878/1670 train_time:81521ms step_avg:92.85ms +step:879/1670 train_time:81612ms step_avg:92.85ms +step:880/1670 train_time:81705ms step_avg:92.85ms +step:881/1670 train_time:81796ms step_avg:92.84ms +step:882/1670 train_time:81887ms step_avg:92.84ms +step:883/1670 train_time:81979ms step_avg:92.84ms +step:884/1670 train_time:82071ms step_avg:92.84ms +step:885/1670 train_time:82165ms step_avg:92.84ms +step:886/1670 train_time:82258ms step_avg:92.84ms +step:887/1670 train_time:82353ms step_avg:92.84ms +step:888/1670 train_time:82447ms step_avg:92.85ms +step:889/1670 train_time:82540ms step_avg:92.85ms +step:890/1670 train_time:82632ms step_avg:92.84ms +step:891/1670 train_time:82724ms step_avg:92.84ms +step:892/1670 train_time:82815ms step_avg:92.84ms +step:893/1670 train_time:82907ms step_avg:92.84ms +step:894/1670 train_time:83000ms step_avg:92.84ms +step:895/1670 train_time:83091ms step_avg:92.84ms +step:896/1670 train_time:83185ms step_avg:92.84ms +step:897/1670 train_time:83280ms step_avg:92.84ms +step:898/1670 train_time:83373ms step_avg:92.84ms +step:899/1670 train_time:83466ms step_avg:92.84ms +step:900/1670 train_time:83560ms step_avg:92.84ms +step:901/1670 train_time:83652ms step_avg:92.84ms +step:902/1670 train_time:83744ms step_avg:92.84ms +step:903/1670 train_time:83836ms step_avg:92.84ms +step:904/1670 train_time:83928ms step_avg:92.84ms +step:905/1670 train_time:84020ms step_avg:92.84ms +step:906/1670 train_time:84111ms step_avg:92.84ms +step:907/1670 train_time:84205ms step_avg:92.84ms +step:908/1670 train_time:84298ms step_avg:92.84ms +step:909/1670 train_time:84392ms step_avg:92.84ms +step:910/1670 train_time:84485ms step_avg:92.84ms +step:911/1670 train_time:84579ms step_avg:92.84ms +step:912/1670 train_time:84671ms step_avg:92.84ms +step:913/1670 train_time:84764ms step_avg:92.84ms +step:914/1670 train_time:84856ms step_avg:92.84ms +step:915/1670 train_time:84948ms step_avg:92.84ms +step:916/1670 train_time:85040ms step_avg:92.84ms +step:917/1670 train_time:85132ms step_avg:92.84ms +step:918/1670 train_time:85227ms step_avg:92.84ms +step:919/1670 train_time:85318ms step_avg:92.84ms +step:920/1670 train_time:85411ms step_avg:92.84ms +step:921/1670 train_time:85504ms step_avg:92.84ms +step:922/1670 train_time:85596ms step_avg:92.84ms +step:923/1670 train_time:85689ms step_avg:92.84ms +step:924/1670 train_time:85781ms step_avg:92.84ms +step:925/1670 train_time:85874ms step_avg:92.84ms +step:926/1670 train_time:85966ms step_avg:92.84ms +step:927/1670 train_time:86059ms step_avg:92.84ms +step:928/1670 train_time:86151ms step_avg:92.84ms +step:929/1670 train_time:86244ms step_avg:92.84ms +step:930/1670 train_time:86336ms step_avg:92.83ms +step:931/1670 train_time:86429ms step_avg:92.83ms +step:932/1670 train_time:86522ms step_avg:92.83ms +step:933/1670 train_time:86614ms step_avg:92.83ms +step:934/1670 train_time:86707ms step_avg:92.83ms +step:935/1670 train_time:86800ms step_avg:92.83ms +step:936/1670 train_time:86892ms step_avg:92.83ms +step:937/1670 train_time:86985ms step_avg:92.83ms +step:938/1670 train_time:87078ms step_avg:92.83ms +step:939/1670 train_time:87170ms step_avg:92.83ms +step:940/1670 train_time:87263ms step_avg:92.83ms +step:941/1670 train_time:87355ms step_avg:92.83ms +step:942/1670 train_time:87448ms step_avg:92.83ms +step:943/1670 train_time:87541ms step_avg:92.83ms +step:944/1670 train_time:87633ms step_avg:92.83ms +step:945/1670 train_time:87728ms step_avg:92.83ms +step:946/1670 train_time:87821ms step_avg:92.83ms +step:947/1670 train_time:87912ms step_avg:92.83ms +step:948/1670 train_time:88006ms step_avg:92.83ms +step:949/1670 train_time:88099ms step_avg:92.83ms +step:950/1670 train_time:88190ms step_avg:92.83ms +step:951/1670 train_time:88284ms step_avg:92.83ms +step:952/1670 train_time:88377ms step_avg:92.83ms +step:953/1670 train_time:88470ms step_avg:92.83ms +step:954/1670 train_time:88563ms step_avg:92.83ms +step:955/1670 train_time:88655ms step_avg:92.83ms +step:956/1670 train_time:88748ms step_avg:92.83ms +step:957/1670 train_time:88841ms step_avg:92.83ms +step:958/1670 train_time:88933ms step_avg:92.83ms +step:959/1670 train_time:89025ms step_avg:92.83ms +step:960/1670 train_time:89117ms step_avg:92.83ms +step:961/1670 train_time:89209ms step_avg:92.83ms +step:962/1670 train_time:89302ms step_avg:92.83ms +step:963/1670 train_time:89394ms step_avg:92.83ms +step:964/1670 train_time:89487ms step_avg:92.83ms +step:965/1670 train_time:89581ms step_avg:92.83ms +step:966/1670 train_time:89673ms step_avg:92.83ms +step:967/1670 train_time:89766ms step_avg:92.83ms +step:968/1670 train_time:89859ms step_avg:92.83ms +step:969/1670 train_time:89950ms step_avg:92.83ms +step:970/1670 train_time:90043ms step_avg:92.83ms +step:971/1670 train_time:90135ms step_avg:92.83ms +step:972/1670 train_time:90227ms step_avg:92.83ms +step:973/1670 train_time:90318ms step_avg:92.82ms +step:974/1670 train_time:90410ms step_avg:92.82ms +step:975/1670 train_time:90504ms step_avg:92.82ms +step:976/1670 train_time:90597ms step_avg:92.83ms +step:977/1670 train_time:90690ms step_avg:92.83ms +step:978/1670 train_time:90784ms step_avg:92.83ms +step:979/1670 train_time:90876ms step_avg:92.83ms +step:980/1670 train_time:90969ms step_avg:92.83ms +step:981/1670 train_time:91061ms step_avg:92.82ms +step:982/1670 train_time:91153ms step_avg:92.82ms +step:983/1670 train_time:91246ms step_avg:92.82ms +step:984/1670 train_time:91339ms step_avg:92.82ms +step:985/1670 train_time:91431ms step_avg:92.82ms +step:986/1670 train_time:91524ms step_avg:92.82ms +step:987/1670 train_time:91617ms step_avg:92.82ms +step:988/1670 train_time:91709ms step_avg:92.82ms +step:989/1670 train_time:91800ms step_avg:92.82ms +step:990/1670 train_time:91893ms step_avg:92.82ms +step:991/1670 train_time:91987ms step_avg:92.82ms +step:992/1670 train_time:92079ms step_avg:92.82ms +step:993/1670 train_time:92171ms step_avg:92.82ms +step:994/1670 train_time:92264ms step_avg:92.82ms +step:995/1670 train_time:92356ms step_avg:92.82ms +step:996/1670 train_time:92448ms step_avg:92.82ms +step:997/1670 train_time:92540ms step_avg:92.82ms +step:998/1670 train_time:92632ms step_avg:92.82ms +step:999/1670 train_time:92725ms step_avg:92.82ms +step:1000/1670 train_time:92817ms step_avg:92.82ms +step:1000/1670 val_loss:3.4664 train_time:92911ms step_avg:92.91ms +step:1001/1670 train_time:92928ms step_avg:92.84ms +step:1002/1670 train_time:93005ms step_avg:92.82ms +step:1003/1670 train_time:93098ms step_avg:92.82ms +step:1004/1670 train_time:93189ms step_avg:92.82ms +step:1005/1670 train_time:93281ms step_avg:92.82ms +step:1006/1670 train_time:93373ms step_avg:92.82ms +step:1007/1670 train_time:93464ms step_avg:92.81ms +step:1008/1670 train_time:93557ms step_avg:92.81ms +step:1009/1670 train_time:93650ms step_avg:92.81ms +step:1010/1670 train_time:93742ms step_avg:92.81ms +step:1011/1670 train_time:93835ms step_avg:92.81ms +step:1012/1670 train_time:93929ms step_avg:92.82ms +step:1013/1670 train_time:94023ms step_avg:92.82ms +step:1014/1670 train_time:94116ms step_avg:92.82ms +step:1015/1670 train_time:94208ms step_avg:92.82ms +step:1016/1670 train_time:94301ms step_avg:92.82ms +step:1017/1670 train_time:94392ms step_avg:92.81ms +step:1018/1670 train_time:94484ms step_avg:92.81ms +step:1019/1670 train_time:94576ms step_avg:92.81ms +step:1020/1670 train_time:94668ms step_avg:92.81ms +step:1021/1670 train_time:94761ms step_avg:92.81ms +step:1022/1670 train_time:94856ms step_avg:92.81ms +step:1023/1670 train_time:94949ms step_avg:92.81ms +step:1024/1670 train_time:95042ms step_avg:92.81ms +step:1025/1670 train_time:95135ms step_avg:92.81ms +step:1026/1670 train_time:95227ms step_avg:92.81ms +step:1027/1670 train_time:95320ms step_avg:92.81ms +step:1028/1670 train_time:95412ms step_avg:92.81ms +step:1029/1670 train_time:95504ms step_avg:92.81ms +step:1030/1670 train_time:95598ms step_avg:92.81ms +step:1031/1670 train_time:95690ms step_avg:92.81ms +step:1032/1670 train_time:95783ms step_avg:92.81ms +step:1033/1670 train_time:95875ms step_avg:92.81ms +step:1034/1670 train_time:95969ms step_avg:92.81ms +step:1035/1670 train_time:96062ms step_avg:92.81ms +step:1036/1670 train_time:96154ms step_avg:92.81ms +step:1037/1670 train_time:96246ms step_avg:92.81ms +step:1038/1670 train_time:96338ms step_avg:92.81ms +step:1039/1670 train_time:96431ms step_avg:92.81ms +step:1040/1670 train_time:96523ms step_avg:92.81ms +step:1041/1670 train_time:96615ms step_avg:92.81ms +step:1042/1670 train_time:96707ms step_avg:92.81ms +step:1043/1670 train_time:96800ms step_avg:92.81ms +step:1044/1670 train_time:96892ms step_avg:92.81ms +step:1045/1670 train_time:96985ms step_avg:92.81ms +step:1046/1670 train_time:97079ms step_avg:92.81ms +step:1047/1670 train_time:97170ms step_avg:92.81ms +step:1048/1670 train_time:97263ms step_avg:92.81ms +step:1049/1670 train_time:97356ms step_avg:92.81ms +step:1050/1670 train_time:97449ms step_avg:92.81ms +step:1051/1670 train_time:97542ms step_avg:92.81ms +step:1052/1670 train_time:97634ms step_avg:92.81ms +step:1053/1670 train_time:97726ms step_avg:92.81ms +step:1054/1670 train_time:97820ms step_avg:92.81ms +step:1055/1670 train_time:97913ms step_avg:92.81ms +step:1056/1670 train_time:98006ms step_avg:92.81ms +step:1057/1670 train_time:98098ms step_avg:92.81ms +step:1058/1670 train_time:98190ms step_avg:92.81ms +step:1059/1670 train_time:98283ms step_avg:92.81ms +step:1060/1670 train_time:98375ms step_avg:92.81ms +step:1061/1670 train_time:98467ms step_avg:92.81ms +step:1062/1670 train_time:98716ms step_avg:92.95ms +step:1063/1670 train_time:98788ms step_avg:92.93ms +step:1064/1670 train_time:98878ms step_avg:92.93ms +step:1065/1670 train_time:98969ms step_avg:92.93ms +step:1066/1670 train_time:99060ms step_avg:92.93ms +step:1067/1670 train_time:99151ms step_avg:92.93ms +step:1068/1670 train_time:99242ms step_avg:92.92ms +step:1069/1670 train_time:99333ms step_avg:92.92ms +step:1070/1670 train_time:99425ms step_avg:92.92ms +step:1071/1670 train_time:99516ms step_avg:92.92ms +step:1072/1670 train_time:99612ms step_avg:92.92ms +step:1073/1670 train_time:99709ms step_avg:92.93ms +step:1074/1670 train_time:99803ms step_avg:92.93ms +step:1075/1670 train_time:99895ms step_avg:92.93ms +step:1076/1670 train_time:99987ms step_avg:92.92ms +step:1077/1670 train_time:100078ms step_avg:92.92ms +step:1078/1670 train_time:100169ms step_avg:92.92ms +step:1079/1670 train_time:100260ms step_avg:92.92ms +step:1080/1670 train_time:100354ms step_avg:92.92ms +step:1081/1670 train_time:100445ms step_avg:92.92ms +step:1082/1670 train_time:100540ms step_avg:92.92ms +step:1083/1670 train_time:100636ms step_avg:92.92ms +step:1084/1670 train_time:100730ms step_avg:92.92ms +step:1085/1670 train_time:100824ms step_avg:92.93ms +step:1086/1670 train_time:100917ms step_avg:92.93ms +step:1087/1670 train_time:101009ms step_avg:92.92ms +step:1088/1670 train_time:101100ms step_avg:92.92ms +step:1089/1670 train_time:101192ms step_avg:92.92ms +step:1090/1670 train_time:101283ms step_avg:92.92ms +step:1091/1670 train_time:101375ms step_avg:92.92ms +step:1092/1670 train_time:101467ms step_avg:92.92ms +step:1093/1670 train_time:101561ms step_avg:92.92ms +step:1094/1670 train_time:101655ms step_avg:92.92ms +step:1095/1670 train_time:101747ms step_avg:92.92ms +step:1096/1670 train_time:101843ms step_avg:92.92ms +step:1097/1670 train_time:101936ms step_avg:92.92ms +step:1098/1670 train_time:102028ms step_avg:92.92ms +step:1099/1670 train_time:102121ms step_avg:92.92ms +step:1100/1670 train_time:102212ms step_avg:92.92ms +step:1101/1670 train_time:102303ms step_avg:92.92ms +step:1102/1670 train_time:102396ms step_avg:92.92ms +step:1103/1670 train_time:102488ms step_avg:92.92ms +step:1104/1670 train_time:102583ms step_avg:92.92ms +step:1105/1670 train_time:102676ms step_avg:92.92ms +step:1106/1670 train_time:102768ms step_avg:92.92ms +step:1107/1670 train_time:102863ms step_avg:92.92ms +step:1108/1670 train_time:102957ms step_avg:92.92ms +step:1109/1670 train_time:103049ms step_avg:92.92ms +step:1110/1670 train_time:103141ms step_avg:92.92ms +step:1111/1670 train_time:103232ms step_avg:92.92ms +step:1112/1670 train_time:103324ms step_avg:92.92ms +step:1113/1670 train_time:103416ms step_avg:92.92ms +step:1114/1670 train_time:103509ms step_avg:92.92ms +step:1115/1670 train_time:103793ms step_avg:93.09ms +step:1116/1670 train_time:103870ms step_avg:93.07ms +step:1117/1670 train_time:103962ms step_avg:93.07ms +step:1118/1670 train_time:104054ms step_avg:93.07ms +step:1119/1670 train_time:104146ms step_avg:93.07ms +step:1120/1670 train_time:104237ms step_avg:93.07ms +step:1121/1670 train_time:104329ms step_avg:93.07ms +step:1122/1670 train_time:104421ms step_avg:93.07ms +step:1123/1670 train_time:104513ms step_avg:93.07ms +step:1124/1670 train_time:104605ms step_avg:93.07ms +step:1125/1670 train_time:104705ms step_avg:93.07ms +step:1125/1670 val_loss:3.4137 train_time:104806ms step_avg:93.16ms +step:1126/1670 train_time:104823ms step_avg:93.09ms +step:1127/1670 train_time:104906ms step_avg:93.08ms +step:1128/1670 train_time:105006ms step_avg:93.09ms +step:1129/1670 train_time:105099ms step_avg:93.09ms +step:1130/1670 train_time:105192ms step_avg:93.09ms +step:1131/1670 train_time:105285ms step_avg:93.09ms +step:1132/1670 train_time:105377ms step_avg:93.09ms +step:1133/1670 train_time:105469ms step_avg:93.09ms +step:1134/1670 train_time:105561ms step_avg:93.09ms +step:1135/1670 train_time:105654ms step_avg:93.09ms +step:1136/1670 train_time:105747ms step_avg:93.09ms +step:1137/1670 train_time:105842ms step_avg:93.09ms +step:1138/1670 train_time:105937ms step_avg:93.09ms +step:1139/1670 train_time:106033ms step_avg:93.09ms +step:1140/1670 train_time:106127ms step_avg:93.09ms +step:1141/1670 train_time:106220ms step_avg:93.09ms +step:1142/1670 train_time:106313ms step_avg:93.09ms +step:1143/1670 train_time:106406ms step_avg:93.09ms +step:1144/1670 train_time:106497ms step_avg:93.09ms +step:1145/1670 train_time:106589ms step_avg:93.09ms +step:1146/1670 train_time:106683ms step_avg:93.09ms +step:1147/1670 train_time:106777ms step_avg:93.09ms +step:1148/1670 train_time:106871ms step_avg:93.09ms +step:1149/1670 train_time:106966ms step_avg:93.09ms +step:1150/1670 train_time:107061ms step_avg:93.10ms +step:1151/1670 train_time:107154ms step_avg:93.10ms +step:1152/1670 train_time:107248ms step_avg:93.10ms +step:1153/1670 train_time:107340ms step_avg:93.10ms +step:1154/1670 train_time:107432ms step_avg:93.10ms +step:1155/1670 train_time:107525ms step_avg:93.10ms +step:1156/1670 train_time:107618ms step_avg:93.10ms +step:1157/1670 train_time:107711ms step_avg:93.09ms +step:1158/1670 train_time:107804ms step_avg:93.10ms +step:1159/1670 train_time:107898ms step_avg:93.10ms +step:1160/1670 train_time:107993ms step_avg:93.10ms +step:1161/1670 train_time:108087ms step_avg:93.10ms +step:1162/1670 train_time:108181ms step_avg:93.10ms +step:1163/1670 train_time:108274ms step_avg:93.10ms +step:1164/1670 train_time:108367ms step_avg:93.10ms +step:1165/1670 train_time:108460ms step_avg:93.10ms +step:1166/1670 train_time:108553ms step_avg:93.10ms +step:1167/1670 train_time:108645ms step_avg:93.10ms +step:1168/1670 train_time:108737ms step_avg:93.10ms +step:1169/1670 train_time:108832ms step_avg:93.10ms +step:1170/1670 train_time:108926ms step_avg:93.10ms +step:1171/1670 train_time:109019ms step_avg:93.10ms +step:1172/1670 train_time:109113ms step_avg:93.10ms +step:1173/1670 train_time:109207ms step_avg:93.10ms +step:1174/1670 train_time:109300ms step_avg:93.10ms +step:1175/1670 train_time:109393ms step_avg:93.10ms +step:1176/1670 train_time:109488ms step_avg:93.10ms +step:1177/1670 train_time:109581ms step_avg:93.10ms +step:1178/1670 train_time:109673ms step_avg:93.10ms +step:1179/1670 train_time:109766ms step_avg:93.10ms +step:1180/1670 train_time:109859ms step_avg:93.10ms +step:1181/1670 train_time:109953ms step_avg:93.10ms +step:1182/1670 train_time:110046ms step_avg:93.10ms +step:1183/1670 train_time:110139ms step_avg:93.10ms +step:1184/1670 train_time:110232ms step_avg:93.10ms +step:1185/1670 train_time:110326ms step_avg:93.10ms +step:1186/1670 train_time:110419ms step_avg:93.10ms +step:1187/1670 train_time:110512ms step_avg:93.10ms +step:1188/1670 train_time:110604ms step_avg:93.10ms +step:1189/1670 train_time:110698ms step_avg:93.10ms +step:1190/1670 train_time:110791ms step_avg:93.10ms +step:1191/1670 train_time:110885ms step_avg:93.10ms +step:1192/1670 train_time:110978ms step_avg:93.10ms +step:1193/1670 train_time:111071ms step_avg:93.10ms +step:1194/1670 train_time:111164ms step_avg:93.10ms +step:1195/1670 train_time:111257ms step_avg:93.10ms +step:1196/1670 train_time:111350ms step_avg:93.10ms +step:1197/1670 train_time:111443ms step_avg:93.10ms +step:1198/1670 train_time:111536ms step_avg:93.10ms +step:1199/1670 train_time:111630ms step_avg:93.10ms +step:1200/1670 train_time:111723ms step_avg:93.10ms +step:1201/1670 train_time:111815ms step_avg:93.10ms +step:1202/1670 train_time:111911ms step_avg:93.10ms +step:1203/1670 train_time:112004ms step_avg:93.10ms +step:1204/1670 train_time:112097ms step_avg:93.10ms +step:1205/1670 train_time:112190ms step_avg:93.10ms +step:1206/1670 train_time:112284ms step_avg:93.10ms +step:1207/1670 train_time:112377ms step_avg:93.10ms +step:1208/1670 train_time:112471ms step_avg:93.11ms +step:1209/1670 train_time:112564ms step_avg:93.11ms +step:1210/1670 train_time:112657ms step_avg:93.10ms +step:1211/1670 train_time:112751ms step_avg:93.11ms +step:1212/1670 train_time:112844ms step_avg:93.11ms +step:1213/1670 train_time:112937ms step_avg:93.11ms +step:1214/1670 train_time:113030ms step_avg:93.11ms +step:1215/1670 train_time:113124ms step_avg:93.11ms +step:1216/1670 train_time:113216ms step_avg:93.11ms +step:1217/1670 train_time:113310ms step_avg:93.11ms +step:1218/1670 train_time:113403ms step_avg:93.11ms +step:1219/1670 train_time:113497ms step_avg:93.11ms +step:1220/1670 train_time:113590ms step_avg:93.11ms +step:1221/1670 train_time:113683ms step_avg:93.11ms +step:1222/1670 train_time:113776ms step_avg:93.11ms +step:1223/1670 train_time:113869ms step_avg:93.11ms +step:1224/1670 train_time:113962ms step_avg:93.11ms +step:1225/1670 train_time:114055ms step_avg:93.11ms +step:1226/1670 train_time:114149ms step_avg:93.11ms +step:1227/1670 train_time:114242ms step_avg:93.11ms +step:1228/1670 train_time:114334ms step_avg:93.11ms +step:1229/1670 train_time:114428ms step_avg:93.11ms +step:1230/1670 train_time:114521ms step_avg:93.11ms +step:1231/1670 train_time:114614ms step_avg:93.11ms +step:1232/1670 train_time:114706ms step_avg:93.11ms +step:1233/1670 train_time:114799ms step_avg:93.11ms +step:1234/1670 train_time:114892ms step_avg:93.11ms +step:1235/1670 train_time:114987ms step_avg:93.11ms +step:1236/1670 train_time:115080ms step_avg:93.11ms +step:1237/1670 train_time:115173ms step_avg:93.11ms +step:1238/1670 train_time:115267ms step_avg:93.11ms +step:1239/1670 train_time:115359ms step_avg:93.11ms +step:1240/1670 train_time:115452ms step_avg:93.11ms +step:1241/1670 train_time:115546ms step_avg:93.11ms +step:1242/1670 train_time:115639ms step_avg:93.11ms +step:1243/1670 train_time:115732ms step_avg:93.11ms +step:1244/1670 train_time:115827ms step_avg:93.11ms +step:1245/1670 train_time:115920ms step_avg:93.11ms +step:1246/1670 train_time:116013ms step_avg:93.11ms +step:1247/1670 train_time:116106ms step_avg:93.11ms +step:1248/1670 train_time:116199ms step_avg:93.11ms +step:1249/1670 train_time:116292ms step_avg:93.11ms +step:1250/1670 train_time:116386ms step_avg:93.11ms +step:1250/1670 val_loss:3.3755 train_time:116479ms step_avg:93.18ms +step:1251/1670 train_time:116496ms step_avg:93.12ms +step:1252/1670 train_time:116574ms step_avg:93.11ms +step:1253/1670 train_time:116668ms step_avg:93.11ms +step:1254/1670 train_time:116761ms step_avg:93.11ms +step:1255/1670 train_time:116854ms step_avg:93.11ms +step:1256/1670 train_time:116946ms step_avg:93.11ms +step:1257/1670 train_time:117039ms step_avg:93.11ms +step:1258/1670 train_time:117132ms step_avg:93.11ms +step:1259/1670 train_time:117224ms step_avg:93.11ms +step:1260/1670 train_time:117318ms step_avg:93.11ms +step:1261/1670 train_time:117412ms step_avg:93.11ms +step:1262/1670 train_time:117506ms step_avg:93.11ms +step:1263/1670 train_time:117601ms step_avg:93.11ms +step:1264/1670 train_time:117694ms step_avg:93.11ms +step:1265/1670 train_time:117786ms step_avg:93.11ms +step:1266/1670 train_time:117879ms step_avg:93.11ms +step:1267/1670 train_time:117973ms step_avg:93.11ms +step:1268/1670 train_time:118066ms step_avg:93.11ms +step:1269/1670 train_time:118159ms step_avg:93.11ms +step:1270/1670 train_time:118251ms step_avg:93.11ms +step:1271/1670 train_time:118344ms step_avg:93.11ms +step:1272/1670 train_time:118438ms step_avg:93.11ms +step:1273/1670 train_time:118532ms step_avg:93.11ms +step:1274/1670 train_time:118769ms step_avg:93.23ms +step:1275/1670 train_time:118853ms step_avg:93.22ms +step:1276/1670 train_time:118944ms step_avg:93.22ms +step:1277/1670 train_time:119036ms step_avg:93.22ms +step:1278/1670 train_time:119129ms step_avg:93.21ms +step:1279/1670 train_time:119220ms step_avg:93.21ms +step:1280/1670 train_time:119312ms step_avg:93.21ms +step:1281/1670 train_time:119404ms step_avg:93.21ms +step:1282/1670 train_time:119497ms step_avg:93.21ms +step:1283/1670 train_time:119588ms step_avg:93.21ms +step:1284/1670 train_time:119686ms step_avg:93.21ms +step:1285/1670 train_time:119785ms step_avg:93.22ms +step:1286/1670 train_time:119879ms step_avg:93.22ms +step:1287/1670 train_time:119972ms step_avg:93.22ms +step:1288/1670 train_time:120064ms step_avg:93.22ms +step:1289/1670 train_time:120156ms step_avg:93.22ms +step:1290/1670 train_time:120249ms step_avg:93.22ms +step:1291/1670 train_time:120341ms step_avg:93.22ms +step:1292/1670 train_time:120433ms step_avg:93.21ms +step:1293/1670 train_time:120525ms step_avg:93.21ms +step:1294/1670 train_time:120619ms step_avg:93.21ms +step:1295/1670 train_time:120715ms step_avg:93.22ms +step:1296/1670 train_time:120810ms step_avg:93.22ms +step:1297/1670 train_time:120903ms step_avg:93.22ms +step:1298/1670 train_time:120997ms step_avg:93.22ms +step:1299/1670 train_time:121090ms step_avg:93.22ms +step:1300/1670 train_time:121182ms step_avg:93.22ms +step:1301/1670 train_time:121278ms step_avg:93.22ms +step:1302/1670 train_time:121371ms step_avg:93.22ms +step:1303/1670 train_time:121462ms step_avg:93.22ms +step:1304/1670 train_time:121555ms step_avg:93.22ms +step:1305/1670 train_time:121648ms step_avg:93.22ms +step:1306/1670 train_time:121743ms step_avg:93.22ms +step:1307/1670 train_time:121838ms step_avg:93.22ms +step:1308/1670 train_time:121931ms step_avg:93.22ms +step:1309/1670 train_time:122024ms step_avg:93.22ms +step:1310/1670 train_time:122117ms step_avg:93.22ms +step:1311/1670 train_time:122211ms step_avg:93.22ms +step:1312/1670 train_time:122304ms step_avg:93.22ms +step:1313/1670 train_time:122397ms step_avg:93.22ms +step:1314/1670 train_time:122490ms step_avg:93.22ms +step:1315/1670 train_time:122583ms step_avg:93.22ms +step:1316/1670 train_time:122677ms step_avg:93.22ms +step:1317/1670 train_time:122771ms step_avg:93.22ms +step:1318/1670 train_time:122865ms step_avg:93.22ms +step:1319/1670 train_time:122958ms step_avg:93.22ms +step:1320/1670 train_time:123050ms step_avg:93.22ms +step:1321/1670 train_time:123144ms step_avg:93.22ms +step:1322/1670 train_time:123238ms step_avg:93.22ms +step:1323/1670 train_time:123331ms step_avg:93.22ms +step:1324/1670 train_time:123423ms step_avg:93.22ms +step:1325/1670 train_time:123516ms step_avg:93.22ms +step:1326/1670 train_time:123610ms step_avg:93.22ms +step:1327/1670 train_time:123704ms step_avg:93.22ms +step:1328/1670 train_time:123798ms step_avg:93.22ms +step:1329/1670 train_time:123891ms step_avg:93.22ms +step:1330/1670 train_time:123984ms step_avg:93.22ms +step:1331/1670 train_time:124078ms step_avg:93.22ms +step:1332/1670 train_time:124171ms step_avg:93.22ms +step:1333/1670 train_time:124263ms step_avg:93.22ms +step:1334/1670 train_time:124356ms step_avg:93.22ms +step:1335/1670 train_time:124448ms step_avg:93.22ms +step:1336/1670 train_time:124542ms step_avg:93.22ms +step:1337/1670 train_time:124636ms step_avg:93.22ms +step:1338/1670 train_time:124730ms step_avg:93.22ms +step:1339/1670 train_time:124824ms step_avg:93.22ms +step:1340/1670 train_time:124918ms step_avg:93.22ms +step:1341/1670 train_time:125010ms step_avg:93.22ms +step:1342/1670 train_time:125105ms step_avg:93.22ms +step:1343/1670 train_time:125198ms step_avg:93.22ms +step:1344/1670 train_time:125291ms step_avg:93.22ms +step:1345/1670 train_time:125383ms step_avg:93.22ms +step:1346/1670 train_time:125477ms step_avg:93.22ms +step:1347/1670 train_time:125570ms step_avg:93.22ms +step:1348/1670 train_time:125663ms step_avg:93.22ms +step:1349/1670 train_time:125756ms step_avg:93.22ms +step:1350/1670 train_time:125849ms step_avg:93.22ms +step:1351/1670 train_time:125942ms step_avg:93.22ms +step:1352/1670 train_time:126037ms step_avg:93.22ms +step:1353/1670 train_time:126131ms step_avg:93.22ms +step:1354/1670 train_time:126224ms step_avg:93.22ms +step:1355/1670 train_time:126317ms step_avg:93.22ms +step:1356/1670 train_time:126410ms step_avg:93.22ms +step:1357/1670 train_time:126503ms step_avg:93.22ms +step:1358/1670 train_time:126597ms step_avg:93.22ms +step:1359/1670 train_time:126690ms step_avg:93.22ms +step:1360/1670 train_time:126783ms step_avg:93.22ms +step:1361/1670 train_time:126878ms step_avg:93.22ms +step:1362/1670 train_time:126971ms step_avg:93.22ms +step:1363/1670 train_time:127064ms step_avg:93.22ms +step:1364/1670 train_time:127157ms step_avg:93.22ms +step:1365/1670 train_time:127250ms step_avg:93.22ms +step:1366/1670 train_time:127343ms step_avg:93.22ms +step:1367/1670 train_time:127437ms step_avg:93.22ms +step:1368/1670 train_time:127530ms step_avg:93.22ms +step:1369/1670 train_time:127623ms step_avg:93.22ms +step:1370/1670 train_time:127715ms step_avg:93.22ms +step:1371/1670 train_time:127808ms step_avg:93.22ms +step:1372/1670 train_time:127903ms step_avg:93.22ms +step:1373/1670 train_time:127997ms step_avg:93.22ms +step:1374/1670 train_time:128090ms step_avg:93.22ms +step:1375/1670 train_time:128184ms step_avg:93.22ms +step:1375/1670 val_loss:3.3405 train_time:128278ms step_avg:93.29ms +step:1376/1670 train_time:128295ms step_avg:93.24ms +step:1377/1670 train_time:128373ms step_avg:93.23ms +step:1378/1670 train_time:128466ms step_avg:93.23ms +step:1379/1670 train_time:128559ms step_avg:93.23ms +step:1380/1670 train_time:128651ms step_avg:93.23ms +step:1381/1670 train_time:128744ms step_avg:93.23ms +step:1382/1670 train_time:128836ms step_avg:93.22ms +step:1383/1670 train_time:128930ms step_avg:93.23ms +step:1384/1670 train_time:129023ms step_avg:93.22ms +step:1385/1670 train_time:129116ms step_avg:93.22ms +step:1386/1670 train_time:129211ms step_avg:93.23ms +step:1387/1670 train_time:129305ms step_avg:93.23ms +step:1388/1670 train_time:129400ms step_avg:93.23ms +step:1389/1670 train_time:129494ms step_avg:93.23ms +step:1390/1670 train_time:129587ms step_avg:93.23ms +step:1391/1670 train_time:129680ms step_avg:93.23ms +step:1392/1670 train_time:129772ms step_avg:93.23ms +step:1393/1670 train_time:129865ms step_avg:93.23ms +step:1394/1670 train_time:129958ms step_avg:93.23ms +step:1395/1670 train_time:130051ms step_avg:93.23ms +step:1396/1670 train_time:130145ms step_avg:93.23ms +step:1397/1670 train_time:130238ms step_avg:93.23ms +step:1398/1670 train_time:130334ms step_avg:93.23ms +step:1399/1670 train_time:130429ms step_avg:93.23ms +step:1400/1670 train_time:130522ms step_avg:93.23ms +step:1401/1670 train_time:130615ms step_avg:93.23ms +step:1402/1670 train_time:130709ms step_avg:93.23ms +step:1403/1670 train_time:130801ms step_avg:93.23ms +step:1404/1670 train_time:130894ms step_avg:93.23ms +step:1405/1670 train_time:130987ms step_avg:93.23ms +step:1406/1670 train_time:131079ms step_avg:93.23ms +step:1407/1670 train_time:131172ms step_avg:93.23ms +step:1408/1670 train_time:131266ms step_avg:93.23ms +step:1409/1670 train_time:131360ms step_avg:93.23ms +step:1410/1670 train_time:131453ms step_avg:93.23ms +step:1411/1670 train_time:131547ms step_avg:93.23ms +step:1412/1670 train_time:131641ms step_avg:93.23ms +step:1413/1670 train_time:131734ms step_avg:93.23ms +step:1414/1670 train_time:131828ms step_avg:93.23ms +step:1415/1670 train_time:131922ms step_avg:93.23ms +step:1416/1670 train_time:132014ms step_avg:93.23ms +step:1417/1670 train_time:132108ms step_avg:93.23ms +step:1418/1670 train_time:132202ms step_avg:93.23ms +step:1419/1670 train_time:132295ms step_avg:93.23ms +step:1420/1670 train_time:132388ms step_avg:93.23ms +step:1421/1670 train_time:132480ms step_avg:93.23ms +step:1422/1670 train_time:132574ms step_avg:93.23ms +step:1423/1670 train_time:132668ms step_avg:93.23ms +step:1424/1670 train_time:132762ms step_avg:93.23ms +step:1425/1670 train_time:132855ms step_avg:93.23ms +step:1426/1670 train_time:132948ms step_avg:93.23ms +step:1427/1670 train_time:133041ms step_avg:93.23ms +step:1428/1670 train_time:133134ms step_avg:93.23ms +step:1429/1670 train_time:133229ms step_avg:93.23ms +step:1430/1670 train_time:133323ms step_avg:93.23ms +step:1431/1670 train_time:133416ms step_avg:93.23ms +step:1432/1670 train_time:133509ms step_avg:93.23ms +step:1433/1670 train_time:133602ms step_avg:93.23ms +step:1434/1670 train_time:133695ms step_avg:93.23ms +step:1435/1670 train_time:133789ms step_avg:93.23ms +step:1436/1670 train_time:133882ms step_avg:93.23ms +step:1437/1670 train_time:133974ms step_avg:93.23ms +step:1438/1670 train_time:134068ms step_avg:93.23ms +step:1439/1670 train_time:134161ms step_avg:93.23ms +step:1440/1670 train_time:134256ms step_avg:93.23ms +step:1441/1670 train_time:134350ms step_avg:93.23ms +step:1442/1670 train_time:134444ms step_avg:93.23ms +step:1443/1670 train_time:134537ms step_avg:93.23ms +step:1444/1670 train_time:134630ms step_avg:93.23ms +step:1445/1670 train_time:134724ms step_avg:93.23ms +step:1446/1670 train_time:134817ms step_avg:93.23ms +step:1447/1670 train_time:134911ms step_avg:93.23ms +step:1448/1670 train_time:135004ms step_avg:93.23ms +step:1449/1670 train_time:135097ms step_avg:93.23ms +step:1450/1670 train_time:135191ms step_avg:93.24ms +step:1451/1670 train_time:135286ms step_avg:93.24ms +step:1452/1670 train_time:135379ms step_avg:93.24ms +step:1453/1670 train_time:135472ms step_avg:93.24ms +step:1454/1670 train_time:135565ms step_avg:93.24ms +step:1455/1670 train_time:135659ms step_avg:93.24ms +step:1456/1670 train_time:135752ms step_avg:93.24ms +step:1457/1670 train_time:135844ms step_avg:93.24ms +step:1458/1670 train_time:135937ms step_avg:93.24ms +step:1459/1670 train_time:136032ms step_avg:93.24ms +step:1460/1670 train_time:136126ms step_avg:93.24ms +step:1461/1670 train_time:136220ms step_avg:93.24ms +step:1462/1670 train_time:136313ms step_avg:93.24ms +step:1463/1670 train_time:136406ms step_avg:93.24ms +step:1464/1670 train_time:136499ms step_avg:93.24ms +step:1465/1670 train_time:136593ms step_avg:93.24ms +step:1466/1670 train_time:136688ms step_avg:93.24ms +step:1467/1670 train_time:136780ms step_avg:93.24ms +step:1468/1670 train_time:136873ms step_avg:93.24ms +step:1469/1670 train_time:136966ms step_avg:93.24ms +step:1470/1670 train_time:137060ms step_avg:93.24ms +step:1471/1670 train_time:137153ms step_avg:93.24ms +step:1472/1670 train_time:137248ms step_avg:93.24ms +step:1473/1670 train_time:137341ms step_avg:93.24ms +step:1474/1670 train_time:137434ms step_avg:93.24ms +step:1475/1670 train_time:137528ms step_avg:93.24ms +step:1476/1670 train_time:137621ms step_avg:93.24ms +step:1477/1670 train_time:137715ms step_avg:93.24ms +step:1478/1670 train_time:137808ms step_avg:93.24ms +step:1479/1670 train_time:137900ms step_avg:93.24ms +step:1480/1670 train_time:137993ms step_avg:93.24ms +step:1481/1670 train_time:138086ms step_avg:93.24ms +step:1482/1670 train_time:138180ms step_avg:93.24ms +step:1483/1670 train_time:138273ms step_avg:93.24ms +step:1484/1670 train_time:138366ms step_avg:93.24ms +step:1485/1670 train_time:138616ms step_avg:93.34ms +step:1486/1670 train_time:138689ms step_avg:93.33ms +step:1487/1670 train_time:138782ms step_avg:93.33ms +step:1488/1670 train_time:138873ms step_avg:93.33ms +step:1489/1670 train_time:138965ms step_avg:93.33ms +step:1490/1670 train_time:139057ms step_avg:93.33ms +step:1491/1670 train_time:139149ms step_avg:93.33ms +step:1492/1670 train_time:139241ms step_avg:93.32ms +step:1493/1670 train_time:139332ms step_avg:93.32ms +step:1494/1670 train_time:139425ms step_avg:93.32ms +step:1495/1670 train_time:139523ms step_avg:93.33ms +step:1496/1670 train_time:139620ms step_avg:93.33ms +step:1497/1670 train_time:139714ms step_avg:93.33ms +step:1498/1670 train_time:139807ms step_avg:93.33ms +step:1499/1670 train_time:139900ms step_avg:93.33ms +step:1500/1670 train_time:139992ms step_avg:93.33ms +step:1500/1670 val_loss:3.3106 train_time:140086ms step_avg:93.39ms +step:1501/1670 train_time:140103ms step_avg:93.34ms +step:1502/1670 train_time:140180ms step_avg:93.33ms +step:1503/1670 train_time:140272ms step_avg:93.33ms +step:1504/1670 train_time:140364ms step_avg:93.33ms +step:1505/1670 train_time:140457ms step_avg:93.33ms +step:1506/1670 train_time:140549ms step_avg:93.33ms +step:1507/1670 train_time:140642ms step_avg:93.33ms +step:1508/1670 train_time:140738ms step_avg:93.33ms +step:1509/1670 train_time:140831ms step_avg:93.33ms +step:1510/1670 train_time:140924ms step_avg:93.33ms +step:1511/1670 train_time:141018ms step_avg:93.33ms +step:1512/1670 train_time:141112ms step_avg:93.33ms +step:1513/1670 train_time:141206ms step_avg:93.33ms +step:1514/1670 train_time:141299ms step_avg:93.33ms +step:1515/1670 train_time:141391ms step_avg:93.33ms +step:1516/1670 train_time:141483ms step_avg:93.33ms +step:1517/1670 train_time:141577ms step_avg:93.33ms +step:1518/1670 train_time:141670ms step_avg:93.33ms +step:1519/1670 train_time:141764ms step_avg:93.33ms +step:1520/1670 train_time:141858ms step_avg:93.33ms +step:1521/1670 train_time:141951ms step_avg:93.33ms +step:1522/1670 train_time:142045ms step_avg:93.33ms +step:1523/1670 train_time:142140ms step_avg:93.33ms +step:1524/1670 train_time:142233ms step_avg:93.33ms +step:1525/1670 train_time:142326ms step_avg:93.33ms +step:1526/1670 train_time:142420ms step_avg:93.33ms +step:1527/1670 train_time:142512ms step_avg:93.33ms +step:1528/1670 train_time:142605ms step_avg:93.33ms +step:1529/1670 train_time:142699ms step_avg:93.33ms +step:1530/1670 train_time:142792ms step_avg:93.33ms +step:1531/1670 train_time:142885ms step_avg:93.33ms +step:1532/1670 train_time:142980ms step_avg:93.33ms +step:1533/1670 train_time:143074ms step_avg:93.33ms +step:1534/1670 train_time:143167ms step_avg:93.33ms +step:1535/1670 train_time:143260ms step_avg:93.33ms +step:1536/1670 train_time:143353ms step_avg:93.33ms +step:1537/1670 train_time:143447ms step_avg:93.33ms +step:1538/1670 train_time:143541ms step_avg:93.33ms +step:1539/1670 train_time:143634ms step_avg:93.33ms +step:1540/1670 train_time:143727ms step_avg:93.33ms +step:1541/1670 train_time:143821ms step_avg:93.33ms +step:1542/1670 train_time:143916ms step_avg:93.33ms +step:1543/1670 train_time:144010ms step_avg:93.33ms +step:1544/1670 train_time:144103ms step_avg:93.33ms +step:1545/1670 train_time:144196ms step_avg:93.33ms +step:1546/1670 train_time:144289ms step_avg:93.33ms +step:1547/1670 train_time:144382ms step_avg:93.33ms +step:1548/1670 train_time:144476ms step_avg:93.33ms +step:1549/1670 train_time:144569ms step_avg:93.33ms +step:1550/1670 train_time:144662ms step_avg:93.33ms +step:1551/1670 train_time:144755ms step_avg:93.33ms +step:1552/1670 train_time:144849ms step_avg:93.33ms +step:1553/1670 train_time:144942ms step_avg:93.33ms +step:1554/1670 train_time:145035ms step_avg:93.33ms +step:1555/1670 train_time:145128ms step_avg:93.33ms +step:1556/1670 train_time:145221ms step_avg:93.33ms +step:1557/1670 train_time:145313ms step_avg:93.33ms +step:1558/1670 train_time:145406ms step_avg:93.33ms +step:1559/1670 train_time:145499ms step_avg:93.33ms +step:1560/1670 train_time:145592ms step_avg:93.33ms +step:1561/1670 train_time:145685ms step_avg:93.33ms +step:1562/1670 train_time:145780ms step_avg:93.33ms +step:1563/1670 train_time:145875ms step_avg:93.33ms +step:1564/1670 train_time:145966ms step_avg:93.33ms +step:1565/1670 train_time:146059ms step_avg:93.33ms +step:1566/1670 train_time:146152ms step_avg:93.33ms +step:1567/1670 train_time:146245ms step_avg:93.33ms +step:1568/1670 train_time:146338ms step_avg:93.33ms +step:1569/1670 train_time:146431ms step_avg:93.33ms +step:1570/1670 train_time:146524ms step_avg:93.33ms +step:1571/1670 train_time:146618ms step_avg:93.33ms +step:1572/1670 train_time:146711ms step_avg:93.33ms +step:1573/1670 train_time:146804ms step_avg:93.33ms +step:1574/1670 train_time:146898ms step_avg:93.33ms +step:1575/1670 train_time:146992ms step_avg:93.33ms +step:1576/1670 train_time:147084ms step_avg:93.33ms +step:1577/1670 train_time:147180ms step_avg:93.33ms +step:1578/1670 train_time:147273ms step_avg:93.33ms +step:1579/1670 train_time:147365ms step_avg:93.33ms +step:1580/1670 train_time:147459ms step_avg:93.33ms +step:1581/1670 train_time:147553ms step_avg:93.33ms +step:1582/1670 train_time:147646ms step_avg:93.33ms +step:1583/1670 train_time:147740ms step_avg:93.33ms +step:1584/1670 train_time:147833ms step_avg:93.33ms +step:1585/1670 train_time:147927ms step_avg:93.33ms +step:1586/1670 train_time:148020ms step_avg:93.33ms +step:1587/1670 train_time:148113ms step_avg:93.33ms +step:1588/1670 train_time:148206ms step_avg:93.33ms +step:1589/1670 train_time:148300ms step_avg:93.33ms +step:1590/1670 train_time:148393ms step_avg:93.33ms +step:1591/1670 train_time:148486ms step_avg:93.33ms +step:1592/1670 train_time:148579ms step_avg:93.33ms +step:1593/1670 train_time:148673ms step_avg:93.33ms +step:1594/1670 train_time:148766ms step_avg:93.33ms +step:1595/1670 train_time:148860ms step_avg:93.33ms +step:1596/1670 train_time:148952ms step_avg:93.33ms +step:1597/1670 train_time:149046ms step_avg:93.33ms +step:1598/1670 train_time:149140ms step_avg:93.33ms +step:1599/1670 train_time:149233ms step_avg:93.33ms +step:1600/1670 train_time:149326ms step_avg:93.33ms +step:1601/1670 train_time:149419ms step_avg:93.33ms +step:1602/1670 train_time:149511ms step_avg:93.33ms +step:1603/1670 train_time:149604ms step_avg:93.33ms +step:1604/1670 train_time:149699ms step_avg:93.33ms +step:1605/1670 train_time:149792ms step_avg:93.33ms +step:1606/1670 train_time:149884ms step_avg:93.33ms +step:1607/1670 train_time:149979ms step_avg:93.33ms +step:1608/1670 train_time:150073ms step_avg:93.33ms +step:1609/1670 train_time:150166ms step_avg:93.33ms +step:1610/1670 train_time:150261ms step_avg:93.33ms +step:1611/1670 train_time:150354ms step_avg:93.33ms +step:1612/1670 train_time:150447ms step_avg:93.33ms +step:1613/1670 train_time:150540ms step_avg:93.33ms +step:1614/1670 train_time:150635ms step_avg:93.33ms +step:1615/1670 train_time:150729ms step_avg:93.33ms +step:1616/1670 train_time:150821ms step_avg:93.33ms +step:1617/1670 train_time:150914ms step_avg:93.33ms +step:1618/1670 train_time:151008ms step_avg:93.33ms +step:1619/1670 train_time:151101ms step_avg:93.33ms +step:1620/1670 train_time:151194ms step_avg:93.33ms +step:1621/1670 train_time:151287ms step_avg:93.33ms +step:1622/1670 train_time:151381ms step_avg:93.33ms +step:1623/1670 train_time:151474ms step_avg:93.33ms +step:1624/1670 train_time:151567ms step_avg:93.33ms +step:1625/1670 train_time:151662ms step_avg:93.33ms +step:1625/1670 val_loss:3.2855 train_time:151756ms step_avg:93.39ms +step:1626/1670 train_time:151773ms step_avg:93.34ms +step:1627/1670 train_time:151850ms step_avg:93.33ms +step:1628/1670 train_time:151944ms step_avg:93.33ms +step:1629/1670 train_time:152038ms step_avg:93.33ms +step:1630/1670 train_time:152130ms step_avg:93.33ms +step:1631/1670 train_time:152224ms step_avg:93.33ms +step:1632/1670 train_time:152317ms step_avg:93.33ms +step:1633/1670 train_time:152410ms step_avg:93.33ms +step:1634/1670 train_time:152503ms step_avg:93.33ms +step:1635/1670 train_time:152597ms step_avg:93.33ms +step:1636/1670 train_time:152690ms step_avg:93.33ms +step:1637/1670 train_time:152785ms step_avg:93.33ms +step:1638/1670 train_time:152879ms step_avg:93.33ms +step:1639/1670 train_time:152972ms step_avg:93.33ms +step:1640/1670 train_time:153066ms step_avg:93.33ms +step:1641/1670 train_time:153160ms step_avg:93.33ms +step:1642/1670 train_time:153252ms step_avg:93.33ms +step:1643/1670 train_time:153346ms step_avg:93.33ms +step:1644/1670 train_time:153439ms step_avg:93.33ms +step:1645/1670 train_time:153532ms step_avg:93.33ms +step:1646/1670 train_time:153627ms step_avg:93.33ms +step:1647/1670 train_time:153722ms step_avg:93.33ms +step:1648/1670 train_time:153816ms step_avg:93.34ms +step:1649/1670 train_time:153910ms step_avg:93.34ms +step:1650/1670 train_time:154004ms step_avg:93.34ms +step:1651/1670 train_time:154096ms step_avg:93.34ms +step:1652/1670 train_time:154190ms step_avg:93.34ms +step:1653/1670 train_time:154283ms step_avg:93.34ms +step:1654/1670 train_time:154375ms step_avg:93.33ms +step:1655/1670 train_time:154469ms step_avg:93.33ms +step:1656/1670 train_time:154562ms step_avg:93.33ms +step:1657/1670 train_time:154655ms step_avg:93.33ms +step:1658/1670 train_time:154749ms step_avg:93.33ms +step:1659/1670 train_time:154842ms step_avg:93.33ms +step:1660/1670 train_time:154936ms step_avg:93.33ms +step:1661/1670 train_time:155029ms step_avg:93.33ms +step:1662/1670 train_time:155123ms step_avg:93.33ms +step:1663/1670 train_time:155216ms step_avg:93.33ms +step:1664/1670 train_time:155309ms step_avg:93.33ms +step:1665/1670 train_time:155402ms step_avg:93.33ms +step:1666/1670 train_time:155496ms step_avg:93.33ms +step:1667/1670 train_time:155589ms step_avg:93.33ms +step:1668/1670 train_time:155683ms step_avg:93.33ms +step:1669/1670 train_time:155775ms step_avg:93.33ms +step:1670/1670 train_time:155868ms step_avg:93.33ms +step:1670/1670 val_loss:3.2769 train_time:156129ms step_avg:93.49ms +peak memory allocated: 31587 MiB reserved: 47114 MiB diff --git a/records/091125_VectSigmoidBFloat16/b4bb35d4-92c1-42f4-91dd-dfbf665e66b4.txt b/records/091125_VectSigmoidBFloat16/b4bb35d4-92c1-42f4-91dd-dfbf665e66b4.txt new file mode 100644 index 000000000..ca7b53ae1 --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/b4bb35d4-92c1-42f4-91dd-dfbf665e66b4.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:20:37 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 127W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 133W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 46C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 38C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 46C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 45C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 41C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:292ms step_avg:291.66ms +step:2/1670 train_time:311ms step_avg:155.28ms +step:3/1670 train_time:378ms step_avg:126.07ms +step:4/1670 train_time:468ms step_avg:116.89ms +step:5/1670 train_time:558ms step_avg:111.51ms +step:6/1670 train_time:649ms step_avg:108.12ms +step:7/1670 train_time:739ms step_avg:105.56ms +step:8/1670 train_time:830ms step_avg:103.75ms +step:9/1670 train_time:920ms step_avg:102.26ms +step:10/1670 train_time:1011ms step_avg:101.11ms +step:11/1670 train_time:1101ms step_avg:100.12ms +step:12/1670 train_time:1195ms step_avg:99.57ms +step:13/1670 train_time:1289ms step_avg:99.13ms +step:14/1670 train_time:1381ms step_avg:98.67ms +step:15/1670 train_time:1474ms step_avg:98.25ms +step:16/1670 train_time:1565ms step_avg:97.84ms +step:17/1670 train_time:1656ms step_avg:97.44ms +step:18/1670 train_time:1748ms step_avg:97.09ms +step:19/1670 train_time:1839ms step_avg:96.79ms +step:20/1670 train_time:1930ms step_avg:96.50ms +step:21/1670 train_time:2020ms step_avg:96.20ms +step:22/1670 train_time:2111ms step_avg:95.97ms +step:23/1670 train_time:2205ms step_avg:95.86ms +step:24/1670 train_time:2299ms step_avg:95.78ms +step:25/1670 train_time:2391ms step_avg:95.65ms +step:26/1670 train_time:2483ms step_avg:95.50ms +step:27/1670 train_time:2575ms step_avg:95.37ms +step:28/1670 train_time:2668ms step_avg:95.28ms +step:29/1670 train_time:2759ms step_avg:95.15ms +step:30/1670 train_time:2851ms step_avg:95.03ms +step:31/1670 train_time:2942ms step_avg:94.90ms +step:32/1670 train_time:3034ms step_avg:94.82ms +step:33/1670 train_time:3125ms step_avg:94.70ms +step:34/1670 train_time:3218ms step_avg:94.65ms +step:35/1670 train_time:3311ms step_avg:94.60ms +step:36/1670 train_time:3403ms step_avg:94.51ms +step:37/1670 train_time:3494ms step_avg:94.44ms +step:38/1670 train_time:3586ms step_avg:94.36ms +step:39/1670 train_time:3679ms step_avg:94.34ms +step:40/1670 train_time:3772ms step_avg:94.31ms +step:41/1670 train_time:3864ms step_avg:94.26ms +step:42/1670 train_time:3956ms step_avg:94.20ms +step:43/1670 train_time:4047ms step_avg:94.12ms +step:44/1670 train_time:4139ms step_avg:94.07ms +step:45/1670 train_time:4232ms step_avg:94.05ms +step:46/1670 train_time:4325ms step_avg:94.01ms +step:47/1670 train_time:4418ms step_avg:94.00ms +step:48/1670 train_time:4511ms step_avg:93.97ms +step:49/1670 train_time:4602ms step_avg:93.92ms +step:50/1670 train_time:4694ms step_avg:93.89ms +step:51/1670 train_time:4786ms step_avg:93.84ms +step:52/1670 train_time:4878ms step_avg:93.81ms +step:53/1670 train_time:4970ms step_avg:93.78ms +step:54/1670 train_time:5062ms step_avg:93.74ms +step:55/1670 train_time:5154ms step_avg:93.71ms +step:56/1670 train_time:5245ms step_avg:93.66ms +step:57/1670 train_time:5338ms step_avg:93.65ms +step:58/1670 train_time:5431ms step_avg:93.65ms +step:59/1670 train_time:5523ms step_avg:93.61ms +step:60/1670 train_time:5615ms step_avg:93.59ms +step:61/1670 train_time:5707ms step_avg:93.55ms +step:62/1670 train_time:5799ms step_avg:93.53ms +step:63/1670 train_time:5890ms step_avg:93.49ms +step:64/1670 train_time:5981ms step_avg:93.45ms +step:65/1670 train_time:6073ms step_avg:93.44ms +step:66/1670 train_time:6164ms step_avg:93.39ms +step:67/1670 train_time:6256ms step_avg:93.37ms +step:68/1670 train_time:6348ms step_avg:93.35ms +step:69/1670 train_time:6441ms step_avg:93.34ms +step:70/1670 train_time:6534ms step_avg:93.34ms +step:71/1670 train_time:6625ms step_avg:93.31ms +step:72/1670 train_time:6718ms step_avg:93.30ms +step:73/1670 train_time:6810ms step_avg:93.29ms +step:74/1670 train_time:6901ms step_avg:93.26ms +step:75/1670 train_time:6993ms step_avg:93.24ms +step:76/1670 train_time:7085ms step_avg:93.23ms +step:77/1670 train_time:7177ms step_avg:93.20ms +step:78/1670 train_time:7268ms step_avg:93.17ms +step:79/1670 train_time:7359ms step_avg:93.15ms +step:80/1670 train_time:7450ms step_avg:93.13ms +step:81/1670 train_time:7542ms step_avg:93.11ms +step:82/1670 train_time:7634ms step_avg:93.09ms +step:83/1670 train_time:7725ms step_avg:93.07ms +step:84/1670 train_time:7817ms step_avg:93.06ms +step:85/1670 train_time:7910ms step_avg:93.06ms +step:86/1670 train_time:8001ms step_avg:93.04ms +step:87/1670 train_time:8092ms step_avg:93.01ms +step:88/1670 train_time:8182ms step_avg:92.98ms +step:89/1670 train_time:8274ms step_avg:92.96ms +step:90/1670 train_time:8365ms step_avg:92.94ms +step:91/1670 train_time:8456ms step_avg:92.92ms +step:92/1670 train_time:8547ms step_avg:92.90ms +step:93/1670 train_time:8639ms step_avg:92.89ms +step:94/1670 train_time:8731ms step_avg:92.88ms +step:95/1670 train_time:8822ms step_avg:92.86ms +step:96/1670 train_time:8914ms step_avg:92.85ms +step:97/1670 train_time:9006ms step_avg:92.84ms +step:98/1670 train_time:9099ms step_avg:92.84ms +step:99/1670 train_time:9190ms step_avg:92.83ms +step:100/1670 train_time:9281ms step_avg:92.81ms +step:101/1670 train_time:9372ms step_avg:92.79ms +step:102/1670 train_time:9463ms step_avg:92.78ms +step:103/1670 train_time:9554ms step_avg:92.76ms +step:104/1670 train_time:9645ms step_avg:92.74ms +step:105/1670 train_time:9738ms step_avg:92.75ms +step:106/1670 train_time:9829ms step_avg:92.73ms +step:107/1670 train_time:9921ms step_avg:92.72ms +step:108/1670 train_time:10013ms step_avg:92.71ms +step:109/1670 train_time:10105ms step_avg:92.71ms +step:110/1670 train_time:10197ms step_avg:92.70ms +step:111/1670 train_time:10288ms step_avg:92.68ms +step:112/1670 train_time:10380ms step_avg:92.68ms +step:113/1670 train_time:10472ms step_avg:92.67ms +step:114/1670 train_time:10563ms step_avg:92.66ms +step:115/1670 train_time:10656ms step_avg:92.66ms +step:116/1670 train_time:10748ms step_avg:92.65ms +step:117/1670 train_time:10840ms step_avg:92.65ms +step:118/1670 train_time:10932ms step_avg:92.64ms +step:119/1670 train_time:11022ms step_avg:92.63ms +step:120/1670 train_time:11115ms step_avg:92.62ms +step:121/1670 train_time:11206ms step_avg:92.62ms +step:122/1670 train_time:11299ms step_avg:92.61ms +step:123/1670 train_time:11390ms step_avg:92.60ms +step:124/1670 train_time:11482ms step_avg:92.59ms +step:125/1670 train_time:11573ms step_avg:92.59ms +step:125/1670 val_loss:4.3115 train_time:11664ms step_avg:93.31ms +step:126/1670 train_time:11684ms step_avg:92.73ms +step:127/1670 train_time:11760ms step_avg:92.60ms +step:128/1670 train_time:11861ms step_avg:92.67ms +step:129/1670 train_time:11954ms step_avg:92.67ms +step:130/1670 train_time:12045ms step_avg:92.66ms +step:131/1670 train_time:12136ms step_avg:92.64ms +step:132/1670 train_time:12225ms step_avg:92.62ms +step:133/1670 train_time:12315ms step_avg:92.60ms +step:134/1670 train_time:12405ms step_avg:92.57ms +step:135/1670 train_time:12496ms step_avg:92.56ms +step:136/1670 train_time:12586ms step_avg:92.55ms +step:137/1670 train_time:12678ms step_avg:92.54ms +step:138/1670 train_time:12771ms step_avg:92.54ms +step:139/1670 train_time:12866ms step_avg:92.56ms +step:140/1670 train_time:12960ms step_avg:92.57ms +step:141/1670 train_time:13053ms step_avg:92.57ms +step:142/1670 train_time:13143ms step_avg:92.56ms +step:143/1670 train_time:13234ms step_avg:92.54ms +step:144/1670 train_time:13325ms step_avg:92.53ms +step:145/1670 train_time:13415ms step_avg:92.52ms +step:146/1670 train_time:13505ms step_avg:92.50ms +step:147/1670 train_time:13597ms step_avg:92.50ms +step:148/1670 train_time:13688ms step_avg:92.49ms +step:149/1670 train_time:13781ms step_avg:92.49ms +step:150/1670 train_time:13873ms step_avg:92.49ms +step:151/1670 train_time:13966ms step_avg:92.49ms +step:152/1670 train_time:14061ms step_avg:92.51ms +step:153/1670 train_time:14152ms step_avg:92.49ms +step:154/1670 train_time:14243ms step_avg:92.49ms +step:155/1670 train_time:14333ms step_avg:92.47ms +step:156/1670 train_time:14425ms step_avg:92.47ms +step:157/1670 train_time:14516ms step_avg:92.46ms +step:158/1670 train_time:14606ms step_avg:92.44ms +step:159/1670 train_time:14697ms step_avg:92.43ms +step:160/1670 train_time:14787ms step_avg:92.42ms +step:161/1670 train_time:14880ms step_avg:92.42ms +step:162/1670 train_time:14971ms step_avg:92.41ms +step:163/1670 train_time:15063ms step_avg:92.41ms +step:164/1670 train_time:15155ms step_avg:92.41ms +step:165/1670 train_time:15247ms step_avg:92.40ms +step:166/1670 train_time:15338ms step_avg:92.40ms +step:167/1670 train_time:15428ms step_avg:92.38ms +step:168/1670 train_time:15520ms step_avg:92.38ms +step:169/1670 train_time:15610ms step_avg:92.37ms +step:170/1670 train_time:15702ms step_avg:92.36ms +step:171/1670 train_time:15794ms step_avg:92.36ms +step:172/1670 train_time:15885ms step_avg:92.36ms +step:173/1670 train_time:15977ms step_avg:92.35ms +step:174/1670 train_time:16068ms step_avg:92.35ms +step:175/1670 train_time:16159ms step_avg:92.34ms +step:176/1670 train_time:16250ms step_avg:92.33ms +step:177/1670 train_time:16341ms step_avg:92.32ms +step:178/1670 train_time:16431ms step_avg:92.31ms +step:179/1670 train_time:16524ms step_avg:92.31ms +step:180/1670 train_time:16616ms step_avg:92.31ms +step:181/1670 train_time:16707ms step_avg:92.30ms +step:182/1670 train_time:16798ms step_avg:92.30ms +step:183/1670 train_time:16889ms step_avg:92.29ms +step:184/1670 train_time:16980ms step_avg:92.28ms +step:185/1670 train_time:17072ms step_avg:92.28ms +step:186/1670 train_time:17164ms step_avg:92.28ms +step:187/1670 train_time:17255ms step_avg:92.27ms +step:188/1670 train_time:17347ms step_avg:92.27ms +step:189/1670 train_time:17438ms step_avg:92.26ms +step:190/1670 train_time:17528ms step_avg:92.25ms +step:191/1670 train_time:17621ms step_avg:92.26ms +step:192/1670 train_time:17712ms step_avg:92.25ms +step:193/1670 train_time:17804ms step_avg:92.25ms +step:194/1670 train_time:17896ms step_avg:92.25ms +step:195/1670 train_time:17987ms step_avg:92.24ms +step:196/1670 train_time:18078ms step_avg:92.24ms +step:197/1670 train_time:18169ms step_avg:92.23ms +step:198/1670 train_time:18260ms step_avg:92.22ms +step:199/1670 train_time:18351ms step_avg:92.22ms +step:200/1670 train_time:18442ms step_avg:92.21ms +step:201/1670 train_time:18533ms step_avg:92.20ms +step:202/1670 train_time:18625ms step_avg:92.20ms +step:203/1670 train_time:18717ms step_avg:92.20ms +step:204/1670 train_time:18808ms step_avg:92.20ms +step:205/1670 train_time:18901ms step_avg:92.20ms +step:206/1670 train_time:18993ms step_avg:92.20ms +step:207/1670 train_time:19085ms step_avg:92.20ms +step:208/1670 train_time:19176ms step_avg:92.19ms +step:209/1670 train_time:19267ms step_avg:92.19ms +step:210/1670 train_time:19359ms step_avg:92.18ms +step:211/1670 train_time:19449ms step_avg:92.18ms +step:212/1670 train_time:19540ms step_avg:92.17ms +step:213/1670 train_time:19789ms step_avg:92.91ms +step:214/1670 train_time:19858ms step_avg:92.79ms +step:215/1670 train_time:19947ms step_avg:92.78ms +step:216/1670 train_time:20037ms step_avg:92.77ms +step:217/1670 train_time:20127ms step_avg:92.75ms +step:218/1670 train_time:20217ms step_avg:92.74ms +step:219/1670 train_time:20307ms step_avg:92.72ms +step:220/1670 train_time:20397ms step_avg:92.71ms +step:221/1670 train_time:20486ms step_avg:92.70ms +step:222/1670 train_time:20577ms step_avg:92.69ms +step:223/1670 train_time:20672ms step_avg:92.70ms +step:224/1670 train_time:20768ms step_avg:92.71ms +step:225/1670 train_time:20861ms step_avg:92.71ms +step:226/1670 train_time:20951ms step_avg:92.70ms +step:227/1670 train_time:21042ms step_avg:92.70ms +step:228/1670 train_time:21133ms step_avg:92.69ms +step:229/1670 train_time:21223ms step_avg:92.68ms +step:230/1670 train_time:21314ms step_avg:92.67ms +step:231/1670 train_time:21404ms step_avg:92.66ms +step:232/1670 train_time:21496ms step_avg:92.65ms +step:233/1670 train_time:21587ms step_avg:92.65ms +step:234/1670 train_time:21681ms step_avg:92.66ms +step:235/1670 train_time:21774ms step_avg:92.66ms +step:236/1670 train_time:21867ms step_avg:92.66ms +step:237/1670 train_time:21958ms step_avg:92.65ms +step:238/1670 train_time:22048ms step_avg:92.64ms +step:239/1670 train_time:22139ms step_avg:92.63ms +step:240/1670 train_time:22229ms step_avg:92.62ms +step:241/1670 train_time:22319ms step_avg:92.61ms +step:242/1670 train_time:22409ms step_avg:92.60ms +step:243/1670 train_time:22500ms step_avg:92.59ms +step:244/1670 train_time:22591ms step_avg:92.59ms +step:245/1670 train_time:22684ms step_avg:92.59ms +step:246/1670 train_time:22777ms step_avg:92.59ms +step:247/1670 train_time:22869ms step_avg:92.59ms +step:248/1670 train_time:22961ms step_avg:92.58ms +step:249/1670 train_time:23052ms step_avg:92.58ms +step:250/1670 train_time:23142ms step_avg:92.57ms +step:250/1670 val_loss:3.9705 train_time:23232ms step_avg:92.93ms +step:251/1670 train_time:23253ms step_avg:92.64ms +step:252/1670 train_time:23325ms step_avg:92.56ms +step:253/1670 train_time:23416ms step_avg:92.55ms +step:254/1670 train_time:23506ms step_avg:92.54ms +step:255/1670 train_time:23596ms step_avg:92.53ms +step:256/1670 train_time:23687ms step_avg:92.53ms +step:257/1670 train_time:23777ms step_avg:92.52ms +step:258/1670 train_time:23867ms step_avg:92.51ms +step:259/1670 train_time:23958ms step_avg:92.50ms +step:260/1670 train_time:24052ms step_avg:92.51ms +step:261/1670 train_time:24142ms step_avg:92.50ms +step:262/1670 train_time:24235ms step_avg:92.50ms +step:263/1670 train_time:24328ms step_avg:92.50ms +step:264/1670 train_time:24418ms step_avg:92.49ms +step:265/1670 train_time:24509ms step_avg:92.49ms +step:266/1670 train_time:24599ms step_avg:92.48ms +step:267/1670 train_time:24690ms step_avg:92.47ms +step:268/1670 train_time:24780ms step_avg:92.46ms +step:269/1670 train_time:24871ms step_avg:92.46ms +step:270/1670 train_time:24961ms step_avg:92.45ms +step:271/1670 train_time:25054ms step_avg:92.45ms +step:272/1670 train_time:25145ms step_avg:92.44ms +step:273/1670 train_time:25236ms step_avg:92.44ms +step:274/1670 train_time:25328ms step_avg:92.44ms +step:275/1670 train_time:25419ms step_avg:92.43ms +step:276/1670 train_time:25512ms step_avg:92.43ms +step:277/1670 train_time:25602ms step_avg:92.43ms +step:278/1670 train_time:25694ms step_avg:92.42ms +step:279/1670 train_time:25785ms step_avg:92.42ms +step:280/1670 train_time:25876ms step_avg:92.41ms +step:281/1670 train_time:25966ms step_avg:92.41ms +step:282/1670 train_time:26058ms step_avg:92.40ms +step:283/1670 train_time:26149ms step_avg:92.40ms +step:284/1670 train_time:26240ms step_avg:92.39ms +step:285/1670 train_time:26332ms step_avg:92.39ms +step:286/1670 train_time:26423ms step_avg:92.39ms +step:287/1670 train_time:26514ms step_avg:92.38ms +step:288/1670 train_time:26605ms step_avg:92.38ms +step:289/1670 train_time:26697ms step_avg:92.38ms +step:290/1670 train_time:26789ms step_avg:92.38ms +step:291/1670 train_time:26880ms step_avg:92.37ms +step:292/1670 train_time:26971ms step_avg:92.37ms +step:293/1670 train_time:27061ms step_avg:92.36ms +step:294/1670 train_time:27152ms step_avg:92.35ms +step:295/1670 train_time:27242ms step_avg:92.35ms +step:296/1670 train_time:27333ms step_avg:92.34ms +step:297/1670 train_time:27424ms step_avg:92.34ms +step:298/1670 train_time:27515ms step_avg:92.33ms +step:299/1670 train_time:27605ms step_avg:92.33ms +step:300/1670 train_time:27697ms step_avg:92.32ms +step:301/1670 train_time:27790ms step_avg:92.33ms +step:302/1670 train_time:27881ms step_avg:92.32ms +step:303/1670 train_time:27972ms step_avg:92.32ms +step:304/1670 train_time:28062ms step_avg:92.31ms +step:305/1670 train_time:28153ms step_avg:92.30ms +step:306/1670 train_time:28243ms step_avg:92.30ms +step:307/1670 train_time:28336ms step_avg:92.30ms +step:308/1670 train_time:28427ms step_avg:92.29ms +step:309/1670 train_time:28517ms step_avg:92.29ms +step:310/1670 train_time:28608ms step_avg:92.28ms +step:311/1670 train_time:28699ms step_avg:92.28ms +step:312/1670 train_time:28790ms step_avg:92.28ms +step:313/1670 train_time:28882ms step_avg:92.27ms +step:314/1670 train_time:28973ms step_avg:92.27ms +step:315/1670 train_time:29063ms step_avg:92.26ms +step:316/1670 train_time:29155ms step_avg:92.26ms +step:317/1670 train_time:29246ms step_avg:92.26ms +step:318/1670 train_time:29337ms step_avg:92.25ms +step:319/1670 train_time:29428ms step_avg:92.25ms +step:320/1670 train_time:29518ms step_avg:92.24ms +step:321/1670 train_time:29610ms step_avg:92.24ms +step:322/1670 train_time:29700ms step_avg:92.24ms +step:323/1670 train_time:29793ms step_avg:92.24ms +step:324/1670 train_time:29883ms step_avg:92.23ms +step:325/1670 train_time:29975ms step_avg:92.23ms +step:326/1670 train_time:30066ms step_avg:92.23ms +step:327/1670 train_time:30157ms step_avg:92.22ms +step:328/1670 train_time:30248ms step_avg:92.22ms +step:329/1670 train_time:30338ms step_avg:92.21ms +step:330/1670 train_time:30429ms step_avg:92.21ms +step:331/1670 train_time:30520ms step_avg:92.20ms +step:332/1670 train_time:30610ms step_avg:92.20ms +step:333/1670 train_time:30701ms step_avg:92.19ms +step:334/1670 train_time:30793ms step_avg:92.19ms +step:335/1670 train_time:30884ms step_avg:92.19ms +step:336/1670 train_time:30975ms step_avg:92.19ms +step:337/1670 train_time:31068ms step_avg:92.19ms +step:338/1670 train_time:31159ms step_avg:92.19ms +step:339/1670 train_time:31249ms step_avg:92.18ms +step:340/1670 train_time:31340ms step_avg:92.18ms +step:341/1670 train_time:31430ms step_avg:92.17ms +step:342/1670 train_time:31520ms step_avg:92.17ms +step:343/1670 train_time:31611ms step_avg:92.16ms +step:344/1670 train_time:31702ms step_avg:92.16ms +step:345/1670 train_time:31794ms step_avg:92.16ms +step:346/1670 train_time:31885ms step_avg:92.15ms +step:347/1670 train_time:31976ms step_avg:92.15ms +step:348/1670 train_time:32068ms step_avg:92.15ms +step:349/1670 train_time:32159ms step_avg:92.15ms +step:350/1670 train_time:32251ms step_avg:92.15ms +step:351/1670 train_time:32341ms step_avg:92.14ms +step:352/1670 train_time:32432ms step_avg:92.14ms +step:353/1670 train_time:32522ms step_avg:92.13ms +step:354/1670 train_time:32613ms step_avg:92.13ms +step:355/1670 train_time:32703ms step_avg:92.12ms +step:356/1670 train_time:32796ms step_avg:92.12ms +step:357/1670 train_time:32886ms step_avg:92.12ms +step:358/1670 train_time:32977ms step_avg:92.11ms +step:359/1670 train_time:33068ms step_avg:92.11ms +step:360/1670 train_time:33159ms step_avg:92.11ms +step:361/1670 train_time:33251ms step_avg:92.11ms +step:362/1670 train_time:33342ms step_avg:92.10ms +step:363/1670 train_time:33433ms step_avg:92.10ms +step:364/1670 train_time:33525ms step_avg:92.10ms +step:365/1670 train_time:33617ms step_avg:92.10ms +step:366/1670 train_time:33708ms step_avg:92.10ms +step:367/1670 train_time:33799ms step_avg:92.10ms +step:368/1670 train_time:33889ms step_avg:92.09ms +step:369/1670 train_time:33980ms step_avg:92.09ms +step:370/1670 train_time:34071ms step_avg:92.08ms +step:371/1670 train_time:34162ms step_avg:92.08ms +step:372/1670 train_time:34254ms step_avg:92.08ms +step:373/1670 train_time:34345ms step_avg:92.08ms +step:374/1670 train_time:34435ms step_avg:92.07ms +step:375/1670 train_time:34526ms step_avg:92.07ms +step:375/1670 val_loss:3.8117 train_time:34617ms step_avg:92.31ms +step:376/1670 train_time:34637ms step_avg:92.12ms +step:377/1670 train_time:34710ms step_avg:92.07ms +step:378/1670 train_time:34802ms step_avg:92.07ms +step:379/1670 train_time:34892ms step_avg:92.06ms +step:380/1670 train_time:34984ms step_avg:92.06ms +step:381/1670 train_time:35074ms step_avg:92.06ms +step:382/1670 train_time:35165ms step_avg:92.06ms +step:383/1670 train_time:35256ms step_avg:92.05ms +step:384/1670 train_time:35347ms step_avg:92.05ms +step:385/1670 train_time:35437ms step_avg:92.05ms +step:386/1670 train_time:35529ms step_avg:92.04ms +step:387/1670 train_time:35623ms step_avg:92.05ms +step:388/1670 train_time:35715ms step_avg:92.05ms +step:389/1670 train_time:35807ms step_avg:92.05ms +step:390/1670 train_time:35898ms step_avg:92.05ms +step:391/1670 train_time:35988ms step_avg:92.04ms +step:392/1670 train_time:36079ms step_avg:92.04ms +step:393/1670 train_time:36170ms step_avg:92.03ms +step:394/1670 train_time:36260ms step_avg:92.03ms +step:395/1670 train_time:36351ms step_avg:92.03ms +step:396/1670 train_time:36444ms step_avg:92.03ms +step:397/1670 train_time:36535ms step_avg:92.03ms +step:398/1670 train_time:36628ms step_avg:92.03ms +step:399/1670 train_time:36719ms step_avg:92.03ms +step:400/1670 train_time:36810ms step_avg:92.03ms +step:401/1670 train_time:36902ms step_avg:92.03ms +step:402/1670 train_time:36993ms step_avg:92.02ms +step:403/1670 train_time:37084ms step_avg:92.02ms +step:404/1670 train_time:37174ms step_avg:92.01ms +step:405/1670 train_time:37265ms step_avg:92.01ms +step:406/1670 train_time:37356ms step_avg:92.01ms +step:407/1670 train_time:37447ms step_avg:92.01ms +step:408/1670 train_time:37539ms step_avg:92.01ms +step:409/1670 train_time:37630ms step_avg:92.00ms +step:410/1670 train_time:37721ms step_avg:92.00ms +step:411/1670 train_time:37812ms step_avg:92.00ms +step:412/1670 train_time:37904ms step_avg:92.00ms +step:413/1670 train_time:37995ms step_avg:92.00ms +step:414/1670 train_time:38087ms step_avg:92.00ms +step:415/1670 train_time:38177ms step_avg:91.99ms +step:416/1670 train_time:38269ms step_avg:91.99ms +step:417/1670 train_time:38360ms step_avg:91.99ms +step:418/1670 train_time:38451ms step_avg:91.99ms +step:419/1670 train_time:38543ms step_avg:91.99ms +step:420/1670 train_time:38634ms step_avg:91.99ms +step:421/1670 train_time:38726ms step_avg:91.99ms +step:422/1670 train_time:38817ms step_avg:91.98ms +step:423/1670 train_time:38908ms step_avg:91.98ms +step:424/1670 train_time:39000ms step_avg:91.98ms +step:425/1670 train_time:39250ms step_avg:92.35ms +step:426/1670 train_time:39322ms step_avg:92.30ms +step:427/1670 train_time:39411ms step_avg:92.30ms +step:428/1670 train_time:39501ms step_avg:92.29ms +step:429/1670 train_time:39592ms step_avg:92.29ms +step:430/1670 train_time:39682ms step_avg:92.28ms +step:431/1670 train_time:39772ms step_avg:92.28ms +step:432/1670 train_time:39861ms step_avg:92.27ms +step:433/1670 train_time:39951ms step_avg:92.27ms +step:434/1670 train_time:40042ms step_avg:92.26ms +step:435/1670 train_time:40133ms step_avg:92.26ms +step:436/1670 train_time:40230ms step_avg:92.27ms +step:437/1670 train_time:40324ms step_avg:92.27ms +step:438/1670 train_time:40415ms step_avg:92.27ms +step:439/1670 train_time:40506ms step_avg:92.27ms +step:440/1670 train_time:40598ms step_avg:92.27ms +step:441/1670 train_time:40688ms step_avg:92.26ms +step:442/1670 train_time:40778ms step_avg:92.26ms +step:443/1670 train_time:40868ms step_avg:92.25ms +step:444/1670 train_time:40960ms step_avg:92.25ms +step:445/1670 train_time:41050ms step_avg:92.25ms +step:446/1670 train_time:41141ms step_avg:92.24ms +step:447/1670 train_time:41234ms step_avg:92.25ms +step:448/1670 train_time:41326ms step_avg:92.25ms +step:449/1670 train_time:41418ms step_avg:92.25ms +step:450/1670 train_time:41509ms step_avg:92.24ms +step:451/1670 train_time:41601ms step_avg:92.24ms +step:452/1670 train_time:41691ms step_avg:92.24ms +step:453/1670 train_time:41782ms step_avg:92.23ms +step:454/1670 train_time:41873ms step_avg:92.23ms +step:455/1670 train_time:41963ms step_avg:92.23ms +step:456/1670 train_time:42053ms step_avg:92.22ms +step:457/1670 train_time:42145ms step_avg:92.22ms +step:458/1670 train_time:42238ms step_avg:92.22ms +step:459/1670 train_time:42329ms step_avg:92.22ms +step:460/1670 train_time:42422ms step_avg:92.22ms +step:461/1670 train_time:42512ms step_avg:92.22ms +step:462/1670 train_time:42603ms step_avg:92.22ms +step:463/1670 train_time:42694ms step_avg:92.21ms +step:464/1670 train_time:42785ms step_avg:92.21ms +step:465/1670 train_time:42876ms step_avg:92.21ms +step:466/1670 train_time:42968ms step_avg:92.21ms +step:467/1670 train_time:43059ms step_avg:92.20ms +step:468/1670 train_time:43151ms step_avg:92.20ms +step:469/1670 train_time:43243ms step_avg:92.20ms +step:470/1670 train_time:43334ms step_avg:92.20ms +step:471/1670 train_time:43426ms step_avg:92.20ms +step:472/1670 train_time:43517ms step_avg:92.20ms +step:473/1670 train_time:43609ms step_avg:92.20ms +step:474/1670 train_time:43700ms step_avg:92.19ms +step:475/1670 train_time:43790ms step_avg:92.19ms +step:476/1670 train_time:43881ms step_avg:92.19ms +step:477/1670 train_time:43972ms step_avg:92.18ms +step:478/1670 train_time:44064ms step_avg:92.18ms +step:479/1670 train_time:44155ms step_avg:92.18ms +step:480/1670 train_time:44247ms step_avg:92.18ms +step:481/1670 train_time:44339ms step_avg:92.18ms +step:482/1670 train_time:44431ms step_avg:92.18ms +step:483/1670 train_time:44523ms step_avg:92.18ms +step:484/1670 train_time:44613ms step_avg:92.18ms +step:485/1670 train_time:44704ms step_avg:92.17ms +step:486/1670 train_time:44794ms step_avg:92.17ms +step:487/1670 train_time:44886ms step_avg:92.17ms +step:488/1670 train_time:44977ms step_avg:92.17ms +step:489/1670 train_time:45068ms step_avg:92.16ms +step:490/1670 train_time:45160ms step_avg:92.16ms +step:491/1670 train_time:45250ms step_avg:92.16ms +step:492/1670 train_time:45342ms step_avg:92.16ms +step:493/1670 train_time:45433ms step_avg:92.16ms +step:494/1670 train_time:45525ms step_avg:92.16ms +step:495/1670 train_time:45616ms step_avg:92.15ms +step:496/1670 train_time:45707ms step_avg:92.15ms +step:497/1670 train_time:45798ms step_avg:92.15ms +step:498/1670 train_time:45889ms step_avg:92.15ms +step:499/1670 train_time:45980ms step_avg:92.14ms +step:500/1670 train_time:46071ms step_avg:92.14ms +step:500/1670 val_loss:3.7132 train_time:46162ms step_avg:92.32ms +step:501/1670 train_time:46183ms step_avg:92.18ms +step:502/1670 train_time:46255ms step_avg:92.14ms +step:503/1670 train_time:46346ms step_avg:92.14ms +step:504/1670 train_time:46438ms step_avg:92.14ms +step:505/1670 train_time:46529ms step_avg:92.14ms +step:506/1670 train_time:46619ms step_avg:92.13ms +step:507/1670 train_time:46710ms step_avg:92.13ms +step:508/1670 train_time:46800ms step_avg:92.13ms +step:509/1670 train_time:46892ms step_avg:92.13ms +step:510/1670 train_time:46983ms step_avg:92.12ms +step:511/1670 train_time:47075ms step_avg:92.12ms +step:512/1670 train_time:47168ms step_avg:92.12ms +step:513/1670 train_time:47260ms step_avg:92.12ms +step:514/1670 train_time:47352ms step_avg:92.12ms +step:515/1670 train_time:47442ms step_avg:92.12ms +step:516/1670 train_time:47533ms step_avg:92.12ms +step:517/1670 train_time:47623ms step_avg:92.11ms +step:518/1670 train_time:47715ms step_avg:92.11ms +step:519/1670 train_time:47807ms step_avg:92.11ms +step:520/1670 train_time:47898ms step_avg:92.11ms +step:521/1670 train_time:47989ms step_avg:92.11ms +step:522/1670 train_time:48080ms step_avg:92.11ms +step:523/1670 train_time:48173ms step_avg:92.11ms +step:524/1670 train_time:48264ms step_avg:92.11ms +step:525/1670 train_time:48356ms step_avg:92.11ms +step:526/1670 train_time:48448ms step_avg:92.11ms +step:527/1670 train_time:48539ms step_avg:92.10ms +step:528/1670 train_time:48629ms step_avg:92.10ms +step:529/1670 train_time:48720ms step_avg:92.10ms +step:530/1670 train_time:48810ms step_avg:92.10ms +step:531/1670 train_time:48902ms step_avg:92.09ms +step:532/1670 train_time:48993ms step_avg:92.09ms +step:533/1670 train_time:49084ms step_avg:92.09ms +step:534/1670 train_time:49176ms step_avg:92.09ms +step:535/1670 train_time:49268ms step_avg:92.09ms +step:536/1670 train_time:49361ms step_avg:92.09ms +step:537/1670 train_time:49452ms step_avg:92.09ms +step:538/1670 train_time:49542ms step_avg:92.09ms +step:539/1670 train_time:49633ms step_avg:92.08ms +step:540/1670 train_time:49724ms step_avg:92.08ms +step:541/1670 train_time:49816ms step_avg:92.08ms +step:542/1670 train_time:49907ms step_avg:92.08ms +step:543/1670 train_time:49998ms step_avg:92.08ms +step:544/1670 train_time:50089ms step_avg:92.07ms +step:545/1670 train_time:50180ms step_avg:92.07ms +step:546/1670 train_time:50271ms step_avg:92.07ms +step:547/1670 train_time:50363ms step_avg:92.07ms +step:548/1670 train_time:50454ms step_avg:92.07ms +step:549/1670 train_time:50545ms step_avg:92.07ms +step:550/1670 train_time:50637ms step_avg:92.07ms +step:551/1670 train_time:50728ms step_avg:92.07ms +step:552/1670 train_time:50819ms step_avg:92.06ms +step:553/1670 train_time:50910ms step_avg:92.06ms +step:554/1670 train_time:51001ms step_avg:92.06ms +step:555/1670 train_time:51092ms step_avg:92.06ms +step:556/1670 train_time:51182ms step_avg:92.05ms +step:557/1670 train_time:51274ms step_avg:92.05ms +step:558/1670 train_time:51565ms step_avg:92.41ms +step:559/1670 train_time:51634ms step_avg:92.37ms +step:560/1670 train_time:51725ms step_avg:92.37ms +step:561/1670 train_time:51816ms step_avg:92.36ms +step:562/1670 train_time:51907ms step_avg:92.36ms +step:563/1670 train_time:51998ms step_avg:92.36ms +step:564/1670 train_time:52090ms step_avg:92.36ms +step:565/1670 train_time:52184ms step_avg:92.36ms +step:566/1670 train_time:52272ms step_avg:92.35ms +step:567/1670 train_time:52363ms step_avg:92.35ms +step:568/1670 train_time:52460ms step_avg:92.36ms +step:569/1670 train_time:52558ms step_avg:92.37ms +step:570/1670 train_time:52651ms step_avg:92.37ms +step:571/1670 train_time:52742ms step_avg:92.37ms +step:572/1670 train_time:52835ms step_avg:92.37ms +step:573/1670 train_time:52927ms step_avg:92.37ms +step:574/1670 train_time:53018ms step_avg:92.37ms +step:575/1670 train_time:53110ms step_avg:92.36ms +step:576/1670 train_time:53201ms step_avg:92.36ms +step:577/1670 train_time:53294ms step_avg:92.36ms +step:578/1670 train_time:53388ms step_avg:92.37ms +step:579/1670 train_time:53482ms step_avg:92.37ms +step:580/1670 train_time:53577ms step_avg:92.37ms +step:581/1670 train_time:53670ms step_avg:92.38ms +step:582/1670 train_time:53762ms step_avg:92.38ms +step:583/1670 train_time:53856ms step_avg:92.38ms +step:584/1670 train_time:53948ms step_avg:92.38ms +step:585/1670 train_time:54039ms step_avg:92.37ms +step:586/1670 train_time:54131ms step_avg:92.37ms +step:587/1670 train_time:54225ms step_avg:92.38ms +step:588/1670 train_time:54319ms step_avg:92.38ms +step:589/1670 train_time:54410ms step_avg:92.38ms +step:590/1670 train_time:54504ms step_avg:92.38ms +step:591/1670 train_time:54598ms step_avg:92.38ms +step:592/1670 train_time:54692ms step_avg:92.38ms +step:593/1670 train_time:54784ms step_avg:92.38ms +step:594/1670 train_time:54876ms step_avg:92.38ms +step:595/1670 train_time:54969ms step_avg:92.38ms +step:596/1670 train_time:55060ms step_avg:92.38ms +step:597/1670 train_time:55152ms step_avg:92.38ms +step:598/1670 train_time:55244ms step_avg:92.38ms +step:599/1670 train_time:55336ms step_avg:92.38ms +step:600/1670 train_time:55429ms step_avg:92.38ms +step:601/1670 train_time:55521ms step_avg:92.38ms +step:602/1670 train_time:55615ms step_avg:92.38ms +step:603/1670 train_time:55708ms step_avg:92.38ms +step:604/1670 train_time:55800ms step_avg:92.38ms +step:605/1670 train_time:55893ms step_avg:92.38ms +step:606/1670 train_time:55984ms step_avg:92.38ms +step:607/1670 train_time:56076ms step_avg:92.38ms +step:608/1670 train_time:56168ms step_avg:92.38ms +step:609/1670 train_time:56260ms step_avg:92.38ms +step:610/1670 train_time:56353ms step_avg:92.38ms +step:611/1670 train_time:56445ms step_avg:92.38ms +step:612/1670 train_time:56539ms step_avg:92.38ms +step:613/1670 train_time:56632ms step_avg:92.39ms +step:614/1670 train_time:56725ms step_avg:92.39ms +step:615/1670 train_time:56819ms step_avg:92.39ms +step:616/1670 train_time:56912ms step_avg:92.39ms +step:617/1670 train_time:57003ms step_avg:92.39ms +step:618/1670 train_time:57096ms step_avg:92.39ms +step:619/1670 train_time:57188ms step_avg:92.39ms +step:620/1670 train_time:57280ms step_avg:92.39ms +step:621/1670 train_time:57373ms step_avg:92.39ms +step:622/1670 train_time:57465ms step_avg:92.39ms +step:623/1670 train_time:57560ms step_avg:92.39ms +step:624/1670 train_time:57653ms step_avg:92.39ms +step:625/1670 train_time:57744ms step_avg:92.39ms +step:625/1670 val_loss:3.6122 train_time:57838ms step_avg:92.54ms +step:626/1670 train_time:57861ms step_avg:92.43ms +step:627/1670 train_time:57936ms step_avg:92.40ms +step:628/1670 train_time:58035ms step_avg:92.41ms +step:629/1670 train_time:58128ms step_avg:92.41ms +step:630/1670 train_time:58220ms step_avg:92.41ms +step:631/1670 train_time:58311ms step_avg:92.41ms +step:632/1670 train_time:58402ms step_avg:92.41ms +step:633/1670 train_time:58493ms step_avg:92.41ms +step:634/1670 train_time:58584ms step_avg:92.40ms +step:635/1670 train_time:58675ms step_avg:92.40ms +step:636/1670 train_time:58766ms step_avg:92.40ms +step:637/1670 train_time:58858ms step_avg:92.40ms +step:638/1670 train_time:58954ms step_avg:92.40ms +step:639/1670 train_time:59190ms step_avg:92.63ms +step:640/1670 train_time:59264ms step_avg:92.60ms +step:641/1670 train_time:59355ms step_avg:92.60ms +step:642/1670 train_time:59447ms step_avg:92.60ms +step:643/1670 train_time:59537ms step_avg:92.59ms +step:644/1670 train_time:59629ms step_avg:92.59ms +step:645/1670 train_time:59720ms step_avg:92.59ms +step:646/1670 train_time:59811ms step_avg:92.59ms +step:647/1670 train_time:59902ms step_avg:92.58ms +step:648/1670 train_time:59993ms step_avg:92.58ms +step:649/1670 train_time:60089ms step_avg:92.59ms +step:650/1670 train_time:60185ms step_avg:92.59ms +step:651/1670 train_time:60278ms step_avg:92.59ms +step:652/1670 train_time:60371ms step_avg:92.59ms +step:653/1670 train_time:60463ms step_avg:92.59ms +step:654/1670 train_time:60554ms step_avg:92.59ms +step:655/1670 train_time:60647ms step_avg:92.59ms +step:656/1670 train_time:60738ms step_avg:92.59ms +step:657/1670 train_time:60830ms step_avg:92.59ms +step:658/1670 train_time:60921ms step_avg:92.59ms +step:659/1670 train_time:61014ms step_avg:92.59ms +step:660/1670 train_time:61109ms step_avg:92.59ms +step:661/1670 train_time:61203ms step_avg:92.59ms +step:662/1670 train_time:61296ms step_avg:92.59ms +step:663/1670 train_time:61390ms step_avg:92.59ms +step:664/1670 train_time:61484ms step_avg:92.60ms +step:665/1670 train_time:61576ms step_avg:92.60ms +step:666/1670 train_time:61668ms step_avg:92.59ms +step:667/1670 train_time:61759ms step_avg:92.59ms +step:668/1670 train_time:61851ms step_avg:92.59ms +step:669/1670 train_time:61942ms step_avg:92.59ms +step:670/1670 train_time:62037ms step_avg:92.59ms +step:671/1670 train_time:62131ms step_avg:92.59ms +step:672/1670 train_time:62224ms step_avg:92.60ms +step:673/1670 train_time:62316ms step_avg:92.59ms +step:674/1670 train_time:62408ms step_avg:92.59ms +step:675/1670 train_time:62500ms step_avg:92.59ms +step:676/1670 train_time:62595ms step_avg:92.60ms +step:677/1670 train_time:62685ms step_avg:92.59ms +step:678/1670 train_time:62777ms step_avg:92.59ms +step:679/1670 train_time:62870ms step_avg:92.59ms +step:680/1670 train_time:62962ms step_avg:92.59ms +step:681/1670 train_time:63055ms step_avg:92.59ms +step:682/1670 train_time:63149ms step_avg:92.59ms +step:683/1670 train_time:63241ms step_avg:92.59ms +step:684/1670 train_time:63337ms step_avg:92.60ms +step:685/1670 train_time:63428ms step_avg:92.60ms +step:686/1670 train_time:63520ms step_avg:92.59ms +step:687/1670 train_time:63613ms step_avg:92.60ms +step:688/1670 train_time:63704ms step_avg:92.59ms +step:689/1670 train_time:63796ms step_avg:92.59ms +step:690/1670 train_time:63889ms step_avg:92.59ms +step:691/1670 train_time:63980ms step_avg:92.59ms +step:692/1670 train_time:64073ms step_avg:92.59ms +step:693/1670 train_time:64168ms step_avg:92.59ms +step:694/1670 train_time:64259ms step_avg:92.59ms +step:695/1670 train_time:64352ms step_avg:92.59ms +step:696/1670 train_time:64446ms step_avg:92.59ms +step:697/1670 train_time:64537ms step_avg:92.59ms +step:698/1670 train_time:64630ms step_avg:92.59ms +step:699/1670 train_time:64722ms step_avg:92.59ms +step:700/1670 train_time:64814ms step_avg:92.59ms +step:701/1670 train_time:64907ms step_avg:92.59ms +step:702/1670 train_time:65000ms step_avg:92.59ms +step:703/1670 train_time:65092ms step_avg:92.59ms +step:704/1670 train_time:65185ms step_avg:92.59ms +step:705/1670 train_time:65278ms step_avg:92.59ms +step:706/1670 train_time:65371ms step_avg:92.59ms +step:707/1670 train_time:65464ms step_avg:92.59ms +step:708/1670 train_time:65556ms step_avg:92.59ms +step:709/1670 train_time:65647ms step_avg:92.59ms +step:710/1670 train_time:65739ms step_avg:92.59ms +step:711/1670 train_time:65832ms step_avg:92.59ms +step:712/1670 train_time:65925ms step_avg:92.59ms +step:713/1670 train_time:66016ms step_avg:92.59ms +step:714/1670 train_time:66108ms step_avg:92.59ms +step:715/1670 train_time:66201ms step_avg:92.59ms +step:716/1670 train_time:66294ms step_avg:92.59ms +step:717/1670 train_time:66386ms step_avg:92.59ms +step:718/1670 train_time:66478ms step_avg:92.59ms +step:719/1670 train_time:66572ms step_avg:92.59ms +step:720/1670 train_time:66664ms step_avg:92.59ms +step:721/1670 train_time:66756ms step_avg:92.59ms +step:722/1670 train_time:66849ms step_avg:92.59ms +step:723/1670 train_time:66939ms step_avg:92.59ms +step:724/1670 train_time:67032ms step_avg:92.59ms +step:725/1670 train_time:67124ms step_avg:92.58ms +step:726/1670 train_time:67215ms step_avg:92.58ms +step:727/1670 train_time:67309ms step_avg:92.58ms +step:728/1670 train_time:67401ms step_avg:92.58ms +step:729/1670 train_time:67494ms step_avg:92.58ms +step:730/1670 train_time:67587ms step_avg:92.58ms +step:731/1670 train_time:67678ms step_avg:92.58ms +step:732/1670 train_time:67771ms step_avg:92.58ms +step:733/1670 train_time:67863ms step_avg:92.58ms +step:734/1670 train_time:67955ms step_avg:92.58ms +step:735/1670 train_time:68047ms step_avg:92.58ms +step:736/1670 train_time:68139ms step_avg:92.58ms +step:737/1670 train_time:68233ms step_avg:92.58ms +step:738/1670 train_time:68325ms step_avg:92.58ms +step:739/1670 train_time:68417ms step_avg:92.58ms +step:740/1670 train_time:68511ms step_avg:92.58ms +step:741/1670 train_time:68606ms step_avg:92.59ms +step:742/1670 train_time:68696ms step_avg:92.58ms +step:743/1670 train_time:68788ms step_avg:92.58ms +step:744/1670 train_time:68880ms step_avg:92.58ms +step:745/1670 train_time:68973ms step_avg:92.58ms +step:746/1670 train_time:69067ms step_avg:92.58ms +step:747/1670 train_time:69158ms step_avg:92.58ms +step:748/1670 train_time:69251ms step_avg:92.58ms +step:749/1670 train_time:69343ms step_avg:92.58ms +step:750/1670 train_time:69436ms step_avg:92.58ms +step:750/1670 val_loss:3.5597 train_time:69529ms step_avg:92.70ms +step:751/1670 train_time:69549ms step_avg:92.61ms +step:752/1670 train_time:69622ms step_avg:92.58ms +step:753/1670 train_time:69714ms step_avg:92.58ms +step:754/1670 train_time:69808ms step_avg:92.58ms +step:755/1670 train_time:69900ms step_avg:92.58ms +step:756/1670 train_time:69991ms step_avg:92.58ms +step:757/1670 train_time:70084ms step_avg:92.58ms +step:758/1670 train_time:70175ms step_avg:92.58ms +step:759/1670 train_time:70268ms step_avg:92.58ms +step:760/1670 train_time:70360ms step_avg:92.58ms +step:761/1670 train_time:70453ms step_avg:92.58ms +step:762/1670 train_time:70548ms step_avg:92.58ms +step:763/1670 train_time:70641ms step_avg:92.58ms +step:764/1670 train_time:70734ms step_avg:92.58ms +step:765/1670 train_time:70827ms step_avg:92.58ms +step:766/1670 train_time:70920ms step_avg:92.59ms +step:767/1670 train_time:71013ms step_avg:92.58ms +step:768/1670 train_time:71105ms step_avg:92.58ms +step:769/1670 train_time:71197ms step_avg:92.58ms +step:770/1670 train_time:71290ms step_avg:92.58ms +step:771/1670 train_time:71382ms step_avg:92.58ms +step:772/1670 train_time:71475ms step_avg:92.58ms +step:773/1670 train_time:71567ms step_avg:92.58ms +step:774/1670 train_time:71660ms step_avg:92.58ms +step:775/1670 train_time:71753ms step_avg:92.58ms +step:776/1670 train_time:71845ms step_avg:92.58ms +step:777/1670 train_time:71936ms step_avg:92.58ms +step:778/1670 train_time:72029ms step_avg:92.58ms +step:779/1670 train_time:72121ms step_avg:92.58ms +step:780/1670 train_time:72213ms step_avg:92.58ms +step:781/1670 train_time:72305ms step_avg:92.58ms +step:782/1670 train_time:72397ms step_avg:92.58ms +step:783/1670 train_time:72491ms step_avg:92.58ms +step:784/1670 train_time:72584ms step_avg:92.58ms +step:785/1670 train_time:72676ms step_avg:92.58ms +step:786/1670 train_time:72768ms step_avg:92.58ms +step:787/1670 train_time:72860ms step_avg:92.58ms +step:788/1670 train_time:72952ms step_avg:92.58ms +step:789/1670 train_time:73045ms step_avg:92.58ms +step:790/1670 train_time:73137ms step_avg:92.58ms +step:791/1670 train_time:73229ms step_avg:92.58ms +step:792/1670 train_time:73322ms step_avg:92.58ms +step:793/1670 train_time:73415ms step_avg:92.58ms +step:794/1670 train_time:73509ms step_avg:92.58ms +step:795/1670 train_time:73603ms step_avg:92.58ms +step:796/1670 train_time:73695ms step_avg:92.58ms +step:797/1670 train_time:73787ms step_avg:92.58ms +step:798/1670 train_time:73880ms step_avg:92.58ms +step:799/1670 train_time:73972ms step_avg:92.58ms +step:800/1670 train_time:74065ms step_avg:92.58ms +step:801/1670 train_time:74157ms step_avg:92.58ms +step:802/1670 train_time:74250ms step_avg:92.58ms +step:803/1670 train_time:74342ms step_avg:92.58ms +step:804/1670 train_time:74434ms step_avg:92.58ms +step:805/1670 train_time:74527ms step_avg:92.58ms +step:806/1670 train_time:74620ms step_avg:92.58ms +step:807/1670 train_time:74713ms step_avg:92.58ms +step:808/1670 train_time:74806ms step_avg:92.58ms +step:809/1670 train_time:74898ms step_avg:92.58ms +step:810/1670 train_time:74991ms step_avg:92.58ms +step:811/1670 train_time:75083ms step_avg:92.58ms +step:812/1670 train_time:75175ms step_avg:92.58ms +step:813/1670 train_time:75268ms step_avg:92.58ms +step:814/1670 train_time:75361ms step_avg:92.58ms +step:815/1670 train_time:75453ms step_avg:92.58ms +step:816/1670 train_time:75546ms step_avg:92.58ms +step:817/1670 train_time:75637ms step_avg:92.58ms +step:818/1670 train_time:75731ms step_avg:92.58ms +step:819/1670 train_time:75823ms step_avg:92.58ms +step:820/1670 train_time:75915ms step_avg:92.58ms +step:821/1670 train_time:76008ms step_avg:92.58ms +step:822/1670 train_time:76100ms step_avg:92.58ms +step:823/1670 train_time:76193ms step_avg:92.58ms +step:824/1670 train_time:76286ms step_avg:92.58ms +step:825/1670 train_time:76380ms step_avg:92.58ms +step:826/1670 train_time:76473ms step_avg:92.58ms +step:827/1670 train_time:76566ms step_avg:92.58ms +step:828/1670 train_time:76658ms step_avg:92.58ms +step:829/1670 train_time:76751ms step_avg:92.58ms +step:830/1670 train_time:76844ms step_avg:92.58ms +step:831/1670 train_time:76935ms step_avg:92.58ms +step:832/1670 train_time:77028ms step_avg:92.58ms +step:833/1670 train_time:77120ms step_avg:92.58ms +step:834/1670 train_time:77213ms step_avg:92.58ms +step:835/1670 train_time:77306ms step_avg:92.58ms +step:836/1670 train_time:77398ms step_avg:92.58ms +step:837/1670 train_time:77491ms step_avg:92.58ms +step:838/1670 train_time:77584ms step_avg:92.58ms +step:839/1670 train_time:77676ms step_avg:92.58ms +step:840/1670 train_time:77769ms step_avg:92.58ms +step:841/1670 train_time:77862ms step_avg:92.58ms +step:842/1670 train_time:77953ms step_avg:92.58ms +step:843/1670 train_time:78046ms step_avg:92.58ms +step:844/1670 train_time:78138ms step_avg:92.58ms +step:845/1670 train_time:78232ms step_avg:92.58ms +step:846/1670 train_time:78325ms step_avg:92.58ms +step:847/1670 train_time:78416ms step_avg:92.58ms +step:848/1670 train_time:78510ms step_avg:92.58ms +step:849/1670 train_time:78602ms step_avg:92.58ms +step:850/1670 train_time:78694ms step_avg:92.58ms +step:851/1670 train_time:78946ms step_avg:92.77ms +step:852/1670 train_time:79015ms step_avg:92.74ms +step:853/1670 train_time:79106ms step_avg:92.74ms +step:854/1670 train_time:79198ms step_avg:92.74ms +step:855/1670 train_time:79290ms step_avg:92.74ms +step:856/1670 train_time:79381ms step_avg:92.73ms +step:857/1670 train_time:79472ms step_avg:92.73ms +step:858/1670 train_time:79564ms step_avg:92.73ms +step:859/1670 train_time:79655ms step_avg:92.73ms +step:860/1670 train_time:79746ms step_avg:92.73ms +step:861/1670 train_time:79842ms step_avg:92.73ms +step:862/1670 train_time:79939ms step_avg:92.74ms +step:863/1670 train_time:80033ms step_avg:92.74ms +step:864/1670 train_time:80125ms step_avg:92.74ms +step:865/1670 train_time:80216ms step_avg:92.74ms +step:866/1670 train_time:80309ms step_avg:92.74ms +step:867/1670 train_time:80400ms step_avg:92.73ms +step:868/1670 train_time:80492ms step_avg:92.73ms +step:869/1670 train_time:80583ms step_avg:92.73ms +step:870/1670 train_time:80673ms step_avg:92.73ms +step:871/1670 train_time:80766ms step_avg:92.73ms +step:872/1670 train_time:80860ms step_avg:92.73ms +step:873/1670 train_time:80954ms step_avg:92.73ms +step:874/1670 train_time:81049ms step_avg:92.73ms +step:875/1670 train_time:81141ms step_avg:92.73ms +step:875/1670 val_loss:3.5163 train_time:81234ms step_avg:92.84ms +step:876/1670 train_time:81254ms step_avg:92.76ms +step:877/1670 train_time:81331ms step_avg:92.74ms +step:878/1670 train_time:81423ms step_avg:92.74ms +step:879/1670 train_time:81514ms step_avg:92.73ms +step:880/1670 train_time:81605ms step_avg:92.73ms +step:881/1670 train_time:81696ms step_avg:92.73ms +step:882/1670 train_time:81788ms step_avg:92.73ms +step:883/1670 train_time:81879ms step_avg:92.73ms +step:884/1670 train_time:81972ms step_avg:92.73ms +step:885/1670 train_time:82067ms step_avg:92.73ms +step:886/1670 train_time:82160ms step_avg:92.73ms +step:887/1670 train_time:82255ms step_avg:92.73ms +step:888/1670 train_time:82349ms step_avg:92.74ms +step:889/1670 train_time:82442ms step_avg:92.74ms +step:890/1670 train_time:82534ms step_avg:92.73ms +step:891/1670 train_time:82626ms step_avg:92.73ms +step:892/1670 train_time:82717ms step_avg:92.73ms +step:893/1670 train_time:82809ms step_avg:92.73ms +step:894/1670 train_time:82900ms step_avg:92.73ms +step:895/1670 train_time:82992ms step_avg:92.73ms +step:896/1670 train_time:83086ms step_avg:92.73ms +step:897/1670 train_time:83179ms step_avg:92.73ms +step:898/1670 train_time:83273ms step_avg:92.73ms +step:899/1670 train_time:83366ms step_avg:92.73ms +step:900/1670 train_time:83459ms step_avg:92.73ms +step:901/1670 train_time:83551ms step_avg:92.73ms +step:902/1670 train_time:83645ms step_avg:92.73ms +step:903/1670 train_time:83736ms step_avg:92.73ms +step:904/1670 train_time:83828ms step_avg:92.73ms +step:905/1670 train_time:83920ms step_avg:92.73ms +step:906/1670 train_time:84013ms step_avg:92.73ms +step:907/1670 train_time:84105ms step_avg:92.73ms +step:908/1670 train_time:84198ms step_avg:92.73ms +step:909/1670 train_time:84292ms step_avg:92.73ms +step:910/1670 train_time:84385ms step_avg:92.73ms +step:911/1670 train_time:84476ms step_avg:92.73ms +step:912/1670 train_time:84569ms step_avg:92.73ms +step:913/1670 train_time:84661ms step_avg:92.73ms +step:914/1670 train_time:84753ms step_avg:92.73ms +step:915/1670 train_time:84845ms step_avg:92.73ms +step:916/1670 train_time:84936ms step_avg:92.73ms +step:917/1670 train_time:85028ms step_avg:92.72ms +step:918/1670 train_time:85120ms step_avg:92.72ms +step:919/1670 train_time:85213ms step_avg:92.72ms +step:920/1670 train_time:85306ms step_avg:92.72ms +step:921/1670 train_time:85398ms step_avg:92.72ms +step:922/1670 train_time:85493ms step_avg:92.73ms +step:923/1670 train_time:85585ms step_avg:92.73ms +step:924/1670 train_time:85678ms step_avg:92.73ms +step:925/1670 train_time:85770ms step_avg:92.72ms +step:926/1670 train_time:85863ms step_avg:92.72ms +step:927/1670 train_time:85955ms step_avg:92.72ms +step:928/1670 train_time:86047ms step_avg:92.72ms +step:929/1670 train_time:86140ms step_avg:92.72ms +step:930/1670 train_time:86233ms step_avg:92.72ms +step:931/1670 train_time:86326ms step_avg:92.72ms +step:932/1670 train_time:86418ms step_avg:92.72ms +step:933/1670 train_time:86511ms step_avg:92.72ms +step:934/1670 train_time:86604ms step_avg:92.72ms +step:935/1670 train_time:86697ms step_avg:92.72ms +step:936/1670 train_time:86790ms step_avg:92.72ms +step:937/1670 train_time:86883ms step_avg:92.72ms +step:938/1670 train_time:86976ms step_avg:92.72ms +step:939/1670 train_time:87069ms step_avg:92.72ms +step:940/1670 train_time:87161ms step_avg:92.72ms +step:941/1670 train_time:87254ms step_avg:92.72ms +step:942/1670 train_time:87346ms step_avg:92.72ms +step:943/1670 train_time:87438ms step_avg:92.72ms +step:944/1670 train_time:87530ms step_avg:92.72ms +step:945/1670 train_time:87623ms step_avg:92.72ms +step:946/1670 train_time:87715ms step_avg:92.72ms +step:947/1670 train_time:87807ms step_avg:92.72ms +step:948/1670 train_time:87899ms step_avg:92.72ms +step:949/1670 train_time:87992ms step_avg:92.72ms +step:950/1670 train_time:88085ms step_avg:92.72ms +step:951/1670 train_time:88178ms step_avg:92.72ms +step:952/1670 train_time:88271ms step_avg:92.72ms +step:953/1670 train_time:88363ms step_avg:92.72ms +step:954/1670 train_time:88456ms step_avg:92.72ms +step:955/1670 train_time:88548ms step_avg:92.72ms +step:956/1670 train_time:88640ms step_avg:92.72ms +step:957/1670 train_time:88733ms step_avg:92.72ms +step:958/1670 train_time:88825ms step_avg:92.72ms +step:959/1670 train_time:88917ms step_avg:92.72ms +step:960/1670 train_time:89010ms step_avg:92.72ms +step:961/1670 train_time:89103ms step_avg:92.72ms +step:962/1670 train_time:89195ms step_avg:92.72ms +step:963/1670 train_time:89288ms step_avg:92.72ms +step:964/1670 train_time:89380ms step_avg:92.72ms +step:965/1670 train_time:89473ms step_avg:92.72ms +step:966/1670 train_time:89566ms step_avg:92.72ms +step:967/1670 train_time:89658ms step_avg:92.72ms +step:968/1670 train_time:89751ms step_avg:92.72ms +step:969/1670 train_time:89843ms step_avg:92.72ms +step:970/1670 train_time:89935ms step_avg:92.72ms +step:971/1670 train_time:90028ms step_avg:92.72ms +step:972/1670 train_time:90120ms step_avg:92.72ms +step:973/1670 train_time:90213ms step_avg:92.72ms +step:974/1670 train_time:90305ms step_avg:92.72ms +step:975/1670 train_time:90397ms step_avg:92.71ms +step:976/1670 train_time:90490ms step_avg:92.71ms +step:977/1670 train_time:90582ms step_avg:92.71ms +step:978/1670 train_time:90675ms step_avg:92.72ms +step:979/1670 train_time:90768ms step_avg:92.71ms +step:980/1670 train_time:90860ms step_avg:92.71ms +step:981/1670 train_time:90953ms step_avg:92.71ms +step:982/1670 train_time:91046ms step_avg:92.71ms +step:983/1670 train_time:91138ms step_avg:92.71ms +step:984/1670 train_time:91231ms step_avg:92.71ms +step:985/1670 train_time:91323ms step_avg:92.71ms +step:986/1670 train_time:91415ms step_avg:92.71ms +step:987/1670 train_time:91507ms step_avg:92.71ms +step:988/1670 train_time:91599ms step_avg:92.71ms +step:989/1670 train_time:91692ms step_avg:92.71ms +step:990/1670 train_time:91784ms step_avg:92.71ms +step:991/1670 train_time:91876ms step_avg:92.71ms +step:992/1670 train_time:91970ms step_avg:92.71ms +step:993/1670 train_time:92062ms step_avg:92.71ms +step:994/1670 train_time:92155ms step_avg:92.71ms +step:995/1670 train_time:92247ms step_avg:92.71ms +step:996/1670 train_time:92339ms step_avg:92.71ms +step:997/1670 train_time:92431ms step_avg:92.71ms +step:998/1670 train_time:92523ms step_avg:92.71ms +step:999/1670 train_time:92616ms step_avg:92.71ms +step:1000/1670 train_time:92708ms step_avg:92.71ms +step:1000/1670 val_loss:3.4665 train_time:92799ms step_avg:92.80ms +step:1001/1670 train_time:92819ms step_avg:92.73ms +step:1002/1670 train_time:92892ms step_avg:92.71ms +step:1003/1670 train_time:92985ms step_avg:92.71ms +step:1004/1670 train_time:93077ms step_avg:92.71ms +step:1005/1670 train_time:93170ms step_avg:92.71ms +step:1006/1670 train_time:93261ms step_avg:92.71ms +step:1007/1670 train_time:93353ms step_avg:92.70ms +step:1008/1670 train_time:93445ms step_avg:92.70ms +step:1009/1670 train_time:93537ms step_avg:92.70ms +step:1010/1670 train_time:93630ms step_avg:92.70ms +step:1011/1670 train_time:93723ms step_avg:92.70ms +step:1012/1670 train_time:93817ms step_avg:92.70ms +step:1013/1670 train_time:93910ms step_avg:92.70ms +step:1014/1670 train_time:94002ms step_avg:92.70ms +step:1015/1670 train_time:94095ms step_avg:92.70ms +step:1016/1670 train_time:94189ms step_avg:92.71ms +step:1017/1670 train_time:94281ms step_avg:92.70ms +step:1018/1670 train_time:94373ms step_avg:92.70ms +step:1019/1670 train_time:94464ms step_avg:92.70ms +step:1020/1670 train_time:94556ms step_avg:92.70ms +step:1021/1670 train_time:94649ms step_avg:92.70ms +step:1022/1670 train_time:94742ms step_avg:92.70ms +step:1023/1670 train_time:94835ms step_avg:92.70ms +step:1024/1670 train_time:94928ms step_avg:92.70ms +step:1025/1670 train_time:95021ms step_avg:92.70ms +step:1026/1670 train_time:95115ms step_avg:92.70ms +step:1027/1670 train_time:95208ms step_avg:92.70ms +step:1028/1670 train_time:95299ms step_avg:92.70ms +step:1029/1670 train_time:95391ms step_avg:92.70ms +step:1030/1670 train_time:95483ms step_avg:92.70ms +step:1031/1670 train_time:95576ms step_avg:92.70ms +step:1032/1670 train_time:95669ms step_avg:92.70ms +step:1033/1670 train_time:95761ms step_avg:92.70ms +step:1034/1670 train_time:95855ms step_avg:92.70ms +step:1035/1670 train_time:95948ms step_avg:92.70ms +step:1036/1670 train_time:96041ms step_avg:92.70ms +step:1037/1670 train_time:96134ms step_avg:92.70ms +step:1038/1670 train_time:96227ms step_avg:92.70ms +step:1039/1670 train_time:96319ms step_avg:92.70ms +step:1040/1670 train_time:96411ms step_avg:92.70ms +step:1041/1670 train_time:96503ms step_avg:92.70ms +step:1042/1670 train_time:96596ms step_avg:92.70ms +step:1043/1670 train_time:96688ms step_avg:92.70ms +step:1044/1670 train_time:96780ms step_avg:92.70ms +step:1045/1670 train_time:96873ms step_avg:92.70ms +step:1046/1670 train_time:96965ms step_avg:92.70ms +step:1047/1670 train_time:97059ms step_avg:92.70ms +step:1048/1670 train_time:97152ms step_avg:92.70ms +step:1049/1670 train_time:97244ms step_avg:92.70ms +step:1050/1670 train_time:97336ms step_avg:92.70ms +step:1051/1670 train_time:97429ms step_avg:92.70ms +step:1052/1670 train_time:97521ms step_avg:92.70ms +step:1053/1670 train_time:97614ms step_avg:92.70ms +step:1054/1670 train_time:97706ms step_avg:92.70ms +step:1055/1670 train_time:97798ms step_avg:92.70ms +step:1056/1670 train_time:97890ms step_avg:92.70ms +step:1057/1670 train_time:97983ms step_avg:92.70ms +step:1058/1670 train_time:98078ms step_avg:92.70ms +step:1059/1670 train_time:98169ms step_avg:92.70ms +step:1060/1670 train_time:98261ms step_avg:92.70ms +step:1061/1670 train_time:98354ms step_avg:92.70ms +step:1062/1670 train_time:98605ms step_avg:92.85ms +step:1063/1670 train_time:98676ms step_avg:92.83ms +step:1064/1670 train_time:98768ms step_avg:92.83ms +step:1065/1670 train_time:98858ms step_avg:92.82ms +step:1066/1670 train_time:98950ms step_avg:92.82ms +step:1067/1670 train_time:99041ms step_avg:92.82ms +step:1068/1670 train_time:99133ms step_avg:92.82ms +step:1069/1670 train_time:99225ms step_avg:92.82ms +step:1070/1670 train_time:99316ms step_avg:92.82ms +step:1071/1670 train_time:99408ms step_avg:92.82ms +step:1072/1670 train_time:99505ms step_avg:92.82ms +step:1073/1670 train_time:99601ms step_avg:92.83ms +step:1074/1670 train_time:99695ms step_avg:92.83ms +step:1075/1670 train_time:99787ms step_avg:92.82ms +step:1076/1670 train_time:99878ms step_avg:92.82ms +step:1077/1670 train_time:99970ms step_avg:92.82ms +step:1078/1670 train_time:100060ms step_avg:92.82ms +step:1079/1670 train_time:100153ms step_avg:92.82ms +step:1080/1670 train_time:100245ms step_avg:92.82ms +step:1081/1670 train_time:100336ms step_avg:92.82ms +step:1082/1670 train_time:100430ms step_avg:92.82ms +step:1083/1670 train_time:100524ms step_avg:92.82ms +step:1084/1670 train_time:100618ms step_avg:92.82ms +step:1085/1670 train_time:100713ms step_avg:92.82ms +step:1086/1670 train_time:100807ms step_avg:92.82ms +step:1087/1670 train_time:100898ms step_avg:92.82ms +step:1088/1670 train_time:100990ms step_avg:92.82ms +step:1089/1670 train_time:101081ms step_avg:92.82ms +step:1090/1670 train_time:101173ms step_avg:92.82ms +step:1091/1670 train_time:101264ms step_avg:92.82ms +step:1092/1670 train_time:101357ms step_avg:92.82ms +step:1093/1670 train_time:101451ms step_avg:92.82ms +step:1094/1670 train_time:101543ms step_avg:92.82ms +step:1095/1670 train_time:101638ms step_avg:92.82ms +step:1096/1670 train_time:101731ms step_avg:92.82ms +step:1097/1670 train_time:101823ms step_avg:92.82ms +step:1098/1670 train_time:101916ms step_avg:92.82ms +step:1099/1670 train_time:102007ms step_avg:92.82ms +step:1100/1670 train_time:102099ms step_avg:92.82ms +step:1101/1670 train_time:102191ms step_avg:92.82ms +step:1102/1670 train_time:102282ms step_avg:92.81ms +step:1103/1670 train_time:102375ms step_avg:92.81ms +step:1104/1670 train_time:102467ms step_avg:92.81ms +step:1105/1670 train_time:102560ms step_avg:92.81ms +step:1106/1670 train_time:102655ms step_avg:92.82ms +step:1107/1670 train_time:102749ms step_avg:92.82ms +step:1108/1670 train_time:102842ms step_avg:92.82ms +step:1109/1670 train_time:102935ms step_avg:92.82ms +step:1110/1670 train_time:103028ms step_avg:92.82ms +step:1111/1670 train_time:103121ms step_avg:92.82ms +step:1112/1670 train_time:103213ms step_avg:92.82ms +step:1113/1670 train_time:103304ms step_avg:92.82ms +step:1114/1670 train_time:103397ms step_avg:92.82ms +step:1115/1670 train_time:103685ms step_avg:92.99ms +step:1116/1670 train_time:103754ms step_avg:92.97ms +step:1117/1670 train_time:103845ms step_avg:92.97ms +step:1118/1670 train_time:103936ms step_avg:92.97ms +step:1119/1670 train_time:104028ms step_avg:92.97ms +step:1120/1670 train_time:104120ms step_avg:92.96ms +step:1121/1670 train_time:104212ms step_avg:92.96ms +step:1122/1670 train_time:104304ms step_avg:92.96ms +step:1123/1670 train_time:104395ms step_avg:92.96ms +step:1124/1670 train_time:104487ms step_avg:92.96ms +step:1125/1670 train_time:104584ms step_avg:92.96ms +step:1125/1670 val_loss:3.4132 train_time:104682ms step_avg:93.05ms +step:1126/1670 train_time:104704ms step_avg:92.99ms +step:1127/1670 train_time:104782ms step_avg:92.97ms +step:1128/1670 train_time:104882ms step_avg:92.98ms +step:1129/1670 train_time:104977ms step_avg:92.98ms +step:1130/1670 train_time:105071ms step_avg:92.98ms +step:1131/1670 train_time:105163ms step_avg:92.98ms +step:1132/1670 train_time:105255ms step_avg:92.98ms +step:1133/1670 train_time:105347ms step_avg:92.98ms +step:1134/1670 train_time:105438ms step_avg:92.98ms +step:1135/1670 train_time:105530ms step_avg:92.98ms +step:1136/1670 train_time:105622ms step_avg:92.98ms +step:1137/1670 train_time:105717ms step_avg:92.98ms +step:1138/1670 train_time:105812ms step_avg:92.98ms +step:1139/1670 train_time:105908ms step_avg:92.98ms +step:1140/1670 train_time:106002ms step_avg:92.98ms +step:1141/1670 train_time:106095ms step_avg:92.98ms +step:1142/1670 train_time:106187ms step_avg:92.98ms +step:1143/1670 train_time:106279ms step_avg:92.98ms +step:1144/1670 train_time:106372ms step_avg:92.98ms +step:1145/1670 train_time:106465ms step_avg:92.98ms +step:1146/1670 train_time:106557ms step_avg:92.98ms +step:1147/1670 train_time:106649ms step_avg:92.98ms +step:1148/1670 train_time:106744ms step_avg:92.98ms +step:1149/1670 train_time:106839ms step_avg:92.98ms +step:1150/1670 train_time:106933ms step_avg:92.99ms +step:1151/1670 train_time:107027ms step_avg:92.99ms +step:1152/1670 train_time:107120ms step_avg:92.99ms +step:1153/1670 train_time:107212ms step_avg:92.99ms +step:1154/1670 train_time:107305ms step_avg:92.99ms +step:1155/1670 train_time:107398ms step_avg:92.98ms +step:1156/1670 train_time:107489ms step_avg:92.98ms +step:1157/1670 train_time:107582ms step_avg:92.98ms +step:1158/1670 train_time:107676ms step_avg:92.98ms +step:1159/1670 train_time:107771ms step_avg:92.99ms +step:1160/1670 train_time:107866ms step_avg:92.99ms +step:1161/1670 train_time:107959ms step_avg:92.99ms +step:1162/1670 train_time:108052ms step_avg:92.99ms +step:1163/1670 train_time:108146ms step_avg:92.99ms +step:1164/1670 train_time:108239ms step_avg:92.99ms +step:1165/1670 train_time:108331ms step_avg:92.99ms +step:1166/1670 train_time:108424ms step_avg:92.99ms +step:1167/1670 train_time:108517ms step_avg:92.99ms +step:1168/1670 train_time:108610ms step_avg:92.99ms +step:1169/1670 train_time:108704ms step_avg:92.99ms +step:1170/1670 train_time:108798ms step_avg:92.99ms +step:1171/1670 train_time:108891ms step_avg:92.99ms +step:1172/1670 train_time:108985ms step_avg:92.99ms +step:1173/1670 train_time:109080ms step_avg:92.99ms +step:1174/1670 train_time:109172ms step_avg:92.99ms +step:1175/1670 train_time:109266ms step_avg:92.99ms +step:1176/1670 train_time:109359ms step_avg:92.99ms +step:1177/1670 train_time:109451ms step_avg:92.99ms +step:1178/1670 train_time:109543ms step_avg:92.99ms +step:1179/1670 train_time:109636ms step_avg:92.99ms +step:1180/1670 train_time:109731ms step_avg:92.99ms +step:1181/1670 train_time:109826ms step_avg:92.99ms +step:1182/1670 train_time:109920ms step_avg:92.99ms +step:1183/1670 train_time:110013ms step_avg:93.00ms +step:1184/1670 train_time:110108ms step_avg:93.00ms +step:1185/1670 train_time:110201ms step_avg:93.00ms +step:1186/1670 train_time:110293ms step_avg:93.00ms +step:1187/1670 train_time:110386ms step_avg:93.00ms +step:1188/1670 train_time:110479ms step_avg:93.00ms +step:1189/1670 train_time:110572ms step_avg:93.00ms +step:1190/1670 train_time:110666ms step_avg:93.00ms +step:1191/1670 train_time:110759ms step_avg:93.00ms +step:1192/1670 train_time:110852ms step_avg:93.00ms +step:1193/1670 train_time:110947ms step_avg:93.00ms +step:1194/1670 train_time:111040ms step_avg:93.00ms +step:1195/1670 train_time:111133ms step_avg:93.00ms +step:1196/1670 train_time:111225ms step_avg:93.00ms +step:1197/1670 train_time:111319ms step_avg:93.00ms +step:1198/1670 train_time:111411ms step_avg:93.00ms +step:1199/1670 train_time:111504ms step_avg:93.00ms +step:1200/1670 train_time:111597ms step_avg:93.00ms +step:1201/1670 train_time:111690ms step_avg:93.00ms +step:1202/1670 train_time:111785ms step_avg:93.00ms +step:1203/1670 train_time:111879ms step_avg:93.00ms +step:1204/1670 train_time:111972ms step_avg:93.00ms +step:1205/1670 train_time:112067ms step_avg:93.00ms +step:1206/1670 train_time:112160ms step_avg:93.00ms +step:1207/1670 train_time:112253ms step_avg:93.00ms +step:1208/1670 train_time:112348ms step_avg:93.00ms +step:1209/1670 train_time:112441ms step_avg:93.00ms +step:1210/1670 train_time:112533ms step_avg:93.00ms +step:1211/1670 train_time:112626ms step_avg:93.00ms +step:1212/1670 train_time:112719ms step_avg:93.00ms +step:1213/1670 train_time:112812ms step_avg:93.00ms +step:1214/1670 train_time:112906ms step_avg:93.00ms +step:1215/1670 train_time:112999ms step_avg:93.00ms +step:1216/1670 train_time:113092ms step_avg:93.00ms +step:1217/1670 train_time:113186ms step_avg:93.00ms +step:1218/1670 train_time:113279ms step_avg:93.00ms +step:1219/1670 train_time:113373ms step_avg:93.00ms +step:1220/1670 train_time:113466ms step_avg:93.01ms +step:1221/1670 train_time:113560ms step_avg:93.01ms +step:1222/1670 train_time:113652ms step_avg:93.01ms +step:1223/1670 train_time:113746ms step_avg:93.01ms +step:1224/1670 train_time:113840ms step_avg:93.01ms +step:1225/1670 train_time:113932ms step_avg:93.01ms +step:1226/1670 train_time:114027ms step_avg:93.01ms +step:1227/1670 train_time:114120ms step_avg:93.01ms +step:1228/1670 train_time:114212ms step_avg:93.01ms +step:1229/1670 train_time:114305ms step_avg:93.01ms +step:1230/1670 train_time:114398ms step_avg:93.01ms +step:1231/1670 train_time:114491ms step_avg:93.01ms +step:1232/1670 train_time:114585ms step_avg:93.01ms +step:1233/1670 train_time:114677ms step_avg:93.01ms +step:1234/1670 train_time:114772ms step_avg:93.01ms +step:1235/1670 train_time:114865ms step_avg:93.01ms +step:1236/1670 train_time:114959ms step_avg:93.01ms +step:1237/1670 train_time:115052ms step_avg:93.01ms +step:1238/1670 train_time:115146ms step_avg:93.01ms +step:1239/1670 train_time:115238ms step_avg:93.01ms +step:1240/1670 train_time:115331ms step_avg:93.01ms +step:1241/1670 train_time:115424ms step_avg:93.01ms +step:1242/1670 train_time:115517ms step_avg:93.01ms +step:1243/1670 train_time:115611ms step_avg:93.01ms +step:1244/1670 train_time:115705ms step_avg:93.01ms +step:1245/1670 train_time:115798ms step_avg:93.01ms +step:1246/1670 train_time:115891ms step_avg:93.01ms +step:1247/1670 train_time:115984ms step_avg:93.01ms +step:1248/1670 train_time:116077ms step_avg:93.01ms +step:1249/1670 train_time:116170ms step_avg:93.01ms +step:1250/1670 train_time:116263ms step_avg:93.01ms +step:1250/1670 val_loss:3.3746 train_time:116355ms step_avg:93.08ms +step:1251/1670 train_time:116375ms step_avg:93.03ms +step:1252/1670 train_time:116450ms step_avg:93.01ms +step:1253/1670 train_time:116544ms step_avg:93.01ms +step:1254/1670 train_time:116635ms step_avg:93.01ms +step:1255/1670 train_time:116728ms step_avg:93.01ms +step:1256/1670 train_time:116820ms step_avg:93.01ms +step:1257/1670 train_time:116913ms step_avg:93.01ms +step:1258/1670 train_time:117005ms step_avg:93.01ms +step:1259/1670 train_time:117097ms step_avg:93.01ms +step:1260/1670 train_time:117191ms step_avg:93.01ms +step:1261/1670 train_time:117288ms step_avg:93.01ms +step:1262/1670 train_time:117383ms step_avg:93.01ms +step:1263/1670 train_time:117477ms step_avg:93.01ms +step:1264/1670 train_time:117570ms step_avg:93.01ms +step:1265/1670 train_time:117664ms step_avg:93.01ms +step:1266/1670 train_time:117756ms step_avg:93.01ms +step:1267/1670 train_time:117851ms step_avg:93.02ms +step:1268/1670 train_time:117944ms step_avg:93.02ms +step:1269/1670 train_time:118036ms step_avg:93.01ms +step:1270/1670 train_time:118128ms step_avg:93.01ms +step:1271/1670 train_time:118222ms step_avg:93.02ms +step:1272/1670 train_time:118316ms step_avg:93.02ms +step:1273/1670 train_time:118410ms step_avg:93.02ms +step:1274/1670 train_time:118649ms step_avg:93.13ms +step:1275/1670 train_time:118728ms step_avg:93.12ms +step:1276/1670 train_time:118819ms step_avg:93.12ms +step:1277/1670 train_time:118911ms step_avg:93.12ms +step:1278/1670 train_time:119002ms step_avg:93.12ms +step:1279/1670 train_time:119094ms step_avg:93.11ms +step:1280/1670 train_time:119186ms step_avg:93.11ms +step:1281/1670 train_time:119278ms step_avg:93.11ms +step:1282/1670 train_time:119371ms step_avg:93.11ms +step:1283/1670 train_time:119463ms step_avg:93.11ms +step:1284/1670 train_time:119561ms step_avg:93.12ms +step:1285/1670 train_time:119658ms step_avg:93.12ms +step:1286/1670 train_time:119753ms step_avg:93.12ms +step:1287/1670 train_time:119846ms step_avg:93.12ms +step:1288/1670 train_time:119939ms step_avg:93.12ms +step:1289/1670 train_time:120033ms step_avg:93.12ms +step:1290/1670 train_time:120126ms step_avg:93.12ms +step:1291/1670 train_time:120218ms step_avg:93.12ms +step:1292/1670 train_time:120310ms step_avg:93.12ms +step:1293/1670 train_time:120402ms step_avg:93.12ms +step:1294/1670 train_time:120498ms step_avg:93.12ms +step:1295/1670 train_time:120593ms step_avg:93.12ms +step:1296/1670 train_time:120688ms step_avg:93.12ms +step:1297/1670 train_time:120781ms step_avg:93.12ms +step:1298/1670 train_time:120875ms step_avg:93.12ms +step:1299/1670 train_time:120968ms step_avg:93.12ms +step:1300/1670 train_time:121060ms step_avg:93.12ms +step:1301/1670 train_time:121153ms step_avg:93.12ms +step:1302/1670 train_time:121245ms step_avg:93.12ms +step:1303/1670 train_time:121336ms step_avg:93.12ms +step:1304/1670 train_time:121430ms step_avg:93.12ms +step:1305/1670 train_time:121523ms step_avg:93.12ms +step:1306/1670 train_time:121617ms step_avg:93.12ms +step:1307/1670 train_time:121711ms step_avg:93.12ms +step:1308/1670 train_time:121804ms step_avg:93.12ms +step:1309/1670 train_time:121898ms step_avg:93.12ms +step:1310/1670 train_time:121991ms step_avg:93.12ms +step:1311/1670 train_time:122084ms step_avg:93.12ms +step:1312/1670 train_time:122177ms step_avg:93.12ms +step:1313/1670 train_time:122270ms step_avg:93.12ms +step:1314/1670 train_time:122362ms step_avg:93.12ms +step:1315/1670 train_time:122456ms step_avg:93.12ms +step:1316/1670 train_time:122551ms step_avg:93.12ms +step:1317/1670 train_time:122644ms step_avg:93.12ms +step:1318/1670 train_time:122738ms step_avg:93.12ms +step:1319/1670 train_time:122832ms step_avg:93.12ms +step:1320/1670 train_time:122924ms step_avg:93.12ms +step:1321/1670 train_time:123017ms step_avg:93.12ms +step:1322/1670 train_time:123109ms step_avg:93.12ms +step:1323/1670 train_time:123202ms step_avg:93.12ms +step:1324/1670 train_time:123295ms step_avg:93.12ms +step:1325/1670 train_time:123389ms step_avg:93.12ms +step:1326/1670 train_time:123481ms step_avg:93.12ms +step:1327/1670 train_time:123576ms step_avg:93.12ms +step:1328/1670 train_time:123670ms step_avg:93.12ms +step:1329/1670 train_time:123763ms step_avg:93.12ms +step:1330/1670 train_time:123857ms step_avg:93.13ms +step:1331/1670 train_time:123952ms step_avg:93.13ms +step:1332/1670 train_time:124044ms step_avg:93.13ms +step:1333/1670 train_time:124136ms step_avg:93.13ms +step:1334/1670 train_time:124229ms step_avg:93.12ms +step:1335/1670 train_time:124321ms step_avg:93.12ms +step:1336/1670 train_time:124415ms step_avg:93.12ms +step:1337/1670 train_time:124508ms step_avg:93.13ms +step:1338/1670 train_time:124601ms step_avg:93.12ms +step:1339/1670 train_time:124695ms step_avg:93.13ms +step:1340/1670 train_time:124789ms step_avg:93.13ms +step:1341/1670 train_time:124881ms step_avg:93.13ms +step:1342/1670 train_time:124975ms step_avg:93.13ms +step:1343/1670 train_time:125068ms step_avg:93.13ms +step:1344/1670 train_time:125160ms step_avg:93.13ms +step:1345/1670 train_time:125254ms step_avg:93.13ms +step:1346/1670 train_time:125347ms step_avg:93.13ms +step:1347/1670 train_time:125441ms step_avg:93.13ms +step:1348/1670 train_time:125533ms step_avg:93.13ms +step:1349/1670 train_time:125626ms step_avg:93.13ms +step:1350/1670 train_time:125719ms step_avg:93.13ms +step:1351/1670 train_time:125813ms step_avg:93.13ms +step:1352/1670 train_time:125906ms step_avg:93.13ms +step:1353/1670 train_time:125999ms step_avg:93.13ms +step:1354/1670 train_time:126093ms step_avg:93.13ms +step:1355/1670 train_time:126186ms step_avg:93.13ms +step:1356/1670 train_time:126278ms step_avg:93.13ms +step:1357/1670 train_time:126372ms step_avg:93.13ms +step:1358/1670 train_time:126465ms step_avg:93.13ms +step:1359/1670 train_time:126558ms step_avg:93.13ms +step:1360/1670 train_time:126652ms step_avg:93.13ms +step:1361/1670 train_time:126746ms step_avg:93.13ms +step:1362/1670 train_time:126839ms step_avg:93.13ms +step:1363/1670 train_time:126933ms step_avg:93.13ms +step:1364/1670 train_time:127027ms step_avg:93.13ms +step:1365/1670 train_time:127121ms step_avg:93.13ms +step:1366/1670 train_time:127215ms step_avg:93.13ms +step:1367/1670 train_time:127308ms step_avg:93.13ms +step:1368/1670 train_time:127400ms step_avg:93.13ms +step:1369/1670 train_time:127494ms step_avg:93.13ms +step:1370/1670 train_time:127588ms step_avg:93.13ms +step:1371/1670 train_time:127681ms step_avg:93.13ms +step:1372/1670 train_time:127776ms step_avg:93.13ms +step:1373/1670 train_time:127869ms step_avg:93.13ms +step:1374/1670 train_time:127962ms step_avg:93.13ms +step:1375/1670 train_time:128056ms step_avg:93.13ms +step:1375/1670 val_loss:3.3403 train_time:128149ms step_avg:93.20ms +step:1376/1670 train_time:128169ms step_avg:93.15ms +step:1377/1670 train_time:128244ms step_avg:93.13ms +step:1378/1670 train_time:128337ms step_avg:93.13ms +step:1379/1670 train_time:128430ms step_avg:93.13ms +step:1380/1670 train_time:128522ms step_avg:93.13ms +step:1381/1670 train_time:128615ms step_avg:93.13ms +step:1382/1670 train_time:128708ms step_avg:93.13ms +step:1383/1670 train_time:128804ms step_avg:93.13ms +step:1384/1670 train_time:128898ms step_avg:93.13ms +step:1385/1670 train_time:128991ms step_avg:93.13ms +step:1386/1670 train_time:129085ms step_avg:93.14ms +step:1387/1670 train_time:129181ms step_avg:93.14ms +step:1388/1670 train_time:129274ms step_avg:93.14ms +step:1389/1670 train_time:129367ms step_avg:93.14ms +step:1390/1670 train_time:129459ms step_avg:93.14ms +step:1391/1670 train_time:129552ms step_avg:93.14ms +step:1392/1670 train_time:129644ms step_avg:93.14ms +step:1393/1670 train_time:129738ms step_avg:93.14ms +step:1394/1670 train_time:129830ms step_avg:93.13ms +step:1395/1670 train_time:129923ms step_avg:93.13ms +step:1396/1670 train_time:130017ms step_avg:93.14ms +step:1397/1670 train_time:130111ms step_avg:93.14ms +step:1398/1670 train_time:130205ms step_avg:93.14ms +step:1399/1670 train_time:130300ms step_avg:93.14ms +step:1400/1670 train_time:130392ms step_avg:93.14ms +step:1401/1670 train_time:130485ms step_avg:93.14ms +step:1402/1670 train_time:130578ms step_avg:93.14ms +step:1403/1670 train_time:130671ms step_avg:93.14ms +step:1404/1670 train_time:130764ms step_avg:93.14ms +step:1405/1670 train_time:130857ms step_avg:93.14ms +step:1406/1670 train_time:130949ms step_avg:93.14ms +step:1407/1670 train_time:131043ms step_avg:93.14ms +step:1408/1670 train_time:131137ms step_avg:93.14ms +step:1409/1670 train_time:131232ms step_avg:93.14ms +step:1410/1670 train_time:131325ms step_avg:93.14ms +step:1411/1670 train_time:131419ms step_avg:93.14ms +step:1412/1670 train_time:131513ms step_avg:93.14ms +step:1413/1670 train_time:131606ms step_avg:93.14ms +step:1414/1670 train_time:131700ms step_avg:93.14ms +step:1415/1670 train_time:131793ms step_avg:93.14ms +step:1416/1670 train_time:131886ms step_avg:93.14ms +step:1417/1670 train_time:131979ms step_avg:93.14ms +step:1418/1670 train_time:132071ms step_avg:93.14ms +step:1419/1670 train_time:132165ms step_avg:93.14ms +step:1420/1670 train_time:132258ms step_avg:93.14ms +step:1421/1670 train_time:132351ms step_avg:93.14ms +step:1422/1670 train_time:132444ms step_avg:93.14ms +step:1423/1670 train_time:132538ms step_avg:93.14ms +step:1424/1670 train_time:132630ms step_avg:93.14ms +step:1425/1670 train_time:132724ms step_avg:93.14ms +step:1426/1670 train_time:132818ms step_avg:93.14ms +step:1427/1670 train_time:132910ms step_avg:93.14ms +step:1428/1670 train_time:133004ms step_avg:93.14ms +step:1429/1670 train_time:133097ms step_avg:93.14ms +step:1430/1670 train_time:133190ms step_avg:93.14ms +step:1431/1670 train_time:133283ms step_avg:93.14ms +step:1432/1670 train_time:133376ms step_avg:93.14ms +step:1433/1670 train_time:133469ms step_avg:93.14ms +step:1434/1670 train_time:133562ms step_avg:93.14ms +step:1435/1670 train_time:133655ms step_avg:93.14ms +step:1436/1670 train_time:133747ms step_avg:93.14ms +step:1437/1670 train_time:133841ms step_avg:93.14ms +step:1438/1670 train_time:133934ms step_avg:93.14ms +step:1439/1670 train_time:134027ms step_avg:93.14ms +step:1440/1670 train_time:134121ms step_avg:93.14ms +step:1441/1670 train_time:134214ms step_avg:93.14ms +step:1442/1670 train_time:134308ms step_avg:93.14ms +step:1443/1670 train_time:134403ms step_avg:93.14ms +step:1444/1670 train_time:134496ms step_avg:93.14ms +step:1445/1670 train_time:134588ms step_avg:93.14ms +step:1446/1670 train_time:134681ms step_avg:93.14ms +step:1447/1670 train_time:134775ms step_avg:93.14ms +step:1448/1670 train_time:134867ms step_avg:93.14ms +step:1449/1670 train_time:134961ms step_avg:93.14ms +step:1450/1670 train_time:135055ms step_avg:93.14ms +step:1451/1670 train_time:135148ms step_avg:93.14ms +step:1452/1670 train_time:135242ms step_avg:93.14ms +step:1453/1670 train_time:135336ms step_avg:93.14ms +step:1454/1670 train_time:135429ms step_avg:93.14ms +step:1455/1670 train_time:135523ms step_avg:93.14ms +step:1456/1670 train_time:135615ms step_avg:93.14ms +step:1457/1670 train_time:135708ms step_avg:93.14ms +step:1458/1670 train_time:135803ms step_avg:93.14ms +step:1459/1670 train_time:135896ms step_avg:93.14ms +step:1460/1670 train_time:135989ms step_avg:93.14ms +step:1461/1670 train_time:136083ms step_avg:93.14ms +step:1462/1670 train_time:136176ms step_avg:93.14ms +step:1463/1670 train_time:136269ms step_avg:93.14ms +step:1464/1670 train_time:136364ms step_avg:93.14ms +step:1465/1670 train_time:136457ms step_avg:93.14ms +step:1466/1670 train_time:136550ms step_avg:93.14ms +step:1467/1670 train_time:136644ms step_avg:93.14ms +step:1468/1670 train_time:136737ms step_avg:93.15ms +step:1469/1670 train_time:136830ms step_avg:93.15ms +step:1470/1670 train_time:136923ms step_avg:93.14ms +step:1471/1670 train_time:137015ms step_avg:93.14ms +step:1472/1670 train_time:137109ms step_avg:93.14ms +step:1473/1670 train_time:137202ms step_avg:93.14ms +step:1474/1670 train_time:137297ms step_avg:93.15ms +step:1475/1670 train_time:137389ms step_avg:93.15ms +step:1476/1670 train_time:137482ms step_avg:93.15ms +step:1477/1670 train_time:137575ms step_avg:93.15ms +step:1478/1670 train_time:137668ms step_avg:93.14ms +step:1479/1670 train_time:137761ms step_avg:93.14ms +step:1480/1670 train_time:137854ms step_avg:93.14ms +step:1481/1670 train_time:137947ms step_avg:93.14ms +step:1482/1670 train_time:138040ms step_avg:93.14ms +step:1483/1670 train_time:138133ms step_avg:93.14ms +step:1484/1670 train_time:138226ms step_avg:93.14ms +step:1485/1670 train_time:138476ms step_avg:93.25ms +step:1486/1670 train_time:138546ms step_avg:93.23ms +step:1487/1670 train_time:138637ms step_avg:93.23ms +step:1488/1670 train_time:138729ms step_avg:93.23ms +step:1489/1670 train_time:138820ms step_avg:93.23ms +step:1490/1670 train_time:138912ms step_avg:93.23ms +step:1491/1670 train_time:139004ms step_avg:93.23ms +step:1492/1670 train_time:139096ms step_avg:93.23ms +step:1493/1670 train_time:139188ms step_avg:93.23ms +step:1494/1670 train_time:139280ms step_avg:93.23ms +step:1495/1670 train_time:139377ms step_avg:93.23ms +step:1496/1670 train_time:139474ms step_avg:93.23ms +step:1497/1670 train_time:139572ms step_avg:93.23ms +step:1498/1670 train_time:139665ms step_avg:93.23ms +step:1499/1670 train_time:139758ms step_avg:93.23ms +step:1500/1670 train_time:139849ms step_avg:93.23ms +step:1500/1670 val_loss:3.3104 train_time:139944ms step_avg:93.30ms +step:1501/1670 train_time:139964ms step_avg:93.25ms +step:1502/1670 train_time:140037ms step_avg:93.23ms +step:1503/1670 train_time:140130ms step_avg:93.23ms +step:1504/1670 train_time:140222ms step_avg:93.23ms +step:1505/1670 train_time:140315ms step_avg:93.23ms +step:1506/1670 train_time:140408ms step_avg:93.23ms +step:1507/1670 train_time:140501ms step_avg:93.23ms +step:1508/1670 train_time:140594ms step_avg:93.23ms +step:1509/1670 train_time:140689ms step_avg:93.23ms +step:1510/1670 train_time:140784ms step_avg:93.23ms +step:1511/1670 train_time:140877ms step_avg:93.23ms +step:1512/1670 train_time:140972ms step_avg:93.24ms +step:1513/1670 train_time:141066ms step_avg:93.24ms +step:1514/1670 train_time:141159ms step_avg:93.24ms +step:1515/1670 train_time:141252ms step_avg:93.24ms +step:1516/1670 train_time:141345ms step_avg:93.24ms +step:1517/1670 train_time:141437ms step_avg:93.23ms +step:1518/1670 train_time:141529ms step_avg:93.23ms +step:1519/1670 train_time:141623ms step_avg:93.23ms +step:1520/1670 train_time:141716ms step_avg:93.23ms +step:1521/1670 train_time:141811ms step_avg:93.24ms +step:1522/1670 train_time:141905ms step_avg:93.24ms +step:1523/1670 train_time:141998ms step_avg:93.24ms +step:1524/1670 train_time:142091ms step_avg:93.24ms +step:1525/1670 train_time:142185ms step_avg:93.24ms +step:1526/1670 train_time:142279ms step_avg:93.24ms +step:1527/1670 train_time:142371ms step_avg:93.24ms +step:1528/1670 train_time:142464ms step_avg:93.24ms +step:1529/1670 train_time:142557ms step_avg:93.24ms +step:1530/1670 train_time:142651ms step_avg:93.24ms +step:1531/1670 train_time:142745ms step_avg:93.24ms +step:1532/1670 train_time:142839ms step_avg:93.24ms +step:1533/1670 train_time:142932ms step_avg:93.24ms +step:1534/1670 train_time:143026ms step_avg:93.24ms +step:1535/1670 train_time:143120ms step_avg:93.24ms +step:1536/1670 train_time:143212ms step_avg:93.24ms +step:1537/1670 train_time:143306ms step_avg:93.24ms +step:1538/1670 train_time:143399ms step_avg:93.24ms +step:1539/1670 train_time:143492ms step_avg:93.24ms +step:1540/1670 train_time:143585ms step_avg:93.24ms +step:1541/1670 train_time:143678ms step_avg:93.24ms +step:1542/1670 train_time:143772ms step_avg:93.24ms +step:1543/1670 train_time:143865ms step_avg:93.24ms +step:1544/1670 train_time:143958ms step_avg:93.24ms +step:1545/1670 train_time:144052ms step_avg:93.24ms +step:1546/1670 train_time:144146ms step_avg:93.24ms +step:1547/1670 train_time:144240ms step_avg:93.24ms +step:1548/1670 train_time:144333ms step_avg:93.24ms +step:1549/1670 train_time:144427ms step_avg:93.24ms +step:1550/1670 train_time:144519ms step_avg:93.24ms +step:1551/1670 train_time:144612ms step_avg:93.24ms +step:1552/1670 train_time:144706ms step_avg:93.24ms +step:1553/1670 train_time:144799ms step_avg:93.24ms +step:1554/1670 train_time:144892ms step_avg:93.24ms +step:1555/1670 train_time:144986ms step_avg:93.24ms +step:1556/1670 train_time:145079ms step_avg:93.24ms +step:1557/1670 train_time:145172ms step_avg:93.24ms +step:1558/1670 train_time:145265ms step_avg:93.24ms +step:1559/1670 train_time:145358ms step_avg:93.24ms +step:1560/1670 train_time:145451ms step_avg:93.24ms +step:1561/1670 train_time:145544ms step_avg:93.24ms +step:1562/1670 train_time:145638ms step_avg:93.24ms +step:1563/1670 train_time:145732ms step_avg:93.24ms +step:1564/1670 train_time:145825ms step_avg:93.24ms +step:1565/1670 train_time:145918ms step_avg:93.24ms +step:1566/1670 train_time:146012ms step_avg:93.24ms +step:1567/1670 train_time:146106ms step_avg:93.24ms +step:1568/1670 train_time:146198ms step_avg:93.24ms +step:1569/1670 train_time:146291ms step_avg:93.24ms +step:1570/1670 train_time:146385ms step_avg:93.24ms +step:1571/1670 train_time:146478ms step_avg:93.24ms +step:1572/1670 train_time:146571ms step_avg:93.24ms +step:1573/1670 train_time:146664ms step_avg:93.24ms +step:1574/1670 train_time:146758ms step_avg:93.24ms +step:1575/1670 train_time:146852ms step_avg:93.24ms +step:1576/1670 train_time:146946ms step_avg:93.24ms +step:1577/1670 train_time:147039ms step_avg:93.24ms +step:1578/1670 train_time:147132ms step_avg:93.24ms +step:1579/1670 train_time:147227ms step_avg:93.24ms +step:1580/1670 train_time:147321ms step_avg:93.24ms +step:1581/1670 train_time:147414ms step_avg:93.24ms +step:1582/1670 train_time:147507ms step_avg:93.24ms +step:1583/1670 train_time:147600ms step_avg:93.24ms +step:1584/1670 train_time:147693ms step_avg:93.24ms +step:1585/1670 train_time:147786ms step_avg:93.24ms +step:1586/1670 train_time:147880ms step_avg:93.24ms +step:1587/1670 train_time:147973ms step_avg:93.24ms +step:1588/1670 train_time:148066ms step_avg:93.24ms +step:1589/1670 train_time:148160ms step_avg:93.24ms +step:1590/1670 train_time:148253ms step_avg:93.24ms +step:1591/1670 train_time:148348ms step_avg:93.24ms +step:1592/1670 train_time:148441ms step_avg:93.24ms +step:1593/1670 train_time:148534ms step_avg:93.24ms +step:1594/1670 train_time:148628ms step_avg:93.24ms +step:1595/1670 train_time:148721ms step_avg:93.24ms +step:1596/1670 train_time:148814ms step_avg:93.24ms +step:1597/1670 train_time:148907ms step_avg:93.24ms +step:1598/1670 train_time:149001ms step_avg:93.24ms +step:1599/1670 train_time:149093ms step_avg:93.24ms +step:1600/1670 train_time:149187ms step_avg:93.24ms +step:1601/1670 train_time:149281ms step_avg:93.24ms +step:1602/1670 train_time:149373ms step_avg:93.24ms +step:1603/1670 train_time:149466ms step_avg:93.24ms +step:1604/1670 train_time:149559ms step_avg:93.24ms +step:1605/1670 train_time:149652ms step_avg:93.24ms +step:1606/1670 train_time:149745ms step_avg:93.24ms +step:1607/1670 train_time:149838ms step_avg:93.24ms +step:1608/1670 train_time:149931ms step_avg:93.24ms +step:1609/1670 train_time:150025ms step_avg:93.24ms +step:1610/1670 train_time:150119ms step_avg:93.24ms +step:1611/1670 train_time:150211ms step_avg:93.24ms +step:1612/1670 train_time:150305ms step_avg:93.24ms +step:1613/1670 train_time:150398ms step_avg:93.24ms +step:1614/1670 train_time:150493ms step_avg:93.24ms +step:1615/1670 train_time:150588ms step_avg:93.24ms +step:1616/1670 train_time:150681ms step_avg:93.24ms +step:1617/1670 train_time:150773ms step_avg:93.24ms +step:1618/1670 train_time:150865ms step_avg:93.24ms +step:1619/1670 train_time:150958ms step_avg:93.24ms +step:1620/1670 train_time:151052ms step_avg:93.24ms +step:1621/1670 train_time:151146ms step_avg:93.24ms +step:1622/1670 train_time:151239ms step_avg:93.24ms +step:1623/1670 train_time:151332ms step_avg:93.24ms +step:1624/1670 train_time:151426ms step_avg:93.24ms +step:1625/1670 train_time:151519ms step_avg:93.24ms +step:1625/1670 val_loss:3.2852 train_time:151612ms step_avg:93.30ms +step:1626/1670 train_time:151632ms step_avg:93.25ms +step:1627/1670 train_time:151707ms step_avg:93.24ms +step:1628/1670 train_time:151800ms step_avg:93.24ms +step:1629/1670 train_time:151893ms step_avg:93.24ms +step:1630/1670 train_time:151985ms step_avg:93.24ms +step:1631/1670 train_time:152078ms step_avg:93.24ms +step:1632/1670 train_time:152171ms step_avg:93.24ms +step:1633/1670 train_time:152265ms step_avg:93.24ms +step:1634/1670 train_time:152357ms step_avg:93.24ms +step:1635/1670 train_time:152450ms step_avg:93.24ms +step:1636/1670 train_time:152544ms step_avg:93.24ms +step:1637/1670 train_time:152639ms step_avg:93.24ms +step:1638/1670 train_time:152734ms step_avg:93.24ms +step:1639/1670 train_time:152827ms step_avg:93.24ms +step:1640/1670 train_time:152920ms step_avg:93.24ms +step:1641/1670 train_time:153013ms step_avg:93.24ms +step:1642/1670 train_time:153105ms step_avg:93.24ms +step:1643/1670 train_time:153198ms step_avg:93.24ms +step:1644/1670 train_time:153292ms step_avg:93.24ms +step:1645/1670 train_time:153387ms step_avg:93.24ms +step:1646/1670 train_time:153481ms step_avg:93.24ms +step:1647/1670 train_time:153575ms step_avg:93.25ms +step:1648/1670 train_time:153670ms step_avg:93.25ms +step:1649/1670 train_time:153764ms step_avg:93.25ms +step:1650/1670 train_time:153857ms step_avg:93.25ms +step:1651/1670 train_time:153951ms step_avg:93.25ms +step:1652/1670 train_time:154044ms step_avg:93.25ms +step:1653/1670 train_time:154137ms step_avg:93.25ms +step:1654/1670 train_time:154230ms step_avg:93.25ms +step:1655/1670 train_time:154323ms step_avg:93.25ms +step:1656/1670 train_time:154416ms step_avg:93.25ms +step:1657/1670 train_time:154509ms step_avg:93.25ms +step:1658/1670 train_time:154603ms step_avg:93.25ms +step:1659/1670 train_time:154697ms step_avg:93.25ms +step:1660/1670 train_time:154790ms step_avg:93.25ms +step:1661/1670 train_time:154884ms step_avg:93.25ms +step:1662/1670 train_time:154977ms step_avg:93.25ms +step:1663/1670 train_time:155069ms step_avg:93.25ms +step:1664/1670 train_time:155162ms step_avg:93.25ms +step:1665/1670 train_time:155255ms step_avg:93.25ms +step:1666/1670 train_time:155349ms step_avg:93.25ms +step:1667/1670 train_time:155441ms step_avg:93.25ms +step:1668/1670 train_time:155536ms step_avg:93.25ms +step:1669/1670 train_time:155630ms step_avg:93.25ms +step:1670/1670 train_time:155723ms step_avg:93.25ms +step:1670/1670 val_loss:3.2767 train_time:155987ms step_avg:93.41ms +peak memory allocated: 31753 MiB reserved: 46816 MiB diff --git a/records/091125_VectSigmoidBFloat16/deb22a2c-6cf2-46f9-a350-aec1c97e9909.txt b/records/091125_VectSigmoidBFloat16/deb22a2c-6cf2-46f9-a350-aec1c97e9909.txt new file mode 100644 index 000000000..ff40b419b --- /dev/null +++ b/records/091125_VectSigmoidBFloat16/deb22a2c-6cf2-46f9-a350-aec1c97e9909.txt @@ -0,0 +1,3382 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op( + x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op( + g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float +) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d( + pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M + ) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and ( + n_idx + BLOCK_SIZE_N <= m_idx + ) + skip_block_above_diag = (LOWER_UPPER != 0) and ( + m_idx + BLOCK_SIZE_M <= n_idx + ) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + ( + offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c + ) + at_ptrs = A_ptr + ( + offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0 + ) + at = tl.load( + at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + ( + offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c + ) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + ( + offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + ( + offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c + ) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size + * triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + + +@torch.compile( + dynamic=False, fullgraph=True +) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# ----------------------------------------------------------------------------- +# Muon optimizer + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.01, + ): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append( + dist.reduce_scatter_tensor( + grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + ) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group["betas"] + eps = group["eps"] + wd = group["weight_decay"] + params = group["params"] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size : (rank + 1) * rank_size] + lr = group["lr"] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + g_slice, g_slice, value=1 - beta2 + ) + # bias corrections + bias1 = 1 - beta1**t + bias2 = 1 - beta2**t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append( + dist.all_gather_into_tensor( + p, p_slice, async_op=True + ).get_future() + ) + torch.futures.collect_all(all_gather_futures).wait() + + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + + +class CastedLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + use_fp8=False, + x_s=1.0, + w_s=1.0, + grad_s=1.0, + ): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * ( + self.in_features**-0.5 + ) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3**0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm( + _x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s + )[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + rotary_cos: torch.Tensor + rotary_sin: torch.Tensor + attn_scale: float + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = ( + attn_args.seqlens, + attn_args.attn_scale, + attn_args.bm_size, + ) + + q, k, v = ( + F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)) + .view(B, T, 3 * self.num_heads, self.head_dim) + .chunk(3, dim=-2) + ) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = ( + rotary(q, rotary_cos, rotary_sin), + rotary(k, rotary_cos, rotary_sin), + ) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as( + v + ) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = ( + args.train_max_seq_len + if self.training + else (args.val_batch_size // (grad_accum_steps * world_size)) + ) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=seqlens, + cu_seqlens_k=seqlens, + max_seqlen_q=max_len, + max_seqlen_k=max_len, + causal=True, + softmax_scale=attn_scale, + window_size=(bm_size, 0), + ) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid( + self.attn_gate(x[..., : self.attn_gate.weight.size(-1)]) + ).view(B, T, self.num_heads, 1) + y = y.contiguous().view( + B, T, self.num_heads * self.head_dim + ) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim**-0.5) + bound = (3**0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu( + x + ).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = ( + CausalSelfAttention(dim, head_dim, num_heads) + if layer_idx != 7 + else None + ) + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward( + self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs + ): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + + +# ----------------------------------------------------------------------------- +# The main model + + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + model_dim: int, + max_seq_len: int, + ): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList( + [nn.Embedding(vocab_size, model_dim) for _ in range(3)] + ) + self.blocks = nn.ModuleList( + [ + Block(model_dim, head_dim, num_heads, i) + for i in range(num_layers) + ] + ) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear( + model_dim, + vocab_size, + use_fp8=use_fp8, + x_s=(model_dim**0.5) / 448, + w_s=2**-9, + grad_s=1 / 448, + ) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + self.max_seq_len = max_seq_len + self.setup_yarn(head_dim) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75.0 + for param in self.value_embeds.parameters(): + param.lr_mul = 75.0 + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def setup_yarn(self, head_dim: int): + # store single copy of rotary tensors + angular_freq = (1 / 1024) ** torch.linspace( + 0, 1, steps=head_dim // 4, dtype=torch.float32 + ) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat( + [angular_freq, angular_freq.new_zeros(head_dim // 4)] + ) + t = torch.arange(self.max_seq_len, dtype=torch.float32) + theta = torch.outer(t, angular_freq) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + + # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd + windows = list( + dict.fromkeys(list(args.ws_schedule) + [args.ws_validate]) + ) + scale_factors = [ + 0.2 * math.log(curr / prev) + 1 + for prev, curr in zip(windows[:-1], windows[1:]) + ] + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + attn_scales = list( + accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor) + ) + self.attn_scales = dict(zip(windows, attn_scales)) + + def apply_yarn( + self, old_window: int, new_window: int, alpha: int = 1, beta: int = 32 + ): + rotations = ( + args.block_size * old_window * self.angular_freq / (2 * torch.pi) + ) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp( + (rotations - alpha) / (beta - alpha), 0, 1 + ) + self.angular_freq *= scaling_factor + interpolation_weight * ( + 1 - scaling_factor + ) + t = torch.arange( + self.max_seq_len, + dtype=torch.float32, + device=self.angular_freq.device, + ) + theta = torch.outer(t, self.angular_freq) + self.rotary_cos.copy_(theta.cos()) + self.rotary_sin.copy_(theta.sin()) + + def forward( + self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int + ): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = ( + [ve[0], ve[1], ve[2]] + + [None] * (len(self.blocks) - 6) + + [ve[0], ve[1], ve[2]] + ) + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [ + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + long_bm, + short_bm, + short_bm, + short_bm, + long_bm, + ] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[: (len(self.blocks) // 2)] + lambdas = self.scalars[ + 1 * len(self.blocks) : 3 * len(self.blocks) + ].view(-1, 2) + sa_lambdas = self.scalars[ + 3 * len(self.blocks) : 5 * len(self.blocks) + ].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + rotary_cos=self.rotary_cos, + rotary_sin=self.rotary_sin, + attn_scale=self.attn_scales[ws], + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + target_seq, + reduction="sum" if self.training else "mean", + ) + return loss + + +# ----------------------------------------------------------------------------- +# Distributed data loader + + +def _load_data_shard(file: Path): + header = torch.from_file( + str(file), False, 256, dtype=torch.int32 + ) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty( + num_tokens, dtype=torch.uint16, pin_memory=True + ) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto( + tokens.numpy() + ) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, ( + "number of tokens read does not match header" + ) + return tokens + + +BOS_ID = 50256 + + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = ( + (tokens == BOS_ID) + .nonzero(as_tuple=True)[0] + .to(torch.int64) + .cpu() + .numpy() + ) + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration( + f"Insufficient BOS ahead of position {cur}; hit tail of shard." + ) + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min( + self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1, + ) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + + +def distributed_data_generator( + filename_pattern: str, + num_tokens: int, + max_seq_len: int, + grad_accum_steps: int = 1, + align_to_bos: bool = True, +): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, ( + "Batch size must be divisible by world size" + ) + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError( + f"No files found for pattern: {filename_pattern}" + ) + + file_iter = iter( + files + ) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n( + num_tokens_local // 300, n=128 + ) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch( + num_tokens_local, max_seq_len + ) + start_idxs, end_idxs = ( + torch.tensor(seq_starts[rank]), + torch.tensor(seq_ends[rank]), + ) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= ( + 1 # last document was too long to account for _targets offset + ) + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len( + tokens + ): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local : pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view( + num_tokens_local, + ) + _targets = buf[1:].view( + num_tokens_local, + ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1 : len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to( + device="cuda", dtype=torch.int32, non_blocking=True + ), + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, ( + "Num tokens must be divisible by world size" + ) + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + + +@dataclass +class Hyperparameters: + # data + train_files: str = ( + "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + ) + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = ( + 0.5 # fraction of training spent cooling down the learning rate + ) + # evaluation and logging + run_id: str = f"{uuid.uuid4()}" + val_loss_every: int = ( + 125 # every how many steps to evaluate val loss? 0 for only at the end + ) + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = rank == 0 # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + + +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + +# begin by printing this file (the Python code) +print0(code) +print0("=" * 100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0( + f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}" +) +print0(f"Running Triton version {triton.__version__}") + + +def nvidia_smi(): + import subprocess # avoid top level import + + return subprocess.run( + ["nvidia-smi"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ).stdout + + +print0(nvidia_smi()) +print0("=" * 100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) + // (grad_accum_steps * world_size), +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [ + p + for n, p in model.blocks.named_parameters() + if p.ndim >= 2 and "embed" not in n +] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon( + hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0 +) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict( + model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers], +) # save the initial state +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[ + step % len(args.ws_schedule) + ] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator( + args.train_files, + args.train_batch_size, + args.train_max_seq_len, + grad_accum_steps=grad_accum_steps, +) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +ws = get_ws(0) +for step in range(train_steps + 1): + last_step = step == train_steps + new_ws = get_ws(step) + if new_ws != ws: + model.apply_yarn(ws, new_ws) + ws = new_ws + + # --------------- VALIDATION SECTION ----------------- + if last_step or ( + args.val_loss_every > 0 and step % args.val_loss_every == 0 + ): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator( + args.val_files, + args.val_batch_size, + -1, + grad_accum_steps=grad_accum_steps, + align_to_bos=False, + ) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0( + f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms", + console=True, + ) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict( + step=step, + code=code, + model=model.state_dict(), + optimizers=[opt.state_dict() for opt in optimizers], + ) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * ( + time.perf_counter() - t0 + ) + print0( + f"step:{step + 1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / (step + 1):.2f}ms", + console=True, + ) + +print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, +) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.12.11 (main, Sep 2 2025, 14:20:58) [Clang 20.1.4 ] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 11 09:43:56 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | Off | +| N/A 40C P0 125W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | Off | +| N/A 45C P0 134W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:65:00.0 Off | Off | +| N/A 45C P0 126W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:68:00.0 Off | Off | +| N/A 37C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | Off | +| N/A 37C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:86:00.0 Off | Off | +| N/A 45C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:E5:00.0 Off | Off | +| N/A 44C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E8:00.0 Off | Off | +| N/A 40C P0 129W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.07ms +step:1/1670 train_time:294ms step_avg:293.96ms +step:2/1670 train_time:312ms step_avg:155.93ms +step:3/1670 train_time:381ms step_avg:127.08ms +step:4/1670 train_time:471ms step_avg:117.71ms +step:5/1670 train_time:561ms step_avg:112.17ms +step:6/1670 train_time:652ms step_avg:108.66ms +step:7/1670 train_time:742ms step_avg:106.04ms +step:8/1670 train_time:833ms step_avg:104.11ms +step:9/1670 train_time:923ms step_avg:102.55ms +step:10/1670 train_time:1014ms step_avg:101.38ms +step:11/1670 train_time:1104ms step_avg:100.35ms +step:12/1670 train_time:1197ms step_avg:99.77ms +step:13/1670 train_time:1292ms step_avg:99.36ms +step:14/1670 train_time:1384ms step_avg:98.84ms +step:15/1670 train_time:1476ms step_avg:98.42ms +step:16/1670 train_time:1568ms step_avg:97.98ms +step:17/1670 train_time:1659ms step_avg:97.62ms +step:18/1670 train_time:1751ms step_avg:97.29ms +step:19/1670 train_time:1842ms step_avg:96.95ms +step:20/1670 train_time:1933ms step_avg:96.63ms +step:21/1670 train_time:2024ms step_avg:96.39ms +step:22/1670 train_time:2116ms step_avg:96.18ms +step:23/1670 train_time:2208ms step_avg:95.98ms +step:24/1670 train_time:2301ms step_avg:95.87ms +step:25/1670 train_time:2393ms step_avg:95.73ms +step:26/1670 train_time:2486ms step_avg:95.63ms +step:27/1670 train_time:2578ms step_avg:95.47ms +step:28/1670 train_time:2666ms step_avg:95.23ms +step:29/1670 train_time:2757ms step_avg:95.08ms +step:30/1670 train_time:2849ms step_avg:94.95ms +step:31/1670 train_time:2940ms step_avg:94.83ms +step:32/1670 train_time:3032ms step_avg:94.74ms +step:33/1670 train_time:3123ms step_avg:94.64ms +step:34/1670 train_time:3216ms step_avg:94.59ms +step:35/1670 train_time:3307ms step_avg:94.50ms +step:36/1670 train_time:3400ms step_avg:94.44ms +step:37/1670 train_time:3493ms step_avg:94.40ms +step:38/1670 train_time:3585ms step_avg:94.33ms +step:39/1670 train_time:3677ms step_avg:94.28ms +step:40/1670 train_time:3768ms step_avg:94.19ms +step:41/1670 train_time:3859ms step_avg:94.12ms +step:42/1670 train_time:3950ms step_avg:94.04ms +step:43/1670 train_time:4041ms step_avg:93.97ms +step:44/1670 train_time:4133ms step_avg:93.93ms +step:45/1670 train_time:4224ms step_avg:93.87ms +step:46/1670 train_time:4318ms step_avg:93.86ms +step:47/1670 train_time:4411ms step_avg:93.85ms +step:48/1670 train_time:4504ms step_avg:93.83ms +step:49/1670 train_time:4596ms step_avg:93.79ms +step:50/1670 train_time:4688ms step_avg:93.76ms +step:51/1670 train_time:4780ms step_avg:93.73ms +step:52/1670 train_time:4872ms step_avg:93.69ms +step:53/1670 train_time:4963ms step_avg:93.63ms +step:54/1670 train_time:5053ms step_avg:93.58ms +step:55/1670 train_time:5144ms step_avg:93.53ms +step:56/1670 train_time:5235ms step_avg:93.48ms +step:57/1670 train_time:5327ms step_avg:93.45ms +step:58/1670 train_time:5420ms step_avg:93.46ms +step:59/1670 train_time:5513ms step_avg:93.44ms +step:60/1670 train_time:5604ms step_avg:93.41ms +step:61/1670 train_time:5697ms step_avg:93.39ms +step:62/1670 train_time:5788ms step_avg:93.36ms +step:63/1670 train_time:5881ms step_avg:93.35ms +step:64/1670 train_time:5972ms step_avg:93.31ms +step:65/1670 train_time:6062ms step_avg:93.27ms +step:66/1670 train_time:6153ms step_avg:93.23ms +step:67/1670 train_time:6244ms step_avg:93.19ms +step:68/1670 train_time:6335ms step_avg:93.16ms +step:69/1670 train_time:6426ms step_avg:93.13ms +step:70/1670 train_time:6518ms step_avg:93.11ms +step:71/1670 train_time:6611ms step_avg:93.11ms +step:72/1670 train_time:6703ms step_avg:93.09ms +step:73/1670 train_time:6795ms step_avg:93.08ms +step:74/1670 train_time:6886ms step_avg:93.05ms +step:75/1670 train_time:6979ms step_avg:93.06ms +step:76/1670 train_time:7070ms step_avg:93.03ms +step:77/1670 train_time:7160ms step_avg:92.99ms +step:78/1670 train_time:7253ms step_avg:92.99ms +step:79/1670 train_time:7345ms step_avg:92.97ms +step:80/1670 train_time:7436ms step_avg:92.95ms +step:81/1670 train_time:7527ms step_avg:92.93ms +step:82/1670 train_time:7620ms step_avg:92.93ms +step:83/1670 train_time:7712ms step_avg:92.92ms +step:84/1670 train_time:7804ms step_avg:92.90ms +step:85/1670 train_time:7895ms step_avg:92.89ms +step:86/1670 train_time:7986ms step_avg:92.86ms +step:87/1670 train_time:8078ms step_avg:92.85ms +step:88/1670 train_time:8169ms step_avg:92.83ms +step:89/1670 train_time:8260ms step_avg:92.81ms +step:90/1670 train_time:8352ms step_avg:92.80ms +step:91/1670 train_time:8443ms step_avg:92.78ms +step:92/1670 train_time:8535ms step_avg:92.77ms +step:93/1670 train_time:8626ms step_avg:92.75ms +step:94/1670 train_time:8719ms step_avg:92.76ms +step:95/1670 train_time:8812ms step_avg:92.76ms +step:96/1670 train_time:8904ms step_avg:92.75ms +step:97/1670 train_time:8996ms step_avg:92.74ms +step:98/1670 train_time:9087ms step_avg:92.72ms +step:99/1670 train_time:9178ms step_avg:92.70ms +step:100/1670 train_time:9269ms step_avg:92.69ms +step:101/1670 train_time:9360ms step_avg:92.67ms +step:102/1670 train_time:9451ms step_avg:92.66ms +step:103/1670 train_time:9543ms step_avg:92.65ms +step:104/1670 train_time:9634ms step_avg:92.64ms +step:105/1670 train_time:9725ms step_avg:92.62ms +step:106/1670 train_time:9819ms step_avg:92.63ms +step:107/1670 train_time:9910ms step_avg:92.62ms +step:108/1670 train_time:10002ms step_avg:92.61ms +step:109/1670 train_time:10094ms step_avg:92.61ms +step:110/1670 train_time:10185ms step_avg:92.59ms +step:111/1670 train_time:10276ms step_avg:92.58ms +step:112/1670 train_time:10366ms step_avg:92.56ms +step:113/1670 train_time:10458ms step_avg:92.55ms +step:114/1670 train_time:10549ms step_avg:92.54ms +step:115/1670 train_time:10641ms step_avg:92.53ms +step:116/1670 train_time:10733ms step_avg:92.52ms +step:117/1670 train_time:10824ms step_avg:92.51ms +step:118/1670 train_time:10918ms step_avg:92.52ms +step:119/1670 train_time:11009ms step_avg:92.51ms +step:120/1670 train_time:11101ms step_avg:92.51ms +step:121/1670 train_time:11192ms step_avg:92.50ms +step:122/1670 train_time:11283ms step_avg:92.48ms +step:123/1670 train_time:11374ms step_avg:92.48ms +step:124/1670 train_time:11466ms step_avg:92.47ms +step:125/1670 train_time:11559ms step_avg:92.47ms +step:125/1670 val_loss:4.3014 train_time:11649ms step_avg:93.19ms +step:126/1670 train_time:11667ms step_avg:92.59ms +step:127/1670 train_time:11744ms step_avg:92.47ms +step:128/1670 train_time:11844ms step_avg:92.53ms +step:129/1670 train_time:11937ms step_avg:92.53ms +step:130/1670 train_time:12026ms step_avg:92.51ms +step:131/1670 train_time:12117ms step_avg:92.50ms +step:132/1670 train_time:12208ms step_avg:92.48ms +step:133/1670 train_time:12298ms step_avg:92.47ms +step:134/1670 train_time:12388ms step_avg:92.45ms +step:135/1670 train_time:12479ms step_avg:92.43ms +step:136/1670 train_time:12569ms step_avg:92.42ms +step:137/1670 train_time:12661ms step_avg:92.42ms +step:138/1670 train_time:12756ms step_avg:92.43ms +step:139/1670 train_time:12852ms step_avg:92.46ms +step:140/1670 train_time:12945ms step_avg:92.47ms +step:141/1670 train_time:13037ms step_avg:92.46ms +step:142/1670 train_time:13128ms step_avg:92.45ms +step:143/1670 train_time:13218ms step_avg:92.43ms +step:144/1670 train_time:13309ms step_avg:92.42ms +step:145/1670 train_time:13400ms step_avg:92.41ms +step:146/1670 train_time:13490ms step_avg:92.40ms +step:147/1670 train_time:13583ms step_avg:92.40ms +step:148/1670 train_time:13674ms step_avg:92.39ms +step:149/1670 train_time:13768ms step_avg:92.40ms +step:150/1670 train_time:13861ms step_avg:92.41ms +step:151/1670 train_time:13953ms step_avg:92.40ms +step:152/1670 train_time:14046ms step_avg:92.41ms +step:153/1670 train_time:14137ms step_avg:92.40ms +step:154/1670 train_time:14228ms step_avg:92.39ms +step:155/1670 train_time:14318ms step_avg:92.37ms +step:156/1670 train_time:14409ms step_avg:92.36ms +step:157/1670 train_time:14500ms step_avg:92.36ms +step:158/1670 train_time:14591ms step_avg:92.35ms +step:159/1670 train_time:14682ms step_avg:92.34ms +step:160/1670 train_time:14775ms step_avg:92.34ms +step:161/1670 train_time:14868ms step_avg:92.35ms +step:162/1670 train_time:14960ms step_avg:92.35ms +step:163/1670 train_time:15052ms step_avg:92.34ms +step:164/1670 train_time:15143ms step_avg:92.34ms +step:165/1670 train_time:15233ms step_avg:92.32ms +step:166/1670 train_time:15324ms step_avg:92.31ms +step:167/1670 train_time:15415ms step_avg:92.30ms +step:168/1670 train_time:15507ms step_avg:92.31ms +step:169/1670 train_time:15598ms step_avg:92.30ms +step:170/1670 train_time:15690ms step_avg:92.30ms +step:171/1670 train_time:15783ms step_avg:92.30ms +step:172/1670 train_time:15876ms step_avg:92.30ms +step:173/1670 train_time:15969ms step_avg:92.30ms +step:174/1670 train_time:16061ms step_avg:92.30ms +step:175/1670 train_time:16152ms step_avg:92.30ms +step:176/1670 train_time:16242ms step_avg:92.28ms +step:177/1670 train_time:16333ms step_avg:92.28ms +step:178/1670 train_time:16424ms step_avg:92.27ms +step:179/1670 train_time:16514ms step_avg:92.26ms +step:180/1670 train_time:16605ms step_avg:92.25ms +step:181/1670 train_time:16696ms step_avg:92.24ms +step:182/1670 train_time:16789ms step_avg:92.25ms +step:183/1670 train_time:16882ms step_avg:92.25ms +step:184/1670 train_time:16973ms step_avg:92.25ms +step:185/1670 train_time:17066ms step_avg:92.25ms +step:186/1670 train_time:17157ms step_avg:92.24ms +step:187/1670 train_time:17249ms step_avg:92.24ms +step:188/1670 train_time:17340ms step_avg:92.23ms +step:189/1670 train_time:17431ms step_avg:92.23ms +step:190/1670 train_time:17521ms step_avg:92.21ms +step:191/1670 train_time:17612ms step_avg:92.21ms +step:192/1670 train_time:17704ms step_avg:92.21ms +step:193/1670 train_time:17795ms step_avg:92.20ms +step:194/1670 train_time:17887ms step_avg:92.20ms +step:195/1670 train_time:17979ms step_avg:92.20ms +step:196/1670 train_time:18072ms step_avg:92.21ms +step:197/1670 train_time:18164ms step_avg:92.20ms +step:198/1670 train_time:18255ms step_avg:92.20ms +step:199/1670 train_time:18346ms step_avg:92.19ms +step:200/1670 train_time:18438ms step_avg:92.19ms +step:201/1670 train_time:18529ms step_avg:92.19ms +step:202/1670 train_time:18620ms step_avg:92.18ms +step:203/1670 train_time:18711ms step_avg:92.17ms +step:204/1670 train_time:18803ms step_avg:92.17ms +step:205/1670 train_time:18894ms step_avg:92.17ms +step:206/1670 train_time:18986ms step_avg:92.16ms +step:207/1670 train_time:19077ms step_avg:92.16ms +step:208/1670 train_time:19171ms step_avg:92.17ms +step:209/1670 train_time:19263ms step_avg:92.17ms +step:210/1670 train_time:19353ms step_avg:92.16ms +step:211/1670 train_time:19444ms step_avg:92.15ms +step:212/1670 train_time:19534ms step_avg:92.14ms +step:213/1670 train_time:19782ms step_avg:92.87ms +step:214/1670 train_time:19857ms step_avg:92.79ms +step:215/1670 train_time:19948ms step_avg:92.78ms +step:216/1670 train_time:20038ms step_avg:92.77ms +step:217/1670 train_time:20129ms step_avg:92.76ms +step:218/1670 train_time:20219ms step_avg:92.75ms +step:219/1670 train_time:20309ms step_avg:92.74ms +step:220/1670 train_time:20399ms step_avg:92.72ms +step:221/1670 train_time:20489ms step_avg:92.71ms +step:222/1670 train_time:20579ms step_avg:92.70ms +step:223/1670 train_time:20674ms step_avg:92.71ms +step:224/1670 train_time:20769ms step_avg:92.72ms +step:225/1670 train_time:20862ms step_avg:92.72ms +step:226/1670 train_time:20952ms step_avg:92.71ms +step:227/1670 train_time:21042ms step_avg:92.70ms +step:228/1670 train_time:21133ms step_avg:92.69ms +step:229/1670 train_time:21224ms step_avg:92.68ms +step:230/1670 train_time:21315ms step_avg:92.67ms +step:231/1670 train_time:21406ms step_avg:92.67ms +step:232/1670 train_time:21496ms step_avg:92.66ms +step:233/1670 train_time:21590ms step_avg:92.66ms +step:234/1670 train_time:21683ms step_avg:92.66ms +step:235/1670 train_time:21775ms step_avg:92.66ms +step:236/1670 train_time:21868ms step_avg:92.66ms +step:237/1670 train_time:21960ms step_avg:92.66ms +step:238/1670 train_time:22051ms step_avg:92.65ms +step:239/1670 train_time:22142ms step_avg:92.65ms +step:240/1670 train_time:22232ms step_avg:92.64ms +step:241/1670 train_time:22323ms step_avg:92.63ms +step:242/1670 train_time:22414ms step_avg:92.62ms +step:243/1670 train_time:22505ms step_avg:92.61ms +step:244/1670 train_time:22595ms step_avg:92.60ms +step:245/1670 train_time:22688ms step_avg:92.60ms +step:246/1670 train_time:22781ms step_avg:92.60ms +step:247/1670 train_time:22873ms step_avg:92.60ms +step:248/1670 train_time:22964ms step_avg:92.60ms +step:249/1670 train_time:23055ms step_avg:92.59ms +step:250/1670 train_time:23146ms step_avg:92.58ms +step:250/1670 val_loss:3.9754 train_time:23237ms step_avg:92.95ms +step:251/1670 train_time:23254ms step_avg:92.65ms +step:252/1670 train_time:23329ms step_avg:92.58ms +step:253/1670 train_time:23422ms step_avg:92.58ms +step:254/1670 train_time:23512ms step_avg:92.57ms +step:255/1670 train_time:23602ms step_avg:92.56ms +step:256/1670 train_time:23692ms step_avg:92.55ms +step:257/1670 train_time:23782ms step_avg:92.54ms +step:258/1670 train_time:23872ms step_avg:92.53ms +step:259/1670 train_time:23962ms step_avg:92.52ms +step:260/1670 train_time:24054ms step_avg:92.52ms +step:261/1670 train_time:24145ms step_avg:92.51ms +step:262/1670 train_time:24238ms step_avg:92.51ms +step:263/1670 train_time:24331ms step_avg:92.51ms +step:264/1670 train_time:24423ms step_avg:92.51ms +step:265/1670 train_time:24515ms step_avg:92.51ms +step:266/1670 train_time:24605ms step_avg:92.50ms +step:267/1670 train_time:24696ms step_avg:92.49ms +step:268/1670 train_time:24786ms step_avg:92.48ms +step:269/1670 train_time:24877ms step_avg:92.48ms +step:270/1670 train_time:24968ms step_avg:92.47ms +step:271/1670 train_time:25060ms step_avg:92.47ms +step:272/1670 train_time:25151ms step_avg:92.47ms +step:273/1670 train_time:25244ms step_avg:92.47ms +step:274/1670 train_time:25336ms step_avg:92.47ms +step:275/1670 train_time:25428ms step_avg:92.46ms +step:276/1670 train_time:25518ms step_avg:92.46ms +step:277/1670 train_time:25609ms step_avg:92.45ms +step:278/1670 train_time:25699ms step_avg:92.44ms +step:279/1670 train_time:25790ms step_avg:92.44ms +step:280/1670 train_time:25881ms step_avg:92.43ms +step:281/1670 train_time:25972ms step_avg:92.43ms +step:282/1670 train_time:26065ms step_avg:92.43ms +step:283/1670 train_time:26157ms step_avg:92.43ms +step:284/1670 train_time:26248ms step_avg:92.42ms +step:285/1670 train_time:26340ms step_avg:92.42ms +step:286/1670 train_time:26431ms step_avg:92.42ms +step:287/1670 train_time:26521ms step_avg:92.41ms +step:288/1670 train_time:26613ms step_avg:92.40ms +step:289/1670 train_time:26704ms step_avg:92.40ms +step:290/1670 train_time:26795ms step_avg:92.39ms +step:291/1670 train_time:26886ms step_avg:92.39ms +step:292/1670 train_time:26977ms step_avg:92.39ms +step:293/1670 train_time:27068ms step_avg:92.38ms +step:294/1670 train_time:27159ms step_avg:92.38ms +step:295/1670 train_time:27249ms step_avg:92.37ms +step:296/1670 train_time:27342ms step_avg:92.37ms +step:297/1670 train_time:27433ms step_avg:92.37ms +step:298/1670 train_time:27525ms step_avg:92.37ms +step:299/1670 train_time:27615ms step_avg:92.36ms +step:300/1670 train_time:27705ms step_avg:92.35ms +step:301/1670 train_time:27796ms step_avg:92.35ms +step:302/1670 train_time:27887ms step_avg:92.34ms +step:303/1670 train_time:27979ms step_avg:92.34ms +step:304/1670 train_time:28070ms step_avg:92.34ms +step:305/1670 train_time:28162ms step_avg:92.33ms +step:306/1670 train_time:28252ms step_avg:92.33ms +step:307/1670 train_time:28345ms step_avg:92.33ms +step:308/1670 train_time:28437ms step_avg:92.33ms +step:309/1670 train_time:28528ms step_avg:92.32ms +step:310/1670 train_time:28619ms step_avg:92.32ms +step:311/1670 train_time:28709ms step_avg:92.31ms +step:312/1670 train_time:28799ms step_avg:92.31ms +step:313/1670 train_time:28890ms step_avg:92.30ms +step:314/1670 train_time:28981ms step_avg:92.30ms +step:315/1670 train_time:29072ms step_avg:92.29ms +step:316/1670 train_time:29164ms step_avg:92.29ms +step:317/1670 train_time:29256ms step_avg:92.29ms +step:318/1670 train_time:29348ms step_avg:92.29ms +step:319/1670 train_time:29441ms step_avg:92.29ms +step:320/1670 train_time:29531ms step_avg:92.29ms +step:321/1670 train_time:29622ms step_avg:92.28ms +step:322/1670 train_time:29713ms step_avg:92.28ms +step:323/1670 train_time:29804ms step_avg:92.27ms +step:324/1670 train_time:29895ms step_avg:92.27ms +step:325/1670 train_time:29986ms step_avg:92.26ms +step:326/1670 train_time:30077ms step_avg:92.26ms +step:327/1670 train_time:30167ms step_avg:92.25ms +step:328/1670 train_time:30259ms step_avg:92.25ms +step:329/1670 train_time:30350ms step_avg:92.25ms +step:330/1670 train_time:30442ms step_avg:92.25ms +step:331/1670 train_time:30533ms step_avg:92.24ms +step:332/1670 train_time:30624ms step_avg:92.24ms +step:333/1670 train_time:30715ms step_avg:92.24ms +step:334/1670 train_time:30806ms step_avg:92.23ms +step:335/1670 train_time:30897ms step_avg:92.23ms +step:336/1670 train_time:30987ms step_avg:92.22ms +step:337/1670 train_time:31079ms step_avg:92.22ms +step:338/1670 train_time:31170ms step_avg:92.22ms +step:339/1670 train_time:31262ms step_avg:92.22ms +step:340/1670 train_time:31354ms step_avg:92.22ms +step:341/1670 train_time:31446ms step_avg:92.22ms +step:342/1670 train_time:31537ms step_avg:92.21ms +step:343/1670 train_time:31627ms step_avg:92.21ms +step:344/1670 train_time:31719ms step_avg:92.21ms +step:345/1670 train_time:31810ms step_avg:92.20ms +step:346/1670 train_time:31901ms step_avg:92.20ms +step:347/1670 train_time:31992ms step_avg:92.20ms +step:348/1670 train_time:32085ms step_avg:92.20ms +step:349/1670 train_time:32176ms step_avg:92.19ms +step:350/1670 train_time:32266ms step_avg:92.19ms +step:351/1670 train_time:32359ms step_avg:92.19ms +step:352/1670 train_time:32450ms step_avg:92.19ms +step:353/1670 train_time:32541ms step_avg:92.18ms +step:354/1670 train_time:32632ms step_avg:92.18ms +step:355/1670 train_time:32723ms step_avg:92.18ms +step:356/1670 train_time:32814ms step_avg:92.17ms +step:357/1670 train_time:32905ms step_avg:92.17ms +step:358/1670 train_time:32996ms step_avg:92.17ms +step:359/1670 train_time:33087ms step_avg:92.17ms +step:360/1670 train_time:33178ms step_avg:92.16ms +step:361/1670 train_time:33269ms step_avg:92.16ms +step:362/1670 train_time:33361ms step_avg:92.16ms +step:363/1670 train_time:33453ms step_avg:92.16ms +step:364/1670 train_time:33544ms step_avg:92.15ms +step:365/1670 train_time:33635ms step_avg:92.15ms +step:366/1670 train_time:33727ms step_avg:92.15ms +step:367/1670 train_time:33818ms step_avg:92.15ms +step:368/1670 train_time:33908ms step_avg:92.14ms +step:369/1670 train_time:33999ms step_avg:92.14ms +step:370/1670 train_time:34090ms step_avg:92.13ms +step:371/1670 train_time:34181ms step_avg:92.13ms +step:372/1670 train_time:34273ms step_avg:92.13ms +step:373/1670 train_time:34366ms step_avg:92.13ms +step:374/1670 train_time:34457ms step_avg:92.13ms +step:375/1670 train_time:34548ms step_avg:92.13ms +step:375/1670 val_loss:3.8165 train_time:34639ms step_avg:92.37ms +step:376/1670 train_time:34656ms step_avg:92.17ms +step:377/1670 train_time:34732ms step_avg:92.13ms +step:378/1670 train_time:34824ms step_avg:92.13ms +step:379/1670 train_time:34916ms step_avg:92.13ms +step:380/1670 train_time:35006ms step_avg:92.12ms +step:381/1670 train_time:35096ms step_avg:92.12ms +step:382/1670 train_time:35186ms step_avg:92.11ms +step:383/1670 train_time:35277ms step_avg:92.11ms +step:384/1670 train_time:35368ms step_avg:92.10ms +step:385/1670 train_time:35459ms step_avg:92.10ms +step:386/1670 train_time:35552ms step_avg:92.10ms +step:387/1670 train_time:35644ms step_avg:92.10ms +step:388/1670 train_time:35737ms step_avg:92.11ms +step:389/1670 train_time:35829ms step_avg:92.11ms +step:390/1670 train_time:35920ms step_avg:92.10ms +step:391/1670 train_time:36009ms step_avg:92.10ms +step:392/1670 train_time:36099ms step_avg:92.09ms +step:393/1670 train_time:36190ms step_avg:92.09ms +step:394/1670 train_time:36280ms step_avg:92.08ms +step:395/1670 train_time:36373ms step_avg:92.08ms +step:396/1670 train_time:36463ms step_avg:92.08ms +step:397/1670 train_time:36555ms step_avg:92.08ms +step:398/1670 train_time:36647ms step_avg:92.08ms +step:399/1670 train_time:36739ms step_avg:92.08ms +step:400/1670 train_time:36832ms step_avg:92.08ms +step:401/1670 train_time:36923ms step_avg:92.08ms +step:402/1670 train_time:37014ms step_avg:92.07ms +step:403/1670 train_time:37104ms step_avg:92.07ms +step:404/1670 train_time:37195ms step_avg:92.07ms +step:405/1670 train_time:37286ms step_avg:92.06ms +step:406/1670 train_time:37377ms step_avg:92.06ms +step:407/1670 train_time:37467ms step_avg:92.06ms +step:408/1670 train_time:37559ms step_avg:92.06ms +step:409/1670 train_time:37651ms step_avg:92.06ms +step:410/1670 train_time:37742ms step_avg:92.05ms +step:411/1670 train_time:37835ms step_avg:92.06ms +step:412/1670 train_time:37926ms step_avg:92.05ms +step:413/1670 train_time:38017ms step_avg:92.05ms +step:414/1670 train_time:38108ms step_avg:92.05ms +step:415/1670 train_time:38198ms step_avg:92.04ms +step:416/1670 train_time:38290ms step_avg:92.04ms +step:417/1670 train_time:38380ms step_avg:92.04ms +step:418/1670 train_time:38470ms step_avg:92.03ms +step:419/1670 train_time:38561ms step_avg:92.03ms +step:420/1670 train_time:38652ms step_avg:92.03ms +step:421/1670 train_time:38743ms step_avg:92.03ms +step:422/1670 train_time:38836ms step_avg:92.03ms +step:423/1670 train_time:38928ms step_avg:92.03ms +step:424/1670 train_time:39020ms step_avg:92.03ms +step:425/1670 train_time:39267ms step_avg:92.39ms +step:426/1670 train_time:39346ms step_avg:92.36ms +step:427/1670 train_time:39435ms step_avg:92.35ms +step:428/1670 train_time:39526ms step_avg:92.35ms +step:429/1670 train_time:39616ms step_avg:92.35ms +step:430/1670 train_time:39706ms step_avg:92.34ms +step:431/1670 train_time:39796ms step_avg:92.33ms +step:432/1670 train_time:39886ms step_avg:92.33ms +step:433/1670 train_time:39976ms step_avg:92.32ms +step:434/1670 train_time:40065ms step_avg:92.32ms +step:435/1670 train_time:40158ms step_avg:92.32ms +step:436/1670 train_time:40255ms step_avg:92.33ms +step:437/1670 train_time:40349ms step_avg:92.33ms +step:438/1670 train_time:40440ms step_avg:92.33ms +step:439/1670 train_time:40532ms step_avg:92.33ms +step:440/1670 train_time:40620ms step_avg:92.32ms +step:441/1670 train_time:40712ms step_avg:92.32ms +step:442/1670 train_time:40801ms step_avg:92.31ms +step:443/1670 train_time:40891ms step_avg:92.31ms +step:444/1670 train_time:40981ms step_avg:92.30ms +step:445/1670 train_time:41074ms step_avg:92.30ms +step:446/1670 train_time:41166ms step_avg:92.30ms +step:447/1670 train_time:41261ms step_avg:92.31ms +step:448/1670 train_time:41355ms step_avg:92.31ms +step:449/1670 train_time:41446ms step_avg:92.31ms +step:450/1670 train_time:41538ms step_avg:92.31ms +step:451/1670 train_time:41629ms step_avg:92.30ms +step:452/1670 train_time:41719ms step_avg:92.30ms +step:453/1670 train_time:41809ms step_avg:92.29ms +step:454/1670 train_time:41899ms step_avg:92.29ms +step:455/1670 train_time:41990ms step_avg:92.29ms +step:456/1670 train_time:42080ms step_avg:92.28ms +step:457/1670 train_time:42172ms step_avg:92.28ms +step:458/1670 train_time:42264ms step_avg:92.28ms +step:459/1670 train_time:42357ms step_avg:92.28ms +step:460/1670 train_time:42449ms step_avg:92.28ms +step:461/1670 train_time:42540ms step_avg:92.28ms +step:462/1670 train_time:42631ms step_avg:92.27ms +step:463/1670 train_time:42721ms step_avg:92.27ms +step:464/1670 train_time:42812ms step_avg:92.27ms +step:465/1670 train_time:42902ms step_avg:92.26ms +step:466/1670 train_time:42992ms step_avg:92.26ms +step:467/1670 train_time:43083ms step_avg:92.25ms +step:468/1670 train_time:43174ms step_avg:92.25ms +step:469/1670 train_time:43266ms step_avg:92.25ms +step:470/1670 train_time:43358ms step_avg:92.25ms +step:471/1670 train_time:43451ms step_avg:92.25ms +step:472/1670 train_time:43541ms step_avg:92.25ms +step:473/1670 train_time:43632ms step_avg:92.25ms +step:474/1670 train_time:43723ms step_avg:92.24ms +step:475/1670 train_time:43814ms step_avg:92.24ms +step:476/1670 train_time:43904ms step_avg:92.24ms +step:477/1670 train_time:43994ms step_avg:92.23ms +step:478/1670 train_time:44086ms step_avg:92.23ms +step:479/1670 train_time:44178ms step_avg:92.23ms +step:480/1670 train_time:44269ms step_avg:92.23ms +step:481/1670 train_time:44361ms step_avg:92.23ms +step:482/1670 train_time:44452ms step_avg:92.22ms +step:483/1670 train_time:44544ms step_avg:92.22ms +step:484/1670 train_time:44635ms step_avg:92.22ms +step:485/1670 train_time:44728ms step_avg:92.22ms +step:486/1670 train_time:44819ms step_avg:92.22ms +step:487/1670 train_time:44910ms step_avg:92.22ms +step:488/1670 train_time:45000ms step_avg:92.21ms +step:489/1670 train_time:45092ms step_avg:92.21ms +step:490/1670 train_time:45183ms step_avg:92.21ms +step:491/1670 train_time:45275ms step_avg:92.21ms +step:492/1670 train_time:45365ms step_avg:92.21ms +step:493/1670 train_time:45458ms step_avg:92.21ms +step:494/1670 train_time:45548ms step_avg:92.20ms +step:495/1670 train_time:45639ms step_avg:92.20ms +step:496/1670 train_time:45731ms step_avg:92.20ms +step:497/1670 train_time:45822ms step_avg:92.20ms +step:498/1670 train_time:45913ms step_avg:92.20ms +step:499/1670 train_time:46004ms step_avg:92.19ms +step:500/1670 train_time:46095ms step_avg:92.19ms +step:500/1670 val_loss:3.7140 train_time:46186ms step_avg:92.37ms +step:501/1670 train_time:46203ms step_avg:92.22ms +step:502/1670 train_time:46279ms step_avg:92.19ms +step:503/1670 train_time:46369ms step_avg:92.19ms +step:504/1670 train_time:46460ms step_avg:92.18ms +step:505/1670 train_time:46550ms step_avg:92.18ms +step:506/1670 train_time:46640ms step_avg:92.17ms +step:507/1670 train_time:46730ms step_avg:92.17ms +step:508/1670 train_time:46822ms step_avg:92.17ms +step:509/1670 train_time:46913ms step_avg:92.17ms +step:510/1670 train_time:47005ms step_avg:92.17ms +step:511/1670 train_time:47096ms step_avg:92.16ms +step:512/1670 train_time:47189ms step_avg:92.17ms +step:513/1670 train_time:47281ms step_avg:92.17ms +step:514/1670 train_time:47373ms step_avg:92.16ms +step:515/1670 train_time:47466ms step_avg:92.17ms +step:516/1670 train_time:47557ms step_avg:92.16ms +step:517/1670 train_time:47647ms step_avg:92.16ms +step:518/1670 train_time:47738ms step_avg:92.16ms +step:519/1670 train_time:47829ms step_avg:92.16ms +step:520/1670 train_time:47921ms step_avg:92.16ms +step:521/1670 train_time:48011ms step_avg:92.15ms +step:522/1670 train_time:48103ms step_avg:92.15ms +step:523/1670 train_time:48194ms step_avg:92.15ms +step:524/1670 train_time:48285ms step_avg:92.15ms +step:525/1670 train_time:48376ms step_avg:92.15ms +step:526/1670 train_time:48467ms step_avg:92.14ms +step:527/1670 train_time:48558ms step_avg:92.14ms +step:528/1670 train_time:48648ms step_avg:92.14ms +step:529/1670 train_time:48739ms step_avg:92.13ms +step:530/1670 train_time:48830ms step_avg:92.13ms +step:531/1670 train_time:48922ms step_avg:92.13ms +step:532/1670 train_time:49013ms step_avg:92.13ms +step:533/1670 train_time:49105ms step_avg:92.13ms +step:534/1670 train_time:49197ms step_avg:92.13ms +step:535/1670 train_time:49287ms step_avg:92.13ms +step:536/1670 train_time:49378ms step_avg:92.12ms +step:537/1670 train_time:49469ms step_avg:92.12ms +step:538/1670 train_time:49560ms step_avg:92.12ms +step:539/1670 train_time:49650ms step_avg:92.12ms +step:540/1670 train_time:49742ms step_avg:92.11ms +step:541/1670 train_time:49832ms step_avg:92.11ms +step:542/1670 train_time:49923ms step_avg:92.11ms +step:543/1670 train_time:50014ms step_avg:92.11ms +step:544/1670 train_time:50106ms step_avg:92.11ms +step:545/1670 train_time:50199ms step_avg:92.11ms +step:546/1670 train_time:50290ms step_avg:92.11ms +step:547/1670 train_time:50381ms step_avg:92.11ms +step:548/1670 train_time:50472ms step_avg:92.10ms +step:549/1670 train_time:50563ms step_avg:92.10ms +step:550/1670 train_time:50654ms step_avg:92.10ms +step:551/1670 train_time:50746ms step_avg:92.10ms +step:552/1670 train_time:50837ms step_avg:92.10ms +step:553/1670 train_time:50928ms step_avg:92.09ms +step:554/1670 train_time:51020ms step_avg:92.09ms +step:555/1670 train_time:51110ms step_avg:92.09ms +step:556/1670 train_time:51202ms step_avg:92.09ms +step:557/1670 train_time:51293ms step_avg:92.09ms +step:558/1670 train_time:51577ms step_avg:92.43ms +step:559/1670 train_time:51648ms step_avg:92.39ms +step:560/1670 train_time:51739ms step_avg:92.39ms +step:561/1670 train_time:51830ms step_avg:92.39ms +step:562/1670 train_time:51921ms step_avg:92.39ms +step:563/1670 train_time:52012ms step_avg:92.38ms +step:564/1670 train_time:52103ms step_avg:92.38ms +step:565/1670 train_time:52195ms step_avg:92.38ms +step:566/1670 train_time:52286ms step_avg:92.38ms +step:567/1670 train_time:52378ms step_avg:92.38ms +step:568/1670 train_time:52474ms step_avg:92.38ms +step:569/1670 train_time:52570ms step_avg:92.39ms +step:570/1670 train_time:52665ms step_avg:92.39ms +step:571/1670 train_time:52758ms step_avg:92.40ms +step:572/1670 train_time:52850ms step_avg:92.39ms +step:573/1670 train_time:52942ms step_avg:92.39ms +step:574/1670 train_time:53032ms step_avg:92.39ms +step:575/1670 train_time:53125ms step_avg:92.39ms +step:576/1670 train_time:53216ms step_avg:92.39ms +step:577/1670 train_time:53308ms step_avg:92.39ms +step:578/1670 train_time:53400ms step_avg:92.39ms +step:579/1670 train_time:53494ms step_avg:92.39ms +step:580/1670 train_time:53589ms step_avg:92.39ms +step:581/1670 train_time:53683ms step_avg:92.40ms +step:582/1670 train_time:53775ms step_avg:92.40ms +step:583/1670 train_time:53867ms step_avg:92.40ms +step:584/1670 train_time:53959ms step_avg:92.39ms +step:585/1670 train_time:54050ms step_avg:92.39ms +step:586/1670 train_time:54141ms step_avg:92.39ms +step:587/1670 train_time:54232ms step_avg:92.39ms +step:588/1670 train_time:54325ms step_avg:92.39ms +step:589/1670 train_time:54419ms step_avg:92.39ms +step:590/1670 train_time:54511ms step_avg:92.39ms +step:591/1670 train_time:54604ms step_avg:92.39ms +step:592/1670 train_time:54698ms step_avg:92.39ms +step:593/1670 train_time:54790ms step_avg:92.39ms +step:594/1670 train_time:54883ms step_avg:92.40ms +step:595/1670 train_time:54975ms step_avg:92.40ms +step:596/1670 train_time:55067ms step_avg:92.39ms +step:597/1670 train_time:55158ms step_avg:92.39ms +step:598/1670 train_time:55250ms step_avg:92.39ms +step:599/1670 train_time:55342ms step_avg:92.39ms +step:600/1670 train_time:55435ms step_avg:92.39ms +step:601/1670 train_time:55528ms step_avg:92.39ms +step:602/1670 train_time:55622ms step_avg:92.39ms +step:603/1670 train_time:55715ms step_avg:92.40ms +step:604/1670 train_time:55808ms step_avg:92.40ms +step:605/1670 train_time:55901ms step_avg:92.40ms +step:606/1670 train_time:55993ms step_avg:92.40ms +step:607/1670 train_time:56087ms step_avg:92.40ms +step:608/1670 train_time:56179ms step_avg:92.40ms +step:609/1670 train_time:56270ms step_avg:92.40ms +step:610/1670 train_time:56363ms step_avg:92.40ms +step:611/1670 train_time:56455ms step_avg:92.40ms +step:612/1670 train_time:56547ms step_avg:92.40ms +step:613/1670 train_time:56639ms step_avg:92.40ms +step:614/1670 train_time:56731ms step_avg:92.40ms +step:615/1670 train_time:56824ms step_avg:92.40ms +step:616/1670 train_time:56918ms step_avg:92.40ms +step:617/1670 train_time:57010ms step_avg:92.40ms +step:618/1670 train_time:57103ms step_avg:92.40ms +step:619/1670 train_time:57194ms step_avg:92.40ms +step:620/1670 train_time:57287ms step_avg:92.40ms +step:621/1670 train_time:57379ms step_avg:92.40ms +step:622/1670 train_time:57473ms step_avg:92.40ms +step:623/1670 train_time:57566ms step_avg:92.40ms +step:624/1670 train_time:57659ms step_avg:92.40ms +step:625/1670 train_time:57751ms step_avg:92.40ms +step:625/1670 val_loss:3.6159 train_time:57844ms step_avg:92.55ms +step:626/1670 train_time:57861ms step_avg:92.43ms +step:627/1670 train_time:57937ms step_avg:92.40ms +step:628/1670 train_time:58041ms step_avg:92.42ms +step:629/1670 train_time:58136ms step_avg:92.43ms +step:630/1670 train_time:58229ms step_avg:92.43ms +step:631/1670 train_time:58320ms step_avg:92.43ms +step:632/1670 train_time:58412ms step_avg:92.42ms +step:633/1670 train_time:58502ms step_avg:92.42ms +step:634/1670 train_time:58593ms step_avg:92.42ms +step:635/1670 train_time:58684ms step_avg:92.42ms +step:636/1670 train_time:58776ms step_avg:92.41ms +step:637/1670 train_time:58869ms step_avg:92.42ms +step:638/1670 train_time:58962ms step_avg:92.42ms +step:639/1670 train_time:59199ms step_avg:92.64ms +step:640/1670 train_time:59274ms step_avg:92.62ms +step:641/1670 train_time:59364ms step_avg:92.61ms +step:642/1670 train_time:59455ms step_avg:92.61ms +step:643/1670 train_time:59546ms step_avg:92.61ms +step:644/1670 train_time:59637ms step_avg:92.60ms +step:645/1670 train_time:59728ms step_avg:92.60ms +step:646/1670 train_time:59819ms step_avg:92.60ms +step:647/1670 train_time:59911ms step_avg:92.60ms +step:648/1670 train_time:60002ms step_avg:92.60ms +step:649/1670 train_time:60098ms step_avg:92.60ms +step:650/1670 train_time:60195ms step_avg:92.61ms +step:651/1670 train_time:60290ms step_avg:92.61ms +step:652/1670 train_time:60382ms step_avg:92.61ms +step:653/1670 train_time:60474ms step_avg:92.61ms +step:654/1670 train_time:60565ms step_avg:92.61ms +step:655/1670 train_time:60658ms step_avg:92.61ms +step:656/1670 train_time:60750ms step_avg:92.61ms +step:657/1670 train_time:60840ms step_avg:92.60ms +step:658/1670 train_time:60932ms step_avg:92.60ms +step:659/1670 train_time:61025ms step_avg:92.60ms +step:660/1670 train_time:61119ms step_avg:92.60ms +step:661/1670 train_time:61215ms step_avg:92.61ms +step:662/1670 train_time:61309ms step_avg:92.61ms +step:663/1670 train_time:61401ms step_avg:92.61ms +step:664/1670 train_time:61493ms step_avg:92.61ms +step:665/1670 train_time:61585ms step_avg:92.61ms +step:666/1670 train_time:61676ms step_avg:92.61ms +step:667/1670 train_time:61768ms step_avg:92.61ms +step:668/1670 train_time:61859ms step_avg:92.60ms +step:669/1670 train_time:61951ms step_avg:92.60ms +step:670/1670 train_time:62044ms step_avg:92.60ms +step:671/1670 train_time:62137ms step_avg:92.60ms +step:672/1670 train_time:62230ms step_avg:92.60ms +step:673/1670 train_time:62323ms step_avg:92.61ms +step:674/1670 train_time:62415ms step_avg:92.60ms +step:675/1670 train_time:62508ms step_avg:92.60ms +step:676/1670 train_time:62599ms step_avg:92.60ms +step:677/1670 train_time:62692ms step_avg:92.60ms +step:678/1670 train_time:62783ms step_avg:92.60ms +step:679/1670 train_time:62875ms step_avg:92.60ms +step:680/1670 train_time:62966ms step_avg:92.60ms +step:681/1670 train_time:63058ms step_avg:92.60ms +step:682/1670 train_time:63153ms step_avg:92.60ms +step:683/1670 train_time:63247ms step_avg:92.60ms +step:684/1670 train_time:63339ms step_avg:92.60ms +step:685/1670 train_time:63432ms step_avg:92.60ms +step:686/1670 train_time:63524ms step_avg:92.60ms +step:687/1670 train_time:63615ms step_avg:92.60ms +step:688/1670 train_time:63707ms step_avg:92.60ms +step:689/1670 train_time:63799ms step_avg:92.60ms +step:690/1670 train_time:63891ms step_avg:92.60ms +step:691/1670 train_time:63984ms step_avg:92.60ms +step:692/1670 train_time:64076ms step_avg:92.59ms +step:693/1670 train_time:64170ms step_avg:92.60ms +step:694/1670 train_time:64262ms step_avg:92.60ms +step:695/1670 train_time:64355ms step_avg:92.60ms +step:696/1670 train_time:64448ms step_avg:92.60ms +step:697/1670 train_time:64540ms step_avg:92.60ms +step:698/1670 train_time:64633ms step_avg:92.60ms +step:699/1670 train_time:64725ms step_avg:92.60ms +step:700/1670 train_time:64818ms step_avg:92.60ms +step:701/1670 train_time:64911ms step_avg:92.60ms +step:702/1670 train_time:65003ms step_avg:92.60ms +step:703/1670 train_time:65095ms step_avg:92.60ms +step:704/1670 train_time:65188ms step_avg:92.60ms +step:705/1670 train_time:65280ms step_avg:92.60ms +step:706/1670 train_time:65373ms step_avg:92.60ms +step:707/1670 train_time:65466ms step_avg:92.60ms +step:708/1670 train_time:65558ms step_avg:92.60ms +step:709/1670 train_time:65651ms step_avg:92.60ms +step:710/1670 train_time:65743ms step_avg:92.60ms +step:711/1670 train_time:65836ms step_avg:92.60ms +step:712/1670 train_time:65928ms step_avg:92.60ms +step:713/1670 train_time:66020ms step_avg:92.59ms +step:714/1670 train_time:66112ms step_avg:92.59ms +step:715/1670 train_time:66205ms step_avg:92.59ms +step:716/1670 train_time:66297ms step_avg:92.59ms +step:717/1670 train_time:66389ms step_avg:92.59ms +step:718/1670 train_time:66482ms step_avg:92.59ms +step:719/1670 train_time:66574ms step_avg:92.59ms +step:720/1670 train_time:66668ms step_avg:92.59ms +step:721/1670 train_time:66759ms step_avg:92.59ms +step:722/1670 train_time:66853ms step_avg:92.59ms +step:723/1670 train_time:66946ms step_avg:92.60ms +step:724/1670 train_time:67038ms step_avg:92.59ms +step:725/1670 train_time:67131ms step_avg:92.59ms +step:726/1670 train_time:67223ms step_avg:92.59ms +step:727/1670 train_time:67315ms step_avg:92.59ms +step:728/1670 train_time:67409ms step_avg:92.59ms +step:729/1670 train_time:67502ms step_avg:92.60ms +step:730/1670 train_time:67595ms step_avg:92.60ms +step:731/1670 train_time:67687ms step_avg:92.60ms +step:732/1670 train_time:67780ms step_avg:92.59ms +step:733/1670 train_time:67873ms step_avg:92.60ms +step:734/1670 train_time:67964ms step_avg:92.59ms +step:735/1670 train_time:68056ms step_avg:92.59ms +step:736/1670 train_time:68150ms step_avg:92.59ms +step:737/1670 train_time:68241ms step_avg:92.59ms +step:738/1670 train_time:68334ms step_avg:92.59ms +step:739/1670 train_time:68427ms step_avg:92.59ms +step:740/1670 train_time:68519ms step_avg:92.59ms +step:741/1670 train_time:68613ms step_avg:92.60ms +step:742/1670 train_time:68706ms step_avg:92.60ms +step:743/1670 train_time:68798ms step_avg:92.59ms +step:744/1670 train_time:68891ms step_avg:92.59ms +step:745/1670 train_time:68982ms step_avg:92.59ms +step:746/1670 train_time:69075ms step_avg:92.59ms +step:747/1670 train_time:69167ms step_avg:92.59ms +step:748/1670 train_time:69259ms step_avg:92.59ms +step:749/1670 train_time:69352ms step_avg:92.59ms +step:750/1670 train_time:69444ms step_avg:92.59ms +step:750/1670 val_loss:3.5620 train_time:69537ms step_avg:92.72ms +step:751/1670 train_time:69555ms step_avg:92.62ms +step:752/1670 train_time:69632ms step_avg:92.60ms +step:753/1670 train_time:69725ms step_avg:92.60ms +step:754/1670 train_time:69816ms step_avg:92.59ms +step:755/1670 train_time:69908ms step_avg:92.59ms +step:756/1670 train_time:70000ms step_avg:92.59ms +step:757/1670 train_time:70092ms step_avg:92.59ms +step:758/1670 train_time:70184ms step_avg:92.59ms +step:759/1670 train_time:70276ms step_avg:92.59ms +step:760/1670 train_time:70369ms step_avg:92.59ms +step:761/1670 train_time:70463ms step_avg:92.59ms +step:762/1670 train_time:70556ms step_avg:92.59ms +step:763/1670 train_time:70649ms step_avg:92.59ms +step:764/1670 train_time:70743ms step_avg:92.60ms +step:765/1670 train_time:70835ms step_avg:92.59ms +step:766/1670 train_time:70926ms step_avg:92.59ms +step:767/1670 train_time:71019ms step_avg:92.59ms +step:768/1670 train_time:71111ms step_avg:92.59ms +step:769/1670 train_time:71204ms step_avg:92.59ms +step:770/1670 train_time:71295ms step_avg:92.59ms +step:771/1670 train_time:71388ms step_avg:92.59ms +step:772/1670 train_time:71481ms step_avg:92.59ms +step:773/1670 train_time:71574ms step_avg:92.59ms +step:774/1670 train_time:71667ms step_avg:92.59ms +step:775/1670 train_time:71760ms step_avg:92.59ms +step:776/1670 train_time:71854ms step_avg:92.59ms +step:777/1670 train_time:71946ms step_avg:92.59ms +step:778/1670 train_time:72037ms step_avg:92.59ms +step:779/1670 train_time:72129ms step_avg:92.59ms +step:780/1670 train_time:72221ms step_avg:92.59ms +step:781/1670 train_time:72314ms step_avg:92.59ms +step:782/1670 train_time:72407ms step_avg:92.59ms +step:783/1670 train_time:72500ms step_avg:92.59ms +step:784/1670 train_time:72595ms step_avg:92.60ms +step:785/1670 train_time:72688ms step_avg:92.60ms +step:786/1670 train_time:72780ms step_avg:92.60ms +step:787/1670 train_time:72874ms step_avg:92.60ms +step:788/1670 train_time:72967ms step_avg:92.60ms +step:789/1670 train_time:73058ms step_avg:92.60ms +step:790/1670 train_time:73149ms step_avg:92.59ms +step:791/1670 train_time:73241ms step_avg:92.59ms +step:792/1670 train_time:73335ms step_avg:92.59ms +step:793/1670 train_time:73429ms step_avg:92.60ms +step:794/1670 train_time:73521ms step_avg:92.60ms +step:795/1670 train_time:73615ms step_avg:92.60ms +step:796/1670 train_time:73708ms step_avg:92.60ms +step:797/1670 train_time:73800ms step_avg:92.60ms +step:798/1670 train_time:73894ms step_avg:92.60ms +step:799/1670 train_time:73987ms step_avg:92.60ms +step:800/1670 train_time:74079ms step_avg:92.60ms +step:801/1670 train_time:74171ms step_avg:92.60ms +step:802/1670 train_time:74263ms step_avg:92.60ms +step:803/1670 train_time:74356ms step_avg:92.60ms +step:804/1670 train_time:74448ms step_avg:92.60ms +step:805/1670 train_time:74540ms step_avg:92.60ms +step:806/1670 train_time:74634ms step_avg:92.60ms +step:807/1670 train_time:74727ms step_avg:92.60ms +step:808/1670 train_time:74819ms step_avg:92.60ms +step:809/1670 train_time:74913ms step_avg:92.60ms +step:810/1670 train_time:75005ms step_avg:92.60ms +step:811/1670 train_time:75097ms step_avg:92.60ms +step:812/1670 train_time:75190ms step_avg:92.60ms +step:813/1670 train_time:75282ms step_avg:92.60ms +step:814/1670 train_time:75374ms step_avg:92.60ms +step:815/1670 train_time:75467ms step_avg:92.60ms +step:816/1670 train_time:75559ms step_avg:92.60ms +step:817/1670 train_time:75651ms step_avg:92.60ms +step:818/1670 train_time:75743ms step_avg:92.60ms +step:819/1670 train_time:75836ms step_avg:92.60ms +step:820/1670 train_time:75928ms step_avg:92.60ms +step:821/1670 train_time:76020ms step_avg:92.59ms +step:822/1670 train_time:76114ms step_avg:92.60ms +step:823/1670 train_time:76207ms step_avg:92.60ms +step:824/1670 train_time:76298ms step_avg:92.59ms +step:825/1670 train_time:76391ms step_avg:92.60ms +step:826/1670 train_time:76484ms step_avg:92.60ms +step:827/1670 train_time:76576ms step_avg:92.59ms +step:828/1670 train_time:76669ms step_avg:92.59ms +step:829/1670 train_time:76761ms step_avg:92.59ms +step:830/1670 train_time:76853ms step_avg:92.59ms +step:831/1670 train_time:76947ms step_avg:92.60ms +step:832/1670 train_time:77038ms step_avg:92.59ms +step:833/1670 train_time:77131ms step_avg:92.59ms +step:834/1670 train_time:77223ms step_avg:92.59ms +step:835/1670 train_time:77316ms step_avg:92.59ms +step:836/1670 train_time:77409ms step_avg:92.59ms +step:837/1670 train_time:77500ms step_avg:92.59ms +step:838/1670 train_time:77594ms step_avg:92.59ms +step:839/1670 train_time:77687ms step_avg:92.59ms +step:840/1670 train_time:77779ms step_avg:92.59ms +step:841/1670 train_time:77872ms step_avg:92.59ms +step:842/1670 train_time:77965ms step_avg:92.59ms +step:843/1670 train_time:78057ms step_avg:92.59ms +step:844/1670 train_time:78149ms step_avg:92.59ms +step:845/1670 train_time:78241ms step_avg:92.59ms +step:846/1670 train_time:78334ms step_avg:92.59ms +step:847/1670 train_time:78427ms step_avg:92.59ms +step:848/1670 train_time:78518ms step_avg:92.59ms +step:849/1670 train_time:78612ms step_avg:92.59ms +step:850/1670 train_time:78705ms step_avg:92.59ms +step:851/1670 train_time:78955ms step_avg:92.78ms +step:852/1670 train_time:79029ms step_avg:92.76ms +step:853/1670 train_time:79120ms step_avg:92.75ms +step:854/1670 train_time:79211ms step_avg:92.75ms +step:855/1670 train_time:79302ms step_avg:92.75ms +step:856/1670 train_time:79393ms step_avg:92.75ms +step:857/1670 train_time:79484ms step_avg:92.75ms +step:858/1670 train_time:79576ms step_avg:92.75ms +step:859/1670 train_time:79667ms step_avg:92.74ms +step:860/1670 train_time:79758ms step_avg:92.74ms +step:861/1670 train_time:79854ms step_avg:92.75ms +step:862/1670 train_time:79953ms step_avg:92.75ms +step:863/1670 train_time:80047ms step_avg:92.75ms +step:864/1670 train_time:80140ms step_avg:92.75ms +step:865/1670 train_time:80232ms step_avg:92.75ms +step:866/1670 train_time:80323ms step_avg:92.75ms +step:867/1670 train_time:80415ms step_avg:92.75ms +step:868/1670 train_time:80508ms step_avg:92.75ms +step:869/1670 train_time:80599ms step_avg:92.75ms +step:870/1670 train_time:80690ms step_avg:92.75ms +step:871/1670 train_time:80782ms step_avg:92.75ms +step:872/1670 train_time:80877ms step_avg:92.75ms +step:873/1670 train_time:80972ms step_avg:92.75ms +step:874/1670 train_time:81066ms step_avg:92.75ms +step:875/1670 train_time:81158ms step_avg:92.75ms +step:875/1670 val_loss:3.5192 train_time:81250ms step_avg:92.86ms +step:876/1670 train_time:81268ms step_avg:92.77ms +step:877/1670 train_time:81345ms step_avg:92.75ms +step:878/1670 train_time:81438ms step_avg:92.75ms +step:879/1670 train_time:81530ms step_avg:92.75ms +step:880/1670 train_time:81621ms step_avg:92.75ms +step:881/1670 train_time:81712ms step_avg:92.75ms +step:882/1670 train_time:81803ms step_avg:92.75ms +step:883/1670 train_time:81894ms step_avg:92.75ms +step:884/1670 train_time:81986ms step_avg:92.74ms +step:885/1670 train_time:82077ms step_avg:92.74ms +step:886/1670 train_time:82171ms step_avg:92.74ms +step:887/1670 train_time:82267ms step_avg:92.75ms +step:888/1670 train_time:82361ms step_avg:92.75ms +step:889/1670 train_time:82453ms step_avg:92.75ms +step:890/1670 train_time:82547ms step_avg:92.75ms +step:891/1670 train_time:82639ms step_avg:92.75ms +step:892/1670 train_time:82731ms step_avg:92.75ms +step:893/1670 train_time:82823ms step_avg:92.75ms +step:894/1670 train_time:82914ms step_avg:92.75ms +step:895/1670 train_time:83005ms step_avg:92.74ms +step:896/1670 train_time:83097ms step_avg:92.74ms +step:897/1670 train_time:83191ms step_avg:92.74ms +step:898/1670 train_time:83285ms step_avg:92.74ms +step:899/1670 train_time:83377ms step_avg:92.74ms +step:900/1670 train_time:83471ms step_avg:92.75ms +step:901/1670 train_time:83564ms step_avg:92.75ms +step:902/1670 train_time:83656ms step_avg:92.75ms +step:903/1670 train_time:83749ms step_avg:92.75ms +step:904/1670 train_time:83841ms step_avg:92.74ms +step:905/1670 train_time:83933ms step_avg:92.74ms +step:906/1670 train_time:84025ms step_avg:92.74ms +step:907/1670 train_time:84118ms step_avg:92.74ms +step:908/1670 train_time:84212ms step_avg:92.74ms +step:909/1670 train_time:84306ms step_avg:92.75ms +step:910/1670 train_time:84399ms step_avg:92.75ms +step:911/1670 train_time:84493ms step_avg:92.75ms +step:912/1670 train_time:84585ms step_avg:92.75ms +step:913/1670 train_time:84677ms step_avg:92.75ms +step:914/1670 train_time:84770ms step_avg:92.75ms +step:915/1670 train_time:84862ms step_avg:92.75ms +step:916/1670 train_time:84953ms step_avg:92.74ms +step:917/1670 train_time:85046ms step_avg:92.74ms +step:918/1670 train_time:85139ms step_avg:92.74ms +step:919/1670 train_time:85232ms step_avg:92.74ms +step:920/1670 train_time:85325ms step_avg:92.74ms +step:921/1670 train_time:85418ms step_avg:92.74ms +step:922/1670 train_time:85511ms step_avg:92.75ms +step:923/1670 train_time:85604ms step_avg:92.75ms +step:924/1670 train_time:85696ms step_avg:92.74ms +step:925/1670 train_time:85789ms step_avg:92.74ms +step:926/1670 train_time:85881ms step_avg:92.74ms +step:927/1670 train_time:85974ms step_avg:92.74ms +step:928/1670 train_time:86067ms step_avg:92.74ms +step:929/1670 train_time:86159ms step_avg:92.74ms +step:930/1670 train_time:86252ms step_avg:92.74ms +step:931/1670 train_time:86344ms step_avg:92.74ms +step:932/1670 train_time:86437ms step_avg:92.74ms +step:933/1670 train_time:86530ms step_avg:92.74ms +step:934/1670 train_time:86624ms step_avg:92.74ms +step:935/1670 train_time:86716ms step_avg:92.74ms +step:936/1670 train_time:86808ms step_avg:92.74ms +step:937/1670 train_time:86900ms step_avg:92.74ms +step:938/1670 train_time:86992ms step_avg:92.74ms +step:939/1670 train_time:87085ms step_avg:92.74ms +step:940/1670 train_time:87177ms step_avg:92.74ms +step:941/1670 train_time:87269ms step_avg:92.74ms +step:942/1670 train_time:87362ms step_avg:92.74ms +step:943/1670 train_time:87455ms step_avg:92.74ms +step:944/1670 train_time:87548ms step_avg:92.74ms +step:945/1670 train_time:87641ms step_avg:92.74ms +step:946/1670 train_time:87734ms step_avg:92.74ms +step:947/1670 train_time:87827ms step_avg:92.74ms +step:948/1670 train_time:87919ms step_avg:92.74ms +step:949/1670 train_time:88012ms step_avg:92.74ms +step:950/1670 train_time:88105ms step_avg:92.74ms +step:951/1670 train_time:88197ms step_avg:92.74ms +step:952/1670 train_time:88290ms step_avg:92.74ms +step:953/1670 train_time:88382ms step_avg:92.74ms +step:954/1670 train_time:88475ms step_avg:92.74ms +step:955/1670 train_time:88567ms step_avg:92.74ms +step:956/1670 train_time:88660ms step_avg:92.74ms +step:957/1670 train_time:88753ms step_avg:92.74ms +step:958/1670 train_time:88845ms step_avg:92.74ms +step:959/1670 train_time:88937ms step_avg:92.74ms +step:960/1670 train_time:89029ms step_avg:92.74ms +step:961/1670 train_time:89121ms step_avg:92.74ms +step:962/1670 train_time:89213ms step_avg:92.74ms +step:963/1670 train_time:89306ms step_avg:92.74ms +step:964/1670 train_time:89398ms step_avg:92.74ms +step:965/1670 train_time:89492ms step_avg:92.74ms +step:966/1670 train_time:89584ms step_avg:92.74ms +step:967/1670 train_time:89676ms step_avg:92.74ms +step:968/1670 train_time:89771ms step_avg:92.74ms +step:969/1670 train_time:89864ms step_avg:92.74ms +step:970/1670 train_time:89956ms step_avg:92.74ms +step:971/1670 train_time:90049ms step_avg:92.74ms +step:972/1670 train_time:90141ms step_avg:92.74ms +step:973/1670 train_time:90234ms step_avg:92.74ms +step:974/1670 train_time:90326ms step_avg:92.74ms +step:975/1670 train_time:90419ms step_avg:92.74ms +step:976/1670 train_time:90511ms step_avg:92.74ms +step:977/1670 train_time:90604ms step_avg:92.74ms +step:978/1670 train_time:90697ms step_avg:92.74ms +step:979/1670 train_time:90790ms step_avg:92.74ms +step:980/1670 train_time:90882ms step_avg:92.74ms +step:981/1670 train_time:90974ms step_avg:92.74ms +step:982/1670 train_time:91068ms step_avg:92.74ms +step:983/1670 train_time:91160ms step_avg:92.74ms +step:984/1670 train_time:91253ms step_avg:92.74ms +step:985/1670 train_time:91346ms step_avg:92.74ms +step:986/1670 train_time:91438ms step_avg:92.74ms +step:987/1670 train_time:91532ms step_avg:92.74ms +step:988/1670 train_time:91625ms step_avg:92.74ms +step:989/1670 train_time:91717ms step_avg:92.74ms +step:990/1670 train_time:91810ms step_avg:92.74ms +step:991/1670 train_time:91902ms step_avg:92.74ms +step:992/1670 train_time:91995ms step_avg:92.74ms +step:993/1670 train_time:92088ms step_avg:92.74ms +step:994/1670 train_time:92180ms step_avg:92.74ms +step:995/1670 train_time:92273ms step_avg:92.74ms +step:996/1670 train_time:92366ms step_avg:92.74ms +step:997/1670 train_time:92458ms step_avg:92.74ms +step:998/1670 train_time:92552ms step_avg:92.74ms +step:999/1670 train_time:92645ms step_avg:92.74ms +step:1000/1670 train_time:92737ms step_avg:92.74ms +step:1000/1670 val_loss:3.4692 train_time:92831ms step_avg:92.83ms +step:1001/1670 train_time:92848ms step_avg:92.76ms +step:1002/1670 train_time:92926ms step_avg:92.74ms +step:1003/1670 train_time:93018ms step_avg:92.74ms +step:1004/1670 train_time:93111ms step_avg:92.74ms +step:1005/1670 train_time:93203ms step_avg:92.74ms +step:1006/1670 train_time:93296ms step_avg:92.74ms +step:1007/1670 train_time:93388ms step_avg:92.74ms +step:1008/1670 train_time:93479ms step_avg:92.74ms +step:1009/1670 train_time:93572ms step_avg:92.74ms +step:1010/1670 train_time:93664ms step_avg:92.74ms +step:1011/1670 train_time:93756ms step_avg:92.74ms +step:1012/1670 train_time:93850ms step_avg:92.74ms +step:1013/1670 train_time:93944ms step_avg:92.74ms +step:1014/1670 train_time:94036ms step_avg:92.74ms +step:1015/1670 train_time:94129ms step_avg:92.74ms +step:1016/1670 train_time:94221ms step_avg:92.74ms +step:1017/1670 train_time:94313ms step_avg:92.74ms +step:1018/1670 train_time:94406ms step_avg:92.74ms +step:1019/1670 train_time:94498ms step_avg:92.74ms +step:1020/1670 train_time:94590ms step_avg:92.74ms +step:1021/1670 train_time:94684ms step_avg:92.74ms +step:1022/1670 train_time:94776ms step_avg:92.74ms +step:1023/1670 train_time:94870ms step_avg:92.74ms +step:1024/1670 train_time:94963ms step_avg:92.74ms +step:1025/1670 train_time:95055ms step_avg:92.74ms +step:1026/1670 train_time:95147ms step_avg:92.74ms +step:1027/1670 train_time:95239ms step_avg:92.74ms +step:1028/1670 train_time:95333ms step_avg:92.74ms +step:1029/1670 train_time:95425ms step_avg:92.74ms +step:1030/1670 train_time:95517ms step_avg:92.74ms +step:1031/1670 train_time:95610ms step_avg:92.74ms +step:1032/1670 train_time:95702ms step_avg:92.73ms +step:1033/1670 train_time:95794ms step_avg:92.73ms +step:1034/1670 train_time:95887ms step_avg:92.73ms +step:1035/1670 train_time:95980ms step_avg:92.73ms +step:1036/1670 train_time:96073ms step_avg:92.73ms +step:1037/1670 train_time:96165ms step_avg:92.73ms +step:1038/1670 train_time:96257ms step_avg:92.73ms +step:1039/1670 train_time:96349ms step_avg:92.73ms +step:1040/1670 train_time:96441ms step_avg:92.73ms +step:1041/1670 train_time:96535ms step_avg:92.73ms +step:1042/1670 train_time:96627ms step_avg:92.73ms +step:1043/1670 train_time:96719ms step_avg:92.73ms +step:1044/1670 train_time:96813ms step_avg:92.73ms +step:1045/1670 train_time:96906ms step_avg:92.73ms +step:1046/1670 train_time:96998ms step_avg:92.73ms +step:1047/1670 train_time:97091ms step_avg:92.73ms +step:1048/1670 train_time:97183ms step_avg:92.73ms +step:1049/1670 train_time:97275ms step_avg:92.73ms +step:1050/1670 train_time:97368ms step_avg:92.73ms +step:1051/1670 train_time:97461ms step_avg:92.73ms +step:1052/1670 train_time:97553ms step_avg:92.73ms +step:1053/1670 train_time:97646ms step_avg:92.73ms +step:1054/1670 train_time:97738ms step_avg:92.73ms +step:1055/1670 train_time:97831ms step_avg:92.73ms +step:1056/1670 train_time:97924ms step_avg:92.73ms +step:1057/1670 train_time:98017ms step_avg:92.73ms +step:1058/1670 train_time:98109ms step_avg:92.73ms +step:1059/1670 train_time:98201ms step_avg:92.73ms +step:1060/1670 train_time:98294ms step_avg:92.73ms +step:1061/1670 train_time:98386ms step_avg:92.73ms +step:1062/1670 train_time:98634ms step_avg:92.88ms +step:1063/1670 train_time:98714ms step_avg:92.86ms +step:1064/1670 train_time:98805ms step_avg:92.86ms +step:1065/1670 train_time:98896ms step_avg:92.86ms +step:1066/1670 train_time:98987ms step_avg:92.86ms +step:1067/1670 train_time:99078ms step_avg:92.86ms +step:1068/1670 train_time:99169ms step_avg:92.85ms +step:1069/1670 train_time:99260ms step_avg:92.85ms +step:1070/1670 train_time:99352ms step_avg:92.85ms +step:1071/1670 train_time:99443ms step_avg:92.85ms +step:1072/1670 train_time:99542ms step_avg:92.86ms +step:1073/1670 train_time:99638ms step_avg:92.86ms +step:1074/1670 train_time:99733ms step_avg:92.86ms +step:1075/1670 train_time:99826ms step_avg:92.86ms +step:1076/1670 train_time:99917ms step_avg:92.86ms +step:1077/1670 train_time:100009ms step_avg:92.86ms +step:1078/1670 train_time:100100ms step_avg:92.86ms +step:1079/1670 train_time:100192ms step_avg:92.86ms +step:1080/1670 train_time:100283ms step_avg:92.85ms +step:1081/1670 train_time:100374ms step_avg:92.85ms +step:1082/1670 train_time:100469ms step_avg:92.85ms +step:1083/1670 train_time:100563ms step_avg:92.86ms +step:1084/1670 train_time:100658ms step_avg:92.86ms +step:1085/1670 train_time:100752ms step_avg:92.86ms +step:1086/1670 train_time:100844ms step_avg:92.86ms +step:1087/1670 train_time:100936ms step_avg:92.86ms +step:1088/1670 train_time:101027ms step_avg:92.86ms +step:1089/1670 train_time:101119ms step_avg:92.86ms +step:1090/1670 train_time:101211ms step_avg:92.85ms +step:1091/1670 train_time:101302ms step_avg:92.85ms +step:1092/1670 train_time:101395ms step_avg:92.85ms +step:1093/1670 train_time:101488ms step_avg:92.85ms +step:1094/1670 train_time:101580ms step_avg:92.85ms +step:1095/1670 train_time:101674ms step_avg:92.85ms +step:1096/1670 train_time:101768ms step_avg:92.85ms +step:1097/1670 train_time:101861ms step_avg:92.85ms +step:1098/1670 train_time:101954ms step_avg:92.85ms +step:1099/1670 train_time:102045ms step_avg:92.85ms +step:1100/1670 train_time:102137ms step_avg:92.85ms +step:1101/1670 train_time:102228ms step_avg:92.85ms +step:1102/1670 train_time:102320ms step_avg:92.85ms +step:1103/1670 train_time:102412ms step_avg:92.85ms +step:1104/1670 train_time:102505ms step_avg:92.85ms +step:1105/1670 train_time:102598ms step_avg:92.85ms +step:1106/1670 train_time:102692ms step_avg:92.85ms +step:1107/1670 train_time:102784ms step_avg:92.85ms +step:1108/1670 train_time:102877ms step_avg:92.85ms +step:1109/1670 train_time:102970ms step_avg:92.85ms +step:1110/1670 train_time:103062ms step_avg:92.85ms +step:1111/1670 train_time:103154ms step_avg:92.85ms +step:1112/1670 train_time:103246ms step_avg:92.85ms +step:1113/1670 train_time:103336ms step_avg:92.84ms +step:1114/1670 train_time:103429ms step_avg:92.84ms +step:1115/1670 train_time:103715ms step_avg:93.02ms +step:1116/1670 train_time:103788ms step_avg:93.00ms +step:1117/1670 train_time:103879ms step_avg:93.00ms +step:1118/1670 train_time:103971ms step_avg:93.00ms +step:1119/1670 train_time:104063ms step_avg:93.00ms +step:1120/1670 train_time:104154ms step_avg:93.00ms +step:1121/1670 train_time:104246ms step_avg:92.99ms +step:1122/1670 train_time:104338ms step_avg:92.99ms +step:1123/1670 train_time:104430ms step_avg:92.99ms +step:1124/1670 train_time:104522ms step_avg:92.99ms +step:1125/1670 train_time:104622ms step_avg:93.00ms +step:1125/1670 val_loss:3.4164 train_time:104722ms step_avg:93.09ms +step:1126/1670 train_time:104740ms step_avg:93.02ms +step:1127/1670 train_time:104821ms step_avg:93.01ms +step:1128/1670 train_time:104921ms step_avg:93.01ms +step:1129/1670 train_time:105015ms step_avg:93.02ms +step:1130/1670 train_time:105108ms step_avg:93.02ms +step:1131/1670 train_time:105199ms step_avg:93.01ms +step:1132/1670 train_time:105292ms step_avg:93.01ms +step:1133/1670 train_time:105383ms step_avg:93.01ms +step:1134/1670 train_time:105475ms step_avg:93.01ms +step:1135/1670 train_time:105567ms step_avg:93.01ms +step:1136/1670 train_time:105662ms step_avg:93.01ms +step:1137/1670 train_time:105758ms step_avg:93.01ms +step:1138/1670 train_time:105855ms step_avg:93.02ms +step:1139/1670 train_time:105950ms step_avg:93.02ms +step:1140/1670 train_time:106044ms step_avg:93.02ms +step:1141/1670 train_time:106136ms step_avg:93.02ms +step:1142/1670 train_time:106228ms step_avg:93.02ms +step:1143/1670 train_time:106320ms step_avg:93.02ms +step:1144/1670 train_time:106413ms step_avg:93.02ms +step:1145/1670 train_time:106504ms step_avg:93.02ms +step:1146/1670 train_time:106597ms step_avg:93.02ms +step:1147/1670 train_time:106691ms step_avg:93.02ms +step:1148/1670 train_time:106785ms step_avg:93.02ms +step:1149/1670 train_time:106882ms step_avg:93.02ms +step:1150/1670 train_time:106976ms step_avg:93.02ms +step:1151/1670 train_time:107070ms step_avg:93.02ms +step:1152/1670 train_time:107162ms step_avg:93.02ms +step:1153/1670 train_time:107254ms step_avg:93.02ms +step:1154/1670 train_time:107348ms step_avg:93.02ms +step:1155/1670 train_time:107440ms step_avg:93.02ms +step:1156/1670 train_time:107532ms step_avg:93.02ms +step:1157/1670 train_time:107625ms step_avg:93.02ms +step:1158/1670 train_time:107718ms step_avg:93.02ms +step:1159/1670 train_time:107816ms step_avg:93.02ms +step:1160/1670 train_time:107911ms step_avg:93.03ms +step:1161/1670 train_time:108004ms step_avg:93.03ms +step:1162/1670 train_time:108098ms step_avg:93.03ms +step:1163/1670 train_time:108191ms step_avg:93.03ms +step:1164/1670 train_time:108283ms step_avg:93.03ms +step:1165/1670 train_time:108375ms step_avg:93.03ms +step:1166/1670 train_time:108468ms step_avg:93.03ms +step:1167/1670 train_time:108560ms step_avg:93.03ms +step:1168/1670 train_time:108654ms step_avg:93.03ms +step:1169/1670 train_time:108749ms step_avg:93.03ms +step:1170/1670 train_time:108843ms step_avg:93.03ms +step:1171/1670 train_time:108937ms step_avg:93.03ms +step:1172/1670 train_time:109032ms step_avg:93.03ms +step:1173/1670 train_time:109124ms step_avg:93.03ms +step:1174/1670 train_time:109218ms step_avg:93.03ms +step:1175/1670 train_time:109311ms step_avg:93.03ms +step:1176/1670 train_time:109403ms step_avg:93.03ms +step:1177/1670 train_time:109496ms step_avg:93.03ms +step:1178/1670 train_time:109588ms step_avg:93.03ms +step:1179/1670 train_time:109682ms step_avg:93.03ms +step:1180/1670 train_time:109775ms step_avg:93.03ms +step:1181/1670 train_time:109868ms step_avg:93.03ms +step:1182/1670 train_time:109962ms step_avg:93.03ms +step:1183/1670 train_time:110056ms step_avg:93.03ms +step:1184/1670 train_time:110149ms step_avg:93.03ms +step:1185/1670 train_time:110242ms step_avg:93.03ms +step:1186/1670 train_time:110335ms step_avg:93.03ms +step:1187/1670 train_time:110427ms step_avg:93.03ms +step:1188/1670 train_time:110520ms step_avg:93.03ms +step:1189/1670 train_time:110614ms step_avg:93.03ms +step:1190/1670 train_time:110707ms step_avg:93.03ms +step:1191/1670 train_time:110799ms step_avg:93.03ms +step:1192/1670 train_time:110893ms step_avg:93.03ms +step:1193/1670 train_time:110987ms step_avg:93.03ms +step:1194/1670 train_time:111081ms step_avg:93.03ms +step:1195/1670 train_time:111174ms step_avg:93.03ms +step:1196/1670 train_time:111267ms step_avg:93.03ms +step:1197/1670 train_time:111359ms step_avg:93.03ms +step:1198/1670 train_time:111452ms step_avg:93.03ms +step:1199/1670 train_time:111544ms step_avg:93.03ms +step:1200/1670 train_time:111637ms step_avg:93.03ms +step:1201/1670 train_time:111731ms step_avg:93.03ms +step:1202/1670 train_time:111824ms step_avg:93.03ms +step:1203/1670 train_time:111918ms step_avg:93.03ms +step:1204/1670 train_time:112012ms step_avg:93.03ms +step:1205/1670 train_time:112104ms step_avg:93.03ms +step:1206/1670 train_time:112198ms step_avg:93.03ms +step:1207/1670 train_time:112292ms step_avg:93.03ms +step:1208/1670 train_time:112385ms step_avg:93.03ms +step:1209/1670 train_time:112477ms step_avg:93.03ms +step:1210/1670 train_time:112570ms step_avg:93.03ms +step:1211/1670 train_time:112664ms step_avg:93.03ms +step:1212/1670 train_time:112757ms step_avg:93.03ms +step:1213/1670 train_time:112850ms step_avg:93.03ms +step:1214/1670 train_time:112944ms step_avg:93.03ms +step:1215/1670 train_time:113037ms step_avg:93.03ms +step:1216/1670 train_time:113131ms step_avg:93.04ms +step:1217/1670 train_time:113224ms step_avg:93.04ms +step:1218/1670 train_time:113317ms step_avg:93.04ms +step:1219/1670 train_time:113410ms step_avg:93.04ms +step:1220/1670 train_time:113503ms step_avg:93.04ms +step:1221/1670 train_time:113596ms step_avg:93.04ms +step:1222/1670 train_time:113689ms step_avg:93.03ms +step:1223/1670 train_time:113782ms step_avg:93.04ms +step:1224/1670 train_time:113875ms step_avg:93.04ms +step:1225/1670 train_time:113968ms step_avg:93.04ms +step:1226/1670 train_time:114061ms step_avg:93.04ms +step:1227/1670 train_time:114155ms step_avg:93.04ms +step:1228/1670 train_time:114248ms step_avg:93.04ms +step:1229/1670 train_time:114341ms step_avg:93.04ms +step:1230/1670 train_time:114434ms step_avg:93.04ms +step:1231/1670 train_time:114527ms step_avg:93.04ms +step:1232/1670 train_time:114620ms step_avg:93.04ms +step:1233/1670 train_time:114713ms step_avg:93.04ms +step:1234/1670 train_time:114807ms step_avg:93.04ms +step:1235/1670 train_time:114900ms step_avg:93.04ms +step:1236/1670 train_time:114993ms step_avg:93.04ms +step:1237/1670 train_time:115086ms step_avg:93.04ms +step:1238/1670 train_time:115179ms step_avg:93.04ms +step:1239/1670 train_time:115272ms step_avg:93.04ms +step:1240/1670 train_time:115365ms step_avg:93.04ms +step:1241/1670 train_time:115457ms step_avg:93.04ms +step:1242/1670 train_time:115550ms step_avg:93.04ms +step:1243/1670 train_time:115643ms step_avg:93.04ms +step:1244/1670 train_time:115737ms step_avg:93.04ms +step:1245/1670 train_time:115831ms step_avg:93.04ms +step:1246/1670 train_time:115924ms step_avg:93.04ms +step:1247/1670 train_time:116018ms step_avg:93.04ms +step:1248/1670 train_time:116111ms step_avg:93.04ms +step:1249/1670 train_time:116203ms step_avg:93.04ms +step:1250/1670 train_time:116297ms step_avg:93.04ms +step:1250/1670 val_loss:3.3777 train_time:116390ms step_avg:93.11ms +step:1251/1670 train_time:116407ms step_avg:93.05ms +step:1252/1670 train_time:116484ms step_avg:93.04ms +step:1253/1670 train_time:116578ms step_avg:93.04ms +step:1254/1670 train_time:116671ms step_avg:93.04ms +step:1255/1670 train_time:116763ms step_avg:93.04ms +step:1256/1670 train_time:116855ms step_avg:93.04ms +step:1257/1670 train_time:116948ms step_avg:93.04ms +step:1258/1670 train_time:117042ms step_avg:93.04ms +step:1259/1670 train_time:117135ms step_avg:93.04ms +step:1260/1670 train_time:117229ms step_avg:93.04ms +step:1261/1670 train_time:117322ms step_avg:93.04ms +step:1262/1670 train_time:117416ms step_avg:93.04ms +step:1263/1670 train_time:117510ms step_avg:93.04ms +step:1264/1670 train_time:117603ms step_avg:93.04ms +step:1265/1670 train_time:117696ms step_avg:93.04ms +step:1266/1670 train_time:117788ms step_avg:93.04ms +step:1267/1670 train_time:117881ms step_avg:93.04ms +step:1268/1670 train_time:117974ms step_avg:93.04ms +step:1269/1670 train_time:118066ms step_avg:93.04ms +step:1270/1670 train_time:118160ms step_avg:93.04ms +step:1271/1670 train_time:118253ms step_avg:93.04ms +step:1272/1670 train_time:118347ms step_avg:93.04ms +step:1273/1670 train_time:118441ms step_avg:93.04ms +step:1274/1670 train_time:118678ms step_avg:93.15ms +step:1275/1670 train_time:118770ms step_avg:93.15ms +step:1276/1670 train_time:118862ms step_avg:93.15ms +step:1277/1670 train_time:118954ms step_avg:93.15ms +step:1278/1670 train_time:119045ms step_avg:93.15ms +step:1279/1670 train_time:119137ms step_avg:93.15ms +step:1280/1670 train_time:119229ms step_avg:93.15ms +step:1281/1670 train_time:119321ms step_avg:93.15ms +step:1282/1670 train_time:119413ms step_avg:93.15ms +step:1283/1670 train_time:119505ms step_avg:93.14ms +step:1284/1670 train_time:119602ms step_avg:93.15ms +step:1285/1670 train_time:119699ms step_avg:93.15ms +step:1286/1670 train_time:119793ms step_avg:93.15ms +step:1287/1670 train_time:119886ms step_avg:93.15ms +step:1288/1670 train_time:119979ms step_avg:93.15ms +step:1289/1670 train_time:120071ms step_avg:93.15ms +step:1290/1670 train_time:120163ms step_avg:93.15ms +step:1291/1670 train_time:120256ms step_avg:93.15ms +step:1292/1670 train_time:120347ms step_avg:93.15ms +step:1293/1670 train_time:120441ms step_avg:93.15ms +step:1294/1670 train_time:120535ms step_avg:93.15ms +step:1295/1670 train_time:120630ms step_avg:93.15ms +step:1296/1670 train_time:120726ms step_avg:93.15ms +step:1297/1670 train_time:120820ms step_avg:93.15ms +step:1298/1670 train_time:120912ms step_avg:93.15ms +step:1299/1670 train_time:121006ms step_avg:93.15ms +step:1300/1670 train_time:121099ms step_avg:93.15ms +step:1301/1670 train_time:121191ms step_avg:93.15ms +step:1302/1670 train_time:121283ms step_avg:93.15ms +step:1303/1670 train_time:121375ms step_avg:93.15ms +step:1304/1670 train_time:121467ms step_avg:93.15ms +step:1305/1670 train_time:121561ms step_avg:93.15ms +step:1306/1670 train_time:121655ms step_avg:93.15ms +step:1307/1670 train_time:121749ms step_avg:93.15ms +step:1308/1670 train_time:121842ms step_avg:93.15ms +step:1309/1670 train_time:121936ms step_avg:93.15ms +step:1310/1670 train_time:122029ms step_avg:93.15ms +step:1311/1670 train_time:122122ms step_avg:93.15ms +step:1312/1670 train_time:122214ms step_avg:93.15ms +step:1313/1670 train_time:122306ms step_avg:93.15ms +step:1314/1670 train_time:122398ms step_avg:93.15ms +step:1315/1670 train_time:122491ms step_avg:93.15ms +step:1316/1670 train_time:122586ms step_avg:93.15ms +step:1317/1670 train_time:122680ms step_avg:93.15ms +step:1318/1670 train_time:122774ms step_avg:93.15ms +step:1319/1670 train_time:122868ms step_avg:93.15ms +step:1320/1670 train_time:122962ms step_avg:93.15ms +step:1321/1670 train_time:123055ms step_avg:93.15ms +step:1322/1670 train_time:123147ms step_avg:93.15ms +step:1323/1670 train_time:123240ms step_avg:93.15ms +step:1324/1670 train_time:123333ms step_avg:93.15ms +step:1325/1670 train_time:123426ms step_avg:93.15ms +step:1326/1670 train_time:123520ms step_avg:93.15ms +step:1327/1670 train_time:123614ms step_avg:93.15ms +step:1328/1670 train_time:123707ms step_avg:93.15ms +step:1329/1670 train_time:123801ms step_avg:93.15ms +step:1330/1670 train_time:123895ms step_avg:93.15ms +step:1331/1670 train_time:123988ms step_avg:93.15ms +step:1332/1670 train_time:124080ms step_avg:93.15ms +step:1333/1670 train_time:124173ms step_avg:93.15ms +step:1334/1670 train_time:124267ms step_avg:93.15ms +step:1335/1670 train_time:124360ms step_avg:93.15ms +step:1336/1670 train_time:124453ms step_avg:93.15ms +step:1337/1670 train_time:124546ms step_avg:93.15ms +step:1338/1670 train_time:124639ms step_avg:93.15ms +step:1339/1670 train_time:124733ms step_avg:93.15ms +step:1340/1670 train_time:124828ms step_avg:93.15ms +step:1341/1670 train_time:124921ms step_avg:93.16ms +step:1342/1670 train_time:125014ms step_avg:93.16ms +step:1343/1670 train_time:125107ms step_avg:93.15ms +step:1344/1670 train_time:125199ms step_avg:93.15ms +step:1345/1670 train_time:125291ms step_avg:93.15ms +step:1346/1670 train_time:125384ms step_avg:93.15ms +step:1347/1670 train_time:125478ms step_avg:93.15ms +step:1348/1670 train_time:125571ms step_avg:93.15ms +step:1349/1670 train_time:125664ms step_avg:93.15ms +step:1350/1670 train_time:125757ms step_avg:93.15ms +step:1351/1670 train_time:125850ms step_avg:93.15ms +step:1352/1670 train_time:125945ms step_avg:93.15ms +step:1353/1670 train_time:126038ms step_avg:93.15ms +step:1354/1670 train_time:126132ms step_avg:93.15ms +step:1355/1670 train_time:126224ms step_avg:93.15ms +step:1356/1670 train_time:126317ms step_avg:93.15ms +step:1357/1670 train_time:126409ms step_avg:93.15ms +step:1358/1670 train_time:126503ms step_avg:93.15ms +step:1359/1670 train_time:126597ms step_avg:93.15ms +step:1360/1670 train_time:126689ms step_avg:93.15ms +step:1361/1670 train_time:126782ms step_avg:93.15ms +step:1362/1670 train_time:126876ms step_avg:93.15ms +step:1363/1670 train_time:126969ms step_avg:93.15ms +step:1364/1670 train_time:127064ms step_avg:93.16ms +step:1365/1670 train_time:127158ms step_avg:93.16ms +step:1366/1670 train_time:127250ms step_avg:93.15ms +step:1367/1670 train_time:127343ms step_avg:93.16ms +step:1368/1670 train_time:127436ms step_avg:93.15ms +step:1369/1670 train_time:127529ms step_avg:93.15ms +step:1370/1670 train_time:127623ms step_avg:93.16ms +step:1371/1670 train_time:127716ms step_avg:93.16ms +step:1372/1670 train_time:127809ms step_avg:93.16ms +step:1373/1670 train_time:127903ms step_avg:93.16ms +step:1374/1670 train_time:127997ms step_avg:93.16ms +step:1375/1670 train_time:128090ms step_avg:93.16ms +step:1375/1670 val_loss:3.3432 train_time:128184ms step_avg:93.22ms +step:1376/1670 train_time:128201ms step_avg:93.17ms +step:1377/1670 train_time:128280ms step_avg:93.16ms +step:1378/1670 train_time:128375ms step_avg:93.16ms +step:1379/1670 train_time:128468ms step_avg:93.16ms +step:1380/1670 train_time:128559ms step_avg:93.16ms +step:1381/1670 train_time:128652ms step_avg:93.16ms +step:1382/1670 train_time:128745ms step_avg:93.16ms +step:1383/1670 train_time:128838ms step_avg:93.16ms +step:1384/1670 train_time:128931ms step_avg:93.16ms +step:1385/1670 train_time:129024ms step_avg:93.16ms +step:1386/1670 train_time:129118ms step_avg:93.16ms +step:1387/1670 train_time:129214ms step_avg:93.16ms +step:1388/1670 train_time:129309ms step_avg:93.16ms +step:1389/1670 train_time:129402ms step_avg:93.16ms +step:1390/1670 train_time:129495ms step_avg:93.16ms +step:1391/1670 train_time:129588ms step_avg:93.16ms +step:1392/1670 train_time:129680ms step_avg:93.16ms +step:1393/1670 train_time:129774ms step_avg:93.16ms +step:1394/1670 train_time:129867ms step_avg:93.16ms +step:1395/1670 train_time:129959ms step_avg:93.16ms +step:1396/1670 train_time:130052ms step_avg:93.16ms +step:1397/1670 train_time:130147ms step_avg:93.16ms +step:1398/1670 train_time:130240ms step_avg:93.16ms +step:1399/1670 train_time:130335ms step_avg:93.16ms +step:1400/1670 train_time:130428ms step_avg:93.16ms +step:1401/1670 train_time:130520ms step_avg:93.16ms +step:1402/1670 train_time:130615ms step_avg:93.16ms +step:1403/1670 train_time:130709ms step_avg:93.16ms +step:1404/1670 train_time:130801ms step_avg:93.16ms +step:1405/1670 train_time:130893ms step_avg:93.16ms +step:1406/1670 train_time:130986ms step_avg:93.16ms +step:1407/1670 train_time:131079ms step_avg:93.16ms +step:1408/1670 train_time:131174ms step_avg:93.16ms +step:1409/1670 train_time:131269ms step_avg:93.16ms +step:1410/1670 train_time:131361ms step_avg:93.16ms +step:1411/1670 train_time:131455ms step_avg:93.16ms +step:1412/1670 train_time:131549ms step_avg:93.16ms +step:1413/1670 train_time:131642ms step_avg:93.16ms +step:1414/1670 train_time:131737ms step_avg:93.17ms +step:1415/1670 train_time:131829ms step_avg:93.17ms +step:1416/1670 train_time:131921ms step_avg:93.16ms +step:1417/1670 train_time:132015ms step_avg:93.17ms +step:1418/1670 train_time:132108ms step_avg:93.17ms +step:1419/1670 train_time:132201ms step_avg:93.17ms +step:1420/1670 train_time:132294ms step_avg:93.16ms +step:1421/1670 train_time:132386ms step_avg:93.16ms +step:1422/1670 train_time:132479ms step_avg:93.16ms +step:1423/1670 train_time:132572ms step_avg:93.16ms +step:1424/1670 train_time:132666ms step_avg:93.16ms +step:1425/1670 train_time:132759ms step_avg:93.16ms +step:1426/1670 train_time:132852ms step_avg:93.16ms +step:1427/1670 train_time:132945ms step_avg:93.16ms +step:1428/1670 train_time:133038ms step_avg:93.16ms +step:1429/1670 train_time:133132ms step_avg:93.16ms +step:1430/1670 train_time:133224ms step_avg:93.16ms +step:1431/1670 train_time:133318ms step_avg:93.16ms +step:1432/1670 train_time:133411ms step_avg:93.16ms +step:1433/1670 train_time:133504ms step_avg:93.16ms +step:1434/1670 train_time:133598ms step_avg:93.16ms +step:1435/1670 train_time:133691ms step_avg:93.16ms +step:1436/1670 train_time:133784ms step_avg:93.16ms +step:1437/1670 train_time:133877ms step_avg:93.16ms +step:1438/1670 train_time:133969ms step_avg:93.16ms +step:1439/1670 train_time:134062ms step_avg:93.16ms +step:1440/1670 train_time:134155ms step_avg:93.16ms +step:1441/1670 train_time:134248ms step_avg:93.16ms +step:1442/1670 train_time:134341ms step_avg:93.16ms +step:1443/1670 train_time:134436ms step_avg:93.16ms +step:1444/1670 train_time:134530ms step_avg:93.16ms +step:1445/1670 train_time:134623ms step_avg:93.16ms +step:1446/1670 train_time:134717ms step_avg:93.17ms +step:1447/1670 train_time:134810ms step_avg:93.17ms +step:1448/1670 train_time:134903ms step_avg:93.16ms +step:1449/1670 train_time:134996ms step_avg:93.16ms +step:1450/1670 train_time:135089ms step_avg:93.16ms +step:1451/1670 train_time:135181ms step_avg:93.16ms +step:1452/1670 train_time:135274ms step_avg:93.16ms +step:1453/1670 train_time:135367ms step_avg:93.16ms +step:1454/1670 train_time:135460ms step_avg:93.16ms +step:1455/1670 train_time:135554ms step_avg:93.16ms +step:1456/1670 train_time:135647ms step_avg:93.16ms +step:1457/1670 train_time:135741ms step_avg:93.16ms +step:1458/1670 train_time:135836ms step_avg:93.17ms +step:1459/1670 train_time:135928ms step_avg:93.17ms +step:1460/1670 train_time:136022ms step_avg:93.17ms +step:1461/1670 train_time:136117ms step_avg:93.17ms +step:1462/1670 train_time:136211ms step_avg:93.17ms +step:1463/1670 train_time:136304ms step_avg:93.17ms +step:1464/1670 train_time:136397ms step_avg:93.17ms +step:1465/1670 train_time:136490ms step_avg:93.17ms +step:1466/1670 train_time:136583ms step_avg:93.17ms +step:1467/1670 train_time:136676ms step_avg:93.17ms +step:1468/1670 train_time:136769ms step_avg:93.17ms +step:1469/1670 train_time:136862ms step_avg:93.17ms +step:1470/1670 train_time:136955ms step_avg:93.17ms +step:1471/1670 train_time:137047ms step_avg:93.17ms +step:1472/1670 train_time:137140ms step_avg:93.17ms +step:1473/1670 train_time:137234ms step_avg:93.17ms +step:1474/1670 train_time:137327ms step_avg:93.17ms +step:1475/1670 train_time:137421ms step_avg:93.17ms +step:1476/1670 train_time:137516ms step_avg:93.17ms +step:1477/1670 train_time:137610ms step_avg:93.17ms +step:1478/1670 train_time:137702ms step_avg:93.17ms +step:1479/1670 train_time:137796ms step_avg:93.17ms +step:1480/1670 train_time:137889ms step_avg:93.17ms +step:1481/1670 train_time:137982ms step_avg:93.17ms +step:1482/1670 train_time:138076ms step_avg:93.17ms +step:1483/1670 train_time:138169ms step_avg:93.17ms +step:1484/1670 train_time:138261ms step_avg:93.17ms +step:1485/1670 train_time:138512ms step_avg:93.27ms +step:1486/1670 train_time:138583ms step_avg:93.26ms +step:1487/1670 train_time:138676ms step_avg:93.26ms +step:1488/1670 train_time:138769ms step_avg:93.26ms +step:1489/1670 train_time:138860ms step_avg:93.26ms +step:1490/1670 train_time:138952ms step_avg:93.26ms +step:1491/1670 train_time:139044ms step_avg:93.26ms +step:1492/1670 train_time:139136ms step_avg:93.25ms +step:1493/1670 train_time:139228ms step_avg:93.25ms +step:1494/1670 train_time:139319ms step_avg:93.25ms +step:1495/1670 train_time:139418ms step_avg:93.26ms +step:1496/1670 train_time:139516ms step_avg:93.26ms +step:1497/1670 train_time:139610ms step_avg:93.26ms +step:1498/1670 train_time:139703ms step_avg:93.26ms +step:1499/1670 train_time:139796ms step_avg:93.26ms +step:1500/1670 train_time:139888ms step_avg:93.26ms +step:1500/1670 val_loss:3.3132 train_time:139982ms step_avg:93.32ms +step:1501/1670 train_time:140000ms step_avg:93.27ms +step:1502/1670 train_time:140078ms step_avg:93.26ms +step:1503/1670 train_time:140171ms step_avg:93.26ms +step:1504/1670 train_time:140262ms step_avg:93.26ms +step:1505/1670 train_time:140355ms step_avg:93.26ms +step:1506/1670 train_time:140446ms step_avg:93.26ms +step:1507/1670 train_time:140540ms step_avg:93.26ms +step:1508/1670 train_time:140634ms step_avg:93.26ms +step:1509/1670 train_time:140727ms step_avg:93.26ms +step:1510/1670 train_time:140821ms step_avg:93.26ms +step:1511/1670 train_time:140917ms step_avg:93.26ms +step:1512/1670 train_time:141011ms step_avg:93.26ms +step:1513/1670 train_time:141104ms step_avg:93.26ms +step:1514/1670 train_time:141197ms step_avg:93.26ms +step:1515/1670 train_time:141289ms step_avg:93.26ms +step:1516/1670 train_time:141381ms step_avg:93.26ms +step:1517/1670 train_time:141473ms step_avg:93.26ms +step:1518/1670 train_time:141567ms step_avg:93.26ms +step:1519/1670 train_time:141660ms step_avg:93.26ms +step:1520/1670 train_time:141754ms step_avg:93.26ms +step:1521/1670 train_time:141847ms step_avg:93.26ms +step:1522/1670 train_time:141942ms step_avg:93.26ms +step:1523/1670 train_time:142036ms step_avg:93.26ms +step:1524/1670 train_time:142130ms step_avg:93.26ms +step:1525/1670 train_time:142223ms step_avg:93.26ms +step:1526/1670 train_time:142317ms step_avg:93.26ms +step:1527/1670 train_time:142409ms step_avg:93.26ms +step:1528/1670 train_time:142501ms step_avg:93.26ms +step:1529/1670 train_time:142593ms step_avg:93.26ms +step:1530/1670 train_time:142686ms step_avg:93.26ms +step:1531/1670 train_time:142780ms step_avg:93.26ms +step:1532/1670 train_time:142874ms step_avg:93.26ms +step:1533/1670 train_time:142968ms step_avg:93.26ms +step:1534/1670 train_time:143062ms step_avg:93.26ms +step:1535/1670 train_time:143156ms step_avg:93.26ms +step:1536/1670 train_time:143248ms step_avg:93.26ms +step:1537/1670 train_time:143342ms step_avg:93.26ms +step:1538/1670 train_time:143435ms step_avg:93.26ms +step:1539/1670 train_time:143528ms step_avg:93.26ms +step:1540/1670 train_time:143621ms step_avg:93.26ms +step:1541/1670 train_time:143716ms step_avg:93.26ms +step:1542/1670 train_time:143809ms step_avg:93.26ms +step:1543/1670 train_time:143903ms step_avg:93.26ms +step:1544/1670 train_time:143997ms step_avg:93.26ms +step:1545/1670 train_time:144091ms step_avg:93.26ms +step:1546/1670 train_time:144183ms step_avg:93.26ms +step:1547/1670 train_time:144277ms step_avg:93.26ms +step:1548/1670 train_time:144369ms step_avg:93.26ms +step:1549/1670 train_time:144462ms step_avg:93.26ms +step:1550/1670 train_time:144555ms step_avg:93.26ms +step:1551/1670 train_time:144649ms step_avg:93.26ms +step:1552/1670 train_time:144742ms step_avg:93.26ms +step:1553/1670 train_time:144836ms step_avg:93.26ms +step:1554/1670 train_time:144931ms step_avg:93.26ms +step:1555/1670 train_time:145024ms step_avg:93.26ms +step:1556/1670 train_time:145116ms step_avg:93.26ms +step:1557/1670 train_time:145210ms step_avg:93.26ms +step:1558/1670 train_time:145302ms step_avg:93.26ms +step:1559/1670 train_time:145395ms step_avg:93.26ms +step:1560/1670 train_time:145488ms step_avg:93.26ms +step:1561/1670 train_time:145582ms step_avg:93.26ms +step:1562/1670 train_time:145675ms step_avg:93.26ms +step:1563/1670 train_time:145768ms step_avg:93.26ms +step:1564/1670 train_time:145862ms step_avg:93.26ms +step:1565/1670 train_time:145955ms step_avg:93.26ms +step:1566/1670 train_time:146049ms step_avg:93.26ms +step:1567/1670 train_time:146142ms step_avg:93.26ms +step:1568/1670 train_time:146236ms step_avg:93.26ms +step:1569/1670 train_time:146328ms step_avg:93.26ms +step:1570/1670 train_time:146423ms step_avg:93.26ms +step:1571/1670 train_time:146516ms step_avg:93.26ms +step:1572/1670 train_time:146609ms step_avg:93.26ms +step:1573/1670 train_time:146702ms step_avg:93.26ms +step:1574/1670 train_time:146796ms step_avg:93.26ms +step:1575/1670 train_time:146889ms step_avg:93.26ms +step:1576/1670 train_time:146983ms step_avg:93.26ms +step:1577/1670 train_time:147076ms step_avg:93.26ms +step:1578/1670 train_time:147168ms step_avg:93.26ms +step:1579/1670 train_time:147263ms step_avg:93.26ms +step:1580/1670 train_time:147357ms step_avg:93.26ms +step:1581/1670 train_time:147450ms step_avg:93.26ms +step:1582/1670 train_time:147543ms step_avg:93.26ms +step:1583/1670 train_time:147637ms step_avg:93.26ms +step:1584/1670 train_time:147730ms step_avg:93.26ms +step:1585/1670 train_time:147822ms step_avg:93.26ms +step:1586/1670 train_time:147916ms step_avg:93.26ms +step:1587/1670 train_time:148010ms step_avg:93.26ms +step:1588/1670 train_time:148102ms step_avg:93.26ms +step:1589/1670 train_time:148196ms step_avg:93.26ms +step:1590/1670 train_time:148288ms step_avg:93.26ms +step:1591/1670 train_time:148383ms step_avg:93.26ms +step:1592/1670 train_time:148476ms step_avg:93.26ms +step:1593/1670 train_time:148569ms step_avg:93.26ms +step:1594/1670 train_time:148663ms step_avg:93.26ms +step:1595/1670 train_time:148757ms step_avg:93.26ms +step:1596/1670 train_time:148849ms step_avg:93.26ms +step:1597/1670 train_time:148943ms step_avg:93.26ms +step:1598/1670 train_time:149036ms step_avg:93.26ms +step:1599/1670 train_time:149129ms step_avg:93.26ms +step:1600/1670 train_time:149222ms step_avg:93.26ms +step:1601/1670 train_time:149316ms step_avg:93.26ms +step:1602/1670 train_time:149408ms step_avg:93.26ms +step:1603/1670 train_time:149501ms step_avg:93.26ms +step:1604/1670 train_time:149595ms step_avg:93.26ms +step:1605/1670 train_time:149687ms step_avg:93.26ms +step:1606/1670 train_time:149780ms step_avg:93.26ms +step:1607/1670 train_time:149873ms step_avg:93.26ms +step:1608/1670 train_time:149966ms step_avg:93.26ms +step:1609/1670 train_time:150060ms step_avg:93.26ms +step:1610/1670 train_time:150154ms step_avg:93.26ms +step:1611/1670 train_time:150247ms step_avg:93.26ms +step:1612/1670 train_time:150341ms step_avg:93.26ms +step:1613/1670 train_time:150434ms step_avg:93.26ms +step:1614/1670 train_time:150528ms step_avg:93.26ms +step:1615/1670 train_time:150622ms step_avg:93.26ms +step:1616/1670 train_time:150716ms step_avg:93.26ms +step:1617/1670 train_time:150808ms step_avg:93.26ms +step:1618/1670 train_time:150901ms step_avg:93.26ms +step:1619/1670 train_time:150994ms step_avg:93.26ms +step:1620/1670 train_time:151087ms step_avg:93.26ms +step:1621/1670 train_time:151180ms step_avg:93.26ms +step:1622/1670 train_time:151273ms step_avg:93.26ms +step:1623/1670 train_time:151366ms step_avg:93.26ms +step:1624/1670 train_time:151459ms step_avg:93.26ms +step:1625/1670 train_time:151553ms step_avg:93.26ms +step:1625/1670 val_loss:3.2883 train_time:151647ms step_avg:93.32ms +step:1626/1670 train_time:151664ms step_avg:93.27ms +step:1627/1670 train_time:151739ms step_avg:93.26ms +step:1628/1670 train_time:151831ms step_avg:93.26ms +step:1629/1670 train_time:151925ms step_avg:93.26ms +step:1630/1670 train_time:152017ms step_avg:93.26ms +step:1631/1670 train_time:152110ms step_avg:93.26ms +step:1632/1670 train_time:152203ms step_avg:93.26ms +step:1633/1670 train_time:152296ms step_avg:93.26ms +step:1634/1670 train_time:152390ms step_avg:93.26ms +step:1635/1670 train_time:152483ms step_avg:93.26ms +step:1636/1670 train_time:152578ms step_avg:93.26ms +step:1637/1670 train_time:152674ms step_avg:93.26ms +step:1638/1670 train_time:152767ms step_avg:93.26ms +step:1639/1670 train_time:152860ms step_avg:93.26ms +step:1640/1670 train_time:152952ms step_avg:93.26ms +step:1641/1670 train_time:153045ms step_avg:93.26ms +step:1642/1670 train_time:153138ms step_avg:93.26ms +step:1643/1670 train_time:153231ms step_avg:93.26ms +step:1644/1670 train_time:153326ms step_avg:93.26ms +step:1645/1670 train_time:153420ms step_avg:93.26ms +step:1646/1670 train_time:153513ms step_avg:93.26ms +step:1647/1670 train_time:153607ms step_avg:93.26ms +step:1648/1670 train_time:153701ms step_avg:93.26ms +step:1649/1670 train_time:153794ms step_avg:93.27ms +step:1650/1670 train_time:153890ms step_avg:93.27ms +step:1651/1670 train_time:153982ms step_avg:93.27ms +step:1652/1670 train_time:154075ms step_avg:93.27ms +step:1653/1670 train_time:154168ms step_avg:93.27ms +step:1654/1670 train_time:154261ms step_avg:93.27ms +step:1655/1670 train_time:154354ms step_avg:93.27ms +step:1656/1670 train_time:154448ms step_avg:93.27ms +step:1657/1670 train_time:154542ms step_avg:93.27ms +step:1658/1670 train_time:154635ms step_avg:93.27ms +step:1659/1670 train_time:154729ms step_avg:93.27ms +step:1660/1670 train_time:154821ms step_avg:93.27ms +step:1661/1670 train_time:154914ms step_avg:93.27ms +step:1662/1670 train_time:155008ms step_avg:93.27ms +step:1663/1670 train_time:155101ms step_avg:93.27ms +step:1664/1670 train_time:155194ms step_avg:93.27ms +step:1665/1670 train_time:155288ms step_avg:93.27ms +step:1666/1670 train_time:155382ms step_avg:93.27ms +step:1667/1670 train_time:155474ms step_avg:93.27ms +step:1668/1670 train_time:155569ms step_avg:93.27ms +step:1669/1670 train_time:155662ms step_avg:93.27ms +step:1670/1670 train_time:155755ms step_avg:93.27ms +step:1670/1670 val_loss:3.2797 train_time:156019ms step_avg:93.42ms +peak memory allocated: 32002 MiB reserved: 46294 MiB diff --git a/train_gpt.py b/train_gpt.py index 98f8df08c..c4d8ab06c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1,35 +1,39 @@ import os import sys + with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time + code = f.read() # read the code of this file ASAP, for logging import copy import glob import math - +import time +import uuid from dataclasses import dataclass -from functools import lru_cache from itertools import accumulate from pathlib import Path os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min import triton import triton.language as tl from flash_attn_interface import flash_attn_varlen_func -import torch._dynamo as dynamo +from torch import Tensor, nn + dynamo.config.recompile_limit = 64 # ----------------------------------------------------------------------------- # Custom operators: FP8 matmul by @YouJiacheng + @torch.library.custom_op("nanogpt::mm", mutates_args=()) def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: @torch.compile @@ -422,42 +426,189 @@ def step(self): # @ryanyang0, and @vagrawal. rank = dist.get_rank() world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] + group_infos = [] for group in self.param_groups: params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() class DistAdam(torch.optim.Optimizer): def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): @@ -505,13 +656,23 @@ def step(self): g_slice = grad_slices[idx] # State init if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] # weight decay if wd != 0: eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) @@ -559,13 +720,18 @@ def forward(self, x: Tensor): else: return F.linear(x, self.weight.type_as(x)) + def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): assert cos.size(0) >= x_BTHD.size(-3) - cos, sin = cos[None, :x_BTHD.size(-3), None, :], sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) + return torch.cat((y1, y2), 3) + @dataclass class AttnArgs: @@ -684,12 +850,21 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) self.max_seq_len = max_seq_len self.setup_yarn(head_dim) # set learning rates @@ -707,8 +882,12 @@ def setup_yarn(self, head_dim: int): angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) t = torch.arange(self.max_seq_len, dtype=torch.float32) theta = torch.outer(t, angular_freq) - self.rotary_cos = nn.Buffer(theta.cos(), persistent=False) - self.rotary_sin = nn.Buffer(theta.sin(), persistent=False) + self.rotary_cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.rotary_sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) self.angular_freq = angular_freq # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd @@ -740,7 +919,9 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] assert len(bm_sizes) == len(self.blocks) - x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = x0 = norm(self.embed(input_seq)[None]).to( + torch.bfloat16 + ) # use of norm here by @Grad62304977 # U-net design by @brendanh0gan skip_connections = [] @@ -761,7 +942,8 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in attn_scale=self.attn_scales[ws] ) if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() x = self.blocks[i](x, x0, lambdas[i], attn_args) if i < n: skip_connections.append(x) @@ -872,7 +1054,6 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] pos += num_tokens - _cum_lengths = torch.full((max_num_docs,), num_tokens_local) _cum_lengths[0] = 0 _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths @@ -971,7 +1152,7 @@ def nvidia_smi(): max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) ).cuda() for m in model.modules(): - if isinstance(m, nn.Embedding): + if isinstance(m, (nn.Embedding, nn.Linear)): m.bfloat16() for param in model.parameters(): dist.broadcast(param.detach(), 0) @@ -985,7 +1166,13 @@ def nvidia_smi(): # init the optimizer(s) # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) optimizers = [optimizer1, optimizer2] for opt in optimizers: From e89948185a107fb36a0d69ae9a944c7e0c891924 Mon Sep 17 00:00:00 2001 From: larry dial Date: Mon, 15 Sep 2025 21:31:49 -0700 Subject: [PATCH 11/14] data threading and final layer window --- .../25db37c7-2bab-4ef4-ae63-d593590ef823.txt | 3111 +++++++++++++++++ .../26acd99c-9089-406e-8249-f0532e6c2a13.txt | 3111 +++++++++++++++++ .../305f24ee-051f-41a0-939a-0fa26654712a.txt | 3111 +++++++++++++++++ .../517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt | 3111 +++++++++++++++++ .../584d668b-cc79-4dde-b5be-b911623bdb61.txt | 3111 +++++++++++++++++ .../61705980-e239-4d86-9233-210200da7010.txt | 3111 +++++++++++++++++ .../720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt | 3111 +++++++++++++++++ .../93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt | 3111 +++++++++++++++++ .../93480ae1-43e4-415b-813e-ba8b43ab899b.txt | 3111 +++++++++++++++++ records/091525_ThreadingFinalWindow/README.md | 29 + .../a04db288-4cdd-4401-bdd1-444b26c53cd8.txt | 3111 +++++++++++++++++ train_gpt.py | 203 +- 12 files changed, 31277 insertions(+), 65 deletions(-) create mode 100644 records/091525_ThreadingFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt create mode 100644 records/091525_ThreadingFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt create mode 100644 records/091525_ThreadingFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt create mode 100644 records/091525_ThreadingFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt create mode 100644 records/091525_ThreadingFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt create mode 100644 records/091525_ThreadingFinalWindow/61705980-e239-4d86-9233-210200da7010.txt create mode 100644 records/091525_ThreadingFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt create mode 100644 records/091525_ThreadingFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt create mode 100644 records/091525_ThreadingFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt create mode 100644 records/091525_ThreadingFinalWindow/README.md create mode 100644 records/091525_ThreadingFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt diff --git a/records/091525_ThreadingFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt b/records/091525_ThreadingFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt new file mode 100644 index 000000000..a6d14b7b6 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:57:59 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 198817 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 198818 C /usr/bin/python3 614MiB | +| 0 N/A N/A 198819 C /usr/bin/python3 614MiB | +| 0 N/A N/A 198820 C /usr/bin/python3 614MiB | +| 0 N/A N/A 198821 C /usr/bin/python3 614MiB | +| 0 N/A N/A 198822 C /usr/bin/python3 614MiB | +| 0 N/A N/A 198823 C /usr/bin/python3 614MiB | +| 0 N/A N/A 198824 C /usr/bin/python3 614MiB | +| 1 N/A N/A 198818 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 198819 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 198820 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 198821 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 198822 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 198823 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 198824 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:146ms step_avg:145.63ms +step:2/1660 train_time:172ms step_avg:85.76ms +step:3/1660 train_time:234ms step_avg:77.91ms +step:4/1660 train_time:323ms step_avg:80.77ms +step:5/1660 train_time:413ms step_avg:82.67ms +step:6/1660 train_time:505ms step_avg:84.14ms +step:7/1660 train_time:595ms step_avg:85.05ms +step:8/1660 train_time:686ms step_avg:85.78ms +step:9/1660 train_time:776ms step_avg:86.26ms +step:10/1660 train_time:867ms step_avg:86.73ms +step:11/1660 train_time:958ms step_avg:87.06ms +step:12/1660 train_time:1050ms step_avg:87.50ms +step:13/1660 train_time:1145ms step_avg:88.09ms +step:14/1660 train_time:1239ms step_avg:88.51ms +step:15/1660 train_time:1330ms step_avg:88.69ms +step:16/1660 train_time:1421ms step_avg:88.84ms +step:17/1660 train_time:1513ms step_avg:88.98ms +step:18/1660 train_time:1604ms step_avg:89.11ms +step:19/1660 train_time:1696ms step_avg:89.24ms +step:20/1660 train_time:1786ms step_avg:89.31ms +step:21/1660 train_time:1877ms step_avg:89.40ms +step:22/1660 train_time:1969ms step_avg:89.49ms +step:23/1660 train_time:2062ms step_avg:89.63ms +step:24/1660 train_time:2155ms step_avg:89.78ms +step:25/1660 train_time:2248ms step_avg:89.93ms +step:26/1660 train_time:2341ms step_avg:90.05ms +step:27/1660 train_time:2433ms step_avg:90.12ms +step:28/1660 train_time:2525ms step_avg:90.19ms +step:29/1660 train_time:2617ms step_avg:90.23ms +step:30/1660 train_time:2709ms step_avg:90.31ms +step:31/1660 train_time:2801ms step_avg:90.37ms +step:32/1660 train_time:2892ms step_avg:90.38ms +step:33/1660 train_time:2983ms step_avg:90.41ms +step:34/1660 train_time:3075ms step_avg:90.45ms +step:35/1660 train_time:3167ms step_avg:90.48ms +step:36/1660 train_time:3258ms step_avg:90.51ms +step:37/1660 train_time:3351ms step_avg:90.56ms +step:38/1660 train_time:3444ms step_avg:90.62ms +step:39/1660 train_time:3535ms step_avg:90.64ms +step:40/1660 train_time:3627ms step_avg:90.67ms +step:41/1660 train_time:3718ms step_avg:90.69ms +step:42/1660 train_time:3810ms step_avg:90.71ms +step:43/1660 train_time:3901ms step_avg:90.72ms +step:44/1660 train_time:3993ms step_avg:90.74ms +step:45/1660 train_time:4084ms step_avg:90.75ms +step:46/1660 train_time:4176ms step_avg:90.78ms +step:47/1660 train_time:4268ms step_avg:90.81ms +step:48/1660 train_time:4360ms step_avg:90.84ms +step:49/1660 train_time:4453ms step_avg:90.87ms +step:50/1660 train_time:4545ms step_avg:90.90ms +step:51/1660 train_time:4637ms step_avg:90.92ms +step:52/1660 train_time:4729ms step_avg:90.94ms +step:53/1660 train_time:4821ms step_avg:90.97ms +step:54/1660 train_time:4912ms step_avg:90.97ms +step:55/1660 train_time:5004ms step_avg:90.98ms +step:56/1660 train_time:5096ms step_avg:90.99ms +step:57/1660 train_time:5187ms step_avg:90.99ms +step:58/1660 train_time:5278ms step_avg:91.00ms +step:59/1660 train_time:5370ms step_avg:91.02ms +step:60/1660 train_time:5462ms step_avg:91.04ms +step:61/1660 train_time:5554ms step_avg:91.05ms +step:62/1660 train_time:5646ms step_avg:91.06ms +step:63/1660 train_time:5737ms step_avg:91.07ms +step:64/1660 train_time:5830ms step_avg:91.09ms +step:65/1660 train_time:5922ms step_avg:91.10ms +step:66/1660 train_time:6014ms step_avg:91.11ms +step:67/1660 train_time:6105ms step_avg:91.12ms +step:68/1660 train_time:6197ms step_avg:91.14ms +step:69/1660 train_time:6289ms step_avg:91.15ms +step:70/1660 train_time:6381ms step_avg:91.16ms +step:71/1660 train_time:6473ms step_avg:91.17ms +step:72/1660 train_time:6565ms step_avg:91.17ms +step:73/1660 train_time:6656ms step_avg:91.18ms +step:74/1660 train_time:6750ms step_avg:91.21ms +step:75/1660 train_time:6843ms step_avg:91.24ms +step:76/1660 train_time:6934ms step_avg:91.24ms +step:77/1660 train_time:7026ms step_avg:91.24ms +step:78/1660 train_time:7118ms step_avg:91.26ms +step:79/1660 train_time:7210ms step_avg:91.26ms +step:80/1660 train_time:7302ms step_avg:91.28ms +step:81/1660 train_time:7394ms step_avg:91.28ms +step:82/1660 train_time:7485ms step_avg:91.28ms +step:83/1660 train_time:7576ms step_avg:91.28ms +step:84/1660 train_time:7669ms step_avg:91.30ms +step:85/1660 train_time:7762ms step_avg:91.32ms +step:86/1660 train_time:7854ms step_avg:91.33ms +step:87/1660 train_time:7948ms step_avg:91.36ms +step:88/1660 train_time:8040ms step_avg:91.36ms +step:89/1660 train_time:8132ms step_avg:91.37ms +step:90/1660 train_time:8224ms step_avg:91.38ms +step:91/1660 train_time:8316ms step_avg:91.38ms +step:92/1660 train_time:8407ms step_avg:91.38ms +step:93/1660 train_time:8498ms step_avg:91.38ms +step:94/1660 train_time:8591ms step_avg:91.39ms +step:95/1660 train_time:8682ms step_avg:91.39ms +step:96/1660 train_time:8774ms step_avg:91.40ms +step:97/1660 train_time:8866ms step_avg:91.40ms +step:98/1660 train_time:8957ms step_avg:91.40ms +step:99/1660 train_time:9050ms step_avg:91.41ms +step:100/1660 train_time:9143ms step_avg:91.43ms +step:101/1660 train_time:9235ms step_avg:91.43ms +step:102/1660 train_time:9326ms step_avg:91.43ms +step:103/1660 train_time:9417ms step_avg:91.43ms +step:104/1660 train_time:9509ms step_avg:91.43ms +step:105/1660 train_time:9600ms step_avg:91.43ms +step:106/1660 train_time:9692ms step_avg:91.43ms +step:107/1660 train_time:9783ms step_avg:91.43ms +step:108/1660 train_time:9875ms step_avg:91.44ms +step:109/1660 train_time:9967ms step_avg:91.44ms +step:110/1660 train_time:10058ms step_avg:91.44ms +step:111/1660 train_time:10151ms step_avg:91.45ms +step:112/1660 train_time:10243ms step_avg:91.46ms +step:113/1660 train_time:10335ms step_avg:91.46ms +step:114/1660 train_time:10426ms step_avg:91.46ms +step:115/1660 train_time:10517ms step_avg:91.45ms +step:116/1660 train_time:10609ms step_avg:91.46ms +step:117/1660 train_time:10702ms step_avg:91.47ms +step:118/1660 train_time:10793ms step_avg:91.47ms +step:119/1660 train_time:10884ms step_avg:91.47ms +step:120/1660 train_time:10976ms step_avg:91.46ms +step:121/1660 train_time:11067ms step_avg:91.47ms +step:122/1660 train_time:11159ms step_avg:91.47ms +step:123/1660 train_time:11252ms step_avg:91.48ms +step:124/1660 train_time:11343ms step_avg:91.48ms +step:125/1660 train_time:11435ms step_avg:91.48ms +step:125/1660 val_loss:4.3159 train_time:11528ms step_avg:92.22ms +step:126/1660 train_time:11551ms step_avg:91.68ms +step:127/1660 train_time:11623ms step_avg:91.52ms +step:128/1660 train_time:11729ms step_avg:91.63ms +step:129/1660 train_time:11823ms step_avg:91.65ms +step:130/1660 train_time:11914ms step_avg:91.64ms +step:131/1660 train_time:12004ms step_avg:91.63ms +step:132/1660 train_time:12094ms step_avg:91.62ms +step:133/1660 train_time:12184ms step_avg:91.61ms +step:134/1660 train_time:12274ms step_avg:91.60ms +step:135/1660 train_time:12365ms step_avg:91.59ms +step:136/1660 train_time:12455ms step_avg:91.58ms +step:137/1660 train_time:12546ms step_avg:91.57ms +step:138/1660 train_time:12640ms step_avg:91.60ms +step:139/1660 train_time:12734ms step_avg:91.61ms +step:140/1660 train_time:12828ms step_avg:91.63ms +step:141/1660 train_time:12920ms step_avg:91.63ms +step:142/1660 train_time:13012ms step_avg:91.63ms +step:143/1660 train_time:13103ms step_avg:91.63ms +step:144/1660 train_time:13193ms step_avg:91.62ms +step:145/1660 train_time:13284ms step_avg:91.62ms +step:146/1660 train_time:13375ms step_avg:91.61ms +step:147/1660 train_time:13466ms step_avg:91.61ms +step:148/1660 train_time:13558ms step_avg:91.61ms +step:149/1660 train_time:13650ms step_avg:91.61ms +step:150/1660 train_time:13743ms step_avg:91.62ms +step:151/1660 train_time:13835ms step_avg:91.63ms +step:152/1660 train_time:13928ms step_avg:91.63ms +step:153/1660 train_time:14020ms step_avg:91.63ms +step:154/1660 train_time:14111ms step_avg:91.63ms +step:155/1660 train_time:14202ms step_avg:91.62ms +step:156/1660 train_time:14292ms step_avg:91.62ms +step:157/1660 train_time:14384ms step_avg:91.62ms +step:158/1660 train_time:14474ms step_avg:91.61ms +step:159/1660 train_time:14566ms step_avg:91.61ms +step:160/1660 train_time:14658ms step_avg:91.61ms +step:161/1660 train_time:14750ms step_avg:91.62ms +step:162/1660 train_time:14843ms step_avg:91.62ms +step:163/1660 train_time:14935ms step_avg:91.62ms +step:164/1660 train_time:15026ms step_avg:91.62ms +step:165/1660 train_time:15117ms step_avg:91.62ms +step:166/1660 train_time:15208ms step_avg:91.61ms +step:167/1660 train_time:15299ms step_avg:91.61ms +step:168/1660 train_time:15390ms step_avg:91.61ms +step:169/1660 train_time:15482ms step_avg:91.61ms +step:170/1660 train_time:15573ms step_avg:91.60ms +step:171/1660 train_time:15666ms step_avg:91.61ms +step:172/1660 train_time:15758ms step_avg:91.62ms +step:173/1660 train_time:15851ms step_avg:91.62ms +step:174/1660 train_time:15942ms step_avg:91.62ms +step:175/1660 train_time:16034ms step_avg:91.62ms +step:176/1660 train_time:16125ms step_avg:91.62ms +step:177/1660 train_time:16215ms step_avg:91.61ms +step:178/1660 train_time:16307ms step_avg:91.61ms +step:179/1660 train_time:16398ms step_avg:91.61ms +step:180/1660 train_time:16489ms step_avg:91.61ms +step:181/1660 train_time:16581ms step_avg:91.61ms +step:182/1660 train_time:16673ms step_avg:91.61ms +step:183/1660 train_time:16766ms step_avg:91.62ms +step:184/1660 train_time:16858ms step_avg:91.62ms +step:185/1660 train_time:16949ms step_avg:91.62ms +step:186/1660 train_time:17041ms step_avg:91.62ms +step:187/1660 train_time:17132ms step_avg:91.62ms +step:188/1660 train_time:17223ms step_avg:91.61ms +step:189/1660 train_time:17313ms step_avg:91.61ms +step:190/1660 train_time:17403ms step_avg:91.60ms +step:191/1660 train_time:17494ms step_avg:91.59ms +step:192/1660 train_time:17587ms step_avg:91.60ms +step:193/1660 train_time:17679ms step_avg:91.60ms +step:194/1660 train_time:17771ms step_avg:91.60ms +step:195/1660 train_time:17862ms step_avg:91.60ms +step:196/1660 train_time:17953ms step_avg:91.60ms +step:197/1660 train_time:18045ms step_avg:91.60ms +step:198/1660 train_time:18136ms step_avg:91.60ms +step:199/1660 train_time:18227ms step_avg:91.59ms +step:200/1660 train_time:18318ms step_avg:91.59ms +step:201/1660 train_time:18410ms step_avg:91.59ms +step:202/1660 train_time:18501ms step_avg:91.59ms +step:203/1660 train_time:18592ms step_avg:91.59ms +step:204/1660 train_time:18685ms step_avg:91.60ms +step:205/1660 train_time:18777ms step_avg:91.60ms +step:206/1660 train_time:18869ms step_avg:91.60ms +step:207/1660 train_time:18961ms step_avg:91.60ms +step:208/1660 train_time:19052ms step_avg:91.60ms +step:209/1660 train_time:19143ms step_avg:91.59ms +step:210/1660 train_time:19234ms step_avg:91.59ms +step:211/1660 train_time:19324ms step_avg:91.58ms +step:212/1660 train_time:19416ms step_avg:91.59ms +step:213/1660 train_time:19508ms step_avg:91.58ms +step:214/1660 train_time:19599ms step_avg:91.58ms +step:215/1660 train_time:19691ms step_avg:91.58ms +step:216/1660 train_time:19782ms step_avg:91.58ms +step:217/1660 train_time:19874ms step_avg:91.58ms +step:218/1660 train_time:19966ms step_avg:91.59ms +step:219/1660 train_time:20058ms step_avg:91.59ms +step:220/1660 train_time:20151ms step_avg:91.59ms +step:221/1660 train_time:20242ms step_avg:91.59ms +step:222/1660 train_time:20333ms step_avg:91.59ms +step:223/1660 train_time:20425ms step_avg:91.59ms +step:224/1660 train_time:20516ms step_avg:91.59ms +step:225/1660 train_time:20608ms step_avg:91.59ms +step:226/1660 train_time:20700ms step_avg:91.59ms +step:227/1660 train_time:20792ms step_avg:91.59ms +step:228/1660 train_time:20884ms step_avg:91.60ms +step:229/1660 train_time:20976ms step_avg:91.60ms +step:230/1660 train_time:21068ms step_avg:91.60ms +step:231/1660 train_time:21159ms step_avg:91.60ms +step:232/1660 train_time:21250ms step_avg:91.60ms +step:233/1660 train_time:21342ms step_avg:91.59ms +step:234/1660 train_time:21432ms step_avg:91.59ms +step:235/1660 train_time:21524ms step_avg:91.59ms +step:236/1660 train_time:21615ms step_avg:91.59ms +step:237/1660 train_time:21707ms step_avg:91.59ms +step:238/1660 train_time:21798ms step_avg:91.59ms +step:239/1660 train_time:21890ms step_avg:91.59ms +step:240/1660 train_time:21983ms step_avg:91.59ms +step:241/1660 train_time:22075ms step_avg:91.60ms +step:242/1660 train_time:22168ms step_avg:91.60ms +step:243/1660 train_time:22260ms step_avg:91.60ms +step:244/1660 train_time:22351ms step_avg:91.60ms +step:245/1660 train_time:22442ms step_avg:91.60ms +step:246/1660 train_time:22533ms step_avg:91.60ms +step:247/1660 train_time:22623ms step_avg:91.59ms +step:248/1660 train_time:22715ms step_avg:91.59ms +step:249/1660 train_time:22807ms step_avg:91.59ms +step:250/1660 train_time:22898ms step_avg:91.59ms +step:250/1660 val_loss:3.9653 train_time:22992ms step_avg:91.97ms +step:251/1660 train_time:23014ms step_avg:91.69ms +step:252/1660 train_time:23087ms step_avg:91.61ms +step:253/1660 train_time:23182ms step_avg:91.63ms +step:254/1660 train_time:23274ms step_avg:91.63ms +step:255/1660 train_time:23364ms step_avg:91.63ms +step:256/1660 train_time:23455ms step_avg:91.62ms +step:257/1660 train_time:23545ms step_avg:91.61ms +step:258/1660 train_time:23635ms step_avg:91.61ms +step:259/1660 train_time:23725ms step_avg:91.60ms +step:260/1660 train_time:23816ms step_avg:91.60ms +step:261/1660 train_time:23907ms step_avg:91.60ms +step:262/1660 train_time:24000ms step_avg:91.60ms +step:263/1660 train_time:24094ms step_avg:91.61ms +step:264/1660 train_time:24187ms step_avg:91.62ms +step:265/1660 train_time:24279ms step_avg:91.62ms +step:266/1660 train_time:24371ms step_avg:91.62ms +step:267/1660 train_time:24462ms step_avg:91.62ms +step:268/1660 train_time:24552ms step_avg:91.61ms +step:269/1660 train_time:24643ms step_avg:91.61ms +step:270/1660 train_time:24734ms step_avg:91.61ms +step:271/1660 train_time:24824ms step_avg:91.60ms +step:272/1660 train_time:24915ms step_avg:91.60ms +step:273/1660 train_time:25006ms step_avg:91.60ms +step:274/1660 train_time:25100ms step_avg:91.60ms +step:275/1660 train_time:25192ms step_avg:91.61ms +step:276/1660 train_time:25283ms step_avg:91.60ms +step:277/1660 train_time:25375ms step_avg:91.61ms +step:278/1660 train_time:25466ms step_avg:91.60ms +step:279/1660 train_time:25557ms step_avg:91.60ms +step:280/1660 train_time:25648ms step_avg:91.60ms +step:281/1660 train_time:25739ms step_avg:91.60ms +step:282/1660 train_time:25830ms step_avg:91.59ms +step:283/1660 train_time:25921ms step_avg:91.59ms +step:284/1660 train_time:26012ms step_avg:91.59ms +step:285/1660 train_time:26105ms step_avg:91.59ms +step:286/1660 train_time:26197ms step_avg:91.60ms +step:287/1660 train_time:26289ms step_avg:91.60ms +step:288/1660 train_time:26381ms step_avg:91.60ms +step:289/1660 train_time:26471ms step_avg:91.60ms +step:290/1660 train_time:26562ms step_avg:91.59ms +step:291/1660 train_time:26653ms step_avg:91.59ms +step:292/1660 train_time:26744ms step_avg:91.59ms +step:293/1660 train_time:26835ms step_avg:91.59ms +step:294/1660 train_time:26927ms step_avg:91.59ms +step:295/1660 train_time:27019ms step_avg:91.59ms +step:296/1660 train_time:27111ms step_avg:91.59ms +step:297/1660 train_time:27203ms step_avg:91.59ms +step:298/1660 train_time:27295ms step_avg:91.59ms +step:299/1660 train_time:27387ms step_avg:91.59ms +step:300/1660 train_time:27478ms step_avg:91.59ms +step:301/1660 train_time:27569ms step_avg:91.59ms +step:302/1660 train_time:27661ms step_avg:91.59ms +step:303/1660 train_time:27752ms step_avg:91.59ms +step:304/1660 train_time:27843ms step_avg:91.59ms +step:305/1660 train_time:27935ms step_avg:91.59ms +step:306/1660 train_time:28027ms step_avg:91.59ms +step:307/1660 train_time:28119ms step_avg:91.59ms +step:308/1660 train_time:28211ms step_avg:91.59ms +step:309/1660 train_time:28302ms step_avg:91.59ms +step:310/1660 train_time:28394ms step_avg:91.59ms +step:311/1660 train_time:28485ms step_avg:91.59ms +step:312/1660 train_time:28576ms step_avg:91.59ms +step:313/1660 train_time:28668ms step_avg:91.59ms +step:314/1660 train_time:28760ms step_avg:91.59ms +step:315/1660 train_time:28851ms step_avg:91.59ms +step:316/1660 train_time:28942ms step_avg:91.59ms +step:317/1660 train_time:29035ms step_avg:91.59ms +step:318/1660 train_time:29127ms step_avg:91.60ms +step:319/1660 train_time:29220ms step_avg:91.60ms +step:320/1660 train_time:29312ms step_avg:91.60ms +step:321/1660 train_time:29403ms step_avg:91.60ms +step:322/1660 train_time:29494ms step_avg:91.60ms +step:323/1660 train_time:29585ms step_avg:91.59ms +step:324/1660 train_time:29677ms step_avg:91.59ms +step:325/1660 train_time:29768ms step_avg:91.59ms +step:326/1660 train_time:29859ms step_avg:91.59ms +step:327/1660 train_time:29951ms step_avg:91.59ms +step:328/1660 train_time:30042ms step_avg:91.59ms +step:329/1660 train_time:30135ms step_avg:91.60ms +step:330/1660 train_time:30227ms step_avg:91.60ms +step:331/1660 train_time:30320ms step_avg:91.60ms +step:332/1660 train_time:30411ms step_avg:91.60ms +step:333/1660 train_time:30502ms step_avg:91.60ms +step:334/1660 train_time:30592ms step_avg:91.59ms +step:335/1660 train_time:30682ms step_avg:91.59ms +step:336/1660 train_time:30774ms step_avg:91.59ms +step:337/1660 train_time:30865ms step_avg:91.59ms +step:338/1660 train_time:30956ms step_avg:91.59ms +step:339/1660 train_time:31049ms step_avg:91.59ms +step:340/1660 train_time:31140ms step_avg:91.59ms +step:341/1660 train_time:31232ms step_avg:91.59ms +step:342/1660 train_time:31323ms step_avg:91.59ms +step:343/1660 train_time:31415ms step_avg:91.59ms +step:344/1660 train_time:31506ms step_avg:91.59ms +step:345/1660 train_time:31598ms step_avg:91.59ms +step:346/1660 train_time:31689ms step_avg:91.59ms +step:347/1660 train_time:31779ms step_avg:91.58ms +step:348/1660 train_time:31870ms step_avg:91.58ms +step:349/1660 train_time:31962ms step_avg:91.58ms +step:350/1660 train_time:32053ms step_avg:91.58ms +step:351/1660 train_time:32144ms step_avg:91.58ms +step:352/1660 train_time:32237ms step_avg:91.58ms +step:353/1660 train_time:32329ms step_avg:91.58ms +step:354/1660 train_time:32421ms step_avg:91.58ms +step:355/1660 train_time:32513ms step_avg:91.59ms +step:356/1660 train_time:32604ms step_avg:91.58ms +step:357/1660 train_time:32696ms step_avg:91.59ms +step:358/1660 train_time:32787ms step_avg:91.58ms +step:359/1660 train_time:32878ms step_avg:91.58ms +step:360/1660 train_time:32970ms step_avg:91.58ms +step:361/1660 train_time:33061ms step_avg:91.58ms +step:362/1660 train_time:33152ms step_avg:91.58ms +step:363/1660 train_time:33243ms step_avg:91.58ms +step:364/1660 train_time:33335ms step_avg:91.58ms +step:365/1660 train_time:33426ms step_avg:91.58ms +step:366/1660 train_time:33520ms step_avg:91.58ms +step:367/1660 train_time:33611ms step_avg:91.58ms +step:368/1660 train_time:33702ms step_avg:91.58ms +step:369/1660 train_time:33793ms step_avg:91.58ms +step:370/1660 train_time:33885ms step_avg:91.58ms +step:371/1660 train_time:33976ms step_avg:91.58ms +step:372/1660 train_time:34067ms step_avg:91.58ms +step:373/1660 train_time:34158ms step_avg:91.58ms +step:374/1660 train_time:34250ms step_avg:91.58ms +step:375/1660 train_time:34342ms step_avg:91.58ms +step:375/1660 val_loss:3.8136 train_time:34436ms step_avg:91.83ms +step:376/1660 train_time:34458ms step_avg:91.64ms +step:377/1660 train_time:34533ms step_avg:91.60ms +step:378/1660 train_time:34631ms step_avg:91.62ms +step:379/1660 train_time:34722ms step_avg:91.61ms +step:380/1660 train_time:34812ms step_avg:91.61ms +step:381/1660 train_time:34902ms step_avg:91.61ms +step:382/1660 train_time:34993ms step_avg:91.60ms +step:383/1660 train_time:35083ms step_avg:91.60ms +step:384/1660 train_time:35173ms step_avg:91.60ms +step:385/1660 train_time:35264ms step_avg:91.59ms +step:386/1660 train_time:35354ms step_avg:91.59ms +step:387/1660 train_time:35447ms step_avg:91.59ms +step:388/1660 train_time:35541ms step_avg:91.60ms +step:389/1660 train_time:35635ms step_avg:91.61ms +step:390/1660 train_time:35727ms step_avg:91.61ms +step:391/1660 train_time:35818ms step_avg:91.61ms +step:392/1660 train_time:35909ms step_avg:91.60ms +step:393/1660 train_time:36000ms step_avg:91.60ms +step:394/1660 train_time:36091ms step_avg:91.60ms +step:395/1660 train_time:36181ms step_avg:91.60ms +step:396/1660 train_time:36272ms step_avg:91.60ms +step:397/1660 train_time:36363ms step_avg:91.59ms +step:398/1660 train_time:36455ms step_avg:91.59ms +step:399/1660 train_time:36547ms step_avg:91.60ms +step:400/1660 train_time:36640ms step_avg:91.60ms +step:401/1660 train_time:36732ms step_avg:91.60ms +step:402/1660 train_time:36823ms step_avg:91.60ms +step:403/1660 train_time:36914ms step_avg:91.60ms +step:404/1660 train_time:37005ms step_avg:91.60ms +step:405/1660 train_time:37096ms step_avg:91.60ms +step:406/1660 train_time:37187ms step_avg:91.59ms +step:407/1660 train_time:37279ms step_avg:91.59ms +step:408/1660 train_time:37370ms step_avg:91.59ms +step:409/1660 train_time:37463ms step_avg:91.60ms +step:410/1660 train_time:37555ms step_avg:91.60ms +step:411/1660 train_time:37646ms step_avg:91.60ms +step:412/1660 train_time:37739ms step_avg:91.60ms +step:413/1660 train_time:37830ms step_avg:91.60ms +step:414/1660 train_time:37922ms step_avg:91.60ms +step:415/1660 train_time:38012ms step_avg:91.60ms +step:416/1660 train_time:38103ms step_avg:91.59ms +step:417/1660 train_time:38194ms step_avg:91.59ms +step:418/1660 train_time:38286ms step_avg:91.59ms +step:419/1660 train_time:38378ms step_avg:91.59ms +step:420/1660 train_time:38470ms step_avg:91.60ms +step:421/1660 train_time:38562ms step_avg:91.60ms +step:422/1660 train_time:38654ms step_avg:91.60ms +step:423/1660 train_time:38746ms step_avg:91.60ms +step:424/1660 train_time:38838ms step_avg:91.60ms +step:425/1660 train_time:38929ms step_avg:91.60ms +step:426/1660 train_time:39020ms step_avg:91.60ms +step:427/1660 train_time:39111ms step_avg:91.60ms +step:428/1660 train_time:39202ms step_avg:91.59ms +step:429/1660 train_time:39293ms step_avg:91.59ms +step:430/1660 train_time:39384ms step_avg:91.59ms +step:431/1660 train_time:39475ms step_avg:91.59ms +step:432/1660 train_time:39567ms step_avg:91.59ms +step:433/1660 train_time:39659ms step_avg:91.59ms +step:434/1660 train_time:39751ms step_avg:91.59ms +step:435/1660 train_time:39843ms step_avg:91.59ms +step:436/1660 train_time:39934ms step_avg:91.59ms +step:437/1660 train_time:40025ms step_avg:91.59ms +step:438/1660 train_time:40117ms step_avg:91.59ms +step:439/1660 train_time:40209ms step_avg:91.59ms +step:440/1660 train_time:40300ms step_avg:91.59ms +step:441/1660 train_time:40392ms step_avg:91.59ms +step:442/1660 train_time:40484ms step_avg:91.59ms +step:443/1660 train_time:40575ms step_avg:91.59ms +step:444/1660 train_time:40668ms step_avg:91.59ms +step:445/1660 train_time:40760ms step_avg:91.60ms +step:446/1660 train_time:40852ms step_avg:91.60ms +step:447/1660 train_time:40944ms step_avg:91.60ms +step:448/1660 train_time:41034ms step_avg:91.59ms +step:449/1660 train_time:41125ms step_avg:91.59ms +step:450/1660 train_time:41216ms step_avg:91.59ms +step:451/1660 train_time:41309ms step_avg:91.59ms +step:452/1660 train_time:41400ms step_avg:91.59ms +step:453/1660 train_time:41491ms step_avg:91.59ms +step:454/1660 train_time:41583ms step_avg:91.59ms +step:455/1660 train_time:41674ms step_avg:91.59ms +step:456/1660 train_time:41767ms step_avg:91.59ms +step:457/1660 train_time:41858ms step_avg:91.59ms +step:458/1660 train_time:41950ms step_avg:91.59ms +step:459/1660 train_time:42041ms step_avg:91.59ms +step:460/1660 train_time:42133ms step_avg:91.59ms +step:461/1660 train_time:42224ms step_avg:91.59ms +step:462/1660 train_time:42315ms step_avg:91.59ms +step:463/1660 train_time:42407ms step_avg:91.59ms +step:464/1660 train_time:42498ms step_avg:91.59ms +step:465/1660 train_time:42590ms step_avg:91.59ms +step:466/1660 train_time:42682ms step_avg:91.59ms +step:467/1660 train_time:42773ms step_avg:91.59ms +step:468/1660 train_time:42864ms step_avg:91.59ms +step:469/1660 train_time:42955ms step_avg:91.59ms +step:470/1660 train_time:43047ms step_avg:91.59ms +step:471/1660 train_time:43138ms step_avg:91.59ms +step:472/1660 train_time:43229ms step_avg:91.59ms +step:473/1660 train_time:43320ms step_avg:91.59ms +step:474/1660 train_time:43412ms step_avg:91.59ms +step:475/1660 train_time:43504ms step_avg:91.59ms +step:476/1660 train_time:43595ms step_avg:91.59ms +step:477/1660 train_time:43688ms step_avg:91.59ms +step:478/1660 train_time:43780ms step_avg:91.59ms +step:479/1660 train_time:43871ms step_avg:91.59ms +step:480/1660 train_time:43963ms step_avg:91.59ms +step:481/1660 train_time:44054ms step_avg:91.59ms +step:482/1660 train_time:44145ms step_avg:91.59ms +step:483/1660 train_time:44236ms step_avg:91.59ms +step:484/1660 train_time:44328ms step_avg:91.59ms +step:485/1660 train_time:44419ms step_avg:91.59ms +step:486/1660 train_time:44511ms step_avg:91.59ms +step:487/1660 train_time:44602ms step_avg:91.59ms +step:488/1660 train_time:44693ms step_avg:91.58ms +step:489/1660 train_time:44784ms step_avg:91.58ms +step:490/1660 train_time:44876ms step_avg:91.58ms +step:491/1660 train_time:44968ms step_avg:91.59ms +step:492/1660 train_time:45060ms step_avg:91.58ms +step:493/1660 train_time:45151ms step_avg:91.58ms +step:494/1660 train_time:45243ms step_avg:91.58ms +step:495/1660 train_time:45334ms step_avg:91.58ms +step:496/1660 train_time:45426ms step_avg:91.58ms +step:497/1660 train_time:45517ms step_avg:91.58ms +step:498/1660 train_time:45608ms step_avg:91.58ms +step:499/1660 train_time:45700ms step_avg:91.58ms +step:500/1660 train_time:45792ms step_avg:91.58ms +step:500/1660 val_loss:3.7121 train_time:45885ms step_avg:91.77ms +step:501/1660 train_time:45908ms step_avg:91.63ms +step:502/1660 train_time:45981ms step_avg:91.60ms +step:503/1660 train_time:46076ms step_avg:91.60ms +step:504/1660 train_time:46169ms step_avg:91.61ms +step:505/1660 train_time:46260ms step_avg:91.60ms +step:506/1660 train_time:46351ms step_avg:91.60ms +step:507/1660 train_time:46441ms step_avg:91.60ms +step:508/1660 train_time:46531ms step_avg:91.60ms +step:509/1660 train_time:46621ms step_avg:91.59ms +step:510/1660 train_time:46712ms step_avg:91.59ms +step:511/1660 train_time:46803ms step_avg:91.59ms +step:512/1660 train_time:46897ms step_avg:91.59ms +step:513/1660 train_time:46991ms step_avg:91.60ms +step:514/1660 train_time:47083ms step_avg:91.60ms +step:515/1660 train_time:47175ms step_avg:91.60ms +step:516/1660 train_time:47266ms step_avg:91.60ms +step:517/1660 train_time:47357ms step_avg:91.60ms +step:518/1660 train_time:47448ms step_avg:91.60ms +step:519/1660 train_time:47539ms step_avg:91.60ms +step:520/1660 train_time:47629ms step_avg:91.59ms +step:521/1660 train_time:47719ms step_avg:91.59ms +step:522/1660 train_time:47811ms step_avg:91.59ms +step:523/1660 train_time:47903ms step_avg:91.59ms +step:524/1660 train_time:47996ms step_avg:91.60ms +step:525/1660 train_time:48089ms step_avg:91.60ms +step:526/1660 train_time:48181ms step_avg:91.60ms +step:527/1660 train_time:48272ms step_avg:91.60ms +step:528/1660 train_time:48363ms step_avg:91.60ms +step:529/1660 train_time:48455ms step_avg:91.60ms +step:530/1660 train_time:48545ms step_avg:91.60ms +step:531/1660 train_time:48636ms step_avg:91.59ms +step:532/1660 train_time:48728ms step_avg:91.59ms +step:533/1660 train_time:48820ms step_avg:91.60ms +step:534/1660 train_time:48913ms step_avg:91.60ms +step:535/1660 train_time:49004ms step_avg:91.60ms +step:536/1660 train_time:49096ms step_avg:91.60ms +step:537/1660 train_time:49189ms step_avg:91.60ms +step:538/1660 train_time:49281ms step_avg:91.60ms +step:539/1660 train_time:49371ms step_avg:91.60ms +step:540/1660 train_time:49462ms step_avg:91.60ms +step:541/1660 train_time:49553ms step_avg:91.60ms +step:542/1660 train_time:49644ms step_avg:91.59ms +step:543/1660 train_time:49736ms step_avg:91.59ms +step:544/1660 train_time:49828ms step_avg:91.60ms +step:545/1660 train_time:49920ms step_avg:91.60ms +step:546/1660 train_time:50012ms step_avg:91.60ms +step:547/1660 train_time:50104ms step_avg:91.60ms +step:548/1660 train_time:50197ms step_avg:91.60ms +step:549/1660 train_time:50289ms step_avg:91.60ms +step:550/1660 train_time:50380ms step_avg:91.60ms +step:551/1660 train_time:50471ms step_avg:91.60ms +step:552/1660 train_time:50562ms step_avg:91.60ms +step:553/1660 train_time:50653ms step_avg:91.60ms +step:554/1660 train_time:50744ms step_avg:91.60ms +step:555/1660 train_time:50836ms step_avg:91.60ms +step:556/1660 train_time:50930ms step_avg:91.60ms +step:557/1660 train_time:51024ms step_avg:91.60ms +step:558/1660 train_time:51118ms step_avg:91.61ms +step:559/1660 train_time:51211ms step_avg:91.61ms +step:560/1660 train_time:51304ms step_avg:91.61ms +step:561/1660 train_time:51398ms step_avg:91.62ms +step:562/1660 train_time:51491ms step_avg:91.62ms +step:563/1660 train_time:51583ms step_avg:91.62ms +step:564/1660 train_time:51675ms step_avg:91.62ms +step:565/1660 train_time:51767ms step_avg:91.62ms +step:566/1660 train_time:51861ms step_avg:91.63ms +step:567/1660 train_time:51954ms step_avg:91.63ms +step:568/1660 train_time:52046ms step_avg:91.63ms +step:569/1660 train_time:52140ms step_avg:91.63ms +step:570/1660 train_time:52233ms step_avg:91.64ms +step:571/1660 train_time:52326ms step_avg:91.64ms +step:572/1660 train_time:52419ms step_avg:91.64ms +step:573/1660 train_time:52512ms step_avg:91.64ms +step:574/1660 train_time:52604ms step_avg:91.64ms +step:575/1660 train_time:52697ms step_avg:91.65ms +step:576/1660 train_time:52790ms step_avg:91.65ms +step:577/1660 train_time:52883ms step_avg:91.65ms +step:578/1660 train_time:52976ms step_avg:91.65ms +step:579/1660 train_time:53069ms step_avg:91.66ms +step:580/1660 train_time:53161ms step_avg:91.66ms +step:581/1660 train_time:53254ms step_avg:91.66ms +step:582/1660 train_time:53347ms step_avg:91.66ms +step:583/1660 train_time:53441ms step_avg:91.67ms +step:584/1660 train_time:53534ms step_avg:91.67ms +step:585/1660 train_time:53627ms step_avg:91.67ms +step:586/1660 train_time:53719ms step_avg:91.67ms +step:587/1660 train_time:53812ms step_avg:91.67ms +step:588/1660 train_time:53905ms step_avg:91.68ms +step:589/1660 train_time:53999ms step_avg:91.68ms +step:590/1660 train_time:54092ms step_avg:91.68ms +step:591/1660 train_time:54186ms step_avg:91.68ms +step:592/1660 train_time:54278ms step_avg:91.69ms +step:593/1660 train_time:54372ms step_avg:91.69ms +step:594/1660 train_time:54464ms step_avg:91.69ms +step:595/1660 train_time:54557ms step_avg:91.69ms +step:596/1660 train_time:54650ms step_avg:91.69ms +step:597/1660 train_time:54742ms step_avg:91.70ms +step:598/1660 train_time:54836ms step_avg:91.70ms +step:599/1660 train_time:54929ms step_avg:91.70ms +step:600/1660 train_time:55022ms step_avg:91.70ms +step:601/1660 train_time:55115ms step_avg:91.70ms +step:602/1660 train_time:55207ms step_avg:91.71ms +step:603/1660 train_time:55300ms step_avg:91.71ms +step:604/1660 train_time:55394ms step_avg:91.71ms +step:605/1660 train_time:55486ms step_avg:91.71ms +step:606/1660 train_time:55579ms step_avg:91.71ms +step:607/1660 train_time:55672ms step_avg:91.72ms +step:608/1660 train_time:55765ms step_avg:91.72ms +step:609/1660 train_time:55858ms step_avg:91.72ms +step:610/1660 train_time:55950ms step_avg:91.72ms +step:611/1660 train_time:56042ms step_avg:91.72ms +step:612/1660 train_time:56135ms step_avg:91.72ms +step:613/1660 train_time:56228ms step_avg:91.73ms +step:614/1660 train_time:56320ms step_avg:91.73ms +step:615/1660 train_time:56412ms step_avg:91.73ms +step:616/1660 train_time:56505ms step_avg:91.73ms +step:617/1660 train_time:56598ms step_avg:91.73ms +step:618/1660 train_time:56691ms step_avg:91.73ms +step:619/1660 train_time:56784ms step_avg:91.73ms +step:620/1660 train_time:56876ms step_avg:91.74ms +step:621/1660 train_time:56968ms step_avg:91.74ms +step:622/1660 train_time:57060ms step_avg:91.74ms +step:623/1660 train_time:57153ms step_avg:91.74ms +step:624/1660 train_time:57246ms step_avg:91.74ms +step:625/1660 train_time:57339ms step_avg:91.74ms +step:625/1660 val_loss:3.6115 train_time:57434ms step_avg:91.89ms +step:626/1660 train_time:57455ms step_avg:91.78ms +step:627/1660 train_time:57527ms step_avg:91.75ms +step:628/1660 train_time:57628ms step_avg:91.76ms +step:629/1660 train_time:57722ms step_avg:91.77ms +step:630/1660 train_time:57813ms step_avg:91.77ms +step:631/1660 train_time:57904ms step_avg:91.77ms +step:632/1660 train_time:57995ms step_avg:91.76ms +step:633/1660 train_time:58087ms step_avg:91.76ms +step:634/1660 train_time:58178ms step_avg:91.76ms +step:635/1660 train_time:58269ms step_avg:91.76ms +step:636/1660 train_time:58364ms step_avg:91.77ms +step:637/1660 train_time:58461ms step_avg:91.77ms +step:638/1660 train_time:58556ms step_avg:91.78ms +step:639/1660 train_time:58650ms step_avg:91.78ms +step:640/1660 train_time:58743ms step_avg:91.79ms +step:641/1660 train_time:58835ms step_avg:91.79ms +step:642/1660 train_time:58927ms step_avg:91.79ms +step:643/1660 train_time:59019ms step_avg:91.79ms +step:644/1660 train_time:59111ms step_avg:91.79ms +step:645/1660 train_time:59202ms step_avg:91.79ms +step:646/1660 train_time:59295ms step_avg:91.79ms +step:647/1660 train_time:59389ms step_avg:91.79ms +step:648/1660 train_time:59483ms step_avg:91.79ms +step:649/1660 train_time:59578ms step_avg:91.80ms +step:650/1660 train_time:59671ms step_avg:91.80ms +step:651/1660 train_time:59763ms step_avg:91.80ms +step:652/1660 train_time:59855ms step_avg:91.80ms +step:653/1660 train_time:59948ms step_avg:91.80ms +step:654/1660 train_time:60040ms step_avg:91.80ms +step:655/1660 train_time:60131ms step_avg:91.80ms +step:656/1660 train_time:60224ms step_avg:91.81ms +step:657/1660 train_time:60317ms step_avg:91.81ms +step:658/1660 train_time:60409ms step_avg:91.81ms +step:659/1660 train_time:60503ms step_avg:91.81ms +step:660/1660 train_time:60598ms step_avg:91.82ms +step:661/1660 train_time:60691ms step_avg:91.82ms +step:662/1660 train_time:60783ms step_avg:91.82ms +step:663/1660 train_time:60877ms step_avg:91.82ms +step:664/1660 train_time:60970ms step_avg:91.82ms +step:665/1660 train_time:61061ms step_avg:91.82ms +step:666/1660 train_time:61154ms step_avg:91.82ms +step:667/1660 train_time:61246ms step_avg:91.82ms +step:668/1660 train_time:61338ms step_avg:91.82ms +step:669/1660 train_time:61431ms step_avg:91.82ms +step:670/1660 train_time:61525ms step_avg:91.83ms +step:671/1660 train_time:61619ms step_avg:91.83ms +step:672/1660 train_time:61712ms step_avg:91.83ms +step:673/1660 train_time:61804ms step_avg:91.83ms +step:674/1660 train_time:61899ms step_avg:91.84ms +step:675/1660 train_time:61991ms step_avg:91.84ms +step:676/1660 train_time:62083ms step_avg:91.84ms +step:677/1660 train_time:62177ms step_avg:91.84ms +step:678/1660 train_time:62270ms step_avg:91.84ms +step:679/1660 train_time:62362ms step_avg:91.84ms +step:680/1660 train_time:62454ms step_avg:91.84ms +step:681/1660 train_time:62547ms step_avg:91.85ms +step:682/1660 train_time:62640ms step_avg:91.85ms +step:683/1660 train_time:62732ms step_avg:91.85ms +step:684/1660 train_time:62825ms step_avg:91.85ms +step:685/1660 train_time:62917ms step_avg:91.85ms +step:686/1660 train_time:63010ms step_avg:91.85ms +step:687/1660 train_time:63102ms step_avg:91.85ms +step:688/1660 train_time:63196ms step_avg:91.85ms +step:689/1660 train_time:63289ms step_avg:91.86ms +step:690/1660 train_time:63381ms step_avg:91.86ms +step:691/1660 train_time:63474ms step_avg:91.86ms +step:692/1660 train_time:63567ms step_avg:91.86ms +step:693/1660 train_time:63660ms step_avg:91.86ms +step:694/1660 train_time:63753ms step_avg:91.86ms +step:695/1660 train_time:63847ms step_avg:91.87ms +step:696/1660 train_time:63939ms step_avg:91.87ms +step:697/1660 train_time:64031ms step_avg:91.87ms +step:698/1660 train_time:64124ms step_avg:91.87ms +step:699/1660 train_time:64218ms step_avg:91.87ms +step:700/1660 train_time:64310ms step_avg:91.87ms +step:701/1660 train_time:64402ms step_avg:91.87ms +step:702/1660 train_time:64496ms step_avg:91.87ms +step:703/1660 train_time:64590ms step_avg:91.88ms +step:704/1660 train_time:64681ms step_avg:91.88ms +step:705/1660 train_time:64774ms step_avg:91.88ms +step:706/1660 train_time:64867ms step_avg:91.88ms +step:707/1660 train_time:64960ms step_avg:91.88ms +step:708/1660 train_time:65052ms step_avg:91.88ms +step:709/1660 train_time:65144ms step_avg:91.88ms +step:710/1660 train_time:65237ms step_avg:91.88ms +step:711/1660 train_time:65330ms step_avg:91.88ms +step:712/1660 train_time:65422ms step_avg:91.88ms +step:713/1660 train_time:65515ms step_avg:91.89ms +step:714/1660 train_time:65608ms step_avg:91.89ms +step:715/1660 train_time:65701ms step_avg:91.89ms +step:716/1660 train_time:65795ms step_avg:91.89ms +step:717/1660 train_time:65888ms step_avg:91.89ms +step:718/1660 train_time:65980ms step_avg:91.89ms +step:719/1660 train_time:66072ms step_avg:91.89ms +step:720/1660 train_time:66164ms step_avg:91.89ms +step:721/1660 train_time:66257ms step_avg:91.90ms +step:722/1660 train_time:66351ms step_avg:91.90ms +step:723/1660 train_time:66443ms step_avg:91.90ms +step:724/1660 train_time:66536ms step_avg:91.90ms +step:725/1660 train_time:66629ms step_avg:91.90ms +step:726/1660 train_time:66721ms step_avg:91.90ms +step:727/1660 train_time:66815ms step_avg:91.91ms +step:728/1660 train_time:66908ms step_avg:91.91ms +step:729/1660 train_time:67000ms step_avg:91.91ms +step:730/1660 train_time:67093ms step_avg:91.91ms +step:731/1660 train_time:67186ms step_avg:91.91ms +step:732/1660 train_time:67279ms step_avg:91.91ms +step:733/1660 train_time:67371ms step_avg:91.91ms +step:734/1660 train_time:67464ms step_avg:91.91ms +step:735/1660 train_time:67557ms step_avg:91.91ms +step:736/1660 train_time:67650ms step_avg:91.92ms +step:737/1660 train_time:67743ms step_avg:91.92ms +step:738/1660 train_time:67835ms step_avg:91.92ms +step:739/1660 train_time:67927ms step_avg:91.92ms +step:740/1660 train_time:68020ms step_avg:91.92ms +step:741/1660 train_time:68113ms step_avg:91.92ms +step:742/1660 train_time:68205ms step_avg:91.92ms +step:743/1660 train_time:68299ms step_avg:91.92ms +step:744/1660 train_time:68391ms step_avg:91.92ms +step:745/1660 train_time:68483ms step_avg:91.92ms +step:746/1660 train_time:68576ms step_avg:91.93ms +step:747/1660 train_time:68670ms step_avg:91.93ms +step:748/1660 train_time:68762ms step_avg:91.93ms +step:749/1660 train_time:68856ms step_avg:91.93ms +step:750/1660 train_time:68949ms step_avg:91.93ms +step:750/1660 val_loss:3.5579 train_time:69043ms step_avg:92.06ms +step:751/1660 train_time:69064ms step_avg:91.96ms +step:752/1660 train_time:69141ms step_avg:91.94ms +step:753/1660 train_time:69238ms step_avg:91.95ms +step:754/1660 train_time:69332ms step_avg:91.95ms +step:755/1660 train_time:69424ms step_avg:91.95ms +step:756/1660 train_time:69515ms step_avg:91.95ms +step:757/1660 train_time:69607ms step_avg:91.95ms +step:758/1660 train_time:69699ms step_avg:91.95ms +step:759/1660 train_time:69791ms step_avg:91.95ms +step:760/1660 train_time:69882ms step_avg:91.95ms +step:761/1660 train_time:69974ms step_avg:91.95ms +step:762/1660 train_time:70068ms step_avg:91.95ms +step:763/1660 train_time:70164ms step_avg:91.96ms +step:764/1660 train_time:70260ms step_avg:91.96ms +step:765/1660 train_time:70352ms step_avg:91.96ms +step:766/1660 train_time:70444ms step_avg:91.96ms +step:767/1660 train_time:70537ms step_avg:91.96ms +step:768/1660 train_time:70629ms step_avg:91.96ms +step:769/1660 train_time:70721ms step_avg:91.96ms +step:770/1660 train_time:70813ms step_avg:91.96ms +step:771/1660 train_time:70904ms step_avg:91.96ms +step:772/1660 train_time:70998ms step_avg:91.97ms +step:773/1660 train_time:71092ms step_avg:91.97ms +step:774/1660 train_time:71187ms step_avg:91.97ms +step:775/1660 train_time:71281ms step_avg:91.98ms +step:776/1660 train_time:71375ms step_avg:91.98ms +step:777/1660 train_time:71468ms step_avg:91.98ms +step:778/1660 train_time:71560ms step_avg:91.98ms +step:779/1660 train_time:71652ms step_avg:91.98ms +step:780/1660 train_time:71744ms step_avg:91.98ms +step:781/1660 train_time:71836ms step_avg:91.98ms +step:782/1660 train_time:71928ms step_avg:91.98ms +step:783/1660 train_time:72021ms step_avg:91.98ms +step:784/1660 train_time:72113ms step_avg:91.98ms +step:785/1660 train_time:72207ms step_avg:91.98ms +step:786/1660 train_time:72302ms step_avg:91.99ms +step:787/1660 train_time:72396ms step_avg:91.99ms +step:788/1660 train_time:72488ms step_avg:91.99ms +step:789/1660 train_time:72582ms step_avg:91.99ms +step:790/1660 train_time:72674ms step_avg:91.99ms +step:791/1660 train_time:72765ms step_avg:91.99ms +step:792/1660 train_time:72858ms step_avg:91.99ms +step:793/1660 train_time:72950ms step_avg:91.99ms +step:794/1660 train_time:73042ms step_avg:91.99ms +step:795/1660 train_time:73137ms step_avg:92.00ms +step:796/1660 train_time:73231ms step_avg:92.00ms +step:797/1660 train_time:73324ms step_avg:92.00ms +step:798/1660 train_time:73417ms step_avg:92.00ms +step:799/1660 train_time:73511ms step_avg:92.00ms +step:800/1660 train_time:73603ms step_avg:92.00ms +step:801/1660 train_time:73695ms step_avg:92.00ms +step:802/1660 train_time:73787ms step_avg:92.00ms +step:803/1660 train_time:73879ms step_avg:92.00ms +step:804/1660 train_time:73972ms step_avg:92.00ms +step:805/1660 train_time:74064ms step_avg:92.01ms +step:806/1660 train_time:74159ms step_avg:92.01ms +step:807/1660 train_time:74252ms step_avg:92.01ms +step:808/1660 train_time:74345ms step_avg:92.01ms +step:809/1660 train_time:74440ms step_avg:92.01ms +step:810/1660 train_time:74533ms step_avg:92.02ms +step:811/1660 train_time:74625ms step_avg:92.02ms +step:812/1660 train_time:74718ms step_avg:92.02ms +step:813/1660 train_time:74810ms step_avg:92.02ms +step:814/1660 train_time:74903ms step_avg:92.02ms +step:815/1660 train_time:74996ms step_avg:92.02ms +step:816/1660 train_time:75089ms step_avg:92.02ms +step:817/1660 train_time:75182ms step_avg:92.02ms +step:818/1660 train_time:75275ms step_avg:92.02ms +step:819/1660 train_time:75368ms step_avg:92.02ms +step:820/1660 train_time:75461ms step_avg:92.03ms +step:821/1660 train_time:75555ms step_avg:92.03ms +step:822/1660 train_time:75647ms step_avg:92.03ms +step:823/1660 train_time:75740ms step_avg:92.03ms +step:824/1660 train_time:75833ms step_avg:92.03ms +step:825/1660 train_time:75925ms step_avg:92.03ms +step:826/1660 train_time:76019ms step_avg:92.03ms +step:827/1660 train_time:76112ms step_avg:92.03ms +step:828/1660 train_time:76204ms step_avg:92.03ms +step:829/1660 train_time:76299ms step_avg:92.04ms +step:830/1660 train_time:76392ms step_avg:92.04ms +step:831/1660 train_time:76483ms step_avg:92.04ms +step:832/1660 train_time:76576ms step_avg:92.04ms +step:833/1660 train_time:76669ms step_avg:92.04ms +step:834/1660 train_time:76761ms step_avg:92.04ms +step:835/1660 train_time:76854ms step_avg:92.04ms +step:836/1660 train_time:76946ms step_avg:92.04ms +step:837/1660 train_time:77039ms step_avg:92.04ms +step:838/1660 train_time:77132ms step_avg:92.04ms +step:839/1660 train_time:77225ms step_avg:92.04ms +step:840/1660 train_time:77319ms step_avg:92.05ms +step:841/1660 train_time:77412ms step_avg:92.05ms +step:842/1660 train_time:77504ms step_avg:92.05ms +step:843/1660 train_time:77596ms step_avg:92.05ms +step:844/1660 train_time:77689ms step_avg:92.05ms +step:845/1660 train_time:77782ms step_avg:92.05ms +step:846/1660 train_time:77876ms step_avg:92.05ms +step:847/1660 train_time:77968ms step_avg:92.05ms +step:848/1660 train_time:78062ms step_avg:92.05ms +step:849/1660 train_time:78156ms step_avg:92.06ms +step:850/1660 train_time:78248ms step_avg:92.06ms +step:851/1660 train_time:78342ms step_avg:92.06ms +step:852/1660 train_time:78434ms step_avg:92.06ms +step:853/1660 train_time:78527ms step_avg:92.06ms +step:854/1660 train_time:78619ms step_avg:92.06ms +step:855/1660 train_time:78711ms step_avg:92.06ms +step:856/1660 train_time:78804ms step_avg:92.06ms +step:857/1660 train_time:78898ms step_avg:92.06ms +step:858/1660 train_time:78991ms step_avg:92.06ms +step:859/1660 train_time:79084ms step_avg:92.06ms +step:860/1660 train_time:79178ms step_avg:92.07ms +step:861/1660 train_time:79271ms step_avg:92.07ms +step:862/1660 train_time:79364ms step_avg:92.07ms +step:863/1660 train_time:79456ms step_avg:92.07ms +step:864/1660 train_time:79548ms step_avg:92.07ms +step:865/1660 train_time:79641ms step_avg:92.07ms +step:866/1660 train_time:79734ms step_avg:92.07ms +step:867/1660 train_time:79827ms step_avg:92.07ms +step:868/1660 train_time:79919ms step_avg:92.07ms +step:869/1660 train_time:80012ms step_avg:92.07ms +step:870/1660 train_time:80105ms step_avg:92.07ms +step:871/1660 train_time:80199ms step_avg:92.08ms +step:872/1660 train_time:80292ms step_avg:92.08ms +step:873/1660 train_time:80385ms step_avg:92.08ms +step:874/1660 train_time:80477ms step_avg:92.08ms +step:875/1660 train_time:80570ms step_avg:92.08ms +step:875/1660 val_loss:3.5145 train_time:80664ms step_avg:92.19ms +step:876/1660 train_time:80685ms step_avg:92.11ms +step:877/1660 train_time:80762ms step_avg:92.09ms +step:878/1660 train_time:80859ms step_avg:92.09ms +step:879/1660 train_time:80951ms step_avg:92.09ms +step:880/1660 train_time:81042ms step_avg:92.09ms +step:881/1660 train_time:81134ms step_avg:92.09ms +step:882/1660 train_time:81226ms step_avg:92.09ms +step:883/1660 train_time:81317ms step_avg:92.09ms +step:884/1660 train_time:81409ms step_avg:92.09ms +step:885/1660 train_time:81501ms step_avg:92.09ms +step:886/1660 train_time:81593ms step_avg:92.09ms +step:887/1660 train_time:81688ms step_avg:92.09ms +step:888/1660 train_time:81782ms step_avg:92.10ms +step:889/1660 train_time:81877ms step_avg:92.10ms +step:890/1660 train_time:81971ms step_avg:92.10ms +step:891/1660 train_time:82064ms step_avg:92.10ms +step:892/1660 train_time:82157ms step_avg:92.10ms +step:893/1660 train_time:82249ms step_avg:92.10ms +step:894/1660 train_time:82341ms step_avg:92.10ms +step:895/1660 train_time:82432ms step_avg:92.10ms +step:896/1660 train_time:82525ms step_avg:92.10ms +step:897/1660 train_time:82619ms step_avg:92.11ms +step:898/1660 train_time:82712ms step_avg:92.11ms +step:899/1660 train_time:82806ms step_avg:92.11ms +step:900/1660 train_time:82899ms step_avg:92.11ms +step:901/1660 train_time:82993ms step_avg:92.11ms +step:902/1660 train_time:83084ms step_avg:92.11ms +step:903/1660 train_time:83178ms step_avg:92.11ms +step:904/1660 train_time:83270ms step_avg:92.11ms +step:905/1660 train_time:83362ms step_avg:92.11ms +step:906/1660 train_time:83454ms step_avg:92.11ms +step:907/1660 train_time:83546ms step_avg:92.11ms +step:908/1660 train_time:83638ms step_avg:92.11ms +step:909/1660 train_time:83731ms step_avg:92.11ms +step:910/1660 train_time:83824ms step_avg:92.11ms +step:911/1660 train_time:83918ms step_avg:92.12ms +step:912/1660 train_time:84012ms step_avg:92.12ms +step:913/1660 train_time:84104ms step_avg:92.12ms +step:914/1660 train_time:84197ms step_avg:92.12ms +step:915/1660 train_time:84289ms step_avg:92.12ms +step:916/1660 train_time:84381ms step_avg:92.12ms +step:917/1660 train_time:84474ms step_avg:92.12ms +step:918/1660 train_time:84567ms step_avg:92.12ms +step:919/1660 train_time:84659ms step_avg:92.12ms +step:920/1660 train_time:84752ms step_avg:92.12ms +step:921/1660 train_time:84845ms step_avg:92.12ms +step:922/1660 train_time:84938ms step_avg:92.12ms +step:923/1660 train_time:85031ms step_avg:92.12ms +step:924/1660 train_time:85123ms step_avg:92.12ms +step:925/1660 train_time:85217ms step_avg:92.13ms +step:926/1660 train_time:85309ms step_avg:92.13ms +step:927/1660 train_time:85401ms step_avg:92.13ms +step:928/1660 train_time:85494ms step_avg:92.13ms +step:929/1660 train_time:85586ms step_avg:92.13ms +step:930/1660 train_time:85679ms step_avg:92.13ms +step:931/1660 train_time:85772ms step_avg:92.13ms +step:932/1660 train_time:85865ms step_avg:92.13ms +step:933/1660 train_time:85958ms step_avg:92.13ms +step:934/1660 train_time:86051ms step_avg:92.13ms +step:935/1660 train_time:86143ms step_avg:92.13ms +step:936/1660 train_time:86237ms step_avg:92.13ms +step:937/1660 train_time:86330ms step_avg:92.13ms +step:938/1660 train_time:86422ms step_avg:92.13ms +step:939/1660 train_time:86515ms step_avg:92.13ms +step:940/1660 train_time:86607ms step_avg:92.14ms +step:941/1660 train_time:86700ms step_avg:92.14ms +step:942/1660 train_time:86793ms step_avg:92.14ms +step:943/1660 train_time:86885ms step_avg:92.14ms +step:944/1660 train_time:86978ms step_avg:92.14ms +step:945/1660 train_time:87073ms step_avg:92.14ms +step:946/1660 train_time:87166ms step_avg:92.14ms +step:947/1660 train_time:87259ms step_avg:92.14ms +step:948/1660 train_time:87351ms step_avg:92.14ms +step:949/1660 train_time:87443ms step_avg:92.14ms +step:950/1660 train_time:87537ms step_avg:92.14ms +step:951/1660 train_time:87629ms step_avg:92.14ms +step:952/1660 train_time:87722ms step_avg:92.15ms +step:953/1660 train_time:87815ms step_avg:92.15ms +step:954/1660 train_time:87908ms step_avg:92.15ms +step:955/1660 train_time:88000ms step_avg:92.15ms +step:956/1660 train_time:88095ms step_avg:92.15ms +step:957/1660 train_time:88187ms step_avg:92.15ms +step:958/1660 train_time:88280ms step_avg:92.15ms +step:959/1660 train_time:88373ms step_avg:92.15ms +step:960/1660 train_time:88466ms step_avg:92.15ms +step:961/1660 train_time:88558ms step_avg:92.15ms +step:962/1660 train_time:88651ms step_avg:92.15ms +step:963/1660 train_time:88743ms step_avg:92.15ms +step:964/1660 train_time:88836ms step_avg:92.15ms +step:965/1660 train_time:88929ms step_avg:92.15ms +step:966/1660 train_time:89021ms step_avg:92.15ms +step:967/1660 train_time:89115ms step_avg:92.16ms +step:968/1660 train_time:89207ms step_avg:92.16ms +step:969/1660 train_time:89300ms step_avg:92.16ms +step:970/1660 train_time:89394ms step_avg:92.16ms +step:971/1660 train_time:89487ms step_avg:92.16ms +step:972/1660 train_time:89580ms step_avg:92.16ms +step:973/1660 train_time:89671ms step_avg:92.16ms +step:974/1660 train_time:89763ms step_avg:92.16ms +step:975/1660 train_time:89856ms step_avg:92.16ms +step:976/1660 train_time:89948ms step_avg:92.16ms +step:977/1660 train_time:90041ms step_avg:92.16ms +step:978/1660 train_time:90134ms step_avg:92.16ms +step:979/1660 train_time:90227ms step_avg:92.16ms +step:980/1660 train_time:90320ms step_avg:92.16ms +step:981/1660 train_time:90413ms step_avg:92.16ms +step:982/1660 train_time:90505ms step_avg:92.16ms +step:983/1660 train_time:90598ms step_avg:92.16ms +step:984/1660 train_time:90691ms step_avg:92.17ms +step:985/1660 train_time:90783ms step_avg:92.17ms +step:986/1660 train_time:90878ms step_avg:92.17ms +step:987/1660 train_time:90971ms step_avg:92.17ms +step:988/1660 train_time:91063ms step_avg:92.17ms +step:989/1660 train_time:91156ms step_avg:92.17ms +step:990/1660 train_time:91249ms step_avg:92.17ms +step:991/1660 train_time:91341ms step_avg:92.17ms +step:992/1660 train_time:91434ms step_avg:92.17ms +step:993/1660 train_time:91527ms step_avg:92.17ms +step:994/1660 train_time:91620ms step_avg:92.17ms +step:995/1660 train_time:91712ms step_avg:92.17ms +step:996/1660 train_time:91805ms step_avg:92.17ms +step:997/1660 train_time:91898ms step_avg:92.17ms +step:998/1660 train_time:91991ms step_avg:92.18ms +step:999/1660 train_time:92083ms step_avg:92.18ms +step:1000/1660 train_time:92178ms step_avg:92.18ms +step:1000/1660 val_loss:3.4640 train_time:92272ms step_avg:92.27ms +step:1001/1660 train_time:92293ms step_avg:92.20ms +step:1002/1660 train_time:92369ms step_avg:92.18ms +step:1003/1660 train_time:92466ms step_avg:92.19ms +step:1004/1660 train_time:92558ms step_avg:92.19ms +step:1005/1660 train_time:92650ms step_avg:92.19ms +step:1006/1660 train_time:92741ms step_avg:92.19ms +step:1007/1660 train_time:92833ms step_avg:92.19ms +step:1008/1660 train_time:92925ms step_avg:92.19ms +step:1009/1660 train_time:93017ms step_avg:92.19ms +step:1010/1660 train_time:93109ms step_avg:92.19ms +step:1011/1660 train_time:93202ms step_avg:92.19ms +step:1012/1660 train_time:93298ms step_avg:92.19ms +step:1013/1660 train_time:93394ms step_avg:92.20ms +step:1014/1660 train_time:93489ms step_avg:92.20ms +step:1015/1660 train_time:93582ms step_avg:92.20ms +step:1016/1660 train_time:93674ms step_avg:92.20ms +step:1017/1660 train_time:93766ms step_avg:92.20ms +step:1018/1660 train_time:93858ms step_avg:92.20ms +step:1019/1660 train_time:93949ms step_avg:92.20ms +step:1020/1660 train_time:94041ms step_avg:92.20ms +step:1021/1660 train_time:94134ms step_avg:92.20ms +step:1022/1660 train_time:94227ms step_avg:92.20ms +step:1023/1660 train_time:94320ms step_avg:92.20ms +step:1024/1660 train_time:94416ms step_avg:92.20ms +step:1025/1660 train_time:94509ms step_avg:92.20ms +step:1026/1660 train_time:94602ms step_avg:92.21ms +step:1027/1660 train_time:94696ms step_avg:92.21ms +step:1028/1660 train_time:94788ms step_avg:92.21ms +step:1029/1660 train_time:94880ms step_avg:92.21ms +step:1030/1660 train_time:94972ms step_avg:92.21ms +step:1031/1660 train_time:95064ms step_avg:92.21ms +step:1032/1660 train_time:95156ms step_avg:92.21ms +step:1033/1660 train_time:95249ms step_avg:92.21ms +step:1034/1660 train_time:95342ms step_avg:92.21ms +step:1035/1660 train_time:95437ms step_avg:92.21ms +step:1036/1660 train_time:95529ms step_avg:92.21ms +step:1037/1660 train_time:95622ms step_avg:92.21ms +step:1038/1660 train_time:95716ms step_avg:92.21ms +step:1039/1660 train_time:95808ms step_avg:92.21ms +step:1040/1660 train_time:95900ms step_avg:92.21ms +step:1041/1660 train_time:95993ms step_avg:92.21ms +step:1042/1660 train_time:96085ms step_avg:92.21ms +step:1043/1660 train_time:96178ms step_avg:92.21ms +step:1044/1660 train_time:96270ms step_avg:92.21ms +step:1045/1660 train_time:96364ms step_avg:92.21ms +step:1046/1660 train_time:96457ms step_avg:92.22ms +step:1047/1660 train_time:96551ms step_avg:92.22ms +step:1048/1660 train_time:96644ms step_avg:92.22ms +step:1049/1660 train_time:96736ms step_avg:92.22ms +step:1050/1660 train_time:96829ms step_avg:92.22ms +step:1051/1660 train_time:96921ms step_avg:92.22ms +step:1052/1660 train_time:97014ms step_avg:92.22ms +step:1053/1660 train_time:97107ms step_avg:92.22ms +step:1054/1660 train_time:97199ms step_avg:92.22ms +step:1055/1660 train_time:97293ms step_avg:92.22ms +step:1056/1660 train_time:97386ms step_avg:92.22ms +step:1057/1660 train_time:97479ms step_avg:92.22ms +step:1058/1660 train_time:97572ms step_avg:92.22ms +step:1059/1660 train_time:97665ms step_avg:92.22ms +step:1060/1660 train_time:97758ms step_avg:92.22ms +step:1061/1660 train_time:97851ms step_avg:92.23ms +step:1062/1660 train_time:97943ms step_avg:92.23ms +step:1063/1660 train_time:98036ms step_avg:92.23ms +step:1064/1660 train_time:98129ms step_avg:92.23ms +step:1065/1660 train_time:98222ms step_avg:92.23ms +step:1066/1660 train_time:98315ms step_avg:92.23ms +step:1067/1660 train_time:98409ms step_avg:92.23ms +step:1068/1660 train_time:98501ms step_avg:92.23ms +step:1069/1660 train_time:98594ms step_avg:92.23ms +step:1070/1660 train_time:98688ms step_avg:92.23ms +step:1071/1660 train_time:98780ms step_avg:92.23ms +step:1072/1660 train_time:98874ms step_avg:92.23ms +step:1073/1660 train_time:98967ms step_avg:92.23ms +step:1074/1660 train_time:99059ms step_avg:92.23ms +step:1075/1660 train_time:99151ms step_avg:92.23ms +step:1076/1660 train_time:99243ms step_avg:92.23ms +step:1077/1660 train_time:99336ms step_avg:92.23ms +step:1078/1660 train_time:99429ms step_avg:92.23ms +step:1079/1660 train_time:99522ms step_avg:92.24ms +step:1080/1660 train_time:99616ms step_avg:92.24ms +step:1081/1660 train_time:99708ms step_avg:92.24ms +step:1082/1660 train_time:99800ms step_avg:92.24ms +step:1083/1660 train_time:99893ms step_avg:92.24ms +step:1084/1660 train_time:99986ms step_avg:92.24ms +step:1085/1660 train_time:100078ms step_avg:92.24ms +step:1086/1660 train_time:100171ms step_avg:92.24ms +step:1087/1660 train_time:100264ms step_avg:92.24ms +step:1088/1660 train_time:100357ms step_avg:92.24ms +step:1089/1660 train_time:100450ms step_avg:92.24ms +step:1090/1660 train_time:100542ms step_avg:92.24ms +step:1091/1660 train_time:100635ms step_avg:92.24ms +step:1092/1660 train_time:100728ms step_avg:92.24ms +step:1093/1660 train_time:100820ms step_avg:92.24ms +step:1094/1660 train_time:100916ms step_avg:92.24ms +step:1095/1660 train_time:101009ms step_avg:92.25ms +step:1096/1660 train_time:101101ms step_avg:92.25ms +step:1097/1660 train_time:101194ms step_avg:92.25ms +step:1098/1660 train_time:101287ms step_avg:92.25ms +step:1099/1660 train_time:101379ms step_avg:92.25ms +step:1100/1660 train_time:101472ms step_avg:92.25ms +step:1101/1660 train_time:101565ms step_avg:92.25ms +step:1102/1660 train_time:101657ms step_avg:92.25ms +step:1103/1660 train_time:101750ms step_avg:92.25ms +step:1104/1660 train_time:101843ms step_avg:92.25ms +step:1105/1660 train_time:101936ms step_avg:92.25ms +step:1106/1660 train_time:102029ms step_avg:92.25ms +step:1107/1660 train_time:102121ms step_avg:92.25ms +step:1108/1660 train_time:102214ms step_avg:92.25ms +step:1109/1660 train_time:102308ms step_avg:92.25ms +step:1110/1660 train_time:102401ms step_avg:92.25ms +step:1111/1660 train_time:102495ms step_avg:92.25ms +step:1112/1660 train_time:102589ms step_avg:92.26ms +step:1113/1660 train_time:102684ms step_avg:92.26ms +step:1114/1660 train_time:102776ms step_avg:92.26ms +step:1115/1660 train_time:102870ms step_avg:92.26ms +step:1116/1660 train_time:102963ms step_avg:92.26ms +step:1117/1660 train_time:103056ms step_avg:92.26ms +step:1118/1660 train_time:103149ms step_avg:92.26ms +step:1119/1660 train_time:103242ms step_avg:92.26ms +step:1120/1660 train_time:103335ms step_avg:92.26ms +step:1121/1660 train_time:103429ms step_avg:92.26ms +step:1122/1660 train_time:103522ms step_avg:92.27ms +step:1123/1660 train_time:103616ms step_avg:92.27ms +step:1124/1660 train_time:103711ms step_avg:92.27ms +step:1125/1660 train_time:103804ms step_avg:92.27ms +step:1125/1660 val_loss:3.4118 train_time:103899ms step_avg:92.35ms +step:1126/1660 train_time:103920ms step_avg:92.29ms +step:1127/1660 train_time:103997ms step_avg:92.28ms +step:1128/1660 train_time:104095ms step_avg:92.28ms +step:1129/1660 train_time:104187ms step_avg:92.28ms +step:1130/1660 train_time:104280ms step_avg:92.28ms +step:1131/1660 train_time:104372ms step_avg:92.28ms +step:1132/1660 train_time:104464ms step_avg:92.28ms +step:1133/1660 train_time:104557ms step_avg:92.28ms +step:1134/1660 train_time:104649ms step_avg:92.28ms +step:1135/1660 train_time:104742ms step_avg:92.28ms +step:1136/1660 train_time:104838ms step_avg:92.29ms +step:1137/1660 train_time:104933ms step_avg:92.29ms +step:1138/1660 train_time:105029ms step_avg:92.29ms +step:1139/1660 train_time:105125ms step_avg:92.30ms +step:1140/1660 train_time:105220ms step_avg:92.30ms +step:1141/1660 train_time:105312ms step_avg:92.30ms +step:1142/1660 train_time:105404ms step_avg:92.30ms +step:1143/1660 train_time:105496ms step_avg:92.30ms +step:1144/1660 train_time:105588ms step_avg:92.30ms +step:1145/1660 train_time:105681ms step_avg:92.30ms +step:1146/1660 train_time:105775ms step_avg:92.30ms +step:1147/1660 train_time:105869ms step_avg:92.30ms +step:1148/1660 train_time:105964ms step_avg:92.30ms +step:1149/1660 train_time:106059ms step_avg:92.31ms +step:1150/1660 train_time:106153ms step_avg:92.31ms +step:1151/1660 train_time:106246ms step_avg:92.31ms +step:1152/1660 train_time:106339ms step_avg:92.31ms +step:1153/1660 train_time:106432ms step_avg:92.31ms +step:1154/1660 train_time:106525ms step_avg:92.31ms +step:1155/1660 train_time:106618ms step_avg:92.31ms +step:1156/1660 train_time:106711ms step_avg:92.31ms +step:1157/1660 train_time:106805ms step_avg:92.31ms +step:1158/1660 train_time:106899ms step_avg:92.31ms +step:1159/1660 train_time:106993ms step_avg:92.32ms +step:1160/1660 train_time:107088ms step_avg:92.32ms +step:1161/1660 train_time:107183ms step_avg:92.32ms +step:1162/1660 train_time:107278ms step_avg:92.32ms +step:1163/1660 train_time:107370ms step_avg:92.32ms +step:1164/1660 train_time:107463ms step_avg:92.32ms +step:1165/1660 train_time:107555ms step_avg:92.32ms +step:1166/1660 train_time:107648ms step_avg:92.32ms +step:1167/1660 train_time:107742ms step_avg:92.32ms +step:1168/1660 train_time:107834ms step_avg:92.32ms +step:1169/1660 train_time:107929ms step_avg:92.33ms +step:1170/1660 train_time:108024ms step_avg:92.33ms +step:1171/1660 train_time:108120ms step_avg:92.33ms +step:1172/1660 train_time:108215ms step_avg:92.33ms +step:1173/1660 train_time:108308ms step_avg:92.33ms +step:1174/1660 train_time:108401ms step_avg:92.33ms +step:1175/1660 train_time:108494ms step_avg:92.34ms +step:1176/1660 train_time:108587ms step_avg:92.34ms +step:1177/1660 train_time:108680ms step_avg:92.34ms +step:1178/1660 train_time:108773ms step_avg:92.34ms +step:1179/1660 train_time:108866ms step_avg:92.34ms +step:1180/1660 train_time:108960ms step_avg:92.34ms +step:1181/1660 train_time:109055ms step_avg:92.34ms +step:1182/1660 train_time:109150ms step_avg:92.34ms +step:1183/1660 train_time:109243ms step_avg:92.34ms +step:1184/1660 train_time:109337ms step_avg:92.35ms +step:1185/1660 train_time:109430ms step_avg:92.35ms +step:1186/1660 train_time:109524ms step_avg:92.35ms +step:1187/1660 train_time:109617ms step_avg:92.35ms +step:1188/1660 train_time:109710ms step_avg:92.35ms +step:1189/1660 train_time:109803ms step_avg:92.35ms +step:1190/1660 train_time:109897ms step_avg:92.35ms +step:1191/1660 train_time:109991ms step_avg:92.35ms +step:1192/1660 train_time:110086ms step_avg:92.35ms +step:1193/1660 train_time:110181ms step_avg:92.36ms +step:1194/1660 train_time:110275ms step_avg:92.36ms +step:1195/1660 train_time:110367ms step_avg:92.36ms +step:1196/1660 train_time:110460ms step_avg:92.36ms +step:1197/1660 train_time:110553ms step_avg:92.36ms +step:1198/1660 train_time:110646ms step_avg:92.36ms +step:1199/1660 train_time:110739ms step_avg:92.36ms +step:1200/1660 train_time:110833ms step_avg:92.36ms +step:1201/1660 train_time:110926ms step_avg:92.36ms +step:1202/1660 train_time:111020ms step_avg:92.36ms +step:1203/1660 train_time:111114ms step_avg:92.36ms +step:1204/1660 train_time:111208ms step_avg:92.37ms +step:1205/1660 train_time:111303ms step_avg:92.37ms +step:1206/1660 train_time:111395ms step_avg:92.37ms +step:1207/1660 train_time:111489ms step_avg:92.37ms +step:1208/1660 train_time:111583ms step_avg:92.37ms +step:1209/1660 train_time:111676ms step_avg:92.37ms +step:1210/1660 train_time:111768ms step_avg:92.37ms +step:1211/1660 train_time:111861ms step_avg:92.37ms +step:1212/1660 train_time:111955ms step_avg:92.37ms +step:1213/1660 train_time:112049ms step_avg:92.37ms +step:1214/1660 train_time:112143ms step_avg:92.38ms +step:1215/1660 train_time:112237ms step_avg:92.38ms +step:1216/1660 train_time:112330ms step_avg:92.38ms +step:1217/1660 train_time:112425ms step_avg:92.38ms +step:1218/1660 train_time:112519ms step_avg:92.38ms +step:1219/1660 train_time:112612ms step_avg:92.38ms +step:1220/1660 train_time:112705ms step_avg:92.38ms +step:1221/1660 train_time:112798ms step_avg:92.38ms +step:1222/1660 train_time:112891ms step_avg:92.38ms +step:1223/1660 train_time:112984ms step_avg:92.38ms +step:1224/1660 train_time:113078ms step_avg:92.38ms +step:1225/1660 train_time:113171ms step_avg:92.38ms +step:1226/1660 train_time:113265ms step_avg:92.39ms +step:1227/1660 train_time:113359ms step_avg:92.39ms +step:1228/1660 train_time:113452ms step_avg:92.39ms +step:1229/1660 train_time:113545ms step_avg:92.39ms +step:1230/1660 train_time:113639ms step_avg:92.39ms +step:1231/1660 train_time:113731ms step_avg:92.39ms +step:1232/1660 train_time:113825ms step_avg:92.39ms +step:1233/1660 train_time:113918ms step_avg:92.39ms +step:1234/1660 train_time:114011ms step_avg:92.39ms +step:1235/1660 train_time:114104ms step_avg:92.39ms +step:1236/1660 train_time:114198ms step_avg:92.39ms +step:1237/1660 train_time:114292ms step_avg:92.39ms +step:1238/1660 train_time:114387ms step_avg:92.40ms +step:1239/1660 train_time:114481ms step_avg:92.40ms +step:1240/1660 train_time:114574ms step_avg:92.40ms +step:1241/1660 train_time:114668ms step_avg:92.40ms +step:1242/1660 train_time:114761ms step_avg:92.40ms +step:1243/1660 train_time:114855ms step_avg:92.40ms +step:1244/1660 train_time:114947ms step_avg:92.40ms +step:1245/1660 train_time:115040ms step_avg:92.40ms +step:1246/1660 train_time:115133ms step_avg:92.40ms +step:1247/1660 train_time:115226ms step_avg:92.40ms +step:1248/1660 train_time:115320ms step_avg:92.40ms +step:1249/1660 train_time:115413ms step_avg:92.40ms +step:1250/1660 train_time:115506ms step_avg:92.41ms +step:1250/1660 val_loss:3.3735 train_time:115601ms step_avg:92.48ms +step:1251/1660 train_time:115622ms step_avg:92.42ms +step:1252/1660 train_time:115701ms step_avg:92.41ms +step:1253/1660 train_time:115798ms step_avg:92.42ms +step:1254/1660 train_time:115892ms step_avg:92.42ms +step:1255/1660 train_time:115984ms step_avg:92.42ms +step:1256/1660 train_time:116076ms step_avg:92.42ms +step:1257/1660 train_time:116168ms step_avg:92.42ms +step:1258/1660 train_time:116259ms step_avg:92.42ms +step:1259/1660 train_time:116352ms step_avg:92.42ms +step:1260/1660 train_time:116444ms step_avg:92.42ms +step:1261/1660 train_time:116538ms step_avg:92.42ms +step:1262/1660 train_time:116635ms step_avg:92.42ms +step:1263/1660 train_time:116731ms step_avg:92.42ms +step:1264/1660 train_time:116826ms step_avg:92.43ms +step:1265/1660 train_time:116918ms step_avg:92.43ms +step:1266/1660 train_time:117012ms step_avg:92.43ms +step:1267/1660 train_time:117105ms step_avg:92.43ms +step:1268/1660 train_time:117197ms step_avg:92.43ms +step:1269/1660 train_time:117290ms step_avg:92.43ms +step:1270/1660 train_time:117382ms step_avg:92.43ms +step:1271/1660 train_time:117477ms step_avg:92.43ms +step:1272/1660 train_time:117570ms step_avg:92.43ms +step:1273/1660 train_time:117665ms step_avg:92.43ms +step:1274/1660 train_time:117761ms step_avg:92.43ms +step:1275/1660 train_time:117854ms step_avg:92.43ms +step:1276/1660 train_time:117948ms step_avg:92.44ms +step:1277/1660 train_time:118041ms step_avg:92.44ms +step:1278/1660 train_time:118134ms step_avg:92.44ms +step:1279/1660 train_time:118226ms step_avg:92.44ms +step:1280/1660 train_time:118319ms step_avg:92.44ms +step:1281/1660 train_time:118413ms step_avg:92.44ms +step:1282/1660 train_time:118507ms step_avg:92.44ms +step:1283/1660 train_time:118600ms step_avg:92.44ms +step:1284/1660 train_time:118697ms step_avg:92.44ms +step:1285/1660 train_time:118792ms step_avg:92.44ms +step:1286/1660 train_time:118886ms step_avg:92.45ms +step:1287/1660 train_time:118978ms step_avg:92.45ms +step:1288/1660 train_time:119072ms step_avg:92.45ms +step:1289/1660 train_time:119165ms step_avg:92.45ms +step:1290/1660 train_time:119258ms step_avg:92.45ms +step:1291/1660 train_time:119352ms step_avg:92.45ms +step:1292/1660 train_time:119445ms step_avg:92.45ms +step:1293/1660 train_time:119538ms step_avg:92.45ms +step:1294/1660 train_time:119632ms step_avg:92.45ms +step:1295/1660 train_time:119726ms step_avg:92.45ms +step:1296/1660 train_time:119821ms step_avg:92.45ms +step:1297/1660 train_time:119916ms step_avg:92.46ms +step:1298/1660 train_time:120010ms step_avg:92.46ms +step:1299/1660 train_time:120103ms step_avg:92.46ms +step:1300/1660 train_time:120196ms step_avg:92.46ms +step:1301/1660 train_time:120289ms step_avg:92.46ms +step:1302/1660 train_time:120381ms step_avg:92.46ms +step:1303/1660 train_time:120475ms step_avg:92.46ms +step:1304/1660 train_time:120569ms step_avg:92.46ms +step:1305/1660 train_time:120662ms step_avg:92.46ms +step:1306/1660 train_time:120756ms step_avg:92.46ms +step:1307/1660 train_time:120851ms step_avg:92.46ms +step:1308/1660 train_time:120946ms step_avg:92.47ms +step:1309/1660 train_time:121039ms step_avg:92.47ms +step:1310/1660 train_time:121131ms step_avg:92.47ms +step:1311/1660 train_time:121225ms step_avg:92.47ms +step:1312/1660 train_time:121317ms step_avg:92.47ms +step:1313/1660 train_time:121411ms step_avg:92.47ms +step:1314/1660 train_time:121504ms step_avg:92.47ms +step:1315/1660 train_time:121597ms step_avg:92.47ms +step:1316/1660 train_time:121691ms step_avg:92.47ms +step:1317/1660 train_time:121785ms step_avg:92.47ms +step:1318/1660 train_time:121879ms step_avg:92.47ms +step:1319/1660 train_time:121973ms step_avg:92.47ms +step:1320/1660 train_time:122067ms step_avg:92.47ms +step:1321/1660 train_time:122159ms step_avg:92.47ms +step:1322/1660 train_time:122253ms step_avg:92.48ms +step:1323/1660 train_time:122347ms step_avg:92.48ms +step:1324/1660 train_time:122440ms step_avg:92.48ms +step:1325/1660 train_time:122535ms step_avg:92.48ms +step:1326/1660 train_time:122630ms step_avg:92.48ms +step:1327/1660 train_time:122722ms step_avg:92.48ms +step:1328/1660 train_time:122818ms step_avg:92.48ms +step:1329/1660 train_time:122912ms step_avg:92.48ms +step:1330/1660 train_time:123006ms step_avg:92.49ms +step:1331/1660 train_time:123099ms step_avg:92.49ms +step:1332/1660 train_time:123193ms step_avg:92.49ms +step:1333/1660 train_time:123287ms step_avg:92.49ms +step:1334/1660 train_time:123380ms step_avg:92.49ms +step:1335/1660 train_time:123473ms step_avg:92.49ms +step:1336/1660 train_time:123567ms step_avg:92.49ms +step:1337/1660 train_time:123660ms step_avg:92.49ms +step:1338/1660 train_time:123754ms step_avg:92.49ms +step:1339/1660 train_time:123848ms step_avg:92.49ms +step:1340/1660 train_time:123941ms step_avg:92.49ms +step:1341/1660 train_time:124037ms step_avg:92.50ms +step:1342/1660 train_time:124129ms step_avg:92.50ms +step:1343/1660 train_time:124222ms step_avg:92.50ms +step:1344/1660 train_time:124315ms step_avg:92.50ms +step:1345/1660 train_time:124409ms step_avg:92.50ms +step:1346/1660 train_time:124502ms step_avg:92.50ms +step:1347/1660 train_time:124596ms step_avg:92.50ms +step:1348/1660 train_time:124690ms step_avg:92.50ms +step:1349/1660 train_time:124783ms step_avg:92.50ms +step:1350/1660 train_time:124878ms step_avg:92.50ms +step:1351/1660 train_time:124973ms step_avg:92.50ms +step:1352/1660 train_time:125068ms step_avg:92.51ms +step:1353/1660 train_time:125160ms step_avg:92.51ms +step:1354/1660 train_time:125254ms step_avg:92.51ms +step:1355/1660 train_time:125348ms step_avg:92.51ms +step:1356/1660 train_time:125441ms step_avg:92.51ms +step:1357/1660 train_time:125535ms step_avg:92.51ms +step:1358/1660 train_time:125628ms step_avg:92.51ms +step:1359/1660 train_time:125721ms step_avg:92.51ms +step:1360/1660 train_time:125815ms step_avg:92.51ms +step:1361/1660 train_time:125909ms step_avg:92.51ms +step:1362/1660 train_time:126002ms step_avg:92.51ms +step:1363/1660 train_time:126096ms step_avg:92.51ms +step:1364/1660 train_time:126190ms step_avg:92.51ms +step:1365/1660 train_time:126283ms step_avg:92.51ms +step:1366/1660 train_time:126377ms step_avg:92.52ms +step:1367/1660 train_time:126470ms step_avg:92.52ms +step:1368/1660 train_time:126564ms step_avg:92.52ms +step:1369/1660 train_time:126657ms step_avg:92.52ms +step:1370/1660 train_time:126749ms step_avg:92.52ms +step:1371/1660 train_time:126843ms step_avg:92.52ms +step:1372/1660 train_time:126937ms step_avg:92.52ms +step:1373/1660 train_time:127031ms step_avg:92.52ms +step:1374/1660 train_time:127125ms step_avg:92.52ms +step:1375/1660 train_time:127218ms step_avg:92.52ms +step:1375/1660 val_loss:3.3391 train_time:127314ms step_avg:92.59ms +step:1376/1660 train_time:127336ms step_avg:92.54ms +step:1377/1660 train_time:127412ms step_avg:92.53ms +step:1378/1660 train_time:127507ms step_avg:92.53ms +step:1379/1660 train_time:127601ms step_avg:92.53ms +step:1380/1660 train_time:127693ms step_avg:92.53ms +step:1381/1660 train_time:127787ms step_avg:92.53ms +step:1382/1660 train_time:127880ms step_avg:92.53ms +step:1383/1660 train_time:127972ms step_avg:92.53ms +step:1384/1660 train_time:128066ms step_avg:92.53ms +step:1385/1660 train_time:128158ms step_avg:92.53ms +step:1386/1660 train_time:128252ms step_avg:92.53ms +step:1387/1660 train_time:128350ms step_avg:92.54ms +step:1388/1660 train_time:128447ms step_avg:92.54ms +step:1389/1660 train_time:128542ms step_avg:92.54ms +step:1390/1660 train_time:128635ms step_avg:92.54ms +step:1391/1660 train_time:128727ms step_avg:92.54ms +step:1392/1660 train_time:128820ms step_avg:92.54ms +step:1393/1660 train_time:128912ms step_avg:92.54ms +step:1394/1660 train_time:129005ms step_avg:92.54ms +step:1395/1660 train_time:129099ms step_avg:92.54ms +step:1396/1660 train_time:129191ms step_avg:92.54ms +step:1397/1660 train_time:129286ms step_avg:92.55ms +step:1398/1660 train_time:129380ms step_avg:92.55ms +step:1399/1660 train_time:129474ms step_avg:92.55ms +step:1400/1660 train_time:129569ms step_avg:92.55ms +step:1401/1660 train_time:129662ms step_avg:92.55ms +step:1402/1660 train_time:129755ms step_avg:92.55ms +step:1403/1660 train_time:129848ms step_avg:92.55ms +step:1404/1660 train_time:129942ms step_avg:92.55ms +step:1405/1660 train_time:130035ms step_avg:92.55ms +step:1406/1660 train_time:130128ms step_avg:92.55ms +step:1407/1660 train_time:130221ms step_avg:92.55ms +step:1408/1660 train_time:130315ms step_avg:92.55ms +step:1409/1660 train_time:130409ms step_avg:92.55ms +step:1410/1660 train_time:130505ms step_avg:92.56ms +step:1411/1660 train_time:130598ms step_avg:92.56ms +step:1412/1660 train_time:130692ms step_avg:92.56ms +step:1413/1660 train_time:130785ms step_avg:92.56ms +step:1414/1660 train_time:130879ms step_avg:92.56ms +step:1415/1660 train_time:130971ms step_avg:92.56ms +step:1416/1660 train_time:131064ms step_avg:92.56ms +step:1417/1660 train_time:131157ms step_avg:92.56ms +step:1418/1660 train_time:131251ms step_avg:92.56ms +step:1419/1660 train_time:131347ms step_avg:92.56ms +step:1420/1660 train_time:131441ms step_avg:92.56ms +step:1421/1660 train_time:131535ms step_avg:92.56ms +step:1422/1660 train_time:131628ms step_avg:92.57ms +step:1423/1660 train_time:131722ms step_avg:92.57ms +step:1424/1660 train_time:131814ms step_avg:92.57ms +step:1425/1660 train_time:131907ms step_avg:92.57ms +step:1426/1660 train_time:132000ms step_avg:92.57ms +step:1427/1660 train_time:132094ms step_avg:92.57ms +step:1428/1660 train_time:132187ms step_avg:92.57ms +step:1429/1660 train_time:132281ms step_avg:92.57ms +step:1430/1660 train_time:132375ms step_avg:92.57ms +step:1431/1660 train_time:132468ms step_avg:92.57ms +step:1432/1660 train_time:132562ms step_avg:92.57ms +step:1433/1660 train_time:132655ms step_avg:92.57ms +step:1434/1660 train_time:132748ms step_avg:92.57ms +step:1435/1660 train_time:132843ms step_avg:92.57ms +step:1436/1660 train_time:132935ms step_avg:92.57ms +step:1437/1660 train_time:133029ms step_avg:92.57ms +step:1438/1660 train_time:133122ms step_avg:92.57ms +step:1439/1660 train_time:133216ms step_avg:92.58ms +step:1440/1660 train_time:133310ms step_avg:92.58ms +step:1441/1660 train_time:133404ms step_avg:92.58ms +step:1442/1660 train_time:133497ms step_avg:92.58ms +step:1443/1660 train_time:133591ms step_avg:92.58ms +step:1444/1660 train_time:133684ms step_avg:92.58ms +step:1445/1660 train_time:133779ms step_avg:92.58ms +step:1446/1660 train_time:133871ms step_avg:92.58ms +step:1447/1660 train_time:133964ms step_avg:92.58ms +step:1448/1660 train_time:134058ms step_avg:92.58ms +step:1449/1660 train_time:134151ms step_avg:92.58ms +step:1450/1660 train_time:134246ms step_avg:92.58ms +step:1451/1660 train_time:134341ms step_avg:92.58ms +step:1452/1660 train_time:134434ms step_avg:92.59ms +step:1453/1660 train_time:134528ms step_avg:92.59ms +step:1454/1660 train_time:134621ms step_avg:92.59ms +step:1455/1660 train_time:134715ms step_avg:92.59ms +step:1456/1660 train_time:134808ms step_avg:92.59ms +step:1457/1660 train_time:134901ms step_avg:92.59ms +step:1458/1660 train_time:134995ms step_avg:92.59ms +step:1459/1660 train_time:135088ms step_avg:92.59ms +step:1460/1660 train_time:135181ms step_avg:92.59ms +step:1461/1660 train_time:135274ms step_avg:92.59ms +step:1462/1660 train_time:135368ms step_avg:92.59ms +step:1463/1660 train_time:135461ms step_avg:92.59ms +step:1464/1660 train_time:135555ms step_avg:92.59ms +step:1465/1660 train_time:135648ms step_avg:92.59ms +step:1466/1660 train_time:135742ms step_avg:92.59ms +step:1467/1660 train_time:135836ms step_avg:92.59ms +step:1468/1660 train_time:135930ms step_avg:92.60ms +step:1469/1660 train_time:136023ms step_avg:92.60ms +step:1470/1660 train_time:136116ms step_avg:92.60ms +step:1471/1660 train_time:136209ms step_avg:92.60ms +step:1472/1660 train_time:136304ms step_avg:92.60ms +step:1473/1660 train_time:136398ms step_avg:92.60ms +step:1474/1660 train_time:136491ms step_avg:92.60ms +step:1475/1660 train_time:136584ms step_avg:92.60ms +step:1476/1660 train_time:136679ms step_avg:92.60ms +step:1477/1660 train_time:136772ms step_avg:92.60ms +step:1478/1660 train_time:136865ms step_avg:92.60ms +step:1479/1660 train_time:136958ms step_avg:92.60ms +step:1480/1660 train_time:137051ms step_avg:92.60ms +step:1481/1660 train_time:137146ms step_avg:92.60ms +step:1482/1660 train_time:137239ms step_avg:92.60ms +step:1483/1660 train_time:137332ms step_avg:92.60ms +step:1484/1660 train_time:137426ms step_avg:92.61ms +step:1485/1660 train_time:137520ms step_avg:92.61ms +step:1486/1660 train_time:137613ms step_avg:92.61ms +step:1487/1660 train_time:137707ms step_avg:92.61ms +step:1488/1660 train_time:137800ms step_avg:92.61ms +step:1489/1660 train_time:137892ms step_avg:92.61ms +step:1490/1660 train_time:137986ms step_avg:92.61ms +step:1491/1660 train_time:138080ms step_avg:92.61ms +step:1492/1660 train_time:138174ms step_avg:92.61ms +step:1493/1660 train_time:138268ms step_avg:92.61ms +step:1494/1660 train_time:138361ms step_avg:92.61ms +step:1495/1660 train_time:138454ms step_avg:92.61ms +step:1496/1660 train_time:138548ms step_avg:92.61ms +step:1497/1660 train_time:138642ms step_avg:92.61ms +step:1498/1660 train_time:138736ms step_avg:92.61ms +step:1499/1660 train_time:138829ms step_avg:92.61ms +step:1500/1660 train_time:138923ms step_avg:92.62ms +step:1500/1660 val_loss:3.3095 train_time:139017ms step_avg:92.68ms +step:1501/1660 train_time:139039ms step_avg:92.63ms +step:1502/1660 train_time:139119ms step_avg:92.62ms +step:1503/1660 train_time:139217ms step_avg:92.63ms +step:1504/1660 train_time:139310ms step_avg:92.63ms +step:1505/1660 train_time:139402ms step_avg:92.63ms +step:1506/1660 train_time:139494ms step_avg:92.63ms +step:1507/1660 train_time:139586ms step_avg:92.62ms +step:1508/1660 train_time:139679ms step_avg:92.63ms +step:1509/1660 train_time:139772ms step_avg:92.63ms +step:1510/1660 train_time:139864ms step_avg:92.63ms +step:1511/1660 train_time:139960ms step_avg:92.63ms +step:1512/1660 train_time:140057ms step_avg:92.63ms +step:1513/1660 train_time:140152ms step_avg:92.63ms +step:1514/1660 train_time:140246ms step_avg:92.63ms +step:1515/1660 train_time:140340ms step_avg:92.63ms +step:1516/1660 train_time:140433ms step_avg:92.63ms +step:1517/1660 train_time:140525ms step_avg:92.63ms +step:1518/1660 train_time:140618ms step_avg:92.63ms +step:1519/1660 train_time:140710ms step_avg:92.63ms +step:1520/1660 train_time:140803ms step_avg:92.63ms +step:1521/1660 train_time:140897ms step_avg:92.63ms +step:1522/1660 train_time:140990ms step_avg:92.63ms +step:1523/1660 train_time:141085ms step_avg:92.64ms +step:1524/1660 train_time:141179ms step_avg:92.64ms +step:1525/1660 train_time:141273ms step_avg:92.64ms +step:1526/1660 train_time:141367ms step_avg:92.64ms +step:1527/1660 train_time:141460ms step_avg:92.64ms +step:1528/1660 train_time:141553ms step_avg:92.64ms +step:1529/1660 train_time:141645ms step_avg:92.64ms +step:1530/1660 train_time:141738ms step_avg:92.64ms +step:1531/1660 train_time:141830ms step_avg:92.64ms +step:1532/1660 train_time:141924ms step_avg:92.64ms +step:1533/1660 train_time:142018ms step_avg:92.64ms +step:1534/1660 train_time:142112ms step_avg:92.64ms +step:1535/1660 train_time:142206ms step_avg:92.64ms +step:1536/1660 train_time:142300ms step_avg:92.64ms +step:1537/1660 train_time:142395ms step_avg:92.64ms +step:1538/1660 train_time:142487ms step_avg:92.64ms +step:1539/1660 train_time:142580ms step_avg:92.64ms +step:1540/1660 train_time:142674ms step_avg:92.65ms +step:1541/1660 train_time:142766ms step_avg:92.65ms +step:1542/1660 train_time:142860ms step_avg:92.65ms +step:1543/1660 train_time:142954ms step_avg:92.65ms +step:1544/1660 train_time:143049ms step_avg:92.65ms +step:1545/1660 train_time:143142ms step_avg:92.65ms +step:1546/1660 train_time:143236ms step_avg:92.65ms +step:1547/1660 train_time:143329ms step_avg:92.65ms +step:1548/1660 train_time:143422ms step_avg:92.65ms +step:1549/1660 train_time:143515ms step_avg:92.65ms +step:1550/1660 train_time:143608ms step_avg:92.65ms +step:1551/1660 train_time:143701ms step_avg:92.65ms +step:1552/1660 train_time:143794ms step_avg:92.65ms +step:1553/1660 train_time:143888ms step_avg:92.65ms +step:1554/1660 train_time:143981ms step_avg:92.65ms +step:1555/1660 train_time:144077ms step_avg:92.65ms +step:1556/1660 train_time:144173ms step_avg:92.66ms +step:1557/1660 train_time:144266ms step_avg:92.66ms +step:1558/1660 train_time:144360ms step_avg:92.66ms +step:1559/1660 train_time:144454ms step_avg:92.66ms +step:1560/1660 train_time:144547ms step_avg:92.66ms +step:1561/1660 train_time:144641ms step_avg:92.66ms +step:1562/1660 train_time:144733ms step_avg:92.66ms +step:1563/1660 train_time:144826ms step_avg:92.66ms +step:1564/1660 train_time:144919ms step_avg:92.66ms +step:1565/1660 train_time:145013ms step_avg:92.66ms +step:1566/1660 train_time:145107ms step_avg:92.66ms +step:1567/1660 train_time:145201ms step_avg:92.66ms +step:1568/1660 train_time:145295ms step_avg:92.66ms +step:1569/1660 train_time:145388ms step_avg:92.66ms +step:1570/1660 train_time:145481ms step_avg:92.66ms +step:1571/1660 train_time:145574ms step_avg:92.66ms +step:1572/1660 train_time:145668ms step_avg:92.66ms +step:1573/1660 train_time:145761ms step_avg:92.66ms +step:1574/1660 train_time:145853ms step_avg:92.66ms +step:1575/1660 train_time:145946ms step_avg:92.66ms +step:1576/1660 train_time:146040ms step_avg:92.66ms +step:1577/1660 train_time:146133ms step_avg:92.67ms +step:1578/1660 train_time:146227ms step_avg:92.67ms +step:1579/1660 train_time:146321ms step_avg:92.67ms +step:1580/1660 train_time:146415ms step_avg:92.67ms +step:1581/1660 train_time:146509ms step_avg:92.67ms +step:1582/1660 train_time:146602ms step_avg:92.67ms +step:1583/1660 train_time:146695ms step_avg:92.67ms +step:1584/1660 train_time:146788ms step_avg:92.67ms +step:1585/1660 train_time:146881ms step_avg:92.67ms +step:1586/1660 train_time:146976ms step_avg:92.67ms +step:1587/1660 train_time:147069ms step_avg:92.67ms +step:1588/1660 train_time:147163ms step_avg:92.67ms +step:1589/1660 train_time:147258ms step_avg:92.67ms +step:1590/1660 train_time:147352ms step_avg:92.67ms +step:1591/1660 train_time:147445ms step_avg:92.67ms +step:1592/1660 train_time:147539ms step_avg:92.68ms +step:1593/1660 train_time:147633ms step_avg:92.68ms +step:1594/1660 train_time:147726ms step_avg:92.68ms +step:1595/1660 train_time:147819ms step_avg:92.68ms +step:1596/1660 train_time:147912ms step_avg:92.68ms +step:1597/1660 train_time:148006ms step_avg:92.68ms +step:1598/1660 train_time:148100ms step_avg:92.68ms +step:1599/1660 train_time:148193ms step_avg:92.68ms +step:1600/1660 train_time:148287ms step_avg:92.68ms +step:1601/1660 train_time:148382ms step_avg:92.68ms +step:1602/1660 train_time:148475ms step_avg:92.68ms +step:1603/1660 train_time:148569ms step_avg:92.68ms +step:1604/1660 train_time:148663ms step_avg:92.68ms +step:1605/1660 train_time:148756ms step_avg:92.68ms +step:1606/1660 train_time:148849ms step_avg:92.68ms +step:1607/1660 train_time:148943ms step_avg:92.68ms +step:1608/1660 train_time:149036ms step_avg:92.68ms +step:1609/1660 train_time:149129ms step_avg:92.68ms +step:1610/1660 train_time:149223ms step_avg:92.69ms +step:1611/1660 train_time:149317ms step_avg:92.69ms +step:1612/1660 train_time:149411ms step_avg:92.69ms +step:1613/1660 train_time:149505ms step_avg:92.69ms +step:1614/1660 train_time:149600ms step_avg:92.69ms +step:1615/1660 train_time:149693ms step_avg:92.69ms +step:1616/1660 train_time:149786ms step_avg:92.69ms +step:1617/1660 train_time:149881ms step_avg:92.69ms +step:1618/1660 train_time:149976ms step_avg:92.69ms +step:1619/1660 train_time:150069ms step_avg:92.69ms +step:1620/1660 train_time:150162ms step_avg:92.69ms +step:1621/1660 train_time:150256ms step_avg:92.69ms +step:1622/1660 train_time:150351ms step_avg:92.69ms +step:1623/1660 train_time:150444ms step_avg:92.69ms +step:1624/1660 train_time:150538ms step_avg:92.70ms +step:1625/1660 train_time:150632ms step_avg:92.70ms +step:1625/1660 val_loss:3.2848 train_time:150725ms step_avg:92.75ms +step:1626/1660 train_time:150747ms step_avg:92.71ms +step:1627/1660 train_time:150822ms step_avg:92.70ms +step:1628/1660 train_time:150921ms step_avg:92.70ms +step:1629/1660 train_time:151014ms step_avg:92.70ms +step:1630/1660 train_time:151107ms step_avg:92.70ms +step:1631/1660 train_time:151200ms step_avg:92.70ms +step:1632/1660 train_time:151293ms step_avg:92.70ms +step:1633/1660 train_time:151385ms step_avg:92.70ms +step:1634/1660 train_time:151478ms step_avg:92.70ms +step:1635/1660 train_time:151571ms step_avg:92.70ms +step:1636/1660 train_time:151665ms step_avg:92.70ms +step:1637/1660 train_time:151761ms step_avg:92.71ms +step:1638/1660 train_time:151856ms step_avg:92.71ms +step:1639/1660 train_time:151951ms step_avg:92.71ms +step:1640/1660 train_time:152045ms step_avg:92.71ms +step:1641/1660 train_time:152139ms step_avg:92.71ms +step:1642/1660 train_time:152232ms step_avg:92.71ms +step:1643/1660 train_time:152325ms step_avg:92.71ms +step:1644/1660 train_time:152417ms step_avg:92.71ms +step:1645/1660 train_time:152510ms step_avg:92.71ms +step:1646/1660 train_time:152605ms step_avg:92.71ms +step:1647/1660 train_time:152698ms step_avg:92.71ms +step:1648/1660 train_time:152793ms step_avg:92.71ms +step:1649/1660 train_time:152889ms step_avg:92.72ms +step:1650/1660 train_time:152984ms step_avg:92.72ms +step:1651/1660 train_time:153076ms step_avg:92.72ms +step:1652/1660 train_time:153170ms step_avg:92.72ms +step:1653/1660 train_time:153264ms step_avg:92.72ms +step:1654/1660 train_time:153356ms step_avg:92.72ms +step:1655/1660 train_time:153449ms step_avg:92.72ms +step:1656/1660 train_time:153542ms step_avg:92.72ms +step:1657/1660 train_time:153635ms step_avg:92.72ms +step:1658/1660 train_time:153729ms step_avg:92.72ms +step:1659/1660 train_time:153823ms step_avg:92.72ms +step:1660/1660 train_time:153918ms step_avg:92.72ms +step:1660/1660 val_loss:3.2769 train_time:154015ms step_avg:92.78ms +peak memory allocated: 32002 MiB reserved: 46316 MiB diff --git a/records/091525_ThreadingFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt b/records/091525_ThreadingFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt new file mode 100644 index 000000000..90683742d --- /dev/null +++ b/records/091525_ThreadingFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:25:29 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 28C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 30C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 28C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 27C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 30C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 28C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 185791 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 185792 C /usr/bin/python3 614MiB | +| 0 N/A N/A 185793 C /usr/bin/python3 614MiB | +| 0 N/A N/A 185794 C /usr/bin/python3 614MiB | +| 0 N/A N/A 185795 C /usr/bin/python3 614MiB | +| 0 N/A N/A 185796 C /usr/bin/python3 614MiB | +| 0 N/A N/A 185797 C /usr/bin/python3 614MiB | +| 0 N/A N/A 185798 C /usr/bin/python3 614MiB | +| 1 N/A N/A 185792 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 185793 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 185794 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 185795 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 185796 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 185797 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 185798 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:146ms step_avg:145.92ms +step:2/1660 train_time:167ms step_avg:83.66ms +step:3/1660 train_time:233ms step_avg:77.78ms +step:4/1660 train_time:322ms step_avg:80.61ms +step:5/1660 train_time:413ms step_avg:82.54ms +step:6/1660 train_time:503ms step_avg:83.81ms +step:7/1660 train_time:593ms step_avg:84.73ms +step:8/1660 train_time:684ms step_avg:85.51ms +step:9/1660 train_time:775ms step_avg:86.13ms +step:10/1660 train_time:867ms step_avg:86.66ms +step:11/1660 train_time:957ms step_avg:86.99ms +step:12/1660 train_time:1049ms step_avg:87.43ms +step:13/1660 train_time:1144ms step_avg:88.01ms +step:14/1660 train_time:1237ms step_avg:88.38ms +step:15/1660 train_time:1329ms step_avg:88.62ms +step:16/1660 train_time:1421ms step_avg:88.80ms +step:17/1660 train_time:1512ms step_avg:88.92ms +step:18/1660 train_time:1602ms step_avg:89.02ms +step:19/1660 train_time:1694ms step_avg:89.15ms +step:20/1660 train_time:1785ms step_avg:89.23ms +step:21/1660 train_time:1876ms step_avg:89.33ms +step:22/1660 train_time:1968ms step_avg:89.44ms +step:23/1660 train_time:2061ms step_avg:89.60ms +step:24/1660 train_time:2154ms step_avg:89.73ms +step:25/1660 train_time:2246ms step_avg:89.83ms +step:26/1660 train_time:2338ms step_avg:89.91ms +step:27/1660 train_time:2429ms step_avg:89.98ms +step:28/1660 train_time:2521ms step_avg:90.04ms +step:29/1660 train_time:2612ms step_avg:90.07ms +step:30/1660 train_time:2703ms step_avg:90.11ms +step:31/1660 train_time:2794ms step_avg:90.14ms +step:32/1660 train_time:2886ms step_avg:90.20ms +step:33/1660 train_time:2979ms step_avg:90.28ms +step:34/1660 train_time:3071ms step_avg:90.32ms +step:35/1660 train_time:3163ms step_avg:90.38ms +step:36/1660 train_time:3255ms step_avg:90.42ms +step:37/1660 train_time:3347ms step_avg:90.46ms +step:38/1660 train_time:3439ms step_avg:90.49ms +step:39/1660 train_time:3530ms step_avg:90.51ms +step:40/1660 train_time:3622ms step_avg:90.54ms +step:41/1660 train_time:3713ms step_avg:90.56ms +step:42/1660 train_time:3804ms step_avg:90.58ms +step:43/1660 train_time:3896ms step_avg:90.60ms +step:44/1660 train_time:3988ms step_avg:90.64ms +step:45/1660 train_time:4081ms step_avg:90.68ms +step:46/1660 train_time:4172ms step_avg:90.70ms +step:47/1660 train_time:4265ms step_avg:90.74ms +step:48/1660 train_time:4356ms step_avg:90.75ms +step:49/1660 train_time:4448ms step_avg:90.77ms +step:50/1660 train_time:4539ms step_avg:90.78ms +step:51/1660 train_time:4630ms step_avg:90.79ms +step:52/1660 train_time:4721ms step_avg:90.79ms +step:53/1660 train_time:4812ms step_avg:90.79ms +step:54/1660 train_time:4905ms step_avg:90.83ms +step:55/1660 train_time:4996ms step_avg:90.84ms +step:56/1660 train_time:5089ms step_avg:90.87ms +step:57/1660 train_time:5180ms step_avg:90.88ms +step:58/1660 train_time:5272ms step_avg:90.90ms +step:59/1660 train_time:5364ms step_avg:90.91ms +step:60/1660 train_time:5455ms step_avg:90.92ms +step:61/1660 train_time:5546ms step_avg:90.92ms +step:62/1660 train_time:5638ms step_avg:90.93ms +step:63/1660 train_time:5729ms step_avg:90.93ms +step:64/1660 train_time:5820ms step_avg:90.94ms +step:65/1660 train_time:5911ms step_avg:90.95ms +step:66/1660 train_time:6005ms step_avg:90.99ms +step:67/1660 train_time:6098ms step_avg:91.01ms +step:68/1660 train_time:6190ms step_avg:91.03ms +step:69/1660 train_time:6282ms step_avg:91.04ms +step:70/1660 train_time:6373ms step_avg:91.05ms +step:71/1660 train_time:6465ms step_avg:91.05ms +step:72/1660 train_time:6556ms step_avg:91.05ms +step:73/1660 train_time:6647ms step_avg:91.06ms +step:74/1660 train_time:6738ms step_avg:91.06ms +step:75/1660 train_time:6829ms step_avg:91.06ms +step:76/1660 train_time:6921ms step_avg:91.06ms +step:77/1660 train_time:7012ms step_avg:91.07ms +step:78/1660 train_time:7105ms step_avg:91.09ms +step:79/1660 train_time:7198ms step_avg:91.11ms +step:80/1660 train_time:7290ms step_avg:91.12ms +step:81/1660 train_time:7381ms step_avg:91.13ms +step:82/1660 train_time:7473ms step_avg:91.13ms +step:83/1660 train_time:7564ms step_avg:91.14ms +step:84/1660 train_time:7656ms step_avg:91.14ms +step:85/1660 train_time:7747ms step_avg:91.14ms +step:86/1660 train_time:7838ms step_avg:91.14ms +step:87/1660 train_time:7930ms step_avg:91.15ms +step:88/1660 train_time:8022ms step_avg:91.16ms +step:89/1660 train_time:8115ms step_avg:91.18ms +step:90/1660 train_time:8207ms step_avg:91.18ms +step:91/1660 train_time:8298ms step_avg:91.19ms +step:92/1660 train_time:8390ms step_avg:91.19ms +step:93/1660 train_time:8482ms step_avg:91.20ms +step:94/1660 train_time:8573ms step_avg:91.20ms +step:95/1660 train_time:8665ms step_avg:91.21ms +step:96/1660 train_time:8755ms step_avg:91.20ms +step:97/1660 train_time:8846ms step_avg:91.20ms +step:98/1660 train_time:8938ms step_avg:91.20ms +step:99/1660 train_time:9029ms step_avg:91.20ms +step:100/1660 train_time:9121ms step_avg:91.21ms +step:101/1660 train_time:9212ms step_avg:91.21ms +step:102/1660 train_time:9304ms step_avg:91.22ms +step:103/1660 train_time:9396ms step_avg:91.22ms +step:104/1660 train_time:9488ms step_avg:91.23ms +step:105/1660 train_time:9581ms step_avg:91.25ms +step:106/1660 train_time:9673ms step_avg:91.25ms +step:107/1660 train_time:9764ms step_avg:91.25ms +step:108/1660 train_time:9855ms step_avg:91.25ms +step:109/1660 train_time:9946ms step_avg:91.25ms +step:110/1660 train_time:10037ms step_avg:91.24ms +step:111/1660 train_time:10128ms step_avg:91.24ms +step:112/1660 train_time:10219ms step_avg:91.24ms +step:113/1660 train_time:10310ms step_avg:91.24ms +step:114/1660 train_time:10402ms step_avg:91.24ms +step:115/1660 train_time:10493ms step_avg:91.24ms +step:116/1660 train_time:10586ms step_avg:91.26ms +step:117/1660 train_time:10677ms step_avg:91.26ms +step:118/1660 train_time:10769ms step_avg:91.26ms +step:119/1660 train_time:10860ms step_avg:91.26ms +step:120/1660 train_time:10951ms step_avg:91.26ms +step:121/1660 train_time:11042ms step_avg:91.26ms +step:122/1660 train_time:11133ms step_avg:91.26ms +step:123/1660 train_time:11225ms step_avg:91.26ms +step:124/1660 train_time:11316ms step_avg:91.26ms +step:125/1660 train_time:11408ms step_avg:91.26ms +step:125/1660 val_loss:4.3111 train_time:11501ms step_avg:92.01ms +step:126/1660 train_time:11524ms step_avg:91.46ms +step:127/1660 train_time:11595ms step_avg:91.30ms +step:128/1660 train_time:11697ms step_avg:91.38ms +step:129/1660 train_time:11793ms step_avg:91.42ms +step:130/1660 train_time:11884ms step_avg:91.42ms +step:131/1660 train_time:11975ms step_avg:91.41ms +step:132/1660 train_time:12065ms step_avg:91.40ms +step:133/1660 train_time:12156ms step_avg:91.40ms +step:134/1660 train_time:12245ms step_avg:91.38ms +step:135/1660 train_time:12336ms step_avg:91.38ms +step:136/1660 train_time:12426ms step_avg:91.37ms +step:137/1660 train_time:12517ms step_avg:91.36ms +step:138/1660 train_time:12610ms step_avg:91.38ms +step:139/1660 train_time:12705ms step_avg:91.40ms +step:140/1660 train_time:12798ms step_avg:91.41ms +step:141/1660 train_time:12890ms step_avg:91.42ms +step:142/1660 train_time:12981ms step_avg:91.42ms +step:143/1660 train_time:13072ms step_avg:91.41ms +step:144/1660 train_time:13163ms step_avg:91.41ms +step:145/1660 train_time:13254ms step_avg:91.40ms +step:146/1660 train_time:13344ms step_avg:91.40ms +step:147/1660 train_time:13435ms step_avg:91.39ms +step:148/1660 train_time:13526ms step_avg:91.40ms +step:149/1660 train_time:13619ms step_avg:91.40ms +step:150/1660 train_time:13712ms step_avg:91.41ms +step:151/1660 train_time:13804ms step_avg:91.41ms +step:152/1660 train_time:13896ms step_avg:91.42ms +step:153/1660 train_time:13987ms step_avg:91.42ms +step:154/1660 train_time:14079ms step_avg:91.42ms +step:155/1660 train_time:14170ms step_avg:91.42ms +step:156/1660 train_time:14261ms step_avg:91.42ms +step:157/1660 train_time:14352ms step_avg:91.41ms +step:158/1660 train_time:14442ms step_avg:91.40ms +step:159/1660 train_time:14533ms step_avg:91.40ms +step:160/1660 train_time:14625ms step_avg:91.40ms +step:161/1660 train_time:14717ms step_avg:91.41ms +step:162/1660 train_time:14809ms step_avg:91.42ms +step:163/1660 train_time:14901ms step_avg:91.42ms +step:164/1660 train_time:14992ms step_avg:91.42ms +step:165/1660 train_time:15084ms step_avg:91.42ms +step:166/1660 train_time:15174ms step_avg:91.41ms +step:167/1660 train_time:15265ms step_avg:91.41ms +step:168/1660 train_time:15357ms step_avg:91.41ms +step:169/1660 train_time:15449ms step_avg:91.42ms +step:170/1660 train_time:15541ms step_avg:91.42ms +step:171/1660 train_time:15632ms step_avg:91.42ms +step:172/1660 train_time:15723ms step_avg:91.41ms +step:173/1660 train_time:15815ms step_avg:91.41ms +step:174/1660 train_time:15906ms step_avg:91.41ms +step:175/1660 train_time:15997ms step_avg:91.41ms +step:176/1660 train_time:16088ms step_avg:91.41ms +step:177/1660 train_time:16180ms step_avg:91.41ms +step:178/1660 train_time:16270ms step_avg:91.41ms +step:179/1660 train_time:16362ms step_avg:91.41ms +step:180/1660 train_time:16454ms step_avg:91.41ms +step:181/1660 train_time:16545ms step_avg:91.41ms +step:182/1660 train_time:16637ms step_avg:91.41ms +step:183/1660 train_time:16730ms step_avg:91.42ms +step:184/1660 train_time:16822ms step_avg:91.42ms +step:185/1660 train_time:16913ms step_avg:91.42ms +step:186/1660 train_time:17004ms step_avg:91.42ms +step:187/1660 train_time:17095ms step_avg:91.42ms +step:188/1660 train_time:17186ms step_avg:91.41ms +step:189/1660 train_time:17277ms step_avg:91.41ms +step:190/1660 train_time:17368ms step_avg:91.41ms +step:191/1660 train_time:17460ms step_avg:91.41ms +step:192/1660 train_time:17551ms step_avg:91.41ms +step:193/1660 train_time:17642ms step_avg:91.41ms +step:194/1660 train_time:17734ms step_avg:91.41ms +step:195/1660 train_time:17825ms step_avg:91.41ms +step:196/1660 train_time:17918ms step_avg:91.42ms +step:197/1660 train_time:18010ms step_avg:91.42ms +step:198/1660 train_time:18101ms step_avg:91.42ms +step:199/1660 train_time:18193ms step_avg:91.42ms +step:200/1660 train_time:18283ms step_avg:91.42ms +step:201/1660 train_time:18374ms step_avg:91.41ms +step:202/1660 train_time:18464ms step_avg:91.41ms +step:203/1660 train_time:18556ms step_avg:91.41ms +step:204/1660 train_time:18649ms step_avg:91.42ms +step:205/1660 train_time:18741ms step_avg:91.42ms +step:206/1660 train_time:18833ms step_avg:91.42ms +step:207/1660 train_time:18924ms step_avg:91.42ms +step:208/1660 train_time:19015ms step_avg:91.42ms +step:209/1660 train_time:19107ms step_avg:91.42ms +step:210/1660 train_time:19198ms step_avg:91.42ms +step:211/1660 train_time:19290ms step_avg:91.42ms +step:212/1660 train_time:19382ms step_avg:91.42ms +step:213/1660 train_time:19473ms step_avg:91.42ms +step:214/1660 train_time:19563ms step_avg:91.42ms +step:215/1660 train_time:19655ms step_avg:91.42ms +step:216/1660 train_time:19747ms step_avg:91.42ms +step:217/1660 train_time:19840ms step_avg:91.43ms +step:218/1660 train_time:19931ms step_avg:91.43ms +step:219/1660 train_time:20022ms step_avg:91.43ms +step:220/1660 train_time:20113ms step_avg:91.42ms +step:221/1660 train_time:20204ms step_avg:91.42ms +step:222/1660 train_time:20296ms step_avg:91.42ms +step:223/1660 train_time:20387ms step_avg:91.42ms +step:224/1660 train_time:20478ms step_avg:91.42ms +step:225/1660 train_time:20569ms step_avg:91.42ms +step:226/1660 train_time:20661ms step_avg:91.42ms +step:227/1660 train_time:20752ms step_avg:91.42ms +step:228/1660 train_time:20843ms step_avg:91.42ms +step:229/1660 train_time:20935ms step_avg:91.42ms +step:230/1660 train_time:21026ms step_avg:91.42ms +step:231/1660 train_time:21118ms step_avg:91.42ms +step:232/1660 train_time:21209ms step_avg:91.42ms +step:233/1660 train_time:21300ms step_avg:91.42ms +step:234/1660 train_time:21392ms step_avg:91.42ms +step:235/1660 train_time:21483ms step_avg:91.42ms +step:236/1660 train_time:21574ms step_avg:91.41ms +step:237/1660 train_time:21665ms step_avg:91.41ms +step:238/1660 train_time:21758ms step_avg:91.42ms +step:239/1660 train_time:21849ms step_avg:91.42ms +step:240/1660 train_time:21941ms step_avg:91.42ms +step:241/1660 train_time:22032ms step_avg:91.42ms +step:242/1660 train_time:22124ms step_avg:91.42ms +step:243/1660 train_time:22215ms step_avg:91.42ms +step:244/1660 train_time:22306ms step_avg:91.42ms +step:245/1660 train_time:22398ms step_avg:91.42ms +step:246/1660 train_time:22489ms step_avg:91.42ms +step:247/1660 train_time:22580ms step_avg:91.42ms +step:248/1660 train_time:22671ms step_avg:91.41ms +step:249/1660 train_time:22762ms step_avg:91.42ms +step:250/1660 train_time:22854ms step_avg:91.42ms +step:250/1660 val_loss:3.9721 train_time:22948ms step_avg:91.79ms +step:251/1660 train_time:22968ms step_avg:91.51ms +step:252/1660 train_time:23043ms step_avg:91.44ms +step:253/1660 train_time:23140ms step_avg:91.46ms +step:254/1660 train_time:23233ms step_avg:91.47ms +step:255/1660 train_time:23324ms step_avg:91.47ms +step:256/1660 train_time:23416ms step_avg:91.47ms +step:257/1660 train_time:23507ms step_avg:91.47ms +step:258/1660 train_time:23597ms step_avg:91.46ms +step:259/1660 train_time:23688ms step_avg:91.46ms +step:260/1660 train_time:23778ms step_avg:91.45ms +step:261/1660 train_time:23869ms step_avg:91.45ms +step:262/1660 train_time:23962ms step_avg:91.46ms +step:263/1660 train_time:24056ms step_avg:91.47ms +step:264/1660 train_time:24149ms step_avg:91.47ms +step:265/1660 train_time:24241ms step_avg:91.48ms +step:266/1660 train_time:24333ms step_avg:91.48ms +step:267/1660 train_time:24425ms step_avg:91.48ms +step:268/1660 train_time:24516ms step_avg:91.48ms +step:269/1660 train_time:24606ms step_avg:91.47ms +step:270/1660 train_time:24696ms step_avg:91.47ms +step:271/1660 train_time:24787ms step_avg:91.46ms +step:272/1660 train_time:24877ms step_avg:91.46ms +step:273/1660 train_time:24969ms step_avg:91.46ms +step:274/1660 train_time:25063ms step_avg:91.47ms +step:275/1660 train_time:25155ms step_avg:91.47ms +step:276/1660 train_time:25247ms step_avg:91.48ms +step:277/1660 train_time:25338ms step_avg:91.47ms +step:278/1660 train_time:25429ms step_avg:91.47ms +step:279/1660 train_time:25520ms step_avg:91.47ms +step:280/1660 train_time:25611ms step_avg:91.47ms +step:281/1660 train_time:25702ms step_avg:91.47ms +step:282/1660 train_time:25793ms step_avg:91.47ms +step:283/1660 train_time:25885ms step_avg:91.46ms +step:284/1660 train_time:25976ms step_avg:91.47ms +step:285/1660 train_time:26068ms step_avg:91.47ms +step:286/1660 train_time:26160ms step_avg:91.47ms +step:287/1660 train_time:26252ms step_avg:91.47ms +step:288/1660 train_time:26343ms step_avg:91.47ms +step:289/1660 train_time:26435ms step_avg:91.47ms +step:290/1660 train_time:26527ms step_avg:91.47ms +step:291/1660 train_time:26617ms step_avg:91.47ms +step:292/1660 train_time:26709ms step_avg:91.47ms +step:293/1660 train_time:26800ms step_avg:91.47ms +step:294/1660 train_time:26891ms step_avg:91.47ms +step:295/1660 train_time:26983ms step_avg:91.47ms +step:296/1660 train_time:27075ms step_avg:91.47ms +step:297/1660 train_time:27166ms step_avg:91.47ms +step:298/1660 train_time:27258ms step_avg:91.47ms +step:299/1660 train_time:27349ms step_avg:91.47ms +step:300/1660 train_time:27440ms step_avg:91.47ms +step:301/1660 train_time:27531ms step_avg:91.47ms +step:302/1660 train_time:27623ms step_avg:91.47ms +step:303/1660 train_time:27715ms step_avg:91.47ms +step:304/1660 train_time:27807ms step_avg:91.47ms +step:305/1660 train_time:27897ms step_avg:91.47ms +step:306/1660 train_time:27989ms step_avg:91.47ms +step:307/1660 train_time:28081ms step_avg:91.47ms +step:308/1660 train_time:28172ms step_avg:91.47ms +step:309/1660 train_time:28264ms step_avg:91.47ms +step:310/1660 train_time:28355ms step_avg:91.47ms +step:311/1660 train_time:28447ms step_avg:91.47ms +step:312/1660 train_time:28537ms step_avg:91.47ms +step:313/1660 train_time:28628ms step_avg:91.46ms +step:314/1660 train_time:28719ms step_avg:91.46ms +step:315/1660 train_time:28811ms step_avg:91.46ms +step:316/1660 train_time:28904ms step_avg:91.47ms +step:317/1660 train_time:28996ms step_avg:91.47ms +step:318/1660 train_time:29087ms step_avg:91.47ms +step:319/1660 train_time:29179ms step_avg:91.47ms +step:320/1660 train_time:29271ms step_avg:91.47ms +step:321/1660 train_time:29363ms step_avg:91.47ms +step:322/1660 train_time:29454ms step_avg:91.47ms +step:323/1660 train_time:29545ms step_avg:91.47ms +step:324/1660 train_time:29636ms step_avg:91.47ms +step:325/1660 train_time:29726ms step_avg:91.47ms +step:326/1660 train_time:29817ms step_avg:91.46ms +step:327/1660 train_time:29910ms step_avg:91.47ms +step:328/1660 train_time:30001ms step_avg:91.47ms +step:329/1660 train_time:30094ms step_avg:91.47ms +step:330/1660 train_time:30186ms step_avg:91.47ms +step:331/1660 train_time:30277ms step_avg:91.47ms +step:332/1660 train_time:30368ms step_avg:91.47ms +step:333/1660 train_time:30459ms step_avg:91.47ms +step:334/1660 train_time:30551ms step_avg:91.47ms +step:335/1660 train_time:30642ms step_avg:91.47ms +step:336/1660 train_time:30733ms step_avg:91.47ms +step:337/1660 train_time:30824ms step_avg:91.47ms +step:338/1660 train_time:30915ms step_avg:91.47ms +step:339/1660 train_time:31007ms step_avg:91.47ms +step:340/1660 train_time:31098ms step_avg:91.47ms +step:341/1660 train_time:31189ms step_avg:91.46ms +step:342/1660 train_time:31282ms step_avg:91.47ms +step:343/1660 train_time:31373ms step_avg:91.47ms +step:344/1660 train_time:31464ms step_avg:91.47ms +step:345/1660 train_time:31555ms step_avg:91.46ms +step:346/1660 train_time:31647ms step_avg:91.46ms +step:347/1660 train_time:31738ms step_avg:91.46ms +step:348/1660 train_time:31829ms step_avg:91.46ms +step:349/1660 train_time:31921ms step_avg:91.46ms +step:350/1660 train_time:32014ms step_avg:91.47ms +step:351/1660 train_time:32106ms step_avg:91.47ms +step:352/1660 train_time:32196ms step_avg:91.47ms +step:353/1660 train_time:32287ms step_avg:91.47ms +step:354/1660 train_time:32378ms step_avg:91.46ms +step:355/1660 train_time:32469ms step_avg:91.46ms +step:356/1660 train_time:32560ms step_avg:91.46ms +step:357/1660 train_time:32651ms step_avg:91.46ms +step:358/1660 train_time:32744ms step_avg:91.46ms +step:359/1660 train_time:32836ms step_avg:91.46ms +step:360/1660 train_time:32927ms step_avg:91.47ms +step:361/1660 train_time:33019ms step_avg:91.47ms +step:362/1660 train_time:33113ms step_avg:91.47ms +step:363/1660 train_time:33204ms step_avg:91.47ms +step:364/1660 train_time:33296ms step_avg:91.47ms +step:365/1660 train_time:33388ms step_avg:91.47ms +step:366/1660 train_time:33479ms step_avg:91.47ms +step:367/1660 train_time:33570ms step_avg:91.47ms +step:368/1660 train_time:33662ms step_avg:91.47ms +step:369/1660 train_time:33753ms step_avg:91.47ms +step:370/1660 train_time:33845ms step_avg:91.47ms +step:371/1660 train_time:33936ms step_avg:91.47ms +step:372/1660 train_time:34027ms step_avg:91.47ms +step:373/1660 train_time:34118ms step_avg:91.47ms +step:374/1660 train_time:34210ms step_avg:91.47ms +step:375/1660 train_time:34302ms step_avg:91.47ms +step:375/1660 val_loss:3.8189 train_time:34396ms step_avg:91.72ms +step:376/1660 train_time:34417ms step_avg:91.53ms +step:377/1660 train_time:34495ms step_avg:91.50ms +step:378/1660 train_time:34591ms step_avg:91.51ms +step:379/1660 train_time:34682ms step_avg:91.51ms +step:380/1660 train_time:34772ms step_avg:91.51ms +step:381/1660 train_time:34862ms step_avg:91.50ms +step:382/1660 train_time:34953ms step_avg:91.50ms +step:383/1660 train_time:35043ms step_avg:91.50ms +step:384/1660 train_time:35133ms step_avg:91.49ms +step:385/1660 train_time:35224ms step_avg:91.49ms +step:386/1660 train_time:35315ms step_avg:91.49ms +step:387/1660 train_time:35407ms step_avg:91.49ms +step:388/1660 train_time:35501ms step_avg:91.50ms +step:389/1660 train_time:35595ms step_avg:91.50ms +step:390/1660 train_time:35686ms step_avg:91.50ms +step:391/1660 train_time:35778ms step_avg:91.50ms +step:392/1660 train_time:35868ms step_avg:91.50ms +step:393/1660 train_time:35959ms step_avg:91.50ms +step:394/1660 train_time:36050ms step_avg:91.50ms +step:395/1660 train_time:36140ms step_avg:91.49ms +step:396/1660 train_time:36230ms step_avg:91.49ms +step:397/1660 train_time:36321ms step_avg:91.49ms +step:398/1660 train_time:36413ms step_avg:91.49ms +step:399/1660 train_time:36505ms step_avg:91.49ms +step:400/1660 train_time:36599ms step_avg:91.50ms +step:401/1660 train_time:36690ms step_avg:91.50ms +step:402/1660 train_time:36781ms step_avg:91.49ms +step:403/1660 train_time:36872ms step_avg:91.49ms +step:404/1660 train_time:36963ms step_avg:91.49ms +step:405/1660 train_time:37054ms step_avg:91.49ms +step:406/1660 train_time:37145ms step_avg:91.49ms +step:407/1660 train_time:37236ms step_avg:91.49ms +step:408/1660 train_time:37327ms step_avg:91.49ms +step:409/1660 train_time:37418ms step_avg:91.49ms +step:410/1660 train_time:37510ms step_avg:91.49ms +step:411/1660 train_time:37603ms step_avg:91.49ms +step:412/1660 train_time:37695ms step_avg:91.49ms +step:413/1660 train_time:37786ms step_avg:91.49ms +step:414/1660 train_time:37877ms step_avg:91.49ms +step:415/1660 train_time:37968ms step_avg:91.49ms +step:416/1660 train_time:38059ms step_avg:91.49ms +step:417/1660 train_time:38149ms step_avg:91.48ms +step:418/1660 train_time:38240ms step_avg:91.48ms +step:419/1660 train_time:38331ms step_avg:91.48ms +step:420/1660 train_time:38422ms step_avg:91.48ms +step:421/1660 train_time:38513ms step_avg:91.48ms +step:422/1660 train_time:38605ms step_avg:91.48ms +step:423/1660 train_time:38696ms step_avg:91.48ms +step:424/1660 train_time:38788ms step_avg:91.48ms +step:425/1660 train_time:38880ms step_avg:91.48ms +step:426/1660 train_time:38970ms step_avg:91.48ms +step:427/1660 train_time:39060ms step_avg:91.48ms +step:428/1660 train_time:39151ms step_avg:91.47ms +step:429/1660 train_time:39242ms step_avg:91.47ms +step:430/1660 train_time:39333ms step_avg:91.47ms +step:431/1660 train_time:39424ms step_avg:91.47ms +step:432/1660 train_time:39516ms step_avg:91.47ms +step:433/1660 train_time:39607ms step_avg:91.47ms +step:434/1660 train_time:39699ms step_avg:91.47ms +step:435/1660 train_time:39790ms step_avg:91.47ms +step:436/1660 train_time:39881ms step_avg:91.47ms +step:437/1660 train_time:39973ms step_avg:91.47ms +step:438/1660 train_time:40065ms step_avg:91.47ms +step:439/1660 train_time:40156ms step_avg:91.47ms +step:440/1660 train_time:40247ms step_avg:91.47ms +step:441/1660 train_time:40339ms step_avg:91.47ms +step:442/1660 train_time:40429ms step_avg:91.47ms +step:443/1660 train_time:40521ms step_avg:91.47ms +step:444/1660 train_time:40611ms step_avg:91.47ms +step:445/1660 train_time:40703ms step_avg:91.47ms +step:446/1660 train_time:40794ms step_avg:91.47ms +step:447/1660 train_time:40885ms step_avg:91.46ms +step:448/1660 train_time:40977ms step_avg:91.47ms +step:449/1660 train_time:41068ms step_avg:91.47ms +step:450/1660 train_time:41160ms step_avg:91.47ms +step:451/1660 train_time:41252ms step_avg:91.47ms +step:452/1660 train_time:41344ms step_avg:91.47ms +step:453/1660 train_time:41435ms step_avg:91.47ms +step:454/1660 train_time:41526ms step_avg:91.47ms +step:455/1660 train_time:41617ms step_avg:91.47ms +step:456/1660 train_time:41709ms step_avg:91.47ms +step:457/1660 train_time:41800ms step_avg:91.47ms +step:458/1660 train_time:41891ms step_avg:91.46ms +step:459/1660 train_time:41981ms step_avg:91.46ms +step:460/1660 train_time:42072ms step_avg:91.46ms +step:461/1660 train_time:42164ms step_avg:91.46ms +step:462/1660 train_time:42256ms step_avg:91.46ms +step:463/1660 train_time:42347ms step_avg:91.46ms +step:464/1660 train_time:42438ms step_avg:91.46ms +step:465/1660 train_time:42529ms step_avg:91.46ms +step:466/1660 train_time:42620ms step_avg:91.46ms +step:467/1660 train_time:42711ms step_avg:91.46ms +step:468/1660 train_time:42803ms step_avg:91.46ms +step:469/1660 train_time:42893ms step_avg:91.46ms +step:470/1660 train_time:42984ms step_avg:91.46ms +step:471/1660 train_time:43075ms step_avg:91.45ms +step:472/1660 train_time:43168ms step_avg:91.46ms +step:473/1660 train_time:43260ms step_avg:91.46ms +step:474/1660 train_time:43351ms step_avg:91.46ms +step:475/1660 train_time:43443ms step_avg:91.46ms +step:476/1660 train_time:43535ms step_avg:91.46ms +step:477/1660 train_time:43627ms step_avg:91.46ms +step:478/1660 train_time:43718ms step_avg:91.46ms +step:479/1660 train_time:43809ms step_avg:91.46ms +step:480/1660 train_time:43900ms step_avg:91.46ms +step:481/1660 train_time:43991ms step_avg:91.46ms +step:482/1660 train_time:44082ms step_avg:91.46ms +step:483/1660 train_time:44173ms step_avg:91.46ms +step:484/1660 train_time:44264ms step_avg:91.45ms +step:485/1660 train_time:44355ms step_avg:91.45ms +step:486/1660 train_time:44447ms step_avg:91.45ms +step:487/1660 train_time:44538ms step_avg:91.45ms +step:488/1660 train_time:44630ms step_avg:91.45ms +step:489/1660 train_time:44722ms step_avg:91.46ms +step:490/1660 train_time:44813ms step_avg:91.46ms +step:491/1660 train_time:44904ms step_avg:91.46ms +step:492/1660 train_time:44995ms step_avg:91.45ms +step:493/1660 train_time:45086ms step_avg:91.45ms +step:494/1660 train_time:45177ms step_avg:91.45ms +step:495/1660 train_time:45270ms step_avg:91.45ms +step:496/1660 train_time:45361ms step_avg:91.45ms +step:497/1660 train_time:45452ms step_avg:91.45ms +step:498/1660 train_time:45544ms step_avg:91.45ms +step:499/1660 train_time:45635ms step_avg:91.45ms +step:500/1660 train_time:45727ms step_avg:91.45ms +step:500/1660 val_loss:3.7155 train_time:45820ms step_avg:91.64ms +step:501/1660 train_time:45840ms step_avg:91.50ms +step:502/1660 train_time:45914ms step_avg:91.46ms +step:503/1660 train_time:46008ms step_avg:91.47ms +step:504/1660 train_time:46101ms step_avg:91.47ms +step:505/1660 train_time:46191ms step_avg:91.47ms +step:506/1660 train_time:46281ms step_avg:91.46ms +step:507/1660 train_time:46371ms step_avg:91.46ms +step:508/1660 train_time:46461ms step_avg:91.46ms +step:509/1660 train_time:46551ms step_avg:91.46ms +step:510/1660 train_time:46641ms step_avg:91.45ms +step:511/1660 train_time:46731ms step_avg:91.45ms +step:512/1660 train_time:46823ms step_avg:91.45ms +step:513/1660 train_time:46918ms step_avg:91.46ms +step:514/1660 train_time:47012ms step_avg:91.46ms +step:515/1660 train_time:47103ms step_avg:91.46ms +step:516/1660 train_time:47194ms step_avg:91.46ms +step:517/1660 train_time:47285ms step_avg:91.46ms +step:518/1660 train_time:47377ms step_avg:91.46ms +step:519/1660 train_time:47468ms step_avg:91.46ms +step:520/1660 train_time:47558ms step_avg:91.46ms +step:521/1660 train_time:47648ms step_avg:91.46ms +step:522/1660 train_time:47739ms step_avg:91.45ms +step:523/1660 train_time:47830ms step_avg:91.45ms +step:524/1660 train_time:47922ms step_avg:91.45ms +step:525/1660 train_time:48014ms step_avg:91.46ms +step:526/1660 train_time:48106ms step_avg:91.46ms +step:527/1660 train_time:48198ms step_avg:91.46ms +step:528/1660 train_time:48289ms step_avg:91.46ms +step:529/1660 train_time:48380ms step_avg:91.46ms +step:530/1660 train_time:48471ms step_avg:91.46ms +step:531/1660 train_time:48562ms step_avg:91.45ms +step:532/1660 train_time:48653ms step_avg:91.45ms +step:533/1660 train_time:48744ms step_avg:91.45ms +step:534/1660 train_time:48835ms step_avg:91.45ms +step:535/1660 train_time:48926ms step_avg:91.45ms +step:536/1660 train_time:49018ms step_avg:91.45ms +step:537/1660 train_time:49110ms step_avg:91.45ms +step:538/1660 train_time:49202ms step_avg:91.45ms +step:539/1660 train_time:49294ms step_avg:91.45ms +step:540/1660 train_time:49385ms step_avg:91.45ms +step:541/1660 train_time:49476ms step_avg:91.45ms +step:542/1660 train_time:49568ms step_avg:91.45ms +step:543/1660 train_time:49659ms step_avg:91.45ms +step:544/1660 train_time:49749ms step_avg:91.45ms +step:545/1660 train_time:49841ms step_avg:91.45ms +step:546/1660 train_time:49932ms step_avg:91.45ms +step:547/1660 train_time:50023ms step_avg:91.45ms +step:548/1660 train_time:50114ms step_avg:91.45ms +step:549/1660 train_time:50206ms step_avg:91.45ms +step:550/1660 train_time:50298ms step_avg:91.45ms +step:551/1660 train_time:50390ms step_avg:91.45ms +step:552/1660 train_time:50481ms step_avg:91.45ms +step:553/1660 train_time:50573ms step_avg:91.45ms +step:554/1660 train_time:50664ms step_avg:91.45ms +step:555/1660 train_time:50755ms step_avg:91.45ms +step:556/1660 train_time:50848ms step_avg:91.45ms +step:557/1660 train_time:50940ms step_avg:91.45ms +step:558/1660 train_time:51032ms step_avg:91.46ms +step:559/1660 train_time:51125ms step_avg:91.46ms +step:560/1660 train_time:51217ms step_avg:91.46ms +step:561/1660 train_time:51311ms step_avg:91.46ms +step:562/1660 train_time:51404ms step_avg:91.47ms +step:563/1660 train_time:51497ms step_avg:91.47ms +step:564/1660 train_time:51589ms step_avg:91.47ms +step:565/1660 train_time:51681ms step_avg:91.47ms +step:566/1660 train_time:51774ms step_avg:91.47ms +step:567/1660 train_time:51867ms step_avg:91.48ms +step:568/1660 train_time:51960ms step_avg:91.48ms +step:569/1660 train_time:52052ms step_avg:91.48ms +step:570/1660 train_time:52145ms step_avg:91.48ms +step:571/1660 train_time:52238ms step_avg:91.48ms +step:572/1660 train_time:52330ms step_avg:91.49ms +step:573/1660 train_time:52421ms step_avg:91.49ms +step:574/1660 train_time:52514ms step_avg:91.49ms +step:575/1660 train_time:52606ms step_avg:91.49ms +step:576/1660 train_time:52699ms step_avg:91.49ms +step:577/1660 train_time:52792ms step_avg:91.49ms +step:578/1660 train_time:52885ms step_avg:91.50ms +step:579/1660 train_time:52977ms step_avg:91.50ms +step:580/1660 train_time:53070ms step_avg:91.50ms +step:581/1660 train_time:53162ms step_avg:91.50ms +step:582/1660 train_time:53255ms step_avg:91.50ms +step:583/1660 train_time:53348ms step_avg:91.51ms +step:584/1660 train_time:53440ms step_avg:91.51ms +step:585/1660 train_time:53533ms step_avg:91.51ms +step:586/1660 train_time:53626ms step_avg:91.51ms +step:587/1660 train_time:53718ms step_avg:91.51ms +step:588/1660 train_time:53812ms step_avg:91.52ms +step:589/1660 train_time:53903ms step_avg:91.52ms +step:590/1660 train_time:53996ms step_avg:91.52ms +step:591/1660 train_time:54089ms step_avg:91.52ms +step:592/1660 train_time:54181ms step_avg:91.52ms +step:593/1660 train_time:54273ms step_avg:91.52ms +step:594/1660 train_time:54366ms step_avg:91.52ms +step:595/1660 train_time:54458ms step_avg:91.53ms +step:596/1660 train_time:54551ms step_avg:91.53ms +step:597/1660 train_time:54644ms step_avg:91.53ms +step:598/1660 train_time:54737ms step_avg:91.53ms +step:599/1660 train_time:54830ms step_avg:91.54ms +step:600/1660 train_time:54922ms step_avg:91.54ms +step:601/1660 train_time:55016ms step_avg:91.54ms +step:602/1660 train_time:55108ms step_avg:91.54ms +step:603/1660 train_time:55200ms step_avg:91.54ms +step:604/1660 train_time:55294ms step_avg:91.55ms +step:605/1660 train_time:55386ms step_avg:91.55ms +step:606/1660 train_time:55479ms step_avg:91.55ms +step:607/1660 train_time:55572ms step_avg:91.55ms +step:608/1660 train_time:55664ms step_avg:91.55ms +step:609/1660 train_time:55756ms step_avg:91.55ms +step:610/1660 train_time:55848ms step_avg:91.55ms +step:611/1660 train_time:55941ms step_avg:91.56ms +step:612/1660 train_time:56033ms step_avg:91.56ms +step:613/1660 train_time:56126ms step_avg:91.56ms +step:614/1660 train_time:56218ms step_avg:91.56ms +step:615/1660 train_time:56311ms step_avg:91.56ms +step:616/1660 train_time:56403ms step_avg:91.56ms +step:617/1660 train_time:56497ms step_avg:91.57ms +step:618/1660 train_time:56590ms step_avg:91.57ms +step:619/1660 train_time:56682ms step_avg:91.57ms +step:620/1660 train_time:56775ms step_avg:91.57ms +step:621/1660 train_time:56868ms step_avg:91.57ms +step:622/1660 train_time:56960ms step_avg:91.58ms +step:623/1660 train_time:57052ms step_avg:91.58ms +step:624/1660 train_time:57145ms step_avg:91.58ms +step:625/1660 train_time:57237ms step_avg:91.58ms +step:625/1660 val_loss:3.6152 train_time:57331ms step_avg:91.73ms +step:626/1660 train_time:57351ms step_avg:91.61ms +step:627/1660 train_time:57427ms step_avg:91.59ms +step:628/1660 train_time:57530ms step_avg:91.61ms +step:629/1660 train_time:57623ms step_avg:91.61ms +step:630/1660 train_time:57716ms step_avg:91.61ms +step:631/1660 train_time:57808ms step_avg:91.61ms +step:632/1660 train_time:57899ms step_avg:91.61ms +step:633/1660 train_time:57990ms step_avg:91.61ms +step:634/1660 train_time:58082ms step_avg:91.61ms +step:635/1660 train_time:58173ms step_avg:91.61ms +step:636/1660 train_time:58265ms step_avg:91.61ms +step:637/1660 train_time:58357ms step_avg:91.61ms +step:638/1660 train_time:58454ms step_avg:91.62ms +step:639/1660 train_time:58550ms step_avg:91.63ms +step:640/1660 train_time:58643ms step_avg:91.63ms +step:641/1660 train_time:58736ms step_avg:91.63ms +step:642/1660 train_time:58828ms step_avg:91.63ms +step:643/1660 train_time:58920ms step_avg:91.63ms +step:644/1660 train_time:59012ms step_avg:91.63ms +step:645/1660 train_time:59104ms step_avg:91.63ms +step:646/1660 train_time:59195ms step_avg:91.63ms +step:647/1660 train_time:59288ms step_avg:91.63ms +step:648/1660 train_time:59380ms step_avg:91.64ms +step:649/1660 train_time:59476ms step_avg:91.64ms +step:650/1660 train_time:59570ms step_avg:91.65ms +step:651/1660 train_time:59663ms step_avg:91.65ms +step:652/1660 train_time:59758ms step_avg:91.65ms +step:653/1660 train_time:59850ms step_avg:91.65ms +step:654/1660 train_time:59941ms step_avg:91.65ms +step:655/1660 train_time:60033ms step_avg:91.65ms +step:656/1660 train_time:60125ms step_avg:91.65ms +step:657/1660 train_time:60218ms step_avg:91.66ms +step:658/1660 train_time:60311ms step_avg:91.66ms +step:659/1660 train_time:60404ms step_avg:91.66ms +step:660/1660 train_time:60498ms step_avg:91.66ms +step:661/1660 train_time:60592ms step_avg:91.67ms +step:662/1660 train_time:60686ms step_avg:91.67ms +step:663/1660 train_time:60779ms step_avg:91.67ms +step:664/1660 train_time:60872ms step_avg:91.68ms +step:665/1660 train_time:60964ms step_avg:91.68ms +step:666/1660 train_time:61056ms step_avg:91.68ms +step:667/1660 train_time:61148ms step_avg:91.68ms +step:668/1660 train_time:61240ms step_avg:91.68ms +step:669/1660 train_time:61334ms step_avg:91.68ms +step:670/1660 train_time:61426ms step_avg:91.68ms +step:671/1660 train_time:61519ms step_avg:91.68ms +step:672/1660 train_time:61613ms step_avg:91.69ms +step:673/1660 train_time:61707ms step_avg:91.69ms +step:674/1660 train_time:61799ms step_avg:91.69ms +step:675/1660 train_time:61892ms step_avg:91.69ms +step:676/1660 train_time:61985ms step_avg:91.69ms +step:677/1660 train_time:62077ms step_avg:91.69ms +step:678/1660 train_time:62169ms step_avg:91.70ms +step:679/1660 train_time:62261ms step_avg:91.70ms +step:680/1660 train_time:62353ms step_avg:91.70ms +step:681/1660 train_time:62446ms step_avg:91.70ms +step:682/1660 train_time:62539ms step_avg:91.70ms +step:683/1660 train_time:62633ms step_avg:91.70ms +step:684/1660 train_time:62726ms step_avg:91.70ms +step:685/1660 train_time:62819ms step_avg:91.71ms +step:686/1660 train_time:62913ms step_avg:91.71ms +step:687/1660 train_time:63005ms step_avg:91.71ms +step:688/1660 train_time:63097ms step_avg:91.71ms +step:689/1660 train_time:63190ms step_avg:91.71ms +step:690/1660 train_time:63282ms step_avg:91.71ms +step:691/1660 train_time:63374ms step_avg:91.71ms +step:692/1660 train_time:63466ms step_avg:91.71ms +step:693/1660 train_time:63558ms step_avg:91.71ms +step:694/1660 train_time:63651ms step_avg:91.72ms +step:695/1660 train_time:63744ms step_avg:91.72ms +step:696/1660 train_time:63837ms step_avg:91.72ms +step:697/1660 train_time:63930ms step_avg:91.72ms +step:698/1660 train_time:64023ms step_avg:91.72ms +step:699/1660 train_time:64115ms step_avg:91.72ms +step:700/1660 train_time:64208ms step_avg:91.73ms +step:701/1660 train_time:64300ms step_avg:91.73ms +step:702/1660 train_time:64394ms step_avg:91.73ms +step:703/1660 train_time:64487ms step_avg:91.73ms +step:704/1660 train_time:64580ms step_avg:91.73ms +step:705/1660 train_time:64673ms step_avg:91.73ms +step:706/1660 train_time:64766ms step_avg:91.74ms +step:707/1660 train_time:64858ms step_avg:91.74ms +step:708/1660 train_time:64950ms step_avg:91.74ms +step:709/1660 train_time:65043ms step_avg:91.74ms +step:710/1660 train_time:65136ms step_avg:91.74ms +step:711/1660 train_time:65228ms step_avg:91.74ms +step:712/1660 train_time:65320ms step_avg:91.74ms +step:713/1660 train_time:65414ms step_avg:91.74ms +step:714/1660 train_time:65507ms step_avg:91.75ms +step:715/1660 train_time:65599ms step_avg:91.75ms +step:716/1660 train_time:65693ms step_avg:91.75ms +step:717/1660 train_time:65786ms step_avg:91.75ms +step:718/1660 train_time:65879ms step_avg:91.75ms +step:719/1660 train_time:65971ms step_avg:91.75ms +step:720/1660 train_time:66064ms step_avg:91.76ms +step:721/1660 train_time:66156ms step_avg:91.76ms +step:722/1660 train_time:66249ms step_avg:91.76ms +step:723/1660 train_time:66341ms step_avg:91.76ms +step:724/1660 train_time:66434ms step_avg:91.76ms +step:725/1660 train_time:66527ms step_avg:91.76ms +step:726/1660 train_time:66620ms step_avg:91.76ms +step:727/1660 train_time:66712ms step_avg:91.76ms +step:728/1660 train_time:66805ms step_avg:91.77ms +step:729/1660 train_time:66897ms step_avg:91.77ms +step:730/1660 train_time:66990ms step_avg:91.77ms +step:731/1660 train_time:67083ms step_avg:91.77ms +step:732/1660 train_time:67176ms step_avg:91.77ms +step:733/1660 train_time:67269ms step_avg:91.77ms +step:734/1660 train_time:67361ms step_avg:91.77ms +step:735/1660 train_time:67454ms step_avg:91.77ms +step:736/1660 train_time:67547ms step_avg:91.78ms +step:737/1660 train_time:67639ms step_avg:91.78ms +step:738/1660 train_time:67732ms step_avg:91.78ms +step:739/1660 train_time:67825ms step_avg:91.78ms +step:740/1660 train_time:67918ms step_avg:91.78ms +step:741/1660 train_time:68010ms step_avg:91.78ms +step:742/1660 train_time:68103ms step_avg:91.78ms +step:743/1660 train_time:68195ms step_avg:91.78ms +step:744/1660 train_time:68288ms step_avg:91.78ms +step:745/1660 train_time:68380ms step_avg:91.78ms +step:746/1660 train_time:68473ms step_avg:91.79ms +step:747/1660 train_time:68565ms step_avg:91.79ms +step:748/1660 train_time:68658ms step_avg:91.79ms +step:749/1660 train_time:68751ms step_avg:91.79ms +step:750/1660 train_time:68844ms step_avg:91.79ms +step:750/1660 val_loss:3.5623 train_time:68939ms step_avg:91.92ms +step:751/1660 train_time:68959ms step_avg:91.82ms +step:752/1660 train_time:69041ms step_avg:91.81ms +step:753/1660 train_time:69141ms step_avg:91.82ms +step:754/1660 train_time:69236ms step_avg:91.82ms +step:755/1660 train_time:69327ms step_avg:91.82ms +step:756/1660 train_time:69418ms step_avg:91.82ms +step:757/1660 train_time:69510ms step_avg:91.82ms +step:758/1660 train_time:69601ms step_avg:91.82ms +step:759/1660 train_time:69692ms step_avg:91.82ms +step:760/1660 train_time:69783ms step_avg:91.82ms +step:761/1660 train_time:69875ms step_avg:91.82ms +step:762/1660 train_time:69969ms step_avg:91.82ms +step:763/1660 train_time:70064ms step_avg:91.83ms +step:764/1660 train_time:70161ms step_avg:91.83ms +step:765/1660 train_time:70256ms step_avg:91.84ms +step:766/1660 train_time:70348ms step_avg:91.84ms +step:767/1660 train_time:70440ms step_avg:91.84ms +step:768/1660 train_time:70533ms step_avg:91.84ms +step:769/1660 train_time:70624ms step_avg:91.84ms +step:770/1660 train_time:70716ms step_avg:91.84ms +step:771/1660 train_time:70808ms step_avg:91.84ms +step:772/1660 train_time:70900ms step_avg:91.84ms +step:773/1660 train_time:70993ms step_avg:91.84ms +step:774/1660 train_time:71088ms step_avg:91.85ms +step:775/1660 train_time:71181ms step_avg:91.85ms +step:776/1660 train_time:71276ms step_avg:91.85ms +step:777/1660 train_time:71369ms step_avg:91.85ms +step:778/1660 train_time:71461ms step_avg:91.85ms +step:779/1660 train_time:71554ms step_avg:91.85ms +step:780/1660 train_time:71646ms step_avg:91.85ms +step:781/1660 train_time:71738ms step_avg:91.85ms +step:782/1660 train_time:71830ms step_avg:91.85ms +step:783/1660 train_time:71924ms step_avg:91.86ms +step:784/1660 train_time:72018ms step_avg:91.86ms +step:785/1660 train_time:72111ms step_avg:91.86ms +step:786/1660 train_time:72203ms step_avg:91.86ms +step:787/1660 train_time:72298ms step_avg:91.86ms +step:788/1660 train_time:72391ms step_avg:91.87ms +step:789/1660 train_time:72483ms step_avg:91.87ms +step:790/1660 train_time:72577ms step_avg:91.87ms +step:791/1660 train_time:72668ms step_avg:91.87ms +step:792/1660 train_time:72760ms step_avg:91.87ms +step:793/1660 train_time:72853ms step_avg:91.87ms +step:794/1660 train_time:72946ms step_avg:91.87ms +step:795/1660 train_time:73040ms step_avg:91.87ms +step:796/1660 train_time:73132ms step_avg:91.87ms +step:797/1660 train_time:73226ms step_avg:91.88ms +step:798/1660 train_time:73319ms step_avg:91.88ms +step:799/1660 train_time:73411ms step_avg:91.88ms +step:800/1660 train_time:73503ms step_avg:91.88ms +step:801/1660 train_time:73596ms step_avg:91.88ms +step:802/1660 train_time:73689ms step_avg:91.88ms +step:803/1660 train_time:73781ms step_avg:91.88ms +step:804/1660 train_time:73874ms step_avg:91.88ms +step:805/1660 train_time:73967ms step_avg:91.88ms +step:806/1660 train_time:74060ms step_avg:91.89ms +step:807/1660 train_time:74153ms step_avg:91.89ms +step:808/1660 train_time:74246ms step_avg:91.89ms +step:809/1660 train_time:74339ms step_avg:91.89ms +step:810/1660 train_time:74432ms step_avg:91.89ms +step:811/1660 train_time:74525ms step_avg:91.89ms +step:812/1660 train_time:74618ms step_avg:91.89ms +step:813/1660 train_time:74710ms step_avg:91.89ms +step:814/1660 train_time:74803ms step_avg:91.90ms +step:815/1660 train_time:74896ms step_avg:91.90ms +step:816/1660 train_time:74988ms step_avg:91.90ms +step:817/1660 train_time:75081ms step_avg:91.90ms +step:818/1660 train_time:75174ms step_avg:91.90ms +step:819/1660 train_time:75267ms step_avg:91.90ms +step:820/1660 train_time:75359ms step_avg:91.90ms +step:821/1660 train_time:75452ms step_avg:91.90ms +step:822/1660 train_time:75545ms step_avg:91.90ms +step:823/1660 train_time:75637ms step_avg:91.90ms +step:824/1660 train_time:75730ms step_avg:91.91ms +step:825/1660 train_time:75823ms step_avg:91.91ms +step:826/1660 train_time:75916ms step_avg:91.91ms +step:827/1660 train_time:76009ms step_avg:91.91ms +step:828/1660 train_time:76101ms step_avg:91.91ms +step:829/1660 train_time:76194ms step_avg:91.91ms +step:830/1660 train_time:76287ms step_avg:91.91ms +step:831/1660 train_time:76379ms step_avg:91.91ms +step:832/1660 train_time:76472ms step_avg:91.91ms +step:833/1660 train_time:76564ms step_avg:91.91ms +step:834/1660 train_time:76657ms step_avg:91.92ms +step:835/1660 train_time:76751ms step_avg:91.92ms +step:836/1660 train_time:76843ms step_avg:91.92ms +step:837/1660 train_time:76936ms step_avg:91.92ms +step:838/1660 train_time:77028ms step_avg:91.92ms +step:839/1660 train_time:77121ms step_avg:91.92ms +step:840/1660 train_time:77214ms step_avg:91.92ms +step:841/1660 train_time:77308ms step_avg:91.92ms +step:842/1660 train_time:77400ms step_avg:91.92ms +step:843/1660 train_time:77492ms step_avg:91.92ms +step:844/1660 train_time:77585ms step_avg:91.92ms +step:845/1660 train_time:77679ms step_avg:91.93ms +step:846/1660 train_time:77771ms step_avg:91.93ms +step:847/1660 train_time:77863ms step_avg:91.93ms +step:848/1660 train_time:77956ms step_avg:91.93ms +step:849/1660 train_time:78049ms step_avg:91.93ms +step:850/1660 train_time:78141ms step_avg:91.93ms +step:851/1660 train_time:78235ms step_avg:91.93ms +step:852/1660 train_time:78328ms step_avg:91.93ms +step:853/1660 train_time:78420ms step_avg:91.93ms +step:854/1660 train_time:78514ms step_avg:91.94ms +step:855/1660 train_time:78606ms step_avg:91.94ms +step:856/1660 train_time:78699ms step_avg:91.94ms +step:857/1660 train_time:78791ms step_avg:91.94ms +step:858/1660 train_time:78884ms step_avg:91.94ms +step:859/1660 train_time:78978ms step_avg:91.94ms +step:860/1660 train_time:79070ms step_avg:91.94ms +step:861/1660 train_time:79164ms step_avg:91.94ms +step:862/1660 train_time:79257ms step_avg:91.95ms +step:863/1660 train_time:79350ms step_avg:91.95ms +step:864/1660 train_time:79442ms step_avg:91.95ms +step:865/1660 train_time:79535ms step_avg:91.95ms +step:866/1660 train_time:79628ms step_avg:91.95ms +step:867/1660 train_time:79720ms step_avg:91.95ms +step:868/1660 train_time:79813ms step_avg:91.95ms +step:869/1660 train_time:79905ms step_avg:91.95ms +step:870/1660 train_time:79997ms step_avg:91.95ms +step:871/1660 train_time:80089ms step_avg:91.95ms +step:872/1660 train_time:80182ms step_avg:91.95ms +step:873/1660 train_time:80275ms step_avg:91.95ms +step:874/1660 train_time:80368ms step_avg:91.95ms +step:875/1660 train_time:80460ms step_avg:91.95ms +step:875/1660 val_loss:3.5164 train_time:80554ms step_avg:92.06ms +step:876/1660 train_time:80574ms step_avg:91.98ms +step:877/1660 train_time:80651ms step_avg:91.96ms +step:878/1660 train_time:80747ms step_avg:91.97ms +step:879/1660 train_time:80840ms step_avg:91.97ms +step:880/1660 train_time:80932ms step_avg:91.97ms +step:881/1660 train_time:81023ms step_avg:91.97ms +step:882/1660 train_time:81115ms step_avg:91.97ms +step:883/1660 train_time:81206ms step_avg:91.97ms +step:884/1660 train_time:81297ms step_avg:91.97ms +step:885/1660 train_time:81389ms step_avg:91.96ms +step:886/1660 train_time:81481ms step_avg:91.97ms +step:887/1660 train_time:81576ms step_avg:91.97ms +step:888/1660 train_time:81672ms step_avg:91.97ms +step:889/1660 train_time:81766ms step_avg:91.97ms +step:890/1660 train_time:81858ms step_avg:91.98ms +step:891/1660 train_time:81952ms step_avg:91.98ms +step:892/1660 train_time:82044ms step_avg:91.98ms +step:893/1660 train_time:82135ms step_avg:91.98ms +step:894/1660 train_time:82227ms step_avg:91.98ms +step:895/1660 train_time:82318ms step_avg:91.98ms +step:896/1660 train_time:82410ms step_avg:91.98ms +step:897/1660 train_time:82502ms step_avg:91.98ms +step:898/1660 train_time:82598ms step_avg:91.98ms +step:899/1660 train_time:82691ms step_avg:91.98ms +step:900/1660 train_time:82785ms step_avg:91.98ms +step:901/1660 train_time:82877ms step_avg:91.98ms +step:902/1660 train_time:82970ms step_avg:91.98ms +step:903/1660 train_time:83062ms step_avg:91.98ms +step:904/1660 train_time:83154ms step_avg:91.98ms +step:905/1660 train_time:83245ms step_avg:91.98ms +step:906/1660 train_time:83337ms step_avg:91.98ms +step:907/1660 train_time:83429ms step_avg:91.98ms +step:908/1660 train_time:83522ms step_avg:91.98ms +step:909/1660 train_time:83615ms step_avg:91.99ms +step:910/1660 train_time:83709ms step_avg:91.99ms +step:911/1660 train_time:83802ms step_avg:91.99ms +step:912/1660 train_time:83894ms step_avg:91.99ms +step:913/1660 train_time:83988ms step_avg:91.99ms +step:914/1660 train_time:84080ms step_avg:91.99ms +step:915/1660 train_time:84172ms step_avg:91.99ms +step:916/1660 train_time:84264ms step_avg:91.99ms +step:917/1660 train_time:84356ms step_avg:91.99ms +step:918/1660 train_time:84449ms step_avg:91.99ms +step:919/1660 train_time:84540ms step_avg:91.99ms +step:920/1660 train_time:84633ms step_avg:91.99ms +step:921/1660 train_time:84726ms step_avg:91.99ms +step:922/1660 train_time:84818ms step_avg:91.99ms +step:923/1660 train_time:84911ms step_avg:91.99ms +step:924/1660 train_time:85005ms step_avg:92.00ms +step:925/1660 train_time:85097ms step_avg:92.00ms +step:926/1660 train_time:85190ms step_avg:92.00ms +step:927/1660 train_time:85283ms step_avg:92.00ms +step:928/1660 train_time:85376ms step_avg:92.00ms +step:929/1660 train_time:85469ms step_avg:92.00ms +step:930/1660 train_time:85561ms step_avg:92.00ms +step:931/1660 train_time:85654ms step_avg:92.00ms +step:932/1660 train_time:85747ms step_avg:92.00ms +step:933/1660 train_time:85839ms step_avg:92.00ms +step:934/1660 train_time:85932ms step_avg:92.00ms +step:935/1660 train_time:86025ms step_avg:92.01ms +step:936/1660 train_time:86117ms step_avg:92.01ms +step:937/1660 train_time:86210ms step_avg:92.01ms +step:938/1660 train_time:86303ms step_avg:92.01ms +step:939/1660 train_time:86397ms step_avg:92.01ms +step:940/1660 train_time:86489ms step_avg:92.01ms +step:941/1660 train_time:86582ms step_avg:92.01ms +step:942/1660 train_time:86675ms step_avg:92.01ms +step:943/1660 train_time:86767ms step_avg:92.01ms +step:944/1660 train_time:86859ms step_avg:92.01ms +step:945/1660 train_time:86952ms step_avg:92.01ms +step:946/1660 train_time:87044ms step_avg:92.01ms +step:947/1660 train_time:87137ms step_avg:92.01ms +step:948/1660 train_time:87229ms step_avg:92.01ms +step:949/1660 train_time:87321ms step_avg:92.01ms +step:950/1660 train_time:87414ms step_avg:92.02ms +step:951/1660 train_time:87507ms step_avg:92.02ms +step:952/1660 train_time:87599ms step_avg:92.02ms +step:953/1660 train_time:87692ms step_avg:92.02ms +step:954/1660 train_time:87784ms step_avg:92.02ms +step:955/1660 train_time:87876ms step_avg:92.02ms +step:956/1660 train_time:87969ms step_avg:92.02ms +step:957/1660 train_time:88061ms step_avg:92.02ms +step:958/1660 train_time:88153ms step_avg:92.02ms +step:959/1660 train_time:88246ms step_avg:92.02ms +step:960/1660 train_time:88338ms step_avg:92.02ms +step:961/1660 train_time:88430ms step_avg:92.02ms +step:962/1660 train_time:88523ms step_avg:92.02ms +step:963/1660 train_time:88615ms step_avg:92.02ms +step:964/1660 train_time:88708ms step_avg:92.02ms +step:965/1660 train_time:88800ms step_avg:92.02ms +step:966/1660 train_time:88893ms step_avg:92.02ms +step:967/1660 train_time:88986ms step_avg:92.02ms +step:968/1660 train_time:89078ms step_avg:92.02ms +step:969/1660 train_time:89171ms step_avg:92.02ms +step:970/1660 train_time:89263ms step_avg:92.02ms +step:971/1660 train_time:89357ms step_avg:92.03ms +step:972/1660 train_time:89450ms step_avg:92.03ms +step:973/1660 train_time:89542ms step_avg:92.03ms +step:974/1660 train_time:89634ms step_avg:92.03ms +step:975/1660 train_time:89726ms step_avg:92.03ms +step:976/1660 train_time:89819ms step_avg:92.03ms +step:977/1660 train_time:89911ms step_avg:92.03ms +step:978/1660 train_time:90004ms step_avg:92.03ms +step:979/1660 train_time:90097ms step_avg:92.03ms +step:980/1660 train_time:90190ms step_avg:92.03ms +step:981/1660 train_time:90283ms step_avg:92.03ms +step:982/1660 train_time:90375ms step_avg:92.03ms +step:983/1660 train_time:90468ms step_avg:92.03ms +step:984/1660 train_time:90560ms step_avg:92.03ms +step:985/1660 train_time:90652ms step_avg:92.03ms +step:986/1660 train_time:90745ms step_avg:92.03ms +step:987/1660 train_time:90837ms step_avg:92.03ms +step:988/1660 train_time:90929ms step_avg:92.03ms +step:989/1660 train_time:91022ms step_avg:92.03ms +step:990/1660 train_time:91116ms step_avg:92.04ms +step:991/1660 train_time:91209ms step_avg:92.04ms +step:992/1660 train_time:91301ms step_avg:92.04ms +step:993/1660 train_time:91393ms step_avg:92.04ms +step:994/1660 train_time:91487ms step_avg:92.04ms +step:995/1660 train_time:91578ms step_avg:92.04ms +step:996/1660 train_time:91671ms step_avg:92.04ms +step:997/1660 train_time:91763ms step_avg:92.04ms +step:998/1660 train_time:91857ms step_avg:92.04ms +step:999/1660 train_time:91949ms step_avg:92.04ms +step:1000/1660 train_time:92042ms step_avg:92.04ms +step:1000/1660 val_loss:3.4671 train_time:92136ms step_avg:92.14ms +step:1001/1660 train_time:92156ms step_avg:92.06ms +step:1002/1660 train_time:92232ms step_avg:92.05ms +step:1003/1660 train_time:92330ms step_avg:92.05ms +step:1004/1660 train_time:92424ms step_avg:92.06ms +step:1005/1660 train_time:92515ms step_avg:92.05ms +step:1006/1660 train_time:92607ms step_avg:92.05ms +step:1007/1660 train_time:92698ms step_avg:92.05ms +step:1008/1660 train_time:92789ms step_avg:92.05ms +step:1009/1660 train_time:92880ms step_avg:92.05ms +step:1010/1660 train_time:92973ms step_avg:92.05ms +step:1011/1660 train_time:93065ms step_avg:92.05ms +step:1012/1660 train_time:93160ms step_avg:92.06ms +step:1013/1660 train_time:93255ms step_avg:92.06ms +step:1014/1660 train_time:93349ms step_avg:92.06ms +step:1015/1660 train_time:93441ms step_avg:92.06ms +step:1016/1660 train_time:93534ms step_avg:92.06ms +step:1017/1660 train_time:93626ms step_avg:92.06ms +step:1018/1660 train_time:93718ms step_avg:92.06ms +step:1019/1660 train_time:93809ms step_avg:92.06ms +step:1020/1660 train_time:93901ms step_avg:92.06ms +step:1021/1660 train_time:93993ms step_avg:92.06ms +step:1022/1660 train_time:94086ms step_avg:92.06ms +step:1023/1660 train_time:94180ms step_avg:92.06ms +step:1024/1660 train_time:94274ms step_avg:92.06ms +step:1025/1660 train_time:94367ms step_avg:92.06ms +step:1026/1660 train_time:94459ms step_avg:92.07ms +step:1027/1660 train_time:94551ms step_avg:92.07ms +step:1028/1660 train_time:94644ms step_avg:92.07ms +step:1029/1660 train_time:94737ms step_avg:92.07ms +step:1030/1660 train_time:94829ms step_avg:92.07ms +step:1031/1660 train_time:94921ms step_avg:92.07ms +step:1032/1660 train_time:95013ms step_avg:92.07ms +step:1033/1660 train_time:95106ms step_avg:92.07ms +step:1034/1660 train_time:95200ms step_avg:92.07ms +step:1035/1660 train_time:95294ms step_avg:92.07ms +step:1036/1660 train_time:95386ms step_avg:92.07ms +step:1037/1660 train_time:95479ms step_avg:92.07ms +step:1038/1660 train_time:95571ms step_avg:92.07ms +step:1039/1660 train_time:95664ms step_avg:92.07ms +step:1040/1660 train_time:95756ms step_avg:92.07ms +step:1041/1660 train_time:95848ms step_avg:92.07ms +step:1042/1660 train_time:95940ms step_avg:92.07ms +step:1043/1660 train_time:96032ms step_avg:92.07ms +step:1044/1660 train_time:96125ms step_avg:92.07ms +step:1045/1660 train_time:96219ms step_avg:92.08ms +step:1046/1660 train_time:96313ms step_avg:92.08ms +step:1047/1660 train_time:96405ms step_avg:92.08ms +step:1048/1660 train_time:96498ms step_avg:92.08ms +step:1049/1660 train_time:96591ms step_avg:92.08ms +step:1050/1660 train_time:96684ms step_avg:92.08ms +step:1051/1660 train_time:96776ms step_avg:92.08ms +step:1052/1660 train_time:96869ms step_avg:92.08ms +step:1053/1660 train_time:96961ms step_avg:92.08ms +step:1054/1660 train_time:97053ms step_avg:92.08ms +step:1055/1660 train_time:97146ms step_avg:92.08ms +step:1056/1660 train_time:97239ms step_avg:92.08ms +step:1057/1660 train_time:97332ms step_avg:92.08ms +step:1058/1660 train_time:97426ms step_avg:92.09ms +step:1059/1660 train_time:97518ms step_avg:92.09ms +step:1060/1660 train_time:97611ms step_avg:92.09ms +step:1061/1660 train_time:97704ms step_avg:92.09ms +step:1062/1660 train_time:97797ms step_avg:92.09ms +step:1063/1660 train_time:97889ms step_avg:92.09ms +step:1064/1660 train_time:97980ms step_avg:92.09ms +step:1065/1660 train_time:98073ms step_avg:92.09ms +step:1066/1660 train_time:98166ms step_avg:92.09ms +step:1067/1660 train_time:98258ms step_avg:92.09ms +step:1068/1660 train_time:98351ms step_avg:92.09ms +step:1069/1660 train_time:98444ms step_avg:92.09ms +step:1070/1660 train_time:98537ms step_avg:92.09ms +step:1071/1660 train_time:98630ms step_avg:92.09ms +step:1072/1660 train_time:98723ms step_avg:92.09ms +step:1073/1660 train_time:98815ms step_avg:92.09ms +step:1074/1660 train_time:98908ms step_avg:92.09ms +step:1075/1660 train_time:99000ms step_avg:92.09ms +step:1076/1660 train_time:99092ms step_avg:92.09ms +step:1077/1660 train_time:99184ms step_avg:92.09ms +step:1078/1660 train_time:99276ms step_avg:92.09ms +step:1079/1660 train_time:99369ms step_avg:92.09ms +step:1080/1660 train_time:99461ms step_avg:92.09ms +step:1081/1660 train_time:99554ms step_avg:92.09ms +step:1082/1660 train_time:99647ms step_avg:92.10ms +step:1083/1660 train_time:99740ms step_avg:92.10ms +step:1084/1660 train_time:99832ms step_avg:92.10ms +step:1085/1660 train_time:99925ms step_avg:92.10ms +step:1086/1660 train_time:100018ms step_avg:92.10ms +step:1087/1660 train_time:100111ms step_avg:92.10ms +step:1088/1660 train_time:100204ms step_avg:92.10ms +step:1089/1660 train_time:100296ms step_avg:92.10ms +step:1090/1660 train_time:100388ms step_avg:92.10ms +step:1091/1660 train_time:100481ms step_avg:92.10ms +step:1092/1660 train_time:100574ms step_avg:92.10ms +step:1093/1660 train_time:100666ms step_avg:92.10ms +step:1094/1660 train_time:100759ms step_avg:92.10ms +step:1095/1660 train_time:100851ms step_avg:92.10ms +step:1096/1660 train_time:100944ms step_avg:92.10ms +step:1097/1660 train_time:101037ms step_avg:92.10ms +step:1098/1660 train_time:101130ms step_avg:92.10ms +step:1099/1660 train_time:101222ms step_avg:92.10ms +step:1100/1660 train_time:101315ms step_avg:92.10ms +step:1101/1660 train_time:101407ms step_avg:92.10ms +step:1102/1660 train_time:101500ms step_avg:92.11ms +step:1103/1660 train_time:101593ms step_avg:92.11ms +step:1104/1660 train_time:101686ms step_avg:92.11ms +step:1105/1660 train_time:101778ms step_avg:92.11ms +step:1106/1660 train_time:101871ms step_avg:92.11ms +step:1107/1660 train_time:101964ms step_avg:92.11ms +step:1108/1660 train_time:102057ms step_avg:92.11ms +step:1109/1660 train_time:102149ms step_avg:92.11ms +step:1110/1660 train_time:102242ms step_avg:92.11ms +step:1111/1660 train_time:102337ms step_avg:92.11ms +step:1112/1660 train_time:102430ms step_avg:92.11ms +step:1113/1660 train_time:102522ms step_avg:92.11ms +step:1114/1660 train_time:102617ms step_avg:92.12ms +step:1115/1660 train_time:102710ms step_avg:92.12ms +step:1116/1660 train_time:102803ms step_avg:92.12ms +step:1117/1660 train_time:102896ms step_avg:92.12ms +step:1118/1660 train_time:102990ms step_avg:92.12ms +step:1119/1660 train_time:103082ms step_avg:92.12ms +step:1120/1660 train_time:103175ms step_avg:92.12ms +step:1121/1660 train_time:103269ms step_avg:92.12ms +step:1122/1660 train_time:103362ms step_avg:92.12ms +step:1123/1660 train_time:103455ms step_avg:92.12ms +step:1124/1660 train_time:103549ms step_avg:92.13ms +step:1125/1660 train_time:103642ms step_avg:92.13ms +step:1125/1660 val_loss:3.4134 train_time:103738ms step_avg:92.21ms +step:1126/1660 train_time:103759ms step_avg:92.15ms +step:1127/1660 train_time:103836ms step_avg:92.13ms +step:1128/1660 train_time:103940ms step_avg:92.15ms +step:1129/1660 train_time:104033ms step_avg:92.15ms +step:1130/1660 train_time:104125ms step_avg:92.15ms +step:1131/1660 train_time:104217ms step_avg:92.15ms +step:1132/1660 train_time:104309ms step_avg:92.15ms +step:1133/1660 train_time:104402ms step_avg:92.15ms +step:1134/1660 train_time:104495ms step_avg:92.15ms +step:1135/1660 train_time:104587ms step_avg:92.15ms +step:1136/1660 train_time:104679ms step_avg:92.15ms +step:1137/1660 train_time:104773ms step_avg:92.15ms +step:1138/1660 train_time:104869ms step_avg:92.15ms +step:1139/1660 train_time:104966ms step_avg:92.16ms +step:1140/1660 train_time:105060ms step_avg:92.16ms +step:1141/1660 train_time:105154ms step_avg:92.16ms +step:1142/1660 train_time:105246ms step_avg:92.16ms +step:1143/1660 train_time:105338ms step_avg:92.16ms +step:1144/1660 train_time:105431ms step_avg:92.16ms +step:1145/1660 train_time:105523ms step_avg:92.16ms +step:1146/1660 train_time:105616ms step_avg:92.16ms +step:1147/1660 train_time:105709ms step_avg:92.16ms +step:1148/1660 train_time:105804ms step_avg:92.16ms +step:1149/1660 train_time:105899ms step_avg:92.17ms +step:1150/1660 train_time:105993ms step_avg:92.17ms +step:1151/1660 train_time:106087ms step_avg:92.17ms +step:1152/1660 train_time:106180ms step_avg:92.17ms +step:1153/1660 train_time:106273ms step_avg:92.17ms +step:1154/1660 train_time:106366ms step_avg:92.17ms +step:1155/1660 train_time:106459ms step_avg:92.17ms +step:1156/1660 train_time:106551ms step_avg:92.17ms +step:1157/1660 train_time:106643ms step_avg:92.17ms +step:1158/1660 train_time:106739ms step_avg:92.17ms +step:1159/1660 train_time:106834ms step_avg:92.18ms +step:1160/1660 train_time:106929ms step_avg:92.18ms +step:1161/1660 train_time:107022ms step_avg:92.18ms +step:1162/1660 train_time:107116ms step_avg:92.18ms +step:1163/1660 train_time:107209ms step_avg:92.18ms +step:1164/1660 train_time:107302ms step_avg:92.18ms +step:1165/1660 train_time:107396ms step_avg:92.19ms +step:1166/1660 train_time:107488ms step_avg:92.19ms +step:1167/1660 train_time:107580ms step_avg:92.19ms +step:1168/1660 train_time:107673ms step_avg:92.19ms +step:1169/1660 train_time:107767ms step_avg:92.19ms +step:1170/1660 train_time:107860ms step_avg:92.19ms +step:1171/1660 train_time:107956ms step_avg:92.19ms +step:1172/1660 train_time:108050ms step_avg:92.19ms +step:1173/1660 train_time:108143ms step_avg:92.19ms +step:1174/1660 train_time:108236ms step_avg:92.19ms +step:1175/1660 train_time:108330ms step_avg:92.20ms +step:1176/1660 train_time:108422ms step_avg:92.20ms +step:1177/1660 train_time:108515ms step_avg:92.20ms +step:1178/1660 train_time:108608ms step_avg:92.20ms +step:1179/1660 train_time:108701ms step_avg:92.20ms +step:1180/1660 train_time:108795ms step_avg:92.20ms +step:1181/1660 train_time:108889ms step_avg:92.20ms +step:1182/1660 train_time:108983ms step_avg:92.20ms +step:1183/1660 train_time:109076ms step_avg:92.20ms +step:1184/1660 train_time:109170ms step_avg:92.20ms +step:1185/1660 train_time:109263ms step_avg:92.21ms +step:1186/1660 train_time:109358ms step_avg:92.21ms +step:1187/1660 train_time:109451ms step_avg:92.21ms +step:1188/1660 train_time:109543ms step_avg:92.21ms +step:1189/1660 train_time:109637ms step_avg:92.21ms +step:1190/1660 train_time:109731ms step_avg:92.21ms +step:1191/1660 train_time:109824ms step_avg:92.21ms +step:1192/1660 train_time:109918ms step_avg:92.21ms +step:1193/1660 train_time:110012ms step_avg:92.21ms +step:1194/1660 train_time:110106ms step_avg:92.22ms +step:1195/1660 train_time:110199ms step_avg:92.22ms +step:1196/1660 train_time:110292ms step_avg:92.22ms +step:1197/1660 train_time:110385ms step_avg:92.22ms +step:1198/1660 train_time:110478ms step_avg:92.22ms +step:1199/1660 train_time:110571ms step_avg:92.22ms +step:1200/1660 train_time:110664ms step_avg:92.22ms +step:1201/1660 train_time:110758ms step_avg:92.22ms +step:1202/1660 train_time:110851ms step_avg:92.22ms +step:1203/1660 train_time:110945ms step_avg:92.22ms +step:1204/1660 train_time:111038ms step_avg:92.22ms +step:1205/1660 train_time:111132ms step_avg:92.23ms +step:1206/1660 train_time:111225ms step_avg:92.23ms +step:1207/1660 train_time:111319ms step_avg:92.23ms +step:1208/1660 train_time:111412ms step_avg:92.23ms +step:1209/1660 train_time:111506ms step_avg:92.23ms +step:1210/1660 train_time:111599ms step_avg:92.23ms +step:1211/1660 train_time:111692ms step_avg:92.23ms +step:1212/1660 train_time:111785ms step_avg:92.23ms +step:1213/1660 train_time:111879ms step_avg:92.23ms +step:1214/1660 train_time:111972ms step_avg:92.23ms +step:1215/1660 train_time:112065ms step_avg:92.23ms +step:1216/1660 train_time:112158ms step_avg:92.24ms +step:1217/1660 train_time:112252ms step_avg:92.24ms +step:1218/1660 train_time:112345ms step_avg:92.24ms +step:1219/1660 train_time:112439ms step_avg:92.24ms +step:1220/1660 train_time:112534ms step_avg:92.24ms +step:1221/1660 train_time:112628ms step_avg:92.24ms +step:1222/1660 train_time:112720ms step_avg:92.24ms +step:1223/1660 train_time:112814ms step_avg:92.24ms +step:1224/1660 train_time:112908ms step_avg:92.25ms +step:1225/1660 train_time:113001ms step_avg:92.25ms +step:1226/1660 train_time:113095ms step_avg:92.25ms +step:1227/1660 train_time:113189ms step_avg:92.25ms +step:1228/1660 train_time:113282ms step_avg:92.25ms +step:1229/1660 train_time:113375ms step_avg:92.25ms +step:1230/1660 train_time:113469ms step_avg:92.25ms +step:1231/1660 train_time:113562ms step_avg:92.25ms +step:1232/1660 train_time:113655ms step_avg:92.25ms +step:1233/1660 train_time:113749ms step_avg:92.25ms +step:1234/1660 train_time:113842ms step_avg:92.25ms +step:1235/1660 train_time:113935ms step_avg:92.26ms +step:1236/1660 train_time:114029ms step_avg:92.26ms +step:1237/1660 train_time:114122ms step_avg:92.26ms +step:1238/1660 train_time:114215ms step_avg:92.26ms +step:1239/1660 train_time:114309ms step_avg:92.26ms +step:1240/1660 train_time:114403ms step_avg:92.26ms +step:1241/1660 train_time:114497ms step_avg:92.26ms +step:1242/1660 train_time:114591ms step_avg:92.26ms +step:1243/1660 train_time:114684ms step_avg:92.26ms +step:1244/1660 train_time:114778ms step_avg:92.27ms +step:1245/1660 train_time:114871ms step_avg:92.27ms +step:1246/1660 train_time:114964ms step_avg:92.27ms +step:1247/1660 train_time:115058ms step_avg:92.27ms +step:1248/1660 train_time:115151ms step_avg:92.27ms +step:1249/1660 train_time:115244ms step_avg:92.27ms +step:1250/1660 train_time:115338ms step_avg:92.27ms +step:1250/1660 val_loss:3.3750 train_time:115433ms step_avg:92.35ms +step:1251/1660 train_time:115455ms step_avg:92.29ms +step:1252/1660 train_time:115532ms step_avg:92.28ms +step:1253/1660 train_time:115629ms step_avg:92.28ms +step:1254/1660 train_time:115722ms step_avg:92.28ms +step:1255/1660 train_time:115814ms step_avg:92.28ms +step:1256/1660 train_time:115906ms step_avg:92.28ms +step:1257/1660 train_time:115999ms step_avg:92.28ms +step:1258/1660 train_time:116091ms step_avg:92.28ms +step:1259/1660 train_time:116183ms step_avg:92.28ms +step:1260/1660 train_time:116275ms step_avg:92.28ms +step:1261/1660 train_time:116369ms step_avg:92.28ms +step:1262/1660 train_time:116465ms step_avg:92.29ms +step:1263/1660 train_time:116562ms step_avg:92.29ms +step:1264/1660 train_time:116656ms step_avg:92.29ms +step:1265/1660 train_time:116751ms step_avg:92.29ms +step:1266/1660 train_time:116844ms step_avg:92.29ms +step:1267/1660 train_time:116937ms step_avg:92.29ms +step:1268/1660 train_time:117029ms step_avg:92.29ms +step:1269/1660 train_time:117122ms step_avg:92.29ms +step:1270/1660 train_time:117214ms step_avg:92.29ms +step:1271/1660 train_time:117307ms step_avg:92.29ms +step:1272/1660 train_time:117400ms step_avg:92.30ms +step:1273/1660 train_time:117496ms step_avg:92.30ms +step:1274/1660 train_time:117591ms step_avg:92.30ms +step:1275/1660 train_time:117684ms step_avg:92.30ms +step:1276/1660 train_time:117778ms step_avg:92.30ms +step:1277/1660 train_time:117871ms step_avg:92.30ms +step:1278/1660 train_time:117964ms step_avg:92.30ms +step:1279/1660 train_time:118056ms step_avg:92.30ms +step:1280/1660 train_time:118149ms step_avg:92.30ms +step:1281/1660 train_time:118241ms step_avg:92.30ms +step:1282/1660 train_time:118335ms step_avg:92.30ms +step:1283/1660 train_time:118430ms step_avg:92.31ms +step:1284/1660 train_time:118525ms step_avg:92.31ms +step:1285/1660 train_time:118619ms step_avg:92.31ms +step:1286/1660 train_time:118712ms step_avg:92.31ms +step:1287/1660 train_time:118805ms step_avg:92.31ms +step:1288/1660 train_time:118900ms step_avg:92.31ms +step:1289/1660 train_time:118993ms step_avg:92.31ms +step:1290/1660 train_time:119086ms step_avg:92.31ms +step:1291/1660 train_time:119179ms step_avg:92.31ms +step:1292/1660 train_time:119271ms step_avg:92.32ms +step:1293/1660 train_time:119365ms step_avg:92.32ms +step:1294/1660 train_time:119459ms step_avg:92.32ms +step:1295/1660 train_time:119553ms step_avg:92.32ms +step:1296/1660 train_time:119648ms step_avg:92.32ms +step:1297/1660 train_time:119742ms step_avg:92.32ms +step:1298/1660 train_time:119836ms step_avg:92.32ms +step:1299/1660 train_time:119929ms step_avg:92.32ms +step:1300/1660 train_time:120022ms step_avg:92.32ms +step:1301/1660 train_time:120115ms step_avg:92.33ms +step:1302/1660 train_time:120209ms step_avg:92.33ms +step:1303/1660 train_time:120302ms step_avg:92.33ms +step:1304/1660 train_time:120396ms step_avg:92.33ms +step:1305/1660 train_time:120490ms step_avg:92.33ms +step:1306/1660 train_time:120583ms step_avg:92.33ms +step:1307/1660 train_time:120677ms step_avg:92.33ms +step:1308/1660 train_time:120771ms step_avg:92.33ms +step:1309/1660 train_time:120865ms step_avg:92.33ms +step:1310/1660 train_time:120958ms step_avg:92.33ms +step:1311/1660 train_time:121051ms step_avg:92.33ms +step:1312/1660 train_time:121144ms step_avg:92.34ms +step:1313/1660 train_time:121237ms step_avg:92.34ms +step:1314/1660 train_time:121330ms step_avg:92.34ms +step:1315/1660 train_time:121424ms step_avg:92.34ms +step:1316/1660 train_time:121516ms step_avg:92.34ms +step:1317/1660 train_time:121610ms step_avg:92.34ms +step:1318/1660 train_time:121703ms step_avg:92.34ms +step:1319/1660 train_time:121797ms step_avg:92.34ms +step:1320/1660 train_time:121891ms step_avg:92.34ms +step:1321/1660 train_time:121985ms step_avg:92.34ms +step:1322/1660 train_time:122078ms step_avg:92.34ms +step:1323/1660 train_time:122171ms step_avg:92.34ms +step:1324/1660 train_time:122265ms step_avg:92.35ms +step:1325/1660 train_time:122359ms step_avg:92.35ms +step:1326/1660 train_time:122452ms step_avg:92.35ms +step:1327/1660 train_time:122545ms step_avg:92.35ms +step:1328/1660 train_time:122638ms step_avg:92.35ms +step:1329/1660 train_time:122733ms step_avg:92.35ms +step:1330/1660 train_time:122826ms step_avg:92.35ms +step:1331/1660 train_time:122920ms step_avg:92.35ms +step:1332/1660 train_time:123013ms step_avg:92.35ms +step:1333/1660 train_time:123107ms step_avg:92.35ms +step:1334/1660 train_time:123201ms step_avg:92.35ms +step:1335/1660 train_time:123294ms step_avg:92.36ms +step:1336/1660 train_time:123389ms step_avg:92.36ms +step:1337/1660 train_time:123482ms step_avg:92.36ms +step:1338/1660 train_time:123575ms step_avg:92.36ms +step:1339/1660 train_time:123669ms step_avg:92.36ms +step:1340/1660 train_time:123763ms step_avg:92.36ms +step:1341/1660 train_time:123856ms step_avg:92.36ms +step:1342/1660 train_time:123949ms step_avg:92.36ms +step:1343/1660 train_time:124042ms step_avg:92.36ms +step:1344/1660 train_time:124135ms step_avg:92.36ms +step:1345/1660 train_time:124228ms step_avg:92.36ms +step:1346/1660 train_time:124321ms step_avg:92.36ms +step:1347/1660 train_time:124414ms step_avg:92.36ms +step:1348/1660 train_time:124508ms step_avg:92.37ms +step:1349/1660 train_time:124603ms step_avg:92.37ms +step:1350/1660 train_time:124696ms step_avg:92.37ms +step:1351/1660 train_time:124791ms step_avg:92.37ms +step:1352/1660 train_time:124885ms step_avg:92.37ms +step:1353/1660 train_time:124978ms step_avg:92.37ms +step:1354/1660 train_time:125072ms step_avg:92.37ms +step:1355/1660 train_time:125165ms step_avg:92.37ms +step:1356/1660 train_time:125258ms step_avg:92.37ms +step:1357/1660 train_time:125352ms step_avg:92.37ms +step:1358/1660 train_time:125445ms step_avg:92.37ms +step:1359/1660 train_time:125538ms step_avg:92.38ms +step:1360/1660 train_time:125632ms step_avg:92.38ms +step:1361/1660 train_time:125726ms step_avg:92.38ms +step:1362/1660 train_time:125820ms step_avg:92.38ms +step:1363/1660 train_time:125913ms step_avg:92.38ms +step:1364/1660 train_time:126006ms step_avg:92.38ms +step:1365/1660 train_time:126100ms step_avg:92.38ms +step:1366/1660 train_time:126194ms step_avg:92.38ms +step:1367/1660 train_time:126286ms step_avg:92.38ms +step:1368/1660 train_time:126379ms step_avg:92.38ms +step:1369/1660 train_time:126472ms step_avg:92.38ms +step:1370/1660 train_time:126567ms step_avg:92.38ms +step:1371/1660 train_time:126660ms step_avg:92.39ms +step:1372/1660 train_time:126754ms step_avg:92.39ms +step:1373/1660 train_time:126848ms step_avg:92.39ms +step:1374/1660 train_time:126942ms step_avg:92.39ms +step:1375/1660 train_time:127035ms step_avg:92.39ms +step:1375/1660 val_loss:3.3400 train_time:127130ms step_avg:92.46ms +step:1376/1660 train_time:127150ms step_avg:92.41ms +step:1377/1660 train_time:127228ms step_avg:92.40ms +step:1378/1660 train_time:127324ms step_avg:92.40ms +step:1379/1660 train_time:127417ms step_avg:92.40ms +step:1380/1660 train_time:127509ms step_avg:92.40ms +step:1381/1660 train_time:127601ms step_avg:92.40ms +step:1382/1660 train_time:127693ms step_avg:92.40ms +step:1383/1660 train_time:127785ms step_avg:92.40ms +step:1384/1660 train_time:127877ms step_avg:92.40ms +step:1385/1660 train_time:127970ms step_avg:92.40ms +step:1386/1660 train_time:128064ms step_avg:92.40ms +step:1387/1660 train_time:128159ms step_avg:92.40ms +step:1388/1660 train_time:128255ms step_avg:92.40ms +step:1389/1660 train_time:128351ms step_avg:92.41ms +step:1390/1660 train_time:128445ms step_avg:92.41ms +step:1391/1660 train_time:128538ms step_avg:92.41ms +step:1392/1660 train_time:128631ms step_avg:92.41ms +step:1393/1660 train_time:128724ms step_avg:92.41ms +step:1394/1660 train_time:128816ms step_avg:92.41ms +step:1395/1660 train_time:128909ms step_avg:92.41ms +step:1396/1660 train_time:129002ms step_avg:92.41ms +step:1397/1660 train_time:129096ms step_avg:92.41ms +step:1398/1660 train_time:129192ms step_avg:92.41ms +step:1399/1660 train_time:129287ms step_avg:92.41ms +step:1400/1660 train_time:129381ms step_avg:92.42ms +step:1401/1660 train_time:129474ms step_avg:92.42ms +step:1402/1660 train_time:129567ms step_avg:92.42ms +step:1403/1660 train_time:129660ms step_avg:92.42ms +step:1404/1660 train_time:129753ms step_avg:92.42ms +step:1405/1660 train_time:129846ms step_avg:92.42ms +step:1406/1660 train_time:129938ms step_avg:92.42ms +step:1407/1660 train_time:130032ms step_avg:92.42ms +step:1408/1660 train_time:130126ms step_avg:92.42ms +step:1409/1660 train_time:130221ms step_avg:92.42ms +step:1410/1660 train_time:130316ms step_avg:92.42ms +step:1411/1660 train_time:130411ms step_avg:92.42ms +step:1412/1660 train_time:130504ms step_avg:92.43ms +step:1413/1660 train_time:130598ms step_avg:92.43ms +step:1414/1660 train_time:130691ms step_avg:92.43ms +step:1415/1660 train_time:130784ms step_avg:92.43ms +step:1416/1660 train_time:130877ms step_avg:92.43ms +step:1417/1660 train_time:130970ms step_avg:92.43ms +step:1418/1660 train_time:131064ms step_avg:92.43ms +step:1419/1660 train_time:131157ms step_avg:92.43ms +step:1420/1660 train_time:131252ms step_avg:92.43ms +step:1421/1660 train_time:131346ms step_avg:92.43ms +step:1422/1660 train_time:131440ms step_avg:92.43ms +step:1423/1660 train_time:131533ms step_avg:92.43ms +step:1424/1660 train_time:131626ms step_avg:92.43ms +step:1425/1660 train_time:131720ms step_avg:92.43ms +step:1426/1660 train_time:131814ms step_avg:92.44ms +step:1427/1660 train_time:131906ms step_avg:92.44ms +step:1428/1660 train_time:132000ms step_avg:92.44ms +step:1429/1660 train_time:132094ms step_avg:92.44ms +step:1430/1660 train_time:132187ms step_avg:92.44ms +step:1431/1660 train_time:132280ms step_avg:92.44ms +step:1432/1660 train_time:132374ms step_avg:92.44ms +step:1433/1660 train_time:132468ms step_avg:92.44ms +step:1434/1660 train_time:132561ms step_avg:92.44ms +step:1435/1660 train_time:132655ms step_avg:92.44ms +step:1436/1660 train_time:132748ms step_avg:92.44ms +step:1437/1660 train_time:132841ms step_avg:92.44ms +step:1438/1660 train_time:132934ms step_avg:92.44ms +step:1439/1660 train_time:133026ms step_avg:92.44ms +step:1440/1660 train_time:133120ms step_avg:92.44ms +step:1441/1660 train_time:133214ms step_avg:92.45ms +step:1442/1660 train_time:133308ms step_avg:92.45ms +step:1443/1660 train_time:133401ms step_avg:92.45ms +step:1444/1660 train_time:133495ms step_avg:92.45ms +step:1445/1660 train_time:133591ms step_avg:92.45ms +step:1446/1660 train_time:133686ms step_avg:92.45ms +step:1447/1660 train_time:133779ms step_avg:92.45ms +step:1448/1660 train_time:133872ms step_avg:92.45ms +step:1449/1660 train_time:133966ms step_avg:92.45ms +step:1450/1660 train_time:134060ms step_avg:92.46ms +step:1451/1660 train_time:134154ms step_avg:92.46ms +step:1452/1660 train_time:134248ms step_avg:92.46ms +step:1453/1660 train_time:134340ms step_avg:92.46ms +step:1454/1660 train_time:134434ms step_avg:92.46ms +step:1455/1660 train_time:134527ms step_avg:92.46ms +step:1456/1660 train_time:134622ms step_avg:92.46ms +step:1457/1660 train_time:134717ms step_avg:92.46ms +step:1458/1660 train_time:134810ms step_avg:92.46ms +step:1459/1660 train_time:134903ms step_avg:92.46ms +step:1460/1660 train_time:134996ms step_avg:92.46ms +step:1461/1660 train_time:135089ms step_avg:92.46ms +step:1462/1660 train_time:135183ms step_avg:92.46ms +step:1463/1660 train_time:135276ms step_avg:92.46ms +step:1464/1660 train_time:135370ms step_avg:92.47ms +step:1465/1660 train_time:135464ms step_avg:92.47ms +step:1466/1660 train_time:135558ms step_avg:92.47ms +step:1467/1660 train_time:135652ms step_avg:92.47ms +step:1468/1660 train_time:135745ms step_avg:92.47ms +step:1469/1660 train_time:135838ms step_avg:92.47ms +step:1470/1660 train_time:135932ms step_avg:92.47ms +step:1471/1660 train_time:136025ms step_avg:92.47ms +step:1472/1660 train_time:136119ms step_avg:92.47ms +step:1473/1660 train_time:136213ms step_avg:92.47ms +step:1474/1660 train_time:136305ms step_avg:92.47ms +step:1475/1660 train_time:136399ms step_avg:92.47ms +step:1476/1660 train_time:136493ms step_avg:92.47ms +step:1477/1660 train_time:136587ms step_avg:92.48ms +step:1478/1660 train_time:136680ms step_avg:92.48ms +step:1479/1660 train_time:136774ms step_avg:92.48ms +step:1480/1660 train_time:136868ms step_avg:92.48ms +step:1481/1660 train_time:136962ms step_avg:92.48ms +step:1482/1660 train_time:137055ms step_avg:92.48ms +step:1483/1660 train_time:137148ms step_avg:92.48ms +step:1484/1660 train_time:137241ms step_avg:92.48ms +step:1485/1660 train_time:137335ms step_avg:92.48ms +step:1486/1660 train_time:137429ms step_avg:92.48ms +step:1487/1660 train_time:137522ms step_avg:92.48ms +step:1488/1660 train_time:137615ms step_avg:92.48ms +step:1489/1660 train_time:137709ms step_avg:92.48ms +step:1490/1660 train_time:137802ms step_avg:92.48ms +step:1491/1660 train_time:137896ms step_avg:92.49ms +step:1492/1660 train_time:137990ms step_avg:92.49ms +step:1493/1660 train_time:138085ms step_avg:92.49ms +step:1494/1660 train_time:138178ms step_avg:92.49ms +step:1495/1660 train_time:138273ms step_avg:92.49ms +step:1496/1660 train_time:138366ms step_avg:92.49ms +step:1497/1660 train_time:138460ms step_avg:92.49ms +step:1498/1660 train_time:138554ms step_avg:92.49ms +step:1499/1660 train_time:138648ms step_avg:92.49ms +step:1500/1660 train_time:138741ms step_avg:92.49ms +step:1500/1660 val_loss:3.3101 train_time:138835ms step_avg:92.56ms +step:1501/1660 train_time:138856ms step_avg:92.51ms +step:1502/1660 train_time:138931ms step_avg:92.50ms +step:1503/1660 train_time:139028ms step_avg:92.50ms +step:1504/1660 train_time:139122ms step_avg:92.50ms +step:1505/1660 train_time:139215ms step_avg:92.50ms +step:1506/1660 train_time:139307ms step_avg:92.50ms +step:1507/1660 train_time:139400ms step_avg:92.50ms +step:1508/1660 train_time:139493ms step_avg:92.50ms +step:1509/1660 train_time:139587ms step_avg:92.50ms +step:1510/1660 train_time:139679ms step_avg:92.50ms +step:1511/1660 train_time:139773ms step_avg:92.50ms +step:1512/1660 train_time:139868ms step_avg:92.51ms +step:1513/1660 train_time:139962ms step_avg:92.51ms +step:1514/1660 train_time:140057ms step_avg:92.51ms +step:1515/1660 train_time:140149ms step_avg:92.51ms +step:1516/1660 train_time:140243ms step_avg:92.51ms +step:1517/1660 train_time:140335ms step_avg:92.51ms +step:1518/1660 train_time:140427ms step_avg:92.51ms +step:1519/1660 train_time:140520ms step_avg:92.51ms +step:1520/1660 train_time:140614ms step_avg:92.51ms +step:1521/1660 train_time:140707ms step_avg:92.51ms +step:1522/1660 train_time:140801ms step_avg:92.51ms +step:1523/1660 train_time:140895ms step_avg:92.51ms +step:1524/1660 train_time:140989ms step_avg:92.51ms +step:1525/1660 train_time:141083ms step_avg:92.51ms +step:1526/1660 train_time:141177ms step_avg:92.51ms +step:1527/1660 train_time:141270ms step_avg:92.51ms +step:1528/1660 train_time:141363ms step_avg:92.51ms +step:1529/1660 train_time:141456ms step_avg:92.52ms +step:1530/1660 train_time:141549ms step_avg:92.52ms +step:1531/1660 train_time:141642ms step_avg:92.52ms +step:1532/1660 train_time:141735ms step_avg:92.52ms +step:1533/1660 train_time:141829ms step_avg:92.52ms +step:1534/1660 train_time:141922ms step_avg:92.52ms +step:1535/1660 train_time:142017ms step_avg:92.52ms +step:1536/1660 train_time:142112ms step_avg:92.52ms +step:1537/1660 train_time:142207ms step_avg:92.52ms +step:1538/1660 train_time:142300ms step_avg:92.52ms +step:1539/1660 train_time:142394ms step_avg:92.52ms +step:1540/1660 train_time:142487ms step_avg:92.52ms +step:1541/1660 train_time:142580ms step_avg:92.52ms +step:1542/1660 train_time:142673ms step_avg:92.52ms +step:1543/1660 train_time:142767ms step_avg:92.53ms +step:1544/1660 train_time:142860ms step_avg:92.53ms +step:1545/1660 train_time:142955ms step_avg:92.53ms +step:1546/1660 train_time:143048ms step_avg:92.53ms +step:1547/1660 train_time:143141ms step_avg:92.53ms +step:1548/1660 train_time:143235ms step_avg:92.53ms +step:1549/1660 train_time:143328ms step_avg:92.53ms +step:1550/1660 train_time:143421ms step_avg:92.53ms +step:1551/1660 train_time:143514ms step_avg:92.53ms +step:1552/1660 train_time:143608ms step_avg:92.53ms +step:1553/1660 train_time:143701ms step_avg:92.53ms +step:1554/1660 train_time:143794ms step_avg:92.53ms +step:1555/1660 train_time:143888ms step_avg:92.53ms +step:1556/1660 train_time:143982ms step_avg:92.53ms +step:1557/1660 train_time:144076ms step_avg:92.53ms +step:1558/1660 train_time:144170ms step_avg:92.54ms +step:1559/1660 train_time:144264ms step_avg:92.54ms +step:1560/1660 train_time:144357ms step_avg:92.54ms +step:1561/1660 train_time:144450ms step_avg:92.54ms +step:1562/1660 train_time:144543ms step_avg:92.54ms +step:1563/1660 train_time:144636ms step_avg:92.54ms +step:1564/1660 train_time:144729ms step_avg:92.54ms +step:1565/1660 train_time:144823ms step_avg:92.54ms +step:1566/1660 train_time:144916ms step_avg:92.54ms +step:1567/1660 train_time:145011ms step_avg:92.54ms +step:1568/1660 train_time:145104ms step_avg:92.54ms +step:1569/1660 train_time:145198ms step_avg:92.54ms +step:1570/1660 train_time:145293ms step_avg:92.54ms +step:1571/1660 train_time:145386ms step_avg:92.54ms +step:1572/1660 train_time:145478ms step_avg:92.54ms +step:1573/1660 train_time:145572ms step_avg:92.54ms +step:1574/1660 train_time:145665ms step_avg:92.54ms +step:1575/1660 train_time:145758ms step_avg:92.54ms +step:1576/1660 train_time:145851ms step_avg:92.55ms +step:1577/1660 train_time:145945ms step_avg:92.55ms +step:1578/1660 train_time:146038ms step_avg:92.55ms +step:1579/1660 train_time:146133ms step_avg:92.55ms +step:1580/1660 train_time:146226ms step_avg:92.55ms +step:1581/1660 train_time:146321ms step_avg:92.55ms +step:1582/1660 train_time:146414ms step_avg:92.55ms +step:1583/1660 train_time:146508ms step_avg:92.55ms +step:1584/1660 train_time:146601ms step_avg:92.55ms +step:1585/1660 train_time:146694ms step_avg:92.55ms +step:1586/1660 train_time:146787ms step_avg:92.55ms +step:1587/1660 train_time:146881ms step_avg:92.55ms +step:1588/1660 train_time:146975ms step_avg:92.55ms +step:1589/1660 train_time:147069ms step_avg:92.55ms +step:1590/1660 train_time:147162ms step_avg:92.55ms +step:1591/1660 train_time:147255ms step_avg:92.55ms +step:1592/1660 train_time:147348ms step_avg:92.56ms +step:1593/1660 train_time:147442ms step_avg:92.56ms +step:1594/1660 train_time:147535ms step_avg:92.56ms +step:1595/1660 train_time:147628ms step_avg:92.56ms +step:1596/1660 train_time:147721ms step_avg:92.56ms +step:1597/1660 train_time:147815ms step_avg:92.56ms +step:1598/1660 train_time:147908ms step_avg:92.56ms +step:1599/1660 train_time:148002ms step_avg:92.56ms +step:1600/1660 train_time:148095ms step_avg:92.56ms +step:1601/1660 train_time:148189ms step_avg:92.56ms +step:1602/1660 train_time:148282ms step_avg:92.56ms +step:1603/1660 train_time:148376ms step_avg:92.56ms +step:1604/1660 train_time:148469ms step_avg:92.56ms +step:1605/1660 train_time:148563ms step_avg:92.56ms +step:1606/1660 train_time:148655ms step_avg:92.56ms +step:1607/1660 train_time:148748ms step_avg:92.56ms +step:1608/1660 train_time:148842ms step_avg:92.56ms +step:1609/1660 train_time:148935ms step_avg:92.56ms +step:1610/1660 train_time:149028ms step_avg:92.56ms +step:1611/1660 train_time:149122ms step_avg:92.57ms +step:1612/1660 train_time:149217ms step_avg:92.57ms +step:1613/1660 train_time:149311ms step_avg:92.57ms +step:1614/1660 train_time:149405ms step_avg:92.57ms +step:1615/1660 train_time:149498ms step_avg:92.57ms +step:1616/1660 train_time:149594ms step_avg:92.57ms +step:1617/1660 train_time:149687ms step_avg:92.57ms +step:1618/1660 train_time:149780ms step_avg:92.57ms +step:1619/1660 train_time:149874ms step_avg:92.57ms +step:1620/1660 train_time:149967ms step_avg:92.57ms +step:1621/1660 train_time:150060ms step_avg:92.57ms +step:1622/1660 train_time:150154ms step_avg:92.57ms +step:1623/1660 train_time:150247ms step_avg:92.57ms +step:1624/1660 train_time:150341ms step_avg:92.57ms +step:1625/1660 train_time:150435ms step_avg:92.58ms +step:1625/1660 val_loss:3.2852 train_time:150530ms step_avg:92.63ms +step:1626/1660 train_time:150550ms step_avg:92.59ms +step:1627/1660 train_time:150627ms step_avg:92.58ms +step:1628/1660 train_time:150723ms step_avg:92.58ms +step:1629/1660 train_time:150817ms step_avg:92.58ms +step:1630/1660 train_time:150910ms step_avg:92.58ms +step:1631/1660 train_time:151002ms step_avg:92.58ms +step:1632/1660 train_time:151095ms step_avg:92.58ms +step:1633/1660 train_time:151187ms step_avg:92.58ms +step:1634/1660 train_time:151279ms step_avg:92.58ms +step:1635/1660 train_time:151372ms step_avg:92.58ms +step:1636/1660 train_time:151466ms step_avg:92.58ms +step:1637/1660 train_time:151562ms step_avg:92.59ms +step:1638/1660 train_time:151660ms step_avg:92.59ms +step:1639/1660 train_time:151755ms step_avg:92.59ms +step:1640/1660 train_time:151848ms step_avg:92.59ms +step:1641/1660 train_time:151942ms step_avg:92.59ms +step:1642/1660 train_time:152036ms step_avg:92.59ms +step:1643/1660 train_time:152128ms step_avg:92.59ms +step:1644/1660 train_time:152220ms step_avg:92.59ms +step:1645/1660 train_time:152314ms step_avg:92.59ms +step:1646/1660 train_time:152408ms step_avg:92.59ms +step:1647/1660 train_time:152502ms step_avg:92.59ms +step:1648/1660 train_time:152596ms step_avg:92.59ms +step:1649/1660 train_time:152690ms step_avg:92.60ms +step:1650/1660 train_time:152784ms step_avg:92.60ms +step:1651/1660 train_time:152877ms step_avg:92.60ms +step:1652/1660 train_time:152972ms step_avg:92.60ms +step:1653/1660 train_time:153064ms step_avg:92.60ms +step:1654/1660 train_time:153157ms step_avg:92.60ms +step:1655/1660 train_time:153249ms step_avg:92.60ms +step:1656/1660 train_time:153342ms step_avg:92.60ms +step:1657/1660 train_time:153438ms step_avg:92.60ms +step:1658/1660 train_time:153533ms step_avg:92.60ms +step:1659/1660 train_time:153627ms step_avg:92.60ms +step:1660/1660 train_time:153721ms step_avg:92.60ms +step:1660/1660 val_loss:3.2770 train_time:153816ms step_avg:92.66ms +peak memory allocated: 32002 MiB reserved: 46836 MiB diff --git a/records/091525_ThreadingFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt b/records/091525_ThreadingFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt new file mode 100644 index 000000000..7f2a5d990 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:29:15 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 187426 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 187427 C /usr/bin/python3 614MiB | +| 0 N/A N/A 187428 C /usr/bin/python3 614MiB | +| 0 N/A N/A 187429 C /usr/bin/python3 614MiB | +| 0 N/A N/A 187430 C /usr/bin/python3 614MiB | +| 0 N/A N/A 187431 C /usr/bin/python3 614MiB | +| 0 N/A N/A 187432 C /usr/bin/python3 614MiB | +| 0 N/A N/A 187433 C /usr/bin/python3 614MiB | +| 1 N/A N/A 187427 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 187428 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 187429 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 187430 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 187431 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 187432 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 187433 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:160ms step_avg:160.47ms +step:2/1660 train_time:183ms step_avg:91.41ms +step:3/1660 train_time:248ms step_avg:82.58ms +step:4/1660 train_time:337ms step_avg:84.32ms +step:5/1660 train_time:427ms step_avg:85.47ms +step:6/1660 train_time:518ms step_avg:86.25ms +step:7/1660 train_time:608ms step_avg:86.84ms +step:8/1660 train_time:698ms step_avg:87.31ms +step:9/1660 train_time:789ms step_avg:87.64ms +step:10/1660 train_time:879ms step_avg:87.94ms +step:11/1660 train_time:970ms step_avg:88.20ms +step:12/1660 train_time:1066ms step_avg:88.80ms +step:13/1660 train_time:1162ms step_avg:89.37ms +step:14/1660 train_time:1255ms step_avg:89.62ms +step:15/1660 train_time:1347ms step_avg:89.77ms +step:16/1660 train_time:1438ms step_avg:89.87ms +step:17/1660 train_time:1529ms step_avg:89.95ms +step:18/1660 train_time:1620ms step_avg:90.01ms +step:19/1660 train_time:1711ms step_avg:90.05ms +step:20/1660 train_time:1803ms step_avg:90.13ms +step:21/1660 train_time:1894ms step_avg:90.18ms +step:22/1660 train_time:1986ms step_avg:90.28ms +step:23/1660 train_time:2081ms step_avg:90.46ms +step:24/1660 train_time:2173ms step_avg:90.56ms +step:25/1660 train_time:2266ms step_avg:90.66ms +step:26/1660 train_time:2360ms step_avg:90.75ms +step:27/1660 train_time:2451ms step_avg:90.78ms +step:28/1660 train_time:2543ms step_avg:90.81ms +step:29/1660 train_time:2634ms step_avg:90.82ms +step:30/1660 train_time:2725ms step_avg:90.84ms +step:31/1660 train_time:2816ms step_avg:90.83ms +step:32/1660 train_time:2907ms step_avg:90.84ms +step:33/1660 train_time:3000ms step_avg:90.90ms +step:34/1660 train_time:3093ms step_avg:90.96ms +step:35/1660 train_time:3187ms step_avg:91.05ms +step:36/1660 train_time:3279ms step_avg:91.09ms +step:37/1660 train_time:3371ms step_avg:91.12ms +step:38/1660 train_time:3464ms step_avg:91.16ms +step:39/1660 train_time:3555ms step_avg:91.16ms +step:40/1660 train_time:3646ms step_avg:91.15ms +step:41/1660 train_time:3738ms step_avg:91.16ms +step:42/1660 train_time:3829ms step_avg:91.16ms +step:43/1660 train_time:3919ms step_avg:91.15ms +step:44/1660 train_time:4011ms step_avg:91.15ms +step:45/1660 train_time:4104ms step_avg:91.20ms +step:46/1660 train_time:4197ms step_avg:91.23ms +step:47/1660 train_time:4289ms step_avg:91.26ms +step:48/1660 train_time:4382ms step_avg:91.29ms +step:49/1660 train_time:4474ms step_avg:91.30ms +step:50/1660 train_time:4566ms step_avg:91.32ms +step:51/1660 train_time:4658ms step_avg:91.33ms +step:52/1660 train_time:4749ms step_avg:91.33ms +step:53/1660 train_time:4841ms step_avg:91.34ms +step:54/1660 train_time:4932ms step_avg:91.34ms +step:55/1660 train_time:5024ms step_avg:91.34ms +step:56/1660 train_time:5116ms step_avg:91.35ms +step:57/1660 train_time:5208ms step_avg:91.36ms +step:58/1660 train_time:5300ms step_avg:91.39ms +step:59/1660 train_time:5392ms step_avg:91.39ms +step:60/1660 train_time:5484ms step_avg:91.40ms +step:61/1660 train_time:5576ms step_avg:91.41ms +step:62/1660 train_time:5668ms step_avg:91.41ms +step:63/1660 train_time:5759ms step_avg:91.41ms +step:64/1660 train_time:5850ms step_avg:91.41ms +step:65/1660 train_time:5942ms step_avg:91.41ms +step:66/1660 train_time:6033ms step_avg:91.41ms +step:67/1660 train_time:6125ms step_avg:91.42ms +step:68/1660 train_time:6217ms step_avg:91.43ms +step:69/1660 train_time:6309ms step_avg:91.43ms +step:70/1660 train_time:6401ms step_avg:91.44ms +step:71/1660 train_time:6493ms step_avg:91.45ms +step:72/1660 train_time:6586ms step_avg:91.47ms +step:73/1660 train_time:6678ms step_avg:91.47ms +step:74/1660 train_time:6769ms step_avg:91.47ms +step:75/1660 train_time:6860ms step_avg:91.47ms +step:76/1660 train_time:6951ms step_avg:91.46ms +step:77/1660 train_time:7042ms step_avg:91.46ms +step:78/1660 train_time:7134ms step_avg:91.46ms +step:79/1660 train_time:7226ms step_avg:91.46ms +step:80/1660 train_time:7317ms step_avg:91.46ms +step:81/1660 train_time:7409ms step_avg:91.47ms +step:82/1660 train_time:7501ms step_avg:91.48ms +step:83/1660 train_time:7592ms step_avg:91.47ms +step:84/1660 train_time:7684ms step_avg:91.47ms +step:85/1660 train_time:7776ms step_avg:91.48ms +step:86/1660 train_time:7867ms step_avg:91.48ms +step:87/1660 train_time:7959ms step_avg:91.48ms +step:88/1660 train_time:8050ms step_avg:91.48ms +step:89/1660 train_time:8142ms step_avg:91.48ms +step:90/1660 train_time:8234ms step_avg:91.49ms +step:91/1660 train_time:8326ms step_avg:91.50ms +step:92/1660 train_time:8417ms step_avg:91.49ms +step:93/1660 train_time:8509ms step_avg:91.49ms +step:94/1660 train_time:8601ms step_avg:91.49ms +step:95/1660 train_time:8692ms step_avg:91.50ms +step:96/1660 train_time:8783ms step_avg:91.49ms +step:97/1660 train_time:8875ms step_avg:91.50ms +step:98/1660 train_time:8967ms step_avg:91.50ms +step:99/1660 train_time:9059ms step_avg:91.51ms +step:100/1660 train_time:9151ms step_avg:91.51ms +step:101/1660 train_time:9243ms step_avg:91.51ms +step:102/1660 train_time:9334ms step_avg:91.51ms +step:103/1660 train_time:9426ms step_avg:91.51ms +step:104/1660 train_time:9518ms step_avg:91.52ms +step:105/1660 train_time:9609ms step_avg:91.52ms +step:106/1660 train_time:9701ms step_avg:91.52ms +step:107/1660 train_time:9793ms step_avg:91.52ms +step:108/1660 train_time:9885ms step_avg:91.53ms +step:109/1660 train_time:9977ms step_avg:91.53ms +step:110/1660 train_time:10069ms step_avg:91.53ms +step:111/1660 train_time:10160ms step_avg:91.53ms +step:112/1660 train_time:10252ms step_avg:91.53ms +step:113/1660 train_time:10344ms step_avg:91.54ms +step:114/1660 train_time:10436ms step_avg:91.55ms +step:115/1660 train_time:10528ms step_avg:91.55ms +step:116/1660 train_time:10619ms step_avg:91.54ms +step:117/1660 train_time:10711ms step_avg:91.54ms +step:118/1660 train_time:10803ms step_avg:91.55ms +step:119/1660 train_time:10896ms step_avg:91.56ms +step:120/1660 train_time:10988ms step_avg:91.57ms +step:121/1660 train_time:11081ms step_avg:91.57ms +step:122/1660 train_time:11172ms step_avg:91.57ms +step:123/1660 train_time:11263ms step_avg:91.57ms +step:124/1660 train_time:11354ms step_avg:91.57ms +step:125/1660 train_time:11445ms step_avg:91.56ms +step:125/1660 val_loss:4.3086 train_time:11539ms step_avg:92.32ms +step:126/1660 train_time:11562ms step_avg:91.76ms +step:127/1660 train_time:11638ms step_avg:91.64ms +step:128/1660 train_time:11741ms step_avg:91.73ms +step:129/1660 train_time:11836ms step_avg:91.75ms +step:130/1660 train_time:11928ms step_avg:91.75ms +step:131/1660 train_time:12018ms step_avg:91.74ms +step:132/1660 train_time:12108ms step_avg:91.73ms +step:133/1660 train_time:12198ms step_avg:91.72ms +step:134/1660 train_time:12288ms step_avg:91.71ms +step:135/1660 train_time:12379ms step_avg:91.69ms +step:136/1660 train_time:12469ms step_avg:91.68ms +step:137/1660 train_time:12559ms step_avg:91.67ms +step:138/1660 train_time:12653ms step_avg:91.69ms +step:139/1660 train_time:12748ms step_avg:91.71ms +step:140/1660 train_time:12839ms step_avg:91.71ms +step:141/1660 train_time:12931ms step_avg:91.71ms +step:142/1660 train_time:13024ms step_avg:91.72ms +step:143/1660 train_time:13115ms step_avg:91.72ms +step:144/1660 train_time:13206ms step_avg:91.71ms +step:145/1660 train_time:13297ms step_avg:91.70ms +step:146/1660 train_time:13387ms step_avg:91.69ms +step:147/1660 train_time:13477ms step_avg:91.68ms +step:148/1660 train_time:13568ms step_avg:91.68ms +step:149/1660 train_time:13660ms step_avg:91.68ms +step:150/1660 train_time:13754ms step_avg:91.69ms +step:151/1660 train_time:13846ms step_avg:91.70ms +step:152/1660 train_time:13938ms step_avg:91.70ms +step:153/1660 train_time:14031ms step_avg:91.71ms +step:154/1660 train_time:14122ms step_avg:91.70ms +step:155/1660 train_time:14214ms step_avg:91.70ms +step:156/1660 train_time:14305ms step_avg:91.70ms +step:157/1660 train_time:14396ms step_avg:91.69ms +step:158/1660 train_time:14487ms step_avg:91.69ms +step:159/1660 train_time:14577ms step_avg:91.68ms +step:160/1660 train_time:14669ms step_avg:91.68ms +step:161/1660 train_time:14761ms step_avg:91.68ms +step:162/1660 train_time:14855ms step_avg:91.70ms +step:163/1660 train_time:14947ms step_avg:91.70ms +step:164/1660 train_time:15039ms step_avg:91.70ms +step:165/1660 train_time:15131ms step_avg:91.70ms +step:166/1660 train_time:15222ms step_avg:91.70ms +step:167/1660 train_time:15314ms step_avg:91.70ms +step:168/1660 train_time:15405ms step_avg:91.70ms +step:169/1660 train_time:15496ms step_avg:91.69ms +step:170/1660 train_time:15588ms step_avg:91.69ms +step:171/1660 train_time:15679ms step_avg:91.69ms +step:172/1660 train_time:15772ms step_avg:91.70ms +step:173/1660 train_time:15863ms step_avg:91.69ms +step:174/1660 train_time:15956ms step_avg:91.70ms +step:175/1660 train_time:16048ms step_avg:91.70ms +step:176/1660 train_time:16139ms step_avg:91.70ms +step:177/1660 train_time:16230ms step_avg:91.70ms +step:178/1660 train_time:16322ms step_avg:91.69ms +step:179/1660 train_time:16414ms step_avg:91.70ms +step:180/1660 train_time:16505ms step_avg:91.70ms +step:181/1660 train_time:16596ms step_avg:91.69ms +step:182/1660 train_time:16688ms step_avg:91.69ms +step:183/1660 train_time:16780ms step_avg:91.69ms +step:184/1660 train_time:16872ms step_avg:91.69ms +step:185/1660 train_time:16962ms step_avg:91.69ms +step:186/1660 train_time:17056ms step_avg:91.70ms +step:187/1660 train_time:17148ms step_avg:91.70ms +step:188/1660 train_time:17240ms step_avg:91.70ms +step:189/1660 train_time:17331ms step_avg:91.70ms +step:190/1660 train_time:17422ms step_avg:91.69ms +step:191/1660 train_time:17513ms step_avg:91.69ms +step:192/1660 train_time:17604ms step_avg:91.69ms +step:193/1660 train_time:17696ms step_avg:91.69ms +step:194/1660 train_time:17787ms step_avg:91.68ms +step:195/1660 train_time:17878ms step_avg:91.68ms +step:196/1660 train_time:17970ms step_avg:91.68ms +step:197/1660 train_time:18062ms step_avg:91.69ms +step:198/1660 train_time:18155ms step_avg:91.69ms +step:199/1660 train_time:18248ms step_avg:91.70ms +step:200/1660 train_time:18339ms step_avg:91.69ms +step:201/1660 train_time:18430ms step_avg:91.69ms +step:202/1660 train_time:18521ms step_avg:91.69ms +step:203/1660 train_time:18613ms step_avg:91.69ms +step:204/1660 train_time:18704ms step_avg:91.69ms +step:205/1660 train_time:18795ms step_avg:91.68ms +step:206/1660 train_time:18888ms step_avg:91.69ms +step:207/1660 train_time:18979ms step_avg:91.69ms +step:208/1660 train_time:19071ms step_avg:91.69ms +step:209/1660 train_time:19162ms step_avg:91.68ms +step:210/1660 train_time:19254ms step_avg:91.69ms +step:211/1660 train_time:19347ms step_avg:91.69ms +step:212/1660 train_time:19439ms step_avg:91.69ms +step:213/1660 train_time:19530ms step_avg:91.69ms +step:214/1660 train_time:19621ms step_avg:91.69ms +step:215/1660 train_time:19714ms step_avg:91.69ms +step:216/1660 train_time:19805ms step_avg:91.69ms +step:217/1660 train_time:19896ms step_avg:91.69ms +step:218/1660 train_time:19988ms step_avg:91.69ms +step:219/1660 train_time:20079ms step_avg:91.69ms +step:220/1660 train_time:20170ms step_avg:91.68ms +step:221/1660 train_time:20263ms step_avg:91.69ms +step:222/1660 train_time:20356ms step_avg:91.69ms +step:223/1660 train_time:20448ms step_avg:91.70ms +step:224/1660 train_time:20539ms step_avg:91.69ms +step:225/1660 train_time:20631ms step_avg:91.69ms +step:226/1660 train_time:20722ms step_avg:91.69ms +step:227/1660 train_time:20814ms step_avg:91.69ms +step:228/1660 train_time:20905ms step_avg:91.69ms +step:229/1660 train_time:20997ms step_avg:91.69ms +step:230/1660 train_time:21088ms step_avg:91.69ms +step:231/1660 train_time:21180ms step_avg:91.69ms +step:232/1660 train_time:21272ms step_avg:91.69ms +step:233/1660 train_time:21363ms step_avg:91.69ms +step:234/1660 train_time:21456ms step_avg:91.69ms +step:235/1660 train_time:21549ms step_avg:91.70ms +step:236/1660 train_time:21639ms step_avg:91.69ms +step:237/1660 train_time:21730ms step_avg:91.69ms +step:238/1660 train_time:21821ms step_avg:91.69ms +step:239/1660 train_time:21912ms step_avg:91.68ms +step:240/1660 train_time:22003ms step_avg:91.68ms +step:241/1660 train_time:22094ms step_avg:91.68ms +step:242/1660 train_time:22186ms step_avg:91.68ms +step:243/1660 train_time:22277ms step_avg:91.68ms +step:244/1660 train_time:22369ms step_avg:91.67ms +step:245/1660 train_time:22461ms step_avg:91.68ms +step:246/1660 train_time:22555ms step_avg:91.69ms +step:247/1660 train_time:22647ms step_avg:91.69ms +step:248/1660 train_time:22738ms step_avg:91.68ms +step:249/1660 train_time:22829ms step_avg:91.68ms +step:250/1660 train_time:22920ms step_avg:91.68ms +step:250/1660 val_loss:3.9601 train_time:23013ms step_avg:92.05ms +step:251/1660 train_time:23034ms step_avg:91.77ms +step:252/1660 train_time:23107ms step_avg:91.69ms +step:253/1660 train_time:23205ms step_avg:91.72ms +step:254/1660 train_time:23299ms step_avg:91.73ms +step:255/1660 train_time:23390ms step_avg:91.73ms +step:256/1660 train_time:23481ms step_avg:91.72ms +step:257/1660 train_time:23571ms step_avg:91.72ms +step:258/1660 train_time:23661ms step_avg:91.71ms +step:259/1660 train_time:23752ms step_avg:91.71ms +step:260/1660 train_time:23842ms step_avg:91.70ms +step:261/1660 train_time:23933ms step_avg:91.70ms +step:262/1660 train_time:24027ms step_avg:91.70ms +step:263/1660 train_time:24120ms step_avg:91.71ms +step:264/1660 train_time:24214ms step_avg:91.72ms +step:265/1660 train_time:24307ms step_avg:91.72ms +step:266/1660 train_time:24398ms step_avg:91.72ms +step:267/1660 train_time:24490ms step_avg:91.72ms +step:268/1660 train_time:24580ms step_avg:91.72ms +step:269/1660 train_time:24672ms step_avg:91.72ms +step:270/1660 train_time:24763ms step_avg:91.71ms +step:271/1660 train_time:24854ms step_avg:91.71ms +step:272/1660 train_time:24945ms step_avg:91.71ms +step:273/1660 train_time:25037ms step_avg:91.71ms +step:274/1660 train_time:25129ms step_avg:91.71ms +step:275/1660 train_time:25221ms step_avg:91.71ms +step:276/1660 train_time:25313ms step_avg:91.71ms +step:277/1660 train_time:25405ms step_avg:91.71ms +step:278/1660 train_time:25497ms step_avg:91.72ms +step:279/1660 train_time:25589ms step_avg:91.72ms +step:280/1660 train_time:25679ms step_avg:91.71ms +step:281/1660 train_time:25770ms step_avg:91.71ms +step:282/1660 train_time:25862ms step_avg:91.71ms +step:283/1660 train_time:25953ms step_avg:91.71ms +step:284/1660 train_time:26046ms step_avg:91.71ms +step:285/1660 train_time:26138ms step_avg:91.71ms +step:286/1660 train_time:26230ms step_avg:91.71ms +step:287/1660 train_time:26321ms step_avg:91.71ms +step:288/1660 train_time:26413ms step_avg:91.71ms +step:289/1660 train_time:26505ms step_avg:91.71ms +step:290/1660 train_time:26596ms step_avg:91.71ms +step:291/1660 train_time:26688ms step_avg:91.71ms +step:292/1660 train_time:26778ms step_avg:91.71ms +step:293/1660 train_time:26869ms step_avg:91.70ms +step:294/1660 train_time:26961ms step_avg:91.70ms +step:295/1660 train_time:27053ms step_avg:91.70ms +step:296/1660 train_time:27144ms step_avg:91.70ms +step:297/1660 train_time:27236ms step_avg:91.70ms +step:298/1660 train_time:27327ms step_avg:91.70ms +step:299/1660 train_time:27419ms step_avg:91.70ms +step:300/1660 train_time:27511ms step_avg:91.70ms +step:301/1660 train_time:27602ms step_avg:91.70ms +step:302/1660 train_time:27694ms step_avg:91.70ms +step:303/1660 train_time:27786ms step_avg:91.70ms +step:304/1660 train_time:27877ms step_avg:91.70ms +step:305/1660 train_time:27969ms step_avg:91.70ms +step:306/1660 train_time:28060ms step_avg:91.70ms +step:307/1660 train_time:28151ms step_avg:91.70ms +step:308/1660 train_time:28242ms step_avg:91.70ms +step:309/1660 train_time:28334ms step_avg:91.70ms +step:310/1660 train_time:28426ms step_avg:91.70ms +step:311/1660 train_time:28517ms step_avg:91.69ms +step:312/1660 train_time:28608ms step_avg:91.69ms +step:313/1660 train_time:28699ms step_avg:91.69ms +step:314/1660 train_time:28792ms step_avg:91.69ms +step:315/1660 train_time:28884ms step_avg:91.70ms +step:316/1660 train_time:28976ms step_avg:91.69ms +step:317/1660 train_time:29068ms step_avg:91.70ms +step:318/1660 train_time:29160ms step_avg:91.70ms +step:319/1660 train_time:29251ms step_avg:91.70ms +step:320/1660 train_time:29342ms step_avg:91.69ms +step:321/1660 train_time:29435ms step_avg:91.70ms +step:322/1660 train_time:29527ms step_avg:91.70ms +step:323/1660 train_time:29618ms step_avg:91.70ms +step:324/1660 train_time:29709ms step_avg:91.69ms +step:325/1660 train_time:29800ms step_avg:91.69ms +step:326/1660 train_time:29893ms step_avg:91.70ms +step:327/1660 train_time:29985ms step_avg:91.70ms +step:328/1660 train_time:30076ms step_avg:91.70ms +step:329/1660 train_time:30168ms step_avg:91.69ms +step:330/1660 train_time:30259ms step_avg:91.69ms +step:331/1660 train_time:30352ms step_avg:91.70ms +step:332/1660 train_time:30443ms step_avg:91.70ms +step:333/1660 train_time:30534ms step_avg:91.70ms +step:334/1660 train_time:30627ms step_avg:91.70ms +step:335/1660 train_time:30718ms step_avg:91.69ms +step:336/1660 train_time:30809ms step_avg:91.69ms +step:337/1660 train_time:30901ms step_avg:91.69ms +step:338/1660 train_time:30993ms step_avg:91.70ms +step:339/1660 train_time:31085ms step_avg:91.70ms +step:340/1660 train_time:31177ms step_avg:91.70ms +step:341/1660 train_time:31268ms step_avg:91.70ms +step:342/1660 train_time:31360ms step_avg:91.70ms +step:343/1660 train_time:31452ms step_avg:91.70ms +step:344/1660 train_time:31544ms step_avg:91.70ms +step:345/1660 train_time:31636ms step_avg:91.70ms +step:346/1660 train_time:31728ms step_avg:91.70ms +step:347/1660 train_time:31819ms step_avg:91.70ms +step:348/1660 train_time:31911ms step_avg:91.70ms +step:349/1660 train_time:32003ms step_avg:91.70ms +step:350/1660 train_time:32095ms step_avg:91.70ms +step:351/1660 train_time:32188ms step_avg:91.70ms +step:352/1660 train_time:32279ms step_avg:91.70ms +step:353/1660 train_time:32371ms step_avg:91.70ms +step:354/1660 train_time:32463ms step_avg:91.70ms +step:355/1660 train_time:32556ms step_avg:91.71ms +step:356/1660 train_time:32647ms step_avg:91.71ms +step:357/1660 train_time:32738ms step_avg:91.70ms +step:358/1660 train_time:32830ms step_avg:91.70ms +step:359/1660 train_time:32921ms step_avg:91.70ms +step:360/1660 train_time:33012ms step_avg:91.70ms +step:361/1660 train_time:33104ms step_avg:91.70ms +step:362/1660 train_time:33197ms step_avg:91.70ms +step:363/1660 train_time:33288ms step_avg:91.70ms +step:364/1660 train_time:33379ms step_avg:91.70ms +step:365/1660 train_time:33471ms step_avg:91.70ms +step:366/1660 train_time:33564ms step_avg:91.71ms +step:367/1660 train_time:33656ms step_avg:91.71ms +step:368/1660 train_time:33747ms step_avg:91.70ms +step:369/1660 train_time:33838ms step_avg:91.70ms +step:370/1660 train_time:33930ms step_avg:91.70ms +step:371/1660 train_time:34021ms step_avg:91.70ms +step:372/1660 train_time:34113ms step_avg:91.70ms +step:373/1660 train_time:34205ms step_avg:91.70ms +step:374/1660 train_time:34297ms step_avg:91.70ms +step:375/1660 train_time:34388ms step_avg:91.70ms +step:375/1660 val_loss:3.8115 train_time:34481ms step_avg:91.95ms +step:376/1660 train_time:34502ms step_avg:91.76ms +step:377/1660 train_time:34577ms step_avg:91.72ms +step:378/1660 train_time:34676ms step_avg:91.74ms +step:379/1660 train_time:34770ms step_avg:91.74ms +step:380/1660 train_time:34861ms step_avg:91.74ms +step:381/1660 train_time:34952ms step_avg:91.74ms +step:382/1660 train_time:35042ms step_avg:91.73ms +step:383/1660 train_time:35132ms step_avg:91.73ms +step:384/1660 train_time:35222ms step_avg:91.72ms +step:385/1660 train_time:35313ms step_avg:91.72ms +step:386/1660 train_time:35403ms step_avg:91.72ms +step:387/1660 train_time:35495ms step_avg:91.72ms +step:388/1660 train_time:35588ms step_avg:91.72ms +step:389/1660 train_time:35682ms step_avg:91.73ms +step:390/1660 train_time:35775ms step_avg:91.73ms +step:391/1660 train_time:35867ms step_avg:91.73ms +step:392/1660 train_time:35958ms step_avg:91.73ms +step:393/1660 train_time:36049ms step_avg:91.73ms +step:394/1660 train_time:36140ms step_avg:91.73ms +step:395/1660 train_time:36232ms step_avg:91.73ms +step:396/1660 train_time:36322ms step_avg:91.72ms +step:397/1660 train_time:36414ms step_avg:91.72ms +step:398/1660 train_time:36505ms step_avg:91.72ms +step:399/1660 train_time:36599ms step_avg:91.73ms +step:400/1660 train_time:36692ms step_avg:91.73ms +step:401/1660 train_time:36783ms step_avg:91.73ms +step:402/1660 train_time:36875ms step_avg:91.73ms +step:403/1660 train_time:36966ms step_avg:91.73ms +step:404/1660 train_time:37057ms step_avg:91.73ms +step:405/1660 train_time:37147ms step_avg:91.72ms +step:406/1660 train_time:37239ms step_avg:91.72ms +step:407/1660 train_time:37330ms step_avg:91.72ms +step:408/1660 train_time:37423ms step_avg:91.72ms +step:409/1660 train_time:37516ms step_avg:91.73ms +step:410/1660 train_time:37608ms step_avg:91.73ms +step:411/1660 train_time:37700ms step_avg:91.73ms +step:412/1660 train_time:37793ms step_avg:91.73ms +step:413/1660 train_time:37884ms step_avg:91.73ms +step:414/1660 train_time:37976ms step_avg:91.73ms +step:415/1660 train_time:38067ms step_avg:91.73ms +step:416/1660 train_time:38158ms step_avg:91.72ms +step:417/1660 train_time:38249ms step_avg:91.72ms +step:418/1660 train_time:38340ms step_avg:91.72ms +step:419/1660 train_time:38432ms step_avg:91.72ms +step:420/1660 train_time:38524ms step_avg:91.72ms +step:421/1660 train_time:38617ms step_avg:91.73ms +step:422/1660 train_time:38710ms step_avg:91.73ms +step:423/1660 train_time:38802ms step_avg:91.73ms +step:424/1660 train_time:38894ms step_avg:91.73ms +step:425/1660 train_time:38985ms step_avg:91.73ms +step:426/1660 train_time:39076ms step_avg:91.73ms +step:427/1660 train_time:39166ms step_avg:91.72ms +step:428/1660 train_time:39257ms step_avg:91.72ms +step:429/1660 train_time:39348ms step_avg:91.72ms +step:430/1660 train_time:39440ms step_avg:91.72ms +step:431/1660 train_time:39533ms step_avg:91.72ms +step:432/1660 train_time:39624ms step_avg:91.72ms +step:433/1660 train_time:39717ms step_avg:91.72ms +step:434/1660 train_time:39809ms step_avg:91.73ms +step:435/1660 train_time:39900ms step_avg:91.72ms +step:436/1660 train_time:39991ms step_avg:91.72ms +step:437/1660 train_time:40082ms step_avg:91.72ms +step:438/1660 train_time:40172ms step_avg:91.72ms +step:439/1660 train_time:40263ms step_avg:91.72ms +step:440/1660 train_time:40355ms step_avg:91.72ms +step:441/1660 train_time:40446ms step_avg:91.71ms +step:442/1660 train_time:40539ms step_avg:91.72ms +step:443/1660 train_time:40631ms step_avg:91.72ms +step:444/1660 train_time:40723ms step_avg:91.72ms +step:445/1660 train_time:40815ms step_avg:91.72ms +step:446/1660 train_time:40906ms step_avg:91.72ms +step:447/1660 train_time:40999ms step_avg:91.72ms +step:448/1660 train_time:41091ms step_avg:91.72ms +step:449/1660 train_time:41181ms step_avg:91.72ms +step:450/1660 train_time:41272ms step_avg:91.71ms +step:451/1660 train_time:41362ms step_avg:91.71ms +step:452/1660 train_time:41453ms step_avg:91.71ms +step:453/1660 train_time:41545ms step_avg:91.71ms +step:454/1660 train_time:41637ms step_avg:91.71ms +step:455/1660 train_time:41730ms step_avg:91.71ms +step:456/1660 train_time:41821ms step_avg:91.71ms +step:457/1660 train_time:41913ms step_avg:91.71ms +step:458/1660 train_time:42005ms step_avg:91.71ms +step:459/1660 train_time:42096ms step_avg:91.71ms +step:460/1660 train_time:42187ms step_avg:91.71ms +step:461/1660 train_time:42278ms step_avg:91.71ms +step:462/1660 train_time:42370ms step_avg:91.71ms +step:463/1660 train_time:42461ms step_avg:91.71ms +step:464/1660 train_time:42553ms step_avg:91.71ms +step:465/1660 train_time:42644ms step_avg:91.71ms +step:466/1660 train_time:42737ms step_avg:91.71ms +step:467/1660 train_time:42830ms step_avg:91.71ms +step:468/1660 train_time:42921ms step_avg:91.71ms +step:469/1660 train_time:43013ms step_avg:91.71ms +step:470/1660 train_time:43105ms step_avg:91.71ms +step:471/1660 train_time:43197ms step_avg:91.71ms +step:472/1660 train_time:43288ms step_avg:91.71ms +step:473/1660 train_time:43378ms step_avg:91.71ms +step:474/1660 train_time:43469ms step_avg:91.71ms +step:475/1660 train_time:43561ms step_avg:91.71ms +step:476/1660 train_time:43651ms step_avg:91.70ms +step:477/1660 train_time:43742ms step_avg:91.70ms +step:478/1660 train_time:43836ms step_avg:91.71ms +step:479/1660 train_time:43926ms step_avg:91.70ms +step:480/1660 train_time:44019ms step_avg:91.71ms +step:481/1660 train_time:44111ms step_avg:91.71ms +step:482/1660 train_time:44202ms step_avg:91.71ms +step:483/1660 train_time:44294ms step_avg:91.71ms +step:484/1660 train_time:44384ms step_avg:91.70ms +step:485/1660 train_time:44475ms step_avg:91.70ms +step:486/1660 train_time:44567ms step_avg:91.70ms +step:487/1660 train_time:44658ms step_avg:91.70ms +step:488/1660 train_time:44749ms step_avg:91.70ms +step:489/1660 train_time:44841ms step_avg:91.70ms +step:490/1660 train_time:44933ms step_avg:91.70ms +step:491/1660 train_time:45025ms step_avg:91.70ms +step:492/1660 train_time:45118ms step_avg:91.70ms +step:493/1660 train_time:45211ms step_avg:91.70ms +step:494/1660 train_time:45301ms step_avg:91.70ms +step:495/1660 train_time:45392ms step_avg:91.70ms +step:496/1660 train_time:45482ms step_avg:91.70ms +step:497/1660 train_time:45573ms step_avg:91.70ms +step:498/1660 train_time:45664ms step_avg:91.69ms +step:499/1660 train_time:45755ms step_avg:91.69ms +step:500/1660 train_time:45847ms step_avg:91.69ms +step:500/1660 val_loss:3.7117 train_time:45940ms step_avg:91.88ms +step:501/1660 train_time:45962ms step_avg:91.74ms +step:502/1660 train_time:46036ms step_avg:91.70ms +step:503/1660 train_time:46134ms step_avg:91.72ms +step:504/1660 train_time:46226ms step_avg:91.72ms +step:505/1660 train_time:46317ms step_avg:91.72ms +step:506/1660 train_time:46407ms step_avg:91.71ms +step:507/1660 train_time:46497ms step_avg:91.71ms +step:508/1660 train_time:46587ms step_avg:91.71ms +step:509/1660 train_time:46678ms step_avg:91.70ms +step:510/1660 train_time:46768ms step_avg:91.70ms +step:511/1660 train_time:46858ms step_avg:91.70ms +step:512/1660 train_time:46950ms step_avg:91.70ms +step:513/1660 train_time:47043ms step_avg:91.70ms +step:514/1660 train_time:47137ms step_avg:91.71ms +step:515/1660 train_time:47230ms step_avg:91.71ms +step:516/1660 train_time:47322ms step_avg:91.71ms +step:517/1660 train_time:47414ms step_avg:91.71ms +step:518/1660 train_time:47504ms step_avg:91.71ms +step:519/1660 train_time:47595ms step_avg:91.70ms +step:520/1660 train_time:47684ms step_avg:91.70ms +step:521/1660 train_time:47775ms step_avg:91.70ms +step:522/1660 train_time:47866ms step_avg:91.70ms +step:523/1660 train_time:47959ms step_avg:91.70ms +step:524/1660 train_time:48051ms step_avg:91.70ms +step:525/1660 train_time:48143ms step_avg:91.70ms +step:526/1660 train_time:48236ms step_avg:91.70ms +step:527/1660 train_time:48327ms step_avg:91.70ms +step:528/1660 train_time:48419ms step_avg:91.70ms +step:529/1660 train_time:48510ms step_avg:91.70ms +step:530/1660 train_time:48601ms step_avg:91.70ms +step:531/1660 train_time:48691ms step_avg:91.70ms +step:532/1660 train_time:48782ms step_avg:91.70ms +step:533/1660 train_time:48874ms step_avg:91.70ms +step:534/1660 train_time:48965ms step_avg:91.69ms +step:535/1660 train_time:49058ms step_avg:91.70ms +step:536/1660 train_time:49150ms step_avg:91.70ms +step:537/1660 train_time:49242ms step_avg:91.70ms +step:538/1660 train_time:49334ms step_avg:91.70ms +step:539/1660 train_time:49426ms step_avg:91.70ms +step:540/1660 train_time:49517ms step_avg:91.70ms +step:541/1660 train_time:49608ms step_avg:91.70ms +step:542/1660 train_time:49699ms step_avg:91.70ms +step:543/1660 train_time:49789ms step_avg:91.69ms +step:544/1660 train_time:49881ms step_avg:91.69ms +step:545/1660 train_time:49972ms step_avg:91.69ms +step:546/1660 train_time:50064ms step_avg:91.69ms +step:547/1660 train_time:50157ms step_avg:91.69ms +step:548/1660 train_time:50248ms step_avg:91.69ms +step:549/1660 train_time:50340ms step_avg:91.69ms +step:550/1660 train_time:50432ms step_avg:91.69ms +step:551/1660 train_time:50522ms step_avg:91.69ms +step:552/1660 train_time:50613ms step_avg:91.69ms +step:553/1660 train_time:50704ms step_avg:91.69ms +step:554/1660 train_time:50795ms step_avg:91.69ms +step:555/1660 train_time:50886ms step_avg:91.69ms +step:556/1660 train_time:50978ms step_avg:91.69ms +step:557/1660 train_time:51072ms step_avg:91.69ms +step:558/1660 train_time:51165ms step_avg:91.69ms +step:559/1660 train_time:51259ms step_avg:91.70ms +step:560/1660 train_time:51352ms step_avg:91.70ms +step:561/1660 train_time:51445ms step_avg:91.70ms +step:562/1660 train_time:51538ms step_avg:91.70ms +step:563/1660 train_time:51629ms step_avg:91.70ms +step:564/1660 train_time:51721ms step_avg:91.70ms +step:565/1660 train_time:51815ms step_avg:91.71ms +step:566/1660 train_time:51908ms step_avg:91.71ms +step:567/1660 train_time:52001ms step_avg:91.71ms +step:568/1660 train_time:52093ms step_avg:91.71ms +step:569/1660 train_time:52186ms step_avg:91.71ms +step:570/1660 train_time:52281ms step_avg:91.72ms +step:571/1660 train_time:52375ms step_avg:91.72ms +step:572/1660 train_time:52467ms step_avg:91.73ms +step:573/1660 train_time:52560ms step_avg:91.73ms +step:574/1660 train_time:52652ms step_avg:91.73ms +step:575/1660 train_time:52745ms step_avg:91.73ms +step:576/1660 train_time:52839ms step_avg:91.73ms +step:577/1660 train_time:52932ms step_avg:91.74ms +step:578/1660 train_time:53024ms step_avg:91.74ms +step:579/1660 train_time:53116ms step_avg:91.74ms +step:580/1660 train_time:53209ms step_avg:91.74ms +step:581/1660 train_time:53302ms step_avg:91.74ms +step:582/1660 train_time:53395ms step_avg:91.74ms +step:583/1660 train_time:53488ms step_avg:91.75ms +step:584/1660 train_time:53580ms step_avg:91.75ms +step:585/1660 train_time:53673ms step_avg:91.75ms +step:586/1660 train_time:53766ms step_avg:91.75ms +step:587/1660 train_time:53860ms step_avg:91.75ms +step:588/1660 train_time:53953ms step_avg:91.76ms +step:589/1660 train_time:54045ms step_avg:91.76ms +step:590/1660 train_time:54138ms step_avg:91.76ms +step:591/1660 train_time:54231ms step_avg:91.76ms +step:592/1660 train_time:54323ms step_avg:91.76ms +step:593/1660 train_time:54415ms step_avg:91.76ms +step:594/1660 train_time:54508ms step_avg:91.76ms +step:595/1660 train_time:54602ms step_avg:91.77ms +step:596/1660 train_time:54695ms step_avg:91.77ms +step:597/1660 train_time:54788ms step_avg:91.77ms +step:598/1660 train_time:54881ms step_avg:91.77ms +step:599/1660 train_time:54975ms step_avg:91.78ms +step:600/1660 train_time:55068ms step_avg:91.78ms +step:601/1660 train_time:55160ms step_avg:91.78ms +step:602/1660 train_time:55253ms step_avg:91.78ms +step:603/1660 train_time:55346ms step_avg:91.78ms +step:604/1660 train_time:55440ms step_avg:91.79ms +step:605/1660 train_time:55533ms step_avg:91.79ms +step:606/1660 train_time:55625ms step_avg:91.79ms +step:607/1660 train_time:55718ms step_avg:91.79ms +step:608/1660 train_time:55810ms step_avg:91.79ms +step:609/1660 train_time:55903ms step_avg:91.80ms +step:610/1660 train_time:55997ms step_avg:91.80ms +step:611/1660 train_time:56090ms step_avg:91.80ms +step:612/1660 train_time:56183ms step_avg:91.80ms +step:613/1660 train_time:56275ms step_avg:91.80ms +step:614/1660 train_time:56368ms step_avg:91.80ms +step:615/1660 train_time:56460ms step_avg:91.81ms +step:616/1660 train_time:56553ms step_avg:91.81ms +step:617/1660 train_time:56645ms step_avg:91.81ms +step:618/1660 train_time:56739ms step_avg:91.81ms +step:619/1660 train_time:56832ms step_avg:91.81ms +step:620/1660 train_time:56924ms step_avg:91.81ms +step:621/1660 train_time:57017ms step_avg:91.82ms +step:622/1660 train_time:57111ms step_avg:91.82ms +step:623/1660 train_time:57203ms step_avg:91.82ms +step:624/1660 train_time:57296ms step_avg:91.82ms +step:625/1660 train_time:57389ms step_avg:91.82ms +step:625/1660 val_loss:3.6113 train_time:57484ms step_avg:91.97ms +step:626/1660 train_time:57504ms step_avg:91.86ms +step:627/1660 train_time:57585ms step_avg:91.84ms +step:628/1660 train_time:57688ms step_avg:91.86ms +step:629/1660 train_time:57783ms step_avg:91.87ms +step:630/1660 train_time:57876ms step_avg:91.87ms +step:631/1660 train_time:57967ms step_avg:91.87ms +step:632/1660 train_time:58059ms step_avg:91.87ms +step:633/1660 train_time:58150ms step_avg:91.86ms +step:634/1660 train_time:58242ms step_avg:91.86ms +step:635/1660 train_time:58334ms step_avg:91.86ms +step:636/1660 train_time:58426ms step_avg:91.86ms +step:637/1660 train_time:58519ms step_avg:91.87ms +step:638/1660 train_time:58616ms step_avg:91.88ms +step:639/1660 train_time:58714ms step_avg:91.88ms +step:640/1660 train_time:58807ms step_avg:91.89ms +step:641/1660 train_time:58899ms step_avg:91.89ms +step:642/1660 train_time:58991ms step_avg:91.89ms +step:643/1660 train_time:59083ms step_avg:91.89ms +step:644/1660 train_time:59175ms step_avg:91.89ms +step:645/1660 train_time:59267ms step_avg:91.89ms +step:646/1660 train_time:59358ms step_avg:91.88ms +step:647/1660 train_time:59451ms step_avg:91.89ms +step:648/1660 train_time:59546ms step_avg:91.89ms +step:649/1660 train_time:59639ms step_avg:91.89ms +step:650/1660 train_time:59733ms step_avg:91.90ms +step:651/1660 train_time:59827ms step_avg:91.90ms +step:652/1660 train_time:59920ms step_avg:91.90ms +step:653/1660 train_time:60013ms step_avg:91.90ms +step:654/1660 train_time:60105ms step_avg:91.90ms +step:655/1660 train_time:60197ms step_avg:91.90ms +step:656/1660 train_time:60288ms step_avg:91.90ms +step:657/1660 train_time:60380ms step_avg:91.90ms +step:658/1660 train_time:60475ms step_avg:91.91ms +step:659/1660 train_time:60569ms step_avg:91.91ms +step:660/1660 train_time:60662ms step_avg:91.91ms +step:661/1660 train_time:60755ms step_avg:91.91ms +step:662/1660 train_time:60850ms step_avg:91.92ms +step:663/1660 train_time:60943ms step_avg:91.92ms +step:664/1660 train_time:61036ms step_avg:91.92ms +step:665/1660 train_time:61128ms step_avg:91.92ms +step:666/1660 train_time:61220ms step_avg:91.92ms +step:667/1660 train_time:61312ms step_avg:91.92ms +step:668/1660 train_time:61404ms step_avg:91.92ms +step:669/1660 train_time:61497ms step_avg:91.92ms +step:670/1660 train_time:61593ms step_avg:91.93ms +step:671/1660 train_time:61687ms step_avg:91.93ms +step:672/1660 train_time:61780ms step_avg:91.93ms +step:673/1660 train_time:61873ms step_avg:91.94ms +step:674/1660 train_time:61966ms step_avg:91.94ms +step:675/1660 train_time:62058ms step_avg:91.94ms +step:676/1660 train_time:62152ms step_avg:91.94ms +step:677/1660 train_time:62246ms step_avg:91.94ms +step:678/1660 train_time:62337ms step_avg:91.94ms +step:679/1660 train_time:62429ms step_avg:91.94ms +step:680/1660 train_time:62523ms step_avg:91.94ms +step:681/1660 train_time:62616ms step_avg:91.95ms +step:682/1660 train_time:62709ms step_avg:91.95ms +step:683/1660 train_time:62802ms step_avg:91.95ms +step:684/1660 train_time:62895ms step_avg:91.95ms +step:685/1660 train_time:62988ms step_avg:91.95ms +step:686/1660 train_time:63080ms step_avg:91.95ms +step:687/1660 train_time:63173ms step_avg:91.96ms +step:688/1660 train_time:63266ms step_avg:91.96ms +step:689/1660 train_time:63358ms step_avg:91.96ms +step:690/1660 train_time:63453ms step_avg:91.96ms +step:691/1660 train_time:63546ms step_avg:91.96ms +step:692/1660 train_time:63639ms step_avg:91.96ms +step:693/1660 train_time:63733ms step_avg:91.97ms +step:694/1660 train_time:63826ms step_avg:91.97ms +step:695/1660 train_time:63919ms step_avg:91.97ms +step:696/1660 train_time:64012ms step_avg:91.97ms +step:697/1660 train_time:64104ms step_avg:91.97ms +step:698/1660 train_time:64196ms step_avg:91.97ms +step:699/1660 train_time:64289ms step_avg:91.97ms +step:700/1660 train_time:64381ms step_avg:91.97ms +step:701/1660 train_time:64473ms step_avg:91.97ms +step:702/1660 train_time:64566ms step_avg:91.97ms +step:703/1660 train_time:64658ms step_avg:91.98ms +step:704/1660 train_time:64752ms step_avg:91.98ms +step:705/1660 train_time:64846ms step_avg:91.98ms +step:706/1660 train_time:64939ms step_avg:91.98ms +step:707/1660 train_time:65033ms step_avg:91.98ms +step:708/1660 train_time:65126ms step_avg:91.99ms +step:709/1660 train_time:65218ms step_avg:91.99ms +step:710/1660 train_time:65311ms step_avg:91.99ms +step:711/1660 train_time:65404ms step_avg:91.99ms +step:712/1660 train_time:65496ms step_avg:91.99ms +step:713/1660 train_time:65589ms step_avg:91.99ms +step:714/1660 train_time:65681ms step_avg:91.99ms +step:715/1660 train_time:65775ms step_avg:91.99ms +step:716/1660 train_time:65868ms step_avg:91.99ms +step:717/1660 train_time:65961ms step_avg:92.00ms +step:718/1660 train_time:66055ms step_avg:92.00ms +step:719/1660 train_time:66147ms step_avg:92.00ms +step:720/1660 train_time:66239ms step_avg:92.00ms +step:721/1660 train_time:66331ms step_avg:92.00ms +step:722/1660 train_time:66424ms step_avg:92.00ms +step:723/1660 train_time:66516ms step_avg:92.00ms +step:724/1660 train_time:66609ms step_avg:92.00ms +step:725/1660 train_time:66701ms step_avg:92.00ms +step:726/1660 train_time:66794ms step_avg:92.00ms +step:727/1660 train_time:66887ms step_avg:92.00ms +step:728/1660 train_time:66979ms step_avg:92.00ms +step:729/1660 train_time:67074ms step_avg:92.01ms +step:730/1660 train_time:67166ms step_avg:92.01ms +step:731/1660 train_time:67259ms step_avg:92.01ms +step:732/1660 train_time:67352ms step_avg:92.01ms +step:733/1660 train_time:67445ms step_avg:92.01ms +step:734/1660 train_time:67537ms step_avg:92.01ms +step:735/1660 train_time:67630ms step_avg:92.01ms +step:736/1660 train_time:67723ms step_avg:92.02ms +step:737/1660 train_time:67816ms step_avg:92.02ms +step:738/1660 train_time:67909ms step_avg:92.02ms +step:739/1660 train_time:68002ms step_avg:92.02ms +step:740/1660 train_time:68095ms step_avg:92.02ms +step:741/1660 train_time:68188ms step_avg:92.02ms +step:742/1660 train_time:68281ms step_avg:92.02ms +step:743/1660 train_time:68375ms step_avg:92.03ms +step:744/1660 train_time:68468ms step_avg:92.03ms +step:745/1660 train_time:68559ms step_avg:92.03ms +step:746/1660 train_time:68654ms step_avg:92.03ms +step:747/1660 train_time:68747ms step_avg:92.03ms +step:748/1660 train_time:68839ms step_avg:92.03ms +step:749/1660 train_time:68933ms step_avg:92.03ms +step:750/1660 train_time:69026ms step_avg:92.04ms +step:750/1660 val_loss:3.5612 train_time:69121ms step_avg:92.16ms +step:751/1660 train_time:69142ms step_avg:92.07ms +step:752/1660 train_time:69221ms step_avg:92.05ms +step:753/1660 train_time:69318ms step_avg:92.06ms +step:754/1660 train_time:69411ms step_avg:92.06ms +step:755/1660 train_time:69502ms step_avg:92.06ms +step:756/1660 train_time:69594ms step_avg:92.06ms +step:757/1660 train_time:69686ms step_avg:92.06ms +step:758/1660 train_time:69778ms step_avg:92.06ms +step:759/1660 train_time:69869ms step_avg:92.05ms +step:760/1660 train_time:69961ms step_avg:92.05ms +step:761/1660 train_time:70054ms step_avg:92.06ms +step:762/1660 train_time:70151ms step_avg:92.06ms +step:763/1660 train_time:70248ms step_avg:92.07ms +step:764/1660 train_time:70342ms step_avg:92.07ms +step:765/1660 train_time:70434ms step_avg:92.07ms +step:766/1660 train_time:70527ms step_avg:92.07ms +step:767/1660 train_time:70620ms step_avg:92.07ms +step:768/1660 train_time:70712ms step_avg:92.07ms +step:769/1660 train_time:70804ms step_avg:92.07ms +step:770/1660 train_time:70895ms step_avg:92.07ms +step:771/1660 train_time:70988ms step_avg:92.07ms +step:772/1660 train_time:71082ms step_avg:92.07ms +step:773/1660 train_time:71176ms step_avg:92.08ms +step:774/1660 train_time:71270ms step_avg:92.08ms +step:775/1660 train_time:71364ms step_avg:92.08ms +step:776/1660 train_time:71457ms step_avg:92.08ms +step:777/1660 train_time:71550ms step_avg:92.08ms +step:778/1660 train_time:71642ms step_avg:92.08ms +step:779/1660 train_time:71733ms step_avg:92.08ms +step:780/1660 train_time:71825ms step_avg:92.08ms +step:781/1660 train_time:71917ms step_avg:92.08ms +step:782/1660 train_time:72010ms step_avg:92.08ms +step:783/1660 train_time:72104ms step_avg:92.09ms +step:784/1660 train_time:72197ms step_avg:92.09ms +step:785/1660 train_time:72291ms step_avg:92.09ms +step:786/1660 train_time:72385ms step_avg:92.09ms +step:787/1660 train_time:72479ms step_avg:92.10ms +step:788/1660 train_time:72571ms step_avg:92.10ms +step:789/1660 train_time:72664ms step_avg:92.10ms +step:790/1660 train_time:72756ms step_avg:92.10ms +step:791/1660 train_time:72848ms step_avg:92.10ms +step:792/1660 train_time:72940ms step_avg:92.10ms +step:793/1660 train_time:73033ms step_avg:92.10ms +step:794/1660 train_time:73127ms step_avg:92.10ms +step:795/1660 train_time:73221ms step_avg:92.10ms +step:796/1660 train_time:73314ms step_avg:92.10ms +step:797/1660 train_time:73409ms step_avg:92.11ms +step:798/1660 train_time:73502ms step_avg:92.11ms +step:799/1660 train_time:73594ms step_avg:92.11ms +step:800/1660 train_time:73687ms step_avg:92.11ms +step:801/1660 train_time:73779ms step_avg:92.11ms +step:802/1660 train_time:73871ms step_avg:92.11ms +step:803/1660 train_time:73963ms step_avg:92.11ms +step:804/1660 train_time:74057ms step_avg:92.11ms +step:805/1660 train_time:74150ms step_avg:92.11ms +step:806/1660 train_time:74243ms step_avg:92.11ms +step:807/1660 train_time:74336ms step_avg:92.11ms +step:808/1660 train_time:74430ms step_avg:92.12ms +step:809/1660 train_time:74523ms step_avg:92.12ms +step:810/1660 train_time:74616ms step_avg:92.12ms +step:811/1660 train_time:74709ms step_avg:92.12ms +step:812/1660 train_time:74802ms step_avg:92.12ms +step:813/1660 train_time:74893ms step_avg:92.12ms +step:814/1660 train_time:74986ms step_avg:92.12ms +step:815/1660 train_time:75080ms step_avg:92.12ms +step:816/1660 train_time:75173ms step_avg:92.12ms +step:817/1660 train_time:75267ms step_avg:92.13ms +step:818/1660 train_time:75359ms step_avg:92.13ms +step:819/1660 train_time:75453ms step_avg:92.13ms +step:820/1660 train_time:75546ms step_avg:92.13ms +step:821/1660 train_time:75639ms step_avg:92.13ms +step:822/1660 train_time:75732ms step_avg:92.13ms +step:823/1660 train_time:75824ms step_avg:92.13ms +step:824/1660 train_time:75916ms step_avg:92.13ms +step:825/1660 train_time:76010ms step_avg:92.13ms +step:826/1660 train_time:76102ms step_avg:92.13ms +step:827/1660 train_time:76195ms step_avg:92.13ms +step:828/1660 train_time:76289ms step_avg:92.14ms +step:829/1660 train_time:76382ms step_avg:92.14ms +step:830/1660 train_time:76475ms step_avg:92.14ms +step:831/1660 train_time:76568ms step_avg:92.14ms +step:832/1660 train_time:76661ms step_avg:92.14ms +step:833/1660 train_time:76753ms step_avg:92.14ms +step:834/1660 train_time:76847ms step_avg:92.14ms +step:835/1660 train_time:76940ms step_avg:92.14ms +step:836/1660 train_time:77032ms step_avg:92.14ms +step:837/1660 train_time:77125ms step_avg:92.14ms +step:838/1660 train_time:77218ms step_avg:92.15ms +step:839/1660 train_time:77311ms step_avg:92.15ms +step:840/1660 train_time:77405ms step_avg:92.15ms +step:841/1660 train_time:77498ms step_avg:92.15ms +step:842/1660 train_time:77592ms step_avg:92.15ms +step:843/1660 train_time:77684ms step_avg:92.15ms +step:844/1660 train_time:77777ms step_avg:92.15ms +step:845/1660 train_time:77869ms step_avg:92.15ms +step:846/1660 train_time:77962ms step_avg:92.15ms +step:847/1660 train_time:78055ms step_avg:92.16ms +step:848/1660 train_time:78149ms step_avg:92.16ms +step:849/1660 train_time:78242ms step_avg:92.16ms +step:850/1660 train_time:78335ms step_avg:92.16ms +step:851/1660 train_time:78428ms step_avg:92.16ms +step:852/1660 train_time:78521ms step_avg:92.16ms +step:853/1660 train_time:78613ms step_avg:92.16ms +step:854/1660 train_time:78707ms step_avg:92.16ms +step:855/1660 train_time:78800ms step_avg:92.16ms +step:856/1660 train_time:78892ms step_avg:92.16ms +step:857/1660 train_time:78986ms step_avg:92.17ms +step:858/1660 train_time:79079ms step_avg:92.17ms +step:859/1660 train_time:79172ms step_avg:92.17ms +step:860/1660 train_time:79265ms step_avg:92.17ms +step:861/1660 train_time:79357ms step_avg:92.17ms +step:862/1660 train_time:79450ms step_avg:92.17ms +step:863/1660 train_time:79542ms step_avg:92.17ms +step:864/1660 train_time:79635ms step_avg:92.17ms +step:865/1660 train_time:79728ms step_avg:92.17ms +step:866/1660 train_time:79821ms step_avg:92.17ms +step:867/1660 train_time:79914ms step_avg:92.17ms +step:868/1660 train_time:80008ms step_avg:92.17ms +step:869/1660 train_time:80101ms step_avg:92.18ms +step:870/1660 train_time:80193ms step_avg:92.18ms +step:871/1660 train_time:80286ms step_avg:92.18ms +step:872/1660 train_time:80379ms step_avg:92.18ms +step:873/1660 train_time:80471ms step_avg:92.18ms +step:874/1660 train_time:80563ms step_avg:92.18ms +step:875/1660 train_time:80657ms step_avg:92.18ms +step:875/1660 val_loss:3.5150 train_time:80752ms step_avg:92.29ms +step:876/1660 train_time:80773ms step_avg:92.21ms +step:877/1660 train_time:80848ms step_avg:92.19ms +step:878/1660 train_time:80945ms step_avg:92.19ms +step:879/1660 train_time:81040ms step_avg:92.20ms +step:880/1660 train_time:81132ms step_avg:92.20ms +step:881/1660 train_time:81223ms step_avg:92.19ms +step:882/1660 train_time:81314ms step_avg:92.19ms +step:883/1660 train_time:81406ms step_avg:92.19ms +step:884/1660 train_time:81498ms step_avg:92.19ms +step:885/1660 train_time:81590ms step_avg:92.19ms +step:886/1660 train_time:81682ms step_avg:92.19ms +step:887/1660 train_time:81777ms step_avg:92.20ms +step:888/1660 train_time:81873ms step_avg:92.20ms +step:889/1660 train_time:81967ms step_avg:92.20ms +step:890/1660 train_time:82060ms step_avg:92.20ms +step:891/1660 train_time:82153ms step_avg:92.20ms +step:892/1660 train_time:82245ms step_avg:92.20ms +step:893/1660 train_time:82337ms step_avg:92.20ms +step:894/1660 train_time:82429ms step_avg:92.20ms +step:895/1660 train_time:82521ms step_avg:92.20ms +step:896/1660 train_time:82613ms step_avg:92.20ms +step:897/1660 train_time:82706ms step_avg:92.20ms +step:898/1660 train_time:82800ms step_avg:92.20ms +step:899/1660 train_time:82895ms step_avg:92.21ms +step:900/1660 train_time:82989ms step_avg:92.21ms +step:901/1660 train_time:83082ms step_avg:92.21ms +step:902/1660 train_time:83175ms step_avg:92.21ms +step:903/1660 train_time:83267ms step_avg:92.21ms +step:904/1660 train_time:83359ms step_avg:92.21ms +step:905/1660 train_time:83451ms step_avg:92.21ms +step:906/1660 train_time:83543ms step_avg:92.21ms +step:907/1660 train_time:83636ms step_avg:92.21ms +step:908/1660 train_time:83727ms step_avg:92.21ms +step:909/1660 train_time:83821ms step_avg:92.21ms +step:910/1660 train_time:83915ms step_avg:92.21ms +step:911/1660 train_time:84008ms step_avg:92.22ms +step:912/1660 train_time:84101ms step_avg:92.22ms +step:913/1660 train_time:84195ms step_avg:92.22ms +step:914/1660 train_time:84287ms step_avg:92.22ms +step:915/1660 train_time:84380ms step_avg:92.22ms +step:916/1660 train_time:84472ms step_avg:92.22ms +step:917/1660 train_time:84564ms step_avg:92.22ms +step:918/1660 train_time:84657ms step_avg:92.22ms +step:919/1660 train_time:84750ms step_avg:92.22ms +step:920/1660 train_time:84843ms step_avg:92.22ms +step:921/1660 train_time:84936ms step_avg:92.22ms +step:922/1660 train_time:85029ms step_avg:92.22ms +step:923/1660 train_time:85122ms step_avg:92.22ms +step:924/1660 train_time:85217ms step_avg:92.23ms +step:925/1660 train_time:85310ms step_avg:92.23ms +step:926/1660 train_time:85402ms step_avg:92.23ms +step:927/1660 train_time:85495ms step_avg:92.23ms +step:928/1660 train_time:85588ms step_avg:92.23ms +step:929/1660 train_time:85680ms step_avg:92.23ms +step:930/1660 train_time:85773ms step_avg:92.23ms +step:931/1660 train_time:85865ms step_avg:92.23ms +step:932/1660 train_time:85960ms step_avg:92.23ms +step:933/1660 train_time:86053ms step_avg:92.23ms +step:934/1660 train_time:86146ms step_avg:92.23ms +step:935/1660 train_time:86239ms step_avg:92.23ms +step:936/1660 train_time:86333ms step_avg:92.24ms +step:937/1660 train_time:86425ms step_avg:92.24ms +step:938/1660 train_time:86518ms step_avg:92.24ms +step:939/1660 train_time:86611ms step_avg:92.24ms +step:940/1660 train_time:86704ms step_avg:92.24ms +step:941/1660 train_time:86796ms step_avg:92.24ms +step:942/1660 train_time:86889ms step_avg:92.24ms +step:943/1660 train_time:86981ms step_avg:92.24ms +step:944/1660 train_time:87074ms step_avg:92.24ms +step:945/1660 train_time:87167ms step_avg:92.24ms +step:946/1660 train_time:87260ms step_avg:92.24ms +step:947/1660 train_time:87353ms step_avg:92.24ms +step:948/1660 train_time:87445ms step_avg:92.24ms +step:949/1660 train_time:87537ms step_avg:92.24ms +step:950/1660 train_time:87630ms step_avg:92.24ms +step:951/1660 train_time:87723ms step_avg:92.24ms +step:952/1660 train_time:87816ms step_avg:92.24ms +step:953/1660 train_time:87910ms step_avg:92.25ms +step:954/1660 train_time:88002ms step_avg:92.25ms +step:955/1660 train_time:88095ms step_avg:92.25ms +step:956/1660 train_time:88188ms step_avg:92.25ms +step:957/1660 train_time:88281ms step_avg:92.25ms +step:958/1660 train_time:88374ms step_avg:92.25ms +step:959/1660 train_time:88466ms step_avg:92.25ms +step:960/1660 train_time:88559ms step_avg:92.25ms +step:961/1660 train_time:88652ms step_avg:92.25ms +step:962/1660 train_time:88745ms step_avg:92.25ms +step:963/1660 train_time:88838ms step_avg:92.25ms +step:964/1660 train_time:88931ms step_avg:92.25ms +step:965/1660 train_time:89023ms step_avg:92.25ms +step:966/1660 train_time:89116ms step_avg:92.25ms +step:967/1660 train_time:89210ms step_avg:92.25ms +step:968/1660 train_time:89303ms step_avg:92.25ms +step:969/1660 train_time:89396ms step_avg:92.26ms +step:970/1660 train_time:89489ms step_avg:92.26ms +step:971/1660 train_time:89581ms step_avg:92.26ms +step:972/1660 train_time:89674ms step_avg:92.26ms +step:973/1660 train_time:89766ms step_avg:92.26ms +step:974/1660 train_time:89859ms step_avg:92.26ms +step:975/1660 train_time:89953ms step_avg:92.26ms +step:976/1660 train_time:90046ms step_avg:92.26ms +step:977/1660 train_time:90139ms step_avg:92.26ms +step:978/1660 train_time:90232ms step_avg:92.26ms +step:979/1660 train_time:90324ms step_avg:92.26ms +step:980/1660 train_time:90417ms step_avg:92.26ms +step:981/1660 train_time:90510ms step_avg:92.26ms +step:982/1660 train_time:90602ms step_avg:92.26ms +step:983/1660 train_time:90696ms step_avg:92.26ms +step:984/1660 train_time:90789ms step_avg:92.27ms +step:985/1660 train_time:90882ms step_avg:92.27ms +step:986/1660 train_time:90976ms step_avg:92.27ms +step:987/1660 train_time:91068ms step_avg:92.27ms +step:988/1660 train_time:91160ms step_avg:92.27ms +step:989/1660 train_time:91253ms step_avg:92.27ms +step:990/1660 train_time:91346ms step_avg:92.27ms +step:991/1660 train_time:91440ms step_avg:92.27ms +step:992/1660 train_time:91533ms step_avg:92.27ms +step:993/1660 train_time:91625ms step_avg:92.27ms +step:994/1660 train_time:91718ms step_avg:92.27ms +step:995/1660 train_time:91811ms step_avg:92.27ms +step:996/1660 train_time:91904ms step_avg:92.27ms +step:997/1660 train_time:91997ms step_avg:92.27ms +step:998/1660 train_time:92090ms step_avg:92.27ms +step:999/1660 train_time:92184ms step_avg:92.28ms +step:1000/1660 train_time:92277ms step_avg:92.28ms +step:1000/1660 val_loss:3.4668 train_time:92371ms step_avg:92.37ms +step:1001/1660 train_time:92392ms step_avg:92.30ms +step:1002/1660 train_time:92466ms step_avg:92.28ms +step:1003/1660 train_time:92566ms step_avg:92.29ms +step:1004/1660 train_time:92658ms step_avg:92.29ms +step:1005/1660 train_time:92750ms step_avg:92.29ms +step:1006/1660 train_time:92841ms step_avg:92.29ms +step:1007/1660 train_time:92933ms step_avg:92.29ms +step:1008/1660 train_time:93024ms step_avg:92.29ms +step:1009/1660 train_time:93115ms step_avg:92.28ms +step:1010/1660 train_time:93208ms step_avg:92.28ms +step:1011/1660 train_time:93300ms step_avg:92.28ms +step:1012/1660 train_time:93394ms step_avg:92.29ms +step:1013/1660 train_time:93490ms step_avg:92.29ms +step:1014/1660 train_time:93585ms step_avg:92.29ms +step:1015/1660 train_time:93678ms step_avg:92.29ms +step:1016/1660 train_time:93770ms step_avg:92.29ms +step:1017/1660 train_time:93862ms step_avg:92.29ms +step:1018/1660 train_time:93953ms step_avg:92.29ms +step:1019/1660 train_time:94045ms step_avg:92.29ms +step:1020/1660 train_time:94136ms step_avg:92.29ms +step:1021/1660 train_time:94229ms step_avg:92.29ms +step:1022/1660 train_time:94322ms step_avg:92.29ms +step:1023/1660 train_time:94417ms step_avg:92.29ms +step:1024/1660 train_time:94511ms step_avg:92.30ms +step:1025/1660 train_time:94604ms step_avg:92.30ms +step:1026/1660 train_time:94697ms step_avg:92.30ms +step:1027/1660 train_time:94789ms step_avg:92.30ms +step:1028/1660 train_time:94881ms step_avg:92.30ms +step:1029/1660 train_time:94973ms step_avg:92.30ms +step:1030/1660 train_time:95065ms step_avg:92.30ms +step:1031/1660 train_time:95157ms step_avg:92.30ms +step:1032/1660 train_time:95250ms step_avg:92.30ms +step:1033/1660 train_time:95343ms step_avg:92.30ms +step:1034/1660 train_time:95436ms step_avg:92.30ms +step:1035/1660 train_time:95531ms step_avg:92.30ms +step:1036/1660 train_time:95623ms step_avg:92.30ms +step:1037/1660 train_time:95716ms step_avg:92.30ms +step:1038/1660 train_time:95809ms step_avg:92.30ms +step:1039/1660 train_time:95901ms step_avg:92.30ms +step:1040/1660 train_time:95993ms step_avg:92.30ms +step:1041/1660 train_time:96085ms step_avg:92.30ms +step:1042/1660 train_time:96176ms step_avg:92.30ms +step:1043/1660 train_time:96269ms step_avg:92.30ms +step:1044/1660 train_time:96362ms step_avg:92.30ms +step:1045/1660 train_time:96456ms step_avg:92.30ms +step:1046/1660 train_time:96550ms step_avg:92.30ms +step:1047/1660 train_time:96643ms step_avg:92.30ms +step:1048/1660 train_time:96735ms step_avg:92.30ms +step:1049/1660 train_time:96829ms step_avg:92.31ms +step:1050/1660 train_time:96922ms step_avg:92.31ms +step:1051/1660 train_time:97014ms step_avg:92.31ms +step:1052/1660 train_time:97105ms step_avg:92.31ms +step:1053/1660 train_time:97197ms step_avg:92.30ms +step:1054/1660 train_time:97290ms step_avg:92.31ms +step:1055/1660 train_time:97383ms step_avg:92.31ms +step:1056/1660 train_time:97475ms step_avg:92.31ms +step:1057/1660 train_time:97569ms step_avg:92.31ms +step:1058/1660 train_time:97661ms step_avg:92.31ms +step:1059/1660 train_time:97755ms step_avg:92.31ms +step:1060/1660 train_time:97847ms step_avg:92.31ms +step:1061/1660 train_time:97941ms step_avg:92.31ms +step:1062/1660 train_time:98034ms step_avg:92.31ms +step:1063/1660 train_time:98126ms step_avg:92.31ms +step:1064/1660 train_time:98218ms step_avg:92.31ms +step:1065/1660 train_time:98311ms step_avg:92.31ms +step:1066/1660 train_time:98404ms step_avg:92.31ms +step:1067/1660 train_time:98496ms step_avg:92.31ms +step:1068/1660 train_time:98589ms step_avg:92.31ms +step:1069/1660 train_time:98683ms step_avg:92.31ms +step:1070/1660 train_time:98775ms step_avg:92.31ms +step:1071/1660 train_time:98868ms step_avg:92.31ms +step:1072/1660 train_time:98961ms step_avg:92.31ms +step:1073/1660 train_time:99054ms step_avg:92.32ms +step:1074/1660 train_time:99147ms step_avg:92.32ms +step:1075/1660 train_time:99240ms step_avg:92.32ms +step:1076/1660 train_time:99333ms step_avg:92.32ms +step:1077/1660 train_time:99425ms step_avg:92.32ms +step:1078/1660 train_time:99518ms step_avg:92.32ms +step:1079/1660 train_time:99611ms step_avg:92.32ms +step:1080/1660 train_time:99703ms step_avg:92.32ms +step:1081/1660 train_time:99795ms step_avg:92.32ms +step:1082/1660 train_time:99889ms step_avg:92.32ms +step:1083/1660 train_time:99982ms step_avg:92.32ms +step:1084/1660 train_time:100074ms step_avg:92.32ms +step:1085/1660 train_time:100167ms step_avg:92.32ms +step:1086/1660 train_time:100260ms step_avg:92.32ms +step:1087/1660 train_time:100352ms step_avg:92.32ms +step:1088/1660 train_time:100445ms step_avg:92.32ms +step:1089/1660 train_time:100537ms step_avg:92.32ms +step:1090/1660 train_time:100630ms step_avg:92.32ms +step:1091/1660 train_time:100722ms step_avg:92.32ms +step:1092/1660 train_time:100815ms step_avg:92.32ms +step:1093/1660 train_time:100908ms step_avg:92.32ms +step:1094/1660 train_time:101001ms step_avg:92.32ms +step:1095/1660 train_time:101093ms step_avg:92.32ms +step:1096/1660 train_time:101186ms step_avg:92.32ms +step:1097/1660 train_time:101278ms step_avg:92.32ms +step:1098/1660 train_time:101370ms step_avg:92.32ms +step:1099/1660 train_time:101463ms step_avg:92.32ms +step:1100/1660 train_time:101555ms step_avg:92.32ms +step:1101/1660 train_time:101648ms step_avg:92.32ms +step:1102/1660 train_time:101740ms step_avg:92.32ms +step:1103/1660 train_time:101833ms step_avg:92.32ms +step:1104/1660 train_time:101926ms step_avg:92.32ms +step:1105/1660 train_time:102019ms step_avg:92.32ms +step:1106/1660 train_time:102111ms step_avg:92.32ms +step:1107/1660 train_time:102203ms step_avg:92.32ms +step:1108/1660 train_time:102296ms step_avg:92.33ms +step:1109/1660 train_time:102389ms step_avg:92.33ms +step:1110/1660 train_time:102484ms step_avg:92.33ms +step:1111/1660 train_time:102577ms step_avg:92.33ms +step:1112/1660 train_time:102670ms step_avg:92.33ms +step:1113/1660 train_time:102763ms step_avg:92.33ms +step:1114/1660 train_time:102856ms step_avg:92.33ms +step:1115/1660 train_time:102951ms step_avg:92.33ms +step:1116/1660 train_time:103044ms step_avg:92.33ms +step:1117/1660 train_time:103136ms step_avg:92.33ms +step:1118/1660 train_time:103229ms step_avg:92.33ms +step:1119/1660 train_time:103323ms step_avg:92.33ms +step:1120/1660 train_time:103416ms step_avg:92.34ms +step:1121/1660 train_time:103509ms step_avg:92.34ms +step:1122/1660 train_time:103603ms step_avg:92.34ms +step:1123/1660 train_time:103695ms step_avg:92.34ms +step:1124/1660 train_time:103789ms step_avg:92.34ms +step:1125/1660 train_time:103883ms step_avg:92.34ms +step:1125/1660 val_loss:3.4127 train_time:103978ms step_avg:92.42ms +step:1126/1660 train_time:103999ms step_avg:92.36ms +step:1127/1660 train_time:104075ms step_avg:92.35ms +step:1128/1660 train_time:104181ms step_avg:92.36ms +step:1129/1660 train_time:104276ms step_avg:92.36ms +step:1130/1660 train_time:104369ms step_avg:92.36ms +step:1131/1660 train_time:104461ms step_avg:92.36ms +step:1132/1660 train_time:104553ms step_avg:92.36ms +step:1133/1660 train_time:104646ms step_avg:92.36ms +step:1134/1660 train_time:104738ms step_avg:92.36ms +step:1135/1660 train_time:104831ms step_avg:92.36ms +step:1136/1660 train_time:104924ms step_avg:92.36ms +step:1137/1660 train_time:105019ms step_avg:92.36ms +step:1138/1660 train_time:105117ms step_avg:92.37ms +step:1139/1660 train_time:105216ms step_avg:92.38ms +step:1140/1660 train_time:105311ms step_avg:92.38ms +step:1141/1660 train_time:105404ms step_avg:92.38ms +step:1142/1660 train_time:105495ms step_avg:92.38ms +step:1143/1660 train_time:105589ms step_avg:92.38ms +step:1144/1660 train_time:105682ms step_avg:92.38ms +step:1145/1660 train_time:105775ms step_avg:92.38ms +step:1146/1660 train_time:105867ms step_avg:92.38ms +step:1147/1660 train_time:105960ms step_avg:92.38ms +step:1148/1660 train_time:106055ms step_avg:92.38ms +step:1149/1660 train_time:106153ms step_avg:92.39ms +step:1150/1660 train_time:106248ms step_avg:92.39ms +step:1151/1660 train_time:106342ms step_avg:92.39ms +step:1152/1660 train_time:106435ms step_avg:92.39ms +step:1153/1660 train_time:106527ms step_avg:92.39ms +step:1154/1660 train_time:106619ms step_avg:92.39ms +step:1155/1660 train_time:106713ms step_avg:92.39ms +step:1156/1660 train_time:106805ms step_avg:92.39ms +step:1157/1660 train_time:106897ms step_avg:92.39ms +step:1158/1660 train_time:106992ms step_avg:92.39ms +step:1159/1660 train_time:107087ms step_avg:92.40ms +step:1160/1660 train_time:107181ms step_avg:92.40ms +step:1161/1660 train_time:107276ms step_avg:92.40ms +step:1162/1660 train_time:107371ms step_avg:92.40ms +step:1163/1660 train_time:107465ms step_avg:92.40ms +step:1164/1660 train_time:107558ms step_avg:92.40ms +step:1165/1660 train_time:107651ms step_avg:92.40ms +step:1166/1660 train_time:107745ms step_avg:92.41ms +step:1167/1660 train_time:107837ms step_avg:92.41ms +step:1168/1660 train_time:107931ms step_avg:92.41ms +step:1169/1660 train_time:108025ms step_avg:92.41ms +step:1170/1660 train_time:108119ms step_avg:92.41ms +step:1171/1660 train_time:108214ms step_avg:92.41ms +step:1172/1660 train_time:108309ms step_avg:92.41ms +step:1173/1660 train_time:108404ms step_avg:92.42ms +step:1174/1660 train_time:108496ms step_avg:92.42ms +step:1175/1660 train_time:108590ms step_avg:92.42ms +step:1176/1660 train_time:108683ms step_avg:92.42ms +step:1177/1660 train_time:108776ms step_avg:92.42ms +step:1178/1660 train_time:108870ms step_avg:92.42ms +step:1179/1660 train_time:108963ms step_avg:92.42ms +step:1180/1660 train_time:109056ms step_avg:92.42ms +step:1181/1660 train_time:109150ms step_avg:92.42ms +step:1182/1660 train_time:109244ms step_avg:92.42ms +step:1183/1660 train_time:109338ms step_avg:92.42ms +step:1184/1660 train_time:109431ms step_avg:92.43ms +step:1185/1660 train_time:109524ms step_avg:92.43ms +step:1186/1660 train_time:109618ms step_avg:92.43ms +step:1187/1660 train_time:109711ms step_avg:92.43ms +step:1188/1660 train_time:109803ms step_avg:92.43ms +step:1189/1660 train_time:109897ms step_avg:92.43ms +step:1190/1660 train_time:109992ms step_avg:92.43ms +step:1191/1660 train_time:110086ms step_avg:92.43ms +step:1192/1660 train_time:110179ms step_avg:92.43ms +step:1193/1660 train_time:110273ms step_avg:92.43ms +step:1194/1660 train_time:110367ms step_avg:92.43ms +step:1195/1660 train_time:110460ms step_avg:92.43ms +step:1196/1660 train_time:110553ms step_avg:92.44ms +step:1197/1660 train_time:110647ms step_avg:92.44ms +step:1198/1660 train_time:110740ms step_avg:92.44ms +step:1199/1660 train_time:110833ms step_avg:92.44ms +step:1200/1660 train_time:110926ms step_avg:92.44ms +step:1201/1660 train_time:111019ms step_avg:92.44ms +step:1202/1660 train_time:111113ms step_avg:92.44ms +step:1203/1660 train_time:111206ms step_avg:92.44ms +step:1204/1660 train_time:111299ms step_avg:92.44ms +step:1205/1660 train_time:111393ms step_avg:92.44ms +step:1206/1660 train_time:111487ms step_avg:92.44ms +step:1207/1660 train_time:111580ms step_avg:92.44ms +step:1208/1660 train_time:111674ms step_avg:92.45ms +step:1209/1660 train_time:111767ms step_avg:92.45ms +step:1210/1660 train_time:111860ms step_avg:92.45ms +step:1211/1660 train_time:111953ms step_avg:92.45ms +step:1212/1660 train_time:112048ms step_avg:92.45ms +step:1213/1660 train_time:112142ms step_avg:92.45ms +step:1214/1660 train_time:112236ms step_avg:92.45ms +step:1215/1660 train_time:112329ms step_avg:92.45ms +step:1216/1660 train_time:112421ms step_avg:92.45ms +step:1217/1660 train_time:112515ms step_avg:92.45ms +step:1218/1660 train_time:112610ms step_avg:92.46ms +step:1219/1660 train_time:112705ms step_avg:92.46ms +step:1220/1660 train_time:112798ms step_avg:92.46ms +step:1221/1660 train_time:112892ms step_avg:92.46ms +step:1222/1660 train_time:112985ms step_avg:92.46ms +step:1223/1660 train_time:113078ms step_avg:92.46ms +step:1224/1660 train_time:113172ms step_avg:92.46ms +step:1225/1660 train_time:113265ms step_avg:92.46ms +step:1226/1660 train_time:113358ms step_avg:92.46ms +step:1227/1660 train_time:113452ms step_avg:92.46ms +step:1228/1660 train_time:113546ms step_avg:92.46ms +step:1229/1660 train_time:113639ms step_avg:92.46ms +step:1230/1660 train_time:113733ms step_avg:92.47ms +step:1231/1660 train_time:113827ms step_avg:92.47ms +step:1232/1660 train_time:113920ms step_avg:92.47ms +step:1233/1660 train_time:114014ms step_avg:92.47ms +step:1234/1660 train_time:114108ms step_avg:92.47ms +step:1235/1660 train_time:114201ms step_avg:92.47ms +step:1236/1660 train_time:114294ms step_avg:92.47ms +step:1237/1660 train_time:114388ms step_avg:92.47ms +step:1238/1660 train_time:114481ms step_avg:92.47ms +step:1239/1660 train_time:114575ms step_avg:92.47ms +step:1240/1660 train_time:114669ms step_avg:92.48ms +step:1241/1660 train_time:114763ms step_avg:92.48ms +step:1242/1660 train_time:114856ms step_avg:92.48ms +step:1243/1660 train_time:114949ms step_avg:92.48ms +step:1244/1660 train_time:115042ms step_avg:92.48ms +step:1245/1660 train_time:115136ms step_avg:92.48ms +step:1246/1660 train_time:115229ms step_avg:92.48ms +step:1247/1660 train_time:115321ms step_avg:92.48ms +step:1248/1660 train_time:115414ms step_avg:92.48ms +step:1249/1660 train_time:115508ms step_avg:92.48ms +step:1250/1660 train_time:115601ms step_avg:92.48ms +step:1250/1660 val_loss:3.3746 train_time:115697ms step_avg:92.56ms +step:1251/1660 train_time:115718ms step_avg:92.50ms +step:1252/1660 train_time:115797ms step_avg:92.49ms +step:1253/1660 train_time:115894ms step_avg:92.49ms +step:1254/1660 train_time:115988ms step_avg:92.49ms +step:1255/1660 train_time:116081ms step_avg:92.49ms +step:1256/1660 train_time:116173ms step_avg:92.49ms +step:1257/1660 train_time:116266ms step_avg:92.49ms +step:1258/1660 train_time:116358ms step_avg:92.49ms +step:1259/1660 train_time:116451ms step_avg:92.49ms +step:1260/1660 train_time:116542ms step_avg:92.49ms +step:1261/1660 train_time:116636ms step_avg:92.50ms +step:1262/1660 train_time:116732ms step_avg:92.50ms +step:1263/1660 train_time:116829ms step_avg:92.50ms +step:1264/1660 train_time:116924ms step_avg:92.50ms +step:1265/1660 train_time:117017ms step_avg:92.50ms +step:1266/1660 train_time:117110ms step_avg:92.50ms +step:1267/1660 train_time:117203ms step_avg:92.50ms +step:1268/1660 train_time:117296ms step_avg:92.50ms +step:1269/1660 train_time:117388ms step_avg:92.50ms +step:1270/1660 train_time:117480ms step_avg:92.50ms +step:1271/1660 train_time:117574ms step_avg:92.50ms +step:1272/1660 train_time:117667ms step_avg:92.51ms +step:1273/1660 train_time:117762ms step_avg:92.51ms +step:1274/1660 train_time:117857ms step_avg:92.51ms +step:1275/1660 train_time:117951ms step_avg:92.51ms +step:1276/1660 train_time:118044ms step_avg:92.51ms +step:1277/1660 train_time:118138ms step_avg:92.51ms +step:1278/1660 train_time:118233ms step_avg:92.51ms +step:1279/1660 train_time:118326ms step_avg:92.51ms +step:1280/1660 train_time:118418ms step_avg:92.51ms +step:1281/1660 train_time:118511ms step_avg:92.51ms +step:1282/1660 train_time:118604ms step_avg:92.51ms +step:1283/1660 train_time:118698ms step_avg:92.52ms +step:1284/1660 train_time:118795ms step_avg:92.52ms +step:1285/1660 train_time:118889ms step_avg:92.52ms +step:1286/1660 train_time:118983ms step_avg:92.52ms +step:1287/1660 train_time:119075ms step_avg:92.52ms +step:1288/1660 train_time:119169ms step_avg:92.52ms +step:1289/1660 train_time:119262ms step_avg:92.52ms +step:1290/1660 train_time:119356ms step_avg:92.52ms +step:1291/1660 train_time:119450ms step_avg:92.53ms +step:1292/1660 train_time:119543ms step_avg:92.53ms +step:1293/1660 train_time:119636ms step_avg:92.53ms +step:1294/1660 train_time:119730ms step_avg:92.53ms +step:1295/1660 train_time:119825ms step_avg:92.53ms +step:1296/1660 train_time:119918ms step_avg:92.53ms +step:1297/1660 train_time:120013ms step_avg:92.53ms +step:1298/1660 train_time:120106ms step_avg:92.53ms +step:1299/1660 train_time:120199ms step_avg:92.53ms +step:1300/1660 train_time:120293ms step_avg:92.53ms +step:1301/1660 train_time:120385ms step_avg:92.53ms +step:1302/1660 train_time:120478ms step_avg:92.53ms +step:1303/1660 train_time:120572ms step_avg:92.53ms +step:1304/1660 train_time:120666ms step_avg:92.53ms +step:1305/1660 train_time:120760ms step_avg:92.54ms +step:1306/1660 train_time:120854ms step_avg:92.54ms +step:1307/1660 train_time:120947ms step_avg:92.54ms +step:1308/1660 train_time:121040ms step_avg:92.54ms +step:1309/1660 train_time:121135ms step_avg:92.54ms +step:1310/1660 train_time:121228ms step_avg:92.54ms +step:1311/1660 train_time:121322ms step_avg:92.54ms +step:1312/1660 train_time:121414ms step_avg:92.54ms +step:1313/1660 train_time:121508ms step_avg:92.54ms +step:1314/1660 train_time:121601ms step_avg:92.54ms +step:1315/1660 train_time:121694ms step_avg:92.54ms +step:1316/1660 train_time:121789ms step_avg:92.54ms +step:1317/1660 train_time:121881ms step_avg:92.54ms +step:1318/1660 train_time:121975ms step_avg:92.55ms +step:1319/1660 train_time:122069ms step_avg:92.55ms +step:1320/1660 train_time:122163ms step_avg:92.55ms +step:1321/1660 train_time:122257ms step_avg:92.55ms +step:1322/1660 train_time:122350ms step_avg:92.55ms +step:1323/1660 train_time:122443ms step_avg:92.55ms +step:1324/1660 train_time:122537ms step_avg:92.55ms +step:1325/1660 train_time:122631ms step_avg:92.55ms +step:1326/1660 train_time:122725ms step_avg:92.55ms +step:1327/1660 train_time:122818ms step_avg:92.55ms +step:1328/1660 train_time:122912ms step_avg:92.55ms +step:1329/1660 train_time:123005ms step_avg:92.55ms +step:1330/1660 train_time:123099ms step_avg:92.56ms +step:1331/1660 train_time:123194ms step_avg:92.56ms +step:1332/1660 train_time:123288ms step_avg:92.56ms +step:1333/1660 train_time:123381ms step_avg:92.56ms +step:1334/1660 train_time:123474ms step_avg:92.56ms +step:1335/1660 train_time:123568ms step_avg:92.56ms +step:1336/1660 train_time:123661ms step_avg:92.56ms +step:1337/1660 train_time:123755ms step_avg:92.56ms +step:1338/1660 train_time:123848ms step_avg:92.56ms +step:1339/1660 train_time:123941ms step_avg:92.56ms +step:1340/1660 train_time:124036ms step_avg:92.56ms +step:1341/1660 train_time:124131ms step_avg:92.57ms +step:1342/1660 train_time:124224ms step_avg:92.57ms +step:1343/1660 train_time:124317ms step_avg:92.57ms +step:1344/1660 train_time:124411ms step_avg:92.57ms +step:1345/1660 train_time:124505ms step_avg:92.57ms +step:1346/1660 train_time:124598ms step_avg:92.57ms +step:1347/1660 train_time:124692ms step_avg:92.57ms +step:1348/1660 train_time:124786ms step_avg:92.57ms +step:1349/1660 train_time:124879ms step_avg:92.57ms +step:1350/1660 train_time:124972ms step_avg:92.57ms +step:1351/1660 train_time:125067ms step_avg:92.57ms +step:1352/1660 train_time:125161ms step_avg:92.57ms +step:1353/1660 train_time:125256ms step_avg:92.58ms +step:1354/1660 train_time:125349ms step_avg:92.58ms +step:1355/1660 train_time:125442ms step_avg:92.58ms +step:1356/1660 train_time:125536ms step_avg:92.58ms +step:1357/1660 train_time:125631ms step_avg:92.58ms +step:1358/1660 train_time:125725ms step_avg:92.58ms +step:1359/1660 train_time:125818ms step_avg:92.58ms +step:1360/1660 train_time:125912ms step_avg:92.58ms +step:1361/1660 train_time:126006ms step_avg:92.58ms +step:1362/1660 train_time:126099ms step_avg:92.58ms +step:1363/1660 train_time:126194ms step_avg:92.59ms +step:1364/1660 train_time:126287ms step_avg:92.59ms +step:1365/1660 train_time:126380ms step_avg:92.59ms +step:1366/1660 train_time:126473ms step_avg:92.59ms +step:1367/1660 train_time:126568ms step_avg:92.59ms +step:1368/1660 train_time:126661ms step_avg:92.59ms +step:1369/1660 train_time:126755ms step_avg:92.59ms +step:1370/1660 train_time:126847ms step_avg:92.59ms +step:1371/1660 train_time:126940ms step_avg:92.59ms +step:1372/1660 train_time:127035ms step_avg:92.59ms +step:1373/1660 train_time:127129ms step_avg:92.59ms +step:1374/1660 train_time:127223ms step_avg:92.59ms +step:1375/1660 train_time:127316ms step_avg:92.59ms +step:1375/1660 val_loss:3.3407 train_time:127411ms step_avg:92.66ms +step:1376/1660 train_time:127433ms step_avg:92.61ms +step:1377/1660 train_time:127509ms step_avg:92.60ms +step:1378/1660 train_time:127606ms step_avg:92.60ms +step:1379/1660 train_time:127701ms step_avg:92.60ms +step:1380/1660 train_time:127793ms step_avg:92.60ms +step:1381/1660 train_time:127886ms step_avg:92.60ms +step:1382/1660 train_time:127978ms step_avg:92.60ms +step:1383/1660 train_time:128070ms step_avg:92.60ms +step:1384/1660 train_time:128163ms step_avg:92.60ms +step:1385/1660 train_time:128257ms step_avg:92.60ms +step:1386/1660 train_time:128352ms step_avg:92.61ms +step:1387/1660 train_time:128448ms step_avg:92.61ms +step:1388/1660 train_time:128545ms step_avg:92.61ms +step:1389/1660 train_time:128642ms step_avg:92.61ms +step:1390/1660 train_time:128737ms step_avg:92.62ms +step:1391/1660 train_time:128829ms step_avg:92.62ms +step:1392/1660 train_time:128922ms step_avg:92.62ms +step:1393/1660 train_time:129015ms step_avg:92.62ms +step:1394/1660 train_time:129107ms step_avg:92.62ms +step:1395/1660 train_time:129200ms step_avg:92.62ms +step:1396/1660 train_time:129293ms step_avg:92.62ms +step:1397/1660 train_time:129387ms step_avg:92.62ms +step:1398/1660 train_time:129481ms step_avg:92.62ms +step:1399/1660 train_time:129575ms step_avg:92.62ms +step:1400/1660 train_time:129671ms step_avg:92.62ms +step:1401/1660 train_time:129765ms step_avg:92.62ms +step:1402/1660 train_time:129859ms step_avg:92.62ms +step:1403/1660 train_time:129952ms step_avg:92.62ms +step:1404/1660 train_time:130044ms step_avg:92.62ms +step:1405/1660 train_time:130137ms step_avg:92.62ms +step:1406/1660 train_time:130229ms step_avg:92.62ms +step:1407/1660 train_time:130322ms step_avg:92.62ms +step:1408/1660 train_time:130416ms step_avg:92.63ms +step:1409/1660 train_time:130509ms step_avg:92.63ms +step:1410/1660 train_time:130605ms step_avg:92.63ms +step:1411/1660 train_time:130699ms step_avg:92.63ms +step:1412/1660 train_time:130792ms step_avg:92.63ms +step:1413/1660 train_time:130886ms step_avg:92.63ms +step:1414/1660 train_time:130980ms step_avg:92.63ms +step:1415/1660 train_time:131073ms step_avg:92.63ms +step:1416/1660 train_time:131165ms step_avg:92.63ms +step:1417/1660 train_time:131258ms step_avg:92.63ms +step:1418/1660 train_time:131352ms step_avg:92.63ms +step:1419/1660 train_time:131444ms step_avg:92.63ms +step:1420/1660 train_time:131538ms step_avg:92.63ms +step:1421/1660 train_time:131632ms step_avg:92.63ms +step:1422/1660 train_time:131726ms step_avg:92.63ms +step:1423/1660 train_time:131819ms step_avg:92.63ms +step:1424/1660 train_time:131912ms step_avg:92.64ms +step:1425/1660 train_time:132005ms step_avg:92.64ms +step:1426/1660 train_time:132098ms step_avg:92.64ms +step:1427/1660 train_time:132192ms step_avg:92.64ms +step:1428/1660 train_time:132285ms step_avg:92.64ms +step:1429/1660 train_time:132379ms step_avg:92.64ms +step:1430/1660 train_time:132472ms step_avg:92.64ms +step:1431/1660 train_time:132566ms step_avg:92.64ms +step:1432/1660 train_time:132662ms step_avg:92.64ms +step:1433/1660 train_time:132756ms step_avg:92.64ms +step:1434/1660 train_time:132849ms step_avg:92.64ms +step:1435/1660 train_time:132943ms step_avg:92.64ms +step:1436/1660 train_time:133036ms step_avg:92.64ms +step:1437/1660 train_time:133129ms step_avg:92.64ms +step:1438/1660 train_time:133222ms step_avg:92.64ms +step:1439/1660 train_time:133315ms step_avg:92.64ms +step:1440/1660 train_time:133408ms step_avg:92.64ms +step:1441/1660 train_time:133502ms step_avg:92.65ms +step:1442/1660 train_time:133595ms step_avg:92.65ms +step:1443/1660 train_time:133689ms step_avg:92.65ms +step:1444/1660 train_time:133783ms step_avg:92.65ms +step:1445/1660 train_time:133876ms step_avg:92.65ms +step:1446/1660 train_time:133969ms step_avg:92.65ms +step:1447/1660 train_time:134063ms step_avg:92.65ms +step:1448/1660 train_time:134158ms step_avg:92.65ms +step:1449/1660 train_time:134252ms step_avg:92.65ms +step:1450/1660 train_time:134345ms step_avg:92.65ms +step:1451/1660 train_time:134439ms step_avg:92.65ms +step:1452/1660 train_time:134532ms step_avg:92.65ms +step:1453/1660 train_time:134626ms step_avg:92.65ms +step:1454/1660 train_time:134719ms step_avg:92.65ms +step:1455/1660 train_time:134813ms step_avg:92.65ms +step:1456/1660 train_time:134906ms step_avg:92.66ms +step:1457/1660 train_time:134999ms step_avg:92.66ms +step:1458/1660 train_time:135093ms step_avg:92.66ms +step:1459/1660 train_time:135187ms step_avg:92.66ms +step:1460/1660 train_time:135281ms step_avg:92.66ms +step:1461/1660 train_time:135374ms step_avg:92.66ms +step:1462/1660 train_time:135468ms step_avg:92.66ms +step:1463/1660 train_time:135561ms step_avg:92.66ms +step:1464/1660 train_time:135655ms step_avg:92.66ms +step:1465/1660 train_time:135749ms step_avg:92.66ms +step:1466/1660 train_time:135842ms step_avg:92.66ms +step:1467/1660 train_time:135935ms step_avg:92.66ms +step:1468/1660 train_time:136029ms step_avg:92.66ms +step:1469/1660 train_time:136123ms step_avg:92.66ms +step:1470/1660 train_time:136216ms step_avg:92.66ms +step:1471/1660 train_time:136309ms step_avg:92.66ms +step:1472/1660 train_time:136403ms step_avg:92.66ms +step:1473/1660 train_time:136497ms step_avg:92.67ms +step:1474/1660 train_time:136590ms step_avg:92.67ms +step:1475/1660 train_time:136684ms step_avg:92.67ms +step:1476/1660 train_time:136778ms step_avg:92.67ms +step:1477/1660 train_time:136872ms step_avg:92.67ms +step:1478/1660 train_time:136966ms step_avg:92.67ms +step:1479/1660 train_time:137061ms step_avg:92.67ms +step:1480/1660 train_time:137155ms step_avg:92.67ms +step:1481/1660 train_time:137248ms step_avg:92.67ms +step:1482/1660 train_time:137341ms step_avg:92.67ms +step:1483/1660 train_time:137434ms step_avg:92.67ms +step:1484/1660 train_time:137527ms step_avg:92.67ms +step:1485/1660 train_time:137621ms step_avg:92.67ms +step:1486/1660 train_time:137714ms step_avg:92.67ms +step:1487/1660 train_time:137807ms step_avg:92.67ms +step:1488/1660 train_time:137901ms step_avg:92.68ms +step:1489/1660 train_time:137995ms step_avg:92.68ms +step:1490/1660 train_time:138089ms step_avg:92.68ms +step:1491/1660 train_time:138183ms step_avg:92.68ms +step:1492/1660 train_time:138276ms step_avg:92.68ms +step:1493/1660 train_time:138369ms step_avg:92.68ms +step:1494/1660 train_time:138463ms step_avg:92.68ms +step:1495/1660 train_time:138557ms step_avg:92.68ms +step:1496/1660 train_time:138649ms step_avg:92.68ms +step:1497/1660 train_time:138743ms step_avg:92.68ms +step:1498/1660 train_time:138836ms step_avg:92.68ms +step:1499/1660 train_time:138930ms step_avg:92.68ms +step:1500/1660 train_time:139024ms step_avg:92.68ms +step:1500/1660 val_loss:3.3105 train_time:139120ms step_avg:92.75ms +step:1501/1660 train_time:139141ms step_avg:92.70ms +step:1502/1660 train_time:139220ms step_avg:92.69ms +step:1503/1660 train_time:139319ms step_avg:92.69ms +step:1504/1660 train_time:139413ms step_avg:92.69ms +step:1505/1660 train_time:139506ms step_avg:92.69ms +step:1506/1660 train_time:139599ms step_avg:92.69ms +step:1507/1660 train_time:139691ms step_avg:92.69ms +step:1508/1660 train_time:139783ms step_avg:92.69ms +step:1509/1660 train_time:139875ms step_avg:92.69ms +step:1510/1660 train_time:139967ms step_avg:92.69ms +step:1511/1660 train_time:140061ms step_avg:92.69ms +step:1512/1660 train_time:140157ms step_avg:92.70ms +step:1513/1660 train_time:140251ms step_avg:92.70ms +step:1514/1660 train_time:140347ms step_avg:92.70ms +step:1515/1660 train_time:140441ms step_avg:92.70ms +step:1516/1660 train_time:140534ms step_avg:92.70ms +step:1517/1660 train_time:140626ms step_avg:92.70ms +step:1518/1660 train_time:140720ms step_avg:92.70ms +step:1519/1660 train_time:140814ms step_avg:92.70ms +step:1520/1660 train_time:140906ms step_avg:92.70ms +step:1521/1660 train_time:140998ms step_avg:92.70ms +step:1522/1660 train_time:141092ms step_avg:92.70ms +step:1523/1660 train_time:141186ms step_avg:92.70ms +step:1524/1660 train_time:141283ms step_avg:92.71ms +step:1525/1660 train_time:141379ms step_avg:92.71ms +step:1526/1660 train_time:141473ms step_avg:92.71ms +step:1527/1660 train_time:141566ms step_avg:92.71ms +step:1528/1660 train_time:141660ms step_avg:92.71ms +step:1529/1660 train_time:141754ms step_avg:92.71ms +step:1530/1660 train_time:141847ms step_avg:92.71ms +step:1531/1660 train_time:141939ms step_avg:92.71ms +step:1532/1660 train_time:142032ms step_avg:92.71ms +step:1533/1660 train_time:142125ms step_avg:92.71ms +step:1534/1660 train_time:142219ms step_avg:92.71ms +step:1535/1660 train_time:142314ms step_avg:92.71ms +step:1536/1660 train_time:142408ms step_avg:92.71ms +step:1537/1660 train_time:142502ms step_avg:92.71ms +step:1538/1660 train_time:142595ms step_avg:92.71ms +step:1539/1660 train_time:142688ms step_avg:92.71ms +step:1540/1660 train_time:142782ms step_avg:92.72ms +step:1541/1660 train_time:142875ms step_avg:92.72ms +step:1542/1660 train_time:142968ms step_avg:92.72ms +step:1543/1660 train_time:143061ms step_avg:92.72ms +step:1544/1660 train_time:143154ms step_avg:92.72ms +step:1545/1660 train_time:143248ms step_avg:92.72ms +step:1546/1660 train_time:143342ms step_avg:92.72ms +step:1547/1660 train_time:143436ms step_avg:92.72ms +step:1548/1660 train_time:143529ms step_avg:92.72ms +step:1549/1660 train_time:143622ms step_avg:92.72ms +step:1550/1660 train_time:143715ms step_avg:92.72ms +step:1551/1660 train_time:143809ms step_avg:92.72ms +step:1552/1660 train_time:143902ms step_avg:92.72ms +step:1553/1660 train_time:143995ms step_avg:92.72ms +step:1554/1660 train_time:144088ms step_avg:92.72ms +step:1555/1660 train_time:144183ms step_avg:92.72ms +step:1556/1660 train_time:144279ms step_avg:92.72ms +step:1557/1660 train_time:144373ms step_avg:92.73ms +step:1558/1660 train_time:144467ms step_avg:92.73ms +step:1559/1660 train_time:144563ms step_avg:92.73ms +step:1560/1660 train_time:144657ms step_avg:92.73ms +step:1561/1660 train_time:144751ms step_avg:92.73ms +step:1562/1660 train_time:144845ms step_avg:92.73ms +step:1563/1660 train_time:144938ms step_avg:92.73ms +step:1564/1660 train_time:145031ms step_avg:92.73ms +step:1565/1660 train_time:145124ms step_avg:92.73ms +step:1566/1660 train_time:145219ms step_avg:92.73ms +step:1567/1660 train_time:145314ms step_avg:92.73ms +step:1568/1660 train_time:145407ms step_avg:92.73ms +step:1569/1660 train_time:145502ms step_avg:92.74ms +step:1570/1660 train_time:145595ms step_avg:92.74ms +step:1571/1660 train_time:145688ms step_avg:92.74ms +step:1572/1660 train_time:145783ms step_avg:92.74ms +step:1573/1660 train_time:145877ms step_avg:92.74ms +step:1574/1660 train_time:145970ms step_avg:92.74ms +step:1575/1660 train_time:146063ms step_avg:92.74ms +step:1576/1660 train_time:146156ms step_avg:92.74ms +step:1577/1660 train_time:146249ms step_avg:92.74ms +step:1578/1660 train_time:146343ms step_avg:92.74ms +step:1579/1660 train_time:146437ms step_avg:92.74ms +step:1580/1660 train_time:146531ms step_avg:92.74ms +step:1581/1660 train_time:146624ms step_avg:92.74ms +step:1582/1660 train_time:146717ms step_avg:92.74ms +step:1583/1660 train_time:146810ms step_avg:92.74ms +step:1584/1660 train_time:146903ms step_avg:92.74ms +step:1585/1660 train_time:146997ms step_avg:92.74ms +step:1586/1660 train_time:147090ms step_avg:92.74ms +step:1587/1660 train_time:147183ms step_avg:92.74ms +step:1588/1660 train_time:147277ms step_avg:92.74ms +step:1589/1660 train_time:147371ms step_avg:92.74ms +step:1590/1660 train_time:147464ms step_avg:92.74ms +step:1591/1660 train_time:147558ms step_avg:92.75ms +step:1592/1660 train_time:147651ms step_avg:92.75ms +step:1593/1660 train_time:147745ms step_avg:92.75ms +step:1594/1660 train_time:147839ms step_avg:92.75ms +step:1595/1660 train_time:147932ms step_avg:92.75ms +step:1596/1660 train_time:148025ms step_avg:92.75ms +step:1597/1660 train_time:148118ms step_avg:92.75ms +step:1598/1660 train_time:148212ms step_avg:92.75ms +step:1599/1660 train_time:148305ms step_avg:92.75ms +step:1600/1660 train_time:148401ms step_avg:92.75ms +step:1601/1660 train_time:148494ms step_avg:92.75ms +step:1602/1660 train_time:148588ms step_avg:92.75ms +step:1603/1660 train_time:148681ms step_avg:92.75ms +step:1604/1660 train_time:148774ms step_avg:92.75ms +step:1605/1660 train_time:148868ms step_avg:92.75ms +step:1606/1660 train_time:148962ms step_avg:92.75ms +step:1607/1660 train_time:149054ms step_avg:92.75ms +step:1608/1660 train_time:149148ms step_avg:92.75ms +step:1609/1660 train_time:149242ms step_avg:92.75ms +step:1610/1660 train_time:149336ms step_avg:92.75ms +step:1611/1660 train_time:149429ms step_avg:92.76ms +step:1612/1660 train_time:149522ms step_avg:92.76ms +step:1613/1660 train_time:149617ms step_avg:92.76ms +step:1614/1660 train_time:149710ms step_avg:92.76ms +step:1615/1660 train_time:149804ms step_avg:92.76ms +step:1616/1660 train_time:149898ms step_avg:92.76ms +step:1617/1660 train_time:149992ms step_avg:92.76ms +step:1618/1660 train_time:150085ms step_avg:92.76ms +step:1619/1660 train_time:150178ms step_avg:92.76ms +step:1620/1660 train_time:150271ms step_avg:92.76ms +step:1621/1660 train_time:150364ms step_avg:92.76ms +step:1622/1660 train_time:150459ms step_avg:92.76ms +step:1623/1660 train_time:150553ms step_avg:92.76ms +step:1624/1660 train_time:150647ms step_avg:92.76ms +step:1625/1660 train_time:150742ms step_avg:92.76ms +step:1625/1660 val_loss:3.2860 train_time:150836ms step_avg:92.82ms +step:1626/1660 train_time:150857ms step_avg:92.78ms +step:1627/1660 train_time:150934ms step_avg:92.77ms +step:1628/1660 train_time:151032ms step_avg:92.77ms +step:1629/1660 train_time:151126ms step_avg:92.77ms +step:1630/1660 train_time:151219ms step_avg:92.77ms +step:1631/1660 train_time:151312ms step_avg:92.77ms +step:1632/1660 train_time:151404ms step_avg:92.77ms +step:1633/1660 train_time:151496ms step_avg:92.77ms +step:1634/1660 train_time:151588ms step_avg:92.77ms +step:1635/1660 train_time:151681ms step_avg:92.77ms +step:1636/1660 train_time:151774ms step_avg:92.77ms +step:1637/1660 train_time:151870ms step_avg:92.77ms +step:1638/1660 train_time:151968ms step_avg:92.78ms +step:1639/1660 train_time:152064ms step_avg:92.78ms +step:1640/1660 train_time:152159ms step_avg:92.78ms +step:1641/1660 train_time:152252ms step_avg:92.78ms +step:1642/1660 train_time:152345ms step_avg:92.78ms +step:1643/1660 train_time:152438ms step_avg:92.78ms +step:1644/1660 train_time:152530ms step_avg:92.78ms +step:1645/1660 train_time:152623ms step_avg:92.78ms +step:1646/1660 train_time:152716ms step_avg:92.78ms +step:1647/1660 train_time:152809ms step_avg:92.78ms +step:1648/1660 train_time:152905ms step_avg:92.78ms +step:1649/1660 train_time:153002ms step_avg:92.78ms +step:1650/1660 train_time:153097ms step_avg:92.79ms +step:1651/1660 train_time:153190ms step_avg:92.79ms +step:1652/1660 train_time:153282ms step_avg:92.79ms +step:1653/1660 train_time:153375ms step_avg:92.79ms +step:1654/1660 train_time:153468ms step_avg:92.79ms +step:1655/1660 train_time:153561ms step_avg:92.79ms +step:1656/1660 train_time:153654ms step_avg:92.79ms +step:1657/1660 train_time:153746ms step_avg:92.79ms +step:1658/1660 train_time:153840ms step_avg:92.79ms +step:1659/1660 train_time:153936ms step_avg:92.79ms +step:1660/1660 train_time:154030ms step_avg:92.79ms +step:1660/1660 val_loss:3.2778 train_time:154126ms step_avg:92.85ms +peak memory allocated: 31587 MiB reserved: 46516 MiB diff --git a/records/091525_ThreadingFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt b/records/091525_ThreadingFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt new file mode 100644 index 000000000..5afdd3e3a --- /dev/null +++ b/records/091525_ThreadingFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:40:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 40C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 192275 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 192276 C /usr/bin/python3 614MiB | +| 0 N/A N/A 192277 C /usr/bin/python3 614MiB | +| 0 N/A N/A 192278 C /usr/bin/python3 614MiB | +| 0 N/A N/A 192279 C /usr/bin/python3 614MiB | +| 0 N/A N/A 192280 C /usr/bin/python3 614MiB | +| 0 N/A N/A 192281 C /usr/bin/python3 614MiB | +| 0 N/A N/A 192282 C /usr/bin/python3 614MiB | +| 1 N/A N/A 192276 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 192277 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 192278 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 192279 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 192280 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 192281 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 192282 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:157ms step_avg:156.84ms +step:2/1660 train_time:178ms step_avg:88.98ms +step:3/1660 train_time:244ms step_avg:81.45ms +step:4/1660 train_time:333ms step_avg:83.34ms +step:5/1660 train_time:423ms step_avg:84.63ms +step:6/1660 train_time:513ms step_avg:85.58ms +step:7/1660 train_time:604ms step_avg:86.24ms +step:8/1660 train_time:694ms step_avg:86.77ms +step:9/1660 train_time:784ms step_avg:87.11ms +step:10/1660 train_time:875ms step_avg:87.47ms +step:11/1660 train_time:965ms step_avg:87.75ms +step:12/1660 train_time:1058ms step_avg:88.19ms +step:13/1660 train_time:1154ms step_avg:88.74ms +step:14/1660 train_time:1247ms step_avg:89.06ms +step:15/1660 train_time:1340ms step_avg:89.34ms +step:16/1660 train_time:1432ms step_avg:89.52ms +step:17/1660 train_time:1523ms step_avg:89.60ms +step:18/1660 train_time:1615ms step_avg:89.72ms +step:19/1660 train_time:1706ms step_avg:89.78ms +step:20/1660 train_time:1797ms step_avg:89.85ms +step:21/1660 train_time:1887ms step_avg:89.87ms +step:22/1660 train_time:1979ms step_avg:89.95ms +step:23/1660 train_time:2072ms step_avg:90.08ms +step:24/1660 train_time:2165ms step_avg:90.20ms +step:25/1660 train_time:2259ms step_avg:90.36ms +step:26/1660 train_time:2352ms step_avg:90.46ms +step:27/1660 train_time:2443ms step_avg:90.48ms +step:28/1660 train_time:2534ms step_avg:90.50ms +step:29/1660 train_time:2625ms step_avg:90.51ms +step:30/1660 train_time:2716ms step_avg:90.53ms +step:31/1660 train_time:2806ms step_avg:90.51ms +step:32/1660 train_time:2897ms step_avg:90.53ms +step:33/1660 train_time:2988ms step_avg:90.54ms +step:34/1660 train_time:3080ms step_avg:90.59ms +step:35/1660 train_time:3173ms step_avg:90.65ms +step:36/1660 train_time:3266ms step_avg:90.72ms +step:37/1660 train_time:3360ms step_avg:90.80ms +step:38/1660 train_time:3452ms step_avg:90.84ms +step:39/1660 train_time:3543ms step_avg:90.84ms +step:40/1660 train_time:3634ms step_avg:90.85ms +step:41/1660 train_time:3725ms step_avg:90.86ms +step:42/1660 train_time:3817ms step_avg:90.87ms +step:43/1660 train_time:3908ms step_avg:90.88ms +step:44/1660 train_time:4000ms step_avg:90.91ms +step:45/1660 train_time:4091ms step_avg:90.92ms +step:46/1660 train_time:4183ms step_avg:90.94ms +step:47/1660 train_time:4276ms step_avg:90.98ms +step:48/1660 train_time:4369ms step_avg:91.01ms +step:49/1660 train_time:4461ms step_avg:91.04ms +step:50/1660 train_time:4553ms step_avg:91.07ms +step:51/1660 train_time:4645ms step_avg:91.08ms +step:52/1660 train_time:4737ms step_avg:91.09ms +step:53/1660 train_time:4827ms step_avg:91.07ms +step:54/1660 train_time:4918ms step_avg:91.07ms +step:55/1660 train_time:5010ms step_avg:91.08ms +step:56/1660 train_time:5101ms step_avg:91.09ms +step:57/1660 train_time:5193ms step_avg:91.10ms +step:58/1660 train_time:5284ms step_avg:91.11ms +step:59/1660 train_time:5377ms step_avg:91.14ms +step:60/1660 train_time:5471ms step_avg:91.18ms +step:61/1660 train_time:5563ms step_avg:91.19ms +step:62/1660 train_time:5656ms step_avg:91.23ms +step:63/1660 train_time:5747ms step_avg:91.23ms +step:64/1660 train_time:5839ms step_avg:91.23ms +step:65/1660 train_time:5931ms step_avg:91.25ms +step:66/1660 train_time:6022ms step_avg:91.24ms +step:67/1660 train_time:6114ms step_avg:91.25ms +step:68/1660 train_time:6204ms step_avg:91.23ms +step:69/1660 train_time:6295ms step_avg:91.24ms +step:70/1660 train_time:6387ms step_avg:91.24ms +step:71/1660 train_time:6480ms step_avg:91.27ms +step:72/1660 train_time:6572ms step_avg:91.28ms +step:73/1660 train_time:6664ms step_avg:91.29ms +step:74/1660 train_time:6757ms step_avg:91.31ms +step:75/1660 train_time:6849ms step_avg:91.32ms +step:76/1660 train_time:6940ms step_avg:91.32ms +step:77/1660 train_time:7032ms step_avg:91.32ms +step:78/1660 train_time:7122ms step_avg:91.31ms +step:79/1660 train_time:7213ms step_avg:91.31ms +step:80/1660 train_time:7304ms step_avg:91.30ms +step:81/1660 train_time:7396ms step_avg:91.31ms +step:82/1660 train_time:7487ms step_avg:91.31ms +step:83/1660 train_time:7580ms step_avg:91.33ms +step:84/1660 train_time:7673ms step_avg:91.34ms +step:85/1660 train_time:7764ms step_avg:91.34ms +step:86/1660 train_time:7855ms step_avg:91.34ms +step:87/1660 train_time:7948ms step_avg:91.35ms +step:88/1660 train_time:8040ms step_avg:91.36ms +step:89/1660 train_time:8132ms step_avg:91.38ms +step:90/1660 train_time:8223ms step_avg:91.37ms +step:91/1660 train_time:8314ms step_avg:91.37ms +step:92/1660 train_time:8405ms step_avg:91.36ms +step:93/1660 train_time:8497ms step_avg:91.37ms +step:94/1660 train_time:8589ms step_avg:91.37ms +step:95/1660 train_time:8682ms step_avg:91.39ms +step:96/1660 train_time:8774ms step_avg:91.40ms +step:97/1660 train_time:8865ms step_avg:91.39ms +step:98/1660 train_time:8957ms step_avg:91.40ms +step:99/1660 train_time:9049ms step_avg:91.40ms +step:100/1660 train_time:9141ms step_avg:91.41ms +step:101/1660 train_time:9232ms step_avg:91.41ms +step:102/1660 train_time:9322ms step_avg:91.40ms +step:103/1660 train_time:9414ms step_avg:91.40ms +step:104/1660 train_time:9505ms step_avg:91.39ms +step:105/1660 train_time:9597ms step_avg:91.40ms +step:106/1660 train_time:9689ms step_avg:91.40ms +step:107/1660 train_time:9782ms step_avg:91.42ms +step:108/1660 train_time:9874ms step_avg:91.43ms +step:109/1660 train_time:9965ms step_avg:91.42ms +step:110/1660 train_time:10057ms step_avg:91.43ms +step:111/1660 train_time:10150ms step_avg:91.44ms +step:112/1660 train_time:10241ms step_avg:91.44ms +step:113/1660 train_time:10332ms step_avg:91.43ms +step:114/1660 train_time:10422ms step_avg:91.42ms +step:115/1660 train_time:10514ms step_avg:91.43ms +step:116/1660 train_time:10605ms step_avg:91.42ms +step:117/1660 train_time:10697ms step_avg:91.43ms +step:118/1660 train_time:10789ms step_avg:91.43ms +step:119/1660 train_time:10881ms step_avg:91.44ms +step:120/1660 train_time:10974ms step_avg:91.45ms +step:121/1660 train_time:11065ms step_avg:91.44ms +step:122/1660 train_time:11156ms step_avg:91.44ms +step:123/1660 train_time:11248ms step_avg:91.45ms +step:124/1660 train_time:11340ms step_avg:91.45ms +step:125/1660 train_time:11431ms step_avg:91.45ms +step:125/1660 val_loss:4.3019 train_time:11523ms step_avg:92.19ms +step:126/1660 train_time:11546ms step_avg:91.63ms +step:127/1660 train_time:11618ms step_avg:91.48ms +step:128/1660 train_time:11718ms step_avg:91.55ms +step:129/1660 train_time:11811ms step_avg:91.56ms +step:130/1660 train_time:11903ms step_avg:91.56ms +step:131/1660 train_time:11994ms step_avg:91.55ms +step:132/1660 train_time:12084ms step_avg:91.55ms +step:133/1660 train_time:12175ms step_avg:91.54ms +step:134/1660 train_time:12266ms step_avg:91.53ms +step:135/1660 train_time:12356ms step_avg:91.52ms +step:136/1660 train_time:12446ms step_avg:91.51ms +step:137/1660 train_time:12537ms step_avg:91.51ms +step:138/1660 train_time:12630ms step_avg:91.53ms +step:139/1660 train_time:12726ms step_avg:91.56ms +step:140/1660 train_time:12821ms step_avg:91.58ms +step:141/1660 train_time:12912ms step_avg:91.58ms +step:142/1660 train_time:13003ms step_avg:91.57ms +step:143/1660 train_time:13094ms step_avg:91.57ms +step:144/1660 train_time:13185ms step_avg:91.56ms +step:145/1660 train_time:13276ms step_avg:91.56ms +step:146/1660 train_time:13366ms step_avg:91.55ms +step:147/1660 train_time:13457ms step_avg:91.55ms +step:148/1660 train_time:13548ms step_avg:91.54ms +step:149/1660 train_time:13641ms step_avg:91.55ms +step:150/1660 train_time:13733ms step_avg:91.55ms +step:151/1660 train_time:13826ms step_avg:91.56ms +step:152/1660 train_time:13918ms step_avg:91.57ms +step:153/1660 train_time:14009ms step_avg:91.56ms +step:154/1660 train_time:14101ms step_avg:91.57ms +step:155/1660 train_time:14193ms step_avg:91.57ms +step:156/1660 train_time:14285ms step_avg:91.57ms +step:157/1660 train_time:14376ms step_avg:91.56ms +step:158/1660 train_time:14466ms step_avg:91.56ms +step:159/1660 train_time:14558ms step_avg:91.56ms +step:160/1660 train_time:14649ms step_avg:91.55ms +step:161/1660 train_time:14742ms step_avg:91.56ms +step:162/1660 train_time:14834ms step_avg:91.57ms +step:163/1660 train_time:14926ms step_avg:91.57ms +step:164/1660 train_time:15019ms step_avg:91.58ms +step:165/1660 train_time:15109ms step_avg:91.57ms +step:166/1660 train_time:15201ms step_avg:91.57ms +step:167/1660 train_time:15293ms step_avg:91.58ms +step:168/1660 train_time:15385ms step_avg:91.58ms +step:169/1660 train_time:15476ms step_avg:91.57ms +step:170/1660 train_time:15567ms step_avg:91.57ms +step:171/1660 train_time:15659ms step_avg:91.57ms +step:172/1660 train_time:15751ms step_avg:91.57ms +step:173/1660 train_time:15843ms step_avg:91.58ms +step:174/1660 train_time:15935ms step_avg:91.58ms +step:175/1660 train_time:16026ms step_avg:91.58ms +step:176/1660 train_time:16117ms step_avg:91.58ms +step:177/1660 train_time:16209ms step_avg:91.57ms +step:178/1660 train_time:16300ms step_avg:91.57ms +step:179/1660 train_time:16391ms step_avg:91.57ms +step:180/1660 train_time:16483ms step_avg:91.57ms +step:181/1660 train_time:16574ms step_avg:91.57ms +step:182/1660 train_time:16665ms step_avg:91.57ms +step:183/1660 train_time:16758ms step_avg:91.57ms +step:184/1660 train_time:16849ms step_avg:91.57ms +step:185/1660 train_time:16941ms step_avg:91.57ms +step:186/1660 train_time:17032ms step_avg:91.57ms +step:187/1660 train_time:17125ms step_avg:91.58ms +step:188/1660 train_time:17216ms step_avg:91.57ms +step:189/1660 train_time:17307ms step_avg:91.57ms +step:190/1660 train_time:17398ms step_avg:91.57ms +step:191/1660 train_time:17489ms step_avg:91.57ms +step:192/1660 train_time:17581ms step_avg:91.57ms +step:193/1660 train_time:17673ms step_avg:91.57ms +step:194/1660 train_time:17764ms step_avg:91.57ms +step:195/1660 train_time:17856ms step_avg:91.57ms +step:196/1660 train_time:17948ms step_avg:91.57ms +step:197/1660 train_time:18039ms step_avg:91.57ms +step:198/1660 train_time:18130ms step_avg:91.57ms +step:199/1660 train_time:18223ms step_avg:91.57ms +step:200/1660 train_time:18314ms step_avg:91.57ms +step:201/1660 train_time:18405ms step_avg:91.57ms +step:202/1660 train_time:18497ms step_avg:91.57ms +step:203/1660 train_time:18588ms step_avg:91.57ms +step:204/1660 train_time:18680ms step_avg:91.57ms +step:205/1660 train_time:18771ms step_avg:91.57ms +step:206/1660 train_time:18863ms step_avg:91.57ms +step:207/1660 train_time:18955ms step_avg:91.57ms +step:208/1660 train_time:19046ms step_avg:91.57ms +step:209/1660 train_time:19137ms step_avg:91.57ms +step:210/1660 train_time:19228ms step_avg:91.56ms +step:211/1660 train_time:19320ms step_avg:91.56ms +step:212/1660 train_time:19411ms step_avg:91.56ms +step:213/1660 train_time:19505ms step_avg:91.57ms +step:214/1660 train_time:19597ms step_avg:91.57ms +step:215/1660 train_time:19688ms step_avg:91.57ms +step:216/1660 train_time:19779ms step_avg:91.57ms +step:217/1660 train_time:19869ms step_avg:91.56ms +step:218/1660 train_time:19961ms step_avg:91.57ms +step:219/1660 train_time:20053ms step_avg:91.56ms +step:220/1660 train_time:20144ms step_avg:91.56ms +step:221/1660 train_time:20235ms step_avg:91.56ms +step:222/1660 train_time:20326ms step_avg:91.56ms +step:223/1660 train_time:20417ms step_avg:91.56ms +step:224/1660 train_time:20509ms step_avg:91.56ms +step:225/1660 train_time:20601ms step_avg:91.56ms +step:226/1660 train_time:20693ms step_avg:91.56ms +step:227/1660 train_time:20784ms step_avg:91.56ms +step:228/1660 train_time:20876ms step_avg:91.56ms +step:229/1660 train_time:20967ms step_avg:91.56ms +step:230/1660 train_time:21059ms step_avg:91.56ms +step:231/1660 train_time:21150ms step_avg:91.56ms +step:232/1660 train_time:21242ms step_avg:91.56ms +step:233/1660 train_time:21334ms step_avg:91.56ms +step:234/1660 train_time:21426ms step_avg:91.56ms +step:235/1660 train_time:21518ms step_avg:91.57ms +step:236/1660 train_time:21609ms step_avg:91.56ms +step:237/1660 train_time:21700ms step_avg:91.56ms +step:238/1660 train_time:21793ms step_avg:91.57ms +step:239/1660 train_time:21884ms step_avg:91.57ms +step:240/1660 train_time:21975ms step_avg:91.56ms +step:241/1660 train_time:22066ms step_avg:91.56ms +step:242/1660 train_time:22157ms step_avg:91.56ms +step:243/1660 train_time:22248ms step_avg:91.55ms +step:244/1660 train_time:22339ms step_avg:91.55ms +step:245/1660 train_time:22430ms step_avg:91.55ms +step:246/1660 train_time:22524ms step_avg:91.56ms +step:247/1660 train_time:22616ms step_avg:91.56ms +step:248/1660 train_time:22707ms step_avg:91.56ms +step:249/1660 train_time:22798ms step_avg:91.56ms +step:250/1660 train_time:22889ms step_avg:91.56ms +step:250/1660 val_loss:3.9644 train_time:22982ms step_avg:91.93ms +step:251/1660 train_time:23002ms step_avg:91.64ms +step:252/1660 train_time:23075ms step_avg:91.57ms +step:253/1660 train_time:23173ms step_avg:91.59ms +step:254/1660 train_time:23267ms step_avg:91.60ms +step:255/1660 train_time:23358ms step_avg:91.60ms +step:256/1660 train_time:23448ms step_avg:91.59ms +step:257/1660 train_time:23538ms step_avg:91.59ms +step:258/1660 train_time:23628ms step_avg:91.58ms +step:259/1660 train_time:23718ms step_avg:91.58ms +step:260/1660 train_time:23808ms step_avg:91.57ms +step:261/1660 train_time:23899ms step_avg:91.57ms +step:262/1660 train_time:23991ms step_avg:91.57ms +step:263/1660 train_time:24085ms step_avg:91.58ms +step:264/1660 train_time:24178ms step_avg:91.58ms +step:265/1660 train_time:24270ms step_avg:91.58ms +step:266/1660 train_time:24362ms step_avg:91.59ms +step:267/1660 train_time:24453ms step_avg:91.58ms +step:268/1660 train_time:24544ms step_avg:91.58ms +step:269/1660 train_time:24635ms step_avg:91.58ms +step:270/1660 train_time:24726ms step_avg:91.58ms +step:271/1660 train_time:24816ms step_avg:91.57ms +step:272/1660 train_time:24907ms step_avg:91.57ms +step:273/1660 train_time:24998ms step_avg:91.57ms +step:274/1660 train_time:25090ms step_avg:91.57ms +step:275/1660 train_time:25183ms step_avg:91.58ms +step:276/1660 train_time:25275ms step_avg:91.58ms +step:277/1660 train_time:25367ms step_avg:91.58ms +step:278/1660 train_time:25459ms step_avg:91.58ms +step:279/1660 train_time:25550ms step_avg:91.58ms +step:280/1660 train_time:25642ms step_avg:91.58ms +step:281/1660 train_time:25733ms step_avg:91.58ms +step:282/1660 train_time:25825ms step_avg:91.58ms +step:283/1660 train_time:25916ms step_avg:91.58ms +step:284/1660 train_time:26007ms step_avg:91.57ms +step:285/1660 train_time:26098ms step_avg:91.57ms +step:286/1660 train_time:26190ms step_avg:91.57ms +step:287/1660 train_time:26281ms step_avg:91.57ms +step:288/1660 train_time:26373ms step_avg:91.57ms +step:289/1660 train_time:26466ms step_avg:91.58ms +step:290/1660 train_time:26558ms step_avg:91.58ms +step:291/1660 train_time:26649ms step_avg:91.58ms +step:292/1660 train_time:26741ms step_avg:91.58ms +step:293/1660 train_time:26831ms step_avg:91.57ms +step:294/1660 train_time:26923ms step_avg:91.57ms +step:295/1660 train_time:27014ms step_avg:91.57ms +step:296/1660 train_time:27105ms step_avg:91.57ms +step:297/1660 train_time:27197ms step_avg:91.57ms +step:298/1660 train_time:27289ms step_avg:91.57ms +step:299/1660 train_time:27381ms step_avg:91.57ms +step:300/1660 train_time:27472ms step_avg:91.57ms +step:301/1660 train_time:27565ms step_avg:91.58ms +step:302/1660 train_time:27657ms step_avg:91.58ms +step:303/1660 train_time:27748ms step_avg:91.58ms +step:304/1660 train_time:27840ms step_avg:91.58ms +step:305/1660 train_time:27930ms step_avg:91.57ms +step:306/1660 train_time:28022ms step_avg:91.58ms +step:307/1660 train_time:28114ms step_avg:91.58ms +step:308/1660 train_time:28205ms step_avg:91.57ms +step:309/1660 train_time:28297ms step_avg:91.58ms +step:310/1660 train_time:28388ms step_avg:91.57ms +step:311/1660 train_time:28479ms step_avg:91.57ms +step:312/1660 train_time:28570ms step_avg:91.57ms +step:313/1660 train_time:28663ms step_avg:91.57ms +step:314/1660 train_time:28754ms step_avg:91.57ms +step:315/1660 train_time:28845ms step_avg:91.57ms +step:316/1660 train_time:28937ms step_avg:91.57ms +step:317/1660 train_time:29028ms step_avg:91.57ms +step:318/1660 train_time:29119ms step_avg:91.57ms +step:319/1660 train_time:29211ms step_avg:91.57ms +step:320/1660 train_time:29302ms step_avg:91.57ms +step:321/1660 train_time:29394ms step_avg:91.57ms +step:322/1660 train_time:29486ms step_avg:91.57ms +step:323/1660 train_time:29576ms step_avg:91.57ms +step:324/1660 train_time:29668ms step_avg:91.57ms +step:325/1660 train_time:29760ms step_avg:91.57ms +step:326/1660 train_time:29852ms step_avg:91.57ms +step:327/1660 train_time:29945ms step_avg:91.57ms +step:328/1660 train_time:30036ms step_avg:91.57ms +step:329/1660 train_time:30128ms step_avg:91.57ms +step:330/1660 train_time:30219ms step_avg:91.57ms +step:331/1660 train_time:30309ms step_avg:91.57ms +step:332/1660 train_time:30400ms step_avg:91.57ms +step:333/1660 train_time:30491ms step_avg:91.57ms +step:334/1660 train_time:30583ms step_avg:91.56ms +step:335/1660 train_time:30674ms step_avg:91.56ms +step:336/1660 train_time:30766ms step_avg:91.56ms +step:337/1660 train_time:30858ms step_avg:91.57ms +step:338/1660 train_time:30950ms step_avg:91.57ms +step:339/1660 train_time:31042ms step_avg:91.57ms +step:340/1660 train_time:31134ms step_avg:91.57ms +step:341/1660 train_time:31225ms step_avg:91.57ms +step:342/1660 train_time:31316ms step_avg:91.57ms +step:343/1660 train_time:31407ms step_avg:91.57ms +step:344/1660 train_time:31498ms step_avg:91.56ms +step:345/1660 train_time:31589ms step_avg:91.56ms +step:346/1660 train_time:31680ms step_avg:91.56ms +step:347/1660 train_time:31771ms step_avg:91.56ms +step:348/1660 train_time:31863ms step_avg:91.56ms +step:349/1660 train_time:31956ms step_avg:91.56ms +step:350/1660 train_time:32047ms step_avg:91.56ms +step:351/1660 train_time:32140ms step_avg:91.57ms +step:352/1660 train_time:32232ms step_avg:91.57ms +step:353/1660 train_time:32323ms step_avg:91.57ms +step:354/1660 train_time:32415ms step_avg:91.57ms +step:355/1660 train_time:32507ms step_avg:91.57ms +step:356/1660 train_time:32597ms step_avg:91.57ms +step:357/1660 train_time:32688ms step_avg:91.56ms +step:358/1660 train_time:32780ms step_avg:91.56ms +step:359/1660 train_time:32871ms step_avg:91.56ms +step:360/1660 train_time:32964ms step_avg:91.57ms +step:361/1660 train_time:33056ms step_avg:91.57ms +step:362/1660 train_time:33147ms step_avg:91.57ms +step:363/1660 train_time:33238ms step_avg:91.57ms +step:364/1660 train_time:33329ms step_avg:91.56ms +step:365/1660 train_time:33420ms step_avg:91.56ms +step:366/1660 train_time:33510ms step_avg:91.56ms +step:367/1660 train_time:33602ms step_avg:91.56ms +step:368/1660 train_time:33693ms step_avg:91.56ms +step:369/1660 train_time:33785ms step_avg:91.56ms +step:370/1660 train_time:33877ms step_avg:91.56ms +step:371/1660 train_time:33968ms step_avg:91.56ms +step:372/1660 train_time:34061ms step_avg:91.56ms +step:373/1660 train_time:34152ms step_avg:91.56ms +step:374/1660 train_time:34244ms step_avg:91.56ms +step:375/1660 train_time:34337ms step_avg:91.57ms +step:375/1660 val_loss:3.8148 train_time:34429ms step_avg:91.81ms +step:376/1660 train_time:34450ms step_avg:91.62ms +step:377/1660 train_time:34524ms step_avg:91.58ms +step:378/1660 train_time:34623ms step_avg:91.60ms +step:379/1660 train_time:34715ms step_avg:91.60ms +step:380/1660 train_time:34806ms step_avg:91.59ms +step:381/1660 train_time:34896ms step_avg:91.59ms +step:382/1660 train_time:34986ms step_avg:91.59ms +step:383/1660 train_time:35076ms step_avg:91.58ms +step:384/1660 train_time:35166ms step_avg:91.58ms +step:385/1660 train_time:35256ms step_avg:91.57ms +step:386/1660 train_time:35347ms step_avg:91.57ms +step:387/1660 train_time:35440ms step_avg:91.58ms +step:388/1660 train_time:35533ms step_avg:91.58ms +step:389/1660 train_time:35626ms step_avg:91.58ms +step:390/1660 train_time:35719ms step_avg:91.59ms +step:391/1660 train_time:35810ms step_avg:91.59ms +step:392/1660 train_time:35902ms step_avg:91.59ms +step:393/1660 train_time:35994ms step_avg:91.59ms +step:394/1660 train_time:36085ms step_avg:91.59ms +step:395/1660 train_time:36176ms step_avg:91.58ms +step:396/1660 train_time:36266ms step_avg:91.58ms +step:397/1660 train_time:36358ms step_avg:91.58ms +step:398/1660 train_time:36449ms step_avg:91.58ms +step:399/1660 train_time:36541ms step_avg:91.58ms +step:400/1660 train_time:36633ms step_avg:91.58ms +step:401/1660 train_time:36726ms step_avg:91.59ms +step:402/1660 train_time:36818ms step_avg:91.59ms +step:403/1660 train_time:36909ms step_avg:91.59ms +step:404/1660 train_time:37000ms step_avg:91.59ms +step:405/1660 train_time:37091ms step_avg:91.58ms +step:406/1660 train_time:37182ms step_avg:91.58ms +step:407/1660 train_time:37272ms step_avg:91.58ms +step:408/1660 train_time:37364ms step_avg:91.58ms +step:409/1660 train_time:37456ms step_avg:91.58ms +step:410/1660 train_time:37547ms step_avg:91.58ms +step:411/1660 train_time:37640ms step_avg:91.58ms +step:412/1660 train_time:37731ms step_avg:91.58ms +step:413/1660 train_time:37822ms step_avg:91.58ms +step:414/1660 train_time:37913ms step_avg:91.58ms +step:415/1660 train_time:38006ms step_avg:91.58ms +step:416/1660 train_time:38098ms step_avg:91.58ms +step:417/1660 train_time:38189ms step_avg:91.58ms +step:418/1660 train_time:38279ms step_avg:91.58ms +step:419/1660 train_time:38370ms step_avg:91.57ms +step:420/1660 train_time:38462ms step_avg:91.58ms +step:421/1660 train_time:38555ms step_avg:91.58ms +step:422/1660 train_time:38646ms step_avg:91.58ms +step:423/1660 train_time:38738ms step_avg:91.58ms +step:424/1660 train_time:38828ms step_avg:91.58ms +step:425/1660 train_time:38919ms step_avg:91.57ms +step:426/1660 train_time:39011ms step_avg:91.57ms +step:427/1660 train_time:39103ms step_avg:91.58ms +step:428/1660 train_time:39196ms step_avg:91.58ms +step:429/1660 train_time:39287ms step_avg:91.58ms +step:430/1660 train_time:39378ms step_avg:91.58ms +step:431/1660 train_time:39469ms step_avg:91.58ms +step:432/1660 train_time:39561ms step_avg:91.58ms +step:433/1660 train_time:39652ms step_avg:91.57ms +step:434/1660 train_time:39744ms step_avg:91.58ms +step:435/1660 train_time:39835ms step_avg:91.57ms +step:436/1660 train_time:39926ms step_avg:91.57ms +step:437/1660 train_time:40017ms step_avg:91.57ms +step:438/1660 train_time:40109ms step_avg:91.57ms +step:439/1660 train_time:40202ms step_avg:91.58ms +step:440/1660 train_time:40293ms step_avg:91.58ms +step:441/1660 train_time:40385ms step_avg:91.58ms +step:442/1660 train_time:40476ms step_avg:91.58ms +step:443/1660 train_time:40568ms step_avg:91.58ms +step:444/1660 train_time:40659ms step_avg:91.57ms +step:445/1660 train_time:40750ms step_avg:91.57ms +step:446/1660 train_time:40842ms step_avg:91.57ms +step:447/1660 train_time:40933ms step_avg:91.57ms +step:448/1660 train_time:41024ms step_avg:91.57ms +step:449/1660 train_time:41115ms step_avg:91.57ms +step:450/1660 train_time:41208ms step_avg:91.57ms +step:451/1660 train_time:41300ms step_avg:91.57ms +step:452/1660 train_time:41390ms step_avg:91.57ms +step:453/1660 train_time:41483ms step_avg:91.57ms +step:454/1660 train_time:41574ms step_avg:91.57ms +step:455/1660 train_time:41665ms step_avg:91.57ms +step:456/1660 train_time:41756ms step_avg:91.57ms +step:457/1660 train_time:41846ms step_avg:91.57ms +step:458/1660 train_time:41938ms step_avg:91.57ms +step:459/1660 train_time:42029ms step_avg:91.57ms +step:460/1660 train_time:42121ms step_avg:91.57ms +step:461/1660 train_time:42212ms step_avg:91.57ms +step:462/1660 train_time:42304ms step_avg:91.57ms +step:463/1660 train_time:42396ms step_avg:91.57ms +step:464/1660 train_time:42488ms step_avg:91.57ms +step:465/1660 train_time:42580ms step_avg:91.57ms +step:466/1660 train_time:42670ms step_avg:91.57ms +step:467/1660 train_time:42762ms step_avg:91.57ms +step:468/1660 train_time:42853ms step_avg:91.57ms +step:469/1660 train_time:42944ms step_avg:91.57ms +step:470/1660 train_time:43035ms step_avg:91.56ms +step:471/1660 train_time:43127ms step_avg:91.56ms +step:472/1660 train_time:43219ms step_avg:91.56ms +step:473/1660 train_time:43310ms step_avg:91.56ms +step:474/1660 train_time:43402ms step_avg:91.57ms +step:475/1660 train_time:43494ms step_avg:91.57ms +step:476/1660 train_time:43586ms step_avg:91.57ms +step:477/1660 train_time:43678ms step_avg:91.57ms +step:478/1660 train_time:43769ms step_avg:91.57ms +step:479/1660 train_time:43860ms step_avg:91.57ms +step:480/1660 train_time:43950ms step_avg:91.56ms +step:481/1660 train_time:44042ms step_avg:91.56ms +step:482/1660 train_time:44133ms step_avg:91.56ms +step:483/1660 train_time:44225ms step_avg:91.56ms +step:484/1660 train_time:44316ms step_avg:91.56ms +step:485/1660 train_time:44409ms step_avg:91.56ms +step:486/1660 train_time:44501ms step_avg:91.57ms +step:487/1660 train_time:44592ms step_avg:91.57ms +step:488/1660 train_time:44685ms step_avg:91.57ms +step:489/1660 train_time:44777ms step_avg:91.57ms +step:490/1660 train_time:44868ms step_avg:91.57ms +step:491/1660 train_time:44959ms step_avg:91.57ms +step:492/1660 train_time:45050ms step_avg:91.56ms +step:493/1660 train_time:45141ms step_avg:91.56ms +step:494/1660 train_time:45231ms step_avg:91.56ms +step:495/1660 train_time:45323ms step_avg:91.56ms +step:496/1660 train_time:45414ms step_avg:91.56ms +step:497/1660 train_time:45507ms step_avg:91.56ms +step:498/1660 train_time:45599ms step_avg:91.56ms +step:499/1660 train_time:45690ms step_avg:91.56ms +step:500/1660 train_time:45782ms step_avg:91.56ms +step:500/1660 val_loss:3.7134 train_time:45876ms step_avg:91.75ms +step:501/1660 train_time:45896ms step_avg:91.61ms +step:502/1660 train_time:45972ms step_avg:91.58ms +step:503/1660 train_time:46071ms step_avg:91.59ms +step:504/1660 train_time:46164ms step_avg:91.60ms +step:505/1660 train_time:46255ms step_avg:91.59ms +step:506/1660 train_time:46346ms step_avg:91.59ms +step:507/1660 train_time:46436ms step_avg:91.59ms +step:508/1660 train_time:46528ms step_avg:91.59ms +step:509/1660 train_time:46618ms step_avg:91.59ms +step:510/1660 train_time:46708ms step_avg:91.58ms +step:511/1660 train_time:46799ms step_avg:91.58ms +step:512/1660 train_time:46891ms step_avg:91.58ms +step:513/1660 train_time:46985ms step_avg:91.59ms +step:514/1660 train_time:47079ms step_avg:91.59ms +step:515/1660 train_time:47170ms step_avg:91.59ms +step:516/1660 train_time:47262ms step_avg:91.59ms +step:517/1660 train_time:47352ms step_avg:91.59ms +step:518/1660 train_time:47443ms step_avg:91.59ms +step:519/1660 train_time:47533ms step_avg:91.59ms +step:520/1660 train_time:47624ms step_avg:91.58ms +step:521/1660 train_time:47714ms step_avg:91.58ms +step:522/1660 train_time:47805ms step_avg:91.58ms +step:523/1660 train_time:47896ms step_avg:91.58ms +step:524/1660 train_time:47990ms step_avg:91.58ms +step:525/1660 train_time:48083ms step_avg:91.59ms +step:526/1660 train_time:48175ms step_avg:91.59ms +step:527/1660 train_time:48267ms step_avg:91.59ms +step:528/1660 train_time:48357ms step_avg:91.58ms +step:529/1660 train_time:48448ms step_avg:91.58ms +step:530/1660 train_time:48539ms step_avg:91.58ms +step:531/1660 train_time:48630ms step_avg:91.58ms +step:532/1660 train_time:48720ms step_avg:91.58ms +step:533/1660 train_time:48811ms step_avg:91.58ms +step:534/1660 train_time:48903ms step_avg:91.58ms +step:535/1660 train_time:48994ms step_avg:91.58ms +step:536/1660 train_time:49087ms step_avg:91.58ms +step:537/1660 train_time:49179ms step_avg:91.58ms +step:538/1660 train_time:49271ms step_avg:91.58ms +step:539/1660 train_time:49362ms step_avg:91.58ms +step:540/1660 train_time:49452ms step_avg:91.58ms +step:541/1660 train_time:49544ms step_avg:91.58ms +step:542/1660 train_time:49635ms step_avg:91.58ms +step:543/1660 train_time:49726ms step_avg:91.58ms +step:544/1660 train_time:49817ms step_avg:91.57ms +step:545/1660 train_time:49909ms step_avg:91.58ms +step:546/1660 train_time:50000ms step_avg:91.58ms +step:547/1660 train_time:50091ms step_avg:91.57ms +step:548/1660 train_time:50184ms step_avg:91.58ms +step:549/1660 train_time:50275ms step_avg:91.58ms +step:550/1660 train_time:50367ms step_avg:91.58ms +step:551/1660 train_time:50457ms step_avg:91.57ms +step:552/1660 train_time:50550ms step_avg:91.58ms +step:553/1660 train_time:50642ms step_avg:91.58ms +step:554/1660 train_time:50733ms step_avg:91.58ms +step:555/1660 train_time:50824ms step_avg:91.58ms +step:556/1660 train_time:50916ms step_avg:91.58ms +step:557/1660 train_time:51009ms step_avg:91.58ms +step:558/1660 train_time:51101ms step_avg:91.58ms +step:559/1660 train_time:51193ms step_avg:91.58ms +step:560/1660 train_time:51287ms step_avg:91.58ms +step:561/1660 train_time:51379ms step_avg:91.59ms +step:562/1660 train_time:51472ms step_avg:91.59ms +step:563/1660 train_time:51565ms step_avg:91.59ms +step:564/1660 train_time:51657ms step_avg:91.59ms +step:565/1660 train_time:51750ms step_avg:91.59ms +step:566/1660 train_time:51843ms step_avg:91.60ms +step:567/1660 train_time:51935ms step_avg:91.60ms +step:568/1660 train_time:52028ms step_avg:91.60ms +step:569/1660 train_time:52122ms step_avg:91.60ms +step:570/1660 train_time:52214ms step_avg:91.60ms +step:571/1660 train_time:52308ms step_avg:91.61ms +step:572/1660 train_time:52401ms step_avg:91.61ms +step:573/1660 train_time:52493ms step_avg:91.61ms +step:574/1660 train_time:52586ms step_avg:91.61ms +step:575/1660 train_time:52679ms step_avg:91.61ms +step:576/1660 train_time:52771ms step_avg:91.62ms +step:577/1660 train_time:52864ms step_avg:91.62ms +step:578/1660 train_time:52955ms step_avg:91.62ms +step:579/1660 train_time:53049ms step_avg:91.62ms +step:580/1660 train_time:53142ms step_avg:91.62ms +step:581/1660 train_time:53234ms step_avg:91.63ms +step:582/1660 train_time:53328ms step_avg:91.63ms +step:583/1660 train_time:53421ms step_avg:91.63ms +step:584/1660 train_time:53514ms step_avg:91.63ms +step:585/1660 train_time:53607ms step_avg:91.64ms +step:586/1660 train_time:53699ms step_avg:91.64ms +step:587/1660 train_time:53792ms step_avg:91.64ms +step:588/1660 train_time:53884ms step_avg:91.64ms +step:589/1660 train_time:53977ms step_avg:91.64ms +step:590/1660 train_time:54071ms step_avg:91.64ms +step:591/1660 train_time:54163ms step_avg:91.65ms +step:592/1660 train_time:54255ms step_avg:91.65ms +step:593/1660 train_time:54350ms step_avg:91.65ms +step:594/1660 train_time:54443ms step_avg:91.65ms +step:595/1660 train_time:54534ms step_avg:91.65ms +step:596/1660 train_time:54627ms step_avg:91.66ms +step:597/1660 train_time:54720ms step_avg:91.66ms +step:598/1660 train_time:54812ms step_avg:91.66ms +step:599/1660 train_time:54905ms step_avg:91.66ms +step:600/1660 train_time:54997ms step_avg:91.66ms +step:601/1660 train_time:55090ms step_avg:91.66ms +step:602/1660 train_time:55183ms step_avg:91.67ms +step:603/1660 train_time:55275ms step_avg:91.67ms +step:604/1660 train_time:55369ms step_avg:91.67ms +step:605/1660 train_time:55463ms step_avg:91.67ms +step:606/1660 train_time:55555ms step_avg:91.68ms +step:607/1660 train_time:55650ms step_avg:91.68ms +step:608/1660 train_time:55743ms step_avg:91.68ms +step:609/1660 train_time:55834ms step_avg:91.68ms +step:610/1660 train_time:55927ms step_avg:91.68ms +step:611/1660 train_time:56020ms step_avg:91.69ms +step:612/1660 train_time:56113ms step_avg:91.69ms +step:613/1660 train_time:56205ms step_avg:91.69ms +step:614/1660 train_time:56298ms step_avg:91.69ms +step:615/1660 train_time:56390ms step_avg:91.69ms +step:616/1660 train_time:56483ms step_avg:91.69ms +step:617/1660 train_time:56575ms step_avg:91.69ms +step:618/1660 train_time:56669ms step_avg:91.70ms +step:619/1660 train_time:56762ms step_avg:91.70ms +step:620/1660 train_time:56854ms step_avg:91.70ms +step:621/1660 train_time:56946ms step_avg:91.70ms +step:622/1660 train_time:57038ms step_avg:91.70ms +step:623/1660 train_time:57131ms step_avg:91.70ms +step:624/1660 train_time:57224ms step_avg:91.71ms +step:625/1660 train_time:57316ms step_avg:91.71ms +step:625/1660 val_loss:3.6141 train_time:57411ms step_avg:91.86ms +step:626/1660 train_time:57431ms step_avg:91.74ms +step:627/1660 train_time:57508ms step_avg:91.72ms +step:628/1660 train_time:57612ms step_avg:91.74ms +step:629/1660 train_time:57708ms step_avg:91.75ms +step:630/1660 train_time:57800ms step_avg:91.75ms +step:631/1660 train_time:57891ms step_avg:91.75ms +step:632/1660 train_time:57982ms step_avg:91.74ms +step:633/1660 train_time:58073ms step_avg:91.74ms +step:634/1660 train_time:58165ms step_avg:91.74ms +step:635/1660 train_time:58256ms step_avg:91.74ms +step:636/1660 train_time:58347ms step_avg:91.74ms +step:637/1660 train_time:58439ms step_avg:91.74ms +step:638/1660 train_time:58536ms step_avg:91.75ms +step:639/1660 train_time:58634ms step_avg:91.76ms +step:640/1660 train_time:58728ms step_avg:91.76ms +step:641/1660 train_time:58820ms step_avg:91.76ms +step:642/1660 train_time:58913ms step_avg:91.76ms +step:643/1660 train_time:59005ms step_avg:91.77ms +step:644/1660 train_time:59096ms step_avg:91.76ms +step:645/1660 train_time:59188ms step_avg:91.76ms +step:646/1660 train_time:59279ms step_avg:91.76ms +step:647/1660 train_time:59371ms step_avg:91.76ms +step:648/1660 train_time:59464ms step_avg:91.77ms +step:649/1660 train_time:59558ms step_avg:91.77ms +step:650/1660 train_time:59653ms step_avg:91.77ms +step:651/1660 train_time:59747ms step_avg:91.78ms +step:652/1660 train_time:59839ms step_avg:91.78ms +step:653/1660 train_time:59932ms step_avg:91.78ms +step:654/1660 train_time:60025ms step_avg:91.78ms +step:655/1660 train_time:60116ms step_avg:91.78ms +step:656/1660 train_time:60208ms step_avg:91.78ms +step:657/1660 train_time:60299ms step_avg:91.78ms +step:658/1660 train_time:60392ms step_avg:91.78ms +step:659/1660 train_time:60484ms step_avg:91.78ms +step:660/1660 train_time:60577ms step_avg:91.78ms +step:661/1660 train_time:60672ms step_avg:91.79ms +step:662/1660 train_time:60766ms step_avg:91.79ms +step:663/1660 train_time:60858ms step_avg:91.79ms +step:664/1660 train_time:60952ms step_avg:91.80ms +step:665/1660 train_time:61044ms step_avg:91.80ms +step:666/1660 train_time:61136ms step_avg:91.80ms +step:667/1660 train_time:61229ms step_avg:91.80ms +step:668/1660 train_time:61321ms step_avg:91.80ms +step:669/1660 train_time:61413ms step_avg:91.80ms +step:670/1660 train_time:61504ms step_avg:91.80ms +step:671/1660 train_time:61597ms step_avg:91.80ms +step:672/1660 train_time:61691ms step_avg:91.80ms +step:673/1660 train_time:61785ms step_avg:91.81ms +step:674/1660 train_time:61877ms step_avg:91.80ms +step:675/1660 train_time:61969ms step_avg:91.81ms +step:676/1660 train_time:62061ms step_avg:91.81ms +step:677/1660 train_time:62154ms step_avg:91.81ms +step:678/1660 train_time:62247ms step_avg:91.81ms +step:679/1660 train_time:62339ms step_avg:91.81ms +step:680/1660 train_time:62433ms step_avg:91.81ms +step:681/1660 train_time:62526ms step_avg:91.82ms +step:682/1660 train_time:62618ms step_avg:91.82ms +step:683/1660 train_time:62711ms step_avg:91.82ms +step:684/1660 train_time:62805ms step_avg:91.82ms +step:685/1660 train_time:62897ms step_avg:91.82ms +step:686/1660 train_time:62990ms step_avg:91.82ms +step:687/1660 train_time:63083ms step_avg:91.82ms +step:688/1660 train_time:63175ms step_avg:91.82ms +step:689/1660 train_time:63267ms step_avg:91.82ms +step:690/1660 train_time:63360ms step_avg:91.83ms +step:691/1660 train_time:63454ms step_avg:91.83ms +step:692/1660 train_time:63547ms step_avg:91.83ms +step:693/1660 train_time:63638ms step_avg:91.83ms +step:694/1660 train_time:63731ms step_avg:91.83ms +step:695/1660 train_time:63824ms step_avg:91.83ms +step:696/1660 train_time:63917ms step_avg:91.83ms +step:697/1660 train_time:64010ms step_avg:91.84ms +step:698/1660 train_time:64104ms step_avg:91.84ms +step:699/1660 train_time:64196ms step_avg:91.84ms +step:700/1660 train_time:64289ms step_avg:91.84ms +step:701/1660 train_time:64381ms step_avg:91.84ms +step:702/1660 train_time:64474ms step_avg:91.84ms +step:703/1660 train_time:64567ms step_avg:91.84ms +step:704/1660 train_time:64659ms step_avg:91.84ms +step:705/1660 train_time:64753ms step_avg:91.85ms +step:706/1660 train_time:64845ms step_avg:91.85ms +step:707/1660 train_time:64937ms step_avg:91.85ms +step:708/1660 train_time:65032ms step_avg:91.85ms +step:709/1660 train_time:65125ms step_avg:91.85ms +step:710/1660 train_time:65217ms step_avg:91.85ms +step:711/1660 train_time:65310ms step_avg:91.86ms +step:712/1660 train_time:65402ms step_avg:91.86ms +step:713/1660 train_time:65495ms step_avg:91.86ms +step:714/1660 train_time:65587ms step_avg:91.86ms +step:715/1660 train_time:65680ms step_avg:91.86ms +step:716/1660 train_time:65773ms step_avg:91.86ms +step:717/1660 train_time:65866ms step_avg:91.86ms +step:718/1660 train_time:65958ms step_avg:91.86ms +step:719/1660 train_time:66053ms step_avg:91.87ms +step:720/1660 train_time:66145ms step_avg:91.87ms +step:721/1660 train_time:66237ms step_avg:91.87ms +step:722/1660 train_time:66331ms step_avg:91.87ms +step:723/1660 train_time:66424ms step_avg:91.87ms +step:724/1660 train_time:66516ms step_avg:91.87ms +step:725/1660 train_time:66609ms step_avg:91.87ms +step:726/1660 train_time:66701ms step_avg:91.87ms +step:727/1660 train_time:66794ms step_avg:91.88ms +step:728/1660 train_time:66887ms step_avg:91.88ms +step:729/1660 train_time:66979ms step_avg:91.88ms +step:730/1660 train_time:67072ms step_avg:91.88ms +step:731/1660 train_time:67165ms step_avg:91.88ms +step:732/1660 train_time:67257ms step_avg:91.88ms +step:733/1660 train_time:67350ms step_avg:91.88ms +step:734/1660 train_time:67442ms step_avg:91.88ms +step:735/1660 train_time:67535ms step_avg:91.88ms +step:736/1660 train_time:67629ms step_avg:91.89ms +step:737/1660 train_time:67721ms step_avg:91.89ms +step:738/1660 train_time:67814ms step_avg:91.89ms +step:739/1660 train_time:67906ms step_avg:91.89ms +step:740/1660 train_time:67998ms step_avg:91.89ms +step:741/1660 train_time:68092ms step_avg:91.89ms +step:742/1660 train_time:68185ms step_avg:91.89ms +step:743/1660 train_time:68277ms step_avg:91.89ms +step:744/1660 train_time:68371ms step_avg:91.90ms +step:745/1660 train_time:68463ms step_avg:91.90ms +step:746/1660 train_time:68556ms step_avg:91.90ms +step:747/1660 train_time:68649ms step_avg:91.90ms +step:748/1660 train_time:68742ms step_avg:91.90ms +step:749/1660 train_time:68835ms step_avg:91.90ms +step:750/1660 train_time:68929ms step_avg:91.90ms +step:750/1660 val_loss:3.5627 train_time:69024ms step_avg:92.03ms +step:751/1660 train_time:69044ms step_avg:91.94ms +step:752/1660 train_time:69121ms step_avg:91.92ms +step:753/1660 train_time:69217ms step_avg:91.92ms +step:754/1660 train_time:69309ms step_avg:91.92ms +step:755/1660 train_time:69400ms step_avg:91.92ms +step:756/1660 train_time:69491ms step_avg:91.92ms +step:757/1660 train_time:69583ms step_avg:91.92ms +step:758/1660 train_time:69674ms step_avg:91.92ms +step:759/1660 train_time:69766ms step_avg:91.92ms +step:760/1660 train_time:69858ms step_avg:91.92ms +step:761/1660 train_time:69950ms step_avg:91.92ms +step:762/1660 train_time:70046ms step_avg:91.92ms +step:763/1660 train_time:70143ms step_avg:91.93ms +step:764/1660 train_time:70237ms step_avg:91.93ms +step:765/1660 train_time:70330ms step_avg:91.93ms +step:766/1660 train_time:70423ms step_avg:91.94ms +step:767/1660 train_time:70514ms step_avg:91.94ms +step:768/1660 train_time:70606ms step_avg:91.94ms +step:769/1660 train_time:70698ms step_avg:91.93ms +step:770/1660 train_time:70789ms step_avg:91.93ms +step:771/1660 train_time:70882ms step_avg:91.93ms +step:772/1660 train_time:70974ms step_avg:91.93ms +step:773/1660 train_time:71069ms step_avg:91.94ms +step:774/1660 train_time:71164ms step_avg:91.94ms +step:775/1660 train_time:71258ms step_avg:91.95ms +step:776/1660 train_time:71350ms step_avg:91.95ms +step:777/1660 train_time:71443ms step_avg:91.95ms +step:778/1660 train_time:71534ms step_avg:91.95ms +step:779/1660 train_time:71628ms step_avg:91.95ms +step:780/1660 train_time:71720ms step_avg:91.95ms +step:781/1660 train_time:71811ms step_avg:91.95ms +step:782/1660 train_time:71904ms step_avg:91.95ms +step:783/1660 train_time:71996ms step_avg:91.95ms +step:784/1660 train_time:72090ms step_avg:91.95ms +step:785/1660 train_time:72185ms step_avg:91.96ms +step:786/1660 train_time:72278ms step_avg:91.96ms +step:787/1660 train_time:72371ms step_avg:91.96ms +step:788/1660 train_time:72464ms step_avg:91.96ms +step:789/1660 train_time:72556ms step_avg:91.96ms +step:790/1660 train_time:72647ms step_avg:91.96ms +step:791/1660 train_time:72739ms step_avg:91.96ms +step:792/1660 train_time:72831ms step_avg:91.96ms +step:793/1660 train_time:72924ms step_avg:91.96ms +step:794/1660 train_time:73017ms step_avg:91.96ms +step:795/1660 train_time:73109ms step_avg:91.96ms +step:796/1660 train_time:73202ms step_avg:91.96ms +step:797/1660 train_time:73295ms step_avg:91.96ms +step:798/1660 train_time:73387ms step_avg:91.96ms +step:799/1660 train_time:73480ms step_avg:91.96ms +step:800/1660 train_time:73571ms step_avg:91.96ms +step:801/1660 train_time:73665ms step_avg:91.97ms +step:802/1660 train_time:73757ms step_avg:91.97ms +step:803/1660 train_time:73849ms step_avg:91.97ms +step:804/1660 train_time:73942ms step_avg:91.97ms +step:805/1660 train_time:74034ms step_avg:91.97ms +step:806/1660 train_time:74128ms step_avg:91.97ms +step:807/1660 train_time:74221ms step_avg:91.97ms +step:808/1660 train_time:74313ms step_avg:91.97ms +step:809/1660 train_time:74406ms step_avg:91.97ms +step:810/1660 train_time:74499ms step_avg:91.97ms +step:811/1660 train_time:74592ms step_avg:91.97ms +step:812/1660 train_time:74686ms step_avg:91.98ms +step:813/1660 train_time:74778ms step_avg:91.98ms +step:814/1660 train_time:74870ms step_avg:91.98ms +step:815/1660 train_time:74963ms step_avg:91.98ms +step:816/1660 train_time:75056ms step_avg:91.98ms +step:817/1660 train_time:75148ms step_avg:91.98ms +step:818/1660 train_time:75240ms step_avg:91.98ms +step:819/1660 train_time:75332ms step_avg:91.98ms +step:820/1660 train_time:75426ms step_avg:91.98ms +step:821/1660 train_time:75518ms step_avg:91.98ms +step:822/1660 train_time:75611ms step_avg:91.98ms +step:823/1660 train_time:75703ms step_avg:91.98ms +step:824/1660 train_time:75795ms step_avg:91.98ms +step:825/1660 train_time:75888ms step_avg:91.99ms +step:826/1660 train_time:75981ms step_avg:91.99ms +step:827/1660 train_time:76073ms step_avg:91.99ms +step:828/1660 train_time:76166ms step_avg:91.99ms +step:829/1660 train_time:76259ms step_avg:91.99ms +step:830/1660 train_time:76351ms step_avg:91.99ms +step:831/1660 train_time:76444ms step_avg:91.99ms +step:832/1660 train_time:76537ms step_avg:91.99ms +step:833/1660 train_time:76630ms step_avg:91.99ms +step:834/1660 train_time:76723ms step_avg:91.99ms +step:835/1660 train_time:76815ms step_avg:91.99ms +step:836/1660 train_time:76907ms step_avg:91.99ms +step:837/1660 train_time:77000ms step_avg:91.99ms +step:838/1660 train_time:77091ms step_avg:91.99ms +step:839/1660 train_time:77185ms step_avg:92.00ms +step:840/1660 train_time:77277ms step_avg:92.00ms +step:841/1660 train_time:77370ms step_avg:92.00ms +step:842/1660 train_time:77463ms step_avg:92.00ms +step:843/1660 train_time:77556ms step_avg:92.00ms +step:844/1660 train_time:77648ms step_avg:92.00ms +step:845/1660 train_time:77741ms step_avg:92.00ms +step:846/1660 train_time:77833ms step_avg:92.00ms +step:847/1660 train_time:77926ms step_avg:92.00ms +step:848/1660 train_time:78019ms step_avg:92.00ms +step:849/1660 train_time:78112ms step_avg:92.00ms +step:850/1660 train_time:78205ms step_avg:92.01ms +step:851/1660 train_time:78297ms step_avg:92.01ms +step:852/1660 train_time:78390ms step_avg:92.01ms +step:853/1660 train_time:78484ms step_avg:92.01ms +step:854/1660 train_time:78576ms step_avg:92.01ms +step:855/1660 train_time:78669ms step_avg:92.01ms +step:856/1660 train_time:78762ms step_avg:92.01ms +step:857/1660 train_time:78855ms step_avg:92.01ms +step:858/1660 train_time:78948ms step_avg:92.01ms +step:859/1660 train_time:79041ms step_avg:92.02ms +step:860/1660 train_time:79133ms step_avg:92.02ms +step:861/1660 train_time:79227ms step_avg:92.02ms +step:862/1660 train_time:79319ms step_avg:92.02ms +step:863/1660 train_time:79411ms step_avg:92.02ms +step:864/1660 train_time:79503ms step_avg:92.02ms +step:865/1660 train_time:79596ms step_avg:92.02ms +step:866/1660 train_time:79689ms step_avg:92.02ms +step:867/1660 train_time:79781ms step_avg:92.02ms +step:868/1660 train_time:79873ms step_avg:92.02ms +step:869/1660 train_time:79966ms step_avg:92.02ms +step:870/1660 train_time:80059ms step_avg:92.02ms +step:871/1660 train_time:80151ms step_avg:92.02ms +step:872/1660 train_time:80244ms step_avg:92.02ms +step:873/1660 train_time:80337ms step_avg:92.02ms +step:874/1660 train_time:80430ms step_avg:92.02ms +step:875/1660 train_time:80522ms step_avg:92.03ms +step:875/1660 val_loss:3.5191 train_time:80616ms step_avg:92.13ms +step:876/1660 train_time:80636ms step_avg:92.05ms +step:877/1660 train_time:80716ms step_avg:92.04ms +step:878/1660 train_time:80812ms step_avg:92.04ms +step:879/1660 train_time:80906ms step_avg:92.04ms +step:880/1660 train_time:80997ms step_avg:92.04ms +step:881/1660 train_time:81089ms step_avg:92.04ms +step:882/1660 train_time:81180ms step_avg:92.04ms +step:883/1660 train_time:81271ms step_avg:92.04ms +step:884/1660 train_time:81363ms step_avg:92.04ms +step:885/1660 train_time:81455ms step_avg:92.04ms +step:886/1660 train_time:81548ms step_avg:92.04ms +step:887/1660 train_time:81643ms step_avg:92.04ms +step:888/1660 train_time:81737ms step_avg:92.05ms +step:889/1660 train_time:81834ms step_avg:92.05ms +step:890/1660 train_time:81928ms step_avg:92.05ms +step:891/1660 train_time:82020ms step_avg:92.05ms +step:892/1660 train_time:82112ms step_avg:92.05ms +step:893/1660 train_time:82203ms step_avg:92.05ms +step:894/1660 train_time:82294ms step_avg:92.05ms +step:895/1660 train_time:82386ms step_avg:92.05ms +step:896/1660 train_time:82478ms step_avg:92.05ms +step:897/1660 train_time:82570ms step_avg:92.05ms +step:898/1660 train_time:82664ms step_avg:92.05ms +step:899/1660 train_time:82758ms step_avg:92.06ms +step:900/1660 train_time:82853ms step_avg:92.06ms +step:901/1660 train_time:82946ms step_avg:92.06ms +step:902/1660 train_time:83037ms step_avg:92.06ms +step:903/1660 train_time:83130ms step_avg:92.06ms +step:904/1660 train_time:83222ms step_avg:92.06ms +step:905/1660 train_time:83314ms step_avg:92.06ms +step:906/1660 train_time:83406ms step_avg:92.06ms +step:907/1660 train_time:83497ms step_avg:92.06ms +step:908/1660 train_time:83591ms step_avg:92.06ms +step:909/1660 train_time:83683ms step_avg:92.06ms +step:910/1660 train_time:83776ms step_avg:92.06ms +step:911/1660 train_time:83870ms step_avg:92.06ms +step:912/1660 train_time:83962ms step_avg:92.06ms +step:913/1660 train_time:84054ms step_avg:92.06ms +step:914/1660 train_time:84147ms step_avg:92.06ms +step:915/1660 train_time:84239ms step_avg:92.06ms +step:916/1660 train_time:84331ms step_avg:92.06ms +step:917/1660 train_time:84423ms step_avg:92.06ms +step:918/1660 train_time:84515ms step_avg:92.06ms +step:919/1660 train_time:84608ms step_avg:92.07ms +step:920/1660 train_time:84701ms step_avg:92.07ms +step:921/1660 train_time:84795ms step_avg:92.07ms +step:922/1660 train_time:84888ms step_avg:92.07ms +step:923/1660 train_time:84980ms step_avg:92.07ms +step:924/1660 train_time:85073ms step_avg:92.07ms +step:925/1660 train_time:85167ms step_avg:92.07ms +step:926/1660 train_time:85258ms step_avg:92.07ms +step:927/1660 train_time:85351ms step_avg:92.07ms +step:928/1660 train_time:85442ms step_avg:92.07ms +step:929/1660 train_time:85535ms step_avg:92.07ms +step:930/1660 train_time:85629ms step_avg:92.07ms +step:931/1660 train_time:85722ms step_avg:92.07ms +step:932/1660 train_time:85814ms step_avg:92.08ms +step:933/1660 train_time:85907ms step_avg:92.08ms +step:934/1660 train_time:86000ms step_avg:92.08ms +step:935/1660 train_time:86093ms step_avg:92.08ms +step:936/1660 train_time:86187ms step_avg:92.08ms +step:937/1660 train_time:86278ms step_avg:92.08ms +step:938/1660 train_time:86371ms step_avg:92.08ms +step:939/1660 train_time:86463ms step_avg:92.08ms +step:940/1660 train_time:86556ms step_avg:92.08ms +step:941/1660 train_time:86648ms step_avg:92.08ms +step:942/1660 train_time:86740ms step_avg:92.08ms +step:943/1660 train_time:86833ms step_avg:92.08ms +step:944/1660 train_time:86926ms step_avg:92.08ms +step:945/1660 train_time:87019ms step_avg:92.08ms +step:946/1660 train_time:87112ms step_avg:92.08ms +step:947/1660 train_time:87204ms step_avg:92.08ms +step:948/1660 train_time:87296ms step_avg:92.08ms +step:949/1660 train_time:87388ms step_avg:92.08ms +step:950/1660 train_time:87481ms step_avg:92.08ms +step:951/1660 train_time:87573ms step_avg:92.09ms +step:952/1660 train_time:87666ms step_avg:92.09ms +step:953/1660 train_time:87759ms step_avg:92.09ms +step:954/1660 train_time:87853ms step_avg:92.09ms +step:955/1660 train_time:87946ms step_avg:92.09ms +step:956/1660 train_time:88038ms step_avg:92.09ms +step:957/1660 train_time:88130ms step_avg:92.09ms +step:958/1660 train_time:88222ms step_avg:92.09ms +step:959/1660 train_time:88315ms step_avg:92.09ms +step:960/1660 train_time:88407ms step_avg:92.09ms +step:961/1660 train_time:88499ms step_avg:92.09ms +step:962/1660 train_time:88594ms step_avg:92.09ms +step:963/1660 train_time:88686ms step_avg:92.09ms +step:964/1660 train_time:88779ms step_avg:92.09ms +step:965/1660 train_time:88872ms step_avg:92.10ms +step:966/1660 train_time:88965ms step_avg:92.10ms +step:967/1660 train_time:89057ms step_avg:92.10ms +step:968/1660 train_time:89150ms step_avg:92.10ms +step:969/1660 train_time:89243ms step_avg:92.10ms +step:970/1660 train_time:89335ms step_avg:92.10ms +step:971/1660 train_time:89428ms step_avg:92.10ms +step:972/1660 train_time:89520ms step_avg:92.10ms +step:973/1660 train_time:89613ms step_avg:92.10ms +step:974/1660 train_time:89706ms step_avg:92.10ms +step:975/1660 train_time:89797ms step_avg:92.10ms +step:976/1660 train_time:89891ms step_avg:92.10ms +step:977/1660 train_time:89983ms step_avg:92.10ms +step:978/1660 train_time:90075ms step_avg:92.10ms +step:979/1660 train_time:90168ms step_avg:92.10ms +step:980/1660 train_time:90261ms step_avg:92.10ms +step:981/1660 train_time:90353ms step_avg:92.10ms +step:982/1660 train_time:90446ms step_avg:92.10ms +step:983/1660 train_time:90538ms step_avg:92.10ms +step:984/1660 train_time:90631ms step_avg:92.10ms +step:985/1660 train_time:90724ms step_avg:92.11ms +step:986/1660 train_time:90816ms step_avg:92.11ms +step:987/1660 train_time:90909ms step_avg:92.11ms +step:988/1660 train_time:91001ms step_avg:92.11ms +step:989/1660 train_time:91095ms step_avg:92.11ms +step:990/1660 train_time:91187ms step_avg:92.11ms +step:991/1660 train_time:91280ms step_avg:92.11ms +step:992/1660 train_time:91373ms step_avg:92.11ms +step:993/1660 train_time:91465ms step_avg:92.11ms +step:994/1660 train_time:91557ms step_avg:92.11ms +step:995/1660 train_time:91650ms step_avg:92.11ms +step:996/1660 train_time:91742ms step_avg:92.11ms +step:997/1660 train_time:91834ms step_avg:92.11ms +step:998/1660 train_time:91927ms step_avg:92.11ms +step:999/1660 train_time:92019ms step_avg:92.11ms +step:1000/1660 train_time:92112ms step_avg:92.11ms +step:1000/1660 val_loss:3.4689 train_time:92206ms step_avg:92.21ms +step:1001/1660 train_time:92226ms step_avg:92.13ms +step:1002/1660 train_time:92302ms step_avg:92.12ms +step:1003/1660 train_time:92398ms step_avg:92.12ms +step:1004/1660 train_time:92492ms step_avg:92.12ms +step:1005/1660 train_time:92584ms step_avg:92.12ms +step:1006/1660 train_time:92676ms step_avg:92.12ms +step:1007/1660 train_time:92768ms step_avg:92.12ms +step:1008/1660 train_time:92859ms step_avg:92.12ms +step:1009/1660 train_time:92950ms step_avg:92.12ms +step:1010/1660 train_time:93041ms step_avg:92.12ms +step:1011/1660 train_time:93134ms step_avg:92.12ms +step:1012/1660 train_time:93228ms step_avg:92.12ms +step:1013/1660 train_time:93322ms step_avg:92.12ms +step:1014/1660 train_time:93418ms step_avg:92.13ms +step:1015/1660 train_time:93513ms step_avg:92.13ms +step:1016/1660 train_time:93605ms step_avg:92.13ms +step:1017/1660 train_time:93697ms step_avg:92.13ms +step:1018/1660 train_time:93789ms step_avg:92.13ms +step:1019/1660 train_time:93880ms step_avg:92.13ms +step:1020/1660 train_time:93972ms step_avg:92.13ms +step:1021/1660 train_time:94063ms step_avg:92.13ms +step:1022/1660 train_time:94156ms step_avg:92.13ms +step:1023/1660 train_time:94250ms step_avg:92.13ms +step:1024/1660 train_time:94344ms step_avg:92.13ms +step:1025/1660 train_time:94438ms step_avg:92.13ms +step:1026/1660 train_time:94532ms step_avg:92.14ms +step:1027/1660 train_time:94623ms step_avg:92.14ms +step:1028/1660 train_time:94717ms step_avg:92.14ms +step:1029/1660 train_time:94810ms step_avg:92.14ms +step:1030/1660 train_time:94901ms step_avg:92.14ms +step:1031/1660 train_time:94994ms step_avg:92.14ms +step:1032/1660 train_time:95087ms step_avg:92.14ms +step:1033/1660 train_time:95179ms step_avg:92.14ms +step:1034/1660 train_time:95274ms step_avg:92.14ms +step:1035/1660 train_time:95367ms step_avg:92.14ms +step:1036/1660 train_time:95459ms step_avg:92.14ms +step:1037/1660 train_time:95552ms step_avg:92.14ms +step:1038/1660 train_time:95644ms step_avg:92.14ms +step:1039/1660 train_time:95737ms step_avg:92.14ms +step:1040/1660 train_time:95829ms step_avg:92.14ms +step:1041/1660 train_time:95920ms step_avg:92.14ms +step:1042/1660 train_time:96013ms step_avg:92.14ms +step:1043/1660 train_time:96106ms step_avg:92.14ms +step:1044/1660 train_time:96199ms step_avg:92.14ms +step:1045/1660 train_time:96292ms step_avg:92.15ms +step:1046/1660 train_time:96385ms step_avg:92.15ms +step:1047/1660 train_time:96478ms step_avg:92.15ms +step:1048/1660 train_time:96571ms step_avg:92.15ms +step:1049/1660 train_time:96663ms step_avg:92.15ms +step:1050/1660 train_time:96755ms step_avg:92.15ms +step:1051/1660 train_time:96848ms step_avg:92.15ms +step:1052/1660 train_time:96940ms step_avg:92.15ms +step:1053/1660 train_time:97033ms step_avg:92.15ms +step:1054/1660 train_time:97126ms step_avg:92.15ms +step:1055/1660 train_time:97218ms step_avg:92.15ms +step:1056/1660 train_time:97313ms step_avg:92.15ms +step:1057/1660 train_time:97406ms step_avg:92.15ms +step:1058/1660 train_time:97498ms step_avg:92.15ms +step:1059/1660 train_time:97591ms step_avg:92.15ms +step:1060/1660 train_time:97683ms step_avg:92.15ms +step:1061/1660 train_time:97776ms step_avg:92.15ms +step:1062/1660 train_time:97869ms step_avg:92.16ms +step:1063/1660 train_time:97961ms step_avg:92.15ms +step:1064/1660 train_time:98054ms step_avg:92.16ms +step:1065/1660 train_time:98146ms step_avg:92.16ms +step:1066/1660 train_time:98239ms step_avg:92.16ms +step:1067/1660 train_time:98332ms step_avg:92.16ms +step:1068/1660 train_time:98424ms step_avg:92.16ms +step:1069/1660 train_time:98518ms step_avg:92.16ms +step:1070/1660 train_time:98610ms step_avg:92.16ms +step:1071/1660 train_time:98703ms step_avg:92.16ms +step:1072/1660 train_time:98796ms step_avg:92.16ms +step:1073/1660 train_time:98889ms step_avg:92.16ms +step:1074/1660 train_time:98980ms step_avg:92.16ms +step:1075/1660 train_time:99074ms step_avg:92.16ms +step:1076/1660 train_time:99167ms step_avg:92.16ms +step:1077/1660 train_time:99259ms step_avg:92.16ms +step:1078/1660 train_time:99352ms step_avg:92.16ms +step:1079/1660 train_time:99445ms step_avg:92.16ms +step:1080/1660 train_time:99538ms step_avg:92.16ms +step:1081/1660 train_time:99630ms step_avg:92.16ms +step:1082/1660 train_time:99722ms step_avg:92.16ms +step:1083/1660 train_time:99816ms step_avg:92.17ms +step:1084/1660 train_time:99909ms step_avg:92.17ms +step:1085/1660 train_time:100001ms step_avg:92.17ms +step:1086/1660 train_time:100094ms step_avg:92.17ms +step:1087/1660 train_time:100186ms step_avg:92.17ms +step:1088/1660 train_time:100279ms step_avg:92.17ms +step:1089/1660 train_time:100372ms step_avg:92.17ms +step:1090/1660 train_time:100465ms step_avg:92.17ms +step:1091/1660 train_time:100557ms step_avg:92.17ms +step:1092/1660 train_time:100650ms step_avg:92.17ms +step:1093/1660 train_time:100742ms step_avg:92.17ms +step:1094/1660 train_time:100836ms step_avg:92.17ms +step:1095/1660 train_time:100929ms step_avg:92.17ms +step:1096/1660 train_time:101021ms step_avg:92.17ms +step:1097/1660 train_time:101114ms step_avg:92.17ms +step:1098/1660 train_time:101207ms step_avg:92.17ms +step:1099/1660 train_time:101299ms step_avg:92.17ms +step:1100/1660 train_time:101392ms step_avg:92.17ms +step:1101/1660 train_time:101484ms step_avg:92.17ms +step:1102/1660 train_time:101577ms step_avg:92.18ms +step:1103/1660 train_time:101670ms step_avg:92.18ms +step:1104/1660 train_time:101762ms step_avg:92.18ms +step:1105/1660 train_time:101855ms step_avg:92.18ms +step:1106/1660 train_time:101948ms step_avg:92.18ms +step:1107/1660 train_time:102040ms step_avg:92.18ms +step:1108/1660 train_time:102133ms step_avg:92.18ms +step:1109/1660 train_time:102227ms step_avg:92.18ms +step:1110/1660 train_time:102319ms step_avg:92.18ms +step:1111/1660 train_time:102415ms step_avg:92.18ms +step:1112/1660 train_time:102508ms step_avg:92.18ms +step:1113/1660 train_time:102601ms step_avg:92.18ms +step:1114/1660 train_time:102694ms step_avg:92.19ms +step:1115/1660 train_time:102789ms step_avg:92.19ms +step:1116/1660 train_time:102881ms step_avg:92.19ms +step:1117/1660 train_time:102975ms step_avg:92.19ms +step:1118/1660 train_time:103067ms step_avg:92.19ms +step:1119/1660 train_time:103160ms step_avg:92.19ms +step:1120/1660 train_time:103254ms step_avg:92.19ms +step:1121/1660 train_time:103349ms step_avg:92.19ms +step:1122/1660 train_time:103442ms step_avg:92.19ms +step:1123/1660 train_time:103537ms step_avg:92.20ms +step:1124/1660 train_time:103630ms step_avg:92.20ms +step:1125/1660 train_time:103723ms step_avg:92.20ms +step:1125/1660 val_loss:3.4165 train_time:103819ms step_avg:92.28ms +step:1126/1660 train_time:103839ms step_avg:92.22ms +step:1127/1660 train_time:103918ms step_avg:92.21ms +step:1128/1660 train_time:104017ms step_avg:92.21ms +step:1129/1660 train_time:104111ms step_avg:92.22ms +step:1130/1660 train_time:104203ms step_avg:92.22ms +step:1131/1660 train_time:104295ms step_avg:92.21ms +step:1132/1660 train_time:104387ms step_avg:92.21ms +step:1133/1660 train_time:104479ms step_avg:92.21ms +step:1134/1660 train_time:104571ms step_avg:92.21ms +step:1135/1660 train_time:104664ms step_avg:92.21ms +step:1136/1660 train_time:104756ms step_avg:92.22ms +step:1137/1660 train_time:104853ms step_avg:92.22ms +step:1138/1660 train_time:104951ms step_avg:92.22ms +step:1139/1660 train_time:105048ms step_avg:92.23ms +step:1140/1660 train_time:105142ms step_avg:92.23ms +step:1141/1660 train_time:105234ms step_avg:92.23ms +step:1142/1660 train_time:105326ms step_avg:92.23ms +step:1143/1660 train_time:105418ms step_avg:92.23ms +step:1144/1660 train_time:105511ms step_avg:92.23ms +step:1145/1660 train_time:105603ms step_avg:92.23ms +step:1146/1660 train_time:105695ms step_avg:92.23ms +step:1147/1660 train_time:105791ms step_avg:92.23ms +step:1148/1660 train_time:105887ms step_avg:92.24ms +step:1149/1660 train_time:105982ms step_avg:92.24ms +step:1150/1660 train_time:106075ms step_avg:92.24ms +step:1151/1660 train_time:106169ms step_avg:92.24ms +step:1152/1660 train_time:106263ms step_avg:92.24ms +step:1153/1660 train_time:106356ms step_avg:92.24ms +step:1154/1660 train_time:106450ms step_avg:92.24ms +step:1155/1660 train_time:106543ms step_avg:92.24ms +step:1156/1660 train_time:106634ms step_avg:92.24ms +step:1157/1660 train_time:106727ms step_avg:92.24ms +step:1158/1660 train_time:106820ms step_avg:92.25ms +step:1159/1660 train_time:106914ms step_avg:92.25ms +step:1160/1660 train_time:107008ms step_avg:92.25ms +step:1161/1660 train_time:107103ms step_avg:92.25ms +step:1162/1660 train_time:107196ms step_avg:92.25ms +step:1163/1660 train_time:107290ms step_avg:92.25ms +step:1164/1660 train_time:107384ms step_avg:92.25ms +step:1165/1660 train_time:107476ms step_avg:92.25ms +step:1166/1660 train_time:107570ms step_avg:92.26ms +step:1167/1660 train_time:107662ms step_avg:92.26ms +step:1168/1660 train_time:107754ms step_avg:92.26ms +step:1169/1660 train_time:107848ms step_avg:92.26ms +step:1170/1660 train_time:107942ms step_avg:92.26ms +step:1171/1660 train_time:108035ms step_avg:92.26ms +step:1172/1660 train_time:108129ms step_avg:92.26ms +step:1173/1660 train_time:108223ms step_avg:92.26ms +step:1174/1660 train_time:108316ms step_avg:92.26ms +step:1175/1660 train_time:108410ms step_avg:92.26ms +step:1176/1660 train_time:108503ms step_avg:92.26ms +step:1177/1660 train_time:108595ms step_avg:92.26ms +step:1178/1660 train_time:108689ms step_avg:92.27ms +step:1179/1660 train_time:108782ms step_avg:92.27ms +step:1180/1660 train_time:108876ms step_avg:92.27ms +step:1181/1660 train_time:108969ms step_avg:92.27ms +step:1182/1660 train_time:109064ms step_avg:92.27ms +step:1183/1660 train_time:109157ms step_avg:92.27ms +step:1184/1660 train_time:109251ms step_avg:92.27ms +step:1185/1660 train_time:109345ms step_avg:92.27ms +step:1186/1660 train_time:109438ms step_avg:92.27ms +step:1187/1660 train_time:109531ms step_avg:92.28ms +step:1188/1660 train_time:109625ms step_avg:92.28ms +step:1189/1660 train_time:109717ms step_avg:92.28ms +step:1190/1660 train_time:109812ms step_avg:92.28ms +step:1191/1660 train_time:109906ms step_avg:92.28ms +step:1192/1660 train_time:109999ms step_avg:92.28ms +step:1193/1660 train_time:110092ms step_avg:92.28ms +step:1194/1660 train_time:110186ms step_avg:92.28ms +step:1195/1660 train_time:110280ms step_avg:92.28ms +step:1196/1660 train_time:110372ms step_avg:92.28ms +step:1197/1660 train_time:110466ms step_avg:92.29ms +step:1198/1660 train_time:110560ms step_avg:92.29ms +step:1199/1660 train_time:110653ms step_avg:92.29ms +step:1200/1660 train_time:110747ms step_avg:92.29ms +step:1201/1660 train_time:110840ms step_avg:92.29ms +step:1202/1660 train_time:110933ms step_avg:92.29ms +step:1203/1660 train_time:111026ms step_avg:92.29ms +step:1204/1660 train_time:111120ms step_avg:92.29ms +step:1205/1660 train_time:111213ms step_avg:92.29ms +step:1206/1660 train_time:111307ms step_avg:92.29ms +step:1207/1660 train_time:111400ms step_avg:92.30ms +step:1208/1660 train_time:111493ms step_avg:92.30ms +step:1209/1660 train_time:111588ms step_avg:92.30ms +step:1210/1660 train_time:111682ms step_avg:92.30ms +step:1211/1660 train_time:111774ms step_avg:92.30ms +step:1212/1660 train_time:111868ms step_avg:92.30ms +step:1213/1660 train_time:111961ms step_avg:92.30ms +step:1214/1660 train_time:112055ms step_avg:92.30ms +step:1215/1660 train_time:112148ms step_avg:92.30ms +step:1216/1660 train_time:112242ms step_avg:92.30ms +step:1217/1660 train_time:112334ms step_avg:92.30ms +step:1218/1660 train_time:112429ms step_avg:92.31ms +step:1219/1660 train_time:112523ms step_avg:92.31ms +step:1220/1660 train_time:112617ms step_avg:92.31ms +step:1221/1660 train_time:112710ms step_avg:92.31ms +step:1222/1660 train_time:112803ms step_avg:92.31ms +step:1223/1660 train_time:112896ms step_avg:92.31ms +step:1224/1660 train_time:112991ms step_avg:92.31ms +step:1225/1660 train_time:113085ms step_avg:92.31ms +step:1226/1660 train_time:113179ms step_avg:92.32ms +step:1227/1660 train_time:113271ms step_avg:92.32ms +step:1228/1660 train_time:113365ms step_avg:92.32ms +step:1229/1660 train_time:113458ms step_avg:92.32ms +step:1230/1660 train_time:113552ms step_avg:92.32ms +step:1231/1660 train_time:113647ms step_avg:92.32ms +step:1232/1660 train_time:113741ms step_avg:92.32ms +step:1233/1660 train_time:113833ms step_avg:92.32ms +step:1234/1660 train_time:113927ms step_avg:92.32ms +step:1235/1660 train_time:114019ms step_avg:92.32ms +step:1236/1660 train_time:114112ms step_avg:92.32ms +step:1237/1660 train_time:114206ms step_avg:92.32ms +step:1238/1660 train_time:114299ms step_avg:92.33ms +step:1239/1660 train_time:114392ms step_avg:92.33ms +step:1240/1660 train_time:114486ms step_avg:92.33ms +step:1241/1660 train_time:114580ms step_avg:92.33ms +step:1242/1660 train_time:114673ms step_avg:92.33ms +step:1243/1660 train_time:114767ms step_avg:92.33ms +step:1244/1660 train_time:114860ms step_avg:92.33ms +step:1245/1660 train_time:114953ms step_avg:92.33ms +step:1246/1660 train_time:115047ms step_avg:92.33ms +step:1247/1660 train_time:115140ms step_avg:92.33ms +step:1248/1660 train_time:115232ms step_avg:92.33ms +step:1249/1660 train_time:115325ms step_avg:92.33ms +step:1250/1660 train_time:115419ms step_avg:92.33ms +step:1250/1660 val_loss:3.3774 train_time:115513ms step_avg:92.41ms +step:1251/1660 train_time:115533ms step_avg:92.35ms +step:1252/1660 train_time:115609ms step_avg:92.34ms +step:1253/1660 train_time:115706ms step_avg:92.34ms +step:1254/1660 train_time:115799ms step_avg:92.34ms +step:1255/1660 train_time:115891ms step_avg:92.34ms +step:1256/1660 train_time:115983ms step_avg:92.34ms +step:1257/1660 train_time:116075ms step_avg:92.34ms +step:1258/1660 train_time:116168ms step_avg:92.34ms +step:1259/1660 train_time:116261ms step_avg:92.34ms +step:1260/1660 train_time:116354ms step_avg:92.34ms +step:1261/1660 train_time:116447ms step_avg:92.35ms +step:1262/1660 train_time:116543ms step_avg:92.35ms +step:1263/1660 train_time:116639ms step_avg:92.35ms +step:1264/1660 train_time:116733ms step_avg:92.35ms +step:1265/1660 train_time:116826ms step_avg:92.35ms +step:1266/1660 train_time:116919ms step_avg:92.35ms +step:1267/1660 train_time:117012ms step_avg:92.35ms +step:1268/1660 train_time:117104ms step_avg:92.35ms +step:1269/1660 train_time:117197ms step_avg:92.35ms +step:1270/1660 train_time:117289ms step_avg:92.35ms +step:1271/1660 train_time:117382ms step_avg:92.35ms +step:1272/1660 train_time:117476ms step_avg:92.36ms +step:1273/1660 train_time:117570ms step_avg:92.36ms +step:1274/1660 train_time:117665ms step_avg:92.36ms +step:1275/1660 train_time:117759ms step_avg:92.36ms +step:1276/1660 train_time:117853ms step_avg:92.36ms +step:1277/1660 train_time:117946ms step_avg:92.36ms +step:1278/1660 train_time:118038ms step_avg:92.36ms +step:1279/1660 train_time:118130ms step_avg:92.36ms +step:1280/1660 train_time:118224ms step_avg:92.36ms +step:1281/1660 train_time:118316ms step_avg:92.36ms +step:1282/1660 train_time:118409ms step_avg:92.36ms +step:1283/1660 train_time:118505ms step_avg:92.37ms +step:1284/1660 train_time:118600ms step_avg:92.37ms +step:1285/1660 train_time:118694ms step_avg:92.37ms +step:1286/1660 train_time:118788ms step_avg:92.37ms +step:1287/1660 train_time:118881ms step_avg:92.37ms +step:1288/1660 train_time:118974ms step_avg:92.37ms +step:1289/1660 train_time:119067ms step_avg:92.37ms +step:1290/1660 train_time:119160ms step_avg:92.37ms +step:1291/1660 train_time:119252ms step_avg:92.37ms +step:1292/1660 train_time:119344ms step_avg:92.37ms +step:1293/1660 train_time:119438ms step_avg:92.37ms +step:1294/1660 train_time:119531ms step_avg:92.37ms +step:1295/1660 train_time:119626ms step_avg:92.38ms +step:1296/1660 train_time:119720ms step_avg:92.38ms +step:1297/1660 train_time:119814ms step_avg:92.38ms +step:1298/1660 train_time:119907ms step_avg:92.38ms +step:1299/1660 train_time:120001ms step_avg:92.38ms +step:1300/1660 train_time:120094ms step_avg:92.38ms +step:1301/1660 train_time:120187ms step_avg:92.38ms +step:1302/1660 train_time:120280ms step_avg:92.38ms +step:1303/1660 train_time:120373ms step_avg:92.38ms +step:1304/1660 train_time:120466ms step_avg:92.38ms +step:1305/1660 train_time:120561ms step_avg:92.38ms +step:1306/1660 train_time:120655ms step_avg:92.39ms +step:1307/1660 train_time:120747ms step_avg:92.39ms +step:1308/1660 train_time:120842ms step_avg:92.39ms +step:1309/1660 train_time:120937ms step_avg:92.39ms +step:1310/1660 train_time:121030ms step_avg:92.39ms +step:1311/1660 train_time:121124ms step_avg:92.39ms +step:1312/1660 train_time:121217ms step_avg:92.39ms +step:1313/1660 train_time:121309ms step_avg:92.39ms +step:1314/1660 train_time:121402ms step_avg:92.39ms +step:1315/1660 train_time:121496ms step_avg:92.39ms +step:1316/1660 train_time:121590ms step_avg:92.39ms +step:1317/1660 train_time:121683ms step_avg:92.39ms +step:1318/1660 train_time:121777ms step_avg:92.40ms +step:1319/1660 train_time:121870ms step_avg:92.40ms +step:1320/1660 train_time:121964ms step_avg:92.40ms +step:1321/1660 train_time:122058ms step_avg:92.40ms +step:1322/1660 train_time:122151ms step_avg:92.40ms +step:1323/1660 train_time:122246ms step_avg:92.40ms +step:1324/1660 train_time:122339ms step_avg:92.40ms +step:1325/1660 train_time:122432ms step_avg:92.40ms +step:1326/1660 train_time:122526ms step_avg:92.40ms +step:1327/1660 train_time:122620ms step_avg:92.40ms +step:1328/1660 train_time:122713ms step_avg:92.40ms +step:1329/1660 train_time:122807ms step_avg:92.41ms +step:1330/1660 train_time:122902ms step_avg:92.41ms +step:1331/1660 train_time:122998ms step_avg:92.41ms +step:1332/1660 train_time:123091ms step_avg:92.41ms +step:1333/1660 train_time:123184ms step_avg:92.41ms +step:1334/1660 train_time:123276ms step_avg:92.41ms +step:1335/1660 train_time:123369ms step_avg:92.41ms +step:1336/1660 train_time:123464ms step_avg:92.41ms +step:1337/1660 train_time:123557ms step_avg:92.41ms +step:1338/1660 train_time:123650ms step_avg:92.41ms +step:1339/1660 train_time:123743ms step_avg:92.41ms +step:1340/1660 train_time:123838ms step_avg:92.42ms +step:1341/1660 train_time:123930ms step_avg:92.42ms +step:1342/1660 train_time:124024ms step_avg:92.42ms +step:1343/1660 train_time:124117ms step_avg:92.42ms +step:1344/1660 train_time:124209ms step_avg:92.42ms +step:1345/1660 train_time:124304ms step_avg:92.42ms +step:1346/1660 train_time:124398ms step_avg:92.42ms +step:1347/1660 train_time:124491ms step_avg:92.42ms +step:1348/1660 train_time:124584ms step_avg:92.42ms +step:1349/1660 train_time:124678ms step_avg:92.42ms +step:1350/1660 train_time:124770ms step_avg:92.42ms +step:1351/1660 train_time:124865ms step_avg:92.42ms +step:1352/1660 train_time:124960ms step_avg:92.43ms +step:1353/1660 train_time:125053ms step_avg:92.43ms +step:1354/1660 train_time:125146ms step_avg:92.43ms +step:1355/1660 train_time:125239ms step_avg:92.43ms +step:1356/1660 train_time:125331ms step_avg:92.43ms +step:1357/1660 train_time:125424ms step_avg:92.43ms +step:1358/1660 train_time:125518ms step_avg:92.43ms +step:1359/1660 train_time:125610ms step_avg:92.43ms +step:1360/1660 train_time:125705ms step_avg:92.43ms +step:1361/1660 train_time:125798ms step_avg:92.43ms +step:1362/1660 train_time:125891ms step_avg:92.43ms +step:1363/1660 train_time:125986ms step_avg:92.43ms +step:1364/1660 train_time:126079ms step_avg:92.43ms +step:1365/1660 train_time:126171ms step_avg:92.43ms +step:1366/1660 train_time:126266ms step_avg:92.43ms +step:1367/1660 train_time:126359ms step_avg:92.44ms +step:1368/1660 train_time:126452ms step_avg:92.44ms +step:1369/1660 train_time:126545ms step_avg:92.44ms +step:1370/1660 train_time:126639ms step_avg:92.44ms +step:1371/1660 train_time:126732ms step_avg:92.44ms +step:1372/1660 train_time:126825ms step_avg:92.44ms +step:1373/1660 train_time:126918ms step_avg:92.44ms +step:1374/1660 train_time:127011ms step_avg:92.44ms +step:1375/1660 train_time:127105ms step_avg:92.44ms +step:1375/1660 val_loss:3.3433 train_time:127200ms step_avg:92.51ms +step:1376/1660 train_time:127222ms step_avg:92.46ms +step:1377/1660 train_time:127298ms step_avg:92.45ms +step:1378/1660 train_time:127395ms step_avg:92.45ms +step:1379/1660 train_time:127489ms step_avg:92.45ms +step:1380/1660 train_time:127581ms step_avg:92.45ms +step:1381/1660 train_time:127673ms step_avg:92.45ms +step:1382/1660 train_time:127765ms step_avg:92.45ms +step:1383/1660 train_time:127857ms step_avg:92.45ms +step:1384/1660 train_time:127950ms step_avg:92.45ms +step:1385/1660 train_time:128042ms step_avg:92.45ms +step:1386/1660 train_time:128136ms step_avg:92.45ms +step:1387/1660 train_time:128234ms step_avg:92.45ms +step:1388/1660 train_time:128331ms step_avg:92.46ms +step:1389/1660 train_time:128426ms step_avg:92.46ms +step:1390/1660 train_time:128520ms step_avg:92.46ms +step:1391/1660 train_time:128612ms step_avg:92.46ms +step:1392/1660 train_time:128704ms step_avg:92.46ms +step:1393/1660 train_time:128797ms step_avg:92.46ms +step:1394/1660 train_time:128890ms step_avg:92.46ms +step:1395/1660 train_time:128982ms step_avg:92.46ms +step:1396/1660 train_time:129075ms step_avg:92.46ms +step:1397/1660 train_time:129169ms step_avg:92.46ms +step:1398/1660 train_time:129263ms step_avg:92.46ms +step:1399/1660 train_time:129357ms step_avg:92.46ms +step:1400/1660 train_time:129452ms step_avg:92.47ms +step:1401/1660 train_time:129545ms step_avg:92.47ms +step:1402/1660 train_time:129637ms step_avg:92.47ms +step:1403/1660 train_time:129731ms step_avg:92.47ms +step:1404/1660 train_time:129824ms step_avg:92.47ms +step:1405/1660 train_time:129916ms step_avg:92.47ms +step:1406/1660 train_time:130009ms step_avg:92.47ms +step:1407/1660 train_time:130102ms step_avg:92.47ms +step:1408/1660 train_time:130195ms step_avg:92.47ms +step:1409/1660 train_time:130291ms step_avg:92.47ms +step:1410/1660 train_time:130386ms step_avg:92.47ms +step:1411/1660 train_time:130480ms step_avg:92.47ms +step:1412/1660 train_time:130573ms step_avg:92.47ms +step:1413/1660 train_time:130666ms step_avg:92.47ms +step:1414/1660 train_time:130759ms step_avg:92.47ms +step:1415/1660 train_time:130852ms step_avg:92.47ms +step:1416/1660 train_time:130944ms step_avg:92.47ms +step:1417/1660 train_time:131036ms step_avg:92.47ms +step:1418/1660 train_time:131131ms step_avg:92.48ms +step:1419/1660 train_time:131225ms step_avg:92.48ms +step:1420/1660 train_time:131318ms step_avg:92.48ms +step:1421/1660 train_time:131412ms step_avg:92.48ms +step:1422/1660 train_time:131506ms step_avg:92.48ms +step:1423/1660 train_time:131599ms step_avg:92.48ms +step:1424/1660 train_time:131693ms step_avg:92.48ms +step:1425/1660 train_time:131787ms step_avg:92.48ms +step:1426/1660 train_time:131880ms step_avg:92.48ms +step:1427/1660 train_time:131972ms step_avg:92.48ms +step:1428/1660 train_time:132065ms step_avg:92.48ms +step:1429/1660 train_time:132159ms step_avg:92.48ms +step:1430/1660 train_time:132253ms step_avg:92.48ms +step:1431/1660 train_time:132346ms step_avg:92.49ms +step:1432/1660 train_time:132439ms step_avg:92.49ms +step:1433/1660 train_time:132533ms step_avg:92.49ms +step:1434/1660 train_time:132628ms step_avg:92.49ms +step:1435/1660 train_time:132724ms step_avg:92.49ms +step:1436/1660 train_time:132816ms step_avg:92.49ms +step:1437/1660 train_time:132909ms step_avg:92.49ms +step:1438/1660 train_time:133002ms step_avg:92.49ms +step:1439/1660 train_time:133096ms step_avg:92.49ms +step:1440/1660 train_time:133190ms step_avg:92.49ms +step:1441/1660 train_time:133284ms step_avg:92.49ms +step:1442/1660 train_time:133377ms step_avg:92.49ms +step:1443/1660 train_time:133470ms step_avg:92.49ms +step:1444/1660 train_time:133564ms step_avg:92.50ms +step:1445/1660 train_time:133657ms step_avg:92.50ms +step:1446/1660 train_time:133752ms step_avg:92.50ms +step:1447/1660 train_time:133845ms step_avg:92.50ms +step:1448/1660 train_time:133937ms step_avg:92.50ms +step:1449/1660 train_time:134032ms step_avg:92.50ms +step:1450/1660 train_time:134125ms step_avg:92.50ms +step:1451/1660 train_time:134218ms step_avg:92.50ms +step:1452/1660 train_time:134311ms step_avg:92.50ms +step:1453/1660 train_time:134404ms step_avg:92.50ms +step:1454/1660 train_time:134497ms step_avg:92.50ms +step:1455/1660 train_time:134592ms step_avg:92.50ms +step:1456/1660 train_time:134686ms step_avg:92.50ms +step:1457/1660 train_time:134779ms step_avg:92.50ms +step:1458/1660 train_time:134872ms step_avg:92.50ms +step:1459/1660 train_time:134965ms step_avg:92.51ms +step:1460/1660 train_time:135059ms step_avg:92.51ms +step:1461/1660 train_time:135153ms step_avg:92.51ms +step:1462/1660 train_time:135246ms step_avg:92.51ms +step:1463/1660 train_time:135338ms step_avg:92.51ms +step:1464/1660 train_time:135433ms step_avg:92.51ms +step:1465/1660 train_time:135526ms step_avg:92.51ms +step:1466/1660 train_time:135620ms step_avg:92.51ms +step:1467/1660 train_time:135714ms step_avg:92.51ms +step:1468/1660 train_time:135807ms step_avg:92.51ms +step:1469/1660 train_time:135901ms step_avg:92.51ms +step:1470/1660 train_time:135994ms step_avg:92.51ms +step:1471/1660 train_time:136088ms step_avg:92.51ms +step:1472/1660 train_time:136182ms step_avg:92.51ms +step:1473/1660 train_time:136274ms step_avg:92.51ms +step:1474/1660 train_time:136368ms step_avg:92.52ms +step:1475/1660 train_time:136462ms step_avg:92.52ms +step:1476/1660 train_time:136554ms step_avg:92.52ms +step:1477/1660 train_time:136648ms step_avg:92.52ms +step:1478/1660 train_time:136741ms step_avg:92.52ms +step:1479/1660 train_time:136834ms step_avg:92.52ms +step:1480/1660 train_time:136929ms step_avg:92.52ms +step:1481/1660 train_time:137023ms step_avg:92.52ms +step:1482/1660 train_time:137116ms step_avg:92.52ms +step:1483/1660 train_time:137209ms step_avg:92.52ms +step:1484/1660 train_time:137304ms step_avg:92.52ms +step:1485/1660 train_time:137397ms step_avg:92.52ms +step:1486/1660 train_time:137490ms step_avg:92.52ms +step:1487/1660 train_time:137583ms step_avg:92.52ms +step:1488/1660 train_time:137676ms step_avg:92.52ms +step:1489/1660 train_time:137769ms step_avg:92.52ms +step:1490/1660 train_time:137862ms step_avg:92.53ms +step:1491/1660 train_time:137956ms step_avg:92.53ms +step:1492/1660 train_time:138051ms step_avg:92.53ms +step:1493/1660 train_time:138144ms step_avg:92.53ms +step:1494/1660 train_time:138236ms step_avg:92.53ms +step:1495/1660 train_time:138331ms step_avg:92.53ms +step:1496/1660 train_time:138425ms step_avg:92.53ms +step:1497/1660 train_time:138517ms step_avg:92.53ms +step:1498/1660 train_time:138611ms step_avg:92.53ms +step:1499/1660 train_time:138704ms step_avg:92.53ms +step:1500/1660 train_time:138797ms step_avg:92.53ms +step:1500/1660 val_loss:3.3137 train_time:138892ms step_avg:92.59ms +step:1501/1660 train_time:138913ms step_avg:92.55ms +step:1502/1660 train_time:138992ms step_avg:92.54ms +step:1503/1660 train_time:139091ms step_avg:92.54ms +step:1504/1660 train_time:139185ms step_avg:92.54ms +step:1505/1660 train_time:139277ms step_avg:92.54ms +step:1506/1660 train_time:139370ms step_avg:92.54ms +step:1507/1660 train_time:139461ms step_avg:92.54ms +step:1508/1660 train_time:139553ms step_avg:92.54ms +step:1509/1660 train_time:139645ms step_avg:92.54ms +step:1510/1660 train_time:139738ms step_avg:92.54ms +step:1511/1660 train_time:139831ms step_avg:92.54ms +step:1512/1660 train_time:139925ms step_avg:92.54ms +step:1513/1660 train_time:140022ms step_avg:92.55ms +step:1514/1660 train_time:140119ms step_avg:92.55ms +step:1515/1660 train_time:140212ms step_avg:92.55ms +step:1516/1660 train_time:140306ms step_avg:92.55ms +step:1517/1660 train_time:140398ms step_avg:92.55ms +step:1518/1660 train_time:140491ms step_avg:92.55ms +step:1519/1660 train_time:140583ms step_avg:92.55ms +step:1520/1660 train_time:140676ms step_avg:92.55ms +step:1521/1660 train_time:140768ms step_avg:92.55ms +step:1522/1660 train_time:140861ms step_avg:92.55ms +step:1523/1660 train_time:140957ms step_avg:92.55ms +step:1524/1660 train_time:141052ms step_avg:92.55ms +step:1525/1660 train_time:141146ms step_avg:92.55ms +step:1526/1660 train_time:141240ms step_avg:92.56ms +step:1527/1660 train_time:141334ms step_avg:92.56ms +step:1528/1660 train_time:141426ms step_avg:92.56ms +step:1529/1660 train_time:141518ms step_avg:92.56ms +step:1530/1660 train_time:141611ms step_avg:92.56ms +step:1531/1660 train_time:141703ms step_avg:92.56ms +step:1532/1660 train_time:141796ms step_avg:92.56ms +step:1533/1660 train_time:141890ms step_avg:92.56ms +step:1534/1660 train_time:141984ms step_avg:92.56ms +step:1535/1660 train_time:142078ms step_avg:92.56ms +step:1536/1660 train_time:142173ms step_avg:92.56ms +step:1537/1660 train_time:142266ms step_avg:92.56ms +step:1538/1660 train_time:142360ms step_avg:92.56ms +step:1539/1660 train_time:142454ms step_avg:92.56ms +step:1540/1660 train_time:142547ms step_avg:92.56ms +step:1541/1660 train_time:142639ms step_avg:92.56ms +step:1542/1660 train_time:142732ms step_avg:92.56ms +step:1543/1660 train_time:142824ms step_avg:92.56ms +step:1544/1660 train_time:142919ms step_avg:92.56ms +step:1545/1660 train_time:143013ms step_avg:92.57ms +step:1546/1660 train_time:143107ms step_avg:92.57ms +step:1547/1660 train_time:143200ms step_avg:92.57ms +step:1548/1660 train_time:143294ms step_avg:92.57ms +step:1549/1660 train_time:143387ms step_avg:92.57ms +step:1550/1660 train_time:143480ms step_avg:92.57ms +step:1551/1660 train_time:143573ms step_avg:92.57ms +step:1552/1660 train_time:143665ms step_avg:92.57ms +step:1553/1660 train_time:143758ms step_avg:92.57ms +step:1554/1660 train_time:143852ms step_avg:92.57ms +step:1555/1660 train_time:143945ms step_avg:92.57ms +step:1556/1660 train_time:144041ms step_avg:92.57ms +step:1557/1660 train_time:144134ms step_avg:92.57ms +step:1558/1660 train_time:144227ms step_avg:92.57ms +step:1559/1660 train_time:144321ms step_avg:92.57ms +step:1560/1660 train_time:144415ms step_avg:92.57ms +step:1561/1660 train_time:144507ms step_avg:92.57ms +step:1562/1660 train_time:144600ms step_avg:92.57ms +step:1563/1660 train_time:144693ms step_avg:92.57ms +step:1564/1660 train_time:144786ms step_avg:92.57ms +step:1565/1660 train_time:144879ms step_avg:92.57ms +step:1566/1660 train_time:144973ms step_avg:92.58ms +step:1567/1660 train_time:145066ms step_avg:92.58ms +step:1568/1660 train_time:145159ms step_avg:92.58ms +step:1569/1660 train_time:145253ms step_avg:92.58ms +step:1570/1660 train_time:145347ms step_avg:92.58ms +step:1571/1660 train_time:145440ms step_avg:92.58ms +step:1572/1660 train_time:145533ms step_avg:92.58ms +step:1573/1660 train_time:145627ms step_avg:92.58ms +step:1574/1660 train_time:145720ms step_avg:92.58ms +step:1575/1660 train_time:145814ms step_avg:92.58ms +step:1576/1660 train_time:145907ms step_avg:92.58ms +step:1577/1660 train_time:146000ms step_avg:92.58ms +step:1578/1660 train_time:146095ms step_avg:92.58ms +step:1579/1660 train_time:146189ms step_avg:92.58ms +step:1580/1660 train_time:146283ms step_avg:92.58ms +step:1581/1660 train_time:146378ms step_avg:92.59ms +step:1582/1660 train_time:146472ms step_avg:92.59ms +step:1583/1660 train_time:146564ms step_avg:92.59ms +step:1584/1660 train_time:146657ms step_avg:92.59ms +step:1585/1660 train_time:146750ms step_avg:92.59ms +step:1586/1660 train_time:146843ms step_avg:92.59ms +step:1587/1660 train_time:146937ms step_avg:92.59ms +step:1588/1660 train_time:147031ms step_avg:92.59ms +step:1589/1660 train_time:147124ms step_avg:92.59ms +step:1590/1660 train_time:147218ms step_avg:92.59ms +step:1591/1660 train_time:147313ms step_avg:92.59ms +step:1592/1660 train_time:147406ms step_avg:92.59ms +step:1593/1660 train_time:147499ms step_avg:92.59ms +step:1594/1660 train_time:147592ms step_avg:92.59ms +step:1595/1660 train_time:147685ms step_avg:92.59ms +step:1596/1660 train_time:147778ms step_avg:92.59ms +step:1597/1660 train_time:147871ms step_avg:92.59ms +step:1598/1660 train_time:147964ms step_avg:92.59ms +step:1599/1660 train_time:148059ms step_avg:92.59ms +step:1600/1660 train_time:148153ms step_avg:92.60ms +step:1601/1660 train_time:148247ms step_avg:92.60ms +step:1602/1660 train_time:148340ms step_avg:92.60ms +step:1603/1660 train_time:148434ms step_avg:92.60ms +step:1604/1660 train_time:148527ms step_avg:92.60ms +step:1605/1660 train_time:148620ms step_avg:92.60ms +step:1606/1660 train_time:148714ms step_avg:92.60ms +step:1607/1660 train_time:148806ms step_avg:92.60ms +step:1608/1660 train_time:148899ms step_avg:92.60ms +step:1609/1660 train_time:148994ms step_avg:92.60ms +step:1610/1660 train_time:149087ms step_avg:92.60ms +step:1611/1660 train_time:149183ms step_avg:92.60ms +step:1612/1660 train_time:149276ms step_avg:92.60ms +step:1613/1660 train_time:149369ms step_avg:92.60ms +step:1614/1660 train_time:149462ms step_avg:92.60ms +step:1615/1660 train_time:149556ms step_avg:92.60ms +step:1616/1660 train_time:149649ms step_avg:92.60ms +step:1617/1660 train_time:149741ms step_avg:92.60ms +step:1618/1660 train_time:149835ms step_avg:92.60ms +step:1619/1660 train_time:149927ms step_avg:92.60ms +step:1620/1660 train_time:150021ms step_avg:92.61ms +step:1621/1660 train_time:150116ms step_avg:92.61ms +step:1622/1660 train_time:150211ms step_avg:92.61ms +step:1623/1660 train_time:150303ms step_avg:92.61ms +step:1624/1660 train_time:150396ms step_avg:92.61ms +step:1625/1660 train_time:150490ms step_avg:92.61ms +step:1625/1660 val_loss:3.2892 train_time:150585ms step_avg:92.67ms +step:1626/1660 train_time:150605ms step_avg:92.62ms +step:1627/1660 train_time:150683ms step_avg:92.61ms +step:1628/1660 train_time:150784ms step_avg:92.62ms +step:1629/1660 train_time:150877ms step_avg:92.62ms +step:1630/1660 train_time:150970ms step_avg:92.62ms +step:1631/1660 train_time:151062ms step_avg:92.62ms +step:1632/1660 train_time:151154ms step_avg:92.62ms +step:1633/1660 train_time:151245ms step_avg:92.62ms +step:1634/1660 train_time:151338ms step_avg:92.62ms +step:1635/1660 train_time:151430ms step_avg:92.62ms +step:1636/1660 train_time:151523ms step_avg:92.62ms +step:1637/1660 train_time:151618ms step_avg:92.62ms +step:1638/1660 train_time:151715ms step_avg:92.62ms +step:1639/1660 train_time:151809ms step_avg:92.62ms +step:1640/1660 train_time:151903ms step_avg:92.62ms +step:1641/1660 train_time:151996ms step_avg:92.62ms +step:1642/1660 train_time:152089ms step_avg:92.62ms +step:1643/1660 train_time:152183ms step_avg:92.62ms +step:1644/1660 train_time:152275ms step_avg:92.62ms +step:1645/1660 train_time:152367ms step_avg:92.62ms +step:1646/1660 train_time:152460ms step_avg:92.62ms +step:1647/1660 train_time:152553ms step_avg:92.62ms +step:1648/1660 train_time:152647ms step_avg:92.63ms +step:1649/1660 train_time:152743ms step_avg:92.63ms +step:1650/1660 train_time:152838ms step_avg:92.63ms +step:1651/1660 train_time:152932ms step_avg:92.63ms +step:1652/1660 train_time:153025ms step_avg:92.63ms +step:1653/1660 train_time:153117ms step_avg:92.63ms +step:1654/1660 train_time:153210ms step_avg:92.63ms +step:1655/1660 train_time:153303ms step_avg:92.63ms +step:1656/1660 train_time:153396ms step_avg:92.63ms +step:1657/1660 train_time:153489ms step_avg:92.63ms +step:1658/1660 train_time:153584ms step_avg:92.63ms +step:1659/1660 train_time:153680ms step_avg:92.63ms +step:1660/1660 train_time:153774ms step_avg:92.63ms +step:1660/1660 val_loss:3.2813 train_time:153869ms step_avg:92.69ms +peak memory allocated: 32002 MiB reserved: 47316 MiB diff --git a/records/091525_ThreadingFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt b/records/091525_ThreadingFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt new file mode 100644 index 000000000..4fe370693 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:54:13 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 197242 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 197243 C /usr/bin/python3 614MiB | +| 0 N/A N/A 197244 C /usr/bin/python3 614MiB | +| 0 N/A N/A 197245 C /usr/bin/python3 614MiB | +| 0 N/A N/A 197246 C /usr/bin/python3 614MiB | +| 0 N/A N/A 197247 C /usr/bin/python3 614MiB | +| 0 N/A N/A 197248 C /usr/bin/python3 614MiB | +| 0 N/A N/A 197249 C /usr/bin/python3 614MiB | +| 1 N/A N/A 197243 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 197244 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 197245 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 197246 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 197247 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 197248 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 197249 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:145ms step_avg:145.00ms +step:2/1660 train_time:166ms step_avg:82.86ms +step:3/1660 train_time:232ms step_avg:77.41ms +step:4/1660 train_time:321ms step_avg:80.27ms +step:5/1660 train_time:411ms step_avg:82.26ms +step:6/1660 train_time:501ms step_avg:83.55ms +step:7/1660 train_time:592ms step_avg:84.53ms +step:8/1660 train_time:682ms step_avg:85.30ms +step:9/1660 train_time:774ms step_avg:85.96ms +step:10/1660 train_time:864ms step_avg:86.43ms +step:11/1660 train_time:955ms step_avg:86.85ms +step:12/1660 train_time:1048ms step_avg:87.32ms +step:13/1660 train_time:1143ms step_avg:87.94ms +step:14/1660 train_time:1239ms step_avg:88.49ms +step:15/1660 train_time:1332ms step_avg:88.78ms +step:16/1660 train_time:1423ms step_avg:88.94ms +step:17/1660 train_time:1515ms step_avg:89.12ms +step:18/1660 train_time:1606ms step_avg:89.20ms +step:19/1660 train_time:1697ms step_avg:89.31ms +step:20/1660 train_time:1788ms step_avg:89.41ms +step:21/1660 train_time:1879ms step_avg:89.46ms +step:22/1660 train_time:1971ms step_avg:89.60ms +step:23/1660 train_time:2064ms step_avg:89.73ms +step:24/1660 train_time:2156ms step_avg:89.84ms +step:25/1660 train_time:2249ms step_avg:89.95ms +step:26/1660 train_time:2342ms step_avg:90.07ms +step:27/1660 train_time:2434ms step_avg:90.16ms +step:28/1660 train_time:2527ms step_avg:90.25ms +step:29/1660 train_time:2619ms step_avg:90.29ms +step:30/1660 train_time:2711ms step_avg:90.35ms +step:31/1660 train_time:2801ms step_avg:90.37ms +step:32/1660 train_time:2893ms step_avg:90.40ms +step:33/1660 train_time:2985ms step_avg:90.45ms +step:34/1660 train_time:3077ms step_avg:90.49ms +step:35/1660 train_time:3169ms step_avg:90.55ms +step:36/1660 train_time:3261ms step_avg:90.59ms +step:37/1660 train_time:3354ms step_avg:90.65ms +step:38/1660 train_time:3446ms step_avg:90.67ms +step:39/1660 train_time:3538ms step_avg:90.71ms +step:40/1660 train_time:3630ms step_avg:90.74ms +step:41/1660 train_time:3721ms step_avg:90.75ms +step:42/1660 train_time:3812ms step_avg:90.77ms +step:43/1660 train_time:3904ms step_avg:90.78ms +step:44/1660 train_time:3996ms step_avg:90.82ms +step:45/1660 train_time:4089ms step_avg:90.86ms +step:46/1660 train_time:4180ms step_avg:90.88ms +step:47/1660 train_time:4272ms step_avg:90.90ms +step:48/1660 train_time:4365ms step_avg:90.94ms +step:49/1660 train_time:4458ms step_avg:90.98ms +step:50/1660 train_time:4550ms step_avg:91.00ms +step:51/1660 train_time:4642ms step_avg:91.02ms +step:52/1660 train_time:4734ms step_avg:91.04ms +step:53/1660 train_time:4825ms step_avg:91.04ms +step:54/1660 train_time:4917ms step_avg:91.05ms +step:55/1660 train_time:5008ms step_avg:91.06ms +step:56/1660 train_time:5100ms step_avg:91.08ms +step:57/1660 train_time:5192ms step_avg:91.09ms +step:58/1660 train_time:5284ms step_avg:91.10ms +step:59/1660 train_time:5376ms step_avg:91.11ms +step:60/1660 train_time:5467ms step_avg:91.12ms +step:61/1660 train_time:5558ms step_avg:91.12ms +step:62/1660 train_time:5650ms step_avg:91.13ms +step:63/1660 train_time:5742ms step_avg:91.14ms +step:64/1660 train_time:5834ms step_avg:91.15ms +step:65/1660 train_time:5925ms step_avg:91.16ms +step:66/1660 train_time:6017ms step_avg:91.17ms +step:67/1660 train_time:6109ms step_avg:91.19ms +step:68/1660 train_time:6201ms step_avg:91.19ms +step:69/1660 train_time:6293ms step_avg:91.20ms +step:70/1660 train_time:6385ms step_avg:91.22ms +step:71/1660 train_time:6477ms step_avg:91.23ms +step:72/1660 train_time:6568ms step_avg:91.22ms +step:73/1660 train_time:6659ms step_avg:91.22ms +step:74/1660 train_time:6751ms step_avg:91.22ms +step:75/1660 train_time:6842ms step_avg:91.22ms +step:76/1660 train_time:6934ms step_avg:91.23ms +step:77/1660 train_time:7025ms step_avg:91.23ms +step:78/1660 train_time:7119ms step_avg:91.27ms +step:79/1660 train_time:7211ms step_avg:91.28ms +step:80/1660 train_time:7302ms step_avg:91.28ms +step:81/1660 train_time:7394ms step_avg:91.29ms +step:82/1660 train_time:7486ms step_avg:91.29ms +step:83/1660 train_time:7578ms step_avg:91.30ms +step:84/1660 train_time:7669ms step_avg:91.29ms +step:85/1660 train_time:7759ms step_avg:91.29ms +step:86/1660 train_time:7850ms step_avg:91.28ms +step:87/1660 train_time:7941ms step_avg:91.28ms +step:88/1660 train_time:8034ms step_avg:91.29ms +step:89/1660 train_time:8125ms step_avg:91.30ms +step:90/1660 train_time:8218ms step_avg:91.31ms +step:91/1660 train_time:8310ms step_avg:91.32ms +step:92/1660 train_time:8402ms step_avg:91.32ms +step:93/1660 train_time:8493ms step_avg:91.33ms +step:94/1660 train_time:8585ms step_avg:91.33ms +step:95/1660 train_time:8677ms step_avg:91.33ms +step:96/1660 train_time:8768ms step_avg:91.33ms +step:97/1660 train_time:8859ms step_avg:91.33ms +step:98/1660 train_time:8950ms step_avg:91.33ms +step:99/1660 train_time:9042ms step_avg:91.33ms +step:100/1660 train_time:9134ms step_avg:91.34ms +step:101/1660 train_time:9226ms step_avg:91.35ms +step:102/1660 train_time:9318ms step_avg:91.36ms +step:103/1660 train_time:9410ms step_avg:91.36ms +step:104/1660 train_time:9501ms step_avg:91.35ms +step:105/1660 train_time:9592ms step_avg:91.36ms +step:106/1660 train_time:9684ms step_avg:91.36ms +step:107/1660 train_time:9775ms step_avg:91.35ms +step:108/1660 train_time:9865ms step_avg:91.34ms +step:109/1660 train_time:9956ms step_avg:91.34ms +step:110/1660 train_time:10048ms step_avg:91.34ms +step:111/1660 train_time:10140ms step_avg:91.35ms +step:112/1660 train_time:10233ms step_avg:91.37ms +step:113/1660 train_time:10324ms step_avg:91.37ms +step:114/1660 train_time:10417ms step_avg:91.38ms +step:115/1660 train_time:10510ms step_avg:91.39ms +step:116/1660 train_time:10602ms step_avg:91.39ms +step:117/1660 train_time:10692ms step_avg:91.39ms +step:118/1660 train_time:10784ms step_avg:91.39ms +step:119/1660 train_time:10875ms step_avg:91.38ms +step:120/1660 train_time:10966ms step_avg:91.39ms +step:121/1660 train_time:11058ms step_avg:91.38ms +step:122/1660 train_time:11149ms step_avg:91.39ms +step:123/1660 train_time:11241ms step_avg:91.39ms +step:124/1660 train_time:11333ms step_avg:91.39ms +step:125/1660 train_time:11425ms step_avg:91.40ms +step:125/1660 val_loss:4.2980 train_time:11521ms step_avg:92.17ms +step:126/1660 train_time:11544ms step_avg:91.62ms +step:127/1660 train_time:11617ms step_avg:91.48ms +step:128/1660 train_time:11717ms step_avg:91.54ms +step:129/1660 train_time:11811ms step_avg:91.56ms +step:130/1660 train_time:11903ms step_avg:91.56ms +step:131/1660 train_time:11993ms step_avg:91.55ms +step:132/1660 train_time:12083ms step_avg:91.54ms +step:133/1660 train_time:12173ms step_avg:91.53ms +step:134/1660 train_time:12264ms step_avg:91.52ms +step:135/1660 train_time:12354ms step_avg:91.51ms +step:136/1660 train_time:12444ms step_avg:91.50ms +step:137/1660 train_time:12536ms step_avg:91.51ms +step:138/1660 train_time:12630ms step_avg:91.52ms +step:139/1660 train_time:12724ms step_avg:91.54ms +step:140/1660 train_time:12817ms step_avg:91.55ms +step:141/1660 train_time:12910ms step_avg:91.56ms +step:142/1660 train_time:13000ms step_avg:91.55ms +step:143/1660 train_time:13091ms step_avg:91.54ms +step:144/1660 train_time:13181ms step_avg:91.54ms +step:145/1660 train_time:13272ms step_avg:91.53ms +step:146/1660 train_time:13363ms step_avg:91.53ms +step:147/1660 train_time:13454ms step_avg:91.52ms +step:148/1660 train_time:13546ms step_avg:91.52ms +step:149/1660 train_time:13638ms step_avg:91.53ms +step:150/1660 train_time:13730ms step_avg:91.53ms +step:151/1660 train_time:13823ms step_avg:91.54ms +step:152/1660 train_time:13914ms step_avg:91.54ms +step:153/1660 train_time:14006ms step_avg:91.54ms +step:154/1660 train_time:14097ms step_avg:91.54ms +step:155/1660 train_time:14188ms step_avg:91.54ms +step:156/1660 train_time:14279ms step_avg:91.53ms +step:157/1660 train_time:14370ms step_avg:91.53ms +step:158/1660 train_time:14461ms step_avg:91.53ms +step:159/1660 train_time:14552ms step_avg:91.52ms +step:160/1660 train_time:14646ms step_avg:91.54ms +step:161/1660 train_time:14739ms step_avg:91.54ms +step:162/1660 train_time:14830ms step_avg:91.55ms +step:163/1660 train_time:14922ms step_avg:91.55ms +step:164/1660 train_time:15014ms step_avg:91.55ms +step:165/1660 train_time:15105ms step_avg:91.55ms +step:166/1660 train_time:15196ms step_avg:91.54ms +step:167/1660 train_time:15287ms step_avg:91.54ms +step:168/1660 train_time:15378ms step_avg:91.54ms +step:169/1660 train_time:15469ms step_avg:91.53ms +step:170/1660 train_time:15560ms step_avg:91.53ms +step:171/1660 train_time:15652ms step_avg:91.53ms +step:172/1660 train_time:15744ms step_avg:91.53ms +step:173/1660 train_time:15835ms step_avg:91.53ms +step:174/1660 train_time:15927ms step_avg:91.54ms +step:175/1660 train_time:16020ms step_avg:91.54ms +step:176/1660 train_time:16111ms step_avg:91.54ms +step:177/1660 train_time:16202ms step_avg:91.54ms +step:178/1660 train_time:16293ms step_avg:91.54ms +step:179/1660 train_time:16386ms step_avg:91.54ms +step:180/1660 train_time:16478ms step_avg:91.55ms +step:181/1660 train_time:16569ms step_avg:91.54ms +step:182/1660 train_time:16661ms step_avg:91.54ms +step:183/1660 train_time:16752ms step_avg:91.54ms +step:184/1660 train_time:16844ms step_avg:91.54ms +step:185/1660 train_time:16936ms step_avg:91.54ms +step:186/1660 train_time:17027ms step_avg:91.54ms +step:187/1660 train_time:17119ms step_avg:91.55ms +step:188/1660 train_time:17211ms step_avg:91.55ms +step:189/1660 train_time:17302ms step_avg:91.54ms +step:190/1660 train_time:17392ms step_avg:91.54ms +step:191/1660 train_time:17484ms step_avg:91.54ms +step:192/1660 train_time:17575ms step_avg:91.54ms +step:193/1660 train_time:17668ms step_avg:91.54ms +step:194/1660 train_time:17759ms step_avg:91.54ms +step:195/1660 train_time:17850ms step_avg:91.54ms +step:196/1660 train_time:17942ms step_avg:91.54ms +step:197/1660 train_time:18033ms step_avg:91.54ms +step:198/1660 train_time:18125ms step_avg:91.54ms +step:199/1660 train_time:18217ms step_avg:91.54ms +step:200/1660 train_time:18308ms step_avg:91.54ms +step:201/1660 train_time:18399ms step_avg:91.54ms +step:202/1660 train_time:18489ms step_avg:91.53ms +step:203/1660 train_time:18582ms step_avg:91.54ms +step:204/1660 train_time:18673ms step_avg:91.54ms +step:205/1660 train_time:18765ms step_avg:91.54ms +step:206/1660 train_time:18856ms step_avg:91.54ms +step:207/1660 train_time:18948ms step_avg:91.54ms +step:208/1660 train_time:19039ms step_avg:91.54ms +step:209/1660 train_time:19131ms step_avg:91.54ms +step:210/1660 train_time:19223ms step_avg:91.54ms +step:211/1660 train_time:19314ms step_avg:91.53ms +step:212/1660 train_time:19405ms step_avg:91.53ms +step:213/1660 train_time:19496ms step_avg:91.53ms +step:214/1660 train_time:19588ms step_avg:91.53ms +step:215/1660 train_time:19680ms step_avg:91.54ms +step:216/1660 train_time:19772ms step_avg:91.54ms +step:217/1660 train_time:19864ms step_avg:91.54ms +step:218/1660 train_time:19955ms step_avg:91.54ms +step:219/1660 train_time:20046ms step_avg:91.54ms +step:220/1660 train_time:20139ms step_avg:91.54ms +step:221/1660 train_time:20230ms step_avg:91.54ms +step:222/1660 train_time:20321ms step_avg:91.54ms +step:223/1660 train_time:20412ms step_avg:91.53ms +step:224/1660 train_time:20503ms step_avg:91.53ms +step:225/1660 train_time:20594ms step_avg:91.53ms +step:226/1660 train_time:20687ms step_avg:91.54ms +step:227/1660 train_time:20779ms step_avg:91.54ms +step:228/1660 train_time:20871ms step_avg:91.54ms +step:229/1660 train_time:20962ms step_avg:91.54ms +step:230/1660 train_time:21053ms step_avg:91.53ms +step:231/1660 train_time:21144ms step_avg:91.53ms +step:232/1660 train_time:21235ms step_avg:91.53ms +step:233/1660 train_time:21327ms step_avg:91.53ms +step:234/1660 train_time:21418ms step_avg:91.53ms +step:235/1660 train_time:21509ms step_avg:91.53ms +step:236/1660 train_time:21601ms step_avg:91.53ms +step:237/1660 train_time:21692ms step_avg:91.53ms +step:238/1660 train_time:21786ms step_avg:91.54ms +step:239/1660 train_time:21877ms step_avg:91.54ms +step:240/1660 train_time:21969ms step_avg:91.54ms +step:241/1660 train_time:22062ms step_avg:91.54ms +step:242/1660 train_time:22153ms step_avg:91.54ms +step:243/1660 train_time:22245ms step_avg:91.54ms +step:244/1660 train_time:22336ms step_avg:91.54ms +step:245/1660 train_time:22428ms step_avg:91.54ms +step:246/1660 train_time:22519ms step_avg:91.54ms +step:247/1660 train_time:22610ms step_avg:91.54ms +step:248/1660 train_time:22702ms step_avg:91.54ms +step:249/1660 train_time:22793ms step_avg:91.54ms +step:250/1660 train_time:22885ms step_avg:91.54ms +step:250/1660 val_loss:3.9698 train_time:22978ms step_avg:91.91ms +step:251/1660 train_time:22998ms step_avg:91.63ms +step:252/1660 train_time:23071ms step_avg:91.55ms +step:253/1660 train_time:23167ms step_avg:91.57ms +step:254/1660 train_time:23260ms step_avg:91.57ms +step:255/1660 train_time:23350ms step_avg:91.57ms +step:256/1660 train_time:23440ms step_avg:91.56ms +step:257/1660 train_time:23530ms step_avg:91.56ms +step:258/1660 train_time:23621ms step_avg:91.55ms +step:259/1660 train_time:23711ms step_avg:91.55ms +step:260/1660 train_time:23801ms step_avg:91.54ms +step:261/1660 train_time:23893ms step_avg:91.54ms +step:262/1660 train_time:23985ms step_avg:91.54ms +step:263/1660 train_time:24078ms step_avg:91.55ms +step:264/1660 train_time:24172ms step_avg:91.56ms +step:265/1660 train_time:24264ms step_avg:91.56ms +step:266/1660 train_time:24355ms step_avg:91.56ms +step:267/1660 train_time:24447ms step_avg:91.56ms +step:268/1660 train_time:24537ms step_avg:91.56ms +step:269/1660 train_time:24627ms step_avg:91.55ms +step:270/1660 train_time:24718ms step_avg:91.55ms +step:271/1660 train_time:24810ms step_avg:91.55ms +step:272/1660 train_time:24902ms step_avg:91.55ms +step:273/1660 train_time:24994ms step_avg:91.55ms +step:274/1660 train_time:25086ms step_avg:91.56ms +step:275/1660 train_time:25178ms step_avg:91.56ms +step:276/1660 train_time:25270ms step_avg:91.56ms +step:277/1660 train_time:25361ms step_avg:91.56ms +step:278/1660 train_time:25452ms step_avg:91.55ms +step:279/1660 train_time:25542ms step_avg:91.55ms +step:280/1660 train_time:25633ms step_avg:91.55ms +step:281/1660 train_time:25724ms step_avg:91.54ms +step:282/1660 train_time:25814ms step_avg:91.54ms +step:283/1660 train_time:25905ms step_avg:91.54ms +step:284/1660 train_time:25997ms step_avg:91.54ms +step:285/1660 train_time:26090ms step_avg:91.54ms +step:286/1660 train_time:26183ms step_avg:91.55ms +step:287/1660 train_time:26275ms step_avg:91.55ms +step:288/1660 train_time:26366ms step_avg:91.55ms +step:289/1660 train_time:26458ms step_avg:91.55ms +step:290/1660 train_time:26549ms step_avg:91.55ms +step:291/1660 train_time:26639ms step_avg:91.54ms +step:292/1660 train_time:26731ms step_avg:91.54ms +step:293/1660 train_time:26821ms step_avg:91.54ms +step:294/1660 train_time:26913ms step_avg:91.54ms +step:295/1660 train_time:27004ms step_avg:91.54ms +step:296/1660 train_time:27097ms step_avg:91.54ms +step:297/1660 train_time:27188ms step_avg:91.54ms +step:298/1660 train_time:27280ms step_avg:91.54ms +step:299/1660 train_time:27372ms step_avg:91.55ms +step:300/1660 train_time:27463ms step_avg:91.54ms +step:301/1660 train_time:27554ms step_avg:91.54ms +step:302/1660 train_time:27645ms step_avg:91.54ms +step:303/1660 train_time:27736ms step_avg:91.54ms +step:304/1660 train_time:27827ms step_avg:91.53ms +step:305/1660 train_time:27918ms step_avg:91.53ms +step:306/1660 train_time:28009ms step_avg:91.53ms +step:307/1660 train_time:28101ms step_avg:91.54ms +step:308/1660 train_time:28193ms step_avg:91.54ms +step:309/1660 train_time:28285ms step_avg:91.54ms +step:310/1660 train_time:28376ms step_avg:91.53ms +step:311/1660 train_time:28467ms step_avg:91.54ms +step:312/1660 train_time:28558ms step_avg:91.53ms +step:313/1660 train_time:28650ms step_avg:91.53ms +step:314/1660 train_time:28742ms step_avg:91.53ms +step:315/1660 train_time:28833ms step_avg:91.53ms +step:316/1660 train_time:28923ms step_avg:91.53ms +step:317/1660 train_time:29015ms step_avg:91.53ms +step:318/1660 train_time:29106ms step_avg:91.53ms +step:319/1660 train_time:29197ms step_avg:91.53ms +step:320/1660 train_time:29289ms step_avg:91.53ms +step:321/1660 train_time:29381ms step_avg:91.53ms +step:322/1660 train_time:29473ms step_avg:91.53ms +step:323/1660 train_time:29565ms step_avg:91.53ms +step:324/1660 train_time:29655ms step_avg:91.53ms +step:325/1660 train_time:29746ms step_avg:91.53ms +step:326/1660 train_time:29837ms step_avg:91.53ms +step:327/1660 train_time:29928ms step_avg:91.52ms +step:328/1660 train_time:30021ms step_avg:91.53ms +step:329/1660 train_time:30112ms step_avg:91.53ms +step:330/1660 train_time:30204ms step_avg:91.53ms +step:331/1660 train_time:30295ms step_avg:91.53ms +step:332/1660 train_time:30387ms step_avg:91.53ms +step:333/1660 train_time:30479ms step_avg:91.53ms +step:334/1660 train_time:30570ms step_avg:91.53ms +step:335/1660 train_time:30661ms step_avg:91.53ms +step:336/1660 train_time:30752ms step_avg:91.52ms +step:337/1660 train_time:30843ms step_avg:91.52ms +step:338/1660 train_time:30935ms step_avg:91.52ms +step:339/1660 train_time:31026ms step_avg:91.52ms +step:340/1660 train_time:31117ms step_avg:91.52ms +step:341/1660 train_time:31209ms step_avg:91.52ms +step:342/1660 train_time:31300ms step_avg:91.52ms +step:343/1660 train_time:31393ms step_avg:91.52ms +step:344/1660 train_time:31484ms step_avg:91.52ms +step:345/1660 train_time:31575ms step_avg:91.52ms +step:346/1660 train_time:31667ms step_avg:91.52ms +step:347/1660 train_time:31758ms step_avg:91.52ms +step:348/1660 train_time:31849ms step_avg:91.52ms +step:349/1660 train_time:31940ms step_avg:91.52ms +step:350/1660 train_time:32031ms step_avg:91.52ms +step:351/1660 train_time:32123ms step_avg:91.52ms +step:352/1660 train_time:32215ms step_avg:91.52ms +step:353/1660 train_time:32306ms step_avg:91.52ms +step:354/1660 train_time:32398ms step_avg:91.52ms +step:355/1660 train_time:32489ms step_avg:91.52ms +step:356/1660 train_time:32581ms step_avg:91.52ms +step:357/1660 train_time:32673ms step_avg:91.52ms +step:358/1660 train_time:32763ms step_avg:91.52ms +step:359/1660 train_time:32854ms step_avg:91.52ms +step:360/1660 train_time:32945ms step_avg:91.51ms +step:361/1660 train_time:33037ms step_avg:91.51ms +step:362/1660 train_time:33128ms step_avg:91.51ms +step:363/1660 train_time:33220ms step_avg:91.51ms +step:364/1660 train_time:33312ms step_avg:91.52ms +step:365/1660 train_time:33403ms step_avg:91.52ms +step:366/1660 train_time:33495ms step_avg:91.52ms +step:367/1660 train_time:33586ms step_avg:91.52ms +step:368/1660 train_time:33677ms step_avg:91.51ms +step:369/1660 train_time:33769ms step_avg:91.51ms +step:370/1660 train_time:33859ms step_avg:91.51ms +step:371/1660 train_time:33950ms step_avg:91.51ms +step:372/1660 train_time:34042ms step_avg:91.51ms +step:373/1660 train_time:34133ms step_avg:91.51ms +step:374/1660 train_time:34224ms step_avg:91.51ms +step:375/1660 train_time:34315ms step_avg:91.51ms +step:375/1660 val_loss:3.8141 train_time:34408ms step_avg:91.75ms +step:376/1660 train_time:34428ms step_avg:91.56ms +step:377/1660 train_time:34504ms step_avg:91.52ms +step:378/1660 train_time:34599ms step_avg:91.53ms +step:379/1660 train_time:34692ms step_avg:91.54ms +step:380/1660 train_time:34783ms step_avg:91.53ms +step:381/1660 train_time:34873ms step_avg:91.53ms +step:382/1660 train_time:34964ms step_avg:91.53ms +step:383/1660 train_time:35054ms step_avg:91.53ms +step:384/1660 train_time:35145ms step_avg:91.52ms +step:385/1660 train_time:35234ms step_avg:91.52ms +step:386/1660 train_time:35325ms step_avg:91.52ms +step:387/1660 train_time:35417ms step_avg:91.52ms +step:388/1660 train_time:35510ms step_avg:91.52ms +step:389/1660 train_time:35603ms step_avg:91.52ms +step:390/1660 train_time:35696ms step_avg:91.53ms +step:391/1660 train_time:35788ms step_avg:91.53ms +step:392/1660 train_time:35879ms step_avg:91.53ms +step:393/1660 train_time:35970ms step_avg:91.53ms +step:394/1660 train_time:36060ms step_avg:91.52ms +step:395/1660 train_time:36150ms step_avg:91.52ms +step:396/1660 train_time:36240ms step_avg:91.52ms +step:397/1660 train_time:36331ms step_avg:91.51ms +step:398/1660 train_time:36423ms step_avg:91.51ms +step:399/1660 train_time:36515ms step_avg:91.52ms +step:400/1660 train_time:36608ms step_avg:91.52ms +step:401/1660 train_time:36700ms step_avg:91.52ms +step:402/1660 train_time:36792ms step_avg:91.52ms +step:403/1660 train_time:36883ms step_avg:91.52ms +step:404/1660 train_time:36974ms step_avg:91.52ms +step:405/1660 train_time:37065ms step_avg:91.52ms +step:406/1660 train_time:37156ms step_avg:91.52ms +step:407/1660 train_time:37246ms step_avg:91.51ms +step:408/1660 train_time:37337ms step_avg:91.51ms +step:409/1660 train_time:37428ms step_avg:91.51ms +step:410/1660 train_time:37520ms step_avg:91.51ms +step:411/1660 train_time:37612ms step_avg:91.51ms +step:412/1660 train_time:37703ms step_avg:91.51ms +step:413/1660 train_time:37795ms step_avg:91.51ms +step:414/1660 train_time:37886ms step_avg:91.51ms +step:415/1660 train_time:37977ms step_avg:91.51ms +step:416/1660 train_time:38068ms step_avg:91.51ms +step:417/1660 train_time:38159ms step_avg:91.51ms +step:418/1660 train_time:38250ms step_avg:91.51ms +step:419/1660 train_time:38341ms step_avg:91.51ms +step:420/1660 train_time:38432ms step_avg:91.51ms +step:421/1660 train_time:38524ms step_avg:91.51ms +step:422/1660 train_time:38616ms step_avg:91.51ms +step:423/1660 train_time:38707ms step_avg:91.51ms +step:424/1660 train_time:38799ms step_avg:91.51ms +step:425/1660 train_time:38890ms step_avg:91.51ms +step:426/1660 train_time:38981ms step_avg:91.50ms +step:427/1660 train_time:39072ms step_avg:91.50ms +step:428/1660 train_time:39163ms step_avg:91.50ms +step:429/1660 train_time:39253ms step_avg:91.50ms +step:430/1660 train_time:39345ms step_avg:91.50ms +step:431/1660 train_time:39436ms step_avg:91.50ms +step:432/1660 train_time:39528ms step_avg:91.50ms +step:433/1660 train_time:39619ms step_avg:91.50ms +step:434/1660 train_time:39711ms step_avg:91.50ms +step:435/1660 train_time:39803ms step_avg:91.50ms +step:436/1660 train_time:39894ms step_avg:91.50ms +step:437/1660 train_time:39986ms step_avg:91.50ms +step:438/1660 train_time:40077ms step_avg:91.50ms +step:439/1660 train_time:40169ms step_avg:91.50ms +step:440/1660 train_time:40260ms step_avg:91.50ms +step:441/1660 train_time:40351ms step_avg:91.50ms +step:442/1660 train_time:40442ms step_avg:91.50ms +step:443/1660 train_time:40534ms step_avg:91.50ms +step:444/1660 train_time:40626ms step_avg:91.50ms +step:445/1660 train_time:40717ms step_avg:91.50ms +step:446/1660 train_time:40808ms step_avg:91.50ms +step:447/1660 train_time:40899ms step_avg:91.50ms +step:448/1660 train_time:40991ms step_avg:91.50ms +step:449/1660 train_time:41082ms step_avg:91.50ms +step:450/1660 train_time:41174ms step_avg:91.50ms +step:451/1660 train_time:41265ms step_avg:91.50ms +step:452/1660 train_time:41356ms step_avg:91.50ms +step:453/1660 train_time:41447ms step_avg:91.49ms +step:454/1660 train_time:41538ms step_avg:91.49ms +step:455/1660 train_time:41630ms step_avg:91.49ms +step:456/1660 train_time:41722ms step_avg:91.50ms +step:457/1660 train_time:41813ms step_avg:91.49ms +step:458/1660 train_time:41904ms step_avg:91.49ms +step:459/1660 train_time:41996ms step_avg:91.49ms +step:460/1660 train_time:42088ms step_avg:91.50ms +step:461/1660 train_time:42179ms step_avg:91.50ms +step:462/1660 train_time:42272ms step_avg:91.50ms +step:463/1660 train_time:42363ms step_avg:91.50ms +step:464/1660 train_time:42454ms step_avg:91.50ms +step:465/1660 train_time:42545ms step_avg:91.49ms +step:466/1660 train_time:42636ms step_avg:91.49ms +step:467/1660 train_time:42727ms step_avg:91.49ms +step:468/1660 train_time:42819ms step_avg:91.49ms +step:469/1660 train_time:42909ms step_avg:91.49ms +step:470/1660 train_time:43002ms step_avg:91.49ms +step:471/1660 train_time:43093ms step_avg:91.49ms +step:472/1660 train_time:43185ms step_avg:91.49ms +step:473/1660 train_time:43276ms step_avg:91.49ms +step:474/1660 train_time:43367ms step_avg:91.49ms +step:475/1660 train_time:43459ms step_avg:91.49ms +step:476/1660 train_time:43550ms step_avg:91.49ms +step:477/1660 train_time:43642ms step_avg:91.49ms +step:478/1660 train_time:43734ms step_avg:91.49ms +step:479/1660 train_time:43825ms step_avg:91.49ms +step:480/1660 train_time:43917ms step_avg:91.49ms +step:481/1660 train_time:44008ms step_avg:91.49ms +step:482/1660 train_time:44099ms step_avg:91.49ms +step:483/1660 train_time:44191ms step_avg:91.49ms +step:484/1660 train_time:44283ms step_avg:91.49ms +step:485/1660 train_time:44375ms step_avg:91.49ms +step:486/1660 train_time:44467ms step_avg:91.49ms +step:487/1660 train_time:44557ms step_avg:91.49ms +step:488/1660 train_time:44649ms step_avg:91.49ms +step:489/1660 train_time:44740ms step_avg:91.49ms +step:490/1660 train_time:44831ms step_avg:91.49ms +step:491/1660 train_time:44922ms step_avg:91.49ms +step:492/1660 train_time:45014ms step_avg:91.49ms +step:493/1660 train_time:45105ms step_avg:91.49ms +step:494/1660 train_time:45196ms step_avg:91.49ms +step:495/1660 train_time:45288ms step_avg:91.49ms +step:496/1660 train_time:45380ms step_avg:91.49ms +step:497/1660 train_time:45471ms step_avg:91.49ms +step:498/1660 train_time:45562ms step_avg:91.49ms +step:499/1660 train_time:45653ms step_avg:91.49ms +step:500/1660 train_time:45745ms step_avg:91.49ms +step:500/1660 val_loss:3.7144 train_time:45838ms step_avg:91.68ms +step:501/1660 train_time:45858ms step_avg:91.53ms +step:502/1660 train_time:45932ms step_avg:91.50ms +step:503/1660 train_time:46029ms step_avg:91.51ms +step:504/1660 train_time:46123ms step_avg:91.51ms +step:505/1660 train_time:46214ms step_avg:91.51ms +step:506/1660 train_time:46304ms step_avg:91.51ms +step:507/1660 train_time:46394ms step_avg:91.51ms +step:508/1660 train_time:46485ms step_avg:91.51ms +step:509/1660 train_time:46576ms step_avg:91.50ms +step:510/1660 train_time:46667ms step_avg:91.50ms +step:511/1660 train_time:46758ms step_avg:91.50ms +step:512/1660 train_time:46849ms step_avg:91.50ms +step:513/1660 train_time:46943ms step_avg:91.51ms +step:514/1660 train_time:47036ms step_avg:91.51ms +step:515/1660 train_time:47128ms step_avg:91.51ms +step:516/1660 train_time:47220ms step_avg:91.51ms +step:517/1660 train_time:47311ms step_avg:91.51ms +step:518/1660 train_time:47402ms step_avg:91.51ms +step:519/1660 train_time:47492ms step_avg:91.51ms +step:520/1660 train_time:47583ms step_avg:91.51ms +step:521/1660 train_time:47674ms step_avg:91.50ms +step:522/1660 train_time:47765ms step_avg:91.50ms +step:523/1660 train_time:47857ms step_avg:91.50ms +step:524/1660 train_time:47949ms step_avg:91.51ms +step:525/1660 train_time:48043ms step_avg:91.51ms +step:526/1660 train_time:48135ms step_avg:91.51ms +step:527/1660 train_time:48227ms step_avg:91.51ms +step:528/1660 train_time:48319ms step_avg:91.51ms +step:529/1660 train_time:48410ms step_avg:91.51ms +step:530/1660 train_time:48500ms step_avg:91.51ms +step:531/1660 train_time:48591ms step_avg:91.51ms +step:532/1660 train_time:48683ms step_avg:91.51ms +step:533/1660 train_time:48774ms step_avg:91.51ms +step:534/1660 train_time:48866ms step_avg:91.51ms +step:535/1660 train_time:48958ms step_avg:91.51ms +step:536/1660 train_time:49050ms step_avg:91.51ms +step:537/1660 train_time:49142ms step_avg:91.51ms +step:538/1660 train_time:49234ms step_avg:91.51ms +step:539/1660 train_time:49327ms step_avg:91.52ms +step:540/1660 train_time:49419ms step_avg:91.52ms +step:541/1660 train_time:49509ms step_avg:91.51ms +step:542/1660 train_time:49600ms step_avg:91.51ms +step:543/1660 train_time:49690ms step_avg:91.51ms +step:544/1660 train_time:49781ms step_avg:91.51ms +step:545/1660 train_time:49872ms step_avg:91.51ms +step:546/1660 train_time:49963ms step_avg:91.51ms +step:547/1660 train_time:50055ms step_avg:91.51ms +step:548/1660 train_time:50148ms step_avg:91.51ms +step:549/1660 train_time:50240ms step_avg:91.51ms +step:550/1660 train_time:50331ms step_avg:91.51ms +step:551/1660 train_time:50422ms step_avg:91.51ms +step:552/1660 train_time:50513ms step_avg:91.51ms +step:553/1660 train_time:50604ms step_avg:91.51ms +step:554/1660 train_time:50695ms step_avg:91.51ms +step:555/1660 train_time:50786ms step_avg:91.51ms +step:556/1660 train_time:50879ms step_avg:91.51ms +step:557/1660 train_time:50973ms step_avg:91.51ms +step:558/1660 train_time:51067ms step_avg:91.52ms +step:559/1660 train_time:51161ms step_avg:91.52ms +step:560/1660 train_time:51253ms step_avg:91.52ms +step:561/1660 train_time:51347ms step_avg:91.53ms +step:562/1660 train_time:51440ms step_avg:91.53ms +step:563/1660 train_time:51532ms step_avg:91.53ms +step:564/1660 train_time:51625ms step_avg:91.53ms +step:565/1660 train_time:51718ms step_avg:91.54ms +step:566/1660 train_time:51810ms step_avg:91.54ms +step:567/1660 train_time:51903ms step_avg:91.54ms +step:568/1660 train_time:51996ms step_avg:91.54ms +step:569/1660 train_time:52088ms step_avg:91.54ms +step:570/1660 train_time:52181ms step_avg:91.55ms +step:571/1660 train_time:52275ms step_avg:91.55ms +step:572/1660 train_time:52368ms step_avg:91.55ms +step:573/1660 train_time:52461ms step_avg:91.55ms +step:574/1660 train_time:52553ms step_avg:91.56ms +step:575/1660 train_time:52647ms step_avg:91.56ms +step:576/1660 train_time:52740ms step_avg:91.56ms +step:577/1660 train_time:52833ms step_avg:91.56ms +step:578/1660 train_time:52925ms step_avg:91.57ms +step:579/1660 train_time:53017ms step_avg:91.57ms +step:580/1660 train_time:53109ms step_avg:91.57ms +step:581/1660 train_time:53203ms step_avg:91.57ms +step:582/1660 train_time:53296ms step_avg:91.57ms +step:583/1660 train_time:53389ms step_avg:91.58ms +step:584/1660 train_time:53481ms step_avg:91.58ms +step:585/1660 train_time:53574ms step_avg:91.58ms +step:586/1660 train_time:53667ms step_avg:91.58ms +step:587/1660 train_time:53760ms step_avg:91.58ms +step:588/1660 train_time:53852ms step_avg:91.58ms +step:589/1660 train_time:53946ms step_avg:91.59ms +step:590/1660 train_time:54039ms step_avg:91.59ms +step:591/1660 train_time:54131ms step_avg:91.59ms +step:592/1660 train_time:54224ms step_avg:91.59ms +step:593/1660 train_time:54317ms step_avg:91.60ms +step:594/1660 train_time:54410ms step_avg:91.60ms +step:595/1660 train_time:54503ms step_avg:91.60ms +step:596/1660 train_time:54596ms step_avg:91.60ms +step:597/1660 train_time:54689ms step_avg:91.61ms +step:598/1660 train_time:54781ms step_avg:91.61ms +step:599/1660 train_time:54873ms step_avg:91.61ms +step:600/1660 train_time:54967ms step_avg:91.61ms +step:601/1660 train_time:55060ms step_avg:91.61ms +step:602/1660 train_time:55152ms step_avg:91.61ms +step:603/1660 train_time:55245ms step_avg:91.62ms +step:604/1660 train_time:55339ms step_avg:91.62ms +step:605/1660 train_time:55431ms step_avg:91.62ms +step:606/1660 train_time:55524ms step_avg:91.62ms +step:607/1660 train_time:55617ms step_avg:91.63ms +step:608/1660 train_time:55710ms step_avg:91.63ms +step:609/1660 train_time:55803ms step_avg:91.63ms +step:610/1660 train_time:55895ms step_avg:91.63ms +step:611/1660 train_time:55988ms step_avg:91.63ms +step:612/1660 train_time:56080ms step_avg:91.63ms +step:613/1660 train_time:56172ms step_avg:91.63ms +step:614/1660 train_time:56265ms step_avg:91.64ms +step:615/1660 train_time:56358ms step_avg:91.64ms +step:616/1660 train_time:56450ms step_avg:91.64ms +step:617/1660 train_time:56544ms step_avg:91.64ms +step:618/1660 train_time:56637ms step_avg:91.65ms +step:619/1660 train_time:56729ms step_avg:91.65ms +step:620/1660 train_time:56822ms step_avg:91.65ms +step:621/1660 train_time:56914ms step_avg:91.65ms +step:622/1660 train_time:57008ms step_avg:91.65ms +step:623/1660 train_time:57100ms step_avg:91.65ms +step:624/1660 train_time:57193ms step_avg:91.66ms +step:625/1660 train_time:57286ms step_avg:91.66ms +step:625/1660 val_loss:3.6126 train_time:57381ms step_avg:91.81ms +step:626/1660 train_time:57401ms step_avg:91.70ms +step:627/1660 train_time:57478ms step_avg:91.67ms +step:628/1660 train_time:57576ms step_avg:91.68ms +step:629/1660 train_time:57669ms step_avg:91.68ms +step:630/1660 train_time:57761ms step_avg:91.68ms +step:631/1660 train_time:57852ms step_avg:91.68ms +step:632/1660 train_time:57943ms step_avg:91.68ms +step:633/1660 train_time:58034ms step_avg:91.68ms +step:634/1660 train_time:58125ms step_avg:91.68ms +step:635/1660 train_time:58217ms step_avg:91.68ms +step:636/1660 train_time:58312ms step_avg:91.69ms +step:637/1660 train_time:58408ms step_avg:91.69ms +step:638/1660 train_time:58503ms step_avg:91.70ms +step:639/1660 train_time:58597ms step_avg:91.70ms +step:640/1660 train_time:58689ms step_avg:91.70ms +step:641/1660 train_time:58782ms step_avg:91.70ms +step:642/1660 train_time:58874ms step_avg:91.70ms +step:643/1660 train_time:58966ms step_avg:91.70ms +step:644/1660 train_time:59057ms step_avg:91.70ms +step:645/1660 train_time:59149ms step_avg:91.70ms +step:646/1660 train_time:59241ms step_avg:91.70ms +step:647/1660 train_time:59334ms step_avg:91.71ms +step:648/1660 train_time:59432ms step_avg:91.72ms +step:649/1660 train_time:59527ms step_avg:91.72ms +step:650/1660 train_time:59620ms step_avg:91.72ms +step:651/1660 train_time:59712ms step_avg:91.72ms +step:652/1660 train_time:59805ms step_avg:91.73ms +step:653/1660 train_time:59897ms step_avg:91.73ms +step:654/1660 train_time:59990ms step_avg:91.73ms +step:655/1660 train_time:60082ms step_avg:91.73ms +step:656/1660 train_time:60175ms step_avg:91.73ms +step:657/1660 train_time:60268ms step_avg:91.73ms +step:658/1660 train_time:60362ms step_avg:91.74ms +step:659/1660 train_time:60456ms step_avg:91.74ms +step:660/1660 train_time:60552ms step_avg:91.74ms +step:661/1660 train_time:60645ms step_avg:91.75ms +step:662/1660 train_time:60737ms step_avg:91.75ms +step:663/1660 train_time:60830ms step_avg:91.75ms +step:664/1660 train_time:60922ms step_avg:91.75ms +step:665/1660 train_time:61014ms step_avg:91.75ms +step:666/1660 train_time:61106ms step_avg:91.75ms +step:667/1660 train_time:61198ms step_avg:91.75ms +step:668/1660 train_time:61291ms step_avg:91.75ms +step:669/1660 train_time:61385ms step_avg:91.76ms +step:670/1660 train_time:61478ms step_avg:91.76ms +step:671/1660 train_time:61572ms step_avg:91.76ms +step:672/1660 train_time:61665ms step_avg:91.76ms +step:673/1660 train_time:61757ms step_avg:91.76ms +step:674/1660 train_time:61850ms step_avg:91.77ms +step:675/1660 train_time:61942ms step_avg:91.77ms +step:676/1660 train_time:62033ms step_avg:91.77ms +step:677/1660 train_time:62126ms step_avg:91.77ms +step:678/1660 train_time:62218ms step_avg:91.77ms +step:679/1660 train_time:62311ms step_avg:91.77ms +step:680/1660 train_time:62404ms step_avg:91.77ms +step:681/1660 train_time:62497ms step_avg:91.77ms +step:682/1660 train_time:62590ms step_avg:91.77ms +step:683/1660 train_time:62683ms step_avg:91.78ms +step:684/1660 train_time:62776ms step_avg:91.78ms +step:685/1660 train_time:62869ms step_avg:91.78ms +step:686/1660 train_time:62962ms step_avg:91.78ms +step:687/1660 train_time:63054ms step_avg:91.78ms +step:688/1660 train_time:63147ms step_avg:91.78ms +step:689/1660 train_time:63239ms step_avg:91.78ms +step:690/1660 train_time:63331ms step_avg:91.78ms +step:691/1660 train_time:63425ms step_avg:91.79ms +step:692/1660 train_time:63517ms step_avg:91.79ms +step:693/1660 train_time:63610ms step_avg:91.79ms +step:694/1660 train_time:63702ms step_avg:91.79ms +step:695/1660 train_time:63795ms step_avg:91.79ms +step:696/1660 train_time:63888ms step_avg:91.79ms +step:697/1660 train_time:63981ms step_avg:91.80ms +step:698/1660 train_time:64074ms step_avg:91.80ms +step:699/1660 train_time:64166ms step_avg:91.80ms +step:700/1660 train_time:64259ms step_avg:91.80ms +step:701/1660 train_time:64351ms step_avg:91.80ms +step:702/1660 train_time:64445ms step_avg:91.80ms +step:703/1660 train_time:64537ms step_avg:91.80ms +step:704/1660 train_time:64631ms step_avg:91.81ms +step:705/1660 train_time:64725ms step_avg:91.81ms +step:706/1660 train_time:64817ms step_avg:91.81ms +step:707/1660 train_time:64911ms step_avg:91.81ms +step:708/1660 train_time:65004ms step_avg:91.81ms +step:709/1660 train_time:65096ms step_avg:91.81ms +step:710/1660 train_time:65189ms step_avg:91.82ms +step:711/1660 train_time:65282ms step_avg:91.82ms +step:712/1660 train_time:65374ms step_avg:91.82ms +step:713/1660 train_time:65466ms step_avg:91.82ms +step:714/1660 train_time:65558ms step_avg:91.82ms +step:715/1660 train_time:65651ms step_avg:91.82ms +step:716/1660 train_time:65744ms step_avg:91.82ms +step:717/1660 train_time:65837ms step_avg:91.82ms +step:718/1660 train_time:65930ms step_avg:91.82ms +step:719/1660 train_time:66023ms step_avg:91.83ms +step:720/1660 train_time:66116ms step_avg:91.83ms +step:721/1660 train_time:66209ms step_avg:91.83ms +step:722/1660 train_time:66302ms step_avg:91.83ms +step:723/1660 train_time:66394ms step_avg:91.83ms +step:724/1660 train_time:66486ms step_avg:91.83ms +step:725/1660 train_time:66579ms step_avg:91.83ms +step:726/1660 train_time:66672ms step_avg:91.83ms +step:727/1660 train_time:66764ms step_avg:91.84ms +step:728/1660 train_time:66857ms step_avg:91.84ms +step:729/1660 train_time:66951ms step_avg:91.84ms +step:730/1660 train_time:67044ms step_avg:91.84ms +step:731/1660 train_time:67137ms step_avg:91.84ms +step:732/1660 train_time:67230ms step_avg:91.84ms +step:733/1660 train_time:67323ms step_avg:91.85ms +step:734/1660 train_time:67415ms step_avg:91.85ms +step:735/1660 train_time:67508ms step_avg:91.85ms +step:736/1660 train_time:67601ms step_avg:91.85ms +step:737/1660 train_time:67694ms step_avg:91.85ms +step:738/1660 train_time:67786ms step_avg:91.85ms +step:739/1660 train_time:67880ms step_avg:91.85ms +step:740/1660 train_time:67972ms step_avg:91.85ms +step:741/1660 train_time:68065ms step_avg:91.86ms +step:742/1660 train_time:68158ms step_avg:91.86ms +step:743/1660 train_time:68250ms step_avg:91.86ms +step:744/1660 train_time:68343ms step_avg:91.86ms +step:745/1660 train_time:68435ms step_avg:91.86ms +step:746/1660 train_time:68527ms step_avg:91.86ms +step:747/1660 train_time:68620ms step_avg:91.86ms +step:748/1660 train_time:68712ms step_avg:91.86ms +step:749/1660 train_time:68804ms step_avg:91.86ms +step:750/1660 train_time:68897ms step_avg:91.86ms +step:750/1660 val_loss:3.5588 train_time:68991ms step_avg:91.99ms +step:751/1660 train_time:69011ms step_avg:91.89ms +step:752/1660 train_time:69087ms step_avg:91.87ms +step:753/1660 train_time:69183ms step_avg:91.88ms +step:754/1660 train_time:69277ms step_avg:91.88ms +step:755/1660 train_time:69368ms step_avg:91.88ms +step:756/1660 train_time:69460ms step_avg:91.88ms +step:757/1660 train_time:69552ms step_avg:91.88ms +step:758/1660 train_time:69644ms step_avg:91.88ms +step:759/1660 train_time:69735ms step_avg:91.88ms +step:760/1660 train_time:69827ms step_avg:91.88ms +step:761/1660 train_time:69920ms step_avg:91.88ms +step:762/1660 train_time:70014ms step_avg:91.88ms +step:763/1660 train_time:70109ms step_avg:91.89ms +step:764/1660 train_time:70204ms step_avg:91.89ms +step:765/1660 train_time:70297ms step_avg:91.89ms +step:766/1660 train_time:70390ms step_avg:91.89ms +step:767/1660 train_time:70483ms step_avg:91.89ms +step:768/1660 train_time:70575ms step_avg:91.89ms +step:769/1660 train_time:70667ms step_avg:91.89ms +step:770/1660 train_time:70758ms step_avg:91.89ms +step:771/1660 train_time:70851ms step_avg:91.90ms +step:772/1660 train_time:70944ms step_avg:91.90ms +step:773/1660 train_time:71039ms step_avg:91.90ms +step:774/1660 train_time:71134ms step_avg:91.90ms +step:775/1660 train_time:71227ms step_avg:91.91ms +step:776/1660 train_time:71320ms step_avg:91.91ms +step:777/1660 train_time:71413ms step_avg:91.91ms +step:778/1660 train_time:71505ms step_avg:91.91ms +step:779/1660 train_time:71598ms step_avg:91.91ms +step:780/1660 train_time:71690ms step_avg:91.91ms +step:781/1660 train_time:71783ms step_avg:91.91ms +step:782/1660 train_time:71875ms step_avg:91.91ms +step:783/1660 train_time:71967ms step_avg:91.91ms +step:784/1660 train_time:72061ms step_avg:91.91ms +step:785/1660 train_time:72155ms step_avg:91.92ms +step:786/1660 train_time:72248ms step_avg:91.92ms +step:787/1660 train_time:72340ms step_avg:91.92ms +step:788/1660 train_time:72433ms step_avg:91.92ms +step:789/1660 train_time:72526ms step_avg:91.92ms +step:790/1660 train_time:72620ms step_avg:91.92ms +step:791/1660 train_time:72712ms step_avg:91.92ms +step:792/1660 train_time:72804ms step_avg:91.92ms +step:793/1660 train_time:72897ms step_avg:91.93ms +step:794/1660 train_time:72989ms step_avg:91.93ms +step:795/1660 train_time:73083ms step_avg:91.93ms +step:796/1660 train_time:73175ms step_avg:91.93ms +step:797/1660 train_time:73268ms step_avg:91.93ms +step:798/1660 train_time:73361ms step_avg:91.93ms +step:799/1660 train_time:73454ms step_avg:91.93ms +step:800/1660 train_time:73546ms step_avg:91.93ms +step:801/1660 train_time:73639ms step_avg:91.93ms +step:802/1660 train_time:73731ms step_avg:91.93ms +step:803/1660 train_time:73825ms step_avg:91.94ms +step:804/1660 train_time:73918ms step_avg:91.94ms +step:805/1660 train_time:74010ms step_avg:91.94ms +step:806/1660 train_time:74104ms step_avg:91.94ms +step:807/1660 train_time:74197ms step_avg:91.94ms +step:808/1660 train_time:74290ms step_avg:91.94ms +step:809/1660 train_time:74383ms step_avg:91.94ms +step:810/1660 train_time:74475ms step_avg:91.94ms +step:811/1660 train_time:74567ms step_avg:91.94ms +step:812/1660 train_time:74660ms step_avg:91.95ms +step:813/1660 train_time:74752ms step_avg:91.95ms +step:814/1660 train_time:74845ms step_avg:91.95ms +step:815/1660 train_time:74937ms step_avg:91.95ms +step:816/1660 train_time:75030ms step_avg:91.95ms +step:817/1660 train_time:75123ms step_avg:91.95ms +step:818/1660 train_time:75216ms step_avg:91.95ms +step:819/1660 train_time:75309ms step_avg:91.95ms +step:820/1660 train_time:75402ms step_avg:91.95ms +step:821/1660 train_time:75496ms step_avg:91.96ms +step:822/1660 train_time:75589ms step_avg:91.96ms +step:823/1660 train_time:75682ms step_avg:91.96ms +step:824/1660 train_time:75775ms step_avg:91.96ms +step:825/1660 train_time:75867ms step_avg:91.96ms +step:826/1660 train_time:75959ms step_avg:91.96ms +step:827/1660 train_time:76052ms step_avg:91.96ms +step:828/1660 train_time:76144ms step_avg:91.96ms +step:829/1660 train_time:76237ms step_avg:91.96ms +step:830/1660 train_time:76330ms step_avg:91.96ms +step:831/1660 train_time:76423ms step_avg:91.96ms +step:832/1660 train_time:76515ms step_avg:91.97ms +step:833/1660 train_time:76608ms step_avg:91.97ms +step:834/1660 train_time:76702ms step_avg:91.97ms +step:835/1660 train_time:76795ms step_avg:91.97ms +step:836/1660 train_time:76887ms step_avg:91.97ms +step:837/1660 train_time:76979ms step_avg:91.97ms +step:838/1660 train_time:77072ms step_avg:91.97ms +step:839/1660 train_time:77165ms step_avg:91.97ms +step:840/1660 train_time:77259ms step_avg:91.98ms +step:841/1660 train_time:77351ms step_avg:91.98ms +step:842/1660 train_time:77445ms step_avg:91.98ms +step:843/1660 train_time:77538ms step_avg:91.98ms +step:844/1660 train_time:77630ms step_avg:91.98ms +step:845/1660 train_time:77723ms step_avg:91.98ms +step:846/1660 train_time:77816ms step_avg:91.98ms +step:847/1660 train_time:77908ms step_avg:91.98ms +step:848/1660 train_time:78001ms step_avg:91.98ms +step:849/1660 train_time:78093ms step_avg:91.98ms +step:850/1660 train_time:78186ms step_avg:91.98ms +step:851/1660 train_time:78279ms step_avg:91.98ms +step:852/1660 train_time:78372ms step_avg:91.99ms +step:853/1660 train_time:78465ms step_avg:91.99ms +step:854/1660 train_time:78558ms step_avg:91.99ms +step:855/1660 train_time:78651ms step_avg:91.99ms +step:856/1660 train_time:78745ms step_avg:91.99ms +step:857/1660 train_time:78837ms step_avg:91.99ms +step:858/1660 train_time:78930ms step_avg:91.99ms +step:859/1660 train_time:79022ms step_avg:91.99ms +step:860/1660 train_time:79115ms step_avg:91.99ms +step:861/1660 train_time:79207ms step_avg:91.99ms +step:862/1660 train_time:79300ms step_avg:92.00ms +step:863/1660 train_time:79393ms step_avg:92.00ms +step:864/1660 train_time:79485ms step_avg:92.00ms +step:865/1660 train_time:79579ms step_avg:92.00ms +step:866/1660 train_time:79671ms step_avg:92.00ms +step:867/1660 train_time:79765ms step_avg:92.00ms +step:868/1660 train_time:79858ms step_avg:92.00ms +step:869/1660 train_time:79950ms step_avg:92.00ms +step:870/1660 train_time:80043ms step_avg:92.00ms +step:871/1660 train_time:80136ms step_avg:92.00ms +step:872/1660 train_time:80229ms step_avg:92.01ms +step:873/1660 train_time:80322ms step_avg:92.01ms +step:874/1660 train_time:80415ms step_avg:92.01ms +step:875/1660 train_time:80507ms step_avg:92.01ms +step:875/1660 val_loss:3.5150 train_time:80601ms step_avg:92.12ms +step:876/1660 train_time:80621ms step_avg:92.03ms +step:877/1660 train_time:80697ms step_avg:92.02ms +step:878/1660 train_time:80796ms step_avg:92.02ms +step:879/1660 train_time:80889ms step_avg:92.02ms +step:880/1660 train_time:80982ms step_avg:92.02ms +step:881/1660 train_time:81073ms step_avg:92.02ms +step:882/1660 train_time:81165ms step_avg:92.02ms +step:883/1660 train_time:81256ms step_avg:92.02ms +step:884/1660 train_time:81347ms step_avg:92.02ms +step:885/1660 train_time:81439ms step_avg:92.02ms +step:886/1660 train_time:81531ms step_avg:92.02ms +step:887/1660 train_time:81626ms step_avg:92.02ms +step:888/1660 train_time:81721ms step_avg:92.03ms +step:889/1660 train_time:81816ms step_avg:92.03ms +step:890/1660 train_time:81909ms step_avg:92.03ms +step:891/1660 train_time:82002ms step_avg:92.03ms +step:892/1660 train_time:82094ms step_avg:92.03ms +step:893/1660 train_time:82186ms step_avg:92.03ms +step:894/1660 train_time:82278ms step_avg:92.03ms +step:895/1660 train_time:82369ms step_avg:92.03ms +step:896/1660 train_time:82462ms step_avg:92.03ms +step:897/1660 train_time:82554ms step_avg:92.03ms +step:898/1660 train_time:82648ms step_avg:92.04ms +step:899/1660 train_time:82743ms step_avg:92.04ms +step:900/1660 train_time:82837ms step_avg:92.04ms +step:901/1660 train_time:82930ms step_avg:92.04ms +step:902/1660 train_time:83022ms step_avg:92.04ms +step:903/1660 train_time:83114ms step_avg:92.04ms +step:904/1660 train_time:83207ms step_avg:92.04ms +step:905/1660 train_time:83300ms step_avg:92.04ms +step:906/1660 train_time:83391ms step_avg:92.04ms +step:907/1660 train_time:83483ms step_avg:92.04ms +step:908/1660 train_time:83577ms step_avg:92.04ms +step:909/1660 train_time:83670ms step_avg:92.05ms +step:910/1660 train_time:83765ms step_avg:92.05ms +step:911/1660 train_time:83858ms step_avg:92.05ms +step:912/1660 train_time:83951ms step_avg:92.05ms +step:913/1660 train_time:84043ms step_avg:92.05ms +step:914/1660 train_time:84136ms step_avg:92.05ms +step:915/1660 train_time:84229ms step_avg:92.05ms +step:916/1660 train_time:84322ms step_avg:92.05ms +step:917/1660 train_time:84413ms step_avg:92.05ms +step:918/1660 train_time:84507ms step_avg:92.06ms +step:919/1660 train_time:84600ms step_avg:92.06ms +step:920/1660 train_time:84692ms step_avg:92.06ms +step:921/1660 train_time:84786ms step_avg:92.06ms +step:922/1660 train_time:84880ms step_avg:92.06ms +step:923/1660 train_time:84973ms step_avg:92.06ms +step:924/1660 train_time:85066ms step_avg:92.06ms +step:925/1660 train_time:85158ms step_avg:92.06ms +step:926/1660 train_time:85250ms step_avg:92.06ms +step:927/1660 train_time:85343ms step_avg:92.06ms +step:928/1660 train_time:85436ms step_avg:92.07ms +step:929/1660 train_time:85529ms step_avg:92.07ms +step:930/1660 train_time:85622ms step_avg:92.07ms +step:931/1660 train_time:85715ms step_avg:92.07ms +step:932/1660 train_time:85808ms step_avg:92.07ms +step:933/1660 train_time:85901ms step_avg:92.07ms +step:934/1660 train_time:85993ms step_avg:92.07ms +step:935/1660 train_time:86087ms step_avg:92.07ms +step:936/1660 train_time:86180ms step_avg:92.07ms +step:937/1660 train_time:86273ms step_avg:92.07ms +step:938/1660 train_time:86366ms step_avg:92.07ms +step:939/1660 train_time:86459ms step_avg:92.08ms +step:940/1660 train_time:86552ms step_avg:92.08ms +step:941/1660 train_time:86644ms step_avg:92.08ms +step:942/1660 train_time:86737ms step_avg:92.08ms +step:943/1660 train_time:86830ms step_avg:92.08ms +step:944/1660 train_time:86922ms step_avg:92.08ms +step:945/1660 train_time:87015ms step_avg:92.08ms +step:946/1660 train_time:87108ms step_avg:92.08ms +step:947/1660 train_time:87203ms step_avg:92.08ms +step:948/1660 train_time:87295ms step_avg:92.08ms +step:949/1660 train_time:87388ms step_avg:92.08ms +step:950/1660 train_time:87481ms step_avg:92.08ms +step:951/1660 train_time:87573ms step_avg:92.09ms +step:952/1660 train_time:87666ms step_avg:92.09ms +step:953/1660 train_time:87759ms step_avg:92.09ms +step:954/1660 train_time:87852ms step_avg:92.09ms +step:955/1660 train_time:87945ms step_avg:92.09ms +step:956/1660 train_time:88038ms step_avg:92.09ms +step:957/1660 train_time:88131ms step_avg:92.09ms +step:958/1660 train_time:88225ms step_avg:92.09ms +step:959/1660 train_time:88317ms step_avg:92.09ms +step:960/1660 train_time:88410ms step_avg:92.09ms +step:961/1660 train_time:88504ms step_avg:92.10ms +step:962/1660 train_time:88598ms step_avg:92.10ms +step:963/1660 train_time:88690ms step_avg:92.10ms +step:964/1660 train_time:88782ms step_avg:92.10ms +step:965/1660 train_time:88875ms step_avg:92.10ms +step:966/1660 train_time:88967ms step_avg:92.10ms +step:967/1660 train_time:89060ms step_avg:92.10ms +step:968/1660 train_time:89153ms step_avg:92.10ms +step:969/1660 train_time:89246ms step_avg:92.10ms +step:970/1660 train_time:89339ms step_avg:92.10ms +step:971/1660 train_time:89432ms step_avg:92.10ms +step:972/1660 train_time:89526ms step_avg:92.10ms +step:973/1660 train_time:89618ms step_avg:92.10ms +step:974/1660 train_time:89711ms step_avg:92.11ms +step:975/1660 train_time:89804ms step_avg:92.11ms +step:976/1660 train_time:89896ms step_avg:92.11ms +step:977/1660 train_time:89989ms step_avg:92.11ms +step:978/1660 train_time:90082ms step_avg:92.11ms +step:979/1660 train_time:90175ms step_avg:92.11ms +step:980/1660 train_time:90268ms step_avg:92.11ms +step:981/1660 train_time:90361ms step_avg:92.11ms +step:982/1660 train_time:90454ms step_avg:92.11ms +step:983/1660 train_time:90547ms step_avg:92.11ms +step:984/1660 train_time:90640ms step_avg:92.11ms +step:985/1660 train_time:90733ms step_avg:92.11ms +step:986/1660 train_time:90826ms step_avg:92.12ms +step:987/1660 train_time:90919ms step_avg:92.12ms +step:988/1660 train_time:91012ms step_avg:92.12ms +step:989/1660 train_time:91105ms step_avg:92.12ms +step:990/1660 train_time:91199ms step_avg:92.12ms +step:991/1660 train_time:91291ms step_avg:92.12ms +step:992/1660 train_time:91384ms step_avg:92.12ms +step:993/1660 train_time:91477ms step_avg:92.12ms +step:994/1660 train_time:91570ms step_avg:92.12ms +step:995/1660 train_time:91662ms step_avg:92.12ms +step:996/1660 train_time:91755ms step_avg:92.12ms +step:997/1660 train_time:91848ms step_avg:92.12ms +step:998/1660 train_time:91942ms step_avg:92.13ms +step:999/1660 train_time:92034ms step_avg:92.13ms +step:1000/1660 train_time:92128ms step_avg:92.13ms +step:1000/1660 val_loss:3.4647 train_time:92222ms step_avg:92.22ms +step:1001/1660 train_time:92243ms step_avg:92.15ms +step:1002/1660 train_time:92320ms step_avg:92.14ms +step:1003/1660 train_time:92419ms step_avg:92.14ms +step:1004/1660 train_time:92512ms step_avg:92.14ms +step:1005/1660 train_time:92604ms step_avg:92.14ms +step:1006/1660 train_time:92695ms step_avg:92.14ms +step:1007/1660 train_time:92786ms step_avg:92.14ms +step:1008/1660 train_time:92878ms step_avg:92.14ms +step:1009/1660 train_time:92969ms step_avg:92.14ms +step:1010/1660 train_time:93061ms step_avg:92.14ms +step:1011/1660 train_time:93153ms step_avg:92.14ms +step:1012/1660 train_time:93248ms step_avg:92.14ms +step:1013/1660 train_time:93345ms step_avg:92.15ms +step:1014/1660 train_time:93440ms step_avg:92.15ms +step:1015/1660 train_time:93533ms step_avg:92.15ms +step:1016/1660 train_time:93626ms step_avg:92.15ms +step:1017/1660 train_time:93719ms step_avg:92.15ms +step:1018/1660 train_time:93810ms step_avg:92.15ms +step:1019/1660 train_time:93902ms step_avg:92.15ms +step:1020/1660 train_time:93994ms step_avg:92.15ms +step:1021/1660 train_time:94085ms step_avg:92.15ms +step:1022/1660 train_time:94178ms step_avg:92.15ms +step:1023/1660 train_time:94273ms step_avg:92.15ms +step:1024/1660 train_time:94369ms step_avg:92.16ms +step:1025/1660 train_time:94464ms step_avg:92.16ms +step:1026/1660 train_time:94556ms step_avg:92.16ms +step:1027/1660 train_time:94649ms step_avg:92.16ms +step:1028/1660 train_time:94741ms step_avg:92.16ms +step:1029/1660 train_time:94834ms step_avg:92.16ms +step:1030/1660 train_time:94926ms step_avg:92.16ms +step:1031/1660 train_time:95017ms step_avg:92.16ms +step:1032/1660 train_time:95109ms step_avg:92.16ms +step:1033/1660 train_time:95201ms step_avg:92.16ms +step:1034/1660 train_time:95296ms step_avg:92.16ms +step:1035/1660 train_time:95389ms step_avg:92.16ms +step:1036/1660 train_time:95482ms step_avg:92.16ms +step:1037/1660 train_time:95575ms step_avg:92.16ms +step:1038/1660 train_time:95668ms step_avg:92.17ms +step:1039/1660 train_time:95761ms step_avg:92.17ms +step:1040/1660 train_time:95853ms step_avg:92.17ms +step:1041/1660 train_time:95947ms step_avg:92.17ms +step:1042/1660 train_time:96039ms step_avg:92.17ms +step:1043/1660 train_time:96131ms step_avg:92.17ms +step:1044/1660 train_time:96225ms step_avg:92.17ms +step:1045/1660 train_time:96318ms step_avg:92.17ms +step:1046/1660 train_time:96411ms step_avg:92.17ms +step:1047/1660 train_time:96505ms step_avg:92.17ms +step:1048/1660 train_time:96599ms step_avg:92.17ms +step:1049/1660 train_time:96691ms step_avg:92.17ms +step:1050/1660 train_time:96784ms step_avg:92.17ms +step:1051/1660 train_time:96876ms step_avg:92.17ms +step:1052/1660 train_time:96968ms step_avg:92.18ms +step:1053/1660 train_time:97061ms step_avg:92.18ms +step:1054/1660 train_time:97154ms step_avg:92.18ms +step:1055/1660 train_time:97248ms step_avg:92.18ms +step:1056/1660 train_time:97341ms step_avg:92.18ms +step:1057/1660 train_time:97434ms step_avg:92.18ms +step:1058/1660 train_time:97527ms step_avg:92.18ms +step:1059/1660 train_time:97620ms step_avg:92.18ms +step:1060/1660 train_time:97714ms step_avg:92.18ms +step:1061/1660 train_time:97806ms step_avg:92.18ms +step:1062/1660 train_time:97899ms step_avg:92.18ms +step:1063/1660 train_time:97991ms step_avg:92.18ms +step:1064/1660 train_time:98083ms step_avg:92.18ms +step:1065/1660 train_time:98176ms step_avg:92.18ms +step:1066/1660 train_time:98269ms step_avg:92.18ms +step:1067/1660 train_time:98362ms step_avg:92.19ms +step:1068/1660 train_time:98455ms step_avg:92.19ms +step:1069/1660 train_time:98549ms step_avg:92.19ms +step:1070/1660 train_time:98643ms step_avg:92.19ms +step:1071/1660 train_time:98735ms step_avg:92.19ms +step:1072/1660 train_time:98828ms step_avg:92.19ms +step:1073/1660 train_time:98920ms step_avg:92.19ms +step:1074/1660 train_time:99012ms step_avg:92.19ms +step:1075/1660 train_time:99104ms step_avg:92.19ms +step:1076/1660 train_time:99197ms step_avg:92.19ms +step:1077/1660 train_time:99289ms step_avg:92.19ms +step:1078/1660 train_time:99383ms step_avg:92.19ms +step:1079/1660 train_time:99475ms step_avg:92.19ms +step:1080/1660 train_time:99569ms step_avg:92.19ms +step:1081/1660 train_time:99662ms step_avg:92.19ms +step:1082/1660 train_time:99754ms step_avg:92.19ms +step:1083/1660 train_time:99848ms step_avg:92.20ms +step:1084/1660 train_time:99940ms step_avg:92.20ms +step:1085/1660 train_time:100032ms step_avg:92.20ms +step:1086/1660 train_time:100125ms step_avg:92.20ms +step:1087/1660 train_time:100219ms step_avg:92.20ms +step:1088/1660 train_time:100312ms step_avg:92.20ms +step:1089/1660 train_time:100404ms step_avg:92.20ms +step:1090/1660 train_time:100497ms step_avg:92.20ms +step:1091/1660 train_time:100590ms step_avg:92.20ms +step:1092/1660 train_time:100683ms step_avg:92.20ms +step:1093/1660 train_time:100776ms step_avg:92.20ms +step:1094/1660 train_time:100869ms step_avg:92.20ms +step:1095/1660 train_time:100961ms step_avg:92.20ms +step:1096/1660 train_time:101053ms step_avg:92.20ms +step:1097/1660 train_time:101146ms step_avg:92.20ms +step:1098/1660 train_time:101239ms step_avg:92.20ms +step:1099/1660 train_time:101332ms step_avg:92.20ms +step:1100/1660 train_time:101426ms step_avg:92.21ms +step:1101/1660 train_time:101519ms step_avg:92.21ms +step:1102/1660 train_time:101611ms step_avg:92.21ms +step:1103/1660 train_time:101703ms step_avg:92.21ms +step:1104/1660 train_time:101797ms step_avg:92.21ms +step:1105/1660 train_time:101889ms step_avg:92.21ms +step:1106/1660 train_time:101981ms step_avg:92.21ms +step:1107/1660 train_time:102073ms step_avg:92.21ms +step:1108/1660 train_time:102166ms step_avg:92.21ms +step:1109/1660 train_time:102259ms step_avg:92.21ms +step:1110/1660 train_time:102353ms step_avg:92.21ms +step:1111/1660 train_time:102447ms step_avg:92.21ms +step:1112/1660 train_time:102541ms step_avg:92.21ms +step:1113/1660 train_time:102633ms step_avg:92.21ms +step:1114/1660 train_time:102728ms step_avg:92.22ms +step:1115/1660 train_time:102822ms step_avg:92.22ms +step:1116/1660 train_time:102916ms step_avg:92.22ms +step:1117/1660 train_time:103009ms step_avg:92.22ms +step:1118/1660 train_time:103101ms step_avg:92.22ms +step:1119/1660 train_time:103194ms step_avg:92.22ms +step:1120/1660 train_time:103288ms step_avg:92.22ms +step:1121/1660 train_time:103381ms step_avg:92.22ms +step:1122/1660 train_time:103475ms step_avg:92.22ms +step:1123/1660 train_time:103568ms step_avg:92.22ms +step:1124/1660 train_time:103662ms step_avg:92.23ms +step:1125/1660 train_time:103755ms step_avg:92.23ms +step:1125/1660 val_loss:3.4123 train_time:103851ms step_avg:92.31ms +step:1126/1660 train_time:103871ms step_avg:92.25ms +step:1127/1660 train_time:103950ms step_avg:92.24ms +step:1128/1660 train_time:104052ms step_avg:92.24ms +step:1129/1660 train_time:104147ms step_avg:92.25ms +step:1130/1660 train_time:104239ms step_avg:92.25ms +step:1131/1660 train_time:104332ms step_avg:92.25ms +step:1132/1660 train_time:104423ms step_avg:92.25ms +step:1133/1660 train_time:104516ms step_avg:92.25ms +step:1134/1660 train_time:104608ms step_avg:92.25ms +step:1135/1660 train_time:104700ms step_avg:92.25ms +step:1136/1660 train_time:104792ms step_avg:92.25ms +step:1137/1660 train_time:104886ms step_avg:92.25ms +step:1138/1660 train_time:104981ms step_avg:92.25ms +step:1139/1660 train_time:105077ms step_avg:92.25ms +step:1140/1660 train_time:105171ms step_avg:92.26ms +step:1141/1660 train_time:105264ms step_avg:92.26ms +step:1142/1660 train_time:105356ms step_avg:92.26ms +step:1143/1660 train_time:105448ms step_avg:92.26ms +step:1144/1660 train_time:105541ms step_avg:92.26ms +step:1145/1660 train_time:105634ms step_avg:92.26ms +step:1146/1660 train_time:105726ms step_avg:92.26ms +step:1147/1660 train_time:105819ms step_avg:92.26ms +step:1148/1660 train_time:105913ms step_avg:92.26ms +step:1149/1660 train_time:106007ms step_avg:92.26ms +step:1150/1660 train_time:106102ms step_avg:92.26ms +step:1151/1660 train_time:106196ms step_avg:92.26ms +step:1152/1660 train_time:106290ms step_avg:92.27ms +step:1153/1660 train_time:106383ms step_avg:92.27ms +step:1154/1660 train_time:106477ms step_avg:92.27ms +step:1155/1660 train_time:106571ms step_avg:92.27ms +step:1156/1660 train_time:106663ms step_avg:92.27ms +step:1157/1660 train_time:106755ms step_avg:92.27ms +step:1158/1660 train_time:106848ms step_avg:92.27ms +step:1159/1660 train_time:106942ms step_avg:92.27ms +step:1160/1660 train_time:107036ms step_avg:92.27ms +step:1161/1660 train_time:107130ms step_avg:92.27ms +step:1162/1660 train_time:107223ms step_avg:92.27ms +step:1163/1660 train_time:107318ms step_avg:92.28ms +step:1164/1660 train_time:107411ms step_avg:92.28ms +step:1165/1660 train_time:107504ms step_avg:92.28ms +step:1166/1660 train_time:107598ms step_avg:92.28ms +step:1167/1660 train_time:107691ms step_avg:92.28ms +step:1168/1660 train_time:107784ms step_avg:92.28ms +step:1169/1660 train_time:107876ms step_avg:92.28ms +step:1170/1660 train_time:107970ms step_avg:92.28ms +step:1171/1660 train_time:108064ms step_avg:92.28ms +step:1172/1660 train_time:108157ms step_avg:92.28ms +step:1173/1660 train_time:108251ms step_avg:92.29ms +step:1174/1660 train_time:108344ms step_avg:92.29ms +step:1175/1660 train_time:108437ms step_avg:92.29ms +step:1176/1660 train_time:108531ms step_avg:92.29ms +step:1177/1660 train_time:108624ms step_avg:92.29ms +step:1178/1660 train_time:108716ms step_avg:92.29ms +step:1179/1660 train_time:108809ms step_avg:92.29ms +step:1180/1660 train_time:108902ms step_avg:92.29ms +step:1181/1660 train_time:108997ms step_avg:92.29ms +step:1182/1660 train_time:109091ms step_avg:92.29ms +step:1183/1660 train_time:109184ms step_avg:92.29ms +step:1184/1660 train_time:109278ms step_avg:92.30ms +step:1185/1660 train_time:109371ms step_avg:92.30ms +step:1186/1660 train_time:109464ms step_avg:92.30ms +step:1187/1660 train_time:109557ms step_avg:92.30ms +step:1188/1660 train_time:109650ms step_avg:92.30ms +step:1189/1660 train_time:109743ms step_avg:92.30ms +step:1190/1660 train_time:109838ms step_avg:92.30ms +step:1191/1660 train_time:109932ms step_avg:92.30ms +step:1192/1660 train_time:110025ms step_avg:92.30ms +step:1193/1660 train_time:110119ms step_avg:92.30ms +step:1194/1660 train_time:110212ms step_avg:92.30ms +step:1195/1660 train_time:110304ms step_avg:92.30ms +step:1196/1660 train_time:110398ms step_avg:92.31ms +step:1197/1660 train_time:110491ms step_avg:92.31ms +step:1198/1660 train_time:110584ms step_avg:92.31ms +step:1199/1660 train_time:110678ms step_avg:92.31ms +step:1200/1660 train_time:110771ms step_avg:92.31ms +step:1201/1660 train_time:110864ms step_avg:92.31ms +step:1202/1660 train_time:110957ms step_avg:92.31ms +step:1203/1660 train_time:111050ms step_avg:92.31ms +step:1204/1660 train_time:111143ms step_avg:92.31ms +step:1205/1660 train_time:111237ms step_avg:92.31ms +step:1206/1660 train_time:111331ms step_avg:92.31ms +step:1207/1660 train_time:111425ms step_avg:92.32ms +step:1208/1660 train_time:111517ms step_avg:92.32ms +step:1209/1660 train_time:111610ms step_avg:92.32ms +step:1210/1660 train_time:111704ms step_avg:92.32ms +step:1211/1660 train_time:111797ms step_avg:92.32ms +step:1212/1660 train_time:111891ms step_avg:92.32ms +step:1213/1660 train_time:111984ms step_avg:92.32ms +step:1214/1660 train_time:112077ms step_avg:92.32ms +step:1215/1660 train_time:112171ms step_avg:92.32ms +step:1216/1660 train_time:112264ms step_avg:92.32ms +step:1217/1660 train_time:112358ms step_avg:92.32ms +step:1218/1660 train_time:112451ms step_avg:92.32ms +step:1219/1660 train_time:112543ms step_avg:92.32ms +step:1220/1660 train_time:112637ms step_avg:92.33ms +step:1221/1660 train_time:112730ms step_avg:92.33ms +step:1222/1660 train_time:112824ms step_avg:92.33ms +step:1223/1660 train_time:112917ms step_avg:92.33ms +step:1224/1660 train_time:113011ms step_avg:92.33ms +step:1225/1660 train_time:113104ms step_avg:92.33ms +step:1226/1660 train_time:113198ms step_avg:92.33ms +step:1227/1660 train_time:113291ms step_avg:92.33ms +step:1228/1660 train_time:113385ms step_avg:92.33ms +step:1229/1660 train_time:113478ms step_avg:92.33ms +step:1230/1660 train_time:113572ms step_avg:92.33ms +step:1231/1660 train_time:113665ms step_avg:92.34ms +step:1232/1660 train_time:113757ms step_avg:92.34ms +step:1233/1660 train_time:113850ms step_avg:92.34ms +step:1234/1660 train_time:113943ms step_avg:92.34ms +step:1235/1660 train_time:114036ms step_avg:92.34ms +step:1236/1660 train_time:114130ms step_avg:92.34ms +step:1237/1660 train_time:114224ms step_avg:92.34ms +step:1238/1660 train_time:114318ms step_avg:92.34ms +step:1239/1660 train_time:114411ms step_avg:92.34ms +step:1240/1660 train_time:114504ms step_avg:92.34ms +step:1241/1660 train_time:114598ms step_avg:92.34ms +step:1242/1660 train_time:114691ms step_avg:92.34ms +step:1243/1660 train_time:114784ms step_avg:92.34ms +step:1244/1660 train_time:114877ms step_avg:92.34ms +step:1245/1660 train_time:114971ms step_avg:92.35ms +step:1246/1660 train_time:115064ms step_avg:92.35ms +step:1247/1660 train_time:115157ms step_avg:92.35ms +step:1248/1660 train_time:115250ms step_avg:92.35ms +step:1249/1660 train_time:115343ms step_avg:92.35ms +step:1250/1660 train_time:115437ms step_avg:92.35ms +step:1250/1660 val_loss:3.3738 train_time:115532ms step_avg:92.43ms +step:1251/1660 train_time:115553ms step_avg:92.37ms +step:1252/1660 train_time:115628ms step_avg:92.35ms +step:1253/1660 train_time:115728ms step_avg:92.36ms +step:1254/1660 train_time:115822ms step_avg:92.36ms +step:1255/1660 train_time:115913ms step_avg:92.36ms +step:1256/1660 train_time:116006ms step_avg:92.36ms +step:1257/1660 train_time:116097ms step_avg:92.36ms +step:1258/1660 train_time:116190ms step_avg:92.36ms +step:1259/1660 train_time:116282ms step_avg:92.36ms +step:1260/1660 train_time:116374ms step_avg:92.36ms +step:1261/1660 train_time:116468ms step_avg:92.36ms +step:1262/1660 train_time:116565ms step_avg:92.37ms +step:1263/1660 train_time:116660ms step_avg:92.37ms +step:1264/1660 train_time:116754ms step_avg:92.37ms +step:1265/1660 train_time:116847ms step_avg:92.37ms +step:1266/1660 train_time:116940ms step_avg:92.37ms +step:1267/1660 train_time:117032ms step_avg:92.37ms +step:1268/1660 train_time:117124ms step_avg:92.37ms +step:1269/1660 train_time:117217ms step_avg:92.37ms +step:1270/1660 train_time:117309ms step_avg:92.37ms +step:1271/1660 train_time:117402ms step_avg:92.37ms +step:1272/1660 train_time:117494ms step_avg:92.37ms +step:1273/1660 train_time:117590ms step_avg:92.37ms +step:1274/1660 train_time:117686ms step_avg:92.37ms +step:1275/1660 train_time:117779ms step_avg:92.38ms +step:1276/1660 train_time:117874ms step_avg:92.38ms +step:1277/1660 train_time:117967ms step_avg:92.38ms +step:1278/1660 train_time:118059ms step_avg:92.38ms +step:1279/1660 train_time:118151ms step_avg:92.38ms +step:1280/1660 train_time:118244ms step_avg:92.38ms +step:1281/1660 train_time:118336ms step_avg:92.38ms +step:1282/1660 train_time:118429ms step_avg:92.38ms +step:1283/1660 train_time:118523ms step_avg:92.38ms +step:1284/1660 train_time:118618ms step_avg:92.38ms +step:1285/1660 train_time:118711ms step_avg:92.38ms +step:1286/1660 train_time:118806ms step_avg:92.38ms +step:1287/1660 train_time:118898ms step_avg:92.38ms +step:1288/1660 train_time:118992ms step_avg:92.39ms +step:1289/1660 train_time:119085ms step_avg:92.39ms +step:1290/1660 train_time:119177ms step_avg:92.39ms +step:1291/1660 train_time:119270ms step_avg:92.39ms +step:1292/1660 train_time:119363ms step_avg:92.39ms +step:1293/1660 train_time:119456ms step_avg:92.39ms +step:1294/1660 train_time:119549ms step_avg:92.39ms +step:1295/1660 train_time:119643ms step_avg:92.39ms +step:1296/1660 train_time:119736ms step_avg:92.39ms +step:1297/1660 train_time:119830ms step_avg:92.39ms +step:1298/1660 train_time:119924ms step_avg:92.39ms +step:1299/1660 train_time:120017ms step_avg:92.39ms +step:1300/1660 train_time:120109ms step_avg:92.39ms +step:1301/1660 train_time:120203ms step_avg:92.39ms +step:1302/1660 train_time:120295ms step_avg:92.39ms +step:1303/1660 train_time:120387ms step_avg:92.39ms +step:1304/1660 train_time:120480ms step_avg:92.39ms +step:1305/1660 train_time:120575ms step_avg:92.39ms +step:1306/1660 train_time:120670ms step_avg:92.40ms +step:1307/1660 train_time:120762ms step_avg:92.40ms +step:1308/1660 train_time:120857ms step_avg:92.40ms +step:1309/1660 train_time:120950ms step_avg:92.40ms +step:1310/1660 train_time:121043ms step_avg:92.40ms +step:1311/1660 train_time:121136ms step_avg:92.40ms +step:1312/1660 train_time:121229ms step_avg:92.40ms +step:1313/1660 train_time:121322ms step_avg:92.40ms +step:1314/1660 train_time:121415ms step_avg:92.40ms +step:1315/1660 train_time:121508ms step_avg:92.40ms +step:1316/1660 train_time:121601ms step_avg:92.40ms +step:1317/1660 train_time:121694ms step_avg:92.40ms +step:1318/1660 train_time:121787ms step_avg:92.40ms +step:1319/1660 train_time:121880ms step_avg:92.40ms +step:1320/1660 train_time:121975ms step_avg:92.41ms +step:1321/1660 train_time:122068ms step_avg:92.41ms +step:1322/1660 train_time:122161ms step_avg:92.41ms +step:1323/1660 train_time:122254ms step_avg:92.41ms +step:1324/1660 train_time:122348ms step_avg:92.41ms +step:1325/1660 train_time:122441ms step_avg:92.41ms +step:1326/1660 train_time:122534ms step_avg:92.41ms +step:1327/1660 train_time:122628ms step_avg:92.41ms +step:1328/1660 train_time:122722ms step_avg:92.41ms +step:1329/1660 train_time:122815ms step_avg:92.41ms +step:1330/1660 train_time:122908ms step_avg:92.41ms +step:1331/1660 train_time:123002ms step_avg:92.41ms +step:1332/1660 train_time:123096ms step_avg:92.41ms +step:1333/1660 train_time:123189ms step_avg:92.41ms +step:1334/1660 train_time:123281ms step_avg:92.41ms +step:1335/1660 train_time:123375ms step_avg:92.42ms +step:1336/1660 train_time:123468ms step_avg:92.42ms +step:1337/1660 train_time:123562ms step_avg:92.42ms +step:1338/1660 train_time:123654ms step_avg:92.42ms +step:1339/1660 train_time:123748ms step_avg:92.42ms +step:1340/1660 train_time:123841ms step_avg:92.42ms +step:1341/1660 train_time:123935ms step_avg:92.42ms +step:1342/1660 train_time:124028ms step_avg:92.42ms +step:1343/1660 train_time:124121ms step_avg:92.42ms +step:1344/1660 train_time:124214ms step_avg:92.42ms +step:1345/1660 train_time:124307ms step_avg:92.42ms +step:1346/1660 train_time:124400ms step_avg:92.42ms +step:1347/1660 train_time:124494ms step_avg:92.42ms +step:1348/1660 train_time:124586ms step_avg:92.42ms +step:1349/1660 train_time:124679ms step_avg:92.42ms +step:1350/1660 train_time:124772ms step_avg:92.42ms +step:1351/1660 train_time:124867ms step_avg:92.43ms +step:1352/1660 train_time:124959ms step_avg:92.43ms +step:1353/1660 train_time:125054ms step_avg:92.43ms +step:1354/1660 train_time:125147ms step_avg:92.43ms +step:1355/1660 train_time:125239ms step_avg:92.43ms +step:1356/1660 train_time:125333ms step_avg:92.43ms +step:1357/1660 train_time:125426ms step_avg:92.43ms +step:1358/1660 train_time:125519ms step_avg:92.43ms +step:1359/1660 train_time:125612ms step_avg:92.43ms +step:1360/1660 train_time:125705ms step_avg:92.43ms +step:1361/1660 train_time:125798ms step_avg:92.43ms +step:1362/1660 train_time:125892ms step_avg:92.43ms +step:1363/1660 train_time:125986ms step_avg:92.43ms +step:1364/1660 train_time:126078ms step_avg:92.43ms +step:1365/1660 train_time:126173ms step_avg:92.43ms +step:1366/1660 train_time:126266ms step_avg:92.43ms +step:1367/1660 train_time:126359ms step_avg:92.44ms +step:1368/1660 train_time:126453ms step_avg:92.44ms +step:1369/1660 train_time:126546ms step_avg:92.44ms +step:1370/1660 train_time:126639ms step_avg:92.44ms +step:1371/1660 train_time:126732ms step_avg:92.44ms +step:1372/1660 train_time:126826ms step_avg:92.44ms +step:1373/1660 train_time:126920ms step_avg:92.44ms +step:1374/1660 train_time:127013ms step_avg:92.44ms +step:1375/1660 train_time:127106ms step_avg:92.44ms +step:1375/1660 val_loss:3.3400 train_time:127200ms step_avg:92.51ms +step:1376/1660 train_time:127221ms step_avg:92.46ms +step:1377/1660 train_time:127297ms step_avg:92.44ms +step:1378/1660 train_time:127393ms step_avg:92.45ms +step:1379/1660 train_time:127487ms step_avg:92.45ms +step:1380/1660 train_time:127579ms step_avg:92.45ms +step:1381/1660 train_time:127671ms step_avg:92.45ms +step:1382/1660 train_time:127763ms step_avg:92.45ms +step:1383/1660 train_time:127856ms step_avg:92.45ms +step:1384/1660 train_time:127947ms step_avg:92.45ms +step:1385/1660 train_time:128039ms step_avg:92.45ms +step:1386/1660 train_time:128133ms step_avg:92.45ms +step:1387/1660 train_time:128227ms step_avg:92.45ms +step:1388/1660 train_time:128323ms step_avg:92.45ms +step:1389/1660 train_time:128417ms step_avg:92.45ms +step:1390/1660 train_time:128511ms step_avg:92.45ms +step:1391/1660 train_time:128604ms step_avg:92.45ms +step:1392/1660 train_time:128697ms step_avg:92.46ms +step:1393/1660 train_time:128790ms step_avg:92.45ms +step:1394/1660 train_time:128882ms step_avg:92.45ms +step:1395/1660 train_time:128974ms step_avg:92.45ms +step:1396/1660 train_time:129067ms step_avg:92.46ms +step:1397/1660 train_time:129161ms step_avg:92.46ms +step:1398/1660 train_time:129255ms step_avg:92.46ms +step:1399/1660 train_time:129350ms step_avg:92.46ms +step:1400/1660 train_time:129444ms step_avg:92.46ms +step:1401/1660 train_time:129537ms step_avg:92.46ms +step:1402/1660 train_time:129630ms step_avg:92.46ms +step:1403/1660 train_time:129725ms step_avg:92.46ms +step:1404/1660 train_time:129817ms step_avg:92.46ms +step:1405/1660 train_time:129910ms step_avg:92.46ms +step:1406/1660 train_time:130002ms step_avg:92.46ms +step:1407/1660 train_time:130095ms step_avg:92.46ms +step:1408/1660 train_time:130188ms step_avg:92.46ms +step:1409/1660 train_time:130284ms step_avg:92.47ms +step:1410/1660 train_time:130378ms step_avg:92.47ms +step:1411/1660 train_time:130471ms step_avg:92.47ms +step:1412/1660 train_time:130564ms step_avg:92.47ms +step:1413/1660 train_time:130657ms step_avg:92.47ms +step:1414/1660 train_time:130750ms step_avg:92.47ms +step:1415/1660 train_time:130843ms step_avg:92.47ms +step:1416/1660 train_time:130935ms step_avg:92.47ms +step:1417/1660 train_time:131028ms step_avg:92.47ms +step:1418/1660 train_time:131121ms step_avg:92.47ms +step:1419/1660 train_time:131214ms step_avg:92.47ms +step:1420/1660 train_time:131309ms step_avg:92.47ms +step:1421/1660 train_time:131402ms step_avg:92.47ms +step:1422/1660 train_time:131496ms step_avg:92.47ms +step:1423/1660 train_time:131589ms step_avg:92.47ms +step:1424/1660 train_time:131683ms step_avg:92.47ms +step:1425/1660 train_time:131776ms step_avg:92.47ms +step:1426/1660 train_time:131869ms step_avg:92.48ms +step:1427/1660 train_time:131962ms step_avg:92.48ms +step:1428/1660 train_time:132056ms step_avg:92.48ms +step:1429/1660 train_time:132149ms step_avg:92.48ms +step:1430/1660 train_time:132243ms step_avg:92.48ms +step:1431/1660 train_time:132336ms step_avg:92.48ms +step:1432/1660 train_time:132429ms step_avg:92.48ms +step:1433/1660 train_time:132523ms step_avg:92.48ms +step:1434/1660 train_time:132617ms step_avg:92.48ms +step:1435/1660 train_time:132710ms step_avg:92.48ms +step:1436/1660 train_time:132803ms step_avg:92.48ms +step:1437/1660 train_time:132897ms step_avg:92.48ms +step:1438/1660 train_time:132990ms step_avg:92.48ms +step:1439/1660 train_time:133083ms step_avg:92.48ms +step:1440/1660 train_time:133177ms step_avg:92.48ms +step:1441/1660 train_time:133270ms step_avg:92.48ms +step:1442/1660 train_time:133364ms step_avg:92.49ms +step:1443/1660 train_time:133457ms step_avg:92.49ms +step:1444/1660 train_time:133550ms step_avg:92.49ms +step:1445/1660 train_time:133644ms step_avg:92.49ms +step:1446/1660 train_time:133738ms step_avg:92.49ms +step:1447/1660 train_time:133831ms step_avg:92.49ms +step:1448/1660 train_time:133925ms step_avg:92.49ms +step:1449/1660 train_time:134017ms step_avg:92.49ms +step:1450/1660 train_time:134110ms step_avg:92.49ms +step:1451/1660 train_time:134204ms step_avg:92.49ms +step:1452/1660 train_time:134298ms step_avg:92.49ms +step:1453/1660 train_time:134391ms step_avg:92.49ms +step:1454/1660 train_time:134486ms step_avg:92.49ms +step:1455/1660 train_time:134580ms step_avg:92.49ms +step:1456/1660 train_time:134672ms step_avg:92.49ms +step:1457/1660 train_time:134765ms step_avg:92.50ms +step:1458/1660 train_time:134859ms step_avg:92.50ms +step:1459/1660 train_time:134952ms step_avg:92.50ms +step:1460/1660 train_time:135045ms step_avg:92.50ms +step:1461/1660 train_time:135139ms step_avg:92.50ms +step:1462/1660 train_time:135232ms step_avg:92.50ms +step:1463/1660 train_time:135326ms step_avg:92.50ms +step:1464/1660 train_time:135420ms step_avg:92.50ms +step:1465/1660 train_time:135514ms step_avg:92.50ms +step:1466/1660 train_time:135607ms step_avg:92.50ms +step:1467/1660 train_time:135701ms step_avg:92.50ms +step:1468/1660 train_time:135794ms step_avg:92.50ms +step:1469/1660 train_time:135886ms step_avg:92.50ms +step:1470/1660 train_time:135980ms step_avg:92.50ms +step:1471/1660 train_time:136073ms step_avg:92.50ms +step:1472/1660 train_time:136166ms step_avg:92.50ms +step:1473/1660 train_time:136259ms step_avg:92.50ms +step:1474/1660 train_time:136353ms step_avg:92.51ms +step:1475/1660 train_time:136447ms step_avg:92.51ms +step:1476/1660 train_time:136539ms step_avg:92.51ms +step:1477/1660 train_time:136633ms step_avg:92.51ms +step:1478/1660 train_time:136727ms step_avg:92.51ms +step:1479/1660 train_time:136820ms step_avg:92.51ms +step:1480/1660 train_time:136912ms step_avg:92.51ms +step:1481/1660 train_time:137006ms step_avg:92.51ms +step:1482/1660 train_time:137100ms step_avg:92.51ms +step:1483/1660 train_time:137193ms step_avg:92.51ms +step:1484/1660 train_time:137286ms step_avg:92.51ms +step:1485/1660 train_time:137380ms step_avg:92.51ms +step:1486/1660 train_time:137472ms step_avg:92.51ms +step:1487/1660 train_time:137566ms step_avg:92.51ms +step:1488/1660 train_time:137660ms step_avg:92.51ms +step:1489/1660 train_time:137753ms step_avg:92.51ms +step:1490/1660 train_time:137846ms step_avg:92.51ms +step:1491/1660 train_time:137939ms step_avg:92.51ms +step:1492/1660 train_time:138033ms step_avg:92.52ms +step:1493/1660 train_time:138126ms step_avg:92.52ms +step:1494/1660 train_time:138219ms step_avg:92.52ms +step:1495/1660 train_time:138312ms step_avg:92.52ms +step:1496/1660 train_time:138407ms step_avg:92.52ms +step:1497/1660 train_time:138500ms step_avg:92.52ms +step:1498/1660 train_time:138594ms step_avg:92.52ms +step:1499/1660 train_time:138687ms step_avg:92.52ms +step:1500/1660 train_time:138780ms step_avg:92.52ms +step:1500/1660 val_loss:3.3098 train_time:138875ms step_avg:92.58ms +step:1501/1660 train_time:138895ms step_avg:92.54ms +step:1502/1660 train_time:138974ms step_avg:92.53ms +step:1503/1660 train_time:139072ms step_avg:92.53ms +step:1504/1660 train_time:139165ms step_avg:92.53ms +step:1505/1660 train_time:139258ms step_avg:92.53ms +step:1506/1660 train_time:139350ms step_avg:92.53ms +step:1507/1660 train_time:139442ms step_avg:92.53ms +step:1508/1660 train_time:139534ms step_avg:92.53ms +step:1509/1660 train_time:139626ms step_avg:92.53ms +step:1510/1660 train_time:139718ms step_avg:92.53ms +step:1511/1660 train_time:139811ms step_avg:92.53ms +step:1512/1660 train_time:139908ms step_avg:92.53ms +step:1513/1660 train_time:140003ms step_avg:92.53ms +step:1514/1660 train_time:140099ms step_avg:92.54ms +step:1515/1660 train_time:140193ms step_avg:92.54ms +step:1516/1660 train_time:140286ms step_avg:92.54ms +step:1517/1660 train_time:140378ms step_avg:92.54ms +step:1518/1660 train_time:140470ms step_avg:92.54ms +step:1519/1660 train_time:140562ms step_avg:92.54ms +step:1520/1660 train_time:140655ms step_avg:92.54ms +step:1521/1660 train_time:140748ms step_avg:92.54ms +step:1522/1660 train_time:140841ms step_avg:92.54ms +step:1523/1660 train_time:140936ms step_avg:92.54ms +step:1524/1660 train_time:141031ms step_avg:92.54ms +step:1525/1660 train_time:141125ms step_avg:92.54ms +step:1526/1660 train_time:141219ms step_avg:92.54ms +step:1527/1660 train_time:141312ms step_avg:92.54ms +step:1528/1660 train_time:141405ms step_avg:92.54ms +step:1529/1660 train_time:141497ms step_avg:92.54ms +step:1530/1660 train_time:141589ms step_avg:92.54ms +step:1531/1660 train_time:141681ms step_avg:92.54ms +step:1532/1660 train_time:141774ms step_avg:92.54ms +step:1533/1660 train_time:141868ms step_avg:92.54ms +step:1534/1660 train_time:141962ms step_avg:92.54ms +step:1535/1660 train_time:142056ms step_avg:92.54ms +step:1536/1660 train_time:142152ms step_avg:92.55ms +step:1537/1660 train_time:142246ms step_avg:92.55ms +step:1538/1660 train_time:142339ms step_avg:92.55ms +step:1539/1660 train_time:142432ms step_avg:92.55ms +step:1540/1660 train_time:142525ms step_avg:92.55ms +step:1541/1660 train_time:142618ms step_avg:92.55ms +step:1542/1660 train_time:142711ms step_avg:92.55ms +step:1543/1660 train_time:142803ms step_avg:92.55ms +step:1544/1660 train_time:142897ms step_avg:92.55ms +step:1545/1660 train_time:142991ms step_avg:92.55ms +step:1546/1660 train_time:143085ms step_avg:92.55ms +step:1547/1660 train_time:143179ms step_avg:92.55ms +step:1548/1660 train_time:143272ms step_avg:92.55ms +step:1549/1660 train_time:143365ms step_avg:92.55ms +step:1550/1660 train_time:143459ms step_avg:92.55ms +step:1551/1660 train_time:143552ms step_avg:92.55ms +step:1552/1660 train_time:143644ms step_avg:92.55ms +step:1553/1660 train_time:143737ms step_avg:92.55ms +step:1554/1660 train_time:143830ms step_avg:92.55ms +step:1555/1660 train_time:143923ms step_avg:92.56ms +step:1556/1660 train_time:144016ms step_avg:92.56ms +step:1557/1660 train_time:144110ms step_avg:92.56ms +step:1558/1660 train_time:144204ms step_avg:92.56ms +step:1559/1660 train_time:144298ms step_avg:92.56ms +step:1560/1660 train_time:144392ms step_avg:92.56ms +step:1561/1660 train_time:144484ms step_avg:92.56ms +step:1562/1660 train_time:144577ms step_avg:92.56ms +step:1563/1660 train_time:144670ms step_avg:92.56ms +step:1564/1660 train_time:144763ms step_avg:92.56ms +step:1565/1660 train_time:144856ms step_avg:92.56ms +step:1566/1660 train_time:144950ms step_avg:92.56ms +step:1567/1660 train_time:145042ms step_avg:92.56ms +step:1568/1660 train_time:145137ms step_avg:92.56ms +step:1569/1660 train_time:145230ms step_avg:92.56ms +step:1570/1660 train_time:145324ms step_avg:92.56ms +step:1571/1660 train_time:145417ms step_avg:92.56ms +step:1572/1660 train_time:145510ms step_avg:92.56ms +step:1573/1660 train_time:145603ms step_avg:92.56ms +step:1574/1660 train_time:145696ms step_avg:92.56ms +step:1575/1660 train_time:145789ms step_avg:92.56ms +step:1576/1660 train_time:145881ms step_avg:92.56ms +step:1577/1660 train_time:145974ms step_avg:92.56ms +step:1578/1660 train_time:146067ms step_avg:92.56ms +step:1579/1660 train_time:146161ms step_avg:92.57ms +step:1580/1660 train_time:146255ms step_avg:92.57ms +step:1581/1660 train_time:146349ms step_avg:92.57ms +step:1582/1660 train_time:146441ms step_avg:92.57ms +step:1583/1660 train_time:146534ms step_avg:92.57ms +step:1584/1660 train_time:146627ms step_avg:92.57ms +step:1585/1660 train_time:146720ms step_avg:92.57ms +step:1586/1660 train_time:146813ms step_avg:92.57ms +step:1587/1660 train_time:146906ms step_avg:92.57ms +step:1588/1660 train_time:147000ms step_avg:92.57ms +step:1589/1660 train_time:147094ms step_avg:92.57ms +step:1590/1660 train_time:147187ms step_avg:92.57ms +step:1591/1660 train_time:147280ms step_avg:92.57ms +step:1592/1660 train_time:147373ms step_avg:92.57ms +step:1593/1660 train_time:147466ms step_avg:92.57ms +step:1594/1660 train_time:147560ms step_avg:92.57ms +step:1595/1660 train_time:147653ms step_avg:92.57ms +step:1596/1660 train_time:147746ms step_avg:92.57ms +step:1597/1660 train_time:147840ms step_avg:92.57ms +step:1598/1660 train_time:147933ms step_avg:92.57ms +step:1599/1660 train_time:148027ms step_avg:92.57ms +step:1600/1660 train_time:148119ms step_avg:92.57ms +step:1601/1660 train_time:148213ms step_avg:92.58ms +step:1602/1660 train_time:148305ms step_avg:92.58ms +step:1603/1660 train_time:148400ms step_avg:92.58ms +step:1604/1660 train_time:148493ms step_avg:92.58ms +step:1605/1660 train_time:148586ms step_avg:92.58ms +step:1606/1660 train_time:148679ms step_avg:92.58ms +step:1607/1660 train_time:148772ms step_avg:92.58ms +step:1608/1660 train_time:148865ms step_avg:92.58ms +step:1609/1660 train_time:148959ms step_avg:92.58ms +step:1610/1660 train_time:149052ms step_avg:92.58ms +step:1611/1660 train_time:149145ms step_avg:92.58ms +step:1612/1660 train_time:149238ms step_avg:92.58ms +step:1613/1660 train_time:149332ms step_avg:92.58ms +step:1614/1660 train_time:149425ms step_avg:92.58ms +step:1615/1660 train_time:149519ms step_avg:92.58ms +step:1616/1660 train_time:149612ms step_avg:92.58ms +step:1617/1660 train_time:149705ms step_avg:92.58ms +step:1618/1660 train_time:149799ms step_avg:92.58ms +step:1619/1660 train_time:149893ms step_avg:92.58ms +step:1620/1660 train_time:149986ms step_avg:92.58ms +step:1621/1660 train_time:150079ms step_avg:92.58ms +step:1622/1660 train_time:150173ms step_avg:92.58ms +step:1623/1660 train_time:150266ms step_avg:92.59ms +step:1624/1660 train_time:150359ms step_avg:92.59ms +step:1625/1660 train_time:150452ms step_avg:92.59ms +step:1625/1660 val_loss:3.2853 train_time:150547ms step_avg:92.64ms +step:1626/1660 train_time:150567ms step_avg:92.60ms +step:1627/1660 train_time:150644ms step_avg:92.59ms +step:1628/1660 train_time:150743ms step_avg:92.59ms +step:1629/1660 train_time:150835ms step_avg:92.59ms +step:1630/1660 train_time:150928ms step_avg:92.59ms +step:1631/1660 train_time:151020ms step_avg:92.59ms +step:1632/1660 train_time:151112ms step_avg:92.59ms +step:1633/1660 train_time:151204ms step_avg:92.59ms +step:1634/1660 train_time:151296ms step_avg:92.59ms +step:1635/1660 train_time:151388ms step_avg:92.59ms +step:1636/1660 train_time:151481ms step_avg:92.59ms +step:1637/1660 train_time:151577ms step_avg:92.59ms +step:1638/1660 train_time:151675ms step_avg:92.60ms +step:1639/1660 train_time:151770ms step_avg:92.60ms +step:1640/1660 train_time:151863ms step_avg:92.60ms +step:1641/1660 train_time:151956ms step_avg:92.60ms +step:1642/1660 train_time:152049ms step_avg:92.60ms +step:1643/1660 train_time:152141ms step_avg:92.60ms +step:1644/1660 train_time:152233ms step_avg:92.60ms +step:1645/1660 train_time:152326ms step_avg:92.60ms +step:1646/1660 train_time:152418ms step_avg:92.60ms +step:1647/1660 train_time:152512ms step_avg:92.60ms +step:1648/1660 train_time:152607ms step_avg:92.60ms +step:1649/1660 train_time:152702ms step_avg:92.60ms +step:1650/1660 train_time:152795ms step_avg:92.60ms +step:1651/1660 train_time:152889ms step_avg:92.60ms +step:1652/1660 train_time:152983ms step_avg:92.60ms +step:1653/1660 train_time:153076ms step_avg:92.60ms +step:1654/1660 train_time:153169ms step_avg:92.61ms +step:1655/1660 train_time:153261ms step_avg:92.60ms +step:1656/1660 train_time:153353ms step_avg:92.60ms +step:1657/1660 train_time:153447ms step_avg:92.61ms +step:1658/1660 train_time:153542ms step_avg:92.61ms +step:1659/1660 train_time:153637ms step_avg:92.61ms +step:1660/1660 train_time:153733ms step_avg:92.61ms +step:1660/1660 val_loss:3.2774 train_time:153828ms step_avg:92.67ms +peak memory allocated: 32002 MiB reserved: 47016 MiB diff --git a/records/091525_ThreadingFinalWindow/61705980-e239-4d86-9233-210200da7010.txt b/records/091525_ThreadingFinalWindow/61705980-e239-4d86-9233-210200da7010.txt new file mode 100644 index 000000000..d02635a6d --- /dev/null +++ b/records/091525_ThreadingFinalWindow/61705980-e239-4d86-9233-210200da7010.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:33:03 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 35C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 189042 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 189043 C /usr/bin/python3 614MiB | +| 0 N/A N/A 189044 C /usr/bin/python3 614MiB | +| 0 N/A N/A 189045 C /usr/bin/python3 614MiB | +| 0 N/A N/A 189046 C /usr/bin/python3 614MiB | +| 0 N/A N/A 189047 C /usr/bin/python3 614MiB | +| 0 N/A N/A 189048 C /usr/bin/python3 614MiB | +| 0 N/A N/A 189049 C /usr/bin/python3 614MiB | +| 1 N/A N/A 189043 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 189044 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 189045 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 189046 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 189047 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 189048 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 189049 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:146ms step_avg:145.73ms +step:2/1660 train_time:166ms step_avg:82.77ms +step:3/1660 train_time:234ms step_avg:77.88ms +step:4/1660 train_time:323ms step_avg:80.79ms +step:5/1660 train_time:414ms step_avg:82.78ms +step:6/1660 train_time:505ms step_avg:84.09ms +step:7/1660 train_time:596ms step_avg:85.16ms +step:8/1660 train_time:687ms step_avg:85.91ms +step:9/1660 train_time:779ms step_avg:86.51ms +step:10/1660 train_time:869ms step_avg:86.89ms +step:11/1660 train_time:960ms step_avg:87.24ms +step:12/1660 train_time:1052ms step_avg:87.70ms +step:13/1660 train_time:1148ms step_avg:88.33ms +step:14/1660 train_time:1242ms step_avg:88.71ms +step:15/1660 train_time:1333ms step_avg:88.88ms +step:16/1660 train_time:1425ms step_avg:89.07ms +step:17/1660 train_time:1517ms step_avg:89.21ms +step:18/1660 train_time:1608ms step_avg:89.33ms +step:19/1660 train_time:1699ms step_avg:89.43ms +step:20/1660 train_time:1791ms step_avg:89.53ms +step:21/1660 train_time:1881ms step_avg:89.59ms +step:22/1660 train_time:1972ms step_avg:89.65ms +step:23/1660 train_time:2065ms step_avg:89.78ms +step:24/1660 train_time:2157ms step_avg:89.88ms +step:25/1660 train_time:2250ms step_avg:89.98ms +step:26/1660 train_time:2343ms step_avg:90.10ms +step:27/1660 train_time:2435ms step_avg:90.17ms +step:28/1660 train_time:2526ms step_avg:90.22ms +step:29/1660 train_time:2618ms step_avg:90.28ms +step:30/1660 train_time:2709ms step_avg:90.29ms +step:31/1660 train_time:2800ms step_avg:90.31ms +step:32/1660 train_time:2891ms step_avg:90.35ms +step:33/1660 train_time:2983ms step_avg:90.40ms +step:34/1660 train_time:3076ms step_avg:90.46ms +step:35/1660 train_time:3168ms step_avg:90.52ms +step:36/1660 train_time:3261ms step_avg:90.57ms +step:37/1660 train_time:3352ms step_avg:90.61ms +step:38/1660 train_time:3444ms step_avg:90.64ms +step:39/1660 train_time:3535ms step_avg:90.65ms +step:40/1660 train_time:3627ms step_avg:90.67ms +step:41/1660 train_time:3719ms step_avg:90.71ms +step:42/1660 train_time:3811ms step_avg:90.73ms +step:43/1660 train_time:3902ms step_avg:90.74ms +step:44/1660 train_time:3993ms step_avg:90.75ms +step:45/1660 train_time:4086ms step_avg:90.80ms +step:46/1660 train_time:4178ms step_avg:90.82ms +step:47/1660 train_time:4269ms step_avg:90.83ms +step:48/1660 train_time:4362ms step_avg:90.87ms +step:49/1660 train_time:4454ms step_avg:90.89ms +step:50/1660 train_time:4545ms step_avg:90.91ms +step:51/1660 train_time:4636ms step_avg:90.91ms +step:52/1660 train_time:4728ms step_avg:90.93ms +step:53/1660 train_time:4820ms step_avg:90.95ms +step:54/1660 train_time:4912ms step_avg:90.97ms +step:55/1660 train_time:5005ms step_avg:90.99ms +step:56/1660 train_time:5099ms step_avg:91.04ms +step:57/1660 train_time:5190ms step_avg:91.05ms +step:58/1660 train_time:5282ms step_avg:91.06ms +step:59/1660 train_time:5373ms step_avg:91.06ms +step:60/1660 train_time:5464ms step_avg:91.07ms +step:61/1660 train_time:5556ms step_avg:91.08ms +step:62/1660 train_time:5648ms step_avg:91.09ms +step:63/1660 train_time:5739ms step_avg:91.10ms +step:64/1660 train_time:5831ms step_avg:91.10ms +step:65/1660 train_time:5923ms step_avg:91.12ms +step:66/1660 train_time:6015ms step_avg:91.14ms +step:67/1660 train_time:6108ms step_avg:91.16ms +step:68/1660 train_time:6200ms step_avg:91.18ms +step:69/1660 train_time:6291ms step_avg:91.18ms +step:70/1660 train_time:6384ms step_avg:91.20ms +step:71/1660 train_time:6475ms step_avg:91.20ms +step:72/1660 train_time:6566ms step_avg:91.20ms +step:73/1660 train_time:6658ms step_avg:91.21ms +step:74/1660 train_time:6749ms step_avg:91.20ms +step:75/1660 train_time:6840ms step_avg:91.20ms +step:76/1660 train_time:6931ms step_avg:91.20ms +step:77/1660 train_time:7025ms step_avg:91.23ms +step:78/1660 train_time:7118ms step_avg:91.25ms +step:79/1660 train_time:7209ms step_avg:91.25ms +step:80/1660 train_time:7301ms step_avg:91.26ms +step:81/1660 train_time:7392ms step_avg:91.26ms +step:82/1660 train_time:7484ms step_avg:91.27ms +step:83/1660 train_time:7575ms step_avg:91.27ms +step:84/1660 train_time:7666ms step_avg:91.27ms +step:85/1660 train_time:7758ms step_avg:91.26ms +step:86/1660 train_time:7849ms step_avg:91.27ms +step:87/1660 train_time:7941ms step_avg:91.28ms +step:88/1660 train_time:8032ms step_avg:91.28ms +step:89/1660 train_time:8125ms step_avg:91.29ms +step:90/1660 train_time:8218ms step_avg:91.31ms +step:91/1660 train_time:8310ms step_avg:91.32ms +step:92/1660 train_time:8402ms step_avg:91.32ms +step:93/1660 train_time:8493ms step_avg:91.33ms +step:94/1660 train_time:8585ms step_avg:91.33ms +step:95/1660 train_time:8677ms step_avg:91.33ms +step:96/1660 train_time:8767ms step_avg:91.33ms +step:97/1660 train_time:8859ms step_avg:91.33ms +step:98/1660 train_time:8950ms step_avg:91.33ms +step:99/1660 train_time:9042ms step_avg:91.33ms +step:100/1660 train_time:9133ms step_avg:91.33ms +step:101/1660 train_time:9226ms step_avg:91.34ms +step:102/1660 train_time:9318ms step_avg:91.35ms +step:103/1660 train_time:9409ms step_avg:91.35ms +step:104/1660 train_time:9501ms step_avg:91.36ms +step:105/1660 train_time:9592ms step_avg:91.36ms +step:106/1660 train_time:9684ms step_avg:91.36ms +step:107/1660 train_time:9776ms step_avg:91.36ms +step:108/1660 train_time:9867ms step_avg:91.36ms +step:109/1660 train_time:9957ms step_avg:91.35ms +step:110/1660 train_time:10048ms step_avg:91.35ms +step:111/1660 train_time:10140ms step_avg:91.35ms +step:112/1660 train_time:10231ms step_avg:91.35ms +step:113/1660 train_time:10323ms step_avg:91.36ms +step:114/1660 train_time:10415ms step_avg:91.36ms +step:115/1660 train_time:10507ms step_avg:91.36ms +step:116/1660 train_time:10598ms step_avg:91.36ms +step:117/1660 train_time:10689ms step_avg:91.36ms +step:118/1660 train_time:10781ms step_avg:91.37ms +step:119/1660 train_time:10872ms step_avg:91.36ms +step:120/1660 train_time:10963ms step_avg:91.36ms +step:121/1660 train_time:11055ms step_avg:91.36ms +step:122/1660 train_time:11146ms step_avg:91.36ms +step:123/1660 train_time:11238ms step_avg:91.36ms +step:124/1660 train_time:11329ms step_avg:91.36ms +step:125/1660 train_time:11421ms step_avg:91.37ms +step:125/1660 val_loss:4.3232 train_time:11514ms step_avg:92.11ms +step:126/1660 train_time:11535ms step_avg:91.55ms +step:127/1660 train_time:11610ms step_avg:91.42ms +step:128/1660 train_time:11712ms step_avg:91.50ms +step:129/1660 train_time:11806ms step_avg:91.52ms +step:130/1660 train_time:11897ms step_avg:91.52ms +step:131/1660 train_time:11988ms step_avg:91.51ms +step:132/1660 train_time:12078ms step_avg:91.50ms +step:133/1660 train_time:12168ms step_avg:91.49ms +step:134/1660 train_time:12259ms step_avg:91.48ms +step:135/1660 train_time:12349ms step_avg:91.47ms +step:136/1660 train_time:12440ms step_avg:91.47ms +step:137/1660 train_time:12532ms step_avg:91.47ms +step:138/1660 train_time:12626ms step_avg:91.49ms +step:139/1660 train_time:12720ms step_avg:91.51ms +step:140/1660 train_time:12812ms step_avg:91.52ms +step:141/1660 train_time:12905ms step_avg:91.52ms +step:142/1660 train_time:12996ms step_avg:91.52ms +step:143/1660 train_time:13087ms step_avg:91.51ms +step:144/1660 train_time:13176ms step_avg:91.50ms +step:145/1660 train_time:13267ms step_avg:91.50ms +step:146/1660 train_time:13358ms step_avg:91.49ms +step:147/1660 train_time:13449ms step_avg:91.49ms +step:148/1660 train_time:13540ms step_avg:91.49ms +step:149/1660 train_time:13633ms step_avg:91.49ms +step:150/1660 train_time:13726ms step_avg:91.51ms +step:151/1660 train_time:13819ms step_avg:91.51ms +step:152/1660 train_time:13910ms step_avg:91.51ms +step:153/1660 train_time:14001ms step_avg:91.51ms +step:154/1660 train_time:14092ms step_avg:91.51ms +step:155/1660 train_time:14183ms step_avg:91.50ms +step:156/1660 train_time:14273ms step_avg:91.50ms +step:157/1660 train_time:14365ms step_avg:91.50ms +step:158/1660 train_time:14457ms step_avg:91.50ms +step:159/1660 train_time:14548ms step_avg:91.50ms +step:160/1660 train_time:14640ms step_avg:91.50ms +step:161/1660 train_time:14731ms step_avg:91.50ms +step:162/1660 train_time:14824ms step_avg:91.50ms +step:163/1660 train_time:14915ms step_avg:91.50ms +step:164/1660 train_time:15006ms step_avg:91.50ms +step:165/1660 train_time:15098ms step_avg:91.50ms +step:166/1660 train_time:15188ms step_avg:91.50ms +step:167/1660 train_time:15279ms step_avg:91.49ms +step:168/1660 train_time:15370ms step_avg:91.49ms +step:169/1660 train_time:15462ms step_avg:91.49ms +step:170/1660 train_time:15553ms step_avg:91.49ms +step:171/1660 train_time:15644ms step_avg:91.49ms +step:172/1660 train_time:15736ms step_avg:91.49ms +step:173/1660 train_time:15829ms step_avg:91.50ms +step:174/1660 train_time:15920ms step_avg:91.50ms +step:175/1660 train_time:16011ms step_avg:91.49ms +step:176/1660 train_time:16103ms step_avg:91.49ms +step:177/1660 train_time:16194ms step_avg:91.49ms +step:178/1660 train_time:16286ms step_avg:91.50ms +step:179/1660 train_time:16378ms step_avg:91.49ms +step:180/1660 train_time:16469ms step_avg:91.49ms +step:181/1660 train_time:16561ms step_avg:91.50ms +step:182/1660 train_time:16652ms step_avg:91.50ms +step:183/1660 train_time:16745ms step_avg:91.50ms +step:184/1660 train_time:16837ms step_avg:91.51ms +step:185/1660 train_time:16928ms step_avg:91.51ms +step:186/1660 train_time:17019ms step_avg:91.50ms +step:187/1660 train_time:17111ms step_avg:91.50ms +step:188/1660 train_time:17202ms step_avg:91.50ms +step:189/1660 train_time:17293ms step_avg:91.50ms +step:190/1660 train_time:17386ms step_avg:91.50ms +step:191/1660 train_time:17477ms step_avg:91.50ms +step:192/1660 train_time:17569ms step_avg:91.50ms +step:193/1660 train_time:17660ms step_avg:91.50ms +step:194/1660 train_time:17752ms step_avg:91.50ms +step:195/1660 train_time:17844ms step_avg:91.51ms +step:196/1660 train_time:17935ms step_avg:91.50ms +step:197/1660 train_time:18026ms step_avg:91.50ms +step:198/1660 train_time:18117ms step_avg:91.50ms +step:199/1660 train_time:18209ms step_avg:91.50ms +step:200/1660 train_time:18300ms step_avg:91.50ms +step:201/1660 train_time:18392ms step_avg:91.50ms +step:202/1660 train_time:18483ms step_avg:91.50ms +step:203/1660 train_time:18574ms step_avg:91.50ms +step:204/1660 train_time:18668ms step_avg:91.51ms +step:205/1660 train_time:18760ms step_avg:91.51ms +step:206/1660 train_time:18851ms step_avg:91.51ms +step:207/1660 train_time:18943ms step_avg:91.51ms +step:208/1660 train_time:19034ms step_avg:91.51ms +step:209/1660 train_time:19125ms step_avg:91.51ms +step:210/1660 train_time:19216ms step_avg:91.51ms +step:211/1660 train_time:19308ms step_avg:91.51ms +step:212/1660 train_time:19400ms step_avg:91.51ms +step:213/1660 train_time:19491ms step_avg:91.51ms +step:214/1660 train_time:19583ms step_avg:91.51ms +step:215/1660 train_time:19674ms step_avg:91.51ms +step:216/1660 train_time:19767ms step_avg:91.51ms +step:217/1660 train_time:19858ms step_avg:91.51ms +step:218/1660 train_time:19950ms step_avg:91.51ms +step:219/1660 train_time:20041ms step_avg:91.51ms +step:220/1660 train_time:20132ms step_avg:91.51ms +step:221/1660 train_time:20224ms step_avg:91.51ms +step:222/1660 train_time:20315ms step_avg:91.51ms +step:223/1660 train_time:20405ms step_avg:91.50ms +step:224/1660 train_time:20496ms step_avg:91.50ms +step:225/1660 train_time:20587ms step_avg:91.50ms +step:226/1660 train_time:20679ms step_avg:91.50ms +step:227/1660 train_time:20770ms step_avg:91.50ms +step:228/1660 train_time:20861ms step_avg:91.50ms +step:229/1660 train_time:20952ms step_avg:91.49ms +step:230/1660 train_time:21043ms step_avg:91.49ms +step:231/1660 train_time:21134ms step_avg:91.49ms +step:232/1660 train_time:21225ms step_avg:91.49ms +step:233/1660 train_time:21317ms step_avg:91.49ms +step:234/1660 train_time:21408ms step_avg:91.49ms +step:235/1660 train_time:21499ms step_avg:91.49ms +step:236/1660 train_time:21591ms step_avg:91.49ms +step:237/1660 train_time:21682ms step_avg:91.48ms +step:238/1660 train_time:21773ms step_avg:91.48ms +step:239/1660 train_time:21865ms step_avg:91.48ms +step:240/1660 train_time:21955ms step_avg:91.48ms +step:241/1660 train_time:22048ms step_avg:91.48ms +step:242/1660 train_time:22140ms step_avg:91.49ms +step:243/1660 train_time:22231ms step_avg:91.49ms +step:244/1660 train_time:22322ms step_avg:91.49ms +step:245/1660 train_time:22413ms step_avg:91.48ms +step:246/1660 train_time:22504ms step_avg:91.48ms +step:247/1660 train_time:22596ms step_avg:91.48ms +step:248/1660 train_time:22687ms step_avg:91.48ms +step:249/1660 train_time:22778ms step_avg:91.48ms +step:250/1660 train_time:22870ms step_avg:91.48ms +step:250/1660 val_loss:3.9761 train_time:22963ms step_avg:91.85ms +step:251/1660 train_time:22982ms step_avg:91.56ms +step:252/1660 train_time:23059ms step_avg:91.50ms +step:253/1660 train_time:23155ms step_avg:91.52ms +step:254/1660 train_time:23247ms step_avg:91.52ms +step:255/1660 train_time:23337ms step_avg:91.52ms +step:256/1660 train_time:23428ms step_avg:91.51ms +step:257/1660 train_time:23517ms step_avg:91.51ms +step:258/1660 train_time:23608ms step_avg:91.51ms +step:259/1660 train_time:23698ms step_avg:91.50ms +step:260/1660 train_time:23788ms step_avg:91.49ms +step:261/1660 train_time:23880ms step_avg:91.49ms +step:262/1660 train_time:23973ms step_avg:91.50ms +step:263/1660 train_time:24066ms step_avg:91.51ms +step:264/1660 train_time:24159ms step_avg:91.51ms +step:265/1660 train_time:24251ms step_avg:91.51ms +step:266/1660 train_time:24342ms step_avg:91.51ms +step:267/1660 train_time:24434ms step_avg:91.51ms +step:268/1660 train_time:24524ms step_avg:91.51ms +step:269/1660 train_time:24616ms step_avg:91.51ms +step:270/1660 train_time:24708ms step_avg:91.51ms +step:271/1660 train_time:24798ms step_avg:91.51ms +step:272/1660 train_time:24890ms step_avg:91.51ms +step:273/1660 train_time:24981ms step_avg:91.51ms +step:274/1660 train_time:25074ms step_avg:91.51ms +step:275/1660 train_time:25166ms step_avg:91.51ms +step:276/1660 train_time:25258ms step_avg:91.51ms +step:277/1660 train_time:25350ms step_avg:91.52ms +step:278/1660 train_time:25440ms step_avg:91.51ms +step:279/1660 train_time:25531ms step_avg:91.51ms +step:280/1660 train_time:25621ms step_avg:91.51ms +step:281/1660 train_time:25713ms step_avg:91.51ms +step:282/1660 train_time:25805ms step_avg:91.51ms +step:283/1660 train_time:25896ms step_avg:91.51ms +step:284/1660 train_time:25987ms step_avg:91.50ms +step:285/1660 train_time:26079ms step_avg:91.50ms +step:286/1660 train_time:26171ms step_avg:91.51ms +step:287/1660 train_time:26262ms step_avg:91.50ms +step:288/1660 train_time:26353ms step_avg:91.50ms +step:289/1660 train_time:26444ms step_avg:91.50ms +step:290/1660 train_time:26535ms step_avg:91.50ms +step:291/1660 train_time:26626ms step_avg:91.50ms +step:292/1660 train_time:26717ms step_avg:91.50ms +step:293/1660 train_time:26808ms step_avg:91.49ms +step:294/1660 train_time:26899ms step_avg:91.49ms +step:295/1660 train_time:26991ms step_avg:91.49ms +step:296/1660 train_time:27082ms step_avg:91.49ms +step:297/1660 train_time:27174ms step_avg:91.50ms +step:298/1660 train_time:27266ms step_avg:91.50ms +step:299/1660 train_time:27358ms step_avg:91.50ms +step:300/1660 train_time:27449ms step_avg:91.50ms +step:301/1660 train_time:27540ms step_avg:91.49ms +step:302/1660 train_time:27631ms step_avg:91.49ms +step:303/1660 train_time:27721ms step_avg:91.49ms +step:304/1660 train_time:27813ms step_avg:91.49ms +step:305/1660 train_time:27905ms step_avg:91.49ms +step:306/1660 train_time:27997ms step_avg:91.49ms +step:307/1660 train_time:28088ms step_avg:91.49ms +step:308/1660 train_time:28180ms step_avg:91.49ms +step:309/1660 train_time:28271ms step_avg:91.49ms +step:310/1660 train_time:28362ms step_avg:91.49ms +step:311/1660 train_time:28453ms step_avg:91.49ms +step:312/1660 train_time:28545ms step_avg:91.49ms +step:313/1660 train_time:28636ms step_avg:91.49ms +step:314/1660 train_time:28727ms step_avg:91.49ms +step:315/1660 train_time:28818ms step_avg:91.48ms +step:316/1660 train_time:28909ms step_avg:91.49ms +step:317/1660 train_time:29001ms step_avg:91.49ms +step:318/1660 train_time:29094ms step_avg:91.49ms +step:319/1660 train_time:29185ms step_avg:91.49ms +step:320/1660 train_time:29276ms step_avg:91.49ms +step:321/1660 train_time:29368ms step_avg:91.49ms +step:322/1660 train_time:29459ms step_avg:91.49ms +step:323/1660 train_time:29550ms step_avg:91.49ms +step:324/1660 train_time:29642ms step_avg:91.49ms +step:325/1660 train_time:29733ms step_avg:91.49ms +step:326/1660 train_time:29824ms step_avg:91.49ms +step:327/1660 train_time:29916ms step_avg:91.49ms +step:328/1660 train_time:30008ms step_avg:91.49ms +step:329/1660 train_time:30099ms step_avg:91.49ms +step:330/1660 train_time:30190ms step_avg:91.49ms +step:331/1660 train_time:30282ms step_avg:91.49ms +step:332/1660 train_time:30373ms step_avg:91.49ms +step:333/1660 train_time:30464ms step_avg:91.48ms +step:334/1660 train_time:30555ms step_avg:91.48ms +step:335/1660 train_time:30646ms step_avg:91.48ms +step:336/1660 train_time:30738ms step_avg:91.48ms +step:337/1660 train_time:30828ms step_avg:91.48ms +step:338/1660 train_time:30919ms step_avg:91.48ms +step:339/1660 train_time:31011ms step_avg:91.48ms +step:340/1660 train_time:31103ms step_avg:91.48ms +step:341/1660 train_time:31195ms step_avg:91.48ms +step:342/1660 train_time:31287ms step_avg:91.48ms +step:343/1660 train_time:31378ms step_avg:91.48ms +step:344/1660 train_time:31470ms step_avg:91.48ms +step:345/1660 train_time:31560ms step_avg:91.48ms +step:346/1660 train_time:31651ms step_avg:91.48ms +step:347/1660 train_time:31742ms step_avg:91.48ms +step:348/1660 train_time:31834ms step_avg:91.48ms +step:349/1660 train_time:31925ms step_avg:91.47ms +step:350/1660 train_time:32016ms step_avg:91.47ms +step:351/1660 train_time:32108ms step_avg:91.48ms +step:352/1660 train_time:32199ms step_avg:91.48ms +step:353/1660 train_time:32292ms step_avg:91.48ms +step:354/1660 train_time:32384ms step_avg:91.48ms +step:355/1660 train_time:32476ms step_avg:91.48ms +step:356/1660 train_time:32567ms step_avg:91.48ms +step:357/1660 train_time:32658ms step_avg:91.48ms +step:358/1660 train_time:32749ms step_avg:91.48ms +step:359/1660 train_time:32840ms step_avg:91.48ms +step:360/1660 train_time:32931ms step_avg:91.47ms +step:361/1660 train_time:33022ms step_avg:91.47ms +step:362/1660 train_time:33115ms step_avg:91.48ms +step:363/1660 train_time:33207ms step_avg:91.48ms +step:364/1660 train_time:33298ms step_avg:91.48ms +step:365/1660 train_time:33390ms step_avg:91.48ms +step:366/1660 train_time:33482ms step_avg:91.48ms +step:367/1660 train_time:33574ms step_avg:91.48ms +step:368/1660 train_time:33665ms step_avg:91.48ms +step:369/1660 train_time:33756ms step_avg:91.48ms +step:370/1660 train_time:33846ms step_avg:91.48ms +step:371/1660 train_time:33937ms step_avg:91.48ms +step:372/1660 train_time:34028ms step_avg:91.47ms +step:373/1660 train_time:34120ms step_avg:91.47ms +step:374/1660 train_time:34212ms step_avg:91.48ms +step:375/1660 train_time:34304ms step_avg:91.48ms +step:375/1660 val_loss:3.8184 train_time:34398ms step_avg:91.73ms +step:376/1660 train_time:34417ms step_avg:91.53ms +step:377/1660 train_time:34493ms step_avg:91.49ms +step:378/1660 train_time:34590ms step_avg:91.51ms +step:379/1660 train_time:34682ms step_avg:91.51ms +step:380/1660 train_time:34773ms step_avg:91.51ms +step:381/1660 train_time:34865ms step_avg:91.51ms +step:382/1660 train_time:34956ms step_avg:91.51ms +step:383/1660 train_time:35046ms step_avg:91.50ms +step:384/1660 train_time:35136ms step_avg:91.50ms +step:385/1660 train_time:35227ms step_avg:91.50ms +step:386/1660 train_time:35318ms step_avg:91.50ms +step:387/1660 train_time:35409ms step_avg:91.50ms +step:388/1660 train_time:35502ms step_avg:91.50ms +step:389/1660 train_time:35595ms step_avg:91.50ms +step:390/1660 train_time:35688ms step_avg:91.51ms +step:391/1660 train_time:35779ms step_avg:91.51ms +step:392/1660 train_time:35870ms step_avg:91.50ms +step:393/1660 train_time:35961ms step_avg:91.50ms +step:394/1660 train_time:36052ms step_avg:91.50ms +step:395/1660 train_time:36143ms step_avg:91.50ms +step:396/1660 train_time:36233ms step_avg:91.50ms +step:397/1660 train_time:36324ms step_avg:91.50ms +step:398/1660 train_time:36416ms step_avg:91.50ms +step:399/1660 train_time:36509ms step_avg:91.50ms +step:400/1660 train_time:36600ms step_avg:91.50ms +step:401/1660 train_time:36692ms step_avg:91.50ms +step:402/1660 train_time:36784ms step_avg:91.50ms +step:403/1660 train_time:36875ms step_avg:91.50ms +step:404/1660 train_time:36966ms step_avg:91.50ms +step:405/1660 train_time:37057ms step_avg:91.50ms +step:406/1660 train_time:37149ms step_avg:91.50ms +step:407/1660 train_time:37240ms step_avg:91.50ms +step:408/1660 train_time:37331ms step_avg:91.50ms +step:409/1660 train_time:37422ms step_avg:91.50ms +step:410/1660 train_time:37514ms step_avg:91.50ms +step:411/1660 train_time:37606ms step_avg:91.50ms +step:412/1660 train_time:37698ms step_avg:91.50ms +step:413/1660 train_time:37789ms step_avg:91.50ms +step:414/1660 train_time:37881ms step_avg:91.50ms +step:415/1660 train_time:37972ms step_avg:91.50ms +step:416/1660 train_time:38063ms step_avg:91.50ms +step:417/1660 train_time:38154ms step_avg:91.50ms +step:418/1660 train_time:38247ms step_avg:91.50ms +step:419/1660 train_time:38339ms step_avg:91.50ms +step:420/1660 train_time:38430ms step_avg:91.50ms +step:421/1660 train_time:38522ms step_avg:91.50ms +step:422/1660 train_time:38613ms step_avg:91.50ms +step:423/1660 train_time:38705ms step_avg:91.50ms +step:424/1660 train_time:38796ms step_avg:91.50ms +step:425/1660 train_time:38887ms step_avg:91.50ms +step:426/1660 train_time:38979ms step_avg:91.50ms +step:427/1660 train_time:39071ms step_avg:91.50ms +step:428/1660 train_time:39162ms step_avg:91.50ms +step:429/1660 train_time:39252ms step_avg:91.50ms +step:430/1660 train_time:39344ms step_avg:91.50ms +step:431/1660 train_time:39436ms step_avg:91.50ms +step:432/1660 train_time:39528ms step_avg:91.50ms +step:433/1660 train_time:39619ms step_avg:91.50ms +step:434/1660 train_time:39710ms step_avg:91.50ms +step:435/1660 train_time:39802ms step_avg:91.50ms +step:436/1660 train_time:39893ms step_avg:91.50ms +step:437/1660 train_time:39986ms step_avg:91.50ms +step:438/1660 train_time:40077ms step_avg:91.50ms +step:439/1660 train_time:40168ms step_avg:91.50ms +step:440/1660 train_time:40259ms step_avg:91.50ms +step:441/1660 train_time:40350ms step_avg:91.50ms +step:442/1660 train_time:40442ms step_avg:91.50ms +step:443/1660 train_time:40534ms step_avg:91.50ms +step:444/1660 train_time:40625ms step_avg:91.50ms +step:445/1660 train_time:40717ms step_avg:91.50ms +step:446/1660 train_time:40808ms step_avg:91.50ms +step:447/1660 train_time:40899ms step_avg:91.50ms +step:448/1660 train_time:40991ms step_avg:91.50ms +step:449/1660 train_time:41082ms step_avg:91.50ms +step:450/1660 train_time:41173ms step_avg:91.50ms +step:451/1660 train_time:41266ms step_avg:91.50ms +step:452/1660 train_time:41358ms step_avg:91.50ms +step:453/1660 train_time:41449ms step_avg:91.50ms +step:454/1660 train_time:41540ms step_avg:91.50ms +step:455/1660 train_time:41632ms step_avg:91.50ms +step:456/1660 train_time:41724ms step_avg:91.50ms +step:457/1660 train_time:41815ms step_avg:91.50ms +step:458/1660 train_time:41907ms step_avg:91.50ms +step:459/1660 train_time:41999ms step_avg:91.50ms +step:460/1660 train_time:42090ms step_avg:91.50ms +step:461/1660 train_time:42182ms step_avg:91.50ms +step:462/1660 train_time:42272ms step_avg:91.50ms +step:463/1660 train_time:42364ms step_avg:91.50ms +step:464/1660 train_time:42457ms step_avg:91.50ms +step:465/1660 train_time:42548ms step_avg:91.50ms +step:466/1660 train_time:42640ms step_avg:91.50ms +step:467/1660 train_time:42731ms step_avg:91.50ms +step:468/1660 train_time:42823ms step_avg:91.50ms +step:469/1660 train_time:42914ms step_avg:91.50ms +step:470/1660 train_time:43006ms step_avg:91.50ms +step:471/1660 train_time:43097ms step_avg:91.50ms +step:472/1660 train_time:43188ms step_avg:91.50ms +step:473/1660 train_time:43279ms step_avg:91.50ms +step:474/1660 train_time:43370ms step_avg:91.50ms +step:475/1660 train_time:43463ms step_avg:91.50ms +step:476/1660 train_time:43555ms step_avg:91.50ms +step:477/1660 train_time:43647ms step_avg:91.50ms +step:478/1660 train_time:43738ms step_avg:91.50ms +step:479/1660 train_time:43829ms step_avg:91.50ms +step:480/1660 train_time:43920ms step_avg:91.50ms +step:481/1660 train_time:44012ms step_avg:91.50ms +step:482/1660 train_time:44103ms step_avg:91.50ms +step:483/1660 train_time:44195ms step_avg:91.50ms +step:484/1660 train_time:44286ms step_avg:91.50ms +step:485/1660 train_time:44378ms step_avg:91.50ms +step:486/1660 train_time:44469ms step_avg:91.50ms +step:487/1660 train_time:44561ms step_avg:91.50ms +step:488/1660 train_time:44653ms step_avg:91.50ms +step:489/1660 train_time:44743ms step_avg:91.50ms +step:490/1660 train_time:44835ms step_avg:91.50ms +step:491/1660 train_time:44927ms step_avg:91.50ms +step:492/1660 train_time:45018ms step_avg:91.50ms +step:493/1660 train_time:45110ms step_avg:91.50ms +step:494/1660 train_time:45200ms step_avg:91.50ms +step:495/1660 train_time:45292ms step_avg:91.50ms +step:496/1660 train_time:45383ms step_avg:91.50ms +step:497/1660 train_time:45474ms step_avg:91.50ms +step:498/1660 train_time:45567ms step_avg:91.50ms +step:499/1660 train_time:45658ms step_avg:91.50ms +step:500/1660 train_time:45749ms step_avg:91.50ms +step:500/1660 val_loss:3.7164 train_time:45842ms step_avg:91.68ms +step:501/1660 train_time:45862ms step_avg:91.54ms +step:502/1660 train_time:45937ms step_avg:91.51ms +step:503/1660 train_time:46033ms step_avg:91.52ms +step:504/1660 train_time:46125ms step_avg:91.52ms +step:505/1660 train_time:46216ms step_avg:91.52ms +step:506/1660 train_time:46308ms step_avg:91.52ms +step:507/1660 train_time:46398ms step_avg:91.52ms +step:508/1660 train_time:46488ms step_avg:91.51ms +step:509/1660 train_time:46579ms step_avg:91.51ms +step:510/1660 train_time:46669ms step_avg:91.51ms +step:511/1660 train_time:46760ms step_avg:91.51ms +step:512/1660 train_time:46853ms step_avg:91.51ms +step:513/1660 train_time:46945ms step_avg:91.51ms +step:514/1660 train_time:47039ms step_avg:91.52ms +step:515/1660 train_time:47131ms step_avg:91.52ms +step:516/1660 train_time:47223ms step_avg:91.52ms +step:517/1660 train_time:47314ms step_avg:91.52ms +step:518/1660 train_time:47405ms step_avg:91.51ms +step:519/1660 train_time:47495ms step_avg:91.51ms +step:520/1660 train_time:47585ms step_avg:91.51ms +step:521/1660 train_time:47675ms step_avg:91.51ms +step:522/1660 train_time:47766ms step_avg:91.51ms +step:523/1660 train_time:47858ms step_avg:91.51ms +step:524/1660 train_time:47950ms step_avg:91.51ms +step:525/1660 train_time:48043ms step_avg:91.51ms +step:526/1660 train_time:48134ms step_avg:91.51ms +step:527/1660 train_time:48226ms step_avg:91.51ms +step:528/1660 train_time:48318ms step_avg:91.51ms +step:529/1660 train_time:48408ms step_avg:91.51ms +step:530/1660 train_time:48499ms step_avg:91.51ms +step:531/1660 train_time:48589ms step_avg:91.51ms +step:532/1660 train_time:48680ms step_avg:91.50ms +step:533/1660 train_time:48770ms step_avg:91.50ms +step:534/1660 train_time:48862ms step_avg:91.50ms +step:535/1660 train_time:48953ms step_avg:91.50ms +step:536/1660 train_time:49044ms step_avg:91.50ms +step:537/1660 train_time:49136ms step_avg:91.50ms +step:538/1660 train_time:49227ms step_avg:91.50ms +step:539/1660 train_time:49319ms step_avg:91.50ms +step:540/1660 train_time:49410ms step_avg:91.50ms +step:541/1660 train_time:49501ms step_avg:91.50ms +step:542/1660 train_time:49593ms step_avg:91.50ms +step:543/1660 train_time:49683ms step_avg:91.50ms +step:544/1660 train_time:49773ms step_avg:91.50ms +step:545/1660 train_time:49864ms step_avg:91.49ms +step:546/1660 train_time:49956ms step_avg:91.50ms +step:547/1660 train_time:50047ms step_avg:91.49ms +step:548/1660 train_time:50139ms step_avg:91.49ms +step:549/1660 train_time:50231ms step_avg:91.49ms +step:550/1660 train_time:50322ms step_avg:91.49ms +step:551/1660 train_time:50413ms step_avg:91.49ms +step:552/1660 train_time:50504ms step_avg:91.49ms +step:553/1660 train_time:50596ms step_avg:91.49ms +step:554/1660 train_time:50687ms step_avg:91.49ms +step:555/1660 train_time:50778ms step_avg:91.49ms +step:556/1660 train_time:50871ms step_avg:91.49ms +step:557/1660 train_time:50964ms step_avg:91.50ms +step:558/1660 train_time:51056ms step_avg:91.50ms +step:559/1660 train_time:51149ms step_avg:91.50ms +step:560/1660 train_time:51241ms step_avg:91.50ms +step:561/1660 train_time:51334ms step_avg:91.50ms +step:562/1660 train_time:51426ms step_avg:91.51ms +step:563/1660 train_time:51521ms step_avg:91.51ms +step:564/1660 train_time:51614ms step_avg:91.51ms +step:565/1660 train_time:51706ms step_avg:91.51ms +step:566/1660 train_time:51799ms step_avg:91.52ms +step:567/1660 train_time:51892ms step_avg:91.52ms +step:568/1660 train_time:51984ms step_avg:91.52ms +step:569/1660 train_time:52076ms step_avg:91.52ms +step:570/1660 train_time:52168ms step_avg:91.52ms +step:571/1660 train_time:52261ms step_avg:91.52ms +step:572/1660 train_time:52353ms step_avg:91.53ms +step:573/1660 train_time:52446ms step_avg:91.53ms +step:574/1660 train_time:52539ms step_avg:91.53ms +step:575/1660 train_time:52632ms step_avg:91.53ms +step:576/1660 train_time:52725ms step_avg:91.54ms +step:577/1660 train_time:52819ms step_avg:91.54ms +step:578/1660 train_time:52911ms step_avg:91.54ms +step:579/1660 train_time:53003ms step_avg:91.54ms +step:580/1660 train_time:53095ms step_avg:91.54ms +step:581/1660 train_time:53189ms step_avg:91.55ms +step:582/1660 train_time:53281ms step_avg:91.55ms +step:583/1660 train_time:53373ms step_avg:91.55ms +step:584/1660 train_time:53465ms step_avg:91.55ms +step:585/1660 train_time:53559ms step_avg:91.55ms +step:586/1660 train_time:53652ms step_avg:91.56ms +step:587/1660 train_time:53745ms step_avg:91.56ms +step:588/1660 train_time:53838ms step_avg:91.56ms +step:589/1660 train_time:53931ms step_avg:91.56ms +step:590/1660 train_time:54023ms step_avg:91.56ms +step:591/1660 train_time:54115ms step_avg:91.57ms +step:592/1660 train_time:54207ms step_avg:91.57ms +step:593/1660 train_time:54299ms step_avg:91.57ms +step:594/1660 train_time:54392ms step_avg:91.57ms +step:595/1660 train_time:54485ms step_avg:91.57ms +step:596/1660 train_time:54578ms step_avg:91.57ms +step:597/1660 train_time:54670ms step_avg:91.57ms +step:598/1660 train_time:54763ms step_avg:91.58ms +step:599/1660 train_time:54855ms step_avg:91.58ms +step:600/1660 train_time:54947ms step_avg:91.58ms +step:601/1660 train_time:55040ms step_avg:91.58ms +step:602/1660 train_time:55133ms step_avg:91.58ms +step:603/1660 train_time:55226ms step_avg:91.58ms +step:604/1660 train_time:55318ms step_avg:91.59ms +step:605/1660 train_time:55411ms step_avg:91.59ms +step:606/1660 train_time:55503ms step_avg:91.59ms +step:607/1660 train_time:55595ms step_avg:91.59ms +step:608/1660 train_time:55687ms step_avg:91.59ms +step:609/1660 train_time:55781ms step_avg:91.59ms +step:610/1660 train_time:55873ms step_avg:91.59ms +step:611/1660 train_time:55965ms step_avg:91.60ms +step:612/1660 train_time:56057ms step_avg:91.60ms +step:613/1660 train_time:56150ms step_avg:91.60ms +step:614/1660 train_time:56242ms step_avg:91.60ms +step:615/1660 train_time:56334ms step_avg:91.60ms +step:616/1660 train_time:56426ms step_avg:91.60ms +step:617/1660 train_time:56519ms step_avg:91.60ms +step:618/1660 train_time:56612ms step_avg:91.60ms +step:619/1660 train_time:56704ms step_avg:91.61ms +step:620/1660 train_time:56797ms step_avg:91.61ms +step:621/1660 train_time:56889ms step_avg:91.61ms +step:622/1660 train_time:56982ms step_avg:91.61ms +step:623/1660 train_time:57075ms step_avg:91.61ms +step:624/1660 train_time:57167ms step_avg:91.61ms +step:625/1660 train_time:57259ms step_avg:91.61ms +step:625/1660 val_loss:3.6137 train_time:57354ms step_avg:91.77ms +step:626/1660 train_time:57373ms step_avg:91.65ms +step:627/1660 train_time:57454ms step_avg:91.63ms +step:628/1660 train_time:57553ms step_avg:91.64ms +step:629/1660 train_time:57647ms step_avg:91.65ms +step:630/1660 train_time:57739ms step_avg:91.65ms +step:631/1660 train_time:57829ms step_avg:91.65ms +step:632/1660 train_time:57920ms step_avg:91.65ms +step:633/1660 train_time:58012ms step_avg:91.65ms +step:634/1660 train_time:58103ms step_avg:91.65ms +step:635/1660 train_time:58194ms step_avg:91.64ms +step:636/1660 train_time:58287ms step_avg:91.65ms +step:637/1660 train_time:58384ms step_avg:91.65ms +step:638/1660 train_time:58483ms step_avg:91.67ms +step:639/1660 train_time:58579ms step_avg:91.67ms +step:640/1660 train_time:58673ms step_avg:91.68ms +step:641/1660 train_time:58765ms step_avg:91.68ms +step:642/1660 train_time:58857ms step_avg:91.68ms +step:643/1660 train_time:58948ms step_avg:91.68ms +step:644/1660 train_time:59040ms step_avg:91.68ms +step:645/1660 train_time:59132ms step_avg:91.68ms +step:646/1660 train_time:59223ms step_avg:91.68ms +step:647/1660 train_time:59316ms step_avg:91.68ms +step:648/1660 train_time:59409ms step_avg:91.68ms +step:649/1660 train_time:59504ms step_avg:91.69ms +step:650/1660 train_time:59599ms step_avg:91.69ms +step:651/1660 train_time:59692ms step_avg:91.69ms +step:652/1660 train_time:59785ms step_avg:91.69ms +step:653/1660 train_time:59877ms step_avg:91.70ms +step:654/1660 train_time:59970ms step_avg:91.70ms +step:655/1660 train_time:60061ms step_avg:91.70ms +step:656/1660 train_time:60152ms step_avg:91.70ms +step:657/1660 train_time:60244ms step_avg:91.70ms +step:658/1660 train_time:60337ms step_avg:91.70ms +step:659/1660 train_time:60429ms step_avg:91.70ms +step:660/1660 train_time:60523ms step_avg:91.70ms +step:661/1660 train_time:60617ms step_avg:91.70ms +step:662/1660 train_time:60709ms step_avg:91.71ms +step:663/1660 train_time:60802ms step_avg:91.71ms +step:664/1660 train_time:60895ms step_avg:91.71ms +step:665/1660 train_time:60987ms step_avg:91.71ms +step:666/1660 train_time:61079ms step_avg:91.71ms +step:667/1660 train_time:61170ms step_avg:91.71ms +step:668/1660 train_time:61262ms step_avg:91.71ms +step:669/1660 train_time:61355ms step_avg:91.71ms +step:670/1660 train_time:61447ms step_avg:91.71ms +step:671/1660 train_time:61541ms step_avg:91.71ms +step:672/1660 train_time:61633ms step_avg:91.72ms +step:673/1660 train_time:61725ms step_avg:91.72ms +step:674/1660 train_time:61819ms step_avg:91.72ms +step:675/1660 train_time:61911ms step_avg:91.72ms +step:676/1660 train_time:62003ms step_avg:91.72ms +step:677/1660 train_time:62096ms step_avg:91.72ms +step:678/1660 train_time:62187ms step_avg:91.72ms +step:679/1660 train_time:62280ms step_avg:91.72ms +step:680/1660 train_time:62373ms step_avg:91.72ms +step:681/1660 train_time:62466ms step_avg:91.73ms +step:682/1660 train_time:62560ms step_avg:91.73ms +step:683/1660 train_time:62653ms step_avg:91.73ms +step:684/1660 train_time:62745ms step_avg:91.73ms +step:685/1660 train_time:62840ms step_avg:91.74ms +step:686/1660 train_time:62932ms step_avg:91.74ms +step:687/1660 train_time:63023ms step_avg:91.74ms +step:688/1660 train_time:63116ms step_avg:91.74ms +step:689/1660 train_time:63208ms step_avg:91.74ms +step:690/1660 train_time:63301ms step_avg:91.74ms +step:691/1660 train_time:63394ms step_avg:91.74ms +step:692/1660 train_time:63487ms step_avg:91.74ms +step:693/1660 train_time:63581ms step_avg:91.75ms +step:694/1660 train_time:63674ms step_avg:91.75ms +step:695/1660 train_time:63767ms step_avg:91.75ms +step:696/1660 train_time:63861ms step_avg:91.75ms +step:697/1660 train_time:63953ms step_avg:91.75ms +step:698/1660 train_time:64045ms step_avg:91.75ms +step:699/1660 train_time:64137ms step_avg:91.76ms +step:700/1660 train_time:64229ms step_avg:91.76ms +step:701/1660 train_time:64322ms step_avg:91.76ms +step:702/1660 train_time:64414ms step_avg:91.76ms +step:703/1660 train_time:64508ms step_avg:91.76ms +step:704/1660 train_time:64602ms step_avg:91.76ms +step:705/1660 train_time:64695ms step_avg:91.77ms +step:706/1660 train_time:64787ms step_avg:91.77ms +step:707/1660 train_time:64881ms step_avg:91.77ms +step:708/1660 train_time:64974ms step_avg:91.77ms +step:709/1660 train_time:65066ms step_avg:91.77ms +step:710/1660 train_time:65159ms step_avg:91.77ms +step:711/1660 train_time:65251ms step_avg:91.77ms +step:712/1660 train_time:65343ms step_avg:91.77ms +step:713/1660 train_time:65435ms step_avg:91.77ms +step:714/1660 train_time:65528ms step_avg:91.78ms +step:715/1660 train_time:65621ms step_avg:91.78ms +step:716/1660 train_time:65714ms step_avg:91.78ms +step:717/1660 train_time:65807ms step_avg:91.78ms +step:718/1660 train_time:65900ms step_avg:91.78ms +step:719/1660 train_time:65993ms step_avg:91.78ms +step:720/1660 train_time:66085ms step_avg:91.78ms +step:721/1660 train_time:66177ms step_avg:91.79ms +step:722/1660 train_time:66270ms step_avg:91.79ms +step:723/1660 train_time:66363ms step_avg:91.79ms +step:724/1660 train_time:66455ms step_avg:91.79ms +step:725/1660 train_time:66548ms step_avg:91.79ms +step:726/1660 train_time:66640ms step_avg:91.79ms +step:727/1660 train_time:66734ms step_avg:91.79ms +step:728/1660 train_time:66827ms step_avg:91.79ms +step:729/1660 train_time:66920ms step_avg:91.80ms +step:730/1660 train_time:67013ms step_avg:91.80ms +step:731/1660 train_time:67105ms step_avg:91.80ms +step:732/1660 train_time:67198ms step_avg:91.80ms +step:733/1660 train_time:67290ms step_avg:91.80ms +step:734/1660 train_time:67382ms step_avg:91.80ms +step:735/1660 train_time:67475ms step_avg:91.80ms +step:736/1660 train_time:67567ms step_avg:91.80ms +step:737/1660 train_time:67660ms step_avg:91.80ms +step:738/1660 train_time:67753ms step_avg:91.81ms +step:739/1660 train_time:67845ms step_avg:91.81ms +step:740/1660 train_time:67938ms step_avg:91.81ms +step:741/1660 train_time:68030ms step_avg:91.81ms +step:742/1660 train_time:68123ms step_avg:91.81ms +step:743/1660 train_time:68215ms step_avg:91.81ms +step:744/1660 train_time:68307ms step_avg:91.81ms +step:745/1660 train_time:68401ms step_avg:91.81ms +step:746/1660 train_time:68494ms step_avg:91.81ms +step:747/1660 train_time:68586ms step_avg:91.81ms +step:748/1660 train_time:68679ms step_avg:91.82ms +step:749/1660 train_time:68772ms step_avg:91.82ms +step:750/1660 train_time:68865ms step_avg:91.82ms +step:750/1660 val_loss:3.5616 train_time:68960ms step_avg:91.95ms +step:751/1660 train_time:68979ms step_avg:91.85ms +step:752/1660 train_time:69055ms step_avg:91.83ms +step:753/1660 train_time:69152ms step_avg:91.84ms +step:754/1660 train_time:69245ms step_avg:91.84ms +step:755/1660 train_time:69337ms step_avg:91.84ms +step:756/1660 train_time:69427ms step_avg:91.84ms +step:757/1660 train_time:69519ms step_avg:91.83ms +step:758/1660 train_time:69611ms step_avg:91.83ms +step:759/1660 train_time:69702ms step_avg:91.83ms +step:760/1660 train_time:69793ms step_avg:91.83ms +step:761/1660 train_time:69886ms step_avg:91.83ms +step:762/1660 train_time:69981ms step_avg:91.84ms +step:763/1660 train_time:70077ms step_avg:91.84ms +step:764/1660 train_time:70171ms step_avg:91.85ms +step:765/1660 train_time:70263ms step_avg:91.85ms +step:766/1660 train_time:70357ms step_avg:91.85ms +step:767/1660 train_time:70448ms step_avg:91.85ms +step:768/1660 train_time:70540ms step_avg:91.85ms +step:769/1660 train_time:70633ms step_avg:91.85ms +step:770/1660 train_time:70724ms step_avg:91.85ms +step:771/1660 train_time:70816ms step_avg:91.85ms +step:772/1660 train_time:70909ms step_avg:91.85ms +step:773/1660 train_time:71003ms step_avg:91.85ms +step:774/1660 train_time:71097ms step_avg:91.86ms +step:775/1660 train_time:71190ms step_avg:91.86ms +step:776/1660 train_time:71283ms step_avg:91.86ms +step:777/1660 train_time:71375ms step_avg:91.86ms +step:778/1660 train_time:71467ms step_avg:91.86ms +step:779/1660 train_time:71559ms step_avg:91.86ms +step:780/1660 train_time:71650ms step_avg:91.86ms +step:781/1660 train_time:71742ms step_avg:91.86ms +step:782/1660 train_time:71835ms step_avg:91.86ms +step:783/1660 train_time:71927ms step_avg:91.86ms +step:784/1660 train_time:72022ms step_avg:91.86ms +step:785/1660 train_time:72114ms step_avg:91.87ms +step:786/1660 train_time:72207ms step_avg:91.87ms +step:787/1660 train_time:72300ms step_avg:91.87ms +step:788/1660 train_time:72393ms step_avg:91.87ms +step:789/1660 train_time:72485ms step_avg:91.87ms +step:790/1660 train_time:72578ms step_avg:91.87ms +step:791/1660 train_time:72670ms step_avg:91.87ms +step:792/1660 train_time:72763ms step_avg:91.87ms +step:793/1660 train_time:72854ms step_avg:91.87ms +step:794/1660 train_time:72946ms step_avg:91.87ms +step:795/1660 train_time:73040ms step_avg:91.87ms +step:796/1660 train_time:73133ms step_avg:91.88ms +step:797/1660 train_time:73225ms step_avg:91.88ms +step:798/1660 train_time:73319ms step_avg:91.88ms +step:799/1660 train_time:73412ms step_avg:91.88ms +step:800/1660 train_time:73504ms step_avg:91.88ms +step:801/1660 train_time:73596ms step_avg:91.88ms +step:802/1660 train_time:73688ms step_avg:91.88ms +step:803/1660 train_time:73780ms step_avg:91.88ms +step:804/1660 train_time:73873ms step_avg:91.88ms +step:805/1660 train_time:73966ms step_avg:91.88ms +step:806/1660 train_time:74059ms step_avg:91.88ms +step:807/1660 train_time:74152ms step_avg:91.89ms +step:808/1660 train_time:74245ms step_avg:91.89ms +step:809/1660 train_time:74338ms step_avg:91.89ms +step:810/1660 train_time:74431ms step_avg:91.89ms +step:811/1660 train_time:74523ms step_avg:91.89ms +step:812/1660 train_time:74615ms step_avg:91.89ms +step:813/1660 train_time:74707ms step_avg:91.89ms +step:814/1660 train_time:74800ms step_avg:91.89ms +step:815/1660 train_time:74892ms step_avg:91.89ms +step:816/1660 train_time:74984ms step_avg:91.89ms +step:817/1660 train_time:75077ms step_avg:91.89ms +step:818/1660 train_time:75170ms step_avg:91.89ms +step:819/1660 train_time:75263ms step_avg:91.90ms +step:820/1660 train_time:75355ms step_avg:91.90ms +step:821/1660 train_time:75447ms step_avg:91.90ms +step:822/1660 train_time:75539ms step_avg:91.90ms +step:823/1660 train_time:75633ms step_avg:91.90ms +step:824/1660 train_time:75725ms step_avg:91.90ms +step:825/1660 train_time:75818ms step_avg:91.90ms +step:826/1660 train_time:75910ms step_avg:91.90ms +step:827/1660 train_time:76002ms step_avg:91.90ms +step:828/1660 train_time:76095ms step_avg:91.90ms +step:829/1660 train_time:76187ms step_avg:91.90ms +step:830/1660 train_time:76280ms step_avg:91.90ms +step:831/1660 train_time:76372ms step_avg:91.90ms +step:832/1660 train_time:76464ms step_avg:91.90ms +step:833/1660 train_time:76556ms step_avg:91.90ms +step:834/1660 train_time:76649ms step_avg:91.90ms +step:835/1660 train_time:76741ms step_avg:91.91ms +step:836/1660 train_time:76834ms step_avg:91.91ms +step:837/1660 train_time:76926ms step_avg:91.91ms +step:838/1660 train_time:77019ms step_avg:91.91ms +step:839/1660 train_time:77111ms step_avg:91.91ms +step:840/1660 train_time:77204ms step_avg:91.91ms +step:841/1660 train_time:77297ms step_avg:91.91ms +step:842/1660 train_time:77389ms step_avg:91.91ms +step:843/1660 train_time:77482ms step_avg:91.91ms +step:844/1660 train_time:77574ms step_avg:91.91ms +step:845/1660 train_time:77667ms step_avg:91.91ms +step:846/1660 train_time:77759ms step_avg:91.91ms +step:847/1660 train_time:77852ms step_avg:91.92ms +step:848/1660 train_time:77943ms step_avg:91.91ms +step:849/1660 train_time:78037ms step_avg:91.92ms +step:850/1660 train_time:78130ms step_avg:91.92ms +step:851/1660 train_time:78223ms step_avg:91.92ms +step:852/1660 train_time:78315ms step_avg:91.92ms +step:853/1660 train_time:78408ms step_avg:91.92ms +step:854/1660 train_time:78501ms step_avg:91.92ms +step:855/1660 train_time:78593ms step_avg:91.92ms +step:856/1660 train_time:78685ms step_avg:91.92ms +step:857/1660 train_time:78777ms step_avg:91.92ms +step:858/1660 train_time:78869ms step_avg:91.92ms +step:859/1660 train_time:78962ms step_avg:91.92ms +step:860/1660 train_time:79054ms step_avg:91.92ms +step:861/1660 train_time:79146ms step_avg:91.92ms +step:862/1660 train_time:79240ms step_avg:91.93ms +step:863/1660 train_time:79333ms step_avg:91.93ms +step:864/1660 train_time:79425ms step_avg:91.93ms +step:865/1660 train_time:79519ms step_avg:91.93ms +step:866/1660 train_time:79612ms step_avg:91.93ms +step:867/1660 train_time:79704ms step_avg:91.93ms +step:868/1660 train_time:79796ms step_avg:91.93ms +step:869/1660 train_time:79889ms step_avg:91.93ms +step:870/1660 train_time:79982ms step_avg:91.93ms +step:871/1660 train_time:80075ms step_avg:91.93ms +step:872/1660 train_time:80168ms step_avg:91.94ms +step:873/1660 train_time:80260ms step_avg:91.94ms +step:874/1660 train_time:80353ms step_avg:91.94ms +step:875/1660 train_time:80445ms step_avg:91.94ms +step:875/1660 val_loss:3.5172 train_time:80540ms step_avg:92.05ms +step:876/1660 train_time:80559ms step_avg:91.96ms +step:877/1660 train_time:80640ms step_avg:91.95ms +step:878/1660 train_time:80739ms step_avg:91.96ms +step:879/1660 train_time:80832ms step_avg:91.96ms +step:880/1660 train_time:80923ms step_avg:91.96ms +step:881/1660 train_time:81014ms step_avg:91.96ms +step:882/1660 train_time:81106ms step_avg:91.96ms +step:883/1660 train_time:81197ms step_avg:91.96ms +step:884/1660 train_time:81288ms step_avg:91.96ms +step:885/1660 train_time:81380ms step_avg:91.96ms +step:886/1660 train_time:81473ms step_avg:91.96ms +step:887/1660 train_time:81568ms step_avg:91.96ms +step:888/1660 train_time:81664ms step_avg:91.96ms +step:889/1660 train_time:81759ms step_avg:91.97ms +step:890/1660 train_time:81853ms step_avg:91.97ms +step:891/1660 train_time:81946ms step_avg:91.97ms +step:892/1660 train_time:82039ms step_avg:91.97ms +step:893/1660 train_time:82131ms step_avg:91.97ms +step:894/1660 train_time:82222ms step_avg:91.97ms +step:895/1660 train_time:82314ms step_avg:91.97ms +step:896/1660 train_time:82405ms step_avg:91.97ms +step:897/1660 train_time:82499ms step_avg:91.97ms +step:898/1660 train_time:82593ms step_avg:91.97ms +step:899/1660 train_time:82686ms step_avg:91.98ms +step:900/1660 train_time:82781ms step_avg:91.98ms +step:901/1660 train_time:82876ms step_avg:91.98ms +step:902/1660 train_time:82968ms step_avg:91.98ms +step:903/1660 train_time:83060ms step_avg:91.98ms +step:904/1660 train_time:83152ms step_avg:91.98ms +step:905/1660 train_time:83243ms step_avg:91.98ms +step:906/1660 train_time:83336ms step_avg:91.98ms +step:907/1660 train_time:83428ms step_avg:91.98ms +step:908/1660 train_time:83521ms step_avg:91.98ms +step:909/1660 train_time:83614ms step_avg:91.98ms +step:910/1660 train_time:83707ms step_avg:91.99ms +step:911/1660 train_time:83801ms step_avg:91.99ms +step:912/1660 train_time:83895ms step_avg:91.99ms +step:913/1660 train_time:83987ms step_avg:91.99ms +step:914/1660 train_time:84080ms step_avg:91.99ms +step:915/1660 train_time:84172ms step_avg:91.99ms +step:916/1660 train_time:84264ms step_avg:91.99ms +step:917/1660 train_time:84356ms step_avg:91.99ms +step:918/1660 train_time:84450ms step_avg:91.99ms +step:919/1660 train_time:84542ms step_avg:91.99ms +step:920/1660 train_time:84635ms step_avg:92.00ms +step:921/1660 train_time:84728ms step_avg:92.00ms +step:922/1660 train_time:84822ms step_avg:92.00ms +step:923/1660 train_time:84916ms step_avg:92.00ms +step:924/1660 train_time:85008ms step_avg:92.00ms +step:925/1660 train_time:85102ms step_avg:92.00ms +step:926/1660 train_time:85194ms step_avg:92.00ms +step:927/1660 train_time:85286ms step_avg:92.00ms +step:928/1660 train_time:85379ms step_avg:92.00ms +step:929/1660 train_time:85472ms step_avg:92.00ms +step:930/1660 train_time:85564ms step_avg:92.00ms +step:931/1660 train_time:85657ms step_avg:92.00ms +step:932/1660 train_time:85749ms step_avg:92.01ms +step:933/1660 train_time:85843ms step_avg:92.01ms +step:934/1660 train_time:85937ms step_avg:92.01ms +step:935/1660 train_time:86030ms step_avg:92.01ms +step:936/1660 train_time:86122ms step_avg:92.01ms +step:937/1660 train_time:86215ms step_avg:92.01ms +step:938/1660 train_time:86307ms step_avg:92.01ms +step:939/1660 train_time:86399ms step_avg:92.01ms +step:940/1660 train_time:86492ms step_avg:92.01ms +step:941/1660 train_time:86584ms step_avg:92.01ms +step:942/1660 train_time:86676ms step_avg:92.01ms +step:943/1660 train_time:86769ms step_avg:92.01ms +step:944/1660 train_time:86862ms step_avg:92.01ms +step:945/1660 train_time:86955ms step_avg:92.02ms +step:946/1660 train_time:87048ms step_avg:92.02ms +step:947/1660 train_time:87140ms step_avg:92.02ms +step:948/1660 train_time:87233ms step_avg:92.02ms +step:949/1660 train_time:87325ms step_avg:92.02ms +step:950/1660 train_time:87417ms step_avg:92.02ms +step:951/1660 train_time:87510ms step_avg:92.02ms +step:952/1660 train_time:87603ms step_avg:92.02ms +step:953/1660 train_time:87696ms step_avg:92.02ms +step:954/1660 train_time:87788ms step_avg:92.02ms +step:955/1660 train_time:87882ms step_avg:92.02ms +step:956/1660 train_time:87976ms step_avg:92.02ms +step:957/1660 train_time:88068ms step_avg:92.03ms +step:958/1660 train_time:88161ms step_avg:92.03ms +step:959/1660 train_time:88254ms step_avg:92.03ms +step:960/1660 train_time:88347ms step_avg:92.03ms +step:961/1660 train_time:88440ms step_avg:92.03ms +step:962/1660 train_time:88533ms step_avg:92.03ms +step:963/1660 train_time:88625ms step_avg:92.03ms +step:964/1660 train_time:88717ms step_avg:92.03ms +step:965/1660 train_time:88810ms step_avg:92.03ms +step:966/1660 train_time:88903ms step_avg:92.03ms +step:967/1660 train_time:88998ms step_avg:92.03ms +step:968/1660 train_time:89090ms step_avg:92.04ms +step:969/1660 train_time:89183ms step_avg:92.04ms +step:970/1660 train_time:89277ms step_avg:92.04ms +step:971/1660 train_time:89370ms step_avg:92.04ms +step:972/1660 train_time:89462ms step_avg:92.04ms +step:973/1660 train_time:89555ms step_avg:92.04ms +step:974/1660 train_time:89647ms step_avg:92.04ms +step:975/1660 train_time:89740ms step_avg:92.04ms +step:976/1660 train_time:89833ms step_avg:92.04ms +step:977/1660 train_time:89926ms step_avg:92.04ms +step:978/1660 train_time:90019ms step_avg:92.04ms +step:979/1660 train_time:90111ms step_avg:92.04ms +step:980/1660 train_time:90204ms step_avg:92.05ms +step:981/1660 train_time:90298ms step_avg:92.05ms +step:982/1660 train_time:90390ms step_avg:92.05ms +step:983/1660 train_time:90482ms step_avg:92.05ms +step:984/1660 train_time:90576ms step_avg:92.05ms +step:985/1660 train_time:90669ms step_avg:92.05ms +step:986/1660 train_time:90762ms step_avg:92.05ms +step:987/1660 train_time:90855ms step_avg:92.05ms +step:988/1660 train_time:90948ms step_avg:92.05ms +step:989/1660 train_time:91041ms step_avg:92.05ms +step:990/1660 train_time:91133ms step_avg:92.05ms +step:991/1660 train_time:91225ms step_avg:92.05ms +step:992/1660 train_time:91319ms step_avg:92.06ms +step:993/1660 train_time:91411ms step_avg:92.06ms +step:994/1660 train_time:91504ms step_avg:92.06ms +step:995/1660 train_time:91597ms step_avg:92.06ms +step:996/1660 train_time:91689ms step_avg:92.06ms +step:997/1660 train_time:91782ms step_avg:92.06ms +step:998/1660 train_time:91876ms step_avg:92.06ms +step:999/1660 train_time:91969ms step_avg:92.06ms +step:1000/1660 train_time:92062ms step_avg:92.06ms +step:1000/1660 val_loss:3.4660 train_time:92155ms step_avg:92.16ms +step:1001/1660 train_time:92175ms step_avg:92.08ms +step:1002/1660 train_time:92251ms step_avg:92.07ms +step:1003/1660 train_time:92349ms step_avg:92.07ms +step:1004/1660 train_time:92442ms step_avg:92.07ms +step:1005/1660 train_time:92535ms step_avg:92.07ms +step:1006/1660 train_time:92627ms step_avg:92.07ms +step:1007/1660 train_time:92718ms step_avg:92.07ms +step:1008/1660 train_time:92810ms step_avg:92.07ms +step:1009/1660 train_time:92901ms step_avg:92.07ms +step:1010/1660 train_time:92992ms step_avg:92.07ms +step:1011/1660 train_time:93084ms step_avg:92.07ms +step:1012/1660 train_time:93179ms step_avg:92.07ms +step:1013/1660 train_time:93275ms step_avg:92.08ms +step:1014/1660 train_time:93371ms step_avg:92.08ms +step:1015/1660 train_time:93464ms step_avg:92.08ms +step:1016/1660 train_time:93557ms step_avg:92.08ms +step:1017/1660 train_time:93649ms step_avg:92.08ms +step:1018/1660 train_time:93741ms step_avg:92.08ms +step:1019/1660 train_time:93832ms step_avg:92.08ms +step:1020/1660 train_time:93923ms step_avg:92.08ms +step:1021/1660 train_time:94016ms step_avg:92.08ms +step:1022/1660 train_time:94108ms step_avg:92.08ms +step:1023/1660 train_time:94201ms step_avg:92.08ms +step:1024/1660 train_time:94297ms step_avg:92.09ms +step:1025/1660 train_time:94392ms step_avg:92.09ms +step:1026/1660 train_time:94485ms step_avg:92.09ms +step:1027/1660 train_time:94578ms step_avg:92.09ms +step:1028/1660 train_time:94670ms step_avg:92.09ms +step:1029/1660 train_time:94762ms step_avg:92.09ms +step:1030/1660 train_time:94853ms step_avg:92.09ms +step:1031/1660 train_time:94946ms step_avg:92.09ms +step:1032/1660 train_time:95038ms step_avg:92.09ms +step:1033/1660 train_time:95130ms step_avg:92.09ms +step:1034/1660 train_time:95223ms step_avg:92.09ms +step:1035/1660 train_time:95318ms step_avg:92.09ms +step:1036/1660 train_time:95411ms step_avg:92.10ms +step:1037/1660 train_time:95504ms step_avg:92.10ms +step:1038/1660 train_time:95598ms step_avg:92.10ms +step:1039/1660 train_time:95690ms step_avg:92.10ms +step:1040/1660 train_time:95782ms step_avg:92.10ms +step:1041/1660 train_time:95875ms step_avg:92.10ms +step:1042/1660 train_time:95967ms step_avg:92.10ms +step:1043/1660 train_time:96060ms step_avg:92.10ms +step:1044/1660 train_time:96152ms step_avg:92.10ms +step:1045/1660 train_time:96246ms step_avg:92.10ms +step:1046/1660 train_time:96339ms step_avg:92.10ms +step:1047/1660 train_time:96432ms step_avg:92.10ms +step:1048/1660 train_time:96525ms step_avg:92.10ms +step:1049/1660 train_time:96618ms step_avg:92.10ms +step:1050/1660 train_time:96711ms step_avg:92.11ms +step:1051/1660 train_time:96803ms step_avg:92.11ms +step:1052/1660 train_time:96895ms step_avg:92.11ms +step:1053/1660 train_time:96987ms step_avg:92.11ms +step:1054/1660 train_time:97079ms step_avg:92.11ms +step:1055/1660 train_time:97172ms step_avg:92.11ms +step:1056/1660 train_time:97265ms step_avg:92.11ms +step:1057/1660 train_time:97358ms step_avg:92.11ms +step:1058/1660 train_time:97451ms step_avg:92.11ms +step:1059/1660 train_time:97543ms step_avg:92.11ms +step:1060/1660 train_time:97637ms step_avg:92.11ms +step:1061/1660 train_time:97729ms step_avg:92.11ms +step:1062/1660 train_time:97822ms step_avg:92.11ms +step:1063/1660 train_time:97914ms step_avg:92.11ms +step:1064/1660 train_time:98006ms step_avg:92.11ms +step:1065/1660 train_time:98099ms step_avg:92.11ms +step:1066/1660 train_time:98192ms step_avg:92.11ms +step:1067/1660 train_time:98286ms step_avg:92.11ms +step:1068/1660 train_time:98378ms step_avg:92.11ms +step:1069/1660 train_time:98471ms step_avg:92.12ms +step:1070/1660 train_time:98564ms step_avg:92.12ms +step:1071/1660 train_time:98658ms step_avg:92.12ms +step:1072/1660 train_time:98751ms step_avg:92.12ms +step:1073/1660 train_time:98843ms step_avg:92.12ms +step:1074/1660 train_time:98936ms step_avg:92.12ms +step:1075/1660 train_time:99027ms step_avg:92.12ms +step:1076/1660 train_time:99120ms step_avg:92.12ms +step:1077/1660 train_time:99214ms step_avg:92.12ms +step:1078/1660 train_time:99307ms step_avg:92.12ms +step:1079/1660 train_time:99399ms step_avg:92.12ms +step:1080/1660 train_time:99491ms step_avg:92.12ms +step:1081/1660 train_time:99585ms step_avg:92.12ms +step:1082/1660 train_time:99679ms step_avg:92.12ms +step:1083/1660 train_time:99771ms step_avg:92.13ms +step:1084/1660 train_time:99863ms step_avg:92.12ms +step:1085/1660 train_time:99956ms step_avg:92.13ms +step:1086/1660 train_time:100048ms step_avg:92.13ms +step:1087/1660 train_time:100140ms step_avg:92.12ms +step:1088/1660 train_time:100233ms step_avg:92.13ms +step:1089/1660 train_time:100325ms step_avg:92.13ms +step:1090/1660 train_time:100420ms step_avg:92.13ms +step:1091/1660 train_time:100513ms step_avg:92.13ms +step:1092/1660 train_time:100606ms step_avg:92.13ms +step:1093/1660 train_time:100699ms step_avg:92.13ms +step:1094/1660 train_time:100792ms step_avg:92.13ms +step:1095/1660 train_time:100884ms step_avg:92.13ms +step:1096/1660 train_time:100977ms step_avg:92.13ms +step:1097/1660 train_time:101070ms step_avg:92.13ms +step:1098/1660 train_time:101162ms step_avg:92.13ms +step:1099/1660 train_time:101256ms step_avg:92.13ms +step:1100/1660 train_time:101349ms step_avg:92.14ms +step:1101/1660 train_time:101441ms step_avg:92.14ms +step:1102/1660 train_time:101535ms step_avg:92.14ms +step:1103/1660 train_time:101627ms step_avg:92.14ms +step:1104/1660 train_time:101720ms step_avg:92.14ms +step:1105/1660 train_time:101814ms step_avg:92.14ms +step:1106/1660 train_time:101907ms step_avg:92.14ms +step:1107/1660 train_time:101999ms step_avg:92.14ms +step:1108/1660 train_time:102092ms step_avg:92.14ms +step:1109/1660 train_time:102185ms step_avg:92.14ms +step:1110/1660 train_time:102279ms step_avg:92.14ms +step:1111/1660 train_time:102372ms step_avg:92.14ms +step:1112/1660 train_time:102465ms step_avg:92.14ms +step:1113/1660 train_time:102558ms step_avg:92.15ms +step:1114/1660 train_time:102652ms step_avg:92.15ms +step:1115/1660 train_time:102745ms step_avg:92.15ms +step:1116/1660 train_time:102840ms step_avg:92.15ms +step:1117/1660 train_time:102934ms step_avg:92.15ms +step:1118/1660 train_time:103027ms step_avg:92.15ms +step:1119/1660 train_time:103120ms step_avg:92.15ms +step:1120/1660 train_time:103213ms step_avg:92.15ms +step:1121/1660 train_time:103306ms step_avg:92.16ms +step:1122/1660 train_time:103399ms step_avg:92.16ms +step:1123/1660 train_time:103492ms step_avg:92.16ms +step:1124/1660 train_time:103585ms step_avg:92.16ms +step:1125/1660 train_time:103679ms step_avg:92.16ms +step:1125/1660 val_loss:3.4132 train_time:103774ms step_avg:92.24ms +step:1126/1660 train_time:103793ms step_avg:92.18ms +step:1127/1660 train_time:103871ms step_avg:92.17ms +step:1128/1660 train_time:103973ms step_avg:92.17ms +step:1129/1660 train_time:104067ms step_avg:92.18ms +step:1130/1660 train_time:104160ms step_avg:92.18ms +step:1131/1660 train_time:104252ms step_avg:92.18ms +step:1132/1660 train_time:104344ms step_avg:92.18ms +step:1133/1660 train_time:104436ms step_avg:92.18ms +step:1134/1660 train_time:104528ms step_avg:92.18ms +step:1135/1660 train_time:104620ms step_avg:92.18ms +step:1136/1660 train_time:104713ms step_avg:92.18ms +step:1137/1660 train_time:104808ms step_avg:92.18ms +step:1138/1660 train_time:104904ms step_avg:92.18ms +step:1139/1660 train_time:104999ms step_avg:92.19ms +step:1140/1660 train_time:105093ms step_avg:92.19ms +step:1141/1660 train_time:105186ms step_avg:92.19ms +step:1142/1660 train_time:105278ms step_avg:92.19ms +step:1143/1660 train_time:105371ms step_avg:92.19ms +step:1144/1660 train_time:105464ms step_avg:92.19ms +step:1145/1660 train_time:105556ms step_avg:92.19ms +step:1146/1660 train_time:105649ms step_avg:92.19ms +step:1147/1660 train_time:105741ms step_avg:92.19ms +step:1148/1660 train_time:105836ms step_avg:92.19ms +step:1149/1660 train_time:105933ms step_avg:92.20ms +step:1150/1660 train_time:106028ms step_avg:92.20ms +step:1151/1660 train_time:106122ms step_avg:92.20ms +step:1152/1660 train_time:106214ms step_avg:92.20ms +step:1153/1660 train_time:106307ms step_avg:92.20ms +step:1154/1660 train_time:106399ms step_avg:92.20ms +step:1155/1660 train_time:106492ms step_avg:92.20ms +step:1156/1660 train_time:106585ms step_avg:92.20ms +step:1157/1660 train_time:106677ms step_avg:92.20ms +step:1158/1660 train_time:106771ms step_avg:92.20ms +step:1159/1660 train_time:106866ms step_avg:92.21ms +step:1160/1660 train_time:106960ms step_avg:92.21ms +step:1161/1660 train_time:107056ms step_avg:92.21ms +step:1162/1660 train_time:107150ms step_avg:92.21ms +step:1163/1660 train_time:107242ms step_avg:92.21ms +step:1164/1660 train_time:107335ms step_avg:92.21ms +step:1165/1660 train_time:107429ms step_avg:92.21ms +step:1166/1660 train_time:107521ms step_avg:92.21ms +step:1167/1660 train_time:107614ms step_avg:92.21ms +step:1168/1660 train_time:107707ms step_avg:92.21ms +step:1169/1660 train_time:107800ms step_avg:92.22ms +step:1170/1660 train_time:107895ms step_avg:92.22ms +step:1171/1660 train_time:107990ms step_avg:92.22ms +step:1172/1660 train_time:108084ms step_avg:92.22ms +step:1173/1660 train_time:108177ms step_avg:92.22ms +step:1174/1660 train_time:108270ms step_avg:92.22ms +step:1175/1660 train_time:108363ms step_avg:92.22ms +step:1176/1660 train_time:108457ms step_avg:92.23ms +step:1177/1660 train_time:108550ms step_avg:92.23ms +step:1178/1660 train_time:108643ms step_avg:92.23ms +step:1179/1660 train_time:108736ms step_avg:92.23ms +step:1180/1660 train_time:108830ms step_avg:92.23ms +step:1181/1660 train_time:108924ms step_avg:92.23ms +step:1182/1660 train_time:109018ms step_avg:92.23ms +step:1183/1660 train_time:109113ms step_avg:92.23ms +step:1184/1660 train_time:109206ms step_avg:92.23ms +step:1185/1660 train_time:109300ms step_avg:92.24ms +step:1186/1660 train_time:109393ms step_avg:92.24ms +step:1187/1660 train_time:109487ms step_avg:92.24ms +step:1188/1660 train_time:109580ms step_avg:92.24ms +step:1189/1660 train_time:109673ms step_avg:92.24ms +step:1190/1660 train_time:109767ms step_avg:92.24ms +step:1191/1660 train_time:109860ms step_avg:92.24ms +step:1192/1660 train_time:109954ms step_avg:92.24ms +step:1193/1660 train_time:110048ms step_avg:92.24ms +step:1194/1660 train_time:110141ms step_avg:92.25ms +step:1195/1660 train_time:110235ms step_avg:92.25ms +step:1196/1660 train_time:110328ms step_avg:92.25ms +step:1197/1660 train_time:110421ms step_avg:92.25ms +step:1198/1660 train_time:110514ms step_avg:92.25ms +step:1199/1660 train_time:110607ms step_avg:92.25ms +step:1200/1660 train_time:110699ms step_avg:92.25ms +step:1201/1660 train_time:110794ms step_avg:92.25ms +step:1202/1660 train_time:110888ms step_avg:92.25ms +step:1203/1660 train_time:110981ms step_avg:92.25ms +step:1204/1660 train_time:111075ms step_avg:92.25ms +step:1205/1660 train_time:111168ms step_avg:92.26ms +step:1206/1660 train_time:111262ms step_avg:92.26ms +step:1207/1660 train_time:111356ms step_avg:92.26ms +step:1208/1660 train_time:111448ms step_avg:92.26ms +step:1209/1660 train_time:111542ms step_avg:92.26ms +step:1210/1660 train_time:111635ms step_avg:92.26ms +step:1211/1660 train_time:111728ms step_avg:92.26ms +step:1212/1660 train_time:111821ms step_avg:92.26ms +step:1213/1660 train_time:111914ms step_avg:92.26ms +step:1214/1660 train_time:112008ms step_avg:92.26ms +step:1215/1660 train_time:112101ms step_avg:92.26ms +step:1216/1660 train_time:112197ms step_avg:92.27ms +step:1217/1660 train_time:112293ms step_avg:92.27ms +step:1218/1660 train_time:112388ms step_avg:92.27ms +step:1219/1660 train_time:112480ms step_avg:92.27ms +step:1220/1660 train_time:112574ms step_avg:92.27ms +step:1221/1660 train_time:112667ms step_avg:92.27ms +step:1222/1660 train_time:112760ms step_avg:92.27ms +step:1223/1660 train_time:112853ms step_avg:92.28ms +step:1224/1660 train_time:112947ms step_avg:92.28ms +step:1225/1660 train_time:113040ms step_avg:92.28ms +step:1226/1660 train_time:113135ms step_avg:92.28ms +step:1227/1660 train_time:113230ms step_avg:92.28ms +step:1228/1660 train_time:113324ms step_avg:92.28ms +step:1229/1660 train_time:113417ms step_avg:92.28ms +step:1230/1660 train_time:113510ms step_avg:92.28ms +step:1231/1660 train_time:113603ms step_avg:92.29ms +step:1232/1660 train_time:113696ms step_avg:92.29ms +step:1233/1660 train_time:113790ms step_avg:92.29ms +step:1234/1660 train_time:113882ms step_avg:92.29ms +step:1235/1660 train_time:113975ms step_avg:92.29ms +step:1236/1660 train_time:114069ms step_avg:92.29ms +step:1237/1660 train_time:114163ms step_avg:92.29ms +step:1238/1660 train_time:114257ms step_avg:92.29ms +step:1239/1660 train_time:114353ms step_avg:92.29ms +step:1240/1660 train_time:114447ms step_avg:92.30ms +step:1241/1660 train_time:114539ms step_avg:92.30ms +step:1242/1660 train_time:114633ms step_avg:92.30ms +step:1243/1660 train_time:114727ms step_avg:92.30ms +step:1244/1660 train_time:114821ms step_avg:92.30ms +step:1245/1660 train_time:114915ms step_avg:92.30ms +step:1246/1660 train_time:115008ms step_avg:92.30ms +step:1247/1660 train_time:115101ms step_avg:92.30ms +step:1248/1660 train_time:115195ms step_avg:92.30ms +step:1249/1660 train_time:115288ms step_avg:92.30ms +step:1250/1660 train_time:115382ms step_avg:92.31ms +step:1250/1660 val_loss:3.3741 train_time:115476ms step_avg:92.38ms +step:1251/1660 train_time:115495ms step_avg:92.32ms +step:1252/1660 train_time:115574ms step_avg:92.31ms +step:1253/1660 train_time:115674ms step_avg:92.32ms +step:1254/1660 train_time:115767ms step_avg:92.32ms +step:1255/1660 train_time:115859ms step_avg:92.32ms +step:1256/1660 train_time:115952ms step_avg:92.32ms +step:1257/1660 train_time:116044ms step_avg:92.32ms +step:1258/1660 train_time:116136ms step_avg:92.32ms +step:1259/1660 train_time:116228ms step_avg:92.32ms +step:1260/1660 train_time:116320ms step_avg:92.32ms +step:1261/1660 train_time:116414ms step_avg:92.32ms +step:1262/1660 train_time:116509ms step_avg:92.32ms +step:1263/1660 train_time:116607ms step_avg:92.33ms +step:1264/1660 train_time:116703ms step_avg:92.33ms +step:1265/1660 train_time:116797ms step_avg:92.33ms +step:1266/1660 train_time:116889ms step_avg:92.33ms +step:1267/1660 train_time:116982ms step_avg:92.33ms +step:1268/1660 train_time:117075ms step_avg:92.33ms +step:1269/1660 train_time:117167ms step_avg:92.33ms +step:1270/1660 train_time:117259ms step_avg:92.33ms +step:1271/1660 train_time:117352ms step_avg:92.33ms +step:1272/1660 train_time:117445ms step_avg:92.33ms +step:1273/1660 train_time:117541ms step_avg:92.33ms +step:1274/1660 train_time:117637ms step_avg:92.34ms +step:1275/1660 train_time:117731ms step_avg:92.34ms +step:1276/1660 train_time:117825ms step_avg:92.34ms +step:1277/1660 train_time:117919ms step_avg:92.34ms +step:1278/1660 train_time:118011ms step_avg:92.34ms +step:1279/1660 train_time:118104ms step_avg:92.34ms +step:1280/1660 train_time:118196ms step_avg:92.34ms +step:1281/1660 train_time:118289ms step_avg:92.34ms +step:1282/1660 train_time:118382ms step_avg:92.34ms +step:1283/1660 train_time:118475ms step_avg:92.34ms +step:1284/1660 train_time:118569ms step_avg:92.34ms +step:1285/1660 train_time:118665ms step_avg:92.35ms +step:1286/1660 train_time:118758ms step_avg:92.35ms +step:1287/1660 train_time:118852ms step_avg:92.35ms +step:1288/1660 train_time:118945ms step_avg:92.35ms +step:1289/1660 train_time:119038ms step_avg:92.35ms +step:1290/1660 train_time:119130ms step_avg:92.35ms +step:1291/1660 train_time:119224ms step_avg:92.35ms +step:1292/1660 train_time:119317ms step_avg:92.35ms +step:1293/1660 train_time:119409ms step_avg:92.35ms +step:1294/1660 train_time:119503ms step_avg:92.35ms +step:1295/1660 train_time:119597ms step_avg:92.35ms +step:1296/1660 train_time:119691ms step_avg:92.35ms +step:1297/1660 train_time:119785ms step_avg:92.36ms +step:1298/1660 train_time:119878ms step_avg:92.36ms +step:1299/1660 train_time:119971ms step_avg:92.36ms +step:1300/1660 train_time:120063ms step_avg:92.36ms +step:1301/1660 train_time:120156ms step_avg:92.36ms +step:1302/1660 train_time:120249ms step_avg:92.36ms +step:1303/1660 train_time:120343ms step_avg:92.36ms +step:1304/1660 train_time:120435ms step_avg:92.36ms +step:1305/1660 train_time:120529ms step_avg:92.36ms +step:1306/1660 train_time:120624ms step_avg:92.36ms +step:1307/1660 train_time:120718ms step_avg:92.36ms +step:1308/1660 train_time:120811ms step_avg:92.36ms +step:1309/1660 train_time:120904ms step_avg:92.36ms +step:1310/1660 train_time:120997ms step_avg:92.36ms +step:1311/1660 train_time:121091ms step_avg:92.37ms +step:1312/1660 train_time:121184ms step_avg:92.37ms +step:1313/1660 train_time:121277ms step_avg:92.37ms +step:1314/1660 train_time:121369ms step_avg:92.37ms +step:1315/1660 train_time:121463ms step_avg:92.37ms +step:1316/1660 train_time:121557ms step_avg:92.37ms +step:1317/1660 train_time:121649ms step_avg:92.37ms +step:1318/1660 train_time:121744ms step_avg:92.37ms +step:1319/1660 train_time:121838ms step_avg:92.37ms +step:1320/1660 train_time:121930ms step_avg:92.37ms +step:1321/1660 train_time:122025ms step_avg:92.37ms +step:1322/1660 train_time:122121ms step_avg:92.38ms +step:1323/1660 train_time:122215ms step_avg:92.38ms +step:1324/1660 train_time:122307ms step_avg:92.38ms +step:1325/1660 train_time:122401ms step_avg:92.38ms +step:1326/1660 train_time:122495ms step_avg:92.38ms +step:1327/1660 train_time:122588ms step_avg:92.38ms +step:1328/1660 train_time:122682ms step_avg:92.38ms +step:1329/1660 train_time:122776ms step_avg:92.38ms +step:1330/1660 train_time:122869ms step_avg:92.38ms +step:1331/1660 train_time:122963ms step_avg:92.38ms +step:1332/1660 train_time:123058ms step_avg:92.39ms +step:1333/1660 train_time:123151ms step_avg:92.39ms +step:1334/1660 train_time:123244ms step_avg:92.39ms +step:1335/1660 train_time:123337ms step_avg:92.39ms +step:1336/1660 train_time:123430ms step_avg:92.39ms +step:1337/1660 train_time:123524ms step_avg:92.39ms +step:1338/1660 train_time:123617ms step_avg:92.39ms +step:1339/1660 train_time:123711ms step_avg:92.39ms +step:1340/1660 train_time:123805ms step_avg:92.39ms +step:1341/1660 train_time:123898ms step_avg:92.39ms +step:1342/1660 train_time:123990ms step_avg:92.39ms +step:1343/1660 train_time:124084ms step_avg:92.39ms +step:1344/1660 train_time:124178ms step_avg:92.39ms +step:1345/1660 train_time:124271ms step_avg:92.39ms +step:1346/1660 train_time:124364ms step_avg:92.40ms +step:1347/1660 train_time:124458ms step_avg:92.40ms +step:1348/1660 train_time:124551ms step_avg:92.40ms +step:1349/1660 train_time:124645ms step_avg:92.40ms +step:1350/1660 train_time:124738ms step_avg:92.40ms +step:1351/1660 train_time:124831ms step_avg:92.40ms +step:1352/1660 train_time:124925ms step_avg:92.40ms +step:1353/1660 train_time:125020ms step_avg:92.40ms +step:1354/1660 train_time:125115ms step_avg:92.40ms +step:1355/1660 train_time:125207ms step_avg:92.40ms +step:1356/1660 train_time:125300ms step_avg:92.40ms +step:1357/1660 train_time:125394ms step_avg:92.41ms +step:1358/1660 train_time:125487ms step_avg:92.41ms +step:1359/1660 train_time:125581ms step_avg:92.41ms +step:1360/1660 train_time:125674ms step_avg:92.41ms +step:1361/1660 train_time:125767ms step_avg:92.41ms +step:1362/1660 train_time:125860ms step_avg:92.41ms +step:1363/1660 train_time:125953ms step_avg:92.41ms +step:1364/1660 train_time:126046ms step_avg:92.41ms +step:1365/1660 train_time:126141ms step_avg:92.41ms +step:1366/1660 train_time:126235ms step_avg:92.41ms +step:1367/1660 train_time:126328ms step_avg:92.41ms +step:1368/1660 train_time:126423ms step_avg:92.41ms +step:1369/1660 train_time:126518ms step_avg:92.42ms +step:1370/1660 train_time:126611ms step_avg:92.42ms +step:1371/1660 train_time:126704ms step_avg:92.42ms +step:1372/1660 train_time:126797ms step_avg:92.42ms +step:1373/1660 train_time:126891ms step_avg:92.42ms +step:1374/1660 train_time:126984ms step_avg:92.42ms +step:1375/1660 train_time:127077ms step_avg:92.42ms +step:1375/1660 val_loss:3.3396 train_time:127171ms step_avg:92.49ms +step:1376/1660 train_time:127191ms step_avg:92.44ms +step:1377/1660 train_time:127270ms step_avg:92.43ms +step:1378/1660 train_time:127368ms step_avg:92.43ms +step:1379/1660 train_time:127462ms step_avg:92.43ms +step:1380/1660 train_time:127554ms step_avg:92.43ms +step:1381/1660 train_time:127646ms step_avg:92.43ms +step:1382/1660 train_time:127738ms step_avg:92.43ms +step:1383/1660 train_time:127831ms step_avg:92.43ms +step:1384/1660 train_time:127923ms step_avg:92.43ms +step:1385/1660 train_time:128016ms step_avg:92.43ms +step:1386/1660 train_time:128110ms step_avg:92.43ms +step:1387/1660 train_time:128205ms step_avg:92.43ms +step:1388/1660 train_time:128301ms step_avg:92.44ms +step:1389/1660 train_time:128396ms step_avg:92.44ms +step:1390/1660 train_time:128492ms step_avg:92.44ms +step:1391/1660 train_time:128586ms step_avg:92.44ms +step:1392/1660 train_time:128678ms step_avg:92.44ms +step:1393/1660 train_time:128772ms step_avg:92.44ms +step:1394/1660 train_time:128865ms step_avg:92.44ms +step:1395/1660 train_time:128957ms step_avg:92.44ms +step:1396/1660 train_time:129050ms step_avg:92.44ms +step:1397/1660 train_time:129143ms step_avg:92.44ms +step:1398/1660 train_time:129238ms step_avg:92.44ms +step:1399/1660 train_time:129332ms step_avg:92.45ms +step:1400/1660 train_time:129427ms step_avg:92.45ms +step:1401/1660 train_time:129521ms step_avg:92.45ms +step:1402/1660 train_time:129615ms step_avg:92.45ms +step:1403/1660 train_time:129708ms step_avg:92.45ms +step:1404/1660 train_time:129800ms step_avg:92.45ms +step:1405/1660 train_time:129893ms step_avg:92.45ms +step:1406/1660 train_time:129985ms step_avg:92.45ms +step:1407/1660 train_time:130078ms step_avg:92.45ms +step:1408/1660 train_time:130173ms step_avg:92.45ms +step:1409/1660 train_time:130266ms step_avg:92.45ms +step:1410/1660 train_time:130360ms step_avg:92.45ms +step:1411/1660 train_time:130454ms step_avg:92.45ms +step:1412/1660 train_time:130549ms step_avg:92.46ms +step:1413/1660 train_time:130642ms step_avg:92.46ms +step:1414/1660 train_time:130735ms step_avg:92.46ms +step:1415/1660 train_time:130827ms step_avg:92.46ms +step:1416/1660 train_time:130920ms step_avg:92.46ms +step:1417/1660 train_time:131013ms step_avg:92.46ms +step:1418/1660 train_time:131106ms step_avg:92.46ms +step:1419/1660 train_time:131199ms step_avg:92.46ms +step:1420/1660 train_time:131293ms step_avg:92.46ms +step:1421/1660 train_time:131387ms step_avg:92.46ms +step:1422/1660 train_time:131481ms step_avg:92.46ms +step:1423/1660 train_time:131574ms step_avg:92.46ms +step:1424/1660 train_time:131668ms step_avg:92.46ms +step:1425/1660 train_time:131761ms step_avg:92.46ms +step:1426/1660 train_time:131854ms step_avg:92.46ms +step:1427/1660 train_time:131947ms step_avg:92.46ms +step:1428/1660 train_time:132041ms step_avg:92.47ms +step:1429/1660 train_time:132136ms step_avg:92.47ms +step:1430/1660 train_time:132230ms step_avg:92.47ms +step:1431/1660 train_time:132324ms step_avg:92.47ms +step:1432/1660 train_time:132418ms step_avg:92.47ms +step:1433/1660 train_time:132512ms step_avg:92.47ms +step:1434/1660 train_time:132605ms step_avg:92.47ms +step:1435/1660 train_time:132698ms step_avg:92.47ms +step:1436/1660 train_time:132792ms step_avg:92.47ms +step:1437/1660 train_time:132886ms step_avg:92.47ms +step:1438/1660 train_time:132980ms step_avg:92.48ms +step:1439/1660 train_time:133074ms step_avg:92.48ms +step:1440/1660 train_time:133167ms step_avg:92.48ms +step:1441/1660 train_time:133260ms step_avg:92.48ms +step:1442/1660 train_time:133354ms step_avg:92.48ms +step:1443/1660 train_time:133449ms step_avg:92.48ms +step:1444/1660 train_time:133543ms step_avg:92.48ms +step:1445/1660 train_time:133635ms step_avg:92.48ms +step:1446/1660 train_time:133729ms step_avg:92.48ms +step:1447/1660 train_time:133822ms step_avg:92.48ms +step:1448/1660 train_time:133916ms step_avg:92.48ms +step:1449/1660 train_time:134010ms step_avg:92.48ms +step:1450/1660 train_time:134103ms step_avg:92.48ms +step:1451/1660 train_time:134196ms step_avg:92.48ms +step:1452/1660 train_time:134289ms step_avg:92.49ms +step:1453/1660 train_time:134383ms step_avg:92.49ms +step:1454/1660 train_time:134478ms step_avg:92.49ms +step:1455/1660 train_time:134571ms step_avg:92.49ms +step:1456/1660 train_time:134665ms step_avg:92.49ms +step:1457/1660 train_time:134757ms step_avg:92.49ms +step:1458/1660 train_time:134852ms step_avg:92.49ms +step:1459/1660 train_time:134946ms step_avg:92.49ms +step:1460/1660 train_time:135039ms step_avg:92.49ms +step:1461/1660 train_time:135134ms step_avg:92.49ms +step:1462/1660 train_time:135226ms step_avg:92.49ms +step:1463/1660 train_time:135320ms step_avg:92.49ms +step:1464/1660 train_time:135414ms step_avg:92.50ms +step:1465/1660 train_time:135507ms step_avg:92.50ms +step:1466/1660 train_time:135600ms step_avg:92.50ms +step:1467/1660 train_time:135693ms step_avg:92.50ms +step:1468/1660 train_time:135788ms step_avg:92.50ms +step:1469/1660 train_time:135882ms step_avg:92.50ms +step:1470/1660 train_time:135974ms step_avg:92.50ms +step:1471/1660 train_time:136069ms step_avg:92.50ms +step:1472/1660 train_time:136162ms step_avg:92.50ms +step:1473/1660 train_time:136255ms step_avg:92.50ms +step:1474/1660 train_time:136349ms step_avg:92.50ms +step:1475/1660 train_time:136443ms step_avg:92.50ms +step:1476/1660 train_time:136536ms step_avg:92.50ms +step:1477/1660 train_time:136629ms step_avg:92.50ms +step:1478/1660 train_time:136722ms step_avg:92.50ms +step:1479/1660 train_time:136816ms step_avg:92.51ms +step:1480/1660 train_time:136910ms step_avg:92.51ms +step:1481/1660 train_time:137004ms step_avg:92.51ms +step:1482/1660 train_time:137097ms step_avg:92.51ms +step:1483/1660 train_time:137191ms step_avg:92.51ms +step:1484/1660 train_time:137285ms step_avg:92.51ms +step:1485/1660 train_time:137379ms step_avg:92.51ms +step:1486/1660 train_time:137473ms step_avg:92.51ms +step:1487/1660 train_time:137566ms step_avg:92.51ms +step:1488/1660 train_time:137659ms step_avg:92.51ms +step:1489/1660 train_time:137753ms step_avg:92.51ms +step:1490/1660 train_time:137846ms step_avg:92.51ms +step:1491/1660 train_time:137940ms step_avg:92.52ms +step:1492/1660 train_time:138033ms step_avg:92.52ms +step:1493/1660 train_time:138126ms step_avg:92.52ms +step:1494/1660 train_time:138219ms step_avg:92.52ms +step:1495/1660 train_time:138313ms step_avg:92.52ms +step:1496/1660 train_time:138407ms step_avg:92.52ms +step:1497/1660 train_time:138499ms step_avg:92.52ms +step:1498/1660 train_time:138592ms step_avg:92.52ms +step:1499/1660 train_time:138686ms step_avg:92.52ms +step:1500/1660 train_time:138779ms step_avg:92.52ms +step:1500/1660 val_loss:3.3097 train_time:138874ms step_avg:92.58ms +step:1501/1660 train_time:138894ms step_avg:92.53ms +step:1502/1660 train_time:138971ms step_avg:92.52ms +step:1503/1660 train_time:139066ms step_avg:92.53ms +step:1504/1660 train_time:139158ms step_avg:92.53ms +step:1505/1660 train_time:139250ms step_avg:92.53ms +step:1506/1660 train_time:139342ms step_avg:92.52ms +step:1507/1660 train_time:139435ms step_avg:92.52ms +step:1508/1660 train_time:139527ms step_avg:92.52ms +step:1509/1660 train_time:139619ms step_avg:92.52ms +step:1510/1660 train_time:139713ms step_avg:92.52ms +step:1511/1660 train_time:139807ms step_avg:92.53ms +step:1512/1660 train_time:139902ms step_avg:92.53ms +step:1513/1660 train_time:139997ms step_avg:92.53ms +step:1514/1660 train_time:140091ms step_avg:92.53ms +step:1515/1660 train_time:140184ms step_avg:92.53ms +step:1516/1660 train_time:140277ms step_avg:92.53ms +step:1517/1660 train_time:140370ms step_avg:92.53ms +step:1518/1660 train_time:140462ms step_avg:92.53ms +step:1519/1660 train_time:140555ms step_avg:92.53ms +step:1520/1660 train_time:140648ms step_avg:92.53ms +step:1521/1660 train_time:140741ms step_avg:92.53ms +step:1522/1660 train_time:140836ms step_avg:92.53ms +step:1523/1660 train_time:140930ms step_avg:92.53ms +step:1524/1660 train_time:141023ms step_avg:92.53ms +step:1525/1660 train_time:141117ms step_avg:92.54ms +step:1526/1660 train_time:141211ms step_avg:92.54ms +step:1527/1660 train_time:141304ms step_avg:92.54ms +step:1528/1660 train_time:141397ms step_avg:92.54ms +step:1529/1660 train_time:141490ms step_avg:92.54ms +step:1530/1660 train_time:141582ms step_avg:92.54ms +step:1531/1660 train_time:141676ms step_avg:92.54ms +step:1532/1660 train_time:141770ms step_avg:92.54ms +step:1533/1660 train_time:141863ms step_avg:92.54ms +step:1534/1660 train_time:141957ms step_avg:92.54ms +step:1535/1660 train_time:142050ms step_avg:92.54ms +step:1536/1660 train_time:142144ms step_avg:92.54ms +step:1537/1660 train_time:142238ms step_avg:92.54ms +step:1538/1660 train_time:142331ms step_avg:92.54ms +step:1539/1660 train_time:142424ms step_avg:92.54ms +step:1540/1660 train_time:142517ms step_avg:92.54ms +step:1541/1660 train_time:142609ms step_avg:92.54ms +step:1542/1660 train_time:142702ms step_avg:92.54ms +step:1543/1660 train_time:142796ms step_avg:92.54ms +step:1544/1660 train_time:142889ms step_avg:92.54ms +step:1545/1660 train_time:142982ms step_avg:92.55ms +step:1546/1660 train_time:143076ms step_avg:92.55ms +step:1547/1660 train_time:143171ms step_avg:92.55ms +step:1548/1660 train_time:143265ms step_avg:92.55ms +step:1549/1660 train_time:143358ms step_avg:92.55ms +step:1550/1660 train_time:143451ms step_avg:92.55ms +step:1551/1660 train_time:143544ms step_avg:92.55ms +step:1552/1660 train_time:143638ms step_avg:92.55ms +step:1553/1660 train_time:143731ms step_avg:92.55ms +step:1554/1660 train_time:143823ms step_avg:92.55ms +step:1555/1660 train_time:143917ms step_avg:92.55ms +step:1556/1660 train_time:144011ms step_avg:92.55ms +step:1557/1660 train_time:144104ms step_avg:92.55ms +step:1558/1660 train_time:144198ms step_avg:92.55ms +step:1559/1660 train_time:144291ms step_avg:92.55ms +step:1560/1660 train_time:144384ms step_avg:92.55ms +step:1561/1660 train_time:144478ms step_avg:92.55ms +step:1562/1660 train_time:144571ms step_avg:92.55ms +step:1563/1660 train_time:144663ms step_avg:92.55ms +step:1564/1660 train_time:144756ms step_avg:92.56ms +step:1565/1660 train_time:144849ms step_avg:92.56ms +step:1566/1660 train_time:144942ms step_avg:92.56ms +step:1567/1660 train_time:145036ms step_avg:92.56ms +step:1568/1660 train_time:145130ms step_avg:92.56ms +step:1569/1660 train_time:145223ms step_avg:92.56ms +step:1570/1660 train_time:145317ms step_avg:92.56ms +step:1571/1660 train_time:145409ms step_avg:92.56ms +step:1572/1660 train_time:145503ms step_avg:92.56ms +step:1573/1660 train_time:145596ms step_avg:92.56ms +step:1574/1660 train_time:145689ms step_avg:92.56ms +step:1575/1660 train_time:145781ms step_avg:92.56ms +step:1576/1660 train_time:145875ms step_avg:92.56ms +step:1577/1660 train_time:145968ms step_avg:92.56ms +step:1578/1660 train_time:146061ms step_avg:92.56ms +step:1579/1660 train_time:146156ms step_avg:92.56ms +step:1580/1660 train_time:146249ms step_avg:92.56ms +step:1581/1660 train_time:146343ms step_avg:92.56ms +step:1582/1660 train_time:146437ms step_avg:92.56ms +step:1583/1660 train_time:146530ms step_avg:92.57ms +step:1584/1660 train_time:146623ms step_avg:92.56ms +step:1585/1660 train_time:146716ms step_avg:92.57ms +step:1586/1660 train_time:146810ms step_avg:92.57ms +step:1587/1660 train_time:146903ms step_avg:92.57ms +step:1588/1660 train_time:146997ms step_avg:92.57ms +step:1589/1660 train_time:147090ms step_avg:92.57ms +step:1590/1660 train_time:147183ms step_avg:92.57ms +step:1591/1660 train_time:147277ms step_avg:92.57ms +step:1592/1660 train_time:147370ms step_avg:92.57ms +step:1593/1660 train_time:147464ms step_avg:92.57ms +step:1594/1660 train_time:147557ms step_avg:92.57ms +step:1595/1660 train_time:147650ms step_avg:92.57ms +step:1596/1660 train_time:147743ms step_avg:92.57ms +step:1597/1660 train_time:147837ms step_avg:92.57ms +step:1598/1660 train_time:147931ms step_avg:92.57ms +step:1599/1660 train_time:148023ms step_avg:92.57ms +step:1600/1660 train_time:148116ms step_avg:92.57ms +step:1601/1660 train_time:148209ms step_avg:92.57ms +step:1602/1660 train_time:148302ms step_avg:92.57ms +step:1603/1660 train_time:148396ms step_avg:92.57ms +step:1604/1660 train_time:148490ms step_avg:92.57ms +step:1605/1660 train_time:148583ms step_avg:92.57ms +step:1606/1660 train_time:148676ms step_avg:92.58ms +step:1607/1660 train_time:148770ms step_avg:92.58ms +step:1608/1660 train_time:148863ms step_avg:92.58ms +step:1609/1660 train_time:148956ms step_avg:92.58ms +step:1610/1660 train_time:149050ms step_avg:92.58ms +step:1611/1660 train_time:149143ms step_avg:92.58ms +step:1612/1660 train_time:149237ms step_avg:92.58ms +step:1613/1660 train_time:149330ms step_avg:92.58ms +step:1614/1660 train_time:149423ms step_avg:92.58ms +step:1615/1660 train_time:149516ms step_avg:92.58ms +step:1616/1660 train_time:149609ms step_avg:92.58ms +step:1617/1660 train_time:149703ms step_avg:92.58ms +step:1618/1660 train_time:149796ms step_avg:92.58ms +step:1619/1660 train_time:149889ms step_avg:92.58ms +step:1620/1660 train_time:149982ms step_avg:92.58ms +step:1621/1660 train_time:150076ms step_avg:92.58ms +step:1622/1660 train_time:150170ms step_avg:92.58ms +step:1623/1660 train_time:150263ms step_avg:92.58ms +step:1624/1660 train_time:150356ms step_avg:92.58ms +step:1625/1660 train_time:150449ms step_avg:92.58ms +step:1625/1660 val_loss:3.2846 train_time:150544ms step_avg:92.64ms +step:1626/1660 train_time:150564ms step_avg:92.60ms +step:1627/1660 train_time:150642ms step_avg:92.59ms +step:1628/1660 train_time:150739ms step_avg:92.59ms +step:1629/1660 train_time:150832ms step_avg:92.59ms +step:1630/1660 train_time:150925ms step_avg:92.59ms +step:1631/1660 train_time:151016ms step_avg:92.59ms +step:1632/1660 train_time:151109ms step_avg:92.59ms +step:1633/1660 train_time:151200ms step_avg:92.59ms +step:1634/1660 train_time:151293ms step_avg:92.59ms +step:1635/1660 train_time:151385ms step_avg:92.59ms +step:1636/1660 train_time:151479ms step_avg:92.59ms +step:1637/1660 train_time:151575ms step_avg:92.59ms +step:1638/1660 train_time:151672ms step_avg:92.60ms +step:1639/1660 train_time:151768ms step_avg:92.60ms +step:1640/1660 train_time:151861ms step_avg:92.60ms +step:1641/1660 train_time:151953ms step_avg:92.60ms +step:1642/1660 train_time:152046ms step_avg:92.60ms +step:1643/1660 train_time:152138ms step_avg:92.60ms +step:1644/1660 train_time:152230ms step_avg:92.60ms +step:1645/1660 train_time:152323ms step_avg:92.60ms +step:1646/1660 train_time:152415ms step_avg:92.60ms +step:1647/1660 train_time:152510ms step_avg:92.60ms +step:1648/1660 train_time:152606ms step_avg:92.60ms +step:1649/1660 train_time:152700ms step_avg:92.60ms +step:1650/1660 train_time:152793ms step_avg:92.60ms +step:1651/1660 train_time:152887ms step_avg:92.60ms +step:1652/1660 train_time:152979ms step_avg:92.60ms +step:1653/1660 train_time:153072ms step_avg:92.60ms +step:1654/1660 train_time:153164ms step_avg:92.60ms +step:1655/1660 train_time:153257ms step_avg:92.60ms +step:1656/1660 train_time:153349ms step_avg:92.60ms +step:1657/1660 train_time:153443ms step_avg:92.60ms +step:1658/1660 train_time:153537ms step_avg:92.60ms +step:1659/1660 train_time:153631ms step_avg:92.60ms +step:1660/1660 train_time:153725ms step_avg:92.61ms +step:1660/1660 val_loss:3.2764 train_time:153820ms step_avg:92.66ms +peak memory allocated: 32002 MiB reserved: 46836 MiB diff --git a/records/091525_ThreadingFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt b/records/091525_ThreadingFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt new file mode 100644 index 000000000..dbb9b8251 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 04:01:46 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 39C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 127W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 200415 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 200416 C /usr/bin/python3 614MiB | +| 0 N/A N/A 200417 C /usr/bin/python3 614MiB | +| 0 N/A N/A 200418 C /usr/bin/python3 614MiB | +| 0 N/A N/A 200419 C /usr/bin/python3 614MiB | +| 0 N/A N/A 200420 C /usr/bin/python3 614MiB | +| 0 N/A N/A 200421 C /usr/bin/python3 614MiB | +| 0 N/A N/A 200422 C /usr/bin/python3 614MiB | +| 1 N/A N/A 200416 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 200417 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 200418 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 200419 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 200420 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 200421 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 200422 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:146ms step_avg:145.85ms +step:2/1660 train_time:171ms step_avg:85.47ms +step:3/1660 train_time:234ms step_avg:77.90ms +step:4/1660 train_time:323ms step_avg:80.80ms +step:5/1660 train_time:414ms step_avg:82.73ms +step:6/1660 train_time:504ms step_avg:84.07ms +step:7/1660 train_time:595ms step_avg:84.98ms +step:8/1660 train_time:685ms step_avg:85.65ms +step:9/1660 train_time:776ms step_avg:86.22ms +step:10/1660 train_time:868ms step_avg:86.78ms +step:11/1660 train_time:959ms step_avg:87.16ms +step:12/1660 train_time:1054ms step_avg:87.79ms +step:13/1660 train_time:1149ms step_avg:88.40ms +step:14/1660 train_time:1242ms step_avg:88.72ms +step:15/1660 train_time:1333ms step_avg:88.87ms +step:16/1660 train_time:1424ms step_avg:88.99ms +step:17/1660 train_time:1515ms step_avg:89.12ms +step:18/1660 train_time:1607ms step_avg:89.27ms +step:19/1660 train_time:1698ms step_avg:89.36ms +step:20/1660 train_time:1789ms step_avg:89.44ms +step:21/1660 train_time:1880ms step_avg:89.50ms +step:22/1660 train_time:1972ms step_avg:89.65ms +step:23/1660 train_time:2066ms step_avg:89.83ms +step:24/1660 train_time:2160ms step_avg:89.99ms +step:25/1660 train_time:2253ms step_avg:90.11ms +step:26/1660 train_time:2346ms step_avg:90.23ms +step:27/1660 train_time:2437ms step_avg:90.27ms +step:28/1660 train_time:2530ms step_avg:90.35ms +step:29/1660 train_time:2621ms step_avg:90.38ms +step:30/1660 train_time:2713ms step_avg:90.42ms +step:31/1660 train_time:2804ms step_avg:90.44ms +step:32/1660 train_time:2896ms step_avg:90.50ms +step:33/1660 train_time:2988ms step_avg:90.55ms +step:34/1660 train_time:3081ms step_avg:90.61ms +step:35/1660 train_time:3173ms step_avg:90.66ms +step:36/1660 train_time:3264ms step_avg:90.68ms +step:37/1660 train_time:3356ms step_avg:90.71ms +step:38/1660 train_time:3448ms step_avg:90.73ms +step:39/1660 train_time:3539ms step_avg:90.74ms +step:40/1660 train_time:3630ms step_avg:90.76ms +step:41/1660 train_time:3722ms step_avg:90.77ms +step:42/1660 train_time:3814ms step_avg:90.81ms +step:43/1660 train_time:3906ms step_avg:90.84ms +step:44/1660 train_time:3997ms step_avg:90.84ms +step:45/1660 train_time:4089ms step_avg:90.87ms +step:46/1660 train_time:4182ms step_avg:90.91ms +step:47/1660 train_time:4274ms step_avg:90.94ms +step:48/1660 train_time:4366ms step_avg:90.96ms +step:49/1660 train_time:4458ms step_avg:90.97ms +step:50/1660 train_time:4549ms step_avg:90.98ms +step:51/1660 train_time:4640ms step_avg:90.99ms +step:52/1660 train_time:4732ms step_avg:91.01ms +step:53/1660 train_time:4824ms step_avg:91.02ms +step:54/1660 train_time:4916ms step_avg:91.03ms +step:55/1660 train_time:5007ms step_avg:91.04ms +step:56/1660 train_time:5099ms step_avg:91.06ms +step:57/1660 train_time:5193ms step_avg:91.10ms +step:58/1660 train_time:5285ms step_avg:91.12ms +step:59/1660 train_time:5376ms step_avg:91.12ms +step:60/1660 train_time:5467ms step_avg:91.12ms +step:61/1660 train_time:5559ms step_avg:91.13ms +step:62/1660 train_time:5650ms step_avg:91.13ms +step:63/1660 train_time:5742ms step_avg:91.14ms +step:64/1660 train_time:5833ms step_avg:91.15ms +step:65/1660 train_time:5925ms step_avg:91.15ms +step:66/1660 train_time:6016ms step_avg:91.15ms +step:67/1660 train_time:6108ms step_avg:91.17ms +step:68/1660 train_time:6201ms step_avg:91.19ms +step:69/1660 train_time:6294ms step_avg:91.22ms +step:70/1660 train_time:6386ms step_avg:91.23ms +step:71/1660 train_time:6478ms step_avg:91.24ms +step:72/1660 train_time:6570ms step_avg:91.25ms +step:73/1660 train_time:6662ms step_avg:91.26ms +step:74/1660 train_time:6754ms step_avg:91.27ms +step:75/1660 train_time:6847ms step_avg:91.29ms +step:76/1660 train_time:6938ms step_avg:91.29ms +step:77/1660 train_time:7029ms step_avg:91.29ms +step:78/1660 train_time:7121ms step_avg:91.29ms +step:79/1660 train_time:7214ms step_avg:91.31ms +step:80/1660 train_time:7307ms step_avg:91.33ms +step:81/1660 train_time:7398ms step_avg:91.33ms +step:82/1660 train_time:7490ms step_avg:91.35ms +step:83/1660 train_time:7582ms step_avg:91.35ms +step:84/1660 train_time:7673ms step_avg:91.35ms +step:85/1660 train_time:7764ms step_avg:91.35ms +step:86/1660 train_time:7856ms step_avg:91.35ms +step:87/1660 train_time:7948ms step_avg:91.35ms +step:88/1660 train_time:8039ms step_avg:91.35ms +step:89/1660 train_time:8130ms step_avg:91.35ms +step:90/1660 train_time:8221ms step_avg:91.35ms +step:91/1660 train_time:8313ms step_avg:91.36ms +step:92/1660 train_time:8405ms step_avg:91.36ms +step:93/1660 train_time:8496ms step_avg:91.36ms +step:94/1660 train_time:8589ms step_avg:91.37ms +step:95/1660 train_time:8680ms step_avg:91.37ms +step:96/1660 train_time:8772ms step_avg:91.38ms +step:97/1660 train_time:8864ms step_avg:91.38ms +step:98/1660 train_time:8956ms step_avg:91.38ms +step:99/1660 train_time:9047ms step_avg:91.39ms +step:100/1660 train_time:9138ms step_avg:91.38ms +step:101/1660 train_time:9229ms step_avg:91.38ms +step:102/1660 train_time:9320ms step_avg:91.38ms +step:103/1660 train_time:9413ms step_avg:91.39ms +step:104/1660 train_time:9506ms step_avg:91.40ms +step:105/1660 train_time:9597ms step_avg:91.40ms +step:106/1660 train_time:9688ms step_avg:91.40ms +step:107/1660 train_time:9779ms step_avg:91.40ms +step:108/1660 train_time:9871ms step_avg:91.40ms +step:109/1660 train_time:9962ms step_avg:91.40ms +step:110/1660 train_time:10054ms step_avg:91.40ms +step:111/1660 train_time:10147ms step_avg:91.42ms +step:112/1660 train_time:10238ms step_avg:91.41ms +step:113/1660 train_time:10329ms step_avg:91.41ms +step:114/1660 train_time:10421ms step_avg:91.41ms +step:115/1660 train_time:10515ms step_avg:91.43ms +step:116/1660 train_time:10606ms step_avg:91.43ms +step:117/1660 train_time:10697ms step_avg:91.43ms +step:118/1660 train_time:10789ms step_avg:91.43ms +step:119/1660 train_time:10881ms step_avg:91.43ms +step:120/1660 train_time:10972ms step_avg:91.44ms +step:121/1660 train_time:11064ms step_avg:91.44ms +step:122/1660 train_time:11155ms step_avg:91.44ms +step:123/1660 train_time:11247ms step_avg:91.44ms +step:124/1660 train_time:11338ms step_avg:91.44ms +step:125/1660 train_time:11430ms step_avg:91.44ms +step:125/1660 val_loss:4.3006 train_time:11523ms step_avg:92.18ms +step:126/1660 train_time:11546ms step_avg:91.63ms +step:127/1660 train_time:11617ms step_avg:91.47ms +step:128/1660 train_time:11718ms step_avg:91.55ms +step:129/1660 train_time:11814ms step_avg:91.58ms +step:130/1660 train_time:11905ms step_avg:91.58ms +step:131/1660 train_time:11997ms step_avg:91.58ms +step:132/1660 train_time:12087ms step_avg:91.57ms +step:133/1660 train_time:12178ms step_avg:91.56ms +step:134/1660 train_time:12268ms step_avg:91.55ms +step:135/1660 train_time:12359ms step_avg:91.55ms +step:136/1660 train_time:12449ms step_avg:91.54ms +step:137/1660 train_time:12540ms step_avg:91.53ms +step:138/1660 train_time:12633ms step_avg:91.55ms +step:139/1660 train_time:12728ms step_avg:91.57ms +step:140/1660 train_time:12824ms step_avg:91.60ms +step:141/1660 train_time:12916ms step_avg:91.60ms +step:142/1660 train_time:13008ms step_avg:91.60ms +step:143/1660 train_time:13100ms step_avg:91.60ms +step:144/1660 train_time:13190ms step_avg:91.60ms +step:145/1660 train_time:13280ms step_avg:91.59ms +step:146/1660 train_time:13370ms step_avg:91.58ms +step:147/1660 train_time:13461ms step_avg:91.57ms +step:148/1660 train_time:13552ms step_avg:91.57ms +step:149/1660 train_time:13645ms step_avg:91.58ms +step:150/1660 train_time:13737ms step_avg:91.58ms +step:151/1660 train_time:13830ms step_avg:91.59ms +step:152/1660 train_time:13923ms step_avg:91.60ms +step:153/1660 train_time:14014ms step_avg:91.60ms +step:154/1660 train_time:14105ms step_avg:91.59ms +step:155/1660 train_time:14196ms step_avg:91.59ms +step:156/1660 train_time:14286ms step_avg:91.58ms +step:157/1660 train_time:14377ms step_avg:91.57ms +step:158/1660 train_time:14468ms step_avg:91.57ms +step:159/1660 train_time:14559ms step_avg:91.57ms +step:160/1660 train_time:14653ms step_avg:91.58ms +step:161/1660 train_time:14744ms step_avg:91.58ms +step:162/1660 train_time:14836ms step_avg:91.58ms +step:163/1660 train_time:14929ms step_avg:91.59ms +step:164/1660 train_time:15021ms step_avg:91.59ms +step:165/1660 train_time:15112ms step_avg:91.59ms +step:166/1660 train_time:15203ms step_avg:91.58ms +step:167/1660 train_time:15294ms step_avg:91.58ms +step:168/1660 train_time:15384ms step_avg:91.57ms +step:169/1660 train_time:15475ms step_avg:91.57ms +step:170/1660 train_time:15566ms step_avg:91.56ms +step:171/1660 train_time:15659ms step_avg:91.57ms +step:172/1660 train_time:15751ms step_avg:91.58ms +step:173/1660 train_time:15843ms step_avg:91.58ms +step:174/1660 train_time:15936ms step_avg:91.58ms +step:175/1660 train_time:16027ms step_avg:91.58ms +step:176/1660 train_time:16119ms step_avg:91.58ms +step:177/1660 train_time:16210ms step_avg:91.58ms +step:178/1660 train_time:16301ms step_avg:91.58ms +step:179/1660 train_time:16393ms step_avg:91.58ms +step:180/1660 train_time:16484ms step_avg:91.58ms +step:181/1660 train_time:16575ms step_avg:91.57ms +step:182/1660 train_time:16667ms step_avg:91.57ms +step:183/1660 train_time:16760ms step_avg:91.58ms +step:184/1660 train_time:16853ms step_avg:91.59ms +step:185/1660 train_time:16945ms step_avg:91.59ms +step:186/1660 train_time:17037ms step_avg:91.60ms +step:187/1660 train_time:17128ms step_avg:91.59ms +step:188/1660 train_time:17220ms step_avg:91.59ms +step:189/1660 train_time:17312ms step_avg:91.60ms +step:190/1660 train_time:17403ms step_avg:91.59ms +step:191/1660 train_time:17494ms step_avg:91.59ms +step:192/1660 train_time:17585ms step_avg:91.59ms +step:193/1660 train_time:17676ms step_avg:91.59ms +step:194/1660 train_time:17768ms step_avg:91.59ms +step:195/1660 train_time:17861ms step_avg:91.60ms +step:196/1660 train_time:17953ms step_avg:91.60ms +step:197/1660 train_time:18044ms step_avg:91.60ms +step:198/1660 train_time:18135ms step_avg:91.59ms +step:199/1660 train_time:18227ms step_avg:91.60ms +step:200/1660 train_time:18319ms step_avg:91.60ms +step:201/1660 train_time:18410ms step_avg:91.59ms +step:202/1660 train_time:18501ms step_avg:91.59ms +step:203/1660 train_time:18593ms step_avg:91.59ms +step:204/1660 train_time:18685ms step_avg:91.59ms +step:205/1660 train_time:18777ms step_avg:91.59ms +step:206/1660 train_time:18868ms step_avg:91.59ms +step:207/1660 train_time:18961ms step_avg:91.60ms +step:208/1660 train_time:19052ms step_avg:91.60ms +step:209/1660 train_time:19144ms step_avg:91.60ms +step:210/1660 train_time:19235ms step_avg:91.60ms +step:211/1660 train_time:19326ms step_avg:91.59ms +step:212/1660 train_time:19418ms step_avg:91.60ms +step:213/1660 train_time:19510ms step_avg:91.60ms +step:214/1660 train_time:19602ms step_avg:91.60ms +step:215/1660 train_time:19693ms step_avg:91.60ms +step:216/1660 train_time:19785ms step_avg:91.60ms +step:217/1660 train_time:19876ms step_avg:91.60ms +step:218/1660 train_time:19968ms step_avg:91.60ms +step:219/1660 train_time:20060ms step_avg:91.60ms +step:220/1660 train_time:20152ms step_avg:91.60ms +step:221/1660 train_time:20244ms step_avg:91.60ms +step:222/1660 train_time:20335ms step_avg:91.60ms +step:223/1660 train_time:20426ms step_avg:91.59ms +step:224/1660 train_time:20517ms step_avg:91.59ms +step:225/1660 train_time:20609ms step_avg:91.59ms +step:226/1660 train_time:20701ms step_avg:91.60ms +step:227/1660 train_time:20792ms step_avg:91.60ms +step:228/1660 train_time:20884ms step_avg:91.60ms +step:229/1660 train_time:20976ms step_avg:91.60ms +step:230/1660 train_time:21068ms step_avg:91.60ms +step:231/1660 train_time:21160ms step_avg:91.60ms +step:232/1660 train_time:21251ms step_avg:91.60ms +step:233/1660 train_time:21343ms step_avg:91.60ms +step:234/1660 train_time:21434ms step_avg:91.60ms +step:235/1660 train_time:21525ms step_avg:91.59ms +step:236/1660 train_time:21616ms step_avg:91.59ms +step:237/1660 train_time:21707ms step_avg:91.59ms +step:238/1660 train_time:21799ms step_avg:91.59ms +step:239/1660 train_time:21891ms step_avg:91.59ms +step:240/1660 train_time:21983ms step_avg:91.60ms +step:241/1660 train_time:22075ms step_avg:91.60ms +step:242/1660 train_time:22166ms step_avg:91.60ms +step:243/1660 train_time:22258ms step_avg:91.60ms +step:244/1660 train_time:22349ms step_avg:91.59ms +step:245/1660 train_time:22441ms step_avg:91.59ms +step:246/1660 train_time:22533ms step_avg:91.60ms +step:247/1660 train_time:22623ms step_avg:91.59ms +step:248/1660 train_time:22715ms step_avg:91.59ms +step:249/1660 train_time:22807ms step_avg:91.59ms +step:250/1660 train_time:22899ms step_avg:91.60ms +step:250/1660 val_loss:3.9630 train_time:22992ms step_avg:91.97ms +step:251/1660 train_time:23013ms step_avg:91.69ms +step:252/1660 train_time:23085ms step_avg:91.61ms +step:253/1660 train_time:23181ms step_avg:91.62ms +step:254/1660 train_time:23275ms step_avg:91.63ms +step:255/1660 train_time:23367ms step_avg:91.63ms +step:256/1660 train_time:23457ms step_avg:91.63ms +step:257/1660 train_time:23547ms step_avg:91.62ms +step:258/1660 train_time:23637ms step_avg:91.62ms +step:259/1660 train_time:23728ms step_avg:91.61ms +step:260/1660 train_time:23818ms step_avg:91.61ms +step:261/1660 train_time:23910ms step_avg:91.61ms +step:262/1660 train_time:24002ms step_avg:91.61ms +step:263/1660 train_time:24096ms step_avg:91.62ms +step:264/1660 train_time:24190ms step_avg:91.63ms +step:265/1660 train_time:24283ms step_avg:91.63ms +step:266/1660 train_time:24375ms step_avg:91.63ms +step:267/1660 train_time:24466ms step_avg:91.63ms +step:268/1660 train_time:24556ms step_avg:91.63ms +step:269/1660 train_time:24648ms step_avg:91.63ms +step:270/1660 train_time:24738ms step_avg:91.62ms +step:271/1660 train_time:24829ms step_avg:91.62ms +step:272/1660 train_time:24920ms step_avg:91.62ms +step:273/1660 train_time:25011ms step_avg:91.62ms +step:274/1660 train_time:25104ms step_avg:91.62ms +step:275/1660 train_time:25196ms step_avg:91.62ms +step:276/1660 train_time:25289ms step_avg:91.63ms +step:277/1660 train_time:25381ms step_avg:91.63ms +step:278/1660 train_time:25473ms step_avg:91.63ms +step:279/1660 train_time:25563ms step_avg:91.62ms +step:280/1660 train_time:25653ms step_avg:91.62ms +step:281/1660 train_time:25744ms step_avg:91.61ms +step:282/1660 train_time:25835ms step_avg:91.61ms +step:283/1660 train_time:25926ms step_avg:91.61ms +step:284/1660 train_time:26017ms step_avg:91.61ms +step:285/1660 train_time:26110ms step_avg:91.61ms +step:286/1660 train_time:26202ms step_avg:91.62ms +step:287/1660 train_time:26294ms step_avg:91.62ms +step:288/1660 train_time:26386ms step_avg:91.62ms +step:289/1660 train_time:26478ms step_avg:91.62ms +step:290/1660 train_time:26569ms step_avg:91.62ms +step:291/1660 train_time:26660ms step_avg:91.62ms +step:292/1660 train_time:26751ms step_avg:91.61ms +step:293/1660 train_time:26842ms step_avg:91.61ms +step:294/1660 train_time:26933ms step_avg:91.61ms +step:295/1660 train_time:27024ms step_avg:91.61ms +step:296/1660 train_time:27116ms step_avg:91.61ms +step:297/1660 train_time:27211ms step_avg:91.62ms +step:298/1660 train_time:27304ms step_avg:91.62ms +step:299/1660 train_time:27395ms step_avg:91.62ms +step:300/1660 train_time:27486ms step_avg:91.62ms +step:301/1660 train_time:27577ms step_avg:91.62ms +step:302/1660 train_time:27669ms step_avg:91.62ms +step:303/1660 train_time:27759ms step_avg:91.62ms +step:304/1660 train_time:27851ms step_avg:91.61ms +step:305/1660 train_time:27942ms step_avg:91.61ms +step:306/1660 train_time:28034ms step_avg:91.61ms +step:307/1660 train_time:28126ms step_avg:91.62ms +step:308/1660 train_time:28219ms step_avg:91.62ms +step:309/1660 train_time:28311ms step_avg:91.62ms +step:310/1660 train_time:28403ms step_avg:91.62ms +step:311/1660 train_time:28495ms step_avg:91.62ms +step:312/1660 train_time:28586ms step_avg:91.62ms +step:313/1660 train_time:28677ms step_avg:91.62ms +step:314/1660 train_time:28769ms step_avg:91.62ms +step:315/1660 train_time:28860ms step_avg:91.62ms +step:316/1660 train_time:28951ms step_avg:91.62ms +step:317/1660 train_time:29043ms step_avg:91.62ms +step:318/1660 train_time:29135ms step_avg:91.62ms +step:319/1660 train_time:29227ms step_avg:91.62ms +step:320/1660 train_time:29319ms step_avg:91.62ms +step:321/1660 train_time:29412ms step_avg:91.63ms +step:322/1660 train_time:29504ms step_avg:91.63ms +step:323/1660 train_time:29595ms step_avg:91.63ms +step:324/1660 train_time:29686ms step_avg:91.62ms +step:325/1660 train_time:29777ms step_avg:91.62ms +step:326/1660 train_time:29869ms step_avg:91.62ms +step:327/1660 train_time:29960ms step_avg:91.62ms +step:328/1660 train_time:30052ms step_avg:91.62ms +step:329/1660 train_time:30143ms step_avg:91.62ms +step:330/1660 train_time:30235ms step_avg:91.62ms +step:331/1660 train_time:30327ms step_avg:91.62ms +step:332/1660 train_time:30419ms step_avg:91.62ms +step:333/1660 train_time:30511ms step_avg:91.62ms +step:334/1660 train_time:30602ms step_avg:91.62ms +step:335/1660 train_time:30694ms step_avg:91.62ms +step:336/1660 train_time:30785ms step_avg:91.62ms +step:337/1660 train_time:30876ms step_avg:91.62ms +step:338/1660 train_time:30967ms step_avg:91.62ms +step:339/1660 train_time:31058ms step_avg:91.62ms +step:340/1660 train_time:31150ms step_avg:91.62ms +step:341/1660 train_time:31242ms step_avg:91.62ms +step:342/1660 train_time:31334ms step_avg:91.62ms +step:343/1660 train_time:31425ms step_avg:91.62ms +step:344/1660 train_time:31517ms step_avg:91.62ms +step:345/1660 train_time:31610ms step_avg:91.62ms +step:346/1660 train_time:31701ms step_avg:91.62ms +step:347/1660 train_time:31794ms step_avg:91.62ms +step:348/1660 train_time:31885ms step_avg:91.62ms +step:349/1660 train_time:31977ms step_avg:91.62ms +step:350/1660 train_time:32068ms step_avg:91.62ms +step:351/1660 train_time:32158ms step_avg:91.62ms +step:352/1660 train_time:32250ms step_avg:91.62ms +step:353/1660 train_time:32342ms step_avg:91.62ms +step:354/1660 train_time:32435ms step_avg:91.62ms +step:355/1660 train_time:32526ms step_avg:91.62ms +step:356/1660 train_time:32617ms step_avg:91.62ms +step:357/1660 train_time:32709ms step_avg:91.62ms +step:358/1660 train_time:32800ms step_avg:91.62ms +step:359/1660 train_time:32891ms step_avg:91.62ms +step:360/1660 train_time:32982ms step_avg:91.62ms +step:361/1660 train_time:33074ms step_avg:91.62ms +step:362/1660 train_time:33165ms step_avg:91.62ms +step:363/1660 train_time:33256ms step_avg:91.61ms +step:364/1660 train_time:33347ms step_avg:91.61ms +step:365/1660 train_time:33439ms step_avg:91.61ms +step:366/1660 train_time:33531ms step_avg:91.61ms +step:367/1660 train_time:33622ms step_avg:91.61ms +step:368/1660 train_time:33713ms step_avg:91.61ms +step:369/1660 train_time:33805ms step_avg:91.61ms +step:370/1660 train_time:33897ms step_avg:91.61ms +step:371/1660 train_time:33989ms step_avg:91.61ms +step:372/1660 train_time:34080ms step_avg:91.61ms +step:373/1660 train_time:34172ms step_avg:91.61ms +step:374/1660 train_time:34263ms step_avg:91.61ms +step:375/1660 train_time:34357ms step_avg:91.62ms +step:375/1660 val_loss:3.8119 train_time:34450ms step_avg:91.87ms +step:376/1660 train_time:34472ms step_avg:91.68ms +step:377/1660 train_time:34544ms step_avg:91.63ms +step:378/1660 train_time:34638ms step_avg:91.64ms +step:379/1660 train_time:34729ms step_avg:91.63ms +step:380/1660 train_time:34820ms step_avg:91.63ms +step:381/1660 train_time:34911ms step_avg:91.63ms +step:382/1660 train_time:35001ms step_avg:91.63ms +step:383/1660 train_time:35092ms step_avg:91.62ms +step:384/1660 train_time:35183ms step_avg:91.62ms +step:385/1660 train_time:35274ms step_avg:91.62ms +step:386/1660 train_time:35365ms step_avg:91.62ms +step:387/1660 train_time:35458ms step_avg:91.62ms +step:388/1660 train_time:35551ms step_avg:91.63ms +step:389/1660 train_time:35644ms step_avg:91.63ms +step:390/1660 train_time:35736ms step_avg:91.63ms +step:391/1660 train_time:35826ms step_avg:91.63ms +step:392/1660 train_time:35917ms step_avg:91.63ms +step:393/1660 train_time:36008ms step_avg:91.62ms +step:394/1660 train_time:36099ms step_avg:91.62ms +step:395/1660 train_time:36189ms step_avg:91.62ms +step:396/1660 train_time:36280ms step_avg:91.62ms +step:397/1660 train_time:36371ms step_avg:91.62ms +step:398/1660 train_time:36464ms step_avg:91.62ms +step:399/1660 train_time:36556ms step_avg:91.62ms +step:400/1660 train_time:36648ms step_avg:91.62ms +step:401/1660 train_time:36739ms step_avg:91.62ms +step:402/1660 train_time:36831ms step_avg:91.62ms +step:403/1660 train_time:36922ms step_avg:91.62ms +step:404/1660 train_time:37013ms step_avg:91.62ms +step:405/1660 train_time:37104ms step_avg:91.61ms +step:406/1660 train_time:37195ms step_avg:91.61ms +step:407/1660 train_time:37285ms step_avg:91.61ms +step:408/1660 train_time:37376ms step_avg:91.61ms +step:409/1660 train_time:37467ms step_avg:91.61ms +step:410/1660 train_time:37560ms step_avg:91.61ms +step:411/1660 train_time:37652ms step_avg:91.61ms +step:412/1660 train_time:37744ms step_avg:91.61ms +step:413/1660 train_time:37836ms step_avg:91.61ms +step:414/1660 train_time:37927ms step_avg:91.61ms +step:415/1660 train_time:38018ms step_avg:91.61ms +step:416/1660 train_time:38109ms step_avg:91.61ms +step:417/1660 train_time:38200ms step_avg:91.61ms +step:418/1660 train_time:38291ms step_avg:91.60ms +step:419/1660 train_time:38382ms step_avg:91.60ms +step:420/1660 train_time:38473ms step_avg:91.60ms +step:421/1660 train_time:38564ms step_avg:91.60ms +step:422/1660 train_time:38656ms step_avg:91.60ms +step:423/1660 train_time:38748ms step_avg:91.60ms +step:424/1660 train_time:38840ms step_avg:91.60ms +step:425/1660 train_time:38932ms step_avg:91.61ms +step:426/1660 train_time:39023ms step_avg:91.60ms +step:427/1660 train_time:39114ms step_avg:91.60ms +step:428/1660 train_time:39205ms step_avg:91.60ms +step:429/1660 train_time:39296ms step_avg:91.60ms +step:430/1660 train_time:39387ms step_avg:91.60ms +step:431/1660 train_time:39478ms step_avg:91.60ms +step:432/1660 train_time:39570ms step_avg:91.60ms +step:433/1660 train_time:39662ms step_avg:91.60ms +step:434/1660 train_time:39753ms step_avg:91.60ms +step:435/1660 train_time:39844ms step_avg:91.60ms +step:436/1660 train_time:39936ms step_avg:91.60ms +step:437/1660 train_time:40026ms step_avg:91.59ms +step:438/1660 train_time:40117ms step_avg:91.59ms +step:439/1660 train_time:40210ms step_avg:91.59ms +step:440/1660 train_time:40301ms step_avg:91.59ms +step:441/1660 train_time:40391ms step_avg:91.59ms +step:442/1660 train_time:40482ms step_avg:91.59ms +step:443/1660 train_time:40574ms step_avg:91.59ms +step:444/1660 train_time:40666ms step_avg:91.59ms +step:445/1660 train_time:40757ms step_avg:91.59ms +step:446/1660 train_time:40849ms step_avg:91.59ms +step:447/1660 train_time:40941ms step_avg:91.59ms +step:448/1660 train_time:41032ms step_avg:91.59ms +step:449/1660 train_time:41123ms step_avg:91.59ms +step:450/1660 train_time:41214ms step_avg:91.59ms +step:451/1660 train_time:41305ms step_avg:91.58ms +step:452/1660 train_time:41396ms step_avg:91.58ms +step:453/1660 train_time:41487ms step_avg:91.58ms +step:454/1660 train_time:41578ms step_avg:91.58ms +step:455/1660 train_time:41670ms step_avg:91.58ms +step:456/1660 train_time:41761ms step_avg:91.58ms +step:457/1660 train_time:41854ms step_avg:91.58ms +step:458/1660 train_time:41945ms step_avg:91.58ms +step:459/1660 train_time:42036ms step_avg:91.58ms +step:460/1660 train_time:42127ms step_avg:91.58ms +step:461/1660 train_time:42219ms step_avg:91.58ms +step:462/1660 train_time:42310ms step_avg:91.58ms +step:463/1660 train_time:42401ms step_avg:91.58ms +step:464/1660 train_time:42493ms step_avg:91.58ms +step:465/1660 train_time:42584ms step_avg:91.58ms +step:466/1660 train_time:42675ms step_avg:91.58ms +step:467/1660 train_time:42767ms step_avg:91.58ms +step:468/1660 train_time:42859ms step_avg:91.58ms +step:469/1660 train_time:42950ms step_avg:91.58ms +step:470/1660 train_time:43041ms step_avg:91.58ms +step:471/1660 train_time:43132ms step_avg:91.57ms +step:472/1660 train_time:43223ms step_avg:91.57ms +step:473/1660 train_time:43314ms step_avg:91.57ms +step:474/1660 train_time:43405ms step_avg:91.57ms +step:475/1660 train_time:43496ms step_avg:91.57ms +step:476/1660 train_time:43588ms step_avg:91.57ms +step:477/1660 train_time:43679ms step_avg:91.57ms +step:478/1660 train_time:43770ms step_avg:91.57ms +step:479/1660 train_time:43861ms step_avg:91.57ms +step:480/1660 train_time:43953ms step_avg:91.57ms +step:481/1660 train_time:44045ms step_avg:91.57ms +step:482/1660 train_time:44136ms step_avg:91.57ms +step:483/1660 train_time:44227ms step_avg:91.57ms +step:484/1660 train_time:44318ms step_avg:91.57ms +step:485/1660 train_time:44409ms step_avg:91.57ms +step:486/1660 train_time:44501ms step_avg:91.57ms +step:487/1660 train_time:44593ms step_avg:91.57ms +step:488/1660 train_time:44685ms step_avg:91.57ms +step:489/1660 train_time:44776ms step_avg:91.57ms +step:490/1660 train_time:44868ms step_avg:91.57ms +step:491/1660 train_time:44960ms step_avg:91.57ms +step:492/1660 train_time:45052ms step_avg:91.57ms +step:493/1660 train_time:45143ms step_avg:91.57ms +step:494/1660 train_time:45236ms step_avg:91.57ms +step:495/1660 train_time:45326ms step_avg:91.57ms +step:496/1660 train_time:45418ms step_avg:91.57ms +step:497/1660 train_time:45509ms step_avg:91.57ms +step:498/1660 train_time:45601ms step_avg:91.57ms +step:499/1660 train_time:45692ms step_avg:91.57ms +step:500/1660 train_time:45783ms step_avg:91.57ms +step:500/1660 val_loss:3.7131 train_time:45875ms step_avg:91.75ms +step:501/1660 train_time:45897ms step_avg:91.61ms +step:502/1660 train_time:45972ms step_avg:91.58ms +step:503/1660 train_time:46072ms step_avg:91.59ms +step:504/1660 train_time:46165ms step_avg:91.60ms +step:505/1660 train_time:46255ms step_avg:91.59ms +step:506/1660 train_time:46346ms step_avg:91.59ms +step:507/1660 train_time:46436ms step_avg:91.59ms +step:508/1660 train_time:46527ms step_avg:91.59ms +step:509/1660 train_time:46617ms step_avg:91.58ms +step:510/1660 train_time:46707ms step_avg:91.58ms +step:511/1660 train_time:46798ms step_avg:91.58ms +step:512/1660 train_time:46891ms step_avg:91.58ms +step:513/1660 train_time:46986ms step_avg:91.59ms +step:514/1660 train_time:47081ms step_avg:91.60ms +step:515/1660 train_time:47173ms step_avg:91.60ms +step:516/1660 train_time:47264ms step_avg:91.60ms +step:517/1660 train_time:47354ms step_avg:91.59ms +step:518/1660 train_time:47445ms step_avg:91.59ms +step:519/1660 train_time:47535ms step_avg:91.59ms +step:520/1660 train_time:47626ms step_avg:91.59ms +step:521/1660 train_time:47716ms step_avg:91.59ms +step:522/1660 train_time:47808ms step_avg:91.59ms +step:523/1660 train_time:47900ms step_avg:91.59ms +step:524/1660 train_time:47992ms step_avg:91.59ms +step:525/1660 train_time:48086ms step_avg:91.59ms +step:526/1660 train_time:48177ms step_avg:91.59ms +step:527/1660 train_time:48269ms step_avg:91.59ms +step:528/1660 train_time:48359ms step_avg:91.59ms +step:529/1660 train_time:48451ms step_avg:91.59ms +step:530/1660 train_time:48541ms step_avg:91.59ms +step:531/1660 train_time:48631ms step_avg:91.58ms +step:532/1660 train_time:48722ms step_avg:91.58ms +step:533/1660 train_time:48813ms step_avg:91.58ms +step:534/1660 train_time:48905ms step_avg:91.58ms +step:535/1660 train_time:48998ms step_avg:91.58ms +step:536/1660 train_time:49090ms step_avg:91.59ms +step:537/1660 train_time:49182ms step_avg:91.59ms +step:538/1660 train_time:49273ms step_avg:91.59ms +step:539/1660 train_time:49365ms step_avg:91.59ms +step:540/1660 train_time:49456ms step_avg:91.59ms +step:541/1660 train_time:49547ms step_avg:91.58ms +step:542/1660 train_time:49638ms step_avg:91.58ms +step:543/1660 train_time:49729ms step_avg:91.58ms +step:544/1660 train_time:49820ms step_avg:91.58ms +step:545/1660 train_time:49911ms step_avg:91.58ms +step:546/1660 train_time:50004ms step_avg:91.58ms +step:547/1660 train_time:50096ms step_avg:91.58ms +step:548/1660 train_time:50188ms step_avg:91.58ms +step:549/1660 train_time:50279ms step_avg:91.58ms +step:550/1660 train_time:50370ms step_avg:91.58ms +step:551/1660 train_time:50462ms step_avg:91.58ms +step:552/1660 train_time:50553ms step_avg:91.58ms +step:553/1660 train_time:50643ms step_avg:91.58ms +step:554/1660 train_time:50734ms step_avg:91.58ms +step:555/1660 train_time:50825ms step_avg:91.58ms +step:556/1660 train_time:50917ms step_avg:91.58ms +step:557/1660 train_time:51011ms step_avg:91.58ms +step:558/1660 train_time:51103ms step_avg:91.58ms +step:559/1660 train_time:51196ms step_avg:91.58ms +step:560/1660 train_time:51289ms step_avg:91.59ms +step:561/1660 train_time:51381ms step_avg:91.59ms +step:562/1660 train_time:51473ms step_avg:91.59ms +step:563/1660 train_time:51565ms step_avg:91.59ms +step:564/1660 train_time:51657ms step_avg:91.59ms +step:565/1660 train_time:51749ms step_avg:91.59ms +step:566/1660 train_time:51842ms step_avg:91.59ms +step:567/1660 train_time:51935ms step_avg:91.60ms +step:568/1660 train_time:52028ms step_avg:91.60ms +step:569/1660 train_time:52121ms step_avg:91.60ms +step:570/1660 train_time:52213ms step_avg:91.60ms +step:571/1660 train_time:52307ms step_avg:91.61ms +step:572/1660 train_time:52400ms step_avg:91.61ms +step:573/1660 train_time:52492ms step_avg:91.61ms +step:574/1660 train_time:52585ms step_avg:91.61ms +step:575/1660 train_time:52677ms step_avg:91.61ms +step:576/1660 train_time:52769ms step_avg:91.61ms +step:577/1660 train_time:52862ms step_avg:91.61ms +step:578/1660 train_time:52954ms step_avg:91.62ms +step:579/1660 train_time:53048ms step_avg:91.62ms +step:580/1660 train_time:53140ms step_avg:91.62ms +step:581/1660 train_time:53233ms step_avg:91.62ms +step:582/1660 train_time:53327ms step_avg:91.63ms +step:583/1660 train_time:53420ms step_avg:91.63ms +step:584/1660 train_time:53512ms step_avg:91.63ms +step:585/1660 train_time:53606ms step_avg:91.63ms +step:586/1660 train_time:53698ms step_avg:91.64ms +step:587/1660 train_time:53790ms step_avg:91.64ms +step:588/1660 train_time:53882ms step_avg:91.64ms +step:589/1660 train_time:53974ms step_avg:91.64ms +step:590/1660 train_time:54067ms step_avg:91.64ms +step:591/1660 train_time:54160ms step_avg:91.64ms +step:592/1660 train_time:54253ms step_avg:91.64ms +step:593/1660 train_time:54346ms step_avg:91.65ms +step:594/1660 train_time:54438ms step_avg:91.65ms +step:595/1660 train_time:54531ms step_avg:91.65ms +step:596/1660 train_time:54624ms step_avg:91.65ms +step:597/1660 train_time:54716ms step_avg:91.65ms +step:598/1660 train_time:54809ms step_avg:91.65ms +step:599/1660 train_time:54902ms step_avg:91.66ms +step:600/1660 train_time:54995ms step_avg:91.66ms +step:601/1660 train_time:55088ms step_avg:91.66ms +step:602/1660 train_time:55181ms step_avg:91.66ms +step:603/1660 train_time:55273ms step_avg:91.66ms +step:604/1660 train_time:55366ms step_avg:91.67ms +step:605/1660 train_time:55458ms step_avg:91.67ms +step:606/1660 train_time:55551ms step_avg:91.67ms +step:607/1660 train_time:55644ms step_avg:91.67ms +step:608/1660 train_time:55736ms step_avg:91.67ms +step:609/1660 train_time:55829ms step_avg:91.67ms +step:610/1660 train_time:55921ms step_avg:91.67ms +step:611/1660 train_time:56013ms step_avg:91.67ms +step:612/1660 train_time:56107ms step_avg:91.68ms +step:613/1660 train_time:56199ms step_avg:91.68ms +step:614/1660 train_time:56292ms step_avg:91.68ms +step:615/1660 train_time:56385ms step_avg:91.68ms +step:616/1660 train_time:56478ms step_avg:91.68ms +step:617/1660 train_time:56570ms step_avg:91.69ms +step:618/1660 train_time:56663ms step_avg:91.69ms +step:619/1660 train_time:56755ms step_avg:91.69ms +step:620/1660 train_time:56847ms step_avg:91.69ms +step:621/1660 train_time:56939ms step_avg:91.69ms +step:622/1660 train_time:57032ms step_avg:91.69ms +step:623/1660 train_time:57125ms step_avg:91.69ms +step:624/1660 train_time:57218ms step_avg:91.70ms +step:625/1660 train_time:57310ms step_avg:91.70ms +step:625/1660 val_loss:3.6096 train_time:57404ms step_avg:91.85ms +step:626/1660 train_time:57425ms step_avg:91.73ms +step:627/1660 train_time:57500ms step_avg:91.71ms +step:628/1660 train_time:57600ms step_avg:91.72ms +step:629/1660 train_time:57695ms step_avg:91.73ms +step:630/1660 train_time:57789ms step_avg:91.73ms +step:631/1660 train_time:57881ms step_avg:91.73ms +step:632/1660 train_time:57971ms step_avg:91.73ms +step:633/1660 train_time:58063ms step_avg:91.73ms +step:634/1660 train_time:58154ms step_avg:91.73ms +step:635/1660 train_time:58246ms step_avg:91.73ms +step:636/1660 train_time:58338ms step_avg:91.73ms +step:637/1660 train_time:58431ms step_avg:91.73ms +step:638/1660 train_time:58527ms step_avg:91.73ms +step:639/1660 train_time:58622ms step_avg:91.74ms +step:640/1660 train_time:58716ms step_avg:91.74ms +step:641/1660 train_time:58809ms step_avg:91.75ms +step:642/1660 train_time:58902ms step_avg:91.75ms +step:643/1660 train_time:58994ms step_avg:91.75ms +step:644/1660 train_time:59085ms step_avg:91.75ms +step:645/1660 train_time:59177ms step_avg:91.75ms +step:646/1660 train_time:59269ms step_avg:91.75ms +step:647/1660 train_time:59361ms step_avg:91.75ms +step:648/1660 train_time:59454ms step_avg:91.75ms +step:649/1660 train_time:59549ms step_avg:91.76ms +step:650/1660 train_time:59643ms step_avg:91.76ms +step:651/1660 train_time:59736ms step_avg:91.76ms +step:652/1660 train_time:59830ms step_avg:91.76ms +step:653/1660 train_time:59923ms step_avg:91.77ms +step:654/1660 train_time:60016ms step_avg:91.77ms +step:655/1660 train_time:60107ms step_avg:91.77ms +step:656/1660 train_time:60199ms step_avg:91.77ms +step:657/1660 train_time:60292ms step_avg:91.77ms +step:658/1660 train_time:60384ms step_avg:91.77ms +step:659/1660 train_time:60477ms step_avg:91.77ms +step:660/1660 train_time:60571ms step_avg:91.77ms +step:661/1660 train_time:60664ms step_avg:91.78ms +step:662/1660 train_time:60757ms step_avg:91.78ms +step:663/1660 train_time:60851ms step_avg:91.78ms +step:664/1660 train_time:60944ms step_avg:91.78ms +step:665/1660 train_time:61036ms step_avg:91.78ms +step:666/1660 train_time:61129ms step_avg:91.78ms +step:667/1660 train_time:61221ms step_avg:91.79ms +step:668/1660 train_time:61313ms step_avg:91.79ms +step:669/1660 train_time:61405ms step_avg:91.79ms +step:670/1660 train_time:61498ms step_avg:91.79ms +step:671/1660 train_time:61592ms step_avg:91.79ms +step:672/1660 train_time:61685ms step_avg:91.79ms +step:673/1660 train_time:61777ms step_avg:91.79ms +step:674/1660 train_time:61870ms step_avg:91.80ms +step:675/1660 train_time:61962ms step_avg:91.80ms +step:676/1660 train_time:62055ms step_avg:91.80ms +step:677/1660 train_time:62148ms step_avg:91.80ms +step:678/1660 train_time:62239ms step_avg:91.80ms +step:679/1660 train_time:62331ms step_avg:91.80ms +step:680/1660 train_time:62423ms step_avg:91.80ms +step:681/1660 train_time:62516ms step_avg:91.80ms +step:682/1660 train_time:62609ms step_avg:91.80ms +step:683/1660 train_time:62702ms step_avg:91.80ms +step:684/1660 train_time:62794ms step_avg:91.80ms +step:685/1660 train_time:62887ms step_avg:91.81ms +step:686/1660 train_time:62980ms step_avg:91.81ms +step:687/1660 train_time:63072ms step_avg:91.81ms +step:688/1660 train_time:63163ms step_avg:91.81ms +step:689/1660 train_time:63255ms step_avg:91.81ms +step:690/1660 train_time:63348ms step_avg:91.81ms +step:691/1660 train_time:63441ms step_avg:91.81ms +step:692/1660 train_time:63533ms step_avg:91.81ms +step:693/1660 train_time:63626ms step_avg:91.81ms +step:694/1660 train_time:63719ms step_avg:91.81ms +step:695/1660 train_time:63812ms step_avg:91.82ms +step:696/1660 train_time:63905ms step_avg:91.82ms +step:697/1660 train_time:63998ms step_avg:91.82ms +step:698/1660 train_time:64091ms step_avg:91.82ms +step:699/1660 train_time:64184ms step_avg:91.82ms +step:700/1660 train_time:64276ms step_avg:91.82ms +step:701/1660 train_time:64370ms step_avg:91.83ms +step:702/1660 train_time:64463ms step_avg:91.83ms +step:703/1660 train_time:64555ms step_avg:91.83ms +step:704/1660 train_time:64648ms step_avg:91.83ms +step:705/1660 train_time:64740ms step_avg:91.83ms +step:706/1660 train_time:64834ms step_avg:91.83ms +step:707/1660 train_time:64927ms step_avg:91.83ms +step:708/1660 train_time:65018ms step_avg:91.83ms +step:709/1660 train_time:65111ms step_avg:91.84ms +step:710/1660 train_time:65203ms step_avg:91.84ms +step:711/1660 train_time:65295ms step_avg:91.84ms +step:712/1660 train_time:65388ms step_avg:91.84ms +step:713/1660 train_time:65481ms step_avg:91.84ms +step:714/1660 train_time:65573ms step_avg:91.84ms +step:715/1660 train_time:65665ms step_avg:91.84ms +step:716/1660 train_time:65758ms step_avg:91.84ms +step:717/1660 train_time:65851ms step_avg:91.84ms +step:718/1660 train_time:65943ms step_avg:91.84ms +step:719/1660 train_time:66036ms step_avg:91.84ms +step:720/1660 train_time:66129ms step_avg:91.85ms +step:721/1660 train_time:66221ms step_avg:91.85ms +step:722/1660 train_time:66314ms step_avg:91.85ms +step:723/1660 train_time:66406ms step_avg:91.85ms +step:724/1660 train_time:66498ms step_avg:91.85ms +step:725/1660 train_time:66591ms step_avg:91.85ms +step:726/1660 train_time:66684ms step_avg:91.85ms +step:727/1660 train_time:66777ms step_avg:91.85ms +step:728/1660 train_time:66870ms step_avg:91.85ms +step:729/1660 train_time:66963ms step_avg:91.86ms +step:730/1660 train_time:67056ms step_avg:91.86ms +step:731/1660 train_time:67148ms step_avg:91.86ms +step:732/1660 train_time:67241ms step_avg:91.86ms +step:733/1660 train_time:67333ms step_avg:91.86ms +step:734/1660 train_time:67426ms step_avg:91.86ms +step:735/1660 train_time:67518ms step_avg:91.86ms +step:736/1660 train_time:67611ms step_avg:91.86ms +step:737/1660 train_time:67703ms step_avg:91.86ms +step:738/1660 train_time:67796ms step_avg:91.86ms +step:739/1660 train_time:67890ms step_avg:91.87ms +step:740/1660 train_time:67983ms step_avg:91.87ms +step:741/1660 train_time:68075ms step_avg:91.87ms +step:742/1660 train_time:68167ms step_avg:91.87ms +step:743/1660 train_time:68260ms step_avg:91.87ms +step:744/1660 train_time:68352ms step_avg:91.87ms +step:745/1660 train_time:68445ms step_avg:91.87ms +step:746/1660 train_time:68538ms step_avg:91.87ms +step:747/1660 train_time:68630ms step_avg:91.87ms +step:748/1660 train_time:68723ms step_avg:91.88ms +step:749/1660 train_time:68816ms step_avg:91.88ms +step:750/1660 train_time:68909ms step_avg:91.88ms +step:750/1660 val_loss:3.5605 train_time:69003ms step_avg:92.00ms +step:751/1660 train_time:69024ms step_avg:91.91ms +step:752/1660 train_time:69100ms step_avg:91.89ms +step:753/1660 train_time:69200ms step_avg:91.90ms +step:754/1660 train_time:69292ms step_avg:91.90ms +step:755/1660 train_time:69383ms step_avg:91.90ms +step:756/1660 train_time:69475ms step_avg:91.90ms +step:757/1660 train_time:69566ms step_avg:91.90ms +step:758/1660 train_time:69658ms step_avg:91.90ms +step:759/1660 train_time:69749ms step_avg:91.90ms +step:760/1660 train_time:69840ms step_avg:91.89ms +step:761/1660 train_time:69932ms step_avg:91.89ms +step:762/1660 train_time:70027ms step_avg:91.90ms +step:763/1660 train_time:70122ms step_avg:91.90ms +step:764/1660 train_time:70217ms step_avg:91.91ms +step:765/1660 train_time:70310ms step_avg:91.91ms +step:766/1660 train_time:70403ms step_avg:91.91ms +step:767/1660 train_time:70495ms step_avg:91.91ms +step:768/1660 train_time:70587ms step_avg:91.91ms +step:769/1660 train_time:70679ms step_avg:91.91ms +step:770/1660 train_time:70770ms step_avg:91.91ms +step:771/1660 train_time:70862ms step_avg:91.91ms +step:772/1660 train_time:70954ms step_avg:91.91ms +step:773/1660 train_time:71049ms step_avg:91.91ms +step:774/1660 train_time:71144ms step_avg:91.92ms +step:775/1660 train_time:71238ms step_avg:91.92ms +step:776/1660 train_time:71330ms step_avg:91.92ms +step:777/1660 train_time:71422ms step_avg:91.92ms +step:778/1660 train_time:71514ms step_avg:91.92ms +step:779/1660 train_time:71607ms step_avg:91.92ms +step:780/1660 train_time:71699ms step_avg:91.92ms +step:781/1660 train_time:71791ms step_avg:91.92ms +step:782/1660 train_time:71883ms step_avg:91.92ms +step:783/1660 train_time:71976ms step_avg:91.92ms +step:784/1660 train_time:72071ms step_avg:91.93ms +step:785/1660 train_time:72165ms step_avg:91.93ms +step:786/1660 train_time:72258ms step_avg:91.93ms +step:787/1660 train_time:72351ms step_avg:91.93ms +step:788/1660 train_time:72443ms step_avg:91.93ms +step:789/1660 train_time:72535ms step_avg:91.93ms +step:790/1660 train_time:72629ms step_avg:91.93ms +step:791/1660 train_time:72720ms step_avg:91.93ms +step:792/1660 train_time:72812ms step_avg:91.93ms +step:793/1660 train_time:72904ms step_avg:91.93ms +step:794/1660 train_time:72996ms step_avg:91.93ms +step:795/1660 train_time:73089ms step_avg:91.94ms +step:796/1660 train_time:73183ms step_avg:91.94ms +step:797/1660 train_time:73276ms step_avg:91.94ms +step:798/1660 train_time:73368ms step_avg:91.94ms +step:799/1660 train_time:73460ms step_avg:91.94ms +step:800/1660 train_time:73552ms step_avg:91.94ms +step:801/1660 train_time:73646ms step_avg:91.94ms +step:802/1660 train_time:73738ms step_avg:91.94ms +step:803/1660 train_time:73830ms step_avg:91.94ms +step:804/1660 train_time:73923ms step_avg:91.94ms +step:805/1660 train_time:74017ms step_avg:91.95ms +step:806/1660 train_time:74111ms step_avg:91.95ms +step:807/1660 train_time:74204ms step_avg:91.95ms +step:808/1660 train_time:74296ms step_avg:91.95ms +step:809/1660 train_time:74389ms step_avg:91.95ms +step:810/1660 train_time:74481ms step_avg:91.95ms +step:811/1660 train_time:74575ms step_avg:91.95ms +step:812/1660 train_time:74667ms step_avg:91.95ms +step:813/1660 train_time:74760ms step_avg:91.96ms +step:814/1660 train_time:74852ms step_avg:91.96ms +step:815/1660 train_time:74945ms step_avg:91.96ms +step:816/1660 train_time:75038ms step_avg:91.96ms +step:817/1660 train_time:75131ms step_avg:91.96ms +step:818/1660 train_time:75225ms step_avg:91.96ms +step:819/1660 train_time:75317ms step_avg:91.96ms +step:820/1660 train_time:75409ms step_avg:91.96ms +step:821/1660 train_time:75501ms step_avg:91.96ms +step:822/1660 train_time:75594ms step_avg:91.96ms +step:823/1660 train_time:75687ms step_avg:91.96ms +step:824/1660 train_time:75779ms step_avg:91.97ms +step:825/1660 train_time:75872ms step_avg:91.97ms +step:826/1660 train_time:75964ms step_avg:91.97ms +step:827/1660 train_time:76057ms step_avg:91.97ms +step:828/1660 train_time:76150ms step_avg:91.97ms +step:829/1660 train_time:76243ms step_avg:91.97ms +step:830/1660 train_time:76335ms step_avg:91.97ms +step:831/1660 train_time:76428ms step_avg:91.97ms +step:832/1660 train_time:76522ms step_avg:91.97ms +step:833/1660 train_time:76614ms step_avg:91.97ms +step:834/1660 train_time:76707ms step_avg:91.97ms +step:835/1660 train_time:76798ms step_avg:91.97ms +step:836/1660 train_time:76891ms step_avg:91.97ms +step:837/1660 train_time:76983ms step_avg:91.97ms +step:838/1660 train_time:77075ms step_avg:91.98ms +step:839/1660 train_time:77169ms step_avg:91.98ms +step:840/1660 train_time:77262ms step_avg:91.98ms +step:841/1660 train_time:77354ms step_avg:91.98ms +step:842/1660 train_time:77447ms step_avg:91.98ms +step:843/1660 train_time:77540ms step_avg:91.98ms +step:844/1660 train_time:77632ms step_avg:91.98ms +step:845/1660 train_time:77725ms step_avg:91.98ms +step:846/1660 train_time:77817ms step_avg:91.98ms +step:847/1660 train_time:77910ms step_avg:91.98ms +step:848/1660 train_time:78002ms step_avg:91.98ms +step:849/1660 train_time:78095ms step_avg:91.98ms +step:850/1660 train_time:78188ms step_avg:91.99ms +step:851/1660 train_time:78281ms step_avg:91.99ms +step:852/1660 train_time:78374ms step_avg:91.99ms +step:853/1660 train_time:78466ms step_avg:91.99ms +step:854/1660 train_time:78559ms step_avg:91.99ms +step:855/1660 train_time:78652ms step_avg:91.99ms +step:856/1660 train_time:78744ms step_avg:91.99ms +step:857/1660 train_time:78836ms step_avg:91.99ms +step:858/1660 train_time:78931ms step_avg:91.99ms +step:859/1660 train_time:79024ms step_avg:91.99ms +step:860/1660 train_time:79115ms step_avg:91.99ms +step:861/1660 train_time:79209ms step_avg:92.00ms +step:862/1660 train_time:79302ms step_avg:92.00ms +step:863/1660 train_time:79395ms step_avg:92.00ms +step:864/1660 train_time:79488ms step_avg:92.00ms +step:865/1660 train_time:79580ms step_avg:92.00ms +step:866/1660 train_time:79672ms step_avg:92.00ms +step:867/1660 train_time:79765ms step_avg:92.00ms +step:868/1660 train_time:79857ms step_avg:92.00ms +step:869/1660 train_time:79950ms step_avg:92.00ms +step:870/1660 train_time:80042ms step_avg:92.00ms +step:871/1660 train_time:80135ms step_avg:92.00ms +step:872/1660 train_time:80228ms step_avg:92.00ms +step:873/1660 train_time:80321ms step_avg:92.01ms +step:874/1660 train_time:80413ms step_avg:92.01ms +step:875/1660 train_time:80505ms step_avg:92.01ms +step:875/1660 val_loss:3.5144 train_time:80599ms step_avg:92.11ms +step:876/1660 train_time:80620ms step_avg:92.03ms +step:877/1660 train_time:80695ms step_avg:92.01ms +step:878/1660 train_time:80791ms step_avg:92.02ms +step:879/1660 train_time:80885ms step_avg:92.02ms +step:880/1660 train_time:80977ms step_avg:92.02ms +step:881/1660 train_time:81069ms step_avg:92.02ms +step:882/1660 train_time:81160ms step_avg:92.02ms +step:883/1660 train_time:81251ms step_avg:92.02ms +step:884/1660 train_time:81343ms step_avg:92.02ms +step:885/1660 train_time:81434ms step_avg:92.02ms +step:886/1660 train_time:81528ms step_avg:92.02ms +step:887/1660 train_time:81621ms step_avg:92.02ms +step:888/1660 train_time:81716ms step_avg:92.02ms +step:889/1660 train_time:81810ms step_avg:92.02ms +step:890/1660 train_time:81904ms step_avg:92.03ms +step:891/1660 train_time:81996ms step_avg:92.03ms +step:892/1660 train_time:82089ms step_avg:92.03ms +step:893/1660 train_time:82181ms step_avg:92.03ms +step:894/1660 train_time:82273ms step_avg:92.03ms +step:895/1660 train_time:82365ms step_avg:92.03ms +step:896/1660 train_time:82457ms step_avg:92.03ms +step:897/1660 train_time:82550ms step_avg:92.03ms +step:898/1660 train_time:82644ms step_avg:92.03ms +step:899/1660 train_time:82737ms step_avg:92.03ms +step:900/1660 train_time:82831ms step_avg:92.03ms +step:901/1660 train_time:82924ms step_avg:92.04ms +step:902/1660 train_time:83017ms step_avg:92.04ms +step:903/1660 train_time:83110ms step_avg:92.04ms +step:904/1660 train_time:83202ms step_avg:92.04ms +step:905/1660 train_time:83294ms step_avg:92.04ms +step:906/1660 train_time:83386ms step_avg:92.04ms +step:907/1660 train_time:83477ms step_avg:92.04ms +step:908/1660 train_time:83570ms step_avg:92.04ms +step:909/1660 train_time:83663ms step_avg:92.04ms +step:910/1660 train_time:83757ms step_avg:92.04ms +step:911/1660 train_time:83851ms step_avg:92.04ms +step:912/1660 train_time:83944ms step_avg:92.04ms +step:913/1660 train_time:84036ms step_avg:92.04ms +step:914/1660 train_time:84129ms step_avg:92.04ms +step:915/1660 train_time:84221ms step_avg:92.05ms +step:916/1660 train_time:84313ms step_avg:92.04ms +step:917/1660 train_time:84405ms step_avg:92.04ms +step:918/1660 train_time:84496ms step_avg:92.04ms +step:919/1660 train_time:84589ms step_avg:92.04ms +step:920/1660 train_time:84681ms step_avg:92.05ms +step:921/1660 train_time:84774ms step_avg:92.05ms +step:922/1660 train_time:84868ms step_avg:92.05ms +step:923/1660 train_time:84961ms step_avg:92.05ms +step:924/1660 train_time:85053ms step_avg:92.05ms +step:925/1660 train_time:85146ms step_avg:92.05ms +step:926/1660 train_time:85238ms step_avg:92.05ms +step:927/1660 train_time:85331ms step_avg:92.05ms +step:928/1660 train_time:85423ms step_avg:92.05ms +step:929/1660 train_time:85515ms step_avg:92.05ms +step:930/1660 train_time:85608ms step_avg:92.05ms +step:931/1660 train_time:85701ms step_avg:92.05ms +step:932/1660 train_time:85793ms step_avg:92.05ms +step:933/1660 train_time:85887ms step_avg:92.05ms +step:934/1660 train_time:85979ms step_avg:92.05ms +step:935/1660 train_time:86072ms step_avg:92.06ms +step:936/1660 train_time:86165ms step_avg:92.06ms +step:937/1660 train_time:86258ms step_avg:92.06ms +step:938/1660 train_time:86351ms step_avg:92.06ms +step:939/1660 train_time:86443ms step_avg:92.06ms +step:940/1660 train_time:86535ms step_avg:92.06ms +step:941/1660 train_time:86628ms step_avg:92.06ms +step:942/1660 train_time:86721ms step_avg:92.06ms +step:943/1660 train_time:86814ms step_avg:92.06ms +step:944/1660 train_time:86906ms step_avg:92.06ms +step:945/1660 train_time:86998ms step_avg:92.06ms +step:946/1660 train_time:87091ms step_avg:92.06ms +step:947/1660 train_time:87184ms step_avg:92.06ms +step:948/1660 train_time:87276ms step_avg:92.06ms +step:949/1660 train_time:87369ms step_avg:92.06ms +step:950/1660 train_time:87462ms step_avg:92.07ms +step:951/1660 train_time:87555ms step_avg:92.07ms +step:952/1660 train_time:87647ms step_avg:92.07ms +step:953/1660 train_time:87741ms step_avg:92.07ms +step:954/1660 train_time:87833ms step_avg:92.07ms +step:955/1660 train_time:87925ms step_avg:92.07ms +step:956/1660 train_time:88019ms step_avg:92.07ms +step:957/1660 train_time:88112ms step_avg:92.07ms +step:958/1660 train_time:88204ms step_avg:92.07ms +step:959/1660 train_time:88297ms step_avg:92.07ms +step:960/1660 train_time:88390ms step_avg:92.07ms +step:961/1660 train_time:88482ms step_avg:92.07ms +step:962/1660 train_time:88575ms step_avg:92.07ms +step:963/1660 train_time:88668ms step_avg:92.07ms +step:964/1660 train_time:88761ms step_avg:92.08ms +step:965/1660 train_time:88854ms step_avg:92.08ms +step:966/1660 train_time:88947ms step_avg:92.08ms +step:967/1660 train_time:89040ms step_avg:92.08ms +step:968/1660 train_time:89133ms step_avg:92.08ms +step:969/1660 train_time:89226ms step_avg:92.08ms +step:970/1660 train_time:89319ms step_avg:92.08ms +step:971/1660 train_time:89412ms step_avg:92.08ms +step:972/1660 train_time:89504ms step_avg:92.08ms +step:973/1660 train_time:89596ms step_avg:92.08ms +step:974/1660 train_time:89689ms step_avg:92.08ms +step:975/1660 train_time:89780ms step_avg:92.08ms +step:976/1660 train_time:89872ms step_avg:92.08ms +step:977/1660 train_time:89965ms step_avg:92.08ms +step:978/1660 train_time:90058ms step_avg:92.08ms +step:979/1660 train_time:90151ms step_avg:92.08ms +step:980/1660 train_time:90244ms step_avg:92.09ms +step:981/1660 train_time:90336ms step_avg:92.09ms +step:982/1660 train_time:90430ms step_avg:92.09ms +step:983/1660 train_time:90522ms step_avg:92.09ms +step:984/1660 train_time:90615ms step_avg:92.09ms +step:985/1660 train_time:90708ms step_avg:92.09ms +step:986/1660 train_time:90801ms step_avg:92.09ms +step:987/1660 train_time:90893ms step_avg:92.09ms +step:988/1660 train_time:90985ms step_avg:92.09ms +step:989/1660 train_time:91078ms step_avg:92.09ms +step:990/1660 train_time:91171ms step_avg:92.09ms +step:991/1660 train_time:91264ms step_avg:92.09ms +step:992/1660 train_time:91356ms step_avg:92.09ms +step:993/1660 train_time:91449ms step_avg:92.09ms +step:994/1660 train_time:91542ms step_avg:92.09ms +step:995/1660 train_time:91634ms step_avg:92.09ms +step:996/1660 train_time:91727ms step_avg:92.09ms +step:997/1660 train_time:91819ms step_avg:92.09ms +step:998/1660 train_time:91911ms step_avg:92.10ms +step:999/1660 train_time:92004ms step_avg:92.10ms +step:1000/1660 train_time:92096ms step_avg:92.10ms +step:1000/1660 val_loss:3.4656 train_time:92191ms step_avg:92.19ms +step:1001/1660 train_time:92213ms step_avg:92.12ms +step:1002/1660 train_time:92291ms step_avg:92.11ms +step:1003/1660 train_time:92386ms step_avg:92.11ms +step:1004/1660 train_time:92480ms step_avg:92.11ms +step:1005/1660 train_time:92572ms step_avg:92.11ms +step:1006/1660 train_time:92663ms step_avg:92.11ms +step:1007/1660 train_time:92755ms step_avg:92.11ms +step:1008/1660 train_time:92846ms step_avg:92.11ms +step:1009/1660 train_time:92938ms step_avg:92.11ms +step:1010/1660 train_time:93029ms step_avg:92.11ms +step:1011/1660 train_time:93121ms step_avg:92.11ms +step:1012/1660 train_time:93217ms step_avg:92.11ms +step:1013/1660 train_time:93312ms step_avg:92.11ms +step:1014/1660 train_time:93407ms step_avg:92.12ms +step:1015/1660 train_time:93501ms step_avg:92.12ms +step:1016/1660 train_time:93592ms step_avg:92.12ms +step:1017/1660 train_time:93684ms step_avg:92.12ms +step:1018/1660 train_time:93775ms step_avg:92.12ms +step:1019/1660 train_time:93867ms step_avg:92.12ms +step:1020/1660 train_time:93958ms step_avg:92.12ms +step:1021/1660 train_time:94050ms step_avg:92.12ms +step:1022/1660 train_time:94143ms step_avg:92.12ms +step:1023/1660 train_time:94237ms step_avg:92.12ms +step:1024/1660 train_time:94331ms step_avg:92.12ms +step:1025/1660 train_time:94425ms step_avg:92.12ms +step:1026/1660 train_time:94517ms step_avg:92.12ms +step:1027/1660 train_time:94610ms step_avg:92.12ms +step:1028/1660 train_time:94703ms step_avg:92.12ms +step:1029/1660 train_time:94795ms step_avg:92.12ms +step:1030/1660 train_time:94887ms step_avg:92.12ms +step:1031/1660 train_time:94978ms step_avg:92.12ms +step:1032/1660 train_time:95070ms step_avg:92.12ms +step:1033/1660 train_time:95163ms step_avg:92.12ms +step:1034/1660 train_time:95257ms step_avg:92.12ms +step:1035/1660 train_time:95350ms step_avg:92.13ms +step:1036/1660 train_time:95444ms step_avg:92.13ms +step:1037/1660 train_time:95537ms step_avg:92.13ms +step:1038/1660 train_time:95631ms step_avg:92.13ms +step:1039/1660 train_time:95724ms step_avg:92.13ms +step:1040/1660 train_time:95815ms step_avg:92.13ms +step:1041/1660 train_time:95909ms step_avg:92.13ms +step:1042/1660 train_time:96001ms step_avg:92.13ms +step:1043/1660 train_time:96093ms step_avg:92.13ms +step:1044/1660 train_time:96185ms step_avg:92.13ms +step:1045/1660 train_time:96279ms step_avg:92.13ms +step:1046/1660 train_time:96372ms step_avg:92.13ms +step:1047/1660 train_time:96466ms step_avg:92.14ms +step:1048/1660 train_time:96559ms step_avg:92.14ms +step:1049/1660 train_time:96652ms step_avg:92.14ms +step:1050/1660 train_time:96745ms step_avg:92.14ms +step:1051/1660 train_time:96837ms step_avg:92.14ms +step:1052/1660 train_time:96930ms step_avg:92.14ms +step:1053/1660 train_time:97022ms step_avg:92.14ms +step:1054/1660 train_time:97114ms step_avg:92.14ms +step:1055/1660 train_time:97206ms step_avg:92.14ms +step:1056/1660 train_time:97299ms step_avg:92.14ms +step:1057/1660 train_time:97393ms step_avg:92.14ms +step:1058/1660 train_time:97485ms step_avg:92.14ms +step:1059/1660 train_time:97578ms step_avg:92.14ms +step:1060/1660 train_time:97670ms step_avg:92.14ms +step:1061/1660 train_time:97763ms step_avg:92.14ms +step:1062/1660 train_time:97856ms step_avg:92.14ms +step:1063/1660 train_time:97949ms step_avg:92.14ms +step:1064/1660 train_time:98041ms step_avg:92.14ms +step:1065/1660 train_time:98133ms step_avg:92.14ms +step:1066/1660 train_time:98225ms step_avg:92.14ms +step:1067/1660 train_time:98318ms step_avg:92.14ms +step:1068/1660 train_time:98412ms step_avg:92.15ms +step:1069/1660 train_time:98505ms step_avg:92.15ms +step:1070/1660 train_time:98597ms step_avg:92.15ms +step:1071/1660 train_time:98690ms step_avg:92.15ms +step:1072/1660 train_time:98783ms step_avg:92.15ms +step:1073/1660 train_time:98875ms step_avg:92.15ms +step:1074/1660 train_time:98968ms step_avg:92.15ms +step:1075/1660 train_time:99061ms step_avg:92.15ms +step:1076/1660 train_time:99153ms step_avg:92.15ms +step:1077/1660 train_time:99245ms step_avg:92.15ms +step:1078/1660 train_time:99338ms step_avg:92.15ms +step:1079/1660 train_time:99431ms step_avg:92.15ms +step:1080/1660 train_time:99524ms step_avg:92.15ms +step:1081/1660 train_time:99617ms step_avg:92.15ms +step:1082/1660 train_time:99711ms step_avg:92.15ms +step:1083/1660 train_time:99803ms step_avg:92.15ms +step:1084/1660 train_time:99895ms step_avg:92.15ms +step:1085/1660 train_time:99988ms step_avg:92.15ms +step:1086/1660 train_time:100080ms step_avg:92.15ms +step:1087/1660 train_time:100172ms step_avg:92.15ms +step:1088/1660 train_time:100265ms step_avg:92.16ms +step:1089/1660 train_time:100357ms step_avg:92.16ms +step:1090/1660 train_time:100451ms step_avg:92.16ms +step:1091/1660 train_time:100543ms step_avg:92.16ms +step:1092/1660 train_time:100636ms step_avg:92.16ms +step:1093/1660 train_time:100730ms step_avg:92.16ms +step:1094/1660 train_time:100822ms step_avg:92.16ms +step:1095/1660 train_time:100914ms step_avg:92.16ms +step:1096/1660 train_time:101008ms step_avg:92.16ms +step:1097/1660 train_time:101100ms step_avg:92.16ms +step:1098/1660 train_time:101192ms step_avg:92.16ms +step:1099/1660 train_time:101285ms step_avg:92.16ms +step:1100/1660 train_time:101376ms step_avg:92.16ms +step:1101/1660 train_time:101470ms step_avg:92.16ms +step:1102/1660 train_time:101563ms step_avg:92.16ms +step:1103/1660 train_time:101656ms step_avg:92.16ms +step:1104/1660 train_time:101749ms step_avg:92.16ms +step:1105/1660 train_time:101841ms step_avg:92.16ms +step:1106/1660 train_time:101934ms step_avg:92.16ms +step:1107/1660 train_time:102027ms step_avg:92.17ms +step:1108/1660 train_time:102119ms step_avg:92.17ms +step:1109/1660 train_time:102212ms step_avg:92.17ms +step:1110/1660 train_time:102305ms step_avg:92.17ms +step:1111/1660 train_time:102399ms step_avg:92.17ms +step:1112/1660 train_time:102492ms step_avg:92.17ms +step:1113/1660 train_time:102586ms step_avg:92.17ms +step:1114/1660 train_time:102679ms step_avg:92.17ms +step:1115/1660 train_time:102774ms step_avg:92.17ms +step:1116/1660 train_time:102868ms step_avg:92.18ms +step:1117/1660 train_time:102961ms step_avg:92.18ms +step:1118/1660 train_time:103055ms step_avg:92.18ms +step:1119/1660 train_time:103149ms step_avg:92.18ms +step:1120/1660 train_time:103241ms step_avg:92.18ms +step:1121/1660 train_time:103335ms step_avg:92.18ms +step:1122/1660 train_time:103428ms step_avg:92.18ms +step:1123/1660 train_time:103521ms step_avg:92.18ms +step:1124/1660 train_time:103614ms step_avg:92.18ms +step:1125/1660 train_time:103707ms step_avg:92.18ms +step:1125/1660 val_loss:3.4124 train_time:103803ms step_avg:92.27ms +step:1126/1660 train_time:103824ms step_avg:92.21ms +step:1127/1660 train_time:103899ms step_avg:92.19ms +step:1128/1660 train_time:103999ms step_avg:92.20ms +step:1129/1660 train_time:104097ms step_avg:92.20ms +step:1130/1660 train_time:104190ms step_avg:92.20ms +step:1131/1660 train_time:104282ms step_avg:92.20ms +step:1132/1660 train_time:104374ms step_avg:92.20ms +step:1133/1660 train_time:104466ms step_avg:92.20ms +step:1134/1660 train_time:104559ms step_avg:92.20ms +step:1135/1660 train_time:104650ms step_avg:92.20ms +step:1136/1660 train_time:104742ms step_avg:92.20ms +step:1137/1660 train_time:104837ms step_avg:92.20ms +step:1138/1660 train_time:104935ms step_avg:92.21ms +step:1139/1660 train_time:105030ms step_avg:92.21ms +step:1140/1660 train_time:105124ms step_avg:92.21ms +step:1141/1660 train_time:105218ms step_avg:92.22ms +step:1142/1660 train_time:105310ms step_avg:92.22ms +step:1143/1660 train_time:105403ms step_avg:92.22ms +step:1144/1660 train_time:105495ms step_avg:92.22ms +step:1145/1660 train_time:105587ms step_avg:92.22ms +step:1146/1660 train_time:105679ms step_avg:92.22ms +step:1147/1660 train_time:105771ms step_avg:92.22ms +step:1148/1660 train_time:105866ms step_avg:92.22ms +step:1149/1660 train_time:105962ms step_avg:92.22ms +step:1150/1660 train_time:106056ms step_avg:92.22ms +step:1151/1660 train_time:106150ms step_avg:92.22ms +step:1152/1660 train_time:106243ms step_avg:92.23ms +step:1153/1660 train_time:106336ms step_avg:92.23ms +step:1154/1660 train_time:106429ms step_avg:92.23ms +step:1155/1660 train_time:106523ms step_avg:92.23ms +step:1156/1660 train_time:106616ms step_avg:92.23ms +step:1157/1660 train_time:106708ms step_avg:92.23ms +step:1158/1660 train_time:106801ms step_avg:92.23ms +step:1159/1660 train_time:106895ms step_avg:92.23ms +step:1160/1660 train_time:106989ms step_avg:92.23ms +step:1161/1660 train_time:107085ms step_avg:92.24ms +step:1162/1660 train_time:107179ms step_avg:92.24ms +step:1163/1660 train_time:107272ms step_avg:92.24ms +step:1164/1660 train_time:107365ms step_avg:92.24ms +step:1165/1660 train_time:107459ms step_avg:92.24ms +step:1166/1660 train_time:107552ms step_avg:92.24ms +step:1167/1660 train_time:107644ms step_avg:92.24ms +step:1168/1660 train_time:107737ms step_avg:92.24ms +step:1169/1660 train_time:107830ms step_avg:92.24ms +step:1170/1660 train_time:107925ms step_avg:92.24ms +step:1171/1660 train_time:108019ms step_avg:92.24ms +step:1172/1660 train_time:108113ms step_avg:92.25ms +step:1173/1660 train_time:108207ms step_avg:92.25ms +step:1174/1660 train_time:108300ms step_avg:92.25ms +step:1175/1660 train_time:108393ms step_avg:92.25ms +step:1176/1660 train_time:108485ms step_avg:92.25ms +step:1177/1660 train_time:108578ms step_avg:92.25ms +step:1178/1660 train_time:108672ms step_avg:92.25ms +step:1179/1660 train_time:108765ms step_avg:92.25ms +step:1180/1660 train_time:108857ms step_avg:92.25ms +step:1181/1660 train_time:108950ms step_avg:92.25ms +step:1182/1660 train_time:109044ms step_avg:92.25ms +step:1183/1660 train_time:109138ms step_avg:92.26ms +step:1184/1660 train_time:109231ms step_avg:92.26ms +step:1185/1660 train_time:109325ms step_avg:92.26ms +step:1186/1660 train_time:109418ms step_avg:92.26ms +step:1187/1660 train_time:109510ms step_avg:92.26ms +step:1188/1660 train_time:109605ms step_avg:92.26ms +step:1189/1660 train_time:109700ms step_avg:92.26ms +step:1190/1660 train_time:109792ms step_avg:92.26ms +step:1191/1660 train_time:109886ms step_avg:92.26ms +step:1192/1660 train_time:109979ms step_avg:92.26ms +step:1193/1660 train_time:110072ms step_avg:92.26ms +step:1194/1660 train_time:110165ms step_avg:92.27ms +step:1195/1660 train_time:110258ms step_avg:92.27ms +step:1196/1660 train_time:110352ms step_avg:92.27ms +step:1197/1660 train_time:110445ms step_avg:92.27ms +step:1198/1660 train_time:110538ms step_avg:92.27ms +step:1199/1660 train_time:110631ms step_avg:92.27ms +step:1200/1660 train_time:110724ms step_avg:92.27ms +step:1201/1660 train_time:110817ms step_avg:92.27ms +step:1202/1660 train_time:110910ms step_avg:92.27ms +step:1203/1660 train_time:111005ms step_avg:92.27ms +step:1204/1660 train_time:111098ms step_avg:92.27ms +step:1205/1660 train_time:111191ms step_avg:92.28ms +step:1206/1660 train_time:111285ms step_avg:92.28ms +step:1207/1660 train_time:111378ms step_avg:92.28ms +step:1208/1660 train_time:111471ms step_avg:92.28ms +step:1209/1660 train_time:111564ms step_avg:92.28ms +step:1210/1660 train_time:111657ms step_avg:92.28ms +step:1211/1660 train_time:111749ms step_avg:92.28ms +step:1212/1660 train_time:111843ms step_avg:92.28ms +step:1213/1660 train_time:111936ms step_avg:92.28ms +step:1214/1660 train_time:112029ms step_avg:92.28ms +step:1215/1660 train_time:112122ms step_avg:92.28ms +step:1216/1660 train_time:112215ms step_avg:92.28ms +step:1217/1660 train_time:112308ms step_avg:92.28ms +step:1218/1660 train_time:112403ms step_avg:92.28ms +step:1219/1660 train_time:112498ms step_avg:92.29ms +step:1220/1660 train_time:112590ms step_avg:92.29ms +step:1221/1660 train_time:112684ms step_avg:92.29ms +step:1222/1660 train_time:112776ms step_avg:92.29ms +step:1223/1660 train_time:112870ms step_avg:92.29ms +step:1224/1660 train_time:112964ms step_avg:92.29ms +step:1225/1660 train_time:113058ms step_avg:92.29ms +step:1226/1660 train_time:113151ms step_avg:92.29ms +step:1227/1660 train_time:113244ms step_avg:92.29ms +step:1228/1660 train_time:113336ms step_avg:92.29ms +step:1229/1660 train_time:113429ms step_avg:92.29ms +step:1230/1660 train_time:113523ms step_avg:92.29ms +step:1231/1660 train_time:113616ms step_avg:92.30ms +step:1232/1660 train_time:113709ms step_avg:92.30ms +step:1233/1660 train_time:113803ms step_avg:92.30ms +step:1234/1660 train_time:113897ms step_avg:92.30ms +step:1235/1660 train_time:113990ms step_avg:92.30ms +step:1236/1660 train_time:114083ms step_avg:92.30ms +step:1237/1660 train_time:114177ms step_avg:92.30ms +step:1238/1660 train_time:114270ms step_avg:92.30ms +step:1239/1660 train_time:114365ms step_avg:92.30ms +step:1240/1660 train_time:114458ms step_avg:92.31ms +step:1241/1660 train_time:114552ms step_avg:92.31ms +step:1242/1660 train_time:114645ms step_avg:92.31ms +step:1243/1660 train_time:114738ms step_avg:92.31ms +step:1244/1660 train_time:114831ms step_avg:92.31ms +step:1245/1660 train_time:114926ms step_avg:92.31ms +step:1246/1660 train_time:115019ms step_avg:92.31ms +step:1247/1660 train_time:115112ms step_avg:92.31ms +step:1248/1660 train_time:115206ms step_avg:92.31ms +step:1249/1660 train_time:115299ms step_avg:92.31ms +step:1250/1660 train_time:115392ms step_avg:92.31ms +step:1250/1660 val_loss:3.3741 train_time:115487ms step_avg:92.39ms +step:1251/1660 train_time:115510ms step_avg:92.33ms +step:1252/1660 train_time:115586ms step_avg:92.32ms +step:1253/1660 train_time:115687ms step_avg:92.33ms +step:1254/1660 train_time:115781ms step_avg:92.33ms +step:1255/1660 train_time:115873ms step_avg:92.33ms +step:1256/1660 train_time:115965ms step_avg:92.33ms +step:1257/1660 train_time:116057ms step_avg:92.33ms +step:1258/1660 train_time:116149ms step_avg:92.33ms +step:1259/1660 train_time:116242ms step_avg:92.33ms +step:1260/1660 train_time:116334ms step_avg:92.33ms +step:1261/1660 train_time:116427ms step_avg:92.33ms +step:1262/1660 train_time:116522ms step_avg:92.33ms +step:1263/1660 train_time:116619ms step_avg:92.33ms +step:1264/1660 train_time:116716ms step_avg:92.34ms +step:1265/1660 train_time:116809ms step_avg:92.34ms +step:1266/1660 train_time:116902ms step_avg:92.34ms +step:1267/1660 train_time:116995ms step_avg:92.34ms +step:1268/1660 train_time:117087ms step_avg:92.34ms +step:1269/1660 train_time:117180ms step_avg:92.34ms +step:1270/1660 train_time:117273ms step_avg:92.34ms +step:1271/1660 train_time:117365ms step_avg:92.34ms +step:1272/1660 train_time:117458ms step_avg:92.34ms +step:1273/1660 train_time:117552ms step_avg:92.34ms +step:1274/1660 train_time:117649ms step_avg:92.35ms +step:1275/1660 train_time:117744ms step_avg:92.35ms +step:1276/1660 train_time:117838ms step_avg:92.35ms +step:1277/1660 train_time:117931ms step_avg:92.35ms +step:1278/1660 train_time:118025ms step_avg:92.35ms +step:1279/1660 train_time:118118ms step_avg:92.35ms +step:1280/1660 train_time:118211ms step_avg:92.35ms +step:1281/1660 train_time:118303ms step_avg:92.35ms +step:1282/1660 train_time:118396ms step_avg:92.35ms +step:1283/1660 train_time:118489ms step_avg:92.35ms +step:1284/1660 train_time:118585ms step_avg:92.36ms +step:1285/1660 train_time:118680ms step_avg:92.36ms +step:1286/1660 train_time:118774ms step_avg:92.36ms +step:1287/1660 train_time:118867ms step_avg:92.36ms +step:1288/1660 train_time:118960ms step_avg:92.36ms +step:1289/1660 train_time:119053ms step_avg:92.36ms +step:1290/1660 train_time:119147ms step_avg:92.36ms +step:1291/1660 train_time:119239ms step_avg:92.36ms +step:1292/1660 train_time:119332ms step_avg:92.36ms +step:1293/1660 train_time:119425ms step_avg:92.36ms +step:1294/1660 train_time:119520ms step_avg:92.36ms +step:1295/1660 train_time:119614ms step_avg:92.37ms +step:1296/1660 train_time:119708ms step_avg:92.37ms +step:1297/1660 train_time:119803ms step_avg:92.37ms +step:1298/1660 train_time:119897ms step_avg:92.37ms +step:1299/1660 train_time:119991ms step_avg:92.37ms +step:1300/1660 train_time:120084ms step_avg:92.37ms +step:1301/1660 train_time:120178ms step_avg:92.37ms +step:1302/1660 train_time:120271ms step_avg:92.37ms +step:1303/1660 train_time:120365ms step_avg:92.38ms +step:1304/1660 train_time:120458ms step_avg:92.38ms +step:1305/1660 train_time:120552ms step_avg:92.38ms +step:1306/1660 train_time:120645ms step_avg:92.38ms +step:1307/1660 train_time:120739ms step_avg:92.38ms +step:1308/1660 train_time:120832ms step_avg:92.38ms +step:1309/1660 train_time:120927ms step_avg:92.38ms +step:1310/1660 train_time:121020ms step_avg:92.38ms +step:1311/1660 train_time:121113ms step_avg:92.38ms +step:1312/1660 train_time:121206ms step_avg:92.38ms +step:1313/1660 train_time:121299ms step_avg:92.38ms +step:1314/1660 train_time:121392ms step_avg:92.38ms +step:1315/1660 train_time:121486ms step_avg:92.38ms +step:1316/1660 train_time:121579ms step_avg:92.38ms +step:1317/1660 train_time:121671ms step_avg:92.39ms +step:1318/1660 train_time:121766ms step_avg:92.39ms +step:1319/1660 train_time:121860ms step_avg:92.39ms +step:1320/1660 train_time:121954ms step_avg:92.39ms +step:1321/1660 train_time:122047ms step_avg:92.39ms +step:1322/1660 train_time:122141ms step_avg:92.39ms +step:1323/1660 train_time:122234ms step_avg:92.39ms +step:1324/1660 train_time:122328ms step_avg:92.39ms +step:1325/1660 train_time:122423ms step_avg:92.39ms +step:1326/1660 train_time:122517ms step_avg:92.40ms +step:1327/1660 train_time:122609ms step_avg:92.40ms +step:1328/1660 train_time:122703ms step_avg:92.40ms +step:1329/1660 train_time:122796ms step_avg:92.40ms +step:1330/1660 train_time:122889ms step_avg:92.40ms +step:1331/1660 train_time:122984ms step_avg:92.40ms +step:1332/1660 train_time:123078ms step_avg:92.40ms +step:1333/1660 train_time:123170ms step_avg:92.40ms +step:1334/1660 train_time:123264ms step_avg:92.40ms +step:1335/1660 train_time:123357ms step_avg:92.40ms +step:1336/1660 train_time:123451ms step_avg:92.40ms +step:1337/1660 train_time:123545ms step_avg:92.40ms +step:1338/1660 train_time:123638ms step_avg:92.41ms +step:1339/1660 train_time:123731ms step_avg:92.41ms +step:1340/1660 train_time:123826ms step_avg:92.41ms +step:1341/1660 train_time:123921ms step_avg:92.41ms +step:1342/1660 train_time:124014ms step_avg:92.41ms +step:1343/1660 train_time:124107ms step_avg:92.41ms +step:1344/1660 train_time:124200ms step_avg:92.41ms +step:1345/1660 train_time:124292ms step_avg:92.41ms +step:1346/1660 train_time:124387ms step_avg:92.41ms +step:1347/1660 train_time:124480ms step_avg:92.41ms +step:1348/1660 train_time:124573ms step_avg:92.41ms +step:1349/1660 train_time:124668ms step_avg:92.41ms +step:1350/1660 train_time:124761ms step_avg:92.42ms +step:1351/1660 train_time:124854ms step_avg:92.42ms +step:1352/1660 train_time:124950ms step_avg:92.42ms +step:1353/1660 train_time:125043ms step_avg:92.42ms +step:1354/1660 train_time:125136ms step_avg:92.42ms +step:1355/1660 train_time:125229ms step_avg:92.42ms +step:1356/1660 train_time:125323ms step_avg:92.42ms +step:1357/1660 train_time:125417ms step_avg:92.42ms +step:1358/1660 train_time:125510ms step_avg:92.42ms +step:1359/1660 train_time:125604ms step_avg:92.42ms +step:1360/1660 train_time:125697ms step_avg:92.42ms +step:1361/1660 train_time:125790ms step_avg:92.42ms +step:1362/1660 train_time:125883ms step_avg:92.43ms +step:1363/1660 train_time:125977ms step_avg:92.43ms +step:1364/1660 train_time:126070ms step_avg:92.43ms +step:1365/1660 train_time:126163ms step_avg:92.43ms +step:1366/1660 train_time:126258ms step_avg:92.43ms +step:1367/1660 train_time:126350ms step_avg:92.43ms +step:1368/1660 train_time:126444ms step_avg:92.43ms +step:1369/1660 train_time:126537ms step_avg:92.43ms +step:1370/1660 train_time:126631ms step_avg:92.43ms +step:1371/1660 train_time:126726ms step_avg:92.43ms +step:1372/1660 train_time:126820ms step_avg:92.43ms +step:1373/1660 train_time:126913ms step_avg:92.44ms +step:1374/1660 train_time:127007ms step_avg:92.44ms +step:1375/1660 train_time:127100ms step_avg:92.44ms +step:1375/1660 val_loss:3.3396 train_time:127195ms step_avg:92.51ms +step:1376/1660 train_time:127216ms step_avg:92.45ms +step:1377/1660 train_time:127293ms step_avg:92.44ms +step:1378/1660 train_time:127391ms step_avg:92.45ms +step:1379/1660 train_time:127484ms step_avg:92.45ms +step:1380/1660 train_time:127575ms step_avg:92.45ms +step:1381/1660 train_time:127668ms step_avg:92.45ms +step:1382/1660 train_time:127760ms step_avg:92.45ms +step:1383/1660 train_time:127853ms step_avg:92.45ms +step:1384/1660 train_time:127946ms step_avg:92.45ms +step:1385/1660 train_time:128038ms step_avg:92.45ms +step:1386/1660 train_time:128131ms step_avg:92.45ms +step:1387/1660 train_time:128226ms step_avg:92.45ms +step:1388/1660 train_time:128324ms step_avg:92.45ms +step:1389/1660 train_time:128417ms step_avg:92.45ms +step:1390/1660 train_time:128511ms step_avg:92.45ms +step:1391/1660 train_time:128603ms step_avg:92.45ms +step:1392/1660 train_time:128695ms step_avg:92.45ms +step:1393/1660 train_time:128788ms step_avg:92.45ms +step:1394/1660 train_time:128881ms step_avg:92.45ms +step:1395/1660 train_time:128974ms step_avg:92.45ms +step:1396/1660 train_time:129067ms step_avg:92.46ms +step:1397/1660 train_time:129161ms step_avg:92.46ms +step:1398/1660 train_time:129257ms step_avg:92.46ms +step:1399/1660 train_time:129353ms step_avg:92.46ms +step:1400/1660 train_time:129448ms step_avg:92.46ms +step:1401/1660 train_time:129541ms step_avg:92.46ms +step:1402/1660 train_time:129634ms step_avg:92.46ms +step:1403/1660 train_time:129727ms step_avg:92.46ms +step:1404/1660 train_time:129820ms step_avg:92.46ms +step:1405/1660 train_time:129913ms step_avg:92.46ms +step:1406/1660 train_time:130005ms step_avg:92.46ms +step:1407/1660 train_time:130098ms step_avg:92.47ms +step:1408/1660 train_time:130192ms step_avg:92.47ms +step:1409/1660 train_time:130287ms step_avg:92.47ms +step:1410/1660 train_time:130381ms step_avg:92.47ms +step:1411/1660 train_time:130475ms step_avg:92.47ms +step:1412/1660 train_time:130569ms step_avg:92.47ms +step:1413/1660 train_time:130662ms step_avg:92.47ms +step:1414/1660 train_time:130755ms step_avg:92.47ms +step:1415/1660 train_time:130848ms step_avg:92.47ms +step:1416/1660 train_time:130941ms step_avg:92.47ms +step:1417/1660 train_time:131034ms step_avg:92.47ms +step:1418/1660 train_time:131127ms step_avg:92.47ms +step:1419/1660 train_time:131220ms step_avg:92.47ms +step:1420/1660 train_time:131315ms step_avg:92.48ms +step:1421/1660 train_time:131410ms step_avg:92.48ms +step:1422/1660 train_time:131504ms step_avg:92.48ms +step:1423/1660 train_time:131597ms step_avg:92.48ms +step:1424/1660 train_time:131690ms step_avg:92.48ms +step:1425/1660 train_time:131784ms step_avg:92.48ms +step:1426/1660 train_time:131877ms step_avg:92.48ms +step:1427/1660 train_time:131970ms step_avg:92.48ms +step:1428/1660 train_time:132063ms step_avg:92.48ms +step:1429/1660 train_time:132157ms step_avg:92.48ms +step:1430/1660 train_time:132251ms step_avg:92.48ms +step:1431/1660 train_time:132345ms step_avg:92.48ms +step:1432/1660 train_time:132439ms step_avg:92.49ms +step:1433/1660 train_time:132533ms step_avg:92.49ms +step:1434/1660 train_time:132627ms step_avg:92.49ms +step:1435/1660 train_time:132721ms step_avg:92.49ms +step:1436/1660 train_time:132815ms step_avg:92.49ms +step:1437/1660 train_time:132909ms step_avg:92.49ms +step:1438/1660 train_time:133002ms step_avg:92.49ms +step:1439/1660 train_time:133095ms step_avg:92.49ms +step:1440/1660 train_time:133189ms step_avg:92.49ms +step:1441/1660 train_time:133282ms step_avg:92.49ms +step:1442/1660 train_time:133376ms step_avg:92.49ms +step:1443/1660 train_time:133470ms step_avg:92.49ms +step:1444/1660 train_time:133563ms step_avg:92.50ms +step:1445/1660 train_time:133657ms step_avg:92.50ms +step:1446/1660 train_time:133753ms step_avg:92.50ms +step:1447/1660 train_time:133848ms step_avg:92.50ms +step:1448/1660 train_time:133940ms step_avg:92.50ms +step:1449/1660 train_time:134034ms step_avg:92.50ms +step:1450/1660 train_time:134128ms step_avg:92.50ms +step:1451/1660 train_time:134221ms step_avg:92.50ms +step:1452/1660 train_time:134315ms step_avg:92.50ms +step:1453/1660 train_time:134408ms step_avg:92.50ms +step:1454/1660 train_time:134502ms step_avg:92.50ms +step:1455/1660 train_time:134596ms step_avg:92.51ms +step:1456/1660 train_time:134689ms step_avg:92.51ms +step:1457/1660 train_time:134782ms step_avg:92.51ms +step:1458/1660 train_time:134876ms step_avg:92.51ms +step:1459/1660 train_time:134969ms step_avg:92.51ms +step:1460/1660 train_time:135062ms step_avg:92.51ms +step:1461/1660 train_time:135157ms step_avg:92.51ms +step:1462/1660 train_time:135250ms step_avg:92.51ms +step:1463/1660 train_time:135343ms step_avg:92.51ms +step:1464/1660 train_time:135436ms step_avg:92.51ms +step:1465/1660 train_time:135529ms step_avg:92.51ms +step:1466/1660 train_time:135622ms step_avg:92.51ms +step:1467/1660 train_time:135716ms step_avg:92.51ms +step:1468/1660 train_time:135810ms step_avg:92.51ms +step:1469/1660 train_time:135904ms step_avg:92.51ms +step:1470/1660 train_time:135997ms step_avg:92.52ms +step:1471/1660 train_time:136091ms step_avg:92.52ms +step:1472/1660 train_time:136185ms step_avg:92.52ms +step:1473/1660 train_time:136279ms step_avg:92.52ms +step:1474/1660 train_time:136372ms step_avg:92.52ms +step:1475/1660 train_time:136466ms step_avg:92.52ms +step:1476/1660 train_time:136559ms step_avg:92.52ms +step:1477/1660 train_time:136653ms step_avg:92.52ms +step:1478/1660 train_time:136747ms step_avg:92.52ms +step:1479/1660 train_time:136841ms step_avg:92.52ms +step:1480/1660 train_time:136936ms step_avg:92.52ms +step:1481/1660 train_time:137030ms step_avg:92.53ms +step:1482/1660 train_time:137123ms step_avg:92.53ms +step:1483/1660 train_time:137216ms step_avg:92.53ms +step:1484/1660 train_time:137310ms step_avg:92.53ms +step:1485/1660 train_time:137404ms step_avg:92.53ms +step:1486/1660 train_time:137497ms step_avg:92.53ms +step:1487/1660 train_time:137590ms step_avg:92.53ms +step:1488/1660 train_time:137683ms step_avg:92.53ms +step:1489/1660 train_time:137776ms step_avg:92.53ms +step:1490/1660 train_time:137870ms step_avg:92.53ms +step:1491/1660 train_time:137965ms step_avg:92.53ms +step:1492/1660 train_time:138059ms step_avg:92.53ms +step:1493/1660 train_time:138154ms step_avg:92.53ms +step:1494/1660 train_time:138249ms step_avg:92.54ms +step:1495/1660 train_time:138341ms step_avg:92.54ms +step:1496/1660 train_time:138434ms step_avg:92.54ms +step:1497/1660 train_time:138528ms step_avg:92.54ms +step:1498/1660 train_time:138621ms step_avg:92.54ms +step:1499/1660 train_time:138715ms step_avg:92.54ms +step:1500/1660 train_time:138808ms step_avg:92.54ms +step:1500/1660 val_loss:3.3092 train_time:138903ms step_avg:92.60ms +step:1501/1660 train_time:138925ms step_avg:92.56ms +step:1502/1660 train_time:138998ms step_avg:92.54ms +step:1503/1660 train_time:139096ms step_avg:92.55ms +step:1504/1660 train_time:139189ms step_avg:92.55ms +step:1505/1660 train_time:139284ms step_avg:92.55ms +step:1506/1660 train_time:139376ms step_avg:92.55ms +step:1507/1660 train_time:139468ms step_avg:92.55ms +step:1508/1660 train_time:139561ms step_avg:92.55ms +step:1509/1660 train_time:139654ms step_avg:92.55ms +step:1510/1660 train_time:139746ms step_avg:92.55ms +step:1511/1660 train_time:139839ms step_avg:92.55ms +step:1512/1660 train_time:139934ms step_avg:92.55ms +step:1513/1660 train_time:140030ms step_avg:92.55ms +step:1514/1660 train_time:140124ms step_avg:92.55ms +step:1515/1660 train_time:140218ms step_avg:92.55ms +step:1516/1660 train_time:140312ms step_avg:92.55ms +step:1517/1660 train_time:140405ms step_avg:92.55ms +step:1518/1660 train_time:140498ms step_avg:92.55ms +step:1519/1660 train_time:140590ms step_avg:92.55ms +step:1520/1660 train_time:140682ms step_avg:92.55ms +step:1521/1660 train_time:140775ms step_avg:92.55ms +step:1522/1660 train_time:140869ms step_avg:92.56ms +step:1523/1660 train_time:140963ms step_avg:92.56ms +step:1524/1660 train_time:141056ms step_avg:92.56ms +step:1525/1660 train_time:141152ms step_avg:92.56ms +step:1526/1660 train_time:141245ms step_avg:92.56ms +step:1527/1660 train_time:141339ms step_avg:92.56ms +step:1528/1660 train_time:141432ms step_avg:92.56ms +step:1529/1660 train_time:141525ms step_avg:92.56ms +step:1530/1660 train_time:141618ms step_avg:92.56ms +step:1531/1660 train_time:141711ms step_avg:92.56ms +step:1532/1660 train_time:141805ms step_avg:92.56ms +step:1533/1660 train_time:141899ms step_avg:92.56ms +step:1534/1660 train_time:141992ms step_avg:92.56ms +step:1535/1660 train_time:142087ms step_avg:92.56ms +step:1536/1660 train_time:142181ms step_avg:92.57ms +step:1537/1660 train_time:142275ms step_avg:92.57ms +step:1538/1660 train_time:142369ms step_avg:92.57ms +step:1539/1660 train_time:142463ms step_avg:92.57ms +step:1540/1660 train_time:142556ms step_avg:92.57ms +step:1541/1660 train_time:142649ms step_avg:92.57ms +step:1542/1660 train_time:142743ms step_avg:92.57ms +step:1543/1660 train_time:142837ms step_avg:92.57ms +step:1544/1660 train_time:142930ms step_avg:92.57ms +step:1545/1660 train_time:143024ms step_avg:92.57ms +step:1546/1660 train_time:143118ms step_avg:92.57ms +step:1547/1660 train_time:143212ms step_avg:92.57ms +step:1548/1660 train_time:143305ms step_avg:92.57ms +step:1549/1660 train_time:143398ms step_avg:92.57ms +step:1550/1660 train_time:143491ms step_avg:92.57ms +step:1551/1660 train_time:143586ms step_avg:92.58ms +step:1552/1660 train_time:143679ms step_avg:92.58ms +step:1553/1660 train_time:143773ms step_avg:92.58ms +step:1554/1660 train_time:143866ms step_avg:92.58ms +step:1555/1660 train_time:143960ms step_avg:92.58ms +step:1556/1660 train_time:144055ms step_avg:92.58ms +step:1557/1660 train_time:144149ms step_avg:92.58ms +step:1558/1660 train_time:144243ms step_avg:92.58ms +step:1559/1660 train_time:144337ms step_avg:92.58ms +step:1560/1660 train_time:144431ms step_avg:92.58ms +step:1561/1660 train_time:144523ms step_avg:92.58ms +step:1562/1660 train_time:144616ms step_avg:92.58ms +step:1563/1660 train_time:144709ms step_avg:92.58ms +step:1564/1660 train_time:144803ms step_avg:92.59ms +step:1565/1660 train_time:144896ms step_avg:92.59ms +step:1566/1660 train_time:144990ms step_avg:92.59ms +step:1567/1660 train_time:145084ms step_avg:92.59ms +step:1568/1660 train_time:145179ms step_avg:92.59ms +step:1569/1660 train_time:145273ms step_avg:92.59ms +step:1570/1660 train_time:145367ms step_avg:92.59ms +step:1571/1660 train_time:145461ms step_avg:92.59ms +step:1572/1660 train_time:145553ms step_avg:92.59ms +step:1573/1660 train_time:145647ms step_avg:92.59ms +step:1574/1660 train_time:145740ms step_avg:92.59ms +step:1575/1660 train_time:145833ms step_avg:92.59ms +step:1576/1660 train_time:145926ms step_avg:92.59ms +step:1577/1660 train_time:146019ms step_avg:92.59ms +step:1578/1660 train_time:146113ms step_avg:92.59ms +step:1579/1660 train_time:146206ms step_avg:92.59ms +step:1580/1660 train_time:146300ms step_avg:92.60ms +step:1581/1660 train_time:146394ms step_avg:92.60ms +step:1582/1660 train_time:146487ms step_avg:92.60ms +step:1583/1660 train_time:146581ms step_avg:92.60ms +step:1584/1660 train_time:146675ms step_avg:92.60ms +step:1585/1660 train_time:146768ms step_avg:92.60ms +step:1586/1660 train_time:146862ms step_avg:92.60ms +step:1587/1660 train_time:146956ms step_avg:92.60ms +step:1588/1660 train_time:147050ms step_avg:92.60ms +step:1589/1660 train_time:147144ms step_avg:92.60ms +step:1590/1660 train_time:147237ms step_avg:92.60ms +step:1591/1660 train_time:147330ms step_avg:92.60ms +step:1592/1660 train_time:147423ms step_avg:92.60ms +step:1593/1660 train_time:147517ms step_avg:92.60ms +step:1594/1660 train_time:147610ms step_avg:92.60ms +step:1595/1660 train_time:147703ms step_avg:92.60ms +step:1596/1660 train_time:147796ms step_avg:92.60ms +step:1597/1660 train_time:147890ms step_avg:92.60ms +step:1598/1660 train_time:147983ms step_avg:92.61ms +step:1599/1660 train_time:148077ms step_avg:92.61ms +step:1600/1660 train_time:148170ms step_avg:92.61ms +step:1601/1660 train_time:148264ms step_avg:92.61ms +step:1602/1660 train_time:148357ms step_avg:92.61ms +step:1603/1660 train_time:148450ms step_avg:92.61ms +step:1604/1660 train_time:148544ms step_avg:92.61ms +step:1605/1660 train_time:148638ms step_avg:92.61ms +step:1606/1660 train_time:148730ms step_avg:92.61ms +step:1607/1660 train_time:148823ms step_avg:92.61ms +step:1608/1660 train_time:148917ms step_avg:92.61ms +step:1609/1660 train_time:149010ms step_avg:92.61ms +step:1610/1660 train_time:149105ms step_avg:92.61ms +step:1611/1660 train_time:149199ms step_avg:92.61ms +step:1612/1660 train_time:149292ms step_avg:92.61ms +step:1613/1660 train_time:149385ms step_avg:92.61ms +step:1614/1660 train_time:149479ms step_avg:92.61ms +step:1615/1660 train_time:149572ms step_avg:92.61ms +step:1616/1660 train_time:149665ms step_avg:92.61ms +step:1617/1660 train_time:149758ms step_avg:92.61ms +step:1618/1660 train_time:149852ms step_avg:92.62ms +step:1619/1660 train_time:149945ms step_avg:92.62ms +step:1620/1660 train_time:150039ms step_avg:92.62ms +step:1621/1660 train_time:150132ms step_avg:92.62ms +step:1622/1660 train_time:150226ms step_avg:92.62ms +step:1623/1660 train_time:150320ms step_avg:92.62ms +step:1624/1660 train_time:150414ms step_avg:92.62ms +step:1625/1660 train_time:150507ms step_avg:92.62ms +step:1625/1660 val_loss:3.2849 train_time:150602ms step_avg:92.68ms +step:1626/1660 train_time:150623ms step_avg:92.63ms +step:1627/1660 train_time:150701ms step_avg:92.63ms +step:1628/1660 train_time:150798ms step_avg:92.63ms +step:1629/1660 train_time:150892ms step_avg:92.63ms +step:1630/1660 train_time:150985ms step_avg:92.63ms +step:1631/1660 train_time:151078ms step_avg:92.63ms +step:1632/1660 train_time:151171ms step_avg:92.63ms +step:1633/1660 train_time:151263ms step_avg:92.63ms +step:1634/1660 train_time:151356ms step_avg:92.63ms +step:1635/1660 train_time:151450ms step_avg:92.63ms +step:1636/1660 train_time:151543ms step_avg:92.63ms +step:1637/1660 train_time:151639ms step_avg:92.63ms +step:1638/1660 train_time:151735ms step_avg:92.63ms +step:1639/1660 train_time:151830ms step_avg:92.64ms +step:1640/1660 train_time:151923ms step_avg:92.64ms +step:1641/1660 train_time:152016ms step_avg:92.64ms +step:1642/1660 train_time:152109ms step_avg:92.64ms +step:1643/1660 train_time:152202ms step_avg:92.64ms +step:1644/1660 train_time:152294ms step_avg:92.64ms +step:1645/1660 train_time:152387ms step_avg:92.64ms +step:1646/1660 train_time:152480ms step_avg:92.64ms +step:1647/1660 train_time:152574ms step_avg:92.64ms +step:1648/1660 train_time:152669ms step_avg:92.64ms +step:1649/1660 train_time:152763ms step_avg:92.64ms +step:1650/1660 train_time:152857ms step_avg:92.64ms +step:1651/1660 train_time:152951ms step_avg:92.64ms +step:1652/1660 train_time:153045ms step_avg:92.64ms +step:1653/1660 train_time:153137ms step_avg:92.64ms +step:1654/1660 train_time:153230ms step_avg:92.64ms +step:1655/1660 train_time:153322ms step_avg:92.64ms +step:1656/1660 train_time:153415ms step_avg:92.64ms +step:1657/1660 train_time:153510ms step_avg:92.64ms +step:1658/1660 train_time:153603ms step_avg:92.64ms +step:1659/1660 train_time:153697ms step_avg:92.64ms +step:1660/1660 train_time:153791ms step_avg:92.65ms +step:1660/1660 val_loss:3.2767 train_time:153887ms step_avg:92.70ms +peak memory allocated: 31587 MiB reserved: 47056 MiB diff --git a/records/091525_ThreadingFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt b/records/091525_ThreadingFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt new file mode 100644 index 000000000..d7def6a35 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:36:50 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 40C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 190659 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 190660 C /usr/bin/python3 614MiB | +| 0 N/A N/A 190661 C /usr/bin/python3 614MiB | +| 0 N/A N/A 190662 C /usr/bin/python3 614MiB | +| 0 N/A N/A 190663 C /usr/bin/python3 614MiB | +| 0 N/A N/A 190664 C /usr/bin/python3 614MiB | +| 0 N/A N/A 190665 C /usr/bin/python3 614MiB | +| 0 N/A N/A 190666 C /usr/bin/python3 614MiB | +| 1 N/A N/A 190660 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 190661 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 190662 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 190663 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 190664 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 190665 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 190666 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:148ms step_avg:147.75ms +step:2/1660 train_time:168ms step_avg:84.23ms +step:3/1660 train_time:235ms step_avg:78.45ms +step:4/1660 train_time:324ms step_avg:81.10ms +step:5/1660 train_time:415ms step_avg:82.91ms +step:6/1660 train_time:505ms step_avg:84.22ms +step:7/1660 train_time:596ms step_avg:85.20ms +step:8/1660 train_time:687ms step_avg:85.84ms +step:9/1660 train_time:778ms step_avg:86.39ms +step:10/1660 train_time:869ms step_avg:86.95ms +step:11/1660 train_time:961ms step_avg:87.38ms +step:12/1660 train_time:1057ms step_avg:88.10ms +step:13/1660 train_time:1153ms step_avg:88.72ms +step:14/1660 train_time:1246ms step_avg:89.01ms +step:15/1660 train_time:1339ms step_avg:89.25ms +step:16/1660 train_time:1430ms step_avg:89.40ms +step:17/1660 train_time:1521ms step_avg:89.47ms +step:18/1660 train_time:1613ms step_avg:89.59ms +step:19/1660 train_time:1703ms step_avg:89.65ms +step:20/1660 train_time:1794ms step_avg:89.72ms +step:21/1660 train_time:1885ms step_avg:89.75ms +step:22/1660 train_time:1979ms step_avg:89.95ms +step:23/1660 train_time:2074ms step_avg:90.15ms +step:24/1660 train_time:2166ms step_avg:90.26ms +step:25/1660 train_time:2260ms step_avg:90.38ms +step:26/1660 train_time:2352ms step_avg:90.47ms +step:27/1660 train_time:2444ms step_avg:90.51ms +step:28/1660 train_time:2536ms step_avg:90.55ms +step:29/1660 train_time:2626ms step_avg:90.55ms +step:30/1660 train_time:2717ms step_avg:90.56ms +step:31/1660 train_time:2808ms step_avg:90.58ms +step:32/1660 train_time:2900ms step_avg:90.62ms +step:33/1660 train_time:2993ms step_avg:90.69ms +step:34/1660 train_time:3085ms step_avg:90.74ms +step:35/1660 train_time:3179ms step_avg:90.83ms +step:36/1660 train_time:3272ms step_avg:90.90ms +step:37/1660 train_time:3364ms step_avg:90.91ms +step:38/1660 train_time:3456ms step_avg:90.94ms +step:39/1660 train_time:3547ms step_avg:90.96ms +step:40/1660 train_time:3639ms step_avg:90.97ms +step:41/1660 train_time:3730ms step_avg:90.97ms +step:42/1660 train_time:3821ms step_avg:90.97ms +step:43/1660 train_time:3912ms step_avg:90.98ms +step:44/1660 train_time:4004ms step_avg:91.00ms +step:45/1660 train_time:4097ms step_avg:91.04ms +step:46/1660 train_time:4189ms step_avg:91.05ms +step:47/1660 train_time:4283ms step_avg:91.12ms +step:48/1660 train_time:4375ms step_avg:91.16ms +step:49/1660 train_time:4467ms step_avg:91.15ms +step:50/1660 train_time:4558ms step_avg:91.16ms +step:51/1660 train_time:4649ms step_avg:91.16ms +step:52/1660 train_time:4741ms step_avg:91.18ms +step:53/1660 train_time:4833ms step_avg:91.18ms +step:54/1660 train_time:4924ms step_avg:91.18ms +step:55/1660 train_time:5015ms step_avg:91.18ms +step:56/1660 train_time:5107ms step_avg:91.20ms +step:57/1660 train_time:5200ms step_avg:91.22ms +step:58/1660 train_time:5292ms step_avg:91.23ms +step:59/1660 train_time:5385ms step_avg:91.26ms +step:60/1660 train_time:5477ms step_avg:91.28ms +step:61/1660 train_time:5569ms step_avg:91.29ms +step:62/1660 train_time:5661ms step_avg:91.31ms +step:63/1660 train_time:5753ms step_avg:91.32ms +step:64/1660 train_time:5844ms step_avg:91.31ms +step:65/1660 train_time:5935ms step_avg:91.31ms +step:66/1660 train_time:6026ms step_avg:91.31ms +step:67/1660 train_time:6118ms step_avg:91.32ms +step:68/1660 train_time:6210ms step_avg:91.32ms +step:69/1660 train_time:6302ms step_avg:91.33ms +step:70/1660 train_time:6394ms step_avg:91.34ms +step:71/1660 train_time:6485ms step_avg:91.34ms +step:72/1660 train_time:6578ms step_avg:91.36ms +step:73/1660 train_time:6670ms step_avg:91.37ms +step:74/1660 train_time:6762ms step_avg:91.38ms +step:75/1660 train_time:6854ms step_avg:91.38ms +step:76/1660 train_time:6944ms step_avg:91.37ms +step:77/1660 train_time:7036ms step_avg:91.38ms +step:78/1660 train_time:7128ms step_avg:91.39ms +step:79/1660 train_time:7221ms step_avg:91.40ms +step:80/1660 train_time:7313ms step_avg:91.41ms +step:81/1660 train_time:7404ms step_avg:91.41ms +step:82/1660 train_time:7495ms step_avg:91.41ms +step:83/1660 train_time:7586ms step_avg:91.40ms +step:84/1660 train_time:7679ms step_avg:91.42ms +step:85/1660 train_time:7772ms step_avg:91.43ms +step:86/1660 train_time:7863ms step_avg:91.43ms +step:87/1660 train_time:7954ms step_avg:91.43ms +step:88/1660 train_time:8045ms step_avg:91.42ms +step:89/1660 train_time:8137ms step_avg:91.43ms +step:90/1660 train_time:8228ms step_avg:91.42ms +step:91/1660 train_time:8320ms step_avg:91.43ms +step:92/1660 train_time:8412ms step_avg:91.43ms +step:93/1660 train_time:8503ms step_avg:91.43ms +step:94/1660 train_time:8594ms step_avg:91.43ms +step:95/1660 train_time:8686ms step_avg:91.43ms +step:96/1660 train_time:8778ms step_avg:91.44ms +step:97/1660 train_time:8869ms step_avg:91.44ms +step:98/1660 train_time:8962ms step_avg:91.45ms +step:99/1660 train_time:9053ms step_avg:91.45ms +step:100/1660 train_time:9144ms step_avg:91.44ms +step:101/1660 train_time:9236ms step_avg:91.44ms +step:102/1660 train_time:9327ms step_avg:91.44ms +step:103/1660 train_time:9419ms step_avg:91.45ms +step:104/1660 train_time:9512ms step_avg:91.46ms +step:105/1660 train_time:9603ms step_avg:91.46ms +step:106/1660 train_time:9694ms step_avg:91.45ms +step:107/1660 train_time:9785ms step_avg:91.45ms +step:108/1660 train_time:9877ms step_avg:91.45ms +step:109/1660 train_time:9968ms step_avg:91.45ms +step:110/1660 train_time:10060ms step_avg:91.46ms +step:111/1660 train_time:10153ms step_avg:91.46ms +step:112/1660 train_time:10245ms step_avg:91.47ms +step:113/1660 train_time:10336ms step_avg:91.47ms +step:114/1660 train_time:10428ms step_avg:91.47ms +step:115/1660 train_time:10519ms step_avg:91.47ms +step:116/1660 train_time:10610ms step_avg:91.46ms +step:117/1660 train_time:10701ms step_avg:91.46ms +step:118/1660 train_time:10792ms step_avg:91.46ms +step:119/1660 train_time:10884ms step_avg:91.46ms +step:120/1660 train_time:10976ms step_avg:91.46ms +step:121/1660 train_time:11066ms step_avg:91.45ms +step:122/1660 train_time:11157ms step_avg:91.45ms +step:123/1660 train_time:11249ms step_avg:91.46ms +step:124/1660 train_time:11342ms step_avg:91.47ms +step:125/1660 train_time:11434ms step_avg:91.48ms +step:125/1660 val_loss:4.3151 train_time:11527ms step_avg:92.21ms +step:126/1660 train_time:11548ms step_avg:91.65ms +step:127/1660 train_time:11621ms step_avg:91.50ms +step:128/1660 train_time:11722ms step_avg:91.58ms +step:129/1660 train_time:11816ms step_avg:91.60ms +step:130/1660 train_time:11907ms step_avg:91.60ms +step:131/1660 train_time:11998ms step_avg:91.59ms +step:132/1660 train_time:12089ms step_avg:91.58ms +step:133/1660 train_time:12179ms step_avg:91.57ms +step:134/1660 train_time:12270ms step_avg:91.56ms +step:135/1660 train_time:12360ms step_avg:91.55ms +step:136/1660 train_time:12450ms step_avg:91.54ms +step:137/1660 train_time:12541ms step_avg:91.54ms +step:138/1660 train_time:12634ms step_avg:91.55ms +step:139/1660 train_time:12731ms step_avg:91.59ms +step:140/1660 train_time:12825ms step_avg:91.61ms +step:141/1660 train_time:12917ms step_avg:91.61ms +step:142/1660 train_time:13009ms step_avg:91.61ms +step:143/1660 train_time:13100ms step_avg:91.61ms +step:144/1660 train_time:13191ms step_avg:91.60ms +step:145/1660 train_time:13282ms step_avg:91.60ms +step:146/1660 train_time:13371ms step_avg:91.58ms +step:147/1660 train_time:13462ms step_avg:91.58ms +step:148/1660 train_time:13552ms step_avg:91.57ms +step:149/1660 train_time:13644ms step_avg:91.57ms +step:150/1660 train_time:13737ms step_avg:91.58ms +step:151/1660 train_time:13830ms step_avg:91.59ms +step:152/1660 train_time:13922ms step_avg:91.59ms +step:153/1660 train_time:14013ms step_avg:91.59ms +step:154/1660 train_time:14104ms step_avg:91.59ms +step:155/1660 train_time:14196ms step_avg:91.59ms +step:156/1660 train_time:14287ms step_avg:91.59ms +step:157/1660 train_time:14379ms step_avg:91.58ms +step:158/1660 train_time:14470ms step_avg:91.58ms +step:159/1660 train_time:14560ms step_avg:91.57ms +step:160/1660 train_time:14651ms step_avg:91.57ms +step:161/1660 train_time:14744ms step_avg:91.58ms +step:162/1660 train_time:14836ms step_avg:91.58ms +step:163/1660 train_time:14929ms step_avg:91.59ms +step:164/1660 train_time:15021ms step_avg:91.59ms +step:165/1660 train_time:15112ms step_avg:91.59ms +step:166/1660 train_time:15204ms step_avg:91.59ms +step:167/1660 train_time:15295ms step_avg:91.59ms +step:168/1660 train_time:15387ms step_avg:91.59ms +step:169/1660 train_time:15478ms step_avg:91.59ms +step:170/1660 train_time:15569ms step_avg:91.58ms +step:171/1660 train_time:15660ms step_avg:91.58ms +step:172/1660 train_time:15752ms step_avg:91.58ms +step:173/1660 train_time:15844ms step_avg:91.59ms +step:174/1660 train_time:15937ms step_avg:91.59ms +step:175/1660 train_time:16029ms step_avg:91.59ms +step:176/1660 train_time:16122ms step_avg:91.60ms +step:177/1660 train_time:16213ms step_avg:91.60ms +step:178/1660 train_time:16305ms step_avg:91.60ms +step:179/1660 train_time:16397ms step_avg:91.60ms +step:180/1660 train_time:16488ms step_avg:91.60ms +step:181/1660 train_time:16579ms step_avg:91.60ms +step:182/1660 train_time:16670ms step_avg:91.59ms +step:183/1660 train_time:16762ms step_avg:91.59ms +step:184/1660 train_time:16853ms step_avg:91.59ms +step:185/1660 train_time:16946ms step_avg:91.60ms +step:186/1660 train_time:17037ms step_avg:91.60ms +step:187/1660 train_time:17130ms step_avg:91.61ms +step:188/1660 train_time:17222ms step_avg:91.61ms +step:189/1660 train_time:17313ms step_avg:91.60ms +step:190/1660 train_time:17405ms step_avg:91.60ms +step:191/1660 train_time:17496ms step_avg:91.60ms +step:192/1660 train_time:17587ms step_avg:91.60ms +step:193/1660 train_time:17679ms step_avg:91.60ms +step:194/1660 train_time:17770ms step_avg:91.60ms +step:195/1660 train_time:17861ms step_avg:91.59ms +step:196/1660 train_time:17952ms step_avg:91.59ms +step:197/1660 train_time:18045ms step_avg:91.60ms +step:198/1660 train_time:18136ms step_avg:91.59ms +step:199/1660 train_time:18228ms step_avg:91.60ms +step:200/1660 train_time:18320ms step_avg:91.60ms +step:201/1660 train_time:18411ms step_avg:91.60ms +step:202/1660 train_time:18503ms step_avg:91.60ms +step:203/1660 train_time:18594ms step_avg:91.60ms +step:204/1660 train_time:18686ms step_avg:91.60ms +step:205/1660 train_time:18778ms step_avg:91.60ms +step:206/1660 train_time:18869ms step_avg:91.60ms +step:207/1660 train_time:18961ms step_avg:91.60ms +step:208/1660 train_time:19052ms step_avg:91.59ms +step:209/1660 train_time:19143ms step_avg:91.59ms +step:210/1660 train_time:19234ms step_avg:91.59ms +step:211/1660 train_time:19327ms step_avg:91.60ms +step:212/1660 train_time:19419ms step_avg:91.60ms +step:213/1660 train_time:19511ms step_avg:91.60ms +step:214/1660 train_time:19603ms step_avg:91.60ms +step:215/1660 train_time:19694ms step_avg:91.60ms +step:216/1660 train_time:19785ms step_avg:91.60ms +step:217/1660 train_time:19877ms step_avg:91.60ms +step:218/1660 train_time:19968ms step_avg:91.60ms +step:219/1660 train_time:20060ms step_avg:91.60ms +step:220/1660 train_time:20150ms step_avg:91.59ms +step:221/1660 train_time:20242ms step_avg:91.59ms +step:222/1660 train_time:20332ms step_avg:91.59ms +step:223/1660 train_time:20424ms step_avg:91.59ms +step:224/1660 train_time:20515ms step_avg:91.59ms +step:225/1660 train_time:20607ms step_avg:91.59ms +step:226/1660 train_time:20699ms step_avg:91.59ms +step:227/1660 train_time:20791ms step_avg:91.59ms +step:228/1660 train_time:20882ms step_avg:91.59ms +step:229/1660 train_time:20974ms step_avg:91.59ms +step:230/1660 train_time:21066ms step_avg:91.59ms +step:231/1660 train_time:21157ms step_avg:91.59ms +step:232/1660 train_time:21249ms step_avg:91.59ms +step:233/1660 train_time:21339ms step_avg:91.59ms +step:234/1660 train_time:21431ms step_avg:91.58ms +step:235/1660 train_time:21522ms step_avg:91.58ms +step:236/1660 train_time:21613ms step_avg:91.58ms +step:237/1660 train_time:21705ms step_avg:91.58ms +step:238/1660 train_time:21797ms step_avg:91.59ms +step:239/1660 train_time:21889ms step_avg:91.59ms +step:240/1660 train_time:21981ms step_avg:91.59ms +step:241/1660 train_time:22071ms step_avg:91.58ms +step:242/1660 train_time:22163ms step_avg:91.58ms +step:243/1660 train_time:22254ms step_avg:91.58ms +step:244/1660 train_time:22346ms step_avg:91.58ms +step:245/1660 train_time:22437ms step_avg:91.58ms +step:246/1660 train_time:22528ms step_avg:91.58ms +step:247/1660 train_time:22620ms step_avg:91.58ms +step:248/1660 train_time:22711ms step_avg:91.58ms +step:249/1660 train_time:22803ms step_avg:91.58ms +step:250/1660 train_time:22895ms step_avg:91.58ms +step:250/1660 val_loss:3.9760 train_time:22989ms step_avg:91.95ms +step:251/1660 train_time:23008ms step_avg:91.66ms +step:252/1660 train_time:23083ms step_avg:91.60ms +step:253/1660 train_time:23180ms step_avg:91.62ms +step:254/1660 train_time:23272ms step_avg:91.62ms +step:255/1660 train_time:23363ms step_avg:91.62ms +step:256/1660 train_time:23453ms step_avg:91.61ms +step:257/1660 train_time:23543ms step_avg:91.61ms +step:258/1660 train_time:23634ms step_avg:91.60ms +step:259/1660 train_time:23723ms step_avg:91.60ms +step:260/1660 train_time:23814ms step_avg:91.59ms +step:261/1660 train_time:23905ms step_avg:91.59ms +step:262/1660 train_time:23997ms step_avg:91.59ms +step:263/1660 train_time:24090ms step_avg:91.60ms +step:264/1660 train_time:24184ms step_avg:91.61ms +step:265/1660 train_time:24277ms step_avg:91.61ms +step:266/1660 train_time:24369ms step_avg:91.61ms +step:267/1660 train_time:24461ms step_avg:91.61ms +step:268/1660 train_time:24551ms step_avg:91.61ms +step:269/1660 train_time:24642ms step_avg:91.60ms +step:270/1660 train_time:24732ms step_avg:91.60ms +step:271/1660 train_time:24823ms step_avg:91.60ms +step:272/1660 train_time:24914ms step_avg:91.59ms +step:273/1660 train_time:25005ms step_avg:91.59ms +step:274/1660 train_time:25097ms step_avg:91.60ms +step:275/1660 train_time:25189ms step_avg:91.60ms +step:276/1660 train_time:25283ms step_avg:91.60ms +step:277/1660 train_time:25375ms step_avg:91.61ms +step:278/1660 train_time:25466ms step_avg:91.60ms +step:279/1660 train_time:25557ms step_avg:91.60ms +step:280/1660 train_time:25648ms step_avg:91.60ms +step:281/1660 train_time:25739ms step_avg:91.60ms +step:282/1660 train_time:25831ms step_avg:91.60ms +step:283/1660 train_time:25921ms step_avg:91.60ms +step:284/1660 train_time:26012ms step_avg:91.59ms +step:285/1660 train_time:26104ms step_avg:91.59ms +step:286/1660 train_time:26196ms step_avg:91.59ms +step:287/1660 train_time:26287ms step_avg:91.59ms +step:288/1660 train_time:26380ms step_avg:91.60ms +step:289/1660 train_time:26472ms step_avg:91.60ms +step:290/1660 train_time:26563ms step_avg:91.60ms +step:291/1660 train_time:26654ms step_avg:91.60ms +step:292/1660 train_time:26746ms step_avg:91.59ms +step:293/1660 train_time:26837ms step_avg:91.59ms +step:294/1660 train_time:26928ms step_avg:91.59ms +step:295/1660 train_time:27019ms step_avg:91.59ms +step:296/1660 train_time:27111ms step_avg:91.59ms +step:297/1660 train_time:27203ms step_avg:91.59ms +step:298/1660 train_time:27294ms step_avg:91.59ms +step:299/1660 train_time:27385ms step_avg:91.59ms +step:300/1660 train_time:27478ms step_avg:91.59ms +step:301/1660 train_time:27569ms step_avg:91.59ms +step:302/1660 train_time:27661ms step_avg:91.59ms +step:303/1660 train_time:27753ms step_avg:91.59ms +step:304/1660 train_time:27844ms step_avg:91.59ms +step:305/1660 train_time:27935ms step_avg:91.59ms +step:306/1660 train_time:28025ms step_avg:91.59ms +step:307/1660 train_time:28117ms step_avg:91.59ms +step:308/1660 train_time:28208ms step_avg:91.59ms +step:309/1660 train_time:28300ms step_avg:91.59ms +step:310/1660 train_time:28392ms step_avg:91.59ms +step:311/1660 train_time:28484ms step_avg:91.59ms +step:312/1660 train_time:28576ms step_avg:91.59ms +step:313/1660 train_time:28667ms step_avg:91.59ms +step:314/1660 train_time:28760ms step_avg:91.59ms +step:315/1660 train_time:28852ms step_avg:91.59ms +step:316/1660 train_time:28944ms step_avg:91.59ms +step:317/1660 train_time:29036ms step_avg:91.59ms +step:318/1660 train_time:29126ms step_avg:91.59ms +step:319/1660 train_time:29218ms step_avg:91.59ms +step:320/1660 train_time:29309ms step_avg:91.59ms +step:321/1660 train_time:29401ms step_avg:91.59ms +step:322/1660 train_time:29492ms step_avg:91.59ms +step:323/1660 train_time:29585ms step_avg:91.59ms +step:324/1660 train_time:29677ms step_avg:91.59ms +step:325/1660 train_time:29768ms step_avg:91.59ms +step:326/1660 train_time:29861ms step_avg:91.60ms +step:327/1660 train_time:29953ms step_avg:91.60ms +step:328/1660 train_time:30044ms step_avg:91.60ms +step:329/1660 train_time:30135ms step_avg:91.59ms +step:330/1660 train_time:30226ms step_avg:91.59ms +step:331/1660 train_time:30318ms step_avg:91.59ms +step:332/1660 train_time:30410ms step_avg:91.60ms +step:333/1660 train_time:30501ms step_avg:91.59ms +step:334/1660 train_time:30592ms step_avg:91.59ms +step:335/1660 train_time:30685ms step_avg:91.60ms +step:336/1660 train_time:30777ms step_avg:91.60ms +step:337/1660 train_time:30869ms step_avg:91.60ms +step:338/1660 train_time:30961ms step_avg:91.60ms +step:339/1660 train_time:31054ms step_avg:91.60ms +step:340/1660 train_time:31145ms step_avg:91.60ms +step:341/1660 train_time:31236ms step_avg:91.60ms +step:342/1660 train_time:31327ms step_avg:91.60ms +step:343/1660 train_time:31418ms step_avg:91.60ms +step:344/1660 train_time:31509ms step_avg:91.60ms +step:345/1660 train_time:31601ms step_avg:91.60ms +step:346/1660 train_time:31692ms step_avg:91.59ms +step:347/1660 train_time:31784ms step_avg:91.60ms +step:348/1660 train_time:31876ms step_avg:91.60ms +step:349/1660 train_time:31967ms step_avg:91.60ms +step:350/1660 train_time:32059ms step_avg:91.60ms +step:351/1660 train_time:32150ms step_avg:91.60ms +step:352/1660 train_time:32242ms step_avg:91.60ms +step:353/1660 train_time:32333ms step_avg:91.60ms +step:354/1660 train_time:32424ms step_avg:91.59ms +step:355/1660 train_time:32516ms step_avg:91.59ms +step:356/1660 train_time:32607ms step_avg:91.59ms +step:357/1660 train_time:32698ms step_avg:91.59ms +step:358/1660 train_time:32790ms step_avg:91.59ms +step:359/1660 train_time:32883ms step_avg:91.60ms +step:360/1660 train_time:32974ms step_avg:91.60ms +step:361/1660 train_time:33065ms step_avg:91.59ms +step:362/1660 train_time:33157ms step_avg:91.59ms +step:363/1660 train_time:33249ms step_avg:91.59ms +step:364/1660 train_time:33341ms step_avg:91.60ms +step:365/1660 train_time:33432ms step_avg:91.60ms +step:366/1660 train_time:33523ms step_avg:91.59ms +step:367/1660 train_time:33614ms step_avg:91.59ms +step:368/1660 train_time:33705ms step_avg:91.59ms +step:369/1660 train_time:33796ms step_avg:91.59ms +step:370/1660 train_time:33887ms step_avg:91.59ms +step:371/1660 train_time:33980ms step_avg:91.59ms +step:372/1660 train_time:34071ms step_avg:91.59ms +step:373/1660 train_time:34163ms step_avg:91.59ms +step:374/1660 train_time:34255ms step_avg:91.59ms +step:375/1660 train_time:34347ms step_avg:91.59ms +step:375/1660 val_loss:3.8235 train_time:34442ms step_avg:91.84ms +step:376/1660 train_time:34461ms step_avg:91.65ms +step:377/1660 train_time:34536ms step_avg:91.61ms +step:378/1660 train_time:34633ms step_avg:91.62ms +step:379/1660 train_time:34725ms step_avg:91.62ms +step:380/1660 train_time:34818ms step_avg:91.63ms +step:381/1660 train_time:34910ms step_avg:91.63ms +step:382/1660 train_time:35000ms step_avg:91.62ms +step:383/1660 train_time:35091ms step_avg:91.62ms +step:384/1660 train_time:35181ms step_avg:91.62ms +step:385/1660 train_time:35271ms step_avg:91.61ms +step:386/1660 train_time:35362ms step_avg:91.61ms +step:387/1660 train_time:35454ms step_avg:91.61ms +step:388/1660 train_time:35547ms step_avg:91.62ms +step:389/1660 train_time:35640ms step_avg:91.62ms +step:390/1660 train_time:35733ms step_avg:91.62ms +step:391/1660 train_time:35825ms step_avg:91.62ms +step:392/1660 train_time:35917ms step_avg:91.62ms +step:393/1660 train_time:36008ms step_avg:91.62ms +step:394/1660 train_time:36099ms step_avg:91.62ms +step:395/1660 train_time:36190ms step_avg:91.62ms +step:396/1660 train_time:36281ms step_avg:91.62ms +step:397/1660 train_time:36371ms step_avg:91.62ms +step:398/1660 train_time:36462ms step_avg:91.61ms +step:399/1660 train_time:36554ms step_avg:91.61ms +step:400/1660 train_time:36646ms step_avg:91.61ms +step:401/1660 train_time:36739ms step_avg:91.62ms +step:402/1660 train_time:36832ms step_avg:91.62ms +step:403/1660 train_time:36922ms step_avg:91.62ms +step:404/1660 train_time:37014ms step_avg:91.62ms +step:405/1660 train_time:37105ms step_avg:91.62ms +step:406/1660 train_time:37196ms step_avg:91.62ms +step:407/1660 train_time:37287ms step_avg:91.61ms +step:408/1660 train_time:37378ms step_avg:91.61ms +step:409/1660 train_time:37470ms step_avg:91.61ms +step:410/1660 train_time:37561ms step_avg:91.61ms +step:411/1660 train_time:37653ms step_avg:91.61ms +step:412/1660 train_time:37745ms step_avg:91.61ms +step:413/1660 train_time:37838ms step_avg:91.62ms +step:414/1660 train_time:37930ms step_avg:91.62ms +step:415/1660 train_time:38022ms step_avg:91.62ms +step:416/1660 train_time:38114ms step_avg:91.62ms +step:417/1660 train_time:38206ms step_avg:91.62ms +step:418/1660 train_time:38297ms step_avg:91.62ms +step:419/1660 train_time:38388ms step_avg:91.62ms +step:420/1660 train_time:38479ms step_avg:91.62ms +step:421/1660 train_time:38570ms step_avg:91.62ms +step:422/1660 train_time:38661ms step_avg:91.61ms +step:423/1660 train_time:38753ms step_avg:91.61ms +step:424/1660 train_time:38844ms step_avg:91.61ms +step:425/1660 train_time:38936ms step_avg:91.62ms +step:426/1660 train_time:39028ms step_avg:91.62ms +step:427/1660 train_time:39120ms step_avg:91.62ms +step:428/1660 train_time:39212ms step_avg:91.62ms +step:429/1660 train_time:39303ms step_avg:91.62ms +step:430/1660 train_time:39395ms step_avg:91.62ms +step:431/1660 train_time:39486ms step_avg:91.61ms +step:432/1660 train_time:39578ms step_avg:91.61ms +step:433/1660 train_time:39669ms step_avg:91.62ms +step:434/1660 train_time:39761ms step_avg:91.61ms +step:435/1660 train_time:39852ms step_avg:91.61ms +step:436/1660 train_time:39943ms step_avg:91.61ms +step:437/1660 train_time:40036ms step_avg:91.62ms +step:438/1660 train_time:40129ms step_avg:91.62ms +step:439/1660 train_time:40220ms step_avg:91.62ms +step:440/1660 train_time:40312ms step_avg:91.62ms +step:441/1660 train_time:40402ms step_avg:91.61ms +step:442/1660 train_time:40494ms step_avg:91.62ms +step:443/1660 train_time:40585ms step_avg:91.61ms +step:444/1660 train_time:40677ms step_avg:91.61ms +step:445/1660 train_time:40768ms step_avg:91.61ms +step:446/1660 train_time:40859ms step_avg:91.61ms +step:447/1660 train_time:40950ms step_avg:91.61ms +step:448/1660 train_time:41041ms step_avg:91.61ms +step:449/1660 train_time:41134ms step_avg:91.61ms +step:450/1660 train_time:41225ms step_avg:91.61ms +step:451/1660 train_time:41317ms step_avg:91.61ms +step:452/1660 train_time:41409ms step_avg:91.61ms +step:453/1660 train_time:41500ms step_avg:91.61ms +step:454/1660 train_time:41591ms step_avg:91.61ms +step:455/1660 train_time:41683ms step_avg:91.61ms +step:456/1660 train_time:41775ms step_avg:91.61ms +step:457/1660 train_time:41866ms step_avg:91.61ms +step:458/1660 train_time:41960ms step_avg:91.61ms +step:459/1660 train_time:42050ms step_avg:91.61ms +step:460/1660 train_time:42140ms step_avg:91.61ms +step:461/1660 train_time:42232ms step_avg:91.61ms +step:462/1660 train_time:42325ms step_avg:91.61ms +step:463/1660 train_time:42417ms step_avg:91.61ms +step:464/1660 train_time:42509ms step_avg:91.62ms +step:465/1660 train_time:42601ms step_avg:91.61ms +step:466/1660 train_time:42692ms step_avg:91.61ms +step:467/1660 train_time:42783ms step_avg:91.61ms +step:468/1660 train_time:42874ms step_avg:91.61ms +step:469/1660 train_time:42966ms step_avg:91.61ms +step:470/1660 train_time:43057ms step_avg:91.61ms +step:471/1660 train_time:43149ms step_avg:91.61ms +step:472/1660 train_time:43241ms step_avg:91.61ms +step:473/1660 train_time:43333ms step_avg:91.61ms +step:474/1660 train_time:43425ms step_avg:91.61ms +step:475/1660 train_time:43517ms step_avg:91.62ms +step:476/1660 train_time:43609ms step_avg:91.62ms +step:477/1660 train_time:43700ms step_avg:91.61ms +step:478/1660 train_time:43792ms step_avg:91.62ms +step:479/1660 train_time:43883ms step_avg:91.61ms +step:480/1660 train_time:43974ms step_avg:91.61ms +step:481/1660 train_time:44065ms step_avg:91.61ms +step:482/1660 train_time:44157ms step_avg:91.61ms +step:483/1660 train_time:44250ms step_avg:91.61ms +step:484/1660 train_time:44341ms step_avg:91.61ms +step:485/1660 train_time:44434ms step_avg:91.62ms +step:486/1660 train_time:44525ms step_avg:91.61ms +step:487/1660 train_time:44616ms step_avg:91.61ms +step:488/1660 train_time:44708ms step_avg:91.62ms +step:489/1660 train_time:44800ms step_avg:91.61ms +step:490/1660 train_time:44891ms step_avg:91.61ms +step:491/1660 train_time:44983ms step_avg:91.62ms +step:492/1660 train_time:45074ms step_avg:91.61ms +step:493/1660 train_time:45166ms step_avg:91.61ms +step:494/1660 train_time:45257ms step_avg:91.61ms +step:495/1660 train_time:45349ms step_avg:91.61ms +step:496/1660 train_time:45440ms step_avg:91.61ms +step:497/1660 train_time:45533ms step_avg:91.62ms +step:498/1660 train_time:45625ms step_avg:91.62ms +step:499/1660 train_time:45717ms step_avg:91.62ms +step:500/1660 train_time:45808ms step_avg:91.62ms +step:500/1660 val_loss:3.7191 train_time:45902ms step_avg:91.80ms +step:501/1660 train_time:45921ms step_avg:91.66ms +step:502/1660 train_time:45994ms step_avg:91.62ms +step:503/1660 train_time:46089ms step_avg:91.63ms +step:504/1660 train_time:46181ms step_avg:91.63ms +step:505/1660 train_time:46271ms step_avg:91.63ms +step:506/1660 train_time:46362ms step_avg:91.62ms +step:507/1660 train_time:46452ms step_avg:91.62ms +step:508/1660 train_time:46543ms step_avg:91.62ms +step:509/1660 train_time:46634ms step_avg:91.62ms +step:510/1660 train_time:46725ms step_avg:91.62ms +step:511/1660 train_time:46816ms step_avg:91.62ms +step:512/1660 train_time:46909ms step_avg:91.62ms +step:513/1660 train_time:47002ms step_avg:91.62ms +step:514/1660 train_time:47094ms step_avg:91.62ms +step:515/1660 train_time:47187ms step_avg:91.63ms +step:516/1660 train_time:47279ms step_avg:91.63ms +step:517/1660 train_time:47369ms step_avg:91.62ms +step:518/1660 train_time:47460ms step_avg:91.62ms +step:519/1660 train_time:47550ms step_avg:91.62ms +step:520/1660 train_time:47642ms step_avg:91.62ms +step:521/1660 train_time:47732ms step_avg:91.62ms +step:522/1660 train_time:47824ms step_avg:91.62ms +step:523/1660 train_time:47916ms step_avg:91.62ms +step:524/1660 train_time:48008ms step_avg:91.62ms +step:525/1660 train_time:48100ms step_avg:91.62ms +step:526/1660 train_time:48192ms step_avg:91.62ms +step:527/1660 train_time:48286ms step_avg:91.62ms +step:528/1660 train_time:48377ms step_avg:91.62ms +step:529/1660 train_time:48468ms step_avg:91.62ms +step:530/1660 train_time:48559ms step_avg:91.62ms +step:531/1660 train_time:48650ms step_avg:91.62ms +step:532/1660 train_time:48741ms step_avg:91.62ms +step:533/1660 train_time:48832ms step_avg:91.62ms +step:534/1660 train_time:48924ms step_avg:91.62ms +step:535/1660 train_time:49016ms step_avg:91.62ms +step:536/1660 train_time:49108ms step_avg:91.62ms +step:537/1660 train_time:49199ms step_avg:91.62ms +step:538/1660 train_time:49291ms step_avg:91.62ms +step:539/1660 train_time:49384ms step_avg:91.62ms +step:540/1660 train_time:49475ms step_avg:91.62ms +step:541/1660 train_time:49566ms step_avg:91.62ms +step:542/1660 train_time:49657ms step_avg:91.62ms +step:543/1660 train_time:49748ms step_avg:91.62ms +step:544/1660 train_time:49839ms step_avg:91.62ms +step:545/1660 train_time:49930ms step_avg:91.61ms +step:546/1660 train_time:50022ms step_avg:91.62ms +step:547/1660 train_time:50113ms step_avg:91.61ms +step:548/1660 train_time:50204ms step_avg:91.61ms +step:549/1660 train_time:50295ms step_avg:91.61ms +step:550/1660 train_time:50387ms step_avg:91.61ms +step:551/1660 train_time:50480ms step_avg:91.61ms +step:552/1660 train_time:50571ms step_avg:91.61ms +step:553/1660 train_time:50663ms step_avg:91.62ms +step:554/1660 train_time:50754ms step_avg:91.61ms +step:555/1660 train_time:50846ms step_avg:91.61ms +step:556/1660 train_time:50939ms step_avg:91.62ms +step:557/1660 train_time:51031ms step_avg:91.62ms +step:558/1660 train_time:51124ms step_avg:91.62ms +step:559/1660 train_time:51216ms step_avg:91.62ms +step:560/1660 train_time:51309ms step_avg:91.62ms +step:561/1660 train_time:51402ms step_avg:91.63ms +step:562/1660 train_time:51494ms step_avg:91.63ms +step:563/1660 train_time:51587ms step_avg:91.63ms +step:564/1660 train_time:51681ms step_avg:91.63ms +step:565/1660 train_time:51773ms step_avg:91.63ms +step:566/1660 train_time:51866ms step_avg:91.64ms +step:567/1660 train_time:51959ms step_avg:91.64ms +step:568/1660 train_time:52051ms step_avg:91.64ms +step:569/1660 train_time:52145ms step_avg:91.64ms +step:570/1660 train_time:52237ms step_avg:91.64ms +step:571/1660 train_time:52330ms step_avg:91.65ms +step:572/1660 train_time:52423ms step_avg:91.65ms +step:573/1660 train_time:52515ms step_avg:91.65ms +step:574/1660 train_time:52608ms step_avg:91.65ms +step:575/1660 train_time:52701ms step_avg:91.65ms +step:576/1660 train_time:52793ms step_avg:91.66ms +step:577/1660 train_time:52888ms step_avg:91.66ms +step:578/1660 train_time:52981ms step_avg:91.66ms +step:579/1660 train_time:53073ms step_avg:91.66ms +step:580/1660 train_time:53167ms step_avg:91.67ms +step:581/1660 train_time:53260ms step_avg:91.67ms +step:582/1660 train_time:53352ms step_avg:91.67ms +step:583/1660 train_time:53446ms step_avg:91.67ms +step:584/1660 train_time:53538ms step_avg:91.68ms +step:585/1660 train_time:53631ms step_avg:91.68ms +step:586/1660 train_time:53724ms step_avg:91.68ms +step:587/1660 train_time:53816ms step_avg:91.68ms +step:588/1660 train_time:53909ms step_avg:91.68ms +step:589/1660 train_time:54002ms step_avg:91.68ms +step:590/1660 train_time:54095ms step_avg:91.69ms +step:591/1660 train_time:54188ms step_avg:91.69ms +step:592/1660 train_time:54281ms step_avg:91.69ms +step:593/1660 train_time:54373ms step_avg:91.69ms +step:594/1660 train_time:54466ms step_avg:91.69ms +step:595/1660 train_time:54560ms step_avg:91.70ms +step:596/1660 train_time:54652ms step_avg:91.70ms +step:597/1660 train_time:54745ms step_avg:91.70ms +step:598/1660 train_time:54837ms step_avg:91.70ms +step:599/1660 train_time:54930ms step_avg:91.70ms +step:600/1660 train_time:55024ms step_avg:91.71ms +step:601/1660 train_time:55116ms step_avg:91.71ms +step:602/1660 train_time:55210ms step_avg:91.71ms +step:603/1660 train_time:55303ms step_avg:91.71ms +step:604/1660 train_time:55396ms step_avg:91.72ms +step:605/1660 train_time:55490ms step_avg:91.72ms +step:606/1660 train_time:55582ms step_avg:91.72ms +step:607/1660 train_time:55674ms step_avg:91.72ms +step:608/1660 train_time:55767ms step_avg:91.72ms +step:609/1660 train_time:55860ms step_avg:91.72ms +step:610/1660 train_time:55953ms step_avg:91.73ms +step:611/1660 train_time:56047ms step_avg:91.73ms +step:612/1660 train_time:56140ms step_avg:91.73ms +step:613/1660 train_time:56233ms step_avg:91.73ms +step:614/1660 train_time:56325ms step_avg:91.74ms +step:615/1660 train_time:56418ms step_avg:91.74ms +step:616/1660 train_time:56510ms step_avg:91.74ms +step:617/1660 train_time:56603ms step_avg:91.74ms +step:618/1660 train_time:56694ms step_avg:91.74ms +step:619/1660 train_time:56788ms step_avg:91.74ms +step:620/1660 train_time:56881ms step_avg:91.74ms +step:621/1660 train_time:56973ms step_avg:91.74ms +step:622/1660 train_time:57068ms step_avg:91.75ms +step:623/1660 train_time:57160ms step_avg:91.75ms +step:624/1660 train_time:57253ms step_avg:91.75ms +step:625/1660 train_time:57345ms step_avg:91.75ms +step:625/1660 val_loss:3.6151 train_time:57440ms step_avg:91.90ms +step:626/1660 train_time:57459ms step_avg:91.79ms +step:627/1660 train_time:57540ms step_avg:91.77ms +step:628/1660 train_time:57640ms step_avg:91.78ms +step:629/1660 train_time:57737ms step_avg:91.79ms +step:630/1660 train_time:57830ms step_avg:91.79ms +step:631/1660 train_time:57921ms step_avg:91.79ms +step:632/1660 train_time:58012ms step_avg:91.79ms +step:633/1660 train_time:58103ms step_avg:91.79ms +step:634/1660 train_time:58195ms step_avg:91.79ms +step:635/1660 train_time:58286ms step_avg:91.79ms +step:636/1660 train_time:58380ms step_avg:91.79ms +step:637/1660 train_time:58474ms step_avg:91.80ms +step:638/1660 train_time:58570ms step_avg:91.80ms +step:639/1660 train_time:58664ms step_avg:91.81ms +step:640/1660 train_time:58758ms step_avg:91.81ms +step:641/1660 train_time:58854ms step_avg:91.82ms +step:642/1660 train_time:58947ms step_avg:91.82ms +step:643/1660 train_time:59038ms step_avg:91.82ms +step:644/1660 train_time:59130ms step_avg:91.82ms +step:645/1660 train_time:59221ms step_avg:91.81ms +step:646/1660 train_time:59313ms step_avg:91.82ms +step:647/1660 train_time:59405ms step_avg:91.82ms +step:648/1660 train_time:59498ms step_avg:91.82ms +step:649/1660 train_time:59593ms step_avg:91.82ms +step:650/1660 train_time:59687ms step_avg:91.83ms +step:651/1660 train_time:59781ms step_avg:91.83ms +step:652/1660 train_time:59875ms step_avg:91.83ms +step:653/1660 train_time:59967ms step_avg:91.83ms +step:654/1660 train_time:60058ms step_avg:91.83ms +step:655/1660 train_time:60150ms step_avg:91.83ms +step:656/1660 train_time:60242ms step_avg:91.83ms +step:657/1660 train_time:60334ms step_avg:91.83ms +step:658/1660 train_time:60426ms step_avg:91.83ms +step:659/1660 train_time:60519ms step_avg:91.83ms +step:660/1660 train_time:60612ms step_avg:91.84ms +step:661/1660 train_time:60707ms step_avg:91.84ms +step:662/1660 train_time:60799ms step_avg:91.84ms +step:663/1660 train_time:60895ms step_avg:91.85ms +step:664/1660 train_time:60988ms step_avg:91.85ms +step:665/1660 train_time:61080ms step_avg:91.85ms +step:666/1660 train_time:61172ms step_avg:91.85ms +step:667/1660 train_time:61264ms step_avg:91.85ms +step:668/1660 train_time:61357ms step_avg:91.85ms +step:669/1660 train_time:61449ms step_avg:91.85ms +step:670/1660 train_time:61541ms step_avg:91.85ms +step:671/1660 train_time:61635ms step_avg:91.86ms +step:672/1660 train_time:61728ms step_avg:91.86ms +step:673/1660 train_time:61821ms step_avg:91.86ms +step:674/1660 train_time:61915ms step_avg:91.86ms +step:675/1660 train_time:62008ms step_avg:91.86ms +step:676/1660 train_time:62100ms step_avg:91.86ms +step:677/1660 train_time:62192ms step_avg:91.86ms +step:678/1660 train_time:62284ms step_avg:91.86ms +step:679/1660 train_time:62375ms step_avg:91.86ms +step:680/1660 train_time:62468ms step_avg:91.87ms +step:681/1660 train_time:62561ms step_avg:91.87ms +step:682/1660 train_time:62655ms step_avg:91.87ms +step:683/1660 train_time:62747ms step_avg:91.87ms +step:684/1660 train_time:62840ms step_avg:91.87ms +step:685/1660 train_time:62935ms step_avg:91.88ms +step:686/1660 train_time:63027ms step_avg:91.88ms +step:687/1660 train_time:63119ms step_avg:91.88ms +step:688/1660 train_time:63212ms step_avg:91.88ms +step:689/1660 train_time:63304ms step_avg:91.88ms +step:690/1660 train_time:63397ms step_avg:91.88ms +step:691/1660 train_time:63490ms step_avg:91.88ms +step:692/1660 train_time:63582ms step_avg:91.88ms +step:693/1660 train_time:63676ms step_avg:91.88ms +step:694/1660 train_time:63769ms step_avg:91.89ms +step:695/1660 train_time:63861ms step_avg:91.89ms +step:696/1660 train_time:63956ms step_avg:91.89ms +step:697/1660 train_time:64051ms step_avg:91.89ms +step:698/1660 train_time:64143ms step_avg:91.90ms +step:699/1660 train_time:64235ms step_avg:91.90ms +step:700/1660 train_time:64327ms step_avg:91.90ms +step:701/1660 train_time:64419ms step_avg:91.90ms +step:702/1660 train_time:64513ms step_avg:91.90ms +step:703/1660 train_time:64605ms step_avg:91.90ms +step:704/1660 train_time:64697ms step_avg:91.90ms +step:705/1660 train_time:64790ms step_avg:91.90ms +step:706/1660 train_time:64883ms step_avg:91.90ms +step:707/1660 train_time:64976ms step_avg:91.90ms +step:708/1660 train_time:65069ms step_avg:91.91ms +step:709/1660 train_time:65161ms step_avg:91.91ms +step:710/1660 train_time:65255ms step_avg:91.91ms +step:711/1660 train_time:65348ms step_avg:91.91ms +step:712/1660 train_time:65440ms step_avg:91.91ms +step:713/1660 train_time:65533ms step_avg:91.91ms +step:714/1660 train_time:65627ms step_avg:91.91ms +step:715/1660 train_time:65719ms step_avg:91.92ms +step:716/1660 train_time:65813ms step_avg:91.92ms +step:717/1660 train_time:65906ms step_avg:91.92ms +step:718/1660 train_time:65999ms step_avg:91.92ms +step:719/1660 train_time:66092ms step_avg:91.92ms +step:720/1660 train_time:66184ms step_avg:91.92ms +step:721/1660 train_time:66277ms step_avg:91.92ms +step:722/1660 train_time:66369ms step_avg:91.92ms +step:723/1660 train_time:66461ms step_avg:91.92ms +step:724/1660 train_time:66555ms step_avg:91.93ms +step:725/1660 train_time:66649ms step_avg:91.93ms +step:726/1660 train_time:66741ms step_avg:91.93ms +step:727/1660 train_time:66835ms step_avg:91.93ms +step:728/1660 train_time:66928ms step_avg:91.93ms +step:729/1660 train_time:67020ms step_avg:91.93ms +step:730/1660 train_time:67113ms step_avg:91.94ms +step:731/1660 train_time:67206ms step_avg:91.94ms +step:732/1660 train_time:67298ms step_avg:91.94ms +step:733/1660 train_time:67391ms step_avg:91.94ms +step:734/1660 train_time:67484ms step_avg:91.94ms +step:735/1660 train_time:67578ms step_avg:91.94ms +step:736/1660 train_time:67671ms step_avg:91.94ms +step:737/1660 train_time:67763ms step_avg:91.94ms +step:738/1660 train_time:67855ms step_avg:91.94ms +step:739/1660 train_time:67948ms step_avg:91.95ms +step:740/1660 train_time:68041ms step_avg:91.95ms +step:741/1660 train_time:68134ms step_avg:91.95ms +step:742/1660 train_time:68227ms step_avg:91.95ms +step:743/1660 train_time:68319ms step_avg:91.95ms +step:744/1660 train_time:68411ms step_avg:91.95ms +step:745/1660 train_time:68504ms step_avg:91.95ms +step:746/1660 train_time:68597ms step_avg:91.95ms +step:747/1660 train_time:68690ms step_avg:91.95ms +step:748/1660 train_time:68782ms step_avg:91.95ms +step:749/1660 train_time:68875ms step_avg:91.96ms +step:750/1660 train_time:68967ms step_avg:91.96ms +step:750/1660 val_loss:3.5624 train_time:69061ms step_avg:92.08ms +step:751/1660 train_time:69080ms step_avg:91.98ms +step:752/1660 train_time:69159ms step_avg:91.97ms +step:753/1660 train_time:69259ms step_avg:91.98ms +step:754/1660 train_time:69352ms step_avg:91.98ms +step:755/1660 train_time:69443ms step_avg:91.98ms +step:756/1660 train_time:69534ms step_avg:91.98ms +step:757/1660 train_time:69626ms step_avg:91.98ms +step:758/1660 train_time:69718ms step_avg:91.98ms +step:759/1660 train_time:69809ms step_avg:91.98ms +step:760/1660 train_time:69901ms step_avg:91.97ms +step:761/1660 train_time:69994ms step_avg:91.98ms +step:762/1660 train_time:70089ms step_avg:91.98ms +step:763/1660 train_time:70185ms step_avg:91.99ms +step:764/1660 train_time:70279ms step_avg:91.99ms +step:765/1660 train_time:70372ms step_avg:91.99ms +step:766/1660 train_time:70465ms step_avg:91.99ms +step:767/1660 train_time:70557ms step_avg:91.99ms +step:768/1660 train_time:70649ms step_avg:91.99ms +step:769/1660 train_time:70740ms step_avg:91.99ms +step:770/1660 train_time:70832ms step_avg:91.99ms +step:771/1660 train_time:70924ms step_avg:91.99ms +step:772/1660 train_time:71017ms step_avg:91.99ms +step:773/1660 train_time:71111ms step_avg:91.99ms +step:774/1660 train_time:71204ms step_avg:92.00ms +step:775/1660 train_time:71298ms step_avg:92.00ms +step:776/1660 train_time:71392ms step_avg:92.00ms +step:777/1660 train_time:71485ms step_avg:92.00ms +step:778/1660 train_time:71577ms step_avg:92.00ms +step:779/1660 train_time:71669ms step_avg:92.00ms +step:780/1660 train_time:71761ms step_avg:92.00ms +step:781/1660 train_time:71854ms step_avg:92.00ms +step:782/1660 train_time:71946ms step_avg:92.00ms +step:783/1660 train_time:72039ms step_avg:92.00ms +step:784/1660 train_time:72134ms step_avg:92.01ms +step:785/1660 train_time:72227ms step_avg:92.01ms +step:786/1660 train_time:72320ms step_avg:92.01ms +step:787/1660 train_time:72415ms step_avg:92.01ms +step:788/1660 train_time:72508ms step_avg:92.01ms +step:789/1660 train_time:72600ms step_avg:92.01ms +step:790/1660 train_time:72694ms step_avg:92.02ms +step:791/1660 train_time:72785ms step_avg:92.02ms +step:792/1660 train_time:72877ms step_avg:92.02ms +step:793/1660 train_time:72970ms step_avg:92.02ms +step:794/1660 train_time:73062ms step_avg:92.02ms +step:795/1660 train_time:73156ms step_avg:92.02ms +step:796/1660 train_time:73249ms step_avg:92.02ms +step:797/1660 train_time:73342ms step_avg:92.02ms +step:798/1660 train_time:73435ms step_avg:92.02ms +step:799/1660 train_time:73528ms step_avg:92.02ms +step:800/1660 train_time:73620ms step_avg:92.02ms +step:801/1660 train_time:73713ms step_avg:92.03ms +step:802/1660 train_time:73806ms step_avg:92.03ms +step:803/1660 train_time:73898ms step_avg:92.03ms +step:804/1660 train_time:73991ms step_avg:92.03ms +step:805/1660 train_time:74083ms step_avg:92.03ms +step:806/1660 train_time:74176ms step_avg:92.03ms +step:807/1660 train_time:74269ms step_avg:92.03ms +step:808/1660 train_time:74361ms step_avg:92.03ms +step:809/1660 train_time:74455ms step_avg:92.03ms +step:810/1660 train_time:74548ms step_avg:92.03ms +step:811/1660 train_time:74641ms step_avg:92.04ms +step:812/1660 train_time:74733ms step_avg:92.04ms +step:813/1660 train_time:74826ms step_avg:92.04ms +step:814/1660 train_time:74918ms step_avg:92.04ms +step:815/1660 train_time:75012ms step_avg:92.04ms +step:816/1660 train_time:75104ms step_avg:92.04ms +step:817/1660 train_time:75197ms step_avg:92.04ms +step:818/1660 train_time:75290ms step_avg:92.04ms +step:819/1660 train_time:75384ms step_avg:92.04ms +step:820/1660 train_time:75477ms step_avg:92.04ms +step:821/1660 train_time:75569ms step_avg:92.05ms +step:822/1660 train_time:75662ms step_avg:92.05ms +step:823/1660 train_time:75755ms step_avg:92.05ms +step:824/1660 train_time:75847ms step_avg:92.05ms +step:825/1660 train_time:75940ms step_avg:92.05ms +step:826/1660 train_time:76032ms step_avg:92.05ms +step:827/1660 train_time:76125ms step_avg:92.05ms +step:828/1660 train_time:76217ms step_avg:92.05ms +step:829/1660 train_time:76310ms step_avg:92.05ms +step:830/1660 train_time:76402ms step_avg:92.05ms +step:831/1660 train_time:76496ms step_avg:92.05ms +step:832/1660 train_time:76589ms step_avg:92.05ms +step:833/1660 train_time:76681ms step_avg:92.05ms +step:834/1660 train_time:76774ms step_avg:92.06ms +step:835/1660 train_time:76866ms step_avg:92.06ms +step:836/1660 train_time:76959ms step_avg:92.06ms +step:837/1660 train_time:77052ms step_avg:92.06ms +step:838/1660 train_time:77144ms step_avg:92.06ms +step:839/1660 train_time:77236ms step_avg:92.06ms +step:840/1660 train_time:77329ms step_avg:92.06ms +step:841/1660 train_time:77421ms step_avg:92.06ms +step:842/1660 train_time:77515ms step_avg:92.06ms +step:843/1660 train_time:77608ms step_avg:92.06ms +step:844/1660 train_time:77700ms step_avg:92.06ms +step:845/1660 train_time:77793ms step_avg:92.06ms +step:846/1660 train_time:77886ms step_avg:92.06ms +step:847/1660 train_time:77978ms step_avg:92.06ms +step:848/1660 train_time:78070ms step_avg:92.06ms +step:849/1660 train_time:78162ms step_avg:92.06ms +step:850/1660 train_time:78255ms step_avg:92.06ms +step:851/1660 train_time:78348ms step_avg:92.07ms +step:852/1660 train_time:78440ms step_avg:92.07ms +step:853/1660 train_time:78533ms step_avg:92.07ms +step:854/1660 train_time:78626ms step_avg:92.07ms +step:855/1660 train_time:78719ms step_avg:92.07ms +step:856/1660 train_time:78812ms step_avg:92.07ms +step:857/1660 train_time:78905ms step_avg:92.07ms +step:858/1660 train_time:78998ms step_avg:92.07ms +step:859/1660 train_time:79092ms step_avg:92.07ms +step:860/1660 train_time:79185ms step_avg:92.08ms +step:861/1660 train_time:79278ms step_avg:92.08ms +step:862/1660 train_time:79370ms step_avg:92.08ms +step:863/1660 train_time:79462ms step_avg:92.08ms +step:864/1660 train_time:79555ms step_avg:92.08ms +step:865/1660 train_time:79648ms step_avg:92.08ms +step:866/1660 train_time:79741ms step_avg:92.08ms +step:867/1660 train_time:79834ms step_avg:92.08ms +step:868/1660 train_time:79927ms step_avg:92.08ms +step:869/1660 train_time:80020ms step_avg:92.08ms +step:870/1660 train_time:80113ms step_avg:92.08ms +step:871/1660 train_time:80206ms step_avg:92.08ms +step:872/1660 train_time:80298ms step_avg:92.09ms +step:873/1660 train_time:80393ms step_avg:92.09ms +step:874/1660 train_time:80486ms step_avg:92.09ms +step:875/1660 train_time:80578ms step_avg:92.09ms +step:875/1660 val_loss:3.5183 train_time:80673ms step_avg:92.20ms +step:876/1660 train_time:80692ms step_avg:92.11ms +step:877/1660 train_time:80771ms step_avg:92.10ms +step:878/1660 train_time:80867ms step_avg:92.10ms +step:879/1660 train_time:80959ms step_avg:92.10ms +step:880/1660 train_time:81051ms step_avg:92.10ms +step:881/1660 train_time:81143ms step_avg:92.10ms +step:882/1660 train_time:81233ms step_avg:92.10ms +step:883/1660 train_time:81325ms step_avg:92.10ms +step:884/1660 train_time:81417ms step_avg:92.10ms +step:885/1660 train_time:81510ms step_avg:92.10ms +step:886/1660 train_time:81604ms step_avg:92.10ms +step:887/1660 train_time:81698ms step_avg:92.11ms +step:888/1660 train_time:81793ms step_avg:92.11ms +step:889/1660 train_time:81889ms step_avg:92.11ms +step:890/1660 train_time:81982ms step_avg:92.11ms +step:891/1660 train_time:82074ms step_avg:92.11ms +step:892/1660 train_time:82167ms step_avg:92.12ms +step:893/1660 train_time:82259ms step_avg:92.12ms +step:894/1660 train_time:82352ms step_avg:92.12ms +step:895/1660 train_time:82443ms step_avg:92.12ms +step:896/1660 train_time:82535ms step_avg:92.12ms +step:897/1660 train_time:82630ms step_avg:92.12ms +step:898/1660 train_time:82724ms step_avg:92.12ms +step:899/1660 train_time:82818ms step_avg:92.12ms +step:900/1660 train_time:82911ms step_avg:92.12ms +step:901/1660 train_time:83004ms step_avg:92.12ms +step:902/1660 train_time:83096ms step_avg:92.12ms +step:903/1660 train_time:83189ms step_avg:92.12ms +step:904/1660 train_time:83281ms step_avg:92.12ms +step:905/1660 train_time:83373ms step_avg:92.12ms +step:906/1660 train_time:83465ms step_avg:92.13ms +step:907/1660 train_time:83559ms step_avg:92.13ms +step:908/1660 train_time:83651ms step_avg:92.13ms +step:909/1660 train_time:83745ms step_avg:92.13ms +step:910/1660 train_time:83839ms step_avg:92.13ms +step:911/1660 train_time:83932ms step_avg:92.13ms +step:912/1660 train_time:84025ms step_avg:92.13ms +step:913/1660 train_time:84117ms step_avg:92.13ms +step:914/1660 train_time:84209ms step_avg:92.13ms +step:915/1660 train_time:84301ms step_avg:92.13ms +step:916/1660 train_time:84393ms step_avg:92.13ms +step:917/1660 train_time:84486ms step_avg:92.13ms +step:918/1660 train_time:84580ms step_avg:92.13ms +step:919/1660 train_time:84672ms step_avg:92.13ms +step:920/1660 train_time:84765ms step_avg:92.14ms +step:921/1660 train_time:84859ms step_avg:92.14ms +step:922/1660 train_time:84951ms step_avg:92.14ms +step:923/1660 train_time:85044ms step_avg:92.14ms +step:924/1660 train_time:85136ms step_avg:92.14ms +step:925/1660 train_time:85228ms step_avg:92.14ms +step:926/1660 train_time:85321ms step_avg:92.14ms +step:927/1660 train_time:85413ms step_avg:92.14ms +step:928/1660 train_time:85506ms step_avg:92.14ms +step:929/1660 train_time:85599ms step_avg:92.14ms +step:930/1660 train_time:85691ms step_avg:92.14ms +step:931/1660 train_time:85786ms step_avg:92.14ms +step:932/1660 train_time:85880ms step_avg:92.15ms +step:933/1660 train_time:85972ms step_avg:92.15ms +step:934/1660 train_time:86065ms step_avg:92.15ms +step:935/1660 train_time:86157ms step_avg:92.15ms +step:936/1660 train_time:86250ms step_avg:92.15ms +step:937/1660 train_time:86342ms step_avg:92.15ms +step:938/1660 train_time:86434ms step_avg:92.15ms +step:939/1660 train_time:86528ms step_avg:92.15ms +step:940/1660 train_time:86621ms step_avg:92.15ms +step:941/1660 train_time:86714ms step_avg:92.15ms +step:942/1660 train_time:86809ms step_avg:92.15ms +step:943/1660 train_time:86902ms step_avg:92.15ms +step:944/1660 train_time:86994ms step_avg:92.16ms +step:945/1660 train_time:87088ms step_avg:92.16ms +step:946/1660 train_time:87180ms step_avg:92.16ms +step:947/1660 train_time:87272ms step_avg:92.16ms +step:948/1660 train_time:87363ms step_avg:92.16ms +step:949/1660 train_time:87455ms step_avg:92.15ms +step:950/1660 train_time:87548ms step_avg:92.16ms +step:951/1660 train_time:87640ms step_avg:92.16ms +step:952/1660 train_time:87732ms step_avg:92.16ms +step:953/1660 train_time:87826ms step_avg:92.16ms +step:954/1660 train_time:87920ms step_avg:92.16ms +step:955/1660 train_time:88012ms step_avg:92.16ms +step:956/1660 train_time:88106ms step_avg:92.16ms +step:957/1660 train_time:88198ms step_avg:92.16ms +step:958/1660 train_time:88291ms step_avg:92.16ms +step:959/1660 train_time:88384ms step_avg:92.16ms +step:960/1660 train_time:88476ms step_avg:92.16ms +step:961/1660 train_time:88569ms step_avg:92.16ms +step:962/1660 train_time:88662ms step_avg:92.16ms +step:963/1660 train_time:88753ms step_avg:92.16ms +step:964/1660 train_time:88848ms step_avg:92.17ms +step:965/1660 train_time:88941ms step_avg:92.17ms +step:966/1660 train_time:89033ms step_avg:92.17ms +step:967/1660 train_time:89127ms step_avg:92.17ms +step:968/1660 train_time:89220ms step_avg:92.17ms +step:969/1660 train_time:89312ms step_avg:92.17ms +step:970/1660 train_time:89405ms step_avg:92.17ms +step:971/1660 train_time:89497ms step_avg:92.17ms +step:972/1660 train_time:89590ms step_avg:92.17ms +step:973/1660 train_time:89683ms step_avg:92.17ms +step:974/1660 train_time:89776ms step_avg:92.17ms +step:975/1660 train_time:89868ms step_avg:92.17ms +step:976/1660 train_time:89961ms step_avg:92.17ms +step:977/1660 train_time:90053ms step_avg:92.17ms +step:978/1660 train_time:90146ms step_avg:92.17ms +step:979/1660 train_time:90239ms step_avg:92.17ms +step:980/1660 train_time:90331ms step_avg:92.17ms +step:981/1660 train_time:90425ms step_avg:92.18ms +step:982/1660 train_time:90518ms step_avg:92.18ms +step:983/1660 train_time:90610ms step_avg:92.18ms +step:984/1660 train_time:90704ms step_avg:92.18ms +step:985/1660 train_time:90796ms step_avg:92.18ms +step:986/1660 train_time:90888ms step_avg:92.18ms +step:987/1660 train_time:90981ms step_avg:92.18ms +step:988/1660 train_time:91073ms step_avg:92.18ms +step:989/1660 train_time:91165ms step_avg:92.18ms +step:990/1660 train_time:91257ms step_avg:92.18ms +step:991/1660 train_time:91350ms step_avg:92.18ms +step:992/1660 train_time:91443ms step_avg:92.18ms +step:993/1660 train_time:91535ms step_avg:92.18ms +step:994/1660 train_time:91628ms step_avg:92.18ms +step:995/1660 train_time:91723ms step_avg:92.18ms +step:996/1660 train_time:91814ms step_avg:92.18ms +step:997/1660 train_time:91908ms step_avg:92.18ms +step:998/1660 train_time:92001ms step_avg:92.19ms +step:999/1660 train_time:92094ms step_avg:92.19ms +step:1000/1660 train_time:92188ms step_avg:92.19ms +step:1000/1660 val_loss:3.4677 train_time:92282ms step_avg:92.28ms +step:1001/1660 train_time:92302ms step_avg:92.21ms +step:1002/1660 train_time:92377ms step_avg:92.19ms +step:1003/1660 train_time:92477ms step_avg:92.20ms +step:1004/1660 train_time:92570ms step_avg:92.20ms +step:1005/1660 train_time:92662ms step_avg:92.20ms +step:1006/1660 train_time:92753ms step_avg:92.20ms +step:1007/1660 train_time:92844ms step_avg:92.20ms +step:1008/1660 train_time:92936ms step_avg:92.20ms +step:1009/1660 train_time:93027ms step_avg:92.20ms +step:1010/1660 train_time:93120ms step_avg:92.20ms +step:1011/1660 train_time:93212ms step_avg:92.20ms +step:1012/1660 train_time:93307ms step_avg:92.20ms +step:1013/1660 train_time:93403ms step_avg:92.20ms +step:1014/1660 train_time:93498ms step_avg:92.21ms +step:1015/1660 train_time:93591ms step_avg:92.21ms +step:1016/1660 train_time:93683ms step_avg:92.21ms +step:1017/1660 train_time:93774ms step_avg:92.21ms +step:1018/1660 train_time:93866ms step_avg:92.21ms +step:1019/1660 train_time:93960ms step_avg:92.21ms +step:1020/1660 train_time:94052ms step_avg:92.21ms +step:1021/1660 train_time:94144ms step_avg:92.21ms +step:1022/1660 train_time:94238ms step_avg:92.21ms +step:1023/1660 train_time:94331ms step_avg:92.21ms +step:1024/1660 train_time:94426ms step_avg:92.21ms +step:1025/1660 train_time:94519ms step_avg:92.21ms +step:1026/1660 train_time:94612ms step_avg:92.21ms +step:1027/1660 train_time:94704ms step_avg:92.21ms +step:1028/1660 train_time:94796ms step_avg:92.21ms +step:1029/1660 train_time:94888ms step_avg:92.21ms +step:1030/1660 train_time:94981ms step_avg:92.21ms +step:1031/1660 train_time:95073ms step_avg:92.21ms +step:1032/1660 train_time:95165ms step_avg:92.21ms +step:1033/1660 train_time:95259ms step_avg:92.22ms +step:1034/1660 train_time:95352ms step_avg:92.22ms +step:1035/1660 train_time:95445ms step_avg:92.22ms +step:1036/1660 train_time:95540ms step_avg:92.22ms +step:1037/1660 train_time:95632ms step_avg:92.22ms +step:1038/1660 train_time:95724ms step_avg:92.22ms +step:1039/1660 train_time:95816ms step_avg:92.22ms +step:1040/1660 train_time:95907ms step_avg:92.22ms +step:1041/1660 train_time:96000ms step_avg:92.22ms +step:1042/1660 train_time:96094ms step_avg:92.22ms +step:1043/1660 train_time:96186ms step_avg:92.22ms +step:1044/1660 train_time:96280ms step_avg:92.22ms +step:1045/1660 train_time:96373ms step_avg:92.22ms +step:1046/1660 train_time:96465ms step_avg:92.22ms +step:1047/1660 train_time:96560ms step_avg:92.23ms +step:1048/1660 train_time:96653ms step_avg:92.23ms +step:1049/1660 train_time:96746ms step_avg:92.23ms +step:1050/1660 train_time:96838ms step_avg:92.23ms +step:1051/1660 train_time:96929ms step_avg:92.23ms +step:1052/1660 train_time:97022ms step_avg:92.23ms +step:1053/1660 train_time:97115ms step_avg:92.23ms +step:1054/1660 train_time:97206ms step_avg:92.23ms +step:1055/1660 train_time:97299ms step_avg:92.23ms +step:1056/1660 train_time:97392ms step_avg:92.23ms +step:1057/1660 train_time:97485ms step_avg:92.23ms +step:1058/1660 train_time:97580ms step_avg:92.23ms +step:1059/1660 train_time:97673ms step_avg:92.23ms +step:1060/1660 train_time:97765ms step_avg:92.23ms +step:1061/1660 train_time:97858ms step_avg:92.23ms +step:1062/1660 train_time:97951ms step_avg:92.23ms +step:1063/1660 train_time:98043ms step_avg:92.23ms +step:1064/1660 train_time:98135ms step_avg:92.23ms +step:1065/1660 train_time:98227ms step_avg:92.23ms +step:1066/1660 train_time:98321ms step_avg:92.23ms +step:1067/1660 train_time:98414ms step_avg:92.23ms +step:1068/1660 train_time:98507ms step_avg:92.23ms +step:1069/1660 train_time:98602ms step_avg:92.24ms +step:1070/1660 train_time:98696ms step_avg:92.24ms +step:1071/1660 train_time:98787ms step_avg:92.24ms +step:1072/1660 train_time:98881ms step_avg:92.24ms +step:1073/1660 train_time:98973ms step_avg:92.24ms +step:1074/1660 train_time:99066ms step_avg:92.24ms +step:1075/1660 train_time:99159ms step_avg:92.24ms +step:1076/1660 train_time:99252ms step_avg:92.24ms +step:1077/1660 train_time:99345ms step_avg:92.24ms +step:1078/1660 train_time:99437ms step_avg:92.24ms +step:1079/1660 train_time:99530ms step_avg:92.24ms +step:1080/1660 train_time:99623ms step_avg:92.24ms +step:1081/1660 train_time:99716ms step_avg:92.24ms +step:1082/1660 train_time:99808ms step_avg:92.24ms +step:1083/1660 train_time:99901ms step_avg:92.24ms +step:1084/1660 train_time:99994ms step_avg:92.25ms +step:1085/1660 train_time:100087ms step_avg:92.25ms +step:1086/1660 train_time:100180ms step_avg:92.25ms +step:1087/1660 train_time:100273ms step_avg:92.25ms +step:1088/1660 train_time:100365ms step_avg:92.25ms +step:1089/1660 train_time:100459ms step_avg:92.25ms +step:1090/1660 train_time:100552ms step_avg:92.25ms +step:1091/1660 train_time:100644ms step_avg:92.25ms +step:1092/1660 train_time:100738ms step_avg:92.25ms +step:1093/1660 train_time:100829ms step_avg:92.25ms +step:1094/1660 train_time:100924ms step_avg:92.25ms +step:1095/1660 train_time:101016ms step_avg:92.25ms +step:1096/1660 train_time:101107ms step_avg:92.25ms +step:1097/1660 train_time:101201ms step_avg:92.25ms +step:1098/1660 train_time:101294ms step_avg:92.25ms +step:1099/1660 train_time:101386ms step_avg:92.25ms +step:1100/1660 train_time:101480ms step_avg:92.25ms +step:1101/1660 train_time:101573ms step_avg:92.25ms +step:1102/1660 train_time:101665ms step_avg:92.26ms +step:1103/1660 train_time:101758ms step_avg:92.26ms +step:1104/1660 train_time:101850ms step_avg:92.26ms +step:1105/1660 train_time:101943ms step_avg:92.26ms +step:1106/1660 train_time:102035ms step_avg:92.26ms +step:1107/1660 train_time:102128ms step_avg:92.26ms +step:1108/1660 train_time:102222ms step_avg:92.26ms +step:1109/1660 train_time:102315ms step_avg:92.26ms +step:1110/1660 train_time:102408ms step_avg:92.26ms +step:1111/1660 train_time:102502ms step_avg:92.26ms +step:1112/1660 train_time:102595ms step_avg:92.26ms +step:1113/1660 train_time:102688ms step_avg:92.26ms +step:1114/1660 train_time:102782ms step_avg:92.26ms +step:1115/1660 train_time:102876ms step_avg:92.27ms +step:1116/1660 train_time:102970ms step_avg:92.27ms +step:1117/1660 train_time:103064ms step_avg:92.27ms +step:1118/1660 train_time:103159ms step_avg:92.27ms +step:1119/1660 train_time:103253ms step_avg:92.27ms +step:1120/1660 train_time:103346ms step_avg:92.27ms +step:1121/1660 train_time:103440ms step_avg:92.28ms +step:1122/1660 train_time:103534ms step_avg:92.28ms +step:1123/1660 train_time:103627ms step_avg:92.28ms +step:1124/1660 train_time:103720ms step_avg:92.28ms +step:1125/1660 train_time:103814ms step_avg:92.28ms +step:1125/1660 val_loss:3.4156 train_time:103909ms step_avg:92.36ms +step:1126/1660 train_time:103928ms step_avg:92.30ms +step:1127/1660 train_time:104012ms step_avg:92.29ms +step:1128/1660 train_time:104110ms step_avg:92.30ms +step:1129/1660 train_time:104204ms step_avg:92.30ms +step:1130/1660 train_time:104296ms step_avg:92.30ms +step:1131/1660 train_time:104388ms step_avg:92.30ms +step:1132/1660 train_time:104481ms step_avg:92.30ms +step:1133/1660 train_time:104574ms step_avg:92.30ms +step:1134/1660 train_time:104666ms step_avg:92.30ms +step:1135/1660 train_time:104758ms step_avg:92.30ms +step:1136/1660 train_time:104852ms step_avg:92.30ms +step:1137/1660 train_time:104947ms step_avg:92.30ms +step:1138/1660 train_time:105045ms step_avg:92.31ms +step:1139/1660 train_time:105142ms step_avg:92.31ms +step:1140/1660 train_time:105235ms step_avg:92.31ms +step:1141/1660 train_time:105327ms step_avg:92.31ms +step:1142/1660 train_time:105419ms step_avg:92.31ms +step:1143/1660 train_time:105512ms step_avg:92.31ms +step:1144/1660 train_time:105605ms step_avg:92.31ms +step:1145/1660 train_time:105697ms step_avg:92.31ms +step:1146/1660 train_time:105791ms step_avg:92.31ms +step:1147/1660 train_time:105885ms step_avg:92.31ms +step:1148/1660 train_time:105981ms step_avg:92.32ms +step:1149/1660 train_time:106077ms step_avg:92.32ms +step:1150/1660 train_time:106170ms step_avg:92.32ms +step:1151/1660 train_time:106264ms step_avg:92.32ms +step:1152/1660 train_time:106357ms step_avg:92.32ms +step:1153/1660 train_time:106450ms step_avg:92.32ms +step:1154/1660 train_time:106543ms step_avg:92.32ms +step:1155/1660 train_time:106636ms step_avg:92.33ms +step:1156/1660 train_time:106728ms step_avg:92.33ms +step:1157/1660 train_time:106822ms step_avg:92.33ms +step:1158/1660 train_time:106917ms step_avg:92.33ms +step:1159/1660 train_time:107013ms step_avg:92.33ms +step:1160/1660 train_time:107107ms step_avg:92.33ms +step:1161/1660 train_time:107203ms step_avg:92.34ms +step:1162/1660 train_time:107297ms step_avg:92.34ms +step:1163/1660 train_time:107391ms step_avg:92.34ms +step:1164/1660 train_time:107484ms step_avg:92.34ms +step:1165/1660 train_time:107576ms step_avg:92.34ms +step:1166/1660 train_time:107669ms step_avg:92.34ms +step:1167/1660 train_time:107762ms step_avg:92.34ms +step:1168/1660 train_time:107855ms step_avg:92.34ms +step:1169/1660 train_time:107949ms step_avg:92.34ms +step:1170/1660 train_time:108043ms step_avg:92.34ms +step:1171/1660 train_time:108137ms step_avg:92.35ms +step:1172/1660 train_time:108231ms step_avg:92.35ms +step:1173/1660 train_time:108324ms step_avg:92.35ms +step:1174/1660 train_time:108419ms step_avg:92.35ms +step:1175/1660 train_time:108513ms step_avg:92.35ms +step:1176/1660 train_time:108605ms step_avg:92.35ms +step:1177/1660 train_time:108698ms step_avg:92.35ms +step:1178/1660 train_time:108791ms step_avg:92.35ms +step:1179/1660 train_time:108884ms step_avg:92.35ms +step:1180/1660 train_time:108980ms step_avg:92.36ms +step:1181/1660 train_time:109074ms step_avg:92.36ms +step:1182/1660 train_time:109167ms step_avg:92.36ms +step:1183/1660 train_time:109261ms step_avg:92.36ms +step:1184/1660 train_time:109355ms step_avg:92.36ms +step:1185/1660 train_time:109448ms step_avg:92.36ms +step:1186/1660 train_time:109541ms step_avg:92.36ms +step:1187/1660 train_time:109634ms step_avg:92.36ms +step:1188/1660 train_time:109726ms step_avg:92.36ms +step:1189/1660 train_time:109820ms step_avg:92.36ms +step:1190/1660 train_time:109914ms step_avg:92.36ms +step:1191/1660 train_time:110008ms step_avg:92.37ms +step:1192/1660 train_time:110101ms step_avg:92.37ms +step:1193/1660 train_time:110194ms step_avg:92.37ms +step:1194/1660 train_time:110287ms step_avg:92.37ms +step:1195/1660 train_time:110380ms step_avg:92.37ms +step:1196/1660 train_time:110474ms step_avg:92.37ms +step:1197/1660 train_time:110566ms step_avg:92.37ms +step:1198/1660 train_time:110659ms step_avg:92.37ms +step:1199/1660 train_time:110753ms step_avg:92.37ms +step:1200/1660 train_time:110846ms step_avg:92.37ms +step:1201/1660 train_time:110940ms step_avg:92.37ms +step:1202/1660 train_time:111034ms step_avg:92.37ms +step:1203/1660 train_time:111127ms step_avg:92.37ms +step:1204/1660 train_time:111221ms step_avg:92.38ms +step:1205/1660 train_time:111315ms step_avg:92.38ms +step:1206/1660 train_time:111409ms step_avg:92.38ms +step:1207/1660 train_time:111502ms step_avg:92.38ms +step:1208/1660 train_time:111594ms step_avg:92.38ms +step:1209/1660 train_time:111688ms step_avg:92.38ms +step:1210/1660 train_time:111782ms step_avg:92.38ms +step:1211/1660 train_time:111875ms step_avg:92.38ms +step:1212/1660 train_time:111968ms step_avg:92.38ms +step:1213/1660 train_time:112061ms step_avg:92.38ms +step:1214/1660 train_time:112155ms step_avg:92.38ms +step:1215/1660 train_time:112248ms step_avg:92.39ms +step:1216/1660 train_time:112342ms step_avg:92.39ms +step:1217/1660 train_time:112436ms step_avg:92.39ms +step:1218/1660 train_time:112529ms step_avg:92.39ms +step:1219/1660 train_time:112623ms step_avg:92.39ms +step:1220/1660 train_time:112715ms step_avg:92.39ms +step:1221/1660 train_time:112808ms step_avg:92.39ms +step:1222/1660 train_time:112902ms step_avg:92.39ms +step:1223/1660 train_time:112995ms step_avg:92.39ms +step:1224/1660 train_time:113088ms step_avg:92.39ms +step:1225/1660 train_time:113182ms step_avg:92.39ms +step:1226/1660 train_time:113276ms step_avg:92.39ms +step:1227/1660 train_time:113370ms step_avg:92.40ms +step:1228/1660 train_time:113463ms step_avg:92.40ms +step:1229/1660 train_time:113556ms step_avg:92.40ms +step:1230/1660 train_time:113650ms step_avg:92.40ms +step:1231/1660 train_time:113744ms step_avg:92.40ms +step:1232/1660 train_time:113837ms step_avg:92.40ms +step:1233/1660 train_time:113931ms step_avg:92.40ms +step:1234/1660 train_time:114024ms step_avg:92.40ms +step:1235/1660 train_time:114118ms step_avg:92.40ms +step:1236/1660 train_time:114211ms step_avg:92.40ms +step:1237/1660 train_time:114305ms step_avg:92.40ms +step:1238/1660 train_time:114398ms step_avg:92.41ms +step:1239/1660 train_time:114491ms step_avg:92.41ms +step:1240/1660 train_time:114584ms step_avg:92.41ms +step:1241/1660 train_time:114678ms step_avg:92.41ms +step:1242/1660 train_time:114771ms step_avg:92.41ms +step:1243/1660 train_time:114863ms step_avg:92.41ms +step:1244/1660 train_time:114957ms step_avg:92.41ms +step:1245/1660 train_time:115051ms step_avg:92.41ms +step:1246/1660 train_time:115145ms step_avg:92.41ms +step:1247/1660 train_time:115238ms step_avg:92.41ms +step:1248/1660 train_time:115332ms step_avg:92.41ms +step:1249/1660 train_time:115424ms step_avg:92.41ms +step:1250/1660 train_time:115519ms step_avg:92.42ms +step:1250/1660 val_loss:3.3764 train_time:115614ms step_avg:92.49ms +step:1251/1660 train_time:115634ms step_avg:92.43ms +step:1252/1660 train_time:115712ms step_avg:92.42ms +step:1253/1660 train_time:115810ms step_avg:92.43ms +step:1254/1660 train_time:115904ms step_avg:92.43ms +step:1255/1660 train_time:115996ms step_avg:92.43ms +step:1256/1660 train_time:116088ms step_avg:92.43ms +step:1257/1660 train_time:116181ms step_avg:92.43ms +step:1258/1660 train_time:116273ms step_avg:92.43ms +step:1259/1660 train_time:116366ms step_avg:92.43ms +step:1260/1660 train_time:116458ms step_avg:92.43ms +step:1261/1660 train_time:116552ms step_avg:92.43ms +step:1262/1660 train_time:116650ms step_avg:92.43ms +step:1263/1660 train_time:116747ms step_avg:92.44ms +step:1264/1660 train_time:116842ms step_avg:92.44ms +step:1265/1660 train_time:116936ms step_avg:92.44ms +step:1266/1660 train_time:117029ms step_avg:92.44ms +step:1267/1660 train_time:117122ms step_avg:92.44ms +step:1268/1660 train_time:117214ms step_avg:92.44ms +step:1269/1660 train_time:117307ms step_avg:92.44ms +step:1270/1660 train_time:117399ms step_avg:92.44ms +step:1271/1660 train_time:117492ms step_avg:92.44ms +step:1272/1660 train_time:117586ms step_avg:92.44ms +step:1273/1660 train_time:117682ms step_avg:92.44ms +step:1274/1660 train_time:117776ms step_avg:92.45ms +step:1275/1660 train_time:117870ms step_avg:92.45ms +step:1276/1660 train_time:117963ms step_avg:92.45ms +step:1277/1660 train_time:118056ms step_avg:92.45ms +step:1278/1660 train_time:118150ms step_avg:92.45ms +step:1279/1660 train_time:118243ms step_avg:92.45ms +step:1280/1660 train_time:118335ms step_avg:92.45ms +step:1281/1660 train_time:118427ms step_avg:92.45ms +step:1282/1660 train_time:118521ms step_avg:92.45ms +step:1283/1660 train_time:118614ms step_avg:92.45ms +step:1284/1660 train_time:118710ms step_avg:92.45ms +step:1285/1660 train_time:118806ms step_avg:92.46ms +step:1286/1660 train_time:118900ms step_avg:92.46ms +step:1287/1660 train_time:118993ms step_avg:92.46ms +step:1288/1660 train_time:119087ms step_avg:92.46ms +step:1289/1660 train_time:119181ms step_avg:92.46ms +step:1290/1660 train_time:119273ms step_avg:92.46ms +step:1291/1660 train_time:119367ms step_avg:92.46ms +step:1292/1660 train_time:119459ms step_avg:92.46ms +step:1293/1660 train_time:119552ms step_avg:92.46ms +step:1294/1660 train_time:119647ms step_avg:92.46ms +step:1295/1660 train_time:119741ms step_avg:92.46ms +step:1296/1660 train_time:119835ms step_avg:92.46ms +step:1297/1660 train_time:119929ms step_avg:92.47ms +step:1298/1660 train_time:120022ms step_avg:92.47ms +step:1299/1660 train_time:120115ms step_avg:92.47ms +step:1300/1660 train_time:120209ms step_avg:92.47ms +step:1301/1660 train_time:120302ms step_avg:92.47ms +step:1302/1660 train_time:120394ms step_avg:92.47ms +step:1303/1660 train_time:120488ms step_avg:92.47ms +step:1304/1660 train_time:120580ms step_avg:92.47ms +step:1305/1660 train_time:120674ms step_avg:92.47ms +step:1306/1660 train_time:120769ms step_avg:92.47ms +step:1307/1660 train_time:120862ms step_avg:92.47ms +step:1308/1660 train_time:120955ms step_avg:92.47ms +step:1309/1660 train_time:121050ms step_avg:92.48ms +step:1310/1660 train_time:121145ms step_avg:92.48ms +step:1311/1660 train_time:121238ms step_avg:92.48ms +step:1312/1660 train_time:121331ms step_avg:92.48ms +step:1313/1660 train_time:121424ms step_avg:92.48ms +step:1314/1660 train_time:121518ms step_avg:92.48ms +step:1315/1660 train_time:121611ms step_avg:92.48ms +step:1316/1660 train_time:121705ms step_avg:92.48ms +step:1317/1660 train_time:121799ms step_avg:92.48ms +step:1318/1660 train_time:121892ms step_avg:92.48ms +step:1319/1660 train_time:121986ms step_avg:92.48ms +step:1320/1660 train_time:122080ms step_avg:92.48ms +step:1321/1660 train_time:122175ms step_avg:92.49ms +step:1322/1660 train_time:122269ms step_avg:92.49ms +step:1323/1660 train_time:122362ms step_avg:92.49ms +step:1324/1660 train_time:122455ms step_avg:92.49ms +step:1325/1660 train_time:122549ms step_avg:92.49ms +step:1326/1660 train_time:122643ms step_avg:92.49ms +step:1327/1660 train_time:122736ms step_avg:92.49ms +step:1328/1660 train_time:122830ms step_avg:92.49ms +step:1329/1660 train_time:122924ms step_avg:92.49ms +step:1330/1660 train_time:123018ms step_avg:92.49ms +step:1331/1660 train_time:123112ms step_avg:92.50ms +step:1332/1660 train_time:123207ms step_avg:92.50ms +step:1333/1660 train_time:123301ms step_avg:92.50ms +step:1334/1660 train_time:123393ms step_avg:92.50ms +step:1335/1660 train_time:123487ms step_avg:92.50ms +step:1336/1660 train_time:123581ms step_avg:92.50ms +step:1337/1660 train_time:123674ms step_avg:92.50ms +step:1338/1660 train_time:123768ms step_avg:92.50ms +step:1339/1660 train_time:123861ms step_avg:92.50ms +step:1340/1660 train_time:123954ms step_avg:92.50ms +step:1341/1660 train_time:124049ms step_avg:92.50ms +step:1342/1660 train_time:124142ms step_avg:92.51ms +step:1343/1660 train_time:124235ms step_avg:92.51ms +step:1344/1660 train_time:124330ms step_avg:92.51ms +step:1345/1660 train_time:124423ms step_avg:92.51ms +step:1346/1660 train_time:124516ms step_avg:92.51ms +step:1347/1660 train_time:124610ms step_avg:92.51ms +step:1348/1660 train_time:124704ms step_avg:92.51ms +step:1349/1660 train_time:124797ms step_avg:92.51ms +step:1350/1660 train_time:124890ms step_avg:92.51ms +step:1351/1660 train_time:124983ms step_avg:92.51ms +step:1352/1660 train_time:125076ms step_avg:92.51ms +step:1353/1660 train_time:125170ms step_avg:92.51ms +step:1354/1660 train_time:125263ms step_avg:92.51ms +step:1355/1660 train_time:125356ms step_avg:92.51ms +step:1356/1660 train_time:125450ms step_avg:92.51ms +step:1357/1660 train_time:125545ms step_avg:92.52ms +step:1358/1660 train_time:125638ms step_avg:92.52ms +step:1359/1660 train_time:125731ms step_avg:92.52ms +step:1360/1660 train_time:125825ms step_avg:92.52ms +step:1361/1660 train_time:125918ms step_avg:92.52ms +step:1362/1660 train_time:126012ms step_avg:92.52ms +step:1363/1660 train_time:126105ms step_avg:92.52ms +step:1364/1660 train_time:126199ms step_avg:92.52ms +step:1365/1660 train_time:126291ms step_avg:92.52ms +step:1366/1660 train_time:126385ms step_avg:92.52ms +step:1367/1660 train_time:126478ms step_avg:92.52ms +step:1368/1660 train_time:126572ms step_avg:92.52ms +step:1369/1660 train_time:126667ms step_avg:92.53ms +step:1370/1660 train_time:126759ms step_avg:92.53ms +step:1371/1660 train_time:126852ms step_avg:92.53ms +step:1372/1660 train_time:126946ms step_avg:92.53ms +step:1373/1660 train_time:127041ms step_avg:92.53ms +step:1374/1660 train_time:127134ms step_avg:92.53ms +step:1375/1660 train_time:127227ms step_avg:92.53ms +step:1375/1660 val_loss:3.3418 train_time:127322ms step_avg:92.60ms +step:1376/1660 train_time:127341ms step_avg:92.54ms +step:1377/1660 train_time:127420ms step_avg:92.53ms +step:1378/1660 train_time:127520ms step_avg:92.54ms +step:1379/1660 train_time:127614ms step_avg:92.54ms +step:1380/1660 train_time:127706ms step_avg:92.54ms +step:1381/1660 train_time:127799ms step_avg:92.54ms +step:1382/1660 train_time:127892ms step_avg:92.54ms +step:1383/1660 train_time:127984ms step_avg:92.54ms +step:1384/1660 train_time:128077ms step_avg:92.54ms +step:1385/1660 train_time:128170ms step_avg:92.54ms +step:1386/1660 train_time:128262ms step_avg:92.54ms +step:1387/1660 train_time:128357ms step_avg:92.54ms +step:1388/1660 train_time:128453ms step_avg:92.55ms +step:1389/1660 train_time:128548ms step_avg:92.55ms +step:1390/1660 train_time:128643ms step_avg:92.55ms +step:1391/1660 train_time:128736ms step_avg:92.55ms +step:1392/1660 train_time:128828ms step_avg:92.55ms +step:1393/1660 train_time:128920ms step_avg:92.55ms +step:1394/1660 train_time:129013ms step_avg:92.55ms +step:1395/1660 train_time:129105ms step_avg:92.55ms +step:1396/1660 train_time:129199ms step_avg:92.55ms +step:1397/1660 train_time:129292ms step_avg:92.55ms +step:1398/1660 train_time:129386ms step_avg:92.55ms +step:1399/1660 train_time:129481ms step_avg:92.55ms +step:1400/1660 train_time:129577ms step_avg:92.55ms +step:1401/1660 train_time:129671ms step_avg:92.56ms +step:1402/1660 train_time:129764ms step_avg:92.56ms +step:1403/1660 train_time:129858ms step_avg:92.56ms +step:1404/1660 train_time:129952ms step_avg:92.56ms +step:1405/1660 train_time:130044ms step_avg:92.56ms +step:1406/1660 train_time:130137ms step_avg:92.56ms +step:1407/1660 train_time:130229ms step_avg:92.56ms +step:1408/1660 train_time:130322ms step_avg:92.56ms +step:1409/1660 train_time:130417ms step_avg:92.56ms +step:1410/1660 train_time:130512ms step_avg:92.56ms +step:1411/1660 train_time:130605ms step_avg:92.56ms +step:1412/1660 train_time:130701ms step_avg:92.56ms +step:1413/1660 train_time:130794ms step_avg:92.57ms +step:1414/1660 train_time:130887ms step_avg:92.57ms +step:1415/1660 train_time:130980ms step_avg:92.57ms +step:1416/1660 train_time:131073ms step_avg:92.57ms +step:1417/1660 train_time:131166ms step_avg:92.57ms +step:1418/1660 train_time:131260ms step_avg:92.57ms +step:1419/1660 train_time:131353ms step_avg:92.57ms +step:1420/1660 train_time:131447ms step_avg:92.57ms +step:1421/1660 train_time:131541ms step_avg:92.57ms +step:1422/1660 train_time:131635ms step_avg:92.57ms +step:1423/1660 train_time:131729ms step_avg:92.57ms +step:1424/1660 train_time:131822ms step_avg:92.57ms +step:1425/1660 train_time:131915ms step_avg:92.57ms +step:1426/1660 train_time:132008ms step_avg:92.57ms +step:1427/1660 train_time:132101ms step_avg:92.57ms +step:1428/1660 train_time:132194ms step_avg:92.57ms +step:1429/1660 train_time:132287ms step_avg:92.57ms +step:1430/1660 train_time:132381ms step_avg:92.57ms +step:1431/1660 train_time:132475ms step_avg:92.58ms +step:1432/1660 train_time:132569ms step_avg:92.58ms +step:1433/1660 train_time:132662ms step_avg:92.58ms +step:1434/1660 train_time:132756ms step_avg:92.58ms +step:1435/1660 train_time:132850ms step_avg:92.58ms +step:1436/1660 train_time:132943ms step_avg:92.58ms +step:1437/1660 train_time:133039ms step_avg:92.58ms +step:1438/1660 train_time:133133ms step_avg:92.58ms +step:1439/1660 train_time:133225ms step_avg:92.58ms +step:1440/1660 train_time:133318ms step_avg:92.58ms +step:1441/1660 train_time:133412ms step_avg:92.58ms +step:1442/1660 train_time:133505ms step_avg:92.58ms +step:1443/1660 train_time:133598ms step_avg:92.58ms +step:1444/1660 train_time:133692ms step_avg:92.58ms +step:1445/1660 train_time:133786ms step_avg:92.59ms +step:1446/1660 train_time:133882ms step_avg:92.59ms +step:1447/1660 train_time:133975ms step_avg:92.59ms +step:1448/1660 train_time:134067ms step_avg:92.59ms +step:1449/1660 train_time:134161ms step_avg:92.59ms +step:1450/1660 train_time:134255ms step_avg:92.59ms +step:1451/1660 train_time:134349ms step_avg:92.59ms +step:1452/1660 train_time:134443ms step_avg:92.59ms +step:1453/1660 train_time:134536ms step_avg:92.59ms +step:1454/1660 train_time:134630ms step_avg:92.59ms +step:1455/1660 train_time:134723ms step_avg:92.59ms +step:1456/1660 train_time:134817ms step_avg:92.59ms +step:1457/1660 train_time:134911ms step_avg:92.60ms +step:1458/1660 train_time:135004ms step_avg:92.60ms +step:1459/1660 train_time:135100ms step_avg:92.60ms +step:1460/1660 train_time:135193ms step_avg:92.60ms +step:1461/1660 train_time:135286ms step_avg:92.60ms +step:1462/1660 train_time:135379ms step_avg:92.60ms +step:1463/1660 train_time:135473ms step_avg:92.60ms +step:1464/1660 train_time:135566ms step_avg:92.60ms +step:1465/1660 train_time:135660ms step_avg:92.60ms +step:1466/1660 train_time:135753ms step_avg:92.60ms +step:1467/1660 train_time:135846ms step_avg:92.60ms +step:1468/1660 train_time:135940ms step_avg:92.60ms +step:1469/1660 train_time:136034ms step_avg:92.60ms +step:1470/1660 train_time:136126ms step_avg:92.60ms +step:1471/1660 train_time:136219ms step_avg:92.60ms +step:1472/1660 train_time:136314ms step_avg:92.60ms +step:1473/1660 train_time:136407ms step_avg:92.60ms +step:1474/1660 train_time:136500ms step_avg:92.60ms +step:1475/1660 train_time:136593ms step_avg:92.61ms +step:1476/1660 train_time:136686ms step_avg:92.61ms +step:1477/1660 train_time:136781ms step_avg:92.61ms +step:1478/1660 train_time:136874ms step_avg:92.61ms +step:1479/1660 train_time:136967ms step_avg:92.61ms +step:1480/1660 train_time:137061ms step_avg:92.61ms +step:1481/1660 train_time:137155ms step_avg:92.61ms +step:1482/1660 train_time:137248ms step_avg:92.61ms +step:1483/1660 train_time:137342ms step_avg:92.61ms +step:1484/1660 train_time:137436ms step_avg:92.61ms +step:1485/1660 train_time:137529ms step_avg:92.61ms +step:1486/1660 train_time:137621ms step_avg:92.61ms +step:1487/1660 train_time:137714ms step_avg:92.61ms +step:1488/1660 train_time:137808ms step_avg:92.61ms +step:1489/1660 train_time:137902ms step_avg:92.61ms +step:1490/1660 train_time:137996ms step_avg:92.61ms +step:1491/1660 train_time:138088ms step_avg:92.61ms +step:1492/1660 train_time:138182ms step_avg:92.62ms +step:1493/1660 train_time:138277ms step_avg:92.62ms +step:1494/1660 train_time:138371ms step_avg:92.62ms +step:1495/1660 train_time:138465ms step_avg:92.62ms +step:1496/1660 train_time:138559ms step_avg:92.62ms +step:1497/1660 train_time:138653ms step_avg:92.62ms +step:1498/1660 train_time:138746ms step_avg:92.62ms +step:1499/1660 train_time:138841ms step_avg:92.62ms +step:1500/1660 train_time:138935ms step_avg:92.62ms +step:1500/1660 val_loss:3.3118 train_time:139028ms step_avg:92.69ms +step:1501/1660 train_time:139047ms step_avg:92.64ms +step:1502/1660 train_time:139124ms step_avg:92.63ms +step:1503/1660 train_time:139221ms step_avg:92.63ms +step:1504/1660 train_time:139314ms step_avg:92.63ms +step:1505/1660 train_time:139407ms step_avg:92.63ms +step:1506/1660 train_time:139499ms step_avg:92.63ms +step:1507/1660 train_time:139590ms step_avg:92.63ms +step:1508/1660 train_time:139685ms step_avg:92.63ms +step:1509/1660 train_time:139778ms step_avg:92.63ms +step:1510/1660 train_time:139871ms step_avg:92.63ms +step:1511/1660 train_time:139965ms step_avg:92.63ms +step:1512/1660 train_time:140060ms step_avg:92.63ms +step:1513/1660 train_time:140154ms step_avg:92.63ms +step:1514/1660 train_time:140249ms step_avg:92.63ms +step:1515/1660 train_time:140342ms step_avg:92.64ms +step:1516/1660 train_time:140435ms step_avg:92.64ms +step:1517/1660 train_time:140528ms step_avg:92.64ms +step:1518/1660 train_time:140621ms step_avg:92.64ms +step:1519/1660 train_time:140714ms step_avg:92.64ms +step:1520/1660 train_time:140808ms step_avg:92.64ms +step:1521/1660 train_time:140902ms step_avg:92.64ms +step:1522/1660 train_time:140995ms step_avg:92.64ms +step:1523/1660 train_time:141089ms step_avg:92.64ms +step:1524/1660 train_time:141183ms step_avg:92.64ms +step:1525/1660 train_time:141277ms step_avg:92.64ms +step:1526/1660 train_time:141371ms step_avg:92.64ms +step:1527/1660 train_time:141465ms step_avg:92.64ms +step:1528/1660 train_time:141558ms step_avg:92.64ms +step:1529/1660 train_time:141650ms step_avg:92.64ms +step:1530/1660 train_time:141743ms step_avg:92.64ms +step:1531/1660 train_time:141835ms step_avg:92.64ms +step:1532/1660 train_time:141929ms step_avg:92.64ms +step:1533/1660 train_time:142023ms step_avg:92.64ms +step:1534/1660 train_time:142117ms step_avg:92.64ms +step:1535/1660 train_time:142211ms step_avg:92.65ms +step:1536/1660 train_time:142307ms step_avg:92.65ms +step:1537/1660 train_time:142401ms step_avg:92.65ms +step:1538/1660 train_time:142495ms step_avg:92.65ms +step:1539/1660 train_time:142588ms step_avg:92.65ms +step:1540/1660 train_time:142681ms step_avg:92.65ms +step:1541/1660 train_time:142774ms step_avg:92.65ms +step:1542/1660 train_time:142867ms step_avg:92.65ms +step:1543/1660 train_time:142961ms step_avg:92.65ms +step:1544/1660 train_time:143054ms step_avg:92.65ms +step:1545/1660 train_time:143148ms step_avg:92.65ms +step:1546/1660 train_time:143242ms step_avg:92.65ms +step:1547/1660 train_time:143336ms step_avg:92.65ms +step:1548/1660 train_time:143430ms step_avg:92.65ms +step:1549/1660 train_time:143523ms step_avg:92.66ms +step:1550/1660 train_time:143617ms step_avg:92.66ms +step:1551/1660 train_time:143709ms step_avg:92.66ms +step:1552/1660 train_time:143802ms step_avg:92.66ms +step:1553/1660 train_time:143894ms step_avg:92.66ms +step:1554/1660 train_time:143988ms step_avg:92.66ms +step:1555/1660 train_time:144083ms step_avg:92.66ms +step:1556/1660 train_time:144179ms step_avg:92.66ms +step:1557/1660 train_time:144272ms step_avg:92.66ms +step:1558/1660 train_time:144366ms step_avg:92.66ms +step:1559/1660 train_time:144460ms step_avg:92.66ms +step:1560/1660 train_time:144553ms step_avg:92.66ms +step:1561/1660 train_time:144647ms step_avg:92.66ms +step:1562/1660 train_time:144741ms step_avg:92.66ms +step:1563/1660 train_time:144834ms step_avg:92.66ms +step:1564/1660 train_time:144927ms step_avg:92.66ms +step:1565/1660 train_time:145020ms step_avg:92.66ms +step:1566/1660 train_time:145114ms step_avg:92.67ms +step:1567/1660 train_time:145209ms step_avg:92.67ms +step:1568/1660 train_time:145304ms step_avg:92.67ms +step:1569/1660 train_time:145398ms step_avg:92.67ms +step:1570/1660 train_time:145492ms step_avg:92.67ms +step:1571/1660 train_time:145586ms step_avg:92.67ms +step:1572/1660 train_time:145679ms step_avg:92.67ms +step:1573/1660 train_time:145773ms step_avg:92.67ms +step:1574/1660 train_time:145866ms step_avg:92.67ms +step:1575/1660 train_time:145959ms step_avg:92.67ms +step:1576/1660 train_time:146051ms step_avg:92.67ms +step:1577/1660 train_time:146145ms step_avg:92.67ms +step:1578/1660 train_time:146239ms step_avg:92.67ms +step:1579/1660 train_time:146332ms step_avg:92.67ms +step:1580/1660 train_time:146426ms step_avg:92.67ms +step:1581/1660 train_time:146520ms step_avg:92.68ms +step:1582/1660 train_time:146612ms step_avg:92.68ms +step:1583/1660 train_time:146706ms step_avg:92.68ms +step:1584/1660 train_time:146800ms step_avg:92.68ms +step:1585/1660 train_time:146892ms step_avg:92.68ms +step:1586/1660 train_time:146986ms step_avg:92.68ms +step:1587/1660 train_time:147080ms step_avg:92.68ms +step:1588/1660 train_time:147174ms step_avg:92.68ms +step:1589/1660 train_time:147268ms step_avg:92.68ms +step:1590/1660 train_time:147362ms step_avg:92.68ms +step:1591/1660 train_time:147455ms step_avg:92.68ms +step:1592/1660 train_time:147548ms step_avg:92.68ms +step:1593/1660 train_time:147641ms step_avg:92.68ms +step:1594/1660 train_time:147735ms step_avg:92.68ms +step:1595/1660 train_time:147828ms step_avg:92.68ms +step:1596/1660 train_time:147921ms step_avg:92.68ms +step:1597/1660 train_time:148014ms step_avg:92.68ms +step:1598/1660 train_time:148108ms step_avg:92.68ms +step:1599/1660 train_time:148202ms step_avg:92.68ms +step:1600/1660 train_time:148295ms step_avg:92.68ms +step:1601/1660 train_time:148390ms step_avg:92.69ms +step:1602/1660 train_time:148484ms step_avg:92.69ms +step:1603/1660 train_time:148577ms step_avg:92.69ms +step:1604/1660 train_time:148671ms step_avg:92.69ms +step:1605/1660 train_time:148764ms step_avg:92.69ms +step:1606/1660 train_time:148858ms step_avg:92.69ms +step:1607/1660 train_time:148951ms step_avg:92.69ms +step:1608/1660 train_time:149044ms step_avg:92.69ms +step:1609/1660 train_time:149138ms step_avg:92.69ms +step:1610/1660 train_time:149232ms step_avg:92.69ms +step:1611/1660 train_time:149327ms step_avg:92.69ms +step:1612/1660 train_time:149420ms step_avg:92.69ms +step:1613/1660 train_time:149513ms step_avg:92.69ms +step:1614/1660 train_time:149608ms step_avg:92.69ms +step:1615/1660 train_time:149701ms step_avg:92.69ms +step:1616/1660 train_time:149794ms step_avg:92.69ms +step:1617/1660 train_time:149887ms step_avg:92.69ms +step:1618/1660 train_time:149981ms step_avg:92.70ms +step:1619/1660 train_time:150074ms step_avg:92.70ms +step:1620/1660 train_time:150168ms step_avg:92.70ms +step:1621/1660 train_time:150262ms step_avg:92.70ms +step:1622/1660 train_time:150355ms step_avg:92.70ms +step:1623/1660 train_time:150448ms step_avg:92.70ms +step:1624/1660 train_time:150542ms step_avg:92.70ms +step:1625/1660 train_time:150635ms step_avg:92.70ms +step:1625/1660 val_loss:3.2870 train_time:150730ms step_avg:92.76ms +step:1626/1660 train_time:150750ms step_avg:92.71ms +step:1627/1660 train_time:150829ms step_avg:92.70ms +step:1628/1660 train_time:150925ms step_avg:92.71ms +step:1629/1660 train_time:151018ms step_avg:92.71ms +step:1630/1660 train_time:151111ms step_avg:92.71ms +step:1631/1660 train_time:151205ms step_avg:92.71ms +step:1632/1660 train_time:151297ms step_avg:92.71ms +step:1633/1660 train_time:151390ms step_avg:92.71ms +step:1634/1660 train_time:151482ms step_avg:92.71ms +step:1635/1660 train_time:151575ms step_avg:92.71ms +step:1636/1660 train_time:151669ms step_avg:92.71ms +step:1637/1660 train_time:151765ms step_avg:92.71ms +step:1638/1660 train_time:151861ms step_avg:92.71ms +step:1639/1660 train_time:151955ms step_avg:92.71ms +step:1640/1660 train_time:152048ms step_avg:92.71ms +step:1641/1660 train_time:152142ms step_avg:92.71ms +step:1642/1660 train_time:152235ms step_avg:92.71ms +step:1643/1660 train_time:152328ms step_avg:92.71ms +step:1644/1660 train_time:152421ms step_avg:92.71ms +step:1645/1660 train_time:152513ms step_avg:92.71ms +step:1646/1660 train_time:152607ms step_avg:92.71ms +step:1647/1660 train_time:152701ms step_avg:92.71ms +step:1648/1660 train_time:152795ms step_avg:92.72ms +step:1649/1660 train_time:152890ms step_avg:92.72ms +step:1650/1660 train_time:152984ms step_avg:92.72ms +step:1651/1660 train_time:153077ms step_avg:92.72ms +step:1652/1660 train_time:153170ms step_avg:92.72ms +step:1653/1660 train_time:153263ms step_avg:92.72ms +step:1654/1660 train_time:153356ms step_avg:92.72ms +step:1655/1660 train_time:153449ms step_avg:92.72ms +step:1656/1660 train_time:153541ms step_avg:92.72ms +step:1657/1660 train_time:153636ms step_avg:92.72ms +step:1658/1660 train_time:153731ms step_avg:92.72ms +step:1659/1660 train_time:153825ms step_avg:92.72ms +step:1660/1660 train_time:153919ms step_avg:92.72ms +step:1660/1660 val_loss:3.2787 train_time:154014ms step_avg:92.78ms +peak memory allocated: 32002 MiB reserved: 46896 MiB diff --git a/records/091525_ThreadingFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt b/records/091525_ThreadingFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt new file mode 100644 index 000000000..6e7904b9d --- /dev/null +++ b/records/091525_ThreadingFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:50:26 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 36C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 36C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 195661 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 195662 C /usr/bin/python3 614MiB | +| 0 N/A N/A 195663 C /usr/bin/python3 614MiB | +| 0 N/A N/A 195664 C /usr/bin/python3 614MiB | +| 0 N/A N/A 195665 C /usr/bin/python3 614MiB | +| 0 N/A N/A 195666 C /usr/bin/python3 614MiB | +| 0 N/A N/A 195667 C /usr/bin/python3 614MiB | +| 0 N/A N/A 195668 C /usr/bin/python3 614MiB | +| 1 N/A N/A 195662 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 195663 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 195664 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 195665 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 195666 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 195667 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 195668 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:146ms step_avg:146.21ms +step:2/1660 train_time:167ms step_avg:83.59ms +step:3/1660 train_time:234ms step_avg:77.88ms +step:4/1660 train_time:324ms step_avg:80.90ms +step:5/1660 train_time:414ms step_avg:82.76ms +step:6/1660 train_time:505ms step_avg:84.13ms +step:7/1660 train_time:595ms step_avg:85.07ms +step:8/1660 train_time:686ms step_avg:85.75ms +step:9/1660 train_time:777ms step_avg:86.30ms +step:10/1660 train_time:867ms step_avg:86.72ms +step:11/1660 train_time:958ms step_avg:87.06ms +step:12/1660 train_time:1051ms step_avg:87.55ms +step:13/1660 train_time:1146ms step_avg:88.16ms +step:14/1660 train_time:1240ms step_avg:88.57ms +step:15/1660 train_time:1332ms step_avg:88.78ms +step:16/1660 train_time:1424ms step_avg:89.00ms +step:17/1660 train_time:1515ms step_avg:89.14ms +step:18/1660 train_time:1606ms step_avg:89.23ms +step:19/1660 train_time:1697ms step_avg:89.34ms +step:20/1660 train_time:1789ms step_avg:89.43ms +step:21/1660 train_time:1879ms step_avg:89.49ms +step:22/1660 train_time:1971ms step_avg:89.61ms +step:23/1660 train_time:2064ms step_avg:89.74ms +step:24/1660 train_time:2157ms step_avg:89.87ms +step:25/1660 train_time:2249ms step_avg:89.97ms +step:26/1660 train_time:2343ms step_avg:90.12ms +step:27/1660 train_time:2435ms step_avg:90.20ms +step:28/1660 train_time:2527ms step_avg:90.24ms +step:29/1660 train_time:2618ms step_avg:90.27ms +step:30/1660 train_time:2708ms step_avg:90.28ms +step:31/1660 train_time:2799ms step_avg:90.30ms +step:32/1660 train_time:2890ms step_avg:90.33ms +step:33/1660 train_time:2982ms step_avg:90.37ms +step:34/1660 train_time:3075ms step_avg:90.44ms +step:35/1660 train_time:3167ms step_avg:90.50ms +step:36/1660 train_time:3260ms step_avg:90.54ms +step:37/1660 train_time:3351ms step_avg:90.57ms +step:38/1660 train_time:3444ms step_avg:90.62ms +step:39/1660 train_time:3536ms step_avg:90.68ms +step:40/1660 train_time:3628ms step_avg:90.70ms +step:41/1660 train_time:3720ms step_avg:90.73ms +step:42/1660 train_time:3812ms step_avg:90.75ms +step:43/1660 train_time:3903ms step_avg:90.77ms +step:44/1660 train_time:3995ms step_avg:90.79ms +step:45/1660 train_time:4087ms step_avg:90.82ms +step:46/1660 train_time:4179ms step_avg:90.84ms +step:47/1660 train_time:4271ms step_avg:90.87ms +step:48/1660 train_time:4363ms step_avg:90.89ms +step:49/1660 train_time:4454ms step_avg:90.90ms +step:50/1660 train_time:4546ms step_avg:90.93ms +step:51/1660 train_time:4639ms step_avg:90.95ms +step:52/1660 train_time:4731ms step_avg:90.97ms +step:53/1660 train_time:4823ms step_avg:91.00ms +step:54/1660 train_time:4915ms step_avg:91.02ms +step:55/1660 train_time:5007ms step_avg:91.04ms +step:56/1660 train_time:5099ms step_avg:91.05ms +step:57/1660 train_time:5190ms step_avg:91.05ms +step:58/1660 train_time:5283ms step_avg:91.08ms +step:59/1660 train_time:5374ms step_avg:91.08ms +step:60/1660 train_time:5465ms step_avg:91.09ms +step:61/1660 train_time:5557ms step_avg:91.09ms +step:62/1660 train_time:5648ms step_avg:91.10ms +step:63/1660 train_time:5741ms step_avg:91.12ms +step:64/1660 train_time:5832ms step_avg:91.12ms +step:65/1660 train_time:5923ms step_avg:91.13ms +step:66/1660 train_time:6015ms step_avg:91.13ms +step:67/1660 train_time:6106ms step_avg:91.14ms +step:68/1660 train_time:6198ms step_avg:91.15ms +step:69/1660 train_time:6289ms step_avg:91.15ms +step:70/1660 train_time:6381ms step_avg:91.16ms +step:71/1660 train_time:6474ms step_avg:91.18ms +step:72/1660 train_time:6566ms step_avg:91.19ms +step:73/1660 train_time:6657ms step_avg:91.19ms +step:74/1660 train_time:6749ms step_avg:91.20ms +step:75/1660 train_time:6841ms step_avg:91.22ms +step:76/1660 train_time:6934ms step_avg:91.24ms +step:77/1660 train_time:7026ms step_avg:91.25ms +step:78/1660 train_time:7118ms step_avg:91.26ms +step:79/1660 train_time:7209ms step_avg:91.25ms +step:80/1660 train_time:7300ms step_avg:91.25ms +step:81/1660 train_time:7391ms step_avg:91.25ms +step:82/1660 train_time:7482ms step_avg:91.25ms +step:83/1660 train_time:7573ms step_avg:91.25ms +step:84/1660 train_time:7665ms step_avg:91.25ms +step:85/1660 train_time:7757ms step_avg:91.25ms +step:86/1660 train_time:7848ms step_avg:91.26ms +step:87/1660 train_time:7940ms step_avg:91.26ms +step:88/1660 train_time:8032ms step_avg:91.27ms +step:89/1660 train_time:8124ms step_avg:91.28ms +step:90/1660 train_time:8216ms step_avg:91.29ms +step:91/1660 train_time:8308ms step_avg:91.30ms +step:92/1660 train_time:8400ms step_avg:91.31ms +step:93/1660 train_time:8492ms step_avg:91.31ms +step:94/1660 train_time:8583ms step_avg:91.31ms +step:95/1660 train_time:8675ms step_avg:91.32ms +step:96/1660 train_time:8766ms step_avg:91.32ms +step:97/1660 train_time:8858ms step_avg:91.32ms +step:98/1660 train_time:8949ms step_avg:91.32ms +step:99/1660 train_time:9043ms step_avg:91.34ms +step:100/1660 train_time:9135ms step_avg:91.35ms +step:101/1660 train_time:9227ms step_avg:91.35ms +step:102/1660 train_time:9319ms step_avg:91.36ms +step:103/1660 train_time:9410ms step_avg:91.36ms +step:104/1660 train_time:9501ms step_avg:91.36ms +step:105/1660 train_time:9592ms step_avg:91.35ms +step:106/1660 train_time:9683ms step_avg:91.35ms +step:107/1660 train_time:9775ms step_avg:91.36ms +step:108/1660 train_time:9866ms step_avg:91.35ms +step:109/1660 train_time:9958ms step_avg:91.36ms +step:110/1660 train_time:10049ms step_avg:91.36ms +step:111/1660 train_time:10141ms step_avg:91.36ms +step:112/1660 train_time:10234ms step_avg:91.37ms +step:113/1660 train_time:10326ms step_avg:91.38ms +step:114/1660 train_time:10417ms step_avg:91.38ms +step:115/1660 train_time:10509ms step_avg:91.38ms +step:116/1660 train_time:10601ms step_avg:91.39ms +step:117/1660 train_time:10692ms step_avg:91.39ms +step:118/1660 train_time:10784ms step_avg:91.39ms +step:119/1660 train_time:10876ms step_avg:91.39ms +step:120/1660 train_time:10967ms step_avg:91.39ms +step:121/1660 train_time:11059ms step_avg:91.39ms +step:122/1660 train_time:11149ms step_avg:91.39ms +step:123/1660 train_time:11242ms step_avg:91.40ms +step:124/1660 train_time:11334ms step_avg:91.41ms +step:125/1660 train_time:11426ms step_avg:91.41ms +step:125/1660 val_loss:4.3282 train_time:11519ms step_avg:92.15ms +step:126/1660 train_time:11540ms step_avg:91.59ms +step:127/1660 train_time:11612ms step_avg:91.44ms +step:128/1660 train_time:11716ms step_avg:91.53ms +step:129/1660 train_time:11810ms step_avg:91.55ms +step:130/1660 train_time:11902ms step_avg:91.55ms +step:131/1660 train_time:11993ms step_avg:91.55ms +step:132/1660 train_time:12084ms step_avg:91.54ms +step:133/1660 train_time:12174ms step_avg:91.53ms +step:134/1660 train_time:12264ms step_avg:91.53ms +step:135/1660 train_time:12355ms step_avg:91.52ms +step:136/1660 train_time:12445ms step_avg:91.51ms +step:137/1660 train_time:12536ms step_avg:91.50ms +step:138/1660 train_time:12629ms step_avg:91.52ms +step:139/1660 train_time:12723ms step_avg:91.53ms +step:140/1660 train_time:12816ms step_avg:91.54ms +step:141/1660 train_time:12909ms step_avg:91.55ms +step:142/1660 train_time:12999ms step_avg:91.54ms +step:143/1660 train_time:13091ms step_avg:91.54ms +step:144/1660 train_time:13182ms step_avg:91.54ms +step:145/1660 train_time:13273ms step_avg:91.54ms +step:146/1660 train_time:13364ms step_avg:91.53ms +step:147/1660 train_time:13455ms step_avg:91.53ms +step:148/1660 train_time:13547ms step_avg:91.53ms +step:149/1660 train_time:13639ms step_avg:91.54ms +step:150/1660 train_time:13732ms step_avg:91.54ms +step:151/1660 train_time:13824ms step_avg:91.55ms +step:152/1660 train_time:13916ms step_avg:91.55ms +step:153/1660 train_time:14008ms step_avg:91.55ms +step:154/1660 train_time:14098ms step_avg:91.55ms +step:155/1660 train_time:14190ms step_avg:91.55ms +step:156/1660 train_time:14282ms step_avg:91.55ms +step:157/1660 train_time:14373ms step_avg:91.54ms +step:158/1660 train_time:14463ms step_avg:91.54ms +step:159/1660 train_time:14555ms step_avg:91.54ms +step:160/1660 train_time:14647ms step_avg:91.54ms +step:161/1660 train_time:14739ms step_avg:91.55ms +step:162/1660 train_time:14831ms step_avg:91.55ms +step:163/1660 train_time:14924ms step_avg:91.56ms +step:164/1660 train_time:15015ms step_avg:91.55ms +step:165/1660 train_time:15106ms step_avg:91.55ms +step:166/1660 train_time:15197ms step_avg:91.55ms +step:167/1660 train_time:15288ms step_avg:91.55ms +step:168/1660 train_time:15379ms step_avg:91.54ms +step:169/1660 train_time:15470ms step_avg:91.54ms +step:170/1660 train_time:15563ms step_avg:91.55ms +step:171/1660 train_time:15655ms step_avg:91.55ms +step:172/1660 train_time:15747ms step_avg:91.55ms +step:173/1660 train_time:15838ms step_avg:91.55ms +step:174/1660 train_time:15929ms step_avg:91.55ms +step:175/1660 train_time:16020ms step_avg:91.54ms +step:176/1660 train_time:16112ms step_avg:91.55ms +step:177/1660 train_time:16203ms step_avg:91.54ms +step:178/1660 train_time:16295ms step_avg:91.55ms +step:179/1660 train_time:16387ms step_avg:91.55ms +step:180/1660 train_time:16478ms step_avg:91.54ms +step:181/1660 train_time:16569ms step_avg:91.54ms +step:182/1660 train_time:16661ms step_avg:91.55ms +step:183/1660 train_time:16754ms step_avg:91.55ms +step:184/1660 train_time:16846ms step_avg:91.55ms +step:185/1660 train_time:16936ms step_avg:91.55ms +step:186/1660 train_time:17028ms step_avg:91.55ms +step:187/1660 train_time:17119ms step_avg:91.54ms +step:188/1660 train_time:17210ms step_avg:91.54ms +step:189/1660 train_time:17302ms step_avg:91.54ms +step:190/1660 train_time:17394ms step_avg:91.54ms +step:191/1660 train_time:17485ms step_avg:91.55ms +step:192/1660 train_time:17576ms step_avg:91.54ms +step:193/1660 train_time:17668ms step_avg:91.54ms +step:194/1660 train_time:17759ms step_avg:91.54ms +step:195/1660 train_time:17851ms step_avg:91.54ms +step:196/1660 train_time:17942ms step_avg:91.54ms +step:197/1660 train_time:18034ms step_avg:91.54ms +step:198/1660 train_time:18125ms step_avg:91.54ms +step:199/1660 train_time:18216ms step_avg:91.54ms +step:200/1660 train_time:18307ms step_avg:91.54ms +step:201/1660 train_time:18398ms step_avg:91.53ms +step:202/1660 train_time:18491ms step_avg:91.54ms +step:203/1660 train_time:18582ms step_avg:91.54ms +step:204/1660 train_time:18674ms step_avg:91.54ms +step:205/1660 train_time:18766ms step_avg:91.54ms +step:206/1660 train_time:18858ms step_avg:91.54ms +step:207/1660 train_time:18950ms step_avg:91.54ms +step:208/1660 train_time:19041ms step_avg:91.54ms +step:209/1660 train_time:19133ms step_avg:91.55ms +step:210/1660 train_time:19224ms step_avg:91.54ms +step:211/1660 train_time:19315ms step_avg:91.54ms +step:212/1660 train_time:19406ms step_avg:91.54ms +step:213/1660 train_time:19497ms step_avg:91.54ms +step:214/1660 train_time:19590ms step_avg:91.54ms +step:215/1660 train_time:19681ms step_avg:91.54ms +step:216/1660 train_time:19773ms step_avg:91.54ms +step:217/1660 train_time:19865ms step_avg:91.55ms +step:218/1660 train_time:19956ms step_avg:91.54ms +step:219/1660 train_time:20048ms step_avg:91.54ms +step:220/1660 train_time:20139ms step_avg:91.54ms +step:221/1660 train_time:20230ms step_avg:91.54ms +step:222/1660 train_time:20322ms step_avg:91.54ms +step:223/1660 train_time:20413ms step_avg:91.54ms +step:224/1660 train_time:20505ms step_avg:91.54ms +step:225/1660 train_time:20597ms step_avg:91.54ms +step:226/1660 train_time:20688ms step_avg:91.54ms +step:227/1660 train_time:20779ms step_avg:91.54ms +step:228/1660 train_time:20871ms step_avg:91.54ms +step:229/1660 train_time:20963ms step_avg:91.54ms +step:230/1660 train_time:21054ms step_avg:91.54ms +step:231/1660 train_time:21145ms step_avg:91.54ms +step:232/1660 train_time:21236ms step_avg:91.54ms +step:233/1660 train_time:21327ms step_avg:91.53ms +step:234/1660 train_time:21419ms step_avg:91.53ms +step:235/1660 train_time:21512ms step_avg:91.54ms +step:236/1660 train_time:21603ms step_avg:91.54ms +step:237/1660 train_time:21695ms step_avg:91.54ms +step:238/1660 train_time:21786ms step_avg:91.54ms +step:239/1660 train_time:21877ms step_avg:91.54ms +step:240/1660 train_time:21970ms step_avg:91.54ms +step:241/1660 train_time:22062ms step_avg:91.54ms +step:242/1660 train_time:22153ms step_avg:91.54ms +step:243/1660 train_time:22245ms step_avg:91.54ms +step:244/1660 train_time:22335ms step_avg:91.54ms +step:245/1660 train_time:22427ms step_avg:91.54ms +step:246/1660 train_time:22518ms step_avg:91.54ms +step:247/1660 train_time:22609ms step_avg:91.54ms +step:248/1660 train_time:22701ms step_avg:91.54ms +step:249/1660 train_time:22793ms step_avg:91.54ms +step:250/1660 train_time:22885ms step_avg:91.54ms +step:250/1660 val_loss:3.9800 train_time:22977ms step_avg:91.91ms +step:251/1660 train_time:23000ms step_avg:91.63ms +step:252/1660 train_time:23072ms step_avg:91.56ms +step:253/1660 train_time:23169ms step_avg:91.58ms +step:254/1660 train_time:23262ms step_avg:91.58ms +step:255/1660 train_time:23353ms step_avg:91.58ms +step:256/1660 train_time:23444ms step_avg:91.58ms +step:257/1660 train_time:23534ms step_avg:91.57ms +step:258/1660 train_time:23624ms step_avg:91.57ms +step:259/1660 train_time:23715ms step_avg:91.56ms +step:260/1660 train_time:23806ms step_avg:91.56ms +step:261/1660 train_time:23897ms step_avg:91.56ms +step:262/1660 train_time:23989ms step_avg:91.56ms +step:263/1660 train_time:24084ms step_avg:91.57ms +step:264/1660 train_time:24178ms step_avg:91.58ms +step:265/1660 train_time:24269ms step_avg:91.58ms +step:266/1660 train_time:24361ms step_avg:91.58ms +step:267/1660 train_time:24453ms step_avg:91.58ms +step:268/1660 train_time:24544ms step_avg:91.58ms +step:269/1660 train_time:24636ms step_avg:91.58ms +step:270/1660 train_time:24726ms step_avg:91.58ms +step:271/1660 train_time:24817ms step_avg:91.58ms +step:272/1660 train_time:24908ms step_avg:91.57ms +step:273/1660 train_time:24999ms step_avg:91.57ms +step:274/1660 train_time:25091ms step_avg:91.57ms +step:275/1660 train_time:25184ms step_avg:91.58ms +step:276/1660 train_time:25275ms step_avg:91.58ms +step:277/1660 train_time:25366ms step_avg:91.58ms +step:278/1660 train_time:25458ms step_avg:91.57ms +step:279/1660 train_time:25549ms step_avg:91.57ms +step:280/1660 train_time:25640ms step_avg:91.57ms +step:281/1660 train_time:25731ms step_avg:91.57ms +step:282/1660 train_time:25822ms step_avg:91.57ms +step:283/1660 train_time:25913ms step_avg:91.56ms +step:284/1660 train_time:26004ms step_avg:91.56ms +step:285/1660 train_time:26095ms step_avg:91.56ms +step:286/1660 train_time:26187ms step_avg:91.56ms +step:287/1660 train_time:26279ms step_avg:91.56ms +step:288/1660 train_time:26370ms step_avg:91.56ms +step:289/1660 train_time:26461ms step_avg:91.56ms +step:290/1660 train_time:26552ms step_avg:91.56ms +step:291/1660 train_time:26643ms step_avg:91.56ms +step:292/1660 train_time:26735ms step_avg:91.56ms +step:293/1660 train_time:26826ms step_avg:91.56ms +step:294/1660 train_time:26917ms step_avg:91.55ms +step:295/1660 train_time:27009ms step_avg:91.55ms +step:296/1660 train_time:27100ms step_avg:91.56ms +step:297/1660 train_time:27191ms step_avg:91.55ms +step:298/1660 train_time:27284ms step_avg:91.56ms +step:299/1660 train_time:27375ms step_avg:91.56ms +step:300/1660 train_time:27466ms step_avg:91.55ms +step:301/1660 train_time:27557ms step_avg:91.55ms +step:302/1660 train_time:27648ms step_avg:91.55ms +step:303/1660 train_time:27739ms step_avg:91.55ms +step:304/1660 train_time:27830ms step_avg:91.55ms +step:305/1660 train_time:27922ms step_avg:91.55ms +step:306/1660 train_time:28015ms step_avg:91.55ms +step:307/1660 train_time:28106ms step_avg:91.55ms +step:308/1660 train_time:28199ms step_avg:91.55ms +step:309/1660 train_time:28290ms step_avg:91.55ms +step:310/1660 train_time:28382ms step_avg:91.55ms +step:311/1660 train_time:28473ms step_avg:91.55ms +step:312/1660 train_time:28564ms step_avg:91.55ms +step:313/1660 train_time:28656ms step_avg:91.55ms +step:314/1660 train_time:28747ms step_avg:91.55ms +step:315/1660 train_time:28839ms step_avg:91.55ms +step:316/1660 train_time:28930ms step_avg:91.55ms +step:317/1660 train_time:29022ms step_avg:91.55ms +step:318/1660 train_time:29112ms step_avg:91.55ms +step:319/1660 train_time:29205ms step_avg:91.55ms +step:320/1660 train_time:29297ms step_avg:91.55ms +step:321/1660 train_time:29388ms step_avg:91.55ms +step:322/1660 train_time:29480ms step_avg:91.55ms +step:323/1660 train_time:29571ms step_avg:91.55ms +step:324/1660 train_time:29662ms step_avg:91.55ms +step:325/1660 train_time:29753ms step_avg:91.55ms +step:326/1660 train_time:29845ms step_avg:91.55ms +step:327/1660 train_time:29937ms step_avg:91.55ms +step:328/1660 train_time:30029ms step_avg:91.55ms +step:329/1660 train_time:30120ms step_avg:91.55ms +step:330/1660 train_time:30212ms step_avg:91.55ms +step:331/1660 train_time:30304ms step_avg:91.55ms +step:332/1660 train_time:30396ms step_avg:91.55ms +step:333/1660 train_time:30487ms step_avg:91.55ms +step:334/1660 train_time:30578ms step_avg:91.55ms +step:335/1660 train_time:30670ms step_avg:91.55ms +step:336/1660 train_time:30762ms step_avg:91.55ms +step:337/1660 train_time:30853ms step_avg:91.55ms +step:338/1660 train_time:30945ms step_avg:91.55ms +step:339/1660 train_time:31036ms step_avg:91.55ms +step:340/1660 train_time:31127ms step_avg:91.55ms +step:341/1660 train_time:31220ms step_avg:91.55ms +step:342/1660 train_time:31311ms step_avg:91.55ms +step:343/1660 train_time:31403ms step_avg:91.55ms +step:344/1660 train_time:31495ms step_avg:91.56ms +step:345/1660 train_time:31586ms step_avg:91.55ms +step:346/1660 train_time:31677ms step_avg:91.55ms +step:347/1660 train_time:31769ms step_avg:91.55ms +step:348/1660 train_time:31860ms step_avg:91.55ms +step:349/1660 train_time:31951ms step_avg:91.55ms +step:350/1660 train_time:32043ms step_avg:91.55ms +step:351/1660 train_time:32135ms step_avg:91.55ms +step:352/1660 train_time:32227ms step_avg:91.55ms +step:353/1660 train_time:32319ms step_avg:91.55ms +step:354/1660 train_time:32410ms step_avg:91.55ms +step:355/1660 train_time:32502ms step_avg:91.56ms +step:356/1660 train_time:32594ms step_avg:91.56ms +step:357/1660 train_time:32686ms step_avg:91.56ms +step:358/1660 train_time:32777ms step_avg:91.56ms +step:359/1660 train_time:32868ms step_avg:91.55ms +step:360/1660 train_time:32958ms step_avg:91.55ms +step:361/1660 train_time:33049ms step_avg:91.55ms +step:362/1660 train_time:33141ms step_avg:91.55ms +step:363/1660 train_time:33232ms step_avg:91.55ms +step:364/1660 train_time:33325ms step_avg:91.55ms +step:365/1660 train_time:33416ms step_avg:91.55ms +step:366/1660 train_time:33508ms step_avg:91.55ms +step:367/1660 train_time:33599ms step_avg:91.55ms +step:368/1660 train_time:33691ms step_avg:91.55ms +step:369/1660 train_time:33782ms step_avg:91.55ms +step:370/1660 train_time:33874ms step_avg:91.55ms +step:371/1660 train_time:33965ms step_avg:91.55ms +step:372/1660 train_time:34057ms step_avg:91.55ms +step:373/1660 train_time:34148ms step_avg:91.55ms +step:374/1660 train_time:34239ms step_avg:91.55ms +step:375/1660 train_time:34331ms step_avg:91.55ms +step:375/1660 val_loss:3.8207 train_time:34425ms step_avg:91.80ms +step:376/1660 train_time:34446ms step_avg:91.61ms +step:377/1660 train_time:34518ms step_avg:91.56ms +step:378/1660 train_time:34616ms step_avg:91.58ms +step:379/1660 train_time:34709ms step_avg:91.58ms +step:380/1660 train_time:34800ms step_avg:91.58ms +step:381/1660 train_time:34891ms step_avg:91.58ms +step:382/1660 train_time:34981ms step_avg:91.57ms +step:383/1660 train_time:35071ms step_avg:91.57ms +step:384/1660 train_time:35161ms step_avg:91.57ms +step:385/1660 train_time:35252ms step_avg:91.56ms +step:386/1660 train_time:35342ms step_avg:91.56ms +step:387/1660 train_time:35434ms step_avg:91.56ms +step:388/1660 train_time:35527ms step_avg:91.57ms +step:389/1660 train_time:35620ms step_avg:91.57ms +step:390/1660 train_time:35712ms step_avg:91.57ms +step:391/1660 train_time:35803ms step_avg:91.57ms +step:392/1660 train_time:35895ms step_avg:91.57ms +step:393/1660 train_time:35986ms step_avg:91.57ms +step:394/1660 train_time:36078ms step_avg:91.57ms +step:395/1660 train_time:36169ms step_avg:91.57ms +step:396/1660 train_time:36260ms step_avg:91.57ms +step:397/1660 train_time:36351ms step_avg:91.56ms +step:398/1660 train_time:36442ms step_avg:91.56ms +step:399/1660 train_time:36535ms step_avg:91.57ms +step:400/1660 train_time:36627ms step_avg:91.57ms +step:401/1660 train_time:36719ms step_avg:91.57ms +step:402/1660 train_time:36811ms step_avg:91.57ms +step:403/1660 train_time:36902ms step_avg:91.57ms +step:404/1660 train_time:36993ms step_avg:91.57ms +step:405/1660 train_time:37084ms step_avg:91.56ms +step:406/1660 train_time:37176ms step_avg:91.57ms +step:407/1660 train_time:37267ms step_avg:91.56ms +step:408/1660 train_time:37359ms step_avg:91.57ms +step:409/1660 train_time:37451ms step_avg:91.57ms +step:410/1660 train_time:37542ms step_avg:91.57ms +step:411/1660 train_time:37635ms step_avg:91.57ms +step:412/1660 train_time:37726ms step_avg:91.57ms +step:413/1660 train_time:37818ms step_avg:91.57ms +step:414/1660 train_time:37909ms step_avg:91.57ms +step:415/1660 train_time:38000ms step_avg:91.57ms +step:416/1660 train_time:38092ms step_avg:91.57ms +step:417/1660 train_time:38182ms step_avg:91.56ms +step:418/1660 train_time:38274ms step_avg:91.57ms +step:419/1660 train_time:38366ms step_avg:91.56ms +step:420/1660 train_time:38458ms step_avg:91.57ms +step:421/1660 train_time:38550ms step_avg:91.57ms +step:422/1660 train_time:38642ms step_avg:91.57ms +step:423/1660 train_time:38734ms step_avg:91.57ms +step:424/1660 train_time:38824ms step_avg:91.57ms +step:425/1660 train_time:38916ms step_avg:91.57ms +step:426/1660 train_time:39007ms step_avg:91.57ms +step:427/1660 train_time:39099ms step_avg:91.57ms +step:428/1660 train_time:39191ms step_avg:91.57ms +step:429/1660 train_time:39281ms step_avg:91.57ms +step:430/1660 train_time:39373ms step_avg:91.56ms +step:431/1660 train_time:39464ms step_avg:91.56ms +step:432/1660 train_time:39555ms step_avg:91.56ms +step:433/1660 train_time:39647ms step_avg:91.56ms +step:434/1660 train_time:39738ms step_avg:91.56ms +step:435/1660 train_time:39829ms step_avg:91.56ms +step:436/1660 train_time:39920ms step_avg:91.56ms +step:437/1660 train_time:40012ms step_avg:91.56ms +step:438/1660 train_time:40102ms step_avg:91.56ms +step:439/1660 train_time:40194ms step_avg:91.56ms +step:440/1660 train_time:40285ms step_avg:91.56ms +step:441/1660 train_time:40377ms step_avg:91.56ms +step:442/1660 train_time:40468ms step_avg:91.56ms +step:443/1660 train_time:40560ms step_avg:91.56ms +step:444/1660 train_time:40652ms step_avg:91.56ms +step:445/1660 train_time:40743ms step_avg:91.56ms +step:446/1660 train_time:40835ms step_avg:91.56ms +step:447/1660 train_time:40925ms step_avg:91.56ms +step:448/1660 train_time:41017ms step_avg:91.56ms +step:449/1660 train_time:41108ms step_avg:91.55ms +step:450/1660 train_time:41199ms step_avg:91.55ms +step:451/1660 train_time:41291ms step_avg:91.56ms +step:452/1660 train_time:41382ms step_avg:91.55ms +step:453/1660 train_time:41474ms step_avg:91.55ms +step:454/1660 train_time:41565ms step_avg:91.55ms +step:455/1660 train_time:41657ms step_avg:91.55ms +step:456/1660 train_time:41749ms step_avg:91.55ms +step:457/1660 train_time:41840ms step_avg:91.55ms +step:458/1660 train_time:41931ms step_avg:91.55ms +step:459/1660 train_time:42021ms step_avg:91.55ms +step:460/1660 train_time:42113ms step_avg:91.55ms +step:461/1660 train_time:42203ms step_avg:91.55ms +step:462/1660 train_time:42294ms step_avg:91.55ms +step:463/1660 train_time:42386ms step_avg:91.55ms +step:464/1660 train_time:42478ms step_avg:91.55ms +step:465/1660 train_time:42570ms step_avg:91.55ms +step:466/1660 train_time:42661ms step_avg:91.55ms +step:467/1660 train_time:42752ms step_avg:91.55ms +step:468/1660 train_time:42843ms step_avg:91.55ms +step:469/1660 train_time:42935ms step_avg:91.55ms +step:470/1660 train_time:43026ms step_avg:91.54ms +step:471/1660 train_time:43117ms step_avg:91.54ms +step:472/1660 train_time:43208ms step_avg:91.54ms +step:473/1660 train_time:43299ms step_avg:91.54ms +step:474/1660 train_time:43391ms step_avg:91.54ms +step:475/1660 train_time:43482ms step_avg:91.54ms +step:476/1660 train_time:43574ms step_avg:91.54ms +step:477/1660 train_time:43666ms step_avg:91.54ms +step:478/1660 train_time:43758ms step_avg:91.54ms +step:479/1660 train_time:43850ms step_avg:91.54ms +step:480/1660 train_time:43940ms step_avg:91.54ms +step:481/1660 train_time:44031ms step_avg:91.54ms +step:482/1660 train_time:44122ms step_avg:91.54ms +step:483/1660 train_time:44213ms step_avg:91.54ms +step:484/1660 train_time:44305ms step_avg:91.54ms +step:485/1660 train_time:44397ms step_avg:91.54ms +step:486/1660 train_time:44489ms step_avg:91.54ms +step:487/1660 train_time:44580ms step_avg:91.54ms +step:488/1660 train_time:44671ms step_avg:91.54ms +step:489/1660 train_time:44763ms step_avg:91.54ms +step:490/1660 train_time:44854ms step_avg:91.54ms +step:491/1660 train_time:44945ms step_avg:91.54ms +step:492/1660 train_time:45036ms step_avg:91.54ms +step:493/1660 train_time:45127ms step_avg:91.54ms +step:494/1660 train_time:45219ms step_avg:91.54ms +step:495/1660 train_time:45310ms step_avg:91.53ms +step:496/1660 train_time:45401ms step_avg:91.53ms +step:497/1660 train_time:45493ms step_avg:91.54ms +step:498/1660 train_time:45585ms step_avg:91.54ms +step:499/1660 train_time:45677ms step_avg:91.54ms +step:500/1660 train_time:45769ms step_avg:91.54ms +step:500/1660 val_loss:3.7180 train_time:45861ms step_avg:91.72ms +step:501/1660 train_time:45882ms step_avg:91.58ms +step:502/1660 train_time:45953ms step_avg:91.54ms +step:503/1660 train_time:46051ms step_avg:91.55ms +step:504/1660 train_time:46143ms step_avg:91.55ms +step:505/1660 train_time:46234ms step_avg:91.55ms +step:506/1660 train_time:46324ms step_avg:91.55ms +step:507/1660 train_time:46415ms step_avg:91.55ms +step:508/1660 train_time:46506ms step_avg:91.55ms +step:509/1660 train_time:46597ms step_avg:91.55ms +step:510/1660 train_time:46688ms step_avg:91.54ms +step:511/1660 train_time:46779ms step_avg:91.54ms +step:512/1660 train_time:46871ms step_avg:91.54ms +step:513/1660 train_time:46964ms step_avg:91.55ms +step:514/1660 train_time:47057ms step_avg:91.55ms +step:515/1660 train_time:47149ms step_avg:91.55ms +step:516/1660 train_time:47241ms step_avg:91.55ms +step:517/1660 train_time:47331ms step_avg:91.55ms +step:518/1660 train_time:47422ms step_avg:91.55ms +step:519/1660 train_time:47512ms step_avg:91.55ms +step:520/1660 train_time:47603ms step_avg:91.55ms +step:521/1660 train_time:47694ms step_avg:91.54ms +step:522/1660 train_time:47786ms step_avg:91.54ms +step:523/1660 train_time:47878ms step_avg:91.55ms +step:524/1660 train_time:47971ms step_avg:91.55ms +step:525/1660 train_time:48063ms step_avg:91.55ms +step:526/1660 train_time:48155ms step_avg:91.55ms +step:527/1660 train_time:48247ms step_avg:91.55ms +step:528/1660 train_time:48338ms step_avg:91.55ms +step:529/1660 train_time:48429ms step_avg:91.55ms +step:530/1660 train_time:48520ms step_avg:91.55ms +step:531/1660 train_time:48611ms step_avg:91.55ms +step:532/1660 train_time:48702ms step_avg:91.55ms +step:533/1660 train_time:48793ms step_avg:91.54ms +step:534/1660 train_time:48885ms step_avg:91.54ms +step:535/1660 train_time:48977ms step_avg:91.55ms +step:536/1660 train_time:49069ms step_avg:91.55ms +step:537/1660 train_time:49161ms step_avg:91.55ms +step:538/1660 train_time:49253ms step_avg:91.55ms +step:539/1660 train_time:49345ms step_avg:91.55ms +step:540/1660 train_time:49436ms step_avg:91.55ms +step:541/1660 train_time:49527ms step_avg:91.55ms +step:542/1660 train_time:49619ms step_avg:91.55ms +step:543/1660 train_time:49710ms step_avg:91.55ms +step:544/1660 train_time:49801ms step_avg:91.55ms +step:545/1660 train_time:49892ms step_avg:91.54ms +step:546/1660 train_time:49983ms step_avg:91.54ms +step:547/1660 train_time:50074ms step_avg:91.54ms +step:548/1660 train_time:50166ms step_avg:91.54ms +step:549/1660 train_time:50257ms step_avg:91.54ms +step:550/1660 train_time:50349ms step_avg:91.54ms +step:551/1660 train_time:50440ms step_avg:91.54ms +step:552/1660 train_time:50531ms step_avg:91.54ms +step:553/1660 train_time:50623ms step_avg:91.54ms +step:554/1660 train_time:50714ms step_avg:91.54ms +step:555/1660 train_time:50806ms step_avg:91.54ms +step:556/1660 train_time:50898ms step_avg:91.54ms +step:557/1660 train_time:50990ms step_avg:91.54ms +step:558/1660 train_time:51083ms step_avg:91.55ms +step:559/1660 train_time:51175ms step_avg:91.55ms +step:560/1660 train_time:51267ms step_avg:91.55ms +step:561/1660 train_time:51360ms step_avg:91.55ms +step:562/1660 train_time:51452ms step_avg:91.55ms +step:563/1660 train_time:51546ms step_avg:91.56ms +step:564/1660 train_time:51639ms step_avg:91.56ms +step:565/1660 train_time:51731ms step_avg:91.56ms +step:566/1660 train_time:51824ms step_avg:91.56ms +step:567/1660 train_time:51917ms step_avg:91.56ms +step:568/1660 train_time:52009ms step_avg:91.57ms +step:569/1660 train_time:52102ms step_avg:91.57ms +step:570/1660 train_time:52194ms step_avg:91.57ms +step:571/1660 train_time:52287ms step_avg:91.57ms +step:572/1660 train_time:52380ms step_avg:91.57ms +step:573/1660 train_time:52472ms step_avg:91.57ms +step:574/1660 train_time:52565ms step_avg:91.58ms +step:575/1660 train_time:52658ms step_avg:91.58ms +step:576/1660 train_time:52750ms step_avg:91.58ms +step:577/1660 train_time:52844ms step_avg:91.58ms +step:578/1660 train_time:52937ms step_avg:91.59ms +step:579/1660 train_time:53030ms step_avg:91.59ms +step:580/1660 train_time:53123ms step_avg:91.59ms +step:581/1660 train_time:53215ms step_avg:91.59ms +step:582/1660 train_time:53309ms step_avg:91.60ms +step:583/1660 train_time:53401ms step_avg:91.60ms +step:584/1660 train_time:53493ms step_avg:91.60ms +step:585/1660 train_time:53587ms step_avg:91.60ms +step:586/1660 train_time:53681ms step_avg:91.61ms +step:587/1660 train_time:53773ms step_avg:91.61ms +step:588/1660 train_time:53868ms step_avg:91.61ms +step:589/1660 train_time:53961ms step_avg:91.61ms +step:590/1660 train_time:54052ms step_avg:91.61ms +step:591/1660 train_time:54145ms step_avg:91.62ms +step:592/1660 train_time:54237ms step_avg:91.62ms +step:593/1660 train_time:54330ms step_avg:91.62ms +step:594/1660 train_time:54423ms step_avg:91.62ms +step:595/1660 train_time:54516ms step_avg:91.62ms +step:596/1660 train_time:54609ms step_avg:91.63ms +step:597/1660 train_time:54702ms step_avg:91.63ms +step:598/1660 train_time:54795ms step_avg:91.63ms +step:599/1660 train_time:54887ms step_avg:91.63ms +step:600/1660 train_time:54980ms step_avg:91.63ms +step:601/1660 train_time:55072ms step_avg:91.63ms +step:602/1660 train_time:55164ms step_avg:91.64ms +step:603/1660 train_time:55257ms step_avg:91.64ms +step:604/1660 train_time:55350ms step_avg:91.64ms +step:605/1660 train_time:55443ms step_avg:91.64ms +step:606/1660 train_time:55536ms step_avg:91.64ms +step:607/1660 train_time:55628ms step_avg:91.64ms +step:608/1660 train_time:55721ms step_avg:91.65ms +step:609/1660 train_time:55814ms step_avg:91.65ms +step:610/1660 train_time:55907ms step_avg:91.65ms +step:611/1660 train_time:56000ms step_avg:91.65ms +step:612/1660 train_time:56092ms step_avg:91.65ms +step:613/1660 train_time:56185ms step_avg:91.66ms +step:614/1660 train_time:56278ms step_avg:91.66ms +step:615/1660 train_time:56370ms step_avg:91.66ms +step:616/1660 train_time:56463ms step_avg:91.66ms +step:617/1660 train_time:56557ms step_avg:91.66ms +step:618/1660 train_time:56649ms step_avg:91.66ms +step:619/1660 train_time:56742ms step_avg:91.67ms +step:620/1660 train_time:56835ms step_avg:91.67ms +step:621/1660 train_time:56928ms step_avg:91.67ms +step:622/1660 train_time:57020ms step_avg:91.67ms +step:623/1660 train_time:57111ms step_avg:91.67ms +step:624/1660 train_time:57205ms step_avg:91.67ms +step:625/1660 train_time:57298ms step_avg:91.68ms +step:625/1660 val_loss:3.6152 train_time:57392ms step_avg:91.83ms +step:626/1660 train_time:57412ms step_avg:91.71ms +step:627/1660 train_time:57487ms step_avg:91.69ms +step:628/1660 train_time:57592ms step_avg:91.71ms +step:629/1660 train_time:57684ms step_avg:91.71ms +step:630/1660 train_time:57776ms step_avg:91.71ms +step:631/1660 train_time:57868ms step_avg:91.71ms +step:632/1660 train_time:57959ms step_avg:91.71ms +step:633/1660 train_time:58051ms step_avg:91.71ms +step:634/1660 train_time:58142ms step_avg:91.71ms +step:635/1660 train_time:58234ms step_avg:91.71ms +step:636/1660 train_time:58326ms step_avg:91.71ms +step:637/1660 train_time:58418ms step_avg:91.71ms +step:638/1660 train_time:58516ms step_avg:91.72ms +step:639/1660 train_time:58612ms step_avg:91.73ms +step:640/1660 train_time:58706ms step_avg:91.73ms +step:641/1660 train_time:58798ms step_avg:91.73ms +step:642/1660 train_time:58891ms step_avg:91.73ms +step:643/1660 train_time:58982ms step_avg:91.73ms +step:644/1660 train_time:59074ms step_avg:91.73ms +step:645/1660 train_time:59165ms step_avg:91.73ms +step:646/1660 train_time:59256ms step_avg:91.73ms +step:647/1660 train_time:59349ms step_avg:91.73ms +step:648/1660 train_time:59443ms step_avg:91.73ms +step:649/1660 train_time:59537ms step_avg:91.74ms +step:650/1660 train_time:59632ms step_avg:91.74ms +step:651/1660 train_time:59725ms step_avg:91.74ms +step:652/1660 train_time:59817ms step_avg:91.74ms +step:653/1660 train_time:59909ms step_avg:91.74ms +step:654/1660 train_time:60001ms step_avg:91.74ms +step:655/1660 train_time:60093ms step_avg:91.74ms +step:656/1660 train_time:60185ms step_avg:91.75ms +step:657/1660 train_time:60277ms step_avg:91.75ms +step:658/1660 train_time:60370ms step_avg:91.75ms +step:659/1660 train_time:60464ms step_avg:91.75ms +step:660/1660 train_time:60557ms step_avg:91.75ms +step:661/1660 train_time:60651ms step_avg:91.76ms +step:662/1660 train_time:60743ms step_avg:91.76ms +step:663/1660 train_time:60836ms step_avg:91.76ms +step:664/1660 train_time:60930ms step_avg:91.76ms +step:665/1660 train_time:61022ms step_avg:91.76ms +step:666/1660 train_time:61114ms step_avg:91.76ms +step:667/1660 train_time:61207ms step_avg:91.76ms +step:668/1660 train_time:61298ms step_avg:91.76ms +step:669/1660 train_time:61391ms step_avg:91.77ms +step:670/1660 train_time:61484ms step_avg:91.77ms +step:671/1660 train_time:61577ms step_avg:91.77ms +step:672/1660 train_time:61672ms step_avg:91.77ms +step:673/1660 train_time:61765ms step_avg:91.78ms +step:674/1660 train_time:61857ms step_avg:91.78ms +step:675/1660 train_time:61950ms step_avg:91.78ms +step:676/1660 train_time:62043ms step_avg:91.78ms +step:677/1660 train_time:62135ms step_avg:91.78ms +step:678/1660 train_time:62228ms step_avg:91.78ms +step:679/1660 train_time:62320ms step_avg:91.78ms +step:680/1660 train_time:62413ms step_avg:91.78ms +step:681/1660 train_time:62506ms step_avg:91.79ms +step:682/1660 train_time:62599ms step_avg:91.79ms +step:683/1660 train_time:62692ms step_avg:91.79ms +step:684/1660 train_time:62785ms step_avg:91.79ms +step:685/1660 train_time:62877ms step_avg:91.79ms +step:686/1660 train_time:62969ms step_avg:91.79ms +step:687/1660 train_time:63062ms step_avg:91.79ms +step:688/1660 train_time:63154ms step_avg:91.79ms +step:689/1660 train_time:63247ms step_avg:91.79ms +step:690/1660 train_time:63339ms step_avg:91.80ms +step:691/1660 train_time:63432ms step_avg:91.80ms +step:692/1660 train_time:63525ms step_avg:91.80ms +step:693/1660 train_time:63618ms step_avg:91.80ms +step:694/1660 train_time:63712ms step_avg:91.80ms +step:695/1660 train_time:63805ms step_avg:91.81ms +step:696/1660 train_time:63898ms step_avg:91.81ms +step:697/1660 train_time:63990ms step_avg:91.81ms +step:698/1660 train_time:64083ms step_avg:91.81ms +step:699/1660 train_time:64174ms step_avg:91.81ms +step:700/1660 train_time:64267ms step_avg:91.81ms +step:701/1660 train_time:64360ms step_avg:91.81ms +step:702/1660 train_time:64453ms step_avg:91.81ms +step:703/1660 train_time:64546ms step_avg:91.81ms +step:704/1660 train_time:64638ms step_avg:91.81ms +step:705/1660 train_time:64732ms step_avg:91.82ms +step:706/1660 train_time:64825ms step_avg:91.82ms +step:707/1660 train_time:64917ms step_avg:91.82ms +step:708/1660 train_time:65010ms step_avg:91.82ms +step:709/1660 train_time:65102ms step_avg:91.82ms +step:710/1660 train_time:65194ms step_avg:91.82ms +step:711/1660 train_time:65287ms step_avg:91.82ms +step:712/1660 train_time:65380ms step_avg:91.83ms +step:713/1660 train_time:65473ms step_avg:91.83ms +step:714/1660 train_time:65566ms step_avg:91.83ms +step:715/1660 train_time:65658ms step_avg:91.83ms +step:716/1660 train_time:65752ms step_avg:91.83ms +step:717/1660 train_time:65845ms step_avg:91.83ms +step:718/1660 train_time:65937ms step_avg:91.83ms +step:719/1660 train_time:66030ms step_avg:91.84ms +step:720/1660 train_time:66122ms step_avg:91.84ms +step:721/1660 train_time:66214ms step_avg:91.84ms +step:722/1660 train_time:66306ms step_avg:91.84ms +step:723/1660 train_time:66398ms step_avg:91.84ms +step:724/1660 train_time:66491ms step_avg:91.84ms +step:725/1660 train_time:66584ms step_avg:91.84ms +step:726/1660 train_time:66677ms step_avg:91.84ms +step:727/1660 train_time:66770ms step_avg:91.84ms +step:728/1660 train_time:66863ms step_avg:91.84ms +step:729/1660 train_time:66956ms step_avg:91.85ms +step:730/1660 train_time:67049ms step_avg:91.85ms +step:731/1660 train_time:67140ms step_avg:91.85ms +step:732/1660 train_time:67233ms step_avg:91.85ms +step:733/1660 train_time:67326ms step_avg:91.85ms +step:734/1660 train_time:67419ms step_avg:91.85ms +step:735/1660 train_time:67512ms step_avg:91.85ms +step:736/1660 train_time:67605ms step_avg:91.85ms +step:737/1660 train_time:67697ms step_avg:91.86ms +step:738/1660 train_time:67790ms step_avg:91.86ms +step:739/1660 train_time:67883ms step_avg:91.86ms +step:740/1660 train_time:67975ms step_avg:91.86ms +step:741/1660 train_time:68068ms step_avg:91.86ms +step:742/1660 train_time:68160ms step_avg:91.86ms +step:743/1660 train_time:68253ms step_avg:91.86ms +step:744/1660 train_time:68346ms step_avg:91.86ms +step:745/1660 train_time:68439ms step_avg:91.86ms +step:746/1660 train_time:68532ms step_avg:91.87ms +step:747/1660 train_time:68625ms step_avg:91.87ms +step:748/1660 train_time:68717ms step_avg:91.87ms +step:749/1660 train_time:68809ms step_avg:91.87ms +step:750/1660 train_time:68901ms step_avg:91.87ms +step:750/1660 val_loss:3.5627 train_time:68995ms step_avg:91.99ms +step:751/1660 train_time:69016ms step_avg:91.90ms +step:752/1660 train_time:69091ms step_avg:91.88ms +step:753/1660 train_time:69188ms step_avg:91.88ms +step:754/1660 train_time:69281ms step_avg:91.88ms +step:755/1660 train_time:69372ms step_avg:91.88ms +step:756/1660 train_time:69463ms step_avg:91.88ms +step:757/1660 train_time:69554ms step_avg:91.88ms +step:758/1660 train_time:69646ms step_avg:91.88ms +step:759/1660 train_time:69737ms step_avg:91.88ms +step:760/1660 train_time:69828ms step_avg:91.88ms +step:761/1660 train_time:69921ms step_avg:91.88ms +step:762/1660 train_time:70015ms step_avg:91.88ms +step:763/1660 train_time:70110ms step_avg:91.89ms +step:764/1660 train_time:70203ms step_avg:91.89ms +step:765/1660 train_time:70296ms step_avg:91.89ms +step:766/1660 train_time:70388ms step_avg:91.89ms +step:767/1660 train_time:70479ms step_avg:91.89ms +step:768/1660 train_time:70572ms step_avg:91.89ms +step:769/1660 train_time:70664ms step_avg:91.89ms +step:770/1660 train_time:70756ms step_avg:91.89ms +step:771/1660 train_time:70848ms step_avg:91.89ms +step:772/1660 train_time:70941ms step_avg:91.89ms +step:773/1660 train_time:71034ms step_avg:91.89ms +step:774/1660 train_time:71128ms step_avg:91.90ms +step:775/1660 train_time:71222ms step_avg:91.90ms +step:776/1660 train_time:71314ms step_avg:91.90ms +step:777/1660 train_time:71407ms step_avg:91.90ms +step:778/1660 train_time:71498ms step_avg:91.90ms +step:779/1660 train_time:71590ms step_avg:91.90ms +step:780/1660 train_time:71683ms step_avg:91.90ms +step:781/1660 train_time:71776ms step_avg:91.90ms +step:782/1660 train_time:71869ms step_avg:91.90ms +step:783/1660 train_time:71962ms step_avg:91.91ms +step:784/1660 train_time:72054ms step_avg:91.91ms +step:785/1660 train_time:72148ms step_avg:91.91ms +step:786/1660 train_time:72241ms step_avg:91.91ms +step:787/1660 train_time:72333ms step_avg:91.91ms +step:788/1660 train_time:72426ms step_avg:91.91ms +step:789/1660 train_time:72518ms step_avg:91.91ms +step:790/1660 train_time:72610ms step_avg:91.91ms +step:791/1660 train_time:72703ms step_avg:91.91ms +step:792/1660 train_time:72795ms step_avg:91.91ms +step:793/1660 train_time:72889ms step_avg:91.91ms +step:794/1660 train_time:72981ms step_avg:91.92ms +step:795/1660 train_time:73074ms step_avg:91.92ms +step:796/1660 train_time:73168ms step_avg:91.92ms +step:797/1660 train_time:73260ms step_avg:91.92ms +step:798/1660 train_time:73353ms step_avg:91.92ms +step:799/1660 train_time:73445ms step_avg:91.92ms +step:800/1660 train_time:73537ms step_avg:91.92ms +step:801/1660 train_time:73630ms step_avg:91.92ms +step:802/1660 train_time:73724ms step_avg:91.93ms +step:803/1660 train_time:73816ms step_avg:91.93ms +step:804/1660 train_time:73909ms step_avg:91.93ms +step:805/1660 train_time:74001ms step_avg:91.93ms +step:806/1660 train_time:74094ms step_avg:91.93ms +step:807/1660 train_time:74187ms step_avg:91.93ms +step:808/1660 train_time:74280ms step_avg:91.93ms +step:809/1660 train_time:74373ms step_avg:91.93ms +step:810/1660 train_time:74466ms step_avg:91.93ms +step:811/1660 train_time:74558ms step_avg:91.93ms +step:812/1660 train_time:74651ms step_avg:91.94ms +step:813/1660 train_time:74744ms step_avg:91.94ms +step:814/1660 train_time:74835ms step_avg:91.94ms +step:815/1660 train_time:74928ms step_avg:91.94ms +step:816/1660 train_time:75021ms step_avg:91.94ms +step:817/1660 train_time:75114ms step_avg:91.94ms +step:818/1660 train_time:75207ms step_avg:91.94ms +step:819/1660 train_time:75300ms step_avg:91.94ms +step:820/1660 train_time:75395ms step_avg:91.95ms +step:821/1660 train_time:75487ms step_avg:91.95ms +step:822/1660 train_time:75579ms step_avg:91.94ms +step:823/1660 train_time:75671ms step_avg:91.95ms +step:824/1660 train_time:75763ms step_avg:91.95ms +step:825/1660 train_time:75855ms step_avg:91.95ms +step:826/1660 train_time:75948ms step_avg:91.95ms +step:827/1660 train_time:76040ms step_avg:91.95ms +step:828/1660 train_time:76132ms step_avg:91.95ms +step:829/1660 train_time:76226ms step_avg:91.95ms +step:830/1660 train_time:76319ms step_avg:91.95ms +step:831/1660 train_time:76411ms step_avg:91.95ms +step:832/1660 train_time:76504ms step_avg:91.95ms +step:833/1660 train_time:76597ms step_avg:91.95ms +step:834/1660 train_time:76690ms step_avg:91.95ms +step:835/1660 train_time:76783ms step_avg:91.96ms +step:836/1660 train_time:76875ms step_avg:91.96ms +step:837/1660 train_time:76969ms step_avg:91.96ms +step:838/1660 train_time:77061ms step_avg:91.96ms +step:839/1660 train_time:77154ms step_avg:91.96ms +step:840/1660 train_time:77246ms step_avg:91.96ms +step:841/1660 train_time:77338ms step_avg:91.96ms +step:842/1660 train_time:77432ms step_avg:91.96ms +step:843/1660 train_time:77526ms step_avg:91.96ms +step:844/1660 train_time:77618ms step_avg:91.96ms +step:845/1660 train_time:77711ms step_avg:91.97ms +step:846/1660 train_time:77804ms step_avg:91.97ms +step:847/1660 train_time:77897ms step_avg:91.97ms +step:848/1660 train_time:77990ms step_avg:91.97ms +step:849/1660 train_time:78083ms step_avg:91.97ms +step:850/1660 train_time:78176ms step_avg:91.97ms +step:851/1660 train_time:78268ms step_avg:91.97ms +step:852/1660 train_time:78361ms step_avg:91.97ms +step:853/1660 train_time:78453ms step_avg:91.97ms +step:854/1660 train_time:78547ms step_avg:91.97ms +step:855/1660 train_time:78639ms step_avg:91.98ms +step:856/1660 train_time:78731ms step_avg:91.98ms +step:857/1660 train_time:78824ms step_avg:91.98ms +step:858/1660 train_time:78917ms step_avg:91.98ms +step:859/1660 train_time:79009ms step_avg:91.98ms +step:860/1660 train_time:79102ms step_avg:91.98ms +step:861/1660 train_time:79195ms step_avg:91.98ms +step:862/1660 train_time:79287ms step_avg:91.98ms +step:863/1660 train_time:79380ms step_avg:91.98ms +step:864/1660 train_time:79472ms step_avg:91.98ms +step:865/1660 train_time:79566ms step_avg:91.98ms +step:866/1660 train_time:79658ms step_avg:91.98ms +step:867/1660 train_time:79752ms step_avg:91.99ms +step:868/1660 train_time:79844ms step_avg:91.99ms +step:869/1660 train_time:79938ms step_avg:91.99ms +step:870/1660 train_time:80031ms step_avg:91.99ms +step:871/1660 train_time:80123ms step_avg:91.99ms +step:872/1660 train_time:80216ms step_avg:91.99ms +step:873/1660 train_time:80308ms step_avg:91.99ms +step:874/1660 train_time:80401ms step_avg:91.99ms +step:875/1660 train_time:80494ms step_avg:91.99ms +step:875/1660 val_loss:3.5172 train_time:80588ms step_avg:92.10ms +step:876/1660 train_time:80609ms step_avg:92.02ms +step:877/1660 train_time:80685ms step_avg:92.00ms +step:878/1660 train_time:80782ms step_avg:92.01ms +step:879/1660 train_time:80874ms step_avg:92.01ms +step:880/1660 train_time:80965ms step_avg:92.01ms +step:881/1660 train_time:81056ms step_avg:92.00ms +step:882/1660 train_time:81147ms step_avg:92.00ms +step:883/1660 train_time:81239ms step_avg:92.00ms +step:884/1660 train_time:81331ms step_avg:92.00ms +step:885/1660 train_time:81422ms step_avg:92.00ms +step:886/1660 train_time:81515ms step_avg:92.00ms +step:887/1660 train_time:81611ms step_avg:92.01ms +step:888/1660 train_time:81708ms step_avg:92.01ms +step:889/1660 train_time:81802ms step_avg:92.02ms +step:890/1660 train_time:81894ms step_avg:92.02ms +step:891/1660 train_time:81987ms step_avg:92.02ms +step:892/1660 train_time:82079ms step_avg:92.02ms +step:893/1660 train_time:82171ms step_avg:92.02ms +step:894/1660 train_time:82263ms step_avg:92.02ms +step:895/1660 train_time:82354ms step_avg:92.02ms +step:896/1660 train_time:82446ms step_avg:92.02ms +step:897/1660 train_time:82540ms step_avg:92.02ms +step:898/1660 train_time:82635ms step_avg:92.02ms +step:899/1660 train_time:82731ms step_avg:92.03ms +step:900/1660 train_time:82824ms step_avg:92.03ms +step:901/1660 train_time:82916ms step_avg:92.03ms +step:902/1660 train_time:83009ms step_avg:92.03ms +step:903/1660 train_time:83101ms step_avg:92.03ms +step:904/1660 train_time:83193ms step_avg:92.03ms +step:905/1660 train_time:83285ms step_avg:92.03ms +step:906/1660 train_time:83377ms step_avg:92.03ms +step:907/1660 train_time:83470ms step_avg:92.03ms +step:908/1660 train_time:83563ms step_avg:92.03ms +step:909/1660 train_time:83657ms step_avg:92.03ms +step:910/1660 train_time:83752ms step_avg:92.04ms +step:911/1660 train_time:83846ms step_avg:92.04ms +step:912/1660 train_time:83938ms step_avg:92.04ms +step:913/1660 train_time:84031ms step_avg:92.04ms +step:914/1660 train_time:84124ms step_avg:92.04ms +step:915/1660 train_time:84215ms step_avg:92.04ms +step:916/1660 train_time:84308ms step_avg:92.04ms +step:917/1660 train_time:84400ms step_avg:92.04ms +step:918/1660 train_time:84494ms step_avg:92.04ms +step:919/1660 train_time:84586ms step_avg:92.04ms +step:920/1660 train_time:84679ms step_avg:92.04ms +step:921/1660 train_time:84772ms step_avg:92.04ms +step:922/1660 train_time:84865ms step_avg:92.04ms +step:923/1660 train_time:84958ms step_avg:92.05ms +step:924/1660 train_time:85051ms step_avg:92.05ms +step:925/1660 train_time:85143ms step_avg:92.05ms +step:926/1660 train_time:85235ms step_avg:92.05ms +step:927/1660 train_time:85329ms step_avg:92.05ms +step:928/1660 train_time:85421ms step_avg:92.05ms +step:929/1660 train_time:85514ms step_avg:92.05ms +step:930/1660 train_time:85608ms step_avg:92.05ms +step:931/1660 train_time:85701ms step_avg:92.05ms +step:932/1660 train_time:85793ms step_avg:92.05ms +step:933/1660 train_time:85887ms step_avg:92.05ms +step:934/1660 train_time:85979ms step_avg:92.05ms +step:935/1660 train_time:86073ms step_avg:92.06ms +step:936/1660 train_time:86166ms step_avg:92.06ms +step:937/1660 train_time:86257ms step_avg:92.06ms +step:938/1660 train_time:86351ms step_avg:92.06ms +step:939/1660 train_time:86444ms step_avg:92.06ms +step:940/1660 train_time:86537ms step_avg:92.06ms +step:941/1660 train_time:86630ms step_avg:92.06ms +step:942/1660 train_time:86724ms step_avg:92.06ms +step:943/1660 train_time:86816ms step_avg:92.06ms +step:944/1660 train_time:86908ms step_avg:92.06ms +step:945/1660 train_time:87001ms step_avg:92.06ms +step:946/1660 train_time:87093ms step_avg:92.06ms +step:947/1660 train_time:87186ms step_avg:92.07ms +step:948/1660 train_time:87278ms step_avg:92.07ms +step:949/1660 train_time:87371ms step_avg:92.07ms +step:950/1660 train_time:87464ms step_avg:92.07ms +step:951/1660 train_time:87556ms step_avg:92.07ms +step:952/1660 train_time:87649ms step_avg:92.07ms +step:953/1660 train_time:87742ms step_avg:92.07ms +step:954/1660 train_time:87834ms step_avg:92.07ms +step:955/1660 train_time:87927ms step_avg:92.07ms +step:956/1660 train_time:88019ms step_avg:92.07ms +step:957/1660 train_time:88113ms step_avg:92.07ms +step:958/1660 train_time:88205ms step_avg:92.07ms +step:959/1660 train_time:88298ms step_avg:92.07ms +step:960/1660 train_time:88393ms step_avg:92.08ms +step:961/1660 train_time:88486ms step_avg:92.08ms +step:962/1660 train_time:88577ms step_avg:92.08ms +step:963/1660 train_time:88670ms step_avg:92.08ms +step:964/1660 train_time:88763ms step_avg:92.08ms +step:965/1660 train_time:88856ms step_avg:92.08ms +step:966/1660 train_time:88948ms step_avg:92.08ms +step:967/1660 train_time:89041ms step_avg:92.08ms +step:968/1660 train_time:89134ms step_avg:92.08ms +step:969/1660 train_time:89226ms step_avg:92.08ms +step:970/1660 train_time:89319ms step_avg:92.08ms +step:971/1660 train_time:89413ms step_avg:92.08ms +step:972/1660 train_time:89506ms step_avg:92.08ms +step:973/1660 train_time:89598ms step_avg:92.08ms +step:974/1660 train_time:89691ms step_avg:92.09ms +step:975/1660 train_time:89784ms step_avg:92.09ms +step:976/1660 train_time:89876ms step_avg:92.09ms +step:977/1660 train_time:89969ms step_avg:92.09ms +step:978/1660 train_time:90062ms step_avg:92.09ms +step:979/1660 train_time:90154ms step_avg:92.09ms +step:980/1660 train_time:90246ms step_avg:92.09ms +step:981/1660 train_time:90339ms step_avg:92.09ms +step:982/1660 train_time:90432ms step_avg:92.09ms +step:983/1660 train_time:90525ms step_avg:92.09ms +step:984/1660 train_time:90617ms step_avg:92.09ms +step:985/1660 train_time:90712ms step_avg:92.09ms +step:986/1660 train_time:90805ms step_avg:92.09ms +step:987/1660 train_time:90897ms step_avg:92.09ms +step:988/1660 train_time:90989ms step_avg:92.09ms +step:989/1660 train_time:91082ms step_avg:92.10ms +step:990/1660 train_time:91175ms step_avg:92.10ms +step:991/1660 train_time:91268ms step_avg:92.10ms +step:992/1660 train_time:91362ms step_avg:92.10ms +step:993/1660 train_time:91455ms step_avg:92.10ms +step:994/1660 train_time:91547ms step_avg:92.10ms +step:995/1660 train_time:91640ms step_avg:92.10ms +step:996/1660 train_time:91733ms step_avg:92.10ms +step:997/1660 train_time:91826ms step_avg:92.10ms +step:998/1660 train_time:91917ms step_avg:92.10ms +step:999/1660 train_time:92010ms step_avg:92.10ms +step:1000/1660 train_time:92103ms step_avg:92.10ms +step:1000/1660 val_loss:3.4688 train_time:92197ms step_avg:92.20ms +step:1001/1660 train_time:92217ms step_avg:92.13ms +step:1002/1660 train_time:92293ms step_avg:92.11ms +step:1003/1660 train_time:92388ms step_avg:92.11ms +step:1004/1660 train_time:92480ms step_avg:92.11ms +step:1005/1660 train_time:92571ms step_avg:92.11ms +step:1006/1660 train_time:92663ms step_avg:92.11ms +step:1007/1660 train_time:92754ms step_avg:92.11ms +step:1008/1660 train_time:92846ms step_avg:92.11ms +step:1009/1660 train_time:92938ms step_avg:92.11ms +step:1010/1660 train_time:93030ms step_avg:92.11ms +step:1011/1660 train_time:93123ms step_avg:92.11ms +step:1012/1660 train_time:93217ms step_avg:92.11ms +step:1013/1660 train_time:93311ms step_avg:92.11ms +step:1014/1660 train_time:93405ms step_avg:92.12ms +step:1015/1660 train_time:93498ms step_avg:92.12ms +step:1016/1660 train_time:93590ms step_avg:92.12ms +step:1017/1660 train_time:93682ms step_avg:92.12ms +step:1018/1660 train_time:93774ms step_avg:92.12ms +step:1019/1660 train_time:93866ms step_avg:92.12ms +step:1020/1660 train_time:93958ms step_avg:92.12ms +step:1021/1660 train_time:94050ms step_avg:92.12ms +step:1022/1660 train_time:94144ms step_avg:92.12ms +step:1023/1660 train_time:94237ms step_avg:92.12ms +step:1024/1660 train_time:94332ms step_avg:92.12ms +step:1025/1660 train_time:94425ms step_avg:92.12ms +step:1026/1660 train_time:94517ms step_avg:92.12ms +step:1027/1660 train_time:94609ms step_avg:92.12ms +step:1028/1660 train_time:94702ms step_avg:92.12ms +step:1029/1660 train_time:94794ms step_avg:92.12ms +step:1030/1660 train_time:94886ms step_avg:92.12ms +step:1031/1660 train_time:94979ms step_avg:92.12ms +step:1032/1660 train_time:95071ms step_avg:92.12ms +step:1033/1660 train_time:95164ms step_avg:92.12ms +step:1034/1660 train_time:95257ms step_avg:92.13ms +step:1035/1660 train_time:95351ms step_avg:92.13ms +step:1036/1660 train_time:95444ms step_avg:92.13ms +step:1037/1660 train_time:95537ms step_avg:92.13ms +step:1038/1660 train_time:95630ms step_avg:92.13ms +step:1039/1660 train_time:95723ms step_avg:92.13ms +step:1040/1660 train_time:95815ms step_avg:92.13ms +step:1041/1660 train_time:95907ms step_avg:92.13ms +step:1042/1660 train_time:96000ms step_avg:92.13ms +step:1043/1660 train_time:96093ms step_avg:92.13ms +step:1044/1660 train_time:96186ms step_avg:92.13ms +step:1045/1660 train_time:96279ms step_avg:92.13ms +step:1046/1660 train_time:96374ms step_avg:92.14ms +step:1047/1660 train_time:96467ms step_avg:92.14ms +step:1048/1660 train_time:96560ms step_avg:92.14ms +step:1049/1660 train_time:96653ms step_avg:92.14ms +step:1050/1660 train_time:96745ms step_avg:92.14ms +step:1051/1660 train_time:96837ms step_avg:92.14ms +step:1052/1660 train_time:96930ms step_avg:92.14ms +step:1053/1660 train_time:97023ms step_avg:92.14ms +step:1054/1660 train_time:97115ms step_avg:92.14ms +step:1055/1660 train_time:97209ms step_avg:92.14ms +step:1056/1660 train_time:97302ms step_avg:92.14ms +step:1057/1660 train_time:97394ms step_avg:92.14ms +step:1058/1660 train_time:97487ms step_avg:92.14ms +step:1059/1660 train_time:97579ms step_avg:92.14ms +step:1060/1660 train_time:97672ms step_avg:92.14ms +step:1061/1660 train_time:97765ms step_avg:92.14ms +step:1062/1660 train_time:97857ms step_avg:92.14ms +step:1063/1660 train_time:97952ms step_avg:92.15ms +step:1064/1660 train_time:98044ms step_avg:92.15ms +step:1065/1660 train_time:98137ms step_avg:92.15ms +step:1066/1660 train_time:98231ms step_avg:92.15ms +step:1067/1660 train_time:98324ms step_avg:92.15ms +step:1068/1660 train_time:98416ms step_avg:92.15ms +step:1069/1660 train_time:98509ms step_avg:92.15ms +step:1070/1660 train_time:98602ms step_avg:92.15ms +step:1071/1660 train_time:98695ms step_avg:92.15ms +step:1072/1660 train_time:98787ms step_avg:92.15ms +step:1073/1660 train_time:98880ms step_avg:92.15ms +step:1074/1660 train_time:98973ms step_avg:92.15ms +step:1075/1660 train_time:99066ms step_avg:92.15ms +step:1076/1660 train_time:99157ms step_avg:92.15ms +step:1077/1660 train_time:99251ms step_avg:92.15ms +step:1078/1660 train_time:99344ms step_avg:92.16ms +step:1079/1660 train_time:99436ms step_avg:92.16ms +step:1080/1660 train_time:99529ms step_avg:92.16ms +step:1081/1660 train_time:99622ms step_avg:92.16ms +step:1082/1660 train_time:99714ms step_avg:92.16ms +step:1083/1660 train_time:99807ms step_avg:92.16ms +step:1084/1660 train_time:99900ms step_avg:92.16ms +step:1085/1660 train_time:99993ms step_avg:92.16ms +step:1086/1660 train_time:100086ms step_avg:92.16ms +step:1087/1660 train_time:100178ms step_avg:92.16ms +step:1088/1660 train_time:100272ms step_avg:92.16ms +step:1089/1660 train_time:100365ms step_avg:92.16ms +step:1090/1660 train_time:100457ms step_avg:92.16ms +step:1091/1660 train_time:100550ms step_avg:92.16ms +step:1092/1660 train_time:100644ms step_avg:92.16ms +step:1093/1660 train_time:100736ms step_avg:92.16ms +step:1094/1660 train_time:100830ms step_avg:92.17ms +step:1095/1660 train_time:100922ms step_avg:92.17ms +step:1096/1660 train_time:101014ms step_avg:92.17ms +step:1097/1660 train_time:101107ms step_avg:92.17ms +step:1098/1660 train_time:101200ms step_avg:92.17ms +step:1099/1660 train_time:101293ms step_avg:92.17ms +step:1100/1660 train_time:101386ms step_avg:92.17ms +step:1101/1660 train_time:101478ms step_avg:92.17ms +step:1102/1660 train_time:101572ms step_avg:92.17ms +step:1103/1660 train_time:101664ms step_avg:92.17ms +step:1104/1660 train_time:101756ms step_avg:92.17ms +step:1105/1660 train_time:101849ms step_avg:92.17ms +step:1106/1660 train_time:101942ms step_avg:92.17ms +step:1107/1660 train_time:102034ms step_avg:92.17ms +step:1108/1660 train_time:102128ms step_avg:92.17ms +step:1109/1660 train_time:102220ms step_avg:92.17ms +step:1110/1660 train_time:102315ms step_avg:92.18ms +step:1111/1660 train_time:102408ms step_avg:92.18ms +step:1112/1660 train_time:102501ms step_avg:92.18ms +step:1113/1660 train_time:102594ms step_avg:92.18ms +step:1114/1660 train_time:102688ms step_avg:92.18ms +step:1115/1660 train_time:102781ms step_avg:92.18ms +step:1116/1660 train_time:102875ms step_avg:92.18ms +step:1117/1660 train_time:102968ms step_avg:92.18ms +step:1118/1660 train_time:103061ms step_avg:92.18ms +step:1119/1660 train_time:103155ms step_avg:92.18ms +step:1120/1660 train_time:103249ms step_avg:92.19ms +step:1121/1660 train_time:103342ms step_avg:92.19ms +step:1122/1660 train_time:103436ms step_avg:92.19ms +step:1123/1660 train_time:103530ms step_avg:92.19ms +step:1124/1660 train_time:103623ms step_avg:92.19ms +step:1125/1660 train_time:103717ms step_avg:92.19ms +step:1125/1660 val_loss:3.4139 train_time:103812ms step_avg:92.28ms +step:1126/1660 train_time:103832ms step_avg:92.21ms +step:1127/1660 train_time:103907ms step_avg:92.20ms +step:1128/1660 train_time:104006ms step_avg:92.20ms +step:1129/1660 train_time:104100ms step_avg:92.21ms +step:1130/1660 train_time:104192ms step_avg:92.21ms +step:1131/1660 train_time:104284ms step_avg:92.21ms +step:1132/1660 train_time:104376ms step_avg:92.21ms +step:1133/1660 train_time:104468ms step_avg:92.20ms +step:1134/1660 train_time:104560ms step_avg:92.20ms +step:1135/1660 train_time:104652ms step_avg:92.20ms +step:1136/1660 train_time:104747ms step_avg:92.21ms +step:1137/1660 train_time:104845ms step_avg:92.21ms +step:1138/1660 train_time:104940ms step_avg:92.21ms +step:1139/1660 train_time:105035ms step_avg:92.22ms +step:1140/1660 train_time:105129ms step_avg:92.22ms +step:1141/1660 train_time:105221ms step_avg:92.22ms +step:1142/1660 train_time:105314ms step_avg:92.22ms +step:1143/1660 train_time:105406ms step_avg:92.22ms +step:1144/1660 train_time:105500ms step_avg:92.22ms +step:1145/1660 train_time:105591ms step_avg:92.22ms +step:1146/1660 train_time:105685ms step_avg:92.22ms +step:1147/1660 train_time:105779ms step_avg:92.22ms +step:1148/1660 train_time:105873ms step_avg:92.22ms +step:1149/1660 train_time:105968ms step_avg:92.23ms +step:1150/1660 train_time:106061ms step_avg:92.23ms +step:1151/1660 train_time:106155ms step_avg:92.23ms +step:1152/1660 train_time:106247ms step_avg:92.23ms +step:1153/1660 train_time:106340ms step_avg:92.23ms +step:1154/1660 train_time:106432ms step_avg:92.23ms +step:1155/1660 train_time:106525ms step_avg:92.23ms +step:1156/1660 train_time:106619ms step_avg:92.23ms +step:1157/1660 train_time:106711ms step_avg:92.23ms +step:1158/1660 train_time:106805ms step_avg:92.23ms +step:1159/1660 train_time:106899ms step_avg:92.23ms +step:1160/1660 train_time:106992ms step_avg:92.23ms +step:1161/1660 train_time:107087ms step_avg:92.24ms +step:1162/1660 train_time:107180ms step_avg:92.24ms +step:1163/1660 train_time:107273ms step_avg:92.24ms +step:1164/1660 train_time:107366ms step_avg:92.24ms +step:1165/1660 train_time:107458ms step_avg:92.24ms +step:1166/1660 train_time:107551ms step_avg:92.24ms +step:1167/1660 train_time:107645ms step_avg:92.24ms +step:1168/1660 train_time:107738ms step_avg:92.24ms +step:1169/1660 train_time:107831ms step_avg:92.24ms +step:1170/1660 train_time:107926ms step_avg:92.24ms +step:1171/1660 train_time:108020ms step_avg:92.25ms +step:1172/1660 train_time:108113ms step_avg:92.25ms +step:1173/1660 train_time:108207ms step_avg:92.25ms +step:1174/1660 train_time:108300ms step_avg:92.25ms +step:1175/1660 train_time:108392ms step_avg:92.25ms +step:1176/1660 train_time:108485ms step_avg:92.25ms +step:1177/1660 train_time:108579ms step_avg:92.25ms +step:1178/1660 train_time:108671ms step_avg:92.25ms +step:1179/1660 train_time:108765ms step_avg:92.25ms +step:1180/1660 train_time:108858ms step_avg:92.25ms +step:1181/1660 train_time:108951ms step_avg:92.25ms +step:1182/1660 train_time:109045ms step_avg:92.25ms +step:1183/1660 train_time:109138ms step_avg:92.26ms +step:1184/1660 train_time:109231ms step_avg:92.26ms +step:1185/1660 train_time:109324ms step_avg:92.26ms +step:1186/1660 train_time:109417ms step_avg:92.26ms +step:1187/1660 train_time:109509ms step_avg:92.26ms +step:1188/1660 train_time:109602ms step_avg:92.26ms +step:1189/1660 train_time:109695ms step_avg:92.26ms +step:1190/1660 train_time:109789ms step_avg:92.26ms +step:1191/1660 train_time:109883ms step_avg:92.26ms +step:1192/1660 train_time:109977ms step_avg:92.26ms +step:1193/1660 train_time:110070ms step_avg:92.26ms +step:1194/1660 train_time:110162ms step_avg:92.26ms +step:1195/1660 train_time:110255ms step_avg:92.26ms +step:1196/1660 train_time:110349ms step_avg:92.26ms +step:1197/1660 train_time:110442ms step_avg:92.27ms +step:1198/1660 train_time:110534ms step_avg:92.27ms +step:1199/1660 train_time:110628ms step_avg:92.27ms +step:1200/1660 train_time:110721ms step_avg:92.27ms +step:1201/1660 train_time:110813ms step_avg:92.27ms +step:1202/1660 train_time:110907ms step_avg:92.27ms +step:1203/1660 train_time:111000ms step_avg:92.27ms +step:1204/1660 train_time:111093ms step_avg:92.27ms +step:1205/1660 train_time:111187ms step_avg:92.27ms +step:1206/1660 train_time:111280ms step_avg:92.27ms +step:1207/1660 train_time:111372ms step_avg:92.27ms +step:1208/1660 train_time:111465ms step_avg:92.27ms +step:1209/1660 train_time:111559ms step_avg:92.27ms +step:1210/1660 train_time:111652ms step_avg:92.27ms +step:1211/1660 train_time:111745ms step_avg:92.27ms +step:1212/1660 train_time:111837ms step_avg:92.28ms +step:1213/1660 train_time:111931ms step_avg:92.28ms +step:1214/1660 train_time:112025ms step_avg:92.28ms +step:1215/1660 train_time:112118ms step_avg:92.28ms +step:1216/1660 train_time:112212ms step_avg:92.28ms +step:1217/1660 train_time:112305ms step_avg:92.28ms +step:1218/1660 train_time:112399ms step_avg:92.28ms +step:1219/1660 train_time:112491ms step_avg:92.28ms +step:1220/1660 train_time:112584ms step_avg:92.28ms +step:1221/1660 train_time:112678ms step_avg:92.28ms +step:1222/1660 train_time:112770ms step_avg:92.28ms +step:1223/1660 train_time:112863ms step_avg:92.28ms +step:1224/1660 train_time:112956ms step_avg:92.28ms +step:1225/1660 train_time:113049ms step_avg:92.29ms +step:1226/1660 train_time:113143ms step_avg:92.29ms +step:1227/1660 train_time:113236ms step_avg:92.29ms +step:1228/1660 train_time:113330ms step_avg:92.29ms +step:1229/1660 train_time:113422ms step_avg:92.29ms +step:1230/1660 train_time:113515ms step_avg:92.29ms +step:1231/1660 train_time:113609ms step_avg:92.29ms +step:1232/1660 train_time:113702ms step_avg:92.29ms +step:1233/1660 train_time:113795ms step_avg:92.29ms +step:1234/1660 train_time:113888ms step_avg:92.29ms +step:1235/1660 train_time:113981ms step_avg:92.29ms +step:1236/1660 train_time:114074ms step_avg:92.29ms +step:1237/1660 train_time:114169ms step_avg:92.29ms +step:1238/1660 train_time:114263ms step_avg:92.30ms +step:1239/1660 train_time:114356ms step_avg:92.30ms +step:1240/1660 train_time:114449ms step_avg:92.30ms +step:1241/1660 train_time:114543ms step_avg:92.30ms +step:1242/1660 train_time:114636ms step_avg:92.30ms +step:1243/1660 train_time:114729ms step_avg:92.30ms +step:1244/1660 train_time:114822ms step_avg:92.30ms +step:1245/1660 train_time:114916ms step_avg:92.30ms +step:1246/1660 train_time:115009ms step_avg:92.30ms +step:1247/1660 train_time:115102ms step_avg:92.30ms +step:1248/1660 train_time:115195ms step_avg:92.30ms +step:1249/1660 train_time:115288ms step_avg:92.30ms +step:1250/1660 train_time:115382ms step_avg:92.31ms +step:1250/1660 val_loss:3.3752 train_time:115476ms step_avg:92.38ms +step:1251/1660 train_time:115496ms step_avg:92.32ms +step:1252/1660 train_time:115573ms step_avg:92.31ms +step:1253/1660 train_time:115669ms step_avg:92.31ms +step:1254/1660 train_time:115762ms step_avg:92.31ms +step:1255/1660 train_time:115855ms step_avg:92.31ms +step:1256/1660 train_time:115947ms step_avg:92.31ms +step:1257/1660 train_time:116039ms step_avg:92.31ms +step:1258/1660 train_time:116131ms step_avg:92.31ms +step:1259/1660 train_time:116224ms step_avg:92.31ms +step:1260/1660 train_time:116315ms step_avg:92.31ms +step:1261/1660 train_time:116410ms step_avg:92.32ms +step:1262/1660 train_time:116507ms step_avg:92.32ms +step:1263/1660 train_time:116601ms step_avg:92.32ms +step:1264/1660 train_time:116695ms step_avg:92.32ms +step:1265/1660 train_time:116789ms step_avg:92.32ms +step:1266/1660 train_time:116882ms step_avg:92.32ms +step:1267/1660 train_time:116974ms step_avg:92.32ms +step:1268/1660 train_time:117067ms step_avg:92.32ms +step:1269/1660 train_time:117159ms step_avg:92.32ms +step:1270/1660 train_time:117251ms step_avg:92.32ms +step:1271/1660 train_time:117344ms step_avg:92.32ms +step:1272/1660 train_time:117437ms step_avg:92.32ms +step:1273/1660 train_time:117532ms step_avg:92.33ms +step:1274/1660 train_time:117626ms step_avg:92.33ms +step:1275/1660 train_time:117720ms step_avg:92.33ms +step:1276/1660 train_time:117814ms step_avg:92.33ms +step:1277/1660 train_time:117906ms step_avg:92.33ms +step:1278/1660 train_time:117999ms step_avg:92.33ms +step:1279/1660 train_time:118091ms step_avg:92.33ms +step:1280/1660 train_time:118183ms step_avg:92.33ms +step:1281/1660 train_time:118276ms step_avg:92.33ms +step:1282/1660 train_time:118370ms step_avg:92.33ms +step:1283/1660 train_time:118463ms step_avg:92.33ms +step:1284/1660 train_time:118557ms step_avg:92.33ms +step:1285/1660 train_time:118652ms step_avg:92.34ms +step:1286/1660 train_time:118745ms step_avg:92.34ms +step:1287/1660 train_time:118837ms step_avg:92.34ms +step:1288/1660 train_time:118931ms step_avg:92.34ms +step:1289/1660 train_time:119024ms step_avg:92.34ms +step:1290/1660 train_time:119117ms step_avg:92.34ms +step:1291/1660 train_time:119210ms step_avg:92.34ms +step:1292/1660 train_time:119303ms step_avg:92.34ms +step:1293/1660 train_time:119397ms step_avg:92.34ms +step:1294/1660 train_time:119490ms step_avg:92.34ms +step:1295/1660 train_time:119584ms step_avg:92.34ms +step:1296/1660 train_time:119677ms step_avg:92.34ms +step:1297/1660 train_time:119771ms step_avg:92.34ms +step:1298/1660 train_time:119864ms step_avg:92.34ms +step:1299/1660 train_time:119957ms step_avg:92.35ms +step:1300/1660 train_time:120050ms step_avg:92.35ms +step:1301/1660 train_time:120143ms step_avg:92.35ms +step:1302/1660 train_time:120236ms step_avg:92.35ms +step:1303/1660 train_time:120329ms step_avg:92.35ms +step:1304/1660 train_time:120422ms step_avg:92.35ms +step:1305/1660 train_time:120518ms step_avg:92.35ms +step:1306/1660 train_time:120612ms step_avg:92.35ms +step:1307/1660 train_time:120705ms step_avg:92.35ms +step:1308/1660 train_time:120798ms step_avg:92.35ms +step:1309/1660 train_time:120891ms step_avg:92.35ms +step:1310/1660 train_time:120984ms step_avg:92.35ms +step:1311/1660 train_time:121077ms step_avg:92.35ms +step:1312/1660 train_time:121171ms step_avg:92.36ms +step:1313/1660 train_time:121264ms step_avg:92.36ms +step:1314/1660 train_time:121357ms step_avg:92.36ms +step:1315/1660 train_time:121450ms step_avg:92.36ms +step:1316/1660 train_time:121543ms step_avg:92.36ms +step:1317/1660 train_time:121637ms step_avg:92.36ms +step:1318/1660 train_time:121730ms step_avg:92.36ms +step:1319/1660 train_time:121824ms step_avg:92.36ms +step:1320/1660 train_time:121918ms step_avg:92.36ms +step:1321/1660 train_time:122011ms step_avg:92.36ms +step:1322/1660 train_time:122104ms step_avg:92.36ms +step:1323/1660 train_time:122197ms step_avg:92.36ms +step:1324/1660 train_time:122290ms step_avg:92.36ms +step:1325/1660 train_time:122384ms step_avg:92.37ms +step:1326/1660 train_time:122478ms step_avg:92.37ms +step:1327/1660 train_time:122570ms step_avg:92.37ms +step:1328/1660 train_time:122663ms step_avg:92.37ms +step:1329/1660 train_time:122758ms step_avg:92.37ms +step:1330/1660 train_time:122851ms step_avg:92.37ms +step:1331/1660 train_time:122945ms step_avg:92.37ms +step:1332/1660 train_time:123038ms step_avg:92.37ms +step:1333/1660 train_time:123131ms step_avg:92.37ms +step:1334/1660 train_time:123223ms step_avg:92.37ms +step:1335/1660 train_time:123317ms step_avg:92.37ms +step:1336/1660 train_time:123411ms step_avg:92.37ms +step:1337/1660 train_time:123504ms step_avg:92.37ms +step:1338/1660 train_time:123597ms step_avg:92.37ms +step:1339/1660 train_time:123690ms step_avg:92.37ms +step:1340/1660 train_time:123783ms step_avg:92.38ms +step:1341/1660 train_time:123876ms step_avg:92.38ms +step:1342/1660 train_time:123970ms step_avg:92.38ms +step:1343/1660 train_time:124064ms step_avg:92.38ms +step:1344/1660 train_time:124156ms step_avg:92.38ms +step:1345/1660 train_time:124249ms step_avg:92.38ms +step:1346/1660 train_time:124343ms step_avg:92.38ms +step:1347/1660 train_time:124436ms step_avg:92.38ms +step:1348/1660 train_time:124529ms step_avg:92.38ms +step:1349/1660 train_time:124622ms step_avg:92.38ms +step:1350/1660 train_time:124717ms step_avg:92.38ms +step:1351/1660 train_time:124811ms step_avg:92.38ms +step:1352/1660 train_time:124904ms step_avg:92.38ms +step:1353/1660 train_time:124998ms step_avg:92.39ms +step:1354/1660 train_time:125091ms step_avg:92.39ms +step:1355/1660 train_time:125184ms step_avg:92.39ms +step:1356/1660 train_time:125276ms step_avg:92.39ms +step:1357/1660 train_time:125369ms step_avg:92.39ms +step:1358/1660 train_time:125462ms step_avg:92.39ms +step:1359/1660 train_time:125555ms step_avg:92.39ms +step:1360/1660 train_time:125648ms step_avg:92.39ms +step:1361/1660 train_time:125741ms step_avg:92.39ms +step:1362/1660 train_time:125835ms step_avg:92.39ms +step:1363/1660 train_time:125929ms step_avg:92.39ms +step:1364/1660 train_time:126022ms step_avg:92.39ms +step:1365/1660 train_time:126117ms step_avg:92.39ms +step:1366/1660 train_time:126211ms step_avg:92.39ms +step:1367/1660 train_time:126304ms step_avg:92.40ms +step:1368/1660 train_time:126396ms step_avg:92.40ms +step:1369/1660 train_time:126489ms step_avg:92.40ms +step:1370/1660 train_time:126583ms step_avg:92.40ms +step:1371/1660 train_time:126676ms step_avg:92.40ms +step:1372/1660 train_time:126769ms step_avg:92.40ms +step:1373/1660 train_time:126863ms step_avg:92.40ms +step:1374/1660 train_time:126956ms step_avg:92.40ms +step:1375/1660 train_time:127050ms step_avg:92.40ms +step:1375/1660 val_loss:3.3410 train_time:127145ms step_avg:92.47ms +step:1376/1660 train_time:127165ms step_avg:92.42ms +step:1377/1660 train_time:127244ms step_avg:92.41ms +step:1378/1660 train_time:127343ms step_avg:92.41ms +step:1379/1660 train_time:127435ms step_avg:92.41ms +step:1380/1660 train_time:127527ms step_avg:92.41ms +step:1381/1660 train_time:127619ms step_avg:92.41ms +step:1382/1660 train_time:127711ms step_avg:92.41ms +step:1383/1660 train_time:127804ms step_avg:92.41ms +step:1384/1660 train_time:127897ms step_avg:92.41ms +step:1385/1660 train_time:127990ms step_avg:92.41ms +step:1386/1660 train_time:128083ms step_avg:92.41ms +step:1387/1660 train_time:128179ms step_avg:92.41ms +step:1388/1660 train_time:128275ms step_avg:92.42ms +step:1389/1660 train_time:128370ms step_avg:92.42ms +step:1390/1660 train_time:128464ms step_avg:92.42ms +step:1391/1660 train_time:128556ms step_avg:92.42ms +step:1392/1660 train_time:128648ms step_avg:92.42ms +step:1393/1660 train_time:128741ms step_avg:92.42ms +step:1394/1660 train_time:128833ms step_avg:92.42ms +step:1395/1660 train_time:128925ms step_avg:92.42ms +step:1396/1660 train_time:129018ms step_avg:92.42ms +step:1397/1660 train_time:129111ms step_avg:92.42ms +step:1398/1660 train_time:129205ms step_avg:92.42ms +step:1399/1660 train_time:129299ms step_avg:92.42ms +step:1400/1660 train_time:129393ms step_avg:92.42ms +step:1401/1660 train_time:129490ms step_avg:92.43ms +step:1402/1660 train_time:129582ms step_avg:92.43ms +step:1403/1660 train_time:129675ms step_avg:92.43ms +step:1404/1660 train_time:129767ms step_avg:92.43ms +step:1405/1660 train_time:129859ms step_avg:92.43ms +step:1406/1660 train_time:129951ms step_avg:92.43ms +step:1407/1660 train_time:130044ms step_avg:92.43ms +step:1408/1660 train_time:130138ms step_avg:92.43ms +step:1409/1660 train_time:130232ms step_avg:92.43ms +step:1410/1660 train_time:130326ms step_avg:92.43ms +step:1411/1660 train_time:130420ms step_avg:92.43ms +step:1412/1660 train_time:130513ms step_avg:92.43ms +step:1413/1660 train_time:130606ms step_avg:92.43ms +step:1414/1660 train_time:130699ms step_avg:92.43ms +step:1415/1660 train_time:130791ms step_avg:92.43ms +step:1416/1660 train_time:130884ms step_avg:92.43ms +step:1417/1660 train_time:130977ms step_avg:92.43ms +step:1418/1660 train_time:131070ms step_avg:92.43ms +step:1419/1660 train_time:131163ms step_avg:92.43ms +step:1420/1660 train_time:131257ms step_avg:92.43ms +step:1421/1660 train_time:131351ms step_avg:92.44ms +step:1422/1660 train_time:131445ms step_avg:92.44ms +step:1423/1660 train_time:131538ms step_avg:92.44ms +step:1424/1660 train_time:131632ms step_avg:92.44ms +step:1425/1660 train_time:131724ms step_avg:92.44ms +step:1426/1660 train_time:131817ms step_avg:92.44ms +step:1427/1660 train_time:131909ms step_avg:92.44ms +step:1428/1660 train_time:132002ms step_avg:92.44ms +step:1429/1660 train_time:132094ms step_avg:92.44ms +step:1430/1660 train_time:132188ms step_avg:92.44ms +step:1431/1660 train_time:132282ms step_avg:92.44ms +step:1432/1660 train_time:132376ms step_avg:92.44ms +step:1433/1660 train_time:132470ms step_avg:92.44ms +step:1434/1660 train_time:132563ms step_avg:92.44ms +step:1435/1660 train_time:132656ms step_avg:92.44ms +step:1436/1660 train_time:132749ms step_avg:92.44ms +step:1437/1660 train_time:132842ms step_avg:92.44ms +step:1438/1660 train_time:132935ms step_avg:92.44ms +step:1439/1660 train_time:133028ms step_avg:92.44ms +step:1440/1660 train_time:133121ms step_avg:92.45ms +step:1441/1660 train_time:133213ms step_avg:92.45ms +step:1442/1660 train_time:133307ms step_avg:92.45ms +step:1443/1660 train_time:133400ms step_avg:92.45ms +step:1444/1660 train_time:133494ms step_avg:92.45ms +step:1445/1660 train_time:133588ms step_avg:92.45ms +step:1446/1660 train_time:133681ms step_avg:92.45ms +step:1447/1660 train_time:133775ms step_avg:92.45ms +step:1448/1660 train_time:133867ms step_avg:92.45ms +step:1449/1660 train_time:133960ms step_avg:92.45ms +step:1450/1660 train_time:134053ms step_avg:92.45ms +step:1451/1660 train_time:134147ms step_avg:92.45ms +step:1452/1660 train_time:134240ms step_avg:92.45ms +step:1453/1660 train_time:134333ms step_avg:92.45ms +step:1454/1660 train_time:134427ms step_avg:92.45ms +step:1455/1660 train_time:134522ms step_avg:92.45ms +step:1456/1660 train_time:134615ms step_avg:92.46ms +step:1457/1660 train_time:134709ms step_avg:92.46ms +step:1458/1660 train_time:134802ms step_avg:92.46ms +step:1459/1660 train_time:134895ms step_avg:92.46ms +step:1460/1660 train_time:134989ms step_avg:92.46ms +step:1461/1660 train_time:135082ms step_avg:92.46ms +step:1462/1660 train_time:135176ms step_avg:92.46ms +step:1463/1660 train_time:135269ms step_avg:92.46ms +step:1464/1660 train_time:135362ms step_avg:92.46ms +step:1465/1660 train_time:135455ms step_avg:92.46ms +step:1466/1660 train_time:135548ms step_avg:92.46ms +step:1467/1660 train_time:135641ms step_avg:92.46ms +step:1468/1660 train_time:135735ms step_avg:92.46ms +step:1469/1660 train_time:135828ms step_avg:92.46ms +step:1470/1660 train_time:135922ms step_avg:92.46ms +step:1471/1660 train_time:136014ms step_avg:92.46ms +step:1472/1660 train_time:136109ms step_avg:92.47ms +step:1473/1660 train_time:136202ms step_avg:92.47ms +step:1474/1660 train_time:136296ms step_avg:92.47ms +step:1475/1660 train_time:136389ms step_avg:92.47ms +step:1476/1660 train_time:136482ms step_avg:92.47ms +step:1477/1660 train_time:136576ms step_avg:92.47ms +step:1478/1660 train_time:136668ms step_avg:92.47ms +step:1479/1660 train_time:136762ms step_avg:92.47ms +step:1480/1660 train_time:136855ms step_avg:92.47ms +step:1481/1660 train_time:136949ms step_avg:92.47ms +step:1482/1660 train_time:137043ms step_avg:92.47ms +step:1483/1660 train_time:137135ms step_avg:92.47ms +step:1484/1660 train_time:137228ms step_avg:92.47ms +step:1485/1660 train_time:137322ms step_avg:92.47ms +step:1486/1660 train_time:137415ms step_avg:92.47ms +step:1487/1660 train_time:137508ms step_avg:92.47ms +step:1488/1660 train_time:137601ms step_avg:92.47ms +step:1489/1660 train_time:137693ms step_avg:92.47ms +step:1490/1660 train_time:137787ms step_avg:92.47ms +step:1491/1660 train_time:137881ms step_avg:92.48ms +step:1492/1660 train_time:137974ms step_avg:92.48ms +step:1493/1660 train_time:138068ms step_avg:92.48ms +step:1494/1660 train_time:138161ms step_avg:92.48ms +step:1495/1660 train_time:138254ms step_avg:92.48ms +step:1496/1660 train_time:138348ms step_avg:92.48ms +step:1497/1660 train_time:138441ms step_avg:92.48ms +step:1498/1660 train_time:138534ms step_avg:92.48ms +step:1499/1660 train_time:138627ms step_avg:92.48ms +step:1500/1660 train_time:138720ms step_avg:92.48ms +step:1500/1660 val_loss:3.3111 train_time:138815ms step_avg:92.54ms +step:1501/1660 train_time:138836ms step_avg:92.50ms +step:1502/1660 train_time:138913ms step_avg:92.49ms +step:1503/1660 train_time:139008ms step_avg:92.49ms +step:1504/1660 train_time:139102ms step_avg:92.49ms +step:1505/1660 train_time:139194ms step_avg:92.49ms +step:1506/1660 train_time:139285ms step_avg:92.49ms +step:1507/1660 train_time:139378ms step_avg:92.49ms +step:1508/1660 train_time:139470ms step_avg:92.49ms +step:1509/1660 train_time:139562ms step_avg:92.49ms +step:1510/1660 train_time:139655ms step_avg:92.49ms +step:1511/1660 train_time:139749ms step_avg:92.49ms +step:1512/1660 train_time:139845ms step_avg:92.49ms +step:1513/1660 train_time:139940ms step_avg:92.49ms +step:1514/1660 train_time:140035ms step_avg:92.49ms +step:1515/1660 train_time:140128ms step_avg:92.49ms +step:1516/1660 train_time:140222ms step_avg:92.49ms +step:1517/1660 train_time:140314ms step_avg:92.49ms +step:1518/1660 train_time:140406ms step_avg:92.49ms +step:1519/1660 train_time:140498ms step_avg:92.49ms +step:1520/1660 train_time:140590ms step_avg:92.49ms +step:1521/1660 train_time:140683ms step_avg:92.49ms +step:1522/1660 train_time:140777ms step_avg:92.50ms +step:1523/1660 train_time:140873ms step_avg:92.50ms +step:1524/1660 train_time:140966ms step_avg:92.50ms +step:1525/1660 train_time:141060ms step_avg:92.50ms +step:1526/1660 train_time:141154ms step_avg:92.50ms +step:1527/1660 train_time:141247ms step_avg:92.50ms +step:1528/1660 train_time:141339ms step_avg:92.50ms +step:1529/1660 train_time:141431ms step_avg:92.50ms +step:1530/1660 train_time:141523ms step_avg:92.50ms +step:1531/1660 train_time:141615ms step_avg:92.50ms +step:1532/1660 train_time:141709ms step_avg:92.50ms +step:1533/1660 train_time:141802ms step_avg:92.50ms +step:1534/1660 train_time:141896ms step_avg:92.50ms +step:1535/1660 train_time:141990ms step_avg:92.50ms +step:1536/1660 train_time:142084ms step_avg:92.50ms +step:1537/1660 train_time:142177ms step_avg:92.50ms +step:1538/1660 train_time:142271ms step_avg:92.50ms +step:1539/1660 train_time:142363ms step_avg:92.50ms +step:1540/1660 train_time:142456ms step_avg:92.50ms +step:1541/1660 train_time:142548ms step_avg:92.50ms +step:1542/1660 train_time:142641ms step_avg:92.50ms +step:1543/1660 train_time:142735ms step_avg:92.51ms +step:1544/1660 train_time:142829ms step_avg:92.51ms +step:1545/1660 train_time:142924ms step_avg:92.51ms +step:1546/1660 train_time:143018ms step_avg:92.51ms +step:1547/1660 train_time:143112ms step_avg:92.51ms +step:1548/1660 train_time:143205ms step_avg:92.51ms +step:1549/1660 train_time:143298ms step_avg:92.51ms +step:1550/1660 train_time:143391ms step_avg:92.51ms +step:1551/1660 train_time:143483ms step_avg:92.51ms +step:1552/1660 train_time:143577ms step_avg:92.51ms +step:1553/1660 train_time:143670ms step_avg:92.51ms +step:1554/1660 train_time:143763ms step_avg:92.51ms +step:1555/1660 train_time:143857ms step_avg:92.51ms +step:1556/1660 train_time:143952ms step_avg:92.51ms +step:1557/1660 train_time:144045ms step_avg:92.51ms +step:1558/1660 train_time:144140ms step_avg:92.52ms +step:1559/1660 train_time:144233ms step_avg:92.52ms +step:1560/1660 train_time:144326ms step_avg:92.52ms +step:1561/1660 train_time:144420ms step_avg:92.52ms +step:1562/1660 train_time:144513ms step_avg:92.52ms +step:1563/1660 train_time:144606ms step_avg:92.52ms +step:1564/1660 train_time:144699ms step_avg:92.52ms +step:1565/1660 train_time:144793ms step_avg:92.52ms +step:1566/1660 train_time:144886ms step_avg:92.52ms +step:1567/1660 train_time:144979ms step_avg:92.52ms +step:1568/1660 train_time:145073ms step_avg:92.52ms +step:1569/1660 train_time:145166ms step_avg:92.52ms +step:1570/1660 train_time:145259ms step_avg:92.52ms +step:1571/1660 train_time:145352ms step_avg:92.52ms +step:1572/1660 train_time:145445ms step_avg:92.52ms +step:1573/1660 train_time:145538ms step_avg:92.52ms +step:1574/1660 train_time:145631ms step_avg:92.52ms +step:1575/1660 train_time:145724ms step_avg:92.52ms +step:1576/1660 train_time:145817ms step_avg:92.52ms +step:1577/1660 train_time:145910ms step_avg:92.52ms +step:1578/1660 train_time:146004ms step_avg:92.52ms +step:1579/1660 train_time:146097ms step_avg:92.52ms +step:1580/1660 train_time:146191ms step_avg:92.53ms +step:1581/1660 train_time:146284ms step_avg:92.53ms +step:1582/1660 train_time:146378ms step_avg:92.53ms +step:1583/1660 train_time:146471ms step_avg:92.53ms +step:1584/1660 train_time:146563ms step_avg:92.53ms +step:1585/1660 train_time:146656ms step_avg:92.53ms +step:1586/1660 train_time:146749ms step_avg:92.53ms +step:1587/1660 train_time:146843ms step_avg:92.53ms +step:1588/1660 train_time:146937ms step_avg:92.53ms +step:1589/1660 train_time:147030ms step_avg:92.53ms +step:1590/1660 train_time:147124ms step_avg:92.53ms +step:1591/1660 train_time:147218ms step_avg:92.53ms +step:1592/1660 train_time:147311ms step_avg:92.53ms +step:1593/1660 train_time:147403ms step_avg:92.53ms +step:1594/1660 train_time:147496ms step_avg:92.53ms +step:1595/1660 train_time:147590ms step_avg:92.53ms +step:1596/1660 train_time:147682ms step_avg:92.53ms +step:1597/1660 train_time:147776ms step_avg:92.53ms +step:1598/1660 train_time:147870ms step_avg:92.53ms +step:1599/1660 train_time:147962ms step_avg:92.53ms +step:1600/1660 train_time:148056ms step_avg:92.53ms +step:1601/1660 train_time:148149ms step_avg:92.54ms +step:1602/1660 train_time:148243ms step_avg:92.54ms +step:1603/1660 train_time:148336ms step_avg:92.54ms +step:1604/1660 train_time:148430ms step_avg:92.54ms +step:1605/1660 train_time:148523ms step_avg:92.54ms +step:1606/1660 train_time:148617ms step_avg:92.54ms +step:1607/1660 train_time:148709ms step_avg:92.54ms +step:1608/1660 train_time:148802ms step_avg:92.54ms +step:1609/1660 train_time:148896ms step_avg:92.54ms +step:1610/1660 train_time:148989ms step_avg:92.54ms +step:1611/1660 train_time:149083ms step_avg:92.54ms +step:1612/1660 train_time:149176ms step_avg:92.54ms +step:1613/1660 train_time:149270ms step_avg:92.54ms +step:1614/1660 train_time:149363ms step_avg:92.54ms +step:1615/1660 train_time:149456ms step_avg:92.54ms +step:1616/1660 train_time:149548ms step_avg:92.54ms +step:1617/1660 train_time:149642ms step_avg:92.54ms +step:1618/1660 train_time:149736ms step_avg:92.54ms +step:1619/1660 train_time:149829ms step_avg:92.54ms +step:1620/1660 train_time:149923ms step_avg:92.54ms +step:1621/1660 train_time:150017ms step_avg:92.55ms +step:1622/1660 train_time:150110ms step_avg:92.55ms +step:1623/1660 train_time:150203ms step_avg:92.55ms +step:1624/1660 train_time:150296ms step_avg:92.55ms +step:1625/1660 train_time:150390ms step_avg:92.55ms +step:1625/1660 val_loss:3.2860 train_time:150484ms step_avg:92.61ms +step:1626/1660 train_time:150504ms step_avg:92.56ms +step:1627/1660 train_time:150580ms step_avg:92.55ms +step:1628/1660 train_time:150676ms step_avg:92.55ms +step:1629/1660 train_time:150769ms step_avg:92.55ms +step:1630/1660 train_time:150861ms step_avg:92.55ms +step:1631/1660 train_time:150953ms step_avg:92.55ms +step:1632/1660 train_time:151045ms step_avg:92.55ms +step:1633/1660 train_time:151137ms step_avg:92.55ms +step:1634/1660 train_time:151230ms step_avg:92.55ms +step:1635/1660 train_time:151323ms step_avg:92.55ms +step:1636/1660 train_time:151417ms step_avg:92.55ms +step:1637/1660 train_time:151512ms step_avg:92.55ms +step:1638/1660 train_time:151608ms step_avg:92.56ms +step:1639/1660 train_time:151703ms step_avg:92.56ms +step:1640/1660 train_time:151796ms step_avg:92.56ms +step:1641/1660 train_time:151889ms step_avg:92.56ms +step:1642/1660 train_time:151982ms step_avg:92.56ms +step:1643/1660 train_time:152075ms step_avg:92.56ms +step:1644/1660 train_time:152168ms step_avg:92.56ms +step:1645/1660 train_time:152260ms step_avg:92.56ms +step:1646/1660 train_time:152353ms step_avg:92.56ms +step:1647/1660 train_time:152447ms step_avg:92.56ms +step:1648/1660 train_time:152541ms step_avg:92.56ms +step:1649/1660 train_time:152635ms step_avg:92.56ms +step:1650/1660 train_time:152729ms step_avg:92.56ms +step:1651/1660 train_time:152823ms step_avg:92.56ms +step:1652/1660 train_time:152916ms step_avg:92.56ms +step:1653/1660 train_time:153008ms step_avg:92.56ms +step:1654/1660 train_time:153101ms step_avg:92.56ms +step:1655/1660 train_time:153194ms step_avg:92.56ms +step:1656/1660 train_time:153286ms step_avg:92.56ms +step:1657/1660 train_time:153379ms step_avg:92.56ms +step:1658/1660 train_time:153472ms step_avg:92.56ms +step:1659/1660 train_time:153567ms step_avg:92.57ms +step:1660/1660 train_time:153661ms step_avg:92.57ms +step:1660/1660 val_loss:3.2781 train_time:153756ms step_avg:92.62ms +peak memory allocated: 31836 MiB reserved: 47036 MiB diff --git a/records/091525_ThreadingFinalWindow/README.md b/records/091525_ThreadingFinalWindow/README.md new file mode 100644 index 000000000..019a26bb6 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/README.md @@ -0,0 +1,29 @@ +## New WR 153.9s: Apply threading to preload data and extend final layer attention window for validation + +This PR builds on all recent WR improvements including PR #125 by @bernard24. This one adds: + +- Start prefetching and indexing the next shard immediately. Since this occurs on the CPU, there is ample time to perform this during the GPU heavy workload, and we shouldn't be bottlenecking GPU activities on CPU data indexing. (1.5s) +- Only partially index the first shard before starting to train on it. Kickoff a parallel thread to finish indexing it, which gets picked up on the 5th step. (300ms) +- Extend the final layer attention window out to 20 for validation (no need to apply YaRN for this layer). If curious, some inspiration for this change came from: https://medium.com/@larry36d/formation-of-induction-heads-in-modded-nanogpt-5eb899de89e4. This dropped loss by roughly 0.001 and enabled -10 steps (1s) while still cleanly remaining under 3.28. + +(Exact runtimes will vary by GPU provider, mine is about 1s slower than I believe some GPU setups will get) + +Validation: +``` +import scipy.stats +import torch + +accs = [3.2813, 3.2778, 3.2781, 3.2764, 3.277 , 3.2787, 3.2767, 3.2807, + 3.2769, 3.2774] + +times = [153.869, 154.126, 153.756, 153.82 , 153.816, 154.014, 153.887, + 153.811, 154.015, 153.828] + +print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue) +# p=0.0030 +print("acc:", torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0017), tensor(3.2781)) + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (tensor(0.1180), tensor(153.8942)) +``` \ No newline at end of file diff --git a/records/091525_ThreadingFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt b/records/091525_ThreadingFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt new file mode 100644 index 000000000..1c7026728 --- /dev/null +++ b/records/091525_ThreadingFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt @@ -0,0 +1,3111 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1660 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"data_threading_1660/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer is highly indifferent to length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Tue Sep 16 03:46:40 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 29C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 27C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 29C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 29C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 194072 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 194073 C /usr/bin/python3 614MiB | +| 0 N/A N/A 194074 C /usr/bin/python3 614MiB | +| 0 N/A N/A 194075 C /usr/bin/python3 614MiB | +| 0 N/A N/A 194076 C /usr/bin/python3 614MiB | +| 0 N/A N/A 194077 C /usr/bin/python3 614MiB | +| 0 N/A N/A 194078 C /usr/bin/python3 614MiB | +| 0 N/A N/A 194079 C /usr/bin/python3 614MiB | +| 1 N/A N/A 194073 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 194074 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 194075 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 194076 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 194077 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 194078 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 194079 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1660 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1660 train_time:148ms step_avg:148.18ms +step:2/1660 train_time:171ms step_avg:85.50ms +step:3/1660 train_time:238ms step_avg:79.45ms +step:4/1660 train_time:328ms step_avg:81.94ms +step:5/1660 train_time:418ms step_avg:83.69ms +step:6/1660 train_time:509ms step_avg:84.78ms +step:7/1660 train_time:599ms step_avg:85.64ms +step:8/1660 train_time:690ms step_avg:86.25ms +step:9/1660 train_time:781ms step_avg:86.74ms +step:10/1660 train_time:871ms step_avg:87.14ms +step:11/1660 train_time:962ms step_avg:87.45ms +step:12/1660 train_time:1058ms step_avg:88.13ms +step:13/1660 train_time:1153ms step_avg:88.68ms +step:14/1660 train_time:1246ms step_avg:88.98ms +step:15/1660 train_time:1337ms step_avg:89.16ms +step:16/1660 train_time:1428ms step_avg:89.26ms +step:17/1660 train_time:1519ms step_avg:89.35ms +step:18/1660 train_time:1609ms step_avg:89.41ms +step:19/1660 train_time:1700ms step_avg:89.49ms +step:20/1660 train_time:1791ms step_avg:89.56ms +step:21/1660 train_time:1882ms step_avg:89.61ms +step:22/1660 train_time:1975ms step_avg:89.75ms +step:23/1660 train_time:2067ms step_avg:89.87ms +step:24/1660 train_time:2162ms step_avg:90.09ms +step:25/1660 train_time:2255ms step_avg:90.22ms +step:26/1660 train_time:2347ms step_avg:90.28ms +step:27/1660 train_time:2439ms step_avg:90.32ms +step:28/1660 train_time:2530ms step_avg:90.35ms +step:29/1660 train_time:2620ms step_avg:90.36ms +step:30/1660 train_time:2711ms step_avg:90.38ms +step:31/1660 train_time:2802ms step_avg:90.38ms +step:32/1660 train_time:2893ms step_avg:90.40ms +step:33/1660 train_time:2984ms step_avg:90.41ms +step:34/1660 train_time:3076ms step_avg:90.48ms +step:35/1660 train_time:3169ms step_avg:90.54ms +step:36/1660 train_time:3262ms step_avg:90.60ms +step:37/1660 train_time:3354ms step_avg:90.64ms +step:38/1660 train_time:3445ms step_avg:90.67ms +step:39/1660 train_time:3537ms step_avg:90.69ms +step:40/1660 train_time:3628ms step_avg:90.70ms +step:41/1660 train_time:3719ms step_avg:90.70ms +step:42/1660 train_time:3811ms step_avg:90.73ms +step:43/1660 train_time:3902ms step_avg:90.74ms +step:44/1660 train_time:3992ms step_avg:90.73ms +step:45/1660 train_time:4084ms step_avg:90.75ms +step:46/1660 train_time:4177ms step_avg:90.80ms +step:47/1660 train_time:4268ms step_avg:90.82ms +step:48/1660 train_time:4361ms step_avg:90.86ms +step:49/1660 train_time:4453ms step_avg:90.88ms +step:50/1660 train_time:4544ms step_avg:90.89ms +step:51/1660 train_time:4636ms step_avg:90.91ms +step:52/1660 train_time:4728ms step_avg:90.92ms +step:53/1660 train_time:4818ms step_avg:90.91ms +step:54/1660 train_time:4910ms step_avg:90.92ms +step:55/1660 train_time:5001ms step_avg:90.93ms +step:56/1660 train_time:5093ms step_avg:90.95ms +step:57/1660 train_time:5185ms step_avg:90.96ms +step:58/1660 train_time:5278ms step_avg:91.00ms +step:59/1660 train_time:5370ms step_avg:91.01ms +step:60/1660 train_time:5461ms step_avg:91.02ms +step:61/1660 train_time:5552ms step_avg:91.02ms +step:62/1660 train_time:5644ms step_avg:91.03ms +step:63/1660 train_time:5736ms step_avg:91.05ms +step:64/1660 train_time:5827ms step_avg:91.05ms +step:65/1660 train_time:5919ms step_avg:91.06ms +step:66/1660 train_time:6011ms step_avg:91.08ms +step:67/1660 train_time:6103ms step_avg:91.08ms +step:68/1660 train_time:6194ms step_avg:91.09ms +step:69/1660 train_time:6286ms step_avg:91.10ms +step:70/1660 train_time:6378ms step_avg:91.11ms +step:71/1660 train_time:6470ms step_avg:91.12ms +step:72/1660 train_time:6562ms step_avg:91.14ms +step:73/1660 train_time:6653ms step_avg:91.14ms +step:74/1660 train_time:6745ms step_avg:91.15ms +step:75/1660 train_time:6837ms step_avg:91.16ms +step:76/1660 train_time:6928ms step_avg:91.16ms +step:77/1660 train_time:7020ms step_avg:91.16ms +step:78/1660 train_time:7112ms step_avg:91.18ms +step:79/1660 train_time:7204ms step_avg:91.19ms +step:80/1660 train_time:7295ms step_avg:91.19ms +step:81/1660 train_time:7387ms step_avg:91.19ms +step:82/1660 train_time:7481ms step_avg:91.23ms +step:83/1660 train_time:7573ms step_avg:91.24ms +step:84/1660 train_time:7664ms step_avg:91.24ms +step:85/1660 train_time:7757ms step_avg:91.25ms +step:86/1660 train_time:7847ms step_avg:91.25ms +step:87/1660 train_time:7939ms step_avg:91.25ms +step:88/1660 train_time:8031ms step_avg:91.26ms +step:89/1660 train_time:8122ms step_avg:91.25ms +step:90/1660 train_time:8213ms step_avg:91.26ms +step:91/1660 train_time:8306ms step_avg:91.27ms +step:92/1660 train_time:8398ms step_avg:91.28ms +step:93/1660 train_time:8489ms step_avg:91.28ms +step:94/1660 train_time:8581ms step_avg:91.28ms +step:95/1660 train_time:8673ms step_avg:91.29ms +step:96/1660 train_time:8764ms step_avg:91.29ms +step:97/1660 train_time:8855ms step_avg:91.29ms +step:98/1660 train_time:8946ms step_avg:91.29ms +step:99/1660 train_time:9038ms step_avg:91.29ms +step:100/1660 train_time:9129ms step_avg:91.29ms +step:101/1660 train_time:9220ms step_avg:91.29ms +step:102/1660 train_time:9312ms step_avg:91.29ms +step:103/1660 train_time:9404ms step_avg:91.30ms +step:104/1660 train_time:9496ms step_avg:91.31ms +step:105/1660 train_time:9587ms step_avg:91.30ms +step:106/1660 train_time:9680ms step_avg:91.32ms +step:107/1660 train_time:9771ms step_avg:91.32ms +step:108/1660 train_time:9862ms step_avg:91.32ms +step:109/1660 train_time:9954ms step_avg:91.32ms +step:110/1660 train_time:10044ms step_avg:91.31ms +step:111/1660 train_time:10136ms step_avg:91.31ms +step:112/1660 train_time:10227ms step_avg:91.32ms +step:113/1660 train_time:10319ms step_avg:91.32ms +step:114/1660 train_time:10411ms step_avg:91.33ms +step:115/1660 train_time:10502ms step_avg:91.32ms +step:116/1660 train_time:10594ms step_avg:91.32ms +step:117/1660 train_time:10684ms step_avg:91.32ms +step:118/1660 train_time:10777ms step_avg:91.33ms +step:119/1660 train_time:10868ms step_avg:91.33ms +step:120/1660 train_time:10960ms step_avg:91.33ms +step:121/1660 train_time:11052ms step_avg:91.34ms +step:122/1660 train_time:11143ms step_avg:91.34ms +step:123/1660 train_time:11234ms step_avg:91.34ms +step:124/1660 train_time:11325ms step_avg:91.33ms +step:125/1660 train_time:11416ms step_avg:91.33ms +step:125/1660 val_loss:4.3163 train_time:11509ms step_avg:92.07ms +step:126/1660 train_time:11530ms step_avg:91.50ms +step:127/1660 train_time:11603ms step_avg:91.36ms +step:128/1660 train_time:11705ms step_avg:91.45ms +step:129/1660 train_time:11804ms step_avg:91.50ms +step:130/1660 train_time:11896ms step_avg:91.51ms +step:131/1660 train_time:11986ms step_avg:91.50ms +step:132/1660 train_time:12076ms step_avg:91.49ms +step:133/1660 train_time:12166ms step_avg:91.47ms +step:134/1660 train_time:12257ms step_avg:91.47ms +step:135/1660 train_time:12347ms step_avg:91.46ms +step:136/1660 train_time:12437ms step_avg:91.45ms +step:137/1660 train_time:12528ms step_avg:91.44ms +step:138/1660 train_time:12621ms step_avg:91.46ms +step:139/1660 train_time:12716ms step_avg:91.48ms +step:140/1660 train_time:12809ms step_avg:91.49ms +step:141/1660 train_time:12905ms step_avg:91.52ms +step:142/1660 train_time:12996ms step_avg:91.52ms +step:143/1660 train_time:13087ms step_avg:91.52ms +step:144/1660 train_time:13178ms step_avg:91.51ms +step:145/1660 train_time:13267ms step_avg:91.50ms +step:146/1660 train_time:13358ms step_avg:91.49ms +step:147/1660 train_time:13448ms step_avg:91.49ms +step:148/1660 train_time:13539ms step_avg:91.48ms +step:149/1660 train_time:13631ms step_avg:91.49ms +step:150/1660 train_time:13724ms step_avg:91.49ms +step:151/1660 train_time:13817ms step_avg:91.50ms +step:152/1660 train_time:13909ms step_avg:91.51ms +step:153/1660 train_time:14002ms step_avg:91.52ms +step:154/1660 train_time:14093ms step_avg:91.52ms +step:155/1660 train_time:14185ms step_avg:91.52ms +step:156/1660 train_time:14276ms step_avg:91.51ms +step:157/1660 train_time:14367ms step_avg:91.51ms +step:158/1660 train_time:14457ms step_avg:91.50ms +step:159/1660 train_time:14548ms step_avg:91.50ms +step:160/1660 train_time:14640ms step_avg:91.50ms +step:161/1660 train_time:14731ms step_avg:91.50ms +step:162/1660 train_time:14824ms step_avg:91.51ms +step:163/1660 train_time:14917ms step_avg:91.52ms +step:164/1660 train_time:15008ms step_avg:91.51ms +step:165/1660 train_time:15100ms step_avg:91.51ms +step:166/1660 train_time:15191ms step_avg:91.51ms +step:167/1660 train_time:15283ms step_avg:91.52ms +step:168/1660 train_time:15374ms step_avg:91.51ms +step:169/1660 train_time:15465ms step_avg:91.51ms +step:170/1660 train_time:15556ms step_avg:91.51ms +step:171/1660 train_time:15647ms step_avg:91.50ms +step:172/1660 train_time:15738ms step_avg:91.50ms +step:173/1660 train_time:15829ms step_avg:91.50ms +step:174/1660 train_time:15923ms step_avg:91.51ms +step:175/1660 train_time:16014ms step_avg:91.51ms +step:176/1660 train_time:16106ms step_avg:91.51ms +step:177/1660 train_time:16198ms step_avg:91.51ms +step:178/1660 train_time:16288ms step_avg:91.51ms +step:179/1660 train_time:16379ms step_avg:91.50ms +step:180/1660 train_time:16470ms step_avg:91.50ms +step:181/1660 train_time:16561ms step_avg:91.50ms +step:182/1660 train_time:16652ms step_avg:91.49ms +step:183/1660 train_time:16744ms step_avg:91.50ms +step:184/1660 train_time:16836ms step_avg:91.50ms +step:185/1660 train_time:16927ms step_avg:91.50ms +step:186/1660 train_time:17019ms step_avg:91.50ms +step:187/1660 train_time:17110ms step_avg:91.50ms +step:188/1660 train_time:17203ms step_avg:91.50ms +step:189/1660 train_time:17295ms step_avg:91.51ms +step:190/1660 train_time:17387ms step_avg:91.51ms +step:191/1660 train_time:17477ms step_avg:91.50ms +step:192/1660 train_time:17568ms step_avg:91.50ms +step:193/1660 train_time:17659ms step_avg:91.50ms +step:194/1660 train_time:17750ms step_avg:91.50ms +step:195/1660 train_time:17843ms step_avg:91.50ms +step:196/1660 train_time:17934ms step_avg:91.50ms +step:197/1660 train_time:18026ms step_avg:91.50ms +step:198/1660 train_time:18118ms step_avg:91.51ms +step:199/1660 train_time:18210ms step_avg:91.51ms +step:200/1660 train_time:18302ms step_avg:91.51ms +step:201/1660 train_time:18394ms step_avg:91.51ms +step:202/1660 train_time:18485ms step_avg:91.51ms +step:203/1660 train_time:18576ms step_avg:91.51ms +step:204/1660 train_time:18667ms step_avg:91.50ms +step:205/1660 train_time:18759ms step_avg:91.51ms +step:206/1660 train_time:18850ms step_avg:91.51ms +step:207/1660 train_time:18942ms step_avg:91.51ms +step:208/1660 train_time:19032ms step_avg:91.50ms +step:209/1660 train_time:19125ms step_avg:91.51ms +step:210/1660 train_time:19217ms step_avg:91.51ms +step:211/1660 train_time:19308ms step_avg:91.51ms +step:212/1660 train_time:19399ms step_avg:91.51ms +step:213/1660 train_time:19491ms step_avg:91.51ms +step:214/1660 train_time:19583ms step_avg:91.51ms +step:215/1660 train_time:19674ms step_avg:91.51ms +step:216/1660 train_time:19765ms step_avg:91.51ms +step:217/1660 train_time:19858ms step_avg:91.51ms +step:218/1660 train_time:19949ms step_avg:91.51ms +step:219/1660 train_time:20040ms step_avg:91.51ms +step:220/1660 train_time:20131ms step_avg:91.50ms +step:221/1660 train_time:20223ms step_avg:91.51ms +step:222/1660 train_time:20315ms step_avg:91.51ms +step:223/1660 train_time:20406ms step_avg:91.51ms +step:224/1660 train_time:20498ms step_avg:91.51ms +step:225/1660 train_time:20590ms step_avg:91.51ms +step:226/1660 train_time:20683ms step_avg:91.52ms +step:227/1660 train_time:20774ms step_avg:91.52ms +step:228/1660 train_time:20865ms step_avg:91.51ms +step:229/1660 train_time:20958ms step_avg:91.52ms +step:230/1660 train_time:21048ms step_avg:91.52ms +step:231/1660 train_time:21140ms step_avg:91.51ms +step:232/1660 train_time:21231ms step_avg:91.51ms +step:233/1660 train_time:21323ms step_avg:91.52ms +step:234/1660 train_time:21416ms step_avg:91.52ms +step:235/1660 train_time:21507ms step_avg:91.52ms +step:236/1660 train_time:21598ms step_avg:91.52ms +step:237/1660 train_time:21689ms step_avg:91.52ms +step:238/1660 train_time:21781ms step_avg:91.52ms +step:239/1660 train_time:21872ms step_avg:91.52ms +step:240/1660 train_time:21964ms step_avg:91.52ms +step:241/1660 train_time:22055ms step_avg:91.52ms +step:242/1660 train_time:22147ms step_avg:91.51ms +step:243/1660 train_time:22239ms step_avg:91.52ms +step:244/1660 train_time:22330ms step_avg:91.51ms +step:245/1660 train_time:22421ms step_avg:91.51ms +step:246/1660 train_time:22512ms step_avg:91.51ms +step:247/1660 train_time:22606ms step_avg:91.52ms +step:248/1660 train_time:22698ms step_avg:91.52ms +step:249/1660 train_time:22789ms step_avg:91.52ms +step:250/1660 train_time:22880ms step_avg:91.52ms +step:250/1660 val_loss:3.9765 train_time:22972ms step_avg:91.89ms +step:251/1660 train_time:22991ms step_avg:91.60ms +step:252/1660 train_time:23067ms step_avg:91.54ms +step:253/1660 train_time:23166ms step_avg:91.56ms +step:254/1660 train_time:23259ms step_avg:91.57ms +step:255/1660 train_time:23351ms step_avg:91.57ms +step:256/1660 train_time:23442ms step_avg:91.57ms +step:257/1660 train_time:23533ms step_avg:91.57ms +step:258/1660 train_time:23623ms step_avg:91.56ms +step:259/1660 train_time:23713ms step_avg:91.56ms +step:260/1660 train_time:23803ms step_avg:91.55ms +step:261/1660 train_time:23894ms step_avg:91.55ms +step:262/1660 train_time:23985ms step_avg:91.55ms +step:263/1660 train_time:24080ms step_avg:91.56ms +step:264/1660 train_time:24174ms step_avg:91.57ms +step:265/1660 train_time:24266ms step_avg:91.57ms +step:266/1660 train_time:24358ms step_avg:91.57ms +step:267/1660 train_time:24449ms step_avg:91.57ms +step:268/1660 train_time:24540ms step_avg:91.57ms +step:269/1660 train_time:24631ms step_avg:91.57ms +step:270/1660 train_time:24722ms step_avg:91.56ms +step:271/1660 train_time:24813ms step_avg:91.56ms +step:272/1660 train_time:24903ms step_avg:91.56ms +step:273/1660 train_time:24996ms step_avg:91.56ms +step:274/1660 train_time:25087ms step_avg:91.56ms +step:275/1660 train_time:25181ms step_avg:91.57ms +step:276/1660 train_time:25273ms step_avg:91.57ms +step:277/1660 train_time:25365ms step_avg:91.57ms +step:278/1660 train_time:25457ms step_avg:91.57ms +step:279/1660 train_time:25549ms step_avg:91.57ms +step:280/1660 train_time:25640ms step_avg:91.57ms +step:281/1660 train_time:25731ms step_avg:91.57ms +step:282/1660 train_time:25821ms step_avg:91.56ms +step:283/1660 train_time:25912ms step_avg:91.56ms +step:284/1660 train_time:26004ms step_avg:91.56ms +step:285/1660 train_time:26095ms step_avg:91.56ms +step:286/1660 train_time:26187ms step_avg:91.56ms +step:287/1660 train_time:26279ms step_avg:91.56ms +step:288/1660 train_time:26371ms step_avg:91.57ms +step:289/1660 train_time:26463ms step_avg:91.57ms +step:290/1660 train_time:26555ms step_avg:91.57ms +step:291/1660 train_time:26646ms step_avg:91.57ms +step:292/1660 train_time:26737ms step_avg:91.56ms +step:293/1660 train_time:26827ms step_avg:91.56ms +step:294/1660 train_time:26918ms step_avg:91.56ms +step:295/1660 train_time:27009ms step_avg:91.56ms +step:296/1660 train_time:27101ms step_avg:91.56ms +step:297/1660 train_time:27193ms step_avg:91.56ms +step:298/1660 train_time:27283ms step_avg:91.56ms +step:299/1660 train_time:27376ms step_avg:91.56ms +step:300/1660 train_time:27468ms step_avg:91.56ms +step:301/1660 train_time:27560ms step_avg:91.56ms +step:302/1660 train_time:27652ms step_avg:91.56ms +step:303/1660 train_time:27743ms step_avg:91.56ms +step:304/1660 train_time:27835ms step_avg:91.56ms +step:305/1660 train_time:27925ms step_avg:91.56ms +step:306/1660 train_time:28017ms step_avg:91.56ms +step:307/1660 train_time:28108ms step_avg:91.56ms +step:308/1660 train_time:28199ms step_avg:91.56ms +step:309/1660 train_time:28291ms step_avg:91.56ms +step:310/1660 train_time:28383ms step_avg:91.56ms +step:311/1660 train_time:28475ms step_avg:91.56ms +step:312/1660 train_time:28567ms step_avg:91.56ms +step:313/1660 train_time:28659ms step_avg:91.56ms +step:314/1660 train_time:28750ms step_avg:91.56ms +step:315/1660 train_time:28841ms step_avg:91.56ms +step:316/1660 train_time:28932ms step_avg:91.56ms +step:317/1660 train_time:29024ms step_avg:91.56ms +step:318/1660 train_time:29115ms step_avg:91.56ms +step:319/1660 train_time:29205ms step_avg:91.55ms +step:320/1660 train_time:29297ms step_avg:91.55ms +step:321/1660 train_time:29388ms step_avg:91.55ms +step:322/1660 train_time:29482ms step_avg:91.56ms +step:323/1660 train_time:29574ms step_avg:91.56ms +step:324/1660 train_time:29665ms step_avg:91.56ms +step:325/1660 train_time:29758ms step_avg:91.56ms +step:326/1660 train_time:29850ms step_avg:91.56ms +step:327/1660 train_time:29941ms step_avg:91.56ms +step:328/1660 train_time:30032ms step_avg:91.56ms +step:329/1660 train_time:30123ms step_avg:91.56ms +step:330/1660 train_time:30214ms step_avg:91.56ms +step:331/1660 train_time:30306ms step_avg:91.56ms +step:332/1660 train_time:30398ms step_avg:91.56ms +step:333/1660 train_time:30489ms step_avg:91.56ms +step:334/1660 train_time:30582ms step_avg:91.56ms +step:335/1660 train_time:30674ms step_avg:91.57ms +step:336/1660 train_time:30766ms step_avg:91.57ms +step:337/1660 train_time:30858ms step_avg:91.57ms +step:338/1660 train_time:30950ms step_avg:91.57ms +step:339/1660 train_time:31041ms step_avg:91.57ms +step:340/1660 train_time:31132ms step_avg:91.56ms +step:341/1660 train_time:31223ms step_avg:91.56ms +step:342/1660 train_time:31315ms step_avg:91.56ms +step:343/1660 train_time:31407ms step_avg:91.56ms +step:344/1660 train_time:31498ms step_avg:91.57ms +step:345/1660 train_time:31590ms step_avg:91.57ms +step:346/1660 train_time:31683ms step_avg:91.57ms +step:347/1660 train_time:31775ms step_avg:91.57ms +step:348/1660 train_time:31865ms step_avg:91.57ms +step:349/1660 train_time:31957ms step_avg:91.57ms +step:350/1660 train_time:32049ms step_avg:91.57ms +step:351/1660 train_time:32140ms step_avg:91.57ms +step:352/1660 train_time:32231ms step_avg:91.57ms +step:353/1660 train_time:32322ms step_avg:91.56ms +step:354/1660 train_time:32413ms step_avg:91.56ms +step:355/1660 train_time:32504ms step_avg:91.56ms +step:356/1660 train_time:32596ms step_avg:91.56ms +step:357/1660 train_time:32686ms step_avg:91.56ms +step:358/1660 train_time:32779ms step_avg:91.56ms +step:359/1660 train_time:32871ms step_avg:91.56ms +step:360/1660 train_time:32963ms step_avg:91.56ms +step:361/1660 train_time:33055ms step_avg:91.57ms +step:362/1660 train_time:33145ms step_avg:91.56ms +step:363/1660 train_time:33237ms step_avg:91.56ms +step:364/1660 train_time:33327ms step_avg:91.56ms +step:365/1660 train_time:33418ms step_avg:91.56ms +step:366/1660 train_time:33510ms step_avg:91.56ms +step:367/1660 train_time:33601ms step_avg:91.56ms +step:368/1660 train_time:33692ms step_avg:91.56ms +step:369/1660 train_time:33784ms step_avg:91.55ms +step:370/1660 train_time:33876ms step_avg:91.56ms +step:371/1660 train_time:33967ms step_avg:91.55ms +step:372/1660 train_time:34059ms step_avg:91.56ms +step:373/1660 train_time:34150ms step_avg:91.56ms +step:374/1660 train_time:34242ms step_avg:91.56ms +step:375/1660 train_time:34333ms step_avg:91.56ms +step:375/1660 val_loss:3.8187 train_time:34426ms step_avg:91.80ms +step:376/1660 train_time:34445ms step_avg:91.61ms +step:377/1660 train_time:34520ms step_avg:91.57ms +step:378/1660 train_time:34615ms step_avg:91.57ms +step:379/1660 train_time:34707ms step_avg:91.57ms +step:380/1660 train_time:34797ms step_avg:91.57ms +step:381/1660 train_time:34888ms step_avg:91.57ms +step:382/1660 train_time:34978ms step_avg:91.56ms +step:383/1660 train_time:35068ms step_avg:91.56ms +step:384/1660 train_time:35158ms step_avg:91.56ms +step:385/1660 train_time:35249ms step_avg:91.56ms +step:386/1660 train_time:35340ms step_avg:91.55ms +step:387/1660 train_time:35435ms step_avg:91.56ms +step:388/1660 train_time:35529ms step_avg:91.57ms +step:389/1660 train_time:35622ms step_avg:91.57ms +step:390/1660 train_time:35714ms step_avg:91.58ms +step:391/1660 train_time:35805ms step_avg:91.57ms +step:392/1660 train_time:35896ms step_avg:91.57ms +step:393/1660 train_time:35986ms step_avg:91.57ms +step:394/1660 train_time:36077ms step_avg:91.57ms +step:395/1660 train_time:36167ms step_avg:91.56ms +step:396/1660 train_time:36257ms step_avg:91.56ms +step:397/1660 train_time:36348ms step_avg:91.56ms +step:398/1660 train_time:36439ms step_avg:91.56ms +step:399/1660 train_time:36533ms step_avg:91.56ms +step:400/1660 train_time:36626ms step_avg:91.56ms +step:401/1660 train_time:36717ms step_avg:91.56ms +step:402/1660 train_time:36810ms step_avg:91.57ms +step:403/1660 train_time:36901ms step_avg:91.57ms +step:404/1660 train_time:36993ms step_avg:91.57ms +step:405/1660 train_time:37084ms step_avg:91.57ms +step:406/1660 train_time:37174ms step_avg:91.56ms +step:407/1660 train_time:37265ms step_avg:91.56ms +step:408/1660 train_time:37356ms step_avg:91.56ms +step:409/1660 train_time:37448ms step_avg:91.56ms +step:410/1660 train_time:37539ms step_avg:91.56ms +step:411/1660 train_time:37632ms step_avg:91.56ms +step:412/1660 train_time:37724ms step_avg:91.56ms +step:413/1660 train_time:37815ms step_avg:91.56ms +step:414/1660 train_time:37907ms step_avg:91.56ms +step:415/1660 train_time:37998ms step_avg:91.56ms +step:416/1660 train_time:38090ms step_avg:91.56ms +step:417/1660 train_time:38181ms step_avg:91.56ms +step:418/1660 train_time:38272ms step_avg:91.56ms +step:419/1660 train_time:38364ms step_avg:91.56ms +step:420/1660 train_time:38455ms step_avg:91.56ms +step:421/1660 train_time:38547ms step_avg:91.56ms +step:422/1660 train_time:38638ms step_avg:91.56ms +step:423/1660 train_time:38731ms step_avg:91.56ms +step:424/1660 train_time:38824ms step_avg:91.57ms +step:425/1660 train_time:38916ms step_avg:91.57ms +step:426/1660 train_time:39007ms step_avg:91.57ms +step:427/1660 train_time:39098ms step_avg:91.57ms +step:428/1660 train_time:39190ms step_avg:91.57ms +step:429/1660 train_time:39281ms step_avg:91.56ms +step:430/1660 train_time:39371ms step_avg:91.56ms +step:431/1660 train_time:39463ms step_avg:91.56ms +step:432/1660 train_time:39554ms step_avg:91.56ms +step:433/1660 train_time:39645ms step_avg:91.56ms +step:434/1660 train_time:39736ms step_avg:91.56ms +step:435/1660 train_time:39830ms step_avg:91.56ms +step:436/1660 train_time:39922ms step_avg:91.56ms +step:437/1660 train_time:40014ms step_avg:91.56ms +step:438/1660 train_time:40105ms step_avg:91.56ms +step:439/1660 train_time:40196ms step_avg:91.56ms +step:440/1660 train_time:40287ms step_avg:91.56ms +step:441/1660 train_time:40378ms step_avg:91.56ms +step:442/1660 train_time:40469ms step_avg:91.56ms +step:443/1660 train_time:40561ms step_avg:91.56ms +step:444/1660 train_time:40653ms step_avg:91.56ms +step:445/1660 train_time:40744ms step_avg:91.56ms +step:446/1660 train_time:40835ms step_avg:91.56ms +step:447/1660 train_time:40929ms step_avg:91.56ms +step:448/1660 train_time:41020ms step_avg:91.56ms +step:449/1660 train_time:41112ms step_avg:91.56ms +step:450/1660 train_time:41203ms step_avg:91.56ms +step:451/1660 train_time:41295ms step_avg:91.56ms +step:452/1660 train_time:41386ms step_avg:91.56ms +step:453/1660 train_time:41477ms step_avg:91.56ms +step:454/1660 train_time:41569ms step_avg:91.56ms +step:455/1660 train_time:41660ms step_avg:91.56ms +step:456/1660 train_time:41752ms step_avg:91.56ms +step:457/1660 train_time:41844ms step_avg:91.56ms +step:458/1660 train_time:41935ms step_avg:91.56ms +step:459/1660 train_time:42027ms step_avg:91.56ms +step:460/1660 train_time:42120ms step_avg:91.56ms +step:461/1660 train_time:42212ms step_avg:91.57ms +step:462/1660 train_time:42303ms step_avg:91.57ms +step:463/1660 train_time:42395ms step_avg:91.57ms +step:464/1660 train_time:42486ms step_avg:91.56ms +step:465/1660 train_time:42576ms step_avg:91.56ms +step:466/1660 train_time:42667ms step_avg:91.56ms +step:467/1660 train_time:42758ms step_avg:91.56ms +step:468/1660 train_time:42850ms step_avg:91.56ms +step:469/1660 train_time:42942ms step_avg:91.56ms +step:470/1660 train_time:43034ms step_avg:91.56ms +step:471/1660 train_time:43126ms step_avg:91.56ms +step:472/1660 train_time:43217ms step_avg:91.56ms +step:473/1660 train_time:43309ms step_avg:91.56ms +step:474/1660 train_time:43401ms step_avg:91.56ms +step:475/1660 train_time:43491ms step_avg:91.56ms +step:476/1660 train_time:43583ms step_avg:91.56ms +step:477/1660 train_time:43674ms step_avg:91.56ms +step:478/1660 train_time:43765ms step_avg:91.56ms +step:479/1660 train_time:43856ms step_avg:91.56ms +step:480/1660 train_time:43948ms step_avg:91.56ms +step:481/1660 train_time:44040ms step_avg:91.56ms +step:482/1660 train_time:44132ms step_avg:91.56ms +step:483/1660 train_time:44225ms step_avg:91.56ms +step:484/1660 train_time:44315ms step_avg:91.56ms +step:485/1660 train_time:44406ms step_avg:91.56ms +step:486/1660 train_time:44497ms step_avg:91.56ms +step:487/1660 train_time:44588ms step_avg:91.56ms +step:488/1660 train_time:44679ms step_avg:91.56ms +step:489/1660 train_time:44770ms step_avg:91.55ms +step:490/1660 train_time:44862ms step_avg:91.56ms +step:491/1660 train_time:44953ms step_avg:91.55ms +step:492/1660 train_time:45044ms step_avg:91.55ms +step:493/1660 train_time:45135ms step_avg:91.55ms +step:494/1660 train_time:45227ms step_avg:91.55ms +step:495/1660 train_time:45318ms step_avg:91.55ms +step:496/1660 train_time:45410ms step_avg:91.55ms +step:497/1660 train_time:45502ms step_avg:91.55ms +step:498/1660 train_time:45593ms step_avg:91.55ms +step:499/1660 train_time:45684ms step_avg:91.55ms +step:500/1660 train_time:45775ms step_avg:91.55ms +step:500/1660 val_loss:3.7178 train_time:45867ms step_avg:91.73ms +step:501/1660 train_time:45886ms step_avg:91.59ms +step:502/1660 train_time:45960ms step_avg:91.55ms +step:503/1660 train_time:46055ms step_avg:91.56ms +step:504/1660 train_time:46146ms step_avg:91.56ms +step:505/1660 train_time:46237ms step_avg:91.56ms +step:506/1660 train_time:46327ms step_avg:91.56ms +step:507/1660 train_time:46417ms step_avg:91.55ms +step:508/1660 train_time:46508ms step_avg:91.55ms +step:509/1660 train_time:46598ms step_avg:91.55ms +step:510/1660 train_time:46690ms step_avg:91.55ms +step:511/1660 train_time:46781ms step_avg:91.55ms +step:512/1660 train_time:46874ms step_avg:91.55ms +step:513/1660 train_time:46967ms step_avg:91.55ms +step:514/1660 train_time:47060ms step_avg:91.56ms +step:515/1660 train_time:47152ms step_avg:91.56ms +step:516/1660 train_time:47243ms step_avg:91.56ms +step:517/1660 train_time:47334ms step_avg:91.55ms +step:518/1660 train_time:47424ms step_avg:91.55ms +step:519/1660 train_time:47515ms step_avg:91.55ms +step:520/1660 train_time:47605ms step_avg:91.55ms +step:521/1660 train_time:47696ms step_avg:91.55ms +step:522/1660 train_time:47787ms step_avg:91.55ms +step:523/1660 train_time:47880ms step_avg:91.55ms +step:524/1660 train_time:47973ms step_avg:91.55ms +step:525/1660 train_time:48065ms step_avg:91.55ms +step:526/1660 train_time:48157ms step_avg:91.55ms +step:527/1660 train_time:48248ms step_avg:91.55ms +step:528/1660 train_time:48339ms step_avg:91.55ms +step:529/1660 train_time:48430ms step_avg:91.55ms +step:530/1660 train_time:48521ms step_avg:91.55ms +step:531/1660 train_time:48611ms step_avg:91.55ms +step:532/1660 train_time:48702ms step_avg:91.54ms +step:533/1660 train_time:48794ms step_avg:91.55ms +step:534/1660 train_time:48885ms step_avg:91.54ms +step:535/1660 train_time:48977ms step_avg:91.55ms +step:536/1660 train_time:49069ms step_avg:91.55ms +step:537/1660 train_time:49161ms step_avg:91.55ms +step:538/1660 train_time:49253ms step_avg:91.55ms +step:539/1660 train_time:49343ms step_avg:91.55ms +step:540/1660 train_time:49434ms step_avg:91.54ms +step:541/1660 train_time:49524ms step_avg:91.54ms +step:542/1660 train_time:49615ms step_avg:91.54ms +step:543/1660 train_time:49706ms step_avg:91.54ms +step:544/1660 train_time:49797ms step_avg:91.54ms +step:545/1660 train_time:49888ms step_avg:91.54ms +step:546/1660 train_time:49980ms step_avg:91.54ms +step:547/1660 train_time:50072ms step_avg:91.54ms +step:548/1660 train_time:50162ms step_avg:91.54ms +step:549/1660 train_time:50254ms step_avg:91.54ms +step:550/1660 train_time:50346ms step_avg:91.54ms +step:551/1660 train_time:50438ms step_avg:91.54ms +step:552/1660 train_time:50528ms step_avg:91.54ms +step:553/1660 train_time:50619ms step_avg:91.54ms +step:554/1660 train_time:50710ms step_avg:91.54ms +step:555/1660 train_time:50802ms step_avg:91.53ms +step:556/1660 train_time:50895ms step_avg:91.54ms +step:557/1660 train_time:50988ms step_avg:91.54ms +step:558/1660 train_time:51080ms step_avg:91.54ms +step:559/1660 train_time:51174ms step_avg:91.55ms +step:560/1660 train_time:51267ms step_avg:91.55ms +step:561/1660 train_time:51360ms step_avg:91.55ms +step:562/1660 train_time:51454ms step_avg:91.55ms +step:563/1660 train_time:51546ms step_avg:91.56ms +step:564/1660 train_time:51637ms step_avg:91.56ms +step:565/1660 train_time:51730ms step_avg:91.56ms +step:566/1660 train_time:51822ms step_avg:91.56ms +step:567/1660 train_time:51915ms step_avg:91.56ms +step:568/1660 train_time:52009ms step_avg:91.56ms +step:569/1660 train_time:52101ms step_avg:91.57ms +step:570/1660 train_time:52195ms step_avg:91.57ms +step:571/1660 train_time:52288ms step_avg:91.57ms +step:572/1660 train_time:52380ms step_avg:91.57ms +step:573/1660 train_time:52472ms step_avg:91.57ms +step:574/1660 train_time:52564ms step_avg:91.58ms +step:575/1660 train_time:52657ms step_avg:91.58ms +step:576/1660 train_time:52750ms step_avg:91.58ms +step:577/1660 train_time:52842ms step_avg:91.58ms +step:578/1660 train_time:52934ms step_avg:91.58ms +step:579/1660 train_time:53026ms step_avg:91.58ms +step:580/1660 train_time:53119ms step_avg:91.58ms +step:581/1660 train_time:53212ms step_avg:91.59ms +step:582/1660 train_time:53304ms step_avg:91.59ms +step:583/1660 train_time:53398ms step_avg:91.59ms +step:584/1660 train_time:53490ms step_avg:91.59ms +step:585/1660 train_time:53582ms step_avg:91.59ms +step:586/1660 train_time:53676ms step_avg:91.60ms +step:587/1660 train_time:53768ms step_avg:91.60ms +step:588/1660 train_time:53860ms step_avg:91.60ms +step:589/1660 train_time:53953ms step_avg:91.60ms +step:590/1660 train_time:54045ms step_avg:91.60ms +step:591/1660 train_time:54138ms step_avg:91.60ms +step:592/1660 train_time:54231ms step_avg:91.61ms +step:593/1660 train_time:54323ms step_avg:91.61ms +step:594/1660 train_time:54416ms step_avg:91.61ms +step:595/1660 train_time:54510ms step_avg:91.61ms +step:596/1660 train_time:54602ms step_avg:91.61ms +step:597/1660 train_time:54695ms step_avg:91.62ms +step:598/1660 train_time:54787ms step_avg:91.62ms +step:599/1660 train_time:54880ms step_avg:91.62ms +step:600/1660 train_time:54973ms step_avg:91.62ms +step:601/1660 train_time:55065ms step_avg:91.62ms +step:602/1660 train_time:55160ms step_avg:91.63ms +step:603/1660 train_time:55252ms step_avg:91.63ms +step:604/1660 train_time:55344ms step_avg:91.63ms +step:605/1660 train_time:55438ms step_avg:91.63ms +step:606/1660 train_time:55530ms step_avg:91.63ms +step:607/1660 train_time:55622ms step_avg:91.63ms +step:608/1660 train_time:55715ms step_avg:91.64ms +step:609/1660 train_time:55807ms step_avg:91.64ms +step:610/1660 train_time:55900ms step_avg:91.64ms +step:611/1660 train_time:55992ms step_avg:91.64ms +step:612/1660 train_time:56083ms step_avg:91.64ms +step:613/1660 train_time:56177ms step_avg:91.64ms +step:614/1660 train_time:56270ms step_avg:91.65ms +step:615/1660 train_time:56363ms step_avg:91.65ms +step:616/1660 train_time:56456ms step_avg:91.65ms +step:617/1660 train_time:56549ms step_avg:91.65ms +step:618/1660 train_time:56641ms step_avg:91.65ms +step:619/1660 train_time:56734ms step_avg:91.65ms +step:620/1660 train_time:56826ms step_avg:91.66ms +step:621/1660 train_time:56919ms step_avg:91.66ms +step:622/1660 train_time:57012ms step_avg:91.66ms +step:623/1660 train_time:57104ms step_avg:91.66ms +step:624/1660 train_time:57198ms step_avg:91.66ms +step:625/1660 train_time:57291ms step_avg:91.67ms +step:625/1660 val_loss:3.6161 train_time:57385ms step_avg:91.82ms +step:626/1660 train_time:57404ms step_avg:91.70ms +step:627/1660 train_time:57480ms step_avg:91.67ms +step:628/1660 train_time:57577ms step_avg:91.68ms +step:629/1660 train_time:57671ms step_avg:91.69ms +step:630/1660 train_time:57763ms step_avg:91.69ms +step:631/1660 train_time:57854ms step_avg:91.69ms +step:632/1660 train_time:57945ms step_avg:91.68ms +step:633/1660 train_time:58035ms step_avg:91.68ms +step:634/1660 train_time:58126ms step_avg:91.68ms +step:635/1660 train_time:58217ms step_avg:91.68ms +step:636/1660 train_time:58311ms step_avg:91.68ms +step:637/1660 train_time:58407ms step_avg:91.69ms +step:638/1660 train_time:58502ms step_avg:91.70ms +step:639/1660 train_time:58597ms step_avg:91.70ms +step:640/1660 train_time:58691ms step_avg:91.70ms +step:641/1660 train_time:58784ms step_avg:91.71ms +step:642/1660 train_time:58876ms step_avg:91.71ms +step:643/1660 train_time:58968ms step_avg:91.71ms +step:644/1660 train_time:59059ms step_avg:91.71ms +step:645/1660 train_time:59150ms step_avg:91.71ms +step:646/1660 train_time:59241ms step_avg:91.70ms +step:647/1660 train_time:59335ms step_avg:91.71ms +step:648/1660 train_time:59430ms step_avg:91.71ms +step:649/1660 train_time:59524ms step_avg:91.72ms +step:650/1660 train_time:59616ms step_avg:91.72ms +step:651/1660 train_time:59710ms step_avg:91.72ms +step:652/1660 train_time:59803ms step_avg:91.72ms +step:653/1660 train_time:59895ms step_avg:91.72ms +step:654/1660 train_time:59988ms step_avg:91.72ms +step:655/1660 train_time:60079ms step_avg:91.72ms +step:656/1660 train_time:60170ms step_avg:91.72ms +step:657/1660 train_time:60262ms step_avg:91.72ms +step:658/1660 train_time:60355ms step_avg:91.73ms +step:659/1660 train_time:60449ms step_avg:91.73ms +step:660/1660 train_time:60542ms step_avg:91.73ms +step:661/1660 train_time:60635ms step_avg:91.73ms +step:662/1660 train_time:60730ms step_avg:91.74ms +step:663/1660 train_time:60823ms step_avg:91.74ms +step:664/1660 train_time:60915ms step_avg:91.74ms +step:665/1660 train_time:61007ms step_avg:91.74ms +step:666/1660 train_time:61098ms step_avg:91.74ms +step:667/1660 train_time:61191ms step_avg:91.74ms +step:668/1660 train_time:61283ms step_avg:91.74ms +step:669/1660 train_time:61375ms step_avg:91.74ms +step:670/1660 train_time:61468ms step_avg:91.74ms +step:671/1660 train_time:61561ms step_avg:91.74ms +step:672/1660 train_time:61655ms step_avg:91.75ms +step:673/1660 train_time:61749ms step_avg:91.75ms +step:674/1660 train_time:61841ms step_avg:91.75ms +step:675/1660 train_time:61934ms step_avg:91.75ms +step:676/1660 train_time:62026ms step_avg:91.75ms +step:677/1660 train_time:62118ms step_avg:91.75ms +step:678/1660 train_time:62210ms step_avg:91.76ms +step:679/1660 train_time:62303ms step_avg:91.76ms +step:680/1660 train_time:62395ms step_avg:91.76ms +step:681/1660 train_time:62488ms step_avg:91.76ms +step:682/1660 train_time:62582ms step_avg:91.76ms +step:683/1660 train_time:62674ms step_avg:91.76ms +step:684/1660 train_time:62767ms step_avg:91.76ms +step:685/1660 train_time:62859ms step_avg:91.77ms +step:686/1660 train_time:62952ms step_avg:91.77ms +step:687/1660 train_time:63045ms step_avg:91.77ms +step:688/1660 train_time:63136ms step_avg:91.77ms +step:689/1660 train_time:63231ms step_avg:91.77ms +step:690/1660 train_time:63323ms step_avg:91.77ms +step:691/1660 train_time:63414ms step_avg:91.77ms +step:692/1660 train_time:63507ms step_avg:91.77ms +step:693/1660 train_time:63600ms step_avg:91.77ms +step:694/1660 train_time:63694ms step_avg:91.78ms +step:695/1660 train_time:63787ms step_avg:91.78ms +step:696/1660 train_time:63879ms step_avg:91.78ms +step:697/1660 train_time:63971ms step_avg:91.78ms +step:698/1660 train_time:64064ms step_avg:91.78ms +step:699/1660 train_time:64156ms step_avg:91.78ms +step:700/1660 train_time:64250ms step_avg:91.79ms +step:701/1660 train_time:64342ms step_avg:91.79ms +step:702/1660 train_time:64435ms step_avg:91.79ms +step:703/1660 train_time:64528ms step_avg:91.79ms +step:704/1660 train_time:64620ms step_avg:91.79ms +step:705/1660 train_time:64713ms step_avg:91.79ms +step:706/1660 train_time:64805ms step_avg:91.79ms +step:707/1660 train_time:64897ms step_avg:91.79ms +step:708/1660 train_time:64991ms step_avg:91.79ms +step:709/1660 train_time:65084ms step_avg:91.80ms +step:710/1660 train_time:65175ms step_avg:91.80ms +step:711/1660 train_time:65268ms step_avg:91.80ms +step:712/1660 train_time:65360ms step_avg:91.80ms +step:713/1660 train_time:65453ms step_avg:91.80ms +step:714/1660 train_time:65545ms step_avg:91.80ms +step:715/1660 train_time:65638ms step_avg:91.80ms +step:716/1660 train_time:65732ms step_avg:91.80ms +step:717/1660 train_time:65824ms step_avg:91.81ms +step:718/1660 train_time:65916ms step_avg:91.81ms +step:719/1660 train_time:66009ms step_avg:91.81ms +step:720/1660 train_time:66101ms step_avg:91.81ms +step:721/1660 train_time:66193ms step_avg:91.81ms +step:722/1660 train_time:66286ms step_avg:91.81ms +step:723/1660 train_time:66377ms step_avg:91.81ms +step:724/1660 train_time:66470ms step_avg:91.81ms +step:725/1660 train_time:66562ms step_avg:91.81ms +step:726/1660 train_time:66654ms step_avg:91.81ms +step:727/1660 train_time:66746ms step_avg:91.81ms +step:728/1660 train_time:66837ms step_avg:91.81ms +step:729/1660 train_time:66931ms step_avg:91.81ms +step:730/1660 train_time:67023ms step_avg:91.81ms +step:731/1660 train_time:67115ms step_avg:91.81ms +step:732/1660 train_time:67208ms step_avg:91.81ms +step:733/1660 train_time:67301ms step_avg:91.82ms +step:734/1660 train_time:67393ms step_avg:91.82ms +step:735/1660 train_time:67486ms step_avg:91.82ms +step:736/1660 train_time:67578ms step_avg:91.82ms +step:737/1660 train_time:67671ms step_avg:91.82ms +step:738/1660 train_time:67763ms step_avg:91.82ms +step:739/1660 train_time:67856ms step_avg:91.82ms +step:740/1660 train_time:67949ms step_avg:91.82ms +step:741/1660 train_time:68041ms step_avg:91.82ms +step:742/1660 train_time:68134ms step_avg:91.82ms +step:743/1660 train_time:68227ms step_avg:91.83ms +step:744/1660 train_time:68318ms step_avg:91.83ms +step:745/1660 train_time:68411ms step_avg:91.83ms +step:746/1660 train_time:68503ms step_avg:91.83ms +step:747/1660 train_time:68595ms step_avg:91.83ms +step:748/1660 train_time:68689ms step_avg:91.83ms +step:749/1660 train_time:68780ms step_avg:91.83ms +step:750/1660 train_time:68873ms step_avg:91.83ms +step:750/1660 val_loss:3.5652 train_time:68968ms step_avg:91.96ms +step:751/1660 train_time:68987ms step_avg:91.86ms +step:752/1660 train_time:69063ms step_avg:91.84ms +step:753/1660 train_time:69159ms step_avg:91.84ms +step:754/1660 train_time:69252ms step_avg:91.85ms +step:755/1660 train_time:69344ms step_avg:91.85ms +step:756/1660 train_time:69435ms step_avg:91.84ms +step:757/1660 train_time:69526ms step_avg:91.84ms +step:758/1660 train_time:69618ms step_avg:91.84ms +step:759/1660 train_time:69710ms step_avg:91.84ms +step:760/1660 train_time:69803ms step_avg:91.85ms +step:761/1660 train_time:69896ms step_avg:91.85ms +step:762/1660 train_time:69990ms step_avg:91.85ms +step:763/1660 train_time:70085ms step_avg:91.85ms +step:764/1660 train_time:70177ms step_avg:91.86ms +step:765/1660 train_time:70272ms step_avg:91.86ms +step:766/1660 train_time:70364ms step_avg:91.86ms +step:767/1660 train_time:70456ms step_avg:91.86ms +step:768/1660 train_time:70548ms step_avg:91.86ms +step:769/1660 train_time:70639ms step_avg:91.86ms +step:770/1660 train_time:70732ms step_avg:91.86ms +step:771/1660 train_time:70825ms step_avg:91.86ms +step:772/1660 train_time:70917ms step_avg:91.86ms +step:773/1660 train_time:71011ms step_avg:91.86ms +step:774/1660 train_time:71105ms step_avg:91.87ms +step:775/1660 train_time:71198ms step_avg:91.87ms +step:776/1660 train_time:71291ms step_avg:91.87ms +step:777/1660 train_time:71383ms step_avg:91.87ms +step:778/1660 train_time:71475ms step_avg:91.87ms +step:779/1660 train_time:71568ms step_avg:91.87ms +step:780/1660 train_time:71660ms step_avg:91.87ms +step:781/1660 train_time:71752ms step_avg:91.87ms +step:782/1660 train_time:71845ms step_avg:91.87ms +step:783/1660 train_time:71938ms step_avg:91.87ms +step:784/1660 train_time:72032ms step_avg:91.88ms +step:785/1660 train_time:72125ms step_avg:91.88ms +step:786/1660 train_time:72218ms step_avg:91.88ms +step:787/1660 train_time:72311ms step_avg:91.88ms +step:788/1660 train_time:72404ms step_avg:91.88ms +step:789/1660 train_time:72496ms step_avg:91.88ms +step:790/1660 train_time:72589ms step_avg:91.88ms +step:791/1660 train_time:72681ms step_avg:91.89ms +step:792/1660 train_time:72772ms step_avg:91.88ms +step:793/1660 train_time:72866ms step_avg:91.89ms +step:794/1660 train_time:72958ms step_avg:91.89ms +step:795/1660 train_time:73051ms step_avg:91.89ms +step:796/1660 train_time:73145ms step_avg:91.89ms +step:797/1660 train_time:73237ms step_avg:91.89ms +step:798/1660 train_time:73331ms step_avg:91.89ms +step:799/1660 train_time:73424ms step_avg:91.89ms +step:800/1660 train_time:73515ms step_avg:91.89ms +step:801/1660 train_time:73607ms step_avg:91.89ms +step:802/1660 train_time:73700ms step_avg:91.89ms +step:803/1660 train_time:73791ms step_avg:91.89ms +step:804/1660 train_time:73884ms step_avg:91.90ms +step:805/1660 train_time:73976ms step_avg:91.90ms +step:806/1660 train_time:74070ms step_avg:91.90ms +step:807/1660 train_time:74162ms step_avg:91.90ms +step:808/1660 train_time:74254ms step_avg:91.90ms +step:809/1660 train_time:74348ms step_avg:91.90ms +step:810/1660 train_time:74442ms step_avg:91.90ms +step:811/1660 train_time:74534ms step_avg:91.90ms +step:812/1660 train_time:74627ms step_avg:91.90ms +step:813/1660 train_time:74719ms step_avg:91.91ms +step:814/1660 train_time:74811ms step_avg:91.90ms +step:815/1660 train_time:74903ms step_avg:91.91ms +step:816/1660 train_time:74995ms step_avg:91.91ms +step:817/1660 train_time:75088ms step_avg:91.91ms +step:818/1660 train_time:75181ms step_avg:91.91ms +step:819/1660 train_time:75274ms step_avg:91.91ms +step:820/1660 train_time:75368ms step_avg:91.91ms +step:821/1660 train_time:75461ms step_avg:91.91ms +step:822/1660 train_time:75553ms step_avg:91.91ms +step:823/1660 train_time:75647ms step_avg:91.92ms +step:824/1660 train_time:75739ms step_avg:91.92ms +step:825/1660 train_time:75831ms step_avg:91.92ms +step:826/1660 train_time:75923ms step_avg:91.92ms +step:827/1660 train_time:76015ms step_avg:91.92ms +step:828/1660 train_time:76107ms step_avg:91.92ms +step:829/1660 train_time:76199ms step_avg:91.92ms +step:830/1660 train_time:76291ms step_avg:91.92ms +step:831/1660 train_time:76384ms step_avg:91.92ms +step:832/1660 train_time:76476ms step_avg:91.92ms +step:833/1660 train_time:76569ms step_avg:91.92ms +step:834/1660 train_time:76662ms step_avg:91.92ms +step:835/1660 train_time:76754ms step_avg:91.92ms +step:836/1660 train_time:76847ms step_avg:91.92ms +step:837/1660 train_time:76940ms step_avg:91.92ms +step:838/1660 train_time:77032ms step_avg:91.92ms +step:839/1660 train_time:77124ms step_avg:91.92ms +step:840/1660 train_time:77216ms step_avg:91.92ms +step:841/1660 train_time:77310ms step_avg:91.93ms +step:842/1660 train_time:77403ms step_avg:91.93ms +step:843/1660 train_time:77494ms step_avg:91.93ms +step:844/1660 train_time:77587ms step_avg:91.93ms +step:845/1660 train_time:77680ms step_avg:91.93ms +step:846/1660 train_time:77773ms step_avg:91.93ms +step:847/1660 train_time:77866ms step_avg:91.93ms +step:848/1660 train_time:77960ms step_avg:91.93ms +step:849/1660 train_time:78052ms step_avg:91.93ms +step:850/1660 train_time:78144ms step_avg:91.93ms +step:851/1660 train_time:78235ms step_avg:91.93ms +step:852/1660 train_time:78329ms step_avg:91.94ms +step:853/1660 train_time:78423ms step_avg:91.94ms +step:854/1660 train_time:78514ms step_avg:91.94ms +step:855/1660 train_time:78608ms step_avg:91.94ms +step:856/1660 train_time:78700ms step_avg:91.94ms +step:857/1660 train_time:78793ms step_avg:91.94ms +step:858/1660 train_time:78887ms step_avg:91.94ms +step:859/1660 train_time:78980ms step_avg:91.94ms +step:860/1660 train_time:79073ms step_avg:91.94ms +step:861/1660 train_time:79166ms step_avg:91.95ms +step:862/1660 train_time:79257ms step_avg:91.95ms +step:863/1660 train_time:79351ms step_avg:91.95ms +step:864/1660 train_time:79444ms step_avg:91.95ms +step:865/1660 train_time:79536ms step_avg:91.95ms +step:866/1660 train_time:79630ms step_avg:91.95ms +step:867/1660 train_time:79723ms step_avg:91.95ms +step:868/1660 train_time:79815ms step_avg:91.95ms +step:869/1660 train_time:79908ms step_avg:91.95ms +step:870/1660 train_time:80000ms step_avg:91.95ms +step:871/1660 train_time:80092ms step_avg:91.95ms +step:872/1660 train_time:80184ms step_avg:91.95ms +step:873/1660 train_time:80276ms step_avg:91.95ms +step:874/1660 train_time:80369ms step_avg:91.96ms +step:875/1660 train_time:80462ms step_avg:91.96ms +step:875/1660 val_loss:3.5191 train_time:80555ms step_avg:92.06ms +step:876/1660 train_time:80575ms step_avg:91.98ms +step:877/1660 train_time:80650ms step_avg:91.96ms +step:878/1660 train_time:80751ms step_avg:91.97ms +step:879/1660 train_time:80844ms step_avg:91.97ms +step:880/1660 train_time:80935ms step_avg:91.97ms +step:881/1660 train_time:81027ms step_avg:91.97ms +step:882/1660 train_time:81118ms step_avg:91.97ms +step:883/1660 train_time:81209ms step_avg:91.97ms +step:884/1660 train_time:81301ms step_avg:91.97ms +step:885/1660 train_time:81392ms step_avg:91.97ms +step:886/1660 train_time:81485ms step_avg:91.97ms +step:887/1660 train_time:81579ms step_avg:91.97ms +step:888/1660 train_time:81674ms step_avg:91.98ms +step:889/1660 train_time:81770ms step_avg:91.98ms +step:890/1660 train_time:81864ms step_avg:91.98ms +step:891/1660 train_time:81956ms step_avg:91.98ms +step:892/1660 train_time:82048ms step_avg:91.98ms +step:893/1660 train_time:82140ms step_avg:91.98ms +step:894/1660 train_time:82231ms step_avg:91.98ms +step:895/1660 train_time:82323ms step_avg:91.98ms +step:896/1660 train_time:82414ms step_avg:91.98ms +step:897/1660 train_time:82508ms step_avg:91.98ms +step:898/1660 train_time:82602ms step_avg:91.98ms +step:899/1660 train_time:82695ms step_avg:91.99ms +step:900/1660 train_time:82788ms step_avg:91.99ms +step:901/1660 train_time:82882ms step_avg:91.99ms +step:902/1660 train_time:82974ms step_avg:91.99ms +step:903/1660 train_time:83067ms step_avg:91.99ms +step:904/1660 train_time:83159ms step_avg:91.99ms +step:905/1660 train_time:83251ms step_avg:91.99ms +step:906/1660 train_time:83342ms step_avg:91.99ms +step:907/1660 train_time:83434ms step_avg:91.99ms +step:908/1660 train_time:83528ms step_avg:91.99ms +step:909/1660 train_time:83621ms step_avg:91.99ms +step:910/1660 train_time:83714ms step_avg:91.99ms +step:911/1660 train_time:83809ms step_avg:92.00ms +step:912/1660 train_time:83904ms step_avg:92.00ms +step:913/1660 train_time:83997ms step_avg:92.00ms +step:914/1660 train_time:84089ms step_avg:92.00ms +step:915/1660 train_time:84181ms step_avg:92.00ms +step:916/1660 train_time:84272ms step_avg:92.00ms +step:917/1660 train_time:84365ms step_avg:92.00ms +step:918/1660 train_time:84458ms step_avg:92.00ms +step:919/1660 train_time:84550ms step_avg:92.00ms +step:920/1660 train_time:84643ms step_avg:92.00ms +step:921/1660 train_time:84736ms step_avg:92.00ms +step:922/1660 train_time:84829ms step_avg:92.01ms +step:923/1660 train_time:84923ms step_avg:92.01ms +step:924/1660 train_time:85015ms step_avg:92.01ms +step:925/1660 train_time:85107ms step_avg:92.01ms +step:926/1660 train_time:85199ms step_avg:92.01ms +step:927/1660 train_time:85291ms step_avg:92.01ms +step:928/1660 train_time:85383ms step_avg:92.01ms +step:929/1660 train_time:85474ms step_avg:92.01ms +step:930/1660 train_time:85568ms step_avg:92.01ms +step:931/1660 train_time:85661ms step_avg:92.01ms +step:932/1660 train_time:85754ms step_avg:92.01ms +step:933/1660 train_time:85848ms step_avg:92.01ms +step:934/1660 train_time:85941ms step_avg:92.01ms +step:935/1660 train_time:86033ms step_avg:92.01ms +step:936/1660 train_time:86127ms step_avg:92.02ms +step:937/1660 train_time:86220ms step_avg:92.02ms +step:938/1660 train_time:86311ms step_avg:92.02ms +step:939/1660 train_time:86404ms step_avg:92.02ms +step:940/1660 train_time:86496ms step_avg:92.02ms +step:941/1660 train_time:86589ms step_avg:92.02ms +step:942/1660 train_time:86681ms step_avg:92.02ms +step:943/1660 train_time:86774ms step_avg:92.02ms +step:944/1660 train_time:86867ms step_avg:92.02ms +step:945/1660 train_time:86960ms step_avg:92.02ms +step:946/1660 train_time:87052ms step_avg:92.02ms +step:947/1660 train_time:87145ms step_avg:92.02ms +step:948/1660 train_time:87238ms step_avg:92.02ms +step:949/1660 train_time:87330ms step_avg:92.02ms +step:950/1660 train_time:87422ms step_avg:92.02ms +step:951/1660 train_time:87514ms step_avg:92.02ms +step:952/1660 train_time:87608ms step_avg:92.03ms +step:953/1660 train_time:87701ms step_avg:92.03ms +step:954/1660 train_time:87793ms step_avg:92.03ms +step:955/1660 train_time:87886ms step_avg:92.03ms +step:956/1660 train_time:87980ms step_avg:92.03ms +step:957/1660 train_time:88072ms step_avg:92.03ms +step:958/1660 train_time:88165ms step_avg:92.03ms +step:959/1660 train_time:88257ms step_avg:92.03ms +step:960/1660 train_time:88349ms step_avg:92.03ms +step:961/1660 train_time:88441ms step_avg:92.03ms +step:962/1660 train_time:88534ms step_avg:92.03ms +step:963/1660 train_time:88627ms step_avg:92.03ms +step:964/1660 train_time:88720ms step_avg:92.03ms +step:965/1660 train_time:88812ms step_avg:92.03ms +step:966/1660 train_time:88905ms step_avg:92.03ms +step:967/1660 train_time:88998ms step_avg:92.04ms +step:968/1660 train_time:89090ms step_avg:92.04ms +step:969/1660 train_time:89182ms step_avg:92.04ms +step:970/1660 train_time:89275ms step_avg:92.04ms +step:971/1660 train_time:89366ms step_avg:92.04ms +step:972/1660 train_time:89459ms step_avg:92.04ms +step:973/1660 train_time:89551ms step_avg:92.04ms +step:974/1660 train_time:89644ms step_avg:92.04ms +step:975/1660 train_time:89737ms step_avg:92.04ms +step:976/1660 train_time:89829ms step_avg:92.04ms +step:977/1660 train_time:89921ms step_avg:92.04ms +step:978/1660 train_time:90013ms step_avg:92.04ms +step:979/1660 train_time:90106ms step_avg:92.04ms +step:980/1660 train_time:90199ms step_avg:92.04ms +step:981/1660 train_time:90291ms step_avg:92.04ms +step:982/1660 train_time:90384ms step_avg:92.04ms +step:983/1660 train_time:90476ms step_avg:92.04ms +step:984/1660 train_time:90569ms step_avg:92.04ms +step:985/1660 train_time:90662ms step_avg:92.04ms +step:986/1660 train_time:90753ms step_avg:92.04ms +step:987/1660 train_time:90845ms step_avg:92.04ms +step:988/1660 train_time:90939ms step_avg:92.04ms +step:989/1660 train_time:91031ms step_avg:92.04ms +step:990/1660 train_time:91124ms step_avg:92.04ms +step:991/1660 train_time:91216ms step_avg:92.04ms +step:992/1660 train_time:91309ms step_avg:92.05ms +step:993/1660 train_time:91402ms step_avg:92.05ms +step:994/1660 train_time:91494ms step_avg:92.05ms +step:995/1660 train_time:91585ms step_avg:92.05ms +step:996/1660 train_time:91678ms step_avg:92.05ms +step:997/1660 train_time:91771ms step_avg:92.05ms +step:998/1660 train_time:91864ms step_avg:92.05ms +step:999/1660 train_time:91957ms step_avg:92.05ms +step:1000/1660 train_time:92050ms step_avg:92.05ms +step:1000/1660 val_loss:3.4692 train_time:92144ms step_avg:92.14ms +step:1001/1660 train_time:92164ms step_avg:92.07ms +step:1002/1660 train_time:92239ms step_avg:92.06ms +step:1003/1660 train_time:92335ms step_avg:92.06ms +step:1004/1660 train_time:92428ms step_avg:92.06ms +step:1005/1660 train_time:92520ms step_avg:92.06ms +step:1006/1660 train_time:92612ms step_avg:92.06ms +step:1007/1660 train_time:92703ms step_avg:92.06ms +step:1008/1660 train_time:92794ms step_avg:92.06ms +step:1009/1660 train_time:92885ms step_avg:92.06ms +step:1010/1660 train_time:92977ms step_avg:92.06ms +step:1011/1660 train_time:93069ms step_avg:92.06ms +step:1012/1660 train_time:93164ms step_avg:92.06ms +step:1013/1660 train_time:93258ms step_avg:92.06ms +step:1014/1660 train_time:93352ms step_avg:92.06ms +step:1015/1660 train_time:93448ms step_avg:92.07ms +step:1016/1660 train_time:93542ms step_avg:92.07ms +step:1017/1660 train_time:93634ms step_avg:92.07ms +step:1018/1660 train_time:93726ms step_avg:92.07ms +step:1019/1660 train_time:93818ms step_avg:92.07ms +step:1020/1660 train_time:93909ms step_avg:92.07ms +step:1021/1660 train_time:94001ms step_avg:92.07ms +step:1022/1660 train_time:94094ms step_avg:92.07ms +step:1023/1660 train_time:94188ms step_avg:92.07ms +step:1024/1660 train_time:94282ms step_avg:92.07ms +step:1025/1660 train_time:94375ms step_avg:92.07ms +step:1026/1660 train_time:94469ms step_avg:92.08ms +step:1027/1660 train_time:94562ms step_avg:92.08ms +step:1028/1660 train_time:94653ms step_avg:92.07ms +step:1029/1660 train_time:94746ms step_avg:92.08ms +step:1030/1660 train_time:94838ms step_avg:92.08ms +step:1031/1660 train_time:94929ms step_avg:92.07ms +step:1032/1660 train_time:95022ms step_avg:92.08ms +step:1033/1660 train_time:95114ms step_avg:92.08ms +step:1034/1660 train_time:95209ms step_avg:92.08ms +step:1035/1660 train_time:95303ms step_avg:92.08ms +step:1036/1660 train_time:95395ms step_avg:92.08ms +step:1037/1660 train_time:95488ms step_avg:92.08ms +step:1038/1660 train_time:95581ms step_avg:92.08ms +step:1039/1660 train_time:95673ms step_avg:92.08ms +step:1040/1660 train_time:95766ms step_avg:92.08ms +step:1041/1660 train_time:95858ms step_avg:92.08ms +step:1042/1660 train_time:95950ms step_avg:92.08ms +step:1043/1660 train_time:96043ms step_avg:92.08ms +step:1044/1660 train_time:96135ms step_avg:92.08ms +step:1045/1660 train_time:96230ms step_avg:92.09ms +step:1046/1660 train_time:96323ms step_avg:92.09ms +step:1047/1660 train_time:96415ms step_avg:92.09ms +step:1048/1660 train_time:96508ms step_avg:92.09ms +step:1049/1660 train_time:96601ms step_avg:92.09ms +step:1050/1660 train_time:96693ms step_avg:92.09ms +step:1051/1660 train_time:96786ms step_avg:92.09ms +step:1052/1660 train_time:96878ms step_avg:92.09ms +step:1053/1660 train_time:96970ms step_avg:92.09ms +step:1054/1660 train_time:97062ms step_avg:92.09ms +step:1055/1660 train_time:97156ms step_avg:92.09ms +step:1056/1660 train_time:97250ms step_avg:92.09ms +step:1057/1660 train_time:97344ms step_avg:92.09ms +step:1058/1660 train_time:97437ms step_avg:92.10ms +step:1059/1660 train_time:97529ms step_avg:92.10ms +step:1060/1660 train_time:97621ms step_avg:92.10ms +step:1061/1660 train_time:97713ms step_avg:92.10ms +step:1062/1660 train_time:97808ms step_avg:92.10ms +step:1063/1660 train_time:97900ms step_avg:92.10ms +step:1064/1660 train_time:97992ms step_avg:92.10ms +step:1065/1660 train_time:98084ms step_avg:92.10ms +step:1066/1660 train_time:98177ms step_avg:92.10ms +step:1067/1660 train_time:98269ms step_avg:92.10ms +step:1068/1660 train_time:98362ms step_avg:92.10ms +step:1069/1660 train_time:98455ms step_avg:92.10ms +step:1070/1660 train_time:98548ms step_avg:92.10ms +step:1071/1660 train_time:98641ms step_avg:92.10ms +step:1072/1660 train_time:98733ms step_avg:92.10ms +step:1073/1660 train_time:98826ms step_avg:92.10ms +step:1074/1660 train_time:98919ms step_avg:92.10ms +step:1075/1660 train_time:99012ms step_avg:92.10ms +step:1076/1660 train_time:99104ms step_avg:92.10ms +step:1077/1660 train_time:99196ms step_avg:92.10ms +step:1078/1660 train_time:99288ms step_avg:92.10ms +step:1079/1660 train_time:99381ms step_avg:92.10ms +step:1080/1660 train_time:99474ms step_avg:92.11ms +step:1081/1660 train_time:99567ms step_avg:92.11ms +step:1082/1660 train_time:99660ms step_avg:92.11ms +step:1083/1660 train_time:99752ms step_avg:92.11ms +step:1084/1660 train_time:99846ms step_avg:92.11ms +step:1085/1660 train_time:99938ms step_avg:92.11ms +step:1086/1660 train_time:100031ms step_avg:92.11ms +step:1087/1660 train_time:100123ms step_avg:92.11ms +step:1088/1660 train_time:100215ms step_avg:92.11ms +step:1089/1660 train_time:100308ms step_avg:92.11ms +step:1090/1660 train_time:100401ms step_avg:92.11ms +step:1091/1660 train_time:100494ms step_avg:92.11ms +step:1092/1660 train_time:100587ms step_avg:92.11ms +step:1093/1660 train_time:100680ms step_avg:92.11ms +step:1094/1660 train_time:100772ms step_avg:92.11ms +step:1095/1660 train_time:100865ms step_avg:92.11ms +step:1096/1660 train_time:100958ms step_avg:92.12ms +step:1097/1660 train_time:101050ms step_avg:92.12ms +step:1098/1660 train_time:101143ms step_avg:92.12ms +step:1099/1660 train_time:101235ms step_avg:92.12ms +step:1100/1660 train_time:101328ms step_avg:92.12ms +step:1101/1660 train_time:101421ms step_avg:92.12ms +step:1102/1660 train_time:101513ms step_avg:92.12ms +step:1103/1660 train_time:101608ms step_avg:92.12ms +step:1104/1660 train_time:101701ms step_avg:92.12ms +step:1105/1660 train_time:101793ms step_avg:92.12ms +step:1106/1660 train_time:101886ms step_avg:92.12ms +step:1107/1660 train_time:101979ms step_avg:92.12ms +step:1108/1660 train_time:102072ms step_avg:92.12ms +step:1109/1660 train_time:102166ms step_avg:92.12ms +step:1110/1660 train_time:102259ms step_avg:92.13ms +step:1111/1660 train_time:102353ms step_avg:92.13ms +step:1112/1660 train_time:102447ms step_avg:92.13ms +step:1113/1660 train_time:102541ms step_avg:92.13ms +step:1114/1660 train_time:102633ms step_avg:92.13ms +step:1115/1660 train_time:102728ms step_avg:92.13ms +step:1116/1660 train_time:102822ms step_avg:92.13ms +step:1117/1660 train_time:102914ms step_avg:92.13ms +step:1118/1660 train_time:103008ms step_avg:92.14ms +step:1119/1660 train_time:103101ms step_avg:92.14ms +step:1120/1660 train_time:103194ms step_avg:92.14ms +step:1121/1660 train_time:103288ms step_avg:92.14ms +step:1122/1660 train_time:103381ms step_avg:92.14ms +step:1123/1660 train_time:103474ms step_avg:92.14ms +step:1124/1660 train_time:103568ms step_avg:92.14ms +step:1125/1660 train_time:103661ms step_avg:92.14ms +step:1125/1660 val_loss:3.4165 train_time:103755ms step_avg:92.23ms +step:1126/1660 train_time:103775ms step_avg:92.16ms +step:1127/1660 train_time:103851ms step_avg:92.15ms +step:1128/1660 train_time:103955ms step_avg:92.16ms +step:1129/1660 train_time:104050ms step_avg:92.16ms +step:1130/1660 train_time:104143ms step_avg:92.16ms +step:1131/1660 train_time:104235ms step_avg:92.16ms +step:1132/1660 train_time:104327ms step_avg:92.16ms +step:1133/1660 train_time:104420ms step_avg:92.16ms +step:1134/1660 train_time:104513ms step_avg:92.16ms +step:1135/1660 train_time:104605ms step_avg:92.16ms +step:1136/1660 train_time:104698ms step_avg:92.16ms +step:1137/1660 train_time:104792ms step_avg:92.17ms +step:1138/1660 train_time:104890ms step_avg:92.17ms +step:1139/1660 train_time:104986ms step_avg:92.17ms +step:1140/1660 train_time:105081ms step_avg:92.18ms +step:1141/1660 train_time:105173ms step_avg:92.18ms +step:1142/1660 train_time:105266ms step_avg:92.18ms +step:1143/1660 train_time:105358ms step_avg:92.18ms +step:1144/1660 train_time:105451ms step_avg:92.18ms +step:1145/1660 train_time:105544ms step_avg:92.18ms +step:1146/1660 train_time:105636ms step_avg:92.18ms +step:1147/1660 train_time:105729ms step_avg:92.18ms +step:1148/1660 train_time:105823ms step_avg:92.18ms +step:1149/1660 train_time:105917ms step_avg:92.18ms +step:1150/1660 train_time:106013ms step_avg:92.19ms +step:1151/1660 train_time:106108ms step_avg:92.19ms +step:1152/1660 train_time:106201ms step_avg:92.19ms +step:1153/1660 train_time:106293ms step_avg:92.19ms +step:1154/1660 train_time:106385ms step_avg:92.19ms +step:1155/1660 train_time:106477ms step_avg:92.19ms +step:1156/1660 train_time:106569ms step_avg:92.19ms +step:1157/1660 train_time:106662ms step_avg:92.19ms +step:1158/1660 train_time:106756ms step_avg:92.19ms +step:1159/1660 train_time:106851ms step_avg:92.19ms +step:1160/1660 train_time:106945ms step_avg:92.19ms +step:1161/1660 train_time:107039ms step_avg:92.20ms +step:1162/1660 train_time:107134ms step_avg:92.20ms +step:1163/1660 train_time:107227ms step_avg:92.20ms +step:1164/1660 train_time:107320ms step_avg:92.20ms +step:1165/1660 train_time:107412ms step_avg:92.20ms +step:1166/1660 train_time:107505ms step_avg:92.20ms +step:1167/1660 train_time:107597ms step_avg:92.20ms +step:1168/1660 train_time:107690ms step_avg:92.20ms +step:1169/1660 train_time:107783ms step_avg:92.20ms +step:1170/1660 train_time:107876ms step_avg:92.20ms +step:1171/1660 train_time:107970ms step_avg:92.20ms +step:1172/1660 train_time:108064ms step_avg:92.20ms +step:1173/1660 train_time:108156ms step_avg:92.20ms +step:1174/1660 train_time:108249ms step_avg:92.21ms +step:1175/1660 train_time:108342ms step_avg:92.21ms +step:1176/1660 train_time:108435ms step_avg:92.21ms +step:1177/1660 train_time:108528ms step_avg:92.21ms +step:1178/1660 train_time:108621ms step_avg:92.21ms +step:1179/1660 train_time:108714ms step_avg:92.21ms +step:1180/1660 train_time:108808ms step_avg:92.21ms +step:1181/1660 train_time:108903ms step_avg:92.21ms +step:1182/1660 train_time:108996ms step_avg:92.21ms +step:1183/1660 train_time:109089ms step_avg:92.21ms +step:1184/1660 train_time:109183ms step_avg:92.22ms +step:1185/1660 train_time:109276ms step_avg:92.22ms +step:1186/1660 train_time:109369ms step_avg:92.22ms +step:1187/1660 train_time:109462ms step_avg:92.22ms +step:1188/1660 train_time:109554ms step_avg:92.22ms +step:1189/1660 train_time:109648ms step_avg:92.22ms +step:1190/1660 train_time:109741ms step_avg:92.22ms +step:1191/1660 train_time:109834ms step_avg:92.22ms +step:1192/1660 train_time:109927ms step_avg:92.22ms +step:1193/1660 train_time:110022ms step_avg:92.22ms +step:1194/1660 train_time:110114ms step_avg:92.22ms +step:1195/1660 train_time:110208ms step_avg:92.22ms +step:1196/1660 train_time:110301ms step_avg:92.23ms +step:1197/1660 train_time:110394ms step_avg:92.23ms +step:1198/1660 train_time:110487ms step_avg:92.23ms +step:1199/1660 train_time:110580ms step_avg:92.23ms +step:1200/1660 train_time:110672ms step_avg:92.23ms +step:1201/1660 train_time:110766ms step_avg:92.23ms +step:1202/1660 train_time:110860ms step_avg:92.23ms +step:1203/1660 train_time:110952ms step_avg:92.23ms +step:1204/1660 train_time:111046ms step_avg:92.23ms +step:1205/1660 train_time:111140ms step_avg:92.23ms +step:1206/1660 train_time:111233ms step_avg:92.23ms +step:1207/1660 train_time:111327ms step_avg:92.23ms +step:1208/1660 train_time:111420ms step_avg:92.23ms +step:1209/1660 train_time:111512ms step_avg:92.24ms +step:1210/1660 train_time:111605ms step_avg:92.24ms +step:1211/1660 train_time:111699ms step_avg:92.24ms +step:1212/1660 train_time:111792ms step_avg:92.24ms +step:1213/1660 train_time:111886ms step_avg:92.24ms +step:1214/1660 train_time:111978ms step_avg:92.24ms +step:1215/1660 train_time:112071ms step_avg:92.24ms +step:1216/1660 train_time:112165ms step_avg:92.24ms +step:1217/1660 train_time:112259ms step_avg:92.24ms +step:1218/1660 train_time:112353ms step_avg:92.24ms +step:1219/1660 train_time:112446ms step_avg:92.24ms +step:1220/1660 train_time:112539ms step_avg:92.24ms +step:1221/1660 train_time:112632ms step_avg:92.25ms +step:1222/1660 train_time:112724ms step_avg:92.25ms +step:1223/1660 train_time:112817ms step_avg:92.25ms +step:1224/1660 train_time:112911ms step_avg:92.25ms +step:1225/1660 train_time:113004ms step_avg:92.25ms +step:1226/1660 train_time:113099ms step_avg:92.25ms +step:1227/1660 train_time:113191ms step_avg:92.25ms +step:1228/1660 train_time:113284ms step_avg:92.25ms +step:1229/1660 train_time:113376ms step_avg:92.25ms +step:1230/1660 train_time:113470ms step_avg:92.25ms +step:1231/1660 train_time:113564ms step_avg:92.25ms +step:1232/1660 train_time:113657ms step_avg:92.25ms +step:1233/1660 train_time:113751ms step_avg:92.26ms +step:1234/1660 train_time:113844ms step_avg:92.26ms +step:1235/1660 train_time:113936ms step_avg:92.26ms +step:1236/1660 train_time:114030ms step_avg:92.26ms +step:1237/1660 train_time:114124ms step_avg:92.26ms +step:1238/1660 train_time:114218ms step_avg:92.26ms +step:1239/1660 train_time:114310ms step_avg:92.26ms +step:1240/1660 train_time:114405ms step_avg:92.26ms +step:1241/1660 train_time:114498ms step_avg:92.26ms +step:1242/1660 train_time:114590ms step_avg:92.26ms +step:1243/1660 train_time:114684ms step_avg:92.26ms +step:1244/1660 train_time:114777ms step_avg:92.26ms +step:1245/1660 train_time:114871ms step_avg:92.27ms +step:1246/1660 train_time:114964ms step_avg:92.27ms +step:1247/1660 train_time:115057ms step_avg:92.27ms +step:1248/1660 train_time:115151ms step_avg:92.27ms +step:1249/1660 train_time:115244ms step_avg:92.27ms +step:1250/1660 train_time:115337ms step_avg:92.27ms +step:1250/1660 val_loss:3.3776 train_time:115432ms step_avg:92.35ms +step:1251/1660 train_time:115451ms step_avg:92.29ms +step:1252/1660 train_time:115528ms step_avg:92.27ms +step:1253/1660 train_time:115624ms step_avg:92.28ms +step:1254/1660 train_time:115716ms step_avg:92.28ms +step:1255/1660 train_time:115808ms step_avg:92.28ms +step:1256/1660 train_time:115900ms step_avg:92.28ms +step:1257/1660 train_time:115993ms step_avg:92.28ms +step:1258/1660 train_time:116086ms step_avg:92.28ms +step:1259/1660 train_time:116178ms step_avg:92.28ms +step:1260/1660 train_time:116272ms step_avg:92.28ms +step:1261/1660 train_time:116366ms step_avg:92.28ms +step:1262/1660 train_time:116460ms step_avg:92.28ms +step:1263/1660 train_time:116555ms step_avg:92.28ms +step:1264/1660 train_time:116649ms step_avg:92.29ms +step:1265/1660 train_time:116743ms step_avg:92.29ms +step:1266/1660 train_time:116835ms step_avg:92.29ms +step:1267/1660 train_time:116928ms step_avg:92.29ms +step:1268/1660 train_time:117020ms step_avg:92.29ms +step:1269/1660 train_time:117113ms step_avg:92.29ms +step:1270/1660 train_time:117207ms step_avg:92.29ms +step:1271/1660 train_time:117301ms step_avg:92.29ms +step:1272/1660 train_time:117396ms step_avg:92.29ms +step:1273/1660 train_time:117491ms step_avg:92.29ms +step:1274/1660 train_time:117585ms step_avg:92.30ms +step:1275/1660 train_time:117678ms step_avg:92.30ms +step:1276/1660 train_time:117773ms step_avg:92.30ms +step:1277/1660 train_time:117865ms step_avg:92.30ms +step:1278/1660 train_time:117957ms step_avg:92.30ms +step:1279/1660 train_time:118050ms step_avg:92.30ms +step:1280/1660 train_time:118143ms step_avg:92.30ms +step:1281/1660 train_time:118235ms step_avg:92.30ms +step:1282/1660 train_time:118330ms step_avg:92.30ms +step:1283/1660 train_time:118424ms step_avg:92.30ms +step:1284/1660 train_time:118517ms step_avg:92.30ms +step:1285/1660 train_time:118612ms step_avg:92.31ms +step:1286/1660 train_time:118707ms step_avg:92.31ms +step:1287/1660 train_time:118801ms step_avg:92.31ms +step:1288/1660 train_time:118894ms step_avg:92.31ms +step:1289/1660 train_time:118987ms step_avg:92.31ms +step:1290/1660 train_time:119079ms step_avg:92.31ms +step:1291/1660 train_time:119173ms step_avg:92.31ms +step:1292/1660 train_time:119267ms step_avg:92.31ms +step:1293/1660 train_time:119360ms step_avg:92.31ms +step:1294/1660 train_time:119454ms step_avg:92.31ms +step:1295/1660 train_time:119547ms step_avg:92.31ms +step:1296/1660 train_time:119641ms step_avg:92.32ms +step:1297/1660 train_time:119734ms step_avg:92.32ms +step:1298/1660 train_time:119828ms step_avg:92.32ms +step:1299/1660 train_time:119921ms step_avg:92.32ms +step:1300/1660 train_time:120014ms step_avg:92.32ms +step:1301/1660 train_time:120107ms step_avg:92.32ms +step:1302/1660 train_time:120200ms step_avg:92.32ms +step:1303/1660 train_time:120293ms step_avg:92.32ms +step:1304/1660 train_time:120386ms step_avg:92.32ms +step:1305/1660 train_time:120480ms step_avg:92.32ms +step:1306/1660 train_time:120576ms step_avg:92.32ms +step:1307/1660 train_time:120669ms step_avg:92.33ms +step:1308/1660 train_time:120763ms step_avg:92.33ms +step:1309/1660 train_time:120856ms step_avg:92.33ms +step:1310/1660 train_time:120949ms step_avg:92.33ms +step:1311/1660 train_time:121042ms step_avg:92.33ms +step:1312/1660 train_time:121135ms step_avg:92.33ms +step:1313/1660 train_time:121229ms step_avg:92.33ms +step:1314/1660 train_time:121323ms step_avg:92.33ms +step:1315/1660 train_time:121415ms step_avg:92.33ms +step:1316/1660 train_time:121509ms step_avg:92.33ms +step:1317/1660 train_time:121602ms step_avg:92.33ms +step:1318/1660 train_time:121696ms step_avg:92.33ms +step:1319/1660 train_time:121790ms step_avg:92.33ms +step:1320/1660 train_time:121883ms step_avg:92.34ms +step:1321/1660 train_time:121976ms step_avg:92.34ms +step:1322/1660 train_time:122071ms step_avg:92.34ms +step:1323/1660 train_time:122165ms step_avg:92.34ms +step:1324/1660 train_time:122258ms step_avg:92.34ms +step:1325/1660 train_time:122352ms step_avg:92.34ms +step:1326/1660 train_time:122446ms step_avg:92.34ms +step:1327/1660 train_time:122539ms step_avg:92.34ms +step:1328/1660 train_time:122633ms step_avg:92.34ms +step:1329/1660 train_time:122726ms step_avg:92.34ms +step:1330/1660 train_time:122819ms step_avg:92.35ms +step:1331/1660 train_time:122912ms step_avg:92.35ms +step:1332/1660 train_time:123006ms step_avg:92.35ms +step:1333/1660 train_time:123099ms step_avg:92.35ms +step:1334/1660 train_time:123192ms step_avg:92.35ms +step:1335/1660 train_time:123286ms step_avg:92.35ms +step:1336/1660 train_time:123380ms step_avg:92.35ms +step:1337/1660 train_time:123475ms step_avg:92.35ms +step:1338/1660 train_time:123569ms step_avg:92.35ms +step:1339/1660 train_time:123662ms step_avg:92.35ms +step:1340/1660 train_time:123756ms step_avg:92.35ms +step:1341/1660 train_time:123849ms step_avg:92.36ms +step:1342/1660 train_time:123942ms step_avg:92.36ms +step:1343/1660 train_time:124035ms step_avg:92.36ms +step:1344/1660 train_time:124129ms step_avg:92.36ms +step:1345/1660 train_time:124222ms step_avg:92.36ms +step:1346/1660 train_time:124315ms step_avg:92.36ms +step:1347/1660 train_time:124409ms step_avg:92.36ms +step:1348/1660 train_time:124502ms step_avg:92.36ms +step:1349/1660 train_time:124596ms step_avg:92.36ms +step:1350/1660 train_time:124690ms step_avg:92.36ms +step:1351/1660 train_time:124783ms step_avg:92.36ms +step:1352/1660 train_time:124876ms step_avg:92.36ms +step:1353/1660 train_time:124970ms step_avg:92.37ms +step:1354/1660 train_time:125064ms step_avg:92.37ms +step:1355/1660 train_time:125156ms step_avg:92.37ms +step:1356/1660 train_time:125250ms step_avg:92.37ms +step:1357/1660 train_time:125343ms step_avg:92.37ms +step:1358/1660 train_time:125436ms step_avg:92.37ms +step:1359/1660 train_time:125530ms step_avg:92.37ms +step:1360/1660 train_time:125623ms step_avg:92.37ms +step:1361/1660 train_time:125716ms step_avg:92.37ms +step:1362/1660 train_time:125809ms step_avg:92.37ms +step:1363/1660 train_time:125902ms step_avg:92.37ms +step:1364/1660 train_time:125995ms step_avg:92.37ms +step:1365/1660 train_time:126089ms step_avg:92.37ms +step:1366/1660 train_time:126182ms step_avg:92.37ms +step:1367/1660 train_time:126276ms step_avg:92.37ms +step:1368/1660 train_time:126371ms step_avg:92.38ms +step:1369/1660 train_time:126464ms step_avg:92.38ms +step:1370/1660 train_time:126557ms step_avg:92.38ms +step:1371/1660 train_time:126652ms step_avg:92.38ms +step:1372/1660 train_time:126746ms step_avg:92.38ms +step:1373/1660 train_time:126839ms step_avg:92.38ms +step:1374/1660 train_time:126933ms step_avg:92.38ms +step:1375/1660 train_time:127026ms step_avg:92.38ms +step:1375/1660 val_loss:3.3434 train_time:127120ms step_avg:92.45ms +step:1376/1660 train_time:127142ms step_avg:92.40ms +step:1377/1660 train_time:127217ms step_avg:92.39ms +step:1378/1660 train_time:127314ms step_avg:92.39ms +step:1379/1660 train_time:127407ms step_avg:92.39ms +step:1380/1660 train_time:127499ms step_avg:92.39ms +step:1381/1660 train_time:127592ms step_avg:92.39ms +step:1382/1660 train_time:127685ms step_avg:92.39ms +step:1383/1660 train_time:127777ms step_avg:92.39ms +step:1384/1660 train_time:127870ms step_avg:92.39ms +step:1385/1660 train_time:127963ms step_avg:92.39ms +step:1386/1660 train_time:128057ms step_avg:92.39ms +step:1387/1660 train_time:128155ms step_avg:92.40ms +step:1388/1660 train_time:128250ms step_avg:92.40ms +step:1389/1660 train_time:128343ms step_avg:92.40ms +step:1390/1660 train_time:128438ms step_avg:92.40ms +step:1391/1660 train_time:128531ms step_avg:92.40ms +step:1392/1660 train_time:128624ms step_avg:92.40ms +step:1393/1660 train_time:128716ms step_avg:92.40ms +step:1394/1660 train_time:128809ms step_avg:92.40ms +step:1395/1660 train_time:128901ms step_avg:92.40ms +step:1396/1660 train_time:128994ms step_avg:92.40ms +step:1397/1660 train_time:129088ms step_avg:92.40ms +step:1398/1660 train_time:129183ms step_avg:92.41ms +step:1399/1660 train_time:129277ms step_avg:92.41ms +step:1400/1660 train_time:129372ms step_avg:92.41ms +step:1401/1660 train_time:129465ms step_avg:92.41ms +step:1402/1660 train_time:129558ms step_avg:92.41ms +step:1403/1660 train_time:129652ms step_avg:92.41ms +step:1404/1660 train_time:129743ms step_avg:92.41ms +step:1405/1660 train_time:129836ms step_avg:92.41ms +step:1406/1660 train_time:129929ms step_avg:92.41ms +step:1407/1660 train_time:130022ms step_avg:92.41ms +step:1408/1660 train_time:130116ms step_avg:92.41ms +step:1409/1660 train_time:130210ms step_avg:92.41ms +step:1410/1660 train_time:130305ms step_avg:92.41ms +step:1411/1660 train_time:130400ms step_avg:92.42ms +step:1412/1660 train_time:130494ms step_avg:92.42ms +step:1413/1660 train_time:130588ms step_avg:92.42ms +step:1414/1660 train_time:130680ms step_avg:92.42ms +step:1415/1660 train_time:130773ms step_avg:92.42ms +step:1416/1660 train_time:130866ms step_avg:92.42ms +step:1417/1660 train_time:130959ms step_avg:92.42ms +step:1418/1660 train_time:131053ms step_avg:92.42ms +step:1419/1660 train_time:131146ms step_avg:92.42ms +step:1420/1660 train_time:131239ms step_avg:92.42ms +step:1421/1660 train_time:131334ms step_avg:92.42ms +step:1422/1660 train_time:131430ms step_avg:92.43ms +step:1423/1660 train_time:131523ms step_avg:92.43ms +step:1424/1660 train_time:131616ms step_avg:92.43ms +step:1425/1660 train_time:131708ms step_avg:92.43ms +step:1426/1660 train_time:131801ms step_avg:92.43ms +step:1427/1660 train_time:131895ms step_avg:92.43ms +step:1428/1660 train_time:131988ms step_avg:92.43ms +step:1429/1660 train_time:132081ms step_avg:92.43ms +step:1430/1660 train_time:132175ms step_avg:92.43ms +step:1431/1660 train_time:132268ms step_avg:92.43ms +step:1432/1660 train_time:132362ms step_avg:92.43ms +step:1433/1660 train_time:132457ms step_avg:92.43ms +step:1434/1660 train_time:132551ms step_avg:92.43ms +step:1435/1660 train_time:132644ms step_avg:92.43ms +step:1436/1660 train_time:132737ms step_avg:92.43ms +step:1437/1660 train_time:132830ms step_avg:92.44ms +step:1438/1660 train_time:132922ms step_avg:92.44ms +step:1439/1660 train_time:133016ms step_avg:92.44ms +step:1440/1660 train_time:133109ms step_avg:92.44ms +step:1441/1660 train_time:133202ms step_avg:92.44ms +step:1442/1660 train_time:133297ms step_avg:92.44ms +step:1443/1660 train_time:133391ms step_avg:92.44ms +step:1444/1660 train_time:133485ms step_avg:92.44ms +step:1445/1660 train_time:133581ms step_avg:92.44ms +step:1446/1660 train_time:133675ms step_avg:92.44ms +step:1447/1660 train_time:133767ms step_avg:92.44ms +step:1448/1660 train_time:133861ms step_avg:92.45ms +step:1449/1660 train_time:133954ms step_avg:92.45ms +step:1450/1660 train_time:134048ms step_avg:92.45ms +step:1451/1660 train_time:134141ms step_avg:92.45ms +step:1452/1660 train_time:134234ms step_avg:92.45ms +step:1453/1660 train_time:134327ms step_avg:92.45ms +step:1454/1660 train_time:134422ms step_avg:92.45ms +step:1455/1660 train_time:134516ms step_avg:92.45ms +step:1456/1660 train_time:134609ms step_avg:92.45ms +step:1457/1660 train_time:134703ms step_avg:92.45ms +step:1458/1660 train_time:134797ms step_avg:92.45ms +step:1459/1660 train_time:134890ms step_avg:92.45ms +step:1460/1660 train_time:134983ms step_avg:92.45ms +step:1461/1660 train_time:135077ms step_avg:92.46ms +step:1462/1660 train_time:135171ms step_avg:92.46ms +step:1463/1660 train_time:135264ms step_avg:92.46ms +step:1464/1660 train_time:135359ms step_avg:92.46ms +step:1465/1660 train_time:135453ms step_avg:92.46ms +step:1466/1660 train_time:135545ms step_avg:92.46ms +step:1467/1660 train_time:135638ms step_avg:92.46ms +step:1468/1660 train_time:135732ms step_avg:92.46ms +step:1469/1660 train_time:135826ms step_avg:92.46ms +step:1470/1660 train_time:135919ms step_avg:92.46ms +step:1471/1660 train_time:136012ms step_avg:92.46ms +step:1472/1660 train_time:136106ms step_avg:92.46ms +step:1473/1660 train_time:136201ms step_avg:92.46ms +step:1474/1660 train_time:136294ms step_avg:92.47ms +step:1475/1660 train_time:136388ms step_avg:92.47ms +step:1476/1660 train_time:136482ms step_avg:92.47ms +step:1477/1660 train_time:136576ms step_avg:92.47ms +step:1478/1660 train_time:136669ms step_avg:92.47ms +step:1479/1660 train_time:136763ms step_avg:92.47ms +step:1480/1660 train_time:136857ms step_avg:92.47ms +step:1481/1660 train_time:136952ms step_avg:92.47ms +step:1482/1660 train_time:137044ms step_avg:92.47ms +step:1483/1660 train_time:137138ms step_avg:92.47ms +step:1484/1660 train_time:137233ms step_avg:92.47ms +step:1485/1660 train_time:137326ms step_avg:92.48ms +step:1486/1660 train_time:137420ms step_avg:92.48ms +step:1487/1660 train_time:137513ms step_avg:92.48ms +step:1488/1660 train_time:137606ms step_avg:92.48ms +step:1489/1660 train_time:137700ms step_avg:92.48ms +step:1490/1660 train_time:137793ms step_avg:92.48ms +step:1491/1660 train_time:137887ms step_avg:92.48ms +step:1492/1660 train_time:137981ms step_avg:92.48ms +step:1493/1660 train_time:138074ms step_avg:92.48ms +step:1494/1660 train_time:138168ms step_avg:92.48ms +step:1495/1660 train_time:138262ms step_avg:92.48ms +step:1496/1660 train_time:138355ms step_avg:92.48ms +step:1497/1660 train_time:138448ms step_avg:92.48ms +step:1498/1660 train_time:138541ms step_avg:92.48ms +step:1499/1660 train_time:138636ms step_avg:92.49ms +step:1500/1660 train_time:138729ms step_avg:92.49ms +step:1500/1660 val_loss:3.3136 train_time:138824ms step_avg:92.55ms +step:1501/1660 train_time:138843ms step_avg:92.50ms +step:1502/1660 train_time:138920ms step_avg:92.49ms +step:1503/1660 train_time:139016ms step_avg:92.49ms +step:1504/1660 train_time:139109ms step_avg:92.49ms +step:1505/1660 train_time:139202ms step_avg:92.49ms +step:1506/1660 train_time:139294ms step_avg:92.49ms +step:1507/1660 train_time:139386ms step_avg:92.49ms +step:1508/1660 train_time:139479ms step_avg:92.49ms +step:1509/1660 train_time:139572ms step_avg:92.49ms +step:1510/1660 train_time:139664ms step_avg:92.49ms +step:1511/1660 train_time:139759ms step_avg:92.49ms +step:1512/1660 train_time:139855ms step_avg:92.50ms +step:1513/1660 train_time:139951ms step_avg:92.50ms +step:1514/1660 train_time:140045ms step_avg:92.50ms +step:1515/1660 train_time:140138ms step_avg:92.50ms +step:1516/1660 train_time:140230ms step_avg:92.50ms +step:1517/1660 train_time:140323ms step_avg:92.50ms +step:1518/1660 train_time:140417ms step_avg:92.50ms +step:1519/1660 train_time:140510ms step_avg:92.50ms +step:1520/1660 train_time:140602ms step_avg:92.50ms +step:1521/1660 train_time:140695ms step_avg:92.50ms +step:1522/1660 train_time:140789ms step_avg:92.50ms +step:1523/1660 train_time:140883ms step_avg:92.50ms +step:1524/1660 train_time:140978ms step_avg:92.51ms +step:1525/1660 train_time:141073ms step_avg:92.51ms +step:1526/1660 train_time:141166ms step_avg:92.51ms +step:1527/1660 train_time:141260ms step_avg:92.51ms +step:1528/1660 train_time:141353ms step_avg:92.51ms +step:1529/1660 train_time:141446ms step_avg:92.51ms +step:1530/1660 train_time:141539ms step_avg:92.51ms +step:1531/1660 train_time:141632ms step_avg:92.51ms +step:1532/1660 train_time:141726ms step_avg:92.51ms +step:1533/1660 train_time:141820ms step_avg:92.51ms +step:1534/1660 train_time:141915ms step_avg:92.51ms +step:1535/1660 train_time:142009ms step_avg:92.51ms +step:1536/1660 train_time:142103ms step_avg:92.52ms +step:1537/1660 train_time:142196ms step_avg:92.52ms +step:1538/1660 train_time:142289ms step_avg:92.52ms +step:1539/1660 train_time:142382ms step_avg:92.52ms +step:1540/1660 train_time:142475ms step_avg:92.52ms +step:1541/1660 train_time:142568ms step_avg:92.52ms +step:1542/1660 train_time:142661ms step_avg:92.52ms +step:1543/1660 train_time:142755ms step_avg:92.52ms +step:1544/1660 train_time:142849ms step_avg:92.52ms +step:1545/1660 train_time:142945ms step_avg:92.52ms +step:1546/1660 train_time:143038ms step_avg:92.52ms +step:1547/1660 train_time:143132ms step_avg:92.52ms +step:1548/1660 train_time:143225ms step_avg:92.52ms +step:1549/1660 train_time:143318ms step_avg:92.52ms +step:1550/1660 train_time:143411ms step_avg:92.52ms +step:1551/1660 train_time:143505ms step_avg:92.52ms +step:1552/1660 train_time:143598ms step_avg:92.52ms +step:1553/1660 train_time:143691ms step_avg:92.53ms +step:1554/1660 train_time:143784ms step_avg:92.53ms +step:1555/1660 train_time:143879ms step_avg:92.53ms +step:1556/1660 train_time:143973ms step_avg:92.53ms +step:1557/1660 train_time:144067ms step_avg:92.53ms +step:1558/1660 train_time:144161ms step_avg:92.53ms +step:1559/1660 train_time:144256ms step_avg:92.53ms +step:1560/1660 train_time:144348ms step_avg:92.53ms +step:1561/1660 train_time:144441ms step_avg:92.53ms +step:1562/1660 train_time:144534ms step_avg:92.53ms +step:1563/1660 train_time:144628ms step_avg:92.53ms +step:1564/1660 train_time:144722ms step_avg:92.53ms +step:1565/1660 train_time:144816ms step_avg:92.53ms +step:1566/1660 train_time:144909ms step_avg:92.53ms +step:1567/1660 train_time:145003ms step_avg:92.54ms +step:1568/1660 train_time:145096ms step_avg:92.54ms +step:1569/1660 train_time:145189ms step_avg:92.54ms +step:1570/1660 train_time:145283ms step_avg:92.54ms +step:1571/1660 train_time:145376ms step_avg:92.54ms +step:1572/1660 train_time:145469ms step_avg:92.54ms +step:1573/1660 train_time:145562ms step_avg:92.54ms +step:1574/1660 train_time:145656ms step_avg:92.54ms +step:1575/1660 train_time:145750ms step_avg:92.54ms +step:1576/1660 train_time:145844ms step_avg:92.54ms +step:1577/1660 train_time:145937ms step_avg:92.54ms +step:1578/1660 train_time:146030ms step_avg:92.54ms +step:1579/1660 train_time:146124ms step_avg:92.54ms +step:1580/1660 train_time:146219ms step_avg:92.54ms +step:1581/1660 train_time:146313ms step_avg:92.54ms +step:1582/1660 train_time:146406ms step_avg:92.54ms +step:1583/1660 train_time:146499ms step_avg:92.54ms +step:1584/1660 train_time:146591ms step_avg:92.55ms +step:1585/1660 train_time:146685ms step_avg:92.55ms +step:1586/1660 train_time:146779ms step_avg:92.55ms +step:1587/1660 train_time:146872ms step_avg:92.55ms +step:1588/1660 train_time:146965ms step_avg:92.55ms +step:1589/1660 train_time:147060ms step_avg:92.55ms +step:1590/1660 train_time:147154ms step_avg:92.55ms +step:1591/1660 train_time:147247ms step_avg:92.55ms +step:1592/1660 train_time:147340ms step_avg:92.55ms +step:1593/1660 train_time:147433ms step_avg:92.55ms +step:1594/1660 train_time:147527ms step_avg:92.55ms +step:1595/1660 train_time:147620ms step_avg:92.55ms +step:1596/1660 train_time:147713ms step_avg:92.55ms +step:1597/1660 train_time:147807ms step_avg:92.55ms +step:1598/1660 train_time:147901ms step_avg:92.55ms +step:1599/1660 train_time:147994ms step_avg:92.55ms +step:1600/1660 train_time:148087ms step_avg:92.55ms +step:1601/1660 train_time:148182ms step_avg:92.56ms +step:1602/1660 train_time:148275ms step_avg:92.56ms +step:1603/1660 train_time:148369ms step_avg:92.56ms +step:1604/1660 train_time:148462ms step_avg:92.56ms +step:1605/1660 train_time:148556ms step_avg:92.56ms +step:1606/1660 train_time:148648ms step_avg:92.56ms +step:1607/1660 train_time:148742ms step_avg:92.56ms +step:1608/1660 train_time:148835ms step_avg:92.56ms +step:1609/1660 train_time:148929ms step_avg:92.56ms +step:1610/1660 train_time:149022ms step_avg:92.56ms +step:1611/1660 train_time:149117ms step_avg:92.56ms +step:1612/1660 train_time:149211ms step_avg:92.56ms +step:1613/1660 train_time:149304ms step_avg:92.56ms +step:1614/1660 train_time:149398ms step_avg:92.56ms +step:1615/1660 train_time:149491ms step_avg:92.56ms +step:1616/1660 train_time:149584ms step_avg:92.56ms +step:1617/1660 train_time:149678ms step_avg:92.57ms +step:1618/1660 train_time:149771ms step_avg:92.57ms +step:1619/1660 train_time:149864ms step_avg:92.57ms +step:1620/1660 train_time:149957ms step_avg:92.57ms +step:1621/1660 train_time:150050ms step_avg:92.57ms +step:1622/1660 train_time:150145ms step_avg:92.57ms +step:1623/1660 train_time:150240ms step_avg:92.57ms +step:1624/1660 train_time:150334ms step_avg:92.57ms +step:1625/1660 train_time:150427ms step_avg:92.57ms +step:1625/1660 val_loss:3.2889 train_time:150522ms step_avg:92.63ms +step:1626/1660 train_time:150541ms step_avg:92.58ms +step:1627/1660 train_time:150621ms step_avg:92.58ms +step:1628/1660 train_time:150718ms step_avg:92.58ms +step:1629/1660 train_time:150814ms step_avg:92.58ms +step:1630/1660 train_time:150906ms step_avg:92.58ms +step:1631/1660 train_time:150999ms step_avg:92.58ms +step:1632/1660 train_time:151092ms step_avg:92.58ms +step:1633/1660 train_time:151184ms step_avg:92.58ms +step:1634/1660 train_time:151277ms step_avg:92.58ms +step:1635/1660 train_time:151369ms step_avg:92.58ms +step:1636/1660 train_time:151461ms step_avg:92.58ms +step:1637/1660 train_time:151558ms step_avg:92.58ms +step:1638/1660 train_time:151657ms step_avg:92.59ms +step:1639/1660 train_time:151754ms step_avg:92.59ms +step:1640/1660 train_time:151847ms step_avg:92.59ms +step:1641/1660 train_time:151940ms step_avg:92.59ms +step:1642/1660 train_time:152033ms step_avg:92.59ms +step:1643/1660 train_time:152126ms step_avg:92.59ms +step:1644/1660 train_time:152218ms step_avg:92.59ms +step:1645/1660 train_time:152310ms step_avg:92.59ms +step:1646/1660 train_time:152403ms step_avg:92.59ms +step:1647/1660 train_time:152497ms step_avg:92.59ms +step:1648/1660 train_time:152592ms step_avg:92.59ms +step:1649/1660 train_time:152687ms step_avg:92.59ms +step:1650/1660 train_time:152781ms step_avg:92.59ms +step:1651/1660 train_time:152874ms step_avg:92.59ms +step:1652/1660 train_time:152967ms step_avg:92.60ms +step:1653/1660 train_time:153059ms step_avg:92.59ms +step:1654/1660 train_time:153153ms step_avg:92.60ms +step:1655/1660 train_time:153246ms step_avg:92.60ms +step:1656/1660 train_time:153338ms step_avg:92.60ms +step:1657/1660 train_time:153431ms step_avg:92.60ms +step:1658/1660 train_time:153525ms step_avg:92.60ms +step:1659/1660 train_time:153619ms step_avg:92.60ms +step:1660/1660 train_time:153715ms step_avg:92.60ms +step:1660/1660 val_loss:3.2807 train_time:153811ms step_avg:92.66ms +peak memory allocated: 31587 MiB reserved: 47116 MiB diff --git a/train_gpt.py b/train_gpt.py index c4d8ab06c..2082bd9fa 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -6,6 +6,7 @@ import copy import glob import math +import threading import time import uuid from dataclasses import dataclass @@ -720,6 +721,40 @@ def forward(self, x: Tensor): else: return F.linear(x, self.weight.type_as(x)) +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): assert cos.size(0) >= x_BTHD.size(-3) @@ -732,15 +767,14 @@ def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): y2 = x1 * (-sin) + x2 * cos return torch.cat((y1, y2), 3) - @dataclass class AttnArgs: ve: torch.Tensor sa_lambdas: torch.Tensor seqlens: torch.Tensor bm_size: int - rotary_cos: torch.Tensor - rotary_sin: torch.Tensor + cos: torch.Tensor + sin: torch.Tensor attn_scale: float class CausalSelfAttention(nn.Module): @@ -768,13 +802,13 @@ def forward(self, x: Tensor, attn_args: AttnArgs): assert B == 1, "varlen sequences requires B == 1" assert T % 16 == 0 # unpack attention args - rotary_cos, rotary_sin = attn_args.rotary_cos, attn_args.rotary_sin + cos, sin = attn_args.cos, attn_args.sin ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = rotary(q, rotary_cos, rotary_sin), rotary(k, rotary_cos, rotary_sin) + q, k = rotary(q, cos, sin), rotary(k, cos, sin) if ve is not None: v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 else: # skip mid-layers token value embeddings by @YouJiacheng @@ -842,6 +876,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. use_fp8 = not os.environ.get("DISABLE_FP8", False) @@ -865,8 +900,6 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i ] ) ) - self.max_seq_len = max_seq_len - self.setup_yarn(head_dim) # set learning rates for param in self.embed.parameters(): param.lr_mul = 75. @@ -875,39 +908,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i self.lm_head.weight.lr_mul = 1.0 self.scalars.lr_mul = 5.0 - def setup_yarn(self, head_dim: int): - # store single copy of rotary tensors - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=head_dim//4, dtype=torch.float32) - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(head_dim//4)]) - t = torch.arange(self.max_seq_len, dtype=torch.float32) - theta = torch.outer(t, angular_freq) - self.rotary_cos = nn.Buffer( - theta.cos().to(torch.bfloat16), persistent=False - ) - self.rotary_sin = nn.Buffer( - theta.sin().to(torch.bfloat16), persistent=False - ) - self.angular_freq = angular_freq - - # scale attention factor f in attn=softmax(f*qk) logarithmically with window size @classiclarryd - windows = list(dict.fromkeys(list(args.ws_schedule) + [args.ws_validate])) - scale_factors = [0.2 * math.log(curr / prev) + 1 for prev, curr in zip(windows[:-1], windows[1:])] - # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - attn_scales = list(accumulate([0.1] + scale_factors, lambda acc, factor: acc * factor)) - self.attn_scales = dict(zip(windows, attn_scales)) - - def apply_yarn(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): - rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) - scaling_factor = old_window / new_window - interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) - self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) - t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) - theta = torch.outer(t, self.angular_freq) - self.rotary_cos.copy_(theta.cos()) - self.rotary_sin.copy_(theta.sin()) - - def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): assert input_seq.ndim == 1 ve = [value_embed(input_seq) for value_embed in self.value_embeds] @@ -916,12 +917,11 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in assert len(ve) == len(self.blocks) long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] assert len(bm_sizes) == len(self.blocks) - x = x0 = norm(self.embed(input_seq)[None]).to( - torch.bfloat16 - ) # use of norm here by @Grad62304977 + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 # U-net design by @brendanh0gan skip_connections = [] @@ -937,9 +937,9 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in sa_lambdas=sa_lambdas[i], seqlens=seqlens, bm_size=bm_sizes[i], - rotary_cos=self.rotary_cos, - rotary_sin=self.rotary_sin, - attn_scale=self.attn_scales[ws] + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale ) if i >= n: gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) @@ -974,14 +974,45 @@ def _load_data_shard(file: Path): class BOSFinder: # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1): + def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): # Precompute BOS positions once per shard + self.tokens=tokens self.size = tokens.numel() - self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + #t0 = time.perf_counter() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + #t1 = time.perf_counter() + #print(f'{t1-t0} slowload') self.i = 0 self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() n = len(self.bos_idx) starts = [[] for _ in range(self.world_size)] ends = [[] for _ in range(self.world_size)] @@ -1003,9 +1034,33 @@ def next_batch(self, num_tokens_local: int, max_seq_len: int): assert cur_len == num_tokens_local + 1 self.i = idx - + self.batch_iter+=1 return starts, ends +class DataPreloader: + def __init__(self, file_iter, world_size): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len rank = dist.get_rank() if dist.is_initialized() else 0 @@ -1019,8 +1074,14 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training tokens = _load_data_shard(next(file_iter)) - finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None - pos = 0 # for unaligned case + if align_to_bos: + #loading in a whole shard will be slow... + #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case while True: num_tokens_local = num_tokens // world_size @@ -1032,8 +1093,10 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) except StopIteration: # This shard is exhausted, load the next one in the next loop iteration. - tokens = _load_data_shard(next(file_iter)) - finder = BOSFinder(tokens, world_size=world_size) + #tokens = _load_data_shard(next(file_iter)) + #finder = BOSFinder(tokens, world_size=world_size) + tokens, finder = preloader.get() + preloader.start() continue buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) @@ -1054,6 +1117,7 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] pos += num_tokens + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) _cum_lengths[0] = 0 _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths @@ -1070,7 +1134,7 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" num_tokens = new_num_tokens max_seq_len = new_max_seq_len - grad_accum_steps = new_grad_accum_steps + grad_accum_steps = new_grad_accum_steps # ----------------------------------------------------------------------------- @@ -1086,7 +1150,7 @@ class Hyperparameters: train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1670 # number of iterations to run + num_iterations: int = 1660 # number of iterations to run cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate # evaluation and logging run_id: str = f"{uuid.uuid4()}" @@ -1096,6 +1160,7 @@ class Hyperparameters: block_size: int = 128 ws_schedule: tuple = (3, 7, 11) ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length args = Hyperparameters() @@ -1191,11 +1256,11 @@ def get_lr(step: int): def get_ws(step: int): if step == args.num_iterations: - return args.ws_validate + return args.ws_validate, args.ws_validate_final_layer x = step / (1 + args.num_iterations) assert 0 <= x < 1 ws_idx = int(len(args.ws_schedule) * x) - return args.ws_schedule[ws_idx] + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) @@ -1208,13 +1273,21 @@ def get_ws(step: int): initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] for step in range(warmup_steps): inputs, targets, cum_seqlens = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each - model(inputs, targets, cum_seqlens, ws).backward() + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws Date: Tue, 16 Sep 2025 18:37:25 -0700 Subject: [PATCH 12/14] cleanup --- .../25db37c7-2bab-4ef4-ae63-d593590ef823.txt | 0 .../26acd99c-9089-406e-8249-f0532e6c2a13.txt | 0 .../305f24ee-051f-41a0-939a-0fa26654712a.txt | 0 .../517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt | 0 .../584d668b-cc79-4dde-b5be-b911623bdb61.txt | 0 .../61705980-e239-4d86-9233-210200da7010.txt | 0 .../720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt | 0 .../93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt | 0 .../93480ae1-43e4-415b-813e-ba8b43ab899b.txt | 0 .../README.md | 2 +- .../a04db288-4cdd-4401-bdd1-444b26c53cd8.txt | 0 train_gpt.py | 20 +++++++------------ 12 files changed, 8 insertions(+), 14 deletions(-) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/25db37c7-2bab-4ef4-ae63-d593590ef823.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/26acd99c-9089-406e-8249-f0532e6c2a13.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/305f24ee-051f-41a0-939a-0fa26654712a.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/584d668b-cc79-4dde-b5be-b911623bdb61.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/61705980-e239-4d86-9233-210200da7010.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/93480ae1-43e4-415b-813e-ba8b43ab899b.txt (100%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/README.md (92%) rename records/{091525_ThreadingFinalWindow => 091525_AsyncDataLoadAttnFinalWindow}/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt (100%) diff --git a/records/091525_ThreadingFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt b/records/091525_AsyncDataLoadAttnFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt diff --git a/records/091525_ThreadingFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt b/records/091525_AsyncDataLoadAttnFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/26acd99c-9089-406e-8249-f0532e6c2a13.txt diff --git a/records/091525_ThreadingFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt b/records/091525_AsyncDataLoadAttnFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/305f24ee-051f-41a0-939a-0fa26654712a.txt diff --git a/records/091525_ThreadingFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt b/records/091525_AsyncDataLoadAttnFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/517e4b6c-2cc5-46ae-82fe-55c86f28ac0e.txt diff --git a/records/091525_ThreadingFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt b/records/091525_AsyncDataLoadAttnFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/584d668b-cc79-4dde-b5be-b911623bdb61.txt diff --git a/records/091525_ThreadingFinalWindow/61705980-e239-4d86-9233-210200da7010.txt b/records/091525_AsyncDataLoadAttnFinalWindow/61705980-e239-4d86-9233-210200da7010.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/61705980-e239-4d86-9233-210200da7010.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/61705980-e239-4d86-9233-210200da7010.txt diff --git a/records/091525_ThreadingFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt b/records/091525_AsyncDataLoadAttnFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/720dfef5-25c6-4e17-81b5-975ecfdd4d81.txt diff --git a/records/091525_ThreadingFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt b/records/091525_AsyncDataLoadAttnFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/93252460-2bb6-4e87-bdf0-1d2ca99f48df.txt diff --git a/records/091525_ThreadingFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt b/records/091525_AsyncDataLoadAttnFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/93480ae1-43e4-415b-813e-ba8b43ab899b.txt diff --git a/records/091525_ThreadingFinalWindow/README.md b/records/091525_AsyncDataLoadAttnFinalWindow/README.md similarity index 92% rename from records/091525_ThreadingFinalWindow/README.md rename to records/091525_AsyncDataLoadAttnFinalWindow/README.md index 019a26bb6..1aa37076d 100644 --- a/records/091525_ThreadingFinalWindow/README.md +++ b/records/091525_AsyncDataLoadAttnFinalWindow/README.md @@ -1,4 +1,4 @@ -## New WR 153.9s: Apply threading to preload data and extend final layer attention window for validation +## New WR 153.9s: Asynchronously fetch and index data batches, extend final layer attention window for validation This PR builds on all recent WR improvements including PR #125 by @bernard24. This one adds: diff --git a/records/091525_ThreadingFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt b/records/091525_AsyncDataLoadAttnFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt similarity index 100% rename from records/091525_ThreadingFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt rename to records/091525_AsyncDataLoadAttnFinalWindow/a04db288-4cdd-4401-bdd1-444b26c53cd8.txt diff --git a/train_gpt.py b/train_gpt.py index 2082bd9fa..4b2935897 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -974,7 +974,7 @@ def _load_data_shard(file: Path): class BOSFinder: # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): # Precompute BOS positions once per shard self.tokens=tokens self.size = tokens.numel() @@ -986,10 +986,7 @@ def __init__(self, tokens: Tensor, world_size: int = 1, quickload=False): self.ready = threading.Event() self.start() else: - #t0 = time.perf_counter() self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - #t1 = time.perf_counter() - #print(f'{t1-t0} slowload') self.i = 0 self.world_size = world_size self.batch_iter = 0 @@ -1038,7 +1035,8 @@ def next_batch(self, num_tokens_local: int, max_seq_len: int): return starts, ends class DataPreloader: - def __init__(self, file_iter, world_size): + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): self.file_iter = file_iter self.world_size = world_size self.thread = None @@ -1075,13 +1073,11 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training tokens = _load_data_shard(next(file_iter)) if align_to_bos: - #loading in a whole shard will be slow... - #BosFinder tracks its self.i index. I can kickoff with a smaller set. then have the - finder = BOSFinder(tokens, world_size=world_size, quickload=True) - preloader = DataPreloader(file_iter, world_size) - preloader.start() + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() else: - pos = 0 # for unaligned case + pos = 0 # for unaligned case while True: num_tokens_local = num_tokens // world_size @@ -1093,8 +1089,6 @@ def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_l start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) except StopIteration: # This shard is exhausted, load the next one in the next loop iteration. - #tokens = _load_data_shard(next(file_iter)) - #finder = BOSFinder(tokens, world_size=world_size) tokens, finder = preloader.get() preloader.start() continue From d149ec4ac119fc600fbda007da7303878c9aae9b Mon Sep 17 00:00:00 2001 From: larry dial Date: Thu, 18 Sep 2025 12:10:22 -0700 Subject: [PATCH 13/14] add smear --- .../18a1e5c7-947e-479d-bc3a-a57a61a98fc9.txt | 3089 +++++++++++++++++ .../1cc585a6-ecd2-452c-905d-d5774079e6ff.txt | 3089 +++++++++++++++++ .../2bbcf732-7a1d-4bad-b992-aa018419033e.txt | 3089 +++++++++++++++++ .../36761e6e-19ee-414f-a43c-63729950dfe7.txt | 3089 +++++++++++++++++ .../4e8d8366-3db6-43fc-aba0-03be7d484dd3.txt | 3089 +++++++++++++++++ .../692ba682-7466-4b59-a31a-7e0adcb55b4b.txt | 3089 +++++++++++++++++ .../81535293-56a7-49f3-925c-569441d4f87c.txt | 3089 +++++++++++++++++ .../898a21a4-3cf7-4c32-a61b-c3427618ae7b.txt | 3089 +++++++++++++++++ records/091825_Smear/README.md | 35 + .../cc16404d-92d4-48c1-b9b0-b906360363b4.txt | 3089 +++++++++++++++++ .../fd1a1ca9-afff-4881-a686-8b56bad5901b.txt | 3089 +++++++++++++++++ train_gpt.py | 18 +- 12 files changed, 30939 insertions(+), 4 deletions(-) create mode 100644 records/091825_Smear/18a1e5c7-947e-479d-bc3a-a57a61a98fc9.txt create mode 100644 records/091825_Smear/1cc585a6-ecd2-452c-905d-d5774079e6ff.txt create mode 100644 records/091825_Smear/2bbcf732-7a1d-4bad-b992-aa018419033e.txt create mode 100644 records/091825_Smear/36761e6e-19ee-414f-a43c-63729950dfe7.txt create mode 100644 records/091825_Smear/4e8d8366-3db6-43fc-aba0-03be7d484dd3.txt create mode 100644 records/091825_Smear/692ba682-7466-4b59-a31a-7e0adcb55b4b.txt create mode 100644 records/091825_Smear/81535293-56a7-49f3-925c-569441d4f87c.txt create mode 100644 records/091825_Smear/898a21a4-3cf7-4c32-a61b-c3427618ae7b.txt create mode 100644 records/091825_Smear/README.md create mode 100644 records/091825_Smear/cc16404d-92d4-48c1-b9b0-b906360363b4.txt create mode 100644 records/091825_Smear/fd1a1ca9-afff-4881-a686-8b56bad5901b.txt diff --git a/records/091825_Smear/18a1e5c7-947e-479d-bc3a-a57a61a98fc9.txt b/records/091825_Smear/18a1e5c7-947e-479d-bc3a-a57a61a98fc9.txt new file mode 100644 index 000000000..a215a46ed --- /dev/null +++ b/records/091825_Smear/18a1e5c7-947e-479d-bc3a-a57a61a98fc9.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:24:39 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 35C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:134ms step_avg:133.56ms +step:2/1645 train_time:152ms step_avg:76.25ms +step:3/1645 train_time:222ms step_avg:74.07ms +step:4/1645 train_time:312ms step_avg:78.03ms +step:5/1645 train_time:402ms step_avg:80.46ms +step:6/1645 train_time:493ms step_avg:82.20ms +step:7/1645 train_time:584ms step_avg:83.44ms +step:8/1645 train_time:675ms step_avg:84.39ms +step:9/1645 train_time:766ms step_avg:85.09ms +step:10/1645 train_time:857ms step_avg:85.67ms +step:11/1645 train_time:948ms step_avg:86.19ms +step:12/1645 train_time:1043ms step_avg:86.90ms +step:13/1645 train_time:1138ms step_avg:87.55ms +step:14/1645 train_time:1231ms step_avg:87.95ms +step:15/1645 train_time:1323ms step_avg:88.20ms +step:16/1645 train_time:1414ms step_avg:88.38ms +step:17/1645 train_time:1506ms step_avg:88.58ms +step:18/1645 train_time:1597ms step_avg:88.71ms +step:19/1645 train_time:1689ms step_avg:88.87ms +step:20/1645 train_time:1779ms step_avg:88.97ms +step:21/1645 train_time:1871ms step_avg:89.08ms +step:22/1645 train_time:1963ms step_avg:89.23ms +step:23/1645 train_time:2056ms step_avg:89.41ms +step:24/1645 train_time:2150ms step_avg:89.59ms +step:25/1645 train_time:2243ms step_avg:89.71ms +step:26/1645 train_time:2335ms step_avg:89.80ms +step:27/1645 train_time:2426ms step_avg:89.87ms +step:28/1645 train_time:2518ms step_avg:89.93ms +step:29/1645 train_time:2610ms step_avg:89.99ms +step:30/1645 train_time:2701ms step_avg:90.03ms +step:31/1645 train_time:2792ms step_avg:90.06ms +step:32/1645 train_time:2884ms step_avg:90.11ms +step:33/1645 train_time:2975ms step_avg:90.15ms +step:34/1645 train_time:3069ms step_avg:90.28ms +step:35/1645 train_time:3163ms step_avg:90.37ms +step:36/1645 train_time:3256ms step_avg:90.43ms +step:37/1645 train_time:3347ms step_avg:90.47ms +step:38/1645 train_time:3440ms step_avg:90.52ms +step:39/1645 train_time:3532ms step_avg:90.56ms +step:40/1645 train_time:3623ms step_avg:90.58ms +step:41/1645 train_time:3715ms step_avg:90.61ms +step:42/1645 train_time:3807ms step_avg:90.64ms +step:43/1645 train_time:3898ms step_avg:90.65ms +step:44/1645 train_time:3990ms step_avg:90.69ms +step:45/1645 train_time:4082ms step_avg:90.72ms +step:46/1645 train_time:4174ms step_avg:90.74ms +step:47/1645 train_time:4267ms step_avg:90.79ms +step:48/1645 train_time:4359ms step_avg:90.82ms +step:49/1645 train_time:4451ms step_avg:90.84ms +step:50/1645 train_time:4544ms step_avg:90.87ms +step:51/1645 train_time:4635ms step_avg:90.88ms +step:52/1645 train_time:4727ms step_avg:90.89ms +step:53/1645 train_time:4820ms step_avg:90.94ms +step:54/1645 train_time:4910ms step_avg:90.93ms +step:55/1645 train_time:5002ms step_avg:90.94ms +step:56/1645 train_time:5093ms step_avg:90.95ms +step:57/1645 train_time:5185ms step_avg:90.96ms +step:58/1645 train_time:5277ms step_avg:90.98ms +step:59/1645 train_time:5369ms step_avg:91.01ms +step:60/1645 train_time:5461ms step_avg:91.02ms +step:61/1645 train_time:5553ms step_avg:91.03ms +step:62/1645 train_time:5645ms step_avg:91.05ms +step:63/1645 train_time:5737ms step_avg:91.07ms +step:64/1645 train_time:5829ms step_avg:91.08ms +step:65/1645 train_time:5922ms step_avg:91.11ms +step:66/1645 train_time:6013ms step_avg:91.10ms +step:67/1645 train_time:6104ms step_avg:91.10ms +step:68/1645 train_time:6195ms step_avg:91.11ms +step:69/1645 train_time:6287ms step_avg:91.12ms +step:70/1645 train_time:6379ms step_avg:91.13ms +step:71/1645 train_time:6471ms step_avg:91.15ms +step:72/1645 train_time:6564ms step_avg:91.16ms +step:73/1645 train_time:6655ms step_avg:91.17ms +step:74/1645 train_time:6747ms step_avg:91.18ms +step:75/1645 train_time:6840ms step_avg:91.20ms +step:76/1645 train_time:6932ms step_avg:91.20ms +step:77/1645 train_time:7023ms step_avg:91.21ms +step:78/1645 train_time:7114ms step_avg:91.21ms +step:79/1645 train_time:7206ms step_avg:91.21ms +step:80/1645 train_time:7298ms step_avg:91.22ms +step:81/1645 train_time:7390ms step_avg:91.24ms +step:82/1645 train_time:7481ms step_avg:91.24ms +step:83/1645 train_time:7574ms step_avg:91.25ms +step:84/1645 train_time:7667ms step_avg:91.27ms +step:85/1645 train_time:7758ms step_avg:91.28ms +step:86/1645 train_time:7850ms step_avg:91.28ms +step:87/1645 train_time:7943ms step_avg:91.29ms +step:88/1645 train_time:8035ms step_avg:91.30ms +step:89/1645 train_time:8126ms step_avg:91.31ms +step:90/1645 train_time:8217ms step_avg:91.30ms +step:91/1645 train_time:8309ms step_avg:91.31ms +step:92/1645 train_time:8401ms step_avg:91.31ms +step:93/1645 train_time:8493ms step_avg:91.32ms +step:94/1645 train_time:8584ms step_avg:91.32ms +step:95/1645 train_time:8675ms step_avg:91.32ms +step:96/1645 train_time:8769ms step_avg:91.34ms +step:97/1645 train_time:8862ms step_avg:91.36ms +step:98/1645 train_time:8954ms step_avg:91.37ms +step:99/1645 train_time:9045ms step_avg:91.37ms +step:100/1645 train_time:9137ms step_avg:91.37ms +step:101/1645 train_time:9229ms step_avg:91.38ms +step:102/1645 train_time:9321ms step_avg:91.38ms +step:103/1645 train_time:9412ms step_avg:91.38ms +step:104/1645 train_time:9503ms step_avg:91.38ms +step:105/1645 train_time:9594ms step_avg:91.38ms +step:106/1645 train_time:9687ms step_avg:91.39ms +step:107/1645 train_time:9778ms step_avg:91.38ms +step:108/1645 train_time:9872ms step_avg:91.40ms +step:109/1645 train_time:9963ms step_avg:91.41ms +step:110/1645 train_time:10055ms step_avg:91.41ms +step:111/1645 train_time:10147ms step_avg:91.41ms +step:112/1645 train_time:10240ms step_avg:91.43ms +step:113/1645 train_time:10331ms step_avg:91.43ms +step:114/1645 train_time:10423ms step_avg:91.43ms +step:115/1645 train_time:10515ms step_avg:91.43ms +step:116/1645 train_time:10607ms step_avg:91.44ms +step:117/1645 train_time:10699ms step_avg:91.44ms +step:118/1645 train_time:10791ms step_avg:91.45ms +step:119/1645 train_time:10882ms step_avg:91.45ms +step:120/1645 train_time:10974ms step_avg:91.45ms +step:121/1645 train_time:11066ms step_avg:91.46ms +step:122/1645 train_time:11158ms step_avg:91.46ms +step:123/1645 train_time:11251ms step_avg:91.47ms +step:124/1645 train_time:11342ms step_avg:91.47ms +step:125/1645 train_time:11433ms step_avg:91.47ms +step:125/1645 val_loss:4.3202 train_time:11525ms step_avg:92.20ms +step:126/1645 train_time:11547ms step_avg:91.64ms +step:127/1645 train_time:11621ms step_avg:91.50ms +step:128/1645 train_time:11720ms step_avg:91.57ms +step:129/1645 train_time:11814ms step_avg:91.58ms +step:130/1645 train_time:11906ms step_avg:91.59ms +step:131/1645 train_time:11997ms step_avg:91.58ms +step:132/1645 train_time:12087ms step_avg:91.57ms +step:133/1645 train_time:12179ms step_avg:91.57ms +step:134/1645 train_time:12269ms step_avg:91.56ms +step:135/1645 train_time:12360ms step_avg:91.56ms +step:136/1645 train_time:12451ms step_avg:91.55ms +step:137/1645 train_time:12544ms step_avg:91.56ms +step:138/1645 train_time:12638ms step_avg:91.58ms +step:139/1645 train_time:12732ms step_avg:91.59ms +step:140/1645 train_time:12825ms step_avg:91.61ms +step:141/1645 train_time:12918ms step_avg:91.61ms +step:142/1645 train_time:13009ms step_avg:91.61ms +step:143/1645 train_time:13100ms step_avg:91.61ms +step:144/1645 train_time:13191ms step_avg:91.60ms +step:145/1645 train_time:13282ms step_avg:91.60ms +step:146/1645 train_time:13373ms step_avg:91.60ms +step:147/1645 train_time:13465ms step_avg:91.60ms +step:148/1645 train_time:13557ms step_avg:91.60ms +step:149/1645 train_time:13649ms step_avg:91.61ms +step:150/1645 train_time:13742ms step_avg:91.61ms +step:151/1645 train_time:13835ms step_avg:91.62ms +step:152/1645 train_time:13927ms step_avg:91.62ms +step:153/1645 train_time:14018ms step_avg:91.62ms +step:154/1645 train_time:14109ms step_avg:91.62ms +step:155/1645 train_time:14200ms step_avg:91.61ms +step:156/1645 train_time:14291ms step_avg:91.61ms +step:157/1645 train_time:14382ms step_avg:91.60ms +step:158/1645 train_time:14474ms step_avg:91.61ms +step:159/1645 train_time:14566ms step_avg:91.61ms +step:160/1645 train_time:14658ms step_avg:91.61ms +step:161/1645 train_time:14750ms step_avg:91.62ms +step:162/1645 train_time:14843ms step_avg:91.63ms +step:163/1645 train_time:14936ms step_avg:91.63ms +step:164/1645 train_time:15027ms step_avg:91.63ms +step:165/1645 train_time:15118ms step_avg:91.62ms +step:166/1645 train_time:15209ms step_avg:91.62ms +step:167/1645 train_time:15301ms step_avg:91.62ms +step:168/1645 train_time:15392ms step_avg:91.62ms +step:169/1645 train_time:15484ms step_avg:91.62ms +step:170/1645 train_time:15576ms step_avg:91.62ms +step:171/1645 train_time:15668ms step_avg:91.63ms +step:172/1645 train_time:15760ms step_avg:91.63ms +step:173/1645 train_time:15852ms step_avg:91.63ms +step:174/1645 train_time:15945ms step_avg:91.64ms +step:175/1645 train_time:16037ms step_avg:91.64ms +step:176/1645 train_time:16128ms step_avg:91.64ms +step:177/1645 train_time:16219ms step_avg:91.63ms +step:178/1645 train_time:16310ms step_avg:91.63ms +step:179/1645 train_time:16401ms step_avg:91.63ms +step:180/1645 train_time:16493ms step_avg:91.63ms +step:181/1645 train_time:16585ms step_avg:91.63ms +step:182/1645 train_time:16677ms step_avg:91.63ms +step:183/1645 train_time:16769ms step_avg:91.64ms +step:184/1645 train_time:16862ms step_avg:91.64ms +step:185/1645 train_time:16953ms step_avg:91.64ms +step:186/1645 train_time:17045ms step_avg:91.64ms +step:187/1645 train_time:17137ms step_avg:91.64ms +step:188/1645 train_time:17228ms step_avg:91.64ms +step:189/1645 train_time:17319ms step_avg:91.63ms +step:190/1645 train_time:17410ms step_avg:91.63ms +step:191/1645 train_time:17502ms step_avg:91.63ms +step:192/1645 train_time:17594ms step_avg:91.64ms +step:193/1645 train_time:17687ms step_avg:91.64ms +step:194/1645 train_time:17778ms step_avg:91.64ms +step:195/1645 train_time:17870ms step_avg:91.64ms +step:196/1645 train_time:17962ms step_avg:91.64ms +step:197/1645 train_time:18055ms step_avg:91.65ms +step:198/1645 train_time:18146ms step_avg:91.65ms +step:199/1645 train_time:18238ms step_avg:91.65ms +step:200/1645 train_time:18329ms step_avg:91.64ms +step:201/1645 train_time:18420ms step_avg:91.64ms +step:202/1645 train_time:18511ms step_avg:91.64ms +step:203/1645 train_time:18604ms step_avg:91.65ms +step:204/1645 train_time:18695ms step_avg:91.64ms +step:205/1645 train_time:18788ms step_avg:91.65ms +step:206/1645 train_time:18880ms step_avg:91.65ms +step:207/1645 train_time:18972ms step_avg:91.65ms +step:208/1645 train_time:19063ms step_avg:91.65ms +step:209/1645 train_time:19156ms step_avg:91.65ms +step:210/1645 train_time:19247ms step_avg:91.65ms +step:211/1645 train_time:19339ms step_avg:91.65ms +step:212/1645 train_time:19429ms step_avg:91.65ms +step:213/1645 train_time:19520ms step_avg:91.65ms +step:214/1645 train_time:19612ms step_avg:91.64ms +step:215/1645 train_time:19704ms step_avg:91.65ms +step:216/1645 train_time:19797ms step_avg:91.65ms +step:217/1645 train_time:19888ms step_avg:91.65ms +step:218/1645 train_time:19980ms step_avg:91.65ms +step:219/1645 train_time:20071ms step_avg:91.65ms +step:220/1645 train_time:20163ms step_avg:91.65ms +step:221/1645 train_time:20255ms step_avg:91.65ms +step:222/1645 train_time:20346ms step_avg:91.65ms +step:223/1645 train_time:20437ms step_avg:91.65ms +step:224/1645 train_time:20528ms step_avg:91.64ms +step:225/1645 train_time:20619ms step_avg:91.64ms +step:226/1645 train_time:20711ms step_avg:91.64ms +step:227/1645 train_time:20805ms step_avg:91.65ms +step:228/1645 train_time:20895ms step_avg:91.64ms +step:229/1645 train_time:20987ms step_avg:91.64ms +step:230/1645 train_time:21079ms step_avg:91.65ms +step:231/1645 train_time:21170ms step_avg:91.64ms +step:232/1645 train_time:21262ms step_avg:91.65ms +step:233/1645 train_time:21353ms step_avg:91.65ms +step:234/1645 train_time:21445ms step_avg:91.65ms +step:235/1645 train_time:21537ms step_avg:91.65ms +step:236/1645 train_time:21628ms step_avg:91.64ms +step:237/1645 train_time:21719ms step_avg:91.64ms +step:238/1645 train_time:21811ms step_avg:91.64ms +step:239/1645 train_time:21902ms step_avg:91.64ms +step:240/1645 train_time:21994ms step_avg:91.64ms +step:241/1645 train_time:22087ms step_avg:91.65ms +step:242/1645 train_time:22179ms step_avg:91.65ms +step:243/1645 train_time:22271ms step_avg:91.65ms +step:244/1645 train_time:22362ms step_avg:91.65ms +step:245/1645 train_time:22454ms step_avg:91.65ms +step:246/1645 train_time:22545ms step_avg:91.65ms +step:247/1645 train_time:22637ms step_avg:91.65ms +step:248/1645 train_time:22728ms step_avg:91.65ms +step:249/1645 train_time:22819ms step_avg:91.64ms +step:250/1645 train_time:22911ms step_avg:91.65ms +step:250/1645 val_loss:3.9696 train_time:23002ms step_avg:92.01ms +step:251/1645 train_time:23023ms step_avg:91.73ms +step:252/1645 train_time:23102ms step_avg:91.67ms +step:253/1645 train_time:23195ms step_avg:91.68ms +step:254/1645 train_time:23286ms step_avg:91.68ms +step:255/1645 train_time:23377ms step_avg:91.68ms +step:256/1645 train_time:23469ms step_avg:91.67ms +step:257/1645 train_time:23559ms step_avg:91.67ms +step:258/1645 train_time:23649ms step_avg:91.66ms +step:259/1645 train_time:23740ms step_avg:91.66ms +step:260/1645 train_time:23831ms step_avg:91.66ms +step:261/1645 train_time:23923ms step_avg:91.66ms +step:262/1645 train_time:24017ms step_avg:91.67ms +step:263/1645 train_time:24112ms step_avg:91.68ms +step:264/1645 train_time:24204ms step_avg:91.68ms +step:265/1645 train_time:24297ms step_avg:91.69ms +step:266/1645 train_time:24388ms step_avg:91.68ms +step:267/1645 train_time:24480ms step_avg:91.68ms +step:268/1645 train_time:24570ms step_avg:91.68ms +step:269/1645 train_time:24661ms step_avg:91.68ms +step:270/1645 train_time:24751ms step_avg:91.67ms +step:271/1645 train_time:24841ms step_avg:91.67ms +step:272/1645 train_time:24933ms step_avg:91.66ms +step:273/1645 train_time:25026ms step_avg:91.67ms +step:274/1645 train_time:25120ms step_avg:91.68ms +step:275/1645 train_time:25212ms step_avg:91.68ms +step:276/1645 train_time:25304ms step_avg:91.68ms +step:277/1645 train_time:25395ms step_avg:91.68ms +step:278/1645 train_time:25487ms step_avg:91.68ms +step:279/1645 train_time:25579ms step_avg:91.68ms +step:280/1645 train_time:25668ms step_avg:91.67ms +step:281/1645 train_time:25760ms step_avg:91.67ms +step:282/1645 train_time:25851ms step_avg:91.67ms +step:283/1645 train_time:25943ms step_avg:91.67ms +step:284/1645 train_time:26035ms step_avg:91.67ms +step:285/1645 train_time:26126ms step_avg:91.67ms +step:286/1645 train_time:26220ms step_avg:91.68ms +step:287/1645 train_time:26312ms step_avg:91.68ms +step:288/1645 train_time:26404ms step_avg:91.68ms +step:289/1645 train_time:26495ms step_avg:91.68ms +step:290/1645 train_time:26587ms step_avg:91.68ms +step:291/1645 train_time:26677ms step_avg:91.67ms +step:292/1645 train_time:26768ms step_avg:91.67ms +step:293/1645 train_time:26860ms step_avg:91.67ms +step:294/1645 train_time:26951ms step_avg:91.67ms +step:295/1645 train_time:27043ms step_avg:91.67ms +step:296/1645 train_time:27135ms step_avg:91.67ms +step:297/1645 train_time:27227ms step_avg:91.67ms +step:298/1645 train_time:27319ms step_avg:91.67ms +step:299/1645 train_time:27411ms step_avg:91.68ms +step:300/1645 train_time:27503ms step_avg:91.68ms +step:301/1645 train_time:27595ms step_avg:91.68ms +step:302/1645 train_time:27686ms step_avg:91.68ms +step:303/1645 train_time:27777ms step_avg:91.67ms +step:304/1645 train_time:27868ms step_avg:91.67ms +step:305/1645 train_time:27960ms step_avg:91.67ms +step:306/1645 train_time:28052ms step_avg:91.67ms +step:307/1645 train_time:28144ms step_avg:91.67ms +step:308/1645 train_time:28237ms step_avg:91.68ms +step:309/1645 train_time:28329ms step_avg:91.68ms +step:310/1645 train_time:28421ms step_avg:91.68ms +step:311/1645 train_time:28512ms step_avg:91.68ms +step:312/1645 train_time:28603ms step_avg:91.68ms +step:313/1645 train_time:28695ms step_avg:91.68ms +step:314/1645 train_time:28786ms step_avg:91.67ms +step:315/1645 train_time:28877ms step_avg:91.67ms +step:316/1645 train_time:28968ms step_avg:91.67ms +step:317/1645 train_time:29060ms step_avg:91.67ms +step:318/1645 train_time:29152ms step_avg:91.67ms +step:319/1645 train_time:29244ms step_avg:91.67ms +step:320/1645 train_time:29337ms step_avg:91.68ms +step:321/1645 train_time:29429ms step_avg:91.68ms +step:322/1645 train_time:29520ms step_avg:91.68ms +step:323/1645 train_time:29612ms step_avg:91.68ms +step:324/1645 train_time:29703ms step_avg:91.68ms +step:325/1645 train_time:29795ms step_avg:91.68ms +step:326/1645 train_time:29886ms step_avg:91.68ms +step:327/1645 train_time:29977ms step_avg:91.67ms +step:328/1645 train_time:30069ms step_avg:91.67ms +step:329/1645 train_time:30161ms step_avg:91.67ms +step:330/1645 train_time:30253ms step_avg:91.67ms +step:331/1645 train_time:30345ms step_avg:91.68ms +step:332/1645 train_time:30438ms step_avg:91.68ms +step:333/1645 train_time:30529ms step_avg:91.68ms +step:334/1645 train_time:30621ms step_avg:91.68ms +step:335/1645 train_time:30712ms step_avg:91.68ms +step:336/1645 train_time:30805ms step_avg:91.68ms +step:337/1645 train_time:30895ms step_avg:91.68ms +step:338/1645 train_time:30986ms step_avg:91.68ms +step:339/1645 train_time:31079ms step_avg:91.68ms +step:340/1645 train_time:31171ms step_avg:91.68ms +step:341/1645 train_time:31262ms step_avg:91.68ms +step:342/1645 train_time:31354ms step_avg:91.68ms +step:343/1645 train_time:31446ms step_avg:91.68ms +step:344/1645 train_time:31538ms step_avg:91.68ms +step:345/1645 train_time:31630ms step_avg:91.68ms +step:346/1645 train_time:31722ms step_avg:91.68ms +step:347/1645 train_time:31814ms step_avg:91.68ms +step:348/1645 train_time:31906ms step_avg:91.68ms +step:349/1645 train_time:31998ms step_avg:91.68ms +step:350/1645 train_time:32089ms step_avg:91.68ms +step:351/1645 train_time:32180ms step_avg:91.68ms +step:352/1645 train_time:32272ms step_avg:91.68ms +step:353/1645 train_time:32364ms step_avg:91.68ms +step:354/1645 train_time:32455ms step_avg:91.68ms +step:355/1645 train_time:32546ms step_avg:91.68ms +step:356/1645 train_time:32639ms step_avg:91.68ms +step:357/1645 train_time:32731ms step_avg:91.68ms +step:358/1645 train_time:32823ms step_avg:91.68ms +step:359/1645 train_time:32914ms step_avg:91.68ms +step:360/1645 train_time:33005ms step_avg:91.68ms +step:361/1645 train_time:33096ms step_avg:91.68ms +step:362/1645 train_time:33188ms step_avg:91.68ms +step:363/1645 train_time:33280ms step_avg:91.68ms +step:364/1645 train_time:33371ms step_avg:91.68ms +step:365/1645 train_time:33463ms step_avg:91.68ms +step:366/1645 train_time:33554ms step_avg:91.68ms +step:367/1645 train_time:33645ms step_avg:91.68ms +step:368/1645 train_time:33739ms step_avg:91.68ms +step:369/1645 train_time:33830ms step_avg:91.68ms +step:370/1645 train_time:33922ms step_avg:91.68ms +step:371/1645 train_time:34016ms step_avg:91.69ms +step:372/1645 train_time:34105ms step_avg:91.68ms +step:373/1645 train_time:34198ms step_avg:91.68ms +step:374/1645 train_time:34290ms step_avg:91.68ms +step:375/1645 train_time:34381ms step_avg:91.68ms +step:375/1645 val_loss:3.8141 train_time:34473ms step_avg:91.93ms +step:376/1645 train_time:34493ms step_avg:91.74ms +step:377/1645 train_time:34567ms step_avg:91.69ms +step:378/1645 train_time:34661ms step_avg:91.69ms +step:379/1645 train_time:34752ms step_avg:91.70ms +step:380/1645 train_time:34843ms step_avg:91.69ms +step:381/1645 train_time:34934ms step_avg:91.69ms +step:382/1645 train_time:35024ms step_avg:91.69ms +step:383/1645 train_time:35115ms step_avg:91.68ms +step:384/1645 train_time:35205ms step_avg:91.68ms +step:385/1645 train_time:35296ms step_avg:91.68ms +step:386/1645 train_time:35389ms step_avg:91.68ms +step:387/1645 train_time:35483ms step_avg:91.69ms +step:388/1645 train_time:35578ms step_avg:91.69ms +step:389/1645 train_time:35670ms step_avg:91.70ms +step:390/1645 train_time:35761ms step_avg:91.70ms +step:391/1645 train_time:35853ms step_avg:91.69ms +step:392/1645 train_time:35944ms step_avg:91.69ms +step:393/1645 train_time:36035ms step_avg:91.69ms +step:394/1645 train_time:36125ms step_avg:91.69ms +step:395/1645 train_time:36216ms step_avg:91.69ms +step:396/1645 train_time:36307ms step_avg:91.68ms +step:397/1645 train_time:36399ms step_avg:91.68ms +step:398/1645 train_time:36493ms step_avg:91.69ms +step:399/1645 train_time:36587ms step_avg:91.70ms +step:400/1645 train_time:36679ms step_avg:91.70ms +step:401/1645 train_time:36772ms step_avg:91.70ms +step:402/1645 train_time:36863ms step_avg:91.70ms +step:403/1645 train_time:36955ms step_avg:91.70ms +step:404/1645 train_time:37045ms step_avg:91.70ms +step:405/1645 train_time:37136ms step_avg:91.69ms +step:406/1645 train_time:37226ms step_avg:91.69ms +step:407/1645 train_time:37318ms step_avg:91.69ms +step:408/1645 train_time:37409ms step_avg:91.69ms +step:409/1645 train_time:37501ms step_avg:91.69ms +step:410/1645 train_time:37595ms step_avg:91.69ms +step:411/1645 train_time:37687ms step_avg:91.70ms +step:412/1645 train_time:37779ms step_avg:91.70ms +step:413/1645 train_time:37872ms step_avg:91.70ms +step:414/1645 train_time:37963ms step_avg:91.70ms +step:415/1645 train_time:38054ms step_avg:91.70ms +step:416/1645 train_time:38145ms step_avg:91.70ms +step:417/1645 train_time:38236ms step_avg:91.69ms +step:418/1645 train_time:38326ms step_avg:91.69ms +step:419/1645 train_time:38418ms step_avg:91.69ms +step:420/1645 train_time:38510ms step_avg:91.69ms +step:421/1645 train_time:38602ms step_avg:91.69ms +step:422/1645 train_time:38695ms step_avg:91.69ms +step:423/1645 train_time:38787ms step_avg:91.70ms +step:424/1645 train_time:38878ms step_avg:91.69ms +step:425/1645 train_time:38971ms step_avg:91.70ms +step:426/1645 train_time:39062ms step_avg:91.70ms +step:427/1645 train_time:39154ms step_avg:91.69ms +step:428/1645 train_time:39245ms step_avg:91.69ms +step:429/1645 train_time:39336ms step_avg:91.69ms +step:430/1645 train_time:39427ms step_avg:91.69ms +step:431/1645 train_time:39518ms step_avg:91.69ms +step:432/1645 train_time:39611ms step_avg:91.69ms +step:433/1645 train_time:39702ms step_avg:91.69ms +step:434/1645 train_time:39794ms step_avg:91.69ms +step:435/1645 train_time:39886ms step_avg:91.69ms +step:436/1645 train_time:39978ms step_avg:91.69ms +step:437/1645 train_time:40070ms step_avg:91.69ms +step:438/1645 train_time:40162ms step_avg:91.69ms +step:439/1645 train_time:40253ms step_avg:91.69ms +step:440/1645 train_time:40345ms step_avg:91.69ms +step:441/1645 train_time:40436ms step_avg:91.69ms +step:442/1645 train_time:40527ms step_avg:91.69ms +step:443/1645 train_time:40618ms step_avg:91.69ms +step:444/1645 train_time:40711ms step_avg:91.69ms +step:445/1645 train_time:40803ms step_avg:91.69ms +step:446/1645 train_time:40894ms step_avg:91.69ms +step:447/1645 train_time:40986ms step_avg:91.69ms +step:448/1645 train_time:41077ms step_avg:91.69ms +step:449/1645 train_time:41170ms step_avg:91.69ms +step:450/1645 train_time:41262ms step_avg:91.69ms +step:451/1645 train_time:41353ms step_avg:91.69ms +step:452/1645 train_time:41445ms step_avg:91.69ms +step:453/1645 train_time:41536ms step_avg:91.69ms +step:454/1645 train_time:41627ms step_avg:91.69ms +step:455/1645 train_time:41718ms step_avg:91.69ms +step:456/1645 train_time:41810ms step_avg:91.69ms +step:457/1645 train_time:41903ms step_avg:91.69ms +step:458/1645 train_time:41994ms step_avg:91.69ms +step:459/1645 train_time:42087ms step_avg:91.69ms +step:460/1645 train_time:42178ms step_avg:91.69ms +step:461/1645 train_time:42271ms step_avg:91.69ms +step:462/1645 train_time:42362ms step_avg:91.69ms +step:463/1645 train_time:42455ms step_avg:91.69ms +step:464/1645 train_time:42546ms step_avg:91.69ms +step:465/1645 train_time:42637ms step_avg:91.69ms +step:466/1645 train_time:42728ms step_avg:91.69ms +step:467/1645 train_time:42820ms step_avg:91.69ms +step:468/1645 train_time:42911ms step_avg:91.69ms +step:469/1645 train_time:43003ms step_avg:91.69ms +step:470/1645 train_time:43095ms step_avg:91.69ms +step:471/1645 train_time:43186ms step_avg:91.69ms +step:472/1645 train_time:43278ms step_avg:91.69ms +step:473/1645 train_time:43370ms step_avg:91.69ms +step:474/1645 train_time:43462ms step_avg:91.69ms +step:475/1645 train_time:43554ms step_avg:91.69ms +step:476/1645 train_time:43645ms step_avg:91.69ms +step:477/1645 train_time:43737ms step_avg:91.69ms +step:478/1645 train_time:43828ms step_avg:91.69ms +step:479/1645 train_time:43920ms step_avg:91.69ms +step:480/1645 train_time:44013ms step_avg:91.69ms +step:481/1645 train_time:44104ms step_avg:91.69ms +step:482/1645 train_time:44196ms step_avg:91.69ms +step:483/1645 train_time:44289ms step_avg:91.70ms +step:484/1645 train_time:44381ms step_avg:91.70ms +step:485/1645 train_time:44473ms step_avg:91.70ms +step:486/1645 train_time:44564ms step_avg:91.70ms +step:487/1645 train_time:44656ms step_avg:91.70ms +step:488/1645 train_time:44748ms step_avg:91.70ms +step:489/1645 train_time:44839ms step_avg:91.70ms +step:490/1645 train_time:44931ms step_avg:91.69ms +step:491/1645 train_time:45023ms step_avg:91.70ms +step:492/1645 train_time:45114ms step_avg:91.70ms +step:493/1645 train_time:45207ms step_avg:91.70ms +step:494/1645 train_time:45298ms step_avg:91.70ms +step:495/1645 train_time:45391ms step_avg:91.70ms +step:496/1645 train_time:45482ms step_avg:91.70ms +step:497/1645 train_time:45574ms step_avg:91.70ms +step:498/1645 train_time:45666ms step_avg:91.70ms +step:499/1645 train_time:45757ms step_avg:91.70ms +step:500/1645 train_time:45848ms step_avg:91.70ms +step:500/1645 val_loss:3.7137 train_time:45941ms step_avg:91.88ms +step:501/1645 train_time:45961ms step_avg:91.74ms +step:502/1645 train_time:46036ms step_avg:91.70ms +step:503/1645 train_time:46131ms step_avg:91.71ms +step:504/1645 train_time:46223ms step_avg:91.71ms +step:505/1645 train_time:46313ms step_avg:91.71ms +step:506/1645 train_time:46404ms step_avg:91.71ms +step:507/1645 train_time:46494ms step_avg:91.70ms +step:508/1645 train_time:46585ms step_avg:91.70ms +step:509/1645 train_time:46677ms step_avg:91.70ms +step:510/1645 train_time:46768ms step_avg:91.70ms +step:511/1645 train_time:46861ms step_avg:91.70ms +step:512/1645 train_time:46954ms step_avg:91.71ms +step:513/1645 train_time:47047ms step_avg:91.71ms +step:514/1645 train_time:47139ms step_avg:91.71ms +step:515/1645 train_time:47232ms step_avg:91.71ms +step:516/1645 train_time:47323ms step_avg:91.71ms +step:517/1645 train_time:47414ms step_avg:91.71ms +step:518/1645 train_time:47505ms step_avg:91.71ms +step:519/1645 train_time:47595ms step_avg:91.71ms +step:520/1645 train_time:47687ms step_avg:91.71ms +step:521/1645 train_time:47778ms step_avg:91.70ms +step:522/1645 train_time:47870ms step_avg:91.70ms +step:523/1645 train_time:47962ms step_avg:91.70ms +step:524/1645 train_time:48054ms step_avg:91.71ms +step:525/1645 train_time:48147ms step_avg:91.71ms +step:526/1645 train_time:48240ms step_avg:91.71ms +step:527/1645 train_time:48331ms step_avg:91.71ms +step:528/1645 train_time:48423ms step_avg:91.71ms +step:529/1645 train_time:48513ms step_avg:91.71ms +step:530/1645 train_time:48605ms step_avg:91.71ms +step:531/1645 train_time:48696ms step_avg:91.71ms +step:532/1645 train_time:48787ms step_avg:91.71ms +step:533/1645 train_time:48878ms step_avg:91.70ms +step:534/1645 train_time:48970ms step_avg:91.70ms +step:535/1645 train_time:49062ms step_avg:91.70ms +step:536/1645 train_time:49154ms step_avg:91.71ms +step:537/1645 train_time:49248ms step_avg:91.71ms +step:538/1645 train_time:49340ms step_avg:91.71ms +step:539/1645 train_time:49432ms step_avg:91.71ms +step:540/1645 train_time:49524ms step_avg:91.71ms +step:541/1645 train_time:49615ms step_avg:91.71ms +step:542/1645 train_time:49707ms step_avg:91.71ms +step:543/1645 train_time:49799ms step_avg:91.71ms +step:544/1645 train_time:49889ms step_avg:91.71ms +step:545/1645 train_time:49980ms step_avg:91.71ms +step:546/1645 train_time:50072ms step_avg:91.71ms +step:547/1645 train_time:50165ms step_avg:91.71ms +step:548/1645 train_time:50257ms step_avg:91.71ms +step:549/1645 train_time:50350ms step_avg:91.71ms +step:550/1645 train_time:50443ms step_avg:91.71ms +step:551/1645 train_time:50535ms step_avg:91.71ms +step:552/1645 train_time:50628ms step_avg:91.72ms +step:553/1645 train_time:50721ms step_avg:91.72ms +step:554/1645 train_time:50813ms step_avg:91.72ms +step:555/1645 train_time:50906ms step_avg:91.72ms +step:556/1645 train_time:50998ms step_avg:91.72ms +step:557/1645 train_time:51091ms step_avg:91.73ms +step:558/1645 train_time:51184ms step_avg:91.73ms +step:559/1645 train_time:51277ms step_avg:91.73ms +step:560/1645 train_time:51371ms step_avg:91.73ms +step:561/1645 train_time:51464ms step_avg:91.74ms +step:562/1645 train_time:51557ms step_avg:91.74ms +step:563/1645 train_time:51650ms step_avg:91.74ms +step:564/1645 train_time:51742ms step_avg:91.74ms +step:565/1645 train_time:51835ms step_avg:91.74ms +step:566/1645 train_time:51928ms step_avg:91.74ms +step:567/1645 train_time:52020ms step_avg:91.75ms +step:568/1645 train_time:52112ms step_avg:91.75ms +step:569/1645 train_time:52207ms step_avg:91.75ms +step:570/1645 train_time:52300ms step_avg:91.75ms +step:571/1645 train_time:52393ms step_avg:91.76ms +step:572/1645 train_time:52486ms step_avg:91.76ms +step:573/1645 train_time:52579ms step_avg:91.76ms +step:574/1645 train_time:52673ms step_avg:91.76ms +step:575/1645 train_time:52767ms step_avg:91.77ms +step:576/1645 train_time:52859ms step_avg:91.77ms +step:577/1645 train_time:52953ms step_avg:91.77ms +step:578/1645 train_time:53045ms step_avg:91.77ms +step:579/1645 train_time:53138ms step_avg:91.78ms +step:580/1645 train_time:53232ms step_avg:91.78ms +step:581/1645 train_time:53325ms step_avg:91.78ms +step:582/1645 train_time:53418ms step_avg:91.78ms +step:583/1645 train_time:53512ms step_avg:91.79ms +step:584/1645 train_time:53604ms step_avg:91.79ms +step:585/1645 train_time:53698ms step_avg:91.79ms +step:586/1645 train_time:53791ms step_avg:91.79ms +step:587/1645 train_time:53884ms step_avg:91.80ms +step:588/1645 train_time:53976ms step_avg:91.80ms +step:589/1645 train_time:54069ms step_avg:91.80ms +step:590/1645 train_time:54161ms step_avg:91.80ms +step:591/1645 train_time:54255ms step_avg:91.80ms +step:592/1645 train_time:54347ms step_avg:91.80ms +step:593/1645 train_time:54441ms step_avg:91.81ms +step:594/1645 train_time:54534ms step_avg:91.81ms +step:595/1645 train_time:54627ms step_avg:91.81ms +step:596/1645 train_time:54720ms step_avg:91.81ms +step:597/1645 train_time:54813ms step_avg:91.81ms +step:598/1645 train_time:54906ms step_avg:91.82ms +step:599/1645 train_time:54999ms step_avg:91.82ms +step:600/1645 train_time:55092ms step_avg:91.82ms +step:601/1645 train_time:55185ms step_avg:91.82ms +step:602/1645 train_time:55278ms step_avg:91.82ms +step:603/1645 train_time:55371ms step_avg:91.83ms +step:604/1645 train_time:55464ms step_avg:91.83ms +step:605/1645 train_time:55557ms step_avg:91.83ms +step:606/1645 train_time:55650ms step_avg:91.83ms +step:607/1645 train_time:55743ms step_avg:91.83ms +step:608/1645 train_time:55836ms step_avg:91.84ms +step:609/1645 train_time:55930ms step_avg:91.84ms +step:610/1645 train_time:56022ms step_avg:91.84ms +step:611/1645 train_time:56114ms step_avg:91.84ms +step:612/1645 train_time:56207ms step_avg:91.84ms +step:613/1645 train_time:56301ms step_avg:91.84ms +step:614/1645 train_time:56392ms step_avg:91.84ms +step:615/1645 train_time:56485ms step_avg:91.85ms +step:616/1645 train_time:56578ms step_avg:91.85ms +step:617/1645 train_time:56671ms step_avg:91.85ms +step:618/1645 train_time:56764ms step_avg:91.85ms +step:619/1645 train_time:56856ms step_avg:91.85ms +step:620/1645 train_time:56951ms step_avg:91.86ms +step:621/1645 train_time:57044ms step_avg:91.86ms +step:622/1645 train_time:57136ms step_avg:91.86ms +step:623/1645 train_time:57231ms step_avg:91.86ms +step:624/1645 train_time:57323ms step_avg:91.86ms +step:625/1645 train_time:57415ms step_avg:91.86ms +step:625/1645 val_loss:3.6114 train_time:57508ms step_avg:92.01ms +step:626/1645 train_time:57530ms step_avg:91.90ms +step:627/1645 train_time:57608ms step_avg:91.88ms +step:628/1645 train_time:57705ms step_avg:91.89ms +step:629/1645 train_time:57798ms step_avg:91.89ms +step:630/1645 train_time:57890ms step_avg:91.89ms +step:631/1645 train_time:57982ms step_avg:91.89ms +step:632/1645 train_time:58074ms step_avg:91.89ms +step:633/1645 train_time:58165ms step_avg:91.89ms +step:634/1645 train_time:58257ms step_avg:91.89ms +step:635/1645 train_time:58348ms step_avg:91.89ms +step:636/1645 train_time:58443ms step_avg:91.89ms +step:637/1645 train_time:58540ms step_avg:91.90ms +step:638/1645 train_time:58635ms step_avg:91.90ms +step:639/1645 train_time:58729ms step_avg:91.91ms +step:640/1645 train_time:58823ms step_avg:91.91ms +step:641/1645 train_time:58915ms step_avg:91.91ms +step:642/1645 train_time:59007ms step_avg:91.91ms +step:643/1645 train_time:59100ms step_avg:91.91ms +step:644/1645 train_time:59193ms step_avg:91.91ms +step:645/1645 train_time:59284ms step_avg:91.91ms +step:646/1645 train_time:59376ms step_avg:91.91ms +step:647/1645 train_time:59469ms step_avg:91.92ms +step:648/1645 train_time:59565ms step_avg:91.92ms +step:649/1645 train_time:59660ms step_avg:91.93ms +step:650/1645 train_time:59753ms step_avg:91.93ms +step:651/1645 train_time:59846ms step_avg:91.93ms +step:652/1645 train_time:59939ms step_avg:91.93ms +step:653/1645 train_time:60032ms step_avg:91.93ms +step:654/1645 train_time:60125ms step_avg:91.93ms +step:655/1645 train_time:60216ms step_avg:91.93ms +step:656/1645 train_time:60309ms step_avg:91.93ms +step:657/1645 train_time:60402ms step_avg:91.94ms +step:658/1645 train_time:60496ms step_avg:91.94ms +step:659/1645 train_time:60589ms step_avg:91.94ms +step:660/1645 train_time:60683ms step_avg:91.94ms +step:661/1645 train_time:60776ms step_avg:91.95ms +step:662/1645 train_time:60869ms step_avg:91.95ms +step:663/1645 train_time:60962ms step_avg:91.95ms +step:664/1645 train_time:61054ms step_avg:91.95ms +step:665/1645 train_time:61148ms step_avg:91.95ms +step:666/1645 train_time:61240ms step_avg:91.95ms +step:667/1645 train_time:61333ms step_avg:91.95ms +step:668/1645 train_time:61426ms step_avg:91.96ms +step:669/1645 train_time:61519ms step_avg:91.96ms +step:670/1645 train_time:61613ms step_avg:91.96ms +step:671/1645 train_time:61706ms step_avg:91.96ms +step:672/1645 train_time:61800ms step_avg:91.96ms +step:673/1645 train_time:61894ms step_avg:91.97ms +step:674/1645 train_time:61986ms step_avg:91.97ms +step:675/1645 train_time:62079ms step_avg:91.97ms +step:676/1645 train_time:62171ms step_avg:91.97ms +step:677/1645 train_time:62263ms step_avg:91.97ms +step:678/1645 train_time:62356ms step_avg:91.97ms +step:679/1645 train_time:62448ms step_avg:91.97ms +step:680/1645 train_time:62541ms step_avg:91.97ms +step:681/1645 train_time:62634ms step_avg:91.97ms +step:682/1645 train_time:62727ms step_avg:91.97ms +step:683/1645 train_time:62821ms step_avg:91.98ms +step:684/1645 train_time:62914ms step_avg:91.98ms +step:685/1645 train_time:63006ms step_avg:91.98ms +step:686/1645 train_time:63100ms step_avg:91.98ms +step:687/1645 train_time:63192ms step_avg:91.98ms +step:688/1645 train_time:63285ms step_avg:91.98ms +step:689/1645 train_time:63378ms step_avg:91.98ms +step:690/1645 train_time:63470ms step_avg:91.98ms +step:691/1645 train_time:63563ms step_avg:91.99ms +step:692/1645 train_time:63656ms step_avg:91.99ms +step:693/1645 train_time:63749ms step_avg:91.99ms +step:694/1645 train_time:63842ms step_avg:91.99ms +step:695/1645 train_time:63934ms step_avg:91.99ms +step:696/1645 train_time:64028ms step_avg:91.99ms +step:697/1645 train_time:64121ms step_avg:92.00ms +step:698/1645 train_time:64215ms step_avg:92.00ms +step:699/1645 train_time:64307ms step_avg:92.00ms +step:700/1645 train_time:64400ms step_avg:92.00ms +step:701/1645 train_time:64493ms step_avg:92.00ms +step:702/1645 train_time:64586ms step_avg:92.00ms +step:703/1645 train_time:64679ms step_avg:92.00ms +step:704/1645 train_time:64772ms step_avg:92.01ms +step:705/1645 train_time:64865ms step_avg:92.01ms +step:706/1645 train_time:64958ms step_avg:92.01ms +step:707/1645 train_time:65050ms step_avg:92.01ms +step:708/1645 train_time:65143ms step_avg:92.01ms +step:709/1645 train_time:65236ms step_avg:92.01ms +step:710/1645 train_time:65328ms step_avg:92.01ms +step:711/1645 train_time:65421ms step_avg:92.01ms +step:712/1645 train_time:65514ms step_avg:92.01ms +step:713/1645 train_time:65607ms step_avg:92.02ms +step:714/1645 train_time:65701ms step_avg:92.02ms +step:715/1645 train_time:65793ms step_avg:92.02ms +step:716/1645 train_time:65886ms step_avg:92.02ms +step:717/1645 train_time:65980ms step_avg:92.02ms +step:718/1645 train_time:66073ms step_avg:92.02ms +step:719/1645 train_time:66165ms step_avg:92.02ms +step:720/1645 train_time:66258ms step_avg:92.02ms +step:721/1645 train_time:66350ms step_avg:92.03ms +step:722/1645 train_time:66443ms step_avg:92.03ms +step:723/1645 train_time:66535ms step_avg:92.03ms +step:724/1645 train_time:66628ms step_avg:92.03ms +step:725/1645 train_time:66722ms step_avg:92.03ms +step:726/1645 train_time:66815ms step_avg:92.03ms +step:727/1645 train_time:66908ms step_avg:92.03ms +step:728/1645 train_time:67001ms step_avg:92.03ms +step:729/1645 train_time:67095ms step_avg:92.04ms +step:730/1645 train_time:67187ms step_avg:92.04ms +step:731/1645 train_time:67280ms step_avg:92.04ms +step:732/1645 train_time:67373ms step_avg:92.04ms +step:733/1645 train_time:67465ms step_avg:92.04ms +step:734/1645 train_time:67558ms step_avg:92.04ms +step:735/1645 train_time:67652ms step_avg:92.04ms +step:736/1645 train_time:67745ms step_avg:92.04ms +step:737/1645 train_time:67838ms step_avg:92.05ms +step:738/1645 train_time:67930ms step_avg:92.05ms +step:739/1645 train_time:68024ms step_avg:92.05ms +step:740/1645 train_time:68117ms step_avg:92.05ms +step:741/1645 train_time:68209ms step_avg:92.05ms +step:742/1645 train_time:68303ms step_avg:92.05ms +step:743/1645 train_time:68395ms step_avg:92.05ms +step:744/1645 train_time:68488ms step_avg:92.05ms +step:745/1645 train_time:68580ms step_avg:92.05ms +step:746/1645 train_time:68674ms step_avg:92.06ms +step:747/1645 train_time:68767ms step_avg:92.06ms +step:748/1645 train_time:68859ms step_avg:92.06ms +step:749/1645 train_time:68953ms step_avg:92.06ms +step:750/1645 train_time:69047ms step_avg:92.06ms +step:750/1645 val_loss:3.5596 train_time:69140ms step_avg:92.19ms +step:751/1645 train_time:69161ms step_avg:92.09ms +step:752/1645 train_time:69237ms step_avg:92.07ms +step:753/1645 train_time:69334ms step_avg:92.08ms +step:754/1645 train_time:69427ms step_avg:92.08ms +step:755/1645 train_time:69518ms step_avg:92.08ms +step:756/1645 train_time:69610ms step_avg:92.08ms +step:757/1645 train_time:69702ms step_avg:92.08ms +step:758/1645 train_time:69794ms step_avg:92.08ms +step:759/1645 train_time:69886ms step_avg:92.08ms +step:760/1645 train_time:69978ms step_avg:92.08ms +step:761/1645 train_time:70070ms step_avg:92.08ms +step:762/1645 train_time:70164ms step_avg:92.08ms +step:763/1645 train_time:70260ms step_avg:92.08ms +step:764/1645 train_time:70353ms step_avg:92.09ms +step:765/1645 train_time:70446ms step_avg:92.09ms +step:766/1645 train_time:70540ms step_avg:92.09ms +step:767/1645 train_time:70632ms step_avg:92.09ms +step:768/1645 train_time:70726ms step_avg:92.09ms +step:769/1645 train_time:70818ms step_avg:92.09ms +step:770/1645 train_time:70910ms step_avg:92.09ms +step:771/1645 train_time:71003ms step_avg:92.09ms +step:772/1645 train_time:71097ms step_avg:92.09ms +step:773/1645 train_time:71192ms step_avg:92.10ms +step:774/1645 train_time:71286ms step_avg:92.10ms +step:775/1645 train_time:71379ms step_avg:92.10ms +step:776/1645 train_time:71472ms step_avg:92.10ms +step:777/1645 train_time:71565ms step_avg:92.10ms +step:778/1645 train_time:71658ms step_avg:92.11ms +step:779/1645 train_time:71750ms step_avg:92.11ms +step:780/1645 train_time:71843ms step_avg:92.11ms +step:781/1645 train_time:71935ms step_avg:92.11ms +step:782/1645 train_time:72028ms step_avg:92.11ms +step:783/1645 train_time:72120ms step_avg:92.11ms +step:784/1645 train_time:72213ms step_avg:92.11ms +step:785/1645 train_time:72307ms step_avg:92.11ms +step:786/1645 train_time:72400ms step_avg:92.11ms +step:787/1645 train_time:72494ms step_avg:92.11ms +step:788/1645 train_time:72587ms step_avg:92.12ms +step:789/1645 train_time:72680ms step_avg:92.12ms +step:790/1645 train_time:72772ms step_avg:92.12ms +step:791/1645 train_time:72864ms step_avg:92.12ms +step:792/1645 train_time:72956ms step_avg:92.12ms +step:793/1645 train_time:73049ms step_avg:92.12ms +step:794/1645 train_time:73143ms step_avg:92.12ms +step:795/1645 train_time:73236ms step_avg:92.12ms +step:796/1645 train_time:73330ms step_avg:92.12ms +step:797/1645 train_time:73423ms step_avg:92.12ms +step:798/1645 train_time:73517ms step_avg:92.13ms +step:799/1645 train_time:73611ms step_avg:92.13ms +step:800/1645 train_time:73703ms step_avg:92.13ms +step:801/1645 train_time:73796ms step_avg:92.13ms +step:802/1645 train_time:73889ms step_avg:92.13ms +step:803/1645 train_time:73980ms step_avg:92.13ms +step:804/1645 train_time:74073ms step_avg:92.13ms +step:805/1645 train_time:74167ms step_avg:92.13ms +step:806/1645 train_time:74260ms step_avg:92.13ms +step:807/1645 train_time:74353ms step_avg:92.13ms +step:808/1645 train_time:74447ms step_avg:92.14ms +step:809/1645 train_time:74539ms step_avg:92.14ms +step:810/1645 train_time:74632ms step_avg:92.14ms +step:811/1645 train_time:74725ms step_avg:92.14ms +step:812/1645 train_time:74819ms step_avg:92.14ms +step:813/1645 train_time:74911ms step_avg:92.14ms +step:814/1645 train_time:75004ms step_avg:92.14ms +step:815/1645 train_time:75097ms step_avg:92.14ms +step:816/1645 train_time:75190ms step_avg:92.14ms +step:817/1645 train_time:75283ms step_avg:92.15ms +step:818/1645 train_time:75375ms step_avg:92.15ms +step:819/1645 train_time:75469ms step_avg:92.15ms +step:820/1645 train_time:75562ms step_avg:92.15ms +step:821/1645 train_time:75654ms step_avg:92.15ms +step:822/1645 train_time:75749ms step_avg:92.15ms +step:823/1645 train_time:75842ms step_avg:92.15ms +step:824/1645 train_time:75934ms step_avg:92.15ms +step:825/1645 train_time:76027ms step_avg:92.15ms +step:826/1645 train_time:76119ms step_avg:92.15ms +step:827/1645 train_time:76212ms step_avg:92.16ms +step:828/1645 train_time:76306ms step_avg:92.16ms +step:829/1645 train_time:76398ms step_avg:92.16ms +step:830/1645 train_time:76491ms step_avg:92.16ms +step:831/1645 train_time:76584ms step_avg:92.16ms +step:832/1645 train_time:76677ms step_avg:92.16ms +step:833/1645 train_time:76771ms step_avg:92.16ms +step:834/1645 train_time:76864ms step_avg:92.16ms +step:835/1645 train_time:76956ms step_avg:92.16ms +step:836/1645 train_time:77049ms step_avg:92.16ms +step:837/1645 train_time:77142ms step_avg:92.17ms +step:838/1645 train_time:77235ms step_avg:92.17ms +step:839/1645 train_time:77329ms step_avg:92.17ms +step:840/1645 train_time:77423ms step_avg:92.17ms +step:841/1645 train_time:77516ms step_avg:92.17ms +step:842/1645 train_time:77609ms step_avg:92.17ms +step:843/1645 train_time:77702ms step_avg:92.17ms +step:844/1645 train_time:77795ms step_avg:92.17ms +step:845/1645 train_time:77887ms step_avg:92.17ms +step:846/1645 train_time:77980ms step_avg:92.17ms +step:847/1645 train_time:78072ms step_avg:92.17ms +step:848/1645 train_time:78165ms step_avg:92.18ms +step:849/1645 train_time:78258ms step_avg:92.18ms +step:850/1645 train_time:78350ms step_avg:92.18ms +step:851/1645 train_time:78445ms step_avg:92.18ms +step:852/1645 train_time:78537ms step_avg:92.18ms +step:853/1645 train_time:78630ms step_avg:92.18ms +step:854/1645 train_time:78725ms step_avg:92.18ms +step:855/1645 train_time:78817ms step_avg:92.18ms +step:856/1645 train_time:78910ms step_avg:92.18ms +step:857/1645 train_time:79003ms step_avg:92.19ms +step:858/1645 train_time:79096ms step_avg:92.19ms +step:859/1645 train_time:79188ms step_avg:92.19ms +step:860/1645 train_time:79281ms step_avg:92.19ms +step:861/1645 train_time:79374ms step_avg:92.19ms +step:862/1645 train_time:79467ms step_avg:92.19ms +step:863/1645 train_time:79559ms step_avg:92.19ms +step:864/1645 train_time:79652ms step_avg:92.19ms +step:865/1645 train_time:79746ms step_avg:92.19ms +step:866/1645 train_time:79839ms step_avg:92.19ms +step:867/1645 train_time:79932ms step_avg:92.19ms +step:868/1645 train_time:80025ms step_avg:92.19ms +step:869/1645 train_time:80117ms step_avg:92.19ms +step:870/1645 train_time:80210ms step_avg:92.20ms +step:871/1645 train_time:80303ms step_avg:92.20ms +step:872/1645 train_time:80396ms step_avg:92.20ms +step:873/1645 train_time:80489ms step_avg:92.20ms +step:874/1645 train_time:80582ms step_avg:92.20ms +step:875/1645 train_time:80674ms step_avg:92.20ms +step:875/1645 val_loss:3.5143 train_time:80768ms step_avg:92.31ms +step:876/1645 train_time:80788ms step_avg:92.22ms +step:877/1645 train_time:80864ms step_avg:92.21ms +step:878/1645 train_time:80960ms step_avg:92.21ms +step:879/1645 train_time:81053ms step_avg:92.21ms +step:880/1645 train_time:81145ms step_avg:92.21ms +step:881/1645 train_time:81237ms step_avg:92.21ms +step:882/1645 train_time:81328ms step_avg:92.21ms +step:883/1645 train_time:81420ms step_avg:92.21ms +step:884/1645 train_time:81512ms step_avg:92.21ms +step:885/1645 train_time:81605ms step_avg:92.21ms +step:886/1645 train_time:81697ms step_avg:92.21ms +step:887/1645 train_time:81791ms step_avg:92.21ms +step:888/1645 train_time:81887ms step_avg:92.21ms +step:889/1645 train_time:81982ms step_avg:92.22ms +step:890/1645 train_time:82076ms step_avg:92.22ms +step:891/1645 train_time:82168ms step_avg:92.22ms +step:892/1645 train_time:82260ms step_avg:92.22ms +step:893/1645 train_time:82352ms step_avg:92.22ms +step:894/1645 train_time:82445ms step_avg:92.22ms +step:895/1645 train_time:82537ms step_avg:92.22ms +step:896/1645 train_time:82629ms step_avg:92.22ms +step:897/1645 train_time:82722ms step_avg:92.22ms +step:898/1645 train_time:82814ms step_avg:92.22ms +step:899/1645 train_time:82910ms step_avg:92.22ms +step:900/1645 train_time:83005ms step_avg:92.23ms +step:901/1645 train_time:83097ms step_avg:92.23ms +step:902/1645 train_time:83190ms step_avg:92.23ms +step:903/1645 train_time:83284ms step_avg:92.23ms +step:904/1645 train_time:83375ms step_avg:92.23ms +step:905/1645 train_time:83468ms step_avg:92.23ms +step:906/1645 train_time:83560ms step_avg:92.23ms +step:907/1645 train_time:83652ms step_avg:92.23ms +step:908/1645 train_time:83745ms step_avg:92.23ms +step:909/1645 train_time:83838ms step_avg:92.23ms +step:910/1645 train_time:83932ms step_avg:92.23ms +step:911/1645 train_time:84026ms step_avg:92.23ms +step:912/1645 train_time:84119ms step_avg:92.24ms +step:913/1645 train_time:84212ms step_avg:92.24ms +step:914/1645 train_time:84305ms step_avg:92.24ms +step:915/1645 train_time:84398ms step_avg:92.24ms +step:916/1645 train_time:84491ms step_avg:92.24ms +step:917/1645 train_time:84583ms step_avg:92.24ms +step:918/1645 train_time:84675ms step_avg:92.24ms +step:919/1645 train_time:84768ms step_avg:92.24ms +step:920/1645 train_time:84861ms step_avg:92.24ms +step:921/1645 train_time:84955ms step_avg:92.24ms +step:922/1645 train_time:85048ms step_avg:92.24ms +step:923/1645 train_time:85141ms step_avg:92.24ms +step:924/1645 train_time:85233ms step_avg:92.24ms +step:925/1645 train_time:85327ms step_avg:92.24ms +step:926/1645 train_time:85419ms step_avg:92.24ms +step:927/1645 train_time:85512ms step_avg:92.25ms +step:928/1645 train_time:85605ms step_avg:92.25ms +step:929/1645 train_time:85698ms step_avg:92.25ms +step:930/1645 train_time:85791ms step_avg:92.25ms +step:931/1645 train_time:85884ms step_avg:92.25ms +step:932/1645 train_time:85978ms step_avg:92.25ms +step:933/1645 train_time:86071ms step_avg:92.25ms +step:934/1645 train_time:86163ms step_avg:92.25ms +step:935/1645 train_time:86256ms step_avg:92.25ms +step:936/1645 train_time:86349ms step_avg:92.25ms +step:937/1645 train_time:86442ms step_avg:92.25ms +step:938/1645 train_time:86534ms step_avg:92.25ms +step:939/1645 train_time:86626ms step_avg:92.25ms +step:940/1645 train_time:86719ms step_avg:92.25ms +step:941/1645 train_time:86812ms step_avg:92.25ms +step:942/1645 train_time:86905ms step_avg:92.26ms +step:943/1645 train_time:86999ms step_avg:92.26ms +step:944/1645 train_time:87091ms step_avg:92.26ms +step:945/1645 train_time:87185ms step_avg:92.26ms +step:946/1645 train_time:87278ms step_avg:92.26ms +step:947/1645 train_time:87371ms step_avg:92.26ms +step:948/1645 train_time:87465ms step_avg:92.26ms +step:949/1645 train_time:87557ms step_avg:92.26ms +step:950/1645 train_time:87650ms step_avg:92.26ms +step:951/1645 train_time:87744ms step_avg:92.26ms +step:952/1645 train_time:87837ms step_avg:92.27ms +step:953/1645 train_time:87929ms step_avg:92.27ms +step:954/1645 train_time:88022ms step_avg:92.27ms +step:955/1645 train_time:88114ms step_avg:92.27ms +step:956/1645 train_time:88208ms step_avg:92.27ms +step:957/1645 train_time:88300ms step_avg:92.27ms +step:958/1645 train_time:88392ms step_avg:92.27ms +step:959/1645 train_time:88486ms step_avg:92.27ms +step:960/1645 train_time:88579ms step_avg:92.27ms +step:961/1645 train_time:88672ms step_avg:92.27ms +step:962/1645 train_time:88766ms step_avg:92.27ms +step:963/1645 train_time:88859ms step_avg:92.27ms +step:964/1645 train_time:88951ms step_avg:92.27ms +step:965/1645 train_time:89045ms step_avg:92.27ms +step:966/1645 train_time:89137ms step_avg:92.27ms +step:967/1645 train_time:89230ms step_avg:92.28ms +step:968/1645 train_time:89323ms step_avg:92.28ms +step:969/1645 train_time:89417ms step_avg:92.28ms +step:970/1645 train_time:89510ms step_avg:92.28ms +step:971/1645 train_time:89603ms step_avg:92.28ms +step:972/1645 train_time:89696ms step_avg:92.28ms +step:973/1645 train_time:89790ms step_avg:92.28ms +step:974/1645 train_time:89883ms step_avg:92.28ms +step:975/1645 train_time:89975ms step_avg:92.28ms +step:976/1645 train_time:90068ms step_avg:92.28ms +step:977/1645 train_time:90161ms step_avg:92.28ms +step:978/1645 train_time:90254ms step_avg:92.28ms +step:979/1645 train_time:90347ms step_avg:92.28ms +step:980/1645 train_time:90440ms step_avg:92.29ms +step:981/1645 train_time:90532ms step_avg:92.29ms +step:982/1645 train_time:90625ms step_avg:92.29ms +step:983/1645 train_time:90718ms step_avg:92.29ms +step:984/1645 train_time:90811ms step_avg:92.29ms +step:985/1645 train_time:90904ms step_avg:92.29ms +step:986/1645 train_time:90997ms step_avg:92.29ms +step:987/1645 train_time:91090ms step_avg:92.29ms +step:988/1645 train_time:91183ms step_avg:92.29ms +step:989/1645 train_time:91276ms step_avg:92.29ms +step:990/1645 train_time:91369ms step_avg:92.29ms +step:991/1645 train_time:91462ms step_avg:92.29ms +step:992/1645 train_time:91555ms step_avg:92.29ms +step:993/1645 train_time:91648ms step_avg:92.29ms +step:994/1645 train_time:91741ms step_avg:92.30ms +step:995/1645 train_time:91835ms step_avg:92.30ms +step:996/1645 train_time:91927ms step_avg:92.30ms +step:997/1645 train_time:92021ms step_avg:92.30ms +step:998/1645 train_time:92113ms step_avg:92.30ms +step:999/1645 train_time:92206ms step_avg:92.30ms +step:1000/1645 train_time:92299ms step_avg:92.30ms +step:1000/1645 val_loss:3.4659 train_time:92392ms step_avg:92.39ms +step:1001/1645 train_time:92417ms step_avg:92.32ms +step:1002/1645 train_time:92490ms step_avg:92.31ms +step:1003/1645 train_time:92586ms step_avg:92.31ms +step:1004/1645 train_time:92679ms step_avg:92.31ms +step:1005/1645 train_time:92771ms step_avg:92.31ms +step:1006/1645 train_time:92862ms step_avg:92.31ms +step:1007/1645 train_time:92954ms step_avg:92.31ms +step:1008/1645 train_time:93046ms step_avg:92.31ms +step:1009/1645 train_time:93138ms step_avg:92.31ms +step:1010/1645 train_time:93230ms step_avg:92.31ms +step:1011/1645 train_time:93323ms step_avg:92.31ms +step:1012/1645 train_time:93417ms step_avg:92.31ms +step:1013/1645 train_time:93513ms step_avg:92.31ms +step:1014/1645 train_time:93607ms step_avg:92.31ms +step:1015/1645 train_time:93700ms step_avg:92.32ms +step:1016/1645 train_time:93793ms step_avg:92.32ms +step:1017/1645 train_time:93885ms step_avg:92.32ms +step:1018/1645 train_time:93977ms step_avg:92.31ms +step:1019/1645 train_time:94069ms step_avg:92.31ms +step:1020/1645 train_time:94161ms step_avg:92.31ms +step:1021/1645 train_time:94253ms step_avg:92.31ms +step:1022/1645 train_time:94347ms step_avg:92.32ms +step:1023/1645 train_time:94441ms step_avg:92.32ms +step:1024/1645 train_time:94534ms step_avg:92.32ms +step:1025/1645 train_time:94627ms step_avg:92.32ms +step:1026/1645 train_time:94720ms step_avg:92.32ms +step:1027/1645 train_time:94813ms step_avg:92.32ms +step:1028/1645 train_time:94906ms step_avg:92.32ms +step:1029/1645 train_time:94998ms step_avg:92.32ms +step:1030/1645 train_time:95090ms step_avg:92.32ms +step:1031/1645 train_time:95183ms step_avg:92.32ms +step:1032/1645 train_time:95275ms step_avg:92.32ms +step:1033/1645 train_time:95368ms step_avg:92.32ms +step:1034/1645 train_time:95463ms step_avg:92.32ms +step:1035/1645 train_time:95555ms step_avg:92.32ms +step:1036/1645 train_time:95649ms step_avg:92.33ms +step:1037/1645 train_time:95741ms step_avg:92.33ms +step:1038/1645 train_time:95834ms step_avg:92.33ms +step:1039/1645 train_time:95928ms step_avg:92.33ms +step:1040/1645 train_time:96020ms step_avg:92.33ms +step:1041/1645 train_time:96112ms step_avg:92.33ms +step:1042/1645 train_time:96205ms step_avg:92.33ms +step:1043/1645 train_time:96298ms step_avg:92.33ms +step:1044/1645 train_time:96390ms step_avg:92.33ms +step:1045/1645 train_time:96485ms step_avg:92.33ms +step:1046/1645 train_time:96578ms step_avg:92.33ms +step:1047/1645 train_time:96671ms step_avg:92.33ms +step:1048/1645 train_time:96763ms step_avg:92.33ms +step:1049/1645 train_time:96856ms step_avg:92.33ms +step:1050/1645 train_time:96949ms step_avg:92.33ms +step:1051/1645 train_time:97041ms step_avg:92.33ms +step:1052/1645 train_time:97133ms step_avg:92.33ms +step:1053/1645 train_time:97226ms step_avg:92.33ms +step:1054/1645 train_time:97319ms step_avg:92.33ms +step:1055/1645 train_time:97413ms step_avg:92.33ms +step:1056/1645 train_time:97507ms step_avg:92.34ms +step:1057/1645 train_time:97600ms step_avg:92.34ms +step:1058/1645 train_time:97693ms step_avg:92.34ms +step:1059/1645 train_time:97788ms step_avg:92.34ms +step:1060/1645 train_time:97880ms step_avg:92.34ms +step:1061/1645 train_time:97972ms step_avg:92.34ms +step:1062/1645 train_time:98066ms step_avg:92.34ms +step:1063/1645 train_time:98158ms step_avg:92.34ms +step:1064/1645 train_time:98250ms step_avg:92.34ms +step:1065/1645 train_time:98344ms step_avg:92.34ms +step:1066/1645 train_time:98436ms step_avg:92.34ms +step:1067/1645 train_time:98530ms step_avg:92.34ms +step:1068/1645 train_time:98623ms step_avg:92.34ms +step:1069/1645 train_time:98716ms step_avg:92.34ms +step:1070/1645 train_time:98810ms step_avg:92.35ms +step:1071/1645 train_time:98902ms step_avg:92.35ms +step:1072/1645 train_time:98996ms step_avg:92.35ms +step:1073/1645 train_time:99088ms step_avg:92.35ms +step:1074/1645 train_time:99180ms step_avg:92.35ms +step:1075/1645 train_time:99272ms step_avg:92.35ms +step:1076/1645 train_time:99365ms step_avg:92.35ms +step:1077/1645 train_time:99458ms step_avg:92.35ms +step:1078/1645 train_time:99551ms step_avg:92.35ms +step:1079/1645 train_time:99644ms step_avg:92.35ms +step:1080/1645 train_time:99737ms step_avg:92.35ms +step:1081/1645 train_time:99830ms step_avg:92.35ms +step:1082/1645 train_time:99924ms step_avg:92.35ms +step:1083/1645 train_time:100018ms step_avg:92.35ms +step:1084/1645 train_time:100110ms step_avg:92.35ms +step:1085/1645 train_time:100203ms step_avg:92.35ms +step:1086/1645 train_time:100296ms step_avg:92.35ms +step:1087/1645 train_time:100389ms step_avg:92.35ms +step:1088/1645 train_time:100482ms step_avg:92.35ms +step:1089/1645 train_time:100574ms step_avg:92.35ms +step:1090/1645 train_time:100667ms step_avg:92.36ms +step:1091/1645 train_time:100760ms step_avg:92.36ms +step:1092/1645 train_time:100853ms step_avg:92.36ms +step:1093/1645 train_time:100946ms step_avg:92.36ms +step:1094/1645 train_time:101039ms step_avg:92.36ms +step:1095/1645 train_time:101132ms step_avg:92.36ms +step:1096/1645 train_time:101225ms step_avg:92.36ms +step:1097/1645 train_time:101318ms step_avg:92.36ms +step:1098/1645 train_time:101411ms step_avg:92.36ms +step:1099/1645 train_time:101506ms step_avg:92.36ms +step:1100/1645 train_time:101599ms step_avg:92.36ms +step:1101/1645 train_time:101692ms step_avg:92.36ms +step:1102/1645 train_time:101786ms step_avg:92.37ms +step:1103/1645 train_time:101879ms step_avg:92.37ms +step:1104/1645 train_time:101973ms step_avg:92.37ms +step:1105/1645 train_time:102067ms step_avg:92.37ms +step:1106/1645 train_time:102159ms step_avg:92.37ms +step:1107/1645 train_time:102253ms step_avg:92.37ms +step:1108/1645 train_time:102347ms step_avg:92.37ms +step:1109/1645 train_time:102440ms step_avg:92.37ms +step:1110/1645 train_time:102533ms step_avg:92.37ms +step:1111/1645 train_time:102628ms step_avg:92.37ms +step:1112/1645 train_time:102722ms step_avg:92.38ms +step:1113/1645 train_time:102814ms step_avg:92.38ms +step:1114/1645 train_time:102909ms step_avg:92.38ms +step:1115/1645 train_time:103002ms step_avg:92.38ms +step:1116/1645 train_time:103096ms step_avg:92.38ms +step:1117/1645 train_time:103189ms step_avg:92.38ms +step:1118/1645 train_time:103282ms step_avg:92.38ms +step:1119/1645 train_time:103375ms step_avg:92.38ms +step:1120/1645 train_time:103469ms step_avg:92.38ms +step:1121/1645 train_time:103562ms step_avg:92.38ms +step:1122/1645 train_time:103656ms step_avg:92.39ms +step:1123/1645 train_time:103750ms step_avg:92.39ms +step:1124/1645 train_time:103843ms step_avg:92.39ms +step:1125/1645 train_time:103936ms step_avg:92.39ms +step:1125/1645 val_loss:3.4121 train_time:104031ms step_avg:92.47ms +step:1126/1645 train_time:104057ms step_avg:92.41ms +step:1127/1645 train_time:104133ms step_avg:92.40ms +step:1128/1645 train_time:104235ms step_avg:92.41ms +step:1129/1645 train_time:104329ms step_avg:92.41ms +step:1130/1645 train_time:104421ms step_avg:92.41ms +step:1131/1645 train_time:104514ms step_avg:92.41ms +step:1132/1645 train_time:104607ms step_avg:92.41ms +step:1133/1645 train_time:104700ms step_avg:92.41ms +step:1134/1645 train_time:104793ms step_avg:92.41ms +step:1135/1645 train_time:104885ms step_avg:92.41ms +step:1136/1645 train_time:104978ms step_avg:92.41ms +step:1137/1645 train_time:105074ms step_avg:92.41ms +step:1138/1645 train_time:105171ms step_avg:92.42ms +step:1139/1645 train_time:105267ms step_avg:92.42ms +step:1140/1645 train_time:105361ms step_avg:92.42ms +step:1141/1645 train_time:105455ms step_avg:92.42ms +step:1142/1645 train_time:105548ms step_avg:92.42ms +step:1143/1645 train_time:105640ms step_avg:92.42ms +step:1144/1645 train_time:105733ms step_avg:92.42ms +step:1145/1645 train_time:105826ms step_avg:92.42ms +step:1146/1645 train_time:105918ms step_avg:92.42ms +step:1147/1645 train_time:106012ms step_avg:92.43ms +step:1148/1645 train_time:106107ms step_avg:92.43ms +step:1149/1645 train_time:106201ms step_avg:92.43ms +step:1150/1645 train_time:106296ms step_avg:92.43ms +step:1151/1645 train_time:106390ms step_avg:92.43ms +step:1152/1645 train_time:106483ms step_avg:92.43ms +step:1153/1645 train_time:106576ms step_avg:92.43ms +step:1154/1645 train_time:106669ms step_avg:92.43ms +step:1155/1645 train_time:106762ms step_avg:92.43ms +step:1156/1645 train_time:106856ms step_avg:92.44ms +step:1157/1645 train_time:106950ms step_avg:92.44ms +step:1158/1645 train_time:107044ms step_avg:92.44ms +step:1159/1645 train_time:107138ms step_avg:92.44ms +step:1160/1645 train_time:107232ms step_avg:92.44ms +step:1161/1645 train_time:107326ms step_avg:92.44ms +step:1162/1645 train_time:107420ms step_avg:92.44ms +step:1163/1645 train_time:107514ms step_avg:92.45ms +step:1164/1645 train_time:107607ms step_avg:92.45ms +step:1165/1645 train_time:107700ms step_avg:92.45ms +step:1166/1645 train_time:107793ms step_avg:92.45ms +step:1167/1645 train_time:107886ms step_avg:92.45ms +step:1168/1645 train_time:107979ms step_avg:92.45ms +step:1169/1645 train_time:108073ms step_avg:92.45ms +step:1170/1645 train_time:108167ms step_avg:92.45ms +step:1171/1645 train_time:108261ms step_avg:92.45ms +step:1172/1645 train_time:108355ms step_avg:92.45ms +step:1173/1645 train_time:108449ms step_avg:92.45ms +step:1174/1645 train_time:108543ms step_avg:92.46ms +step:1175/1645 train_time:108636ms step_avg:92.46ms +step:1176/1645 train_time:108730ms step_avg:92.46ms +step:1177/1645 train_time:108823ms step_avg:92.46ms +step:1178/1645 train_time:108916ms step_avg:92.46ms +step:1179/1645 train_time:109009ms step_avg:92.46ms +step:1180/1645 train_time:109103ms step_avg:92.46ms +step:1181/1645 train_time:109197ms step_avg:92.46ms +step:1182/1645 train_time:109290ms step_avg:92.46ms +step:1183/1645 train_time:109383ms step_avg:92.46ms +step:1184/1645 train_time:109477ms step_avg:92.46ms +step:1185/1645 train_time:109572ms step_avg:92.47ms +step:1186/1645 train_time:109666ms step_avg:92.47ms +step:1187/1645 train_time:109758ms step_avg:92.47ms +step:1188/1645 train_time:109852ms step_avg:92.47ms +step:1189/1645 train_time:109946ms step_avg:92.47ms +step:1190/1645 train_time:110039ms step_avg:92.47ms +step:1191/1645 train_time:110133ms step_avg:92.47ms +step:1192/1645 train_time:110226ms step_avg:92.47ms +step:1193/1645 train_time:110320ms step_avg:92.47ms +step:1194/1645 train_time:110413ms step_avg:92.47ms +step:1195/1645 train_time:110507ms step_avg:92.47ms +step:1196/1645 train_time:110601ms step_avg:92.48ms +step:1197/1645 train_time:110694ms step_avg:92.48ms +step:1198/1645 train_time:110788ms step_avg:92.48ms +step:1199/1645 train_time:110881ms step_avg:92.48ms +step:1200/1645 train_time:110975ms step_avg:92.48ms +step:1201/1645 train_time:111068ms step_avg:92.48ms +step:1202/1645 train_time:111162ms step_avg:92.48ms +step:1203/1645 train_time:111256ms step_avg:92.48ms +step:1204/1645 train_time:111349ms step_avg:92.48ms +step:1205/1645 train_time:111443ms step_avg:92.48ms +step:1206/1645 train_time:111537ms step_avg:92.49ms +step:1207/1645 train_time:111631ms step_avg:92.49ms +step:1208/1645 train_time:111724ms step_avg:92.49ms +step:1209/1645 train_time:111818ms step_avg:92.49ms +step:1210/1645 train_time:111911ms step_avg:92.49ms +step:1211/1645 train_time:112005ms step_avg:92.49ms +step:1212/1645 train_time:112098ms step_avg:92.49ms +step:1213/1645 train_time:112191ms step_avg:92.49ms +step:1214/1645 train_time:112284ms step_avg:92.49ms +step:1215/1645 train_time:112378ms step_avg:92.49ms +step:1216/1645 train_time:112472ms step_avg:92.49ms +step:1217/1645 train_time:112566ms step_avg:92.49ms +step:1218/1645 train_time:112659ms step_avg:92.50ms +step:1219/1645 train_time:112753ms step_avg:92.50ms +step:1220/1645 train_time:112847ms step_avg:92.50ms +step:1221/1645 train_time:112940ms step_avg:92.50ms +step:1222/1645 train_time:113034ms step_avg:92.50ms +step:1223/1645 train_time:113127ms step_avg:92.50ms +step:1224/1645 train_time:113220ms step_avg:92.50ms +step:1225/1645 train_time:113314ms step_avg:92.50ms +step:1226/1645 train_time:113407ms step_avg:92.50ms +step:1227/1645 train_time:113501ms step_avg:92.50ms +step:1228/1645 train_time:113595ms step_avg:92.50ms +step:1229/1645 train_time:113688ms step_avg:92.50ms +step:1230/1645 train_time:113783ms step_avg:92.51ms +step:1231/1645 train_time:113876ms step_avg:92.51ms +step:1232/1645 train_time:113970ms step_avg:92.51ms +step:1233/1645 train_time:114064ms step_avg:92.51ms +step:1234/1645 train_time:114157ms step_avg:92.51ms +step:1235/1645 train_time:114251ms step_avg:92.51ms +step:1236/1645 train_time:114345ms step_avg:92.51ms +step:1237/1645 train_time:114438ms step_avg:92.51ms +step:1238/1645 train_time:114532ms step_avg:92.51ms +step:1239/1645 train_time:114625ms step_avg:92.51ms +step:1240/1645 train_time:114719ms step_avg:92.52ms +step:1241/1645 train_time:114813ms step_avg:92.52ms +step:1242/1645 train_time:114907ms step_avg:92.52ms +step:1243/1645 train_time:115000ms step_avg:92.52ms +step:1244/1645 train_time:115093ms step_avg:92.52ms +step:1245/1645 train_time:115186ms step_avg:92.52ms +step:1246/1645 train_time:115280ms step_avg:92.52ms +step:1247/1645 train_time:115373ms step_avg:92.52ms +step:1248/1645 train_time:115466ms step_avg:92.52ms +step:1249/1645 train_time:115559ms step_avg:92.52ms +step:1250/1645 train_time:115654ms step_avg:92.52ms +step:1250/1645 val_loss:3.3733 train_time:115747ms step_avg:92.60ms +step:1251/1645 train_time:115773ms step_avg:92.54ms +step:1252/1645 train_time:115848ms step_avg:92.53ms +step:1253/1645 train_time:115942ms step_avg:92.53ms +step:1254/1645 train_time:116035ms step_avg:92.53ms +step:1255/1645 train_time:116127ms step_avg:92.53ms +step:1256/1645 train_time:116219ms step_avg:92.53ms +step:1257/1645 train_time:116312ms step_avg:92.53ms +step:1258/1645 train_time:116404ms step_avg:92.53ms +step:1259/1645 train_time:116497ms step_avg:92.53ms +step:1260/1645 train_time:116590ms step_avg:92.53ms +step:1261/1645 train_time:116685ms step_avg:92.53ms +step:1262/1645 train_time:116783ms step_avg:92.54ms +step:1263/1645 train_time:116879ms step_avg:92.54ms +step:1264/1645 train_time:116973ms step_avg:92.54ms +step:1265/1645 train_time:117066ms step_avg:92.54ms +step:1266/1645 train_time:117160ms step_avg:92.54ms +step:1267/1645 train_time:117252ms step_avg:92.54ms +step:1268/1645 train_time:117345ms step_avg:92.54ms +step:1269/1645 train_time:117438ms step_avg:92.54ms +step:1270/1645 train_time:117531ms step_avg:92.54ms +step:1271/1645 train_time:117624ms step_avg:92.54ms +step:1272/1645 train_time:117719ms step_avg:92.55ms +step:1273/1645 train_time:117813ms step_avg:92.55ms +step:1274/1645 train_time:117909ms step_avg:92.55ms +step:1275/1645 train_time:118003ms step_avg:92.55ms +step:1276/1645 train_time:118096ms step_avg:92.55ms +step:1277/1645 train_time:118188ms step_avg:92.55ms +step:1278/1645 train_time:118282ms step_avg:92.55ms +step:1279/1645 train_time:118374ms step_avg:92.55ms +step:1280/1645 train_time:118467ms step_avg:92.55ms +step:1281/1645 train_time:118560ms step_avg:92.55ms +step:1282/1645 train_time:118654ms step_avg:92.55ms +step:1283/1645 train_time:118748ms step_avg:92.55ms +step:1284/1645 train_time:118842ms step_avg:92.56ms +step:1285/1645 train_time:118936ms step_avg:92.56ms +step:1286/1645 train_time:119029ms step_avg:92.56ms +step:1287/1645 train_time:119123ms step_avg:92.56ms +step:1288/1645 train_time:119216ms step_avg:92.56ms +step:1289/1645 train_time:119310ms step_avg:92.56ms +step:1290/1645 train_time:119403ms step_avg:92.56ms +step:1291/1645 train_time:119496ms step_avg:92.56ms +step:1292/1645 train_time:119590ms step_avg:92.56ms +step:1293/1645 train_time:119684ms step_avg:92.56ms +step:1294/1645 train_time:119778ms step_avg:92.56ms +step:1295/1645 train_time:119872ms step_avg:92.57ms +step:1296/1645 train_time:119966ms step_avg:92.57ms +step:1297/1645 train_time:120060ms step_avg:92.57ms +step:1298/1645 train_time:120154ms step_avg:92.57ms +step:1299/1645 train_time:120248ms step_avg:92.57ms +step:1300/1645 train_time:120341ms step_avg:92.57ms +step:1301/1645 train_time:120433ms step_avg:92.57ms +step:1302/1645 train_time:120527ms step_avg:92.57ms +step:1303/1645 train_time:120620ms step_avg:92.57ms +step:1304/1645 train_time:120714ms step_avg:92.57ms +step:1305/1645 train_time:120808ms step_avg:92.57ms +step:1306/1645 train_time:120905ms step_avg:92.58ms +step:1307/1645 train_time:120998ms step_avg:92.58ms +step:1308/1645 train_time:121090ms step_avg:92.58ms +step:1309/1645 train_time:121185ms step_avg:92.58ms +step:1310/1645 train_time:121278ms step_avg:92.58ms +step:1311/1645 train_time:121371ms step_avg:92.58ms +step:1312/1645 train_time:121464ms step_avg:92.58ms +step:1313/1645 train_time:121558ms step_avg:92.58ms +step:1314/1645 train_time:121652ms step_avg:92.58ms +step:1315/1645 train_time:121745ms step_avg:92.58ms +step:1316/1645 train_time:121839ms step_avg:92.58ms +step:1317/1645 train_time:121933ms step_avg:92.58ms +step:1318/1645 train_time:122027ms step_avg:92.58ms +step:1319/1645 train_time:122120ms step_avg:92.59ms +step:1320/1645 train_time:122213ms step_avg:92.59ms +step:1321/1645 train_time:122307ms step_avg:92.59ms +step:1322/1645 train_time:122400ms step_avg:92.59ms +step:1323/1645 train_time:122493ms step_avg:92.59ms +step:1324/1645 train_time:122587ms step_avg:92.59ms +step:1325/1645 train_time:122681ms step_avg:92.59ms +step:1326/1645 train_time:122776ms step_avg:92.59ms +step:1327/1645 train_time:122869ms step_avg:92.59ms +step:1328/1645 train_time:122963ms step_avg:92.59ms +step:1329/1645 train_time:123057ms step_avg:92.59ms +step:1330/1645 train_time:123150ms step_avg:92.59ms +step:1331/1645 train_time:123244ms step_avg:92.59ms +step:1332/1645 train_time:123337ms step_avg:92.60ms +step:1333/1645 train_time:123430ms step_avg:92.60ms +step:1334/1645 train_time:123524ms step_avg:92.60ms +step:1335/1645 train_time:123619ms step_avg:92.60ms +step:1336/1645 train_time:123712ms step_avg:92.60ms +step:1337/1645 train_time:123805ms step_avg:92.60ms +step:1338/1645 train_time:123899ms step_avg:92.60ms +step:1339/1645 train_time:123992ms step_avg:92.60ms +step:1340/1645 train_time:124085ms step_avg:92.60ms +step:1341/1645 train_time:124180ms step_avg:92.60ms +step:1342/1645 train_time:124274ms step_avg:92.60ms +step:1343/1645 train_time:124367ms step_avg:92.60ms +step:1344/1645 train_time:124460ms step_avg:92.60ms +step:1345/1645 train_time:124553ms step_avg:92.60ms +step:1346/1645 train_time:124647ms step_avg:92.61ms +step:1347/1645 train_time:124740ms step_avg:92.61ms +step:1348/1645 train_time:124834ms step_avg:92.61ms +step:1349/1645 train_time:124928ms step_avg:92.61ms +step:1350/1645 train_time:125021ms step_avg:92.61ms +step:1351/1645 train_time:125115ms step_avg:92.61ms +step:1352/1645 train_time:125209ms step_avg:92.61ms +step:1353/1645 train_time:125303ms step_avg:92.61ms +step:1354/1645 train_time:125396ms step_avg:92.61ms +step:1355/1645 train_time:125488ms step_avg:92.61ms +step:1356/1645 train_time:125583ms step_avg:92.61ms +step:1357/1645 train_time:125677ms step_avg:92.61ms +step:1358/1645 train_time:125771ms step_avg:92.62ms +step:1359/1645 train_time:125864ms step_avg:92.62ms +step:1360/1645 train_time:125957ms step_avg:92.62ms +step:1361/1645 train_time:126051ms step_avg:92.62ms +step:1362/1645 train_time:126146ms step_avg:92.62ms +step:1363/1645 train_time:126239ms step_avg:92.62ms +step:1364/1645 train_time:126332ms step_avg:92.62ms +step:1365/1645 train_time:126427ms step_avg:92.62ms +step:1366/1645 train_time:126520ms step_avg:92.62ms +step:1367/1645 train_time:126613ms step_avg:92.62ms +step:1368/1645 train_time:126706ms step_avg:92.62ms +step:1369/1645 train_time:126800ms step_avg:92.62ms +step:1370/1645 train_time:126893ms step_avg:92.62ms +step:1371/1645 train_time:126986ms step_avg:92.62ms +step:1372/1645 train_time:127081ms step_avg:92.62ms +step:1373/1645 train_time:127176ms step_avg:92.63ms +step:1374/1645 train_time:127269ms step_avg:92.63ms +step:1375/1645 train_time:127363ms step_avg:92.63ms +step:1375/1645 val_loss:3.3391 train_time:127456ms step_avg:92.70ms +step:1376/1645 train_time:127477ms step_avg:92.64ms +step:1377/1645 train_time:127555ms step_avg:92.63ms +step:1378/1645 train_time:127650ms step_avg:92.63ms +step:1379/1645 train_time:127743ms step_avg:92.63ms +step:1380/1645 train_time:127836ms step_avg:92.63ms +step:1381/1645 train_time:127928ms step_avg:92.63ms +step:1382/1645 train_time:128021ms step_avg:92.63ms +step:1383/1645 train_time:128113ms step_avg:92.63ms +step:1384/1645 train_time:128206ms step_avg:92.63ms +step:1385/1645 train_time:128300ms step_avg:92.64ms +step:1386/1645 train_time:128394ms step_avg:92.64ms +step:1387/1645 train_time:128489ms step_avg:92.64ms +step:1388/1645 train_time:128584ms step_avg:92.64ms +step:1389/1645 train_time:128679ms step_avg:92.64ms +step:1390/1645 train_time:128773ms step_avg:92.64ms +step:1391/1645 train_time:128866ms step_avg:92.64ms +step:1392/1645 train_time:128960ms step_avg:92.64ms +step:1393/1645 train_time:129052ms step_avg:92.64ms +step:1394/1645 train_time:129145ms step_avg:92.64ms +step:1395/1645 train_time:129237ms step_avg:92.64ms +step:1396/1645 train_time:129330ms step_avg:92.64ms +step:1397/1645 train_time:129424ms step_avg:92.64ms +step:1398/1645 train_time:129522ms step_avg:92.65ms +step:1399/1645 train_time:129617ms step_avg:92.65ms +step:1400/1645 train_time:129710ms step_avg:92.65ms +step:1401/1645 train_time:129804ms step_avg:92.65ms +step:1402/1645 train_time:129898ms step_avg:92.65ms +step:1403/1645 train_time:129992ms step_avg:92.65ms +step:1404/1645 train_time:130085ms step_avg:92.65ms +step:1405/1645 train_time:130177ms step_avg:92.65ms +step:1406/1645 train_time:130270ms step_avg:92.65ms +step:1407/1645 train_time:130363ms step_avg:92.65ms +step:1408/1645 train_time:130457ms step_avg:92.65ms +step:1409/1645 train_time:130552ms step_avg:92.66ms +step:1410/1645 train_time:130645ms step_avg:92.66ms +step:1411/1645 train_time:130739ms step_avg:92.66ms +step:1412/1645 train_time:130832ms step_avg:92.66ms +step:1413/1645 train_time:130925ms step_avg:92.66ms +step:1414/1645 train_time:131019ms step_avg:92.66ms +step:1415/1645 train_time:131112ms step_avg:92.66ms +step:1416/1645 train_time:131206ms step_avg:92.66ms +step:1417/1645 train_time:131299ms step_avg:92.66ms +step:1418/1645 train_time:131391ms step_avg:92.66ms +step:1419/1645 train_time:131485ms step_avg:92.66ms +step:1420/1645 train_time:131579ms step_avg:92.66ms +step:1421/1645 train_time:131673ms step_avg:92.66ms +step:1422/1645 train_time:131766ms step_avg:92.66ms +step:1423/1645 train_time:131860ms step_avg:92.66ms +step:1424/1645 train_time:131953ms step_avg:92.66ms +step:1425/1645 train_time:132047ms step_avg:92.66ms +step:1426/1645 train_time:132141ms step_avg:92.67ms +step:1427/1645 train_time:132235ms step_avg:92.67ms +step:1428/1645 train_time:132327ms step_avg:92.67ms +step:1429/1645 train_time:132421ms step_avg:92.67ms +step:1430/1645 train_time:132516ms step_avg:92.67ms +step:1431/1645 train_time:132609ms step_avg:92.67ms +step:1432/1645 train_time:132703ms step_avg:92.67ms +step:1433/1645 train_time:132797ms step_avg:92.67ms +step:1434/1645 train_time:132892ms step_avg:92.67ms +step:1435/1645 train_time:132985ms step_avg:92.67ms +step:1436/1645 train_time:133079ms step_avg:92.67ms +step:1437/1645 train_time:133172ms step_avg:92.67ms +step:1438/1645 train_time:133264ms step_avg:92.67ms +step:1439/1645 train_time:133358ms step_avg:92.67ms +step:1440/1645 train_time:133453ms step_avg:92.68ms +step:1441/1645 train_time:133546ms step_avg:92.68ms +step:1442/1645 train_time:133641ms step_avg:92.68ms +step:1443/1645 train_time:133734ms step_avg:92.68ms +step:1444/1645 train_time:133828ms step_avg:92.68ms +step:1445/1645 train_time:133922ms step_avg:92.68ms +step:1446/1645 train_time:134017ms step_avg:92.68ms +step:1447/1645 train_time:134110ms step_avg:92.68ms +step:1448/1645 train_time:134203ms step_avg:92.68ms +step:1449/1645 train_time:134297ms step_avg:92.68ms +step:1450/1645 train_time:134391ms step_avg:92.68ms +step:1451/1645 train_time:134484ms step_avg:92.68ms +step:1452/1645 train_time:134578ms step_avg:92.68ms +step:1453/1645 train_time:134672ms step_avg:92.69ms +step:1454/1645 train_time:134765ms step_avg:92.69ms +step:1455/1645 train_time:134860ms step_avg:92.69ms +step:1456/1645 train_time:134955ms step_avg:92.69ms +step:1457/1645 train_time:135049ms step_avg:92.69ms +step:1458/1645 train_time:135141ms step_avg:92.69ms +step:1459/1645 train_time:135235ms step_avg:92.69ms +step:1460/1645 train_time:135329ms step_avg:92.69ms +step:1461/1645 train_time:135422ms step_avg:92.69ms +step:1462/1645 train_time:135518ms step_avg:92.69ms +step:1463/1645 train_time:135611ms step_avg:92.69ms +step:1464/1645 train_time:135705ms step_avg:92.69ms +step:1465/1645 train_time:135799ms step_avg:92.70ms +step:1466/1645 train_time:135893ms step_avg:92.70ms +step:1467/1645 train_time:135986ms step_avg:92.70ms +step:1468/1645 train_time:136079ms step_avg:92.70ms +step:1469/1645 train_time:136173ms step_avg:92.70ms +step:1470/1645 train_time:136266ms step_avg:92.70ms +step:1471/1645 train_time:136360ms step_avg:92.70ms +step:1472/1645 train_time:136455ms step_avg:92.70ms +step:1473/1645 train_time:136549ms step_avg:92.70ms +step:1474/1645 train_time:136642ms step_avg:92.70ms +step:1475/1645 train_time:136736ms step_avg:92.70ms +step:1476/1645 train_time:136829ms step_avg:92.70ms +step:1477/1645 train_time:136923ms step_avg:92.70ms +step:1478/1645 train_time:137017ms step_avg:92.70ms +step:1479/1645 train_time:137111ms step_avg:92.71ms +step:1480/1645 train_time:137204ms step_avg:92.71ms +step:1481/1645 train_time:137297ms step_avg:92.71ms +step:1482/1645 train_time:137390ms step_avg:92.71ms +step:1483/1645 train_time:137483ms step_avg:92.71ms +step:1484/1645 train_time:137577ms step_avg:92.71ms +step:1485/1645 train_time:137671ms step_avg:92.71ms +step:1486/1645 train_time:137766ms step_avg:92.71ms +step:1487/1645 train_time:137860ms step_avg:92.71ms +step:1488/1645 train_time:137953ms step_avg:92.71ms +step:1489/1645 train_time:138046ms step_avg:92.71ms +step:1490/1645 train_time:138139ms step_avg:92.71ms +step:1491/1645 train_time:138232ms step_avg:92.71ms +step:1492/1645 train_time:138325ms step_avg:92.71ms +step:1493/1645 train_time:138418ms step_avg:92.71ms +step:1494/1645 train_time:138512ms step_avg:92.71ms +step:1495/1645 train_time:138605ms step_avg:92.71ms +step:1496/1645 train_time:138698ms step_avg:92.71ms +step:1497/1645 train_time:138792ms step_avg:92.71ms +step:1498/1645 train_time:138885ms step_avg:92.71ms +step:1499/1645 train_time:138979ms step_avg:92.71ms +step:1500/1645 train_time:139073ms step_avg:92.72ms +step:1500/1645 val_loss:3.3094 train_time:139166ms step_avg:92.78ms +step:1501/1645 train_time:139191ms step_avg:92.73ms +step:1502/1645 train_time:139263ms step_avg:92.72ms +step:1503/1645 train_time:139359ms step_avg:92.72ms +step:1504/1645 train_time:139453ms step_avg:92.72ms +step:1505/1645 train_time:139547ms step_avg:92.72ms +step:1506/1645 train_time:139639ms step_avg:92.72ms +step:1507/1645 train_time:139732ms step_avg:92.72ms +step:1508/1645 train_time:139825ms step_avg:92.72ms +step:1509/1645 train_time:139917ms step_avg:92.72ms +step:1510/1645 train_time:140010ms step_avg:92.72ms +step:1511/1645 train_time:140104ms step_avg:92.72ms +step:1512/1645 train_time:140200ms step_avg:92.72ms +step:1513/1645 train_time:140295ms step_avg:92.73ms +step:1514/1645 train_time:140390ms step_avg:92.73ms +step:1515/1645 train_time:140484ms step_avg:92.73ms +step:1516/1645 train_time:140577ms step_avg:92.73ms +step:1517/1645 train_time:140670ms step_avg:92.73ms +step:1518/1645 train_time:140762ms step_avg:92.73ms +step:1519/1645 train_time:140856ms step_avg:92.73ms +step:1520/1645 train_time:140950ms step_avg:92.73ms +step:1521/1645 train_time:141043ms step_avg:92.73ms +step:1522/1645 train_time:141138ms step_avg:92.73ms +step:1523/1645 train_time:141231ms step_avg:92.73ms +step:1524/1645 train_time:141325ms step_avg:92.73ms +step:1525/1645 train_time:141419ms step_avg:92.73ms +step:1526/1645 train_time:141513ms step_avg:92.73ms +step:1527/1645 train_time:141607ms step_avg:92.74ms +step:1528/1645 train_time:141700ms step_avg:92.74ms +step:1529/1645 train_time:141793ms step_avg:92.74ms +step:1530/1645 train_time:141886ms step_avg:92.74ms +step:1531/1645 train_time:141979ms step_avg:92.74ms +step:1532/1645 train_time:142073ms step_avg:92.74ms +step:1533/1645 train_time:142167ms step_avg:92.74ms +step:1534/1645 train_time:142261ms step_avg:92.74ms +step:1535/1645 train_time:142355ms step_avg:92.74ms +step:1536/1645 train_time:142449ms step_avg:92.74ms +step:1537/1645 train_time:142543ms step_avg:92.74ms +step:1538/1645 train_time:142636ms step_avg:92.74ms +step:1539/1645 train_time:142730ms step_avg:92.74ms +step:1540/1645 train_time:142823ms step_avg:92.74ms +step:1541/1645 train_time:142917ms step_avg:92.74ms +step:1542/1645 train_time:143010ms step_avg:92.74ms +step:1543/1645 train_time:143104ms step_avg:92.74ms +step:1544/1645 train_time:143197ms step_avg:92.74ms +step:1545/1645 train_time:143290ms step_avg:92.74ms +step:1546/1645 train_time:143384ms step_avg:92.75ms +step:1547/1645 train_time:143478ms step_avg:92.75ms +step:1548/1645 train_time:143573ms step_avg:92.75ms +step:1549/1645 train_time:143667ms step_avg:92.75ms +step:1550/1645 train_time:143760ms step_avg:92.75ms +step:1551/1645 train_time:143853ms step_avg:92.75ms +step:1552/1645 train_time:143947ms step_avg:92.75ms +step:1553/1645 train_time:144041ms step_avg:92.75ms +step:1554/1645 train_time:144134ms step_avg:92.75ms +step:1555/1645 train_time:144229ms step_avg:92.75ms +step:1556/1645 train_time:144324ms step_avg:92.75ms +step:1557/1645 train_time:144417ms step_avg:92.75ms +step:1558/1645 train_time:144511ms step_avg:92.75ms +step:1559/1645 train_time:144604ms step_avg:92.75ms +step:1560/1645 train_time:144698ms step_avg:92.76ms +step:1561/1645 train_time:144792ms step_avg:92.76ms +step:1562/1645 train_time:144885ms step_avg:92.76ms +step:1563/1645 train_time:144979ms step_avg:92.76ms +step:1564/1645 train_time:145072ms step_avg:92.76ms +step:1565/1645 train_time:145166ms step_avg:92.76ms +step:1566/1645 train_time:145260ms step_avg:92.76ms +step:1567/1645 train_time:145353ms step_avg:92.76ms +step:1568/1645 train_time:145448ms step_avg:92.76ms +step:1569/1645 train_time:145542ms step_avg:92.76ms +step:1570/1645 train_time:145635ms step_avg:92.76ms +step:1571/1645 train_time:145730ms step_avg:92.76ms +step:1572/1645 train_time:145822ms step_avg:92.76ms +step:1573/1645 train_time:145915ms step_avg:92.76ms +step:1574/1645 train_time:146008ms step_avg:92.76ms +step:1575/1645 train_time:146102ms step_avg:92.76ms +step:1576/1645 train_time:146195ms step_avg:92.76ms +step:1577/1645 train_time:146288ms step_avg:92.76ms +step:1578/1645 train_time:146382ms step_avg:92.76ms +step:1579/1645 train_time:146476ms step_avg:92.77ms +step:1580/1645 train_time:146570ms step_avg:92.77ms +step:1581/1645 train_time:146663ms step_avg:92.77ms +step:1582/1645 train_time:146756ms step_avg:92.77ms +step:1583/1645 train_time:146850ms step_avg:92.77ms +step:1584/1645 train_time:146943ms step_avg:92.77ms +step:1585/1645 train_time:147036ms step_avg:92.77ms +step:1586/1645 train_time:147131ms step_avg:92.77ms +step:1587/1645 train_time:147224ms step_avg:92.77ms +step:1588/1645 train_time:147317ms step_avg:92.77ms +step:1589/1645 train_time:147410ms step_avg:92.77ms +step:1590/1645 train_time:147504ms step_avg:92.77ms +step:1591/1645 train_time:147597ms step_avg:92.77ms +step:1592/1645 train_time:147691ms step_avg:92.77ms +step:1593/1645 train_time:147784ms step_avg:92.77ms +step:1594/1645 train_time:147878ms step_avg:92.77ms +step:1595/1645 train_time:147971ms step_avg:92.77ms +step:1596/1645 train_time:148065ms step_avg:92.77ms +step:1597/1645 train_time:148158ms step_avg:92.77ms +step:1598/1645 train_time:148252ms step_avg:92.77ms +step:1599/1645 train_time:148346ms step_avg:92.77ms +step:1600/1645 train_time:148439ms step_avg:92.77ms +step:1601/1645 train_time:148534ms step_avg:92.78ms +step:1602/1645 train_time:148628ms step_avg:92.78ms +step:1603/1645 train_time:148722ms step_avg:92.78ms +step:1604/1645 train_time:148815ms step_avg:92.78ms +step:1605/1645 train_time:148909ms step_avg:92.78ms +step:1606/1645 train_time:149003ms step_avg:92.78ms +step:1607/1645 train_time:149096ms step_avg:92.78ms +step:1608/1645 train_time:149190ms step_avg:92.78ms +step:1609/1645 train_time:149283ms step_avg:92.78ms +step:1610/1645 train_time:149377ms step_avg:92.78ms +step:1611/1645 train_time:149472ms step_avg:92.78ms +step:1612/1645 train_time:149565ms step_avg:92.78ms +step:1613/1645 train_time:149658ms step_avg:92.78ms +step:1614/1645 train_time:149753ms step_avg:92.78ms +step:1615/1645 train_time:149848ms step_avg:92.78ms +step:1616/1645 train_time:149941ms step_avg:92.79ms +step:1617/1645 train_time:150035ms step_avg:92.79ms +step:1618/1645 train_time:150128ms step_avg:92.79ms +step:1619/1645 train_time:150222ms step_avg:92.79ms +step:1620/1645 train_time:150315ms step_avg:92.79ms +step:1621/1645 train_time:150409ms step_avg:92.79ms +step:1622/1645 train_time:150503ms step_avg:92.79ms +step:1623/1645 train_time:150596ms step_avg:92.79ms +step:1624/1645 train_time:150689ms step_avg:92.79ms +step:1625/1645 train_time:150782ms step_avg:92.79ms +step:1625/1645 val_loss:3.2852 train_time:150876ms step_avg:92.85ms +step:1626/1645 train_time:150901ms step_avg:92.81ms +step:1627/1645 train_time:150976ms step_avg:92.79ms +step:1628/1645 train_time:151072ms step_avg:92.80ms +step:1629/1645 train_time:151165ms step_avg:92.80ms +step:1630/1645 train_time:151258ms step_avg:92.80ms +step:1631/1645 train_time:151351ms step_avg:92.80ms +step:1632/1645 train_time:151444ms step_avg:92.80ms +step:1633/1645 train_time:151537ms step_avg:92.80ms +step:1634/1645 train_time:151631ms step_avg:92.80ms +step:1635/1645 train_time:151723ms step_avg:92.80ms +step:1636/1645 train_time:151817ms step_avg:92.80ms +step:1637/1645 train_time:151912ms step_avg:92.80ms +step:1638/1645 train_time:152007ms step_avg:92.80ms +step:1639/1645 train_time:152101ms step_avg:92.80ms +step:1640/1645 train_time:152195ms step_avg:92.80ms +step:1641/1645 train_time:152288ms step_avg:92.80ms +step:1642/1645 train_time:152382ms step_avg:92.80ms +step:1643/1645 train_time:152475ms step_avg:92.80ms +step:1644/1645 train_time:152567ms step_avg:92.80ms +step:1645/1645 train_time:152661ms step_avg:92.80ms +step:1645/1645 val_loss:3.2796 train_time:152755ms step_avg:92.86ms +peak memory allocated: 32074 MiB reserved: 47316 MiB diff --git a/records/091825_Smear/1cc585a6-ecd2-452c-905d-d5774079e6ff.txt b/records/091825_Smear/1cc585a6-ecd2-452c-905d-d5774079e6ff.txt new file mode 100644 index 000000000..dd2f508b5 --- /dev/null +++ b/records/091825_Smear/1cc585a6-ecd2-452c-905d-d5774079e6ff.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:28:32 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 33C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 33C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:132ms step_avg:131.69ms +step:2/1645 train_time:147ms step_avg:73.69ms +step:3/1645 train_time:220ms step_avg:73.41ms +step:4/1645 train_time:310ms step_avg:77.48ms +step:5/1645 train_time:401ms step_avg:80.12ms +step:6/1645 train_time:492ms step_avg:82.00ms +step:7/1645 train_time:583ms step_avg:83.23ms +step:8/1645 train_time:673ms step_avg:84.14ms +step:9/1645 train_time:764ms step_avg:84.87ms +step:10/1645 train_time:854ms step_avg:85.45ms +step:11/1645 train_time:945ms step_avg:85.94ms +step:12/1645 train_time:1040ms step_avg:86.66ms +step:13/1645 train_time:1135ms step_avg:87.28ms +step:14/1645 train_time:1228ms step_avg:87.72ms +step:15/1645 train_time:1320ms step_avg:88.02ms +step:16/1645 train_time:1412ms step_avg:88.26ms +step:17/1645 train_time:1503ms step_avg:88.44ms +step:18/1645 train_time:1596ms step_avg:88.64ms +step:19/1645 train_time:1688ms step_avg:88.82ms +step:20/1645 train_time:1779ms step_avg:88.93ms +step:21/1645 train_time:1869ms step_avg:89.01ms +step:22/1645 train_time:1962ms step_avg:89.17ms +step:23/1645 train_time:2054ms step_avg:89.32ms +step:24/1645 train_time:2149ms step_avg:89.53ms +step:25/1645 train_time:2242ms step_avg:89.69ms +step:26/1645 train_time:2334ms step_avg:89.78ms +step:27/1645 train_time:2426ms step_avg:89.85ms +step:28/1645 train_time:2518ms step_avg:89.92ms +step:29/1645 train_time:2609ms step_avg:89.98ms +step:30/1645 train_time:2701ms step_avg:90.02ms +step:31/1645 train_time:2792ms step_avg:90.06ms +step:32/1645 train_time:2883ms step_avg:90.09ms +step:33/1645 train_time:2975ms step_avg:90.15ms +step:34/1645 train_time:3067ms step_avg:90.21ms +step:35/1645 train_time:3160ms step_avg:90.30ms +step:36/1645 train_time:3253ms step_avg:90.35ms +step:37/1645 train_time:3346ms step_avg:90.42ms +step:38/1645 train_time:3438ms step_avg:90.47ms +step:39/1645 train_time:3530ms step_avg:90.51ms +step:40/1645 train_time:3621ms step_avg:90.52ms +step:41/1645 train_time:3713ms step_avg:90.55ms +step:42/1645 train_time:3804ms step_avg:90.56ms +step:43/1645 train_time:3895ms step_avg:90.59ms +step:44/1645 train_time:3987ms step_avg:90.62ms +step:45/1645 train_time:4079ms step_avg:90.64ms +step:46/1645 train_time:4171ms step_avg:90.68ms +step:47/1645 train_time:4264ms step_avg:90.72ms +step:48/1645 train_time:4356ms step_avg:90.75ms +step:49/1645 train_time:4449ms step_avg:90.80ms +step:50/1645 train_time:4542ms step_avg:90.84ms +step:51/1645 train_time:4634ms step_avg:90.87ms +step:52/1645 train_time:4725ms step_avg:90.86ms +step:53/1645 train_time:4816ms step_avg:90.86ms +step:54/1645 train_time:4907ms step_avg:90.87ms +step:55/1645 train_time:4999ms step_avg:90.89ms +step:56/1645 train_time:5090ms step_avg:90.90ms +step:57/1645 train_time:5182ms step_avg:90.92ms +step:58/1645 train_time:5274ms step_avg:90.94ms +step:59/1645 train_time:5368ms step_avg:90.98ms +step:60/1645 train_time:5460ms step_avg:91.00ms +step:61/1645 train_time:5552ms step_avg:91.01ms +step:62/1645 train_time:5643ms step_avg:91.02ms +step:63/1645 train_time:5736ms step_avg:91.04ms +step:64/1645 train_time:5826ms step_avg:91.04ms +step:65/1645 train_time:5918ms step_avg:91.04ms +step:66/1645 train_time:6009ms step_avg:91.04ms +step:67/1645 train_time:6100ms step_avg:91.04ms +step:68/1645 train_time:6192ms step_avg:91.06ms +step:69/1645 train_time:6284ms step_avg:91.07ms +step:70/1645 train_time:6375ms step_avg:91.07ms +step:71/1645 train_time:6468ms step_avg:91.10ms +step:72/1645 train_time:6560ms step_avg:91.11ms +step:73/1645 train_time:6652ms step_avg:91.12ms +step:74/1645 train_time:6744ms step_avg:91.14ms +step:75/1645 train_time:6834ms step_avg:91.13ms +step:76/1645 train_time:6926ms step_avg:91.13ms +step:77/1645 train_time:7017ms step_avg:91.13ms +step:78/1645 train_time:7109ms step_avg:91.14ms +step:79/1645 train_time:7200ms step_avg:91.14ms +step:80/1645 train_time:7292ms step_avg:91.15ms +step:81/1645 train_time:7385ms step_avg:91.17ms +step:82/1645 train_time:7476ms step_avg:91.17ms +step:83/1645 train_time:7568ms step_avg:91.19ms +step:84/1645 train_time:7661ms step_avg:91.20ms +step:85/1645 train_time:7753ms step_avg:91.21ms +step:86/1645 train_time:7847ms step_avg:91.24ms +step:87/1645 train_time:7938ms step_avg:91.24ms +step:88/1645 train_time:8029ms step_avg:91.24ms +step:89/1645 train_time:8121ms step_avg:91.25ms +step:90/1645 train_time:8212ms step_avg:91.25ms +step:91/1645 train_time:8304ms step_avg:91.25ms +step:92/1645 train_time:8396ms step_avg:91.26ms +step:93/1645 train_time:8488ms step_avg:91.27ms +step:94/1645 train_time:8580ms step_avg:91.28ms +step:95/1645 train_time:8672ms step_avg:91.29ms +step:96/1645 train_time:8764ms step_avg:91.29ms +step:97/1645 train_time:8855ms step_avg:91.29ms +step:98/1645 train_time:8947ms step_avg:91.29ms +step:99/1645 train_time:9038ms step_avg:91.30ms +step:100/1645 train_time:9130ms step_avg:91.30ms +step:101/1645 train_time:9221ms step_avg:91.30ms +step:102/1645 train_time:9312ms step_avg:91.29ms +step:103/1645 train_time:9403ms step_avg:91.29ms +step:104/1645 train_time:9494ms step_avg:91.29ms +step:105/1645 train_time:9587ms step_avg:91.31ms +step:106/1645 train_time:9679ms step_avg:91.31ms +step:107/1645 train_time:9770ms step_avg:91.31ms +step:108/1645 train_time:9862ms step_avg:91.32ms +step:109/1645 train_time:9955ms step_avg:91.33ms +step:110/1645 train_time:10046ms step_avg:91.33ms +step:111/1645 train_time:10138ms step_avg:91.33ms +step:112/1645 train_time:10230ms step_avg:91.34ms +step:113/1645 train_time:10321ms step_avg:91.34ms +step:114/1645 train_time:10413ms step_avg:91.34ms +step:115/1645 train_time:10505ms step_avg:91.35ms +step:116/1645 train_time:10596ms step_avg:91.35ms +step:117/1645 train_time:10688ms step_avg:91.35ms +step:118/1645 train_time:10780ms step_avg:91.35ms +step:119/1645 train_time:10871ms step_avg:91.36ms +step:120/1645 train_time:10964ms step_avg:91.37ms +step:121/1645 train_time:11055ms step_avg:91.37ms +step:122/1645 train_time:11147ms step_avg:91.37ms +step:123/1645 train_time:11239ms step_avg:91.38ms +step:124/1645 train_time:11331ms step_avg:91.38ms +step:125/1645 train_time:11423ms step_avg:91.38ms +step:125/1645 val_loss:4.3081 train_time:11514ms step_avg:92.11ms +step:126/1645 train_time:11535ms step_avg:91.55ms +step:127/1645 train_time:11612ms step_avg:91.43ms +step:128/1645 train_time:11715ms step_avg:91.52ms +step:129/1645 train_time:11810ms step_avg:91.55ms +step:130/1645 train_time:11903ms step_avg:91.56ms +step:131/1645 train_time:11993ms step_avg:91.55ms +step:132/1645 train_time:12084ms step_avg:91.54ms +step:133/1645 train_time:12174ms step_avg:91.54ms +step:134/1645 train_time:12265ms step_avg:91.53ms +step:135/1645 train_time:12356ms step_avg:91.52ms +step:136/1645 train_time:12447ms step_avg:91.53ms +step:137/1645 train_time:12541ms step_avg:91.54ms +step:138/1645 train_time:12635ms step_avg:91.56ms +step:139/1645 train_time:12728ms step_avg:91.57ms +step:140/1645 train_time:12821ms step_avg:91.58ms +step:141/1645 train_time:12913ms step_avg:91.58ms +step:142/1645 train_time:13005ms step_avg:91.58ms +step:143/1645 train_time:13096ms step_avg:91.58ms +step:144/1645 train_time:13187ms step_avg:91.58ms +step:145/1645 train_time:13278ms step_avg:91.57ms +step:146/1645 train_time:13369ms step_avg:91.57ms +step:147/1645 train_time:13461ms step_avg:91.57ms +step:148/1645 train_time:13555ms step_avg:91.59ms +step:149/1645 train_time:13648ms step_avg:91.60ms +step:150/1645 train_time:13740ms step_avg:91.60ms +step:151/1645 train_time:13832ms step_avg:91.60ms +step:152/1645 train_time:13924ms step_avg:91.60ms +step:153/1645 train_time:14015ms step_avg:91.60ms +step:154/1645 train_time:14106ms step_avg:91.60ms +step:155/1645 train_time:14198ms step_avg:91.60ms +step:156/1645 train_time:14289ms step_avg:91.60ms +step:157/1645 train_time:14380ms step_avg:91.59ms +step:158/1645 train_time:14471ms step_avg:91.59ms +step:159/1645 train_time:14564ms step_avg:91.60ms +step:160/1645 train_time:14657ms step_avg:91.60ms +step:161/1645 train_time:14748ms step_avg:91.60ms +step:162/1645 train_time:14840ms step_avg:91.61ms +step:163/1645 train_time:14932ms step_avg:91.61ms +step:164/1645 train_time:15025ms step_avg:91.61ms +step:165/1645 train_time:15116ms step_avg:91.61ms +step:166/1645 train_time:15207ms step_avg:91.61ms +step:167/1645 train_time:15299ms step_avg:91.61ms +step:168/1645 train_time:15390ms step_avg:91.60ms +step:169/1645 train_time:15482ms step_avg:91.61ms +step:170/1645 train_time:15574ms step_avg:91.61ms +step:171/1645 train_time:15666ms step_avg:91.61ms +step:172/1645 train_time:15758ms step_avg:91.62ms +step:173/1645 train_time:15849ms step_avg:91.61ms +step:174/1645 train_time:15941ms step_avg:91.61ms +step:175/1645 train_time:16033ms step_avg:91.61ms +step:176/1645 train_time:16124ms step_avg:91.61ms +step:177/1645 train_time:16216ms step_avg:91.61ms +step:178/1645 train_time:16307ms step_avg:91.61ms +step:179/1645 train_time:16399ms step_avg:91.61ms +step:180/1645 train_time:16491ms step_avg:91.62ms +step:181/1645 train_time:16582ms step_avg:91.62ms +step:182/1645 train_time:16674ms step_avg:91.62ms +step:183/1645 train_time:16766ms step_avg:91.62ms +step:184/1645 train_time:16858ms step_avg:91.62ms +step:185/1645 train_time:16949ms step_avg:91.61ms +step:186/1645 train_time:17040ms step_avg:91.61ms +step:187/1645 train_time:17132ms step_avg:91.61ms +step:188/1645 train_time:17224ms step_avg:91.62ms +step:189/1645 train_time:17315ms step_avg:91.62ms +step:190/1645 train_time:17407ms step_avg:91.61ms +step:191/1645 train_time:17499ms step_avg:91.62ms +step:192/1645 train_time:17590ms step_avg:91.62ms +step:193/1645 train_time:17682ms step_avg:91.62ms +step:194/1645 train_time:17773ms step_avg:91.61ms +step:195/1645 train_time:17866ms step_avg:91.62ms +step:196/1645 train_time:17958ms step_avg:91.62ms +step:197/1645 train_time:18049ms step_avg:91.62ms +step:198/1645 train_time:18141ms step_avg:91.62ms +step:199/1645 train_time:18233ms step_avg:91.62ms +step:200/1645 train_time:18324ms step_avg:91.62ms +step:201/1645 train_time:18416ms step_avg:91.62ms +step:202/1645 train_time:18507ms step_avg:91.62ms +step:203/1645 train_time:18599ms step_avg:91.62ms +step:204/1645 train_time:18690ms step_avg:91.62ms +step:205/1645 train_time:18782ms step_avg:91.62ms +step:206/1645 train_time:18873ms step_avg:91.62ms +step:207/1645 train_time:18965ms step_avg:91.62ms +step:208/1645 train_time:19057ms step_avg:91.62ms +step:209/1645 train_time:19148ms step_avg:91.62ms +step:210/1645 train_time:19241ms step_avg:91.63ms +step:211/1645 train_time:19333ms step_avg:91.62ms +step:212/1645 train_time:19425ms step_avg:91.63ms +step:213/1645 train_time:19517ms step_avg:91.63ms +step:214/1645 train_time:19607ms step_avg:91.62ms +step:215/1645 train_time:19699ms step_avg:91.62ms +step:216/1645 train_time:19791ms step_avg:91.63ms +step:217/1645 train_time:19882ms step_avg:91.62ms +step:218/1645 train_time:19974ms step_avg:91.62ms +step:219/1645 train_time:20065ms step_avg:91.62ms +step:220/1645 train_time:20157ms step_avg:91.62ms +step:221/1645 train_time:20248ms step_avg:91.62ms +step:222/1645 train_time:20340ms step_avg:91.62ms +step:223/1645 train_time:20430ms step_avg:91.62ms +step:224/1645 train_time:20523ms step_avg:91.62ms +step:225/1645 train_time:20615ms step_avg:91.62ms +step:226/1645 train_time:20707ms step_avg:91.63ms +step:227/1645 train_time:20800ms step_avg:91.63ms +step:228/1645 train_time:20892ms step_avg:91.63ms +step:229/1645 train_time:20985ms step_avg:91.64ms +step:230/1645 train_time:21076ms step_avg:91.64ms +step:231/1645 train_time:21168ms step_avg:91.64ms +step:232/1645 train_time:21260ms step_avg:91.64ms +step:233/1645 train_time:21350ms step_avg:91.63ms +step:234/1645 train_time:21442ms step_avg:91.63ms +step:235/1645 train_time:21533ms step_avg:91.63ms +step:236/1645 train_time:21625ms step_avg:91.63ms +step:237/1645 train_time:21717ms step_avg:91.63ms +step:238/1645 train_time:21809ms step_avg:91.63ms +step:239/1645 train_time:21903ms step_avg:91.64ms +step:240/1645 train_time:21996ms step_avg:91.65ms +step:241/1645 train_time:22087ms step_avg:91.65ms +step:242/1645 train_time:22179ms step_avg:91.65ms +step:243/1645 train_time:22270ms step_avg:91.64ms +step:244/1645 train_time:22361ms step_avg:91.64ms +step:245/1645 train_time:22452ms step_avg:91.64ms +step:246/1645 train_time:22543ms step_avg:91.64ms +step:247/1645 train_time:22634ms step_avg:91.64ms +step:248/1645 train_time:22726ms step_avg:91.64ms +step:249/1645 train_time:22817ms step_avg:91.64ms +step:250/1645 train_time:22910ms step_avg:91.64ms +step:250/1645 val_loss:3.9670 train_time:23004ms step_avg:92.01ms +step:251/1645 train_time:23024ms step_avg:91.73ms +step:252/1645 train_time:23098ms step_avg:91.66ms +step:253/1645 train_time:23194ms step_avg:91.68ms +step:254/1645 train_time:23285ms step_avg:91.67ms +step:255/1645 train_time:23376ms step_avg:91.67ms +step:256/1645 train_time:23466ms step_avg:91.66ms +step:257/1645 train_time:23557ms step_avg:91.66ms +step:258/1645 train_time:23648ms step_avg:91.66ms +step:259/1645 train_time:23738ms step_avg:91.65ms +step:260/1645 train_time:23830ms step_avg:91.65ms +step:261/1645 train_time:23921ms step_avg:91.65ms +step:262/1645 train_time:24016ms step_avg:91.66ms +step:263/1645 train_time:24110ms step_avg:91.67ms +step:264/1645 train_time:24202ms step_avg:91.67ms +step:265/1645 train_time:24294ms step_avg:91.67ms +step:266/1645 train_time:24385ms step_avg:91.67ms +step:267/1645 train_time:24475ms step_avg:91.67ms +step:268/1645 train_time:24566ms step_avg:91.66ms +step:269/1645 train_time:24657ms step_avg:91.66ms +step:270/1645 train_time:24748ms step_avg:91.66ms +step:271/1645 train_time:24839ms step_avg:91.66ms +step:272/1645 train_time:24932ms step_avg:91.66ms +step:273/1645 train_time:25026ms step_avg:91.67ms +step:274/1645 train_time:25119ms step_avg:91.67ms +step:275/1645 train_time:25212ms step_avg:91.68ms +step:276/1645 train_time:25303ms step_avg:91.68ms +step:277/1645 train_time:25395ms step_avg:91.68ms +step:278/1645 train_time:25487ms step_avg:91.68ms +step:279/1645 train_time:25578ms step_avg:91.68ms +step:280/1645 train_time:25668ms step_avg:91.67ms +step:281/1645 train_time:25759ms step_avg:91.67ms +step:282/1645 train_time:25851ms step_avg:91.67ms +step:283/1645 train_time:25943ms step_avg:91.67ms +step:284/1645 train_time:26036ms step_avg:91.68ms +step:285/1645 train_time:26129ms step_avg:91.68ms +step:286/1645 train_time:26221ms step_avg:91.68ms +step:287/1645 train_time:26312ms step_avg:91.68ms +step:288/1645 train_time:26404ms step_avg:91.68ms +step:289/1645 train_time:26495ms step_avg:91.68ms +step:290/1645 train_time:26586ms step_avg:91.68ms +step:291/1645 train_time:26678ms step_avg:91.68ms +step:292/1645 train_time:26768ms step_avg:91.67ms +step:293/1645 train_time:26859ms step_avg:91.67ms +step:294/1645 train_time:26953ms step_avg:91.68ms +step:295/1645 train_time:27046ms step_avg:91.68ms +step:296/1645 train_time:27139ms step_avg:91.68ms +step:297/1645 train_time:27230ms step_avg:91.68ms +step:298/1645 train_time:27323ms step_avg:91.69ms +step:299/1645 train_time:27414ms step_avg:91.69ms +step:300/1645 train_time:27506ms step_avg:91.69ms +step:301/1645 train_time:27596ms step_avg:91.68ms +step:302/1645 train_time:27687ms step_avg:91.68ms +step:303/1645 train_time:27779ms step_avg:91.68ms +step:304/1645 train_time:27871ms step_avg:91.68ms +step:305/1645 train_time:27963ms step_avg:91.68ms +step:306/1645 train_time:28056ms step_avg:91.69ms +step:307/1645 train_time:28147ms step_avg:91.68ms +step:308/1645 train_time:28239ms step_avg:91.69ms +step:309/1645 train_time:28331ms step_avg:91.69ms +step:310/1645 train_time:28424ms step_avg:91.69ms +step:311/1645 train_time:28515ms step_avg:91.69ms +step:312/1645 train_time:28607ms step_avg:91.69ms +step:313/1645 train_time:28697ms step_avg:91.69ms +step:314/1645 train_time:28788ms step_avg:91.68ms +step:315/1645 train_time:28879ms step_avg:91.68ms +step:316/1645 train_time:28971ms step_avg:91.68ms +step:317/1645 train_time:29062ms step_avg:91.68ms +step:318/1645 train_time:29155ms step_avg:91.68ms +step:319/1645 train_time:29248ms step_avg:91.69ms +step:320/1645 train_time:29340ms step_avg:91.69ms +step:321/1645 train_time:29433ms step_avg:91.69ms +step:322/1645 train_time:29525ms step_avg:91.69ms +step:323/1645 train_time:29616ms step_avg:91.69ms +step:324/1645 train_time:29707ms step_avg:91.69ms +step:325/1645 train_time:29799ms step_avg:91.69ms +step:326/1645 train_time:29890ms step_avg:91.69ms +step:327/1645 train_time:29981ms step_avg:91.68ms +step:328/1645 train_time:30073ms step_avg:91.68ms +step:329/1645 train_time:30164ms step_avg:91.68ms +step:330/1645 train_time:30256ms step_avg:91.69ms +step:331/1645 train_time:30348ms step_avg:91.69ms +step:332/1645 train_time:30440ms step_avg:91.69ms +step:333/1645 train_time:30532ms step_avg:91.69ms +step:334/1645 train_time:30624ms step_avg:91.69ms +step:335/1645 train_time:30715ms step_avg:91.69ms +step:336/1645 train_time:30806ms step_avg:91.69ms +step:337/1645 train_time:30898ms step_avg:91.69ms +step:338/1645 train_time:30990ms step_avg:91.69ms +step:339/1645 train_time:31081ms step_avg:91.68ms +step:340/1645 train_time:31172ms step_avg:91.68ms +step:341/1645 train_time:31265ms step_avg:91.69ms +step:342/1645 train_time:31357ms step_avg:91.69ms +step:343/1645 train_time:31449ms step_avg:91.69ms +step:344/1645 train_time:31543ms step_avg:91.70ms +step:345/1645 train_time:31635ms step_avg:91.70ms +step:346/1645 train_time:31726ms step_avg:91.70ms +step:347/1645 train_time:31818ms step_avg:91.69ms +step:348/1645 train_time:31909ms step_avg:91.69ms +step:349/1645 train_time:32000ms step_avg:91.69ms +step:350/1645 train_time:32092ms step_avg:91.69ms +step:351/1645 train_time:32184ms step_avg:91.69ms +step:352/1645 train_time:32275ms step_avg:91.69ms +step:353/1645 train_time:32367ms step_avg:91.69ms +step:354/1645 train_time:32458ms step_avg:91.69ms +step:355/1645 train_time:32552ms step_avg:91.69ms +step:356/1645 train_time:32643ms step_avg:91.69ms +step:357/1645 train_time:32735ms step_avg:91.70ms +step:358/1645 train_time:32827ms step_avg:91.69ms +step:359/1645 train_time:32918ms step_avg:91.69ms +step:360/1645 train_time:33009ms step_avg:91.69ms +step:361/1645 train_time:33100ms step_avg:91.69ms +step:362/1645 train_time:33192ms step_avg:91.69ms +step:363/1645 train_time:33283ms step_avg:91.69ms +step:364/1645 train_time:33375ms step_avg:91.69ms +step:365/1645 train_time:33466ms step_avg:91.69ms +step:366/1645 train_time:33558ms step_avg:91.69ms +step:367/1645 train_time:33651ms step_avg:91.69ms +step:368/1645 train_time:33742ms step_avg:91.69ms +step:369/1645 train_time:33834ms step_avg:91.69ms +step:370/1645 train_time:33926ms step_avg:91.69ms +step:371/1645 train_time:34017ms step_avg:91.69ms +step:372/1645 train_time:34108ms step_avg:91.69ms +step:373/1645 train_time:34199ms step_avg:91.69ms +step:374/1645 train_time:34290ms step_avg:91.69ms +step:375/1645 train_time:34381ms step_avg:91.68ms +step:375/1645 val_loss:3.8172 train_time:34473ms step_avg:91.93ms +step:376/1645 train_time:34494ms step_avg:91.74ms +step:377/1645 train_time:34571ms step_avg:91.70ms +step:378/1645 train_time:34664ms step_avg:91.70ms +step:379/1645 train_time:34756ms step_avg:91.70ms +step:380/1645 train_time:34846ms step_avg:91.70ms +step:381/1645 train_time:34938ms step_avg:91.70ms +step:382/1645 train_time:35029ms step_avg:91.70ms +step:383/1645 train_time:35121ms step_avg:91.70ms +step:384/1645 train_time:35212ms step_avg:91.70ms +step:385/1645 train_time:35302ms step_avg:91.69ms +step:386/1645 train_time:35394ms step_avg:91.70ms +step:387/1645 train_time:35488ms step_avg:91.70ms +step:388/1645 train_time:35582ms step_avg:91.71ms +step:389/1645 train_time:35675ms step_avg:91.71ms +step:390/1645 train_time:35767ms step_avg:91.71ms +step:391/1645 train_time:35858ms step_avg:91.71ms +step:392/1645 train_time:35950ms step_avg:91.71ms +step:393/1645 train_time:36041ms step_avg:91.71ms +step:394/1645 train_time:36131ms step_avg:91.70ms +step:395/1645 train_time:36223ms step_avg:91.70ms +step:396/1645 train_time:36313ms step_avg:91.70ms +step:397/1645 train_time:36405ms step_avg:91.70ms +step:398/1645 train_time:36498ms step_avg:91.70ms +step:399/1645 train_time:36592ms step_avg:91.71ms +step:400/1645 train_time:36684ms step_avg:91.71ms +step:401/1645 train_time:36775ms step_avg:91.71ms +step:402/1645 train_time:36867ms step_avg:91.71ms +step:403/1645 train_time:36958ms step_avg:91.71ms +step:404/1645 train_time:37049ms step_avg:91.71ms +step:405/1645 train_time:37140ms step_avg:91.70ms +step:406/1645 train_time:37231ms step_avg:91.70ms +step:407/1645 train_time:37321ms step_avg:91.70ms +step:408/1645 train_time:37413ms step_avg:91.70ms +step:409/1645 train_time:37506ms step_avg:91.70ms +step:410/1645 train_time:37598ms step_avg:91.70ms +step:411/1645 train_time:37690ms step_avg:91.70ms +step:412/1645 train_time:37782ms step_avg:91.70ms +step:413/1645 train_time:37874ms step_avg:91.71ms +step:414/1645 train_time:37966ms step_avg:91.71ms +step:415/1645 train_time:38057ms step_avg:91.70ms +step:416/1645 train_time:38148ms step_avg:91.70ms +step:417/1645 train_time:38240ms step_avg:91.70ms +step:418/1645 train_time:38331ms step_avg:91.70ms +step:419/1645 train_time:38422ms step_avg:91.70ms +step:420/1645 train_time:38514ms step_avg:91.70ms +step:421/1645 train_time:38607ms step_avg:91.70ms +step:422/1645 train_time:38700ms step_avg:91.71ms +step:423/1645 train_time:38791ms step_avg:91.71ms +step:424/1645 train_time:38883ms step_avg:91.70ms +step:425/1645 train_time:38974ms step_avg:91.70ms +step:426/1645 train_time:39066ms step_avg:91.70ms +step:427/1645 train_time:39157ms step_avg:91.70ms +step:428/1645 train_time:39248ms step_avg:91.70ms +step:429/1645 train_time:39339ms step_avg:91.70ms +step:430/1645 train_time:39430ms step_avg:91.70ms +step:431/1645 train_time:39521ms step_avg:91.70ms +step:432/1645 train_time:39613ms step_avg:91.70ms +step:433/1645 train_time:39706ms step_avg:91.70ms +step:434/1645 train_time:39798ms step_avg:91.70ms +step:435/1645 train_time:39891ms step_avg:91.70ms +step:436/1645 train_time:39982ms step_avg:91.70ms +step:437/1645 train_time:40074ms step_avg:91.70ms +step:438/1645 train_time:40165ms step_avg:91.70ms +step:439/1645 train_time:40256ms step_avg:91.70ms +step:440/1645 train_time:40347ms step_avg:91.70ms +step:441/1645 train_time:40439ms step_avg:91.70ms +step:442/1645 train_time:40530ms step_avg:91.70ms +step:443/1645 train_time:40622ms step_avg:91.70ms +step:444/1645 train_time:40713ms step_avg:91.70ms +step:445/1645 train_time:40805ms step_avg:91.70ms +step:446/1645 train_time:40897ms step_avg:91.70ms +step:447/1645 train_time:40990ms step_avg:91.70ms +step:448/1645 train_time:41081ms step_avg:91.70ms +step:449/1645 train_time:41173ms step_avg:91.70ms +step:450/1645 train_time:41265ms step_avg:91.70ms +step:451/1645 train_time:41356ms step_avg:91.70ms +step:452/1645 train_time:41448ms step_avg:91.70ms +step:453/1645 train_time:41539ms step_avg:91.70ms +step:454/1645 train_time:41630ms step_avg:91.70ms +step:455/1645 train_time:41721ms step_avg:91.70ms +step:456/1645 train_time:41813ms step_avg:91.70ms +step:457/1645 train_time:41906ms step_avg:91.70ms +step:458/1645 train_time:41998ms step_avg:91.70ms +step:459/1645 train_time:42090ms step_avg:91.70ms +step:460/1645 train_time:42182ms step_avg:91.70ms +step:461/1645 train_time:42274ms step_avg:91.70ms +step:462/1645 train_time:42365ms step_avg:91.70ms +step:463/1645 train_time:42456ms step_avg:91.70ms +step:464/1645 train_time:42547ms step_avg:91.70ms +step:465/1645 train_time:42639ms step_avg:91.70ms +step:466/1645 train_time:42731ms step_avg:91.70ms +step:467/1645 train_time:42823ms step_avg:91.70ms +step:468/1645 train_time:42915ms step_avg:91.70ms +step:469/1645 train_time:43009ms step_avg:91.70ms +step:470/1645 train_time:43101ms step_avg:91.70ms +step:471/1645 train_time:43192ms step_avg:91.70ms +step:472/1645 train_time:43284ms step_avg:91.70ms +step:473/1645 train_time:43375ms step_avg:91.70ms +step:474/1645 train_time:43467ms step_avg:91.70ms +step:475/1645 train_time:43558ms step_avg:91.70ms +step:476/1645 train_time:43650ms step_avg:91.70ms +step:477/1645 train_time:43741ms step_avg:91.70ms +step:478/1645 train_time:43833ms step_avg:91.70ms +step:479/1645 train_time:43924ms step_avg:91.70ms +step:480/1645 train_time:44016ms step_avg:91.70ms +step:481/1645 train_time:44108ms step_avg:91.70ms +step:482/1645 train_time:44201ms step_avg:91.70ms +step:483/1645 train_time:44292ms step_avg:91.70ms +step:484/1645 train_time:44384ms step_avg:91.70ms +step:485/1645 train_time:44474ms step_avg:91.70ms +step:486/1645 train_time:44566ms step_avg:91.70ms +step:487/1645 train_time:44659ms step_avg:91.70ms +step:488/1645 train_time:44750ms step_avg:91.70ms +step:489/1645 train_time:44841ms step_avg:91.70ms +step:490/1645 train_time:44933ms step_avg:91.70ms +step:491/1645 train_time:45025ms step_avg:91.70ms +step:492/1645 train_time:45116ms step_avg:91.70ms +step:493/1645 train_time:45209ms step_avg:91.70ms +step:494/1645 train_time:45300ms step_avg:91.70ms +step:495/1645 train_time:45392ms step_avg:91.70ms +step:496/1645 train_time:45483ms step_avg:91.70ms +step:497/1645 train_time:45575ms step_avg:91.70ms +step:498/1645 train_time:45667ms step_avg:91.70ms +step:499/1645 train_time:45758ms step_avg:91.70ms +step:500/1645 train_time:45849ms step_avg:91.70ms +step:500/1645 val_loss:3.7134 train_time:45941ms step_avg:91.88ms +step:501/1645 train_time:45962ms step_avg:91.74ms +step:502/1645 train_time:46038ms step_avg:91.71ms +step:503/1645 train_time:46130ms step_avg:91.71ms +step:504/1645 train_time:46222ms step_avg:91.71ms +step:505/1645 train_time:46313ms step_avg:91.71ms +step:506/1645 train_time:46403ms step_avg:91.71ms +step:507/1645 train_time:46493ms step_avg:91.70ms +step:508/1645 train_time:46585ms step_avg:91.70ms +step:509/1645 train_time:46676ms step_avg:91.70ms +step:510/1645 train_time:46767ms step_avg:91.70ms +step:511/1645 train_time:46859ms step_avg:91.70ms +step:512/1645 train_time:46952ms step_avg:91.70ms +step:513/1645 train_time:47046ms step_avg:91.71ms +step:514/1645 train_time:47138ms step_avg:91.71ms +step:515/1645 train_time:47230ms step_avg:91.71ms +step:516/1645 train_time:47321ms step_avg:91.71ms +step:517/1645 train_time:47412ms step_avg:91.71ms +step:518/1645 train_time:47502ms step_avg:91.70ms +step:519/1645 train_time:47594ms step_avg:91.70ms +step:520/1645 train_time:47685ms step_avg:91.70ms +step:521/1645 train_time:47777ms step_avg:91.70ms +step:522/1645 train_time:47868ms step_avg:91.70ms +step:523/1645 train_time:47960ms step_avg:91.70ms +step:524/1645 train_time:48053ms step_avg:91.70ms +step:525/1645 train_time:48146ms step_avg:91.71ms +step:526/1645 train_time:48238ms step_avg:91.71ms +step:527/1645 train_time:48329ms step_avg:91.71ms +step:528/1645 train_time:48422ms step_avg:91.71ms +step:529/1645 train_time:48512ms step_avg:91.71ms +step:530/1645 train_time:48602ms step_avg:91.70ms +step:531/1645 train_time:48693ms step_avg:91.70ms +step:532/1645 train_time:48784ms step_avg:91.70ms +step:533/1645 train_time:48876ms step_avg:91.70ms +step:534/1645 train_time:48968ms step_avg:91.70ms +step:535/1645 train_time:49060ms step_avg:91.70ms +step:536/1645 train_time:49152ms step_avg:91.70ms +step:537/1645 train_time:49244ms step_avg:91.70ms +step:538/1645 train_time:49336ms step_avg:91.70ms +step:539/1645 train_time:49429ms step_avg:91.70ms +step:540/1645 train_time:49520ms step_avg:91.70ms +step:541/1645 train_time:49611ms step_avg:91.70ms +step:542/1645 train_time:49702ms step_avg:91.70ms +step:543/1645 train_time:49793ms step_avg:91.70ms +step:544/1645 train_time:49883ms step_avg:91.70ms +step:545/1645 train_time:49976ms step_avg:91.70ms +step:546/1645 train_time:50067ms step_avg:91.70ms +step:547/1645 train_time:50159ms step_avg:91.70ms +step:548/1645 train_time:50251ms step_avg:91.70ms +step:549/1645 train_time:50343ms step_avg:91.70ms +step:550/1645 train_time:50436ms step_avg:91.70ms +step:551/1645 train_time:50528ms step_avg:91.70ms +step:552/1645 train_time:50621ms step_avg:91.70ms +step:553/1645 train_time:50714ms step_avg:91.71ms +step:554/1645 train_time:50807ms step_avg:91.71ms +step:555/1645 train_time:50899ms step_avg:91.71ms +step:556/1645 train_time:50993ms step_avg:91.71ms +step:557/1645 train_time:51086ms step_avg:91.72ms +step:558/1645 train_time:51180ms step_avg:91.72ms +step:559/1645 train_time:51274ms step_avg:91.72ms +step:560/1645 train_time:51367ms step_avg:91.73ms +step:561/1645 train_time:51459ms step_avg:91.73ms +step:562/1645 train_time:51552ms step_avg:91.73ms +step:563/1645 train_time:51645ms step_avg:91.73ms +step:564/1645 train_time:51738ms step_avg:91.73ms +step:565/1645 train_time:51831ms step_avg:91.74ms +step:566/1645 train_time:51923ms step_avg:91.74ms +step:567/1645 train_time:52016ms step_avg:91.74ms +step:568/1645 train_time:52109ms step_avg:91.74ms +step:569/1645 train_time:52202ms step_avg:91.74ms +step:570/1645 train_time:52295ms step_avg:91.75ms +step:571/1645 train_time:52388ms step_avg:91.75ms +step:572/1645 train_time:52481ms step_avg:91.75ms +step:573/1645 train_time:52573ms step_avg:91.75ms +step:574/1645 train_time:52666ms step_avg:91.75ms +step:575/1645 train_time:52759ms step_avg:91.75ms +step:576/1645 train_time:52852ms step_avg:91.76ms +step:577/1645 train_time:52945ms step_avg:91.76ms +step:578/1645 train_time:53039ms step_avg:91.76ms +step:579/1645 train_time:53131ms step_avg:91.76ms +step:580/1645 train_time:53224ms step_avg:91.77ms +step:581/1645 train_time:53317ms step_avg:91.77ms +step:582/1645 train_time:53410ms step_avg:91.77ms +step:583/1645 train_time:53503ms step_avg:91.77ms +step:584/1645 train_time:53596ms step_avg:91.77ms +step:585/1645 train_time:53688ms step_avg:91.77ms +step:586/1645 train_time:53781ms step_avg:91.78ms +step:587/1645 train_time:53874ms step_avg:91.78ms +step:588/1645 train_time:53968ms step_avg:91.78ms +step:589/1645 train_time:54061ms step_avg:91.78ms +step:590/1645 train_time:54153ms step_avg:91.79ms +step:591/1645 train_time:54248ms step_avg:91.79ms +step:592/1645 train_time:54340ms step_avg:91.79ms +step:593/1645 train_time:54434ms step_avg:91.79ms +step:594/1645 train_time:54526ms step_avg:91.79ms +step:595/1645 train_time:54619ms step_avg:91.80ms +step:596/1645 train_time:54712ms step_avg:91.80ms +step:597/1645 train_time:54804ms step_avg:91.80ms +step:598/1645 train_time:54897ms step_avg:91.80ms +step:599/1645 train_time:54990ms step_avg:91.80ms +step:600/1645 train_time:55083ms step_avg:91.80ms +step:601/1645 train_time:55176ms step_avg:91.81ms +step:602/1645 train_time:55270ms step_avg:91.81ms +step:603/1645 train_time:55362ms step_avg:91.81ms +step:604/1645 train_time:55456ms step_avg:91.81ms +step:605/1645 train_time:55550ms step_avg:91.82ms +step:606/1645 train_time:55642ms step_avg:91.82ms +step:607/1645 train_time:55735ms step_avg:91.82ms +step:608/1645 train_time:55828ms step_avg:91.82ms +step:609/1645 train_time:55921ms step_avg:91.82ms +step:610/1645 train_time:56014ms step_avg:91.83ms +step:611/1645 train_time:56107ms step_avg:91.83ms +step:612/1645 train_time:56200ms step_avg:91.83ms +step:613/1645 train_time:56293ms step_avg:91.83ms +step:614/1645 train_time:56386ms step_avg:91.83ms +step:615/1645 train_time:56479ms step_avg:91.84ms +step:616/1645 train_time:56572ms step_avg:91.84ms +step:617/1645 train_time:56665ms step_avg:91.84ms +step:618/1645 train_time:56760ms step_avg:91.84ms +step:619/1645 train_time:56851ms step_avg:91.84ms +step:620/1645 train_time:56944ms step_avg:91.84ms +step:621/1645 train_time:57036ms step_avg:91.85ms +step:622/1645 train_time:57129ms step_avg:91.85ms +step:623/1645 train_time:57222ms step_avg:91.85ms +step:624/1645 train_time:57316ms step_avg:91.85ms +step:625/1645 train_time:57408ms step_avg:91.85ms +step:625/1645 val_loss:3.6115 train_time:57501ms step_avg:92.00ms +step:626/1645 train_time:57522ms step_avg:91.89ms +step:627/1645 train_time:57599ms step_avg:91.86ms +step:628/1645 train_time:57702ms step_avg:91.88ms +step:629/1645 train_time:57796ms step_avg:91.89ms +step:630/1645 train_time:57888ms step_avg:91.89ms +step:631/1645 train_time:57979ms step_avg:91.88ms +step:632/1645 train_time:58071ms step_avg:91.88ms +step:633/1645 train_time:58163ms step_avg:91.88ms +step:634/1645 train_time:58254ms step_avg:91.88ms +step:635/1645 train_time:58346ms step_avg:91.88ms +step:636/1645 train_time:58441ms step_avg:91.89ms +step:637/1645 train_time:58536ms step_avg:91.89ms +step:638/1645 train_time:58632ms step_avg:91.90ms +step:639/1645 train_time:58726ms step_avg:91.90ms +step:640/1645 train_time:58820ms step_avg:91.91ms +step:641/1645 train_time:58913ms step_avg:91.91ms +step:642/1645 train_time:59005ms step_avg:91.91ms +step:643/1645 train_time:59097ms step_avg:91.91ms +step:644/1645 train_time:59189ms step_avg:91.91ms +step:645/1645 train_time:59280ms step_avg:91.91ms +step:646/1645 train_time:59372ms step_avg:91.91ms +step:647/1645 train_time:59466ms step_avg:91.91ms +step:648/1645 train_time:59558ms step_avg:91.91ms +step:649/1645 train_time:59653ms step_avg:91.91ms +step:650/1645 train_time:59747ms step_avg:91.92ms +step:651/1645 train_time:59840ms step_avg:91.92ms +step:652/1645 train_time:59934ms step_avg:91.92ms +step:653/1645 train_time:60026ms step_avg:91.92ms +step:654/1645 train_time:60119ms step_avg:91.93ms +step:655/1645 train_time:60212ms step_avg:91.93ms +step:656/1645 train_time:60304ms step_avg:91.93ms +step:657/1645 train_time:60397ms step_avg:91.93ms +step:658/1645 train_time:60490ms step_avg:91.93ms +step:659/1645 train_time:60584ms step_avg:91.93ms +step:660/1645 train_time:60677ms step_avg:91.93ms +step:661/1645 train_time:60770ms step_avg:91.94ms +step:662/1645 train_time:60864ms step_avg:91.94ms +step:663/1645 train_time:60957ms step_avg:91.94ms +step:664/1645 train_time:61050ms step_avg:91.94ms +step:665/1645 train_time:61142ms step_avg:91.94ms +step:666/1645 train_time:61236ms step_avg:91.95ms +step:667/1645 train_time:61328ms step_avg:91.95ms +step:668/1645 train_time:61420ms step_avg:91.95ms +step:669/1645 train_time:61513ms step_avg:91.95ms +step:670/1645 train_time:61606ms step_avg:91.95ms +step:671/1645 train_time:61699ms step_avg:91.95ms +step:672/1645 train_time:61793ms step_avg:91.95ms +step:673/1645 train_time:61887ms step_avg:91.96ms +step:674/1645 train_time:61979ms step_avg:91.96ms +step:675/1645 train_time:62071ms step_avg:91.96ms +step:676/1645 train_time:62164ms step_avg:91.96ms +step:677/1645 train_time:62257ms step_avg:91.96ms +step:678/1645 train_time:62349ms step_avg:91.96ms +step:679/1645 train_time:62442ms step_avg:91.96ms +step:680/1645 train_time:62534ms step_avg:91.96ms +step:681/1645 train_time:62627ms step_avg:91.96ms +step:682/1645 train_time:62721ms step_avg:91.97ms +step:683/1645 train_time:62814ms step_avg:91.97ms +step:684/1645 train_time:62907ms step_avg:91.97ms +step:685/1645 train_time:63001ms step_avg:91.97ms +step:686/1645 train_time:63094ms step_avg:91.97ms +step:687/1645 train_time:63186ms step_avg:91.97ms +step:688/1645 train_time:63279ms step_avg:91.98ms +step:689/1645 train_time:63371ms step_avg:91.98ms +step:690/1645 train_time:63464ms step_avg:91.98ms +step:691/1645 train_time:63556ms step_avg:91.98ms +step:692/1645 train_time:63649ms step_avg:91.98ms +step:693/1645 train_time:63742ms step_avg:91.98ms +step:694/1645 train_time:63835ms step_avg:91.98ms +step:695/1645 train_time:63929ms step_avg:91.98ms +step:696/1645 train_time:64023ms step_avg:91.99ms +step:697/1645 train_time:64116ms step_avg:91.99ms +step:698/1645 train_time:64209ms step_avg:91.99ms +step:699/1645 train_time:64302ms step_avg:91.99ms +step:700/1645 train_time:64395ms step_avg:91.99ms +step:701/1645 train_time:64488ms step_avg:91.99ms +step:702/1645 train_time:64580ms step_avg:91.99ms +step:703/1645 train_time:64672ms step_avg:91.99ms +step:704/1645 train_time:64765ms step_avg:92.00ms +step:705/1645 train_time:64857ms step_avg:92.00ms +step:706/1645 train_time:64951ms step_avg:92.00ms +step:707/1645 train_time:65044ms step_avg:92.00ms +step:708/1645 train_time:65136ms step_avg:92.00ms +step:709/1645 train_time:65229ms step_avg:92.00ms +step:710/1645 train_time:65322ms step_avg:92.00ms +step:711/1645 train_time:65415ms step_avg:92.00ms +step:712/1645 train_time:65509ms step_avg:92.01ms +step:713/1645 train_time:65602ms step_avg:92.01ms +step:714/1645 train_time:65695ms step_avg:92.01ms +step:715/1645 train_time:65788ms step_avg:92.01ms +step:716/1645 train_time:65880ms step_avg:92.01ms +step:717/1645 train_time:65973ms step_avg:92.01ms +step:718/1645 train_time:66066ms step_avg:92.01ms +step:719/1645 train_time:66160ms step_avg:92.02ms +step:720/1645 train_time:66252ms step_avg:92.02ms +step:721/1645 train_time:66344ms step_avg:92.02ms +step:722/1645 train_time:66437ms step_avg:92.02ms +step:723/1645 train_time:66530ms step_avg:92.02ms +step:724/1645 train_time:66622ms step_avg:92.02ms +step:725/1645 train_time:66715ms step_avg:92.02ms +step:726/1645 train_time:66809ms step_avg:92.02ms +step:727/1645 train_time:66901ms step_avg:92.02ms +step:728/1645 train_time:66994ms step_avg:92.03ms +step:729/1645 train_time:67087ms step_avg:92.03ms +step:730/1645 train_time:67180ms step_avg:92.03ms +step:731/1645 train_time:67273ms step_avg:92.03ms +step:732/1645 train_time:67366ms step_avg:92.03ms +step:733/1645 train_time:67458ms step_avg:92.03ms +step:734/1645 train_time:67552ms step_avg:92.03ms +step:735/1645 train_time:67645ms step_avg:92.03ms +step:736/1645 train_time:67737ms step_avg:92.03ms +step:737/1645 train_time:67831ms step_avg:92.04ms +step:738/1645 train_time:67924ms step_avg:92.04ms +step:739/1645 train_time:68016ms step_avg:92.04ms +step:740/1645 train_time:68109ms step_avg:92.04ms +step:741/1645 train_time:68202ms step_avg:92.04ms +step:742/1645 train_time:68295ms step_avg:92.04ms +step:743/1645 train_time:68389ms step_avg:92.04ms +step:744/1645 train_time:68482ms step_avg:92.05ms +step:745/1645 train_time:68574ms step_avg:92.05ms +step:746/1645 train_time:68668ms step_avg:92.05ms +step:747/1645 train_time:68761ms step_avg:92.05ms +step:748/1645 train_time:68855ms step_avg:92.05ms +step:749/1645 train_time:68947ms step_avg:92.05ms +step:750/1645 train_time:69040ms step_avg:92.05ms +step:750/1645 val_loss:3.5603 train_time:69133ms step_avg:92.18ms +step:751/1645 train_time:69154ms step_avg:92.08ms +step:752/1645 train_time:69231ms step_avg:92.06ms +step:753/1645 train_time:69327ms step_avg:92.07ms +step:754/1645 train_time:69420ms step_avg:92.07ms +step:755/1645 train_time:69512ms step_avg:92.07ms +step:756/1645 train_time:69604ms step_avg:92.07ms +step:757/1645 train_time:69696ms step_avg:92.07ms +step:758/1645 train_time:69788ms step_avg:92.07ms +step:759/1645 train_time:69880ms step_avg:92.07ms +step:760/1645 train_time:69971ms step_avg:92.07ms +step:761/1645 train_time:70064ms step_avg:92.07ms +step:762/1645 train_time:70158ms step_avg:92.07ms +step:763/1645 train_time:70253ms step_avg:92.07ms +step:764/1645 train_time:70348ms step_avg:92.08ms +step:765/1645 train_time:70441ms step_avg:92.08ms +step:766/1645 train_time:70534ms step_avg:92.08ms +step:767/1645 train_time:70627ms step_avg:92.08ms +step:768/1645 train_time:70718ms step_avg:92.08ms +step:769/1645 train_time:70811ms step_avg:92.08ms +step:770/1645 train_time:70903ms step_avg:92.08ms +step:771/1645 train_time:70995ms step_avg:92.08ms +step:772/1645 train_time:71088ms step_avg:92.08ms +step:773/1645 train_time:71183ms step_avg:92.09ms +step:774/1645 train_time:71277ms step_avg:92.09ms +step:775/1645 train_time:71370ms step_avg:92.09ms +step:776/1645 train_time:71464ms step_avg:92.09ms +step:777/1645 train_time:71556ms step_avg:92.09ms +step:778/1645 train_time:71649ms step_avg:92.09ms +step:779/1645 train_time:71742ms step_avg:92.09ms +step:780/1645 train_time:71834ms step_avg:92.09ms +step:781/1645 train_time:71926ms step_avg:92.10ms +step:782/1645 train_time:72020ms step_avg:92.10ms +step:783/1645 train_time:72113ms step_avg:92.10ms +step:784/1645 train_time:72208ms step_avg:92.10ms +step:785/1645 train_time:72301ms step_avg:92.10ms +step:786/1645 train_time:72394ms step_avg:92.10ms +step:787/1645 train_time:72487ms step_avg:92.11ms +step:788/1645 train_time:72580ms step_avg:92.11ms +step:789/1645 train_time:72673ms step_avg:92.11ms +step:790/1645 train_time:72766ms step_avg:92.11ms +step:791/1645 train_time:72858ms step_avg:92.11ms +step:792/1645 train_time:72950ms step_avg:92.11ms +step:793/1645 train_time:73043ms step_avg:92.11ms +step:794/1645 train_time:73136ms step_avg:92.11ms +step:795/1645 train_time:73229ms step_avg:92.11ms +step:796/1645 train_time:73324ms step_avg:92.12ms +step:797/1645 train_time:73416ms step_avg:92.12ms +step:798/1645 train_time:73510ms step_avg:92.12ms +step:799/1645 train_time:73603ms step_avg:92.12ms +step:800/1645 train_time:73695ms step_avg:92.12ms +step:801/1645 train_time:73787ms step_avg:92.12ms +step:802/1645 train_time:73879ms step_avg:92.12ms +step:803/1645 train_time:73972ms step_avg:92.12ms +step:804/1645 train_time:74065ms step_avg:92.12ms +step:805/1645 train_time:74157ms step_avg:92.12ms +step:806/1645 train_time:74253ms step_avg:92.12ms +step:807/1645 train_time:74346ms step_avg:92.13ms +step:808/1645 train_time:74439ms step_avg:92.13ms +step:809/1645 train_time:74532ms step_avg:92.13ms +step:810/1645 train_time:74625ms step_avg:92.13ms +step:811/1645 train_time:74717ms step_avg:92.13ms +step:812/1645 train_time:74810ms step_avg:92.13ms +step:813/1645 train_time:74902ms step_avg:92.13ms +step:814/1645 train_time:74995ms step_avg:92.13ms +step:815/1645 train_time:75088ms step_avg:92.13ms +step:816/1645 train_time:75181ms step_avg:92.13ms +step:817/1645 train_time:75274ms step_avg:92.13ms +step:818/1645 train_time:75367ms step_avg:92.14ms +step:819/1645 train_time:75461ms step_avg:92.14ms +step:820/1645 train_time:75554ms step_avg:92.14ms +step:821/1645 train_time:75647ms step_avg:92.14ms +step:822/1645 train_time:75743ms step_avg:92.14ms +step:823/1645 train_time:75834ms step_avg:92.14ms +step:824/1645 train_time:75927ms step_avg:92.14ms +step:825/1645 train_time:76018ms step_avg:92.14ms +step:826/1645 train_time:76111ms step_avg:92.14ms +step:827/1645 train_time:76204ms step_avg:92.15ms +step:828/1645 train_time:76296ms step_avg:92.15ms +step:829/1645 train_time:76390ms step_avg:92.15ms +step:830/1645 train_time:76483ms step_avg:92.15ms +step:831/1645 train_time:76576ms step_avg:92.15ms +step:832/1645 train_time:76669ms step_avg:92.15ms +step:833/1645 train_time:76762ms step_avg:92.15ms +step:834/1645 train_time:76854ms step_avg:92.15ms +step:835/1645 train_time:76948ms step_avg:92.15ms +step:836/1645 train_time:77041ms step_avg:92.15ms +step:837/1645 train_time:77133ms step_avg:92.15ms +step:838/1645 train_time:77226ms step_avg:92.16ms +step:839/1645 train_time:77320ms step_avg:92.16ms +step:840/1645 train_time:77412ms step_avg:92.16ms +step:841/1645 train_time:77507ms step_avg:92.16ms +step:842/1645 train_time:77600ms step_avg:92.16ms +step:843/1645 train_time:77694ms step_avg:92.16ms +step:844/1645 train_time:77787ms step_avg:92.16ms +step:845/1645 train_time:77879ms step_avg:92.16ms +step:846/1645 train_time:77972ms step_avg:92.17ms +step:847/1645 train_time:78065ms step_avg:92.17ms +step:848/1645 train_time:78158ms step_avg:92.17ms +step:849/1645 train_time:78251ms step_avg:92.17ms +step:850/1645 train_time:78344ms step_avg:92.17ms +step:851/1645 train_time:78438ms step_avg:92.17ms +step:852/1645 train_time:78532ms step_avg:92.17ms +step:853/1645 train_time:78624ms step_avg:92.17ms +step:854/1645 train_time:78717ms step_avg:92.17ms +step:855/1645 train_time:78810ms step_avg:92.18ms +step:856/1645 train_time:78903ms step_avg:92.18ms +step:857/1645 train_time:78996ms step_avg:92.18ms +step:858/1645 train_time:79089ms step_avg:92.18ms +step:859/1645 train_time:79182ms step_avg:92.18ms +step:860/1645 train_time:79275ms step_avg:92.18ms +step:861/1645 train_time:79368ms step_avg:92.18ms +step:862/1645 train_time:79461ms step_avg:92.18ms +step:863/1645 train_time:79553ms step_avg:92.18ms +step:864/1645 train_time:79646ms step_avg:92.18ms +step:865/1645 train_time:79740ms step_avg:92.18ms +step:866/1645 train_time:79833ms step_avg:92.19ms +step:867/1645 train_time:79926ms step_avg:92.19ms +step:868/1645 train_time:80019ms step_avg:92.19ms +step:869/1645 train_time:80112ms step_avg:92.19ms +step:870/1645 train_time:80206ms step_avg:92.19ms +step:871/1645 train_time:80298ms step_avg:92.19ms +step:872/1645 train_time:80391ms step_avg:92.19ms +step:873/1645 train_time:80484ms step_avg:92.19ms +step:874/1645 train_time:80576ms step_avg:92.19ms +step:875/1645 train_time:80669ms step_avg:92.19ms +step:875/1645 val_loss:3.5146 train_time:80762ms step_avg:92.30ms +step:876/1645 train_time:80783ms step_avg:92.22ms +step:877/1645 train_time:80861ms step_avg:92.20ms +step:878/1645 train_time:80956ms step_avg:92.20ms +step:879/1645 train_time:81048ms step_avg:92.20ms +step:880/1645 train_time:81140ms step_avg:92.20ms +step:881/1645 train_time:81231ms step_avg:92.20ms +step:882/1645 train_time:81323ms step_avg:92.20ms +step:883/1645 train_time:81414ms step_avg:92.20ms +step:884/1645 train_time:81506ms step_avg:92.20ms +step:885/1645 train_time:81598ms step_avg:92.20ms +step:886/1645 train_time:81692ms step_avg:92.20ms +step:887/1645 train_time:81788ms step_avg:92.21ms +step:888/1645 train_time:81883ms step_avg:92.21ms +step:889/1645 train_time:81977ms step_avg:92.21ms +step:890/1645 train_time:82070ms step_avg:92.21ms +step:891/1645 train_time:82162ms step_avg:92.21ms +step:892/1645 train_time:82255ms step_avg:92.21ms +step:893/1645 train_time:82347ms step_avg:92.21ms +step:894/1645 train_time:82439ms step_avg:92.21ms +step:895/1645 train_time:82531ms step_avg:92.21ms +step:896/1645 train_time:82624ms step_avg:92.21ms +step:897/1645 train_time:82717ms step_avg:92.22ms +step:898/1645 train_time:82812ms step_avg:92.22ms +step:899/1645 train_time:82907ms step_avg:92.22ms +step:900/1645 train_time:83001ms step_avg:92.22ms +step:901/1645 train_time:83093ms step_avg:92.22ms +step:902/1645 train_time:83187ms step_avg:92.22ms +step:903/1645 train_time:83279ms step_avg:92.23ms +step:904/1645 train_time:83372ms step_avg:92.23ms +step:905/1645 train_time:83464ms step_avg:92.23ms +step:906/1645 train_time:83556ms step_avg:92.23ms +step:907/1645 train_time:83650ms step_avg:92.23ms +step:908/1645 train_time:83742ms step_avg:92.23ms +step:909/1645 train_time:83836ms step_avg:92.23ms +step:910/1645 train_time:83930ms step_avg:92.23ms +step:911/1645 train_time:84024ms step_avg:92.23ms +step:912/1645 train_time:84116ms step_avg:92.23ms +step:913/1645 train_time:84209ms step_avg:92.23ms +step:914/1645 train_time:84301ms step_avg:92.23ms +step:915/1645 train_time:84394ms step_avg:92.23ms +step:916/1645 train_time:84486ms step_avg:92.23ms +step:917/1645 train_time:84579ms step_avg:92.23ms +step:918/1645 train_time:84672ms step_avg:92.24ms +step:919/1645 train_time:84765ms step_avg:92.24ms +step:920/1645 train_time:84858ms step_avg:92.24ms +step:921/1645 train_time:84951ms step_avg:92.24ms +step:922/1645 train_time:85045ms step_avg:92.24ms +step:923/1645 train_time:85138ms step_avg:92.24ms +step:924/1645 train_time:85230ms step_avg:92.24ms +step:925/1645 train_time:85324ms step_avg:92.24ms +step:926/1645 train_time:85416ms step_avg:92.24ms +step:927/1645 train_time:85509ms step_avg:92.24ms +step:928/1645 train_time:85601ms step_avg:92.24ms +step:929/1645 train_time:85693ms step_avg:92.24ms +step:930/1645 train_time:85787ms step_avg:92.24ms +step:931/1645 train_time:85880ms step_avg:92.25ms +step:932/1645 train_time:85974ms step_avg:92.25ms +step:933/1645 train_time:86067ms step_avg:92.25ms +step:934/1645 train_time:86160ms step_avg:92.25ms +step:935/1645 train_time:86252ms step_avg:92.25ms +step:936/1645 train_time:86345ms step_avg:92.25ms +step:937/1645 train_time:86438ms step_avg:92.25ms +step:938/1645 train_time:86531ms step_avg:92.25ms +step:939/1645 train_time:86623ms step_avg:92.25ms +step:940/1645 train_time:86716ms step_avg:92.25ms +step:941/1645 train_time:86809ms step_avg:92.25ms +step:942/1645 train_time:86902ms step_avg:92.25ms +step:943/1645 train_time:86994ms step_avg:92.25ms +step:944/1645 train_time:87089ms step_avg:92.26ms +step:945/1645 train_time:87182ms step_avg:92.26ms +step:946/1645 train_time:87275ms step_avg:92.26ms +step:947/1645 train_time:87368ms step_avg:92.26ms +step:948/1645 train_time:87460ms step_avg:92.26ms +step:949/1645 train_time:87552ms step_avg:92.26ms +step:950/1645 train_time:87645ms step_avg:92.26ms +step:951/1645 train_time:87738ms step_avg:92.26ms +step:952/1645 train_time:87832ms step_avg:92.26ms +step:953/1645 train_time:87923ms step_avg:92.26ms +step:954/1645 train_time:88016ms step_avg:92.26ms +step:955/1645 train_time:88109ms step_avg:92.26ms +step:956/1645 train_time:88202ms step_avg:92.26ms +step:957/1645 train_time:88295ms step_avg:92.26ms +step:958/1645 train_time:88389ms step_avg:92.26ms +step:959/1645 train_time:88483ms step_avg:92.27ms +step:960/1645 train_time:88575ms step_avg:92.27ms +step:961/1645 train_time:88669ms step_avg:92.27ms +step:962/1645 train_time:88762ms step_avg:92.27ms +step:963/1645 train_time:88855ms step_avg:92.27ms +step:964/1645 train_time:88947ms step_avg:92.27ms +step:965/1645 train_time:89041ms step_avg:92.27ms +step:966/1645 train_time:89134ms step_avg:92.27ms +step:967/1645 train_time:89226ms step_avg:92.27ms +step:968/1645 train_time:89320ms step_avg:92.27ms +step:969/1645 train_time:89413ms step_avg:92.27ms +step:970/1645 train_time:89506ms step_avg:92.27ms +step:971/1645 train_time:89599ms step_avg:92.27ms +step:972/1645 train_time:89692ms step_avg:92.28ms +step:973/1645 train_time:89786ms step_avg:92.28ms +step:974/1645 train_time:89878ms step_avg:92.28ms +step:975/1645 train_time:89971ms step_avg:92.28ms +step:976/1645 train_time:90065ms step_avg:92.28ms +step:977/1645 train_time:90157ms step_avg:92.28ms +step:978/1645 train_time:90251ms step_avg:92.28ms +step:979/1645 train_time:90344ms step_avg:92.28ms +step:980/1645 train_time:90437ms step_avg:92.28ms +step:981/1645 train_time:90531ms step_avg:92.28ms +step:982/1645 train_time:90623ms step_avg:92.28ms +step:983/1645 train_time:90716ms step_avg:92.29ms +step:984/1645 train_time:90810ms step_avg:92.29ms +step:985/1645 train_time:90903ms step_avg:92.29ms +step:986/1645 train_time:90996ms step_avg:92.29ms +step:987/1645 train_time:91089ms step_avg:92.29ms +step:988/1645 train_time:91182ms step_avg:92.29ms +step:989/1645 train_time:91275ms step_avg:92.29ms +step:990/1645 train_time:91368ms step_avg:92.29ms +step:991/1645 train_time:91462ms step_avg:92.29ms +step:992/1645 train_time:91555ms step_avg:92.29ms +step:993/1645 train_time:91649ms step_avg:92.30ms +step:994/1645 train_time:91742ms step_avg:92.30ms +step:995/1645 train_time:91834ms step_avg:92.30ms +step:996/1645 train_time:91927ms step_avg:92.30ms +step:997/1645 train_time:92020ms step_avg:92.30ms +step:998/1645 train_time:92113ms step_avg:92.30ms +step:999/1645 train_time:92208ms step_avg:92.30ms +step:1000/1645 train_time:92301ms step_avg:92.30ms +step:1000/1645 val_loss:3.4655 train_time:92393ms step_avg:92.39ms +step:1001/1645 train_time:92415ms step_avg:92.32ms +step:1002/1645 train_time:92492ms step_avg:92.31ms +step:1003/1645 train_time:92587ms step_avg:92.31ms +step:1004/1645 train_time:92680ms step_avg:92.31ms +step:1005/1645 train_time:92772ms step_avg:92.31ms +step:1006/1645 train_time:92865ms step_avg:92.31ms +step:1007/1645 train_time:92957ms step_avg:92.31ms +step:1008/1645 train_time:93049ms step_avg:92.31ms +step:1009/1645 train_time:93142ms step_avg:92.31ms +step:1010/1645 train_time:93233ms step_avg:92.31ms +step:1011/1645 train_time:93328ms step_avg:92.31ms +step:1012/1645 train_time:93423ms step_avg:92.32ms +step:1013/1645 train_time:93519ms step_avg:92.32ms +step:1014/1645 train_time:93612ms step_avg:92.32ms +step:1015/1645 train_time:93706ms step_avg:92.32ms +step:1016/1645 train_time:93800ms step_avg:92.32ms +step:1017/1645 train_time:93891ms step_avg:92.32ms +step:1018/1645 train_time:93983ms step_avg:92.32ms +step:1019/1645 train_time:94075ms step_avg:92.32ms +step:1020/1645 train_time:94167ms step_avg:92.32ms +step:1021/1645 train_time:94260ms step_avg:92.32ms +step:1022/1645 train_time:94353ms step_avg:92.32ms +step:1023/1645 train_time:94448ms step_avg:92.32ms +step:1024/1645 train_time:94543ms step_avg:92.33ms +step:1025/1645 train_time:94636ms step_avg:92.33ms +step:1026/1645 train_time:94730ms step_avg:92.33ms +step:1027/1645 train_time:94822ms step_avg:92.33ms +step:1028/1645 train_time:94915ms step_avg:92.33ms +step:1029/1645 train_time:95008ms step_avg:92.33ms +step:1030/1645 train_time:95099ms step_avg:92.33ms +step:1031/1645 train_time:95192ms step_avg:92.33ms +step:1032/1645 train_time:95284ms step_avg:92.33ms +step:1033/1645 train_time:95377ms step_avg:92.33ms +step:1034/1645 train_time:95471ms step_avg:92.33ms +step:1035/1645 train_time:95564ms step_avg:92.33ms +step:1036/1645 train_time:95657ms step_avg:92.33ms +step:1037/1645 train_time:95751ms step_avg:92.33ms +step:1038/1645 train_time:95844ms step_avg:92.34ms +step:1039/1645 train_time:95937ms step_avg:92.34ms +step:1040/1645 train_time:96029ms step_avg:92.34ms +step:1041/1645 train_time:96122ms step_avg:92.34ms +step:1042/1645 train_time:96215ms step_avg:92.34ms +step:1043/1645 train_time:96308ms step_avg:92.34ms +step:1044/1645 train_time:96402ms step_avg:92.34ms +step:1045/1645 train_time:96495ms step_avg:92.34ms +step:1046/1645 train_time:96589ms step_avg:92.34ms +step:1047/1645 train_time:96682ms step_avg:92.34ms +step:1048/1645 train_time:96775ms step_avg:92.34ms +step:1049/1645 train_time:96869ms step_avg:92.34ms +step:1050/1645 train_time:96962ms step_avg:92.34ms +step:1051/1645 train_time:97054ms step_avg:92.34ms +step:1052/1645 train_time:97148ms step_avg:92.35ms +step:1053/1645 train_time:97241ms step_avg:92.35ms +step:1054/1645 train_time:97334ms step_avg:92.35ms +step:1055/1645 train_time:97428ms step_avg:92.35ms +step:1056/1645 train_time:97521ms step_avg:92.35ms +step:1057/1645 train_time:97616ms step_avg:92.35ms +step:1058/1645 train_time:97709ms step_avg:92.35ms +step:1059/1645 train_time:97802ms step_avg:92.35ms +step:1060/1645 train_time:97895ms step_avg:92.35ms +step:1061/1645 train_time:97989ms step_avg:92.35ms +step:1062/1645 train_time:98082ms step_avg:92.36ms +step:1063/1645 train_time:98175ms step_avg:92.36ms +step:1064/1645 train_time:98267ms step_avg:92.36ms +step:1065/1645 train_time:98361ms step_avg:92.36ms +step:1066/1645 train_time:98454ms step_avg:92.36ms +step:1067/1645 train_time:98547ms step_avg:92.36ms +step:1068/1645 train_time:98640ms step_avg:92.36ms +step:1069/1645 train_time:98733ms step_avg:92.36ms +step:1070/1645 train_time:98827ms step_avg:92.36ms +step:1071/1645 train_time:98920ms step_avg:92.36ms +step:1072/1645 train_time:99013ms step_avg:92.36ms +step:1073/1645 train_time:99106ms step_avg:92.36ms +step:1074/1645 train_time:99198ms step_avg:92.36ms +step:1075/1645 train_time:99291ms step_avg:92.36ms +step:1076/1645 train_time:99384ms step_avg:92.36ms +step:1077/1645 train_time:99477ms step_avg:92.36ms +step:1078/1645 train_time:99570ms step_avg:92.37ms +step:1079/1645 train_time:99663ms step_avg:92.37ms +step:1080/1645 train_time:99756ms step_avg:92.37ms +step:1081/1645 train_time:99849ms step_avg:92.37ms +step:1082/1645 train_time:99942ms step_avg:92.37ms +step:1083/1645 train_time:100035ms step_avg:92.37ms +step:1084/1645 train_time:100128ms step_avg:92.37ms +step:1085/1645 train_time:100220ms step_avg:92.37ms +step:1086/1645 train_time:100314ms step_avg:92.37ms +step:1087/1645 train_time:100407ms step_avg:92.37ms +step:1088/1645 train_time:100500ms step_avg:92.37ms +step:1089/1645 train_time:100593ms step_avg:92.37ms +step:1090/1645 train_time:100687ms step_avg:92.37ms +step:1091/1645 train_time:100779ms step_avg:92.37ms +step:1092/1645 train_time:100873ms step_avg:92.37ms +step:1093/1645 train_time:100966ms step_avg:92.38ms +step:1094/1645 train_time:101059ms step_avg:92.38ms +step:1095/1645 train_time:101151ms step_avg:92.38ms +step:1096/1645 train_time:101244ms step_avg:92.38ms +step:1097/1645 train_time:101336ms step_avg:92.38ms +step:1098/1645 train_time:101430ms step_avg:92.38ms +step:1099/1645 train_time:101524ms step_avg:92.38ms +step:1100/1645 train_time:101617ms step_avg:92.38ms +step:1101/1645 train_time:101710ms step_avg:92.38ms +step:1102/1645 train_time:101804ms step_avg:92.38ms +step:1103/1645 train_time:101897ms step_avg:92.38ms +step:1104/1645 train_time:101991ms step_avg:92.38ms +step:1105/1645 train_time:102085ms step_avg:92.38ms +step:1106/1645 train_time:102178ms step_avg:92.39ms +step:1107/1645 train_time:102271ms step_avg:92.39ms +step:1108/1645 train_time:102364ms step_avg:92.39ms +step:1109/1645 train_time:102458ms step_avg:92.39ms +step:1110/1645 train_time:102552ms step_avg:92.39ms +step:1111/1645 train_time:102646ms step_avg:92.39ms +step:1112/1645 train_time:102741ms step_avg:92.39ms +step:1113/1645 train_time:102834ms step_avg:92.39ms +step:1114/1645 train_time:102930ms step_avg:92.40ms +step:1115/1645 train_time:103023ms step_avg:92.40ms +step:1116/1645 train_time:103117ms step_avg:92.40ms +step:1117/1645 train_time:103210ms step_avg:92.40ms +step:1118/1645 train_time:103304ms step_avg:92.40ms +step:1119/1645 train_time:103398ms step_avg:92.40ms +step:1120/1645 train_time:103492ms step_avg:92.40ms +step:1121/1645 train_time:103586ms step_avg:92.41ms +step:1122/1645 train_time:103680ms step_avg:92.41ms +step:1123/1645 train_time:103774ms step_avg:92.41ms +step:1124/1645 train_time:103868ms step_avg:92.41ms +step:1125/1645 train_time:103961ms step_avg:92.41ms +step:1125/1645 val_loss:3.4120 train_time:104055ms step_avg:92.49ms +step:1126/1645 train_time:104078ms step_avg:92.43ms +step:1127/1645 train_time:104159ms step_avg:92.42ms +step:1128/1645 train_time:104260ms step_avg:92.43ms +step:1129/1645 train_time:104353ms step_avg:92.43ms +step:1130/1645 train_time:104445ms step_avg:92.43ms +step:1131/1645 train_time:104538ms step_avg:92.43ms +step:1132/1645 train_time:104630ms step_avg:92.43ms +step:1133/1645 train_time:104723ms step_avg:92.43ms +step:1134/1645 train_time:104816ms step_avg:92.43ms +step:1135/1645 train_time:104908ms step_avg:92.43ms +step:1136/1645 train_time:105001ms step_avg:92.43ms +step:1137/1645 train_time:105096ms step_avg:92.43ms +step:1138/1645 train_time:105192ms step_avg:92.44ms +step:1139/1645 train_time:105287ms step_avg:92.44ms +step:1140/1645 train_time:105381ms step_avg:92.44ms +step:1141/1645 train_time:105474ms step_avg:92.44ms +step:1142/1645 train_time:105566ms step_avg:92.44ms +step:1143/1645 train_time:105659ms step_avg:92.44ms +step:1144/1645 train_time:105752ms step_avg:92.44ms +step:1145/1645 train_time:105844ms step_avg:92.44ms +step:1146/1645 train_time:105937ms step_avg:92.44ms +step:1147/1645 train_time:106031ms step_avg:92.44ms +step:1148/1645 train_time:106126ms step_avg:92.44ms +step:1149/1645 train_time:106221ms step_avg:92.45ms +step:1150/1645 train_time:106316ms step_avg:92.45ms +step:1151/1645 train_time:106410ms step_avg:92.45ms +step:1152/1645 train_time:106505ms step_avg:92.45ms +step:1153/1645 train_time:106597ms step_avg:92.45ms +step:1154/1645 train_time:106691ms step_avg:92.45ms +step:1155/1645 train_time:106784ms step_avg:92.45ms +step:1156/1645 train_time:106876ms step_avg:92.45ms +step:1157/1645 train_time:106969ms step_avg:92.45ms +step:1158/1645 train_time:107064ms step_avg:92.46ms +step:1159/1645 train_time:107158ms step_avg:92.46ms +step:1160/1645 train_time:107252ms step_avg:92.46ms +step:1161/1645 train_time:107346ms step_avg:92.46ms +step:1162/1645 train_time:107441ms step_avg:92.46ms +step:1163/1645 train_time:107534ms step_avg:92.46ms +step:1164/1645 train_time:107627ms step_avg:92.46ms +step:1165/1645 train_time:107720ms step_avg:92.46ms +step:1166/1645 train_time:107814ms step_avg:92.46ms +step:1167/1645 train_time:107906ms step_avg:92.46ms +step:1168/1645 train_time:107999ms step_avg:92.47ms +step:1169/1645 train_time:108093ms step_avg:92.47ms +step:1170/1645 train_time:108187ms step_avg:92.47ms +step:1171/1645 train_time:108283ms step_avg:92.47ms +step:1172/1645 train_time:108378ms step_avg:92.47ms +step:1173/1645 train_time:108472ms step_avg:92.47ms +step:1174/1645 train_time:108565ms step_avg:92.47ms +step:1175/1645 train_time:108658ms step_avg:92.48ms +step:1176/1645 train_time:108751ms step_avg:92.48ms +step:1177/1645 train_time:108844ms step_avg:92.48ms +step:1178/1645 train_time:108938ms step_avg:92.48ms +step:1179/1645 train_time:109031ms step_avg:92.48ms +step:1180/1645 train_time:109125ms step_avg:92.48ms +step:1181/1645 train_time:109219ms step_avg:92.48ms +step:1182/1645 train_time:109312ms step_avg:92.48ms +step:1183/1645 train_time:109406ms step_avg:92.48ms +step:1184/1645 train_time:109500ms step_avg:92.48ms +step:1185/1645 train_time:109593ms step_avg:92.48ms +step:1186/1645 train_time:109686ms step_avg:92.48ms +step:1187/1645 train_time:109780ms step_avg:92.48ms +step:1188/1645 train_time:109874ms step_avg:92.49ms +step:1189/1645 train_time:109967ms step_avg:92.49ms +step:1190/1645 train_time:110061ms step_avg:92.49ms +step:1191/1645 train_time:110155ms step_avg:92.49ms +step:1192/1645 train_time:110249ms step_avg:92.49ms +step:1193/1645 train_time:110343ms step_avg:92.49ms +step:1194/1645 train_time:110437ms step_avg:92.49ms +step:1195/1645 train_time:110532ms step_avg:92.50ms +step:1196/1645 train_time:110626ms step_avg:92.50ms +step:1197/1645 train_time:110718ms step_avg:92.50ms +step:1198/1645 train_time:110811ms step_avg:92.50ms +step:1199/1645 train_time:110905ms step_avg:92.50ms +step:1200/1645 train_time:110999ms step_avg:92.50ms +step:1201/1645 train_time:111092ms step_avg:92.50ms +step:1202/1645 train_time:111185ms step_avg:92.50ms +step:1203/1645 train_time:111280ms step_avg:92.50ms +step:1204/1645 train_time:111375ms step_avg:92.50ms +step:1205/1645 train_time:111468ms step_avg:92.50ms +step:1206/1645 train_time:111561ms step_avg:92.51ms +step:1207/1645 train_time:111654ms step_avg:92.51ms +step:1208/1645 train_time:111748ms step_avg:92.51ms +step:1209/1645 train_time:111841ms step_avg:92.51ms +step:1210/1645 train_time:111935ms step_avg:92.51ms +step:1211/1645 train_time:112029ms step_avg:92.51ms +step:1212/1645 train_time:112122ms step_avg:92.51ms +step:1213/1645 train_time:112217ms step_avg:92.51ms +step:1214/1645 train_time:112310ms step_avg:92.51ms +step:1215/1645 train_time:112404ms step_avg:92.51ms +step:1216/1645 train_time:112498ms step_avg:92.52ms +step:1217/1645 train_time:112592ms step_avg:92.52ms +step:1218/1645 train_time:112685ms step_avg:92.52ms +step:1219/1645 train_time:112779ms step_avg:92.52ms +step:1220/1645 train_time:112872ms step_avg:92.52ms +step:1221/1645 train_time:112965ms step_avg:92.52ms +step:1222/1645 train_time:113059ms step_avg:92.52ms +step:1223/1645 train_time:113152ms step_avg:92.52ms +step:1224/1645 train_time:113245ms step_avg:92.52ms +step:1225/1645 train_time:113339ms step_avg:92.52ms +step:1226/1645 train_time:113433ms step_avg:92.52ms +step:1227/1645 train_time:113527ms step_avg:92.52ms +step:1228/1645 train_time:113621ms step_avg:92.52ms +step:1229/1645 train_time:113714ms step_avg:92.53ms +step:1230/1645 train_time:113807ms step_avg:92.53ms +step:1231/1645 train_time:113901ms step_avg:92.53ms +step:1232/1645 train_time:113995ms step_avg:92.53ms +step:1233/1645 train_time:114088ms step_avg:92.53ms +step:1234/1645 train_time:114182ms step_avg:92.53ms +step:1235/1645 train_time:114276ms step_avg:92.53ms +step:1236/1645 train_time:114370ms step_avg:92.53ms +step:1237/1645 train_time:114464ms step_avg:92.53ms +step:1238/1645 train_time:114558ms step_avg:92.53ms +step:1239/1645 train_time:114651ms step_avg:92.54ms +step:1240/1645 train_time:114745ms step_avg:92.54ms +step:1241/1645 train_time:114839ms step_avg:92.54ms +step:1242/1645 train_time:114933ms step_avg:92.54ms +step:1243/1645 train_time:115026ms step_avg:92.54ms +step:1244/1645 train_time:115119ms step_avg:92.54ms +step:1245/1645 train_time:115213ms step_avg:92.54ms +step:1246/1645 train_time:115307ms step_avg:92.54ms +step:1247/1645 train_time:115402ms step_avg:92.54ms +step:1248/1645 train_time:115495ms step_avg:92.54ms +step:1249/1645 train_time:115588ms step_avg:92.54ms +step:1250/1645 train_time:115682ms step_avg:92.55ms +step:1250/1645 val_loss:3.3738 train_time:115775ms step_avg:92.62ms +step:1251/1645 train_time:115796ms step_avg:92.56ms +step:1252/1645 train_time:115874ms step_avg:92.55ms +step:1253/1645 train_time:115969ms step_avg:92.55ms +step:1254/1645 train_time:116063ms step_avg:92.55ms +step:1255/1645 train_time:116156ms step_avg:92.55ms +step:1256/1645 train_time:116248ms step_avg:92.55ms +step:1257/1645 train_time:116340ms step_avg:92.55ms +step:1258/1645 train_time:116433ms step_avg:92.55ms +step:1259/1645 train_time:116526ms step_avg:92.55ms +step:1260/1645 train_time:116619ms step_avg:92.55ms +step:1261/1645 train_time:116713ms step_avg:92.56ms +step:1262/1645 train_time:116809ms step_avg:92.56ms +step:1263/1645 train_time:116903ms step_avg:92.56ms +step:1264/1645 train_time:116997ms step_avg:92.56ms +step:1265/1645 train_time:117092ms step_avg:92.56ms +step:1266/1645 train_time:117184ms step_avg:92.56ms +step:1267/1645 train_time:117277ms step_avg:92.56ms +step:1268/1645 train_time:117371ms step_avg:92.56ms +step:1269/1645 train_time:117463ms step_avg:92.56ms +step:1270/1645 train_time:117556ms step_avg:92.56ms +step:1271/1645 train_time:117649ms step_avg:92.56ms +step:1272/1645 train_time:117744ms step_avg:92.57ms +step:1273/1645 train_time:117838ms step_avg:92.57ms +step:1274/1645 train_time:117932ms step_avg:92.57ms +step:1275/1645 train_time:118026ms step_avg:92.57ms +step:1276/1645 train_time:118120ms step_avg:92.57ms +step:1277/1645 train_time:118214ms step_avg:92.57ms +step:1278/1645 train_time:118307ms step_avg:92.57ms +step:1279/1645 train_time:118401ms step_avg:92.57ms +step:1280/1645 train_time:118494ms step_avg:92.57ms +step:1281/1645 train_time:118587ms step_avg:92.57ms +step:1282/1645 train_time:118680ms step_avg:92.57ms +step:1283/1645 train_time:118774ms step_avg:92.58ms +step:1284/1645 train_time:118868ms step_avg:92.58ms +step:1285/1645 train_time:118962ms step_avg:92.58ms +step:1286/1645 train_time:119056ms step_avg:92.58ms +step:1287/1645 train_time:119149ms step_avg:92.58ms +step:1288/1645 train_time:119242ms step_avg:92.58ms +step:1289/1645 train_time:119337ms step_avg:92.58ms +step:1290/1645 train_time:119430ms step_avg:92.58ms +step:1291/1645 train_time:119523ms step_avg:92.58ms +step:1292/1645 train_time:119618ms step_avg:92.58ms +step:1293/1645 train_time:119711ms step_avg:92.58ms +step:1294/1645 train_time:119805ms step_avg:92.58ms +step:1295/1645 train_time:119899ms step_avg:92.59ms +step:1296/1645 train_time:119993ms step_avg:92.59ms +step:1297/1645 train_time:120088ms step_avg:92.59ms +step:1298/1645 train_time:120182ms step_avg:92.59ms +step:1299/1645 train_time:120275ms step_avg:92.59ms +step:1300/1645 train_time:120369ms step_avg:92.59ms +step:1301/1645 train_time:120463ms step_avg:92.59ms +step:1302/1645 train_time:120557ms step_avg:92.59ms +step:1303/1645 train_time:120650ms step_avg:92.59ms +step:1304/1645 train_time:120743ms step_avg:92.59ms +step:1305/1645 train_time:120837ms step_avg:92.60ms +step:1306/1645 train_time:120930ms step_avg:92.60ms +step:1307/1645 train_time:121023ms step_avg:92.60ms +step:1308/1645 train_time:121118ms step_avg:92.60ms +step:1309/1645 train_time:121211ms step_avg:92.60ms +step:1310/1645 train_time:121304ms step_avg:92.60ms +step:1311/1645 train_time:121398ms step_avg:92.60ms +step:1312/1645 train_time:121492ms step_avg:92.60ms +step:1313/1645 train_time:121586ms step_avg:92.60ms +step:1314/1645 train_time:121679ms step_avg:92.60ms +step:1315/1645 train_time:121772ms step_avg:92.60ms +step:1316/1645 train_time:121866ms step_avg:92.60ms +step:1317/1645 train_time:121959ms step_avg:92.60ms +step:1318/1645 train_time:122053ms step_avg:92.61ms +step:1319/1645 train_time:122147ms step_avg:92.61ms +step:1320/1645 train_time:122241ms step_avg:92.61ms +step:1321/1645 train_time:122336ms step_avg:92.61ms +step:1322/1645 train_time:122429ms step_avg:92.61ms +step:1323/1645 train_time:122523ms step_avg:92.61ms +step:1324/1645 train_time:122617ms step_avg:92.61ms +step:1325/1645 train_time:122711ms step_avg:92.61ms +step:1326/1645 train_time:122804ms step_avg:92.61ms +step:1327/1645 train_time:122898ms step_avg:92.61ms +step:1328/1645 train_time:122991ms step_avg:92.61ms +step:1329/1645 train_time:123085ms step_avg:92.61ms +step:1330/1645 train_time:123179ms step_avg:92.62ms +step:1331/1645 train_time:123272ms step_avg:92.62ms +step:1332/1645 train_time:123366ms step_avg:92.62ms +step:1333/1645 train_time:123459ms step_avg:92.62ms +step:1334/1645 train_time:123552ms step_avg:92.62ms +step:1335/1645 train_time:123646ms step_avg:92.62ms +step:1336/1645 train_time:123740ms step_avg:92.62ms +step:1337/1645 train_time:123833ms step_avg:92.62ms +step:1338/1645 train_time:123927ms step_avg:92.62ms +step:1339/1645 train_time:124022ms step_avg:92.62ms +step:1340/1645 train_time:124115ms step_avg:92.62ms +step:1341/1645 train_time:124210ms step_avg:92.62ms +step:1342/1645 train_time:124303ms step_avg:92.63ms +step:1343/1645 train_time:124398ms step_avg:92.63ms +step:1344/1645 train_time:124492ms step_avg:92.63ms +step:1345/1645 train_time:124584ms step_avg:92.63ms +step:1346/1645 train_time:124678ms step_avg:92.63ms +step:1347/1645 train_time:124771ms step_avg:92.63ms +step:1348/1645 train_time:124864ms step_avg:92.63ms +step:1349/1645 train_time:124958ms step_avg:92.63ms +step:1350/1645 train_time:125051ms step_avg:92.63ms +step:1351/1645 train_time:125145ms step_avg:92.63ms +step:1352/1645 train_time:125238ms step_avg:92.63ms +step:1353/1645 train_time:125333ms step_avg:92.63ms +step:1354/1645 train_time:125428ms step_avg:92.63ms +step:1355/1645 train_time:125520ms step_avg:92.64ms +step:1356/1645 train_time:125614ms step_avg:92.64ms +step:1357/1645 train_time:125707ms step_avg:92.64ms +step:1358/1645 train_time:125801ms step_avg:92.64ms +step:1359/1645 train_time:125894ms step_avg:92.64ms +step:1360/1645 train_time:125987ms step_avg:92.64ms +step:1361/1645 train_time:126081ms step_avg:92.64ms +step:1362/1645 train_time:126175ms step_avg:92.64ms +step:1363/1645 train_time:126268ms step_avg:92.64ms +step:1364/1645 train_time:126362ms step_avg:92.64ms +step:1365/1645 train_time:126456ms step_avg:92.64ms +step:1366/1645 train_time:126549ms step_avg:92.64ms +step:1367/1645 train_time:126643ms step_avg:92.64ms +step:1368/1645 train_time:126736ms step_avg:92.64ms +step:1369/1645 train_time:126830ms step_avg:92.64ms +step:1370/1645 train_time:126923ms step_avg:92.64ms +step:1371/1645 train_time:127018ms step_avg:92.65ms +step:1372/1645 train_time:127112ms step_avg:92.65ms +step:1373/1645 train_time:127206ms step_avg:92.65ms +step:1374/1645 train_time:127299ms step_avg:92.65ms +step:1375/1645 train_time:127393ms step_avg:92.65ms +step:1375/1645 val_loss:3.3395 train_time:127487ms step_avg:92.72ms +step:1376/1645 train_time:127513ms step_avg:92.67ms +step:1377/1645 train_time:127585ms step_avg:92.65ms +step:1378/1645 train_time:127682ms step_avg:92.66ms +step:1379/1645 train_time:127775ms step_avg:92.66ms +step:1380/1645 train_time:127869ms step_avg:92.66ms +step:1381/1645 train_time:127961ms step_avg:92.66ms +step:1382/1645 train_time:128053ms step_avg:92.66ms +step:1383/1645 train_time:128147ms step_avg:92.66ms +step:1384/1645 train_time:128240ms step_avg:92.66ms +step:1385/1645 train_time:128333ms step_avg:92.66ms +step:1386/1645 train_time:128427ms step_avg:92.66ms +step:1387/1645 train_time:128523ms step_avg:92.66ms +step:1388/1645 train_time:128618ms step_avg:92.66ms +step:1389/1645 train_time:128713ms step_avg:92.67ms +step:1390/1645 train_time:128806ms step_avg:92.67ms +step:1391/1645 train_time:128900ms step_avg:92.67ms +step:1392/1645 train_time:128993ms step_avg:92.67ms +step:1393/1645 train_time:129086ms step_avg:92.67ms +step:1394/1645 train_time:129179ms step_avg:92.67ms +step:1395/1645 train_time:129272ms step_avg:92.67ms +step:1396/1645 train_time:129366ms step_avg:92.67ms +step:1397/1645 train_time:129460ms step_avg:92.67ms +step:1398/1645 train_time:129554ms step_avg:92.67ms +step:1399/1645 train_time:129649ms step_avg:92.67ms +step:1400/1645 train_time:129742ms step_avg:92.67ms +step:1401/1645 train_time:129836ms step_avg:92.67ms +step:1402/1645 train_time:129929ms step_avg:92.67ms +step:1403/1645 train_time:130023ms step_avg:92.67ms +step:1404/1645 train_time:130116ms step_avg:92.67ms +step:1405/1645 train_time:130209ms step_avg:92.68ms +step:1406/1645 train_time:130302ms step_avg:92.68ms +step:1407/1645 train_time:130395ms step_avg:92.68ms +step:1408/1645 train_time:130489ms step_avg:92.68ms +step:1409/1645 train_time:130585ms step_avg:92.68ms +step:1410/1645 train_time:130678ms step_avg:92.68ms +step:1411/1645 train_time:130772ms step_avg:92.68ms +step:1412/1645 train_time:130865ms step_avg:92.68ms +step:1413/1645 train_time:130959ms step_avg:92.68ms +step:1414/1645 train_time:131052ms step_avg:92.68ms +step:1415/1645 train_time:131146ms step_avg:92.68ms +step:1416/1645 train_time:131240ms step_avg:92.68ms +step:1417/1645 train_time:131333ms step_avg:92.68ms +step:1418/1645 train_time:131427ms step_avg:92.69ms +step:1419/1645 train_time:131521ms step_avg:92.69ms +step:1420/1645 train_time:131614ms step_avg:92.69ms +step:1421/1645 train_time:131709ms step_avg:92.69ms +step:1422/1645 train_time:131802ms step_avg:92.69ms +step:1423/1645 train_time:131896ms step_avg:92.69ms +step:1424/1645 train_time:131989ms step_avg:92.69ms +step:1425/1645 train_time:132082ms step_avg:92.69ms +step:1426/1645 train_time:132176ms step_avg:92.69ms +step:1427/1645 train_time:132270ms step_avg:92.69ms +step:1428/1645 train_time:132364ms step_avg:92.69ms +step:1429/1645 train_time:132457ms step_avg:92.69ms +step:1430/1645 train_time:132551ms step_avg:92.69ms +step:1431/1645 train_time:132645ms step_avg:92.69ms +step:1432/1645 train_time:132739ms step_avg:92.69ms +step:1433/1645 train_time:132833ms step_avg:92.70ms +step:1434/1645 train_time:132927ms step_avg:92.70ms +step:1435/1645 train_time:133020ms step_avg:92.70ms +step:1436/1645 train_time:133113ms step_avg:92.70ms +step:1437/1645 train_time:133207ms step_avg:92.70ms +step:1438/1645 train_time:133300ms step_avg:92.70ms +step:1439/1645 train_time:133393ms step_avg:92.70ms +step:1440/1645 train_time:133487ms step_avg:92.70ms +step:1441/1645 train_time:133582ms step_avg:92.70ms +step:1442/1645 train_time:133675ms step_avg:92.70ms +step:1443/1645 train_time:133769ms step_avg:92.70ms +step:1444/1645 train_time:133862ms step_avg:92.70ms +step:1445/1645 train_time:133956ms step_avg:92.70ms +step:1446/1645 train_time:134049ms step_avg:92.70ms +step:1447/1645 train_time:134142ms step_avg:92.70ms +step:1448/1645 train_time:134236ms step_avg:92.70ms +step:1449/1645 train_time:134330ms step_avg:92.71ms +step:1450/1645 train_time:134424ms step_avg:92.71ms +step:1451/1645 train_time:134518ms step_avg:92.71ms +step:1452/1645 train_time:134612ms step_avg:92.71ms +step:1453/1645 train_time:134707ms step_avg:92.71ms +step:1454/1645 train_time:134801ms step_avg:92.71ms +step:1455/1645 train_time:134894ms step_avg:92.71ms +step:1456/1645 train_time:134987ms step_avg:92.71ms +step:1457/1645 train_time:135080ms step_avg:92.71ms +step:1458/1645 train_time:135174ms step_avg:92.71ms +step:1459/1645 train_time:135268ms step_avg:92.71ms +step:1460/1645 train_time:135361ms step_avg:92.71ms +step:1461/1645 train_time:135455ms step_avg:92.71ms +step:1462/1645 train_time:135549ms step_avg:92.71ms +step:1463/1645 train_time:135643ms step_avg:92.72ms +step:1464/1645 train_time:135737ms step_avg:92.72ms +step:1465/1645 train_time:135830ms step_avg:92.72ms +step:1466/1645 train_time:135924ms step_avg:92.72ms +step:1467/1645 train_time:136018ms step_avg:92.72ms +step:1468/1645 train_time:136111ms step_avg:92.72ms +step:1469/1645 train_time:136206ms step_avg:92.72ms +step:1470/1645 train_time:136299ms step_avg:92.72ms +step:1471/1645 train_time:136392ms step_avg:92.72ms +step:1472/1645 train_time:136486ms step_avg:92.72ms +step:1473/1645 train_time:136581ms step_avg:92.72ms +step:1474/1645 train_time:136674ms step_avg:92.72ms +step:1475/1645 train_time:136768ms step_avg:92.72ms +step:1476/1645 train_time:136861ms step_avg:92.72ms +step:1477/1645 train_time:136955ms step_avg:92.72ms +step:1478/1645 train_time:137049ms step_avg:92.73ms +step:1479/1645 train_time:137143ms step_avg:92.73ms +step:1480/1645 train_time:137237ms step_avg:92.73ms +step:1481/1645 train_time:137331ms step_avg:92.73ms +step:1482/1645 train_time:137424ms step_avg:92.73ms +step:1483/1645 train_time:137517ms step_avg:92.73ms +step:1484/1645 train_time:137611ms step_avg:92.73ms +step:1485/1645 train_time:137707ms step_avg:92.73ms +step:1486/1645 train_time:137800ms step_avg:92.73ms +step:1487/1645 train_time:137895ms step_avg:92.73ms +step:1488/1645 train_time:137988ms step_avg:92.73ms +step:1489/1645 train_time:138082ms step_avg:92.73ms +step:1490/1645 train_time:138175ms step_avg:92.73ms +step:1491/1645 train_time:138269ms step_avg:92.74ms +step:1492/1645 train_time:138362ms step_avg:92.74ms +step:1493/1645 train_time:138455ms step_avg:92.74ms +step:1494/1645 train_time:138548ms step_avg:92.74ms +step:1495/1645 train_time:138642ms step_avg:92.74ms +step:1496/1645 train_time:138736ms step_avg:92.74ms +step:1497/1645 train_time:138830ms step_avg:92.74ms +step:1498/1645 train_time:138925ms step_avg:92.74ms +step:1499/1645 train_time:139019ms step_avg:92.74ms +step:1500/1645 train_time:139112ms step_avg:92.74ms +step:1500/1645 val_loss:3.3097 train_time:139206ms step_avg:92.80ms +step:1501/1645 train_time:139232ms step_avg:92.76ms +step:1502/1645 train_time:139304ms step_avg:92.75ms +step:1503/1645 train_time:139400ms step_avg:92.75ms +step:1504/1645 train_time:139495ms step_avg:92.75ms +step:1505/1645 train_time:139588ms step_avg:92.75ms +step:1506/1645 train_time:139680ms step_avg:92.75ms +step:1507/1645 train_time:139773ms step_avg:92.75ms +step:1508/1645 train_time:139866ms step_avg:92.75ms +step:1509/1645 train_time:139958ms step_avg:92.75ms +step:1510/1645 train_time:140052ms step_avg:92.75ms +step:1511/1645 train_time:140146ms step_avg:92.75ms +step:1512/1645 train_time:140241ms step_avg:92.75ms +step:1513/1645 train_time:140337ms step_avg:92.75ms +step:1514/1645 train_time:140432ms step_avg:92.76ms +step:1515/1645 train_time:140525ms step_avg:92.76ms +step:1516/1645 train_time:140618ms step_avg:92.76ms +step:1517/1645 train_time:140711ms step_avg:92.76ms +step:1518/1645 train_time:140803ms step_avg:92.76ms +step:1519/1645 train_time:140897ms step_avg:92.76ms +step:1520/1645 train_time:140990ms step_avg:92.76ms +step:1521/1645 train_time:141083ms step_avg:92.76ms +step:1522/1645 train_time:141179ms step_avg:92.76ms +step:1523/1645 train_time:141273ms step_avg:92.76ms +step:1524/1645 train_time:141367ms step_avg:92.76ms +step:1525/1645 train_time:141462ms step_avg:92.76ms +step:1526/1645 train_time:141555ms step_avg:92.76ms +step:1527/1645 train_time:141648ms step_avg:92.76ms +step:1528/1645 train_time:141741ms step_avg:92.76ms +step:1529/1645 train_time:141834ms step_avg:92.76ms +step:1530/1645 train_time:141927ms step_avg:92.76ms +step:1531/1645 train_time:142019ms step_avg:92.76ms +step:1532/1645 train_time:142113ms step_avg:92.76ms +step:1533/1645 train_time:142206ms step_avg:92.76ms +step:1534/1645 train_time:142302ms step_avg:92.77ms +step:1535/1645 train_time:142396ms step_avg:92.77ms +step:1536/1645 train_time:142490ms step_avg:92.77ms +step:1537/1645 train_time:142583ms step_avg:92.77ms +step:1538/1645 train_time:142676ms step_avg:92.77ms +step:1539/1645 train_time:142770ms step_avg:92.77ms +step:1540/1645 train_time:142863ms step_avg:92.77ms +step:1541/1645 train_time:142956ms step_avg:92.77ms +step:1542/1645 train_time:143049ms step_avg:92.77ms +step:1543/1645 train_time:143142ms step_avg:92.77ms +step:1544/1645 train_time:143237ms step_avg:92.77ms +step:1545/1645 train_time:143331ms step_avg:92.77ms +step:1546/1645 train_time:143424ms step_avg:92.77ms +step:1547/1645 train_time:143518ms step_avg:92.77ms +step:1548/1645 train_time:143611ms step_avg:92.77ms +step:1549/1645 train_time:143704ms step_avg:92.77ms +step:1550/1645 train_time:143798ms step_avg:92.77ms +step:1551/1645 train_time:143891ms step_avg:92.77ms +step:1552/1645 train_time:143984ms step_avg:92.77ms +step:1553/1645 train_time:144077ms step_avg:92.77ms +step:1554/1645 train_time:144170ms step_avg:92.77ms +step:1555/1645 train_time:144264ms step_avg:92.77ms +step:1556/1645 train_time:144359ms step_avg:92.78ms +step:1557/1645 train_time:144452ms step_avg:92.78ms +step:1558/1645 train_time:144546ms step_avg:92.78ms +step:1559/1645 train_time:144639ms step_avg:92.78ms +step:1560/1645 train_time:144733ms step_avg:92.78ms +step:1561/1645 train_time:144826ms step_avg:92.78ms +step:1562/1645 train_time:144920ms step_avg:92.78ms +step:1563/1645 train_time:145013ms step_avg:92.78ms +step:1564/1645 train_time:145106ms step_avg:92.78ms +step:1565/1645 train_time:145200ms step_avg:92.78ms +step:1566/1645 train_time:145294ms step_avg:92.78ms +step:1567/1645 train_time:145387ms step_avg:92.78ms +step:1568/1645 train_time:145481ms step_avg:92.78ms +step:1569/1645 train_time:145576ms step_avg:92.78ms +step:1570/1645 train_time:145670ms step_avg:92.78ms +step:1571/1645 train_time:145763ms step_avg:92.78ms +step:1572/1645 train_time:145856ms step_avg:92.78ms +step:1573/1645 train_time:145951ms step_avg:92.78ms +step:1574/1645 train_time:146045ms step_avg:92.79ms +step:1575/1645 train_time:146138ms step_avg:92.79ms +step:1576/1645 train_time:146231ms step_avg:92.79ms +step:1577/1645 train_time:146324ms step_avg:92.79ms +step:1578/1645 train_time:146418ms step_avg:92.79ms +step:1579/1645 train_time:146511ms step_avg:92.79ms +step:1580/1645 train_time:146605ms step_avg:92.79ms +step:1581/1645 train_time:146699ms step_avg:92.79ms +step:1582/1645 train_time:146792ms step_avg:92.79ms +step:1583/1645 train_time:146887ms step_avg:92.79ms +step:1584/1645 train_time:146980ms step_avg:92.79ms +step:1585/1645 train_time:147073ms step_avg:92.79ms +step:1586/1645 train_time:147166ms step_avg:92.79ms +step:1587/1645 train_time:147260ms step_avg:92.79ms +step:1588/1645 train_time:147354ms step_avg:92.79ms +step:1589/1645 train_time:147447ms step_avg:92.79ms +step:1590/1645 train_time:147541ms step_avg:92.79ms +step:1591/1645 train_time:147635ms step_avg:92.79ms +step:1592/1645 train_time:147727ms step_avg:92.79ms +step:1593/1645 train_time:147821ms step_avg:92.79ms +step:1594/1645 train_time:147915ms step_avg:92.80ms +step:1595/1645 train_time:148010ms step_avg:92.80ms +step:1596/1645 train_time:148103ms step_avg:92.80ms +step:1597/1645 train_time:148198ms step_avg:92.80ms +step:1598/1645 train_time:148291ms step_avg:92.80ms +step:1599/1645 train_time:148384ms step_avg:92.80ms +step:1600/1645 train_time:148478ms step_avg:92.80ms +step:1601/1645 train_time:148572ms step_avg:92.80ms +step:1602/1645 train_time:148665ms step_avg:92.80ms +step:1603/1645 train_time:148759ms step_avg:92.80ms +step:1604/1645 train_time:148852ms step_avg:92.80ms +step:1605/1645 train_time:148945ms step_avg:92.80ms +step:1606/1645 train_time:149039ms step_avg:92.80ms +step:1607/1645 train_time:149132ms step_avg:92.80ms +step:1608/1645 train_time:149225ms step_avg:92.80ms +step:1609/1645 train_time:149318ms step_avg:92.80ms +step:1610/1645 train_time:149412ms step_avg:92.80ms +step:1611/1645 train_time:149508ms step_avg:92.80ms +step:1612/1645 train_time:149602ms step_avg:92.80ms +step:1613/1645 train_time:149697ms step_avg:92.81ms +step:1614/1645 train_time:149790ms step_avg:92.81ms +step:1615/1645 train_time:149883ms step_avg:92.81ms +step:1616/1645 train_time:149976ms step_avg:92.81ms +step:1617/1645 train_time:150070ms step_avg:92.81ms +step:1618/1645 train_time:150164ms step_avg:92.81ms +step:1619/1645 train_time:150257ms step_avg:92.81ms +step:1620/1645 train_time:150351ms step_avg:92.81ms +step:1621/1645 train_time:150444ms step_avg:92.81ms +step:1622/1645 train_time:150537ms step_avg:92.81ms +step:1623/1645 train_time:150632ms step_avg:92.81ms +step:1624/1645 train_time:150726ms step_avg:92.81ms +step:1625/1645 train_time:150819ms step_avg:92.81ms +step:1625/1645 val_loss:3.2861 train_time:150912ms step_avg:92.87ms +step:1626/1645 train_time:150940ms step_avg:92.83ms +step:1627/1645 train_time:151011ms step_avg:92.82ms +step:1628/1645 train_time:151106ms step_avg:92.82ms +step:1629/1645 train_time:151200ms step_avg:92.82ms +step:1630/1645 train_time:151293ms step_avg:92.82ms +step:1631/1645 train_time:151385ms step_avg:92.82ms +step:1632/1645 train_time:151478ms step_avg:92.82ms +step:1633/1645 train_time:151570ms step_avg:92.82ms +step:1634/1645 train_time:151664ms step_avg:92.82ms +step:1635/1645 train_time:151756ms step_avg:92.82ms +step:1636/1645 train_time:151850ms step_avg:92.82ms +step:1637/1645 train_time:151946ms step_avg:92.82ms +step:1638/1645 train_time:152042ms step_avg:92.82ms +step:1639/1645 train_time:152137ms step_avg:92.82ms +step:1640/1645 train_time:152232ms step_avg:92.82ms +step:1641/1645 train_time:152324ms step_avg:92.82ms +step:1642/1645 train_time:152417ms step_avg:92.82ms +step:1643/1645 train_time:152510ms step_avg:92.82ms +step:1644/1645 train_time:152602ms step_avg:92.82ms +step:1645/1645 train_time:152696ms step_avg:92.82ms +step:1645/1645 val_loss:3.2803 train_time:152789ms step_avg:92.88ms +peak memory allocated: 32074 MiB reserved: 46896 MiB diff --git a/records/091825_Smear/2bbcf732-7a1d-4bad-b992-aa018419033e.txt b/records/091825_Smear/2bbcf732-7a1d-4bad-b992-aa018419033e.txt new file mode 100644 index 000000000..858f17e23 --- /dev/null +++ b/records/091825_Smear/2bbcf732-7a1d-4bad-b992-aa018419033e.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:36:53 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 34C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:128ms step_avg:128.44ms +step:2/1645 train_time:146ms step_avg:73.18ms +step:3/1645 train_time:217ms step_avg:72.17ms +step:4/1645 train_time:306ms step_avg:76.51ms +step:5/1645 train_time:397ms step_avg:79.33ms +step:6/1645 train_time:487ms step_avg:81.20ms +step:7/1645 train_time:578ms step_avg:82.61ms +step:8/1645 train_time:669ms step_avg:83.61ms +step:9/1645 train_time:760ms step_avg:84.46ms +step:10/1645 train_time:851ms step_avg:85.12ms +step:11/1645 train_time:942ms step_avg:85.65ms +step:12/1645 train_time:1035ms step_avg:86.21ms +step:13/1645 train_time:1131ms step_avg:86.98ms +step:14/1645 train_time:1226ms step_avg:87.59ms +step:15/1645 train_time:1319ms step_avg:87.90ms +step:16/1645 train_time:1410ms step_avg:88.14ms +step:17/1645 train_time:1502ms step_avg:88.34ms +step:18/1645 train_time:1594ms step_avg:88.53ms +step:19/1645 train_time:1684ms step_avg:88.65ms +step:20/1645 train_time:1775ms step_avg:88.76ms +step:21/1645 train_time:1867ms step_avg:88.90ms +step:22/1645 train_time:1959ms step_avg:89.04ms +step:23/1645 train_time:2051ms step_avg:89.18ms +step:24/1645 train_time:2145ms step_avg:89.38ms +step:25/1645 train_time:2238ms step_avg:89.53ms +step:26/1645 train_time:2331ms step_avg:89.64ms +step:27/1645 train_time:2423ms step_avg:89.74ms +step:28/1645 train_time:2515ms step_avg:89.81ms +step:29/1645 train_time:2606ms step_avg:89.86ms +step:30/1645 train_time:2697ms step_avg:89.88ms +step:31/1645 train_time:2787ms step_avg:89.90ms +step:32/1645 train_time:2878ms step_avg:89.94ms +step:33/1645 train_time:2969ms step_avg:89.98ms +step:34/1645 train_time:3062ms step_avg:90.07ms +step:35/1645 train_time:3154ms step_avg:90.12ms +step:36/1645 train_time:3247ms step_avg:90.20ms +step:37/1645 train_time:3340ms step_avg:90.26ms +step:38/1645 train_time:3432ms step_avg:90.32ms +step:39/1645 train_time:3525ms step_avg:90.37ms +step:40/1645 train_time:3617ms step_avg:90.41ms +step:41/1645 train_time:3707ms step_avg:90.43ms +step:42/1645 train_time:3799ms step_avg:90.44ms +step:43/1645 train_time:3890ms step_avg:90.47ms +step:44/1645 train_time:3982ms step_avg:90.49ms +step:45/1645 train_time:4074ms step_avg:90.53ms +step:46/1645 train_time:4167ms step_avg:90.58ms +step:47/1645 train_time:4260ms step_avg:90.63ms +step:48/1645 train_time:4352ms step_avg:90.67ms +step:49/1645 train_time:4445ms step_avg:90.71ms +step:50/1645 train_time:4536ms step_avg:90.72ms +step:51/1645 train_time:4628ms step_avg:90.75ms +step:52/1645 train_time:4720ms step_avg:90.76ms +step:53/1645 train_time:4811ms step_avg:90.77ms +step:54/1645 train_time:4902ms step_avg:90.78ms +step:55/1645 train_time:4993ms step_avg:90.79ms +step:56/1645 train_time:5085ms step_avg:90.81ms +step:57/1645 train_time:5177ms step_avg:90.82ms +step:58/1645 train_time:5270ms step_avg:90.86ms +step:59/1645 train_time:5363ms step_avg:90.90ms +step:60/1645 train_time:5455ms step_avg:90.91ms +step:61/1645 train_time:5546ms step_avg:90.92ms +step:62/1645 train_time:5638ms step_avg:90.94ms +step:63/1645 train_time:5730ms step_avg:90.95ms +step:64/1645 train_time:5821ms step_avg:90.96ms +step:65/1645 train_time:5913ms step_avg:90.96ms +step:66/1645 train_time:6004ms step_avg:90.97ms +step:67/1645 train_time:6095ms step_avg:90.97ms +step:68/1645 train_time:6188ms step_avg:91.00ms +step:69/1645 train_time:6280ms step_avg:91.02ms +step:70/1645 train_time:6372ms step_avg:91.03ms +step:71/1645 train_time:6464ms step_avg:91.04ms +step:72/1645 train_time:6556ms step_avg:91.05ms +step:73/1645 train_time:6648ms step_avg:91.06ms +step:74/1645 train_time:6740ms step_avg:91.08ms +step:75/1645 train_time:6832ms step_avg:91.09ms +step:76/1645 train_time:6924ms step_avg:91.10ms +step:77/1645 train_time:7015ms step_avg:91.10ms +step:78/1645 train_time:7106ms step_avg:91.11ms +step:79/1645 train_time:7199ms step_avg:91.13ms +step:80/1645 train_time:7291ms step_avg:91.14ms +step:81/1645 train_time:7383ms step_avg:91.14ms +step:82/1645 train_time:7474ms step_avg:91.15ms +step:83/1645 train_time:7566ms step_avg:91.16ms +step:84/1645 train_time:7659ms step_avg:91.18ms +step:85/1645 train_time:7751ms step_avg:91.19ms +step:86/1645 train_time:7843ms step_avg:91.20ms +step:87/1645 train_time:7936ms step_avg:91.22ms +step:88/1645 train_time:8027ms step_avg:91.22ms +step:89/1645 train_time:8118ms step_avg:91.22ms +step:90/1645 train_time:8210ms step_avg:91.22ms +step:91/1645 train_time:8301ms step_avg:91.22ms +step:92/1645 train_time:8393ms step_avg:91.23ms +step:93/1645 train_time:8484ms step_avg:91.23ms +step:94/1645 train_time:8576ms step_avg:91.23ms +step:95/1645 train_time:8669ms step_avg:91.25ms +step:96/1645 train_time:8761ms step_avg:91.27ms +step:97/1645 train_time:8853ms step_avg:91.26ms +step:98/1645 train_time:8945ms step_avg:91.28ms +step:99/1645 train_time:9037ms step_avg:91.28ms +step:100/1645 train_time:9128ms step_avg:91.28ms +step:101/1645 train_time:9220ms step_avg:91.29ms +step:102/1645 train_time:9312ms step_avg:91.29ms +step:103/1645 train_time:9402ms step_avg:91.28ms +step:104/1645 train_time:9494ms step_avg:91.29ms +step:105/1645 train_time:9587ms step_avg:91.30ms +step:106/1645 train_time:9679ms step_avg:91.31ms +step:107/1645 train_time:9771ms step_avg:91.31ms +step:108/1645 train_time:9863ms step_avg:91.33ms +step:109/1645 train_time:9955ms step_avg:91.33ms +step:110/1645 train_time:10047ms step_avg:91.34ms +step:111/1645 train_time:10140ms step_avg:91.35ms +step:112/1645 train_time:10231ms step_avg:91.35ms +step:113/1645 train_time:10323ms step_avg:91.35ms +step:114/1645 train_time:10414ms step_avg:91.35ms +step:115/1645 train_time:10506ms step_avg:91.35ms +step:116/1645 train_time:10598ms step_avg:91.36ms +step:117/1645 train_time:10689ms step_avg:91.36ms +step:118/1645 train_time:10781ms step_avg:91.36ms +step:119/1645 train_time:10872ms step_avg:91.36ms +step:120/1645 train_time:10964ms step_avg:91.37ms +step:121/1645 train_time:11056ms step_avg:91.37ms +step:122/1645 train_time:11149ms step_avg:91.38ms +step:123/1645 train_time:11241ms step_avg:91.39ms +step:124/1645 train_time:11332ms step_avg:91.39ms +step:125/1645 train_time:11423ms step_avg:91.38ms +step:125/1645 val_loss:4.3027 train_time:11514ms step_avg:92.11ms +step:126/1645 train_time:11531ms step_avg:91.51ms +step:127/1645 train_time:11611ms step_avg:91.43ms +step:128/1645 train_time:11713ms step_avg:91.51ms +step:129/1645 train_time:11809ms step_avg:91.54ms +step:130/1645 train_time:11900ms step_avg:91.54ms +step:131/1645 train_time:11991ms step_avg:91.53ms +step:132/1645 train_time:12082ms step_avg:91.53ms +step:133/1645 train_time:12172ms step_avg:91.52ms +step:134/1645 train_time:12263ms step_avg:91.52ms +step:135/1645 train_time:12354ms step_avg:91.51ms +step:136/1645 train_time:12445ms step_avg:91.50ms +step:137/1645 train_time:12537ms step_avg:91.51ms +step:138/1645 train_time:12630ms step_avg:91.52ms +step:139/1645 train_time:12724ms step_avg:91.54ms +step:140/1645 train_time:12817ms step_avg:91.55ms +step:141/1645 train_time:12909ms step_avg:91.56ms +step:142/1645 train_time:13001ms step_avg:91.56ms +step:143/1645 train_time:13092ms step_avg:91.55ms +step:144/1645 train_time:13183ms step_avg:91.55ms +step:145/1645 train_time:13275ms step_avg:91.55ms +step:146/1645 train_time:13366ms step_avg:91.55ms +step:147/1645 train_time:13457ms step_avg:91.54ms +step:148/1645 train_time:13548ms step_avg:91.54ms +step:149/1645 train_time:13641ms step_avg:91.55ms +step:150/1645 train_time:13734ms step_avg:91.56ms +step:151/1645 train_time:13826ms step_avg:91.56ms +step:152/1645 train_time:13918ms step_avg:91.56ms +step:153/1645 train_time:14010ms step_avg:91.57ms +step:154/1645 train_time:14102ms step_avg:91.57ms +step:155/1645 train_time:14194ms step_avg:91.57ms +step:156/1645 train_time:14285ms step_avg:91.57ms +step:157/1645 train_time:14376ms step_avg:91.57ms +step:158/1645 train_time:14468ms step_avg:91.57ms +step:159/1645 train_time:14559ms step_avg:91.57ms +step:160/1645 train_time:14652ms step_avg:91.57ms +step:161/1645 train_time:14745ms step_avg:91.58ms +step:162/1645 train_time:14838ms step_avg:91.59ms +step:163/1645 train_time:14929ms step_avg:91.59ms +step:164/1645 train_time:15021ms step_avg:91.59ms +step:165/1645 train_time:15112ms step_avg:91.59ms +step:166/1645 train_time:15206ms step_avg:91.60ms +step:167/1645 train_time:15296ms step_avg:91.60ms +step:168/1645 train_time:15387ms step_avg:91.59ms +step:169/1645 train_time:15479ms step_avg:91.59ms +step:170/1645 train_time:15570ms step_avg:91.59ms +step:171/1645 train_time:15663ms step_avg:91.59ms +step:172/1645 train_time:15755ms step_avg:91.60ms +step:173/1645 train_time:15847ms step_avg:91.60ms +step:174/1645 train_time:15939ms step_avg:91.60ms +step:175/1645 train_time:16030ms step_avg:91.60ms +step:176/1645 train_time:16122ms step_avg:91.60ms +step:177/1645 train_time:16213ms step_avg:91.60ms +step:178/1645 train_time:16305ms step_avg:91.60ms +step:179/1645 train_time:16397ms step_avg:91.60ms +step:180/1645 train_time:16489ms step_avg:91.60ms +step:181/1645 train_time:16581ms step_avg:91.61ms +step:182/1645 train_time:16673ms step_avg:91.61ms +step:183/1645 train_time:16765ms step_avg:91.61ms +step:184/1645 train_time:16857ms step_avg:91.61ms +step:185/1645 train_time:16948ms step_avg:91.61ms +step:186/1645 train_time:17039ms step_avg:91.61ms +step:187/1645 train_time:17130ms step_avg:91.61ms +step:188/1645 train_time:17223ms step_avg:91.61ms +step:189/1645 train_time:17314ms step_avg:91.61ms +step:190/1645 train_time:17405ms step_avg:91.61ms +step:191/1645 train_time:17498ms step_avg:91.61ms +step:192/1645 train_time:17590ms step_avg:91.61ms +step:193/1645 train_time:17682ms step_avg:91.62ms +step:194/1645 train_time:17774ms step_avg:91.62ms +step:195/1645 train_time:17866ms step_avg:91.62ms +step:196/1645 train_time:17958ms step_avg:91.62ms +step:197/1645 train_time:18049ms step_avg:91.62ms +step:198/1645 train_time:18140ms step_avg:91.62ms +step:199/1645 train_time:18231ms step_avg:91.61ms +step:200/1645 train_time:18322ms step_avg:91.61ms +step:201/1645 train_time:18413ms step_avg:91.61ms +step:202/1645 train_time:18505ms step_avg:91.61ms +step:203/1645 train_time:18597ms step_avg:91.61ms +step:204/1645 train_time:18689ms step_avg:91.61ms +step:205/1645 train_time:18782ms step_avg:91.62ms +step:206/1645 train_time:18873ms step_avg:91.62ms +step:207/1645 train_time:18964ms step_avg:91.61ms +step:208/1645 train_time:19056ms step_avg:91.62ms +step:209/1645 train_time:19148ms step_avg:91.62ms +step:210/1645 train_time:19239ms step_avg:91.61ms +step:211/1645 train_time:19330ms step_avg:91.61ms +step:212/1645 train_time:19422ms step_avg:91.61ms +step:213/1645 train_time:19513ms step_avg:91.61ms +step:214/1645 train_time:19606ms step_avg:91.62ms +step:215/1645 train_time:19698ms step_avg:91.62ms +step:216/1645 train_time:19790ms step_avg:91.62ms +step:217/1645 train_time:19882ms step_avg:91.62ms +step:218/1645 train_time:19974ms step_avg:91.62ms +step:219/1645 train_time:20066ms step_avg:91.62ms +step:220/1645 train_time:20157ms step_avg:91.62ms +step:221/1645 train_time:20248ms step_avg:91.62ms +step:222/1645 train_time:20339ms step_avg:91.62ms +step:223/1645 train_time:20430ms step_avg:91.61ms +step:224/1645 train_time:20521ms step_avg:91.61ms +step:225/1645 train_time:20612ms step_avg:91.61ms +step:226/1645 train_time:20705ms step_avg:91.61ms +step:227/1645 train_time:20797ms step_avg:91.61ms +step:228/1645 train_time:20888ms step_avg:91.61ms +step:229/1645 train_time:20980ms step_avg:91.61ms +step:230/1645 train_time:21072ms step_avg:91.62ms +step:231/1645 train_time:21164ms step_avg:91.62ms +step:232/1645 train_time:21255ms step_avg:91.62ms +step:233/1645 train_time:21347ms step_avg:91.62ms +step:234/1645 train_time:21437ms step_avg:91.61ms +step:235/1645 train_time:21529ms step_avg:91.61ms +step:236/1645 train_time:21621ms step_avg:91.61ms +step:237/1645 train_time:21712ms step_avg:91.61ms +step:238/1645 train_time:21804ms step_avg:91.62ms +step:239/1645 train_time:21897ms step_avg:91.62ms +step:240/1645 train_time:21989ms step_avg:91.62ms +step:241/1645 train_time:22081ms step_avg:91.62ms +step:242/1645 train_time:22172ms step_avg:91.62ms +step:243/1645 train_time:22264ms step_avg:91.62ms +step:244/1645 train_time:22356ms step_avg:91.62ms +step:245/1645 train_time:22447ms step_avg:91.62ms +step:246/1645 train_time:22538ms step_avg:91.62ms +step:247/1645 train_time:22629ms step_avg:91.62ms +step:248/1645 train_time:22721ms step_avg:91.62ms +step:249/1645 train_time:22812ms step_avg:91.62ms +step:250/1645 train_time:22906ms step_avg:91.62ms +step:250/1645 val_loss:3.9618 train_time:22998ms step_avg:91.99ms +step:251/1645 train_time:23014ms step_avg:91.69ms +step:252/1645 train_time:23095ms step_avg:91.65ms +step:253/1645 train_time:23188ms step_avg:91.65ms +step:254/1645 train_time:23280ms step_avg:91.66ms +step:255/1645 train_time:23371ms step_avg:91.65ms +step:256/1645 train_time:23462ms step_avg:91.65ms +step:257/1645 train_time:23552ms step_avg:91.64ms +step:258/1645 train_time:23644ms step_avg:91.64ms +step:259/1645 train_time:23736ms step_avg:91.64ms +step:260/1645 train_time:23826ms step_avg:91.64ms +step:261/1645 train_time:23918ms step_avg:91.64ms +step:262/1645 train_time:24011ms step_avg:91.64ms +step:263/1645 train_time:24105ms step_avg:91.65ms +step:264/1645 train_time:24198ms step_avg:91.66ms +step:265/1645 train_time:24290ms step_avg:91.66ms +step:266/1645 train_time:24381ms step_avg:91.66ms +step:267/1645 train_time:24472ms step_avg:91.65ms +step:268/1645 train_time:24564ms step_avg:91.66ms +step:269/1645 train_time:24655ms step_avg:91.66ms +step:270/1645 train_time:24746ms step_avg:91.65ms +step:271/1645 train_time:24837ms step_avg:91.65ms +step:272/1645 train_time:24928ms step_avg:91.65ms +step:273/1645 train_time:25021ms step_avg:91.65ms +step:274/1645 train_time:25113ms step_avg:91.65ms +step:275/1645 train_time:25205ms step_avg:91.66ms +step:276/1645 train_time:25297ms step_avg:91.66ms +step:277/1645 train_time:25389ms step_avg:91.66ms +step:278/1645 train_time:25481ms step_avg:91.66ms +step:279/1645 train_time:25572ms step_avg:91.66ms +step:280/1645 train_time:25664ms step_avg:91.66ms +step:281/1645 train_time:25756ms step_avg:91.66ms +step:282/1645 train_time:25847ms step_avg:91.66ms +step:283/1645 train_time:25938ms step_avg:91.65ms +step:284/1645 train_time:26030ms step_avg:91.65ms +step:285/1645 train_time:26122ms step_avg:91.65ms +step:286/1645 train_time:26213ms step_avg:91.66ms +step:287/1645 train_time:26305ms step_avg:91.65ms +step:288/1645 train_time:26396ms step_avg:91.65ms +step:289/1645 train_time:26488ms step_avg:91.65ms +step:290/1645 train_time:26579ms step_avg:91.65ms +step:291/1645 train_time:26670ms step_avg:91.65ms +step:292/1645 train_time:26763ms step_avg:91.65ms +step:293/1645 train_time:26855ms step_avg:91.65ms +step:294/1645 train_time:26946ms step_avg:91.65ms +step:295/1645 train_time:27039ms step_avg:91.66ms +step:296/1645 train_time:27130ms step_avg:91.66ms +step:297/1645 train_time:27223ms step_avg:91.66ms +step:298/1645 train_time:27313ms step_avg:91.66ms +step:299/1645 train_time:27405ms step_avg:91.66ms +step:300/1645 train_time:27497ms step_avg:91.66ms +step:301/1645 train_time:27589ms step_avg:91.66ms +step:302/1645 train_time:27680ms step_avg:91.66ms +step:303/1645 train_time:27771ms step_avg:91.65ms +step:304/1645 train_time:27863ms step_avg:91.66ms +step:305/1645 train_time:27956ms step_avg:91.66ms +step:306/1645 train_time:28047ms step_avg:91.66ms +step:307/1645 train_time:28139ms step_avg:91.66ms +step:308/1645 train_time:28230ms step_avg:91.65ms +step:309/1645 train_time:28323ms step_avg:91.66ms +step:310/1645 train_time:28414ms step_avg:91.66ms +step:311/1645 train_time:28505ms step_avg:91.66ms +step:312/1645 train_time:28597ms step_avg:91.66ms +step:313/1645 train_time:28689ms step_avg:91.66ms +step:314/1645 train_time:28780ms step_avg:91.66ms +step:315/1645 train_time:28872ms step_avg:91.66ms +step:316/1645 train_time:28965ms step_avg:91.66ms +step:317/1645 train_time:29057ms step_avg:91.66ms +step:318/1645 train_time:29148ms step_avg:91.66ms +step:319/1645 train_time:29240ms step_avg:91.66ms +step:320/1645 train_time:29333ms step_avg:91.66ms +step:321/1645 train_time:29425ms step_avg:91.67ms +step:322/1645 train_time:29516ms step_avg:91.66ms +step:323/1645 train_time:29607ms step_avg:91.66ms +step:324/1645 train_time:29698ms step_avg:91.66ms +step:325/1645 train_time:29789ms step_avg:91.66ms +step:326/1645 train_time:29881ms step_avg:91.66ms +step:327/1645 train_time:29972ms step_avg:91.66ms +step:328/1645 train_time:30064ms step_avg:91.66ms +step:329/1645 train_time:30157ms step_avg:91.66ms +step:330/1645 train_time:30250ms step_avg:91.67ms +step:331/1645 train_time:30341ms step_avg:91.67ms +step:332/1645 train_time:30432ms step_avg:91.66ms +step:333/1645 train_time:30525ms step_avg:91.67ms +step:334/1645 train_time:30616ms step_avg:91.66ms +step:335/1645 train_time:30707ms step_avg:91.66ms +step:336/1645 train_time:30798ms step_avg:91.66ms +step:337/1645 train_time:30890ms step_avg:91.66ms +step:338/1645 train_time:30982ms step_avg:91.66ms +step:339/1645 train_time:31074ms step_avg:91.66ms +step:340/1645 train_time:31167ms step_avg:91.67ms +step:341/1645 train_time:31259ms step_avg:91.67ms +step:342/1645 train_time:31350ms step_avg:91.67ms +step:343/1645 train_time:31443ms step_avg:91.67ms +step:344/1645 train_time:31536ms step_avg:91.67ms +step:345/1645 train_time:31626ms step_avg:91.67ms +step:346/1645 train_time:31718ms step_avg:91.67ms +step:347/1645 train_time:31810ms step_avg:91.67ms +step:348/1645 train_time:31901ms step_avg:91.67ms +step:349/1645 train_time:31992ms step_avg:91.67ms +step:350/1645 train_time:32084ms step_avg:91.67ms +step:351/1645 train_time:32175ms step_avg:91.67ms +step:352/1645 train_time:32266ms step_avg:91.67ms +step:353/1645 train_time:32359ms step_avg:91.67ms +step:354/1645 train_time:32451ms step_avg:91.67ms +step:355/1645 train_time:32545ms step_avg:91.67ms +step:356/1645 train_time:32636ms step_avg:91.67ms +step:357/1645 train_time:32728ms step_avg:91.67ms +step:358/1645 train_time:32819ms step_avg:91.67ms +step:359/1645 train_time:32911ms step_avg:91.67ms +step:360/1645 train_time:33001ms step_avg:91.67ms +step:361/1645 train_time:33092ms step_avg:91.67ms +step:362/1645 train_time:33184ms step_avg:91.67ms +step:363/1645 train_time:33275ms step_avg:91.67ms +step:364/1645 train_time:33367ms step_avg:91.67ms +step:365/1645 train_time:33460ms step_avg:91.67ms +step:366/1645 train_time:33552ms step_avg:91.67ms +step:367/1645 train_time:33645ms step_avg:91.68ms +step:368/1645 train_time:33736ms step_avg:91.67ms +step:369/1645 train_time:33828ms step_avg:91.67ms +step:370/1645 train_time:33918ms step_avg:91.67ms +step:371/1645 train_time:34010ms step_avg:91.67ms +step:372/1645 train_time:34101ms step_avg:91.67ms +step:373/1645 train_time:34193ms step_avg:91.67ms +step:374/1645 train_time:34285ms step_avg:91.67ms +step:375/1645 train_time:34377ms step_avg:91.67ms +step:375/1645 val_loss:3.8105 train_time:34469ms step_avg:91.92ms +step:376/1645 train_time:34485ms step_avg:91.72ms +step:377/1645 train_time:34564ms step_avg:91.68ms +step:378/1645 train_time:34658ms step_avg:91.69ms +step:379/1645 train_time:34751ms step_avg:91.69ms +step:380/1645 train_time:34842ms step_avg:91.69ms +step:381/1645 train_time:34932ms step_avg:91.69ms +step:382/1645 train_time:35023ms step_avg:91.68ms +step:383/1645 train_time:35114ms step_avg:91.68ms +step:384/1645 train_time:35205ms step_avg:91.68ms +step:385/1645 train_time:35296ms step_avg:91.68ms +step:386/1645 train_time:35388ms step_avg:91.68ms +step:387/1645 train_time:35481ms step_avg:91.68ms +step:388/1645 train_time:35575ms step_avg:91.69ms +step:389/1645 train_time:35667ms step_avg:91.69ms +step:390/1645 train_time:35759ms step_avg:91.69ms +step:391/1645 train_time:35850ms step_avg:91.69ms +step:392/1645 train_time:35941ms step_avg:91.69ms +step:393/1645 train_time:36032ms step_avg:91.68ms +step:394/1645 train_time:36123ms step_avg:91.68ms +step:395/1645 train_time:36215ms step_avg:91.68ms +step:396/1645 train_time:36305ms step_avg:91.68ms +step:397/1645 train_time:36397ms step_avg:91.68ms +step:398/1645 train_time:36490ms step_avg:91.68ms +step:399/1645 train_time:36583ms step_avg:91.69ms +step:400/1645 train_time:36675ms step_avg:91.69ms +step:401/1645 train_time:36767ms step_avg:91.69ms +step:402/1645 train_time:36859ms step_avg:91.69ms +step:403/1645 train_time:36951ms step_avg:91.69ms +step:404/1645 train_time:37042ms step_avg:91.69ms +step:405/1645 train_time:37133ms step_avg:91.69ms +step:406/1645 train_time:37225ms step_avg:91.69ms +step:407/1645 train_time:37316ms step_avg:91.69ms +step:408/1645 train_time:37408ms step_avg:91.69ms +step:409/1645 train_time:37501ms step_avg:91.69ms +step:410/1645 train_time:37593ms step_avg:91.69ms +step:411/1645 train_time:37685ms step_avg:91.69ms +step:412/1645 train_time:37777ms step_avg:91.69ms +step:413/1645 train_time:37868ms step_avg:91.69ms +step:414/1645 train_time:37959ms step_avg:91.69ms +step:415/1645 train_time:38050ms step_avg:91.69ms +step:416/1645 train_time:38141ms step_avg:91.69ms +step:417/1645 train_time:38232ms step_avg:91.68ms +step:418/1645 train_time:38325ms step_avg:91.69ms +step:419/1645 train_time:38416ms step_avg:91.68ms +step:420/1645 train_time:38507ms step_avg:91.68ms +step:421/1645 train_time:38600ms step_avg:91.69ms +step:422/1645 train_time:38693ms step_avg:91.69ms +step:423/1645 train_time:38786ms step_avg:91.69ms +step:424/1645 train_time:38877ms step_avg:91.69ms +step:425/1645 train_time:38969ms step_avg:91.69ms +step:426/1645 train_time:39061ms step_avg:91.69ms +step:427/1645 train_time:39152ms step_avg:91.69ms +step:428/1645 train_time:39243ms step_avg:91.69ms +step:429/1645 train_time:39335ms step_avg:91.69ms +step:430/1645 train_time:39426ms step_avg:91.69ms +step:431/1645 train_time:39518ms step_avg:91.69ms +step:432/1645 train_time:39610ms step_avg:91.69ms +step:433/1645 train_time:39702ms step_avg:91.69ms +step:434/1645 train_time:39794ms step_avg:91.69ms +step:435/1645 train_time:39886ms step_avg:91.69ms +step:436/1645 train_time:39978ms step_avg:91.69ms +step:437/1645 train_time:40069ms step_avg:91.69ms +step:438/1645 train_time:40159ms step_avg:91.69ms +step:439/1645 train_time:40251ms step_avg:91.69ms +step:440/1645 train_time:40342ms step_avg:91.69ms +step:441/1645 train_time:40433ms step_avg:91.69ms +step:442/1645 train_time:40525ms step_avg:91.69ms +step:443/1645 train_time:40617ms step_avg:91.69ms +step:444/1645 train_time:40709ms step_avg:91.69ms +step:445/1645 train_time:40801ms step_avg:91.69ms +step:446/1645 train_time:40893ms step_avg:91.69ms +step:447/1645 train_time:40985ms step_avg:91.69ms +step:448/1645 train_time:41077ms step_avg:91.69ms +step:449/1645 train_time:41168ms step_avg:91.69ms +step:450/1645 train_time:41259ms step_avg:91.69ms +step:451/1645 train_time:41350ms step_avg:91.69ms +step:452/1645 train_time:41442ms step_avg:91.68ms +step:453/1645 train_time:41533ms step_avg:91.69ms +step:454/1645 train_time:41625ms step_avg:91.68ms +step:455/1645 train_time:41718ms step_avg:91.69ms +step:456/1645 train_time:41810ms step_avg:91.69ms +step:457/1645 train_time:41902ms step_avg:91.69ms +step:458/1645 train_time:41994ms step_avg:91.69ms +step:459/1645 train_time:42085ms step_avg:91.69ms +step:460/1645 train_time:42177ms step_avg:91.69ms +step:461/1645 train_time:42269ms step_avg:91.69ms +step:462/1645 train_time:42360ms step_avg:91.69ms +step:463/1645 train_time:42452ms step_avg:91.69ms +step:464/1645 train_time:42543ms step_avg:91.69ms +step:465/1645 train_time:42635ms step_avg:91.69ms +step:466/1645 train_time:42727ms step_avg:91.69ms +step:467/1645 train_time:42820ms step_avg:91.69ms +step:468/1645 train_time:42913ms step_avg:91.69ms +step:469/1645 train_time:43004ms step_avg:91.69ms +step:470/1645 train_time:43095ms step_avg:91.69ms +step:471/1645 train_time:43186ms step_avg:91.69ms +step:472/1645 train_time:43277ms step_avg:91.69ms +step:473/1645 train_time:43368ms step_avg:91.69ms +step:474/1645 train_time:43460ms step_avg:91.69ms +step:475/1645 train_time:43552ms step_avg:91.69ms +step:476/1645 train_time:43643ms step_avg:91.69ms +step:477/1645 train_time:43736ms step_avg:91.69ms +step:478/1645 train_time:43828ms step_avg:91.69ms +step:479/1645 train_time:43920ms step_avg:91.69ms +step:480/1645 train_time:44013ms step_avg:91.69ms +step:481/1645 train_time:44104ms step_avg:91.69ms +step:482/1645 train_time:44195ms step_avg:91.69ms +step:483/1645 train_time:44287ms step_avg:91.69ms +step:484/1645 train_time:44379ms step_avg:91.69ms +step:485/1645 train_time:44471ms step_avg:91.69ms +step:486/1645 train_time:44562ms step_avg:91.69ms +step:487/1645 train_time:44654ms step_avg:91.69ms +step:488/1645 train_time:44746ms step_avg:91.69ms +step:489/1645 train_time:44838ms step_avg:91.69ms +step:490/1645 train_time:44929ms step_avg:91.69ms +step:491/1645 train_time:45021ms step_avg:91.69ms +step:492/1645 train_time:45113ms step_avg:91.69ms +step:493/1645 train_time:45205ms step_avg:91.69ms +step:494/1645 train_time:45297ms step_avg:91.69ms +step:495/1645 train_time:45389ms step_avg:91.69ms +step:496/1645 train_time:45481ms step_avg:91.70ms +step:497/1645 train_time:45573ms step_avg:91.70ms +step:498/1645 train_time:45664ms step_avg:91.70ms +step:499/1645 train_time:45756ms step_avg:91.70ms +step:500/1645 train_time:45846ms step_avg:91.69ms +step:500/1645 val_loss:3.7085 train_time:45938ms step_avg:91.88ms +step:501/1645 train_time:45954ms step_avg:91.72ms +step:502/1645 train_time:46037ms step_avg:91.71ms +step:503/1645 train_time:46130ms step_avg:91.71ms +step:504/1645 train_time:46221ms step_avg:91.71ms +step:505/1645 train_time:46311ms step_avg:91.71ms +step:506/1645 train_time:46402ms step_avg:91.70ms +step:507/1645 train_time:46492ms step_avg:91.70ms +step:508/1645 train_time:46583ms step_avg:91.70ms +step:509/1645 train_time:46677ms step_avg:91.70ms +step:510/1645 train_time:46767ms step_avg:91.70ms +step:511/1645 train_time:46859ms step_avg:91.70ms +step:512/1645 train_time:46952ms step_avg:91.70ms +step:513/1645 train_time:47047ms step_avg:91.71ms +step:514/1645 train_time:47140ms step_avg:91.71ms +step:515/1645 train_time:47232ms step_avg:91.71ms +step:516/1645 train_time:47323ms step_avg:91.71ms +step:517/1645 train_time:47414ms step_avg:91.71ms +step:518/1645 train_time:47504ms step_avg:91.71ms +step:519/1645 train_time:47596ms step_avg:91.71ms +step:520/1645 train_time:47687ms step_avg:91.71ms +step:521/1645 train_time:47779ms step_avg:91.71ms +step:522/1645 train_time:47870ms step_avg:91.71ms +step:523/1645 train_time:47963ms step_avg:91.71ms +step:524/1645 train_time:48057ms step_avg:91.71ms +step:525/1645 train_time:48150ms step_avg:91.71ms +step:526/1645 train_time:48242ms step_avg:91.72ms +step:527/1645 train_time:48334ms step_avg:91.72ms +step:528/1645 train_time:48425ms step_avg:91.71ms +step:529/1645 train_time:48516ms step_avg:91.71ms +step:530/1645 train_time:48607ms step_avg:91.71ms +step:531/1645 train_time:48697ms step_avg:91.71ms +step:532/1645 train_time:48788ms step_avg:91.71ms +step:533/1645 train_time:48879ms step_avg:91.71ms +step:534/1645 train_time:48972ms step_avg:91.71ms +step:535/1645 train_time:49064ms step_avg:91.71ms +step:536/1645 train_time:49156ms step_avg:91.71ms +step:537/1645 train_time:49249ms step_avg:91.71ms +step:538/1645 train_time:49342ms step_avg:91.71ms +step:539/1645 train_time:49433ms step_avg:91.71ms +step:540/1645 train_time:49524ms step_avg:91.71ms +step:541/1645 train_time:49615ms step_avg:91.71ms +step:542/1645 train_time:49706ms step_avg:91.71ms +step:543/1645 train_time:49797ms step_avg:91.71ms +step:544/1645 train_time:49888ms step_avg:91.71ms +step:545/1645 train_time:49980ms step_avg:91.71ms +step:546/1645 train_time:50072ms step_avg:91.71ms +step:547/1645 train_time:50164ms step_avg:91.71ms +step:548/1645 train_time:50257ms step_avg:91.71ms +step:549/1645 train_time:50350ms step_avg:91.71ms +step:550/1645 train_time:50442ms step_avg:91.71ms +step:551/1645 train_time:50536ms step_avg:91.72ms +step:552/1645 train_time:50628ms step_avg:91.72ms +step:553/1645 train_time:50720ms step_avg:91.72ms +step:554/1645 train_time:50813ms step_avg:91.72ms +step:555/1645 train_time:50906ms step_avg:91.72ms +step:556/1645 train_time:50999ms step_avg:91.73ms +step:557/1645 train_time:51092ms step_avg:91.73ms +step:558/1645 train_time:51184ms step_avg:91.73ms +step:559/1645 train_time:51279ms step_avg:91.73ms +step:560/1645 train_time:51373ms step_avg:91.74ms +step:561/1645 train_time:51465ms step_avg:91.74ms +step:562/1645 train_time:51558ms step_avg:91.74ms +step:563/1645 train_time:51651ms step_avg:91.74ms +step:564/1645 train_time:51743ms step_avg:91.74ms +step:565/1645 train_time:51836ms step_avg:91.75ms +step:566/1645 train_time:51929ms step_avg:91.75ms +step:567/1645 train_time:52022ms step_avg:91.75ms +step:568/1645 train_time:52115ms step_avg:91.75ms +step:569/1645 train_time:52209ms step_avg:91.75ms +step:570/1645 train_time:52302ms step_avg:91.76ms +step:571/1645 train_time:52395ms step_avg:91.76ms +step:572/1645 train_time:52488ms step_avg:91.76ms +step:573/1645 train_time:52581ms step_avg:91.76ms +step:574/1645 train_time:52674ms step_avg:91.77ms +step:575/1645 train_time:52766ms step_avg:91.77ms +step:576/1645 train_time:52859ms step_avg:91.77ms +step:577/1645 train_time:52952ms step_avg:91.77ms +step:578/1645 train_time:53045ms step_avg:91.77ms +step:579/1645 train_time:53137ms step_avg:91.77ms +step:580/1645 train_time:53231ms step_avg:91.78ms +step:581/1645 train_time:53324ms step_avg:91.78ms +step:582/1645 train_time:53417ms step_avg:91.78ms +step:583/1645 train_time:53511ms step_avg:91.78ms +step:584/1645 train_time:53603ms step_avg:91.79ms +step:585/1645 train_time:53697ms step_avg:91.79ms +step:586/1645 train_time:53790ms step_avg:91.79ms +step:587/1645 train_time:53882ms step_avg:91.79ms +step:588/1645 train_time:53977ms step_avg:91.80ms +step:589/1645 train_time:54069ms step_avg:91.80ms +step:590/1645 train_time:54162ms step_avg:91.80ms +step:591/1645 train_time:54255ms step_avg:91.80ms +step:592/1645 train_time:54348ms step_avg:91.80ms +step:593/1645 train_time:54441ms step_avg:91.81ms +step:594/1645 train_time:54535ms step_avg:91.81ms +step:595/1645 train_time:54628ms step_avg:91.81ms +step:596/1645 train_time:54720ms step_avg:91.81ms +step:597/1645 train_time:54813ms step_avg:91.81ms +step:598/1645 train_time:54905ms step_avg:91.82ms +step:599/1645 train_time:54998ms step_avg:91.82ms +step:600/1645 train_time:55091ms step_avg:91.82ms +step:601/1645 train_time:55183ms step_avg:91.82ms +step:602/1645 train_time:55277ms step_avg:91.82ms +step:603/1645 train_time:55370ms step_avg:91.82ms +step:604/1645 train_time:55464ms step_avg:91.83ms +step:605/1645 train_time:55557ms step_avg:91.83ms +step:606/1645 train_time:55649ms step_avg:91.83ms +step:607/1645 train_time:55742ms step_avg:91.83ms +step:608/1645 train_time:55835ms step_avg:91.83ms +step:609/1645 train_time:55928ms step_avg:91.84ms +step:610/1645 train_time:56021ms step_avg:91.84ms +step:611/1645 train_time:56114ms step_avg:91.84ms +step:612/1645 train_time:56207ms step_avg:91.84ms +step:613/1645 train_time:56300ms step_avg:91.84ms +step:614/1645 train_time:56394ms step_avg:91.85ms +step:615/1645 train_time:56487ms step_avg:91.85ms +step:616/1645 train_time:56580ms step_avg:91.85ms +step:617/1645 train_time:56673ms step_avg:91.85ms +step:618/1645 train_time:56766ms step_avg:91.85ms +step:619/1645 train_time:56859ms step_avg:91.86ms +step:620/1645 train_time:56951ms step_avg:91.86ms +step:621/1645 train_time:57043ms step_avg:91.86ms +step:622/1645 train_time:57137ms step_avg:91.86ms +step:623/1645 train_time:57229ms step_avg:91.86ms +step:624/1645 train_time:57322ms step_avg:91.86ms +step:625/1645 train_time:57416ms step_avg:91.87ms +step:625/1645 val_loss:3.6095 train_time:57509ms step_avg:92.01ms +step:626/1645 train_time:57526ms step_avg:91.89ms +step:627/1645 train_time:57611ms step_avg:91.88ms +step:628/1645 train_time:57714ms step_avg:91.90ms +step:629/1645 train_time:57812ms step_avg:91.91ms +step:630/1645 train_time:57904ms step_avg:91.91ms +step:631/1645 train_time:57995ms step_avg:91.91ms +step:632/1645 train_time:58087ms step_avg:91.91ms +step:633/1645 train_time:58178ms step_avg:91.91ms +step:634/1645 train_time:58269ms step_avg:91.91ms +step:635/1645 train_time:58361ms step_avg:91.91ms +step:636/1645 train_time:58454ms step_avg:91.91ms +step:637/1645 train_time:58546ms step_avg:91.91ms +step:638/1645 train_time:58643ms step_avg:91.92ms +step:639/1645 train_time:58739ms step_avg:91.92ms +step:640/1645 train_time:58834ms step_avg:91.93ms +step:641/1645 train_time:58926ms step_avg:91.93ms +step:642/1645 train_time:59020ms step_avg:91.93ms +step:643/1645 train_time:59112ms step_avg:91.93ms +step:644/1645 train_time:59203ms step_avg:91.93ms +step:645/1645 train_time:59296ms step_avg:91.93ms +step:646/1645 train_time:59387ms step_avg:91.93ms +step:647/1645 train_time:59479ms step_avg:91.93ms +step:648/1645 train_time:59573ms step_avg:91.93ms +step:649/1645 train_time:59668ms step_avg:91.94ms +step:650/1645 train_time:59763ms step_avg:91.94ms +step:651/1645 train_time:59857ms step_avg:91.95ms +step:652/1645 train_time:59950ms step_avg:91.95ms +step:653/1645 train_time:60044ms step_avg:91.95ms +step:654/1645 train_time:60135ms step_avg:91.95ms +step:655/1645 train_time:60227ms step_avg:91.95ms +step:656/1645 train_time:60320ms step_avg:91.95ms +step:657/1645 train_time:60412ms step_avg:91.95ms +step:658/1645 train_time:60505ms step_avg:91.95ms +step:659/1645 train_time:60599ms step_avg:91.96ms +step:660/1645 train_time:60692ms step_avg:91.96ms +step:661/1645 train_time:60786ms step_avg:91.96ms +step:662/1645 train_time:60879ms step_avg:91.96ms +step:663/1645 train_time:60972ms step_avg:91.96ms +step:664/1645 train_time:61065ms step_avg:91.97ms +step:665/1645 train_time:61158ms step_avg:91.97ms +step:666/1645 train_time:61252ms step_avg:91.97ms +step:667/1645 train_time:61343ms step_avg:91.97ms +step:668/1645 train_time:61436ms step_avg:91.97ms +step:669/1645 train_time:61529ms step_avg:91.97ms +step:670/1645 train_time:61622ms step_avg:91.97ms +step:671/1645 train_time:61715ms step_avg:91.98ms +step:672/1645 train_time:61808ms step_avg:91.98ms +step:673/1645 train_time:61902ms step_avg:91.98ms +step:674/1645 train_time:61995ms step_avg:91.98ms +step:675/1645 train_time:62088ms step_avg:91.98ms +step:676/1645 train_time:62181ms step_avg:91.98ms +step:677/1645 train_time:62273ms step_avg:91.98ms +step:678/1645 train_time:62365ms step_avg:91.98ms +step:679/1645 train_time:62457ms step_avg:91.98ms +step:680/1645 train_time:62550ms step_avg:91.98ms +step:681/1645 train_time:62644ms step_avg:91.99ms +step:682/1645 train_time:62737ms step_avg:91.99ms +step:683/1645 train_time:62830ms step_avg:91.99ms +step:684/1645 train_time:62923ms step_avg:91.99ms +step:685/1645 train_time:63017ms step_avg:92.00ms +step:686/1645 train_time:63109ms step_avg:92.00ms +step:687/1645 train_time:63203ms step_avg:92.00ms +step:688/1645 train_time:63295ms step_avg:92.00ms +step:689/1645 train_time:63388ms step_avg:92.00ms +step:690/1645 train_time:63480ms step_avg:92.00ms +step:691/1645 train_time:63573ms step_avg:92.00ms +step:692/1645 train_time:63667ms step_avg:92.00ms +step:693/1645 train_time:63761ms step_avg:92.01ms +step:694/1645 train_time:63854ms step_avg:92.01ms +step:695/1645 train_time:63947ms step_avg:92.01ms +step:696/1645 train_time:64041ms step_avg:92.01ms +step:697/1645 train_time:64134ms step_avg:92.01ms +step:698/1645 train_time:64227ms step_avg:92.02ms +step:699/1645 train_time:64320ms step_avg:92.02ms +step:700/1645 train_time:64414ms step_avg:92.02ms +step:701/1645 train_time:64506ms step_avg:92.02ms +step:702/1645 train_time:64601ms step_avg:92.02ms +step:703/1645 train_time:64694ms step_avg:92.03ms +step:704/1645 train_time:64786ms step_avg:92.03ms +step:705/1645 train_time:64878ms step_avg:92.03ms +step:706/1645 train_time:64971ms step_avg:92.03ms +step:707/1645 train_time:65064ms step_avg:92.03ms +step:708/1645 train_time:65157ms step_avg:92.03ms +step:709/1645 train_time:65249ms step_avg:92.03ms +step:710/1645 train_time:65343ms step_avg:92.03ms +step:711/1645 train_time:65436ms step_avg:92.03ms +step:712/1645 train_time:65529ms step_avg:92.03ms +step:713/1645 train_time:65623ms step_avg:92.04ms +step:714/1645 train_time:65716ms step_avg:92.04ms +step:715/1645 train_time:65808ms step_avg:92.04ms +step:716/1645 train_time:65901ms step_avg:92.04ms +step:717/1645 train_time:65994ms step_avg:92.04ms +step:718/1645 train_time:66088ms step_avg:92.04ms +step:719/1645 train_time:66180ms step_avg:92.05ms +step:720/1645 train_time:66274ms step_avg:92.05ms +step:721/1645 train_time:66367ms step_avg:92.05ms +step:722/1645 train_time:66461ms step_avg:92.05ms +step:723/1645 train_time:66554ms step_avg:92.05ms +step:724/1645 train_time:66646ms step_avg:92.05ms +step:725/1645 train_time:66739ms step_avg:92.05ms +step:726/1645 train_time:66832ms step_avg:92.05ms +step:727/1645 train_time:66925ms step_avg:92.06ms +step:728/1645 train_time:67019ms step_avg:92.06ms +step:729/1645 train_time:67111ms step_avg:92.06ms +step:730/1645 train_time:67205ms step_avg:92.06ms +step:731/1645 train_time:67297ms step_avg:92.06ms +step:732/1645 train_time:67391ms step_avg:92.06ms +step:733/1645 train_time:67484ms step_avg:92.07ms +step:734/1645 train_time:67579ms step_avg:92.07ms +step:735/1645 train_time:67670ms step_avg:92.07ms +step:736/1645 train_time:67763ms step_avg:92.07ms +step:737/1645 train_time:67856ms step_avg:92.07ms +step:738/1645 train_time:67949ms step_avg:92.07ms +step:739/1645 train_time:68042ms step_avg:92.07ms +step:740/1645 train_time:68136ms step_avg:92.08ms +step:741/1645 train_time:68229ms step_avg:92.08ms +step:742/1645 train_time:68322ms step_avg:92.08ms +step:743/1645 train_time:68415ms step_avg:92.08ms +step:744/1645 train_time:68508ms step_avg:92.08ms +step:745/1645 train_time:68602ms step_avg:92.08ms +step:746/1645 train_time:68695ms step_avg:92.08ms +step:747/1645 train_time:68787ms step_avg:92.08ms +step:748/1645 train_time:68881ms step_avg:92.09ms +step:749/1645 train_time:68974ms step_avg:92.09ms +step:750/1645 train_time:69067ms step_avg:92.09ms +step:750/1645 val_loss:3.5597 train_time:69161ms step_avg:92.21ms +step:751/1645 train_time:69177ms step_avg:92.11ms +step:752/1645 train_time:69258ms step_avg:92.10ms +step:753/1645 train_time:69352ms step_avg:92.10ms +step:754/1645 train_time:69446ms step_avg:92.10ms +step:755/1645 train_time:69538ms step_avg:92.10ms +step:756/1645 train_time:69631ms step_avg:92.10ms +step:757/1645 train_time:69722ms step_avg:92.10ms +step:758/1645 train_time:69815ms step_avg:92.10ms +step:759/1645 train_time:69907ms step_avg:92.10ms +step:760/1645 train_time:70000ms step_avg:92.10ms +step:761/1645 train_time:70093ms step_avg:92.11ms +step:762/1645 train_time:70188ms step_avg:92.11ms +step:763/1645 train_time:70283ms step_avg:92.11ms +step:764/1645 train_time:70377ms step_avg:92.12ms +step:765/1645 train_time:70470ms step_avg:92.12ms +step:766/1645 train_time:70563ms step_avg:92.12ms +step:767/1645 train_time:70655ms step_avg:92.12ms +step:768/1645 train_time:70748ms step_avg:92.12ms +step:769/1645 train_time:70841ms step_avg:92.12ms +step:770/1645 train_time:70933ms step_avg:92.12ms +step:771/1645 train_time:71026ms step_avg:92.12ms +step:772/1645 train_time:71119ms step_avg:92.12ms +step:773/1645 train_time:71213ms step_avg:92.13ms +step:774/1645 train_time:71307ms step_avg:92.13ms +step:775/1645 train_time:71401ms step_avg:92.13ms +step:776/1645 train_time:71494ms step_avg:92.13ms +step:777/1645 train_time:71587ms step_avg:92.13ms +step:778/1645 train_time:71680ms step_avg:92.13ms +step:779/1645 train_time:71773ms step_avg:92.14ms +step:780/1645 train_time:71866ms step_avg:92.14ms +step:781/1645 train_time:71958ms step_avg:92.14ms +step:782/1645 train_time:72051ms step_avg:92.14ms +step:783/1645 train_time:72145ms step_avg:92.14ms +step:784/1645 train_time:72238ms step_avg:92.14ms +step:785/1645 train_time:72332ms step_avg:92.14ms +step:786/1645 train_time:72425ms step_avg:92.14ms +step:787/1645 train_time:72518ms step_avg:92.14ms +step:788/1645 train_time:72611ms step_avg:92.15ms +step:789/1645 train_time:72704ms step_avg:92.15ms +step:790/1645 train_time:72796ms step_avg:92.15ms +step:791/1645 train_time:72889ms step_avg:92.15ms +step:792/1645 train_time:72981ms step_avg:92.15ms +step:793/1645 train_time:73075ms step_avg:92.15ms +step:794/1645 train_time:73169ms step_avg:92.15ms +step:795/1645 train_time:73261ms step_avg:92.15ms +step:796/1645 train_time:73355ms step_avg:92.15ms +step:797/1645 train_time:73448ms step_avg:92.16ms +step:798/1645 train_time:73541ms step_avg:92.16ms +step:799/1645 train_time:73634ms step_avg:92.16ms +step:800/1645 train_time:73727ms step_avg:92.16ms +step:801/1645 train_time:73820ms step_avg:92.16ms +step:802/1645 train_time:73912ms step_avg:92.16ms +step:803/1645 train_time:74005ms step_avg:92.16ms +step:804/1645 train_time:74099ms step_avg:92.16ms +step:805/1645 train_time:74192ms step_avg:92.16ms +step:806/1645 train_time:74286ms step_avg:92.17ms +step:807/1645 train_time:74380ms step_avg:92.17ms +step:808/1645 train_time:74473ms step_avg:92.17ms +step:809/1645 train_time:74567ms step_avg:92.17ms +step:810/1645 train_time:74661ms step_avg:92.17ms +step:811/1645 train_time:74753ms step_avg:92.17ms +step:812/1645 train_time:74846ms step_avg:92.17ms +step:813/1645 train_time:74938ms step_avg:92.17ms +step:814/1645 train_time:75032ms step_avg:92.18ms +step:815/1645 train_time:75124ms step_avg:92.18ms +step:816/1645 train_time:75217ms step_avg:92.18ms +step:817/1645 train_time:75309ms step_avg:92.18ms +step:818/1645 train_time:75402ms step_avg:92.18ms +step:819/1645 train_time:75495ms step_avg:92.18ms +step:820/1645 train_time:75588ms step_avg:92.18ms +step:821/1645 train_time:75682ms step_avg:92.18ms +step:822/1645 train_time:75774ms step_avg:92.18ms +step:823/1645 train_time:75868ms step_avg:92.18ms +step:824/1645 train_time:75962ms step_avg:92.19ms +step:825/1645 train_time:76054ms step_avg:92.19ms +step:826/1645 train_time:76148ms step_avg:92.19ms +step:827/1645 train_time:76241ms step_avg:92.19ms +step:828/1645 train_time:76334ms step_avg:92.19ms +step:829/1645 train_time:76426ms step_avg:92.19ms +step:830/1645 train_time:76519ms step_avg:92.19ms +step:831/1645 train_time:76612ms step_avg:92.19ms +step:832/1645 train_time:76705ms step_avg:92.19ms +step:833/1645 train_time:76798ms step_avg:92.19ms +step:834/1645 train_time:76891ms step_avg:92.19ms +step:835/1645 train_time:76984ms step_avg:92.20ms +step:836/1645 train_time:77077ms step_avg:92.20ms +step:837/1645 train_time:77170ms step_avg:92.20ms +step:838/1645 train_time:77263ms step_avg:92.20ms +step:839/1645 train_time:77355ms step_avg:92.20ms +step:840/1645 train_time:77449ms step_avg:92.20ms +step:841/1645 train_time:77542ms step_avg:92.20ms +step:842/1645 train_time:77635ms step_avg:92.20ms +step:843/1645 train_time:77728ms step_avg:92.20ms +step:844/1645 train_time:77821ms step_avg:92.20ms +step:845/1645 train_time:77913ms step_avg:92.20ms +step:846/1645 train_time:78007ms step_avg:92.21ms +step:847/1645 train_time:78101ms step_avg:92.21ms +step:848/1645 train_time:78193ms step_avg:92.21ms +step:849/1645 train_time:78286ms step_avg:92.21ms +step:850/1645 train_time:78380ms step_avg:92.21ms +step:851/1645 train_time:78474ms step_avg:92.21ms +step:852/1645 train_time:78566ms step_avg:92.21ms +step:853/1645 train_time:78659ms step_avg:92.21ms +step:854/1645 train_time:78752ms step_avg:92.22ms +step:855/1645 train_time:78844ms step_avg:92.22ms +step:856/1645 train_time:78937ms step_avg:92.22ms +step:857/1645 train_time:79030ms step_avg:92.22ms +step:858/1645 train_time:79123ms step_avg:92.22ms +step:859/1645 train_time:79215ms step_avg:92.22ms +step:860/1645 train_time:79308ms step_avg:92.22ms +step:861/1645 train_time:79401ms step_avg:92.22ms +step:862/1645 train_time:79494ms step_avg:92.22ms +step:863/1645 train_time:79588ms step_avg:92.22ms +step:864/1645 train_time:79680ms step_avg:92.22ms +step:865/1645 train_time:79773ms step_avg:92.22ms +step:866/1645 train_time:79866ms step_avg:92.22ms +step:867/1645 train_time:79959ms step_avg:92.23ms +step:868/1645 train_time:80052ms step_avg:92.23ms +step:869/1645 train_time:80145ms step_avg:92.23ms +step:870/1645 train_time:80238ms step_avg:92.23ms +step:871/1645 train_time:80331ms step_avg:92.23ms +step:872/1645 train_time:80424ms step_avg:92.23ms +step:873/1645 train_time:80517ms step_avg:92.23ms +step:874/1645 train_time:80610ms step_avg:92.23ms +step:875/1645 train_time:80704ms step_avg:92.23ms +step:875/1645 val_loss:3.5121 train_time:80797ms step_avg:92.34ms +step:876/1645 train_time:80818ms step_avg:92.26ms +step:877/1645 train_time:80895ms step_avg:92.24ms +step:878/1645 train_time:80988ms step_avg:92.24ms +step:879/1645 train_time:81081ms step_avg:92.24ms +step:880/1645 train_time:81173ms step_avg:92.24ms +step:881/1645 train_time:81265ms step_avg:92.24ms +step:882/1645 train_time:81357ms step_avg:92.24ms +step:883/1645 train_time:81449ms step_avg:92.24ms +step:884/1645 train_time:81542ms step_avg:92.24ms +step:885/1645 train_time:81634ms step_avg:92.24ms +step:886/1645 train_time:81730ms step_avg:92.25ms +step:887/1645 train_time:81824ms step_avg:92.25ms +step:888/1645 train_time:81918ms step_avg:92.25ms +step:889/1645 train_time:82011ms step_avg:92.25ms +step:890/1645 train_time:82104ms step_avg:92.25ms +step:891/1645 train_time:82196ms step_avg:92.25ms +step:892/1645 train_time:82289ms step_avg:92.25ms +step:893/1645 train_time:82381ms step_avg:92.25ms +step:894/1645 train_time:82474ms step_avg:92.25ms +step:895/1645 train_time:82566ms step_avg:92.25ms +step:896/1645 train_time:82659ms step_avg:92.25ms +step:897/1645 train_time:82753ms step_avg:92.26ms +step:898/1645 train_time:82847ms step_avg:92.26ms +step:899/1645 train_time:82940ms step_avg:92.26ms +step:900/1645 train_time:83034ms step_avg:92.26ms +step:901/1645 train_time:83127ms step_avg:92.26ms +step:902/1645 train_time:83220ms step_avg:92.26ms +step:903/1645 train_time:83312ms step_avg:92.26ms +step:904/1645 train_time:83405ms step_avg:92.26ms +step:905/1645 train_time:83497ms step_avg:92.26ms +step:906/1645 train_time:83590ms step_avg:92.26ms +step:907/1645 train_time:83684ms step_avg:92.26ms +step:908/1645 train_time:83778ms step_avg:92.27ms +step:909/1645 train_time:83870ms step_avg:92.27ms +step:910/1645 train_time:83964ms step_avg:92.27ms +step:911/1645 train_time:84056ms step_avg:92.27ms +step:912/1645 train_time:84149ms step_avg:92.27ms +step:913/1645 train_time:84242ms step_avg:92.27ms +step:914/1645 train_time:84334ms step_avg:92.27ms +step:915/1645 train_time:84427ms step_avg:92.27ms +step:916/1645 train_time:84520ms step_avg:92.27ms +step:917/1645 train_time:84613ms step_avg:92.27ms +step:918/1645 train_time:84706ms step_avg:92.27ms +step:919/1645 train_time:84799ms step_avg:92.27ms +step:920/1645 train_time:84892ms step_avg:92.27ms +step:921/1645 train_time:84985ms step_avg:92.27ms +step:922/1645 train_time:85078ms step_avg:92.28ms +step:923/1645 train_time:85171ms step_avg:92.28ms +step:924/1645 train_time:85264ms step_avg:92.28ms +step:925/1645 train_time:85356ms step_avg:92.28ms +step:926/1645 train_time:85449ms step_avg:92.28ms +step:927/1645 train_time:85542ms step_avg:92.28ms +step:928/1645 train_time:85635ms step_avg:92.28ms +step:929/1645 train_time:85728ms step_avg:92.28ms +step:930/1645 train_time:85822ms step_avg:92.28ms +step:931/1645 train_time:85915ms step_avg:92.28ms +step:932/1645 train_time:86008ms step_avg:92.28ms +step:933/1645 train_time:86101ms step_avg:92.28ms +step:934/1645 train_time:86194ms step_avg:92.29ms +step:935/1645 train_time:86287ms step_avg:92.29ms +step:936/1645 train_time:86380ms step_avg:92.29ms +step:937/1645 train_time:86472ms step_avg:92.29ms +step:938/1645 train_time:86565ms step_avg:92.29ms +step:939/1645 train_time:86658ms step_avg:92.29ms +step:940/1645 train_time:86752ms step_avg:92.29ms +step:941/1645 train_time:86844ms step_avg:92.29ms +step:942/1645 train_time:86938ms step_avg:92.29ms +step:943/1645 train_time:87031ms step_avg:92.29ms +step:944/1645 train_time:87124ms step_avg:92.29ms +step:945/1645 train_time:87216ms step_avg:92.29ms +step:946/1645 train_time:87310ms step_avg:92.29ms +step:947/1645 train_time:87403ms step_avg:92.30ms +step:948/1645 train_time:87496ms step_avg:92.29ms +step:949/1645 train_time:87589ms step_avg:92.30ms +step:950/1645 train_time:87682ms step_avg:92.30ms +step:951/1645 train_time:87775ms step_avg:92.30ms +step:952/1645 train_time:87867ms step_avg:92.30ms +step:953/1645 train_time:87960ms step_avg:92.30ms +step:954/1645 train_time:88054ms step_avg:92.30ms +step:955/1645 train_time:88146ms step_avg:92.30ms +step:956/1645 train_time:88240ms step_avg:92.30ms +step:957/1645 train_time:88333ms step_avg:92.30ms +step:958/1645 train_time:88427ms step_avg:92.30ms +step:959/1645 train_time:88520ms step_avg:92.30ms +step:960/1645 train_time:88612ms step_avg:92.30ms +step:961/1645 train_time:88705ms step_avg:92.31ms +step:962/1645 train_time:88799ms step_avg:92.31ms +step:963/1645 train_time:88891ms step_avg:92.31ms +step:964/1645 train_time:88984ms step_avg:92.31ms +step:965/1645 train_time:89077ms step_avg:92.31ms +step:966/1645 train_time:89169ms step_avg:92.31ms +step:967/1645 train_time:89264ms step_avg:92.31ms +step:968/1645 train_time:89356ms step_avg:92.31ms +step:969/1645 train_time:89449ms step_avg:92.31ms +step:970/1645 train_time:89542ms step_avg:92.31ms +step:971/1645 train_time:89635ms step_avg:92.31ms +step:972/1645 train_time:89728ms step_avg:92.31ms +step:973/1645 train_time:89821ms step_avg:92.31ms +step:974/1645 train_time:89913ms step_avg:92.31ms +step:975/1645 train_time:90006ms step_avg:92.31ms +step:976/1645 train_time:90099ms step_avg:92.31ms +step:977/1645 train_time:90191ms step_avg:92.31ms +step:978/1645 train_time:90284ms step_avg:92.32ms +step:979/1645 train_time:90378ms step_avg:92.32ms +step:980/1645 train_time:90471ms step_avg:92.32ms +step:981/1645 train_time:90564ms step_avg:92.32ms +step:982/1645 train_time:90656ms step_avg:92.32ms +step:983/1645 train_time:90750ms step_avg:92.32ms +step:984/1645 train_time:90843ms step_avg:92.32ms +step:985/1645 train_time:90935ms step_avg:92.32ms +step:986/1645 train_time:91029ms step_avg:92.32ms +step:987/1645 train_time:91122ms step_avg:92.32ms +step:988/1645 train_time:91215ms step_avg:92.32ms +step:989/1645 train_time:91308ms step_avg:92.32ms +step:990/1645 train_time:91402ms step_avg:92.33ms +step:991/1645 train_time:91495ms step_avg:92.33ms +step:992/1645 train_time:91587ms step_avg:92.33ms +step:993/1645 train_time:91681ms step_avg:92.33ms +step:994/1645 train_time:91773ms step_avg:92.33ms +step:995/1645 train_time:91865ms step_avg:92.33ms +step:996/1645 train_time:91958ms step_avg:92.33ms +step:997/1645 train_time:92051ms step_avg:92.33ms +step:998/1645 train_time:92144ms step_avg:92.33ms +step:999/1645 train_time:92236ms step_avg:92.33ms +step:1000/1645 train_time:92331ms step_avg:92.33ms +step:1000/1645 val_loss:3.4634 train_time:92424ms step_avg:92.42ms +step:1001/1645 train_time:92446ms step_avg:92.35ms +step:1002/1645 train_time:92522ms step_avg:92.34ms +step:1003/1645 train_time:92619ms step_avg:92.34ms +step:1004/1645 train_time:92712ms step_avg:92.34ms +step:1005/1645 train_time:92805ms step_avg:92.34ms +step:1006/1645 train_time:92897ms step_avg:92.34ms +step:1007/1645 train_time:92988ms step_avg:92.34ms +step:1008/1645 train_time:93079ms step_avg:92.34ms +step:1009/1645 train_time:93171ms step_avg:92.34ms +step:1010/1645 train_time:93265ms step_avg:92.34ms +step:1011/1645 train_time:93358ms step_avg:92.34ms +step:1012/1645 train_time:93453ms step_avg:92.34ms +step:1013/1645 train_time:93549ms step_avg:92.35ms +step:1014/1645 train_time:93643ms step_avg:92.35ms +step:1015/1645 train_time:93736ms step_avg:92.35ms +step:1016/1645 train_time:93830ms step_avg:92.35ms +step:1017/1645 train_time:93921ms step_avg:92.35ms +step:1018/1645 train_time:94013ms step_avg:92.35ms +step:1019/1645 train_time:94105ms step_avg:92.35ms +step:1020/1645 train_time:94197ms step_avg:92.35ms +step:1021/1645 train_time:94289ms step_avg:92.35ms +step:1022/1645 train_time:94382ms step_avg:92.35ms +step:1023/1645 train_time:94476ms step_avg:92.35ms +step:1024/1645 train_time:94571ms step_avg:92.35ms +step:1025/1645 train_time:94666ms step_avg:92.36ms +step:1026/1645 train_time:94759ms step_avg:92.36ms +step:1027/1645 train_time:94852ms step_avg:92.36ms +step:1028/1645 train_time:94945ms step_avg:92.36ms +step:1029/1645 train_time:95037ms step_avg:92.36ms +step:1030/1645 train_time:95129ms step_avg:92.36ms +step:1031/1645 train_time:95221ms step_avg:92.36ms +step:1032/1645 train_time:95313ms step_avg:92.36ms +step:1033/1645 train_time:95407ms step_avg:92.36ms +step:1034/1645 train_time:95500ms step_avg:92.36ms +step:1035/1645 train_time:95595ms step_avg:92.36ms +step:1036/1645 train_time:95689ms step_avg:92.36ms +step:1037/1645 train_time:95781ms step_avg:92.36ms +step:1038/1645 train_time:95875ms step_avg:92.36ms +step:1039/1645 train_time:95968ms step_avg:92.37ms +step:1040/1645 train_time:96060ms step_avg:92.37ms +step:1041/1645 train_time:96153ms step_avg:92.37ms +step:1042/1645 train_time:96245ms step_avg:92.37ms +step:1043/1645 train_time:96338ms step_avg:92.37ms +step:1044/1645 train_time:96431ms step_avg:92.37ms +step:1045/1645 train_time:96524ms step_avg:92.37ms +step:1046/1645 train_time:96618ms step_avg:92.37ms +step:1047/1645 train_time:96711ms step_avg:92.37ms +step:1048/1645 train_time:96804ms step_avg:92.37ms +step:1049/1645 train_time:96897ms step_avg:92.37ms +step:1050/1645 train_time:96991ms step_avg:92.37ms +step:1051/1645 train_time:97082ms step_avg:92.37ms +step:1052/1645 train_time:97175ms step_avg:92.37ms +step:1053/1645 train_time:97268ms step_avg:92.37ms +step:1054/1645 train_time:97360ms step_avg:92.37ms +step:1055/1645 train_time:97454ms step_avg:92.37ms +step:1056/1645 train_time:97547ms step_avg:92.37ms +step:1057/1645 train_time:97640ms step_avg:92.37ms +step:1058/1645 train_time:97735ms step_avg:92.38ms +step:1059/1645 train_time:97828ms step_avg:92.38ms +step:1060/1645 train_time:97921ms step_avg:92.38ms +step:1061/1645 train_time:98013ms step_avg:92.38ms +step:1062/1645 train_time:98106ms step_avg:92.38ms +step:1063/1645 train_time:98199ms step_avg:92.38ms +step:1064/1645 train_time:98292ms step_avg:92.38ms +step:1065/1645 train_time:98384ms step_avg:92.38ms +step:1066/1645 train_time:98477ms step_avg:92.38ms +step:1067/1645 train_time:98571ms step_avg:92.38ms +step:1068/1645 train_time:98665ms step_avg:92.38ms +step:1069/1645 train_time:98758ms step_avg:92.38ms +step:1070/1645 train_time:98852ms step_avg:92.38ms +step:1071/1645 train_time:98945ms step_avg:92.39ms +step:1072/1645 train_time:99038ms step_avg:92.39ms +step:1073/1645 train_time:99132ms step_avg:92.39ms +step:1074/1645 train_time:99224ms step_avg:92.39ms +step:1075/1645 train_time:99316ms step_avg:92.39ms +step:1076/1645 train_time:99409ms step_avg:92.39ms +step:1077/1645 train_time:99501ms step_avg:92.39ms +step:1078/1645 train_time:99594ms step_avg:92.39ms +step:1079/1645 train_time:99687ms step_avg:92.39ms +step:1080/1645 train_time:99780ms step_avg:92.39ms +step:1081/1645 train_time:99874ms step_avg:92.39ms +step:1082/1645 train_time:99967ms step_avg:92.39ms +step:1083/1645 train_time:100060ms step_avg:92.39ms +step:1084/1645 train_time:100152ms step_avg:92.39ms +step:1085/1645 train_time:100245ms step_avg:92.39ms +step:1086/1645 train_time:100337ms step_avg:92.39ms +step:1087/1645 train_time:100430ms step_avg:92.39ms +step:1088/1645 train_time:100523ms step_avg:92.39ms +step:1089/1645 train_time:100616ms step_avg:92.39ms +step:1090/1645 train_time:100709ms step_avg:92.39ms +step:1091/1645 train_time:100802ms step_avg:92.39ms +step:1092/1645 train_time:100896ms step_avg:92.40ms +step:1093/1645 train_time:100990ms step_avg:92.40ms +step:1094/1645 train_time:101081ms step_avg:92.40ms +step:1095/1645 train_time:101175ms step_avg:92.40ms +step:1096/1645 train_time:101268ms step_avg:92.40ms +step:1097/1645 train_time:101362ms step_avg:92.40ms +step:1098/1645 train_time:101455ms step_avg:92.40ms +step:1099/1645 train_time:101548ms step_avg:92.40ms +step:1100/1645 train_time:101641ms step_avg:92.40ms +step:1101/1645 train_time:101735ms step_avg:92.40ms +step:1102/1645 train_time:101829ms step_avg:92.40ms +step:1103/1645 train_time:101922ms step_avg:92.40ms +step:1104/1645 train_time:102016ms step_avg:92.41ms +step:1105/1645 train_time:102109ms step_avg:92.41ms +step:1106/1645 train_time:102203ms step_avg:92.41ms +step:1107/1645 train_time:102297ms step_avg:92.41ms +step:1108/1645 train_time:102390ms step_avg:92.41ms +step:1109/1645 train_time:102484ms step_avg:92.41ms +step:1110/1645 train_time:102577ms step_avg:92.41ms +step:1111/1645 train_time:102672ms step_avg:92.41ms +step:1112/1645 train_time:102765ms step_avg:92.41ms +step:1113/1645 train_time:102859ms step_avg:92.42ms +step:1114/1645 train_time:102953ms step_avg:92.42ms +step:1115/1645 train_time:103046ms step_avg:92.42ms +step:1116/1645 train_time:103139ms step_avg:92.42ms +step:1117/1645 train_time:103232ms step_avg:92.42ms +step:1118/1645 train_time:103327ms step_avg:92.42ms +step:1119/1645 train_time:103421ms step_avg:92.42ms +step:1120/1645 train_time:103514ms step_avg:92.42ms +step:1121/1645 train_time:103609ms step_avg:92.43ms +step:1122/1645 train_time:103702ms step_avg:92.43ms +step:1123/1645 train_time:103795ms step_avg:92.43ms +step:1124/1645 train_time:103889ms step_avg:92.43ms +step:1125/1645 train_time:103982ms step_avg:92.43ms +step:1125/1645 val_loss:3.4102 train_time:104076ms step_avg:92.51ms +step:1126/1645 train_time:104092ms step_avg:92.44ms +step:1127/1645 train_time:104175ms step_avg:92.44ms +step:1128/1645 train_time:104276ms step_avg:92.44ms +step:1129/1645 train_time:104369ms step_avg:92.44ms +step:1130/1645 train_time:104462ms step_avg:92.44ms +step:1131/1645 train_time:104555ms step_avg:92.44ms +step:1132/1645 train_time:104647ms step_avg:92.44ms +step:1133/1645 train_time:104739ms step_avg:92.44ms +step:1134/1645 train_time:104832ms step_avg:92.44ms +step:1135/1645 train_time:104924ms step_avg:92.44ms +step:1136/1645 train_time:105017ms step_avg:92.44ms +step:1137/1645 train_time:105112ms step_avg:92.45ms +step:1138/1645 train_time:105213ms step_avg:92.45ms +step:1139/1645 train_time:105308ms step_avg:92.46ms +step:1140/1645 train_time:105402ms step_avg:92.46ms +step:1141/1645 train_time:105495ms step_avg:92.46ms +step:1142/1645 train_time:105588ms step_avg:92.46ms +step:1143/1645 train_time:105680ms step_avg:92.46ms +step:1144/1645 train_time:105772ms step_avg:92.46ms +step:1145/1645 train_time:105865ms step_avg:92.46ms +step:1146/1645 train_time:105957ms step_avg:92.46ms +step:1147/1645 train_time:106051ms step_avg:92.46ms +step:1148/1645 train_time:106145ms step_avg:92.46ms +step:1149/1645 train_time:106240ms step_avg:92.46ms +step:1150/1645 train_time:106337ms step_avg:92.47ms +step:1151/1645 train_time:106430ms step_avg:92.47ms +step:1152/1645 train_time:106523ms step_avg:92.47ms +step:1153/1645 train_time:106617ms step_avg:92.47ms +step:1154/1645 train_time:106709ms step_avg:92.47ms +step:1155/1645 train_time:106802ms step_avg:92.47ms +step:1156/1645 train_time:106895ms step_avg:92.47ms +step:1157/1645 train_time:106988ms step_avg:92.47ms +step:1158/1645 train_time:107082ms step_avg:92.47ms +step:1159/1645 train_time:107176ms step_avg:92.47ms +step:1160/1645 train_time:107271ms step_avg:92.47ms +step:1161/1645 train_time:107366ms step_avg:92.48ms +step:1162/1645 train_time:107459ms step_avg:92.48ms +step:1163/1645 train_time:107553ms step_avg:92.48ms +step:1164/1645 train_time:107647ms step_avg:92.48ms +step:1165/1645 train_time:107739ms step_avg:92.48ms +step:1166/1645 train_time:107833ms step_avg:92.48ms +step:1167/1645 train_time:107926ms step_avg:92.48ms +step:1168/1645 train_time:108019ms step_avg:92.48ms +step:1169/1645 train_time:108113ms step_avg:92.48ms +step:1170/1645 train_time:108207ms step_avg:92.48ms +step:1171/1645 train_time:108302ms step_avg:92.49ms +step:1172/1645 train_time:108396ms step_avg:92.49ms +step:1173/1645 train_time:108491ms step_avg:92.49ms +step:1174/1645 train_time:108584ms step_avg:92.49ms +step:1175/1645 train_time:108677ms step_avg:92.49ms +step:1176/1645 train_time:108770ms step_avg:92.49ms +step:1177/1645 train_time:108863ms step_avg:92.49ms +step:1178/1645 train_time:108958ms step_avg:92.49ms +step:1179/1645 train_time:109051ms step_avg:92.49ms +step:1180/1645 train_time:109144ms step_avg:92.50ms +step:1181/1645 train_time:109237ms step_avg:92.50ms +step:1182/1645 train_time:109332ms step_avg:92.50ms +step:1183/1645 train_time:109427ms step_avg:92.50ms +step:1184/1645 train_time:109520ms step_avg:92.50ms +step:1185/1645 train_time:109614ms step_avg:92.50ms +step:1186/1645 train_time:109707ms step_avg:92.50ms +step:1187/1645 train_time:109800ms step_avg:92.50ms +step:1188/1645 train_time:109893ms step_avg:92.50ms +step:1189/1645 train_time:109988ms step_avg:92.50ms +step:1190/1645 train_time:110080ms step_avg:92.50ms +step:1191/1645 train_time:110173ms step_avg:92.51ms +step:1192/1645 train_time:110269ms step_avg:92.51ms +step:1193/1645 train_time:110362ms step_avg:92.51ms +step:1194/1645 train_time:110456ms step_avg:92.51ms +step:1195/1645 train_time:110551ms step_avg:92.51ms +step:1196/1645 train_time:110644ms step_avg:92.51ms +step:1197/1645 train_time:110737ms step_avg:92.51ms +step:1198/1645 train_time:110831ms step_avg:92.51ms +step:1199/1645 train_time:110924ms step_avg:92.51ms +step:1200/1645 train_time:111017ms step_avg:92.51ms +step:1201/1645 train_time:111110ms step_avg:92.51ms +step:1202/1645 train_time:111204ms step_avg:92.52ms +step:1203/1645 train_time:111297ms step_avg:92.52ms +step:1204/1645 train_time:111391ms step_avg:92.52ms +step:1205/1645 train_time:111485ms step_avg:92.52ms +step:1206/1645 train_time:111579ms step_avg:92.52ms +step:1207/1645 train_time:111672ms step_avg:92.52ms +step:1208/1645 train_time:111768ms step_avg:92.52ms +step:1209/1645 train_time:111860ms step_avg:92.52ms +step:1210/1645 train_time:111954ms step_avg:92.52ms +step:1211/1645 train_time:112047ms step_avg:92.52ms +step:1212/1645 train_time:112140ms step_avg:92.52ms +step:1213/1645 train_time:112234ms step_avg:92.53ms +step:1214/1645 train_time:112327ms step_avg:92.53ms +step:1215/1645 train_time:112420ms step_avg:92.53ms +step:1216/1645 train_time:112514ms step_avg:92.53ms +step:1217/1645 train_time:112608ms step_avg:92.53ms +step:1218/1645 train_time:112701ms step_avg:92.53ms +step:1219/1645 train_time:112795ms step_avg:92.53ms +step:1220/1645 train_time:112887ms step_avg:92.53ms +step:1221/1645 train_time:112981ms step_avg:92.53ms +step:1222/1645 train_time:113075ms step_avg:92.53ms +step:1223/1645 train_time:113167ms step_avg:92.53ms +step:1224/1645 train_time:113261ms step_avg:92.53ms +step:1225/1645 train_time:113355ms step_avg:92.53ms +step:1226/1645 train_time:113448ms step_avg:92.54ms +step:1227/1645 train_time:113541ms step_avg:92.54ms +step:1228/1645 train_time:113636ms step_avg:92.54ms +step:1229/1645 train_time:113729ms step_avg:92.54ms +step:1230/1645 train_time:113823ms step_avg:92.54ms +step:1231/1645 train_time:113916ms step_avg:92.54ms +step:1232/1645 train_time:114010ms step_avg:92.54ms +step:1233/1645 train_time:114104ms step_avg:92.54ms +step:1234/1645 train_time:114197ms step_avg:92.54ms +step:1235/1645 train_time:114292ms step_avg:92.54ms +step:1236/1645 train_time:114386ms step_avg:92.54ms +step:1237/1645 train_time:114480ms step_avg:92.55ms +step:1238/1645 train_time:114572ms step_avg:92.55ms +step:1239/1645 train_time:114665ms step_avg:92.55ms +step:1240/1645 train_time:114758ms step_avg:92.55ms +step:1241/1645 train_time:114853ms step_avg:92.55ms +step:1242/1645 train_time:114946ms step_avg:92.55ms +step:1243/1645 train_time:115039ms step_avg:92.55ms +step:1244/1645 train_time:115133ms step_avg:92.55ms +step:1245/1645 train_time:115228ms step_avg:92.55ms +step:1246/1645 train_time:115322ms step_avg:92.55ms +step:1247/1645 train_time:115415ms step_avg:92.55ms +step:1248/1645 train_time:115508ms step_avg:92.55ms +step:1249/1645 train_time:115602ms step_avg:92.56ms +step:1250/1645 train_time:115695ms step_avg:92.56ms +step:1250/1645 val_loss:3.3726 train_time:115789ms step_avg:92.63ms +step:1251/1645 train_time:115811ms step_avg:92.57ms +step:1252/1645 train_time:115887ms step_avg:92.56ms +step:1253/1645 train_time:115983ms step_avg:92.56ms +step:1254/1645 train_time:116076ms step_avg:92.56ms +step:1255/1645 train_time:116169ms step_avg:92.56ms +step:1256/1645 train_time:116261ms step_avg:92.56ms +step:1257/1645 train_time:116353ms step_avg:92.56ms +step:1258/1645 train_time:116446ms step_avg:92.56ms +step:1259/1645 train_time:116539ms step_avg:92.57ms +step:1260/1645 train_time:116632ms step_avg:92.57ms +step:1261/1645 train_time:116727ms step_avg:92.57ms +step:1262/1645 train_time:116821ms step_avg:92.57ms +step:1263/1645 train_time:116916ms step_avg:92.57ms +step:1264/1645 train_time:117011ms step_avg:92.57ms +step:1265/1645 train_time:117105ms step_avg:92.57ms +step:1266/1645 train_time:117198ms step_avg:92.57ms +step:1267/1645 train_time:117292ms step_avg:92.57ms +step:1268/1645 train_time:117385ms step_avg:92.57ms +step:1269/1645 train_time:117477ms step_avg:92.57ms +step:1270/1645 train_time:117570ms step_avg:92.57ms +step:1271/1645 train_time:117664ms step_avg:92.58ms +step:1272/1645 train_time:117757ms step_avg:92.58ms +step:1273/1645 train_time:117851ms step_avg:92.58ms +step:1274/1645 train_time:117945ms step_avg:92.58ms +step:1275/1645 train_time:118039ms step_avg:92.58ms +step:1276/1645 train_time:118133ms step_avg:92.58ms +step:1277/1645 train_time:118227ms step_avg:92.58ms +step:1278/1645 train_time:118320ms step_avg:92.58ms +step:1279/1645 train_time:118413ms step_avg:92.58ms +step:1280/1645 train_time:118506ms step_avg:92.58ms +step:1281/1645 train_time:118599ms step_avg:92.58ms +step:1282/1645 train_time:118694ms step_avg:92.58ms +step:1283/1645 train_time:118788ms step_avg:92.59ms +step:1284/1645 train_time:118882ms step_avg:92.59ms +step:1285/1645 train_time:118975ms step_avg:92.59ms +step:1286/1645 train_time:119070ms step_avg:92.59ms +step:1287/1645 train_time:119164ms step_avg:92.59ms +step:1288/1645 train_time:119258ms step_avg:92.59ms +step:1289/1645 train_time:119351ms step_avg:92.59ms +step:1290/1645 train_time:119444ms step_avg:92.59ms +step:1291/1645 train_time:119538ms step_avg:92.59ms +step:1292/1645 train_time:119631ms step_avg:92.59ms +step:1293/1645 train_time:119725ms step_avg:92.59ms +step:1294/1645 train_time:119819ms step_avg:92.60ms +step:1295/1645 train_time:119912ms step_avg:92.60ms +step:1296/1645 train_time:120007ms step_avg:92.60ms +step:1297/1645 train_time:120102ms step_avg:92.60ms +step:1298/1645 train_time:120196ms step_avg:92.60ms +step:1299/1645 train_time:120290ms step_avg:92.60ms +step:1300/1645 train_time:120382ms step_avg:92.60ms +step:1301/1645 train_time:120475ms step_avg:92.60ms +step:1302/1645 train_time:120570ms step_avg:92.60ms +step:1303/1645 train_time:120663ms step_avg:92.60ms +step:1304/1645 train_time:120755ms step_avg:92.60ms +step:1305/1645 train_time:120849ms step_avg:92.60ms +step:1306/1645 train_time:120945ms step_avg:92.61ms +step:1307/1645 train_time:121038ms step_avg:92.61ms +step:1308/1645 train_time:121132ms step_avg:92.61ms +step:1309/1645 train_time:121226ms step_avg:92.61ms +step:1310/1645 train_time:121320ms step_avg:92.61ms +step:1311/1645 train_time:121413ms step_avg:92.61ms +step:1312/1645 train_time:121507ms step_avg:92.61ms +step:1313/1645 train_time:121600ms step_avg:92.61ms +step:1314/1645 train_time:121693ms step_avg:92.61ms +step:1315/1645 train_time:121786ms step_avg:92.61ms +step:1316/1645 train_time:121879ms step_avg:92.61ms +step:1317/1645 train_time:121973ms step_avg:92.61ms +step:1318/1645 train_time:122069ms step_avg:92.62ms +step:1319/1645 train_time:122163ms step_avg:92.62ms +step:1320/1645 train_time:122257ms step_avg:92.62ms +step:1321/1645 train_time:122351ms step_avg:92.62ms +step:1322/1645 train_time:122445ms step_avg:92.62ms +step:1323/1645 train_time:122538ms step_avg:92.62ms +step:1324/1645 train_time:122632ms step_avg:92.62ms +step:1325/1645 train_time:122725ms step_avg:92.62ms +step:1326/1645 train_time:122818ms step_avg:92.62ms +step:1327/1645 train_time:122911ms step_avg:92.62ms +step:1328/1645 train_time:123005ms step_avg:92.62ms +step:1329/1645 train_time:123098ms step_avg:92.62ms +step:1330/1645 train_time:123193ms step_avg:92.63ms +step:1331/1645 train_time:123286ms step_avg:92.63ms +step:1332/1645 train_time:123380ms step_avg:92.63ms +step:1333/1645 train_time:123473ms step_avg:92.63ms +step:1334/1645 train_time:123568ms step_avg:92.63ms +step:1335/1645 train_time:123662ms step_avg:92.63ms +step:1336/1645 train_time:123756ms step_avg:92.63ms +step:1337/1645 train_time:123850ms step_avg:92.63ms +step:1338/1645 train_time:123943ms step_avg:92.63ms +step:1339/1645 train_time:124037ms step_avg:92.63ms +step:1340/1645 train_time:124130ms step_avg:92.63ms +step:1341/1645 train_time:124224ms step_avg:92.64ms +step:1342/1645 train_time:124317ms step_avg:92.64ms +step:1343/1645 train_time:124410ms step_avg:92.64ms +step:1344/1645 train_time:124503ms step_avg:92.64ms +step:1345/1645 train_time:124597ms step_avg:92.64ms +step:1346/1645 train_time:124691ms step_avg:92.64ms +step:1347/1645 train_time:124784ms step_avg:92.64ms +step:1348/1645 train_time:124877ms step_avg:92.64ms +step:1349/1645 train_time:124971ms step_avg:92.64ms +step:1350/1645 train_time:125067ms step_avg:92.64ms +step:1351/1645 train_time:125161ms step_avg:92.64ms +step:1352/1645 train_time:125254ms step_avg:92.64ms +step:1353/1645 train_time:125347ms step_avg:92.64ms +step:1354/1645 train_time:125441ms step_avg:92.64ms +step:1355/1645 train_time:125534ms step_avg:92.64ms +step:1356/1645 train_time:125629ms step_avg:92.65ms +step:1357/1645 train_time:125723ms step_avg:92.65ms +step:1358/1645 train_time:125815ms step_avg:92.65ms +step:1359/1645 train_time:125909ms step_avg:92.65ms +step:1360/1645 train_time:126002ms step_avg:92.65ms +step:1361/1645 train_time:126095ms step_avg:92.65ms +step:1362/1645 train_time:126189ms step_avg:92.65ms +step:1363/1645 train_time:126283ms step_avg:92.65ms +step:1364/1645 train_time:126376ms step_avg:92.65ms +step:1365/1645 train_time:126472ms step_avg:92.65ms +step:1366/1645 train_time:126564ms step_avg:92.65ms +step:1367/1645 train_time:126659ms step_avg:92.65ms +step:1368/1645 train_time:126751ms step_avg:92.65ms +step:1369/1645 train_time:126845ms step_avg:92.65ms +step:1370/1645 train_time:126938ms step_avg:92.66ms +step:1371/1645 train_time:127032ms step_avg:92.66ms +step:1372/1645 train_time:127125ms step_avg:92.66ms +step:1373/1645 train_time:127218ms step_avg:92.66ms +step:1374/1645 train_time:127311ms step_avg:92.66ms +step:1375/1645 train_time:127405ms step_avg:92.66ms +step:1375/1645 val_loss:3.3378 train_time:127499ms step_avg:92.73ms +step:1376/1645 train_time:127521ms step_avg:92.67ms +step:1377/1645 train_time:127597ms step_avg:92.66ms +step:1378/1645 train_time:127691ms step_avg:92.66ms +step:1379/1645 train_time:127785ms step_avg:92.66ms +step:1380/1645 train_time:127878ms step_avg:92.66ms +step:1381/1645 train_time:127971ms step_avg:92.67ms +step:1382/1645 train_time:128064ms step_avg:92.67ms +step:1383/1645 train_time:128157ms step_avg:92.67ms +step:1384/1645 train_time:128250ms step_avg:92.67ms +step:1385/1645 train_time:128343ms step_avg:92.67ms +step:1386/1645 train_time:128437ms step_avg:92.67ms +step:1387/1645 train_time:128532ms step_avg:92.67ms +step:1388/1645 train_time:128627ms step_avg:92.67ms +step:1389/1645 train_time:128721ms step_avg:92.67ms +step:1390/1645 train_time:128816ms step_avg:92.67ms +step:1391/1645 train_time:128908ms step_avg:92.67ms +step:1392/1645 train_time:129001ms step_avg:92.67ms +step:1393/1645 train_time:129094ms step_avg:92.67ms +step:1394/1645 train_time:129188ms step_avg:92.67ms +step:1395/1645 train_time:129281ms step_avg:92.67ms +step:1396/1645 train_time:129375ms step_avg:92.68ms +step:1397/1645 train_time:129468ms step_avg:92.68ms +step:1398/1645 train_time:129563ms step_avg:92.68ms +step:1399/1645 train_time:129657ms step_avg:92.68ms +step:1400/1645 train_time:129751ms step_avg:92.68ms +step:1401/1645 train_time:129844ms step_avg:92.68ms +step:1402/1645 train_time:129937ms step_avg:92.68ms +step:1403/1645 train_time:130030ms step_avg:92.68ms +step:1404/1645 train_time:130123ms step_avg:92.68ms +step:1405/1645 train_time:130217ms step_avg:92.68ms +step:1406/1645 train_time:130310ms step_avg:92.68ms +step:1407/1645 train_time:130404ms step_avg:92.68ms +step:1408/1645 train_time:130498ms step_avg:92.68ms +step:1409/1645 train_time:130593ms step_avg:92.68ms +step:1410/1645 train_time:130687ms step_avg:92.69ms +step:1411/1645 train_time:130781ms step_avg:92.69ms +step:1412/1645 train_time:130874ms step_avg:92.69ms +step:1413/1645 train_time:130968ms step_avg:92.69ms +step:1414/1645 train_time:131062ms step_avg:92.69ms +step:1415/1645 train_time:131154ms step_avg:92.69ms +step:1416/1645 train_time:131247ms step_avg:92.69ms +step:1417/1645 train_time:131340ms step_avg:92.69ms +step:1418/1645 train_time:131434ms step_avg:92.69ms +step:1419/1645 train_time:131528ms step_avg:92.69ms +step:1420/1645 train_time:131622ms step_avg:92.69ms +step:1421/1645 train_time:131716ms step_avg:92.69ms +step:1422/1645 train_time:131809ms step_avg:92.69ms +step:1423/1645 train_time:131904ms step_avg:92.69ms +step:1424/1645 train_time:131998ms step_avg:92.70ms +step:1425/1645 train_time:132091ms step_avg:92.70ms +step:1426/1645 train_time:132184ms step_avg:92.70ms +step:1427/1645 train_time:132279ms step_avg:92.70ms +step:1428/1645 train_time:132372ms step_avg:92.70ms +step:1429/1645 train_time:132466ms step_avg:92.70ms +step:1430/1645 train_time:132560ms step_avg:92.70ms +step:1431/1645 train_time:132654ms step_avg:92.70ms +step:1432/1645 train_time:132748ms step_avg:92.70ms +step:1433/1645 train_time:132841ms step_avg:92.70ms +step:1434/1645 train_time:132935ms step_avg:92.70ms +step:1435/1645 train_time:133029ms step_avg:92.70ms +step:1436/1645 train_time:133123ms step_avg:92.70ms +step:1437/1645 train_time:133216ms step_avg:92.70ms +step:1438/1645 train_time:133309ms step_avg:92.70ms +step:1439/1645 train_time:133403ms step_avg:92.71ms +step:1440/1645 train_time:133498ms step_avg:92.71ms +step:1441/1645 train_time:133591ms step_avg:92.71ms +step:1442/1645 train_time:133684ms step_avg:92.71ms +step:1443/1645 train_time:133777ms step_avg:92.71ms +step:1444/1645 train_time:133870ms step_avg:92.71ms +step:1445/1645 train_time:133964ms step_avg:92.71ms +step:1446/1645 train_time:134058ms step_avg:92.71ms +step:1447/1645 train_time:134151ms step_avg:92.71ms +step:1448/1645 train_time:134245ms step_avg:92.71ms +step:1449/1645 train_time:134342ms step_avg:92.71ms +step:1450/1645 train_time:134434ms step_avg:92.71ms +step:1451/1645 train_time:134528ms step_avg:92.71ms +step:1452/1645 train_time:134622ms step_avg:92.72ms +step:1453/1645 train_time:134716ms step_avg:92.72ms +step:1454/1645 train_time:134808ms step_avg:92.72ms +step:1455/1645 train_time:134903ms step_avg:92.72ms +step:1456/1645 train_time:134997ms step_avg:92.72ms +step:1457/1645 train_time:135090ms step_avg:92.72ms +step:1458/1645 train_time:135184ms step_avg:92.72ms +step:1459/1645 train_time:135277ms step_avg:92.72ms +step:1460/1645 train_time:135370ms step_avg:92.72ms +step:1461/1645 train_time:135464ms step_avg:92.72ms +step:1462/1645 train_time:135561ms step_avg:92.72ms +step:1463/1645 train_time:135656ms step_avg:92.72ms +step:1464/1645 train_time:135748ms step_avg:92.72ms +step:1465/1645 train_time:135842ms step_avg:92.73ms +step:1466/1645 train_time:135937ms step_avg:92.73ms +step:1467/1645 train_time:136030ms step_avg:92.73ms +step:1468/1645 train_time:136123ms step_avg:92.73ms +step:1469/1645 train_time:136218ms step_avg:92.73ms +step:1470/1645 train_time:136312ms step_avg:92.73ms +step:1471/1645 train_time:136405ms step_avg:92.73ms +step:1472/1645 train_time:136501ms step_avg:92.73ms +step:1473/1645 train_time:136595ms step_avg:92.73ms +step:1474/1645 train_time:136688ms step_avg:92.73ms +step:1475/1645 train_time:136782ms step_avg:92.73ms +step:1476/1645 train_time:136876ms step_avg:92.73ms +step:1477/1645 train_time:136969ms step_avg:92.73ms +step:1478/1645 train_time:137063ms step_avg:92.74ms +step:1479/1645 train_time:137156ms step_avg:92.74ms +step:1480/1645 train_time:137251ms step_avg:92.74ms +step:1481/1645 train_time:137342ms step_avg:92.74ms +step:1482/1645 train_time:137437ms step_avg:92.74ms +step:1483/1645 train_time:137532ms step_avg:92.74ms +step:1484/1645 train_time:137626ms step_avg:92.74ms +step:1485/1645 train_time:137720ms step_avg:92.74ms +step:1486/1645 train_time:137813ms step_avg:92.74ms +step:1487/1645 train_time:137907ms step_avg:92.74ms +step:1488/1645 train_time:138001ms step_avg:92.74ms +step:1489/1645 train_time:138094ms step_avg:92.74ms +step:1490/1645 train_time:138188ms step_avg:92.74ms +step:1491/1645 train_time:138282ms step_avg:92.74ms +step:1492/1645 train_time:138374ms step_avg:92.74ms +step:1493/1645 train_time:138468ms step_avg:92.74ms +step:1494/1645 train_time:138563ms step_avg:92.75ms +step:1495/1645 train_time:138658ms step_avg:92.75ms +step:1496/1645 train_time:138750ms step_avg:92.75ms +step:1497/1645 train_time:138844ms step_avg:92.75ms +step:1498/1645 train_time:138937ms step_avg:92.75ms +step:1499/1645 train_time:139032ms step_avg:92.75ms +step:1500/1645 train_time:139125ms step_avg:92.75ms +step:1500/1645 val_loss:3.3082 train_time:139220ms step_avg:92.81ms +step:1501/1645 train_time:139242ms step_avg:92.77ms +step:1502/1645 train_time:139319ms step_avg:92.76ms +step:1503/1645 train_time:139417ms step_avg:92.76ms +step:1504/1645 train_time:139510ms step_avg:92.76ms +step:1505/1645 train_time:139603ms step_avg:92.76ms +step:1506/1645 train_time:139695ms step_avg:92.76ms +step:1507/1645 train_time:139787ms step_avg:92.76ms +step:1508/1645 train_time:139880ms step_avg:92.76ms +step:1509/1645 train_time:139973ms step_avg:92.76ms +step:1510/1645 train_time:140066ms step_avg:92.76ms +step:1511/1645 train_time:140160ms step_avg:92.76ms +step:1512/1645 train_time:140256ms step_avg:92.76ms +step:1513/1645 train_time:140352ms step_avg:92.76ms +step:1514/1645 train_time:140446ms step_avg:92.77ms +step:1515/1645 train_time:140540ms step_avg:92.77ms +step:1516/1645 train_time:140634ms step_avg:92.77ms +step:1517/1645 train_time:140727ms step_avg:92.77ms +step:1518/1645 train_time:140820ms step_avg:92.77ms +step:1519/1645 train_time:140914ms step_avg:92.77ms +step:1520/1645 train_time:141006ms step_avg:92.77ms +step:1521/1645 train_time:141101ms step_avg:92.77ms +step:1522/1645 train_time:141195ms step_avg:92.77ms +step:1523/1645 train_time:141289ms step_avg:92.77ms +step:1524/1645 train_time:141383ms step_avg:92.77ms +step:1525/1645 train_time:141477ms step_avg:92.77ms +step:1526/1645 train_time:141571ms step_avg:92.77ms +step:1527/1645 train_time:141665ms step_avg:92.77ms +step:1528/1645 train_time:141758ms step_avg:92.77ms +step:1529/1645 train_time:141851ms step_avg:92.77ms +step:1530/1645 train_time:141944ms step_avg:92.77ms +step:1531/1645 train_time:142037ms step_avg:92.77ms +step:1532/1645 train_time:142130ms step_avg:92.77ms +step:1533/1645 train_time:142224ms step_avg:92.78ms +step:1534/1645 train_time:142319ms step_avg:92.78ms +step:1535/1645 train_time:142414ms step_avg:92.78ms +step:1536/1645 train_time:142508ms step_avg:92.78ms +step:1537/1645 train_time:142601ms step_avg:92.78ms +step:1538/1645 train_time:142695ms step_avg:92.78ms +step:1539/1645 train_time:142789ms step_avg:92.78ms +step:1540/1645 train_time:142882ms step_avg:92.78ms +step:1541/1645 train_time:142975ms step_avg:92.78ms +step:1542/1645 train_time:143068ms step_avg:92.78ms +step:1543/1645 train_time:143162ms step_avg:92.78ms +step:1544/1645 train_time:143256ms step_avg:92.78ms +step:1545/1645 train_time:143350ms step_avg:92.78ms +step:1546/1645 train_time:143444ms step_avg:92.78ms +step:1547/1645 train_time:143538ms step_avg:92.78ms +step:1548/1645 train_time:143634ms step_avg:92.79ms +step:1549/1645 train_time:143727ms step_avg:92.79ms +step:1550/1645 train_time:143820ms step_avg:92.79ms +step:1551/1645 train_time:143914ms step_avg:92.79ms +step:1552/1645 train_time:144008ms step_avg:92.79ms +step:1553/1645 train_time:144101ms step_avg:92.79ms +step:1554/1645 train_time:144195ms step_avg:92.79ms +step:1555/1645 train_time:144289ms step_avg:92.79ms +step:1556/1645 train_time:144383ms step_avg:92.79ms +step:1557/1645 train_time:144476ms step_avg:92.79ms +step:1558/1645 train_time:144571ms step_avg:92.79ms +step:1559/1645 train_time:144664ms step_avg:92.79ms +step:1560/1645 train_time:144758ms step_avg:92.79ms +step:1561/1645 train_time:144851ms step_avg:92.79ms +step:1562/1645 train_time:144945ms step_avg:92.79ms +step:1563/1645 train_time:145038ms step_avg:92.79ms +step:1564/1645 train_time:145133ms step_avg:92.80ms +step:1565/1645 train_time:145227ms step_avg:92.80ms +step:1566/1645 train_time:145323ms step_avg:92.80ms +step:1567/1645 train_time:145415ms step_avg:92.80ms +step:1568/1645 train_time:145509ms step_avg:92.80ms +step:1569/1645 train_time:145602ms step_avg:92.80ms +step:1570/1645 train_time:145697ms step_avg:92.80ms +step:1571/1645 train_time:145790ms step_avg:92.80ms +step:1572/1645 train_time:145883ms step_avg:92.80ms +step:1573/1645 train_time:145976ms step_avg:92.80ms +step:1574/1645 train_time:146070ms step_avg:92.80ms +step:1575/1645 train_time:146164ms step_avg:92.80ms +step:1576/1645 train_time:146257ms step_avg:92.80ms +step:1577/1645 train_time:146351ms step_avg:92.80ms +step:1578/1645 train_time:146444ms step_avg:92.80ms +step:1579/1645 train_time:146538ms step_avg:92.80ms +step:1580/1645 train_time:146632ms step_avg:92.81ms +step:1581/1645 train_time:146726ms step_avg:92.81ms +step:1582/1645 train_time:146819ms step_avg:92.81ms +step:1583/1645 train_time:146913ms step_avg:92.81ms +step:1584/1645 train_time:147007ms step_avg:92.81ms +step:1585/1645 train_time:147100ms step_avg:92.81ms +step:1586/1645 train_time:147194ms step_avg:92.81ms +step:1587/1645 train_time:147287ms step_avg:92.81ms +step:1588/1645 train_time:147381ms step_avg:92.81ms +step:1589/1645 train_time:147474ms step_avg:92.81ms +step:1590/1645 train_time:147567ms step_avg:92.81ms +step:1591/1645 train_time:147662ms step_avg:92.81ms +step:1592/1645 train_time:147756ms step_avg:92.81ms +step:1593/1645 train_time:147849ms step_avg:92.81ms +step:1594/1645 train_time:147942ms step_avg:92.81ms +step:1595/1645 train_time:148037ms step_avg:92.81ms +step:1596/1645 train_time:148132ms step_avg:92.81ms +step:1597/1645 train_time:148225ms step_avg:92.81ms +step:1598/1645 train_time:148319ms step_avg:92.82ms +step:1599/1645 train_time:148412ms step_avg:92.82ms +step:1600/1645 train_time:148506ms step_avg:92.82ms +step:1601/1645 train_time:148599ms step_avg:92.82ms +step:1602/1645 train_time:148693ms step_avg:92.82ms +step:1603/1645 train_time:148787ms step_avg:92.82ms +step:1604/1645 train_time:148881ms step_avg:92.82ms +step:1605/1645 train_time:148974ms step_avg:92.82ms +step:1606/1645 train_time:149068ms step_avg:92.82ms +step:1607/1645 train_time:149161ms step_avg:92.82ms +step:1608/1645 train_time:149255ms step_avg:92.82ms +step:1609/1645 train_time:149348ms step_avg:92.82ms +step:1610/1645 train_time:149442ms step_avg:92.82ms +step:1611/1645 train_time:149535ms step_avg:92.82ms +step:1612/1645 train_time:149630ms step_avg:92.82ms +step:1613/1645 train_time:149724ms step_avg:92.82ms +step:1614/1645 train_time:149818ms step_avg:92.82ms +step:1615/1645 train_time:149911ms step_avg:92.82ms +step:1616/1645 train_time:150003ms step_avg:92.82ms +step:1617/1645 train_time:150097ms step_avg:92.82ms +step:1618/1645 train_time:150191ms step_avg:92.83ms +step:1619/1645 train_time:150284ms step_avg:92.83ms +step:1620/1645 train_time:150377ms step_avg:92.83ms +step:1621/1645 train_time:150472ms step_avg:92.83ms +step:1622/1645 train_time:150565ms step_avg:92.83ms +step:1623/1645 train_time:150659ms step_avg:92.83ms +step:1624/1645 train_time:150754ms step_avg:92.83ms +step:1625/1645 train_time:150848ms step_avg:92.83ms +step:1625/1645 val_loss:3.2840 train_time:150940ms step_avg:92.89ms +step:1626/1645 train_time:150962ms step_avg:92.84ms +step:1627/1645 train_time:151039ms step_avg:92.83ms +step:1628/1645 train_time:151135ms step_avg:92.83ms +step:1629/1645 train_time:151228ms step_avg:92.84ms +step:1630/1645 train_time:151321ms step_avg:92.84ms +step:1631/1645 train_time:151414ms step_avg:92.84ms +step:1632/1645 train_time:151507ms step_avg:92.83ms +step:1633/1645 train_time:151599ms step_avg:92.83ms +step:1634/1645 train_time:151691ms step_avg:92.83ms +step:1635/1645 train_time:151785ms step_avg:92.83ms +step:1636/1645 train_time:151878ms step_avg:92.84ms +step:1637/1645 train_time:151974ms step_avg:92.84ms +step:1638/1645 train_time:152070ms step_avg:92.84ms +step:1639/1645 train_time:152165ms step_avg:92.84ms +step:1640/1645 train_time:152258ms step_avg:92.84ms +step:1641/1645 train_time:152351ms step_avg:92.84ms +step:1642/1645 train_time:152444ms step_avg:92.84ms +step:1643/1645 train_time:152536ms step_avg:92.84ms +step:1644/1645 train_time:152629ms step_avg:92.84ms +step:1645/1645 train_time:152722ms step_avg:92.84ms +step:1645/1645 val_loss:3.2787 train_time:152815ms step_avg:92.90ms +peak memory allocated: 31497 MiB reserved: 47034 MiB diff --git a/records/091825_Smear/36761e6e-19ee-414f-a43c-63729950dfe7.txt b/records/091825_Smear/36761e6e-19ee-414f-a43c-63729950dfe7.txt new file mode 100644 index 000000000..e275bcd3e --- /dev/null +++ b/records/091825_Smear/36761e6e-19ee-414f-a43c-63729950dfe7.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:33:01 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 25C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 28C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 30C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 27C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 26C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 27C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:134ms step_avg:134.40ms +step:2/1645 train_time:150ms step_avg:75.09ms +step:3/1645 train_time:223ms step_avg:74.32ms +step:4/1645 train_time:313ms step_avg:78.24ms +step:5/1645 train_time:404ms step_avg:80.80ms +step:6/1645 train_time:495ms step_avg:82.42ms +step:7/1645 train_time:585ms step_avg:83.61ms +step:8/1645 train_time:676ms step_avg:84.55ms +step:9/1645 train_time:767ms step_avg:85.27ms +step:10/1645 train_time:859ms step_avg:85.85ms +step:11/1645 train_time:950ms step_avg:86.33ms +step:12/1645 train_time:1045ms step_avg:87.04ms +step:13/1645 train_time:1142ms step_avg:87.86ms +step:14/1645 train_time:1235ms step_avg:88.25ms +step:15/1645 train_time:1327ms step_avg:88.45ms +step:16/1645 train_time:1419ms step_avg:88.71ms +step:17/1645 train_time:1511ms step_avg:88.89ms +step:18/1645 train_time:1602ms step_avg:89.00ms +step:19/1645 train_time:1693ms step_avg:89.12ms +step:20/1645 train_time:1784ms step_avg:89.21ms +step:21/1645 train_time:1875ms step_avg:89.28ms +step:22/1645 train_time:1967ms step_avg:89.42ms +step:23/1645 train_time:2061ms step_avg:89.59ms +step:24/1645 train_time:2155ms step_avg:89.79ms +step:25/1645 train_time:2248ms step_avg:89.90ms +step:26/1645 train_time:2340ms step_avg:89.98ms +step:27/1645 train_time:2432ms step_avg:90.06ms +step:28/1645 train_time:2523ms step_avg:90.12ms +step:29/1645 train_time:2615ms step_avg:90.18ms +step:30/1645 train_time:2707ms step_avg:90.23ms +step:31/1645 train_time:2799ms step_avg:90.28ms +step:32/1645 train_time:2891ms step_avg:90.33ms +step:33/1645 train_time:2983ms step_avg:90.38ms +step:34/1645 train_time:3075ms step_avg:90.45ms +step:35/1645 train_time:3168ms step_avg:90.52ms +step:36/1645 train_time:3260ms step_avg:90.56ms +step:37/1645 train_time:3352ms step_avg:90.61ms +step:38/1645 train_time:3445ms step_avg:90.65ms +step:39/1645 train_time:3536ms step_avg:90.67ms +step:40/1645 train_time:3628ms step_avg:90.70ms +step:41/1645 train_time:3719ms step_avg:90.72ms +step:42/1645 train_time:3812ms step_avg:90.75ms +step:43/1645 train_time:3904ms step_avg:90.79ms +step:44/1645 train_time:3996ms step_avg:90.83ms +step:45/1645 train_time:4089ms step_avg:90.86ms +step:46/1645 train_time:4181ms step_avg:90.90ms +step:47/1645 train_time:4274ms step_avg:90.94ms +step:48/1645 train_time:4366ms step_avg:90.96ms +step:49/1645 train_time:4458ms step_avg:90.99ms +step:50/1645 train_time:4549ms step_avg:90.99ms +step:51/1645 train_time:4641ms step_avg:91.00ms +step:52/1645 train_time:4732ms step_avg:91.01ms +step:53/1645 train_time:4824ms step_avg:91.03ms +step:54/1645 train_time:4917ms step_avg:91.05ms +step:55/1645 train_time:5009ms step_avg:91.07ms +step:56/1645 train_time:5102ms step_avg:91.10ms +step:57/1645 train_time:5196ms step_avg:91.16ms +step:58/1645 train_time:5287ms step_avg:91.15ms +step:59/1645 train_time:5379ms step_avg:91.17ms +step:60/1645 train_time:5470ms step_avg:91.17ms +step:61/1645 train_time:5562ms step_avg:91.18ms +step:62/1645 train_time:5654ms step_avg:91.19ms +step:63/1645 train_time:5745ms step_avg:91.19ms +step:64/1645 train_time:5839ms step_avg:91.23ms +step:65/1645 train_time:5930ms step_avg:91.22ms +step:66/1645 train_time:6023ms step_avg:91.25ms +step:67/1645 train_time:6116ms step_avg:91.28ms +step:68/1645 train_time:6208ms step_avg:91.29ms +step:69/1645 train_time:6302ms step_avg:91.33ms +step:70/1645 train_time:6393ms step_avg:91.33ms +step:71/1645 train_time:6484ms step_avg:91.32ms +step:72/1645 train_time:6575ms step_avg:91.32ms +step:73/1645 train_time:6667ms step_avg:91.33ms +step:74/1645 train_time:6759ms step_avg:91.33ms +step:75/1645 train_time:6849ms step_avg:91.33ms +step:76/1645 train_time:6942ms step_avg:91.34ms +step:77/1645 train_time:7034ms step_avg:91.35ms +step:78/1645 train_time:7127ms step_avg:91.37ms +step:79/1645 train_time:7219ms step_avg:91.38ms +step:80/1645 train_time:7310ms step_avg:91.38ms +step:81/1645 train_time:7403ms step_avg:91.40ms +step:82/1645 train_time:7495ms step_avg:91.41ms +step:83/1645 train_time:7586ms step_avg:91.40ms +step:84/1645 train_time:7678ms step_avg:91.41ms +step:85/1645 train_time:7769ms step_avg:91.40ms +step:86/1645 train_time:7861ms step_avg:91.40ms +step:87/1645 train_time:7953ms step_avg:91.41ms +step:88/1645 train_time:8045ms step_avg:91.42ms +step:89/1645 train_time:8136ms step_avg:91.42ms +step:90/1645 train_time:8228ms step_avg:91.42ms +step:91/1645 train_time:8322ms step_avg:91.45ms +step:92/1645 train_time:8414ms step_avg:91.46ms +step:93/1645 train_time:8505ms step_avg:91.46ms +step:94/1645 train_time:8597ms step_avg:91.45ms +step:95/1645 train_time:8688ms step_avg:91.46ms +step:96/1645 train_time:8780ms step_avg:91.46ms +step:97/1645 train_time:8872ms step_avg:91.46ms +step:98/1645 train_time:8963ms step_avg:91.46ms +step:99/1645 train_time:9055ms step_avg:91.46ms +step:100/1645 train_time:9147ms step_avg:91.47ms +step:101/1645 train_time:9240ms step_avg:91.48ms +step:102/1645 train_time:9331ms step_avg:91.48ms +step:103/1645 train_time:9424ms step_avg:91.50ms +step:104/1645 train_time:9516ms step_avg:91.50ms +step:105/1645 train_time:9607ms step_avg:91.50ms +step:106/1645 train_time:9699ms step_avg:91.50ms +step:107/1645 train_time:9790ms step_avg:91.49ms +step:108/1645 train_time:9882ms step_avg:91.50ms +step:109/1645 train_time:9974ms step_avg:91.51ms +step:110/1645 train_time:10066ms step_avg:91.51ms +step:111/1645 train_time:10158ms step_avg:91.51ms +step:112/1645 train_time:10249ms step_avg:91.51ms +step:113/1645 train_time:10342ms step_avg:91.52ms +step:114/1645 train_time:10434ms step_avg:91.53ms +step:115/1645 train_time:10527ms step_avg:91.54ms +step:116/1645 train_time:10618ms step_avg:91.53ms +step:117/1645 train_time:10709ms step_avg:91.53ms +step:118/1645 train_time:10801ms step_avg:91.54ms +step:119/1645 train_time:10893ms step_avg:91.54ms +step:120/1645 train_time:10985ms step_avg:91.54ms +step:121/1645 train_time:11077ms step_avg:91.54ms +step:122/1645 train_time:11168ms step_avg:91.54ms +step:123/1645 train_time:11259ms step_avg:91.54ms +step:124/1645 train_time:11350ms step_avg:91.54ms +step:125/1645 train_time:11442ms step_avg:91.54ms +step:125/1645 val_loss:4.3125 train_time:11534ms step_avg:92.27ms +step:126/1645 train_time:11549ms step_avg:91.66ms +step:127/1645 train_time:11633ms step_avg:91.59ms +step:128/1645 train_time:11735ms step_avg:91.68ms +step:129/1645 train_time:11833ms step_avg:91.73ms +step:130/1645 train_time:11924ms step_avg:91.72ms +step:131/1645 train_time:12014ms step_avg:91.71ms +step:132/1645 train_time:12104ms step_avg:91.70ms +step:133/1645 train_time:12195ms step_avg:91.69ms +step:134/1645 train_time:12285ms step_avg:91.68ms +step:135/1645 train_time:12376ms step_avg:91.67ms +step:136/1645 train_time:12466ms step_avg:91.66ms +step:137/1645 train_time:12557ms step_avg:91.66ms +step:138/1645 train_time:12652ms step_avg:91.68ms +step:139/1645 train_time:12747ms step_avg:91.70ms +step:140/1645 train_time:12841ms step_avg:91.72ms +step:141/1645 train_time:12934ms step_avg:91.73ms +step:142/1645 train_time:13025ms step_avg:91.73ms +step:143/1645 train_time:13118ms step_avg:91.73ms +step:144/1645 train_time:13208ms step_avg:91.72ms +step:145/1645 train_time:13299ms step_avg:91.72ms +step:146/1645 train_time:13389ms step_avg:91.71ms +step:147/1645 train_time:13482ms step_avg:91.72ms +step:148/1645 train_time:13572ms step_avg:91.70ms +step:149/1645 train_time:13665ms step_avg:91.71ms +step:150/1645 train_time:13759ms step_avg:91.73ms +step:151/1645 train_time:13851ms step_avg:91.73ms +step:152/1645 train_time:13944ms step_avg:91.74ms +step:153/1645 train_time:14036ms step_avg:91.74ms +step:154/1645 train_time:14128ms step_avg:91.74ms +step:155/1645 train_time:14220ms step_avg:91.74ms +step:156/1645 train_time:14311ms step_avg:91.73ms +step:157/1645 train_time:14401ms step_avg:91.73ms +step:158/1645 train_time:14492ms step_avg:91.72ms +step:159/1645 train_time:14584ms step_avg:91.72ms +step:160/1645 train_time:14675ms step_avg:91.72ms +step:161/1645 train_time:14766ms step_avg:91.72ms +step:162/1645 train_time:14859ms step_avg:91.72ms +step:163/1645 train_time:14952ms step_avg:91.73ms +step:164/1645 train_time:15044ms step_avg:91.73ms +step:165/1645 train_time:15136ms step_avg:91.74ms +step:166/1645 train_time:15227ms step_avg:91.73ms +step:167/1645 train_time:15318ms step_avg:91.73ms +step:168/1645 train_time:15409ms step_avg:91.72ms +step:169/1645 train_time:15500ms step_avg:91.72ms +step:170/1645 train_time:15591ms step_avg:91.71ms +step:171/1645 train_time:15682ms step_avg:91.71ms +step:172/1645 train_time:15774ms step_avg:91.71ms +step:173/1645 train_time:15866ms step_avg:91.71ms +step:174/1645 train_time:15960ms step_avg:91.72ms +step:175/1645 train_time:16052ms step_avg:91.73ms +step:176/1645 train_time:16144ms step_avg:91.73ms +step:177/1645 train_time:16235ms step_avg:91.73ms +step:178/1645 train_time:16327ms step_avg:91.72ms +step:179/1645 train_time:16418ms step_avg:91.72ms +step:180/1645 train_time:16509ms step_avg:91.72ms +step:181/1645 train_time:16600ms step_avg:91.71ms +step:182/1645 train_time:16693ms step_avg:91.72ms +step:183/1645 train_time:16783ms step_avg:91.71ms +step:184/1645 train_time:16875ms step_avg:91.71ms +step:185/1645 train_time:16966ms step_avg:91.71ms +step:186/1645 train_time:17059ms step_avg:91.71ms +step:187/1645 train_time:17151ms step_avg:91.72ms +step:188/1645 train_time:17242ms step_avg:91.72ms +step:189/1645 train_time:17334ms step_avg:91.72ms +step:190/1645 train_time:17426ms step_avg:91.72ms +step:191/1645 train_time:17517ms step_avg:91.71ms +step:192/1645 train_time:17608ms step_avg:91.71ms +step:193/1645 train_time:17699ms step_avg:91.70ms +step:194/1645 train_time:17790ms step_avg:91.70ms +step:195/1645 train_time:17882ms step_avg:91.70ms +step:196/1645 train_time:17974ms step_avg:91.71ms +step:197/1645 train_time:18066ms step_avg:91.71ms +step:198/1645 train_time:18159ms step_avg:91.71ms +step:199/1645 train_time:18251ms step_avg:91.71ms +step:200/1645 train_time:18342ms step_avg:91.71ms +step:201/1645 train_time:18434ms step_avg:91.71ms +step:202/1645 train_time:18525ms step_avg:91.71ms +step:203/1645 train_time:18616ms step_avg:91.70ms +step:204/1645 train_time:18707ms step_avg:91.70ms +step:205/1645 train_time:18798ms step_avg:91.70ms +step:206/1645 train_time:18890ms step_avg:91.70ms +step:207/1645 train_time:18982ms step_avg:91.70ms +step:208/1645 train_time:19074ms step_avg:91.70ms +step:209/1645 train_time:19166ms step_avg:91.70ms +step:210/1645 train_time:19259ms step_avg:91.71ms +step:211/1645 train_time:19350ms step_avg:91.71ms +step:212/1645 train_time:19442ms step_avg:91.71ms +step:213/1645 train_time:19533ms step_avg:91.71ms +step:214/1645 train_time:19625ms step_avg:91.70ms +step:215/1645 train_time:19715ms step_avg:91.70ms +step:216/1645 train_time:19809ms step_avg:91.71ms +step:217/1645 train_time:19898ms step_avg:91.70ms +step:218/1645 train_time:19990ms step_avg:91.70ms +step:219/1645 train_time:20081ms step_avg:91.70ms +step:220/1645 train_time:20173ms step_avg:91.70ms +step:221/1645 train_time:20265ms step_avg:91.70ms +step:222/1645 train_time:20357ms step_avg:91.70ms +step:223/1645 train_time:20448ms step_avg:91.70ms +step:224/1645 train_time:20541ms step_avg:91.70ms +step:225/1645 train_time:20633ms step_avg:91.70ms +step:226/1645 train_time:20724ms step_avg:91.70ms +step:227/1645 train_time:20816ms step_avg:91.70ms +step:228/1645 train_time:20906ms step_avg:91.69ms +step:229/1645 train_time:20998ms step_avg:91.69ms +step:230/1645 train_time:21089ms step_avg:91.69ms +step:231/1645 train_time:21181ms step_avg:91.69ms +step:232/1645 train_time:21273ms step_avg:91.69ms +step:233/1645 train_time:21365ms step_avg:91.69ms +step:234/1645 train_time:21458ms step_avg:91.70ms +step:235/1645 train_time:21549ms step_avg:91.70ms +step:236/1645 train_time:21641ms step_avg:91.70ms +step:237/1645 train_time:21732ms step_avg:91.70ms +step:238/1645 train_time:21823ms step_avg:91.69ms +step:239/1645 train_time:21915ms step_avg:91.70ms +step:240/1645 train_time:22005ms step_avg:91.69ms +step:241/1645 train_time:22097ms step_avg:91.69ms +step:242/1645 train_time:22188ms step_avg:91.69ms +step:243/1645 train_time:22280ms step_avg:91.69ms +step:244/1645 train_time:22372ms step_avg:91.69ms +step:245/1645 train_time:22464ms step_avg:91.69ms +step:246/1645 train_time:22556ms step_avg:91.69ms +step:247/1645 train_time:22647ms step_avg:91.69ms +step:248/1645 train_time:22739ms step_avg:91.69ms +step:249/1645 train_time:22831ms step_avg:91.69ms +step:250/1645 train_time:22923ms step_avg:91.69ms +step:250/1645 val_loss:3.9668 train_time:23015ms step_avg:92.06ms +step:251/1645 train_time:23029ms step_avg:91.75ms +step:252/1645 train_time:23110ms step_avg:91.71ms +step:253/1645 train_time:23203ms step_avg:91.71ms +step:254/1645 train_time:23295ms step_avg:91.71ms +step:255/1645 train_time:23386ms step_avg:91.71ms +step:256/1645 train_time:23476ms step_avg:91.71ms +step:257/1645 train_time:23567ms step_avg:91.70ms +step:258/1645 train_time:23658ms step_avg:91.70ms +step:259/1645 train_time:23749ms step_avg:91.70ms +step:260/1645 train_time:23840ms step_avg:91.69ms +step:261/1645 train_time:23932ms step_avg:91.69ms +step:262/1645 train_time:24024ms step_avg:91.70ms +step:263/1645 train_time:24119ms step_avg:91.71ms +step:264/1645 train_time:24212ms step_avg:91.71ms +step:265/1645 train_time:24303ms step_avg:91.71ms +step:266/1645 train_time:24395ms step_avg:91.71ms +step:267/1645 train_time:24487ms step_avg:91.71ms +step:268/1645 train_time:24578ms step_avg:91.71ms +step:269/1645 train_time:24669ms step_avg:91.71ms +step:270/1645 train_time:24760ms step_avg:91.70ms +step:271/1645 train_time:24851ms step_avg:91.70ms +step:272/1645 train_time:24943ms step_avg:91.70ms +step:273/1645 train_time:25035ms step_avg:91.70ms +step:274/1645 train_time:25127ms step_avg:91.71ms +step:275/1645 train_time:25219ms step_avg:91.71ms +step:276/1645 train_time:25311ms step_avg:91.71ms +step:277/1645 train_time:25402ms step_avg:91.70ms +step:278/1645 train_time:25494ms step_avg:91.71ms +step:279/1645 train_time:25586ms step_avg:91.71ms +step:280/1645 train_time:25677ms step_avg:91.70ms +step:281/1645 train_time:25768ms step_avg:91.70ms +step:282/1645 train_time:25859ms step_avg:91.70ms +step:283/1645 train_time:25951ms step_avg:91.70ms +step:284/1645 train_time:26042ms step_avg:91.70ms +step:285/1645 train_time:26135ms step_avg:91.70ms +step:286/1645 train_time:26227ms step_avg:91.70ms +step:287/1645 train_time:26318ms step_avg:91.70ms +step:288/1645 train_time:26410ms step_avg:91.70ms +step:289/1645 train_time:26502ms step_avg:91.70ms +step:290/1645 train_time:26594ms step_avg:91.70ms +step:291/1645 train_time:26686ms step_avg:91.70ms +step:292/1645 train_time:26777ms step_avg:91.70ms +step:293/1645 train_time:26869ms step_avg:91.70ms +step:294/1645 train_time:26959ms step_avg:91.70ms +step:295/1645 train_time:27050ms step_avg:91.70ms +step:296/1645 train_time:27142ms step_avg:91.70ms +step:297/1645 train_time:27235ms step_avg:91.70ms +step:298/1645 train_time:27327ms step_avg:91.70ms +step:299/1645 train_time:27418ms step_avg:91.70ms +step:300/1645 train_time:27509ms step_avg:91.70ms +step:301/1645 train_time:27601ms step_avg:91.70ms +step:302/1645 train_time:27693ms step_avg:91.70ms +step:303/1645 train_time:27785ms step_avg:91.70ms +step:304/1645 train_time:27876ms step_avg:91.70ms +step:305/1645 train_time:27968ms step_avg:91.70ms +step:306/1645 train_time:28059ms step_avg:91.70ms +step:307/1645 train_time:28151ms step_avg:91.70ms +step:308/1645 train_time:28242ms step_avg:91.70ms +step:309/1645 train_time:28334ms step_avg:91.70ms +step:310/1645 train_time:28426ms step_avg:91.70ms +step:311/1645 train_time:28517ms step_avg:91.70ms +step:312/1645 train_time:28610ms step_avg:91.70ms +step:313/1645 train_time:28701ms step_avg:91.70ms +step:314/1645 train_time:28792ms step_avg:91.69ms +step:315/1645 train_time:28884ms step_avg:91.69ms +step:316/1645 train_time:28976ms step_avg:91.70ms +step:317/1645 train_time:29067ms step_avg:91.69ms +step:318/1645 train_time:29159ms step_avg:91.69ms +step:319/1645 train_time:29251ms step_avg:91.69ms +step:320/1645 train_time:29343ms step_avg:91.70ms +step:321/1645 train_time:29434ms step_avg:91.70ms +step:322/1645 train_time:29526ms step_avg:91.69ms +step:323/1645 train_time:29618ms step_avg:91.70ms +step:324/1645 train_time:29710ms step_avg:91.70ms +step:325/1645 train_time:29802ms step_avg:91.70ms +step:326/1645 train_time:29893ms step_avg:91.70ms +step:327/1645 train_time:29985ms step_avg:91.70ms +step:328/1645 train_time:30076ms step_avg:91.70ms +step:329/1645 train_time:30168ms step_avg:91.70ms +step:330/1645 train_time:30260ms step_avg:91.70ms +step:331/1645 train_time:30351ms step_avg:91.69ms +step:332/1645 train_time:30443ms step_avg:91.69ms +step:333/1645 train_time:30535ms step_avg:91.70ms +step:334/1645 train_time:30626ms step_avg:91.70ms +step:335/1645 train_time:30717ms step_avg:91.69ms +step:336/1645 train_time:30809ms step_avg:91.69ms +step:337/1645 train_time:30900ms step_avg:91.69ms +step:338/1645 train_time:30992ms step_avg:91.69ms +step:339/1645 train_time:31083ms step_avg:91.69ms +step:340/1645 train_time:31175ms step_avg:91.69ms +step:341/1645 train_time:31266ms step_avg:91.69ms +step:342/1645 train_time:31358ms step_avg:91.69ms +step:343/1645 train_time:31449ms step_avg:91.69ms +step:344/1645 train_time:31541ms step_avg:91.69ms +step:345/1645 train_time:31632ms step_avg:91.69ms +step:346/1645 train_time:31724ms step_avg:91.69ms +step:347/1645 train_time:31817ms step_avg:91.69ms +step:348/1645 train_time:31909ms step_avg:91.69ms +step:349/1645 train_time:32001ms step_avg:91.69ms +step:350/1645 train_time:32092ms step_avg:91.69ms +step:351/1645 train_time:32182ms step_avg:91.69ms +step:352/1645 train_time:32274ms step_avg:91.69ms +step:353/1645 train_time:32366ms step_avg:91.69ms +step:354/1645 train_time:32457ms step_avg:91.69ms +step:355/1645 train_time:32549ms step_avg:91.69ms +step:356/1645 train_time:32641ms step_avg:91.69ms +step:357/1645 train_time:32732ms step_avg:91.69ms +step:358/1645 train_time:32824ms step_avg:91.69ms +step:359/1645 train_time:32916ms step_avg:91.69ms +step:360/1645 train_time:33008ms step_avg:91.69ms +step:361/1645 train_time:33100ms step_avg:91.69ms +step:362/1645 train_time:33192ms step_avg:91.69ms +step:363/1645 train_time:33283ms step_avg:91.69ms +step:364/1645 train_time:33375ms step_avg:91.69ms +step:365/1645 train_time:33466ms step_avg:91.69ms +step:366/1645 train_time:33557ms step_avg:91.69ms +step:367/1645 train_time:33649ms step_avg:91.69ms +step:368/1645 train_time:33740ms step_avg:91.69ms +step:369/1645 train_time:33833ms step_avg:91.69ms +step:370/1645 train_time:33923ms step_avg:91.68ms +step:371/1645 train_time:34016ms step_avg:91.69ms +step:372/1645 train_time:34108ms step_avg:91.69ms +step:373/1645 train_time:34201ms step_avg:91.69ms +step:374/1645 train_time:34293ms step_avg:91.69ms +step:375/1645 train_time:34385ms step_avg:91.69ms +step:375/1645 val_loss:3.8152 train_time:34477ms step_avg:91.94ms +step:376/1645 train_time:34498ms step_avg:91.75ms +step:377/1645 train_time:34574ms step_avg:91.71ms +step:378/1645 train_time:34669ms step_avg:91.72ms +step:379/1645 train_time:34761ms step_avg:91.72ms +step:380/1645 train_time:34852ms step_avg:91.72ms +step:381/1645 train_time:34942ms step_avg:91.71ms +step:382/1645 train_time:35032ms step_avg:91.71ms +step:383/1645 train_time:35123ms step_avg:91.71ms +step:384/1645 train_time:35214ms step_avg:91.70ms +step:385/1645 train_time:35305ms step_avg:91.70ms +step:386/1645 train_time:35396ms step_avg:91.70ms +step:387/1645 train_time:35490ms step_avg:91.70ms +step:388/1645 train_time:35582ms step_avg:91.71ms +step:389/1645 train_time:35675ms step_avg:91.71ms +step:390/1645 train_time:35767ms step_avg:91.71ms +step:391/1645 train_time:35858ms step_avg:91.71ms +step:392/1645 train_time:35949ms step_avg:91.71ms +step:393/1645 train_time:36040ms step_avg:91.71ms +step:394/1645 train_time:36131ms step_avg:91.70ms +step:395/1645 train_time:36222ms step_avg:91.70ms +step:396/1645 train_time:36313ms step_avg:91.70ms +step:397/1645 train_time:36404ms step_avg:91.70ms +step:398/1645 train_time:36497ms step_avg:91.70ms +step:399/1645 train_time:36588ms step_avg:91.70ms +step:400/1645 train_time:36680ms step_avg:91.70ms +step:401/1645 train_time:36772ms step_avg:91.70ms +step:402/1645 train_time:36864ms step_avg:91.70ms +step:403/1645 train_time:36955ms step_avg:91.70ms +step:404/1645 train_time:37047ms step_avg:91.70ms +step:405/1645 train_time:37138ms step_avg:91.70ms +step:406/1645 train_time:37229ms step_avg:91.70ms +step:407/1645 train_time:37321ms step_avg:91.70ms +step:408/1645 train_time:37412ms step_avg:91.70ms +step:409/1645 train_time:37504ms step_avg:91.70ms +step:410/1645 train_time:37594ms step_avg:91.69ms +step:411/1645 train_time:37687ms step_avg:91.70ms +step:412/1645 train_time:37778ms step_avg:91.69ms +step:413/1645 train_time:37870ms step_avg:91.70ms +step:414/1645 train_time:37962ms step_avg:91.70ms +step:415/1645 train_time:38053ms step_avg:91.69ms +step:416/1645 train_time:38145ms step_avg:91.69ms +step:417/1645 train_time:38236ms step_avg:91.69ms +step:418/1645 train_time:38328ms step_avg:91.69ms +step:419/1645 train_time:38418ms step_avg:91.69ms +step:420/1645 train_time:38509ms step_avg:91.69ms +step:421/1645 train_time:38601ms step_avg:91.69ms +step:422/1645 train_time:38692ms step_avg:91.69ms +step:423/1645 train_time:38783ms step_avg:91.69ms +step:424/1645 train_time:38875ms step_avg:91.69ms +step:425/1645 train_time:38968ms step_avg:91.69ms +step:426/1645 train_time:39059ms step_avg:91.69ms +step:427/1645 train_time:39150ms step_avg:91.69ms +step:428/1645 train_time:39242ms step_avg:91.69ms +step:429/1645 train_time:39334ms step_avg:91.69ms +step:430/1645 train_time:39425ms step_avg:91.69ms +step:431/1645 train_time:39516ms step_avg:91.68ms +step:432/1645 train_time:39609ms step_avg:91.69ms +step:433/1645 train_time:39700ms step_avg:91.69ms +step:434/1645 train_time:39791ms step_avg:91.68ms +step:435/1645 train_time:39884ms step_avg:91.69ms +step:436/1645 train_time:39975ms step_avg:91.69ms +step:437/1645 train_time:40067ms step_avg:91.69ms +step:438/1645 train_time:40158ms step_avg:91.69ms +step:439/1645 train_time:40250ms step_avg:91.69ms +step:440/1645 train_time:40342ms step_avg:91.69ms +step:441/1645 train_time:40433ms step_avg:91.68ms +step:442/1645 train_time:40524ms step_avg:91.68ms +step:443/1645 train_time:40617ms step_avg:91.69ms +step:444/1645 train_time:40707ms step_avg:91.68ms +step:445/1645 train_time:40798ms step_avg:91.68ms +step:446/1645 train_time:40890ms step_avg:91.68ms +step:447/1645 train_time:40982ms step_avg:91.68ms +step:448/1645 train_time:41073ms step_avg:91.68ms +step:449/1645 train_time:41165ms step_avg:91.68ms +step:450/1645 train_time:41257ms step_avg:91.68ms +step:451/1645 train_time:41348ms step_avg:91.68ms +step:452/1645 train_time:41440ms step_avg:91.68ms +step:453/1645 train_time:41532ms step_avg:91.68ms +step:454/1645 train_time:41623ms step_avg:91.68ms +step:455/1645 train_time:41715ms step_avg:91.68ms +step:456/1645 train_time:41806ms step_avg:91.68ms +step:457/1645 train_time:41897ms step_avg:91.68ms +step:458/1645 train_time:41989ms step_avg:91.68ms +step:459/1645 train_time:42080ms step_avg:91.68ms +step:460/1645 train_time:42172ms step_avg:91.68ms +step:461/1645 train_time:42263ms step_avg:91.68ms +step:462/1645 train_time:42355ms step_avg:91.68ms +step:463/1645 train_time:42447ms step_avg:91.68ms +step:464/1645 train_time:42539ms step_avg:91.68ms +step:465/1645 train_time:42632ms step_avg:91.68ms +step:466/1645 train_time:42724ms step_avg:91.68ms +step:467/1645 train_time:42814ms step_avg:91.68ms +step:468/1645 train_time:42906ms step_avg:91.68ms +step:469/1645 train_time:42997ms step_avg:91.68ms +step:470/1645 train_time:43089ms step_avg:91.68ms +step:471/1645 train_time:43180ms step_avg:91.68ms +step:472/1645 train_time:43271ms step_avg:91.68ms +step:473/1645 train_time:43363ms step_avg:91.68ms +step:474/1645 train_time:43455ms step_avg:91.68ms +step:475/1645 train_time:43547ms step_avg:91.68ms +step:476/1645 train_time:43638ms step_avg:91.68ms +step:477/1645 train_time:43731ms step_avg:91.68ms +step:478/1645 train_time:43821ms step_avg:91.68ms +step:479/1645 train_time:43912ms step_avg:91.68ms +step:480/1645 train_time:44004ms step_avg:91.67ms +step:481/1645 train_time:44095ms step_avg:91.67ms +step:482/1645 train_time:44187ms step_avg:91.67ms +step:483/1645 train_time:44278ms step_avg:91.67ms +step:484/1645 train_time:44370ms step_avg:91.67ms +step:485/1645 train_time:44462ms step_avg:91.67ms +step:486/1645 train_time:44554ms step_avg:91.68ms +step:487/1645 train_time:44647ms step_avg:91.68ms +step:488/1645 train_time:44740ms step_avg:91.68ms +step:489/1645 train_time:44831ms step_avg:91.68ms +step:490/1645 train_time:44922ms step_avg:91.68ms +step:491/1645 train_time:45013ms step_avg:91.68ms +step:492/1645 train_time:45104ms step_avg:91.68ms +step:493/1645 train_time:45195ms step_avg:91.67ms +step:494/1645 train_time:45286ms step_avg:91.67ms +step:495/1645 train_time:45378ms step_avg:91.67ms +step:496/1645 train_time:45471ms step_avg:91.68ms +step:497/1645 train_time:45563ms step_avg:91.68ms +step:498/1645 train_time:45656ms step_avg:91.68ms +step:499/1645 train_time:45747ms step_avg:91.68ms +step:500/1645 train_time:45841ms step_avg:91.68ms +step:500/1645 val_loss:3.7132 train_time:45931ms step_avg:91.86ms +step:501/1645 train_time:45952ms step_avg:91.72ms +step:502/1645 train_time:46027ms step_avg:91.69ms +step:503/1645 train_time:46121ms step_avg:91.69ms +step:504/1645 train_time:46213ms step_avg:91.69ms +step:505/1645 train_time:46304ms step_avg:91.69ms +step:506/1645 train_time:46394ms step_avg:91.69ms +step:507/1645 train_time:46485ms step_avg:91.69ms +step:508/1645 train_time:46576ms step_avg:91.69ms +step:509/1645 train_time:46667ms step_avg:91.68ms +step:510/1645 train_time:46758ms step_avg:91.68ms +step:511/1645 train_time:46850ms step_avg:91.68ms +step:512/1645 train_time:46942ms step_avg:91.68ms +step:513/1645 train_time:47034ms step_avg:91.69ms +step:514/1645 train_time:47128ms step_avg:91.69ms +step:515/1645 train_time:47221ms step_avg:91.69ms +step:516/1645 train_time:47313ms step_avg:91.69ms +step:517/1645 train_time:47404ms step_avg:91.69ms +step:518/1645 train_time:47495ms step_avg:91.69ms +step:519/1645 train_time:47586ms step_avg:91.69ms +step:520/1645 train_time:47677ms step_avg:91.69ms +step:521/1645 train_time:47767ms step_avg:91.68ms +step:522/1645 train_time:47858ms step_avg:91.68ms +step:523/1645 train_time:47950ms step_avg:91.68ms +step:524/1645 train_time:48043ms step_avg:91.69ms +step:525/1645 train_time:48136ms step_avg:91.69ms +step:526/1645 train_time:48228ms step_avg:91.69ms +step:527/1645 train_time:48319ms step_avg:91.69ms +step:528/1645 train_time:48411ms step_avg:91.69ms +step:529/1645 train_time:48502ms step_avg:91.69ms +step:530/1645 train_time:48593ms step_avg:91.69ms +step:531/1645 train_time:48684ms step_avg:91.68ms +step:532/1645 train_time:48775ms step_avg:91.68ms +step:533/1645 train_time:48867ms step_avg:91.68ms +step:534/1645 train_time:48959ms step_avg:91.68ms +step:535/1645 train_time:49051ms step_avg:91.68ms +step:536/1645 train_time:49143ms step_avg:91.68ms +step:537/1645 train_time:49235ms step_avg:91.69ms +step:538/1645 train_time:49326ms step_avg:91.68ms +step:539/1645 train_time:49418ms step_avg:91.69ms +step:540/1645 train_time:49509ms step_avg:91.68ms +step:541/1645 train_time:49601ms step_avg:91.68ms +step:542/1645 train_time:49691ms step_avg:91.68ms +step:543/1645 train_time:49783ms step_avg:91.68ms +step:544/1645 train_time:49874ms step_avg:91.68ms +step:545/1645 train_time:49966ms step_avg:91.68ms +step:546/1645 train_time:50058ms step_avg:91.68ms +step:547/1645 train_time:50150ms step_avg:91.68ms +step:548/1645 train_time:50241ms step_avg:91.68ms +step:549/1645 train_time:50333ms step_avg:91.68ms +step:550/1645 train_time:50426ms step_avg:91.68ms +step:551/1645 train_time:50520ms step_avg:91.69ms +step:552/1645 train_time:50612ms step_avg:91.69ms +step:553/1645 train_time:50705ms step_avg:91.69ms +step:554/1645 train_time:50797ms step_avg:91.69ms +step:555/1645 train_time:50890ms step_avg:91.69ms +step:556/1645 train_time:50984ms step_avg:91.70ms +step:557/1645 train_time:51077ms step_avg:91.70ms +step:558/1645 train_time:51170ms step_avg:91.70ms +step:559/1645 train_time:51263ms step_avg:91.70ms +step:560/1645 train_time:51355ms step_avg:91.71ms +step:561/1645 train_time:51449ms step_avg:91.71ms +step:562/1645 train_time:51542ms step_avg:91.71ms +step:563/1645 train_time:51634ms step_avg:91.71ms +step:564/1645 train_time:51728ms step_avg:91.72ms +step:565/1645 train_time:51821ms step_avg:91.72ms +step:566/1645 train_time:51914ms step_avg:91.72ms +step:567/1645 train_time:52007ms step_avg:91.72ms +step:568/1645 train_time:52101ms step_avg:91.73ms +step:569/1645 train_time:52193ms step_avg:91.73ms +step:570/1645 train_time:52286ms step_avg:91.73ms +step:571/1645 train_time:52379ms step_avg:91.73ms +step:572/1645 train_time:52471ms step_avg:91.73ms +step:573/1645 train_time:52563ms step_avg:91.73ms +step:574/1645 train_time:52656ms step_avg:91.74ms +step:575/1645 train_time:52749ms step_avg:91.74ms +step:576/1645 train_time:52842ms step_avg:91.74ms +step:577/1645 train_time:52936ms step_avg:91.74ms +step:578/1645 train_time:53029ms step_avg:91.75ms +step:579/1645 train_time:53122ms step_avg:91.75ms +step:580/1645 train_time:53215ms step_avg:91.75ms +step:581/1645 train_time:53309ms step_avg:91.75ms +step:582/1645 train_time:53402ms step_avg:91.76ms +step:583/1645 train_time:53494ms step_avg:91.76ms +step:584/1645 train_time:53588ms step_avg:91.76ms +step:585/1645 train_time:53680ms step_avg:91.76ms +step:586/1645 train_time:53772ms step_avg:91.76ms +step:587/1645 train_time:53866ms step_avg:91.76ms +step:588/1645 train_time:53959ms step_avg:91.77ms +step:589/1645 train_time:54052ms step_avg:91.77ms +step:590/1645 train_time:54145ms step_avg:91.77ms +step:591/1645 train_time:54238ms step_avg:91.77ms +step:592/1645 train_time:54331ms step_avg:91.77ms +step:593/1645 train_time:54424ms step_avg:91.78ms +step:594/1645 train_time:54518ms step_avg:91.78ms +step:595/1645 train_time:54611ms step_avg:91.78ms +step:596/1645 train_time:54703ms step_avg:91.78ms +step:597/1645 train_time:54795ms step_avg:91.78ms +step:598/1645 train_time:54888ms step_avg:91.79ms +step:599/1645 train_time:54980ms step_avg:91.79ms +step:600/1645 train_time:55073ms step_avg:91.79ms +step:601/1645 train_time:55167ms step_avg:91.79ms +step:602/1645 train_time:55260ms step_avg:91.79ms +step:603/1645 train_time:55352ms step_avg:91.79ms +step:604/1645 train_time:55446ms step_avg:91.80ms +step:605/1645 train_time:55540ms step_avg:91.80ms +step:606/1645 train_time:55632ms step_avg:91.80ms +step:607/1645 train_time:55725ms step_avg:91.80ms +step:608/1645 train_time:55818ms step_avg:91.81ms +step:609/1645 train_time:55911ms step_avg:91.81ms +step:610/1645 train_time:56005ms step_avg:91.81ms +step:611/1645 train_time:56098ms step_avg:91.81ms +step:612/1645 train_time:56191ms step_avg:91.81ms +step:613/1645 train_time:56285ms step_avg:91.82ms +step:614/1645 train_time:56378ms step_avg:91.82ms +step:615/1645 train_time:56471ms step_avg:91.82ms +step:616/1645 train_time:56563ms step_avg:91.82ms +step:617/1645 train_time:56656ms step_avg:91.82ms +step:618/1645 train_time:56750ms step_avg:91.83ms +step:619/1645 train_time:56842ms step_avg:91.83ms +step:620/1645 train_time:56935ms step_avg:91.83ms +step:621/1645 train_time:57029ms step_avg:91.83ms +step:622/1645 train_time:57121ms step_avg:91.84ms +step:623/1645 train_time:57215ms step_avg:91.84ms +step:624/1645 train_time:57308ms step_avg:91.84ms +step:625/1645 train_time:57401ms step_avg:91.84ms +step:625/1645 val_loss:3.6152 train_time:57494ms step_avg:91.99ms +step:626/1645 train_time:57515ms step_avg:91.88ms +step:627/1645 train_time:57589ms step_avg:91.85ms +step:628/1645 train_time:57692ms step_avg:91.87ms +step:629/1645 train_time:57788ms step_avg:91.87ms +step:630/1645 train_time:57880ms step_avg:91.87ms +step:631/1645 train_time:57971ms step_avg:91.87ms +step:632/1645 train_time:58063ms step_avg:91.87ms +step:633/1645 train_time:58155ms step_avg:91.87ms +step:634/1645 train_time:58246ms step_avg:91.87ms +step:635/1645 train_time:58338ms step_avg:91.87ms +step:636/1645 train_time:58432ms step_avg:91.87ms +step:637/1645 train_time:58526ms step_avg:91.88ms +step:638/1645 train_time:58621ms step_avg:91.88ms +step:639/1645 train_time:58717ms step_avg:91.89ms +step:640/1645 train_time:58811ms step_avg:91.89ms +step:641/1645 train_time:58904ms step_avg:91.89ms +step:642/1645 train_time:58996ms step_avg:91.89ms +step:643/1645 train_time:59088ms step_avg:91.89ms +step:644/1645 train_time:59180ms step_avg:91.90ms +step:645/1645 train_time:59273ms step_avg:91.90ms +step:646/1645 train_time:59365ms step_avg:91.90ms +step:647/1645 train_time:59457ms step_avg:91.90ms +step:648/1645 train_time:59551ms step_avg:91.90ms +step:649/1645 train_time:59645ms step_avg:91.90ms +step:650/1645 train_time:59741ms step_avg:91.91ms +step:651/1645 train_time:59834ms step_avg:91.91ms +step:652/1645 train_time:59926ms step_avg:91.91ms +step:653/1645 train_time:60019ms step_avg:91.91ms +step:654/1645 train_time:60111ms step_avg:91.91ms +step:655/1645 train_time:60203ms step_avg:91.91ms +step:656/1645 train_time:60295ms step_avg:91.91ms +step:657/1645 train_time:60387ms step_avg:91.91ms +step:658/1645 train_time:60482ms step_avg:91.92ms +step:659/1645 train_time:60576ms step_avg:91.92ms +step:660/1645 train_time:60669ms step_avg:91.92ms +step:661/1645 train_time:60762ms step_avg:91.92ms +step:662/1645 train_time:60856ms step_avg:91.93ms +step:663/1645 train_time:60948ms step_avg:91.93ms +step:664/1645 train_time:61042ms step_avg:91.93ms +step:665/1645 train_time:61135ms step_avg:91.93ms +step:666/1645 train_time:61228ms step_avg:91.93ms +step:667/1645 train_time:61320ms step_avg:91.93ms +step:668/1645 train_time:61412ms step_avg:91.93ms +step:669/1645 train_time:61505ms step_avg:91.94ms +step:670/1645 train_time:61599ms step_avg:91.94ms +step:671/1645 train_time:61693ms step_avg:91.94ms +step:672/1645 train_time:61785ms step_avg:91.94ms +step:673/1645 train_time:61878ms step_avg:91.94ms +step:674/1645 train_time:61972ms step_avg:91.95ms +step:675/1645 train_time:62065ms step_avg:91.95ms +step:676/1645 train_time:62157ms step_avg:91.95ms +step:677/1645 train_time:62249ms step_avg:91.95ms +step:678/1645 train_time:62341ms step_avg:91.95ms +step:679/1645 train_time:62434ms step_avg:91.95ms +step:680/1645 train_time:62526ms step_avg:91.95ms +step:681/1645 train_time:62619ms step_avg:91.95ms +step:682/1645 train_time:62713ms step_avg:91.95ms +step:683/1645 train_time:62806ms step_avg:91.96ms +step:684/1645 train_time:62900ms step_avg:91.96ms +step:685/1645 train_time:62994ms step_avg:91.96ms +step:686/1645 train_time:63086ms step_avg:91.96ms +step:687/1645 train_time:63179ms step_avg:91.96ms +step:688/1645 train_time:63271ms step_avg:91.96ms +step:689/1645 train_time:63363ms step_avg:91.96ms +step:690/1645 train_time:63456ms step_avg:91.97ms +step:691/1645 train_time:63549ms step_avg:91.97ms +step:692/1645 train_time:63641ms step_avg:91.97ms +step:693/1645 train_time:63735ms step_avg:91.97ms +step:694/1645 train_time:63828ms step_avg:91.97ms +step:695/1645 train_time:63921ms step_avg:91.97ms +step:696/1645 train_time:64014ms step_avg:91.97ms +step:697/1645 train_time:64107ms step_avg:91.98ms +step:698/1645 train_time:64201ms step_avg:91.98ms +step:699/1645 train_time:64294ms step_avg:91.98ms +step:700/1645 train_time:64387ms step_avg:91.98ms +step:701/1645 train_time:64480ms step_avg:91.98ms +step:702/1645 train_time:64573ms step_avg:91.98ms +step:703/1645 train_time:64666ms step_avg:91.99ms +step:704/1645 train_time:64760ms step_avg:91.99ms +step:705/1645 train_time:64853ms step_avg:91.99ms +step:706/1645 train_time:64947ms step_avg:91.99ms +step:707/1645 train_time:65039ms step_avg:91.99ms +step:708/1645 train_time:65132ms step_avg:91.99ms +step:709/1645 train_time:65225ms step_avg:92.00ms +step:710/1645 train_time:65317ms step_avg:92.00ms +step:711/1645 train_time:65409ms step_avg:92.00ms +step:712/1645 train_time:65503ms step_avg:92.00ms +step:713/1645 train_time:65596ms step_avg:92.00ms +step:714/1645 train_time:65688ms step_avg:92.00ms +step:715/1645 train_time:65782ms step_avg:92.00ms +step:716/1645 train_time:65875ms step_avg:92.00ms +step:717/1645 train_time:65968ms step_avg:92.01ms +step:718/1645 train_time:66062ms step_avg:92.01ms +step:719/1645 train_time:66155ms step_avg:92.01ms +step:720/1645 train_time:66247ms step_avg:92.01ms +step:721/1645 train_time:66340ms step_avg:92.01ms +step:722/1645 train_time:66434ms step_avg:92.01ms +step:723/1645 train_time:66526ms step_avg:92.01ms +step:724/1645 train_time:66619ms step_avg:92.02ms +step:725/1645 train_time:66712ms step_avg:92.02ms +step:726/1645 train_time:66805ms step_avg:92.02ms +step:727/1645 train_time:66898ms step_avg:92.02ms +step:728/1645 train_time:66990ms step_avg:92.02ms +step:729/1645 train_time:67084ms step_avg:92.02ms +step:730/1645 train_time:67177ms step_avg:92.02ms +step:731/1645 train_time:67270ms step_avg:92.02ms +step:732/1645 train_time:67363ms step_avg:92.03ms +step:733/1645 train_time:67457ms step_avg:92.03ms +step:734/1645 train_time:67548ms step_avg:92.03ms +step:735/1645 train_time:67641ms step_avg:92.03ms +step:736/1645 train_time:67734ms step_avg:92.03ms +step:737/1645 train_time:67827ms step_avg:92.03ms +step:738/1645 train_time:67921ms step_avg:92.03ms +step:739/1645 train_time:68014ms step_avg:92.03ms +step:740/1645 train_time:68106ms step_avg:92.04ms +step:741/1645 train_time:68199ms step_avg:92.04ms +step:742/1645 train_time:68291ms step_avg:92.04ms +step:743/1645 train_time:68385ms step_avg:92.04ms +step:744/1645 train_time:68477ms step_avg:92.04ms +step:745/1645 train_time:68570ms step_avg:92.04ms +step:746/1645 train_time:68664ms step_avg:92.04ms +step:747/1645 train_time:68756ms step_avg:92.04ms +step:748/1645 train_time:68848ms step_avg:92.04ms +step:749/1645 train_time:68941ms step_avg:92.04ms +step:750/1645 train_time:69035ms step_avg:92.05ms +step:750/1645 val_loss:3.5608 train_time:69127ms step_avg:92.17ms +step:751/1645 train_time:69148ms step_avg:92.07ms +step:752/1645 train_time:69224ms step_avg:92.05ms +step:753/1645 train_time:69320ms step_avg:92.06ms +step:754/1645 train_time:69413ms step_avg:92.06ms +step:755/1645 train_time:69504ms step_avg:92.06ms +step:756/1645 train_time:69596ms step_avg:92.06ms +step:757/1645 train_time:69688ms step_avg:92.06ms +step:758/1645 train_time:69780ms step_avg:92.06ms +step:759/1645 train_time:69873ms step_avg:92.06ms +step:760/1645 train_time:69964ms step_avg:92.06ms +step:761/1645 train_time:70057ms step_avg:92.06ms +step:762/1645 train_time:70151ms step_avg:92.06ms +step:763/1645 train_time:70245ms step_avg:92.06ms +step:764/1645 train_time:70341ms step_avg:92.07ms +step:765/1645 train_time:70435ms step_avg:92.07ms +step:766/1645 train_time:70527ms step_avg:92.07ms +step:767/1645 train_time:70619ms step_avg:92.07ms +step:768/1645 train_time:70712ms step_avg:92.07ms +step:769/1645 train_time:70804ms step_avg:92.07ms +step:770/1645 train_time:70896ms step_avg:92.07ms +step:771/1645 train_time:70988ms step_avg:92.07ms +step:772/1645 train_time:71082ms step_avg:92.08ms +step:773/1645 train_time:71176ms step_avg:92.08ms +step:774/1645 train_time:71270ms step_avg:92.08ms +step:775/1645 train_time:71363ms step_avg:92.08ms +step:776/1645 train_time:71455ms step_avg:92.08ms +step:777/1645 train_time:71548ms step_avg:92.08ms +step:778/1645 train_time:71641ms step_avg:92.08ms +step:779/1645 train_time:71734ms step_avg:92.08ms +step:780/1645 train_time:71827ms step_avg:92.09ms +step:781/1645 train_time:71920ms step_avg:92.09ms +step:782/1645 train_time:72013ms step_avg:92.09ms +step:783/1645 train_time:72106ms step_avg:92.09ms +step:784/1645 train_time:72201ms step_avg:92.09ms +step:785/1645 train_time:72294ms step_avg:92.09ms +step:786/1645 train_time:72387ms step_avg:92.09ms +step:787/1645 train_time:72480ms step_avg:92.10ms +step:788/1645 train_time:72573ms step_avg:92.10ms +step:789/1645 train_time:72666ms step_avg:92.10ms +step:790/1645 train_time:72758ms step_avg:92.10ms +step:791/1645 train_time:72849ms step_avg:92.10ms +step:792/1645 train_time:72942ms step_avg:92.10ms +step:793/1645 train_time:73036ms step_avg:92.10ms +step:794/1645 train_time:73128ms step_avg:92.10ms +step:795/1645 train_time:73221ms step_avg:92.10ms +step:796/1645 train_time:73316ms step_avg:92.11ms +step:797/1645 train_time:73409ms step_avg:92.11ms +step:798/1645 train_time:73503ms step_avg:92.11ms +step:799/1645 train_time:73597ms step_avg:92.11ms +step:800/1645 train_time:73689ms step_avg:92.11ms +step:801/1645 train_time:73782ms step_avg:92.11ms +step:802/1645 train_time:73875ms step_avg:92.11ms +step:803/1645 train_time:73967ms step_avg:92.11ms +step:804/1645 train_time:74060ms step_avg:92.11ms +step:805/1645 train_time:74152ms step_avg:92.11ms +step:806/1645 train_time:74245ms step_avg:92.12ms +step:807/1645 train_time:74338ms step_avg:92.12ms +step:808/1645 train_time:74431ms step_avg:92.12ms +step:809/1645 train_time:74524ms step_avg:92.12ms +step:810/1645 train_time:74617ms step_avg:92.12ms +step:811/1645 train_time:74710ms step_avg:92.12ms +step:812/1645 train_time:74804ms step_avg:92.12ms +step:813/1645 train_time:74896ms step_avg:92.12ms +step:814/1645 train_time:74989ms step_avg:92.12ms +step:815/1645 train_time:75083ms step_avg:92.13ms +step:816/1645 train_time:75175ms step_avg:92.13ms +step:817/1645 train_time:75268ms step_avg:92.13ms +step:818/1645 train_time:75360ms step_avg:92.13ms +step:819/1645 train_time:75453ms step_avg:92.13ms +step:820/1645 train_time:75546ms step_avg:92.13ms +step:821/1645 train_time:75640ms step_avg:92.13ms +step:822/1645 train_time:75734ms step_avg:92.13ms +step:823/1645 train_time:75826ms step_avg:92.13ms +step:824/1645 train_time:75919ms step_avg:92.13ms +step:825/1645 train_time:76012ms step_avg:92.14ms +step:826/1645 train_time:76105ms step_avg:92.14ms +step:827/1645 train_time:76198ms step_avg:92.14ms +step:828/1645 train_time:76291ms step_avg:92.14ms +step:829/1645 train_time:76384ms step_avg:92.14ms +step:830/1645 train_time:76477ms step_avg:92.14ms +step:831/1645 train_time:76570ms step_avg:92.14ms +step:832/1645 train_time:76663ms step_avg:92.14ms +step:833/1645 train_time:76755ms step_avg:92.14ms +step:834/1645 train_time:76848ms step_avg:92.14ms +step:835/1645 train_time:76941ms step_avg:92.15ms +step:836/1645 train_time:77035ms step_avg:92.15ms +step:837/1645 train_time:77128ms step_avg:92.15ms +step:838/1645 train_time:77220ms step_avg:92.15ms +step:839/1645 train_time:77313ms step_avg:92.15ms +step:840/1645 train_time:77406ms step_avg:92.15ms +step:841/1645 train_time:77500ms step_avg:92.15ms +step:842/1645 train_time:77592ms step_avg:92.15ms +step:843/1645 train_time:77686ms step_avg:92.15ms +step:844/1645 train_time:77779ms step_avg:92.15ms +step:845/1645 train_time:77871ms step_avg:92.15ms +step:846/1645 train_time:77963ms step_avg:92.16ms +step:847/1645 train_time:78057ms step_avg:92.16ms +step:848/1645 train_time:78150ms step_avg:92.16ms +step:849/1645 train_time:78244ms step_avg:92.16ms +step:850/1645 train_time:78336ms step_avg:92.16ms +step:851/1645 train_time:78430ms step_avg:92.16ms +step:852/1645 train_time:78523ms step_avg:92.16ms +step:853/1645 train_time:78618ms step_avg:92.17ms +step:854/1645 train_time:78711ms step_avg:92.17ms +step:855/1645 train_time:78803ms step_avg:92.17ms +step:856/1645 train_time:78896ms step_avg:92.17ms +step:857/1645 train_time:78988ms step_avg:92.17ms +step:858/1645 train_time:79081ms step_avg:92.17ms +step:859/1645 train_time:79175ms step_avg:92.17ms +step:860/1645 train_time:79268ms step_avg:92.17ms +step:861/1645 train_time:79360ms step_avg:92.17ms +step:862/1645 train_time:79453ms step_avg:92.17ms +step:863/1645 train_time:79545ms step_avg:92.17ms +step:864/1645 train_time:79638ms step_avg:92.17ms +step:865/1645 train_time:79730ms step_avg:92.17ms +step:866/1645 train_time:79824ms step_avg:92.18ms +step:867/1645 train_time:79917ms step_avg:92.18ms +step:868/1645 train_time:80009ms step_avg:92.18ms +step:869/1645 train_time:80103ms step_avg:92.18ms +step:870/1645 train_time:80196ms step_avg:92.18ms +step:871/1645 train_time:80288ms step_avg:92.18ms +step:872/1645 train_time:80381ms step_avg:92.18ms +step:873/1645 train_time:80474ms step_avg:92.18ms +step:874/1645 train_time:80567ms step_avg:92.18ms +step:875/1645 train_time:80660ms step_avg:92.18ms +step:875/1645 val_loss:3.5166 train_time:80752ms step_avg:92.29ms +step:876/1645 train_time:80772ms step_avg:92.21ms +step:877/1645 train_time:80852ms step_avg:92.19ms +step:878/1645 train_time:80945ms step_avg:92.19ms +step:879/1645 train_time:81038ms step_avg:92.19ms +step:880/1645 train_time:81130ms step_avg:92.19ms +step:881/1645 train_time:81222ms step_avg:92.19ms +step:882/1645 train_time:81314ms step_avg:92.19ms +step:883/1645 train_time:81406ms step_avg:92.19ms +step:884/1645 train_time:81498ms step_avg:92.19ms +step:885/1645 train_time:81590ms step_avg:92.19ms +step:886/1645 train_time:81684ms step_avg:92.19ms +step:887/1645 train_time:81778ms step_avg:92.20ms +step:888/1645 train_time:81872ms step_avg:92.20ms +step:889/1645 train_time:81966ms step_avg:92.20ms +step:890/1645 train_time:82059ms step_avg:92.20ms +step:891/1645 train_time:82152ms step_avg:92.20ms +step:892/1645 train_time:82244ms step_avg:92.20ms +step:893/1645 train_time:82336ms step_avg:92.20ms +step:894/1645 train_time:82429ms step_avg:92.20ms +step:895/1645 train_time:82521ms step_avg:92.20ms +step:896/1645 train_time:82614ms step_avg:92.20ms +step:897/1645 train_time:82708ms step_avg:92.21ms +step:898/1645 train_time:82801ms step_avg:92.21ms +step:899/1645 train_time:82894ms step_avg:92.21ms +step:900/1645 train_time:82988ms step_avg:92.21ms +step:901/1645 train_time:83082ms step_avg:92.21ms +step:902/1645 train_time:83174ms step_avg:92.21ms +step:903/1645 train_time:83267ms step_avg:92.21ms +step:904/1645 train_time:83359ms step_avg:92.21ms +step:905/1645 train_time:83451ms step_avg:92.21ms +step:906/1645 train_time:83543ms step_avg:92.21ms +step:907/1645 train_time:83637ms step_avg:92.21ms +step:908/1645 train_time:83730ms step_avg:92.21ms +step:909/1645 train_time:83824ms step_avg:92.22ms +step:910/1645 train_time:83918ms step_avg:92.22ms +step:911/1645 train_time:84012ms step_avg:92.22ms +step:912/1645 train_time:84105ms step_avg:92.22ms +step:913/1645 train_time:84198ms step_avg:92.22ms +step:914/1645 train_time:84290ms step_avg:92.22ms +step:915/1645 train_time:84383ms step_avg:92.22ms +step:916/1645 train_time:84475ms step_avg:92.22ms +step:917/1645 train_time:84569ms step_avg:92.22ms +step:918/1645 train_time:84661ms step_avg:92.22ms +step:919/1645 train_time:84754ms step_avg:92.22ms +step:920/1645 train_time:84847ms step_avg:92.23ms +step:921/1645 train_time:84941ms step_avg:92.23ms +step:922/1645 train_time:85034ms step_avg:92.23ms +step:923/1645 train_time:85127ms step_avg:92.23ms +step:924/1645 train_time:85221ms step_avg:92.23ms +step:925/1645 train_time:85312ms step_avg:92.23ms +step:926/1645 train_time:85405ms step_avg:92.23ms +step:927/1645 train_time:85497ms step_avg:92.23ms +step:928/1645 train_time:85590ms step_avg:92.23ms +step:929/1645 train_time:85684ms step_avg:92.23ms +step:930/1645 train_time:85777ms step_avg:92.23ms +step:931/1645 train_time:85870ms step_avg:92.23ms +step:932/1645 train_time:85963ms step_avg:92.24ms +step:933/1645 train_time:86057ms step_avg:92.24ms +step:934/1645 train_time:86150ms step_avg:92.24ms +step:935/1645 train_time:86243ms step_avg:92.24ms +step:936/1645 train_time:86337ms step_avg:92.24ms +step:937/1645 train_time:86429ms step_avg:92.24ms +step:938/1645 train_time:86523ms step_avg:92.24ms +step:939/1645 train_time:86615ms step_avg:92.24ms +step:940/1645 train_time:86708ms step_avg:92.24ms +step:941/1645 train_time:86800ms step_avg:92.24ms +step:942/1645 train_time:86893ms step_avg:92.24ms +step:943/1645 train_time:86987ms step_avg:92.24ms +step:944/1645 train_time:87080ms step_avg:92.25ms +step:945/1645 train_time:87172ms step_avg:92.25ms +step:946/1645 train_time:87267ms step_avg:92.25ms +step:947/1645 train_time:87360ms step_avg:92.25ms +step:948/1645 train_time:87452ms step_avg:92.25ms +step:949/1645 train_time:87545ms step_avg:92.25ms +step:950/1645 train_time:87638ms step_avg:92.25ms +step:951/1645 train_time:87731ms step_avg:92.25ms +step:952/1645 train_time:87824ms step_avg:92.25ms +step:953/1645 train_time:87917ms step_avg:92.25ms +step:954/1645 train_time:88010ms step_avg:92.25ms +step:955/1645 train_time:88103ms step_avg:92.25ms +step:956/1645 train_time:88196ms step_avg:92.25ms +step:957/1645 train_time:88288ms step_avg:92.26ms +step:958/1645 train_time:88381ms step_avg:92.26ms +step:959/1645 train_time:88474ms step_avg:92.26ms +step:960/1645 train_time:88566ms step_avg:92.26ms +step:961/1645 train_time:88660ms step_avg:92.26ms +step:962/1645 train_time:88753ms step_avg:92.26ms +step:963/1645 train_time:88846ms step_avg:92.26ms +step:964/1645 train_time:88940ms step_avg:92.26ms +step:965/1645 train_time:89032ms step_avg:92.26ms +step:966/1645 train_time:89126ms step_avg:92.26ms +step:967/1645 train_time:89219ms step_avg:92.26ms +step:968/1645 train_time:89311ms step_avg:92.26ms +step:969/1645 train_time:89404ms step_avg:92.26ms +step:970/1645 train_time:89497ms step_avg:92.26ms +step:971/1645 train_time:89590ms step_avg:92.27ms +step:972/1645 train_time:89683ms step_avg:92.27ms +step:973/1645 train_time:89777ms step_avg:92.27ms +step:974/1645 train_time:89870ms step_avg:92.27ms +step:975/1645 train_time:89963ms step_avg:92.27ms +step:976/1645 train_time:90057ms step_avg:92.27ms +step:977/1645 train_time:90150ms step_avg:92.27ms +step:978/1645 train_time:90243ms step_avg:92.27ms +step:979/1645 train_time:90337ms step_avg:92.27ms +step:980/1645 train_time:90429ms step_avg:92.27ms +step:981/1645 train_time:90522ms step_avg:92.27ms +step:982/1645 train_time:90615ms step_avg:92.28ms +step:983/1645 train_time:90707ms step_avg:92.28ms +step:984/1645 train_time:90799ms step_avg:92.28ms +step:985/1645 train_time:90892ms step_avg:92.28ms +step:986/1645 train_time:90985ms step_avg:92.28ms +step:987/1645 train_time:91078ms step_avg:92.28ms +step:988/1645 train_time:91171ms step_avg:92.28ms +step:989/1645 train_time:91265ms step_avg:92.28ms +step:990/1645 train_time:91357ms step_avg:92.28ms +step:991/1645 train_time:91451ms step_avg:92.28ms +step:992/1645 train_time:91545ms step_avg:92.28ms +step:993/1645 train_time:91638ms step_avg:92.28ms +step:994/1645 train_time:91730ms step_avg:92.28ms +step:995/1645 train_time:91824ms step_avg:92.29ms +step:996/1645 train_time:91917ms step_avg:92.29ms +step:997/1645 train_time:92011ms step_avg:92.29ms +step:998/1645 train_time:92103ms step_avg:92.29ms +step:999/1645 train_time:92196ms step_avg:92.29ms +step:1000/1645 train_time:92289ms step_avg:92.29ms +step:1000/1645 val_loss:3.4662 train_time:92382ms step_avg:92.38ms +step:1001/1645 train_time:92403ms step_avg:92.31ms +step:1002/1645 train_time:92479ms step_avg:92.29ms +step:1003/1645 train_time:92573ms step_avg:92.30ms +step:1004/1645 train_time:92665ms step_avg:92.30ms +step:1005/1645 train_time:92757ms step_avg:92.30ms +step:1006/1645 train_time:92848ms step_avg:92.29ms +step:1007/1645 train_time:92941ms step_avg:92.29ms +step:1008/1645 train_time:93033ms step_avg:92.29ms +step:1009/1645 train_time:93124ms step_avg:92.29ms +step:1010/1645 train_time:93217ms step_avg:92.29ms +step:1011/1645 train_time:93311ms step_avg:92.30ms +step:1012/1645 train_time:93406ms step_avg:92.30ms +step:1013/1645 train_time:93501ms step_avg:92.30ms +step:1014/1645 train_time:93595ms step_avg:92.30ms +step:1015/1645 train_time:93687ms step_avg:92.30ms +step:1016/1645 train_time:93779ms step_avg:92.30ms +step:1017/1645 train_time:93871ms step_avg:92.30ms +step:1018/1645 train_time:93965ms step_avg:92.30ms +step:1019/1645 train_time:94057ms step_avg:92.30ms +step:1020/1645 train_time:94149ms step_avg:92.30ms +step:1021/1645 train_time:94242ms step_avg:92.30ms +step:1022/1645 train_time:94335ms step_avg:92.30ms +step:1023/1645 train_time:94429ms step_avg:92.31ms +step:1024/1645 train_time:94524ms step_avg:92.31ms +step:1025/1645 train_time:94617ms step_avg:92.31ms +step:1026/1645 train_time:94709ms step_avg:92.31ms +step:1027/1645 train_time:94801ms step_avg:92.31ms +step:1028/1645 train_time:94894ms step_avg:92.31ms +step:1029/1645 train_time:94987ms step_avg:92.31ms +step:1030/1645 train_time:95079ms step_avg:92.31ms +step:1031/1645 train_time:95171ms step_avg:92.31ms +step:1032/1645 train_time:95264ms step_avg:92.31ms +step:1033/1645 train_time:95358ms step_avg:92.31ms +step:1034/1645 train_time:95451ms step_avg:92.31ms +step:1035/1645 train_time:95544ms step_avg:92.31ms +step:1036/1645 train_time:95636ms step_avg:92.31ms +step:1037/1645 train_time:95730ms step_avg:92.31ms +step:1038/1645 train_time:95823ms step_avg:92.31ms +step:1039/1645 train_time:95916ms step_avg:92.32ms +step:1040/1645 train_time:96008ms step_avg:92.32ms +step:1041/1645 train_time:96103ms step_avg:92.32ms +step:1042/1645 train_time:96194ms step_avg:92.32ms +step:1043/1645 train_time:96287ms step_avg:92.32ms +step:1044/1645 train_time:96380ms step_avg:92.32ms +step:1045/1645 train_time:96473ms step_avg:92.32ms +step:1046/1645 train_time:96566ms step_avg:92.32ms +step:1047/1645 train_time:96659ms step_avg:92.32ms +step:1048/1645 train_time:96752ms step_avg:92.32ms +step:1049/1645 train_time:96845ms step_avg:92.32ms +step:1050/1645 train_time:96938ms step_avg:92.32ms +step:1051/1645 train_time:97030ms step_avg:92.32ms +step:1052/1645 train_time:97123ms step_avg:92.32ms +step:1053/1645 train_time:97216ms step_avg:92.32ms +step:1054/1645 train_time:97310ms step_avg:92.32ms +step:1055/1645 train_time:97403ms step_avg:92.32ms +step:1056/1645 train_time:97496ms step_avg:92.33ms +step:1057/1645 train_time:97590ms step_avg:92.33ms +step:1058/1645 train_time:97683ms step_avg:92.33ms +step:1059/1645 train_time:97776ms step_avg:92.33ms +step:1060/1645 train_time:97869ms step_avg:92.33ms +step:1061/1645 train_time:97962ms step_avg:92.33ms +step:1062/1645 train_time:98056ms step_avg:92.33ms +step:1063/1645 train_time:98148ms step_avg:92.33ms +step:1064/1645 train_time:98241ms step_avg:92.33ms +step:1065/1645 train_time:98333ms step_avg:92.33ms +step:1066/1645 train_time:98427ms step_avg:92.33ms +step:1067/1645 train_time:98520ms step_avg:92.33ms +step:1068/1645 train_time:98614ms step_avg:92.33ms +step:1069/1645 train_time:98706ms step_avg:92.33ms +step:1070/1645 train_time:98799ms step_avg:92.34ms +step:1071/1645 train_time:98891ms step_avg:92.34ms +step:1072/1645 train_time:98985ms step_avg:92.34ms +step:1073/1645 train_time:99077ms step_avg:92.34ms +step:1074/1645 train_time:99170ms step_avg:92.34ms +step:1075/1645 train_time:99263ms step_avg:92.34ms +step:1076/1645 train_time:99356ms step_avg:92.34ms +step:1077/1645 train_time:99449ms step_avg:92.34ms +step:1078/1645 train_time:99543ms step_avg:92.34ms +step:1079/1645 train_time:99636ms step_avg:92.34ms +step:1080/1645 train_time:99729ms step_avg:92.34ms +step:1081/1645 train_time:99823ms step_avg:92.34ms +step:1082/1645 train_time:99917ms step_avg:92.34ms +step:1083/1645 train_time:100009ms step_avg:92.34ms +step:1084/1645 train_time:100102ms step_avg:92.34ms +step:1085/1645 train_time:100195ms step_avg:92.35ms +step:1086/1645 train_time:100288ms step_avg:92.35ms +step:1087/1645 train_time:100381ms step_avg:92.35ms +step:1088/1645 train_time:100473ms step_avg:92.35ms +step:1089/1645 train_time:100566ms step_avg:92.35ms +step:1090/1645 train_time:100659ms step_avg:92.35ms +step:1091/1645 train_time:100752ms step_avg:92.35ms +step:1092/1645 train_time:100845ms step_avg:92.35ms +step:1093/1645 train_time:100937ms step_avg:92.35ms +step:1094/1645 train_time:101031ms step_avg:92.35ms +step:1095/1645 train_time:101123ms step_avg:92.35ms +step:1096/1645 train_time:101216ms step_avg:92.35ms +step:1097/1645 train_time:101308ms step_avg:92.35ms +step:1098/1645 train_time:101402ms step_avg:92.35ms +step:1099/1645 train_time:101496ms step_avg:92.35ms +step:1100/1645 train_time:101589ms step_avg:92.35ms +step:1101/1645 train_time:101683ms step_avg:92.36ms +step:1102/1645 train_time:101777ms step_avg:92.36ms +step:1103/1645 train_time:101870ms step_avg:92.36ms +step:1104/1645 train_time:101964ms step_avg:92.36ms +step:1105/1645 train_time:102057ms step_avg:92.36ms +step:1106/1645 train_time:102150ms step_avg:92.36ms +step:1107/1645 train_time:102243ms step_avg:92.36ms +step:1108/1645 train_time:102337ms step_avg:92.36ms +step:1109/1645 train_time:102431ms step_avg:92.36ms +step:1110/1645 train_time:102525ms step_avg:92.36ms +step:1111/1645 train_time:102619ms step_avg:92.37ms +step:1112/1645 train_time:102712ms step_avg:92.37ms +step:1113/1645 train_time:102806ms step_avg:92.37ms +step:1114/1645 train_time:102900ms step_avg:92.37ms +step:1115/1645 train_time:102994ms step_avg:92.37ms +step:1116/1645 train_time:103087ms step_avg:92.37ms +step:1117/1645 train_time:103180ms step_avg:92.37ms +step:1118/1645 train_time:103273ms step_avg:92.37ms +step:1119/1645 train_time:103367ms step_avg:92.37ms +step:1120/1645 train_time:103461ms step_avg:92.38ms +step:1121/1645 train_time:103554ms step_avg:92.38ms +step:1122/1645 train_time:103649ms step_avg:92.38ms +step:1123/1645 train_time:103743ms step_avg:92.38ms +step:1124/1645 train_time:103836ms step_avg:92.38ms +step:1125/1645 train_time:103930ms step_avg:92.38ms +step:1125/1645 val_loss:3.4134 train_time:104024ms step_avg:92.47ms +step:1126/1645 train_time:104039ms step_avg:92.40ms +step:1127/1645 train_time:104123ms step_avg:92.39ms +step:1128/1645 train_time:104227ms step_avg:92.40ms +step:1129/1645 train_time:104323ms step_avg:92.40ms +step:1130/1645 train_time:104415ms step_avg:92.40ms +step:1131/1645 train_time:104508ms step_avg:92.40ms +step:1132/1645 train_time:104600ms step_avg:92.40ms +step:1133/1645 train_time:104693ms step_avg:92.40ms +step:1134/1645 train_time:104785ms step_avg:92.40ms +step:1135/1645 train_time:104877ms step_avg:92.40ms +step:1136/1645 train_time:104972ms step_avg:92.41ms +step:1137/1645 train_time:105067ms step_avg:92.41ms +step:1138/1645 train_time:105164ms step_avg:92.41ms +step:1139/1645 train_time:105260ms step_avg:92.41ms +step:1140/1645 train_time:105354ms step_avg:92.42ms +step:1141/1645 train_time:105447ms step_avg:92.42ms +step:1142/1645 train_time:105541ms step_avg:92.42ms +step:1143/1645 train_time:105633ms step_avg:92.42ms +step:1144/1645 train_time:105726ms step_avg:92.42ms +step:1145/1645 train_time:105819ms step_avg:92.42ms +step:1146/1645 train_time:105912ms step_avg:92.42ms +step:1147/1645 train_time:106006ms step_avg:92.42ms +step:1148/1645 train_time:106100ms step_avg:92.42ms +step:1149/1645 train_time:106195ms step_avg:92.42ms +step:1150/1645 train_time:106289ms step_avg:92.43ms +step:1151/1645 train_time:106382ms step_avg:92.43ms +step:1152/1645 train_time:106476ms step_avg:92.43ms +step:1153/1645 train_time:106569ms step_avg:92.43ms +step:1154/1645 train_time:106662ms step_avg:92.43ms +step:1155/1645 train_time:106755ms step_avg:92.43ms +step:1156/1645 train_time:106848ms step_avg:92.43ms +step:1157/1645 train_time:106942ms step_avg:92.43ms +step:1158/1645 train_time:107036ms step_avg:92.43ms +step:1159/1645 train_time:107130ms step_avg:92.43ms +step:1160/1645 train_time:107225ms step_avg:92.43ms +step:1161/1645 train_time:107319ms step_avg:92.44ms +step:1162/1645 train_time:107413ms step_avg:92.44ms +step:1163/1645 train_time:107507ms step_avg:92.44ms +step:1164/1645 train_time:107601ms step_avg:92.44ms +step:1165/1645 train_time:107693ms step_avg:92.44ms +step:1166/1645 train_time:107786ms step_avg:92.44ms +step:1167/1645 train_time:107880ms step_avg:92.44ms +step:1168/1645 train_time:107974ms step_avg:92.44ms +step:1169/1645 train_time:108068ms step_avg:92.44ms +step:1170/1645 train_time:108162ms step_avg:92.45ms +step:1171/1645 train_time:108256ms step_avg:92.45ms +step:1172/1645 train_time:108349ms step_avg:92.45ms +step:1173/1645 train_time:108444ms step_avg:92.45ms +step:1174/1645 train_time:108537ms step_avg:92.45ms +step:1175/1645 train_time:108631ms step_avg:92.45ms +step:1176/1645 train_time:108724ms step_avg:92.45ms +step:1177/1645 train_time:108817ms step_avg:92.45ms +step:1178/1645 train_time:108911ms step_avg:92.45ms +step:1179/1645 train_time:109004ms step_avg:92.45ms +step:1180/1645 train_time:109098ms step_avg:92.46ms +step:1181/1645 train_time:109191ms step_avg:92.46ms +step:1182/1645 train_time:109285ms step_avg:92.46ms +step:1183/1645 train_time:109379ms step_avg:92.46ms +step:1184/1645 train_time:109473ms step_avg:92.46ms +step:1185/1645 train_time:109567ms step_avg:92.46ms +step:1186/1645 train_time:109661ms step_avg:92.46ms +step:1187/1645 train_time:109754ms step_avg:92.46ms +step:1188/1645 train_time:109849ms step_avg:92.47ms +step:1189/1645 train_time:109942ms step_avg:92.47ms +step:1190/1645 train_time:110036ms step_avg:92.47ms +step:1191/1645 train_time:110130ms step_avg:92.47ms +step:1192/1645 train_time:110223ms step_avg:92.47ms +step:1193/1645 train_time:110317ms step_avg:92.47ms +step:1194/1645 train_time:110411ms step_avg:92.47ms +step:1195/1645 train_time:110506ms step_avg:92.47ms +step:1196/1645 train_time:110599ms step_avg:92.47ms +step:1197/1645 train_time:110692ms step_avg:92.47ms +step:1198/1645 train_time:110786ms step_avg:92.48ms +step:1199/1645 train_time:110879ms step_avg:92.48ms +step:1200/1645 train_time:110973ms step_avg:92.48ms +step:1201/1645 train_time:111066ms step_avg:92.48ms +step:1202/1645 train_time:111159ms step_avg:92.48ms +step:1203/1645 train_time:111253ms step_avg:92.48ms +step:1204/1645 train_time:111347ms step_avg:92.48ms +step:1205/1645 train_time:111441ms step_avg:92.48ms +step:1206/1645 train_time:111535ms step_avg:92.48ms +step:1207/1645 train_time:111629ms step_avg:92.48ms +step:1208/1645 train_time:111722ms step_avg:92.49ms +step:1209/1645 train_time:111816ms step_avg:92.49ms +step:1210/1645 train_time:111909ms step_avg:92.49ms +step:1211/1645 train_time:112002ms step_avg:92.49ms +step:1212/1645 train_time:112095ms step_avg:92.49ms +step:1213/1645 train_time:112189ms step_avg:92.49ms +step:1214/1645 train_time:112284ms step_avg:92.49ms +step:1215/1645 train_time:112377ms step_avg:92.49ms +step:1216/1645 train_time:112471ms step_avg:92.49ms +step:1217/1645 train_time:112565ms step_avg:92.49ms +step:1218/1645 train_time:112658ms step_avg:92.49ms +step:1219/1645 train_time:112752ms step_avg:92.50ms +step:1220/1645 train_time:112846ms step_avg:92.50ms +step:1221/1645 train_time:112939ms step_avg:92.50ms +step:1222/1645 train_time:113032ms step_avg:92.50ms +step:1223/1645 train_time:113126ms step_avg:92.50ms +step:1224/1645 train_time:113220ms step_avg:92.50ms +step:1225/1645 train_time:113314ms step_avg:92.50ms +step:1226/1645 train_time:113407ms step_avg:92.50ms +step:1227/1645 train_time:113502ms step_avg:92.50ms +step:1228/1645 train_time:113595ms step_avg:92.50ms +step:1229/1645 train_time:113688ms step_avg:92.50ms +step:1230/1645 train_time:113782ms step_avg:92.51ms +step:1231/1645 train_time:113875ms step_avg:92.51ms +step:1232/1645 train_time:113969ms step_avg:92.51ms +step:1233/1645 train_time:114062ms step_avg:92.51ms +step:1234/1645 train_time:114156ms step_avg:92.51ms +step:1235/1645 train_time:114250ms step_avg:92.51ms +step:1236/1645 train_time:114343ms step_avg:92.51ms +step:1237/1645 train_time:114437ms step_avg:92.51ms +step:1238/1645 train_time:114530ms step_avg:92.51ms +step:1239/1645 train_time:114624ms step_avg:92.51ms +step:1240/1645 train_time:114716ms step_avg:92.51ms +step:1241/1645 train_time:114810ms step_avg:92.51ms +step:1242/1645 train_time:114903ms step_avg:92.51ms +step:1243/1645 train_time:114996ms step_avg:92.52ms +step:1244/1645 train_time:115090ms step_avg:92.52ms +step:1245/1645 train_time:115184ms step_avg:92.52ms +step:1246/1645 train_time:115278ms step_avg:92.52ms +step:1247/1645 train_time:115372ms step_avg:92.52ms +step:1248/1645 train_time:115467ms step_avg:92.52ms +step:1249/1645 train_time:115560ms step_avg:92.52ms +step:1250/1645 train_time:115654ms step_avg:92.52ms +step:1250/1645 val_loss:3.3744 train_time:115749ms step_avg:92.60ms +step:1251/1645 train_time:115769ms step_avg:92.54ms +step:1252/1645 train_time:115848ms step_avg:92.53ms +step:1253/1645 train_time:115942ms step_avg:92.53ms +step:1254/1645 train_time:116036ms step_avg:92.53ms +step:1255/1645 train_time:116127ms step_avg:92.53ms +step:1256/1645 train_time:116221ms step_avg:92.53ms +step:1257/1645 train_time:116313ms step_avg:92.53ms +step:1258/1645 train_time:116405ms step_avg:92.53ms +step:1259/1645 train_time:116498ms step_avg:92.53ms +step:1260/1645 train_time:116590ms step_avg:92.53ms +step:1261/1645 train_time:116685ms step_avg:92.53ms +step:1262/1645 train_time:116783ms step_avg:92.54ms +step:1263/1645 train_time:116879ms step_avg:92.54ms +step:1264/1645 train_time:116972ms step_avg:92.54ms +step:1265/1645 train_time:117066ms step_avg:92.54ms +step:1266/1645 train_time:117158ms step_avg:92.54ms +step:1267/1645 train_time:117251ms step_avg:92.54ms +step:1268/1645 train_time:117345ms step_avg:92.54ms +step:1269/1645 train_time:117438ms step_avg:92.54ms +step:1270/1645 train_time:117531ms step_avg:92.54ms +step:1271/1645 train_time:117624ms step_avg:92.54ms +step:1272/1645 train_time:117718ms step_avg:92.55ms +step:1273/1645 train_time:117812ms step_avg:92.55ms +step:1274/1645 train_time:117907ms step_avg:92.55ms +step:1275/1645 train_time:118002ms step_avg:92.55ms +step:1276/1645 train_time:118096ms step_avg:92.55ms +step:1277/1645 train_time:118189ms step_avg:92.55ms +step:1278/1645 train_time:118282ms step_avg:92.55ms +step:1279/1645 train_time:118375ms step_avg:92.55ms +step:1280/1645 train_time:118470ms step_avg:92.55ms +step:1281/1645 train_time:118562ms step_avg:92.55ms +step:1282/1645 train_time:118656ms step_avg:92.56ms +step:1283/1645 train_time:118749ms step_avg:92.56ms +step:1284/1645 train_time:118844ms step_avg:92.56ms +step:1285/1645 train_time:118938ms step_avg:92.56ms +step:1286/1645 train_time:119032ms step_avg:92.56ms +step:1287/1645 train_time:119126ms step_avg:92.56ms +step:1288/1645 train_time:119219ms step_avg:92.56ms +step:1289/1645 train_time:119313ms step_avg:92.56ms +step:1290/1645 train_time:119406ms step_avg:92.56ms +step:1291/1645 train_time:119499ms step_avg:92.56ms +step:1292/1645 train_time:119592ms step_avg:92.56ms +step:1293/1645 train_time:119686ms step_avg:92.56ms +step:1294/1645 train_time:119780ms step_avg:92.57ms +step:1295/1645 train_time:119874ms step_avg:92.57ms +step:1296/1645 train_time:119969ms step_avg:92.57ms +step:1297/1645 train_time:120063ms step_avg:92.57ms +step:1298/1645 train_time:120158ms step_avg:92.57ms +step:1299/1645 train_time:120252ms step_avg:92.57ms +step:1300/1645 train_time:120345ms step_avg:92.57ms +step:1301/1645 train_time:120438ms step_avg:92.57ms +step:1302/1645 train_time:120530ms step_avg:92.57ms +step:1303/1645 train_time:120624ms step_avg:92.57ms +step:1304/1645 train_time:120717ms step_avg:92.57ms +step:1305/1645 train_time:120810ms step_avg:92.57ms +step:1306/1645 train_time:120905ms step_avg:92.58ms +step:1307/1645 train_time:120998ms step_avg:92.58ms +step:1308/1645 train_time:121090ms step_avg:92.58ms +step:1309/1645 train_time:121184ms step_avg:92.58ms +step:1310/1645 train_time:121278ms step_avg:92.58ms +step:1311/1645 train_time:121371ms step_avg:92.58ms +step:1312/1645 train_time:121465ms step_avg:92.58ms +step:1313/1645 train_time:121558ms step_avg:92.58ms +step:1314/1645 train_time:121652ms step_avg:92.58ms +step:1315/1645 train_time:121745ms step_avg:92.58ms +step:1316/1645 train_time:121839ms step_avg:92.58ms +step:1317/1645 train_time:121931ms step_avg:92.58ms +step:1318/1645 train_time:122025ms step_avg:92.58ms +step:1319/1645 train_time:122119ms step_avg:92.58ms +step:1320/1645 train_time:122213ms step_avg:92.59ms +step:1321/1645 train_time:122307ms step_avg:92.59ms +step:1322/1645 train_time:122401ms step_avg:92.59ms +step:1323/1645 train_time:122496ms step_avg:92.59ms +step:1324/1645 train_time:122589ms step_avg:92.59ms +step:1325/1645 train_time:122682ms step_avg:92.59ms +step:1326/1645 train_time:122775ms step_avg:92.59ms +step:1327/1645 train_time:122869ms step_avg:92.59ms +step:1328/1645 train_time:122962ms step_avg:92.59ms +step:1329/1645 train_time:123055ms step_avg:92.59ms +step:1330/1645 train_time:123150ms step_avg:92.59ms +step:1331/1645 train_time:123244ms step_avg:92.59ms +step:1332/1645 train_time:123337ms step_avg:92.60ms +step:1333/1645 train_time:123430ms step_avg:92.60ms +step:1334/1645 train_time:123523ms step_avg:92.60ms +step:1335/1645 train_time:123617ms step_avg:92.60ms +step:1336/1645 train_time:123710ms step_avg:92.60ms +step:1337/1645 train_time:123805ms step_avg:92.60ms +step:1338/1645 train_time:123898ms step_avg:92.60ms +step:1339/1645 train_time:123991ms step_avg:92.60ms +step:1340/1645 train_time:124087ms step_avg:92.60ms +step:1341/1645 train_time:124180ms step_avg:92.60ms +step:1342/1645 train_time:124274ms step_avg:92.60ms +step:1343/1645 train_time:124368ms step_avg:92.60ms +step:1344/1645 train_time:124462ms step_avg:92.61ms +step:1345/1645 train_time:124556ms step_avg:92.61ms +step:1346/1645 train_time:124649ms step_avg:92.61ms +step:1347/1645 train_time:124742ms step_avg:92.61ms +step:1348/1645 train_time:124836ms step_avg:92.61ms +step:1349/1645 train_time:124929ms step_avg:92.61ms +step:1350/1645 train_time:125023ms step_avg:92.61ms +step:1351/1645 train_time:125117ms step_avg:92.61ms +step:1352/1645 train_time:125210ms step_avg:92.61ms +step:1353/1645 train_time:125304ms step_avg:92.61ms +step:1354/1645 train_time:125397ms step_avg:92.61ms +step:1355/1645 train_time:125492ms step_avg:92.61ms +step:1356/1645 train_time:125586ms step_avg:92.61ms +step:1357/1645 train_time:125679ms step_avg:92.62ms +step:1358/1645 train_time:125772ms step_avg:92.62ms +step:1359/1645 train_time:125865ms step_avg:92.62ms +step:1360/1645 train_time:125958ms step_avg:92.62ms +step:1361/1645 train_time:126052ms step_avg:92.62ms +step:1362/1645 train_time:126146ms step_avg:92.62ms +step:1363/1645 train_time:126240ms step_avg:92.62ms +step:1364/1645 train_time:126333ms step_avg:92.62ms +step:1365/1645 train_time:126426ms step_avg:92.62ms +step:1366/1645 train_time:126519ms step_avg:92.62ms +step:1367/1645 train_time:126614ms step_avg:92.62ms +step:1368/1645 train_time:126707ms step_avg:92.62ms +step:1369/1645 train_time:126801ms step_avg:92.62ms +step:1370/1645 train_time:126896ms step_avg:92.62ms +step:1371/1645 train_time:126988ms step_avg:92.62ms +step:1372/1645 train_time:127082ms step_avg:92.63ms +step:1373/1645 train_time:127176ms step_avg:92.63ms +step:1374/1645 train_time:127269ms step_avg:92.63ms +step:1375/1645 train_time:127364ms step_avg:92.63ms +step:1375/1645 val_loss:3.3396 train_time:127457ms step_avg:92.70ms +step:1376/1645 train_time:127478ms step_avg:92.64ms +step:1377/1645 train_time:127556ms step_avg:92.63ms +step:1378/1645 train_time:127652ms step_avg:92.64ms +step:1379/1645 train_time:127746ms step_avg:92.64ms +step:1380/1645 train_time:127839ms step_avg:92.64ms +step:1381/1645 train_time:127933ms step_avg:92.64ms +step:1382/1645 train_time:128024ms step_avg:92.64ms +step:1383/1645 train_time:128116ms step_avg:92.64ms +step:1384/1645 train_time:128210ms step_avg:92.64ms +step:1385/1645 train_time:128303ms step_avg:92.64ms +step:1386/1645 train_time:128398ms step_avg:92.64ms +step:1387/1645 train_time:128494ms step_avg:92.64ms +step:1388/1645 train_time:128589ms step_avg:92.64ms +step:1389/1645 train_time:128683ms step_avg:92.64ms +step:1390/1645 train_time:128777ms step_avg:92.65ms +step:1391/1645 train_time:128871ms step_avg:92.65ms +step:1392/1645 train_time:128964ms step_avg:92.65ms +step:1393/1645 train_time:129057ms step_avg:92.65ms +step:1394/1645 train_time:129150ms step_avg:92.65ms +step:1395/1645 train_time:129242ms step_avg:92.65ms +step:1396/1645 train_time:129337ms step_avg:92.65ms +step:1397/1645 train_time:129431ms step_avg:92.65ms +step:1398/1645 train_time:129524ms step_avg:92.65ms +step:1399/1645 train_time:129620ms step_avg:92.65ms +step:1400/1645 train_time:129714ms step_avg:92.65ms +step:1401/1645 train_time:129807ms step_avg:92.65ms +step:1402/1645 train_time:129901ms step_avg:92.65ms +step:1403/1645 train_time:129994ms step_avg:92.65ms +step:1404/1645 train_time:130087ms step_avg:92.65ms +step:1405/1645 train_time:130181ms step_avg:92.66ms +step:1406/1645 train_time:130274ms step_avg:92.66ms +step:1407/1645 train_time:130366ms step_avg:92.66ms +step:1408/1645 train_time:130461ms step_avg:92.66ms +step:1409/1645 train_time:130556ms step_avg:92.66ms +step:1410/1645 train_time:130650ms step_avg:92.66ms +step:1411/1645 train_time:130744ms step_avg:92.66ms +step:1412/1645 train_time:130838ms step_avg:92.66ms +step:1413/1645 train_time:130931ms step_avg:92.66ms +step:1414/1645 train_time:131023ms step_avg:92.66ms +step:1415/1645 train_time:131116ms step_avg:92.66ms +step:1416/1645 train_time:131209ms step_avg:92.66ms +step:1417/1645 train_time:131302ms step_avg:92.66ms +step:1418/1645 train_time:131395ms step_avg:92.66ms +step:1419/1645 train_time:131489ms step_avg:92.66ms +step:1420/1645 train_time:131583ms step_avg:92.66ms +step:1421/1645 train_time:131678ms step_avg:92.67ms +step:1422/1645 train_time:131772ms step_avg:92.67ms +step:1423/1645 train_time:131865ms step_avg:92.67ms +step:1424/1645 train_time:131959ms step_avg:92.67ms +step:1425/1645 train_time:132053ms step_avg:92.67ms +step:1426/1645 train_time:132146ms step_avg:92.67ms +step:1427/1645 train_time:132239ms step_avg:92.67ms +step:1428/1645 train_time:132332ms step_avg:92.67ms +step:1429/1645 train_time:132425ms step_avg:92.67ms +step:1430/1645 train_time:132520ms step_avg:92.67ms +step:1431/1645 train_time:132614ms step_avg:92.67ms +step:1432/1645 train_time:132708ms step_avg:92.67ms +step:1433/1645 train_time:132802ms step_avg:92.67ms +step:1434/1645 train_time:132896ms step_avg:92.67ms +step:1435/1645 train_time:132990ms step_avg:92.68ms +step:1436/1645 train_time:133083ms step_avg:92.68ms +step:1437/1645 train_time:133176ms step_avg:92.68ms +step:1438/1645 train_time:133269ms step_avg:92.68ms +step:1439/1645 train_time:133363ms step_avg:92.68ms +step:1440/1645 train_time:133458ms step_avg:92.68ms +step:1441/1645 train_time:133550ms step_avg:92.68ms +step:1442/1645 train_time:133644ms step_avg:92.68ms +step:1443/1645 train_time:133738ms step_avg:92.68ms +step:1444/1645 train_time:133831ms step_avg:92.68ms +step:1445/1645 train_time:133925ms step_avg:92.68ms +step:1446/1645 train_time:134019ms step_avg:92.68ms +step:1447/1645 train_time:134113ms step_avg:92.68ms +step:1448/1645 train_time:134206ms step_avg:92.68ms +step:1449/1645 train_time:134300ms step_avg:92.68ms +step:1450/1645 train_time:134393ms step_avg:92.68ms +step:1451/1645 train_time:134487ms step_avg:92.69ms +step:1452/1645 train_time:134581ms step_avg:92.69ms +step:1453/1645 train_time:134675ms step_avg:92.69ms +step:1454/1645 train_time:134769ms step_avg:92.69ms +step:1455/1645 train_time:134864ms step_avg:92.69ms +step:1456/1645 train_time:134958ms step_avg:92.69ms +step:1457/1645 train_time:135052ms step_avg:92.69ms +step:1458/1645 train_time:135145ms step_avg:92.69ms +step:1459/1645 train_time:135239ms step_avg:92.69ms +step:1460/1645 train_time:135332ms step_avg:92.69ms +step:1461/1645 train_time:135426ms step_avg:92.69ms +step:1462/1645 train_time:135521ms step_avg:92.70ms +step:1463/1645 train_time:135615ms step_avg:92.70ms +step:1464/1645 train_time:135709ms step_avg:92.70ms +step:1465/1645 train_time:135803ms step_avg:92.70ms +step:1466/1645 train_time:135897ms step_avg:92.70ms +step:1467/1645 train_time:135991ms step_avg:92.70ms +step:1468/1645 train_time:136085ms step_avg:92.70ms +step:1469/1645 train_time:136178ms step_avg:92.70ms +step:1470/1645 train_time:136272ms step_avg:92.70ms +step:1471/1645 train_time:136365ms step_avg:92.70ms +step:1472/1645 train_time:136458ms step_avg:92.70ms +step:1473/1645 train_time:136552ms step_avg:92.70ms +step:1474/1645 train_time:136645ms step_avg:92.70ms +step:1475/1645 train_time:136739ms step_avg:92.70ms +step:1476/1645 train_time:136834ms step_avg:92.71ms +step:1477/1645 train_time:136927ms step_avg:92.71ms +step:1478/1645 train_time:137020ms step_avg:92.71ms +step:1479/1645 train_time:137115ms step_avg:92.71ms +step:1480/1645 train_time:137209ms step_avg:92.71ms +step:1481/1645 train_time:137303ms step_avg:92.71ms +step:1482/1645 train_time:137396ms step_avg:92.71ms +step:1483/1645 train_time:137490ms step_avg:92.71ms +step:1484/1645 train_time:137584ms step_avg:92.71ms +step:1485/1645 train_time:137678ms step_avg:92.71ms +step:1486/1645 train_time:137771ms step_avg:92.71ms +step:1487/1645 train_time:137865ms step_avg:92.71ms +step:1488/1645 train_time:137959ms step_avg:92.71ms +step:1489/1645 train_time:138052ms step_avg:92.71ms +step:1490/1645 train_time:138145ms step_avg:92.71ms +step:1491/1645 train_time:138239ms step_avg:92.72ms +step:1492/1645 train_time:138332ms step_avg:92.72ms +step:1493/1645 train_time:138425ms step_avg:92.72ms +step:1494/1645 train_time:138519ms step_avg:92.72ms +step:1495/1645 train_time:138613ms step_avg:92.72ms +step:1496/1645 train_time:138707ms step_avg:92.72ms +step:1497/1645 train_time:138801ms step_avg:92.72ms +step:1498/1645 train_time:138894ms step_avg:92.72ms +step:1499/1645 train_time:138987ms step_avg:92.72ms +step:1500/1645 train_time:139081ms step_avg:92.72ms +step:1500/1645 val_loss:3.3100 train_time:139176ms step_avg:92.78ms +step:1501/1645 train_time:139196ms step_avg:92.74ms +step:1502/1645 train_time:139274ms step_avg:92.73ms +step:1503/1645 train_time:139370ms step_avg:92.73ms +step:1504/1645 train_time:139463ms step_avg:92.73ms +step:1505/1645 train_time:139556ms step_avg:92.73ms +step:1506/1645 train_time:139649ms step_avg:92.73ms +step:1507/1645 train_time:139741ms step_avg:92.73ms +step:1508/1645 train_time:139835ms step_avg:92.73ms +step:1509/1645 train_time:139927ms step_avg:92.73ms +step:1510/1645 train_time:140020ms step_avg:92.73ms +step:1511/1645 train_time:140114ms step_avg:92.73ms +step:1512/1645 train_time:140210ms step_avg:92.73ms +step:1513/1645 train_time:140306ms step_avg:92.73ms +step:1514/1645 train_time:140401ms step_avg:92.73ms +step:1515/1645 train_time:140493ms step_avg:92.73ms +step:1516/1645 train_time:140587ms step_avg:92.74ms +step:1517/1645 train_time:140680ms step_avg:92.74ms +step:1518/1645 train_time:140773ms step_avg:92.74ms +step:1519/1645 train_time:140866ms step_avg:92.74ms +step:1520/1645 train_time:140959ms step_avg:92.74ms +step:1521/1645 train_time:141053ms step_avg:92.74ms +step:1522/1645 train_time:141148ms step_avg:92.74ms +step:1523/1645 train_time:141243ms step_avg:92.74ms +step:1524/1645 train_time:141338ms step_avg:92.74ms +step:1525/1645 train_time:141433ms step_avg:92.74ms +step:1526/1645 train_time:141526ms step_avg:92.74ms +step:1527/1645 train_time:141619ms step_avg:92.74ms +step:1528/1645 train_time:141712ms step_avg:92.74ms +step:1529/1645 train_time:141805ms step_avg:92.74ms +step:1530/1645 train_time:141897ms step_avg:92.74ms +step:1531/1645 train_time:141990ms step_avg:92.74ms +step:1532/1645 train_time:142083ms step_avg:92.74ms +step:1533/1645 train_time:142177ms step_avg:92.74ms +step:1534/1645 train_time:142271ms step_avg:92.75ms +step:1535/1645 train_time:142366ms step_avg:92.75ms +step:1536/1645 train_time:142460ms step_avg:92.75ms +step:1537/1645 train_time:142554ms step_avg:92.75ms +step:1538/1645 train_time:142648ms step_avg:92.75ms +step:1539/1645 train_time:142741ms step_avg:92.75ms +step:1540/1645 train_time:142835ms step_avg:92.75ms +step:1541/1645 train_time:142928ms step_avg:92.75ms +step:1542/1645 train_time:143022ms step_avg:92.75ms +step:1543/1645 train_time:143116ms step_avg:92.75ms +step:1544/1645 train_time:143210ms step_avg:92.75ms +step:1545/1645 train_time:143305ms step_avg:92.75ms +step:1546/1645 train_time:143398ms step_avg:92.75ms +step:1547/1645 train_time:143492ms step_avg:92.75ms +step:1548/1645 train_time:143586ms step_avg:92.76ms +step:1549/1645 train_time:143679ms step_avg:92.76ms +step:1550/1645 train_time:143773ms step_avg:92.76ms +step:1551/1645 train_time:143866ms step_avg:92.76ms +step:1552/1645 train_time:143958ms step_avg:92.76ms +step:1553/1645 train_time:144053ms step_avg:92.76ms +step:1554/1645 train_time:144148ms step_avg:92.76ms +step:1555/1645 train_time:144242ms step_avg:92.76ms +step:1556/1645 train_time:144337ms step_avg:92.76ms +step:1557/1645 train_time:144430ms step_avg:92.76ms +step:1558/1645 train_time:144524ms step_avg:92.76ms +step:1559/1645 train_time:144618ms step_avg:92.76ms +step:1560/1645 train_time:144711ms step_avg:92.76ms +step:1561/1645 train_time:144805ms step_avg:92.76ms +step:1562/1645 train_time:144898ms step_avg:92.76ms +step:1563/1645 train_time:144991ms step_avg:92.76ms +step:1564/1645 train_time:145085ms step_avg:92.77ms +step:1565/1645 train_time:145178ms step_avg:92.77ms +step:1566/1645 train_time:145272ms step_avg:92.77ms +step:1567/1645 train_time:145366ms step_avg:92.77ms +step:1568/1645 train_time:145459ms step_avg:92.77ms +step:1569/1645 train_time:145553ms step_avg:92.77ms +step:1570/1645 train_time:145648ms step_avg:92.77ms +step:1571/1645 train_time:145741ms step_avg:92.77ms +step:1572/1645 train_time:145834ms step_avg:92.77ms +step:1573/1645 train_time:145927ms step_avg:92.77ms +step:1574/1645 train_time:146022ms step_avg:92.77ms +step:1575/1645 train_time:146116ms step_avg:92.77ms +step:1576/1645 train_time:146209ms step_avg:92.77ms +step:1577/1645 train_time:146303ms step_avg:92.77ms +step:1578/1645 train_time:146397ms step_avg:92.77ms +step:1579/1645 train_time:146491ms step_avg:92.77ms +step:1580/1645 train_time:146585ms step_avg:92.78ms +step:1581/1645 train_time:146680ms step_avg:92.78ms +step:1582/1645 train_time:146773ms step_avg:92.78ms +step:1583/1645 train_time:146867ms step_avg:92.78ms +step:1584/1645 train_time:146961ms step_avg:92.78ms +step:1585/1645 train_time:147054ms step_avg:92.78ms +step:1586/1645 train_time:147149ms step_avg:92.78ms +step:1587/1645 train_time:147243ms step_avg:92.78ms +step:1588/1645 train_time:147337ms step_avg:92.78ms +step:1589/1645 train_time:147431ms step_avg:92.78ms +step:1590/1645 train_time:147524ms step_avg:92.78ms +step:1591/1645 train_time:147617ms step_avg:92.78ms +step:1592/1645 train_time:147711ms step_avg:92.78ms +step:1593/1645 train_time:147804ms step_avg:92.78ms +step:1594/1645 train_time:147897ms step_avg:92.78ms +step:1595/1645 train_time:147990ms step_avg:92.78ms +step:1596/1645 train_time:148084ms step_avg:92.78ms +step:1597/1645 train_time:148177ms step_avg:92.78ms +step:1598/1645 train_time:148271ms step_avg:92.79ms +step:1599/1645 train_time:148364ms step_avg:92.79ms +step:1600/1645 train_time:148457ms step_avg:92.79ms +step:1601/1645 train_time:148553ms step_avg:92.79ms +step:1602/1645 train_time:148646ms step_avg:92.79ms +step:1603/1645 train_time:148740ms step_avg:92.79ms +step:1604/1645 train_time:148833ms step_avg:92.79ms +step:1605/1645 train_time:148927ms step_avg:92.79ms +step:1606/1645 train_time:149020ms step_avg:92.79ms +step:1607/1645 train_time:149114ms step_avg:92.79ms +step:1608/1645 train_time:149208ms step_avg:92.79ms +step:1609/1645 train_time:149301ms step_avg:92.79ms +step:1610/1645 train_time:149395ms step_avg:92.79ms +step:1611/1645 train_time:149489ms step_avg:92.79ms +step:1612/1645 train_time:149583ms step_avg:92.79ms +step:1613/1645 train_time:149677ms step_avg:92.79ms +step:1614/1645 train_time:149771ms step_avg:92.79ms +step:1615/1645 train_time:149865ms step_avg:92.80ms +step:1616/1645 train_time:149959ms step_avg:92.80ms +step:1617/1645 train_time:150051ms step_avg:92.80ms +step:1618/1645 train_time:150145ms step_avg:92.80ms +step:1619/1645 train_time:150238ms step_avg:92.80ms +step:1620/1645 train_time:150332ms step_avg:92.80ms +step:1621/1645 train_time:150425ms step_avg:92.80ms +step:1622/1645 train_time:150519ms step_avg:92.80ms +step:1623/1645 train_time:150613ms step_avg:92.80ms +step:1624/1645 train_time:150707ms step_avg:92.80ms +step:1625/1645 train_time:150800ms step_avg:92.80ms +step:1625/1645 val_loss:3.2859 train_time:150894ms step_avg:92.86ms +step:1626/1645 train_time:150914ms step_avg:92.81ms +step:1627/1645 train_time:150991ms step_avg:92.80ms +step:1628/1645 train_time:151088ms step_avg:92.81ms +step:1629/1645 train_time:151182ms step_avg:92.81ms +step:1630/1645 train_time:151275ms step_avg:92.81ms +step:1631/1645 train_time:151367ms step_avg:92.81ms +step:1632/1645 train_time:151461ms step_avg:92.81ms +step:1633/1645 train_time:151554ms step_avg:92.81ms +step:1634/1645 train_time:151648ms step_avg:92.81ms +step:1635/1645 train_time:151740ms step_avg:92.81ms +step:1636/1645 train_time:151834ms step_avg:92.81ms +step:1637/1645 train_time:151929ms step_avg:92.81ms +step:1638/1645 train_time:152024ms step_avg:92.81ms +step:1639/1645 train_time:152119ms step_avg:92.81ms +step:1640/1645 train_time:152213ms step_avg:92.81ms +step:1641/1645 train_time:152306ms step_avg:92.81ms +step:1642/1645 train_time:152399ms step_avg:92.81ms +step:1643/1645 train_time:152492ms step_avg:92.81ms +step:1644/1645 train_time:152586ms step_avg:92.81ms +step:1645/1645 train_time:152678ms step_avg:92.81ms +step:1645/1645 val_loss:3.2801 train_time:152773ms step_avg:92.87ms +peak memory allocated: 32074 MiB reserved: 46756 MiB diff --git a/records/091825_Smear/4e8d8366-3db6-43fc-aba0-03be7d484dd3.txt b/records/091825_Smear/4e8d8366-3db6-43fc-aba0-03be7d484dd3.txt new file mode 100644 index 000000000..bd6ae7cb3 --- /dev/null +++ b/records/091825_Smear/4e8d8366-3db6-43fc-aba0-03be7d484dd3.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:48:27 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 34C P0 124W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:137ms step_avg:137.07ms +step:2/1645 train_time:160ms step_avg:79.87ms +step:3/1645 train_time:226ms step_avg:75.21ms +step:4/1645 train_time:315ms step_avg:78.74ms +step:5/1645 train_time:406ms step_avg:81.11ms +step:6/1645 train_time:496ms step_avg:82.67ms +step:7/1645 train_time:587ms step_avg:83.83ms +step:8/1645 train_time:678ms step_avg:84.79ms +step:9/1645 train_time:769ms step_avg:85.45ms +step:10/1645 train_time:859ms step_avg:85.95ms +step:11/1645 train_time:950ms step_avg:86.38ms +step:12/1645 train_time:1046ms step_avg:87.17ms +step:13/1645 train_time:1143ms step_avg:87.93ms +step:14/1645 train_time:1237ms step_avg:88.38ms +step:15/1645 train_time:1329ms step_avg:88.57ms +step:16/1645 train_time:1420ms step_avg:88.75ms +step:17/1645 train_time:1511ms step_avg:88.90ms +step:18/1645 train_time:1602ms step_avg:89.02ms +step:19/1645 train_time:1693ms step_avg:89.12ms +step:20/1645 train_time:1784ms step_avg:89.22ms +step:21/1645 train_time:1877ms step_avg:89.36ms +step:22/1645 train_time:1968ms step_avg:89.45ms +step:23/1645 train_time:2061ms step_avg:89.60ms +step:24/1645 train_time:2155ms step_avg:89.79ms +step:25/1645 train_time:2248ms step_avg:89.93ms +step:26/1645 train_time:2341ms step_avg:90.04ms +step:27/1645 train_time:2433ms step_avg:90.10ms +step:28/1645 train_time:2524ms step_avg:90.15ms +step:29/1645 train_time:2615ms step_avg:90.17ms +step:30/1645 train_time:2706ms step_avg:90.21ms +step:31/1645 train_time:2797ms step_avg:90.23ms +step:32/1645 train_time:2889ms step_avg:90.28ms +step:33/1645 train_time:2982ms step_avg:90.38ms +step:34/1645 train_time:3075ms step_avg:90.45ms +step:35/1645 train_time:3168ms step_avg:90.51ms +step:36/1645 train_time:3262ms step_avg:90.61ms +step:37/1645 train_time:3354ms step_avg:90.64ms +step:38/1645 train_time:3446ms step_avg:90.68ms +step:39/1645 train_time:3537ms step_avg:90.70ms +step:40/1645 train_time:3629ms step_avg:90.72ms +step:41/1645 train_time:3719ms step_avg:90.72ms +step:42/1645 train_time:3811ms step_avg:90.73ms +step:43/1645 train_time:3903ms step_avg:90.76ms +step:44/1645 train_time:3994ms step_avg:90.77ms +step:45/1645 train_time:4087ms step_avg:90.83ms +step:46/1645 train_time:4179ms step_avg:90.85ms +step:47/1645 train_time:4271ms step_avg:90.88ms +step:48/1645 train_time:4363ms step_avg:90.90ms +step:49/1645 train_time:4457ms step_avg:90.95ms +step:50/1645 train_time:4548ms step_avg:90.97ms +step:51/1645 train_time:4640ms step_avg:90.98ms +step:52/1645 train_time:4732ms step_avg:90.99ms +step:53/1645 train_time:4823ms step_avg:90.99ms +step:54/1645 train_time:4914ms step_avg:91.01ms +step:55/1645 train_time:5006ms step_avg:91.01ms +step:56/1645 train_time:5098ms step_avg:91.04ms +step:57/1645 train_time:5190ms step_avg:91.05ms +step:58/1645 train_time:5282ms step_avg:91.07ms +step:59/1645 train_time:5375ms step_avg:91.09ms +step:60/1645 train_time:5467ms step_avg:91.11ms +step:61/1645 train_time:5559ms step_avg:91.14ms +step:62/1645 train_time:5651ms step_avg:91.15ms +step:63/1645 train_time:5742ms step_avg:91.15ms +step:64/1645 train_time:5834ms step_avg:91.15ms +step:65/1645 train_time:5925ms step_avg:91.16ms +step:66/1645 train_time:6017ms step_avg:91.16ms +step:67/1645 train_time:6109ms step_avg:91.17ms +step:68/1645 train_time:6202ms step_avg:91.21ms +step:69/1645 train_time:6293ms step_avg:91.21ms +step:70/1645 train_time:6386ms step_avg:91.23ms +step:71/1645 train_time:6479ms step_avg:91.26ms +step:72/1645 train_time:6571ms step_avg:91.26ms +step:73/1645 train_time:6663ms step_avg:91.27ms +step:74/1645 train_time:6755ms step_avg:91.28ms +step:75/1645 train_time:6847ms step_avg:91.29ms +step:76/1645 train_time:6939ms step_avg:91.30ms +step:77/1645 train_time:7029ms step_avg:91.29ms +step:78/1645 train_time:7121ms step_avg:91.29ms +step:79/1645 train_time:7213ms step_avg:91.30ms +step:80/1645 train_time:7305ms step_avg:91.31ms +step:81/1645 train_time:7397ms step_avg:91.32ms +step:82/1645 train_time:7489ms step_avg:91.33ms +step:83/1645 train_time:7582ms step_avg:91.35ms +step:84/1645 train_time:7674ms step_avg:91.36ms +step:85/1645 train_time:7766ms step_avg:91.36ms +step:86/1645 train_time:7857ms step_avg:91.37ms +step:87/1645 train_time:7950ms step_avg:91.38ms +step:88/1645 train_time:8041ms step_avg:91.38ms +step:89/1645 train_time:8133ms step_avg:91.39ms +step:90/1645 train_time:8224ms step_avg:91.38ms +step:91/1645 train_time:8316ms step_avg:91.38ms +step:92/1645 train_time:8409ms step_avg:91.40ms +step:93/1645 train_time:8500ms step_avg:91.40ms +step:94/1645 train_time:8592ms step_avg:91.40ms +step:95/1645 train_time:8685ms step_avg:91.42ms +step:96/1645 train_time:8778ms step_avg:91.43ms +step:97/1645 train_time:8870ms step_avg:91.44ms +step:98/1645 train_time:8963ms step_avg:91.46ms +step:99/1645 train_time:9054ms step_avg:91.46ms +step:100/1645 train_time:9146ms step_avg:91.46ms +step:101/1645 train_time:9237ms step_avg:91.46ms +step:102/1645 train_time:9328ms step_avg:91.45ms +step:103/1645 train_time:9420ms step_avg:91.45ms +step:104/1645 train_time:9511ms step_avg:91.46ms +step:105/1645 train_time:9604ms step_avg:91.47ms +step:106/1645 train_time:9696ms step_avg:91.47ms +step:107/1645 train_time:9788ms step_avg:91.47ms +step:108/1645 train_time:9880ms step_avg:91.48ms +step:109/1645 train_time:9972ms step_avg:91.48ms +step:110/1645 train_time:10064ms step_avg:91.49ms +step:111/1645 train_time:10155ms step_avg:91.48ms +step:112/1645 train_time:10246ms step_avg:91.48ms +step:113/1645 train_time:10338ms step_avg:91.48ms +step:114/1645 train_time:10429ms step_avg:91.48ms +step:115/1645 train_time:10520ms step_avg:91.48ms +step:116/1645 train_time:10612ms step_avg:91.48ms +step:117/1645 train_time:10704ms step_avg:91.48ms +step:118/1645 train_time:10795ms step_avg:91.49ms +step:119/1645 train_time:10887ms step_avg:91.49ms +step:120/1645 train_time:10981ms step_avg:91.51ms +step:121/1645 train_time:11073ms step_avg:91.51ms +step:122/1645 train_time:11165ms step_avg:91.52ms +step:123/1645 train_time:11257ms step_avg:91.52ms +step:124/1645 train_time:11349ms step_avg:91.52ms +step:125/1645 train_time:11440ms step_avg:91.52ms +step:125/1645 val_loss:4.3122 train_time:11532ms step_avg:92.26ms +step:126/1645 train_time:11554ms step_avg:91.70ms +step:127/1645 train_time:11630ms step_avg:91.58ms +step:128/1645 train_time:11731ms step_avg:91.65ms +step:129/1645 train_time:11824ms step_avg:91.66ms +step:130/1645 train_time:11915ms step_avg:91.66ms +step:131/1645 train_time:12006ms step_avg:91.65ms +step:132/1645 train_time:12097ms step_avg:91.64ms +step:133/1645 train_time:12187ms step_avg:91.63ms +step:134/1645 train_time:12278ms step_avg:91.62ms +step:135/1645 train_time:12368ms step_avg:91.61ms +step:136/1645 train_time:12461ms step_avg:91.62ms +step:137/1645 train_time:12555ms step_avg:91.64ms +step:138/1645 train_time:12648ms step_avg:91.65ms +step:139/1645 train_time:12743ms step_avg:91.67ms +step:140/1645 train_time:12834ms step_avg:91.67ms +step:141/1645 train_time:12925ms step_avg:91.67ms +step:142/1645 train_time:13017ms step_avg:91.67ms +step:143/1645 train_time:13108ms step_avg:91.67ms +step:144/1645 train_time:13199ms step_avg:91.66ms +step:145/1645 train_time:13290ms step_avg:91.66ms +step:146/1645 train_time:13381ms step_avg:91.65ms +step:147/1645 train_time:13474ms step_avg:91.66ms +step:148/1645 train_time:13566ms step_avg:91.66ms +step:149/1645 train_time:13659ms step_avg:91.67ms +step:150/1645 train_time:13752ms step_avg:91.68ms +step:151/1645 train_time:13844ms step_avg:91.68ms +step:152/1645 train_time:13936ms step_avg:91.68ms +step:153/1645 train_time:14027ms step_avg:91.68ms +step:154/1645 train_time:14119ms step_avg:91.68ms +step:155/1645 train_time:14210ms step_avg:91.68ms +step:156/1645 train_time:14301ms step_avg:91.67ms +step:157/1645 train_time:14392ms step_avg:91.67ms +step:158/1645 train_time:14484ms step_avg:91.67ms +step:159/1645 train_time:14577ms step_avg:91.68ms +step:160/1645 train_time:14670ms step_avg:91.69ms +step:161/1645 train_time:14763ms step_avg:91.69ms +step:162/1645 train_time:14856ms step_avg:91.70ms +step:163/1645 train_time:14947ms step_avg:91.70ms +step:164/1645 train_time:15038ms step_avg:91.70ms +step:165/1645 train_time:15129ms step_avg:91.69ms +step:166/1645 train_time:15220ms step_avg:91.69ms +step:167/1645 train_time:15312ms step_avg:91.69ms +step:168/1645 train_time:15402ms step_avg:91.68ms +step:169/1645 train_time:15494ms step_avg:91.68ms +step:170/1645 train_time:15586ms step_avg:91.68ms +step:171/1645 train_time:15678ms step_avg:91.69ms +step:172/1645 train_time:15770ms step_avg:91.69ms +step:173/1645 train_time:15862ms step_avg:91.69ms +step:174/1645 train_time:15954ms step_avg:91.69ms +step:175/1645 train_time:16045ms step_avg:91.68ms +step:176/1645 train_time:16136ms step_avg:91.68ms +step:177/1645 train_time:16226ms step_avg:91.67ms +step:178/1645 train_time:16318ms step_avg:91.67ms +step:179/1645 train_time:16408ms step_avg:91.67ms +step:180/1645 train_time:16500ms step_avg:91.67ms +step:181/1645 train_time:16593ms step_avg:91.67ms +step:182/1645 train_time:16684ms step_avg:91.67ms +step:183/1645 train_time:16777ms step_avg:91.68ms +step:184/1645 train_time:16869ms step_avg:91.68ms +step:185/1645 train_time:16962ms step_avg:91.68ms +step:186/1645 train_time:17053ms step_avg:91.68ms +step:187/1645 train_time:17144ms step_avg:91.68ms +step:188/1645 train_time:17235ms step_avg:91.68ms +step:189/1645 train_time:17327ms step_avg:91.68ms +step:190/1645 train_time:17418ms step_avg:91.67ms +step:191/1645 train_time:17509ms step_avg:91.67ms +step:192/1645 train_time:17602ms step_avg:91.68ms +step:193/1645 train_time:17695ms step_avg:91.68ms +step:194/1645 train_time:17786ms step_avg:91.68ms +step:195/1645 train_time:17879ms step_avg:91.69ms +step:196/1645 train_time:17971ms step_avg:91.69ms +step:197/1645 train_time:18062ms step_avg:91.69ms +step:198/1645 train_time:18154ms step_avg:91.69ms +step:199/1645 train_time:18245ms step_avg:91.68ms +step:200/1645 train_time:18336ms step_avg:91.68ms +step:201/1645 train_time:18428ms step_avg:91.68ms +step:202/1645 train_time:18519ms step_avg:91.68ms +step:203/1645 train_time:18611ms step_avg:91.68ms +step:204/1645 train_time:18704ms step_avg:91.69ms +step:205/1645 train_time:18796ms step_avg:91.69ms +step:206/1645 train_time:18888ms step_avg:91.69ms +step:207/1645 train_time:18981ms step_avg:91.70ms +step:208/1645 train_time:19073ms step_avg:91.70ms +step:209/1645 train_time:19164ms step_avg:91.69ms +step:210/1645 train_time:19255ms step_avg:91.69ms +step:211/1645 train_time:19346ms step_avg:91.69ms +step:212/1645 train_time:19437ms step_avg:91.69ms +step:213/1645 train_time:19529ms step_avg:91.68ms +step:214/1645 train_time:19621ms step_avg:91.68ms +step:215/1645 train_time:19712ms step_avg:91.68ms +step:216/1645 train_time:19805ms step_avg:91.69ms +step:217/1645 train_time:19897ms step_avg:91.69ms +step:218/1645 train_time:19988ms step_avg:91.69ms +step:219/1645 train_time:20080ms step_avg:91.69ms +step:220/1645 train_time:20173ms step_avg:91.69ms +step:221/1645 train_time:20264ms step_avg:91.69ms +step:222/1645 train_time:20356ms step_avg:91.69ms +step:223/1645 train_time:20447ms step_avg:91.69ms +step:224/1645 train_time:20538ms step_avg:91.69ms +step:225/1645 train_time:20630ms step_avg:91.69ms +step:226/1645 train_time:20721ms step_avg:91.69ms +step:227/1645 train_time:20813ms step_avg:91.69ms +step:228/1645 train_time:20904ms step_avg:91.68ms +step:229/1645 train_time:20996ms step_avg:91.69ms +step:230/1645 train_time:21089ms step_avg:91.69ms +step:231/1645 train_time:21182ms step_avg:91.70ms +step:232/1645 train_time:21273ms step_avg:91.69ms +step:233/1645 train_time:21363ms step_avg:91.69ms +step:234/1645 train_time:21456ms step_avg:91.69ms +step:235/1645 train_time:21546ms step_avg:91.69ms +step:236/1645 train_time:21638ms step_avg:91.69ms +step:237/1645 train_time:21730ms step_avg:91.69ms +step:238/1645 train_time:21822ms step_avg:91.69ms +step:239/1645 train_time:21913ms step_avg:91.69ms +step:240/1645 train_time:22004ms step_avg:91.69ms +step:241/1645 train_time:22096ms step_avg:91.69ms +step:242/1645 train_time:22188ms step_avg:91.69ms +step:243/1645 train_time:22280ms step_avg:91.69ms +step:244/1645 train_time:22371ms step_avg:91.69ms +step:245/1645 train_time:22463ms step_avg:91.69ms +step:246/1645 train_time:22554ms step_avg:91.68ms +step:247/1645 train_time:22646ms step_avg:91.69ms +step:248/1645 train_time:22738ms step_avg:91.68ms +step:249/1645 train_time:22829ms step_avg:91.68ms +step:250/1645 train_time:22921ms step_avg:91.68ms +step:250/1645 val_loss:3.9663 train_time:23012ms step_avg:92.05ms +step:251/1645 train_time:23037ms step_avg:91.78ms +step:252/1645 train_time:23107ms step_avg:91.69ms +step:253/1645 train_time:23202ms step_avg:91.71ms +step:254/1645 train_time:23294ms step_avg:91.71ms +step:255/1645 train_time:23385ms step_avg:91.71ms +step:256/1645 train_time:23476ms step_avg:91.70ms +step:257/1645 train_time:23567ms step_avg:91.70ms +step:258/1645 train_time:23657ms step_avg:91.70ms +step:259/1645 train_time:23749ms step_avg:91.70ms +step:260/1645 train_time:23840ms step_avg:91.69ms +step:261/1645 train_time:23931ms step_avg:91.69ms +step:262/1645 train_time:24024ms step_avg:91.70ms +step:263/1645 train_time:24117ms step_avg:91.70ms +step:264/1645 train_time:24210ms step_avg:91.70ms +step:265/1645 train_time:24301ms step_avg:91.70ms +step:266/1645 train_time:24393ms step_avg:91.70ms +step:267/1645 train_time:24483ms step_avg:91.70ms +step:268/1645 train_time:24575ms step_avg:91.70ms +step:269/1645 train_time:24665ms step_avg:91.69ms +step:270/1645 train_time:24757ms step_avg:91.69ms +step:271/1645 train_time:24848ms step_avg:91.69ms +step:272/1645 train_time:24940ms step_avg:91.69ms +step:273/1645 train_time:25033ms step_avg:91.70ms +step:274/1645 train_time:25126ms step_avg:91.70ms +step:275/1645 train_time:25218ms step_avg:91.70ms +step:276/1645 train_time:25310ms step_avg:91.70ms +step:277/1645 train_time:25402ms step_avg:91.70ms +step:278/1645 train_time:25494ms step_avg:91.70ms +step:279/1645 train_time:25585ms step_avg:91.70ms +step:280/1645 train_time:25676ms step_avg:91.70ms +step:281/1645 train_time:25767ms step_avg:91.70ms +step:282/1645 train_time:25858ms step_avg:91.69ms +step:283/1645 train_time:25950ms step_avg:91.70ms +step:284/1645 train_time:26043ms step_avg:91.70ms +step:285/1645 train_time:26135ms step_avg:91.70ms +step:286/1645 train_time:26228ms step_avg:91.71ms +step:287/1645 train_time:26320ms step_avg:91.71ms +step:288/1645 train_time:26411ms step_avg:91.71ms +step:289/1645 train_time:26503ms step_avg:91.71ms +step:290/1645 train_time:26594ms step_avg:91.70ms +step:291/1645 train_time:26685ms step_avg:91.70ms +step:292/1645 train_time:26776ms step_avg:91.70ms +step:293/1645 train_time:26869ms step_avg:91.70ms +step:294/1645 train_time:26959ms step_avg:91.70ms +step:295/1645 train_time:27052ms step_avg:91.70ms +step:296/1645 train_time:27144ms step_avg:91.70ms +step:297/1645 train_time:27236ms step_avg:91.70ms +step:298/1645 train_time:27327ms step_avg:91.70ms +step:299/1645 train_time:27419ms step_avg:91.70ms +step:300/1645 train_time:27513ms step_avg:91.71ms +step:301/1645 train_time:27604ms step_avg:91.71ms +step:302/1645 train_time:27694ms step_avg:91.70ms +step:303/1645 train_time:27785ms step_avg:91.70ms +step:304/1645 train_time:27877ms step_avg:91.70ms +step:305/1645 train_time:27968ms step_avg:91.70ms +step:306/1645 train_time:28060ms step_avg:91.70ms +step:307/1645 train_time:28152ms step_avg:91.70ms +step:308/1645 train_time:28244ms step_avg:91.70ms +step:309/1645 train_time:28336ms step_avg:91.70ms +step:310/1645 train_time:28428ms step_avg:91.70ms +step:311/1645 train_time:28520ms step_avg:91.70ms +step:312/1645 train_time:28613ms step_avg:91.71ms +step:313/1645 train_time:28704ms step_avg:91.70ms +step:314/1645 train_time:28794ms step_avg:91.70ms +step:315/1645 train_time:28886ms step_avg:91.70ms +step:316/1645 train_time:28978ms step_avg:91.70ms +step:317/1645 train_time:29069ms step_avg:91.70ms +step:318/1645 train_time:29161ms step_avg:91.70ms +step:319/1645 train_time:29254ms step_avg:91.71ms +step:320/1645 train_time:29346ms step_avg:91.71ms +step:321/1645 train_time:29437ms step_avg:91.70ms +step:322/1645 train_time:29529ms step_avg:91.70ms +step:323/1645 train_time:29620ms step_avg:91.70ms +step:324/1645 train_time:29712ms step_avg:91.70ms +step:325/1645 train_time:29802ms step_avg:91.70ms +step:326/1645 train_time:29893ms step_avg:91.70ms +step:327/1645 train_time:29985ms step_avg:91.70ms +step:328/1645 train_time:30076ms step_avg:91.70ms +step:329/1645 train_time:30167ms step_avg:91.69ms +step:330/1645 train_time:30260ms step_avg:91.70ms +step:331/1645 train_time:30351ms step_avg:91.70ms +step:332/1645 train_time:30444ms step_avg:91.70ms +step:333/1645 train_time:30536ms step_avg:91.70ms +step:334/1645 train_time:30628ms step_avg:91.70ms +step:335/1645 train_time:30719ms step_avg:91.70ms +step:336/1645 train_time:30811ms step_avg:91.70ms +step:337/1645 train_time:30902ms step_avg:91.70ms +step:338/1645 train_time:30993ms step_avg:91.70ms +step:339/1645 train_time:31084ms step_avg:91.69ms +step:340/1645 train_time:31176ms step_avg:91.69ms +step:341/1645 train_time:31266ms step_avg:91.69ms +step:342/1645 train_time:31358ms step_avg:91.69ms +step:343/1645 train_time:31451ms step_avg:91.69ms +step:344/1645 train_time:31543ms step_avg:91.70ms +step:345/1645 train_time:31635ms step_avg:91.70ms +step:346/1645 train_time:31727ms step_avg:91.70ms +step:347/1645 train_time:31819ms step_avg:91.70ms +step:348/1645 train_time:31910ms step_avg:91.70ms +step:349/1645 train_time:32002ms step_avg:91.70ms +step:350/1645 train_time:32095ms step_avg:91.70ms +step:351/1645 train_time:32185ms step_avg:91.70ms +step:352/1645 train_time:32277ms step_avg:91.70ms +step:353/1645 train_time:32369ms step_avg:91.70ms +step:354/1645 train_time:32461ms step_avg:91.70ms +step:355/1645 train_time:32554ms step_avg:91.70ms +step:356/1645 train_time:32646ms step_avg:91.70ms +step:357/1645 train_time:32737ms step_avg:91.70ms +step:358/1645 train_time:32829ms step_avg:91.70ms +step:359/1645 train_time:32921ms step_avg:91.70ms +step:360/1645 train_time:33012ms step_avg:91.70ms +step:361/1645 train_time:33104ms step_avg:91.70ms +step:362/1645 train_time:33195ms step_avg:91.70ms +step:363/1645 train_time:33286ms step_avg:91.70ms +step:364/1645 train_time:33378ms step_avg:91.70ms +step:365/1645 train_time:33470ms step_avg:91.70ms +step:366/1645 train_time:33562ms step_avg:91.70ms +step:367/1645 train_time:33654ms step_avg:91.70ms +step:368/1645 train_time:33746ms step_avg:91.70ms +step:369/1645 train_time:33837ms step_avg:91.70ms +step:370/1645 train_time:33929ms step_avg:91.70ms +step:371/1645 train_time:34020ms step_avg:91.70ms +step:372/1645 train_time:34112ms step_avg:91.70ms +step:373/1645 train_time:34205ms step_avg:91.70ms +step:374/1645 train_time:34295ms step_avg:91.70ms +step:375/1645 train_time:34387ms step_avg:91.70ms +step:375/1645 val_loss:3.8136 train_time:34479ms step_avg:91.94ms +step:376/1645 train_time:34503ms step_avg:91.76ms +step:377/1645 train_time:34574ms step_avg:91.71ms +step:378/1645 train_time:34669ms step_avg:91.72ms +step:379/1645 train_time:34762ms step_avg:91.72ms +step:380/1645 train_time:34852ms step_avg:91.72ms +step:381/1645 train_time:34943ms step_avg:91.71ms +step:382/1645 train_time:35033ms step_avg:91.71ms +step:383/1645 train_time:35124ms step_avg:91.71ms +step:384/1645 train_time:35215ms step_avg:91.71ms +step:385/1645 train_time:35305ms step_avg:91.70ms +step:386/1645 train_time:35397ms step_avg:91.70ms +step:387/1645 train_time:35490ms step_avg:91.71ms +step:388/1645 train_time:35583ms step_avg:91.71ms +step:389/1645 train_time:35675ms step_avg:91.71ms +step:390/1645 train_time:35768ms step_avg:91.71ms +step:391/1645 train_time:35859ms step_avg:91.71ms +step:392/1645 train_time:35950ms step_avg:91.71ms +step:393/1645 train_time:36041ms step_avg:91.71ms +step:394/1645 train_time:36132ms step_avg:91.70ms +step:395/1645 train_time:36223ms step_avg:91.70ms +step:396/1645 train_time:36314ms step_avg:91.70ms +step:397/1645 train_time:36406ms step_avg:91.70ms +step:398/1645 train_time:36499ms step_avg:91.71ms +step:399/1645 train_time:36591ms step_avg:91.71ms +step:400/1645 train_time:36684ms step_avg:91.71ms +step:401/1645 train_time:36777ms step_avg:91.71ms +step:402/1645 train_time:36869ms step_avg:91.71ms +step:403/1645 train_time:36960ms step_avg:91.71ms +step:404/1645 train_time:37051ms step_avg:91.71ms +step:405/1645 train_time:37142ms step_avg:91.71ms +step:406/1645 train_time:37232ms step_avg:91.71ms +step:407/1645 train_time:37324ms step_avg:91.71ms +step:408/1645 train_time:37416ms step_avg:91.71ms +step:409/1645 train_time:37508ms step_avg:91.71ms +step:410/1645 train_time:37601ms step_avg:91.71ms +step:411/1645 train_time:37693ms step_avg:91.71ms +step:412/1645 train_time:37785ms step_avg:91.71ms +step:413/1645 train_time:37878ms step_avg:91.71ms +step:414/1645 train_time:37969ms step_avg:91.71ms +step:415/1645 train_time:38060ms step_avg:91.71ms +step:416/1645 train_time:38151ms step_avg:91.71ms +step:417/1645 train_time:38243ms step_avg:91.71ms +step:418/1645 train_time:38333ms step_avg:91.71ms +step:419/1645 train_time:38425ms step_avg:91.71ms +step:420/1645 train_time:38517ms step_avg:91.71ms +step:421/1645 train_time:38609ms step_avg:91.71ms +step:422/1645 train_time:38702ms step_avg:91.71ms +step:423/1645 train_time:38794ms step_avg:91.71ms +step:424/1645 train_time:38886ms step_avg:91.71ms +step:425/1645 train_time:38980ms step_avg:91.72ms +step:426/1645 train_time:39070ms step_avg:91.71ms +step:427/1645 train_time:39161ms step_avg:91.71ms +step:428/1645 train_time:39252ms step_avg:91.71ms +step:429/1645 train_time:39343ms step_avg:91.71ms +step:430/1645 train_time:39433ms step_avg:91.71ms +step:431/1645 train_time:39525ms step_avg:91.71ms +step:432/1645 train_time:39616ms step_avg:91.70ms +step:433/1645 train_time:39709ms step_avg:91.71ms +step:434/1645 train_time:39802ms step_avg:91.71ms +step:435/1645 train_time:39893ms step_avg:91.71ms +step:436/1645 train_time:39985ms step_avg:91.71ms +step:437/1645 train_time:40077ms step_avg:91.71ms +step:438/1645 train_time:40168ms step_avg:91.71ms +step:439/1645 train_time:40260ms step_avg:91.71ms +step:440/1645 train_time:40351ms step_avg:91.71ms +step:441/1645 train_time:40442ms step_avg:91.70ms +step:442/1645 train_time:40533ms step_avg:91.70ms +step:443/1645 train_time:40625ms step_avg:91.70ms +step:444/1645 train_time:40716ms step_avg:91.70ms +step:445/1645 train_time:40808ms step_avg:91.70ms +step:446/1645 train_time:40901ms step_avg:91.71ms +step:447/1645 train_time:40994ms step_avg:91.71ms +step:448/1645 train_time:41087ms step_avg:91.71ms +step:449/1645 train_time:41177ms step_avg:91.71ms +step:450/1645 train_time:41269ms step_avg:91.71ms +step:451/1645 train_time:41359ms step_avg:91.71ms +step:452/1645 train_time:41450ms step_avg:91.70ms +step:453/1645 train_time:41541ms step_avg:91.70ms +step:454/1645 train_time:41633ms step_avg:91.70ms +step:455/1645 train_time:41724ms step_avg:91.70ms +step:456/1645 train_time:41817ms step_avg:91.70ms +step:457/1645 train_time:41907ms step_avg:91.70ms +step:458/1645 train_time:42000ms step_avg:91.70ms +step:459/1645 train_time:42093ms step_avg:91.71ms +step:460/1645 train_time:42184ms step_avg:91.70ms +step:461/1645 train_time:42275ms step_avg:91.70ms +step:462/1645 train_time:42367ms step_avg:91.70ms +step:463/1645 train_time:42458ms step_avg:91.70ms +step:464/1645 train_time:42549ms step_avg:91.70ms +step:465/1645 train_time:42641ms step_avg:91.70ms +step:466/1645 train_time:42732ms step_avg:91.70ms +step:467/1645 train_time:42823ms step_avg:91.70ms +step:468/1645 train_time:42914ms step_avg:91.70ms +step:469/1645 train_time:43007ms step_avg:91.70ms +step:470/1645 train_time:43099ms step_avg:91.70ms +step:471/1645 train_time:43192ms step_avg:91.70ms +step:472/1645 train_time:43285ms step_avg:91.71ms +step:473/1645 train_time:43376ms step_avg:91.70ms +step:474/1645 train_time:43467ms step_avg:91.70ms +step:475/1645 train_time:43559ms step_avg:91.70ms +step:476/1645 train_time:43650ms step_avg:91.70ms +step:477/1645 train_time:43742ms step_avg:91.70ms +step:478/1645 train_time:43833ms step_avg:91.70ms +step:479/1645 train_time:43924ms step_avg:91.70ms +step:480/1645 train_time:44016ms step_avg:91.70ms +step:481/1645 train_time:44108ms step_avg:91.70ms +step:482/1645 train_time:44201ms step_avg:91.70ms +step:483/1645 train_time:44293ms step_avg:91.70ms +step:484/1645 train_time:44385ms step_avg:91.70ms +step:485/1645 train_time:44476ms step_avg:91.70ms +step:486/1645 train_time:44567ms step_avg:91.70ms +step:487/1645 train_time:44659ms step_avg:91.70ms +step:488/1645 train_time:44751ms step_avg:91.70ms +step:489/1645 train_time:44842ms step_avg:91.70ms +step:490/1645 train_time:44934ms step_avg:91.70ms +step:491/1645 train_time:45025ms step_avg:91.70ms +step:492/1645 train_time:45116ms step_avg:91.70ms +step:493/1645 train_time:45209ms step_avg:91.70ms +step:494/1645 train_time:45301ms step_avg:91.70ms +step:495/1645 train_time:45394ms step_avg:91.70ms +step:496/1645 train_time:45485ms step_avg:91.70ms +step:497/1645 train_time:45576ms step_avg:91.70ms +step:498/1645 train_time:45668ms step_avg:91.70ms +step:499/1645 train_time:45759ms step_avg:91.70ms +step:500/1645 train_time:45850ms step_avg:91.70ms +step:500/1645 val_loss:3.7121 train_time:45943ms step_avg:91.89ms +step:501/1645 train_time:45963ms step_avg:91.74ms +step:502/1645 train_time:46038ms step_avg:91.71ms +step:503/1645 train_time:46133ms step_avg:91.72ms +step:504/1645 train_time:46225ms step_avg:91.72ms +step:505/1645 train_time:46317ms step_avg:91.72ms +step:506/1645 train_time:46407ms step_avg:91.71ms +step:507/1645 train_time:46498ms step_avg:91.71ms +step:508/1645 train_time:46588ms step_avg:91.71ms +step:509/1645 train_time:46680ms step_avg:91.71ms +step:510/1645 train_time:46771ms step_avg:91.71ms +step:511/1645 train_time:46862ms step_avg:91.71ms +step:512/1645 train_time:46955ms step_avg:91.71ms +step:513/1645 train_time:47049ms step_avg:91.71ms +step:514/1645 train_time:47141ms step_avg:91.71ms +step:515/1645 train_time:47234ms step_avg:91.72ms +step:516/1645 train_time:47326ms step_avg:91.72ms +step:517/1645 train_time:47417ms step_avg:91.72ms +step:518/1645 train_time:47508ms step_avg:91.71ms +step:519/1645 train_time:47598ms step_avg:91.71ms +step:520/1645 train_time:47690ms step_avg:91.71ms +step:521/1645 train_time:47780ms step_avg:91.71ms +step:522/1645 train_time:47873ms step_avg:91.71ms +step:523/1645 train_time:47964ms step_avg:91.71ms +step:524/1645 train_time:48057ms step_avg:91.71ms +step:525/1645 train_time:48150ms step_avg:91.71ms +step:526/1645 train_time:48243ms step_avg:91.72ms +step:527/1645 train_time:48335ms step_avg:91.72ms +step:528/1645 train_time:48426ms step_avg:91.72ms +step:529/1645 train_time:48518ms step_avg:91.72ms +step:530/1645 train_time:48609ms step_avg:91.72ms +step:531/1645 train_time:48700ms step_avg:91.71ms +step:532/1645 train_time:48791ms step_avg:91.71ms +step:533/1645 train_time:48885ms step_avg:91.72ms +step:534/1645 train_time:48976ms step_avg:91.71ms +step:535/1645 train_time:49067ms step_avg:91.71ms +step:536/1645 train_time:49160ms step_avg:91.72ms +step:537/1645 train_time:49252ms step_avg:91.72ms +step:538/1645 train_time:49344ms step_avg:91.72ms +step:539/1645 train_time:49435ms step_avg:91.72ms +step:540/1645 train_time:49526ms step_avg:91.72ms +step:541/1645 train_time:49618ms step_avg:91.72ms +step:542/1645 train_time:49709ms step_avg:91.71ms +step:543/1645 train_time:49800ms step_avg:91.71ms +step:544/1645 train_time:49892ms step_avg:91.71ms +step:545/1645 train_time:49984ms step_avg:91.71ms +step:546/1645 train_time:50076ms step_avg:91.71ms +step:547/1645 train_time:50168ms step_avg:91.71ms +step:548/1645 train_time:50260ms step_avg:91.71ms +step:549/1645 train_time:50353ms step_avg:91.72ms +step:550/1645 train_time:50445ms step_avg:91.72ms +step:551/1645 train_time:50537ms step_avg:91.72ms +step:552/1645 train_time:50630ms step_avg:91.72ms +step:553/1645 train_time:50722ms step_avg:91.72ms +step:554/1645 train_time:50816ms step_avg:91.73ms +step:555/1645 train_time:50908ms step_avg:91.73ms +step:556/1645 train_time:51001ms step_avg:91.73ms +step:557/1645 train_time:51095ms step_avg:91.73ms +step:558/1645 train_time:51188ms step_avg:91.73ms +step:559/1645 train_time:51282ms step_avg:91.74ms +step:560/1645 train_time:51375ms step_avg:91.74ms +step:561/1645 train_time:51468ms step_avg:91.74ms +step:562/1645 train_time:51561ms step_avg:91.74ms +step:563/1645 train_time:51654ms step_avg:91.75ms +step:564/1645 train_time:51746ms step_avg:91.75ms +step:565/1645 train_time:51839ms step_avg:91.75ms +step:566/1645 train_time:51932ms step_avg:91.75ms +step:567/1645 train_time:52024ms step_avg:91.75ms +step:568/1645 train_time:52118ms step_avg:91.76ms +step:569/1645 train_time:52212ms step_avg:91.76ms +step:570/1645 train_time:52305ms step_avg:91.76ms +step:571/1645 train_time:52398ms step_avg:91.76ms +step:572/1645 train_time:52490ms step_avg:91.77ms +step:573/1645 train_time:52583ms step_avg:91.77ms +step:574/1645 train_time:52676ms step_avg:91.77ms +step:575/1645 train_time:52770ms step_avg:91.77ms +step:576/1645 train_time:52861ms step_avg:91.77ms +step:577/1645 train_time:52955ms step_avg:91.78ms +step:578/1645 train_time:53048ms step_avg:91.78ms +step:579/1645 train_time:53141ms step_avg:91.78ms +step:580/1645 train_time:53235ms step_avg:91.78ms +step:581/1645 train_time:53328ms step_avg:91.79ms +step:582/1645 train_time:53421ms step_avg:91.79ms +step:583/1645 train_time:53515ms step_avg:91.79ms +step:584/1645 train_time:53608ms step_avg:91.79ms +step:585/1645 train_time:53702ms step_avg:91.80ms +step:586/1645 train_time:53794ms step_avg:91.80ms +step:587/1645 train_time:53886ms step_avg:91.80ms +step:588/1645 train_time:53979ms step_avg:91.80ms +step:589/1645 train_time:54072ms step_avg:91.80ms +step:590/1645 train_time:54165ms step_avg:91.80ms +step:591/1645 train_time:54258ms step_avg:91.81ms +step:592/1645 train_time:54352ms step_avg:91.81ms +step:593/1645 train_time:54445ms step_avg:91.81ms +step:594/1645 train_time:54538ms step_avg:91.81ms +step:595/1645 train_time:54631ms step_avg:91.82ms +step:596/1645 train_time:54724ms step_avg:91.82ms +step:597/1645 train_time:54817ms step_avg:91.82ms +step:598/1645 train_time:54909ms step_avg:91.82ms +step:599/1645 train_time:55003ms step_avg:91.82ms +step:600/1645 train_time:55096ms step_avg:91.83ms +step:601/1645 train_time:55189ms step_avg:91.83ms +step:602/1645 train_time:55282ms step_avg:91.83ms +step:603/1645 train_time:55375ms step_avg:91.83ms +step:604/1645 train_time:55468ms step_avg:91.84ms +step:605/1645 train_time:55561ms step_avg:91.84ms +step:606/1645 train_time:55655ms step_avg:91.84ms +step:607/1645 train_time:55748ms step_avg:91.84ms +step:608/1645 train_time:55841ms step_avg:91.84ms +step:609/1645 train_time:55933ms step_avg:91.84ms +step:610/1645 train_time:56026ms step_avg:91.85ms +step:611/1645 train_time:56119ms step_avg:91.85ms +step:612/1645 train_time:56212ms step_avg:91.85ms +step:613/1645 train_time:56305ms step_avg:91.85ms +step:614/1645 train_time:56399ms step_avg:91.85ms +step:615/1645 train_time:56493ms step_avg:91.86ms +step:616/1645 train_time:56586ms step_avg:91.86ms +step:617/1645 train_time:56679ms step_avg:91.86ms +step:618/1645 train_time:56772ms step_avg:91.86ms +step:619/1645 train_time:56865ms step_avg:91.87ms +step:620/1645 train_time:56957ms step_avg:91.87ms +step:621/1645 train_time:57050ms step_avg:91.87ms +step:622/1645 train_time:57142ms step_avg:91.87ms +step:623/1645 train_time:57236ms step_avg:91.87ms +step:624/1645 train_time:57329ms step_avg:91.87ms +step:625/1645 train_time:57422ms step_avg:91.88ms +step:625/1645 val_loss:3.6130 train_time:57517ms step_avg:92.03ms +step:626/1645 train_time:57537ms step_avg:91.91ms +step:627/1645 train_time:57612ms step_avg:91.89ms +step:628/1645 train_time:57714ms step_avg:91.90ms +step:629/1645 train_time:57809ms step_avg:91.91ms +step:630/1645 train_time:57902ms step_avg:91.91ms +step:631/1645 train_time:57994ms step_avg:91.91ms +step:632/1645 train_time:58085ms step_avg:91.91ms +step:633/1645 train_time:58177ms step_avg:91.91ms +step:634/1645 train_time:58269ms step_avg:91.91ms +step:635/1645 train_time:58361ms step_avg:91.91ms +step:636/1645 train_time:58456ms step_avg:91.91ms +step:637/1645 train_time:58552ms step_avg:91.92ms +step:638/1645 train_time:58646ms step_avg:91.92ms +step:639/1645 train_time:58740ms step_avg:91.93ms +step:640/1645 train_time:58833ms step_avg:91.93ms +step:641/1645 train_time:58926ms step_avg:91.93ms +step:642/1645 train_time:59018ms step_avg:91.93ms +step:643/1645 train_time:59111ms step_avg:91.93ms +step:644/1645 train_time:59203ms step_avg:91.93ms +step:645/1645 train_time:59295ms step_avg:91.93ms +step:646/1645 train_time:59387ms step_avg:91.93ms +step:647/1645 train_time:59481ms step_avg:91.93ms +step:648/1645 train_time:59577ms step_avg:91.94ms +step:649/1645 train_time:59670ms step_avg:91.94ms +step:650/1645 train_time:59763ms step_avg:91.94ms +step:651/1645 train_time:59856ms step_avg:91.94ms +step:652/1645 train_time:59948ms step_avg:91.95ms +step:653/1645 train_time:60041ms step_avg:91.95ms +step:654/1645 train_time:60133ms step_avg:91.95ms +step:655/1645 train_time:60226ms step_avg:91.95ms +step:656/1645 train_time:60318ms step_avg:91.95ms +step:657/1645 train_time:60412ms step_avg:91.95ms +step:658/1645 train_time:60505ms step_avg:91.95ms +step:659/1645 train_time:60598ms step_avg:91.96ms +step:660/1645 train_time:60694ms step_avg:91.96ms +step:661/1645 train_time:60786ms step_avg:91.96ms +step:662/1645 train_time:60878ms step_avg:91.96ms +step:663/1645 train_time:60972ms step_avg:91.96ms +step:664/1645 train_time:61065ms step_avg:91.96ms +step:665/1645 train_time:61157ms step_avg:91.97ms +step:666/1645 train_time:61250ms step_avg:91.97ms +step:667/1645 train_time:61342ms step_avg:91.97ms +step:668/1645 train_time:61435ms step_avg:91.97ms +step:669/1645 train_time:61529ms step_avg:91.97ms +step:670/1645 train_time:61622ms step_avg:91.97ms +step:671/1645 train_time:61715ms step_avg:91.97ms +step:672/1645 train_time:61808ms step_avg:91.98ms +step:673/1645 train_time:61901ms step_avg:91.98ms +step:674/1645 train_time:61995ms step_avg:91.98ms +step:675/1645 train_time:62088ms step_avg:91.98ms +step:676/1645 train_time:62180ms step_avg:91.98ms +step:677/1645 train_time:62272ms step_avg:91.98ms +step:678/1645 train_time:62365ms step_avg:91.98ms +step:679/1645 train_time:62458ms step_avg:91.98ms +step:680/1645 train_time:62550ms step_avg:91.99ms +step:681/1645 train_time:62643ms step_avg:91.99ms +step:682/1645 train_time:62736ms step_avg:91.99ms +step:683/1645 train_time:62829ms step_avg:91.99ms +step:684/1645 train_time:62922ms step_avg:91.99ms +step:685/1645 train_time:63016ms step_avg:91.99ms +step:686/1645 train_time:63109ms step_avg:92.00ms +step:687/1645 train_time:63201ms step_avg:92.00ms +step:688/1645 train_time:63294ms step_avg:92.00ms +step:689/1645 train_time:63387ms step_avg:92.00ms +step:690/1645 train_time:63480ms step_avg:92.00ms +step:691/1645 train_time:63573ms step_avg:92.00ms +step:692/1645 train_time:63666ms step_avg:92.00ms +step:693/1645 train_time:63760ms step_avg:92.01ms +step:694/1645 train_time:63853ms step_avg:92.01ms +step:695/1645 train_time:63946ms step_avg:92.01ms +step:696/1645 train_time:64039ms step_avg:92.01ms +step:697/1645 train_time:64132ms step_avg:92.01ms +step:698/1645 train_time:64225ms step_avg:92.01ms +step:699/1645 train_time:64318ms step_avg:92.01ms +step:700/1645 train_time:64411ms step_avg:92.02ms +step:701/1645 train_time:64504ms step_avg:92.02ms +step:702/1645 train_time:64597ms step_avg:92.02ms +step:703/1645 train_time:64690ms step_avg:92.02ms +step:704/1645 train_time:64782ms step_avg:92.02ms +step:705/1645 train_time:64875ms step_avg:92.02ms +step:706/1645 train_time:64968ms step_avg:92.02ms +step:707/1645 train_time:65061ms step_avg:92.02ms +step:708/1645 train_time:65154ms step_avg:92.03ms +step:709/1645 train_time:65247ms step_avg:92.03ms +step:710/1645 train_time:65340ms step_avg:92.03ms +step:711/1645 train_time:65432ms step_avg:92.03ms +step:712/1645 train_time:65525ms step_avg:92.03ms +step:713/1645 train_time:65621ms step_avg:92.03ms +step:714/1645 train_time:65713ms step_avg:92.03ms +step:715/1645 train_time:65805ms step_avg:92.04ms +step:716/1645 train_time:65898ms step_avg:92.04ms +step:717/1645 train_time:65990ms step_avg:92.04ms +step:718/1645 train_time:66082ms step_avg:92.04ms +step:719/1645 train_time:66176ms step_avg:92.04ms +step:720/1645 train_time:66269ms step_avg:92.04ms +step:721/1645 train_time:66361ms step_avg:92.04ms +step:722/1645 train_time:66454ms step_avg:92.04ms +step:723/1645 train_time:66547ms step_avg:92.04ms +step:724/1645 train_time:66640ms step_avg:92.04ms +step:725/1645 train_time:66734ms step_avg:92.05ms +step:726/1645 train_time:66827ms step_avg:92.05ms +step:727/1645 train_time:66921ms step_avg:92.05ms +step:728/1645 train_time:67014ms step_avg:92.05ms +step:729/1645 train_time:67106ms step_avg:92.05ms +step:730/1645 train_time:67199ms step_avg:92.05ms +step:731/1645 train_time:67292ms step_avg:92.05ms +step:732/1645 train_time:67385ms step_avg:92.06ms +step:733/1645 train_time:67477ms step_avg:92.06ms +step:734/1645 train_time:67570ms step_avg:92.06ms +step:735/1645 train_time:67663ms step_avg:92.06ms +step:736/1645 train_time:67756ms step_avg:92.06ms +step:737/1645 train_time:67849ms step_avg:92.06ms +step:738/1645 train_time:67941ms step_avg:92.06ms +step:739/1645 train_time:68034ms step_avg:92.06ms +step:740/1645 train_time:68126ms step_avg:92.06ms +step:741/1645 train_time:68220ms step_avg:92.06ms +step:742/1645 train_time:68313ms step_avg:92.07ms +step:743/1645 train_time:68406ms step_avg:92.07ms +step:744/1645 train_time:68500ms step_avg:92.07ms +step:745/1645 train_time:68593ms step_avg:92.07ms +step:746/1645 train_time:68686ms step_avg:92.07ms +step:747/1645 train_time:68778ms step_avg:92.07ms +step:748/1645 train_time:68871ms step_avg:92.07ms +step:749/1645 train_time:68963ms step_avg:92.07ms +step:750/1645 train_time:69055ms step_avg:92.07ms +step:750/1645 val_loss:3.5612 train_time:69148ms step_avg:92.20ms +step:751/1645 train_time:69169ms step_avg:92.10ms +step:752/1645 train_time:69245ms step_avg:92.08ms +step:753/1645 train_time:69338ms step_avg:92.08ms +step:754/1645 train_time:69430ms step_avg:92.08ms +step:755/1645 train_time:69522ms step_avg:92.08ms +step:756/1645 train_time:69614ms step_avg:92.08ms +step:757/1645 train_time:69706ms step_avg:92.08ms +step:758/1645 train_time:69798ms step_avg:92.08ms +step:759/1645 train_time:69891ms step_avg:92.08ms +step:760/1645 train_time:69984ms step_avg:92.08ms +step:761/1645 train_time:70079ms step_avg:92.09ms +step:762/1645 train_time:70173ms step_avg:92.09ms +step:763/1645 train_time:70267ms step_avg:92.09ms +step:764/1645 train_time:70360ms step_avg:92.09ms +step:765/1645 train_time:70454ms step_avg:92.10ms +step:766/1645 train_time:70547ms step_avg:92.10ms +step:767/1645 train_time:70639ms step_avg:92.10ms +step:768/1645 train_time:70731ms step_avg:92.10ms +step:769/1645 train_time:70824ms step_avg:92.10ms +step:770/1645 train_time:70916ms step_avg:92.10ms +step:771/1645 train_time:71009ms step_avg:92.10ms +step:772/1645 train_time:71102ms step_avg:92.10ms +step:773/1645 train_time:71195ms step_avg:92.10ms +step:774/1645 train_time:71288ms step_avg:92.10ms +step:775/1645 train_time:71382ms step_avg:92.11ms +step:776/1645 train_time:71475ms step_avg:92.11ms +step:777/1645 train_time:71567ms step_avg:92.11ms +step:778/1645 train_time:71660ms step_avg:92.11ms +step:779/1645 train_time:71753ms step_avg:92.11ms +step:780/1645 train_time:71845ms step_avg:92.11ms +step:781/1645 train_time:71939ms step_avg:92.11ms +step:782/1645 train_time:72032ms step_avg:92.11ms +step:783/1645 train_time:72125ms step_avg:92.11ms +step:784/1645 train_time:72219ms step_avg:92.12ms +step:785/1645 train_time:72312ms step_avg:92.12ms +step:786/1645 train_time:72405ms step_avg:92.12ms +step:787/1645 train_time:72498ms step_avg:92.12ms +step:788/1645 train_time:72593ms step_avg:92.12ms +step:789/1645 train_time:72683ms step_avg:92.12ms +step:790/1645 train_time:72776ms step_avg:92.12ms +step:791/1645 train_time:72868ms step_avg:92.12ms +step:792/1645 train_time:72961ms step_avg:92.12ms +step:793/1645 train_time:73055ms step_avg:92.12ms +step:794/1645 train_time:73148ms step_avg:92.13ms +step:795/1645 train_time:73241ms step_avg:92.13ms +step:796/1645 train_time:73334ms step_avg:92.13ms +step:797/1645 train_time:73427ms step_avg:92.13ms +step:798/1645 train_time:73520ms step_avg:92.13ms +step:799/1645 train_time:73613ms step_avg:92.13ms +step:800/1645 train_time:73705ms step_avg:92.13ms +step:801/1645 train_time:73799ms step_avg:92.13ms +step:802/1645 train_time:73892ms step_avg:92.13ms +step:803/1645 train_time:73984ms step_avg:92.13ms +step:804/1645 train_time:74078ms step_avg:92.14ms +step:805/1645 train_time:74171ms step_avg:92.14ms +step:806/1645 train_time:74264ms step_avg:92.14ms +step:807/1645 train_time:74358ms step_avg:92.14ms +step:808/1645 train_time:74451ms step_avg:92.14ms +step:809/1645 train_time:74544ms step_avg:92.14ms +step:810/1645 train_time:74636ms step_avg:92.14ms +step:811/1645 train_time:74728ms step_avg:92.14ms +step:812/1645 train_time:74821ms step_avg:92.14ms +step:813/1645 train_time:74914ms step_avg:92.14ms +step:814/1645 train_time:75007ms step_avg:92.15ms +step:815/1645 train_time:75100ms step_avg:92.15ms +step:816/1645 train_time:75193ms step_avg:92.15ms +step:817/1645 train_time:75286ms step_avg:92.15ms +step:818/1645 train_time:75380ms step_avg:92.15ms +step:819/1645 train_time:75473ms step_avg:92.15ms +step:820/1645 train_time:75565ms step_avg:92.15ms +step:821/1645 train_time:75658ms step_avg:92.15ms +step:822/1645 train_time:75751ms step_avg:92.15ms +step:823/1645 train_time:75843ms step_avg:92.15ms +step:824/1645 train_time:75936ms step_avg:92.16ms +step:825/1645 train_time:76029ms step_avg:92.16ms +step:826/1645 train_time:76122ms step_avg:92.16ms +step:827/1645 train_time:76216ms step_avg:92.16ms +step:828/1645 train_time:76311ms step_avg:92.16ms +step:829/1645 train_time:76403ms step_avg:92.16ms +step:830/1645 train_time:76496ms step_avg:92.16ms +step:831/1645 train_time:76588ms step_avg:92.16ms +step:832/1645 train_time:76680ms step_avg:92.16ms +step:833/1645 train_time:76773ms step_avg:92.16ms +step:834/1645 train_time:76865ms step_avg:92.16ms +step:835/1645 train_time:76958ms step_avg:92.17ms +step:836/1645 train_time:77051ms step_avg:92.17ms +step:837/1645 train_time:77144ms step_avg:92.17ms +step:838/1645 train_time:77238ms step_avg:92.17ms +step:839/1645 train_time:77331ms step_avg:92.17ms +step:840/1645 train_time:77424ms step_avg:92.17ms +step:841/1645 train_time:77518ms step_avg:92.17ms +step:842/1645 train_time:77611ms step_avg:92.17ms +step:843/1645 train_time:77703ms step_avg:92.17ms +step:844/1645 train_time:77795ms step_avg:92.17ms +step:845/1645 train_time:77888ms step_avg:92.18ms +step:846/1645 train_time:77980ms step_avg:92.18ms +step:847/1645 train_time:78074ms step_avg:92.18ms +step:848/1645 train_time:78166ms step_avg:92.18ms +step:849/1645 train_time:78260ms step_avg:92.18ms +step:850/1645 train_time:78354ms step_avg:92.18ms +step:851/1645 train_time:78448ms step_avg:92.18ms +step:852/1645 train_time:78541ms step_avg:92.18ms +step:853/1645 train_time:78634ms step_avg:92.18ms +step:854/1645 train_time:78726ms step_avg:92.19ms +step:855/1645 train_time:78821ms step_avg:92.19ms +step:856/1645 train_time:78912ms step_avg:92.19ms +step:857/1645 train_time:79005ms step_avg:92.19ms +step:858/1645 train_time:79097ms step_avg:92.19ms +step:859/1645 train_time:79189ms step_avg:92.19ms +step:860/1645 train_time:79282ms step_avg:92.19ms +step:861/1645 train_time:79376ms step_avg:92.19ms +step:862/1645 train_time:79468ms step_avg:92.19ms +step:863/1645 train_time:79561ms step_avg:92.19ms +step:864/1645 train_time:79654ms step_avg:92.19ms +step:865/1645 train_time:79747ms step_avg:92.19ms +step:866/1645 train_time:79840ms step_avg:92.19ms +step:867/1645 train_time:79933ms step_avg:92.19ms +step:868/1645 train_time:80027ms step_avg:92.20ms +step:869/1645 train_time:80119ms step_avg:92.20ms +step:870/1645 train_time:80211ms step_avg:92.20ms +step:871/1645 train_time:80304ms step_avg:92.20ms +step:872/1645 train_time:80397ms step_avg:92.20ms +step:873/1645 train_time:80489ms step_avg:92.20ms +step:874/1645 train_time:80582ms step_avg:92.20ms +step:875/1645 train_time:80675ms step_avg:92.20ms +step:875/1645 val_loss:3.5148 train_time:80768ms step_avg:92.31ms +step:876/1645 train_time:80789ms step_avg:92.22ms +step:877/1645 train_time:80866ms step_avg:92.21ms +step:878/1645 train_time:80963ms step_avg:92.21ms +step:879/1645 train_time:81055ms step_avg:92.21ms +step:880/1645 train_time:81147ms step_avg:92.21ms +step:881/1645 train_time:81238ms step_avg:92.21ms +step:882/1645 train_time:81330ms step_avg:92.21ms +step:883/1645 train_time:81423ms step_avg:92.21ms +step:884/1645 train_time:81515ms step_avg:92.21ms +step:885/1645 train_time:81607ms step_avg:92.21ms +step:886/1645 train_time:81701ms step_avg:92.21ms +step:887/1645 train_time:81797ms step_avg:92.22ms +step:888/1645 train_time:81892ms step_avg:92.22ms +step:889/1645 train_time:81985ms step_avg:92.22ms +step:890/1645 train_time:82079ms step_avg:92.22ms +step:891/1645 train_time:82170ms step_avg:92.22ms +step:892/1645 train_time:82262ms step_avg:92.22ms +step:893/1645 train_time:82354ms step_avg:92.22ms +step:894/1645 train_time:82447ms step_avg:92.22ms +step:895/1645 train_time:82539ms step_avg:92.22ms +step:896/1645 train_time:82630ms step_avg:92.22ms +step:897/1645 train_time:82724ms step_avg:92.22ms +step:898/1645 train_time:82818ms step_avg:92.23ms +step:899/1645 train_time:82912ms step_avg:92.23ms +step:900/1645 train_time:83005ms step_avg:92.23ms +step:901/1645 train_time:83098ms step_avg:92.23ms +step:902/1645 train_time:83191ms step_avg:92.23ms +step:903/1645 train_time:83283ms step_avg:92.23ms +step:904/1645 train_time:83375ms step_avg:92.23ms +step:905/1645 train_time:83467ms step_avg:92.23ms +step:906/1645 train_time:83560ms step_avg:92.23ms +step:907/1645 train_time:83653ms step_avg:92.23ms +step:908/1645 train_time:83746ms step_avg:92.23ms +step:909/1645 train_time:83840ms step_avg:92.23ms +step:910/1645 train_time:83934ms step_avg:92.23ms +step:911/1645 train_time:84027ms step_avg:92.24ms +step:912/1645 train_time:84119ms step_avg:92.24ms +step:913/1645 train_time:84211ms step_avg:92.24ms +step:914/1645 train_time:84304ms step_avg:92.24ms +step:915/1645 train_time:84397ms step_avg:92.24ms +step:916/1645 train_time:84489ms step_avg:92.24ms +step:917/1645 train_time:84582ms step_avg:92.24ms +step:918/1645 train_time:84675ms step_avg:92.24ms +step:919/1645 train_time:84768ms step_avg:92.24ms +step:920/1645 train_time:84862ms step_avg:92.24ms +step:921/1645 train_time:84956ms step_avg:92.24ms +step:922/1645 train_time:85049ms step_avg:92.24ms +step:923/1645 train_time:85141ms step_avg:92.24ms +step:924/1645 train_time:85234ms step_avg:92.24ms +step:925/1645 train_time:85327ms step_avg:92.25ms +step:926/1645 train_time:85419ms step_avg:92.25ms +step:927/1645 train_time:85511ms step_avg:92.25ms +step:928/1645 train_time:85604ms step_avg:92.25ms +step:929/1645 train_time:85698ms step_avg:92.25ms +step:930/1645 train_time:85791ms step_avg:92.25ms +step:931/1645 train_time:85884ms step_avg:92.25ms +step:932/1645 train_time:85978ms step_avg:92.25ms +step:933/1645 train_time:86070ms step_avg:92.25ms +step:934/1645 train_time:86163ms step_avg:92.25ms +step:935/1645 train_time:86256ms step_avg:92.25ms +step:936/1645 train_time:86349ms step_avg:92.25ms +step:937/1645 train_time:86441ms step_avg:92.25ms +step:938/1645 train_time:86533ms step_avg:92.25ms +step:939/1645 train_time:86626ms step_avg:92.25ms +step:940/1645 train_time:86719ms step_avg:92.25ms +step:941/1645 train_time:86812ms step_avg:92.25ms +step:942/1645 train_time:86905ms step_avg:92.26ms +step:943/1645 train_time:86999ms step_avg:92.26ms +step:944/1645 train_time:87092ms step_avg:92.26ms +step:945/1645 train_time:87185ms step_avg:92.26ms +step:946/1645 train_time:87277ms step_avg:92.26ms +step:947/1645 train_time:87370ms step_avg:92.26ms +step:948/1645 train_time:87463ms step_avg:92.26ms +step:949/1645 train_time:87556ms step_avg:92.26ms +step:950/1645 train_time:87649ms step_avg:92.26ms +step:951/1645 train_time:87742ms step_avg:92.26ms +step:952/1645 train_time:87835ms step_avg:92.26ms +step:953/1645 train_time:87928ms step_avg:92.26ms +step:954/1645 train_time:88021ms step_avg:92.27ms +step:955/1645 train_time:88115ms step_avg:92.27ms +step:956/1645 train_time:88209ms step_avg:92.27ms +step:957/1645 train_time:88300ms step_avg:92.27ms +step:958/1645 train_time:88393ms step_avg:92.27ms +step:959/1645 train_time:88486ms step_avg:92.27ms +step:960/1645 train_time:88579ms step_avg:92.27ms +step:961/1645 train_time:88672ms step_avg:92.27ms +step:962/1645 train_time:88765ms step_avg:92.27ms +step:963/1645 train_time:88857ms step_avg:92.27ms +step:964/1645 train_time:88950ms step_avg:92.27ms +step:965/1645 train_time:89044ms step_avg:92.27ms +step:966/1645 train_time:89137ms step_avg:92.27ms +step:967/1645 train_time:89229ms step_avg:92.27ms +step:968/1645 train_time:89322ms step_avg:92.27ms +step:969/1645 train_time:89415ms step_avg:92.28ms +step:970/1645 train_time:89507ms step_avg:92.27ms +step:971/1645 train_time:89601ms step_avg:92.28ms +step:972/1645 train_time:89694ms step_avg:92.28ms +step:973/1645 train_time:89786ms step_avg:92.28ms +step:974/1645 train_time:89879ms step_avg:92.28ms +step:975/1645 train_time:89972ms step_avg:92.28ms +step:976/1645 train_time:90064ms step_avg:92.28ms +step:977/1645 train_time:90158ms step_avg:92.28ms +step:978/1645 train_time:90251ms step_avg:92.28ms +step:979/1645 train_time:90343ms step_avg:92.28ms +step:980/1645 train_time:90436ms step_avg:92.28ms +step:981/1645 train_time:90529ms step_avg:92.28ms +step:982/1645 train_time:90622ms step_avg:92.28ms +step:983/1645 train_time:90715ms step_avg:92.28ms +step:984/1645 train_time:90808ms step_avg:92.28ms +step:985/1645 train_time:90902ms step_avg:92.29ms +step:986/1645 train_time:90995ms step_avg:92.29ms +step:987/1645 train_time:91087ms step_avg:92.29ms +step:988/1645 train_time:91180ms step_avg:92.29ms +step:989/1645 train_time:91273ms step_avg:92.29ms +step:990/1645 train_time:91365ms step_avg:92.29ms +step:991/1645 train_time:91458ms step_avg:92.29ms +step:992/1645 train_time:91551ms step_avg:92.29ms +step:993/1645 train_time:91643ms step_avg:92.29ms +step:994/1645 train_time:91736ms step_avg:92.29ms +step:995/1645 train_time:91829ms step_avg:92.29ms +step:996/1645 train_time:91922ms step_avg:92.29ms +step:997/1645 train_time:92016ms step_avg:92.29ms +step:998/1645 train_time:92109ms step_avg:92.29ms +step:999/1645 train_time:92203ms step_avg:92.29ms +step:1000/1645 train_time:92295ms step_avg:92.30ms +step:1000/1645 val_loss:3.4650 train_time:92388ms step_avg:92.39ms +step:1001/1645 train_time:92409ms step_avg:92.32ms +step:1002/1645 train_time:92486ms step_avg:92.30ms +step:1003/1645 train_time:92581ms step_avg:92.30ms +step:1004/1645 train_time:92673ms step_avg:92.30ms +step:1005/1645 train_time:92765ms step_avg:92.30ms +step:1006/1645 train_time:92857ms step_avg:92.30ms +step:1007/1645 train_time:92948ms step_avg:92.30ms +step:1008/1645 train_time:93040ms step_avg:92.30ms +step:1009/1645 train_time:93132ms step_avg:92.30ms +step:1010/1645 train_time:93225ms step_avg:92.30ms +step:1011/1645 train_time:93318ms step_avg:92.30ms +step:1012/1645 train_time:93413ms step_avg:92.30ms +step:1013/1645 train_time:93508ms step_avg:92.31ms +step:1014/1645 train_time:93603ms step_avg:92.31ms +step:1015/1645 train_time:93695ms step_avg:92.31ms +step:1016/1645 train_time:93788ms step_avg:92.31ms +step:1017/1645 train_time:93880ms step_avg:92.31ms +step:1018/1645 train_time:93971ms step_avg:92.31ms +step:1019/1645 train_time:94063ms step_avg:92.31ms +step:1020/1645 train_time:94155ms step_avg:92.31ms +step:1021/1645 train_time:94249ms step_avg:92.31ms +step:1022/1645 train_time:94343ms step_avg:92.31ms +step:1023/1645 train_time:94437ms step_avg:92.31ms +step:1024/1645 train_time:94531ms step_avg:92.32ms +step:1025/1645 train_time:94624ms step_avg:92.32ms +step:1026/1645 train_time:94717ms step_avg:92.32ms +step:1027/1645 train_time:94809ms step_avg:92.32ms +step:1028/1645 train_time:94902ms step_avg:92.32ms +step:1029/1645 train_time:94993ms step_avg:92.32ms +step:1030/1645 train_time:95085ms step_avg:92.32ms +step:1031/1645 train_time:95177ms step_avg:92.32ms +step:1032/1645 train_time:95270ms step_avg:92.32ms +step:1033/1645 train_time:95364ms step_avg:92.32ms +step:1034/1645 train_time:95457ms step_avg:92.32ms +step:1035/1645 train_time:95550ms step_avg:92.32ms +step:1036/1645 train_time:95643ms step_avg:92.32ms +step:1037/1645 train_time:95737ms step_avg:92.32ms +step:1038/1645 train_time:95830ms step_avg:92.32ms +step:1039/1645 train_time:95922ms step_avg:92.32ms +step:1040/1645 train_time:96014ms step_avg:92.32ms +step:1041/1645 train_time:96106ms step_avg:92.32ms +step:1042/1645 train_time:96200ms step_avg:92.32ms +step:1043/1645 train_time:96294ms step_avg:92.32ms +step:1044/1645 train_time:96386ms step_avg:92.32ms +step:1045/1645 train_time:96479ms step_avg:92.32ms +step:1046/1645 train_time:96572ms step_avg:92.32ms +step:1047/1645 train_time:96665ms step_avg:92.33ms +step:1048/1645 train_time:96758ms step_avg:92.33ms +step:1049/1645 train_time:96850ms step_avg:92.33ms +step:1050/1645 train_time:96944ms step_avg:92.33ms +step:1051/1645 train_time:97037ms step_avg:92.33ms +step:1052/1645 train_time:97129ms step_avg:92.33ms +step:1053/1645 train_time:97221ms step_avg:92.33ms +step:1054/1645 train_time:97314ms step_avg:92.33ms +step:1055/1645 train_time:97407ms step_avg:92.33ms +step:1056/1645 train_time:97501ms step_avg:92.33ms +step:1057/1645 train_time:97594ms step_avg:92.33ms +step:1058/1645 train_time:97688ms step_avg:92.33ms +step:1059/1645 train_time:97781ms step_avg:92.33ms +step:1060/1645 train_time:97873ms step_avg:92.33ms +step:1061/1645 train_time:97967ms step_avg:92.33ms +step:1062/1645 train_time:98060ms step_avg:92.34ms +step:1063/1645 train_time:98152ms step_avg:92.34ms +step:1064/1645 train_time:98245ms step_avg:92.34ms +step:1065/1645 train_time:98339ms step_avg:92.34ms +step:1066/1645 train_time:98432ms step_avg:92.34ms +step:1067/1645 train_time:98525ms step_avg:92.34ms +step:1068/1645 train_time:98618ms step_avg:92.34ms +step:1069/1645 train_time:98711ms step_avg:92.34ms +step:1070/1645 train_time:98805ms step_avg:92.34ms +step:1071/1645 train_time:98897ms step_avg:92.34ms +step:1072/1645 train_time:98990ms step_avg:92.34ms +step:1073/1645 train_time:99082ms step_avg:92.34ms +step:1074/1645 train_time:99175ms step_avg:92.34ms +step:1075/1645 train_time:99268ms step_avg:92.34ms +step:1076/1645 train_time:99361ms step_avg:92.34ms +step:1077/1645 train_time:99454ms step_avg:92.34ms +step:1078/1645 train_time:99547ms step_avg:92.34ms +step:1079/1645 train_time:99640ms step_avg:92.34ms +step:1080/1645 train_time:99733ms step_avg:92.35ms +step:1081/1645 train_time:99826ms step_avg:92.35ms +step:1082/1645 train_time:99920ms step_avg:92.35ms +step:1083/1645 train_time:100013ms step_avg:92.35ms +step:1084/1645 train_time:100105ms step_avg:92.35ms +step:1085/1645 train_time:100197ms step_avg:92.35ms +step:1086/1645 train_time:100290ms step_avg:92.35ms +step:1087/1645 train_time:100383ms step_avg:92.35ms +step:1088/1645 train_time:100476ms step_avg:92.35ms +step:1089/1645 train_time:100569ms step_avg:92.35ms +step:1090/1645 train_time:100662ms step_avg:92.35ms +step:1091/1645 train_time:100755ms step_avg:92.35ms +step:1092/1645 train_time:100848ms step_avg:92.35ms +step:1093/1645 train_time:100942ms step_avg:92.35ms +step:1094/1645 train_time:101035ms step_avg:92.35ms +step:1095/1645 train_time:101127ms step_avg:92.35ms +step:1096/1645 train_time:101219ms step_avg:92.35ms +step:1097/1645 train_time:101313ms step_avg:92.35ms +step:1098/1645 train_time:101405ms step_avg:92.35ms +step:1099/1645 train_time:101499ms step_avg:92.36ms +step:1100/1645 train_time:101592ms step_avg:92.36ms +step:1101/1645 train_time:101686ms step_avg:92.36ms +step:1102/1645 train_time:101779ms step_avg:92.36ms +step:1103/1645 train_time:101873ms step_avg:92.36ms +step:1104/1645 train_time:101968ms step_avg:92.36ms +step:1105/1645 train_time:102061ms step_avg:92.36ms +step:1106/1645 train_time:102154ms step_avg:92.36ms +step:1107/1645 train_time:102247ms step_avg:92.36ms +step:1108/1645 train_time:102340ms step_avg:92.37ms +step:1109/1645 train_time:102434ms step_avg:92.37ms +step:1110/1645 train_time:102527ms step_avg:92.37ms +step:1111/1645 train_time:102622ms step_avg:92.37ms +step:1112/1645 train_time:102715ms step_avg:92.37ms +step:1113/1645 train_time:102808ms step_avg:92.37ms +step:1114/1645 train_time:102901ms step_avg:92.37ms +step:1115/1645 train_time:102995ms step_avg:92.37ms +step:1116/1645 train_time:103088ms step_avg:92.37ms +step:1117/1645 train_time:103182ms step_avg:92.37ms +step:1118/1645 train_time:103275ms step_avg:92.37ms +step:1119/1645 train_time:103369ms step_avg:92.38ms +step:1120/1645 train_time:103463ms step_avg:92.38ms +step:1121/1645 train_time:103556ms step_avg:92.38ms +step:1122/1645 train_time:103649ms step_avg:92.38ms +step:1123/1645 train_time:103745ms step_avg:92.38ms +step:1124/1645 train_time:103838ms step_avg:92.38ms +step:1125/1645 train_time:103932ms step_avg:92.38ms +step:1125/1645 val_loss:3.4117 train_time:104025ms step_avg:92.47ms +step:1126/1645 train_time:104046ms step_avg:92.40ms +step:1127/1645 train_time:104126ms step_avg:92.39ms +step:1128/1645 train_time:104226ms step_avg:92.40ms +step:1129/1645 train_time:104320ms step_avg:92.40ms +step:1130/1645 train_time:104413ms step_avg:92.40ms +step:1131/1645 train_time:104505ms step_avg:92.40ms +step:1132/1645 train_time:104598ms step_avg:92.40ms +step:1133/1645 train_time:104690ms step_avg:92.40ms +step:1134/1645 train_time:104783ms step_avg:92.40ms +step:1135/1645 train_time:104875ms step_avg:92.40ms +step:1136/1645 train_time:104971ms step_avg:92.40ms +step:1137/1645 train_time:105066ms step_avg:92.41ms +step:1138/1645 train_time:105161ms step_avg:92.41ms +step:1139/1645 train_time:105256ms step_avg:92.41ms +step:1140/1645 train_time:105350ms step_avg:92.41ms +step:1141/1645 train_time:105443ms step_avg:92.41ms +step:1142/1645 train_time:105535ms step_avg:92.41ms +step:1143/1645 train_time:105627ms step_avg:92.41ms +step:1144/1645 train_time:105720ms step_avg:92.41ms +step:1145/1645 train_time:105813ms step_avg:92.41ms +step:1146/1645 train_time:105907ms step_avg:92.41ms +step:1147/1645 train_time:106000ms step_avg:92.41ms +step:1148/1645 train_time:106097ms step_avg:92.42ms +step:1149/1645 train_time:106192ms step_avg:92.42ms +step:1150/1645 train_time:106288ms step_avg:92.42ms +step:1151/1645 train_time:106381ms step_avg:92.42ms +step:1152/1645 train_time:106473ms step_avg:92.42ms +step:1153/1645 train_time:106566ms step_avg:92.42ms +step:1154/1645 train_time:106659ms step_avg:92.43ms +step:1155/1645 train_time:106752ms step_avg:92.43ms +step:1156/1645 train_time:106845ms step_avg:92.43ms +step:1157/1645 train_time:106938ms step_avg:92.43ms +step:1158/1645 train_time:107033ms step_avg:92.43ms +step:1159/1645 train_time:107128ms step_avg:92.43ms +step:1160/1645 train_time:107221ms step_avg:92.43ms +step:1161/1645 train_time:107315ms step_avg:92.43ms +step:1162/1645 train_time:107409ms step_avg:92.43ms +step:1163/1645 train_time:107503ms step_avg:92.44ms +step:1164/1645 train_time:107597ms step_avg:92.44ms +step:1165/1645 train_time:107688ms step_avg:92.44ms +step:1166/1645 train_time:107781ms step_avg:92.44ms +step:1167/1645 train_time:107873ms step_avg:92.44ms +step:1168/1645 train_time:107967ms step_avg:92.44ms +step:1169/1645 train_time:108060ms step_avg:92.44ms +step:1170/1645 train_time:108155ms step_avg:92.44ms +step:1171/1645 train_time:108250ms step_avg:92.44ms +step:1172/1645 train_time:108343ms step_avg:92.44ms +step:1173/1645 train_time:108436ms step_avg:92.44ms +step:1174/1645 train_time:108530ms step_avg:92.44ms +step:1175/1645 train_time:108624ms step_avg:92.45ms +step:1176/1645 train_time:108718ms step_avg:92.45ms +step:1177/1645 train_time:108811ms step_avg:92.45ms +step:1178/1645 train_time:108905ms step_avg:92.45ms +step:1179/1645 train_time:108998ms step_avg:92.45ms +step:1180/1645 train_time:109092ms step_avg:92.45ms +step:1181/1645 train_time:109186ms step_avg:92.45ms +step:1182/1645 train_time:109280ms step_avg:92.45ms +step:1183/1645 train_time:109375ms step_avg:92.46ms +step:1184/1645 train_time:109468ms step_avg:92.46ms +step:1185/1645 train_time:109561ms step_avg:92.46ms +step:1186/1645 train_time:109656ms step_avg:92.46ms +step:1187/1645 train_time:109749ms step_avg:92.46ms +step:1188/1645 train_time:109842ms step_avg:92.46ms +step:1189/1645 train_time:109935ms step_avg:92.46ms +step:1190/1645 train_time:110029ms step_avg:92.46ms +step:1191/1645 train_time:110122ms step_avg:92.46ms +step:1192/1645 train_time:110217ms step_avg:92.46ms +step:1193/1645 train_time:110314ms step_avg:92.47ms +step:1194/1645 train_time:110405ms step_avg:92.47ms +step:1195/1645 train_time:110499ms step_avg:92.47ms +step:1196/1645 train_time:110592ms step_avg:92.47ms +step:1197/1645 train_time:110685ms step_avg:92.47ms +step:1198/1645 train_time:110779ms step_avg:92.47ms +step:1199/1645 train_time:110872ms step_avg:92.47ms +step:1200/1645 train_time:110965ms step_avg:92.47ms +step:1201/1645 train_time:111059ms step_avg:92.47ms +step:1202/1645 train_time:111154ms step_avg:92.47ms +step:1203/1645 train_time:111247ms step_avg:92.47ms +step:1204/1645 train_time:111340ms step_avg:92.48ms +step:1205/1645 train_time:111433ms step_avg:92.48ms +step:1206/1645 train_time:111527ms step_avg:92.48ms +step:1207/1645 train_time:111621ms step_avg:92.48ms +step:1208/1645 train_time:111715ms step_avg:92.48ms +step:1209/1645 train_time:111807ms step_avg:92.48ms +step:1210/1645 train_time:111901ms step_avg:92.48ms +step:1211/1645 train_time:111994ms step_avg:92.48ms +step:1212/1645 train_time:112088ms step_avg:92.48ms +step:1213/1645 train_time:112181ms step_avg:92.48ms +step:1214/1645 train_time:112276ms step_avg:92.48ms +step:1215/1645 train_time:112368ms step_avg:92.48ms +step:1216/1645 train_time:112463ms step_avg:92.49ms +step:1217/1645 train_time:112556ms step_avg:92.49ms +step:1218/1645 train_time:112650ms step_avg:92.49ms +step:1219/1645 train_time:112744ms step_avg:92.49ms +step:1220/1645 train_time:112836ms step_avg:92.49ms +step:1221/1645 train_time:112930ms step_avg:92.49ms +step:1222/1645 train_time:113025ms step_avg:92.49ms +step:1223/1645 train_time:113117ms step_avg:92.49ms +step:1224/1645 train_time:113211ms step_avg:92.49ms +step:1225/1645 train_time:113304ms step_avg:92.49ms +step:1226/1645 train_time:113398ms step_avg:92.49ms +step:1227/1645 train_time:113492ms step_avg:92.50ms +step:1228/1645 train_time:113586ms step_avg:92.50ms +step:1229/1645 train_time:113679ms step_avg:92.50ms +step:1230/1645 train_time:113772ms step_avg:92.50ms +step:1231/1645 train_time:113865ms step_avg:92.50ms +step:1232/1645 train_time:113959ms step_avg:92.50ms +step:1233/1645 train_time:114052ms step_avg:92.50ms +step:1234/1645 train_time:114146ms step_avg:92.50ms +step:1235/1645 train_time:114240ms step_avg:92.50ms +step:1236/1645 train_time:114335ms step_avg:92.50ms +step:1237/1645 train_time:114430ms step_avg:92.51ms +step:1238/1645 train_time:114522ms step_avg:92.51ms +step:1239/1645 train_time:114615ms step_avg:92.51ms +step:1240/1645 train_time:114708ms step_avg:92.51ms +step:1241/1645 train_time:114802ms step_avg:92.51ms +step:1242/1645 train_time:114896ms step_avg:92.51ms +step:1243/1645 train_time:114990ms step_avg:92.51ms +step:1244/1645 train_time:115084ms step_avg:92.51ms +step:1245/1645 train_time:115177ms step_avg:92.51ms +step:1246/1645 train_time:115271ms step_avg:92.51ms +step:1247/1645 train_time:115365ms step_avg:92.51ms +step:1248/1645 train_time:115458ms step_avg:92.51ms +step:1249/1645 train_time:115552ms step_avg:92.52ms +step:1250/1645 train_time:115646ms step_avg:92.52ms +step:1250/1645 val_loss:3.3735 train_time:115739ms step_avg:92.59ms +step:1251/1645 train_time:115760ms step_avg:92.53ms +step:1252/1645 train_time:115838ms step_avg:92.52ms +step:1253/1645 train_time:115933ms step_avg:92.52ms +step:1254/1645 train_time:116028ms step_avg:92.53ms +step:1255/1645 train_time:116120ms step_avg:92.53ms +step:1256/1645 train_time:116213ms step_avg:92.53ms +step:1257/1645 train_time:116305ms step_avg:92.53ms +step:1258/1645 train_time:116398ms step_avg:92.53ms +step:1259/1645 train_time:116491ms step_avg:92.53ms +step:1260/1645 train_time:116584ms step_avg:92.53ms +step:1261/1645 train_time:116679ms step_avg:92.53ms +step:1262/1645 train_time:116774ms step_avg:92.53ms +step:1263/1645 train_time:116869ms step_avg:92.53ms +step:1264/1645 train_time:116964ms step_avg:92.53ms +step:1265/1645 train_time:117057ms step_avg:92.54ms +step:1266/1645 train_time:117151ms step_avg:92.54ms +step:1267/1645 train_time:117243ms step_avg:92.54ms +step:1268/1645 train_time:117336ms step_avg:92.54ms +step:1269/1645 train_time:117430ms step_avg:92.54ms +step:1270/1645 train_time:117523ms step_avg:92.54ms +step:1271/1645 train_time:117616ms step_avg:92.54ms +step:1272/1645 train_time:117710ms step_avg:92.54ms +step:1273/1645 train_time:117805ms step_avg:92.54ms +step:1274/1645 train_time:117900ms step_avg:92.54ms +step:1275/1645 train_time:117994ms step_avg:92.54ms +step:1276/1645 train_time:118088ms step_avg:92.55ms +step:1277/1645 train_time:118183ms step_avg:92.55ms +step:1278/1645 train_time:118275ms step_avg:92.55ms +step:1279/1645 train_time:118368ms step_avg:92.55ms +step:1280/1645 train_time:118462ms step_avg:92.55ms +step:1281/1645 train_time:118554ms step_avg:92.55ms +step:1282/1645 train_time:118648ms step_avg:92.55ms +step:1283/1645 train_time:118742ms step_avg:92.55ms +step:1284/1645 train_time:118836ms step_avg:92.55ms +step:1285/1645 train_time:118930ms step_avg:92.55ms +step:1286/1645 train_time:119024ms step_avg:92.55ms +step:1287/1645 train_time:119118ms step_avg:92.55ms +step:1288/1645 train_time:119211ms step_avg:92.56ms +step:1289/1645 train_time:119305ms step_avg:92.56ms +step:1290/1645 train_time:119398ms step_avg:92.56ms +step:1291/1645 train_time:119493ms step_avg:92.56ms +step:1292/1645 train_time:119585ms step_avg:92.56ms +step:1293/1645 train_time:119678ms step_avg:92.56ms +step:1294/1645 train_time:119772ms step_avg:92.56ms +step:1295/1645 train_time:119867ms step_avg:92.56ms +step:1296/1645 train_time:119961ms step_avg:92.56ms +step:1297/1645 train_time:120055ms step_avg:92.56ms +step:1298/1645 train_time:120150ms step_avg:92.57ms +step:1299/1645 train_time:120243ms step_avg:92.57ms +step:1300/1645 train_time:120337ms step_avg:92.57ms +step:1301/1645 train_time:120430ms step_avg:92.57ms +step:1302/1645 train_time:120525ms step_avg:92.57ms +step:1303/1645 train_time:120618ms step_avg:92.57ms +step:1304/1645 train_time:120712ms step_avg:92.57ms +step:1305/1645 train_time:120806ms step_avg:92.57ms +step:1306/1645 train_time:120900ms step_avg:92.57ms +step:1307/1645 train_time:120995ms step_avg:92.57ms +step:1308/1645 train_time:121088ms step_avg:92.58ms +step:1309/1645 train_time:121181ms step_avg:92.58ms +step:1310/1645 train_time:121274ms step_avg:92.58ms +step:1311/1645 train_time:121367ms step_avg:92.58ms +step:1312/1645 train_time:121461ms step_avg:92.58ms +step:1313/1645 train_time:121555ms step_avg:92.58ms +step:1314/1645 train_time:121648ms step_avg:92.58ms +step:1315/1645 train_time:121741ms step_avg:92.58ms +step:1316/1645 train_time:121835ms step_avg:92.58ms +step:1317/1645 train_time:121929ms step_avg:92.58ms +step:1318/1645 train_time:122023ms step_avg:92.58ms +step:1319/1645 train_time:122117ms step_avg:92.58ms +step:1320/1645 train_time:122210ms step_avg:92.58ms +step:1321/1645 train_time:122305ms step_avg:92.59ms +step:1322/1645 train_time:122398ms step_avg:92.59ms +step:1323/1645 train_time:122493ms step_avg:92.59ms +step:1324/1645 train_time:122587ms step_avg:92.59ms +step:1325/1645 train_time:122681ms step_avg:92.59ms +step:1326/1645 train_time:122774ms step_avg:92.59ms +step:1327/1645 train_time:122868ms step_avg:92.59ms +step:1328/1645 train_time:122962ms step_avg:92.59ms +step:1329/1645 train_time:123056ms step_avg:92.59ms +step:1330/1645 train_time:123149ms step_avg:92.59ms +step:1331/1645 train_time:123243ms step_avg:92.59ms +step:1332/1645 train_time:123336ms step_avg:92.59ms +step:1333/1645 train_time:123430ms step_avg:92.60ms +step:1334/1645 train_time:123524ms step_avg:92.60ms +step:1335/1645 train_time:123618ms step_avg:92.60ms +step:1336/1645 train_time:123712ms step_avg:92.60ms +step:1337/1645 train_time:123805ms step_avg:92.60ms +step:1338/1645 train_time:123899ms step_avg:92.60ms +step:1339/1645 train_time:123993ms step_avg:92.60ms +step:1340/1645 train_time:124086ms step_avg:92.60ms +step:1341/1645 train_time:124180ms step_avg:92.60ms +step:1342/1645 train_time:124273ms step_avg:92.60ms +step:1343/1645 train_time:124366ms step_avg:92.60ms +step:1344/1645 train_time:124460ms step_avg:92.60ms +step:1345/1645 train_time:124554ms step_avg:92.60ms +step:1346/1645 train_time:124649ms step_avg:92.61ms +step:1347/1645 train_time:124742ms step_avg:92.61ms +step:1348/1645 train_time:124836ms step_avg:92.61ms +step:1349/1645 train_time:124930ms step_avg:92.61ms +step:1350/1645 train_time:125023ms step_avg:92.61ms +step:1351/1645 train_time:125116ms step_avg:92.61ms +step:1352/1645 train_time:125211ms step_avg:92.61ms +step:1353/1645 train_time:125304ms step_avg:92.61ms +step:1354/1645 train_time:125397ms step_avg:92.61ms +step:1355/1645 train_time:125491ms step_avg:92.61ms +step:1356/1645 train_time:125584ms step_avg:92.61ms +step:1357/1645 train_time:125678ms step_avg:92.61ms +step:1358/1645 train_time:125771ms step_avg:92.61ms +step:1359/1645 train_time:125864ms step_avg:92.62ms +step:1360/1645 train_time:125957ms step_avg:92.62ms +step:1361/1645 train_time:126051ms step_avg:92.62ms +step:1362/1645 train_time:126144ms step_avg:92.62ms +step:1363/1645 train_time:126238ms step_avg:92.62ms +step:1364/1645 train_time:126332ms step_avg:92.62ms +step:1365/1645 train_time:126426ms step_avg:92.62ms +step:1366/1645 train_time:126521ms step_avg:92.62ms +step:1367/1645 train_time:126614ms step_avg:92.62ms +step:1368/1645 train_time:126708ms step_avg:92.62ms +step:1369/1645 train_time:126802ms step_avg:92.62ms +step:1370/1645 train_time:126895ms step_avg:92.62ms +step:1371/1645 train_time:126989ms step_avg:92.63ms +step:1372/1645 train_time:127083ms step_avg:92.63ms +step:1373/1645 train_time:127177ms step_avg:92.63ms +step:1374/1645 train_time:127269ms step_avg:92.63ms +step:1375/1645 train_time:127364ms step_avg:92.63ms +step:1375/1645 val_loss:3.3383 train_time:127457ms step_avg:92.70ms +step:1376/1645 train_time:127478ms step_avg:92.64ms +step:1377/1645 train_time:127554ms step_avg:92.63ms +step:1378/1645 train_time:127650ms step_avg:92.63ms +step:1379/1645 train_time:127744ms step_avg:92.64ms +step:1380/1645 train_time:127837ms step_avg:92.64ms +step:1381/1645 train_time:127930ms step_avg:92.64ms +step:1382/1645 train_time:128023ms step_avg:92.64ms +step:1383/1645 train_time:128116ms step_avg:92.64ms +step:1384/1645 train_time:128209ms step_avg:92.64ms +step:1385/1645 train_time:128301ms step_avg:92.64ms +step:1386/1645 train_time:128395ms step_avg:92.64ms +step:1387/1645 train_time:128491ms step_avg:92.64ms +step:1388/1645 train_time:128586ms step_avg:92.64ms +step:1389/1645 train_time:128680ms step_avg:92.64ms +step:1390/1645 train_time:128774ms step_avg:92.64ms +step:1391/1645 train_time:128868ms step_avg:92.64ms +step:1392/1645 train_time:128961ms step_avg:92.64ms +step:1393/1645 train_time:129055ms step_avg:92.65ms +step:1394/1645 train_time:129148ms step_avg:92.65ms +step:1395/1645 train_time:129240ms step_avg:92.65ms +step:1396/1645 train_time:129333ms step_avg:92.65ms +step:1397/1645 train_time:129427ms step_avg:92.65ms +step:1398/1645 train_time:129522ms step_avg:92.65ms +step:1399/1645 train_time:129616ms step_avg:92.65ms +step:1400/1645 train_time:129711ms step_avg:92.65ms +step:1401/1645 train_time:129804ms step_avg:92.65ms +step:1402/1645 train_time:129898ms step_avg:92.65ms +step:1403/1645 train_time:129991ms step_avg:92.65ms +step:1404/1645 train_time:130085ms step_avg:92.65ms +step:1405/1645 train_time:130178ms step_avg:92.65ms +step:1406/1645 train_time:130271ms step_avg:92.65ms +step:1407/1645 train_time:130364ms step_avg:92.65ms +step:1408/1645 train_time:130460ms step_avg:92.66ms +step:1409/1645 train_time:130554ms step_avg:92.66ms +step:1410/1645 train_time:130648ms step_avg:92.66ms +step:1411/1645 train_time:130742ms step_avg:92.66ms +step:1412/1645 train_time:130836ms step_avg:92.66ms +step:1413/1645 train_time:130930ms step_avg:92.66ms +step:1414/1645 train_time:131024ms step_avg:92.66ms +step:1415/1645 train_time:131116ms step_avg:92.66ms +step:1416/1645 train_time:131209ms step_avg:92.66ms +step:1417/1645 train_time:131302ms step_avg:92.66ms +step:1418/1645 train_time:131395ms step_avg:92.66ms +step:1419/1645 train_time:131489ms step_avg:92.66ms +step:1420/1645 train_time:131583ms step_avg:92.66ms +step:1421/1645 train_time:131678ms step_avg:92.67ms +step:1422/1645 train_time:131771ms step_avg:92.67ms +step:1423/1645 train_time:131864ms step_avg:92.67ms +step:1424/1645 train_time:131959ms step_avg:92.67ms +step:1425/1645 train_time:132052ms step_avg:92.67ms +step:1426/1645 train_time:132145ms step_avg:92.67ms +step:1427/1645 train_time:132239ms step_avg:92.67ms +step:1428/1645 train_time:132332ms step_avg:92.67ms +step:1429/1645 train_time:132426ms step_avg:92.67ms +step:1430/1645 train_time:132520ms step_avg:92.67ms +step:1431/1645 train_time:132614ms step_avg:92.67ms +step:1432/1645 train_time:132708ms step_avg:92.67ms +step:1433/1645 train_time:132802ms step_avg:92.67ms +step:1434/1645 train_time:132895ms step_avg:92.67ms +step:1435/1645 train_time:132989ms step_avg:92.68ms +step:1436/1645 train_time:133082ms step_avg:92.68ms +step:1437/1645 train_time:133176ms step_avg:92.68ms +step:1438/1645 train_time:133268ms step_avg:92.68ms +step:1439/1645 train_time:133362ms step_avg:92.68ms +step:1440/1645 train_time:133456ms step_avg:92.68ms +step:1441/1645 train_time:133549ms step_avg:92.68ms +step:1442/1645 train_time:133643ms step_avg:92.68ms +step:1443/1645 train_time:133736ms step_avg:92.68ms +step:1444/1645 train_time:133831ms step_avg:92.68ms +step:1445/1645 train_time:133925ms step_avg:92.68ms +step:1446/1645 train_time:134019ms step_avg:92.68ms +step:1447/1645 train_time:134112ms step_avg:92.68ms +step:1448/1645 train_time:134205ms step_avg:92.68ms +step:1449/1645 train_time:134299ms step_avg:92.68ms +step:1450/1645 train_time:134393ms step_avg:92.69ms +step:1451/1645 train_time:134486ms step_avg:92.69ms +step:1452/1645 train_time:134580ms step_avg:92.69ms +step:1453/1645 train_time:134674ms step_avg:92.69ms +step:1454/1645 train_time:134767ms step_avg:92.69ms +step:1455/1645 train_time:134863ms step_avg:92.69ms +step:1456/1645 train_time:134957ms step_avg:92.69ms +step:1457/1645 train_time:135050ms step_avg:92.69ms +step:1458/1645 train_time:135143ms step_avg:92.69ms +step:1459/1645 train_time:135237ms step_avg:92.69ms +step:1460/1645 train_time:135330ms step_avg:92.69ms +step:1461/1645 train_time:135424ms step_avg:92.69ms +step:1462/1645 train_time:135519ms step_avg:92.69ms +step:1463/1645 train_time:135612ms step_avg:92.69ms +step:1464/1645 train_time:135706ms step_avg:92.70ms +step:1465/1645 train_time:135800ms step_avg:92.70ms +step:1466/1645 train_time:135894ms step_avg:92.70ms +step:1467/1645 train_time:135988ms step_avg:92.70ms +step:1468/1645 train_time:136081ms step_avg:92.70ms +step:1469/1645 train_time:136176ms step_avg:92.70ms +step:1470/1645 train_time:136269ms step_avg:92.70ms +step:1471/1645 train_time:136362ms step_avg:92.70ms +step:1472/1645 train_time:136455ms step_avg:92.70ms +step:1473/1645 train_time:136549ms step_avg:92.70ms +step:1474/1645 train_time:136643ms step_avg:92.70ms +step:1475/1645 train_time:136738ms step_avg:92.70ms +step:1476/1645 train_time:136832ms step_avg:92.70ms +step:1477/1645 train_time:136925ms step_avg:92.71ms +step:1478/1645 train_time:137018ms step_avg:92.71ms +step:1479/1645 train_time:137111ms step_avg:92.71ms +step:1480/1645 train_time:137206ms step_avg:92.71ms +step:1481/1645 train_time:137298ms step_avg:92.71ms +step:1482/1645 train_time:137391ms step_avg:92.71ms +step:1483/1645 train_time:137485ms step_avg:92.71ms +step:1484/1645 train_time:137579ms step_avg:92.71ms +step:1485/1645 train_time:137672ms step_avg:92.71ms +step:1486/1645 train_time:137766ms step_avg:92.71ms +step:1487/1645 train_time:137860ms step_avg:92.71ms +step:1488/1645 train_time:137954ms step_avg:92.71ms +step:1489/1645 train_time:138047ms step_avg:92.71ms +step:1490/1645 train_time:138141ms step_avg:92.71ms +step:1491/1645 train_time:138235ms step_avg:92.71ms +step:1492/1645 train_time:138328ms step_avg:92.71ms +step:1493/1645 train_time:138422ms step_avg:92.71ms +step:1494/1645 train_time:138516ms step_avg:92.71ms +step:1495/1645 train_time:138610ms step_avg:92.72ms +step:1496/1645 train_time:138702ms step_avg:92.72ms +step:1497/1645 train_time:138796ms step_avg:92.72ms +step:1498/1645 train_time:138890ms step_avg:92.72ms +step:1499/1645 train_time:138984ms step_avg:92.72ms +step:1500/1645 train_time:139077ms step_avg:92.72ms +step:1500/1645 val_loss:3.3084 train_time:139171ms step_avg:92.78ms +step:1501/1645 train_time:139192ms step_avg:92.73ms +step:1502/1645 train_time:139269ms step_avg:92.72ms +step:1503/1645 train_time:139365ms step_avg:92.72ms +step:1504/1645 train_time:139457ms step_avg:92.72ms +step:1505/1645 train_time:139550ms step_avg:92.72ms +step:1506/1645 train_time:139643ms step_avg:92.72ms +step:1507/1645 train_time:139736ms step_avg:92.72ms +step:1508/1645 train_time:139829ms step_avg:92.72ms +step:1509/1645 train_time:139921ms step_avg:92.72ms +step:1510/1645 train_time:140014ms step_avg:92.72ms +step:1511/1645 train_time:140108ms step_avg:92.73ms +step:1512/1645 train_time:140203ms step_avg:92.73ms +step:1513/1645 train_time:140298ms step_avg:92.73ms +step:1514/1645 train_time:140393ms step_avg:92.73ms +step:1515/1645 train_time:140486ms step_avg:92.73ms +step:1516/1645 train_time:140579ms step_avg:92.73ms +step:1517/1645 train_time:140672ms step_avg:92.73ms +step:1518/1645 train_time:140764ms step_avg:92.73ms +step:1519/1645 train_time:140858ms step_avg:92.73ms +step:1520/1645 train_time:140950ms step_avg:92.73ms +step:1521/1645 train_time:141043ms step_avg:92.73ms +step:1522/1645 train_time:141138ms step_avg:92.73ms +step:1523/1645 train_time:141232ms step_avg:92.73ms +step:1524/1645 train_time:141326ms step_avg:92.73ms +step:1525/1645 train_time:141420ms step_avg:92.73ms +step:1526/1645 train_time:141513ms step_avg:92.73ms +step:1527/1645 train_time:141607ms step_avg:92.74ms +step:1528/1645 train_time:141700ms step_avg:92.74ms +step:1529/1645 train_time:141793ms step_avg:92.74ms +step:1530/1645 train_time:141886ms step_avg:92.74ms +step:1531/1645 train_time:141979ms step_avg:92.74ms +step:1532/1645 train_time:142072ms step_avg:92.74ms +step:1533/1645 train_time:142166ms step_avg:92.74ms +step:1534/1645 train_time:142261ms step_avg:92.74ms +step:1535/1645 train_time:142354ms step_avg:92.74ms +step:1536/1645 train_time:142448ms step_avg:92.74ms +step:1537/1645 train_time:142541ms step_avg:92.74ms +step:1538/1645 train_time:142635ms step_avg:92.74ms +step:1539/1645 train_time:142728ms step_avg:92.74ms +step:1540/1645 train_time:142821ms step_avg:92.74ms +step:1541/1645 train_time:142915ms step_avg:92.74ms +step:1542/1645 train_time:143008ms step_avg:92.74ms +step:1543/1645 train_time:143102ms step_avg:92.74ms +step:1544/1645 train_time:143195ms step_avg:92.74ms +step:1545/1645 train_time:143290ms step_avg:92.74ms +step:1546/1645 train_time:143384ms step_avg:92.74ms +step:1547/1645 train_time:143478ms step_avg:92.75ms +step:1548/1645 train_time:143572ms step_avg:92.75ms +step:1549/1645 train_time:143666ms step_avg:92.75ms +step:1550/1645 train_time:143759ms step_avg:92.75ms +step:1551/1645 train_time:143852ms step_avg:92.75ms +step:1552/1645 train_time:143946ms step_avg:92.75ms +step:1553/1645 train_time:144038ms step_avg:92.75ms +step:1554/1645 train_time:144132ms step_avg:92.75ms +step:1555/1645 train_time:144226ms step_avg:92.75ms +step:1556/1645 train_time:144319ms step_avg:92.75ms +step:1557/1645 train_time:144412ms step_avg:92.75ms +step:1558/1645 train_time:144507ms step_avg:92.75ms +step:1559/1645 train_time:144600ms step_avg:92.75ms +step:1560/1645 train_time:144693ms step_avg:92.75ms +step:1561/1645 train_time:144788ms step_avg:92.75ms +step:1562/1645 train_time:144882ms step_avg:92.75ms +step:1563/1645 train_time:144975ms step_avg:92.75ms +step:1564/1645 train_time:145068ms step_avg:92.75ms +step:1565/1645 train_time:145162ms step_avg:92.76ms +step:1566/1645 train_time:145256ms step_avg:92.76ms +step:1567/1645 train_time:145350ms step_avg:92.76ms +step:1568/1645 train_time:145443ms step_avg:92.76ms +step:1569/1645 train_time:145536ms step_avg:92.76ms +step:1570/1645 train_time:145630ms step_avg:92.76ms +step:1571/1645 train_time:145723ms step_avg:92.76ms +step:1572/1645 train_time:145817ms step_avg:92.76ms +step:1573/1645 train_time:145911ms step_avg:92.76ms +step:1574/1645 train_time:146005ms step_avg:92.76ms +step:1575/1645 train_time:146097ms step_avg:92.76ms +step:1576/1645 train_time:146192ms step_avg:92.76ms +step:1577/1645 train_time:146286ms step_avg:92.76ms +step:1578/1645 train_time:146381ms step_avg:92.76ms +step:1579/1645 train_time:146474ms step_avg:92.76ms +step:1580/1645 train_time:146568ms step_avg:92.76ms +step:1581/1645 train_time:146662ms step_avg:92.77ms +step:1582/1645 train_time:146756ms step_avg:92.77ms +step:1583/1645 train_time:146849ms step_avg:92.77ms +step:1584/1645 train_time:146943ms step_avg:92.77ms +step:1585/1645 train_time:147035ms step_avg:92.77ms +step:1586/1645 train_time:147129ms step_avg:92.77ms +step:1587/1645 train_time:147222ms step_avg:92.77ms +step:1588/1645 train_time:147315ms step_avg:92.77ms +step:1589/1645 train_time:147409ms step_avg:92.77ms +step:1590/1645 train_time:147503ms step_avg:92.77ms +step:1591/1645 train_time:147597ms step_avg:92.77ms +step:1592/1645 train_time:147692ms step_avg:92.77ms +step:1593/1645 train_time:147786ms step_avg:92.77ms +step:1594/1645 train_time:147879ms step_avg:92.77ms +step:1595/1645 train_time:147972ms step_avg:92.77ms +step:1596/1645 train_time:148066ms step_avg:92.77ms +step:1597/1645 train_time:148159ms step_avg:92.77ms +step:1598/1645 train_time:148253ms step_avg:92.77ms +step:1599/1645 train_time:148346ms step_avg:92.77ms +step:1600/1645 train_time:148439ms step_avg:92.77ms +step:1601/1645 train_time:148532ms step_avg:92.77ms +step:1602/1645 train_time:148625ms step_avg:92.77ms +step:1603/1645 train_time:148719ms step_avg:92.78ms +step:1604/1645 train_time:148812ms step_avg:92.78ms +step:1605/1645 train_time:148905ms step_avg:92.78ms +step:1606/1645 train_time:148999ms step_avg:92.78ms +step:1607/1645 train_time:149094ms step_avg:92.78ms +step:1608/1645 train_time:149186ms step_avg:92.78ms +step:1609/1645 train_time:149280ms step_avg:92.78ms +step:1610/1645 train_time:149374ms step_avg:92.78ms +step:1611/1645 train_time:149468ms step_avg:92.78ms +step:1612/1645 train_time:149562ms step_avg:92.78ms +step:1613/1645 train_time:149655ms step_avg:92.78ms +step:1614/1645 train_time:149749ms step_avg:92.78ms +step:1615/1645 train_time:149842ms step_avg:92.78ms +step:1616/1645 train_time:149935ms step_avg:92.78ms +step:1617/1645 train_time:150029ms step_avg:92.78ms +step:1618/1645 train_time:150123ms step_avg:92.78ms +step:1619/1645 train_time:150217ms step_avg:92.78ms +step:1620/1645 train_time:150311ms step_avg:92.78ms +step:1621/1645 train_time:150405ms step_avg:92.79ms +step:1622/1645 train_time:150499ms step_avg:92.79ms +step:1623/1645 train_time:150593ms step_avg:92.79ms +step:1624/1645 train_time:150686ms step_avg:92.79ms +step:1625/1645 train_time:150780ms step_avg:92.79ms +step:1625/1645 val_loss:3.2846 train_time:150873ms step_avg:92.84ms +step:1626/1645 train_time:150893ms step_avg:92.80ms +step:1627/1645 train_time:150971ms step_avg:92.79ms +step:1628/1645 train_time:151066ms step_avg:92.79ms +step:1629/1645 train_time:151159ms step_avg:92.79ms +step:1630/1645 train_time:151251ms step_avg:92.79ms +step:1631/1645 train_time:151344ms step_avg:92.79ms +step:1632/1645 train_time:151437ms step_avg:92.79ms +step:1633/1645 train_time:151530ms step_avg:92.79ms +step:1634/1645 train_time:151623ms step_avg:92.79ms +step:1635/1645 train_time:151716ms step_avg:92.79ms +step:1636/1645 train_time:151811ms step_avg:92.79ms +step:1637/1645 train_time:151907ms step_avg:92.80ms +step:1638/1645 train_time:152005ms step_avg:92.80ms +step:1639/1645 train_time:152100ms step_avg:92.80ms +step:1640/1645 train_time:152194ms step_avg:92.80ms +step:1641/1645 train_time:152287ms step_avg:92.80ms +step:1642/1645 train_time:152380ms step_avg:92.80ms +step:1643/1645 train_time:152473ms step_avg:92.80ms +step:1644/1645 train_time:152567ms step_avg:92.80ms +step:1645/1645 train_time:152659ms step_avg:92.80ms +step:1645/1645 val_loss:3.2786 train_time:152754ms step_avg:92.86ms +peak memory allocated: 32074 MiB reserved: 46736 MiB diff --git a/records/091825_Smear/692ba682-7466-4b59-a31a-7e0adcb55b4b.txt b/records/091825_Smear/692ba682-7466-4b59-a31a-7e0adcb55b4b.txt new file mode 100644 index 000000000..8b370005d --- /dev/null +++ b/records/091825_Smear/692ba682-7466-4b59-a31a-7e0adcb55b4b.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:40:45 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 34C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 34C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1645 train_time:137ms step_avg:137.41ms +step:2/1645 train_time:159ms step_avg:79.52ms +step:3/1645 train_time:226ms step_avg:75.39ms +step:4/1645 train_time:316ms step_avg:78.97ms +step:5/1645 train_time:407ms step_avg:81.36ms +step:6/1645 train_time:497ms step_avg:82.90ms +step:7/1645 train_time:588ms step_avg:84.02ms +step:8/1645 train_time:680ms step_avg:84.96ms +step:9/1645 train_time:771ms step_avg:85.63ms +step:10/1645 train_time:861ms step_avg:86.13ms +step:11/1645 train_time:952ms step_avg:86.56ms +step:12/1645 train_time:1044ms step_avg:86.96ms +step:13/1645 train_time:1137ms step_avg:87.47ms +step:14/1645 train_time:1230ms step_avg:87.85ms +step:15/1645 train_time:1322ms step_avg:88.13ms +step:16/1645 train_time:1414ms step_avg:88.36ms +step:17/1645 train_time:1505ms step_avg:88.55ms +step:18/1645 train_time:1597ms step_avg:88.70ms +step:19/1645 train_time:1688ms step_avg:88.87ms +step:20/1645 train_time:1780ms step_avg:89.00ms +step:21/1645 train_time:1871ms step_avg:89.09ms +step:22/1645 train_time:1963ms step_avg:89.21ms +step:23/1645 train_time:2054ms step_avg:89.30ms +step:24/1645 train_time:2147ms step_avg:89.46ms +step:25/1645 train_time:2240ms step_avg:89.58ms +step:26/1645 train_time:2332ms step_avg:89.68ms +step:27/1645 train_time:2425ms step_avg:89.82ms +step:28/1645 train_time:2516ms step_avg:89.87ms +step:29/1645 train_time:2608ms step_avg:89.94ms +step:30/1645 train_time:2700ms step_avg:89.99ms +step:31/1645 train_time:2792ms step_avg:90.05ms +step:32/1645 train_time:2884ms step_avg:90.13ms +step:33/1645 train_time:2975ms step_avg:90.15ms +step:34/1645 train_time:3067ms step_avg:90.20ms +step:35/1645 train_time:3159ms step_avg:90.26ms +step:36/1645 train_time:3251ms step_avg:90.31ms +step:37/1645 train_time:3344ms step_avg:90.37ms +step:38/1645 train_time:3435ms step_avg:90.41ms +step:39/1645 train_time:3527ms step_avg:90.44ms +step:40/1645 train_time:3620ms step_avg:90.49ms +step:41/1645 train_time:3711ms step_avg:90.52ms +step:42/1645 train_time:3804ms step_avg:90.57ms +step:43/1645 train_time:3896ms step_avg:90.60ms +step:44/1645 train_time:3987ms step_avg:90.62ms +step:45/1645 train_time:4079ms step_avg:90.64ms +step:46/1645 train_time:4171ms step_avg:90.68ms +step:47/1645 train_time:4263ms step_avg:90.70ms +step:48/1645 train_time:4354ms step_avg:90.72ms +step:49/1645 train_time:4446ms step_avg:90.74ms +step:50/1645 train_time:4538ms step_avg:90.75ms +step:51/1645 train_time:4629ms step_avg:90.77ms +step:52/1645 train_time:4721ms step_avg:90.79ms +step:53/1645 train_time:4813ms step_avg:90.82ms +step:54/1645 train_time:4906ms step_avg:90.85ms +step:55/1645 train_time:4998ms step_avg:90.87ms +step:56/1645 train_time:5089ms step_avg:90.88ms +step:57/1645 train_time:5181ms step_avg:90.90ms +step:58/1645 train_time:5273ms step_avg:90.92ms +step:59/1645 train_time:5365ms step_avg:90.94ms +step:60/1645 train_time:5458ms step_avg:90.97ms +step:61/1645 train_time:5549ms step_avg:90.97ms +step:62/1645 train_time:5641ms step_avg:90.98ms +step:63/1645 train_time:5733ms step_avg:91.00ms +step:64/1645 train_time:5826ms step_avg:91.03ms +step:65/1645 train_time:5918ms step_avg:91.05ms +step:66/1645 train_time:6010ms step_avg:91.06ms +step:67/1645 train_time:6102ms step_avg:91.07ms +step:68/1645 train_time:6195ms step_avg:91.10ms +step:69/1645 train_time:6287ms step_avg:91.12ms +step:70/1645 train_time:6379ms step_avg:91.13ms +step:71/1645 train_time:6470ms step_avg:91.12ms +step:72/1645 train_time:6561ms step_avg:91.13ms +step:73/1645 train_time:6653ms step_avg:91.13ms +step:74/1645 train_time:6744ms step_avg:91.14ms +step:75/1645 train_time:6836ms step_avg:91.15ms +step:76/1645 train_time:6929ms step_avg:91.17ms +step:77/1645 train_time:7021ms step_avg:91.19ms +step:78/1645 train_time:7113ms step_avg:91.20ms +step:79/1645 train_time:7205ms step_avg:91.21ms +step:80/1645 train_time:7298ms step_avg:91.22ms +step:81/1645 train_time:7390ms step_avg:91.23ms +step:82/1645 train_time:7481ms step_avg:91.23ms +step:83/1645 train_time:7572ms step_avg:91.23ms +step:84/1645 train_time:7663ms step_avg:91.23ms +step:85/1645 train_time:7754ms step_avg:91.23ms +step:86/1645 train_time:7846ms step_avg:91.24ms +step:87/1645 train_time:7938ms step_avg:91.24ms +step:88/1645 train_time:8030ms step_avg:91.25ms +step:89/1645 train_time:8123ms step_avg:91.26ms +step:90/1645 train_time:8215ms step_avg:91.28ms +step:91/1645 train_time:8307ms step_avg:91.29ms +step:92/1645 train_time:8400ms step_avg:91.30ms +step:93/1645 train_time:8491ms step_avg:91.30ms +step:94/1645 train_time:8582ms step_avg:91.30ms +step:95/1645 train_time:8674ms step_avg:91.31ms +step:96/1645 train_time:8765ms step_avg:91.30ms +step:97/1645 train_time:8857ms step_avg:91.31ms +step:98/1645 train_time:8949ms step_avg:91.31ms +step:99/1645 train_time:9040ms step_avg:91.31ms +step:100/1645 train_time:9132ms step_avg:91.32ms +step:101/1645 train_time:9224ms step_avg:91.33ms +step:102/1645 train_time:9316ms step_avg:91.33ms +step:103/1645 train_time:9408ms step_avg:91.34ms +step:104/1645 train_time:9501ms step_avg:91.35ms +step:105/1645 train_time:9592ms step_avg:91.35ms +step:106/1645 train_time:9684ms step_avg:91.35ms +step:107/1645 train_time:9775ms step_avg:91.36ms +step:108/1645 train_time:9866ms step_avg:91.35ms +step:109/1645 train_time:9959ms step_avg:91.36ms +step:110/1645 train_time:10050ms step_avg:91.37ms +step:111/1645 train_time:10143ms step_avg:91.38ms +step:112/1645 train_time:10234ms step_avg:91.38ms +step:113/1645 train_time:10327ms step_avg:91.39ms +step:114/1645 train_time:10421ms step_avg:91.41ms +step:115/1645 train_time:10512ms step_avg:91.41ms +step:116/1645 train_time:10603ms step_avg:91.40ms +step:117/1645 train_time:10695ms step_avg:91.41ms +step:118/1645 train_time:10786ms step_avg:91.41ms +step:119/1645 train_time:10877ms step_avg:91.40ms +step:120/1645 train_time:10968ms step_avg:91.40ms +step:121/1645 train_time:11059ms step_avg:91.40ms +step:122/1645 train_time:11152ms step_avg:91.41ms +step:123/1645 train_time:11243ms step_avg:91.41ms +step:124/1645 train_time:11335ms step_avg:91.41ms +step:125/1645 train_time:11428ms step_avg:91.43ms +step:125/1645 val_loss:4.3128 train_time:11521ms step_avg:92.16ms +step:126/1645 train_time:11535ms step_avg:91.55ms +step:127/1645 train_time:11616ms step_avg:91.46ms +step:128/1645 train_time:11718ms step_avg:91.55ms +step:129/1645 train_time:11812ms step_avg:91.57ms +step:130/1645 train_time:11904ms step_avg:91.57ms +step:131/1645 train_time:11995ms step_avg:91.57ms +step:132/1645 train_time:12086ms step_avg:91.56ms +step:133/1645 train_time:12176ms step_avg:91.55ms +step:134/1645 train_time:12267ms step_avg:91.54ms +step:135/1645 train_time:12357ms step_avg:91.54ms +step:136/1645 train_time:12448ms step_avg:91.53ms +step:137/1645 train_time:12540ms step_avg:91.53ms +step:138/1645 train_time:12633ms step_avg:91.55ms +step:139/1645 train_time:12729ms step_avg:91.57ms +step:140/1645 train_time:12822ms step_avg:91.58ms +step:141/1645 train_time:12913ms step_avg:91.58ms +step:142/1645 train_time:13005ms step_avg:91.58ms +step:143/1645 train_time:13095ms step_avg:91.58ms +step:144/1645 train_time:13187ms step_avg:91.57ms +step:145/1645 train_time:13278ms step_avg:91.57ms +step:146/1645 train_time:13369ms step_avg:91.57ms +step:147/1645 train_time:13460ms step_avg:91.56ms +step:148/1645 train_time:13552ms step_avg:91.56ms +step:149/1645 train_time:13645ms step_avg:91.58ms +step:150/1645 train_time:13738ms step_avg:91.59ms +step:151/1645 train_time:13831ms step_avg:91.60ms +step:152/1645 train_time:13924ms step_avg:91.60ms +step:153/1645 train_time:14014ms step_avg:91.60ms +step:154/1645 train_time:14106ms step_avg:91.60ms +step:155/1645 train_time:14198ms step_avg:91.60ms +step:156/1645 train_time:14288ms step_avg:91.59ms +step:157/1645 train_time:14380ms step_avg:91.59ms +step:158/1645 train_time:14470ms step_avg:91.58ms +step:159/1645 train_time:14562ms step_avg:91.58ms +step:160/1645 train_time:14655ms step_avg:91.59ms +step:161/1645 train_time:14749ms step_avg:91.61ms +step:162/1645 train_time:14844ms step_avg:91.63ms +step:163/1645 train_time:14934ms step_avg:91.62ms +step:164/1645 train_time:15026ms step_avg:91.62ms +step:165/1645 train_time:15117ms step_avg:91.62ms +step:166/1645 train_time:15208ms step_avg:91.62ms +step:167/1645 train_time:15299ms step_avg:91.61ms +step:168/1645 train_time:15390ms step_avg:91.61ms +step:169/1645 train_time:15481ms step_avg:91.61ms +step:170/1645 train_time:15572ms step_avg:91.60ms +step:171/1645 train_time:15665ms step_avg:91.61ms +step:172/1645 train_time:15758ms step_avg:91.62ms +step:173/1645 train_time:15850ms step_avg:91.62ms +step:174/1645 train_time:15942ms step_avg:91.62ms +step:175/1645 train_time:16033ms step_avg:91.62ms +step:176/1645 train_time:16126ms step_avg:91.62ms +step:177/1645 train_time:16217ms step_avg:91.62ms +step:178/1645 train_time:16308ms step_avg:91.62ms +step:179/1645 train_time:16399ms step_avg:91.62ms +step:180/1645 train_time:16490ms step_avg:91.61ms +step:181/1645 train_time:16582ms step_avg:91.61ms +step:182/1645 train_time:16673ms step_avg:91.61ms +step:183/1645 train_time:16765ms step_avg:91.61ms +step:184/1645 train_time:16857ms step_avg:91.62ms +step:185/1645 train_time:16949ms step_avg:91.62ms +step:186/1645 train_time:17041ms step_avg:91.62ms +step:187/1645 train_time:17133ms step_avg:91.62ms +step:188/1645 train_time:17225ms step_avg:91.62ms +step:189/1645 train_time:17316ms step_avg:91.62ms +step:190/1645 train_time:17407ms step_avg:91.61ms +step:191/1645 train_time:17498ms step_avg:91.61ms +step:192/1645 train_time:17590ms step_avg:91.61ms +step:193/1645 train_time:17681ms step_avg:91.61ms +step:194/1645 train_time:17773ms step_avg:91.62ms +step:195/1645 train_time:17867ms step_avg:91.63ms +step:196/1645 train_time:17960ms step_avg:91.63ms +step:197/1645 train_time:18051ms step_avg:91.63ms +step:198/1645 train_time:18143ms step_avg:91.63ms +step:199/1645 train_time:18234ms step_avg:91.63ms +step:200/1645 train_time:18325ms step_avg:91.63ms +step:201/1645 train_time:18416ms step_avg:91.62ms +step:202/1645 train_time:18507ms step_avg:91.62ms +step:203/1645 train_time:18598ms step_avg:91.62ms +step:204/1645 train_time:18689ms step_avg:91.61ms +step:205/1645 train_time:18781ms step_avg:91.61ms +step:206/1645 train_time:18873ms step_avg:91.62ms +step:207/1645 train_time:18966ms step_avg:91.62ms +step:208/1645 train_time:19059ms step_avg:91.63ms +step:209/1645 train_time:19150ms step_avg:91.63ms +step:210/1645 train_time:19242ms step_avg:91.63ms +step:211/1645 train_time:19333ms step_avg:91.63ms +step:212/1645 train_time:19424ms step_avg:91.62ms +step:213/1645 train_time:19516ms step_avg:91.63ms +step:214/1645 train_time:19608ms step_avg:91.62ms +step:215/1645 train_time:19700ms step_avg:91.63ms +step:216/1645 train_time:19791ms step_avg:91.63ms +step:217/1645 train_time:19884ms step_avg:91.63ms +step:218/1645 train_time:19976ms step_avg:91.63ms +step:219/1645 train_time:20068ms step_avg:91.64ms +step:220/1645 train_time:20160ms step_avg:91.64ms +step:221/1645 train_time:20253ms step_avg:91.64ms +step:222/1645 train_time:20344ms step_avg:91.64ms +step:223/1645 train_time:20436ms step_avg:91.64ms +step:224/1645 train_time:20528ms step_avg:91.64ms +step:225/1645 train_time:20619ms step_avg:91.64ms +step:226/1645 train_time:20710ms step_avg:91.64ms +step:227/1645 train_time:20801ms step_avg:91.64ms +step:228/1645 train_time:20894ms step_avg:91.64ms +step:229/1645 train_time:20986ms step_avg:91.64ms +step:230/1645 train_time:21077ms step_avg:91.64ms +step:231/1645 train_time:21170ms step_avg:91.64ms +step:232/1645 train_time:21261ms step_avg:91.64ms +step:233/1645 train_time:21353ms step_avg:91.64ms +step:234/1645 train_time:21445ms step_avg:91.64ms +step:235/1645 train_time:21536ms step_avg:91.64ms +step:236/1645 train_time:21627ms step_avg:91.64ms +step:237/1645 train_time:21719ms step_avg:91.64ms +step:238/1645 train_time:21810ms step_avg:91.64ms +step:239/1645 train_time:21901ms step_avg:91.64ms +step:240/1645 train_time:21993ms step_avg:91.64ms +step:241/1645 train_time:22085ms step_avg:91.64ms +step:242/1645 train_time:22177ms step_avg:91.64ms +step:243/1645 train_time:22269ms step_avg:91.64ms +step:244/1645 train_time:22360ms step_avg:91.64ms +step:245/1645 train_time:22452ms step_avg:91.64ms +step:246/1645 train_time:22544ms step_avg:91.64ms +step:247/1645 train_time:22636ms step_avg:91.65ms +step:248/1645 train_time:22728ms step_avg:91.65ms +step:249/1645 train_time:22819ms step_avg:91.64ms +step:250/1645 train_time:22911ms step_avg:91.64ms +step:250/1645 val_loss:3.9734 train_time:23002ms step_avg:92.01ms +step:251/1645 train_time:23017ms step_avg:91.70ms +step:252/1645 train_time:23096ms step_avg:91.65ms +step:253/1645 train_time:23190ms step_avg:91.66ms +step:254/1645 train_time:23281ms step_avg:91.66ms +step:255/1645 train_time:23372ms step_avg:91.66ms +step:256/1645 train_time:23463ms step_avg:91.65ms +step:257/1645 train_time:23553ms step_avg:91.65ms +step:258/1645 train_time:23644ms step_avg:91.64ms +step:259/1645 train_time:23735ms step_avg:91.64ms +step:260/1645 train_time:23826ms step_avg:91.64ms +step:261/1645 train_time:23919ms step_avg:91.64ms +step:262/1645 train_time:24013ms step_avg:91.65ms +step:263/1645 train_time:24106ms step_avg:91.66ms +step:264/1645 train_time:24198ms step_avg:91.66ms +step:265/1645 train_time:24290ms step_avg:91.66ms +step:266/1645 train_time:24380ms step_avg:91.66ms +step:267/1645 train_time:24472ms step_avg:91.66ms +step:268/1645 train_time:24563ms step_avg:91.65ms +step:269/1645 train_time:24654ms step_avg:91.65ms +step:270/1645 train_time:24744ms step_avg:91.65ms +step:271/1645 train_time:24837ms step_avg:91.65ms +step:272/1645 train_time:24928ms step_avg:91.65ms +step:273/1645 train_time:25021ms step_avg:91.65ms +step:274/1645 train_time:25114ms step_avg:91.66ms +step:275/1645 train_time:25205ms step_avg:91.65ms +step:276/1645 train_time:25298ms step_avg:91.66ms +step:277/1645 train_time:25389ms step_avg:91.66ms +step:278/1645 train_time:25481ms step_avg:91.66ms +step:279/1645 train_time:25572ms step_avg:91.66ms +step:280/1645 train_time:25663ms step_avg:91.65ms +step:281/1645 train_time:25753ms step_avg:91.65ms +step:282/1645 train_time:25845ms step_avg:91.65ms +step:283/1645 train_time:25937ms step_avg:91.65ms +step:284/1645 train_time:26029ms step_avg:91.65ms +step:285/1645 train_time:26121ms step_avg:91.65ms +step:286/1645 train_time:26214ms step_avg:91.66ms +step:287/1645 train_time:26307ms step_avg:91.66ms +step:288/1645 train_time:26399ms step_avg:91.66ms +step:289/1645 train_time:26491ms step_avg:91.67ms +step:290/1645 train_time:26582ms step_avg:91.66ms +step:291/1645 train_time:26673ms step_avg:91.66ms +step:292/1645 train_time:26764ms step_avg:91.66ms +step:293/1645 train_time:26856ms step_avg:91.66ms +step:294/1645 train_time:26948ms step_avg:91.66ms +step:295/1645 train_time:27039ms step_avg:91.66ms +step:296/1645 train_time:27131ms step_avg:91.66ms +step:297/1645 train_time:27223ms step_avg:91.66ms +step:298/1645 train_time:27316ms step_avg:91.66ms +step:299/1645 train_time:27408ms step_avg:91.66ms +step:300/1645 train_time:27500ms step_avg:91.67ms +step:301/1645 train_time:27591ms step_avg:91.67ms +step:302/1645 train_time:27683ms step_avg:91.66ms +step:303/1645 train_time:27773ms step_avg:91.66ms +step:304/1645 train_time:27864ms step_avg:91.66ms +step:305/1645 train_time:27955ms step_avg:91.66ms +step:306/1645 train_time:28047ms step_avg:91.66ms +step:307/1645 train_time:28139ms step_avg:91.66ms +step:308/1645 train_time:28231ms step_avg:91.66ms +step:309/1645 train_time:28323ms step_avg:91.66ms +step:310/1645 train_time:28416ms step_avg:91.66ms +step:311/1645 train_time:28507ms step_avg:91.66ms +step:312/1645 train_time:28599ms step_avg:91.66ms +step:313/1645 train_time:28691ms step_avg:91.66ms +step:314/1645 train_time:28782ms step_avg:91.66ms +step:315/1645 train_time:28874ms step_avg:91.66ms +step:316/1645 train_time:28965ms step_avg:91.66ms +step:317/1645 train_time:29056ms step_avg:91.66ms +step:318/1645 train_time:29147ms step_avg:91.66ms +step:319/1645 train_time:29239ms step_avg:91.66ms +step:320/1645 train_time:29331ms step_avg:91.66ms +step:321/1645 train_time:29422ms step_avg:91.66ms +step:322/1645 train_time:29514ms step_avg:91.66ms +step:323/1645 train_time:29606ms step_avg:91.66ms +step:324/1645 train_time:29698ms step_avg:91.66ms +step:325/1645 train_time:29790ms step_avg:91.66ms +step:326/1645 train_time:29881ms step_avg:91.66ms +step:327/1645 train_time:29973ms step_avg:91.66ms +step:328/1645 train_time:30064ms step_avg:91.66ms +step:329/1645 train_time:30155ms step_avg:91.66ms +step:330/1645 train_time:30247ms step_avg:91.66ms +step:331/1645 train_time:30339ms step_avg:91.66ms +step:332/1645 train_time:30430ms step_avg:91.66ms +step:333/1645 train_time:30521ms step_avg:91.66ms +step:334/1645 train_time:30613ms step_avg:91.66ms +step:335/1645 train_time:30705ms step_avg:91.66ms +step:336/1645 train_time:30796ms step_avg:91.65ms +step:337/1645 train_time:30888ms step_avg:91.66ms +step:338/1645 train_time:30980ms step_avg:91.66ms +step:339/1645 train_time:31072ms step_avg:91.66ms +step:340/1645 train_time:31162ms step_avg:91.65ms +step:341/1645 train_time:31254ms step_avg:91.65ms +step:342/1645 train_time:31345ms step_avg:91.65ms +step:343/1645 train_time:31437ms step_avg:91.65ms +step:344/1645 train_time:31528ms step_avg:91.65ms +step:345/1645 train_time:31621ms step_avg:91.65ms +step:346/1645 train_time:31714ms step_avg:91.66ms +step:347/1645 train_time:31805ms step_avg:91.66ms +step:348/1645 train_time:31898ms step_avg:91.66ms +step:349/1645 train_time:31989ms step_avg:91.66ms +step:350/1645 train_time:32080ms step_avg:91.66ms +step:351/1645 train_time:32174ms step_avg:91.66ms +step:352/1645 train_time:32264ms step_avg:91.66ms +step:353/1645 train_time:32354ms step_avg:91.66ms +step:354/1645 train_time:32446ms step_avg:91.65ms +step:355/1645 train_time:32538ms step_avg:91.66ms +step:356/1645 train_time:32629ms step_avg:91.66ms +step:357/1645 train_time:32721ms step_avg:91.66ms +step:358/1645 train_time:32813ms step_avg:91.66ms +step:359/1645 train_time:32905ms step_avg:91.66ms +step:360/1645 train_time:32998ms step_avg:91.66ms +step:361/1645 train_time:33088ms step_avg:91.66ms +step:362/1645 train_time:33180ms step_avg:91.66ms +step:363/1645 train_time:33272ms step_avg:91.66ms +step:364/1645 train_time:33363ms step_avg:91.66ms +step:365/1645 train_time:33454ms step_avg:91.65ms +step:366/1645 train_time:33545ms step_avg:91.65ms +step:367/1645 train_time:33637ms step_avg:91.65ms +step:368/1645 train_time:33729ms step_avg:91.65ms +step:369/1645 train_time:33821ms step_avg:91.66ms +step:370/1645 train_time:33913ms step_avg:91.66ms +step:371/1645 train_time:34006ms step_avg:91.66ms +step:372/1645 train_time:34098ms step_avg:91.66ms +step:373/1645 train_time:34189ms step_avg:91.66ms +step:374/1645 train_time:34282ms step_avg:91.66ms +step:375/1645 train_time:34372ms step_avg:91.66ms +step:375/1645 val_loss:3.8190 train_time:34463ms step_avg:91.90ms +step:376/1645 train_time:34478ms step_avg:91.70ms +step:377/1645 train_time:34558ms step_avg:91.67ms +step:378/1645 train_time:34653ms step_avg:91.67ms +step:379/1645 train_time:34745ms step_avg:91.68ms +step:380/1645 train_time:34836ms step_avg:91.67ms +step:381/1645 train_time:34927ms step_avg:91.67ms +step:382/1645 train_time:35017ms step_avg:91.67ms +step:383/1645 train_time:35108ms step_avg:91.67ms +step:384/1645 train_time:35199ms step_avg:91.66ms +step:385/1645 train_time:35290ms step_avg:91.66ms +step:386/1645 train_time:35381ms step_avg:91.66ms +step:387/1645 train_time:35474ms step_avg:91.67ms +step:388/1645 train_time:35569ms step_avg:91.67ms +step:389/1645 train_time:35661ms step_avg:91.67ms +step:390/1645 train_time:35754ms step_avg:91.68ms +step:391/1645 train_time:35846ms step_avg:91.68ms +step:392/1645 train_time:35937ms step_avg:91.68ms +step:393/1645 train_time:36028ms step_avg:91.67ms +step:394/1645 train_time:36119ms step_avg:91.67ms +step:395/1645 train_time:36209ms step_avg:91.67ms +step:396/1645 train_time:36300ms step_avg:91.67ms +step:397/1645 train_time:36392ms step_avg:91.67ms +step:398/1645 train_time:36483ms step_avg:91.67ms +step:399/1645 train_time:36576ms step_avg:91.67ms +step:400/1645 train_time:36669ms step_avg:91.67ms +step:401/1645 train_time:36761ms step_avg:91.67ms +step:402/1645 train_time:36855ms step_avg:91.68ms +step:403/1645 train_time:36948ms step_avg:91.68ms +step:404/1645 train_time:37038ms step_avg:91.68ms +step:405/1645 train_time:37129ms step_avg:91.68ms +step:406/1645 train_time:37220ms step_avg:91.67ms +step:407/1645 train_time:37310ms step_avg:91.67ms +step:408/1645 train_time:37402ms step_avg:91.67ms +step:409/1645 train_time:37494ms step_avg:91.67ms +step:410/1645 train_time:37586ms step_avg:91.67ms +step:411/1645 train_time:37678ms step_avg:91.67ms +step:412/1645 train_time:37770ms step_avg:91.68ms +step:413/1645 train_time:37862ms step_avg:91.67ms +step:414/1645 train_time:37954ms step_avg:91.68ms +step:415/1645 train_time:38046ms step_avg:91.68ms +step:416/1645 train_time:38137ms step_avg:91.67ms +step:417/1645 train_time:38228ms step_avg:91.67ms +step:418/1645 train_time:38319ms step_avg:91.67ms +step:419/1645 train_time:38411ms step_avg:91.67ms +step:420/1645 train_time:38503ms step_avg:91.67ms +step:421/1645 train_time:38595ms step_avg:91.67ms +step:422/1645 train_time:38687ms step_avg:91.67ms +step:423/1645 train_time:38779ms step_avg:91.68ms +step:424/1645 train_time:38871ms step_avg:91.68ms +step:425/1645 train_time:38962ms step_avg:91.68ms +step:426/1645 train_time:39054ms step_avg:91.68ms +step:427/1645 train_time:39145ms step_avg:91.67ms +step:428/1645 train_time:39237ms step_avg:91.67ms +step:429/1645 train_time:39328ms step_avg:91.67ms +step:430/1645 train_time:39419ms step_avg:91.67ms +step:431/1645 train_time:39511ms step_avg:91.67ms +step:432/1645 train_time:39602ms step_avg:91.67ms +step:433/1645 train_time:39694ms step_avg:91.67ms +step:434/1645 train_time:39786ms step_avg:91.67ms +step:435/1645 train_time:39878ms step_avg:91.67ms +step:436/1645 train_time:39969ms step_avg:91.67ms +step:437/1645 train_time:40060ms step_avg:91.67ms +step:438/1645 train_time:40152ms step_avg:91.67ms +step:439/1645 train_time:40244ms step_avg:91.67ms +step:440/1645 train_time:40336ms step_avg:91.67ms +step:441/1645 train_time:40427ms step_avg:91.67ms +step:442/1645 train_time:40519ms step_avg:91.67ms +step:443/1645 train_time:40611ms step_avg:91.67ms +step:444/1645 train_time:40703ms step_avg:91.67ms +step:445/1645 train_time:40795ms step_avg:91.67ms +step:446/1645 train_time:40886ms step_avg:91.67ms +step:447/1645 train_time:40978ms step_avg:91.67ms +step:448/1645 train_time:41070ms step_avg:91.67ms +step:449/1645 train_time:41162ms step_avg:91.67ms +step:450/1645 train_time:41253ms step_avg:91.67ms +step:451/1645 train_time:41345ms step_avg:91.67ms +step:452/1645 train_time:41436ms step_avg:91.67ms +step:453/1645 train_time:41528ms step_avg:91.67ms +step:454/1645 train_time:41619ms step_avg:91.67ms +step:455/1645 train_time:41711ms step_avg:91.67ms +step:456/1645 train_time:41802ms step_avg:91.67ms +step:457/1645 train_time:41894ms step_avg:91.67ms +step:458/1645 train_time:41986ms step_avg:91.67ms +step:459/1645 train_time:42077ms step_avg:91.67ms +step:460/1645 train_time:42170ms step_avg:91.67ms +step:461/1645 train_time:42260ms step_avg:91.67ms +step:462/1645 train_time:42351ms step_avg:91.67ms +step:463/1645 train_time:42443ms step_avg:91.67ms +step:464/1645 train_time:42535ms step_avg:91.67ms +step:465/1645 train_time:42626ms step_avg:91.67ms +step:466/1645 train_time:42718ms step_avg:91.67ms +step:467/1645 train_time:42811ms step_avg:91.67ms +step:468/1645 train_time:42902ms step_avg:91.67ms +step:469/1645 train_time:42994ms step_avg:91.67ms +step:470/1645 train_time:43086ms step_avg:91.67ms +step:471/1645 train_time:43177ms step_avg:91.67ms +step:472/1645 train_time:43269ms step_avg:91.67ms +step:473/1645 train_time:43362ms step_avg:91.67ms +step:474/1645 train_time:43453ms step_avg:91.67ms +step:475/1645 train_time:43544ms step_avg:91.67ms +step:476/1645 train_time:43635ms step_avg:91.67ms +step:477/1645 train_time:43727ms step_avg:91.67ms +step:478/1645 train_time:43819ms step_avg:91.67ms +step:479/1645 train_time:43910ms step_avg:91.67ms +step:480/1645 train_time:44001ms step_avg:91.67ms +step:481/1645 train_time:44093ms step_avg:91.67ms +step:482/1645 train_time:44186ms step_avg:91.67ms +step:483/1645 train_time:44278ms step_avg:91.67ms +step:484/1645 train_time:44370ms step_avg:91.67ms +step:485/1645 train_time:44463ms step_avg:91.68ms +step:486/1645 train_time:44554ms step_avg:91.67ms +step:487/1645 train_time:44645ms step_avg:91.67ms +step:488/1645 train_time:44737ms step_avg:91.67ms +step:489/1645 train_time:44828ms step_avg:91.67ms +step:490/1645 train_time:44919ms step_avg:91.67ms +step:491/1645 train_time:45010ms step_avg:91.67ms +step:492/1645 train_time:45102ms step_avg:91.67ms +step:493/1645 train_time:45194ms step_avg:91.67ms +step:494/1645 train_time:45286ms step_avg:91.67ms +step:495/1645 train_time:45377ms step_avg:91.67ms +step:496/1645 train_time:45469ms step_avg:91.67ms +step:497/1645 train_time:45561ms step_avg:91.67ms +step:498/1645 train_time:45654ms step_avg:91.67ms +step:499/1645 train_time:45745ms step_avg:91.67ms +step:500/1645 train_time:45837ms step_avg:91.67ms +step:500/1645 val_loss:3.7160 train_time:45929ms step_avg:91.86ms +step:501/1645 train_time:45949ms step_avg:91.71ms +step:502/1645 train_time:46024ms step_avg:91.68ms +step:503/1645 train_time:46118ms step_avg:91.69ms +step:504/1645 train_time:46209ms step_avg:91.69ms +step:505/1645 train_time:46300ms step_avg:91.68ms +step:506/1645 train_time:46391ms step_avg:91.68ms +step:507/1645 train_time:46481ms step_avg:91.68ms +step:508/1645 train_time:46572ms step_avg:91.68ms +step:509/1645 train_time:46663ms step_avg:91.68ms +step:510/1645 train_time:46754ms step_avg:91.67ms +step:511/1645 train_time:46846ms step_avg:91.68ms +step:512/1645 train_time:46939ms step_avg:91.68ms +step:513/1645 train_time:47032ms step_avg:91.68ms +step:514/1645 train_time:47125ms step_avg:91.68ms +step:515/1645 train_time:47217ms step_avg:91.68ms +step:516/1645 train_time:47308ms step_avg:91.68ms +step:517/1645 train_time:47399ms step_avg:91.68ms +step:518/1645 train_time:47490ms step_avg:91.68ms +step:519/1645 train_time:47581ms step_avg:91.68ms +step:520/1645 train_time:47671ms step_avg:91.68ms +step:521/1645 train_time:47762ms step_avg:91.67ms +step:522/1645 train_time:47854ms step_avg:91.67ms +step:523/1645 train_time:47946ms step_avg:91.68ms +step:524/1645 train_time:48039ms step_avg:91.68ms +step:525/1645 train_time:48131ms step_avg:91.68ms +step:526/1645 train_time:48223ms step_avg:91.68ms +step:527/1645 train_time:48315ms step_avg:91.68ms +step:528/1645 train_time:48407ms step_avg:91.68ms +step:529/1645 train_time:48498ms step_avg:91.68ms +step:530/1645 train_time:48589ms step_avg:91.68ms +step:531/1645 train_time:48680ms step_avg:91.68ms +step:532/1645 train_time:48771ms step_avg:91.67ms +step:533/1645 train_time:48862ms step_avg:91.67ms +step:534/1645 train_time:48955ms step_avg:91.68ms +step:535/1645 train_time:49046ms step_avg:91.67ms +step:536/1645 train_time:49137ms step_avg:91.67ms +step:537/1645 train_time:49230ms step_avg:91.68ms +step:538/1645 train_time:49323ms step_avg:91.68ms +step:539/1645 train_time:49414ms step_avg:91.68ms +step:540/1645 train_time:49506ms step_avg:91.68ms +step:541/1645 train_time:49598ms step_avg:91.68ms +step:542/1645 train_time:49689ms step_avg:91.68ms +step:543/1645 train_time:49780ms step_avg:91.68ms +step:544/1645 train_time:49871ms step_avg:91.67ms +step:545/1645 train_time:49963ms step_avg:91.68ms +step:546/1645 train_time:50054ms step_avg:91.67ms +step:547/1645 train_time:50146ms step_avg:91.67ms +step:548/1645 train_time:50237ms step_avg:91.67ms +step:549/1645 train_time:50329ms step_avg:91.67ms +step:550/1645 train_time:50422ms step_avg:91.68ms +step:551/1645 train_time:50515ms step_avg:91.68ms +step:552/1645 train_time:50609ms step_avg:91.68ms +step:553/1645 train_time:50701ms step_avg:91.68ms +step:554/1645 train_time:50794ms step_avg:91.69ms +step:555/1645 train_time:50887ms step_avg:91.69ms +step:556/1645 train_time:50980ms step_avg:91.69ms +step:557/1645 train_time:51072ms step_avg:91.69ms +step:558/1645 train_time:51166ms step_avg:91.70ms +step:559/1645 train_time:51258ms step_avg:91.70ms +step:560/1645 train_time:51352ms step_avg:91.70ms +step:561/1645 train_time:51445ms step_avg:91.70ms +step:562/1645 train_time:51539ms step_avg:91.71ms +step:563/1645 train_time:51631ms step_avg:91.71ms +step:564/1645 train_time:51724ms step_avg:91.71ms +step:565/1645 train_time:51817ms step_avg:91.71ms +step:566/1645 train_time:51911ms step_avg:91.71ms +step:567/1645 train_time:52003ms step_avg:91.72ms +step:568/1645 train_time:52095ms step_avg:91.72ms +step:569/1645 train_time:52189ms step_avg:91.72ms +step:570/1645 train_time:52282ms step_avg:91.72ms +step:571/1645 train_time:52375ms step_avg:91.73ms +step:572/1645 train_time:52469ms step_avg:91.73ms +step:573/1645 train_time:52561ms step_avg:91.73ms +step:574/1645 train_time:52654ms step_avg:91.73ms +step:575/1645 train_time:52748ms step_avg:91.74ms +step:576/1645 train_time:52841ms step_avg:91.74ms +step:577/1645 train_time:52933ms step_avg:91.74ms +step:578/1645 train_time:53026ms step_avg:91.74ms +step:579/1645 train_time:53119ms step_avg:91.74ms +step:580/1645 train_time:53212ms step_avg:91.74ms +step:581/1645 train_time:53306ms step_avg:91.75ms +step:582/1645 train_time:53400ms step_avg:91.75ms +step:583/1645 train_time:53492ms step_avg:91.75ms +step:584/1645 train_time:53585ms step_avg:91.76ms +step:585/1645 train_time:53678ms step_avg:91.76ms +step:586/1645 train_time:53770ms step_avg:91.76ms +step:587/1645 train_time:53863ms step_avg:91.76ms +step:588/1645 train_time:53955ms step_avg:91.76ms +step:589/1645 train_time:54048ms step_avg:91.76ms +step:590/1645 train_time:54142ms step_avg:91.77ms +step:591/1645 train_time:54235ms step_avg:91.77ms +step:592/1645 train_time:54328ms step_avg:91.77ms +step:593/1645 train_time:54422ms step_avg:91.77ms +step:594/1645 train_time:54515ms step_avg:91.78ms +step:595/1645 train_time:54608ms step_avg:91.78ms +step:596/1645 train_time:54701ms step_avg:91.78ms +step:597/1645 train_time:54794ms step_avg:91.78ms +step:598/1645 train_time:54888ms step_avg:91.79ms +step:599/1645 train_time:54981ms step_avg:91.79ms +step:600/1645 train_time:55073ms step_avg:91.79ms +step:601/1645 train_time:55167ms step_avg:91.79ms +step:602/1645 train_time:55259ms step_avg:91.79ms +step:603/1645 train_time:55353ms step_avg:91.80ms +step:604/1645 train_time:55448ms step_avg:91.80ms +step:605/1645 train_time:55541ms step_avg:91.80ms +step:606/1645 train_time:55633ms step_avg:91.80ms +step:607/1645 train_time:55726ms step_avg:91.81ms +step:608/1645 train_time:55820ms step_avg:91.81ms +step:609/1645 train_time:55913ms step_avg:91.81ms +step:610/1645 train_time:56006ms step_avg:91.81ms +step:611/1645 train_time:56099ms step_avg:91.81ms +step:612/1645 train_time:56192ms step_avg:91.82ms +step:613/1645 train_time:56285ms step_avg:91.82ms +step:614/1645 train_time:56378ms step_avg:91.82ms +step:615/1645 train_time:56471ms step_avg:91.82ms +step:616/1645 train_time:56565ms step_avg:91.83ms +step:617/1645 train_time:56659ms step_avg:91.83ms +step:618/1645 train_time:56751ms step_avg:91.83ms +step:619/1645 train_time:56844ms step_avg:91.83ms +step:620/1645 train_time:56936ms step_avg:91.83ms +step:621/1645 train_time:57030ms step_avg:91.84ms +step:622/1645 train_time:57123ms step_avg:91.84ms +step:623/1645 train_time:57216ms step_avg:91.84ms +step:624/1645 train_time:57310ms step_avg:91.84ms +step:625/1645 train_time:57403ms step_avg:91.84ms +step:625/1645 val_loss:3.6127 train_time:57495ms step_avg:91.99ms +step:626/1645 train_time:57510ms step_avg:91.87ms +step:627/1645 train_time:57595ms step_avg:91.86ms +step:628/1645 train_time:57697ms step_avg:91.87ms +step:629/1645 train_time:57793ms step_avg:91.88ms +step:630/1645 train_time:57885ms step_avg:91.88ms +step:631/1645 train_time:57976ms step_avg:91.88ms +step:632/1645 train_time:58068ms step_avg:91.88ms +step:633/1645 train_time:58160ms step_avg:91.88ms +step:634/1645 train_time:58252ms step_avg:91.88ms +step:635/1645 train_time:58344ms step_avg:91.88ms +step:636/1645 train_time:58437ms step_avg:91.88ms +step:637/1645 train_time:58531ms step_avg:91.88ms +step:638/1645 train_time:58626ms step_avg:91.89ms +step:639/1645 train_time:58721ms step_avg:91.89ms +step:640/1645 train_time:58815ms step_avg:91.90ms +step:641/1645 train_time:58909ms step_avg:91.90ms +step:642/1645 train_time:59001ms step_avg:91.90ms +step:643/1645 train_time:59093ms step_avg:91.90ms +step:644/1645 train_time:59185ms step_avg:91.90ms +step:645/1645 train_time:59277ms step_avg:91.90ms +step:646/1645 train_time:59369ms step_avg:91.90ms +step:647/1645 train_time:59462ms step_avg:91.90ms +step:648/1645 train_time:59556ms step_avg:91.91ms +step:649/1645 train_time:59650ms step_avg:91.91ms +step:650/1645 train_time:59743ms step_avg:91.91ms +step:651/1645 train_time:59837ms step_avg:91.91ms +step:652/1645 train_time:59930ms step_avg:91.92ms +step:653/1645 train_time:60023ms step_avg:91.92ms +step:654/1645 train_time:60116ms step_avg:91.92ms +step:655/1645 train_time:60208ms step_avg:91.92ms +step:656/1645 train_time:60299ms step_avg:91.92ms +step:657/1645 train_time:60392ms step_avg:91.92ms +step:658/1645 train_time:60486ms step_avg:91.92ms +step:659/1645 train_time:60580ms step_avg:91.93ms +step:660/1645 train_time:60674ms step_avg:91.93ms +step:661/1645 train_time:60769ms step_avg:91.93ms +step:662/1645 train_time:60861ms step_avg:91.93ms +step:663/1645 train_time:60953ms step_avg:91.94ms +step:664/1645 train_time:61046ms step_avg:91.94ms +step:665/1645 train_time:61139ms step_avg:91.94ms +step:666/1645 train_time:61231ms step_avg:91.94ms +step:667/1645 train_time:61323ms step_avg:91.94ms +step:668/1645 train_time:61416ms step_avg:91.94ms +step:669/1645 train_time:61510ms step_avg:91.94ms +step:670/1645 train_time:61604ms step_avg:91.95ms +step:671/1645 train_time:61697ms step_avg:91.95ms +step:672/1645 train_time:61791ms step_avg:91.95ms +step:673/1645 train_time:61885ms step_avg:91.95ms +step:674/1645 train_time:61977ms step_avg:91.95ms +step:675/1645 train_time:62070ms step_avg:91.96ms +step:676/1645 train_time:62163ms step_avg:91.96ms +step:677/1645 train_time:62255ms step_avg:91.96ms +step:678/1645 train_time:62347ms step_avg:91.96ms +step:679/1645 train_time:62440ms step_avg:91.96ms +step:680/1645 train_time:62532ms step_avg:91.96ms +step:681/1645 train_time:62625ms step_avg:91.96ms +step:682/1645 train_time:62719ms step_avg:91.96ms +step:683/1645 train_time:62812ms step_avg:91.96ms +step:684/1645 train_time:62905ms step_avg:91.97ms +step:685/1645 train_time:62998ms step_avg:91.97ms +step:686/1645 train_time:63091ms step_avg:91.97ms +step:687/1645 train_time:63184ms step_avg:91.97ms +step:688/1645 train_time:63276ms step_avg:91.97ms +step:689/1645 train_time:63369ms step_avg:91.97ms +step:690/1645 train_time:63462ms step_avg:91.97ms +step:691/1645 train_time:63555ms step_avg:91.97ms +step:692/1645 train_time:63648ms step_avg:91.98ms +step:693/1645 train_time:63741ms step_avg:91.98ms +step:694/1645 train_time:63834ms step_avg:91.98ms +step:695/1645 train_time:63927ms step_avg:91.98ms +step:696/1645 train_time:64020ms step_avg:91.98ms +step:697/1645 train_time:64113ms step_avg:91.98ms +step:698/1645 train_time:64206ms step_avg:91.98ms +step:699/1645 train_time:64298ms step_avg:91.99ms +step:700/1645 train_time:64391ms step_avg:91.99ms +step:701/1645 train_time:64484ms step_avg:91.99ms +step:702/1645 train_time:64577ms step_avg:91.99ms +step:703/1645 train_time:64669ms step_avg:91.99ms +step:704/1645 train_time:64763ms step_avg:91.99ms +step:705/1645 train_time:64855ms step_avg:91.99ms +step:706/1645 train_time:64948ms step_avg:91.99ms +step:707/1645 train_time:65040ms step_avg:91.99ms +step:708/1645 train_time:65133ms step_avg:92.00ms +step:709/1645 train_time:65226ms step_avg:92.00ms +step:710/1645 train_time:65318ms step_avg:92.00ms +step:711/1645 train_time:65411ms step_avg:92.00ms +step:712/1645 train_time:65504ms step_avg:92.00ms +step:713/1645 train_time:65597ms step_avg:92.00ms +step:714/1645 train_time:65691ms step_avg:92.00ms +step:715/1645 train_time:65783ms step_avg:92.00ms +step:716/1645 train_time:65876ms step_avg:92.01ms +step:717/1645 train_time:65970ms step_avg:92.01ms +step:718/1645 train_time:66062ms step_avg:92.01ms +step:719/1645 train_time:66154ms step_avg:92.01ms +step:720/1645 train_time:66248ms step_avg:92.01ms +step:721/1645 train_time:66341ms step_avg:92.01ms +step:722/1645 train_time:66434ms step_avg:92.01ms +step:723/1645 train_time:66527ms step_avg:92.02ms +step:724/1645 train_time:66620ms step_avg:92.02ms +step:725/1645 train_time:66712ms step_avg:92.02ms +step:726/1645 train_time:66805ms step_avg:92.02ms +step:727/1645 train_time:66898ms step_avg:92.02ms +step:728/1645 train_time:66991ms step_avg:92.02ms +step:729/1645 train_time:67084ms step_avg:92.02ms +step:730/1645 train_time:67177ms step_avg:92.02ms +step:731/1645 train_time:67269ms step_avg:92.02ms +step:732/1645 train_time:67362ms step_avg:92.03ms +step:733/1645 train_time:67455ms step_avg:92.03ms +step:734/1645 train_time:67549ms step_avg:92.03ms +step:735/1645 train_time:67641ms step_avg:92.03ms +step:736/1645 train_time:67734ms step_avg:92.03ms +step:737/1645 train_time:67826ms step_avg:92.03ms +step:738/1645 train_time:67919ms step_avg:92.03ms +step:739/1645 train_time:68012ms step_avg:92.03ms +step:740/1645 train_time:68105ms step_avg:92.03ms +step:741/1645 train_time:68198ms step_avg:92.04ms +step:742/1645 train_time:68291ms step_avg:92.04ms +step:743/1645 train_time:68385ms step_avg:92.04ms +step:744/1645 train_time:68478ms step_avg:92.04ms +step:745/1645 train_time:68571ms step_avg:92.04ms +step:746/1645 train_time:68664ms step_avg:92.04ms +step:747/1645 train_time:68757ms step_avg:92.04ms +step:748/1645 train_time:68850ms step_avg:92.05ms +step:749/1645 train_time:68943ms step_avg:92.05ms +step:750/1645 train_time:69036ms step_avg:92.05ms +step:750/1645 val_loss:3.5624 train_time:69129ms step_avg:92.17ms +step:751/1645 train_time:69149ms step_avg:92.08ms +step:752/1645 train_time:69227ms step_avg:92.06ms +step:753/1645 train_time:69324ms step_avg:92.06ms +step:754/1645 train_time:69415ms step_avg:92.06ms +step:755/1645 train_time:69508ms step_avg:92.06ms +step:756/1645 train_time:69600ms step_avg:92.06ms +step:757/1645 train_time:69692ms step_avg:92.06ms +step:758/1645 train_time:69785ms step_avg:92.06ms +step:759/1645 train_time:69877ms step_avg:92.06ms +step:760/1645 train_time:69970ms step_avg:92.07ms +step:761/1645 train_time:70063ms step_avg:92.07ms +step:762/1645 train_time:70158ms step_avg:92.07ms +step:763/1645 train_time:70253ms step_avg:92.07ms +step:764/1645 train_time:70347ms step_avg:92.08ms +step:765/1645 train_time:70440ms step_avg:92.08ms +step:766/1645 train_time:70532ms step_avg:92.08ms +step:767/1645 train_time:70624ms step_avg:92.08ms +step:768/1645 train_time:70716ms step_avg:92.08ms +step:769/1645 train_time:70809ms step_avg:92.08ms +step:770/1645 train_time:70901ms step_avg:92.08ms +step:771/1645 train_time:70994ms step_avg:92.08ms +step:772/1645 train_time:71087ms step_avg:92.08ms +step:773/1645 train_time:71181ms step_avg:92.08ms +step:774/1645 train_time:71275ms step_avg:92.09ms +step:775/1645 train_time:71368ms step_avg:92.09ms +step:776/1645 train_time:71462ms step_avg:92.09ms +step:777/1645 train_time:71555ms step_avg:92.09ms +step:778/1645 train_time:71647ms step_avg:92.09ms +step:779/1645 train_time:71739ms step_avg:92.09ms +step:780/1645 train_time:71831ms step_avg:92.09ms +step:781/1645 train_time:71925ms step_avg:92.09ms +step:782/1645 train_time:72018ms step_avg:92.09ms +step:783/1645 train_time:72111ms step_avg:92.10ms +step:784/1645 train_time:72205ms step_avg:92.10ms +step:785/1645 train_time:72298ms step_avg:92.10ms +step:786/1645 train_time:72392ms step_avg:92.10ms +step:787/1645 train_time:72484ms step_avg:92.10ms +step:788/1645 train_time:72578ms step_avg:92.10ms +step:789/1645 train_time:72670ms step_avg:92.10ms +step:790/1645 train_time:72764ms step_avg:92.11ms +step:791/1645 train_time:72855ms step_avg:92.11ms +step:792/1645 train_time:72947ms step_avg:92.11ms +step:793/1645 train_time:73040ms step_avg:92.11ms +step:794/1645 train_time:73134ms step_avg:92.11ms +step:795/1645 train_time:73227ms step_avg:92.11ms +step:796/1645 train_time:73320ms step_avg:92.11ms +step:797/1645 train_time:73414ms step_avg:92.11ms +step:798/1645 train_time:73507ms step_avg:92.11ms +step:799/1645 train_time:73600ms step_avg:92.11ms +step:800/1645 train_time:73692ms step_avg:92.12ms +step:801/1645 train_time:73785ms step_avg:92.12ms +step:802/1645 train_time:73877ms step_avg:92.12ms +step:803/1645 train_time:73971ms step_avg:92.12ms +step:804/1645 train_time:74066ms step_avg:92.12ms +step:805/1645 train_time:74158ms step_avg:92.12ms +step:806/1645 train_time:74251ms step_avg:92.12ms +step:807/1645 train_time:74344ms step_avg:92.12ms +step:808/1645 train_time:74437ms step_avg:92.12ms +step:809/1645 train_time:74530ms step_avg:92.13ms +step:810/1645 train_time:74624ms step_avg:92.13ms +step:811/1645 train_time:74717ms step_avg:92.13ms +step:812/1645 train_time:74810ms step_avg:92.13ms +step:813/1645 train_time:74903ms step_avg:92.13ms +step:814/1645 train_time:74995ms step_avg:92.13ms +step:815/1645 train_time:75088ms step_avg:92.13ms +step:816/1645 train_time:75181ms step_avg:92.13ms +step:817/1645 train_time:75275ms step_avg:92.14ms +step:818/1645 train_time:75368ms step_avg:92.14ms +step:819/1645 train_time:75461ms step_avg:92.14ms +step:820/1645 train_time:75554ms step_avg:92.14ms +step:821/1645 train_time:75646ms step_avg:92.14ms +step:822/1645 train_time:75739ms step_avg:92.14ms +step:823/1645 train_time:75832ms step_avg:92.14ms +step:824/1645 train_time:75925ms step_avg:92.14ms +step:825/1645 train_time:76018ms step_avg:92.14ms +step:826/1645 train_time:76112ms step_avg:92.15ms +step:827/1645 train_time:76204ms step_avg:92.15ms +step:828/1645 train_time:76298ms step_avg:92.15ms +step:829/1645 train_time:76392ms step_avg:92.15ms +step:830/1645 train_time:76486ms step_avg:92.15ms +step:831/1645 train_time:76578ms step_avg:92.15ms +step:832/1645 train_time:76670ms step_avg:92.15ms +step:833/1645 train_time:76762ms step_avg:92.15ms +step:834/1645 train_time:76855ms step_avg:92.15ms +step:835/1645 train_time:76948ms step_avg:92.15ms +step:836/1645 train_time:77041ms step_avg:92.15ms +step:837/1645 train_time:77134ms step_avg:92.16ms +step:838/1645 train_time:77226ms step_avg:92.16ms +step:839/1645 train_time:77320ms step_avg:92.16ms +step:840/1645 train_time:77413ms step_avg:92.16ms +step:841/1645 train_time:77506ms step_avg:92.16ms +step:842/1645 train_time:77599ms step_avg:92.16ms +step:843/1645 train_time:77692ms step_avg:92.16ms +step:844/1645 train_time:77786ms step_avg:92.16ms +step:845/1645 train_time:77877ms step_avg:92.16ms +step:846/1645 train_time:77970ms step_avg:92.16ms +step:847/1645 train_time:78063ms step_avg:92.16ms +step:848/1645 train_time:78156ms step_avg:92.16ms +step:849/1645 train_time:78248ms step_avg:92.17ms +step:850/1645 train_time:78341ms step_avg:92.17ms +step:851/1645 train_time:78435ms step_avg:92.17ms +step:852/1645 train_time:78528ms step_avg:92.17ms +step:853/1645 train_time:78621ms step_avg:92.17ms +step:854/1645 train_time:78713ms step_avg:92.17ms +step:855/1645 train_time:78806ms step_avg:92.17ms +step:856/1645 train_time:78899ms step_avg:92.17ms +step:857/1645 train_time:78991ms step_avg:92.17ms +step:858/1645 train_time:79085ms step_avg:92.17ms +step:859/1645 train_time:79178ms step_avg:92.18ms +step:860/1645 train_time:79271ms step_avg:92.18ms +step:861/1645 train_time:79363ms step_avg:92.18ms +step:862/1645 train_time:79457ms step_avg:92.18ms +step:863/1645 train_time:79550ms step_avg:92.18ms +step:864/1645 train_time:79643ms step_avg:92.18ms +step:865/1645 train_time:79736ms step_avg:92.18ms +step:866/1645 train_time:79829ms step_avg:92.18ms +step:867/1645 train_time:79922ms step_avg:92.18ms +step:868/1645 train_time:80015ms step_avg:92.18ms +step:869/1645 train_time:80107ms step_avg:92.18ms +step:870/1645 train_time:80200ms step_avg:92.18ms +step:871/1645 train_time:80293ms step_avg:92.18ms +step:872/1645 train_time:80386ms step_avg:92.19ms +step:873/1645 train_time:80479ms step_avg:92.19ms +step:874/1645 train_time:80572ms step_avg:92.19ms +step:875/1645 train_time:80665ms step_avg:92.19ms +step:875/1645 val_loss:3.5158 train_time:80758ms step_avg:92.29ms +step:876/1645 train_time:80778ms step_avg:92.21ms +step:877/1645 train_time:80853ms step_avg:92.19ms +step:878/1645 train_time:80948ms step_avg:92.20ms +step:879/1645 train_time:81041ms step_avg:92.20ms +step:880/1645 train_time:81133ms step_avg:92.20ms +step:881/1645 train_time:81226ms step_avg:92.20ms +step:882/1645 train_time:81317ms step_avg:92.20ms +step:883/1645 train_time:81410ms step_avg:92.20ms +step:884/1645 train_time:81502ms step_avg:92.20ms +step:885/1645 train_time:81594ms step_avg:92.20ms +step:886/1645 train_time:81688ms step_avg:92.20ms +step:887/1645 train_time:81782ms step_avg:92.20ms +step:888/1645 train_time:81877ms step_avg:92.20ms +step:889/1645 train_time:81971ms step_avg:92.21ms +step:890/1645 train_time:82064ms step_avg:92.21ms +step:891/1645 train_time:82157ms step_avg:92.21ms +step:892/1645 train_time:82250ms step_avg:92.21ms +step:893/1645 train_time:82342ms step_avg:92.21ms +step:894/1645 train_time:82434ms step_avg:92.21ms +step:895/1645 train_time:82527ms step_avg:92.21ms +step:896/1645 train_time:82619ms step_avg:92.21ms +step:897/1645 train_time:82712ms step_avg:92.21ms +step:898/1645 train_time:82807ms step_avg:92.21ms +step:899/1645 train_time:82900ms step_avg:92.21ms +step:900/1645 train_time:82994ms step_avg:92.22ms +step:901/1645 train_time:83087ms step_avg:92.22ms +step:902/1645 train_time:83180ms step_avg:92.22ms +step:903/1645 train_time:83272ms step_avg:92.22ms +step:904/1645 train_time:83364ms step_avg:92.22ms +step:905/1645 train_time:83457ms step_avg:92.22ms +step:906/1645 train_time:83549ms step_avg:92.22ms +step:907/1645 train_time:83642ms step_avg:92.22ms +step:908/1645 train_time:83735ms step_avg:92.22ms +step:909/1645 train_time:83828ms step_avg:92.22ms +step:910/1645 train_time:83922ms step_avg:92.22ms +step:911/1645 train_time:84015ms step_avg:92.22ms +step:912/1645 train_time:84108ms step_avg:92.22ms +step:913/1645 train_time:84201ms step_avg:92.23ms +step:914/1645 train_time:84294ms step_avg:92.23ms +step:915/1645 train_time:84387ms step_avg:92.23ms +step:916/1645 train_time:84479ms step_avg:92.23ms +step:917/1645 train_time:84571ms step_avg:92.23ms +step:918/1645 train_time:84664ms step_avg:92.23ms +step:919/1645 train_time:84757ms step_avg:92.23ms +step:920/1645 train_time:84852ms step_avg:92.23ms +step:921/1645 train_time:84944ms step_avg:92.23ms +step:922/1645 train_time:85038ms step_avg:92.23ms +step:923/1645 train_time:85131ms step_avg:92.23ms +step:924/1645 train_time:85223ms step_avg:92.23ms +step:925/1645 train_time:85316ms step_avg:92.23ms +step:926/1645 train_time:85408ms step_avg:92.23ms +step:927/1645 train_time:85501ms step_avg:92.23ms +step:928/1645 train_time:85593ms step_avg:92.23ms +step:929/1645 train_time:85686ms step_avg:92.23ms +step:930/1645 train_time:85779ms step_avg:92.24ms +step:931/1645 train_time:85872ms step_avg:92.24ms +step:932/1645 train_time:85966ms step_avg:92.24ms +step:933/1645 train_time:86060ms step_avg:92.24ms +step:934/1645 train_time:86152ms step_avg:92.24ms +step:935/1645 train_time:86246ms step_avg:92.24ms +step:936/1645 train_time:86338ms step_avg:92.24ms +step:937/1645 train_time:86431ms step_avg:92.24ms +step:938/1645 train_time:86524ms step_avg:92.24ms +step:939/1645 train_time:86616ms step_avg:92.24ms +step:940/1645 train_time:86708ms step_avg:92.24ms +step:941/1645 train_time:86801ms step_avg:92.24ms +step:942/1645 train_time:86895ms step_avg:92.24ms +step:943/1645 train_time:86988ms step_avg:92.25ms +step:944/1645 train_time:87081ms step_avg:92.25ms +step:945/1645 train_time:87174ms step_avg:92.25ms +step:946/1645 train_time:87269ms step_avg:92.25ms +step:947/1645 train_time:87363ms step_avg:92.25ms +step:948/1645 train_time:87454ms step_avg:92.25ms +step:949/1645 train_time:87548ms step_avg:92.25ms +step:950/1645 train_time:87641ms step_avg:92.25ms +step:951/1645 train_time:87733ms step_avg:92.25ms +step:952/1645 train_time:87827ms step_avg:92.26ms +step:953/1645 train_time:87920ms step_avg:92.26ms +step:954/1645 train_time:88012ms step_avg:92.26ms +step:955/1645 train_time:88104ms step_avg:92.26ms +step:956/1645 train_time:88197ms step_avg:92.26ms +step:957/1645 train_time:88290ms step_avg:92.26ms +step:958/1645 train_time:88383ms step_avg:92.26ms +step:959/1645 train_time:88475ms step_avg:92.26ms +step:960/1645 train_time:88571ms step_avg:92.26ms +step:961/1645 train_time:88663ms step_avg:92.26ms +step:962/1645 train_time:88756ms step_avg:92.26ms +step:963/1645 train_time:88849ms step_avg:92.26ms +step:964/1645 train_time:88942ms step_avg:92.26ms +step:965/1645 train_time:89035ms step_avg:92.26ms +step:966/1645 train_time:89128ms step_avg:92.26ms +step:967/1645 train_time:89221ms step_avg:92.27ms +step:968/1645 train_time:89313ms step_avg:92.27ms +step:969/1645 train_time:89406ms step_avg:92.27ms +step:970/1645 train_time:89499ms step_avg:92.27ms +step:971/1645 train_time:89592ms step_avg:92.27ms +step:972/1645 train_time:89686ms step_avg:92.27ms +step:973/1645 train_time:89779ms step_avg:92.27ms +step:974/1645 train_time:89872ms step_avg:92.27ms +step:975/1645 train_time:89966ms step_avg:92.27ms +step:976/1645 train_time:90058ms step_avg:92.27ms +step:977/1645 train_time:90151ms step_avg:92.27ms +step:978/1645 train_time:90243ms step_avg:92.27ms +step:979/1645 train_time:90336ms step_avg:92.27ms +step:980/1645 train_time:90431ms step_avg:92.28ms +step:981/1645 train_time:90523ms step_avg:92.28ms +step:982/1645 train_time:90616ms step_avg:92.28ms +step:983/1645 train_time:90709ms step_avg:92.28ms +step:984/1645 train_time:90801ms step_avg:92.28ms +step:985/1645 train_time:90894ms step_avg:92.28ms +step:986/1645 train_time:90987ms step_avg:92.28ms +step:987/1645 train_time:91081ms step_avg:92.28ms +step:988/1645 train_time:91174ms step_avg:92.28ms +step:989/1645 train_time:91267ms step_avg:92.28ms +step:990/1645 train_time:91361ms step_avg:92.28ms +step:991/1645 train_time:91453ms step_avg:92.28ms +step:992/1645 train_time:91546ms step_avg:92.28ms +step:993/1645 train_time:91639ms step_avg:92.28ms +step:994/1645 train_time:91732ms step_avg:92.29ms +step:995/1645 train_time:91825ms step_avg:92.29ms +step:996/1645 train_time:91917ms step_avg:92.29ms +step:997/1645 train_time:92010ms step_avg:92.29ms +step:998/1645 train_time:92103ms step_avg:92.29ms +step:999/1645 train_time:92195ms step_avg:92.29ms +step:1000/1645 train_time:92288ms step_avg:92.29ms +step:1000/1645 val_loss:3.4649 train_time:92382ms step_avg:92.38ms +step:1001/1645 train_time:92403ms step_avg:92.31ms +step:1002/1645 train_time:92481ms step_avg:92.30ms +step:1003/1645 train_time:92576ms step_avg:92.30ms +step:1004/1645 train_time:92668ms step_avg:92.30ms +step:1005/1645 train_time:92760ms step_avg:92.30ms +step:1006/1645 train_time:92852ms step_avg:92.30ms +step:1007/1645 train_time:92944ms step_avg:92.30ms +step:1008/1645 train_time:93036ms step_avg:92.30ms +step:1009/1645 train_time:93129ms step_avg:92.30ms +step:1010/1645 train_time:93221ms step_avg:92.30ms +step:1011/1645 train_time:93315ms step_avg:92.30ms +step:1012/1645 train_time:93410ms step_avg:92.30ms +step:1013/1645 train_time:93506ms step_avg:92.31ms +step:1014/1645 train_time:93600ms step_avg:92.31ms +step:1015/1645 train_time:93694ms step_avg:92.31ms +step:1016/1645 train_time:93786ms step_avg:92.31ms +step:1017/1645 train_time:93878ms step_avg:92.31ms +step:1018/1645 train_time:93970ms step_avg:92.31ms +step:1019/1645 train_time:94062ms step_avg:92.31ms +step:1020/1645 train_time:94155ms step_avg:92.31ms +step:1021/1645 train_time:94248ms step_avg:92.31ms +step:1022/1645 train_time:94341ms step_avg:92.31ms +step:1023/1645 train_time:94435ms step_avg:92.31ms +step:1024/1645 train_time:94530ms step_avg:92.31ms +step:1025/1645 train_time:94624ms step_avg:92.32ms +step:1026/1645 train_time:94716ms step_avg:92.32ms +step:1027/1645 train_time:94809ms step_avg:92.32ms +step:1028/1645 train_time:94901ms step_avg:92.32ms +step:1029/1645 train_time:94993ms step_avg:92.32ms +step:1030/1645 train_time:95085ms step_avg:92.32ms +step:1031/1645 train_time:95178ms step_avg:92.32ms +step:1032/1645 train_time:95270ms step_avg:92.32ms +step:1033/1645 train_time:95363ms step_avg:92.32ms +step:1034/1645 train_time:95457ms step_avg:92.32ms +step:1035/1645 train_time:95550ms step_avg:92.32ms +step:1036/1645 train_time:95643ms step_avg:92.32ms +step:1037/1645 train_time:95737ms step_avg:92.32ms +step:1038/1645 train_time:95830ms step_avg:92.32ms +step:1039/1645 train_time:95923ms step_avg:92.32ms +step:1040/1645 train_time:96016ms step_avg:92.32ms +step:1041/1645 train_time:96108ms step_avg:92.32ms +step:1042/1645 train_time:96200ms step_avg:92.32ms +step:1043/1645 train_time:96293ms step_avg:92.32ms +step:1044/1645 train_time:96387ms step_avg:92.32ms +step:1045/1645 train_time:96480ms step_avg:92.33ms +step:1046/1645 train_time:96573ms step_avg:92.33ms +step:1047/1645 train_time:96666ms step_avg:92.33ms +step:1048/1645 train_time:96759ms step_avg:92.33ms +step:1049/1645 train_time:96852ms step_avg:92.33ms +step:1050/1645 train_time:96944ms step_avg:92.33ms +step:1051/1645 train_time:97037ms step_avg:92.33ms +step:1052/1645 train_time:97129ms step_avg:92.33ms +step:1053/1645 train_time:97222ms step_avg:92.33ms +step:1054/1645 train_time:97314ms step_avg:92.33ms +step:1055/1645 train_time:97409ms step_avg:92.33ms +step:1056/1645 train_time:97503ms step_avg:92.33ms +step:1057/1645 train_time:97596ms step_avg:92.33ms +step:1058/1645 train_time:97691ms step_avg:92.34ms +step:1059/1645 train_time:97784ms step_avg:92.34ms +step:1060/1645 train_time:97876ms step_avg:92.34ms +step:1061/1645 train_time:97970ms step_avg:92.34ms +step:1062/1645 train_time:98062ms step_avg:92.34ms +step:1063/1645 train_time:98156ms step_avg:92.34ms +step:1064/1645 train_time:98249ms step_avg:92.34ms +step:1065/1645 train_time:98341ms step_avg:92.34ms +step:1066/1645 train_time:98434ms step_avg:92.34ms +step:1067/1645 train_time:98528ms step_avg:92.34ms +step:1068/1645 train_time:98621ms step_avg:92.34ms +step:1069/1645 train_time:98713ms step_avg:92.34ms +step:1070/1645 train_time:98806ms step_avg:92.34ms +step:1071/1645 train_time:98898ms step_avg:92.34ms +step:1072/1645 train_time:98991ms step_avg:92.34ms +step:1073/1645 train_time:99083ms step_avg:92.34ms +step:1074/1645 train_time:99176ms step_avg:92.34ms +step:1075/1645 train_time:99269ms step_avg:92.34ms +step:1076/1645 train_time:99362ms step_avg:92.34ms +step:1077/1645 train_time:99456ms step_avg:92.35ms +step:1078/1645 train_time:99548ms step_avg:92.35ms +step:1079/1645 train_time:99642ms step_avg:92.35ms +step:1080/1645 train_time:99735ms step_avg:92.35ms +step:1081/1645 train_time:99828ms step_avg:92.35ms +step:1082/1645 train_time:99922ms step_avg:92.35ms +step:1083/1645 train_time:100014ms step_avg:92.35ms +step:1084/1645 train_time:100107ms step_avg:92.35ms +step:1085/1645 train_time:100200ms step_avg:92.35ms +step:1086/1645 train_time:100293ms step_avg:92.35ms +step:1087/1645 train_time:100386ms step_avg:92.35ms +step:1088/1645 train_time:100479ms step_avg:92.35ms +step:1089/1645 train_time:100572ms step_avg:92.35ms +step:1090/1645 train_time:100665ms step_avg:92.35ms +step:1091/1645 train_time:100758ms step_avg:92.35ms +step:1092/1645 train_time:100850ms step_avg:92.35ms +step:1093/1645 train_time:100944ms step_avg:92.36ms +step:1094/1645 train_time:101037ms step_avg:92.36ms +step:1095/1645 train_time:101130ms step_avg:92.36ms +step:1096/1645 train_time:101222ms step_avg:92.36ms +step:1097/1645 train_time:101315ms step_avg:92.36ms +step:1098/1645 train_time:101408ms step_avg:92.36ms +step:1099/1645 train_time:101501ms step_avg:92.36ms +step:1100/1645 train_time:101595ms step_avg:92.36ms +step:1101/1645 train_time:101689ms step_avg:92.36ms +step:1102/1645 train_time:101783ms step_avg:92.36ms +step:1103/1645 train_time:101877ms step_avg:92.36ms +step:1104/1645 train_time:101970ms step_avg:92.36ms +step:1105/1645 train_time:102064ms step_avg:92.37ms +step:1106/1645 train_time:102157ms step_avg:92.37ms +step:1107/1645 train_time:102250ms step_avg:92.37ms +step:1108/1645 train_time:102343ms step_avg:92.37ms +step:1109/1645 train_time:102437ms step_avg:92.37ms +step:1110/1645 train_time:102531ms step_avg:92.37ms +step:1111/1645 train_time:102625ms step_avg:92.37ms +step:1112/1645 train_time:102718ms step_avg:92.37ms +step:1113/1645 train_time:102812ms step_avg:92.37ms +step:1114/1645 train_time:102906ms step_avg:92.38ms +step:1115/1645 train_time:103000ms step_avg:92.38ms +step:1116/1645 train_time:103094ms step_avg:92.38ms +step:1117/1645 train_time:103187ms step_avg:92.38ms +step:1118/1645 train_time:103281ms step_avg:92.38ms +step:1119/1645 train_time:103375ms step_avg:92.38ms +step:1120/1645 train_time:103468ms step_avg:92.38ms +step:1121/1645 train_time:103561ms step_avg:92.38ms +step:1122/1645 train_time:103655ms step_avg:92.38ms +step:1123/1645 train_time:103748ms step_avg:92.38ms +step:1124/1645 train_time:103841ms step_avg:92.39ms +step:1125/1645 train_time:103935ms step_avg:92.39ms +step:1125/1645 val_loss:3.4131 train_time:104029ms step_avg:92.47ms +step:1126/1645 train_time:104052ms step_avg:92.41ms +step:1127/1645 train_time:104132ms step_avg:92.40ms +step:1128/1645 train_time:104234ms step_avg:92.41ms +step:1129/1645 train_time:104329ms step_avg:92.41ms +step:1130/1645 train_time:104421ms step_avg:92.41ms +step:1131/1645 train_time:104513ms step_avg:92.41ms +step:1132/1645 train_time:104605ms step_avg:92.41ms +step:1133/1645 train_time:104698ms step_avg:92.41ms +step:1134/1645 train_time:104790ms step_avg:92.41ms +step:1135/1645 train_time:104883ms step_avg:92.41ms +step:1136/1645 train_time:104977ms step_avg:92.41ms +step:1137/1645 train_time:105073ms step_avg:92.41ms +step:1138/1645 train_time:105168ms step_avg:92.41ms +step:1139/1645 train_time:105264ms step_avg:92.42ms +step:1140/1645 train_time:105359ms step_avg:92.42ms +step:1141/1645 train_time:105452ms step_avg:92.42ms +step:1142/1645 train_time:105545ms step_avg:92.42ms +step:1143/1645 train_time:105638ms step_avg:92.42ms +step:1144/1645 train_time:105730ms step_avg:92.42ms +step:1145/1645 train_time:105823ms step_avg:92.42ms +step:1146/1645 train_time:105916ms step_avg:92.42ms +step:1147/1645 train_time:106009ms step_avg:92.42ms +step:1148/1645 train_time:106105ms step_avg:92.43ms +step:1149/1645 train_time:106198ms step_avg:92.43ms +step:1150/1645 train_time:106292ms step_avg:92.43ms +step:1151/1645 train_time:106386ms step_avg:92.43ms +step:1152/1645 train_time:106480ms step_avg:92.43ms +step:1153/1645 train_time:106573ms step_avg:92.43ms +step:1154/1645 train_time:106666ms step_avg:92.43ms +step:1155/1645 train_time:106759ms step_avg:92.43ms +step:1156/1645 train_time:106852ms step_avg:92.43ms +step:1157/1645 train_time:106945ms step_avg:92.43ms +step:1158/1645 train_time:107039ms step_avg:92.43ms +step:1159/1645 train_time:107134ms step_avg:92.44ms +step:1160/1645 train_time:107228ms step_avg:92.44ms +step:1161/1645 train_time:107322ms step_avg:92.44ms +step:1162/1645 train_time:107417ms step_avg:92.44ms +step:1163/1645 train_time:107510ms step_avg:92.44ms +step:1164/1645 train_time:107603ms step_avg:92.44ms +step:1165/1645 train_time:107696ms step_avg:92.44ms +step:1166/1645 train_time:107789ms step_avg:92.44ms +step:1167/1645 train_time:107882ms step_avg:92.44ms +step:1168/1645 train_time:107975ms step_avg:92.44ms +step:1169/1645 train_time:108068ms step_avg:92.45ms +step:1170/1645 train_time:108162ms step_avg:92.45ms +step:1171/1645 train_time:108257ms step_avg:92.45ms +step:1172/1645 train_time:108351ms step_avg:92.45ms +step:1173/1645 train_time:108447ms step_avg:92.45ms +step:1174/1645 train_time:108540ms step_avg:92.45ms +step:1175/1645 train_time:108633ms step_avg:92.45ms +step:1176/1645 train_time:108727ms step_avg:92.45ms +step:1177/1645 train_time:108820ms step_avg:92.45ms +step:1178/1645 train_time:108913ms step_avg:92.46ms +step:1179/1645 train_time:109005ms step_avg:92.46ms +step:1180/1645 train_time:109099ms step_avg:92.46ms +step:1181/1645 train_time:109192ms step_avg:92.46ms +step:1182/1645 train_time:109287ms step_avg:92.46ms +step:1183/1645 train_time:109381ms step_avg:92.46ms +step:1184/1645 train_time:109476ms step_avg:92.46ms +step:1185/1645 train_time:109569ms step_avg:92.46ms +step:1186/1645 train_time:109663ms step_avg:92.46ms +step:1187/1645 train_time:109756ms step_avg:92.46ms +step:1188/1645 train_time:109849ms step_avg:92.47ms +step:1189/1645 train_time:109942ms step_avg:92.47ms +step:1190/1645 train_time:110036ms step_avg:92.47ms +step:1191/1645 train_time:110129ms step_avg:92.47ms +step:1192/1645 train_time:110223ms step_avg:92.47ms +step:1193/1645 train_time:110317ms step_avg:92.47ms +step:1194/1645 train_time:110411ms step_avg:92.47ms +step:1195/1645 train_time:110504ms step_avg:92.47ms +step:1196/1645 train_time:110598ms step_avg:92.47ms +step:1197/1645 train_time:110691ms step_avg:92.47ms +step:1198/1645 train_time:110785ms step_avg:92.47ms +step:1199/1645 train_time:110879ms step_avg:92.48ms +step:1200/1645 train_time:110972ms step_avg:92.48ms +step:1201/1645 train_time:111064ms step_avg:92.48ms +step:1202/1645 train_time:111159ms step_avg:92.48ms +step:1203/1645 train_time:111252ms step_avg:92.48ms +step:1204/1645 train_time:111345ms step_avg:92.48ms +step:1205/1645 train_time:111439ms step_avg:92.48ms +step:1206/1645 train_time:111533ms step_avg:92.48ms +step:1207/1645 train_time:111626ms step_avg:92.48ms +step:1208/1645 train_time:111721ms step_avg:92.48ms +step:1209/1645 train_time:111814ms step_avg:92.48ms +step:1210/1645 train_time:111907ms step_avg:92.49ms +step:1211/1645 train_time:112001ms step_avg:92.49ms +step:1212/1645 train_time:112093ms step_avg:92.49ms +step:1213/1645 train_time:112187ms step_avg:92.49ms +step:1214/1645 train_time:112280ms step_avg:92.49ms +step:1215/1645 train_time:112373ms step_avg:92.49ms +step:1216/1645 train_time:112467ms step_avg:92.49ms +step:1217/1645 train_time:112561ms step_avg:92.49ms +step:1218/1645 train_time:112655ms step_avg:92.49ms +step:1219/1645 train_time:112749ms step_avg:92.49ms +step:1220/1645 train_time:112843ms step_avg:92.49ms +step:1221/1645 train_time:112937ms step_avg:92.50ms +step:1222/1645 train_time:113030ms step_avg:92.50ms +step:1223/1645 train_time:113122ms step_avg:92.50ms +step:1224/1645 train_time:113216ms step_avg:92.50ms +step:1225/1645 train_time:113309ms step_avg:92.50ms +step:1226/1645 train_time:113403ms step_avg:92.50ms +step:1227/1645 train_time:113497ms step_avg:92.50ms +step:1228/1645 train_time:113591ms step_avg:92.50ms +step:1229/1645 train_time:113684ms step_avg:92.50ms +step:1230/1645 train_time:113779ms step_avg:92.50ms +step:1231/1645 train_time:113874ms step_avg:92.51ms +step:1232/1645 train_time:113966ms step_avg:92.51ms +step:1233/1645 train_time:114060ms step_avg:92.51ms +step:1234/1645 train_time:114153ms step_avg:92.51ms +step:1235/1645 train_time:114246ms step_avg:92.51ms +step:1236/1645 train_time:114340ms step_avg:92.51ms +step:1237/1645 train_time:114433ms step_avg:92.51ms +step:1238/1645 train_time:114527ms step_avg:92.51ms +step:1239/1645 train_time:114620ms step_avg:92.51ms +step:1240/1645 train_time:114714ms step_avg:92.51ms +step:1241/1645 train_time:114808ms step_avg:92.51ms +step:1242/1645 train_time:114901ms step_avg:92.51ms +step:1243/1645 train_time:114995ms step_avg:92.51ms +step:1244/1645 train_time:115088ms step_avg:92.51ms +step:1245/1645 train_time:115183ms step_avg:92.52ms +step:1246/1645 train_time:115277ms step_avg:92.52ms +step:1247/1645 train_time:115369ms step_avg:92.52ms +step:1248/1645 train_time:115463ms step_avg:92.52ms +step:1249/1645 train_time:115556ms step_avg:92.52ms +step:1250/1645 train_time:115649ms step_avg:92.52ms +step:1250/1645 val_loss:3.3746 train_time:115743ms step_avg:92.59ms +step:1251/1645 train_time:115764ms step_avg:92.54ms +step:1252/1645 train_time:115841ms step_avg:92.52ms +step:1253/1645 train_time:115935ms step_avg:92.53ms +step:1254/1645 train_time:116029ms step_avg:92.53ms +step:1255/1645 train_time:116121ms step_avg:92.53ms +step:1256/1645 train_time:116214ms step_avg:92.53ms +step:1257/1645 train_time:116307ms step_avg:92.53ms +step:1258/1645 train_time:116400ms step_avg:92.53ms +step:1259/1645 train_time:116492ms step_avg:92.53ms +step:1260/1645 train_time:116586ms step_avg:92.53ms +step:1261/1645 train_time:116681ms step_avg:92.53ms +step:1262/1645 train_time:116777ms step_avg:92.53ms +step:1263/1645 train_time:116872ms step_avg:92.54ms +step:1264/1645 train_time:116967ms step_avg:92.54ms +step:1265/1645 train_time:117060ms step_avg:92.54ms +step:1266/1645 train_time:117153ms step_avg:92.54ms +step:1267/1645 train_time:117246ms step_avg:92.54ms +step:1268/1645 train_time:117340ms step_avg:92.54ms +step:1269/1645 train_time:117433ms step_avg:92.54ms +step:1270/1645 train_time:117525ms step_avg:92.54ms +step:1271/1645 train_time:117619ms step_avg:92.54ms +step:1272/1645 train_time:117713ms step_avg:92.54ms +step:1273/1645 train_time:117809ms step_avg:92.54ms +step:1274/1645 train_time:117905ms step_avg:92.55ms +step:1275/1645 train_time:117998ms step_avg:92.55ms +step:1276/1645 train_time:118091ms step_avg:92.55ms +step:1277/1645 train_time:118185ms step_avg:92.55ms +step:1278/1645 train_time:118278ms step_avg:92.55ms +step:1279/1645 train_time:118371ms step_avg:92.55ms +step:1280/1645 train_time:118464ms step_avg:92.55ms +step:1281/1645 train_time:118557ms step_avg:92.55ms +step:1282/1645 train_time:118650ms step_avg:92.55ms +step:1283/1645 train_time:118745ms step_avg:92.55ms +step:1284/1645 train_time:118840ms step_avg:92.55ms +step:1285/1645 train_time:118934ms step_avg:92.56ms +step:1286/1645 train_time:119028ms step_avg:92.56ms +step:1287/1645 train_time:119121ms step_avg:92.56ms +step:1288/1645 train_time:119214ms step_avg:92.56ms +step:1289/1645 train_time:119308ms step_avg:92.56ms +step:1290/1645 train_time:119401ms step_avg:92.56ms +step:1291/1645 train_time:119495ms step_avg:92.56ms +step:1292/1645 train_time:119588ms step_avg:92.56ms +step:1293/1645 train_time:119682ms step_avg:92.56ms +step:1294/1645 train_time:119777ms step_avg:92.56ms +step:1295/1645 train_time:119871ms step_avg:92.56ms +step:1296/1645 train_time:119966ms step_avg:92.57ms +step:1297/1645 train_time:120061ms step_avg:92.57ms +step:1298/1645 train_time:120154ms step_avg:92.57ms +step:1299/1645 train_time:120247ms step_avg:92.57ms +step:1300/1645 train_time:120340ms step_avg:92.57ms +step:1301/1645 train_time:120433ms step_avg:92.57ms +step:1302/1645 train_time:120527ms step_avg:92.57ms +step:1303/1645 train_time:120621ms step_avg:92.57ms +step:1304/1645 train_time:120714ms step_avg:92.57ms +step:1305/1645 train_time:120808ms step_avg:92.57ms +step:1306/1645 train_time:120902ms step_avg:92.57ms +step:1307/1645 train_time:120996ms step_avg:92.58ms +step:1308/1645 train_time:121091ms step_avg:92.58ms +step:1309/1645 train_time:121184ms step_avg:92.58ms +step:1310/1645 train_time:121278ms step_avg:92.58ms +step:1311/1645 train_time:121371ms step_avg:92.58ms +step:1312/1645 train_time:121465ms step_avg:92.58ms +step:1313/1645 train_time:121558ms step_avg:92.58ms +step:1314/1645 train_time:121651ms step_avg:92.58ms +step:1315/1645 train_time:121744ms step_avg:92.58ms +step:1316/1645 train_time:121839ms step_avg:92.58ms +step:1317/1645 train_time:121932ms step_avg:92.58ms +step:1318/1645 train_time:122027ms step_avg:92.58ms +step:1319/1645 train_time:122120ms step_avg:92.59ms +step:1320/1645 train_time:122213ms step_avg:92.59ms +step:1321/1645 train_time:122307ms step_avg:92.59ms +step:1322/1645 train_time:122401ms step_avg:92.59ms +step:1323/1645 train_time:122494ms step_avg:92.59ms +step:1324/1645 train_time:122588ms step_avg:92.59ms +step:1325/1645 train_time:122683ms step_avg:92.59ms +step:1326/1645 train_time:122777ms step_avg:92.59ms +step:1327/1645 train_time:122870ms step_avg:92.59ms +step:1328/1645 train_time:122964ms step_avg:92.59ms +step:1329/1645 train_time:123057ms step_avg:92.59ms +step:1330/1645 train_time:123151ms step_avg:92.60ms +step:1331/1645 train_time:123245ms step_avg:92.60ms +step:1332/1645 train_time:123338ms step_avg:92.60ms +step:1333/1645 train_time:123432ms step_avg:92.60ms +step:1334/1645 train_time:123525ms step_avg:92.60ms +step:1335/1645 train_time:123619ms step_avg:92.60ms +step:1336/1645 train_time:123713ms step_avg:92.60ms +step:1337/1645 train_time:123807ms step_avg:92.60ms +step:1338/1645 train_time:123901ms step_avg:92.60ms +step:1339/1645 train_time:123994ms step_avg:92.60ms +step:1340/1645 train_time:124089ms step_avg:92.60ms +step:1341/1645 train_time:124183ms step_avg:92.60ms +step:1342/1645 train_time:124277ms step_avg:92.61ms +step:1343/1645 train_time:124371ms step_avg:92.61ms +step:1344/1645 train_time:124466ms step_avg:92.61ms +step:1345/1645 train_time:124560ms step_avg:92.61ms +step:1346/1645 train_time:124654ms step_avg:92.61ms +step:1347/1645 train_time:124748ms step_avg:92.61ms +step:1348/1645 train_time:124841ms step_avg:92.61ms +step:1349/1645 train_time:124935ms step_avg:92.61ms +step:1350/1645 train_time:125029ms step_avg:92.61ms +step:1351/1645 train_time:125123ms step_avg:92.61ms +step:1352/1645 train_time:125216ms step_avg:92.62ms +step:1353/1645 train_time:125309ms step_avg:92.62ms +step:1354/1645 train_time:125403ms step_avg:92.62ms +step:1355/1645 train_time:125496ms step_avg:92.62ms +step:1356/1645 train_time:125590ms step_avg:92.62ms +step:1357/1645 train_time:125685ms step_avg:92.62ms +step:1358/1645 train_time:125778ms step_avg:92.62ms +step:1359/1645 train_time:125871ms step_avg:92.62ms +step:1360/1645 train_time:125964ms step_avg:92.62ms +step:1361/1645 train_time:126058ms step_avg:92.62ms +step:1362/1645 train_time:126151ms step_avg:92.62ms +step:1363/1645 train_time:126246ms step_avg:92.62ms +step:1364/1645 train_time:126339ms step_avg:92.62ms +step:1365/1645 train_time:126433ms step_avg:92.62ms +step:1366/1645 train_time:126527ms step_avg:92.63ms +step:1367/1645 train_time:126620ms step_avg:92.63ms +step:1368/1645 train_time:126715ms step_avg:92.63ms +step:1369/1645 train_time:126809ms step_avg:92.63ms +step:1370/1645 train_time:126903ms step_avg:92.63ms +step:1371/1645 train_time:126997ms step_avg:92.63ms +step:1372/1645 train_time:127090ms step_avg:92.63ms +step:1373/1645 train_time:127184ms step_avg:92.63ms +step:1374/1645 train_time:127278ms step_avg:92.63ms +step:1375/1645 train_time:127371ms step_avg:92.63ms +step:1375/1645 val_loss:3.3397 train_time:127466ms step_avg:92.70ms +step:1376/1645 train_time:127486ms step_avg:92.65ms +step:1377/1645 train_time:127565ms step_avg:92.64ms +step:1378/1645 train_time:127660ms step_avg:92.64ms +step:1379/1645 train_time:127754ms step_avg:92.64ms +step:1380/1645 train_time:127846ms step_avg:92.64ms +step:1381/1645 train_time:127939ms step_avg:92.64ms +step:1382/1645 train_time:128033ms step_avg:92.64ms +step:1383/1645 train_time:128126ms step_avg:92.64ms +step:1384/1645 train_time:128221ms step_avg:92.65ms +step:1385/1645 train_time:128317ms step_avg:92.65ms +step:1386/1645 train_time:128411ms step_avg:92.65ms +step:1387/1645 train_time:128506ms step_avg:92.65ms +step:1388/1645 train_time:128601ms step_avg:92.65ms +step:1389/1645 train_time:128695ms step_avg:92.65ms +step:1390/1645 train_time:128789ms step_avg:92.65ms +step:1391/1645 train_time:128883ms step_avg:92.65ms +step:1392/1645 train_time:128977ms step_avg:92.66ms +step:1393/1645 train_time:129071ms step_avg:92.66ms +step:1394/1645 train_time:129163ms step_avg:92.66ms +step:1395/1645 train_time:129257ms step_avg:92.66ms +step:1396/1645 train_time:129352ms step_avg:92.66ms +step:1397/1645 train_time:129446ms step_avg:92.66ms +step:1398/1645 train_time:129541ms step_avg:92.66ms +step:1399/1645 train_time:129635ms step_avg:92.66ms +step:1400/1645 train_time:129729ms step_avg:92.66ms +step:1401/1645 train_time:129822ms step_avg:92.66ms +step:1402/1645 train_time:129915ms step_avg:92.66ms +step:1403/1645 train_time:130009ms step_avg:92.66ms +step:1404/1645 train_time:130102ms step_avg:92.67ms +step:1405/1645 train_time:130196ms step_avg:92.67ms +step:1406/1645 train_time:130290ms step_avg:92.67ms +step:1407/1645 train_time:130384ms step_avg:92.67ms +step:1408/1645 train_time:130479ms step_avg:92.67ms +step:1409/1645 train_time:130574ms step_avg:92.67ms +step:1410/1645 train_time:130668ms step_avg:92.67ms +step:1411/1645 train_time:130760ms step_avg:92.67ms +step:1412/1645 train_time:130854ms step_avg:92.67ms +step:1413/1645 train_time:130947ms step_avg:92.67ms +step:1414/1645 train_time:131041ms step_avg:92.67ms +step:1415/1645 train_time:131134ms step_avg:92.67ms +step:1416/1645 train_time:131227ms step_avg:92.67ms +step:1417/1645 train_time:131322ms step_avg:92.68ms +step:1418/1645 train_time:131416ms step_avg:92.68ms +step:1419/1645 train_time:131510ms step_avg:92.68ms +step:1420/1645 train_time:131603ms step_avg:92.68ms +step:1421/1645 train_time:131697ms step_avg:92.68ms +step:1422/1645 train_time:131791ms step_avg:92.68ms +step:1423/1645 train_time:131885ms step_avg:92.68ms +step:1424/1645 train_time:131978ms step_avg:92.68ms +step:1425/1645 train_time:132072ms step_avg:92.68ms +step:1426/1645 train_time:132164ms step_avg:92.68ms +step:1427/1645 train_time:132258ms step_avg:92.68ms +step:1428/1645 train_time:132353ms step_avg:92.68ms +step:1429/1645 train_time:132447ms step_avg:92.68ms +step:1430/1645 train_time:132541ms step_avg:92.69ms +step:1431/1645 train_time:132635ms step_avg:92.69ms +step:1432/1645 train_time:132728ms step_avg:92.69ms +step:1433/1645 train_time:132821ms step_avg:92.69ms +step:1434/1645 train_time:132915ms step_avg:92.69ms +step:1435/1645 train_time:133008ms step_avg:92.69ms +step:1436/1645 train_time:133102ms step_avg:92.69ms +step:1437/1645 train_time:133195ms step_avg:92.69ms +step:1438/1645 train_time:133289ms step_avg:92.69ms +step:1439/1645 train_time:133383ms step_avg:92.69ms +step:1440/1645 train_time:133476ms step_avg:92.69ms +step:1441/1645 train_time:133570ms step_avg:92.69ms +step:1442/1645 train_time:133664ms step_avg:92.69ms +step:1443/1645 train_time:133757ms step_avg:92.69ms +step:1444/1645 train_time:133851ms step_avg:92.69ms +step:1445/1645 train_time:133944ms step_avg:92.69ms +step:1446/1645 train_time:134038ms step_avg:92.70ms +step:1447/1645 train_time:134132ms step_avg:92.70ms +step:1448/1645 train_time:134225ms step_avg:92.70ms +step:1449/1645 train_time:134319ms step_avg:92.70ms +step:1450/1645 train_time:134412ms step_avg:92.70ms +step:1451/1645 train_time:134506ms step_avg:92.70ms +step:1452/1645 train_time:134600ms step_avg:92.70ms +step:1453/1645 train_time:134693ms step_avg:92.70ms +step:1454/1645 train_time:134787ms step_avg:92.70ms +step:1455/1645 train_time:134881ms step_avg:92.70ms +step:1456/1645 train_time:134975ms step_avg:92.70ms +step:1457/1645 train_time:135067ms step_avg:92.70ms +step:1458/1645 train_time:135161ms step_avg:92.70ms +step:1459/1645 train_time:135255ms step_avg:92.70ms +step:1460/1645 train_time:135348ms step_avg:92.70ms +step:1461/1645 train_time:135442ms step_avg:92.71ms +step:1462/1645 train_time:135536ms step_avg:92.71ms +step:1463/1645 train_time:135631ms step_avg:92.71ms +step:1464/1645 train_time:135723ms step_avg:92.71ms +step:1465/1645 train_time:135818ms step_avg:92.71ms +step:1466/1645 train_time:135911ms step_avg:92.71ms +step:1467/1645 train_time:136005ms step_avg:92.71ms +step:1468/1645 train_time:136099ms step_avg:92.71ms +step:1469/1645 train_time:136193ms step_avg:92.71ms +step:1470/1645 train_time:136286ms step_avg:92.71ms +step:1471/1645 train_time:136380ms step_avg:92.71ms +step:1472/1645 train_time:136474ms step_avg:92.71ms +step:1473/1645 train_time:136567ms step_avg:92.71ms +step:1474/1645 train_time:136660ms step_avg:92.71ms +step:1475/1645 train_time:136754ms step_avg:92.71ms +step:1476/1645 train_time:136848ms step_avg:92.72ms +step:1477/1645 train_time:136941ms step_avg:92.72ms +step:1478/1645 train_time:137036ms step_avg:92.72ms +step:1479/1645 train_time:137128ms step_avg:92.72ms +step:1480/1645 train_time:137222ms step_avg:92.72ms +step:1481/1645 train_time:137316ms step_avg:92.72ms +step:1482/1645 train_time:137410ms step_avg:92.72ms +step:1483/1645 train_time:137504ms step_avg:92.72ms +step:1484/1645 train_time:137599ms step_avg:92.72ms +step:1485/1645 train_time:137694ms step_avg:92.72ms +step:1486/1645 train_time:137786ms step_avg:92.72ms +step:1487/1645 train_time:137880ms step_avg:92.72ms +step:1488/1645 train_time:137973ms step_avg:92.72ms +step:1489/1645 train_time:138067ms step_avg:92.72ms +step:1490/1645 train_time:138160ms step_avg:92.73ms +step:1491/1645 train_time:138255ms step_avg:92.73ms +step:1492/1645 train_time:138348ms step_avg:92.73ms +step:1493/1645 train_time:138443ms step_avg:92.73ms +step:1494/1645 train_time:138536ms step_avg:92.73ms +step:1495/1645 train_time:138629ms step_avg:92.73ms +step:1496/1645 train_time:138722ms step_avg:92.73ms +step:1497/1645 train_time:138817ms step_avg:92.73ms +step:1498/1645 train_time:138910ms step_avg:92.73ms +step:1499/1645 train_time:139003ms step_avg:92.73ms +step:1500/1645 train_time:139098ms step_avg:92.73ms +step:1500/1645 val_loss:3.3096 train_time:139192ms step_avg:92.79ms +step:1501/1645 train_time:139217ms step_avg:92.75ms +step:1502/1645 train_time:139291ms step_avg:92.74ms +step:1503/1645 train_time:139387ms step_avg:92.74ms +step:1504/1645 train_time:139481ms step_avg:92.74ms +step:1505/1645 train_time:139573ms step_avg:92.74ms +step:1506/1645 train_time:139666ms step_avg:92.74ms +step:1507/1645 train_time:139758ms step_avg:92.74ms +step:1508/1645 train_time:139851ms step_avg:92.74ms +step:1509/1645 train_time:139947ms step_avg:92.74ms +step:1510/1645 train_time:140042ms step_avg:92.74ms +step:1511/1645 train_time:140136ms step_avg:92.74ms +step:1512/1645 train_time:140230ms step_avg:92.74ms +step:1513/1645 train_time:140324ms step_avg:92.75ms +step:1514/1645 train_time:140418ms step_avg:92.75ms +step:1515/1645 train_time:140512ms step_avg:92.75ms +step:1516/1645 train_time:140606ms step_avg:92.75ms +step:1517/1645 train_time:140700ms step_avg:92.75ms +step:1518/1645 train_time:140792ms step_avg:92.75ms +step:1519/1645 train_time:140886ms step_avg:92.75ms +step:1520/1645 train_time:140981ms step_avg:92.75ms +step:1521/1645 train_time:141074ms step_avg:92.75ms +step:1522/1645 train_time:141169ms step_avg:92.75ms +step:1523/1645 train_time:141263ms step_avg:92.75ms +step:1524/1645 train_time:141357ms step_avg:92.75ms +step:1525/1645 train_time:141451ms step_avg:92.75ms +step:1526/1645 train_time:141546ms step_avg:92.76ms +step:1527/1645 train_time:141639ms step_avg:92.76ms +step:1528/1645 train_time:141732ms step_avg:92.76ms +step:1529/1645 train_time:141825ms step_avg:92.76ms +step:1530/1645 train_time:141919ms step_avg:92.76ms +step:1531/1645 train_time:142012ms step_avg:92.76ms +step:1532/1645 train_time:142106ms step_avg:92.76ms +step:1533/1645 train_time:142200ms step_avg:92.76ms +step:1534/1645 train_time:142294ms step_avg:92.76ms +step:1535/1645 train_time:142388ms step_avg:92.76ms +step:1536/1645 train_time:142482ms step_avg:92.76ms +step:1537/1645 train_time:142575ms step_avg:92.76ms +step:1538/1645 train_time:142669ms step_avg:92.76ms +step:1539/1645 train_time:142763ms step_avg:92.76ms +step:1540/1645 train_time:142856ms step_avg:92.76ms +step:1541/1645 train_time:142950ms step_avg:92.76ms +step:1542/1645 train_time:143044ms step_avg:92.77ms +step:1543/1645 train_time:143138ms step_avg:92.77ms +step:1544/1645 train_time:143232ms step_avg:92.77ms +step:1545/1645 train_time:143326ms step_avg:92.77ms +step:1546/1645 train_time:143420ms step_avg:92.77ms +step:1547/1645 train_time:143514ms step_avg:92.77ms +step:1548/1645 train_time:143608ms step_avg:92.77ms +step:1549/1645 train_time:143702ms step_avg:92.77ms +step:1550/1645 train_time:143794ms step_avg:92.77ms +step:1551/1645 train_time:143888ms step_avg:92.77ms +step:1552/1645 train_time:143983ms step_avg:92.77ms +step:1553/1645 train_time:144078ms step_avg:92.77ms +step:1554/1645 train_time:144171ms step_avg:92.77ms +step:1555/1645 train_time:144265ms step_avg:92.78ms +step:1556/1645 train_time:144359ms step_avg:92.78ms +step:1557/1645 train_time:144453ms step_avg:92.78ms +step:1558/1645 train_time:144548ms step_avg:92.78ms +step:1559/1645 train_time:144642ms step_avg:92.78ms +step:1560/1645 train_time:144736ms step_avg:92.78ms +step:1561/1645 train_time:144829ms step_avg:92.78ms +step:1562/1645 train_time:144922ms step_avg:92.78ms +step:1563/1645 train_time:145016ms step_avg:92.78ms +step:1564/1645 train_time:145109ms step_avg:92.78ms +step:1565/1645 train_time:145204ms step_avg:92.78ms +step:1566/1645 train_time:145298ms step_avg:92.78ms +step:1567/1645 train_time:145390ms step_avg:92.78ms +step:1568/1645 train_time:145484ms step_avg:92.78ms +step:1569/1645 train_time:145578ms step_avg:92.78ms +step:1570/1645 train_time:145672ms step_avg:92.78ms +step:1571/1645 train_time:145765ms step_avg:92.78ms +step:1572/1645 train_time:145858ms step_avg:92.78ms +step:1573/1645 train_time:145951ms step_avg:92.79ms +step:1574/1645 train_time:146045ms step_avg:92.79ms +step:1575/1645 train_time:146139ms step_avg:92.79ms +step:1576/1645 train_time:146232ms step_avg:92.79ms +step:1577/1645 train_time:146326ms step_avg:92.79ms +step:1578/1645 train_time:146420ms step_avg:92.79ms +step:1579/1645 train_time:146513ms step_avg:92.79ms +step:1580/1645 train_time:146608ms step_avg:92.79ms +step:1581/1645 train_time:146702ms step_avg:92.79ms +step:1582/1645 train_time:146795ms step_avg:92.79ms +step:1583/1645 train_time:146888ms step_avg:92.79ms +step:1584/1645 train_time:146982ms step_avg:92.79ms +step:1585/1645 train_time:147075ms step_avg:92.79ms +step:1586/1645 train_time:147168ms step_avg:92.79ms +step:1587/1645 train_time:147262ms step_avg:92.79ms +step:1588/1645 train_time:147355ms step_avg:92.79ms +step:1589/1645 train_time:147449ms step_avg:92.79ms +step:1590/1645 train_time:147543ms step_avg:92.79ms +step:1591/1645 train_time:147637ms step_avg:92.80ms +step:1592/1645 train_time:147731ms step_avg:92.80ms +step:1593/1645 train_time:147825ms step_avg:92.80ms +step:1594/1645 train_time:147918ms step_avg:92.80ms +step:1595/1645 train_time:148012ms step_avg:92.80ms +step:1596/1645 train_time:148105ms step_avg:92.80ms +step:1597/1645 train_time:148199ms step_avg:92.80ms +step:1598/1645 train_time:148293ms step_avg:92.80ms +step:1599/1645 train_time:148386ms step_avg:92.80ms +step:1600/1645 train_time:148480ms step_avg:92.80ms +step:1601/1645 train_time:148574ms step_avg:92.80ms +step:1602/1645 train_time:148667ms step_avg:92.80ms +step:1603/1645 train_time:148761ms step_avg:92.80ms +step:1604/1645 train_time:148854ms step_avg:92.80ms +step:1605/1645 train_time:148949ms step_avg:92.80ms +step:1606/1645 train_time:149042ms step_avg:92.80ms +step:1607/1645 train_time:149136ms step_avg:92.80ms +step:1608/1645 train_time:149230ms step_avg:92.80ms +step:1609/1645 train_time:149323ms step_avg:92.80ms +step:1610/1645 train_time:149418ms step_avg:92.81ms +step:1611/1645 train_time:149511ms step_avg:92.81ms +step:1612/1645 train_time:149606ms step_avg:92.81ms +step:1613/1645 train_time:149700ms step_avg:92.81ms +step:1614/1645 train_time:149793ms step_avg:92.81ms +step:1615/1645 train_time:149887ms step_avg:92.81ms +step:1616/1645 train_time:149981ms step_avg:92.81ms +step:1617/1645 train_time:150075ms step_avg:92.81ms +step:1618/1645 train_time:150169ms step_avg:92.81ms +step:1619/1645 train_time:150262ms step_avg:92.81ms +step:1620/1645 train_time:150355ms step_avg:92.81ms +step:1621/1645 train_time:150449ms step_avg:92.81ms +step:1622/1645 train_time:150544ms step_avg:92.81ms +step:1623/1645 train_time:150637ms step_avg:92.81ms +step:1624/1645 train_time:150731ms step_avg:92.81ms +step:1625/1645 train_time:150825ms step_avg:92.82ms +step:1625/1645 val_loss:3.2858 train_time:150918ms step_avg:92.87ms +step:1626/1645 train_time:150939ms step_avg:92.83ms +step:1627/1645 train_time:151017ms step_avg:92.82ms +step:1628/1645 train_time:151112ms step_avg:92.82ms +step:1629/1645 train_time:151205ms step_avg:92.82ms +step:1630/1645 train_time:151298ms step_avg:92.82ms +step:1631/1645 train_time:151391ms step_avg:92.82ms +step:1632/1645 train_time:151484ms step_avg:92.82ms +step:1633/1645 train_time:151577ms step_avg:92.82ms +step:1634/1645 train_time:151671ms step_avg:92.82ms +step:1635/1645 train_time:151764ms step_avg:92.82ms +step:1636/1645 train_time:151858ms step_avg:92.82ms +step:1637/1645 train_time:151953ms step_avg:92.82ms +step:1638/1645 train_time:152048ms step_avg:92.83ms +step:1639/1645 train_time:152142ms step_avg:92.83ms +step:1640/1645 train_time:152236ms step_avg:92.83ms +step:1641/1645 train_time:152330ms step_avg:92.83ms +step:1642/1645 train_time:152423ms step_avg:92.83ms +step:1643/1645 train_time:152516ms step_avg:92.83ms +step:1644/1645 train_time:152608ms step_avg:92.83ms +step:1645/1645 train_time:152702ms step_avg:92.83ms +step:1645/1645 val_loss:3.2798 train_time:152796ms step_avg:92.89ms +peak memory allocated: 31659 MiB reserved: 46856 MiB diff --git a/records/091825_Smear/81535293-56a7-49f3-925c-569441d4f87c.txt b/records/091825_Smear/81535293-56a7-49f3-925c-569441d4f87c.txt new file mode 100644 index 000000000..13ba842e2 --- /dev/null +++ b/records/091825_Smear/81535293-56a7-49f3-925c-569441d4f87c.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:20:47 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 34C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 33C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:128ms step_avg:128.03ms +step:2/1645 train_time:145ms step_avg:72.74ms +step:3/1645 train_time:219ms step_avg:72.96ms +step:4/1645 train_time:308ms step_avg:77.09ms +step:5/1645 train_time:399ms step_avg:79.82ms +step:6/1645 train_time:490ms step_avg:81.65ms +step:7/1645 train_time:581ms step_avg:82.97ms +step:8/1645 train_time:671ms step_avg:83.92ms +step:9/1645 train_time:762ms step_avg:84.67ms +step:10/1645 train_time:853ms step_avg:85.31ms +step:11/1645 train_time:944ms step_avg:85.81ms +step:12/1645 train_time:1038ms step_avg:86.52ms +step:13/1645 train_time:1136ms step_avg:87.40ms +step:14/1645 train_time:1232ms step_avg:88.02ms +step:15/1645 train_time:1324ms step_avg:88.28ms +step:16/1645 train_time:1415ms step_avg:88.46ms +step:17/1645 train_time:1508ms step_avg:88.68ms +step:18/1645 train_time:1599ms step_avg:88.81ms +step:19/1645 train_time:1691ms step_avg:89.00ms +step:20/1645 train_time:1782ms step_avg:89.11ms +step:21/1645 train_time:1873ms step_avg:89.19ms +step:22/1645 train_time:1965ms step_avg:89.33ms +step:23/1645 train_time:2058ms step_avg:89.47ms +step:24/1645 train_time:2154ms step_avg:89.75ms +step:25/1645 train_time:2247ms step_avg:89.89ms +step:26/1645 train_time:2339ms step_avg:89.98ms +step:27/1645 train_time:2431ms step_avg:90.04ms +step:28/1645 train_time:2523ms step_avg:90.10ms +step:29/1645 train_time:2614ms step_avg:90.13ms +step:30/1645 train_time:2706ms step_avg:90.19ms +step:31/1645 train_time:2797ms step_avg:90.23ms +step:32/1645 train_time:2889ms step_avg:90.27ms +step:33/1645 train_time:2980ms step_avg:90.30ms +step:34/1645 train_time:3074ms step_avg:90.41ms +step:35/1645 train_time:3168ms step_avg:90.52ms +step:36/1645 train_time:3261ms step_avg:90.57ms +step:37/1645 train_time:3353ms step_avg:90.62ms +step:38/1645 train_time:3446ms step_avg:90.68ms +step:39/1645 train_time:3537ms step_avg:90.69ms +step:40/1645 train_time:3629ms step_avg:90.72ms +step:41/1645 train_time:3721ms step_avg:90.74ms +step:42/1645 train_time:3812ms step_avg:90.75ms +step:43/1645 train_time:3903ms step_avg:90.76ms +step:44/1645 train_time:3995ms step_avg:90.79ms +step:45/1645 train_time:4087ms step_avg:90.82ms +step:46/1645 train_time:4179ms step_avg:90.85ms +step:47/1645 train_time:4272ms step_avg:90.90ms +step:48/1645 train_time:4365ms step_avg:90.93ms +step:49/1645 train_time:4456ms step_avg:90.95ms +step:50/1645 train_time:4548ms step_avg:90.96ms +step:51/1645 train_time:4640ms step_avg:90.98ms +step:52/1645 train_time:4732ms step_avg:91.00ms +step:53/1645 train_time:4823ms step_avg:91.01ms +step:54/1645 train_time:4915ms step_avg:91.01ms +step:55/1645 train_time:5006ms step_avg:91.02ms +step:56/1645 train_time:5098ms step_avg:91.04ms +step:57/1645 train_time:5190ms step_avg:91.06ms +step:58/1645 train_time:5283ms step_avg:91.08ms +step:59/1645 train_time:5375ms step_avg:91.10ms +step:60/1645 train_time:5468ms step_avg:91.14ms +step:61/1645 train_time:5560ms step_avg:91.15ms +step:62/1645 train_time:5652ms step_avg:91.16ms +step:63/1645 train_time:5744ms step_avg:91.17ms +step:64/1645 train_time:5835ms step_avg:91.18ms +step:65/1645 train_time:5927ms step_avg:91.19ms +step:66/1645 train_time:6018ms step_avg:91.19ms +step:67/1645 train_time:6110ms step_avg:91.20ms +step:68/1645 train_time:6202ms step_avg:91.21ms +step:69/1645 train_time:6295ms step_avg:91.23ms +step:70/1645 train_time:6387ms step_avg:91.24ms +step:71/1645 train_time:6478ms step_avg:91.24ms +step:72/1645 train_time:6572ms step_avg:91.28ms +step:73/1645 train_time:6664ms step_avg:91.29ms +step:74/1645 train_time:6756ms step_avg:91.30ms +step:75/1645 train_time:6849ms step_avg:91.32ms +step:76/1645 train_time:6941ms step_avg:91.33ms +step:77/1645 train_time:7033ms step_avg:91.33ms +step:78/1645 train_time:7124ms step_avg:91.33ms +step:79/1645 train_time:7216ms step_avg:91.34ms +step:80/1645 train_time:7307ms step_avg:91.34ms +step:81/1645 train_time:7398ms step_avg:91.33ms +step:82/1645 train_time:7490ms step_avg:91.35ms +step:83/1645 train_time:7582ms step_avg:91.35ms +step:84/1645 train_time:7674ms step_avg:91.35ms +step:85/1645 train_time:7766ms step_avg:91.36ms +step:86/1645 train_time:7858ms step_avg:91.37ms +step:87/1645 train_time:7950ms step_avg:91.38ms +step:88/1645 train_time:8042ms step_avg:91.39ms +step:89/1645 train_time:8134ms step_avg:91.39ms +step:90/1645 train_time:8225ms step_avg:91.39ms +step:91/1645 train_time:8316ms step_avg:91.39ms +step:92/1645 train_time:8408ms step_avg:91.39ms +step:93/1645 train_time:8499ms step_avg:91.39ms +step:94/1645 train_time:8591ms step_avg:91.40ms +step:95/1645 train_time:8683ms step_avg:91.40ms +step:96/1645 train_time:8775ms step_avg:91.40ms +step:97/1645 train_time:8867ms step_avg:91.41ms +step:98/1645 train_time:8959ms step_avg:91.42ms +step:99/1645 train_time:9052ms step_avg:91.43ms +step:100/1645 train_time:9144ms step_avg:91.44ms +step:101/1645 train_time:9235ms step_avg:91.43ms +step:102/1645 train_time:9326ms step_avg:91.43ms +step:103/1645 train_time:9418ms step_avg:91.44ms +step:104/1645 train_time:9510ms step_avg:91.44ms +step:105/1645 train_time:9602ms step_avg:91.44ms +step:106/1645 train_time:9693ms step_avg:91.45ms +step:107/1645 train_time:9786ms step_avg:91.46ms +step:108/1645 train_time:9877ms step_avg:91.45ms +step:109/1645 train_time:9969ms step_avg:91.46ms +step:110/1645 train_time:10061ms step_avg:91.46ms +step:111/1645 train_time:10153ms step_avg:91.47ms +step:112/1645 train_time:10245ms step_avg:91.47ms +step:113/1645 train_time:10336ms step_avg:91.47ms +step:114/1645 train_time:10427ms step_avg:91.47ms +step:115/1645 train_time:10519ms step_avg:91.47ms +step:116/1645 train_time:10611ms step_avg:91.47ms +step:117/1645 train_time:10702ms step_avg:91.47ms +step:118/1645 train_time:10794ms step_avg:91.47ms +step:119/1645 train_time:10886ms step_avg:91.48ms +step:120/1645 train_time:10977ms step_avg:91.47ms +step:121/1645 train_time:11069ms step_avg:91.48ms +step:122/1645 train_time:11161ms step_avg:91.49ms +step:123/1645 train_time:11253ms step_avg:91.49ms +step:124/1645 train_time:11345ms step_avg:91.49ms +step:125/1645 train_time:11437ms step_avg:91.49ms +step:125/1645 val_loss:4.3196 train_time:11529ms step_avg:92.23ms +step:126/1645 train_time:11552ms step_avg:91.68ms +step:127/1645 train_time:11626ms step_avg:91.54ms +step:128/1645 train_time:11727ms step_avg:91.62ms +step:129/1645 train_time:11821ms step_avg:91.64ms +step:130/1645 train_time:11913ms step_avg:91.64ms +step:131/1645 train_time:12004ms step_avg:91.63ms +step:132/1645 train_time:12095ms step_avg:91.63ms +step:133/1645 train_time:12185ms step_avg:91.62ms +step:134/1645 train_time:12276ms step_avg:91.61ms +step:135/1645 train_time:12367ms step_avg:91.60ms +step:136/1645 train_time:12457ms step_avg:91.59ms +step:137/1645 train_time:12549ms step_avg:91.60ms +step:138/1645 train_time:12643ms step_avg:91.62ms +step:139/1645 train_time:12736ms step_avg:91.63ms +step:140/1645 train_time:12829ms step_avg:91.64ms +step:141/1645 train_time:12920ms step_avg:91.63ms +step:142/1645 train_time:13012ms step_avg:91.63ms +step:143/1645 train_time:13103ms step_avg:91.63ms +step:144/1645 train_time:13195ms step_avg:91.63ms +step:145/1645 train_time:13286ms step_avg:91.63ms +step:146/1645 train_time:13377ms step_avg:91.62ms +step:147/1645 train_time:13467ms step_avg:91.62ms +step:148/1645 train_time:13560ms step_avg:91.62ms +step:149/1645 train_time:13653ms step_avg:91.63ms +step:150/1645 train_time:13747ms step_avg:91.64ms +step:151/1645 train_time:13838ms step_avg:91.64ms +step:152/1645 train_time:13930ms step_avg:91.65ms +step:153/1645 train_time:14021ms step_avg:91.64ms +step:154/1645 train_time:14114ms step_avg:91.65ms +step:155/1645 train_time:14206ms step_avg:91.65ms +step:156/1645 train_time:14297ms step_avg:91.65ms +step:157/1645 train_time:14388ms step_avg:91.64ms +step:158/1645 train_time:14479ms step_avg:91.64ms +step:159/1645 train_time:14570ms step_avg:91.64ms +step:160/1645 train_time:14663ms step_avg:91.64ms +step:161/1645 train_time:14756ms step_avg:91.65ms +step:162/1645 train_time:14847ms step_avg:91.65ms +step:163/1645 train_time:14939ms step_avg:91.65ms +step:164/1645 train_time:15031ms step_avg:91.65ms +step:165/1645 train_time:15122ms step_avg:91.65ms +step:166/1645 train_time:15215ms step_avg:91.66ms +step:167/1645 train_time:15307ms step_avg:91.66ms +step:168/1645 train_time:15398ms step_avg:91.65ms +step:169/1645 train_time:15489ms step_avg:91.65ms +step:170/1645 train_time:15580ms step_avg:91.65ms +step:171/1645 train_time:15673ms step_avg:91.65ms +step:172/1645 train_time:15765ms step_avg:91.66ms +step:173/1645 train_time:15857ms step_avg:91.66ms +step:174/1645 train_time:15948ms step_avg:91.65ms +step:175/1645 train_time:16039ms step_avg:91.65ms +step:176/1645 train_time:16131ms step_avg:91.65ms +step:177/1645 train_time:16222ms step_avg:91.65ms +step:178/1645 train_time:16315ms step_avg:91.66ms +step:179/1645 train_time:16408ms step_avg:91.66ms +step:180/1645 train_time:16499ms step_avg:91.66ms +step:181/1645 train_time:16591ms step_avg:91.66ms +step:182/1645 train_time:16683ms step_avg:91.66ms +step:183/1645 train_time:16774ms step_avg:91.66ms +step:184/1645 train_time:16865ms step_avg:91.66ms +step:185/1645 train_time:16956ms step_avg:91.66ms +step:186/1645 train_time:17048ms step_avg:91.66ms +step:187/1645 train_time:17140ms step_avg:91.66ms +step:188/1645 train_time:17231ms step_avg:91.65ms +step:189/1645 train_time:17322ms step_avg:91.65ms +step:190/1645 train_time:17414ms step_avg:91.65ms +step:191/1645 train_time:17506ms step_avg:91.65ms +step:192/1645 train_time:17598ms step_avg:91.65ms +step:193/1645 train_time:17689ms step_avg:91.65ms +step:194/1645 train_time:17781ms step_avg:91.65ms +step:195/1645 train_time:17873ms step_avg:91.65ms +step:196/1645 train_time:17964ms step_avg:91.65ms +step:197/1645 train_time:18056ms step_avg:91.65ms +step:198/1645 train_time:18147ms step_avg:91.65ms +step:199/1645 train_time:18239ms step_avg:91.65ms +step:200/1645 train_time:18331ms step_avg:91.65ms +step:201/1645 train_time:18422ms step_avg:91.65ms +step:202/1645 train_time:18514ms step_avg:91.65ms +step:203/1645 train_time:18606ms step_avg:91.65ms +step:204/1645 train_time:18697ms step_avg:91.65ms +step:205/1645 train_time:18790ms step_avg:91.66ms +step:206/1645 train_time:18881ms step_avg:91.65ms +step:207/1645 train_time:18974ms step_avg:91.66ms +step:208/1645 train_time:19064ms step_avg:91.65ms +step:209/1645 train_time:19156ms step_avg:91.65ms +step:210/1645 train_time:19248ms step_avg:91.66ms +step:211/1645 train_time:19339ms step_avg:91.66ms +step:212/1645 train_time:19432ms step_avg:91.66ms +step:213/1645 train_time:19523ms step_avg:91.66ms +step:214/1645 train_time:19616ms step_avg:91.66ms +step:215/1645 train_time:19708ms step_avg:91.67ms +step:216/1645 train_time:19800ms step_avg:91.67ms +step:217/1645 train_time:19893ms step_avg:91.67ms +step:218/1645 train_time:19984ms step_avg:91.67ms +step:219/1645 train_time:20077ms step_avg:91.67ms +step:220/1645 train_time:20168ms step_avg:91.67ms +step:221/1645 train_time:20259ms step_avg:91.67ms +step:222/1645 train_time:20350ms step_avg:91.67ms +step:223/1645 train_time:20442ms step_avg:91.67ms +step:224/1645 train_time:20534ms step_avg:91.67ms +step:225/1645 train_time:20626ms step_avg:91.67ms +step:226/1645 train_time:20718ms step_avg:91.67ms +step:227/1645 train_time:20811ms step_avg:91.68ms +step:228/1645 train_time:20903ms step_avg:91.68ms +step:229/1645 train_time:20994ms step_avg:91.68ms +step:230/1645 train_time:21086ms step_avg:91.68ms +step:231/1645 train_time:21177ms step_avg:91.68ms +step:232/1645 train_time:21269ms step_avg:91.68ms +step:233/1645 train_time:21360ms step_avg:91.67ms +step:234/1645 train_time:21453ms step_avg:91.68ms +step:235/1645 train_time:21543ms step_avg:91.67ms +step:236/1645 train_time:21635ms step_avg:91.67ms +step:237/1645 train_time:21727ms step_avg:91.67ms +step:238/1645 train_time:21818ms step_avg:91.67ms +step:239/1645 train_time:21911ms step_avg:91.68ms +step:240/1645 train_time:22002ms step_avg:91.67ms +step:241/1645 train_time:22094ms step_avg:91.68ms +step:242/1645 train_time:22185ms step_avg:91.67ms +step:243/1645 train_time:22277ms step_avg:91.67ms +step:244/1645 train_time:22368ms step_avg:91.67ms +step:245/1645 train_time:22460ms step_avg:91.67ms +step:246/1645 train_time:22551ms step_avg:91.67ms +step:247/1645 train_time:22642ms step_avg:91.67ms +step:248/1645 train_time:22735ms step_avg:91.67ms +step:249/1645 train_time:22827ms step_avg:91.67ms +step:250/1645 train_time:22919ms step_avg:91.68ms +step:250/1645 val_loss:3.9665 train_time:23011ms step_avg:92.04ms +step:251/1645 train_time:23032ms step_avg:91.76ms +step:252/1645 train_time:23108ms step_avg:91.70ms +step:253/1645 train_time:23200ms step_avg:91.70ms +step:254/1645 train_time:23292ms step_avg:91.70ms +step:255/1645 train_time:23382ms step_avg:91.70ms +step:256/1645 train_time:23474ms step_avg:91.70ms +step:257/1645 train_time:23565ms step_avg:91.69ms +step:258/1645 train_time:23656ms step_avg:91.69ms +step:259/1645 train_time:23748ms step_avg:91.69ms +step:260/1645 train_time:23839ms step_avg:91.69ms +step:261/1645 train_time:23932ms step_avg:91.69ms +step:262/1645 train_time:24025ms step_avg:91.70ms +step:263/1645 train_time:24118ms step_avg:91.70ms +step:264/1645 train_time:24210ms step_avg:91.70ms +step:265/1645 train_time:24301ms step_avg:91.70ms +step:266/1645 train_time:24392ms step_avg:91.70ms +step:267/1645 train_time:24483ms step_avg:91.70ms +step:268/1645 train_time:24574ms step_avg:91.70ms +step:269/1645 train_time:24665ms step_avg:91.69ms +step:270/1645 train_time:24756ms step_avg:91.69ms +step:271/1645 train_time:24848ms step_avg:91.69ms +step:272/1645 train_time:24940ms step_avg:91.69ms +step:273/1645 train_time:25032ms step_avg:91.69ms +step:274/1645 train_time:25125ms step_avg:91.70ms +step:275/1645 train_time:25217ms step_avg:91.70ms +step:276/1645 train_time:25308ms step_avg:91.70ms +step:277/1645 train_time:25400ms step_avg:91.69ms +step:278/1645 train_time:25490ms step_avg:91.69ms +step:279/1645 train_time:25582ms step_avg:91.69ms +step:280/1645 train_time:25674ms step_avg:91.69ms +step:281/1645 train_time:25765ms step_avg:91.69ms +step:282/1645 train_time:25856ms step_avg:91.69ms +step:283/1645 train_time:25949ms step_avg:91.69ms +step:284/1645 train_time:26042ms step_avg:91.70ms +step:285/1645 train_time:26134ms step_avg:91.70ms +step:286/1645 train_time:26226ms step_avg:91.70ms +step:287/1645 train_time:26318ms step_avg:91.70ms +step:288/1645 train_time:26409ms step_avg:91.70ms +step:289/1645 train_time:26500ms step_avg:91.70ms +step:290/1645 train_time:26592ms step_avg:91.70ms +step:291/1645 train_time:26683ms step_avg:91.70ms +step:292/1645 train_time:26775ms step_avg:91.69ms +step:293/1645 train_time:26866ms step_avg:91.69ms +step:294/1645 train_time:26957ms step_avg:91.69ms +step:295/1645 train_time:27049ms step_avg:91.69ms +step:296/1645 train_time:27142ms step_avg:91.70ms +step:297/1645 train_time:27233ms step_avg:91.69ms +step:298/1645 train_time:27325ms step_avg:91.69ms +step:299/1645 train_time:27416ms step_avg:91.69ms +step:300/1645 train_time:27507ms step_avg:91.69ms +step:301/1645 train_time:27598ms step_avg:91.69ms +step:302/1645 train_time:27690ms step_avg:91.69ms +step:303/1645 train_time:27782ms step_avg:91.69ms +step:304/1645 train_time:27874ms step_avg:91.69ms +step:305/1645 train_time:27967ms step_avg:91.70ms +step:306/1645 train_time:28058ms step_avg:91.69ms +step:307/1645 train_time:28150ms step_avg:91.69ms +step:308/1645 train_time:28242ms step_avg:91.70ms +step:309/1645 train_time:28334ms step_avg:91.70ms +step:310/1645 train_time:28425ms step_avg:91.69ms +step:311/1645 train_time:28516ms step_avg:91.69ms +step:312/1645 train_time:28608ms step_avg:91.69ms +step:313/1645 train_time:28699ms step_avg:91.69ms +step:314/1645 train_time:28791ms step_avg:91.69ms +step:315/1645 train_time:28883ms step_avg:91.69ms +step:316/1645 train_time:28975ms step_avg:91.69ms +step:317/1645 train_time:29068ms step_avg:91.70ms +step:318/1645 train_time:29161ms step_avg:91.70ms +step:319/1645 train_time:29253ms step_avg:91.70ms +step:320/1645 train_time:29344ms step_avg:91.70ms +step:321/1645 train_time:29435ms step_avg:91.70ms +step:322/1645 train_time:29526ms step_avg:91.70ms +step:323/1645 train_time:29618ms step_avg:91.70ms +step:324/1645 train_time:29709ms step_avg:91.69ms +step:325/1645 train_time:29800ms step_avg:91.69ms +step:326/1645 train_time:29892ms step_avg:91.69ms +step:327/1645 train_time:29984ms step_avg:91.69ms +step:328/1645 train_time:30075ms step_avg:91.69ms +step:329/1645 train_time:30168ms step_avg:91.70ms +step:330/1645 train_time:30261ms step_avg:91.70ms +step:331/1645 train_time:30353ms step_avg:91.70ms +step:332/1645 train_time:30445ms step_avg:91.70ms +step:333/1645 train_time:30536ms step_avg:91.70ms +step:334/1645 train_time:30628ms step_avg:91.70ms +step:335/1645 train_time:30719ms step_avg:91.70ms +step:336/1645 train_time:30811ms step_avg:91.70ms +step:337/1645 train_time:30901ms step_avg:91.70ms +step:338/1645 train_time:30993ms step_avg:91.70ms +step:339/1645 train_time:31085ms step_avg:91.70ms +step:340/1645 train_time:31177ms step_avg:91.70ms +step:341/1645 train_time:31271ms step_avg:91.70ms +step:342/1645 train_time:31362ms step_avg:91.70ms +step:343/1645 train_time:31454ms step_avg:91.70ms +step:344/1645 train_time:31546ms step_avg:91.70ms +step:345/1645 train_time:31638ms step_avg:91.70ms +step:346/1645 train_time:31729ms step_avg:91.70ms +step:347/1645 train_time:31821ms step_avg:91.70ms +step:348/1645 train_time:31911ms step_avg:91.70ms +step:349/1645 train_time:32003ms step_avg:91.70ms +step:350/1645 train_time:32095ms step_avg:91.70ms +step:351/1645 train_time:32188ms step_avg:91.70ms +step:352/1645 train_time:32279ms step_avg:91.70ms +step:353/1645 train_time:32372ms step_avg:91.70ms +step:354/1645 train_time:32464ms step_avg:91.71ms +step:355/1645 train_time:32555ms step_avg:91.71ms +step:356/1645 train_time:32647ms step_avg:91.71ms +step:357/1645 train_time:32738ms step_avg:91.70ms +step:358/1645 train_time:32830ms step_avg:91.70ms +step:359/1645 train_time:32921ms step_avg:91.70ms +step:360/1645 train_time:33012ms step_avg:91.70ms +step:361/1645 train_time:33104ms step_avg:91.70ms +step:362/1645 train_time:33195ms step_avg:91.70ms +step:363/1645 train_time:33287ms step_avg:91.70ms +step:364/1645 train_time:33379ms step_avg:91.70ms +step:365/1645 train_time:33472ms step_avg:91.70ms +step:366/1645 train_time:33563ms step_avg:91.70ms +step:367/1645 train_time:33655ms step_avg:91.70ms +step:368/1645 train_time:33747ms step_avg:91.70ms +step:369/1645 train_time:33839ms step_avg:91.71ms +step:370/1645 train_time:33931ms step_avg:91.70ms +step:371/1645 train_time:34022ms step_avg:91.70ms +step:372/1645 train_time:34113ms step_avg:91.70ms +step:373/1645 train_time:34205ms step_avg:91.70ms +step:374/1645 train_time:34297ms step_avg:91.70ms +step:375/1645 train_time:34389ms step_avg:91.70ms +step:375/1645 val_loss:3.8146 train_time:34481ms step_avg:91.95ms +step:376/1645 train_time:34500ms step_avg:91.75ms +step:377/1645 train_time:34579ms step_avg:91.72ms +step:378/1645 train_time:34672ms step_avg:91.72ms +step:379/1645 train_time:34763ms step_avg:91.72ms +step:380/1645 train_time:34853ms step_avg:91.72ms +step:381/1645 train_time:34944ms step_avg:91.72ms +step:382/1645 train_time:35034ms step_avg:91.71ms +step:383/1645 train_time:35126ms step_avg:91.71ms +step:384/1645 train_time:35217ms step_avg:91.71ms +step:385/1645 train_time:35308ms step_avg:91.71ms +step:386/1645 train_time:35400ms step_avg:91.71ms +step:387/1645 train_time:35494ms step_avg:91.72ms +step:388/1645 train_time:35587ms step_avg:91.72ms +step:389/1645 train_time:35679ms step_avg:91.72ms +step:390/1645 train_time:35771ms step_avg:91.72ms +step:391/1645 train_time:35862ms step_avg:91.72ms +step:392/1645 train_time:35954ms step_avg:91.72ms +step:393/1645 train_time:36044ms step_avg:91.72ms +step:394/1645 train_time:36135ms step_avg:91.71ms +step:395/1645 train_time:36226ms step_avg:91.71ms +step:396/1645 train_time:36318ms step_avg:91.71ms +step:397/1645 train_time:36411ms step_avg:91.72ms +step:398/1645 train_time:36504ms step_avg:91.72ms +step:399/1645 train_time:36597ms step_avg:91.72ms +step:400/1645 train_time:36690ms step_avg:91.73ms +step:401/1645 train_time:36781ms step_avg:91.72ms +step:402/1645 train_time:36872ms step_avg:91.72ms +step:403/1645 train_time:36964ms step_avg:91.72ms +step:404/1645 train_time:37054ms step_avg:91.72ms +step:405/1645 train_time:37145ms step_avg:91.72ms +step:406/1645 train_time:37236ms step_avg:91.71ms +step:407/1645 train_time:37327ms step_avg:91.71ms +step:408/1645 train_time:37419ms step_avg:91.71ms +step:409/1645 train_time:37512ms step_avg:91.72ms +step:410/1645 train_time:37605ms step_avg:91.72ms +step:411/1645 train_time:37697ms step_avg:91.72ms +step:412/1645 train_time:37790ms step_avg:91.72ms +step:413/1645 train_time:37881ms step_avg:91.72ms +step:414/1645 train_time:37973ms step_avg:91.72ms +step:415/1645 train_time:38064ms step_avg:91.72ms +step:416/1645 train_time:38154ms step_avg:91.72ms +step:417/1645 train_time:38246ms step_avg:91.72ms +step:418/1645 train_time:38338ms step_avg:91.72ms +step:419/1645 train_time:38430ms step_avg:91.72ms +step:420/1645 train_time:38524ms step_avg:91.72ms +step:421/1645 train_time:38615ms step_avg:91.72ms +step:422/1645 train_time:38708ms step_avg:91.73ms +step:423/1645 train_time:38800ms step_avg:91.73ms +step:424/1645 train_time:38891ms step_avg:91.72ms +step:425/1645 train_time:38983ms step_avg:91.72ms +step:426/1645 train_time:39075ms step_avg:91.72ms +step:427/1645 train_time:39165ms step_avg:91.72ms +step:428/1645 train_time:39256ms step_avg:91.72ms +step:429/1645 train_time:39347ms step_avg:91.72ms +step:430/1645 train_time:39439ms step_avg:91.72ms +step:431/1645 train_time:39531ms step_avg:91.72ms +step:432/1645 train_time:39623ms step_avg:91.72ms +step:433/1645 train_time:39714ms step_avg:91.72ms +step:434/1645 train_time:39807ms step_avg:91.72ms +step:435/1645 train_time:39899ms step_avg:91.72ms +step:436/1645 train_time:39990ms step_avg:91.72ms +step:437/1645 train_time:40081ms step_avg:91.72ms +step:438/1645 train_time:40172ms step_avg:91.72ms +step:439/1645 train_time:40263ms step_avg:91.72ms +step:440/1645 train_time:40355ms step_avg:91.72ms +step:441/1645 train_time:40446ms step_avg:91.71ms +step:442/1645 train_time:40538ms step_avg:91.72ms +step:443/1645 train_time:40630ms step_avg:91.72ms +step:444/1645 train_time:40722ms step_avg:91.72ms +step:445/1645 train_time:40815ms step_avg:91.72ms +step:446/1645 train_time:40907ms step_avg:91.72ms +step:447/1645 train_time:40999ms step_avg:91.72ms +step:448/1645 train_time:41090ms step_avg:91.72ms +step:449/1645 train_time:41181ms step_avg:91.72ms +step:450/1645 train_time:41272ms step_avg:91.72ms +step:451/1645 train_time:41363ms step_avg:91.71ms +step:452/1645 train_time:41455ms step_avg:91.71ms +step:453/1645 train_time:41547ms step_avg:91.71ms +step:454/1645 train_time:41639ms step_avg:91.71ms +step:455/1645 train_time:41731ms step_avg:91.72ms +step:456/1645 train_time:41824ms step_avg:91.72ms +step:457/1645 train_time:41916ms step_avg:91.72ms +step:458/1645 train_time:42008ms step_avg:91.72ms +step:459/1645 train_time:42099ms step_avg:91.72ms +step:460/1645 train_time:42191ms step_avg:91.72ms +step:461/1645 train_time:42282ms step_avg:91.72ms +step:462/1645 train_time:42373ms step_avg:91.72ms +step:463/1645 train_time:42464ms step_avg:91.71ms +step:464/1645 train_time:42555ms step_avg:91.71ms +step:465/1645 train_time:42647ms step_avg:91.71ms +step:466/1645 train_time:42739ms step_avg:91.71ms +step:467/1645 train_time:42831ms step_avg:91.71ms +step:468/1645 train_time:42923ms step_avg:91.72ms +step:469/1645 train_time:43015ms step_avg:91.72ms +step:470/1645 train_time:43107ms step_avg:91.72ms +step:471/1645 train_time:43198ms step_avg:91.72ms +step:472/1645 train_time:43289ms step_avg:91.71ms +step:473/1645 train_time:43381ms step_avg:91.71ms +step:474/1645 train_time:43472ms step_avg:91.71ms +step:475/1645 train_time:43564ms step_avg:91.71ms +step:476/1645 train_time:43655ms step_avg:91.71ms +step:477/1645 train_time:43747ms step_avg:91.71ms +step:478/1645 train_time:43839ms step_avg:91.71ms +step:479/1645 train_time:43931ms step_avg:91.71ms +step:480/1645 train_time:44023ms step_avg:91.71ms +step:481/1645 train_time:44115ms step_avg:91.72ms +step:482/1645 train_time:44207ms step_avg:91.72ms +step:483/1645 train_time:44299ms step_avg:91.72ms +step:484/1645 train_time:44390ms step_avg:91.71ms +step:485/1645 train_time:44481ms step_avg:91.71ms +step:486/1645 train_time:44573ms step_avg:91.71ms +step:487/1645 train_time:44664ms step_avg:91.71ms +step:488/1645 train_time:44755ms step_avg:91.71ms +step:489/1645 train_time:44847ms step_avg:91.71ms +step:490/1645 train_time:44939ms step_avg:91.71ms +step:491/1645 train_time:45030ms step_avg:91.71ms +step:492/1645 train_time:45123ms step_avg:91.71ms +step:493/1645 train_time:45214ms step_avg:91.71ms +step:494/1645 train_time:45306ms step_avg:91.71ms +step:495/1645 train_time:45399ms step_avg:91.71ms +step:496/1645 train_time:45490ms step_avg:91.71ms +step:497/1645 train_time:45581ms step_avg:91.71ms +step:498/1645 train_time:45674ms step_avg:91.71ms +step:499/1645 train_time:45764ms step_avg:91.71ms +step:500/1645 train_time:45855ms step_avg:91.71ms +step:500/1645 val_loss:3.7127 train_time:45948ms step_avg:91.90ms +step:501/1645 train_time:45967ms step_avg:91.75ms +step:502/1645 train_time:46044ms step_avg:91.72ms +step:503/1645 train_time:46140ms step_avg:91.73ms +step:504/1645 train_time:46232ms step_avg:91.73ms +step:505/1645 train_time:46323ms step_avg:91.73ms +step:506/1645 train_time:46413ms step_avg:91.73ms +step:507/1645 train_time:46503ms step_avg:91.72ms +step:508/1645 train_time:46594ms step_avg:91.72ms +step:509/1645 train_time:46685ms step_avg:91.72ms +step:510/1645 train_time:46775ms step_avg:91.72ms +step:511/1645 train_time:46867ms step_avg:91.72ms +step:512/1645 train_time:46961ms step_avg:91.72ms +step:513/1645 train_time:47054ms step_avg:91.72ms +step:514/1645 train_time:47148ms step_avg:91.73ms +step:515/1645 train_time:47240ms step_avg:91.73ms +step:516/1645 train_time:47332ms step_avg:91.73ms +step:517/1645 train_time:47423ms step_avg:91.73ms +step:518/1645 train_time:47513ms step_avg:91.72ms +step:519/1645 train_time:47605ms step_avg:91.72ms +step:520/1645 train_time:47696ms step_avg:91.72ms +step:521/1645 train_time:47788ms step_avg:91.72ms +step:522/1645 train_time:47879ms step_avg:91.72ms +step:523/1645 train_time:47972ms step_avg:91.72ms +step:524/1645 train_time:48065ms step_avg:91.73ms +step:525/1645 train_time:48158ms step_avg:91.73ms +step:526/1645 train_time:48251ms step_avg:91.73ms +step:527/1645 train_time:48342ms step_avg:91.73ms +step:528/1645 train_time:48433ms step_avg:91.73ms +step:529/1645 train_time:48524ms step_avg:91.73ms +step:530/1645 train_time:48616ms step_avg:91.73ms +step:531/1645 train_time:48706ms step_avg:91.73ms +step:532/1645 train_time:48797ms step_avg:91.72ms +step:533/1645 train_time:48889ms step_avg:91.72ms +step:534/1645 train_time:48982ms step_avg:91.73ms +step:535/1645 train_time:49075ms step_avg:91.73ms +step:536/1645 train_time:49167ms step_avg:91.73ms +step:537/1645 train_time:49260ms step_avg:91.73ms +step:538/1645 train_time:49353ms step_avg:91.73ms +step:539/1645 train_time:49444ms step_avg:91.73ms +step:540/1645 train_time:49535ms step_avg:91.73ms +step:541/1645 train_time:49626ms step_avg:91.73ms +step:542/1645 train_time:49716ms step_avg:91.73ms +step:543/1645 train_time:49807ms step_avg:91.73ms +step:544/1645 train_time:49898ms step_avg:91.72ms +step:545/1645 train_time:49990ms step_avg:91.72ms +step:546/1645 train_time:50084ms step_avg:91.73ms +step:547/1645 train_time:50177ms step_avg:91.73ms +step:548/1645 train_time:50269ms step_avg:91.73ms +step:549/1645 train_time:50361ms step_avg:91.73ms +step:550/1645 train_time:50453ms step_avg:91.73ms +step:551/1645 train_time:50546ms step_avg:91.73ms +step:552/1645 train_time:50638ms step_avg:91.74ms +step:553/1645 train_time:50730ms step_avg:91.74ms +step:554/1645 train_time:50823ms step_avg:91.74ms +step:555/1645 train_time:50916ms step_avg:91.74ms +step:556/1645 train_time:51009ms step_avg:91.74ms +step:557/1645 train_time:51102ms step_avg:91.75ms +step:558/1645 train_time:51197ms step_avg:91.75ms +step:559/1645 train_time:51290ms step_avg:91.75ms +step:560/1645 train_time:51385ms step_avg:91.76ms +step:561/1645 train_time:51477ms step_avg:91.76ms +step:562/1645 train_time:51570ms step_avg:91.76ms +step:563/1645 train_time:51663ms step_avg:91.76ms +step:564/1645 train_time:51756ms step_avg:91.77ms +step:565/1645 train_time:51848ms step_avg:91.77ms +step:566/1645 train_time:51941ms step_avg:91.77ms +step:567/1645 train_time:52034ms step_avg:91.77ms +step:568/1645 train_time:52127ms step_avg:91.77ms +step:569/1645 train_time:52221ms step_avg:91.78ms +step:570/1645 train_time:52313ms step_avg:91.78ms +step:571/1645 train_time:52406ms step_avg:91.78ms +step:572/1645 train_time:52499ms step_avg:91.78ms +step:573/1645 train_time:52592ms step_avg:91.78ms +step:574/1645 train_time:52685ms step_avg:91.79ms +step:575/1645 train_time:52778ms step_avg:91.79ms +step:576/1645 train_time:52870ms step_avg:91.79ms +step:577/1645 train_time:52964ms step_avg:91.79ms +step:578/1645 train_time:53057ms step_avg:91.79ms +step:579/1645 train_time:53149ms step_avg:91.80ms +step:580/1645 train_time:53242ms step_avg:91.80ms +step:581/1645 train_time:53336ms step_avg:91.80ms +step:582/1645 train_time:53429ms step_avg:91.80ms +step:583/1645 train_time:53522ms step_avg:91.80ms +step:584/1645 train_time:53616ms step_avg:91.81ms +step:585/1645 train_time:53709ms step_avg:91.81ms +step:586/1645 train_time:53802ms step_avg:91.81ms +step:587/1645 train_time:53895ms step_avg:91.81ms +step:588/1645 train_time:53987ms step_avg:91.82ms +step:589/1645 train_time:54081ms step_avg:91.82ms +step:590/1645 train_time:54174ms step_avg:91.82ms +step:591/1645 train_time:54268ms step_avg:91.82ms +step:592/1645 train_time:54360ms step_avg:91.83ms +step:593/1645 train_time:54453ms step_avg:91.83ms +step:594/1645 train_time:54547ms step_avg:91.83ms +step:595/1645 train_time:54640ms step_avg:91.83ms +step:596/1645 train_time:54733ms step_avg:91.83ms +step:597/1645 train_time:54825ms step_avg:91.83ms +step:598/1645 train_time:54918ms step_avg:91.84ms +step:599/1645 train_time:55011ms step_avg:91.84ms +step:600/1645 train_time:55104ms step_avg:91.84ms +step:601/1645 train_time:55198ms step_avg:91.84ms +step:602/1645 train_time:55291ms step_avg:91.84ms +step:603/1645 train_time:55384ms step_avg:91.85ms +step:604/1645 train_time:55477ms step_avg:91.85ms +step:605/1645 train_time:55569ms step_avg:91.85ms +step:606/1645 train_time:55663ms step_avg:91.85ms +step:607/1645 train_time:55755ms step_avg:91.85ms +step:608/1645 train_time:55848ms step_avg:91.86ms +step:609/1645 train_time:55941ms step_avg:91.86ms +step:610/1645 train_time:56033ms step_avg:91.86ms +step:611/1645 train_time:56127ms step_avg:91.86ms +step:612/1645 train_time:56219ms step_avg:91.86ms +step:613/1645 train_time:56314ms step_avg:91.87ms +step:614/1645 train_time:56408ms step_avg:91.87ms +step:615/1645 train_time:56499ms step_avg:91.87ms +step:616/1645 train_time:56592ms step_avg:91.87ms +step:617/1645 train_time:56685ms step_avg:91.87ms +step:618/1645 train_time:56779ms step_avg:91.87ms +step:619/1645 train_time:56871ms step_avg:91.88ms +step:620/1645 train_time:56964ms step_avg:91.88ms +step:621/1645 train_time:57056ms step_avg:91.88ms +step:622/1645 train_time:57149ms step_avg:91.88ms +step:623/1645 train_time:57242ms step_avg:91.88ms +step:624/1645 train_time:57336ms step_avg:91.88ms +step:625/1645 train_time:57429ms step_avg:91.89ms +step:625/1645 val_loss:3.6109 train_time:57522ms step_avg:92.03ms +step:626/1645 train_time:57542ms step_avg:91.92ms +step:627/1645 train_time:57624ms step_avg:91.90ms +step:628/1645 train_time:57724ms step_avg:91.92ms +step:629/1645 train_time:57819ms step_avg:91.92ms +step:630/1645 train_time:57913ms step_avg:91.93ms +step:631/1645 train_time:58005ms step_avg:91.93ms +step:632/1645 train_time:58097ms step_avg:91.93ms +step:633/1645 train_time:58188ms step_avg:91.92ms +step:634/1645 train_time:58280ms step_avg:91.92ms +step:635/1645 train_time:58371ms step_avg:91.92ms +step:636/1645 train_time:58464ms step_avg:91.93ms +step:637/1645 train_time:58559ms step_avg:91.93ms +step:638/1645 train_time:58653ms step_avg:91.93ms +step:639/1645 train_time:58748ms step_avg:91.94ms +step:640/1645 train_time:58842ms step_avg:91.94ms +step:641/1645 train_time:58936ms step_avg:91.94ms +step:642/1645 train_time:59028ms step_avg:91.94ms +step:643/1645 train_time:59121ms step_avg:91.94ms +step:644/1645 train_time:59212ms step_avg:91.94ms +step:645/1645 train_time:59304ms step_avg:91.94ms +step:646/1645 train_time:59395ms step_avg:91.94ms +step:647/1645 train_time:59488ms step_avg:91.94ms +step:648/1645 train_time:59581ms step_avg:91.95ms +step:649/1645 train_time:59674ms step_avg:91.95ms +step:650/1645 train_time:59769ms step_avg:91.95ms +step:651/1645 train_time:59863ms step_avg:91.96ms +step:652/1645 train_time:59956ms step_avg:91.96ms +step:653/1645 train_time:60049ms step_avg:91.96ms +step:654/1645 train_time:60142ms step_avg:91.96ms +step:655/1645 train_time:60235ms step_avg:91.96ms +step:656/1645 train_time:60326ms step_avg:91.96ms +step:657/1645 train_time:60419ms step_avg:91.96ms +step:658/1645 train_time:60512ms step_avg:91.96ms +step:659/1645 train_time:60605ms step_avg:91.97ms +step:660/1645 train_time:60699ms step_avg:91.97ms +step:661/1645 train_time:60793ms step_avg:91.97ms +step:662/1645 train_time:60886ms step_avg:91.97ms +step:663/1645 train_time:60979ms step_avg:91.97ms +step:664/1645 train_time:61072ms step_avg:91.98ms +step:665/1645 train_time:61166ms step_avg:91.98ms +step:666/1645 train_time:61257ms step_avg:91.98ms +step:667/1645 train_time:61350ms step_avg:91.98ms +step:668/1645 train_time:61443ms step_avg:91.98ms +step:669/1645 train_time:61537ms step_avg:91.98ms +step:670/1645 train_time:61630ms step_avg:91.98ms +step:671/1645 train_time:61723ms step_avg:91.99ms +step:672/1645 train_time:61815ms step_avg:91.99ms +step:673/1645 train_time:61909ms step_avg:91.99ms +step:674/1645 train_time:62003ms step_avg:91.99ms +step:675/1645 train_time:62097ms step_avg:92.00ms +step:676/1645 train_time:62189ms step_avg:92.00ms +step:677/1645 train_time:62281ms step_avg:92.00ms +step:678/1645 train_time:62373ms step_avg:92.00ms +step:679/1645 train_time:62466ms step_avg:92.00ms +step:680/1645 train_time:62559ms step_avg:92.00ms +step:681/1645 train_time:62651ms step_avg:92.00ms +step:682/1645 train_time:62745ms step_avg:92.00ms +step:683/1645 train_time:62838ms step_avg:92.00ms +step:684/1645 train_time:62932ms step_avg:92.01ms +step:685/1645 train_time:63025ms step_avg:92.01ms +step:686/1645 train_time:63118ms step_avg:92.01ms +step:687/1645 train_time:63211ms step_avg:92.01ms +step:688/1645 train_time:63305ms step_avg:92.01ms +step:689/1645 train_time:63398ms step_avg:92.01ms +step:690/1645 train_time:63490ms step_avg:92.01ms +step:691/1645 train_time:63583ms step_avg:92.02ms +step:692/1645 train_time:63675ms step_avg:92.02ms +step:693/1645 train_time:63768ms step_avg:92.02ms +step:694/1645 train_time:63861ms step_avg:92.02ms +step:695/1645 train_time:63955ms step_avg:92.02ms +step:696/1645 train_time:64048ms step_avg:92.02ms +step:697/1645 train_time:64141ms step_avg:92.02ms +step:698/1645 train_time:64234ms step_avg:92.03ms +step:699/1645 train_time:64327ms step_avg:92.03ms +step:700/1645 train_time:64420ms step_avg:92.03ms +step:701/1645 train_time:64512ms step_avg:92.03ms +step:702/1645 train_time:64606ms step_avg:92.03ms +step:703/1645 train_time:64698ms step_avg:92.03ms +step:704/1645 train_time:64791ms step_avg:92.03ms +step:705/1645 train_time:64885ms step_avg:92.03ms +step:706/1645 train_time:64978ms step_avg:92.04ms +step:707/1645 train_time:65071ms step_avg:92.04ms +step:708/1645 train_time:65164ms step_avg:92.04ms +step:709/1645 train_time:65257ms step_avg:92.04ms +step:710/1645 train_time:65350ms step_avg:92.04ms +step:711/1645 train_time:65444ms step_avg:92.04ms +step:712/1645 train_time:65536ms step_avg:92.04ms +step:713/1645 train_time:65629ms step_avg:92.05ms +step:714/1645 train_time:65722ms step_avg:92.05ms +step:715/1645 train_time:65816ms step_avg:92.05ms +step:716/1645 train_time:65909ms step_avg:92.05ms +step:717/1645 train_time:66002ms step_avg:92.05ms +step:718/1645 train_time:66095ms step_avg:92.05ms +step:719/1645 train_time:66187ms step_avg:92.05ms +step:720/1645 train_time:66280ms step_avg:92.06ms +step:721/1645 train_time:66372ms step_avg:92.06ms +step:722/1645 train_time:66466ms step_avg:92.06ms +step:723/1645 train_time:66558ms step_avg:92.06ms +step:724/1645 train_time:66651ms step_avg:92.06ms +step:725/1645 train_time:66744ms step_avg:92.06ms +step:726/1645 train_time:66837ms step_avg:92.06ms +step:727/1645 train_time:66931ms step_avg:92.07ms +step:728/1645 train_time:67025ms step_avg:92.07ms +step:729/1645 train_time:67117ms step_avg:92.07ms +step:730/1645 train_time:67210ms step_avg:92.07ms +step:731/1645 train_time:67303ms step_avg:92.07ms +step:732/1645 train_time:67396ms step_avg:92.07ms +step:733/1645 train_time:67490ms step_avg:92.07ms +step:734/1645 train_time:67583ms step_avg:92.07ms +step:735/1645 train_time:67675ms step_avg:92.07ms +step:736/1645 train_time:67769ms step_avg:92.08ms +step:737/1645 train_time:67862ms step_avg:92.08ms +step:738/1645 train_time:67954ms step_avg:92.08ms +step:739/1645 train_time:68048ms step_avg:92.08ms +step:740/1645 train_time:68141ms step_avg:92.08ms +step:741/1645 train_time:68233ms step_avg:92.08ms +step:742/1645 train_time:68326ms step_avg:92.08ms +step:743/1645 train_time:68419ms step_avg:92.08ms +step:744/1645 train_time:68512ms step_avg:92.09ms +step:745/1645 train_time:68604ms step_avg:92.09ms +step:746/1645 train_time:68697ms step_avg:92.09ms +step:747/1645 train_time:68790ms step_avg:92.09ms +step:748/1645 train_time:68883ms step_avg:92.09ms +step:749/1645 train_time:68976ms step_avg:92.09ms +step:750/1645 train_time:69069ms step_avg:92.09ms +step:750/1645 val_loss:3.5579 train_time:69163ms step_avg:92.22ms +step:751/1645 train_time:69183ms step_avg:92.12ms +step:752/1645 train_time:69262ms step_avg:92.10ms +step:753/1645 train_time:69357ms step_avg:92.11ms +step:754/1645 train_time:69449ms step_avg:92.11ms +step:755/1645 train_time:69541ms step_avg:92.11ms +step:756/1645 train_time:69633ms step_avg:92.11ms +step:757/1645 train_time:69724ms step_avg:92.11ms +step:758/1645 train_time:69816ms step_avg:92.10ms +step:759/1645 train_time:69908ms step_avg:92.11ms +step:760/1645 train_time:70000ms step_avg:92.11ms +step:761/1645 train_time:70093ms step_avg:92.11ms +step:762/1645 train_time:70189ms step_avg:92.11ms +step:763/1645 train_time:70286ms step_avg:92.12ms +step:764/1645 train_time:70381ms step_avg:92.12ms +step:765/1645 train_time:70474ms step_avg:92.12ms +step:766/1645 train_time:70566ms step_avg:92.12ms +step:767/1645 train_time:70658ms step_avg:92.12ms +step:768/1645 train_time:70750ms step_avg:92.12ms +step:769/1645 train_time:70842ms step_avg:92.12ms +step:770/1645 train_time:70934ms step_avg:92.12ms +step:771/1645 train_time:71026ms step_avg:92.12ms +step:772/1645 train_time:71121ms step_avg:92.13ms +step:773/1645 train_time:71215ms step_avg:92.13ms +step:774/1645 train_time:71311ms step_avg:92.13ms +step:775/1645 train_time:71404ms step_avg:92.13ms +step:776/1645 train_time:71497ms step_avg:92.14ms +step:777/1645 train_time:71591ms step_avg:92.14ms +step:778/1645 train_time:71685ms step_avg:92.14ms +step:779/1645 train_time:71775ms step_avg:92.14ms +step:780/1645 train_time:71868ms step_avg:92.14ms +step:781/1645 train_time:71960ms step_avg:92.14ms +step:782/1645 train_time:72052ms step_avg:92.14ms +step:783/1645 train_time:72145ms step_avg:92.14ms +step:784/1645 train_time:72238ms step_avg:92.14ms +step:785/1645 train_time:72332ms step_avg:92.14ms +step:786/1645 train_time:72425ms step_avg:92.14ms +step:787/1645 train_time:72517ms step_avg:92.14ms +step:788/1645 train_time:72612ms step_avg:92.15ms +step:789/1645 train_time:72703ms step_avg:92.15ms +step:790/1645 train_time:72795ms step_avg:92.15ms +step:791/1645 train_time:72889ms step_avg:92.15ms +step:792/1645 train_time:72981ms step_avg:92.15ms +step:793/1645 train_time:73073ms step_avg:92.15ms +step:794/1645 train_time:73167ms step_avg:92.15ms +step:795/1645 train_time:73260ms step_avg:92.15ms +step:796/1645 train_time:73354ms step_avg:92.15ms +step:797/1645 train_time:73448ms step_avg:92.16ms +step:798/1645 train_time:73541ms step_avg:92.16ms +step:799/1645 train_time:73633ms step_avg:92.16ms +step:800/1645 train_time:73727ms step_avg:92.16ms +step:801/1645 train_time:73818ms step_avg:92.16ms +step:802/1645 train_time:73911ms step_avg:92.16ms +step:803/1645 train_time:74004ms step_avg:92.16ms +step:804/1645 train_time:74097ms step_avg:92.16ms +step:805/1645 train_time:74192ms step_avg:92.16ms +step:806/1645 train_time:74285ms step_avg:92.16ms +step:807/1645 train_time:74378ms step_avg:92.17ms +step:808/1645 train_time:74471ms step_avg:92.17ms +step:809/1645 train_time:74564ms step_avg:92.17ms +step:810/1645 train_time:74657ms step_avg:92.17ms +step:811/1645 train_time:74750ms step_avg:92.17ms +step:812/1645 train_time:74843ms step_avg:92.17ms +step:813/1645 train_time:74936ms step_avg:92.17ms +step:814/1645 train_time:75028ms step_avg:92.17ms +step:815/1645 train_time:75121ms step_avg:92.17ms +step:816/1645 train_time:75214ms step_avg:92.17ms +step:817/1645 train_time:75308ms step_avg:92.18ms +step:818/1645 train_time:75402ms step_avg:92.18ms +step:819/1645 train_time:75495ms step_avg:92.18ms +step:820/1645 train_time:75588ms step_avg:92.18ms +step:821/1645 train_time:75681ms step_avg:92.18ms +step:822/1645 train_time:75774ms step_avg:92.18ms +step:823/1645 train_time:75867ms step_avg:92.18ms +step:824/1645 train_time:75959ms step_avg:92.18ms +step:825/1645 train_time:76052ms step_avg:92.18ms +step:826/1645 train_time:76145ms step_avg:92.19ms +step:827/1645 train_time:76237ms step_avg:92.19ms +step:828/1645 train_time:76331ms step_avg:92.19ms +step:829/1645 train_time:76424ms step_avg:92.19ms +step:830/1645 train_time:76516ms step_avg:92.19ms +step:831/1645 train_time:76610ms step_avg:92.19ms +step:832/1645 train_time:76703ms step_avg:92.19ms +step:833/1645 train_time:76795ms step_avg:92.19ms +step:834/1645 train_time:76888ms step_avg:92.19ms +step:835/1645 train_time:76982ms step_avg:92.19ms +step:836/1645 train_time:77075ms step_avg:92.19ms +step:837/1645 train_time:77168ms step_avg:92.20ms +step:838/1645 train_time:77261ms step_avg:92.20ms +step:839/1645 train_time:77354ms step_avg:92.20ms +step:840/1645 train_time:77447ms step_avg:92.20ms +step:841/1645 train_time:77541ms step_avg:92.20ms +step:842/1645 train_time:77633ms step_avg:92.20ms +step:843/1645 train_time:77727ms step_avg:92.20ms +step:844/1645 train_time:77820ms step_avg:92.20ms +step:845/1645 train_time:77912ms step_avg:92.20ms +step:846/1645 train_time:78005ms step_avg:92.20ms +step:847/1645 train_time:78098ms step_avg:92.21ms +step:848/1645 train_time:78191ms step_avg:92.21ms +step:849/1645 train_time:78283ms step_avg:92.21ms +step:850/1645 train_time:78377ms step_avg:92.21ms +step:851/1645 train_time:78471ms step_avg:92.21ms +step:852/1645 train_time:78563ms step_avg:92.21ms +step:853/1645 train_time:78656ms step_avg:92.21ms +step:854/1645 train_time:78750ms step_avg:92.21ms +step:855/1645 train_time:78843ms step_avg:92.21ms +step:856/1645 train_time:78935ms step_avg:92.21ms +step:857/1645 train_time:79028ms step_avg:92.21ms +step:858/1645 train_time:79121ms step_avg:92.22ms +step:859/1645 train_time:79216ms step_avg:92.22ms +step:860/1645 train_time:79308ms step_avg:92.22ms +step:861/1645 train_time:79400ms step_avg:92.22ms +step:862/1645 train_time:79493ms step_avg:92.22ms +step:863/1645 train_time:79587ms step_avg:92.22ms +step:864/1645 train_time:79680ms step_avg:92.22ms +step:865/1645 train_time:79773ms step_avg:92.22ms +step:866/1645 train_time:79865ms step_avg:92.22ms +step:867/1645 train_time:79958ms step_avg:92.22ms +step:868/1645 train_time:80051ms step_avg:92.23ms +step:869/1645 train_time:80144ms step_avg:92.23ms +step:870/1645 train_time:80236ms step_avg:92.23ms +step:871/1645 train_time:80329ms step_avg:92.23ms +step:872/1645 train_time:80424ms step_avg:92.23ms +step:873/1645 train_time:80515ms step_avg:92.23ms +step:874/1645 train_time:80609ms step_avg:92.23ms +step:875/1645 train_time:80701ms step_avg:92.23ms +step:875/1645 val_loss:3.5128 train_time:80794ms step_avg:92.34ms +step:876/1645 train_time:80816ms step_avg:92.26ms +step:877/1645 train_time:80892ms step_avg:92.24ms +step:878/1645 train_time:80988ms step_avg:92.24ms +step:879/1645 train_time:81081ms step_avg:92.24ms +step:880/1645 train_time:81172ms step_avg:92.24ms +step:881/1645 train_time:81265ms step_avg:92.24ms +step:882/1645 train_time:81356ms step_avg:92.24ms +step:883/1645 train_time:81449ms step_avg:92.24ms +step:884/1645 train_time:81541ms step_avg:92.24ms +step:885/1645 train_time:81632ms step_avg:92.24ms +step:886/1645 train_time:81728ms step_avg:92.24ms +step:887/1645 train_time:81824ms step_avg:92.25ms +step:888/1645 train_time:81920ms step_avg:92.25ms +step:889/1645 train_time:82014ms step_avg:92.25ms +step:890/1645 train_time:82107ms step_avg:92.26ms +step:891/1645 train_time:82200ms step_avg:92.26ms +step:892/1645 train_time:82293ms step_avg:92.26ms +step:893/1645 train_time:82385ms step_avg:92.26ms +step:894/1645 train_time:82477ms step_avg:92.26ms +step:895/1645 train_time:82568ms step_avg:92.26ms +step:896/1645 train_time:82661ms step_avg:92.26ms +step:897/1645 train_time:82755ms step_avg:92.26ms +step:898/1645 train_time:82849ms step_avg:92.26ms +step:899/1645 train_time:82944ms step_avg:92.26ms +step:900/1645 train_time:83037ms step_avg:92.26ms +step:901/1645 train_time:83130ms step_avg:92.26ms +step:902/1645 train_time:83224ms step_avg:92.27ms +step:903/1645 train_time:83316ms step_avg:92.27ms +step:904/1645 train_time:83409ms step_avg:92.27ms +step:905/1645 train_time:83501ms step_avg:92.27ms +step:906/1645 train_time:83593ms step_avg:92.27ms +step:907/1645 train_time:83686ms step_avg:92.27ms +step:908/1645 train_time:83779ms step_avg:92.27ms +step:909/1645 train_time:83872ms step_avg:92.27ms +step:910/1645 train_time:83965ms step_avg:92.27ms +step:911/1645 train_time:84058ms step_avg:92.27ms +step:912/1645 train_time:84151ms step_avg:92.27ms +step:913/1645 train_time:84244ms step_avg:92.27ms +step:914/1645 train_time:84336ms step_avg:92.27ms +step:915/1645 train_time:84428ms step_avg:92.27ms +step:916/1645 train_time:84520ms step_avg:92.27ms +step:917/1645 train_time:84613ms step_avg:92.27ms +step:918/1645 train_time:84706ms step_avg:92.27ms +step:919/1645 train_time:84798ms step_avg:92.27ms +step:920/1645 train_time:84892ms step_avg:92.27ms +step:921/1645 train_time:84988ms step_avg:92.28ms +step:922/1645 train_time:85079ms step_avg:92.28ms +step:923/1645 train_time:85171ms step_avg:92.28ms +step:924/1645 train_time:85264ms step_avg:92.28ms +step:925/1645 train_time:85357ms step_avg:92.28ms +step:926/1645 train_time:85449ms step_avg:92.28ms +step:927/1645 train_time:85541ms step_avg:92.28ms +step:928/1645 train_time:85635ms step_avg:92.28ms +step:929/1645 train_time:85728ms step_avg:92.28ms +step:930/1645 train_time:85822ms step_avg:92.28ms +step:931/1645 train_time:85915ms step_avg:92.28ms +step:932/1645 train_time:86008ms step_avg:92.28ms +step:933/1645 train_time:86101ms step_avg:92.28ms +step:934/1645 train_time:86194ms step_avg:92.29ms +step:935/1645 train_time:86287ms step_avg:92.29ms +step:936/1645 train_time:86379ms step_avg:92.29ms +step:937/1645 train_time:86472ms step_avg:92.29ms +step:938/1645 train_time:86565ms step_avg:92.29ms +step:939/1645 train_time:86658ms step_avg:92.29ms +step:940/1645 train_time:86751ms step_avg:92.29ms +step:941/1645 train_time:86845ms step_avg:92.29ms +step:942/1645 train_time:86937ms step_avg:92.29ms +step:943/1645 train_time:87032ms step_avg:92.29ms +step:944/1645 train_time:87125ms step_avg:92.29ms +step:945/1645 train_time:87218ms step_avg:92.29ms +step:946/1645 train_time:87311ms step_avg:92.29ms +step:947/1645 train_time:87403ms step_avg:92.30ms +step:948/1645 train_time:87496ms step_avg:92.30ms +step:949/1645 train_time:87589ms step_avg:92.30ms +step:950/1645 train_time:87681ms step_avg:92.30ms +step:951/1645 train_time:87774ms step_avg:92.30ms +step:952/1645 train_time:87867ms step_avg:92.30ms +step:953/1645 train_time:87960ms step_avg:92.30ms +step:954/1645 train_time:88053ms step_avg:92.30ms +step:955/1645 train_time:88147ms step_avg:92.30ms +step:956/1645 train_time:88240ms step_avg:92.30ms +step:957/1645 train_time:88333ms step_avg:92.30ms +step:958/1645 train_time:88426ms step_avg:92.30ms +step:959/1645 train_time:88519ms step_avg:92.30ms +step:960/1645 train_time:88612ms step_avg:92.30ms +step:961/1645 train_time:88705ms step_avg:92.30ms +step:962/1645 train_time:88799ms step_avg:92.31ms +step:963/1645 train_time:88891ms step_avg:92.31ms +step:964/1645 train_time:88984ms step_avg:92.31ms +step:965/1645 train_time:89077ms step_avg:92.31ms +step:966/1645 train_time:89170ms step_avg:92.31ms +step:967/1645 train_time:89264ms step_avg:92.31ms +step:968/1645 train_time:89357ms step_avg:92.31ms +step:969/1645 train_time:89452ms step_avg:92.31ms +step:970/1645 train_time:89545ms step_avg:92.31ms +step:971/1645 train_time:89637ms step_avg:92.31ms +step:972/1645 train_time:89730ms step_avg:92.32ms +step:973/1645 train_time:89824ms step_avg:92.32ms +step:974/1645 train_time:89916ms step_avg:92.32ms +step:975/1645 train_time:90010ms step_avg:92.32ms +step:976/1645 train_time:90102ms step_avg:92.32ms +step:977/1645 train_time:90195ms step_avg:92.32ms +step:978/1645 train_time:90287ms step_avg:92.32ms +step:979/1645 train_time:90380ms step_avg:92.32ms +step:980/1645 train_time:90473ms step_avg:92.32ms +step:981/1645 train_time:90566ms step_avg:92.32ms +step:982/1645 train_time:90659ms step_avg:92.32ms +step:983/1645 train_time:90753ms step_avg:92.32ms +step:984/1645 train_time:90846ms step_avg:92.32ms +step:985/1645 train_time:90938ms step_avg:92.32ms +step:986/1645 train_time:91032ms step_avg:92.32ms +step:987/1645 train_time:91124ms step_avg:92.32ms +step:988/1645 train_time:91217ms step_avg:92.32ms +step:989/1645 train_time:91310ms step_avg:92.33ms +step:990/1645 train_time:91403ms step_avg:92.33ms +step:991/1645 train_time:91496ms step_avg:92.33ms +step:992/1645 train_time:91590ms step_avg:92.33ms +step:993/1645 train_time:91683ms step_avg:92.33ms +step:994/1645 train_time:91776ms step_avg:92.33ms +step:995/1645 train_time:91869ms step_avg:92.33ms +step:996/1645 train_time:91961ms step_avg:92.33ms +step:997/1645 train_time:92055ms step_avg:92.33ms +step:998/1645 train_time:92148ms step_avg:92.33ms +step:999/1645 train_time:92241ms step_avg:92.33ms +step:1000/1645 train_time:92334ms step_avg:92.33ms +step:1000/1645 val_loss:3.4611 train_time:92428ms step_avg:92.43ms +step:1001/1645 train_time:92449ms step_avg:92.36ms +step:1002/1645 train_time:92526ms step_avg:92.34ms +step:1003/1645 train_time:92621ms step_avg:92.34ms +step:1004/1645 train_time:92714ms step_avg:92.34ms +step:1005/1645 train_time:92805ms step_avg:92.34ms +step:1006/1645 train_time:92897ms step_avg:92.34ms +step:1007/1645 train_time:92989ms step_avg:92.34ms +step:1008/1645 train_time:93082ms step_avg:92.34ms +step:1009/1645 train_time:93175ms step_avg:92.34ms +step:1010/1645 train_time:93268ms step_avg:92.34ms +step:1011/1645 train_time:93362ms step_avg:92.35ms +step:1012/1645 train_time:93456ms step_avg:92.35ms +step:1013/1645 train_time:93551ms step_avg:92.35ms +step:1014/1645 train_time:93644ms step_avg:92.35ms +step:1015/1645 train_time:93737ms step_avg:92.35ms +step:1016/1645 train_time:93829ms step_avg:92.35ms +step:1017/1645 train_time:93921ms step_avg:92.35ms +step:1018/1645 train_time:94014ms step_avg:92.35ms +step:1019/1645 train_time:94106ms step_avg:92.35ms +step:1020/1645 train_time:94198ms step_avg:92.35ms +step:1021/1645 train_time:94291ms step_avg:92.35ms +step:1022/1645 train_time:94385ms step_avg:92.35ms +step:1023/1645 train_time:94479ms step_avg:92.35ms +step:1024/1645 train_time:94574ms step_avg:92.36ms +step:1025/1645 train_time:94668ms step_avg:92.36ms +step:1026/1645 train_time:94760ms step_avg:92.36ms +step:1027/1645 train_time:94853ms step_avg:92.36ms +step:1028/1645 train_time:94945ms step_avg:92.36ms +step:1029/1645 train_time:95037ms step_avg:92.36ms +step:1030/1645 train_time:95130ms step_avg:92.36ms +step:1031/1645 train_time:95222ms step_avg:92.36ms +step:1032/1645 train_time:95315ms step_avg:92.36ms +step:1033/1645 train_time:95408ms step_avg:92.36ms +step:1034/1645 train_time:95501ms step_avg:92.36ms +step:1035/1645 train_time:95595ms step_avg:92.36ms +step:1036/1645 train_time:95688ms step_avg:92.36ms +step:1037/1645 train_time:95781ms step_avg:92.36ms +step:1038/1645 train_time:95876ms step_avg:92.37ms +step:1039/1645 train_time:95966ms step_avg:92.36ms +step:1040/1645 train_time:96059ms step_avg:92.36ms +step:1041/1645 train_time:96152ms step_avg:92.37ms +step:1042/1645 train_time:96245ms step_avg:92.37ms +step:1043/1645 train_time:96337ms step_avg:92.37ms +step:1044/1645 train_time:96431ms step_avg:92.37ms +step:1045/1645 train_time:96525ms step_avg:92.37ms +step:1046/1645 train_time:96617ms step_avg:92.37ms +step:1047/1645 train_time:96710ms step_avg:92.37ms +step:1048/1645 train_time:96803ms step_avg:92.37ms +step:1049/1645 train_time:96896ms step_avg:92.37ms +step:1050/1645 train_time:96989ms step_avg:92.37ms +step:1051/1645 train_time:97081ms step_avg:92.37ms +step:1052/1645 train_time:97174ms step_avg:92.37ms +step:1053/1645 train_time:97267ms step_avg:92.37ms +step:1054/1645 train_time:97360ms step_avg:92.37ms +step:1055/1645 train_time:97452ms step_avg:92.37ms +step:1056/1645 train_time:97546ms step_avg:92.37ms +step:1057/1645 train_time:97639ms step_avg:92.37ms +step:1058/1645 train_time:97733ms step_avg:92.38ms +step:1059/1645 train_time:97826ms step_avg:92.38ms +step:1060/1645 train_time:97918ms step_avg:92.38ms +step:1061/1645 train_time:98011ms step_avg:92.38ms +step:1062/1645 train_time:98104ms step_avg:92.38ms +step:1063/1645 train_time:98197ms step_avg:92.38ms +step:1064/1645 train_time:98290ms step_avg:92.38ms +step:1065/1645 train_time:98385ms step_avg:92.38ms +step:1066/1645 train_time:98477ms step_avg:92.38ms +step:1067/1645 train_time:98570ms step_avg:92.38ms +step:1068/1645 train_time:98663ms step_avg:92.38ms +step:1069/1645 train_time:98756ms step_avg:92.38ms +step:1070/1645 train_time:98849ms step_avg:92.38ms +step:1071/1645 train_time:98942ms step_avg:92.38ms +step:1072/1645 train_time:99035ms step_avg:92.38ms +step:1073/1645 train_time:99127ms step_avg:92.38ms +step:1074/1645 train_time:99219ms step_avg:92.38ms +step:1075/1645 train_time:99312ms step_avg:92.38ms +step:1076/1645 train_time:99405ms step_avg:92.38ms +step:1077/1645 train_time:99498ms step_avg:92.38ms +step:1078/1645 train_time:99592ms step_avg:92.39ms +step:1079/1645 train_time:99685ms step_avg:92.39ms +step:1080/1645 train_time:99778ms step_avg:92.39ms +step:1081/1645 train_time:99871ms step_avg:92.39ms +step:1082/1645 train_time:99965ms step_avg:92.39ms +step:1083/1645 train_time:100058ms step_avg:92.39ms +step:1084/1645 train_time:100151ms step_avg:92.39ms +step:1085/1645 train_time:100243ms step_avg:92.39ms +step:1086/1645 train_time:100337ms step_avg:92.39ms +step:1087/1645 train_time:100429ms step_avg:92.39ms +step:1088/1645 train_time:100523ms step_avg:92.39ms +step:1089/1645 train_time:100616ms step_avg:92.39ms +step:1090/1645 train_time:100709ms step_avg:92.39ms +step:1091/1645 train_time:100801ms step_avg:92.39ms +step:1092/1645 train_time:100896ms step_avg:92.40ms +step:1093/1645 train_time:100989ms step_avg:92.40ms +step:1094/1645 train_time:101082ms step_avg:92.40ms +step:1095/1645 train_time:101175ms step_avg:92.40ms +step:1096/1645 train_time:101268ms step_avg:92.40ms +step:1097/1645 train_time:101360ms step_avg:92.40ms +step:1098/1645 train_time:101453ms step_avg:92.40ms +step:1099/1645 train_time:101546ms step_avg:92.40ms +step:1100/1645 train_time:101639ms step_avg:92.40ms +step:1101/1645 train_time:101733ms step_avg:92.40ms +step:1102/1645 train_time:101827ms step_avg:92.40ms +step:1103/1645 train_time:101920ms step_avg:92.40ms +step:1104/1645 train_time:102014ms step_avg:92.40ms +step:1105/1645 train_time:102108ms step_avg:92.41ms +step:1106/1645 train_time:102202ms step_avg:92.41ms +step:1107/1645 train_time:102295ms step_avg:92.41ms +step:1108/1645 train_time:102388ms step_avg:92.41ms +step:1109/1645 train_time:102482ms step_avg:92.41ms +step:1110/1645 train_time:102576ms step_avg:92.41ms +step:1111/1645 train_time:102670ms step_avg:92.41ms +step:1112/1645 train_time:102764ms step_avg:92.41ms +step:1113/1645 train_time:102857ms step_avg:92.41ms +step:1114/1645 train_time:102951ms step_avg:92.42ms +step:1115/1645 train_time:103045ms step_avg:92.42ms +step:1116/1645 train_time:103139ms step_avg:92.42ms +step:1117/1645 train_time:103233ms step_avg:92.42ms +step:1118/1645 train_time:103327ms step_avg:92.42ms +step:1119/1645 train_time:103419ms step_avg:92.42ms +step:1120/1645 train_time:103513ms step_avg:92.42ms +step:1121/1645 train_time:103608ms step_avg:92.42ms +step:1122/1645 train_time:103700ms step_avg:92.42ms +step:1123/1645 train_time:103794ms step_avg:92.43ms +step:1124/1645 train_time:103887ms step_avg:92.43ms +step:1125/1645 train_time:103980ms step_avg:92.43ms +step:1125/1645 val_loss:3.4085 train_time:104075ms step_avg:92.51ms +step:1126/1645 train_time:104099ms step_avg:92.45ms +step:1127/1645 train_time:104177ms step_avg:92.44ms +step:1128/1645 train_time:104279ms step_avg:92.45ms +step:1129/1645 train_time:104375ms step_avg:92.45ms +step:1130/1645 train_time:104467ms step_avg:92.45ms +step:1131/1645 train_time:104559ms step_avg:92.45ms +step:1132/1645 train_time:104652ms step_avg:92.45ms +step:1133/1645 train_time:104744ms step_avg:92.45ms +step:1134/1645 train_time:104836ms step_avg:92.45ms +step:1135/1645 train_time:104929ms step_avg:92.45ms +step:1136/1645 train_time:105025ms step_avg:92.45ms +step:1137/1645 train_time:105120ms step_avg:92.45ms +step:1138/1645 train_time:105218ms step_avg:92.46ms +step:1139/1645 train_time:105314ms step_avg:92.46ms +step:1140/1645 train_time:105409ms step_avg:92.46ms +step:1141/1645 train_time:105501ms step_avg:92.46ms +step:1142/1645 train_time:105594ms step_avg:92.46ms +step:1143/1645 train_time:105687ms step_avg:92.46ms +step:1144/1645 train_time:105780ms step_avg:92.46ms +step:1145/1645 train_time:105873ms step_avg:92.47ms +step:1146/1645 train_time:105966ms step_avg:92.47ms +step:1147/1645 train_time:106061ms step_avg:92.47ms +step:1148/1645 train_time:106155ms step_avg:92.47ms +step:1149/1645 train_time:106250ms step_avg:92.47ms +step:1150/1645 train_time:106344ms step_avg:92.47ms +step:1151/1645 train_time:106438ms step_avg:92.47ms +step:1152/1645 train_time:106531ms step_avg:92.47ms +step:1153/1645 train_time:106625ms step_avg:92.48ms +step:1154/1645 train_time:106717ms step_avg:92.48ms +step:1155/1645 train_time:106810ms step_avg:92.48ms +step:1156/1645 train_time:106904ms step_avg:92.48ms +step:1157/1645 train_time:106997ms step_avg:92.48ms +step:1158/1645 train_time:107090ms step_avg:92.48ms +step:1159/1645 train_time:107185ms step_avg:92.48ms +step:1160/1645 train_time:107280ms step_avg:92.48ms +step:1161/1645 train_time:107375ms step_avg:92.49ms +step:1162/1645 train_time:107468ms step_avg:92.49ms +step:1163/1645 train_time:107563ms step_avg:92.49ms +step:1164/1645 train_time:107655ms step_avg:92.49ms +step:1165/1645 train_time:107747ms step_avg:92.49ms +step:1166/1645 train_time:107840ms step_avg:92.49ms +step:1167/1645 train_time:107933ms step_avg:92.49ms +step:1168/1645 train_time:108026ms step_avg:92.49ms +step:1169/1645 train_time:108120ms step_avg:92.49ms +step:1170/1645 train_time:108214ms step_avg:92.49ms +step:1171/1645 train_time:108308ms step_avg:92.49ms +step:1172/1645 train_time:108401ms step_avg:92.49ms +step:1173/1645 train_time:108495ms step_avg:92.49ms +step:1174/1645 train_time:108588ms step_avg:92.49ms +step:1175/1645 train_time:108681ms step_avg:92.49ms +step:1176/1645 train_time:108774ms step_avg:92.50ms +step:1177/1645 train_time:108868ms step_avg:92.50ms +step:1178/1645 train_time:108961ms step_avg:92.50ms +step:1179/1645 train_time:109054ms step_avg:92.50ms +step:1180/1645 train_time:109148ms step_avg:92.50ms +step:1181/1645 train_time:109242ms step_avg:92.50ms +step:1182/1645 train_time:109336ms step_avg:92.50ms +step:1183/1645 train_time:109429ms step_avg:92.50ms +step:1184/1645 train_time:109523ms step_avg:92.50ms +step:1185/1645 train_time:109618ms step_avg:92.50ms +step:1186/1645 train_time:109711ms step_avg:92.51ms +step:1187/1645 train_time:109804ms step_avg:92.51ms +step:1188/1645 train_time:109897ms step_avg:92.51ms +step:1189/1645 train_time:109991ms step_avg:92.51ms +step:1190/1645 train_time:110085ms step_avg:92.51ms +step:1191/1645 train_time:110178ms step_avg:92.51ms +step:1192/1645 train_time:110271ms step_avg:92.51ms +step:1193/1645 train_time:110364ms step_avg:92.51ms +step:1194/1645 train_time:110457ms step_avg:92.51ms +step:1195/1645 train_time:110551ms step_avg:92.51ms +step:1196/1645 train_time:110644ms step_avg:92.51ms +step:1197/1645 train_time:110738ms step_avg:92.51ms +step:1198/1645 train_time:110831ms step_avg:92.51ms +step:1199/1645 train_time:110923ms step_avg:92.51ms +step:1200/1645 train_time:111017ms step_avg:92.51ms +step:1201/1645 train_time:111110ms step_avg:92.51ms +step:1202/1645 train_time:111204ms step_avg:92.52ms +step:1203/1645 train_time:111297ms step_avg:92.52ms +step:1204/1645 train_time:111392ms step_avg:92.52ms +step:1205/1645 train_time:111485ms step_avg:92.52ms +step:1206/1645 train_time:111578ms step_avg:92.52ms +step:1207/1645 train_time:111672ms step_avg:92.52ms +step:1208/1645 train_time:111766ms step_avg:92.52ms +step:1209/1645 train_time:111859ms step_avg:92.52ms +step:1210/1645 train_time:111953ms step_avg:92.52ms +step:1211/1645 train_time:112046ms step_avg:92.52ms +step:1212/1645 train_time:112140ms step_avg:92.52ms +step:1213/1645 train_time:112233ms step_avg:92.53ms +step:1214/1645 train_time:112327ms step_avg:92.53ms +step:1215/1645 train_time:112423ms step_avg:92.53ms +step:1216/1645 train_time:112517ms step_avg:92.53ms +step:1217/1645 train_time:112609ms step_avg:92.53ms +step:1218/1645 train_time:112703ms step_avg:92.53ms +step:1219/1645 train_time:112796ms step_avg:92.53ms +step:1220/1645 train_time:112889ms step_avg:92.53ms +step:1221/1645 train_time:112983ms step_avg:92.53ms +step:1222/1645 train_time:113077ms step_avg:92.53ms +step:1223/1645 train_time:113170ms step_avg:92.53ms +step:1224/1645 train_time:113263ms step_avg:92.54ms +step:1225/1645 train_time:113357ms step_avg:92.54ms +step:1226/1645 train_time:113451ms step_avg:92.54ms +step:1227/1645 train_time:113545ms step_avg:92.54ms +step:1228/1645 train_time:113639ms step_avg:92.54ms +step:1229/1645 train_time:113732ms step_avg:92.54ms +step:1230/1645 train_time:113826ms step_avg:92.54ms +step:1231/1645 train_time:113920ms step_avg:92.54ms +step:1232/1645 train_time:114015ms step_avg:92.54ms +step:1233/1645 train_time:114108ms step_avg:92.54ms +step:1234/1645 train_time:114201ms step_avg:92.55ms +step:1235/1645 train_time:114295ms step_avg:92.55ms +step:1236/1645 train_time:114389ms step_avg:92.55ms +step:1237/1645 train_time:114483ms step_avg:92.55ms +step:1238/1645 train_time:114577ms step_avg:92.55ms +step:1239/1645 train_time:114671ms step_avg:92.55ms +step:1240/1645 train_time:114764ms step_avg:92.55ms +step:1241/1645 train_time:114859ms step_avg:92.55ms +step:1242/1645 train_time:114953ms step_avg:92.55ms +step:1243/1645 train_time:115046ms step_avg:92.56ms +step:1244/1645 train_time:115140ms step_avg:92.56ms +step:1245/1645 train_time:115233ms step_avg:92.56ms +step:1246/1645 train_time:115327ms step_avg:92.56ms +step:1247/1645 train_time:115422ms step_avg:92.56ms +step:1248/1645 train_time:115517ms step_avg:92.56ms +step:1249/1645 train_time:115609ms step_avg:92.56ms +step:1250/1645 train_time:115702ms step_avg:92.56ms +step:1250/1645 val_loss:3.3704 train_time:115795ms step_avg:92.64ms +step:1251/1645 train_time:115816ms step_avg:92.58ms +step:1252/1645 train_time:115895ms step_avg:92.57ms +step:1253/1645 train_time:115989ms step_avg:92.57ms +step:1254/1645 train_time:116082ms step_avg:92.57ms +step:1255/1645 train_time:116175ms step_avg:92.57ms +step:1256/1645 train_time:116267ms step_avg:92.57ms +step:1257/1645 train_time:116359ms step_avg:92.57ms +step:1258/1645 train_time:116452ms step_avg:92.57ms +step:1259/1645 train_time:116547ms step_avg:92.57ms +step:1260/1645 train_time:116640ms step_avg:92.57ms +step:1261/1645 train_time:116734ms step_avg:92.57ms +step:1262/1645 train_time:116829ms step_avg:92.57ms +step:1263/1645 train_time:116924ms step_avg:92.58ms +step:1264/1645 train_time:117018ms step_avg:92.58ms +step:1265/1645 train_time:117111ms step_avg:92.58ms +step:1266/1645 train_time:117204ms step_avg:92.58ms +step:1267/1645 train_time:117298ms step_avg:92.58ms +step:1268/1645 train_time:117391ms step_avg:92.58ms +step:1269/1645 train_time:117486ms step_avg:92.58ms +step:1270/1645 train_time:117578ms step_avg:92.58ms +step:1271/1645 train_time:117672ms step_avg:92.58ms +step:1272/1645 train_time:117766ms step_avg:92.58ms +step:1273/1645 train_time:117861ms step_avg:92.58ms +step:1274/1645 train_time:117956ms step_avg:92.59ms +step:1275/1645 train_time:118049ms step_avg:92.59ms +step:1276/1645 train_time:118143ms step_avg:92.59ms +step:1277/1645 train_time:118236ms step_avg:92.59ms +step:1278/1645 train_time:118329ms step_avg:92.59ms +step:1279/1645 train_time:118422ms step_avg:92.59ms +step:1280/1645 train_time:118516ms step_avg:92.59ms +step:1281/1645 train_time:118609ms step_avg:92.59ms +step:1282/1645 train_time:118703ms step_avg:92.59ms +step:1283/1645 train_time:118797ms step_avg:92.59ms +step:1284/1645 train_time:118893ms step_avg:92.60ms +step:1285/1645 train_time:118988ms step_avg:92.60ms +step:1286/1645 train_time:119082ms step_avg:92.60ms +step:1287/1645 train_time:119175ms step_avg:92.60ms +step:1288/1645 train_time:119268ms step_avg:92.60ms +step:1289/1645 train_time:119361ms step_avg:92.60ms +step:1290/1645 train_time:119455ms step_avg:92.60ms +step:1291/1645 train_time:119548ms step_avg:92.60ms +step:1292/1645 train_time:119642ms step_avg:92.60ms +step:1293/1645 train_time:119736ms step_avg:92.60ms +step:1294/1645 train_time:119830ms step_avg:92.60ms +step:1295/1645 train_time:119925ms step_avg:92.61ms +step:1296/1645 train_time:120020ms step_avg:92.61ms +step:1297/1645 train_time:120115ms step_avg:92.61ms +step:1298/1645 train_time:120209ms step_avg:92.61ms +step:1299/1645 train_time:120302ms step_avg:92.61ms +step:1300/1645 train_time:120395ms step_avg:92.61ms +step:1301/1645 train_time:120489ms step_avg:92.61ms +step:1302/1645 train_time:120582ms step_avg:92.61ms +step:1303/1645 train_time:120676ms step_avg:92.61ms +step:1304/1645 train_time:120770ms step_avg:92.62ms +step:1305/1645 train_time:120864ms step_avg:92.62ms +step:1306/1645 train_time:120958ms step_avg:92.62ms +step:1307/1645 train_time:121054ms step_avg:92.62ms +step:1308/1645 train_time:121148ms step_avg:92.62ms +step:1309/1645 train_time:121241ms step_avg:92.62ms +step:1310/1645 train_time:121335ms step_avg:92.62ms +step:1311/1645 train_time:121428ms step_avg:92.62ms +step:1312/1645 train_time:121521ms step_avg:92.62ms +step:1313/1645 train_time:121614ms step_avg:92.62ms +step:1314/1645 train_time:121707ms step_avg:92.62ms +step:1315/1645 train_time:121801ms step_avg:92.62ms +step:1316/1645 train_time:121895ms step_avg:92.63ms +step:1317/1645 train_time:121988ms step_avg:92.63ms +step:1318/1645 train_time:122082ms step_avg:92.63ms +step:1319/1645 train_time:122177ms step_avg:92.63ms +step:1320/1645 train_time:122272ms step_avg:92.63ms +step:1321/1645 train_time:122365ms step_avg:92.63ms +step:1322/1645 train_time:122458ms step_avg:92.63ms +step:1323/1645 train_time:122552ms step_avg:92.63ms +step:1324/1645 train_time:122646ms step_avg:92.63ms +step:1325/1645 train_time:122740ms step_avg:92.63ms +step:1326/1645 train_time:122833ms step_avg:92.63ms +step:1327/1645 train_time:122926ms step_avg:92.63ms +step:1328/1645 train_time:123019ms step_avg:92.64ms +step:1329/1645 train_time:123114ms step_avg:92.64ms +step:1330/1645 train_time:123207ms step_avg:92.64ms +step:1331/1645 train_time:123301ms step_avg:92.64ms +step:1332/1645 train_time:123394ms step_avg:92.64ms +step:1333/1645 train_time:123488ms step_avg:92.64ms +step:1334/1645 train_time:123581ms step_avg:92.64ms +step:1335/1645 train_time:123676ms step_avg:92.64ms +step:1336/1645 train_time:123769ms step_avg:92.64ms +step:1337/1645 train_time:123863ms step_avg:92.64ms +step:1338/1645 train_time:123956ms step_avg:92.64ms +step:1339/1645 train_time:124050ms step_avg:92.64ms +step:1340/1645 train_time:124144ms step_avg:92.64ms +step:1341/1645 train_time:124237ms step_avg:92.65ms +step:1342/1645 train_time:124332ms step_avg:92.65ms +step:1343/1645 train_time:124426ms step_avg:92.65ms +step:1344/1645 train_time:124519ms step_avg:92.65ms +step:1345/1645 train_time:124612ms step_avg:92.65ms +step:1346/1645 train_time:124706ms step_avg:92.65ms +step:1347/1645 train_time:124799ms step_avg:92.65ms +step:1348/1645 train_time:124893ms step_avg:92.65ms +step:1349/1645 train_time:124988ms step_avg:92.65ms +step:1350/1645 train_time:125082ms step_avg:92.65ms +step:1351/1645 train_time:125176ms step_avg:92.65ms +step:1352/1645 train_time:125269ms step_avg:92.65ms +step:1353/1645 train_time:125363ms step_avg:92.66ms +step:1354/1645 train_time:125457ms step_avg:92.66ms +step:1355/1645 train_time:125551ms step_avg:92.66ms +step:1356/1645 train_time:125646ms step_avg:92.66ms +step:1357/1645 train_time:125740ms step_avg:92.66ms +step:1358/1645 train_time:125833ms step_avg:92.66ms +step:1359/1645 train_time:125927ms step_avg:92.66ms +step:1360/1645 train_time:126020ms step_avg:92.66ms +step:1361/1645 train_time:126114ms step_avg:92.66ms +step:1362/1645 train_time:126208ms step_avg:92.66ms +step:1363/1645 train_time:126301ms step_avg:92.66ms +step:1364/1645 train_time:126394ms step_avg:92.66ms +step:1365/1645 train_time:126489ms step_avg:92.67ms +step:1366/1645 train_time:126582ms step_avg:92.67ms +step:1367/1645 train_time:126675ms step_avg:92.67ms +step:1368/1645 train_time:126770ms step_avg:92.67ms +step:1369/1645 train_time:126863ms step_avg:92.67ms +step:1370/1645 train_time:126957ms step_avg:92.67ms +step:1371/1645 train_time:127051ms step_avg:92.67ms +step:1372/1645 train_time:127145ms step_avg:92.67ms +step:1373/1645 train_time:127239ms step_avg:92.67ms +step:1374/1645 train_time:127332ms step_avg:92.67ms +step:1375/1645 train_time:127425ms step_avg:92.67ms +step:1375/1645 val_loss:3.3356 train_time:127519ms step_avg:92.74ms +step:1376/1645 train_time:127541ms step_avg:92.69ms +step:1377/1645 train_time:127618ms step_avg:92.68ms +step:1378/1645 train_time:127713ms step_avg:92.68ms +step:1379/1645 train_time:127806ms step_avg:92.68ms +step:1380/1645 train_time:127899ms step_avg:92.68ms +step:1381/1645 train_time:127992ms step_avg:92.68ms +step:1382/1645 train_time:128085ms step_avg:92.68ms +step:1383/1645 train_time:128178ms step_avg:92.68ms +step:1384/1645 train_time:128271ms step_avg:92.68ms +step:1385/1645 train_time:128366ms step_avg:92.68ms +step:1386/1645 train_time:128460ms step_avg:92.68ms +step:1387/1645 train_time:128556ms step_avg:92.69ms +step:1388/1645 train_time:128650ms step_avg:92.69ms +step:1389/1645 train_time:128744ms step_avg:92.69ms +step:1390/1645 train_time:128838ms step_avg:92.69ms +step:1391/1645 train_time:128931ms step_avg:92.69ms +step:1392/1645 train_time:129024ms step_avg:92.69ms +step:1393/1645 train_time:129118ms step_avg:92.69ms +step:1394/1645 train_time:129211ms step_avg:92.69ms +step:1395/1645 train_time:129304ms step_avg:92.69ms +step:1396/1645 train_time:129398ms step_avg:92.69ms +step:1397/1645 train_time:129492ms step_avg:92.69ms +step:1398/1645 train_time:129587ms step_avg:92.69ms +step:1399/1645 train_time:129681ms step_avg:92.70ms +step:1400/1645 train_time:129774ms step_avg:92.70ms +step:1401/1645 train_time:129867ms step_avg:92.70ms +step:1402/1645 train_time:129961ms step_avg:92.70ms +step:1403/1645 train_time:130054ms step_avg:92.70ms +step:1404/1645 train_time:130147ms step_avg:92.70ms +step:1405/1645 train_time:130241ms step_avg:92.70ms +step:1406/1645 train_time:130335ms step_avg:92.70ms +step:1407/1645 train_time:130429ms step_avg:92.70ms +step:1408/1645 train_time:130525ms step_avg:92.70ms +step:1409/1645 train_time:130620ms step_avg:92.70ms +step:1410/1645 train_time:130713ms step_avg:92.70ms +step:1411/1645 train_time:130806ms step_avg:92.70ms +step:1412/1645 train_time:130899ms step_avg:92.70ms +step:1413/1645 train_time:130993ms step_avg:92.71ms +step:1414/1645 train_time:131087ms step_avg:92.71ms +step:1415/1645 train_time:131180ms step_avg:92.71ms +step:1416/1645 train_time:131274ms step_avg:92.71ms +step:1417/1645 train_time:131367ms step_avg:92.71ms +step:1418/1645 train_time:131460ms step_avg:92.71ms +step:1419/1645 train_time:131554ms step_avg:92.71ms +step:1420/1645 train_time:131649ms step_avg:92.71ms +step:1421/1645 train_time:131743ms step_avg:92.71ms +step:1422/1645 train_time:131837ms step_avg:92.71ms +step:1423/1645 train_time:131930ms step_avg:92.71ms +step:1424/1645 train_time:132023ms step_avg:92.71ms +step:1425/1645 train_time:132116ms step_avg:92.71ms +step:1426/1645 train_time:132209ms step_avg:92.71ms +step:1427/1645 train_time:132303ms step_avg:92.71ms +step:1428/1645 train_time:132396ms step_avg:92.71ms +step:1429/1645 train_time:132490ms step_avg:92.72ms +step:1430/1645 train_time:132585ms step_avg:92.72ms +step:1431/1645 train_time:132679ms step_avg:92.72ms +step:1432/1645 train_time:132772ms step_avg:92.72ms +step:1433/1645 train_time:132866ms step_avg:92.72ms +step:1434/1645 train_time:132961ms step_avg:92.72ms +step:1435/1645 train_time:133055ms step_avg:92.72ms +step:1436/1645 train_time:133148ms step_avg:92.72ms +step:1437/1645 train_time:133241ms step_avg:92.72ms +step:1438/1645 train_time:133335ms step_avg:92.72ms +step:1439/1645 train_time:133429ms step_avg:92.72ms +step:1440/1645 train_time:133523ms step_avg:92.72ms +step:1441/1645 train_time:133617ms step_avg:92.72ms +step:1442/1645 train_time:133710ms step_avg:92.73ms +step:1443/1645 train_time:133804ms step_avg:92.73ms +step:1444/1645 train_time:133898ms step_avg:92.73ms +step:1445/1645 train_time:133991ms step_avg:92.73ms +step:1446/1645 train_time:134085ms step_avg:92.73ms +step:1447/1645 train_time:134179ms step_avg:92.73ms +step:1448/1645 train_time:134272ms step_avg:92.73ms +step:1449/1645 train_time:134366ms step_avg:92.73ms +step:1450/1645 train_time:134459ms step_avg:92.73ms +step:1451/1645 train_time:134553ms step_avg:92.73ms +step:1452/1645 train_time:134647ms step_avg:92.73ms +step:1453/1645 train_time:134741ms step_avg:92.73ms +step:1454/1645 train_time:134835ms step_avg:92.73ms +step:1455/1645 train_time:134928ms step_avg:92.73ms +step:1456/1645 train_time:135022ms step_avg:92.73ms +step:1457/1645 train_time:135115ms step_avg:92.74ms +step:1458/1645 train_time:135209ms step_avg:92.74ms +step:1459/1645 train_time:135302ms step_avg:92.74ms +step:1460/1645 train_time:135396ms step_avg:92.74ms +step:1461/1645 train_time:135489ms step_avg:92.74ms +step:1462/1645 train_time:135583ms step_avg:92.74ms +step:1463/1645 train_time:135677ms step_avg:92.74ms +step:1464/1645 train_time:135770ms step_avg:92.74ms +step:1465/1645 train_time:135863ms step_avg:92.74ms +step:1466/1645 train_time:135957ms step_avg:92.74ms +step:1467/1645 train_time:136051ms step_avg:92.74ms +step:1468/1645 train_time:136145ms step_avg:92.74ms +step:1469/1645 train_time:136238ms step_avg:92.74ms +step:1470/1645 train_time:136331ms step_avg:92.74ms +step:1471/1645 train_time:136425ms step_avg:92.74ms +step:1472/1645 train_time:136519ms step_avg:92.74ms +step:1473/1645 train_time:136612ms step_avg:92.74ms +step:1474/1645 train_time:136706ms step_avg:92.74ms +step:1475/1645 train_time:136799ms step_avg:92.75ms +step:1476/1645 train_time:136893ms step_avg:92.75ms +step:1477/1645 train_time:136987ms step_avg:92.75ms +step:1478/1645 train_time:137081ms step_avg:92.75ms +step:1479/1645 train_time:137176ms step_avg:92.75ms +step:1480/1645 train_time:137270ms step_avg:92.75ms +step:1481/1645 train_time:137363ms step_avg:92.75ms +step:1482/1645 train_time:137457ms step_avg:92.75ms +step:1483/1645 train_time:137550ms step_avg:92.75ms +step:1484/1645 train_time:137645ms step_avg:92.75ms +step:1485/1645 train_time:137740ms step_avg:92.75ms +step:1486/1645 train_time:137832ms step_avg:92.75ms +step:1487/1645 train_time:137926ms step_avg:92.75ms +step:1488/1645 train_time:138020ms step_avg:92.76ms +step:1489/1645 train_time:138113ms step_avg:92.76ms +step:1490/1645 train_time:138207ms step_avg:92.76ms +step:1491/1645 train_time:138301ms step_avg:92.76ms +step:1492/1645 train_time:138394ms step_avg:92.76ms +step:1493/1645 train_time:138488ms step_avg:92.76ms +step:1494/1645 train_time:138582ms step_avg:92.76ms +step:1495/1645 train_time:138675ms step_avg:92.76ms +step:1496/1645 train_time:138769ms step_avg:92.76ms +step:1497/1645 train_time:138863ms step_avg:92.76ms +step:1498/1645 train_time:138956ms step_avg:92.76ms +step:1499/1645 train_time:139051ms step_avg:92.76ms +step:1500/1645 train_time:139145ms step_avg:92.76ms +step:1500/1645 val_loss:3.3060 train_time:139238ms step_avg:92.83ms +step:1501/1645 train_time:139260ms step_avg:92.78ms +step:1502/1645 train_time:139336ms step_avg:92.77ms +step:1503/1645 train_time:139434ms step_avg:92.77ms +step:1504/1645 train_time:139527ms step_avg:92.77ms +step:1505/1645 train_time:139620ms step_avg:92.77ms +step:1506/1645 train_time:139713ms step_avg:92.77ms +step:1507/1645 train_time:139805ms step_avg:92.77ms +step:1508/1645 train_time:139899ms step_avg:92.77ms +step:1509/1645 train_time:139992ms step_avg:92.77ms +step:1510/1645 train_time:140085ms step_avg:92.77ms +step:1511/1645 train_time:140179ms step_avg:92.77ms +step:1512/1645 train_time:140275ms step_avg:92.77ms +step:1513/1645 train_time:140369ms step_avg:92.78ms +step:1514/1645 train_time:140464ms step_avg:92.78ms +step:1515/1645 train_time:140557ms step_avg:92.78ms +step:1516/1645 train_time:140651ms step_avg:92.78ms +step:1517/1645 train_time:140743ms step_avg:92.78ms +step:1518/1645 train_time:140836ms step_avg:92.78ms +step:1519/1645 train_time:140929ms step_avg:92.78ms +step:1520/1645 train_time:141023ms step_avg:92.78ms +step:1521/1645 train_time:141117ms step_avg:92.78ms +step:1522/1645 train_time:141211ms step_avg:92.78ms +step:1523/1645 train_time:141305ms step_avg:92.78ms +step:1524/1645 train_time:141400ms step_avg:92.78ms +step:1525/1645 train_time:141493ms step_avg:92.78ms +step:1526/1645 train_time:141587ms step_avg:92.78ms +step:1527/1645 train_time:141680ms step_avg:92.78ms +step:1528/1645 train_time:141773ms step_avg:92.78ms +step:1529/1645 train_time:141866ms step_avg:92.78ms +step:1530/1645 train_time:141959ms step_avg:92.78ms +step:1531/1645 train_time:142053ms step_avg:92.78ms +step:1532/1645 train_time:142146ms step_avg:92.78ms +step:1533/1645 train_time:142240ms step_avg:92.79ms +step:1534/1645 train_time:142335ms step_avg:92.79ms +step:1535/1645 train_time:142429ms step_avg:92.79ms +step:1536/1645 train_time:142524ms step_avg:92.79ms +step:1537/1645 train_time:142617ms step_avg:92.79ms +step:1538/1645 train_time:142712ms step_avg:92.79ms +step:1539/1645 train_time:142804ms step_avg:92.79ms +step:1540/1645 train_time:142897ms step_avg:92.79ms +step:1541/1645 train_time:142990ms step_avg:92.79ms +step:1542/1645 train_time:143083ms step_avg:92.79ms +step:1543/1645 train_time:143177ms step_avg:92.79ms +step:1544/1645 train_time:143272ms step_avg:92.79ms +step:1545/1645 train_time:143365ms step_avg:92.79ms +step:1546/1645 train_time:143459ms step_avg:92.79ms +step:1547/1645 train_time:143552ms step_avg:92.79ms +step:1548/1645 train_time:143646ms step_avg:92.79ms +step:1549/1645 train_time:143740ms step_avg:92.80ms +step:1550/1645 train_time:143833ms step_avg:92.80ms +step:1551/1645 train_time:143926ms step_avg:92.80ms +step:1552/1645 train_time:144020ms step_avg:92.80ms +step:1553/1645 train_time:144114ms step_avg:92.80ms +step:1554/1645 train_time:144208ms step_avg:92.80ms +step:1555/1645 train_time:144302ms step_avg:92.80ms +step:1556/1645 train_time:144396ms step_avg:92.80ms +step:1557/1645 train_time:144490ms step_avg:92.80ms +step:1558/1645 train_time:144585ms step_avg:92.80ms +step:1559/1645 train_time:144678ms step_avg:92.80ms +step:1560/1645 train_time:144772ms step_avg:92.80ms +step:1561/1645 train_time:144866ms step_avg:92.80ms +step:1562/1645 train_time:144959ms step_avg:92.80ms +step:1563/1645 train_time:145054ms step_avg:92.80ms +step:1564/1645 train_time:145148ms step_avg:92.81ms +step:1565/1645 train_time:145242ms step_avg:92.81ms +step:1566/1645 train_time:145336ms step_avg:92.81ms +step:1567/1645 train_time:145430ms step_avg:92.81ms +step:1568/1645 train_time:145523ms step_avg:92.81ms +step:1569/1645 train_time:145617ms step_avg:92.81ms +step:1570/1645 train_time:145711ms step_avg:92.81ms +step:1571/1645 train_time:145804ms step_avg:92.81ms +step:1572/1645 train_time:145897ms step_avg:92.81ms +step:1573/1645 train_time:145991ms step_avg:92.81ms +step:1574/1645 train_time:146084ms step_avg:92.81ms +step:1575/1645 train_time:146178ms step_avg:92.81ms +step:1576/1645 train_time:146272ms step_avg:92.81ms +step:1577/1645 train_time:146365ms step_avg:92.81ms +step:1578/1645 train_time:146459ms step_avg:92.81ms +step:1579/1645 train_time:146552ms step_avg:92.81ms +step:1580/1645 train_time:146647ms step_avg:92.81ms +step:1581/1645 train_time:146741ms step_avg:92.82ms +step:1582/1645 train_time:146834ms step_avg:92.82ms +step:1583/1645 train_time:146927ms step_avg:92.82ms +step:1584/1645 train_time:147021ms step_avg:92.82ms +step:1585/1645 train_time:147114ms step_avg:92.82ms +step:1586/1645 train_time:147207ms step_avg:92.82ms +step:1587/1645 train_time:147301ms step_avg:92.82ms +step:1588/1645 train_time:147395ms step_avg:92.82ms +step:1589/1645 train_time:147488ms step_avg:92.82ms +step:1590/1645 train_time:147582ms step_avg:92.82ms +step:1591/1645 train_time:147676ms step_avg:92.82ms +step:1592/1645 train_time:147769ms step_avg:92.82ms +step:1593/1645 train_time:147864ms step_avg:92.82ms +step:1594/1645 train_time:147957ms step_avg:92.82ms +step:1595/1645 train_time:148052ms step_avg:92.82ms +step:1596/1645 train_time:148145ms step_avg:92.82ms +step:1597/1645 train_time:148239ms step_avg:92.82ms +step:1598/1645 train_time:148332ms step_avg:92.82ms +step:1599/1645 train_time:148425ms step_avg:92.82ms +step:1600/1645 train_time:148519ms step_avg:92.82ms +step:1601/1645 train_time:148613ms step_avg:92.82ms +step:1602/1645 train_time:148706ms step_avg:92.83ms +step:1603/1645 train_time:148800ms step_avg:92.83ms +step:1604/1645 train_time:148894ms step_avg:92.83ms +step:1605/1645 train_time:148988ms step_avg:92.83ms +step:1606/1645 train_time:149081ms step_avg:92.83ms +step:1607/1645 train_time:149174ms step_avg:92.83ms +step:1608/1645 train_time:149268ms step_avg:92.83ms +step:1609/1645 train_time:149362ms step_avg:92.83ms +step:1610/1645 train_time:149456ms step_avg:92.83ms +step:1611/1645 train_time:149551ms step_avg:92.83ms +step:1612/1645 train_time:149644ms step_avg:92.83ms +step:1613/1645 train_time:149737ms step_avg:92.83ms +step:1614/1645 train_time:149831ms step_avg:92.83ms +step:1615/1645 train_time:149924ms step_avg:92.83ms +step:1616/1645 train_time:150019ms step_avg:92.83ms +step:1617/1645 train_time:150112ms step_avg:92.83ms +step:1618/1645 train_time:150206ms step_avg:92.83ms +step:1619/1645 train_time:150300ms step_avg:92.83ms +step:1620/1645 train_time:150393ms step_avg:92.84ms +step:1621/1645 train_time:150487ms step_avg:92.84ms +step:1622/1645 train_time:150581ms step_avg:92.84ms +step:1623/1645 train_time:150675ms step_avg:92.84ms +step:1624/1645 train_time:150768ms step_avg:92.84ms +step:1625/1645 train_time:150862ms step_avg:92.84ms +step:1625/1645 val_loss:3.2821 train_time:150956ms step_avg:92.90ms +step:1626/1645 train_time:150976ms step_avg:92.85ms +step:1627/1645 train_time:151053ms step_avg:92.84ms +step:1628/1645 train_time:151149ms step_avg:92.84ms +step:1629/1645 train_time:151242ms step_avg:92.84ms +step:1630/1645 train_time:151335ms step_avg:92.84ms +step:1631/1645 train_time:151427ms step_avg:92.84ms +step:1632/1645 train_time:151520ms step_avg:92.84ms +step:1633/1645 train_time:151613ms step_avg:92.84ms +step:1634/1645 train_time:151706ms step_avg:92.84ms +step:1635/1645 train_time:151799ms step_avg:92.84ms +step:1636/1645 train_time:151896ms step_avg:92.85ms +step:1637/1645 train_time:151991ms step_avg:92.85ms +step:1638/1645 train_time:152085ms step_avg:92.85ms +step:1639/1645 train_time:152179ms step_avg:92.85ms +step:1640/1645 train_time:152272ms step_avg:92.85ms +step:1641/1645 train_time:152366ms step_avg:92.85ms +step:1642/1645 train_time:152459ms step_avg:92.85ms +step:1643/1645 train_time:152552ms step_avg:92.85ms +step:1644/1645 train_time:152645ms step_avg:92.85ms +step:1645/1645 train_time:152739ms step_avg:92.85ms +step:1645/1645 val_loss:3.2765 train_time:152834ms step_avg:92.91ms +peak memory allocated: 32074 MiB reserved: 46896 MiB diff --git a/records/091825_Smear/898a21a4-3cf7-4c32-a61b-c3427618ae7b.txt b/records/091825_Smear/898a21a4-3cf7-4c32-a61b-c3427618ae7b.txt new file mode 100644 index 000000000..90590e863 --- /dev/null +++ b/records/091825_Smear/898a21a4-3cf7-4c32-a61b-c3427618ae7b.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:44:37 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 29C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 35C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 34C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:129ms step_avg:129.06ms +step:2/1645 train_time:149ms step_avg:74.73ms +step:3/1645 train_time:218ms step_avg:72.57ms +step:4/1645 train_time:307ms step_avg:76.85ms +step:5/1645 train_time:398ms step_avg:79.68ms +step:6/1645 train_time:489ms step_avg:81.52ms +step:7/1645 train_time:580ms step_avg:82.83ms +step:8/1645 train_time:671ms step_avg:83.83ms +step:9/1645 train_time:761ms step_avg:84.59ms +step:10/1645 train_time:852ms step_avg:85.16ms +step:11/1645 train_time:942ms step_avg:85.68ms +step:12/1645 train_time:1037ms step_avg:86.41ms +step:13/1645 train_time:1132ms step_avg:87.06ms +step:14/1645 train_time:1226ms step_avg:87.54ms +step:15/1645 train_time:1318ms step_avg:87.88ms +step:16/1645 train_time:1411ms step_avg:88.16ms +step:17/1645 train_time:1502ms step_avg:88.35ms +step:18/1645 train_time:1594ms step_avg:88.54ms +step:19/1645 train_time:1685ms step_avg:88.67ms +step:20/1645 train_time:1776ms step_avg:88.79ms +step:21/1645 train_time:1867ms step_avg:88.89ms +step:22/1645 train_time:1959ms step_avg:89.05ms +step:23/1645 train_time:2052ms step_avg:89.22ms +step:24/1645 train_time:2145ms step_avg:89.36ms +step:25/1645 train_time:2238ms step_avg:89.50ms +step:26/1645 train_time:2330ms step_avg:89.61ms +step:27/1645 train_time:2422ms step_avg:89.71ms +step:28/1645 train_time:2515ms step_avg:89.81ms +step:29/1645 train_time:2607ms step_avg:89.88ms +step:30/1645 train_time:2698ms step_avg:89.94ms +step:31/1645 train_time:2790ms step_avg:90.00ms +step:32/1645 train_time:2882ms step_avg:90.05ms +step:33/1645 train_time:2974ms step_avg:90.11ms +step:34/1645 train_time:3066ms step_avg:90.18ms +step:35/1645 train_time:3160ms step_avg:90.29ms +step:36/1645 train_time:3253ms step_avg:90.36ms +step:37/1645 train_time:3345ms step_avg:90.41ms +step:38/1645 train_time:3438ms step_avg:90.48ms +step:39/1645 train_time:3530ms step_avg:90.51ms +step:40/1645 train_time:3623ms step_avg:90.56ms +step:41/1645 train_time:3715ms step_avg:90.61ms +step:42/1645 train_time:3806ms step_avg:90.62ms +step:43/1645 train_time:3898ms step_avg:90.65ms +step:44/1645 train_time:3989ms step_avg:90.67ms +step:45/1645 train_time:4082ms step_avg:90.71ms +step:46/1645 train_time:4174ms step_avg:90.73ms +step:47/1645 train_time:4267ms step_avg:90.78ms +step:48/1645 train_time:4359ms step_avg:90.81ms +step:49/1645 train_time:4451ms step_avg:90.83ms +step:50/1645 train_time:4543ms step_avg:90.86ms +step:51/1645 train_time:4635ms step_avg:90.89ms +step:52/1645 train_time:4727ms step_avg:90.91ms +step:53/1645 train_time:4819ms step_avg:90.93ms +step:54/1645 train_time:4910ms step_avg:90.93ms +step:55/1645 train_time:5002ms step_avg:90.94ms +step:56/1645 train_time:5094ms step_avg:90.97ms +step:57/1645 train_time:5186ms step_avg:90.98ms +step:58/1645 train_time:5278ms step_avg:90.99ms +step:59/1645 train_time:5370ms step_avg:91.02ms +step:60/1645 train_time:5462ms step_avg:91.03ms +step:61/1645 train_time:5554ms step_avg:91.05ms +step:62/1645 train_time:5646ms step_avg:91.07ms +step:63/1645 train_time:5738ms step_avg:91.09ms +step:64/1645 train_time:5830ms step_avg:91.10ms +step:65/1645 train_time:5921ms step_avg:91.10ms +step:66/1645 train_time:6014ms step_avg:91.12ms +step:67/1645 train_time:6106ms step_avg:91.13ms +step:68/1645 train_time:6199ms step_avg:91.16ms +step:69/1645 train_time:6290ms step_avg:91.16ms +step:70/1645 train_time:6383ms step_avg:91.18ms +step:71/1645 train_time:6474ms step_avg:91.19ms +step:72/1645 train_time:6566ms step_avg:91.20ms +step:73/1645 train_time:6658ms step_avg:91.21ms +step:74/1645 train_time:6750ms step_avg:91.21ms +step:75/1645 train_time:6842ms step_avg:91.22ms +step:76/1645 train_time:6934ms step_avg:91.24ms +step:77/1645 train_time:7025ms step_avg:91.24ms +step:78/1645 train_time:7119ms step_avg:91.26ms +step:79/1645 train_time:7210ms step_avg:91.27ms +step:80/1645 train_time:7302ms step_avg:91.28ms +step:81/1645 train_time:7394ms step_avg:91.29ms +step:82/1645 train_time:7485ms step_avg:91.29ms +step:83/1645 train_time:7578ms step_avg:91.31ms +step:84/1645 train_time:7669ms step_avg:91.30ms +step:85/1645 train_time:7760ms step_avg:91.30ms +step:86/1645 train_time:7852ms step_avg:91.31ms +step:87/1645 train_time:7944ms step_avg:91.31ms +step:88/1645 train_time:8036ms step_avg:91.32ms +step:89/1645 train_time:8128ms step_avg:91.32ms +step:90/1645 train_time:8221ms step_avg:91.34ms +step:91/1645 train_time:8312ms step_avg:91.34ms +step:92/1645 train_time:8404ms step_avg:91.35ms +step:93/1645 train_time:8495ms step_avg:91.35ms +step:94/1645 train_time:8586ms step_avg:91.34ms +step:95/1645 train_time:8678ms step_avg:91.34ms +step:96/1645 train_time:8769ms step_avg:91.34ms +step:97/1645 train_time:8861ms step_avg:91.35ms +step:98/1645 train_time:8952ms step_avg:91.35ms +step:99/1645 train_time:9044ms step_avg:91.35ms +step:100/1645 train_time:9136ms step_avg:91.36ms +step:101/1645 train_time:9228ms step_avg:91.37ms +step:102/1645 train_time:9321ms step_avg:91.38ms +step:103/1645 train_time:9413ms step_avg:91.39ms +step:104/1645 train_time:9504ms step_avg:91.39ms +step:105/1645 train_time:9596ms step_avg:91.39ms +step:106/1645 train_time:9687ms step_avg:91.39ms +step:107/1645 train_time:9778ms step_avg:91.38ms +step:108/1645 train_time:9869ms step_avg:91.38ms +step:109/1645 train_time:9961ms step_avg:91.39ms +step:110/1645 train_time:10052ms step_avg:91.39ms +step:111/1645 train_time:10144ms step_avg:91.39ms +step:112/1645 train_time:10237ms step_avg:91.40ms +step:113/1645 train_time:10330ms step_avg:91.41ms +step:114/1645 train_time:10422ms step_avg:91.42ms +step:115/1645 train_time:10513ms step_avg:91.42ms +step:116/1645 train_time:10604ms step_avg:91.41ms +step:117/1645 train_time:10696ms step_avg:91.42ms +step:118/1645 train_time:10789ms step_avg:91.43ms +step:119/1645 train_time:10881ms step_avg:91.43ms +step:120/1645 train_time:10972ms step_avg:91.44ms +step:121/1645 train_time:11063ms step_avg:91.43ms +step:122/1645 train_time:11154ms step_avg:91.43ms +step:123/1645 train_time:11246ms step_avg:91.43ms +step:124/1645 train_time:11339ms step_avg:91.44ms +step:125/1645 train_time:11431ms step_avg:91.45ms +step:125/1645 val_loss:4.3213 train_time:11524ms step_avg:92.19ms +step:126/1645 train_time:11545ms step_avg:91.63ms +step:127/1645 train_time:11620ms step_avg:91.50ms +step:128/1645 train_time:11722ms step_avg:91.58ms +step:129/1645 train_time:11820ms step_avg:91.63ms +step:130/1645 train_time:11913ms step_avg:91.64ms +step:131/1645 train_time:12004ms step_avg:91.63ms +step:132/1645 train_time:12094ms step_avg:91.62ms +step:133/1645 train_time:12185ms step_avg:91.62ms +step:134/1645 train_time:12275ms step_avg:91.61ms +step:135/1645 train_time:12366ms step_avg:91.60ms +step:136/1645 train_time:12457ms step_avg:91.59ms +step:137/1645 train_time:12548ms step_avg:91.59ms +step:138/1645 train_time:12641ms step_avg:91.60ms +step:139/1645 train_time:12735ms step_avg:91.62ms +step:140/1645 train_time:12828ms step_avg:91.63ms +step:141/1645 train_time:12921ms step_avg:91.64ms +step:142/1645 train_time:13013ms step_avg:91.64ms +step:143/1645 train_time:13104ms step_avg:91.63ms +step:144/1645 train_time:13195ms step_avg:91.63ms +step:145/1645 train_time:13286ms step_avg:91.62ms +step:146/1645 train_time:13376ms step_avg:91.62ms +step:147/1645 train_time:13467ms step_avg:91.61ms +step:148/1645 train_time:13558ms step_avg:91.61ms +step:149/1645 train_time:13650ms step_avg:91.61ms +step:150/1645 train_time:13743ms step_avg:91.62ms +step:151/1645 train_time:13835ms step_avg:91.62ms +step:152/1645 train_time:13928ms step_avg:91.63ms +step:153/1645 train_time:14020ms step_avg:91.63ms +step:154/1645 train_time:14113ms step_avg:91.64ms +step:155/1645 train_time:14205ms step_avg:91.65ms +step:156/1645 train_time:14296ms step_avg:91.64ms +step:157/1645 train_time:14387ms step_avg:91.63ms +step:158/1645 train_time:14478ms step_avg:91.63ms +step:159/1645 train_time:14569ms step_avg:91.63ms +step:160/1645 train_time:14661ms step_avg:91.63ms +step:161/1645 train_time:14754ms step_avg:91.64ms +step:162/1645 train_time:14847ms step_avg:91.65ms +step:163/1645 train_time:14938ms step_avg:91.65ms +step:164/1645 train_time:15031ms step_avg:91.65ms +step:165/1645 train_time:15123ms step_avg:91.65ms +step:166/1645 train_time:15215ms step_avg:91.65ms +step:167/1645 train_time:15306ms step_avg:91.65ms +step:168/1645 train_time:15397ms step_avg:91.65ms +step:169/1645 train_time:15488ms step_avg:91.64ms +step:170/1645 train_time:15579ms step_avg:91.64ms +step:171/1645 train_time:15670ms step_avg:91.64ms +step:172/1645 train_time:15763ms step_avg:91.64ms +step:173/1645 train_time:15855ms step_avg:91.64ms +step:174/1645 train_time:15947ms step_avg:91.65ms +step:175/1645 train_time:16039ms step_avg:91.65ms +step:176/1645 train_time:16131ms step_avg:91.65ms +step:177/1645 train_time:16222ms step_avg:91.65ms +step:178/1645 train_time:16313ms step_avg:91.65ms +step:179/1645 train_time:16404ms step_avg:91.64ms +step:180/1645 train_time:16495ms step_avg:91.64ms +step:181/1645 train_time:16586ms step_avg:91.64ms +step:182/1645 train_time:16679ms step_avg:91.64ms +step:183/1645 train_time:16771ms step_avg:91.64ms +step:184/1645 train_time:16863ms step_avg:91.65ms +step:185/1645 train_time:16956ms step_avg:91.65ms +step:186/1645 train_time:17048ms step_avg:91.66ms +step:187/1645 train_time:17139ms step_avg:91.65ms +step:188/1645 train_time:17230ms step_avg:91.65ms +step:189/1645 train_time:17321ms step_avg:91.65ms +step:190/1645 train_time:17413ms step_avg:91.64ms +step:191/1645 train_time:17503ms step_avg:91.64ms +step:192/1645 train_time:17595ms step_avg:91.64ms +step:193/1645 train_time:17686ms step_avg:91.64ms +step:194/1645 train_time:17779ms step_avg:91.65ms +step:195/1645 train_time:17871ms step_avg:91.65ms +step:196/1645 train_time:17964ms step_avg:91.65ms +step:197/1645 train_time:18057ms step_avg:91.66ms +step:198/1645 train_time:18148ms step_avg:91.66ms +step:199/1645 train_time:18240ms step_avg:91.66ms +step:200/1645 train_time:18332ms step_avg:91.66ms +step:201/1645 train_time:18423ms step_avg:91.66ms +step:202/1645 train_time:18514ms step_avg:91.65ms +step:203/1645 train_time:18605ms step_avg:91.65ms +step:204/1645 train_time:18696ms step_avg:91.65ms +step:205/1645 train_time:18787ms step_avg:91.65ms +step:206/1645 train_time:18880ms step_avg:91.65ms +step:207/1645 train_time:18972ms step_avg:91.65ms +step:208/1645 train_time:19066ms step_avg:91.66ms +step:209/1645 train_time:19156ms step_avg:91.66ms +step:210/1645 train_time:19248ms step_avg:91.66ms +step:211/1645 train_time:19340ms step_avg:91.66ms +step:212/1645 train_time:19431ms step_avg:91.65ms +step:213/1645 train_time:19523ms step_avg:91.66ms +step:214/1645 train_time:19614ms step_avg:91.65ms +step:215/1645 train_time:19705ms step_avg:91.65ms +step:216/1645 train_time:19797ms step_avg:91.65ms +step:217/1645 train_time:19888ms step_avg:91.65ms +step:218/1645 train_time:19980ms step_avg:91.65ms +step:219/1645 train_time:20073ms step_avg:91.66ms +step:220/1645 train_time:20165ms step_avg:91.66ms +step:221/1645 train_time:20257ms step_avg:91.66ms +step:222/1645 train_time:20348ms step_avg:91.66ms +step:223/1645 train_time:20440ms step_avg:91.66ms +step:224/1645 train_time:20532ms step_avg:91.66ms +step:225/1645 train_time:20622ms step_avg:91.65ms +step:226/1645 train_time:20713ms step_avg:91.65ms +step:227/1645 train_time:20804ms step_avg:91.65ms +step:228/1645 train_time:20896ms step_avg:91.65ms +step:229/1645 train_time:20988ms step_avg:91.65ms +step:230/1645 train_time:21083ms step_avg:91.67ms +step:231/1645 train_time:21177ms step_avg:91.67ms +step:232/1645 train_time:21267ms step_avg:91.67ms +step:233/1645 train_time:21358ms step_avg:91.67ms +step:234/1645 train_time:21450ms step_avg:91.67ms +step:235/1645 train_time:21542ms step_avg:91.67ms +step:236/1645 train_time:21633ms step_avg:91.66ms +step:237/1645 train_time:21724ms step_avg:91.66ms +step:238/1645 train_time:21815ms step_avg:91.66ms +step:239/1645 train_time:21906ms step_avg:91.66ms +step:240/1645 train_time:21998ms step_avg:91.66ms +step:241/1645 train_time:22089ms step_avg:91.66ms +step:242/1645 train_time:22183ms step_avg:91.67ms +step:243/1645 train_time:22276ms step_avg:91.67ms +step:244/1645 train_time:22367ms step_avg:91.67ms +step:245/1645 train_time:22460ms step_avg:91.67ms +step:246/1645 train_time:22551ms step_avg:91.67ms +step:247/1645 train_time:22642ms step_avg:91.67ms +step:248/1645 train_time:22734ms step_avg:91.67ms +step:249/1645 train_time:22824ms step_avg:91.66ms +step:250/1645 train_time:22916ms step_avg:91.66ms +step:250/1645 val_loss:3.9739 train_time:23007ms step_avg:92.03ms +step:251/1645 train_time:23027ms step_avg:91.74ms +step:252/1645 train_time:23103ms step_avg:91.68ms +step:253/1645 train_time:23197ms step_avg:91.69ms +step:254/1645 train_time:23290ms step_avg:91.69ms +step:255/1645 train_time:23380ms step_avg:91.69ms +step:256/1645 train_time:23471ms step_avg:91.68ms +step:257/1645 train_time:23561ms step_avg:91.68ms +step:258/1645 train_time:23652ms step_avg:91.67ms +step:259/1645 train_time:23743ms step_avg:91.67ms +step:260/1645 train_time:23834ms step_avg:91.67ms +step:261/1645 train_time:23927ms step_avg:91.67ms +step:262/1645 train_time:24021ms step_avg:91.68ms +step:263/1645 train_time:24115ms step_avg:91.69ms +step:264/1645 train_time:24207ms step_avg:91.69ms +step:265/1645 train_time:24299ms step_avg:91.69ms +step:266/1645 train_time:24390ms step_avg:91.69ms +step:267/1645 train_time:24481ms step_avg:91.69ms +step:268/1645 train_time:24572ms step_avg:91.69ms +step:269/1645 train_time:24663ms step_avg:91.68ms +step:270/1645 train_time:24754ms step_avg:91.68ms +step:271/1645 train_time:24847ms step_avg:91.69ms +step:272/1645 train_time:24939ms step_avg:91.69ms +step:273/1645 train_time:25031ms step_avg:91.69ms +step:274/1645 train_time:25124ms step_avg:91.69ms +step:275/1645 train_time:25217ms step_avg:91.70ms +step:276/1645 train_time:25309ms step_avg:91.70ms +step:277/1645 train_time:25400ms step_avg:91.70ms +step:278/1645 train_time:25491ms step_avg:91.70ms +step:279/1645 train_time:25582ms step_avg:91.69ms +step:280/1645 train_time:25673ms step_avg:91.69ms +step:281/1645 train_time:25763ms step_avg:91.69ms +step:282/1645 train_time:25855ms step_avg:91.68ms +step:283/1645 train_time:25947ms step_avg:91.68ms +step:284/1645 train_time:26039ms step_avg:91.69ms +step:285/1645 train_time:26132ms step_avg:91.69ms +step:286/1645 train_time:26225ms step_avg:91.70ms +step:287/1645 train_time:26317ms step_avg:91.70ms +step:288/1645 train_time:26409ms step_avg:91.70ms +step:289/1645 train_time:26500ms step_avg:91.70ms +step:290/1645 train_time:26591ms step_avg:91.69ms +step:291/1645 train_time:26682ms step_avg:91.69ms +step:292/1645 train_time:26773ms step_avg:91.69ms +step:293/1645 train_time:26864ms step_avg:91.69ms +step:294/1645 train_time:26957ms step_avg:91.69ms +step:295/1645 train_time:27048ms step_avg:91.69ms +step:296/1645 train_time:27141ms step_avg:91.69ms +step:297/1645 train_time:27234ms step_avg:91.70ms +step:298/1645 train_time:27326ms step_avg:91.70ms +step:299/1645 train_time:27417ms step_avg:91.70ms +step:300/1645 train_time:27509ms step_avg:91.70ms +step:301/1645 train_time:27600ms step_avg:91.69ms +step:302/1645 train_time:27692ms step_avg:91.69ms +step:303/1645 train_time:27783ms step_avg:91.69ms +step:304/1645 train_time:27874ms step_avg:91.69ms +step:305/1645 train_time:27965ms step_avg:91.69ms +step:306/1645 train_time:28057ms step_avg:91.69ms +step:307/1645 train_time:28149ms step_avg:91.69ms +step:308/1645 train_time:28241ms step_avg:91.69ms +step:309/1645 train_time:28334ms step_avg:91.69ms +step:310/1645 train_time:28425ms step_avg:91.69ms +step:311/1645 train_time:28518ms step_avg:91.70ms +step:312/1645 train_time:28610ms step_avg:91.70ms +step:313/1645 train_time:28701ms step_avg:91.69ms +step:314/1645 train_time:28792ms step_avg:91.69ms +step:315/1645 train_time:28883ms step_avg:91.69ms +step:316/1645 train_time:28974ms step_avg:91.69ms +step:317/1645 train_time:29067ms step_avg:91.69ms +step:318/1645 train_time:29159ms step_avg:91.69ms +step:319/1645 train_time:29250ms step_avg:91.69ms +step:320/1645 train_time:29342ms step_avg:91.69ms +step:321/1645 train_time:29434ms step_avg:91.69ms +step:322/1645 train_time:29526ms step_avg:91.70ms +step:323/1645 train_time:29618ms step_avg:91.70ms +step:324/1645 train_time:29709ms step_avg:91.70ms +step:325/1645 train_time:29800ms step_avg:91.69ms +step:326/1645 train_time:29891ms step_avg:91.69ms +step:327/1645 train_time:29982ms step_avg:91.69ms +step:328/1645 train_time:30075ms step_avg:91.69ms +step:329/1645 train_time:30166ms step_avg:91.69ms +step:330/1645 train_time:30258ms step_avg:91.69ms +step:331/1645 train_time:30350ms step_avg:91.69ms +step:332/1645 train_time:30441ms step_avg:91.69ms +step:333/1645 train_time:30535ms step_avg:91.70ms +step:334/1645 train_time:30628ms step_avg:91.70ms +step:335/1645 train_time:30719ms step_avg:91.70ms +step:336/1645 train_time:30811ms step_avg:91.70ms +step:337/1645 train_time:30901ms step_avg:91.70ms +step:338/1645 train_time:30993ms step_avg:91.70ms +step:339/1645 train_time:31084ms step_avg:91.69ms +step:340/1645 train_time:31176ms step_avg:91.69ms +step:341/1645 train_time:31267ms step_avg:91.69ms +step:342/1645 train_time:31359ms step_avg:91.69ms +step:343/1645 train_time:31450ms step_avg:91.69ms +step:344/1645 train_time:31542ms step_avg:91.69ms +step:345/1645 train_time:31636ms step_avg:91.70ms +step:346/1645 train_time:31728ms step_avg:91.70ms +step:347/1645 train_time:31820ms step_avg:91.70ms +step:348/1645 train_time:31911ms step_avg:91.70ms +step:349/1645 train_time:32002ms step_avg:91.70ms +step:350/1645 train_time:32094ms step_avg:91.70ms +step:351/1645 train_time:32185ms step_avg:91.70ms +step:352/1645 train_time:32277ms step_avg:91.70ms +step:353/1645 train_time:32369ms step_avg:91.70ms +step:354/1645 train_time:32460ms step_avg:91.70ms +step:355/1645 train_time:32553ms step_avg:91.70ms +step:356/1645 train_time:32645ms step_avg:91.70ms +step:357/1645 train_time:32736ms step_avg:91.70ms +step:358/1645 train_time:32829ms step_avg:91.70ms +step:359/1645 train_time:32920ms step_avg:91.70ms +step:360/1645 train_time:33012ms step_avg:91.70ms +step:361/1645 train_time:33104ms step_avg:91.70ms +step:362/1645 train_time:33195ms step_avg:91.70ms +step:363/1645 train_time:33287ms step_avg:91.70ms +step:364/1645 train_time:33377ms step_avg:91.70ms +step:365/1645 train_time:33469ms step_avg:91.70ms +step:366/1645 train_time:33561ms step_avg:91.70ms +step:367/1645 train_time:33653ms step_avg:91.70ms +step:368/1645 train_time:33744ms step_avg:91.69ms +step:369/1645 train_time:33837ms step_avg:91.70ms +step:370/1645 train_time:33929ms step_avg:91.70ms +step:371/1645 train_time:34021ms step_avg:91.70ms +step:372/1645 train_time:34113ms step_avg:91.70ms +step:373/1645 train_time:34204ms step_avg:91.70ms +step:374/1645 train_time:34297ms step_avg:91.70ms +step:375/1645 train_time:34389ms step_avg:91.70ms +step:375/1645 val_loss:3.8159 train_time:34480ms step_avg:91.95ms +step:376/1645 train_time:34500ms step_avg:91.76ms +step:377/1645 train_time:34575ms step_avg:91.71ms +step:378/1645 train_time:34669ms step_avg:91.72ms +step:379/1645 train_time:34761ms step_avg:91.72ms +step:380/1645 train_time:34852ms step_avg:91.72ms +step:381/1645 train_time:34944ms step_avg:91.72ms +step:382/1645 train_time:35034ms step_avg:91.71ms +step:383/1645 train_time:35125ms step_avg:91.71ms +step:384/1645 train_time:35216ms step_avg:91.71ms +step:385/1645 train_time:35306ms step_avg:91.71ms +step:386/1645 train_time:35398ms step_avg:91.70ms +step:387/1645 train_time:35491ms step_avg:91.71ms +step:388/1645 train_time:35585ms step_avg:91.71ms +step:389/1645 train_time:35678ms step_avg:91.72ms +step:390/1645 train_time:35769ms step_avg:91.72ms +step:391/1645 train_time:35861ms step_avg:91.72ms +step:392/1645 train_time:35953ms step_avg:91.72ms +step:393/1645 train_time:36044ms step_avg:91.71ms +step:394/1645 train_time:36134ms step_avg:91.71ms +step:395/1645 train_time:36225ms step_avg:91.71ms +step:396/1645 train_time:36316ms step_avg:91.71ms +step:397/1645 train_time:36408ms step_avg:91.71ms +step:398/1645 train_time:36499ms step_avg:91.71ms +step:399/1645 train_time:36593ms step_avg:91.71ms +step:400/1645 train_time:36686ms step_avg:91.71ms +step:401/1645 train_time:36777ms step_avg:91.71ms +step:402/1645 train_time:36870ms step_avg:91.72ms +step:403/1645 train_time:36962ms step_avg:91.72ms +step:404/1645 train_time:37054ms step_avg:91.72ms +step:405/1645 train_time:37145ms step_avg:91.72ms +step:406/1645 train_time:37236ms step_avg:91.71ms +step:407/1645 train_time:37327ms step_avg:91.71ms +step:408/1645 train_time:37419ms step_avg:91.71ms +step:409/1645 train_time:37511ms step_avg:91.71ms +step:410/1645 train_time:37603ms step_avg:91.71ms +step:411/1645 train_time:37695ms step_avg:91.72ms +step:412/1645 train_time:37787ms step_avg:91.72ms +step:413/1645 train_time:37879ms step_avg:91.72ms +step:414/1645 train_time:37971ms step_avg:91.72ms +step:415/1645 train_time:38063ms step_avg:91.72ms +step:416/1645 train_time:38154ms step_avg:91.72ms +step:417/1645 train_time:38245ms step_avg:91.71ms +step:418/1645 train_time:38336ms step_avg:91.71ms +step:419/1645 train_time:38428ms step_avg:91.71ms +step:420/1645 train_time:38519ms step_avg:91.71ms +step:421/1645 train_time:38611ms step_avg:91.71ms +step:422/1645 train_time:38703ms step_avg:91.71ms +step:423/1645 train_time:38795ms step_avg:91.71ms +step:424/1645 train_time:38887ms step_avg:91.72ms +step:425/1645 train_time:38980ms step_avg:91.72ms +step:426/1645 train_time:39072ms step_avg:91.72ms +step:427/1645 train_time:39165ms step_avg:91.72ms +step:428/1645 train_time:39256ms step_avg:91.72ms +step:429/1645 train_time:39347ms step_avg:91.72ms +step:430/1645 train_time:39438ms step_avg:91.72ms +step:431/1645 train_time:39529ms step_avg:91.72ms +step:432/1645 train_time:39621ms step_avg:91.71ms +step:433/1645 train_time:39713ms step_avg:91.72ms +step:434/1645 train_time:39805ms step_avg:91.72ms +step:435/1645 train_time:39896ms step_avg:91.71ms +step:436/1645 train_time:39989ms step_avg:91.72ms +step:437/1645 train_time:40081ms step_avg:91.72ms +step:438/1645 train_time:40174ms step_avg:91.72ms +step:439/1645 train_time:40266ms step_avg:91.72ms +step:440/1645 train_time:40356ms step_avg:91.72ms +step:441/1645 train_time:40448ms step_avg:91.72ms +step:442/1645 train_time:40539ms step_avg:91.72ms +step:443/1645 train_time:40631ms step_avg:91.72ms +step:444/1645 train_time:40722ms step_avg:91.72ms +step:445/1645 train_time:40814ms step_avg:91.72ms +step:446/1645 train_time:40907ms step_avg:91.72ms +step:447/1645 train_time:40998ms step_avg:91.72ms +step:448/1645 train_time:41091ms step_avg:91.72ms +step:449/1645 train_time:41183ms step_avg:91.72ms +step:450/1645 train_time:41274ms step_avg:91.72ms +step:451/1645 train_time:41365ms step_avg:91.72ms +step:452/1645 train_time:41456ms step_avg:91.72ms +step:453/1645 train_time:41548ms step_avg:91.72ms +step:454/1645 train_time:41639ms step_avg:91.72ms +step:455/1645 train_time:41731ms step_avg:91.72ms +step:456/1645 train_time:41822ms step_avg:91.71ms +step:457/1645 train_time:41913ms step_avg:91.71ms +step:458/1645 train_time:42005ms step_avg:91.71ms +step:459/1645 train_time:42097ms step_avg:91.71ms +step:460/1645 train_time:42191ms step_avg:91.72ms +step:461/1645 train_time:42282ms step_avg:91.72ms +step:462/1645 train_time:42373ms step_avg:91.72ms +step:463/1645 train_time:42465ms step_avg:91.72ms +step:464/1645 train_time:42556ms step_avg:91.72ms +step:465/1645 train_time:42649ms step_avg:91.72ms +step:466/1645 train_time:42741ms step_avg:91.72ms +step:467/1645 train_time:42832ms step_avg:91.72ms +step:468/1645 train_time:42923ms step_avg:91.72ms +step:469/1645 train_time:43015ms step_avg:91.72ms +step:470/1645 train_time:43106ms step_avg:91.72ms +step:471/1645 train_time:43198ms step_avg:91.72ms +step:472/1645 train_time:43291ms step_avg:91.72ms +step:473/1645 train_time:43384ms step_avg:91.72ms +step:474/1645 train_time:43476ms step_avg:91.72ms +step:475/1645 train_time:43567ms step_avg:91.72ms +step:476/1645 train_time:43659ms step_avg:91.72ms +step:477/1645 train_time:43750ms step_avg:91.72ms +step:478/1645 train_time:43842ms step_avg:91.72ms +step:479/1645 train_time:43933ms step_avg:91.72ms +step:480/1645 train_time:44024ms step_avg:91.72ms +step:481/1645 train_time:44117ms step_avg:91.72ms +step:482/1645 train_time:44208ms step_avg:91.72ms +step:483/1645 train_time:44299ms step_avg:91.72ms +step:484/1645 train_time:44392ms step_avg:91.72ms +step:485/1645 train_time:44485ms step_avg:91.72ms +step:486/1645 train_time:44576ms step_avg:91.72ms +step:487/1645 train_time:44667ms step_avg:91.72ms +step:488/1645 train_time:44759ms step_avg:91.72ms +step:489/1645 train_time:44851ms step_avg:91.72ms +step:490/1645 train_time:44943ms step_avg:91.72ms +step:491/1645 train_time:45034ms step_avg:91.72ms +step:492/1645 train_time:45125ms step_avg:91.72ms +step:493/1645 train_time:45217ms step_avg:91.72ms +step:494/1645 train_time:45309ms step_avg:91.72ms +step:495/1645 train_time:45400ms step_avg:91.72ms +step:496/1645 train_time:45493ms step_avg:91.72ms +step:497/1645 train_time:45585ms step_avg:91.72ms +step:498/1645 train_time:45676ms step_avg:91.72ms +step:499/1645 train_time:45769ms step_avg:91.72ms +step:500/1645 train_time:45862ms step_avg:91.72ms +step:500/1645 val_loss:3.7151 train_time:45954ms step_avg:91.91ms +step:501/1645 train_time:45974ms step_avg:91.76ms +step:502/1645 train_time:46049ms step_avg:91.73ms +step:503/1645 train_time:46143ms step_avg:91.74ms +step:504/1645 train_time:46234ms step_avg:91.73ms +step:505/1645 train_time:46325ms step_avg:91.73ms +step:506/1645 train_time:46416ms step_avg:91.73ms +step:507/1645 train_time:46506ms step_avg:91.73ms +step:508/1645 train_time:46597ms step_avg:91.73ms +step:509/1645 train_time:46687ms step_avg:91.72ms +step:510/1645 train_time:46778ms step_avg:91.72ms +step:511/1645 train_time:46869ms step_avg:91.72ms +step:512/1645 train_time:46962ms step_avg:91.72ms +step:513/1645 train_time:47055ms step_avg:91.73ms +step:514/1645 train_time:47148ms step_avg:91.73ms +step:515/1645 train_time:47241ms step_avg:91.73ms +step:516/1645 train_time:47332ms step_avg:91.73ms +step:517/1645 train_time:47425ms step_avg:91.73ms +step:518/1645 train_time:47515ms step_avg:91.73ms +step:519/1645 train_time:47606ms step_avg:91.73ms +step:520/1645 train_time:47697ms step_avg:91.73ms +step:521/1645 train_time:47788ms step_avg:91.72ms +step:522/1645 train_time:47880ms step_avg:91.72ms +step:523/1645 train_time:47972ms step_avg:91.72ms +step:524/1645 train_time:48064ms step_avg:91.72ms +step:525/1645 train_time:48157ms step_avg:91.73ms +step:526/1645 train_time:48249ms step_avg:91.73ms +step:527/1645 train_time:48341ms step_avg:91.73ms +step:528/1645 train_time:48433ms step_avg:91.73ms +step:529/1645 train_time:48524ms step_avg:91.73ms +step:530/1645 train_time:48616ms step_avg:91.73ms +step:531/1645 train_time:48706ms step_avg:91.73ms +step:532/1645 train_time:48798ms step_avg:91.72ms +step:533/1645 train_time:48889ms step_avg:91.72ms +step:534/1645 train_time:48980ms step_avg:91.72ms +step:535/1645 train_time:49071ms step_avg:91.72ms +step:536/1645 train_time:49163ms step_avg:91.72ms +step:537/1645 train_time:49255ms step_avg:91.72ms +step:538/1645 train_time:49347ms step_avg:91.72ms +step:539/1645 train_time:49438ms step_avg:91.72ms +step:540/1645 train_time:49530ms step_avg:91.72ms +step:541/1645 train_time:49621ms step_avg:91.72ms +step:542/1645 train_time:49713ms step_avg:91.72ms +step:543/1645 train_time:49805ms step_avg:91.72ms +step:544/1645 train_time:49896ms step_avg:91.72ms +step:545/1645 train_time:49987ms step_avg:91.72ms +step:546/1645 train_time:50079ms step_avg:91.72ms +step:547/1645 train_time:50171ms step_avg:91.72ms +step:548/1645 train_time:50263ms step_avg:91.72ms +step:549/1645 train_time:50354ms step_avg:91.72ms +step:550/1645 train_time:50446ms step_avg:91.72ms +step:551/1645 train_time:50539ms step_avg:91.72ms +step:552/1645 train_time:50631ms step_avg:91.72ms +step:553/1645 train_time:50725ms step_avg:91.73ms +step:554/1645 train_time:50818ms step_avg:91.73ms +step:555/1645 train_time:50910ms step_avg:91.73ms +step:556/1645 train_time:51004ms step_avg:91.73ms +step:557/1645 train_time:51097ms step_avg:91.74ms +step:558/1645 train_time:51190ms step_avg:91.74ms +step:559/1645 train_time:51282ms step_avg:91.74ms +step:560/1645 train_time:51375ms step_avg:91.74ms +step:561/1645 train_time:51468ms step_avg:91.74ms +step:562/1645 train_time:51561ms step_avg:91.74ms +step:563/1645 train_time:51653ms step_avg:91.75ms +step:564/1645 train_time:51747ms step_avg:91.75ms +step:565/1645 train_time:51840ms step_avg:91.75ms +step:566/1645 train_time:51932ms step_avg:91.75ms +step:567/1645 train_time:52025ms step_avg:91.76ms +step:568/1645 train_time:52118ms step_avg:91.76ms +step:569/1645 train_time:52210ms step_avg:91.76ms +step:570/1645 train_time:52303ms step_avg:91.76ms +step:571/1645 train_time:52396ms step_avg:91.76ms +step:572/1645 train_time:52489ms step_avg:91.76ms +step:573/1645 train_time:52582ms step_avg:91.77ms +step:574/1645 train_time:52675ms step_avg:91.77ms +step:575/1645 train_time:52768ms step_avg:91.77ms +step:576/1645 train_time:52861ms step_avg:91.77ms +step:577/1645 train_time:52954ms step_avg:91.77ms +step:578/1645 train_time:53048ms step_avg:91.78ms +step:579/1645 train_time:53140ms step_avg:91.78ms +step:580/1645 train_time:53233ms step_avg:91.78ms +step:581/1645 train_time:53326ms step_avg:91.78ms +step:582/1645 train_time:53419ms step_avg:91.78ms +step:583/1645 train_time:53511ms step_avg:91.79ms +step:584/1645 train_time:53604ms step_avg:91.79ms +step:585/1645 train_time:53697ms step_avg:91.79ms +step:586/1645 train_time:53790ms step_avg:91.79ms +step:587/1645 train_time:53882ms step_avg:91.79ms +step:588/1645 train_time:53975ms step_avg:91.79ms +step:589/1645 train_time:54068ms step_avg:91.80ms +step:590/1645 train_time:54161ms step_avg:91.80ms +step:591/1645 train_time:54253ms step_avg:91.80ms +step:592/1645 train_time:54347ms step_avg:91.80ms +step:593/1645 train_time:54440ms step_avg:91.80ms +step:594/1645 train_time:54533ms step_avg:91.81ms +step:595/1645 train_time:54626ms step_avg:91.81ms +step:596/1645 train_time:54719ms step_avg:91.81ms +step:597/1645 train_time:54812ms step_avg:91.81ms +step:598/1645 train_time:54905ms step_avg:91.81ms +step:599/1645 train_time:54998ms step_avg:91.82ms +step:600/1645 train_time:55090ms step_avg:91.82ms +step:601/1645 train_time:55184ms step_avg:91.82ms +step:602/1645 train_time:55276ms step_avg:91.82ms +step:603/1645 train_time:55368ms step_avg:91.82ms +step:604/1645 train_time:55462ms step_avg:91.82ms +step:605/1645 train_time:55554ms step_avg:91.83ms +step:606/1645 train_time:55648ms step_avg:91.83ms +step:607/1645 train_time:55742ms step_avg:91.83ms +step:608/1645 train_time:55835ms step_avg:91.83ms +step:609/1645 train_time:55928ms step_avg:91.84ms +step:610/1645 train_time:56022ms step_avg:91.84ms +step:611/1645 train_time:56114ms step_avg:91.84ms +step:612/1645 train_time:56207ms step_avg:91.84ms +step:613/1645 train_time:56299ms step_avg:91.84ms +step:614/1645 train_time:56393ms step_avg:91.85ms +step:615/1645 train_time:56485ms step_avg:91.85ms +step:616/1645 train_time:56577ms step_avg:91.85ms +step:617/1645 train_time:56670ms step_avg:91.85ms +step:618/1645 train_time:56763ms step_avg:91.85ms +step:619/1645 train_time:56856ms step_avg:91.85ms +step:620/1645 train_time:56949ms step_avg:91.85ms +step:621/1645 train_time:57042ms step_avg:91.85ms +step:622/1645 train_time:57135ms step_avg:91.86ms +step:623/1645 train_time:57227ms step_avg:91.86ms +step:624/1645 train_time:57320ms step_avg:91.86ms +step:625/1645 train_time:57412ms step_avg:91.86ms +step:625/1645 val_loss:3.6132 train_time:57505ms step_avg:92.01ms +step:626/1645 train_time:57526ms step_avg:91.89ms +step:627/1645 train_time:57604ms step_avg:91.87ms +step:628/1645 train_time:57706ms step_avg:91.89ms +step:629/1645 train_time:57798ms step_avg:91.89ms +step:630/1645 train_time:57890ms step_avg:91.89ms +step:631/1645 train_time:57982ms step_avg:91.89ms +step:632/1645 train_time:58073ms step_avg:91.89ms +step:633/1645 train_time:58165ms step_avg:91.89ms +step:634/1645 train_time:58257ms step_avg:91.89ms +step:635/1645 train_time:58349ms step_avg:91.89ms +step:636/1645 train_time:58441ms step_avg:91.89ms +step:637/1645 train_time:58537ms step_avg:91.89ms +step:638/1645 train_time:58632ms step_avg:91.90ms +step:639/1645 train_time:58727ms step_avg:91.91ms +step:640/1645 train_time:58821ms step_avg:91.91ms +step:641/1645 train_time:58914ms step_avg:91.91ms +step:642/1645 train_time:59006ms step_avg:91.91ms +step:643/1645 train_time:59099ms step_avg:91.91ms +step:644/1645 train_time:59192ms step_avg:91.91ms +step:645/1645 train_time:59284ms step_avg:91.91ms +step:646/1645 train_time:59376ms step_avg:91.91ms +step:647/1645 train_time:59469ms step_avg:91.92ms +step:648/1645 train_time:59564ms step_avg:91.92ms +step:649/1645 train_time:59658ms step_avg:91.92ms +step:650/1645 train_time:59750ms step_avg:91.92ms +step:651/1645 train_time:59844ms step_avg:91.93ms +step:652/1645 train_time:59937ms step_avg:91.93ms +step:653/1645 train_time:60029ms step_avg:91.93ms +step:654/1645 train_time:60121ms step_avg:91.93ms +step:655/1645 train_time:60213ms step_avg:91.93ms +step:656/1645 train_time:60307ms step_avg:91.93ms +step:657/1645 train_time:60399ms step_avg:91.93ms +step:658/1645 train_time:60492ms step_avg:91.93ms +step:659/1645 train_time:60586ms step_avg:91.94ms +step:660/1645 train_time:60680ms step_avg:91.94ms +step:661/1645 train_time:60774ms step_avg:91.94ms +step:662/1645 train_time:60866ms step_avg:91.94ms +step:663/1645 train_time:60959ms step_avg:91.94ms +step:664/1645 train_time:61052ms step_avg:91.95ms +step:665/1645 train_time:61144ms step_avg:91.95ms +step:666/1645 train_time:61237ms step_avg:91.95ms +step:667/1645 train_time:61330ms step_avg:91.95ms +step:668/1645 train_time:61423ms step_avg:91.95ms +step:669/1645 train_time:61515ms step_avg:91.95ms +step:670/1645 train_time:61609ms step_avg:91.95ms +step:671/1645 train_time:61702ms step_avg:91.96ms +step:672/1645 train_time:61795ms step_avg:91.96ms +step:673/1645 train_time:61888ms step_avg:91.96ms +step:674/1645 train_time:61982ms step_avg:91.96ms +step:675/1645 train_time:62075ms step_avg:91.96ms +step:676/1645 train_time:62167ms step_avg:91.96ms +step:677/1645 train_time:62260ms step_avg:91.96ms +step:678/1645 train_time:62353ms step_avg:91.97ms +step:679/1645 train_time:62445ms step_avg:91.97ms +step:680/1645 train_time:62537ms step_avg:91.97ms +step:681/1645 train_time:62630ms step_avg:91.97ms +step:682/1645 train_time:62723ms step_avg:91.97ms +step:683/1645 train_time:62817ms step_avg:91.97ms +step:684/1645 train_time:62909ms step_avg:91.97ms +step:685/1645 train_time:63003ms step_avg:91.98ms +step:686/1645 train_time:63096ms step_avg:91.98ms +step:687/1645 train_time:63189ms step_avg:91.98ms +step:688/1645 train_time:63281ms step_avg:91.98ms +step:689/1645 train_time:63374ms step_avg:91.98ms +step:690/1645 train_time:63467ms step_avg:91.98ms +step:691/1645 train_time:63559ms step_avg:91.98ms +step:692/1645 train_time:63651ms step_avg:91.98ms +step:693/1645 train_time:63744ms step_avg:91.98ms +step:694/1645 train_time:63837ms step_avg:91.98ms +step:695/1645 train_time:63930ms step_avg:91.99ms +step:696/1645 train_time:64024ms step_avg:91.99ms +step:697/1645 train_time:64117ms step_avg:91.99ms +step:698/1645 train_time:64210ms step_avg:91.99ms +step:699/1645 train_time:64303ms step_avg:91.99ms +step:700/1645 train_time:64397ms step_avg:92.00ms +step:701/1645 train_time:64489ms step_avg:92.00ms +step:702/1645 train_time:64583ms step_avg:92.00ms +step:703/1645 train_time:64675ms step_avg:92.00ms +step:704/1645 train_time:64767ms step_avg:92.00ms +step:705/1645 train_time:64861ms step_avg:92.00ms +step:706/1645 train_time:64953ms step_avg:92.00ms +step:707/1645 train_time:65047ms step_avg:92.00ms +step:708/1645 train_time:65139ms step_avg:92.00ms +step:709/1645 train_time:65232ms step_avg:92.01ms +step:710/1645 train_time:65325ms step_avg:92.01ms +step:711/1645 train_time:65418ms step_avg:92.01ms +step:712/1645 train_time:65511ms step_avg:92.01ms +step:713/1645 train_time:65605ms step_avg:92.01ms +step:714/1645 train_time:65698ms step_avg:92.01ms +step:715/1645 train_time:65791ms step_avg:92.02ms +step:716/1645 train_time:65884ms step_avg:92.02ms +step:717/1645 train_time:65977ms step_avg:92.02ms +step:718/1645 train_time:66070ms step_avg:92.02ms +step:719/1645 train_time:66164ms step_avg:92.02ms +step:720/1645 train_time:66257ms step_avg:92.02ms +step:721/1645 train_time:66349ms step_avg:92.02ms +step:722/1645 train_time:66442ms step_avg:92.03ms +step:723/1645 train_time:66535ms step_avg:92.03ms +step:724/1645 train_time:66627ms step_avg:92.03ms +step:725/1645 train_time:66720ms step_avg:92.03ms +step:726/1645 train_time:66812ms step_avg:92.03ms +step:727/1645 train_time:66905ms step_avg:92.03ms +step:728/1645 train_time:67000ms step_avg:92.03ms +step:729/1645 train_time:67092ms step_avg:92.03ms +step:730/1645 train_time:67186ms step_avg:92.04ms +step:731/1645 train_time:67279ms step_avg:92.04ms +step:732/1645 train_time:67371ms step_avg:92.04ms +step:733/1645 train_time:67465ms step_avg:92.04ms +step:734/1645 train_time:67558ms step_avg:92.04ms +step:735/1645 train_time:67651ms step_avg:92.04ms +step:736/1645 train_time:67743ms step_avg:92.04ms +step:737/1645 train_time:67836ms step_avg:92.04ms +step:738/1645 train_time:67928ms step_avg:92.04ms +step:739/1645 train_time:68021ms step_avg:92.05ms +step:740/1645 train_time:68114ms step_avg:92.05ms +step:741/1645 train_time:68209ms step_avg:92.05ms +step:742/1645 train_time:68302ms step_avg:92.05ms +step:743/1645 train_time:68394ms step_avg:92.05ms +step:744/1645 train_time:68488ms step_avg:92.05ms +step:745/1645 train_time:68580ms step_avg:92.05ms +step:746/1645 train_time:68673ms step_avg:92.06ms +step:747/1645 train_time:68765ms step_avg:92.06ms +step:748/1645 train_time:68859ms step_avg:92.06ms +step:749/1645 train_time:68951ms step_avg:92.06ms +step:750/1645 train_time:69044ms step_avg:92.06ms +step:750/1645 val_loss:3.5603 train_time:69137ms step_avg:92.18ms +step:751/1645 train_time:69158ms step_avg:92.09ms +step:752/1645 train_time:69234ms step_avg:92.07ms +step:753/1645 train_time:69329ms step_avg:92.07ms +step:754/1645 train_time:69422ms step_avg:92.07ms +step:755/1645 train_time:69513ms step_avg:92.07ms +step:756/1645 train_time:69606ms step_avg:92.07ms +step:757/1645 train_time:69698ms step_avg:92.07ms +step:758/1645 train_time:69789ms step_avg:92.07ms +step:759/1645 train_time:69881ms step_avg:92.07ms +step:760/1645 train_time:69974ms step_avg:92.07ms +step:761/1645 train_time:70067ms step_avg:92.07ms +step:762/1645 train_time:70161ms step_avg:92.08ms +step:763/1645 train_time:70256ms step_avg:92.08ms +step:764/1645 train_time:70350ms step_avg:92.08ms +step:765/1645 train_time:70444ms step_avg:92.08ms +step:766/1645 train_time:70536ms step_avg:92.08ms +step:767/1645 train_time:70628ms step_avg:92.08ms +step:768/1645 train_time:70720ms step_avg:92.08ms +step:769/1645 train_time:70812ms step_avg:92.08ms +step:770/1645 train_time:70905ms step_avg:92.08ms +step:771/1645 train_time:70997ms step_avg:92.08ms +step:772/1645 train_time:71090ms step_avg:92.09ms +step:773/1645 train_time:71184ms step_avg:92.09ms +step:774/1645 train_time:71277ms step_avg:92.09ms +step:775/1645 train_time:71371ms step_avg:92.09ms +step:776/1645 train_time:71465ms step_avg:92.09ms +step:777/1645 train_time:71557ms step_avg:92.09ms +step:778/1645 train_time:71650ms step_avg:92.09ms +step:779/1645 train_time:71743ms step_avg:92.10ms +step:780/1645 train_time:71835ms step_avg:92.10ms +step:781/1645 train_time:71928ms step_avg:92.10ms +step:782/1645 train_time:72020ms step_avg:92.10ms +step:783/1645 train_time:72114ms step_avg:92.10ms +step:784/1645 train_time:72207ms step_avg:92.10ms +step:785/1645 train_time:72300ms step_avg:92.10ms +step:786/1645 train_time:72393ms step_avg:92.10ms +step:787/1645 train_time:72486ms step_avg:92.10ms +step:788/1645 train_time:72578ms step_avg:92.10ms +step:789/1645 train_time:72672ms step_avg:92.11ms +step:790/1645 train_time:72766ms step_avg:92.11ms +step:791/1645 train_time:72858ms step_avg:92.11ms +step:792/1645 train_time:72951ms step_avg:92.11ms +step:793/1645 train_time:73044ms step_avg:92.11ms +step:794/1645 train_time:73138ms step_avg:92.11ms +step:795/1645 train_time:73231ms step_avg:92.11ms +step:796/1645 train_time:73325ms step_avg:92.12ms +step:797/1645 train_time:73418ms step_avg:92.12ms +step:798/1645 train_time:73511ms step_avg:92.12ms +step:799/1645 train_time:73604ms step_avg:92.12ms +step:800/1645 train_time:73696ms step_avg:92.12ms +step:801/1645 train_time:73789ms step_avg:92.12ms +step:802/1645 train_time:73882ms step_avg:92.12ms +step:803/1645 train_time:73975ms step_avg:92.12ms +step:804/1645 train_time:74069ms step_avg:92.13ms +step:805/1645 train_time:74162ms step_avg:92.13ms +step:806/1645 train_time:74253ms step_avg:92.13ms +step:807/1645 train_time:74347ms step_avg:92.13ms +step:808/1645 train_time:74440ms step_avg:92.13ms +step:809/1645 train_time:74533ms step_avg:92.13ms +step:810/1645 train_time:74627ms step_avg:92.13ms +step:811/1645 train_time:74719ms step_avg:92.13ms +step:812/1645 train_time:74812ms step_avg:92.13ms +step:813/1645 train_time:74904ms step_avg:92.13ms +step:814/1645 train_time:74997ms step_avg:92.13ms +step:815/1645 train_time:75090ms step_avg:92.13ms +step:816/1645 train_time:75183ms step_avg:92.14ms +step:817/1645 train_time:75275ms step_avg:92.14ms +step:818/1645 train_time:75369ms step_avg:92.14ms +step:819/1645 train_time:75462ms step_avg:92.14ms +step:820/1645 train_time:75554ms step_avg:92.14ms +step:821/1645 train_time:75648ms step_avg:92.14ms +step:822/1645 train_time:75742ms step_avg:92.14ms +step:823/1645 train_time:75836ms step_avg:92.15ms +step:824/1645 train_time:75928ms step_avg:92.15ms +step:825/1645 train_time:76021ms step_avg:92.15ms +step:826/1645 train_time:76113ms step_avg:92.15ms +step:827/1645 train_time:76207ms step_avg:92.15ms +step:828/1645 train_time:76300ms step_avg:92.15ms +step:829/1645 train_time:76393ms step_avg:92.15ms +step:830/1645 train_time:76485ms step_avg:92.15ms +step:831/1645 train_time:76577ms step_avg:92.15ms +step:832/1645 train_time:76671ms step_avg:92.15ms +step:833/1645 train_time:76764ms step_avg:92.15ms +step:834/1645 train_time:76857ms step_avg:92.16ms +step:835/1645 train_time:76950ms step_avg:92.16ms +step:836/1645 train_time:77043ms step_avg:92.16ms +step:837/1645 train_time:77136ms step_avg:92.16ms +step:838/1645 train_time:77229ms step_avg:92.16ms +step:839/1645 train_time:77322ms step_avg:92.16ms +step:840/1645 train_time:77415ms step_avg:92.16ms +step:841/1645 train_time:77509ms step_avg:92.16ms +step:842/1645 train_time:77602ms step_avg:92.16ms +step:843/1645 train_time:77694ms step_avg:92.16ms +step:844/1645 train_time:77789ms step_avg:92.17ms +step:845/1645 train_time:77881ms step_avg:92.17ms +step:846/1645 train_time:77974ms step_avg:92.17ms +step:847/1645 train_time:78066ms step_avg:92.17ms +step:848/1645 train_time:78159ms step_avg:92.17ms +step:849/1645 train_time:78252ms step_avg:92.17ms +step:850/1645 train_time:78347ms step_avg:92.17ms +step:851/1645 train_time:78441ms step_avg:92.17ms +step:852/1645 train_time:78534ms step_avg:92.18ms +step:853/1645 train_time:78626ms step_avg:92.18ms +step:854/1645 train_time:78719ms step_avg:92.18ms +step:855/1645 train_time:78811ms step_avg:92.18ms +step:856/1645 train_time:78905ms step_avg:92.18ms +step:857/1645 train_time:78998ms step_avg:92.18ms +step:858/1645 train_time:79090ms step_avg:92.18ms +step:859/1645 train_time:79184ms step_avg:92.18ms +step:860/1645 train_time:79276ms step_avg:92.18ms +step:861/1645 train_time:79369ms step_avg:92.18ms +step:862/1645 train_time:79463ms step_avg:92.18ms +step:863/1645 train_time:79555ms step_avg:92.18ms +step:864/1645 train_time:79648ms step_avg:92.19ms +step:865/1645 train_time:79742ms step_avg:92.19ms +step:866/1645 train_time:79835ms step_avg:92.19ms +step:867/1645 train_time:79928ms step_avg:92.19ms +step:868/1645 train_time:80021ms step_avg:92.19ms +step:869/1645 train_time:80113ms step_avg:92.19ms +step:870/1645 train_time:80207ms step_avg:92.19ms +step:871/1645 train_time:80300ms step_avg:92.19ms +step:872/1645 train_time:80393ms step_avg:92.19ms +step:873/1645 train_time:80486ms step_avg:92.19ms +step:874/1645 train_time:80579ms step_avg:92.20ms +step:875/1645 train_time:80672ms step_avg:92.20ms +step:875/1645 val_loss:3.5135 train_time:80766ms step_avg:92.30ms +step:876/1645 train_time:80792ms step_avg:92.23ms +step:877/1645 train_time:80869ms step_avg:92.21ms +step:878/1645 train_time:80964ms step_avg:92.21ms +step:879/1645 train_time:81057ms step_avg:92.22ms +step:880/1645 train_time:81149ms step_avg:92.21ms +step:881/1645 train_time:81241ms step_avg:92.22ms +step:882/1645 train_time:81333ms step_avg:92.21ms +step:883/1645 train_time:81424ms step_avg:92.21ms +step:884/1645 train_time:81516ms step_avg:92.21ms +step:885/1645 train_time:81608ms step_avg:92.21ms +step:886/1645 train_time:81701ms step_avg:92.21ms +step:887/1645 train_time:81796ms step_avg:92.22ms +step:888/1645 train_time:81891ms step_avg:92.22ms +step:889/1645 train_time:81986ms step_avg:92.22ms +step:890/1645 train_time:82080ms step_avg:92.22ms +step:891/1645 train_time:82172ms step_avg:92.22ms +step:892/1645 train_time:82265ms step_avg:92.23ms +step:893/1645 train_time:82357ms step_avg:92.23ms +step:894/1645 train_time:82450ms step_avg:92.23ms +step:895/1645 train_time:82542ms step_avg:92.23ms +step:896/1645 train_time:82634ms step_avg:92.22ms +step:897/1645 train_time:82727ms step_avg:92.23ms +step:898/1645 train_time:82821ms step_avg:92.23ms +step:899/1645 train_time:82915ms step_avg:92.23ms +step:900/1645 train_time:83010ms step_avg:92.23ms +step:901/1645 train_time:83105ms step_avg:92.24ms +step:902/1645 train_time:83197ms step_avg:92.24ms +step:903/1645 train_time:83290ms step_avg:92.24ms +step:904/1645 train_time:83383ms step_avg:92.24ms +step:905/1645 train_time:83476ms step_avg:92.24ms +step:906/1645 train_time:83569ms step_avg:92.24ms +step:907/1645 train_time:83661ms step_avg:92.24ms +step:908/1645 train_time:83753ms step_avg:92.24ms +step:909/1645 train_time:83847ms step_avg:92.24ms +step:910/1645 train_time:83941ms step_avg:92.24ms +step:911/1645 train_time:84035ms step_avg:92.24ms +step:912/1645 train_time:84127ms step_avg:92.24ms +step:913/1645 train_time:84219ms step_avg:92.24ms +step:914/1645 train_time:84312ms step_avg:92.24ms +step:915/1645 train_time:84405ms step_avg:92.25ms +step:916/1645 train_time:84498ms step_avg:92.25ms +step:917/1645 train_time:84590ms step_avg:92.25ms +step:918/1645 train_time:84683ms step_avg:92.25ms +step:919/1645 train_time:84776ms step_avg:92.25ms +step:920/1645 train_time:84870ms step_avg:92.25ms +step:921/1645 train_time:84964ms step_avg:92.25ms +step:922/1645 train_time:85059ms step_avg:92.25ms +step:923/1645 train_time:85150ms step_avg:92.25ms +step:924/1645 train_time:85243ms step_avg:92.25ms +step:925/1645 train_time:85336ms step_avg:92.26ms +step:926/1645 train_time:85429ms step_avg:92.26ms +step:927/1645 train_time:85521ms step_avg:92.26ms +step:928/1645 train_time:85614ms step_avg:92.26ms +step:929/1645 train_time:85707ms step_avg:92.26ms +step:930/1645 train_time:85800ms step_avg:92.26ms +step:931/1645 train_time:85892ms step_avg:92.26ms +step:932/1645 train_time:85987ms step_avg:92.26ms +step:933/1645 train_time:86081ms step_avg:92.26ms +step:934/1645 train_time:86174ms step_avg:92.26ms +step:935/1645 train_time:86267ms step_avg:92.26ms +step:936/1645 train_time:86360ms step_avg:92.27ms +step:937/1645 train_time:86453ms step_avg:92.27ms +step:938/1645 train_time:86546ms step_avg:92.27ms +step:939/1645 train_time:86639ms step_avg:92.27ms +step:940/1645 train_time:86731ms step_avg:92.27ms +step:941/1645 train_time:86825ms step_avg:92.27ms +step:942/1645 train_time:86917ms step_avg:92.27ms +step:943/1645 train_time:87010ms step_avg:92.27ms +step:944/1645 train_time:87104ms step_avg:92.27ms +step:945/1645 train_time:87197ms step_avg:92.27ms +step:946/1645 train_time:87290ms step_avg:92.27ms +step:947/1645 train_time:87384ms step_avg:92.27ms +step:948/1645 train_time:87477ms step_avg:92.28ms +step:949/1645 train_time:87569ms step_avg:92.28ms +step:950/1645 train_time:87662ms step_avg:92.28ms +step:951/1645 train_time:87755ms step_avg:92.28ms +step:952/1645 train_time:87848ms step_avg:92.28ms +step:953/1645 train_time:87941ms step_avg:92.28ms +step:954/1645 train_time:88034ms step_avg:92.28ms +step:955/1645 train_time:88127ms step_avg:92.28ms +step:956/1645 train_time:88220ms step_avg:92.28ms +step:957/1645 train_time:88312ms step_avg:92.28ms +step:958/1645 train_time:88406ms step_avg:92.28ms +step:959/1645 train_time:88499ms step_avg:92.28ms +step:960/1645 train_time:88593ms step_avg:92.28ms +step:961/1645 train_time:88685ms step_avg:92.28ms +step:962/1645 train_time:88779ms step_avg:92.29ms +step:963/1645 train_time:88871ms step_avg:92.29ms +step:964/1645 train_time:88964ms step_avg:92.29ms +step:965/1645 train_time:89057ms step_avg:92.29ms +step:966/1645 train_time:89150ms step_avg:92.29ms +step:967/1645 train_time:89243ms step_avg:92.29ms +step:968/1645 train_time:89337ms step_avg:92.29ms +step:969/1645 train_time:89429ms step_avg:92.29ms +step:970/1645 train_time:89522ms step_avg:92.29ms +step:971/1645 train_time:89615ms step_avg:92.29ms +step:972/1645 train_time:89708ms step_avg:92.29ms +step:973/1645 train_time:89802ms step_avg:92.29ms +step:974/1645 train_time:89894ms step_avg:92.29ms +step:975/1645 train_time:89987ms step_avg:92.29ms +step:976/1645 train_time:90081ms step_avg:92.30ms +step:977/1645 train_time:90173ms step_avg:92.30ms +step:978/1645 train_time:90266ms step_avg:92.30ms +step:979/1645 train_time:90359ms step_avg:92.30ms +step:980/1645 train_time:90452ms step_avg:92.30ms +step:981/1645 train_time:90545ms step_avg:92.30ms +step:982/1645 train_time:90638ms step_avg:92.30ms +step:983/1645 train_time:90731ms step_avg:92.30ms +step:984/1645 train_time:90824ms step_avg:92.30ms +step:985/1645 train_time:90917ms step_avg:92.30ms +step:986/1645 train_time:91009ms step_avg:92.30ms +step:987/1645 train_time:91103ms step_avg:92.30ms +step:988/1645 train_time:91196ms step_avg:92.30ms +step:989/1645 train_time:91290ms step_avg:92.31ms +step:990/1645 train_time:91384ms step_avg:92.31ms +step:991/1645 train_time:91476ms step_avg:92.31ms +step:992/1645 train_time:91569ms step_avg:92.31ms +step:993/1645 train_time:91661ms step_avg:92.31ms +step:994/1645 train_time:91754ms step_avg:92.31ms +step:995/1645 train_time:91847ms step_avg:92.31ms +step:996/1645 train_time:91940ms step_avg:92.31ms +step:997/1645 train_time:92032ms step_avg:92.31ms +step:998/1645 train_time:92125ms step_avg:92.31ms +step:999/1645 train_time:92218ms step_avg:92.31ms +step:1000/1645 train_time:92311ms step_avg:92.31ms +step:1000/1645 val_loss:3.4640 train_time:92404ms step_avg:92.40ms +step:1001/1645 train_time:92430ms step_avg:92.34ms +step:1002/1645 train_time:92504ms step_avg:92.32ms +step:1003/1645 train_time:92600ms step_avg:92.32ms +step:1004/1645 train_time:92693ms step_avg:92.32ms +step:1005/1645 train_time:92784ms step_avg:92.32ms +step:1006/1645 train_time:92876ms step_avg:92.32ms +step:1007/1645 train_time:92968ms step_avg:92.32ms +step:1008/1645 train_time:93059ms step_avg:92.32ms +step:1009/1645 train_time:93151ms step_avg:92.32ms +step:1010/1645 train_time:93244ms step_avg:92.32ms +step:1011/1645 train_time:93336ms step_avg:92.32ms +step:1012/1645 train_time:93432ms step_avg:92.32ms +step:1013/1645 train_time:93526ms step_avg:92.33ms +step:1014/1645 train_time:93621ms step_avg:92.33ms +step:1015/1645 train_time:93714ms step_avg:92.33ms +step:1016/1645 train_time:93806ms step_avg:92.33ms +step:1017/1645 train_time:93899ms step_avg:92.33ms +step:1018/1645 train_time:93991ms step_avg:92.33ms +step:1019/1645 train_time:94083ms step_avg:92.33ms +step:1020/1645 train_time:94175ms step_avg:92.33ms +step:1021/1645 train_time:94267ms step_avg:92.33ms +step:1022/1645 train_time:94360ms step_avg:92.33ms +step:1023/1645 train_time:94454ms step_avg:92.33ms +step:1024/1645 train_time:94547ms step_avg:92.33ms +step:1025/1645 train_time:94640ms step_avg:92.33ms +step:1026/1645 train_time:94733ms step_avg:92.33ms +step:1027/1645 train_time:94826ms step_avg:92.33ms +step:1028/1645 train_time:94918ms step_avg:92.33ms +step:1029/1645 train_time:95010ms step_avg:92.33ms +step:1030/1645 train_time:95103ms step_avg:92.33ms +step:1031/1645 train_time:95196ms step_avg:92.33ms +step:1032/1645 train_time:95288ms step_avg:92.33ms +step:1033/1645 train_time:95381ms step_avg:92.33ms +step:1034/1645 train_time:95474ms step_avg:92.33ms +step:1035/1645 train_time:95567ms step_avg:92.34ms +step:1036/1645 train_time:95661ms step_avg:92.34ms +step:1037/1645 train_time:95753ms step_avg:92.34ms +step:1038/1645 train_time:95846ms step_avg:92.34ms +step:1039/1645 train_time:95939ms step_avg:92.34ms +step:1040/1645 train_time:96032ms step_avg:92.34ms +step:1041/1645 train_time:96124ms step_avg:92.34ms +step:1042/1645 train_time:96217ms step_avg:92.34ms +step:1043/1645 train_time:96310ms step_avg:92.34ms +step:1044/1645 train_time:96403ms step_avg:92.34ms +step:1045/1645 train_time:96496ms step_avg:92.34ms +step:1046/1645 train_time:96589ms step_avg:92.34ms +step:1047/1645 train_time:96682ms step_avg:92.34ms +step:1048/1645 train_time:96775ms step_avg:92.34ms +step:1049/1645 train_time:96868ms step_avg:92.34ms +step:1050/1645 train_time:96961ms step_avg:92.34ms +step:1051/1645 train_time:97053ms step_avg:92.34ms +step:1052/1645 train_time:97146ms step_avg:92.34ms +step:1053/1645 train_time:97238ms step_avg:92.34ms +step:1054/1645 train_time:97331ms step_avg:92.34ms +step:1055/1645 train_time:97424ms step_avg:92.35ms +step:1056/1645 train_time:97518ms step_avg:92.35ms +step:1057/1645 train_time:97611ms step_avg:92.35ms +step:1058/1645 train_time:97704ms step_avg:92.35ms +step:1059/1645 train_time:97796ms step_avg:92.35ms +step:1060/1645 train_time:97889ms step_avg:92.35ms +step:1061/1645 train_time:97983ms step_avg:92.35ms +step:1062/1645 train_time:98076ms step_avg:92.35ms +step:1063/1645 train_time:98168ms step_avg:92.35ms +step:1064/1645 train_time:98260ms step_avg:92.35ms +step:1065/1645 train_time:98354ms step_avg:92.35ms +step:1066/1645 train_time:98446ms step_avg:92.35ms +step:1067/1645 train_time:98539ms step_avg:92.35ms +step:1068/1645 train_time:98632ms step_avg:92.35ms +step:1069/1645 train_time:98726ms step_avg:92.35ms +step:1070/1645 train_time:98820ms step_avg:92.35ms +step:1071/1645 train_time:98913ms step_avg:92.36ms +step:1072/1645 train_time:99006ms step_avg:92.36ms +step:1073/1645 train_time:99098ms step_avg:92.36ms +step:1074/1645 train_time:99191ms step_avg:92.36ms +step:1075/1645 train_time:99284ms step_avg:92.36ms +step:1076/1645 train_time:99376ms step_avg:92.36ms +step:1077/1645 train_time:99469ms step_avg:92.36ms +step:1078/1645 train_time:99561ms step_avg:92.36ms +step:1079/1645 train_time:99654ms step_avg:92.36ms +step:1080/1645 train_time:99747ms step_avg:92.36ms +step:1081/1645 train_time:99840ms step_avg:92.36ms +step:1082/1645 train_time:99933ms step_avg:92.36ms +step:1083/1645 train_time:100026ms step_avg:92.36ms +step:1084/1645 train_time:100120ms step_avg:92.36ms +step:1085/1645 train_time:100212ms step_avg:92.36ms +step:1086/1645 train_time:100304ms step_avg:92.36ms +step:1087/1645 train_time:100397ms step_avg:92.36ms +step:1088/1645 train_time:100490ms step_avg:92.36ms +step:1089/1645 train_time:100583ms step_avg:92.36ms +step:1090/1645 train_time:100676ms step_avg:92.36ms +step:1091/1645 train_time:100769ms step_avg:92.36ms +step:1092/1645 train_time:100864ms step_avg:92.37ms +step:1093/1645 train_time:100956ms step_avg:92.37ms +step:1094/1645 train_time:101048ms step_avg:92.37ms +step:1095/1645 train_time:101141ms step_avg:92.37ms +step:1096/1645 train_time:101233ms step_avg:92.37ms +step:1097/1645 train_time:101327ms step_avg:92.37ms +step:1098/1645 train_time:101420ms step_avg:92.37ms +step:1099/1645 train_time:101512ms step_avg:92.37ms +step:1100/1645 train_time:101606ms step_avg:92.37ms +step:1101/1645 train_time:101700ms step_avg:92.37ms +step:1102/1645 train_time:101793ms step_avg:92.37ms +step:1103/1645 train_time:101887ms step_avg:92.37ms +step:1104/1645 train_time:101980ms step_avg:92.37ms +step:1105/1645 train_time:102074ms step_avg:92.37ms +step:1106/1645 train_time:102167ms step_avg:92.38ms +step:1107/1645 train_time:102260ms step_avg:92.38ms +step:1108/1645 train_time:102354ms step_avg:92.38ms +step:1109/1645 train_time:102447ms step_avg:92.38ms +step:1110/1645 train_time:102541ms step_avg:92.38ms +step:1111/1645 train_time:102633ms step_avg:92.38ms +step:1112/1645 train_time:102728ms step_avg:92.38ms +step:1113/1645 train_time:102823ms step_avg:92.38ms +step:1114/1645 train_time:102916ms step_avg:92.38ms +step:1115/1645 train_time:103009ms step_avg:92.38ms +step:1116/1645 train_time:103103ms step_avg:92.39ms +step:1117/1645 train_time:103196ms step_avg:92.39ms +step:1118/1645 train_time:103289ms step_avg:92.39ms +step:1119/1645 train_time:103382ms step_avg:92.39ms +step:1120/1645 train_time:103476ms step_avg:92.39ms +step:1121/1645 train_time:103569ms step_avg:92.39ms +step:1122/1645 train_time:103663ms step_avg:92.39ms +step:1123/1645 train_time:103756ms step_avg:92.39ms +step:1124/1645 train_time:103850ms step_avg:92.39ms +step:1125/1645 train_time:103944ms step_avg:92.39ms +step:1125/1645 val_loss:3.4105 train_time:104037ms step_avg:92.48ms +step:1126/1645 train_time:104063ms step_avg:92.42ms +step:1127/1645 train_time:104142ms step_avg:92.41ms +step:1128/1645 train_time:104241ms step_avg:92.41ms +step:1129/1645 train_time:104336ms step_avg:92.41ms +step:1130/1645 train_time:104429ms step_avg:92.42ms +step:1131/1645 train_time:104521ms step_avg:92.41ms +step:1132/1645 train_time:104614ms step_avg:92.42ms +step:1133/1645 train_time:104706ms step_avg:92.42ms +step:1134/1645 train_time:104798ms step_avg:92.41ms +step:1135/1645 train_time:104891ms step_avg:92.41ms +step:1136/1645 train_time:104984ms step_avg:92.42ms +step:1137/1645 train_time:105078ms step_avg:92.42ms +step:1138/1645 train_time:105176ms step_avg:92.42ms +step:1139/1645 train_time:105272ms step_avg:92.42ms +step:1140/1645 train_time:105366ms step_avg:92.43ms +step:1141/1645 train_time:105459ms step_avg:92.43ms +step:1142/1645 train_time:105551ms step_avg:92.43ms +step:1143/1645 train_time:105645ms step_avg:92.43ms +step:1144/1645 train_time:105737ms step_avg:92.43ms +step:1145/1645 train_time:105830ms step_avg:92.43ms +step:1146/1645 train_time:105923ms step_avg:92.43ms +step:1147/1645 train_time:106016ms step_avg:92.43ms +step:1148/1645 train_time:106111ms step_avg:92.43ms +step:1149/1645 train_time:106207ms step_avg:92.43ms +step:1150/1645 train_time:106301ms step_avg:92.44ms +step:1151/1645 train_time:106395ms step_avg:92.44ms +step:1152/1645 train_time:106489ms step_avg:92.44ms +step:1153/1645 train_time:106581ms step_avg:92.44ms +step:1154/1645 train_time:106675ms step_avg:92.44ms +step:1155/1645 train_time:106768ms step_avg:92.44ms +step:1156/1645 train_time:106861ms step_avg:92.44ms +step:1157/1645 train_time:106953ms step_avg:92.44ms +step:1158/1645 train_time:107047ms step_avg:92.44ms +step:1159/1645 train_time:107141ms step_avg:92.44ms +step:1160/1645 train_time:107235ms step_avg:92.44ms +step:1161/1645 train_time:107330ms step_avg:92.45ms +step:1162/1645 train_time:107423ms step_avg:92.45ms +step:1163/1645 train_time:107516ms step_avg:92.45ms +step:1164/1645 train_time:107610ms step_avg:92.45ms +step:1165/1645 train_time:107703ms step_avg:92.45ms +step:1166/1645 train_time:107796ms step_avg:92.45ms +step:1167/1645 train_time:107889ms step_avg:92.45ms +step:1168/1645 train_time:107982ms step_avg:92.45ms +step:1169/1645 train_time:108076ms step_avg:92.45ms +step:1170/1645 train_time:108170ms step_avg:92.45ms +step:1171/1645 train_time:108265ms step_avg:92.45ms +step:1172/1645 train_time:108359ms step_avg:92.46ms +step:1173/1645 train_time:108453ms step_avg:92.46ms +step:1174/1645 train_time:108547ms step_avg:92.46ms +step:1175/1645 train_time:108639ms step_avg:92.46ms +step:1176/1645 train_time:108733ms step_avg:92.46ms +step:1177/1645 train_time:108826ms step_avg:92.46ms +step:1178/1645 train_time:108919ms step_avg:92.46ms +step:1179/1645 train_time:109013ms step_avg:92.46ms +step:1180/1645 train_time:109108ms step_avg:92.46ms +step:1181/1645 train_time:109202ms step_avg:92.47ms +step:1182/1645 train_time:109295ms step_avg:92.47ms +step:1183/1645 train_time:109389ms step_avg:92.47ms +step:1184/1645 train_time:109483ms step_avg:92.47ms +step:1185/1645 train_time:109576ms step_avg:92.47ms +step:1186/1645 train_time:109670ms step_avg:92.47ms +step:1187/1645 train_time:109763ms step_avg:92.47ms +step:1188/1645 train_time:109856ms step_avg:92.47ms +step:1189/1645 train_time:109950ms step_avg:92.47ms +step:1190/1645 train_time:110044ms step_avg:92.47ms +step:1191/1645 train_time:110138ms step_avg:92.47ms +step:1192/1645 train_time:110233ms step_avg:92.48ms +step:1193/1645 train_time:110326ms step_avg:92.48ms +step:1194/1645 train_time:110420ms step_avg:92.48ms +step:1195/1645 train_time:110513ms step_avg:92.48ms +step:1196/1645 train_time:110606ms step_avg:92.48ms +step:1197/1645 train_time:110700ms step_avg:92.48ms +step:1198/1645 train_time:110793ms step_avg:92.48ms +step:1199/1645 train_time:110887ms step_avg:92.48ms +step:1200/1645 train_time:110980ms step_avg:92.48ms +step:1201/1645 train_time:111073ms step_avg:92.48ms +step:1202/1645 train_time:111167ms step_avg:92.49ms +step:1203/1645 train_time:111261ms step_avg:92.49ms +step:1204/1645 train_time:111354ms step_avg:92.49ms +step:1205/1645 train_time:111448ms step_avg:92.49ms +step:1206/1645 train_time:111542ms step_avg:92.49ms +step:1207/1645 train_time:111635ms step_avg:92.49ms +step:1208/1645 train_time:111729ms step_avg:92.49ms +step:1209/1645 train_time:111822ms step_avg:92.49ms +step:1210/1645 train_time:111916ms step_avg:92.49ms +step:1211/1645 train_time:112010ms step_avg:92.49ms +step:1212/1645 train_time:112103ms step_avg:92.49ms +step:1213/1645 train_time:112197ms step_avg:92.50ms +step:1214/1645 train_time:112291ms step_avg:92.50ms +step:1215/1645 train_time:112385ms step_avg:92.50ms +step:1216/1645 train_time:112478ms step_avg:92.50ms +step:1217/1645 train_time:112573ms step_avg:92.50ms +step:1218/1645 train_time:112667ms step_avg:92.50ms +step:1219/1645 train_time:112761ms step_avg:92.50ms +step:1220/1645 train_time:112853ms step_avg:92.50ms +step:1221/1645 train_time:112946ms step_avg:92.50ms +step:1222/1645 train_time:113039ms step_avg:92.50ms +step:1223/1645 train_time:113133ms step_avg:92.50ms +step:1224/1645 train_time:113227ms step_avg:92.51ms +step:1225/1645 train_time:113321ms step_avg:92.51ms +step:1226/1645 train_time:113415ms step_avg:92.51ms +step:1227/1645 train_time:113508ms step_avg:92.51ms +step:1228/1645 train_time:113601ms step_avg:92.51ms +step:1229/1645 train_time:113695ms step_avg:92.51ms +step:1230/1645 train_time:113789ms step_avg:92.51ms +step:1231/1645 train_time:113882ms step_avg:92.51ms +step:1232/1645 train_time:113975ms step_avg:92.51ms +step:1233/1645 train_time:114070ms step_avg:92.51ms +step:1234/1645 train_time:114163ms step_avg:92.51ms +step:1235/1645 train_time:114257ms step_avg:92.52ms +step:1236/1645 train_time:114350ms step_avg:92.52ms +step:1237/1645 train_time:114444ms step_avg:92.52ms +step:1238/1645 train_time:114538ms step_avg:92.52ms +step:1239/1645 train_time:114632ms step_avg:92.52ms +step:1240/1645 train_time:114726ms step_avg:92.52ms +step:1241/1645 train_time:114820ms step_avg:92.52ms +step:1242/1645 train_time:114913ms step_avg:92.52ms +step:1243/1645 train_time:115006ms step_avg:92.52ms +step:1244/1645 train_time:115101ms step_avg:92.52ms +step:1245/1645 train_time:115194ms step_avg:92.53ms +step:1246/1645 train_time:115289ms step_avg:92.53ms +step:1247/1645 train_time:115382ms step_avg:92.53ms +step:1248/1645 train_time:115477ms step_avg:92.53ms +step:1249/1645 train_time:115570ms step_avg:92.53ms +step:1250/1645 train_time:115664ms step_avg:92.53ms +step:1250/1645 val_loss:3.3726 train_time:115758ms step_avg:92.61ms +step:1251/1645 train_time:115779ms step_avg:92.55ms +step:1252/1645 train_time:115859ms step_avg:92.54ms +step:1253/1645 train_time:115953ms step_avg:92.54ms +step:1254/1645 train_time:116046ms step_avg:92.54ms +step:1255/1645 train_time:116138ms step_avg:92.54ms +step:1256/1645 train_time:116230ms step_avg:92.54ms +step:1257/1645 train_time:116322ms step_avg:92.54ms +step:1258/1645 train_time:116414ms step_avg:92.54ms +step:1259/1645 train_time:116507ms step_avg:92.54ms +step:1260/1645 train_time:116602ms step_avg:92.54ms +step:1261/1645 train_time:116697ms step_avg:92.54ms +step:1262/1645 train_time:116794ms step_avg:92.55ms +step:1263/1645 train_time:116889ms step_avg:92.55ms +step:1264/1645 train_time:116985ms step_avg:92.55ms +step:1265/1645 train_time:117079ms step_avg:92.55ms +step:1266/1645 train_time:117171ms step_avg:92.55ms +step:1267/1645 train_time:117264ms step_avg:92.55ms +step:1268/1645 train_time:117357ms step_avg:92.55ms +step:1269/1645 train_time:117449ms step_avg:92.55ms +step:1270/1645 train_time:117542ms step_avg:92.55ms +step:1271/1645 train_time:117636ms step_avg:92.55ms +step:1272/1645 train_time:117730ms step_avg:92.56ms +step:1273/1645 train_time:117825ms step_avg:92.56ms +step:1274/1645 train_time:117920ms step_avg:92.56ms +step:1275/1645 train_time:118015ms step_avg:92.56ms +step:1276/1645 train_time:118108ms step_avg:92.56ms +step:1277/1645 train_time:118202ms step_avg:92.56ms +step:1278/1645 train_time:118295ms step_avg:92.56ms +step:1279/1645 train_time:118388ms step_avg:92.56ms +step:1280/1645 train_time:118481ms step_avg:92.56ms +step:1281/1645 train_time:118573ms step_avg:92.56ms +step:1282/1645 train_time:118667ms step_avg:92.56ms +step:1283/1645 train_time:118762ms step_avg:92.57ms +step:1284/1645 train_time:118856ms step_avg:92.57ms +step:1285/1645 train_time:118951ms step_avg:92.57ms +step:1286/1645 train_time:119045ms step_avg:92.57ms +step:1287/1645 train_time:119138ms step_avg:92.57ms +step:1288/1645 train_time:119232ms step_avg:92.57ms +step:1289/1645 train_time:119325ms step_avg:92.57ms +step:1290/1645 train_time:119419ms step_avg:92.57ms +step:1291/1645 train_time:119512ms step_avg:92.57ms +step:1292/1645 train_time:119607ms step_avg:92.57ms +step:1293/1645 train_time:119701ms step_avg:92.58ms +step:1294/1645 train_time:119797ms step_avg:92.58ms +step:1295/1645 train_time:119891ms step_avg:92.58ms +step:1296/1645 train_time:119985ms step_avg:92.58ms +step:1297/1645 train_time:120079ms step_avg:92.58ms +step:1298/1645 train_time:120172ms step_avg:92.58ms +step:1299/1645 train_time:120265ms step_avg:92.58ms +step:1300/1645 train_time:120359ms step_avg:92.58ms +step:1301/1645 train_time:120452ms step_avg:92.58ms +step:1302/1645 train_time:120544ms step_avg:92.58ms +step:1303/1645 train_time:120638ms step_avg:92.58ms +step:1304/1645 train_time:120732ms step_avg:92.59ms +step:1305/1645 train_time:120825ms step_avg:92.59ms +step:1306/1645 train_time:120920ms step_avg:92.59ms +step:1307/1645 train_time:121013ms step_avg:92.59ms +step:1308/1645 train_time:121106ms step_avg:92.59ms +step:1309/1645 train_time:121200ms step_avg:92.59ms +step:1310/1645 train_time:121293ms step_avg:92.59ms +step:1311/1645 train_time:121387ms step_avg:92.59ms +step:1312/1645 train_time:121480ms step_avg:92.59ms +step:1313/1645 train_time:121574ms step_avg:92.59ms +step:1314/1645 train_time:121667ms step_avg:92.59ms +step:1315/1645 train_time:121761ms step_avg:92.59ms +step:1316/1645 train_time:121854ms step_avg:92.59ms +step:1317/1645 train_time:121948ms step_avg:92.59ms +step:1318/1645 train_time:122041ms step_avg:92.60ms +step:1319/1645 train_time:122135ms step_avg:92.60ms +step:1320/1645 train_time:122229ms step_avg:92.60ms +step:1321/1645 train_time:122323ms step_avg:92.60ms +step:1322/1645 train_time:122417ms step_avg:92.60ms +step:1323/1645 train_time:122510ms step_avg:92.60ms +step:1324/1645 train_time:122604ms step_avg:92.60ms +step:1325/1645 train_time:122699ms step_avg:92.60ms +step:1326/1645 train_time:122793ms step_avg:92.60ms +step:1327/1645 train_time:122885ms step_avg:92.60ms +step:1328/1645 train_time:122979ms step_avg:92.60ms +step:1329/1645 train_time:123073ms step_avg:92.61ms +step:1330/1645 train_time:123167ms step_avg:92.61ms +step:1331/1645 train_time:123261ms step_avg:92.61ms +step:1332/1645 train_time:123356ms step_avg:92.61ms +step:1333/1645 train_time:123449ms step_avg:92.61ms +step:1334/1645 train_time:123542ms step_avg:92.61ms +step:1335/1645 train_time:123636ms step_avg:92.61ms +step:1336/1645 train_time:123729ms step_avg:92.61ms +step:1337/1645 train_time:123823ms step_avg:92.61ms +step:1338/1645 train_time:123917ms step_avg:92.61ms +step:1339/1645 train_time:124010ms step_avg:92.61ms +step:1340/1645 train_time:124105ms step_avg:92.62ms +step:1341/1645 train_time:124200ms step_avg:92.62ms +step:1342/1645 train_time:124295ms step_avg:92.62ms +step:1343/1645 train_time:124388ms step_avg:92.62ms +step:1344/1645 train_time:124482ms step_avg:92.62ms +step:1345/1645 train_time:124575ms step_avg:92.62ms +step:1346/1645 train_time:124668ms step_avg:92.62ms +step:1347/1645 train_time:124761ms step_avg:92.62ms +step:1348/1645 train_time:124855ms step_avg:92.62ms +step:1349/1645 train_time:124949ms step_avg:92.62ms +step:1350/1645 train_time:125042ms step_avg:92.62ms +step:1351/1645 train_time:125136ms step_avg:92.62ms +step:1352/1645 train_time:125230ms step_avg:92.63ms +step:1353/1645 train_time:125324ms step_avg:92.63ms +step:1354/1645 train_time:125418ms step_avg:92.63ms +step:1355/1645 train_time:125512ms step_avg:92.63ms +step:1356/1645 train_time:125605ms step_avg:92.63ms +step:1357/1645 train_time:125699ms step_avg:92.63ms +step:1358/1645 train_time:125793ms step_avg:92.63ms +step:1359/1645 train_time:125887ms step_avg:92.63ms +step:1360/1645 train_time:125980ms step_avg:92.63ms +step:1361/1645 train_time:126074ms step_avg:92.63ms +step:1362/1645 train_time:126167ms step_avg:92.63ms +step:1363/1645 train_time:126262ms step_avg:92.64ms +step:1364/1645 train_time:126356ms step_avg:92.64ms +step:1365/1645 train_time:126449ms step_avg:92.64ms +step:1366/1645 train_time:126542ms step_avg:92.64ms +step:1367/1645 train_time:126635ms step_avg:92.64ms +step:1368/1645 train_time:126729ms step_avg:92.64ms +step:1369/1645 train_time:126822ms step_avg:92.64ms +step:1370/1645 train_time:126916ms step_avg:92.64ms +step:1371/1645 train_time:127010ms step_avg:92.64ms +step:1372/1645 train_time:127104ms step_avg:92.64ms +step:1373/1645 train_time:127199ms step_avg:92.64ms +step:1374/1645 train_time:127292ms step_avg:92.64ms +step:1375/1645 train_time:127386ms step_avg:92.64ms +step:1375/1645 val_loss:3.3375 train_time:127479ms step_avg:92.71ms +step:1376/1645 train_time:127505ms step_avg:92.66ms +step:1377/1645 train_time:127579ms step_avg:92.65ms +step:1378/1645 train_time:127675ms step_avg:92.65ms +step:1379/1645 train_time:127768ms step_avg:92.65ms +step:1380/1645 train_time:127861ms step_avg:92.65ms +step:1381/1645 train_time:127954ms step_avg:92.65ms +step:1382/1645 train_time:128045ms step_avg:92.65ms +step:1383/1645 train_time:128138ms step_avg:92.65ms +step:1384/1645 train_time:128231ms step_avg:92.65ms +step:1385/1645 train_time:128325ms step_avg:92.65ms +step:1386/1645 train_time:128419ms step_avg:92.65ms +step:1387/1645 train_time:128515ms step_avg:92.66ms +step:1388/1645 train_time:128611ms step_avg:92.66ms +step:1389/1645 train_time:128706ms step_avg:92.66ms +step:1390/1645 train_time:128800ms step_avg:92.66ms +step:1391/1645 train_time:128894ms step_avg:92.66ms +step:1392/1645 train_time:128988ms step_avg:92.66ms +step:1393/1645 train_time:129080ms step_avg:92.66ms +step:1394/1645 train_time:129173ms step_avg:92.66ms +step:1395/1645 train_time:129267ms step_avg:92.66ms +step:1396/1645 train_time:129360ms step_avg:92.66ms +step:1397/1645 train_time:129456ms step_avg:92.67ms +step:1398/1645 train_time:129551ms step_avg:92.67ms +step:1399/1645 train_time:129645ms step_avg:92.67ms +step:1400/1645 train_time:129739ms step_avg:92.67ms +step:1401/1645 train_time:129833ms step_avg:92.67ms +step:1402/1645 train_time:129926ms step_avg:92.67ms +step:1403/1645 train_time:130020ms step_avg:92.67ms +step:1404/1645 train_time:130113ms step_avg:92.67ms +step:1405/1645 train_time:130205ms step_avg:92.67ms +step:1406/1645 train_time:130298ms step_avg:92.67ms +step:1407/1645 train_time:130392ms step_avg:92.67ms +step:1408/1645 train_time:130486ms step_avg:92.67ms +step:1409/1645 train_time:130580ms step_avg:92.68ms +step:1410/1645 train_time:130674ms step_avg:92.68ms +step:1411/1645 train_time:130768ms step_avg:92.68ms +step:1412/1645 train_time:130862ms step_avg:92.68ms +step:1413/1645 train_time:130955ms step_avg:92.68ms +step:1414/1645 train_time:131049ms step_avg:92.68ms +step:1415/1645 train_time:131142ms step_avg:92.68ms +step:1416/1645 train_time:131236ms step_avg:92.68ms +step:1417/1645 train_time:131329ms step_avg:92.68ms +step:1418/1645 train_time:131422ms step_avg:92.68ms +step:1419/1645 train_time:131516ms step_avg:92.68ms +step:1420/1645 train_time:131610ms step_avg:92.68ms +step:1421/1645 train_time:131703ms step_avg:92.68ms +step:1422/1645 train_time:131797ms step_avg:92.68ms +step:1423/1645 train_time:131891ms step_avg:92.69ms +step:1424/1645 train_time:131984ms step_avg:92.69ms +step:1425/1645 train_time:132077ms step_avg:92.69ms +step:1426/1645 train_time:132172ms step_avg:92.69ms +step:1427/1645 train_time:132266ms step_avg:92.69ms +step:1428/1645 train_time:132360ms step_avg:92.69ms +step:1429/1645 train_time:132454ms step_avg:92.69ms +step:1430/1645 train_time:132548ms step_avg:92.69ms +step:1431/1645 train_time:132641ms step_avg:92.69ms +step:1432/1645 train_time:132735ms step_avg:92.69ms +step:1433/1645 train_time:132829ms step_avg:92.69ms +step:1434/1645 train_time:132923ms step_avg:92.69ms +step:1435/1645 train_time:133018ms step_avg:92.70ms +step:1436/1645 train_time:133110ms step_avg:92.69ms +step:1437/1645 train_time:133204ms step_avg:92.70ms +step:1438/1645 train_time:133296ms step_avg:92.70ms +step:1439/1645 train_time:133391ms step_avg:92.70ms +step:1440/1645 train_time:133484ms step_avg:92.70ms +step:1441/1645 train_time:133577ms step_avg:92.70ms +step:1442/1645 train_time:133671ms step_avg:92.70ms +step:1443/1645 train_time:133765ms step_avg:92.70ms +step:1444/1645 train_time:133859ms step_avg:92.70ms +step:1445/1645 train_time:133953ms step_avg:92.70ms +step:1446/1645 train_time:134047ms step_avg:92.70ms +step:1447/1645 train_time:134140ms step_avg:92.70ms +step:1448/1645 train_time:134234ms step_avg:92.70ms +step:1449/1645 train_time:134328ms step_avg:92.70ms +step:1450/1645 train_time:134422ms step_avg:92.70ms +step:1451/1645 train_time:134516ms step_avg:92.71ms +step:1452/1645 train_time:134610ms step_avg:92.71ms +step:1453/1645 train_time:134704ms step_avg:92.71ms +step:1454/1645 train_time:134798ms step_avg:92.71ms +step:1455/1645 train_time:134893ms step_avg:92.71ms +step:1456/1645 train_time:134987ms step_avg:92.71ms +step:1457/1645 train_time:135080ms step_avg:92.71ms +step:1458/1645 train_time:135172ms step_avg:92.71ms +step:1459/1645 train_time:135266ms step_avg:92.71ms +step:1460/1645 train_time:135360ms step_avg:92.71ms +step:1461/1645 train_time:135454ms step_avg:92.71ms +step:1462/1645 train_time:135548ms step_avg:92.71ms +step:1463/1645 train_time:135640ms step_avg:92.71ms +step:1464/1645 train_time:135735ms step_avg:92.72ms +step:1465/1645 train_time:135829ms step_avg:92.72ms +step:1466/1645 train_time:135924ms step_avg:92.72ms +step:1467/1645 train_time:136018ms step_avg:92.72ms +step:1468/1645 train_time:136111ms step_avg:92.72ms +step:1469/1645 train_time:136205ms step_avg:92.72ms +step:1470/1645 train_time:136298ms step_avg:92.72ms +step:1471/1645 train_time:136392ms step_avg:92.72ms +step:1472/1645 train_time:136486ms step_avg:92.72ms +step:1473/1645 train_time:136578ms step_avg:92.72ms +step:1474/1645 train_time:136672ms step_avg:92.72ms +step:1475/1645 train_time:136766ms step_avg:92.72ms +step:1476/1645 train_time:136860ms step_avg:92.72ms +step:1477/1645 train_time:136953ms step_avg:92.72ms +step:1478/1645 train_time:137046ms step_avg:92.72ms +step:1479/1645 train_time:137140ms step_avg:92.72ms +step:1480/1645 train_time:137233ms step_avg:92.73ms +step:1481/1645 train_time:137327ms step_avg:92.73ms +step:1482/1645 train_time:137421ms step_avg:92.73ms +step:1483/1645 train_time:137514ms step_avg:92.73ms +step:1484/1645 train_time:137608ms step_avg:92.73ms +step:1485/1645 train_time:137702ms step_avg:92.73ms +step:1486/1645 train_time:137795ms step_avg:92.73ms +step:1487/1645 train_time:137889ms step_avg:92.73ms +step:1488/1645 train_time:137983ms step_avg:92.73ms +step:1489/1645 train_time:138077ms step_avg:92.73ms +step:1490/1645 train_time:138171ms step_avg:92.73ms +step:1491/1645 train_time:138265ms step_avg:92.73ms +step:1492/1645 train_time:138358ms step_avg:92.73ms +step:1493/1645 train_time:138451ms step_avg:92.73ms +step:1494/1645 train_time:138545ms step_avg:92.73ms +step:1495/1645 train_time:138640ms step_avg:92.74ms +step:1496/1645 train_time:138733ms step_avg:92.74ms +step:1497/1645 train_time:138828ms step_avg:92.74ms +step:1498/1645 train_time:138921ms step_avg:92.74ms +step:1499/1645 train_time:139016ms step_avg:92.74ms +step:1500/1645 train_time:139109ms step_avg:92.74ms +step:1500/1645 val_loss:3.3082 train_time:139202ms step_avg:92.80ms +step:1501/1645 train_time:139227ms step_avg:92.76ms +step:1502/1645 train_time:139300ms step_avg:92.74ms +step:1503/1645 train_time:139394ms step_avg:92.74ms +step:1504/1645 train_time:139487ms step_avg:92.74ms +step:1505/1645 train_time:139580ms step_avg:92.74ms +step:1506/1645 train_time:139674ms step_avg:92.74ms +step:1507/1645 train_time:139766ms step_avg:92.74ms +step:1508/1645 train_time:139859ms step_avg:92.74ms +step:1509/1645 train_time:139952ms step_avg:92.75ms +step:1510/1645 train_time:140046ms step_avg:92.75ms +step:1511/1645 train_time:140141ms step_avg:92.75ms +step:1512/1645 train_time:140238ms step_avg:92.75ms +step:1513/1645 train_time:140332ms step_avg:92.75ms +step:1514/1645 train_time:140426ms step_avg:92.75ms +step:1515/1645 train_time:140519ms step_avg:92.75ms +step:1516/1645 train_time:140613ms step_avg:92.75ms +step:1517/1645 train_time:140705ms step_avg:92.75ms +step:1518/1645 train_time:140797ms step_avg:92.75ms +step:1519/1645 train_time:140891ms step_avg:92.75ms +step:1520/1645 train_time:140984ms step_avg:92.75ms +step:1521/1645 train_time:141078ms step_avg:92.75ms +step:1522/1645 train_time:141174ms step_avg:92.76ms +step:1523/1645 train_time:141268ms step_avg:92.76ms +step:1524/1645 train_time:141363ms step_avg:92.76ms +step:1525/1645 train_time:141456ms step_avg:92.76ms +step:1526/1645 train_time:141550ms step_avg:92.76ms +step:1527/1645 train_time:141643ms step_avg:92.76ms +step:1528/1645 train_time:141736ms step_avg:92.76ms +step:1529/1645 train_time:141828ms step_avg:92.76ms +step:1530/1645 train_time:141921ms step_avg:92.76ms +step:1531/1645 train_time:142016ms step_avg:92.76ms +step:1532/1645 train_time:142110ms step_avg:92.76ms +step:1533/1645 train_time:142203ms step_avg:92.76ms +step:1534/1645 train_time:142297ms step_avg:92.76ms +step:1535/1645 train_time:142393ms step_avg:92.76ms +step:1536/1645 train_time:142486ms step_avg:92.76ms +step:1537/1645 train_time:142580ms step_avg:92.77ms +step:1538/1645 train_time:142673ms step_avg:92.77ms +step:1539/1645 train_time:142766ms step_avg:92.77ms +step:1540/1645 train_time:142859ms step_avg:92.77ms +step:1541/1645 train_time:142953ms step_avg:92.77ms +step:1542/1645 train_time:143046ms step_avg:92.77ms +step:1543/1645 train_time:143140ms step_avg:92.77ms +step:1544/1645 train_time:143234ms step_avg:92.77ms +step:1545/1645 train_time:143328ms step_avg:92.77ms +step:1546/1645 train_time:143422ms step_avg:92.77ms +step:1547/1645 train_time:143517ms step_avg:92.77ms +step:1548/1645 train_time:143610ms step_avg:92.77ms +step:1549/1645 train_time:143704ms step_avg:92.77ms +step:1550/1645 train_time:143798ms step_avg:92.77ms +step:1551/1645 train_time:143892ms step_avg:92.77ms +step:1552/1645 train_time:143986ms step_avg:92.77ms +step:1553/1645 train_time:144079ms step_avg:92.77ms +step:1554/1645 train_time:144173ms step_avg:92.78ms +step:1555/1645 train_time:144267ms step_avg:92.78ms +step:1556/1645 train_time:144360ms step_avg:92.78ms +step:1557/1645 train_time:144455ms step_avg:92.78ms +step:1558/1645 train_time:144550ms step_avg:92.78ms +step:1559/1645 train_time:144643ms step_avg:92.78ms +step:1560/1645 train_time:144736ms step_avg:92.78ms +step:1561/1645 train_time:144830ms step_avg:92.78ms +step:1562/1645 train_time:144923ms step_avg:92.78ms +step:1563/1645 train_time:145017ms step_avg:92.78ms +step:1564/1645 train_time:145110ms step_avg:92.78ms +step:1565/1645 train_time:145204ms step_avg:92.78ms +step:1566/1645 train_time:145299ms step_avg:92.78ms +step:1567/1645 train_time:145393ms step_avg:92.78ms +step:1568/1645 train_time:145487ms step_avg:92.78ms +step:1569/1645 train_time:145581ms step_avg:92.79ms +step:1570/1645 train_time:145674ms step_avg:92.79ms +step:1571/1645 train_time:145767ms step_avg:92.79ms +step:1572/1645 train_time:145860ms step_avg:92.79ms +step:1573/1645 train_time:145955ms step_avg:92.79ms +step:1574/1645 train_time:146048ms step_avg:92.79ms +step:1575/1645 train_time:146142ms step_avg:92.79ms +step:1576/1645 train_time:146236ms step_avg:92.79ms +step:1577/1645 train_time:146330ms step_avg:92.79ms +step:1578/1645 train_time:146424ms step_avg:92.79ms +step:1579/1645 train_time:146518ms step_avg:92.79ms +step:1580/1645 train_time:146611ms step_avg:92.79ms +step:1581/1645 train_time:146704ms step_avg:92.79ms +step:1582/1645 train_time:146797ms step_avg:92.79ms +step:1583/1645 train_time:146891ms step_avg:92.79ms +step:1584/1645 train_time:146985ms step_avg:92.79ms +step:1585/1645 train_time:147077ms step_avg:92.79ms +step:1586/1645 train_time:147171ms step_avg:92.79ms +step:1587/1645 train_time:147264ms step_avg:92.79ms +step:1588/1645 train_time:147358ms step_avg:92.79ms +step:1589/1645 train_time:147452ms step_avg:92.80ms +step:1590/1645 train_time:147546ms step_avg:92.80ms +step:1591/1645 train_time:147640ms step_avg:92.80ms +step:1592/1645 train_time:147733ms step_avg:92.80ms +step:1593/1645 train_time:147827ms step_avg:92.80ms +step:1594/1645 train_time:147920ms step_avg:92.80ms +step:1595/1645 train_time:148014ms step_avg:92.80ms +step:1596/1645 train_time:148108ms step_avg:92.80ms +step:1597/1645 train_time:148202ms step_avg:92.80ms +step:1598/1645 train_time:148295ms step_avg:92.80ms +step:1599/1645 train_time:148389ms step_avg:92.80ms +step:1600/1645 train_time:148482ms step_avg:92.80ms +step:1601/1645 train_time:148576ms step_avg:92.80ms +step:1602/1645 train_time:148670ms step_avg:92.80ms +step:1603/1645 train_time:148762ms step_avg:92.80ms +step:1604/1645 train_time:148856ms step_avg:92.80ms +step:1605/1645 train_time:148950ms step_avg:92.80ms +step:1606/1645 train_time:149044ms step_avg:92.80ms +step:1607/1645 train_time:149137ms step_avg:92.80ms +step:1608/1645 train_time:149230ms step_avg:92.80ms +step:1609/1645 train_time:149324ms step_avg:92.81ms +step:1610/1645 train_time:149419ms step_avg:92.81ms +step:1611/1645 train_time:149512ms step_avg:92.81ms +step:1612/1645 train_time:149607ms step_avg:92.81ms +step:1613/1645 train_time:149700ms step_avg:92.81ms +step:1614/1645 train_time:149793ms step_avg:92.81ms +step:1615/1645 train_time:149888ms step_avg:92.81ms +step:1616/1645 train_time:149982ms step_avg:92.81ms +step:1617/1645 train_time:150076ms step_avg:92.81ms +step:1618/1645 train_time:150169ms step_avg:92.81ms +step:1619/1645 train_time:150263ms step_avg:92.81ms +step:1620/1645 train_time:150356ms step_avg:92.81ms +step:1621/1645 train_time:150450ms step_avg:92.81ms +step:1622/1645 train_time:150544ms step_avg:92.81ms +step:1623/1645 train_time:150638ms step_avg:92.81ms +step:1624/1645 train_time:150731ms step_avg:92.81ms +step:1625/1645 train_time:150825ms step_avg:92.82ms +step:1625/1645 val_loss:3.2844 train_time:150919ms step_avg:92.87ms +step:1626/1645 train_time:150939ms step_avg:92.83ms +step:1627/1645 train_time:151016ms step_avg:92.82ms +step:1628/1645 train_time:151113ms step_avg:92.82ms +step:1629/1645 train_time:151206ms step_avg:92.82ms +step:1630/1645 train_time:151299ms step_avg:92.82ms +step:1631/1645 train_time:151392ms step_avg:92.82ms +step:1632/1645 train_time:151485ms step_avg:92.82ms +step:1633/1645 train_time:151578ms step_avg:92.82ms +step:1634/1645 train_time:151671ms step_avg:92.82ms +step:1635/1645 train_time:151764ms step_avg:92.82ms +step:1636/1645 train_time:151859ms step_avg:92.82ms +step:1637/1645 train_time:151954ms step_avg:92.82ms +step:1638/1645 train_time:152049ms step_avg:92.83ms +step:1639/1645 train_time:152143ms step_avg:92.83ms +step:1640/1645 train_time:152237ms step_avg:92.83ms +step:1641/1645 train_time:152330ms step_avg:92.83ms +step:1642/1645 train_time:152423ms step_avg:92.83ms +step:1643/1645 train_time:152517ms step_avg:92.83ms +step:1644/1645 train_time:152610ms step_avg:92.83ms +step:1645/1645 train_time:152704ms step_avg:92.83ms +step:1645/1645 val_loss:3.2787 train_time:152798ms step_avg:92.89ms +peak memory allocated: 31659 MiB reserved: 46816 MiB diff --git a/records/091825_Smear/README.md b/records/091825_Smear/README.md new file mode 100644 index 000000000..8b1872052 --- /dev/null +++ b/records/091825_Smear/README.md @@ -0,0 +1,35 @@ +## New WR 152.7s: Smear token embeddings 1 position forward, -15 steps + +This PR builds on all recent WR improvements including PR #127. From inspecting trained model weights, it was observed that multiple attention heads were consistently attending to the prior token. However, attention is a computationally inefficient way to attend to the prior token. This functionality is built-in below in a light-weight manner. The first 12 dimensions of the residual stream/embed are used to gate both the smear module and attention. Approximately, the model finds that (token + 0.07prior_token) is a more useful embedding representation than (token). + +Note: This improvement is more marginal than the timing change would indicate. The prior WR had a mean loss of 3.2781. If I attempt to control for loss, the impact of this change appears closer to 5 steps based on testing. + +``` +self.smear_gate = CastedLinear(12, 1) +self.smear_gate.weight.detach().zero_() + +x = self.embed(input_seq) +# smear token embed forward 1 position +smear_lambda = self.scalars[5 * len(self.blocks)] +smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) +x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) +x = x0 = norm(x[None]) +``` + +Validation: +``` +import scipy.stats +import torch + +accs = [3.2781,3.2792,3.2765,3.2796,3.2803,3.2801,3.2787,3.2798,3.2787,3.2786] + +times = [152.771,152.816,152.834,152.755,152.789,152.773,152.815,152.796,152.798,152.754] + +print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue) +# p=0.0084 +print("acc:", torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0011), tensor(3.2790)) + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (tensor(0.0269), tensor(152.7901)) +``` \ No newline at end of file diff --git a/records/091825_Smear/cc16404d-92d4-48c1-b9b0-b906360363b4.txt b/records/091825_Smear/cc16404d-92d4-48c1-b9b0-b906360363b4.txt new file mode 100644 index 000000000..750f49a3f --- /dev/null +++ b/records/091825_Smear/cc16404d-92d4-48c1-b9b0-b906360363b4.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:13:03 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 24C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 27C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 28C P0 113W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 26C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 25C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 29C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 28C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 26C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:134ms step_avg:133.92ms +step:2/1645 train_time:153ms step_avg:76.49ms +step:3/1645 train_time:224ms step_avg:74.77ms +step:4/1645 train_time:314ms step_avg:78.52ms +step:5/1645 train_time:405ms step_avg:80.98ms +step:6/1645 train_time:496ms step_avg:82.62ms +step:7/1645 train_time:586ms step_avg:83.76ms +step:8/1645 train_time:677ms step_avg:84.66ms +step:9/1645 train_time:769ms step_avg:85.43ms +step:10/1645 train_time:860ms step_avg:85.97ms +step:11/1645 train_time:951ms step_avg:86.42ms +step:12/1645 train_time:1045ms step_avg:87.04ms +step:13/1645 train_time:1141ms step_avg:87.76ms +step:14/1645 train_time:1235ms step_avg:88.19ms +step:15/1645 train_time:1326ms step_avg:88.41ms +step:16/1645 train_time:1417ms step_avg:88.59ms +step:17/1645 train_time:1510ms step_avg:88.83ms +step:18/1645 train_time:1601ms step_avg:88.93ms +step:19/1645 train_time:1692ms step_avg:89.05ms +step:20/1645 train_time:1783ms step_avg:89.17ms +step:21/1645 train_time:1875ms step_avg:89.27ms +step:22/1645 train_time:1967ms step_avg:89.39ms +step:23/1645 train_time:2059ms step_avg:89.51ms +step:24/1645 train_time:2154ms step_avg:89.75ms +step:25/1645 train_time:2247ms step_avg:89.90ms +step:26/1645 train_time:2340ms step_avg:89.98ms +step:27/1645 train_time:2432ms step_avg:90.09ms +step:28/1645 train_time:2523ms step_avg:90.12ms +step:29/1645 train_time:2615ms step_avg:90.17ms +step:30/1645 train_time:2707ms step_avg:90.22ms +step:31/1645 train_time:2798ms step_avg:90.25ms +step:32/1645 train_time:2890ms step_avg:90.32ms +step:33/1645 train_time:2981ms step_avg:90.34ms +step:34/1645 train_time:3074ms step_avg:90.40ms +step:35/1645 train_time:3167ms step_avg:90.49ms +step:36/1645 train_time:3259ms step_avg:90.54ms +step:37/1645 train_time:3352ms step_avg:90.60ms +step:38/1645 train_time:3444ms step_avg:90.63ms +step:39/1645 train_time:3535ms step_avg:90.65ms +step:40/1645 train_time:3627ms step_avg:90.67ms +step:41/1645 train_time:3718ms step_avg:90.69ms +step:42/1645 train_time:3811ms step_avg:90.73ms +step:43/1645 train_time:3902ms step_avg:90.75ms +step:44/1645 train_time:3994ms step_avg:90.78ms +step:45/1645 train_time:4087ms step_avg:90.82ms +step:46/1645 train_time:4179ms step_avg:90.84ms +step:47/1645 train_time:4272ms step_avg:90.89ms +step:48/1645 train_time:4364ms step_avg:90.92ms +step:49/1645 train_time:4456ms step_avg:90.94ms +step:50/1645 train_time:4547ms step_avg:90.95ms +step:51/1645 train_time:4639ms step_avg:90.95ms +step:52/1645 train_time:4730ms step_avg:90.96ms +step:53/1645 train_time:4821ms step_avg:90.96ms +step:54/1645 train_time:4913ms step_avg:90.98ms +step:55/1645 train_time:5006ms step_avg:91.01ms +step:56/1645 train_time:5098ms step_avg:91.04ms +step:57/1645 train_time:5191ms step_avg:91.06ms +step:58/1645 train_time:5283ms step_avg:91.09ms +step:59/1645 train_time:5375ms step_avg:91.10ms +step:60/1645 train_time:5467ms step_avg:91.12ms +step:61/1645 train_time:5558ms step_avg:91.12ms +step:62/1645 train_time:5650ms step_avg:91.13ms +step:63/1645 train_time:5741ms step_avg:91.13ms +step:64/1645 train_time:5833ms step_avg:91.14ms +step:65/1645 train_time:5924ms step_avg:91.14ms +step:66/1645 train_time:6016ms step_avg:91.15ms +step:67/1645 train_time:6109ms step_avg:91.18ms +step:68/1645 train_time:6201ms step_avg:91.19ms +step:69/1645 train_time:6294ms step_avg:91.22ms +step:70/1645 train_time:6387ms step_avg:91.24ms +step:71/1645 train_time:6478ms step_avg:91.24ms +step:72/1645 train_time:6571ms step_avg:91.26ms +step:73/1645 train_time:6662ms step_avg:91.27ms +step:74/1645 train_time:6754ms step_avg:91.27ms +step:75/1645 train_time:6845ms step_avg:91.27ms +step:76/1645 train_time:6937ms step_avg:91.27ms +step:77/1645 train_time:7029ms step_avg:91.28ms +step:78/1645 train_time:7121ms step_avg:91.29ms +step:79/1645 train_time:7212ms step_avg:91.30ms +step:80/1645 train_time:7305ms step_avg:91.31ms +step:81/1645 train_time:7396ms step_avg:91.31ms +step:82/1645 train_time:7489ms step_avg:91.32ms +step:83/1645 train_time:7580ms step_avg:91.33ms +step:84/1645 train_time:7672ms step_avg:91.33ms +step:85/1645 train_time:7764ms step_avg:91.34ms +step:86/1645 train_time:7855ms step_avg:91.34ms +step:87/1645 train_time:7946ms step_avg:91.34ms +step:88/1645 train_time:8038ms step_avg:91.34ms +step:89/1645 train_time:8131ms step_avg:91.36ms +step:90/1645 train_time:8224ms step_avg:91.37ms +step:91/1645 train_time:8316ms step_avg:91.39ms +step:92/1645 train_time:8408ms step_avg:91.40ms +step:93/1645 train_time:8500ms step_avg:91.40ms +step:94/1645 train_time:8592ms step_avg:91.41ms +step:95/1645 train_time:8685ms step_avg:91.42ms +step:96/1645 train_time:8776ms step_avg:91.42ms +step:97/1645 train_time:8868ms step_avg:91.43ms +step:98/1645 train_time:8960ms step_avg:91.43ms +step:99/1645 train_time:9052ms step_avg:91.43ms +step:100/1645 train_time:9143ms step_avg:91.43ms +step:101/1645 train_time:9235ms step_avg:91.44ms +step:102/1645 train_time:9328ms step_avg:91.45ms +step:103/1645 train_time:9420ms step_avg:91.46ms +step:104/1645 train_time:9512ms step_avg:91.46ms +step:105/1645 train_time:9604ms step_avg:91.47ms +step:106/1645 train_time:9695ms step_avg:91.47ms +step:107/1645 train_time:9788ms step_avg:91.48ms +step:108/1645 train_time:9879ms step_avg:91.47ms +step:109/1645 train_time:9971ms step_avg:91.47ms +step:110/1645 train_time:10062ms step_avg:91.47ms +step:111/1645 train_time:10154ms step_avg:91.47ms +step:112/1645 train_time:10246ms step_avg:91.48ms +step:113/1645 train_time:10337ms step_avg:91.48ms +step:114/1645 train_time:10431ms step_avg:91.50ms +step:115/1645 train_time:10523ms step_avg:91.50ms +step:116/1645 train_time:10614ms step_avg:91.50ms +step:117/1645 train_time:10705ms step_avg:91.50ms +step:118/1645 train_time:10797ms step_avg:91.50ms +step:119/1645 train_time:10889ms step_avg:91.50ms +step:120/1645 train_time:10980ms step_avg:91.50ms +step:121/1645 train_time:11072ms step_avg:91.50ms +step:122/1645 train_time:11163ms step_avg:91.50ms +step:123/1645 train_time:11255ms step_avg:91.50ms +step:124/1645 train_time:11347ms step_avg:91.51ms +step:125/1645 train_time:11438ms step_avg:91.51ms +step:125/1645 val_loss:4.3312 train_time:11530ms step_avg:92.24ms +step:126/1645 train_time:11546ms step_avg:91.64ms +step:127/1645 train_time:11629ms step_avg:91.56ms +step:128/1645 train_time:11732ms step_avg:91.66ms +step:129/1645 train_time:11827ms step_avg:91.68ms +step:130/1645 train_time:11919ms step_avg:91.69ms +step:131/1645 train_time:12011ms step_avg:91.69ms +step:132/1645 train_time:12102ms step_avg:91.68ms +step:133/1645 train_time:12192ms step_avg:91.67ms +step:134/1645 train_time:12283ms step_avg:91.66ms +step:135/1645 train_time:12373ms step_avg:91.65ms +step:136/1645 train_time:12464ms step_avg:91.65ms +step:137/1645 train_time:12555ms step_avg:91.64ms +step:138/1645 train_time:12648ms step_avg:91.65ms +step:139/1645 train_time:12744ms step_avg:91.68ms +step:140/1645 train_time:12838ms step_avg:91.70ms +step:141/1645 train_time:12929ms step_avg:91.70ms +step:142/1645 train_time:13020ms step_avg:91.69ms +step:143/1645 train_time:13111ms step_avg:91.69ms +step:144/1645 train_time:13202ms step_avg:91.68ms +step:145/1645 train_time:13293ms step_avg:91.67ms +step:146/1645 train_time:13383ms step_avg:91.67ms +step:147/1645 train_time:13474ms step_avg:91.66ms +step:148/1645 train_time:13565ms step_avg:91.66ms +step:149/1645 train_time:13658ms step_avg:91.66ms +step:150/1645 train_time:13750ms step_avg:91.67ms +step:151/1645 train_time:13844ms step_avg:91.68ms +step:152/1645 train_time:13936ms step_avg:91.69ms +step:153/1645 train_time:14028ms step_avg:91.69ms +step:154/1645 train_time:14120ms step_avg:91.69ms +step:155/1645 train_time:14211ms step_avg:91.68ms +step:156/1645 train_time:14302ms step_avg:91.68ms +step:157/1645 train_time:14393ms step_avg:91.67ms +step:158/1645 train_time:14484ms step_avg:91.67ms +step:159/1645 train_time:14575ms step_avg:91.67ms +step:160/1645 train_time:14666ms step_avg:91.66ms +step:161/1645 train_time:14759ms step_avg:91.67ms +step:162/1645 train_time:14851ms step_avg:91.68ms +step:163/1645 train_time:14944ms step_avg:91.68ms +step:164/1645 train_time:15037ms step_avg:91.69ms +step:165/1645 train_time:15129ms step_avg:91.69ms +step:166/1645 train_time:15221ms step_avg:91.69ms +step:167/1645 train_time:15313ms step_avg:91.69ms +step:168/1645 train_time:15404ms step_avg:91.69ms +step:169/1645 train_time:15495ms step_avg:91.68ms +step:170/1645 train_time:15586ms step_avg:91.68ms +step:171/1645 train_time:15678ms step_avg:91.68ms +step:172/1645 train_time:15769ms step_avg:91.68ms +step:173/1645 train_time:15862ms step_avg:91.69ms +step:174/1645 train_time:15953ms step_avg:91.69ms +step:175/1645 train_time:16045ms step_avg:91.69ms +step:176/1645 train_time:16137ms step_avg:91.69ms +step:177/1645 train_time:16229ms step_avg:91.69ms +step:178/1645 train_time:16321ms step_avg:91.69ms +step:179/1645 train_time:16413ms step_avg:91.69ms +step:180/1645 train_time:16505ms step_avg:91.69ms +step:181/1645 train_time:16596ms step_avg:91.69ms +step:182/1645 train_time:16687ms step_avg:91.69ms +step:183/1645 train_time:16779ms step_avg:91.69ms +step:184/1645 train_time:16870ms step_avg:91.69ms +step:185/1645 train_time:16963ms step_avg:91.69ms +step:186/1645 train_time:17055ms step_avg:91.69ms +step:187/1645 train_time:17146ms step_avg:91.69ms +step:188/1645 train_time:17238ms step_avg:91.69ms +step:189/1645 train_time:17330ms step_avg:91.69ms +step:190/1645 train_time:17422ms step_avg:91.69ms +step:191/1645 train_time:17514ms step_avg:91.70ms +step:192/1645 train_time:17606ms step_avg:91.70ms +step:193/1645 train_time:17698ms step_avg:91.70ms +step:194/1645 train_time:17789ms step_avg:91.69ms +step:195/1645 train_time:17880ms step_avg:91.69ms +step:196/1645 train_time:17971ms step_avg:91.69ms +step:197/1645 train_time:18063ms step_avg:91.69ms +step:198/1645 train_time:18154ms step_avg:91.69ms +step:199/1645 train_time:18246ms step_avg:91.69ms +step:200/1645 train_time:18338ms step_avg:91.69ms +step:201/1645 train_time:18429ms step_avg:91.69ms +step:202/1645 train_time:18521ms step_avg:91.69ms +step:203/1645 train_time:18612ms step_avg:91.68ms +step:204/1645 train_time:18703ms step_avg:91.68ms +step:205/1645 train_time:18796ms step_avg:91.69ms +step:206/1645 train_time:18886ms step_avg:91.68ms +step:207/1645 train_time:18978ms step_avg:91.68ms +step:208/1645 train_time:19069ms step_avg:91.68ms +step:209/1645 train_time:19160ms step_avg:91.67ms +step:210/1645 train_time:19252ms step_avg:91.67ms +step:211/1645 train_time:19344ms step_avg:91.68ms +step:212/1645 train_time:19435ms step_avg:91.67ms +step:213/1645 train_time:19527ms step_avg:91.68ms +step:214/1645 train_time:19619ms step_avg:91.68ms +step:215/1645 train_time:19711ms step_avg:91.68ms +step:216/1645 train_time:19803ms step_avg:91.68ms +step:217/1645 train_time:19896ms step_avg:91.68ms +step:218/1645 train_time:19986ms step_avg:91.68ms +step:219/1645 train_time:20078ms step_avg:91.68ms +step:220/1645 train_time:20169ms step_avg:91.68ms +step:221/1645 train_time:20261ms step_avg:91.68ms +step:222/1645 train_time:20353ms step_avg:91.68ms +step:223/1645 train_time:20444ms step_avg:91.68ms +step:224/1645 train_time:20537ms step_avg:91.68ms +step:225/1645 train_time:20628ms step_avg:91.68ms +step:226/1645 train_time:20720ms step_avg:91.68ms +step:227/1645 train_time:20812ms step_avg:91.68ms +step:228/1645 train_time:20904ms step_avg:91.68ms +step:229/1645 train_time:20995ms step_avg:91.68ms +step:230/1645 train_time:21086ms step_avg:91.68ms +step:231/1645 train_time:21178ms step_avg:91.68ms +step:232/1645 train_time:21270ms step_avg:91.68ms +step:233/1645 train_time:21362ms step_avg:91.68ms +step:234/1645 train_time:21454ms step_avg:91.69ms +step:235/1645 train_time:21546ms step_avg:91.68ms +step:236/1645 train_time:21638ms step_avg:91.69ms +step:237/1645 train_time:21730ms step_avg:91.69ms +step:238/1645 train_time:21821ms step_avg:91.69ms +step:239/1645 train_time:21913ms step_avg:91.69ms +step:240/1645 train_time:22004ms step_avg:91.69ms +step:241/1645 train_time:22096ms step_avg:91.69ms +step:242/1645 train_time:22187ms step_avg:91.68ms +step:243/1645 train_time:22279ms step_avg:91.68ms +step:244/1645 train_time:22369ms step_avg:91.68ms +step:245/1645 train_time:22461ms step_avg:91.68ms +step:246/1645 train_time:22553ms step_avg:91.68ms +step:247/1645 train_time:22644ms step_avg:91.68ms +step:248/1645 train_time:22737ms step_avg:91.68ms +step:249/1645 train_time:22828ms step_avg:91.68ms +step:250/1645 train_time:22921ms step_avg:91.68ms +step:250/1645 val_loss:3.9747 train_time:23014ms step_avg:92.06ms +step:251/1645 train_time:23029ms step_avg:91.75ms +step:252/1645 train_time:23108ms step_avg:91.70ms +step:253/1645 train_time:23204ms step_avg:91.72ms +step:254/1645 train_time:23296ms step_avg:91.72ms +step:255/1645 train_time:23387ms step_avg:91.71ms +step:256/1645 train_time:23478ms step_avg:91.71ms +step:257/1645 train_time:23569ms step_avg:91.71ms +step:258/1645 train_time:23660ms step_avg:91.70ms +step:259/1645 train_time:23750ms step_avg:91.70ms +step:260/1645 train_time:23842ms step_avg:91.70ms +step:261/1645 train_time:23933ms step_avg:91.70ms +step:262/1645 train_time:24025ms step_avg:91.70ms +step:263/1645 train_time:24118ms step_avg:91.70ms +step:264/1645 train_time:24210ms step_avg:91.70ms +step:265/1645 train_time:24303ms step_avg:91.71ms +step:266/1645 train_time:24394ms step_avg:91.71ms +step:267/1645 train_time:24485ms step_avg:91.71ms +step:268/1645 train_time:24577ms step_avg:91.70ms +step:269/1645 train_time:24667ms step_avg:91.70ms +step:270/1645 train_time:24758ms step_avg:91.70ms +step:271/1645 train_time:24849ms step_avg:91.69ms +step:272/1645 train_time:24940ms step_avg:91.69ms +step:273/1645 train_time:25032ms step_avg:91.69ms +step:274/1645 train_time:25123ms step_avg:91.69ms +step:275/1645 train_time:25215ms step_avg:91.69ms +step:276/1645 train_time:25306ms step_avg:91.69ms +step:277/1645 train_time:25399ms step_avg:91.69ms +step:278/1645 train_time:25490ms step_avg:91.69ms +step:279/1645 train_time:25582ms step_avg:91.69ms +step:280/1645 train_time:25674ms step_avg:91.69ms +step:281/1645 train_time:25765ms step_avg:91.69ms +step:282/1645 train_time:25856ms step_avg:91.69ms +step:283/1645 train_time:25947ms step_avg:91.69ms +step:284/1645 train_time:26039ms step_avg:91.69ms +step:285/1645 train_time:26130ms step_avg:91.69ms +step:286/1645 train_time:26222ms step_avg:91.69ms +step:287/1645 train_time:26314ms step_avg:91.69ms +step:288/1645 train_time:26406ms step_avg:91.69ms +step:289/1645 train_time:26499ms step_avg:91.69ms +step:290/1645 train_time:26591ms step_avg:91.69ms +step:291/1645 train_time:26683ms step_avg:91.69ms +step:292/1645 train_time:26774ms step_avg:91.69ms +step:293/1645 train_time:26865ms step_avg:91.69ms +step:294/1645 train_time:26957ms step_avg:91.69ms +step:295/1645 train_time:27048ms step_avg:91.69ms +step:296/1645 train_time:27139ms step_avg:91.69ms +step:297/1645 train_time:27232ms step_avg:91.69ms +step:298/1645 train_time:27324ms step_avg:91.69ms +step:299/1645 train_time:27416ms step_avg:91.69ms +step:300/1645 train_time:27508ms step_avg:91.69ms +step:301/1645 train_time:27602ms step_avg:91.70ms +step:302/1645 train_time:27693ms step_avg:91.70ms +step:303/1645 train_time:27784ms step_avg:91.70ms +step:304/1645 train_time:27875ms step_avg:91.70ms +step:305/1645 train_time:27967ms step_avg:91.69ms +step:306/1645 train_time:28059ms step_avg:91.69ms +step:307/1645 train_time:28149ms step_avg:91.69ms +step:308/1645 train_time:28241ms step_avg:91.69ms +step:309/1645 train_time:28333ms step_avg:91.69ms +step:310/1645 train_time:28424ms step_avg:91.69ms +step:311/1645 train_time:28515ms step_avg:91.69ms +step:312/1645 train_time:28607ms step_avg:91.69ms +step:313/1645 train_time:28700ms step_avg:91.69ms +step:314/1645 train_time:28792ms step_avg:91.69ms +step:315/1645 train_time:28882ms step_avg:91.69ms +step:316/1645 train_time:28974ms step_avg:91.69ms +step:317/1645 train_time:29066ms step_avg:91.69ms +step:318/1645 train_time:29157ms step_avg:91.69ms +step:319/1645 train_time:29249ms step_avg:91.69ms +step:320/1645 train_time:29340ms step_avg:91.69ms +step:321/1645 train_time:29431ms step_avg:91.69ms +step:322/1645 train_time:29523ms step_avg:91.69ms +step:323/1645 train_time:29614ms step_avg:91.69ms +step:324/1645 train_time:29706ms step_avg:91.68ms +step:325/1645 train_time:29799ms step_avg:91.69ms +step:326/1645 train_time:29889ms step_avg:91.68ms +step:327/1645 train_time:29982ms step_avg:91.69ms +step:328/1645 train_time:30074ms step_avg:91.69ms +step:329/1645 train_time:30165ms step_avg:91.69ms +step:330/1645 train_time:30256ms step_avg:91.68ms +step:331/1645 train_time:30348ms step_avg:91.69ms +step:332/1645 train_time:30439ms step_avg:91.68ms +step:333/1645 train_time:30531ms step_avg:91.68ms +step:334/1645 train_time:30623ms step_avg:91.68ms +step:335/1645 train_time:30714ms step_avg:91.68ms +step:336/1645 train_time:30806ms step_avg:91.68ms +step:337/1645 train_time:30898ms step_avg:91.69ms +step:338/1645 train_time:30989ms step_avg:91.68ms +step:339/1645 train_time:31081ms step_avg:91.68ms +step:340/1645 train_time:31173ms step_avg:91.68ms +step:341/1645 train_time:31264ms step_avg:91.68ms +step:342/1645 train_time:31356ms step_avg:91.68ms +step:343/1645 train_time:31447ms step_avg:91.68ms +step:344/1645 train_time:31538ms step_avg:91.68ms +step:345/1645 train_time:31629ms step_avg:91.68ms +step:346/1645 train_time:31721ms step_avg:91.68ms +step:347/1645 train_time:31814ms step_avg:91.68ms +step:348/1645 train_time:31905ms step_avg:91.68ms +step:349/1645 train_time:31997ms step_avg:91.68ms +step:350/1645 train_time:32088ms step_avg:91.68ms +step:351/1645 train_time:32181ms step_avg:91.68ms +step:352/1645 train_time:32272ms step_avg:91.68ms +step:353/1645 train_time:32364ms step_avg:91.68ms +step:354/1645 train_time:32456ms step_avg:91.68ms +step:355/1645 train_time:32547ms step_avg:91.68ms +step:356/1645 train_time:32638ms step_avg:91.68ms +step:357/1645 train_time:32729ms step_avg:91.68ms +step:358/1645 train_time:32821ms step_avg:91.68ms +step:359/1645 train_time:32913ms step_avg:91.68ms +step:360/1645 train_time:33005ms step_avg:91.68ms +step:361/1645 train_time:33097ms step_avg:91.68ms +step:362/1645 train_time:33189ms step_avg:91.68ms +step:363/1645 train_time:33282ms step_avg:91.69ms +step:364/1645 train_time:33373ms step_avg:91.68ms +step:365/1645 train_time:33464ms step_avg:91.68ms +step:366/1645 train_time:33556ms step_avg:91.68ms +step:367/1645 train_time:33647ms step_avg:91.68ms +step:368/1645 train_time:33738ms step_avg:91.68ms +step:369/1645 train_time:33829ms step_avg:91.68ms +step:370/1645 train_time:33921ms step_avg:91.68ms +step:371/1645 train_time:34013ms step_avg:91.68ms +step:372/1645 train_time:34105ms step_avg:91.68ms +step:373/1645 train_time:34196ms step_avg:91.68ms +step:374/1645 train_time:34288ms step_avg:91.68ms +step:375/1645 train_time:34381ms step_avg:91.68ms +step:375/1645 val_loss:3.8144 train_time:34473ms step_avg:91.93ms +step:376/1645 train_time:34489ms step_avg:91.72ms +step:377/1645 train_time:34569ms step_avg:91.69ms +step:378/1645 train_time:34664ms step_avg:91.70ms +step:379/1645 train_time:34756ms step_avg:91.71ms +step:380/1645 train_time:34848ms step_avg:91.70ms +step:381/1645 train_time:34938ms step_avg:91.70ms +step:382/1645 train_time:35029ms step_avg:91.70ms +step:383/1645 train_time:35120ms step_avg:91.70ms +step:384/1645 train_time:35211ms step_avg:91.70ms +step:385/1645 train_time:35302ms step_avg:91.69ms +step:386/1645 train_time:35392ms step_avg:91.69ms +step:387/1645 train_time:35485ms step_avg:91.69ms +step:388/1645 train_time:35578ms step_avg:91.70ms +step:389/1645 train_time:35670ms step_avg:91.70ms +step:390/1645 train_time:35762ms step_avg:91.70ms +step:391/1645 train_time:35854ms step_avg:91.70ms +step:392/1645 train_time:35945ms step_avg:91.70ms +step:393/1645 train_time:36036ms step_avg:91.70ms +step:394/1645 train_time:36128ms step_avg:91.69ms +step:395/1645 train_time:36219ms step_avg:91.69ms +step:396/1645 train_time:36310ms step_avg:91.69ms +step:397/1645 train_time:36402ms step_avg:91.69ms +step:398/1645 train_time:36494ms step_avg:91.69ms +step:399/1645 train_time:36586ms step_avg:91.70ms +step:400/1645 train_time:36678ms step_avg:91.70ms +step:401/1645 train_time:36770ms step_avg:91.69ms +step:402/1645 train_time:36863ms step_avg:91.70ms +step:403/1645 train_time:36955ms step_avg:91.70ms +step:404/1645 train_time:37047ms step_avg:91.70ms +step:405/1645 train_time:37138ms step_avg:91.70ms +step:406/1645 train_time:37229ms step_avg:91.70ms +step:407/1645 train_time:37321ms step_avg:91.70ms +step:408/1645 train_time:37412ms step_avg:91.70ms +step:409/1645 train_time:37503ms step_avg:91.69ms +step:410/1645 train_time:37595ms step_avg:91.69ms +step:411/1645 train_time:37687ms step_avg:91.69ms +step:412/1645 train_time:37778ms step_avg:91.69ms +step:413/1645 train_time:37869ms step_avg:91.69ms +step:414/1645 train_time:37961ms step_avg:91.69ms +step:415/1645 train_time:38053ms step_avg:91.69ms +step:416/1645 train_time:38144ms step_avg:91.69ms +step:417/1645 train_time:38236ms step_avg:91.69ms +step:418/1645 train_time:38327ms step_avg:91.69ms +step:419/1645 train_time:38419ms step_avg:91.69ms +step:420/1645 train_time:38510ms step_avg:91.69ms +step:421/1645 train_time:38601ms step_avg:91.69ms +step:422/1645 train_time:38693ms step_avg:91.69ms +step:423/1645 train_time:38785ms step_avg:91.69ms +step:424/1645 train_time:38876ms step_avg:91.69ms +step:425/1645 train_time:38969ms step_avg:91.69ms +step:426/1645 train_time:39060ms step_avg:91.69ms +step:427/1645 train_time:39152ms step_avg:91.69ms +step:428/1645 train_time:39243ms step_avg:91.69ms +step:429/1645 train_time:39334ms step_avg:91.69ms +step:430/1645 train_time:39425ms step_avg:91.69ms +step:431/1645 train_time:39517ms step_avg:91.69ms +step:432/1645 train_time:39608ms step_avg:91.69ms +step:433/1645 train_time:39701ms step_avg:91.69ms +step:434/1645 train_time:39792ms step_avg:91.69ms +step:435/1645 train_time:39884ms step_avg:91.69ms +step:436/1645 train_time:39975ms step_avg:91.69ms +step:437/1645 train_time:40067ms step_avg:91.69ms +step:438/1645 train_time:40158ms step_avg:91.68ms +step:439/1645 train_time:40249ms step_avg:91.68ms +step:440/1645 train_time:40341ms step_avg:91.68ms +step:441/1645 train_time:40433ms step_avg:91.68ms +step:442/1645 train_time:40524ms step_avg:91.68ms +step:443/1645 train_time:40615ms step_avg:91.68ms +step:444/1645 train_time:40707ms step_avg:91.68ms +step:445/1645 train_time:40798ms step_avg:91.68ms +step:446/1645 train_time:40889ms step_avg:91.68ms +step:447/1645 train_time:40982ms step_avg:91.68ms +step:448/1645 train_time:41073ms step_avg:91.68ms +step:449/1645 train_time:41165ms step_avg:91.68ms +step:450/1645 train_time:41256ms step_avg:91.68ms +step:451/1645 train_time:41348ms step_avg:91.68ms +step:452/1645 train_time:41440ms step_avg:91.68ms +step:453/1645 train_time:41532ms step_avg:91.68ms +step:454/1645 train_time:41623ms step_avg:91.68ms +step:455/1645 train_time:41715ms step_avg:91.68ms +step:456/1645 train_time:41807ms step_avg:91.68ms +step:457/1645 train_time:41898ms step_avg:91.68ms +step:458/1645 train_time:41991ms step_avg:91.68ms +step:459/1645 train_time:42082ms step_avg:91.68ms +step:460/1645 train_time:42173ms step_avg:91.68ms +step:461/1645 train_time:42265ms step_avg:91.68ms +step:462/1645 train_time:42356ms step_avg:91.68ms +step:463/1645 train_time:42448ms step_avg:91.68ms +step:464/1645 train_time:42540ms step_avg:91.68ms +step:465/1645 train_time:42631ms step_avg:91.68ms +step:466/1645 train_time:42723ms step_avg:91.68ms +step:467/1645 train_time:42816ms step_avg:91.68ms +step:468/1645 train_time:42906ms step_avg:91.68ms +step:469/1645 train_time:42997ms step_avg:91.68ms +step:470/1645 train_time:43088ms step_avg:91.68ms +step:471/1645 train_time:43180ms step_avg:91.68ms +step:472/1645 train_time:43271ms step_avg:91.68ms +step:473/1645 train_time:43363ms step_avg:91.68ms +step:474/1645 train_time:43456ms step_avg:91.68ms +step:475/1645 train_time:43548ms step_avg:91.68ms +step:476/1645 train_time:43641ms step_avg:91.68ms +step:477/1645 train_time:43732ms step_avg:91.68ms +step:478/1645 train_time:43824ms step_avg:91.68ms +step:479/1645 train_time:43915ms step_avg:91.68ms +step:480/1645 train_time:44006ms step_avg:91.68ms +step:481/1645 train_time:44098ms step_avg:91.68ms +step:482/1645 train_time:44189ms step_avg:91.68ms +step:483/1645 train_time:44281ms step_avg:91.68ms +step:484/1645 train_time:44373ms step_avg:91.68ms +step:485/1645 train_time:44464ms step_avg:91.68ms +step:486/1645 train_time:44558ms step_avg:91.68ms +step:487/1645 train_time:44649ms step_avg:91.68ms +step:488/1645 train_time:44741ms step_avg:91.68ms +step:489/1645 train_time:44832ms step_avg:91.68ms +step:490/1645 train_time:44924ms step_avg:91.68ms +step:491/1645 train_time:45015ms step_avg:91.68ms +step:492/1645 train_time:45106ms step_avg:91.68ms +step:493/1645 train_time:45199ms step_avg:91.68ms +step:494/1645 train_time:45290ms step_avg:91.68ms +step:495/1645 train_time:45382ms step_avg:91.68ms +step:496/1645 train_time:45473ms step_avg:91.68ms +step:497/1645 train_time:45565ms step_avg:91.68ms +step:498/1645 train_time:45656ms step_avg:91.68ms +step:499/1645 train_time:45748ms step_avg:91.68ms +step:500/1645 train_time:45840ms step_avg:91.68ms +step:500/1645 val_loss:3.7121 train_time:45932ms step_avg:91.86ms +step:501/1645 train_time:45947ms step_avg:91.71ms +step:502/1645 train_time:46028ms step_avg:91.69ms +step:503/1645 train_time:46121ms step_avg:91.69ms +step:504/1645 train_time:46214ms step_avg:91.69ms +step:505/1645 train_time:46304ms step_avg:91.69ms +step:506/1645 train_time:46395ms step_avg:91.69ms +step:507/1645 train_time:46486ms step_avg:91.69ms +step:508/1645 train_time:46577ms step_avg:91.69ms +step:509/1645 train_time:46667ms step_avg:91.68ms +step:510/1645 train_time:46758ms step_avg:91.68ms +step:511/1645 train_time:46849ms step_avg:91.68ms +step:512/1645 train_time:46942ms step_avg:91.68ms +step:513/1645 train_time:47035ms step_avg:91.69ms +step:514/1645 train_time:47127ms step_avg:91.69ms +step:515/1645 train_time:47219ms step_avg:91.69ms +step:516/1645 train_time:47310ms step_avg:91.69ms +step:517/1645 train_time:47402ms step_avg:91.69ms +step:518/1645 train_time:47493ms step_avg:91.69ms +step:519/1645 train_time:47585ms step_avg:91.69ms +step:520/1645 train_time:47675ms step_avg:91.68ms +step:521/1645 train_time:47766ms step_avg:91.68ms +step:522/1645 train_time:47857ms step_avg:91.68ms +step:523/1645 train_time:47949ms step_avg:91.68ms +step:524/1645 train_time:48041ms step_avg:91.68ms +step:525/1645 train_time:48134ms step_avg:91.68ms +step:526/1645 train_time:48226ms step_avg:91.68ms +step:527/1645 train_time:48317ms step_avg:91.68ms +step:528/1645 train_time:48409ms step_avg:91.68ms +step:529/1645 train_time:48501ms step_avg:91.68ms +step:530/1645 train_time:48592ms step_avg:91.68ms +step:531/1645 train_time:48683ms step_avg:91.68ms +step:532/1645 train_time:48774ms step_avg:91.68ms +step:533/1645 train_time:48865ms step_avg:91.68ms +step:534/1645 train_time:48957ms step_avg:91.68ms +step:535/1645 train_time:49048ms step_avg:91.68ms +step:536/1645 train_time:49140ms step_avg:91.68ms +step:537/1645 train_time:49233ms step_avg:91.68ms +step:538/1645 train_time:49324ms step_avg:91.68ms +step:539/1645 train_time:49416ms step_avg:91.68ms +step:540/1645 train_time:49508ms step_avg:91.68ms +step:541/1645 train_time:49599ms step_avg:91.68ms +step:542/1645 train_time:49691ms step_avg:91.68ms +step:543/1645 train_time:49782ms step_avg:91.68ms +step:544/1645 train_time:49874ms step_avg:91.68ms +step:545/1645 train_time:49965ms step_avg:91.68ms +step:546/1645 train_time:50057ms step_avg:91.68ms +step:547/1645 train_time:50148ms step_avg:91.68ms +step:548/1645 train_time:50239ms step_avg:91.68ms +step:549/1645 train_time:50331ms step_avg:91.68ms +step:550/1645 train_time:50423ms step_avg:91.68ms +step:551/1645 train_time:50517ms step_avg:91.68ms +step:552/1645 train_time:50610ms step_avg:91.69ms +step:553/1645 train_time:50704ms step_avg:91.69ms +step:554/1645 train_time:50796ms step_avg:91.69ms +step:555/1645 train_time:50889ms step_avg:91.69ms +step:556/1645 train_time:50982ms step_avg:91.69ms +step:557/1645 train_time:51076ms step_avg:91.70ms +step:558/1645 train_time:51170ms step_avg:91.70ms +step:559/1645 train_time:51262ms step_avg:91.70ms +step:560/1645 train_time:51355ms step_avg:91.71ms +step:561/1645 train_time:51448ms step_avg:91.71ms +step:562/1645 train_time:51540ms step_avg:91.71ms +step:563/1645 train_time:51634ms step_avg:91.71ms +step:564/1645 train_time:51726ms step_avg:91.71ms +step:565/1645 train_time:51819ms step_avg:91.71ms +step:566/1645 train_time:51912ms step_avg:91.72ms +step:567/1645 train_time:52007ms step_avg:91.72ms +step:568/1645 train_time:52099ms step_avg:91.72ms +step:569/1645 train_time:52191ms step_avg:91.72ms +step:570/1645 train_time:52284ms step_avg:91.73ms +step:571/1645 train_time:52376ms step_avg:91.73ms +step:572/1645 train_time:52470ms step_avg:91.73ms +step:573/1645 train_time:52563ms step_avg:91.73ms +step:574/1645 train_time:52656ms step_avg:91.74ms +step:575/1645 train_time:52749ms step_avg:91.74ms +step:576/1645 train_time:52841ms step_avg:91.74ms +step:577/1645 train_time:52935ms step_avg:91.74ms +step:578/1645 train_time:53028ms step_avg:91.74ms +step:579/1645 train_time:53120ms step_avg:91.74ms +step:580/1645 train_time:53215ms step_avg:91.75ms +step:581/1645 train_time:53308ms step_avg:91.75ms +step:582/1645 train_time:53400ms step_avg:91.75ms +step:583/1645 train_time:53493ms step_avg:91.75ms +step:584/1645 train_time:53586ms step_avg:91.76ms +step:585/1645 train_time:53679ms step_avg:91.76ms +step:586/1645 train_time:53772ms step_avg:91.76ms +step:587/1645 train_time:53865ms step_avg:91.76ms +step:588/1645 train_time:53958ms step_avg:91.77ms +step:589/1645 train_time:54051ms step_avg:91.77ms +step:590/1645 train_time:54144ms step_avg:91.77ms +step:591/1645 train_time:54237ms step_avg:91.77ms +step:592/1645 train_time:54330ms step_avg:91.77ms +step:593/1645 train_time:54423ms step_avg:91.78ms +step:594/1645 train_time:54518ms step_avg:91.78ms +step:595/1645 train_time:54611ms step_avg:91.78ms +step:596/1645 train_time:54703ms step_avg:91.78ms +step:597/1645 train_time:54796ms step_avg:91.79ms +step:598/1645 train_time:54889ms step_avg:91.79ms +step:599/1645 train_time:54982ms step_avg:91.79ms +step:600/1645 train_time:55075ms step_avg:91.79ms +step:601/1645 train_time:55168ms step_avg:91.79ms +step:602/1645 train_time:55260ms step_avg:91.79ms +step:603/1645 train_time:55353ms step_avg:91.80ms +step:604/1645 train_time:55445ms step_avg:91.80ms +step:605/1645 train_time:55538ms step_avg:91.80ms +step:606/1645 train_time:55631ms step_avg:91.80ms +step:607/1645 train_time:55724ms step_avg:91.80ms +step:608/1645 train_time:55818ms step_avg:91.81ms +step:609/1645 train_time:55911ms step_avg:91.81ms +step:610/1645 train_time:56004ms step_avg:91.81ms +step:611/1645 train_time:56096ms step_avg:91.81ms +step:612/1645 train_time:56188ms step_avg:91.81ms +step:613/1645 train_time:56281ms step_avg:91.81ms +step:614/1645 train_time:56374ms step_avg:91.81ms +step:615/1645 train_time:56468ms step_avg:91.82ms +step:616/1645 train_time:56560ms step_avg:91.82ms +step:617/1645 train_time:56653ms step_avg:91.82ms +step:618/1645 train_time:56745ms step_avg:91.82ms +step:619/1645 train_time:56838ms step_avg:91.82ms +step:620/1645 train_time:56932ms step_avg:91.83ms +step:621/1645 train_time:57024ms step_avg:91.83ms +step:622/1645 train_time:57117ms step_avg:91.83ms +step:623/1645 train_time:57210ms step_avg:91.83ms +step:624/1645 train_time:57303ms step_avg:91.83ms +step:625/1645 train_time:57396ms step_avg:91.83ms +step:625/1645 val_loss:3.6105 train_time:57489ms step_avg:91.98ms +step:626/1645 train_time:57512ms step_avg:91.87ms +step:627/1645 train_time:57586ms step_avg:91.84ms +step:628/1645 train_time:57685ms step_avg:91.86ms +step:629/1645 train_time:57778ms step_avg:91.86ms +step:630/1645 train_time:57870ms step_avg:91.86ms +step:631/1645 train_time:57961ms step_avg:91.86ms +step:632/1645 train_time:58052ms step_avg:91.85ms +step:633/1645 train_time:58144ms step_avg:91.85ms +step:634/1645 train_time:58235ms step_avg:91.85ms +step:635/1645 train_time:58327ms step_avg:91.85ms +step:636/1645 train_time:58424ms step_avg:91.86ms +step:637/1645 train_time:58522ms step_avg:91.87ms +step:638/1645 train_time:58617ms step_avg:91.88ms +step:639/1645 train_time:58710ms step_avg:91.88ms +step:640/1645 train_time:58803ms step_avg:91.88ms +step:641/1645 train_time:58895ms step_avg:91.88ms +step:642/1645 train_time:58988ms step_avg:91.88ms +step:643/1645 train_time:59080ms step_avg:91.88ms +step:644/1645 train_time:59172ms step_avg:91.88ms +step:645/1645 train_time:59263ms step_avg:91.88ms +step:646/1645 train_time:59356ms step_avg:91.88ms +step:647/1645 train_time:59450ms step_avg:91.89ms +step:648/1645 train_time:59544ms step_avg:91.89ms +step:649/1645 train_time:59637ms step_avg:91.89ms +step:650/1645 train_time:59730ms step_avg:91.89ms +step:651/1645 train_time:59823ms step_avg:91.89ms +step:652/1645 train_time:59917ms step_avg:91.90ms +step:653/1645 train_time:60009ms step_avg:91.90ms +step:654/1645 train_time:60102ms step_avg:91.90ms +step:655/1645 train_time:60194ms step_avg:91.90ms +step:656/1645 train_time:60285ms step_avg:91.90ms +step:657/1645 train_time:60379ms step_avg:91.90ms +step:658/1645 train_time:60472ms step_avg:91.90ms +step:659/1645 train_time:60565ms step_avg:91.90ms +step:660/1645 train_time:60659ms step_avg:91.91ms +step:661/1645 train_time:60752ms step_avg:91.91ms +step:662/1645 train_time:60847ms step_avg:91.91ms +step:663/1645 train_time:60941ms step_avg:91.92ms +step:664/1645 train_time:61034ms step_avg:91.92ms +step:665/1645 train_time:61126ms step_avg:91.92ms +step:666/1645 train_time:61218ms step_avg:91.92ms +step:667/1645 train_time:61311ms step_avg:91.92ms +step:668/1645 train_time:61404ms step_avg:91.92ms +step:669/1645 train_time:61497ms step_avg:91.92ms +step:670/1645 train_time:61590ms step_avg:91.93ms +step:671/1645 train_time:61684ms step_avg:91.93ms +step:672/1645 train_time:61777ms step_avg:91.93ms +step:673/1645 train_time:61870ms step_avg:91.93ms +step:674/1645 train_time:61963ms step_avg:91.93ms +step:675/1645 train_time:62056ms step_avg:91.93ms +step:676/1645 train_time:62149ms step_avg:91.94ms +step:677/1645 train_time:62241ms step_avg:91.94ms +step:678/1645 train_time:62334ms step_avg:91.94ms +step:679/1645 train_time:62427ms step_avg:91.94ms +step:680/1645 train_time:62519ms step_avg:91.94ms +step:681/1645 train_time:62612ms step_avg:91.94ms +step:682/1645 train_time:62705ms step_avg:91.94ms +step:683/1645 train_time:62797ms step_avg:91.94ms +step:684/1645 train_time:62891ms step_avg:91.95ms +step:685/1645 train_time:62984ms step_avg:91.95ms +step:686/1645 train_time:63077ms step_avg:91.95ms +step:687/1645 train_time:63169ms step_avg:91.95ms +step:688/1645 train_time:63261ms step_avg:91.95ms +step:689/1645 train_time:63354ms step_avg:91.95ms +step:690/1645 train_time:63447ms step_avg:91.95ms +step:691/1645 train_time:63540ms step_avg:91.95ms +step:692/1645 train_time:63633ms step_avg:91.96ms +step:693/1645 train_time:63726ms step_avg:91.96ms +step:694/1645 train_time:63819ms step_avg:91.96ms +step:695/1645 train_time:63912ms step_avg:91.96ms +step:696/1645 train_time:64006ms step_avg:91.96ms +step:697/1645 train_time:64098ms step_avg:91.96ms +step:698/1645 train_time:64190ms step_avg:91.96ms +step:699/1645 train_time:64283ms step_avg:91.96ms +step:700/1645 train_time:64376ms step_avg:91.97ms +step:701/1645 train_time:64468ms step_avg:91.97ms +step:702/1645 train_time:64562ms step_avg:91.97ms +step:703/1645 train_time:64655ms step_avg:91.97ms +step:704/1645 train_time:64748ms step_avg:91.97ms +step:705/1645 train_time:64842ms step_avg:91.97ms +step:706/1645 train_time:64935ms step_avg:91.98ms +step:707/1645 train_time:65028ms step_avg:91.98ms +step:708/1645 train_time:65121ms step_avg:91.98ms +step:709/1645 train_time:65214ms step_avg:91.98ms +step:710/1645 train_time:65306ms step_avg:91.98ms +step:711/1645 train_time:65399ms step_avg:91.98ms +step:712/1645 train_time:65492ms step_avg:91.98ms +step:713/1645 train_time:65585ms step_avg:91.98ms +step:714/1645 train_time:65679ms step_avg:91.99ms +step:715/1645 train_time:65771ms step_avg:91.99ms +step:716/1645 train_time:65864ms step_avg:91.99ms +step:717/1645 train_time:65958ms step_avg:91.99ms +step:718/1645 train_time:66050ms step_avg:91.99ms +step:719/1645 train_time:66143ms step_avg:91.99ms +step:720/1645 train_time:66236ms step_avg:91.99ms +step:721/1645 train_time:66328ms step_avg:92.00ms +step:722/1645 train_time:66422ms step_avg:92.00ms +step:723/1645 train_time:66516ms step_avg:92.00ms +step:724/1645 train_time:66608ms step_avg:92.00ms +step:725/1645 train_time:66700ms step_avg:92.00ms +step:726/1645 train_time:66792ms step_avg:92.00ms +step:727/1645 train_time:66886ms step_avg:92.00ms +step:728/1645 train_time:66979ms step_avg:92.00ms +step:729/1645 train_time:67072ms step_avg:92.01ms +step:730/1645 train_time:67164ms step_avg:92.01ms +step:731/1645 train_time:67258ms step_avg:92.01ms +step:732/1645 train_time:67351ms step_avg:92.01ms +step:733/1645 train_time:67444ms step_avg:92.01ms +step:734/1645 train_time:67538ms step_avg:92.01ms +step:735/1645 train_time:67630ms step_avg:92.01ms +step:736/1645 train_time:67724ms step_avg:92.02ms +step:737/1645 train_time:67816ms step_avg:92.02ms +step:738/1645 train_time:67909ms step_avg:92.02ms +step:739/1645 train_time:68002ms step_avg:92.02ms +step:740/1645 train_time:68095ms step_avg:92.02ms +step:741/1645 train_time:68188ms step_avg:92.02ms +step:742/1645 train_time:68281ms step_avg:92.02ms +step:743/1645 train_time:68375ms step_avg:92.03ms +step:744/1645 train_time:68468ms step_avg:92.03ms +step:745/1645 train_time:68561ms step_avg:92.03ms +step:746/1645 train_time:68654ms step_avg:92.03ms +step:747/1645 train_time:68747ms step_avg:92.03ms +step:748/1645 train_time:68841ms step_avg:92.03ms +step:749/1645 train_time:68935ms step_avg:92.04ms +step:750/1645 train_time:69029ms step_avg:92.04ms +step:750/1645 val_loss:3.5580 train_time:69120ms step_avg:92.16ms +step:751/1645 train_time:69143ms step_avg:92.07ms +step:752/1645 train_time:69216ms step_avg:92.04ms +step:753/1645 train_time:69309ms step_avg:92.04ms +step:754/1645 train_time:69401ms step_avg:92.04ms +step:755/1645 train_time:69494ms step_avg:92.05ms +step:756/1645 train_time:69587ms step_avg:92.05ms +step:757/1645 train_time:69679ms step_avg:92.05ms +step:758/1645 train_time:69772ms step_avg:92.05ms +step:759/1645 train_time:69864ms step_avg:92.05ms +step:760/1645 train_time:69957ms step_avg:92.05ms +step:761/1645 train_time:70051ms step_avg:92.05ms +step:762/1645 train_time:70145ms step_avg:92.05ms +step:763/1645 train_time:70239ms step_avg:92.06ms +step:764/1645 train_time:70332ms step_avg:92.06ms +step:765/1645 train_time:70425ms step_avg:92.06ms +step:766/1645 train_time:70518ms step_avg:92.06ms +step:767/1645 train_time:70610ms step_avg:92.06ms +step:768/1645 train_time:70702ms step_avg:92.06ms +step:769/1645 train_time:70795ms step_avg:92.06ms +step:770/1645 train_time:70887ms step_avg:92.06ms +step:771/1645 train_time:70980ms step_avg:92.06ms +step:772/1645 train_time:71075ms step_avg:92.07ms +step:773/1645 train_time:71168ms step_avg:92.07ms +step:774/1645 train_time:71260ms step_avg:92.07ms +step:775/1645 train_time:71354ms step_avg:92.07ms +step:776/1645 train_time:71446ms step_avg:92.07ms +step:777/1645 train_time:71539ms step_avg:92.07ms +step:778/1645 train_time:71630ms step_avg:92.07ms +step:779/1645 train_time:71723ms step_avg:92.07ms +step:780/1645 train_time:71815ms step_avg:92.07ms +step:781/1645 train_time:71907ms step_avg:92.07ms +step:782/1645 train_time:72001ms step_avg:92.07ms +step:783/1645 train_time:72094ms step_avg:92.07ms +step:784/1645 train_time:72187ms step_avg:92.08ms +step:785/1645 train_time:72281ms step_avg:92.08ms +step:786/1645 train_time:72374ms step_avg:92.08ms +step:787/1645 train_time:72466ms step_avg:92.08ms +step:788/1645 train_time:72558ms step_avg:92.08ms +step:789/1645 train_time:72651ms step_avg:92.08ms +step:790/1645 train_time:72744ms step_avg:92.08ms +step:791/1645 train_time:72837ms step_avg:92.08ms +step:792/1645 train_time:72929ms step_avg:92.08ms +step:793/1645 train_time:73022ms step_avg:92.08ms +step:794/1645 train_time:73115ms step_avg:92.08ms +step:795/1645 train_time:73208ms step_avg:92.09ms +step:796/1645 train_time:73301ms step_avg:92.09ms +step:797/1645 train_time:73394ms step_avg:92.09ms +step:798/1645 train_time:73488ms step_avg:92.09ms +step:799/1645 train_time:73582ms step_avg:92.09ms +step:800/1645 train_time:73675ms step_avg:92.09ms +step:801/1645 train_time:73768ms step_avg:92.09ms +step:802/1645 train_time:73860ms step_avg:92.09ms +step:803/1645 train_time:73955ms step_avg:92.10ms +step:804/1645 train_time:74047ms step_avg:92.10ms +step:805/1645 train_time:74139ms step_avg:92.10ms +step:806/1645 train_time:74232ms step_avg:92.10ms +step:807/1645 train_time:74325ms step_avg:92.10ms +step:808/1645 train_time:74418ms step_avg:92.10ms +step:809/1645 train_time:74512ms step_avg:92.10ms +step:810/1645 train_time:74604ms step_avg:92.10ms +step:811/1645 train_time:74698ms step_avg:92.11ms +step:812/1645 train_time:74792ms step_avg:92.11ms +step:813/1645 train_time:74884ms step_avg:92.11ms +step:814/1645 train_time:74978ms step_avg:92.11ms +step:815/1645 train_time:75070ms step_avg:92.11ms +step:816/1645 train_time:75163ms step_avg:92.11ms +step:817/1645 train_time:75256ms step_avg:92.11ms +step:818/1645 train_time:75349ms step_avg:92.11ms +step:819/1645 train_time:75442ms step_avg:92.12ms +step:820/1645 train_time:75535ms step_avg:92.12ms +step:821/1645 train_time:75628ms step_avg:92.12ms +step:822/1645 train_time:75721ms step_avg:92.12ms +step:823/1645 train_time:75814ms step_avg:92.12ms +step:824/1645 train_time:75907ms step_avg:92.12ms +step:825/1645 train_time:75999ms step_avg:92.12ms +step:826/1645 train_time:76092ms step_avg:92.12ms +step:827/1645 train_time:76185ms step_avg:92.12ms +step:828/1645 train_time:76278ms step_avg:92.12ms +step:829/1645 train_time:76371ms step_avg:92.12ms +step:830/1645 train_time:76464ms step_avg:92.13ms +step:831/1645 train_time:76557ms step_avg:92.13ms +step:832/1645 train_time:76650ms step_avg:92.13ms +step:833/1645 train_time:76743ms step_avg:92.13ms +step:834/1645 train_time:76836ms step_avg:92.13ms +step:835/1645 train_time:76929ms step_avg:92.13ms +step:836/1645 train_time:77021ms step_avg:92.13ms +step:837/1645 train_time:77115ms step_avg:92.13ms +step:838/1645 train_time:77207ms step_avg:92.13ms +step:839/1645 train_time:77300ms step_avg:92.13ms +step:840/1645 train_time:77394ms step_avg:92.14ms +step:841/1645 train_time:77487ms step_avg:92.14ms +step:842/1645 train_time:77580ms step_avg:92.14ms +step:843/1645 train_time:77673ms step_avg:92.14ms +step:844/1645 train_time:77766ms step_avg:92.14ms +step:845/1645 train_time:77858ms step_avg:92.14ms +step:846/1645 train_time:77952ms step_avg:92.14ms +step:847/1645 train_time:78044ms step_avg:92.14ms +step:848/1645 train_time:78137ms step_avg:92.14ms +step:849/1645 train_time:78230ms step_avg:92.14ms +step:850/1645 train_time:78322ms step_avg:92.14ms +step:851/1645 train_time:78417ms step_avg:92.15ms +step:852/1645 train_time:78511ms step_avg:92.15ms +step:853/1645 train_time:78604ms step_avg:92.15ms +step:854/1645 train_time:78697ms step_avg:92.15ms +step:855/1645 train_time:78790ms step_avg:92.15ms +step:856/1645 train_time:78883ms step_avg:92.15ms +step:857/1645 train_time:78977ms step_avg:92.15ms +step:858/1645 train_time:79069ms step_avg:92.15ms +step:859/1645 train_time:79161ms step_avg:92.16ms +step:860/1645 train_time:79255ms step_avg:92.16ms +step:861/1645 train_time:79348ms step_avg:92.16ms +step:862/1645 train_time:79441ms step_avg:92.16ms +step:863/1645 train_time:79534ms step_avg:92.16ms +step:864/1645 train_time:79626ms step_avg:92.16ms +step:865/1645 train_time:79719ms step_avg:92.16ms +step:866/1645 train_time:79812ms step_avg:92.16ms +step:867/1645 train_time:79904ms step_avg:92.16ms +step:868/1645 train_time:79997ms step_avg:92.16ms +step:869/1645 train_time:80091ms step_avg:92.16ms +step:870/1645 train_time:80184ms step_avg:92.17ms +step:871/1645 train_time:80277ms step_avg:92.17ms +step:872/1645 train_time:80369ms step_avg:92.17ms +step:873/1645 train_time:80462ms step_avg:92.17ms +step:874/1645 train_time:80555ms step_avg:92.17ms +step:875/1645 train_time:80649ms step_avg:92.17ms +step:875/1645 val_loss:3.5132 train_time:80742ms step_avg:92.28ms +step:876/1645 train_time:80763ms step_avg:92.20ms +step:877/1645 train_time:80839ms step_avg:92.18ms +step:878/1645 train_time:80937ms step_avg:92.18ms +step:879/1645 train_time:81030ms step_avg:92.18ms +step:880/1645 train_time:81122ms step_avg:92.18ms +step:881/1645 train_time:81214ms step_avg:92.18ms +step:882/1645 train_time:81305ms step_avg:92.18ms +step:883/1645 train_time:81399ms step_avg:92.18ms +step:884/1645 train_time:81492ms step_avg:92.18ms +step:885/1645 train_time:81584ms step_avg:92.18ms +step:886/1645 train_time:81676ms step_avg:92.19ms +step:887/1645 train_time:81770ms step_avg:92.19ms +step:888/1645 train_time:81865ms step_avg:92.19ms +step:889/1645 train_time:81960ms step_avg:92.19ms +step:890/1645 train_time:82053ms step_avg:92.19ms +step:891/1645 train_time:82145ms step_avg:92.19ms +step:892/1645 train_time:82238ms step_avg:92.20ms +step:893/1645 train_time:82331ms step_avg:92.20ms +step:894/1645 train_time:82423ms step_avg:92.20ms +step:895/1645 train_time:82516ms step_avg:92.20ms +step:896/1645 train_time:82608ms step_avg:92.20ms +step:897/1645 train_time:82702ms step_avg:92.20ms +step:898/1645 train_time:82795ms step_avg:92.20ms +step:899/1645 train_time:82889ms step_avg:92.20ms +step:900/1645 train_time:82983ms step_avg:92.20ms +step:901/1645 train_time:83077ms step_avg:92.21ms +step:902/1645 train_time:83170ms step_avg:92.21ms +step:903/1645 train_time:83263ms step_avg:92.21ms +step:904/1645 train_time:83355ms step_avg:92.21ms +step:905/1645 train_time:83448ms step_avg:92.21ms +step:906/1645 train_time:83540ms step_avg:92.21ms +step:907/1645 train_time:83633ms step_avg:92.21ms +step:908/1645 train_time:83725ms step_avg:92.21ms +step:909/1645 train_time:83819ms step_avg:92.21ms +step:910/1645 train_time:83912ms step_avg:92.21ms +step:911/1645 train_time:84005ms step_avg:92.21ms +step:912/1645 train_time:84098ms step_avg:92.21ms +step:913/1645 train_time:84191ms step_avg:92.21ms +step:914/1645 train_time:84283ms step_avg:92.21ms +step:915/1645 train_time:84376ms step_avg:92.21ms +step:916/1645 train_time:84468ms step_avg:92.21ms +step:917/1645 train_time:84561ms step_avg:92.21ms +step:918/1645 train_time:84654ms step_avg:92.22ms +step:919/1645 train_time:84747ms step_avg:92.22ms +step:920/1645 train_time:84840ms step_avg:92.22ms +step:921/1645 train_time:84934ms step_avg:92.22ms +step:922/1645 train_time:85027ms step_avg:92.22ms +step:923/1645 train_time:85121ms step_avg:92.22ms +step:924/1645 train_time:85214ms step_avg:92.22ms +step:925/1645 train_time:85307ms step_avg:92.22ms +step:926/1645 train_time:85400ms step_avg:92.22ms +step:927/1645 train_time:85492ms step_avg:92.22ms +step:928/1645 train_time:85585ms step_avg:92.23ms +step:929/1645 train_time:85677ms step_avg:92.23ms +step:930/1645 train_time:85770ms step_avg:92.23ms +step:931/1645 train_time:85863ms step_avg:92.23ms +step:932/1645 train_time:85956ms step_avg:92.23ms +step:933/1645 train_time:86049ms step_avg:92.23ms +step:934/1645 train_time:86142ms step_avg:92.23ms +step:935/1645 train_time:86234ms step_avg:92.23ms +step:936/1645 train_time:86328ms step_avg:92.23ms +step:937/1645 train_time:86421ms step_avg:92.23ms +step:938/1645 train_time:86513ms step_avg:92.23ms +step:939/1645 train_time:86606ms step_avg:92.23ms +step:940/1645 train_time:86700ms step_avg:92.23ms +step:941/1645 train_time:86792ms step_avg:92.23ms +step:942/1645 train_time:86886ms step_avg:92.24ms +step:943/1645 train_time:86978ms step_avg:92.24ms +step:944/1645 train_time:87070ms step_avg:92.24ms +step:945/1645 train_time:87163ms step_avg:92.24ms +step:946/1645 train_time:87256ms step_avg:92.24ms +step:947/1645 train_time:87349ms step_avg:92.24ms +step:948/1645 train_time:87442ms step_avg:92.24ms +step:949/1645 train_time:87535ms step_avg:92.24ms +step:950/1645 train_time:87629ms step_avg:92.24ms +step:951/1645 train_time:87722ms step_avg:92.24ms +step:952/1645 train_time:87816ms step_avg:92.24ms +step:953/1645 train_time:87908ms step_avg:92.24ms +step:954/1645 train_time:88002ms step_avg:92.25ms +step:955/1645 train_time:88095ms step_avg:92.25ms +step:956/1645 train_time:88188ms step_avg:92.25ms +step:957/1645 train_time:88281ms step_avg:92.25ms +step:958/1645 train_time:88373ms step_avg:92.25ms +step:959/1645 train_time:88466ms step_avg:92.25ms +step:960/1645 train_time:88559ms step_avg:92.25ms +step:961/1645 train_time:88652ms step_avg:92.25ms +step:962/1645 train_time:88746ms step_avg:92.25ms +step:963/1645 train_time:88838ms step_avg:92.25ms +step:964/1645 train_time:88931ms step_avg:92.25ms +step:965/1645 train_time:89024ms step_avg:92.25ms +step:966/1645 train_time:89117ms step_avg:92.25ms +step:967/1645 train_time:89210ms step_avg:92.25ms +step:968/1645 train_time:89303ms step_avg:92.26ms +step:969/1645 train_time:89396ms step_avg:92.26ms +step:970/1645 train_time:89488ms step_avg:92.26ms +step:971/1645 train_time:89581ms step_avg:92.26ms +step:972/1645 train_time:89674ms step_avg:92.26ms +step:973/1645 train_time:89767ms step_avg:92.26ms +step:974/1645 train_time:89860ms step_avg:92.26ms +step:975/1645 train_time:89953ms step_avg:92.26ms +step:976/1645 train_time:90046ms step_avg:92.26ms +step:977/1645 train_time:90141ms step_avg:92.26ms +step:978/1645 train_time:90234ms step_avg:92.26ms +step:979/1645 train_time:90327ms step_avg:92.26ms +step:980/1645 train_time:90421ms step_avg:92.27ms +step:981/1645 train_time:90514ms step_avg:92.27ms +step:982/1645 train_time:90607ms step_avg:92.27ms +step:983/1645 train_time:90702ms step_avg:92.27ms +step:984/1645 train_time:90794ms step_avg:92.27ms +step:985/1645 train_time:90886ms step_avg:92.27ms +step:986/1645 train_time:90979ms step_avg:92.27ms +step:987/1645 train_time:91072ms step_avg:92.27ms +step:988/1645 train_time:91164ms step_avg:92.27ms +step:989/1645 train_time:91258ms step_avg:92.27ms +step:990/1645 train_time:91350ms step_avg:92.27ms +step:991/1645 train_time:91443ms step_avg:92.27ms +step:992/1645 train_time:91536ms step_avg:92.27ms +step:993/1645 train_time:91630ms step_avg:92.28ms +step:994/1645 train_time:91722ms step_avg:92.28ms +step:995/1645 train_time:91816ms step_avg:92.28ms +step:996/1645 train_time:91909ms step_avg:92.28ms +step:997/1645 train_time:92003ms step_avg:92.28ms +step:998/1645 train_time:92095ms step_avg:92.28ms +step:999/1645 train_time:92188ms step_avg:92.28ms +step:1000/1645 train_time:92281ms step_avg:92.28ms +step:1000/1645 val_loss:3.4628 train_time:92373ms step_avg:92.37ms +step:1001/1645 train_time:92398ms step_avg:92.31ms +step:1002/1645 train_time:92470ms step_avg:92.29ms +step:1003/1645 train_time:92565ms step_avg:92.29ms +step:1004/1645 train_time:92657ms step_avg:92.29ms +step:1005/1645 train_time:92748ms step_avg:92.29ms +step:1006/1645 train_time:92840ms step_avg:92.29ms +step:1007/1645 train_time:92932ms step_avg:92.29ms +step:1008/1645 train_time:93025ms step_avg:92.29ms +step:1009/1645 train_time:93117ms step_avg:92.29ms +step:1010/1645 train_time:93211ms step_avg:92.29ms +step:1011/1645 train_time:93304ms step_avg:92.29ms +step:1012/1645 train_time:93398ms step_avg:92.29ms +step:1013/1645 train_time:93493ms step_avg:92.29ms +step:1014/1645 train_time:93586ms step_avg:92.29ms +step:1015/1645 train_time:93678ms step_avg:92.29ms +step:1016/1645 train_time:93771ms step_avg:92.29ms +step:1017/1645 train_time:93863ms step_avg:92.29ms +step:1018/1645 train_time:93956ms step_avg:92.29ms +step:1019/1645 train_time:94049ms step_avg:92.30ms +step:1020/1645 train_time:94141ms step_avg:92.30ms +step:1021/1645 train_time:94234ms step_avg:92.30ms +step:1022/1645 train_time:94328ms step_avg:92.30ms +step:1023/1645 train_time:94421ms step_avg:92.30ms +step:1024/1645 train_time:94515ms step_avg:92.30ms +step:1025/1645 train_time:94608ms step_avg:92.30ms +step:1026/1645 train_time:94700ms step_avg:92.30ms +step:1027/1645 train_time:94793ms step_avg:92.30ms +step:1028/1645 train_time:94886ms step_avg:92.30ms +step:1029/1645 train_time:94978ms step_avg:92.30ms +step:1030/1645 train_time:95071ms step_avg:92.30ms +step:1031/1645 train_time:95164ms step_avg:92.30ms +step:1032/1645 train_time:95257ms step_avg:92.30ms +step:1033/1645 train_time:95351ms step_avg:92.30ms +step:1034/1645 train_time:95444ms step_avg:92.31ms +step:1035/1645 train_time:95538ms step_avg:92.31ms +step:1036/1645 train_time:95630ms step_avg:92.31ms +step:1037/1645 train_time:95723ms step_avg:92.31ms +step:1038/1645 train_time:95816ms step_avg:92.31ms +step:1039/1645 train_time:95909ms step_avg:92.31ms +step:1040/1645 train_time:96001ms step_avg:92.31ms +step:1041/1645 train_time:96094ms step_avg:92.31ms +step:1042/1645 train_time:96186ms step_avg:92.31ms +step:1043/1645 train_time:96280ms step_avg:92.31ms +step:1044/1645 train_time:96374ms step_avg:92.31ms +step:1045/1645 train_time:96466ms step_avg:92.31ms +step:1046/1645 train_time:96560ms step_avg:92.31ms +step:1047/1645 train_time:96653ms step_avg:92.31ms +step:1048/1645 train_time:96746ms step_avg:92.32ms +step:1049/1645 train_time:96839ms step_avg:92.32ms +step:1050/1645 train_time:96931ms step_avg:92.32ms +step:1051/1645 train_time:97024ms step_avg:92.32ms +step:1052/1645 train_time:97117ms step_avg:92.32ms +step:1053/1645 train_time:97210ms step_avg:92.32ms +step:1054/1645 train_time:97302ms step_avg:92.32ms +step:1055/1645 train_time:97396ms step_avg:92.32ms +step:1056/1645 train_time:97490ms step_avg:92.32ms +step:1057/1645 train_time:97583ms step_avg:92.32ms +step:1058/1645 train_time:97678ms step_avg:92.32ms +step:1059/1645 train_time:97771ms step_avg:92.32ms +step:1060/1645 train_time:97865ms step_avg:92.33ms +step:1061/1645 train_time:97958ms step_avg:92.33ms +step:1062/1645 train_time:98051ms step_avg:92.33ms +step:1063/1645 train_time:98143ms step_avg:92.33ms +step:1064/1645 train_time:98236ms step_avg:92.33ms +step:1065/1645 train_time:98328ms step_avg:92.33ms +step:1066/1645 train_time:98421ms step_avg:92.33ms +step:1067/1645 train_time:98514ms step_avg:92.33ms +step:1068/1645 train_time:98607ms step_avg:92.33ms +step:1069/1645 train_time:98700ms step_avg:92.33ms +step:1070/1645 train_time:98795ms step_avg:92.33ms +step:1071/1645 train_time:98888ms step_avg:92.33ms +step:1072/1645 train_time:98980ms step_avg:92.33ms +step:1073/1645 train_time:99073ms step_avg:92.33ms +step:1074/1645 train_time:99166ms step_avg:92.33ms +step:1075/1645 train_time:99258ms step_avg:92.33ms +step:1076/1645 train_time:99351ms step_avg:92.33ms +step:1077/1645 train_time:99443ms step_avg:92.33ms +step:1078/1645 train_time:99537ms step_avg:92.33ms +step:1079/1645 train_time:99630ms step_avg:92.34ms +step:1080/1645 train_time:99723ms step_avg:92.34ms +step:1081/1645 train_time:99817ms step_avg:92.34ms +step:1082/1645 train_time:99910ms step_avg:92.34ms +step:1083/1645 train_time:100002ms step_avg:92.34ms +step:1084/1645 train_time:100096ms step_avg:92.34ms +step:1085/1645 train_time:100189ms step_avg:92.34ms +step:1086/1645 train_time:100281ms step_avg:92.34ms +step:1087/1645 train_time:100374ms step_avg:92.34ms +step:1088/1645 train_time:100467ms step_avg:92.34ms +step:1089/1645 train_time:100560ms step_avg:92.34ms +step:1090/1645 train_time:100653ms step_avg:92.34ms +step:1091/1645 train_time:100746ms step_avg:92.34ms +step:1092/1645 train_time:100839ms step_avg:92.34ms +step:1093/1645 train_time:100933ms step_avg:92.34ms +step:1094/1645 train_time:101026ms step_avg:92.35ms +step:1095/1645 train_time:101119ms step_avg:92.35ms +step:1096/1645 train_time:101212ms step_avg:92.35ms +step:1097/1645 train_time:101304ms step_avg:92.35ms +step:1098/1645 train_time:101396ms step_avg:92.35ms +step:1099/1645 train_time:101490ms step_avg:92.35ms +step:1100/1645 train_time:101585ms step_avg:92.35ms +step:1101/1645 train_time:101679ms step_avg:92.35ms +step:1102/1645 train_time:101773ms step_avg:92.35ms +step:1103/1645 train_time:101866ms step_avg:92.35ms +step:1104/1645 train_time:101960ms step_avg:92.36ms +step:1105/1645 train_time:102055ms step_avg:92.36ms +step:1106/1645 train_time:102148ms step_avg:92.36ms +step:1107/1645 train_time:102241ms step_avg:92.36ms +step:1108/1645 train_time:102334ms step_avg:92.36ms +step:1109/1645 train_time:102427ms step_avg:92.36ms +step:1110/1645 train_time:102520ms step_avg:92.36ms +step:1111/1645 train_time:102614ms step_avg:92.36ms +step:1112/1645 train_time:102708ms step_avg:92.36ms +step:1113/1645 train_time:102800ms step_avg:92.36ms +step:1114/1645 train_time:102895ms step_avg:92.37ms +step:1115/1645 train_time:102989ms step_avg:92.37ms +step:1116/1645 train_time:103083ms step_avg:92.37ms +step:1117/1645 train_time:103177ms step_avg:92.37ms +step:1118/1645 train_time:103270ms step_avg:92.37ms +step:1119/1645 train_time:103364ms step_avg:92.37ms +step:1120/1645 train_time:103457ms step_avg:92.37ms +step:1121/1645 train_time:103551ms step_avg:92.37ms +step:1122/1645 train_time:103645ms step_avg:92.38ms +step:1123/1645 train_time:103739ms step_avg:92.38ms +step:1124/1645 train_time:103833ms step_avg:92.38ms +step:1125/1645 train_time:103926ms step_avg:92.38ms +step:1125/1645 val_loss:3.4103 train_time:104020ms step_avg:92.46ms +step:1126/1645 train_time:104041ms step_avg:92.40ms +step:1127/1645 train_time:104122ms step_avg:92.39ms +step:1128/1645 train_time:104223ms step_avg:92.40ms +step:1129/1645 train_time:104319ms step_avg:92.40ms +step:1130/1645 train_time:104411ms step_avg:92.40ms +step:1131/1645 train_time:104503ms step_avg:92.40ms +step:1132/1645 train_time:104596ms step_avg:92.40ms +step:1133/1645 train_time:104688ms step_avg:92.40ms +step:1134/1645 train_time:104781ms step_avg:92.40ms +step:1135/1645 train_time:104873ms step_avg:92.40ms +step:1136/1645 train_time:104966ms step_avg:92.40ms +step:1137/1645 train_time:105061ms step_avg:92.40ms +step:1138/1645 train_time:105158ms step_avg:92.41ms +step:1139/1645 train_time:105253ms step_avg:92.41ms +step:1140/1645 train_time:105348ms step_avg:92.41ms +step:1141/1645 train_time:105441ms step_avg:92.41ms +step:1142/1645 train_time:105535ms step_avg:92.41ms +step:1143/1645 train_time:105627ms step_avg:92.41ms +step:1144/1645 train_time:105720ms step_avg:92.41ms +step:1145/1645 train_time:105812ms step_avg:92.41ms +step:1146/1645 train_time:105905ms step_avg:92.41ms +step:1147/1645 train_time:105999ms step_avg:92.41ms +step:1148/1645 train_time:106093ms step_avg:92.42ms +step:1149/1645 train_time:106188ms step_avg:92.42ms +step:1150/1645 train_time:106282ms step_avg:92.42ms +step:1151/1645 train_time:106376ms step_avg:92.42ms +step:1152/1645 train_time:106469ms step_avg:92.42ms +step:1153/1645 train_time:106563ms step_avg:92.42ms +step:1154/1645 train_time:106656ms step_avg:92.42ms +step:1155/1645 train_time:106749ms step_avg:92.42ms +step:1156/1645 train_time:106842ms step_avg:92.42ms +step:1157/1645 train_time:106935ms step_avg:92.42ms +step:1158/1645 train_time:107029ms step_avg:92.43ms +step:1159/1645 train_time:107123ms step_avg:92.43ms +step:1160/1645 train_time:107217ms step_avg:92.43ms +step:1161/1645 train_time:107312ms step_avg:92.43ms +step:1162/1645 train_time:107405ms step_avg:92.43ms +step:1163/1645 train_time:107498ms step_avg:92.43ms +step:1164/1645 train_time:107592ms step_avg:92.43ms +step:1165/1645 train_time:107685ms step_avg:92.43ms +step:1166/1645 train_time:107779ms step_avg:92.43ms +step:1167/1645 train_time:107872ms step_avg:92.43ms +step:1168/1645 train_time:107965ms step_avg:92.44ms +step:1169/1645 train_time:108058ms step_avg:92.44ms +step:1170/1645 train_time:108152ms step_avg:92.44ms +step:1171/1645 train_time:108247ms step_avg:92.44ms +step:1172/1645 train_time:108341ms step_avg:92.44ms +step:1173/1645 train_time:108434ms step_avg:92.44ms +step:1174/1645 train_time:108528ms step_avg:92.44ms +step:1175/1645 train_time:108621ms step_avg:92.44ms +step:1176/1645 train_time:108714ms step_avg:92.44ms +step:1177/1645 train_time:108807ms step_avg:92.44ms +step:1178/1645 train_time:108900ms step_avg:92.45ms +step:1179/1645 train_time:108994ms step_avg:92.45ms +step:1180/1645 train_time:109087ms step_avg:92.45ms +step:1181/1645 train_time:109181ms step_avg:92.45ms +step:1182/1645 train_time:109274ms step_avg:92.45ms +step:1183/1645 train_time:109368ms step_avg:92.45ms +step:1184/1645 train_time:109462ms step_avg:92.45ms +step:1185/1645 train_time:109555ms step_avg:92.45ms +step:1186/1645 train_time:109648ms step_avg:92.45ms +step:1187/1645 train_time:109742ms step_avg:92.45ms +step:1188/1645 train_time:109835ms step_avg:92.45ms +step:1189/1645 train_time:109928ms step_avg:92.45ms +step:1190/1645 train_time:110022ms step_avg:92.46ms +step:1191/1645 train_time:110115ms step_avg:92.46ms +step:1192/1645 train_time:110209ms step_avg:92.46ms +step:1193/1645 train_time:110303ms step_avg:92.46ms +step:1194/1645 train_time:110397ms step_avg:92.46ms +step:1195/1645 train_time:110490ms step_avg:92.46ms +step:1196/1645 train_time:110584ms step_avg:92.46ms +step:1197/1645 train_time:110677ms step_avg:92.46ms +step:1198/1645 train_time:110771ms step_avg:92.46ms +step:1199/1645 train_time:110865ms step_avg:92.46ms +step:1200/1645 train_time:110958ms step_avg:92.46ms +step:1201/1645 train_time:111051ms step_avg:92.47ms +step:1202/1645 train_time:111145ms step_avg:92.47ms +step:1203/1645 train_time:111239ms step_avg:92.47ms +step:1204/1645 train_time:111333ms step_avg:92.47ms +step:1205/1645 train_time:111427ms step_avg:92.47ms +step:1206/1645 train_time:111520ms step_avg:92.47ms +step:1207/1645 train_time:111613ms step_avg:92.47ms +step:1208/1645 train_time:111707ms step_avg:92.47ms +step:1209/1645 train_time:111800ms step_avg:92.47ms +step:1210/1645 train_time:111893ms step_avg:92.47ms +step:1211/1645 train_time:111987ms step_avg:92.47ms +step:1212/1645 train_time:112080ms step_avg:92.48ms +step:1213/1645 train_time:112173ms step_avg:92.48ms +step:1214/1645 train_time:112267ms step_avg:92.48ms +step:1215/1645 train_time:112360ms step_avg:92.48ms +step:1216/1645 train_time:112454ms step_avg:92.48ms +step:1217/1645 train_time:112547ms step_avg:92.48ms +step:1218/1645 train_time:112641ms step_avg:92.48ms +step:1219/1645 train_time:112734ms step_avg:92.48ms +step:1220/1645 train_time:112827ms step_avg:92.48ms +step:1221/1645 train_time:112922ms step_avg:92.48ms +step:1222/1645 train_time:113015ms step_avg:92.48ms +step:1223/1645 train_time:113108ms step_avg:92.48ms +step:1224/1645 train_time:113204ms step_avg:92.49ms +step:1225/1645 train_time:113297ms step_avg:92.49ms +step:1226/1645 train_time:113390ms step_avg:92.49ms +step:1227/1645 train_time:113484ms step_avg:92.49ms +step:1228/1645 train_time:113578ms step_avg:92.49ms +step:1229/1645 train_time:113670ms step_avg:92.49ms +step:1230/1645 train_time:113765ms step_avg:92.49ms +step:1231/1645 train_time:113858ms step_avg:92.49ms +step:1232/1645 train_time:113951ms step_avg:92.49ms +step:1233/1645 train_time:114045ms step_avg:92.49ms +step:1234/1645 train_time:114138ms step_avg:92.49ms +step:1235/1645 train_time:114231ms step_avg:92.50ms +step:1236/1645 train_time:114325ms step_avg:92.50ms +step:1237/1645 train_time:114419ms step_avg:92.50ms +step:1238/1645 train_time:114513ms step_avg:92.50ms +step:1239/1645 train_time:114606ms step_avg:92.50ms +step:1240/1645 train_time:114700ms step_avg:92.50ms +step:1241/1645 train_time:114793ms step_avg:92.50ms +step:1242/1645 train_time:114886ms step_avg:92.50ms +step:1243/1645 train_time:114980ms step_avg:92.50ms +step:1244/1645 train_time:115073ms step_avg:92.50ms +step:1245/1645 train_time:115167ms step_avg:92.50ms +step:1246/1645 train_time:115261ms step_avg:92.50ms +step:1247/1645 train_time:115354ms step_avg:92.51ms +step:1248/1645 train_time:115448ms step_avg:92.51ms +step:1249/1645 train_time:115541ms step_avg:92.51ms +step:1250/1645 train_time:115634ms step_avg:92.51ms +step:1250/1645 val_loss:3.3722 train_time:115728ms step_avg:92.58ms +step:1251/1645 train_time:115753ms step_avg:92.53ms +step:1252/1645 train_time:115828ms step_avg:92.51ms +step:1253/1645 train_time:115924ms step_avg:92.52ms +step:1254/1645 train_time:116018ms step_avg:92.52ms +step:1255/1645 train_time:116110ms step_avg:92.52ms +step:1256/1645 train_time:116202ms step_avg:92.52ms +step:1257/1645 train_time:116294ms step_avg:92.52ms +step:1258/1645 train_time:116387ms step_avg:92.52ms +step:1259/1645 train_time:116479ms step_avg:92.52ms +step:1260/1645 train_time:116573ms step_avg:92.52ms +step:1261/1645 train_time:116667ms step_avg:92.52ms +step:1262/1645 train_time:116764ms step_avg:92.52ms +step:1263/1645 train_time:116859ms step_avg:92.53ms +step:1264/1645 train_time:116954ms step_avg:92.53ms +step:1265/1645 train_time:117048ms step_avg:92.53ms +step:1266/1645 train_time:117141ms step_avg:92.53ms +step:1267/1645 train_time:117234ms step_avg:92.53ms +step:1268/1645 train_time:117328ms step_avg:92.53ms +step:1269/1645 train_time:117422ms step_avg:92.53ms +step:1270/1645 train_time:117515ms step_avg:92.53ms +step:1271/1645 train_time:117608ms step_avg:92.53ms +step:1272/1645 train_time:117702ms step_avg:92.53ms +step:1273/1645 train_time:117797ms step_avg:92.53ms +step:1274/1645 train_time:117891ms step_avg:92.54ms +step:1275/1645 train_time:117985ms step_avg:92.54ms +step:1276/1645 train_time:118078ms step_avg:92.54ms +step:1277/1645 train_time:118172ms step_avg:92.54ms +step:1278/1645 train_time:118266ms step_avg:92.54ms +step:1279/1645 train_time:118359ms step_avg:92.54ms +step:1280/1645 train_time:118452ms step_avg:92.54ms +step:1281/1645 train_time:118544ms step_avg:92.54ms +step:1282/1645 train_time:118638ms step_avg:92.54ms +step:1283/1645 train_time:118732ms step_avg:92.54ms +step:1284/1645 train_time:118826ms step_avg:92.54ms +step:1285/1645 train_time:118921ms step_avg:92.55ms +step:1286/1645 train_time:119015ms step_avg:92.55ms +step:1287/1645 train_time:119109ms step_avg:92.55ms +step:1288/1645 train_time:119202ms step_avg:92.55ms +step:1289/1645 train_time:119296ms step_avg:92.55ms +step:1290/1645 train_time:119389ms step_avg:92.55ms +step:1291/1645 train_time:119483ms step_avg:92.55ms +step:1292/1645 train_time:119576ms step_avg:92.55ms +step:1293/1645 train_time:119669ms step_avg:92.55ms +step:1294/1645 train_time:119763ms step_avg:92.55ms +step:1295/1645 train_time:119859ms step_avg:92.56ms +step:1296/1645 train_time:119953ms step_avg:92.56ms +step:1297/1645 train_time:120047ms step_avg:92.56ms +step:1298/1645 train_time:120141ms step_avg:92.56ms +step:1299/1645 train_time:120234ms step_avg:92.56ms +step:1300/1645 train_time:120328ms step_avg:92.56ms +step:1301/1645 train_time:120421ms step_avg:92.56ms +step:1302/1645 train_time:120514ms step_avg:92.56ms +step:1303/1645 train_time:120607ms step_avg:92.56ms +step:1304/1645 train_time:120700ms step_avg:92.56ms +step:1305/1645 train_time:120794ms step_avg:92.56ms +step:1306/1645 train_time:120888ms step_avg:92.56ms +step:1307/1645 train_time:120981ms step_avg:92.56ms +step:1308/1645 train_time:121076ms step_avg:92.57ms +step:1309/1645 train_time:121169ms step_avg:92.57ms +step:1310/1645 train_time:121263ms step_avg:92.57ms +step:1311/1645 train_time:121357ms step_avg:92.57ms +step:1312/1645 train_time:121451ms step_avg:92.57ms +step:1313/1645 train_time:121544ms step_avg:92.57ms +step:1314/1645 train_time:121637ms step_avg:92.57ms +step:1315/1645 train_time:121730ms step_avg:92.57ms +step:1316/1645 train_time:121824ms step_avg:92.57ms +step:1317/1645 train_time:121917ms step_avg:92.57ms +step:1318/1645 train_time:122011ms step_avg:92.57ms +step:1319/1645 train_time:122105ms step_avg:92.57ms +step:1320/1645 train_time:122199ms step_avg:92.57ms +step:1321/1645 train_time:122293ms step_avg:92.58ms +step:1322/1645 train_time:122386ms step_avg:92.58ms +step:1323/1645 train_time:122480ms step_avg:92.58ms +step:1324/1645 train_time:122574ms step_avg:92.58ms +step:1325/1645 train_time:122668ms step_avg:92.58ms +step:1326/1645 train_time:122762ms step_avg:92.58ms +step:1327/1645 train_time:122855ms step_avg:92.58ms +step:1328/1645 train_time:122949ms step_avg:92.58ms +step:1329/1645 train_time:123042ms step_avg:92.58ms +step:1330/1645 train_time:123136ms step_avg:92.58ms +step:1331/1645 train_time:123230ms step_avg:92.58ms +step:1332/1645 train_time:123324ms step_avg:92.59ms +step:1333/1645 train_time:123418ms step_avg:92.59ms +step:1334/1645 train_time:123510ms step_avg:92.59ms +step:1335/1645 train_time:123604ms step_avg:92.59ms +step:1336/1645 train_time:123698ms step_avg:92.59ms +step:1337/1645 train_time:123792ms step_avg:92.59ms +step:1338/1645 train_time:123885ms step_avg:92.59ms +step:1339/1645 train_time:123979ms step_avg:92.59ms +step:1340/1645 train_time:124073ms step_avg:92.59ms +step:1341/1645 train_time:124169ms step_avg:92.59ms +step:1342/1645 train_time:124262ms step_avg:92.59ms +step:1343/1645 train_time:124356ms step_avg:92.60ms +step:1344/1645 train_time:124449ms step_avg:92.60ms +step:1345/1645 train_time:124542ms step_avg:92.60ms +step:1346/1645 train_time:124636ms step_avg:92.60ms +step:1347/1645 train_time:124730ms step_avg:92.60ms +step:1348/1645 train_time:124823ms step_avg:92.60ms +step:1349/1645 train_time:124917ms step_avg:92.60ms +step:1350/1645 train_time:125010ms step_avg:92.60ms +step:1351/1645 train_time:125104ms step_avg:92.60ms +step:1352/1645 train_time:125197ms step_avg:92.60ms +step:1353/1645 train_time:125291ms step_avg:92.60ms +step:1354/1645 train_time:125385ms step_avg:92.60ms +step:1355/1645 train_time:125478ms step_avg:92.60ms +step:1356/1645 train_time:125572ms step_avg:92.60ms +step:1357/1645 train_time:125665ms step_avg:92.61ms +step:1358/1645 train_time:125759ms step_avg:92.61ms +step:1359/1645 train_time:125852ms step_avg:92.61ms +step:1360/1645 train_time:125945ms step_avg:92.61ms +step:1361/1645 train_time:126039ms step_avg:92.61ms +step:1362/1645 train_time:126132ms step_avg:92.61ms +step:1363/1645 train_time:126226ms step_avg:92.61ms +step:1364/1645 train_time:126319ms step_avg:92.61ms +step:1365/1645 train_time:126412ms step_avg:92.61ms +step:1366/1645 train_time:126506ms step_avg:92.61ms +step:1367/1645 train_time:126599ms step_avg:92.61ms +step:1368/1645 train_time:126693ms step_avg:92.61ms +step:1369/1645 train_time:126787ms step_avg:92.61ms +step:1370/1645 train_time:126880ms step_avg:92.61ms +step:1371/1645 train_time:126974ms step_avg:92.61ms +step:1372/1645 train_time:127068ms step_avg:92.62ms +step:1373/1645 train_time:127162ms step_avg:92.62ms +step:1374/1645 train_time:127256ms step_avg:92.62ms +step:1375/1645 train_time:127350ms step_avg:92.62ms +step:1375/1645 val_loss:3.3377 train_time:127444ms step_avg:92.69ms +step:1376/1645 train_time:127470ms step_avg:92.64ms +step:1377/1645 train_time:127543ms step_avg:92.62ms +step:1378/1645 train_time:127638ms step_avg:92.63ms +step:1379/1645 train_time:127731ms step_avg:92.63ms +step:1380/1645 train_time:127825ms step_avg:92.63ms +step:1381/1645 train_time:127917ms step_avg:92.63ms +step:1382/1645 train_time:128009ms step_avg:92.63ms +step:1383/1645 train_time:128102ms step_avg:92.63ms +step:1384/1645 train_time:128197ms step_avg:92.63ms +step:1385/1645 train_time:128291ms step_avg:92.63ms +step:1386/1645 train_time:128385ms step_avg:92.63ms +step:1387/1645 train_time:128480ms step_avg:92.63ms +step:1388/1645 train_time:128575ms step_avg:92.63ms +step:1389/1645 train_time:128669ms step_avg:92.63ms +step:1390/1645 train_time:128762ms step_avg:92.63ms +step:1391/1645 train_time:128855ms step_avg:92.63ms +step:1392/1645 train_time:128948ms step_avg:92.64ms +step:1393/1645 train_time:129041ms step_avg:92.64ms +step:1394/1645 train_time:129135ms step_avg:92.64ms +step:1395/1645 train_time:129229ms step_avg:92.64ms +step:1396/1645 train_time:129323ms step_avg:92.64ms +step:1397/1645 train_time:129417ms step_avg:92.64ms +step:1398/1645 train_time:129511ms step_avg:92.64ms +step:1399/1645 train_time:129606ms step_avg:92.64ms +step:1400/1645 train_time:129699ms step_avg:92.64ms +step:1401/1645 train_time:129793ms step_avg:92.64ms +step:1402/1645 train_time:129887ms step_avg:92.64ms +step:1403/1645 train_time:129980ms step_avg:92.64ms +step:1404/1645 train_time:130072ms step_avg:92.64ms +step:1405/1645 train_time:130166ms step_avg:92.64ms +step:1406/1645 train_time:130260ms step_avg:92.65ms +step:1407/1645 train_time:130355ms step_avg:92.65ms +step:1408/1645 train_time:130449ms step_avg:92.65ms +step:1409/1645 train_time:130543ms step_avg:92.65ms +step:1410/1645 train_time:130639ms step_avg:92.65ms +step:1411/1645 train_time:130731ms step_avg:92.65ms +step:1412/1645 train_time:130826ms step_avg:92.65ms +step:1413/1645 train_time:130918ms step_avg:92.65ms +step:1414/1645 train_time:131012ms step_avg:92.65ms +step:1415/1645 train_time:131105ms step_avg:92.65ms +step:1416/1645 train_time:131199ms step_avg:92.65ms +step:1417/1645 train_time:131292ms step_avg:92.66ms +step:1418/1645 train_time:131386ms step_avg:92.66ms +step:1419/1645 train_time:131481ms step_avg:92.66ms +step:1420/1645 train_time:131575ms step_avg:92.66ms +step:1421/1645 train_time:131668ms step_avg:92.66ms +step:1422/1645 train_time:131762ms step_avg:92.66ms +step:1423/1645 train_time:131856ms step_avg:92.66ms +step:1424/1645 train_time:131950ms step_avg:92.66ms +step:1425/1645 train_time:132044ms step_avg:92.66ms +step:1426/1645 train_time:132136ms step_avg:92.66ms +step:1427/1645 train_time:132229ms step_avg:92.66ms +step:1428/1645 train_time:132323ms step_avg:92.66ms +step:1429/1645 train_time:132417ms step_avg:92.66ms +step:1430/1645 train_time:132510ms step_avg:92.66ms +step:1431/1645 train_time:132605ms step_avg:92.67ms +step:1432/1645 train_time:132699ms step_avg:92.67ms +step:1433/1645 train_time:132792ms step_avg:92.67ms +step:1434/1645 train_time:132886ms step_avg:92.67ms +step:1435/1645 train_time:132980ms step_avg:92.67ms +step:1436/1645 train_time:133074ms step_avg:92.67ms +step:1437/1645 train_time:133167ms step_avg:92.67ms +step:1438/1645 train_time:133261ms step_avg:92.67ms +step:1439/1645 train_time:133355ms step_avg:92.67ms +step:1440/1645 train_time:133450ms step_avg:92.67ms +step:1441/1645 train_time:133542ms step_avg:92.67ms +step:1442/1645 train_time:133636ms step_avg:92.67ms +step:1443/1645 train_time:133729ms step_avg:92.67ms +step:1444/1645 train_time:133823ms step_avg:92.68ms +step:1445/1645 train_time:133916ms step_avg:92.68ms +step:1446/1645 train_time:134011ms step_avg:92.68ms +step:1447/1645 train_time:134104ms step_avg:92.68ms +step:1448/1645 train_time:134197ms step_avg:92.68ms +step:1449/1645 train_time:134292ms step_avg:92.68ms +step:1450/1645 train_time:134386ms step_avg:92.68ms +step:1451/1645 train_time:134479ms step_avg:92.68ms +step:1452/1645 train_time:134573ms step_avg:92.68ms +step:1453/1645 train_time:134666ms step_avg:92.68ms +step:1454/1645 train_time:134760ms step_avg:92.68ms +step:1455/1645 train_time:134853ms step_avg:92.68ms +step:1456/1645 train_time:134947ms step_avg:92.68ms +step:1457/1645 train_time:135042ms step_avg:92.69ms +step:1458/1645 train_time:135136ms step_avg:92.69ms +step:1459/1645 train_time:135229ms step_avg:92.69ms +step:1460/1645 train_time:135323ms step_avg:92.69ms +step:1461/1645 train_time:135417ms step_avg:92.69ms +step:1462/1645 train_time:135511ms step_avg:92.69ms +step:1463/1645 train_time:135605ms step_avg:92.69ms +step:1464/1645 train_time:135700ms step_avg:92.69ms +step:1465/1645 train_time:135794ms step_avg:92.69ms +step:1466/1645 train_time:135887ms step_avg:92.69ms +step:1467/1645 train_time:135981ms step_avg:92.69ms +step:1468/1645 train_time:136075ms step_avg:92.69ms +step:1469/1645 train_time:136168ms step_avg:92.69ms +step:1470/1645 train_time:136263ms step_avg:92.70ms +step:1471/1645 train_time:136356ms step_avg:92.70ms +step:1472/1645 train_time:136449ms step_avg:92.70ms +step:1473/1645 train_time:136543ms step_avg:92.70ms +step:1474/1645 train_time:136637ms step_avg:92.70ms +step:1475/1645 train_time:136731ms step_avg:92.70ms +step:1476/1645 train_time:136825ms step_avg:92.70ms +step:1477/1645 train_time:136918ms step_avg:92.70ms +step:1478/1645 train_time:137012ms step_avg:92.70ms +step:1479/1645 train_time:137106ms step_avg:92.70ms +step:1480/1645 train_time:137200ms step_avg:92.70ms +step:1481/1645 train_time:137294ms step_avg:92.70ms +step:1482/1645 train_time:137387ms step_avg:92.70ms +step:1483/1645 train_time:137481ms step_avg:92.70ms +step:1484/1645 train_time:137575ms step_avg:92.71ms +step:1485/1645 train_time:137669ms step_avg:92.71ms +step:1486/1645 train_time:137762ms step_avg:92.71ms +step:1487/1645 train_time:137856ms step_avg:92.71ms +step:1488/1645 train_time:137950ms step_avg:92.71ms +step:1489/1645 train_time:138044ms step_avg:92.71ms +step:1490/1645 train_time:138137ms step_avg:92.71ms +step:1491/1645 train_time:138231ms step_avg:92.71ms +step:1492/1645 train_time:138325ms step_avg:92.71ms +step:1493/1645 train_time:138418ms step_avg:92.71ms +step:1494/1645 train_time:138511ms step_avg:92.71ms +step:1495/1645 train_time:138605ms step_avg:92.71ms +step:1496/1645 train_time:138698ms step_avg:92.71ms +step:1497/1645 train_time:138792ms step_avg:92.71ms +step:1498/1645 train_time:138886ms step_avg:92.71ms +step:1499/1645 train_time:138980ms step_avg:92.72ms +step:1500/1645 train_time:139074ms step_avg:92.72ms +step:1500/1645 val_loss:3.3079 train_time:139168ms step_avg:92.78ms +step:1501/1645 train_time:139194ms step_avg:92.73ms +step:1502/1645 train_time:139266ms step_avg:92.72ms +step:1503/1645 train_time:139362ms step_avg:92.72ms +step:1504/1645 train_time:139455ms step_avg:92.72ms +step:1505/1645 train_time:139548ms step_avg:92.72ms +step:1506/1645 train_time:139640ms step_avg:92.72ms +step:1507/1645 train_time:139732ms step_avg:92.72ms +step:1508/1645 train_time:139827ms step_avg:92.72ms +step:1509/1645 train_time:139920ms step_avg:92.72ms +step:1510/1645 train_time:140014ms step_avg:92.72ms +step:1511/1645 train_time:140109ms step_avg:92.73ms +step:1512/1645 train_time:140206ms step_avg:92.73ms +step:1513/1645 train_time:140299ms step_avg:92.73ms +step:1514/1645 train_time:140394ms step_avg:92.73ms +step:1515/1645 train_time:140487ms step_avg:92.73ms +step:1516/1645 train_time:140581ms step_avg:92.73ms +step:1517/1645 train_time:140674ms step_avg:92.73ms +step:1518/1645 train_time:140767ms step_avg:92.73ms +step:1519/1645 train_time:140861ms step_avg:92.73ms +step:1520/1645 train_time:140954ms step_avg:92.73ms +step:1521/1645 train_time:141047ms step_avg:92.73ms +step:1522/1645 train_time:141143ms step_avg:92.73ms +step:1523/1645 train_time:141237ms step_avg:92.74ms +step:1524/1645 train_time:141332ms step_avg:92.74ms +step:1525/1645 train_time:141426ms step_avg:92.74ms +step:1526/1645 train_time:141519ms step_avg:92.74ms +step:1527/1645 train_time:141612ms step_avg:92.74ms +step:1528/1645 train_time:141706ms step_avg:92.74ms +step:1529/1645 train_time:141799ms step_avg:92.74ms +step:1530/1645 train_time:141892ms step_avg:92.74ms +step:1531/1645 train_time:141985ms step_avg:92.74ms +step:1532/1645 train_time:142078ms step_avg:92.74ms +step:1533/1645 train_time:142172ms step_avg:92.74ms +step:1534/1645 train_time:142266ms step_avg:92.74ms +step:1535/1645 train_time:142360ms step_avg:92.74ms +step:1536/1645 train_time:142454ms step_avg:92.74ms +step:1537/1645 train_time:142548ms step_avg:92.74ms +step:1538/1645 train_time:142641ms step_avg:92.74ms +step:1539/1645 train_time:142734ms step_avg:92.74ms +step:1540/1645 train_time:142828ms step_avg:92.75ms +step:1541/1645 train_time:142922ms step_avg:92.75ms +step:1542/1645 train_time:143015ms step_avg:92.75ms +step:1543/1645 train_time:143109ms step_avg:92.75ms +step:1544/1645 train_time:143203ms step_avg:92.75ms +step:1545/1645 train_time:143297ms step_avg:92.75ms +step:1546/1645 train_time:143391ms step_avg:92.75ms +step:1547/1645 train_time:143485ms step_avg:92.75ms +step:1548/1645 train_time:143578ms step_avg:92.75ms +step:1549/1645 train_time:143671ms step_avg:92.75ms +step:1550/1645 train_time:143765ms step_avg:92.75ms +step:1551/1645 train_time:143859ms step_avg:92.75ms +step:1552/1645 train_time:143952ms step_avg:92.75ms +step:1553/1645 train_time:144047ms step_avg:92.75ms +step:1554/1645 train_time:144141ms step_avg:92.75ms +step:1555/1645 train_time:144234ms step_avg:92.76ms +step:1556/1645 train_time:144329ms step_avg:92.76ms +step:1557/1645 train_time:144423ms step_avg:92.76ms +step:1558/1645 train_time:144517ms step_avg:92.76ms +step:1559/1645 train_time:144611ms step_avg:92.76ms +step:1560/1645 train_time:144705ms step_avg:92.76ms +step:1561/1645 train_time:144798ms step_avg:92.76ms +step:1562/1645 train_time:144892ms step_avg:92.76ms +step:1563/1645 train_time:144986ms step_avg:92.76ms +step:1564/1645 train_time:145079ms step_avg:92.76ms +step:1565/1645 train_time:145172ms step_avg:92.76ms +step:1566/1645 train_time:145267ms step_avg:92.76ms +step:1567/1645 train_time:145362ms step_avg:92.76ms +step:1568/1645 train_time:145456ms step_avg:92.77ms +step:1569/1645 train_time:145549ms step_avg:92.77ms +step:1570/1645 train_time:145644ms step_avg:92.77ms +step:1571/1645 train_time:145737ms step_avg:92.77ms +step:1572/1645 train_time:145830ms step_avg:92.77ms +step:1573/1645 train_time:145924ms step_avg:92.77ms +step:1574/1645 train_time:146019ms step_avg:92.77ms +step:1575/1645 train_time:146111ms step_avg:92.77ms +step:1576/1645 train_time:146204ms step_avg:92.77ms +step:1577/1645 train_time:146298ms step_avg:92.77ms +step:1578/1645 train_time:146392ms step_avg:92.77ms +step:1579/1645 train_time:146486ms step_avg:92.77ms +step:1580/1645 train_time:146581ms step_avg:92.77ms +step:1581/1645 train_time:146675ms step_avg:92.77ms +step:1582/1645 train_time:146768ms step_avg:92.77ms +step:1583/1645 train_time:146861ms step_avg:92.77ms +step:1584/1645 train_time:146955ms step_avg:92.77ms +step:1585/1645 train_time:147049ms step_avg:92.78ms +step:1586/1645 train_time:147142ms step_avg:92.78ms +step:1587/1645 train_time:147236ms step_avg:92.78ms +step:1588/1645 train_time:147329ms step_avg:92.78ms +step:1589/1645 train_time:147423ms step_avg:92.78ms +step:1590/1645 train_time:147516ms step_avg:92.78ms +step:1591/1645 train_time:147611ms step_avg:92.78ms +step:1592/1645 train_time:147705ms step_avg:92.78ms +step:1593/1645 train_time:147798ms step_avg:92.78ms +step:1594/1645 train_time:147892ms step_avg:92.78ms +step:1595/1645 train_time:147985ms step_avg:92.78ms +step:1596/1645 train_time:148079ms step_avg:92.78ms +step:1597/1645 train_time:148172ms step_avg:92.78ms +step:1598/1645 train_time:148266ms step_avg:92.78ms +step:1599/1645 train_time:148360ms step_avg:92.78ms +step:1600/1645 train_time:148454ms step_avg:92.78ms +step:1601/1645 train_time:148548ms step_avg:92.78ms +step:1602/1645 train_time:148642ms step_avg:92.79ms +step:1603/1645 train_time:148736ms step_avg:92.79ms +step:1604/1645 train_time:148830ms step_avg:92.79ms +step:1605/1645 train_time:148924ms step_avg:92.79ms +step:1606/1645 train_time:149017ms step_avg:92.79ms +step:1607/1645 train_time:149111ms step_avg:92.79ms +step:1608/1645 train_time:149204ms step_avg:92.79ms +step:1609/1645 train_time:149297ms step_avg:92.79ms +step:1610/1645 train_time:149391ms step_avg:92.79ms +step:1611/1645 train_time:149485ms step_avg:92.79ms +step:1612/1645 train_time:149579ms step_avg:92.79ms +step:1613/1645 train_time:149673ms step_avg:92.79ms +step:1614/1645 train_time:149767ms step_avg:92.79ms +step:1615/1645 train_time:149861ms step_avg:92.79ms +step:1616/1645 train_time:149955ms step_avg:92.79ms +step:1617/1645 train_time:150049ms step_avg:92.79ms +step:1618/1645 train_time:150143ms step_avg:92.80ms +step:1619/1645 train_time:150236ms step_avg:92.80ms +step:1620/1645 train_time:150330ms step_avg:92.80ms +step:1621/1645 train_time:150424ms step_avg:92.80ms +step:1622/1645 train_time:150518ms step_avg:92.80ms +step:1623/1645 train_time:150611ms step_avg:92.80ms +step:1624/1645 train_time:150705ms step_avg:92.80ms +step:1625/1645 train_time:150798ms step_avg:92.80ms +step:1625/1645 val_loss:3.2838 train_time:150893ms step_avg:92.86ms +step:1626/1645 train_time:150918ms step_avg:92.82ms +step:1627/1645 train_time:150991ms step_avg:92.80ms +step:1628/1645 train_time:151087ms step_avg:92.81ms +step:1629/1645 train_time:151181ms step_avg:92.81ms +step:1630/1645 train_time:151273ms step_avg:92.81ms +step:1631/1645 train_time:151366ms step_avg:92.81ms +step:1632/1645 train_time:151459ms step_avg:92.81ms +step:1633/1645 train_time:151552ms step_avg:92.81ms +step:1634/1645 train_time:151645ms step_avg:92.81ms +step:1635/1645 train_time:151739ms step_avg:92.81ms +step:1636/1645 train_time:151833ms step_avg:92.81ms +step:1637/1645 train_time:151928ms step_avg:92.81ms +step:1638/1645 train_time:152023ms step_avg:92.81ms +step:1639/1645 train_time:152117ms step_avg:92.81ms +step:1640/1645 train_time:152211ms step_avg:92.81ms +step:1641/1645 train_time:152305ms step_avg:92.81ms +step:1642/1645 train_time:152398ms step_avg:92.81ms +step:1643/1645 train_time:152491ms step_avg:92.81ms +step:1644/1645 train_time:152583ms step_avg:92.81ms +step:1645/1645 train_time:152677ms step_avg:92.81ms +step:1645/1645 val_loss:3.2781 train_time:152771ms step_avg:92.87ms +peak memory allocated: 31659 MiB reserved: 46796 MiB diff --git a/records/091825_Smear/fd1a1ca9-afff-4881-a686-8b56bad5901b.txt b/records/091825_Smear/fd1a1ca9-afff-4881-a686-8b56bad5901b.txt new file mode 100644 index 000000000..2b561a8a4 --- /dev/null +++ b/records/091825_Smear/fd1a1ca9-afff-4881-a686-8b56bad5901b.txt @@ -0,0 +1,3089 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + smear_lambda = self.scalars[5 * len(self.blocks)] + #x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1645 # number of iterations to run + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"smear/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations: + return args.ws_validate, args.ws_validate_final_layer + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.11.10 (main, Sep 7 2024, 18:35:41) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250721+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Thu Sep 18 17:16:54 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 27C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 28C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 34C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 33C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1645 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1645 train_time:137ms step_avg:136.88ms +step:2/1645 train_time:156ms step_avg:78.07ms +step:3/1645 train_time:227ms step_avg:75.80ms +step:4/1645 train_time:317ms step_avg:79.27ms +step:5/1645 train_time:408ms step_avg:81.54ms +step:6/1645 train_time:499ms step_avg:83.17ms +step:7/1645 train_time:590ms step_avg:84.26ms +step:8/1645 train_time:681ms step_avg:85.17ms +step:9/1645 train_time:772ms step_avg:85.82ms +step:10/1645 train_time:863ms step_avg:86.33ms +step:11/1645 train_time:954ms step_avg:86.73ms +step:12/1645 train_time:1048ms step_avg:87.33ms +step:13/1645 train_time:1144ms step_avg:87.99ms +step:14/1645 train_time:1238ms step_avg:88.44ms +step:15/1645 train_time:1333ms step_avg:88.84ms +step:16/1645 train_time:1425ms step_avg:89.06ms +step:17/1645 train_time:1516ms step_avg:89.20ms +step:18/1645 train_time:1608ms step_avg:89.31ms +step:19/1645 train_time:1699ms step_avg:89.44ms +step:20/1645 train_time:1791ms step_avg:89.53ms +step:21/1645 train_time:1881ms step_avg:89.59ms +step:22/1645 train_time:1974ms step_avg:89.71ms +step:23/1645 train_time:2066ms step_avg:89.84ms +step:24/1645 train_time:2160ms step_avg:89.99ms +step:25/1645 train_time:2253ms step_avg:90.13ms +step:26/1645 train_time:2346ms step_avg:90.24ms +step:27/1645 train_time:2439ms step_avg:90.32ms +step:28/1645 train_time:2531ms step_avg:90.39ms +step:29/1645 train_time:2622ms step_avg:90.43ms +step:30/1645 train_time:2714ms step_avg:90.48ms +step:31/1645 train_time:2806ms step_avg:90.50ms +step:32/1645 train_time:2897ms step_avg:90.54ms +step:33/1645 train_time:2990ms step_avg:90.60ms +step:34/1645 train_time:3081ms step_avg:90.63ms +step:35/1645 train_time:3174ms step_avg:90.69ms +step:36/1645 train_time:3266ms step_avg:90.73ms +step:37/1645 train_time:3358ms step_avg:90.76ms +step:38/1645 train_time:3450ms step_avg:90.79ms +step:39/1645 train_time:3542ms step_avg:90.81ms +step:40/1645 train_time:3633ms step_avg:90.82ms +step:41/1645 train_time:3726ms step_avg:90.87ms +step:42/1645 train_time:3817ms step_avg:90.88ms +step:43/1645 train_time:3909ms step_avg:90.91ms +step:44/1645 train_time:4001ms step_avg:90.92ms +step:45/1645 train_time:4093ms step_avg:90.95ms +step:46/1645 train_time:4187ms step_avg:91.02ms +step:47/1645 train_time:4278ms step_avg:91.02ms +step:48/1645 train_time:4370ms step_avg:91.03ms +step:49/1645 train_time:4461ms step_avg:91.04ms +step:50/1645 train_time:4553ms step_avg:91.06ms +step:51/1645 train_time:4645ms step_avg:91.08ms +step:52/1645 train_time:4736ms step_avg:91.09ms +step:53/1645 train_time:4828ms step_avg:91.10ms +step:54/1645 train_time:4920ms step_avg:91.11ms +step:55/1645 train_time:5012ms step_avg:91.13ms +step:56/1645 train_time:5104ms step_avg:91.14ms +step:57/1645 train_time:5195ms step_avg:91.15ms +step:58/1645 train_time:5287ms step_avg:91.16ms +step:59/1645 train_time:5379ms step_avg:91.16ms +step:60/1645 train_time:5470ms step_avg:91.17ms +step:61/1645 train_time:5562ms step_avg:91.17ms +step:62/1645 train_time:5653ms step_avg:91.18ms +step:63/1645 train_time:5745ms step_avg:91.19ms +step:64/1645 train_time:5837ms step_avg:91.20ms +step:65/1645 train_time:5929ms step_avg:91.22ms +step:66/1645 train_time:6022ms step_avg:91.25ms +step:67/1645 train_time:6114ms step_avg:91.26ms +step:68/1645 train_time:6206ms step_avg:91.26ms +step:69/1645 train_time:6298ms step_avg:91.28ms +step:70/1645 train_time:6390ms step_avg:91.29ms +step:71/1645 train_time:6482ms step_avg:91.30ms +step:72/1645 train_time:6574ms step_avg:91.30ms +step:73/1645 train_time:6665ms step_avg:91.30ms +step:74/1645 train_time:6756ms step_avg:91.30ms +step:75/1645 train_time:6848ms step_avg:91.31ms +step:76/1645 train_time:6939ms step_avg:91.31ms +step:77/1645 train_time:7033ms step_avg:91.34ms +step:78/1645 train_time:7126ms step_avg:91.36ms +step:79/1645 train_time:7217ms step_avg:91.36ms +step:80/1645 train_time:7310ms step_avg:91.38ms +step:81/1645 train_time:7401ms step_avg:91.37ms +step:82/1645 train_time:7493ms step_avg:91.37ms +step:83/1645 train_time:7584ms step_avg:91.37ms +step:84/1645 train_time:7675ms step_avg:91.37ms +step:85/1645 train_time:7766ms step_avg:91.36ms +step:86/1645 train_time:7858ms step_avg:91.37ms +step:87/1645 train_time:7950ms step_avg:91.38ms +step:88/1645 train_time:8042ms step_avg:91.39ms +step:89/1645 train_time:8134ms step_avg:91.39ms +step:90/1645 train_time:8226ms step_avg:91.40ms +step:91/1645 train_time:8319ms step_avg:91.41ms +step:92/1645 train_time:8410ms step_avg:91.42ms +step:93/1645 train_time:8503ms step_avg:91.43ms +step:94/1645 train_time:8595ms step_avg:91.44ms +step:95/1645 train_time:8686ms step_avg:91.43ms +step:96/1645 train_time:8778ms step_avg:91.43ms +step:97/1645 train_time:8869ms step_avg:91.44ms +step:98/1645 train_time:8961ms step_avg:91.44ms +step:99/1645 train_time:9053ms step_avg:91.44ms +step:100/1645 train_time:9144ms step_avg:91.44ms +step:101/1645 train_time:9236ms step_avg:91.45ms +step:102/1645 train_time:9328ms step_avg:91.45ms +step:103/1645 train_time:9420ms step_avg:91.46ms +step:104/1645 train_time:9513ms step_avg:91.47ms +step:105/1645 train_time:9605ms step_avg:91.47ms +step:106/1645 train_time:9696ms step_avg:91.47ms +step:107/1645 train_time:9788ms step_avg:91.47ms +step:108/1645 train_time:9880ms step_avg:91.48ms +step:109/1645 train_time:9971ms step_avg:91.48ms +step:110/1645 train_time:10063ms step_avg:91.48ms +step:111/1645 train_time:10155ms step_avg:91.49ms +step:112/1645 train_time:10246ms step_avg:91.48ms +step:113/1645 train_time:10338ms step_avg:91.48ms +step:114/1645 train_time:10430ms step_avg:91.49ms +step:115/1645 train_time:10522ms step_avg:91.50ms +step:116/1645 train_time:10615ms step_avg:91.51ms +step:117/1645 train_time:10707ms step_avg:91.51ms +step:118/1645 train_time:10799ms step_avg:91.51ms +step:119/1645 train_time:10890ms step_avg:91.51ms +step:120/1645 train_time:10982ms step_avg:91.52ms +step:121/1645 train_time:11074ms step_avg:91.52ms +step:122/1645 train_time:11165ms step_avg:91.52ms +step:123/1645 train_time:11256ms step_avg:91.51ms +step:124/1645 train_time:11349ms step_avg:91.52ms +step:125/1645 train_time:11440ms step_avg:91.52ms +step:125/1645 val_loss:4.2992 train_time:11533ms step_avg:92.26ms +step:126/1645 train_time:11548ms step_avg:91.65ms +step:127/1645 train_time:11634ms step_avg:91.61ms +step:128/1645 train_time:11736ms step_avg:91.68ms +step:129/1645 train_time:11828ms step_avg:91.69ms +step:130/1645 train_time:11919ms step_avg:91.69ms +step:131/1645 train_time:12010ms step_avg:91.68ms +step:132/1645 train_time:12101ms step_avg:91.67ms +step:133/1645 train_time:12191ms step_avg:91.66ms +step:134/1645 train_time:12282ms step_avg:91.66ms +step:135/1645 train_time:12372ms step_avg:91.65ms +step:136/1645 train_time:12463ms step_avg:91.64ms +step:137/1645 train_time:12556ms step_avg:91.65ms +step:138/1645 train_time:12652ms step_avg:91.68ms +step:139/1645 train_time:12746ms step_avg:91.70ms +step:140/1645 train_time:12839ms step_avg:91.71ms +step:141/1645 train_time:12931ms step_avg:91.71ms +step:142/1645 train_time:13022ms step_avg:91.71ms +step:143/1645 train_time:13113ms step_avg:91.70ms +step:144/1645 train_time:13204ms step_avg:91.69ms +step:145/1645 train_time:13295ms step_avg:91.69ms +step:146/1645 train_time:13386ms step_avg:91.68ms +step:147/1645 train_time:13476ms step_avg:91.68ms +step:148/1645 train_time:13570ms step_avg:91.69ms +step:149/1645 train_time:13664ms step_avg:91.71ms +step:150/1645 train_time:13757ms step_avg:91.71ms +step:151/1645 train_time:13849ms step_avg:91.71ms +step:152/1645 train_time:13941ms step_avg:91.71ms +step:153/1645 train_time:14032ms step_avg:91.71ms +step:154/1645 train_time:14123ms step_avg:91.71ms +step:155/1645 train_time:14214ms step_avg:91.70ms +step:156/1645 train_time:14305ms step_avg:91.70ms +step:157/1645 train_time:14396ms step_avg:91.69ms +step:158/1645 train_time:14488ms step_avg:91.69ms +step:159/1645 train_time:14579ms step_avg:91.69ms +step:160/1645 train_time:14671ms step_avg:91.70ms +step:161/1645 train_time:14763ms step_avg:91.70ms +step:162/1645 train_time:14856ms step_avg:91.70ms +step:163/1645 train_time:14948ms step_avg:91.71ms +step:164/1645 train_time:15040ms step_avg:91.71ms +step:165/1645 train_time:15131ms step_avg:91.70ms +step:166/1645 train_time:15222ms step_avg:91.70ms +step:167/1645 train_time:15314ms step_avg:91.70ms +step:168/1645 train_time:15405ms step_avg:91.70ms +step:169/1645 train_time:15496ms step_avg:91.69ms +step:170/1645 train_time:15588ms step_avg:91.69ms +step:171/1645 train_time:15680ms step_avg:91.69ms +step:172/1645 train_time:15772ms step_avg:91.70ms +step:173/1645 train_time:15865ms step_avg:91.71ms +step:174/1645 train_time:15958ms step_avg:91.71ms +step:175/1645 train_time:16049ms step_avg:91.71ms +step:176/1645 train_time:16141ms step_avg:91.71ms +step:177/1645 train_time:16231ms step_avg:91.70ms +step:178/1645 train_time:16323ms step_avg:91.70ms +step:179/1645 train_time:16414ms step_avg:91.70ms +step:180/1645 train_time:16505ms step_avg:91.70ms +step:181/1645 train_time:16597ms step_avg:91.70ms +step:182/1645 train_time:16689ms step_avg:91.70ms +step:183/1645 train_time:16781ms step_avg:91.70ms +step:184/1645 train_time:16872ms step_avg:91.70ms +step:185/1645 train_time:16965ms step_avg:91.70ms +step:186/1645 train_time:17056ms step_avg:91.70ms +step:187/1645 train_time:17148ms step_avg:91.70ms +step:188/1645 train_time:17240ms step_avg:91.70ms +step:189/1645 train_time:17331ms step_avg:91.70ms +step:190/1645 train_time:17422ms step_avg:91.69ms +step:191/1645 train_time:17514ms step_avg:91.69ms +step:192/1645 train_time:17605ms step_avg:91.69ms +step:193/1645 train_time:17697ms step_avg:91.70ms +step:194/1645 train_time:17790ms step_avg:91.70ms +step:195/1645 train_time:17881ms step_avg:91.70ms +step:196/1645 train_time:17972ms step_avg:91.70ms +step:197/1645 train_time:18066ms step_avg:91.71ms +step:198/1645 train_time:18158ms step_avg:91.71ms +step:199/1645 train_time:18249ms step_avg:91.71ms +step:200/1645 train_time:18340ms step_avg:91.70ms +step:201/1645 train_time:18432ms step_avg:91.70ms +step:202/1645 train_time:18523ms step_avg:91.70ms +step:203/1645 train_time:18615ms step_avg:91.70ms +step:204/1645 train_time:18707ms step_avg:91.70ms +step:205/1645 train_time:18799ms step_avg:91.70ms +step:206/1645 train_time:18890ms step_avg:91.70ms +step:207/1645 train_time:18981ms step_avg:91.70ms +step:208/1645 train_time:19073ms step_avg:91.70ms +step:209/1645 train_time:19166ms step_avg:91.70ms +step:210/1645 train_time:19259ms step_avg:91.71ms +step:211/1645 train_time:19350ms step_avg:91.71ms +step:212/1645 train_time:19441ms step_avg:91.70ms +step:213/1645 train_time:19534ms step_avg:91.71ms +step:214/1645 train_time:19625ms step_avg:91.70ms +step:215/1645 train_time:19716ms step_avg:91.70ms +step:216/1645 train_time:19807ms step_avg:91.70ms +step:217/1645 train_time:19898ms step_avg:91.70ms +step:218/1645 train_time:19991ms step_avg:91.70ms +step:219/1645 train_time:20081ms step_avg:91.69ms +step:220/1645 train_time:20173ms step_avg:91.69ms +step:221/1645 train_time:20266ms step_avg:91.70ms +step:222/1645 train_time:20357ms step_avg:91.70ms +step:223/1645 train_time:20448ms step_avg:91.70ms +step:224/1645 train_time:20541ms step_avg:91.70ms +step:225/1645 train_time:20633ms step_avg:91.70ms +step:226/1645 train_time:20724ms step_avg:91.70ms +step:227/1645 train_time:20816ms step_avg:91.70ms +step:228/1645 train_time:20908ms step_avg:91.70ms +step:229/1645 train_time:21000ms step_avg:91.70ms +step:230/1645 train_time:21091ms step_avg:91.70ms +step:231/1645 train_time:21183ms step_avg:91.70ms +step:232/1645 train_time:21275ms step_avg:91.70ms +step:233/1645 train_time:21367ms step_avg:91.70ms +step:234/1645 train_time:21459ms step_avg:91.71ms +step:235/1645 train_time:21551ms step_avg:91.71ms +step:236/1645 train_time:21643ms step_avg:91.71ms +step:237/1645 train_time:21733ms step_avg:91.70ms +step:238/1645 train_time:21825ms step_avg:91.70ms +step:239/1645 train_time:21916ms step_avg:91.70ms +step:240/1645 train_time:22008ms step_avg:91.70ms +step:241/1645 train_time:22100ms step_avg:91.70ms +step:242/1645 train_time:22191ms step_avg:91.70ms +step:243/1645 train_time:22282ms step_avg:91.69ms +step:244/1645 train_time:22373ms step_avg:91.69ms +step:245/1645 train_time:22465ms step_avg:91.70ms +step:246/1645 train_time:22558ms step_avg:91.70ms +step:247/1645 train_time:22649ms step_avg:91.70ms +step:248/1645 train_time:22741ms step_avg:91.70ms +step:249/1645 train_time:22833ms step_avg:91.70ms +step:250/1645 train_time:22924ms step_avg:91.70ms +step:250/1645 val_loss:3.9651 train_time:23016ms step_avg:92.07ms +step:251/1645 train_time:23031ms step_avg:91.76ms +step:252/1645 train_time:23109ms step_avg:91.70ms +step:253/1645 train_time:23203ms step_avg:91.71ms +step:254/1645 train_time:23294ms step_avg:91.71ms +step:255/1645 train_time:23386ms step_avg:91.71ms +step:256/1645 train_time:23477ms step_avg:91.71ms +step:257/1645 train_time:23567ms step_avg:91.70ms +step:258/1645 train_time:23659ms step_avg:91.70ms +step:259/1645 train_time:23750ms step_avg:91.70ms +step:260/1645 train_time:23842ms step_avg:91.70ms +step:261/1645 train_time:23934ms step_avg:91.70ms +step:262/1645 train_time:24027ms step_avg:91.71ms +step:263/1645 train_time:24120ms step_avg:91.71ms +step:264/1645 train_time:24212ms step_avg:91.71ms +step:265/1645 train_time:24305ms step_avg:91.72ms +step:266/1645 train_time:24396ms step_avg:91.72ms +step:267/1645 train_time:24488ms step_avg:91.71ms +step:268/1645 train_time:24579ms step_avg:91.71ms +step:269/1645 train_time:24670ms step_avg:91.71ms +step:270/1645 train_time:24761ms step_avg:91.71ms +step:271/1645 train_time:24852ms step_avg:91.70ms +step:272/1645 train_time:24944ms step_avg:91.71ms +step:273/1645 train_time:25036ms step_avg:91.71ms +step:274/1645 train_time:25129ms step_avg:91.71ms +step:275/1645 train_time:25221ms step_avg:91.71ms +step:276/1645 train_time:25313ms step_avg:91.71ms +step:277/1645 train_time:25404ms step_avg:91.71ms +step:278/1645 train_time:25496ms step_avg:91.71ms +step:279/1645 train_time:25588ms step_avg:91.71ms +step:280/1645 train_time:25679ms step_avg:91.71ms +step:281/1645 train_time:25771ms step_avg:91.71ms +step:282/1645 train_time:25862ms step_avg:91.71ms +step:283/1645 train_time:25954ms step_avg:91.71ms +step:284/1645 train_time:26046ms step_avg:91.71ms +step:285/1645 train_time:26139ms step_avg:91.72ms +step:286/1645 train_time:26230ms step_avg:91.71ms +step:287/1645 train_time:26322ms step_avg:91.71ms +step:288/1645 train_time:26413ms step_avg:91.71ms +step:289/1645 train_time:26505ms step_avg:91.71ms +step:290/1645 train_time:26596ms step_avg:91.71ms +step:291/1645 train_time:26688ms step_avg:91.71ms +step:292/1645 train_time:26781ms step_avg:91.72ms +step:293/1645 train_time:26873ms step_avg:91.72ms +step:294/1645 train_time:26964ms step_avg:91.71ms +step:295/1645 train_time:27056ms step_avg:91.71ms +step:296/1645 train_time:27147ms step_avg:91.71ms +step:297/1645 train_time:27239ms step_avg:91.71ms +step:298/1645 train_time:27331ms step_avg:91.71ms +step:299/1645 train_time:27422ms step_avg:91.71ms +step:300/1645 train_time:27514ms step_avg:91.71ms +step:301/1645 train_time:27606ms step_avg:91.72ms +step:302/1645 train_time:27698ms step_avg:91.71ms +step:303/1645 train_time:27790ms step_avg:91.72ms +step:304/1645 train_time:27881ms step_avg:91.71ms +step:305/1645 train_time:27972ms step_avg:91.71ms +step:306/1645 train_time:28064ms step_avg:91.71ms +step:307/1645 train_time:28155ms step_avg:91.71ms +step:308/1645 train_time:28247ms step_avg:91.71ms +step:309/1645 train_time:28339ms step_avg:91.71ms +step:310/1645 train_time:28431ms step_avg:91.71ms +step:311/1645 train_time:28522ms step_avg:91.71ms +step:312/1645 train_time:28615ms step_avg:91.71ms +step:313/1645 train_time:28707ms step_avg:91.72ms +step:314/1645 train_time:28799ms step_avg:91.72ms +step:315/1645 train_time:28891ms step_avg:91.72ms +step:316/1645 train_time:28983ms step_avg:91.72ms +step:317/1645 train_time:29074ms step_avg:91.72ms +step:318/1645 train_time:29165ms step_avg:91.71ms +step:319/1645 train_time:29257ms step_avg:91.71ms +step:320/1645 train_time:29348ms step_avg:91.71ms +step:321/1645 train_time:29441ms step_avg:91.72ms +step:322/1645 train_time:29532ms step_avg:91.72ms +step:323/1645 train_time:29624ms step_avg:91.71ms +step:324/1645 train_time:29716ms step_avg:91.72ms +step:325/1645 train_time:29808ms step_avg:91.72ms +step:326/1645 train_time:29900ms step_avg:91.72ms +step:327/1645 train_time:29992ms step_avg:91.72ms +step:328/1645 train_time:30082ms step_avg:91.71ms +step:329/1645 train_time:30174ms step_avg:91.71ms +step:330/1645 train_time:30265ms step_avg:91.71ms +step:331/1645 train_time:30356ms step_avg:91.71ms +step:332/1645 train_time:30449ms step_avg:91.71ms +step:333/1645 train_time:30540ms step_avg:91.71ms +step:334/1645 train_time:30632ms step_avg:91.71ms +step:335/1645 train_time:30723ms step_avg:91.71ms +step:336/1645 train_time:30816ms step_avg:91.72ms +step:337/1645 train_time:30909ms step_avg:91.72ms +step:338/1645 train_time:31001ms step_avg:91.72ms +step:339/1645 train_time:31093ms step_avg:91.72ms +step:340/1645 train_time:31184ms step_avg:91.72ms +step:341/1645 train_time:31275ms step_avg:91.72ms +step:342/1645 train_time:31367ms step_avg:91.72ms +step:343/1645 train_time:31459ms step_avg:91.72ms +step:344/1645 train_time:31550ms step_avg:91.71ms +step:345/1645 train_time:31642ms step_avg:91.71ms +step:346/1645 train_time:31733ms step_avg:91.72ms +step:347/1645 train_time:31824ms step_avg:91.71ms +step:348/1645 train_time:31918ms step_avg:91.72ms +step:349/1645 train_time:32010ms step_avg:91.72ms +step:350/1645 train_time:32102ms step_avg:91.72ms +step:351/1645 train_time:32193ms step_avg:91.72ms +step:352/1645 train_time:32285ms step_avg:91.72ms +step:353/1645 train_time:32376ms step_avg:91.72ms +step:354/1645 train_time:32467ms step_avg:91.72ms +step:355/1645 train_time:32559ms step_avg:91.72ms +step:356/1645 train_time:32650ms step_avg:91.71ms +step:357/1645 train_time:32742ms step_avg:91.71ms +step:358/1645 train_time:32834ms step_avg:91.72ms +step:359/1645 train_time:32926ms step_avg:91.72ms +step:360/1645 train_time:33019ms step_avg:91.72ms +step:361/1645 train_time:33112ms step_avg:91.72ms +step:362/1645 train_time:33204ms step_avg:91.72ms +step:363/1645 train_time:33295ms step_avg:91.72ms +step:364/1645 train_time:33387ms step_avg:91.72ms +step:365/1645 train_time:33479ms step_avg:91.72ms +step:366/1645 train_time:33570ms step_avg:91.72ms +step:367/1645 train_time:33662ms step_avg:91.72ms +step:368/1645 train_time:33753ms step_avg:91.72ms +step:369/1645 train_time:33844ms step_avg:91.72ms +step:370/1645 train_time:33936ms step_avg:91.72ms +step:371/1645 train_time:34028ms step_avg:91.72ms +step:372/1645 train_time:34120ms step_avg:91.72ms +step:373/1645 train_time:34213ms step_avg:91.72ms +step:374/1645 train_time:34304ms step_avg:91.72ms +step:375/1645 train_time:34396ms step_avg:91.72ms +step:375/1645 val_loss:3.8148 train_time:34488ms step_avg:91.97ms +step:376/1645 train_time:34508ms step_avg:91.78ms +step:377/1645 train_time:34585ms step_avg:91.74ms +step:378/1645 train_time:34680ms step_avg:91.75ms +step:379/1645 train_time:34772ms step_avg:91.75ms +step:380/1645 train_time:34864ms step_avg:91.75ms +step:381/1645 train_time:34954ms step_avg:91.74ms +step:382/1645 train_time:35045ms step_avg:91.74ms +step:383/1645 train_time:35136ms step_avg:91.74ms +step:384/1645 train_time:35227ms step_avg:91.74ms +step:385/1645 train_time:35317ms step_avg:91.73ms +step:386/1645 train_time:35409ms step_avg:91.73ms +step:387/1645 train_time:35502ms step_avg:91.74ms +step:388/1645 train_time:35597ms step_avg:91.74ms +step:389/1645 train_time:35690ms step_avg:91.75ms +step:390/1645 train_time:35783ms step_avg:91.75ms +step:391/1645 train_time:35874ms step_avg:91.75ms +step:392/1645 train_time:35965ms step_avg:91.75ms +step:393/1645 train_time:36056ms step_avg:91.74ms +step:394/1645 train_time:36147ms step_avg:91.74ms +step:395/1645 train_time:36237ms step_avg:91.74ms +step:396/1645 train_time:36329ms step_avg:91.74ms +step:397/1645 train_time:36421ms step_avg:91.74ms +step:398/1645 train_time:36513ms step_avg:91.74ms +step:399/1645 train_time:36605ms step_avg:91.74ms +step:400/1645 train_time:36697ms step_avg:91.74ms +step:401/1645 train_time:36790ms step_avg:91.75ms +step:402/1645 train_time:36882ms step_avg:91.75ms +step:403/1645 train_time:36975ms step_avg:91.75ms +step:404/1645 train_time:37065ms step_avg:91.74ms +step:405/1645 train_time:37156ms step_avg:91.74ms +step:406/1645 train_time:37247ms step_avg:91.74ms +step:407/1645 train_time:37339ms step_avg:91.74ms +step:408/1645 train_time:37431ms step_avg:91.74ms +step:409/1645 train_time:37523ms step_avg:91.74ms +step:410/1645 train_time:37615ms step_avg:91.74ms +step:411/1645 train_time:37708ms step_avg:91.75ms +step:412/1645 train_time:37800ms step_avg:91.75ms +step:413/1645 train_time:37891ms step_avg:91.75ms +step:414/1645 train_time:37982ms step_avg:91.74ms +step:415/1645 train_time:38073ms step_avg:91.74ms +step:416/1645 train_time:38165ms step_avg:91.74ms +step:417/1645 train_time:38256ms step_avg:91.74ms +step:418/1645 train_time:38348ms step_avg:91.74ms +step:419/1645 train_time:38440ms step_avg:91.74ms +step:420/1645 train_time:38532ms step_avg:91.74ms +step:421/1645 train_time:38624ms step_avg:91.74ms +step:422/1645 train_time:38715ms step_avg:91.74ms +step:423/1645 train_time:38807ms step_avg:91.74ms +step:424/1645 train_time:38898ms step_avg:91.74ms +step:425/1645 train_time:38993ms step_avg:91.75ms +step:426/1645 train_time:39086ms step_avg:91.75ms +step:427/1645 train_time:39176ms step_avg:91.75ms +step:428/1645 train_time:39266ms step_avg:91.74ms +step:429/1645 train_time:39358ms step_avg:91.74ms +step:430/1645 train_time:39449ms step_avg:91.74ms +step:431/1645 train_time:39542ms step_avg:91.74ms +step:432/1645 train_time:39633ms step_avg:91.74ms +step:433/1645 train_time:39726ms step_avg:91.75ms +step:434/1645 train_time:39817ms step_avg:91.74ms +step:435/1645 train_time:39908ms step_avg:91.74ms +step:436/1645 train_time:40000ms step_avg:91.74ms +step:437/1645 train_time:40093ms step_avg:91.75ms +step:438/1645 train_time:40184ms step_avg:91.74ms +step:439/1645 train_time:40276ms step_avg:91.75ms +step:440/1645 train_time:40367ms step_avg:91.74ms +step:441/1645 train_time:40459ms step_avg:91.74ms +step:442/1645 train_time:40551ms step_avg:91.75ms +step:443/1645 train_time:40643ms step_avg:91.75ms +step:444/1645 train_time:40735ms step_avg:91.75ms +step:445/1645 train_time:40826ms step_avg:91.74ms +step:446/1645 train_time:40918ms step_avg:91.74ms +step:447/1645 train_time:41010ms step_avg:91.74ms +step:448/1645 train_time:41102ms step_avg:91.74ms +step:449/1645 train_time:41194ms step_avg:91.75ms +step:450/1645 train_time:41284ms step_avg:91.74ms +step:451/1645 train_time:41376ms step_avg:91.74ms +step:452/1645 train_time:41467ms step_avg:91.74ms +step:453/1645 train_time:41560ms step_avg:91.74ms +step:454/1645 train_time:41652ms step_avg:91.74ms +step:455/1645 train_time:41743ms step_avg:91.74ms +step:456/1645 train_time:41835ms step_avg:91.74ms +step:457/1645 train_time:41927ms step_avg:91.74ms +step:458/1645 train_time:42019ms step_avg:91.74ms +step:459/1645 train_time:42110ms step_avg:91.74ms +step:460/1645 train_time:42203ms step_avg:91.75ms +step:461/1645 train_time:42295ms step_avg:91.75ms +step:462/1645 train_time:42384ms step_avg:91.74ms +step:463/1645 train_time:42476ms step_avg:91.74ms +step:464/1645 train_time:42568ms step_avg:91.74ms +step:465/1645 train_time:42660ms step_avg:91.74ms +step:466/1645 train_time:42752ms step_avg:91.74ms +step:467/1645 train_time:42844ms step_avg:91.74ms +step:468/1645 train_time:42936ms step_avg:91.74ms +step:469/1645 train_time:43028ms step_avg:91.74ms +step:470/1645 train_time:43120ms step_avg:91.75ms +step:471/1645 train_time:43212ms step_avg:91.75ms +step:472/1645 train_time:43303ms step_avg:91.74ms +step:473/1645 train_time:43395ms step_avg:91.74ms +step:474/1645 train_time:43486ms step_avg:91.74ms +step:475/1645 train_time:43578ms step_avg:91.74ms +step:476/1645 train_time:43672ms step_avg:91.75ms +step:477/1645 train_time:43764ms step_avg:91.75ms +step:478/1645 train_time:43855ms step_avg:91.75ms +step:479/1645 train_time:43947ms step_avg:91.75ms +step:480/1645 train_time:44038ms step_avg:91.74ms +step:481/1645 train_time:44129ms step_avg:91.74ms +step:482/1645 train_time:44221ms step_avg:91.74ms +step:483/1645 train_time:44312ms step_avg:91.74ms +step:484/1645 train_time:44404ms step_avg:91.74ms +step:485/1645 train_time:44494ms step_avg:91.74ms +step:486/1645 train_time:44587ms step_avg:91.74ms +step:487/1645 train_time:44678ms step_avg:91.74ms +step:488/1645 train_time:44770ms step_avg:91.74ms +step:489/1645 train_time:44862ms step_avg:91.74ms +step:490/1645 train_time:44953ms step_avg:91.74ms +step:491/1645 train_time:45045ms step_avg:91.74ms +step:492/1645 train_time:45137ms step_avg:91.74ms +step:493/1645 train_time:45228ms step_avg:91.74ms +step:494/1645 train_time:45319ms step_avg:91.74ms +step:495/1645 train_time:45413ms step_avg:91.74ms +step:496/1645 train_time:45503ms step_avg:91.74ms +step:497/1645 train_time:45594ms step_avg:91.74ms +step:498/1645 train_time:45686ms step_avg:91.74ms +step:499/1645 train_time:45777ms step_avg:91.74ms +step:500/1645 train_time:45870ms step_avg:91.74ms +step:500/1645 val_loss:3.7130 train_time:45961ms step_avg:91.92ms +step:501/1645 train_time:45982ms step_avg:91.78ms +step:502/1645 train_time:46058ms step_avg:91.75ms +step:503/1645 train_time:46153ms step_avg:91.75ms +step:504/1645 train_time:46244ms step_avg:91.75ms +step:505/1645 train_time:46335ms step_avg:91.75ms +step:506/1645 train_time:46426ms step_avg:91.75ms +step:507/1645 train_time:46516ms step_avg:91.75ms +step:508/1645 train_time:46607ms step_avg:91.75ms +step:509/1645 train_time:46697ms step_avg:91.74ms +step:510/1645 train_time:46788ms step_avg:91.74ms +step:511/1645 train_time:46880ms step_avg:91.74ms +step:512/1645 train_time:46974ms step_avg:91.75ms +step:513/1645 train_time:47068ms step_avg:91.75ms +step:514/1645 train_time:47161ms step_avg:91.75ms +step:515/1645 train_time:47253ms step_avg:91.75ms +step:516/1645 train_time:47344ms step_avg:91.75ms +step:517/1645 train_time:47435ms step_avg:91.75ms +step:518/1645 train_time:47527ms step_avg:91.75ms +step:519/1645 train_time:47618ms step_avg:91.75ms +step:520/1645 train_time:47708ms step_avg:91.75ms +step:521/1645 train_time:47799ms step_avg:91.75ms +step:522/1645 train_time:47891ms step_avg:91.75ms +step:523/1645 train_time:47986ms step_avg:91.75ms +step:524/1645 train_time:48080ms step_avg:91.76ms +step:525/1645 train_time:48171ms step_avg:91.75ms +step:526/1645 train_time:48263ms step_avg:91.75ms +step:527/1645 train_time:48353ms step_avg:91.75ms +step:528/1645 train_time:48445ms step_avg:91.75ms +step:529/1645 train_time:48536ms step_avg:91.75ms +step:530/1645 train_time:48627ms step_avg:91.75ms +step:531/1645 train_time:48717ms step_avg:91.75ms +step:532/1645 train_time:48809ms step_avg:91.75ms +step:533/1645 train_time:48900ms step_avg:91.74ms +step:534/1645 train_time:48992ms step_avg:91.75ms +step:535/1645 train_time:49085ms step_avg:91.75ms +step:536/1645 train_time:49178ms step_avg:91.75ms +step:537/1645 train_time:49270ms step_avg:91.75ms +step:538/1645 train_time:49360ms step_avg:91.75ms +step:539/1645 train_time:49451ms step_avg:91.75ms +step:540/1645 train_time:49543ms step_avg:91.75ms +step:541/1645 train_time:49634ms step_avg:91.75ms +step:542/1645 train_time:49726ms step_avg:91.75ms +step:543/1645 train_time:49817ms step_avg:91.74ms +step:544/1645 train_time:49908ms step_avg:91.74ms +step:545/1645 train_time:50001ms step_avg:91.75ms +step:546/1645 train_time:50094ms step_avg:91.75ms +step:547/1645 train_time:50186ms step_avg:91.75ms +step:548/1645 train_time:50278ms step_avg:91.75ms +step:549/1645 train_time:50369ms step_avg:91.75ms +step:550/1645 train_time:50461ms step_avg:91.75ms +step:551/1645 train_time:50553ms step_avg:91.75ms +step:552/1645 train_time:50646ms step_avg:91.75ms +step:553/1645 train_time:50739ms step_avg:91.75ms +step:554/1645 train_time:50831ms step_avg:91.75ms +step:555/1645 train_time:50925ms step_avg:91.76ms +step:556/1645 train_time:51019ms step_avg:91.76ms +step:557/1645 train_time:51112ms step_avg:91.76ms +step:558/1645 train_time:51206ms step_avg:91.77ms +step:559/1645 train_time:51299ms step_avg:91.77ms +step:560/1645 train_time:51393ms step_avg:91.77ms +step:561/1645 train_time:51485ms step_avg:91.77ms +step:562/1645 train_time:51577ms step_avg:91.77ms +step:563/1645 train_time:51669ms step_avg:91.78ms +step:564/1645 train_time:51762ms step_avg:91.78ms +step:565/1645 train_time:51855ms step_avg:91.78ms +step:566/1645 train_time:51949ms step_avg:91.78ms +step:567/1645 train_time:52043ms step_avg:91.79ms +step:568/1645 train_time:52136ms step_avg:91.79ms +step:569/1645 train_time:52229ms step_avg:91.79ms +step:570/1645 train_time:52322ms step_avg:91.79ms +step:571/1645 train_time:52415ms step_avg:91.79ms +step:572/1645 train_time:52508ms step_avg:91.80ms +step:573/1645 train_time:52601ms step_avg:91.80ms +step:574/1645 train_time:52693ms step_avg:91.80ms +step:575/1645 train_time:52786ms step_avg:91.80ms +step:576/1645 train_time:52878ms step_avg:91.80ms +step:577/1645 train_time:52972ms step_avg:91.81ms +step:578/1645 train_time:53065ms step_avg:91.81ms +step:579/1645 train_time:53159ms step_avg:91.81ms +step:580/1645 train_time:53251ms step_avg:91.81ms +step:581/1645 train_time:53345ms step_avg:91.82ms +step:582/1645 train_time:53437ms step_avg:91.82ms +step:583/1645 train_time:53530ms step_avg:91.82ms +step:584/1645 train_time:53623ms step_avg:91.82ms +step:585/1645 train_time:53716ms step_avg:91.82ms +step:586/1645 train_time:53809ms step_avg:91.82ms +step:587/1645 train_time:53904ms step_avg:91.83ms +step:588/1645 train_time:53996ms step_avg:91.83ms +step:589/1645 train_time:54089ms step_avg:91.83ms +step:590/1645 train_time:54182ms step_avg:91.83ms +step:591/1645 train_time:54275ms step_avg:91.84ms +step:592/1645 train_time:54368ms step_avg:91.84ms +step:593/1645 train_time:54460ms step_avg:91.84ms +step:594/1645 train_time:54553ms step_avg:91.84ms +step:595/1645 train_time:54646ms step_avg:91.84ms +step:596/1645 train_time:54739ms step_avg:91.84ms +step:597/1645 train_time:54832ms step_avg:91.85ms +step:598/1645 train_time:54925ms step_avg:91.85ms +step:599/1645 train_time:55018ms step_avg:91.85ms +step:600/1645 train_time:55111ms step_avg:91.85ms +step:601/1645 train_time:55205ms step_avg:91.86ms +step:602/1645 train_time:55298ms step_avg:91.86ms +step:603/1645 train_time:55391ms step_avg:91.86ms +step:604/1645 train_time:55484ms step_avg:91.86ms +step:605/1645 train_time:55577ms step_avg:91.86ms +step:606/1645 train_time:55670ms step_avg:91.86ms +step:607/1645 train_time:55763ms step_avg:91.87ms +step:608/1645 train_time:55855ms step_avg:91.87ms +step:609/1645 train_time:55948ms step_avg:91.87ms +step:610/1645 train_time:56041ms step_avg:91.87ms +step:611/1645 train_time:56135ms step_avg:91.87ms +step:612/1645 train_time:56229ms step_avg:91.88ms +step:613/1645 train_time:56322ms step_avg:91.88ms +step:614/1645 train_time:56416ms step_avg:91.88ms +step:615/1645 train_time:56509ms step_avg:91.88ms +step:616/1645 train_time:56601ms step_avg:91.89ms +step:617/1645 train_time:56695ms step_avg:91.89ms +step:618/1645 train_time:56787ms step_avg:91.89ms +step:619/1645 train_time:56880ms step_avg:91.89ms +step:620/1645 train_time:56972ms step_avg:91.89ms +step:621/1645 train_time:57065ms step_avg:91.89ms +step:622/1645 train_time:57158ms step_avg:91.89ms +step:623/1645 train_time:57251ms step_avg:91.90ms +step:624/1645 train_time:57344ms step_avg:91.90ms +step:625/1645 train_time:57437ms step_avg:91.90ms +step:625/1645 val_loss:3.6119 train_time:57530ms step_avg:92.05ms +step:626/1645 train_time:57551ms step_avg:91.93ms +step:627/1645 train_time:57629ms step_avg:91.91ms +step:628/1645 train_time:57731ms step_avg:91.93ms +step:629/1645 train_time:57824ms step_avg:91.93ms +step:630/1645 train_time:57916ms step_avg:91.93ms +step:631/1645 train_time:58008ms step_avg:91.93ms +step:632/1645 train_time:58099ms step_avg:91.93ms +step:633/1645 train_time:58191ms step_avg:91.93ms +step:634/1645 train_time:58283ms step_avg:91.93ms +step:635/1645 train_time:58374ms step_avg:91.93ms +step:636/1645 train_time:58466ms step_avg:91.93ms +step:637/1645 train_time:58562ms step_avg:91.93ms +step:638/1645 train_time:58659ms step_avg:91.94ms +step:639/1645 train_time:58754ms step_avg:91.95ms +step:640/1645 train_time:58847ms step_avg:91.95ms +step:641/1645 train_time:58940ms step_avg:91.95ms +step:642/1645 train_time:59032ms step_avg:91.95ms +step:643/1645 train_time:59124ms step_avg:91.95ms +step:644/1645 train_time:59215ms step_avg:91.95ms +step:645/1645 train_time:59307ms step_avg:91.95ms +step:646/1645 train_time:59398ms step_avg:91.95ms +step:647/1645 train_time:59491ms step_avg:91.95ms +step:648/1645 train_time:59588ms step_avg:91.96ms +step:649/1645 train_time:59683ms step_avg:91.96ms +step:650/1645 train_time:59776ms step_avg:91.96ms +step:651/1645 train_time:59870ms step_avg:91.97ms +step:652/1645 train_time:59964ms step_avg:91.97ms +step:653/1645 train_time:60057ms step_avg:91.97ms +step:654/1645 train_time:60149ms step_avg:91.97ms +step:655/1645 train_time:60241ms step_avg:91.97ms +step:656/1645 train_time:60333ms step_avg:91.97ms +step:657/1645 train_time:60425ms step_avg:91.97ms +step:658/1645 train_time:60519ms step_avg:91.97ms +step:659/1645 train_time:60613ms step_avg:91.98ms +step:660/1645 train_time:60707ms step_avg:91.98ms +step:661/1645 train_time:60800ms step_avg:91.98ms +step:662/1645 train_time:60894ms step_avg:91.98ms +step:663/1645 train_time:60987ms step_avg:91.99ms +step:664/1645 train_time:61080ms step_avg:91.99ms +step:665/1645 train_time:61173ms step_avg:91.99ms +step:666/1645 train_time:61265ms step_avg:91.99ms +step:667/1645 train_time:61357ms step_avg:91.99ms +step:668/1645 train_time:61450ms step_avg:91.99ms +step:669/1645 train_time:61544ms step_avg:91.99ms +step:670/1645 train_time:61636ms step_avg:91.99ms +step:671/1645 train_time:61729ms step_avg:92.00ms +step:672/1645 train_time:61823ms step_avg:92.00ms +step:673/1645 train_time:61916ms step_avg:92.00ms +step:674/1645 train_time:62009ms step_avg:92.00ms +step:675/1645 train_time:62102ms step_avg:92.00ms +step:676/1645 train_time:62194ms step_avg:92.00ms +step:677/1645 train_time:62287ms step_avg:92.00ms +step:678/1645 train_time:62380ms step_avg:92.01ms +step:679/1645 train_time:62474ms step_avg:92.01ms +step:680/1645 train_time:62565ms step_avg:92.01ms +step:681/1645 train_time:62659ms step_avg:92.01ms +step:682/1645 train_time:62752ms step_avg:92.01ms +step:683/1645 train_time:62845ms step_avg:92.01ms +step:684/1645 train_time:62938ms step_avg:92.02ms +step:685/1645 train_time:63031ms step_avg:92.02ms +step:686/1645 train_time:63124ms step_avg:92.02ms +step:687/1645 train_time:63216ms step_avg:92.02ms +step:688/1645 train_time:63308ms step_avg:92.02ms +step:689/1645 train_time:63401ms step_avg:92.02ms +step:690/1645 train_time:63495ms step_avg:92.02ms +step:691/1645 train_time:63588ms step_avg:92.02ms +step:692/1645 train_time:63680ms step_avg:92.02ms +step:693/1645 train_time:63773ms step_avg:92.02ms +step:694/1645 train_time:63865ms step_avg:92.02ms +step:695/1645 train_time:63959ms step_avg:92.03ms +step:696/1645 train_time:64053ms step_avg:92.03ms +step:697/1645 train_time:64146ms step_avg:92.03ms +step:698/1645 train_time:64238ms step_avg:92.03ms +step:699/1645 train_time:64331ms step_avg:92.03ms +step:700/1645 train_time:64424ms step_avg:92.03ms +step:701/1645 train_time:64516ms step_avg:92.03ms +step:702/1645 train_time:64609ms step_avg:92.04ms +step:703/1645 train_time:64703ms step_avg:92.04ms +step:704/1645 train_time:64796ms step_avg:92.04ms +step:705/1645 train_time:64889ms step_avg:92.04ms +step:706/1645 train_time:64983ms step_avg:92.04ms +step:707/1645 train_time:65075ms step_avg:92.04ms +step:708/1645 train_time:65167ms step_avg:92.04ms +step:709/1645 train_time:65260ms step_avg:92.04ms +step:710/1645 train_time:65353ms step_avg:92.05ms +step:711/1645 train_time:65446ms step_avg:92.05ms +step:712/1645 train_time:65538ms step_avg:92.05ms +step:713/1645 train_time:65631ms step_avg:92.05ms +step:714/1645 train_time:65724ms step_avg:92.05ms +step:715/1645 train_time:65818ms step_avg:92.05ms +step:716/1645 train_time:65910ms step_avg:92.05ms +step:717/1645 train_time:66003ms step_avg:92.05ms +step:718/1645 train_time:66096ms step_avg:92.06ms +step:719/1645 train_time:66188ms step_avg:92.06ms +step:720/1645 train_time:66281ms step_avg:92.06ms +step:721/1645 train_time:66374ms step_avg:92.06ms +step:722/1645 train_time:66466ms step_avg:92.06ms +step:723/1645 train_time:66559ms step_avg:92.06ms +step:724/1645 train_time:66652ms step_avg:92.06ms +step:725/1645 train_time:66745ms step_avg:92.06ms +step:726/1645 train_time:66838ms step_avg:92.06ms +step:727/1645 train_time:66930ms step_avg:92.06ms +step:728/1645 train_time:67024ms step_avg:92.07ms +step:729/1645 train_time:67117ms step_avg:92.07ms +step:730/1645 train_time:67210ms step_avg:92.07ms +step:731/1645 train_time:67304ms step_avg:92.07ms +step:732/1645 train_time:67396ms step_avg:92.07ms +step:733/1645 train_time:67488ms step_avg:92.07ms +step:734/1645 train_time:67581ms step_avg:92.07ms +step:735/1645 train_time:67674ms step_avg:92.07ms +step:736/1645 train_time:67767ms step_avg:92.07ms +step:737/1645 train_time:67860ms step_avg:92.08ms +step:738/1645 train_time:67953ms step_avg:92.08ms +step:739/1645 train_time:68046ms step_avg:92.08ms +step:740/1645 train_time:68139ms step_avg:92.08ms +step:741/1645 train_time:68231ms step_avg:92.08ms +step:742/1645 train_time:68324ms step_avg:92.08ms +step:743/1645 train_time:68417ms step_avg:92.08ms +step:744/1645 train_time:68510ms step_avg:92.08ms +step:745/1645 train_time:68603ms step_avg:92.08ms +step:746/1645 train_time:68696ms step_avg:92.09ms +step:747/1645 train_time:68788ms step_avg:92.09ms +step:748/1645 train_time:68882ms step_avg:92.09ms +step:749/1645 train_time:68975ms step_avg:92.09ms +step:750/1645 train_time:69067ms step_avg:92.09ms +step:750/1645 val_loss:3.5612 train_time:69161ms step_avg:92.21ms +step:751/1645 train_time:69182ms step_avg:92.12ms +step:752/1645 train_time:69263ms step_avg:92.11ms +step:753/1645 train_time:69363ms step_avg:92.11ms +step:754/1645 train_time:69455ms step_avg:92.12ms +step:755/1645 train_time:69547ms step_avg:92.11ms +step:756/1645 train_time:69639ms step_avg:92.12ms +step:757/1645 train_time:69731ms step_avg:92.11ms +step:758/1645 train_time:69823ms step_avg:92.11ms +step:759/1645 train_time:69915ms step_avg:92.11ms +step:760/1645 train_time:70007ms step_avg:92.11ms +step:761/1645 train_time:70100ms step_avg:92.12ms +step:762/1645 train_time:70195ms step_avg:92.12ms +step:763/1645 train_time:70290ms step_avg:92.12ms +step:764/1645 train_time:70384ms step_avg:92.13ms +step:765/1645 train_time:70478ms step_avg:92.13ms +step:766/1645 train_time:70571ms step_avg:92.13ms +step:767/1645 train_time:70664ms step_avg:92.13ms +step:768/1645 train_time:70756ms step_avg:92.13ms +step:769/1645 train_time:70848ms step_avg:92.13ms +step:770/1645 train_time:70939ms step_avg:92.13ms +step:771/1645 train_time:71031ms step_avg:92.13ms +step:772/1645 train_time:71124ms step_avg:92.13ms +step:773/1645 train_time:71218ms step_avg:92.13ms +step:774/1645 train_time:71312ms step_avg:92.13ms +step:775/1645 train_time:71406ms step_avg:92.14ms +step:776/1645 train_time:71499ms step_avg:92.14ms +step:777/1645 train_time:71592ms step_avg:92.14ms +step:778/1645 train_time:71686ms step_avg:92.14ms +step:779/1645 train_time:71779ms step_avg:92.14ms +step:780/1645 train_time:71871ms step_avg:92.14ms +step:781/1645 train_time:71964ms step_avg:92.14ms +step:782/1645 train_time:72056ms step_avg:92.14ms +step:783/1645 train_time:72149ms step_avg:92.14ms +step:784/1645 train_time:72242ms step_avg:92.15ms +step:785/1645 train_time:72336ms step_avg:92.15ms +step:786/1645 train_time:72429ms step_avg:92.15ms +step:787/1645 train_time:72522ms step_avg:92.15ms +step:788/1645 train_time:72615ms step_avg:92.15ms +step:789/1645 train_time:72708ms step_avg:92.15ms +step:790/1645 train_time:72801ms step_avg:92.15ms +step:791/1645 train_time:72893ms step_avg:92.15ms +step:792/1645 train_time:72985ms step_avg:92.15ms +step:793/1645 train_time:73077ms step_avg:92.15ms +step:794/1645 train_time:73171ms step_avg:92.15ms +step:795/1645 train_time:73263ms step_avg:92.15ms +step:796/1645 train_time:73356ms step_avg:92.16ms +step:797/1645 train_time:73449ms step_avg:92.16ms +step:798/1645 train_time:73543ms step_avg:92.16ms +step:799/1645 train_time:73636ms step_avg:92.16ms +step:800/1645 train_time:73729ms step_avg:92.16ms +step:801/1645 train_time:73821ms step_avg:92.16ms +step:802/1645 train_time:73914ms step_avg:92.16ms +step:803/1645 train_time:74007ms step_avg:92.16ms +step:804/1645 train_time:74100ms step_avg:92.16ms +step:805/1645 train_time:74193ms step_avg:92.17ms +step:806/1645 train_time:74287ms step_avg:92.17ms +step:807/1645 train_time:74381ms step_avg:92.17ms +step:808/1645 train_time:74475ms step_avg:92.17ms +step:809/1645 train_time:74568ms step_avg:92.17ms +step:810/1645 train_time:74660ms step_avg:92.17ms +step:811/1645 train_time:74753ms step_avg:92.17ms +step:812/1645 train_time:74846ms step_avg:92.18ms +step:813/1645 train_time:74938ms step_avg:92.17ms +step:814/1645 train_time:75031ms step_avg:92.18ms +step:815/1645 train_time:75124ms step_avg:92.18ms +step:816/1645 train_time:75217ms step_avg:92.18ms +step:817/1645 train_time:75309ms step_avg:92.18ms +step:818/1645 train_time:75402ms step_avg:92.18ms +step:819/1645 train_time:75495ms step_avg:92.18ms +step:820/1645 train_time:75588ms step_avg:92.18ms +step:821/1645 train_time:75682ms step_avg:92.18ms +step:822/1645 train_time:75775ms step_avg:92.18ms +step:823/1645 train_time:75868ms step_avg:92.18ms +step:824/1645 train_time:75960ms step_avg:92.18ms +step:825/1645 train_time:76054ms step_avg:92.19ms +step:826/1645 train_time:76147ms step_avg:92.19ms +step:827/1645 train_time:76240ms step_avg:92.19ms +step:828/1645 train_time:76332ms step_avg:92.19ms +step:829/1645 train_time:76425ms step_avg:92.19ms +step:830/1645 train_time:76518ms step_avg:92.19ms +step:831/1645 train_time:76611ms step_avg:92.19ms +step:832/1645 train_time:76704ms step_avg:92.19ms +step:833/1645 train_time:76797ms step_avg:92.19ms +step:834/1645 train_time:76890ms step_avg:92.19ms +step:835/1645 train_time:76983ms step_avg:92.19ms +step:836/1645 train_time:77076ms step_avg:92.20ms +step:837/1645 train_time:77168ms step_avg:92.20ms +step:838/1645 train_time:77261ms step_avg:92.20ms +step:839/1645 train_time:77354ms step_avg:92.20ms +step:840/1645 train_time:77446ms step_avg:92.20ms +step:841/1645 train_time:77539ms step_avg:92.20ms +step:842/1645 train_time:77631ms step_avg:92.20ms +step:843/1645 train_time:77725ms step_avg:92.20ms +step:844/1645 train_time:77818ms step_avg:92.20ms +step:845/1645 train_time:77911ms step_avg:92.20ms +step:846/1645 train_time:78005ms step_avg:92.20ms +step:847/1645 train_time:78099ms step_avg:92.21ms +step:848/1645 train_time:78193ms step_avg:92.21ms +step:849/1645 train_time:78285ms step_avg:92.21ms +step:850/1645 train_time:78378ms step_avg:92.21ms +step:851/1645 train_time:78471ms step_avg:92.21ms +step:852/1645 train_time:78564ms step_avg:92.21ms +step:853/1645 train_time:78657ms step_avg:92.21ms +step:854/1645 train_time:78749ms step_avg:92.21ms +step:855/1645 train_time:78843ms step_avg:92.21ms +step:856/1645 train_time:78934ms step_avg:92.21ms +step:857/1645 train_time:79028ms step_avg:92.21ms +step:858/1645 train_time:79121ms step_avg:92.22ms +step:859/1645 train_time:79214ms step_avg:92.22ms +step:860/1645 train_time:79307ms step_avg:92.22ms +step:861/1645 train_time:79401ms step_avg:92.22ms +step:862/1645 train_time:79493ms step_avg:92.22ms +step:863/1645 train_time:79586ms step_avg:92.22ms +step:864/1645 train_time:79678ms step_avg:92.22ms +step:865/1645 train_time:79771ms step_avg:92.22ms +step:866/1645 train_time:79865ms step_avg:92.22ms +step:867/1645 train_time:79957ms step_avg:92.22ms +step:868/1645 train_time:80049ms step_avg:92.22ms +step:869/1645 train_time:80142ms step_avg:92.22ms +step:870/1645 train_time:80235ms step_avg:92.22ms +step:871/1645 train_time:80328ms step_avg:92.22ms +step:872/1645 train_time:80421ms step_avg:92.23ms +step:873/1645 train_time:80514ms step_avg:92.23ms +step:874/1645 train_time:80607ms step_avg:92.23ms +step:875/1645 train_time:80700ms step_avg:92.23ms +step:875/1645 val_loss:3.5125 train_time:80793ms step_avg:92.33ms +step:876/1645 train_time:80814ms step_avg:92.25ms +step:877/1645 train_time:80890ms step_avg:92.24ms +step:878/1645 train_time:80986ms step_avg:92.24ms +step:879/1645 train_time:81078ms step_avg:92.24ms +step:880/1645 train_time:81170ms step_avg:92.24ms +step:881/1645 train_time:81261ms step_avg:92.24ms +step:882/1645 train_time:81353ms step_avg:92.24ms +step:883/1645 train_time:81446ms step_avg:92.24ms +step:884/1645 train_time:81538ms step_avg:92.24ms +step:885/1645 train_time:81631ms step_avg:92.24ms +step:886/1645 train_time:81725ms step_avg:92.24ms +step:887/1645 train_time:81820ms step_avg:92.24ms +step:888/1645 train_time:81916ms step_avg:92.25ms +step:889/1645 train_time:82011ms step_avg:92.25ms +step:890/1645 train_time:82103ms step_avg:92.25ms +step:891/1645 train_time:82195ms step_avg:92.25ms +step:892/1645 train_time:82287ms step_avg:92.25ms +step:893/1645 train_time:82379ms step_avg:92.25ms +step:894/1645 train_time:82471ms step_avg:92.25ms +step:895/1645 train_time:82563ms step_avg:92.25ms +step:896/1645 train_time:82656ms step_avg:92.25ms +step:897/1645 train_time:82751ms step_avg:92.25ms +step:898/1645 train_time:82845ms step_avg:92.25ms +step:899/1645 train_time:82939ms step_avg:92.26ms +step:900/1645 train_time:83033ms step_avg:92.26ms +step:901/1645 train_time:83127ms step_avg:92.26ms +step:902/1645 train_time:83219ms step_avg:92.26ms +step:903/1645 train_time:83312ms step_avg:92.26ms +step:904/1645 train_time:83404ms step_avg:92.26ms +step:905/1645 train_time:83496ms step_avg:92.26ms +step:906/1645 train_time:83589ms step_avg:92.26ms +step:907/1645 train_time:83682ms step_avg:92.26ms +step:908/1645 train_time:83775ms step_avg:92.26ms +step:909/1645 train_time:83869ms step_avg:92.26ms +step:910/1645 train_time:83963ms step_avg:92.27ms +step:911/1645 train_time:84058ms step_avg:92.27ms +step:912/1645 train_time:84151ms step_avg:92.27ms +step:913/1645 train_time:84244ms step_avg:92.27ms +step:914/1645 train_time:84337ms step_avg:92.27ms +step:915/1645 train_time:84428ms step_avg:92.27ms +step:916/1645 train_time:84521ms step_avg:92.27ms +step:917/1645 train_time:84614ms step_avg:92.27ms +step:918/1645 train_time:84706ms step_avg:92.27ms +step:919/1645 train_time:84799ms step_avg:92.27ms +step:920/1645 train_time:84893ms step_avg:92.27ms +step:921/1645 train_time:84986ms step_avg:92.28ms +step:922/1645 train_time:85079ms step_avg:92.28ms +step:923/1645 train_time:85173ms step_avg:92.28ms +step:924/1645 train_time:85266ms step_avg:92.28ms +step:925/1645 train_time:85358ms step_avg:92.28ms +step:926/1645 train_time:85451ms step_avg:92.28ms +step:927/1645 train_time:85544ms step_avg:92.28ms +step:928/1645 train_time:85637ms step_avg:92.28ms +step:929/1645 train_time:85730ms step_avg:92.28ms +step:930/1645 train_time:85824ms step_avg:92.28ms +step:931/1645 train_time:85917ms step_avg:92.28ms +step:932/1645 train_time:86010ms step_avg:92.29ms +step:933/1645 train_time:86103ms step_avg:92.29ms +step:934/1645 train_time:86196ms step_avg:92.29ms +step:935/1645 train_time:86289ms step_avg:92.29ms +step:936/1645 train_time:86381ms step_avg:92.29ms +step:937/1645 train_time:86474ms step_avg:92.29ms +step:938/1645 train_time:86568ms step_avg:92.29ms +step:939/1645 train_time:86659ms step_avg:92.29ms +step:940/1645 train_time:86752ms step_avg:92.29ms +step:941/1645 train_time:86846ms step_avg:92.29ms +step:942/1645 train_time:86939ms step_avg:92.29ms +step:943/1645 train_time:87033ms step_avg:92.29ms +step:944/1645 train_time:87126ms step_avg:92.29ms +step:945/1645 train_time:87219ms step_avg:92.29ms +step:946/1645 train_time:87311ms step_avg:92.29ms +step:947/1645 train_time:87404ms step_avg:92.30ms +step:948/1645 train_time:87496ms step_avg:92.30ms +step:949/1645 train_time:87589ms step_avg:92.30ms +step:950/1645 train_time:87682ms step_avg:92.30ms +step:951/1645 train_time:87775ms step_avg:92.30ms +step:952/1645 train_time:87867ms step_avg:92.30ms +step:953/1645 train_time:87961ms step_avg:92.30ms +step:954/1645 train_time:88054ms step_avg:92.30ms +step:955/1645 train_time:88146ms step_avg:92.30ms +step:956/1645 train_time:88238ms step_avg:92.30ms +step:957/1645 train_time:88333ms step_avg:92.30ms +step:958/1645 train_time:88426ms step_avg:92.30ms +step:959/1645 train_time:88518ms step_avg:92.30ms +step:960/1645 train_time:88611ms step_avg:92.30ms +step:961/1645 train_time:88705ms step_avg:92.30ms +step:962/1645 train_time:88798ms step_avg:92.31ms +step:963/1645 train_time:88891ms step_avg:92.31ms +step:964/1645 train_time:88984ms step_avg:92.31ms +step:965/1645 train_time:89077ms step_avg:92.31ms +step:966/1645 train_time:89170ms step_avg:92.31ms +step:967/1645 train_time:89263ms step_avg:92.31ms +step:968/1645 train_time:89356ms step_avg:92.31ms +step:969/1645 train_time:89448ms step_avg:92.31ms +step:970/1645 train_time:89542ms step_avg:92.31ms +step:971/1645 train_time:89635ms step_avg:92.31ms +step:972/1645 train_time:89729ms step_avg:92.31ms +step:973/1645 train_time:89823ms step_avg:92.32ms +step:974/1645 train_time:89915ms step_avg:92.32ms +step:975/1645 train_time:90007ms step_avg:92.32ms +step:976/1645 train_time:90102ms step_avg:92.32ms +step:977/1645 train_time:90195ms step_avg:92.32ms +step:978/1645 train_time:90288ms step_avg:92.32ms +step:979/1645 train_time:90380ms step_avg:92.32ms +step:980/1645 train_time:90473ms step_avg:92.32ms +step:981/1645 train_time:90566ms step_avg:92.32ms +step:982/1645 train_time:90659ms step_avg:92.32ms +step:983/1645 train_time:90752ms step_avg:92.32ms +step:984/1645 train_time:90846ms step_avg:92.32ms +step:985/1645 train_time:90938ms step_avg:92.32ms +step:986/1645 train_time:91032ms step_avg:92.32ms +step:987/1645 train_time:91125ms step_avg:92.33ms +step:988/1645 train_time:91218ms step_avg:92.33ms +step:989/1645 train_time:91311ms step_avg:92.33ms +step:990/1645 train_time:91404ms step_avg:92.33ms +step:991/1645 train_time:91498ms step_avg:92.33ms +step:992/1645 train_time:91592ms step_avg:92.33ms +step:993/1645 train_time:91684ms step_avg:92.33ms +step:994/1645 train_time:91777ms step_avg:92.33ms +step:995/1645 train_time:91870ms step_avg:92.33ms +step:996/1645 train_time:91963ms step_avg:92.33ms +step:997/1645 train_time:92056ms step_avg:92.33ms +step:998/1645 train_time:92149ms step_avg:92.33ms +step:999/1645 train_time:92243ms step_avg:92.34ms +step:1000/1645 train_time:92336ms step_avg:92.34ms +step:1000/1645 val_loss:3.4642 train_time:92429ms step_avg:92.43ms +step:1001/1645 train_time:92449ms step_avg:92.36ms +step:1002/1645 train_time:92526ms step_avg:92.34ms +step:1003/1645 train_time:92621ms step_avg:92.34ms +step:1004/1645 train_time:92713ms step_avg:92.34ms +step:1005/1645 train_time:92805ms step_avg:92.34ms +step:1006/1645 train_time:92897ms step_avg:92.34ms +step:1007/1645 train_time:92988ms step_avg:92.34ms +step:1008/1645 train_time:93081ms step_avg:92.34ms +step:1009/1645 train_time:93174ms step_avg:92.34ms +step:1010/1645 train_time:93266ms step_avg:92.34ms +step:1011/1645 train_time:93362ms step_avg:92.35ms +step:1012/1645 train_time:93457ms step_avg:92.35ms +step:1013/1645 train_time:93551ms step_avg:92.35ms +step:1014/1645 train_time:93644ms step_avg:92.35ms +step:1015/1645 train_time:93737ms step_avg:92.35ms +step:1016/1645 train_time:93831ms step_avg:92.35ms +step:1017/1645 train_time:93924ms step_avg:92.35ms +step:1018/1645 train_time:94015ms step_avg:92.35ms +step:1019/1645 train_time:94107ms step_avg:92.35ms +step:1020/1645 train_time:94199ms step_avg:92.35ms +step:1021/1645 train_time:94293ms step_avg:92.35ms +step:1022/1645 train_time:94386ms step_avg:92.35ms +step:1023/1645 train_time:94481ms step_avg:92.36ms +step:1024/1645 train_time:94574ms step_avg:92.36ms +step:1025/1645 train_time:94667ms step_avg:92.36ms +step:1026/1645 train_time:94761ms step_avg:92.36ms +step:1027/1645 train_time:94854ms step_avg:92.36ms +step:1028/1645 train_time:94947ms step_avg:92.36ms +step:1029/1645 train_time:95039ms step_avg:92.36ms +step:1030/1645 train_time:95130ms step_avg:92.36ms +step:1031/1645 train_time:95223ms step_avg:92.36ms +step:1032/1645 train_time:95316ms step_avg:92.36ms +step:1033/1645 train_time:95408ms step_avg:92.36ms +step:1034/1645 train_time:95502ms step_avg:92.36ms +step:1035/1645 train_time:95595ms step_avg:92.36ms +step:1036/1645 train_time:95689ms step_avg:92.36ms +step:1037/1645 train_time:95782ms step_avg:92.36ms +step:1038/1645 train_time:95875ms step_avg:92.37ms +step:1039/1645 train_time:95969ms step_avg:92.37ms +step:1040/1645 train_time:96062ms step_avg:92.37ms +step:1041/1645 train_time:96154ms step_avg:92.37ms +step:1042/1645 train_time:96247ms step_avg:92.37ms +step:1043/1645 train_time:96340ms step_avg:92.37ms +step:1044/1645 train_time:96432ms step_avg:92.37ms +step:1045/1645 train_time:96525ms step_avg:92.37ms +step:1046/1645 train_time:96618ms step_avg:92.37ms +step:1047/1645 train_time:96711ms step_avg:92.37ms +step:1048/1645 train_time:96805ms step_avg:92.37ms +step:1049/1645 train_time:96901ms step_avg:92.37ms +step:1050/1645 train_time:96993ms step_avg:92.37ms +step:1051/1645 train_time:97085ms step_avg:92.37ms +step:1052/1645 train_time:97177ms step_avg:92.37ms +step:1053/1645 train_time:97270ms step_avg:92.37ms +step:1054/1645 train_time:97363ms step_avg:92.37ms +step:1055/1645 train_time:97456ms step_avg:92.38ms +step:1056/1645 train_time:97548ms step_avg:92.38ms +step:1057/1645 train_time:97641ms step_avg:92.38ms +step:1058/1645 train_time:97734ms step_avg:92.38ms +step:1059/1645 train_time:97827ms step_avg:92.38ms +step:1060/1645 train_time:97921ms step_avg:92.38ms +step:1061/1645 train_time:98014ms step_avg:92.38ms +step:1062/1645 train_time:98106ms step_avg:92.38ms +step:1063/1645 train_time:98199ms step_avg:92.38ms +step:1064/1645 train_time:98292ms step_avg:92.38ms +step:1065/1645 train_time:98384ms step_avg:92.38ms +step:1066/1645 train_time:98477ms step_avg:92.38ms +step:1067/1645 train_time:98570ms step_avg:92.38ms +step:1068/1645 train_time:98665ms step_avg:92.38ms +step:1069/1645 train_time:98758ms step_avg:92.38ms +step:1070/1645 train_time:98851ms step_avg:92.38ms +step:1071/1645 train_time:98944ms step_avg:92.38ms +step:1072/1645 train_time:99037ms step_avg:92.39ms +step:1073/1645 train_time:99129ms step_avg:92.39ms +step:1074/1645 train_time:99222ms step_avg:92.39ms +step:1075/1645 train_time:99315ms step_avg:92.39ms +step:1076/1645 train_time:99407ms step_avg:92.39ms +step:1077/1645 train_time:99500ms step_avg:92.39ms +step:1078/1645 train_time:99592ms step_avg:92.39ms +step:1079/1645 train_time:99684ms step_avg:92.39ms +step:1080/1645 train_time:99779ms step_avg:92.39ms +step:1081/1645 train_time:99872ms step_avg:92.39ms +step:1082/1645 train_time:99965ms step_avg:92.39ms +step:1083/1645 train_time:100058ms step_avg:92.39ms +step:1084/1645 train_time:100150ms step_avg:92.39ms +step:1085/1645 train_time:100243ms step_avg:92.39ms +step:1086/1645 train_time:100336ms step_avg:92.39ms +step:1087/1645 train_time:100428ms step_avg:92.39ms +step:1088/1645 train_time:100520ms step_avg:92.39ms +step:1089/1645 train_time:100613ms step_avg:92.39ms +step:1090/1645 train_time:100706ms step_avg:92.39ms +step:1091/1645 train_time:100799ms step_avg:92.39ms +step:1092/1645 train_time:100892ms step_avg:92.39ms +step:1093/1645 train_time:100985ms step_avg:92.39ms +step:1094/1645 train_time:101078ms step_avg:92.39ms +step:1095/1645 train_time:101173ms step_avg:92.40ms +step:1096/1645 train_time:101265ms step_avg:92.39ms +step:1097/1645 train_time:101358ms step_avg:92.40ms +step:1098/1645 train_time:101451ms step_avg:92.40ms +step:1099/1645 train_time:101543ms step_avg:92.40ms +step:1100/1645 train_time:101637ms step_avg:92.40ms +step:1101/1645 train_time:101731ms step_avg:92.40ms +step:1102/1645 train_time:101824ms step_avg:92.40ms +step:1103/1645 train_time:101917ms step_avg:92.40ms +step:1104/1645 train_time:102011ms step_avg:92.40ms +step:1105/1645 train_time:102105ms step_avg:92.40ms +step:1106/1645 train_time:102199ms step_avg:92.40ms +step:1107/1645 train_time:102292ms step_avg:92.40ms +step:1108/1645 train_time:102385ms step_avg:92.41ms +step:1109/1645 train_time:102480ms step_avg:92.41ms +step:1110/1645 train_time:102573ms step_avg:92.41ms +step:1111/1645 train_time:102666ms step_avg:92.41ms +step:1112/1645 train_time:102760ms step_avg:92.41ms +step:1113/1645 train_time:102853ms step_avg:92.41ms +step:1114/1645 train_time:102946ms step_avg:92.41ms +step:1115/1645 train_time:103039ms step_avg:92.41ms +step:1116/1645 train_time:103134ms step_avg:92.41ms +step:1117/1645 train_time:103227ms step_avg:92.41ms +step:1118/1645 train_time:103321ms step_avg:92.42ms +step:1119/1645 train_time:103414ms step_avg:92.42ms +step:1120/1645 train_time:103507ms step_avg:92.42ms +step:1121/1645 train_time:103601ms step_avg:92.42ms +step:1122/1645 train_time:103695ms step_avg:92.42ms +step:1123/1645 train_time:103789ms step_avg:92.42ms +step:1124/1645 train_time:103882ms step_avg:92.42ms +step:1125/1645 train_time:103975ms step_avg:92.42ms +step:1125/1645 val_loss:3.4111 train_time:104069ms step_avg:92.51ms +step:1126/1645 train_time:104090ms step_avg:92.44ms +step:1127/1645 train_time:104171ms step_avg:92.43ms +step:1128/1645 train_time:104272ms step_avg:92.44ms +step:1129/1645 train_time:104366ms step_avg:92.44ms +step:1130/1645 train_time:104459ms step_avg:92.44ms +step:1131/1645 train_time:104551ms step_avg:92.44ms +step:1132/1645 train_time:104643ms step_avg:92.44ms +step:1133/1645 train_time:104736ms step_avg:92.44ms +step:1134/1645 train_time:104829ms step_avg:92.44ms +step:1135/1645 train_time:104920ms step_avg:92.44ms +step:1136/1645 train_time:105013ms step_avg:92.44ms +step:1137/1645 train_time:105107ms step_avg:92.44ms +step:1138/1645 train_time:105205ms step_avg:92.45ms +step:1139/1645 train_time:105301ms step_avg:92.45ms +step:1140/1645 train_time:105395ms step_avg:92.45ms +step:1141/1645 train_time:105489ms step_avg:92.45ms +step:1142/1645 train_time:105582ms step_avg:92.45ms +step:1143/1645 train_time:105675ms step_avg:92.45ms +step:1144/1645 train_time:105767ms step_avg:92.45ms +step:1145/1645 train_time:105860ms step_avg:92.45ms +step:1146/1645 train_time:105953ms step_avg:92.45ms +step:1147/1645 train_time:106046ms step_avg:92.46ms +step:1148/1645 train_time:106140ms step_avg:92.46ms +step:1149/1645 train_time:106235ms step_avg:92.46ms +step:1150/1645 train_time:106330ms step_avg:92.46ms +step:1151/1645 train_time:106425ms step_avg:92.46ms +step:1152/1645 train_time:106519ms step_avg:92.46ms +step:1153/1645 train_time:106612ms step_avg:92.46ms +step:1154/1645 train_time:106705ms step_avg:92.47ms +step:1155/1645 train_time:106798ms step_avg:92.47ms +step:1156/1645 train_time:106891ms step_avg:92.47ms +step:1157/1645 train_time:106983ms step_avg:92.47ms +step:1158/1645 train_time:107078ms step_avg:92.47ms +step:1159/1645 train_time:107172ms step_avg:92.47ms +step:1160/1645 train_time:107267ms step_avg:92.47ms +step:1161/1645 train_time:107361ms step_avg:92.47ms +step:1162/1645 train_time:107455ms step_avg:92.47ms +step:1163/1645 train_time:107550ms step_avg:92.48ms +step:1164/1645 train_time:107642ms step_avg:92.48ms +step:1165/1645 train_time:107735ms step_avg:92.48ms +step:1166/1645 train_time:107828ms step_avg:92.48ms +step:1167/1645 train_time:107922ms step_avg:92.48ms +step:1168/1645 train_time:108014ms step_avg:92.48ms +step:1169/1645 train_time:108108ms step_avg:92.48ms +step:1170/1645 train_time:108202ms step_avg:92.48ms +step:1171/1645 train_time:108296ms step_avg:92.48ms +step:1172/1645 train_time:108390ms step_avg:92.48ms +step:1173/1645 train_time:108483ms step_avg:92.48ms +step:1174/1645 train_time:108577ms step_avg:92.48ms +step:1175/1645 train_time:108671ms step_avg:92.49ms +step:1176/1645 train_time:108764ms step_avg:92.49ms +step:1177/1645 train_time:108858ms step_avg:92.49ms +step:1178/1645 train_time:108951ms step_avg:92.49ms +step:1179/1645 train_time:109045ms step_avg:92.49ms +step:1180/1645 train_time:109139ms step_avg:92.49ms +step:1181/1645 train_time:109233ms step_avg:92.49ms +step:1182/1645 train_time:109327ms step_avg:92.49ms +step:1183/1645 train_time:109422ms step_avg:92.49ms +step:1184/1645 train_time:109515ms step_avg:92.50ms +step:1185/1645 train_time:109610ms step_avg:92.50ms +step:1186/1645 train_time:109703ms step_avg:92.50ms +step:1187/1645 train_time:109796ms step_avg:92.50ms +step:1188/1645 train_time:109889ms step_avg:92.50ms +step:1189/1645 train_time:109982ms step_avg:92.50ms +step:1190/1645 train_time:110077ms step_avg:92.50ms +step:1191/1645 train_time:110171ms step_avg:92.50ms +step:1192/1645 train_time:110266ms step_avg:92.50ms +step:1193/1645 train_time:110359ms step_avg:92.51ms +step:1194/1645 train_time:110453ms step_avg:92.51ms +step:1195/1645 train_time:110547ms step_avg:92.51ms +step:1196/1645 train_time:110640ms step_avg:92.51ms +step:1197/1645 train_time:110733ms step_avg:92.51ms +step:1198/1645 train_time:110827ms step_avg:92.51ms +step:1199/1645 train_time:110920ms step_avg:92.51ms +step:1200/1645 train_time:111013ms step_avg:92.51ms +step:1201/1645 train_time:111106ms step_avg:92.51ms +step:1202/1645 train_time:111200ms step_avg:92.51ms +step:1203/1645 train_time:111294ms step_avg:92.51ms +step:1204/1645 train_time:111388ms step_avg:92.51ms +step:1205/1645 train_time:111482ms step_avg:92.52ms +step:1206/1645 train_time:111575ms step_avg:92.52ms +step:1207/1645 train_time:111670ms step_avg:92.52ms +step:1208/1645 train_time:111762ms step_avg:92.52ms +step:1209/1645 train_time:111855ms step_avg:92.52ms +step:1210/1645 train_time:111949ms step_avg:92.52ms +step:1211/1645 train_time:112042ms step_avg:92.52ms +step:1212/1645 train_time:112135ms step_avg:92.52ms +step:1213/1645 train_time:112229ms step_avg:92.52ms +step:1214/1645 train_time:112323ms step_avg:92.52ms +step:1215/1645 train_time:112417ms step_avg:92.52ms +step:1216/1645 train_time:112511ms step_avg:92.53ms +step:1217/1645 train_time:112604ms step_avg:92.53ms +step:1218/1645 train_time:112698ms step_avg:92.53ms +step:1219/1645 train_time:112792ms step_avg:92.53ms +step:1220/1645 train_time:112884ms step_avg:92.53ms +step:1221/1645 train_time:112978ms step_avg:92.53ms +step:1222/1645 train_time:113071ms step_avg:92.53ms +step:1223/1645 train_time:113164ms step_avg:92.53ms +step:1224/1645 train_time:113259ms step_avg:92.53ms +step:1225/1645 train_time:113352ms step_avg:92.53ms +step:1226/1645 train_time:113446ms step_avg:92.53ms +step:1227/1645 train_time:113540ms step_avg:92.53ms +step:1228/1645 train_time:113633ms step_avg:92.54ms +step:1229/1645 train_time:113728ms step_avg:92.54ms +step:1230/1645 train_time:113822ms step_avg:92.54ms +step:1231/1645 train_time:113915ms step_avg:92.54ms +step:1232/1645 train_time:114008ms step_avg:92.54ms +step:1233/1645 train_time:114102ms step_avg:92.54ms +step:1234/1645 train_time:114195ms step_avg:92.54ms +step:1235/1645 train_time:114289ms step_avg:92.54ms +step:1236/1645 train_time:114384ms step_avg:92.54ms +step:1237/1645 train_time:114478ms step_avg:92.54ms +step:1238/1645 train_time:114572ms step_avg:92.55ms +step:1239/1645 train_time:114665ms step_avg:92.55ms +step:1240/1645 train_time:114758ms step_avg:92.55ms +step:1241/1645 train_time:114852ms step_avg:92.55ms +step:1242/1645 train_time:114946ms step_avg:92.55ms +step:1243/1645 train_time:115039ms step_avg:92.55ms +step:1244/1645 train_time:115132ms step_avg:92.55ms +step:1245/1645 train_time:115226ms step_avg:92.55ms +step:1246/1645 train_time:115320ms step_avg:92.55ms +step:1247/1645 train_time:115413ms step_avg:92.55ms +step:1248/1645 train_time:115507ms step_avg:92.55ms +step:1249/1645 train_time:115600ms step_avg:92.55ms +step:1250/1645 train_time:115693ms step_avg:92.55ms +step:1250/1645 val_loss:3.3728 train_time:115787ms step_avg:92.63ms +step:1251/1645 train_time:115807ms step_avg:92.57ms +step:1252/1645 train_time:115886ms step_avg:92.56ms +step:1253/1645 train_time:115981ms step_avg:92.56ms +step:1254/1645 train_time:116075ms step_avg:92.56ms +step:1255/1645 train_time:116167ms step_avg:92.56ms +step:1256/1645 train_time:116262ms step_avg:92.57ms +step:1257/1645 train_time:116355ms step_avg:92.57ms +step:1258/1645 train_time:116448ms step_avg:92.57ms +step:1259/1645 train_time:116541ms step_avg:92.57ms +step:1260/1645 train_time:116634ms step_avg:92.57ms +step:1261/1645 train_time:116728ms step_avg:92.57ms +step:1262/1645 train_time:116823ms step_avg:92.57ms +step:1263/1645 train_time:116918ms step_avg:92.57ms +step:1264/1645 train_time:117012ms step_avg:92.57ms +step:1265/1645 train_time:117107ms step_avg:92.57ms +step:1266/1645 train_time:117200ms step_avg:92.57ms +step:1267/1645 train_time:117293ms step_avg:92.58ms +step:1268/1645 train_time:117386ms step_avg:92.58ms +step:1269/1645 train_time:117480ms step_avg:92.58ms +step:1270/1645 train_time:117573ms step_avg:92.58ms +step:1271/1645 train_time:117666ms step_avg:92.58ms +step:1272/1645 train_time:117760ms step_avg:92.58ms +step:1273/1645 train_time:117855ms step_avg:92.58ms +step:1274/1645 train_time:117950ms step_avg:92.58ms +step:1275/1645 train_time:118044ms step_avg:92.58ms +step:1276/1645 train_time:118139ms step_avg:92.59ms +step:1277/1645 train_time:118232ms step_avg:92.59ms +step:1278/1645 train_time:118325ms step_avg:92.59ms +step:1279/1645 train_time:118418ms step_avg:92.59ms +step:1280/1645 train_time:118511ms step_avg:92.59ms +step:1281/1645 train_time:118603ms step_avg:92.59ms +step:1282/1645 train_time:118697ms step_avg:92.59ms +step:1283/1645 train_time:118791ms step_avg:92.59ms +step:1284/1645 train_time:118886ms step_avg:92.59ms +step:1285/1645 train_time:118979ms step_avg:92.59ms +step:1286/1645 train_time:119073ms step_avg:92.59ms +step:1287/1645 train_time:119167ms step_avg:92.59ms +step:1288/1645 train_time:119260ms step_avg:92.59ms +step:1289/1645 train_time:119354ms step_avg:92.59ms +step:1290/1645 train_time:119447ms step_avg:92.59ms +step:1291/1645 train_time:119541ms step_avg:92.60ms +step:1292/1645 train_time:119633ms step_avg:92.60ms +step:1293/1645 train_time:119726ms step_avg:92.60ms +step:1294/1645 train_time:119820ms step_avg:92.60ms +step:1295/1645 train_time:119914ms step_avg:92.60ms +step:1296/1645 train_time:120007ms step_avg:92.60ms +step:1297/1645 train_time:120102ms step_avg:92.60ms +step:1298/1645 train_time:120196ms step_avg:92.60ms +step:1299/1645 train_time:120289ms step_avg:92.60ms +step:1300/1645 train_time:120382ms step_avg:92.60ms +step:1301/1645 train_time:120475ms step_avg:92.60ms +step:1302/1645 train_time:120568ms step_avg:92.60ms +step:1303/1645 train_time:120662ms step_avg:92.60ms +step:1304/1645 train_time:120756ms step_avg:92.60ms +step:1305/1645 train_time:120850ms step_avg:92.61ms +step:1306/1645 train_time:120944ms step_avg:92.61ms +step:1307/1645 train_time:121038ms step_avg:92.61ms +step:1308/1645 train_time:121133ms step_avg:92.61ms +step:1309/1645 train_time:121227ms step_avg:92.61ms +step:1310/1645 train_time:121321ms step_avg:92.61ms +step:1311/1645 train_time:121413ms step_avg:92.61ms +step:1312/1645 train_time:121506ms step_avg:92.61ms +step:1313/1645 train_time:121599ms step_avg:92.61ms +step:1314/1645 train_time:121692ms step_avg:92.61ms +step:1315/1645 train_time:121786ms step_avg:92.61ms +step:1316/1645 train_time:121880ms step_avg:92.61ms +step:1317/1645 train_time:121974ms step_avg:92.61ms +step:1318/1645 train_time:122067ms step_avg:92.62ms +step:1319/1645 train_time:122162ms step_avg:92.62ms +step:1320/1645 train_time:122256ms step_avg:92.62ms +step:1321/1645 train_time:122350ms step_avg:92.62ms +step:1322/1645 train_time:122443ms step_avg:92.62ms +step:1323/1645 train_time:122536ms step_avg:92.62ms +step:1324/1645 train_time:122630ms step_avg:92.62ms +step:1325/1645 train_time:122725ms step_avg:92.62ms +step:1326/1645 train_time:122819ms step_avg:92.62ms +step:1327/1645 train_time:122913ms step_avg:92.62ms +step:1328/1645 train_time:123006ms step_avg:92.63ms +step:1329/1645 train_time:123099ms step_avg:92.63ms +step:1330/1645 train_time:123193ms step_avg:92.63ms +step:1331/1645 train_time:123286ms step_avg:92.63ms +step:1332/1645 train_time:123380ms step_avg:92.63ms +step:1333/1645 train_time:123473ms step_avg:92.63ms +step:1334/1645 train_time:123566ms step_avg:92.63ms +step:1335/1645 train_time:123660ms step_avg:92.63ms +step:1336/1645 train_time:123754ms step_avg:92.63ms +step:1337/1645 train_time:123849ms step_avg:92.63ms +step:1338/1645 train_time:123943ms step_avg:92.63ms +step:1339/1645 train_time:124037ms step_avg:92.63ms +step:1340/1645 train_time:124131ms step_avg:92.63ms +step:1341/1645 train_time:124224ms step_avg:92.64ms +step:1342/1645 train_time:124318ms step_avg:92.64ms +step:1343/1645 train_time:124412ms step_avg:92.64ms +step:1344/1645 train_time:124505ms step_avg:92.64ms +step:1345/1645 train_time:124597ms step_avg:92.64ms +step:1346/1645 train_time:124691ms step_avg:92.64ms +step:1347/1645 train_time:124784ms step_avg:92.64ms +step:1348/1645 train_time:124878ms step_avg:92.64ms +step:1349/1645 train_time:124973ms step_avg:92.64ms +step:1350/1645 train_time:125066ms step_avg:92.64ms +step:1351/1645 train_time:125160ms step_avg:92.64ms +step:1352/1645 train_time:125253ms step_avg:92.64ms +step:1353/1645 train_time:125347ms step_avg:92.64ms +step:1354/1645 train_time:125440ms step_avg:92.64ms +step:1355/1645 train_time:125533ms step_avg:92.64ms +step:1356/1645 train_time:125627ms step_avg:92.65ms +step:1357/1645 train_time:125722ms step_avg:92.65ms +step:1358/1645 train_time:125816ms step_avg:92.65ms +step:1359/1645 train_time:125909ms step_avg:92.65ms +step:1360/1645 train_time:126002ms step_avg:92.65ms +step:1361/1645 train_time:126096ms step_avg:92.65ms +step:1362/1645 train_time:126189ms step_avg:92.65ms +step:1363/1645 train_time:126283ms step_avg:92.65ms +step:1364/1645 train_time:126376ms step_avg:92.65ms +step:1365/1645 train_time:126471ms step_avg:92.65ms +step:1366/1645 train_time:126563ms step_avg:92.65ms +step:1367/1645 train_time:126658ms step_avg:92.65ms +step:1368/1645 train_time:126753ms step_avg:92.66ms +step:1369/1645 train_time:126846ms step_avg:92.66ms +step:1370/1645 train_time:126940ms step_avg:92.66ms +step:1371/1645 train_time:127034ms step_avg:92.66ms +step:1372/1645 train_time:127127ms step_avg:92.66ms +step:1373/1645 train_time:127221ms step_avg:92.66ms +step:1374/1645 train_time:127314ms step_avg:92.66ms +step:1375/1645 train_time:127407ms step_avg:92.66ms +step:1375/1645 val_loss:3.3387 train_time:127501ms step_avg:92.73ms +step:1376/1645 train_time:127522ms step_avg:92.68ms +step:1377/1645 train_time:127598ms step_avg:92.66ms +step:1378/1645 train_time:127694ms step_avg:92.67ms +step:1379/1645 train_time:127788ms step_avg:92.67ms +step:1380/1645 train_time:127880ms step_avg:92.67ms +step:1381/1645 train_time:127973ms step_avg:92.67ms +step:1382/1645 train_time:128066ms step_avg:92.67ms +step:1383/1645 train_time:128159ms step_avg:92.67ms +step:1384/1645 train_time:128253ms step_avg:92.67ms +step:1385/1645 train_time:128347ms step_avg:92.67ms +step:1386/1645 train_time:128440ms step_avg:92.67ms +step:1387/1645 train_time:128536ms step_avg:92.67ms +step:1388/1645 train_time:128631ms step_avg:92.67ms +step:1389/1645 train_time:128727ms step_avg:92.68ms +step:1390/1645 train_time:128820ms step_avg:92.68ms +step:1391/1645 train_time:128914ms step_avg:92.68ms +step:1392/1645 train_time:129006ms step_avg:92.68ms +step:1393/1645 train_time:129099ms step_avg:92.68ms +step:1394/1645 train_time:129193ms step_avg:92.68ms +step:1395/1645 train_time:129286ms step_avg:92.68ms +step:1396/1645 train_time:129379ms step_avg:92.68ms +step:1397/1645 train_time:129474ms step_avg:92.68ms +step:1398/1645 train_time:129569ms step_avg:92.68ms +step:1399/1645 train_time:129663ms step_avg:92.68ms +step:1400/1645 train_time:129757ms step_avg:92.68ms +step:1401/1645 train_time:129850ms step_avg:92.68ms +step:1402/1645 train_time:129943ms step_avg:92.68ms +step:1403/1645 train_time:130037ms step_avg:92.68ms +step:1404/1645 train_time:130130ms step_avg:92.69ms +step:1405/1645 train_time:130224ms step_avg:92.69ms +step:1406/1645 train_time:130318ms step_avg:92.69ms +step:1407/1645 train_time:130410ms step_avg:92.69ms +step:1408/1645 train_time:130505ms step_avg:92.69ms +step:1409/1645 train_time:130599ms step_avg:92.69ms +step:1410/1645 train_time:130694ms step_avg:92.69ms +step:1411/1645 train_time:130787ms step_avg:92.69ms +step:1412/1645 train_time:130881ms step_avg:92.69ms +step:1413/1645 train_time:130974ms step_avg:92.69ms +step:1414/1645 train_time:131068ms step_avg:92.69ms +step:1415/1645 train_time:131161ms step_avg:92.69ms +step:1416/1645 train_time:131255ms step_avg:92.69ms +step:1417/1645 train_time:131348ms step_avg:92.69ms +step:1418/1645 train_time:131443ms step_avg:92.70ms +step:1419/1645 train_time:131536ms step_avg:92.70ms +step:1420/1645 train_time:131631ms step_avg:92.70ms +step:1421/1645 train_time:131724ms step_avg:92.70ms +step:1422/1645 train_time:131818ms step_avg:92.70ms +step:1423/1645 train_time:131912ms step_avg:92.70ms +step:1424/1645 train_time:132004ms step_avg:92.70ms +step:1425/1645 train_time:132097ms step_avg:92.70ms +step:1426/1645 train_time:132192ms step_avg:92.70ms +step:1427/1645 train_time:132286ms step_avg:92.70ms +step:1428/1645 train_time:132380ms step_avg:92.70ms +step:1429/1645 train_time:132474ms step_avg:92.70ms +step:1430/1645 train_time:132568ms step_avg:92.70ms +step:1431/1645 train_time:132661ms step_avg:92.71ms +step:1432/1645 train_time:132755ms step_avg:92.71ms +step:1433/1645 train_time:132849ms step_avg:92.71ms +step:1434/1645 train_time:132943ms step_avg:92.71ms +step:1435/1645 train_time:133037ms step_avg:92.71ms +step:1436/1645 train_time:133132ms step_avg:92.71ms +step:1437/1645 train_time:133225ms step_avg:92.71ms +step:1438/1645 train_time:133318ms step_avg:92.71ms +step:1439/1645 train_time:133412ms step_avg:92.71ms +step:1440/1645 train_time:133505ms step_avg:92.71ms +step:1441/1645 train_time:133599ms step_avg:92.71ms +step:1442/1645 train_time:133693ms step_avg:92.71ms +step:1443/1645 train_time:133787ms step_avg:92.71ms +step:1444/1645 train_time:133881ms step_avg:92.72ms +step:1445/1645 train_time:133975ms step_avg:92.72ms +step:1446/1645 train_time:134069ms step_avg:92.72ms +step:1447/1645 train_time:134162ms step_avg:92.72ms +step:1448/1645 train_time:134255ms step_avg:92.72ms +step:1449/1645 train_time:134348ms step_avg:92.72ms +step:1450/1645 train_time:134441ms step_avg:92.72ms +step:1451/1645 train_time:134535ms step_avg:92.72ms +step:1452/1645 train_time:134628ms step_avg:92.72ms +step:1453/1645 train_time:134722ms step_avg:92.72ms +step:1454/1645 train_time:134816ms step_avg:92.72ms +step:1455/1645 train_time:134910ms step_avg:92.72ms +step:1456/1645 train_time:135003ms step_avg:92.72ms +step:1457/1645 train_time:135096ms step_avg:92.72ms +step:1458/1645 train_time:135190ms step_avg:92.72ms +step:1459/1645 train_time:135283ms step_avg:92.72ms +step:1460/1645 train_time:135377ms step_avg:92.72ms +step:1461/1645 train_time:135471ms step_avg:92.73ms +step:1462/1645 train_time:135565ms step_avg:92.73ms +step:1463/1645 train_time:135659ms step_avg:92.73ms +step:1464/1645 train_time:135753ms step_avg:92.73ms +step:1465/1645 train_time:135846ms step_avg:92.73ms +step:1466/1645 train_time:135940ms step_avg:92.73ms +step:1467/1645 train_time:136033ms step_avg:92.73ms +step:1468/1645 train_time:136126ms step_avg:92.73ms +step:1469/1645 train_time:136220ms step_avg:92.73ms +step:1470/1645 train_time:136313ms step_avg:92.73ms +step:1471/1645 train_time:136407ms step_avg:92.73ms +step:1472/1645 train_time:136500ms step_avg:92.73ms +step:1473/1645 train_time:136594ms step_avg:92.73ms +step:1474/1645 train_time:136688ms step_avg:92.73ms +step:1475/1645 train_time:136782ms step_avg:92.73ms +step:1476/1645 train_time:136876ms step_avg:92.73ms +step:1477/1645 train_time:136970ms step_avg:92.74ms +step:1478/1645 train_time:137064ms step_avg:92.74ms +step:1479/1645 train_time:137157ms step_avg:92.74ms +step:1480/1645 train_time:137251ms step_avg:92.74ms +step:1481/1645 train_time:137345ms step_avg:92.74ms +step:1482/1645 train_time:137438ms step_avg:92.74ms +step:1483/1645 train_time:137532ms step_avg:92.74ms +step:1484/1645 train_time:137626ms step_avg:92.74ms +step:1485/1645 train_time:137720ms step_avg:92.74ms +step:1486/1645 train_time:137813ms step_avg:92.74ms +step:1487/1645 train_time:137907ms step_avg:92.74ms +step:1488/1645 train_time:138000ms step_avg:92.74ms +step:1489/1645 train_time:138095ms step_avg:92.74ms +step:1490/1645 train_time:138189ms step_avg:92.74ms +step:1491/1645 train_time:138283ms step_avg:92.75ms +step:1492/1645 train_time:138376ms step_avg:92.75ms +step:1493/1645 train_time:138469ms step_avg:92.75ms +step:1494/1645 train_time:138564ms step_avg:92.75ms +step:1495/1645 train_time:138658ms step_avg:92.75ms +step:1496/1645 train_time:138751ms step_avg:92.75ms +step:1497/1645 train_time:138845ms step_avg:92.75ms +step:1498/1645 train_time:138938ms step_avg:92.75ms +step:1499/1645 train_time:139032ms step_avg:92.75ms +step:1500/1645 train_time:139126ms step_avg:92.75ms +step:1500/1645 val_loss:3.3087 train_time:139219ms step_avg:92.81ms +step:1501/1645 train_time:139241ms step_avg:92.77ms +step:1502/1645 train_time:139316ms step_avg:92.75ms +step:1503/1645 train_time:139412ms step_avg:92.76ms +step:1504/1645 train_time:139505ms step_avg:92.76ms +step:1505/1645 train_time:139598ms step_avg:92.76ms +step:1506/1645 train_time:139690ms step_avg:92.76ms +step:1507/1645 train_time:139784ms step_avg:92.76ms +step:1508/1645 train_time:139877ms step_avg:92.76ms +step:1509/1645 train_time:139970ms step_avg:92.76ms +step:1510/1645 train_time:140063ms step_avg:92.76ms +step:1511/1645 train_time:140158ms step_avg:92.76ms +step:1512/1645 train_time:140255ms step_avg:92.76ms +step:1513/1645 train_time:140351ms step_avg:92.76ms +step:1514/1645 train_time:140446ms step_avg:92.76ms +step:1515/1645 train_time:140539ms step_avg:92.76ms +step:1516/1645 train_time:140632ms step_avg:92.76ms +step:1517/1645 train_time:140725ms step_avg:92.77ms +step:1518/1645 train_time:140819ms step_avg:92.77ms +step:1519/1645 train_time:140912ms step_avg:92.77ms +step:1520/1645 train_time:141005ms step_avg:92.77ms +step:1521/1645 train_time:141099ms step_avg:92.77ms +step:1522/1645 train_time:141193ms step_avg:92.77ms +step:1523/1645 train_time:141288ms step_avg:92.77ms +step:1524/1645 train_time:141382ms step_avg:92.77ms +step:1525/1645 train_time:141476ms step_avg:92.77ms +step:1526/1645 train_time:141569ms step_avg:92.77ms +step:1527/1645 train_time:141663ms step_avg:92.77ms +step:1528/1645 train_time:141756ms step_avg:92.77ms +step:1529/1645 train_time:141849ms step_avg:92.77ms +step:1530/1645 train_time:141942ms step_avg:92.77ms +step:1531/1645 train_time:142035ms step_avg:92.77ms +step:1532/1645 train_time:142129ms step_avg:92.77ms +step:1533/1645 train_time:142224ms step_avg:92.77ms +step:1534/1645 train_time:142318ms step_avg:92.78ms +step:1535/1645 train_time:142412ms step_avg:92.78ms +step:1536/1645 train_time:142506ms step_avg:92.78ms +step:1537/1645 train_time:142600ms step_avg:92.78ms +step:1538/1645 train_time:142693ms step_avg:92.78ms +step:1539/1645 train_time:142787ms step_avg:92.78ms +step:1540/1645 train_time:142880ms step_avg:92.78ms +step:1541/1645 train_time:142974ms step_avg:92.78ms +step:1542/1645 train_time:143066ms step_avg:92.78ms +step:1543/1645 train_time:143161ms step_avg:92.78ms +step:1544/1645 train_time:143255ms step_avg:92.78ms +step:1545/1645 train_time:143349ms step_avg:92.78ms +step:1546/1645 train_time:143442ms step_avg:92.78ms +step:1547/1645 train_time:143536ms step_avg:92.78ms +step:1548/1645 train_time:143629ms step_avg:92.78ms +step:1549/1645 train_time:143723ms step_avg:92.78ms +step:1550/1645 train_time:143817ms step_avg:92.79ms +step:1551/1645 train_time:143910ms step_avg:92.79ms +step:1552/1645 train_time:144005ms step_avg:92.79ms +step:1553/1645 train_time:144098ms step_avg:92.79ms +step:1554/1645 train_time:144191ms step_avg:92.79ms +step:1555/1645 train_time:144285ms step_avg:92.79ms +step:1556/1645 train_time:144379ms step_avg:92.79ms +step:1557/1645 train_time:144473ms step_avg:92.79ms +step:1558/1645 train_time:144566ms step_avg:92.79ms +step:1559/1645 train_time:144659ms step_avg:92.79ms +step:1560/1645 train_time:144753ms step_avg:92.79ms +step:1561/1645 train_time:144846ms step_avg:92.79ms +step:1562/1645 train_time:144940ms step_avg:92.79ms +step:1563/1645 train_time:145033ms step_avg:92.79ms +step:1564/1645 train_time:145127ms step_avg:92.79ms +step:1565/1645 train_time:145221ms step_avg:92.79ms +step:1566/1645 train_time:145315ms step_avg:92.79ms +step:1567/1645 train_time:145409ms step_avg:92.79ms +step:1568/1645 train_time:145502ms step_avg:92.79ms +step:1569/1645 train_time:145596ms step_avg:92.80ms +step:1570/1645 train_time:145689ms step_avg:92.80ms +step:1571/1645 train_time:145783ms step_avg:92.80ms +step:1572/1645 train_time:145876ms step_avg:92.80ms +step:1573/1645 train_time:145969ms step_avg:92.80ms +step:1574/1645 train_time:146062ms step_avg:92.80ms +step:1575/1645 train_time:146157ms step_avg:92.80ms +step:1576/1645 train_time:146250ms step_avg:92.80ms +step:1577/1645 train_time:146344ms step_avg:92.80ms +step:1578/1645 train_time:146438ms step_avg:92.80ms +step:1579/1645 train_time:146532ms step_avg:92.80ms +step:1580/1645 train_time:146627ms step_avg:92.80ms +step:1581/1645 train_time:146722ms step_avg:92.80ms +step:1582/1645 train_time:146817ms step_avg:92.80ms +step:1583/1645 train_time:146910ms step_avg:92.80ms +step:1584/1645 train_time:147003ms step_avg:92.81ms +step:1585/1645 train_time:147097ms step_avg:92.81ms +step:1586/1645 train_time:147191ms step_avg:92.81ms +step:1587/1645 train_time:147285ms step_avg:92.81ms +step:1588/1645 train_time:147378ms step_avg:92.81ms +step:1589/1645 train_time:147472ms step_avg:92.81ms +step:1590/1645 train_time:147565ms step_avg:92.81ms +step:1591/1645 train_time:147658ms step_avg:92.81ms +step:1592/1645 train_time:147751ms step_avg:92.81ms +step:1593/1645 train_time:147844ms step_avg:92.81ms +step:1594/1645 train_time:147938ms step_avg:92.81ms +step:1595/1645 train_time:148031ms step_avg:92.81ms +step:1596/1645 train_time:148126ms step_avg:92.81ms +step:1597/1645 train_time:148221ms step_avg:92.81ms +step:1598/1645 train_time:148315ms step_avg:92.81ms +step:1599/1645 train_time:148408ms step_avg:92.81ms +step:1600/1645 train_time:148502ms step_avg:92.81ms +step:1601/1645 train_time:148596ms step_avg:92.81ms +step:1602/1645 train_time:148689ms step_avg:92.81ms +step:1603/1645 train_time:148783ms step_avg:92.82ms +step:1604/1645 train_time:148877ms step_avg:92.82ms +step:1605/1645 train_time:148970ms step_avg:92.82ms +step:1606/1645 train_time:149063ms step_avg:92.82ms +step:1607/1645 train_time:149157ms step_avg:92.82ms +step:1608/1645 train_time:149251ms step_avg:92.82ms +step:1609/1645 train_time:149345ms step_avg:92.82ms +step:1610/1645 train_time:149439ms step_avg:92.82ms +step:1611/1645 train_time:149534ms step_avg:92.82ms +step:1612/1645 train_time:149628ms step_avg:92.82ms +step:1613/1645 train_time:149722ms step_avg:92.82ms +step:1614/1645 train_time:149816ms step_avg:92.82ms +step:1615/1645 train_time:149909ms step_avg:92.82ms +step:1616/1645 train_time:150003ms step_avg:92.82ms +step:1617/1645 train_time:150096ms step_avg:92.82ms +step:1618/1645 train_time:150190ms step_avg:92.82ms +step:1619/1645 train_time:150283ms step_avg:92.82ms +step:1620/1645 train_time:150378ms step_avg:92.83ms +step:1621/1645 train_time:150472ms step_avg:92.83ms +step:1622/1645 train_time:150565ms step_avg:92.83ms +step:1623/1645 train_time:150658ms step_avg:92.83ms +step:1624/1645 train_time:150752ms step_avg:92.83ms +step:1625/1645 train_time:150846ms step_avg:92.83ms +step:1625/1645 val_loss:3.2849 train_time:150940ms step_avg:92.89ms +step:1626/1645 train_time:150963ms step_avg:92.84ms +step:1627/1645 train_time:151038ms step_avg:92.83ms +step:1628/1645 train_time:151133ms step_avg:92.83ms +step:1629/1645 train_time:151225ms step_avg:92.83ms +step:1630/1645 train_time:151317ms step_avg:92.83ms +step:1631/1645 train_time:151410ms step_avg:92.83ms +step:1632/1645 train_time:151504ms step_avg:92.83ms +step:1633/1645 train_time:151597ms step_avg:92.83ms +step:1634/1645 train_time:151689ms step_avg:92.83ms +step:1635/1645 train_time:151781ms step_avg:92.83ms +step:1636/1645 train_time:151877ms step_avg:92.83ms +step:1637/1645 train_time:151973ms step_avg:92.84ms +step:1638/1645 train_time:152068ms step_avg:92.84ms +step:1639/1645 train_time:152162ms step_avg:92.84ms +step:1640/1645 train_time:152255ms step_avg:92.84ms +step:1641/1645 train_time:152349ms step_avg:92.84ms +step:1642/1645 train_time:152442ms step_avg:92.84ms +step:1643/1645 train_time:152536ms step_avg:92.84ms +step:1644/1645 train_time:152629ms step_avg:92.84ms +step:1645/1645 train_time:152722ms step_avg:92.84ms +step:1645/1645 val_loss:3.2792 train_time:152816ms step_avg:92.90ms +peak memory allocated: 32074 MiB reserved: 47734 MiB diff --git a/train_gpt.py b/train_gpt.py index 4b2935897..a661d63d3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -872,6 +872,8 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i super().__init__() vocab_size = next_multiple_of_n(vocab_size, n=128) self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) @@ -884,7 +886,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i self.lm_head.weight.detach().zero_() # @Grad62304977 # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() + pad = (-num_layers * 6) % dist.get_world_size() self.scalars = nn.Parameter( torch.cat( [ @@ -896,6 +898,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i *[ torch.tensor([0.5, 0.5]) for _ in range(num_layers) ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda torch.ones(pad), ] ) @@ -921,7 +924,13 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] assert len(bm_sizes) == len(self.blocks) - x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) # U-net design by @brendanh0gan skip_connections = [] @@ -1144,7 +1153,7 @@ class Hyperparameters: train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1660 # number of iterations to run + num_iterations: int = 1645 # number of iterations to run cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate # evaluation and logging run_id: str = f"{uuid.uuid4()}" @@ -1221,6 +1230,7 @@ def nvidia_smi(): embed_params = [p for n, p in model.named_parameters() if "embed" in n] scalar_params = [p for p in model.parameters() if p.ndim < 2] head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] # init the optimizer(s) # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence @@ -1232,7 +1242,7 @@ def nvidia_smi(): eps=1e-8, weight_decay=0.0, ) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) optimizers = [optimizer1, optimizer2] for opt in optimizers: for group in opt.param_groups: From 2fe38e24c92f23360e6df5ef39eb7294d4358fbf Mon Sep 17 00:00:00 2001 From: larry dial Date: Sun, 21 Sep 2025 17:36:19 -0700 Subject: [PATCH 14/14] dropattn --- .../01fc4a96-f2a0-47a1-8a6a-c7d10bac99fe.txt | 3138 +++++++++++++++++ .../06350b97-4a98-47da-90c1-3b957af8af6c.txt | 3138 +++++++++++++++++ .../278c1540-3b42-4ab0-94ed-273830bdfa11.txt | 3138 +++++++++++++++++ .../79383017-eb05-4857-842a-7866b00571b4.txt | 3138 +++++++++++++++++ records/092125_DropAttn/README.md | 44 + .../a7f9849e-8c31-4e1a-9149-ab466d7c80b6.txt | 3138 +++++++++++++++++ .../ab8c620e-3d52-42eb-b46e-d69b608b22bc.txt | 3138 +++++++++++++++++ .../bc936c5a-1d9f-4405-8648-de50e4d5aca6.txt | 3138 +++++++++++++++++ .../be55679c-393d-432f-882d-287e7cfa727d.txt | 3138 +++++++++++++++++ .../d511e5c8-cce8-43ff-bac2-5168366ba47c.txt | 3138 +++++++++++++++++ .../e5a48f93-373e-4ff2-903b-5303bf912330.txt | 3138 +++++++++++++++++ train_gpt.py | 80 +- 12 files changed, 31460 insertions(+), 44 deletions(-) create mode 100644 records/092125_DropAttn/01fc4a96-f2a0-47a1-8a6a-c7d10bac99fe.txt create mode 100644 records/092125_DropAttn/06350b97-4a98-47da-90c1-3b957af8af6c.txt create mode 100644 records/092125_DropAttn/278c1540-3b42-4ab0-94ed-273830bdfa11.txt create mode 100644 records/092125_DropAttn/79383017-eb05-4857-842a-7866b00571b4.txt create mode 100644 records/092125_DropAttn/README.md create mode 100644 records/092125_DropAttn/a7f9849e-8c31-4e1a-9149-ab466d7c80b6.txt create mode 100644 records/092125_DropAttn/ab8c620e-3d52-42eb-b46e-d69b608b22bc.txt create mode 100644 records/092125_DropAttn/bc936c5a-1d9f-4405-8648-de50e4d5aca6.txt create mode 100644 records/092125_DropAttn/be55679c-393d-432f-882d-287e7cfa727d.txt create mode 100644 records/092125_DropAttn/d511e5c8-cce8-43ff-bac2-5168366ba47c.txt create mode 100644 records/092125_DropAttn/e5a48f93-373e-4ff2-903b-5303bf912330.txt diff --git a/records/092125_DropAttn/01fc4a96-f2a0-47a1-8a6a-c7d10bac99fe.txt b/records/092125_DropAttn/01fc4a96-f2a0-47a1-8a6a-c7d10bac99fe.txt new file mode 100644 index 000000000..ee60774a1 --- /dev/null +++ b/records/092125_DropAttn/01fc4a96-f2a0-47a1-8a6a-c7d10bac99fe.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 22:33:55 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 43C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 38C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 62023 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 62024 C /usr/bin/python3 614MiB | +| 0 N/A N/A 62025 C /usr/bin/python3 614MiB | +| 0 N/A N/A 62026 C /usr/bin/python3 614MiB | +| 0 N/A N/A 62027 C /usr/bin/python3 614MiB | +| 0 N/A N/A 62028 C /usr/bin/python3 614MiB | +| 0 N/A N/A 62029 C /usr/bin/python3 614MiB | +| 0 N/A N/A 62030 C /usr/bin/python3 614MiB | +| 1 N/A N/A 62024 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 62025 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 62026 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 62027 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 62028 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 62029 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 62030 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:159ms step_avg:158.58ms +step:2/1680 train_time:182ms step_avg:90.96ms +step:3/1680 train_time:244ms step_avg:81.43ms +step:4/1680 train_time:331ms step_avg:82.63ms +step:5/1680 train_time:418ms step_avg:83.66ms +step:6/1680 train_time:506ms step_avg:84.41ms +step:7/1680 train_time:595ms step_avg:84.95ms +step:8/1680 train_time:683ms step_avg:85.35ms +step:9/1680 train_time:771ms step_avg:85.65ms +step:10/1680 train_time:859ms step_avg:85.93ms +step:11/1680 train_time:948ms step_avg:86.15ms +step:12/1680 train_time:1038ms step_avg:86.52ms +step:13/1680 train_time:1132ms step_avg:87.04ms +step:14/1680 train_time:1223ms step_avg:87.34ms +step:15/1680 train_time:1312ms step_avg:87.45ms +step:16/1680 train_time:1402ms step_avg:87.60ms +step:17/1680 train_time:1491ms step_avg:87.69ms +step:18/1680 train_time:1579ms step_avg:87.74ms +step:19/1680 train_time:1668ms step_avg:87.81ms +step:20/1680 train_time:1756ms step_avg:87.82ms +step:21/1680 train_time:1845ms step_avg:87.86ms +step:22/1680 train_time:1934ms step_avg:87.89ms +step:23/1680 train_time:2025ms step_avg:88.04ms +step:24/1680 train_time:2115ms step_avg:88.14ms +step:25/1680 train_time:2206ms step_avg:88.23ms +step:26/1680 train_time:2295ms step_avg:88.29ms +step:27/1680 train_time:2385ms step_avg:88.35ms +step:28/1680 train_time:2474ms step_avg:88.36ms +step:29/1680 train_time:2563ms step_avg:88.37ms +step:30/1680 train_time:2652ms step_avg:88.39ms +step:31/1680 train_time:2740ms step_avg:88.40ms +step:32/1680 train_time:2829ms step_avg:88.41ms +step:33/1680 train_time:2918ms step_avg:88.41ms +step:34/1680 train_time:3007ms step_avg:88.45ms +step:35/1680 train_time:3097ms step_avg:88.48ms +step:36/1680 train_time:3187ms step_avg:88.53ms +step:37/1680 train_time:3277ms step_avg:88.57ms +step:38/1680 train_time:3367ms step_avg:88.62ms +step:39/1680 train_time:3457ms step_avg:88.63ms +step:40/1680 train_time:3545ms step_avg:88.62ms +step:41/1680 train_time:3634ms step_avg:88.63ms +step:42/1680 train_time:3723ms step_avg:88.65ms +step:43/1680 train_time:3813ms step_avg:88.68ms +step:44/1680 train_time:3901ms step_avg:88.65ms +step:45/1680 train_time:3990ms step_avg:88.67ms +step:46/1680 train_time:4079ms step_avg:88.67ms +step:47/1680 train_time:4169ms step_avg:88.70ms +step:48/1680 train_time:4259ms step_avg:88.73ms +step:49/1680 train_time:4348ms step_avg:88.74ms +step:50/1680 train_time:4438ms step_avg:88.76ms +step:51/1680 train_time:4527ms step_avg:88.77ms +step:52/1680 train_time:4617ms step_avg:88.79ms +step:53/1680 train_time:4707ms step_avg:88.81ms +step:54/1680 train_time:4796ms step_avg:88.81ms +step:55/1680 train_time:4884ms step_avg:88.80ms +step:56/1680 train_time:4973ms step_avg:88.81ms +step:57/1680 train_time:5064ms step_avg:88.83ms +step:58/1680 train_time:5153ms step_avg:88.84ms +step:59/1680 train_time:5242ms step_avg:88.85ms +step:60/1680 train_time:5331ms step_avg:88.85ms +step:61/1680 train_time:5420ms step_avg:88.86ms +step:62/1680 train_time:5510ms step_avg:88.87ms +step:63/1680 train_time:5599ms step_avg:88.87ms +step:64/1680 train_time:5688ms step_avg:88.87ms +step:65/1680 train_time:5776ms step_avg:88.87ms +step:66/1680 train_time:5866ms step_avg:88.88ms +step:67/1680 train_time:5955ms step_avg:88.88ms +step:68/1680 train_time:6045ms step_avg:88.90ms +step:69/1680 train_time:6134ms step_avg:88.91ms +step:70/1680 train_time:6224ms step_avg:88.92ms +step:71/1680 train_time:6313ms step_avg:88.91ms +step:72/1680 train_time:6402ms step_avg:88.92ms +step:73/1680 train_time:6491ms step_avg:88.92ms +step:74/1680 train_time:6580ms step_avg:88.92ms +step:75/1680 train_time:6669ms step_avg:88.92ms +step:76/1680 train_time:6759ms step_avg:88.93ms +step:77/1680 train_time:6848ms step_avg:88.94ms +step:78/1680 train_time:6938ms step_avg:88.94ms +step:79/1680 train_time:7027ms step_avg:88.95ms +step:80/1680 train_time:7116ms step_avg:88.95ms +step:81/1680 train_time:7206ms step_avg:88.96ms +step:82/1680 train_time:7295ms step_avg:88.96ms +step:83/1680 train_time:7384ms step_avg:88.96ms +step:84/1680 train_time:7473ms step_avg:88.97ms +step:85/1680 train_time:7562ms step_avg:88.97ms +step:86/1680 train_time:7651ms step_avg:88.97ms +step:87/1680 train_time:7740ms step_avg:88.96ms +step:88/1680 train_time:7829ms step_avg:88.96ms +step:89/1680 train_time:7918ms step_avg:88.96ms +step:90/1680 train_time:8007ms step_avg:88.96ms +step:91/1680 train_time:8096ms step_avg:88.97ms +step:92/1680 train_time:8185ms step_avg:88.97ms +step:93/1680 train_time:8275ms step_avg:88.98ms +step:94/1680 train_time:8364ms step_avg:88.98ms +step:95/1680 train_time:8453ms step_avg:88.98ms +step:96/1680 train_time:8542ms step_avg:88.98ms +step:97/1680 train_time:8631ms step_avg:88.98ms +step:98/1680 train_time:8719ms step_avg:88.97ms +step:99/1680 train_time:8808ms step_avg:88.97ms +step:100/1680 train_time:8898ms step_avg:88.98ms +step:101/1680 train_time:8987ms step_avg:88.99ms +step:102/1680 train_time:9076ms step_avg:88.98ms +step:103/1680 train_time:9167ms step_avg:89.00ms +step:104/1680 train_time:9256ms step_avg:89.00ms +step:105/1680 train_time:9346ms step_avg:89.01ms +step:106/1680 train_time:9435ms step_avg:89.01ms +step:107/1680 train_time:9524ms step_avg:89.00ms +step:108/1680 train_time:9614ms step_avg:89.01ms +step:109/1680 train_time:9704ms step_avg:89.02ms +step:110/1680 train_time:9793ms step_avg:89.02ms +step:111/1680 train_time:9881ms step_avg:89.02ms +step:112/1680 train_time:9971ms step_avg:89.02ms +step:113/1680 train_time:10060ms step_avg:89.02ms +step:114/1680 train_time:10149ms step_avg:89.02ms +step:115/1680 train_time:10238ms step_avg:89.03ms +step:116/1680 train_time:10328ms step_avg:89.03ms +step:117/1680 train_time:10416ms step_avg:89.03ms +step:118/1680 train_time:10505ms step_avg:89.03ms +step:119/1680 train_time:10594ms step_avg:89.03ms +step:120/1680 train_time:10683ms step_avg:89.02ms +step:121/1680 train_time:10772ms step_avg:89.03ms +step:122/1680 train_time:10861ms step_avg:89.02ms +step:123/1680 train_time:10949ms step_avg:89.02ms +step:124/1680 train_time:11038ms step_avg:89.02ms +step:125/1680 train_time:11127ms step_avg:89.02ms +step:125/1680 val_loss:4.3057 train_time:11217ms step_avg:89.74ms +step:126/1680 train_time:11241ms step_avg:89.21ms +step:127/1680 train_time:11309ms step_avg:89.05ms +step:128/1680 train_time:11408ms step_avg:89.13ms +step:129/1680 train_time:11501ms step_avg:89.16ms +step:130/1680 train_time:11590ms step_avg:89.16ms +step:131/1680 train_time:11679ms step_avg:89.15ms +step:132/1680 train_time:11766ms step_avg:89.14ms +step:133/1680 train_time:11854ms step_avg:89.13ms +step:134/1680 train_time:11942ms step_avg:89.12ms +step:135/1680 train_time:12030ms step_avg:89.11ms +step:136/1680 train_time:12119ms step_avg:89.11ms +step:137/1680 train_time:12209ms step_avg:89.12ms +step:138/1680 train_time:12299ms step_avg:89.12ms +step:139/1680 train_time:12391ms step_avg:89.14ms +step:140/1680 train_time:12481ms step_avg:89.15ms +step:141/1680 train_time:12571ms step_avg:89.16ms +step:142/1680 train_time:12661ms step_avg:89.16ms +step:143/1680 train_time:12749ms step_avg:89.16ms +step:144/1680 train_time:12838ms step_avg:89.15ms +step:145/1680 train_time:12926ms step_avg:89.14ms +step:146/1680 train_time:13014ms step_avg:89.14ms +step:147/1680 train_time:13102ms step_avg:89.13ms +step:148/1680 train_time:13190ms step_avg:89.12ms +step:149/1680 train_time:13280ms step_avg:89.12ms +step:150/1680 train_time:13370ms step_avg:89.14ms +step:151/1680 train_time:13461ms step_avg:89.14ms +step:152/1680 train_time:13550ms step_avg:89.15ms +step:153/1680 train_time:13640ms step_avg:89.15ms +step:154/1680 train_time:13729ms step_avg:89.15ms +step:155/1680 train_time:13820ms step_avg:89.16ms +step:156/1680 train_time:13908ms step_avg:89.15ms +step:157/1680 train_time:13997ms step_avg:89.15ms +step:158/1680 train_time:14085ms step_avg:89.15ms +step:159/1680 train_time:14173ms step_avg:89.14ms +step:160/1680 train_time:14262ms step_avg:89.14ms +step:161/1680 train_time:14351ms step_avg:89.14ms +step:162/1680 train_time:14441ms step_avg:89.14ms +step:163/1680 train_time:14531ms step_avg:89.15ms +step:164/1680 train_time:14621ms step_avg:89.15ms +step:165/1680 train_time:14710ms step_avg:89.15ms +step:166/1680 train_time:14800ms step_avg:89.16ms +step:167/1680 train_time:14888ms step_avg:89.15ms +step:168/1680 train_time:14976ms step_avg:89.15ms +step:169/1680 train_time:15065ms step_avg:89.14ms +step:170/1680 train_time:15154ms step_avg:89.14ms +step:171/1680 train_time:15243ms step_avg:89.14ms +step:172/1680 train_time:15332ms step_avg:89.14ms +step:173/1680 train_time:15421ms step_avg:89.14ms +step:174/1680 train_time:15511ms step_avg:89.14ms +step:175/1680 train_time:15600ms step_avg:89.14ms +step:176/1680 train_time:15689ms step_avg:89.14ms +step:177/1680 train_time:15778ms step_avg:89.14ms +step:178/1680 train_time:15867ms step_avg:89.14ms +step:179/1680 train_time:15956ms step_avg:89.14ms +step:180/1680 train_time:16044ms step_avg:89.13ms +step:181/1680 train_time:16133ms step_avg:89.13ms +step:182/1680 train_time:16222ms step_avg:89.13ms +step:183/1680 train_time:16312ms step_avg:89.14ms +step:184/1680 train_time:16401ms step_avg:89.14ms +step:185/1680 train_time:16490ms step_avg:89.14ms +step:186/1680 train_time:16579ms step_avg:89.14ms +step:187/1680 train_time:16668ms step_avg:89.14ms +step:188/1680 train_time:16758ms step_avg:89.14ms +step:189/1680 train_time:16846ms step_avg:89.13ms +step:190/1680 train_time:16935ms step_avg:89.13ms +step:191/1680 train_time:17025ms step_avg:89.13ms +step:192/1680 train_time:17114ms step_avg:89.13ms +step:193/1680 train_time:17203ms step_avg:89.13ms +step:194/1680 train_time:17291ms step_avg:89.13ms +step:195/1680 train_time:17380ms step_avg:89.13ms +step:196/1680 train_time:17468ms step_avg:89.12ms +step:197/1680 train_time:17557ms step_avg:89.12ms +step:198/1680 train_time:17646ms step_avg:89.12ms +step:199/1680 train_time:17736ms step_avg:89.12ms +step:200/1680 train_time:17825ms step_avg:89.12ms +step:201/1680 train_time:17914ms step_avg:89.12ms +step:202/1680 train_time:18003ms step_avg:89.12ms +step:203/1680 train_time:18092ms step_avg:89.12ms +step:204/1680 train_time:18181ms step_avg:89.12ms +step:205/1680 train_time:18270ms step_avg:89.12ms +step:206/1680 train_time:18359ms step_avg:89.12ms +step:207/1680 train_time:18449ms step_avg:89.12ms +step:208/1680 train_time:18538ms step_avg:89.13ms +step:209/1680 train_time:18627ms step_avg:89.12ms +step:210/1680 train_time:18716ms step_avg:89.13ms +step:211/1680 train_time:18806ms step_avg:89.13ms +step:212/1680 train_time:18895ms step_avg:89.13ms +step:213/1680 train_time:18984ms step_avg:89.13ms +step:214/1680 train_time:19073ms step_avg:89.13ms +step:215/1680 train_time:19162ms step_avg:89.13ms +step:216/1680 train_time:19252ms step_avg:89.13ms +step:217/1680 train_time:19340ms step_avg:89.13ms +step:218/1680 train_time:19430ms step_avg:89.13ms +step:219/1680 train_time:19520ms step_avg:89.13ms +step:220/1680 train_time:19609ms step_avg:89.13ms +step:221/1680 train_time:19698ms step_avg:89.13ms +step:222/1680 train_time:19786ms step_avg:89.13ms +step:223/1680 train_time:19874ms step_avg:89.12ms +step:224/1680 train_time:19963ms step_avg:89.12ms +step:225/1680 train_time:20052ms step_avg:89.12ms +step:226/1680 train_time:20141ms step_avg:89.12ms +step:227/1680 train_time:20231ms step_avg:89.12ms +step:228/1680 train_time:20320ms step_avg:89.12ms +step:229/1680 train_time:20410ms step_avg:89.13ms +step:230/1680 train_time:20498ms step_avg:89.12ms +step:231/1680 train_time:20588ms step_avg:89.13ms +step:232/1680 train_time:20677ms step_avg:89.12ms +step:233/1680 train_time:20765ms step_avg:89.12ms +step:234/1680 train_time:20853ms step_avg:89.12ms +step:235/1680 train_time:20942ms step_avg:89.12ms +step:236/1680 train_time:21031ms step_avg:89.12ms +step:237/1680 train_time:21121ms step_avg:89.12ms +step:238/1680 train_time:21210ms step_avg:89.12ms +step:239/1680 train_time:21299ms step_avg:89.12ms +step:240/1680 train_time:21389ms step_avg:89.12ms +step:241/1680 train_time:21478ms step_avg:89.12ms +step:242/1680 train_time:21567ms step_avg:89.12ms +step:243/1680 train_time:21657ms step_avg:89.12ms +step:244/1680 train_time:21745ms step_avg:89.12ms +step:245/1680 train_time:21834ms step_avg:89.12ms +step:246/1680 train_time:21923ms step_avg:89.12ms +step:247/1680 train_time:22012ms step_avg:89.12ms +step:248/1680 train_time:22101ms step_avg:89.12ms +step:249/1680 train_time:22189ms step_avg:89.11ms +step:250/1680 train_time:22279ms step_avg:89.11ms +step:250/1680 val_loss:3.9744 train_time:22370ms step_avg:89.48ms +step:251/1680 train_time:22393ms step_avg:89.21ms +step:252/1680 train_time:22463ms step_avg:89.14ms +step:253/1680 train_time:22558ms step_avg:89.16ms +step:254/1680 train_time:22650ms step_avg:89.17ms +step:255/1680 train_time:22738ms step_avg:89.17ms +step:256/1680 train_time:22826ms step_avg:89.16ms +step:257/1680 train_time:22913ms step_avg:89.16ms +step:258/1680 train_time:23001ms step_avg:89.15ms +step:259/1680 train_time:23088ms step_avg:89.14ms +step:260/1680 train_time:23176ms step_avg:89.14ms +step:261/1680 train_time:23264ms step_avg:89.13ms +step:262/1680 train_time:23352ms step_avg:89.13ms +step:263/1680 train_time:23442ms step_avg:89.13ms +step:264/1680 train_time:23533ms step_avg:89.14ms +step:265/1680 train_time:23623ms step_avg:89.14ms +step:266/1680 train_time:23713ms step_avg:89.15ms +step:267/1680 train_time:23801ms step_avg:89.14ms +step:268/1680 train_time:23890ms step_avg:89.14ms +step:269/1680 train_time:23977ms step_avg:89.13ms +step:270/1680 train_time:24066ms step_avg:89.13ms +step:271/1680 train_time:24154ms step_avg:89.13ms +step:272/1680 train_time:24242ms step_avg:89.12ms +step:273/1680 train_time:24329ms step_avg:89.12ms +step:274/1680 train_time:24418ms step_avg:89.12ms +step:275/1680 train_time:24508ms step_avg:89.12ms +step:276/1680 train_time:24598ms step_avg:89.12ms +step:277/1680 train_time:24688ms step_avg:89.13ms +step:278/1680 train_time:24777ms step_avg:89.12ms +step:279/1680 train_time:24865ms step_avg:89.12ms +step:280/1680 train_time:24954ms step_avg:89.12ms +step:281/1680 train_time:25042ms step_avg:89.12ms +step:282/1680 train_time:25131ms step_avg:89.12ms +step:283/1680 train_time:25219ms step_avg:89.11ms +step:284/1680 train_time:25308ms step_avg:89.11ms +step:285/1680 train_time:25397ms step_avg:89.11ms +step:286/1680 train_time:25487ms step_avg:89.11ms +step:287/1680 train_time:25576ms step_avg:89.12ms +step:288/1680 train_time:25666ms step_avg:89.12ms +step:289/1680 train_time:25755ms step_avg:89.12ms +step:290/1680 train_time:25845ms step_avg:89.12ms +step:291/1680 train_time:25934ms step_avg:89.12ms +step:292/1680 train_time:26023ms step_avg:89.12ms +step:293/1680 train_time:26111ms step_avg:89.12ms +step:294/1680 train_time:26200ms step_avg:89.11ms +step:295/1680 train_time:26288ms step_avg:89.11ms +step:296/1680 train_time:26377ms step_avg:89.11ms +step:297/1680 train_time:26466ms step_avg:89.11ms +step:298/1680 train_time:26556ms step_avg:89.11ms +step:299/1680 train_time:26646ms step_avg:89.12ms +step:300/1680 train_time:26736ms step_avg:89.12ms +step:301/1680 train_time:26825ms step_avg:89.12ms +step:302/1680 train_time:26914ms step_avg:89.12ms +step:303/1680 train_time:27003ms step_avg:89.12ms +step:304/1680 train_time:27091ms step_avg:89.12ms +step:305/1680 train_time:27180ms step_avg:89.11ms +step:306/1680 train_time:27268ms step_avg:89.11ms +step:307/1680 train_time:27356ms step_avg:89.11ms +step:308/1680 train_time:27445ms step_avg:89.11ms +step:309/1680 train_time:27535ms step_avg:89.11ms +step:310/1680 train_time:27624ms step_avg:89.11ms +step:311/1680 train_time:27713ms step_avg:89.11ms +step:312/1680 train_time:27803ms step_avg:89.11ms +step:313/1680 train_time:27892ms step_avg:89.11ms +step:314/1680 train_time:27981ms step_avg:89.11ms +step:315/1680 train_time:28070ms step_avg:89.11ms +step:316/1680 train_time:28158ms step_avg:89.11ms +step:317/1680 train_time:28247ms step_avg:89.11ms +step:318/1680 train_time:28337ms step_avg:89.11ms +step:319/1680 train_time:28426ms step_avg:89.11ms +step:320/1680 train_time:28514ms step_avg:89.11ms +step:321/1680 train_time:28604ms step_avg:89.11ms +step:322/1680 train_time:28693ms step_avg:89.11ms +step:323/1680 train_time:28782ms step_avg:89.11ms +step:324/1680 train_time:28872ms step_avg:89.11ms +step:325/1680 train_time:28960ms step_avg:89.11ms +step:326/1680 train_time:29049ms step_avg:89.11ms +step:327/1680 train_time:29138ms step_avg:89.11ms +step:328/1680 train_time:29227ms step_avg:89.11ms +step:329/1680 train_time:29316ms step_avg:89.10ms +step:330/1680 train_time:29404ms step_avg:89.10ms +step:331/1680 train_time:29494ms step_avg:89.11ms +step:332/1680 train_time:29584ms step_avg:89.11ms +step:333/1680 train_time:29673ms step_avg:89.11ms +step:334/1680 train_time:29763ms step_avg:89.11ms +step:335/1680 train_time:29852ms step_avg:89.11ms +step:336/1680 train_time:29941ms step_avg:89.11ms +step:337/1680 train_time:30030ms step_avg:89.11ms +step:338/1680 train_time:30119ms step_avg:89.11ms +step:339/1680 train_time:30208ms step_avg:89.11ms +step:340/1680 train_time:30297ms step_avg:89.11ms +step:341/1680 train_time:30385ms step_avg:89.11ms +step:342/1680 train_time:30475ms step_avg:89.11ms +step:343/1680 train_time:30565ms step_avg:89.11ms +step:344/1680 train_time:30655ms step_avg:89.11ms +step:345/1680 train_time:30743ms step_avg:89.11ms +step:346/1680 train_time:30833ms step_avg:89.11ms +step:347/1680 train_time:30923ms step_avg:89.11ms +step:348/1680 train_time:31011ms step_avg:89.11ms +step:349/1680 train_time:31100ms step_avg:89.11ms +step:350/1680 train_time:31189ms step_avg:89.11ms +step:351/1680 train_time:31278ms step_avg:89.11ms +step:352/1680 train_time:31367ms step_avg:89.11ms +step:353/1680 train_time:31457ms step_avg:89.11ms +step:354/1680 train_time:31545ms step_avg:89.11ms +step:355/1680 train_time:31635ms step_avg:89.11ms +step:356/1680 train_time:31724ms step_avg:89.11ms +step:357/1680 train_time:31813ms step_avg:89.11ms +step:358/1680 train_time:31902ms step_avg:89.11ms +step:359/1680 train_time:31992ms step_avg:89.11ms +step:360/1680 train_time:32081ms step_avg:89.11ms +step:361/1680 train_time:32170ms step_avg:89.11ms +step:362/1680 train_time:32259ms step_avg:89.11ms +step:363/1680 train_time:32348ms step_avg:89.11ms +step:364/1680 train_time:32436ms step_avg:89.11ms +step:365/1680 train_time:32525ms step_avg:89.11ms +step:366/1680 train_time:32614ms step_avg:89.11ms +step:367/1680 train_time:32703ms step_avg:89.11ms +step:368/1680 train_time:32791ms step_avg:89.11ms +step:369/1680 train_time:32880ms step_avg:89.11ms +step:370/1680 train_time:32968ms step_avg:89.10ms +step:371/1680 train_time:33057ms step_avg:89.10ms +step:372/1680 train_time:33146ms step_avg:89.10ms +step:373/1680 train_time:33236ms step_avg:89.10ms +step:374/1680 train_time:33325ms step_avg:89.10ms +step:375/1680 train_time:33414ms step_avg:89.10ms +step:375/1680 val_loss:3.8185 train_time:33504ms step_avg:89.34ms +step:376/1680 train_time:33527ms step_avg:89.17ms +step:377/1680 train_time:33595ms step_avg:89.11ms +step:378/1680 train_time:33691ms step_avg:89.13ms +step:379/1680 train_time:33783ms step_avg:89.14ms +step:380/1680 train_time:33871ms step_avg:89.14ms +step:381/1680 train_time:33959ms step_avg:89.13ms +step:382/1680 train_time:34047ms step_avg:89.13ms +step:383/1680 train_time:34135ms step_avg:89.13ms +step:384/1680 train_time:34223ms step_avg:89.12ms +step:385/1680 train_time:34311ms step_avg:89.12ms +step:386/1680 train_time:34398ms step_avg:89.12ms +step:387/1680 train_time:34487ms step_avg:89.11ms +step:388/1680 train_time:34577ms step_avg:89.12ms +step:389/1680 train_time:34669ms step_avg:89.12ms +step:390/1680 train_time:34759ms step_avg:89.13ms +step:391/1680 train_time:34848ms step_avg:89.13ms +step:392/1680 train_time:34937ms step_avg:89.13ms +step:393/1680 train_time:35026ms step_avg:89.12ms +step:394/1680 train_time:35113ms step_avg:89.12ms +step:395/1680 train_time:35202ms step_avg:89.12ms +step:396/1680 train_time:35291ms step_avg:89.12ms +step:397/1680 train_time:35379ms step_avg:89.12ms +step:398/1680 train_time:35467ms step_avg:89.11ms +step:399/1680 train_time:35556ms step_avg:89.11ms +step:400/1680 train_time:35647ms step_avg:89.12ms +step:401/1680 train_time:35737ms step_avg:89.12ms +step:402/1680 train_time:35828ms step_avg:89.12ms +step:403/1680 train_time:35916ms step_avg:89.12ms +step:404/1680 train_time:36005ms step_avg:89.12ms +step:405/1680 train_time:36094ms step_avg:89.12ms +step:406/1680 train_time:36183ms step_avg:89.12ms +step:407/1680 train_time:36270ms step_avg:89.12ms +step:408/1680 train_time:36358ms step_avg:89.11ms +step:409/1680 train_time:36447ms step_avg:89.11ms +step:410/1680 train_time:36535ms step_avg:89.11ms +step:411/1680 train_time:36624ms step_avg:89.11ms +step:412/1680 train_time:36714ms step_avg:89.11ms +step:413/1680 train_time:36804ms step_avg:89.11ms +step:414/1680 train_time:36894ms step_avg:89.12ms +step:415/1680 train_time:36983ms step_avg:89.12ms +step:416/1680 train_time:37072ms step_avg:89.12ms +step:417/1680 train_time:37162ms step_avg:89.12ms +step:418/1680 train_time:37250ms step_avg:89.11ms +step:419/1680 train_time:37338ms step_avg:89.11ms +step:420/1680 train_time:37427ms step_avg:89.11ms +step:421/1680 train_time:37516ms step_avg:89.11ms +step:422/1680 train_time:37604ms step_avg:89.11ms +step:423/1680 train_time:37693ms step_avg:89.11ms +step:424/1680 train_time:37782ms step_avg:89.11ms +step:425/1680 train_time:37872ms step_avg:89.11ms +step:426/1680 train_time:37960ms step_avg:89.11ms +step:427/1680 train_time:38050ms step_avg:89.11ms +step:428/1680 train_time:38139ms step_avg:89.11ms +step:429/1680 train_time:38228ms step_avg:89.11ms +step:430/1680 train_time:38317ms step_avg:89.11ms +step:431/1680 train_time:38406ms step_avg:89.11ms +step:432/1680 train_time:38495ms step_avg:89.11ms +step:433/1680 train_time:38584ms step_avg:89.11ms +step:434/1680 train_time:38674ms step_avg:89.11ms +step:435/1680 train_time:38763ms step_avg:89.11ms +step:436/1680 train_time:38852ms step_avg:89.11ms +step:437/1680 train_time:38942ms step_avg:89.11ms +step:438/1680 train_time:39031ms step_avg:89.11ms +step:439/1680 train_time:39120ms step_avg:89.11ms +step:440/1680 train_time:39209ms step_avg:89.11ms +step:441/1680 train_time:39298ms step_avg:89.11ms +step:442/1680 train_time:39387ms step_avg:89.11ms +step:443/1680 train_time:39476ms step_avg:89.11ms +step:444/1680 train_time:39564ms step_avg:89.11ms +step:445/1680 train_time:39654ms step_avg:89.11ms +step:446/1680 train_time:39743ms step_avg:89.11ms +step:447/1680 train_time:39833ms step_avg:89.11ms +step:448/1680 train_time:39922ms step_avg:89.11ms +step:449/1680 train_time:40011ms step_avg:89.11ms +step:450/1680 train_time:40100ms step_avg:89.11ms +step:451/1680 train_time:40189ms step_avg:89.11ms +step:452/1680 train_time:40277ms step_avg:89.11ms +step:453/1680 train_time:40366ms step_avg:89.11ms +step:454/1680 train_time:40455ms step_avg:89.11ms +step:455/1680 train_time:40544ms step_avg:89.11ms +step:456/1680 train_time:40634ms step_avg:89.11ms +step:457/1680 train_time:40724ms step_avg:89.11ms +step:458/1680 train_time:40813ms step_avg:89.11ms +step:459/1680 train_time:40902ms step_avg:89.11ms +step:460/1680 train_time:40991ms step_avg:89.11ms +step:461/1680 train_time:41081ms step_avg:89.11ms +step:462/1680 train_time:41169ms step_avg:89.11ms +step:463/1680 train_time:41258ms step_avg:89.11ms +step:464/1680 train_time:41347ms step_avg:89.11ms +step:465/1680 train_time:41436ms step_avg:89.11ms +step:466/1680 train_time:41526ms step_avg:89.11ms +step:467/1680 train_time:41616ms step_avg:89.11ms +step:468/1680 train_time:41703ms step_avg:89.11ms +step:469/1680 train_time:41793ms step_avg:89.11ms +step:470/1680 train_time:41882ms step_avg:89.11ms +step:471/1680 train_time:41972ms step_avg:89.11ms +step:472/1680 train_time:42061ms step_avg:89.11ms +step:473/1680 train_time:42150ms step_avg:89.11ms +step:474/1680 train_time:42239ms step_avg:89.11ms +step:475/1680 train_time:42328ms step_avg:89.11ms +step:476/1680 train_time:42417ms step_avg:89.11ms +step:477/1680 train_time:42506ms step_avg:89.11ms +step:478/1680 train_time:42595ms step_avg:89.11ms +step:479/1680 train_time:42683ms step_avg:89.11ms +step:480/1680 train_time:42772ms step_avg:89.11ms +step:481/1680 train_time:42861ms step_avg:89.11ms +step:482/1680 train_time:42951ms step_avg:89.11ms +step:483/1680 train_time:43041ms step_avg:89.11ms +step:484/1680 train_time:43130ms step_avg:89.11ms +step:485/1680 train_time:43218ms step_avg:89.11ms +step:486/1680 train_time:43308ms step_avg:89.11ms +step:487/1680 train_time:43397ms step_avg:89.11ms +step:488/1680 train_time:43485ms step_avg:89.11ms +step:489/1680 train_time:43573ms step_avg:89.11ms +step:490/1680 train_time:43662ms step_avg:89.11ms +step:491/1680 train_time:43752ms step_avg:89.11ms +step:492/1680 train_time:43841ms step_avg:89.11ms +step:493/1680 train_time:43930ms step_avg:89.11ms +step:494/1680 train_time:44020ms step_avg:89.11ms +step:495/1680 train_time:44109ms step_avg:89.11ms +step:496/1680 train_time:44198ms step_avg:89.11ms +step:497/1680 train_time:44288ms step_avg:89.11ms +step:498/1680 train_time:44377ms step_avg:89.11ms +step:499/1680 train_time:44466ms step_avg:89.11ms +step:500/1680 train_time:44555ms step_avg:89.11ms +step:500/1680 val_loss:3.7194 train_time:44645ms step_avg:89.29ms +step:501/1680 train_time:44667ms step_avg:89.16ms +step:502/1680 train_time:44736ms step_avg:89.12ms +step:503/1680 train_time:44834ms step_avg:89.13ms +step:504/1680 train_time:44924ms step_avg:89.13ms +step:505/1680 train_time:45013ms step_avg:89.13ms +step:506/1680 train_time:45101ms step_avg:89.13ms +step:507/1680 train_time:45190ms step_avg:89.13ms +step:508/1680 train_time:45278ms step_avg:89.13ms +step:509/1680 train_time:45366ms step_avg:89.13ms +step:510/1680 train_time:45454ms step_avg:89.13ms +step:511/1680 train_time:45542ms step_avg:89.12ms +step:512/1680 train_time:45631ms step_avg:89.12ms +step:513/1680 train_time:45721ms step_avg:89.13ms +step:514/1680 train_time:45813ms step_avg:89.13ms +step:515/1680 train_time:45904ms step_avg:89.13ms +step:516/1680 train_time:45994ms step_avg:89.14ms +step:517/1680 train_time:46083ms step_avg:89.13ms +step:518/1680 train_time:46171ms step_avg:89.13ms +step:519/1680 train_time:46260ms step_avg:89.13ms +step:520/1680 train_time:46349ms step_avg:89.13ms +step:521/1680 train_time:46437ms step_avg:89.13ms +step:522/1680 train_time:46525ms step_avg:89.13ms +step:523/1680 train_time:46613ms step_avg:89.13ms +step:524/1680 train_time:46702ms step_avg:89.13ms +step:525/1680 train_time:46793ms step_avg:89.13ms +step:526/1680 train_time:46885ms step_avg:89.13ms +step:527/1680 train_time:46976ms step_avg:89.14ms +step:528/1680 train_time:47062ms step_avg:89.13ms +step:529/1680 train_time:47151ms step_avg:89.13ms +step:530/1680 train_time:47240ms step_avg:89.13ms +step:531/1680 train_time:47329ms step_avg:89.13ms +step:532/1680 train_time:47418ms step_avg:89.13ms +step:533/1680 train_time:47506ms step_avg:89.13ms +step:534/1680 train_time:47596ms step_avg:89.13ms +step:535/1680 train_time:47685ms step_avg:89.13ms +step:536/1680 train_time:47774ms step_avg:89.13ms +step:537/1680 train_time:47864ms step_avg:89.13ms +step:538/1680 train_time:47954ms step_avg:89.13ms +step:539/1680 train_time:48043ms step_avg:89.13ms +step:540/1680 train_time:48131ms step_avg:89.13ms +step:541/1680 train_time:48221ms step_avg:89.13ms +step:542/1680 train_time:48309ms step_avg:89.13ms +step:543/1680 train_time:48398ms step_avg:89.13ms +step:544/1680 train_time:48487ms step_avg:89.13ms +step:545/1680 train_time:48575ms step_avg:89.13ms +step:546/1680 train_time:48665ms step_avg:89.13ms +step:547/1680 train_time:48754ms step_avg:89.13ms +step:548/1680 train_time:48843ms step_avg:89.13ms +step:549/1680 train_time:48934ms step_avg:89.13ms +step:550/1680 train_time:49024ms step_avg:89.14ms +step:551/1680 train_time:49114ms step_avg:89.14ms +step:552/1680 train_time:49203ms step_avg:89.14ms +step:553/1680 train_time:49293ms step_avg:89.14ms +step:554/1680 train_time:49383ms step_avg:89.14ms +step:555/1680 train_time:49473ms step_avg:89.14ms +step:556/1680 train_time:49563ms step_avg:89.14ms +step:557/1680 train_time:49653ms step_avg:89.14ms +step:558/1680 train_time:49744ms step_avg:89.15ms +step:559/1680 train_time:49835ms step_avg:89.15ms +step:560/1680 train_time:49925ms step_avg:89.15ms +step:561/1680 train_time:50016ms step_avg:89.16ms +step:562/1680 train_time:50106ms step_avg:89.16ms +step:563/1680 train_time:50197ms step_avg:89.16ms +step:564/1680 train_time:50288ms step_avg:89.16ms +step:565/1680 train_time:50377ms step_avg:89.16ms +step:566/1680 train_time:50468ms step_avg:89.17ms +step:567/1680 train_time:50557ms step_avg:89.17ms +step:568/1680 train_time:50649ms step_avg:89.17ms +step:569/1680 train_time:50739ms step_avg:89.17ms +step:570/1680 train_time:50829ms step_avg:89.17ms +step:571/1680 train_time:50920ms step_avg:89.18ms +step:572/1680 train_time:51011ms step_avg:89.18ms +step:573/1680 train_time:51101ms step_avg:89.18ms +step:574/1680 train_time:51191ms step_avg:89.18ms +step:575/1680 train_time:51282ms step_avg:89.19ms +step:576/1680 train_time:51372ms step_avg:89.19ms +step:577/1680 train_time:51462ms step_avg:89.19ms +step:578/1680 train_time:51552ms step_avg:89.19ms +step:579/1680 train_time:51643ms step_avg:89.19ms +step:580/1680 train_time:51733ms step_avg:89.19ms +step:581/1680 train_time:51823ms step_avg:89.20ms +step:582/1680 train_time:51913ms step_avg:89.20ms +step:583/1680 train_time:52004ms step_avg:89.20ms +step:584/1680 train_time:52094ms step_avg:89.20ms +step:585/1680 train_time:52184ms step_avg:89.20ms +step:586/1680 train_time:52274ms step_avg:89.21ms +step:587/1680 train_time:52367ms step_avg:89.21ms +step:588/1680 train_time:52457ms step_avg:89.21ms +step:589/1680 train_time:52547ms step_avg:89.21ms +step:590/1680 train_time:52637ms step_avg:89.22ms +step:591/1680 train_time:52727ms step_avg:89.22ms +step:592/1680 train_time:52817ms step_avg:89.22ms +step:593/1680 train_time:52907ms step_avg:89.22ms +step:594/1680 train_time:52997ms step_avg:89.22ms +step:595/1680 train_time:53089ms step_avg:89.22ms +step:596/1680 train_time:53178ms step_avg:89.23ms +step:597/1680 train_time:53274ms step_avg:89.24ms +step:598/1680 train_time:53359ms step_avg:89.23ms +step:599/1680 train_time:53451ms step_avg:89.23ms +step:600/1680 train_time:53541ms step_avg:89.23ms +step:601/1680 train_time:53631ms step_avg:89.24ms +step:602/1680 train_time:53722ms step_avg:89.24ms +step:603/1680 train_time:53811ms step_avg:89.24ms +step:604/1680 train_time:53902ms step_avg:89.24ms +step:605/1680 train_time:53992ms step_avg:89.24ms +step:606/1680 train_time:54083ms step_avg:89.25ms +step:607/1680 train_time:54173ms step_avg:89.25ms +step:608/1680 train_time:54263ms step_avg:89.25ms +step:609/1680 train_time:54353ms step_avg:89.25ms +step:610/1680 train_time:54443ms step_avg:89.25ms +step:611/1680 train_time:54534ms step_avg:89.25ms +step:612/1680 train_time:54624ms step_avg:89.25ms +step:613/1680 train_time:54714ms step_avg:89.26ms +step:614/1680 train_time:54803ms step_avg:89.26ms +step:615/1680 train_time:54893ms step_avg:89.26ms +step:616/1680 train_time:54984ms step_avg:89.26ms +step:617/1680 train_time:55075ms step_avg:89.26ms +step:618/1680 train_time:55166ms step_avg:89.27ms +step:619/1680 train_time:55256ms step_avg:89.27ms +step:620/1680 train_time:55348ms step_avg:89.27ms +step:621/1680 train_time:55438ms step_avg:89.27ms +step:622/1680 train_time:55528ms step_avg:89.27ms +step:623/1680 train_time:55618ms step_avg:89.27ms +step:624/1680 train_time:55708ms step_avg:89.27ms +step:625/1680 train_time:55798ms step_avg:89.28ms +step:625/1680 val_loss:3.6190 train_time:55890ms step_avg:89.42ms +step:626/1680 train_time:55913ms step_avg:89.32ms +step:627/1680 train_time:55981ms step_avg:89.28ms +step:628/1680 train_time:56079ms step_avg:89.30ms +step:629/1680 train_time:56171ms step_avg:89.30ms +step:630/1680 train_time:56260ms step_avg:89.30ms +step:631/1680 train_time:56349ms step_avg:89.30ms +step:632/1680 train_time:56437ms step_avg:89.30ms +step:633/1680 train_time:56526ms step_avg:89.30ms +step:634/1680 train_time:56615ms step_avg:89.30ms +step:635/1680 train_time:56704ms step_avg:89.30ms +step:636/1680 train_time:56795ms step_avg:89.30ms +step:637/1680 train_time:56887ms step_avg:89.30ms +step:638/1680 train_time:56979ms step_avg:89.31ms +step:639/1680 train_time:57072ms step_avg:89.31ms +step:640/1680 train_time:57163ms step_avg:89.32ms +step:641/1680 train_time:57253ms step_avg:89.32ms +step:642/1680 train_time:57343ms step_avg:89.32ms +step:643/1680 train_time:57432ms step_avg:89.32ms +step:644/1680 train_time:57521ms step_avg:89.32ms +step:645/1680 train_time:57611ms step_avg:89.32ms +step:646/1680 train_time:57700ms step_avg:89.32ms +step:647/1680 train_time:57791ms step_avg:89.32ms +step:648/1680 train_time:57883ms step_avg:89.33ms +step:649/1680 train_time:57973ms step_avg:89.33ms +step:650/1680 train_time:58064ms step_avg:89.33ms +step:651/1680 train_time:58154ms step_avg:89.33ms +step:652/1680 train_time:58245ms step_avg:89.33ms +step:653/1680 train_time:58335ms step_avg:89.33ms +step:654/1680 train_time:58425ms step_avg:89.33ms +step:655/1680 train_time:58514ms step_avg:89.33ms +step:656/1680 train_time:58604ms step_avg:89.33ms +step:657/1680 train_time:58693ms step_avg:89.33ms +step:658/1680 train_time:58783ms step_avg:89.34ms +step:659/1680 train_time:58873ms step_avg:89.34ms +step:660/1680 train_time:58965ms step_avg:89.34ms +step:661/1680 train_time:59056ms step_avg:89.34ms +step:662/1680 train_time:59146ms step_avg:89.34ms +step:663/1680 train_time:59237ms step_avg:89.35ms +step:664/1680 train_time:59327ms step_avg:89.35ms +step:665/1680 train_time:59417ms step_avg:89.35ms +step:666/1680 train_time:59506ms step_avg:89.35ms +step:667/1680 train_time:59596ms step_avg:89.35ms +step:668/1680 train_time:59685ms step_avg:89.35ms +step:669/1680 train_time:59775ms step_avg:89.35ms +step:670/1680 train_time:59866ms step_avg:89.35ms +step:671/1680 train_time:59956ms step_avg:89.35ms +step:672/1680 train_time:60049ms step_avg:89.36ms +step:673/1680 train_time:60138ms step_avg:89.36ms +step:674/1680 train_time:60229ms step_avg:89.36ms +step:675/1680 train_time:60319ms step_avg:89.36ms +step:676/1680 train_time:60408ms step_avg:89.36ms +step:677/1680 train_time:60499ms step_avg:89.36ms +step:678/1680 train_time:60588ms step_avg:89.36ms +step:679/1680 train_time:60683ms step_avg:89.37ms +step:680/1680 train_time:60767ms step_avg:89.36ms +step:681/1680 train_time:60858ms step_avg:89.36ms +step:682/1680 train_time:60949ms step_avg:89.37ms +step:683/1680 train_time:61039ms step_avg:89.37ms +step:684/1680 train_time:61129ms step_avg:89.37ms +step:685/1680 train_time:61219ms step_avg:89.37ms +step:686/1680 train_time:61309ms step_avg:89.37ms +step:687/1680 train_time:61400ms step_avg:89.37ms +step:688/1680 train_time:61490ms step_avg:89.38ms +step:689/1680 train_time:61580ms step_avg:89.38ms +step:690/1680 train_time:61670ms step_avg:89.38ms +step:691/1680 train_time:61760ms step_avg:89.38ms +step:692/1680 train_time:61851ms step_avg:89.38ms +step:693/1680 train_time:61941ms step_avg:89.38ms +step:694/1680 train_time:62033ms step_avg:89.38ms +step:695/1680 train_time:62123ms step_avg:89.39ms +step:696/1680 train_time:62213ms step_avg:89.39ms +step:697/1680 train_time:62303ms step_avg:89.39ms +step:698/1680 train_time:62394ms step_avg:89.39ms +step:699/1680 train_time:62483ms step_avg:89.39ms +step:700/1680 train_time:62573ms step_avg:89.39ms +step:701/1680 train_time:62663ms step_avg:89.39ms +step:702/1680 train_time:62754ms step_avg:89.39ms +step:703/1680 train_time:62844ms step_avg:89.39ms +step:704/1680 train_time:62934ms step_avg:89.40ms +step:705/1680 train_time:63024ms step_avg:89.40ms +step:706/1680 train_time:63115ms step_avg:89.40ms +step:707/1680 train_time:63205ms step_avg:89.40ms +step:708/1680 train_time:63295ms step_avg:89.40ms +step:709/1680 train_time:63388ms step_avg:89.41ms +step:710/1680 train_time:63475ms step_avg:89.40ms +step:711/1680 train_time:63566ms step_avg:89.40ms +step:712/1680 train_time:63655ms step_avg:89.40ms +step:713/1680 train_time:63745ms step_avg:89.40ms +step:714/1680 train_time:63835ms step_avg:89.40ms +step:715/1680 train_time:63925ms step_avg:89.41ms +step:716/1680 train_time:64015ms step_avg:89.41ms +step:717/1680 train_time:64105ms step_avg:89.41ms +step:718/1680 train_time:64195ms step_avg:89.41ms +step:719/1680 train_time:64285ms step_avg:89.41ms +step:720/1680 train_time:64375ms step_avg:89.41ms +step:721/1680 train_time:64466ms step_avg:89.41ms +step:722/1680 train_time:64556ms step_avg:89.41ms +step:723/1680 train_time:64645ms step_avg:89.41ms +step:724/1680 train_time:64735ms step_avg:89.41ms +step:725/1680 train_time:64825ms step_avg:89.41ms +step:726/1680 train_time:64915ms step_avg:89.41ms +step:727/1680 train_time:65005ms step_avg:89.42ms +step:728/1680 train_time:65095ms step_avg:89.42ms +step:729/1680 train_time:65185ms step_avg:89.42ms +step:730/1680 train_time:65276ms step_avg:89.42ms +step:731/1680 train_time:65366ms step_avg:89.42ms +step:732/1680 train_time:65456ms step_avg:89.42ms +step:733/1680 train_time:65546ms step_avg:89.42ms +step:734/1680 train_time:65636ms step_avg:89.42ms +step:735/1680 train_time:65726ms step_avg:89.42ms +step:736/1680 train_time:65816ms step_avg:89.42ms +step:737/1680 train_time:65906ms step_avg:89.42ms +step:738/1680 train_time:65995ms step_avg:89.42ms +step:739/1680 train_time:66087ms step_avg:89.43ms +step:740/1680 train_time:66177ms step_avg:89.43ms +step:741/1680 train_time:66267ms step_avg:89.43ms +step:742/1680 train_time:66357ms step_avg:89.43ms +step:743/1680 train_time:66446ms step_avg:89.43ms +step:744/1680 train_time:66536ms step_avg:89.43ms +step:745/1680 train_time:66626ms step_avg:89.43ms +step:746/1680 train_time:66716ms step_avg:89.43ms +step:747/1680 train_time:66805ms step_avg:89.43ms +step:748/1680 train_time:66895ms step_avg:89.43ms +step:749/1680 train_time:66986ms step_avg:89.43ms +step:750/1680 train_time:67076ms step_avg:89.43ms +step:750/1680 val_loss:3.5659 train_time:67168ms step_avg:89.56ms +step:751/1680 train_time:67191ms step_avg:89.47ms +step:752/1680 train_time:67262ms step_avg:89.44ms +step:753/1680 train_time:67359ms step_avg:89.45ms +step:754/1680 train_time:67451ms step_avg:89.46ms +step:755/1680 train_time:67541ms step_avg:89.46ms +step:756/1680 train_time:67630ms step_avg:89.46ms +step:757/1680 train_time:67719ms step_avg:89.46ms +step:758/1680 train_time:67808ms step_avg:89.46ms +step:759/1680 train_time:67898ms step_avg:89.46ms +step:760/1680 train_time:67986ms step_avg:89.46ms +step:761/1680 train_time:68075ms step_avg:89.46ms +step:762/1680 train_time:68165ms step_avg:89.46ms +step:763/1680 train_time:68257ms step_avg:89.46ms +step:764/1680 train_time:68349ms step_avg:89.46ms +step:765/1680 train_time:68443ms step_avg:89.47ms +step:766/1680 train_time:68533ms step_avg:89.47ms +step:767/1680 train_time:68624ms step_avg:89.47ms +step:768/1680 train_time:68714ms step_avg:89.47ms +step:769/1680 train_time:68803ms step_avg:89.47ms +step:770/1680 train_time:68893ms step_avg:89.47ms +step:771/1680 train_time:68981ms step_avg:89.47ms +step:772/1680 train_time:69070ms step_avg:89.47ms +step:773/1680 train_time:69160ms step_avg:89.47ms +step:774/1680 train_time:69250ms step_avg:89.47ms +step:775/1680 train_time:69342ms step_avg:89.47ms +step:776/1680 train_time:69433ms step_avg:89.48ms +step:777/1680 train_time:69525ms step_avg:89.48ms +step:778/1680 train_time:69616ms step_avg:89.48ms +step:779/1680 train_time:69706ms step_avg:89.48ms +step:780/1680 train_time:69795ms step_avg:89.48ms +step:781/1680 train_time:69884ms step_avg:89.48ms +step:782/1680 train_time:69974ms step_avg:89.48ms +step:783/1680 train_time:70064ms step_avg:89.48ms +step:784/1680 train_time:70153ms step_avg:89.48ms +step:785/1680 train_time:70244ms step_avg:89.48ms +step:786/1680 train_time:70334ms step_avg:89.48ms +step:787/1680 train_time:70425ms step_avg:89.49ms +step:788/1680 train_time:70516ms step_avg:89.49ms +step:789/1680 train_time:70607ms step_avg:89.49ms +step:790/1680 train_time:70698ms step_avg:89.49ms +step:791/1680 train_time:70787ms step_avg:89.49ms +step:792/1680 train_time:70877ms step_avg:89.49ms +step:793/1680 train_time:70967ms step_avg:89.49ms +step:794/1680 train_time:71057ms step_avg:89.49ms +step:795/1680 train_time:71146ms step_avg:89.49ms +step:796/1680 train_time:71236ms step_avg:89.49ms +step:797/1680 train_time:71326ms step_avg:89.49ms +step:798/1680 train_time:71417ms step_avg:89.49ms +step:799/1680 train_time:71507ms step_avg:89.50ms +step:800/1680 train_time:71598ms step_avg:89.50ms +step:801/1680 train_time:71688ms step_avg:89.50ms +step:802/1680 train_time:71779ms step_avg:89.50ms +step:803/1680 train_time:71868ms step_avg:89.50ms +step:804/1680 train_time:71958ms step_avg:89.50ms +step:805/1680 train_time:72049ms step_avg:89.50ms +step:806/1680 train_time:72139ms step_avg:89.50ms +step:807/1680 train_time:72228ms step_avg:89.50ms +step:808/1680 train_time:72318ms step_avg:89.50ms +step:809/1680 train_time:72408ms step_avg:89.50ms +step:810/1680 train_time:72499ms step_avg:89.51ms +step:811/1680 train_time:72589ms step_avg:89.51ms +step:812/1680 train_time:72680ms step_avg:89.51ms +step:813/1680 train_time:72769ms step_avg:89.51ms +step:814/1680 train_time:72859ms step_avg:89.51ms +step:815/1680 train_time:72949ms step_avg:89.51ms +step:816/1680 train_time:73039ms step_avg:89.51ms +step:817/1680 train_time:73129ms step_avg:89.51ms +step:818/1680 train_time:73218ms step_avg:89.51ms +step:819/1680 train_time:73308ms step_avg:89.51ms +step:820/1680 train_time:73399ms step_avg:89.51ms +step:821/1680 train_time:73488ms step_avg:89.51ms +step:822/1680 train_time:73579ms step_avg:89.51ms +step:823/1680 train_time:73669ms step_avg:89.51ms +step:824/1680 train_time:73759ms step_avg:89.51ms +step:825/1680 train_time:73848ms step_avg:89.51ms +step:826/1680 train_time:73938ms step_avg:89.51ms +step:827/1680 train_time:74028ms step_avg:89.51ms +step:828/1680 train_time:74119ms step_avg:89.52ms +step:829/1680 train_time:74209ms step_avg:89.52ms +step:830/1680 train_time:74299ms step_avg:89.52ms +step:831/1680 train_time:74389ms step_avg:89.52ms +step:832/1680 train_time:74479ms step_avg:89.52ms +step:833/1680 train_time:74569ms step_avg:89.52ms +step:834/1680 train_time:74659ms step_avg:89.52ms +step:835/1680 train_time:74748ms step_avg:89.52ms +step:836/1680 train_time:74843ms step_avg:89.53ms +step:837/1680 train_time:74928ms step_avg:89.52ms +step:838/1680 train_time:75019ms step_avg:89.52ms +step:839/1680 train_time:75109ms step_avg:89.52ms +step:840/1680 train_time:75199ms step_avg:89.52ms +step:841/1680 train_time:75290ms step_avg:89.52ms +step:842/1680 train_time:75380ms step_avg:89.53ms +step:843/1680 train_time:75470ms step_avg:89.53ms +step:844/1680 train_time:75560ms step_avg:89.53ms +step:845/1680 train_time:75651ms step_avg:89.53ms +step:846/1680 train_time:75742ms step_avg:89.53ms +step:847/1680 train_time:75832ms step_avg:89.53ms +step:848/1680 train_time:75922ms step_avg:89.53ms +step:849/1680 train_time:76013ms step_avg:89.53ms +step:850/1680 train_time:76104ms step_avg:89.53ms +step:851/1680 train_time:76193ms step_avg:89.53ms +step:852/1680 train_time:76284ms step_avg:89.54ms +step:853/1680 train_time:76375ms step_avg:89.54ms +step:854/1680 train_time:76465ms step_avg:89.54ms +step:855/1680 train_time:76555ms step_avg:89.54ms +step:856/1680 train_time:76646ms step_avg:89.54ms +step:857/1680 train_time:76741ms step_avg:89.55ms +step:858/1680 train_time:76827ms step_avg:89.54ms +step:859/1680 train_time:76918ms step_avg:89.54ms +step:860/1680 train_time:77009ms step_avg:89.54ms +step:861/1680 train_time:77098ms step_avg:89.54ms +step:862/1680 train_time:77188ms step_avg:89.55ms +step:863/1680 train_time:77279ms step_avg:89.55ms +step:864/1680 train_time:77369ms step_avg:89.55ms +step:865/1680 train_time:77458ms step_avg:89.55ms +step:866/1680 train_time:77548ms step_avg:89.55ms +step:867/1680 train_time:77639ms step_avg:89.55ms +step:868/1680 train_time:77728ms step_avg:89.55ms +step:869/1680 train_time:77819ms step_avg:89.55ms +step:870/1680 train_time:77909ms step_avg:89.55ms +step:871/1680 train_time:77999ms step_avg:89.55ms +step:872/1680 train_time:78089ms step_avg:89.55ms +step:873/1680 train_time:78178ms step_avg:89.55ms +step:874/1680 train_time:78268ms step_avg:89.55ms +step:875/1680 train_time:78358ms step_avg:89.55ms +step:875/1680 val_loss:3.5196 train_time:78450ms step_avg:89.66ms +step:876/1680 train_time:78473ms step_avg:89.58ms +step:877/1680 train_time:78545ms step_avg:89.56ms +step:878/1680 train_time:78640ms step_avg:89.57ms +step:879/1680 train_time:78733ms step_avg:89.57ms +step:880/1680 train_time:78823ms step_avg:89.57ms +step:881/1680 train_time:78912ms step_avg:89.57ms +step:882/1680 train_time:79001ms step_avg:89.57ms +step:883/1680 train_time:79089ms step_avg:89.57ms +step:884/1680 train_time:79178ms step_avg:89.57ms +step:885/1680 train_time:79267ms step_avg:89.57ms +step:886/1680 train_time:79356ms step_avg:89.57ms +step:887/1680 train_time:79447ms step_avg:89.57ms +step:888/1680 train_time:79539ms step_avg:89.57ms +step:889/1680 train_time:79633ms step_avg:89.58ms +step:890/1680 train_time:79725ms step_avg:89.58ms +step:891/1680 train_time:79815ms step_avg:89.58ms +step:892/1680 train_time:79906ms step_avg:89.58ms +step:893/1680 train_time:79995ms step_avg:89.58ms +step:894/1680 train_time:80085ms step_avg:89.58ms +step:895/1680 train_time:80177ms step_avg:89.58ms +step:896/1680 train_time:80264ms step_avg:89.58ms +step:897/1680 train_time:80353ms step_avg:89.58ms +step:898/1680 train_time:80444ms step_avg:89.58ms +step:899/1680 train_time:80536ms step_avg:89.58ms +step:900/1680 train_time:80628ms step_avg:89.59ms +step:901/1680 train_time:80720ms step_avg:89.59ms +step:902/1680 train_time:80810ms step_avg:89.59ms +step:903/1680 train_time:80900ms step_avg:89.59ms +step:904/1680 train_time:80989ms step_avg:89.59ms +step:905/1680 train_time:81078ms step_avg:89.59ms +step:906/1680 train_time:81168ms step_avg:89.59ms +step:907/1680 train_time:81258ms step_avg:89.59ms +step:908/1680 train_time:81348ms step_avg:89.59ms +step:909/1680 train_time:81438ms step_avg:89.59ms +step:910/1680 train_time:81529ms step_avg:89.59ms +step:911/1680 train_time:81621ms step_avg:89.59ms +step:912/1680 train_time:81712ms step_avg:89.60ms +step:913/1680 train_time:81804ms step_avg:89.60ms +step:914/1680 train_time:81894ms step_avg:89.60ms +step:915/1680 train_time:81985ms step_avg:89.60ms +step:916/1680 train_time:82075ms step_avg:89.60ms +step:917/1680 train_time:82165ms step_avg:89.60ms +step:918/1680 train_time:82254ms step_avg:89.60ms +step:919/1680 train_time:82344ms step_avg:89.60ms +step:920/1680 train_time:82434ms step_avg:89.60ms +step:921/1680 train_time:82524ms step_avg:89.60ms +step:922/1680 train_time:82614ms step_avg:89.60ms +step:923/1680 train_time:82705ms step_avg:89.61ms +step:924/1680 train_time:82796ms step_avg:89.61ms +step:925/1680 train_time:82887ms step_avg:89.61ms +step:926/1680 train_time:82978ms step_avg:89.61ms +step:927/1680 train_time:83068ms step_avg:89.61ms +step:928/1680 train_time:83157ms step_avg:89.61ms +step:929/1680 train_time:83248ms step_avg:89.61ms +step:930/1680 train_time:83339ms step_avg:89.61ms +step:931/1680 train_time:83429ms step_avg:89.61ms +step:932/1680 train_time:83519ms step_avg:89.61ms +step:933/1680 train_time:83609ms step_avg:89.61ms +step:934/1680 train_time:83699ms step_avg:89.61ms +step:935/1680 train_time:83790ms step_avg:89.61ms +step:936/1680 train_time:83880ms step_avg:89.62ms +step:937/1680 train_time:83970ms step_avg:89.62ms +step:938/1680 train_time:84059ms step_avg:89.62ms +step:939/1680 train_time:84150ms step_avg:89.62ms +step:940/1680 train_time:84240ms step_avg:89.62ms +step:941/1680 train_time:84331ms step_avg:89.62ms +step:942/1680 train_time:84420ms step_avg:89.62ms +step:943/1680 train_time:84510ms step_avg:89.62ms +step:944/1680 train_time:84601ms step_avg:89.62ms +step:945/1680 train_time:84691ms step_avg:89.62ms +step:946/1680 train_time:84783ms step_avg:89.62ms +step:947/1680 train_time:84874ms step_avg:89.62ms +step:948/1680 train_time:84963ms step_avg:89.62ms +step:949/1680 train_time:85053ms step_avg:89.62ms +step:950/1680 train_time:85143ms step_avg:89.62ms +step:951/1680 train_time:85233ms step_avg:89.62ms +step:952/1680 train_time:85323ms step_avg:89.63ms +step:953/1680 train_time:85414ms step_avg:89.63ms +step:954/1680 train_time:85503ms step_avg:89.63ms +step:955/1680 train_time:85594ms step_avg:89.63ms +step:956/1680 train_time:85685ms step_avg:89.63ms +step:957/1680 train_time:85775ms step_avg:89.63ms +step:958/1680 train_time:85867ms step_avg:89.63ms +step:959/1680 train_time:85957ms step_avg:89.63ms +step:960/1680 train_time:86048ms step_avg:89.63ms +step:961/1680 train_time:86138ms step_avg:89.63ms +step:962/1680 train_time:86227ms step_avg:89.63ms +step:963/1680 train_time:86318ms step_avg:89.63ms +step:964/1680 train_time:86407ms step_avg:89.63ms +step:965/1680 train_time:86499ms step_avg:89.64ms +step:966/1680 train_time:86588ms step_avg:89.64ms +step:967/1680 train_time:86678ms step_avg:89.64ms +step:968/1680 train_time:86769ms step_avg:89.64ms +step:969/1680 train_time:86860ms step_avg:89.64ms +step:970/1680 train_time:86950ms step_avg:89.64ms +step:971/1680 train_time:87040ms step_avg:89.64ms +step:972/1680 train_time:87131ms step_avg:89.64ms +step:973/1680 train_time:87221ms step_avg:89.64ms +step:974/1680 train_time:87310ms step_avg:89.64ms +step:975/1680 train_time:87400ms step_avg:89.64ms +step:976/1680 train_time:87489ms step_avg:89.64ms +step:977/1680 train_time:87579ms step_avg:89.64ms +step:978/1680 train_time:87669ms step_avg:89.64ms +step:979/1680 train_time:87759ms step_avg:89.64ms +step:980/1680 train_time:87849ms step_avg:89.64ms +step:981/1680 train_time:87939ms step_avg:89.64ms +step:982/1680 train_time:88029ms step_avg:89.64ms +step:983/1680 train_time:88119ms step_avg:89.64ms +step:984/1680 train_time:88209ms step_avg:89.64ms +step:985/1680 train_time:88299ms step_avg:89.64ms +step:986/1680 train_time:88389ms step_avg:89.64ms +step:987/1680 train_time:88480ms step_avg:89.65ms +step:988/1680 train_time:88569ms step_avg:89.65ms +step:989/1680 train_time:88660ms step_avg:89.65ms +step:990/1680 train_time:88750ms step_avg:89.65ms +step:991/1680 train_time:88841ms step_avg:89.65ms +step:992/1680 train_time:88931ms step_avg:89.65ms +step:993/1680 train_time:89021ms step_avg:89.65ms +step:994/1680 train_time:89111ms step_avg:89.65ms +step:995/1680 train_time:89201ms step_avg:89.65ms +step:996/1680 train_time:89292ms step_avg:89.65ms +step:997/1680 train_time:89381ms step_avg:89.65ms +step:998/1680 train_time:89471ms step_avg:89.65ms +step:999/1680 train_time:89560ms step_avg:89.65ms +step:1000/1680 train_time:89651ms step_avg:89.65ms +step:1000/1680 val_loss:3.4696 train_time:89742ms step_avg:89.74ms +step:1001/1680 train_time:89765ms step_avg:89.68ms +step:1002/1680 train_time:89837ms step_avg:89.66ms +step:1003/1680 train_time:89932ms step_avg:89.66ms +step:1004/1680 train_time:90023ms step_avg:89.66ms +step:1005/1680 train_time:90113ms step_avg:89.67ms +step:1006/1680 train_time:90203ms step_avg:89.66ms +step:1007/1680 train_time:90292ms step_avg:89.66ms +step:1008/1680 train_time:90382ms step_avg:89.66ms +step:1009/1680 train_time:90470ms step_avg:89.66ms +step:1010/1680 train_time:90559ms step_avg:89.66ms +step:1011/1680 train_time:90648ms step_avg:89.66ms +step:1012/1680 train_time:90738ms step_avg:89.66ms +step:1013/1680 train_time:90830ms step_avg:89.66ms +step:1014/1680 train_time:90922ms step_avg:89.67ms +step:1015/1680 train_time:91013ms step_avg:89.67ms +step:1016/1680 train_time:91104ms step_avg:89.67ms +step:1017/1680 train_time:91195ms step_avg:89.67ms +step:1018/1680 train_time:91285ms step_avg:89.67ms +step:1019/1680 train_time:91375ms step_avg:89.67ms +step:1020/1680 train_time:91464ms step_avg:89.67ms +step:1021/1680 train_time:91553ms step_avg:89.67ms +step:1022/1680 train_time:91643ms step_avg:89.67ms +step:1023/1680 train_time:91733ms step_avg:89.67ms +step:1024/1680 train_time:91824ms step_avg:89.67ms +step:1025/1680 train_time:91915ms step_avg:89.67ms +step:1026/1680 train_time:92007ms step_avg:89.68ms +step:1027/1680 train_time:92098ms step_avg:89.68ms +step:1028/1680 train_time:92189ms step_avg:89.68ms +step:1029/1680 train_time:92279ms step_avg:89.68ms +step:1030/1680 train_time:92369ms step_avg:89.68ms +step:1031/1680 train_time:92457ms step_avg:89.68ms +step:1032/1680 train_time:92547ms step_avg:89.68ms +step:1033/1680 train_time:92637ms step_avg:89.68ms +step:1034/1680 train_time:92727ms step_avg:89.68ms +step:1035/1680 train_time:92816ms step_avg:89.68ms +step:1036/1680 train_time:92907ms step_avg:89.68ms +step:1037/1680 train_time:92998ms step_avg:89.68ms +step:1038/1680 train_time:93089ms step_avg:89.68ms +step:1039/1680 train_time:93180ms step_avg:89.68ms +step:1040/1680 train_time:93270ms step_avg:89.68ms +step:1041/1680 train_time:93360ms step_avg:89.68ms +step:1042/1680 train_time:93450ms step_avg:89.68ms +step:1043/1680 train_time:93539ms step_avg:89.68ms +step:1044/1680 train_time:93630ms step_avg:89.68ms +step:1045/1680 train_time:93719ms step_avg:89.68ms +step:1046/1680 train_time:93810ms step_avg:89.68ms +step:1047/1680 train_time:93900ms step_avg:89.68ms +step:1048/1680 train_time:93991ms step_avg:89.69ms +step:1049/1680 train_time:94081ms step_avg:89.69ms +step:1050/1680 train_time:94172ms step_avg:89.69ms +step:1051/1680 train_time:94262ms step_avg:89.69ms +step:1052/1680 train_time:94352ms step_avg:89.69ms +step:1053/1680 train_time:94442ms step_avg:89.69ms +step:1054/1680 train_time:94532ms step_avg:89.69ms +step:1055/1680 train_time:94621ms step_avg:89.69ms +step:1056/1680 train_time:94712ms step_avg:89.69ms +step:1057/1680 train_time:94802ms step_avg:89.69ms +step:1058/1680 train_time:94893ms step_avg:89.69ms +step:1059/1680 train_time:94983ms step_avg:89.69ms +step:1060/1680 train_time:95074ms step_avg:89.69ms +step:1061/1680 train_time:95165ms step_avg:89.69ms +step:1062/1680 train_time:95257ms step_avg:89.70ms +step:1063/1680 train_time:95346ms step_avg:89.69ms +step:1064/1680 train_time:95435ms step_avg:89.69ms +step:1065/1680 train_time:95525ms step_avg:89.70ms +step:1066/1680 train_time:95615ms step_avg:89.70ms +step:1067/1680 train_time:95705ms step_avg:89.70ms +step:1068/1680 train_time:95796ms step_avg:89.70ms +step:1069/1680 train_time:95886ms step_avg:89.70ms +step:1070/1680 train_time:95976ms step_avg:89.70ms +step:1071/1680 train_time:96067ms step_avg:89.70ms +step:1072/1680 train_time:96158ms step_avg:89.70ms +step:1073/1680 train_time:96248ms step_avg:89.70ms +step:1074/1680 train_time:96337ms step_avg:89.70ms +step:1075/1680 train_time:96428ms step_avg:89.70ms +step:1076/1680 train_time:96517ms step_avg:89.70ms +step:1077/1680 train_time:96607ms step_avg:89.70ms +step:1078/1680 train_time:96697ms step_avg:89.70ms +step:1079/1680 train_time:96787ms step_avg:89.70ms +step:1080/1680 train_time:96877ms step_avg:89.70ms +step:1081/1680 train_time:96967ms step_avg:89.70ms +step:1082/1680 train_time:97057ms step_avg:89.70ms +step:1083/1680 train_time:97151ms step_avg:89.71ms +step:1084/1680 train_time:97238ms step_avg:89.70ms +step:1085/1680 train_time:97328ms step_avg:89.70ms +step:1086/1680 train_time:97418ms step_avg:89.70ms +step:1087/1680 train_time:97508ms step_avg:89.70ms +step:1088/1680 train_time:97598ms step_avg:89.70ms +step:1089/1680 train_time:97687ms step_avg:89.70ms +step:1090/1680 train_time:97777ms step_avg:89.70ms +step:1091/1680 train_time:97868ms step_avg:89.71ms +step:1092/1680 train_time:97958ms step_avg:89.71ms +step:1093/1680 train_time:98049ms step_avg:89.71ms +step:1094/1680 train_time:98139ms step_avg:89.71ms +step:1095/1680 train_time:98231ms step_avg:89.71ms +step:1096/1680 train_time:98321ms step_avg:89.71ms +step:1097/1680 train_time:98411ms step_avg:89.71ms +step:1098/1680 train_time:98501ms step_avg:89.71ms +step:1099/1680 train_time:98592ms step_avg:89.71ms +step:1100/1680 train_time:98682ms step_avg:89.71ms +step:1101/1680 train_time:98773ms step_avg:89.71ms +step:1102/1680 train_time:98865ms step_avg:89.71ms +step:1103/1680 train_time:98956ms step_avg:89.72ms +step:1104/1680 train_time:99047ms step_avg:89.72ms +step:1105/1680 train_time:99139ms step_avg:89.72ms +step:1106/1680 train_time:99231ms step_avg:89.72ms +step:1107/1680 train_time:99321ms step_avg:89.72ms +step:1108/1680 train_time:99412ms step_avg:89.72ms +step:1109/1680 train_time:99502ms step_avg:89.72ms +step:1110/1680 train_time:99592ms step_avg:89.72ms +step:1111/1680 train_time:99683ms step_avg:89.72ms +step:1112/1680 train_time:99773ms step_avg:89.72ms +step:1113/1680 train_time:99864ms step_avg:89.73ms +step:1114/1680 train_time:99955ms step_avg:89.73ms +step:1115/1680 train_time:100047ms step_avg:89.73ms +step:1116/1680 train_time:100138ms step_avg:89.73ms +step:1117/1680 train_time:100230ms step_avg:89.73ms +step:1118/1680 train_time:100320ms step_avg:89.73ms +step:1119/1680 train_time:100411ms step_avg:89.73ms +step:1120/1680 train_time:100501ms step_avg:89.73ms +step:1121/1680 train_time:100592ms step_avg:89.73ms +step:1122/1680 train_time:100682ms step_avg:89.73ms +step:1123/1680 train_time:100773ms step_avg:89.74ms +step:1124/1680 train_time:100863ms step_avg:89.74ms +step:1125/1680 train_time:100953ms step_avg:89.74ms +step:1125/1680 val_loss:3.4156 train_time:101046ms step_avg:89.82ms +step:1126/1680 train_time:101069ms step_avg:89.76ms +step:1127/1680 train_time:101140ms step_avg:89.74ms +step:1128/1680 train_time:101238ms step_avg:89.75ms +step:1129/1680 train_time:101330ms step_avg:89.75ms +step:1130/1680 train_time:101421ms step_avg:89.75ms +step:1131/1680 train_time:101511ms step_avg:89.75ms +step:1132/1680 train_time:101601ms step_avg:89.75ms +step:1133/1680 train_time:101691ms step_avg:89.75ms +step:1134/1680 train_time:101781ms step_avg:89.75ms +step:1135/1680 train_time:101871ms step_avg:89.75ms +step:1136/1680 train_time:101961ms step_avg:89.75ms +step:1137/1680 train_time:102053ms step_avg:89.76ms +step:1138/1680 train_time:102147ms step_avg:89.76ms +step:1139/1680 train_time:102241ms step_avg:89.76ms +step:1140/1680 train_time:102332ms step_avg:89.77ms +step:1141/1680 train_time:102423ms step_avg:89.77ms +step:1142/1680 train_time:102514ms step_avg:89.77ms +step:1143/1680 train_time:102604ms step_avg:89.77ms +step:1144/1680 train_time:102694ms step_avg:89.77ms +step:1145/1680 train_time:102783ms step_avg:89.77ms +step:1146/1680 train_time:102873ms step_avg:89.77ms +step:1147/1680 train_time:102963ms step_avg:89.77ms +step:1148/1680 train_time:103055ms step_avg:89.77ms +step:1149/1680 train_time:103147ms step_avg:89.77ms +step:1150/1680 train_time:103241ms step_avg:89.77ms +step:1151/1680 train_time:103332ms step_avg:89.78ms +step:1152/1680 train_time:103423ms step_avg:89.78ms +step:1153/1680 train_time:103514ms step_avg:89.78ms +step:1154/1680 train_time:103604ms step_avg:89.78ms +step:1155/1680 train_time:103694ms step_avg:89.78ms +step:1156/1680 train_time:103784ms step_avg:89.78ms +step:1157/1680 train_time:103874ms step_avg:89.78ms +step:1158/1680 train_time:103965ms step_avg:89.78ms +step:1159/1680 train_time:104055ms step_avg:89.78ms +step:1160/1680 train_time:104147ms step_avg:89.78ms +step:1161/1680 train_time:104241ms step_avg:89.79ms +step:1162/1680 train_time:104331ms step_avg:89.79ms +step:1163/1680 train_time:104422ms step_avg:89.79ms +step:1164/1680 train_time:104513ms step_avg:89.79ms +step:1165/1680 train_time:104604ms step_avg:89.79ms +step:1166/1680 train_time:104694ms step_avg:89.79ms +step:1167/1680 train_time:104784ms step_avg:89.79ms +step:1168/1680 train_time:104874ms step_avg:89.79ms +step:1169/1680 train_time:104965ms step_avg:89.79ms +step:1170/1680 train_time:105055ms step_avg:89.79ms +step:1171/1680 train_time:105147ms step_avg:89.79ms +step:1172/1680 train_time:105237ms step_avg:89.79ms +step:1173/1680 train_time:105329ms step_avg:89.79ms +step:1174/1680 train_time:105420ms step_avg:89.80ms +step:1175/1680 train_time:105511ms step_avg:89.80ms +step:1176/1680 train_time:105601ms step_avg:89.80ms +step:1177/1680 train_time:105692ms step_avg:89.80ms +step:1178/1680 train_time:105782ms step_avg:89.80ms +step:1179/1680 train_time:105873ms step_avg:89.80ms +step:1180/1680 train_time:105963ms step_avg:89.80ms +step:1181/1680 train_time:106054ms step_avg:89.80ms +step:1182/1680 train_time:106145ms step_avg:89.80ms +step:1183/1680 train_time:106236ms step_avg:89.80ms +step:1184/1680 train_time:106329ms step_avg:89.80ms +step:1185/1680 train_time:106419ms step_avg:89.81ms +step:1186/1680 train_time:106510ms step_avg:89.81ms +step:1187/1680 train_time:106601ms step_avg:89.81ms +step:1188/1680 train_time:106692ms step_avg:89.81ms +step:1189/1680 train_time:106783ms step_avg:89.81ms +step:1190/1680 train_time:106874ms step_avg:89.81ms +step:1191/1680 train_time:106964ms step_avg:89.81ms +step:1192/1680 train_time:107054ms step_avg:89.81ms +step:1193/1680 train_time:107145ms step_avg:89.81ms +step:1194/1680 train_time:107235ms step_avg:89.81ms +step:1195/1680 train_time:107326ms step_avg:89.81ms +step:1196/1680 train_time:107419ms step_avg:89.81ms +step:1197/1680 train_time:107507ms step_avg:89.81ms +step:1198/1680 train_time:107598ms step_avg:89.81ms +step:1199/1680 train_time:107688ms step_avg:89.81ms +step:1200/1680 train_time:107779ms step_avg:89.82ms +step:1201/1680 train_time:107870ms step_avg:89.82ms +step:1202/1680 train_time:107960ms step_avg:89.82ms +step:1203/1680 train_time:108051ms step_avg:89.82ms +step:1204/1680 train_time:108141ms step_avg:89.82ms +step:1205/1680 train_time:108232ms step_avg:89.82ms +step:1206/1680 train_time:108323ms step_avg:89.82ms +step:1207/1680 train_time:108414ms step_avg:89.82ms +step:1208/1680 train_time:108505ms step_avg:89.82ms +step:1209/1680 train_time:108595ms step_avg:89.82ms +step:1210/1680 train_time:108686ms step_avg:89.82ms +step:1211/1680 train_time:108776ms step_avg:89.82ms +step:1212/1680 train_time:108868ms step_avg:89.83ms +step:1213/1680 train_time:108958ms step_avg:89.83ms +step:1214/1680 train_time:109049ms step_avg:89.83ms +step:1215/1680 train_time:109140ms step_avg:89.83ms +step:1216/1680 train_time:109231ms step_avg:89.83ms +step:1217/1680 train_time:109322ms step_avg:89.83ms +step:1218/1680 train_time:109413ms step_avg:89.83ms +step:1219/1680 train_time:109504ms step_avg:89.83ms +step:1220/1680 train_time:109595ms step_avg:89.83ms +step:1221/1680 train_time:109686ms step_avg:89.83ms +step:1222/1680 train_time:109776ms step_avg:89.83ms +step:1223/1680 train_time:109867ms step_avg:89.83ms +step:1224/1680 train_time:109958ms step_avg:89.83ms +step:1225/1680 train_time:110049ms step_avg:89.84ms +step:1226/1680 train_time:110139ms step_avg:89.84ms +step:1227/1680 train_time:110230ms step_avg:89.84ms +step:1228/1680 train_time:110323ms step_avg:89.84ms +step:1229/1680 train_time:110413ms step_avg:89.84ms +step:1230/1680 train_time:110503ms step_avg:89.84ms +step:1231/1680 train_time:110594ms step_avg:89.84ms +step:1232/1680 train_time:110685ms step_avg:89.84ms +step:1233/1680 train_time:110776ms step_avg:89.84ms +step:1234/1680 train_time:110868ms step_avg:89.84ms +step:1235/1680 train_time:110959ms step_avg:89.85ms +step:1236/1680 train_time:111049ms step_avg:89.85ms +step:1237/1680 train_time:111140ms step_avg:89.85ms +step:1238/1680 train_time:111231ms step_avg:89.85ms +step:1239/1680 train_time:111322ms step_avg:89.85ms +step:1240/1680 train_time:111413ms step_avg:89.85ms +step:1241/1680 train_time:111504ms step_avg:89.85ms +step:1242/1680 train_time:111594ms step_avg:89.85ms +step:1243/1680 train_time:111686ms step_avg:89.85ms +step:1244/1680 train_time:111776ms step_avg:89.85ms +step:1245/1680 train_time:111867ms step_avg:89.85ms +step:1246/1680 train_time:111957ms step_avg:89.85ms +step:1247/1680 train_time:112047ms step_avg:89.85ms +step:1248/1680 train_time:112138ms step_avg:89.85ms +step:1249/1680 train_time:112228ms step_avg:89.85ms +step:1250/1680 train_time:112318ms step_avg:89.85ms +step:1250/1680 val_loss:3.3782 train_time:112410ms step_avg:89.93ms +step:1251/1680 train_time:112433ms step_avg:89.87ms +step:1252/1680 train_time:112506ms step_avg:89.86ms +step:1253/1680 train_time:112600ms step_avg:89.86ms +step:1254/1680 train_time:112691ms step_avg:89.87ms +step:1255/1680 train_time:112782ms step_avg:89.87ms +step:1256/1680 train_time:112872ms step_avg:89.87ms +step:1257/1680 train_time:112961ms step_avg:89.87ms +step:1258/1680 train_time:113052ms step_avg:89.87ms +step:1259/1680 train_time:113142ms step_avg:89.87ms +step:1260/1680 train_time:113232ms step_avg:89.87ms +step:1261/1680 train_time:113323ms step_avg:89.87ms +step:1262/1680 train_time:113417ms step_avg:89.87ms +step:1263/1680 train_time:113510ms step_avg:89.87ms +step:1264/1680 train_time:113603ms step_avg:89.88ms +step:1265/1680 train_time:113695ms step_avg:89.88ms +step:1266/1680 train_time:113786ms step_avg:89.88ms +step:1267/1680 train_time:113875ms step_avg:89.88ms +step:1268/1680 train_time:113966ms step_avg:89.88ms +step:1269/1680 train_time:114055ms step_avg:89.88ms +step:1270/1680 train_time:114145ms step_avg:89.88ms +step:1271/1680 train_time:114234ms step_avg:89.88ms +step:1272/1680 train_time:114324ms step_avg:89.88ms +step:1273/1680 train_time:114417ms step_avg:89.88ms +step:1274/1680 train_time:114510ms step_avg:89.88ms +step:1275/1680 train_time:114603ms step_avg:89.88ms +step:1276/1680 train_time:114695ms step_avg:89.89ms +step:1277/1680 train_time:114785ms step_avg:89.89ms +step:1278/1680 train_time:114875ms step_avg:89.89ms +step:1279/1680 train_time:114965ms step_avg:89.89ms +step:1280/1680 train_time:115056ms step_avg:89.89ms +step:1281/1680 train_time:115146ms step_avg:89.89ms +step:1282/1680 train_time:115237ms step_avg:89.89ms +step:1283/1680 train_time:115327ms step_avg:89.89ms +step:1284/1680 train_time:115419ms step_avg:89.89ms +step:1285/1680 train_time:115511ms step_avg:89.89ms +step:1286/1680 train_time:115602ms step_avg:89.89ms +step:1287/1680 train_time:115693ms step_avg:89.89ms +step:1288/1680 train_time:115784ms step_avg:89.89ms +step:1289/1680 train_time:115875ms step_avg:89.89ms +step:1290/1680 train_time:115965ms step_avg:89.90ms +step:1291/1680 train_time:116055ms step_avg:89.90ms +step:1292/1680 train_time:116146ms step_avg:89.90ms +step:1293/1680 train_time:116237ms step_avg:89.90ms +step:1294/1680 train_time:116327ms step_avg:89.90ms +step:1295/1680 train_time:116419ms step_avg:89.90ms +step:1296/1680 train_time:116511ms step_avg:89.90ms +step:1297/1680 train_time:116602ms step_avg:89.90ms +step:1298/1680 train_time:116693ms step_avg:89.90ms +step:1299/1680 train_time:116785ms step_avg:89.90ms +step:1300/1680 train_time:116875ms step_avg:89.90ms +step:1301/1680 train_time:116965ms step_avg:89.90ms +step:1302/1680 train_time:117055ms step_avg:89.90ms +step:1303/1680 train_time:117145ms step_avg:89.90ms +step:1304/1680 train_time:117236ms step_avg:89.90ms +step:1305/1680 train_time:117326ms step_avg:89.91ms +step:1306/1680 train_time:117419ms step_avg:89.91ms +step:1307/1680 train_time:117510ms step_avg:89.91ms +step:1308/1680 train_time:117602ms step_avg:89.91ms +step:1309/1680 train_time:117694ms step_avg:89.91ms +step:1310/1680 train_time:117785ms step_avg:89.91ms +step:1311/1680 train_time:117875ms step_avg:89.91ms +step:1312/1680 train_time:117966ms step_avg:89.91ms +step:1313/1680 train_time:118056ms step_avg:89.91ms +step:1314/1680 train_time:118147ms step_avg:89.91ms +step:1315/1680 train_time:118238ms step_avg:89.92ms +step:1316/1680 train_time:118328ms step_avg:89.92ms +step:1317/1680 train_time:118420ms step_avg:89.92ms +step:1318/1680 train_time:118513ms step_avg:89.92ms +step:1319/1680 train_time:118605ms step_avg:89.92ms +step:1320/1680 train_time:118697ms step_avg:89.92ms +step:1321/1680 train_time:118787ms step_avg:89.92ms +step:1322/1680 train_time:118878ms step_avg:89.92ms +step:1323/1680 train_time:118969ms step_avg:89.92ms +step:1324/1680 train_time:119060ms step_avg:89.92ms +step:1325/1680 train_time:119151ms step_avg:89.93ms +step:1326/1680 train_time:119242ms step_avg:89.93ms +step:1327/1680 train_time:119332ms step_avg:89.93ms +step:1328/1680 train_time:119422ms step_avg:89.93ms +step:1329/1680 train_time:119513ms step_avg:89.93ms +step:1330/1680 train_time:119604ms step_avg:89.93ms +step:1331/1680 train_time:119695ms step_avg:89.93ms +step:1332/1680 train_time:119786ms step_avg:89.93ms +step:1333/1680 train_time:119878ms step_avg:89.93ms +step:1334/1680 train_time:119969ms step_avg:89.93ms +step:1335/1680 train_time:120059ms step_avg:89.93ms +step:1336/1680 train_time:120150ms step_avg:89.93ms +step:1337/1680 train_time:120242ms step_avg:89.93ms +step:1338/1680 train_time:120333ms step_avg:89.93ms +step:1339/1680 train_time:120423ms step_avg:89.94ms +step:1340/1680 train_time:120515ms step_avg:89.94ms +step:1341/1680 train_time:120604ms step_avg:89.94ms +step:1342/1680 train_time:120696ms step_avg:89.94ms +step:1343/1680 train_time:120786ms step_avg:89.94ms +step:1344/1680 train_time:120878ms step_avg:89.94ms +step:1345/1680 train_time:120970ms step_avg:89.94ms +step:1346/1680 train_time:121060ms step_avg:89.94ms +step:1347/1680 train_time:121151ms step_avg:89.94ms +step:1348/1680 train_time:121242ms step_avg:89.94ms +step:1349/1680 train_time:121333ms step_avg:89.94ms +step:1350/1680 train_time:121424ms step_avg:89.94ms +step:1351/1680 train_time:121515ms step_avg:89.94ms +step:1352/1680 train_time:121605ms step_avg:89.94ms +step:1353/1680 train_time:121696ms step_avg:89.95ms +step:1354/1680 train_time:121787ms step_avg:89.95ms +step:1355/1680 train_time:121878ms step_avg:89.95ms +step:1356/1680 train_time:121968ms step_avg:89.95ms +step:1357/1680 train_time:122058ms step_avg:89.95ms +step:1358/1680 train_time:122150ms step_avg:89.95ms +step:1359/1680 train_time:122241ms step_avg:89.95ms +step:1360/1680 train_time:122331ms step_avg:89.95ms +step:1361/1680 train_time:122421ms step_avg:89.95ms +step:1362/1680 train_time:122512ms step_avg:89.95ms +step:1363/1680 train_time:122603ms step_avg:89.95ms +step:1364/1680 train_time:122694ms step_avg:89.95ms +step:1365/1680 train_time:122784ms step_avg:89.95ms +step:1366/1680 train_time:122875ms step_avg:89.95ms +step:1367/1680 train_time:122966ms step_avg:89.95ms +step:1368/1680 train_time:123057ms step_avg:89.95ms +step:1369/1680 train_time:123148ms step_avg:89.95ms +step:1370/1680 train_time:123239ms step_avg:89.96ms +step:1371/1680 train_time:123330ms step_avg:89.96ms +step:1372/1680 train_time:123421ms step_avg:89.96ms +step:1373/1680 train_time:123512ms step_avg:89.96ms +step:1374/1680 train_time:123603ms step_avg:89.96ms +step:1375/1680 train_time:123694ms step_avg:89.96ms +step:1375/1680 val_loss:3.3428 train_time:123786ms step_avg:90.03ms +step:1376/1680 train_time:123809ms step_avg:89.98ms +step:1377/1680 train_time:123881ms step_avg:89.96ms +step:1378/1680 train_time:123978ms step_avg:89.97ms +step:1379/1680 train_time:124069ms step_avg:89.97ms +step:1380/1680 train_time:124158ms step_avg:89.97ms +step:1381/1680 train_time:124248ms step_avg:89.97ms +step:1382/1680 train_time:124337ms step_avg:89.97ms +step:1383/1680 train_time:124427ms step_avg:89.97ms +step:1384/1680 train_time:124516ms step_avg:89.97ms +step:1385/1680 train_time:124606ms step_avg:89.97ms +step:1386/1680 train_time:124696ms step_avg:89.97ms +step:1387/1680 train_time:124787ms step_avg:89.97ms +step:1388/1680 train_time:124881ms step_avg:89.97ms +step:1389/1680 train_time:124974ms step_avg:89.97ms +step:1390/1680 train_time:125066ms step_avg:89.98ms +step:1391/1680 train_time:125157ms step_avg:89.98ms +step:1392/1680 train_time:125247ms step_avg:89.98ms +step:1393/1680 train_time:125337ms step_avg:89.98ms +step:1394/1680 train_time:125427ms step_avg:89.98ms +step:1395/1680 train_time:125517ms step_avg:89.98ms +step:1396/1680 train_time:125607ms step_avg:89.98ms +step:1397/1680 train_time:125696ms step_avg:89.98ms +step:1398/1680 train_time:125787ms step_avg:89.98ms +step:1399/1680 train_time:125879ms step_avg:89.98ms +step:1400/1680 train_time:125971ms step_avg:89.98ms +step:1401/1680 train_time:126063ms step_avg:89.98ms +step:1402/1680 train_time:126155ms step_avg:89.98ms +step:1403/1680 train_time:126245ms step_avg:89.98ms +step:1404/1680 train_time:126336ms step_avg:89.98ms +step:1405/1680 train_time:126426ms step_avg:89.98ms +step:1406/1680 train_time:126515ms step_avg:89.98ms +step:1407/1680 train_time:126605ms step_avg:89.98ms +step:1408/1680 train_time:126696ms step_avg:89.98ms +step:1409/1680 train_time:126788ms step_avg:89.98ms +step:1410/1680 train_time:126879ms step_avg:89.99ms +step:1411/1680 train_time:126970ms step_avg:89.99ms +step:1412/1680 train_time:127062ms step_avg:89.99ms +step:1413/1680 train_time:127154ms step_avg:89.99ms +step:1414/1680 train_time:127245ms step_avg:89.99ms +step:1415/1680 train_time:127335ms step_avg:89.99ms +step:1416/1680 train_time:127425ms step_avg:89.99ms +step:1417/1680 train_time:127515ms step_avg:89.99ms +step:1418/1680 train_time:127608ms step_avg:89.99ms +step:1419/1680 train_time:127695ms step_avg:89.99ms +step:1420/1680 train_time:127786ms step_avg:89.99ms +step:1421/1680 train_time:127878ms step_avg:89.99ms +step:1422/1680 train_time:127969ms step_avg:89.99ms +step:1423/1680 train_time:128060ms step_avg:89.99ms +step:1424/1680 train_time:128151ms step_avg:89.99ms +step:1425/1680 train_time:128242ms step_avg:89.99ms +step:1426/1680 train_time:128333ms step_avg:89.99ms +step:1427/1680 train_time:128423ms step_avg:90.00ms +step:1428/1680 train_time:128513ms step_avg:90.00ms +step:1429/1680 train_time:128604ms step_avg:90.00ms +step:1430/1680 train_time:128695ms step_avg:90.00ms +step:1431/1680 train_time:128786ms step_avg:90.00ms +step:1432/1680 train_time:128877ms step_avg:90.00ms +step:1433/1680 train_time:128968ms step_avg:90.00ms +step:1434/1680 train_time:129059ms step_avg:90.00ms +step:1435/1680 train_time:129150ms step_avg:90.00ms +step:1436/1680 train_time:129242ms step_avg:90.00ms +step:1437/1680 train_time:129333ms step_avg:90.00ms +step:1438/1680 train_time:129423ms step_avg:90.00ms +step:1439/1680 train_time:129513ms step_avg:90.00ms +step:1440/1680 train_time:129603ms step_avg:90.00ms +step:1441/1680 train_time:129693ms step_avg:90.00ms +step:1442/1680 train_time:129785ms step_avg:90.00ms +step:1443/1680 train_time:129876ms step_avg:90.00ms +step:1444/1680 train_time:129967ms step_avg:90.00ms +step:1445/1680 train_time:130059ms step_avg:90.01ms +step:1446/1680 train_time:130151ms step_avg:90.01ms +step:1447/1680 train_time:130243ms step_avg:90.01ms +step:1448/1680 train_time:130333ms step_avg:90.01ms +step:1449/1680 train_time:130424ms step_avg:90.01ms +step:1450/1680 train_time:130514ms step_avg:90.01ms +step:1451/1680 train_time:130607ms step_avg:90.01ms +step:1452/1680 train_time:130694ms step_avg:90.01ms +step:1453/1680 train_time:130785ms step_avg:90.01ms +step:1454/1680 train_time:130876ms step_avg:90.01ms +step:1455/1680 train_time:130966ms step_avg:90.01ms +step:1456/1680 train_time:131057ms step_avg:90.01ms +step:1457/1680 train_time:131148ms step_avg:90.01ms +step:1458/1680 train_time:131240ms step_avg:90.01ms +step:1459/1680 train_time:131331ms step_avg:90.01ms +step:1460/1680 train_time:131422ms step_avg:90.02ms +step:1461/1680 train_time:131514ms step_avg:90.02ms +step:1462/1680 train_time:131605ms step_avg:90.02ms +step:1463/1680 train_time:131695ms step_avg:90.02ms +step:1464/1680 train_time:131785ms step_avg:90.02ms +step:1465/1680 train_time:131876ms step_avg:90.02ms +step:1466/1680 train_time:131966ms step_avg:90.02ms +step:1467/1680 train_time:132059ms step_avg:90.02ms +step:1468/1680 train_time:132151ms step_avg:90.02ms +step:1469/1680 train_time:132243ms step_avg:90.02ms +step:1470/1680 train_time:132334ms step_avg:90.02ms +step:1471/1680 train_time:132425ms step_avg:90.02ms +step:1472/1680 train_time:132515ms step_avg:90.02ms +step:1473/1680 train_time:132607ms step_avg:90.02ms +step:1474/1680 train_time:132697ms step_avg:90.02ms +step:1475/1680 train_time:132787ms step_avg:90.03ms +step:1476/1680 train_time:132878ms step_avg:90.03ms +step:1477/1680 train_time:132969ms step_avg:90.03ms +step:1478/1680 train_time:133060ms step_avg:90.03ms +step:1479/1680 train_time:133151ms step_avg:90.03ms +step:1480/1680 train_time:133243ms step_avg:90.03ms +step:1481/1680 train_time:133334ms step_avg:90.03ms +step:1482/1680 train_time:133425ms step_avg:90.03ms +step:1483/1680 train_time:133516ms step_avg:90.03ms +step:1484/1680 train_time:133607ms step_avg:90.03ms +step:1485/1680 train_time:133698ms step_avg:90.03ms +step:1486/1680 train_time:133788ms step_avg:90.03ms +step:1487/1680 train_time:133879ms step_avg:90.03ms +step:1488/1680 train_time:133970ms step_avg:90.03ms +step:1489/1680 train_time:134062ms step_avg:90.03ms +step:1490/1680 train_time:134154ms step_avg:90.04ms +step:1491/1680 train_time:134244ms step_avg:90.04ms +step:1492/1680 train_time:134334ms step_avg:90.04ms +step:1493/1680 train_time:134425ms step_avg:90.04ms +step:1494/1680 train_time:134516ms step_avg:90.04ms +step:1495/1680 train_time:134607ms step_avg:90.04ms +step:1496/1680 train_time:134697ms step_avg:90.04ms +step:1497/1680 train_time:134787ms step_avg:90.04ms +step:1498/1680 train_time:134879ms step_avg:90.04ms +step:1499/1680 train_time:134970ms step_avg:90.04ms +step:1500/1680 train_time:135063ms step_avg:90.04ms +step:1500/1680 val_loss:3.3137 train_time:135157ms step_avg:90.10ms +step:1501/1680 train_time:135180ms step_avg:90.06ms +step:1502/1680 train_time:135253ms step_avg:90.05ms +step:1503/1680 train_time:135352ms step_avg:90.05ms +step:1504/1680 train_time:135446ms step_avg:90.06ms +step:1505/1680 train_time:135534ms step_avg:90.06ms +step:1506/1680 train_time:135623ms step_avg:90.05ms +step:1507/1680 train_time:135712ms step_avg:90.05ms +step:1508/1680 train_time:135802ms step_avg:90.05ms +step:1509/1680 train_time:135892ms step_avg:90.05ms +step:1510/1680 train_time:135982ms step_avg:90.05ms +step:1511/1680 train_time:136071ms step_avg:90.05ms +step:1512/1680 train_time:136164ms step_avg:90.06ms +step:1513/1680 train_time:136259ms step_avg:90.06ms +step:1514/1680 train_time:136352ms step_avg:90.06ms +step:1515/1680 train_time:136444ms step_avg:90.06ms +step:1516/1680 train_time:136534ms step_avg:90.06ms +step:1517/1680 train_time:136623ms step_avg:90.06ms +step:1518/1680 train_time:136713ms step_avg:90.06ms +step:1519/1680 train_time:136802ms step_avg:90.06ms +step:1520/1680 train_time:136892ms step_avg:90.06ms +step:1521/1680 train_time:136982ms step_avg:90.06ms +step:1522/1680 train_time:137072ms step_avg:90.06ms +step:1523/1680 train_time:137163ms step_avg:90.06ms +step:1524/1680 train_time:137256ms step_avg:90.06ms +step:1525/1680 train_time:137349ms step_avg:90.07ms +step:1526/1680 train_time:137439ms step_avg:90.06ms +step:1527/1680 train_time:137530ms step_avg:90.07ms +step:1528/1680 train_time:137621ms step_avg:90.07ms +step:1529/1680 train_time:137711ms step_avg:90.07ms +step:1530/1680 train_time:137801ms step_avg:90.07ms +step:1531/1680 train_time:137891ms step_avg:90.07ms +step:1532/1680 train_time:137981ms step_avg:90.07ms +step:1533/1680 train_time:138072ms step_avg:90.07ms +step:1534/1680 train_time:138163ms step_avg:90.07ms +step:1535/1680 train_time:138254ms step_avg:90.07ms +step:1536/1680 train_time:138346ms step_avg:90.07ms +step:1537/1680 train_time:138439ms step_avg:90.07ms +step:1538/1680 train_time:138530ms step_avg:90.07ms +step:1539/1680 train_time:138621ms step_avg:90.07ms +step:1540/1680 train_time:138711ms step_avg:90.07ms +step:1541/1680 train_time:138801ms step_avg:90.07ms +step:1542/1680 train_time:138891ms step_avg:90.07ms +step:1543/1680 train_time:138981ms step_avg:90.07ms +step:1544/1680 train_time:139071ms step_avg:90.07ms +step:1545/1680 train_time:139162ms step_avg:90.07ms +step:1546/1680 train_time:139254ms step_avg:90.07ms +step:1547/1680 train_time:139344ms step_avg:90.07ms +step:1548/1680 train_time:139435ms step_avg:90.07ms +step:1549/1680 train_time:139526ms step_avg:90.08ms +step:1550/1680 train_time:139618ms step_avg:90.08ms +step:1551/1680 train_time:139709ms step_avg:90.08ms +step:1552/1680 train_time:139800ms step_avg:90.08ms +step:1553/1680 train_time:139890ms step_avg:90.08ms +step:1554/1680 train_time:139981ms step_avg:90.08ms +step:1555/1680 train_time:140071ms step_avg:90.08ms +step:1556/1680 train_time:140163ms step_avg:90.08ms +step:1557/1680 train_time:140255ms step_avg:90.08ms +step:1558/1680 train_time:140345ms step_avg:90.08ms +step:1559/1680 train_time:140435ms step_avg:90.08ms +step:1560/1680 train_time:140527ms step_avg:90.08ms +step:1561/1680 train_time:140618ms step_avg:90.08ms +step:1562/1680 train_time:140708ms step_avg:90.08ms +step:1563/1680 train_time:140798ms step_avg:90.08ms +step:1564/1680 train_time:140889ms step_avg:90.08ms +step:1565/1680 train_time:140981ms step_avg:90.08ms +step:1566/1680 train_time:141072ms step_avg:90.08ms +step:1567/1680 train_time:141163ms step_avg:90.09ms +step:1568/1680 train_time:141254ms step_avg:90.09ms +step:1569/1680 train_time:141345ms step_avg:90.09ms +step:1570/1680 train_time:141435ms step_avg:90.09ms +step:1571/1680 train_time:141526ms step_avg:90.09ms +step:1572/1680 train_time:141617ms step_avg:90.09ms +step:1573/1680 train_time:141708ms step_avg:90.09ms +step:1574/1680 train_time:141799ms step_avg:90.09ms +step:1575/1680 train_time:141889ms step_avg:90.09ms +step:1576/1680 train_time:141980ms step_avg:90.09ms +step:1577/1680 train_time:142070ms step_avg:90.09ms +step:1578/1680 train_time:142161ms step_avg:90.09ms +step:1579/1680 train_time:142252ms step_avg:90.09ms +step:1580/1680 train_time:142342ms step_avg:90.09ms +step:1581/1680 train_time:142433ms step_avg:90.09ms +step:1582/1680 train_time:142524ms step_avg:90.09ms +step:1583/1680 train_time:142614ms step_avg:90.09ms +step:1584/1680 train_time:142704ms step_avg:90.09ms +step:1585/1680 train_time:142795ms step_avg:90.09ms +step:1586/1680 train_time:142886ms step_avg:90.09ms +step:1587/1680 train_time:142976ms step_avg:90.09ms +step:1588/1680 train_time:143067ms step_avg:90.09ms +step:1589/1680 train_time:143157ms step_avg:90.09ms +step:1590/1680 train_time:143248ms step_avg:90.09ms +step:1591/1680 train_time:143339ms step_avg:90.09ms +step:1592/1680 train_time:143430ms step_avg:90.09ms +step:1593/1680 train_time:143522ms step_avg:90.10ms +step:1594/1680 train_time:143614ms step_avg:90.10ms +step:1595/1680 train_time:143703ms step_avg:90.10ms +step:1596/1680 train_time:143794ms step_avg:90.10ms +step:1597/1680 train_time:143885ms step_avg:90.10ms +step:1598/1680 train_time:143975ms step_avg:90.10ms +step:1599/1680 train_time:144066ms step_avg:90.10ms +step:1600/1680 train_time:144156ms step_avg:90.10ms +step:1601/1680 train_time:144251ms step_avg:90.10ms +step:1602/1680 train_time:144338ms step_avg:90.10ms +step:1603/1680 train_time:144428ms step_avg:90.10ms +step:1604/1680 train_time:144521ms step_avg:90.10ms +step:1605/1680 train_time:144614ms step_avg:90.10ms +step:1606/1680 train_time:144704ms step_avg:90.10ms +step:1607/1680 train_time:144795ms step_avg:90.10ms +step:1608/1680 train_time:144886ms step_avg:90.10ms +step:1609/1680 train_time:144977ms step_avg:90.10ms +step:1610/1680 train_time:145067ms step_avg:90.10ms +step:1611/1680 train_time:145158ms step_avg:90.10ms +step:1612/1680 train_time:145249ms step_avg:90.11ms +step:1613/1680 train_time:145340ms step_avg:90.11ms +step:1614/1680 train_time:145431ms step_avg:90.11ms +step:1615/1680 train_time:145521ms step_avg:90.11ms +step:1616/1680 train_time:145612ms step_avg:90.11ms +step:1617/1680 train_time:145703ms step_avg:90.11ms +step:1618/1680 train_time:145794ms step_avg:90.11ms +step:1619/1680 train_time:145884ms step_avg:90.11ms +step:1620/1680 train_time:145975ms step_avg:90.11ms +step:1621/1680 train_time:146065ms step_avg:90.11ms +step:1622/1680 train_time:146155ms step_avg:90.11ms +step:1623/1680 train_time:146247ms step_avg:90.11ms +step:1624/1680 train_time:146337ms step_avg:90.11ms +step:1625/1680 train_time:146429ms step_avg:90.11ms +step:1625/1680 val_loss:3.2902 train_time:146521ms step_avg:90.17ms +step:1626/1680 train_time:146544ms step_avg:90.13ms +step:1627/1680 train_time:146615ms step_avg:90.11ms +step:1628/1680 train_time:146709ms step_avg:90.12ms +step:1629/1680 train_time:146801ms step_avg:90.12ms +step:1630/1680 train_time:146890ms step_avg:90.12ms +step:1631/1680 train_time:146981ms step_avg:90.12ms +step:1632/1680 train_time:147070ms step_avg:90.12ms +step:1633/1680 train_time:147159ms step_avg:90.12ms +step:1634/1680 train_time:147248ms step_avg:90.12ms +step:1635/1680 train_time:147338ms step_avg:90.11ms +step:1636/1680 train_time:147428ms step_avg:90.12ms +step:1637/1680 train_time:147522ms step_avg:90.12ms +step:1638/1680 train_time:147614ms step_avg:90.12ms +step:1639/1680 train_time:147706ms step_avg:90.12ms +step:1640/1680 train_time:147798ms step_avg:90.12ms +step:1641/1680 train_time:147889ms step_avg:90.12ms +step:1642/1680 train_time:147980ms step_avg:90.12ms +step:1643/1680 train_time:148071ms step_avg:90.12ms +step:1644/1680 train_time:148161ms step_avg:90.12ms +step:1645/1680 train_time:148251ms step_avg:90.12ms +step:1646/1680 train_time:148342ms step_avg:90.12ms +step:1647/1680 train_time:148432ms step_avg:90.12ms +step:1648/1680 train_time:148524ms step_avg:90.12ms +step:1649/1680 train_time:148616ms step_avg:90.12ms +step:1650/1680 train_time:148707ms step_avg:90.13ms +step:1651/1680 train_time:148798ms step_avg:90.13ms +step:1652/1680 train_time:148890ms step_avg:90.13ms +step:1653/1680 train_time:148981ms step_avg:90.13ms +step:1654/1680 train_time:149071ms step_avg:90.13ms +step:1655/1680 train_time:149162ms step_avg:90.13ms +step:1656/1680 train_time:149253ms step_avg:90.13ms +step:1657/1680 train_time:149343ms step_avg:90.13ms +step:1658/1680 train_time:149434ms step_avg:90.13ms +step:1659/1680 train_time:149526ms step_avg:90.13ms +step:1660/1680 train_time:149616ms step_avg:90.13ms +step:1661/1680 train_time:149708ms step_avg:90.13ms +step:1662/1680 train_time:149799ms step_avg:90.13ms +step:1663/1680 train_time:149891ms step_avg:90.13ms +step:1664/1680 train_time:149982ms step_avg:90.13ms +step:1665/1680 train_time:150072ms step_avg:90.13ms +step:1666/1680 train_time:150163ms step_avg:90.13ms +step:1667/1680 train_time:150253ms step_avg:90.13ms +step:1668/1680 train_time:150343ms step_avg:90.13ms +step:1669/1680 train_time:150434ms step_avg:90.13ms +step:1670/1680 train_time:150524ms step_avg:90.13ms +step:1671/1680 train_time:150616ms step_avg:90.14ms +step:1672/1680 train_time:150706ms step_avg:90.14ms +step:1673/1680 train_time:150799ms step_avg:90.14ms +step:1674/1680 train_time:150888ms step_avg:90.14ms +step:1675/1680 train_time:150980ms step_avg:90.14ms +step:1676/1680 train_time:151072ms step_avg:90.14ms +step:1677/1680 train_time:151162ms step_avg:90.14ms +step:1678/1680 train_time:151254ms step_avg:90.14ms +step:1679/1680 train_time:151343ms step_avg:90.14ms +step:1680/1680 train_time:151433ms step_avg:90.14ms +step:1680/1680 val_loss:3.2798 train_time:151526ms step_avg:90.19ms +peak memory allocated: 31255 MiB reserved: 46334 MiB diff --git a/records/092125_DropAttn/06350b97-4a98-47da-90c1-3b957af8af6c.txt b/records/092125_DropAttn/06350b97-4a98-47da-90c1-3b957af8af6c.txt new file mode 100644 index 000000000..6a8d4a811 --- /dev/null +++ b/records/092125_DropAttn/06350b97-4a98-47da-90c1-3b957af8af6c.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 22:37:42 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 43C P0 127W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 39C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 63689 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 63690 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63691 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63692 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63693 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63694 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63695 C /usr/bin/python3 614MiB | +| 0 N/A N/A 63696 C /usr/bin/python3 614MiB | +| 1 N/A N/A 63690 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 63691 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 63692 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 63693 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 63694 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 63695 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 63696 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:152ms step_avg:152.15ms +step:2/1680 train_time:176ms step_avg:87.82ms +step:3/1680 train_time:238ms step_avg:79.18ms +step:4/1680 train_time:325ms step_avg:81.16ms +step:5/1680 train_time:412ms step_avg:82.49ms +step:6/1680 train_time:503ms step_avg:83.78ms +step:7/1680 train_time:589ms step_avg:84.12ms +step:8/1680 train_time:677ms step_avg:84.68ms +step:9/1680 train_time:765ms step_avg:85.05ms +step:10/1680 train_time:854ms step_avg:85.38ms +step:11/1680 train_time:942ms step_avg:85.66ms +step:12/1680 train_time:1031ms step_avg:85.92ms +step:13/1680 train_time:1124ms step_avg:86.48ms +step:14/1680 train_time:1216ms step_avg:86.84ms +step:15/1680 train_time:1305ms step_avg:87.03ms +step:16/1680 train_time:1395ms step_avg:87.17ms +step:17/1680 train_time:1484ms step_avg:87.28ms +step:18/1680 train_time:1572ms step_avg:87.36ms +step:19/1680 train_time:1661ms step_avg:87.41ms +step:20/1680 train_time:1750ms step_avg:87.50ms +step:21/1680 train_time:1839ms step_avg:87.58ms +step:22/1680 train_time:1928ms step_avg:87.61ms +step:23/1680 train_time:2018ms step_avg:87.73ms +step:24/1680 train_time:2108ms step_avg:87.85ms +step:25/1680 train_time:2199ms step_avg:87.96ms +step:26/1680 train_time:2288ms step_avg:88.01ms +step:27/1680 train_time:2378ms step_avg:88.06ms +step:28/1680 train_time:2467ms step_avg:88.11ms +step:29/1680 train_time:2556ms step_avg:88.14ms +step:30/1680 train_time:2646ms step_avg:88.19ms +step:31/1680 train_time:2735ms step_avg:88.22ms +step:32/1680 train_time:2824ms step_avg:88.24ms +step:33/1680 train_time:2912ms step_avg:88.26ms +step:34/1680 train_time:3002ms step_avg:88.30ms +step:35/1680 train_time:3092ms step_avg:88.34ms +step:36/1680 train_time:3182ms step_avg:88.38ms +step:37/1680 train_time:3270ms step_avg:88.38ms +step:38/1680 train_time:3359ms step_avg:88.39ms +step:39/1680 train_time:3448ms step_avg:88.41ms +step:40/1680 train_time:3537ms step_avg:88.44ms +step:41/1680 train_time:3627ms step_avg:88.46ms +step:42/1680 train_time:3716ms step_avg:88.48ms +step:43/1680 train_time:3805ms step_avg:88.49ms +step:44/1680 train_time:3895ms step_avg:88.51ms +step:45/1680 train_time:3984ms step_avg:88.54ms +step:46/1680 train_time:4074ms step_avg:88.57ms +step:47/1680 train_time:4165ms step_avg:88.61ms +step:48/1680 train_time:4254ms step_avg:88.63ms +step:49/1680 train_time:4343ms step_avg:88.64ms +step:50/1680 train_time:4433ms step_avg:88.65ms +step:51/1680 train_time:4522ms step_avg:88.66ms +step:52/1680 train_time:4610ms step_avg:88.66ms +step:53/1680 train_time:4699ms step_avg:88.66ms +step:54/1680 train_time:4787ms step_avg:88.65ms +step:55/1680 train_time:4877ms step_avg:88.67ms +step:56/1680 train_time:4966ms step_avg:88.67ms +step:57/1680 train_time:5055ms step_avg:88.68ms +step:58/1680 train_time:5145ms step_avg:88.70ms +step:59/1680 train_time:5235ms step_avg:88.73ms +step:60/1680 train_time:5324ms step_avg:88.74ms +step:61/1680 train_time:5414ms step_avg:88.75ms +step:62/1680 train_time:5503ms step_avg:88.76ms +step:63/1680 train_time:5592ms step_avg:88.76ms +step:64/1680 train_time:5681ms step_avg:88.77ms +step:65/1680 train_time:5771ms step_avg:88.78ms +step:66/1680 train_time:5859ms step_avg:88.78ms +step:67/1680 train_time:5948ms step_avg:88.78ms +step:68/1680 train_time:6038ms step_avg:88.79ms +step:69/1680 train_time:6127ms step_avg:88.80ms +step:70/1680 train_time:6216ms step_avg:88.80ms +step:71/1680 train_time:6306ms step_avg:88.82ms +step:72/1680 train_time:6394ms step_avg:88.80ms +step:73/1680 train_time:6484ms step_avg:88.82ms +step:74/1680 train_time:6573ms step_avg:88.82ms +step:75/1680 train_time:6662ms step_avg:88.83ms +step:76/1680 train_time:6752ms step_avg:88.84ms +step:77/1680 train_time:6841ms step_avg:88.84ms +step:78/1680 train_time:6930ms step_avg:88.84ms +step:79/1680 train_time:7019ms step_avg:88.85ms +step:80/1680 train_time:7108ms step_avg:88.85ms +step:81/1680 train_time:7197ms step_avg:88.85ms +step:82/1680 train_time:7287ms step_avg:88.86ms +step:83/1680 train_time:7376ms step_avg:88.87ms +step:84/1680 train_time:7465ms step_avg:88.87ms +step:85/1680 train_time:7555ms step_avg:88.89ms +step:86/1680 train_time:7645ms step_avg:88.90ms +step:87/1680 train_time:7734ms step_avg:88.90ms +step:88/1680 train_time:7824ms step_avg:88.91ms +step:89/1680 train_time:7913ms step_avg:88.91ms +step:90/1680 train_time:8002ms step_avg:88.91ms +step:91/1680 train_time:8092ms step_avg:88.92ms +step:92/1680 train_time:8181ms step_avg:88.92ms +step:93/1680 train_time:8271ms step_avg:88.94ms +step:94/1680 train_time:8360ms step_avg:88.93ms +step:95/1680 train_time:8449ms step_avg:88.94ms +step:96/1680 train_time:8538ms step_avg:88.94ms +step:97/1680 train_time:8627ms step_avg:88.94ms +step:98/1680 train_time:8716ms step_avg:88.94ms +step:99/1680 train_time:8805ms step_avg:88.94ms +step:100/1680 train_time:8895ms step_avg:88.95ms +step:101/1680 train_time:8984ms step_avg:88.95ms +step:102/1680 train_time:9073ms step_avg:88.95ms +step:103/1680 train_time:9163ms step_avg:88.96ms +step:104/1680 train_time:9252ms step_avg:88.96ms +step:105/1680 train_time:9341ms step_avg:88.96ms +step:106/1680 train_time:9430ms step_avg:88.96ms +step:107/1680 train_time:9519ms step_avg:88.96ms +step:108/1680 train_time:9608ms step_avg:88.96ms +step:109/1680 train_time:9696ms step_avg:88.96ms +step:110/1680 train_time:9785ms step_avg:88.96ms +step:111/1680 train_time:9874ms step_avg:88.96ms +step:112/1680 train_time:9964ms step_avg:88.96ms +step:113/1680 train_time:10053ms step_avg:88.96ms +step:114/1680 train_time:10143ms step_avg:88.97ms +step:115/1680 train_time:10233ms step_avg:88.98ms +step:116/1680 train_time:10322ms step_avg:88.98ms +step:117/1680 train_time:10411ms step_avg:88.98ms +step:118/1680 train_time:10500ms step_avg:88.98ms +step:119/1680 train_time:10589ms step_avg:88.99ms +step:120/1680 train_time:10678ms step_avg:88.98ms +step:121/1680 train_time:10767ms step_avg:88.98ms +step:122/1680 train_time:10856ms step_avg:88.99ms +step:123/1680 train_time:10945ms step_avg:88.99ms +step:124/1680 train_time:11035ms step_avg:88.99ms +step:125/1680 train_time:11125ms step_avg:89.00ms +step:125/1680 val_loss:4.3167 train_time:11216ms step_avg:89.73ms +step:126/1680 train_time:11239ms step_avg:89.20ms +step:127/1680 train_time:11309ms step_avg:89.05ms +step:128/1680 train_time:11405ms step_avg:89.10ms +step:129/1680 train_time:11496ms step_avg:89.12ms +step:130/1680 train_time:11585ms step_avg:89.12ms +step:131/1680 train_time:11673ms step_avg:89.11ms +step:132/1680 train_time:11761ms step_avg:89.10ms +step:133/1680 train_time:11849ms step_avg:89.09ms +step:134/1680 train_time:11937ms step_avg:89.09ms +step:135/1680 train_time:12025ms step_avg:89.08ms +step:136/1680 train_time:12113ms step_avg:89.07ms +step:137/1680 train_time:12202ms step_avg:89.07ms +step:138/1680 train_time:12293ms step_avg:89.08ms +step:139/1680 train_time:12384ms step_avg:89.10ms +step:140/1680 train_time:12475ms step_avg:89.11ms +step:141/1680 train_time:12565ms step_avg:89.11ms +step:142/1680 train_time:12654ms step_avg:89.12ms +step:143/1680 train_time:12743ms step_avg:89.11ms +step:144/1680 train_time:12831ms step_avg:89.11ms +step:145/1680 train_time:12920ms step_avg:89.11ms +step:146/1680 train_time:13010ms step_avg:89.11ms +step:147/1680 train_time:13098ms step_avg:89.10ms +step:148/1680 train_time:13185ms step_avg:89.09ms +step:149/1680 train_time:13274ms step_avg:89.09ms +step:150/1680 train_time:13364ms step_avg:89.09ms +step:151/1680 train_time:13454ms step_avg:89.10ms +step:152/1680 train_time:13544ms step_avg:89.10ms +step:153/1680 train_time:13633ms step_avg:89.11ms +step:154/1680 train_time:13723ms step_avg:89.11ms +step:155/1680 train_time:13812ms step_avg:89.11ms +step:156/1680 train_time:13901ms step_avg:89.11ms +step:157/1680 train_time:13989ms step_avg:89.10ms +step:158/1680 train_time:14076ms step_avg:89.09ms +step:159/1680 train_time:14165ms step_avg:89.09ms +step:160/1680 train_time:14253ms step_avg:89.08ms +step:161/1680 train_time:14343ms step_avg:89.09ms +step:162/1680 train_time:14433ms step_avg:89.09ms +step:163/1680 train_time:14523ms step_avg:89.10ms +step:164/1680 train_time:14612ms step_avg:89.10ms +step:165/1680 train_time:14701ms step_avg:89.10ms +step:166/1680 train_time:14791ms step_avg:89.10ms +step:167/1680 train_time:14880ms step_avg:89.10ms +step:168/1680 train_time:14969ms step_avg:89.10ms +step:169/1680 train_time:15061ms step_avg:89.12ms +step:170/1680 train_time:15147ms step_avg:89.10ms +step:171/1680 train_time:15235ms step_avg:89.09ms +step:172/1680 train_time:15324ms step_avg:89.09ms +step:173/1680 train_time:15413ms step_avg:89.09ms +step:174/1680 train_time:15503ms step_avg:89.10ms +step:175/1680 train_time:15592ms step_avg:89.10ms +step:176/1680 train_time:15682ms step_avg:89.10ms +step:177/1680 train_time:15771ms step_avg:89.10ms +step:178/1680 train_time:15860ms step_avg:89.10ms +step:179/1680 train_time:15949ms step_avg:89.10ms +step:180/1680 train_time:16038ms step_avg:89.10ms +step:181/1680 train_time:16127ms step_avg:89.10ms +step:182/1680 train_time:16215ms step_avg:89.09ms +step:183/1680 train_time:16304ms step_avg:89.09ms +step:184/1680 train_time:16393ms step_avg:89.09ms +step:185/1680 train_time:16482ms step_avg:89.09ms +step:186/1680 train_time:16571ms step_avg:89.09ms +step:187/1680 train_time:16661ms step_avg:89.09ms +step:188/1680 train_time:16750ms step_avg:89.10ms +step:189/1680 train_time:16839ms step_avg:89.10ms +step:190/1680 train_time:16928ms step_avg:89.10ms +step:191/1680 train_time:17017ms step_avg:89.09ms +step:192/1680 train_time:17106ms step_avg:89.09ms +step:193/1680 train_time:17195ms step_avg:89.09ms +step:194/1680 train_time:17284ms step_avg:89.09ms +step:195/1680 train_time:17373ms step_avg:89.09ms +step:196/1680 train_time:17461ms step_avg:89.09ms +step:197/1680 train_time:17551ms step_avg:89.09ms +step:198/1680 train_time:17640ms step_avg:89.09ms +step:199/1680 train_time:17729ms step_avg:89.09ms +step:200/1680 train_time:17819ms step_avg:89.09ms +step:201/1680 train_time:17908ms step_avg:89.09ms +step:202/1680 train_time:17997ms step_avg:89.09ms +step:203/1680 train_time:18086ms step_avg:89.10ms +step:204/1680 train_time:18175ms step_avg:89.10ms +step:205/1680 train_time:18264ms step_avg:89.09ms +step:206/1680 train_time:18353ms step_avg:89.09ms +step:207/1680 train_time:18441ms step_avg:89.09ms +step:208/1680 train_time:18530ms step_avg:89.09ms +step:209/1680 train_time:18619ms step_avg:89.08ms +step:210/1680 train_time:18708ms step_avg:89.08ms +step:211/1680 train_time:18798ms step_avg:89.09ms +step:212/1680 train_time:18887ms step_avg:89.09ms +step:213/1680 train_time:18976ms step_avg:89.09ms +step:214/1680 train_time:19064ms step_avg:89.09ms +step:215/1680 train_time:19153ms step_avg:89.08ms +step:216/1680 train_time:19243ms step_avg:89.09ms +step:217/1680 train_time:19332ms step_avg:89.09ms +step:218/1680 train_time:19421ms step_avg:89.09ms +step:219/1680 train_time:19510ms step_avg:89.09ms +step:220/1680 train_time:19599ms step_avg:89.08ms +step:221/1680 train_time:19688ms step_avg:89.09ms +step:222/1680 train_time:19777ms step_avg:89.08ms +step:223/1680 train_time:19865ms step_avg:89.08ms +step:224/1680 train_time:19954ms step_avg:89.08ms +step:225/1680 train_time:20043ms step_avg:89.08ms +step:226/1680 train_time:20132ms step_avg:89.08ms +step:227/1680 train_time:20221ms step_avg:89.08ms +step:228/1680 train_time:20309ms step_avg:89.08ms +step:229/1680 train_time:20399ms step_avg:89.08ms +step:230/1680 train_time:20488ms step_avg:89.08ms +step:231/1680 train_time:20577ms step_avg:89.08ms +step:232/1680 train_time:20666ms step_avg:89.08ms +step:233/1680 train_time:20755ms step_avg:89.08ms +step:234/1680 train_time:20843ms step_avg:89.07ms +step:235/1680 train_time:20932ms step_avg:89.07ms +step:236/1680 train_time:21021ms step_avg:89.07ms +step:237/1680 train_time:21110ms step_avg:89.07ms +step:238/1680 train_time:21199ms step_avg:89.07ms +step:239/1680 train_time:21288ms step_avg:89.07ms +step:240/1680 train_time:21378ms step_avg:89.07ms +step:241/1680 train_time:21467ms step_avg:89.07ms +step:242/1680 train_time:21556ms step_avg:89.08ms +step:243/1680 train_time:21645ms step_avg:89.07ms +step:244/1680 train_time:21733ms step_avg:89.07ms +step:245/1680 train_time:21822ms step_avg:89.07ms +step:246/1680 train_time:21911ms step_avg:89.07ms +step:247/1680 train_time:22001ms step_avg:89.07ms +step:248/1680 train_time:22090ms step_avg:89.07ms +step:249/1680 train_time:22180ms step_avg:89.07ms +step:250/1680 train_time:22268ms step_avg:89.07ms +step:250/1680 val_loss:3.9755 train_time:22357ms step_avg:89.43ms +step:251/1680 train_time:22380ms step_avg:89.16ms +step:252/1680 train_time:22449ms step_avg:89.08ms +step:253/1680 train_time:22544ms step_avg:89.11ms +step:254/1680 train_time:22638ms step_avg:89.13ms +step:255/1680 train_time:22728ms step_avg:89.13ms +step:256/1680 train_time:22816ms step_avg:89.13ms +step:257/1680 train_time:22904ms step_avg:89.12ms +step:258/1680 train_time:22992ms step_avg:89.12ms +step:259/1680 train_time:23080ms step_avg:89.11ms +step:260/1680 train_time:23168ms step_avg:89.11ms +step:261/1680 train_time:23256ms step_avg:89.10ms +step:262/1680 train_time:23344ms step_avg:89.10ms +step:263/1680 train_time:23434ms step_avg:89.10ms +step:264/1680 train_time:23525ms step_avg:89.11ms +step:265/1680 train_time:23616ms step_avg:89.12ms +step:266/1680 train_time:23706ms step_avg:89.12ms +step:267/1680 train_time:23795ms step_avg:89.12ms +step:268/1680 train_time:23883ms step_avg:89.12ms +step:269/1680 train_time:23971ms step_avg:89.11ms +step:270/1680 train_time:24059ms step_avg:89.11ms +step:271/1680 train_time:24147ms step_avg:89.10ms +step:272/1680 train_time:24235ms step_avg:89.10ms +step:273/1680 train_time:24324ms step_avg:89.10ms +step:274/1680 train_time:24412ms step_avg:89.10ms +step:275/1680 train_time:24503ms step_avg:89.10ms +step:276/1680 train_time:24593ms step_avg:89.11ms +step:277/1680 train_time:24684ms step_avg:89.11ms +step:278/1680 train_time:24773ms step_avg:89.11ms +step:279/1680 train_time:24861ms step_avg:89.11ms +step:280/1680 train_time:24949ms step_avg:89.10ms +step:281/1680 train_time:25038ms step_avg:89.10ms +step:282/1680 train_time:25126ms step_avg:89.10ms +step:283/1680 train_time:25215ms step_avg:89.10ms +step:284/1680 train_time:25303ms step_avg:89.10ms +step:285/1680 train_time:25393ms step_avg:89.10ms +step:286/1680 train_time:25482ms step_avg:89.10ms +step:287/1680 train_time:25572ms step_avg:89.10ms +step:288/1680 train_time:25662ms step_avg:89.11ms +step:289/1680 train_time:25751ms step_avg:89.10ms +step:290/1680 train_time:25841ms step_avg:89.11ms +step:291/1680 train_time:25929ms step_avg:89.10ms +step:292/1680 train_time:26018ms step_avg:89.10ms +step:293/1680 train_time:26106ms step_avg:89.10ms +step:294/1680 train_time:26195ms step_avg:89.10ms +step:295/1680 train_time:26283ms step_avg:89.10ms +step:296/1680 train_time:26372ms step_avg:89.09ms +step:297/1680 train_time:26460ms step_avg:89.09ms +step:298/1680 train_time:26549ms step_avg:89.09ms +step:299/1680 train_time:26639ms step_avg:89.09ms +step:300/1680 train_time:26727ms step_avg:89.09ms +step:301/1680 train_time:26817ms step_avg:89.09ms +step:302/1680 train_time:26906ms step_avg:89.09ms +step:303/1680 train_time:26994ms step_avg:89.09ms +step:304/1680 train_time:27082ms step_avg:89.09ms +step:305/1680 train_time:27171ms step_avg:89.08ms +step:306/1680 train_time:27259ms step_avg:89.08ms +step:307/1680 train_time:27348ms step_avg:89.08ms +step:308/1680 train_time:27437ms step_avg:89.08ms +step:309/1680 train_time:27526ms step_avg:89.08ms +step:310/1680 train_time:27615ms step_avg:89.08ms +step:311/1680 train_time:27704ms step_avg:89.08ms +step:312/1680 train_time:27793ms step_avg:89.08ms +step:313/1680 train_time:27882ms step_avg:89.08ms +step:314/1680 train_time:27971ms step_avg:89.08ms +step:315/1680 train_time:28061ms step_avg:89.08ms +step:316/1680 train_time:28150ms step_avg:89.08ms +step:317/1680 train_time:28239ms step_avg:89.08ms +step:318/1680 train_time:28327ms step_avg:89.08ms +step:319/1680 train_time:28416ms step_avg:89.08ms +step:320/1680 train_time:28505ms step_avg:89.08ms +step:321/1680 train_time:28594ms step_avg:89.08ms +step:322/1680 train_time:28683ms step_avg:89.08ms +step:323/1680 train_time:28771ms step_avg:89.08ms +step:324/1680 train_time:28861ms step_avg:89.08ms +step:325/1680 train_time:28949ms step_avg:89.07ms +step:326/1680 train_time:29037ms step_avg:89.07ms +step:327/1680 train_time:29126ms step_avg:89.07ms +step:328/1680 train_time:29215ms step_avg:89.07ms +step:329/1680 train_time:29304ms step_avg:89.07ms +step:330/1680 train_time:29392ms step_avg:89.07ms +step:331/1680 train_time:29482ms step_avg:89.07ms +step:332/1680 train_time:29571ms step_avg:89.07ms +step:333/1680 train_time:29661ms step_avg:89.07ms +step:334/1680 train_time:29750ms step_avg:89.07ms +step:335/1680 train_time:29839ms step_avg:89.07ms +step:336/1680 train_time:29928ms step_avg:89.07ms +step:337/1680 train_time:30017ms step_avg:89.07ms +step:338/1680 train_time:30105ms step_avg:89.07ms +step:339/1680 train_time:30195ms step_avg:89.07ms +step:340/1680 train_time:30283ms step_avg:89.07ms +step:341/1680 train_time:30372ms step_avg:89.07ms +step:342/1680 train_time:30461ms step_avg:89.07ms +step:343/1680 train_time:30549ms step_avg:89.06ms +step:344/1680 train_time:30638ms step_avg:89.06ms +step:345/1680 train_time:30727ms step_avg:89.06ms +step:346/1680 train_time:30816ms step_avg:89.06ms +step:347/1680 train_time:30905ms step_avg:89.06ms +step:348/1680 train_time:30994ms step_avg:89.06ms +step:349/1680 train_time:31083ms step_avg:89.06ms +step:350/1680 train_time:31172ms step_avg:89.06ms +step:351/1680 train_time:31262ms step_avg:89.06ms +step:352/1680 train_time:31349ms step_avg:89.06ms +step:353/1680 train_time:31438ms step_avg:89.06ms +step:354/1680 train_time:31526ms step_avg:89.06ms +step:355/1680 train_time:31618ms step_avg:89.07ms +step:356/1680 train_time:31705ms step_avg:89.06ms +step:357/1680 train_time:31794ms step_avg:89.06ms +step:358/1680 train_time:31883ms step_avg:89.06ms +step:359/1680 train_time:31972ms step_avg:89.06ms +step:360/1680 train_time:32062ms step_avg:89.06ms +step:361/1680 train_time:32150ms step_avg:89.06ms +step:362/1680 train_time:32239ms step_avg:89.06ms +step:363/1680 train_time:32327ms step_avg:89.06ms +step:364/1680 train_time:32416ms step_avg:89.06ms +step:365/1680 train_time:32505ms step_avg:89.05ms +step:366/1680 train_time:32594ms step_avg:89.05ms +step:367/1680 train_time:32684ms step_avg:89.06ms +step:368/1680 train_time:32773ms step_avg:89.06ms +step:369/1680 train_time:32863ms step_avg:89.06ms +step:370/1680 train_time:32951ms step_avg:89.06ms +step:371/1680 train_time:33041ms step_avg:89.06ms +step:372/1680 train_time:33130ms step_avg:89.06ms +step:373/1680 train_time:33220ms step_avg:89.06ms +step:374/1680 train_time:33310ms step_avg:89.06ms +step:375/1680 train_time:33399ms step_avg:89.06ms +step:375/1680 val_loss:3.8193 train_time:33488ms step_avg:89.30ms +step:376/1680 train_time:33511ms step_avg:89.12ms +step:377/1680 train_time:33580ms step_avg:89.07ms +step:378/1680 train_time:33672ms step_avg:89.08ms +step:379/1680 train_time:33761ms step_avg:89.08ms +step:380/1680 train_time:33849ms step_avg:89.08ms +step:381/1680 train_time:33937ms step_avg:89.07ms +step:382/1680 train_time:34026ms step_avg:89.07ms +step:383/1680 train_time:34113ms step_avg:89.07ms +step:384/1680 train_time:34201ms step_avg:89.07ms +step:385/1680 train_time:34290ms step_avg:89.06ms +step:386/1680 train_time:34379ms step_avg:89.06ms +step:387/1680 train_time:34469ms step_avg:89.07ms +step:388/1680 train_time:34560ms step_avg:89.07ms +step:389/1680 train_time:34650ms step_avg:89.07ms +step:390/1680 train_time:34740ms step_avg:89.08ms +step:391/1680 train_time:34828ms step_avg:89.08ms +step:392/1680 train_time:34917ms step_avg:89.07ms +step:393/1680 train_time:35006ms step_avg:89.07ms +step:394/1680 train_time:35094ms step_avg:89.07ms +step:395/1680 train_time:35187ms step_avg:89.08ms +step:396/1680 train_time:35270ms step_avg:89.07ms +step:397/1680 train_time:35359ms step_avg:89.06ms +step:398/1680 train_time:35448ms step_avg:89.07ms +step:399/1680 train_time:35538ms step_avg:89.07ms +step:400/1680 train_time:35628ms step_avg:89.07ms +step:401/1680 train_time:35718ms step_avg:89.07ms +step:402/1680 train_time:35808ms step_avg:89.07ms +step:403/1680 train_time:35896ms step_avg:89.07ms +step:404/1680 train_time:35986ms step_avg:89.07ms +step:405/1680 train_time:36074ms step_avg:89.07ms +step:406/1680 train_time:36162ms step_avg:89.07ms +step:407/1680 train_time:36250ms step_avg:89.07ms +step:408/1680 train_time:36338ms step_avg:89.06ms +step:409/1680 train_time:36428ms step_avg:89.07ms +step:410/1680 train_time:36517ms step_avg:89.07ms +step:411/1680 train_time:36607ms step_avg:89.07ms +step:412/1680 train_time:36697ms step_avg:89.07ms +step:413/1680 train_time:36787ms step_avg:89.07ms +step:414/1680 train_time:36876ms step_avg:89.07ms +step:415/1680 train_time:36964ms step_avg:89.07ms +step:416/1680 train_time:37053ms step_avg:89.07ms +step:417/1680 train_time:37141ms step_avg:89.07ms +step:418/1680 train_time:37229ms step_avg:89.07ms +step:419/1680 train_time:37319ms step_avg:89.07ms +step:420/1680 train_time:37408ms step_avg:89.07ms +step:421/1680 train_time:37497ms step_avg:89.07ms +step:422/1680 train_time:37587ms step_avg:89.07ms +step:423/1680 train_time:37677ms step_avg:89.07ms +step:424/1680 train_time:37766ms step_avg:89.07ms +step:425/1680 train_time:37856ms step_avg:89.07ms +step:426/1680 train_time:37945ms step_avg:89.07ms +step:427/1680 train_time:38034ms step_avg:89.07ms +step:428/1680 train_time:38122ms step_avg:89.07ms +step:429/1680 train_time:38211ms step_avg:89.07ms +step:430/1680 train_time:38299ms step_avg:89.07ms +step:431/1680 train_time:38389ms step_avg:89.07ms +step:432/1680 train_time:38477ms step_avg:89.07ms +step:433/1680 train_time:38567ms step_avg:89.07ms +step:434/1680 train_time:38656ms step_avg:89.07ms +step:435/1680 train_time:38745ms step_avg:89.07ms +step:436/1680 train_time:38835ms step_avg:89.07ms +step:437/1680 train_time:38925ms step_avg:89.07ms +step:438/1680 train_time:39014ms step_avg:89.07ms +step:439/1680 train_time:39104ms step_avg:89.08ms +step:440/1680 train_time:39192ms step_avg:89.07ms +step:441/1680 train_time:39281ms step_avg:89.07ms +step:442/1680 train_time:39369ms step_avg:89.07ms +step:443/1680 train_time:39458ms step_avg:89.07ms +step:444/1680 train_time:39548ms step_avg:89.07ms +step:445/1680 train_time:39637ms step_avg:89.07ms +step:446/1680 train_time:39727ms step_avg:89.07ms +step:447/1680 train_time:39816ms step_avg:89.07ms +step:448/1680 train_time:39905ms step_avg:89.07ms +step:449/1680 train_time:39995ms step_avg:89.08ms +step:450/1680 train_time:40085ms step_avg:89.08ms +step:451/1680 train_time:40174ms step_avg:89.08ms +step:452/1680 train_time:40262ms step_avg:89.08ms +step:453/1680 train_time:40351ms step_avg:89.07ms +step:454/1680 train_time:40440ms step_avg:89.07ms +step:455/1680 train_time:40529ms step_avg:89.08ms +step:456/1680 train_time:40619ms step_avg:89.08ms +step:457/1680 train_time:40708ms step_avg:89.08ms +step:458/1680 train_time:40798ms step_avg:89.08ms +step:459/1680 train_time:40888ms step_avg:89.08ms +step:460/1680 train_time:40977ms step_avg:89.08ms +step:461/1680 train_time:41066ms step_avg:89.08ms +step:462/1680 train_time:41154ms step_avg:89.08ms +step:463/1680 train_time:41243ms step_avg:89.08ms +step:464/1680 train_time:41331ms step_avg:89.08ms +step:465/1680 train_time:41419ms step_avg:89.07ms +step:466/1680 train_time:41508ms step_avg:89.07ms +step:467/1680 train_time:41597ms step_avg:89.07ms +step:468/1680 train_time:41686ms step_avg:89.07ms +step:469/1680 train_time:41775ms step_avg:89.07ms +step:470/1680 train_time:41865ms step_avg:89.07ms +step:471/1680 train_time:41954ms step_avg:89.07ms +step:472/1680 train_time:42043ms step_avg:89.07ms +step:473/1680 train_time:42132ms step_avg:89.07ms +step:474/1680 train_time:42221ms step_avg:89.07ms +step:475/1680 train_time:42310ms step_avg:89.07ms +step:476/1680 train_time:42398ms step_avg:89.07ms +step:477/1680 train_time:42489ms step_avg:89.07ms +step:478/1680 train_time:42578ms step_avg:89.08ms +step:479/1680 train_time:42667ms step_avg:89.08ms +step:480/1680 train_time:42756ms step_avg:89.07ms +step:481/1680 train_time:42845ms step_avg:89.07ms +step:482/1680 train_time:42934ms step_avg:89.07ms +step:483/1680 train_time:43023ms step_avg:89.08ms +step:484/1680 train_time:43112ms step_avg:89.07ms +step:485/1680 train_time:43201ms step_avg:89.07ms +step:486/1680 train_time:43289ms step_avg:89.07ms +step:487/1680 train_time:43379ms step_avg:89.07ms +step:488/1680 train_time:43468ms step_avg:89.07ms +step:489/1680 train_time:43557ms step_avg:89.07ms +step:490/1680 train_time:43646ms step_avg:89.07ms +step:491/1680 train_time:43736ms step_avg:89.07ms +step:492/1680 train_time:43825ms step_avg:89.08ms +step:493/1680 train_time:43915ms step_avg:89.08ms +step:494/1680 train_time:44005ms step_avg:89.08ms +step:495/1680 train_time:44094ms step_avg:89.08ms +step:496/1680 train_time:44184ms step_avg:89.08ms +step:497/1680 train_time:44273ms step_avg:89.08ms +step:498/1680 train_time:44361ms step_avg:89.08ms +step:499/1680 train_time:44450ms step_avg:89.08ms +step:500/1680 train_time:44538ms step_avg:89.08ms +step:500/1680 val_loss:3.7167 train_time:44629ms step_avg:89.26ms +step:501/1680 train_time:44652ms step_avg:89.12ms +step:502/1680 train_time:44721ms step_avg:89.09ms +step:503/1680 train_time:44813ms step_avg:89.09ms +step:504/1680 train_time:44903ms step_avg:89.09ms +step:505/1680 train_time:44992ms step_avg:89.09ms +step:506/1680 train_time:45079ms step_avg:89.09ms +step:507/1680 train_time:45168ms step_avg:89.09ms +step:508/1680 train_time:45257ms step_avg:89.09ms +step:509/1680 train_time:45345ms step_avg:89.09ms +step:510/1680 train_time:45433ms step_avg:89.08ms +step:511/1680 train_time:45522ms step_avg:89.08ms +step:512/1680 train_time:45612ms step_avg:89.09ms +step:513/1680 train_time:45704ms step_avg:89.09ms +step:514/1680 train_time:45795ms step_avg:89.09ms +step:515/1680 train_time:45885ms step_avg:89.10ms +step:516/1680 train_time:45975ms step_avg:89.10ms +step:517/1680 train_time:46063ms step_avg:89.10ms +step:518/1680 train_time:46152ms step_avg:89.10ms +step:519/1680 train_time:46240ms step_avg:89.10ms +step:520/1680 train_time:46328ms step_avg:89.09ms +step:521/1680 train_time:46417ms step_avg:89.09ms +step:522/1680 train_time:46505ms step_avg:89.09ms +step:523/1680 train_time:46597ms step_avg:89.09ms +step:524/1680 train_time:46686ms step_avg:89.10ms +step:525/1680 train_time:46776ms step_avg:89.10ms +step:526/1680 train_time:46865ms step_avg:89.10ms +step:527/1680 train_time:46955ms step_avg:89.10ms +step:528/1680 train_time:47045ms step_avg:89.10ms +step:529/1680 train_time:47134ms step_avg:89.10ms +step:530/1680 train_time:47222ms step_avg:89.10ms +step:531/1680 train_time:47310ms step_avg:89.10ms +step:532/1680 train_time:47398ms step_avg:89.09ms +step:533/1680 train_time:47487ms step_avg:89.09ms +step:534/1680 train_time:47576ms step_avg:89.09ms +step:535/1680 train_time:47665ms step_avg:89.09ms +step:536/1680 train_time:47756ms step_avg:89.10ms +step:537/1680 train_time:47846ms step_avg:89.10ms +step:538/1680 train_time:47935ms step_avg:89.10ms +step:539/1680 train_time:48024ms step_avg:89.10ms +step:540/1680 train_time:48113ms step_avg:89.10ms +step:541/1680 train_time:48202ms step_avg:89.10ms +step:542/1680 train_time:48291ms step_avg:89.10ms +step:543/1680 train_time:48378ms step_avg:89.09ms +step:544/1680 train_time:48467ms step_avg:89.09ms +step:545/1680 train_time:48556ms step_avg:89.09ms +step:546/1680 train_time:48645ms step_avg:89.09ms +step:547/1680 train_time:48735ms step_avg:89.10ms +step:548/1680 train_time:48824ms step_avg:89.10ms +step:549/1680 train_time:48916ms step_avg:89.10ms +step:550/1680 train_time:49006ms step_avg:89.10ms +step:551/1680 train_time:49097ms step_avg:89.11ms +step:552/1680 train_time:49187ms step_avg:89.11ms +step:553/1680 train_time:49277ms step_avg:89.11ms +step:554/1680 train_time:49366ms step_avg:89.11ms +step:555/1680 train_time:49457ms step_avg:89.11ms +step:556/1680 train_time:49547ms step_avg:89.11ms +step:557/1680 train_time:49638ms step_avg:89.12ms +step:558/1680 train_time:49729ms step_avg:89.12ms +step:559/1680 train_time:49819ms step_avg:89.12ms +step:560/1680 train_time:49908ms step_avg:89.12ms +step:561/1680 train_time:49998ms step_avg:89.12ms +step:562/1680 train_time:50090ms step_avg:89.13ms +step:563/1680 train_time:50179ms step_avg:89.13ms +step:564/1680 train_time:50271ms step_avg:89.13ms +step:565/1680 train_time:50359ms step_avg:89.13ms +step:566/1680 train_time:50450ms step_avg:89.13ms +step:567/1680 train_time:50539ms step_avg:89.13ms +step:568/1680 train_time:50630ms step_avg:89.14ms +step:569/1680 train_time:50720ms step_avg:89.14ms +step:570/1680 train_time:50810ms step_avg:89.14ms +step:571/1680 train_time:50901ms step_avg:89.14ms +step:572/1680 train_time:50991ms step_avg:89.14ms +step:573/1680 train_time:51081ms step_avg:89.15ms +step:574/1680 train_time:51173ms step_avg:89.15ms +step:575/1680 train_time:51263ms step_avg:89.15ms +step:576/1680 train_time:51353ms step_avg:89.16ms +step:577/1680 train_time:51444ms step_avg:89.16ms +step:578/1680 train_time:51534ms step_avg:89.16ms +step:579/1680 train_time:51624ms step_avg:89.16ms +step:580/1680 train_time:51715ms step_avg:89.16ms +step:581/1680 train_time:51806ms step_avg:89.17ms +step:582/1680 train_time:51897ms step_avg:89.17ms +step:583/1680 train_time:51986ms step_avg:89.17ms +step:584/1680 train_time:52077ms step_avg:89.17ms +step:585/1680 train_time:52167ms step_avg:89.17ms +step:586/1680 train_time:52257ms step_avg:89.18ms +step:587/1680 train_time:52347ms step_avg:89.18ms +step:588/1680 train_time:52437ms step_avg:89.18ms +step:589/1680 train_time:52528ms step_avg:89.18ms +step:590/1680 train_time:52617ms step_avg:89.18ms +step:591/1680 train_time:52707ms step_avg:89.18ms +step:592/1680 train_time:52797ms step_avg:89.18ms +step:593/1680 train_time:52887ms step_avg:89.19ms +step:594/1680 train_time:52978ms step_avg:89.19ms +step:595/1680 train_time:53068ms step_avg:89.19ms +step:596/1680 train_time:53159ms step_avg:89.19ms +step:597/1680 train_time:53249ms step_avg:89.19ms +step:598/1680 train_time:53339ms step_avg:89.20ms +step:599/1680 train_time:53429ms step_avg:89.20ms +step:600/1680 train_time:53519ms step_avg:89.20ms +step:601/1680 train_time:53608ms step_avg:89.20ms +step:602/1680 train_time:53699ms step_avg:89.20ms +step:603/1680 train_time:53790ms step_avg:89.20ms +step:604/1680 train_time:53880ms step_avg:89.21ms +step:605/1680 train_time:53969ms step_avg:89.21ms +step:606/1680 train_time:54059ms step_avg:89.21ms +step:607/1680 train_time:54150ms step_avg:89.21ms +step:608/1680 train_time:54241ms step_avg:89.21ms +step:609/1680 train_time:54332ms step_avg:89.21ms +step:610/1680 train_time:54421ms step_avg:89.22ms +step:611/1680 train_time:54511ms step_avg:89.22ms +step:612/1680 train_time:54602ms step_avg:89.22ms +step:613/1680 train_time:54693ms step_avg:89.22ms +step:614/1680 train_time:54783ms step_avg:89.22ms +step:615/1680 train_time:54874ms step_avg:89.23ms +step:616/1680 train_time:54965ms step_avg:89.23ms +step:617/1680 train_time:55054ms step_avg:89.23ms +step:618/1680 train_time:55145ms step_avg:89.23ms +step:619/1680 train_time:55235ms step_avg:89.23ms +step:620/1680 train_time:55325ms step_avg:89.23ms +step:621/1680 train_time:55416ms step_avg:89.24ms +step:622/1680 train_time:55506ms step_avg:89.24ms +step:623/1680 train_time:55596ms step_avg:89.24ms +step:624/1680 train_time:55686ms step_avg:89.24ms +step:625/1680 train_time:55777ms step_avg:89.24ms +step:625/1680 val_loss:3.6150 train_time:55869ms step_avg:89.39ms +step:626/1680 train_time:55892ms step_avg:89.28ms +step:627/1680 train_time:55960ms step_avg:89.25ms +step:628/1680 train_time:56059ms step_avg:89.27ms +step:629/1680 train_time:56152ms step_avg:89.27ms +step:630/1680 train_time:56240ms step_avg:89.27ms +step:631/1680 train_time:56329ms step_avg:89.27ms +step:632/1680 train_time:56418ms step_avg:89.27ms +step:633/1680 train_time:56507ms step_avg:89.27ms +step:634/1680 train_time:56595ms step_avg:89.27ms +step:635/1680 train_time:56685ms step_avg:89.27ms +step:636/1680 train_time:56775ms step_avg:89.27ms +step:637/1680 train_time:56867ms step_avg:89.27ms +step:638/1680 train_time:56960ms step_avg:89.28ms +step:639/1680 train_time:57052ms step_avg:89.28ms +step:640/1680 train_time:57150ms step_avg:89.30ms +step:641/1680 train_time:57233ms step_avg:89.29ms +step:642/1680 train_time:57323ms step_avg:89.29ms +step:643/1680 train_time:57412ms step_avg:89.29ms +step:644/1680 train_time:57501ms step_avg:89.29ms +step:645/1680 train_time:57591ms step_avg:89.29ms +step:646/1680 train_time:57681ms step_avg:89.29ms +step:647/1680 train_time:57773ms step_avg:89.29ms +step:648/1680 train_time:57864ms step_avg:89.30ms +step:649/1680 train_time:57955ms step_avg:89.30ms +step:650/1680 train_time:58047ms step_avg:89.30ms +step:651/1680 train_time:58138ms step_avg:89.31ms +step:652/1680 train_time:58228ms step_avg:89.31ms +step:653/1680 train_time:58318ms step_avg:89.31ms +step:654/1680 train_time:58407ms step_avg:89.31ms +step:655/1680 train_time:58497ms step_avg:89.31ms +step:656/1680 train_time:58586ms step_avg:89.31ms +step:657/1680 train_time:58675ms step_avg:89.31ms +step:658/1680 train_time:58766ms step_avg:89.31ms +step:659/1680 train_time:58856ms step_avg:89.31ms +step:660/1680 train_time:58947ms step_avg:89.31ms +step:661/1680 train_time:59038ms step_avg:89.32ms +step:662/1680 train_time:59128ms step_avg:89.32ms +step:663/1680 train_time:59219ms step_avg:89.32ms +step:664/1680 train_time:59309ms step_avg:89.32ms +step:665/1680 train_time:59398ms step_avg:89.32ms +step:666/1680 train_time:59488ms step_avg:89.32ms +step:667/1680 train_time:59577ms step_avg:89.32ms +step:668/1680 train_time:59667ms step_avg:89.32ms +step:669/1680 train_time:59756ms step_avg:89.32ms +step:670/1680 train_time:59847ms step_avg:89.32ms +step:671/1680 train_time:59937ms step_avg:89.32ms +step:672/1680 train_time:60027ms step_avg:89.33ms +step:673/1680 train_time:60118ms step_avg:89.33ms +step:674/1680 train_time:60209ms step_avg:89.33ms +step:675/1680 train_time:60299ms step_avg:89.33ms +step:676/1680 train_time:60389ms step_avg:89.33ms +step:677/1680 train_time:60478ms step_avg:89.33ms +step:678/1680 train_time:60568ms step_avg:89.33ms +step:679/1680 train_time:60658ms step_avg:89.33ms +step:680/1680 train_time:60748ms step_avg:89.34ms +step:681/1680 train_time:60839ms step_avg:89.34ms +step:682/1680 train_time:60930ms step_avg:89.34ms +step:683/1680 train_time:61020ms step_avg:89.34ms +step:684/1680 train_time:61111ms step_avg:89.34ms +step:685/1680 train_time:61200ms step_avg:89.34ms +step:686/1680 train_time:61291ms step_avg:89.35ms +step:687/1680 train_time:61382ms step_avg:89.35ms +step:688/1680 train_time:61471ms step_avg:89.35ms +step:689/1680 train_time:61562ms step_avg:89.35ms +step:690/1680 train_time:61653ms step_avg:89.35ms +step:691/1680 train_time:61743ms step_avg:89.35ms +step:692/1680 train_time:61833ms step_avg:89.35ms +step:693/1680 train_time:61925ms step_avg:89.36ms +step:694/1680 train_time:62016ms step_avg:89.36ms +step:695/1680 train_time:62106ms step_avg:89.36ms +step:696/1680 train_time:62196ms step_avg:89.36ms +step:697/1680 train_time:62286ms step_avg:89.36ms +step:698/1680 train_time:62376ms step_avg:89.36ms +step:699/1680 train_time:62466ms step_avg:89.36ms +step:700/1680 train_time:62556ms step_avg:89.37ms +step:701/1680 train_time:62645ms step_avg:89.37ms +step:702/1680 train_time:62735ms step_avg:89.37ms +step:703/1680 train_time:62825ms step_avg:89.37ms +step:704/1680 train_time:62915ms step_avg:89.37ms +step:705/1680 train_time:63006ms step_avg:89.37ms +step:706/1680 train_time:63096ms step_avg:89.37ms +step:707/1680 train_time:63186ms step_avg:89.37ms +step:708/1680 train_time:63277ms step_avg:89.37ms +step:709/1680 train_time:63367ms step_avg:89.37ms +step:710/1680 train_time:63457ms step_avg:89.38ms +step:711/1680 train_time:63547ms step_avg:89.38ms +step:712/1680 train_time:63638ms step_avg:89.38ms +step:713/1680 train_time:63727ms step_avg:89.38ms +step:714/1680 train_time:63818ms step_avg:89.38ms +step:715/1680 train_time:63908ms step_avg:89.38ms +step:716/1680 train_time:63998ms step_avg:89.38ms +step:717/1680 train_time:64089ms step_avg:89.38ms +step:718/1680 train_time:64179ms step_avg:89.39ms +step:719/1680 train_time:64269ms step_avg:89.39ms +step:720/1680 train_time:64358ms step_avg:89.39ms +step:721/1680 train_time:64448ms step_avg:89.39ms +step:722/1680 train_time:64538ms step_avg:89.39ms +step:723/1680 train_time:64628ms step_avg:89.39ms +step:724/1680 train_time:64718ms step_avg:89.39ms +step:725/1680 train_time:64808ms step_avg:89.39ms +step:726/1680 train_time:64898ms step_avg:89.39ms +step:727/1680 train_time:64989ms step_avg:89.39ms +step:728/1680 train_time:65079ms step_avg:89.39ms +step:729/1680 train_time:65169ms step_avg:89.40ms +step:730/1680 train_time:65260ms step_avg:89.40ms +step:731/1680 train_time:65350ms step_avg:89.40ms +step:732/1680 train_time:65441ms step_avg:89.40ms +step:733/1680 train_time:65530ms step_avg:89.40ms +step:734/1680 train_time:65620ms step_avg:89.40ms +step:735/1680 train_time:65710ms step_avg:89.40ms +step:736/1680 train_time:65799ms step_avg:89.40ms +step:737/1680 train_time:65889ms step_avg:89.40ms +step:738/1680 train_time:65979ms step_avg:89.40ms +step:739/1680 train_time:66068ms step_avg:89.40ms +step:740/1680 train_time:66159ms step_avg:89.40ms +step:741/1680 train_time:66252ms step_avg:89.41ms +step:742/1680 train_time:66340ms step_avg:89.41ms +step:743/1680 train_time:66430ms step_avg:89.41ms +step:744/1680 train_time:66519ms step_avg:89.41ms +step:745/1680 train_time:66609ms step_avg:89.41ms +step:746/1680 train_time:66699ms step_avg:89.41ms +step:747/1680 train_time:66789ms step_avg:89.41ms +step:748/1680 train_time:66879ms step_avg:89.41ms +step:749/1680 train_time:66970ms step_avg:89.41ms +step:750/1680 train_time:67060ms step_avg:89.41ms +step:750/1680 val_loss:3.5642 train_time:67152ms step_avg:89.54ms +step:751/1680 train_time:67175ms step_avg:89.45ms +step:752/1680 train_time:67247ms step_avg:89.42ms +step:753/1680 train_time:67343ms step_avg:89.43ms +step:754/1680 train_time:67433ms step_avg:89.43ms +step:755/1680 train_time:67523ms step_avg:89.43ms +step:756/1680 train_time:67612ms step_avg:89.43ms +step:757/1680 train_time:67702ms step_avg:89.43ms +step:758/1680 train_time:67792ms step_avg:89.44ms +step:759/1680 train_time:67881ms step_avg:89.44ms +step:760/1680 train_time:67970ms step_avg:89.43ms +step:761/1680 train_time:68059ms step_avg:89.43ms +step:762/1680 train_time:68150ms step_avg:89.44ms +step:763/1680 train_time:68243ms step_avg:89.44ms +step:764/1680 train_time:68335ms step_avg:89.44ms +step:765/1680 train_time:68426ms step_avg:89.45ms +step:766/1680 train_time:68516ms step_avg:89.45ms +step:767/1680 train_time:68606ms step_avg:89.45ms +step:768/1680 train_time:68700ms step_avg:89.45ms +step:769/1680 train_time:68785ms step_avg:89.45ms +step:770/1680 train_time:68874ms step_avg:89.45ms +step:771/1680 train_time:68962ms step_avg:89.45ms +step:772/1680 train_time:69052ms step_avg:89.45ms +step:773/1680 train_time:69142ms step_avg:89.45ms +step:774/1680 train_time:69234ms step_avg:89.45ms +step:775/1680 train_time:69327ms step_avg:89.45ms +step:776/1680 train_time:69417ms step_avg:89.46ms +step:777/1680 train_time:69508ms step_avg:89.46ms +step:778/1680 train_time:69599ms step_avg:89.46ms +step:779/1680 train_time:69689ms step_avg:89.46ms +step:780/1680 train_time:69779ms step_avg:89.46ms +step:781/1680 train_time:69869ms step_avg:89.46ms +step:782/1680 train_time:69959ms step_avg:89.46ms +step:783/1680 train_time:70048ms step_avg:89.46ms +step:784/1680 train_time:70139ms step_avg:89.46ms +step:785/1680 train_time:70231ms step_avg:89.47ms +step:786/1680 train_time:70323ms step_avg:89.47ms +step:787/1680 train_time:70414ms step_avg:89.47ms +step:788/1680 train_time:70505ms step_avg:89.47ms +step:789/1680 train_time:70595ms step_avg:89.47ms +step:790/1680 train_time:70685ms step_avg:89.47ms +step:791/1680 train_time:70775ms step_avg:89.47ms +step:792/1680 train_time:70864ms step_avg:89.47ms +step:793/1680 train_time:70954ms step_avg:89.48ms +step:794/1680 train_time:71043ms step_avg:89.47ms +step:795/1680 train_time:71133ms step_avg:89.48ms +step:796/1680 train_time:71224ms step_avg:89.48ms +step:797/1680 train_time:71315ms step_avg:89.48ms +step:798/1680 train_time:71406ms step_avg:89.48ms +step:799/1680 train_time:71496ms step_avg:89.48ms +step:800/1680 train_time:71587ms step_avg:89.48ms +step:801/1680 train_time:71676ms step_avg:89.48ms +step:802/1680 train_time:71766ms step_avg:89.48ms +step:803/1680 train_time:71857ms step_avg:89.49ms +step:804/1680 train_time:71946ms step_avg:89.48ms +step:805/1680 train_time:72035ms step_avg:89.48ms +step:806/1680 train_time:72125ms step_avg:89.48ms +step:807/1680 train_time:72215ms step_avg:89.49ms +step:808/1680 train_time:72306ms step_avg:89.49ms +step:809/1680 train_time:72397ms step_avg:89.49ms +step:810/1680 train_time:72487ms step_avg:89.49ms +step:811/1680 train_time:72578ms step_avg:89.49ms +step:812/1680 train_time:72668ms step_avg:89.49ms +step:813/1680 train_time:72758ms step_avg:89.49ms +step:814/1680 train_time:72847ms step_avg:89.49ms +step:815/1680 train_time:72936ms step_avg:89.49ms +step:816/1680 train_time:73027ms step_avg:89.49ms +step:817/1680 train_time:73116ms step_avg:89.49ms +step:818/1680 train_time:73206ms step_avg:89.49ms +step:819/1680 train_time:73297ms step_avg:89.50ms +step:820/1680 train_time:73388ms step_avg:89.50ms +step:821/1680 train_time:73477ms step_avg:89.50ms +step:822/1680 train_time:73568ms step_avg:89.50ms +step:823/1680 train_time:73658ms step_avg:89.50ms +step:824/1680 train_time:73748ms step_avg:89.50ms +step:825/1680 train_time:73839ms step_avg:89.50ms +step:826/1680 train_time:73929ms step_avg:89.50ms +step:827/1680 train_time:74018ms step_avg:89.50ms +step:828/1680 train_time:74108ms step_avg:89.50ms +step:829/1680 train_time:74198ms step_avg:89.50ms +step:830/1680 train_time:74289ms step_avg:89.50ms +step:831/1680 train_time:74379ms step_avg:89.51ms +step:832/1680 train_time:74469ms step_avg:89.51ms +step:833/1680 train_time:74560ms step_avg:89.51ms +step:834/1680 train_time:74650ms step_avg:89.51ms +step:835/1680 train_time:74740ms step_avg:89.51ms +step:836/1680 train_time:74831ms step_avg:89.51ms +step:837/1680 train_time:74922ms step_avg:89.51ms +step:838/1680 train_time:75012ms step_avg:89.51ms +step:839/1680 train_time:75102ms step_avg:89.51ms +step:840/1680 train_time:75192ms step_avg:89.51ms +step:841/1680 train_time:75282ms step_avg:89.52ms +step:842/1680 train_time:75372ms step_avg:89.52ms +step:843/1680 train_time:75463ms step_avg:89.52ms +step:844/1680 train_time:75553ms step_avg:89.52ms +step:845/1680 train_time:75643ms step_avg:89.52ms +step:846/1680 train_time:75734ms step_avg:89.52ms +step:847/1680 train_time:75825ms step_avg:89.52ms +step:848/1680 train_time:75914ms step_avg:89.52ms +step:849/1680 train_time:76008ms step_avg:89.53ms +step:850/1680 train_time:76094ms step_avg:89.52ms +step:851/1680 train_time:76185ms step_avg:89.52ms +step:852/1680 train_time:76274ms step_avg:89.52ms +step:853/1680 train_time:76364ms step_avg:89.52ms +step:854/1680 train_time:76455ms step_avg:89.53ms +step:855/1680 train_time:76545ms step_avg:89.53ms +step:856/1680 train_time:76635ms step_avg:89.53ms +step:857/1680 train_time:76727ms step_avg:89.53ms +step:858/1680 train_time:76817ms step_avg:89.53ms +step:859/1680 train_time:76907ms step_avg:89.53ms +step:860/1680 train_time:76997ms step_avg:89.53ms +step:861/1680 train_time:77087ms step_avg:89.53ms +step:862/1680 train_time:77177ms step_avg:89.53ms +step:863/1680 train_time:77267ms step_avg:89.53ms +step:864/1680 train_time:77357ms step_avg:89.53ms +step:865/1680 train_time:77448ms step_avg:89.53ms +step:866/1680 train_time:77539ms step_avg:89.54ms +step:867/1680 train_time:77629ms step_avg:89.54ms +step:868/1680 train_time:77719ms step_avg:89.54ms +step:869/1680 train_time:77809ms step_avg:89.54ms +step:870/1680 train_time:77900ms step_avg:89.54ms +step:871/1680 train_time:77991ms step_avg:89.54ms +step:872/1680 train_time:78081ms step_avg:89.54ms +step:873/1680 train_time:78170ms step_avg:89.54ms +step:874/1680 train_time:78261ms step_avg:89.54ms +step:875/1680 train_time:78351ms step_avg:89.54ms +step:875/1680 val_loss:3.5176 train_time:78443ms step_avg:89.65ms +step:876/1680 train_time:78466ms step_avg:89.57ms +step:877/1680 train_time:78537ms step_avg:89.55ms +step:878/1680 train_time:78635ms step_avg:89.56ms +step:879/1680 train_time:78727ms step_avg:89.56ms +step:880/1680 train_time:78817ms step_avg:89.56ms +step:881/1680 train_time:78906ms step_avg:89.56ms +step:882/1680 train_time:78995ms step_avg:89.56ms +step:883/1680 train_time:79085ms step_avg:89.56ms +step:884/1680 train_time:79175ms step_avg:89.56ms +step:885/1680 train_time:79264ms step_avg:89.56ms +step:886/1680 train_time:79353ms step_avg:89.56ms +step:887/1680 train_time:79447ms step_avg:89.57ms +step:888/1680 train_time:79537ms step_avg:89.57ms +step:889/1680 train_time:79630ms step_avg:89.57ms +step:890/1680 train_time:79722ms step_avg:89.58ms +step:891/1680 train_time:79811ms step_avg:89.57ms +step:892/1680 train_time:79901ms step_avg:89.57ms +step:893/1680 train_time:79990ms step_avg:89.57ms +step:894/1680 train_time:80080ms step_avg:89.57ms +step:895/1680 train_time:80169ms step_avg:89.57ms +step:896/1680 train_time:80259ms step_avg:89.57ms +step:897/1680 train_time:80349ms step_avg:89.58ms +step:898/1680 train_time:80440ms step_avg:89.58ms +step:899/1680 train_time:80532ms step_avg:89.58ms +step:900/1680 train_time:80622ms step_avg:89.58ms +step:901/1680 train_time:80713ms step_avg:89.58ms +step:902/1680 train_time:80803ms step_avg:89.58ms +step:903/1680 train_time:80894ms step_avg:89.58ms +step:904/1680 train_time:80984ms step_avg:89.58ms +step:905/1680 train_time:81075ms step_avg:89.59ms +step:906/1680 train_time:81164ms step_avg:89.58ms +step:907/1680 train_time:81254ms step_avg:89.59ms +step:908/1680 train_time:81344ms step_avg:89.59ms +step:909/1680 train_time:81435ms step_avg:89.59ms +step:910/1680 train_time:81526ms step_avg:89.59ms +step:911/1680 train_time:81616ms step_avg:89.59ms +step:912/1680 train_time:81707ms step_avg:89.59ms +step:913/1680 train_time:81798ms step_avg:89.59ms +step:914/1680 train_time:81887ms step_avg:89.59ms +step:915/1680 train_time:81978ms step_avg:89.59ms +step:916/1680 train_time:82067ms step_avg:89.59ms +step:917/1680 train_time:82157ms step_avg:89.59ms +step:918/1680 train_time:82247ms step_avg:89.59ms +step:919/1680 train_time:82338ms step_avg:89.60ms +step:920/1680 train_time:82428ms step_avg:89.60ms +step:921/1680 train_time:82518ms step_avg:89.60ms +step:922/1680 train_time:82610ms step_avg:89.60ms +step:923/1680 train_time:82700ms step_avg:89.60ms +step:924/1680 train_time:82790ms step_avg:89.60ms +step:925/1680 train_time:82880ms step_avg:89.60ms +step:926/1680 train_time:82971ms step_avg:89.60ms +step:927/1680 train_time:83060ms step_avg:89.60ms +step:928/1680 train_time:83152ms step_avg:89.60ms +step:929/1680 train_time:83241ms step_avg:89.60ms +step:930/1680 train_time:83332ms step_avg:89.60ms +step:931/1680 train_time:83421ms step_avg:89.60ms +step:932/1680 train_time:83512ms step_avg:89.61ms +step:933/1680 train_time:83602ms step_avg:89.61ms +step:934/1680 train_time:83692ms step_avg:89.61ms +step:935/1680 train_time:83783ms step_avg:89.61ms +step:936/1680 train_time:83873ms step_avg:89.61ms +step:937/1680 train_time:83962ms step_avg:89.61ms +step:938/1680 train_time:84052ms step_avg:89.61ms +step:939/1680 train_time:84142ms step_avg:89.61ms +step:940/1680 train_time:84233ms step_avg:89.61ms +step:941/1680 train_time:84323ms step_avg:89.61ms +step:942/1680 train_time:84414ms step_avg:89.61ms +step:943/1680 train_time:84504ms step_avg:89.61ms +step:944/1680 train_time:84593ms step_avg:89.61ms +step:945/1680 train_time:84685ms step_avg:89.61ms +step:946/1680 train_time:84775ms step_avg:89.61ms +step:947/1680 train_time:84865ms step_avg:89.61ms +step:948/1680 train_time:84956ms step_avg:89.62ms +step:949/1680 train_time:85046ms step_avg:89.62ms +step:950/1680 train_time:85138ms step_avg:89.62ms +step:951/1680 train_time:85228ms step_avg:89.62ms +step:952/1680 train_time:85318ms step_avg:89.62ms +step:953/1680 train_time:85408ms step_avg:89.62ms +step:954/1680 train_time:85499ms step_avg:89.62ms +step:955/1680 train_time:85590ms step_avg:89.62ms +step:956/1680 train_time:85679ms step_avg:89.62ms +step:957/1680 train_time:85770ms step_avg:89.62ms +step:958/1680 train_time:85860ms step_avg:89.62ms +step:959/1680 train_time:85950ms step_avg:89.62ms +step:960/1680 train_time:86040ms step_avg:89.63ms +step:961/1680 train_time:86131ms step_avg:89.63ms +step:962/1680 train_time:86222ms step_avg:89.63ms +step:963/1680 train_time:86311ms step_avg:89.63ms +step:964/1680 train_time:86401ms step_avg:89.63ms +step:965/1680 train_time:86491ms step_avg:89.63ms +step:966/1680 train_time:86581ms step_avg:89.63ms +step:967/1680 train_time:86671ms step_avg:89.63ms +step:968/1680 train_time:86761ms step_avg:89.63ms +step:969/1680 train_time:86852ms step_avg:89.63ms +step:970/1680 train_time:86942ms step_avg:89.63ms +step:971/1680 train_time:87033ms step_avg:89.63ms +step:972/1680 train_time:87122ms step_avg:89.63ms +step:973/1680 train_time:87213ms step_avg:89.63ms +step:974/1680 train_time:87303ms step_avg:89.63ms +step:975/1680 train_time:87392ms step_avg:89.63ms +step:976/1680 train_time:87482ms step_avg:89.63ms +step:977/1680 train_time:87572ms step_avg:89.63ms +step:978/1680 train_time:87662ms step_avg:89.63ms +step:979/1680 train_time:87752ms step_avg:89.63ms +step:980/1680 train_time:87842ms step_avg:89.64ms +step:981/1680 train_time:87933ms step_avg:89.64ms +step:982/1680 train_time:88023ms step_avg:89.64ms +step:983/1680 train_time:88114ms step_avg:89.64ms +step:984/1680 train_time:88204ms step_avg:89.64ms +step:985/1680 train_time:88294ms step_avg:89.64ms +step:986/1680 train_time:88385ms step_avg:89.64ms +step:987/1680 train_time:88475ms step_avg:89.64ms +step:988/1680 train_time:88565ms step_avg:89.64ms +step:989/1680 train_time:88655ms step_avg:89.64ms +step:990/1680 train_time:88745ms step_avg:89.64ms +step:991/1680 train_time:88836ms step_avg:89.64ms +step:992/1680 train_time:88926ms step_avg:89.64ms +step:993/1680 train_time:89016ms step_avg:89.64ms +step:994/1680 train_time:89106ms step_avg:89.64ms +step:995/1680 train_time:89195ms step_avg:89.64ms +step:996/1680 train_time:89286ms step_avg:89.64ms +step:997/1680 train_time:89377ms step_avg:89.65ms +step:998/1680 train_time:89467ms step_avg:89.65ms +step:999/1680 train_time:89558ms step_avg:89.65ms +step:1000/1680 train_time:89648ms step_avg:89.65ms +step:1000/1680 val_loss:3.4690 train_time:89741ms step_avg:89.74ms +step:1001/1680 train_time:89764ms step_avg:89.67ms +step:1002/1680 train_time:89837ms step_avg:89.66ms +step:1003/1680 train_time:89935ms step_avg:89.67ms +step:1004/1680 train_time:90027ms step_avg:89.67ms +step:1005/1680 train_time:90118ms step_avg:89.67ms +step:1006/1680 train_time:90207ms step_avg:89.67ms +step:1007/1680 train_time:90296ms step_avg:89.67ms +step:1008/1680 train_time:90386ms step_avg:89.67ms +step:1009/1680 train_time:90475ms step_avg:89.67ms +step:1010/1680 train_time:90564ms step_avg:89.67ms +step:1011/1680 train_time:90653ms step_avg:89.67ms +step:1012/1680 train_time:90744ms step_avg:89.67ms +step:1013/1680 train_time:90836ms step_avg:89.67ms +step:1014/1680 train_time:90928ms step_avg:89.67ms +step:1015/1680 train_time:91020ms step_avg:89.67ms +step:1016/1680 train_time:91111ms step_avg:89.68ms +step:1017/1680 train_time:91201ms step_avg:89.68ms +step:1018/1680 train_time:91291ms step_avg:89.68ms +step:1019/1680 train_time:91381ms step_avg:89.68ms +step:1020/1680 train_time:91470ms step_avg:89.68ms +step:1021/1680 train_time:91559ms step_avg:89.68ms +step:1022/1680 train_time:91649ms step_avg:89.68ms +step:1023/1680 train_time:91738ms step_avg:89.68ms +step:1024/1680 train_time:91829ms step_avg:89.68ms +step:1025/1680 train_time:91921ms step_avg:89.68ms +step:1026/1680 train_time:92012ms step_avg:89.68ms +step:1027/1680 train_time:92103ms step_avg:89.68ms +step:1028/1680 train_time:92193ms step_avg:89.68ms +step:1029/1680 train_time:92283ms step_avg:89.68ms +step:1030/1680 train_time:92372ms step_avg:89.68ms +step:1031/1680 train_time:92462ms step_avg:89.68ms +step:1032/1680 train_time:92551ms step_avg:89.68ms +step:1033/1680 train_time:92640ms step_avg:89.68ms +step:1034/1680 train_time:92731ms step_avg:89.68ms +step:1035/1680 train_time:92821ms step_avg:89.68ms +step:1036/1680 train_time:92913ms step_avg:89.68ms +step:1037/1680 train_time:93002ms step_avg:89.68ms +step:1038/1680 train_time:93093ms step_avg:89.68ms +step:1039/1680 train_time:93183ms step_avg:89.69ms +step:1040/1680 train_time:93272ms step_avg:89.68ms +step:1041/1680 train_time:93362ms step_avg:89.68ms +step:1042/1680 train_time:93452ms step_avg:89.68ms +step:1043/1680 train_time:93542ms step_avg:89.69ms +step:1044/1680 train_time:93631ms step_avg:89.69ms +step:1045/1680 train_time:93721ms step_avg:89.68ms +step:1046/1680 train_time:93811ms step_avg:89.69ms +step:1047/1680 train_time:93902ms step_avg:89.69ms +step:1048/1680 train_time:93992ms step_avg:89.69ms +step:1049/1680 train_time:94084ms step_avg:89.69ms +step:1050/1680 train_time:94173ms step_avg:89.69ms +step:1051/1680 train_time:94263ms step_avg:89.69ms +step:1052/1680 train_time:94353ms step_avg:89.69ms +step:1053/1680 train_time:94442ms step_avg:89.69ms +step:1054/1680 train_time:94532ms step_avg:89.69ms +step:1055/1680 train_time:94621ms step_avg:89.69ms +step:1056/1680 train_time:94711ms step_avg:89.69ms +step:1057/1680 train_time:94802ms step_avg:89.69ms +step:1058/1680 train_time:94892ms step_avg:89.69ms +step:1059/1680 train_time:94983ms step_avg:89.69ms +step:1060/1680 train_time:95074ms step_avg:89.69ms +step:1061/1680 train_time:95164ms step_avg:89.69ms +step:1062/1680 train_time:95254ms step_avg:89.69ms +step:1063/1680 train_time:95343ms step_avg:89.69ms +step:1064/1680 train_time:95434ms step_avg:89.69ms +step:1065/1680 train_time:95523ms step_avg:89.69ms +step:1066/1680 train_time:95613ms step_avg:89.69ms +step:1067/1680 train_time:95703ms step_avg:89.69ms +step:1068/1680 train_time:95794ms step_avg:89.69ms +step:1069/1680 train_time:95886ms step_avg:89.70ms +step:1070/1680 train_time:95975ms step_avg:89.70ms +step:1071/1680 train_time:96065ms step_avg:89.70ms +step:1072/1680 train_time:96157ms step_avg:89.70ms +step:1073/1680 train_time:96247ms step_avg:89.70ms +step:1074/1680 train_time:96336ms step_avg:89.70ms +step:1075/1680 train_time:96427ms step_avg:89.70ms +step:1076/1680 train_time:96516ms step_avg:89.70ms +step:1077/1680 train_time:96606ms step_avg:89.70ms +step:1078/1680 train_time:96697ms step_avg:89.70ms +step:1079/1680 train_time:96788ms step_avg:89.70ms +step:1080/1680 train_time:96879ms step_avg:89.70ms +step:1081/1680 train_time:96970ms step_avg:89.70ms +step:1082/1680 train_time:97062ms step_avg:89.71ms +step:1083/1680 train_time:97152ms step_avg:89.71ms +step:1084/1680 train_time:97241ms step_avg:89.71ms +step:1085/1680 train_time:97331ms step_avg:89.71ms +step:1086/1680 train_time:97421ms step_avg:89.71ms +step:1087/1680 train_time:97511ms step_avg:89.71ms +step:1088/1680 train_time:97601ms step_avg:89.71ms +step:1089/1680 train_time:97690ms step_avg:89.71ms +step:1090/1680 train_time:97781ms step_avg:89.71ms +step:1091/1680 train_time:97871ms step_avg:89.71ms +step:1092/1680 train_time:97961ms step_avg:89.71ms +step:1093/1680 train_time:98051ms step_avg:89.71ms +step:1094/1680 train_time:98141ms step_avg:89.71ms +step:1095/1680 train_time:98232ms step_avg:89.71ms +step:1096/1680 train_time:98323ms step_avg:89.71ms +step:1097/1680 train_time:98413ms step_avg:89.71ms +step:1098/1680 train_time:98505ms step_avg:89.71ms +step:1099/1680 train_time:98596ms step_avg:89.71ms +step:1100/1680 train_time:98688ms step_avg:89.72ms +step:1101/1680 train_time:98781ms step_avg:89.72ms +step:1102/1680 train_time:98870ms step_avg:89.72ms +step:1103/1680 train_time:98962ms step_avg:89.72ms +step:1104/1680 train_time:99052ms step_avg:89.72ms +step:1105/1680 train_time:99143ms step_avg:89.72ms +step:1106/1680 train_time:99234ms step_avg:89.72ms +step:1107/1680 train_time:99326ms step_avg:89.73ms +step:1108/1680 train_time:99417ms step_avg:89.73ms +step:1109/1680 train_time:99508ms step_avg:89.73ms +step:1110/1680 train_time:99599ms step_avg:89.73ms +step:1111/1680 train_time:99690ms step_avg:89.73ms +step:1112/1680 train_time:99782ms step_avg:89.73ms +step:1113/1680 train_time:99872ms step_avg:89.73ms +step:1114/1680 train_time:99962ms step_avg:89.73ms +step:1115/1680 train_time:100053ms step_avg:89.73ms +step:1116/1680 train_time:100143ms step_avg:89.73ms +step:1117/1680 train_time:100234ms step_avg:89.73ms +step:1118/1680 train_time:100325ms step_avg:89.74ms +step:1119/1680 train_time:100416ms step_avg:89.74ms +step:1120/1680 train_time:100507ms step_avg:89.74ms +step:1121/1680 train_time:100598ms step_avg:89.74ms +step:1122/1680 train_time:100689ms step_avg:89.74ms +step:1123/1680 train_time:100781ms step_avg:89.74ms +step:1124/1680 train_time:100871ms step_avg:89.74ms +step:1125/1680 train_time:100963ms step_avg:89.74ms +step:1125/1680 val_loss:3.4153 train_time:101055ms step_avg:89.83ms +step:1126/1680 train_time:101077ms step_avg:89.77ms +step:1127/1680 train_time:101147ms step_avg:89.75ms +step:1128/1680 train_time:101249ms step_avg:89.76ms +step:1129/1680 train_time:101344ms step_avg:89.76ms +step:1130/1680 train_time:101434ms step_avg:89.76ms +step:1131/1680 train_time:101523ms step_avg:89.76ms +step:1132/1680 train_time:101613ms step_avg:89.76ms +step:1133/1680 train_time:101703ms step_avg:89.76ms +step:1134/1680 train_time:101793ms step_avg:89.76ms +step:1135/1680 train_time:101882ms step_avg:89.76ms +step:1136/1680 train_time:101972ms step_avg:89.76ms +step:1137/1680 train_time:102064ms step_avg:89.77ms +step:1138/1680 train_time:102157ms step_avg:89.77ms +step:1139/1680 train_time:102252ms step_avg:89.77ms +step:1140/1680 train_time:102344ms step_avg:89.78ms +step:1141/1680 train_time:102434ms step_avg:89.78ms +step:1142/1680 train_time:102524ms step_avg:89.78ms +step:1143/1680 train_time:102614ms step_avg:89.78ms +step:1144/1680 train_time:102704ms step_avg:89.78ms +step:1145/1680 train_time:102793ms step_avg:89.78ms +step:1146/1680 train_time:102883ms step_avg:89.78ms +step:1147/1680 train_time:102973ms step_avg:89.78ms +step:1148/1680 train_time:103065ms step_avg:89.78ms +step:1149/1680 train_time:103159ms step_avg:89.78ms +step:1150/1680 train_time:103251ms step_avg:89.78ms +step:1151/1680 train_time:103342ms step_avg:89.78ms +step:1152/1680 train_time:103433ms step_avg:89.79ms +step:1153/1680 train_time:103524ms step_avg:89.79ms +step:1154/1680 train_time:103615ms step_avg:89.79ms +step:1155/1680 train_time:103705ms step_avg:89.79ms +step:1156/1680 train_time:103796ms step_avg:89.79ms +step:1157/1680 train_time:103885ms step_avg:89.79ms +step:1158/1680 train_time:103975ms step_avg:89.79ms +step:1159/1680 train_time:104066ms step_avg:89.79ms +step:1160/1680 train_time:104159ms step_avg:89.79ms +step:1161/1680 train_time:104251ms step_avg:89.79ms +step:1162/1680 train_time:104342ms step_avg:89.80ms +step:1163/1680 train_time:104433ms step_avg:89.80ms +step:1164/1680 train_time:104525ms step_avg:89.80ms +step:1165/1680 train_time:104616ms step_avg:89.80ms +step:1166/1680 train_time:104707ms step_avg:89.80ms +step:1167/1680 train_time:104797ms step_avg:89.80ms +step:1168/1680 train_time:104887ms step_avg:89.80ms +step:1169/1680 train_time:104979ms step_avg:89.80ms +step:1170/1680 train_time:105068ms step_avg:89.80ms +step:1171/1680 train_time:105160ms step_avg:89.80ms +step:1172/1680 train_time:105251ms step_avg:89.80ms +step:1173/1680 train_time:105343ms step_avg:89.81ms +step:1174/1680 train_time:105433ms step_avg:89.81ms +step:1175/1680 train_time:105525ms step_avg:89.81ms +step:1176/1680 train_time:105615ms step_avg:89.81ms +step:1177/1680 train_time:105707ms step_avg:89.81ms +step:1178/1680 train_time:105797ms step_avg:89.81ms +step:1179/1680 train_time:105888ms step_avg:89.81ms +step:1180/1680 train_time:105978ms step_avg:89.81ms +step:1181/1680 train_time:106069ms step_avg:89.81ms +step:1182/1680 train_time:106160ms step_avg:89.81ms +step:1183/1680 train_time:106251ms step_avg:89.81ms +step:1184/1680 train_time:106342ms step_avg:89.82ms +step:1185/1680 train_time:106434ms step_avg:89.82ms +step:1186/1680 train_time:106525ms step_avg:89.82ms +step:1187/1680 train_time:106615ms step_avg:89.82ms +step:1188/1680 train_time:106706ms step_avg:89.82ms +step:1189/1680 train_time:106797ms step_avg:89.82ms +step:1190/1680 train_time:106888ms step_avg:89.82ms +step:1191/1680 train_time:106979ms step_avg:89.82ms +step:1192/1680 train_time:107069ms step_avg:89.82ms +step:1193/1680 train_time:107160ms step_avg:89.82ms +step:1194/1680 train_time:107251ms step_avg:89.82ms +step:1195/1680 train_time:107341ms step_avg:89.82ms +step:1196/1680 train_time:107431ms step_avg:89.83ms +step:1197/1680 train_time:107522ms step_avg:89.83ms +step:1198/1680 train_time:107614ms step_avg:89.83ms +step:1199/1680 train_time:107705ms step_avg:89.83ms +step:1200/1680 train_time:107798ms step_avg:89.83ms +step:1201/1680 train_time:107887ms step_avg:89.83ms +step:1202/1680 train_time:107978ms step_avg:89.83ms +step:1203/1680 train_time:108068ms step_avg:89.83ms +step:1204/1680 train_time:108158ms step_avg:89.83ms +step:1205/1680 train_time:108249ms step_avg:89.83ms +step:1206/1680 train_time:108340ms step_avg:89.83ms +step:1207/1680 train_time:108431ms step_avg:89.83ms +step:1208/1680 train_time:108522ms step_avg:89.84ms +step:1209/1680 train_time:108613ms step_avg:89.84ms +step:1210/1680 train_time:108705ms step_avg:89.84ms +step:1211/1680 train_time:108795ms step_avg:89.84ms +step:1212/1680 train_time:108886ms step_avg:89.84ms +step:1213/1680 train_time:108977ms step_avg:89.84ms +step:1214/1680 train_time:109068ms step_avg:89.84ms +step:1215/1680 train_time:109159ms step_avg:89.84ms +step:1216/1680 train_time:109249ms step_avg:89.84ms +step:1217/1680 train_time:109340ms step_avg:89.84ms +step:1218/1680 train_time:109430ms step_avg:89.84ms +step:1219/1680 train_time:109521ms step_avg:89.85ms +step:1220/1680 train_time:109613ms step_avg:89.85ms +step:1221/1680 train_time:109704ms step_avg:89.85ms +step:1222/1680 train_time:109794ms step_avg:89.85ms +step:1223/1680 train_time:109886ms step_avg:89.85ms +step:1224/1680 train_time:109977ms step_avg:89.85ms +step:1225/1680 train_time:110067ms step_avg:89.85ms +step:1226/1680 train_time:110158ms step_avg:89.85ms +step:1227/1680 train_time:110248ms step_avg:89.85ms +step:1228/1680 train_time:110338ms step_avg:89.85ms +step:1229/1680 train_time:110429ms step_avg:89.85ms +step:1230/1680 train_time:110519ms step_avg:89.85ms +step:1231/1680 train_time:110611ms step_avg:89.85ms +step:1232/1680 train_time:110702ms step_avg:89.86ms +step:1233/1680 train_time:110793ms step_avg:89.86ms +step:1234/1680 train_time:110884ms step_avg:89.86ms +step:1235/1680 train_time:110974ms step_avg:89.86ms +step:1236/1680 train_time:111066ms step_avg:89.86ms +step:1237/1680 train_time:111157ms step_avg:89.86ms +step:1238/1680 train_time:111248ms step_avg:89.86ms +step:1239/1680 train_time:111339ms step_avg:89.86ms +step:1240/1680 train_time:111430ms step_avg:89.86ms +step:1241/1680 train_time:111521ms step_avg:89.86ms +step:1242/1680 train_time:111612ms step_avg:89.86ms +step:1243/1680 train_time:111703ms step_avg:89.87ms +step:1244/1680 train_time:111793ms step_avg:89.87ms +step:1245/1680 train_time:111885ms step_avg:89.87ms +step:1246/1680 train_time:111976ms step_avg:89.87ms +step:1247/1680 train_time:112067ms step_avg:89.87ms +step:1248/1680 train_time:112158ms step_avg:89.87ms +step:1249/1680 train_time:112247ms step_avg:89.87ms +step:1250/1680 train_time:112339ms step_avg:89.87ms +step:1250/1680 val_loss:3.3766 train_time:112431ms step_avg:89.94ms +step:1251/1680 train_time:112454ms step_avg:89.89ms +step:1252/1680 train_time:112525ms step_avg:89.88ms +step:1253/1680 train_time:112622ms step_avg:89.88ms +step:1254/1680 train_time:112713ms step_avg:89.88ms +step:1255/1680 train_time:112804ms step_avg:89.88ms +step:1256/1680 train_time:112894ms step_avg:89.88ms +step:1257/1680 train_time:112984ms step_avg:89.88ms +step:1258/1680 train_time:113073ms step_avg:89.88ms +step:1259/1680 train_time:113163ms step_avg:89.88ms +step:1260/1680 train_time:113253ms step_avg:89.88ms +step:1261/1680 train_time:113343ms step_avg:89.88ms +step:1262/1680 train_time:113434ms step_avg:89.88ms +step:1263/1680 train_time:113527ms step_avg:89.89ms +step:1264/1680 train_time:113621ms step_avg:89.89ms +step:1265/1680 train_time:113712ms step_avg:89.89ms +step:1266/1680 train_time:113803ms step_avg:89.89ms +step:1267/1680 train_time:113894ms step_avg:89.89ms +step:1268/1680 train_time:113984ms step_avg:89.89ms +step:1269/1680 train_time:114074ms step_avg:89.89ms +step:1270/1680 train_time:114164ms step_avg:89.89ms +step:1271/1680 train_time:114254ms step_avg:89.89ms +step:1272/1680 train_time:114345ms step_avg:89.89ms +step:1273/1680 train_time:114437ms step_avg:89.90ms +step:1274/1680 train_time:114529ms step_avg:89.90ms +step:1275/1680 train_time:114623ms step_avg:89.90ms +step:1276/1680 train_time:114714ms step_avg:89.90ms +step:1277/1680 train_time:114805ms step_avg:89.90ms +step:1278/1680 train_time:114895ms step_avg:89.90ms +step:1279/1680 train_time:114986ms step_avg:89.90ms +step:1280/1680 train_time:115076ms step_avg:89.90ms +step:1281/1680 train_time:115166ms step_avg:89.90ms +step:1282/1680 train_time:115256ms step_avg:89.90ms +step:1283/1680 train_time:115347ms step_avg:89.90ms +step:1284/1680 train_time:115439ms step_avg:89.91ms +step:1285/1680 train_time:115531ms step_avg:89.91ms +step:1286/1680 train_time:115623ms step_avg:89.91ms +step:1287/1680 train_time:115715ms step_avg:89.91ms +step:1288/1680 train_time:115808ms step_avg:89.91ms +step:1289/1680 train_time:115899ms step_avg:89.91ms +step:1290/1680 train_time:115989ms step_avg:89.91ms +step:1291/1680 train_time:116079ms step_avg:89.91ms +step:1292/1680 train_time:116169ms step_avg:89.91ms +step:1293/1680 train_time:116258ms step_avg:89.91ms +step:1294/1680 train_time:116350ms step_avg:89.91ms +step:1295/1680 train_time:116440ms step_avg:89.92ms +step:1296/1680 train_time:116534ms step_avg:89.92ms +step:1297/1680 train_time:116623ms step_avg:89.92ms +step:1298/1680 train_time:116714ms step_avg:89.92ms +step:1299/1680 train_time:116807ms step_avg:89.92ms +step:1300/1680 train_time:116897ms step_avg:89.92ms +step:1301/1680 train_time:116987ms step_avg:89.92ms +step:1302/1680 train_time:117078ms step_avg:89.92ms +step:1303/1680 train_time:117168ms step_avg:89.92ms +step:1304/1680 train_time:117259ms step_avg:89.92ms +step:1305/1680 train_time:117350ms step_avg:89.92ms +step:1306/1680 train_time:117441ms step_avg:89.92ms +step:1307/1680 train_time:117532ms step_avg:89.92ms +step:1308/1680 train_time:117623ms step_avg:89.93ms +step:1309/1680 train_time:117714ms step_avg:89.93ms +step:1310/1680 train_time:117806ms step_avg:89.93ms +step:1311/1680 train_time:117897ms step_avg:89.93ms +step:1312/1680 train_time:117988ms step_avg:89.93ms +step:1313/1680 train_time:118078ms step_avg:89.93ms +step:1314/1680 train_time:118169ms step_avg:89.93ms +step:1315/1680 train_time:118258ms step_avg:89.93ms +step:1316/1680 train_time:118349ms step_avg:89.93ms +step:1317/1680 train_time:118439ms step_avg:89.93ms +step:1318/1680 train_time:118530ms step_avg:89.93ms +step:1319/1680 train_time:118622ms step_avg:89.93ms +step:1320/1680 train_time:118713ms step_avg:89.93ms +step:1321/1680 train_time:118804ms step_avg:89.94ms +step:1322/1680 train_time:118895ms step_avg:89.94ms +step:1323/1680 train_time:118986ms step_avg:89.94ms +step:1324/1680 train_time:119077ms step_avg:89.94ms +step:1325/1680 train_time:119168ms step_avg:89.94ms +step:1326/1680 train_time:119259ms step_avg:89.94ms +step:1327/1680 train_time:119349ms step_avg:89.94ms +step:1328/1680 train_time:119440ms step_avg:89.94ms +step:1329/1680 train_time:119531ms step_avg:89.94ms +step:1330/1680 train_time:119622ms step_avg:89.94ms +step:1331/1680 train_time:119714ms step_avg:89.94ms +step:1332/1680 train_time:119805ms step_avg:89.94ms +step:1333/1680 train_time:119895ms step_avg:89.94ms +step:1334/1680 train_time:119986ms step_avg:89.94ms +step:1335/1680 train_time:120077ms step_avg:89.95ms +step:1336/1680 train_time:120168ms step_avg:89.95ms +step:1337/1680 train_time:120258ms step_avg:89.95ms +step:1338/1680 train_time:120349ms step_avg:89.95ms +step:1339/1680 train_time:120440ms step_avg:89.95ms +step:1340/1680 train_time:120530ms step_avg:89.95ms +step:1341/1680 train_time:120621ms step_avg:89.95ms +step:1342/1680 train_time:120712ms step_avg:89.95ms +step:1343/1680 train_time:120803ms step_avg:89.95ms +step:1344/1680 train_time:120894ms step_avg:89.95ms +step:1345/1680 train_time:120985ms step_avg:89.95ms +step:1346/1680 train_time:121075ms step_avg:89.95ms +step:1347/1680 train_time:121167ms step_avg:89.95ms +step:1348/1680 train_time:121257ms step_avg:89.95ms +step:1349/1680 train_time:121348ms step_avg:89.95ms +step:1350/1680 train_time:121439ms step_avg:89.95ms +step:1351/1680 train_time:121529ms step_avg:89.95ms +step:1352/1680 train_time:121620ms step_avg:89.96ms +step:1353/1680 train_time:121712ms step_avg:89.96ms +step:1354/1680 train_time:121803ms step_avg:89.96ms +step:1355/1680 train_time:121894ms step_avg:89.96ms +step:1356/1680 train_time:121985ms step_avg:89.96ms +step:1357/1680 train_time:122076ms step_avg:89.96ms +step:1358/1680 train_time:122167ms step_avg:89.96ms +step:1359/1680 train_time:122257ms step_avg:89.96ms +step:1360/1680 train_time:122348ms step_avg:89.96ms +step:1361/1680 train_time:122440ms step_avg:89.96ms +step:1362/1680 train_time:122529ms step_avg:89.96ms +step:1363/1680 train_time:122621ms step_avg:89.96ms +step:1364/1680 train_time:122712ms step_avg:89.97ms +step:1365/1680 train_time:122804ms step_avg:89.97ms +step:1366/1680 train_time:122894ms step_avg:89.97ms +step:1367/1680 train_time:122986ms step_avg:89.97ms +step:1368/1680 train_time:123076ms step_avg:89.97ms +step:1369/1680 train_time:123167ms step_avg:89.97ms +step:1370/1680 train_time:123258ms step_avg:89.97ms +step:1371/1680 train_time:123348ms step_avg:89.97ms +step:1372/1680 train_time:123439ms step_avg:89.97ms +step:1373/1680 train_time:123529ms step_avg:89.97ms +step:1374/1680 train_time:123621ms step_avg:89.97ms +step:1375/1680 train_time:123712ms step_avg:89.97ms +step:1375/1680 val_loss:3.3425 train_time:123805ms step_avg:90.04ms +step:1376/1680 train_time:123828ms step_avg:89.99ms +step:1377/1680 train_time:123902ms step_avg:89.98ms +step:1378/1680 train_time:124000ms step_avg:89.99ms +step:1379/1680 train_time:124091ms step_avg:89.99ms +step:1380/1680 train_time:124181ms step_avg:89.99ms +step:1381/1680 train_time:124271ms step_avg:89.99ms +step:1382/1680 train_time:124360ms step_avg:89.99ms +step:1383/1680 train_time:124450ms step_avg:89.99ms +step:1384/1680 train_time:124540ms step_avg:89.99ms +step:1385/1680 train_time:124629ms step_avg:89.99ms +step:1386/1680 train_time:124720ms step_avg:89.99ms +step:1387/1680 train_time:124812ms step_avg:89.99ms +step:1388/1680 train_time:124906ms step_avg:89.99ms +step:1389/1680 train_time:125000ms step_avg:89.99ms +step:1390/1680 train_time:125092ms step_avg:89.99ms +step:1391/1680 train_time:125183ms step_avg:90.00ms +step:1392/1680 train_time:125273ms step_avg:90.00ms +step:1393/1680 train_time:125363ms step_avg:89.99ms +step:1394/1680 train_time:125452ms step_avg:89.99ms +step:1395/1680 train_time:125542ms step_avg:89.99ms +step:1396/1680 train_time:125632ms step_avg:89.99ms +step:1397/1680 train_time:125722ms step_avg:89.99ms +step:1398/1680 train_time:125813ms step_avg:90.00ms +step:1399/1680 train_time:125906ms step_avg:90.00ms +step:1400/1680 train_time:125997ms step_avg:90.00ms +step:1401/1680 train_time:126090ms step_avg:90.00ms +step:1402/1680 train_time:126181ms step_avg:90.00ms +step:1403/1680 train_time:126272ms step_avg:90.00ms +step:1404/1680 train_time:126362ms step_avg:90.00ms +step:1405/1680 train_time:126452ms step_avg:90.00ms +step:1406/1680 train_time:126542ms step_avg:90.00ms +step:1407/1680 train_time:126632ms step_avg:90.00ms +step:1408/1680 train_time:126723ms step_avg:90.00ms +step:1409/1680 train_time:126814ms step_avg:90.00ms +step:1410/1680 train_time:126905ms step_avg:90.00ms +step:1411/1680 train_time:126997ms step_avg:90.00ms +step:1412/1680 train_time:127088ms step_avg:90.01ms +step:1413/1680 train_time:127180ms step_avg:90.01ms +step:1414/1680 train_time:127272ms step_avg:90.01ms +step:1415/1680 train_time:127362ms step_avg:90.01ms +step:1416/1680 train_time:127452ms step_avg:90.01ms +step:1417/1680 train_time:127542ms step_avg:90.01ms +step:1418/1680 train_time:127632ms step_avg:90.01ms +step:1419/1680 train_time:127722ms step_avg:90.01ms +step:1420/1680 train_time:127813ms step_avg:90.01ms +step:1421/1680 train_time:127904ms step_avg:90.01ms +step:1422/1680 train_time:127996ms step_avg:90.01ms +step:1423/1680 train_time:128088ms step_avg:90.01ms +step:1424/1680 train_time:128180ms step_avg:90.01ms +step:1425/1680 train_time:128271ms step_avg:90.02ms +step:1426/1680 train_time:128363ms step_avg:90.02ms +step:1427/1680 train_time:128453ms step_avg:90.02ms +step:1428/1680 train_time:128544ms step_avg:90.02ms +step:1429/1680 train_time:128634ms step_avg:90.02ms +step:1430/1680 train_time:128725ms step_avg:90.02ms +step:1431/1680 train_time:128816ms step_avg:90.02ms +step:1432/1680 train_time:128908ms step_avg:90.02ms +step:1433/1680 train_time:128999ms step_avg:90.02ms +step:1434/1680 train_time:129091ms step_avg:90.02ms +step:1435/1680 train_time:129181ms step_avg:90.02ms +step:1436/1680 train_time:129273ms step_avg:90.02ms +step:1437/1680 train_time:129363ms step_avg:90.02ms +step:1438/1680 train_time:129453ms step_avg:90.02ms +step:1439/1680 train_time:129543ms step_avg:90.02ms +step:1440/1680 train_time:129633ms step_avg:90.02ms +step:1441/1680 train_time:129724ms step_avg:90.02ms +step:1442/1680 train_time:129816ms step_avg:90.02ms +step:1443/1680 train_time:129908ms step_avg:90.03ms +step:1444/1680 train_time:129998ms step_avg:90.03ms +step:1445/1680 train_time:130090ms step_avg:90.03ms +step:1446/1680 train_time:130183ms step_avg:90.03ms +step:1447/1680 train_time:130273ms step_avg:90.03ms +step:1448/1680 train_time:130364ms step_avg:90.03ms +step:1449/1680 train_time:130454ms step_avg:90.03ms +step:1450/1680 train_time:130545ms step_avg:90.03ms +step:1451/1680 train_time:130636ms step_avg:90.03ms +step:1452/1680 train_time:130726ms step_avg:90.03ms +step:1453/1680 train_time:130816ms step_avg:90.03ms +step:1454/1680 train_time:130907ms step_avg:90.03ms +step:1455/1680 train_time:130999ms step_avg:90.03ms +step:1456/1680 train_time:131091ms step_avg:90.04ms +step:1457/1680 train_time:131182ms step_avg:90.04ms +step:1458/1680 train_time:131272ms step_avg:90.04ms +step:1459/1680 train_time:131363ms step_avg:90.04ms +step:1460/1680 train_time:131454ms step_avg:90.04ms +step:1461/1680 train_time:131544ms step_avg:90.04ms +step:1462/1680 train_time:131635ms step_avg:90.04ms +step:1463/1680 train_time:131725ms step_avg:90.04ms +step:1464/1680 train_time:131816ms step_avg:90.04ms +step:1465/1680 train_time:131907ms step_avg:90.04ms +step:1466/1680 train_time:131998ms step_avg:90.04ms +step:1467/1680 train_time:132090ms step_avg:90.04ms +step:1468/1680 train_time:132182ms step_avg:90.04ms +step:1469/1680 train_time:132273ms step_avg:90.04ms +step:1470/1680 train_time:132364ms step_avg:90.04ms +step:1471/1680 train_time:132454ms step_avg:90.04ms +step:1472/1680 train_time:132544ms step_avg:90.04ms +step:1473/1680 train_time:132635ms step_avg:90.04ms +step:1474/1680 train_time:132726ms step_avg:90.04ms +step:1475/1680 train_time:132818ms step_avg:90.05ms +step:1476/1680 train_time:132911ms step_avg:90.05ms +step:1477/1680 train_time:133001ms step_avg:90.05ms +step:1478/1680 train_time:133092ms step_avg:90.05ms +step:1479/1680 train_time:133184ms step_avg:90.05ms +step:1480/1680 train_time:133275ms step_avg:90.05ms +step:1481/1680 train_time:133366ms step_avg:90.05ms +step:1482/1680 train_time:133456ms step_avg:90.05ms +step:1483/1680 train_time:133548ms step_avg:90.05ms +step:1484/1680 train_time:133638ms step_avg:90.05ms +step:1485/1680 train_time:133729ms step_avg:90.05ms +step:1486/1680 train_time:133820ms step_avg:90.05ms +step:1487/1680 train_time:133911ms step_avg:90.05ms +step:1488/1680 train_time:134003ms step_avg:90.06ms +step:1489/1680 train_time:134094ms step_avg:90.06ms +step:1490/1680 train_time:134185ms step_avg:90.06ms +step:1491/1680 train_time:134276ms step_avg:90.06ms +step:1492/1680 train_time:134367ms step_avg:90.06ms +step:1493/1680 train_time:134459ms step_avg:90.06ms +step:1494/1680 train_time:134550ms step_avg:90.06ms +step:1495/1680 train_time:134641ms step_avg:90.06ms +step:1496/1680 train_time:134731ms step_avg:90.06ms +step:1497/1680 train_time:134822ms step_avg:90.06ms +step:1498/1680 train_time:134912ms step_avg:90.06ms +step:1499/1680 train_time:135003ms step_avg:90.06ms +step:1500/1680 train_time:135095ms step_avg:90.06ms +step:1500/1680 val_loss:3.3127 train_time:135186ms step_avg:90.12ms +step:1501/1680 train_time:135209ms step_avg:90.08ms +step:1502/1680 train_time:135282ms step_avg:90.07ms +step:1503/1680 train_time:135376ms step_avg:90.07ms +step:1504/1680 train_time:135467ms step_avg:90.07ms +step:1505/1680 train_time:135557ms step_avg:90.07ms +step:1506/1680 train_time:135646ms step_avg:90.07ms +step:1507/1680 train_time:135736ms step_avg:90.07ms +step:1508/1680 train_time:135826ms step_avg:90.07ms +step:1509/1680 train_time:135915ms step_avg:90.07ms +step:1510/1680 train_time:136005ms step_avg:90.07ms +step:1511/1680 train_time:136095ms step_avg:90.07ms +step:1512/1680 train_time:136188ms step_avg:90.07ms +step:1513/1680 train_time:136281ms step_avg:90.07ms +step:1514/1680 train_time:136375ms step_avg:90.08ms +step:1515/1680 train_time:136469ms step_avg:90.08ms +step:1516/1680 train_time:136559ms step_avg:90.08ms +step:1517/1680 train_time:136649ms step_avg:90.08ms +step:1518/1680 train_time:136739ms step_avg:90.08ms +step:1519/1680 train_time:136830ms step_avg:90.08ms +step:1520/1680 train_time:136920ms step_avg:90.08ms +step:1521/1680 train_time:137009ms step_avg:90.08ms +step:1522/1680 train_time:137099ms step_avg:90.08ms +step:1523/1680 train_time:137191ms step_avg:90.08ms +step:1524/1680 train_time:137283ms step_avg:90.08ms +step:1525/1680 train_time:137375ms step_avg:90.08ms +step:1526/1680 train_time:137466ms step_avg:90.08ms +step:1527/1680 train_time:137558ms step_avg:90.08ms +step:1528/1680 train_time:137649ms step_avg:90.08ms +step:1529/1680 train_time:137740ms step_avg:90.08ms +step:1530/1680 train_time:137831ms step_avg:90.09ms +step:1531/1680 train_time:137921ms step_avg:90.09ms +step:1532/1680 train_time:138012ms step_avg:90.09ms +step:1533/1680 train_time:138103ms step_avg:90.09ms +step:1534/1680 train_time:138194ms step_avg:90.09ms +step:1535/1680 train_time:138285ms step_avg:90.09ms +step:1536/1680 train_time:138377ms step_avg:90.09ms +step:1537/1680 train_time:138469ms step_avg:90.09ms +step:1538/1680 train_time:138560ms step_avg:90.09ms +step:1539/1680 train_time:138650ms step_avg:90.09ms +step:1540/1680 train_time:138741ms step_avg:90.09ms +step:1541/1680 train_time:138831ms step_avg:90.09ms +step:1542/1680 train_time:138921ms step_avg:90.09ms +step:1543/1680 train_time:139012ms step_avg:90.09ms +step:1544/1680 train_time:139102ms step_avg:90.09ms +step:1545/1680 train_time:139194ms step_avg:90.09ms +step:1546/1680 train_time:139284ms step_avg:90.09ms +step:1547/1680 train_time:139375ms step_avg:90.09ms +step:1548/1680 train_time:139467ms step_avg:90.09ms +step:1549/1680 train_time:139558ms step_avg:90.10ms +step:1550/1680 train_time:139648ms step_avg:90.10ms +step:1551/1680 train_time:139739ms step_avg:90.10ms +step:1552/1680 train_time:139830ms step_avg:90.10ms +step:1553/1680 train_time:139921ms step_avg:90.10ms +step:1554/1680 train_time:140012ms step_avg:90.10ms +step:1555/1680 train_time:140102ms step_avg:90.10ms +step:1556/1680 train_time:140194ms step_avg:90.10ms +step:1557/1680 train_time:140285ms step_avg:90.10ms +step:1558/1680 train_time:140376ms step_avg:90.10ms +step:1559/1680 train_time:140467ms step_avg:90.10ms +step:1560/1680 train_time:140559ms step_avg:90.10ms +step:1561/1680 train_time:140651ms step_avg:90.10ms +step:1562/1680 train_time:140741ms step_avg:90.10ms +step:1563/1680 train_time:140832ms step_avg:90.10ms +step:1564/1680 train_time:140929ms step_avg:90.11ms +step:1565/1680 train_time:141015ms step_avg:90.11ms +step:1566/1680 train_time:141106ms step_avg:90.11ms +step:1567/1680 train_time:141196ms step_avg:90.11ms +step:1568/1680 train_time:141287ms step_avg:90.11ms +step:1569/1680 train_time:141379ms step_avg:90.11ms +step:1570/1680 train_time:141470ms step_avg:90.11ms +step:1571/1680 train_time:141561ms step_avg:90.11ms +step:1572/1680 train_time:141652ms step_avg:90.11ms +step:1573/1680 train_time:141742ms step_avg:90.11ms +step:1574/1680 train_time:141833ms step_avg:90.11ms +step:1575/1680 train_time:141925ms step_avg:90.11ms +step:1576/1680 train_time:142016ms step_avg:90.11ms +step:1577/1680 train_time:142108ms step_avg:90.11ms +step:1578/1680 train_time:142198ms step_avg:90.11ms +step:1579/1680 train_time:142290ms step_avg:90.11ms +step:1580/1680 train_time:142381ms step_avg:90.11ms +step:1581/1680 train_time:142473ms step_avg:90.12ms +step:1582/1680 train_time:142563ms step_avg:90.12ms +step:1583/1680 train_time:142654ms step_avg:90.12ms +step:1584/1680 train_time:142744ms step_avg:90.12ms +step:1585/1680 train_time:142835ms step_avg:90.12ms +step:1586/1680 train_time:142926ms step_avg:90.12ms +step:1587/1680 train_time:143016ms step_avg:90.12ms +step:1588/1680 train_time:143107ms step_avg:90.12ms +step:1589/1680 train_time:143198ms step_avg:90.12ms +step:1590/1680 train_time:143289ms step_avg:90.12ms +step:1591/1680 train_time:143379ms step_avg:90.12ms +step:1592/1680 train_time:143471ms step_avg:90.12ms +step:1593/1680 train_time:143562ms step_avg:90.12ms +step:1594/1680 train_time:143652ms step_avg:90.12ms +step:1595/1680 train_time:143742ms step_avg:90.12ms +step:1596/1680 train_time:143833ms step_avg:90.12ms +step:1597/1680 train_time:143924ms step_avg:90.12ms +step:1598/1680 train_time:144014ms step_avg:90.12ms +step:1599/1680 train_time:144105ms step_avg:90.12ms +step:1600/1680 train_time:144196ms step_avg:90.12ms +step:1601/1680 train_time:144287ms step_avg:90.12ms +step:1602/1680 train_time:144377ms step_avg:90.12ms +step:1603/1680 train_time:144467ms step_avg:90.12ms +step:1604/1680 train_time:144558ms step_avg:90.12ms +step:1605/1680 train_time:144649ms step_avg:90.12ms +step:1606/1680 train_time:144740ms step_avg:90.12ms +step:1607/1680 train_time:144831ms step_avg:90.12ms +step:1608/1680 train_time:144922ms step_avg:90.13ms +step:1609/1680 train_time:145013ms step_avg:90.13ms +step:1610/1680 train_time:145104ms step_avg:90.13ms +step:1611/1680 train_time:145194ms step_avg:90.13ms +step:1612/1680 train_time:145285ms step_avg:90.13ms +step:1613/1680 train_time:145376ms step_avg:90.13ms +step:1614/1680 train_time:145467ms step_avg:90.13ms +step:1615/1680 train_time:145558ms step_avg:90.13ms +step:1616/1680 train_time:145649ms step_avg:90.13ms +step:1617/1680 train_time:145740ms step_avg:90.13ms +step:1618/1680 train_time:145830ms step_avg:90.13ms +step:1619/1680 train_time:145922ms step_avg:90.13ms +step:1620/1680 train_time:146013ms step_avg:90.13ms +step:1621/1680 train_time:146105ms step_avg:90.13ms +step:1622/1680 train_time:146196ms step_avg:90.13ms +step:1623/1680 train_time:146286ms step_avg:90.13ms +step:1624/1680 train_time:146377ms step_avg:90.13ms +step:1625/1680 train_time:146468ms step_avg:90.13ms +step:1625/1680 val_loss:3.2891 train_time:146560ms step_avg:90.19ms +step:1626/1680 train_time:146583ms step_avg:90.15ms +step:1627/1680 train_time:146653ms step_avg:90.14ms +step:1628/1680 train_time:146745ms step_avg:90.14ms +step:1629/1680 train_time:146835ms step_avg:90.14ms +step:1630/1680 train_time:146925ms step_avg:90.14ms +step:1631/1680 train_time:147014ms step_avg:90.14ms +step:1632/1680 train_time:147104ms step_avg:90.14ms +step:1633/1680 train_time:147193ms step_avg:90.14ms +step:1634/1680 train_time:147283ms step_avg:90.14ms +step:1635/1680 train_time:147373ms step_avg:90.14ms +step:1636/1680 train_time:147464ms step_avg:90.14ms +step:1637/1680 train_time:147556ms step_avg:90.14ms +step:1638/1680 train_time:147650ms step_avg:90.14ms +step:1639/1680 train_time:147742ms step_avg:90.14ms +step:1640/1680 train_time:147834ms step_avg:90.14ms +step:1641/1680 train_time:147924ms step_avg:90.14ms +step:1642/1680 train_time:148014ms step_avg:90.14ms +step:1643/1680 train_time:148104ms step_avg:90.14ms +step:1644/1680 train_time:148193ms step_avg:90.14ms +step:1645/1680 train_time:148284ms step_avg:90.14ms +step:1646/1680 train_time:148374ms step_avg:90.14ms +step:1647/1680 train_time:148465ms step_avg:90.14ms +step:1648/1680 train_time:148557ms step_avg:90.14ms +step:1649/1680 train_time:148649ms step_avg:90.15ms +step:1650/1680 train_time:148742ms step_avg:90.15ms +step:1651/1680 train_time:148832ms step_avg:90.15ms +step:1652/1680 train_time:148923ms step_avg:90.15ms +step:1653/1680 train_time:149012ms step_avg:90.15ms +step:1654/1680 train_time:149103ms step_avg:90.15ms +step:1655/1680 train_time:149193ms step_avg:90.15ms +step:1656/1680 train_time:149283ms step_avg:90.15ms +step:1657/1680 train_time:149375ms step_avg:90.15ms +step:1658/1680 train_time:149465ms step_avg:90.15ms +step:1659/1680 train_time:149557ms step_avg:90.15ms +step:1660/1680 train_time:149648ms step_avg:90.15ms +step:1661/1680 train_time:149740ms step_avg:90.15ms +step:1662/1680 train_time:149833ms step_avg:90.15ms +step:1663/1680 train_time:149924ms step_avg:90.15ms +step:1664/1680 train_time:150014ms step_avg:90.15ms +step:1665/1680 train_time:150105ms step_avg:90.15ms +step:1666/1680 train_time:150196ms step_avg:90.15ms +step:1667/1680 train_time:150287ms step_avg:90.15ms +step:1668/1680 train_time:150377ms step_avg:90.15ms +step:1669/1680 train_time:150468ms step_avg:90.15ms +step:1670/1680 train_time:150559ms step_avg:90.15ms +step:1671/1680 train_time:150649ms step_avg:90.16ms +step:1672/1680 train_time:150741ms step_avg:90.16ms +step:1673/1680 train_time:150833ms step_avg:90.16ms +step:1674/1680 train_time:150924ms step_avg:90.16ms +step:1675/1680 train_time:151014ms step_avg:90.16ms +step:1676/1680 train_time:151105ms step_avg:90.16ms +step:1677/1680 train_time:151196ms step_avg:90.16ms +step:1678/1680 train_time:151286ms step_avg:90.16ms +step:1679/1680 train_time:151377ms step_avg:90.16ms +step:1680/1680 train_time:151467ms step_avg:90.16ms +step:1680/1680 val_loss:3.2786 train_time:151559ms step_avg:90.21ms +peak memory allocated: 31255 MiB reserved: 46554 MiB diff --git a/records/092125_DropAttn/278c1540-3b42-4ab0-94ed-273830bdfa11.txt b/records/092125_DropAttn/278c1540-3b42-4ab0-94ed-273830bdfa11.txt new file mode 100644 index 000000000..b2a62c216 --- /dev/null +++ b/records/092125_DropAttn/278c1540-3b42-4ab0-94ed-273830bdfa11.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 23:25:01 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 43C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 39C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 40C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 86064 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 86065 C /usr/bin/python3 614MiB | +| 0 N/A N/A 86066 C /usr/bin/python3 614MiB | +| 0 N/A N/A 86067 C /usr/bin/python3 614MiB | +| 0 N/A N/A 86068 C /usr/bin/python3 614MiB | +| 0 N/A N/A 86069 C /usr/bin/python3 614MiB | +| 0 N/A N/A 86070 C /usr/bin/python3 614MiB | +| 0 N/A N/A 86071 C /usr/bin/python3 614MiB | +| 1 N/A N/A 86065 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 86066 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 86067 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 86068 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 86069 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 86070 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 86071 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:159ms step_avg:159.33ms +step:2/1680 train_time:181ms step_avg:90.71ms +step:3/1680 train_time:245ms step_avg:81.60ms +step:4/1680 train_time:332ms step_avg:82.89ms +step:5/1680 train_time:420ms step_avg:83.96ms +step:6/1680 train_time:508ms step_avg:84.63ms +step:7/1680 train_time:596ms step_avg:85.18ms +step:8/1680 train_time:684ms step_avg:85.56ms +step:9/1680 train_time:773ms step_avg:85.85ms +step:10/1680 train_time:861ms step_avg:86.06ms +step:11/1680 train_time:949ms step_avg:86.28ms +step:12/1680 train_time:1038ms step_avg:86.48ms +step:13/1680 train_time:1129ms step_avg:86.82ms +step:14/1680 train_time:1219ms step_avg:87.10ms +step:15/1680 train_time:1310ms step_avg:87.32ms +step:16/1680 train_time:1399ms step_avg:87.44ms +step:17/1680 train_time:1488ms step_avg:87.54ms +step:18/1680 train_time:1577ms step_avg:87.62ms +step:19/1680 train_time:1670ms step_avg:87.91ms +step:20/1680 train_time:1755ms step_avg:87.74ms +step:21/1680 train_time:1844ms step_avg:87.79ms +step:22/1680 train_time:1932ms step_avg:87.83ms +step:23/1680 train_time:2022ms step_avg:87.91ms +step:24/1680 train_time:2113ms step_avg:88.03ms +step:25/1680 train_time:2203ms step_avg:88.11ms +step:26/1680 train_time:2293ms step_avg:88.20ms +step:27/1680 train_time:2383ms step_avg:88.25ms +step:28/1680 train_time:2472ms step_avg:88.28ms +step:29/1680 train_time:2561ms step_avg:88.29ms +step:30/1680 train_time:2650ms step_avg:88.34ms +step:31/1680 train_time:2739ms step_avg:88.35ms +step:32/1680 train_time:2827ms step_avg:88.36ms +step:33/1680 train_time:2915ms step_avg:88.34ms +step:34/1680 train_time:3004ms step_avg:88.34ms +step:35/1680 train_time:3093ms step_avg:88.38ms +step:36/1680 train_time:3182ms step_avg:88.40ms +step:37/1680 train_time:3272ms step_avg:88.42ms +step:38/1680 train_time:3361ms step_avg:88.44ms +step:39/1680 train_time:3450ms step_avg:88.47ms +step:40/1680 train_time:3540ms step_avg:88.51ms +step:41/1680 train_time:3630ms step_avg:88.54ms +step:42/1680 train_time:3718ms step_avg:88.53ms +step:43/1680 train_time:3807ms step_avg:88.53ms +step:44/1680 train_time:3896ms step_avg:88.53ms +step:45/1680 train_time:3984ms step_avg:88.53ms +step:46/1680 train_time:4073ms step_avg:88.55ms +step:47/1680 train_time:4163ms step_avg:88.57ms +step:48/1680 train_time:4252ms step_avg:88.59ms +step:49/1680 train_time:4341ms step_avg:88.59ms +step:50/1680 train_time:4431ms step_avg:88.61ms +step:51/1680 train_time:4520ms step_avg:88.62ms +step:52/1680 train_time:4610ms step_avg:88.65ms +step:53/1680 train_time:4699ms step_avg:88.66ms +step:54/1680 train_time:4789ms step_avg:88.68ms +step:55/1680 train_time:4877ms step_avg:88.68ms +step:56/1680 train_time:4966ms step_avg:88.68ms +step:57/1680 train_time:5055ms step_avg:88.68ms +step:58/1680 train_time:5144ms step_avg:88.68ms +step:59/1680 train_time:5233ms step_avg:88.69ms +step:60/1680 train_time:5321ms step_avg:88.69ms +step:61/1680 train_time:5411ms step_avg:88.70ms +step:62/1680 train_time:5499ms step_avg:88.69ms +step:63/1680 train_time:5588ms step_avg:88.69ms +step:64/1680 train_time:5678ms step_avg:88.73ms +step:65/1680 train_time:5765ms step_avg:88.69ms +step:66/1680 train_time:5853ms step_avg:88.68ms +step:67/1680 train_time:5942ms step_avg:88.69ms +step:68/1680 train_time:6032ms step_avg:88.70ms +step:69/1680 train_time:6121ms step_avg:88.71ms +step:70/1680 train_time:6210ms step_avg:88.71ms +step:71/1680 train_time:6299ms step_avg:88.72ms +step:72/1680 train_time:6389ms step_avg:88.73ms +step:73/1680 train_time:6478ms step_avg:88.74ms +step:74/1680 train_time:6567ms step_avg:88.74ms +step:75/1680 train_time:6655ms step_avg:88.74ms +step:76/1680 train_time:6744ms step_avg:88.74ms +step:77/1680 train_time:6833ms step_avg:88.74ms +step:78/1680 train_time:6922ms step_avg:88.74ms +step:79/1680 train_time:7011ms step_avg:88.75ms +step:80/1680 train_time:7100ms step_avg:88.75ms +step:81/1680 train_time:7189ms step_avg:88.76ms +step:82/1680 train_time:7278ms step_avg:88.76ms +step:83/1680 train_time:7367ms step_avg:88.76ms +step:84/1680 train_time:7457ms step_avg:88.77ms +step:85/1680 train_time:7545ms step_avg:88.77ms +step:86/1680 train_time:7634ms step_avg:88.77ms +step:87/1680 train_time:7723ms step_avg:88.77ms +step:88/1680 train_time:7812ms step_avg:88.77ms +step:89/1680 train_time:7901ms step_avg:88.77ms +step:90/1680 train_time:7989ms step_avg:88.77ms +step:91/1680 train_time:8079ms step_avg:88.78ms +step:92/1680 train_time:8168ms step_avg:88.78ms +step:93/1680 train_time:8256ms step_avg:88.78ms +step:94/1680 train_time:8345ms step_avg:88.78ms +step:95/1680 train_time:8434ms step_avg:88.78ms +step:96/1680 train_time:8522ms step_avg:88.77ms +step:97/1680 train_time:8612ms step_avg:88.78ms +step:98/1680 train_time:8700ms step_avg:88.78ms +step:99/1680 train_time:8789ms step_avg:88.78ms +step:100/1680 train_time:8878ms step_avg:88.78ms +step:101/1680 train_time:8971ms step_avg:88.82ms +step:102/1680 train_time:9056ms step_avg:88.78ms +step:103/1680 train_time:9145ms step_avg:88.79ms +step:104/1680 train_time:9234ms step_avg:88.79ms +step:105/1680 train_time:9323ms step_avg:88.79ms +step:106/1680 train_time:9411ms step_avg:88.79ms +step:107/1680 train_time:9501ms step_avg:88.79ms +step:108/1680 train_time:9590ms step_avg:88.80ms +step:109/1680 train_time:9679ms step_avg:88.80ms +step:110/1680 train_time:9768ms step_avg:88.80ms +step:111/1680 train_time:9857ms step_avg:88.80ms +step:112/1680 train_time:9945ms step_avg:88.80ms +step:113/1680 train_time:10034ms step_avg:88.80ms +step:114/1680 train_time:10122ms step_avg:88.79ms +step:115/1680 train_time:10212ms step_avg:88.80ms +step:116/1680 train_time:10301ms step_avg:88.80ms +step:117/1680 train_time:10390ms step_avg:88.80ms +step:118/1680 train_time:10478ms step_avg:88.80ms +step:119/1680 train_time:10567ms step_avg:88.80ms +step:120/1680 train_time:10656ms step_avg:88.80ms +step:121/1680 train_time:10745ms step_avg:88.80ms +step:122/1680 train_time:10834ms step_avg:88.81ms +step:123/1680 train_time:10923ms step_avg:88.80ms +step:124/1680 train_time:11012ms step_avg:88.80ms +step:125/1680 train_time:11100ms step_avg:88.80ms +step:125/1680 val_loss:4.3203 train_time:11190ms step_avg:89.52ms +step:126/1680 train_time:11212ms step_avg:88.99ms +step:127/1680 train_time:11282ms step_avg:88.84ms +step:128/1680 train_time:11377ms step_avg:88.88ms +step:129/1680 train_time:11470ms step_avg:88.91ms +step:130/1680 train_time:11559ms step_avg:88.91ms +step:131/1680 train_time:11647ms step_avg:88.91ms +step:132/1680 train_time:11735ms step_avg:88.90ms +step:133/1680 train_time:11823ms step_avg:88.89ms +step:134/1680 train_time:11911ms step_avg:88.89ms +step:135/1680 train_time:11999ms step_avg:88.88ms +step:136/1680 train_time:12087ms step_avg:88.87ms +step:137/1680 train_time:12175ms step_avg:88.87ms +step:138/1680 train_time:12265ms step_avg:88.88ms +step:139/1680 train_time:12355ms step_avg:88.89ms +step:140/1680 train_time:12446ms step_avg:88.90ms +step:141/1680 train_time:12535ms step_avg:88.90ms +step:142/1680 train_time:12625ms step_avg:88.91ms +step:143/1680 train_time:12713ms step_avg:88.90ms +step:144/1680 train_time:12802ms step_avg:88.91ms +step:145/1680 train_time:12891ms step_avg:88.90ms +step:146/1680 train_time:12979ms step_avg:88.90ms +step:147/1680 train_time:13067ms step_avg:88.89ms +step:148/1680 train_time:13156ms step_avg:88.89ms +step:149/1680 train_time:13247ms step_avg:88.90ms +step:150/1680 train_time:13337ms step_avg:88.91ms +step:151/1680 train_time:13426ms step_avg:88.92ms +step:152/1680 train_time:13516ms step_avg:88.92ms +step:153/1680 train_time:13606ms step_avg:88.93ms +step:154/1680 train_time:13694ms step_avg:88.92ms +step:155/1680 train_time:13784ms step_avg:88.93ms +step:156/1680 train_time:13871ms step_avg:88.92ms +step:157/1680 train_time:13960ms step_avg:88.92ms +step:158/1680 train_time:14048ms step_avg:88.91ms +step:159/1680 train_time:14137ms step_avg:88.91ms +step:160/1680 train_time:14226ms step_avg:88.91ms +step:161/1680 train_time:14315ms step_avg:88.91ms +step:162/1680 train_time:14404ms step_avg:88.92ms +step:163/1680 train_time:14493ms step_avg:88.92ms +step:164/1680 train_time:14583ms step_avg:88.92ms +step:165/1680 train_time:14671ms step_avg:88.92ms +step:166/1680 train_time:14760ms step_avg:88.92ms +step:167/1680 train_time:14848ms step_avg:88.91ms +step:168/1680 train_time:14937ms step_avg:88.91ms +step:169/1680 train_time:15025ms step_avg:88.90ms +step:170/1680 train_time:15113ms step_avg:88.90ms +step:171/1680 train_time:15203ms step_avg:88.90ms +step:172/1680 train_time:15292ms step_avg:88.91ms +step:173/1680 train_time:15381ms step_avg:88.91ms +step:174/1680 train_time:15470ms step_avg:88.91ms +step:175/1680 train_time:15560ms step_avg:88.91ms +step:176/1680 train_time:15649ms step_avg:88.92ms +step:177/1680 train_time:15738ms step_avg:88.92ms +step:178/1680 train_time:15827ms step_avg:88.92ms +step:179/1680 train_time:15916ms step_avg:88.91ms +step:180/1680 train_time:16004ms step_avg:88.91ms +step:181/1680 train_time:16092ms step_avg:88.91ms +step:182/1680 train_time:16181ms step_avg:88.91ms +step:183/1680 train_time:16270ms step_avg:88.91ms +step:184/1680 train_time:16360ms step_avg:88.91ms +step:185/1680 train_time:16449ms step_avg:88.91ms +step:186/1680 train_time:16537ms step_avg:88.91ms +step:187/1680 train_time:16627ms step_avg:88.91ms +step:188/1680 train_time:16716ms step_avg:88.92ms +step:189/1680 train_time:16806ms step_avg:88.92ms +step:190/1680 train_time:16893ms step_avg:88.91ms +step:191/1680 train_time:16983ms step_avg:88.92ms +step:192/1680 train_time:17071ms step_avg:88.91ms +step:193/1680 train_time:17159ms step_avg:88.91ms +step:194/1680 train_time:17248ms step_avg:88.91ms +step:195/1680 train_time:17338ms step_avg:88.91ms +step:196/1680 train_time:17428ms step_avg:88.92ms +step:197/1680 train_time:17517ms step_avg:88.92ms +step:198/1680 train_time:17606ms step_avg:88.92ms +step:199/1680 train_time:17695ms step_avg:88.92ms +step:200/1680 train_time:17785ms step_avg:88.93ms +step:201/1680 train_time:17873ms step_avg:88.92ms +step:202/1680 train_time:17963ms step_avg:88.92ms +step:203/1680 train_time:18051ms step_avg:88.92ms +step:204/1680 train_time:18139ms step_avg:88.92ms +step:205/1680 train_time:18227ms step_avg:88.91ms +step:206/1680 train_time:18317ms step_avg:88.92ms +step:207/1680 train_time:18406ms step_avg:88.92ms +step:208/1680 train_time:18494ms step_avg:88.92ms +step:209/1680 train_time:18583ms step_avg:88.92ms +step:210/1680 train_time:18672ms step_avg:88.91ms +step:211/1680 train_time:18762ms step_avg:88.92ms +step:212/1680 train_time:18851ms step_avg:88.92ms +step:213/1680 train_time:18940ms step_avg:88.92ms +step:214/1680 train_time:19029ms step_avg:88.92ms +step:215/1680 train_time:19118ms step_avg:88.92ms +step:216/1680 train_time:19207ms step_avg:88.92ms +step:217/1680 train_time:19296ms step_avg:88.92ms +step:218/1680 train_time:19385ms step_avg:88.92ms +step:219/1680 train_time:19473ms step_avg:88.92ms +step:220/1680 train_time:19563ms step_avg:88.92ms +step:221/1680 train_time:19652ms step_avg:88.92ms +step:222/1680 train_time:19741ms step_avg:88.92ms +step:223/1680 train_time:19830ms step_avg:88.92ms +step:224/1680 train_time:19920ms step_avg:88.93ms +step:225/1680 train_time:20009ms step_avg:88.93ms +step:226/1680 train_time:20098ms step_avg:88.93ms +step:227/1680 train_time:20187ms step_avg:88.93ms +step:228/1680 train_time:20275ms step_avg:88.93ms +step:229/1680 train_time:20365ms step_avg:88.93ms +step:230/1680 train_time:20454ms step_avg:88.93ms +step:231/1680 train_time:20544ms step_avg:88.93ms +step:232/1680 train_time:20633ms step_avg:88.93ms +step:233/1680 train_time:20722ms step_avg:88.94ms +step:234/1680 train_time:20812ms step_avg:88.94ms +step:235/1680 train_time:20901ms step_avg:88.94ms +step:236/1680 train_time:20990ms step_avg:88.94ms +step:237/1680 train_time:21079ms step_avg:88.94ms +step:238/1680 train_time:21168ms step_avg:88.94ms +step:239/1680 train_time:21257ms step_avg:88.94ms +step:240/1680 train_time:21346ms step_avg:88.94ms +step:241/1680 train_time:21435ms step_avg:88.94ms +step:242/1680 train_time:21524ms step_avg:88.94ms +step:243/1680 train_time:21613ms step_avg:88.94ms +step:244/1680 train_time:21702ms step_avg:88.94ms +step:245/1680 train_time:21791ms step_avg:88.94ms +step:246/1680 train_time:21880ms step_avg:88.94ms +step:247/1680 train_time:21969ms step_avg:88.94ms +step:248/1680 train_time:22059ms step_avg:88.95ms +step:249/1680 train_time:22148ms step_avg:88.95ms +step:250/1680 train_time:22237ms step_avg:88.95ms +step:250/1680 val_loss:3.9720 train_time:22327ms step_avg:89.31ms +step:251/1680 train_time:22349ms step_avg:89.04ms +step:252/1680 train_time:22418ms step_avg:88.96ms +step:253/1680 train_time:22511ms step_avg:88.98ms +step:254/1680 train_time:22604ms step_avg:88.99ms +step:255/1680 train_time:22692ms step_avg:88.99ms +step:256/1680 train_time:22779ms step_avg:88.98ms +step:257/1680 train_time:22867ms step_avg:88.98ms +step:258/1680 train_time:22955ms step_avg:88.97ms +step:259/1680 train_time:23043ms step_avg:88.97ms +step:260/1680 train_time:23131ms step_avg:88.96ms +step:261/1680 train_time:23219ms step_avg:88.96ms +step:262/1680 train_time:23308ms step_avg:88.96ms +step:263/1680 train_time:23400ms step_avg:88.97ms +step:264/1680 train_time:23490ms step_avg:88.98ms +step:265/1680 train_time:23580ms step_avg:88.98ms +step:266/1680 train_time:23669ms step_avg:88.98ms +step:267/1680 train_time:23759ms step_avg:88.98ms +step:268/1680 train_time:23847ms step_avg:88.98ms +step:269/1680 train_time:23935ms step_avg:88.98ms +step:270/1680 train_time:24024ms step_avg:88.98ms +step:271/1680 train_time:24112ms step_avg:88.97ms +step:272/1680 train_time:24200ms step_avg:88.97ms +step:273/1680 train_time:24289ms step_avg:88.97ms +step:274/1680 train_time:24379ms step_avg:88.98ms +step:275/1680 train_time:24469ms step_avg:88.98ms +step:276/1680 train_time:24559ms step_avg:88.98ms +step:277/1680 train_time:24648ms step_avg:88.98ms +step:278/1680 train_time:24737ms step_avg:88.98ms +step:279/1680 train_time:24826ms step_avg:88.98ms +step:280/1680 train_time:24914ms step_avg:88.98ms +step:281/1680 train_time:25002ms step_avg:88.98ms +step:282/1680 train_time:25090ms step_avg:88.97ms +step:283/1680 train_time:25179ms step_avg:88.97ms +step:284/1680 train_time:25267ms step_avg:88.97ms +step:285/1680 train_time:25357ms step_avg:88.97ms +step:286/1680 train_time:25448ms step_avg:88.98ms +step:287/1680 train_time:25537ms step_avg:88.98ms +step:288/1680 train_time:25626ms step_avg:88.98ms +step:289/1680 train_time:25716ms step_avg:88.98ms +step:290/1680 train_time:25805ms step_avg:88.98ms +step:291/1680 train_time:25894ms step_avg:88.98ms +step:292/1680 train_time:25983ms step_avg:88.98ms +step:293/1680 train_time:26072ms step_avg:88.98ms +step:294/1680 train_time:26161ms step_avg:88.98ms +step:295/1680 train_time:26249ms step_avg:88.98ms +step:296/1680 train_time:26338ms step_avg:88.98ms +step:297/1680 train_time:26427ms step_avg:88.98ms +step:298/1680 train_time:26517ms step_avg:88.98ms +step:299/1680 train_time:26606ms step_avg:88.98ms +step:300/1680 train_time:26698ms step_avg:88.99ms +step:301/1680 train_time:26785ms step_avg:88.99ms +step:302/1680 train_time:26873ms step_avg:88.98ms +step:303/1680 train_time:26963ms step_avg:88.99ms +step:304/1680 train_time:27051ms step_avg:88.98ms +step:305/1680 train_time:27140ms step_avg:88.98ms +step:306/1680 train_time:27228ms step_avg:88.98ms +step:307/1680 train_time:27317ms step_avg:88.98ms +step:308/1680 train_time:27406ms step_avg:88.98ms +step:309/1680 train_time:27496ms step_avg:88.98ms +step:310/1680 train_time:27585ms step_avg:88.98ms +step:311/1680 train_time:27675ms step_avg:88.99ms +step:312/1680 train_time:27764ms step_avg:88.99ms +step:313/1680 train_time:27852ms step_avg:88.98ms +step:314/1680 train_time:27941ms step_avg:88.98ms +step:315/1680 train_time:28029ms step_avg:88.98ms +step:316/1680 train_time:28119ms step_avg:88.98ms +step:317/1680 train_time:28207ms step_avg:88.98ms +step:318/1680 train_time:28295ms step_avg:88.98ms +step:319/1680 train_time:28385ms step_avg:88.98ms +step:320/1680 train_time:28474ms step_avg:88.98ms +step:321/1680 train_time:28563ms step_avg:88.98ms +step:322/1680 train_time:28652ms step_avg:88.98ms +step:323/1680 train_time:28741ms step_avg:88.98ms +step:324/1680 train_time:28830ms step_avg:88.98ms +step:325/1680 train_time:28920ms step_avg:88.98ms +step:326/1680 train_time:29008ms step_avg:88.98ms +step:327/1680 train_time:29097ms step_avg:88.98ms +step:328/1680 train_time:29187ms step_avg:88.98ms +step:329/1680 train_time:29275ms step_avg:88.98ms +step:330/1680 train_time:29365ms step_avg:88.98ms +step:331/1680 train_time:29454ms step_avg:88.98ms +step:332/1680 train_time:29544ms step_avg:88.99ms +step:333/1680 train_time:29634ms step_avg:88.99ms +step:334/1680 train_time:29723ms step_avg:88.99ms +step:335/1680 train_time:29812ms step_avg:88.99ms +step:336/1680 train_time:29901ms step_avg:88.99ms +step:337/1680 train_time:29989ms step_avg:88.99ms +step:338/1680 train_time:30078ms step_avg:88.99ms +step:339/1680 train_time:30167ms step_avg:88.99ms +step:340/1680 train_time:30257ms step_avg:88.99ms +step:341/1680 train_time:30347ms step_avg:88.99ms +step:342/1680 train_time:30435ms step_avg:88.99ms +step:343/1680 train_time:30525ms step_avg:88.99ms +step:344/1680 train_time:30614ms step_avg:88.99ms +step:345/1680 train_time:30703ms step_avg:88.99ms +step:346/1680 train_time:30791ms step_avg:88.99ms +step:347/1680 train_time:30881ms step_avg:88.99ms +step:348/1680 train_time:30969ms step_avg:88.99ms +step:349/1680 train_time:31058ms step_avg:88.99ms +step:350/1680 train_time:31147ms step_avg:88.99ms +step:351/1680 train_time:31236ms step_avg:88.99ms +step:352/1680 train_time:31325ms step_avg:88.99ms +step:353/1680 train_time:31414ms step_avg:88.99ms +step:354/1680 train_time:31504ms step_avg:88.99ms +step:355/1680 train_time:31593ms step_avg:89.00ms +step:356/1680 train_time:31683ms step_avg:89.00ms +step:357/1680 train_time:31772ms step_avg:89.00ms +step:358/1680 train_time:31862ms step_avg:89.00ms +step:359/1680 train_time:31951ms step_avg:89.00ms +step:360/1680 train_time:32040ms step_avg:89.00ms +step:361/1680 train_time:32128ms step_avg:89.00ms +step:362/1680 train_time:32218ms step_avg:89.00ms +step:363/1680 train_time:32306ms step_avg:89.00ms +step:364/1680 train_time:32395ms step_avg:89.00ms +step:365/1680 train_time:32484ms step_avg:89.00ms +step:366/1680 train_time:32573ms step_avg:89.00ms +step:367/1680 train_time:32662ms step_avg:89.00ms +step:368/1680 train_time:32750ms step_avg:88.99ms +step:369/1680 train_time:32840ms step_avg:89.00ms +step:370/1680 train_time:32928ms step_avg:88.99ms +step:371/1680 train_time:33017ms step_avg:88.99ms +step:372/1680 train_time:33105ms step_avg:88.99ms +step:373/1680 train_time:33195ms step_avg:89.00ms +step:374/1680 train_time:33282ms step_avg:88.99ms +step:375/1680 train_time:33372ms step_avg:88.99ms +step:375/1680 val_loss:3.8171 train_time:33463ms step_avg:89.23ms +step:376/1680 train_time:33484ms step_avg:89.05ms +step:377/1680 train_time:33553ms step_avg:89.00ms +step:378/1680 train_time:33650ms step_avg:89.02ms +step:379/1680 train_time:33740ms step_avg:89.02ms +step:380/1680 train_time:33828ms step_avg:89.02ms +step:381/1680 train_time:33916ms step_avg:89.02ms +step:382/1680 train_time:34004ms step_avg:89.02ms +step:383/1680 train_time:34093ms step_avg:89.02ms +step:384/1680 train_time:34181ms step_avg:89.01ms +step:385/1680 train_time:34269ms step_avg:89.01ms +step:386/1680 train_time:34357ms step_avg:89.01ms +step:387/1680 train_time:34445ms step_avg:89.01ms +step:388/1680 train_time:34536ms step_avg:89.01ms +step:389/1680 train_time:34628ms step_avg:89.02ms +step:390/1680 train_time:34718ms step_avg:89.02ms +step:391/1680 train_time:34807ms step_avg:89.02ms +step:392/1680 train_time:34896ms step_avg:89.02ms +step:393/1680 train_time:34984ms step_avg:89.02ms +step:394/1680 train_time:35073ms step_avg:89.02ms +step:395/1680 train_time:35162ms step_avg:89.02ms +step:396/1680 train_time:35249ms step_avg:89.01ms +step:397/1680 train_time:35338ms step_avg:89.01ms +step:398/1680 train_time:35426ms step_avg:89.01ms +step:399/1680 train_time:35515ms step_avg:89.01ms +step:400/1680 train_time:35605ms step_avg:89.01ms +step:401/1680 train_time:35695ms step_avg:89.01ms +step:402/1680 train_time:35784ms step_avg:89.01ms +step:403/1680 train_time:35872ms step_avg:89.01ms +step:404/1680 train_time:35962ms step_avg:89.01ms +step:405/1680 train_time:36050ms step_avg:89.01ms +step:406/1680 train_time:36139ms step_avg:89.01ms +step:407/1680 train_time:36227ms step_avg:89.01ms +step:408/1680 train_time:36315ms step_avg:89.01ms +step:409/1680 train_time:36404ms step_avg:89.01ms +step:410/1680 train_time:36493ms step_avg:89.01ms +step:411/1680 train_time:36582ms step_avg:89.01ms +step:412/1680 train_time:36673ms step_avg:89.01ms +step:413/1680 train_time:36763ms step_avg:89.01ms +step:414/1680 train_time:36851ms step_avg:89.01ms +step:415/1680 train_time:36940ms step_avg:89.01ms +step:416/1680 train_time:37029ms step_avg:89.01ms +step:417/1680 train_time:37117ms step_avg:89.01ms +step:418/1680 train_time:37206ms step_avg:89.01ms +step:419/1680 train_time:37295ms step_avg:89.01ms +step:420/1680 train_time:37383ms step_avg:89.01ms +step:421/1680 train_time:37472ms step_avg:89.01ms +step:422/1680 train_time:37561ms step_avg:89.01ms +step:423/1680 train_time:37649ms step_avg:89.01ms +step:424/1680 train_time:37740ms step_avg:89.01ms +step:425/1680 train_time:37828ms step_avg:89.01ms +step:426/1680 train_time:37916ms step_avg:89.01ms +step:427/1680 train_time:38005ms step_avg:89.01ms +step:428/1680 train_time:38095ms step_avg:89.01ms +step:429/1680 train_time:38183ms step_avg:89.01ms +step:430/1680 train_time:38272ms step_avg:89.01ms +step:431/1680 train_time:38361ms step_avg:89.00ms +step:432/1680 train_time:38450ms step_avg:89.00ms +step:433/1680 train_time:38539ms step_avg:89.01ms +step:434/1680 train_time:38628ms step_avg:89.00ms +step:435/1680 train_time:38717ms step_avg:89.00ms +step:436/1680 train_time:38807ms step_avg:89.01ms +step:437/1680 train_time:38896ms step_avg:89.01ms +step:438/1680 train_time:38985ms step_avg:89.01ms +step:439/1680 train_time:39074ms step_avg:89.01ms +step:440/1680 train_time:39163ms step_avg:89.01ms +step:441/1680 train_time:39251ms step_avg:89.00ms +step:442/1680 train_time:39341ms step_avg:89.01ms +step:443/1680 train_time:39429ms step_avg:89.00ms +step:444/1680 train_time:39518ms step_avg:89.00ms +step:445/1680 train_time:39607ms step_avg:89.00ms +step:446/1680 train_time:39696ms step_avg:89.00ms +step:447/1680 train_time:39784ms step_avg:89.00ms +step:448/1680 train_time:39873ms step_avg:89.00ms +step:449/1680 train_time:39962ms step_avg:89.00ms +step:450/1680 train_time:40051ms step_avg:89.00ms +step:451/1680 train_time:40141ms step_avg:89.00ms +step:452/1680 train_time:40230ms step_avg:89.00ms +step:453/1680 train_time:40320ms step_avg:89.01ms +step:454/1680 train_time:40408ms step_avg:89.00ms +step:455/1680 train_time:40497ms step_avg:89.00ms +step:456/1680 train_time:40586ms step_avg:89.00ms +step:457/1680 train_time:40674ms step_avg:89.00ms +step:458/1680 train_time:40762ms step_avg:89.00ms +step:459/1680 train_time:40851ms step_avg:89.00ms +step:460/1680 train_time:40941ms step_avg:89.00ms +step:461/1680 train_time:41029ms step_avg:89.00ms +step:462/1680 train_time:41118ms step_avg:89.00ms +step:463/1680 train_time:41208ms step_avg:89.00ms +step:464/1680 train_time:41297ms step_avg:89.00ms +step:465/1680 train_time:41385ms step_avg:89.00ms +step:466/1680 train_time:41474ms step_avg:89.00ms +step:467/1680 train_time:41563ms step_avg:89.00ms +step:468/1680 train_time:41652ms step_avg:89.00ms +step:469/1680 train_time:41741ms step_avg:89.00ms +step:470/1680 train_time:41829ms step_avg:89.00ms +step:471/1680 train_time:41918ms step_avg:89.00ms +step:472/1680 train_time:42007ms step_avg:89.00ms +step:473/1680 train_time:42096ms step_avg:89.00ms +step:474/1680 train_time:42185ms step_avg:89.00ms +step:475/1680 train_time:42274ms step_avg:89.00ms +step:476/1680 train_time:42363ms step_avg:89.00ms +step:477/1680 train_time:42455ms step_avg:89.00ms +step:478/1680 train_time:42541ms step_avg:89.00ms +step:479/1680 train_time:42629ms step_avg:89.00ms +step:480/1680 train_time:42719ms step_avg:89.00ms +step:481/1680 train_time:42808ms step_avg:89.00ms +step:482/1680 train_time:42897ms step_avg:89.00ms +step:483/1680 train_time:42985ms step_avg:89.00ms +step:484/1680 train_time:43074ms step_avg:89.00ms +step:485/1680 train_time:43163ms step_avg:89.00ms +step:486/1680 train_time:43252ms step_avg:89.00ms +step:487/1680 train_time:43340ms step_avg:88.99ms +step:488/1680 train_time:43429ms step_avg:88.99ms +step:489/1680 train_time:43519ms step_avg:89.00ms +step:490/1680 train_time:43607ms step_avg:88.99ms +step:491/1680 train_time:43696ms step_avg:88.99ms +step:492/1680 train_time:43785ms step_avg:88.99ms +step:493/1680 train_time:43874ms step_avg:88.99ms +step:494/1680 train_time:43963ms step_avg:88.99ms +step:495/1680 train_time:44052ms step_avg:88.99ms +step:496/1680 train_time:44141ms step_avg:88.99ms +step:497/1680 train_time:44229ms step_avg:88.99ms +step:498/1680 train_time:44318ms step_avg:88.99ms +step:499/1680 train_time:44407ms step_avg:88.99ms +step:500/1680 train_time:44496ms step_avg:88.99ms +step:500/1680 val_loss:3.7151 train_time:44586ms step_avg:89.17ms +step:501/1680 train_time:44608ms step_avg:89.04ms +step:502/1680 train_time:44678ms step_avg:89.00ms +step:503/1680 train_time:44772ms step_avg:89.01ms +step:504/1680 train_time:44863ms step_avg:89.01ms +step:505/1680 train_time:44952ms step_avg:89.01ms +step:506/1680 train_time:45041ms step_avg:89.01ms +step:507/1680 train_time:45129ms step_avg:89.01ms +step:508/1680 train_time:45217ms step_avg:89.01ms +step:509/1680 train_time:45305ms step_avg:89.01ms +step:510/1680 train_time:45393ms step_avg:89.01ms +step:511/1680 train_time:45481ms step_avg:89.00ms +step:512/1680 train_time:45570ms step_avg:89.00ms +step:513/1680 train_time:45661ms step_avg:89.01ms +step:514/1680 train_time:45752ms step_avg:89.01ms +step:515/1680 train_time:45842ms step_avg:89.01ms +step:516/1680 train_time:45931ms step_avg:89.01ms +step:517/1680 train_time:46019ms step_avg:89.01ms +step:518/1680 train_time:46108ms step_avg:89.01ms +step:519/1680 train_time:46197ms step_avg:89.01ms +step:520/1680 train_time:46285ms step_avg:89.01ms +step:521/1680 train_time:46374ms step_avg:89.01ms +step:522/1680 train_time:46462ms step_avg:89.01ms +step:523/1680 train_time:46551ms step_avg:89.01ms +step:524/1680 train_time:46640ms step_avg:89.01ms +step:525/1680 train_time:46730ms step_avg:89.01ms +step:526/1680 train_time:46820ms step_avg:89.01ms +step:527/1680 train_time:46908ms step_avg:89.01ms +step:528/1680 train_time:46997ms step_avg:89.01ms +step:529/1680 train_time:47086ms step_avg:89.01ms +step:530/1680 train_time:47175ms step_avg:89.01ms +step:531/1680 train_time:47263ms step_avg:89.01ms +step:532/1680 train_time:47351ms step_avg:89.01ms +step:533/1680 train_time:47439ms step_avg:89.00ms +step:534/1680 train_time:47528ms step_avg:89.00ms +step:535/1680 train_time:47617ms step_avg:89.00ms +step:536/1680 train_time:47706ms step_avg:89.00ms +step:537/1680 train_time:47795ms step_avg:89.00ms +step:538/1680 train_time:47885ms step_avg:89.01ms +step:539/1680 train_time:47974ms step_avg:89.01ms +step:540/1680 train_time:48063ms step_avg:89.00ms +step:541/1680 train_time:48151ms step_avg:89.00ms +step:542/1680 train_time:48240ms step_avg:89.00ms +step:543/1680 train_time:48328ms step_avg:89.00ms +step:544/1680 train_time:48417ms step_avg:89.00ms +step:545/1680 train_time:48506ms step_avg:89.00ms +step:546/1680 train_time:48596ms step_avg:89.00ms +step:547/1680 train_time:48685ms step_avg:89.00ms +step:548/1680 train_time:48775ms step_avg:89.01ms +step:549/1680 train_time:48865ms step_avg:89.01ms +step:550/1680 train_time:48956ms step_avg:89.01ms +step:551/1680 train_time:49046ms step_avg:89.01ms +step:552/1680 train_time:49136ms step_avg:89.01ms +step:553/1680 train_time:49227ms step_avg:89.02ms +step:554/1680 train_time:49318ms step_avg:89.02ms +step:555/1680 train_time:49407ms step_avg:89.02ms +step:556/1680 train_time:49497ms step_avg:89.02ms +step:557/1680 train_time:49587ms step_avg:89.02ms +step:558/1680 train_time:49677ms step_avg:89.03ms +step:559/1680 train_time:49767ms step_avg:89.03ms +step:560/1680 train_time:49858ms step_avg:89.03ms +step:561/1680 train_time:49948ms step_avg:89.03ms +step:562/1680 train_time:50039ms step_avg:89.04ms +step:563/1680 train_time:50129ms step_avg:89.04ms +step:564/1680 train_time:50219ms step_avg:89.04ms +step:565/1680 train_time:50309ms step_avg:89.04ms +step:566/1680 train_time:50399ms step_avg:89.04ms +step:567/1680 train_time:50489ms step_avg:89.05ms +step:568/1680 train_time:50579ms step_avg:89.05ms +step:569/1680 train_time:50669ms step_avg:89.05ms +step:570/1680 train_time:50759ms step_avg:89.05ms +step:571/1680 train_time:50851ms step_avg:89.06ms +step:572/1680 train_time:50940ms step_avg:89.06ms +step:573/1680 train_time:51030ms step_avg:89.06ms +step:574/1680 train_time:51120ms step_avg:89.06ms +step:575/1680 train_time:51210ms step_avg:89.06ms +step:576/1680 train_time:51300ms step_avg:89.06ms +step:577/1680 train_time:51391ms step_avg:89.07ms +step:578/1680 train_time:51481ms step_avg:89.07ms +step:579/1680 train_time:51571ms step_avg:89.07ms +step:580/1680 train_time:51661ms step_avg:89.07ms +step:581/1680 train_time:51752ms step_avg:89.07ms +step:582/1680 train_time:51842ms step_avg:89.08ms +step:583/1680 train_time:51934ms step_avg:89.08ms +step:584/1680 train_time:52024ms step_avg:89.08ms +step:585/1680 train_time:52115ms step_avg:89.09ms +step:586/1680 train_time:52205ms step_avg:89.09ms +step:587/1680 train_time:52294ms step_avg:89.09ms +step:588/1680 train_time:52384ms step_avg:89.09ms +step:589/1680 train_time:52475ms step_avg:89.09ms +step:590/1680 train_time:52565ms step_avg:89.09ms +step:591/1680 train_time:52655ms step_avg:89.09ms +step:592/1680 train_time:52744ms step_avg:89.10ms +step:593/1680 train_time:52836ms step_avg:89.10ms +step:594/1680 train_time:52927ms step_avg:89.10ms +step:595/1680 train_time:53020ms step_avg:89.11ms +step:596/1680 train_time:53109ms step_avg:89.11ms +step:597/1680 train_time:53199ms step_avg:89.11ms +step:598/1680 train_time:53289ms step_avg:89.11ms +step:599/1680 train_time:53379ms step_avg:89.11ms +step:600/1680 train_time:53468ms step_avg:89.11ms +step:601/1680 train_time:53559ms step_avg:89.12ms +step:602/1680 train_time:53648ms step_avg:89.12ms +step:603/1680 train_time:53739ms step_avg:89.12ms +step:604/1680 train_time:53829ms step_avg:89.12ms +step:605/1680 train_time:53919ms step_avg:89.12ms +step:606/1680 train_time:54009ms step_avg:89.12ms +step:607/1680 train_time:54100ms step_avg:89.13ms +step:608/1680 train_time:54191ms step_avg:89.13ms +step:609/1680 train_time:54280ms step_avg:89.13ms +step:610/1680 train_time:54370ms step_avg:89.13ms +step:611/1680 train_time:54460ms step_avg:89.13ms +step:612/1680 train_time:54550ms step_avg:89.13ms +step:613/1680 train_time:54640ms step_avg:89.13ms +step:614/1680 train_time:54730ms step_avg:89.14ms +step:615/1680 train_time:54820ms step_avg:89.14ms +step:616/1680 train_time:54910ms step_avg:89.14ms +step:617/1680 train_time:55001ms step_avg:89.14ms +step:618/1680 train_time:55091ms step_avg:89.14ms +step:619/1680 train_time:55181ms step_avg:89.15ms +step:620/1680 train_time:55271ms step_avg:89.15ms +step:621/1680 train_time:55362ms step_avg:89.15ms +step:622/1680 train_time:55452ms step_avg:89.15ms +step:623/1680 train_time:55542ms step_avg:89.15ms +step:624/1680 train_time:55632ms step_avg:89.15ms +step:625/1680 train_time:55722ms step_avg:89.16ms +step:625/1680 val_loss:3.6139 train_time:55815ms step_avg:89.30ms +step:626/1680 train_time:55837ms step_avg:89.20ms +step:627/1680 train_time:55908ms step_avg:89.17ms +step:628/1680 train_time:56009ms step_avg:89.19ms +step:629/1680 train_time:56101ms step_avg:89.19ms +step:630/1680 train_time:56189ms step_avg:89.19ms +step:631/1680 train_time:56278ms step_avg:89.19ms +step:632/1680 train_time:56367ms step_avg:89.19ms +step:633/1680 train_time:56456ms step_avg:89.19ms +step:634/1680 train_time:56544ms step_avg:89.19ms +step:635/1680 train_time:56633ms step_avg:89.19ms +step:636/1680 train_time:56725ms step_avg:89.19ms +step:637/1680 train_time:56820ms step_avg:89.20ms +step:638/1680 train_time:56912ms step_avg:89.20ms +step:639/1680 train_time:57005ms step_avg:89.21ms +step:640/1680 train_time:57095ms step_avg:89.21ms +step:641/1680 train_time:57186ms step_avg:89.21ms +step:642/1680 train_time:57276ms step_avg:89.22ms +step:643/1680 train_time:57366ms step_avg:89.22ms +step:644/1680 train_time:57455ms step_avg:89.22ms +step:645/1680 train_time:57543ms step_avg:89.21ms +step:646/1680 train_time:57634ms step_avg:89.22ms +step:647/1680 train_time:57724ms step_avg:89.22ms +step:648/1680 train_time:57816ms step_avg:89.22ms +step:649/1680 train_time:57908ms step_avg:89.23ms +step:650/1680 train_time:57999ms step_avg:89.23ms +step:651/1680 train_time:58089ms step_avg:89.23ms +step:652/1680 train_time:58179ms step_avg:89.23ms +step:653/1680 train_time:58269ms step_avg:89.23ms +step:654/1680 train_time:58359ms step_avg:89.23ms +step:655/1680 train_time:58448ms step_avg:89.23ms +step:656/1680 train_time:58538ms step_avg:89.24ms +step:657/1680 train_time:58628ms step_avg:89.24ms +step:658/1680 train_time:58718ms step_avg:89.24ms +step:659/1680 train_time:58810ms step_avg:89.24ms +step:660/1680 train_time:58902ms step_avg:89.24ms +step:661/1680 train_time:58993ms step_avg:89.25ms +step:662/1680 train_time:59082ms step_avg:89.25ms +step:663/1680 train_time:59172ms step_avg:89.25ms +step:664/1680 train_time:59263ms step_avg:89.25ms +step:665/1680 train_time:59352ms step_avg:89.25ms +step:666/1680 train_time:59442ms step_avg:89.25ms +step:667/1680 train_time:59531ms step_avg:89.25ms +step:668/1680 train_time:59621ms step_avg:89.25ms +step:669/1680 train_time:59711ms step_avg:89.25ms +step:670/1680 train_time:59801ms step_avg:89.26ms +step:671/1680 train_time:59892ms step_avg:89.26ms +step:672/1680 train_time:59982ms step_avg:89.26ms +step:673/1680 train_time:60072ms step_avg:89.26ms +step:674/1680 train_time:60162ms step_avg:89.26ms +step:675/1680 train_time:60252ms step_avg:89.26ms +step:676/1680 train_time:60342ms step_avg:89.26ms +step:677/1680 train_time:60433ms step_avg:89.27ms +step:678/1680 train_time:60522ms step_avg:89.27ms +step:679/1680 train_time:60612ms step_avg:89.27ms +step:680/1680 train_time:60702ms step_avg:89.27ms +step:681/1680 train_time:60793ms step_avg:89.27ms +step:682/1680 train_time:60883ms step_avg:89.27ms +step:683/1680 train_time:60973ms step_avg:89.27ms +step:684/1680 train_time:61063ms step_avg:89.27ms +step:685/1680 train_time:61154ms step_avg:89.28ms +step:686/1680 train_time:61244ms step_avg:89.28ms +step:687/1680 train_time:61334ms step_avg:89.28ms +step:688/1680 train_time:61424ms step_avg:89.28ms +step:689/1680 train_time:61515ms step_avg:89.28ms +step:690/1680 train_time:61604ms step_avg:89.28ms +step:691/1680 train_time:61694ms step_avg:89.28ms +step:692/1680 train_time:61784ms step_avg:89.28ms +step:693/1680 train_time:61875ms step_avg:89.29ms +step:694/1680 train_time:61965ms step_avg:89.29ms +step:695/1680 train_time:62055ms step_avg:89.29ms +step:696/1680 train_time:62145ms step_avg:89.29ms +step:697/1680 train_time:62236ms step_avg:89.29ms +step:698/1680 train_time:62326ms step_avg:89.29ms +step:699/1680 train_time:62416ms step_avg:89.29ms +step:700/1680 train_time:62506ms step_avg:89.29ms +step:701/1680 train_time:62596ms step_avg:89.30ms +step:702/1680 train_time:62686ms step_avg:89.30ms +step:703/1680 train_time:62776ms step_avg:89.30ms +step:704/1680 train_time:62866ms step_avg:89.30ms +step:705/1680 train_time:62956ms step_avg:89.30ms +step:706/1680 train_time:63046ms step_avg:89.30ms +step:707/1680 train_time:63137ms step_avg:89.30ms +step:708/1680 train_time:63227ms step_avg:89.30ms +step:709/1680 train_time:63317ms step_avg:89.30ms +step:710/1680 train_time:63407ms step_avg:89.31ms +step:711/1680 train_time:63498ms step_avg:89.31ms +step:712/1680 train_time:63587ms step_avg:89.31ms +step:713/1680 train_time:63677ms step_avg:89.31ms +step:714/1680 train_time:63766ms step_avg:89.31ms +step:715/1680 train_time:63856ms step_avg:89.31ms +step:716/1680 train_time:63945ms step_avg:89.31ms +step:717/1680 train_time:64036ms step_avg:89.31ms +step:718/1680 train_time:64125ms step_avg:89.31ms +step:719/1680 train_time:64216ms step_avg:89.31ms +step:720/1680 train_time:64307ms step_avg:89.31ms +step:721/1680 train_time:64398ms step_avg:89.32ms +step:722/1680 train_time:64487ms step_avg:89.32ms +step:723/1680 train_time:64577ms step_avg:89.32ms +step:724/1680 train_time:64667ms step_avg:89.32ms +step:725/1680 train_time:64758ms step_avg:89.32ms +step:726/1680 train_time:64847ms step_avg:89.32ms +step:727/1680 train_time:64938ms step_avg:89.32ms +step:728/1680 train_time:65027ms step_avg:89.32ms +step:729/1680 train_time:65117ms step_avg:89.32ms +step:730/1680 train_time:65207ms step_avg:89.32ms +step:731/1680 train_time:65298ms step_avg:89.33ms +step:732/1680 train_time:65389ms step_avg:89.33ms +step:733/1680 train_time:65479ms step_avg:89.33ms +step:734/1680 train_time:65570ms step_avg:89.33ms +step:735/1680 train_time:65660ms step_avg:89.33ms +step:736/1680 train_time:65749ms step_avg:89.33ms +step:737/1680 train_time:65840ms step_avg:89.34ms +step:738/1680 train_time:65930ms step_avg:89.34ms +step:739/1680 train_time:66021ms step_avg:89.34ms +step:740/1680 train_time:66113ms step_avg:89.34ms +step:741/1680 train_time:66202ms step_avg:89.34ms +step:742/1680 train_time:66292ms step_avg:89.34ms +step:743/1680 train_time:66382ms step_avg:89.34ms +step:744/1680 train_time:66472ms step_avg:89.34ms +step:745/1680 train_time:66562ms step_avg:89.34ms +step:746/1680 train_time:66651ms step_avg:89.34ms +step:747/1680 train_time:66740ms step_avg:89.34ms +step:748/1680 train_time:66831ms step_avg:89.35ms +step:749/1680 train_time:66921ms step_avg:89.35ms +step:750/1680 train_time:67011ms step_avg:89.35ms +step:750/1680 val_loss:3.5641 train_time:67102ms step_avg:89.47ms +step:751/1680 train_time:67124ms step_avg:89.38ms +step:752/1680 train_time:67196ms step_avg:89.36ms +step:753/1680 train_time:67291ms step_avg:89.36ms +step:754/1680 train_time:67382ms step_avg:89.37ms +step:755/1680 train_time:67473ms step_avg:89.37ms +step:756/1680 train_time:67562ms step_avg:89.37ms +step:757/1680 train_time:67651ms step_avg:89.37ms +step:758/1680 train_time:67741ms step_avg:89.37ms +step:759/1680 train_time:67829ms step_avg:89.37ms +step:760/1680 train_time:67918ms step_avg:89.37ms +step:761/1680 train_time:68007ms step_avg:89.37ms +step:762/1680 train_time:68098ms step_avg:89.37ms +step:763/1680 train_time:68191ms step_avg:89.37ms +step:764/1680 train_time:68283ms step_avg:89.38ms +step:765/1680 train_time:68374ms step_avg:89.38ms +step:766/1680 train_time:68463ms step_avg:89.38ms +step:767/1680 train_time:68554ms step_avg:89.38ms +step:768/1680 train_time:68643ms step_avg:89.38ms +step:769/1680 train_time:68733ms step_avg:89.38ms +step:770/1680 train_time:68822ms step_avg:89.38ms +step:771/1680 train_time:68911ms step_avg:89.38ms +step:772/1680 train_time:69001ms step_avg:89.38ms +step:773/1680 train_time:69091ms step_avg:89.38ms +step:774/1680 train_time:69182ms step_avg:89.38ms +step:775/1680 train_time:69274ms step_avg:89.39ms +step:776/1680 train_time:69364ms step_avg:89.39ms +step:777/1680 train_time:69455ms step_avg:89.39ms +step:778/1680 train_time:69546ms step_avg:89.39ms +step:779/1680 train_time:69634ms step_avg:89.39ms +step:780/1680 train_time:69723ms step_avg:89.39ms +step:781/1680 train_time:69813ms step_avg:89.39ms +step:782/1680 train_time:69903ms step_avg:89.39ms +step:783/1680 train_time:69992ms step_avg:89.39ms +step:784/1680 train_time:70082ms step_avg:89.39ms +step:785/1680 train_time:70174ms step_avg:89.39ms +step:786/1680 train_time:70263ms step_avg:89.39ms +step:787/1680 train_time:70354ms step_avg:89.40ms +step:788/1680 train_time:70444ms step_avg:89.40ms +step:789/1680 train_time:70534ms step_avg:89.40ms +step:790/1680 train_time:70624ms step_avg:89.40ms +step:791/1680 train_time:70715ms step_avg:89.40ms +step:792/1680 train_time:70803ms step_avg:89.40ms +step:793/1680 train_time:70893ms step_avg:89.40ms +step:794/1680 train_time:70984ms step_avg:89.40ms +step:795/1680 train_time:71074ms step_avg:89.40ms +step:796/1680 train_time:71163ms step_avg:89.40ms +step:797/1680 train_time:71254ms step_avg:89.40ms +step:798/1680 train_time:71345ms step_avg:89.40ms +step:799/1680 train_time:71436ms step_avg:89.41ms +step:800/1680 train_time:71525ms step_avg:89.41ms +step:801/1680 train_time:71615ms step_avg:89.41ms +step:802/1680 train_time:71705ms step_avg:89.41ms +step:803/1680 train_time:71795ms step_avg:89.41ms +step:804/1680 train_time:71884ms step_avg:89.41ms +step:805/1680 train_time:71974ms step_avg:89.41ms +step:806/1680 train_time:72064ms step_avg:89.41ms +step:807/1680 train_time:72155ms step_avg:89.41ms +step:808/1680 train_time:72244ms step_avg:89.41ms +step:809/1680 train_time:72340ms step_avg:89.42ms +step:810/1680 train_time:72424ms step_avg:89.41ms +step:811/1680 train_time:72515ms step_avg:89.41ms +step:812/1680 train_time:72605ms step_avg:89.42ms +step:813/1680 train_time:72695ms step_avg:89.42ms +step:814/1680 train_time:72785ms step_avg:89.42ms +step:815/1680 train_time:72875ms step_avg:89.42ms +step:816/1680 train_time:72965ms step_avg:89.42ms +step:817/1680 train_time:73055ms step_avg:89.42ms +step:818/1680 train_time:73146ms step_avg:89.42ms +step:819/1680 train_time:73237ms step_avg:89.42ms +step:820/1680 train_time:73327ms step_avg:89.42ms +step:821/1680 train_time:73417ms step_avg:89.42ms +step:822/1680 train_time:73508ms step_avg:89.43ms +step:823/1680 train_time:73598ms step_avg:89.43ms +step:824/1680 train_time:73689ms step_avg:89.43ms +step:825/1680 train_time:73780ms step_avg:89.43ms +step:826/1680 train_time:73870ms step_avg:89.43ms +step:827/1680 train_time:73960ms step_avg:89.43ms +step:828/1680 train_time:74050ms step_avg:89.43ms +step:829/1680 train_time:74140ms step_avg:89.43ms +step:830/1680 train_time:74230ms step_avg:89.43ms +step:831/1680 train_time:74319ms step_avg:89.43ms +step:832/1680 train_time:74410ms step_avg:89.43ms +step:833/1680 train_time:74501ms step_avg:89.44ms +step:834/1680 train_time:74590ms step_avg:89.44ms +step:835/1680 train_time:74680ms step_avg:89.44ms +step:836/1680 train_time:74771ms step_avg:89.44ms +step:837/1680 train_time:74860ms step_avg:89.44ms +step:838/1680 train_time:74951ms step_avg:89.44ms +step:839/1680 train_time:75042ms step_avg:89.44ms +step:840/1680 train_time:75131ms step_avg:89.44ms +step:841/1680 train_time:75221ms step_avg:89.44ms +step:842/1680 train_time:75311ms step_avg:89.44ms +step:843/1680 train_time:75401ms step_avg:89.44ms +step:844/1680 train_time:75491ms step_avg:89.44ms +step:845/1680 train_time:75582ms step_avg:89.45ms +step:846/1680 train_time:75673ms step_avg:89.45ms +step:847/1680 train_time:75762ms step_avg:89.45ms +step:848/1680 train_time:75853ms step_avg:89.45ms +step:849/1680 train_time:75943ms step_avg:89.45ms +step:850/1680 train_time:76034ms step_avg:89.45ms +step:851/1680 train_time:76124ms step_avg:89.45ms +step:852/1680 train_time:76214ms step_avg:89.45ms +step:853/1680 train_time:76303ms step_avg:89.45ms +step:854/1680 train_time:76394ms step_avg:89.45ms +step:855/1680 train_time:76484ms step_avg:89.45ms +step:856/1680 train_time:76575ms step_avg:89.46ms +step:857/1680 train_time:76665ms step_avg:89.46ms +step:858/1680 train_time:76756ms step_avg:89.46ms +step:859/1680 train_time:76846ms step_avg:89.46ms +step:860/1680 train_time:76940ms step_avg:89.47ms +step:861/1680 train_time:77026ms step_avg:89.46ms +step:862/1680 train_time:77116ms step_avg:89.46ms +step:863/1680 train_time:77205ms step_avg:89.46ms +step:864/1680 train_time:77295ms step_avg:89.46ms +step:865/1680 train_time:77385ms step_avg:89.46ms +step:866/1680 train_time:77476ms step_avg:89.46ms +step:867/1680 train_time:77566ms step_avg:89.46ms +step:868/1680 train_time:77656ms step_avg:89.47ms +step:869/1680 train_time:77746ms step_avg:89.47ms +step:870/1680 train_time:77838ms step_avg:89.47ms +step:871/1680 train_time:77928ms step_avg:89.47ms +step:872/1680 train_time:78018ms step_avg:89.47ms +step:873/1680 train_time:78108ms step_avg:89.47ms +step:874/1680 train_time:78198ms step_avg:89.47ms +step:875/1680 train_time:78287ms step_avg:89.47ms +step:875/1680 val_loss:3.5174 train_time:78379ms step_avg:89.58ms +step:876/1680 train_time:78401ms step_avg:89.50ms +step:877/1680 train_time:78476ms step_avg:89.48ms +step:878/1680 train_time:78571ms step_avg:89.49ms +step:879/1680 train_time:78663ms step_avg:89.49ms +step:880/1680 train_time:78753ms step_avg:89.49ms +step:881/1680 train_time:78842ms step_avg:89.49ms +step:882/1680 train_time:78930ms step_avg:89.49ms +step:883/1680 train_time:79021ms step_avg:89.49ms +step:884/1680 train_time:79110ms step_avg:89.49ms +step:885/1680 train_time:79199ms step_avg:89.49ms +step:886/1680 train_time:79289ms step_avg:89.49ms +step:887/1680 train_time:79380ms step_avg:89.49ms +step:888/1680 train_time:79470ms step_avg:89.49ms +step:889/1680 train_time:79563ms step_avg:89.50ms +step:890/1680 train_time:79655ms step_avg:89.50ms +step:891/1680 train_time:79745ms step_avg:89.50ms +step:892/1680 train_time:79836ms step_avg:89.50ms +step:893/1680 train_time:79926ms step_avg:89.50ms +step:894/1680 train_time:80015ms step_avg:89.50ms +step:895/1680 train_time:80104ms step_avg:89.50ms +step:896/1680 train_time:80194ms step_avg:89.50ms +step:897/1680 train_time:80283ms step_avg:89.50ms +step:898/1680 train_time:80373ms step_avg:89.50ms +step:899/1680 train_time:80464ms step_avg:89.50ms +step:900/1680 train_time:80556ms step_avg:89.51ms +step:901/1680 train_time:80647ms step_avg:89.51ms +step:902/1680 train_time:80738ms step_avg:89.51ms +step:903/1680 train_time:80828ms step_avg:89.51ms +step:904/1680 train_time:80919ms step_avg:89.51ms +step:905/1680 train_time:81008ms step_avg:89.51ms +step:906/1680 train_time:81097ms step_avg:89.51ms +step:907/1680 train_time:81187ms step_avg:89.51ms +step:908/1680 train_time:81276ms step_avg:89.51ms +step:909/1680 train_time:81366ms step_avg:89.51ms +step:910/1680 train_time:81456ms step_avg:89.51ms +step:911/1680 train_time:81547ms step_avg:89.51ms +step:912/1680 train_time:81638ms step_avg:89.52ms +step:913/1680 train_time:81728ms step_avg:89.52ms +step:914/1680 train_time:81820ms step_avg:89.52ms +step:915/1680 train_time:81910ms step_avg:89.52ms +step:916/1680 train_time:82001ms step_avg:89.52ms +step:917/1680 train_time:82091ms step_avg:89.52ms +step:918/1680 train_time:82181ms step_avg:89.52ms +step:919/1680 train_time:82271ms step_avg:89.52ms +step:920/1680 train_time:82361ms step_avg:89.52ms +step:921/1680 train_time:82450ms step_avg:89.52ms +step:922/1680 train_time:82542ms step_avg:89.52ms +step:923/1680 train_time:82632ms step_avg:89.53ms +step:924/1680 train_time:82722ms step_avg:89.53ms +step:925/1680 train_time:82813ms step_avg:89.53ms +step:926/1680 train_time:82903ms step_avg:89.53ms +step:927/1680 train_time:82993ms step_avg:89.53ms +step:928/1680 train_time:83083ms step_avg:89.53ms +step:929/1680 train_time:83174ms step_avg:89.53ms +step:930/1680 train_time:83263ms step_avg:89.53ms +step:931/1680 train_time:83353ms step_avg:89.53ms +step:932/1680 train_time:83443ms step_avg:89.53ms +step:933/1680 train_time:83533ms step_avg:89.53ms +step:934/1680 train_time:83623ms step_avg:89.53ms +step:935/1680 train_time:83713ms step_avg:89.53ms +step:936/1680 train_time:83803ms step_avg:89.53ms +step:937/1680 train_time:83894ms step_avg:89.54ms +step:938/1680 train_time:83985ms step_avg:89.54ms +step:939/1680 train_time:84075ms step_avg:89.54ms +step:940/1680 train_time:84165ms step_avg:89.54ms +step:941/1680 train_time:84256ms step_avg:89.54ms +step:942/1680 train_time:84345ms step_avg:89.54ms +step:943/1680 train_time:84435ms step_avg:89.54ms +step:944/1680 train_time:84525ms step_avg:89.54ms +step:945/1680 train_time:84616ms step_avg:89.54ms +step:946/1680 train_time:84705ms step_avg:89.54ms +step:947/1680 train_time:84797ms step_avg:89.54ms +step:948/1680 train_time:84887ms step_avg:89.54ms +step:949/1680 train_time:84978ms step_avg:89.54ms +step:950/1680 train_time:85067ms step_avg:89.54ms +step:951/1680 train_time:85157ms step_avg:89.54ms +step:952/1680 train_time:85247ms step_avg:89.55ms +step:953/1680 train_time:85337ms step_avg:89.55ms +step:954/1680 train_time:85427ms step_avg:89.55ms +step:955/1680 train_time:85517ms step_avg:89.55ms +step:956/1680 train_time:85606ms step_avg:89.55ms +step:957/1680 train_time:85696ms step_avg:89.55ms +step:958/1680 train_time:85786ms step_avg:89.55ms +step:959/1680 train_time:85876ms step_avg:89.55ms +step:960/1680 train_time:85966ms step_avg:89.55ms +step:961/1680 train_time:86056ms step_avg:89.55ms +step:962/1680 train_time:86147ms step_avg:89.55ms +step:963/1680 train_time:86236ms step_avg:89.55ms +step:964/1680 train_time:86327ms step_avg:89.55ms +step:965/1680 train_time:86417ms step_avg:89.55ms +step:966/1680 train_time:86506ms step_avg:89.55ms +step:967/1680 train_time:86596ms step_avg:89.55ms +step:968/1680 train_time:86687ms step_avg:89.55ms +step:969/1680 train_time:86778ms step_avg:89.55ms +step:970/1680 train_time:86867ms step_avg:89.55ms +step:971/1680 train_time:86957ms step_avg:89.55ms +step:972/1680 train_time:87047ms step_avg:89.55ms +step:973/1680 train_time:87137ms step_avg:89.56ms +step:974/1680 train_time:87227ms step_avg:89.56ms +step:975/1680 train_time:87317ms step_avg:89.56ms +step:976/1680 train_time:87407ms step_avg:89.56ms +step:977/1680 train_time:87497ms step_avg:89.56ms +step:978/1680 train_time:87587ms step_avg:89.56ms +step:979/1680 train_time:87677ms step_avg:89.56ms +step:980/1680 train_time:87767ms step_avg:89.56ms +step:981/1680 train_time:87857ms step_avg:89.56ms +step:982/1680 train_time:87947ms step_avg:89.56ms +step:983/1680 train_time:88037ms step_avg:89.56ms +step:984/1680 train_time:88127ms step_avg:89.56ms +step:985/1680 train_time:88217ms step_avg:89.56ms +step:986/1680 train_time:88306ms step_avg:89.56ms +step:987/1680 train_time:88396ms step_avg:89.56ms +step:988/1680 train_time:88486ms step_avg:89.56ms +step:989/1680 train_time:88576ms step_avg:89.56ms +step:990/1680 train_time:88666ms step_avg:89.56ms +step:991/1680 train_time:88756ms step_avg:89.56ms +step:992/1680 train_time:88846ms step_avg:89.56ms +step:993/1680 train_time:88936ms step_avg:89.56ms +step:994/1680 train_time:89027ms step_avg:89.56ms +step:995/1680 train_time:89117ms step_avg:89.56ms +step:996/1680 train_time:89206ms step_avg:89.56ms +step:997/1680 train_time:89296ms step_avg:89.56ms +step:998/1680 train_time:89386ms step_avg:89.57ms +step:999/1680 train_time:89476ms step_avg:89.57ms +step:1000/1680 train_time:89566ms step_avg:89.57ms +step:1000/1680 val_loss:3.4674 train_time:89658ms step_avg:89.66ms +step:1001/1680 train_time:89680ms step_avg:89.59ms +step:1002/1680 train_time:89751ms step_avg:89.57ms +step:1003/1680 train_time:89845ms step_avg:89.58ms +step:1004/1680 train_time:89937ms step_avg:89.58ms +step:1005/1680 train_time:90027ms step_avg:89.58ms +step:1006/1680 train_time:90115ms step_avg:89.58ms +step:1007/1680 train_time:90204ms step_avg:89.58ms +step:1008/1680 train_time:90293ms step_avg:89.58ms +step:1009/1680 train_time:90381ms step_avg:89.58ms +step:1010/1680 train_time:90471ms step_avg:89.58ms +step:1011/1680 train_time:90560ms step_avg:89.58ms +step:1012/1680 train_time:90651ms step_avg:89.58ms +step:1013/1680 train_time:90743ms step_avg:89.58ms +step:1014/1680 train_time:90836ms step_avg:89.58ms +step:1015/1680 train_time:90928ms step_avg:89.58ms +step:1016/1680 train_time:91018ms step_avg:89.58ms +step:1017/1680 train_time:91107ms step_avg:89.58ms +step:1018/1680 train_time:91196ms step_avg:89.58ms +step:1019/1680 train_time:91286ms step_avg:89.58ms +step:1020/1680 train_time:91375ms step_avg:89.58ms +step:1021/1680 train_time:91464ms step_avg:89.58ms +step:1022/1680 train_time:91553ms step_avg:89.58ms +step:1023/1680 train_time:91644ms step_avg:89.58ms +step:1024/1680 train_time:91734ms step_avg:89.58ms +step:1025/1680 train_time:91825ms step_avg:89.59ms +step:1026/1680 train_time:91915ms step_avg:89.59ms +step:1027/1680 train_time:92006ms step_avg:89.59ms +step:1028/1680 train_time:92095ms step_avg:89.59ms +step:1029/1680 train_time:92185ms step_avg:89.59ms +step:1030/1680 train_time:92275ms step_avg:89.59ms +step:1031/1680 train_time:92365ms step_avg:89.59ms +step:1032/1680 train_time:92454ms step_avg:89.59ms +step:1033/1680 train_time:92544ms step_avg:89.59ms +step:1034/1680 train_time:92634ms step_avg:89.59ms +step:1035/1680 train_time:92725ms step_avg:89.59ms +step:1036/1680 train_time:92815ms step_avg:89.59ms +step:1037/1680 train_time:92906ms step_avg:89.59ms +step:1038/1680 train_time:92996ms step_avg:89.59ms +step:1039/1680 train_time:93087ms step_avg:89.59ms +step:1040/1680 train_time:93177ms step_avg:89.59ms +step:1041/1680 train_time:93267ms step_avg:89.59ms +step:1042/1680 train_time:93357ms step_avg:89.59ms +step:1043/1680 train_time:93447ms step_avg:89.59ms +step:1044/1680 train_time:93537ms step_avg:89.59ms +step:1045/1680 train_time:93627ms step_avg:89.60ms +step:1046/1680 train_time:93718ms step_avg:89.60ms +step:1047/1680 train_time:93809ms step_avg:89.60ms +step:1048/1680 train_time:93899ms step_avg:89.60ms +step:1049/1680 train_time:93989ms step_avg:89.60ms +step:1050/1680 train_time:94079ms step_avg:89.60ms +step:1051/1680 train_time:94169ms step_avg:89.60ms +step:1052/1680 train_time:94260ms step_avg:89.60ms +step:1053/1680 train_time:94350ms step_avg:89.60ms +step:1054/1680 train_time:94440ms step_avg:89.60ms +step:1055/1680 train_time:94530ms step_avg:89.60ms +step:1056/1680 train_time:94620ms step_avg:89.60ms +step:1057/1680 train_time:94710ms step_avg:89.60ms +step:1058/1680 train_time:94801ms step_avg:89.60ms +step:1059/1680 train_time:94891ms step_avg:89.60ms +step:1060/1680 train_time:94980ms step_avg:89.60ms +step:1061/1680 train_time:95072ms step_avg:89.61ms +step:1062/1680 train_time:95162ms step_avg:89.61ms +step:1063/1680 train_time:95252ms step_avg:89.61ms +step:1064/1680 train_time:95342ms step_avg:89.61ms +step:1065/1680 train_time:95432ms step_avg:89.61ms +step:1066/1680 train_time:95522ms step_avg:89.61ms +step:1067/1680 train_time:95613ms step_avg:89.61ms +step:1068/1680 train_time:95704ms step_avg:89.61ms +step:1069/1680 train_time:95794ms step_avg:89.61ms +step:1070/1680 train_time:95884ms step_avg:89.61ms +step:1071/1680 train_time:95974ms step_avg:89.61ms +step:1072/1680 train_time:96064ms step_avg:89.61ms +step:1073/1680 train_time:96154ms step_avg:89.61ms +step:1074/1680 train_time:96245ms step_avg:89.61ms +step:1075/1680 train_time:96335ms step_avg:89.61ms +step:1076/1680 train_time:96425ms step_avg:89.61ms +step:1077/1680 train_time:96515ms step_avg:89.61ms +step:1078/1680 train_time:96605ms step_avg:89.61ms +step:1079/1680 train_time:96694ms step_avg:89.61ms +step:1080/1680 train_time:96784ms step_avg:89.61ms +step:1081/1680 train_time:96873ms step_avg:89.61ms +step:1082/1680 train_time:96964ms step_avg:89.62ms +step:1083/1680 train_time:97053ms step_avg:89.62ms +step:1084/1680 train_time:97144ms step_avg:89.62ms +step:1085/1680 train_time:97238ms step_avg:89.62ms +step:1086/1680 train_time:97323ms step_avg:89.62ms +step:1087/1680 train_time:97412ms step_avg:89.62ms +step:1088/1680 train_time:97503ms step_avg:89.62ms +step:1089/1680 train_time:97592ms step_avg:89.62ms +step:1090/1680 train_time:97682ms step_avg:89.62ms +step:1091/1680 train_time:97772ms step_avg:89.62ms +step:1092/1680 train_time:97862ms step_avg:89.62ms +step:1093/1680 train_time:97952ms step_avg:89.62ms +step:1094/1680 train_time:98042ms step_avg:89.62ms +step:1095/1680 train_time:98133ms step_avg:89.62ms +step:1096/1680 train_time:98225ms step_avg:89.62ms +step:1097/1680 train_time:98315ms step_avg:89.62ms +step:1098/1680 train_time:98406ms step_avg:89.62ms +step:1099/1680 train_time:98497ms step_avg:89.62ms +step:1100/1680 train_time:98588ms step_avg:89.63ms +step:1101/1680 train_time:98680ms step_avg:89.63ms +step:1102/1680 train_time:98770ms step_avg:89.63ms +step:1103/1680 train_time:98861ms step_avg:89.63ms +step:1104/1680 train_time:98952ms step_avg:89.63ms +step:1105/1680 train_time:99043ms step_avg:89.63ms +step:1106/1680 train_time:99133ms step_avg:89.63ms +step:1107/1680 train_time:99223ms step_avg:89.63ms +step:1108/1680 train_time:99315ms step_avg:89.63ms +step:1109/1680 train_time:99406ms step_avg:89.64ms +step:1110/1680 train_time:99497ms step_avg:89.64ms +step:1111/1680 train_time:99588ms step_avg:89.64ms +step:1112/1680 train_time:99678ms step_avg:89.64ms +step:1113/1680 train_time:99770ms step_avg:89.64ms +step:1114/1680 train_time:99860ms step_avg:89.64ms +step:1115/1680 train_time:99951ms step_avg:89.64ms +step:1116/1680 train_time:100042ms step_avg:89.64ms +step:1117/1680 train_time:100137ms step_avg:89.65ms +step:1118/1680 train_time:100223ms step_avg:89.64ms +step:1119/1680 train_time:100314ms step_avg:89.65ms +step:1120/1680 train_time:100408ms step_avg:89.65ms +step:1121/1680 train_time:100495ms step_avg:89.65ms +step:1122/1680 train_time:100586ms step_avg:89.65ms +step:1123/1680 train_time:100676ms step_avg:89.65ms +step:1124/1680 train_time:100767ms step_avg:89.65ms +step:1125/1680 train_time:100858ms step_avg:89.65ms +step:1125/1680 val_loss:3.4146 train_time:100951ms step_avg:89.73ms +step:1126/1680 train_time:100972ms step_avg:89.67ms +step:1127/1680 train_time:101047ms step_avg:89.66ms +step:1128/1680 train_time:101147ms step_avg:89.67ms +step:1129/1680 train_time:101239ms step_avg:89.67ms +step:1130/1680 train_time:101329ms step_avg:89.67ms +step:1131/1680 train_time:101418ms step_avg:89.67ms +step:1132/1680 train_time:101508ms step_avg:89.67ms +step:1133/1680 train_time:101598ms step_avg:89.67ms +step:1134/1680 train_time:101688ms step_avg:89.67ms +step:1135/1680 train_time:101777ms step_avg:89.67ms +step:1136/1680 train_time:101868ms step_avg:89.67ms +step:1137/1680 train_time:101959ms step_avg:89.67ms +step:1138/1680 train_time:102053ms step_avg:89.68ms +step:1139/1680 train_time:102148ms step_avg:89.68ms +step:1140/1680 train_time:102239ms step_avg:89.68ms +step:1141/1680 train_time:102330ms step_avg:89.68ms +step:1142/1680 train_time:102421ms step_avg:89.69ms +step:1143/1680 train_time:102511ms step_avg:89.69ms +step:1144/1680 train_time:102600ms step_avg:89.69ms +step:1145/1680 train_time:102690ms step_avg:89.69ms +step:1146/1680 train_time:102780ms step_avg:89.69ms +step:1147/1680 train_time:102871ms step_avg:89.69ms +step:1148/1680 train_time:102962ms step_avg:89.69ms +step:1149/1680 train_time:103054ms step_avg:89.69ms +step:1150/1680 train_time:103147ms step_avg:89.69ms +step:1151/1680 train_time:103238ms step_avg:89.69ms +step:1152/1680 train_time:103331ms step_avg:89.70ms +step:1153/1680 train_time:103422ms step_avg:89.70ms +step:1154/1680 train_time:103513ms step_avg:89.70ms +step:1155/1680 train_time:103602ms step_avg:89.70ms +step:1156/1680 train_time:103692ms step_avg:89.70ms +step:1157/1680 train_time:103782ms step_avg:89.70ms +step:1158/1680 train_time:103873ms step_avg:89.70ms +step:1159/1680 train_time:103964ms step_avg:89.70ms +step:1160/1680 train_time:104055ms step_avg:89.70ms +step:1161/1680 train_time:104146ms step_avg:89.70ms +step:1162/1680 train_time:104238ms step_avg:89.71ms +step:1163/1680 train_time:104329ms step_avg:89.71ms +step:1164/1680 train_time:104420ms step_avg:89.71ms +step:1165/1680 train_time:104512ms step_avg:89.71ms +step:1166/1680 train_time:104602ms step_avg:89.71ms +step:1167/1680 train_time:104692ms step_avg:89.71ms +step:1168/1680 train_time:104782ms step_avg:89.71ms +step:1169/1680 train_time:104873ms step_avg:89.71ms +step:1170/1680 train_time:104964ms step_avg:89.71ms +step:1171/1680 train_time:105055ms step_avg:89.71ms +step:1172/1680 train_time:105146ms step_avg:89.71ms +step:1173/1680 train_time:105237ms step_avg:89.72ms +step:1174/1680 train_time:105329ms step_avg:89.72ms +step:1175/1680 train_time:105420ms step_avg:89.72ms +step:1176/1680 train_time:105512ms step_avg:89.72ms +step:1177/1680 train_time:105603ms step_avg:89.72ms +step:1178/1680 train_time:105694ms step_avg:89.72ms +step:1179/1680 train_time:105784ms step_avg:89.72ms +step:1180/1680 train_time:105874ms step_avg:89.72ms +step:1181/1680 train_time:105966ms step_avg:89.73ms +step:1182/1680 train_time:106056ms step_avg:89.73ms +step:1183/1680 train_time:106148ms step_avg:89.73ms +step:1184/1680 train_time:106239ms step_avg:89.73ms +step:1185/1680 train_time:106330ms step_avg:89.73ms +step:1186/1680 train_time:106422ms step_avg:89.73ms +step:1187/1680 train_time:106513ms step_avg:89.73ms +step:1188/1680 train_time:106603ms step_avg:89.73ms +step:1189/1680 train_time:106694ms step_avg:89.73ms +step:1190/1680 train_time:106784ms step_avg:89.73ms +step:1191/1680 train_time:106874ms step_avg:89.73ms +step:1192/1680 train_time:106965ms step_avg:89.74ms +step:1193/1680 train_time:107055ms step_avg:89.74ms +step:1194/1680 train_time:107146ms step_avg:89.74ms +step:1195/1680 train_time:107237ms step_avg:89.74ms +step:1196/1680 train_time:107328ms step_avg:89.74ms +step:1197/1680 train_time:107418ms step_avg:89.74ms +step:1198/1680 train_time:107510ms step_avg:89.74ms +step:1199/1680 train_time:107601ms step_avg:89.74ms +step:1200/1680 train_time:107691ms step_avg:89.74ms +step:1201/1680 train_time:107781ms step_avg:89.74ms +step:1202/1680 train_time:107872ms step_avg:89.74ms +step:1203/1680 train_time:107962ms step_avg:89.74ms +step:1204/1680 train_time:108053ms step_avg:89.75ms +step:1205/1680 train_time:108144ms step_avg:89.75ms +step:1206/1680 train_time:108235ms step_avg:89.75ms +step:1207/1680 train_time:108327ms step_avg:89.75ms +step:1208/1680 train_time:108418ms step_avg:89.75ms +step:1209/1680 train_time:108510ms step_avg:89.75ms +step:1210/1680 train_time:108602ms step_avg:89.75ms +step:1211/1680 train_time:108693ms step_avg:89.75ms +step:1212/1680 train_time:108784ms step_avg:89.76ms +step:1213/1680 train_time:108874ms step_avg:89.76ms +step:1214/1680 train_time:108965ms step_avg:89.76ms +step:1215/1680 train_time:109055ms step_avg:89.76ms +step:1216/1680 train_time:109146ms step_avg:89.76ms +step:1217/1680 train_time:109237ms step_avg:89.76ms +step:1218/1680 train_time:109329ms step_avg:89.76ms +step:1219/1680 train_time:109420ms step_avg:89.76ms +step:1220/1680 train_time:109512ms step_avg:89.76ms +step:1221/1680 train_time:109603ms step_avg:89.76ms +step:1222/1680 train_time:109693ms step_avg:89.77ms +step:1223/1680 train_time:109784ms step_avg:89.77ms +step:1224/1680 train_time:109874ms step_avg:89.77ms +step:1225/1680 train_time:109965ms step_avg:89.77ms +step:1226/1680 train_time:110055ms step_avg:89.77ms +step:1227/1680 train_time:110146ms step_avg:89.77ms +step:1228/1680 train_time:110237ms step_avg:89.77ms +step:1229/1680 train_time:110327ms step_avg:89.77ms +step:1230/1680 train_time:110417ms step_avg:89.77ms +step:1231/1680 train_time:110509ms step_avg:89.77ms +step:1232/1680 train_time:110600ms step_avg:89.77ms +step:1233/1680 train_time:110691ms step_avg:89.77ms +step:1234/1680 train_time:110782ms step_avg:89.77ms +step:1235/1680 train_time:110872ms step_avg:89.78ms +step:1236/1680 train_time:110963ms step_avg:89.78ms +step:1237/1680 train_time:111053ms step_avg:89.78ms +step:1238/1680 train_time:111143ms step_avg:89.78ms +step:1239/1680 train_time:111234ms step_avg:89.78ms +step:1240/1680 train_time:111325ms step_avg:89.78ms +step:1241/1680 train_time:111415ms step_avg:89.78ms +step:1242/1680 train_time:111506ms step_avg:89.78ms +step:1243/1680 train_time:111598ms step_avg:89.78ms +step:1244/1680 train_time:111689ms step_avg:89.78ms +step:1245/1680 train_time:111779ms step_avg:89.78ms +step:1246/1680 train_time:111870ms step_avg:89.78ms +step:1247/1680 train_time:111962ms step_avg:89.78ms +step:1248/1680 train_time:112052ms step_avg:89.79ms +step:1249/1680 train_time:112143ms step_avg:89.79ms +step:1250/1680 train_time:112234ms step_avg:89.79ms +step:1250/1680 val_loss:3.3761 train_time:112326ms step_avg:89.86ms +step:1251/1680 train_time:112348ms step_avg:89.81ms +step:1252/1680 train_time:112421ms step_avg:89.79ms +step:1253/1680 train_time:112518ms step_avg:89.80ms +step:1254/1680 train_time:112609ms step_avg:89.80ms +step:1255/1680 train_time:112698ms step_avg:89.80ms +step:1256/1680 train_time:112788ms step_avg:89.80ms +step:1257/1680 train_time:112878ms step_avg:89.80ms +step:1258/1680 train_time:112967ms step_avg:89.80ms +step:1259/1680 train_time:113057ms step_avg:89.80ms +step:1260/1680 train_time:113147ms step_avg:89.80ms +step:1261/1680 train_time:113237ms step_avg:89.80ms +step:1262/1680 train_time:113330ms step_avg:89.80ms +step:1263/1680 train_time:113422ms step_avg:89.80ms +step:1264/1680 train_time:113515ms step_avg:89.81ms +step:1265/1680 train_time:113606ms step_avg:89.81ms +step:1266/1680 train_time:113696ms step_avg:89.81ms +step:1267/1680 train_time:113786ms step_avg:89.81ms +step:1268/1680 train_time:113876ms step_avg:89.81ms +step:1269/1680 train_time:113966ms step_avg:89.81ms +step:1270/1680 train_time:114056ms step_avg:89.81ms +step:1271/1680 train_time:114151ms step_avg:89.81ms +step:1272/1680 train_time:114237ms step_avg:89.81ms +step:1273/1680 train_time:114329ms step_avg:89.81ms +step:1274/1680 train_time:114422ms step_avg:89.81ms +step:1275/1680 train_time:114513ms step_avg:89.81ms +step:1276/1680 train_time:114605ms step_avg:89.82ms +step:1277/1680 train_time:114696ms step_avg:89.82ms +step:1278/1680 train_time:114786ms step_avg:89.82ms +step:1279/1680 train_time:114877ms step_avg:89.82ms +step:1280/1680 train_time:114967ms step_avg:89.82ms +step:1281/1680 train_time:115057ms step_avg:89.82ms +step:1282/1680 train_time:115148ms step_avg:89.82ms +step:1283/1680 train_time:115239ms step_avg:89.82ms +step:1284/1680 train_time:115330ms step_avg:89.82ms +step:1285/1680 train_time:115422ms step_avg:89.82ms +step:1286/1680 train_time:115514ms step_avg:89.82ms +step:1287/1680 train_time:115605ms step_avg:89.83ms +step:1288/1680 train_time:115697ms step_avg:89.83ms +step:1289/1680 train_time:115788ms step_avg:89.83ms +step:1290/1680 train_time:115879ms step_avg:89.83ms +step:1291/1680 train_time:115969ms step_avg:89.83ms +step:1292/1680 train_time:116059ms step_avg:89.83ms +step:1293/1680 train_time:116150ms step_avg:89.83ms +step:1294/1680 train_time:116239ms step_avg:89.83ms +step:1295/1680 train_time:116330ms step_avg:89.83ms +step:1296/1680 train_time:116421ms step_avg:89.83ms +step:1297/1680 train_time:116512ms step_avg:89.83ms +step:1298/1680 train_time:116603ms step_avg:89.83ms +step:1299/1680 train_time:116694ms step_avg:89.83ms +step:1300/1680 train_time:116784ms step_avg:89.83ms +step:1301/1680 train_time:116876ms step_avg:89.84ms +step:1302/1680 train_time:116967ms step_avg:89.84ms +step:1303/1680 train_time:117059ms step_avg:89.84ms +step:1304/1680 train_time:117150ms step_avg:89.84ms +step:1305/1680 train_time:117240ms step_avg:89.84ms +step:1306/1680 train_time:117330ms step_avg:89.84ms +step:1307/1680 train_time:117420ms step_avg:89.84ms +step:1308/1680 train_time:117512ms step_avg:89.84ms +step:1309/1680 train_time:117603ms step_avg:89.84ms +step:1310/1680 train_time:117694ms step_avg:89.84ms +step:1311/1680 train_time:117786ms step_avg:89.84ms +step:1312/1680 train_time:117877ms step_avg:89.85ms +step:1313/1680 train_time:117968ms step_avg:89.85ms +step:1314/1680 train_time:118060ms step_avg:89.85ms +step:1315/1680 train_time:118152ms step_avg:89.85ms +step:1316/1680 train_time:118243ms step_avg:89.85ms +step:1317/1680 train_time:118333ms step_avg:89.85ms +step:1318/1680 train_time:118423ms step_avg:89.85ms +step:1319/1680 train_time:118515ms step_avg:89.85ms +step:1320/1680 train_time:118605ms step_avg:89.85ms +step:1321/1680 train_time:118697ms step_avg:89.85ms +step:1322/1680 train_time:118788ms step_avg:89.85ms +step:1323/1680 train_time:118879ms step_avg:89.86ms +step:1324/1680 train_time:118969ms step_avg:89.86ms +step:1325/1680 train_time:119061ms step_avg:89.86ms +step:1326/1680 train_time:119152ms step_avg:89.86ms +step:1327/1680 train_time:119242ms step_avg:89.86ms +step:1328/1680 train_time:119333ms step_avg:89.86ms +step:1329/1680 train_time:119424ms step_avg:89.86ms +step:1330/1680 train_time:119515ms step_avg:89.86ms +step:1331/1680 train_time:119605ms step_avg:89.86ms +step:1332/1680 train_time:119696ms step_avg:89.86ms +step:1333/1680 train_time:119786ms step_avg:89.86ms +step:1334/1680 train_time:119877ms step_avg:89.86ms +step:1335/1680 train_time:119969ms step_avg:89.86ms +step:1336/1680 train_time:120059ms step_avg:89.86ms +step:1337/1680 train_time:120151ms step_avg:89.87ms +step:1338/1680 train_time:120241ms step_avg:89.87ms +step:1339/1680 train_time:120332ms step_avg:89.87ms +step:1340/1680 train_time:120422ms step_avg:89.87ms +step:1341/1680 train_time:120513ms step_avg:89.87ms +step:1342/1680 train_time:120603ms step_avg:89.87ms +step:1343/1680 train_time:120695ms step_avg:89.87ms +step:1344/1680 train_time:120785ms step_avg:89.87ms +step:1345/1680 train_time:120876ms step_avg:89.87ms +step:1346/1680 train_time:120966ms step_avg:89.87ms +step:1347/1680 train_time:121058ms step_avg:89.87ms +step:1348/1680 train_time:121150ms step_avg:89.87ms +step:1349/1680 train_time:121241ms step_avg:89.87ms +step:1350/1680 train_time:121331ms step_avg:89.88ms +step:1351/1680 train_time:121423ms step_avg:89.88ms +step:1352/1680 train_time:121514ms step_avg:89.88ms +step:1353/1680 train_time:121606ms step_avg:89.88ms +step:1354/1680 train_time:121696ms step_avg:89.88ms +step:1355/1680 train_time:121786ms step_avg:89.88ms +step:1356/1680 train_time:121878ms step_avg:89.88ms +step:1357/1680 train_time:121968ms step_avg:89.88ms +step:1358/1680 train_time:122059ms step_avg:89.88ms +step:1359/1680 train_time:122150ms step_avg:89.88ms +step:1360/1680 train_time:122240ms step_avg:89.88ms +step:1361/1680 train_time:122331ms step_avg:89.88ms +step:1362/1680 train_time:122421ms step_avg:89.88ms +step:1363/1680 train_time:122513ms step_avg:89.88ms +step:1364/1680 train_time:122604ms step_avg:89.89ms +step:1365/1680 train_time:122695ms step_avg:89.89ms +step:1366/1680 train_time:122786ms step_avg:89.89ms +step:1367/1680 train_time:122877ms step_avg:89.89ms +step:1368/1680 train_time:122968ms step_avg:89.89ms +step:1369/1680 train_time:123060ms step_avg:89.89ms +step:1370/1680 train_time:123151ms step_avg:89.89ms +step:1371/1680 train_time:123241ms step_avg:89.89ms +step:1372/1680 train_time:123334ms step_avg:89.89ms +step:1373/1680 train_time:123424ms step_avg:89.89ms +step:1374/1680 train_time:123516ms step_avg:89.90ms +step:1375/1680 train_time:123606ms step_avg:89.90ms +step:1375/1680 val_loss:3.3420 train_time:123699ms step_avg:89.96ms +step:1376/1680 train_time:123721ms step_avg:89.91ms +step:1377/1680 train_time:123793ms step_avg:89.90ms +step:1378/1680 train_time:123889ms step_avg:89.90ms +step:1379/1680 train_time:123982ms step_avg:89.91ms +step:1380/1680 train_time:124073ms step_avg:89.91ms +step:1381/1680 train_time:124163ms step_avg:89.91ms +step:1382/1680 train_time:124254ms step_avg:89.91ms +step:1383/1680 train_time:124343ms step_avg:89.91ms +step:1384/1680 train_time:124433ms step_avg:89.91ms +step:1385/1680 train_time:124522ms step_avg:89.91ms +step:1386/1680 train_time:124612ms step_avg:89.91ms +step:1387/1680 train_time:124703ms step_avg:89.91ms +step:1388/1680 train_time:124797ms step_avg:89.91ms +step:1389/1680 train_time:124890ms step_avg:89.91ms +step:1390/1680 train_time:124983ms step_avg:89.92ms +step:1391/1680 train_time:125075ms step_avg:89.92ms +step:1392/1680 train_time:125165ms step_avg:89.92ms +step:1393/1680 train_time:125256ms step_avg:89.92ms +step:1394/1680 train_time:125346ms step_avg:89.92ms +step:1395/1680 train_time:125436ms step_avg:89.92ms +step:1396/1680 train_time:125525ms step_avg:89.92ms +step:1397/1680 train_time:125615ms step_avg:89.92ms +step:1398/1680 train_time:125706ms step_avg:89.92ms +step:1399/1680 train_time:125798ms step_avg:89.92ms +step:1400/1680 train_time:125890ms step_avg:89.92ms +step:1401/1680 train_time:125984ms step_avg:89.92ms +step:1402/1680 train_time:126075ms step_avg:89.93ms +step:1403/1680 train_time:126166ms step_avg:89.93ms +step:1404/1680 train_time:126257ms step_avg:89.93ms +step:1405/1680 train_time:126347ms step_avg:89.93ms +step:1406/1680 train_time:126436ms step_avg:89.93ms +step:1407/1680 train_time:126526ms step_avg:89.93ms +step:1408/1680 train_time:126617ms step_avg:89.93ms +step:1409/1680 train_time:126708ms step_avg:89.93ms +step:1410/1680 train_time:126800ms step_avg:89.93ms +step:1411/1680 train_time:126892ms step_avg:89.93ms +step:1412/1680 train_time:126983ms step_avg:89.93ms +step:1413/1680 train_time:127074ms step_avg:89.93ms +step:1414/1680 train_time:127166ms step_avg:89.93ms +step:1415/1680 train_time:127256ms step_avg:89.93ms +step:1416/1680 train_time:127346ms step_avg:89.93ms +step:1417/1680 train_time:127436ms step_avg:89.93ms +step:1418/1680 train_time:127527ms step_avg:89.93ms +step:1419/1680 train_time:127617ms step_avg:89.93ms +step:1420/1680 train_time:127709ms step_avg:89.94ms +step:1421/1680 train_time:127800ms step_avg:89.94ms +step:1422/1680 train_time:127898ms step_avg:89.94ms +step:1423/1680 train_time:127984ms step_avg:89.94ms +step:1424/1680 train_time:128074ms step_avg:89.94ms +step:1425/1680 train_time:128165ms step_avg:89.94ms +step:1426/1680 train_time:128256ms step_avg:89.94ms +step:1427/1680 train_time:128346ms step_avg:89.94ms +step:1428/1680 train_time:128436ms step_avg:89.94ms +step:1429/1680 train_time:128527ms step_avg:89.94ms +step:1430/1680 train_time:128619ms step_avg:89.94ms +step:1431/1680 train_time:128708ms step_avg:89.94ms +step:1432/1680 train_time:128800ms step_avg:89.94ms +step:1433/1680 train_time:128891ms step_avg:89.94ms +step:1434/1680 train_time:128983ms step_avg:89.95ms +step:1435/1680 train_time:129073ms step_avg:89.95ms +step:1436/1680 train_time:129164ms step_avg:89.95ms +step:1437/1680 train_time:129254ms step_avg:89.95ms +step:1438/1680 train_time:129344ms step_avg:89.95ms +step:1439/1680 train_time:129435ms step_avg:89.95ms +step:1440/1680 train_time:129526ms step_avg:89.95ms +step:1441/1680 train_time:129617ms step_avg:89.95ms +step:1442/1680 train_time:129708ms step_avg:89.95ms +step:1443/1680 train_time:129799ms step_avg:89.95ms +step:1444/1680 train_time:129891ms step_avg:89.95ms +step:1445/1680 train_time:129984ms step_avg:89.95ms +step:1446/1680 train_time:130075ms step_avg:89.96ms +step:1447/1680 train_time:130165ms step_avg:89.95ms +step:1448/1680 train_time:130255ms step_avg:89.96ms +step:1449/1680 train_time:130346ms step_avg:89.96ms +step:1450/1680 train_time:130437ms step_avg:89.96ms +step:1451/1680 train_time:130527ms step_avg:89.96ms +step:1452/1680 train_time:130619ms step_avg:89.96ms +step:1453/1680 train_time:130709ms step_avg:89.96ms +step:1454/1680 train_time:130800ms step_avg:89.96ms +step:1455/1680 train_time:130891ms step_avg:89.96ms +step:1456/1680 train_time:130982ms step_avg:89.96ms +step:1457/1680 train_time:131073ms step_avg:89.96ms +step:1458/1680 train_time:131163ms step_avg:89.96ms +step:1459/1680 train_time:131254ms step_avg:89.96ms +step:1460/1680 train_time:131345ms step_avg:89.96ms +step:1461/1680 train_time:131435ms step_avg:89.96ms +step:1462/1680 train_time:131526ms step_avg:89.96ms +step:1463/1680 train_time:131618ms step_avg:89.96ms +step:1464/1680 train_time:131708ms step_avg:89.96ms +step:1465/1680 train_time:131801ms step_avg:89.97ms +step:1466/1680 train_time:131894ms step_avg:89.97ms +step:1467/1680 train_time:131984ms step_avg:89.97ms +step:1468/1680 train_time:132076ms step_avg:89.97ms +step:1469/1680 train_time:132166ms step_avg:89.97ms +step:1470/1680 train_time:132257ms step_avg:89.97ms +step:1471/1680 train_time:132346ms step_avg:89.97ms +step:1472/1680 train_time:132437ms step_avg:89.97ms +step:1473/1680 train_time:132529ms step_avg:89.97ms +step:1474/1680 train_time:132620ms step_avg:89.97ms +step:1475/1680 train_time:132710ms step_avg:89.97ms +step:1476/1680 train_time:132801ms step_avg:89.97ms +step:1477/1680 train_time:132893ms step_avg:89.97ms +step:1478/1680 train_time:132984ms step_avg:89.98ms +step:1479/1680 train_time:133075ms step_avg:89.98ms +step:1480/1680 train_time:133166ms step_avg:89.98ms +step:1481/1680 train_time:133258ms step_avg:89.98ms +step:1482/1680 train_time:133348ms step_avg:89.98ms +step:1483/1680 train_time:133438ms step_avg:89.98ms +step:1484/1680 train_time:133528ms step_avg:89.98ms +step:1485/1680 train_time:133619ms step_avg:89.98ms +step:1486/1680 train_time:133709ms step_avg:89.98ms +step:1487/1680 train_time:133801ms step_avg:89.98ms +step:1488/1680 train_time:133892ms step_avg:89.98ms +step:1489/1680 train_time:133983ms step_avg:89.98ms +step:1490/1680 train_time:134074ms step_avg:89.98ms +step:1491/1680 train_time:134164ms step_avg:89.98ms +step:1492/1680 train_time:134255ms step_avg:89.98ms +step:1493/1680 train_time:134345ms step_avg:89.98ms +step:1494/1680 train_time:134436ms step_avg:89.98ms +step:1495/1680 train_time:134526ms step_avg:89.98ms +step:1496/1680 train_time:134617ms step_avg:89.98ms +step:1497/1680 train_time:134707ms step_avg:89.98ms +step:1498/1680 train_time:134799ms step_avg:89.99ms +step:1499/1680 train_time:134891ms step_avg:89.99ms +step:1500/1680 train_time:134981ms step_avg:89.99ms +step:1500/1680 val_loss:3.3121 train_time:135073ms step_avg:90.05ms +step:1501/1680 train_time:135095ms step_avg:90.00ms +step:1502/1680 train_time:135166ms step_avg:89.99ms +step:1503/1680 train_time:135261ms step_avg:89.99ms +step:1504/1680 train_time:135351ms step_avg:89.99ms +step:1505/1680 train_time:135442ms step_avg:89.99ms +step:1506/1680 train_time:135532ms step_avg:89.99ms +step:1507/1680 train_time:135622ms step_avg:89.99ms +step:1508/1680 train_time:135712ms step_avg:89.99ms +step:1509/1680 train_time:135803ms step_avg:90.00ms +step:1510/1680 train_time:135893ms step_avg:90.00ms +step:1511/1680 train_time:135983ms step_avg:90.00ms +step:1512/1680 train_time:136076ms step_avg:90.00ms +step:1513/1680 train_time:136168ms step_avg:90.00ms +step:1514/1680 train_time:136261ms step_avg:90.00ms +step:1515/1680 train_time:136354ms step_avg:90.00ms +step:1516/1680 train_time:136444ms step_avg:90.00ms +step:1517/1680 train_time:136534ms step_avg:90.00ms +step:1518/1680 train_time:136623ms step_avg:90.00ms +step:1519/1680 train_time:136714ms step_avg:90.00ms +step:1520/1680 train_time:136804ms step_avg:90.00ms +step:1521/1680 train_time:136894ms step_avg:90.00ms +step:1522/1680 train_time:136986ms step_avg:90.00ms +step:1523/1680 train_time:137077ms step_avg:90.00ms +step:1524/1680 train_time:137169ms step_avg:90.01ms +step:1525/1680 train_time:137260ms step_avg:90.01ms +step:1526/1680 train_time:137353ms step_avg:90.01ms +step:1527/1680 train_time:137449ms step_avg:90.01ms +step:1528/1680 train_time:137535ms step_avg:90.01ms +step:1529/1680 train_time:137625ms step_avg:90.01ms +step:1530/1680 train_time:137715ms step_avg:90.01ms +step:1531/1680 train_time:137805ms step_avg:90.01ms +step:1532/1680 train_time:137896ms step_avg:90.01ms +step:1533/1680 train_time:137988ms step_avg:90.01ms +step:1534/1680 train_time:138078ms step_avg:90.01ms +step:1535/1680 train_time:138169ms step_avg:90.01ms +step:1536/1680 train_time:138260ms step_avg:90.01ms +step:1537/1680 train_time:138355ms step_avg:90.02ms +step:1538/1680 train_time:138447ms step_avg:90.02ms +step:1539/1680 train_time:138536ms step_avg:90.02ms +step:1540/1680 train_time:138627ms step_avg:90.02ms +step:1541/1680 train_time:138717ms step_avg:90.02ms +step:1542/1680 train_time:138807ms step_avg:90.02ms +step:1543/1680 train_time:138898ms step_avg:90.02ms +step:1544/1680 train_time:138989ms step_avg:90.02ms +step:1545/1680 train_time:139080ms step_avg:90.02ms +step:1546/1680 train_time:139171ms step_avg:90.02ms +step:1547/1680 train_time:139261ms step_avg:90.02ms +step:1548/1680 train_time:139354ms step_avg:90.02ms +step:1549/1680 train_time:139446ms step_avg:90.02ms +step:1550/1680 train_time:139536ms step_avg:90.02ms +step:1551/1680 train_time:139626ms step_avg:90.02ms +step:1552/1680 train_time:139716ms step_avg:90.02ms +step:1553/1680 train_time:139807ms step_avg:90.02ms +step:1554/1680 train_time:139897ms step_avg:90.02ms +step:1555/1680 train_time:139987ms step_avg:90.02ms +step:1556/1680 train_time:140078ms step_avg:90.02ms +step:1557/1680 train_time:140170ms step_avg:90.03ms +step:1558/1680 train_time:140261ms step_avg:90.03ms +step:1559/1680 train_time:140354ms step_avg:90.03ms +step:1560/1680 train_time:140446ms step_avg:90.03ms +step:1561/1680 train_time:140536ms step_avg:90.03ms +step:1562/1680 train_time:140627ms step_avg:90.03ms +step:1563/1680 train_time:140717ms step_avg:90.03ms +step:1564/1680 train_time:140807ms step_avg:90.03ms +step:1565/1680 train_time:140897ms step_avg:90.03ms +step:1566/1680 train_time:140988ms step_avg:90.03ms +step:1567/1680 train_time:141078ms step_avg:90.03ms +step:1568/1680 train_time:141169ms step_avg:90.03ms +step:1569/1680 train_time:141259ms step_avg:90.03ms +step:1570/1680 train_time:141351ms step_avg:90.03ms +step:1571/1680 train_time:141442ms step_avg:90.03ms +step:1572/1680 train_time:141533ms step_avg:90.03ms +step:1573/1680 train_time:141624ms step_avg:90.03ms +step:1574/1680 train_time:141715ms step_avg:90.03ms +step:1575/1680 train_time:141805ms step_avg:90.04ms +step:1576/1680 train_time:141896ms step_avg:90.04ms +step:1577/1680 train_time:141986ms step_avg:90.04ms +step:1578/1680 train_time:142078ms step_avg:90.04ms +step:1579/1680 train_time:142168ms step_avg:90.04ms +step:1580/1680 train_time:142260ms step_avg:90.04ms +step:1581/1680 train_time:142352ms step_avg:90.04ms +step:1582/1680 train_time:142443ms step_avg:90.04ms +step:1583/1680 train_time:142534ms step_avg:90.04ms +step:1584/1680 train_time:142625ms step_avg:90.04ms +step:1585/1680 train_time:142715ms step_avg:90.04ms +step:1586/1680 train_time:142806ms step_avg:90.04ms +step:1587/1680 train_time:142897ms step_avg:90.04ms +step:1588/1680 train_time:142987ms step_avg:90.04ms +step:1589/1680 train_time:143078ms step_avg:90.04ms +step:1590/1680 train_time:143168ms step_avg:90.04ms +step:1591/1680 train_time:143259ms step_avg:90.04ms +step:1592/1680 train_time:143350ms step_avg:90.04ms +step:1593/1680 train_time:143441ms step_avg:90.04ms +step:1594/1680 train_time:143532ms step_avg:90.05ms +step:1595/1680 train_time:143623ms step_avg:90.05ms +step:1596/1680 train_time:143714ms step_avg:90.05ms +step:1597/1680 train_time:143805ms step_avg:90.05ms +step:1598/1680 train_time:143895ms step_avg:90.05ms +step:1599/1680 train_time:143985ms step_avg:90.05ms +step:1600/1680 train_time:144076ms step_avg:90.05ms +step:1601/1680 train_time:144167ms step_avg:90.05ms +step:1602/1680 train_time:144258ms step_avg:90.05ms +step:1603/1680 train_time:144349ms step_avg:90.05ms +step:1604/1680 train_time:144439ms step_avg:90.05ms +step:1605/1680 train_time:144531ms step_avg:90.05ms +step:1606/1680 train_time:144622ms step_avg:90.05ms +step:1607/1680 train_time:144713ms step_avg:90.05ms +step:1608/1680 train_time:144803ms step_avg:90.05ms +step:1609/1680 train_time:144894ms step_avg:90.05ms +step:1610/1680 train_time:144986ms step_avg:90.05ms +step:1611/1680 train_time:145077ms step_avg:90.05ms +step:1612/1680 train_time:145168ms step_avg:90.05ms +step:1613/1680 train_time:145259ms step_avg:90.06ms +step:1614/1680 train_time:145351ms step_avg:90.06ms +step:1615/1680 train_time:145441ms step_avg:90.06ms +step:1616/1680 train_time:145532ms step_avg:90.06ms +step:1617/1680 train_time:145623ms step_avg:90.06ms +step:1618/1680 train_time:145715ms step_avg:90.06ms +step:1619/1680 train_time:145806ms step_avg:90.06ms +step:1620/1680 train_time:145897ms step_avg:90.06ms +step:1621/1680 train_time:145989ms step_avg:90.06ms +step:1622/1680 train_time:146079ms step_avg:90.06ms +step:1623/1680 train_time:146170ms step_avg:90.06ms +step:1624/1680 train_time:146260ms step_avg:90.06ms +step:1625/1680 train_time:146352ms step_avg:90.06ms +step:1625/1680 val_loss:3.2884 train_time:146442ms step_avg:90.12ms +step:1626/1680 train_time:146464ms step_avg:90.08ms +step:1627/1680 train_time:146536ms step_avg:90.07ms +step:1628/1680 train_time:146630ms step_avg:90.07ms +step:1629/1680 train_time:146722ms step_avg:90.07ms +step:1630/1680 train_time:146812ms step_avg:90.07ms +step:1631/1680 train_time:146903ms step_avg:90.07ms +step:1632/1680 train_time:146991ms step_avg:90.07ms +step:1633/1680 train_time:147080ms step_avg:90.07ms +step:1634/1680 train_time:147170ms step_avg:90.07ms +step:1635/1680 train_time:147260ms step_avg:90.07ms +step:1636/1680 train_time:147351ms step_avg:90.07ms +step:1637/1680 train_time:147444ms step_avg:90.07ms +step:1638/1680 train_time:147537ms step_avg:90.07ms +step:1639/1680 train_time:147628ms step_avg:90.07ms +step:1640/1680 train_time:147720ms step_avg:90.07ms +step:1641/1680 train_time:147810ms step_avg:90.07ms +step:1642/1680 train_time:147901ms step_avg:90.07ms +step:1643/1680 train_time:147991ms step_avg:90.07ms +step:1644/1680 train_time:148081ms step_avg:90.07ms +step:1645/1680 train_time:148170ms step_avg:90.07ms +step:1646/1680 train_time:148261ms step_avg:90.07ms +step:1647/1680 train_time:148351ms step_avg:90.07ms +step:1648/1680 train_time:148444ms step_avg:90.08ms +step:1649/1680 train_time:148536ms step_avg:90.08ms +step:1650/1680 train_time:148627ms step_avg:90.08ms +step:1651/1680 train_time:148718ms step_avg:90.08ms +step:1652/1680 train_time:148808ms step_avg:90.08ms +step:1653/1680 train_time:148899ms step_avg:90.08ms +step:1654/1680 train_time:148989ms step_avg:90.08ms +step:1655/1680 train_time:149081ms step_avg:90.08ms +step:1656/1680 train_time:149172ms step_avg:90.08ms +step:1657/1680 train_time:149262ms step_avg:90.08ms +step:1658/1680 train_time:149351ms step_avg:90.08ms +step:1659/1680 train_time:149443ms step_avg:90.08ms +step:1660/1680 train_time:149532ms step_avg:90.08ms +step:1661/1680 train_time:149624ms step_avg:90.08ms +step:1662/1680 train_time:149716ms step_avg:90.08ms +step:1663/1680 train_time:149806ms step_avg:90.08ms +step:1664/1680 train_time:149896ms step_avg:90.08ms +step:1665/1680 train_time:149987ms step_avg:90.08ms +step:1666/1680 train_time:150078ms step_avg:90.08ms +step:1667/1680 train_time:150168ms step_avg:90.08ms +step:1668/1680 train_time:150259ms step_avg:90.08ms +step:1669/1680 train_time:150349ms step_avg:90.08ms +step:1670/1680 train_time:150440ms step_avg:90.08ms +step:1671/1680 train_time:150531ms step_avg:90.08ms +step:1672/1680 train_time:150622ms step_avg:90.08ms +step:1673/1680 train_time:150713ms step_avg:90.09ms +step:1674/1680 train_time:150804ms step_avg:90.09ms +step:1675/1680 train_time:150895ms step_avg:90.09ms +step:1676/1680 train_time:150986ms step_avg:90.09ms +step:1677/1680 train_time:151075ms step_avg:90.09ms +step:1678/1680 train_time:151166ms step_avg:90.09ms +step:1679/1680 train_time:151256ms step_avg:90.09ms +step:1680/1680 train_time:151347ms step_avg:90.09ms +step:1680/1680 val_loss:3.2777 train_time:151440ms step_avg:90.14ms +peak memory allocated: 31255 MiB reserved: 46774 MiB diff --git a/records/092125_DropAttn/79383017-eb05-4857-842a-7866b00571b4.txt b/records/092125_DropAttn/79383017-eb05-4857-842a-7866b00571b4.txt new file mode 100644 index 000000000..2ad25f137 --- /dev/null +++ b/records/092125_DropAttn/79383017-eb05-4857-842a-7866b00571b4.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 23:32:35 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 40C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 39C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 89370 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 89371 C /usr/bin/python3 614MiB | +| 0 N/A N/A 89372 C /usr/bin/python3 614MiB | +| 0 N/A N/A 89373 C /usr/bin/python3 614MiB | +| 0 N/A N/A 89374 C /usr/bin/python3 614MiB | +| 0 N/A N/A 89375 C /usr/bin/python3 614MiB | +| 0 N/A N/A 89376 C /usr/bin/python3 614MiB | +| 0 N/A N/A 89377 C /usr/bin/python3 614MiB | +| 1 N/A N/A 89371 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 89372 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 89373 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 89374 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 89375 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 89376 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 89377 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.04ms +step:1/1680 train_time:157ms step_avg:157.46ms +step:2/1680 train_time:180ms step_avg:90.12ms +step:3/1680 train_time:243ms step_avg:81.10ms +step:4/1680 train_time:330ms step_avg:82.52ms +step:5/1680 train_time:418ms step_avg:83.61ms +step:6/1680 train_time:506ms step_avg:84.33ms +step:7/1680 train_time:594ms step_avg:84.92ms +step:8/1680 train_time:683ms step_avg:85.39ms +step:9/1680 train_time:772ms step_avg:85.75ms +step:10/1680 train_time:860ms step_avg:86.04ms +step:11/1680 train_time:948ms step_avg:86.22ms +step:12/1680 train_time:1038ms step_avg:86.47ms +step:13/1680 train_time:1129ms step_avg:86.81ms +step:14/1680 train_time:1221ms step_avg:87.22ms +step:15/1680 train_time:1311ms step_avg:87.40ms +step:16/1680 train_time:1400ms step_avg:87.53ms +step:17/1680 train_time:1489ms step_avg:87.56ms +step:18/1680 train_time:1577ms step_avg:87.61ms +step:19/1680 train_time:1666ms step_avg:87.67ms +step:20/1680 train_time:1754ms step_avg:87.68ms +step:21/1680 train_time:1842ms step_avg:87.72ms +step:22/1680 train_time:1931ms step_avg:87.76ms +step:23/1680 train_time:2020ms step_avg:87.84ms +step:24/1680 train_time:2110ms step_avg:87.91ms +step:25/1680 train_time:2200ms step_avg:88.01ms +step:26/1680 train_time:2291ms step_avg:88.10ms +step:27/1680 train_time:2381ms step_avg:88.19ms +step:28/1680 train_time:2470ms step_avg:88.20ms +step:29/1680 train_time:2559ms step_avg:88.24ms +step:30/1680 train_time:2648ms step_avg:88.27ms +step:31/1680 train_time:2738ms step_avg:88.31ms +step:32/1680 train_time:2826ms step_avg:88.31ms +step:33/1680 train_time:2915ms step_avg:88.34ms +step:34/1680 train_time:3004ms step_avg:88.36ms +step:35/1680 train_time:3093ms step_avg:88.38ms +step:36/1680 train_time:3184ms step_avg:88.43ms +step:37/1680 train_time:3274ms step_avg:88.48ms +step:38/1680 train_time:3364ms step_avg:88.51ms +step:39/1680 train_time:3454ms step_avg:88.56ms +step:40/1680 train_time:3544ms step_avg:88.59ms +step:41/1680 train_time:3633ms step_avg:88.60ms +step:42/1680 train_time:3722ms step_avg:88.62ms +step:43/1680 train_time:3811ms step_avg:88.62ms +step:44/1680 train_time:3899ms step_avg:88.62ms +step:45/1680 train_time:3988ms step_avg:88.63ms +step:46/1680 train_time:4078ms step_avg:88.65ms +step:47/1680 train_time:4168ms step_avg:88.68ms +step:48/1680 train_time:4258ms step_avg:88.70ms +step:49/1680 train_time:4347ms step_avg:88.72ms +step:50/1680 train_time:4436ms step_avg:88.73ms +step:51/1680 train_time:4526ms step_avg:88.74ms +step:52/1680 train_time:4614ms step_avg:88.73ms +step:53/1680 train_time:4704ms step_avg:88.75ms +step:54/1680 train_time:4792ms step_avg:88.75ms +step:55/1680 train_time:4882ms step_avg:88.77ms +step:56/1680 train_time:4971ms step_avg:88.77ms +step:57/1680 train_time:5061ms step_avg:88.78ms +step:58/1680 train_time:5150ms step_avg:88.80ms +step:59/1680 train_time:5241ms step_avg:88.82ms +step:60/1680 train_time:5329ms step_avg:88.82ms +step:61/1680 train_time:5419ms step_avg:88.84ms +step:62/1680 train_time:5508ms step_avg:88.84ms +step:63/1680 train_time:5598ms step_avg:88.85ms +step:64/1680 train_time:5687ms step_avg:88.86ms +step:65/1680 train_time:5776ms step_avg:88.86ms +step:66/1680 train_time:5865ms step_avg:88.86ms +step:67/1680 train_time:5953ms step_avg:88.86ms +step:68/1680 train_time:6043ms step_avg:88.86ms +step:69/1680 train_time:6132ms step_avg:88.87ms +step:70/1680 train_time:6221ms step_avg:88.88ms +step:71/1680 train_time:6311ms step_avg:88.88ms +step:72/1680 train_time:6400ms step_avg:88.89ms +step:73/1680 train_time:6489ms step_avg:88.89ms +step:74/1680 train_time:6578ms step_avg:88.90ms +step:75/1680 train_time:6667ms step_avg:88.90ms +step:76/1680 train_time:6757ms step_avg:88.90ms +step:77/1680 train_time:6845ms step_avg:88.90ms +step:78/1680 train_time:6933ms step_avg:88.89ms +step:79/1680 train_time:7022ms step_avg:88.89ms +step:80/1680 train_time:7111ms step_avg:88.89ms +step:81/1680 train_time:7201ms step_avg:88.90ms +step:82/1680 train_time:7290ms step_avg:88.90ms +step:83/1680 train_time:7379ms step_avg:88.91ms +step:84/1680 train_time:7468ms step_avg:88.91ms +step:85/1680 train_time:7557ms step_avg:88.91ms +step:86/1680 train_time:7647ms step_avg:88.92ms +step:87/1680 train_time:7736ms step_avg:88.92ms +step:88/1680 train_time:7826ms step_avg:88.93ms +step:89/1680 train_time:7914ms step_avg:88.93ms +step:90/1680 train_time:8004ms step_avg:88.93ms +step:91/1680 train_time:8093ms step_avg:88.93ms +step:92/1680 train_time:8183ms step_avg:88.94ms +step:93/1680 train_time:8271ms step_avg:88.94ms +step:94/1680 train_time:8361ms step_avg:88.94ms +step:95/1680 train_time:8450ms step_avg:88.95ms +step:96/1680 train_time:8540ms step_avg:88.96ms +step:97/1680 train_time:8629ms step_avg:88.96ms +step:98/1680 train_time:8718ms step_avg:88.96ms +step:99/1680 train_time:8808ms step_avg:88.97ms +step:100/1680 train_time:8897ms step_avg:88.97ms +step:101/1680 train_time:8986ms step_avg:88.97ms +step:102/1680 train_time:9075ms step_avg:88.97ms +step:103/1680 train_time:9165ms step_avg:88.98ms +step:104/1680 train_time:9253ms step_avg:88.97ms +step:105/1680 train_time:9342ms step_avg:88.97ms +step:106/1680 train_time:9431ms step_avg:88.97ms +step:107/1680 train_time:9520ms step_avg:88.98ms +step:108/1680 train_time:9609ms step_avg:88.98ms +step:109/1680 train_time:9699ms step_avg:88.98ms +step:110/1680 train_time:9788ms step_avg:88.98ms +step:111/1680 train_time:9878ms step_avg:88.99ms +step:112/1680 train_time:9967ms step_avg:88.99ms +step:113/1680 train_time:10055ms step_avg:88.99ms +step:114/1680 train_time:10145ms step_avg:88.99ms +step:115/1680 train_time:10234ms step_avg:88.99ms +step:116/1680 train_time:10323ms step_avg:88.99ms +step:117/1680 train_time:10412ms step_avg:88.99ms +step:118/1680 train_time:10501ms step_avg:88.99ms +step:119/1680 train_time:10590ms step_avg:88.99ms +step:120/1680 train_time:10680ms step_avg:89.00ms +step:121/1680 train_time:10768ms step_avg:88.99ms +step:122/1680 train_time:10858ms step_avg:89.00ms +step:123/1680 train_time:10948ms step_avg:89.01ms +step:124/1680 train_time:11036ms step_avg:89.00ms +step:125/1680 train_time:11125ms step_avg:89.00ms +step:125/1680 val_loss:4.3252 train_time:11215ms step_avg:89.72ms +step:126/1680 train_time:11237ms step_avg:89.18ms +step:127/1680 train_time:11306ms step_avg:89.02ms +step:128/1680 train_time:11403ms step_avg:89.08ms +step:129/1680 train_time:11495ms step_avg:89.11ms +step:130/1680 train_time:11584ms step_avg:89.11ms +step:131/1680 train_time:11671ms step_avg:89.09ms +step:132/1680 train_time:11759ms step_avg:89.08ms +step:133/1680 train_time:11847ms step_avg:89.07ms +step:134/1680 train_time:11935ms step_avg:89.06ms +step:135/1680 train_time:12023ms step_avg:89.06ms +step:136/1680 train_time:12111ms step_avg:89.05ms +step:137/1680 train_time:12201ms step_avg:89.06ms +step:138/1680 train_time:12290ms step_avg:89.06ms +step:139/1680 train_time:12381ms step_avg:89.07ms +step:140/1680 train_time:12471ms step_avg:89.08ms +step:141/1680 train_time:12562ms step_avg:89.09ms +step:142/1680 train_time:12650ms step_avg:89.09ms +step:143/1680 train_time:12739ms step_avg:89.08ms +step:144/1680 train_time:12827ms step_avg:89.08ms +step:145/1680 train_time:12915ms step_avg:89.07ms +step:146/1680 train_time:13004ms step_avg:89.07ms +step:147/1680 train_time:13093ms step_avg:89.07ms +step:148/1680 train_time:13182ms step_avg:89.07ms +step:149/1680 train_time:13271ms step_avg:89.07ms +step:150/1680 train_time:13361ms step_avg:89.07ms +step:151/1680 train_time:13451ms step_avg:89.08ms +step:152/1680 train_time:13541ms step_avg:89.08ms +step:153/1680 train_time:13630ms step_avg:89.08ms +step:154/1680 train_time:13720ms step_avg:89.09ms +step:155/1680 train_time:13808ms step_avg:89.08ms +step:156/1680 train_time:13898ms step_avg:89.09ms +step:157/1680 train_time:13986ms step_avg:89.08ms +step:158/1680 train_time:14075ms step_avg:89.08ms +step:159/1680 train_time:14163ms step_avg:89.08ms +step:160/1680 train_time:14252ms step_avg:89.07ms +step:161/1680 train_time:14341ms step_avg:89.07ms +step:162/1680 train_time:14430ms step_avg:89.07ms +step:163/1680 train_time:14519ms step_avg:89.08ms +step:164/1680 train_time:14608ms step_avg:89.07ms +step:165/1680 train_time:14698ms step_avg:89.08ms +step:166/1680 train_time:14787ms step_avg:89.08ms +step:167/1680 train_time:14875ms step_avg:89.07ms +step:168/1680 train_time:14964ms step_avg:89.07ms +step:169/1680 train_time:15053ms step_avg:89.07ms +step:170/1680 train_time:15143ms step_avg:89.07ms +step:171/1680 train_time:15231ms step_avg:89.07ms +step:172/1680 train_time:15320ms step_avg:89.07ms +step:173/1680 train_time:15409ms step_avg:89.07ms +step:174/1680 train_time:15498ms step_avg:89.07ms +step:175/1680 train_time:15587ms step_avg:89.07ms +step:176/1680 train_time:15676ms step_avg:89.07ms +step:177/1680 train_time:15765ms step_avg:89.07ms +step:178/1680 train_time:15854ms step_avg:89.07ms +step:179/1680 train_time:15942ms step_avg:89.06ms +step:180/1680 train_time:16030ms step_avg:89.06ms +step:181/1680 train_time:16119ms step_avg:89.06ms +step:182/1680 train_time:16207ms step_avg:89.05ms +step:183/1680 train_time:16296ms step_avg:89.05ms +step:184/1680 train_time:16386ms step_avg:89.05ms +step:185/1680 train_time:16474ms step_avg:89.05ms +step:186/1680 train_time:16563ms step_avg:89.05ms +step:187/1680 train_time:16652ms step_avg:89.05ms +step:188/1680 train_time:16742ms step_avg:89.05ms +step:189/1680 train_time:16830ms step_avg:89.05ms +step:190/1680 train_time:16919ms step_avg:89.05ms +step:191/1680 train_time:17008ms step_avg:89.05ms +step:192/1680 train_time:17097ms step_avg:89.05ms +step:193/1680 train_time:17185ms step_avg:89.04ms +step:194/1680 train_time:17274ms step_avg:89.04ms +step:195/1680 train_time:17363ms step_avg:89.04ms +step:196/1680 train_time:17451ms step_avg:89.04ms +step:197/1680 train_time:17540ms step_avg:89.04ms +step:198/1680 train_time:17629ms step_avg:89.04ms +step:199/1680 train_time:17719ms step_avg:89.04ms +step:200/1680 train_time:17807ms step_avg:89.04ms +step:201/1680 train_time:17896ms step_avg:89.04ms +step:202/1680 train_time:17985ms step_avg:89.04ms +step:203/1680 train_time:18074ms step_avg:89.03ms +step:204/1680 train_time:18163ms step_avg:89.03ms +step:205/1680 train_time:18252ms step_avg:89.03ms +step:206/1680 train_time:18341ms step_avg:89.03ms +step:207/1680 train_time:18430ms step_avg:89.03ms +step:208/1680 train_time:18519ms step_avg:89.03ms +step:209/1680 train_time:18608ms step_avg:89.03ms +step:210/1680 train_time:18698ms step_avg:89.04ms +step:211/1680 train_time:18787ms step_avg:89.04ms +step:212/1680 train_time:18876ms step_avg:89.04ms +step:213/1680 train_time:18965ms step_avg:89.04ms +step:214/1680 train_time:19054ms step_avg:89.04ms +step:215/1680 train_time:19143ms step_avg:89.04ms +step:216/1680 train_time:19231ms step_avg:89.03ms +step:217/1680 train_time:19321ms step_avg:89.04ms +step:218/1680 train_time:19409ms step_avg:89.03ms +step:219/1680 train_time:19499ms step_avg:89.04ms +step:220/1680 train_time:19589ms step_avg:89.04ms +step:221/1680 train_time:19680ms step_avg:89.05ms +step:222/1680 train_time:19768ms step_avg:89.05ms +step:223/1680 train_time:19857ms step_avg:89.05ms +step:224/1680 train_time:19946ms step_avg:89.04ms +step:225/1680 train_time:20035ms step_avg:89.04ms +step:226/1680 train_time:20124ms step_avg:89.05ms +step:227/1680 train_time:20213ms step_avg:89.05ms +step:228/1680 train_time:20303ms step_avg:89.05ms +step:229/1680 train_time:20392ms step_avg:89.05ms +step:230/1680 train_time:20481ms step_avg:89.05ms +step:231/1680 train_time:20570ms step_avg:89.05ms +step:232/1680 train_time:20659ms step_avg:89.05ms +step:233/1680 train_time:20748ms step_avg:89.05ms +step:234/1680 train_time:20838ms step_avg:89.05ms +step:235/1680 train_time:20926ms step_avg:89.05ms +step:236/1680 train_time:21016ms step_avg:89.05ms +step:237/1680 train_time:21105ms step_avg:89.05ms +step:238/1680 train_time:21193ms step_avg:89.05ms +step:239/1680 train_time:21282ms step_avg:89.05ms +step:240/1680 train_time:21371ms step_avg:89.05ms +step:241/1680 train_time:21461ms step_avg:89.05ms +step:242/1680 train_time:21549ms step_avg:89.05ms +step:243/1680 train_time:21639ms step_avg:89.05ms +step:244/1680 train_time:21728ms step_avg:89.05ms +step:245/1680 train_time:21818ms step_avg:89.05ms +step:246/1680 train_time:21907ms step_avg:89.05ms +step:247/1680 train_time:21996ms step_avg:89.05ms +step:248/1680 train_time:22085ms step_avg:89.05ms +step:249/1680 train_time:22174ms step_avg:89.05ms +step:250/1680 train_time:22265ms step_avg:89.06ms +step:250/1680 val_loss:3.9699 train_time:22355ms step_avg:89.42ms +step:251/1680 train_time:22377ms step_avg:89.15ms +step:252/1680 train_time:22445ms step_avg:89.07ms +step:253/1680 train_time:22542ms step_avg:89.10ms +step:254/1680 train_time:22634ms step_avg:89.11ms +step:255/1680 train_time:22723ms step_avg:89.11ms +step:256/1680 train_time:22812ms step_avg:89.11ms +step:257/1680 train_time:22899ms step_avg:89.10ms +step:258/1680 train_time:22986ms step_avg:89.09ms +step:259/1680 train_time:23074ms step_avg:89.09ms +step:260/1680 train_time:23162ms step_avg:89.08ms +step:261/1680 train_time:23250ms step_avg:89.08ms +step:262/1680 train_time:23338ms step_avg:89.08ms +step:263/1680 train_time:23429ms step_avg:89.08ms +step:264/1680 train_time:23521ms step_avg:89.09ms +step:265/1680 train_time:23611ms step_avg:89.10ms +step:266/1680 train_time:23700ms step_avg:89.10ms +step:267/1680 train_time:23789ms step_avg:89.10ms +step:268/1680 train_time:23879ms step_avg:89.10ms +step:269/1680 train_time:23967ms step_avg:89.10ms +step:270/1680 train_time:24055ms step_avg:89.09ms +step:271/1680 train_time:24143ms step_avg:89.09ms +step:272/1680 train_time:24232ms step_avg:89.09ms +step:273/1680 train_time:24319ms step_avg:89.08ms +step:274/1680 train_time:24409ms step_avg:89.08ms +step:275/1680 train_time:24499ms step_avg:89.09ms +step:276/1680 train_time:24588ms step_avg:89.09ms +step:277/1680 train_time:24678ms step_avg:89.09ms +step:278/1680 train_time:24768ms step_avg:89.09ms +step:279/1680 train_time:24857ms step_avg:89.09ms +step:280/1680 train_time:24945ms step_avg:89.09ms +step:281/1680 train_time:25034ms step_avg:89.09ms +step:282/1680 train_time:25121ms step_avg:89.08ms +step:283/1680 train_time:25209ms step_avg:89.08ms +step:284/1680 train_time:25298ms step_avg:89.08ms +step:285/1680 train_time:25387ms step_avg:89.08ms +step:286/1680 train_time:25477ms step_avg:89.08ms +step:287/1680 train_time:25566ms step_avg:89.08ms +step:288/1680 train_time:25656ms step_avg:89.08ms +step:289/1680 train_time:25746ms step_avg:89.09ms +step:290/1680 train_time:25835ms step_avg:89.09ms +step:291/1680 train_time:25924ms step_avg:89.09ms +step:292/1680 train_time:26012ms step_avg:89.08ms +step:293/1680 train_time:26101ms step_avg:89.08ms +step:294/1680 train_time:26190ms step_avg:89.08ms +step:295/1680 train_time:26278ms step_avg:89.08ms +step:296/1680 train_time:26367ms step_avg:89.08ms +step:297/1680 train_time:26456ms step_avg:89.08ms +step:298/1680 train_time:26545ms step_avg:89.08ms +step:299/1680 train_time:26634ms step_avg:89.08ms +step:300/1680 train_time:26723ms step_avg:89.08ms +step:301/1680 train_time:26813ms step_avg:89.08ms +step:302/1680 train_time:26901ms step_avg:89.08ms +step:303/1680 train_time:26990ms step_avg:89.08ms +step:304/1680 train_time:27079ms step_avg:89.08ms +step:305/1680 train_time:27168ms step_avg:89.07ms +step:306/1680 train_time:27257ms step_avg:89.07ms +step:307/1680 train_time:27345ms step_avg:89.07ms +step:308/1680 train_time:27435ms step_avg:89.07ms +step:309/1680 train_time:27524ms step_avg:89.07ms +step:310/1680 train_time:27613ms step_avg:89.07ms +step:311/1680 train_time:27701ms step_avg:89.07ms +step:312/1680 train_time:27791ms step_avg:89.07ms +step:313/1680 train_time:27880ms step_avg:89.07ms +step:314/1680 train_time:27970ms step_avg:89.08ms +step:315/1680 train_time:28058ms step_avg:89.07ms +step:316/1680 train_time:28147ms step_avg:89.07ms +step:317/1680 train_time:28236ms step_avg:89.07ms +step:318/1680 train_time:28323ms step_avg:89.07ms +step:319/1680 train_time:28412ms step_avg:89.07ms +step:320/1680 train_time:28501ms step_avg:89.07ms +step:321/1680 train_time:28590ms step_avg:89.06ms +step:322/1680 train_time:28680ms step_avg:89.07ms +step:323/1680 train_time:28769ms step_avg:89.07ms +step:324/1680 train_time:28858ms step_avg:89.07ms +step:325/1680 train_time:28947ms step_avg:89.07ms +step:326/1680 train_time:29037ms step_avg:89.07ms +step:327/1680 train_time:29125ms step_avg:89.07ms +step:328/1680 train_time:29214ms step_avg:89.07ms +step:329/1680 train_time:29302ms step_avg:89.07ms +step:330/1680 train_time:29391ms step_avg:89.06ms +step:331/1680 train_time:29480ms step_avg:89.06ms +step:332/1680 train_time:29569ms step_avg:89.06ms +step:333/1680 train_time:29658ms step_avg:89.06ms +step:334/1680 train_time:29747ms step_avg:89.06ms +step:335/1680 train_time:29837ms step_avg:89.06ms +step:336/1680 train_time:29926ms step_avg:89.07ms +step:337/1680 train_time:30015ms step_avg:89.07ms +step:338/1680 train_time:30104ms step_avg:89.07ms +step:339/1680 train_time:30194ms step_avg:89.07ms +step:340/1680 train_time:30282ms step_avg:89.06ms +step:341/1680 train_time:30371ms step_avg:89.06ms +step:342/1680 train_time:30459ms step_avg:89.06ms +step:343/1680 train_time:30549ms step_avg:89.06ms +step:344/1680 train_time:30637ms step_avg:89.06ms +step:345/1680 train_time:30727ms step_avg:89.06ms +step:346/1680 train_time:30816ms step_avg:89.06ms +step:347/1680 train_time:30905ms step_avg:89.06ms +step:348/1680 train_time:30995ms step_avg:89.06ms +step:349/1680 train_time:31084ms step_avg:89.06ms +step:350/1680 train_time:31173ms step_avg:89.06ms +step:351/1680 train_time:31261ms step_avg:89.06ms +step:352/1680 train_time:31349ms step_avg:89.06ms +step:353/1680 train_time:31439ms step_avg:89.06ms +step:354/1680 train_time:31527ms step_avg:89.06ms +step:355/1680 train_time:31616ms step_avg:89.06ms +step:356/1680 train_time:31705ms step_avg:89.06ms +step:357/1680 train_time:31795ms step_avg:89.06ms +step:358/1680 train_time:31883ms step_avg:89.06ms +step:359/1680 train_time:31973ms step_avg:89.06ms +step:360/1680 train_time:32061ms step_avg:89.06ms +step:361/1680 train_time:32150ms step_avg:89.06ms +step:362/1680 train_time:32238ms step_avg:89.06ms +step:363/1680 train_time:32328ms step_avg:89.06ms +step:364/1680 train_time:32417ms step_avg:89.06ms +step:365/1680 train_time:32507ms step_avg:89.06ms +step:366/1680 train_time:32596ms step_avg:89.06ms +step:367/1680 train_time:32684ms step_avg:89.06ms +step:368/1680 train_time:32773ms step_avg:89.06ms +step:369/1680 train_time:32861ms step_avg:89.05ms +step:370/1680 train_time:32951ms step_avg:89.06ms +step:371/1680 train_time:33040ms step_avg:89.06ms +step:372/1680 train_time:33130ms step_avg:89.06ms +step:373/1680 train_time:33218ms step_avg:89.06ms +step:374/1680 train_time:33307ms step_avg:89.06ms +step:375/1680 train_time:33396ms step_avg:89.06ms +step:375/1680 val_loss:3.8187 train_time:33486ms step_avg:89.30ms +step:376/1680 train_time:33508ms step_avg:89.12ms +step:377/1680 train_time:33577ms step_avg:89.06ms +step:378/1680 train_time:33671ms step_avg:89.08ms +step:379/1680 train_time:33763ms step_avg:89.08ms +step:380/1680 train_time:33851ms step_avg:89.08ms +step:381/1680 train_time:33939ms step_avg:89.08ms +step:382/1680 train_time:34027ms step_avg:89.08ms +step:383/1680 train_time:34115ms step_avg:89.07ms +step:384/1680 train_time:34203ms step_avg:89.07ms +step:385/1680 train_time:34291ms step_avg:89.07ms +step:386/1680 train_time:34379ms step_avg:89.07ms +step:387/1680 train_time:34467ms step_avg:89.06ms +step:388/1680 train_time:34559ms step_avg:89.07ms +step:389/1680 train_time:34649ms step_avg:89.07ms +step:390/1680 train_time:34740ms step_avg:89.08ms +step:391/1680 train_time:34829ms step_avg:89.08ms +step:392/1680 train_time:34917ms step_avg:89.07ms +step:393/1680 train_time:35006ms step_avg:89.07ms +step:394/1680 train_time:35096ms step_avg:89.08ms +step:395/1680 train_time:35183ms step_avg:89.07ms +step:396/1680 train_time:35271ms step_avg:89.07ms +step:397/1680 train_time:35360ms step_avg:89.07ms +step:398/1680 train_time:35448ms step_avg:89.07ms +step:399/1680 train_time:35538ms step_avg:89.07ms +step:400/1680 train_time:35627ms step_avg:89.07ms +step:401/1680 train_time:35718ms step_avg:89.07ms +step:402/1680 train_time:35807ms step_avg:89.07ms +step:403/1680 train_time:35896ms step_avg:89.07ms +step:404/1680 train_time:35985ms step_avg:89.07ms +step:405/1680 train_time:36074ms step_avg:89.07ms +step:406/1680 train_time:36162ms step_avg:89.07ms +step:407/1680 train_time:36250ms step_avg:89.07ms +step:408/1680 train_time:36339ms step_avg:89.07ms +step:409/1680 train_time:36427ms step_avg:89.06ms +step:410/1680 train_time:36516ms step_avg:89.06ms +step:411/1680 train_time:36605ms step_avg:89.06ms +step:412/1680 train_time:36696ms step_avg:89.07ms +step:413/1680 train_time:36784ms step_avg:89.06ms +step:414/1680 train_time:36873ms step_avg:89.07ms +step:415/1680 train_time:36962ms step_avg:89.07ms +step:416/1680 train_time:37051ms step_avg:89.06ms +step:417/1680 train_time:37140ms step_avg:89.06ms +step:418/1680 train_time:37227ms step_avg:89.06ms +step:419/1680 train_time:37316ms step_avg:89.06ms +step:420/1680 train_time:37405ms step_avg:89.06ms +step:421/1680 train_time:37494ms step_avg:89.06ms +step:422/1680 train_time:37584ms step_avg:89.06ms +step:423/1680 train_time:37673ms step_avg:89.06ms +step:424/1680 train_time:37763ms step_avg:89.06ms +step:425/1680 train_time:37851ms step_avg:89.06ms +step:426/1680 train_time:37941ms step_avg:89.06ms +step:427/1680 train_time:38029ms step_avg:89.06ms +step:428/1680 train_time:38118ms step_avg:89.06ms +step:429/1680 train_time:38206ms step_avg:89.06ms +step:430/1680 train_time:38294ms step_avg:89.06ms +step:431/1680 train_time:38383ms step_avg:89.06ms +step:432/1680 train_time:38473ms step_avg:89.06ms +step:433/1680 train_time:38562ms step_avg:89.06ms +step:434/1680 train_time:38650ms step_avg:89.06ms +step:435/1680 train_time:38740ms step_avg:89.06ms +step:436/1680 train_time:38829ms step_avg:89.06ms +step:437/1680 train_time:38919ms step_avg:89.06ms +step:438/1680 train_time:39008ms step_avg:89.06ms +step:439/1680 train_time:39098ms step_avg:89.06ms +step:440/1680 train_time:39186ms step_avg:89.06ms +step:441/1680 train_time:39275ms step_avg:89.06ms +step:442/1680 train_time:39363ms step_avg:89.06ms +step:443/1680 train_time:39452ms step_avg:89.06ms +step:444/1680 train_time:39542ms step_avg:89.06ms +step:445/1680 train_time:39631ms step_avg:89.06ms +step:446/1680 train_time:39720ms step_avg:89.06ms +step:447/1680 train_time:39810ms step_avg:89.06ms +step:448/1680 train_time:39899ms step_avg:89.06ms +step:449/1680 train_time:39988ms step_avg:89.06ms +step:450/1680 train_time:40077ms step_avg:89.06ms +step:451/1680 train_time:40166ms step_avg:89.06ms +step:452/1680 train_time:40255ms step_avg:89.06ms +step:453/1680 train_time:40344ms step_avg:89.06ms +step:454/1680 train_time:40433ms step_avg:89.06ms +step:455/1680 train_time:40522ms step_avg:89.06ms +step:456/1680 train_time:40611ms step_avg:89.06ms +step:457/1680 train_time:40702ms step_avg:89.06ms +step:458/1680 train_time:40790ms step_avg:89.06ms +step:459/1680 train_time:40879ms step_avg:89.06ms +step:460/1680 train_time:40968ms step_avg:89.06ms +step:461/1680 train_time:41057ms step_avg:89.06ms +step:462/1680 train_time:41145ms step_avg:89.06ms +step:463/1680 train_time:41234ms step_avg:89.06ms +step:464/1680 train_time:41323ms step_avg:89.06ms +step:465/1680 train_time:41412ms step_avg:89.06ms +step:466/1680 train_time:41501ms step_avg:89.06ms +step:467/1680 train_time:41591ms step_avg:89.06ms +step:468/1680 train_time:41681ms step_avg:89.06ms +step:469/1680 train_time:41770ms step_avg:89.06ms +step:470/1680 train_time:41860ms step_avg:89.06ms +step:471/1680 train_time:41948ms step_avg:89.06ms +step:472/1680 train_time:42037ms step_avg:89.06ms +step:473/1680 train_time:42125ms step_avg:89.06ms +step:474/1680 train_time:42214ms step_avg:89.06ms +step:475/1680 train_time:42303ms step_avg:89.06ms +step:476/1680 train_time:42392ms step_avg:89.06ms +step:477/1680 train_time:42480ms step_avg:89.06ms +step:478/1680 train_time:42569ms step_avg:89.06ms +step:479/1680 train_time:42659ms step_avg:89.06ms +step:480/1680 train_time:42748ms step_avg:89.06ms +step:481/1680 train_time:42838ms step_avg:89.06ms +step:482/1680 train_time:42927ms step_avg:89.06ms +step:483/1680 train_time:43017ms step_avg:89.06ms +step:484/1680 train_time:43105ms step_avg:89.06ms +step:485/1680 train_time:43194ms step_avg:89.06ms +step:486/1680 train_time:43283ms step_avg:89.06ms +step:487/1680 train_time:43372ms step_avg:89.06ms +step:488/1680 train_time:43462ms step_avg:89.06ms +step:489/1680 train_time:43550ms step_avg:89.06ms +step:490/1680 train_time:43640ms step_avg:89.06ms +step:491/1680 train_time:43729ms step_avg:89.06ms +step:492/1680 train_time:43818ms step_avg:89.06ms +step:493/1680 train_time:43906ms step_avg:89.06ms +step:494/1680 train_time:43995ms step_avg:89.06ms +step:495/1680 train_time:44084ms step_avg:89.06ms +step:496/1680 train_time:44173ms step_avg:89.06ms +step:497/1680 train_time:44262ms step_avg:89.06ms +step:498/1680 train_time:44350ms step_avg:89.06ms +step:499/1680 train_time:44440ms step_avg:89.06ms +step:500/1680 train_time:44529ms step_avg:89.06ms +step:500/1680 val_loss:3.7149 train_time:44619ms step_avg:89.24ms +step:501/1680 train_time:44641ms step_avg:89.10ms +step:502/1680 train_time:44709ms step_avg:89.06ms +step:503/1680 train_time:44803ms step_avg:89.07ms +step:504/1680 train_time:44893ms step_avg:89.07ms +step:505/1680 train_time:44982ms step_avg:89.07ms +step:506/1680 train_time:45070ms step_avg:89.07ms +step:507/1680 train_time:45158ms step_avg:89.07ms +step:508/1680 train_time:45247ms step_avg:89.07ms +step:509/1680 train_time:45335ms step_avg:89.07ms +step:510/1680 train_time:45422ms step_avg:89.06ms +step:511/1680 train_time:45511ms step_avg:89.06ms +step:512/1680 train_time:45601ms step_avg:89.06ms +step:513/1680 train_time:45692ms step_avg:89.07ms +step:514/1680 train_time:45783ms step_avg:89.07ms +step:515/1680 train_time:45873ms step_avg:89.07ms +step:516/1680 train_time:45962ms step_avg:89.07ms +step:517/1680 train_time:46051ms step_avg:89.07ms +step:518/1680 train_time:46139ms step_avg:89.07ms +step:519/1680 train_time:46228ms step_avg:89.07ms +step:520/1680 train_time:46316ms step_avg:89.07ms +step:521/1680 train_time:46405ms step_avg:89.07ms +step:522/1680 train_time:46494ms step_avg:89.07ms +step:523/1680 train_time:46583ms step_avg:89.07ms +step:524/1680 train_time:46673ms step_avg:89.07ms +step:525/1680 train_time:46763ms step_avg:89.07ms +step:526/1680 train_time:46853ms step_avg:89.07ms +step:527/1680 train_time:46942ms step_avg:89.07ms +step:528/1680 train_time:47031ms step_avg:89.07ms +step:529/1680 train_time:47119ms step_avg:89.07ms +step:530/1680 train_time:47208ms step_avg:89.07ms +step:531/1680 train_time:47296ms step_avg:89.07ms +step:532/1680 train_time:47385ms step_avg:89.07ms +step:533/1680 train_time:47473ms step_avg:89.07ms +step:534/1680 train_time:47562ms step_avg:89.07ms +step:535/1680 train_time:47651ms step_avg:89.07ms +step:536/1680 train_time:47740ms step_avg:89.07ms +step:537/1680 train_time:47830ms step_avg:89.07ms +step:538/1680 train_time:47920ms step_avg:89.07ms +step:539/1680 train_time:48008ms step_avg:89.07ms +step:540/1680 train_time:48097ms step_avg:89.07ms +step:541/1680 train_time:48186ms step_avg:89.07ms +step:542/1680 train_time:48274ms step_avg:89.07ms +step:543/1680 train_time:48362ms step_avg:89.07ms +step:544/1680 train_time:48451ms step_avg:89.06ms +step:545/1680 train_time:48542ms step_avg:89.07ms +step:546/1680 train_time:48631ms step_avg:89.07ms +step:547/1680 train_time:48720ms step_avg:89.07ms +step:548/1680 train_time:48809ms step_avg:89.07ms +step:549/1680 train_time:48899ms step_avg:89.07ms +step:550/1680 train_time:48990ms step_avg:89.07ms +step:551/1680 train_time:49081ms step_avg:89.08ms +step:552/1680 train_time:49171ms step_avg:89.08ms +step:553/1680 train_time:49260ms step_avg:89.08ms +step:554/1680 train_time:49351ms step_avg:89.08ms +step:555/1680 train_time:49441ms step_avg:89.08ms +step:556/1680 train_time:49531ms step_avg:89.08ms +step:557/1680 train_time:49622ms step_avg:89.09ms +step:558/1680 train_time:49712ms step_avg:89.09ms +step:559/1680 train_time:49802ms step_avg:89.09ms +step:560/1680 train_time:49892ms step_avg:89.09ms +step:561/1680 train_time:49983ms step_avg:89.10ms +step:562/1680 train_time:50073ms step_avg:89.10ms +step:563/1680 train_time:50164ms step_avg:89.10ms +step:564/1680 train_time:50254ms step_avg:89.10ms +step:565/1680 train_time:50345ms step_avg:89.11ms +step:566/1680 train_time:50434ms step_avg:89.11ms +step:567/1680 train_time:50525ms step_avg:89.11ms +step:568/1680 train_time:50615ms step_avg:89.11ms +step:569/1680 train_time:50705ms step_avg:89.11ms +step:570/1680 train_time:50796ms step_avg:89.11ms +step:571/1680 train_time:50887ms step_avg:89.12ms +step:572/1680 train_time:50977ms step_avg:89.12ms +step:573/1680 train_time:51068ms step_avg:89.12ms +step:574/1680 train_time:51158ms step_avg:89.13ms +step:575/1680 train_time:51249ms step_avg:89.13ms +step:576/1680 train_time:51339ms step_avg:89.13ms +step:577/1680 train_time:51429ms step_avg:89.13ms +step:578/1680 train_time:51519ms step_avg:89.13ms +step:579/1680 train_time:51609ms step_avg:89.13ms +step:580/1680 train_time:51699ms step_avg:89.14ms +step:581/1680 train_time:51789ms step_avg:89.14ms +step:582/1680 train_time:51880ms step_avg:89.14ms +step:583/1680 train_time:51970ms step_avg:89.14ms +step:584/1680 train_time:52061ms step_avg:89.15ms +step:585/1680 train_time:52151ms step_avg:89.15ms +step:586/1680 train_time:52242ms step_avg:89.15ms +step:587/1680 train_time:52332ms step_avg:89.15ms +step:588/1680 train_time:52422ms step_avg:89.15ms +step:589/1680 train_time:52513ms step_avg:89.16ms +step:590/1680 train_time:52603ms step_avg:89.16ms +step:591/1680 train_time:52694ms step_avg:89.16ms +step:592/1680 train_time:52784ms step_avg:89.16ms +step:593/1680 train_time:52874ms step_avg:89.16ms +step:594/1680 train_time:52965ms step_avg:89.17ms +step:595/1680 train_time:53056ms step_avg:89.17ms +step:596/1680 train_time:53148ms step_avg:89.17ms +step:597/1680 train_time:53238ms step_avg:89.18ms +step:598/1680 train_time:53328ms step_avg:89.18ms +step:599/1680 train_time:53418ms step_avg:89.18ms +step:600/1680 train_time:53509ms step_avg:89.18ms +step:601/1680 train_time:53599ms step_avg:89.18ms +step:602/1680 train_time:53689ms step_avg:89.18ms +step:603/1680 train_time:53779ms step_avg:89.19ms +step:604/1680 train_time:53869ms step_avg:89.19ms +step:605/1680 train_time:53960ms step_avg:89.19ms +step:606/1680 train_time:54050ms step_avg:89.19ms +step:607/1680 train_time:54141ms step_avg:89.19ms +step:608/1680 train_time:54231ms step_avg:89.20ms +step:609/1680 train_time:54321ms step_avg:89.20ms +step:610/1680 train_time:54411ms step_avg:89.20ms +step:611/1680 train_time:54502ms step_avg:89.20ms +step:612/1680 train_time:54592ms step_avg:89.20ms +step:613/1680 train_time:54682ms step_avg:89.20ms +step:614/1680 train_time:54771ms step_avg:89.20ms +step:615/1680 train_time:54861ms step_avg:89.21ms +step:616/1680 train_time:54951ms step_avg:89.21ms +step:617/1680 train_time:55041ms step_avg:89.21ms +step:618/1680 train_time:55132ms step_avg:89.21ms +step:619/1680 train_time:55222ms step_avg:89.21ms +step:620/1680 train_time:55312ms step_avg:89.21ms +step:621/1680 train_time:55402ms step_avg:89.21ms +step:622/1680 train_time:55494ms step_avg:89.22ms +step:623/1680 train_time:55584ms step_avg:89.22ms +step:624/1680 train_time:55675ms step_avg:89.22ms +step:625/1680 train_time:55765ms step_avg:89.22ms +step:625/1680 val_loss:3.6163 train_time:55856ms step_avg:89.37ms +step:626/1680 train_time:55879ms step_avg:89.26ms +step:627/1680 train_time:55950ms step_avg:89.24ms +step:628/1680 train_time:56049ms step_avg:89.25ms +step:629/1680 train_time:56138ms step_avg:89.25ms +step:630/1680 train_time:56228ms step_avg:89.25ms +step:631/1680 train_time:56316ms step_avg:89.25ms +step:632/1680 train_time:56405ms step_avg:89.25ms +step:633/1680 train_time:56495ms step_avg:89.25ms +step:634/1680 train_time:56583ms step_avg:89.25ms +step:635/1680 train_time:56672ms step_avg:89.25ms +step:636/1680 train_time:56762ms step_avg:89.25ms +step:637/1680 train_time:56856ms step_avg:89.26ms +step:638/1680 train_time:56948ms step_avg:89.26ms +step:639/1680 train_time:57039ms step_avg:89.26ms +step:640/1680 train_time:57130ms step_avg:89.27ms +step:641/1680 train_time:57220ms step_avg:89.27ms +step:642/1680 train_time:57309ms step_avg:89.27ms +step:643/1680 train_time:57398ms step_avg:89.27ms +step:644/1680 train_time:57488ms step_avg:89.27ms +step:645/1680 train_time:57577ms step_avg:89.27ms +step:646/1680 train_time:57667ms step_avg:89.27ms +step:647/1680 train_time:57757ms step_avg:89.27ms +step:648/1680 train_time:57849ms step_avg:89.27ms +step:649/1680 train_time:57941ms step_avg:89.28ms +step:650/1680 train_time:58032ms step_avg:89.28ms +step:651/1680 train_time:58124ms step_avg:89.28ms +step:652/1680 train_time:58214ms step_avg:89.29ms +step:653/1680 train_time:58304ms step_avg:89.29ms +step:654/1680 train_time:58394ms step_avg:89.29ms +step:655/1680 train_time:58483ms step_avg:89.29ms +step:656/1680 train_time:58573ms step_avg:89.29ms +step:657/1680 train_time:58662ms step_avg:89.29ms +step:658/1680 train_time:58753ms step_avg:89.29ms +step:659/1680 train_time:58844ms step_avg:89.29ms +step:660/1680 train_time:58935ms step_avg:89.30ms +step:661/1680 train_time:59026ms step_avg:89.30ms +step:662/1680 train_time:59117ms step_avg:89.30ms +step:663/1680 train_time:59208ms step_avg:89.30ms +step:664/1680 train_time:59298ms step_avg:89.30ms +step:665/1680 train_time:59388ms step_avg:89.31ms +step:666/1680 train_time:59478ms step_avg:89.31ms +step:667/1680 train_time:59567ms step_avg:89.31ms +step:668/1680 train_time:59657ms step_avg:89.31ms +step:669/1680 train_time:59746ms step_avg:89.31ms +step:670/1680 train_time:59836ms step_avg:89.31ms +step:671/1680 train_time:59927ms step_avg:89.31ms +step:672/1680 train_time:60018ms step_avg:89.31ms +step:673/1680 train_time:60109ms step_avg:89.32ms +step:674/1680 train_time:60201ms step_avg:89.32ms +step:675/1680 train_time:60292ms step_avg:89.32ms +step:676/1680 train_time:60383ms step_avg:89.32ms +step:677/1680 train_time:60473ms step_avg:89.32ms +step:678/1680 train_time:60562ms step_avg:89.33ms +step:679/1680 train_time:60653ms step_avg:89.33ms +step:680/1680 train_time:60742ms step_avg:89.33ms +step:681/1680 train_time:60832ms step_avg:89.33ms +step:682/1680 train_time:60923ms step_avg:89.33ms +step:683/1680 train_time:61013ms step_avg:89.33ms +step:684/1680 train_time:61103ms step_avg:89.33ms +step:685/1680 train_time:61193ms step_avg:89.33ms +step:686/1680 train_time:61285ms step_avg:89.34ms +step:687/1680 train_time:61375ms step_avg:89.34ms +step:688/1680 train_time:61464ms step_avg:89.34ms +step:689/1680 train_time:61553ms step_avg:89.34ms +step:690/1680 train_time:61644ms step_avg:89.34ms +step:691/1680 train_time:61733ms step_avg:89.34ms +step:692/1680 train_time:61823ms step_avg:89.34ms +step:693/1680 train_time:61913ms step_avg:89.34ms +step:694/1680 train_time:62003ms step_avg:89.34ms +step:695/1680 train_time:62094ms step_avg:89.34ms +step:696/1680 train_time:62184ms step_avg:89.34ms +step:697/1680 train_time:62274ms step_avg:89.35ms +step:698/1680 train_time:62365ms step_avg:89.35ms +step:699/1680 train_time:62455ms step_avg:89.35ms +step:700/1680 train_time:62545ms step_avg:89.35ms +step:701/1680 train_time:62635ms step_avg:89.35ms +step:702/1680 train_time:62725ms step_avg:89.35ms +step:703/1680 train_time:62815ms step_avg:89.35ms +step:704/1680 train_time:62905ms step_avg:89.35ms +step:705/1680 train_time:62995ms step_avg:89.36ms +step:706/1680 train_time:63086ms step_avg:89.36ms +step:707/1680 train_time:63177ms step_avg:89.36ms +step:708/1680 train_time:63267ms step_avg:89.36ms +step:709/1680 train_time:63357ms step_avg:89.36ms +step:710/1680 train_time:63447ms step_avg:89.36ms +step:711/1680 train_time:63537ms step_avg:89.36ms +step:712/1680 train_time:63627ms step_avg:89.36ms +step:713/1680 train_time:63716ms step_avg:89.36ms +step:714/1680 train_time:63806ms step_avg:89.36ms +step:715/1680 train_time:63896ms step_avg:89.37ms +step:716/1680 train_time:63987ms step_avg:89.37ms +step:717/1680 train_time:64077ms step_avg:89.37ms +step:718/1680 train_time:64168ms step_avg:89.37ms +step:719/1680 train_time:64257ms step_avg:89.37ms +step:720/1680 train_time:64347ms step_avg:89.37ms +step:721/1680 train_time:64437ms step_avg:89.37ms +step:722/1680 train_time:64528ms step_avg:89.37ms +step:723/1680 train_time:64618ms step_avg:89.37ms +step:724/1680 train_time:64707ms step_avg:89.37ms +step:725/1680 train_time:64797ms step_avg:89.38ms +step:726/1680 train_time:64887ms step_avg:89.38ms +step:727/1680 train_time:64977ms step_avg:89.38ms +step:728/1680 train_time:65068ms step_avg:89.38ms +step:729/1680 train_time:65158ms step_avg:89.38ms +step:730/1680 train_time:65248ms step_avg:89.38ms +step:731/1680 train_time:65338ms step_avg:89.38ms +step:732/1680 train_time:65429ms step_avg:89.38ms +step:733/1680 train_time:65519ms step_avg:89.39ms +step:734/1680 train_time:65610ms step_avg:89.39ms +step:735/1680 train_time:65700ms step_avg:89.39ms +step:736/1680 train_time:65790ms step_avg:89.39ms +step:737/1680 train_time:65879ms step_avg:89.39ms +step:738/1680 train_time:65969ms step_avg:89.39ms +step:739/1680 train_time:66059ms step_avg:89.39ms +step:740/1680 train_time:66150ms step_avg:89.39ms +step:741/1680 train_time:66240ms step_avg:89.39ms +step:742/1680 train_time:66330ms step_avg:89.39ms +step:743/1680 train_time:66420ms step_avg:89.39ms +step:744/1680 train_time:66511ms step_avg:89.40ms +step:745/1680 train_time:66601ms step_avg:89.40ms +step:746/1680 train_time:66691ms step_avg:89.40ms +step:747/1680 train_time:66782ms step_avg:89.40ms +step:748/1680 train_time:66873ms step_avg:89.40ms +step:749/1680 train_time:66963ms step_avg:89.40ms +step:750/1680 train_time:67053ms step_avg:89.40ms +step:750/1680 val_loss:3.5647 train_time:67145ms step_avg:89.53ms +step:751/1680 train_time:67167ms step_avg:89.44ms +step:752/1680 train_time:67240ms step_avg:89.42ms +step:753/1680 train_time:67336ms step_avg:89.42ms +step:754/1680 train_time:67428ms step_avg:89.43ms +step:755/1680 train_time:67517ms step_avg:89.43ms +step:756/1680 train_time:67606ms step_avg:89.43ms +step:757/1680 train_time:67696ms step_avg:89.43ms +step:758/1680 train_time:67784ms step_avg:89.43ms +step:759/1680 train_time:67874ms step_avg:89.43ms +step:760/1680 train_time:67964ms step_avg:89.43ms +step:761/1680 train_time:68054ms step_avg:89.43ms +step:762/1680 train_time:68145ms step_avg:89.43ms +step:763/1680 train_time:68238ms step_avg:89.43ms +step:764/1680 train_time:68331ms step_avg:89.44ms +step:765/1680 train_time:68423ms step_avg:89.44ms +step:766/1680 train_time:68514ms step_avg:89.44ms +step:767/1680 train_time:68605ms step_avg:89.45ms +step:768/1680 train_time:68694ms step_avg:89.45ms +step:769/1680 train_time:68783ms step_avg:89.44ms +step:770/1680 train_time:68872ms step_avg:89.44ms +step:771/1680 train_time:68961ms step_avg:89.44ms +step:772/1680 train_time:69051ms step_avg:89.44ms +step:773/1680 train_time:69142ms step_avg:89.45ms +step:774/1680 train_time:69234ms step_avg:89.45ms +step:775/1680 train_time:69325ms step_avg:89.45ms +step:776/1680 train_time:69417ms step_avg:89.45ms +step:777/1680 train_time:69507ms step_avg:89.46ms +step:778/1680 train_time:69598ms step_avg:89.46ms +step:779/1680 train_time:69687ms step_avg:89.46ms +step:780/1680 train_time:69777ms step_avg:89.46ms +step:781/1680 train_time:69866ms step_avg:89.46ms +step:782/1680 train_time:69956ms step_avg:89.46ms +step:783/1680 train_time:70046ms step_avg:89.46ms +step:784/1680 train_time:70137ms step_avg:89.46ms +step:785/1680 train_time:70227ms step_avg:89.46ms +step:786/1680 train_time:70318ms step_avg:89.46ms +step:787/1680 train_time:70410ms step_avg:89.47ms +step:788/1680 train_time:70500ms step_avg:89.47ms +step:789/1680 train_time:70590ms step_avg:89.47ms +step:790/1680 train_time:70679ms step_avg:89.47ms +step:791/1680 train_time:70769ms step_avg:89.47ms +step:792/1680 train_time:70858ms step_avg:89.47ms +step:793/1680 train_time:70948ms step_avg:89.47ms +step:794/1680 train_time:71038ms step_avg:89.47ms +step:795/1680 train_time:71128ms step_avg:89.47ms +step:796/1680 train_time:71218ms step_avg:89.47ms +step:797/1680 train_time:71310ms step_avg:89.47ms +step:798/1680 train_time:71401ms step_avg:89.47ms +step:799/1680 train_time:71492ms step_avg:89.48ms +step:800/1680 train_time:71581ms step_avg:89.48ms +step:801/1680 train_time:71671ms step_avg:89.48ms +step:802/1680 train_time:71761ms step_avg:89.48ms +step:803/1680 train_time:71851ms step_avg:89.48ms +step:804/1680 train_time:71941ms step_avg:89.48ms +step:805/1680 train_time:72030ms step_avg:89.48ms +step:806/1680 train_time:72120ms step_avg:89.48ms +step:807/1680 train_time:72211ms step_avg:89.48ms +step:808/1680 train_time:72301ms step_avg:89.48ms +step:809/1680 train_time:72393ms step_avg:89.48ms +step:810/1680 train_time:72483ms step_avg:89.49ms +step:811/1680 train_time:72573ms step_avg:89.49ms +step:812/1680 train_time:72663ms step_avg:89.49ms +step:813/1680 train_time:72754ms step_avg:89.49ms +step:814/1680 train_time:72843ms step_avg:89.49ms +step:815/1680 train_time:72933ms step_avg:89.49ms +step:816/1680 train_time:73023ms step_avg:89.49ms +step:817/1680 train_time:73114ms step_avg:89.49ms +step:818/1680 train_time:73204ms step_avg:89.49ms +step:819/1680 train_time:73296ms step_avg:89.49ms +step:820/1680 train_time:73386ms step_avg:89.50ms +step:821/1680 train_time:73477ms step_avg:89.50ms +step:822/1680 train_time:73567ms step_avg:89.50ms +step:823/1680 train_time:73657ms step_avg:89.50ms +step:824/1680 train_time:73747ms step_avg:89.50ms +step:825/1680 train_time:73837ms step_avg:89.50ms +step:826/1680 train_time:73927ms step_avg:89.50ms +step:827/1680 train_time:74017ms step_avg:89.50ms +step:828/1680 train_time:74108ms step_avg:89.50ms +step:829/1680 train_time:74198ms step_avg:89.50ms +step:830/1680 train_time:74288ms step_avg:89.50ms +step:831/1680 train_time:74378ms step_avg:89.50ms +step:832/1680 train_time:74468ms step_avg:89.51ms +step:833/1680 train_time:74558ms step_avg:89.51ms +step:834/1680 train_time:74649ms step_avg:89.51ms +step:835/1680 train_time:74739ms step_avg:89.51ms +step:836/1680 train_time:74829ms step_avg:89.51ms +step:837/1680 train_time:74920ms step_avg:89.51ms +step:838/1680 train_time:75011ms step_avg:89.51ms +step:839/1680 train_time:75100ms step_avg:89.51ms +step:840/1680 train_time:75191ms step_avg:89.51ms +step:841/1680 train_time:75281ms step_avg:89.51ms +step:842/1680 train_time:75371ms step_avg:89.51ms +step:843/1680 train_time:75461ms step_avg:89.51ms +step:844/1680 train_time:75551ms step_avg:89.52ms +step:845/1680 train_time:75641ms step_avg:89.52ms +step:846/1680 train_time:75732ms step_avg:89.52ms +step:847/1680 train_time:75822ms step_avg:89.52ms +step:848/1680 train_time:75911ms step_avg:89.52ms +step:849/1680 train_time:76001ms step_avg:89.52ms +step:850/1680 train_time:76093ms step_avg:89.52ms +step:851/1680 train_time:76182ms step_avg:89.52ms +step:852/1680 train_time:76272ms step_avg:89.52ms +step:853/1680 train_time:76362ms step_avg:89.52ms +step:854/1680 train_time:76452ms step_avg:89.52ms +step:855/1680 train_time:76542ms step_avg:89.52ms +step:856/1680 train_time:76632ms step_avg:89.52ms +step:857/1680 train_time:76722ms step_avg:89.52ms +step:858/1680 train_time:76812ms step_avg:89.53ms +step:859/1680 train_time:76902ms step_avg:89.53ms +step:860/1680 train_time:76994ms step_avg:89.53ms +step:861/1680 train_time:77084ms step_avg:89.53ms +step:862/1680 train_time:77175ms step_avg:89.53ms +step:863/1680 train_time:77266ms step_avg:89.53ms +step:864/1680 train_time:77356ms step_avg:89.53ms +step:865/1680 train_time:77447ms step_avg:89.53ms +step:866/1680 train_time:77537ms step_avg:89.53ms +step:867/1680 train_time:77627ms step_avg:89.54ms +step:868/1680 train_time:77718ms step_avg:89.54ms +step:869/1680 train_time:77808ms step_avg:89.54ms +step:870/1680 train_time:77898ms step_avg:89.54ms +step:871/1680 train_time:77988ms step_avg:89.54ms +step:872/1680 train_time:78078ms step_avg:89.54ms +step:873/1680 train_time:78168ms step_avg:89.54ms +step:874/1680 train_time:78258ms step_avg:89.54ms +step:875/1680 train_time:78348ms step_avg:89.54ms +step:875/1680 val_loss:3.5170 train_time:78439ms step_avg:89.65ms +step:876/1680 train_time:78461ms step_avg:89.57ms +step:877/1680 train_time:78534ms step_avg:89.55ms +step:878/1680 train_time:78632ms step_avg:89.56ms +step:879/1680 train_time:78722ms step_avg:89.56ms +step:880/1680 train_time:78813ms step_avg:89.56ms +step:881/1680 train_time:78902ms step_avg:89.56ms +step:882/1680 train_time:78991ms step_avg:89.56ms +step:883/1680 train_time:79080ms step_avg:89.56ms +step:884/1680 train_time:79169ms step_avg:89.56ms +step:885/1680 train_time:79258ms step_avg:89.56ms +step:886/1680 train_time:79347ms step_avg:89.56ms +step:887/1680 train_time:79438ms step_avg:89.56ms +step:888/1680 train_time:79531ms step_avg:89.56ms +step:889/1680 train_time:79623ms step_avg:89.57ms +step:890/1680 train_time:79715ms step_avg:89.57ms +step:891/1680 train_time:79806ms step_avg:89.57ms +step:892/1680 train_time:79896ms step_avg:89.57ms +step:893/1680 train_time:79986ms step_avg:89.57ms +step:894/1680 train_time:80075ms step_avg:89.57ms +step:895/1680 train_time:80165ms step_avg:89.57ms +step:896/1680 train_time:80254ms step_avg:89.57ms +step:897/1680 train_time:80343ms step_avg:89.57ms +step:898/1680 train_time:80433ms step_avg:89.57ms +step:899/1680 train_time:80525ms step_avg:89.57ms +step:900/1680 train_time:80616ms step_avg:89.57ms +step:901/1680 train_time:80707ms step_avg:89.57ms +step:902/1680 train_time:80798ms step_avg:89.58ms +step:903/1680 train_time:80888ms step_avg:89.58ms +step:904/1680 train_time:80977ms step_avg:89.58ms +step:905/1680 train_time:81067ms step_avg:89.58ms +step:906/1680 train_time:81157ms step_avg:89.58ms +step:907/1680 train_time:81246ms step_avg:89.58ms +step:908/1680 train_time:81335ms step_avg:89.58ms +step:909/1680 train_time:81425ms step_avg:89.58ms +step:910/1680 train_time:81515ms step_avg:89.58ms +step:911/1680 train_time:81606ms step_avg:89.58ms +step:912/1680 train_time:81697ms step_avg:89.58ms +step:913/1680 train_time:81787ms step_avg:89.58ms +step:914/1680 train_time:81877ms step_avg:89.58ms +step:915/1680 train_time:81968ms step_avg:89.58ms +step:916/1680 train_time:82058ms step_avg:89.58ms +step:917/1680 train_time:82150ms step_avg:89.59ms +step:918/1680 train_time:82240ms step_avg:89.59ms +step:919/1680 train_time:82329ms step_avg:89.59ms +step:920/1680 train_time:82420ms step_avg:89.59ms +step:921/1680 train_time:82510ms step_avg:89.59ms +step:922/1680 train_time:82601ms step_avg:89.59ms +step:923/1680 train_time:82692ms step_avg:89.59ms +step:924/1680 train_time:82783ms step_avg:89.59ms +step:925/1680 train_time:82872ms step_avg:89.59ms +step:926/1680 train_time:82963ms step_avg:89.59ms +step:927/1680 train_time:83054ms step_avg:89.59ms +step:928/1680 train_time:83144ms step_avg:89.60ms +step:929/1680 train_time:83234ms step_avg:89.60ms +step:930/1680 train_time:83324ms step_avg:89.60ms +step:931/1680 train_time:83413ms step_avg:89.60ms +step:932/1680 train_time:83504ms step_avg:89.60ms +step:933/1680 train_time:83594ms step_avg:89.60ms +step:934/1680 train_time:83685ms step_avg:89.60ms +step:935/1680 train_time:83775ms step_avg:89.60ms +step:936/1680 train_time:83866ms step_avg:89.60ms +step:937/1680 train_time:83956ms step_avg:89.60ms +step:938/1680 train_time:84046ms step_avg:89.60ms +step:939/1680 train_time:84135ms step_avg:89.60ms +step:940/1680 train_time:84225ms step_avg:89.60ms +step:941/1680 train_time:84314ms step_avg:89.60ms +step:942/1680 train_time:84405ms step_avg:89.60ms +step:943/1680 train_time:84494ms step_avg:89.60ms +step:944/1680 train_time:84585ms step_avg:89.60ms +step:945/1680 train_time:84675ms step_avg:89.60ms +step:946/1680 train_time:84765ms step_avg:89.60ms +step:947/1680 train_time:84856ms step_avg:89.61ms +step:948/1680 train_time:84947ms step_avg:89.61ms +step:949/1680 train_time:85037ms step_avg:89.61ms +step:950/1680 train_time:85127ms step_avg:89.61ms +step:951/1680 train_time:85216ms step_avg:89.61ms +step:952/1680 train_time:85307ms step_avg:89.61ms +step:953/1680 train_time:85396ms step_avg:89.61ms +step:954/1680 train_time:85487ms step_avg:89.61ms +step:955/1680 train_time:85576ms step_avg:89.61ms +step:956/1680 train_time:85667ms step_avg:89.61ms +step:957/1680 train_time:85758ms step_avg:89.61ms +step:958/1680 train_time:85848ms step_avg:89.61ms +step:959/1680 train_time:85938ms step_avg:89.61ms +step:960/1680 train_time:86028ms step_avg:89.61ms +step:961/1680 train_time:86118ms step_avg:89.61ms +step:962/1680 train_time:86209ms step_avg:89.61ms +step:963/1680 train_time:86298ms step_avg:89.61ms +step:964/1680 train_time:86389ms step_avg:89.62ms +step:965/1680 train_time:86478ms step_avg:89.61ms +step:966/1680 train_time:86568ms step_avg:89.62ms +step:967/1680 train_time:86659ms step_avg:89.62ms +step:968/1680 train_time:86750ms step_avg:89.62ms +step:969/1680 train_time:86841ms step_avg:89.62ms +step:970/1680 train_time:86931ms step_avg:89.62ms +step:971/1680 train_time:87020ms step_avg:89.62ms +step:972/1680 train_time:87110ms step_avg:89.62ms +step:973/1680 train_time:87200ms step_avg:89.62ms +step:974/1680 train_time:87290ms step_avg:89.62ms +step:975/1680 train_time:87380ms step_avg:89.62ms +step:976/1680 train_time:87470ms step_avg:89.62ms +step:977/1680 train_time:87560ms step_avg:89.62ms +step:978/1680 train_time:87651ms step_avg:89.62ms +step:979/1680 train_time:87741ms step_avg:89.62ms +step:980/1680 train_time:87831ms step_avg:89.62ms +step:981/1680 train_time:87921ms step_avg:89.62ms +step:982/1680 train_time:88012ms step_avg:89.63ms +step:983/1680 train_time:88102ms step_avg:89.63ms +step:984/1680 train_time:88193ms step_avg:89.63ms +step:985/1680 train_time:88283ms step_avg:89.63ms +step:986/1680 train_time:88373ms step_avg:89.63ms +step:987/1680 train_time:88462ms step_avg:89.63ms +step:988/1680 train_time:88552ms step_avg:89.63ms +step:989/1680 train_time:88642ms step_avg:89.63ms +step:990/1680 train_time:88732ms step_avg:89.63ms +step:991/1680 train_time:88822ms step_avg:89.63ms +step:992/1680 train_time:88913ms step_avg:89.63ms +step:993/1680 train_time:89003ms step_avg:89.63ms +step:994/1680 train_time:89093ms step_avg:89.63ms +step:995/1680 train_time:89184ms step_avg:89.63ms +step:996/1680 train_time:89274ms step_avg:89.63ms +step:997/1680 train_time:89364ms step_avg:89.63ms +step:998/1680 train_time:89454ms step_avg:89.63ms +step:999/1680 train_time:89544ms step_avg:89.63ms +step:1000/1680 train_time:89635ms step_avg:89.63ms +step:1000/1680 val_loss:3.4686 train_time:89727ms step_avg:89.73ms +step:1001/1680 train_time:89749ms step_avg:89.66ms +step:1002/1680 train_time:89821ms step_avg:89.64ms +step:1003/1680 train_time:89915ms step_avg:89.65ms +step:1004/1680 train_time:90005ms step_avg:89.65ms +step:1005/1680 train_time:90094ms step_avg:89.65ms +step:1006/1680 train_time:90184ms step_avg:89.65ms +step:1007/1680 train_time:90273ms step_avg:89.65ms +step:1008/1680 train_time:90362ms step_avg:89.65ms +step:1009/1680 train_time:90451ms step_avg:89.64ms +step:1010/1680 train_time:90540ms step_avg:89.64ms +step:1011/1680 train_time:90629ms step_avg:89.64ms +step:1012/1680 train_time:90720ms step_avg:89.64ms +step:1013/1680 train_time:90814ms step_avg:89.65ms +step:1014/1680 train_time:90907ms step_avg:89.65ms +step:1015/1680 train_time:90998ms step_avg:89.65ms +step:1016/1680 train_time:91087ms step_avg:89.65ms +step:1017/1680 train_time:91177ms step_avg:89.65ms +step:1018/1680 train_time:91267ms step_avg:89.65ms +step:1019/1680 train_time:91356ms step_avg:89.65ms +step:1020/1680 train_time:91445ms step_avg:89.65ms +step:1021/1680 train_time:91534ms step_avg:89.65ms +step:1022/1680 train_time:91623ms step_avg:89.65ms +step:1023/1680 train_time:91713ms step_avg:89.65ms +step:1024/1680 train_time:91805ms step_avg:89.65ms +step:1025/1680 train_time:91896ms step_avg:89.66ms +step:1026/1680 train_time:91987ms step_avg:89.66ms +step:1027/1680 train_time:92078ms step_avg:89.66ms +step:1028/1680 train_time:92168ms step_avg:89.66ms +step:1029/1680 train_time:92258ms step_avg:89.66ms +step:1030/1680 train_time:92347ms step_avg:89.66ms +step:1031/1680 train_time:92436ms step_avg:89.66ms +step:1032/1680 train_time:92526ms step_avg:89.66ms +step:1033/1680 train_time:92616ms step_avg:89.66ms +step:1034/1680 train_time:92706ms step_avg:89.66ms +step:1035/1680 train_time:92797ms step_avg:89.66ms +step:1036/1680 train_time:92888ms step_avg:89.66ms +step:1037/1680 train_time:92978ms step_avg:89.66ms +step:1038/1680 train_time:93068ms step_avg:89.66ms +step:1039/1680 train_time:93159ms step_avg:89.66ms +step:1040/1680 train_time:93249ms step_avg:89.66ms +step:1041/1680 train_time:93339ms step_avg:89.66ms +step:1042/1680 train_time:93428ms step_avg:89.66ms +step:1043/1680 train_time:93518ms step_avg:89.66ms +step:1044/1680 train_time:93608ms step_avg:89.66ms +step:1045/1680 train_time:93698ms step_avg:89.66ms +step:1046/1680 train_time:93788ms step_avg:89.66ms +step:1047/1680 train_time:93879ms step_avg:89.66ms +step:1048/1680 train_time:93969ms step_avg:89.67ms +step:1049/1680 train_time:94059ms step_avg:89.67ms +step:1050/1680 train_time:94150ms step_avg:89.67ms +step:1051/1680 train_time:94241ms step_avg:89.67ms +step:1052/1680 train_time:94330ms step_avg:89.67ms +step:1053/1680 train_time:94421ms step_avg:89.67ms +step:1054/1680 train_time:94511ms step_avg:89.67ms +step:1055/1680 train_time:94601ms step_avg:89.67ms +step:1056/1680 train_time:94691ms step_avg:89.67ms +step:1057/1680 train_time:94781ms step_avg:89.67ms +step:1058/1680 train_time:94872ms step_avg:89.67ms +step:1059/1680 train_time:94963ms step_avg:89.67ms +step:1060/1680 train_time:95053ms step_avg:89.67ms +step:1061/1680 train_time:95143ms step_avg:89.67ms +step:1062/1680 train_time:95234ms step_avg:89.67ms +step:1063/1680 train_time:95324ms step_avg:89.67ms +step:1064/1680 train_time:95414ms step_avg:89.67ms +step:1065/1680 train_time:95505ms step_avg:89.68ms +step:1066/1680 train_time:95594ms step_avg:89.68ms +step:1067/1680 train_time:95685ms step_avg:89.68ms +step:1068/1680 train_time:95775ms step_avg:89.68ms +step:1069/1680 train_time:95865ms step_avg:89.68ms +step:1070/1680 train_time:95955ms step_avg:89.68ms +step:1071/1680 train_time:96046ms step_avg:89.68ms +step:1072/1680 train_time:96135ms step_avg:89.68ms +step:1073/1680 train_time:96226ms step_avg:89.68ms +step:1074/1680 train_time:96315ms step_avg:89.68ms +step:1075/1680 train_time:96405ms step_avg:89.68ms +step:1076/1680 train_time:96495ms step_avg:89.68ms +step:1077/1680 train_time:96584ms step_avg:89.68ms +step:1078/1680 train_time:96675ms step_avg:89.68ms +step:1079/1680 train_time:96765ms step_avg:89.68ms +step:1080/1680 train_time:96855ms step_avg:89.68ms +step:1081/1680 train_time:96946ms step_avg:89.68ms +step:1082/1680 train_time:97037ms step_avg:89.68ms +step:1083/1680 train_time:97127ms step_avg:89.68ms +step:1084/1680 train_time:97218ms step_avg:89.68ms +step:1085/1680 train_time:97309ms step_avg:89.69ms +step:1086/1680 train_time:97399ms step_avg:89.69ms +step:1087/1680 train_time:97489ms step_avg:89.69ms +step:1088/1680 train_time:97580ms step_avg:89.69ms +step:1089/1680 train_time:97669ms step_avg:89.69ms +step:1090/1680 train_time:97759ms step_avg:89.69ms +step:1091/1680 train_time:97849ms step_avg:89.69ms +step:1092/1680 train_time:97939ms step_avg:89.69ms +step:1093/1680 train_time:98028ms step_avg:89.69ms +step:1094/1680 train_time:98119ms step_avg:89.69ms +step:1095/1680 train_time:98210ms step_avg:89.69ms +step:1096/1680 train_time:98301ms step_avg:89.69ms +step:1097/1680 train_time:98391ms step_avg:89.69ms +step:1098/1680 train_time:98482ms step_avg:89.69ms +step:1099/1680 train_time:98572ms step_avg:89.69ms +step:1100/1680 train_time:98663ms step_avg:89.69ms +step:1101/1680 train_time:98754ms step_avg:89.69ms +step:1102/1680 train_time:98844ms step_avg:89.70ms +step:1103/1680 train_time:98936ms step_avg:89.70ms +step:1104/1680 train_time:99026ms step_avg:89.70ms +step:1105/1680 train_time:99117ms step_avg:89.70ms +step:1106/1680 train_time:99208ms step_avg:89.70ms +step:1107/1680 train_time:99299ms step_avg:89.70ms +step:1108/1680 train_time:99389ms step_avg:89.70ms +step:1109/1680 train_time:99481ms step_avg:89.70ms +step:1110/1680 train_time:99571ms step_avg:89.70ms +step:1111/1680 train_time:99663ms step_avg:89.71ms +step:1112/1680 train_time:99753ms step_avg:89.71ms +step:1113/1680 train_time:99845ms step_avg:89.71ms +step:1114/1680 train_time:99935ms step_avg:89.71ms +step:1115/1680 train_time:100026ms step_avg:89.71ms +step:1116/1680 train_time:100117ms step_avg:89.71ms +step:1117/1680 train_time:100208ms step_avg:89.71ms +step:1118/1680 train_time:100299ms step_avg:89.71ms +step:1119/1680 train_time:100389ms step_avg:89.71ms +step:1120/1680 train_time:100480ms step_avg:89.71ms +step:1121/1680 train_time:100570ms step_avg:89.71ms +step:1122/1680 train_time:100661ms step_avg:89.72ms +step:1123/1680 train_time:100751ms step_avg:89.72ms +step:1124/1680 train_time:100842ms step_avg:89.72ms +step:1125/1680 train_time:100933ms step_avg:89.72ms +step:1125/1680 val_loss:3.4147 train_time:101026ms step_avg:89.80ms +step:1126/1680 train_time:101048ms step_avg:89.74ms +step:1127/1680 train_time:101118ms step_avg:89.72ms +step:1128/1680 train_time:101222ms step_avg:89.74ms +step:1129/1680 train_time:101313ms step_avg:89.74ms +step:1130/1680 train_time:101403ms step_avg:89.74ms +step:1131/1680 train_time:101493ms step_avg:89.74ms +step:1132/1680 train_time:101583ms step_avg:89.74ms +step:1133/1680 train_time:101673ms step_avg:89.74ms +step:1134/1680 train_time:101762ms step_avg:89.74ms +step:1135/1680 train_time:101852ms step_avg:89.74ms +step:1136/1680 train_time:101944ms step_avg:89.74ms +step:1137/1680 train_time:102037ms step_avg:89.74ms +step:1138/1680 train_time:102130ms step_avg:89.75ms +step:1139/1680 train_time:102223ms step_avg:89.75ms +step:1140/1680 train_time:102313ms step_avg:89.75ms +step:1141/1680 train_time:102404ms step_avg:89.75ms +step:1142/1680 train_time:102494ms step_avg:89.75ms +step:1143/1680 train_time:102584ms step_avg:89.75ms +step:1144/1680 train_time:102674ms step_avg:89.75ms +step:1145/1680 train_time:102763ms step_avg:89.75ms +step:1146/1680 train_time:102854ms step_avg:89.75ms +step:1147/1680 train_time:102946ms step_avg:89.75ms +step:1148/1680 train_time:103037ms step_avg:89.75ms +step:1149/1680 train_time:103130ms step_avg:89.76ms +step:1150/1680 train_time:103221ms step_avg:89.76ms +step:1151/1680 train_time:103312ms step_avg:89.76ms +step:1152/1680 train_time:103403ms step_avg:89.76ms +step:1153/1680 train_time:103494ms step_avg:89.76ms +step:1154/1680 train_time:103583ms step_avg:89.76ms +step:1155/1680 train_time:103673ms step_avg:89.76ms +step:1156/1680 train_time:103763ms step_avg:89.76ms +step:1157/1680 train_time:103853ms step_avg:89.76ms +step:1158/1680 train_time:103945ms step_avg:89.76ms +step:1159/1680 train_time:104036ms step_avg:89.76ms +step:1160/1680 train_time:104128ms step_avg:89.77ms +step:1161/1680 train_time:104220ms step_avg:89.77ms +step:1162/1680 train_time:104312ms step_avg:89.77ms +step:1163/1680 train_time:104402ms step_avg:89.77ms +step:1164/1680 train_time:104493ms step_avg:89.77ms +step:1165/1680 train_time:104583ms step_avg:89.77ms +step:1166/1680 train_time:104674ms step_avg:89.77ms +step:1167/1680 train_time:104764ms step_avg:89.77ms +step:1168/1680 train_time:104855ms step_avg:89.77ms +step:1169/1680 train_time:104947ms step_avg:89.78ms +step:1170/1680 train_time:105038ms step_avg:89.78ms +step:1171/1680 train_time:105130ms step_avg:89.78ms +step:1172/1680 train_time:105221ms step_avg:89.78ms +step:1173/1680 train_time:105312ms step_avg:89.78ms +step:1174/1680 train_time:105403ms step_avg:89.78ms +step:1175/1680 train_time:105494ms step_avg:89.78ms +step:1176/1680 train_time:105583ms step_avg:89.78ms +step:1177/1680 train_time:105675ms step_avg:89.78ms +step:1178/1680 train_time:105765ms step_avg:89.78ms +step:1179/1680 train_time:105856ms step_avg:89.78ms +step:1180/1680 train_time:105947ms step_avg:89.79ms +step:1181/1680 train_time:106038ms step_avg:89.79ms +step:1182/1680 train_time:106130ms step_avg:89.79ms +step:1183/1680 train_time:106221ms step_avg:89.79ms +step:1184/1680 train_time:106312ms step_avg:89.79ms +step:1185/1680 train_time:106402ms step_avg:89.79ms +step:1186/1680 train_time:106492ms step_avg:89.79ms +step:1187/1680 train_time:106582ms step_avg:89.79ms +step:1188/1680 train_time:106672ms step_avg:89.79ms +step:1189/1680 train_time:106762ms step_avg:89.79ms +step:1190/1680 train_time:106853ms step_avg:89.79ms +step:1191/1680 train_time:106945ms step_avg:89.79ms +step:1192/1680 train_time:107036ms step_avg:89.79ms +step:1193/1680 train_time:107127ms step_avg:89.80ms +step:1194/1680 train_time:107218ms step_avg:89.80ms +step:1195/1680 train_time:107309ms step_avg:89.80ms +step:1196/1680 train_time:107399ms step_avg:89.80ms +step:1197/1680 train_time:107489ms step_avg:89.80ms +step:1198/1680 train_time:107580ms step_avg:89.80ms +step:1199/1680 train_time:107671ms step_avg:89.80ms +step:1200/1680 train_time:107760ms step_avg:89.80ms +step:1201/1680 train_time:107852ms step_avg:89.80ms +step:1202/1680 train_time:107943ms step_avg:89.80ms +step:1203/1680 train_time:108035ms step_avg:89.80ms +step:1204/1680 train_time:108127ms step_avg:89.81ms +step:1205/1680 train_time:108217ms step_avg:89.81ms +step:1206/1680 train_time:108308ms step_avg:89.81ms +step:1207/1680 train_time:108399ms step_avg:89.81ms +step:1208/1680 train_time:108490ms step_avg:89.81ms +step:1209/1680 train_time:108580ms step_avg:89.81ms +step:1210/1680 train_time:108672ms step_avg:89.81ms +step:1211/1680 train_time:108762ms step_avg:89.81ms +step:1212/1680 train_time:108854ms step_avg:89.81ms +step:1213/1680 train_time:108946ms step_avg:89.81ms +step:1214/1680 train_time:109037ms step_avg:89.82ms +step:1215/1680 train_time:109127ms step_avg:89.82ms +step:1216/1680 train_time:109218ms step_avg:89.82ms +step:1217/1680 train_time:109309ms step_avg:89.82ms +step:1218/1680 train_time:109400ms step_avg:89.82ms +step:1219/1680 train_time:109490ms step_avg:89.82ms +step:1220/1680 train_time:109580ms step_avg:89.82ms +step:1221/1680 train_time:109670ms step_avg:89.82ms +step:1222/1680 train_time:109760ms step_avg:89.82ms +step:1223/1680 train_time:109852ms step_avg:89.82ms +step:1224/1680 train_time:109942ms step_avg:89.82ms +step:1225/1680 train_time:110035ms step_avg:89.82ms +step:1226/1680 train_time:110127ms step_avg:89.83ms +step:1227/1680 train_time:110218ms step_avg:89.83ms +step:1228/1680 train_time:110308ms step_avg:89.83ms +step:1229/1680 train_time:110399ms step_avg:89.83ms +step:1230/1680 train_time:110489ms step_avg:89.83ms +step:1231/1680 train_time:110579ms step_avg:89.83ms +step:1232/1680 train_time:110669ms step_avg:89.83ms +step:1233/1680 train_time:110760ms step_avg:89.83ms +step:1234/1680 train_time:110852ms step_avg:89.83ms +step:1235/1680 train_time:110944ms step_avg:89.83ms +step:1236/1680 train_time:111035ms step_avg:89.83ms +step:1237/1680 train_time:111126ms step_avg:89.83ms +step:1238/1680 train_time:111217ms step_avg:89.84ms +step:1239/1680 train_time:111308ms step_avg:89.84ms +step:1240/1680 train_time:111398ms step_avg:89.84ms +step:1241/1680 train_time:111488ms step_avg:89.84ms +step:1242/1680 train_time:111579ms step_avg:89.84ms +step:1243/1680 train_time:111669ms step_avg:89.84ms +step:1244/1680 train_time:111760ms step_avg:89.84ms +step:1245/1680 train_time:111852ms step_avg:89.84ms +step:1246/1680 train_time:111943ms step_avg:89.84ms +step:1247/1680 train_time:112033ms step_avg:89.84ms +step:1248/1680 train_time:112124ms step_avg:89.84ms +step:1249/1680 train_time:112216ms step_avg:89.84ms +step:1250/1680 train_time:112307ms step_avg:89.85ms +step:1250/1680 val_loss:3.3759 train_time:112399ms step_avg:89.92ms +step:1251/1680 train_time:112421ms step_avg:89.86ms +step:1252/1680 train_time:112495ms step_avg:89.85ms +step:1253/1680 train_time:112591ms step_avg:89.86ms +step:1254/1680 train_time:112682ms step_avg:89.86ms +step:1255/1680 train_time:112772ms step_avg:89.86ms +step:1256/1680 train_time:112862ms step_avg:89.86ms +step:1257/1680 train_time:112952ms step_avg:89.86ms +step:1258/1680 train_time:113042ms step_avg:89.86ms +step:1259/1680 train_time:113132ms step_avg:89.86ms +step:1260/1680 train_time:113221ms step_avg:89.86ms +step:1261/1680 train_time:113312ms step_avg:89.86ms +step:1262/1680 train_time:113404ms step_avg:89.86ms +step:1263/1680 train_time:113498ms step_avg:89.86ms +step:1264/1680 train_time:113591ms step_avg:89.87ms +step:1265/1680 train_time:113682ms step_avg:89.87ms +step:1266/1680 train_time:113773ms step_avg:89.87ms +step:1267/1680 train_time:113863ms step_avg:89.87ms +step:1268/1680 train_time:113953ms step_avg:89.87ms +step:1269/1680 train_time:114042ms step_avg:89.87ms +step:1270/1680 train_time:114132ms step_avg:89.87ms +step:1271/1680 train_time:114221ms step_avg:89.87ms +step:1272/1680 train_time:114312ms step_avg:89.87ms +step:1273/1680 train_time:114405ms step_avg:89.87ms +step:1274/1680 train_time:114496ms step_avg:89.87ms +step:1275/1680 train_time:114589ms step_avg:89.87ms +step:1276/1680 train_time:114679ms step_avg:89.87ms +step:1277/1680 train_time:114771ms step_avg:89.88ms +step:1278/1680 train_time:114862ms step_avg:89.88ms +step:1279/1680 train_time:114951ms step_avg:89.88ms +step:1280/1680 train_time:115041ms step_avg:89.88ms +step:1281/1680 train_time:115131ms step_avg:89.88ms +step:1282/1680 train_time:115222ms step_avg:89.88ms +step:1283/1680 train_time:115313ms step_avg:89.88ms +step:1284/1680 train_time:115405ms step_avg:89.88ms +step:1285/1680 train_time:115496ms step_avg:89.88ms +step:1286/1680 train_time:115589ms step_avg:89.88ms +step:1287/1680 train_time:115680ms step_avg:89.88ms +step:1288/1680 train_time:115774ms step_avg:89.89ms +step:1289/1680 train_time:115866ms step_avg:89.89ms +step:1290/1680 train_time:115956ms step_avg:89.89ms +step:1291/1680 train_time:116048ms step_avg:89.89ms +step:1292/1680 train_time:116137ms step_avg:89.89ms +step:1293/1680 train_time:116227ms step_avg:89.89ms +step:1294/1680 train_time:116318ms step_avg:89.89ms +step:1295/1680 train_time:116409ms step_avg:89.89ms +step:1296/1680 train_time:116499ms step_avg:89.89ms +step:1297/1680 train_time:116591ms step_avg:89.89ms +step:1298/1680 train_time:116683ms step_avg:89.89ms +step:1299/1680 train_time:116775ms step_avg:89.90ms +step:1300/1680 train_time:116866ms step_avg:89.90ms +step:1301/1680 train_time:116956ms step_avg:89.90ms +step:1302/1680 train_time:117047ms step_avg:89.90ms +step:1303/1680 train_time:117137ms step_avg:89.90ms +step:1304/1680 train_time:117228ms step_avg:89.90ms +step:1305/1680 train_time:117318ms step_avg:89.90ms +step:1306/1680 train_time:117409ms step_avg:89.90ms +step:1307/1680 train_time:117499ms step_avg:89.90ms +step:1308/1680 train_time:117590ms step_avg:89.90ms +step:1309/1680 train_time:117681ms step_avg:89.90ms +step:1310/1680 train_time:117774ms step_avg:89.90ms +step:1311/1680 train_time:117865ms step_avg:89.90ms +step:1312/1680 train_time:117956ms step_avg:89.91ms +step:1313/1680 train_time:118047ms step_avg:89.91ms +step:1314/1680 train_time:118137ms step_avg:89.91ms +step:1315/1680 train_time:118227ms step_avg:89.91ms +step:1316/1680 train_time:118318ms step_avg:89.91ms +step:1317/1680 train_time:118410ms step_avg:89.91ms +step:1318/1680 train_time:118500ms step_avg:89.91ms +step:1319/1680 train_time:118592ms step_avg:89.91ms +step:1320/1680 train_time:118684ms step_avg:89.91ms +step:1321/1680 train_time:118775ms step_avg:89.91ms +step:1322/1680 train_time:118866ms step_avg:89.91ms +step:1323/1680 train_time:118957ms step_avg:89.91ms +step:1324/1680 train_time:119047ms step_avg:89.91ms +step:1325/1680 train_time:119138ms step_avg:89.92ms +step:1326/1680 train_time:119228ms step_avg:89.92ms +step:1327/1680 train_time:119318ms step_avg:89.92ms +step:1328/1680 train_time:119409ms step_avg:89.92ms +step:1329/1680 train_time:119501ms step_avg:89.92ms +step:1330/1680 train_time:119592ms step_avg:89.92ms +step:1331/1680 train_time:119683ms step_avg:89.92ms +step:1332/1680 train_time:119774ms step_avg:89.92ms +step:1333/1680 train_time:119865ms step_avg:89.92ms +step:1334/1680 train_time:119956ms step_avg:89.92ms +step:1335/1680 train_time:120047ms step_avg:89.92ms +step:1336/1680 train_time:120138ms step_avg:89.92ms +step:1337/1680 train_time:120229ms step_avg:89.92ms +step:1338/1680 train_time:120319ms step_avg:89.92ms +step:1339/1680 train_time:120410ms step_avg:89.93ms +step:1340/1680 train_time:120500ms step_avg:89.93ms +step:1341/1680 train_time:120591ms step_avg:89.93ms +step:1342/1680 train_time:120683ms step_avg:89.93ms +step:1343/1680 train_time:120774ms step_avg:89.93ms +step:1344/1680 train_time:120865ms step_avg:89.93ms +step:1345/1680 train_time:120955ms step_avg:89.93ms +step:1346/1680 train_time:121046ms step_avg:89.93ms +step:1347/1680 train_time:121138ms step_avg:89.93ms +step:1348/1680 train_time:121229ms step_avg:89.93ms +step:1349/1680 train_time:121321ms step_avg:89.93ms +step:1350/1680 train_time:121413ms step_avg:89.94ms +step:1351/1680 train_time:121505ms step_avg:89.94ms +step:1352/1680 train_time:121596ms step_avg:89.94ms +step:1353/1680 train_time:121687ms step_avg:89.94ms +step:1354/1680 train_time:121778ms step_avg:89.94ms +step:1355/1680 train_time:121870ms step_avg:89.94ms +step:1356/1680 train_time:121960ms step_avg:89.94ms +step:1357/1680 train_time:122051ms step_avg:89.94ms +step:1358/1680 train_time:122141ms step_avg:89.94ms +step:1359/1680 train_time:122232ms step_avg:89.94ms +step:1360/1680 train_time:122322ms step_avg:89.94ms +step:1361/1680 train_time:122413ms step_avg:89.94ms +step:1362/1680 train_time:122504ms step_avg:89.94ms +step:1363/1680 train_time:122595ms step_avg:89.95ms +step:1364/1680 train_time:122686ms step_avg:89.95ms +step:1365/1680 train_time:122777ms step_avg:89.95ms +step:1366/1680 train_time:122867ms step_avg:89.95ms +step:1367/1680 train_time:122958ms step_avg:89.95ms +step:1368/1680 train_time:123048ms step_avg:89.95ms +step:1369/1680 train_time:123139ms step_avg:89.95ms +step:1370/1680 train_time:123229ms step_avg:89.95ms +step:1371/1680 train_time:123320ms step_avg:89.95ms +step:1372/1680 train_time:123411ms step_avg:89.95ms +step:1373/1680 train_time:123503ms step_avg:89.95ms +step:1374/1680 train_time:123594ms step_avg:89.95ms +step:1375/1680 train_time:123686ms step_avg:89.95ms +step:1375/1680 val_loss:3.3412 train_time:123778ms step_avg:90.02ms +step:1376/1680 train_time:123800ms step_avg:89.97ms +step:1377/1680 train_time:123873ms step_avg:89.96ms +step:1378/1680 train_time:123969ms step_avg:89.96ms +step:1379/1680 train_time:124060ms step_avg:89.96ms +step:1380/1680 train_time:124151ms step_avg:89.96ms +step:1381/1680 train_time:124241ms step_avg:89.96ms +step:1382/1680 train_time:124330ms step_avg:89.96ms +step:1383/1680 train_time:124420ms step_avg:89.96ms +step:1384/1680 train_time:124510ms step_avg:89.96ms +step:1385/1680 train_time:124600ms step_avg:89.96ms +step:1386/1680 train_time:124690ms step_avg:89.96ms +step:1387/1680 train_time:124783ms step_avg:89.97ms +step:1388/1680 train_time:124876ms step_avg:89.97ms +step:1389/1680 train_time:124969ms step_avg:89.97ms +step:1390/1680 train_time:125061ms step_avg:89.97ms +step:1391/1680 train_time:125152ms step_avg:89.97ms +step:1392/1680 train_time:125242ms step_avg:89.97ms +step:1393/1680 train_time:125332ms step_avg:89.97ms +step:1394/1680 train_time:125421ms step_avg:89.97ms +step:1395/1680 train_time:125511ms step_avg:89.97ms +step:1396/1680 train_time:125601ms step_avg:89.97ms +step:1397/1680 train_time:125691ms step_avg:89.97ms +step:1398/1680 train_time:125783ms step_avg:89.97ms +step:1399/1680 train_time:125874ms step_avg:89.97ms +step:1400/1680 train_time:125967ms step_avg:89.98ms +step:1401/1680 train_time:126058ms step_avg:89.98ms +step:1402/1680 train_time:126150ms step_avg:89.98ms +step:1403/1680 train_time:126241ms step_avg:89.98ms +step:1404/1680 train_time:126331ms step_avg:89.98ms +step:1405/1680 train_time:126421ms step_avg:89.98ms +step:1406/1680 train_time:126512ms step_avg:89.98ms +step:1407/1680 train_time:126602ms step_avg:89.98ms +step:1408/1680 train_time:126693ms step_avg:89.98ms +step:1409/1680 train_time:126784ms step_avg:89.98ms +step:1410/1680 train_time:126877ms step_avg:89.98ms +step:1411/1680 train_time:126969ms step_avg:89.99ms +step:1412/1680 train_time:127060ms step_avg:89.99ms +step:1413/1680 train_time:127152ms step_avg:89.99ms +step:1414/1680 train_time:127243ms step_avg:89.99ms +step:1415/1680 train_time:127333ms step_avg:89.99ms +step:1416/1680 train_time:127424ms step_avg:89.99ms +step:1417/1680 train_time:127514ms step_avg:89.99ms +step:1418/1680 train_time:127604ms step_avg:89.99ms +step:1419/1680 train_time:127694ms step_avg:89.99ms +step:1420/1680 train_time:127784ms step_avg:89.99ms +step:1421/1680 train_time:127876ms step_avg:89.99ms +step:1422/1680 train_time:127967ms step_avg:89.99ms +step:1423/1680 train_time:128057ms step_avg:89.99ms +step:1424/1680 train_time:128150ms step_avg:89.99ms +step:1425/1680 train_time:128242ms step_avg:89.99ms +step:1426/1680 train_time:128332ms step_avg:89.99ms +step:1427/1680 train_time:128424ms step_avg:90.00ms +step:1428/1680 train_time:128514ms step_avg:90.00ms +step:1429/1680 train_time:128604ms step_avg:90.00ms +step:1430/1680 train_time:128694ms step_avg:90.00ms +step:1431/1680 train_time:128785ms step_avg:90.00ms +step:1432/1680 train_time:128876ms step_avg:90.00ms +step:1433/1680 train_time:128967ms step_avg:90.00ms +step:1434/1680 train_time:129057ms step_avg:90.00ms +step:1435/1680 train_time:129150ms step_avg:90.00ms +step:1436/1680 train_time:129242ms step_avg:90.00ms +step:1437/1680 train_time:129331ms step_avg:90.00ms +step:1438/1680 train_time:129422ms step_avg:90.00ms +step:1439/1680 train_time:129512ms step_avg:90.00ms +step:1440/1680 train_time:129603ms step_avg:90.00ms +step:1441/1680 train_time:129693ms step_avg:90.00ms +step:1442/1680 train_time:129784ms step_avg:90.00ms +step:1443/1680 train_time:129874ms step_avg:90.00ms +step:1444/1680 train_time:129965ms step_avg:90.00ms +step:1445/1680 train_time:130056ms step_avg:90.00ms +step:1446/1680 train_time:130149ms step_avg:90.01ms +step:1447/1680 train_time:130239ms step_avg:90.01ms +step:1448/1680 train_time:130330ms step_avg:90.01ms +step:1449/1680 train_time:130422ms step_avg:90.01ms +step:1450/1680 train_time:130512ms step_avg:90.01ms +step:1451/1680 train_time:130603ms step_avg:90.01ms +step:1452/1680 train_time:130693ms step_avg:90.01ms +step:1453/1680 train_time:130784ms step_avg:90.01ms +step:1454/1680 train_time:130874ms step_avg:90.01ms +step:1455/1680 train_time:130966ms step_avg:90.01ms +step:1456/1680 train_time:131057ms step_avg:90.01ms +step:1457/1680 train_time:131148ms step_avg:90.01ms +step:1458/1680 train_time:131239ms step_avg:90.01ms +step:1459/1680 train_time:131331ms step_avg:90.01ms +step:1460/1680 train_time:131423ms step_avg:90.02ms +step:1461/1680 train_time:131513ms step_avg:90.02ms +step:1462/1680 train_time:131604ms step_avg:90.02ms +step:1463/1680 train_time:131695ms step_avg:90.02ms +step:1464/1680 train_time:131786ms step_avg:90.02ms +step:1465/1680 train_time:131876ms step_avg:90.02ms +step:1466/1680 train_time:131968ms step_avg:90.02ms +step:1467/1680 train_time:132058ms step_avg:90.02ms +step:1468/1680 train_time:132149ms step_avg:90.02ms +step:1469/1680 train_time:132241ms step_avg:90.02ms +step:1470/1680 train_time:132332ms step_avg:90.02ms +step:1471/1680 train_time:132425ms step_avg:90.02ms +step:1472/1680 train_time:132516ms step_avg:90.02ms +step:1473/1680 train_time:132608ms step_avg:90.03ms +step:1474/1680 train_time:132699ms step_avg:90.03ms +step:1475/1680 train_time:132789ms step_avg:90.03ms +step:1476/1680 train_time:132879ms step_avg:90.03ms +step:1477/1680 train_time:132970ms step_avg:90.03ms +step:1478/1680 train_time:133061ms step_avg:90.03ms +step:1479/1680 train_time:133152ms step_avg:90.03ms +step:1480/1680 train_time:133243ms step_avg:90.03ms +step:1481/1680 train_time:133334ms step_avg:90.03ms +step:1482/1680 train_time:133426ms step_avg:90.03ms +step:1483/1680 train_time:133517ms step_avg:90.03ms +step:1484/1680 train_time:133608ms step_avg:90.03ms +step:1485/1680 train_time:133700ms step_avg:90.03ms +step:1486/1680 train_time:133789ms step_avg:90.03ms +step:1487/1680 train_time:133880ms step_avg:90.03ms +step:1488/1680 train_time:133970ms step_avg:90.03ms +step:1489/1680 train_time:134060ms step_avg:90.03ms +step:1490/1680 train_time:134151ms step_avg:90.03ms +step:1491/1680 train_time:134242ms step_avg:90.04ms +step:1492/1680 train_time:134333ms step_avg:90.04ms +step:1493/1680 train_time:134425ms step_avg:90.04ms +step:1494/1680 train_time:134517ms step_avg:90.04ms +step:1495/1680 train_time:134610ms step_avg:90.04ms +step:1496/1680 train_time:134701ms step_avg:90.04ms +step:1497/1680 train_time:134792ms step_avg:90.04ms +step:1498/1680 train_time:134882ms step_avg:90.04ms +step:1499/1680 train_time:134973ms step_avg:90.04ms +step:1500/1680 train_time:135063ms step_avg:90.04ms +step:1500/1680 val_loss:3.3118 train_time:135156ms step_avg:90.10ms +step:1501/1680 train_time:135178ms step_avg:90.06ms +step:1502/1680 train_time:135252ms step_avg:90.05ms +step:1503/1680 train_time:135349ms step_avg:90.05ms +step:1504/1680 train_time:135441ms step_avg:90.05ms +step:1505/1680 train_time:135533ms step_avg:90.06ms +step:1506/1680 train_time:135623ms step_avg:90.06ms +step:1507/1680 train_time:135712ms step_avg:90.05ms +step:1508/1680 train_time:135802ms step_avg:90.05ms +step:1509/1680 train_time:135890ms step_avg:90.05ms +step:1510/1680 train_time:135980ms step_avg:90.05ms +step:1511/1680 train_time:136070ms step_avg:90.05ms +step:1512/1680 train_time:136162ms step_avg:90.05ms +step:1513/1680 train_time:136256ms step_avg:90.06ms +step:1514/1680 train_time:136352ms step_avg:90.06ms +step:1515/1680 train_time:136447ms step_avg:90.06ms +step:1516/1680 train_time:136538ms step_avg:90.06ms +step:1517/1680 train_time:136629ms step_avg:90.07ms +step:1518/1680 train_time:136718ms step_avg:90.06ms +step:1519/1680 train_time:136809ms step_avg:90.07ms +step:1520/1680 train_time:136899ms step_avg:90.07ms +step:1521/1680 train_time:136989ms step_avg:90.07ms +step:1522/1680 train_time:137081ms step_avg:90.07ms +step:1523/1680 train_time:137171ms step_avg:90.07ms +step:1524/1680 train_time:137263ms step_avg:90.07ms +step:1525/1680 train_time:137355ms step_avg:90.07ms +step:1526/1680 train_time:137446ms step_avg:90.07ms +step:1527/1680 train_time:137538ms step_avg:90.07ms +step:1528/1680 train_time:137629ms step_avg:90.07ms +step:1529/1680 train_time:137719ms step_avg:90.07ms +step:1530/1680 train_time:137810ms step_avg:90.07ms +step:1531/1680 train_time:137900ms step_avg:90.07ms +step:1532/1680 train_time:137990ms step_avg:90.07ms +step:1533/1680 train_time:138081ms step_avg:90.07ms +step:1534/1680 train_time:138171ms step_avg:90.07ms +step:1535/1680 train_time:138262ms step_avg:90.07ms +step:1536/1680 train_time:138353ms step_avg:90.07ms +step:1537/1680 train_time:138445ms step_avg:90.07ms +step:1538/1680 train_time:138535ms step_avg:90.08ms +step:1539/1680 train_time:138626ms step_avg:90.08ms +step:1540/1680 train_time:138717ms step_avg:90.08ms +step:1541/1680 train_time:138808ms step_avg:90.08ms +step:1542/1680 train_time:138898ms step_avg:90.08ms +step:1543/1680 train_time:138989ms step_avg:90.08ms +step:1544/1680 train_time:139079ms step_avg:90.08ms +step:1545/1680 train_time:139170ms step_avg:90.08ms +step:1546/1680 train_time:139261ms step_avg:90.08ms +step:1547/1680 train_time:139351ms step_avg:90.08ms +step:1548/1680 train_time:139443ms step_avg:90.08ms +step:1549/1680 train_time:139534ms step_avg:90.08ms +step:1550/1680 train_time:139625ms step_avg:90.08ms +step:1551/1680 train_time:139716ms step_avg:90.08ms +step:1552/1680 train_time:139807ms step_avg:90.08ms +step:1553/1680 train_time:139898ms step_avg:90.08ms +step:1554/1680 train_time:139988ms step_avg:90.08ms +step:1555/1680 train_time:140079ms step_avg:90.08ms +step:1556/1680 train_time:140170ms step_avg:90.08ms +step:1557/1680 train_time:140261ms step_avg:90.08ms +step:1558/1680 train_time:140352ms step_avg:90.08ms +step:1559/1680 train_time:140442ms step_avg:90.08ms +step:1560/1680 train_time:140533ms step_avg:90.09ms +step:1561/1680 train_time:140625ms step_avg:90.09ms +step:1562/1680 train_time:140715ms step_avg:90.09ms +step:1563/1680 train_time:140806ms step_avg:90.09ms +step:1564/1680 train_time:140896ms step_avg:90.09ms +step:1565/1680 train_time:140988ms step_avg:90.09ms +step:1566/1680 train_time:141080ms step_avg:90.09ms +step:1567/1680 train_time:141170ms step_avg:90.09ms +step:1568/1680 train_time:141262ms step_avg:90.09ms +step:1569/1680 train_time:141352ms step_avg:90.09ms +step:1570/1680 train_time:141443ms step_avg:90.09ms +step:1571/1680 train_time:141533ms step_avg:90.09ms +step:1572/1680 train_time:141624ms step_avg:90.09ms +step:1573/1680 train_time:141715ms step_avg:90.09ms +step:1574/1680 train_time:141806ms step_avg:90.09ms +step:1575/1680 train_time:141895ms step_avg:90.09ms +step:1576/1680 train_time:141987ms step_avg:90.09ms +step:1577/1680 train_time:142078ms step_avg:90.09ms +step:1578/1680 train_time:142170ms step_avg:90.09ms +step:1579/1680 train_time:142260ms step_avg:90.09ms +step:1580/1680 train_time:142350ms step_avg:90.10ms +step:1581/1680 train_time:142441ms step_avg:90.10ms +step:1582/1680 train_time:142531ms step_avg:90.10ms +step:1583/1680 train_time:142622ms step_avg:90.10ms +step:1584/1680 train_time:142713ms step_avg:90.10ms +step:1585/1680 train_time:142804ms step_avg:90.10ms +step:1586/1680 train_time:142894ms step_avg:90.10ms +step:1587/1680 train_time:142986ms step_avg:90.10ms +step:1588/1680 train_time:143077ms step_avg:90.10ms +step:1589/1680 train_time:143169ms step_avg:90.10ms +step:1590/1680 train_time:143260ms step_avg:90.10ms +step:1591/1680 train_time:143350ms step_avg:90.10ms +step:1592/1680 train_time:143441ms step_avg:90.10ms +step:1593/1680 train_time:143532ms step_avg:90.10ms +step:1594/1680 train_time:143623ms step_avg:90.10ms +step:1595/1680 train_time:143714ms step_avg:90.10ms +step:1596/1680 train_time:143804ms step_avg:90.10ms +step:1597/1680 train_time:143895ms step_avg:90.10ms +step:1598/1680 train_time:143987ms step_avg:90.10ms +step:1599/1680 train_time:144078ms step_avg:90.11ms +step:1600/1680 train_time:144169ms step_avg:90.11ms +step:1601/1680 train_time:144261ms step_avg:90.11ms +step:1602/1680 train_time:144351ms step_avg:90.11ms +step:1603/1680 train_time:144442ms step_avg:90.11ms +step:1604/1680 train_time:144532ms step_avg:90.11ms +step:1605/1680 train_time:144624ms step_avg:90.11ms +step:1606/1680 train_time:144714ms step_avg:90.11ms +step:1607/1680 train_time:144805ms step_avg:90.11ms +step:1608/1680 train_time:144895ms step_avg:90.11ms +step:1609/1680 train_time:144988ms step_avg:90.11ms +step:1610/1680 train_time:145081ms step_avg:90.11ms +step:1611/1680 train_time:145171ms step_avg:90.11ms +step:1612/1680 train_time:145261ms step_avg:90.11ms +step:1613/1680 train_time:145352ms step_avg:90.11ms +step:1614/1680 train_time:145442ms step_avg:90.11ms +step:1615/1680 train_time:145532ms step_avg:90.11ms +step:1616/1680 train_time:145623ms step_avg:90.11ms +step:1617/1680 train_time:145714ms step_avg:90.11ms +step:1618/1680 train_time:145805ms step_avg:90.11ms +step:1619/1680 train_time:145895ms step_avg:90.11ms +step:1620/1680 train_time:145987ms step_avg:90.12ms +step:1621/1680 train_time:146077ms step_avg:90.12ms +step:1622/1680 train_time:146169ms step_avg:90.12ms +step:1623/1680 train_time:146260ms step_avg:90.12ms +step:1624/1680 train_time:146351ms step_avg:90.12ms +step:1625/1680 train_time:146442ms step_avg:90.12ms +step:1625/1680 val_loss:3.2879 train_time:146533ms step_avg:90.17ms +step:1626/1680 train_time:146555ms step_avg:90.13ms +step:1627/1680 train_time:146628ms step_avg:90.12ms +step:1628/1680 train_time:146725ms step_avg:90.13ms +step:1629/1680 train_time:146817ms step_avg:90.13ms +step:1630/1680 train_time:146907ms step_avg:90.13ms +step:1631/1680 train_time:146996ms step_avg:90.13ms +step:1632/1680 train_time:147086ms step_avg:90.13ms +step:1633/1680 train_time:147175ms step_avg:90.13ms +step:1634/1680 train_time:147265ms step_avg:90.13ms +step:1635/1680 train_time:147354ms step_avg:90.12ms +step:1636/1680 train_time:147445ms step_avg:90.13ms +step:1637/1680 train_time:147536ms step_avg:90.13ms +step:1638/1680 train_time:147632ms step_avg:90.13ms +step:1639/1680 train_time:147727ms step_avg:90.13ms +step:1640/1680 train_time:147818ms step_avg:90.13ms +step:1641/1680 train_time:147908ms step_avg:90.13ms +step:1642/1680 train_time:147998ms step_avg:90.13ms +step:1643/1680 train_time:148088ms step_avg:90.13ms +step:1644/1680 train_time:148178ms step_avg:90.13ms +step:1645/1680 train_time:148268ms step_avg:90.13ms +step:1646/1680 train_time:148359ms step_avg:90.13ms +step:1647/1680 train_time:148449ms step_avg:90.13ms +step:1648/1680 train_time:148541ms step_avg:90.13ms +step:1649/1680 train_time:148633ms step_avg:90.14ms +step:1650/1680 train_time:148726ms step_avg:90.14ms +step:1651/1680 train_time:148817ms step_avg:90.14ms +step:1652/1680 train_time:148909ms step_avg:90.14ms +step:1653/1680 train_time:148999ms step_avg:90.14ms +step:1654/1680 train_time:149089ms step_avg:90.14ms +step:1655/1680 train_time:149180ms step_avg:90.14ms +step:1656/1680 train_time:149270ms step_avg:90.14ms +step:1657/1680 train_time:149360ms step_avg:90.14ms +step:1658/1680 train_time:149451ms step_avg:90.14ms +step:1659/1680 train_time:149542ms step_avg:90.14ms +step:1660/1680 train_time:149635ms step_avg:90.14ms +step:1661/1680 train_time:149727ms step_avg:90.14ms +step:1662/1680 train_time:149820ms step_avg:90.14ms +step:1663/1680 train_time:149911ms step_avg:90.15ms +step:1664/1680 train_time:150002ms step_avg:90.15ms +step:1665/1680 train_time:150092ms step_avg:90.15ms +step:1666/1680 train_time:150183ms step_avg:90.15ms +step:1667/1680 train_time:150272ms step_avg:90.15ms +step:1668/1680 train_time:150363ms step_avg:90.15ms +step:1669/1680 train_time:150453ms step_avg:90.15ms +step:1670/1680 train_time:150544ms step_avg:90.15ms +step:1671/1680 train_time:150635ms step_avg:90.15ms +step:1672/1680 train_time:150727ms step_avg:90.15ms +step:1673/1680 train_time:150819ms step_avg:90.15ms +step:1674/1680 train_time:150911ms step_avg:90.15ms +step:1675/1680 train_time:151002ms step_avg:90.15ms +step:1676/1680 train_time:151092ms step_avg:90.15ms +step:1677/1680 train_time:151183ms step_avg:90.15ms +step:1678/1680 train_time:151273ms step_avg:90.15ms +step:1679/1680 train_time:151364ms step_avg:90.15ms +step:1680/1680 train_time:151454ms step_avg:90.15ms +step:1680/1680 val_loss:3.2774 train_time:151546ms step_avg:90.21ms +peak memory allocated: 31255 MiB reserved: 46194 MiB diff --git a/records/092125_DropAttn/README.md b/records/092125_DropAttn/README.md new file mode 100644 index 000000000..4b7230c3a --- /dev/null +++ b/records/092125_DropAttn/README.md @@ -0,0 +1,44 @@ +## New WR 151.5s: Drop first attn layer, extend all long windows for validation, update schedule + +This PR builds on all recent WR improvements including PR #130. Updates: +* Drop the first attention layer. +* Increase step count from 1645 to 1680 +* Extend all long windows to size 20 for validation (-0.001 loss, or ~10 steps = 1s) +* Add arg iteration_extension, to specify number of steps to continue training at final lr and ws + +Several factors led to dropping the first attention layer: +* The first attention layer was [observed](https://medium.com/@larry36d/formation-of-induction-heads-in-modded-nanogpt-5eb899de89e4) to perform no meaningful contribution to induction heads. +* PR #120 dropped the first MLP, which resulted in two Attn layers with no intermediate transformation +* PR #130 added a smear module, which increased the networks ability to pass information between tokens + +Reason for iteration_extension: +* Showed 0.001 improvement in loss. +* Easier to fine tune. Without this parameter, changes to the step count have a large effect on the entire second half of training, making it harder to isolate impact of changes. This is because lr and ws are tied to the step count. EG if you increase step count by 10 you will suddenly see different loss at step 1000. + +### Future Opportunities +This change bring the total number of [4,768,768] attention variables to 10. There are 22 MLP variables of size [768x4,768]. In Muon attention is getting batched such that 6/16ths on the gradient calcs are on padding tokens. There may be a way to move 2 of the attention variables into the MLP batch, such that MLP is 24/24 and attn is 8/8, instead of MLP being 22/24 and attn being 10/16. + +### Investigating Muon for 1D variables +Currently the attention gates and smear gate are passed into Muon. From light inspection, the implementation of newton schulz appears to roughly apply F.normalize(x, p=2, dim=-1) for 1d variables. This normalization makes all steps cover roughly the same distance, regardless of the gradient. So for 1d variables Muon turns into an exponential smoothing over prior gradients, where each step is normalized to be roughly the same size. This seems somewhat reasonable. Swapping these variables over to Adam gave roughly a 0.5s runtime increase and no improvement in loss. Directly replacing newton schulz with F.normalize(x, p=2, dim=-1) for these variables showed slightly worse performance. I do not understand the theory here yet, but empirically the performance is good. + + +## Validation: +Code syntax/naming was lightly refactored after performing validation runs. Loss is roughly 0.001 lower than prior record, which is roughly equal to 1s. +``` +import scipy.stats +import torch + +accs = [3.2786, 3.2798, 3.2762, 3.2781, 3.2778, 3.2801, 3.2774, 3.2772, + 3.2777, 3.2789] + +times = [151.559, 151.526, 151.516, 151.527, 151.606, 151.771, 151.546, + 151.547, 151.44 , 151.872] + +print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue) +# p=0.0005 +print("acc:", torch.std_mean(torch.tensor(accs))) +# acc: (tensor(0.0012), tensor(3.2782)) + +print("time:", torch.std_mean(torch.tensor(times))) +# time: (tensor(0.1305), tensor(151.5910)) +``` \ No newline at end of file diff --git a/records/092125_DropAttn/a7f9849e-8c31-4e1a-9149-ab466d7c80b6.txt b/records/092125_DropAttn/a7f9849e-8c31-4e1a-9149-ab466d7c80b6.txt new file mode 100644 index 000000000..6e296080c --- /dev/null +++ b/records/092125_DropAttn/a7f9849e-8c31-4e1a-9149-ab466d7c80b6.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 23:28:47 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 40C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 87713 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 87714 C /usr/bin/python3 614MiB | +| 0 N/A N/A 87715 C /usr/bin/python3 614MiB | +| 0 N/A N/A 87716 C /usr/bin/python3 614MiB | +| 0 N/A N/A 87717 C /usr/bin/python3 614MiB | +| 0 N/A N/A 87718 C /usr/bin/python3 614MiB | +| 0 N/A N/A 87719 C /usr/bin/python3 614MiB | +| 0 N/A N/A 87720 C /usr/bin/python3 614MiB | +| 1 N/A N/A 87714 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 87715 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 87716 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 87717 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 87718 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 87719 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 87720 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:158ms step_avg:157.91ms +step:2/1680 train_time:181ms step_avg:90.57ms +step:3/1680 train_time:243ms step_avg:81.05ms +step:4/1680 train_time:330ms step_avg:82.50ms +step:5/1680 train_time:418ms step_avg:83.59ms +step:6/1680 train_time:506ms step_avg:84.30ms +step:7/1680 train_time:594ms step_avg:84.86ms +step:8/1680 train_time:682ms step_avg:85.28ms +step:9/1680 train_time:772ms step_avg:85.73ms +step:10/1680 train_time:860ms step_avg:85.98ms +step:11/1680 train_time:948ms step_avg:86.22ms +step:12/1680 train_time:1037ms step_avg:86.43ms +step:13/1680 train_time:1129ms step_avg:86.88ms +step:14/1680 train_time:1221ms step_avg:87.21ms +step:15/1680 train_time:1311ms step_avg:87.40ms +step:16/1680 train_time:1400ms step_avg:87.52ms +step:17/1680 train_time:1489ms step_avg:87.60ms +step:18/1680 train_time:1578ms step_avg:87.65ms +step:19/1680 train_time:1667ms step_avg:87.73ms +step:20/1680 train_time:1756ms step_avg:87.79ms +step:21/1680 train_time:1844ms step_avg:87.82ms +step:22/1680 train_time:1933ms step_avg:87.88ms +step:23/1680 train_time:2023ms step_avg:87.95ms +step:24/1680 train_time:2113ms step_avg:88.04ms +step:25/1680 train_time:2203ms step_avg:88.10ms +step:26/1680 train_time:2293ms step_avg:88.17ms +step:27/1680 train_time:2382ms step_avg:88.22ms +step:28/1680 train_time:2472ms step_avg:88.27ms +step:29/1680 train_time:2561ms step_avg:88.29ms +step:30/1680 train_time:2650ms step_avg:88.32ms +step:31/1680 train_time:2739ms step_avg:88.35ms +step:32/1680 train_time:2828ms step_avg:88.38ms +step:33/1680 train_time:2917ms step_avg:88.38ms +step:34/1680 train_time:3005ms step_avg:88.39ms +step:35/1680 train_time:3094ms step_avg:88.41ms +step:36/1680 train_time:3184ms step_avg:88.45ms +step:37/1680 train_time:3274ms step_avg:88.50ms +step:38/1680 train_time:3364ms step_avg:88.52ms +step:39/1680 train_time:3453ms step_avg:88.55ms +step:40/1680 train_time:3543ms step_avg:88.57ms +step:41/1680 train_time:3631ms step_avg:88.57ms +step:42/1680 train_time:3720ms step_avg:88.57ms +step:43/1680 train_time:3810ms step_avg:88.61ms +step:44/1680 train_time:3899ms step_avg:88.61ms +step:45/1680 train_time:3988ms step_avg:88.62ms +step:46/1680 train_time:4077ms step_avg:88.63ms +step:47/1680 train_time:4166ms step_avg:88.64ms +step:48/1680 train_time:4255ms step_avg:88.64ms +step:49/1680 train_time:4344ms step_avg:88.65ms +step:50/1680 train_time:4433ms step_avg:88.67ms +step:51/1680 train_time:4522ms step_avg:88.67ms +step:52/1680 train_time:4611ms step_avg:88.67ms +step:53/1680 train_time:4700ms step_avg:88.68ms +step:54/1680 train_time:4790ms step_avg:88.71ms +step:55/1680 train_time:4879ms step_avg:88.71ms +step:56/1680 train_time:4969ms step_avg:88.73ms +step:57/1680 train_time:5057ms step_avg:88.72ms +step:58/1680 train_time:5147ms step_avg:88.75ms +step:59/1680 train_time:5236ms step_avg:88.75ms +step:60/1680 train_time:5326ms step_avg:88.76ms +step:61/1680 train_time:5415ms step_avg:88.78ms +step:62/1680 train_time:5504ms step_avg:88.78ms +step:63/1680 train_time:5593ms step_avg:88.78ms +step:64/1680 train_time:5682ms step_avg:88.79ms +step:65/1680 train_time:5772ms step_avg:88.80ms +step:66/1680 train_time:5861ms step_avg:88.81ms +step:67/1680 train_time:5951ms step_avg:88.82ms +step:68/1680 train_time:6040ms step_avg:88.82ms +step:69/1680 train_time:6129ms step_avg:88.83ms +step:70/1680 train_time:6218ms step_avg:88.83ms +step:71/1680 train_time:6309ms step_avg:88.86ms +step:72/1680 train_time:6398ms step_avg:88.86ms +step:73/1680 train_time:6487ms step_avg:88.86ms +step:74/1680 train_time:6576ms step_avg:88.86ms +step:75/1680 train_time:6665ms step_avg:88.86ms +step:76/1680 train_time:6754ms step_avg:88.87ms +step:77/1680 train_time:6843ms step_avg:88.87ms +step:78/1680 train_time:6932ms step_avg:88.88ms +step:79/1680 train_time:7022ms step_avg:88.88ms +step:80/1680 train_time:7111ms step_avg:88.89ms +step:81/1680 train_time:7199ms step_avg:88.88ms +step:82/1680 train_time:7289ms step_avg:88.88ms +step:83/1680 train_time:7378ms step_avg:88.89ms +step:84/1680 train_time:7467ms step_avg:88.89ms +step:85/1680 train_time:7556ms step_avg:88.89ms +step:86/1680 train_time:7645ms step_avg:88.90ms +step:87/1680 train_time:7734ms step_avg:88.89ms +step:88/1680 train_time:7823ms step_avg:88.90ms +step:89/1680 train_time:7913ms step_avg:88.91ms +step:90/1680 train_time:8001ms step_avg:88.90ms +step:91/1680 train_time:8090ms step_avg:88.90ms +step:92/1680 train_time:8179ms step_avg:88.90ms +step:93/1680 train_time:8269ms step_avg:88.91ms +step:94/1680 train_time:8357ms step_avg:88.91ms +step:95/1680 train_time:8447ms step_avg:88.91ms +step:96/1680 train_time:8535ms step_avg:88.91ms +step:97/1680 train_time:8624ms step_avg:88.91ms +step:98/1680 train_time:8713ms step_avg:88.91ms +step:99/1680 train_time:8802ms step_avg:88.91ms +step:100/1680 train_time:8890ms step_avg:88.90ms +step:101/1680 train_time:8980ms step_avg:88.91ms +step:102/1680 train_time:9069ms step_avg:88.91ms +step:103/1680 train_time:9157ms step_avg:88.91ms +step:104/1680 train_time:9247ms step_avg:88.91ms +step:105/1680 train_time:9336ms step_avg:88.91ms +step:106/1680 train_time:9425ms step_avg:88.91ms +step:107/1680 train_time:9514ms step_avg:88.92ms +step:108/1680 train_time:9604ms step_avg:88.93ms +step:109/1680 train_time:9693ms step_avg:88.93ms +step:110/1680 train_time:9782ms step_avg:88.93ms +step:111/1680 train_time:9872ms step_avg:88.93ms +step:112/1680 train_time:9960ms step_avg:88.93ms +step:113/1680 train_time:10050ms step_avg:88.94ms +step:114/1680 train_time:10138ms step_avg:88.93ms +step:115/1680 train_time:10227ms step_avg:88.93ms +step:116/1680 train_time:10316ms step_avg:88.93ms +step:117/1680 train_time:10404ms step_avg:88.92ms +step:118/1680 train_time:10493ms step_avg:88.92ms +step:119/1680 train_time:10582ms step_avg:88.93ms +step:120/1680 train_time:10671ms step_avg:88.92ms +step:121/1680 train_time:10760ms step_avg:88.93ms +step:122/1680 train_time:10850ms step_avg:88.94ms +step:123/1680 train_time:10939ms step_avg:88.93ms +step:124/1680 train_time:11029ms step_avg:88.94ms +step:125/1680 train_time:11117ms step_avg:88.94ms +step:125/1680 val_loss:4.3124 train_time:11208ms step_avg:89.66ms +step:126/1680 train_time:11230ms step_avg:89.13ms +step:127/1680 train_time:11299ms step_avg:88.97ms +step:128/1680 train_time:11398ms step_avg:89.04ms +step:129/1680 train_time:11491ms step_avg:89.08ms +step:130/1680 train_time:11579ms step_avg:89.07ms +step:131/1680 train_time:11668ms step_avg:89.07ms +step:132/1680 train_time:11756ms step_avg:89.06ms +step:133/1680 train_time:11844ms step_avg:89.05ms +step:134/1680 train_time:11931ms step_avg:89.04ms +step:135/1680 train_time:12019ms step_avg:89.03ms +step:136/1680 train_time:12107ms step_avg:89.03ms +step:137/1680 train_time:12196ms step_avg:89.02ms +step:138/1680 train_time:12286ms step_avg:89.03ms +step:139/1680 train_time:12378ms step_avg:89.05ms +step:140/1680 train_time:12469ms step_avg:89.07ms +step:141/1680 train_time:12559ms step_avg:89.07ms +step:142/1680 train_time:12648ms step_avg:89.07ms +step:143/1680 train_time:12736ms step_avg:89.06ms +step:144/1680 train_time:12825ms step_avg:89.06ms +step:145/1680 train_time:12913ms step_avg:89.06ms +step:146/1680 train_time:13001ms step_avg:89.05ms +step:147/1680 train_time:13090ms step_avg:89.05ms +step:148/1680 train_time:13178ms step_avg:89.04ms +step:149/1680 train_time:13266ms step_avg:89.04ms +step:150/1680 train_time:13356ms step_avg:89.04ms +step:151/1680 train_time:13447ms step_avg:89.05ms +step:152/1680 train_time:13537ms step_avg:89.06ms +step:153/1680 train_time:13626ms step_avg:89.06ms +step:154/1680 train_time:13715ms step_avg:89.06ms +step:155/1680 train_time:13804ms step_avg:89.06ms +step:156/1680 train_time:13893ms step_avg:89.06ms +step:157/1680 train_time:13982ms step_avg:89.06ms +step:158/1680 train_time:14070ms step_avg:89.05ms +step:159/1680 train_time:14158ms step_avg:89.04ms +step:160/1680 train_time:14247ms step_avg:89.04ms +step:161/1680 train_time:14336ms step_avg:89.05ms +step:162/1680 train_time:14426ms step_avg:89.05ms +step:163/1680 train_time:14515ms step_avg:89.05ms +step:164/1680 train_time:14605ms step_avg:89.05ms +step:165/1680 train_time:14694ms step_avg:89.06ms +step:166/1680 train_time:14783ms step_avg:89.06ms +step:167/1680 train_time:14872ms step_avg:89.06ms +step:168/1680 train_time:14961ms step_avg:89.05ms +step:169/1680 train_time:15050ms step_avg:89.06ms +step:170/1680 train_time:15138ms step_avg:89.05ms +step:171/1680 train_time:15227ms step_avg:89.05ms +step:172/1680 train_time:15316ms step_avg:89.05ms +step:173/1680 train_time:15406ms step_avg:89.05ms +step:174/1680 train_time:15496ms step_avg:89.06ms +step:175/1680 train_time:15586ms step_avg:89.06ms +step:176/1680 train_time:15676ms step_avg:89.07ms +step:177/1680 train_time:15765ms step_avg:89.07ms +step:178/1680 train_time:15854ms step_avg:89.07ms +step:179/1680 train_time:15943ms step_avg:89.07ms +step:180/1680 train_time:16032ms step_avg:89.07ms +step:181/1680 train_time:16121ms step_avg:89.07ms +step:182/1680 train_time:16210ms step_avg:89.07ms +step:183/1680 train_time:16298ms step_avg:89.06ms +step:184/1680 train_time:16388ms step_avg:89.06ms +step:185/1680 train_time:16476ms step_avg:89.06ms +step:186/1680 train_time:16566ms step_avg:89.06ms +step:187/1680 train_time:16655ms step_avg:89.06ms +step:188/1680 train_time:16744ms step_avg:89.06ms +step:189/1680 train_time:16832ms step_avg:89.06ms +step:190/1680 train_time:16921ms step_avg:89.06ms +step:191/1680 train_time:17010ms step_avg:89.06ms +step:192/1680 train_time:17098ms step_avg:89.05ms +step:193/1680 train_time:17187ms step_avg:89.05ms +step:194/1680 train_time:17276ms step_avg:89.05ms +step:195/1680 train_time:17365ms step_avg:89.05ms +step:196/1680 train_time:17453ms step_avg:89.05ms +step:197/1680 train_time:17542ms step_avg:89.05ms +step:198/1680 train_time:17631ms step_avg:89.05ms +step:199/1680 train_time:17721ms step_avg:89.05ms +step:200/1680 train_time:17809ms step_avg:89.05ms +step:201/1680 train_time:17898ms step_avg:89.05ms +step:202/1680 train_time:17988ms step_avg:89.05ms +step:203/1680 train_time:18076ms step_avg:89.04ms +step:204/1680 train_time:18165ms step_avg:89.04ms +step:205/1680 train_time:18254ms step_avg:89.04ms +step:206/1680 train_time:18343ms step_avg:89.04ms +step:207/1680 train_time:18431ms step_avg:89.04ms +step:208/1680 train_time:18520ms step_avg:89.04ms +step:209/1680 train_time:18611ms step_avg:89.05ms +step:210/1680 train_time:18700ms step_avg:89.05ms +step:211/1680 train_time:18788ms step_avg:89.04ms +step:212/1680 train_time:18878ms step_avg:89.04ms +step:213/1680 train_time:18967ms step_avg:89.05ms +step:214/1680 train_time:19055ms step_avg:89.04ms +step:215/1680 train_time:19144ms step_avg:89.04ms +step:216/1680 train_time:19233ms step_avg:89.04ms +step:217/1680 train_time:19321ms step_avg:89.04ms +step:218/1680 train_time:19410ms step_avg:89.04ms +step:219/1680 train_time:19498ms step_avg:89.03ms +step:220/1680 train_time:19587ms step_avg:89.03ms +step:221/1680 train_time:19676ms step_avg:89.03ms +step:222/1680 train_time:19765ms step_avg:89.03ms +step:223/1680 train_time:19853ms step_avg:89.03ms +step:224/1680 train_time:19942ms step_avg:89.03ms +step:225/1680 train_time:20031ms step_avg:89.02ms +step:226/1680 train_time:20119ms step_avg:89.02ms +step:227/1680 train_time:20208ms step_avg:89.02ms +step:228/1680 train_time:20297ms step_avg:89.02ms +step:229/1680 train_time:20386ms step_avg:89.02ms +step:230/1680 train_time:20474ms step_avg:89.02ms +step:231/1680 train_time:20563ms step_avg:89.02ms +step:232/1680 train_time:20652ms step_avg:89.02ms +step:233/1680 train_time:20741ms step_avg:89.02ms +step:234/1680 train_time:20830ms step_avg:89.02ms +step:235/1680 train_time:20919ms step_avg:89.02ms +step:236/1680 train_time:21008ms step_avg:89.02ms +step:237/1680 train_time:21096ms step_avg:89.01ms +step:238/1680 train_time:21185ms step_avg:89.01ms +step:239/1680 train_time:21274ms step_avg:89.01ms +step:240/1680 train_time:21363ms step_avg:89.01ms +step:241/1680 train_time:21452ms step_avg:89.01ms +step:242/1680 train_time:21540ms step_avg:89.01ms +step:243/1680 train_time:21628ms step_avg:89.01ms +step:244/1680 train_time:21717ms step_avg:89.00ms +step:245/1680 train_time:21805ms step_avg:89.00ms +step:246/1680 train_time:21894ms step_avg:89.00ms +step:247/1680 train_time:21983ms step_avg:89.00ms +step:248/1680 train_time:22072ms step_avg:89.00ms +step:249/1680 train_time:22161ms step_avg:89.00ms +step:250/1680 train_time:22249ms step_avg:89.00ms +step:250/1680 val_loss:3.9699 train_time:22339ms step_avg:89.36ms +step:251/1680 train_time:22361ms step_avg:89.09ms +step:252/1680 train_time:22430ms step_avg:89.01ms +step:253/1680 train_time:22527ms step_avg:89.04ms +step:254/1680 train_time:22618ms step_avg:89.05ms +step:255/1680 train_time:22707ms step_avg:89.05ms +step:256/1680 train_time:22796ms step_avg:89.05ms +step:257/1680 train_time:22884ms step_avg:89.04ms +step:258/1680 train_time:22971ms step_avg:89.04ms +step:259/1680 train_time:23059ms step_avg:89.03ms +step:260/1680 train_time:23147ms step_avg:89.03ms +step:261/1680 train_time:23236ms step_avg:89.03ms +step:262/1680 train_time:23324ms step_avg:89.02ms +step:263/1680 train_time:23414ms step_avg:89.03ms +step:264/1680 train_time:23505ms step_avg:89.03ms +step:265/1680 train_time:23595ms step_avg:89.04ms +step:266/1680 train_time:23684ms step_avg:89.04ms +step:267/1680 train_time:23772ms step_avg:89.03ms +step:268/1680 train_time:23861ms step_avg:89.03ms +step:269/1680 train_time:23949ms step_avg:89.03ms +step:270/1680 train_time:24037ms step_avg:89.03ms +step:271/1680 train_time:24125ms step_avg:89.02ms +step:272/1680 train_time:24213ms step_avg:89.02ms +step:273/1680 train_time:24302ms step_avg:89.02ms +step:274/1680 train_time:24391ms step_avg:89.02ms +step:275/1680 train_time:24480ms step_avg:89.02ms +step:276/1680 train_time:24569ms step_avg:89.02ms +step:277/1680 train_time:24658ms step_avg:89.02ms +step:278/1680 train_time:24748ms step_avg:89.02ms +step:279/1680 train_time:24838ms step_avg:89.02ms +step:280/1680 train_time:24927ms step_avg:89.02ms +step:281/1680 train_time:25015ms step_avg:89.02ms +step:282/1680 train_time:25104ms step_avg:89.02ms +step:283/1680 train_time:25193ms step_avg:89.02ms +step:284/1680 train_time:25281ms step_avg:89.02ms +step:285/1680 train_time:25370ms step_avg:89.02ms +step:286/1680 train_time:25460ms step_avg:89.02ms +step:287/1680 train_time:25548ms step_avg:89.02ms +step:288/1680 train_time:25637ms step_avg:89.02ms +step:289/1680 train_time:25726ms step_avg:89.02ms +step:290/1680 train_time:25815ms step_avg:89.02ms +step:291/1680 train_time:25904ms step_avg:89.02ms +step:292/1680 train_time:25992ms step_avg:89.01ms +step:293/1680 train_time:26081ms step_avg:89.01ms +step:294/1680 train_time:26169ms step_avg:89.01ms +step:295/1680 train_time:26258ms step_avg:89.01ms +step:296/1680 train_time:26345ms step_avg:89.01ms +step:297/1680 train_time:26434ms step_avg:89.00ms +step:298/1680 train_time:26523ms step_avg:89.00ms +step:299/1680 train_time:26612ms step_avg:89.00ms +step:300/1680 train_time:26701ms step_avg:89.00ms +step:301/1680 train_time:26790ms step_avg:89.00ms +step:302/1680 train_time:26879ms step_avg:89.00ms +step:303/1680 train_time:26966ms step_avg:89.00ms +step:304/1680 train_time:27056ms step_avg:89.00ms +step:305/1680 train_time:27145ms step_avg:89.00ms +step:306/1680 train_time:27235ms step_avg:89.00ms +step:307/1680 train_time:27323ms step_avg:89.00ms +step:308/1680 train_time:27412ms step_avg:89.00ms +step:309/1680 train_time:27501ms step_avg:89.00ms +step:310/1680 train_time:27590ms step_avg:89.00ms +step:311/1680 train_time:27680ms step_avg:89.00ms +step:312/1680 train_time:27768ms step_avg:89.00ms +step:313/1680 train_time:27857ms step_avg:89.00ms +step:314/1680 train_time:27947ms step_avg:89.00ms +step:315/1680 train_time:28036ms step_avg:89.00ms +step:316/1680 train_time:28125ms step_avg:89.00ms +step:317/1680 train_time:28215ms step_avg:89.01ms +step:318/1680 train_time:28303ms step_avg:89.00ms +step:319/1680 train_time:28392ms step_avg:89.00ms +step:320/1680 train_time:28482ms step_avg:89.00ms +step:321/1680 train_time:28571ms step_avg:89.00ms +step:322/1680 train_time:28659ms step_avg:89.00ms +step:323/1680 train_time:28749ms step_avg:89.00ms +step:324/1680 train_time:28838ms step_avg:89.01ms +step:325/1680 train_time:28928ms step_avg:89.01ms +step:326/1680 train_time:29018ms step_avg:89.01ms +step:327/1680 train_time:29107ms step_avg:89.01ms +step:328/1680 train_time:29196ms step_avg:89.01ms +step:329/1680 train_time:29284ms step_avg:89.01ms +step:330/1680 train_time:29374ms step_avg:89.01ms +step:331/1680 train_time:29463ms step_avg:89.01ms +step:332/1680 train_time:29551ms step_avg:89.01ms +step:333/1680 train_time:29640ms step_avg:89.01ms +step:334/1680 train_time:29729ms step_avg:89.01ms +step:335/1680 train_time:29819ms step_avg:89.01ms +step:336/1680 train_time:29908ms step_avg:89.01ms +step:337/1680 train_time:29997ms step_avg:89.01ms +step:338/1680 train_time:30085ms step_avg:89.01ms +step:339/1680 train_time:30174ms step_avg:89.01ms +step:340/1680 train_time:30263ms step_avg:89.01ms +step:341/1680 train_time:30352ms step_avg:89.01ms +step:342/1680 train_time:30440ms step_avg:89.01ms +step:343/1680 train_time:30530ms step_avg:89.01ms +step:344/1680 train_time:30618ms step_avg:89.01ms +step:345/1680 train_time:30706ms step_avg:89.00ms +step:346/1680 train_time:30795ms step_avg:89.00ms +step:347/1680 train_time:30883ms step_avg:89.00ms +step:348/1680 train_time:30972ms step_avg:89.00ms +step:349/1680 train_time:31061ms step_avg:89.00ms +step:350/1680 train_time:31150ms step_avg:89.00ms +step:351/1680 train_time:31239ms step_avg:89.00ms +step:352/1680 train_time:31328ms step_avg:89.00ms +step:353/1680 train_time:31417ms step_avg:89.00ms +step:354/1680 train_time:31507ms step_avg:89.00ms +step:355/1680 train_time:31596ms step_avg:89.00ms +step:356/1680 train_time:31684ms step_avg:89.00ms +step:357/1680 train_time:31773ms step_avg:89.00ms +step:358/1680 train_time:31862ms step_avg:89.00ms +step:359/1680 train_time:31951ms step_avg:89.00ms +step:360/1680 train_time:32040ms step_avg:89.00ms +step:361/1680 train_time:32129ms step_avg:89.00ms +step:362/1680 train_time:32218ms step_avg:89.00ms +step:363/1680 train_time:32307ms step_avg:89.00ms +step:364/1680 train_time:32395ms step_avg:89.00ms +step:365/1680 train_time:32484ms step_avg:89.00ms +step:366/1680 train_time:32573ms step_avg:89.00ms +step:367/1680 train_time:32662ms step_avg:89.00ms +step:368/1680 train_time:32751ms step_avg:89.00ms +step:369/1680 train_time:32840ms step_avg:89.00ms +step:370/1680 train_time:32929ms step_avg:89.00ms +step:371/1680 train_time:33019ms step_avg:89.00ms +step:372/1680 train_time:33108ms step_avg:89.00ms +step:373/1680 train_time:33197ms step_avg:89.00ms +step:374/1680 train_time:33286ms step_avg:89.00ms +step:375/1680 train_time:33375ms step_avg:89.00ms +step:375/1680 val_loss:3.8123 train_time:33465ms step_avg:89.24ms +step:376/1680 train_time:33487ms step_avg:89.06ms +step:377/1680 train_time:33557ms step_avg:89.01ms +step:378/1680 train_time:33649ms step_avg:89.02ms +step:379/1680 train_time:33739ms step_avg:89.02ms +step:380/1680 train_time:33827ms step_avg:89.02ms +step:381/1680 train_time:33915ms step_avg:89.02ms +step:382/1680 train_time:34003ms step_avg:89.01ms +step:383/1680 train_time:34091ms step_avg:89.01ms +step:384/1680 train_time:34180ms step_avg:89.01ms +step:385/1680 train_time:34267ms step_avg:89.01ms +step:386/1680 train_time:34356ms step_avg:89.01ms +step:387/1680 train_time:34446ms step_avg:89.01ms +step:388/1680 train_time:34537ms step_avg:89.01ms +step:389/1680 train_time:34627ms step_avg:89.01ms +step:390/1680 train_time:34718ms step_avg:89.02ms +step:391/1680 train_time:34807ms step_avg:89.02ms +step:392/1680 train_time:34897ms step_avg:89.02ms +step:393/1680 train_time:34985ms step_avg:89.02ms +step:394/1680 train_time:35073ms step_avg:89.02ms +step:395/1680 train_time:35162ms step_avg:89.02ms +step:396/1680 train_time:35250ms step_avg:89.01ms +step:397/1680 train_time:35338ms step_avg:89.01ms +step:398/1680 train_time:35427ms step_avg:89.01ms +step:399/1680 train_time:35517ms step_avg:89.01ms +step:400/1680 train_time:35606ms step_avg:89.01ms +step:401/1680 train_time:35696ms step_avg:89.02ms +step:402/1680 train_time:35786ms step_avg:89.02ms +step:403/1680 train_time:35876ms step_avg:89.02ms +step:404/1680 train_time:35964ms step_avg:89.02ms +step:405/1680 train_time:36054ms step_avg:89.02ms +step:406/1680 train_time:36142ms step_avg:89.02ms +step:407/1680 train_time:36230ms step_avg:89.02ms +step:408/1680 train_time:36319ms step_avg:89.02ms +step:409/1680 train_time:36407ms step_avg:89.01ms +step:410/1680 train_time:36497ms step_avg:89.02ms +step:411/1680 train_time:36586ms step_avg:89.02ms +step:412/1680 train_time:36675ms step_avg:89.02ms +step:413/1680 train_time:36765ms step_avg:89.02ms +step:414/1680 train_time:36855ms step_avg:89.02ms +step:415/1680 train_time:36945ms step_avg:89.02ms +step:416/1680 train_time:37034ms step_avg:89.02ms +step:417/1680 train_time:37123ms step_avg:89.02ms +step:418/1680 train_time:37212ms step_avg:89.02ms +step:419/1680 train_time:37301ms step_avg:89.02ms +step:420/1680 train_time:37388ms step_avg:89.02ms +step:421/1680 train_time:37478ms step_avg:89.02ms +step:422/1680 train_time:37566ms step_avg:89.02ms +step:423/1680 train_time:37656ms step_avg:89.02ms +step:424/1680 train_time:37745ms step_avg:89.02ms +step:425/1680 train_time:37835ms step_avg:89.02ms +step:426/1680 train_time:37924ms step_avg:89.02ms +step:427/1680 train_time:38014ms step_avg:89.03ms +step:428/1680 train_time:38103ms step_avg:89.03ms +step:429/1680 train_time:38192ms step_avg:89.03ms +step:430/1680 train_time:38281ms step_avg:89.02ms +step:431/1680 train_time:38368ms step_avg:89.02ms +step:432/1680 train_time:38457ms step_avg:89.02ms +step:433/1680 train_time:38546ms step_avg:89.02ms +step:434/1680 train_time:38635ms step_avg:89.02ms +step:435/1680 train_time:38725ms step_avg:89.02ms +step:436/1680 train_time:38814ms step_avg:89.02ms +step:437/1680 train_time:38902ms step_avg:89.02ms +step:438/1680 train_time:38992ms step_avg:89.02ms +step:439/1680 train_time:39081ms step_avg:89.02ms +step:440/1680 train_time:39170ms step_avg:89.02ms +step:441/1680 train_time:39260ms step_avg:89.02ms +step:442/1680 train_time:39348ms step_avg:89.02ms +step:443/1680 train_time:39437ms step_avg:89.02ms +step:444/1680 train_time:39526ms step_avg:89.02ms +step:445/1680 train_time:39615ms step_avg:89.02ms +step:446/1680 train_time:39704ms step_avg:89.02ms +step:447/1680 train_time:39794ms step_avg:89.02ms +step:448/1680 train_time:39883ms step_avg:89.02ms +step:449/1680 train_time:39972ms step_avg:89.03ms +step:450/1680 train_time:40062ms step_avg:89.03ms +step:451/1680 train_time:40150ms step_avg:89.02ms +step:452/1680 train_time:40239ms step_avg:89.03ms +step:453/1680 train_time:40328ms step_avg:89.02ms +step:454/1680 train_time:40417ms step_avg:89.02ms +step:455/1680 train_time:40505ms step_avg:89.02ms +step:456/1680 train_time:40594ms step_avg:89.02ms +step:457/1680 train_time:40683ms step_avg:89.02ms +step:458/1680 train_time:40772ms step_avg:89.02ms +step:459/1680 train_time:40862ms step_avg:89.02ms +step:460/1680 train_time:40950ms step_avg:89.02ms +step:461/1680 train_time:41039ms step_avg:89.02ms +step:462/1680 train_time:41127ms step_avg:89.02ms +step:463/1680 train_time:41217ms step_avg:89.02ms +step:464/1680 train_time:41306ms step_avg:89.02ms +step:465/1680 train_time:41396ms step_avg:89.02ms +step:466/1680 train_time:41484ms step_avg:89.02ms +step:467/1680 train_time:41573ms step_avg:89.02ms +step:468/1680 train_time:41662ms step_avg:89.02ms +step:469/1680 train_time:41751ms step_avg:89.02ms +step:470/1680 train_time:41840ms step_avg:89.02ms +step:471/1680 train_time:41929ms step_avg:89.02ms +step:472/1680 train_time:42018ms step_avg:89.02ms +step:473/1680 train_time:42107ms step_avg:89.02ms +step:474/1680 train_time:42196ms step_avg:89.02ms +step:475/1680 train_time:42285ms step_avg:89.02ms +step:476/1680 train_time:42374ms step_avg:89.02ms +step:477/1680 train_time:42464ms step_avg:89.02ms +step:478/1680 train_time:42553ms step_avg:89.02ms +step:479/1680 train_time:42642ms step_avg:89.02ms +step:480/1680 train_time:42730ms step_avg:89.02ms +step:481/1680 train_time:42821ms step_avg:89.02ms +step:482/1680 train_time:42910ms step_avg:89.03ms +step:483/1680 train_time:42999ms step_avg:89.03ms +step:484/1680 train_time:43089ms step_avg:89.03ms +step:485/1680 train_time:43179ms step_avg:89.03ms +step:486/1680 train_time:43268ms step_avg:89.03ms +step:487/1680 train_time:43357ms step_avg:89.03ms +step:488/1680 train_time:43445ms step_avg:89.03ms +step:489/1680 train_time:43534ms step_avg:89.03ms +step:490/1680 train_time:43625ms step_avg:89.03ms +step:491/1680 train_time:43713ms step_avg:89.03ms +step:492/1680 train_time:43802ms step_avg:89.03ms +step:493/1680 train_time:43891ms step_avg:89.03ms +step:494/1680 train_time:43980ms step_avg:89.03ms +step:495/1680 train_time:44069ms step_avg:89.03ms +step:496/1680 train_time:44158ms step_avg:89.03ms +step:497/1680 train_time:44247ms step_avg:89.03ms +step:498/1680 train_time:44336ms step_avg:89.03ms +step:499/1680 train_time:44426ms step_avg:89.03ms +step:500/1680 train_time:44515ms step_avg:89.03ms +step:500/1680 val_loss:3.7144 train_time:44605ms step_avg:89.21ms +step:501/1680 train_time:44627ms step_avg:89.08ms +step:502/1680 train_time:44696ms step_avg:89.04ms +step:503/1680 train_time:44791ms step_avg:89.05ms +step:504/1680 train_time:44883ms step_avg:89.05ms +step:505/1680 train_time:44973ms step_avg:89.05ms +step:506/1680 train_time:45062ms step_avg:89.05ms +step:507/1680 train_time:45150ms step_avg:89.05ms +step:508/1680 train_time:45238ms step_avg:89.05ms +step:509/1680 train_time:45326ms step_avg:89.05ms +step:510/1680 train_time:45414ms step_avg:89.05ms +step:511/1680 train_time:45502ms step_avg:89.05ms +step:512/1680 train_time:45591ms step_avg:89.04ms +step:513/1680 train_time:45681ms step_avg:89.05ms +step:514/1680 train_time:45774ms step_avg:89.05ms +step:515/1680 train_time:45863ms step_avg:89.05ms +step:516/1680 train_time:45953ms step_avg:89.06ms +step:517/1680 train_time:46042ms step_avg:89.06ms +step:518/1680 train_time:46132ms step_avg:89.06ms +step:519/1680 train_time:46220ms step_avg:89.06ms +step:520/1680 train_time:46309ms step_avg:89.05ms +step:521/1680 train_time:46397ms step_avg:89.05ms +step:522/1680 train_time:46485ms step_avg:89.05ms +step:523/1680 train_time:46574ms step_avg:89.05ms +step:524/1680 train_time:46663ms step_avg:89.05ms +step:525/1680 train_time:46753ms step_avg:89.05ms +step:526/1680 train_time:46843ms step_avg:89.06ms +step:527/1680 train_time:46933ms step_avg:89.06ms +step:528/1680 train_time:47023ms step_avg:89.06ms +step:529/1680 train_time:47112ms step_avg:89.06ms +step:530/1680 train_time:47202ms step_avg:89.06ms +step:531/1680 train_time:47289ms step_avg:89.06ms +step:532/1680 train_time:47378ms step_avg:89.06ms +step:533/1680 train_time:47466ms step_avg:89.05ms +step:534/1680 train_time:47555ms step_avg:89.05ms +step:535/1680 train_time:47644ms step_avg:89.05ms +step:536/1680 train_time:47734ms step_avg:89.06ms +step:537/1680 train_time:47824ms step_avg:89.06ms +step:538/1680 train_time:47913ms step_avg:89.06ms +step:539/1680 train_time:48003ms step_avg:89.06ms +step:540/1680 train_time:48093ms step_avg:89.06ms +step:541/1680 train_time:48182ms step_avg:89.06ms +step:542/1680 train_time:48270ms step_avg:89.06ms +step:543/1680 train_time:48358ms step_avg:89.06ms +step:544/1680 train_time:48446ms step_avg:89.06ms +step:545/1680 train_time:48536ms step_avg:89.06ms +step:546/1680 train_time:48624ms step_avg:89.05ms +step:547/1680 train_time:48713ms step_avg:89.06ms +step:548/1680 train_time:48802ms step_avg:89.06ms +step:549/1680 train_time:48894ms step_avg:89.06ms +step:550/1680 train_time:48984ms step_avg:89.06ms +step:551/1680 train_time:49074ms step_avg:89.06ms +step:552/1680 train_time:49164ms step_avg:89.07ms +step:553/1680 train_time:49255ms step_avg:89.07ms +step:554/1680 train_time:49345ms step_avg:89.07ms +step:555/1680 train_time:49434ms step_avg:89.07ms +step:556/1680 train_time:49524ms step_avg:89.07ms +step:557/1680 train_time:49616ms step_avg:89.08ms +step:558/1680 train_time:49706ms step_avg:89.08ms +step:559/1680 train_time:49796ms step_avg:89.08ms +step:560/1680 train_time:49886ms step_avg:89.08ms +step:561/1680 train_time:49977ms step_avg:89.08ms +step:562/1680 train_time:50066ms step_avg:89.09ms +step:563/1680 train_time:50158ms step_avg:89.09ms +step:564/1680 train_time:50249ms step_avg:89.09ms +step:565/1680 train_time:50339ms step_avg:89.10ms +step:566/1680 train_time:50429ms step_avg:89.10ms +step:567/1680 train_time:50519ms step_avg:89.10ms +step:568/1680 train_time:50610ms step_avg:89.10ms +step:569/1680 train_time:50700ms step_avg:89.10ms +step:570/1680 train_time:50790ms step_avg:89.10ms +step:571/1680 train_time:50880ms step_avg:89.11ms +step:572/1680 train_time:50971ms step_avg:89.11ms +step:573/1680 train_time:51061ms step_avg:89.11ms +step:574/1680 train_time:51151ms step_avg:89.11ms +step:575/1680 train_time:51241ms step_avg:89.12ms +step:576/1680 train_time:51331ms step_avg:89.12ms +step:577/1680 train_time:51422ms step_avg:89.12ms +step:578/1680 train_time:51512ms step_avg:89.12ms +step:579/1680 train_time:51602ms step_avg:89.12ms +step:580/1680 train_time:51692ms step_avg:89.12ms +step:581/1680 train_time:51783ms step_avg:89.13ms +step:582/1680 train_time:51874ms step_avg:89.13ms +step:583/1680 train_time:51964ms step_avg:89.13ms +step:584/1680 train_time:52055ms step_avg:89.14ms +step:585/1680 train_time:52146ms step_avg:89.14ms +step:586/1680 train_time:52236ms step_avg:89.14ms +step:587/1680 train_time:52326ms step_avg:89.14ms +step:588/1680 train_time:52416ms step_avg:89.14ms +step:589/1680 train_time:52506ms step_avg:89.14ms +step:590/1680 train_time:52597ms step_avg:89.15ms +step:591/1680 train_time:52686ms step_avg:89.15ms +step:592/1680 train_time:52778ms step_avg:89.15ms +step:593/1680 train_time:52869ms step_avg:89.15ms +step:594/1680 train_time:52959ms step_avg:89.16ms +step:595/1680 train_time:53050ms step_avg:89.16ms +step:596/1680 train_time:53140ms step_avg:89.16ms +step:597/1680 train_time:53230ms step_avg:89.16ms +step:598/1680 train_time:53321ms step_avg:89.16ms +step:599/1680 train_time:53411ms step_avg:89.17ms +step:600/1680 train_time:53502ms step_avg:89.17ms +step:601/1680 train_time:53592ms step_avg:89.17ms +step:602/1680 train_time:53683ms step_avg:89.17ms +step:603/1680 train_time:53773ms step_avg:89.18ms +step:604/1680 train_time:53863ms step_avg:89.18ms +step:605/1680 train_time:53953ms step_avg:89.18ms +step:606/1680 train_time:54043ms step_avg:89.18ms +step:607/1680 train_time:54134ms step_avg:89.18ms +step:608/1680 train_time:54224ms step_avg:89.18ms +step:609/1680 train_time:54314ms step_avg:89.19ms +step:610/1680 train_time:54403ms step_avg:89.19ms +step:611/1680 train_time:54493ms step_avg:89.19ms +step:612/1680 train_time:54583ms step_avg:89.19ms +step:613/1680 train_time:54674ms step_avg:89.19ms +step:614/1680 train_time:54765ms step_avg:89.19ms +step:615/1680 train_time:54855ms step_avg:89.20ms +step:616/1680 train_time:54945ms step_avg:89.20ms +step:617/1680 train_time:55036ms step_avg:89.20ms +step:618/1680 train_time:55126ms step_avg:89.20ms +step:619/1680 train_time:55217ms step_avg:89.20ms +step:620/1680 train_time:55306ms step_avg:89.20ms +step:621/1680 train_time:55397ms step_avg:89.21ms +step:622/1680 train_time:55487ms step_avg:89.21ms +step:623/1680 train_time:55578ms step_avg:89.21ms +step:624/1680 train_time:55668ms step_avg:89.21ms +step:625/1680 train_time:55760ms step_avg:89.22ms +step:625/1680 val_loss:3.6145 train_time:55852ms step_avg:89.36ms +step:626/1680 train_time:55874ms step_avg:89.26ms +step:627/1680 train_time:55944ms step_avg:89.23ms +step:628/1680 train_time:56043ms step_avg:89.24ms +step:629/1680 train_time:56133ms step_avg:89.24ms +step:630/1680 train_time:56222ms step_avg:89.24ms +step:631/1680 train_time:56311ms step_avg:89.24ms +step:632/1680 train_time:56400ms step_avg:89.24ms +step:633/1680 train_time:56489ms step_avg:89.24ms +step:634/1680 train_time:56578ms step_avg:89.24ms +step:635/1680 train_time:56667ms step_avg:89.24ms +step:636/1680 train_time:56757ms step_avg:89.24ms +step:637/1680 train_time:56853ms step_avg:89.25ms +step:638/1680 train_time:56945ms step_avg:89.26ms +step:639/1680 train_time:57036ms step_avg:89.26ms +step:640/1680 train_time:57127ms step_avg:89.26ms +step:641/1680 train_time:57217ms step_avg:89.26ms +step:642/1680 train_time:57306ms step_avg:89.26ms +step:643/1680 train_time:57396ms step_avg:89.26ms +step:644/1680 train_time:57485ms step_avg:89.26ms +step:645/1680 train_time:57575ms step_avg:89.26ms +step:646/1680 train_time:57664ms step_avg:89.26ms +step:647/1680 train_time:57755ms step_avg:89.27ms +step:648/1680 train_time:57847ms step_avg:89.27ms +step:649/1680 train_time:57938ms step_avg:89.27ms +step:650/1680 train_time:58031ms step_avg:89.28ms +step:651/1680 train_time:58121ms step_avg:89.28ms +step:652/1680 train_time:58213ms step_avg:89.28ms +step:653/1680 train_time:58302ms step_avg:89.28ms +step:654/1680 train_time:58392ms step_avg:89.28ms +step:655/1680 train_time:58481ms step_avg:89.28ms +step:656/1680 train_time:58571ms step_avg:89.29ms +step:657/1680 train_time:58660ms step_avg:89.29ms +step:658/1680 train_time:58751ms step_avg:89.29ms +step:659/1680 train_time:58842ms step_avg:89.29ms +step:660/1680 train_time:58934ms step_avg:89.29ms +step:661/1680 train_time:59025ms step_avg:89.30ms +step:662/1680 train_time:59116ms step_avg:89.30ms +step:663/1680 train_time:59206ms step_avg:89.30ms +step:664/1680 train_time:59297ms step_avg:89.30ms +step:665/1680 train_time:59386ms step_avg:89.30ms +step:666/1680 train_time:59476ms step_avg:89.30ms +step:667/1680 train_time:59567ms step_avg:89.31ms +step:668/1680 train_time:59657ms step_avg:89.31ms +step:669/1680 train_time:59747ms step_avg:89.31ms +step:670/1680 train_time:59837ms step_avg:89.31ms +step:671/1680 train_time:59927ms step_avg:89.31ms +step:672/1680 train_time:60018ms step_avg:89.31ms +step:673/1680 train_time:60109ms step_avg:89.32ms +step:674/1680 train_time:60199ms step_avg:89.32ms +step:675/1680 train_time:60289ms step_avg:89.32ms +step:676/1680 train_time:60379ms step_avg:89.32ms +step:677/1680 train_time:60469ms step_avg:89.32ms +step:678/1680 train_time:60559ms step_avg:89.32ms +step:679/1680 train_time:60650ms step_avg:89.32ms +step:680/1680 train_time:60739ms step_avg:89.32ms +step:681/1680 train_time:60830ms step_avg:89.32ms +step:682/1680 train_time:60920ms step_avg:89.33ms +step:683/1680 train_time:61011ms step_avg:89.33ms +step:684/1680 train_time:61101ms step_avg:89.33ms +step:685/1680 train_time:61191ms step_avg:89.33ms +step:686/1680 train_time:61281ms step_avg:89.33ms +step:687/1680 train_time:61370ms step_avg:89.33ms +step:688/1680 train_time:61460ms step_avg:89.33ms +step:689/1680 train_time:61550ms step_avg:89.33ms +step:690/1680 train_time:61639ms step_avg:89.33ms +step:691/1680 train_time:61729ms step_avg:89.33ms +step:692/1680 train_time:61819ms step_avg:89.33ms +step:693/1680 train_time:61910ms step_avg:89.34ms +step:694/1680 train_time:62000ms step_avg:89.34ms +step:695/1680 train_time:62091ms step_avg:89.34ms +step:696/1680 train_time:62182ms step_avg:89.34ms +step:697/1680 train_time:62272ms step_avg:89.34ms +step:698/1680 train_time:62361ms step_avg:89.34ms +step:699/1680 train_time:62451ms step_avg:89.34ms +step:700/1680 train_time:62542ms step_avg:89.35ms +step:701/1680 train_time:62632ms step_avg:89.35ms +step:702/1680 train_time:62723ms step_avg:89.35ms +step:703/1680 train_time:62813ms step_avg:89.35ms +step:704/1680 train_time:62904ms step_avg:89.35ms +step:705/1680 train_time:62994ms step_avg:89.35ms +step:706/1680 train_time:63085ms step_avg:89.36ms +step:707/1680 train_time:63175ms step_avg:89.36ms +step:708/1680 train_time:63265ms step_avg:89.36ms +step:709/1680 train_time:63356ms step_avg:89.36ms +step:710/1680 train_time:63446ms step_avg:89.36ms +step:711/1680 train_time:63535ms step_avg:89.36ms +step:712/1680 train_time:63626ms step_avg:89.36ms +step:713/1680 train_time:63716ms step_avg:89.36ms +step:714/1680 train_time:63806ms step_avg:89.36ms +step:715/1680 train_time:63896ms step_avg:89.37ms +step:716/1680 train_time:63988ms step_avg:89.37ms +step:717/1680 train_time:64078ms step_avg:89.37ms +step:718/1680 train_time:64168ms step_avg:89.37ms +step:719/1680 train_time:64258ms step_avg:89.37ms +step:720/1680 train_time:64349ms step_avg:89.37ms +step:721/1680 train_time:64439ms step_avg:89.37ms +step:722/1680 train_time:64530ms step_avg:89.38ms +step:723/1680 train_time:64619ms step_avg:89.38ms +step:724/1680 train_time:64710ms step_avg:89.38ms +step:725/1680 train_time:64800ms step_avg:89.38ms +step:726/1680 train_time:64890ms step_avg:89.38ms +step:727/1680 train_time:64981ms step_avg:89.38ms +step:728/1680 train_time:65071ms step_avg:89.38ms +step:729/1680 train_time:65160ms step_avg:89.38ms +step:730/1680 train_time:65251ms step_avg:89.38ms +step:731/1680 train_time:65340ms step_avg:89.38ms +step:732/1680 train_time:65431ms step_avg:89.39ms +step:733/1680 train_time:65521ms step_avg:89.39ms +step:734/1680 train_time:65612ms step_avg:89.39ms +step:735/1680 train_time:65701ms step_avg:89.39ms +step:736/1680 train_time:65792ms step_avg:89.39ms +step:737/1680 train_time:65881ms step_avg:89.39ms +step:738/1680 train_time:65971ms step_avg:89.39ms +step:739/1680 train_time:66061ms step_avg:89.39ms +step:740/1680 train_time:66153ms step_avg:89.40ms +step:741/1680 train_time:66243ms step_avg:89.40ms +step:742/1680 train_time:66333ms step_avg:89.40ms +step:743/1680 train_time:66423ms step_avg:89.40ms +step:744/1680 train_time:66514ms step_avg:89.40ms +step:745/1680 train_time:66604ms step_avg:89.40ms +step:746/1680 train_time:66695ms step_avg:89.40ms +step:747/1680 train_time:66785ms step_avg:89.40ms +step:748/1680 train_time:66875ms step_avg:89.41ms +step:749/1680 train_time:66965ms step_avg:89.41ms +step:750/1680 train_time:67055ms step_avg:89.41ms +step:750/1680 val_loss:3.5620 train_time:67148ms step_avg:89.53ms +step:751/1680 train_time:67171ms step_avg:89.44ms +step:752/1680 train_time:67243ms step_avg:89.42ms +step:753/1680 train_time:67339ms step_avg:89.43ms +step:754/1680 train_time:67431ms step_avg:89.43ms +step:755/1680 train_time:67520ms step_avg:89.43ms +step:756/1680 train_time:67609ms step_avg:89.43ms +step:757/1680 train_time:67698ms step_avg:89.43ms +step:758/1680 train_time:67787ms step_avg:89.43ms +step:759/1680 train_time:67877ms step_avg:89.43ms +step:760/1680 train_time:67966ms step_avg:89.43ms +step:761/1680 train_time:68055ms step_avg:89.43ms +step:762/1680 train_time:68145ms step_avg:89.43ms +step:763/1680 train_time:68238ms step_avg:89.43ms +step:764/1680 train_time:68331ms step_avg:89.44ms +step:765/1680 train_time:68423ms step_avg:89.44ms +step:766/1680 train_time:68513ms step_avg:89.44ms +step:767/1680 train_time:68603ms step_avg:89.44ms +step:768/1680 train_time:68692ms step_avg:89.44ms +step:769/1680 train_time:68781ms step_avg:89.44ms +step:770/1680 train_time:68871ms step_avg:89.44ms +step:771/1680 train_time:68960ms step_avg:89.44ms +step:772/1680 train_time:69050ms step_avg:89.44ms +step:773/1680 train_time:69141ms step_avg:89.44ms +step:774/1680 train_time:69232ms step_avg:89.45ms +step:775/1680 train_time:69322ms step_avg:89.45ms +step:776/1680 train_time:69413ms step_avg:89.45ms +step:777/1680 train_time:69503ms step_avg:89.45ms +step:778/1680 train_time:69595ms step_avg:89.45ms +step:779/1680 train_time:69685ms step_avg:89.45ms +step:780/1680 train_time:69775ms step_avg:89.46ms +step:781/1680 train_time:69864ms step_avg:89.45ms +step:782/1680 train_time:69954ms step_avg:89.46ms +step:783/1680 train_time:70044ms step_avg:89.46ms +step:784/1680 train_time:70134ms step_avg:89.46ms +step:785/1680 train_time:70225ms step_avg:89.46ms +step:786/1680 train_time:70316ms step_avg:89.46ms +step:787/1680 train_time:70407ms step_avg:89.46ms +step:788/1680 train_time:70499ms step_avg:89.47ms +step:789/1680 train_time:70588ms step_avg:89.47ms +step:790/1680 train_time:70678ms step_avg:89.47ms +step:791/1680 train_time:70769ms step_avg:89.47ms +step:792/1680 train_time:70859ms step_avg:89.47ms +step:793/1680 train_time:70948ms step_avg:89.47ms +step:794/1680 train_time:71039ms step_avg:89.47ms +step:795/1680 train_time:71129ms step_avg:89.47ms +step:796/1680 train_time:71219ms step_avg:89.47ms +step:797/1680 train_time:71309ms step_avg:89.47ms +step:798/1680 train_time:71400ms step_avg:89.47ms +step:799/1680 train_time:71492ms step_avg:89.48ms +step:800/1680 train_time:71582ms step_avg:89.48ms +step:801/1680 train_time:71673ms step_avg:89.48ms +step:802/1680 train_time:71763ms step_avg:89.48ms +step:803/1680 train_time:71853ms step_avg:89.48ms +step:804/1680 train_time:71943ms step_avg:89.48ms +step:805/1680 train_time:72033ms step_avg:89.48ms +step:806/1680 train_time:72123ms step_avg:89.48ms +step:807/1680 train_time:72213ms step_avg:89.48ms +step:808/1680 train_time:72304ms step_avg:89.49ms +step:809/1680 train_time:72395ms step_avg:89.49ms +step:810/1680 train_time:72485ms step_avg:89.49ms +step:811/1680 train_time:72575ms step_avg:89.49ms +step:812/1680 train_time:72665ms step_avg:89.49ms +step:813/1680 train_time:72757ms step_avg:89.49ms +step:814/1680 train_time:72847ms step_avg:89.49ms +step:815/1680 train_time:72937ms step_avg:89.49ms +step:816/1680 train_time:73027ms step_avg:89.49ms +step:817/1680 train_time:73117ms step_avg:89.50ms +step:818/1680 train_time:73207ms step_avg:89.50ms +step:819/1680 train_time:73297ms step_avg:89.50ms +step:820/1680 train_time:73388ms step_avg:89.50ms +step:821/1680 train_time:73479ms step_avg:89.50ms +step:822/1680 train_time:73570ms step_avg:89.50ms +step:823/1680 train_time:73661ms step_avg:89.50ms +step:824/1680 train_time:73751ms step_avg:89.50ms +step:825/1680 train_time:73841ms step_avg:89.50ms +step:826/1680 train_time:73931ms step_avg:89.51ms +step:827/1680 train_time:74021ms step_avg:89.51ms +step:828/1680 train_time:74110ms step_avg:89.51ms +step:829/1680 train_time:74202ms step_avg:89.51ms +step:830/1680 train_time:74292ms step_avg:89.51ms +step:831/1680 train_time:74382ms step_avg:89.51ms +step:832/1680 train_time:74472ms step_avg:89.51ms +step:833/1680 train_time:74562ms step_avg:89.51ms +step:834/1680 train_time:74653ms step_avg:89.51ms +step:835/1680 train_time:74742ms step_avg:89.51ms +step:836/1680 train_time:74833ms step_avg:89.51ms +step:837/1680 train_time:74922ms step_avg:89.51ms +step:838/1680 train_time:75012ms step_avg:89.51ms +step:839/1680 train_time:75103ms step_avg:89.51ms +step:840/1680 train_time:75193ms step_avg:89.52ms +step:841/1680 train_time:75282ms step_avg:89.52ms +step:842/1680 train_time:75373ms step_avg:89.52ms +step:843/1680 train_time:75463ms step_avg:89.52ms +step:844/1680 train_time:75553ms step_avg:89.52ms +step:845/1680 train_time:75643ms step_avg:89.52ms +step:846/1680 train_time:75733ms step_avg:89.52ms +step:847/1680 train_time:75823ms step_avg:89.52ms +step:848/1680 train_time:75913ms step_avg:89.52ms +step:849/1680 train_time:76003ms step_avg:89.52ms +step:850/1680 train_time:76093ms step_avg:89.52ms +step:851/1680 train_time:76183ms step_avg:89.52ms +step:852/1680 train_time:76273ms step_avg:89.52ms +step:853/1680 train_time:76363ms step_avg:89.52ms +step:854/1680 train_time:76454ms step_avg:89.52ms +step:855/1680 train_time:76543ms step_avg:89.52ms +step:856/1680 train_time:76634ms step_avg:89.53ms +step:857/1680 train_time:76724ms step_avg:89.53ms +step:858/1680 train_time:76814ms step_avg:89.53ms +step:859/1680 train_time:76905ms step_avg:89.53ms +step:860/1680 train_time:76996ms step_avg:89.53ms +step:861/1680 train_time:77087ms step_avg:89.53ms +step:862/1680 train_time:77178ms step_avg:89.53ms +step:863/1680 train_time:77268ms step_avg:89.53ms +step:864/1680 train_time:77359ms step_avg:89.54ms +step:865/1680 train_time:77449ms step_avg:89.54ms +step:866/1680 train_time:77539ms step_avg:89.54ms +step:867/1680 train_time:77629ms step_avg:89.54ms +step:868/1680 train_time:77719ms step_avg:89.54ms +step:869/1680 train_time:77809ms step_avg:89.54ms +step:870/1680 train_time:77899ms step_avg:89.54ms +step:871/1680 train_time:77989ms step_avg:89.54ms +step:872/1680 train_time:78079ms step_avg:89.54ms +step:873/1680 train_time:78170ms step_avg:89.54ms +step:874/1680 train_time:78260ms step_avg:89.54ms +step:875/1680 train_time:78351ms step_avg:89.54ms +step:875/1680 val_loss:3.5163 train_time:78443ms step_avg:89.65ms +step:876/1680 train_time:78465ms step_avg:89.57ms +step:877/1680 train_time:78537ms step_avg:89.55ms +step:878/1680 train_time:78632ms step_avg:89.56ms +step:879/1680 train_time:78723ms step_avg:89.56ms +step:880/1680 train_time:78813ms step_avg:89.56ms +step:881/1680 train_time:78903ms step_avg:89.56ms +step:882/1680 train_time:78991ms step_avg:89.56ms +step:883/1680 train_time:79080ms step_avg:89.56ms +step:884/1680 train_time:79170ms step_avg:89.56ms +step:885/1680 train_time:79259ms step_avg:89.56ms +step:886/1680 train_time:79349ms step_avg:89.56ms +step:887/1680 train_time:79439ms step_avg:89.56ms +step:888/1680 train_time:79533ms step_avg:89.56ms +step:889/1680 train_time:79626ms step_avg:89.57ms +step:890/1680 train_time:79717ms step_avg:89.57ms +step:891/1680 train_time:79808ms step_avg:89.57ms +step:892/1680 train_time:79898ms step_avg:89.57ms +step:893/1680 train_time:79987ms step_avg:89.57ms +step:894/1680 train_time:80077ms step_avg:89.57ms +step:895/1680 train_time:80166ms step_avg:89.57ms +step:896/1680 train_time:80255ms step_avg:89.57ms +step:897/1680 train_time:80345ms step_avg:89.57ms +step:898/1680 train_time:80435ms step_avg:89.57ms +step:899/1680 train_time:80528ms step_avg:89.57ms +step:900/1680 train_time:80618ms step_avg:89.58ms +step:901/1680 train_time:80709ms step_avg:89.58ms +step:902/1680 train_time:80799ms step_avg:89.58ms +step:903/1680 train_time:80890ms step_avg:89.58ms +step:904/1680 train_time:80979ms step_avg:89.58ms +step:905/1680 train_time:81068ms step_avg:89.58ms +step:906/1680 train_time:81157ms step_avg:89.58ms +step:907/1680 train_time:81247ms step_avg:89.58ms +step:908/1680 train_time:81337ms step_avg:89.58ms +step:909/1680 train_time:81428ms step_avg:89.58ms +step:910/1680 train_time:81519ms step_avg:89.58ms +step:911/1680 train_time:81610ms step_avg:89.58ms +step:912/1680 train_time:81700ms step_avg:89.58ms +step:913/1680 train_time:81791ms step_avg:89.58ms +step:914/1680 train_time:81881ms step_avg:89.59ms +step:915/1680 train_time:81972ms step_avg:89.59ms +step:916/1680 train_time:82061ms step_avg:89.59ms +step:917/1680 train_time:82151ms step_avg:89.59ms +step:918/1680 train_time:82240ms step_avg:89.59ms +step:919/1680 train_time:82331ms step_avg:89.59ms +step:920/1680 train_time:82421ms step_avg:89.59ms +step:921/1680 train_time:82512ms step_avg:89.59ms +step:922/1680 train_time:82602ms step_avg:89.59ms +step:923/1680 train_time:82693ms step_avg:89.59ms +step:924/1680 train_time:82784ms step_avg:89.59ms +step:925/1680 train_time:82874ms step_avg:89.59ms +step:926/1680 train_time:82964ms step_avg:89.59ms +step:927/1680 train_time:83054ms step_avg:89.59ms +step:928/1680 train_time:83144ms step_avg:89.60ms +step:929/1680 train_time:83234ms step_avg:89.60ms +step:930/1680 train_time:83326ms step_avg:89.60ms +step:931/1680 train_time:83415ms step_avg:89.60ms +step:932/1680 train_time:83506ms step_avg:89.60ms +step:933/1680 train_time:83596ms step_avg:89.60ms +step:934/1680 train_time:83687ms step_avg:89.60ms +step:935/1680 train_time:83777ms step_avg:89.60ms +step:936/1680 train_time:83867ms step_avg:89.60ms +step:937/1680 train_time:83958ms step_avg:89.60ms +step:938/1680 train_time:84048ms step_avg:89.60ms +step:939/1680 train_time:84137ms step_avg:89.60ms +step:940/1680 train_time:84227ms step_avg:89.60ms +step:941/1680 train_time:84317ms step_avg:89.60ms +step:942/1680 train_time:84407ms step_avg:89.60ms +step:943/1680 train_time:84496ms step_avg:89.60ms +step:944/1680 train_time:84587ms step_avg:89.60ms +step:945/1680 train_time:84678ms step_avg:89.61ms +step:946/1680 train_time:84768ms step_avg:89.61ms +step:947/1680 train_time:84858ms step_avg:89.61ms +step:948/1680 train_time:84948ms step_avg:89.61ms +step:949/1680 train_time:85038ms step_avg:89.61ms +step:950/1680 train_time:85128ms step_avg:89.61ms +step:951/1680 train_time:85217ms step_avg:89.61ms +step:952/1680 train_time:85308ms step_avg:89.61ms +step:953/1680 train_time:85398ms step_avg:89.61ms +step:954/1680 train_time:85489ms step_avg:89.61ms +step:955/1680 train_time:85579ms step_avg:89.61ms +step:956/1680 train_time:85670ms step_avg:89.61ms +step:957/1680 train_time:85760ms step_avg:89.61ms +step:958/1680 train_time:85851ms step_avg:89.61ms +step:959/1680 train_time:85940ms step_avg:89.61ms +step:960/1680 train_time:86031ms step_avg:89.62ms +step:961/1680 train_time:86120ms step_avg:89.62ms +step:962/1680 train_time:86211ms step_avg:89.62ms +step:963/1680 train_time:86300ms step_avg:89.62ms +step:964/1680 train_time:86391ms step_avg:89.62ms +step:965/1680 train_time:86481ms step_avg:89.62ms +step:966/1680 train_time:86572ms step_avg:89.62ms +step:967/1680 train_time:86661ms step_avg:89.62ms +step:968/1680 train_time:86752ms step_avg:89.62ms +step:969/1680 train_time:86842ms step_avg:89.62ms +step:970/1680 train_time:86933ms step_avg:89.62ms +step:971/1680 train_time:87023ms step_avg:89.62ms +step:972/1680 train_time:87113ms step_avg:89.62ms +step:973/1680 train_time:87202ms step_avg:89.62ms +step:974/1680 train_time:87292ms step_avg:89.62ms +step:975/1680 train_time:87382ms step_avg:89.62ms +step:976/1680 train_time:87472ms step_avg:89.62ms +step:977/1680 train_time:87562ms step_avg:89.62ms +step:978/1680 train_time:87653ms step_avg:89.62ms +step:979/1680 train_time:87743ms step_avg:89.63ms +step:980/1680 train_time:87833ms step_avg:89.63ms +step:981/1680 train_time:87923ms step_avg:89.63ms +step:982/1680 train_time:88013ms step_avg:89.63ms +step:983/1680 train_time:88103ms step_avg:89.63ms +step:984/1680 train_time:88194ms step_avg:89.63ms +step:985/1680 train_time:88284ms step_avg:89.63ms +step:986/1680 train_time:88374ms step_avg:89.63ms +step:987/1680 train_time:88464ms step_avg:89.63ms +step:988/1680 train_time:88555ms step_avg:89.63ms +step:989/1680 train_time:88646ms step_avg:89.63ms +step:990/1680 train_time:88736ms step_avg:89.63ms +step:991/1680 train_time:88826ms step_avg:89.63ms +step:992/1680 train_time:88917ms step_avg:89.63ms +step:993/1680 train_time:89006ms step_avg:89.63ms +step:994/1680 train_time:89095ms step_avg:89.63ms +step:995/1680 train_time:89186ms step_avg:89.63ms +step:996/1680 train_time:89276ms step_avg:89.63ms +step:997/1680 train_time:89366ms step_avg:89.64ms +step:998/1680 train_time:89458ms step_avg:89.64ms +step:999/1680 train_time:89548ms step_avg:89.64ms +step:1000/1680 train_time:89638ms step_avg:89.64ms +step:1000/1680 val_loss:3.4684 train_time:89730ms step_avg:89.73ms +step:1001/1680 train_time:89752ms step_avg:89.66ms +step:1002/1680 train_time:89823ms step_avg:89.64ms +step:1003/1680 train_time:89918ms step_avg:89.65ms +step:1004/1680 train_time:90010ms step_avg:89.65ms +step:1005/1680 train_time:90100ms step_avg:89.65ms +step:1006/1680 train_time:90189ms step_avg:89.65ms +step:1007/1680 train_time:90278ms step_avg:89.65ms +step:1008/1680 train_time:90367ms step_avg:89.65ms +step:1009/1680 train_time:90457ms step_avg:89.65ms +step:1010/1680 train_time:90546ms step_avg:89.65ms +step:1011/1680 train_time:90636ms step_avg:89.65ms +step:1012/1680 train_time:90727ms step_avg:89.65ms +step:1013/1680 train_time:90820ms step_avg:89.65ms +step:1014/1680 train_time:90913ms step_avg:89.66ms +step:1015/1680 train_time:91003ms step_avg:89.66ms +step:1016/1680 train_time:91093ms step_avg:89.66ms +step:1017/1680 train_time:91183ms step_avg:89.66ms +step:1018/1680 train_time:91272ms step_avg:89.66ms +step:1019/1680 train_time:91361ms step_avg:89.66ms +step:1020/1680 train_time:91449ms step_avg:89.66ms +step:1021/1680 train_time:91539ms step_avg:89.66ms +step:1022/1680 train_time:91628ms step_avg:89.66ms +step:1023/1680 train_time:91721ms step_avg:89.66ms +step:1024/1680 train_time:91812ms step_avg:89.66ms +step:1025/1680 train_time:91903ms step_avg:89.66ms +step:1026/1680 train_time:91994ms step_avg:89.66ms +step:1027/1680 train_time:92085ms step_avg:89.66ms +step:1028/1680 train_time:92176ms step_avg:89.66ms +step:1029/1680 train_time:92266ms step_avg:89.67ms +step:1030/1680 train_time:92356ms step_avg:89.67ms +step:1031/1680 train_time:92446ms step_avg:89.67ms +step:1032/1680 train_time:92535ms step_avg:89.67ms +step:1033/1680 train_time:92626ms step_avg:89.67ms +step:1034/1680 train_time:92716ms step_avg:89.67ms +step:1035/1680 train_time:92807ms step_avg:89.67ms +step:1036/1680 train_time:92898ms step_avg:89.67ms +step:1037/1680 train_time:92989ms step_avg:89.67ms +step:1038/1680 train_time:93079ms step_avg:89.67ms +step:1039/1680 train_time:93169ms step_avg:89.67ms +step:1040/1680 train_time:93259ms step_avg:89.67ms +step:1041/1680 train_time:93349ms step_avg:89.67ms +step:1042/1680 train_time:93439ms step_avg:89.67ms +step:1043/1680 train_time:93529ms step_avg:89.67ms +step:1044/1680 train_time:93618ms step_avg:89.67ms +step:1045/1680 train_time:93708ms step_avg:89.67ms +step:1046/1680 train_time:93798ms step_avg:89.67ms +step:1047/1680 train_time:93889ms step_avg:89.67ms +step:1048/1680 train_time:93978ms step_avg:89.67ms +step:1049/1680 train_time:94069ms step_avg:89.68ms +step:1050/1680 train_time:94160ms step_avg:89.68ms +step:1051/1680 train_time:94251ms step_avg:89.68ms +step:1052/1680 train_time:94340ms step_avg:89.68ms +step:1053/1680 train_time:94430ms step_avg:89.68ms +step:1054/1680 train_time:94521ms step_avg:89.68ms +step:1055/1680 train_time:94610ms step_avg:89.68ms +step:1056/1680 train_time:94701ms step_avg:89.68ms +step:1057/1680 train_time:94792ms step_avg:89.68ms +step:1058/1680 train_time:94882ms step_avg:89.68ms +step:1059/1680 train_time:94973ms step_avg:89.68ms +step:1060/1680 train_time:95064ms step_avg:89.68ms +step:1061/1680 train_time:95155ms step_avg:89.68ms +step:1062/1680 train_time:95245ms step_avg:89.68ms +step:1063/1680 train_time:95335ms step_avg:89.68ms +step:1064/1680 train_time:95424ms step_avg:89.68ms +step:1065/1680 train_time:95514ms step_avg:89.68ms +step:1066/1680 train_time:95604ms step_avg:89.68ms +step:1067/1680 train_time:95693ms step_avg:89.68ms +step:1068/1680 train_time:95783ms step_avg:89.68ms +step:1069/1680 train_time:95874ms step_avg:89.69ms +step:1070/1680 train_time:95964ms step_avg:89.69ms +step:1071/1680 train_time:96055ms step_avg:89.69ms +step:1072/1680 train_time:96146ms step_avg:89.69ms +step:1073/1680 train_time:96236ms step_avg:89.69ms +step:1074/1680 train_time:96326ms step_avg:89.69ms +step:1075/1680 train_time:96416ms step_avg:89.69ms +step:1076/1680 train_time:96506ms step_avg:89.69ms +step:1077/1680 train_time:96597ms step_avg:89.69ms +step:1078/1680 train_time:96687ms step_avg:89.69ms +step:1079/1680 train_time:96777ms step_avg:89.69ms +step:1080/1680 train_time:96867ms step_avg:89.69ms +step:1081/1680 train_time:96957ms step_avg:89.69ms +step:1082/1680 train_time:97048ms step_avg:89.69ms +step:1083/1680 train_time:97138ms step_avg:89.69ms +step:1084/1680 train_time:97229ms step_avg:89.69ms +step:1085/1680 train_time:97318ms step_avg:89.69ms +step:1086/1680 train_time:97408ms step_avg:89.69ms +step:1087/1680 train_time:97498ms step_avg:89.69ms +step:1088/1680 train_time:97588ms step_avg:89.69ms +step:1089/1680 train_time:97678ms step_avg:89.70ms +step:1090/1680 train_time:97768ms step_avg:89.70ms +step:1091/1680 train_time:97858ms step_avg:89.70ms +step:1092/1680 train_time:97948ms step_avg:89.70ms +step:1093/1680 train_time:98039ms step_avg:89.70ms +step:1094/1680 train_time:98129ms step_avg:89.70ms +step:1095/1680 train_time:98220ms step_avg:89.70ms +step:1096/1680 train_time:98310ms step_avg:89.70ms +step:1097/1680 train_time:98401ms step_avg:89.70ms +step:1098/1680 train_time:98492ms step_avg:89.70ms +step:1099/1680 train_time:98583ms step_avg:89.70ms +step:1100/1680 train_time:98674ms step_avg:89.70ms +step:1101/1680 train_time:98766ms step_avg:89.71ms +step:1102/1680 train_time:98856ms step_avg:89.71ms +step:1103/1680 train_time:98946ms step_avg:89.71ms +step:1104/1680 train_time:99038ms step_avg:89.71ms +step:1105/1680 train_time:99129ms step_avg:89.71ms +step:1106/1680 train_time:99221ms step_avg:89.71ms +step:1107/1680 train_time:99312ms step_avg:89.71ms +step:1108/1680 train_time:99403ms step_avg:89.71ms +step:1109/1680 train_time:99494ms step_avg:89.71ms +step:1110/1680 train_time:99584ms step_avg:89.72ms +step:1111/1680 train_time:99676ms step_avg:89.72ms +step:1112/1680 train_time:99767ms step_avg:89.72ms +step:1113/1680 train_time:99857ms step_avg:89.72ms +step:1114/1680 train_time:99948ms step_avg:89.72ms +step:1115/1680 train_time:100039ms step_avg:89.72ms +step:1116/1680 train_time:100130ms step_avg:89.72ms +step:1117/1680 train_time:100222ms step_avg:89.72ms +step:1118/1680 train_time:100313ms step_avg:89.73ms +step:1119/1680 train_time:100404ms step_avg:89.73ms +step:1120/1680 train_time:100495ms step_avg:89.73ms +step:1121/1680 train_time:100585ms step_avg:89.73ms +step:1122/1680 train_time:100676ms step_avg:89.73ms +step:1123/1680 train_time:100767ms step_avg:89.73ms +step:1124/1680 train_time:100858ms step_avg:89.73ms +step:1125/1680 train_time:100948ms step_avg:89.73ms +step:1125/1680 val_loss:3.4148 train_time:101040ms step_avg:89.81ms +step:1126/1680 train_time:101062ms step_avg:89.75ms +step:1127/1680 train_time:101135ms step_avg:89.74ms +step:1128/1680 train_time:101235ms step_avg:89.75ms +step:1129/1680 train_time:101326ms step_avg:89.75ms +step:1130/1680 train_time:101416ms step_avg:89.75ms +step:1131/1680 train_time:101506ms step_avg:89.75ms +step:1132/1680 train_time:101596ms step_avg:89.75ms +step:1133/1680 train_time:101686ms step_avg:89.75ms +step:1134/1680 train_time:101775ms step_avg:89.75ms +step:1135/1680 train_time:101865ms step_avg:89.75ms +step:1136/1680 train_time:101954ms step_avg:89.75ms +step:1137/1680 train_time:102047ms step_avg:89.75ms +step:1138/1680 train_time:102142ms step_avg:89.76ms +step:1139/1680 train_time:102234ms step_avg:89.76ms +step:1140/1680 train_time:102326ms step_avg:89.76ms +step:1141/1680 train_time:102417ms step_avg:89.76ms +step:1142/1680 train_time:102508ms step_avg:89.76ms +step:1143/1680 train_time:102598ms step_avg:89.76ms +step:1144/1680 train_time:102688ms step_avg:89.76ms +step:1145/1680 train_time:102777ms step_avg:89.76ms +step:1146/1680 train_time:102867ms step_avg:89.76ms +step:1147/1680 train_time:102957ms step_avg:89.76ms +step:1148/1680 train_time:103049ms step_avg:89.76ms +step:1149/1680 train_time:103142ms step_avg:89.77ms +step:1150/1680 train_time:103233ms step_avg:89.77ms +step:1151/1680 train_time:103325ms step_avg:89.77ms +step:1152/1680 train_time:103416ms step_avg:89.77ms +step:1153/1680 train_time:103508ms step_avg:89.77ms +step:1154/1680 train_time:103598ms step_avg:89.77ms +step:1155/1680 train_time:103689ms step_avg:89.77ms +step:1156/1680 train_time:103778ms step_avg:89.77ms +step:1157/1680 train_time:103868ms step_avg:89.77ms +step:1158/1680 train_time:103959ms step_avg:89.78ms +step:1159/1680 train_time:104050ms step_avg:89.78ms +step:1160/1680 train_time:104142ms step_avg:89.78ms +step:1161/1680 train_time:104232ms step_avg:89.78ms +step:1162/1680 train_time:104325ms step_avg:89.78ms +step:1163/1680 train_time:104415ms step_avg:89.78ms +step:1164/1680 train_time:104507ms step_avg:89.78ms +step:1165/1680 train_time:104598ms step_avg:89.78ms +step:1166/1680 train_time:104689ms step_avg:89.78ms +step:1167/1680 train_time:104779ms step_avg:89.78ms +step:1168/1680 train_time:104869ms step_avg:89.79ms +step:1169/1680 train_time:104960ms step_avg:89.79ms +step:1170/1680 train_time:105050ms step_avg:89.79ms +step:1171/1680 train_time:105143ms step_avg:89.79ms +step:1172/1680 train_time:105234ms step_avg:89.79ms +step:1173/1680 train_time:105327ms step_avg:89.79ms +step:1174/1680 train_time:105417ms step_avg:89.79ms +step:1175/1680 train_time:105508ms step_avg:89.79ms +step:1176/1680 train_time:105599ms step_avg:89.79ms +step:1177/1680 train_time:105689ms step_avg:89.80ms +step:1178/1680 train_time:105779ms step_avg:89.80ms +step:1179/1680 train_time:105870ms step_avg:89.80ms +step:1180/1680 train_time:105960ms step_avg:89.80ms +step:1181/1680 train_time:106051ms step_avg:89.80ms +step:1182/1680 train_time:106143ms step_avg:89.80ms +step:1183/1680 train_time:106232ms step_avg:89.80ms +step:1184/1680 train_time:106324ms step_avg:89.80ms +step:1185/1680 train_time:106414ms step_avg:89.80ms +step:1186/1680 train_time:106505ms step_avg:89.80ms +step:1187/1680 train_time:106597ms step_avg:89.80ms +step:1188/1680 train_time:106687ms step_avg:89.80ms +step:1189/1680 train_time:106778ms step_avg:89.80ms +step:1190/1680 train_time:106869ms step_avg:89.81ms +step:1191/1680 train_time:106960ms step_avg:89.81ms +step:1192/1680 train_time:107051ms step_avg:89.81ms +step:1193/1680 train_time:107142ms step_avg:89.81ms +step:1194/1680 train_time:107233ms step_avg:89.81ms +step:1195/1680 train_time:107325ms step_avg:89.81ms +step:1196/1680 train_time:107416ms step_avg:89.81ms +step:1197/1680 train_time:107507ms step_avg:89.81ms +step:1198/1680 train_time:107598ms step_avg:89.81ms +step:1199/1680 train_time:107689ms step_avg:89.82ms +step:1200/1680 train_time:107779ms step_avg:89.82ms +step:1201/1680 train_time:107870ms step_avg:89.82ms +step:1202/1680 train_time:107962ms step_avg:89.82ms +step:1203/1680 train_time:108051ms step_avg:89.82ms +step:1204/1680 train_time:108142ms step_avg:89.82ms +step:1205/1680 train_time:108232ms step_avg:89.82ms +step:1206/1680 train_time:108323ms step_avg:89.82ms +step:1207/1680 train_time:108414ms step_avg:89.82ms +step:1208/1680 train_time:108504ms step_avg:89.82ms +step:1209/1680 train_time:108595ms step_avg:89.82ms +step:1210/1680 train_time:108686ms step_avg:89.82ms +step:1211/1680 train_time:108776ms step_avg:89.82ms +step:1212/1680 train_time:108868ms step_avg:89.83ms +step:1213/1680 train_time:108959ms step_avg:89.83ms +step:1214/1680 train_time:109049ms step_avg:89.83ms +step:1215/1680 train_time:109140ms step_avg:89.83ms +step:1216/1680 train_time:109231ms step_avg:89.83ms +step:1217/1680 train_time:109322ms step_avg:89.83ms +step:1218/1680 train_time:109413ms step_avg:89.83ms +step:1219/1680 train_time:109504ms step_avg:89.83ms +step:1220/1680 train_time:109595ms step_avg:89.83ms +step:1221/1680 train_time:109686ms step_avg:89.83ms +step:1222/1680 train_time:109776ms step_avg:89.83ms +step:1223/1680 train_time:109868ms step_avg:89.83ms +step:1224/1680 train_time:109958ms step_avg:89.84ms +step:1225/1680 train_time:110049ms step_avg:89.84ms +step:1226/1680 train_time:110140ms step_avg:89.84ms +step:1227/1680 train_time:110230ms step_avg:89.84ms +step:1228/1680 train_time:110320ms step_avg:89.84ms +step:1229/1680 train_time:110412ms step_avg:89.84ms +step:1230/1680 train_time:110503ms step_avg:89.84ms +step:1231/1680 train_time:110594ms step_avg:89.84ms +step:1232/1680 train_time:110685ms step_avg:89.84ms +step:1233/1680 train_time:110775ms step_avg:89.84ms +step:1234/1680 train_time:110867ms step_avg:89.84ms +step:1235/1680 train_time:110958ms step_avg:89.84ms +step:1236/1680 train_time:111049ms step_avg:89.85ms +step:1237/1680 train_time:111140ms step_avg:89.85ms +step:1238/1680 train_time:111230ms step_avg:89.85ms +step:1239/1680 train_time:111321ms step_avg:89.85ms +step:1240/1680 train_time:111412ms step_avg:89.85ms +step:1241/1680 train_time:111503ms step_avg:89.85ms +step:1242/1680 train_time:111593ms step_avg:89.85ms +step:1243/1680 train_time:111684ms step_avg:89.85ms +step:1244/1680 train_time:111775ms step_avg:89.85ms +step:1245/1680 train_time:111866ms step_avg:89.85ms +step:1246/1680 train_time:111958ms step_avg:89.85ms +step:1247/1680 train_time:112048ms step_avg:89.85ms +step:1248/1680 train_time:112139ms step_avg:89.86ms +step:1249/1680 train_time:112228ms step_avg:89.85ms +step:1250/1680 train_time:112319ms step_avg:89.86ms +step:1250/1680 val_loss:3.3764 train_time:112412ms step_avg:89.93ms +step:1251/1680 train_time:112434ms step_avg:89.88ms +step:1252/1680 train_time:112508ms step_avg:89.86ms +step:1253/1680 train_time:112605ms step_avg:89.87ms +step:1254/1680 train_time:112696ms step_avg:89.87ms +step:1255/1680 train_time:112787ms step_avg:89.87ms +step:1256/1680 train_time:112877ms step_avg:89.87ms +step:1257/1680 train_time:112967ms step_avg:89.87ms +step:1258/1680 train_time:113056ms step_avg:89.87ms +step:1259/1680 train_time:113146ms step_avg:89.87ms +step:1260/1680 train_time:113235ms step_avg:89.87ms +step:1261/1680 train_time:113325ms step_avg:89.87ms +step:1262/1680 train_time:113416ms step_avg:89.87ms +step:1263/1680 train_time:113510ms step_avg:89.87ms +step:1264/1680 train_time:113605ms step_avg:89.88ms +step:1265/1680 train_time:113696ms step_avg:89.88ms +step:1266/1680 train_time:113787ms step_avg:89.88ms +step:1267/1680 train_time:113877ms step_avg:89.88ms +step:1268/1680 train_time:113967ms step_avg:89.88ms +step:1269/1680 train_time:114057ms step_avg:89.88ms +step:1270/1680 train_time:114146ms step_avg:89.88ms +step:1271/1680 train_time:114236ms step_avg:89.88ms +step:1272/1680 train_time:114326ms step_avg:89.88ms +step:1273/1680 train_time:114417ms step_avg:89.88ms +step:1274/1680 train_time:114512ms step_avg:89.88ms +step:1275/1680 train_time:114606ms step_avg:89.89ms +step:1276/1680 train_time:114697ms step_avg:89.89ms +step:1277/1680 train_time:114789ms step_avg:89.89ms +step:1278/1680 train_time:114879ms step_avg:89.89ms +step:1279/1680 train_time:114970ms step_avg:89.89ms +step:1280/1680 train_time:115060ms step_avg:89.89ms +step:1281/1680 train_time:115149ms step_avg:89.89ms +step:1282/1680 train_time:115240ms step_avg:89.89ms +step:1283/1680 train_time:115330ms step_avg:89.89ms +step:1284/1680 train_time:115421ms step_avg:89.89ms +step:1285/1680 train_time:115514ms step_avg:89.89ms +step:1286/1680 train_time:115605ms step_avg:89.90ms +step:1287/1680 train_time:115696ms step_avg:89.90ms +step:1288/1680 train_time:115787ms step_avg:89.90ms +step:1289/1680 train_time:115877ms step_avg:89.90ms +step:1290/1680 train_time:115968ms step_avg:89.90ms +step:1291/1680 train_time:116059ms step_avg:89.90ms +step:1292/1680 train_time:116149ms step_avg:89.90ms +step:1293/1680 train_time:116239ms step_avg:89.90ms +step:1294/1680 train_time:116330ms step_avg:89.90ms +step:1295/1680 train_time:116420ms step_avg:89.90ms +step:1296/1680 train_time:116513ms step_avg:89.90ms +step:1297/1680 train_time:116605ms step_avg:89.90ms +step:1298/1680 train_time:116695ms step_avg:89.90ms +step:1299/1680 train_time:116786ms step_avg:89.90ms +step:1300/1680 train_time:116876ms step_avg:89.90ms +step:1301/1680 train_time:116967ms step_avg:89.91ms +step:1302/1680 train_time:117057ms step_avg:89.91ms +step:1303/1680 train_time:117148ms step_avg:89.91ms +step:1304/1680 train_time:117238ms step_avg:89.91ms +step:1305/1680 train_time:117329ms step_avg:89.91ms +step:1306/1680 train_time:117419ms step_avg:89.91ms +step:1307/1680 train_time:117511ms step_avg:89.91ms +step:1308/1680 train_time:117604ms step_avg:89.91ms +step:1309/1680 train_time:117694ms step_avg:89.91ms +step:1310/1680 train_time:117786ms step_avg:89.91ms +step:1311/1680 train_time:117877ms step_avg:89.91ms +step:1312/1680 train_time:117967ms step_avg:89.91ms +step:1313/1680 train_time:118058ms step_avg:89.91ms +step:1314/1680 train_time:118149ms step_avg:89.92ms +step:1315/1680 train_time:118239ms step_avg:89.92ms +step:1316/1680 train_time:118331ms step_avg:89.92ms +step:1317/1680 train_time:118421ms step_avg:89.92ms +step:1318/1680 train_time:118513ms step_avg:89.92ms +step:1319/1680 train_time:118605ms step_avg:89.92ms +step:1320/1680 train_time:118696ms step_avg:89.92ms +step:1321/1680 train_time:118787ms step_avg:89.92ms +step:1322/1680 train_time:118877ms step_avg:89.92ms +step:1323/1680 train_time:118970ms step_avg:89.92ms +step:1324/1680 train_time:119061ms step_avg:89.92ms +step:1325/1680 train_time:119152ms step_avg:89.93ms +step:1326/1680 train_time:119242ms step_avg:89.93ms +step:1327/1680 train_time:119332ms step_avg:89.93ms +step:1328/1680 train_time:119423ms step_avg:89.93ms +step:1329/1680 train_time:119514ms step_avg:89.93ms +step:1330/1680 train_time:119606ms step_avg:89.93ms +step:1331/1680 train_time:119697ms step_avg:89.93ms +step:1332/1680 train_time:119787ms step_avg:89.93ms +step:1333/1680 train_time:119877ms step_avg:89.93ms +step:1334/1680 train_time:119968ms step_avg:89.93ms +step:1335/1680 train_time:120058ms step_avg:89.93ms +step:1336/1680 train_time:120149ms step_avg:89.93ms +step:1337/1680 train_time:120239ms step_avg:89.93ms +step:1338/1680 train_time:120330ms step_avg:89.93ms +step:1339/1680 train_time:120421ms step_avg:89.93ms +step:1340/1680 train_time:120512ms step_avg:89.93ms +step:1341/1680 train_time:120603ms step_avg:89.93ms +step:1342/1680 train_time:120694ms step_avg:89.94ms +step:1343/1680 train_time:120784ms step_avg:89.94ms +step:1344/1680 train_time:120875ms step_avg:89.94ms +step:1345/1680 train_time:120966ms step_avg:89.94ms +step:1346/1680 train_time:121056ms step_avg:89.94ms +step:1347/1680 train_time:121148ms step_avg:89.94ms +step:1348/1680 train_time:121238ms step_avg:89.94ms +step:1349/1680 train_time:121330ms step_avg:89.94ms +step:1350/1680 train_time:121419ms step_avg:89.94ms +step:1351/1680 train_time:121512ms step_avg:89.94ms +step:1352/1680 train_time:121603ms step_avg:89.94ms +step:1353/1680 train_time:121694ms step_avg:89.94ms +step:1354/1680 train_time:121784ms step_avg:89.94ms +step:1355/1680 train_time:121875ms step_avg:89.94ms +step:1356/1680 train_time:121966ms step_avg:89.95ms +step:1357/1680 train_time:122056ms step_avg:89.95ms +step:1358/1680 train_time:122147ms step_avg:89.95ms +step:1359/1680 train_time:122237ms step_avg:89.95ms +step:1360/1680 train_time:122328ms step_avg:89.95ms +step:1361/1680 train_time:122419ms step_avg:89.95ms +step:1362/1680 train_time:122510ms step_avg:89.95ms +step:1363/1680 train_time:122602ms step_avg:89.95ms +step:1364/1680 train_time:122693ms step_avg:89.95ms +step:1365/1680 train_time:122785ms step_avg:89.95ms +step:1366/1680 train_time:122875ms step_avg:89.95ms +step:1367/1680 train_time:122966ms step_avg:89.95ms +step:1368/1680 train_time:123058ms step_avg:89.95ms +step:1369/1680 train_time:123148ms step_avg:89.95ms +step:1370/1680 train_time:123239ms step_avg:89.96ms +step:1371/1680 train_time:123330ms step_avg:89.96ms +step:1372/1680 train_time:123420ms step_avg:89.96ms +step:1373/1680 train_time:123512ms step_avg:89.96ms +step:1374/1680 train_time:123603ms step_avg:89.96ms +step:1375/1680 train_time:123694ms step_avg:89.96ms +step:1375/1680 val_loss:3.3412 train_time:123787ms step_avg:90.03ms +step:1376/1680 train_time:123809ms step_avg:89.98ms +step:1377/1680 train_time:123882ms step_avg:89.97ms +step:1378/1680 train_time:123978ms step_avg:89.97ms +step:1379/1680 train_time:124069ms step_avg:89.97ms +step:1380/1680 train_time:124159ms step_avg:89.97ms +step:1381/1680 train_time:124250ms step_avg:89.97ms +step:1382/1680 train_time:124339ms step_avg:89.97ms +step:1383/1680 train_time:124429ms step_avg:89.97ms +step:1384/1680 train_time:124519ms step_avg:89.97ms +step:1385/1680 train_time:124608ms step_avg:89.97ms +step:1386/1680 train_time:124699ms step_avg:89.97ms +step:1387/1680 train_time:124792ms step_avg:89.97ms +step:1388/1680 train_time:124885ms step_avg:89.98ms +step:1389/1680 train_time:124978ms step_avg:89.98ms +step:1390/1680 train_time:125070ms step_avg:89.98ms +step:1391/1680 train_time:125160ms step_avg:89.98ms +step:1392/1680 train_time:125252ms step_avg:89.98ms +step:1393/1680 train_time:125343ms step_avg:89.98ms +step:1394/1680 train_time:125432ms step_avg:89.98ms +step:1395/1680 train_time:125522ms step_avg:89.98ms +step:1396/1680 train_time:125613ms step_avg:89.98ms +step:1397/1680 train_time:125703ms step_avg:89.98ms +step:1398/1680 train_time:125795ms step_avg:89.98ms +step:1399/1680 train_time:125887ms step_avg:89.98ms +step:1400/1680 train_time:125979ms step_avg:89.99ms +step:1401/1680 train_time:126071ms step_avg:89.99ms +step:1402/1680 train_time:126160ms step_avg:89.99ms +step:1403/1680 train_time:126251ms step_avg:89.99ms +step:1404/1680 train_time:126342ms step_avg:89.99ms +step:1405/1680 train_time:126433ms step_avg:89.99ms +step:1406/1680 train_time:126522ms step_avg:89.99ms +step:1407/1680 train_time:126612ms step_avg:89.99ms +step:1408/1680 train_time:126703ms step_avg:89.99ms +step:1409/1680 train_time:126794ms step_avg:89.99ms +step:1410/1680 train_time:126886ms step_avg:89.99ms +step:1411/1680 train_time:126978ms step_avg:89.99ms +step:1412/1680 train_time:127069ms step_avg:89.99ms +step:1413/1680 train_time:127160ms step_avg:89.99ms +step:1414/1680 train_time:127251ms step_avg:89.99ms +step:1415/1680 train_time:127342ms step_avg:89.99ms +step:1416/1680 train_time:127433ms step_avg:89.99ms +step:1417/1680 train_time:127524ms step_avg:90.00ms +step:1418/1680 train_time:127614ms step_avg:90.00ms +step:1419/1680 train_time:127704ms step_avg:90.00ms +step:1420/1680 train_time:127795ms step_avg:90.00ms +step:1421/1680 train_time:127887ms step_avg:90.00ms +step:1422/1680 train_time:127978ms step_avg:90.00ms +step:1423/1680 train_time:128069ms step_avg:90.00ms +step:1424/1680 train_time:128160ms step_avg:90.00ms +step:1425/1680 train_time:128251ms step_avg:90.00ms +step:1426/1680 train_time:128341ms step_avg:90.00ms +step:1427/1680 train_time:128432ms step_avg:90.00ms +step:1428/1680 train_time:128524ms step_avg:90.00ms +step:1429/1680 train_time:128614ms step_avg:90.00ms +step:1430/1680 train_time:128705ms step_avg:90.00ms +step:1431/1680 train_time:128796ms step_avg:90.00ms +step:1432/1680 train_time:128886ms step_avg:90.00ms +step:1433/1680 train_time:128978ms step_avg:90.01ms +step:1434/1680 train_time:129069ms step_avg:90.01ms +step:1435/1680 train_time:129160ms step_avg:90.01ms +step:1436/1680 train_time:129252ms step_avg:90.01ms +step:1437/1680 train_time:129342ms step_avg:90.01ms +step:1438/1680 train_time:129433ms step_avg:90.01ms +step:1439/1680 train_time:129523ms step_avg:90.01ms +step:1440/1680 train_time:129614ms step_avg:90.01ms +step:1441/1680 train_time:129704ms step_avg:90.01ms +step:1442/1680 train_time:129796ms step_avg:90.01ms +step:1443/1680 train_time:129887ms step_avg:90.01ms +step:1444/1680 train_time:129978ms step_avg:90.01ms +step:1445/1680 train_time:130069ms step_avg:90.01ms +step:1446/1680 train_time:130161ms step_avg:90.01ms +step:1447/1680 train_time:130252ms step_avg:90.01ms +step:1448/1680 train_time:130343ms step_avg:90.02ms +step:1449/1680 train_time:130433ms step_avg:90.02ms +step:1450/1680 train_time:130524ms step_avg:90.02ms +step:1451/1680 train_time:130614ms step_avg:90.02ms +step:1452/1680 train_time:130705ms step_avg:90.02ms +step:1453/1680 train_time:130796ms step_avg:90.02ms +step:1454/1680 train_time:130887ms step_avg:90.02ms +step:1455/1680 train_time:130978ms step_avg:90.02ms +step:1456/1680 train_time:131070ms step_avg:90.02ms +step:1457/1680 train_time:131161ms step_avg:90.02ms +step:1458/1680 train_time:131252ms step_avg:90.02ms +step:1459/1680 train_time:131342ms step_avg:90.02ms +step:1460/1680 train_time:131433ms step_avg:90.02ms +step:1461/1680 train_time:131524ms step_avg:90.02ms +step:1462/1680 train_time:131614ms step_avg:90.02ms +step:1463/1680 train_time:131705ms step_avg:90.02ms +step:1464/1680 train_time:131795ms step_avg:90.02ms +step:1465/1680 train_time:131887ms step_avg:90.03ms +step:1466/1680 train_time:131977ms step_avg:90.03ms +step:1467/1680 train_time:132068ms step_avg:90.03ms +step:1468/1680 train_time:132159ms step_avg:90.03ms +step:1469/1680 train_time:132250ms step_avg:90.03ms +step:1470/1680 train_time:132342ms step_avg:90.03ms +step:1471/1680 train_time:132433ms step_avg:90.03ms +step:1472/1680 train_time:132525ms step_avg:90.03ms +step:1473/1680 train_time:132616ms step_avg:90.03ms +step:1474/1680 train_time:132706ms step_avg:90.03ms +step:1475/1680 train_time:132797ms step_avg:90.03ms +step:1476/1680 train_time:132888ms step_avg:90.03ms +step:1477/1680 train_time:132979ms step_avg:90.03ms +step:1478/1680 train_time:133070ms step_avg:90.03ms +step:1479/1680 train_time:133161ms step_avg:90.03ms +step:1480/1680 train_time:133252ms step_avg:90.04ms +step:1481/1680 train_time:133343ms step_avg:90.04ms +step:1482/1680 train_time:133434ms step_avg:90.04ms +step:1483/1680 train_time:133525ms step_avg:90.04ms +step:1484/1680 train_time:133616ms step_avg:90.04ms +step:1485/1680 train_time:133707ms step_avg:90.04ms +step:1486/1680 train_time:133798ms step_avg:90.04ms +step:1487/1680 train_time:133889ms step_avg:90.04ms +step:1488/1680 train_time:133981ms step_avg:90.04ms +step:1489/1680 train_time:134072ms step_avg:90.04ms +step:1490/1680 train_time:134162ms step_avg:90.04ms +step:1491/1680 train_time:134253ms step_avg:90.04ms +step:1492/1680 train_time:134344ms step_avg:90.04ms +step:1493/1680 train_time:134435ms step_avg:90.04ms +step:1494/1680 train_time:134526ms step_avg:90.04ms +step:1495/1680 train_time:134616ms step_avg:90.04ms +step:1496/1680 train_time:134707ms step_avg:90.04ms +step:1497/1680 train_time:134797ms step_avg:90.04ms +step:1498/1680 train_time:134888ms step_avg:90.05ms +step:1499/1680 train_time:134979ms step_avg:90.05ms +step:1500/1680 train_time:135070ms step_avg:90.05ms +step:1500/1680 val_loss:3.3117 train_time:135161ms step_avg:90.11ms +step:1501/1680 train_time:135183ms step_avg:90.06ms +step:1502/1680 train_time:135254ms step_avg:90.05ms +step:1503/1680 train_time:135352ms step_avg:90.05ms +step:1504/1680 train_time:135444ms step_avg:90.06ms +step:1505/1680 train_time:135533ms step_avg:90.06ms +step:1506/1680 train_time:135623ms step_avg:90.05ms +step:1507/1680 train_time:135712ms step_avg:90.05ms +step:1508/1680 train_time:135803ms step_avg:90.05ms +step:1509/1680 train_time:135892ms step_avg:90.05ms +step:1510/1680 train_time:135983ms step_avg:90.06ms +step:1511/1680 train_time:136073ms step_avg:90.05ms +step:1512/1680 train_time:136165ms step_avg:90.06ms +step:1513/1680 train_time:136259ms step_avg:90.06ms +step:1514/1680 train_time:136352ms step_avg:90.06ms +step:1515/1680 train_time:136444ms step_avg:90.06ms +step:1516/1680 train_time:136534ms step_avg:90.06ms +step:1517/1680 train_time:136624ms step_avg:90.06ms +step:1518/1680 train_time:136714ms step_avg:90.06ms +step:1519/1680 train_time:136805ms step_avg:90.06ms +step:1520/1680 train_time:136896ms step_avg:90.06ms +step:1521/1680 train_time:136986ms step_avg:90.06ms +step:1522/1680 train_time:137077ms step_avg:90.06ms +step:1523/1680 train_time:137169ms step_avg:90.06ms +step:1524/1680 train_time:137261ms step_avg:90.07ms +step:1525/1680 train_time:137352ms step_avg:90.07ms +step:1526/1680 train_time:137444ms step_avg:90.07ms +step:1527/1680 train_time:137535ms step_avg:90.07ms +step:1528/1680 train_time:137625ms step_avg:90.07ms +step:1529/1680 train_time:137715ms step_avg:90.07ms +step:1530/1680 train_time:137806ms step_avg:90.07ms +step:1531/1680 train_time:137897ms step_avg:90.07ms +step:1532/1680 train_time:137988ms step_avg:90.07ms +step:1533/1680 train_time:138079ms step_avg:90.07ms +step:1534/1680 train_time:138170ms step_avg:90.07ms +step:1535/1680 train_time:138262ms step_avg:90.07ms +step:1536/1680 train_time:138353ms step_avg:90.07ms +step:1537/1680 train_time:138445ms step_avg:90.08ms +step:1538/1680 train_time:138536ms step_avg:90.08ms +step:1539/1680 train_time:138627ms step_avg:90.08ms +step:1540/1680 train_time:138716ms step_avg:90.08ms +step:1541/1680 train_time:138807ms step_avg:90.08ms +step:1542/1680 train_time:138897ms step_avg:90.08ms +step:1543/1680 train_time:138989ms step_avg:90.08ms +step:1544/1680 train_time:139080ms step_avg:90.08ms +step:1545/1680 train_time:139171ms step_avg:90.08ms +step:1546/1680 train_time:139262ms step_avg:90.08ms +step:1547/1680 train_time:139352ms step_avg:90.08ms +step:1548/1680 train_time:139444ms step_avg:90.08ms +step:1549/1680 train_time:139535ms step_avg:90.08ms +step:1550/1680 train_time:139625ms step_avg:90.08ms +step:1551/1680 train_time:139715ms step_avg:90.08ms +step:1552/1680 train_time:139806ms step_avg:90.08ms +step:1553/1680 train_time:139896ms step_avg:90.08ms +step:1554/1680 train_time:139988ms step_avg:90.08ms +step:1555/1680 train_time:140080ms step_avg:90.08ms +step:1556/1680 train_time:140170ms step_avg:90.08ms +step:1557/1680 train_time:140261ms step_avg:90.08ms +step:1558/1680 train_time:140352ms step_avg:90.08ms +step:1559/1680 train_time:140444ms step_avg:90.09ms +step:1560/1680 train_time:140534ms step_avg:90.09ms +step:1561/1680 train_time:140626ms step_avg:90.09ms +step:1562/1680 train_time:140716ms step_avg:90.09ms +step:1563/1680 train_time:140808ms step_avg:90.09ms +step:1564/1680 train_time:140898ms step_avg:90.09ms +step:1565/1680 train_time:140989ms step_avg:90.09ms +step:1566/1680 train_time:141079ms step_avg:90.09ms +step:1567/1680 train_time:141170ms step_avg:90.09ms +step:1568/1680 train_time:141261ms step_avg:90.09ms +step:1569/1680 train_time:141352ms step_avg:90.09ms +step:1570/1680 train_time:141444ms step_avg:90.09ms +step:1571/1680 train_time:141534ms step_avg:90.09ms +step:1572/1680 train_time:141625ms step_avg:90.09ms +step:1573/1680 train_time:141715ms step_avg:90.09ms +step:1574/1680 train_time:141807ms step_avg:90.09ms +step:1575/1680 train_time:141899ms step_avg:90.09ms +step:1576/1680 train_time:141989ms step_avg:90.09ms +step:1577/1680 train_time:142080ms step_avg:90.10ms +step:1578/1680 train_time:142171ms step_avg:90.10ms +step:1579/1680 train_time:142262ms step_avg:90.10ms +step:1580/1680 train_time:142353ms step_avg:90.10ms +step:1581/1680 train_time:142444ms step_avg:90.10ms +step:1582/1680 train_time:142535ms step_avg:90.10ms +step:1583/1680 train_time:142626ms step_avg:90.10ms +step:1584/1680 train_time:142716ms step_avg:90.10ms +step:1585/1680 train_time:142808ms step_avg:90.10ms +step:1586/1680 train_time:142899ms step_avg:90.10ms +step:1587/1680 train_time:142990ms step_avg:90.10ms +step:1588/1680 train_time:143081ms step_avg:90.10ms +step:1589/1680 train_time:143172ms step_avg:90.10ms +step:1590/1680 train_time:143262ms step_avg:90.10ms +step:1591/1680 train_time:143353ms step_avg:90.10ms +step:1592/1680 train_time:143443ms step_avg:90.10ms +step:1593/1680 train_time:143533ms step_avg:90.10ms +step:1594/1680 train_time:143624ms step_avg:90.10ms +step:1595/1680 train_time:143714ms step_avg:90.10ms +step:1596/1680 train_time:143805ms step_avg:90.10ms +step:1597/1680 train_time:143896ms step_avg:90.10ms +step:1598/1680 train_time:143987ms step_avg:90.10ms +step:1599/1680 train_time:144079ms step_avg:90.11ms +step:1600/1680 train_time:144171ms step_avg:90.11ms +step:1601/1680 train_time:144262ms step_avg:90.11ms +step:1602/1680 train_time:144352ms step_avg:90.11ms +step:1603/1680 train_time:144443ms step_avg:90.11ms +step:1604/1680 train_time:144534ms step_avg:90.11ms +step:1605/1680 train_time:144625ms step_avg:90.11ms +step:1606/1680 train_time:144716ms step_avg:90.11ms +step:1607/1680 train_time:144806ms step_avg:90.11ms +step:1608/1680 train_time:144898ms step_avg:90.11ms +step:1609/1680 train_time:144989ms step_avg:90.11ms +step:1610/1680 train_time:145081ms step_avg:90.11ms +step:1611/1680 train_time:145172ms step_avg:90.11ms +step:1612/1680 train_time:145263ms step_avg:90.11ms +step:1613/1680 train_time:145354ms step_avg:90.11ms +step:1614/1680 train_time:145445ms step_avg:90.11ms +step:1615/1680 train_time:145535ms step_avg:90.11ms +step:1616/1680 train_time:145627ms step_avg:90.12ms +step:1617/1680 train_time:145717ms step_avg:90.12ms +step:1618/1680 train_time:145809ms step_avg:90.12ms +step:1619/1680 train_time:145901ms step_avg:90.12ms +step:1620/1680 train_time:145991ms step_avg:90.12ms +step:1621/1680 train_time:146084ms step_avg:90.12ms +step:1622/1680 train_time:146174ms step_avg:90.12ms +step:1623/1680 train_time:146265ms step_avg:90.12ms +step:1624/1680 train_time:146355ms step_avg:90.12ms +step:1625/1680 train_time:146446ms step_avg:90.12ms +step:1625/1680 val_loss:3.2878 train_time:146537ms step_avg:90.18ms +step:1626/1680 train_time:146559ms step_avg:90.13ms +step:1627/1680 train_time:146632ms step_avg:90.12ms +step:1628/1680 train_time:146729ms step_avg:90.13ms +step:1629/1680 train_time:146822ms step_avg:90.13ms +step:1630/1680 train_time:146912ms step_avg:90.13ms +step:1631/1680 train_time:147002ms step_avg:90.13ms +step:1632/1680 train_time:147091ms step_avg:90.13ms +step:1633/1680 train_time:147181ms step_avg:90.13ms +step:1634/1680 train_time:147270ms step_avg:90.13ms +step:1635/1680 train_time:147360ms step_avg:90.13ms +step:1636/1680 train_time:147450ms step_avg:90.13ms +step:1637/1680 train_time:147542ms step_avg:90.13ms +step:1638/1680 train_time:147636ms step_avg:90.13ms +step:1639/1680 train_time:147728ms step_avg:90.13ms +step:1640/1680 train_time:147820ms step_avg:90.13ms +step:1641/1680 train_time:147911ms step_avg:90.13ms +step:1642/1680 train_time:148002ms step_avg:90.14ms +step:1643/1680 train_time:148092ms step_avg:90.14ms +step:1644/1680 train_time:148182ms step_avg:90.14ms +step:1645/1680 train_time:148271ms step_avg:90.13ms +step:1646/1680 train_time:148361ms step_avg:90.13ms +step:1647/1680 train_time:148451ms step_avg:90.13ms +step:1648/1680 train_time:148543ms step_avg:90.14ms +step:1649/1680 train_time:148636ms step_avg:90.14ms +step:1650/1680 train_time:148728ms step_avg:90.14ms +step:1651/1680 train_time:148820ms step_avg:90.14ms +step:1652/1680 train_time:148911ms step_avg:90.14ms +step:1653/1680 train_time:149002ms step_avg:90.14ms +step:1654/1680 train_time:149092ms step_avg:90.14ms +step:1655/1680 train_time:149183ms step_avg:90.14ms +step:1656/1680 train_time:149273ms step_avg:90.14ms +step:1657/1680 train_time:149363ms step_avg:90.14ms +step:1658/1680 train_time:149454ms step_avg:90.14ms +step:1659/1680 train_time:149546ms step_avg:90.14ms +step:1660/1680 train_time:149637ms step_avg:90.14ms +step:1661/1680 train_time:149729ms step_avg:90.14ms +step:1662/1680 train_time:149821ms step_avg:90.14ms +step:1663/1680 train_time:149911ms step_avg:90.15ms +step:1664/1680 train_time:150003ms step_avg:90.15ms +step:1665/1680 train_time:150093ms step_avg:90.15ms +step:1666/1680 train_time:150183ms step_avg:90.15ms +step:1667/1680 train_time:150273ms step_avg:90.15ms +step:1668/1680 train_time:150364ms step_avg:90.15ms +step:1669/1680 train_time:150453ms step_avg:90.15ms +step:1670/1680 train_time:150544ms step_avg:90.15ms +step:1671/1680 train_time:150636ms step_avg:90.15ms +step:1672/1680 train_time:150727ms step_avg:90.15ms +step:1673/1680 train_time:150819ms step_avg:90.15ms +step:1674/1680 train_time:150911ms step_avg:90.15ms +step:1675/1680 train_time:151003ms step_avg:90.15ms +step:1676/1680 train_time:151093ms step_avg:90.15ms +step:1677/1680 train_time:151183ms step_avg:90.15ms +step:1678/1680 train_time:151273ms step_avg:90.15ms +step:1679/1680 train_time:151364ms step_avg:90.15ms +step:1680/1680 train_time:151454ms step_avg:90.15ms +step:1680/1680 val_loss:3.2772 train_time:151547ms step_avg:90.21ms +peak memory allocated: 31255 MiB reserved: 46494 MiB diff --git a/records/092125_DropAttn/ab8c620e-3d52-42eb-b46e-d69b608b22bc.txt b/records/092125_DropAttn/ab8c620e-3d52-42eb-b46e-d69b608b22bc.txt new file mode 100644 index 000000000..5d033e0c7 --- /dev/null +++ b/records/092125_DropAttn/ab8c620e-3d52-42eb-b46e-d69b608b22bc.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 23:17:28 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 29C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 35C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 82769 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 82770 C /usr/bin/python3 614MiB | +| 0 N/A N/A 82771 C /usr/bin/python3 614MiB | +| 0 N/A N/A 82772 C /usr/bin/python3 614MiB | +| 0 N/A N/A 82773 C /usr/bin/python3 614MiB | +| 0 N/A N/A 82774 C /usr/bin/python3 614MiB | +| 0 N/A N/A 82775 C /usr/bin/python3 614MiB | +| 0 N/A N/A 82776 C /usr/bin/python3 614MiB | +| 1 N/A N/A 82770 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 82771 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 82772 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 82773 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 82774 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 82775 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 82776 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:151ms step_avg:151.32ms +step:2/1680 train_time:176ms step_avg:88.00ms +step:3/1680 train_time:237ms step_avg:78.91ms +step:4/1680 train_time:324ms step_avg:80.91ms +step:5/1680 train_time:412ms step_avg:82.36ms +step:6/1680 train_time:500ms step_avg:83.26ms +step:7/1680 train_time:587ms step_avg:83.91ms +step:8/1680 train_time:676ms step_avg:84.46ms +step:9/1680 train_time:764ms step_avg:84.85ms +step:10/1680 train_time:852ms step_avg:85.19ms +step:11/1680 train_time:940ms step_avg:85.47ms +step:12/1680 train_time:1031ms step_avg:85.89ms +step:13/1680 train_time:1122ms step_avg:86.33ms +step:14/1680 train_time:1213ms step_avg:86.66ms +step:15/1680 train_time:1302ms step_avg:86.83ms +step:16/1680 train_time:1392ms step_avg:86.98ms +step:17/1680 train_time:1480ms step_avg:87.07ms +step:18/1680 train_time:1568ms step_avg:87.13ms +step:19/1680 train_time:1656ms step_avg:87.18ms +step:20/1680 train_time:1744ms step_avg:87.22ms +step:21/1680 train_time:1833ms step_avg:87.28ms +step:22/1680 train_time:1922ms step_avg:87.35ms +step:23/1680 train_time:2011ms step_avg:87.42ms +step:24/1680 train_time:2101ms step_avg:87.54ms +step:25/1680 train_time:2193ms step_avg:87.70ms +step:26/1680 train_time:2282ms step_avg:87.77ms +step:27/1680 train_time:2372ms step_avg:87.84ms +step:28/1680 train_time:2460ms step_avg:87.87ms +step:29/1680 train_time:2549ms step_avg:87.91ms +step:30/1680 train_time:2638ms step_avg:87.94ms +step:31/1680 train_time:2727ms step_avg:87.96ms +step:32/1680 train_time:2815ms step_avg:87.98ms +step:33/1680 train_time:2904ms step_avg:87.99ms +step:34/1680 train_time:2993ms step_avg:88.03ms +step:35/1680 train_time:3083ms step_avg:88.09ms +step:36/1680 train_time:3173ms step_avg:88.15ms +step:37/1680 train_time:3263ms step_avg:88.20ms +step:38/1680 train_time:3353ms step_avg:88.25ms +step:39/1680 train_time:3443ms step_avg:88.27ms +step:40/1680 train_time:3532ms step_avg:88.31ms +step:41/1680 train_time:3622ms step_avg:88.34ms +step:42/1680 train_time:3711ms step_avg:88.36ms +step:43/1680 train_time:3801ms step_avg:88.39ms +step:44/1680 train_time:3889ms step_avg:88.39ms +step:45/1680 train_time:3978ms step_avg:88.41ms +step:46/1680 train_time:4068ms step_avg:88.43ms +step:47/1680 train_time:4157ms step_avg:88.46ms +step:48/1680 train_time:4247ms step_avg:88.48ms +step:49/1680 train_time:4336ms step_avg:88.50ms +step:50/1680 train_time:4426ms step_avg:88.51ms +step:51/1680 train_time:4516ms step_avg:88.55ms +step:52/1680 train_time:4605ms step_avg:88.56ms +step:53/1680 train_time:4695ms step_avg:88.58ms +step:54/1680 train_time:4783ms step_avg:88.57ms +step:55/1680 train_time:4872ms step_avg:88.58ms +step:56/1680 train_time:4961ms step_avg:88.59ms +step:57/1680 train_time:5050ms step_avg:88.59ms +step:58/1680 train_time:5139ms step_avg:88.60ms +step:59/1680 train_time:5229ms step_avg:88.62ms +step:60/1680 train_time:5318ms step_avg:88.63ms +step:61/1680 train_time:5407ms step_avg:88.64ms +step:62/1680 train_time:5497ms step_avg:88.66ms +step:63/1680 train_time:5586ms step_avg:88.66ms +step:64/1680 train_time:5674ms step_avg:88.66ms +step:65/1680 train_time:5763ms step_avg:88.66ms +step:66/1680 train_time:5852ms step_avg:88.66ms +step:67/1680 train_time:5940ms step_avg:88.66ms +step:68/1680 train_time:6030ms step_avg:88.67ms +step:69/1680 train_time:6119ms step_avg:88.68ms +step:70/1680 train_time:6208ms step_avg:88.69ms +step:71/1680 train_time:6298ms step_avg:88.71ms +step:72/1680 train_time:6387ms step_avg:88.71ms +step:73/1680 train_time:6476ms step_avg:88.71ms +step:74/1680 train_time:6565ms step_avg:88.72ms +step:75/1680 train_time:6656ms step_avg:88.74ms +step:76/1680 train_time:6745ms step_avg:88.75ms +step:77/1680 train_time:6834ms step_avg:88.76ms +step:78/1680 train_time:6923ms step_avg:88.75ms +step:79/1680 train_time:7012ms step_avg:88.76ms +step:80/1680 train_time:7101ms step_avg:88.76ms +step:81/1680 train_time:7190ms step_avg:88.77ms +step:82/1680 train_time:7280ms step_avg:88.78ms +step:83/1680 train_time:7369ms step_avg:88.78ms +step:84/1680 train_time:7457ms step_avg:88.78ms +step:85/1680 train_time:7546ms step_avg:88.78ms +step:86/1680 train_time:7636ms step_avg:88.80ms +step:87/1680 train_time:7725ms step_avg:88.79ms +step:88/1680 train_time:7815ms step_avg:88.81ms +step:89/1680 train_time:7904ms step_avg:88.81ms +step:90/1680 train_time:7993ms step_avg:88.82ms +step:91/1680 train_time:8082ms step_avg:88.81ms +step:92/1680 train_time:8171ms step_avg:88.82ms +step:93/1680 train_time:8260ms step_avg:88.82ms +step:94/1680 train_time:8349ms step_avg:88.82ms +step:95/1680 train_time:8438ms step_avg:88.82ms +step:96/1680 train_time:8528ms step_avg:88.83ms +step:97/1680 train_time:8617ms step_avg:88.84ms +step:98/1680 train_time:8707ms step_avg:88.85ms +step:99/1680 train_time:8797ms step_avg:88.85ms +step:100/1680 train_time:8885ms step_avg:88.85ms +step:101/1680 train_time:8974ms step_avg:88.85ms +step:102/1680 train_time:9063ms step_avg:88.85ms +step:103/1680 train_time:9152ms step_avg:88.85ms +step:104/1680 train_time:9241ms step_avg:88.85ms +step:105/1680 train_time:9330ms step_avg:88.85ms +step:106/1680 train_time:9420ms step_avg:88.87ms +step:107/1680 train_time:9508ms step_avg:88.86ms +step:108/1680 train_time:9597ms step_avg:88.86ms +step:109/1680 train_time:9686ms step_avg:88.86ms +step:110/1680 train_time:9776ms step_avg:88.87ms +step:111/1680 train_time:9865ms step_avg:88.88ms +step:112/1680 train_time:9955ms step_avg:88.88ms +step:113/1680 train_time:10043ms step_avg:88.88ms +step:114/1680 train_time:10133ms step_avg:88.89ms +step:115/1680 train_time:10222ms step_avg:88.89ms +step:116/1680 train_time:10311ms step_avg:88.89ms +step:117/1680 train_time:10400ms step_avg:88.89ms +step:118/1680 train_time:10489ms step_avg:88.89ms +step:119/1680 train_time:10577ms step_avg:88.88ms +step:120/1680 train_time:10666ms step_avg:88.88ms +step:121/1680 train_time:10755ms step_avg:88.88ms +step:122/1680 train_time:10844ms step_avg:88.88ms +step:123/1680 train_time:10934ms step_avg:88.89ms +step:124/1680 train_time:11022ms step_avg:88.89ms +step:125/1680 train_time:11112ms step_avg:88.89ms +step:125/1680 val_loss:4.2962 train_time:11203ms step_avg:89.62ms +step:126/1680 train_time:11227ms step_avg:89.10ms +step:127/1680 train_time:11294ms step_avg:88.93ms +step:128/1680 train_time:11393ms step_avg:89.01ms +step:129/1680 train_time:11487ms step_avg:89.05ms +step:130/1680 train_time:11575ms step_avg:89.04ms +step:131/1680 train_time:11664ms step_avg:89.04ms +step:132/1680 train_time:11752ms step_avg:89.03ms +step:133/1680 train_time:11839ms step_avg:89.02ms +step:134/1680 train_time:11927ms step_avg:89.01ms +step:135/1680 train_time:12015ms step_avg:89.00ms +step:136/1680 train_time:12104ms step_avg:89.00ms +step:137/1680 train_time:12192ms step_avg:88.99ms +step:138/1680 train_time:12282ms step_avg:89.00ms +step:139/1680 train_time:12374ms step_avg:89.02ms +step:140/1680 train_time:12466ms step_avg:89.04ms +step:141/1680 train_time:12555ms step_avg:89.04ms +step:142/1680 train_time:12644ms step_avg:89.05ms +step:143/1680 train_time:12733ms step_avg:89.04ms +step:144/1680 train_time:12821ms step_avg:89.04ms +step:145/1680 train_time:12909ms step_avg:89.03ms +step:146/1680 train_time:12998ms step_avg:89.03ms +step:147/1680 train_time:13085ms step_avg:89.02ms +step:148/1680 train_time:13174ms step_avg:89.01ms +step:149/1680 train_time:13263ms step_avg:89.02ms +step:150/1680 train_time:13354ms step_avg:89.02ms +step:151/1680 train_time:13444ms step_avg:89.03ms +step:152/1680 train_time:13534ms step_avg:89.04ms +step:153/1680 train_time:13624ms step_avg:89.04ms +step:154/1680 train_time:13712ms step_avg:89.04ms +step:155/1680 train_time:13802ms step_avg:89.04ms +step:156/1680 train_time:13891ms step_avg:89.04ms +step:157/1680 train_time:13979ms step_avg:89.04ms +step:158/1680 train_time:14067ms step_avg:89.03ms +step:159/1680 train_time:14156ms step_avg:89.03ms +step:160/1680 train_time:14245ms step_avg:89.03ms +step:161/1680 train_time:14333ms step_avg:89.03ms +step:162/1680 train_time:14424ms step_avg:89.04ms +step:163/1680 train_time:14512ms step_avg:89.03ms +step:164/1680 train_time:14603ms step_avg:89.04ms +step:165/1680 train_time:14691ms step_avg:89.04ms +step:166/1680 train_time:14780ms step_avg:89.04ms +step:167/1680 train_time:14870ms step_avg:89.04ms +step:168/1680 train_time:14958ms step_avg:89.04ms +step:169/1680 train_time:15048ms step_avg:89.04ms +step:170/1680 train_time:15136ms step_avg:89.04ms +step:171/1680 train_time:15225ms step_avg:89.03ms +step:172/1680 train_time:15314ms step_avg:89.03ms +step:173/1680 train_time:15404ms step_avg:89.04ms +step:174/1680 train_time:15493ms step_avg:89.04ms +step:175/1680 train_time:15583ms step_avg:89.05ms +step:176/1680 train_time:15672ms step_avg:89.04ms +step:177/1680 train_time:15761ms step_avg:89.04ms +step:178/1680 train_time:15850ms step_avg:89.05ms +step:179/1680 train_time:15938ms step_avg:89.04ms +step:180/1680 train_time:16027ms step_avg:89.04ms +step:181/1680 train_time:16116ms step_avg:89.04ms +step:182/1680 train_time:16205ms step_avg:89.04ms +step:183/1680 train_time:16295ms step_avg:89.04ms +step:184/1680 train_time:16385ms step_avg:89.05ms +step:185/1680 train_time:16473ms step_avg:89.04ms +step:186/1680 train_time:16562ms step_avg:89.04ms +step:187/1680 train_time:16651ms step_avg:89.04ms +step:188/1680 train_time:16741ms step_avg:89.05ms +step:189/1680 train_time:16830ms step_avg:89.05ms +step:190/1680 train_time:16919ms step_avg:89.05ms +step:191/1680 train_time:17007ms step_avg:89.04ms +step:192/1680 train_time:17095ms step_avg:89.04ms +step:193/1680 train_time:17184ms step_avg:89.04ms +step:194/1680 train_time:17273ms step_avg:89.03ms +step:195/1680 train_time:17363ms step_avg:89.04ms +step:196/1680 train_time:17451ms step_avg:89.04ms +step:197/1680 train_time:17541ms step_avg:89.04ms +step:198/1680 train_time:17629ms step_avg:89.04ms +step:199/1680 train_time:17719ms step_avg:89.04ms +step:200/1680 train_time:17807ms step_avg:89.04ms +step:201/1680 train_time:17897ms step_avg:89.04ms +step:202/1680 train_time:17985ms step_avg:89.04ms +step:203/1680 train_time:18074ms step_avg:89.03ms +step:204/1680 train_time:18163ms step_avg:89.04ms +step:205/1680 train_time:18252ms step_avg:89.03ms +step:206/1680 train_time:18340ms step_avg:89.03ms +step:207/1680 train_time:18429ms step_avg:89.03ms +step:208/1680 train_time:18518ms step_avg:89.03ms +step:209/1680 train_time:18607ms step_avg:89.03ms +step:210/1680 train_time:18696ms step_avg:89.03ms +step:211/1680 train_time:18785ms step_avg:89.03ms +step:212/1680 train_time:18874ms step_avg:89.03ms +step:213/1680 train_time:18964ms step_avg:89.03ms +step:214/1680 train_time:19053ms step_avg:89.03ms +step:215/1680 train_time:19142ms step_avg:89.03ms +step:216/1680 train_time:19231ms step_avg:89.03ms +step:217/1680 train_time:19321ms step_avg:89.03ms +step:218/1680 train_time:19409ms step_avg:89.03ms +step:219/1680 train_time:19499ms step_avg:89.04ms +step:220/1680 train_time:19588ms step_avg:89.04ms +step:221/1680 train_time:19676ms step_avg:89.03ms +step:222/1680 train_time:19765ms step_avg:89.03ms +step:223/1680 train_time:19853ms step_avg:89.03ms +step:224/1680 train_time:19943ms step_avg:89.03ms +step:225/1680 train_time:20031ms step_avg:89.03ms +step:226/1680 train_time:20121ms step_avg:89.03ms +step:227/1680 train_time:20210ms step_avg:89.03ms +step:228/1680 train_time:20299ms step_avg:89.03ms +step:229/1680 train_time:20388ms step_avg:89.03ms +step:230/1680 train_time:20477ms step_avg:89.03ms +step:231/1680 train_time:20566ms step_avg:89.03ms +step:232/1680 train_time:20655ms step_avg:89.03ms +step:233/1680 train_time:20743ms step_avg:89.03ms +step:234/1680 train_time:20833ms step_avg:89.03ms +step:235/1680 train_time:20923ms step_avg:89.03ms +step:236/1680 train_time:21011ms step_avg:89.03ms +step:237/1680 train_time:21100ms step_avg:89.03ms +step:238/1680 train_time:21189ms step_avg:89.03ms +step:239/1680 train_time:21279ms step_avg:89.03ms +step:240/1680 train_time:21368ms step_avg:89.03ms +step:241/1680 train_time:21457ms step_avg:89.03ms +step:242/1680 train_time:21545ms step_avg:89.03ms +step:243/1680 train_time:21634ms step_avg:89.03ms +step:244/1680 train_time:21724ms step_avg:89.03ms +step:245/1680 train_time:21814ms step_avg:89.04ms +step:246/1680 train_time:21904ms step_avg:89.04ms +step:247/1680 train_time:21993ms step_avg:89.04ms +step:248/1680 train_time:22081ms step_avg:89.04ms +step:249/1680 train_time:22170ms step_avg:89.04ms +step:250/1680 train_time:22259ms step_avg:89.04ms +step:250/1680 val_loss:3.9619 train_time:22351ms step_avg:89.40ms +step:251/1680 train_time:22374ms step_avg:89.14ms +step:252/1680 train_time:22444ms step_avg:89.06ms +step:253/1680 train_time:22543ms step_avg:89.10ms +step:254/1680 train_time:22634ms step_avg:89.11ms +step:255/1680 train_time:22723ms step_avg:89.11ms +step:256/1680 train_time:22811ms step_avg:89.10ms +step:257/1680 train_time:22899ms step_avg:89.10ms +step:258/1680 train_time:22987ms step_avg:89.10ms +step:259/1680 train_time:23075ms step_avg:89.09ms +step:260/1680 train_time:23164ms step_avg:89.09ms +step:261/1680 train_time:23252ms step_avg:89.09ms +step:262/1680 train_time:23341ms step_avg:89.09ms +step:263/1680 train_time:23431ms step_avg:89.09ms +step:264/1680 train_time:23524ms step_avg:89.10ms +step:265/1680 train_time:23615ms step_avg:89.11ms +step:266/1680 train_time:23705ms step_avg:89.12ms +step:267/1680 train_time:23795ms step_avg:89.12ms +step:268/1680 train_time:23883ms step_avg:89.12ms +step:269/1680 train_time:23971ms step_avg:89.11ms +step:270/1680 train_time:24059ms step_avg:89.11ms +step:271/1680 train_time:24147ms step_avg:89.10ms +step:272/1680 train_time:24235ms step_avg:89.10ms +step:273/1680 train_time:24324ms step_avg:89.10ms +step:274/1680 train_time:24412ms step_avg:89.10ms +step:275/1680 train_time:24503ms step_avg:89.10ms +step:276/1680 train_time:24593ms step_avg:89.11ms +step:277/1680 train_time:24683ms step_avg:89.11ms +step:278/1680 train_time:24773ms step_avg:89.11ms +step:279/1680 train_time:24863ms step_avg:89.11ms +step:280/1680 train_time:24952ms step_avg:89.11ms +step:281/1680 train_time:25041ms step_avg:89.11ms +step:282/1680 train_time:25129ms step_avg:89.11ms +step:283/1680 train_time:25218ms step_avg:89.11ms +step:284/1680 train_time:25307ms step_avg:89.11ms +step:285/1680 train_time:25396ms step_avg:89.11ms +step:286/1680 train_time:25486ms step_avg:89.11ms +step:287/1680 train_time:25575ms step_avg:89.11ms +step:288/1680 train_time:25665ms step_avg:89.11ms +step:289/1680 train_time:25755ms step_avg:89.12ms +step:290/1680 train_time:25844ms step_avg:89.12ms +step:291/1680 train_time:25934ms step_avg:89.12ms +step:292/1680 train_time:26022ms step_avg:89.12ms +step:293/1680 train_time:26110ms step_avg:89.11ms +step:294/1680 train_time:26200ms step_avg:89.12ms +step:295/1680 train_time:26288ms step_avg:89.11ms +step:296/1680 train_time:26378ms step_avg:89.11ms +step:297/1680 train_time:26467ms step_avg:89.11ms +step:298/1680 train_time:26557ms step_avg:89.12ms +step:299/1680 train_time:26646ms step_avg:89.12ms +step:300/1680 train_time:26736ms step_avg:89.12ms +step:301/1680 train_time:26825ms step_avg:89.12ms +step:302/1680 train_time:26914ms step_avg:89.12ms +step:303/1680 train_time:27003ms step_avg:89.12ms +step:304/1680 train_time:27092ms step_avg:89.12ms +step:305/1680 train_time:27181ms step_avg:89.12ms +step:306/1680 train_time:27270ms step_avg:89.12ms +step:307/1680 train_time:27360ms step_avg:89.12ms +step:308/1680 train_time:27449ms step_avg:89.12ms +step:309/1680 train_time:27538ms step_avg:89.12ms +step:310/1680 train_time:27628ms step_avg:89.12ms +step:311/1680 train_time:27717ms step_avg:89.12ms +step:312/1680 train_time:27806ms step_avg:89.12ms +step:313/1680 train_time:27896ms step_avg:89.12ms +step:314/1680 train_time:27985ms step_avg:89.12ms +step:315/1680 train_time:28074ms step_avg:89.12ms +step:316/1680 train_time:28162ms step_avg:89.12ms +step:317/1680 train_time:28251ms step_avg:89.12ms +step:318/1680 train_time:28340ms step_avg:89.12ms +step:319/1680 train_time:28429ms step_avg:89.12ms +step:320/1680 train_time:28519ms step_avg:89.12ms +step:321/1680 train_time:28608ms step_avg:89.12ms +step:322/1680 train_time:28698ms step_avg:89.12ms +step:323/1680 train_time:28787ms step_avg:89.12ms +step:324/1680 train_time:28877ms step_avg:89.13ms +step:325/1680 train_time:28965ms step_avg:89.12ms +step:326/1680 train_time:29054ms step_avg:89.12ms +step:327/1680 train_time:29143ms step_avg:89.12ms +step:328/1680 train_time:29231ms step_avg:89.12ms +step:329/1680 train_time:29320ms step_avg:89.12ms +step:330/1680 train_time:29408ms step_avg:89.12ms +step:331/1680 train_time:29498ms step_avg:89.12ms +step:332/1680 train_time:29587ms step_avg:89.12ms +step:333/1680 train_time:29677ms step_avg:89.12ms +step:334/1680 train_time:29766ms step_avg:89.12ms +step:335/1680 train_time:29856ms step_avg:89.12ms +step:336/1680 train_time:29945ms step_avg:89.12ms +step:337/1680 train_time:30035ms step_avg:89.12ms +step:338/1680 train_time:30124ms step_avg:89.12ms +step:339/1680 train_time:30212ms step_avg:89.12ms +step:340/1680 train_time:30303ms step_avg:89.13ms +step:341/1680 train_time:30391ms step_avg:89.12ms +step:342/1680 train_time:30481ms step_avg:89.12ms +step:343/1680 train_time:30570ms step_avg:89.13ms +step:344/1680 train_time:30661ms step_avg:89.13ms +step:345/1680 train_time:30750ms step_avg:89.13ms +step:346/1680 train_time:30839ms step_avg:89.13ms +step:347/1680 train_time:30928ms step_avg:89.13ms +step:348/1680 train_time:31018ms step_avg:89.13ms +step:349/1680 train_time:31108ms step_avg:89.13ms +step:350/1680 train_time:31197ms step_avg:89.13ms +step:351/1680 train_time:31286ms step_avg:89.13ms +step:352/1680 train_time:31375ms step_avg:89.13ms +step:353/1680 train_time:31464ms step_avg:89.13ms +step:354/1680 train_time:31553ms step_avg:89.13ms +step:355/1680 train_time:31642ms step_avg:89.13ms +step:356/1680 train_time:31732ms step_avg:89.14ms +step:357/1680 train_time:31822ms step_avg:89.14ms +step:358/1680 train_time:31910ms step_avg:89.13ms +step:359/1680 train_time:32000ms step_avg:89.14ms +step:360/1680 train_time:32088ms step_avg:89.13ms +step:361/1680 train_time:32177ms step_avg:89.13ms +step:362/1680 train_time:32266ms step_avg:89.13ms +step:363/1680 train_time:32354ms step_avg:89.13ms +step:364/1680 train_time:32443ms step_avg:89.13ms +step:365/1680 train_time:32532ms step_avg:89.13ms +step:366/1680 train_time:32622ms step_avg:89.13ms +step:367/1680 train_time:32710ms step_avg:89.13ms +step:368/1680 train_time:32800ms step_avg:89.13ms +step:369/1680 train_time:32889ms step_avg:89.13ms +step:370/1680 train_time:32978ms step_avg:89.13ms +step:371/1680 train_time:33067ms step_avg:89.13ms +step:372/1680 train_time:33157ms step_avg:89.13ms +step:373/1680 train_time:33246ms step_avg:89.13ms +step:374/1680 train_time:33335ms step_avg:89.13ms +step:375/1680 train_time:33424ms step_avg:89.13ms +step:375/1680 val_loss:3.8123 train_time:33515ms step_avg:89.37ms +step:376/1680 train_time:33541ms step_avg:89.21ms +step:377/1680 train_time:33609ms step_avg:89.15ms +step:378/1680 train_time:33702ms step_avg:89.16ms +step:379/1680 train_time:33790ms step_avg:89.16ms +step:380/1680 train_time:33879ms step_avg:89.15ms +step:381/1680 train_time:33967ms step_avg:89.15ms +step:382/1680 train_time:34056ms step_avg:89.15ms +step:383/1680 train_time:34145ms step_avg:89.15ms +step:384/1680 train_time:34232ms step_avg:89.15ms +step:385/1680 train_time:34322ms step_avg:89.15ms +step:386/1680 train_time:34410ms step_avg:89.15ms +step:387/1680 train_time:34501ms step_avg:89.15ms +step:388/1680 train_time:34592ms step_avg:89.16ms +step:389/1680 train_time:34683ms step_avg:89.16ms +step:390/1680 train_time:34773ms step_avg:89.16ms +step:391/1680 train_time:34863ms step_avg:89.16ms +step:392/1680 train_time:34950ms step_avg:89.16ms +step:393/1680 train_time:35039ms step_avg:89.16ms +step:394/1680 train_time:35127ms step_avg:89.16ms +step:395/1680 train_time:35217ms step_avg:89.16ms +step:396/1680 train_time:35305ms step_avg:89.15ms +step:397/1680 train_time:35393ms step_avg:89.15ms +step:398/1680 train_time:35483ms step_avg:89.15ms +step:399/1680 train_time:35572ms step_avg:89.15ms +step:400/1680 train_time:35664ms step_avg:89.16ms +step:401/1680 train_time:35753ms step_avg:89.16ms +step:402/1680 train_time:35844ms step_avg:89.16ms +step:403/1680 train_time:35932ms step_avg:89.16ms +step:404/1680 train_time:36022ms step_avg:89.16ms +step:405/1680 train_time:36110ms step_avg:89.16ms +step:406/1680 train_time:36199ms step_avg:89.16ms +step:407/1680 train_time:36288ms step_avg:89.16ms +step:408/1680 train_time:36377ms step_avg:89.16ms +step:409/1680 train_time:36466ms step_avg:89.16ms +step:410/1680 train_time:36556ms step_avg:89.16ms +step:411/1680 train_time:36646ms step_avg:89.16ms +step:412/1680 train_time:36736ms step_avg:89.16ms +step:413/1680 train_time:36826ms step_avg:89.17ms +step:414/1680 train_time:36915ms step_avg:89.17ms +step:415/1680 train_time:37005ms step_avg:89.17ms +step:416/1680 train_time:37093ms step_avg:89.17ms +step:417/1680 train_time:37182ms step_avg:89.17ms +step:418/1680 train_time:37271ms step_avg:89.17ms +step:419/1680 train_time:37361ms step_avg:89.17ms +step:420/1680 train_time:37450ms step_avg:89.17ms +step:421/1680 train_time:37539ms step_avg:89.17ms +step:422/1680 train_time:37629ms step_avg:89.17ms +step:423/1680 train_time:37718ms step_avg:89.17ms +step:424/1680 train_time:37808ms step_avg:89.17ms +step:425/1680 train_time:37897ms step_avg:89.17ms +step:426/1680 train_time:37986ms step_avg:89.17ms +step:427/1680 train_time:38074ms step_avg:89.17ms +step:428/1680 train_time:38164ms step_avg:89.17ms +step:429/1680 train_time:38252ms step_avg:89.17ms +step:430/1680 train_time:38341ms step_avg:89.17ms +step:431/1680 train_time:38430ms step_avg:89.17ms +step:432/1680 train_time:38519ms step_avg:89.17ms +step:433/1680 train_time:38609ms step_avg:89.17ms +step:434/1680 train_time:38699ms step_avg:89.17ms +step:435/1680 train_time:38789ms step_avg:89.17ms +step:436/1680 train_time:38878ms step_avg:89.17ms +step:437/1680 train_time:38968ms step_avg:89.17ms +step:438/1680 train_time:39057ms step_avg:89.17ms +step:439/1680 train_time:39146ms step_avg:89.17ms +step:440/1680 train_time:39235ms step_avg:89.17ms +step:441/1680 train_time:39324ms step_avg:89.17ms +step:442/1680 train_time:39413ms step_avg:89.17ms +step:443/1680 train_time:39502ms step_avg:89.17ms +step:444/1680 train_time:39591ms step_avg:89.17ms +step:445/1680 train_time:39680ms step_avg:89.17ms +step:446/1680 train_time:39771ms step_avg:89.17ms +step:447/1680 train_time:39860ms step_avg:89.17ms +step:448/1680 train_time:39949ms step_avg:89.17ms +step:449/1680 train_time:40039ms step_avg:89.17ms +step:450/1680 train_time:40128ms step_avg:89.17ms +step:451/1680 train_time:40217ms step_avg:89.17ms +step:452/1680 train_time:40305ms step_avg:89.17ms +step:453/1680 train_time:40395ms step_avg:89.17ms +step:454/1680 train_time:40484ms step_avg:89.17ms +step:455/1680 train_time:40573ms step_avg:89.17ms +step:456/1680 train_time:40663ms step_avg:89.17ms +step:457/1680 train_time:40752ms step_avg:89.17ms +step:458/1680 train_time:40843ms step_avg:89.18ms +step:459/1680 train_time:40932ms step_avg:89.18ms +step:460/1680 train_time:41022ms step_avg:89.18ms +step:461/1680 train_time:41111ms step_avg:89.18ms +step:462/1680 train_time:41200ms step_avg:89.18ms +step:463/1680 train_time:41288ms step_avg:89.18ms +step:464/1680 train_time:41377ms step_avg:89.17ms +step:465/1680 train_time:41466ms step_avg:89.18ms +step:466/1680 train_time:41555ms step_avg:89.17ms +step:467/1680 train_time:41645ms step_avg:89.18ms +step:468/1680 train_time:41735ms step_avg:89.18ms +step:469/1680 train_time:41825ms step_avg:89.18ms +step:470/1680 train_time:41915ms step_avg:89.18ms +step:471/1680 train_time:42005ms step_avg:89.18ms +step:472/1680 train_time:42094ms step_avg:89.18ms +step:473/1680 train_time:42185ms step_avg:89.19ms +step:474/1680 train_time:42273ms step_avg:89.18ms +step:475/1680 train_time:42364ms step_avg:89.19ms +step:476/1680 train_time:42452ms step_avg:89.19ms +step:477/1680 train_time:42542ms step_avg:89.19ms +step:478/1680 train_time:42631ms step_avg:89.19ms +step:479/1680 train_time:42720ms step_avg:89.19ms +step:480/1680 train_time:42809ms step_avg:89.19ms +step:481/1680 train_time:42898ms step_avg:89.19ms +step:482/1680 train_time:42988ms step_avg:89.19ms +step:483/1680 train_time:43076ms step_avg:89.19ms +step:484/1680 train_time:43166ms step_avg:89.19ms +step:485/1680 train_time:43255ms step_avg:89.19ms +step:486/1680 train_time:43345ms step_avg:89.19ms +step:487/1680 train_time:43434ms step_avg:89.19ms +step:488/1680 train_time:43523ms step_avg:89.19ms +step:489/1680 train_time:43612ms step_avg:89.19ms +step:490/1680 train_time:43702ms step_avg:89.19ms +step:491/1680 train_time:43791ms step_avg:89.19ms +step:492/1680 train_time:43880ms step_avg:89.19ms +step:493/1680 train_time:43971ms step_avg:89.19ms +step:494/1680 train_time:44060ms step_avg:89.19ms +step:495/1680 train_time:44149ms step_avg:89.19ms +step:496/1680 train_time:44239ms step_avg:89.19ms +step:497/1680 train_time:44328ms step_avg:89.19ms +step:498/1680 train_time:44417ms step_avg:89.19ms +step:499/1680 train_time:44507ms step_avg:89.19ms +step:500/1680 train_time:44597ms step_avg:89.19ms +step:500/1680 val_loss:3.7141 train_time:44688ms step_avg:89.38ms +step:501/1680 train_time:44711ms step_avg:89.24ms +step:502/1680 train_time:44780ms step_avg:89.20ms +step:503/1680 train_time:44875ms step_avg:89.22ms +step:504/1680 train_time:44965ms step_avg:89.22ms +step:505/1680 train_time:45053ms step_avg:89.21ms +step:506/1680 train_time:45142ms step_avg:89.21ms +step:507/1680 train_time:45230ms step_avg:89.21ms +step:508/1680 train_time:45318ms step_avg:89.21ms +step:509/1680 train_time:45407ms step_avg:89.21ms +step:510/1680 train_time:45496ms step_avg:89.21ms +step:511/1680 train_time:45585ms step_avg:89.21ms +step:512/1680 train_time:45677ms step_avg:89.21ms +step:513/1680 train_time:45769ms step_avg:89.22ms +step:514/1680 train_time:45860ms step_avg:89.22ms +step:515/1680 train_time:45951ms step_avg:89.22ms +step:516/1680 train_time:46040ms step_avg:89.22ms +step:517/1680 train_time:46129ms step_avg:89.22ms +step:518/1680 train_time:46218ms step_avg:89.22ms +step:519/1680 train_time:46306ms step_avg:89.22ms +step:520/1680 train_time:46394ms step_avg:89.22ms +step:521/1680 train_time:46483ms step_avg:89.22ms +step:522/1680 train_time:46572ms step_avg:89.22ms +step:523/1680 train_time:46661ms step_avg:89.22ms +step:524/1680 train_time:46753ms step_avg:89.22ms +step:525/1680 train_time:46843ms step_avg:89.22ms +step:526/1680 train_time:46933ms step_avg:89.23ms +step:527/1680 train_time:47022ms step_avg:89.23ms +step:528/1680 train_time:47112ms step_avg:89.23ms +step:529/1680 train_time:47200ms step_avg:89.23ms +step:530/1680 train_time:47289ms step_avg:89.23ms +step:531/1680 train_time:47378ms step_avg:89.22ms +step:532/1680 train_time:47467ms step_avg:89.22ms +step:533/1680 train_time:47556ms step_avg:89.22ms +step:534/1680 train_time:47645ms step_avg:89.22ms +step:535/1680 train_time:47735ms step_avg:89.22ms +step:536/1680 train_time:47824ms step_avg:89.22ms +step:537/1680 train_time:47915ms step_avg:89.23ms +step:538/1680 train_time:48004ms step_avg:89.23ms +step:539/1680 train_time:48094ms step_avg:89.23ms +step:540/1680 train_time:48183ms step_avg:89.23ms +step:541/1680 train_time:48273ms step_avg:89.23ms +step:542/1680 train_time:48362ms step_avg:89.23ms +step:543/1680 train_time:48451ms step_avg:89.23ms +step:544/1680 train_time:48540ms step_avg:89.23ms +step:545/1680 train_time:48629ms step_avg:89.23ms +step:546/1680 train_time:48719ms step_avg:89.23ms +step:547/1680 train_time:48809ms step_avg:89.23ms +step:548/1680 train_time:48899ms step_avg:89.23ms +step:549/1680 train_time:48990ms step_avg:89.24ms +step:550/1680 train_time:49081ms step_avg:89.24ms +step:551/1680 train_time:49171ms step_avg:89.24ms +step:552/1680 train_time:49261ms step_avg:89.24ms +step:553/1680 train_time:49351ms step_avg:89.24ms +step:554/1680 train_time:49441ms step_avg:89.24ms +step:555/1680 train_time:49531ms step_avg:89.25ms +step:556/1680 train_time:49623ms step_avg:89.25ms +step:557/1680 train_time:49714ms step_avg:89.25ms +step:558/1680 train_time:49804ms step_avg:89.25ms +step:559/1680 train_time:49895ms step_avg:89.26ms +step:560/1680 train_time:49985ms step_avg:89.26ms +step:561/1680 train_time:50077ms step_avg:89.26ms +step:562/1680 train_time:50168ms step_avg:89.27ms +step:563/1680 train_time:50259ms step_avg:89.27ms +step:564/1680 train_time:50349ms step_avg:89.27ms +step:565/1680 train_time:50440ms step_avg:89.27ms +step:566/1680 train_time:50531ms step_avg:89.28ms +step:567/1680 train_time:50622ms step_avg:89.28ms +step:568/1680 train_time:50712ms step_avg:89.28ms +step:569/1680 train_time:50803ms step_avg:89.28ms +step:570/1680 train_time:50894ms step_avg:89.29ms +step:571/1680 train_time:50984ms step_avg:89.29ms +step:572/1680 train_time:51075ms step_avg:89.29ms +step:573/1680 train_time:51166ms step_avg:89.29ms +step:574/1680 train_time:51257ms step_avg:89.30ms +step:575/1680 train_time:51347ms step_avg:89.30ms +step:576/1680 train_time:51437ms step_avg:89.30ms +step:577/1680 train_time:51528ms step_avg:89.30ms +step:578/1680 train_time:51619ms step_avg:89.31ms +step:579/1680 train_time:51709ms step_avg:89.31ms +step:580/1680 train_time:51800ms step_avg:89.31ms +step:581/1680 train_time:51890ms step_avg:89.31ms +step:582/1680 train_time:51980ms step_avg:89.31ms +step:583/1680 train_time:52070ms step_avg:89.31ms +step:584/1680 train_time:52161ms step_avg:89.32ms +step:585/1680 train_time:52251ms step_avg:89.32ms +step:586/1680 train_time:52342ms step_avg:89.32ms +step:587/1680 train_time:52433ms step_avg:89.32ms +step:588/1680 train_time:52524ms step_avg:89.33ms +step:589/1680 train_time:52615ms step_avg:89.33ms +step:590/1680 train_time:52706ms step_avg:89.33ms +step:591/1680 train_time:52797ms step_avg:89.34ms +step:592/1680 train_time:52888ms step_avg:89.34ms +step:593/1680 train_time:52979ms step_avg:89.34ms +step:594/1680 train_time:53069ms step_avg:89.34ms +step:595/1680 train_time:53159ms step_avg:89.34ms +step:596/1680 train_time:53250ms step_avg:89.34ms +step:597/1680 train_time:53340ms step_avg:89.35ms +step:598/1680 train_time:53430ms step_avg:89.35ms +step:599/1680 train_time:53520ms step_avg:89.35ms +step:600/1680 train_time:53611ms step_avg:89.35ms +step:601/1680 train_time:53701ms step_avg:89.35ms +step:602/1680 train_time:53792ms step_avg:89.35ms +step:603/1680 train_time:53882ms step_avg:89.36ms +step:604/1680 train_time:53973ms step_avg:89.36ms +step:605/1680 train_time:54063ms step_avg:89.36ms +step:606/1680 train_time:54153ms step_avg:89.36ms +step:607/1680 train_time:54244ms step_avg:89.36ms +step:608/1680 train_time:54334ms step_avg:89.36ms +step:609/1680 train_time:54424ms step_avg:89.37ms +step:610/1680 train_time:54515ms step_avg:89.37ms +step:611/1680 train_time:54605ms step_avg:89.37ms +step:612/1680 train_time:54695ms step_avg:89.37ms +step:613/1680 train_time:54785ms step_avg:89.37ms +step:614/1680 train_time:54877ms step_avg:89.38ms +step:615/1680 train_time:54968ms step_avg:89.38ms +step:616/1680 train_time:55058ms step_avg:89.38ms +step:617/1680 train_time:55148ms step_avg:89.38ms +step:618/1680 train_time:55238ms step_avg:89.38ms +step:619/1680 train_time:55328ms step_avg:89.38ms +step:620/1680 train_time:55419ms step_avg:89.39ms +step:621/1680 train_time:55510ms step_avg:89.39ms +step:622/1680 train_time:55600ms step_avg:89.39ms +step:623/1680 train_time:55690ms step_avg:89.39ms +step:624/1680 train_time:55781ms step_avg:89.39ms +step:625/1680 train_time:55872ms step_avg:89.39ms +step:625/1680 val_loss:3.6143 train_time:55963ms step_avg:89.54ms +step:626/1680 train_time:55987ms step_avg:89.44ms +step:627/1680 train_time:56058ms step_avg:89.41ms +step:628/1680 train_time:56158ms step_avg:89.42ms +step:629/1680 train_time:56249ms step_avg:89.43ms +step:630/1680 train_time:56339ms step_avg:89.43ms +step:631/1680 train_time:56427ms step_avg:89.43ms +step:632/1680 train_time:56517ms step_avg:89.43ms +step:633/1680 train_time:56606ms step_avg:89.42ms +step:634/1680 train_time:56695ms step_avg:89.42ms +step:635/1680 train_time:56783ms step_avg:89.42ms +step:636/1680 train_time:56872ms step_avg:89.42ms +step:637/1680 train_time:56962ms step_avg:89.42ms +step:638/1680 train_time:57057ms step_avg:89.43ms +step:639/1680 train_time:57150ms step_avg:89.44ms +step:640/1680 train_time:57241ms step_avg:89.44ms +step:641/1680 train_time:57331ms step_avg:89.44ms +step:642/1680 train_time:57422ms step_avg:89.44ms +step:643/1680 train_time:57510ms step_avg:89.44ms +step:644/1680 train_time:57599ms step_avg:89.44ms +step:645/1680 train_time:57688ms step_avg:89.44ms +step:646/1680 train_time:57778ms step_avg:89.44ms +step:647/1680 train_time:57867ms step_avg:89.44ms +step:648/1680 train_time:57958ms step_avg:89.44ms +step:649/1680 train_time:58050ms step_avg:89.45ms +step:650/1680 train_time:58143ms step_avg:89.45ms +step:651/1680 train_time:58234ms step_avg:89.45ms +step:652/1680 train_time:58324ms step_avg:89.45ms +step:653/1680 train_time:58415ms step_avg:89.46ms +step:654/1680 train_time:58505ms step_avg:89.46ms +step:655/1680 train_time:58596ms step_avg:89.46ms +step:656/1680 train_time:58685ms step_avg:89.46ms +step:657/1680 train_time:58774ms step_avg:89.46ms +step:658/1680 train_time:58864ms step_avg:89.46ms +step:659/1680 train_time:58954ms step_avg:89.46ms +step:660/1680 train_time:59045ms step_avg:89.46ms +step:661/1680 train_time:59136ms step_avg:89.46ms +step:662/1680 train_time:59228ms step_avg:89.47ms +step:663/1680 train_time:59319ms step_avg:89.47ms +step:664/1680 train_time:59411ms step_avg:89.47ms +step:665/1680 train_time:59500ms step_avg:89.47ms +step:666/1680 train_time:59591ms step_avg:89.48ms +step:667/1680 train_time:59681ms step_avg:89.48ms +step:668/1680 train_time:59771ms step_avg:89.48ms +step:669/1680 train_time:59861ms step_avg:89.48ms +step:670/1680 train_time:59951ms step_avg:89.48ms +step:671/1680 train_time:60041ms step_avg:89.48ms +step:672/1680 train_time:60132ms step_avg:89.48ms +step:673/1680 train_time:60224ms step_avg:89.49ms +step:674/1680 train_time:60316ms step_avg:89.49ms +step:675/1680 train_time:60407ms step_avg:89.49ms +step:676/1680 train_time:60499ms step_avg:89.49ms +step:677/1680 train_time:60589ms step_avg:89.50ms +step:678/1680 train_time:60679ms step_avg:89.50ms +step:679/1680 train_time:60769ms step_avg:89.50ms +step:680/1680 train_time:60859ms step_avg:89.50ms +step:681/1680 train_time:60949ms step_avg:89.50ms +step:682/1680 train_time:61040ms step_avg:89.50ms +step:683/1680 train_time:61130ms step_avg:89.50ms +step:684/1680 train_time:61221ms step_avg:89.51ms +step:685/1680 train_time:61313ms step_avg:89.51ms +step:686/1680 train_time:61403ms step_avg:89.51ms +step:687/1680 train_time:61494ms step_avg:89.51ms +step:688/1680 train_time:61583ms step_avg:89.51ms +step:689/1680 train_time:61673ms step_avg:89.51ms +step:690/1680 train_time:61764ms step_avg:89.51ms +step:691/1680 train_time:61854ms step_avg:89.51ms +step:692/1680 train_time:61944ms step_avg:89.51ms +step:693/1680 train_time:62034ms step_avg:89.52ms +step:694/1680 train_time:62125ms step_avg:89.52ms +step:695/1680 train_time:62217ms step_avg:89.52ms +step:696/1680 train_time:62307ms step_avg:89.52ms +step:697/1680 train_time:62399ms step_avg:89.53ms +step:698/1680 train_time:62491ms step_avg:89.53ms +step:699/1680 train_time:62582ms step_avg:89.53ms +step:700/1680 train_time:62674ms step_avg:89.53ms +step:701/1680 train_time:62763ms step_avg:89.53ms +step:702/1680 train_time:62853ms step_avg:89.53ms +step:703/1680 train_time:62943ms step_avg:89.53ms +step:704/1680 train_time:63033ms step_avg:89.54ms +step:705/1680 train_time:63124ms step_avg:89.54ms +step:706/1680 train_time:63215ms step_avg:89.54ms +step:707/1680 train_time:63305ms step_avg:89.54ms +step:708/1680 train_time:63396ms step_avg:89.54ms +step:709/1680 train_time:63487ms step_avg:89.54ms +step:710/1680 train_time:63578ms step_avg:89.55ms +step:711/1680 train_time:63668ms step_avg:89.55ms +step:712/1680 train_time:63758ms step_avg:89.55ms +step:713/1680 train_time:63848ms step_avg:89.55ms +step:714/1680 train_time:63939ms step_avg:89.55ms +step:715/1680 train_time:64029ms step_avg:89.55ms +step:716/1680 train_time:64120ms step_avg:89.55ms +step:717/1680 train_time:64210ms step_avg:89.55ms +step:718/1680 train_time:64301ms step_avg:89.56ms +step:719/1680 train_time:64391ms step_avg:89.56ms +step:720/1680 train_time:64481ms step_avg:89.56ms +step:721/1680 train_time:64571ms step_avg:89.56ms +step:722/1680 train_time:64662ms step_avg:89.56ms +step:723/1680 train_time:64753ms step_avg:89.56ms +step:724/1680 train_time:64843ms step_avg:89.56ms +step:725/1680 train_time:64933ms step_avg:89.56ms +step:726/1680 train_time:65024ms step_avg:89.56ms +step:727/1680 train_time:65114ms step_avg:89.56ms +step:728/1680 train_time:65204ms step_avg:89.57ms +step:729/1680 train_time:65294ms step_avg:89.57ms +step:730/1680 train_time:65385ms step_avg:89.57ms +step:731/1680 train_time:65476ms step_avg:89.57ms +step:732/1680 train_time:65566ms step_avg:89.57ms +step:733/1680 train_time:65657ms step_avg:89.57ms +step:734/1680 train_time:65747ms step_avg:89.57ms +step:735/1680 train_time:65838ms step_avg:89.58ms +step:736/1680 train_time:65928ms step_avg:89.58ms +step:737/1680 train_time:66020ms step_avg:89.58ms +step:738/1680 train_time:66111ms step_avg:89.58ms +step:739/1680 train_time:66201ms step_avg:89.58ms +step:740/1680 train_time:66292ms step_avg:89.58ms +step:741/1680 train_time:66383ms step_avg:89.59ms +step:742/1680 train_time:66473ms step_avg:89.59ms +step:743/1680 train_time:66564ms step_avg:89.59ms +step:744/1680 train_time:66654ms step_avg:89.59ms +step:745/1680 train_time:66744ms step_avg:89.59ms +step:746/1680 train_time:66835ms step_avg:89.59ms +step:747/1680 train_time:66925ms step_avg:89.59ms +step:748/1680 train_time:67016ms step_avg:89.59ms +step:749/1680 train_time:67106ms step_avg:89.59ms +step:750/1680 train_time:67197ms step_avg:89.60ms +step:750/1680 val_loss:3.5631 train_time:67289ms step_avg:89.72ms +step:751/1680 train_time:67313ms step_avg:89.63ms +step:752/1680 train_time:67384ms step_avg:89.61ms +step:753/1680 train_time:67482ms step_avg:89.62ms +step:754/1680 train_time:67574ms step_avg:89.62ms +step:755/1680 train_time:67664ms step_avg:89.62ms +step:756/1680 train_time:67754ms step_avg:89.62ms +step:757/1680 train_time:67843ms step_avg:89.62ms +step:758/1680 train_time:67932ms step_avg:89.62ms +step:759/1680 train_time:68021ms step_avg:89.62ms +step:760/1680 train_time:68110ms step_avg:89.62ms +step:761/1680 train_time:68200ms step_avg:89.62ms +step:762/1680 train_time:68291ms step_avg:89.62ms +step:763/1680 train_time:68384ms step_avg:89.63ms +step:764/1680 train_time:68477ms step_avg:89.63ms +step:765/1680 train_time:68569ms step_avg:89.63ms +step:766/1680 train_time:68660ms step_avg:89.63ms +step:767/1680 train_time:68750ms step_avg:89.64ms +step:768/1680 train_time:68840ms step_avg:89.64ms +step:769/1680 train_time:68930ms step_avg:89.64ms +step:770/1680 train_time:69019ms step_avg:89.64ms +step:771/1680 train_time:69108ms step_avg:89.63ms +step:772/1680 train_time:69198ms step_avg:89.63ms +step:773/1680 train_time:69289ms step_avg:89.64ms +step:774/1680 train_time:69380ms step_avg:89.64ms +step:775/1680 train_time:69472ms step_avg:89.64ms +step:776/1680 train_time:69564ms step_avg:89.64ms +step:777/1680 train_time:69655ms step_avg:89.65ms +step:778/1680 train_time:69745ms step_avg:89.65ms +step:779/1680 train_time:69835ms step_avg:89.65ms +step:780/1680 train_time:69925ms step_avg:89.65ms +step:781/1680 train_time:70015ms step_avg:89.65ms +step:782/1680 train_time:70104ms step_avg:89.65ms +step:783/1680 train_time:70194ms step_avg:89.65ms +step:784/1680 train_time:70284ms step_avg:89.65ms +step:785/1680 train_time:70375ms step_avg:89.65ms +step:786/1680 train_time:70466ms step_avg:89.65ms +step:787/1680 train_time:70558ms step_avg:89.65ms +step:788/1680 train_time:70649ms step_avg:89.66ms +step:789/1680 train_time:70740ms step_avg:89.66ms +step:790/1680 train_time:70830ms step_avg:89.66ms +step:791/1680 train_time:70920ms step_avg:89.66ms +step:792/1680 train_time:71010ms step_avg:89.66ms +step:793/1680 train_time:71100ms step_avg:89.66ms +step:794/1680 train_time:71191ms step_avg:89.66ms +step:795/1680 train_time:71281ms step_avg:89.66ms +step:796/1680 train_time:71372ms step_avg:89.66ms +step:797/1680 train_time:71463ms step_avg:89.67ms +step:798/1680 train_time:71555ms step_avg:89.67ms +step:799/1680 train_time:71645ms step_avg:89.67ms +step:800/1680 train_time:71736ms step_avg:89.67ms +step:801/1680 train_time:71826ms step_avg:89.67ms +step:802/1680 train_time:71917ms step_avg:89.67ms +step:803/1680 train_time:72007ms step_avg:89.67ms +step:804/1680 train_time:72097ms step_avg:89.67ms +step:805/1680 train_time:72187ms step_avg:89.67ms +step:806/1680 train_time:72278ms step_avg:89.67ms +step:807/1680 train_time:72369ms step_avg:89.68ms +step:808/1680 train_time:72461ms step_avg:89.68ms +step:809/1680 train_time:72552ms step_avg:89.68ms +step:810/1680 train_time:72642ms step_avg:89.68ms +step:811/1680 train_time:72732ms step_avg:89.68ms +step:812/1680 train_time:72822ms step_avg:89.68ms +step:813/1680 train_time:72913ms step_avg:89.68ms +step:814/1680 train_time:73002ms step_avg:89.68ms +step:815/1680 train_time:73093ms step_avg:89.68ms +step:816/1680 train_time:73183ms step_avg:89.69ms +step:817/1680 train_time:73273ms step_avg:89.69ms +step:818/1680 train_time:73364ms step_avg:89.69ms +step:819/1680 train_time:73455ms step_avg:89.69ms +step:820/1680 train_time:73544ms step_avg:89.69ms +step:821/1680 train_time:73636ms step_avg:89.69ms +step:822/1680 train_time:73727ms step_avg:89.69ms +step:823/1680 train_time:73818ms step_avg:89.69ms +step:824/1680 train_time:73909ms step_avg:89.69ms +step:825/1680 train_time:73999ms step_avg:89.70ms +step:826/1680 train_time:74089ms step_avg:89.70ms +step:827/1680 train_time:74180ms step_avg:89.70ms +step:828/1680 train_time:74270ms step_avg:89.70ms +step:829/1680 train_time:74361ms step_avg:89.70ms +step:830/1680 train_time:74451ms step_avg:89.70ms +step:831/1680 train_time:74541ms step_avg:89.70ms +step:832/1680 train_time:74631ms step_avg:89.70ms +step:833/1680 train_time:74721ms step_avg:89.70ms +step:834/1680 train_time:74812ms step_avg:89.70ms +step:835/1680 train_time:74902ms step_avg:89.70ms +step:836/1680 train_time:74993ms step_avg:89.70ms +step:837/1680 train_time:75083ms step_avg:89.70ms +step:838/1680 train_time:75173ms step_avg:89.70ms +step:839/1680 train_time:75264ms step_avg:89.71ms +step:840/1680 train_time:75355ms step_avg:89.71ms +step:841/1680 train_time:75445ms step_avg:89.71ms +step:842/1680 train_time:75536ms step_avg:89.71ms +step:843/1680 train_time:75627ms step_avg:89.71ms +step:844/1680 train_time:75718ms step_avg:89.71ms +step:845/1680 train_time:75809ms step_avg:89.71ms +step:846/1680 train_time:75899ms step_avg:89.72ms +step:847/1680 train_time:75989ms step_avg:89.72ms +step:848/1680 train_time:76079ms step_avg:89.72ms +step:849/1680 train_time:76169ms step_avg:89.72ms +step:850/1680 train_time:76260ms step_avg:89.72ms +step:851/1680 train_time:76351ms step_avg:89.72ms +step:852/1680 train_time:76441ms step_avg:89.72ms +step:853/1680 train_time:76531ms step_avg:89.72ms +step:854/1680 train_time:76621ms step_avg:89.72ms +step:855/1680 train_time:76712ms step_avg:89.72ms +step:856/1680 train_time:76802ms step_avg:89.72ms +step:857/1680 train_time:76894ms step_avg:89.72ms +step:858/1680 train_time:76984ms step_avg:89.73ms +step:859/1680 train_time:77074ms step_avg:89.73ms +step:860/1680 train_time:77165ms step_avg:89.73ms +step:861/1680 train_time:77256ms step_avg:89.73ms +step:862/1680 train_time:77345ms step_avg:89.73ms +step:863/1680 train_time:77436ms step_avg:89.73ms +step:864/1680 train_time:77526ms step_avg:89.73ms +step:865/1680 train_time:77617ms step_avg:89.73ms +step:866/1680 train_time:77707ms step_avg:89.73ms +step:867/1680 train_time:77798ms step_avg:89.73ms +step:868/1680 train_time:77888ms step_avg:89.73ms +step:869/1680 train_time:77979ms step_avg:89.73ms +step:870/1680 train_time:78069ms step_avg:89.73ms +step:871/1680 train_time:78160ms step_avg:89.74ms +step:872/1680 train_time:78251ms step_avg:89.74ms +step:873/1680 train_time:78341ms step_avg:89.74ms +step:874/1680 train_time:78431ms step_avg:89.74ms +step:875/1680 train_time:78521ms step_avg:89.74ms +step:875/1680 val_loss:3.5187 train_time:78612ms step_avg:89.84ms +step:876/1680 train_time:78636ms step_avg:89.77ms +step:877/1680 train_time:78706ms step_avg:89.74ms +step:878/1680 train_time:78805ms step_avg:89.76ms +step:879/1680 train_time:78897ms step_avg:89.76ms +step:880/1680 train_time:78986ms step_avg:89.76ms +step:881/1680 train_time:79075ms step_avg:89.76ms +step:882/1680 train_time:79164ms step_avg:89.76ms +step:883/1680 train_time:79254ms step_avg:89.75ms +step:884/1680 train_time:79342ms step_avg:89.75ms +step:885/1680 train_time:79432ms step_avg:89.75ms +step:886/1680 train_time:79521ms step_avg:89.75ms +step:887/1680 train_time:79612ms step_avg:89.75ms +step:888/1680 train_time:79706ms step_avg:89.76ms +step:889/1680 train_time:79801ms step_avg:89.77ms +step:890/1680 train_time:79893ms step_avg:89.77ms +step:891/1680 train_time:79983ms step_avg:89.77ms +step:892/1680 train_time:80073ms step_avg:89.77ms +step:893/1680 train_time:80163ms step_avg:89.77ms +step:894/1680 train_time:80253ms step_avg:89.77ms +step:895/1680 train_time:80342ms step_avg:89.77ms +step:896/1680 train_time:80431ms step_avg:89.77ms +step:897/1680 train_time:80520ms step_avg:89.77ms +step:898/1680 train_time:80611ms step_avg:89.77ms +step:899/1680 train_time:80703ms step_avg:89.77ms +step:900/1680 train_time:80796ms step_avg:89.77ms +step:901/1680 train_time:80886ms step_avg:89.77ms +step:902/1680 train_time:80976ms step_avg:89.77ms +step:903/1680 train_time:81066ms step_avg:89.77ms +step:904/1680 train_time:81156ms step_avg:89.77ms +step:905/1680 train_time:81246ms step_avg:89.77ms +step:906/1680 train_time:81336ms step_avg:89.77ms +step:907/1680 train_time:81425ms step_avg:89.77ms +step:908/1680 train_time:81515ms step_avg:89.77ms +step:909/1680 train_time:81606ms step_avg:89.78ms +step:910/1680 train_time:81697ms step_avg:89.78ms +step:911/1680 train_time:81788ms step_avg:89.78ms +step:912/1680 train_time:81880ms step_avg:89.78ms +step:913/1680 train_time:81972ms step_avg:89.78ms +step:914/1680 train_time:82062ms step_avg:89.78ms +step:915/1680 train_time:82153ms step_avg:89.78ms +step:916/1680 train_time:82243ms step_avg:89.78ms +step:917/1680 train_time:82333ms step_avg:89.78ms +step:918/1680 train_time:82423ms step_avg:89.79ms +step:919/1680 train_time:82513ms step_avg:89.79ms +step:920/1680 train_time:82603ms step_avg:89.79ms +step:921/1680 train_time:82694ms step_avg:89.79ms +step:922/1680 train_time:82784ms step_avg:89.79ms +step:923/1680 train_time:82875ms step_avg:89.79ms +step:924/1680 train_time:82965ms step_avg:89.79ms +step:925/1680 train_time:83057ms step_avg:89.79ms +step:926/1680 train_time:83146ms step_avg:89.79ms +step:927/1680 train_time:83237ms step_avg:89.79ms +step:928/1680 train_time:83326ms step_avg:89.79ms +step:929/1680 train_time:83417ms step_avg:89.79ms +step:930/1680 train_time:83507ms step_avg:89.79ms +step:931/1680 train_time:83599ms step_avg:89.79ms +step:932/1680 train_time:83690ms step_avg:89.80ms +step:933/1680 train_time:83780ms step_avg:89.80ms +step:934/1680 train_time:83871ms step_avg:89.80ms +step:935/1680 train_time:83962ms step_avg:89.80ms +step:936/1680 train_time:84053ms step_avg:89.80ms +step:937/1680 train_time:84143ms step_avg:89.80ms +step:938/1680 train_time:84233ms step_avg:89.80ms +step:939/1680 train_time:84323ms step_avg:89.80ms +step:940/1680 train_time:84413ms step_avg:89.80ms +step:941/1680 train_time:84504ms step_avg:89.80ms +step:942/1680 train_time:84594ms step_avg:89.80ms +step:943/1680 train_time:84685ms step_avg:89.80ms +step:944/1680 train_time:84776ms step_avg:89.80ms +step:945/1680 train_time:84866ms step_avg:89.81ms +step:946/1680 train_time:84958ms step_avg:89.81ms +step:947/1680 train_time:85048ms step_avg:89.81ms +step:948/1680 train_time:85139ms step_avg:89.81ms +step:949/1680 train_time:85229ms step_avg:89.81ms +step:950/1680 train_time:85319ms step_avg:89.81ms +step:951/1680 train_time:85410ms step_avg:89.81ms +step:952/1680 train_time:85501ms step_avg:89.81ms +step:953/1680 train_time:85592ms step_avg:89.81ms +step:954/1680 train_time:85682ms step_avg:89.81ms +step:955/1680 train_time:85773ms step_avg:89.81ms +step:956/1680 train_time:85864ms step_avg:89.82ms +step:957/1680 train_time:85954ms step_avg:89.82ms +step:958/1680 train_time:86045ms step_avg:89.82ms +step:959/1680 train_time:86135ms step_avg:89.82ms +step:960/1680 train_time:86225ms step_avg:89.82ms +step:961/1680 train_time:86315ms step_avg:89.82ms +step:962/1680 train_time:86405ms step_avg:89.82ms +step:963/1680 train_time:86496ms step_avg:89.82ms +step:964/1680 train_time:86586ms step_avg:89.82ms +step:965/1680 train_time:86677ms step_avg:89.82ms +step:966/1680 train_time:86767ms step_avg:89.82ms +step:967/1680 train_time:86859ms step_avg:89.82ms +step:968/1680 train_time:86949ms step_avg:89.82ms +step:969/1680 train_time:87039ms step_avg:89.82ms +step:970/1680 train_time:87129ms step_avg:89.82ms +step:971/1680 train_time:87219ms step_avg:89.82ms +step:972/1680 train_time:87309ms step_avg:89.82ms +step:973/1680 train_time:87400ms step_avg:89.83ms +step:974/1680 train_time:87491ms step_avg:89.83ms +step:975/1680 train_time:87582ms step_avg:89.83ms +step:976/1680 train_time:87672ms step_avg:89.83ms +step:977/1680 train_time:87763ms step_avg:89.83ms +step:978/1680 train_time:87854ms step_avg:89.83ms +step:979/1680 train_time:87944ms step_avg:89.83ms +step:980/1680 train_time:88034ms step_avg:89.83ms +step:981/1680 train_time:88125ms step_avg:89.83ms +step:982/1680 train_time:88216ms step_avg:89.83ms +step:983/1680 train_time:88307ms step_avg:89.83ms +step:984/1680 train_time:88397ms step_avg:89.83ms +step:985/1680 train_time:88487ms step_avg:89.83ms +step:986/1680 train_time:88578ms step_avg:89.84ms +step:987/1680 train_time:88669ms step_avg:89.84ms +step:988/1680 train_time:88760ms step_avg:89.84ms +step:989/1680 train_time:88851ms step_avg:89.84ms +step:990/1680 train_time:88942ms step_avg:89.84ms +step:991/1680 train_time:89032ms step_avg:89.84ms +step:992/1680 train_time:89122ms step_avg:89.84ms +step:993/1680 train_time:89212ms step_avg:89.84ms +step:994/1680 train_time:89303ms step_avg:89.84ms +step:995/1680 train_time:89393ms step_avg:89.84ms +step:996/1680 train_time:89483ms step_avg:89.84ms +step:997/1680 train_time:89574ms step_avg:89.84ms +step:998/1680 train_time:89664ms step_avg:89.84ms +step:999/1680 train_time:89755ms step_avg:89.85ms +step:1000/1680 train_time:89846ms step_avg:89.85ms +step:1000/1680 val_loss:3.4697 train_time:89938ms step_avg:89.94ms +step:1001/1680 train_time:89962ms step_avg:89.87ms +step:1002/1680 train_time:90030ms step_avg:89.85ms +step:1003/1680 train_time:90127ms step_avg:89.86ms +step:1004/1680 train_time:90218ms step_avg:89.86ms +step:1005/1680 train_time:90307ms step_avg:89.86ms +step:1006/1680 train_time:90397ms step_avg:89.86ms +step:1007/1680 train_time:90485ms step_avg:89.86ms +step:1008/1680 train_time:90575ms step_avg:89.86ms +step:1009/1680 train_time:90664ms step_avg:89.86ms +step:1010/1680 train_time:90753ms step_avg:89.85ms +step:1011/1680 train_time:90842ms step_avg:89.85ms +step:1012/1680 train_time:90933ms step_avg:89.85ms +step:1013/1680 train_time:91024ms step_avg:89.86ms +step:1014/1680 train_time:91116ms step_avg:89.86ms +step:1015/1680 train_time:91208ms step_avg:89.86ms +step:1016/1680 train_time:91298ms step_avg:89.86ms +step:1017/1680 train_time:91387ms step_avg:89.86ms +step:1018/1680 train_time:91477ms step_avg:89.86ms +step:1019/1680 train_time:91566ms step_avg:89.86ms +step:1020/1680 train_time:91655ms step_avg:89.86ms +step:1021/1680 train_time:91745ms step_avg:89.86ms +step:1022/1680 train_time:91836ms step_avg:89.86ms +step:1023/1680 train_time:91927ms step_avg:89.86ms +step:1024/1680 train_time:92019ms step_avg:89.86ms +step:1025/1680 train_time:92110ms step_avg:89.86ms +step:1026/1680 train_time:92201ms step_avg:89.86ms +step:1027/1680 train_time:92292ms step_avg:89.87ms +step:1028/1680 train_time:92381ms step_avg:89.86ms +step:1029/1680 train_time:92470ms step_avg:89.86ms +step:1030/1680 train_time:92560ms step_avg:89.86ms +step:1031/1680 train_time:92650ms step_avg:89.86ms +step:1032/1680 train_time:92740ms step_avg:89.86ms +step:1033/1680 train_time:92831ms step_avg:89.87ms +step:1034/1680 train_time:92922ms step_avg:89.87ms +step:1035/1680 train_time:93014ms step_avg:89.87ms +step:1036/1680 train_time:93105ms step_avg:89.87ms +step:1037/1680 train_time:93196ms step_avg:89.87ms +step:1038/1680 train_time:93286ms step_avg:89.87ms +step:1039/1680 train_time:93377ms step_avg:89.87ms +step:1040/1680 train_time:93467ms step_avg:89.87ms +step:1041/1680 train_time:93558ms step_avg:89.87ms +step:1042/1680 train_time:93647ms step_avg:89.87ms +step:1043/1680 train_time:93738ms step_avg:89.87ms +step:1044/1680 train_time:93827ms step_avg:89.87ms +step:1045/1680 train_time:93918ms step_avg:89.87ms +step:1046/1680 train_time:94009ms step_avg:89.87ms +step:1047/1680 train_time:94101ms step_avg:89.88ms +step:1048/1680 train_time:94192ms step_avg:89.88ms +step:1049/1680 train_time:94283ms step_avg:89.88ms +step:1050/1680 train_time:94373ms step_avg:89.88ms +step:1051/1680 train_time:94463ms step_avg:89.88ms +step:1052/1680 train_time:94554ms step_avg:89.88ms +step:1053/1680 train_time:94644ms step_avg:89.88ms +step:1054/1680 train_time:94734ms step_avg:89.88ms +step:1055/1680 train_time:94824ms step_avg:89.88ms +step:1056/1680 train_time:94914ms step_avg:89.88ms +step:1057/1680 train_time:95005ms step_avg:89.88ms +step:1058/1680 train_time:95097ms step_avg:89.88ms +step:1059/1680 train_time:95187ms step_avg:89.88ms +step:1060/1680 train_time:95278ms step_avg:89.89ms +step:1061/1680 train_time:95368ms step_avg:89.89ms +step:1062/1680 train_time:95459ms step_avg:89.89ms +step:1063/1680 train_time:95549ms step_avg:89.89ms +step:1064/1680 train_time:95640ms step_avg:89.89ms +step:1065/1680 train_time:95730ms step_avg:89.89ms +step:1066/1680 train_time:95820ms step_avg:89.89ms +step:1067/1680 train_time:95911ms step_avg:89.89ms +step:1068/1680 train_time:96002ms step_avg:89.89ms +step:1069/1680 train_time:96093ms step_avg:89.89ms +step:1070/1680 train_time:96183ms step_avg:89.89ms +step:1071/1680 train_time:96274ms step_avg:89.89ms +step:1072/1680 train_time:96364ms step_avg:89.89ms +step:1073/1680 train_time:96454ms step_avg:89.89ms +step:1074/1680 train_time:96545ms step_avg:89.89ms +step:1075/1680 train_time:96635ms step_avg:89.89ms +step:1076/1680 train_time:96726ms step_avg:89.89ms +step:1077/1680 train_time:96816ms step_avg:89.89ms +step:1078/1680 train_time:96907ms step_avg:89.89ms +step:1079/1680 train_time:96998ms step_avg:89.90ms +step:1080/1680 train_time:97088ms step_avg:89.90ms +step:1081/1680 train_time:97179ms step_avg:89.90ms +step:1082/1680 train_time:97270ms step_avg:89.90ms +step:1083/1680 train_time:97360ms step_avg:89.90ms +step:1084/1680 train_time:97450ms step_avg:89.90ms +step:1085/1680 train_time:97541ms step_avg:89.90ms +step:1086/1680 train_time:97631ms step_avg:89.90ms +step:1087/1680 train_time:97722ms step_avg:89.90ms +step:1088/1680 train_time:97813ms step_avg:89.90ms +step:1089/1680 train_time:97903ms step_avg:89.90ms +step:1090/1680 train_time:97993ms step_avg:89.90ms +step:1091/1680 train_time:98083ms step_avg:89.90ms +step:1092/1680 train_time:98174ms step_avg:89.90ms +step:1093/1680 train_time:98264ms step_avg:89.90ms +step:1094/1680 train_time:98354ms step_avg:89.90ms +step:1095/1680 train_time:98445ms step_avg:89.90ms +step:1096/1680 train_time:98536ms step_avg:89.91ms +step:1097/1680 train_time:98626ms step_avg:89.91ms +step:1098/1680 train_time:98718ms step_avg:89.91ms +step:1099/1680 train_time:98811ms step_avg:89.91ms +step:1100/1680 train_time:98902ms step_avg:89.91ms +step:1101/1680 train_time:98993ms step_avg:89.91ms +step:1102/1680 train_time:99085ms step_avg:89.91ms +step:1103/1680 train_time:99175ms step_avg:89.91ms +step:1104/1680 train_time:99267ms step_avg:89.92ms +step:1105/1680 train_time:99357ms step_avg:89.92ms +step:1106/1680 train_time:99448ms step_avg:89.92ms +step:1107/1680 train_time:99540ms step_avg:89.92ms +step:1108/1680 train_time:99631ms step_avg:89.92ms +step:1109/1680 train_time:99723ms step_avg:89.92ms +step:1110/1680 train_time:99815ms step_avg:89.92ms +step:1111/1680 train_time:99906ms step_avg:89.92ms +step:1112/1680 train_time:99998ms step_avg:89.93ms +step:1113/1680 train_time:100088ms step_avg:89.93ms +step:1114/1680 train_time:100179ms step_avg:89.93ms +step:1115/1680 train_time:100271ms step_avg:89.93ms +step:1116/1680 train_time:100361ms step_avg:89.93ms +step:1117/1680 train_time:100452ms step_avg:89.93ms +step:1118/1680 train_time:100543ms step_avg:89.93ms +step:1119/1680 train_time:100635ms step_avg:89.93ms +step:1120/1680 train_time:100725ms step_avg:89.93ms +step:1121/1680 train_time:100817ms step_avg:89.93ms +step:1122/1680 train_time:100908ms step_avg:89.94ms +step:1123/1680 train_time:101000ms step_avg:89.94ms +step:1124/1680 train_time:101090ms step_avg:89.94ms +step:1125/1680 train_time:101181ms step_avg:89.94ms +step:1125/1680 val_loss:3.4161 train_time:101273ms step_avg:90.02ms +step:1126/1680 train_time:101297ms step_avg:89.96ms +step:1127/1680 train_time:101367ms step_avg:89.94ms +step:1128/1680 train_time:101468ms step_avg:89.95ms +step:1129/1680 train_time:101561ms step_avg:89.96ms +step:1130/1680 train_time:101652ms step_avg:89.96ms +step:1131/1680 train_time:101741ms step_avg:89.96ms +step:1132/1680 train_time:101831ms step_avg:89.96ms +step:1133/1680 train_time:101921ms step_avg:89.96ms +step:1134/1680 train_time:102011ms step_avg:89.96ms +step:1135/1680 train_time:102100ms step_avg:89.96ms +step:1136/1680 train_time:102190ms step_avg:89.96ms +step:1137/1680 train_time:102282ms step_avg:89.96ms +step:1138/1680 train_time:102375ms step_avg:89.96ms +step:1139/1680 train_time:102469ms step_avg:89.96ms +step:1140/1680 train_time:102562ms step_avg:89.97ms +step:1141/1680 train_time:102653ms step_avg:89.97ms +step:1142/1680 train_time:102743ms step_avg:89.97ms +step:1143/1680 train_time:102833ms step_avg:89.97ms +step:1144/1680 train_time:102924ms step_avg:89.97ms +step:1145/1680 train_time:103014ms step_avg:89.97ms +step:1146/1680 train_time:103103ms step_avg:89.97ms +step:1147/1680 train_time:103193ms step_avg:89.97ms +step:1148/1680 train_time:103284ms step_avg:89.97ms +step:1149/1680 train_time:103376ms step_avg:89.97ms +step:1150/1680 train_time:103469ms step_avg:89.97ms +step:1151/1680 train_time:103561ms step_avg:89.97ms +step:1152/1680 train_time:103652ms step_avg:89.98ms +step:1153/1680 train_time:103743ms step_avg:89.98ms +step:1154/1680 train_time:103833ms step_avg:89.98ms +step:1155/1680 train_time:103925ms step_avg:89.98ms +step:1156/1680 train_time:104015ms step_avg:89.98ms +step:1157/1680 train_time:104105ms step_avg:89.98ms +step:1158/1680 train_time:104196ms step_avg:89.98ms +step:1159/1680 train_time:104286ms step_avg:89.98ms +step:1160/1680 train_time:104379ms step_avg:89.98ms +step:1161/1680 train_time:104471ms step_avg:89.98ms +step:1162/1680 train_time:104563ms step_avg:89.99ms +step:1163/1680 train_time:104655ms step_avg:89.99ms +step:1164/1680 train_time:104747ms step_avg:89.99ms +step:1165/1680 train_time:104838ms step_avg:89.99ms +step:1166/1680 train_time:104928ms step_avg:89.99ms +step:1167/1680 train_time:105019ms step_avg:89.99ms +step:1168/1680 train_time:105110ms step_avg:89.99ms +step:1169/1680 train_time:105201ms step_avg:89.99ms +step:1170/1680 train_time:105292ms step_avg:89.99ms +step:1171/1680 train_time:105384ms step_avg:89.99ms +step:1172/1680 train_time:105474ms step_avg:90.00ms +step:1173/1680 train_time:105565ms step_avg:90.00ms +step:1174/1680 train_time:105658ms step_avg:90.00ms +step:1175/1680 train_time:105749ms step_avg:90.00ms +step:1176/1680 train_time:105840ms step_avg:90.00ms +step:1177/1680 train_time:105930ms step_avg:90.00ms +step:1178/1680 train_time:106022ms step_avg:90.00ms +step:1179/1680 train_time:106112ms step_avg:90.00ms +step:1180/1680 train_time:106203ms step_avg:90.00ms +step:1181/1680 train_time:106295ms step_avg:90.00ms +step:1182/1680 train_time:106386ms step_avg:90.01ms +step:1183/1680 train_time:106478ms step_avg:90.01ms +step:1184/1680 train_time:106569ms step_avg:90.01ms +step:1185/1680 train_time:106660ms step_avg:90.01ms +step:1186/1680 train_time:106751ms step_avg:90.01ms +step:1187/1680 train_time:106842ms step_avg:90.01ms +step:1188/1680 train_time:106932ms step_avg:90.01ms +step:1189/1680 train_time:107023ms step_avg:90.01ms +step:1190/1680 train_time:107115ms step_avg:90.01ms +step:1191/1680 train_time:107206ms step_avg:90.01ms +step:1192/1680 train_time:107297ms step_avg:90.01ms +step:1193/1680 train_time:107388ms step_avg:90.02ms +step:1194/1680 train_time:107479ms step_avg:90.02ms +step:1195/1680 train_time:107571ms step_avg:90.02ms +step:1196/1680 train_time:107662ms step_avg:90.02ms +step:1197/1680 train_time:107753ms step_avg:90.02ms +step:1198/1680 train_time:107844ms step_avg:90.02ms +step:1199/1680 train_time:107934ms step_avg:90.02ms +step:1200/1680 train_time:108024ms step_avg:90.02ms +step:1201/1680 train_time:108116ms step_avg:90.02ms +step:1202/1680 train_time:108206ms step_avg:90.02ms +step:1203/1680 train_time:108296ms step_avg:90.02ms +step:1204/1680 train_time:108387ms step_avg:90.02ms +step:1205/1680 train_time:108478ms step_avg:90.02ms +step:1206/1680 train_time:108569ms step_avg:90.02ms +step:1207/1680 train_time:108660ms step_avg:90.03ms +step:1208/1680 train_time:108752ms step_avg:90.03ms +step:1209/1680 train_time:108843ms step_avg:90.03ms +step:1210/1680 train_time:108933ms step_avg:90.03ms +step:1211/1680 train_time:109024ms step_avg:90.03ms +step:1212/1680 train_time:109115ms step_avg:90.03ms +step:1213/1680 train_time:109206ms step_avg:90.03ms +step:1214/1680 train_time:109297ms step_avg:90.03ms +step:1215/1680 train_time:109387ms step_avg:90.03ms +step:1216/1680 train_time:109479ms step_avg:90.03ms +step:1217/1680 train_time:109570ms step_avg:90.03ms +step:1218/1680 train_time:109661ms step_avg:90.03ms +step:1219/1680 train_time:109752ms step_avg:90.03ms +step:1220/1680 train_time:109844ms step_avg:90.04ms +step:1221/1680 train_time:109936ms step_avg:90.04ms +step:1222/1680 train_time:110026ms step_avg:90.04ms +step:1223/1680 train_time:110117ms step_avg:90.04ms +step:1224/1680 train_time:110208ms step_avg:90.04ms +step:1225/1680 train_time:110299ms step_avg:90.04ms +step:1226/1680 train_time:110391ms step_avg:90.04ms +step:1227/1680 train_time:110482ms step_avg:90.04ms +step:1228/1680 train_time:110573ms step_avg:90.04ms +step:1229/1680 train_time:110664ms step_avg:90.04ms +step:1230/1680 train_time:110755ms step_avg:90.04ms +step:1231/1680 train_time:110846ms step_avg:90.05ms +step:1232/1680 train_time:110938ms step_avg:90.05ms +step:1233/1680 train_time:111029ms step_avg:90.05ms +step:1234/1680 train_time:111121ms step_avg:90.05ms +step:1235/1680 train_time:111212ms step_avg:90.05ms +step:1236/1680 train_time:111302ms step_avg:90.05ms +step:1237/1680 train_time:111393ms step_avg:90.05ms +step:1238/1680 train_time:111484ms step_avg:90.05ms +step:1239/1680 train_time:111574ms step_avg:90.05ms +step:1240/1680 train_time:111665ms step_avg:90.05ms +step:1241/1680 train_time:111756ms step_avg:90.05ms +step:1242/1680 train_time:111847ms step_avg:90.05ms +step:1243/1680 train_time:111938ms step_avg:90.05ms +step:1244/1680 train_time:112029ms step_avg:90.06ms +step:1245/1680 train_time:112121ms step_avg:90.06ms +step:1246/1680 train_time:112211ms step_avg:90.06ms +step:1247/1680 train_time:112302ms step_avg:90.06ms +step:1248/1680 train_time:112393ms step_avg:90.06ms +step:1249/1680 train_time:112483ms step_avg:90.06ms +step:1250/1680 train_time:112574ms step_avg:90.06ms +step:1250/1680 val_loss:3.3773 train_time:112667ms step_avg:90.13ms +step:1251/1680 train_time:112691ms step_avg:90.08ms +step:1252/1680 train_time:112761ms step_avg:90.06ms +step:1253/1680 train_time:112858ms step_avg:90.07ms +step:1254/1680 train_time:112951ms step_avg:90.07ms +step:1255/1680 train_time:113042ms step_avg:90.07ms +step:1256/1680 train_time:113133ms step_avg:90.07ms +step:1257/1680 train_time:113223ms step_avg:90.07ms +step:1258/1680 train_time:113313ms step_avg:90.07ms +step:1259/1680 train_time:113403ms step_avg:90.07ms +step:1260/1680 train_time:113492ms step_avg:90.07ms +step:1261/1680 train_time:113582ms step_avg:90.07ms +step:1262/1680 train_time:113674ms step_avg:90.07ms +step:1263/1680 train_time:113768ms step_avg:90.08ms +step:1264/1680 train_time:113861ms step_avg:90.08ms +step:1265/1680 train_time:113953ms step_avg:90.08ms +step:1266/1680 train_time:114044ms step_avg:90.08ms +step:1267/1680 train_time:114135ms step_avg:90.08ms +step:1268/1680 train_time:114226ms step_avg:90.08ms +step:1269/1680 train_time:114316ms step_avg:90.08ms +step:1270/1680 train_time:114407ms step_avg:90.08ms +step:1271/1680 train_time:114496ms step_avg:90.08ms +step:1272/1680 train_time:114586ms step_avg:90.08ms +step:1273/1680 train_time:114678ms step_avg:90.09ms +step:1274/1680 train_time:114772ms step_avg:90.09ms +step:1275/1680 train_time:114864ms step_avg:90.09ms +step:1276/1680 train_time:114956ms step_avg:90.09ms +step:1277/1680 train_time:115047ms step_avg:90.09ms +step:1278/1680 train_time:115137ms step_avg:90.09ms +step:1279/1680 train_time:115227ms step_avg:90.09ms +step:1280/1680 train_time:115318ms step_avg:90.09ms +step:1281/1680 train_time:115408ms step_avg:90.09ms +step:1282/1680 train_time:115498ms step_avg:90.09ms +step:1283/1680 train_time:115589ms step_avg:90.09ms +step:1284/1680 train_time:115681ms step_avg:90.09ms +step:1285/1680 train_time:115773ms step_avg:90.10ms +step:1286/1680 train_time:115865ms step_avg:90.10ms +step:1287/1680 train_time:115957ms step_avg:90.10ms +step:1288/1680 train_time:116048ms step_avg:90.10ms +step:1289/1680 train_time:116138ms step_avg:90.10ms +step:1290/1680 train_time:116230ms step_avg:90.10ms +step:1291/1680 train_time:116320ms step_avg:90.10ms +step:1292/1680 train_time:116411ms step_avg:90.10ms +step:1293/1680 train_time:116502ms step_avg:90.10ms +step:1294/1680 train_time:116593ms step_avg:90.10ms +step:1295/1680 train_time:116684ms step_avg:90.10ms +step:1296/1680 train_time:116776ms step_avg:90.10ms +step:1297/1680 train_time:116867ms step_avg:90.11ms +step:1298/1680 train_time:116960ms step_avg:90.11ms +step:1299/1680 train_time:117051ms step_avg:90.11ms +step:1300/1680 train_time:117142ms step_avg:90.11ms +step:1301/1680 train_time:117233ms step_avg:90.11ms +step:1302/1680 train_time:117323ms step_avg:90.11ms +step:1303/1680 train_time:117414ms step_avg:90.11ms +step:1304/1680 train_time:117505ms step_avg:90.11ms +step:1305/1680 train_time:117595ms step_avg:90.11ms +step:1306/1680 train_time:117685ms step_avg:90.11ms +step:1307/1680 train_time:117777ms step_avg:90.11ms +step:1308/1680 train_time:117869ms step_avg:90.11ms +step:1309/1680 train_time:117960ms step_avg:90.11ms +step:1310/1680 train_time:118053ms step_avg:90.12ms +step:1311/1680 train_time:118145ms step_avg:90.12ms +step:1312/1680 train_time:118235ms step_avg:90.12ms +step:1313/1680 train_time:118327ms step_avg:90.12ms +step:1314/1680 train_time:118417ms step_avg:90.12ms +step:1315/1680 train_time:118508ms step_avg:90.12ms +step:1316/1680 train_time:118598ms step_avg:90.12ms +step:1317/1680 train_time:118688ms step_avg:90.12ms +step:1318/1680 train_time:118779ms step_avg:90.12ms +step:1319/1680 train_time:118871ms step_avg:90.12ms +step:1320/1680 train_time:118962ms step_avg:90.12ms +step:1321/1680 train_time:119053ms step_avg:90.12ms +step:1322/1680 train_time:119145ms step_avg:90.12ms +step:1323/1680 train_time:119236ms step_avg:90.13ms +step:1324/1680 train_time:119328ms step_avg:90.13ms +step:1325/1680 train_time:119419ms step_avg:90.13ms +step:1326/1680 train_time:119509ms step_avg:90.13ms +step:1327/1680 train_time:119598ms step_avg:90.13ms +step:1328/1680 train_time:119689ms step_avg:90.13ms +step:1329/1680 train_time:119780ms step_avg:90.13ms +step:1330/1680 train_time:119871ms step_avg:90.13ms +step:1331/1680 train_time:119962ms step_avg:90.13ms +step:1332/1680 train_time:120054ms step_avg:90.13ms +step:1333/1680 train_time:120145ms step_avg:90.13ms +step:1334/1680 train_time:120236ms step_avg:90.13ms +step:1335/1680 train_time:120327ms step_avg:90.13ms +step:1336/1680 train_time:120418ms step_avg:90.13ms +step:1337/1680 train_time:120509ms step_avg:90.13ms +step:1338/1680 train_time:120599ms step_avg:90.13ms +step:1339/1680 train_time:120690ms step_avg:90.13ms +step:1340/1680 train_time:120780ms step_avg:90.13ms +step:1341/1680 train_time:120871ms step_avg:90.14ms +step:1342/1680 train_time:120963ms step_avg:90.14ms +step:1343/1680 train_time:121055ms step_avg:90.14ms +step:1344/1680 train_time:121147ms step_avg:90.14ms +step:1345/1680 train_time:121238ms step_avg:90.14ms +step:1346/1680 train_time:121328ms step_avg:90.14ms +step:1347/1680 train_time:121419ms step_avg:90.14ms +step:1348/1680 train_time:121511ms step_avg:90.14ms +step:1349/1680 train_time:121602ms step_avg:90.14ms +step:1350/1680 train_time:121692ms step_avg:90.14ms +step:1351/1680 train_time:121783ms step_avg:90.14ms +step:1352/1680 train_time:121874ms step_avg:90.14ms +step:1353/1680 train_time:121965ms step_avg:90.14ms +step:1354/1680 train_time:122057ms step_avg:90.15ms +step:1355/1680 train_time:122149ms step_avg:90.15ms +step:1356/1680 train_time:122240ms step_avg:90.15ms +step:1357/1680 train_time:122332ms step_avg:90.15ms +step:1358/1680 train_time:122423ms step_avg:90.15ms +step:1359/1680 train_time:122514ms step_avg:90.15ms +step:1360/1680 train_time:122605ms step_avg:90.15ms +step:1361/1680 train_time:122696ms step_avg:90.15ms +step:1362/1680 train_time:122786ms step_avg:90.15ms +step:1363/1680 train_time:122878ms step_avg:90.15ms +step:1364/1680 train_time:122969ms step_avg:90.15ms +step:1365/1680 train_time:123060ms step_avg:90.15ms +step:1366/1680 train_time:123152ms step_avg:90.16ms +step:1367/1680 train_time:123242ms step_avg:90.16ms +step:1368/1680 train_time:123333ms step_avg:90.16ms +step:1369/1680 train_time:123425ms step_avg:90.16ms +step:1370/1680 train_time:123516ms step_avg:90.16ms +step:1371/1680 train_time:123607ms step_avg:90.16ms +step:1372/1680 train_time:123698ms step_avg:90.16ms +step:1373/1680 train_time:123789ms step_avg:90.16ms +step:1374/1680 train_time:123880ms step_avg:90.16ms +step:1375/1680 train_time:123971ms step_avg:90.16ms +step:1375/1680 val_loss:3.3431 train_time:124062ms step_avg:90.23ms +step:1376/1680 train_time:124086ms step_avg:90.18ms +step:1377/1680 train_time:124156ms step_avg:90.16ms +step:1378/1680 train_time:124255ms step_avg:90.17ms +step:1379/1680 train_time:124347ms step_avg:90.17ms +step:1380/1680 train_time:124437ms step_avg:90.17ms +step:1381/1680 train_time:124526ms step_avg:90.17ms +step:1382/1680 train_time:124616ms step_avg:90.17ms +step:1383/1680 train_time:124706ms step_avg:90.17ms +step:1384/1680 train_time:124796ms step_avg:90.17ms +step:1385/1680 train_time:124885ms step_avg:90.17ms +step:1386/1680 train_time:124976ms step_avg:90.17ms +step:1387/1680 train_time:125068ms step_avg:90.17ms +step:1388/1680 train_time:125160ms step_avg:90.17ms +step:1389/1680 train_time:125255ms step_avg:90.18ms +step:1390/1680 train_time:125348ms step_avg:90.18ms +step:1391/1680 train_time:125438ms step_avg:90.18ms +step:1392/1680 train_time:125529ms step_avg:90.18ms +step:1393/1680 train_time:125619ms step_avg:90.18ms +step:1394/1680 train_time:125709ms step_avg:90.18ms +step:1395/1680 train_time:125800ms step_avg:90.18ms +step:1396/1680 train_time:125890ms step_avg:90.18ms +step:1397/1680 train_time:125980ms step_avg:90.18ms +step:1398/1680 train_time:126072ms step_avg:90.18ms +step:1399/1680 train_time:126164ms step_avg:90.18ms +step:1400/1680 train_time:126257ms step_avg:90.18ms +step:1401/1680 train_time:126349ms step_avg:90.18ms +step:1402/1680 train_time:126440ms step_avg:90.19ms +step:1403/1680 train_time:126531ms step_avg:90.19ms +step:1404/1680 train_time:126621ms step_avg:90.19ms +step:1405/1680 train_time:126711ms step_avg:90.19ms +step:1406/1680 train_time:126801ms step_avg:90.19ms +step:1407/1680 train_time:126892ms step_avg:90.19ms +step:1408/1680 train_time:126981ms step_avg:90.19ms +step:1409/1680 train_time:127074ms step_avg:90.19ms +step:1410/1680 train_time:127166ms step_avg:90.19ms +step:1411/1680 train_time:127257ms step_avg:90.19ms +step:1412/1680 train_time:127349ms step_avg:90.19ms +step:1413/1680 train_time:127440ms step_avg:90.19ms +step:1414/1680 train_time:127531ms step_avg:90.19ms +step:1415/1680 train_time:127622ms step_avg:90.19ms +step:1416/1680 train_time:127712ms step_avg:90.19ms +step:1417/1680 train_time:127802ms step_avg:90.19ms +step:1418/1680 train_time:127892ms step_avg:90.19ms +step:1419/1680 train_time:127983ms step_avg:90.19ms +step:1420/1680 train_time:128074ms step_avg:90.19ms +step:1421/1680 train_time:128166ms step_avg:90.19ms +step:1422/1680 train_time:128257ms step_avg:90.20ms +step:1423/1680 train_time:128349ms step_avg:90.20ms +step:1424/1680 train_time:128439ms step_avg:90.20ms +step:1425/1680 train_time:128530ms step_avg:90.20ms +step:1426/1680 train_time:128621ms step_avg:90.20ms +step:1427/1680 train_time:128712ms step_avg:90.20ms +step:1428/1680 train_time:128802ms step_avg:90.20ms +step:1429/1680 train_time:128893ms step_avg:90.20ms +step:1430/1680 train_time:128983ms step_avg:90.20ms +step:1431/1680 train_time:129075ms step_avg:90.20ms +step:1432/1680 train_time:129166ms step_avg:90.20ms +step:1433/1680 train_time:129258ms step_avg:90.20ms +step:1434/1680 train_time:129349ms step_avg:90.20ms +step:1435/1680 train_time:129440ms step_avg:90.20ms +step:1436/1680 train_time:129531ms step_avg:90.20ms +step:1437/1680 train_time:129622ms step_avg:90.20ms +step:1438/1680 train_time:129713ms step_avg:90.20ms +step:1439/1680 train_time:129804ms step_avg:90.20ms +step:1440/1680 train_time:129894ms step_avg:90.20ms +step:1441/1680 train_time:129985ms step_avg:90.21ms +step:1442/1680 train_time:130077ms step_avg:90.21ms +step:1443/1680 train_time:130168ms step_avg:90.21ms +step:1444/1680 train_time:130259ms step_avg:90.21ms +step:1445/1680 train_time:130350ms step_avg:90.21ms +step:1446/1680 train_time:130442ms step_avg:90.21ms +step:1447/1680 train_time:130532ms step_avg:90.21ms +step:1448/1680 train_time:130624ms step_avg:90.21ms +step:1449/1680 train_time:130716ms step_avg:90.21ms +step:1450/1680 train_time:130807ms step_avg:90.21ms +step:1451/1680 train_time:130897ms step_avg:90.21ms +step:1452/1680 train_time:130989ms step_avg:90.21ms +step:1453/1680 train_time:131079ms step_avg:90.21ms +step:1454/1680 train_time:131171ms step_avg:90.21ms +step:1455/1680 train_time:131261ms step_avg:90.21ms +step:1456/1680 train_time:131353ms step_avg:90.21ms +step:1457/1680 train_time:131443ms step_avg:90.21ms +step:1458/1680 train_time:131535ms step_avg:90.22ms +step:1459/1680 train_time:131626ms step_avg:90.22ms +step:1460/1680 train_time:131717ms step_avg:90.22ms +step:1461/1680 train_time:131808ms step_avg:90.22ms +step:1462/1680 train_time:131900ms step_avg:90.22ms +step:1463/1680 train_time:131991ms step_avg:90.22ms +step:1464/1680 train_time:132082ms step_avg:90.22ms +step:1465/1680 train_time:132174ms step_avg:90.22ms +step:1466/1680 train_time:132265ms step_avg:90.22ms +step:1467/1680 train_time:132356ms step_avg:90.22ms +step:1468/1680 train_time:132446ms step_avg:90.22ms +step:1469/1680 train_time:132537ms step_avg:90.22ms +step:1470/1680 train_time:132628ms step_avg:90.22ms +step:1471/1680 train_time:132719ms step_avg:90.22ms +step:1472/1680 train_time:132810ms step_avg:90.22ms +step:1473/1680 train_time:132902ms step_avg:90.23ms +step:1474/1680 train_time:132993ms step_avg:90.23ms +step:1475/1680 train_time:133083ms step_avg:90.23ms +step:1476/1680 train_time:133174ms step_avg:90.23ms +step:1477/1680 train_time:133266ms step_avg:90.23ms +step:1478/1680 train_time:133356ms step_avg:90.23ms +step:1479/1680 train_time:133447ms step_avg:90.23ms +step:1480/1680 train_time:133539ms step_avg:90.23ms +step:1481/1680 train_time:133629ms step_avg:90.23ms +step:1482/1680 train_time:133720ms step_avg:90.23ms +step:1483/1680 train_time:133812ms step_avg:90.23ms +step:1484/1680 train_time:133902ms step_avg:90.23ms +step:1485/1680 train_time:133994ms step_avg:90.23ms +step:1486/1680 train_time:134084ms step_avg:90.23ms +step:1487/1680 train_time:134176ms step_avg:90.23ms +step:1488/1680 train_time:134269ms step_avg:90.23ms +step:1489/1680 train_time:134360ms step_avg:90.23ms +step:1490/1680 train_time:134450ms step_avg:90.24ms +step:1491/1680 train_time:134541ms step_avg:90.24ms +step:1492/1680 train_time:134631ms step_avg:90.24ms +step:1493/1680 train_time:134723ms step_avg:90.24ms +step:1494/1680 train_time:134814ms step_avg:90.24ms +step:1495/1680 train_time:134904ms step_avg:90.24ms +step:1496/1680 train_time:134995ms step_avg:90.24ms +step:1497/1680 train_time:135086ms step_avg:90.24ms +step:1498/1680 train_time:135178ms step_avg:90.24ms +step:1499/1680 train_time:135269ms step_avg:90.24ms +step:1500/1680 train_time:135360ms step_avg:90.24ms +step:1500/1680 val_loss:3.3134 train_time:135453ms step_avg:90.30ms +step:1501/1680 train_time:135476ms step_avg:90.26ms +step:1502/1680 train_time:135549ms step_avg:90.25ms +step:1503/1680 train_time:135646ms step_avg:90.25ms +step:1504/1680 train_time:135737ms step_avg:90.25ms +step:1505/1680 train_time:135828ms step_avg:90.25ms +step:1506/1680 train_time:135918ms step_avg:90.25ms +step:1507/1680 train_time:136008ms step_avg:90.25ms +step:1508/1680 train_time:136098ms step_avg:90.25ms +step:1509/1680 train_time:136187ms step_avg:90.25ms +step:1510/1680 train_time:136278ms step_avg:90.25ms +step:1511/1680 train_time:136368ms step_avg:90.25ms +step:1512/1680 train_time:136459ms step_avg:90.25ms +step:1513/1680 train_time:136553ms step_avg:90.25ms +step:1514/1680 train_time:136647ms step_avg:90.26ms +step:1515/1680 train_time:136738ms step_avg:90.26ms +step:1516/1680 train_time:136830ms step_avg:90.26ms +step:1517/1680 train_time:136920ms step_avg:90.26ms +step:1518/1680 train_time:137009ms step_avg:90.26ms +step:1519/1680 train_time:137100ms step_avg:90.26ms +step:1520/1680 train_time:137189ms step_avg:90.26ms +step:1521/1680 train_time:137279ms step_avg:90.26ms +step:1522/1680 train_time:137370ms step_avg:90.26ms +step:1523/1680 train_time:137461ms step_avg:90.26ms +step:1524/1680 train_time:137553ms step_avg:90.26ms +step:1525/1680 train_time:137647ms step_avg:90.26ms +step:1526/1680 train_time:137738ms step_avg:90.26ms +step:1527/1680 train_time:137829ms step_avg:90.26ms +step:1528/1680 train_time:137920ms step_avg:90.26ms +step:1529/1680 train_time:138010ms step_avg:90.26ms +step:1530/1680 train_time:138101ms step_avg:90.26ms +step:1531/1680 train_time:138190ms step_avg:90.26ms +step:1532/1680 train_time:138280ms step_avg:90.26ms +step:1533/1680 train_time:138371ms step_avg:90.26ms +step:1534/1680 train_time:138462ms step_avg:90.26ms +step:1535/1680 train_time:138554ms step_avg:90.26ms +step:1536/1680 train_time:138646ms step_avg:90.26ms +step:1537/1680 train_time:138737ms step_avg:90.26ms +step:1538/1680 train_time:138830ms step_avg:90.27ms +step:1539/1680 train_time:138922ms step_avg:90.27ms +step:1540/1680 train_time:139012ms step_avg:90.27ms +step:1541/1680 train_time:139104ms step_avg:90.27ms +step:1542/1680 train_time:139194ms step_avg:90.27ms +step:1543/1680 train_time:139284ms step_avg:90.27ms +step:1544/1680 train_time:139374ms step_avg:90.27ms +step:1545/1680 train_time:139466ms step_avg:90.27ms +step:1546/1680 train_time:139557ms step_avg:90.27ms +step:1547/1680 train_time:139648ms step_avg:90.27ms +step:1548/1680 train_time:139739ms step_avg:90.27ms +step:1549/1680 train_time:139830ms step_avg:90.27ms +step:1550/1680 train_time:139921ms step_avg:90.27ms +step:1551/1680 train_time:140011ms step_avg:90.27ms +step:1552/1680 train_time:140103ms step_avg:90.27ms +step:1553/1680 train_time:140192ms step_avg:90.27ms +step:1554/1680 train_time:140282ms step_avg:90.27ms +step:1555/1680 train_time:140373ms step_avg:90.27ms +step:1556/1680 train_time:140465ms step_avg:90.27ms +step:1557/1680 train_time:140556ms step_avg:90.27ms +step:1558/1680 train_time:140648ms step_avg:90.27ms +step:1559/1680 train_time:140740ms step_avg:90.28ms +step:1560/1680 train_time:140831ms step_avg:90.28ms +step:1561/1680 train_time:140923ms step_avg:90.28ms +step:1562/1680 train_time:141015ms step_avg:90.28ms +step:1563/1680 train_time:141106ms step_avg:90.28ms +step:1564/1680 train_time:141197ms step_avg:90.28ms +step:1565/1680 train_time:141288ms step_avg:90.28ms +step:1566/1680 train_time:141379ms step_avg:90.28ms +step:1567/1680 train_time:141470ms step_avg:90.28ms +step:1568/1680 train_time:141562ms step_avg:90.28ms +step:1569/1680 train_time:141654ms step_avg:90.28ms +step:1570/1680 train_time:141744ms step_avg:90.28ms +step:1571/1680 train_time:141835ms step_avg:90.28ms +step:1572/1680 train_time:141928ms step_avg:90.29ms +step:1573/1680 train_time:142021ms step_avg:90.29ms +step:1574/1680 train_time:142112ms step_avg:90.29ms +step:1575/1680 train_time:142202ms step_avg:90.29ms +step:1576/1680 train_time:142292ms step_avg:90.29ms +step:1577/1680 train_time:142382ms step_avg:90.29ms +step:1578/1680 train_time:142474ms step_avg:90.29ms +step:1579/1680 train_time:142565ms step_avg:90.29ms +step:1580/1680 train_time:142656ms step_avg:90.29ms +step:1581/1680 train_time:142747ms step_avg:90.29ms +step:1582/1680 train_time:142838ms step_avg:90.29ms +step:1583/1680 train_time:142930ms step_avg:90.29ms +step:1584/1680 train_time:143023ms step_avg:90.29ms +step:1585/1680 train_time:143113ms step_avg:90.29ms +step:1586/1680 train_time:143205ms step_avg:90.29ms +step:1587/1680 train_time:143296ms step_avg:90.29ms +step:1588/1680 train_time:143387ms step_avg:90.29ms +step:1589/1680 train_time:143479ms step_avg:90.29ms +step:1590/1680 train_time:143569ms step_avg:90.29ms +step:1591/1680 train_time:143660ms step_avg:90.30ms +step:1592/1680 train_time:143751ms step_avg:90.30ms +step:1593/1680 train_time:143842ms step_avg:90.30ms +step:1594/1680 train_time:143934ms step_avg:90.30ms +step:1595/1680 train_time:144025ms step_avg:90.30ms +step:1596/1680 train_time:144115ms step_avg:90.30ms +step:1597/1680 train_time:144207ms step_avg:90.30ms +step:1598/1680 train_time:144299ms step_avg:90.30ms +step:1599/1680 train_time:144390ms step_avg:90.30ms +step:1600/1680 train_time:144481ms step_avg:90.30ms +step:1601/1680 train_time:144572ms step_avg:90.30ms +step:1602/1680 train_time:144662ms step_avg:90.30ms +step:1603/1680 train_time:144754ms step_avg:90.30ms +step:1604/1680 train_time:144845ms step_avg:90.30ms +step:1605/1680 train_time:144936ms step_avg:90.30ms +step:1606/1680 train_time:145028ms step_avg:90.30ms +step:1607/1680 train_time:145119ms step_avg:90.30ms +step:1608/1680 train_time:145211ms step_avg:90.31ms +step:1609/1680 train_time:145303ms step_avg:90.31ms +step:1610/1680 train_time:145393ms step_avg:90.31ms +step:1611/1680 train_time:145484ms step_avg:90.31ms +step:1612/1680 train_time:145575ms step_avg:90.31ms +step:1613/1680 train_time:145666ms step_avg:90.31ms +step:1614/1680 train_time:145757ms step_avg:90.31ms +step:1615/1680 train_time:145848ms step_avg:90.31ms +step:1616/1680 train_time:145939ms step_avg:90.31ms +step:1617/1680 train_time:146031ms step_avg:90.31ms +step:1618/1680 train_time:146122ms step_avg:90.31ms +step:1619/1680 train_time:146212ms step_avg:90.31ms +step:1620/1680 train_time:146304ms step_avg:90.31ms +step:1621/1680 train_time:146396ms step_avg:90.31ms +step:1622/1680 train_time:146487ms step_avg:90.31ms +step:1623/1680 train_time:146577ms step_avg:90.31ms +step:1624/1680 train_time:146668ms step_avg:90.31ms +step:1625/1680 train_time:146759ms step_avg:90.31ms +step:1625/1680 val_loss:3.2897 train_time:146851ms step_avg:90.37ms +step:1626/1680 train_time:146875ms step_avg:90.33ms +step:1627/1680 train_time:146950ms step_avg:90.32ms +step:1628/1680 train_time:147046ms step_avg:90.32ms +step:1629/1680 train_time:147139ms step_avg:90.32ms +step:1630/1680 train_time:147230ms step_avg:90.32ms +step:1631/1680 train_time:147319ms step_avg:90.32ms +step:1632/1680 train_time:147409ms step_avg:90.32ms +step:1633/1680 train_time:147498ms step_avg:90.32ms +step:1634/1680 train_time:147588ms step_avg:90.32ms +step:1635/1680 train_time:147677ms step_avg:90.32ms +step:1636/1680 train_time:147767ms step_avg:90.32ms +step:1637/1680 train_time:147859ms step_avg:90.32ms +step:1638/1680 train_time:147951ms step_avg:90.32ms +step:1639/1680 train_time:148046ms step_avg:90.33ms +step:1640/1680 train_time:148139ms step_avg:90.33ms +step:1641/1680 train_time:148230ms step_avg:90.33ms +step:1642/1680 train_time:148320ms step_avg:90.33ms +step:1643/1680 train_time:148411ms step_avg:90.33ms +step:1644/1680 train_time:148501ms step_avg:90.33ms +step:1645/1680 train_time:148591ms step_avg:90.33ms +step:1646/1680 train_time:148681ms step_avg:90.33ms +step:1647/1680 train_time:148771ms step_avg:90.33ms +step:1648/1680 train_time:148864ms step_avg:90.33ms +step:1649/1680 train_time:148957ms step_avg:90.33ms +step:1650/1680 train_time:149049ms step_avg:90.33ms +step:1651/1680 train_time:149141ms step_avg:90.33ms +step:1652/1680 train_time:149233ms step_avg:90.33ms +step:1653/1680 train_time:149324ms step_avg:90.34ms +step:1654/1680 train_time:149414ms step_avg:90.34ms +step:1655/1680 train_time:149505ms step_avg:90.34ms +step:1656/1680 train_time:149594ms step_avg:90.33ms +step:1657/1680 train_time:149685ms step_avg:90.33ms +step:1658/1680 train_time:149774ms step_avg:90.33ms +step:1659/1680 train_time:149866ms step_avg:90.34ms +step:1660/1680 train_time:149957ms step_avg:90.34ms +step:1661/1680 train_time:150048ms step_avg:90.34ms +step:1662/1680 train_time:150141ms step_avg:90.34ms +step:1663/1680 train_time:150233ms step_avg:90.34ms +step:1664/1680 train_time:150325ms step_avg:90.34ms +step:1665/1680 train_time:150415ms step_avg:90.34ms +step:1666/1680 train_time:150506ms step_avg:90.34ms +step:1667/1680 train_time:150596ms step_avg:90.34ms +step:1668/1680 train_time:150686ms step_avg:90.34ms +step:1669/1680 train_time:150777ms step_avg:90.34ms +step:1670/1680 train_time:150868ms step_avg:90.34ms +step:1671/1680 train_time:150959ms step_avg:90.34ms +step:1672/1680 train_time:151050ms step_avg:90.34ms +step:1673/1680 train_time:151143ms step_avg:90.34ms +step:1674/1680 train_time:151235ms step_avg:90.34ms +step:1675/1680 train_time:151326ms step_avg:90.34ms +step:1676/1680 train_time:151418ms step_avg:90.34ms +step:1677/1680 train_time:151509ms step_avg:90.35ms +step:1678/1680 train_time:151600ms step_avg:90.35ms +step:1679/1680 train_time:151690ms step_avg:90.35ms +step:1680/1680 train_time:151780ms step_avg:90.35ms +step:1680/1680 val_loss:3.2789 train_time:151872ms step_avg:90.40ms +peak memory allocated: 31255 MiB reserved: 46814 MiB diff --git a/records/092125_DropAttn/bc936c5a-1d9f-4405-8648-de50e4d5aca6.txt b/records/092125_DropAttn/bc936c5a-1d9f-4405-8648-de50e4d5aca6.txt new file mode 100644 index 000000000..eef388b49 --- /dev/null +++ b/records/092125_DropAttn/bc936c5a-1d9f-4405-8648-de50e4d5aca6.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 22:41:29 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 44C P0 126W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 39C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 40C P0 127W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 65397 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 65398 C /usr/bin/python3 614MiB | +| 0 N/A N/A 65399 C /usr/bin/python3 614MiB | +| 0 N/A N/A 65400 C /usr/bin/python3 614MiB | +| 0 N/A N/A 65401 C /usr/bin/python3 614MiB | +| 0 N/A N/A 65402 C /usr/bin/python3 614MiB | +| 0 N/A N/A 65403 C /usr/bin/python3 614MiB | +| 0 N/A N/A 65404 C /usr/bin/python3 614MiB | +| 1 N/A N/A 65398 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 65399 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 65400 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 65401 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 65402 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 65403 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 65404 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:154ms step_avg:154.18ms +step:2/1680 train_time:180ms step_avg:89.95ms +step:3/1680 train_time:240ms step_avg:80.11ms +step:4/1680 train_time:327ms step_avg:81.83ms +step:5/1680 train_time:416ms step_avg:83.12ms +step:6/1680 train_time:504ms step_avg:83.96ms +step:7/1680 train_time:592ms step_avg:84.61ms +step:8/1680 train_time:680ms step_avg:85.05ms +step:9/1680 train_time:769ms step_avg:85.40ms +step:10/1680 train_time:857ms step_avg:85.67ms +step:11/1680 train_time:945ms step_avg:85.92ms +step:12/1680 train_time:1037ms step_avg:86.44ms +step:13/1680 train_time:1130ms step_avg:86.95ms +step:14/1680 train_time:1221ms step_avg:87.23ms +step:15/1680 train_time:1311ms step_avg:87.40ms +step:16/1680 train_time:1401ms step_avg:87.53ms +step:17/1680 train_time:1490ms step_avg:87.63ms +step:18/1680 train_time:1578ms step_avg:87.67ms +step:19/1680 train_time:1666ms step_avg:87.70ms +step:20/1680 train_time:1755ms step_avg:87.73ms +step:21/1680 train_time:1843ms step_avg:87.77ms +step:22/1680 train_time:1932ms step_avg:87.80ms +step:23/1680 train_time:2021ms step_avg:87.88ms +step:24/1680 train_time:2114ms step_avg:88.09ms +step:25/1680 train_time:2205ms step_avg:88.21ms +step:26/1680 train_time:2295ms step_avg:88.26ms +step:27/1680 train_time:2384ms step_avg:88.30ms +step:28/1680 train_time:2473ms step_avg:88.33ms +step:29/1680 train_time:2563ms step_avg:88.39ms +step:30/1680 train_time:2653ms step_avg:88.43ms +step:31/1680 train_time:2741ms step_avg:88.43ms +step:32/1680 train_time:2830ms step_avg:88.44ms +step:33/1680 train_time:2919ms step_avg:88.45ms +step:34/1680 train_time:3009ms step_avg:88.49ms +step:35/1680 train_time:3098ms step_avg:88.52ms +step:36/1680 train_time:3188ms step_avg:88.55ms +step:37/1680 train_time:3277ms step_avg:88.58ms +step:38/1680 train_time:3367ms step_avg:88.60ms +step:39/1680 train_time:3456ms step_avg:88.62ms +step:40/1680 train_time:3546ms step_avg:88.64ms +step:41/1680 train_time:3635ms step_avg:88.66ms +step:42/1680 train_time:3724ms step_avg:88.67ms +step:43/1680 train_time:3814ms step_avg:88.70ms +step:44/1680 train_time:3903ms step_avg:88.69ms +step:45/1680 train_time:3992ms step_avg:88.70ms +step:46/1680 train_time:4081ms step_avg:88.72ms +step:47/1680 train_time:4171ms step_avg:88.75ms +step:48/1680 train_time:4261ms step_avg:88.76ms +step:49/1680 train_time:4351ms step_avg:88.80ms +step:50/1680 train_time:4440ms step_avg:88.80ms +step:51/1680 train_time:4530ms step_avg:88.83ms +step:52/1680 train_time:4619ms step_avg:88.83ms +step:53/1680 train_time:4708ms step_avg:88.83ms +step:54/1680 train_time:4797ms step_avg:88.83ms +step:55/1680 train_time:4885ms step_avg:88.83ms +step:56/1680 train_time:4974ms step_avg:88.83ms +step:57/1680 train_time:5064ms step_avg:88.84ms +step:58/1680 train_time:5153ms step_avg:88.85ms +step:59/1680 train_time:5243ms step_avg:88.87ms +step:60/1680 train_time:5333ms step_avg:88.88ms +step:61/1680 train_time:5423ms step_avg:88.89ms +step:62/1680 train_time:5512ms step_avg:88.91ms +step:63/1680 train_time:5602ms step_avg:88.92ms +step:64/1680 train_time:5691ms step_avg:88.92ms +step:65/1680 train_time:5780ms step_avg:88.92ms +step:66/1680 train_time:5870ms step_avg:88.94ms +step:67/1680 train_time:5960ms step_avg:88.95ms +step:68/1680 train_time:6048ms step_avg:88.95ms +step:69/1680 train_time:6137ms step_avg:88.95ms +step:70/1680 train_time:6226ms step_avg:88.94ms +step:71/1680 train_time:6316ms step_avg:88.95ms +step:72/1680 train_time:6405ms step_avg:88.96ms +step:73/1680 train_time:6494ms step_avg:88.96ms +step:74/1680 train_time:6582ms step_avg:88.95ms +step:75/1680 train_time:6672ms step_avg:88.96ms +step:76/1680 train_time:6761ms step_avg:88.96ms +step:77/1680 train_time:6850ms step_avg:88.96ms +step:78/1680 train_time:6940ms step_avg:88.97ms +step:79/1680 train_time:7029ms step_avg:88.97ms +step:80/1680 train_time:7118ms step_avg:88.98ms +step:81/1680 train_time:7207ms step_avg:88.98ms +step:82/1680 train_time:7297ms step_avg:88.99ms +step:83/1680 train_time:7386ms step_avg:88.98ms +step:84/1680 train_time:7474ms step_avg:88.98ms +step:85/1680 train_time:7564ms step_avg:88.98ms +step:86/1680 train_time:7652ms step_avg:88.98ms +step:87/1680 train_time:7742ms step_avg:88.99ms +step:88/1680 train_time:7832ms step_avg:89.00ms +step:89/1680 train_time:7921ms step_avg:89.00ms +step:90/1680 train_time:8010ms step_avg:89.00ms +step:91/1680 train_time:8099ms step_avg:89.00ms +step:92/1680 train_time:8187ms step_avg:88.99ms +step:93/1680 train_time:8276ms step_avg:88.99ms +step:94/1680 train_time:8366ms step_avg:89.00ms +step:95/1680 train_time:8455ms step_avg:89.00ms +step:96/1680 train_time:8544ms step_avg:89.00ms +step:97/1680 train_time:8633ms step_avg:89.00ms +step:98/1680 train_time:8722ms step_avg:89.00ms +step:99/1680 train_time:8812ms step_avg:89.01ms +step:100/1680 train_time:8900ms step_avg:89.00ms +step:101/1680 train_time:8989ms step_avg:89.00ms +step:102/1680 train_time:9078ms step_avg:89.00ms +step:103/1680 train_time:9167ms step_avg:89.00ms +step:104/1680 train_time:9255ms step_avg:88.99ms +step:105/1680 train_time:9345ms step_avg:89.00ms +step:106/1680 train_time:9433ms step_avg:88.99ms +step:107/1680 train_time:9523ms step_avg:89.00ms +step:108/1680 train_time:9612ms step_avg:89.00ms +step:109/1680 train_time:9701ms step_avg:89.00ms +step:110/1680 train_time:9791ms step_avg:89.01ms +step:111/1680 train_time:9879ms step_avg:89.00ms +step:112/1680 train_time:9967ms step_avg:88.99ms +step:113/1680 train_time:10056ms step_avg:88.99ms +step:114/1680 train_time:10145ms step_avg:89.00ms +step:115/1680 train_time:10234ms step_avg:88.99ms +step:116/1680 train_time:10323ms step_avg:88.99ms +step:117/1680 train_time:10412ms step_avg:88.99ms +step:118/1680 train_time:10502ms step_avg:89.00ms +step:119/1680 train_time:10591ms step_avg:89.00ms +step:120/1680 train_time:10680ms step_avg:89.00ms +step:121/1680 train_time:10770ms step_avg:89.00ms +step:122/1680 train_time:10858ms step_avg:89.00ms +step:123/1680 train_time:10947ms step_avg:89.00ms +step:124/1680 train_time:11036ms step_avg:89.00ms +step:125/1680 train_time:11124ms step_avg:88.99ms +step:125/1680 val_loss:4.3230 train_time:11215ms step_avg:89.72ms +step:126/1680 train_time:11242ms step_avg:89.22ms +step:127/1680 train_time:11305ms step_avg:89.02ms +step:128/1680 train_time:11402ms step_avg:89.08ms +step:129/1680 train_time:11495ms step_avg:89.11ms +step:130/1680 train_time:11584ms step_avg:89.11ms +step:131/1680 train_time:11673ms step_avg:89.11ms +step:132/1680 train_time:11762ms step_avg:89.11ms +step:133/1680 train_time:11849ms step_avg:89.09ms +step:134/1680 train_time:11937ms step_avg:89.08ms +step:135/1680 train_time:12025ms step_avg:89.07ms +step:136/1680 train_time:12113ms step_avg:89.06ms +step:137/1680 train_time:12201ms step_avg:89.06ms +step:138/1680 train_time:12290ms step_avg:89.06ms +step:139/1680 train_time:12381ms step_avg:89.07ms +step:140/1680 train_time:12471ms step_avg:89.08ms +step:141/1680 train_time:12561ms step_avg:89.08ms +step:142/1680 train_time:12650ms step_avg:89.08ms +step:143/1680 train_time:12739ms step_avg:89.09ms +step:144/1680 train_time:12828ms step_avg:89.08ms +step:145/1680 train_time:12916ms step_avg:89.08ms +step:146/1680 train_time:13005ms step_avg:89.07ms +step:147/1680 train_time:13093ms step_avg:89.07ms +step:148/1680 train_time:13182ms step_avg:89.06ms +step:149/1680 train_time:13270ms step_avg:89.06ms +step:150/1680 train_time:13361ms step_avg:89.07ms +step:151/1680 train_time:13451ms step_avg:89.08ms +step:152/1680 train_time:13540ms step_avg:89.08ms +step:153/1680 train_time:13629ms step_avg:89.08ms +step:154/1680 train_time:13718ms step_avg:89.08ms +step:155/1680 train_time:13806ms step_avg:89.07ms +step:156/1680 train_time:13895ms step_avg:89.07ms +step:157/1680 train_time:13984ms step_avg:89.07ms +step:158/1680 train_time:14073ms step_avg:89.07ms +step:159/1680 train_time:14163ms step_avg:89.08ms +step:160/1680 train_time:14252ms step_avg:89.07ms +step:161/1680 train_time:14342ms step_avg:89.08ms +step:162/1680 train_time:14431ms step_avg:89.08ms +step:163/1680 train_time:14520ms step_avg:89.08ms +step:164/1680 train_time:14610ms step_avg:89.08ms +step:165/1680 train_time:14699ms step_avg:89.09ms +step:166/1680 train_time:14788ms step_avg:89.08ms +step:167/1680 train_time:14876ms step_avg:89.08ms +step:168/1680 train_time:14965ms step_avg:89.08ms +step:169/1680 train_time:15054ms step_avg:89.08ms +step:170/1680 train_time:15143ms step_avg:89.08ms +step:171/1680 train_time:15231ms step_avg:89.07ms +step:172/1680 train_time:15320ms step_avg:89.07ms +step:173/1680 train_time:15409ms step_avg:89.07ms +step:174/1680 train_time:15498ms step_avg:89.07ms +step:175/1680 train_time:15588ms step_avg:89.07ms +step:176/1680 train_time:15678ms step_avg:89.08ms +step:177/1680 train_time:15766ms step_avg:89.08ms +step:178/1680 train_time:15856ms step_avg:89.08ms +step:179/1680 train_time:15944ms step_avg:89.07ms +step:180/1680 train_time:16033ms step_avg:89.07ms +step:181/1680 train_time:16122ms step_avg:89.07ms +step:182/1680 train_time:16210ms step_avg:89.07ms +step:183/1680 train_time:16299ms step_avg:89.07ms +step:184/1680 train_time:16388ms step_avg:89.06ms +step:185/1680 train_time:16477ms step_avg:89.06ms +step:186/1680 train_time:16566ms step_avg:89.06ms +step:187/1680 train_time:16656ms step_avg:89.07ms +step:188/1680 train_time:16745ms step_avg:89.07ms +step:189/1680 train_time:16834ms step_avg:89.07ms +step:190/1680 train_time:16923ms step_avg:89.07ms +step:191/1680 train_time:17012ms step_avg:89.07ms +step:192/1680 train_time:17101ms step_avg:89.07ms +step:193/1680 train_time:17190ms step_avg:89.07ms +step:194/1680 train_time:17279ms step_avg:89.07ms +step:195/1680 train_time:17368ms step_avg:89.07ms +step:196/1680 train_time:17457ms step_avg:89.07ms +step:197/1680 train_time:17546ms step_avg:89.07ms +step:198/1680 train_time:17636ms step_avg:89.07ms +step:199/1680 train_time:17725ms step_avg:89.07ms +step:200/1680 train_time:17815ms step_avg:89.07ms +step:201/1680 train_time:17904ms step_avg:89.07ms +step:202/1680 train_time:17992ms step_avg:89.07ms +step:203/1680 train_time:18082ms step_avg:89.07ms +step:204/1680 train_time:18171ms step_avg:89.07ms +step:205/1680 train_time:18260ms step_avg:89.07ms +step:206/1680 train_time:18349ms step_avg:89.07ms +step:207/1680 train_time:18438ms step_avg:89.07ms +step:208/1680 train_time:18527ms step_avg:89.07ms +step:209/1680 train_time:18615ms step_avg:89.07ms +step:210/1680 train_time:18704ms step_avg:89.07ms +step:211/1680 train_time:18794ms step_avg:89.07ms +step:212/1680 train_time:18882ms step_avg:89.07ms +step:213/1680 train_time:18971ms step_avg:89.07ms +step:214/1680 train_time:19060ms step_avg:89.07ms +step:215/1680 train_time:19149ms step_avg:89.06ms +step:216/1680 train_time:19238ms step_avg:89.06ms +step:217/1680 train_time:19327ms step_avg:89.06ms +step:218/1680 train_time:19416ms step_avg:89.07ms +step:219/1680 train_time:19505ms step_avg:89.07ms +step:220/1680 train_time:19594ms step_avg:89.06ms +step:221/1680 train_time:19684ms step_avg:89.07ms +step:222/1680 train_time:19774ms step_avg:89.07ms +step:223/1680 train_time:19863ms step_avg:89.07ms +step:224/1680 train_time:19952ms step_avg:89.07ms +step:225/1680 train_time:20041ms step_avg:89.07ms +step:226/1680 train_time:20130ms step_avg:89.07ms +step:227/1680 train_time:20218ms step_avg:89.07ms +step:228/1680 train_time:20308ms step_avg:89.07ms +step:229/1680 train_time:20398ms step_avg:89.07ms +step:230/1680 train_time:20487ms step_avg:89.07ms +step:231/1680 train_time:20576ms step_avg:89.07ms +step:232/1680 train_time:20665ms step_avg:89.07ms +step:233/1680 train_time:20754ms step_avg:89.07ms +step:234/1680 train_time:20843ms step_avg:89.07ms +step:235/1680 train_time:20932ms step_avg:89.07ms +step:236/1680 train_time:21021ms step_avg:89.07ms +step:237/1680 train_time:21109ms step_avg:89.07ms +step:238/1680 train_time:21198ms step_avg:89.07ms +step:239/1680 train_time:21287ms step_avg:89.07ms +step:240/1680 train_time:21376ms step_avg:89.07ms +step:241/1680 train_time:21465ms step_avg:89.07ms +step:242/1680 train_time:21554ms step_avg:89.07ms +step:243/1680 train_time:21644ms step_avg:89.07ms +step:244/1680 train_time:21733ms step_avg:89.07ms +step:245/1680 train_time:21821ms step_avg:89.07ms +step:246/1680 train_time:21911ms step_avg:89.07ms +step:247/1680 train_time:22000ms step_avg:89.07ms +step:248/1680 train_time:22089ms step_avg:89.07ms +step:249/1680 train_time:22178ms step_avg:89.07ms +step:250/1680 train_time:22267ms step_avg:89.07ms +step:250/1680 val_loss:3.9675 train_time:22358ms step_avg:89.43ms +step:251/1680 train_time:22381ms step_avg:89.17ms +step:252/1680 train_time:22451ms step_avg:89.09ms +step:253/1680 train_time:22543ms step_avg:89.10ms +step:254/1680 train_time:22636ms step_avg:89.12ms +step:255/1680 train_time:22725ms step_avg:89.12ms +step:256/1680 train_time:22813ms step_avg:89.11ms +step:257/1680 train_time:22901ms step_avg:89.11ms +step:258/1680 train_time:22989ms step_avg:89.10ms +step:259/1680 train_time:23078ms step_avg:89.10ms +step:260/1680 train_time:23166ms step_avg:89.10ms +step:261/1680 train_time:23254ms step_avg:89.10ms +step:262/1680 train_time:23344ms step_avg:89.10ms +step:263/1680 train_time:23435ms step_avg:89.11ms +step:264/1680 train_time:23526ms step_avg:89.11ms +step:265/1680 train_time:23618ms step_avg:89.12ms +step:266/1680 train_time:23707ms step_avg:89.12ms +step:267/1680 train_time:23796ms step_avg:89.12ms +step:268/1680 train_time:23885ms step_avg:89.12ms +step:269/1680 train_time:23973ms step_avg:89.12ms +step:270/1680 train_time:24061ms step_avg:89.12ms +step:271/1680 train_time:24150ms step_avg:89.11ms +step:272/1680 train_time:24239ms step_avg:89.11ms +step:273/1680 train_time:24327ms step_avg:89.11ms +step:274/1680 train_time:24417ms step_avg:89.11ms +step:275/1680 train_time:24507ms step_avg:89.12ms +step:276/1680 train_time:24598ms step_avg:89.12ms +step:277/1680 train_time:24688ms step_avg:89.13ms +step:278/1680 train_time:24777ms step_avg:89.13ms +step:279/1680 train_time:24866ms step_avg:89.13ms +step:280/1680 train_time:24956ms step_avg:89.13ms +step:281/1680 train_time:25044ms step_avg:89.13ms +step:282/1680 train_time:25133ms step_avg:89.12ms +step:283/1680 train_time:25222ms step_avg:89.12ms +step:284/1680 train_time:25311ms step_avg:89.12ms +step:285/1680 train_time:25400ms step_avg:89.12ms +step:286/1680 train_time:25490ms step_avg:89.13ms +step:287/1680 train_time:25581ms step_avg:89.13ms +step:288/1680 train_time:25670ms step_avg:89.13ms +step:289/1680 train_time:25759ms step_avg:89.13ms +step:290/1680 train_time:25848ms step_avg:89.13ms +step:291/1680 train_time:25938ms step_avg:89.13ms +step:292/1680 train_time:26027ms step_avg:89.13ms +step:293/1680 train_time:26116ms step_avg:89.13ms +step:294/1680 train_time:26205ms step_avg:89.13ms +step:295/1680 train_time:26294ms step_avg:89.13ms +step:296/1680 train_time:26382ms step_avg:89.13ms +step:297/1680 train_time:26472ms step_avg:89.13ms +step:298/1680 train_time:26561ms step_avg:89.13ms +step:299/1680 train_time:26651ms step_avg:89.13ms +step:300/1680 train_time:26740ms step_avg:89.13ms +step:301/1680 train_time:26830ms step_avg:89.14ms +step:302/1680 train_time:26920ms step_avg:89.14ms +step:303/1680 train_time:27009ms step_avg:89.14ms +step:304/1680 train_time:27097ms step_avg:89.14ms +step:305/1680 train_time:27187ms step_avg:89.14ms +step:306/1680 train_time:27276ms step_avg:89.14ms +step:307/1680 train_time:27365ms step_avg:89.14ms +step:308/1680 train_time:27454ms step_avg:89.14ms +step:309/1680 train_time:27544ms step_avg:89.14ms +step:310/1680 train_time:27633ms step_avg:89.14ms +step:311/1680 train_time:27722ms step_avg:89.14ms +step:312/1680 train_time:27811ms step_avg:89.14ms +step:313/1680 train_time:27900ms step_avg:89.14ms +step:314/1680 train_time:27990ms step_avg:89.14ms +step:315/1680 train_time:28078ms step_avg:89.14ms +step:316/1680 train_time:28168ms step_avg:89.14ms +step:317/1680 train_time:28256ms step_avg:89.14ms +step:318/1680 train_time:28346ms step_avg:89.14ms +step:319/1680 train_time:28435ms step_avg:89.14ms +step:320/1680 train_time:28524ms step_avg:89.14ms +step:321/1680 train_time:28614ms step_avg:89.14ms +step:322/1680 train_time:28703ms step_avg:89.14ms +step:323/1680 train_time:28792ms step_avg:89.14ms +step:324/1680 train_time:28880ms step_avg:89.14ms +step:325/1680 train_time:28970ms step_avg:89.14ms +step:326/1680 train_time:29058ms step_avg:89.14ms +step:327/1680 train_time:29147ms step_avg:89.13ms +step:328/1680 train_time:29236ms step_avg:89.13ms +step:329/1680 train_time:29324ms step_avg:89.13ms +step:330/1680 train_time:29413ms step_avg:89.13ms +step:331/1680 train_time:29502ms step_avg:89.13ms +step:332/1680 train_time:29590ms step_avg:89.13ms +step:333/1680 train_time:29679ms step_avg:89.13ms +step:334/1680 train_time:29767ms step_avg:89.12ms +step:335/1680 train_time:29857ms step_avg:89.12ms +step:336/1680 train_time:29946ms step_avg:89.12ms +step:337/1680 train_time:30035ms step_avg:89.12ms +step:338/1680 train_time:30124ms step_avg:89.13ms +step:339/1680 train_time:30213ms step_avg:89.12ms +step:340/1680 train_time:30301ms step_avg:89.12ms +step:341/1680 train_time:30390ms step_avg:89.12ms +step:342/1680 train_time:30479ms step_avg:89.12ms +step:343/1680 train_time:30568ms step_avg:89.12ms +step:344/1680 train_time:30656ms step_avg:89.12ms +step:345/1680 train_time:30746ms step_avg:89.12ms +step:346/1680 train_time:30835ms step_avg:89.12ms +step:347/1680 train_time:30924ms step_avg:89.12ms +step:348/1680 train_time:31014ms step_avg:89.12ms +step:349/1680 train_time:31104ms step_avg:89.12ms +step:350/1680 train_time:31193ms step_avg:89.12ms +step:351/1680 train_time:31281ms step_avg:89.12ms +step:352/1680 train_time:31371ms step_avg:89.12ms +step:353/1680 train_time:31459ms step_avg:89.12ms +step:354/1680 train_time:31548ms step_avg:89.12ms +step:355/1680 train_time:31637ms step_avg:89.12ms +step:356/1680 train_time:31726ms step_avg:89.12ms +step:357/1680 train_time:31816ms step_avg:89.12ms +step:358/1680 train_time:31904ms step_avg:89.12ms +step:359/1680 train_time:31993ms step_avg:89.12ms +step:360/1680 train_time:32082ms step_avg:89.12ms +step:361/1680 train_time:32172ms step_avg:89.12ms +step:362/1680 train_time:32260ms step_avg:89.12ms +step:363/1680 train_time:32350ms step_avg:89.12ms +step:364/1680 train_time:32438ms step_avg:89.12ms +step:365/1680 train_time:32527ms step_avg:89.12ms +step:366/1680 train_time:32616ms step_avg:89.11ms +step:367/1680 train_time:32704ms step_avg:89.11ms +step:368/1680 train_time:32794ms step_avg:89.11ms +step:369/1680 train_time:32883ms step_avg:89.11ms +step:370/1680 train_time:32972ms step_avg:89.11ms +step:371/1680 train_time:33060ms step_avg:89.11ms +step:372/1680 train_time:33149ms step_avg:89.11ms +step:373/1680 train_time:33238ms step_avg:89.11ms +step:374/1680 train_time:33327ms step_avg:89.11ms +step:375/1680 train_time:33418ms step_avg:89.11ms +step:375/1680 val_loss:3.8208 train_time:33508ms step_avg:89.35ms +step:376/1680 train_time:33530ms step_avg:89.18ms +step:377/1680 train_time:33600ms step_avg:89.12ms +step:378/1680 train_time:33697ms step_avg:89.14ms +step:379/1680 train_time:33788ms step_avg:89.15ms +step:380/1680 train_time:33877ms step_avg:89.15ms +step:381/1680 train_time:33965ms step_avg:89.15ms +step:382/1680 train_time:34053ms step_avg:89.14ms +step:383/1680 train_time:34141ms step_avg:89.14ms +step:384/1680 train_time:34229ms step_avg:89.14ms +step:385/1680 train_time:34316ms step_avg:89.13ms +step:386/1680 train_time:34404ms step_avg:89.13ms +step:387/1680 train_time:34493ms step_avg:89.13ms +step:388/1680 train_time:34584ms step_avg:89.13ms +step:389/1680 train_time:34678ms step_avg:89.15ms +step:390/1680 train_time:34769ms step_avg:89.15ms +step:391/1680 train_time:34857ms step_avg:89.15ms +step:392/1680 train_time:34946ms step_avg:89.15ms +step:393/1680 train_time:35034ms step_avg:89.15ms +step:394/1680 train_time:35123ms step_avg:89.14ms +step:395/1680 train_time:35211ms step_avg:89.14ms +step:396/1680 train_time:35300ms step_avg:89.14ms +step:397/1680 train_time:35388ms step_avg:89.14ms +step:398/1680 train_time:35476ms step_avg:89.14ms +step:399/1680 train_time:35566ms step_avg:89.14ms +step:400/1680 train_time:35657ms step_avg:89.14ms +step:401/1680 train_time:35747ms step_avg:89.14ms +step:402/1680 train_time:35837ms step_avg:89.15ms +step:403/1680 train_time:35926ms step_avg:89.15ms +step:404/1680 train_time:36014ms step_avg:89.14ms +step:405/1680 train_time:36103ms step_avg:89.14ms +step:406/1680 train_time:36191ms step_avg:89.14ms +step:407/1680 train_time:36279ms step_avg:89.14ms +step:408/1680 train_time:36367ms step_avg:89.14ms +step:409/1680 train_time:36456ms step_avg:89.13ms +step:410/1680 train_time:36546ms step_avg:89.14ms +step:411/1680 train_time:36636ms step_avg:89.14ms +step:412/1680 train_time:36726ms step_avg:89.14ms +step:413/1680 train_time:36816ms step_avg:89.14ms +step:414/1680 train_time:36907ms step_avg:89.15ms +step:415/1680 train_time:36996ms step_avg:89.15ms +step:416/1680 train_time:37085ms step_avg:89.15ms +step:417/1680 train_time:37174ms step_avg:89.15ms +step:418/1680 train_time:37262ms step_avg:89.14ms +step:419/1680 train_time:37350ms step_avg:89.14ms +step:420/1680 train_time:37440ms step_avg:89.14ms +step:421/1680 train_time:37528ms step_avg:89.14ms +step:422/1680 train_time:37617ms step_avg:89.14ms +step:423/1680 train_time:37706ms step_avg:89.14ms +step:424/1680 train_time:37796ms step_avg:89.14ms +step:425/1680 train_time:37885ms step_avg:89.14ms +step:426/1680 train_time:37973ms step_avg:89.14ms +step:427/1680 train_time:38062ms step_avg:89.14ms +step:428/1680 train_time:38151ms step_avg:89.14ms +step:429/1680 train_time:38240ms step_avg:89.14ms +step:430/1680 train_time:38329ms step_avg:89.14ms +step:431/1680 train_time:38417ms step_avg:89.13ms +step:432/1680 train_time:38506ms step_avg:89.13ms +step:433/1680 train_time:38595ms step_avg:89.13ms +step:434/1680 train_time:38684ms step_avg:89.13ms +step:435/1680 train_time:38774ms step_avg:89.14ms +step:436/1680 train_time:38863ms step_avg:89.14ms +step:437/1680 train_time:38953ms step_avg:89.14ms +step:438/1680 train_time:39041ms step_avg:89.14ms +step:439/1680 train_time:39130ms step_avg:89.14ms +step:440/1680 train_time:39219ms step_avg:89.13ms +step:441/1680 train_time:39308ms step_avg:89.13ms +step:442/1680 train_time:39396ms step_avg:89.13ms +step:443/1680 train_time:39484ms step_avg:89.13ms +step:444/1680 train_time:39574ms step_avg:89.13ms +step:445/1680 train_time:39663ms step_avg:89.13ms +step:446/1680 train_time:39753ms step_avg:89.13ms +step:447/1680 train_time:39842ms step_avg:89.13ms +step:448/1680 train_time:39932ms step_avg:89.13ms +step:449/1680 train_time:40021ms step_avg:89.13ms +step:450/1680 train_time:40110ms step_avg:89.13ms +step:451/1680 train_time:40199ms step_avg:89.13ms +step:452/1680 train_time:40287ms step_avg:89.13ms +step:453/1680 train_time:40376ms step_avg:89.13ms +step:454/1680 train_time:40465ms step_avg:89.13ms +step:455/1680 train_time:40553ms step_avg:89.13ms +step:456/1680 train_time:40644ms step_avg:89.13ms +step:457/1680 train_time:40734ms step_avg:89.13ms +step:458/1680 train_time:40823ms step_avg:89.13ms +step:459/1680 train_time:40912ms step_avg:89.13ms +step:460/1680 train_time:41002ms step_avg:89.13ms +step:461/1680 train_time:41090ms step_avg:89.13ms +step:462/1680 train_time:41179ms step_avg:89.13ms +step:463/1680 train_time:41268ms step_avg:89.13ms +step:464/1680 train_time:41357ms step_avg:89.13ms +step:465/1680 train_time:41446ms step_avg:89.13ms +step:466/1680 train_time:41535ms step_avg:89.13ms +step:467/1680 train_time:41623ms step_avg:89.13ms +step:468/1680 train_time:41712ms step_avg:89.13ms +step:469/1680 train_time:41802ms step_avg:89.13ms +step:470/1680 train_time:41892ms step_avg:89.13ms +step:471/1680 train_time:41981ms step_avg:89.13ms +step:472/1680 train_time:42071ms step_avg:89.13ms +step:473/1680 train_time:42160ms step_avg:89.13ms +step:474/1680 train_time:42248ms step_avg:89.13ms +step:475/1680 train_time:42337ms step_avg:89.13ms +step:476/1680 train_time:42426ms step_avg:89.13ms +step:477/1680 train_time:42514ms step_avg:89.13ms +step:478/1680 train_time:42603ms step_avg:89.13ms +step:479/1680 train_time:42691ms step_avg:89.13ms +step:480/1680 train_time:42780ms step_avg:89.13ms +step:481/1680 train_time:42869ms step_avg:89.12ms +step:482/1680 train_time:42958ms step_avg:89.12ms +step:483/1680 train_time:43047ms step_avg:89.12ms +step:484/1680 train_time:43136ms step_avg:89.12ms +step:485/1680 train_time:43224ms step_avg:89.12ms +step:486/1680 train_time:43314ms step_avg:89.12ms +step:487/1680 train_time:43403ms step_avg:89.12ms +step:488/1680 train_time:43491ms step_avg:89.12ms +step:489/1680 train_time:43580ms step_avg:89.12ms +step:490/1680 train_time:43668ms step_avg:89.12ms +step:491/1680 train_time:43757ms step_avg:89.12ms +step:492/1680 train_time:43846ms step_avg:89.12ms +step:493/1680 train_time:43935ms step_avg:89.12ms +step:494/1680 train_time:44024ms step_avg:89.12ms +step:495/1680 train_time:44114ms step_avg:89.12ms +step:496/1680 train_time:44203ms step_avg:89.12ms +step:497/1680 train_time:44294ms step_avg:89.12ms +step:498/1680 train_time:44383ms step_avg:89.12ms +step:499/1680 train_time:44473ms step_avg:89.12ms +step:500/1680 train_time:44562ms step_avg:89.12ms +step:500/1680 val_loss:3.7171 train_time:44653ms step_avg:89.31ms +step:501/1680 train_time:44675ms step_avg:89.17ms +step:502/1680 train_time:44746ms step_avg:89.13ms +step:503/1680 train_time:44841ms step_avg:89.15ms +step:504/1680 train_time:44931ms step_avg:89.15ms +step:505/1680 train_time:45019ms step_avg:89.15ms +step:506/1680 train_time:45107ms step_avg:89.15ms +step:507/1680 train_time:45195ms step_avg:89.14ms +step:508/1680 train_time:45283ms step_avg:89.14ms +step:509/1680 train_time:45371ms step_avg:89.14ms +step:510/1680 train_time:45459ms step_avg:89.13ms +step:511/1680 train_time:45547ms step_avg:89.13ms +step:512/1680 train_time:45637ms step_avg:89.13ms +step:513/1680 train_time:45728ms step_avg:89.14ms +step:514/1680 train_time:45818ms step_avg:89.14ms +step:515/1680 train_time:45908ms step_avg:89.14ms +step:516/1680 train_time:45997ms step_avg:89.14ms +step:517/1680 train_time:46086ms step_avg:89.14ms +step:518/1680 train_time:46175ms step_avg:89.14ms +step:519/1680 train_time:46263ms step_avg:89.14ms +step:520/1680 train_time:46352ms step_avg:89.14ms +step:521/1680 train_time:46440ms step_avg:89.14ms +step:522/1680 train_time:46528ms step_avg:89.13ms +step:523/1680 train_time:46616ms step_avg:89.13ms +step:524/1680 train_time:46706ms step_avg:89.13ms +step:525/1680 train_time:46797ms step_avg:89.14ms +step:526/1680 train_time:46887ms step_avg:89.14ms +step:527/1680 train_time:46976ms step_avg:89.14ms +step:528/1680 train_time:47066ms step_avg:89.14ms +step:529/1680 train_time:47155ms step_avg:89.14ms +step:530/1680 train_time:47244ms step_avg:89.14ms +step:531/1680 train_time:47333ms step_avg:89.14ms +step:532/1680 train_time:47421ms step_avg:89.14ms +step:533/1680 train_time:47510ms step_avg:89.14ms +step:534/1680 train_time:47599ms step_avg:89.14ms +step:535/1680 train_time:47688ms step_avg:89.14ms +step:536/1680 train_time:47778ms step_avg:89.14ms +step:537/1680 train_time:47868ms step_avg:89.14ms +step:538/1680 train_time:47957ms step_avg:89.14ms +step:539/1680 train_time:48047ms step_avg:89.14ms +step:540/1680 train_time:48136ms step_avg:89.14ms +step:541/1680 train_time:48225ms step_avg:89.14ms +step:542/1680 train_time:48314ms step_avg:89.14ms +step:543/1680 train_time:48402ms step_avg:89.14ms +step:544/1680 train_time:48491ms step_avg:89.14ms +step:545/1680 train_time:48580ms step_avg:89.14ms +step:546/1680 train_time:48668ms step_avg:89.14ms +step:547/1680 train_time:48758ms step_avg:89.14ms +step:548/1680 train_time:48847ms step_avg:89.14ms +step:549/1680 train_time:48938ms step_avg:89.14ms +step:550/1680 train_time:49029ms step_avg:89.14ms +step:551/1680 train_time:49119ms step_avg:89.14ms +step:552/1680 train_time:49209ms step_avg:89.15ms +step:553/1680 train_time:49299ms step_avg:89.15ms +step:554/1680 train_time:49389ms step_avg:89.15ms +step:555/1680 train_time:49478ms step_avg:89.15ms +step:556/1680 train_time:49569ms step_avg:89.15ms +step:557/1680 train_time:49659ms step_avg:89.15ms +step:558/1680 train_time:49748ms step_avg:89.15ms +step:559/1680 train_time:49839ms step_avg:89.16ms +step:560/1680 train_time:49931ms step_avg:89.16ms +step:561/1680 train_time:50020ms step_avg:89.16ms +step:562/1680 train_time:50110ms step_avg:89.16ms +step:563/1680 train_time:50200ms step_avg:89.17ms +step:564/1680 train_time:50292ms step_avg:89.17ms +step:565/1680 train_time:50381ms step_avg:89.17ms +step:566/1680 train_time:50472ms step_avg:89.17ms +step:567/1680 train_time:50562ms step_avg:89.18ms +step:568/1680 train_time:50653ms step_avg:89.18ms +step:569/1680 train_time:50742ms step_avg:89.18ms +step:570/1680 train_time:50834ms step_avg:89.18ms +step:571/1680 train_time:50923ms step_avg:89.18ms +step:572/1680 train_time:51014ms step_avg:89.19ms +step:573/1680 train_time:51104ms step_avg:89.19ms +step:574/1680 train_time:51196ms step_avg:89.19ms +step:575/1680 train_time:51286ms step_avg:89.19ms +step:576/1680 train_time:51376ms step_avg:89.20ms +step:577/1680 train_time:51466ms step_avg:89.20ms +step:578/1680 train_time:51556ms step_avg:89.20ms +step:579/1680 train_time:51646ms step_avg:89.20ms +step:580/1680 train_time:51738ms step_avg:89.20ms +step:581/1680 train_time:51829ms step_avg:89.21ms +step:582/1680 train_time:51919ms step_avg:89.21ms +step:583/1680 train_time:52009ms step_avg:89.21ms +step:584/1680 train_time:52099ms step_avg:89.21ms +step:585/1680 train_time:52189ms step_avg:89.21ms +step:586/1680 train_time:52279ms step_avg:89.21ms +step:587/1680 train_time:52369ms step_avg:89.21ms +step:588/1680 train_time:52459ms step_avg:89.22ms +step:589/1680 train_time:52549ms step_avg:89.22ms +step:590/1680 train_time:52639ms step_avg:89.22ms +step:591/1680 train_time:52730ms step_avg:89.22ms +step:592/1680 train_time:52820ms step_avg:89.22ms +step:593/1680 train_time:52910ms step_avg:89.22ms +step:594/1680 train_time:53000ms step_avg:89.23ms +step:595/1680 train_time:53090ms step_avg:89.23ms +step:596/1680 train_time:53180ms step_avg:89.23ms +step:597/1680 train_time:53270ms step_avg:89.23ms +step:598/1680 train_time:53360ms step_avg:89.23ms +step:599/1680 train_time:53450ms step_avg:89.23ms +step:600/1680 train_time:53540ms step_avg:89.23ms +step:601/1680 train_time:53631ms step_avg:89.24ms +step:602/1680 train_time:53721ms step_avg:89.24ms +step:603/1680 train_time:53811ms step_avg:89.24ms +step:604/1680 train_time:53901ms step_avg:89.24ms +step:605/1680 train_time:53991ms step_avg:89.24ms +step:606/1680 train_time:54082ms step_avg:89.24ms +step:607/1680 train_time:54171ms step_avg:89.24ms +step:608/1680 train_time:54262ms step_avg:89.25ms +step:609/1680 train_time:54351ms step_avg:89.25ms +step:610/1680 train_time:54441ms step_avg:89.25ms +step:611/1680 train_time:54531ms step_avg:89.25ms +step:612/1680 train_time:54621ms step_avg:89.25ms +step:613/1680 train_time:54712ms step_avg:89.25ms +step:614/1680 train_time:54803ms step_avg:89.26ms +step:615/1680 train_time:54893ms step_avg:89.26ms +step:616/1680 train_time:54983ms step_avg:89.26ms +step:617/1680 train_time:55073ms step_avg:89.26ms +step:618/1680 train_time:55163ms step_avg:89.26ms +step:619/1680 train_time:55254ms step_avg:89.26ms +step:620/1680 train_time:55344ms step_avg:89.26ms +step:621/1680 train_time:55434ms step_avg:89.27ms +step:622/1680 train_time:55523ms step_avg:89.27ms +step:623/1680 train_time:55613ms step_avg:89.27ms +step:624/1680 train_time:55704ms step_avg:89.27ms +step:625/1680 train_time:55795ms step_avg:89.27ms +step:625/1680 val_loss:3.6153 train_time:55887ms step_avg:89.42ms +step:626/1680 train_time:55910ms step_avg:89.31ms +step:627/1680 train_time:55980ms step_avg:89.28ms +step:628/1680 train_time:56079ms step_avg:89.30ms +step:629/1680 train_time:56171ms step_avg:89.30ms +step:630/1680 train_time:56260ms step_avg:89.30ms +step:631/1680 train_time:56348ms step_avg:89.30ms +step:632/1680 train_time:56437ms step_avg:89.30ms +step:633/1680 train_time:56526ms step_avg:89.30ms +step:634/1680 train_time:56615ms step_avg:89.30ms +step:635/1680 train_time:56704ms step_avg:89.30ms +step:636/1680 train_time:56793ms step_avg:89.30ms +step:637/1680 train_time:56884ms step_avg:89.30ms +step:638/1680 train_time:56977ms step_avg:89.31ms +step:639/1680 train_time:57071ms step_avg:89.31ms +step:640/1680 train_time:57163ms step_avg:89.32ms +step:641/1680 train_time:57252ms step_avg:89.32ms +step:642/1680 train_time:57341ms step_avg:89.32ms +step:643/1680 train_time:57430ms step_avg:89.32ms +step:644/1680 train_time:57520ms step_avg:89.32ms +step:645/1680 train_time:57609ms step_avg:89.32ms +step:646/1680 train_time:57699ms step_avg:89.32ms +step:647/1680 train_time:57788ms step_avg:89.32ms +step:648/1680 train_time:57879ms step_avg:89.32ms +step:649/1680 train_time:57970ms step_avg:89.32ms +step:650/1680 train_time:58061ms step_avg:89.33ms +step:651/1680 train_time:58153ms step_avg:89.33ms +step:652/1680 train_time:58244ms step_avg:89.33ms +step:653/1680 train_time:58334ms step_avg:89.33ms +step:654/1680 train_time:58424ms step_avg:89.33ms +step:655/1680 train_time:58513ms step_avg:89.33ms +step:656/1680 train_time:58603ms step_avg:89.33ms +step:657/1680 train_time:58692ms step_avg:89.33ms +step:658/1680 train_time:58782ms step_avg:89.33ms +step:659/1680 train_time:58872ms step_avg:89.34ms +step:660/1680 train_time:58963ms step_avg:89.34ms +step:661/1680 train_time:59054ms step_avg:89.34ms +step:662/1680 train_time:59145ms step_avg:89.34ms +step:663/1680 train_time:59236ms step_avg:89.34ms +step:664/1680 train_time:59326ms step_avg:89.35ms +step:665/1680 train_time:59416ms step_avg:89.35ms +step:666/1680 train_time:59507ms step_avg:89.35ms +step:667/1680 train_time:59597ms step_avg:89.35ms +step:668/1680 train_time:59687ms step_avg:89.35ms +step:669/1680 train_time:59777ms step_avg:89.35ms +step:670/1680 train_time:59867ms step_avg:89.35ms +step:671/1680 train_time:59957ms step_avg:89.36ms +step:672/1680 train_time:60048ms step_avg:89.36ms +step:673/1680 train_time:60139ms step_avg:89.36ms +step:674/1680 train_time:60229ms step_avg:89.36ms +step:675/1680 train_time:60320ms step_avg:89.36ms +step:676/1680 train_time:60409ms step_avg:89.36ms +step:677/1680 train_time:60500ms step_avg:89.37ms +step:678/1680 train_time:60590ms step_avg:89.37ms +step:679/1680 train_time:60680ms step_avg:89.37ms +step:680/1680 train_time:60770ms step_avg:89.37ms +step:681/1680 train_time:60860ms step_avg:89.37ms +step:682/1680 train_time:60950ms step_avg:89.37ms +step:683/1680 train_time:61041ms step_avg:89.37ms +step:684/1680 train_time:61131ms step_avg:89.37ms +step:685/1680 train_time:61223ms step_avg:89.38ms +step:686/1680 train_time:61313ms step_avg:89.38ms +step:687/1680 train_time:61403ms step_avg:89.38ms +step:688/1680 train_time:61492ms step_avg:89.38ms +step:689/1680 train_time:61581ms step_avg:89.38ms +step:690/1680 train_time:61672ms step_avg:89.38ms +step:691/1680 train_time:61762ms step_avg:89.38ms +step:692/1680 train_time:61852ms step_avg:89.38ms +step:693/1680 train_time:61942ms step_avg:89.38ms +step:694/1680 train_time:62033ms step_avg:89.39ms +step:695/1680 train_time:62125ms step_avg:89.39ms +step:696/1680 train_time:62215ms step_avg:89.39ms +step:697/1680 train_time:62305ms step_avg:89.39ms +step:698/1680 train_time:62396ms step_avg:89.39ms +step:699/1680 train_time:62487ms step_avg:89.39ms +step:700/1680 train_time:62576ms step_avg:89.39ms +step:701/1680 train_time:62666ms step_avg:89.40ms +step:702/1680 train_time:62757ms step_avg:89.40ms +step:703/1680 train_time:62847ms step_avg:89.40ms +step:704/1680 train_time:62937ms step_avg:89.40ms +step:705/1680 train_time:63028ms step_avg:89.40ms +step:706/1680 train_time:63118ms step_avg:89.40ms +step:707/1680 train_time:63208ms step_avg:89.40ms +step:708/1680 train_time:63299ms step_avg:89.41ms +step:709/1680 train_time:63389ms step_avg:89.41ms +step:710/1680 train_time:63480ms step_avg:89.41ms +step:711/1680 train_time:63569ms step_avg:89.41ms +step:712/1680 train_time:63660ms step_avg:89.41ms +step:713/1680 train_time:63750ms step_avg:89.41ms +step:714/1680 train_time:63839ms step_avg:89.41ms +step:715/1680 train_time:63930ms step_avg:89.41ms +step:716/1680 train_time:64020ms step_avg:89.41ms +step:717/1680 train_time:64111ms step_avg:89.42ms +step:718/1680 train_time:64201ms step_avg:89.42ms +step:719/1680 train_time:64292ms step_avg:89.42ms +step:720/1680 train_time:64383ms step_avg:89.42ms +step:721/1680 train_time:64473ms step_avg:89.42ms +step:722/1680 train_time:64564ms step_avg:89.42ms +step:723/1680 train_time:64653ms step_avg:89.42ms +step:724/1680 train_time:64743ms step_avg:89.42ms +step:725/1680 train_time:64833ms step_avg:89.43ms +step:726/1680 train_time:64923ms step_avg:89.42ms +step:727/1680 train_time:65012ms step_avg:89.43ms +step:728/1680 train_time:65103ms step_avg:89.43ms +step:729/1680 train_time:65193ms step_avg:89.43ms +step:730/1680 train_time:65285ms step_avg:89.43ms +step:731/1680 train_time:65375ms step_avg:89.43ms +step:732/1680 train_time:65465ms step_avg:89.43ms +step:733/1680 train_time:65556ms step_avg:89.43ms +step:734/1680 train_time:65646ms step_avg:89.44ms +step:735/1680 train_time:65736ms step_avg:89.44ms +step:736/1680 train_time:65827ms step_avg:89.44ms +step:737/1680 train_time:65917ms step_avg:89.44ms +step:738/1680 train_time:66007ms step_avg:89.44ms +step:739/1680 train_time:66098ms step_avg:89.44ms +step:740/1680 train_time:66189ms step_avg:89.44ms +step:741/1680 train_time:66280ms step_avg:89.45ms +step:742/1680 train_time:66369ms step_avg:89.45ms +step:743/1680 train_time:66460ms step_avg:89.45ms +step:744/1680 train_time:66549ms step_avg:89.45ms +step:745/1680 train_time:66640ms step_avg:89.45ms +step:746/1680 train_time:66730ms step_avg:89.45ms +step:747/1680 train_time:66821ms step_avg:89.45ms +step:748/1680 train_time:66912ms step_avg:89.45ms +step:749/1680 train_time:67001ms step_avg:89.45ms +step:750/1680 train_time:67092ms step_avg:89.46ms +step:750/1680 val_loss:3.5667 train_time:67184ms step_avg:89.58ms +step:751/1680 train_time:67208ms step_avg:89.49ms +step:752/1680 train_time:67279ms step_avg:89.47ms +step:753/1680 train_time:67377ms step_avg:89.48ms +step:754/1680 train_time:67470ms step_avg:89.48ms +step:755/1680 train_time:67560ms step_avg:89.48ms +step:756/1680 train_time:67649ms step_avg:89.48ms +step:757/1680 train_time:67739ms step_avg:89.48ms +step:758/1680 train_time:67828ms step_avg:89.48ms +step:759/1680 train_time:67917ms step_avg:89.48ms +step:760/1680 train_time:68006ms step_avg:89.48ms +step:761/1680 train_time:68095ms step_avg:89.48ms +step:762/1680 train_time:68184ms step_avg:89.48ms +step:763/1680 train_time:68277ms step_avg:89.48ms +step:764/1680 train_time:68370ms step_avg:89.49ms +step:765/1680 train_time:68462ms step_avg:89.49ms +step:766/1680 train_time:68554ms step_avg:89.50ms +step:767/1680 train_time:68644ms step_avg:89.50ms +step:768/1680 train_time:68734ms step_avg:89.50ms +step:769/1680 train_time:68824ms step_avg:89.50ms +step:770/1680 train_time:68914ms step_avg:89.50ms +step:771/1680 train_time:69003ms step_avg:89.50ms +step:772/1680 train_time:69092ms step_avg:89.50ms +step:773/1680 train_time:69182ms step_avg:89.50ms +step:774/1680 train_time:69274ms step_avg:89.50ms +step:775/1680 train_time:69366ms step_avg:89.50ms +step:776/1680 train_time:69458ms step_avg:89.51ms +step:777/1680 train_time:69549ms step_avg:89.51ms +step:778/1680 train_time:69639ms step_avg:89.51ms +step:779/1680 train_time:69729ms step_avg:89.51ms +step:780/1680 train_time:69819ms step_avg:89.51ms +step:781/1680 train_time:69910ms step_avg:89.51ms +step:782/1680 train_time:69999ms step_avg:89.51ms +step:783/1680 train_time:70088ms step_avg:89.51ms +step:784/1680 train_time:70177ms step_avg:89.51ms +step:785/1680 train_time:70268ms step_avg:89.51ms +step:786/1680 train_time:70358ms step_avg:89.51ms +step:787/1680 train_time:70449ms step_avg:89.52ms +step:788/1680 train_time:70539ms step_avg:89.52ms +step:789/1680 train_time:70629ms step_avg:89.52ms +step:790/1680 train_time:70719ms step_avg:89.52ms +step:791/1680 train_time:70809ms step_avg:89.52ms +step:792/1680 train_time:70898ms step_avg:89.52ms +step:793/1680 train_time:70988ms step_avg:89.52ms +step:794/1680 train_time:71078ms step_avg:89.52ms +step:795/1680 train_time:71168ms step_avg:89.52ms +step:796/1680 train_time:71257ms step_avg:89.52ms +step:797/1680 train_time:71348ms step_avg:89.52ms +step:798/1680 train_time:71438ms step_avg:89.52ms +step:799/1680 train_time:71529ms step_avg:89.52ms +step:800/1680 train_time:71619ms step_avg:89.52ms +step:801/1680 train_time:71709ms step_avg:89.52ms +step:802/1680 train_time:71799ms step_avg:89.53ms +step:803/1680 train_time:71889ms step_avg:89.53ms +step:804/1680 train_time:71978ms step_avg:89.52ms +step:805/1680 train_time:72068ms step_avg:89.53ms +step:806/1680 train_time:72158ms step_avg:89.53ms +step:807/1680 train_time:72248ms step_avg:89.53ms +step:808/1680 train_time:72338ms step_avg:89.53ms +step:809/1680 train_time:72428ms step_avg:89.53ms +step:810/1680 train_time:72518ms step_avg:89.53ms +step:811/1680 train_time:72608ms step_avg:89.53ms +step:812/1680 train_time:72698ms step_avg:89.53ms +step:813/1680 train_time:72788ms step_avg:89.53ms +step:814/1680 train_time:72878ms step_avg:89.53ms +step:815/1680 train_time:72969ms step_avg:89.53ms +step:816/1680 train_time:73059ms step_avg:89.53ms +step:817/1680 train_time:73149ms step_avg:89.53ms +step:818/1680 train_time:73239ms step_avg:89.53ms +step:819/1680 train_time:73330ms step_avg:89.54ms +step:820/1680 train_time:73419ms step_avg:89.54ms +step:821/1680 train_time:73512ms step_avg:89.54ms +step:822/1680 train_time:73602ms step_avg:89.54ms +step:823/1680 train_time:73692ms step_avg:89.54ms +step:824/1680 train_time:73782ms step_avg:89.54ms +step:825/1680 train_time:73873ms step_avg:89.54ms +step:826/1680 train_time:73963ms step_avg:89.54ms +step:827/1680 train_time:74054ms step_avg:89.55ms +step:828/1680 train_time:74144ms step_avg:89.55ms +step:829/1680 train_time:74235ms step_avg:89.55ms +step:830/1680 train_time:74325ms step_avg:89.55ms +step:831/1680 train_time:74415ms step_avg:89.55ms +step:832/1680 train_time:74505ms step_avg:89.55ms +step:833/1680 train_time:74595ms step_avg:89.55ms +step:834/1680 train_time:74686ms step_avg:89.55ms +step:835/1680 train_time:74776ms step_avg:89.55ms +step:836/1680 train_time:74867ms step_avg:89.55ms +step:837/1680 train_time:74957ms step_avg:89.55ms +step:838/1680 train_time:75047ms step_avg:89.55ms +step:839/1680 train_time:75137ms step_avg:89.56ms +step:840/1680 train_time:75227ms step_avg:89.56ms +step:841/1680 train_time:75317ms step_avg:89.56ms +step:842/1680 train_time:75407ms step_avg:89.56ms +step:843/1680 train_time:75497ms step_avg:89.56ms +step:844/1680 train_time:75587ms step_avg:89.56ms +step:845/1680 train_time:75677ms step_avg:89.56ms +step:846/1680 train_time:75768ms step_avg:89.56ms +step:847/1680 train_time:75859ms step_avg:89.56ms +step:848/1680 train_time:75950ms step_avg:89.56ms +step:849/1680 train_time:76040ms step_avg:89.56ms +step:850/1680 train_time:76131ms step_avg:89.57ms +step:851/1680 train_time:76221ms step_avg:89.57ms +step:852/1680 train_time:76311ms step_avg:89.57ms +step:853/1680 train_time:76401ms step_avg:89.57ms +step:854/1680 train_time:76490ms step_avg:89.57ms +step:855/1680 train_time:76581ms step_avg:89.57ms +step:856/1680 train_time:76670ms step_avg:89.57ms +step:857/1680 train_time:76761ms step_avg:89.57ms +step:858/1680 train_time:76851ms step_avg:89.57ms +step:859/1680 train_time:76941ms step_avg:89.57ms +step:860/1680 train_time:77033ms step_avg:89.57ms +step:861/1680 train_time:77123ms step_avg:89.57ms +step:862/1680 train_time:77213ms step_avg:89.57ms +step:863/1680 train_time:77303ms step_avg:89.57ms +step:864/1680 train_time:77393ms step_avg:89.58ms +step:865/1680 train_time:77484ms step_avg:89.58ms +step:866/1680 train_time:77574ms step_avg:89.58ms +step:867/1680 train_time:77664ms step_avg:89.58ms +step:868/1680 train_time:77755ms step_avg:89.58ms +step:869/1680 train_time:77845ms step_avg:89.58ms +step:870/1680 train_time:77936ms step_avg:89.58ms +step:871/1680 train_time:78026ms step_avg:89.58ms +step:872/1680 train_time:78116ms step_avg:89.58ms +step:873/1680 train_time:78207ms step_avg:89.58ms +step:874/1680 train_time:78297ms step_avg:89.58ms +step:875/1680 train_time:78387ms step_avg:89.58ms +step:875/1680 val_loss:3.5187 train_time:78479ms step_avg:89.69ms +step:876/1680 train_time:78501ms step_avg:89.61ms +step:877/1680 train_time:78572ms step_avg:89.59ms +step:878/1680 train_time:78670ms step_avg:89.60ms +step:879/1680 train_time:78761ms step_avg:89.60ms +step:880/1680 train_time:78850ms step_avg:89.60ms +step:881/1680 train_time:78939ms step_avg:89.60ms +step:882/1680 train_time:79028ms step_avg:89.60ms +step:883/1680 train_time:79118ms step_avg:89.60ms +step:884/1680 train_time:79207ms step_avg:89.60ms +step:885/1680 train_time:79296ms step_avg:89.60ms +step:886/1680 train_time:79386ms step_avg:89.60ms +step:887/1680 train_time:79477ms step_avg:89.60ms +step:888/1680 train_time:79569ms step_avg:89.60ms +step:889/1680 train_time:79662ms step_avg:89.61ms +step:890/1680 train_time:79753ms step_avg:89.61ms +step:891/1680 train_time:79843ms step_avg:89.61ms +step:892/1680 train_time:79933ms step_avg:89.61ms +step:893/1680 train_time:80023ms step_avg:89.61ms +step:894/1680 train_time:80115ms step_avg:89.61ms +step:895/1680 train_time:80204ms step_avg:89.61ms +step:896/1680 train_time:80293ms step_avg:89.61ms +step:897/1680 train_time:80383ms step_avg:89.61ms +step:898/1680 train_time:80473ms step_avg:89.61ms +step:899/1680 train_time:80564ms step_avg:89.62ms +step:900/1680 train_time:80655ms step_avg:89.62ms +step:901/1680 train_time:80746ms step_avg:89.62ms +step:902/1680 train_time:80837ms step_avg:89.62ms +step:903/1680 train_time:80927ms step_avg:89.62ms +step:904/1680 train_time:81017ms step_avg:89.62ms +step:905/1680 train_time:81108ms step_avg:89.62ms +step:906/1680 train_time:81197ms step_avg:89.62ms +step:907/1680 train_time:81287ms step_avg:89.62ms +step:908/1680 train_time:81378ms step_avg:89.62ms +step:909/1680 train_time:81468ms step_avg:89.62ms +step:910/1680 train_time:81559ms step_avg:89.62ms +step:911/1680 train_time:81650ms step_avg:89.63ms +step:912/1680 train_time:81741ms step_avg:89.63ms +step:913/1680 train_time:81831ms step_avg:89.63ms +step:914/1680 train_time:81921ms step_avg:89.63ms +step:915/1680 train_time:82012ms step_avg:89.63ms +step:916/1680 train_time:82102ms step_avg:89.63ms +step:917/1680 train_time:82192ms step_avg:89.63ms +step:918/1680 train_time:82281ms step_avg:89.63ms +step:919/1680 train_time:82371ms step_avg:89.63ms +step:920/1680 train_time:82461ms step_avg:89.63ms +step:921/1680 train_time:82551ms step_avg:89.63ms +step:922/1680 train_time:82641ms step_avg:89.63ms +step:923/1680 train_time:82731ms step_avg:89.63ms +step:924/1680 train_time:82822ms step_avg:89.63ms +step:925/1680 train_time:82911ms step_avg:89.63ms +step:926/1680 train_time:83002ms step_avg:89.63ms +step:927/1680 train_time:83092ms step_avg:89.64ms +step:928/1680 train_time:83183ms step_avg:89.64ms +step:929/1680 train_time:83274ms step_avg:89.64ms +step:930/1680 train_time:83364ms step_avg:89.64ms +step:931/1680 train_time:83454ms step_avg:89.64ms +step:932/1680 train_time:83544ms step_avg:89.64ms +step:933/1680 train_time:83634ms step_avg:89.64ms +step:934/1680 train_time:83724ms step_avg:89.64ms +step:935/1680 train_time:83816ms step_avg:89.64ms +step:936/1680 train_time:83905ms step_avg:89.64ms +step:937/1680 train_time:83996ms step_avg:89.64ms +step:938/1680 train_time:84086ms step_avg:89.64ms +step:939/1680 train_time:84176ms step_avg:89.64ms +step:940/1680 train_time:84266ms step_avg:89.64ms +step:941/1680 train_time:84355ms step_avg:89.64ms +step:942/1680 train_time:84445ms step_avg:89.64ms +step:943/1680 train_time:84536ms step_avg:89.65ms +step:944/1680 train_time:84626ms step_avg:89.65ms +step:945/1680 train_time:84717ms step_avg:89.65ms +step:946/1680 train_time:84807ms step_avg:89.65ms +step:947/1680 train_time:84898ms step_avg:89.65ms +step:948/1680 train_time:84989ms step_avg:89.65ms +step:949/1680 train_time:85080ms step_avg:89.65ms +step:950/1680 train_time:85170ms step_avg:89.65ms +step:951/1680 train_time:85260ms step_avg:89.65ms +step:952/1680 train_time:85350ms step_avg:89.65ms +step:953/1680 train_time:85440ms step_avg:89.65ms +step:954/1680 train_time:85529ms step_avg:89.65ms +step:955/1680 train_time:85620ms step_avg:89.65ms +step:956/1680 train_time:85710ms step_avg:89.65ms +step:957/1680 train_time:85799ms step_avg:89.65ms +step:958/1680 train_time:85890ms step_avg:89.66ms +step:959/1680 train_time:85980ms step_avg:89.66ms +step:960/1680 train_time:86070ms step_avg:89.66ms +step:961/1680 train_time:86160ms step_avg:89.66ms +step:962/1680 train_time:86250ms step_avg:89.66ms +step:963/1680 train_time:86340ms step_avg:89.66ms +step:964/1680 train_time:86430ms step_avg:89.66ms +step:965/1680 train_time:86520ms step_avg:89.66ms +step:966/1680 train_time:86610ms step_avg:89.66ms +step:967/1680 train_time:86700ms step_avg:89.66ms +step:968/1680 train_time:86791ms step_avg:89.66ms +step:969/1680 train_time:86881ms step_avg:89.66ms +step:970/1680 train_time:86970ms step_avg:89.66ms +step:971/1680 train_time:87060ms step_avg:89.66ms +step:972/1680 train_time:87150ms step_avg:89.66ms +step:973/1680 train_time:87240ms step_avg:89.66ms +step:974/1680 train_time:87330ms step_avg:89.66ms +step:975/1680 train_time:87420ms step_avg:89.66ms +step:976/1680 train_time:87509ms step_avg:89.66ms +step:977/1680 train_time:87599ms step_avg:89.66ms +step:978/1680 train_time:87689ms step_avg:89.66ms +step:979/1680 train_time:87779ms step_avg:89.66ms +step:980/1680 train_time:87869ms step_avg:89.66ms +step:981/1680 train_time:87959ms step_avg:89.66ms +step:982/1680 train_time:88050ms step_avg:89.66ms +step:983/1680 train_time:88140ms step_avg:89.66ms +step:984/1680 train_time:88231ms step_avg:89.67ms +step:985/1680 train_time:88322ms step_avg:89.67ms +step:986/1680 train_time:88411ms step_avg:89.67ms +step:987/1680 train_time:88500ms step_avg:89.67ms +step:988/1680 train_time:88591ms step_avg:89.67ms +step:989/1680 train_time:88682ms step_avg:89.67ms +step:990/1680 train_time:88773ms step_avg:89.67ms +step:991/1680 train_time:88863ms step_avg:89.67ms +step:992/1680 train_time:88954ms step_avg:89.67ms +step:993/1680 train_time:89044ms step_avg:89.67ms +step:994/1680 train_time:89135ms step_avg:89.67ms +step:995/1680 train_time:89225ms step_avg:89.67ms +step:996/1680 train_time:89316ms step_avg:89.68ms +step:997/1680 train_time:89407ms step_avg:89.68ms +step:998/1680 train_time:89497ms step_avg:89.68ms +step:999/1680 train_time:89588ms step_avg:89.68ms +step:1000/1680 train_time:89678ms step_avg:89.68ms +step:1000/1680 val_loss:3.4681 train_time:89770ms step_avg:89.77ms +step:1001/1680 train_time:89793ms step_avg:89.70ms +step:1002/1680 train_time:89865ms step_avg:89.69ms +step:1003/1680 train_time:89962ms step_avg:89.69ms +step:1004/1680 train_time:90054ms step_avg:89.70ms +step:1005/1680 train_time:90143ms step_avg:89.69ms +step:1006/1680 train_time:90233ms step_avg:89.69ms +step:1007/1680 train_time:90322ms step_avg:89.69ms +step:1008/1680 train_time:90411ms step_avg:89.69ms +step:1009/1680 train_time:90499ms step_avg:89.69ms +step:1010/1680 train_time:90588ms step_avg:89.69ms +step:1011/1680 train_time:90678ms step_avg:89.69ms +step:1012/1680 train_time:90767ms step_avg:89.69ms +step:1013/1680 train_time:90860ms step_avg:89.69ms +step:1014/1680 train_time:90954ms step_avg:89.70ms +step:1015/1680 train_time:91045ms step_avg:89.70ms +step:1016/1680 train_time:91136ms step_avg:89.70ms +step:1017/1680 train_time:91226ms step_avg:89.70ms +step:1018/1680 train_time:91315ms step_avg:89.70ms +step:1019/1680 train_time:91405ms step_avg:89.70ms +step:1020/1680 train_time:91494ms step_avg:89.70ms +step:1021/1680 train_time:91583ms step_avg:89.70ms +step:1022/1680 train_time:91673ms step_avg:89.70ms +step:1023/1680 train_time:91763ms step_avg:89.70ms +step:1024/1680 train_time:91855ms step_avg:89.70ms +step:1025/1680 train_time:91947ms step_avg:89.70ms +step:1026/1680 train_time:92037ms step_avg:89.70ms +step:1027/1680 train_time:92128ms step_avg:89.71ms +step:1028/1680 train_time:92219ms step_avg:89.71ms +step:1029/1680 train_time:92309ms step_avg:89.71ms +step:1030/1680 train_time:92399ms step_avg:89.71ms +step:1031/1680 train_time:92489ms step_avg:89.71ms +step:1032/1680 train_time:92578ms step_avg:89.71ms +step:1033/1680 train_time:92668ms step_avg:89.71ms +step:1034/1680 train_time:92759ms step_avg:89.71ms +step:1035/1680 train_time:92850ms step_avg:89.71ms +step:1036/1680 train_time:92942ms step_avg:89.71ms +step:1037/1680 train_time:93034ms step_avg:89.71ms +step:1038/1680 train_time:93125ms step_avg:89.72ms +step:1039/1680 train_time:93215ms step_avg:89.72ms +step:1040/1680 train_time:93304ms step_avg:89.72ms +step:1041/1680 train_time:93394ms step_avg:89.72ms +step:1042/1680 train_time:93483ms step_avg:89.72ms +step:1043/1680 train_time:93573ms step_avg:89.71ms +step:1044/1680 train_time:93663ms step_avg:89.72ms +step:1045/1680 train_time:93753ms step_avg:89.72ms +step:1046/1680 train_time:93844ms step_avg:89.72ms +step:1047/1680 train_time:93935ms step_avg:89.72ms +step:1048/1680 train_time:94025ms step_avg:89.72ms +step:1049/1680 train_time:94117ms step_avg:89.72ms +step:1050/1680 train_time:94207ms step_avg:89.72ms +step:1051/1680 train_time:94298ms step_avg:89.72ms +step:1052/1680 train_time:94388ms step_avg:89.72ms +step:1053/1680 train_time:94477ms step_avg:89.72ms +step:1054/1680 train_time:94567ms step_avg:89.72ms +step:1055/1680 train_time:94656ms step_avg:89.72ms +step:1056/1680 train_time:94746ms step_avg:89.72ms +step:1057/1680 train_time:94837ms step_avg:89.72ms +step:1058/1680 train_time:94927ms step_avg:89.72ms +step:1059/1680 train_time:95019ms step_avg:89.73ms +step:1060/1680 train_time:95109ms step_avg:89.73ms +step:1061/1680 train_time:95200ms step_avg:89.73ms +step:1062/1680 train_time:95290ms step_avg:89.73ms +step:1063/1680 train_time:95380ms step_avg:89.73ms +step:1064/1680 train_time:95469ms step_avg:89.73ms +step:1065/1680 train_time:95559ms step_avg:89.73ms +step:1066/1680 train_time:95649ms step_avg:89.73ms +step:1067/1680 train_time:95739ms step_avg:89.73ms +step:1068/1680 train_time:95830ms step_avg:89.73ms +step:1069/1680 train_time:95922ms step_avg:89.73ms +step:1070/1680 train_time:96012ms step_avg:89.73ms +step:1071/1680 train_time:96102ms step_avg:89.73ms +step:1072/1680 train_time:96193ms step_avg:89.73ms +step:1073/1680 train_time:96283ms step_avg:89.73ms +step:1074/1680 train_time:96373ms step_avg:89.73ms +step:1075/1680 train_time:96463ms step_avg:89.73ms +step:1076/1680 train_time:96553ms step_avg:89.73ms +step:1077/1680 train_time:96642ms step_avg:89.73ms +step:1078/1680 train_time:96733ms step_avg:89.73ms +step:1079/1680 train_time:96823ms step_avg:89.73ms +step:1080/1680 train_time:96913ms step_avg:89.73ms +step:1081/1680 train_time:97003ms step_avg:89.73ms +step:1082/1680 train_time:97095ms step_avg:89.74ms +step:1083/1680 train_time:97184ms step_avg:89.74ms +step:1084/1680 train_time:97275ms step_avg:89.74ms +step:1085/1680 train_time:97365ms step_avg:89.74ms +step:1086/1680 train_time:97455ms step_avg:89.74ms +step:1087/1680 train_time:97545ms step_avg:89.74ms +step:1088/1680 train_time:97636ms step_avg:89.74ms +step:1089/1680 train_time:97726ms step_avg:89.74ms +step:1090/1680 train_time:97816ms step_avg:89.74ms +step:1091/1680 train_time:97906ms step_avg:89.74ms +step:1092/1680 train_time:97996ms step_avg:89.74ms +step:1093/1680 train_time:98086ms step_avg:89.74ms +step:1094/1680 train_time:98175ms step_avg:89.74ms +step:1095/1680 train_time:98266ms step_avg:89.74ms +step:1096/1680 train_time:98357ms step_avg:89.74ms +step:1097/1680 train_time:98448ms step_avg:89.74ms +step:1098/1680 train_time:98538ms step_avg:89.74ms +step:1099/1680 train_time:98630ms step_avg:89.75ms +step:1100/1680 train_time:98722ms step_avg:89.75ms +step:1101/1680 train_time:98813ms step_avg:89.75ms +step:1102/1680 train_time:98903ms step_avg:89.75ms +step:1103/1680 train_time:98994ms step_avg:89.75ms +step:1104/1680 train_time:99084ms step_avg:89.75ms +step:1105/1680 train_time:99175ms step_avg:89.75ms +step:1106/1680 train_time:99266ms step_avg:89.75ms +step:1107/1680 train_time:99357ms step_avg:89.75ms +step:1108/1680 train_time:99447ms step_avg:89.75ms +step:1109/1680 train_time:99538ms step_avg:89.75ms +step:1110/1680 train_time:99630ms step_avg:89.76ms +step:1111/1680 train_time:99721ms step_avg:89.76ms +step:1112/1680 train_time:99812ms step_avg:89.76ms +step:1113/1680 train_time:99902ms step_avg:89.76ms +step:1114/1680 train_time:99993ms step_avg:89.76ms +step:1115/1680 train_time:100084ms step_avg:89.76ms +step:1116/1680 train_time:100175ms step_avg:89.76ms +step:1117/1680 train_time:100267ms step_avg:89.76ms +step:1118/1680 train_time:100358ms step_avg:89.77ms +step:1119/1680 train_time:100448ms step_avg:89.77ms +step:1120/1680 train_time:100539ms step_avg:89.77ms +step:1121/1680 train_time:100630ms step_avg:89.77ms +step:1122/1680 train_time:100721ms step_avg:89.77ms +step:1123/1680 train_time:100812ms step_avg:89.77ms +step:1124/1680 train_time:100902ms step_avg:89.77ms +step:1125/1680 train_time:100994ms step_avg:89.77ms +step:1125/1680 val_loss:3.4154 train_time:101086ms step_avg:89.85ms +step:1126/1680 train_time:101108ms step_avg:89.79ms +step:1127/1680 train_time:101181ms step_avg:89.78ms +step:1128/1680 train_time:101277ms step_avg:89.78ms +step:1129/1680 train_time:101368ms step_avg:89.79ms +step:1130/1680 train_time:101458ms step_avg:89.79ms +step:1131/1680 train_time:101549ms step_avg:89.79ms +step:1132/1680 train_time:101639ms step_avg:89.79ms +step:1133/1680 train_time:101729ms step_avg:89.79ms +step:1134/1680 train_time:101818ms step_avg:89.79ms +step:1135/1680 train_time:101908ms step_avg:89.79ms +step:1136/1680 train_time:101999ms step_avg:89.79ms +step:1137/1680 train_time:102090ms step_avg:89.79ms +step:1138/1680 train_time:102186ms step_avg:89.79ms +step:1139/1680 train_time:102279ms step_avg:89.80ms +step:1140/1680 train_time:102371ms step_avg:89.80ms +step:1141/1680 train_time:102461ms step_avg:89.80ms +step:1142/1680 train_time:102552ms step_avg:89.80ms +step:1143/1680 train_time:102642ms step_avg:89.80ms +step:1144/1680 train_time:102733ms step_avg:89.80ms +step:1145/1680 train_time:102823ms step_avg:89.80ms +step:1146/1680 train_time:102914ms step_avg:89.80ms +step:1147/1680 train_time:103005ms step_avg:89.80ms +step:1148/1680 train_time:103097ms step_avg:89.81ms +step:1149/1680 train_time:103189ms step_avg:89.81ms +step:1150/1680 train_time:103281ms step_avg:89.81ms +step:1151/1680 train_time:103372ms step_avg:89.81ms +step:1152/1680 train_time:103462ms step_avg:89.81ms +step:1153/1680 train_time:103553ms step_avg:89.81ms +step:1154/1680 train_time:103644ms step_avg:89.81ms +step:1155/1680 train_time:103735ms step_avg:89.81ms +step:1156/1680 train_time:103825ms step_avg:89.81ms +step:1157/1680 train_time:103915ms step_avg:89.81ms +step:1158/1680 train_time:104007ms step_avg:89.82ms +step:1159/1680 train_time:104098ms step_avg:89.82ms +step:1160/1680 train_time:104191ms step_avg:89.82ms +step:1161/1680 train_time:104283ms step_avg:89.82ms +step:1162/1680 train_time:104375ms step_avg:89.82ms +step:1163/1680 train_time:104466ms step_avg:89.82ms +step:1164/1680 train_time:104556ms step_avg:89.83ms +step:1165/1680 train_time:104647ms step_avg:89.83ms +step:1166/1680 train_time:104737ms step_avg:89.83ms +step:1167/1680 train_time:104828ms step_avg:89.83ms +step:1168/1680 train_time:104918ms step_avg:89.83ms +step:1169/1680 train_time:105008ms step_avg:89.83ms +step:1170/1680 train_time:105099ms step_avg:89.83ms +step:1171/1680 train_time:105191ms step_avg:89.83ms +step:1172/1680 train_time:105282ms step_avg:89.83ms +step:1173/1680 train_time:105374ms step_avg:89.83ms +step:1174/1680 train_time:105466ms step_avg:89.83ms +step:1175/1680 train_time:105558ms step_avg:89.84ms +step:1176/1680 train_time:105648ms step_avg:89.84ms +step:1177/1680 train_time:105738ms step_avg:89.84ms +step:1178/1680 train_time:105829ms step_avg:89.84ms +step:1179/1680 train_time:105919ms step_avg:89.84ms +step:1180/1680 train_time:106010ms step_avg:89.84ms +step:1181/1680 train_time:106101ms step_avg:89.84ms +step:1182/1680 train_time:106192ms step_avg:89.84ms +step:1183/1680 train_time:106284ms step_avg:89.84ms +step:1184/1680 train_time:106376ms step_avg:89.84ms +step:1185/1680 train_time:106467ms step_avg:89.85ms +step:1186/1680 train_time:106558ms step_avg:89.85ms +step:1187/1680 train_time:106649ms step_avg:89.85ms +step:1188/1680 train_time:106739ms step_avg:89.85ms +step:1189/1680 train_time:106829ms step_avg:89.85ms +step:1190/1680 train_time:106919ms step_avg:89.85ms +step:1191/1680 train_time:107011ms step_avg:89.85ms +step:1192/1680 train_time:107102ms step_avg:89.85ms +step:1193/1680 train_time:107193ms step_avg:89.85ms +step:1194/1680 train_time:107286ms step_avg:89.85ms +step:1195/1680 train_time:107378ms step_avg:89.86ms +step:1196/1680 train_time:107469ms step_avg:89.86ms +step:1197/1680 train_time:107560ms step_avg:89.86ms +step:1198/1680 train_time:107651ms step_avg:89.86ms +step:1199/1680 train_time:107741ms step_avg:89.86ms +step:1200/1680 train_time:107831ms step_avg:89.86ms +step:1201/1680 train_time:107922ms step_avg:89.86ms +step:1202/1680 train_time:108014ms step_avg:89.86ms +step:1203/1680 train_time:108104ms step_avg:89.86ms +step:1204/1680 train_time:108194ms step_avg:89.86ms +step:1205/1680 train_time:108285ms step_avg:89.86ms +step:1206/1680 train_time:108377ms step_avg:89.86ms +step:1207/1680 train_time:108468ms step_avg:89.87ms +step:1208/1680 train_time:108560ms step_avg:89.87ms +step:1209/1680 train_time:108650ms step_avg:89.87ms +step:1210/1680 train_time:108741ms step_avg:89.87ms +step:1211/1680 train_time:108832ms step_avg:89.87ms +step:1212/1680 train_time:108924ms step_avg:89.87ms +step:1213/1680 train_time:109014ms step_avg:89.87ms +step:1214/1680 train_time:109105ms step_avg:89.87ms +step:1215/1680 train_time:109196ms step_avg:89.87ms +step:1216/1680 train_time:109288ms step_avg:89.87ms +step:1217/1680 train_time:109379ms step_avg:89.88ms +step:1218/1680 train_time:109470ms step_avg:89.88ms +step:1219/1680 train_time:109561ms step_avg:89.88ms +step:1220/1680 train_time:109651ms step_avg:89.88ms +step:1221/1680 train_time:109742ms step_avg:89.88ms +step:1222/1680 train_time:109833ms step_avg:89.88ms +step:1223/1680 train_time:109924ms step_avg:89.88ms +step:1224/1680 train_time:110015ms step_avg:89.88ms +step:1225/1680 train_time:110106ms step_avg:89.88ms +step:1226/1680 train_time:110197ms step_avg:89.88ms +step:1227/1680 train_time:110288ms step_avg:89.88ms +step:1228/1680 train_time:110379ms step_avg:89.89ms +step:1229/1680 train_time:110470ms step_avg:89.89ms +step:1230/1680 train_time:110560ms step_avg:89.89ms +step:1231/1680 train_time:110653ms step_avg:89.89ms +step:1232/1680 train_time:110744ms step_avg:89.89ms +step:1233/1680 train_time:110835ms step_avg:89.89ms +step:1234/1680 train_time:110926ms step_avg:89.89ms +step:1235/1680 train_time:111016ms step_avg:89.89ms +step:1236/1680 train_time:111108ms step_avg:89.89ms +step:1237/1680 train_time:111200ms step_avg:89.89ms +step:1238/1680 train_time:111290ms step_avg:89.90ms +step:1239/1680 train_time:111382ms step_avg:89.90ms +step:1240/1680 train_time:111472ms step_avg:89.90ms +step:1241/1680 train_time:111563ms step_avg:89.90ms +step:1242/1680 train_time:111654ms step_avg:89.90ms +step:1243/1680 train_time:111745ms step_avg:89.90ms +step:1244/1680 train_time:111835ms step_avg:89.90ms +step:1245/1680 train_time:111926ms step_avg:89.90ms +step:1246/1680 train_time:112016ms step_avg:89.90ms +step:1247/1680 train_time:112108ms step_avg:89.90ms +step:1248/1680 train_time:112199ms step_avg:89.90ms +step:1249/1680 train_time:112289ms step_avg:89.90ms +step:1250/1680 train_time:112381ms step_avg:89.91ms +step:1250/1680 val_loss:3.3768 train_time:112473ms step_avg:89.98ms +step:1251/1680 train_time:112496ms step_avg:89.92ms +step:1252/1680 train_time:112572ms step_avg:89.91ms +step:1253/1680 train_time:112667ms step_avg:89.92ms +step:1254/1680 train_time:112759ms step_avg:89.92ms +step:1255/1680 train_time:112850ms step_avg:89.92ms +step:1256/1680 train_time:112939ms step_avg:89.92ms +step:1257/1680 train_time:113028ms step_avg:89.92ms +step:1258/1680 train_time:113118ms step_avg:89.92ms +step:1259/1680 train_time:113208ms step_avg:89.92ms +step:1260/1680 train_time:113298ms step_avg:89.92ms +step:1261/1680 train_time:113388ms step_avg:89.92ms +step:1262/1680 train_time:113483ms step_avg:89.92ms +step:1263/1680 train_time:113577ms step_avg:89.93ms +step:1264/1680 train_time:113671ms step_avg:89.93ms +step:1265/1680 train_time:113762ms step_avg:89.93ms +step:1266/1680 train_time:113852ms step_avg:89.93ms +step:1267/1680 train_time:113943ms step_avg:89.93ms +step:1268/1680 train_time:114033ms step_avg:89.93ms +step:1269/1680 train_time:114122ms step_avg:89.93ms +step:1270/1680 train_time:114211ms step_avg:89.93ms +step:1271/1680 train_time:114301ms step_avg:89.93ms +step:1272/1680 train_time:114391ms step_avg:89.93ms +step:1273/1680 train_time:114483ms step_avg:89.93ms +step:1274/1680 train_time:114576ms step_avg:89.93ms +step:1275/1680 train_time:114669ms step_avg:89.94ms +step:1276/1680 train_time:114760ms step_avg:89.94ms +step:1277/1680 train_time:114851ms step_avg:89.94ms +step:1278/1680 train_time:114941ms step_avg:89.94ms +step:1279/1680 train_time:115032ms step_avg:89.94ms +step:1280/1680 train_time:115122ms step_avg:89.94ms +step:1281/1680 train_time:115212ms step_avg:89.94ms +step:1282/1680 train_time:115302ms step_avg:89.94ms +step:1283/1680 train_time:115393ms step_avg:89.94ms +step:1284/1680 train_time:115484ms step_avg:89.94ms +step:1285/1680 train_time:115576ms step_avg:89.94ms +step:1286/1680 train_time:115667ms step_avg:89.94ms +step:1287/1680 train_time:115759ms step_avg:89.94ms +step:1288/1680 train_time:115850ms step_avg:89.95ms +step:1289/1680 train_time:115941ms step_avg:89.95ms +step:1290/1680 train_time:116032ms step_avg:89.95ms +step:1291/1680 train_time:116122ms step_avg:89.95ms +step:1292/1680 train_time:116212ms step_avg:89.95ms +step:1293/1680 train_time:116302ms step_avg:89.95ms +step:1294/1680 train_time:116393ms step_avg:89.95ms +step:1295/1680 train_time:116484ms step_avg:89.95ms +step:1296/1680 train_time:116575ms step_avg:89.95ms +step:1297/1680 train_time:116666ms step_avg:89.95ms +step:1298/1680 train_time:116758ms step_avg:89.95ms +step:1299/1680 train_time:116849ms step_avg:89.95ms +step:1300/1680 train_time:116940ms step_avg:89.95ms +step:1301/1680 train_time:117030ms step_avg:89.95ms +step:1302/1680 train_time:117120ms step_avg:89.95ms +step:1303/1680 train_time:117210ms step_avg:89.95ms +step:1304/1680 train_time:117301ms step_avg:89.95ms +step:1305/1680 train_time:117392ms step_avg:89.96ms +step:1306/1680 train_time:117484ms step_avg:89.96ms +step:1307/1680 train_time:117574ms step_avg:89.96ms +step:1308/1680 train_time:117666ms step_avg:89.96ms +step:1309/1680 train_time:117757ms step_avg:89.96ms +step:1310/1680 train_time:117848ms step_avg:89.96ms +step:1311/1680 train_time:117940ms step_avg:89.96ms +step:1312/1680 train_time:118030ms step_avg:89.96ms +step:1313/1680 train_time:118121ms step_avg:89.96ms +step:1314/1680 train_time:118212ms step_avg:89.96ms +step:1315/1680 train_time:118301ms step_avg:89.96ms +step:1316/1680 train_time:118392ms step_avg:89.96ms +step:1317/1680 train_time:118483ms step_avg:89.96ms +step:1318/1680 train_time:118574ms step_avg:89.97ms +step:1319/1680 train_time:118665ms step_avg:89.97ms +step:1320/1680 train_time:118756ms step_avg:89.97ms +step:1321/1680 train_time:118847ms step_avg:89.97ms +step:1322/1680 train_time:118940ms step_avg:89.97ms +step:1323/1680 train_time:119032ms step_avg:89.97ms +step:1324/1680 train_time:119122ms step_avg:89.97ms +step:1325/1680 train_time:119212ms step_avg:89.97ms +step:1326/1680 train_time:119302ms step_avg:89.97ms +step:1327/1680 train_time:119393ms step_avg:89.97ms +step:1328/1680 train_time:119484ms step_avg:89.97ms +step:1329/1680 train_time:119574ms step_avg:89.97ms +step:1330/1680 train_time:119665ms step_avg:89.97ms +step:1331/1680 train_time:119756ms step_avg:89.97ms +step:1332/1680 train_time:119849ms step_avg:89.98ms +step:1333/1680 train_time:119940ms step_avg:89.98ms +step:1334/1680 train_time:120032ms step_avg:89.98ms +step:1335/1680 train_time:120122ms step_avg:89.98ms +step:1336/1680 train_time:120213ms step_avg:89.98ms +step:1337/1680 train_time:120303ms step_avg:89.98ms +step:1338/1680 train_time:120394ms step_avg:89.98ms +step:1339/1680 train_time:120484ms step_avg:89.98ms +step:1340/1680 train_time:120575ms step_avg:89.98ms +step:1341/1680 train_time:120666ms step_avg:89.98ms +step:1342/1680 train_time:120758ms step_avg:89.98ms +step:1343/1680 train_time:120849ms step_avg:89.98ms +step:1344/1680 train_time:120940ms step_avg:89.98ms +step:1345/1680 train_time:121031ms step_avg:89.99ms +step:1346/1680 train_time:121121ms step_avg:89.99ms +step:1347/1680 train_time:121212ms step_avg:89.99ms +step:1348/1680 train_time:121303ms step_avg:89.99ms +step:1349/1680 train_time:121393ms step_avg:89.99ms +step:1350/1680 train_time:121483ms step_avg:89.99ms +step:1351/1680 train_time:121574ms step_avg:89.99ms +step:1352/1680 train_time:121667ms step_avg:89.99ms +step:1353/1680 train_time:121759ms step_avg:89.99ms +step:1354/1680 train_time:121850ms step_avg:89.99ms +step:1355/1680 train_time:121941ms step_avg:89.99ms +step:1356/1680 train_time:122032ms step_avg:89.99ms +step:1357/1680 train_time:122123ms step_avg:89.99ms +step:1358/1680 train_time:122213ms step_avg:90.00ms +step:1359/1680 train_time:122304ms step_avg:90.00ms +step:1360/1680 train_time:122394ms step_avg:90.00ms +step:1361/1680 train_time:122485ms step_avg:90.00ms +step:1362/1680 train_time:122576ms step_avg:90.00ms +step:1363/1680 train_time:122668ms step_avg:90.00ms +step:1364/1680 train_time:122759ms step_avg:90.00ms +step:1365/1680 train_time:122850ms step_avg:90.00ms +step:1366/1680 train_time:122941ms step_avg:90.00ms +step:1367/1680 train_time:123032ms step_avg:90.00ms +step:1368/1680 train_time:123123ms step_avg:90.00ms +step:1369/1680 train_time:123215ms step_avg:90.00ms +step:1370/1680 train_time:123306ms step_avg:90.00ms +step:1371/1680 train_time:123397ms step_avg:90.01ms +step:1372/1680 train_time:123488ms step_avg:90.01ms +step:1373/1680 train_time:123579ms step_avg:90.01ms +step:1374/1680 train_time:123669ms step_avg:90.01ms +step:1375/1680 train_time:123759ms step_avg:90.01ms +step:1375/1680 val_loss:3.3426 train_time:123852ms step_avg:90.07ms +step:1376/1680 train_time:123874ms step_avg:90.02ms +step:1377/1680 train_time:123944ms step_avg:90.01ms +step:1378/1680 train_time:124039ms step_avg:90.01ms +step:1379/1680 train_time:124129ms step_avg:90.01ms +step:1380/1680 train_time:124219ms step_avg:90.01ms +step:1381/1680 train_time:124310ms step_avg:90.01ms +step:1382/1680 train_time:124401ms step_avg:90.01ms +step:1383/1680 train_time:124491ms step_avg:90.01ms +step:1384/1680 train_time:124581ms step_avg:90.01ms +step:1385/1680 train_time:124671ms step_avg:90.01ms +step:1386/1680 train_time:124761ms step_avg:90.02ms +step:1387/1680 train_time:124854ms step_avg:90.02ms +step:1388/1680 train_time:124946ms step_avg:90.02ms +step:1389/1680 train_time:125038ms step_avg:90.02ms +step:1390/1680 train_time:125129ms step_avg:90.02ms +step:1391/1680 train_time:125221ms step_avg:90.02ms +step:1392/1680 train_time:125311ms step_avg:90.02ms +step:1393/1680 train_time:125401ms step_avg:90.02ms +step:1394/1680 train_time:125491ms step_avg:90.02ms +step:1395/1680 train_time:125581ms step_avg:90.02ms +step:1396/1680 train_time:125671ms step_avg:90.02ms +step:1397/1680 train_time:125762ms step_avg:90.02ms +step:1398/1680 train_time:125853ms step_avg:90.02ms +step:1399/1680 train_time:125945ms step_avg:90.03ms +step:1400/1680 train_time:126038ms step_avg:90.03ms +step:1401/1680 train_time:126130ms step_avg:90.03ms +step:1402/1680 train_time:126222ms step_avg:90.03ms +step:1403/1680 train_time:126312ms step_avg:90.03ms +step:1404/1680 train_time:126403ms step_avg:90.03ms +step:1405/1680 train_time:126493ms step_avg:90.03ms +step:1406/1680 train_time:126582ms step_avg:90.03ms +step:1407/1680 train_time:126673ms step_avg:90.03ms +step:1408/1680 train_time:126763ms step_avg:90.03ms +step:1409/1680 train_time:126854ms step_avg:90.03ms +step:1410/1680 train_time:126945ms step_avg:90.03ms +step:1411/1680 train_time:127037ms step_avg:90.03ms +step:1412/1680 train_time:127129ms step_avg:90.03ms +step:1413/1680 train_time:127221ms step_avg:90.04ms +step:1414/1680 train_time:127312ms step_avg:90.04ms +step:1415/1680 train_time:127403ms step_avg:90.04ms +step:1416/1680 train_time:127493ms step_avg:90.04ms +step:1417/1680 train_time:127583ms step_avg:90.04ms +step:1418/1680 train_time:127674ms step_avg:90.04ms +step:1419/1680 train_time:127764ms step_avg:90.04ms +step:1420/1680 train_time:127855ms step_avg:90.04ms +step:1421/1680 train_time:127946ms step_avg:90.04ms +step:1422/1680 train_time:128037ms step_avg:90.04ms +step:1423/1680 train_time:128128ms step_avg:90.04ms +step:1424/1680 train_time:128220ms step_avg:90.04ms +step:1425/1680 train_time:128311ms step_avg:90.04ms +step:1426/1680 train_time:128403ms step_avg:90.04ms +step:1427/1680 train_time:128493ms step_avg:90.04ms +step:1428/1680 train_time:128586ms step_avg:90.05ms +step:1429/1680 train_time:128675ms step_avg:90.05ms +step:1430/1680 train_time:128765ms step_avg:90.05ms +step:1431/1680 train_time:128855ms step_avg:90.05ms +step:1432/1680 train_time:128946ms step_avg:90.05ms +step:1433/1680 train_time:129037ms step_avg:90.05ms +step:1434/1680 train_time:129128ms step_avg:90.05ms +step:1435/1680 train_time:129219ms step_avg:90.05ms +step:1436/1680 train_time:129311ms step_avg:90.05ms +step:1437/1680 train_time:129401ms step_avg:90.05ms +step:1438/1680 train_time:129492ms step_avg:90.05ms +step:1439/1680 train_time:129583ms step_avg:90.05ms +step:1440/1680 train_time:129674ms step_avg:90.05ms +step:1441/1680 train_time:129765ms step_avg:90.05ms +step:1442/1680 train_time:129855ms step_avg:90.05ms +step:1443/1680 train_time:129946ms step_avg:90.05ms +step:1444/1680 train_time:130037ms step_avg:90.05ms +step:1445/1680 train_time:130128ms step_avg:90.05ms +step:1446/1680 train_time:130220ms step_avg:90.06ms +step:1447/1680 train_time:130310ms step_avg:90.06ms +step:1448/1680 train_time:130401ms step_avg:90.06ms +step:1449/1680 train_time:130492ms step_avg:90.06ms +step:1450/1680 train_time:130583ms step_avg:90.06ms +step:1451/1680 train_time:130673ms step_avg:90.06ms +step:1452/1680 train_time:130764ms step_avg:90.06ms +step:1453/1680 train_time:130854ms step_avg:90.06ms +step:1454/1680 train_time:130945ms step_avg:90.06ms +step:1455/1680 train_time:131036ms step_avg:90.06ms +step:1456/1680 train_time:131126ms step_avg:90.06ms +step:1457/1680 train_time:131217ms step_avg:90.06ms +step:1458/1680 train_time:131309ms step_avg:90.06ms +step:1459/1680 train_time:131399ms step_avg:90.06ms +step:1460/1680 train_time:131490ms step_avg:90.06ms +step:1461/1680 train_time:131581ms step_avg:90.06ms +step:1462/1680 train_time:131672ms step_avg:90.06ms +step:1463/1680 train_time:131763ms step_avg:90.06ms +step:1464/1680 train_time:131854ms step_avg:90.06ms +step:1465/1680 train_time:131944ms step_avg:90.06ms +step:1466/1680 train_time:132035ms step_avg:90.06ms +step:1467/1680 train_time:132125ms step_avg:90.07ms +step:1468/1680 train_time:132217ms step_avg:90.07ms +step:1469/1680 train_time:132308ms step_avg:90.07ms +step:1470/1680 train_time:132399ms step_avg:90.07ms +step:1471/1680 train_time:132490ms step_avg:90.07ms +step:1472/1680 train_time:132581ms step_avg:90.07ms +step:1473/1680 train_time:132672ms step_avg:90.07ms +step:1474/1680 train_time:132763ms step_avg:90.07ms +step:1475/1680 train_time:132853ms step_avg:90.07ms +step:1476/1680 train_time:132944ms step_avg:90.07ms +step:1477/1680 train_time:133034ms step_avg:90.07ms +step:1478/1680 train_time:133125ms step_avg:90.07ms +step:1479/1680 train_time:133216ms step_avg:90.07ms +step:1480/1680 train_time:133307ms step_avg:90.07ms +step:1481/1680 train_time:133398ms step_avg:90.07ms +step:1482/1680 train_time:133490ms step_avg:90.07ms +step:1483/1680 train_time:133580ms step_avg:90.07ms +step:1484/1680 train_time:133671ms step_avg:90.08ms +step:1485/1680 train_time:133764ms step_avg:90.08ms +step:1486/1680 train_time:133854ms step_avg:90.08ms +step:1487/1680 train_time:133945ms step_avg:90.08ms +step:1488/1680 train_time:134035ms step_avg:90.08ms +step:1489/1680 train_time:134126ms step_avg:90.08ms +step:1490/1680 train_time:134217ms step_avg:90.08ms +step:1491/1680 train_time:134308ms step_avg:90.08ms +step:1492/1680 train_time:134399ms step_avg:90.08ms +step:1493/1680 train_time:134492ms step_avg:90.08ms +step:1494/1680 train_time:134583ms step_avg:90.08ms +step:1495/1680 train_time:134674ms step_avg:90.08ms +step:1496/1680 train_time:134765ms step_avg:90.08ms +step:1497/1680 train_time:134855ms step_avg:90.08ms +step:1498/1680 train_time:134947ms step_avg:90.08ms +step:1499/1680 train_time:135037ms step_avg:90.09ms +step:1500/1680 train_time:135128ms step_avg:90.09ms +step:1500/1680 val_loss:3.3125 train_time:135220ms step_avg:90.15ms +step:1501/1680 train_time:135243ms step_avg:90.10ms +step:1502/1680 train_time:135316ms step_avg:90.09ms +step:1503/1680 train_time:135411ms step_avg:90.09ms +step:1504/1680 train_time:135502ms step_avg:90.09ms +step:1505/1680 train_time:135591ms step_avg:90.09ms +step:1506/1680 train_time:135681ms step_avg:90.09ms +step:1507/1680 train_time:135771ms step_avg:90.09ms +step:1508/1680 train_time:135861ms step_avg:90.09ms +step:1509/1680 train_time:135951ms step_avg:90.09ms +step:1510/1680 train_time:136042ms step_avg:90.09ms +step:1511/1680 train_time:136133ms step_avg:90.09ms +step:1512/1680 train_time:136226ms step_avg:90.10ms +step:1513/1680 train_time:136320ms step_avg:90.10ms +step:1514/1680 train_time:136411ms step_avg:90.10ms +step:1515/1680 train_time:136503ms step_avg:90.10ms +step:1516/1680 train_time:136594ms step_avg:90.10ms +step:1517/1680 train_time:136685ms step_avg:90.10ms +step:1518/1680 train_time:136775ms step_avg:90.10ms +step:1519/1680 train_time:136864ms step_avg:90.10ms +step:1520/1680 train_time:136954ms step_avg:90.10ms +step:1521/1680 train_time:137045ms step_avg:90.10ms +step:1522/1680 train_time:137136ms step_avg:90.10ms +step:1523/1680 train_time:137228ms step_avg:90.10ms +step:1524/1680 train_time:137321ms step_avg:90.11ms +step:1525/1680 train_time:137412ms step_avg:90.11ms +step:1526/1680 train_time:137503ms step_avg:90.11ms +step:1527/1680 train_time:137595ms step_avg:90.11ms +step:1528/1680 train_time:137686ms step_avg:90.11ms +step:1529/1680 train_time:137777ms step_avg:90.11ms +step:1530/1680 train_time:137867ms step_avg:90.11ms +step:1531/1680 train_time:137957ms step_avg:90.11ms +step:1532/1680 train_time:138048ms step_avg:90.11ms +step:1533/1680 train_time:138138ms step_avg:90.11ms +step:1534/1680 train_time:138230ms step_avg:90.11ms +step:1535/1680 train_time:138323ms step_avg:90.11ms +step:1536/1680 train_time:138415ms step_avg:90.11ms +step:1537/1680 train_time:138505ms step_avg:90.11ms +step:1538/1680 train_time:138597ms step_avg:90.11ms +step:1539/1680 train_time:138687ms step_avg:90.12ms +step:1540/1680 train_time:138778ms step_avg:90.12ms +step:1541/1680 train_time:138867ms step_avg:90.12ms +step:1542/1680 train_time:138958ms step_avg:90.12ms +step:1543/1680 train_time:139049ms step_avg:90.12ms +step:1544/1680 train_time:139138ms step_avg:90.12ms +step:1545/1680 train_time:139230ms step_avg:90.12ms +step:1546/1680 train_time:139322ms step_avg:90.12ms +step:1547/1680 train_time:139412ms step_avg:90.12ms +step:1548/1680 train_time:139504ms step_avg:90.12ms +step:1549/1680 train_time:139596ms step_avg:90.12ms +step:1550/1680 train_time:139687ms step_avg:90.12ms +step:1551/1680 train_time:139778ms step_avg:90.12ms +step:1552/1680 train_time:139867ms step_avg:90.12ms +step:1553/1680 train_time:139958ms step_avg:90.12ms +step:1554/1680 train_time:140048ms step_avg:90.12ms +step:1555/1680 train_time:140139ms step_avg:90.12ms +step:1556/1680 train_time:140230ms step_avg:90.12ms +step:1557/1680 train_time:140322ms step_avg:90.12ms +step:1558/1680 train_time:140413ms step_avg:90.12ms +step:1559/1680 train_time:140505ms step_avg:90.12ms +step:1560/1680 train_time:140597ms step_avg:90.13ms +step:1561/1680 train_time:140688ms step_avg:90.13ms +step:1562/1680 train_time:140778ms step_avg:90.13ms +step:1563/1680 train_time:140868ms step_avg:90.13ms +step:1564/1680 train_time:140958ms step_avg:90.13ms +step:1565/1680 train_time:141048ms step_avg:90.13ms +step:1566/1680 train_time:141139ms step_avg:90.13ms +step:1567/1680 train_time:141230ms step_avg:90.13ms +step:1568/1680 train_time:141322ms step_avg:90.13ms +step:1569/1680 train_time:141413ms step_avg:90.13ms +step:1570/1680 train_time:141505ms step_avg:90.13ms +step:1571/1680 train_time:141597ms step_avg:90.13ms +step:1572/1680 train_time:141688ms step_avg:90.13ms +step:1573/1680 train_time:141779ms step_avg:90.13ms +step:1574/1680 train_time:141869ms step_avg:90.13ms +step:1575/1680 train_time:141959ms step_avg:90.13ms +step:1576/1680 train_time:142050ms step_avg:90.13ms +step:1577/1680 train_time:142140ms step_avg:90.13ms +step:1578/1680 train_time:142232ms step_avg:90.13ms +step:1579/1680 train_time:142323ms step_avg:90.13ms +step:1580/1680 train_time:142414ms step_avg:90.14ms +step:1581/1680 train_time:142506ms step_avg:90.14ms +step:1582/1680 train_time:142599ms step_avg:90.14ms +step:1583/1680 train_time:142690ms step_avg:90.14ms +step:1584/1680 train_time:142782ms step_avg:90.14ms +step:1585/1680 train_time:142873ms step_avg:90.14ms +step:1586/1680 train_time:142963ms step_avg:90.14ms +step:1587/1680 train_time:143054ms step_avg:90.14ms +step:1588/1680 train_time:143144ms step_avg:90.14ms +step:1589/1680 train_time:143235ms step_avg:90.14ms +step:1590/1680 train_time:143326ms step_avg:90.14ms +step:1591/1680 train_time:143418ms step_avg:90.14ms +step:1592/1680 train_time:143508ms step_avg:90.14ms +step:1593/1680 train_time:143600ms step_avg:90.14ms +step:1594/1680 train_time:143693ms step_avg:90.15ms +step:1595/1680 train_time:143785ms step_avg:90.15ms +step:1596/1680 train_time:143875ms step_avg:90.15ms +step:1597/1680 train_time:143966ms step_avg:90.15ms +step:1598/1680 train_time:144056ms step_avg:90.15ms +step:1599/1680 train_time:144147ms step_avg:90.15ms +step:1600/1680 train_time:144238ms step_avg:90.15ms +step:1601/1680 train_time:144328ms step_avg:90.15ms +step:1602/1680 train_time:144418ms step_avg:90.15ms +step:1603/1680 train_time:144508ms step_avg:90.15ms +step:1604/1680 train_time:144600ms step_avg:90.15ms +step:1605/1680 train_time:144691ms step_avg:90.15ms +step:1606/1680 train_time:144783ms step_avg:90.15ms +step:1607/1680 train_time:144874ms step_avg:90.15ms +step:1608/1680 train_time:144965ms step_avg:90.15ms +step:1609/1680 train_time:145056ms step_avg:90.15ms +step:1610/1680 train_time:145147ms step_avg:90.15ms +step:1611/1680 train_time:145237ms step_avg:90.15ms +step:1612/1680 train_time:145328ms step_avg:90.15ms +step:1613/1680 train_time:145418ms step_avg:90.15ms +step:1614/1680 train_time:145509ms step_avg:90.15ms +step:1615/1680 train_time:145600ms step_avg:90.15ms +step:1616/1680 train_time:145691ms step_avg:90.16ms +step:1617/1680 train_time:145783ms step_avg:90.16ms +step:1618/1680 train_time:145874ms step_avg:90.16ms +step:1619/1680 train_time:145965ms step_avg:90.16ms +step:1620/1680 train_time:146056ms step_avg:90.16ms +step:1621/1680 train_time:146147ms step_avg:90.16ms +step:1622/1680 train_time:146237ms step_avg:90.16ms +step:1623/1680 train_time:146327ms step_avg:90.16ms +step:1624/1680 train_time:146418ms step_avg:90.16ms +step:1625/1680 train_time:146508ms step_avg:90.16ms +step:1625/1680 val_loss:3.2885 train_time:146600ms step_avg:90.22ms +step:1626/1680 train_time:146622ms step_avg:90.17ms +step:1627/1680 train_time:146694ms step_avg:90.16ms +step:1628/1680 train_time:146793ms step_avg:90.17ms +step:1629/1680 train_time:146885ms step_avg:90.17ms +step:1630/1680 train_time:146975ms step_avg:90.17ms +step:1631/1680 train_time:147064ms step_avg:90.17ms +step:1632/1680 train_time:147153ms step_avg:90.17ms +step:1633/1680 train_time:147242ms step_avg:90.17ms +step:1634/1680 train_time:147332ms step_avg:90.17ms +step:1635/1680 train_time:147422ms step_avg:90.17ms +step:1636/1680 train_time:147511ms step_avg:90.17ms +step:1637/1680 train_time:147602ms step_avg:90.17ms +step:1638/1680 train_time:147696ms step_avg:90.17ms +step:1639/1680 train_time:147789ms step_avg:90.17ms +step:1640/1680 train_time:147881ms step_avg:90.17ms +step:1641/1680 train_time:147972ms step_avg:90.17ms +step:1642/1680 train_time:148062ms step_avg:90.17ms +step:1643/1680 train_time:148152ms step_avg:90.17ms +step:1644/1680 train_time:148242ms step_avg:90.17ms +step:1645/1680 train_time:148331ms step_avg:90.17ms +step:1646/1680 train_time:148421ms step_avg:90.17ms +step:1647/1680 train_time:148512ms step_avg:90.17ms +step:1648/1680 train_time:148603ms step_avg:90.17ms +step:1649/1680 train_time:148695ms step_avg:90.17ms +step:1650/1680 train_time:148787ms step_avg:90.17ms +step:1651/1680 train_time:148880ms step_avg:90.18ms +step:1652/1680 train_time:148971ms step_avg:90.18ms +step:1653/1680 train_time:149061ms step_avg:90.18ms +step:1654/1680 train_time:149151ms step_avg:90.18ms +step:1655/1680 train_time:149241ms step_avg:90.18ms +step:1656/1680 train_time:149331ms step_avg:90.18ms +step:1657/1680 train_time:149421ms step_avg:90.18ms +step:1658/1680 train_time:149512ms step_avg:90.18ms +step:1659/1680 train_time:149603ms step_avg:90.18ms +step:1660/1680 train_time:149694ms step_avg:90.18ms +step:1661/1680 train_time:149785ms step_avg:90.18ms +step:1662/1680 train_time:149877ms step_avg:90.18ms +step:1663/1680 train_time:149969ms step_avg:90.18ms +step:1664/1680 train_time:150060ms step_avg:90.18ms +step:1665/1680 train_time:150152ms step_avg:90.18ms +step:1666/1680 train_time:150242ms step_avg:90.18ms +step:1667/1680 train_time:150333ms step_avg:90.18ms +step:1668/1680 train_time:150423ms step_avg:90.18ms +step:1669/1680 train_time:150514ms step_avg:90.18ms +step:1670/1680 train_time:150605ms step_avg:90.18ms +step:1671/1680 train_time:150696ms step_avg:90.18ms +step:1672/1680 train_time:150787ms step_avg:90.18ms +step:1673/1680 train_time:150879ms step_avg:90.18ms +step:1674/1680 train_time:150971ms step_avg:90.19ms +step:1675/1680 train_time:151062ms step_avg:90.19ms +step:1676/1680 train_time:151152ms step_avg:90.19ms +step:1677/1680 train_time:151243ms step_avg:90.19ms +step:1678/1680 train_time:151333ms step_avg:90.19ms +step:1679/1680 train_time:151423ms step_avg:90.19ms +step:1680/1680 train_time:151514ms step_avg:90.19ms +step:1680/1680 val_loss:3.2778 train_time:151606ms step_avg:90.24ms +peak memory allocated: 31255 MiB reserved: 46494 MiB diff --git a/records/092125_DropAttn/be55679c-393d-432f-882d-287e7cfa727d.txt b/records/092125_DropAttn/be55679c-393d-432f-882d-287e7cfa727d.txt new file mode 100644 index 000000000..673f5f971 --- /dev/null +++ b/records/092125_DropAttn/be55679c-393d-432f-882d-287e7cfa727d.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 22:20:27 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 35C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 9762 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 9763 C /usr/bin/python3 614MiB | +| 0 N/A N/A 9764 C /usr/bin/python3 614MiB | +| 0 N/A N/A 9765 C /usr/bin/python3 614MiB | +| 0 N/A N/A 9766 C /usr/bin/python3 614MiB | +| 0 N/A N/A 9767 C /usr/bin/python3 614MiB | +| 0 N/A N/A 9768 C /usr/bin/python3 614MiB | +| 0 N/A N/A 9769 C /usr/bin/python3 614MiB | +| 1 N/A N/A 9763 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 9764 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 9765 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 9766 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 9767 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 9768 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 9769 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:156ms step_avg:155.71ms +step:2/1680 train_time:181ms step_avg:90.66ms +step:3/1680 train_time:243ms step_avg:80.89ms +step:4/1680 train_time:330ms step_avg:82.39ms +step:5/1680 train_time:417ms step_avg:83.43ms +step:6/1680 train_time:505ms step_avg:84.22ms +step:7/1680 train_time:593ms step_avg:84.76ms +step:8/1680 train_time:682ms step_avg:85.25ms +step:9/1680 train_time:772ms step_avg:85.72ms +step:10/1680 train_time:858ms step_avg:85.84ms +step:11/1680 train_time:947ms step_avg:86.07ms +step:12/1680 train_time:1037ms step_avg:86.45ms +step:13/1680 train_time:1129ms step_avg:86.85ms +step:14/1680 train_time:1221ms step_avg:87.19ms +step:15/1680 train_time:1312ms step_avg:87.47ms +step:16/1680 train_time:1400ms step_avg:87.50ms +step:17/1680 train_time:1489ms step_avg:87.61ms +step:18/1680 train_time:1578ms step_avg:87.68ms +step:19/1680 train_time:1668ms step_avg:87.80ms +step:20/1680 train_time:1757ms step_avg:87.83ms +step:21/1680 train_time:1845ms step_avg:87.86ms +step:22/1680 train_time:1934ms step_avg:87.91ms +step:23/1680 train_time:2024ms step_avg:88.01ms +step:24/1680 train_time:2115ms step_avg:88.12ms +step:25/1680 train_time:2206ms step_avg:88.23ms +step:26/1680 train_time:2296ms step_avg:88.29ms +step:27/1680 train_time:2385ms step_avg:88.34ms +step:28/1680 train_time:2475ms step_avg:88.41ms +step:29/1680 train_time:2565ms step_avg:88.45ms +step:30/1680 train_time:2654ms step_avg:88.46ms +step:31/1680 train_time:2743ms step_avg:88.49ms +step:32/1680 train_time:2832ms step_avg:88.51ms +step:33/1680 train_time:2921ms step_avg:88.51ms +step:34/1680 train_time:3011ms step_avg:88.55ms +step:35/1680 train_time:3100ms step_avg:88.58ms +step:36/1680 train_time:3191ms step_avg:88.64ms +step:37/1680 train_time:3281ms step_avg:88.67ms +step:38/1680 train_time:3370ms step_avg:88.69ms +step:39/1680 train_time:3460ms step_avg:88.71ms +step:40/1680 train_time:3549ms step_avg:88.72ms +step:41/1680 train_time:3638ms step_avg:88.73ms +step:42/1680 train_time:3727ms step_avg:88.74ms +step:43/1680 train_time:3816ms step_avg:88.74ms +step:44/1680 train_time:3906ms step_avg:88.76ms +step:45/1680 train_time:3995ms step_avg:88.79ms +step:46/1680 train_time:4085ms step_avg:88.81ms +step:47/1680 train_time:4175ms step_avg:88.82ms +step:48/1680 train_time:4264ms step_avg:88.84ms +step:49/1680 train_time:4354ms step_avg:88.86ms +step:50/1680 train_time:4444ms step_avg:88.88ms +step:51/1680 train_time:4534ms step_avg:88.90ms +step:52/1680 train_time:4624ms step_avg:88.92ms +step:53/1680 train_time:4714ms step_avg:88.94ms +step:54/1680 train_time:4803ms step_avg:88.95ms +step:55/1680 train_time:4892ms step_avg:88.95ms +step:56/1680 train_time:4981ms step_avg:88.95ms +step:57/1680 train_time:5071ms step_avg:88.97ms +step:58/1680 train_time:5160ms step_avg:88.97ms +step:59/1680 train_time:5249ms step_avg:88.97ms +step:60/1680 train_time:5339ms step_avg:88.98ms +step:61/1680 train_time:5428ms step_avg:88.99ms +step:62/1680 train_time:5518ms step_avg:89.00ms +step:63/1680 train_time:5608ms step_avg:89.02ms +step:64/1680 train_time:5698ms step_avg:89.03ms +step:65/1680 train_time:5787ms step_avg:89.02ms +step:66/1680 train_time:5876ms step_avg:89.03ms +step:67/1680 train_time:5966ms step_avg:89.05ms +step:68/1680 train_time:6056ms step_avg:89.06ms +step:69/1680 train_time:6145ms step_avg:89.06ms +step:70/1680 train_time:6235ms step_avg:89.08ms +step:71/1680 train_time:6325ms step_avg:89.09ms +step:72/1680 train_time:6415ms step_avg:89.10ms +step:73/1680 train_time:6505ms step_avg:89.11ms +step:74/1680 train_time:6595ms step_avg:89.12ms +step:75/1680 train_time:6684ms step_avg:89.12ms +step:76/1680 train_time:6774ms step_avg:89.13ms +step:77/1680 train_time:6864ms step_avg:89.14ms +step:78/1680 train_time:6953ms step_avg:89.14ms +step:79/1680 train_time:7042ms step_avg:89.14ms +step:80/1680 train_time:7131ms step_avg:89.14ms +step:81/1680 train_time:7220ms step_avg:89.14ms +step:82/1680 train_time:7309ms step_avg:89.14ms +step:83/1680 train_time:7398ms step_avg:89.14ms +step:84/1680 train_time:7488ms step_avg:89.15ms +step:85/1680 train_time:7578ms step_avg:89.15ms +step:86/1680 train_time:7668ms step_avg:89.16ms +step:87/1680 train_time:7758ms step_avg:89.17ms +step:88/1680 train_time:7847ms step_avg:89.18ms +step:89/1680 train_time:7937ms step_avg:89.18ms +step:90/1680 train_time:8026ms step_avg:89.18ms +step:91/1680 train_time:8115ms step_avg:89.18ms +step:92/1680 train_time:8205ms step_avg:89.18ms +step:93/1680 train_time:8295ms step_avg:89.20ms +step:94/1680 train_time:8385ms step_avg:89.20ms +step:95/1680 train_time:8474ms step_avg:89.20ms +step:96/1680 train_time:8563ms step_avg:89.20ms +step:97/1680 train_time:8653ms step_avg:89.21ms +step:98/1680 train_time:8744ms step_avg:89.22ms +step:99/1680 train_time:8833ms step_avg:89.22ms +step:100/1680 train_time:8923ms step_avg:89.23ms +step:101/1680 train_time:9012ms step_avg:89.23ms +step:102/1680 train_time:9102ms step_avg:89.23ms +step:103/1680 train_time:9191ms step_avg:89.23ms +step:104/1680 train_time:9281ms step_avg:89.24ms +step:105/1680 train_time:9370ms step_avg:89.24ms +step:106/1680 train_time:9460ms step_avg:89.24ms +step:107/1680 train_time:9549ms step_avg:89.24ms +step:108/1680 train_time:9638ms step_avg:89.25ms +step:109/1680 train_time:9729ms step_avg:89.26ms +step:110/1680 train_time:9819ms step_avg:89.26ms +step:111/1680 train_time:9908ms step_avg:89.26ms +step:112/1680 train_time:9997ms step_avg:89.26ms +step:113/1680 train_time:10089ms step_avg:89.29ms +step:114/1680 train_time:10176ms step_avg:89.26ms +step:115/1680 train_time:10265ms step_avg:89.26ms +step:116/1680 train_time:10355ms step_avg:89.26ms +step:117/1680 train_time:10444ms step_avg:89.27ms +step:118/1680 train_time:10534ms step_avg:89.27ms +step:119/1680 train_time:10624ms step_avg:89.27ms +step:120/1680 train_time:10713ms step_avg:89.28ms +step:121/1680 train_time:10802ms step_avg:89.28ms +step:122/1680 train_time:10893ms step_avg:89.28ms +step:123/1680 train_time:10981ms step_avg:89.28ms +step:124/1680 train_time:11070ms step_avg:89.27ms +step:125/1680 train_time:11158ms step_avg:89.26ms +step:125/1680 val_loss:4.3132 train_time:11248ms step_avg:89.98ms +step:126/1680 train_time:11272ms step_avg:89.46ms +step:127/1680 train_time:11340ms step_avg:89.29ms +step:128/1680 train_time:11437ms step_avg:89.35ms +step:129/1680 train_time:11529ms step_avg:89.38ms +step:130/1680 train_time:11619ms step_avg:89.37ms +step:131/1680 train_time:11707ms step_avg:89.37ms +step:132/1680 train_time:11795ms step_avg:89.36ms +step:133/1680 train_time:11883ms step_avg:89.35ms +step:134/1680 train_time:11971ms step_avg:89.34ms +step:135/1680 train_time:12060ms step_avg:89.33ms +step:136/1680 train_time:12149ms step_avg:89.33ms +step:137/1680 train_time:12238ms step_avg:89.33ms +step:138/1680 train_time:12328ms step_avg:89.33ms +step:139/1680 train_time:12420ms step_avg:89.35ms +step:140/1680 train_time:12511ms step_avg:89.36ms +step:141/1680 train_time:12601ms step_avg:89.37ms +step:142/1680 train_time:12690ms step_avg:89.37ms +step:143/1680 train_time:12779ms step_avg:89.36ms +step:144/1680 train_time:12868ms step_avg:89.36ms +step:145/1680 train_time:12956ms step_avg:89.35ms +step:146/1680 train_time:13044ms step_avg:89.35ms +step:147/1680 train_time:13132ms step_avg:89.33ms +step:148/1680 train_time:13221ms step_avg:89.33ms +step:149/1680 train_time:13311ms step_avg:89.33ms +step:150/1680 train_time:13401ms step_avg:89.34ms +step:151/1680 train_time:13492ms step_avg:89.35ms +step:152/1680 train_time:13582ms step_avg:89.35ms +step:153/1680 train_time:13672ms step_avg:89.36ms +step:154/1680 train_time:13760ms step_avg:89.35ms +step:155/1680 train_time:13850ms step_avg:89.35ms +step:156/1680 train_time:13938ms step_avg:89.35ms +step:157/1680 train_time:14027ms step_avg:89.34ms +step:158/1680 train_time:14115ms step_avg:89.34ms +step:159/1680 train_time:14204ms step_avg:89.33ms +step:160/1680 train_time:14293ms step_avg:89.33ms +step:161/1680 train_time:14384ms step_avg:89.34ms +step:162/1680 train_time:14473ms step_avg:89.34ms +step:163/1680 train_time:14563ms step_avg:89.35ms +step:164/1680 train_time:14653ms step_avg:89.35ms +step:165/1680 train_time:14742ms step_avg:89.35ms +step:166/1680 train_time:14831ms step_avg:89.34ms +step:167/1680 train_time:14920ms step_avg:89.34ms +step:168/1680 train_time:15009ms step_avg:89.34ms +step:169/1680 train_time:15099ms step_avg:89.34ms +step:170/1680 train_time:15187ms step_avg:89.34ms +step:171/1680 train_time:15276ms step_avg:89.33ms +step:172/1680 train_time:15366ms step_avg:89.34ms +step:173/1680 train_time:15455ms step_avg:89.34ms +step:174/1680 train_time:15545ms step_avg:89.34ms +step:175/1680 train_time:15634ms step_avg:89.34ms +step:176/1680 train_time:15723ms step_avg:89.33ms +step:177/1680 train_time:15811ms step_avg:89.33ms +step:178/1680 train_time:15900ms step_avg:89.33ms +step:179/1680 train_time:15990ms step_avg:89.33ms +step:180/1680 train_time:16079ms step_avg:89.33ms +step:181/1680 train_time:16169ms step_avg:89.33ms +step:182/1680 train_time:16258ms step_avg:89.33ms +step:183/1680 train_time:16347ms step_avg:89.33ms +step:184/1680 train_time:16437ms step_avg:89.33ms +step:185/1680 train_time:16526ms step_avg:89.33ms +step:186/1680 train_time:16615ms step_avg:89.33ms +step:187/1680 train_time:16704ms step_avg:89.33ms +step:188/1680 train_time:16793ms step_avg:89.32ms +step:189/1680 train_time:16882ms step_avg:89.32ms +step:190/1680 train_time:16970ms step_avg:89.32ms +step:191/1680 train_time:17060ms step_avg:89.32ms +step:192/1680 train_time:17149ms step_avg:89.32ms +step:193/1680 train_time:17239ms step_avg:89.32ms +step:194/1680 train_time:17328ms step_avg:89.32ms +step:195/1680 train_time:17417ms step_avg:89.32ms +step:196/1680 train_time:17507ms step_avg:89.32ms +step:197/1680 train_time:17598ms step_avg:89.33ms +step:198/1680 train_time:17688ms step_avg:89.33ms +step:199/1680 train_time:17777ms step_avg:89.33ms +step:200/1680 train_time:17866ms step_avg:89.33ms +step:201/1680 train_time:17955ms step_avg:89.33ms +step:202/1680 train_time:18044ms step_avg:89.32ms +step:203/1680 train_time:18133ms step_avg:89.32ms +step:204/1680 train_time:18222ms step_avg:89.32ms +step:205/1680 train_time:18311ms step_avg:89.32ms +step:206/1680 train_time:18401ms step_avg:89.32ms +step:207/1680 train_time:18490ms step_avg:89.33ms +step:208/1680 train_time:18580ms step_avg:89.33ms +step:209/1680 train_time:18670ms step_avg:89.33ms +step:210/1680 train_time:18759ms step_avg:89.33ms +step:211/1680 train_time:18849ms step_avg:89.33ms +step:212/1680 train_time:18939ms step_avg:89.33ms +step:213/1680 train_time:19028ms step_avg:89.33ms +step:214/1680 train_time:19117ms step_avg:89.33ms +step:215/1680 train_time:19206ms step_avg:89.33ms +step:216/1680 train_time:19295ms step_avg:89.33ms +step:217/1680 train_time:19385ms step_avg:89.33ms +step:218/1680 train_time:19474ms step_avg:89.33ms +step:219/1680 train_time:19563ms step_avg:89.33ms +step:220/1680 train_time:19653ms step_avg:89.33ms +step:221/1680 train_time:19742ms step_avg:89.33ms +step:222/1680 train_time:19830ms step_avg:89.33ms +step:223/1680 train_time:19920ms step_avg:89.33ms +step:224/1680 train_time:20009ms step_avg:89.33ms +step:225/1680 train_time:20099ms step_avg:89.33ms +step:226/1680 train_time:20188ms step_avg:89.33ms +step:227/1680 train_time:20277ms step_avg:89.32ms +step:228/1680 train_time:20366ms step_avg:89.33ms +step:229/1680 train_time:20455ms step_avg:89.33ms +step:230/1680 train_time:20545ms step_avg:89.33ms +step:231/1680 train_time:20634ms step_avg:89.32ms +step:232/1680 train_time:20723ms step_avg:89.32ms +step:233/1680 train_time:20811ms step_avg:89.32ms +step:234/1680 train_time:20901ms step_avg:89.32ms +step:235/1680 train_time:20990ms step_avg:89.32ms +step:236/1680 train_time:21078ms step_avg:89.32ms +step:237/1680 train_time:21168ms step_avg:89.31ms +step:238/1680 train_time:21256ms step_avg:89.31ms +step:239/1680 train_time:21346ms step_avg:89.31ms +step:240/1680 train_time:21435ms step_avg:89.31ms +step:241/1680 train_time:21524ms step_avg:89.31ms +step:242/1680 train_time:21614ms step_avg:89.32ms +step:243/1680 train_time:21703ms step_avg:89.31ms +step:244/1680 train_time:21792ms step_avg:89.31ms +step:245/1680 train_time:21881ms step_avg:89.31ms +step:246/1680 train_time:21971ms step_avg:89.31ms +step:247/1680 train_time:22060ms step_avg:89.31ms +step:248/1680 train_time:22150ms step_avg:89.32ms +step:249/1680 train_time:22239ms step_avg:89.31ms +step:250/1680 train_time:22329ms step_avg:89.31ms +step:250/1680 val_loss:3.9718 train_time:22419ms step_avg:89.68ms +step:251/1680 train_time:22444ms step_avg:89.42ms +step:252/1680 train_time:22512ms step_avg:89.33ms +step:253/1680 train_time:22607ms step_avg:89.36ms +step:254/1680 train_time:22697ms step_avg:89.36ms +step:255/1680 train_time:22786ms step_avg:89.36ms +step:256/1680 train_time:22874ms step_avg:89.35ms +step:257/1680 train_time:22962ms step_avg:89.35ms +step:258/1680 train_time:23050ms step_avg:89.34ms +step:259/1680 train_time:23139ms step_avg:89.34ms +step:260/1680 train_time:23227ms step_avg:89.33ms +step:261/1680 train_time:23315ms step_avg:89.33ms +step:262/1680 train_time:23404ms step_avg:89.33ms +step:263/1680 train_time:23497ms step_avg:89.34ms +step:264/1680 train_time:23589ms step_avg:89.35ms +step:265/1680 train_time:23679ms step_avg:89.36ms +step:266/1680 train_time:23769ms step_avg:89.36ms +step:267/1680 train_time:23858ms step_avg:89.36ms +step:268/1680 train_time:23947ms step_avg:89.35ms +step:269/1680 train_time:24035ms step_avg:89.35ms +step:270/1680 train_time:24123ms step_avg:89.35ms +step:271/1680 train_time:24212ms step_avg:89.34ms +step:272/1680 train_time:24300ms step_avg:89.34ms +step:273/1680 train_time:24389ms step_avg:89.34ms +step:274/1680 train_time:24479ms step_avg:89.34ms +step:275/1680 train_time:24570ms step_avg:89.35ms +step:276/1680 train_time:24660ms step_avg:89.35ms +step:277/1680 train_time:24749ms step_avg:89.35ms +step:278/1680 train_time:24839ms step_avg:89.35ms +step:279/1680 train_time:24928ms step_avg:89.35ms +step:280/1680 train_time:25017ms step_avg:89.35ms +step:281/1680 train_time:25106ms step_avg:89.35ms +step:282/1680 train_time:25196ms step_avg:89.35ms +step:283/1680 train_time:25285ms step_avg:89.35ms +step:284/1680 train_time:25374ms step_avg:89.34ms +step:285/1680 train_time:25463ms step_avg:89.34ms +step:286/1680 train_time:25553ms step_avg:89.35ms +step:287/1680 train_time:25643ms step_avg:89.35ms +step:288/1680 train_time:25732ms step_avg:89.35ms +step:289/1680 train_time:25822ms step_avg:89.35ms +step:290/1680 train_time:25912ms step_avg:89.35ms +step:291/1680 train_time:26001ms step_avg:89.35ms +step:292/1680 train_time:26090ms step_avg:89.35ms +step:293/1680 train_time:26183ms step_avg:89.36ms +step:294/1680 train_time:26267ms step_avg:89.34ms +step:295/1680 train_time:26356ms step_avg:89.34ms +step:296/1680 train_time:26446ms step_avg:89.34ms +step:297/1680 train_time:26536ms step_avg:89.35ms +step:298/1680 train_time:26625ms step_avg:89.35ms +step:299/1680 train_time:26715ms step_avg:89.35ms +step:300/1680 train_time:26804ms step_avg:89.35ms +step:301/1680 train_time:26894ms step_avg:89.35ms +step:302/1680 train_time:26983ms step_avg:89.35ms +step:303/1680 train_time:27072ms step_avg:89.35ms +step:304/1680 train_time:27162ms step_avg:89.35ms +step:305/1680 train_time:27250ms step_avg:89.34ms +step:306/1680 train_time:27339ms step_avg:89.34ms +step:307/1680 train_time:27428ms step_avg:89.34ms +step:308/1680 train_time:27517ms step_avg:89.34ms +step:309/1680 train_time:27607ms step_avg:89.34ms +step:310/1680 train_time:27699ms step_avg:89.35ms +step:311/1680 train_time:27788ms step_avg:89.35ms +step:312/1680 train_time:27878ms step_avg:89.35ms +step:313/1680 train_time:27968ms step_avg:89.35ms +step:314/1680 train_time:28057ms step_avg:89.35ms +step:315/1680 train_time:28146ms step_avg:89.35ms +step:316/1680 train_time:28235ms step_avg:89.35ms +step:317/1680 train_time:28325ms step_avg:89.35ms +step:318/1680 train_time:28414ms step_avg:89.35ms +step:319/1680 train_time:28503ms step_avg:89.35ms +step:320/1680 train_time:28592ms step_avg:89.35ms +step:321/1680 train_time:28681ms step_avg:89.35ms +step:322/1680 train_time:28771ms step_avg:89.35ms +step:323/1680 train_time:28861ms step_avg:89.35ms +step:324/1680 train_time:28950ms step_avg:89.35ms +step:325/1680 train_time:29039ms step_avg:89.35ms +step:326/1680 train_time:29129ms step_avg:89.35ms +step:327/1680 train_time:29218ms step_avg:89.35ms +step:328/1680 train_time:29307ms step_avg:89.35ms +step:329/1680 train_time:29397ms step_avg:89.35ms +step:330/1680 train_time:29487ms step_avg:89.35ms +step:331/1680 train_time:29577ms step_avg:89.36ms +step:332/1680 train_time:29666ms step_avg:89.36ms +step:333/1680 train_time:29756ms step_avg:89.36ms +step:334/1680 train_time:29845ms step_avg:89.36ms +step:335/1680 train_time:29934ms step_avg:89.36ms +step:336/1680 train_time:30024ms step_avg:89.36ms +step:337/1680 train_time:30114ms step_avg:89.36ms +step:338/1680 train_time:30202ms step_avg:89.36ms +step:339/1680 train_time:30292ms step_avg:89.36ms +step:340/1680 train_time:30385ms step_avg:89.37ms +step:341/1680 train_time:30470ms step_avg:89.35ms +step:342/1680 train_time:30560ms step_avg:89.36ms +step:343/1680 train_time:30648ms step_avg:89.35ms +step:344/1680 train_time:30738ms step_avg:89.36ms +step:345/1680 train_time:30828ms step_avg:89.36ms +step:346/1680 train_time:30916ms step_avg:89.35ms +step:347/1680 train_time:31005ms step_avg:89.35ms +step:348/1680 train_time:31095ms step_avg:89.35ms +step:349/1680 train_time:31184ms step_avg:89.35ms +step:350/1680 train_time:31275ms step_avg:89.36ms +step:351/1680 train_time:31364ms step_avg:89.36ms +step:352/1680 train_time:31454ms step_avg:89.36ms +step:353/1680 train_time:31544ms step_avg:89.36ms +step:354/1680 train_time:31634ms step_avg:89.36ms +step:355/1680 train_time:31724ms step_avg:89.36ms +step:356/1680 train_time:31812ms step_avg:89.36ms +step:357/1680 train_time:31901ms step_avg:89.36ms +step:358/1680 train_time:31989ms step_avg:89.36ms +step:359/1680 train_time:32079ms step_avg:89.36ms +step:360/1680 train_time:32168ms step_avg:89.36ms +step:361/1680 train_time:32258ms step_avg:89.36ms +step:362/1680 train_time:32348ms step_avg:89.36ms +step:363/1680 train_time:32438ms step_avg:89.36ms +step:364/1680 train_time:32527ms step_avg:89.36ms +step:365/1680 train_time:32616ms step_avg:89.36ms +step:366/1680 train_time:32706ms step_avg:89.36ms +step:367/1680 train_time:32796ms step_avg:89.36ms +step:368/1680 train_time:32885ms step_avg:89.36ms +step:369/1680 train_time:32974ms step_avg:89.36ms +step:370/1680 train_time:33064ms step_avg:89.36ms +step:371/1680 train_time:33154ms step_avg:89.36ms +step:372/1680 train_time:33243ms step_avg:89.36ms +step:373/1680 train_time:33333ms step_avg:89.36ms +step:374/1680 train_time:33422ms step_avg:89.36ms +step:375/1680 train_time:33512ms step_avg:89.37ms +step:375/1680 val_loss:3.8163 train_time:33602ms step_avg:89.61ms +step:376/1680 train_time:33627ms step_avg:89.43ms +step:377/1680 train_time:33699ms step_avg:89.39ms +step:378/1680 train_time:33795ms step_avg:89.40ms +step:379/1680 train_time:33885ms step_avg:89.41ms +step:380/1680 train_time:33974ms step_avg:89.41ms +step:381/1680 train_time:34063ms step_avg:89.41ms +step:382/1680 train_time:34154ms step_avg:89.41ms +step:383/1680 train_time:34241ms step_avg:89.40ms +step:384/1680 train_time:34329ms step_avg:89.40ms +step:385/1680 train_time:34417ms step_avg:89.40ms +step:386/1680 train_time:34505ms step_avg:89.39ms +step:387/1680 train_time:34594ms step_avg:89.39ms +step:388/1680 train_time:34685ms step_avg:89.40ms +step:389/1680 train_time:34777ms step_avg:89.40ms +step:390/1680 train_time:34867ms step_avg:89.40ms +step:391/1680 train_time:34957ms step_avg:89.40ms +step:392/1680 train_time:35046ms step_avg:89.40ms +step:393/1680 train_time:35136ms step_avg:89.40ms +step:394/1680 train_time:35225ms step_avg:89.40ms +step:395/1680 train_time:35313ms step_avg:89.40ms +step:396/1680 train_time:35402ms step_avg:89.40ms +step:397/1680 train_time:35491ms step_avg:89.40ms +step:398/1680 train_time:35580ms step_avg:89.40ms +step:399/1680 train_time:35671ms step_avg:89.40ms +step:400/1680 train_time:35761ms step_avg:89.40ms +step:401/1680 train_time:35851ms step_avg:89.40ms +step:402/1680 train_time:35941ms step_avg:89.41ms +step:403/1680 train_time:36031ms step_avg:89.41ms +step:404/1680 train_time:36120ms step_avg:89.41ms +step:405/1680 train_time:36209ms step_avg:89.40ms +step:406/1680 train_time:36297ms step_avg:89.40ms +step:407/1680 train_time:36386ms step_avg:89.40ms +step:408/1680 train_time:36475ms step_avg:89.40ms +step:409/1680 train_time:36564ms step_avg:89.40ms +step:410/1680 train_time:36654ms step_avg:89.40ms +step:411/1680 train_time:36744ms step_avg:89.40ms +step:412/1680 train_time:36835ms step_avg:89.40ms +step:413/1680 train_time:36925ms step_avg:89.41ms +step:414/1680 train_time:37015ms step_avg:89.41ms +step:415/1680 train_time:37104ms step_avg:89.41ms +step:416/1680 train_time:37194ms step_avg:89.41ms +step:417/1680 train_time:37283ms step_avg:89.41ms +step:418/1680 train_time:37372ms step_avg:89.41ms +step:419/1680 train_time:37460ms step_avg:89.40ms +step:420/1680 train_time:37551ms step_avg:89.41ms +step:421/1680 train_time:37640ms step_avg:89.41ms +step:422/1680 train_time:37729ms step_avg:89.41ms +step:423/1680 train_time:37820ms step_avg:89.41ms +step:424/1680 train_time:37909ms step_avg:89.41ms +step:425/1680 train_time:37999ms step_avg:89.41ms +step:426/1680 train_time:38087ms step_avg:89.41ms +step:427/1680 train_time:38176ms step_avg:89.41ms +step:428/1680 train_time:38265ms step_avg:89.40ms +step:429/1680 train_time:38355ms step_avg:89.40ms +step:430/1680 train_time:38443ms step_avg:89.40ms +step:431/1680 train_time:38533ms step_avg:89.40ms +step:432/1680 train_time:38623ms step_avg:89.40ms +step:433/1680 train_time:38713ms step_avg:89.41ms +step:434/1680 train_time:38802ms step_avg:89.41ms +step:435/1680 train_time:38892ms step_avg:89.41ms +step:436/1680 train_time:38982ms step_avg:89.41ms +step:437/1680 train_time:39072ms step_avg:89.41ms +step:438/1680 train_time:39162ms step_avg:89.41ms +step:439/1680 train_time:39251ms step_avg:89.41ms +step:440/1680 train_time:39340ms step_avg:89.41ms +step:441/1680 train_time:39429ms step_avg:89.41ms +step:442/1680 train_time:39518ms step_avg:89.41ms +step:443/1680 train_time:39607ms step_avg:89.41ms +step:444/1680 train_time:39697ms step_avg:89.41ms +step:445/1680 train_time:39787ms step_avg:89.41ms +step:446/1680 train_time:39876ms step_avg:89.41ms +step:447/1680 train_time:39966ms step_avg:89.41ms +step:448/1680 train_time:40055ms step_avg:89.41ms +step:449/1680 train_time:40145ms step_avg:89.41ms +step:450/1680 train_time:40234ms step_avg:89.41ms +step:451/1680 train_time:40324ms step_avg:89.41ms +step:452/1680 train_time:40413ms step_avg:89.41ms +step:453/1680 train_time:40503ms step_avg:89.41ms +step:454/1680 train_time:40593ms step_avg:89.41ms +step:455/1680 train_time:40682ms step_avg:89.41ms +step:456/1680 train_time:40771ms step_avg:89.41ms +step:457/1680 train_time:40861ms step_avg:89.41ms +step:458/1680 train_time:40951ms step_avg:89.41ms +step:459/1680 train_time:41043ms step_avg:89.42ms +step:460/1680 train_time:41129ms step_avg:89.41ms +step:461/1680 train_time:41218ms step_avg:89.41ms +step:462/1680 train_time:41308ms step_avg:89.41ms +step:463/1680 train_time:41397ms step_avg:89.41ms +step:464/1680 train_time:41486ms step_avg:89.41ms +step:465/1680 train_time:41576ms step_avg:89.41ms +step:466/1680 train_time:41665ms step_avg:89.41ms +step:467/1680 train_time:41755ms step_avg:89.41ms +step:468/1680 train_time:41845ms step_avg:89.41ms +step:469/1680 train_time:41934ms step_avg:89.41ms +step:470/1680 train_time:42025ms step_avg:89.41ms +step:471/1680 train_time:42114ms step_avg:89.41ms +step:472/1680 train_time:42203ms step_avg:89.41ms +step:473/1680 train_time:42294ms step_avg:89.42ms +step:474/1680 train_time:42384ms step_avg:89.42ms +step:475/1680 train_time:42474ms step_avg:89.42ms +step:476/1680 train_time:42563ms step_avg:89.42ms +step:477/1680 train_time:42654ms step_avg:89.42ms +step:478/1680 train_time:42741ms step_avg:89.42ms +step:479/1680 train_time:42830ms step_avg:89.42ms +step:480/1680 train_time:42920ms step_avg:89.42ms +step:481/1680 train_time:43010ms step_avg:89.42ms +step:482/1680 train_time:43100ms step_avg:89.42ms +step:483/1680 train_time:43189ms step_avg:89.42ms +step:484/1680 train_time:43278ms step_avg:89.42ms +step:485/1680 train_time:43368ms step_avg:89.42ms +step:486/1680 train_time:43458ms step_avg:89.42ms +step:487/1680 train_time:43548ms step_avg:89.42ms +step:488/1680 train_time:43637ms step_avg:89.42ms +step:489/1680 train_time:43727ms step_avg:89.42ms +step:490/1680 train_time:43816ms step_avg:89.42ms +step:491/1680 train_time:43905ms step_avg:89.42ms +step:492/1680 train_time:43995ms step_avg:89.42ms +step:493/1680 train_time:44084ms step_avg:89.42ms +step:494/1680 train_time:44173ms step_avg:89.42ms +step:495/1680 train_time:44262ms step_avg:89.42ms +step:496/1680 train_time:44352ms step_avg:89.42ms +step:497/1680 train_time:44441ms step_avg:89.42ms +step:498/1680 train_time:44531ms step_avg:89.42ms +step:499/1680 train_time:44620ms step_avg:89.42ms +step:500/1680 train_time:44710ms step_avg:89.42ms +step:500/1680 val_loss:3.7177 train_time:44800ms step_avg:89.60ms +step:501/1680 train_time:44826ms step_avg:89.47ms +step:502/1680 train_time:44895ms step_avg:89.43ms +step:503/1680 train_time:44988ms step_avg:89.44ms +step:504/1680 train_time:45079ms step_avg:89.44ms +step:505/1680 train_time:45168ms step_avg:89.44ms +step:506/1680 train_time:45256ms step_avg:89.44ms +step:507/1680 train_time:45345ms step_avg:89.44ms +step:508/1680 train_time:45434ms step_avg:89.44ms +step:509/1680 train_time:45523ms step_avg:89.44ms +step:510/1680 train_time:45611ms step_avg:89.43ms +step:511/1680 train_time:45700ms step_avg:89.43ms +step:512/1680 train_time:45790ms step_avg:89.43ms +step:513/1680 train_time:45881ms step_avg:89.44ms +step:514/1680 train_time:45972ms step_avg:89.44ms +step:515/1680 train_time:46063ms step_avg:89.44ms +step:516/1680 train_time:46153ms step_avg:89.44ms +step:517/1680 train_time:46242ms step_avg:89.44ms +step:518/1680 train_time:46331ms step_avg:89.44ms +step:519/1680 train_time:46420ms step_avg:89.44ms +step:520/1680 train_time:46508ms step_avg:89.44ms +step:521/1680 train_time:46596ms step_avg:89.44ms +step:522/1680 train_time:46685ms step_avg:89.44ms +step:523/1680 train_time:46774ms step_avg:89.43ms +step:524/1680 train_time:46863ms step_avg:89.43ms +step:525/1680 train_time:46955ms step_avg:89.44ms +step:526/1680 train_time:47046ms step_avg:89.44ms +step:527/1680 train_time:47136ms step_avg:89.44ms +step:528/1680 train_time:47226ms step_avg:89.44ms +step:529/1680 train_time:47316ms step_avg:89.44ms +step:530/1680 train_time:47405ms step_avg:89.44ms +step:531/1680 train_time:47494ms step_avg:89.44ms +step:532/1680 train_time:47583ms step_avg:89.44ms +step:533/1680 train_time:47672ms step_avg:89.44ms +step:534/1680 train_time:47761ms step_avg:89.44ms +step:535/1680 train_time:47850ms step_avg:89.44ms +step:536/1680 train_time:47941ms step_avg:89.44ms +step:537/1680 train_time:48031ms step_avg:89.44ms +step:538/1680 train_time:48121ms step_avg:89.44ms +step:539/1680 train_time:48211ms step_avg:89.44ms +step:540/1680 train_time:48301ms step_avg:89.45ms +step:541/1680 train_time:48389ms step_avg:89.44ms +step:542/1680 train_time:48478ms step_avg:89.44ms +step:543/1680 train_time:48567ms step_avg:89.44ms +step:544/1680 train_time:48657ms step_avg:89.44ms +step:545/1680 train_time:48746ms step_avg:89.44ms +step:546/1680 train_time:48836ms step_avg:89.44ms +step:547/1680 train_time:48926ms step_avg:89.44ms +step:548/1680 train_time:49017ms step_avg:89.45ms +step:549/1680 train_time:49108ms step_avg:89.45ms +step:550/1680 train_time:49200ms step_avg:89.45ms +step:551/1680 train_time:49290ms step_avg:89.46ms +step:552/1680 train_time:49380ms step_avg:89.46ms +step:553/1680 train_time:49470ms step_avg:89.46ms +step:554/1680 train_time:49560ms step_avg:89.46ms +step:555/1680 train_time:49650ms step_avg:89.46ms +step:556/1680 train_time:49742ms step_avg:89.46ms +step:557/1680 train_time:49832ms step_avg:89.47ms +step:558/1680 train_time:49922ms step_avg:89.47ms +step:559/1680 train_time:50013ms step_avg:89.47ms +step:560/1680 train_time:50106ms step_avg:89.48ms +step:561/1680 train_time:50196ms step_avg:89.48ms +step:562/1680 train_time:50287ms step_avg:89.48ms +step:563/1680 train_time:50379ms step_avg:89.48ms +step:564/1680 train_time:50469ms step_avg:89.48ms +step:565/1680 train_time:50559ms step_avg:89.49ms +step:566/1680 train_time:50649ms step_avg:89.49ms +step:567/1680 train_time:50740ms step_avg:89.49ms +step:568/1680 train_time:50830ms step_avg:89.49ms +step:569/1680 train_time:50921ms step_avg:89.49ms +step:570/1680 train_time:51012ms step_avg:89.49ms +step:571/1680 train_time:51103ms step_avg:89.50ms +step:572/1680 train_time:51194ms step_avg:89.50ms +step:573/1680 train_time:51286ms step_avg:89.50ms +step:574/1680 train_time:51377ms step_avg:89.51ms +step:575/1680 train_time:51467ms step_avg:89.51ms +step:576/1680 train_time:51558ms step_avg:89.51ms +step:577/1680 train_time:51649ms step_avg:89.51ms +step:578/1680 train_time:51739ms step_avg:89.51ms +step:579/1680 train_time:51829ms step_avg:89.52ms +step:580/1680 train_time:51921ms step_avg:89.52ms +step:581/1680 train_time:52011ms step_avg:89.52ms +step:582/1680 train_time:52102ms step_avg:89.52ms +step:583/1680 train_time:52193ms step_avg:89.52ms +step:584/1680 train_time:52283ms step_avg:89.53ms +step:585/1680 train_time:52374ms step_avg:89.53ms +step:586/1680 train_time:52465ms step_avg:89.53ms +step:587/1680 train_time:52556ms step_avg:89.53ms +step:588/1680 train_time:52648ms step_avg:89.54ms +step:589/1680 train_time:52739ms step_avg:89.54ms +step:590/1680 train_time:52830ms step_avg:89.54ms +step:591/1680 train_time:52922ms step_avg:89.55ms +step:592/1680 train_time:53012ms step_avg:89.55ms +step:593/1680 train_time:53103ms step_avg:89.55ms +step:594/1680 train_time:53193ms step_avg:89.55ms +step:595/1680 train_time:53284ms step_avg:89.55ms +step:596/1680 train_time:53375ms step_avg:89.56ms +step:597/1680 train_time:53466ms step_avg:89.56ms +step:598/1680 train_time:53556ms step_avg:89.56ms +step:599/1680 train_time:53647ms step_avg:89.56ms +step:600/1680 train_time:53738ms step_avg:89.56ms +step:601/1680 train_time:53828ms step_avg:89.56ms +step:602/1680 train_time:53920ms step_avg:89.57ms +step:603/1680 train_time:54011ms step_avg:89.57ms +step:604/1680 train_time:54101ms step_avg:89.57ms +step:605/1680 train_time:54191ms step_avg:89.57ms +step:606/1680 train_time:54283ms step_avg:89.58ms +step:607/1680 train_time:54374ms step_avg:89.58ms +step:608/1680 train_time:54464ms step_avg:89.58ms +step:609/1680 train_time:54555ms step_avg:89.58ms +step:610/1680 train_time:54646ms step_avg:89.58ms +step:611/1680 train_time:54737ms step_avg:89.59ms +step:612/1680 train_time:54828ms step_avg:89.59ms +step:613/1680 train_time:54918ms step_avg:89.59ms +step:614/1680 train_time:55009ms step_avg:89.59ms +step:615/1680 train_time:55100ms step_avg:89.59ms +step:616/1680 train_time:55190ms step_avg:89.59ms +step:617/1680 train_time:55282ms step_avg:89.60ms +step:618/1680 train_time:55373ms step_avg:89.60ms +step:619/1680 train_time:55464ms step_avg:89.60ms +step:620/1680 train_time:55554ms step_avg:89.60ms +step:621/1680 train_time:55645ms step_avg:89.60ms +step:622/1680 train_time:55735ms step_avg:89.61ms +step:623/1680 train_time:55826ms step_avg:89.61ms +step:624/1680 train_time:55918ms step_avg:89.61ms +step:625/1680 train_time:56008ms step_avg:89.61ms +step:625/1680 val_loss:3.6172 train_time:56101ms step_avg:89.76ms +step:626/1680 train_time:56126ms step_avg:89.66ms +step:627/1680 train_time:56194ms step_avg:89.62ms +step:628/1680 train_time:56294ms step_avg:89.64ms +step:629/1680 train_time:56385ms step_avg:89.64ms +step:630/1680 train_time:56474ms step_avg:89.64ms +step:631/1680 train_time:56562ms step_avg:89.64ms +step:632/1680 train_time:56650ms step_avg:89.64ms +step:633/1680 train_time:56738ms step_avg:89.63ms +step:634/1680 train_time:56828ms step_avg:89.63ms +step:635/1680 train_time:56916ms step_avg:89.63ms +step:636/1680 train_time:57007ms step_avg:89.63ms +step:637/1680 train_time:57098ms step_avg:89.64ms +step:638/1680 train_time:57190ms step_avg:89.64ms +step:639/1680 train_time:57282ms step_avg:89.64ms +step:640/1680 train_time:57373ms step_avg:89.65ms +step:641/1680 train_time:57462ms step_avg:89.64ms +step:642/1680 train_time:57552ms step_avg:89.64ms +step:643/1680 train_time:57642ms step_avg:89.65ms +step:644/1680 train_time:57731ms step_avg:89.64ms +step:645/1680 train_time:57820ms step_avg:89.64ms +step:646/1680 train_time:57909ms step_avg:89.64ms +step:647/1680 train_time:57999ms step_avg:89.64ms +step:648/1680 train_time:58089ms step_avg:89.64ms +step:649/1680 train_time:58181ms step_avg:89.65ms +step:650/1680 train_time:58272ms step_avg:89.65ms +step:651/1680 train_time:58362ms step_avg:89.65ms +step:652/1680 train_time:58452ms step_avg:89.65ms +step:653/1680 train_time:58542ms step_avg:89.65ms +step:654/1680 train_time:58632ms step_avg:89.65ms +step:655/1680 train_time:58722ms step_avg:89.65ms +step:656/1680 train_time:58811ms step_avg:89.65ms +step:657/1680 train_time:58900ms step_avg:89.65ms +step:658/1680 train_time:58990ms step_avg:89.65ms +step:659/1680 train_time:59081ms step_avg:89.65ms +step:660/1680 train_time:59171ms step_avg:89.65ms +step:661/1680 train_time:59261ms step_avg:89.65ms +step:662/1680 train_time:59352ms step_avg:89.66ms +step:663/1680 train_time:59443ms step_avg:89.66ms +step:664/1680 train_time:59532ms step_avg:89.66ms +step:665/1680 train_time:59621ms step_avg:89.66ms +step:666/1680 train_time:59711ms step_avg:89.66ms +step:667/1680 train_time:59801ms step_avg:89.66ms +step:668/1680 train_time:59891ms step_avg:89.66ms +step:669/1680 train_time:59980ms step_avg:89.66ms +step:670/1680 train_time:60070ms step_avg:89.66ms +step:671/1680 train_time:60160ms step_avg:89.66ms +step:672/1680 train_time:60250ms step_avg:89.66ms +step:673/1680 train_time:60341ms step_avg:89.66ms +step:674/1680 train_time:60431ms step_avg:89.66ms +step:675/1680 train_time:60521ms step_avg:89.66ms +step:676/1680 train_time:60611ms step_avg:89.66ms +step:677/1680 train_time:60701ms step_avg:89.66ms +step:678/1680 train_time:60791ms step_avg:89.66ms +step:679/1680 train_time:60880ms step_avg:89.66ms +step:680/1680 train_time:60970ms step_avg:89.66ms +step:681/1680 train_time:61059ms step_avg:89.66ms +step:682/1680 train_time:61150ms step_avg:89.66ms +step:683/1680 train_time:61240ms step_avg:89.66ms +step:684/1680 train_time:61330ms step_avg:89.66ms +step:685/1680 train_time:61420ms step_avg:89.66ms +step:686/1680 train_time:61509ms step_avg:89.66ms +step:687/1680 train_time:61599ms step_avg:89.66ms +step:688/1680 train_time:61689ms step_avg:89.66ms +step:689/1680 train_time:61779ms step_avg:89.67ms +step:690/1680 train_time:61869ms step_avg:89.66ms +step:691/1680 train_time:61958ms step_avg:89.66ms +step:692/1680 train_time:62048ms step_avg:89.67ms +step:693/1680 train_time:62139ms step_avg:89.67ms +step:694/1680 train_time:62230ms step_avg:89.67ms +step:695/1680 train_time:62321ms step_avg:89.67ms +step:696/1680 train_time:62411ms step_avg:89.67ms +step:697/1680 train_time:62501ms step_avg:89.67ms +step:698/1680 train_time:62592ms step_avg:89.67ms +step:699/1680 train_time:62682ms step_avg:89.67ms +step:700/1680 train_time:62772ms step_avg:89.67ms +step:701/1680 train_time:62862ms step_avg:89.67ms +step:702/1680 train_time:62951ms step_avg:89.67ms +step:703/1680 train_time:63042ms step_avg:89.68ms +step:704/1680 train_time:63132ms step_avg:89.68ms +step:705/1680 train_time:63222ms step_avg:89.68ms +step:706/1680 train_time:63312ms step_avg:89.68ms +step:707/1680 train_time:63402ms step_avg:89.68ms +step:708/1680 train_time:63492ms step_avg:89.68ms +step:709/1680 train_time:63584ms step_avg:89.68ms +step:710/1680 train_time:63674ms step_avg:89.68ms +step:711/1680 train_time:63763ms step_avg:89.68ms +step:712/1680 train_time:63854ms step_avg:89.68ms +step:713/1680 train_time:63944ms step_avg:89.68ms +step:714/1680 train_time:64034ms step_avg:89.68ms +step:715/1680 train_time:64123ms step_avg:89.68ms +step:716/1680 train_time:64213ms step_avg:89.68ms +step:717/1680 train_time:64305ms step_avg:89.69ms +step:718/1680 train_time:64395ms step_avg:89.69ms +step:719/1680 train_time:64485ms step_avg:89.69ms +step:720/1680 train_time:64575ms step_avg:89.69ms +step:721/1680 train_time:64666ms step_avg:89.69ms +step:722/1680 train_time:64757ms step_avg:89.69ms +step:723/1680 train_time:64847ms step_avg:89.69ms +step:724/1680 train_time:64937ms step_avg:89.69ms +step:725/1680 train_time:65027ms step_avg:89.69ms +step:726/1680 train_time:65117ms step_avg:89.69ms +step:727/1680 train_time:65208ms step_avg:89.69ms +step:728/1680 train_time:65299ms step_avg:89.70ms +step:729/1680 train_time:65389ms step_avg:89.70ms +step:730/1680 train_time:65480ms step_avg:89.70ms +step:731/1680 train_time:65570ms step_avg:89.70ms +step:732/1680 train_time:65660ms step_avg:89.70ms +step:733/1680 train_time:65751ms step_avg:89.70ms +step:734/1680 train_time:65841ms step_avg:89.70ms +step:735/1680 train_time:65931ms step_avg:89.70ms +step:736/1680 train_time:66023ms step_avg:89.71ms +step:737/1680 train_time:66111ms step_avg:89.70ms +step:738/1680 train_time:66201ms step_avg:89.70ms +step:739/1680 train_time:66290ms step_avg:89.70ms +step:740/1680 train_time:66381ms step_avg:89.70ms +step:741/1680 train_time:66471ms step_avg:89.70ms +step:742/1680 train_time:66561ms step_avg:89.70ms +step:743/1680 train_time:66652ms step_avg:89.71ms +step:744/1680 train_time:66742ms step_avg:89.71ms +step:745/1680 train_time:66832ms step_avg:89.71ms +step:746/1680 train_time:66922ms step_avg:89.71ms +step:747/1680 train_time:67012ms step_avg:89.71ms +step:748/1680 train_time:67101ms step_avg:89.71ms +step:749/1680 train_time:67191ms step_avg:89.71ms +step:750/1680 train_time:67281ms step_avg:89.71ms +step:750/1680 val_loss:3.5669 train_time:67372ms step_avg:89.83ms +step:751/1680 train_time:67398ms step_avg:89.74ms +step:752/1680 train_time:67466ms step_avg:89.72ms +step:753/1680 train_time:67561ms step_avg:89.72ms +step:754/1680 train_time:67652ms step_avg:89.72ms +step:755/1680 train_time:67742ms step_avg:89.73ms +step:756/1680 train_time:67832ms step_avg:89.72ms +step:757/1680 train_time:67921ms step_avg:89.72ms +step:758/1680 train_time:68010ms step_avg:89.72ms +step:759/1680 train_time:68099ms step_avg:89.72ms +step:760/1680 train_time:68188ms step_avg:89.72ms +step:761/1680 train_time:68278ms step_avg:89.72ms +step:762/1680 train_time:68370ms step_avg:89.72ms +step:763/1680 train_time:68463ms step_avg:89.73ms +step:764/1680 train_time:68556ms step_avg:89.73ms +step:765/1680 train_time:68648ms step_avg:89.74ms +step:766/1680 train_time:68738ms step_avg:89.74ms +step:767/1680 train_time:68828ms step_avg:89.74ms +step:768/1680 train_time:68917ms step_avg:89.74ms +step:769/1680 train_time:69006ms step_avg:89.74ms +step:770/1680 train_time:69095ms step_avg:89.73ms +step:771/1680 train_time:69184ms step_avg:89.73ms +step:772/1680 train_time:69274ms step_avg:89.73ms +step:773/1680 train_time:69364ms step_avg:89.73ms +step:774/1680 train_time:69456ms step_avg:89.74ms +step:775/1680 train_time:69548ms step_avg:89.74ms +step:776/1680 train_time:69640ms step_avg:89.74ms +step:777/1680 train_time:69731ms step_avg:89.74ms +step:778/1680 train_time:69821ms step_avg:89.74ms +step:779/1680 train_time:69911ms step_avg:89.74ms +step:780/1680 train_time:70002ms step_avg:89.75ms +step:781/1680 train_time:70091ms step_avg:89.75ms +step:782/1680 train_time:70181ms step_avg:89.75ms +step:783/1680 train_time:70271ms step_avg:89.75ms +step:784/1680 train_time:70361ms step_avg:89.75ms +step:785/1680 train_time:70452ms step_avg:89.75ms +step:786/1680 train_time:70544ms step_avg:89.75ms +step:787/1680 train_time:70636ms step_avg:89.75ms +step:788/1680 train_time:70726ms step_avg:89.75ms +step:789/1680 train_time:70816ms step_avg:89.75ms +step:790/1680 train_time:70906ms step_avg:89.75ms +step:791/1680 train_time:70995ms step_avg:89.75ms +step:792/1680 train_time:71084ms step_avg:89.75ms +step:793/1680 train_time:71174ms step_avg:89.75ms +step:794/1680 train_time:71264ms step_avg:89.75ms +step:795/1680 train_time:71353ms step_avg:89.75ms +step:796/1680 train_time:71443ms step_avg:89.75ms +step:797/1680 train_time:71534ms step_avg:89.75ms +step:798/1680 train_time:71625ms step_avg:89.76ms +step:799/1680 train_time:71715ms step_avg:89.76ms +step:800/1680 train_time:71805ms step_avg:89.76ms +step:801/1680 train_time:71897ms step_avg:89.76ms +step:802/1680 train_time:71987ms step_avg:89.76ms +step:803/1680 train_time:72077ms step_avg:89.76ms +step:804/1680 train_time:72166ms step_avg:89.76ms +step:805/1680 train_time:72256ms step_avg:89.76ms +step:806/1680 train_time:72347ms step_avg:89.76ms +step:807/1680 train_time:72438ms step_avg:89.76ms +step:808/1680 train_time:72529ms step_avg:89.76ms +step:809/1680 train_time:72620ms step_avg:89.76ms +step:810/1680 train_time:72711ms step_avg:89.77ms +step:811/1680 train_time:72801ms step_avg:89.77ms +step:812/1680 train_time:72892ms step_avg:89.77ms +step:813/1680 train_time:72983ms step_avg:89.77ms +step:814/1680 train_time:73072ms step_avg:89.77ms +step:815/1680 train_time:73162ms step_avg:89.77ms +step:816/1680 train_time:73252ms step_avg:89.77ms +step:817/1680 train_time:73342ms step_avg:89.77ms +step:818/1680 train_time:73433ms step_avg:89.77ms +step:819/1680 train_time:73523ms step_avg:89.77ms +step:820/1680 train_time:73613ms step_avg:89.77ms +step:821/1680 train_time:73703ms step_avg:89.77ms +step:822/1680 train_time:73793ms step_avg:89.77ms +step:823/1680 train_time:73883ms step_avg:89.77ms +step:824/1680 train_time:73974ms step_avg:89.77ms +step:825/1680 train_time:74066ms step_avg:89.78ms +step:826/1680 train_time:74155ms step_avg:89.78ms +step:827/1680 train_time:74245ms step_avg:89.78ms +step:828/1680 train_time:74335ms step_avg:89.78ms +step:829/1680 train_time:74425ms step_avg:89.78ms +step:830/1680 train_time:74515ms step_avg:89.78ms +step:831/1680 train_time:74606ms step_avg:89.78ms +step:832/1680 train_time:74696ms step_avg:89.78ms +step:833/1680 train_time:74787ms step_avg:89.78ms +step:834/1680 train_time:74877ms step_avg:89.78ms +step:835/1680 train_time:74967ms step_avg:89.78ms +step:836/1680 train_time:75059ms step_avg:89.78ms +step:837/1680 train_time:75148ms step_avg:89.78ms +step:838/1680 train_time:75238ms step_avg:89.78ms +step:839/1680 train_time:75328ms step_avg:89.78ms +step:840/1680 train_time:75417ms step_avg:89.78ms +step:841/1680 train_time:75509ms step_avg:89.78ms +step:842/1680 train_time:75599ms step_avg:89.78ms +step:843/1680 train_time:75689ms step_avg:89.78ms +step:844/1680 train_time:75779ms step_avg:89.79ms +step:845/1680 train_time:75870ms step_avg:89.79ms +step:846/1680 train_time:75961ms step_avg:89.79ms +step:847/1680 train_time:76052ms step_avg:89.79ms +step:848/1680 train_time:76142ms step_avg:89.79ms +step:849/1680 train_time:76232ms step_avg:89.79ms +step:850/1680 train_time:76322ms step_avg:89.79ms +step:851/1680 train_time:76413ms step_avg:89.79ms +step:852/1680 train_time:76502ms step_avg:89.79ms +step:853/1680 train_time:76593ms step_avg:89.79ms +step:854/1680 train_time:76683ms step_avg:89.79ms +step:855/1680 train_time:76778ms step_avg:89.80ms +step:856/1680 train_time:76862ms step_avg:89.79ms +step:857/1680 train_time:76953ms step_avg:89.79ms +step:858/1680 train_time:77043ms step_avg:89.79ms +step:859/1680 train_time:77134ms step_avg:89.79ms +step:860/1680 train_time:77223ms step_avg:89.79ms +step:861/1680 train_time:77313ms step_avg:89.79ms +step:862/1680 train_time:77403ms step_avg:89.79ms +step:863/1680 train_time:77493ms step_avg:89.79ms +step:864/1680 train_time:77582ms step_avg:89.79ms +step:865/1680 train_time:77672ms step_avg:89.79ms +step:866/1680 train_time:77763ms step_avg:89.80ms +step:867/1680 train_time:77853ms step_avg:89.80ms +step:868/1680 train_time:77943ms step_avg:89.80ms +step:869/1680 train_time:78033ms step_avg:89.80ms +step:870/1680 train_time:78123ms step_avg:89.80ms +step:871/1680 train_time:78213ms step_avg:89.80ms +step:872/1680 train_time:78303ms step_avg:89.80ms +step:873/1680 train_time:78393ms step_avg:89.80ms +step:874/1680 train_time:78483ms step_avg:89.80ms +step:875/1680 train_time:78573ms step_avg:89.80ms +step:875/1680 val_loss:3.5204 train_time:78664ms step_avg:89.90ms +step:876/1680 train_time:78689ms step_avg:89.83ms +step:877/1680 train_time:78757ms step_avg:89.80ms +step:878/1680 train_time:78853ms step_avg:89.81ms +step:879/1680 train_time:78944ms step_avg:89.81ms +step:880/1680 train_time:79033ms step_avg:89.81ms +step:881/1680 train_time:79122ms step_avg:89.81ms +step:882/1680 train_time:79211ms step_avg:89.81ms +step:883/1680 train_time:79300ms step_avg:89.81ms +step:884/1680 train_time:79389ms step_avg:89.81ms +step:885/1680 train_time:79478ms step_avg:89.81ms +step:886/1680 train_time:79569ms step_avg:89.81ms +step:887/1680 train_time:79659ms step_avg:89.81ms +step:888/1680 train_time:79751ms step_avg:89.81ms +step:889/1680 train_time:79844ms step_avg:89.81ms +step:890/1680 train_time:79935ms step_avg:89.81ms +step:891/1680 train_time:80025ms step_avg:89.81ms +step:892/1680 train_time:80115ms step_avg:89.82ms +step:893/1680 train_time:80206ms step_avg:89.82ms +step:894/1680 train_time:80295ms step_avg:89.82ms +step:895/1680 train_time:80385ms step_avg:89.82ms +step:896/1680 train_time:80474ms step_avg:89.82ms +step:897/1680 train_time:80565ms step_avg:89.82ms +step:898/1680 train_time:80655ms step_avg:89.82ms +step:899/1680 train_time:80749ms step_avg:89.82ms +step:900/1680 train_time:80841ms step_avg:89.82ms +step:901/1680 train_time:80932ms step_avg:89.82ms +step:902/1680 train_time:81022ms step_avg:89.82ms +step:903/1680 train_time:81112ms step_avg:89.82ms +step:904/1680 train_time:81202ms step_avg:89.83ms +step:905/1680 train_time:81292ms step_avg:89.83ms +step:906/1680 train_time:81381ms step_avg:89.82ms +step:907/1680 train_time:81470ms step_avg:89.82ms +step:908/1680 train_time:81560ms step_avg:89.82ms +step:909/1680 train_time:81650ms step_avg:89.82ms +step:910/1680 train_time:81744ms step_avg:89.83ms +step:911/1680 train_time:81835ms step_avg:89.83ms +step:912/1680 train_time:81926ms step_avg:89.83ms +step:913/1680 train_time:82016ms step_avg:89.83ms +step:914/1680 train_time:82107ms step_avg:89.83ms +step:915/1680 train_time:82198ms step_avg:89.83ms +step:916/1680 train_time:82287ms step_avg:89.83ms +step:917/1680 train_time:82376ms step_avg:89.83ms +step:918/1680 train_time:82466ms step_avg:89.83ms +step:919/1680 train_time:82555ms step_avg:89.83ms +step:920/1680 train_time:82646ms step_avg:89.83ms +step:921/1680 train_time:82737ms step_avg:89.83ms +step:922/1680 train_time:82828ms step_avg:89.83ms +step:923/1680 train_time:82918ms step_avg:89.84ms +step:924/1680 train_time:83008ms step_avg:89.84ms +step:925/1680 train_time:83098ms step_avg:89.84ms +step:926/1680 train_time:83190ms step_avg:89.84ms +step:927/1680 train_time:83279ms step_avg:89.84ms +step:928/1680 train_time:83370ms step_avg:89.84ms +step:929/1680 train_time:83459ms step_avg:89.84ms +step:930/1680 train_time:83550ms step_avg:89.84ms +step:931/1680 train_time:83640ms step_avg:89.84ms +step:932/1680 train_time:83730ms step_avg:89.84ms +step:933/1680 train_time:83821ms step_avg:89.84ms +step:934/1680 train_time:83911ms step_avg:89.84ms +step:935/1680 train_time:84003ms step_avg:89.84ms +step:936/1680 train_time:84094ms step_avg:89.84ms +step:937/1680 train_time:84185ms step_avg:89.85ms +step:938/1680 train_time:84276ms step_avg:89.85ms +step:939/1680 train_time:84366ms step_avg:89.85ms +step:940/1680 train_time:84457ms step_avg:89.85ms +step:941/1680 train_time:84547ms step_avg:89.85ms +step:942/1680 train_time:84637ms step_avg:89.85ms +step:943/1680 train_time:84727ms step_avg:89.85ms +step:944/1680 train_time:84818ms step_avg:89.85ms +step:945/1680 train_time:84907ms step_avg:89.85ms +step:946/1680 train_time:84998ms step_avg:89.85ms +step:947/1680 train_time:85089ms step_avg:89.85ms +step:948/1680 train_time:85180ms step_avg:89.85ms +step:949/1680 train_time:85271ms step_avg:89.85ms +step:950/1680 train_time:85361ms step_avg:89.85ms +step:951/1680 train_time:85451ms step_avg:89.85ms +step:952/1680 train_time:85542ms step_avg:89.85ms +step:953/1680 train_time:85631ms step_avg:89.85ms +step:954/1680 train_time:85722ms step_avg:89.86ms +step:955/1680 train_time:85812ms step_avg:89.86ms +step:956/1680 train_time:85903ms step_avg:89.86ms +step:957/1680 train_time:85993ms step_avg:89.86ms +step:958/1680 train_time:86083ms step_avg:89.86ms +step:959/1680 train_time:86174ms step_avg:89.86ms +step:960/1680 train_time:86264ms step_avg:89.86ms +step:961/1680 train_time:86355ms step_avg:89.86ms +step:962/1680 train_time:86446ms step_avg:89.86ms +step:963/1680 train_time:86535ms step_avg:89.86ms +step:964/1680 train_time:86626ms step_avg:89.86ms +step:965/1680 train_time:86716ms step_avg:89.86ms +step:966/1680 train_time:86807ms step_avg:89.86ms +step:967/1680 train_time:86897ms step_avg:89.86ms +step:968/1680 train_time:86987ms step_avg:89.86ms +step:969/1680 train_time:87078ms step_avg:89.86ms +step:970/1680 train_time:87168ms step_avg:89.86ms +step:971/1680 train_time:87258ms step_avg:89.86ms +step:972/1680 train_time:87349ms step_avg:89.86ms +step:973/1680 train_time:87441ms step_avg:89.87ms +step:974/1680 train_time:87530ms step_avg:89.87ms +step:975/1680 train_time:87620ms step_avg:89.87ms +step:976/1680 train_time:87711ms step_avg:89.87ms +step:977/1680 train_time:87801ms step_avg:89.87ms +step:978/1680 train_time:87891ms step_avg:89.87ms +step:979/1680 train_time:87981ms step_avg:89.87ms +step:980/1680 train_time:88072ms step_avg:89.87ms +step:981/1680 train_time:88163ms step_avg:89.87ms +step:982/1680 train_time:88253ms step_avg:89.87ms +step:983/1680 train_time:88345ms step_avg:89.87ms +step:984/1680 train_time:88436ms step_avg:89.87ms +step:985/1680 train_time:88526ms step_avg:89.87ms +step:986/1680 train_time:88616ms step_avg:89.87ms +step:987/1680 train_time:88707ms step_avg:89.88ms +step:988/1680 train_time:88796ms step_avg:89.87ms +step:989/1680 train_time:88886ms step_avg:89.87ms +step:990/1680 train_time:88976ms step_avg:89.88ms +step:991/1680 train_time:89067ms step_avg:89.88ms +step:992/1680 train_time:89157ms step_avg:89.88ms +step:993/1680 train_time:89247ms step_avg:89.88ms +step:994/1680 train_time:89337ms step_avg:89.88ms +step:995/1680 train_time:89428ms step_avg:89.88ms +step:996/1680 train_time:89519ms step_avg:89.88ms +step:997/1680 train_time:89609ms step_avg:89.88ms +step:998/1680 train_time:89700ms step_avg:89.88ms +step:999/1680 train_time:89790ms step_avg:89.88ms +step:1000/1680 train_time:89880ms step_avg:89.88ms +step:1000/1680 val_loss:3.4706 train_time:89971ms step_avg:89.97ms +step:1001/1680 train_time:89996ms step_avg:89.91ms +step:1002/1680 train_time:90064ms step_avg:89.88ms +step:1003/1680 train_time:90161ms step_avg:89.89ms +step:1004/1680 train_time:90251ms step_avg:89.89ms +step:1005/1680 train_time:90342ms step_avg:89.89ms +step:1006/1680 train_time:90432ms step_avg:89.89ms +step:1007/1680 train_time:90521ms step_avg:89.89ms +step:1008/1680 train_time:90611ms step_avg:89.89ms +step:1009/1680 train_time:90700ms step_avg:89.89ms +step:1010/1680 train_time:90790ms step_avg:89.89ms +step:1011/1680 train_time:90880ms step_avg:89.89ms +step:1012/1680 train_time:90971ms step_avg:89.89ms +step:1013/1680 train_time:91064ms step_avg:89.90ms +step:1014/1680 train_time:91157ms step_avg:89.90ms +step:1015/1680 train_time:91247ms step_avg:89.90ms +step:1016/1680 train_time:91338ms step_avg:89.90ms +step:1017/1680 train_time:91429ms step_avg:89.90ms +step:1018/1680 train_time:91519ms step_avg:89.90ms +step:1019/1680 train_time:91608ms step_avg:89.90ms +step:1020/1680 train_time:91697ms step_avg:89.90ms +step:1021/1680 train_time:91787ms step_avg:89.90ms +step:1022/1680 train_time:91877ms step_avg:89.90ms +step:1023/1680 train_time:91967ms step_avg:89.90ms +step:1024/1680 train_time:92059ms step_avg:89.90ms +step:1025/1680 train_time:92151ms step_avg:89.90ms +step:1026/1680 train_time:92242ms step_avg:89.90ms +step:1027/1680 train_time:92333ms step_avg:89.91ms +step:1028/1680 train_time:92423ms step_avg:89.91ms +step:1029/1680 train_time:92513ms step_avg:89.91ms +step:1030/1680 train_time:92603ms step_avg:89.91ms +step:1031/1680 train_time:92692ms step_avg:89.90ms +step:1032/1680 train_time:92782ms step_avg:89.90ms +step:1033/1680 train_time:92871ms step_avg:89.90ms +step:1034/1680 train_time:92962ms step_avg:89.91ms +step:1035/1680 train_time:93053ms step_avg:89.91ms +step:1036/1680 train_time:93143ms step_avg:89.91ms +step:1037/1680 train_time:93234ms step_avg:89.91ms +step:1038/1680 train_time:93324ms step_avg:89.91ms +step:1039/1680 train_time:93415ms step_avg:89.91ms +step:1040/1680 train_time:93505ms step_avg:89.91ms +step:1041/1680 train_time:93595ms step_avg:89.91ms +step:1042/1680 train_time:93684ms step_avg:89.91ms +step:1043/1680 train_time:93775ms step_avg:89.91ms +step:1044/1680 train_time:93865ms step_avg:89.91ms +step:1045/1680 train_time:93955ms step_avg:89.91ms +step:1046/1680 train_time:94045ms step_avg:89.91ms +step:1047/1680 train_time:94136ms step_avg:89.91ms +step:1048/1680 train_time:94226ms step_avg:89.91ms +step:1049/1680 train_time:94316ms step_avg:89.91ms +step:1050/1680 train_time:94407ms step_avg:89.91ms +step:1051/1680 train_time:94496ms step_avg:89.91ms +step:1052/1680 train_time:94586ms step_avg:89.91ms +step:1053/1680 train_time:94677ms step_avg:89.91ms +step:1054/1680 train_time:94766ms step_avg:89.91ms +step:1055/1680 train_time:94856ms step_avg:89.91ms +step:1056/1680 train_time:94946ms step_avg:89.91ms +step:1057/1680 train_time:95038ms step_avg:89.91ms +step:1058/1680 train_time:95129ms step_avg:89.91ms +step:1059/1680 train_time:95220ms step_avg:89.91ms +step:1060/1680 train_time:95310ms step_avg:89.92ms +step:1061/1680 train_time:95400ms step_avg:89.92ms +step:1062/1680 train_time:95491ms step_avg:89.92ms +step:1063/1680 train_time:95582ms step_avg:89.92ms +step:1064/1680 train_time:95671ms step_avg:89.92ms +step:1065/1680 train_time:95762ms step_avg:89.92ms +step:1066/1680 train_time:95852ms step_avg:89.92ms +step:1067/1680 train_time:95943ms step_avg:89.92ms +step:1068/1680 train_time:96034ms step_avg:89.92ms +step:1069/1680 train_time:96124ms step_avg:89.92ms +step:1070/1680 train_time:96215ms step_avg:89.92ms +step:1071/1680 train_time:96306ms step_avg:89.92ms +step:1072/1680 train_time:96397ms step_avg:89.92ms +step:1073/1680 train_time:96486ms step_avg:89.92ms +step:1074/1680 train_time:96577ms step_avg:89.92ms +step:1075/1680 train_time:96666ms step_avg:89.92ms +step:1076/1680 train_time:96756ms step_avg:89.92ms +step:1077/1680 train_time:96846ms step_avg:89.92ms +step:1078/1680 train_time:96936ms step_avg:89.92ms +step:1079/1680 train_time:97027ms step_avg:89.92ms +step:1080/1680 train_time:97117ms step_avg:89.92ms +step:1081/1680 train_time:97207ms step_avg:89.92ms +step:1082/1680 train_time:97298ms step_avg:89.92ms +step:1083/1680 train_time:97387ms step_avg:89.92ms +step:1084/1680 train_time:97479ms step_avg:89.93ms +step:1085/1680 train_time:97568ms step_avg:89.92ms +step:1086/1680 train_time:97659ms step_avg:89.93ms +step:1087/1680 train_time:97749ms step_avg:89.93ms +step:1088/1680 train_time:97840ms step_avg:89.93ms +step:1089/1680 train_time:97929ms step_avg:89.93ms +step:1090/1680 train_time:98019ms step_avg:89.93ms +step:1091/1680 train_time:98109ms step_avg:89.93ms +step:1092/1680 train_time:98200ms step_avg:89.93ms +step:1093/1680 train_time:98291ms step_avg:89.93ms +step:1094/1680 train_time:98381ms step_avg:89.93ms +step:1095/1680 train_time:98472ms step_avg:89.93ms +step:1096/1680 train_time:98563ms step_avg:89.93ms +step:1097/1680 train_time:98654ms step_avg:89.93ms +step:1098/1680 train_time:98744ms step_avg:89.93ms +step:1099/1680 train_time:98836ms step_avg:89.93ms +step:1100/1680 train_time:98926ms step_avg:89.93ms +step:1101/1680 train_time:99017ms step_avg:89.93ms +step:1102/1680 train_time:99108ms step_avg:89.93ms +step:1103/1680 train_time:99199ms step_avg:89.94ms +step:1104/1680 train_time:99290ms step_avg:89.94ms +step:1105/1680 train_time:99381ms step_avg:89.94ms +step:1106/1680 train_time:99472ms step_avg:89.94ms +step:1107/1680 train_time:99569ms step_avg:89.95ms +step:1108/1680 train_time:99656ms step_avg:89.94ms +step:1109/1680 train_time:99747ms step_avg:89.94ms +step:1110/1680 train_time:99837ms step_avg:89.94ms +step:1111/1680 train_time:99928ms step_avg:89.94ms +step:1112/1680 train_time:100018ms step_avg:89.94ms +step:1113/1680 train_time:100109ms step_avg:89.95ms +step:1114/1680 train_time:100200ms step_avg:89.95ms +step:1115/1680 train_time:100291ms step_avg:89.95ms +step:1116/1680 train_time:100382ms step_avg:89.95ms +step:1117/1680 train_time:100473ms step_avg:89.95ms +step:1118/1680 train_time:100565ms step_avg:89.95ms +step:1119/1680 train_time:100655ms step_avg:89.95ms +step:1120/1680 train_time:100746ms step_avg:89.95ms +step:1121/1680 train_time:100836ms step_avg:89.95ms +step:1122/1680 train_time:100927ms step_avg:89.95ms +step:1123/1680 train_time:101017ms step_avg:89.95ms +step:1124/1680 train_time:101109ms step_avg:89.95ms +step:1125/1680 train_time:101200ms step_avg:89.96ms +step:1125/1680 val_loss:3.4175 train_time:101292ms step_avg:90.04ms +step:1126/1680 train_time:101319ms step_avg:89.98ms +step:1127/1680 train_time:101388ms step_avg:89.96ms +step:1128/1680 train_time:101487ms step_avg:89.97ms +step:1129/1680 train_time:101577ms step_avg:89.97ms +step:1130/1680 train_time:101667ms step_avg:89.97ms +step:1131/1680 train_time:101757ms step_avg:89.97ms +step:1132/1680 train_time:101847ms step_avg:89.97ms +step:1133/1680 train_time:101937ms step_avg:89.97ms +step:1134/1680 train_time:102026ms step_avg:89.97ms +step:1135/1680 train_time:102116ms step_avg:89.97ms +step:1136/1680 train_time:102208ms step_avg:89.97ms +step:1137/1680 train_time:102304ms step_avg:89.98ms +step:1138/1680 train_time:102396ms step_avg:89.98ms +step:1139/1680 train_time:102488ms step_avg:89.98ms +step:1140/1680 train_time:102582ms step_avg:89.98ms +step:1141/1680 train_time:102668ms step_avg:89.98ms +step:1142/1680 train_time:102758ms step_avg:89.98ms +step:1143/1680 train_time:102848ms step_avg:89.98ms +step:1144/1680 train_time:102938ms step_avg:89.98ms +step:1145/1680 train_time:103027ms step_avg:89.98ms +step:1146/1680 train_time:103117ms step_avg:89.98ms +step:1147/1680 train_time:103208ms step_avg:89.98ms +step:1148/1680 train_time:103300ms step_avg:89.98ms +step:1149/1680 train_time:103392ms step_avg:89.98ms +step:1150/1680 train_time:103482ms step_avg:89.98ms +step:1151/1680 train_time:103573ms step_avg:89.99ms +step:1152/1680 train_time:103664ms step_avg:89.99ms +step:1153/1680 train_time:103754ms step_avg:89.99ms +step:1154/1680 train_time:103844ms step_avg:89.99ms +step:1155/1680 train_time:103935ms step_avg:89.99ms +step:1156/1680 train_time:104024ms step_avg:89.99ms +step:1157/1680 train_time:104115ms step_avg:89.99ms +step:1158/1680 train_time:104208ms step_avg:89.99ms +step:1159/1680 train_time:104299ms step_avg:89.99ms +step:1160/1680 train_time:104389ms step_avg:89.99ms +step:1161/1680 train_time:104480ms step_avg:89.99ms +step:1162/1680 train_time:104572ms step_avg:89.99ms +step:1163/1680 train_time:104662ms step_avg:89.99ms +step:1164/1680 train_time:104753ms step_avg:89.99ms +step:1165/1680 train_time:104843ms step_avg:89.99ms +step:1166/1680 train_time:104933ms step_avg:89.99ms +step:1167/1680 train_time:105024ms step_avg:89.99ms +step:1168/1680 train_time:105115ms step_avg:90.00ms +step:1169/1680 train_time:105206ms step_avg:90.00ms +step:1170/1680 train_time:105296ms step_avg:90.00ms +step:1171/1680 train_time:105387ms step_avg:90.00ms +step:1172/1680 train_time:105478ms step_avg:90.00ms +step:1173/1680 train_time:105569ms step_avg:90.00ms +step:1174/1680 train_time:105659ms step_avg:90.00ms +step:1175/1680 train_time:105750ms step_avg:90.00ms +step:1176/1680 train_time:105840ms step_avg:90.00ms +step:1177/1680 train_time:105931ms step_avg:90.00ms +step:1178/1680 train_time:106021ms step_avg:90.00ms +step:1179/1680 train_time:106111ms step_avg:90.00ms +step:1180/1680 train_time:106203ms step_avg:90.00ms +step:1181/1680 train_time:106294ms step_avg:90.00ms +step:1182/1680 train_time:106385ms step_avg:90.00ms +step:1183/1680 train_time:106476ms step_avg:90.00ms +step:1184/1680 train_time:106567ms step_avg:90.01ms +step:1185/1680 train_time:106658ms step_avg:90.01ms +step:1186/1680 train_time:106748ms step_avg:90.01ms +step:1187/1680 train_time:106839ms step_avg:90.01ms +step:1188/1680 train_time:106930ms step_avg:90.01ms +step:1189/1680 train_time:107020ms step_avg:90.01ms +step:1190/1680 train_time:107111ms step_avg:90.01ms +step:1191/1680 train_time:107203ms step_avg:90.01ms +step:1192/1680 train_time:107298ms step_avg:90.02ms +step:1193/1680 train_time:107385ms step_avg:90.01ms +step:1194/1680 train_time:107476ms step_avg:90.01ms +step:1195/1680 train_time:107566ms step_avg:90.01ms +step:1196/1680 train_time:107656ms step_avg:90.01ms +step:1197/1680 train_time:107746ms step_avg:90.01ms +step:1198/1680 train_time:107836ms step_avg:90.01ms +step:1199/1680 train_time:107927ms step_avg:90.01ms +step:1200/1680 train_time:108017ms step_avg:90.01ms +step:1201/1680 train_time:108107ms step_avg:90.01ms +step:1202/1680 train_time:108198ms step_avg:90.02ms +step:1203/1680 train_time:108288ms step_avg:90.02ms +step:1204/1680 train_time:108379ms step_avg:90.02ms +step:1205/1680 train_time:108470ms step_avg:90.02ms +step:1206/1680 train_time:108561ms step_avg:90.02ms +step:1207/1680 train_time:108652ms step_avg:90.02ms +step:1208/1680 train_time:108743ms step_avg:90.02ms +step:1209/1680 train_time:108833ms step_avg:90.02ms +step:1210/1680 train_time:108924ms step_avg:90.02ms +step:1211/1680 train_time:109014ms step_avg:90.02ms +step:1212/1680 train_time:109105ms step_avg:90.02ms +step:1213/1680 train_time:109196ms step_avg:90.02ms +step:1214/1680 train_time:109286ms step_avg:90.02ms +step:1215/1680 train_time:109376ms step_avg:90.02ms +step:1216/1680 train_time:109467ms step_avg:90.02ms +step:1217/1680 train_time:109558ms step_avg:90.02ms +step:1218/1680 train_time:109648ms step_avg:90.02ms +step:1219/1680 train_time:109739ms step_avg:90.02ms +step:1220/1680 train_time:109830ms step_avg:90.02ms +step:1221/1680 train_time:109920ms step_avg:90.02ms +step:1222/1680 train_time:110010ms step_avg:90.02ms +step:1223/1680 train_time:110101ms step_avg:90.03ms +step:1224/1680 train_time:110193ms step_avg:90.03ms +step:1225/1680 train_time:110284ms step_avg:90.03ms +step:1226/1680 train_time:110375ms step_avg:90.03ms +step:1227/1680 train_time:110466ms step_avg:90.03ms +step:1228/1680 train_time:110557ms step_avg:90.03ms +step:1229/1680 train_time:110648ms step_avg:90.03ms +step:1230/1680 train_time:110739ms step_avg:90.03ms +step:1231/1680 train_time:110829ms step_avg:90.03ms +step:1232/1680 train_time:110920ms step_avg:90.03ms +step:1233/1680 train_time:111011ms step_avg:90.03ms +step:1234/1680 train_time:111102ms step_avg:90.03ms +step:1235/1680 train_time:111192ms step_avg:90.03ms +step:1236/1680 train_time:111282ms step_avg:90.03ms +step:1237/1680 train_time:111374ms step_avg:90.04ms +step:1238/1680 train_time:111465ms step_avg:90.04ms +step:1239/1680 train_time:111555ms step_avg:90.04ms +step:1240/1680 train_time:111646ms step_avg:90.04ms +step:1241/1680 train_time:111736ms step_avg:90.04ms +step:1242/1680 train_time:111826ms step_avg:90.04ms +step:1243/1680 train_time:111918ms step_avg:90.04ms +step:1244/1680 train_time:112008ms step_avg:90.04ms +step:1245/1680 train_time:112099ms step_avg:90.04ms +step:1246/1680 train_time:112189ms step_avg:90.04ms +step:1247/1680 train_time:112280ms step_avg:90.04ms +step:1248/1680 train_time:112371ms step_avg:90.04ms +step:1249/1680 train_time:112461ms step_avg:90.04ms +step:1250/1680 train_time:112554ms step_avg:90.04ms +step:1250/1680 val_loss:3.3796 train_time:112647ms step_avg:90.12ms +step:1251/1680 train_time:112672ms step_avg:90.07ms +step:1252/1680 train_time:112742ms step_avg:90.05ms +step:1253/1680 train_time:112836ms step_avg:90.05ms +step:1254/1680 train_time:112927ms step_avg:90.05ms +step:1255/1680 train_time:113016ms step_avg:90.05ms +step:1256/1680 train_time:113106ms step_avg:90.05ms +step:1257/1680 train_time:113195ms step_avg:90.05ms +step:1258/1680 train_time:113284ms step_avg:90.05ms +step:1259/1680 train_time:113374ms step_avg:90.05ms +step:1260/1680 train_time:113464ms step_avg:90.05ms +step:1261/1680 train_time:113555ms step_avg:90.05ms +step:1262/1680 train_time:113648ms step_avg:90.05ms +step:1263/1680 train_time:113742ms step_avg:90.06ms +step:1264/1680 train_time:113834ms step_avg:90.06ms +step:1265/1680 train_time:113925ms step_avg:90.06ms +step:1266/1680 train_time:114015ms step_avg:90.06ms +step:1267/1680 train_time:114105ms step_avg:90.06ms +step:1268/1680 train_time:114195ms step_avg:90.06ms +step:1269/1680 train_time:114284ms step_avg:90.06ms +step:1270/1680 train_time:114375ms step_avg:90.06ms +step:1271/1680 train_time:114465ms step_avg:90.06ms +step:1272/1680 train_time:114555ms step_avg:90.06ms +step:1273/1680 train_time:114647ms step_avg:90.06ms +step:1274/1680 train_time:114739ms step_avg:90.06ms +step:1275/1680 train_time:114830ms step_avg:90.06ms +step:1276/1680 train_time:114922ms step_avg:90.06ms +step:1277/1680 train_time:115012ms step_avg:90.06ms +step:1278/1680 train_time:115102ms step_avg:90.06ms +step:1279/1680 train_time:115192ms step_avg:90.06ms +step:1280/1680 train_time:115283ms step_avg:90.06ms +step:1281/1680 train_time:115373ms step_avg:90.06ms +step:1282/1680 train_time:115463ms step_avg:90.06ms +step:1283/1680 train_time:115553ms step_avg:90.06ms +step:1284/1680 train_time:115645ms step_avg:90.07ms +step:1285/1680 train_time:115738ms step_avg:90.07ms +step:1286/1680 train_time:115829ms step_avg:90.07ms +step:1287/1680 train_time:115921ms step_avg:90.07ms +step:1288/1680 train_time:116012ms step_avg:90.07ms +step:1289/1680 train_time:116103ms step_avg:90.07ms +step:1290/1680 train_time:116193ms step_avg:90.07ms +step:1291/1680 train_time:116284ms step_avg:90.07ms +step:1292/1680 train_time:116374ms step_avg:90.07ms +step:1293/1680 train_time:116464ms step_avg:90.07ms +step:1294/1680 train_time:116554ms step_avg:90.07ms +step:1295/1680 train_time:116646ms step_avg:90.07ms +step:1296/1680 train_time:116737ms step_avg:90.07ms +step:1297/1680 train_time:116829ms step_avg:90.08ms +step:1298/1680 train_time:116921ms step_avg:90.08ms +step:1299/1680 train_time:117013ms step_avg:90.08ms +step:1300/1680 train_time:117103ms step_avg:90.08ms +step:1301/1680 train_time:117194ms step_avg:90.08ms +step:1302/1680 train_time:117284ms step_avg:90.08ms +step:1303/1680 train_time:117374ms step_avg:90.08ms +step:1304/1680 train_time:117465ms step_avg:90.08ms +step:1305/1680 train_time:117555ms step_avg:90.08ms +step:1306/1680 train_time:117646ms step_avg:90.08ms +step:1307/1680 train_time:117738ms step_avg:90.08ms +step:1308/1680 train_time:117829ms step_avg:90.08ms +step:1309/1680 train_time:117921ms step_avg:90.08ms +step:1310/1680 train_time:118014ms step_avg:90.09ms +step:1311/1680 train_time:118105ms step_avg:90.09ms +step:1312/1680 train_time:118195ms step_avg:90.09ms +step:1313/1680 train_time:118286ms step_avg:90.09ms +step:1314/1680 train_time:118376ms step_avg:90.09ms +step:1315/1680 train_time:118467ms step_avg:90.09ms +step:1316/1680 train_time:118557ms step_avg:90.09ms +step:1317/1680 train_time:118647ms step_avg:90.09ms +step:1318/1680 train_time:118743ms step_avg:90.09ms +step:1319/1680 train_time:118829ms step_avg:90.09ms +step:1320/1680 train_time:118920ms step_avg:90.09ms +step:1321/1680 train_time:119012ms step_avg:90.09ms +step:1322/1680 train_time:119103ms step_avg:90.09ms +step:1323/1680 train_time:119194ms step_avg:90.09ms +step:1324/1680 train_time:119284ms step_avg:90.09ms +step:1325/1680 train_time:119375ms step_avg:90.09ms +step:1326/1680 train_time:119466ms step_avg:90.10ms +step:1327/1680 train_time:119556ms step_avg:90.10ms +step:1328/1680 train_time:119647ms step_avg:90.10ms +step:1329/1680 train_time:119739ms step_avg:90.10ms +step:1330/1680 train_time:119829ms step_avg:90.10ms +step:1331/1680 train_time:119920ms step_avg:90.10ms +step:1332/1680 train_time:120012ms step_avg:90.10ms +step:1333/1680 train_time:120102ms step_avg:90.10ms +step:1334/1680 train_time:120192ms step_avg:90.10ms +step:1335/1680 train_time:120283ms step_avg:90.10ms +step:1336/1680 train_time:120374ms step_avg:90.10ms +step:1337/1680 train_time:120465ms step_avg:90.10ms +step:1338/1680 train_time:120555ms step_avg:90.10ms +step:1339/1680 train_time:120645ms step_avg:90.10ms +step:1340/1680 train_time:120735ms step_avg:90.10ms +step:1341/1680 train_time:120826ms step_avg:90.10ms +step:1342/1680 train_time:120917ms step_avg:90.10ms +step:1343/1680 train_time:121009ms step_avg:90.10ms +step:1344/1680 train_time:121100ms step_avg:90.10ms +step:1345/1680 train_time:121191ms step_avg:90.10ms +step:1346/1680 train_time:121281ms step_avg:90.10ms +step:1347/1680 train_time:121372ms step_avg:90.11ms +step:1348/1680 train_time:121463ms step_avg:90.11ms +step:1349/1680 train_time:121557ms step_avg:90.11ms +step:1350/1680 train_time:121644ms step_avg:90.11ms +step:1351/1680 train_time:121735ms step_avg:90.11ms +step:1352/1680 train_time:121826ms step_avg:90.11ms +step:1353/1680 train_time:121917ms step_avg:90.11ms +step:1354/1680 train_time:122008ms step_avg:90.11ms +step:1355/1680 train_time:122099ms step_avg:90.11ms +step:1356/1680 train_time:122190ms step_avg:90.11ms +step:1357/1680 train_time:122281ms step_avg:90.11ms +step:1358/1680 train_time:122372ms step_avg:90.11ms +step:1359/1680 train_time:122462ms step_avg:90.11ms +step:1360/1680 train_time:122552ms step_avg:90.11ms +step:1361/1680 train_time:122643ms step_avg:90.11ms +step:1362/1680 train_time:122733ms step_avg:90.11ms +step:1363/1680 train_time:122825ms step_avg:90.11ms +step:1364/1680 train_time:122917ms step_avg:90.11ms +step:1365/1680 train_time:123007ms step_avg:90.12ms +step:1366/1680 train_time:123099ms step_avg:90.12ms +step:1367/1680 train_time:123189ms step_avg:90.12ms +step:1368/1680 train_time:123280ms step_avg:90.12ms +step:1369/1680 train_time:123372ms step_avg:90.12ms +step:1370/1680 train_time:123462ms step_avg:90.12ms +step:1371/1680 train_time:123553ms step_avg:90.12ms +step:1372/1680 train_time:123643ms step_avg:90.12ms +step:1373/1680 train_time:123734ms step_avg:90.12ms +step:1374/1680 train_time:123824ms step_avg:90.12ms +step:1375/1680 train_time:123915ms step_avg:90.12ms +step:1375/1680 val_loss:3.3448 train_time:124007ms step_avg:90.19ms +step:1376/1680 train_time:124032ms step_avg:90.14ms +step:1377/1680 train_time:124104ms step_avg:90.13ms +step:1378/1680 train_time:124200ms step_avg:90.13ms +step:1379/1680 train_time:124291ms step_avg:90.13ms +step:1380/1680 train_time:124381ms step_avg:90.13ms +step:1381/1680 train_time:124471ms step_avg:90.13ms +step:1382/1680 train_time:124560ms step_avg:90.13ms +step:1383/1680 train_time:124650ms step_avg:90.13ms +step:1384/1680 train_time:124740ms step_avg:90.13ms +step:1385/1680 train_time:124829ms step_avg:90.13ms +step:1386/1680 train_time:124919ms step_avg:90.13ms +step:1387/1680 train_time:125010ms step_avg:90.13ms +step:1388/1680 train_time:125105ms step_avg:90.13ms +step:1389/1680 train_time:125200ms step_avg:90.14ms +step:1390/1680 train_time:125296ms step_avg:90.14ms +step:1391/1680 train_time:125383ms step_avg:90.14ms +step:1392/1680 train_time:125473ms step_avg:90.14ms +step:1393/1680 train_time:125563ms step_avg:90.14ms +step:1394/1680 train_time:125653ms step_avg:90.14ms +step:1395/1680 train_time:125742ms step_avg:90.14ms +step:1396/1680 train_time:125832ms step_avg:90.14ms +step:1397/1680 train_time:125922ms step_avg:90.14ms +step:1398/1680 train_time:126013ms step_avg:90.14ms +step:1399/1680 train_time:126105ms step_avg:90.14ms +step:1400/1680 train_time:126197ms step_avg:90.14ms +step:1401/1680 train_time:126290ms step_avg:90.14ms +step:1402/1680 train_time:126381ms step_avg:90.14ms +step:1403/1680 train_time:126473ms step_avg:90.14ms +step:1404/1680 train_time:126563ms step_avg:90.14ms +step:1405/1680 train_time:126654ms step_avg:90.15ms +step:1406/1680 train_time:126743ms step_avg:90.14ms +step:1407/1680 train_time:126832ms step_avg:90.14ms +step:1408/1680 train_time:126923ms step_avg:90.14ms +step:1409/1680 train_time:127014ms step_avg:90.14ms +step:1410/1680 train_time:127106ms step_avg:90.15ms +step:1411/1680 train_time:127198ms step_avg:90.15ms +step:1412/1680 train_time:127291ms step_avg:90.15ms +step:1413/1680 train_time:127383ms step_avg:90.15ms +step:1414/1680 train_time:127474ms step_avg:90.15ms +step:1415/1680 train_time:127564ms step_avg:90.15ms +step:1416/1680 train_time:127655ms step_avg:90.15ms +step:1417/1680 train_time:127745ms step_avg:90.15ms +step:1418/1680 train_time:127835ms step_avg:90.15ms +step:1419/1680 train_time:127925ms step_avg:90.15ms +step:1420/1680 train_time:128015ms step_avg:90.15ms +step:1421/1680 train_time:128108ms step_avg:90.15ms +step:1422/1680 train_time:128200ms step_avg:90.15ms +step:1423/1680 train_time:128291ms step_avg:90.16ms +step:1424/1680 train_time:128382ms step_avg:90.16ms +step:1425/1680 train_time:128474ms step_avg:90.16ms +step:1426/1680 train_time:128564ms step_avg:90.16ms +step:1427/1680 train_time:128656ms step_avg:90.16ms +step:1428/1680 train_time:128746ms step_avg:90.16ms +step:1429/1680 train_time:128837ms step_avg:90.16ms +step:1430/1680 train_time:128928ms step_avg:90.16ms +step:1431/1680 train_time:129018ms step_avg:90.16ms +step:1432/1680 train_time:129109ms step_avg:90.16ms +step:1433/1680 train_time:129201ms step_avg:90.16ms +step:1434/1680 train_time:129293ms step_avg:90.16ms +step:1435/1680 train_time:129385ms step_avg:90.16ms +step:1436/1680 train_time:129477ms step_avg:90.17ms +step:1437/1680 train_time:129567ms step_avg:90.17ms +step:1438/1680 train_time:129658ms step_avg:90.17ms +step:1439/1680 train_time:129748ms step_avg:90.17ms +step:1440/1680 train_time:129839ms step_avg:90.17ms +step:1441/1680 train_time:129930ms step_avg:90.17ms +step:1442/1680 train_time:130021ms step_avg:90.17ms +step:1443/1680 train_time:130112ms step_avg:90.17ms +step:1444/1680 train_time:130203ms step_avg:90.17ms +step:1445/1680 train_time:130294ms step_avg:90.17ms +step:1446/1680 train_time:130386ms step_avg:90.17ms +step:1447/1680 train_time:130478ms step_avg:90.17ms +step:1448/1680 train_time:130568ms step_avg:90.17ms +step:1449/1680 train_time:130660ms step_avg:90.17ms +step:1450/1680 train_time:130751ms step_avg:90.17ms +step:1451/1680 train_time:130841ms step_avg:90.17ms +step:1452/1680 train_time:130931ms step_avg:90.17ms +step:1453/1680 train_time:131022ms step_avg:90.17ms +step:1454/1680 train_time:131112ms step_avg:90.17ms +step:1455/1680 train_time:131203ms step_avg:90.17ms +step:1456/1680 train_time:131294ms step_avg:90.17ms +step:1457/1680 train_time:131386ms step_avg:90.18ms +step:1458/1680 train_time:131477ms step_avg:90.18ms +step:1459/1680 train_time:131568ms step_avg:90.18ms +step:1460/1680 train_time:131660ms step_avg:90.18ms +step:1461/1680 train_time:131750ms step_avg:90.18ms +step:1462/1680 train_time:131841ms step_avg:90.18ms +step:1463/1680 train_time:131932ms step_avg:90.18ms +step:1464/1680 train_time:132023ms step_avg:90.18ms +step:1465/1680 train_time:132115ms step_avg:90.18ms +step:1466/1680 train_time:132205ms step_avg:90.18ms +step:1467/1680 train_time:132297ms step_avg:90.18ms +step:1468/1680 train_time:132387ms step_avg:90.18ms +step:1469/1680 train_time:132479ms step_avg:90.18ms +step:1470/1680 train_time:132570ms step_avg:90.18ms +step:1471/1680 train_time:132661ms step_avg:90.18ms +step:1472/1680 train_time:132752ms step_avg:90.18ms +step:1473/1680 train_time:132842ms step_avg:90.18ms +step:1474/1680 train_time:132933ms step_avg:90.19ms +step:1475/1680 train_time:133024ms step_avg:90.19ms +step:1476/1680 train_time:133114ms step_avg:90.19ms +step:1477/1680 train_time:133205ms step_avg:90.19ms +step:1478/1680 train_time:133296ms step_avg:90.19ms +step:1479/1680 train_time:133387ms step_avg:90.19ms +step:1480/1680 train_time:133479ms step_avg:90.19ms +step:1481/1680 train_time:133570ms step_avg:90.19ms +step:1482/1680 train_time:133662ms step_avg:90.19ms +step:1483/1680 train_time:133752ms step_avg:90.19ms +step:1484/1680 train_time:133843ms step_avg:90.19ms +step:1485/1680 train_time:133935ms step_avg:90.19ms +step:1486/1680 train_time:134025ms step_avg:90.19ms +step:1487/1680 train_time:134116ms step_avg:90.19ms +step:1488/1680 train_time:134206ms step_avg:90.19ms +step:1489/1680 train_time:134296ms step_avg:90.19ms +step:1490/1680 train_time:134387ms step_avg:90.19ms +step:1491/1680 train_time:134478ms step_avg:90.19ms +step:1492/1680 train_time:134569ms step_avg:90.19ms +step:1493/1680 train_time:134660ms step_avg:90.19ms +step:1494/1680 train_time:134750ms step_avg:90.19ms +step:1495/1680 train_time:134841ms step_avg:90.19ms +step:1496/1680 train_time:134932ms step_avg:90.19ms +step:1497/1680 train_time:135023ms step_avg:90.20ms +step:1498/1680 train_time:135113ms step_avg:90.20ms +step:1499/1680 train_time:135204ms step_avg:90.20ms +step:1500/1680 train_time:135295ms step_avg:90.20ms +step:1500/1680 val_loss:3.3150 train_time:135386ms step_avg:90.26ms +step:1501/1680 train_time:135411ms step_avg:90.21ms +step:1502/1680 train_time:135482ms step_avg:90.20ms +step:1503/1680 train_time:135580ms step_avg:90.21ms +step:1504/1680 train_time:135671ms step_avg:90.21ms +step:1505/1680 train_time:135761ms step_avg:90.21ms +step:1506/1680 train_time:135851ms step_avg:90.21ms +step:1507/1680 train_time:135941ms step_avg:90.21ms +step:1508/1680 train_time:136033ms step_avg:90.21ms +step:1509/1680 train_time:136119ms step_avg:90.20ms +step:1510/1680 train_time:136209ms step_avg:90.20ms +step:1511/1680 train_time:136299ms step_avg:90.20ms +step:1512/1680 train_time:136392ms step_avg:90.21ms +step:1513/1680 train_time:136485ms step_avg:90.21ms +step:1514/1680 train_time:136578ms step_avg:90.21ms +step:1515/1680 train_time:136669ms step_avg:90.21ms +step:1516/1680 train_time:136759ms step_avg:90.21ms +step:1517/1680 train_time:136850ms step_avg:90.21ms +step:1518/1680 train_time:136940ms step_avg:90.21ms +step:1519/1680 train_time:137029ms step_avg:90.21ms +step:1520/1680 train_time:137119ms step_avg:90.21ms +step:1521/1680 train_time:137209ms step_avg:90.21ms +step:1522/1680 train_time:137299ms step_avg:90.21ms +step:1523/1680 train_time:137390ms step_avg:90.21ms +step:1524/1680 train_time:137482ms step_avg:90.21ms +step:1525/1680 train_time:137575ms step_avg:90.21ms +step:1526/1680 train_time:137667ms step_avg:90.21ms +step:1527/1680 train_time:137759ms step_avg:90.22ms +step:1528/1680 train_time:137849ms step_avg:90.22ms +step:1529/1680 train_time:137940ms step_avg:90.22ms +step:1530/1680 train_time:138030ms step_avg:90.22ms +step:1531/1680 train_time:138120ms step_avg:90.22ms +step:1532/1680 train_time:138210ms step_avg:90.22ms +step:1533/1680 train_time:138301ms step_avg:90.22ms +step:1534/1680 train_time:138392ms step_avg:90.22ms +step:1535/1680 train_time:138483ms step_avg:90.22ms +step:1536/1680 train_time:138576ms step_avg:90.22ms +step:1537/1680 train_time:138668ms step_avg:90.22ms +step:1538/1680 train_time:138759ms step_avg:90.22ms +step:1539/1680 train_time:138851ms step_avg:90.22ms +step:1540/1680 train_time:138943ms step_avg:90.22ms +step:1541/1680 train_time:139031ms step_avg:90.22ms +step:1542/1680 train_time:139121ms step_avg:90.22ms +step:1543/1680 train_time:139211ms step_avg:90.22ms +step:1544/1680 train_time:139302ms step_avg:90.22ms +step:1545/1680 train_time:139393ms step_avg:90.22ms +step:1546/1680 train_time:139484ms step_avg:90.22ms +step:1547/1680 train_time:139575ms step_avg:90.22ms +step:1548/1680 train_time:139667ms step_avg:90.22ms +step:1549/1680 train_time:139759ms step_avg:90.23ms +step:1550/1680 train_time:139849ms step_avg:90.23ms +step:1551/1680 train_time:139939ms step_avg:90.23ms +step:1552/1680 train_time:140030ms step_avg:90.23ms +step:1553/1680 train_time:140120ms step_avg:90.23ms +step:1554/1680 train_time:140211ms step_avg:90.23ms +step:1555/1680 train_time:140302ms step_avg:90.23ms +step:1556/1680 train_time:140394ms step_avg:90.23ms +step:1557/1680 train_time:140484ms step_avg:90.23ms +step:1558/1680 train_time:140575ms step_avg:90.23ms +step:1559/1680 train_time:140668ms step_avg:90.23ms +step:1560/1680 train_time:140760ms step_avg:90.23ms +step:1561/1680 train_time:140850ms step_avg:90.23ms +step:1562/1680 train_time:140941ms step_avg:90.23ms +step:1563/1680 train_time:141031ms step_avg:90.23ms +step:1564/1680 train_time:141123ms step_avg:90.23ms +step:1565/1680 train_time:141213ms step_avg:90.23ms +step:1566/1680 train_time:141305ms step_avg:90.23ms +step:1567/1680 train_time:141396ms step_avg:90.23ms +step:1568/1680 train_time:141486ms step_avg:90.23ms +step:1569/1680 train_time:141579ms step_avg:90.24ms +step:1570/1680 train_time:141670ms step_avg:90.24ms +step:1571/1680 train_time:141761ms step_avg:90.24ms +step:1572/1680 train_time:141852ms step_avg:90.24ms +step:1573/1680 train_time:141942ms step_avg:90.24ms +step:1574/1680 train_time:142033ms step_avg:90.24ms +step:1575/1680 train_time:142124ms step_avg:90.24ms +step:1576/1680 train_time:142215ms step_avg:90.24ms +step:1577/1680 train_time:142305ms step_avg:90.24ms +step:1578/1680 train_time:142397ms step_avg:90.24ms +step:1579/1680 train_time:142489ms step_avg:90.24ms +step:1580/1680 train_time:142580ms step_avg:90.24ms +step:1581/1680 train_time:142671ms step_avg:90.24ms +step:1582/1680 train_time:142762ms step_avg:90.24ms +step:1583/1680 train_time:142853ms step_avg:90.24ms +step:1584/1680 train_time:142943ms step_avg:90.24ms +step:1585/1680 train_time:143033ms step_avg:90.24ms +step:1586/1680 train_time:143124ms step_avg:90.24ms +step:1587/1680 train_time:143215ms step_avg:90.24ms +step:1588/1680 train_time:143305ms step_avg:90.24ms +step:1589/1680 train_time:143396ms step_avg:90.24ms +step:1590/1680 train_time:143487ms step_avg:90.24ms +step:1591/1680 train_time:143579ms step_avg:90.24ms +step:1592/1680 train_time:143669ms step_avg:90.24ms +step:1593/1680 train_time:143760ms step_avg:90.24ms +step:1594/1680 train_time:143851ms step_avg:90.25ms +step:1595/1680 train_time:143941ms step_avg:90.25ms +step:1596/1680 train_time:144031ms step_avg:90.25ms +step:1597/1680 train_time:144122ms step_avg:90.25ms +step:1598/1680 train_time:144212ms step_avg:90.25ms +step:1599/1680 train_time:144302ms step_avg:90.24ms +step:1600/1680 train_time:144392ms step_avg:90.25ms +step:1601/1680 train_time:144483ms step_avg:90.25ms +step:1602/1680 train_time:144574ms step_avg:90.25ms +step:1603/1680 train_time:144664ms step_avg:90.25ms +step:1604/1680 train_time:144755ms step_avg:90.25ms +step:1605/1680 train_time:144846ms step_avg:90.25ms +step:1606/1680 train_time:144938ms step_avg:90.25ms +step:1607/1680 train_time:145029ms step_avg:90.25ms +step:1608/1680 train_time:145120ms step_avg:90.25ms +step:1609/1680 train_time:145211ms step_avg:90.25ms +step:1610/1680 train_time:145302ms step_avg:90.25ms +step:1611/1680 train_time:145392ms step_avg:90.25ms +step:1612/1680 train_time:145483ms step_avg:90.25ms +step:1613/1680 train_time:145574ms step_avg:90.25ms +step:1614/1680 train_time:145665ms step_avg:90.25ms +step:1615/1680 train_time:145756ms step_avg:90.25ms +step:1616/1680 train_time:145847ms step_avg:90.25ms +step:1617/1680 train_time:145940ms step_avg:90.25ms +step:1618/1680 train_time:146035ms step_avg:90.26ms +step:1619/1680 train_time:146121ms step_avg:90.25ms +step:1620/1680 train_time:146211ms step_avg:90.25ms +step:1621/1680 train_time:146302ms step_avg:90.25ms +step:1622/1680 train_time:146393ms step_avg:90.25ms +step:1623/1680 train_time:146483ms step_avg:90.25ms +step:1624/1680 train_time:146575ms step_avg:90.26ms +step:1625/1680 train_time:146666ms step_avg:90.26ms +step:1625/1680 val_loss:3.2913 train_time:146759ms step_avg:90.31ms +step:1626/1680 train_time:146783ms step_avg:90.27ms +step:1627/1680 train_time:146854ms step_avg:90.26ms +step:1628/1680 train_time:146948ms step_avg:90.26ms +step:1629/1680 train_time:147040ms step_avg:90.26ms +step:1630/1680 train_time:147130ms step_avg:90.26ms +step:1631/1680 train_time:147220ms step_avg:90.26ms +step:1632/1680 train_time:147310ms step_avg:90.26ms +step:1633/1680 train_time:147400ms step_avg:90.26ms +step:1634/1680 train_time:147491ms step_avg:90.26ms +step:1635/1680 train_time:147580ms step_avg:90.26ms +step:1636/1680 train_time:147671ms step_avg:90.26ms +step:1637/1680 train_time:147764ms step_avg:90.27ms +step:1638/1680 train_time:147856ms step_avg:90.27ms +step:1639/1680 train_time:147949ms step_avg:90.27ms +step:1640/1680 train_time:148042ms step_avg:90.27ms +step:1641/1680 train_time:148133ms step_avg:90.27ms +step:1642/1680 train_time:148224ms step_avg:90.27ms +step:1643/1680 train_time:148314ms step_avg:90.27ms +step:1644/1680 train_time:148404ms step_avg:90.27ms +step:1645/1680 train_time:148495ms step_avg:90.27ms +step:1646/1680 train_time:148585ms step_avg:90.27ms +step:1647/1680 train_time:148676ms step_avg:90.27ms +step:1648/1680 train_time:148768ms step_avg:90.27ms +step:1649/1680 train_time:148859ms step_avg:90.27ms +step:1650/1680 train_time:148951ms step_avg:90.27ms +step:1651/1680 train_time:149042ms step_avg:90.27ms +step:1652/1680 train_time:149133ms step_avg:90.27ms +step:1653/1680 train_time:149224ms step_avg:90.27ms +step:1654/1680 train_time:149314ms step_avg:90.27ms +step:1655/1680 train_time:149404ms step_avg:90.27ms +step:1656/1680 train_time:149495ms step_avg:90.27ms +step:1657/1680 train_time:149586ms step_avg:90.28ms +step:1658/1680 train_time:149677ms step_avg:90.28ms +step:1659/1680 train_time:149768ms step_avg:90.28ms +step:1660/1680 train_time:149860ms step_avg:90.28ms +step:1661/1680 train_time:149951ms step_avg:90.28ms +step:1662/1680 train_time:150042ms step_avg:90.28ms +step:1663/1680 train_time:150133ms step_avg:90.28ms +step:1664/1680 train_time:150223ms step_avg:90.28ms +step:1665/1680 train_time:150313ms step_avg:90.28ms +step:1666/1680 train_time:150404ms step_avg:90.28ms +step:1667/1680 train_time:150493ms step_avg:90.28ms +step:1668/1680 train_time:150585ms step_avg:90.28ms +step:1669/1680 train_time:150679ms step_avg:90.28ms +step:1670/1680 train_time:150768ms step_avg:90.28ms +step:1671/1680 train_time:150860ms step_avg:90.28ms +step:1672/1680 train_time:150951ms step_avg:90.28ms +step:1673/1680 train_time:151042ms step_avg:90.28ms +step:1674/1680 train_time:151133ms step_avg:90.28ms +step:1675/1680 train_time:151223ms step_avg:90.28ms +step:1676/1680 train_time:151315ms step_avg:90.28ms +step:1677/1680 train_time:151405ms step_avg:90.28ms +step:1678/1680 train_time:151496ms step_avg:90.28ms +step:1679/1680 train_time:151587ms step_avg:90.28ms +step:1680/1680 train_time:151678ms step_avg:90.28ms +step:1680/1680 val_loss:3.2801 train_time:151771ms step_avg:90.34ms +peak memory allocated: 31255 MiB reserved: 46654 MiB diff --git a/records/092125_DropAttn/d511e5c8-cce8-43ff-bac2-5168366ba47c.txt b/records/092125_DropAttn/d511e5c8-cce8-43ff-bac2-5168366ba47c.txt new file mode 100644 index 000000000..d10a8223d --- /dev/null +++ b/records/092125_DropAttn/d511e5c8-cce8-43ff-bac2-5168366ba47c.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 23:21:14 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 45C P0 128W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 32C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 40C P0 125W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 41C P0 126W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 40C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 84409 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 84410 C /usr/bin/python3 614MiB | +| 0 N/A N/A 84411 C /usr/bin/python3 614MiB | +| 0 N/A N/A 84412 C /usr/bin/python3 614MiB | +| 0 N/A N/A 84413 C /usr/bin/python3 614MiB | +| 0 N/A N/A 84414 C /usr/bin/python3 614MiB | +| 0 N/A N/A 84415 C /usr/bin/python3 614MiB | +| 0 N/A N/A 84416 C /usr/bin/python3 614MiB | +| 1 N/A N/A 84410 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 84411 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 84412 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 84413 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 84414 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 84415 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 84416 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:152ms step_avg:152.49ms +step:2/1680 train_time:177ms step_avg:88.30ms +step:3/1680 train_time:238ms step_avg:79.35ms +step:4/1680 train_time:326ms step_avg:81.41ms +step:5/1680 train_time:413ms step_avg:82.67ms +step:6/1680 train_time:501ms step_avg:83.56ms +step:7/1680 train_time:603ms step_avg:86.20ms +step:8/1680 train_time:693ms step_avg:86.57ms +step:9/1680 train_time:781ms step_avg:86.78ms +step:10/1680 train_time:869ms step_avg:86.92ms +step:11/1680 train_time:957ms step_avg:87.00ms +step:12/1680 train_time:1045ms step_avg:87.12ms +step:13/1680 train_time:1135ms step_avg:87.29ms +step:14/1680 train_time:1225ms step_avg:87.51ms +step:15/1680 train_time:1314ms step_avg:87.61ms +step:16/1680 train_time:1404ms step_avg:87.77ms +step:17/1680 train_time:1494ms step_avg:87.88ms +step:18/1680 train_time:1584ms step_avg:88.00ms +step:19/1680 train_time:1674ms step_avg:88.11ms +step:20/1680 train_time:1764ms step_avg:88.21ms +step:21/1680 train_time:1853ms step_avg:88.25ms +step:22/1680 train_time:1942ms step_avg:88.29ms +step:23/1680 train_time:2031ms step_avg:88.32ms +step:24/1680 train_time:2120ms step_avg:88.34ms +step:25/1680 train_time:2210ms step_avg:88.38ms +step:26/1680 train_time:2300ms step_avg:88.45ms +step:27/1680 train_time:2388ms step_avg:88.46ms +step:28/1680 train_time:2478ms step_avg:88.51ms +step:29/1680 train_time:2567ms step_avg:88.53ms +step:30/1680 train_time:2657ms step_avg:88.58ms +step:31/1680 train_time:2747ms step_avg:88.60ms +step:32/1680 train_time:2836ms step_avg:88.61ms +step:33/1680 train_time:2925ms step_avg:88.62ms +step:34/1680 train_time:3013ms step_avg:88.62ms +step:35/1680 train_time:3102ms step_avg:88.62ms +step:36/1680 train_time:3191ms step_avg:88.63ms +step:37/1680 train_time:3280ms step_avg:88.64ms +step:38/1680 train_time:3369ms step_avg:88.66ms +step:39/1680 train_time:3459ms step_avg:88.68ms +step:40/1680 train_time:3547ms step_avg:88.69ms +step:41/1680 train_time:3638ms step_avg:88.73ms +step:42/1680 train_time:3727ms step_avg:88.75ms +step:43/1680 train_time:3816ms step_avg:88.75ms +step:44/1680 train_time:3906ms step_avg:88.77ms +step:45/1680 train_time:3995ms step_avg:88.77ms +step:46/1680 train_time:4084ms step_avg:88.78ms +step:47/1680 train_time:4173ms step_avg:88.79ms +step:48/1680 train_time:4263ms step_avg:88.81ms +step:49/1680 train_time:4352ms step_avg:88.82ms +step:50/1680 train_time:4442ms step_avg:88.83ms +step:51/1680 train_time:4531ms step_avg:88.84ms +step:52/1680 train_time:4620ms step_avg:88.85ms +step:53/1680 train_time:4710ms step_avg:88.86ms +step:54/1680 train_time:4800ms step_avg:88.88ms +step:55/1680 train_time:4889ms step_avg:88.89ms +step:56/1680 train_time:4977ms step_avg:88.88ms +step:57/1680 train_time:5066ms step_avg:88.87ms +step:58/1680 train_time:5155ms step_avg:88.87ms +step:59/1680 train_time:5244ms step_avg:88.88ms +step:60/1680 train_time:5333ms step_avg:88.88ms +step:61/1680 train_time:5423ms step_avg:88.90ms +step:62/1680 train_time:5512ms step_avg:88.91ms +step:63/1680 train_time:5603ms step_avg:88.93ms +step:64/1680 train_time:5692ms step_avg:88.93ms +step:65/1680 train_time:5781ms step_avg:88.94ms +step:66/1680 train_time:5870ms step_avg:88.94ms +step:67/1680 train_time:5960ms step_avg:88.95ms +step:68/1680 train_time:6050ms step_avg:88.97ms +step:69/1680 train_time:6139ms step_avg:88.97ms +step:70/1680 train_time:6228ms step_avg:88.97ms +step:71/1680 train_time:6317ms step_avg:88.97ms +step:72/1680 train_time:6406ms step_avg:88.97ms +step:73/1680 train_time:6496ms step_avg:88.98ms +step:74/1680 train_time:6585ms step_avg:88.99ms +step:75/1680 train_time:6674ms step_avg:88.99ms +step:76/1680 train_time:6764ms step_avg:89.00ms +step:77/1680 train_time:6853ms step_avg:89.00ms +step:78/1680 train_time:6943ms step_avg:89.01ms +step:79/1680 train_time:7031ms step_avg:89.00ms +step:80/1680 train_time:7120ms step_avg:89.00ms +step:81/1680 train_time:7209ms step_avg:89.00ms +step:82/1680 train_time:7298ms step_avg:89.00ms +step:83/1680 train_time:7387ms step_avg:89.00ms +step:84/1680 train_time:7476ms step_avg:89.00ms +step:85/1680 train_time:7566ms step_avg:89.01ms +step:86/1680 train_time:7655ms step_avg:89.01ms +step:87/1680 train_time:7745ms step_avg:89.03ms +step:88/1680 train_time:7834ms step_avg:89.02ms +step:89/1680 train_time:7923ms step_avg:89.02ms +step:90/1680 train_time:8012ms step_avg:89.02ms +step:91/1680 train_time:8101ms step_avg:89.02ms +step:92/1680 train_time:8190ms step_avg:89.03ms +step:93/1680 train_time:8279ms step_avg:89.02ms +step:94/1680 train_time:8368ms step_avg:89.03ms +step:95/1680 train_time:8457ms step_avg:89.03ms +step:96/1680 train_time:8547ms step_avg:89.03ms +step:97/1680 train_time:8636ms step_avg:89.03ms +step:98/1680 train_time:8725ms step_avg:89.03ms +step:99/1680 train_time:8814ms step_avg:89.03ms +step:100/1680 train_time:8903ms step_avg:89.03ms +step:101/1680 train_time:8992ms step_avg:89.03ms +step:102/1680 train_time:9081ms step_avg:89.03ms +step:103/1680 train_time:9170ms step_avg:89.03ms +step:104/1680 train_time:9258ms step_avg:89.02ms +step:105/1680 train_time:9348ms step_avg:89.03ms +step:106/1680 train_time:9437ms step_avg:89.03ms +step:107/1680 train_time:9526ms step_avg:89.03ms +step:108/1680 train_time:9614ms step_avg:89.02ms +step:109/1680 train_time:9704ms step_avg:89.03ms +step:110/1680 train_time:9793ms step_avg:89.03ms +step:111/1680 train_time:9882ms step_avg:89.03ms +step:112/1680 train_time:9971ms step_avg:89.02ms +step:113/1680 train_time:10059ms step_avg:89.02ms +step:114/1680 train_time:10148ms step_avg:89.02ms +step:115/1680 train_time:10237ms step_avg:89.01ms +step:116/1680 train_time:10326ms step_avg:89.02ms +step:117/1680 train_time:10414ms step_avg:89.01ms +step:118/1680 train_time:10503ms step_avg:89.01ms +step:119/1680 train_time:10593ms step_avg:89.01ms +step:120/1680 train_time:10681ms step_avg:89.01ms +step:121/1680 train_time:10771ms step_avg:89.01ms +step:122/1680 train_time:10860ms step_avg:89.01ms +step:123/1680 train_time:10949ms step_avg:89.02ms +step:124/1680 train_time:11038ms step_avg:89.02ms +step:125/1680 train_time:11127ms step_avg:89.01ms +step:125/1680 val_loss:4.3123 train_time:11216ms step_avg:89.73ms +step:126/1680 train_time:11239ms step_avg:89.20ms +step:127/1680 train_time:11305ms step_avg:89.02ms +step:128/1680 train_time:11402ms step_avg:89.08ms +step:129/1680 train_time:11497ms step_avg:89.13ms +step:130/1680 train_time:11588ms step_avg:89.14ms +step:131/1680 train_time:11676ms step_avg:89.13ms +step:132/1680 train_time:11763ms step_avg:89.12ms +step:133/1680 train_time:11851ms step_avg:89.11ms +step:134/1680 train_time:11939ms step_avg:89.09ms +step:135/1680 train_time:12027ms step_avg:89.09ms +step:136/1680 train_time:12114ms step_avg:89.07ms +step:137/1680 train_time:12202ms step_avg:89.06ms +step:138/1680 train_time:12291ms step_avg:89.06ms +step:139/1680 train_time:12381ms step_avg:89.07ms +step:140/1680 train_time:12472ms step_avg:89.08ms +step:141/1680 train_time:12562ms step_avg:89.09ms +step:142/1680 train_time:12652ms step_avg:89.10ms +step:143/1680 train_time:12740ms step_avg:89.09ms +step:144/1680 train_time:12828ms step_avg:89.08ms +step:145/1680 train_time:12916ms step_avg:89.08ms +step:146/1680 train_time:13004ms step_avg:89.07ms +step:147/1680 train_time:13092ms step_avg:89.06ms +step:148/1680 train_time:13180ms step_avg:89.05ms +step:149/1680 train_time:13269ms step_avg:89.06ms +step:150/1680 train_time:13358ms step_avg:89.05ms +step:151/1680 train_time:13448ms step_avg:89.06ms +step:152/1680 train_time:13537ms step_avg:89.06ms +step:153/1680 train_time:13626ms step_avg:89.06ms +step:154/1680 train_time:13716ms step_avg:89.06ms +step:155/1680 train_time:13804ms step_avg:89.06ms +step:156/1680 train_time:13893ms step_avg:89.06ms +step:157/1680 train_time:13981ms step_avg:89.05ms +step:158/1680 train_time:14069ms step_avg:89.04ms +step:159/1680 train_time:14157ms step_avg:89.04ms +step:160/1680 train_time:14246ms step_avg:89.04ms +step:161/1680 train_time:14335ms step_avg:89.04ms +step:162/1680 train_time:14425ms step_avg:89.04ms +step:163/1680 train_time:14514ms step_avg:89.04ms +step:164/1680 train_time:14603ms step_avg:89.04ms +step:165/1680 train_time:14692ms step_avg:89.04ms +step:166/1680 train_time:14780ms step_avg:89.04ms +step:167/1680 train_time:14869ms step_avg:89.04ms +step:168/1680 train_time:14958ms step_avg:89.03ms +step:169/1680 train_time:15047ms step_avg:89.03ms +step:170/1680 train_time:15135ms step_avg:89.03ms +step:171/1680 train_time:15224ms step_avg:89.03ms +step:172/1680 train_time:15312ms step_avg:89.02ms +step:173/1680 train_time:15400ms step_avg:89.02ms +step:174/1680 train_time:15489ms step_avg:89.02ms +step:175/1680 train_time:15578ms step_avg:89.02ms +step:176/1680 train_time:15667ms step_avg:89.02ms +step:177/1680 train_time:15757ms step_avg:89.02ms +step:178/1680 train_time:15846ms step_avg:89.02ms +step:179/1680 train_time:15934ms step_avg:89.02ms +step:180/1680 train_time:16022ms step_avg:89.01ms +step:181/1680 train_time:16111ms step_avg:89.01ms +step:182/1680 train_time:16200ms step_avg:89.01ms +step:183/1680 train_time:16289ms step_avg:89.01ms +step:184/1680 train_time:16377ms step_avg:89.01ms +step:185/1680 train_time:16465ms step_avg:89.00ms +step:186/1680 train_time:16555ms step_avg:89.00ms +step:187/1680 train_time:16644ms step_avg:89.01ms +step:188/1680 train_time:16733ms step_avg:89.01ms +step:189/1680 train_time:16822ms step_avg:89.00ms +step:190/1680 train_time:16911ms step_avg:89.01ms +step:191/1680 train_time:17000ms step_avg:89.01ms +step:192/1680 train_time:17089ms step_avg:89.00ms +step:193/1680 train_time:17177ms step_avg:89.00ms +step:194/1680 train_time:17266ms step_avg:89.00ms +step:195/1680 train_time:17355ms step_avg:89.00ms +step:196/1680 train_time:17443ms step_avg:89.00ms +step:197/1680 train_time:17533ms step_avg:89.00ms +step:198/1680 train_time:17621ms step_avg:89.00ms +step:199/1680 train_time:17710ms step_avg:89.00ms +step:200/1680 train_time:17799ms step_avg:89.00ms +step:201/1680 train_time:17887ms step_avg:88.99ms +step:202/1680 train_time:17976ms step_avg:88.99ms +step:203/1680 train_time:18064ms step_avg:88.99ms +step:204/1680 train_time:18153ms step_avg:88.99ms +step:205/1680 train_time:18242ms step_avg:88.99ms +step:206/1680 train_time:18332ms step_avg:88.99ms +step:207/1680 train_time:18420ms step_avg:88.99ms +step:208/1680 train_time:18509ms step_avg:88.98ms +step:209/1680 train_time:18599ms step_avg:88.99ms +step:210/1680 train_time:18688ms step_avg:88.99ms +step:211/1680 train_time:18777ms step_avg:88.99ms +step:212/1680 train_time:18865ms step_avg:88.99ms +step:213/1680 train_time:18954ms step_avg:88.99ms +step:214/1680 train_time:19042ms step_avg:88.98ms +step:215/1680 train_time:19131ms step_avg:88.98ms +step:216/1680 train_time:19219ms step_avg:88.98ms +step:217/1680 train_time:19308ms step_avg:88.98ms +step:218/1680 train_time:19396ms step_avg:88.97ms +step:219/1680 train_time:19485ms step_avg:88.97ms +step:220/1680 train_time:19574ms step_avg:88.97ms +step:221/1680 train_time:19663ms step_avg:88.97ms +step:222/1680 train_time:19752ms step_avg:88.97ms +step:223/1680 train_time:19840ms step_avg:88.97ms +step:224/1680 train_time:19930ms step_avg:88.97ms +step:225/1680 train_time:20018ms step_avg:88.97ms +step:226/1680 train_time:20108ms step_avg:88.97ms +step:227/1680 train_time:20196ms step_avg:88.97ms +step:228/1680 train_time:20285ms step_avg:88.97ms +step:229/1680 train_time:20374ms step_avg:88.97ms +step:230/1680 train_time:20462ms step_avg:88.97ms +step:231/1680 train_time:20552ms step_avg:88.97ms +step:232/1680 train_time:20640ms step_avg:88.96ms +step:233/1680 train_time:20728ms step_avg:88.96ms +step:234/1680 train_time:20817ms step_avg:88.96ms +step:235/1680 train_time:20906ms step_avg:88.96ms +step:236/1680 train_time:20995ms step_avg:88.96ms +step:237/1680 train_time:21084ms step_avg:88.96ms +step:238/1680 train_time:21173ms step_avg:88.96ms +step:239/1680 train_time:21261ms step_avg:88.96ms +step:240/1680 train_time:21350ms step_avg:88.96ms +step:241/1680 train_time:21438ms step_avg:88.96ms +step:242/1680 train_time:21527ms step_avg:88.96ms +step:243/1680 train_time:21616ms step_avg:88.96ms +step:244/1680 train_time:21705ms step_avg:88.95ms +step:245/1680 train_time:21794ms step_avg:88.96ms +step:246/1680 train_time:21883ms step_avg:88.96ms +step:247/1680 train_time:21972ms step_avg:88.96ms +step:248/1680 train_time:22061ms step_avg:88.95ms +step:249/1680 train_time:22149ms step_avg:88.95ms +step:250/1680 train_time:22238ms step_avg:88.95ms +step:250/1680 val_loss:3.9734 train_time:22328ms step_avg:89.31ms +step:251/1680 train_time:22351ms step_avg:89.05ms +step:252/1680 train_time:22422ms step_avg:88.98ms +step:253/1680 train_time:22516ms step_avg:89.00ms +step:254/1680 train_time:22606ms step_avg:89.00ms +step:255/1680 train_time:22695ms step_avg:89.00ms +step:256/1680 train_time:22782ms step_avg:88.99ms +step:257/1680 train_time:22870ms step_avg:88.99ms +step:258/1680 train_time:22959ms step_avg:88.99ms +step:259/1680 train_time:23047ms step_avg:88.98ms +step:260/1680 train_time:23134ms step_avg:88.98ms +step:261/1680 train_time:23222ms step_avg:88.97ms +step:262/1680 train_time:23312ms step_avg:88.98ms +step:263/1680 train_time:23403ms step_avg:88.98ms +step:264/1680 train_time:23494ms step_avg:88.99ms +step:265/1680 train_time:23584ms step_avg:89.00ms +step:266/1680 train_time:23674ms step_avg:89.00ms +step:267/1680 train_time:23761ms step_avg:88.99ms +step:268/1680 train_time:23850ms step_avg:88.99ms +step:269/1680 train_time:23938ms step_avg:88.99ms +step:270/1680 train_time:24026ms step_avg:88.99ms +step:271/1680 train_time:24114ms step_avg:88.98ms +step:272/1680 train_time:24202ms step_avg:88.98ms +step:273/1680 train_time:24292ms step_avg:88.98ms +step:274/1680 train_time:24382ms step_avg:88.99ms +step:275/1680 train_time:24472ms step_avg:88.99ms +step:276/1680 train_time:24561ms step_avg:88.99ms +step:277/1680 train_time:24651ms step_avg:88.99ms +step:278/1680 train_time:24739ms step_avg:88.99ms +step:279/1680 train_time:24828ms step_avg:88.99ms +step:280/1680 train_time:24916ms step_avg:88.99ms +step:281/1680 train_time:25006ms step_avg:88.99ms +step:282/1680 train_time:25094ms step_avg:88.99ms +step:283/1680 train_time:25182ms step_avg:88.98ms +step:284/1680 train_time:25272ms step_avg:88.99ms +step:285/1680 train_time:25361ms step_avg:88.99ms +step:286/1680 train_time:25451ms step_avg:88.99ms +step:287/1680 train_time:25540ms step_avg:88.99ms +step:288/1680 train_time:25630ms step_avg:88.99ms +step:289/1680 train_time:25719ms step_avg:88.99ms +step:290/1680 train_time:25808ms step_avg:88.99ms +step:291/1680 train_time:25897ms step_avg:88.99ms +step:292/1680 train_time:25985ms step_avg:88.99ms +step:293/1680 train_time:26073ms step_avg:88.99ms +step:294/1680 train_time:26161ms step_avg:88.98ms +step:295/1680 train_time:26250ms step_avg:88.98ms +step:296/1680 train_time:26338ms step_avg:88.98ms +step:297/1680 train_time:26427ms step_avg:88.98ms +step:298/1680 train_time:26516ms step_avg:88.98ms +step:299/1680 train_time:26605ms step_avg:88.98ms +step:300/1680 train_time:26694ms step_avg:88.98ms +step:301/1680 train_time:26783ms step_avg:88.98ms +step:302/1680 train_time:26872ms step_avg:88.98ms +step:303/1680 train_time:26960ms step_avg:88.98ms +step:304/1680 train_time:27049ms step_avg:88.98ms +step:305/1680 train_time:27137ms step_avg:88.97ms +step:306/1680 train_time:27226ms step_avg:88.97ms +step:307/1680 train_time:27315ms step_avg:88.97ms +step:308/1680 train_time:27403ms step_avg:88.97ms +step:309/1680 train_time:27491ms step_avg:88.97ms +step:310/1680 train_time:27580ms step_avg:88.97ms +step:311/1680 train_time:27669ms step_avg:88.97ms +step:312/1680 train_time:27758ms step_avg:88.97ms +step:313/1680 train_time:27847ms step_avg:88.97ms +step:314/1680 train_time:27936ms step_avg:88.97ms +step:315/1680 train_time:28024ms step_avg:88.97ms +step:316/1680 train_time:28114ms step_avg:88.97ms +step:317/1680 train_time:28202ms step_avg:88.96ms +step:318/1680 train_time:28291ms step_avg:88.97ms +step:319/1680 train_time:28380ms step_avg:88.96ms +step:320/1680 train_time:28469ms step_avg:88.97ms +step:321/1680 train_time:28559ms step_avg:88.97ms +step:322/1680 train_time:28647ms step_avg:88.97ms +step:323/1680 train_time:28736ms step_avg:88.97ms +step:324/1680 train_time:28825ms step_avg:88.97ms +step:325/1680 train_time:28914ms step_avg:88.97ms +step:326/1680 train_time:29002ms step_avg:88.96ms +step:327/1680 train_time:29092ms step_avg:88.97ms +step:328/1680 train_time:29180ms step_avg:88.96ms +step:329/1680 train_time:29269ms step_avg:88.96ms +step:330/1680 train_time:29358ms step_avg:88.96ms +step:331/1680 train_time:29446ms step_avg:88.96ms +step:332/1680 train_time:29535ms step_avg:88.96ms +step:333/1680 train_time:29624ms step_avg:88.96ms +step:334/1680 train_time:29713ms step_avg:88.96ms +step:335/1680 train_time:29802ms step_avg:88.96ms +step:336/1680 train_time:29891ms step_avg:88.96ms +step:337/1680 train_time:29979ms step_avg:88.96ms +step:338/1680 train_time:30068ms step_avg:88.96ms +step:339/1680 train_time:30157ms step_avg:88.96ms +step:340/1680 train_time:30246ms step_avg:88.96ms +step:341/1680 train_time:30336ms step_avg:88.96ms +step:342/1680 train_time:30425ms step_avg:88.96ms +step:343/1680 train_time:30514ms step_avg:88.96ms +step:344/1680 train_time:30603ms step_avg:88.96ms +step:345/1680 train_time:30692ms step_avg:88.96ms +step:346/1680 train_time:30781ms step_avg:88.96ms +step:347/1680 train_time:30870ms step_avg:88.96ms +step:348/1680 train_time:30959ms step_avg:88.96ms +step:349/1680 train_time:31048ms step_avg:88.96ms +step:350/1680 train_time:31139ms step_avg:88.97ms +step:351/1680 train_time:31225ms step_avg:88.96ms +step:352/1680 train_time:31314ms step_avg:88.96ms +step:353/1680 train_time:31403ms step_avg:88.96ms +step:354/1680 train_time:31493ms step_avg:88.96ms +step:355/1680 train_time:31581ms step_avg:88.96ms +step:356/1680 train_time:31670ms step_avg:88.96ms +step:357/1680 train_time:31759ms step_avg:88.96ms +step:358/1680 train_time:31848ms step_avg:88.96ms +step:359/1680 train_time:31936ms step_avg:88.96ms +step:360/1680 train_time:32025ms step_avg:88.96ms +step:361/1680 train_time:32115ms step_avg:88.96ms +step:362/1680 train_time:32203ms step_avg:88.96ms +step:363/1680 train_time:32292ms step_avg:88.96ms +step:364/1680 train_time:32381ms step_avg:88.96ms +step:365/1680 train_time:32469ms step_avg:88.96ms +step:366/1680 train_time:32558ms step_avg:88.96ms +step:367/1680 train_time:32648ms step_avg:88.96ms +step:368/1680 train_time:32737ms step_avg:88.96ms +step:369/1680 train_time:32826ms step_avg:88.96ms +step:370/1680 train_time:32915ms step_avg:88.96ms +step:371/1680 train_time:33003ms step_avg:88.96ms +step:372/1680 train_time:33092ms step_avg:88.96ms +step:373/1680 train_time:33181ms step_avg:88.96ms +step:374/1680 train_time:33270ms step_avg:88.96ms +step:375/1680 train_time:33359ms step_avg:88.96ms +step:375/1680 val_loss:3.8267 train_time:33449ms step_avg:89.20ms +step:376/1680 train_time:33472ms step_avg:89.02ms +step:377/1680 train_time:33543ms step_avg:88.97ms +step:378/1680 train_time:33640ms step_avg:88.99ms +step:379/1680 train_time:33729ms step_avg:88.99ms +step:380/1680 train_time:33818ms step_avg:88.99ms +step:381/1680 train_time:33905ms step_avg:88.99ms +step:382/1680 train_time:33993ms step_avg:88.99ms +step:383/1680 train_time:34081ms step_avg:88.99ms +step:384/1680 train_time:34170ms step_avg:88.98ms +step:385/1680 train_time:34258ms step_avg:88.98ms +step:386/1680 train_time:34346ms step_avg:88.98ms +step:387/1680 train_time:34435ms step_avg:88.98ms +step:388/1680 train_time:34528ms step_avg:88.99ms +step:389/1680 train_time:34619ms step_avg:88.99ms +step:390/1680 train_time:34708ms step_avg:89.00ms +step:391/1680 train_time:34798ms step_avg:89.00ms +step:392/1680 train_time:34887ms step_avg:89.00ms +step:393/1680 train_time:34975ms step_avg:89.00ms +step:394/1680 train_time:35063ms step_avg:88.99ms +step:395/1680 train_time:35152ms step_avg:88.99ms +step:396/1680 train_time:35240ms step_avg:88.99ms +step:397/1680 train_time:35328ms step_avg:88.99ms +step:398/1680 train_time:35417ms step_avg:88.99ms +step:399/1680 train_time:35507ms step_avg:88.99ms +step:400/1680 train_time:35598ms step_avg:89.00ms +step:401/1680 train_time:35688ms step_avg:89.00ms +step:402/1680 train_time:35777ms step_avg:89.00ms +step:403/1680 train_time:35867ms step_avg:89.00ms +step:404/1680 train_time:35957ms step_avg:89.00ms +step:405/1680 train_time:36045ms step_avg:89.00ms +step:406/1680 train_time:36134ms step_avg:89.00ms +step:407/1680 train_time:36222ms step_avg:89.00ms +step:408/1680 train_time:36310ms step_avg:88.99ms +step:409/1680 train_time:36398ms step_avg:88.99ms +step:410/1680 train_time:36487ms step_avg:88.99ms +step:411/1680 train_time:36577ms step_avg:88.99ms +step:412/1680 train_time:36666ms step_avg:89.00ms +step:413/1680 train_time:36756ms step_avg:89.00ms +step:414/1680 train_time:36845ms step_avg:89.00ms +step:415/1680 train_time:36934ms step_avg:89.00ms +step:416/1680 train_time:37023ms step_avg:89.00ms +step:417/1680 train_time:37111ms step_avg:89.00ms +step:418/1680 train_time:37200ms step_avg:88.99ms +step:419/1680 train_time:37288ms step_avg:88.99ms +step:420/1680 train_time:37377ms step_avg:88.99ms +step:421/1680 train_time:37466ms step_avg:88.99ms +step:422/1680 train_time:37555ms step_avg:88.99ms +step:423/1680 train_time:37645ms step_avg:88.99ms +step:424/1680 train_time:37735ms step_avg:89.00ms +step:425/1680 train_time:37824ms step_avg:89.00ms +step:426/1680 train_time:37913ms step_avg:89.00ms +step:427/1680 train_time:38003ms step_avg:89.00ms +step:428/1680 train_time:38092ms step_avg:89.00ms +step:429/1680 train_time:38181ms step_avg:89.00ms +step:430/1680 train_time:38270ms step_avg:89.00ms +step:431/1680 train_time:38360ms step_avg:89.00ms +step:432/1680 train_time:38449ms step_avg:89.00ms +step:433/1680 train_time:38538ms step_avg:89.00ms +step:434/1680 train_time:38627ms step_avg:89.00ms +step:435/1680 train_time:38716ms step_avg:89.00ms +step:436/1680 train_time:38806ms step_avg:89.00ms +step:437/1680 train_time:38895ms step_avg:89.01ms +step:438/1680 train_time:38990ms step_avg:89.02ms +step:439/1680 train_time:39074ms step_avg:89.01ms +step:440/1680 train_time:39163ms step_avg:89.01ms +step:441/1680 train_time:39252ms step_avg:89.01ms +step:442/1680 train_time:39341ms step_avg:89.01ms +step:443/1680 train_time:39430ms step_avg:89.01ms +step:444/1680 train_time:39519ms step_avg:89.01ms +step:445/1680 train_time:39607ms step_avg:89.00ms +step:446/1680 train_time:39696ms step_avg:89.00ms +step:447/1680 train_time:39785ms step_avg:89.00ms +step:448/1680 train_time:39874ms step_avg:89.01ms +step:449/1680 train_time:39964ms step_avg:89.01ms +step:450/1680 train_time:40053ms step_avg:89.01ms +step:451/1680 train_time:40142ms step_avg:89.01ms +step:452/1680 train_time:40231ms step_avg:89.01ms +step:453/1680 train_time:40320ms step_avg:89.01ms +step:454/1680 train_time:40409ms step_avg:89.01ms +step:455/1680 train_time:40500ms step_avg:89.01ms +step:456/1680 train_time:40586ms step_avg:89.01ms +step:457/1680 train_time:40675ms step_avg:89.00ms +step:458/1680 train_time:40764ms step_avg:89.00ms +step:459/1680 train_time:40854ms step_avg:89.01ms +step:460/1680 train_time:40942ms step_avg:89.01ms +step:461/1680 train_time:41031ms step_avg:89.01ms +step:462/1680 train_time:41121ms step_avg:89.01ms +step:463/1680 train_time:41210ms step_avg:89.01ms +step:464/1680 train_time:41300ms step_avg:89.01ms +step:465/1680 train_time:41389ms step_avg:89.01ms +step:466/1680 train_time:41479ms step_avg:89.01ms +step:467/1680 train_time:41567ms step_avg:89.01ms +step:468/1680 train_time:41656ms step_avg:89.01ms +step:469/1680 train_time:41744ms step_avg:89.01ms +step:470/1680 train_time:41833ms step_avg:89.01ms +step:471/1680 train_time:41922ms step_avg:89.01ms +step:472/1680 train_time:42011ms step_avg:89.01ms +step:473/1680 train_time:42101ms step_avg:89.01ms +step:474/1680 train_time:42190ms step_avg:89.01ms +step:475/1680 train_time:42279ms step_avg:89.01ms +step:476/1680 train_time:42368ms step_avg:89.01ms +step:477/1680 train_time:42457ms step_avg:89.01ms +step:478/1680 train_time:42546ms step_avg:89.01ms +step:479/1680 train_time:42635ms step_avg:89.01ms +step:480/1680 train_time:42724ms step_avg:89.01ms +step:481/1680 train_time:42813ms step_avg:89.01ms +step:482/1680 train_time:42902ms step_avg:89.01ms +step:483/1680 train_time:42991ms step_avg:89.01ms +step:484/1680 train_time:43080ms step_avg:89.01ms +step:485/1680 train_time:43169ms step_avg:89.01ms +step:486/1680 train_time:43259ms step_avg:89.01ms +step:487/1680 train_time:43347ms step_avg:89.01ms +step:488/1680 train_time:43437ms step_avg:89.01ms +step:489/1680 train_time:43525ms step_avg:89.01ms +step:490/1680 train_time:43614ms step_avg:89.01ms +step:491/1680 train_time:43704ms step_avg:89.01ms +step:492/1680 train_time:43793ms step_avg:89.01ms +step:493/1680 train_time:43882ms step_avg:89.01ms +step:494/1680 train_time:43971ms step_avg:89.01ms +step:495/1680 train_time:44061ms step_avg:89.01ms +step:496/1680 train_time:44150ms step_avg:89.01ms +step:497/1680 train_time:44239ms step_avg:89.01ms +step:498/1680 train_time:44328ms step_avg:89.01ms +step:499/1680 train_time:44417ms step_avg:89.01ms +step:500/1680 train_time:44506ms step_avg:89.01ms +step:500/1680 val_loss:3.7202 train_time:44596ms step_avg:89.19ms +step:501/1680 train_time:44619ms step_avg:89.06ms +step:502/1680 train_time:44688ms step_avg:89.02ms +step:503/1680 train_time:44781ms step_avg:89.03ms +step:504/1680 train_time:44872ms step_avg:89.03ms +step:505/1680 train_time:44961ms step_avg:89.03ms +step:506/1680 train_time:45049ms step_avg:89.03ms +step:507/1680 train_time:45137ms step_avg:89.03ms +step:508/1680 train_time:45225ms step_avg:89.03ms +step:509/1680 train_time:45313ms step_avg:89.02ms +step:510/1680 train_time:45401ms step_avg:89.02ms +step:511/1680 train_time:45489ms step_avg:89.02ms +step:512/1680 train_time:45579ms step_avg:89.02ms +step:513/1680 train_time:45668ms step_avg:89.02ms +step:514/1680 train_time:45759ms step_avg:89.03ms +step:515/1680 train_time:45850ms step_avg:89.03ms +step:516/1680 train_time:45939ms step_avg:89.03ms +step:517/1680 train_time:46029ms step_avg:89.03ms +step:518/1680 train_time:46117ms step_avg:89.03ms +step:519/1680 train_time:46206ms step_avg:89.03ms +step:520/1680 train_time:46295ms step_avg:89.03ms +step:521/1680 train_time:46384ms step_avg:89.03ms +step:522/1680 train_time:46472ms step_avg:89.03ms +step:523/1680 train_time:46561ms step_avg:89.03ms +step:524/1680 train_time:46651ms step_avg:89.03ms +step:525/1680 train_time:46741ms step_avg:89.03ms +step:526/1680 train_time:46830ms step_avg:89.03ms +step:527/1680 train_time:46920ms step_avg:89.03ms +step:528/1680 train_time:47009ms step_avg:89.03ms +step:529/1680 train_time:47098ms step_avg:89.03ms +step:530/1680 train_time:47187ms step_avg:89.03ms +step:531/1680 train_time:47275ms step_avg:89.03ms +step:532/1680 train_time:47364ms step_avg:89.03ms +step:533/1680 train_time:47453ms step_avg:89.03ms +step:534/1680 train_time:47541ms step_avg:89.03ms +step:535/1680 train_time:47631ms step_avg:89.03ms +step:536/1680 train_time:47720ms step_avg:89.03ms +step:537/1680 train_time:47809ms step_avg:89.03ms +step:538/1680 train_time:47898ms step_avg:89.03ms +step:539/1680 train_time:47988ms step_avg:89.03ms +step:540/1680 train_time:48077ms step_avg:89.03ms +step:541/1680 train_time:48165ms step_avg:89.03ms +step:542/1680 train_time:48254ms step_avg:89.03ms +step:543/1680 train_time:48342ms step_avg:89.03ms +step:544/1680 train_time:48432ms step_avg:89.03ms +step:545/1680 train_time:48521ms step_avg:89.03ms +step:546/1680 train_time:48611ms step_avg:89.03ms +step:547/1680 train_time:48700ms step_avg:89.03ms +step:548/1680 train_time:48789ms step_avg:89.03ms +step:549/1680 train_time:48879ms step_avg:89.03ms +step:550/1680 train_time:48970ms step_avg:89.04ms +step:551/1680 train_time:49061ms step_avg:89.04ms +step:552/1680 train_time:49151ms step_avg:89.04ms +step:553/1680 train_time:49240ms step_avg:89.04ms +step:554/1680 train_time:49331ms step_avg:89.04ms +step:555/1680 train_time:49421ms step_avg:89.05ms +step:556/1680 train_time:49512ms step_avg:89.05ms +step:557/1680 train_time:49602ms step_avg:89.05ms +step:558/1680 train_time:49692ms step_avg:89.05ms +step:559/1680 train_time:49782ms step_avg:89.06ms +step:560/1680 train_time:49872ms step_avg:89.06ms +step:561/1680 train_time:49963ms step_avg:89.06ms +step:562/1680 train_time:50056ms step_avg:89.07ms +step:563/1680 train_time:50144ms step_avg:89.07ms +step:564/1680 train_time:50234ms step_avg:89.07ms +step:565/1680 train_time:50324ms step_avg:89.07ms +step:566/1680 train_time:50416ms step_avg:89.07ms +step:567/1680 train_time:50506ms step_avg:89.08ms +step:568/1680 train_time:50596ms step_avg:89.08ms +step:569/1680 train_time:50687ms step_avg:89.08ms +step:570/1680 train_time:50776ms step_avg:89.08ms +step:571/1680 train_time:50867ms step_avg:89.08ms +step:572/1680 train_time:50958ms step_avg:89.09ms +step:573/1680 train_time:51049ms step_avg:89.09ms +step:574/1680 train_time:51139ms step_avg:89.09ms +step:575/1680 train_time:51230ms step_avg:89.10ms +step:576/1680 train_time:51320ms step_avg:89.10ms +step:577/1680 train_time:51411ms step_avg:89.10ms +step:578/1680 train_time:51501ms step_avg:89.10ms +step:579/1680 train_time:51591ms step_avg:89.10ms +step:580/1680 train_time:51682ms step_avg:89.11ms +step:581/1680 train_time:51772ms step_avg:89.11ms +step:582/1680 train_time:51862ms step_avg:89.11ms +step:583/1680 train_time:51953ms step_avg:89.11ms +step:584/1680 train_time:52043ms step_avg:89.11ms +step:585/1680 train_time:52133ms step_avg:89.12ms +step:586/1680 train_time:52224ms step_avg:89.12ms +step:587/1680 train_time:52315ms step_avg:89.12ms +step:588/1680 train_time:52404ms step_avg:89.12ms +step:589/1680 train_time:52495ms step_avg:89.13ms +step:590/1680 train_time:52586ms step_avg:89.13ms +step:591/1680 train_time:52676ms step_avg:89.13ms +step:592/1680 train_time:52767ms step_avg:89.13ms +step:593/1680 train_time:52857ms step_avg:89.14ms +step:594/1680 train_time:52947ms step_avg:89.14ms +step:595/1680 train_time:53038ms step_avg:89.14ms +step:596/1680 train_time:53128ms step_avg:89.14ms +step:597/1680 train_time:53219ms step_avg:89.14ms +step:598/1680 train_time:53309ms step_avg:89.15ms +step:599/1680 train_time:53399ms step_avg:89.15ms +step:600/1680 train_time:53489ms step_avg:89.15ms +step:601/1680 train_time:53579ms step_avg:89.15ms +step:602/1680 train_time:53669ms step_avg:89.15ms +step:603/1680 train_time:53760ms step_avg:89.15ms +step:604/1680 train_time:53850ms step_avg:89.16ms +step:605/1680 train_time:53939ms step_avg:89.16ms +step:606/1680 train_time:54031ms step_avg:89.16ms +step:607/1680 train_time:54121ms step_avg:89.16ms +step:608/1680 train_time:54211ms step_avg:89.16ms +step:609/1680 train_time:54301ms step_avg:89.16ms +step:610/1680 train_time:54391ms step_avg:89.17ms +step:611/1680 train_time:54481ms step_avg:89.17ms +step:612/1680 train_time:54572ms step_avg:89.17ms +step:613/1680 train_time:54662ms step_avg:89.17ms +step:614/1680 train_time:54754ms step_avg:89.18ms +step:615/1680 train_time:54844ms step_avg:89.18ms +step:616/1680 train_time:54935ms step_avg:89.18ms +step:617/1680 train_time:55025ms step_avg:89.18ms +step:618/1680 train_time:55116ms step_avg:89.18ms +step:619/1680 train_time:55206ms step_avg:89.19ms +step:620/1680 train_time:55297ms step_avg:89.19ms +step:621/1680 train_time:55386ms step_avg:89.19ms +step:622/1680 train_time:55477ms step_avg:89.19ms +step:623/1680 train_time:55566ms step_avg:89.19ms +step:624/1680 train_time:55657ms step_avg:89.19ms +step:625/1680 train_time:55748ms step_avg:89.20ms +step:625/1680 val_loss:3.6196 train_time:55839ms step_avg:89.34ms +step:626/1680 train_time:55862ms step_avg:89.24ms +step:627/1680 train_time:55931ms step_avg:89.20ms +step:628/1680 train_time:56031ms step_avg:89.22ms +step:629/1680 train_time:56122ms step_avg:89.22ms +step:630/1680 train_time:56211ms step_avg:89.22ms +step:631/1680 train_time:56301ms step_avg:89.22ms +step:632/1680 train_time:56390ms step_avg:89.23ms +step:633/1680 train_time:56480ms step_avg:89.23ms +step:634/1680 train_time:56568ms step_avg:89.22ms +step:635/1680 train_time:56657ms step_avg:89.22ms +step:636/1680 train_time:56747ms step_avg:89.22ms +step:637/1680 train_time:56843ms step_avg:89.24ms +step:638/1680 train_time:56936ms step_avg:89.24ms +step:639/1680 train_time:57027ms step_avg:89.24ms +step:640/1680 train_time:57117ms step_avg:89.25ms +step:641/1680 train_time:57207ms step_avg:89.25ms +step:642/1680 train_time:57297ms step_avg:89.25ms +step:643/1680 train_time:57387ms step_avg:89.25ms +step:644/1680 train_time:57476ms step_avg:89.25ms +step:645/1680 train_time:57565ms step_avg:89.25ms +step:646/1680 train_time:57654ms step_avg:89.25ms +step:647/1680 train_time:57744ms step_avg:89.25ms +step:648/1680 train_time:57836ms step_avg:89.25ms +step:649/1680 train_time:57927ms step_avg:89.26ms +step:650/1680 train_time:58018ms step_avg:89.26ms +step:651/1680 train_time:58108ms step_avg:89.26ms +step:652/1680 train_time:58199ms step_avg:89.26ms +step:653/1680 train_time:58289ms step_avg:89.26ms +step:654/1680 train_time:58379ms step_avg:89.26ms +step:655/1680 train_time:58468ms step_avg:89.26ms +step:656/1680 train_time:58557ms step_avg:89.26ms +step:657/1680 train_time:58646ms step_avg:89.26ms +step:658/1680 train_time:58738ms step_avg:89.27ms +step:659/1680 train_time:58826ms step_avg:89.27ms +step:660/1680 train_time:58918ms step_avg:89.27ms +step:661/1680 train_time:59008ms step_avg:89.27ms +step:662/1680 train_time:59098ms step_avg:89.27ms +step:663/1680 train_time:59189ms step_avg:89.27ms +step:664/1680 train_time:59279ms step_avg:89.28ms +step:665/1680 train_time:59369ms step_avg:89.28ms +step:666/1680 train_time:59459ms step_avg:89.28ms +step:667/1680 train_time:59549ms step_avg:89.28ms +step:668/1680 train_time:59639ms step_avg:89.28ms +step:669/1680 train_time:59729ms step_avg:89.28ms +step:670/1680 train_time:59819ms step_avg:89.28ms +step:671/1680 train_time:59909ms step_avg:89.28ms +step:672/1680 train_time:60000ms step_avg:89.29ms +step:673/1680 train_time:60090ms step_avg:89.29ms +step:674/1680 train_time:60181ms step_avg:89.29ms +step:675/1680 train_time:60271ms step_avg:89.29ms +step:676/1680 train_time:60361ms step_avg:89.29ms +step:677/1680 train_time:60451ms step_avg:89.29ms +step:678/1680 train_time:60542ms step_avg:89.29ms +step:679/1680 train_time:60632ms step_avg:89.30ms +step:680/1680 train_time:60723ms step_avg:89.30ms +step:681/1680 train_time:60813ms step_avg:89.30ms +step:682/1680 train_time:60904ms step_avg:89.30ms +step:683/1680 train_time:60994ms step_avg:89.30ms +step:684/1680 train_time:61083ms step_avg:89.30ms +step:685/1680 train_time:61175ms step_avg:89.31ms +step:686/1680 train_time:61265ms step_avg:89.31ms +step:687/1680 train_time:61355ms step_avg:89.31ms +step:688/1680 train_time:61445ms step_avg:89.31ms +step:689/1680 train_time:61535ms step_avg:89.31ms +step:690/1680 train_time:61625ms step_avg:89.31ms +step:691/1680 train_time:61715ms step_avg:89.31ms +step:692/1680 train_time:61805ms step_avg:89.31ms +step:693/1680 train_time:61894ms step_avg:89.31ms +step:694/1680 train_time:61984ms step_avg:89.31ms +step:695/1680 train_time:62075ms step_avg:89.32ms +step:696/1680 train_time:62166ms step_avg:89.32ms +step:697/1680 train_time:62255ms step_avg:89.32ms +step:698/1680 train_time:62346ms step_avg:89.32ms +step:699/1680 train_time:62436ms step_avg:89.32ms +step:700/1680 train_time:62526ms step_avg:89.32ms +step:701/1680 train_time:62616ms step_avg:89.32ms +step:702/1680 train_time:62707ms step_avg:89.33ms +step:703/1680 train_time:62797ms step_avg:89.33ms +step:704/1680 train_time:62887ms step_avg:89.33ms +step:705/1680 train_time:62977ms step_avg:89.33ms +step:706/1680 train_time:63067ms step_avg:89.33ms +step:707/1680 train_time:63157ms step_avg:89.33ms +step:708/1680 train_time:63248ms step_avg:89.33ms +step:709/1680 train_time:63339ms step_avg:89.34ms +step:710/1680 train_time:63428ms step_avg:89.34ms +step:711/1680 train_time:63519ms step_avg:89.34ms +step:712/1680 train_time:63609ms step_avg:89.34ms +step:713/1680 train_time:63699ms step_avg:89.34ms +step:714/1680 train_time:63788ms step_avg:89.34ms +step:715/1680 train_time:63878ms step_avg:89.34ms +step:716/1680 train_time:63969ms step_avg:89.34ms +step:717/1680 train_time:64059ms step_avg:89.34ms +step:718/1680 train_time:64149ms step_avg:89.34ms +step:719/1680 train_time:64241ms step_avg:89.35ms +step:720/1680 train_time:64331ms step_avg:89.35ms +step:721/1680 train_time:64421ms step_avg:89.35ms +step:722/1680 train_time:64512ms step_avg:89.35ms +step:723/1680 train_time:64603ms step_avg:89.35ms +step:724/1680 train_time:64693ms step_avg:89.36ms +step:725/1680 train_time:64783ms step_avg:89.36ms +step:726/1680 train_time:64872ms step_avg:89.36ms +step:727/1680 train_time:64963ms step_avg:89.36ms +step:728/1680 train_time:65053ms step_avg:89.36ms +step:729/1680 train_time:65143ms step_avg:89.36ms +step:730/1680 train_time:65234ms step_avg:89.36ms +step:731/1680 train_time:65324ms step_avg:89.36ms +step:732/1680 train_time:65415ms step_avg:89.36ms +step:733/1680 train_time:65505ms step_avg:89.37ms +step:734/1680 train_time:65595ms step_avg:89.37ms +step:735/1680 train_time:65685ms step_avg:89.37ms +step:736/1680 train_time:65775ms step_avg:89.37ms +step:737/1680 train_time:65865ms step_avg:89.37ms +step:738/1680 train_time:65954ms step_avg:89.37ms +step:739/1680 train_time:66045ms step_avg:89.37ms +step:740/1680 train_time:66136ms step_avg:89.37ms +step:741/1680 train_time:66226ms step_avg:89.37ms +step:742/1680 train_time:66317ms step_avg:89.38ms +step:743/1680 train_time:66407ms step_avg:89.38ms +step:744/1680 train_time:66497ms step_avg:89.38ms +step:745/1680 train_time:66588ms step_avg:89.38ms +step:746/1680 train_time:66678ms step_avg:89.38ms +step:747/1680 train_time:66768ms step_avg:89.38ms +step:748/1680 train_time:66858ms step_avg:89.38ms +step:749/1680 train_time:66948ms step_avg:89.38ms +step:750/1680 train_time:67039ms step_avg:89.39ms +step:750/1680 val_loss:3.5657 train_time:67129ms step_avg:89.51ms +step:751/1680 train_time:67152ms step_avg:89.42ms +step:752/1680 train_time:67224ms step_avg:89.39ms +step:753/1680 train_time:67322ms step_avg:89.40ms +step:754/1680 train_time:67412ms step_avg:89.41ms +step:755/1680 train_time:67502ms step_avg:89.41ms +step:756/1680 train_time:67591ms step_avg:89.41ms +step:757/1680 train_time:67680ms step_avg:89.41ms +step:758/1680 train_time:67769ms step_avg:89.40ms +step:759/1680 train_time:67858ms step_avg:89.40ms +step:760/1680 train_time:67948ms step_avg:89.40ms +step:761/1680 train_time:68036ms step_avg:89.40ms +step:762/1680 train_time:68127ms step_avg:89.41ms +step:763/1680 train_time:68220ms step_avg:89.41ms +step:764/1680 train_time:68312ms step_avg:89.41ms +step:765/1680 train_time:68404ms step_avg:89.42ms +step:766/1680 train_time:68494ms step_avg:89.42ms +step:767/1680 train_time:68584ms step_avg:89.42ms +step:768/1680 train_time:68674ms step_avg:89.42ms +step:769/1680 train_time:68763ms step_avg:89.42ms +step:770/1680 train_time:68852ms step_avg:89.42ms +step:771/1680 train_time:68942ms step_avg:89.42ms +step:772/1680 train_time:69031ms step_avg:89.42ms +step:773/1680 train_time:69122ms step_avg:89.42ms +step:774/1680 train_time:69212ms step_avg:89.42ms +step:775/1680 train_time:69304ms step_avg:89.43ms +step:776/1680 train_time:69396ms step_avg:89.43ms +step:777/1680 train_time:69487ms step_avg:89.43ms +step:778/1680 train_time:69578ms step_avg:89.43ms +step:779/1680 train_time:69668ms step_avg:89.43ms +step:780/1680 train_time:69758ms step_avg:89.43ms +step:781/1680 train_time:69847ms step_avg:89.43ms +step:782/1680 train_time:69937ms step_avg:89.43ms +step:783/1680 train_time:70027ms step_avg:89.43ms +step:784/1680 train_time:70117ms step_avg:89.44ms +step:785/1680 train_time:70207ms step_avg:89.44ms +step:786/1680 train_time:70298ms step_avg:89.44ms +step:787/1680 train_time:70389ms step_avg:89.44ms +step:788/1680 train_time:70480ms step_avg:89.44ms +step:789/1680 train_time:70570ms step_avg:89.44ms +step:790/1680 train_time:70661ms step_avg:89.44ms +step:791/1680 train_time:70750ms step_avg:89.44ms +step:792/1680 train_time:70841ms step_avg:89.45ms +step:793/1680 train_time:70931ms step_avg:89.45ms +step:794/1680 train_time:71021ms step_avg:89.45ms +step:795/1680 train_time:71110ms step_avg:89.45ms +step:796/1680 train_time:71201ms step_avg:89.45ms +step:797/1680 train_time:71292ms step_avg:89.45ms +step:798/1680 train_time:71384ms step_avg:89.45ms +step:799/1680 train_time:71475ms step_avg:89.46ms +step:800/1680 train_time:71565ms step_avg:89.46ms +step:801/1680 train_time:71656ms step_avg:89.46ms +step:802/1680 train_time:71746ms step_avg:89.46ms +step:803/1680 train_time:71836ms step_avg:89.46ms +step:804/1680 train_time:71926ms step_avg:89.46ms +step:805/1680 train_time:72015ms step_avg:89.46ms +step:806/1680 train_time:72105ms step_avg:89.46ms +step:807/1680 train_time:72194ms step_avg:89.46ms +step:808/1680 train_time:72285ms step_avg:89.46ms +step:809/1680 train_time:72376ms step_avg:89.46ms +step:810/1680 train_time:72467ms step_avg:89.46ms +step:811/1680 train_time:72557ms step_avg:89.47ms +step:812/1680 train_time:72648ms step_avg:89.47ms +step:813/1680 train_time:72739ms step_avg:89.47ms +step:814/1680 train_time:72828ms step_avg:89.47ms +step:815/1680 train_time:72918ms step_avg:89.47ms +step:816/1680 train_time:73008ms step_avg:89.47ms +step:817/1680 train_time:73100ms step_avg:89.47ms +step:818/1680 train_time:73188ms step_avg:89.47ms +step:819/1680 train_time:73278ms step_avg:89.47ms +step:820/1680 train_time:73368ms step_avg:89.47ms +step:821/1680 train_time:73459ms step_avg:89.48ms +step:822/1680 train_time:73549ms step_avg:89.48ms +step:823/1680 train_time:73640ms step_avg:89.48ms +step:824/1680 train_time:73730ms step_avg:89.48ms +step:825/1680 train_time:73820ms step_avg:89.48ms +step:826/1680 train_time:73910ms step_avg:89.48ms +step:827/1680 train_time:73999ms step_avg:89.48ms +step:828/1680 train_time:74089ms step_avg:89.48ms +step:829/1680 train_time:74179ms step_avg:89.48ms +step:830/1680 train_time:74269ms step_avg:89.48ms +step:831/1680 train_time:74360ms step_avg:89.48ms +step:832/1680 train_time:74450ms step_avg:89.48ms +step:833/1680 train_time:74541ms step_avg:89.49ms +step:834/1680 train_time:74631ms step_avg:89.49ms +step:835/1680 train_time:74722ms step_avg:89.49ms +step:836/1680 train_time:74812ms step_avg:89.49ms +step:837/1680 train_time:74902ms step_avg:89.49ms +step:838/1680 train_time:74991ms step_avg:89.49ms +step:839/1680 train_time:75082ms step_avg:89.49ms +step:840/1680 train_time:75171ms step_avg:89.49ms +step:841/1680 train_time:75262ms step_avg:89.49ms +step:842/1680 train_time:75352ms step_avg:89.49ms +step:843/1680 train_time:75442ms step_avg:89.49ms +step:844/1680 train_time:75532ms step_avg:89.49ms +step:845/1680 train_time:75624ms step_avg:89.50ms +step:846/1680 train_time:75715ms step_avg:89.50ms +step:847/1680 train_time:75805ms step_avg:89.50ms +step:848/1680 train_time:75896ms step_avg:89.50ms +step:849/1680 train_time:75986ms step_avg:89.50ms +step:850/1680 train_time:76076ms step_avg:89.50ms +step:851/1680 train_time:76165ms step_avg:89.50ms +step:852/1680 train_time:76256ms step_avg:89.50ms +step:853/1680 train_time:76347ms step_avg:89.50ms +step:854/1680 train_time:76437ms step_avg:89.51ms +step:855/1680 train_time:76528ms step_avg:89.51ms +step:856/1680 train_time:76619ms step_avg:89.51ms +step:857/1680 train_time:76709ms step_avg:89.51ms +step:858/1680 train_time:76799ms step_avg:89.51ms +step:859/1680 train_time:76889ms step_avg:89.51ms +step:860/1680 train_time:76980ms step_avg:89.51ms +step:861/1680 train_time:77070ms step_avg:89.51ms +step:862/1680 train_time:77160ms step_avg:89.51ms +step:863/1680 train_time:77250ms step_avg:89.51ms +step:864/1680 train_time:77340ms step_avg:89.51ms +step:865/1680 train_time:77430ms step_avg:89.51ms +step:866/1680 train_time:77521ms step_avg:89.52ms +step:867/1680 train_time:77611ms step_avg:89.52ms +step:868/1680 train_time:77702ms step_avg:89.52ms +step:869/1680 train_time:77792ms step_avg:89.52ms +step:870/1680 train_time:77883ms step_avg:89.52ms +step:871/1680 train_time:77973ms step_avg:89.52ms +step:872/1680 train_time:78064ms step_avg:89.52ms +step:873/1680 train_time:78154ms step_avg:89.52ms +step:874/1680 train_time:78244ms step_avg:89.52ms +step:875/1680 train_time:78335ms step_avg:89.53ms +step:875/1680 val_loss:3.5196 train_time:78426ms step_avg:89.63ms +step:876/1680 train_time:78450ms step_avg:89.55ms +step:877/1680 train_time:78521ms step_avg:89.53ms +step:878/1680 train_time:78620ms step_avg:89.54ms +step:879/1680 train_time:78712ms step_avg:89.55ms +step:880/1680 train_time:78800ms step_avg:89.55ms +step:881/1680 train_time:78889ms step_avg:89.55ms +step:882/1680 train_time:78978ms step_avg:89.54ms +step:883/1680 train_time:79067ms step_avg:89.54ms +step:884/1680 train_time:79156ms step_avg:89.54ms +step:885/1680 train_time:79244ms step_avg:89.54ms +step:886/1680 train_time:79334ms step_avg:89.54ms +step:887/1680 train_time:79424ms step_avg:89.54ms +step:888/1680 train_time:79517ms step_avg:89.55ms +step:889/1680 train_time:79611ms step_avg:89.55ms +step:890/1680 train_time:79704ms step_avg:89.55ms +step:891/1680 train_time:79795ms step_avg:89.56ms +step:892/1680 train_time:79885ms step_avg:89.56ms +step:893/1680 train_time:79974ms step_avg:89.56ms +step:894/1680 train_time:80062ms step_avg:89.56ms +step:895/1680 train_time:80152ms step_avg:89.56ms +step:896/1680 train_time:80241ms step_avg:89.55ms +step:897/1680 train_time:80330ms step_avg:89.55ms +step:898/1680 train_time:80420ms step_avg:89.55ms +step:899/1680 train_time:80512ms step_avg:89.56ms +step:900/1680 train_time:80604ms step_avg:89.56ms +step:901/1680 train_time:80698ms step_avg:89.56ms +step:902/1680 train_time:80789ms step_avg:89.57ms +step:903/1680 train_time:80879ms step_avg:89.57ms +step:904/1680 train_time:80970ms step_avg:89.57ms +step:905/1680 train_time:81060ms step_avg:89.57ms +step:906/1680 train_time:81149ms step_avg:89.57ms +step:907/1680 train_time:81240ms step_avg:89.57ms +step:908/1680 train_time:81329ms step_avg:89.57ms +step:909/1680 train_time:81419ms step_avg:89.57ms +step:910/1680 train_time:81510ms step_avg:89.57ms +step:911/1680 train_time:81601ms step_avg:89.57ms +step:912/1680 train_time:81692ms step_avg:89.58ms +step:913/1680 train_time:81784ms step_avg:89.58ms +step:914/1680 train_time:81874ms step_avg:89.58ms +step:915/1680 train_time:81965ms step_avg:89.58ms +step:916/1680 train_time:82054ms step_avg:89.58ms +step:917/1680 train_time:82144ms step_avg:89.58ms +step:918/1680 train_time:82234ms step_avg:89.58ms +step:919/1680 train_time:82323ms step_avg:89.58ms +step:920/1680 train_time:82413ms step_avg:89.58ms +step:921/1680 train_time:82503ms step_avg:89.58ms +step:922/1680 train_time:82595ms step_avg:89.58ms +step:923/1680 train_time:82685ms step_avg:89.58ms +step:924/1680 train_time:82776ms step_avg:89.58ms +step:925/1680 train_time:82867ms step_avg:89.59ms +step:926/1680 train_time:82957ms step_avg:89.59ms +step:927/1680 train_time:83047ms step_avg:89.59ms +step:928/1680 train_time:83138ms step_avg:89.59ms +step:929/1680 train_time:83228ms step_avg:89.59ms +step:930/1680 train_time:83317ms step_avg:89.59ms +step:931/1680 train_time:83408ms step_avg:89.59ms +step:932/1680 train_time:83498ms step_avg:89.59ms +step:933/1680 train_time:83588ms step_avg:89.59ms +step:934/1680 train_time:83679ms step_avg:89.59ms +step:935/1680 train_time:83770ms step_avg:89.59ms +step:936/1680 train_time:83861ms step_avg:89.60ms +step:937/1680 train_time:83951ms step_avg:89.60ms +step:938/1680 train_time:84042ms step_avg:89.60ms +step:939/1680 train_time:84132ms step_avg:89.60ms +step:940/1680 train_time:84221ms step_avg:89.60ms +step:941/1680 train_time:84311ms step_avg:89.60ms +step:942/1680 train_time:84401ms step_avg:89.60ms +step:943/1680 train_time:84490ms step_avg:89.60ms +step:944/1680 train_time:84580ms step_avg:89.60ms +step:945/1680 train_time:84671ms step_avg:89.60ms +step:946/1680 train_time:84761ms step_avg:89.60ms +step:947/1680 train_time:84851ms step_avg:89.60ms +step:948/1680 train_time:84941ms step_avg:89.60ms +step:949/1680 train_time:85032ms step_avg:89.60ms +step:950/1680 train_time:85122ms step_avg:89.60ms +step:951/1680 train_time:85213ms step_avg:89.60ms +step:952/1680 train_time:85303ms step_avg:89.60ms +step:953/1680 train_time:85393ms step_avg:89.60ms +step:954/1680 train_time:85483ms step_avg:89.60ms +step:955/1680 train_time:85573ms step_avg:89.61ms +step:956/1680 train_time:85663ms step_avg:89.61ms +step:957/1680 train_time:85753ms step_avg:89.61ms +step:958/1680 train_time:85843ms step_avg:89.61ms +step:959/1680 train_time:85934ms step_avg:89.61ms +step:960/1680 train_time:86023ms step_avg:89.61ms +step:961/1680 train_time:86114ms step_avg:89.61ms +step:962/1680 train_time:86204ms step_avg:89.61ms +step:963/1680 train_time:86295ms step_avg:89.61ms +step:964/1680 train_time:86385ms step_avg:89.61ms +step:965/1680 train_time:86475ms step_avg:89.61ms +step:966/1680 train_time:86566ms step_avg:89.61ms +step:967/1680 train_time:86655ms step_avg:89.61ms +step:968/1680 train_time:86745ms step_avg:89.61ms +step:969/1680 train_time:86839ms step_avg:89.62ms +step:970/1680 train_time:86927ms step_avg:89.62ms +step:971/1680 train_time:87017ms step_avg:89.62ms +step:972/1680 train_time:87109ms step_avg:89.62ms +step:973/1680 train_time:87198ms step_avg:89.62ms +step:974/1680 train_time:87289ms step_avg:89.62ms +step:975/1680 train_time:87378ms step_avg:89.62ms +step:976/1680 train_time:87468ms step_avg:89.62ms +step:977/1680 train_time:87558ms step_avg:89.62ms +step:978/1680 train_time:87648ms step_avg:89.62ms +step:979/1680 train_time:87739ms step_avg:89.62ms +step:980/1680 train_time:87830ms step_avg:89.62ms +step:981/1680 train_time:87920ms step_avg:89.62ms +step:982/1680 train_time:88011ms step_avg:89.62ms +step:983/1680 train_time:88101ms step_avg:89.62ms +step:984/1680 train_time:88192ms step_avg:89.63ms +step:985/1680 train_time:88282ms step_avg:89.63ms +step:986/1680 train_time:88372ms step_avg:89.63ms +step:987/1680 train_time:88462ms step_avg:89.63ms +step:988/1680 train_time:88552ms step_avg:89.63ms +step:989/1680 train_time:88642ms step_avg:89.63ms +step:990/1680 train_time:88732ms step_avg:89.63ms +step:991/1680 train_time:88822ms step_avg:89.63ms +step:992/1680 train_time:88913ms step_avg:89.63ms +step:993/1680 train_time:89003ms step_avg:89.63ms +step:994/1680 train_time:89094ms step_avg:89.63ms +step:995/1680 train_time:89183ms step_avg:89.63ms +step:996/1680 train_time:89274ms step_avg:89.63ms +step:997/1680 train_time:89364ms step_avg:89.63ms +step:998/1680 train_time:89455ms step_avg:89.63ms +step:999/1680 train_time:89544ms step_avg:89.63ms +step:1000/1680 train_time:89638ms step_avg:89.64ms +step:1000/1680 val_loss:3.4689 train_time:89726ms step_avg:89.73ms +step:1001/1680 train_time:89750ms step_avg:89.66ms +step:1002/1680 train_time:89819ms step_avg:89.64ms +step:1003/1680 train_time:89916ms step_avg:89.65ms +step:1004/1680 train_time:90007ms step_avg:89.65ms +step:1005/1680 train_time:90097ms step_avg:89.65ms +step:1006/1680 train_time:90186ms step_avg:89.65ms +step:1007/1680 train_time:90277ms step_avg:89.65ms +step:1008/1680 train_time:90366ms step_avg:89.65ms +step:1009/1680 train_time:90455ms step_avg:89.65ms +step:1010/1680 train_time:90544ms step_avg:89.65ms +step:1011/1680 train_time:90633ms step_avg:89.65ms +step:1012/1680 train_time:90724ms step_avg:89.65ms +step:1013/1680 train_time:90816ms step_avg:89.65ms +step:1014/1680 train_time:90908ms step_avg:89.65ms +step:1015/1680 train_time:91000ms step_avg:89.66ms +step:1016/1680 train_time:91091ms step_avg:89.66ms +step:1017/1680 train_time:91181ms step_avg:89.66ms +step:1018/1680 train_time:91271ms step_avg:89.66ms +step:1019/1680 train_time:91361ms step_avg:89.66ms +step:1020/1680 train_time:91450ms step_avg:89.66ms +step:1021/1680 train_time:91540ms step_avg:89.66ms +step:1022/1680 train_time:91629ms step_avg:89.66ms +step:1023/1680 train_time:91719ms step_avg:89.66ms +step:1024/1680 train_time:91811ms step_avg:89.66ms +step:1025/1680 train_time:91902ms step_avg:89.66ms +step:1026/1680 train_time:91993ms step_avg:89.66ms +step:1027/1680 train_time:92082ms step_avg:89.66ms +step:1028/1680 train_time:92173ms step_avg:89.66ms +step:1029/1680 train_time:92262ms step_avg:89.66ms +step:1030/1680 train_time:92352ms step_avg:89.66ms +step:1031/1680 train_time:92442ms step_avg:89.66ms +step:1032/1680 train_time:92531ms step_avg:89.66ms +step:1033/1680 train_time:92620ms step_avg:89.66ms +step:1034/1680 train_time:92710ms step_avg:89.66ms +step:1035/1680 train_time:92801ms step_avg:89.66ms +step:1036/1680 train_time:92892ms step_avg:89.66ms +step:1037/1680 train_time:92983ms step_avg:89.67ms +step:1038/1680 train_time:93074ms step_avg:89.67ms +step:1039/1680 train_time:93164ms step_avg:89.67ms +step:1040/1680 train_time:93254ms step_avg:89.67ms +step:1041/1680 train_time:93343ms step_avg:89.67ms +step:1042/1680 train_time:93434ms step_avg:89.67ms +step:1043/1680 train_time:93523ms step_avg:89.67ms +step:1044/1680 train_time:93613ms step_avg:89.67ms +step:1045/1680 train_time:93702ms step_avg:89.67ms +step:1046/1680 train_time:93793ms step_avg:89.67ms +step:1047/1680 train_time:93883ms step_avg:89.67ms +step:1048/1680 train_time:93974ms step_avg:89.67ms +step:1049/1680 train_time:94064ms step_avg:89.67ms +step:1050/1680 train_time:94155ms step_avg:89.67ms +step:1051/1680 train_time:94244ms step_avg:89.67ms +step:1052/1680 train_time:94334ms step_avg:89.67ms +step:1053/1680 train_time:94424ms step_avg:89.67ms +step:1054/1680 train_time:94514ms step_avg:89.67ms +step:1055/1680 train_time:94603ms step_avg:89.67ms +step:1056/1680 train_time:94694ms step_avg:89.67ms +step:1057/1680 train_time:94783ms step_avg:89.67ms +step:1058/1680 train_time:94874ms step_avg:89.67ms +step:1059/1680 train_time:94963ms step_avg:89.67ms +step:1060/1680 train_time:95055ms step_avg:89.67ms +step:1061/1680 train_time:95146ms step_avg:89.68ms +step:1062/1680 train_time:95238ms step_avg:89.68ms +step:1063/1680 train_time:95328ms step_avg:89.68ms +step:1064/1680 train_time:95419ms step_avg:89.68ms +step:1065/1680 train_time:95509ms step_avg:89.68ms +step:1066/1680 train_time:95599ms step_avg:89.68ms +step:1067/1680 train_time:95689ms step_avg:89.68ms +step:1068/1680 train_time:95779ms step_avg:89.68ms +step:1069/1680 train_time:95869ms step_avg:89.68ms +step:1070/1680 train_time:95960ms step_avg:89.68ms +step:1071/1680 train_time:96050ms step_avg:89.68ms +step:1072/1680 train_time:96140ms step_avg:89.68ms +step:1073/1680 train_time:96231ms step_avg:89.68ms +step:1074/1680 train_time:96321ms step_avg:89.68ms +step:1075/1680 train_time:96411ms step_avg:89.69ms +step:1076/1680 train_time:96501ms step_avg:89.69ms +step:1077/1680 train_time:96591ms step_avg:89.69ms +step:1078/1680 train_time:96682ms step_avg:89.69ms +step:1079/1680 train_time:96772ms step_avg:89.69ms +step:1080/1680 train_time:96862ms step_avg:89.69ms +step:1081/1680 train_time:96953ms step_avg:89.69ms +step:1082/1680 train_time:97043ms step_avg:89.69ms +step:1083/1680 train_time:97134ms step_avg:89.69ms +step:1084/1680 train_time:97224ms step_avg:89.69ms +step:1085/1680 train_time:97313ms step_avg:89.69ms +step:1086/1680 train_time:97404ms step_avg:89.69ms +step:1087/1680 train_time:97494ms step_avg:89.69ms +step:1088/1680 train_time:97585ms step_avg:89.69ms +step:1089/1680 train_time:97675ms step_avg:89.69ms +step:1090/1680 train_time:97764ms step_avg:89.69ms +step:1091/1680 train_time:97855ms step_avg:89.69ms +step:1092/1680 train_time:97945ms step_avg:89.69ms +step:1093/1680 train_time:98037ms step_avg:89.70ms +step:1094/1680 train_time:98126ms step_avg:89.69ms +step:1095/1680 train_time:98217ms step_avg:89.70ms +step:1096/1680 train_time:98307ms step_avg:89.70ms +step:1097/1680 train_time:98399ms step_avg:89.70ms +step:1098/1680 train_time:98489ms step_avg:89.70ms +step:1099/1680 train_time:98580ms step_avg:89.70ms +step:1100/1680 train_time:98670ms step_avg:89.70ms +step:1101/1680 train_time:98761ms step_avg:89.70ms +step:1102/1680 train_time:98853ms step_avg:89.70ms +step:1103/1680 train_time:98943ms step_avg:89.70ms +step:1104/1680 train_time:99034ms step_avg:89.70ms +step:1105/1680 train_time:99125ms step_avg:89.71ms +step:1106/1680 train_time:99215ms step_avg:89.71ms +step:1107/1680 train_time:99307ms step_avg:89.71ms +step:1108/1680 train_time:99398ms step_avg:89.71ms +step:1109/1680 train_time:99489ms step_avg:89.71ms +step:1110/1680 train_time:99580ms step_avg:89.71ms +step:1111/1680 train_time:99671ms step_avg:89.71ms +step:1112/1680 train_time:99762ms step_avg:89.71ms +step:1113/1680 train_time:99853ms step_avg:89.71ms +step:1114/1680 train_time:99943ms step_avg:89.72ms +step:1115/1680 train_time:100035ms step_avg:89.72ms +step:1116/1680 train_time:100125ms step_avg:89.72ms +step:1117/1680 train_time:100216ms step_avg:89.72ms +step:1118/1680 train_time:100307ms step_avg:89.72ms +step:1119/1680 train_time:100398ms step_avg:89.72ms +step:1120/1680 train_time:100489ms step_avg:89.72ms +step:1121/1680 train_time:100580ms step_avg:89.72ms +step:1122/1680 train_time:100672ms step_avg:89.73ms +step:1123/1680 train_time:100762ms step_avg:89.73ms +step:1124/1680 train_time:100854ms step_avg:89.73ms +step:1125/1680 train_time:100944ms step_avg:89.73ms +step:1125/1680 val_loss:3.4153 train_time:101037ms step_avg:89.81ms +step:1126/1680 train_time:101060ms step_avg:89.75ms +step:1127/1680 train_time:101131ms step_avg:89.74ms +step:1128/1680 train_time:101230ms step_avg:89.74ms +step:1129/1680 train_time:101324ms step_avg:89.75ms +step:1130/1680 train_time:101414ms step_avg:89.75ms +step:1131/1680 train_time:101504ms step_avg:89.75ms +step:1132/1680 train_time:101594ms step_avg:89.75ms +step:1133/1680 train_time:101684ms step_avg:89.75ms +step:1134/1680 train_time:101773ms step_avg:89.75ms +step:1135/1680 train_time:101863ms step_avg:89.75ms +step:1136/1680 train_time:101952ms step_avg:89.75ms +step:1137/1680 train_time:102043ms step_avg:89.75ms +step:1138/1680 train_time:102135ms step_avg:89.75ms +step:1139/1680 train_time:102229ms step_avg:89.75ms +step:1140/1680 train_time:102323ms step_avg:89.76ms +step:1141/1680 train_time:102414ms step_avg:89.76ms +step:1142/1680 train_time:102505ms step_avg:89.76ms +step:1143/1680 train_time:102595ms step_avg:89.76ms +step:1144/1680 train_time:102684ms step_avg:89.76ms +step:1145/1680 train_time:102774ms step_avg:89.76ms +step:1146/1680 train_time:102864ms step_avg:89.76ms +step:1147/1680 train_time:102954ms step_avg:89.76ms +step:1148/1680 train_time:103046ms step_avg:89.76ms +step:1149/1680 train_time:103138ms step_avg:89.76ms +step:1150/1680 train_time:103230ms step_avg:89.77ms +step:1151/1680 train_time:103324ms step_avg:89.77ms +step:1152/1680 train_time:103415ms step_avg:89.77ms +step:1153/1680 train_time:103506ms step_avg:89.77ms +step:1154/1680 train_time:103596ms step_avg:89.77ms +step:1155/1680 train_time:103687ms step_avg:89.77ms +step:1156/1680 train_time:103776ms step_avg:89.77ms +step:1157/1680 train_time:103866ms step_avg:89.77ms +step:1158/1680 train_time:103956ms step_avg:89.77ms +step:1159/1680 train_time:104047ms step_avg:89.77ms +step:1160/1680 train_time:104138ms step_avg:89.77ms +step:1161/1680 train_time:104229ms step_avg:89.78ms +step:1162/1680 train_time:104322ms step_avg:89.78ms +step:1163/1680 train_time:104413ms step_avg:89.78ms +step:1164/1680 train_time:104504ms step_avg:89.78ms +step:1165/1680 train_time:104595ms step_avg:89.78ms +step:1166/1680 train_time:104685ms step_avg:89.78ms +step:1167/1680 train_time:104775ms step_avg:89.78ms +step:1168/1680 train_time:104866ms step_avg:89.78ms +step:1169/1680 train_time:104957ms step_avg:89.78ms +step:1170/1680 train_time:105047ms step_avg:89.78ms +step:1171/1680 train_time:105138ms step_avg:89.79ms +step:1172/1680 train_time:105229ms step_avg:89.79ms +step:1173/1680 train_time:105322ms step_avg:89.79ms +step:1174/1680 train_time:105414ms step_avg:89.79ms +step:1175/1680 train_time:105504ms step_avg:89.79ms +step:1176/1680 train_time:105595ms step_avg:89.79ms +step:1177/1680 train_time:105685ms step_avg:89.79ms +step:1178/1680 train_time:105775ms step_avg:89.79ms +step:1179/1680 train_time:105866ms step_avg:89.79ms +step:1180/1680 train_time:105956ms step_avg:89.79ms +step:1181/1680 train_time:106046ms step_avg:89.79ms +step:1182/1680 train_time:106138ms step_avg:89.79ms +step:1183/1680 train_time:106228ms step_avg:89.80ms +step:1184/1680 train_time:106320ms step_avg:89.80ms +step:1185/1680 train_time:106411ms step_avg:89.80ms +step:1186/1680 train_time:106502ms step_avg:89.80ms +step:1187/1680 train_time:106592ms step_avg:89.80ms +step:1188/1680 train_time:106682ms step_avg:89.80ms +step:1189/1680 train_time:106773ms step_avg:89.80ms +step:1190/1680 train_time:106864ms step_avg:89.80ms +step:1191/1680 train_time:106954ms step_avg:89.80ms +step:1192/1680 train_time:107046ms step_avg:89.80ms +step:1193/1680 train_time:107138ms step_avg:89.81ms +step:1194/1680 train_time:107229ms step_avg:89.81ms +step:1195/1680 train_time:107320ms step_avg:89.81ms +step:1196/1680 train_time:107410ms step_avg:89.81ms +step:1197/1680 train_time:107501ms step_avg:89.81ms +step:1198/1680 train_time:107591ms step_avg:89.81ms +step:1199/1680 train_time:107682ms step_avg:89.81ms +step:1200/1680 train_time:107773ms step_avg:89.81ms +step:1201/1680 train_time:107863ms step_avg:89.81ms +step:1202/1680 train_time:107953ms step_avg:89.81ms +step:1203/1680 train_time:108045ms step_avg:89.81ms +step:1204/1680 train_time:108136ms step_avg:89.81ms +step:1205/1680 train_time:108227ms step_avg:89.82ms +step:1206/1680 train_time:108319ms step_avg:89.82ms +step:1207/1680 train_time:108409ms step_avg:89.82ms +step:1208/1680 train_time:108500ms step_avg:89.82ms +step:1209/1680 train_time:108590ms step_avg:89.82ms +step:1210/1680 train_time:108681ms step_avg:89.82ms +step:1211/1680 train_time:108772ms step_avg:89.82ms +step:1212/1680 train_time:108863ms step_avg:89.82ms +step:1213/1680 train_time:108953ms step_avg:89.82ms +step:1214/1680 train_time:109045ms step_avg:89.82ms +step:1215/1680 train_time:109136ms step_avg:89.82ms +step:1216/1680 train_time:109227ms step_avg:89.83ms +step:1217/1680 train_time:109319ms step_avg:89.83ms +step:1218/1680 train_time:109410ms step_avg:89.83ms +step:1219/1680 train_time:109500ms step_avg:89.83ms +step:1220/1680 train_time:109590ms step_avg:89.83ms +step:1221/1680 train_time:109680ms step_avg:89.83ms +step:1222/1680 train_time:109771ms step_avg:89.83ms +step:1223/1680 train_time:109862ms step_avg:89.83ms +step:1224/1680 train_time:109952ms step_avg:89.83ms +step:1225/1680 train_time:110043ms step_avg:89.83ms +step:1226/1680 train_time:110134ms step_avg:89.83ms +step:1227/1680 train_time:110225ms step_avg:89.83ms +step:1228/1680 train_time:110318ms step_avg:89.84ms +step:1229/1680 train_time:110408ms step_avg:89.84ms +step:1230/1680 train_time:110499ms step_avg:89.84ms +step:1231/1680 train_time:110590ms step_avg:89.84ms +step:1232/1680 train_time:110681ms step_avg:89.84ms +step:1233/1680 train_time:110771ms step_avg:89.84ms +step:1234/1680 train_time:110861ms step_avg:89.84ms +step:1235/1680 train_time:110951ms step_avg:89.84ms +step:1236/1680 train_time:111043ms step_avg:89.84ms +step:1237/1680 train_time:111133ms step_avg:89.84ms +step:1238/1680 train_time:111225ms step_avg:89.84ms +step:1239/1680 train_time:111316ms step_avg:89.84ms +step:1240/1680 train_time:111407ms step_avg:89.84ms +step:1241/1680 train_time:111497ms step_avg:89.84ms +step:1242/1680 train_time:111588ms step_avg:89.85ms +step:1243/1680 train_time:111679ms step_avg:89.85ms +step:1244/1680 train_time:111770ms step_avg:89.85ms +step:1245/1680 train_time:111861ms step_avg:89.85ms +step:1246/1680 train_time:111954ms step_avg:89.85ms +step:1247/1680 train_time:112043ms step_avg:89.85ms +step:1248/1680 train_time:112133ms step_avg:89.85ms +step:1249/1680 train_time:112224ms step_avg:89.85ms +step:1250/1680 train_time:112315ms step_avg:89.85ms +step:1250/1680 val_loss:3.3770 train_time:112407ms step_avg:89.93ms +step:1251/1680 train_time:112430ms step_avg:89.87ms +step:1252/1680 train_time:112503ms step_avg:89.86ms +step:1253/1680 train_time:112601ms step_avg:89.87ms +step:1254/1680 train_time:112693ms step_avg:89.87ms +step:1255/1680 train_time:112783ms step_avg:89.87ms +step:1256/1680 train_time:112873ms step_avg:89.87ms +step:1257/1680 train_time:112962ms step_avg:89.87ms +step:1258/1680 train_time:113052ms step_avg:89.87ms +step:1259/1680 train_time:113142ms step_avg:89.87ms +step:1260/1680 train_time:113231ms step_avg:89.87ms +step:1261/1680 train_time:113321ms step_avg:89.87ms +step:1262/1680 train_time:113413ms step_avg:89.87ms +step:1263/1680 train_time:113506ms step_avg:89.87ms +step:1264/1680 train_time:113600ms step_avg:89.87ms +step:1265/1680 train_time:113692ms step_avg:89.88ms +step:1266/1680 train_time:113783ms step_avg:89.88ms +step:1267/1680 train_time:113880ms step_avg:89.88ms +step:1268/1680 train_time:113964ms step_avg:89.88ms +step:1269/1680 train_time:114054ms step_avg:89.88ms +step:1270/1680 train_time:114143ms step_avg:89.88ms +step:1271/1680 train_time:114233ms step_avg:89.88ms +step:1272/1680 train_time:114323ms step_avg:89.88ms +step:1273/1680 train_time:114414ms step_avg:89.88ms +step:1274/1680 train_time:114507ms step_avg:89.88ms +step:1275/1680 train_time:114599ms step_avg:89.88ms +step:1276/1680 train_time:114691ms step_avg:89.88ms +step:1277/1680 train_time:114783ms step_avg:89.88ms +step:1278/1680 train_time:114874ms step_avg:89.89ms +step:1279/1680 train_time:114964ms step_avg:89.89ms +step:1280/1680 train_time:115055ms step_avg:89.89ms +step:1281/1680 train_time:115145ms step_avg:89.89ms +step:1282/1680 train_time:115235ms step_avg:89.89ms +step:1283/1680 train_time:115325ms step_avg:89.89ms +step:1284/1680 train_time:115417ms step_avg:89.89ms +step:1285/1680 train_time:115508ms step_avg:89.89ms +step:1286/1680 train_time:115599ms step_avg:89.89ms +step:1287/1680 train_time:115690ms step_avg:89.89ms +step:1288/1680 train_time:115783ms step_avg:89.89ms +step:1289/1680 train_time:115875ms step_avg:89.90ms +step:1290/1680 train_time:115965ms step_avg:89.90ms +step:1291/1680 train_time:116056ms step_avg:89.90ms +step:1292/1680 train_time:116146ms step_avg:89.90ms +step:1293/1680 train_time:116235ms step_avg:89.90ms +step:1294/1680 train_time:116326ms step_avg:89.90ms +step:1295/1680 train_time:116417ms step_avg:89.90ms +step:1296/1680 train_time:116509ms step_avg:89.90ms +step:1297/1680 train_time:116600ms step_avg:89.90ms +step:1298/1680 train_time:116691ms step_avg:89.90ms +step:1299/1680 train_time:116784ms step_avg:89.90ms +step:1300/1680 train_time:116875ms step_avg:89.90ms +step:1301/1680 train_time:116966ms step_avg:89.90ms +step:1302/1680 train_time:117057ms step_avg:89.91ms +step:1303/1680 train_time:117147ms step_avg:89.91ms +step:1304/1680 train_time:117239ms step_avg:89.91ms +step:1305/1680 train_time:117329ms step_avg:89.91ms +step:1306/1680 train_time:117420ms step_avg:89.91ms +step:1307/1680 train_time:117510ms step_avg:89.91ms +step:1308/1680 train_time:117601ms step_avg:89.91ms +step:1309/1680 train_time:117693ms step_avg:89.91ms +step:1310/1680 train_time:117784ms step_avg:89.91ms +step:1311/1680 train_time:117875ms step_avg:89.91ms +step:1312/1680 train_time:117966ms step_avg:89.91ms +step:1313/1680 train_time:118058ms step_avg:89.91ms +step:1314/1680 train_time:118148ms step_avg:89.92ms +step:1315/1680 train_time:118239ms step_avg:89.92ms +step:1316/1680 train_time:118330ms step_avg:89.92ms +step:1317/1680 train_time:118420ms step_avg:89.92ms +step:1318/1680 train_time:118511ms step_avg:89.92ms +step:1319/1680 train_time:118602ms step_avg:89.92ms +step:1320/1680 train_time:118692ms step_avg:89.92ms +step:1321/1680 train_time:118784ms step_avg:89.92ms +step:1322/1680 train_time:118876ms step_avg:89.92ms +step:1323/1680 train_time:118966ms step_avg:89.92ms +step:1324/1680 train_time:119057ms step_avg:89.92ms +step:1325/1680 train_time:119147ms step_avg:89.92ms +step:1326/1680 train_time:119237ms step_avg:89.92ms +step:1327/1680 train_time:119327ms step_avg:89.92ms +step:1328/1680 train_time:119418ms step_avg:89.92ms +step:1329/1680 train_time:119509ms step_avg:89.92ms +step:1330/1680 train_time:119600ms step_avg:89.92ms +step:1331/1680 train_time:119690ms step_avg:89.93ms +step:1332/1680 train_time:119782ms step_avg:89.93ms +step:1333/1680 train_time:119872ms step_avg:89.93ms +step:1334/1680 train_time:119964ms step_avg:89.93ms +step:1335/1680 train_time:120054ms step_avg:89.93ms +step:1336/1680 train_time:120145ms step_avg:89.93ms +step:1337/1680 train_time:120235ms step_avg:89.93ms +step:1338/1680 train_time:120325ms step_avg:89.93ms +step:1339/1680 train_time:120415ms step_avg:89.93ms +step:1340/1680 train_time:120507ms step_avg:89.93ms +step:1341/1680 train_time:120597ms step_avg:89.93ms +step:1342/1680 train_time:120687ms step_avg:89.93ms +step:1343/1680 train_time:120778ms step_avg:89.93ms +step:1344/1680 train_time:120869ms step_avg:89.93ms +step:1345/1680 train_time:120961ms step_avg:89.93ms +step:1346/1680 train_time:121052ms step_avg:89.93ms +step:1347/1680 train_time:121144ms step_avg:89.94ms +step:1348/1680 train_time:121234ms step_avg:89.94ms +step:1349/1680 train_time:121325ms step_avg:89.94ms +step:1350/1680 train_time:121416ms step_avg:89.94ms +step:1351/1680 train_time:121507ms step_avg:89.94ms +step:1352/1680 train_time:121598ms step_avg:89.94ms +step:1353/1680 train_time:121689ms step_avg:89.94ms +step:1354/1680 train_time:121780ms step_avg:89.94ms +step:1355/1680 train_time:121871ms step_avg:89.94ms +step:1356/1680 train_time:121962ms step_avg:89.94ms +step:1357/1680 train_time:122053ms step_avg:89.94ms +step:1358/1680 train_time:122144ms step_avg:89.94ms +step:1359/1680 train_time:122236ms step_avg:89.95ms +step:1360/1680 train_time:122326ms step_avg:89.95ms +step:1361/1680 train_time:122417ms step_avg:89.95ms +step:1362/1680 train_time:122507ms step_avg:89.95ms +step:1363/1680 train_time:122598ms step_avg:89.95ms +step:1364/1680 train_time:122688ms step_avg:89.95ms +step:1365/1680 train_time:122779ms step_avg:89.95ms +step:1366/1680 train_time:122869ms step_avg:89.95ms +step:1367/1680 train_time:122961ms step_avg:89.95ms +step:1368/1680 train_time:123053ms step_avg:89.95ms +step:1369/1680 train_time:123143ms step_avg:89.95ms +step:1370/1680 train_time:123234ms step_avg:89.95ms +step:1371/1680 train_time:123324ms step_avg:89.95ms +step:1372/1680 train_time:123415ms step_avg:89.95ms +step:1373/1680 train_time:123506ms step_avg:89.95ms +step:1374/1680 train_time:123597ms step_avg:89.95ms +step:1375/1680 train_time:123687ms step_avg:89.95ms +step:1375/1680 val_loss:3.3424 train_time:123779ms step_avg:90.02ms +step:1376/1680 train_time:123802ms step_avg:89.97ms +step:1377/1680 train_time:123875ms step_avg:89.96ms +step:1378/1680 train_time:123970ms step_avg:89.96ms +step:1379/1680 train_time:124061ms step_avg:89.96ms +step:1380/1680 train_time:124150ms step_avg:89.96ms +step:1381/1680 train_time:124240ms step_avg:89.96ms +step:1382/1680 train_time:124330ms step_avg:89.96ms +step:1383/1680 train_time:124419ms step_avg:89.96ms +step:1384/1680 train_time:124509ms step_avg:89.96ms +step:1385/1680 train_time:124600ms step_avg:89.96ms +step:1386/1680 train_time:124690ms step_avg:89.96ms +step:1387/1680 train_time:124782ms step_avg:89.97ms +step:1388/1680 train_time:124875ms step_avg:89.97ms +step:1389/1680 train_time:124968ms step_avg:89.97ms +step:1390/1680 train_time:125060ms step_avg:89.97ms +step:1391/1680 train_time:125151ms step_avg:89.97ms +step:1392/1680 train_time:125242ms step_avg:89.97ms +step:1393/1680 train_time:125331ms step_avg:89.97ms +step:1394/1680 train_time:125421ms step_avg:89.97ms +step:1395/1680 train_time:125510ms step_avg:89.97ms +step:1396/1680 train_time:125601ms step_avg:89.97ms +step:1397/1680 train_time:125692ms step_avg:89.97ms +step:1398/1680 train_time:125783ms step_avg:89.97ms +step:1399/1680 train_time:125876ms step_avg:89.98ms +step:1400/1680 train_time:125968ms step_avg:89.98ms +step:1401/1680 train_time:126061ms step_avg:89.98ms +step:1402/1680 train_time:126152ms step_avg:89.98ms +step:1403/1680 train_time:126243ms step_avg:89.98ms +step:1404/1680 train_time:126333ms step_avg:89.98ms +step:1405/1680 train_time:126423ms step_avg:89.98ms +step:1406/1680 train_time:126514ms step_avg:89.98ms +step:1407/1680 train_time:126604ms step_avg:89.98ms +step:1408/1680 train_time:126695ms step_avg:89.98ms +step:1409/1680 train_time:126786ms step_avg:89.98ms +step:1410/1680 train_time:126878ms step_avg:89.98ms +step:1411/1680 train_time:126969ms step_avg:89.99ms +step:1412/1680 train_time:127061ms step_avg:89.99ms +step:1413/1680 train_time:127151ms step_avg:89.99ms +step:1414/1680 train_time:127242ms step_avg:89.99ms +step:1415/1680 train_time:127332ms step_avg:89.99ms +step:1416/1680 train_time:127423ms step_avg:89.99ms +step:1417/1680 train_time:127513ms step_avg:89.99ms +step:1418/1680 train_time:127603ms step_avg:89.99ms +step:1419/1680 train_time:127694ms step_avg:89.99ms +step:1420/1680 train_time:127785ms step_avg:89.99ms +step:1421/1680 train_time:127877ms step_avg:89.99ms +step:1422/1680 train_time:127969ms step_avg:89.99ms +step:1423/1680 train_time:128060ms step_avg:89.99ms +step:1424/1680 train_time:128150ms step_avg:89.99ms +step:1425/1680 train_time:128242ms step_avg:89.99ms +step:1426/1680 train_time:128333ms step_avg:89.99ms +step:1427/1680 train_time:128424ms step_avg:90.00ms +step:1428/1680 train_time:128514ms step_avg:90.00ms +step:1429/1680 train_time:128605ms step_avg:90.00ms +step:1430/1680 train_time:128697ms step_avg:90.00ms +step:1431/1680 train_time:128787ms step_avg:90.00ms +step:1432/1680 train_time:128878ms step_avg:90.00ms +step:1433/1680 train_time:128970ms step_avg:90.00ms +step:1434/1680 train_time:129061ms step_avg:90.00ms +step:1435/1680 train_time:129152ms step_avg:90.00ms +step:1436/1680 train_time:129243ms step_avg:90.00ms +step:1437/1680 train_time:129334ms step_avg:90.00ms +step:1438/1680 train_time:129425ms step_avg:90.00ms +step:1439/1680 train_time:129515ms step_avg:90.00ms +step:1440/1680 train_time:129606ms step_avg:90.00ms +step:1441/1680 train_time:129697ms step_avg:90.01ms +step:1442/1680 train_time:129788ms step_avg:90.01ms +step:1443/1680 train_time:129878ms step_avg:90.01ms +step:1444/1680 train_time:129969ms step_avg:90.01ms +step:1445/1680 train_time:130061ms step_avg:90.01ms +step:1446/1680 train_time:130152ms step_avg:90.01ms +step:1447/1680 train_time:130243ms step_avg:90.01ms +step:1448/1680 train_time:130333ms step_avg:90.01ms +step:1449/1680 train_time:130425ms step_avg:90.01ms +step:1450/1680 train_time:130515ms step_avg:90.01ms +step:1451/1680 train_time:130606ms step_avg:90.01ms +step:1452/1680 train_time:130696ms step_avg:90.01ms +step:1453/1680 train_time:130786ms step_avg:90.01ms +step:1454/1680 train_time:130876ms step_avg:90.01ms +step:1455/1680 train_time:130967ms step_avg:90.01ms +step:1456/1680 train_time:131059ms step_avg:90.01ms +step:1457/1680 train_time:131150ms step_avg:90.01ms +step:1458/1680 train_time:131241ms step_avg:90.01ms +step:1459/1680 train_time:131332ms step_avg:90.02ms +step:1460/1680 train_time:131425ms step_avg:90.02ms +step:1461/1680 train_time:131516ms step_avg:90.02ms +step:1462/1680 train_time:131607ms step_avg:90.02ms +step:1463/1680 train_time:131698ms step_avg:90.02ms +step:1464/1680 train_time:131788ms step_avg:90.02ms +step:1465/1680 train_time:131879ms step_avg:90.02ms +step:1466/1680 train_time:131969ms step_avg:90.02ms +step:1467/1680 train_time:132060ms step_avg:90.02ms +step:1468/1680 train_time:132151ms step_avg:90.02ms +step:1469/1680 train_time:132243ms step_avg:90.02ms +step:1470/1680 train_time:132333ms step_avg:90.02ms +step:1471/1680 train_time:132424ms step_avg:90.02ms +step:1472/1680 train_time:132516ms step_avg:90.02ms +step:1473/1680 train_time:132606ms step_avg:90.02ms +step:1474/1680 train_time:132697ms step_avg:90.03ms +step:1475/1680 train_time:132788ms step_avg:90.03ms +step:1476/1680 train_time:132879ms step_avg:90.03ms +step:1477/1680 train_time:132969ms step_avg:90.03ms +step:1478/1680 train_time:133060ms step_avg:90.03ms +step:1479/1680 train_time:133151ms step_avg:90.03ms +step:1480/1680 train_time:133243ms step_avg:90.03ms +step:1481/1680 train_time:133333ms step_avg:90.03ms +step:1482/1680 train_time:133425ms step_avg:90.03ms +step:1483/1680 train_time:133517ms step_avg:90.03ms +step:1484/1680 train_time:133608ms step_avg:90.03ms +step:1485/1680 train_time:133699ms step_avg:90.03ms +step:1486/1680 train_time:133789ms step_avg:90.03ms +step:1487/1680 train_time:133881ms step_avg:90.03ms +step:1488/1680 train_time:133971ms step_avg:90.03ms +step:1489/1680 train_time:134062ms step_avg:90.03ms +step:1490/1680 train_time:134152ms step_avg:90.03ms +step:1491/1680 train_time:134243ms step_avg:90.04ms +step:1492/1680 train_time:134333ms step_avg:90.04ms +step:1493/1680 train_time:134424ms step_avg:90.04ms +step:1494/1680 train_time:134516ms step_avg:90.04ms +step:1495/1680 train_time:134607ms step_avg:90.04ms +step:1496/1680 train_time:134698ms step_avg:90.04ms +step:1497/1680 train_time:134788ms step_avg:90.04ms +step:1498/1680 train_time:134879ms step_avg:90.04ms +step:1499/1680 train_time:134969ms step_avg:90.04ms +step:1500/1680 train_time:135059ms step_avg:90.04ms +step:1500/1680 val_loss:3.3126 train_time:135152ms step_avg:90.10ms +step:1501/1680 train_time:135175ms step_avg:90.06ms +step:1502/1680 train_time:135245ms step_avg:90.04ms +step:1503/1680 train_time:135345ms step_avg:90.05ms +step:1504/1680 train_time:135437ms step_avg:90.05ms +step:1505/1680 train_time:135526ms step_avg:90.05ms +step:1506/1680 train_time:135616ms step_avg:90.05ms +step:1507/1680 train_time:135705ms step_avg:90.05ms +step:1508/1680 train_time:135795ms step_avg:90.05ms +step:1509/1680 train_time:135884ms step_avg:90.05ms +step:1510/1680 train_time:135975ms step_avg:90.05ms +step:1511/1680 train_time:136065ms step_avg:90.05ms +step:1512/1680 train_time:136159ms step_avg:90.05ms +step:1513/1680 train_time:136252ms step_avg:90.05ms +step:1514/1680 train_time:136345ms step_avg:90.06ms +step:1515/1680 train_time:136436ms step_avg:90.06ms +step:1516/1680 train_time:136527ms step_avg:90.06ms +step:1517/1680 train_time:136618ms step_avg:90.06ms +step:1518/1680 train_time:136707ms step_avg:90.06ms +step:1519/1680 train_time:136798ms step_avg:90.06ms +step:1520/1680 train_time:136887ms step_avg:90.06ms +step:1521/1680 train_time:136977ms step_avg:90.06ms +step:1522/1680 train_time:137067ms step_avg:90.06ms +step:1523/1680 train_time:137160ms step_avg:90.06ms +step:1524/1680 train_time:137252ms step_avg:90.06ms +step:1525/1680 train_time:137343ms step_avg:90.06ms +step:1526/1680 train_time:137434ms step_avg:90.06ms +step:1527/1680 train_time:137525ms step_avg:90.06ms +step:1528/1680 train_time:137616ms step_avg:90.06ms +step:1529/1680 train_time:137706ms step_avg:90.06ms +step:1530/1680 train_time:137797ms step_avg:90.06ms +step:1531/1680 train_time:137887ms step_avg:90.06ms +step:1532/1680 train_time:137978ms step_avg:90.06ms +step:1533/1680 train_time:138069ms step_avg:90.06ms +step:1534/1680 train_time:138161ms step_avg:90.07ms +step:1535/1680 train_time:138251ms step_avg:90.07ms +step:1536/1680 train_time:138343ms step_avg:90.07ms +step:1537/1680 train_time:138434ms step_avg:90.07ms +step:1538/1680 train_time:138524ms step_avg:90.07ms +step:1539/1680 train_time:138615ms step_avg:90.07ms +step:1540/1680 train_time:138705ms step_avg:90.07ms +step:1541/1680 train_time:138796ms step_avg:90.07ms +step:1542/1680 train_time:138886ms step_avg:90.07ms +step:1543/1680 train_time:138977ms step_avg:90.07ms +step:1544/1680 train_time:139070ms step_avg:90.07ms +step:1545/1680 train_time:139159ms step_avg:90.07ms +step:1546/1680 train_time:139250ms step_avg:90.07ms +step:1547/1680 train_time:139341ms step_avg:90.07ms +step:1548/1680 train_time:139433ms step_avg:90.07ms +step:1549/1680 train_time:139523ms step_avg:90.07ms +step:1550/1680 train_time:139614ms step_avg:90.07ms +step:1551/1680 train_time:139704ms step_avg:90.07ms +step:1552/1680 train_time:139795ms step_avg:90.07ms +step:1553/1680 train_time:139885ms step_avg:90.07ms +step:1554/1680 train_time:139976ms step_avg:90.07ms +step:1555/1680 train_time:140066ms step_avg:90.07ms +step:1556/1680 train_time:140158ms step_avg:90.08ms +step:1557/1680 train_time:140249ms step_avg:90.08ms +step:1558/1680 train_time:140339ms step_avg:90.08ms +step:1559/1680 train_time:140430ms step_avg:90.08ms +step:1560/1680 train_time:140521ms step_avg:90.08ms +step:1561/1680 train_time:140611ms step_avg:90.08ms +step:1562/1680 train_time:140702ms step_avg:90.08ms +step:1563/1680 train_time:140792ms step_avg:90.08ms +step:1564/1680 train_time:140882ms step_avg:90.08ms +step:1565/1680 train_time:140972ms step_avg:90.08ms +step:1566/1680 train_time:141064ms step_avg:90.08ms +step:1567/1680 train_time:141156ms step_avg:90.08ms +step:1568/1680 train_time:141247ms step_avg:90.08ms +step:1569/1680 train_time:141339ms step_avg:90.08ms +step:1570/1680 train_time:141430ms step_avg:90.08ms +step:1571/1680 train_time:141521ms step_avg:90.08ms +step:1572/1680 train_time:141612ms step_avg:90.08ms +step:1573/1680 train_time:141702ms step_avg:90.08ms +step:1574/1680 train_time:141793ms step_avg:90.08ms +step:1575/1680 train_time:141884ms step_avg:90.09ms +step:1576/1680 train_time:141974ms step_avg:90.09ms +step:1577/1680 train_time:142064ms step_avg:90.09ms +step:1578/1680 train_time:142156ms step_avg:90.09ms +step:1579/1680 train_time:142246ms step_avg:90.09ms +step:1580/1680 train_time:142338ms step_avg:90.09ms +step:1581/1680 train_time:142429ms step_avg:90.09ms +step:1582/1680 train_time:142521ms step_avg:90.09ms +step:1583/1680 train_time:142612ms step_avg:90.09ms +step:1584/1680 train_time:142702ms step_avg:90.09ms +step:1585/1680 train_time:142793ms step_avg:90.09ms +step:1586/1680 train_time:142883ms step_avg:90.09ms +step:1587/1680 train_time:142975ms step_avg:90.09ms +step:1588/1680 train_time:143065ms step_avg:90.09ms +step:1589/1680 train_time:143157ms step_avg:90.09ms +step:1590/1680 train_time:143247ms step_avg:90.09ms +step:1591/1680 train_time:143339ms step_avg:90.09ms +step:1592/1680 train_time:143430ms step_avg:90.09ms +step:1593/1680 train_time:143521ms step_avg:90.09ms +step:1594/1680 train_time:143612ms step_avg:90.10ms +step:1595/1680 train_time:143703ms step_avg:90.10ms +step:1596/1680 train_time:143793ms step_avg:90.10ms +step:1597/1680 train_time:143885ms step_avg:90.10ms +step:1598/1680 train_time:143975ms step_avg:90.10ms +step:1599/1680 train_time:144066ms step_avg:90.10ms +step:1600/1680 train_time:144157ms step_avg:90.10ms +step:1601/1680 train_time:144248ms step_avg:90.10ms +step:1602/1680 train_time:144339ms step_avg:90.10ms +step:1603/1680 train_time:144429ms step_avg:90.10ms +step:1604/1680 train_time:144520ms step_avg:90.10ms +step:1605/1680 train_time:144611ms step_avg:90.10ms +step:1606/1680 train_time:144701ms step_avg:90.10ms +step:1607/1680 train_time:144791ms step_avg:90.10ms +step:1608/1680 train_time:144883ms step_avg:90.10ms +step:1609/1680 train_time:144973ms step_avg:90.10ms +step:1610/1680 train_time:145068ms step_avg:90.10ms +step:1611/1680 train_time:145155ms step_avg:90.10ms +step:1612/1680 train_time:145245ms step_avg:90.10ms +step:1613/1680 train_time:145337ms step_avg:90.10ms +step:1614/1680 train_time:145427ms step_avg:90.10ms +step:1615/1680 train_time:145518ms step_avg:90.10ms +step:1616/1680 train_time:145609ms step_avg:90.10ms +step:1617/1680 train_time:145699ms step_avg:90.10ms +step:1618/1680 train_time:145790ms step_avg:90.11ms +step:1619/1680 train_time:145880ms step_avg:90.11ms +step:1620/1680 train_time:145971ms step_avg:90.11ms +step:1621/1680 train_time:146063ms step_avg:90.11ms +step:1622/1680 train_time:146154ms step_avg:90.11ms +step:1623/1680 train_time:146244ms step_avg:90.11ms +step:1624/1680 train_time:146336ms step_avg:90.11ms +step:1625/1680 train_time:146427ms step_avg:90.11ms +step:1625/1680 val_loss:3.2887 train_time:146520ms step_avg:90.17ms +step:1626/1680 train_time:146543ms step_avg:90.12ms +step:1627/1680 train_time:146613ms step_avg:90.11ms +step:1628/1680 train_time:146712ms step_avg:90.12ms +step:1629/1680 train_time:146808ms step_avg:90.12ms +step:1630/1680 train_time:146896ms step_avg:90.12ms +step:1631/1680 train_time:146986ms step_avg:90.12ms +step:1632/1680 train_time:147076ms step_avg:90.12ms +step:1633/1680 train_time:147165ms step_avg:90.12ms +step:1634/1680 train_time:147255ms step_avg:90.12ms +step:1635/1680 train_time:147345ms step_avg:90.12ms +step:1636/1680 train_time:147435ms step_avg:90.12ms +step:1637/1680 train_time:147525ms step_avg:90.12ms +step:1638/1680 train_time:147619ms step_avg:90.12ms +step:1639/1680 train_time:147712ms step_avg:90.12ms +step:1640/1680 train_time:147804ms step_avg:90.12ms +step:1641/1680 train_time:147896ms step_avg:90.13ms +step:1642/1680 train_time:147987ms step_avg:90.13ms +step:1643/1680 train_time:148077ms step_avg:90.13ms +step:1644/1680 train_time:148166ms step_avg:90.13ms +step:1645/1680 train_time:148256ms step_avg:90.13ms +step:1646/1680 train_time:148345ms step_avg:90.12ms +step:1647/1680 train_time:148435ms step_avg:90.12ms +step:1648/1680 train_time:148526ms step_avg:90.12ms +step:1649/1680 train_time:148618ms step_avg:90.13ms +step:1650/1680 train_time:148710ms step_avg:90.13ms +step:1651/1680 train_time:148801ms step_avg:90.13ms +step:1652/1680 train_time:148893ms step_avg:90.13ms +step:1653/1680 train_time:148983ms step_avg:90.13ms +step:1654/1680 train_time:149074ms step_avg:90.13ms +step:1655/1680 train_time:149164ms step_avg:90.13ms +step:1656/1680 train_time:149255ms step_avg:90.13ms +step:1657/1680 train_time:149345ms step_avg:90.13ms +step:1658/1680 train_time:149435ms step_avg:90.13ms +step:1659/1680 train_time:149525ms step_avg:90.13ms +step:1660/1680 train_time:149617ms step_avg:90.13ms +step:1661/1680 train_time:149709ms step_avg:90.13ms +step:1662/1680 train_time:149801ms step_avg:90.13ms +step:1663/1680 train_time:149892ms step_avg:90.13ms +step:1664/1680 train_time:149983ms step_avg:90.13ms +step:1665/1680 train_time:150074ms step_avg:90.13ms +step:1666/1680 train_time:150165ms step_avg:90.14ms +step:1667/1680 train_time:150255ms step_avg:90.14ms +step:1668/1680 train_time:150345ms step_avg:90.13ms +step:1669/1680 train_time:150435ms step_avg:90.14ms +step:1670/1680 train_time:150526ms step_avg:90.14ms +step:1671/1680 train_time:150617ms step_avg:90.14ms +step:1672/1680 train_time:150708ms step_avg:90.14ms +step:1673/1680 train_time:150800ms step_avg:90.14ms +step:1674/1680 train_time:150890ms step_avg:90.14ms +step:1675/1680 train_time:150981ms step_avg:90.14ms +step:1676/1680 train_time:151073ms step_avg:90.14ms +step:1677/1680 train_time:151163ms step_avg:90.14ms +step:1678/1680 train_time:151255ms step_avg:90.14ms +step:1679/1680 train_time:151344ms step_avg:90.14ms +step:1680/1680 train_time:151434ms step_avg:90.14ms +step:1680/1680 val_loss:3.2781 train_time:151527ms step_avg:90.19ms +peak memory allocated: 31255 MiB reserved: 46474 MiB diff --git a/records/092125_DropAttn/e5a48f93-373e-4ff2-903b-5303bf912330.txt b/records/092125_DropAttn/e5a48f93-373e-4ff2-903b-5303bf912330.txt new file mode 100644 index 000000000..1d5802fd0 --- /dev/null +++ b/records/092125_DropAttn/e5a48f93-373e-4ff2-903b-5303bf912330.txt @@ -0,0 +1,3138 @@ +import os +import sys + +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import copy +import glob +import math +import threading +import time +import uuid +from dataclasses import dataclass +from itertools import accumulate +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch + +torch.empty( + 1, device="cuda", requires_grad=True +).backward() # prevents a bug on some systems +import torch._dynamo as dynamo +import torch.distributed as dist +import torch.nn.functional as F + +# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +from torch import Tensor, nn + +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + group_infos = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + if not params: + continue + + num_params = len(params) + padded_num_params = ( + (num_params + world_size - 1) // world_size * world_size + ) + + grads_to_stack = [p.grad for p in params] + if padded_num_params > num_params: + padding_grad = torch.zeros_like(params[0].grad) + grads_to_stack.extend( + [padding_grad] * (padded_num_params - num_params) + ) + + stacked_grads = torch.stack(grads_to_stack) + + chunk_size = padded_num_params // world_size + grad_chunk = torch.empty( + (chunk_size, *params[0].grad.shape), + dtype=stacked_grads.dtype, + device=stacked_grads.device, + ) + + reduce_future = dist.reduce_scatter_tensor( + grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True + ).get_future() + + group_infos.append( + { + "params": params, + "grad_chunk": grad_chunk, + "reduce_future": reduce_future, + "chunk_size": chunk_size, + "padded_num_params": padded_num_params, + } + ) + + all_gather_infos = [] + # Second pass: wait for gradients, compute updates for the local shard of parameters, + # and launch all async all_gather operations. + for group, info in zip(self.param_groups, group_infos): + info["reduce_future"].wait() + + params = info["params"] + grad_chunk = info["grad_chunk"] + chunk_size = info["chunk_size"] + start_idx = rank * chunk_size + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + # This helps in vectorizing operations later. + p_example = params[0] # All params in a group have the same shape. + eff_lr_val = ( + group["lr"] + * max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 + * getattr(p_example, "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(p_example, "wd_mul", 1.0) + ) + + # Determine effective LR and WD once per group, assuming constant for same-shaped params. + eff_lr_val = ( + group["lr"] + * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 + * getattr(params[0], "lr_mul", 1.0) + ) + eff_weight_decay_val = ( + group["lr"] + * group["weight_decay"] + * getattr(params[0], "wd_mul", 1.0) + ) + + # Prepare a contiguous buffer for the updated parameters for this rank's chunk. + # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. + updated_param_chunk = torch.empty( + (chunk_size, *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + + # List to collect update_grad tensors for batched zeropower computation. + update_grads_for_zeropower = [] + + # Process each parameter in this rank's chunk. + for i in range(chunk_size): + param_idx = start_idx + i + + if param_idx >= len(params): + # For padding: Fill the corresponding part of the updated_param_chunk with zeros. + # These padded entries will not be used by other ranks in the all_gather, but + # initializing them prevents uninitialized memory access issues. + updated_param_chunk[i].zero_() + # Also append a zero tensor for zeropower input if it must be padded. + update_grads_for_zeropower.append( + torch.zeros_like(params[0].grad) + ) + continue + p = params[param_idx] + grad = grad_chunk[ + i + ] # This gradient corresponds to the current parameter p. + state = self.state[p] + + # Initialize momentum buffer if not present + if not state: + state["momentum_buffer"] = torch.zeros_like(grad) + + momentum_buffer = state["momentum_buffer"] + + # Apply momentum update directly to the persistent momentum buffer in-place. + momentum_buffer.lerp_(grad, 1 - group["momentum"]) + + # Compute the actual `update_grad` for zeropower. This creates a new tensor. + update_grad = grad.lerp(momentum_buffer, group["momentum"]) + update_grads_for_zeropower.append(update_grad) + + # Copy the current parameter value into the temporary buffer. + updated_param_chunk[i].copy_(p) + + # Apply weight decay directly to the buffer. + updated_param_chunk[i].mul_(1 - eff_weight_decay_val) + + # Stack the individual `update_grad` tensors for efficient batched zeropower computation. + batched_update_grads = torch.stack(update_grads_for_zeropower) + + # Compute zeropower for the entire chunk in a single, batched call. + + original_shape = batched_update_grads.shape + if batched_update_grads.ndim > 3: + assert batched_update_grads.ndim == 4 + batch = original_shape[0] * original_shape[1] + # Flatten all but the first two dims after batch + d1 = original_shape[2] + d2 = original_shape[3] + batched = batched_update_grads.view(batch, d1, d2) + v_chunk = newton_schulz_triton(batched) + v_chunk = v_chunk.view(original_shape) + else: + v_chunk = newton_schulz_triton(batched_update_grads) + + # Add the computed zeropower update to the parameters in the buffer. + # This loop applies the zeropower output (v_chunk) to the `updated_param_chunk` buffer. + for i in range(chunk_size): + param_idx = start_idx + i + if param_idx >= len(params): # Skip padded entries again. + continue + + # Add the computed zeropower update to the parameter in the buffer. + updated_param_chunk[i].add_(v_chunk[i], alpha=-eff_lr_val) + + stacked_params = torch.empty( + (info["padded_num_params"], *params[0].shape), + dtype=params[0].dtype, + device=params[0].device, + ) + gather_future = dist.all_gather_into_tensor( + stacked_params, updated_param_chunk, async_op=True + ).get_future() + + all_gather_infos.append( + { + "gather_future": gather_future, + "stacked_params": stacked_params, + "orig_params": params, + } + ) + + # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. + for info in all_gather_infos: + info["gather_future"].wait() + stacked_params = info["stacked_params"] + orig_params = info["orig_params"] + + unstacked_params = torch.unbind(stacked_params) + for i, p in enumerate(orig_params): + p.copy_(unstacked_params[i], non_blocking=True) + + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state["step"] = torch.tensor( + 0, dtype=torch.int64, device=p.device + ) + state["exp_avg"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + state["exp_avg_sq"] = torch.zeros( + p_slice.shape, + dtype=torch.bfloat16, + device=p_slice.device, + ) + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + t = state["step"] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +# yarn implementation @classiclarryd +class Yarn(nn.Module): + def __init__(self, head_dim, max_seq_len): + super().__init__() + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.reset() + + def reset(self): + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) + theta = torch.outer(t, angular_freq) + self.cos = nn.Buffer( + theta.cos().to(torch.bfloat16), persistent=False + ) + self.sin = nn.Buffer( + theta.sin().to(torch.bfloat16), persistent=False + ) + self.angular_freq = angular_freq + # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.1 + + def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): + rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) + scaling_factor = old_window / new_window + interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) + self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) + t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) + theta = torch.outer(t, self.angular_freq) + self.cos.copy_(theta.cos()) + self.sin.copy_(theta.sin()) + self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 + +def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): + assert cos.size(0) >= x_BTHD.size(-3) + cos, sin = ( + cos[None, : x_BTHD.size(-3), None, :], + sin[None, : x_BTHD.size(-3), None, :], + ) + x1, x2 = x_BTHD.chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3) + +@dataclass +class AttnArgs: + ve: torch.Tensor + sa_lambdas: torch.Tensor + seqlens: torch.Tensor + bm_size: int + cos: torch.Tensor + sin: torch.Tensor + attn_scale: float + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate = CastedLinear(12, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, attn_args: AttnArgs): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + # unpack attention args + cos, sin = attn_args.cos, attn_args.sin + ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas + seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = rotary(q, cos, sin), rotary(k, cos, sin) + if ve is not None: + v = sa_lambdas[0] * v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = sa_lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 and layer_idx !=0 else None + # skip MLP blocks for first MLP layer by @EmelyanenkoK + self.mlp = MLP(dim) if layer_idx != 0 else None + + def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), attn_args) + if self.mlp is not None: + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + self.smear_gate = CastedLinear(12, 1) + self.smear_gate.weight.detach().zero_() + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) + self.yarn = Yarn(head_dim, max_seq_len) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 6) % dist.get_world_size() + self.scalars = nn.Parameter( + torch.cat( + [ + -1.5 + * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 + *[ + torch.tensor([1.0, 0.0]) for _ in range(num_layers) + ], # block lambdas + *[ + torch.tensor([0.5, 0.5]) for _ in range(num_layers) + ], # SA lambdas + torch.zeros(num_layers), #extra zeros params for smear_lambda + torch.ones(pad), + ] + ) + ) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + final_bm = ws_final_layer * args.block_size + first_bm = 1 * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, final_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, final_bm] + assert len(bm_sizes) == len(self.blocks) + + x = self.embed(input_seq) + + # smear token embed forward 1 position @classiclarryd + smear_lambda = self.scalars[5 * len(self.blocks)] + smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) + x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) + x = x0 = norm(x[None]) + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(1,len(self.blocks)): + attn_args = AttnArgs( + ve=ve[i], + sa_lambdas=sa_lambdas[i], + seqlens=seqlens, + bm_size=bm_sizes[i], + cos=self.yarn.cos, + sin=self.yarn.sin, + attn_scale=self.yarn.attn_scale + ) + if i >= n and i<11: + gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) + x = x + gate * skip_connections.pop() + x = self.blocks[i](x, x0, lambdas[i], attn_args) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): + # Precompute BOS positions once per shard + self.tokens=tokens + self.size = tokens.numel() + self.quickload = quickload + if quickload: + # only scan first 4 million tokens, then kickoff async thread to scan rest + self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.thread = None + self.ready = threading.Event() + self.start() + else: + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + self.batch_iter = 0 + + def _load(self): + self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + self.bos_idx = self.bos_idx_async + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + # if quickload was used, repoint to the full dataset after 5 batches + if self.quickload and self.batch_iter==5: + self.get() + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + self.batch_iter+=1 + return starts, ends + +class DataPreloader: + # Helper for asynchronously loading next shard and indexing bos tokens + def __init__(self, file_iter, world_size: int = 1): + self.file_iter = file_iter + self.world_size = world_size + self.thread = None + self.data = None + self.ready = threading.Event() + + def _load(self): + tokens = _load_data_shard(next(self.file_iter)) + self.data = (tokens, BOSFinder(tokens, self.world_size)) + self.ready.set() + + def start(self): + self.ready.clear() + self.thread = threading.Thread(target=self._load) + self.thread.start() + + def get(self): + if self.thread: + self.ready.wait() + self.thread.join() + return self.data + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + if align_to_bos: + finder = BOSFinder(tokens, world_size=world_size, quickload=True) + preloader = DataPreloader(file_iter, world_size) + preloader.start() + else: + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens, finder = preloader.get() + preloader.start() + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 + cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = f"v1/{uuid.uuid4()}" + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + ws_validate: int = 13 # increase final validation ws @classiclarryd + ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + head_dim=128, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, (nn.Embedding, nn.Linear)): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] +smear_gate_params = [p for n, p in model.named_parameters() if "smear" in n] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam( + scalar_params + head_params + embed_params, + lr=0.008, + betas=(0.8, 0.95), + eps=1e-8, + weight_decay=0.0, +) +optimizer2 = Muon(hidden_matrix_params+smear_gate_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = min(0.9999,step / args.num_iterations) + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate, args.ws_validate_final_layer + x = min(step / (1 + args.num_iterations),0.9999) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +ws=args.ws_schedule[0] +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params + if new_ws > ws: + model.yarn.apply(ws, new_ws) + ws = new_ws + elif new_ws 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250724+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Sun Sep 21 22:30:08 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 42C P0 124W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 37C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 38C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 60331 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 60332 C /usr/bin/python3 614MiB | +| 0 N/A N/A 60333 C /usr/bin/python3 614MiB | +| 0 N/A N/A 60334 C /usr/bin/python3 614MiB | +| 0 N/A N/A 60335 C /usr/bin/python3 614MiB | +| 0 N/A N/A 60336 C /usr/bin/python3 614MiB | +| 0 N/A N/A 60337 C /usr/bin/python3 614MiB | +| 0 N/A N/A 60338 C /usr/bin/python3 614MiB | +| 1 N/A N/A 60332 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 60333 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 60334 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 60335 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 60336 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 60337 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 60338 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1680 val_loss:10.8258 train_time:0ms step_avg:0.03ms +step:1/1680 train_time:152ms step_avg:152.07ms +step:2/1680 train_time:175ms step_avg:87.64ms +step:3/1680 train_time:237ms step_avg:79.15ms +step:4/1680 train_time:325ms step_avg:81.15ms +step:5/1680 train_time:413ms step_avg:82.57ms +step:6/1680 train_time:504ms step_avg:83.98ms +step:7/1680 train_time:589ms step_avg:84.20ms +step:8/1680 train_time:678ms step_avg:84.73ms +step:9/1680 train_time:766ms step_avg:85.09ms +step:10/1680 train_time:854ms step_avg:85.36ms +step:11/1680 train_time:942ms step_avg:85.61ms +step:12/1680 train_time:1031ms step_avg:85.94ms +step:13/1680 train_time:1124ms step_avg:86.43ms +step:14/1680 train_time:1215ms step_avg:86.76ms +step:15/1680 train_time:1304ms step_avg:86.93ms +step:16/1680 train_time:1394ms step_avg:87.11ms +step:17/1680 train_time:1483ms step_avg:87.24ms +step:18/1680 train_time:1572ms step_avg:87.32ms +step:19/1680 train_time:1661ms step_avg:87.40ms +step:20/1680 train_time:1750ms step_avg:87.50ms +step:21/1680 train_time:1839ms step_avg:87.57ms +step:22/1680 train_time:1928ms step_avg:87.65ms +step:23/1680 train_time:2019ms step_avg:87.77ms +step:24/1680 train_time:2111ms step_avg:87.95ms +step:25/1680 train_time:2201ms step_avg:88.05ms +step:26/1680 train_time:2294ms step_avg:88.24ms +step:27/1680 train_time:2381ms step_avg:88.20ms +step:28/1680 train_time:2471ms step_avg:88.25ms +step:29/1680 train_time:2560ms step_avg:88.26ms +step:30/1680 train_time:2649ms step_avg:88.29ms +step:31/1680 train_time:2737ms step_avg:88.30ms +step:32/1680 train_time:2826ms step_avg:88.32ms +step:33/1680 train_time:2915ms step_avg:88.34ms +step:34/1680 train_time:3005ms step_avg:88.38ms +step:35/1680 train_time:3095ms step_avg:88.42ms +step:36/1680 train_time:3185ms step_avg:88.46ms +step:37/1680 train_time:3274ms step_avg:88.49ms +step:38/1680 train_time:3364ms step_avg:88.52ms +step:39/1680 train_time:3454ms step_avg:88.55ms +step:40/1680 train_time:3543ms step_avg:88.57ms +step:41/1680 train_time:3632ms step_avg:88.58ms +step:42/1680 train_time:3721ms step_avg:88.59ms +step:43/1680 train_time:3810ms step_avg:88.60ms +step:44/1680 train_time:3900ms step_avg:88.63ms +step:45/1680 train_time:3989ms step_avg:88.65ms +step:46/1680 train_time:4079ms step_avg:88.67ms +step:47/1680 train_time:4169ms step_avg:88.69ms +step:48/1680 train_time:4258ms step_avg:88.70ms +step:49/1680 train_time:4347ms step_avg:88.71ms +step:50/1680 train_time:4437ms step_avg:88.74ms +step:51/1680 train_time:4526ms step_avg:88.75ms +step:52/1680 train_time:4616ms step_avg:88.77ms +step:53/1680 train_time:4705ms step_avg:88.77ms +step:54/1680 train_time:4794ms step_avg:88.77ms +step:55/1680 train_time:4882ms step_avg:88.76ms +step:56/1680 train_time:4971ms step_avg:88.76ms +step:57/1680 train_time:5060ms step_avg:88.77ms +step:58/1680 train_time:5150ms step_avg:88.79ms +step:59/1680 train_time:5239ms step_avg:88.79ms +step:60/1680 train_time:5328ms step_avg:88.79ms +step:61/1680 train_time:5417ms step_avg:88.80ms +step:62/1680 train_time:5506ms step_avg:88.81ms +step:63/1680 train_time:5595ms step_avg:88.81ms +step:64/1680 train_time:5684ms step_avg:88.81ms +step:65/1680 train_time:5773ms step_avg:88.81ms +step:66/1680 train_time:5861ms step_avg:88.80ms +step:67/1680 train_time:5951ms step_avg:88.82ms +step:68/1680 train_time:6040ms step_avg:88.82ms +step:69/1680 train_time:6128ms step_avg:88.81ms +step:70/1680 train_time:6217ms step_avg:88.81ms +step:71/1680 train_time:6305ms step_avg:88.81ms +step:72/1680 train_time:6394ms step_avg:88.81ms +step:73/1680 train_time:6483ms step_avg:88.81ms +step:74/1680 train_time:6573ms step_avg:88.82ms +step:75/1680 train_time:6662ms step_avg:88.82ms +step:76/1680 train_time:6751ms step_avg:88.83ms +step:77/1680 train_time:6839ms step_avg:88.82ms +step:78/1680 train_time:6928ms step_avg:88.81ms +step:79/1680 train_time:7016ms step_avg:88.82ms +step:80/1680 train_time:7106ms step_avg:88.83ms +step:81/1680 train_time:7195ms step_avg:88.83ms +step:82/1680 train_time:7284ms step_avg:88.83ms +step:83/1680 train_time:7373ms step_avg:88.84ms +step:84/1680 train_time:7462ms step_avg:88.84ms +step:85/1680 train_time:7551ms step_avg:88.83ms +step:86/1680 train_time:7641ms step_avg:88.85ms +step:87/1680 train_time:7730ms step_avg:88.84ms +step:88/1680 train_time:7819ms step_avg:88.85ms +step:89/1680 train_time:7908ms step_avg:88.85ms +step:90/1680 train_time:7997ms step_avg:88.86ms +step:91/1680 train_time:8087ms step_avg:88.86ms +step:92/1680 train_time:8177ms step_avg:88.88ms +step:93/1680 train_time:8265ms step_avg:88.87ms +step:94/1680 train_time:8356ms step_avg:88.89ms +step:95/1680 train_time:8445ms step_avg:88.89ms +step:96/1680 train_time:8534ms step_avg:88.90ms +step:97/1680 train_time:8623ms step_avg:88.90ms +step:98/1680 train_time:8712ms step_avg:88.90ms +step:99/1680 train_time:8801ms step_avg:88.90ms +step:100/1680 train_time:8889ms step_avg:88.89ms +step:101/1680 train_time:8979ms step_avg:88.90ms +step:102/1680 train_time:9068ms step_avg:88.90ms +step:103/1680 train_time:9158ms step_avg:88.91ms +step:104/1680 train_time:9246ms step_avg:88.91ms +step:105/1680 train_time:9336ms step_avg:88.92ms +step:106/1680 train_time:9426ms step_avg:88.92ms +step:107/1680 train_time:9515ms step_avg:88.93ms +step:108/1680 train_time:9605ms step_avg:88.93ms +step:109/1680 train_time:9693ms step_avg:88.93ms +step:110/1680 train_time:9782ms step_avg:88.93ms +step:111/1680 train_time:9871ms step_avg:88.93ms +step:112/1680 train_time:9960ms step_avg:88.93ms +step:113/1680 train_time:10049ms step_avg:88.93ms +step:114/1680 train_time:10138ms step_avg:88.93ms +step:115/1680 train_time:10227ms step_avg:88.93ms +step:116/1680 train_time:10316ms step_avg:88.93ms +step:117/1680 train_time:10404ms step_avg:88.92ms +step:118/1680 train_time:10494ms step_avg:88.93ms +step:119/1680 train_time:10583ms step_avg:88.93ms +step:120/1680 train_time:10672ms step_avg:88.93ms +step:121/1680 train_time:10760ms step_avg:88.92ms +step:122/1680 train_time:10848ms step_avg:88.92ms +step:123/1680 train_time:10937ms step_avg:88.92ms +step:124/1680 train_time:11027ms step_avg:88.92ms +step:125/1680 train_time:11116ms step_avg:88.93ms +step:125/1680 val_loss:4.3255 train_time:11206ms step_avg:89.65ms +step:126/1680 train_time:11229ms step_avg:89.12ms +step:127/1680 train_time:11298ms step_avg:88.96ms +step:128/1680 train_time:11395ms step_avg:89.03ms +step:129/1680 train_time:11488ms step_avg:89.06ms +step:130/1680 train_time:11578ms step_avg:89.06ms +step:131/1680 train_time:11666ms step_avg:89.06ms +step:132/1680 train_time:11754ms step_avg:89.05ms +step:133/1680 train_time:11842ms step_avg:89.04ms +step:134/1680 train_time:11930ms step_avg:89.03ms +step:135/1680 train_time:12017ms step_avg:89.02ms +step:136/1680 train_time:12105ms step_avg:89.01ms +step:137/1680 train_time:12195ms step_avg:89.01ms +step:138/1680 train_time:12286ms step_avg:89.03ms +step:139/1680 train_time:12376ms step_avg:89.03ms +step:140/1680 train_time:12466ms step_avg:89.04ms +step:141/1680 train_time:12555ms step_avg:89.04ms +step:142/1680 train_time:12645ms step_avg:89.05ms +step:143/1680 train_time:12733ms step_avg:89.05ms +step:144/1680 train_time:12822ms step_avg:89.04ms +step:145/1680 train_time:12910ms step_avg:89.04ms +step:146/1680 train_time:12999ms step_avg:89.03ms +step:147/1680 train_time:13088ms step_avg:89.03ms +step:148/1680 train_time:13177ms step_avg:89.03ms +step:149/1680 train_time:13266ms step_avg:89.03ms +step:150/1680 train_time:13355ms step_avg:89.03ms +step:151/1680 train_time:13445ms step_avg:89.04ms +step:152/1680 train_time:13534ms step_avg:89.04ms +step:153/1680 train_time:13624ms step_avg:89.04ms +step:154/1680 train_time:13712ms step_avg:89.04ms +step:155/1680 train_time:13802ms step_avg:89.04ms +step:156/1680 train_time:13889ms step_avg:89.04ms +step:157/1680 train_time:13979ms step_avg:89.04ms +step:158/1680 train_time:14067ms step_avg:89.03ms +step:159/1680 train_time:14156ms step_avg:89.03ms +step:160/1680 train_time:14245ms step_avg:89.03ms +step:161/1680 train_time:14334ms step_avg:89.03ms +step:162/1680 train_time:14424ms step_avg:89.04ms +step:163/1680 train_time:14513ms step_avg:89.04ms +step:164/1680 train_time:14603ms step_avg:89.04ms +step:165/1680 train_time:14692ms step_avg:89.04ms +step:166/1680 train_time:14781ms step_avg:89.04ms +step:167/1680 train_time:14869ms step_avg:89.04ms +step:168/1680 train_time:14958ms step_avg:89.04ms +step:169/1680 train_time:15047ms step_avg:89.03ms +step:170/1680 train_time:15137ms step_avg:89.04ms +step:171/1680 train_time:15225ms step_avg:89.04ms +step:172/1680 train_time:15315ms step_avg:89.04ms +step:173/1680 train_time:15404ms step_avg:89.04ms +step:174/1680 train_time:15494ms step_avg:89.04ms +step:175/1680 train_time:15583ms step_avg:89.04ms +step:176/1680 train_time:15672ms step_avg:89.04ms +step:177/1680 train_time:15761ms step_avg:89.04ms +step:178/1680 train_time:15850ms step_avg:89.05ms +step:179/1680 train_time:15939ms step_avg:89.04ms +step:180/1680 train_time:16027ms step_avg:89.04ms +step:181/1680 train_time:16116ms step_avg:89.04ms +step:182/1680 train_time:16206ms step_avg:89.04ms +step:183/1680 train_time:16295ms step_avg:89.05ms +step:184/1680 train_time:16384ms step_avg:89.04ms +step:185/1680 train_time:16473ms step_avg:89.05ms +step:186/1680 train_time:16563ms step_avg:89.05ms +step:187/1680 train_time:16652ms step_avg:89.05ms +step:188/1680 train_time:16742ms step_avg:89.05ms +step:189/1680 train_time:16831ms step_avg:89.05ms +step:190/1680 train_time:16920ms step_avg:89.05ms +step:191/1680 train_time:17008ms step_avg:89.05ms +step:192/1680 train_time:17097ms step_avg:89.04ms +step:193/1680 train_time:17186ms step_avg:89.05ms +step:194/1680 train_time:17275ms step_avg:89.05ms +step:195/1680 train_time:17364ms step_avg:89.05ms +step:196/1680 train_time:17453ms step_avg:89.05ms +step:197/1680 train_time:17542ms step_avg:89.05ms +step:198/1680 train_time:17631ms step_avg:89.05ms +step:199/1680 train_time:17722ms step_avg:89.05ms +step:200/1680 train_time:17811ms step_avg:89.05ms +step:201/1680 train_time:17900ms step_avg:89.06ms +step:202/1680 train_time:17989ms step_avg:89.05ms +step:203/1680 train_time:18078ms step_avg:89.05ms +step:204/1680 train_time:18166ms step_avg:89.05ms +step:205/1680 train_time:18255ms step_avg:89.05ms +step:206/1680 train_time:18343ms step_avg:89.05ms +step:207/1680 train_time:18433ms step_avg:89.05ms +step:208/1680 train_time:18523ms step_avg:89.05ms +step:209/1680 train_time:18611ms step_avg:89.05ms +step:210/1680 train_time:18700ms step_avg:89.05ms +step:211/1680 train_time:18789ms step_avg:89.05ms +step:212/1680 train_time:18878ms step_avg:89.05ms +step:213/1680 train_time:18966ms step_avg:89.04ms +step:214/1680 train_time:19055ms step_avg:89.04ms +step:215/1680 train_time:19144ms step_avg:89.04ms +step:216/1680 train_time:19233ms step_avg:89.04ms +step:217/1680 train_time:19323ms step_avg:89.04ms +step:218/1680 train_time:19412ms step_avg:89.04ms +step:219/1680 train_time:19502ms step_avg:89.05ms +step:220/1680 train_time:19591ms step_avg:89.05ms +step:221/1680 train_time:19680ms step_avg:89.05ms +step:222/1680 train_time:19769ms step_avg:89.05ms +step:223/1680 train_time:19858ms step_avg:89.05ms +step:224/1680 train_time:19947ms step_avg:89.05ms +step:225/1680 train_time:20036ms step_avg:89.05ms +step:226/1680 train_time:20125ms step_avg:89.05ms +step:227/1680 train_time:20214ms step_avg:89.05ms +step:228/1680 train_time:20303ms step_avg:89.05ms +step:229/1680 train_time:20392ms step_avg:89.05ms +step:230/1680 train_time:20481ms step_avg:89.05ms +step:231/1680 train_time:20570ms step_avg:89.05ms +step:232/1680 train_time:20659ms step_avg:89.05ms +step:233/1680 train_time:20748ms step_avg:89.05ms +step:234/1680 train_time:20836ms step_avg:89.04ms +step:235/1680 train_time:20926ms step_avg:89.05ms +step:236/1680 train_time:21015ms step_avg:89.05ms +step:237/1680 train_time:21105ms step_avg:89.05ms +step:238/1680 train_time:21195ms step_avg:89.05ms +step:239/1680 train_time:21283ms step_avg:89.05ms +step:240/1680 train_time:21372ms step_avg:89.05ms +step:241/1680 train_time:21461ms step_avg:89.05ms +step:242/1680 train_time:21551ms step_avg:89.05ms +step:243/1680 train_time:21641ms step_avg:89.06ms +step:244/1680 train_time:21729ms step_avg:89.05ms +step:245/1680 train_time:21819ms step_avg:89.06ms +step:246/1680 train_time:21908ms step_avg:89.06ms +step:247/1680 train_time:21997ms step_avg:89.06ms +step:248/1680 train_time:22086ms step_avg:89.06ms +step:249/1680 train_time:22175ms step_avg:89.06ms +step:250/1680 train_time:22263ms step_avg:89.05ms +step:250/1680 val_loss:3.9772 train_time:22353ms step_avg:89.41ms +step:251/1680 train_time:22376ms step_avg:89.15ms +step:252/1680 train_time:22446ms step_avg:89.07ms +step:253/1680 train_time:22541ms step_avg:89.09ms +step:254/1680 train_time:22631ms step_avg:89.10ms +step:255/1680 train_time:22719ms step_avg:89.09ms +step:256/1680 train_time:22806ms step_avg:89.09ms +step:257/1680 train_time:22894ms step_avg:89.08ms +step:258/1680 train_time:22982ms step_avg:89.08ms +step:259/1680 train_time:23070ms step_avg:89.07ms +step:260/1680 train_time:23157ms step_avg:89.07ms +step:261/1680 train_time:23245ms step_avg:89.06ms +step:262/1680 train_time:23335ms step_avg:89.07ms +step:263/1680 train_time:23427ms step_avg:89.08ms +step:264/1680 train_time:23518ms step_avg:89.08ms +step:265/1680 train_time:23608ms step_avg:89.09ms +step:266/1680 train_time:23697ms step_avg:89.09ms +step:267/1680 train_time:23786ms step_avg:89.09ms +step:268/1680 train_time:23875ms step_avg:89.09ms +step:269/1680 train_time:23964ms step_avg:89.08ms +step:270/1680 train_time:24052ms step_avg:89.08ms +step:271/1680 train_time:24140ms step_avg:89.08ms +step:272/1680 train_time:24228ms step_avg:89.07ms +step:273/1680 train_time:24317ms step_avg:89.07ms +step:274/1680 train_time:24408ms step_avg:89.08ms +step:275/1680 train_time:24497ms step_avg:89.08ms +step:276/1680 train_time:24588ms step_avg:89.09ms +step:277/1680 train_time:24677ms step_avg:89.09ms +step:278/1680 train_time:24766ms step_avg:89.08ms +step:279/1680 train_time:24855ms step_avg:89.08ms +step:280/1680 train_time:24943ms step_avg:89.08ms +step:281/1680 train_time:25030ms step_avg:89.07ms +step:282/1680 train_time:25118ms step_avg:89.07ms +step:283/1680 train_time:25206ms step_avg:89.07ms +step:284/1680 train_time:25295ms step_avg:89.07ms +step:285/1680 train_time:25385ms step_avg:89.07ms +step:286/1680 train_time:25474ms step_avg:89.07ms +step:287/1680 train_time:25564ms step_avg:89.07ms +step:288/1680 train_time:25653ms step_avg:89.07ms +step:289/1680 train_time:25742ms step_avg:89.07ms +step:290/1680 train_time:25832ms step_avg:89.08ms +step:291/1680 train_time:25920ms step_avg:89.07ms +step:292/1680 train_time:26008ms step_avg:89.07ms +step:293/1680 train_time:26098ms step_avg:89.07ms +step:294/1680 train_time:26186ms step_avg:89.07ms +step:295/1680 train_time:26274ms step_avg:89.07ms +step:296/1680 train_time:26363ms step_avg:89.07ms +step:297/1680 train_time:26453ms step_avg:89.07ms +step:298/1680 train_time:26542ms step_avg:89.07ms +step:299/1680 train_time:26631ms step_avg:89.07ms +step:300/1680 train_time:26721ms step_avg:89.07ms +step:301/1680 train_time:26809ms step_avg:89.07ms +step:302/1680 train_time:26898ms step_avg:89.07ms +step:303/1680 train_time:26986ms step_avg:89.06ms +step:304/1680 train_time:27075ms step_avg:89.06ms +step:305/1680 train_time:27164ms step_avg:89.06ms +step:306/1680 train_time:27253ms step_avg:89.06ms +step:307/1680 train_time:27341ms step_avg:89.06ms +step:308/1680 train_time:27431ms step_avg:89.06ms +step:309/1680 train_time:27520ms step_avg:89.06ms +step:310/1680 train_time:27610ms step_avg:89.06ms +step:311/1680 train_time:27699ms step_avg:89.06ms +step:312/1680 train_time:27788ms step_avg:89.06ms +step:313/1680 train_time:27877ms step_avg:89.06ms +step:314/1680 train_time:27965ms step_avg:89.06ms +step:315/1680 train_time:28054ms step_avg:89.06ms +step:316/1680 train_time:28143ms step_avg:89.06ms +step:317/1680 train_time:28232ms step_avg:89.06ms +step:318/1680 train_time:28321ms step_avg:89.06ms +step:319/1680 train_time:28410ms step_avg:89.06ms +step:320/1680 train_time:28498ms step_avg:89.06ms +step:321/1680 train_time:28588ms step_avg:89.06ms +step:322/1680 train_time:28677ms step_avg:89.06ms +step:323/1680 train_time:28766ms step_avg:89.06ms +step:324/1680 train_time:28856ms step_avg:89.06ms +step:325/1680 train_time:28944ms step_avg:89.06ms +step:326/1680 train_time:29032ms step_avg:89.06ms +step:327/1680 train_time:29121ms step_avg:89.05ms +step:328/1680 train_time:29210ms step_avg:89.05ms +step:329/1680 train_time:29298ms step_avg:89.05ms +step:330/1680 train_time:29387ms step_avg:89.05ms +step:331/1680 train_time:29476ms step_avg:89.05ms +step:332/1680 train_time:29565ms step_avg:89.05ms +step:333/1680 train_time:29655ms step_avg:89.05ms +step:334/1680 train_time:29744ms step_avg:89.05ms +step:335/1680 train_time:29833ms step_avg:89.05ms +step:336/1680 train_time:29922ms step_avg:89.05ms +step:337/1680 train_time:30012ms step_avg:89.06ms +step:338/1680 train_time:30101ms step_avg:89.06ms +step:339/1680 train_time:30189ms step_avg:89.05ms +step:340/1680 train_time:30278ms step_avg:89.05ms +step:341/1680 train_time:30367ms step_avg:89.05ms +step:342/1680 train_time:30456ms step_avg:89.05ms +step:343/1680 train_time:30545ms step_avg:89.05ms +step:344/1680 train_time:30634ms step_avg:89.05ms +step:345/1680 train_time:30723ms step_avg:89.05ms +step:346/1680 train_time:30812ms step_avg:89.05ms +step:347/1680 train_time:30901ms step_avg:89.05ms +step:348/1680 train_time:30990ms step_avg:89.05ms +step:349/1680 train_time:31078ms step_avg:89.05ms +step:350/1680 train_time:31167ms step_avg:89.05ms +step:351/1680 train_time:31257ms step_avg:89.05ms +step:352/1680 train_time:31346ms step_avg:89.05ms +step:353/1680 train_time:31435ms step_avg:89.05ms +step:354/1680 train_time:31524ms step_avg:89.05ms +step:355/1680 train_time:31614ms step_avg:89.05ms +step:356/1680 train_time:31704ms step_avg:89.06ms +step:357/1680 train_time:31794ms step_avg:89.06ms +step:358/1680 train_time:31883ms step_avg:89.06ms +step:359/1680 train_time:31972ms step_avg:89.06ms +step:360/1680 train_time:32061ms step_avg:89.06ms +step:361/1680 train_time:32149ms step_avg:89.06ms +step:362/1680 train_time:32237ms step_avg:89.05ms +step:363/1680 train_time:32326ms step_avg:89.05ms +step:364/1680 train_time:32414ms step_avg:89.05ms +step:365/1680 train_time:32503ms step_avg:89.05ms +step:366/1680 train_time:32592ms step_avg:89.05ms +step:367/1680 train_time:32681ms step_avg:89.05ms +step:368/1680 train_time:32770ms step_avg:89.05ms +step:369/1680 train_time:32859ms step_avg:89.05ms +step:370/1680 train_time:32949ms step_avg:89.05ms +step:371/1680 train_time:33038ms step_avg:89.05ms +step:372/1680 train_time:33127ms step_avg:89.05ms +step:373/1680 train_time:33216ms step_avg:89.05ms +step:374/1680 train_time:33304ms step_avg:89.05ms +step:375/1680 train_time:33393ms step_avg:89.05ms +step:375/1680 val_loss:3.8210 train_time:33484ms step_avg:89.29ms +step:376/1680 train_time:33506ms step_avg:89.11ms +step:377/1680 train_time:33576ms step_avg:89.06ms +step:378/1680 train_time:33678ms step_avg:89.09ms +step:379/1680 train_time:33764ms step_avg:89.09ms +step:380/1680 train_time:33854ms step_avg:89.09ms +step:381/1680 train_time:33942ms step_avg:89.09ms +step:382/1680 train_time:34030ms step_avg:89.09ms +step:383/1680 train_time:34119ms step_avg:89.08ms +step:384/1680 train_time:34206ms step_avg:89.08ms +step:385/1680 train_time:34294ms step_avg:89.07ms +step:386/1680 train_time:34382ms step_avg:89.07ms +step:387/1680 train_time:34470ms step_avg:89.07ms +step:388/1680 train_time:34560ms step_avg:89.07ms +step:389/1680 train_time:34651ms step_avg:89.08ms +step:390/1680 train_time:34742ms step_avg:89.08ms +step:391/1680 train_time:34831ms step_avg:89.08ms +step:392/1680 train_time:34921ms step_avg:89.08ms +step:393/1680 train_time:35010ms step_avg:89.08ms +step:394/1680 train_time:35099ms step_avg:89.08ms +step:395/1680 train_time:35189ms step_avg:89.09ms +step:396/1680 train_time:35275ms step_avg:89.08ms +step:397/1680 train_time:35363ms step_avg:89.07ms +step:398/1680 train_time:35451ms step_avg:89.07ms +step:399/1680 train_time:35541ms step_avg:89.08ms +step:400/1680 train_time:35631ms step_avg:89.08ms +step:401/1680 train_time:35722ms step_avg:89.08ms +step:402/1680 train_time:35811ms step_avg:89.08ms +step:403/1680 train_time:35900ms step_avg:89.08ms +step:404/1680 train_time:35989ms step_avg:89.08ms +step:405/1680 train_time:36078ms step_avg:89.08ms +step:406/1680 train_time:36167ms step_avg:89.08ms +step:407/1680 train_time:36255ms step_avg:89.08ms +step:408/1680 train_time:36343ms step_avg:89.08ms +step:409/1680 train_time:36431ms step_avg:89.07ms +step:410/1680 train_time:36521ms step_avg:89.08ms +step:411/1680 train_time:36610ms step_avg:89.08ms +step:412/1680 train_time:36700ms step_avg:89.08ms +step:413/1680 train_time:36789ms step_avg:89.08ms +step:414/1680 train_time:36878ms step_avg:89.08ms +step:415/1680 train_time:36967ms step_avg:89.08ms +step:416/1680 train_time:37056ms step_avg:89.08ms +step:417/1680 train_time:37145ms step_avg:89.08ms +step:418/1680 train_time:37233ms step_avg:89.07ms +step:419/1680 train_time:37322ms step_avg:89.07ms +step:420/1680 train_time:37410ms step_avg:89.07ms +step:421/1680 train_time:37499ms step_avg:89.07ms +step:422/1680 train_time:37588ms step_avg:89.07ms +step:423/1680 train_time:37679ms step_avg:89.08ms +step:424/1680 train_time:37768ms step_avg:89.07ms +step:425/1680 train_time:37858ms step_avg:89.08ms +step:426/1680 train_time:37947ms step_avg:89.08ms +step:427/1680 train_time:38036ms step_avg:89.08ms +step:428/1680 train_time:38124ms step_avg:89.08ms +step:429/1680 train_time:38214ms step_avg:89.08ms +step:430/1680 train_time:38303ms step_avg:89.08ms +step:431/1680 train_time:38391ms step_avg:89.08ms +step:432/1680 train_time:38480ms step_avg:89.07ms +step:433/1680 train_time:38570ms step_avg:89.08ms +step:434/1680 train_time:38660ms step_avg:89.08ms +step:435/1680 train_time:38749ms step_avg:89.08ms +step:436/1680 train_time:38838ms step_avg:89.08ms +step:437/1680 train_time:38928ms step_avg:89.08ms +step:438/1680 train_time:39016ms step_avg:89.08ms +step:439/1680 train_time:39105ms step_avg:89.08ms +step:440/1680 train_time:39194ms step_avg:89.08ms +step:441/1680 train_time:39284ms step_avg:89.08ms +step:442/1680 train_time:39372ms step_avg:89.08ms +step:443/1680 train_time:39460ms step_avg:89.08ms +step:444/1680 train_time:39549ms step_avg:89.07ms +step:445/1680 train_time:39638ms step_avg:89.07ms +step:446/1680 train_time:39727ms step_avg:89.07ms +step:447/1680 train_time:39816ms step_avg:89.07ms +step:448/1680 train_time:39905ms step_avg:89.07ms +step:449/1680 train_time:39994ms step_avg:89.07ms +step:450/1680 train_time:40083ms step_avg:89.07ms +step:451/1680 train_time:40173ms step_avg:89.07ms +step:452/1680 train_time:40261ms step_avg:89.07ms +step:453/1680 train_time:40350ms step_avg:89.07ms +step:454/1680 train_time:40440ms step_avg:89.07ms +step:455/1680 train_time:40530ms step_avg:89.08ms +step:456/1680 train_time:40619ms step_avg:89.08ms +step:457/1680 train_time:40709ms step_avg:89.08ms +step:458/1680 train_time:40798ms step_avg:89.08ms +step:459/1680 train_time:40887ms step_avg:89.08ms +step:460/1680 train_time:40976ms step_avg:89.08ms +step:461/1680 train_time:41065ms step_avg:89.08ms +step:462/1680 train_time:41154ms step_avg:89.08ms +step:463/1680 train_time:41243ms step_avg:89.08ms +step:464/1680 train_time:41332ms step_avg:89.08ms +step:465/1680 train_time:41421ms step_avg:89.08ms +step:466/1680 train_time:41510ms step_avg:89.08ms +step:467/1680 train_time:41599ms step_avg:89.08ms +step:468/1680 train_time:41689ms step_avg:89.08ms +step:469/1680 train_time:41779ms step_avg:89.08ms +step:470/1680 train_time:41868ms step_avg:89.08ms +step:471/1680 train_time:41958ms step_avg:89.08ms +step:472/1680 train_time:42048ms step_avg:89.08ms +step:473/1680 train_time:42137ms step_avg:89.08ms +step:474/1680 train_time:42225ms step_avg:89.08ms +step:475/1680 train_time:42315ms step_avg:89.08ms +step:476/1680 train_time:42403ms step_avg:89.08ms +step:477/1680 train_time:42493ms step_avg:89.08ms +step:478/1680 train_time:42581ms step_avg:89.08ms +step:479/1680 train_time:42671ms step_avg:89.08ms +step:480/1680 train_time:42760ms step_avg:89.08ms +step:481/1680 train_time:42850ms step_avg:89.09ms +step:482/1680 train_time:42940ms step_avg:89.09ms +step:483/1680 train_time:43030ms step_avg:89.09ms +step:484/1680 train_time:43119ms step_avg:89.09ms +step:485/1680 train_time:43208ms step_avg:89.09ms +step:486/1680 train_time:43298ms step_avg:89.09ms +step:487/1680 train_time:43386ms step_avg:89.09ms +step:488/1680 train_time:43475ms step_avg:89.09ms +step:489/1680 train_time:43563ms step_avg:89.09ms +step:490/1680 train_time:43652ms step_avg:89.09ms +step:491/1680 train_time:43742ms step_avg:89.09ms +step:492/1680 train_time:43831ms step_avg:89.09ms +step:493/1680 train_time:43920ms step_avg:89.09ms +step:494/1680 train_time:44009ms step_avg:89.09ms +step:495/1680 train_time:44099ms step_avg:89.09ms +step:496/1680 train_time:44189ms step_avg:89.09ms +step:497/1680 train_time:44279ms step_avg:89.09ms +step:498/1680 train_time:44368ms step_avg:89.09ms +step:499/1680 train_time:44457ms step_avg:89.09ms +step:500/1680 train_time:44546ms step_avg:89.09ms +step:500/1680 val_loss:3.7202 train_time:44636ms step_avg:89.27ms +step:501/1680 train_time:44658ms step_avg:89.14ms +step:502/1680 train_time:44727ms step_avg:89.10ms +step:503/1680 train_time:44823ms step_avg:89.11ms +step:504/1680 train_time:44913ms step_avg:89.11ms +step:505/1680 train_time:45001ms step_avg:89.11ms +step:506/1680 train_time:45089ms step_avg:89.11ms +step:507/1680 train_time:45177ms step_avg:89.11ms +step:508/1680 train_time:45265ms step_avg:89.10ms +step:509/1680 train_time:45352ms step_avg:89.10ms +step:510/1680 train_time:45440ms step_avg:89.10ms +step:511/1680 train_time:45529ms step_avg:89.10ms +step:512/1680 train_time:45619ms step_avg:89.10ms +step:513/1680 train_time:45710ms step_avg:89.10ms +step:514/1680 train_time:45802ms step_avg:89.11ms +step:515/1680 train_time:45894ms step_avg:89.11ms +step:516/1680 train_time:45982ms step_avg:89.11ms +step:517/1680 train_time:46070ms step_avg:89.11ms +step:518/1680 train_time:46159ms step_avg:89.11ms +step:519/1680 train_time:46247ms step_avg:89.11ms +step:520/1680 train_time:46335ms step_avg:89.11ms +step:521/1680 train_time:46423ms step_avg:89.10ms +step:522/1680 train_time:46511ms step_avg:89.10ms +step:523/1680 train_time:46600ms step_avg:89.10ms +step:524/1680 train_time:46690ms step_avg:89.10ms +step:525/1680 train_time:46780ms step_avg:89.10ms +step:526/1680 train_time:46870ms step_avg:89.11ms +step:527/1680 train_time:46960ms step_avg:89.11ms +step:528/1680 train_time:47049ms step_avg:89.11ms +step:529/1680 train_time:47139ms step_avg:89.11ms +step:530/1680 train_time:47227ms step_avg:89.11ms +step:531/1680 train_time:47315ms step_avg:89.11ms +step:532/1680 train_time:47403ms step_avg:89.10ms +step:533/1680 train_time:47491ms step_avg:89.10ms +step:534/1680 train_time:47580ms step_avg:89.10ms +step:535/1680 train_time:47670ms step_avg:89.10ms +step:536/1680 train_time:47761ms step_avg:89.11ms +step:537/1680 train_time:47851ms step_avg:89.11ms +step:538/1680 train_time:47941ms step_avg:89.11ms +step:539/1680 train_time:48031ms step_avg:89.11ms +step:540/1680 train_time:48120ms step_avg:89.11ms +step:541/1680 train_time:48210ms step_avg:89.11ms +step:542/1680 train_time:48300ms step_avg:89.11ms +step:543/1680 train_time:48389ms step_avg:89.11ms +step:544/1680 train_time:48477ms step_avg:89.11ms +step:545/1680 train_time:48565ms step_avg:89.11ms +step:546/1680 train_time:48655ms step_avg:89.11ms +step:547/1680 train_time:48744ms step_avg:89.11ms +step:548/1680 train_time:48833ms step_avg:89.11ms +step:549/1680 train_time:48923ms step_avg:89.11ms +step:550/1680 train_time:49013ms step_avg:89.11ms +step:551/1680 train_time:49104ms step_avg:89.12ms +step:552/1680 train_time:49194ms step_avg:89.12ms +step:553/1680 train_time:49284ms step_avg:89.12ms +step:554/1680 train_time:49374ms step_avg:89.12ms +step:555/1680 train_time:49463ms step_avg:89.12ms +step:556/1680 train_time:49553ms step_avg:89.12ms +step:557/1680 train_time:49644ms step_avg:89.13ms +step:558/1680 train_time:49734ms step_avg:89.13ms +step:559/1680 train_time:49825ms step_avg:89.13ms +step:560/1680 train_time:49915ms step_avg:89.13ms +step:561/1680 train_time:50005ms step_avg:89.14ms +step:562/1680 train_time:50096ms step_avg:89.14ms +step:563/1680 train_time:50186ms step_avg:89.14ms +step:564/1680 train_time:50277ms step_avg:89.14ms +step:565/1680 train_time:50367ms step_avg:89.14ms +step:566/1680 train_time:50457ms step_avg:89.15ms +step:567/1680 train_time:50547ms step_avg:89.15ms +step:568/1680 train_time:50638ms step_avg:89.15ms +step:569/1680 train_time:50728ms step_avg:89.15ms +step:570/1680 train_time:50817ms step_avg:89.15ms +step:571/1680 train_time:50907ms step_avg:89.15ms +step:572/1680 train_time:50997ms step_avg:89.16ms +step:573/1680 train_time:51088ms step_avg:89.16ms +step:574/1680 train_time:51178ms step_avg:89.16ms +step:575/1680 train_time:51268ms step_avg:89.16ms +step:576/1680 train_time:51358ms step_avg:89.16ms +step:577/1680 train_time:51450ms step_avg:89.17ms +step:578/1680 train_time:51537ms step_avg:89.16ms +step:579/1680 train_time:51627ms step_avg:89.17ms +step:580/1680 train_time:51716ms step_avg:89.17ms +step:581/1680 train_time:51806ms step_avg:89.17ms +step:582/1680 train_time:51896ms step_avg:89.17ms +step:583/1680 train_time:51987ms step_avg:89.17ms +step:584/1680 train_time:52078ms step_avg:89.17ms +step:585/1680 train_time:52167ms step_avg:89.17ms +step:586/1680 train_time:52258ms step_avg:89.18ms +step:587/1680 train_time:52348ms step_avg:89.18ms +step:588/1680 train_time:52438ms step_avg:89.18ms +step:589/1680 train_time:52528ms step_avg:89.18ms +step:590/1680 train_time:52618ms step_avg:89.18ms +step:591/1680 train_time:52709ms step_avg:89.19ms +step:592/1680 train_time:52799ms step_avg:89.19ms +step:593/1680 train_time:52889ms step_avg:89.19ms +step:594/1680 train_time:52979ms step_avg:89.19ms +step:595/1680 train_time:53069ms step_avg:89.19ms +step:596/1680 train_time:53159ms step_avg:89.19ms +step:597/1680 train_time:53250ms step_avg:89.20ms +step:598/1680 train_time:53341ms step_avg:89.20ms +step:599/1680 train_time:53431ms step_avg:89.20ms +step:600/1680 train_time:53521ms step_avg:89.20ms +step:601/1680 train_time:53612ms step_avg:89.20ms +step:602/1680 train_time:53702ms step_avg:89.21ms +step:603/1680 train_time:53792ms step_avg:89.21ms +step:604/1680 train_time:53883ms step_avg:89.21ms +step:605/1680 train_time:53974ms step_avg:89.21ms +step:606/1680 train_time:54064ms step_avg:89.21ms +step:607/1680 train_time:54155ms step_avg:89.22ms +step:608/1680 train_time:54245ms step_avg:89.22ms +step:609/1680 train_time:54335ms step_avg:89.22ms +step:610/1680 train_time:54425ms step_avg:89.22ms +step:611/1680 train_time:54516ms step_avg:89.22ms +step:612/1680 train_time:54606ms step_avg:89.22ms +step:613/1680 train_time:54695ms step_avg:89.23ms +step:614/1680 train_time:54785ms step_avg:89.23ms +step:615/1680 train_time:54875ms step_avg:89.23ms +step:616/1680 train_time:54965ms step_avg:89.23ms +step:617/1680 train_time:55055ms step_avg:89.23ms +step:618/1680 train_time:55145ms step_avg:89.23ms +step:619/1680 train_time:55236ms step_avg:89.23ms +step:620/1680 train_time:55326ms step_avg:89.24ms +step:621/1680 train_time:55416ms step_avg:89.24ms +step:622/1680 train_time:55506ms step_avg:89.24ms +step:623/1680 train_time:55597ms step_avg:89.24ms +step:624/1680 train_time:55686ms step_avg:89.24ms +step:625/1680 train_time:55777ms step_avg:89.24ms +step:625/1680 val_loss:3.6163 train_time:55868ms step_avg:89.39ms +step:626/1680 train_time:55891ms step_avg:89.28ms +step:627/1680 train_time:55962ms step_avg:89.25ms +step:628/1680 train_time:56064ms step_avg:89.27ms +step:629/1680 train_time:56157ms step_avg:89.28ms +step:630/1680 train_time:56248ms step_avg:89.28ms +step:631/1680 train_time:56337ms step_avg:89.28ms +step:632/1680 train_time:56426ms step_avg:89.28ms +step:633/1680 train_time:56515ms step_avg:89.28ms +step:634/1680 train_time:56603ms step_avg:89.28ms +step:635/1680 train_time:56693ms step_avg:89.28ms +step:636/1680 train_time:56783ms step_avg:89.28ms +step:637/1680 train_time:56874ms step_avg:89.28ms +step:638/1680 train_time:56967ms step_avg:89.29ms +step:639/1680 train_time:57060ms step_avg:89.30ms +step:640/1680 train_time:57152ms step_avg:89.30ms +step:641/1680 train_time:57242ms step_avg:89.30ms +step:642/1680 train_time:57331ms step_avg:89.30ms +step:643/1680 train_time:57420ms step_avg:89.30ms +step:644/1680 train_time:57510ms step_avg:89.30ms +step:645/1680 train_time:57599ms step_avg:89.30ms +step:646/1680 train_time:57688ms step_avg:89.30ms +step:647/1680 train_time:57778ms step_avg:89.30ms +step:648/1680 train_time:57868ms step_avg:89.30ms +step:649/1680 train_time:57959ms step_avg:89.31ms +step:650/1680 train_time:58051ms step_avg:89.31ms +step:651/1680 train_time:58141ms step_avg:89.31ms +step:652/1680 train_time:58232ms step_avg:89.31ms +step:653/1680 train_time:58322ms step_avg:89.31ms +step:654/1680 train_time:58412ms step_avg:89.32ms +step:655/1680 train_time:58501ms step_avg:89.31ms +step:656/1680 train_time:58591ms step_avg:89.32ms +step:657/1680 train_time:58680ms step_avg:89.32ms +step:658/1680 train_time:58770ms step_avg:89.32ms +step:659/1680 train_time:58860ms step_avg:89.32ms +step:660/1680 train_time:58950ms step_avg:89.32ms +step:661/1680 train_time:59041ms step_avg:89.32ms +step:662/1680 train_time:59132ms step_avg:89.32ms +step:663/1680 train_time:59223ms step_avg:89.33ms +step:664/1680 train_time:59313ms step_avg:89.33ms +step:665/1680 train_time:59403ms step_avg:89.33ms +step:666/1680 train_time:59493ms step_avg:89.33ms +step:667/1680 train_time:59583ms step_avg:89.33ms +step:668/1680 train_time:59673ms step_avg:89.33ms +step:669/1680 train_time:59761ms step_avg:89.33ms +step:670/1680 train_time:59852ms step_avg:89.33ms +step:671/1680 train_time:59942ms step_avg:89.33ms +step:672/1680 train_time:60032ms step_avg:89.33ms +step:673/1680 train_time:60123ms step_avg:89.34ms +step:674/1680 train_time:60213ms step_avg:89.34ms +step:675/1680 train_time:60304ms step_avg:89.34ms +step:676/1680 train_time:60394ms step_avg:89.34ms +step:677/1680 train_time:60483ms step_avg:89.34ms +step:678/1680 train_time:60573ms step_avg:89.34ms +step:679/1680 train_time:60662ms step_avg:89.34ms +step:680/1680 train_time:60752ms step_avg:89.34ms +step:681/1680 train_time:60841ms step_avg:89.34ms +step:682/1680 train_time:60931ms step_avg:89.34ms +step:683/1680 train_time:61021ms step_avg:89.34ms +step:684/1680 train_time:61111ms step_avg:89.34ms +step:685/1680 train_time:61202ms step_avg:89.35ms +step:686/1680 train_time:61293ms step_avg:89.35ms +step:687/1680 train_time:61383ms step_avg:89.35ms +step:688/1680 train_time:61472ms step_avg:89.35ms +step:689/1680 train_time:61563ms step_avg:89.35ms +step:690/1680 train_time:61652ms step_avg:89.35ms +step:691/1680 train_time:61741ms step_avg:89.35ms +step:692/1680 train_time:61830ms step_avg:89.35ms +step:693/1680 train_time:61920ms step_avg:89.35ms +step:694/1680 train_time:62011ms step_avg:89.35ms +step:695/1680 train_time:62100ms step_avg:89.35ms +step:696/1680 train_time:62191ms step_avg:89.36ms +step:697/1680 train_time:62282ms step_avg:89.36ms +step:698/1680 train_time:62371ms step_avg:89.36ms +step:699/1680 train_time:62461ms step_avg:89.36ms +step:700/1680 train_time:62551ms step_avg:89.36ms +step:701/1680 train_time:62641ms step_avg:89.36ms +step:702/1680 train_time:62731ms step_avg:89.36ms +step:703/1680 train_time:62820ms step_avg:89.36ms +step:704/1680 train_time:62911ms step_avg:89.36ms +step:705/1680 train_time:63001ms step_avg:89.36ms +step:706/1680 train_time:63091ms step_avg:89.36ms +step:707/1680 train_time:63182ms step_avg:89.37ms +step:708/1680 train_time:63273ms step_avg:89.37ms +step:709/1680 train_time:63363ms step_avg:89.37ms +step:710/1680 train_time:63454ms step_avg:89.37ms +step:711/1680 train_time:63544ms step_avg:89.37ms +step:712/1680 train_time:63634ms step_avg:89.37ms +step:713/1680 train_time:63724ms step_avg:89.37ms +step:714/1680 train_time:63814ms step_avg:89.38ms +step:715/1680 train_time:63904ms step_avg:89.38ms +step:716/1680 train_time:63994ms step_avg:89.38ms +step:717/1680 train_time:64084ms step_avg:89.38ms +step:718/1680 train_time:64174ms step_avg:89.38ms +step:719/1680 train_time:64264ms step_avg:89.38ms +step:720/1680 train_time:64354ms step_avg:89.38ms +step:721/1680 train_time:64444ms step_avg:89.38ms +step:722/1680 train_time:64534ms step_avg:89.38ms +step:723/1680 train_time:64624ms step_avg:89.38ms +step:724/1680 train_time:64714ms step_avg:89.38ms +step:725/1680 train_time:64804ms step_avg:89.38ms +step:726/1680 train_time:64894ms step_avg:89.39ms +step:727/1680 train_time:64985ms step_avg:89.39ms +step:728/1680 train_time:65075ms step_avg:89.39ms +step:729/1680 train_time:65165ms step_avg:89.39ms +step:730/1680 train_time:65256ms step_avg:89.39ms +step:731/1680 train_time:65346ms step_avg:89.39ms +step:732/1680 train_time:65436ms step_avg:89.39ms +step:733/1680 train_time:65526ms step_avg:89.39ms +step:734/1680 train_time:65616ms step_avg:89.40ms +step:735/1680 train_time:65707ms step_avg:89.40ms +step:736/1680 train_time:65797ms step_avg:89.40ms +step:737/1680 train_time:65887ms step_avg:89.40ms +step:738/1680 train_time:65978ms step_avg:89.40ms +step:739/1680 train_time:66069ms step_avg:89.40ms +step:740/1680 train_time:66159ms step_avg:89.40ms +step:741/1680 train_time:66249ms step_avg:89.41ms +step:742/1680 train_time:66340ms step_avg:89.41ms +step:743/1680 train_time:66430ms step_avg:89.41ms +step:744/1680 train_time:66520ms step_avg:89.41ms +step:745/1680 train_time:66611ms step_avg:89.41ms +step:746/1680 train_time:66700ms step_avg:89.41ms +step:747/1680 train_time:66791ms step_avg:89.41ms +step:748/1680 train_time:66883ms step_avg:89.42ms +step:749/1680 train_time:66970ms step_avg:89.41ms +step:750/1680 train_time:67060ms step_avg:89.41ms +step:750/1680 val_loss:3.5666 train_time:67152ms step_avg:89.54ms +step:751/1680 train_time:67176ms step_avg:89.45ms +step:752/1680 train_time:67250ms step_avg:89.43ms +step:753/1680 train_time:67344ms step_avg:89.43ms +step:754/1680 train_time:67436ms step_avg:89.44ms +step:755/1680 train_time:67526ms step_avg:89.44ms +step:756/1680 train_time:67615ms step_avg:89.44ms +step:757/1680 train_time:67704ms step_avg:89.44ms +step:758/1680 train_time:67793ms step_avg:89.44ms +step:759/1680 train_time:67882ms step_avg:89.44ms +step:760/1680 train_time:67971ms step_avg:89.44ms +step:761/1680 train_time:68060ms step_avg:89.44ms +step:762/1680 train_time:68152ms step_avg:89.44ms +step:763/1680 train_time:68244ms step_avg:89.44ms +step:764/1680 train_time:68337ms step_avg:89.45ms +step:765/1680 train_time:68430ms step_avg:89.45ms +step:766/1680 train_time:68520ms step_avg:89.45ms +step:767/1680 train_time:68609ms step_avg:89.45ms +step:768/1680 train_time:68699ms step_avg:89.45ms +step:769/1680 train_time:68789ms step_avg:89.45ms +step:770/1680 train_time:68878ms step_avg:89.45ms +step:771/1680 train_time:68966ms step_avg:89.45ms +step:772/1680 train_time:69056ms step_avg:89.45ms +step:773/1680 train_time:69147ms step_avg:89.45ms +step:774/1680 train_time:69238ms step_avg:89.45ms +step:775/1680 train_time:69330ms step_avg:89.46ms +step:776/1680 train_time:69421ms step_avg:89.46ms +step:777/1680 train_time:69512ms step_avg:89.46ms +step:778/1680 train_time:69602ms step_avg:89.46ms +step:779/1680 train_time:69692ms step_avg:89.46ms +step:780/1680 train_time:69781ms step_avg:89.46ms +step:781/1680 train_time:69870ms step_avg:89.46ms +step:782/1680 train_time:69960ms step_avg:89.46ms +step:783/1680 train_time:70049ms step_avg:89.46ms +step:784/1680 train_time:70139ms step_avg:89.46ms +step:785/1680 train_time:70230ms step_avg:89.46ms +step:786/1680 train_time:70321ms step_avg:89.47ms +step:787/1680 train_time:70411ms step_avg:89.47ms +step:788/1680 train_time:70502ms step_avg:89.47ms +step:789/1680 train_time:70593ms step_avg:89.47ms +step:790/1680 train_time:70683ms step_avg:89.47ms +step:791/1680 train_time:70773ms step_avg:89.47ms +step:792/1680 train_time:70863ms step_avg:89.47ms +step:793/1680 train_time:70953ms step_avg:89.47ms +step:794/1680 train_time:71042ms step_avg:89.47ms +step:795/1680 train_time:71132ms step_avg:89.47ms +step:796/1680 train_time:71222ms step_avg:89.47ms +step:797/1680 train_time:71313ms step_avg:89.48ms +step:798/1680 train_time:71403ms step_avg:89.48ms +step:799/1680 train_time:71494ms step_avg:89.48ms +step:800/1680 train_time:71584ms step_avg:89.48ms +step:801/1680 train_time:71675ms step_avg:89.48ms +step:802/1680 train_time:71764ms step_avg:89.48ms +step:803/1680 train_time:71854ms step_avg:89.48ms +step:804/1680 train_time:71944ms step_avg:89.48ms +step:805/1680 train_time:72034ms step_avg:89.48ms +step:806/1680 train_time:72124ms step_avg:89.48ms +step:807/1680 train_time:72214ms step_avg:89.48ms +step:808/1680 train_time:72303ms step_avg:89.48ms +step:809/1680 train_time:72395ms step_avg:89.49ms +step:810/1680 train_time:72486ms step_avg:89.49ms +step:811/1680 train_time:72576ms step_avg:89.49ms +step:812/1680 train_time:72667ms step_avg:89.49ms +step:813/1680 train_time:72757ms step_avg:89.49ms +step:814/1680 train_time:72848ms step_avg:89.49ms +step:815/1680 train_time:72938ms step_avg:89.49ms +step:816/1680 train_time:73028ms step_avg:89.50ms +step:817/1680 train_time:73118ms step_avg:89.50ms +step:818/1680 train_time:73208ms step_avg:89.50ms +step:819/1680 train_time:73298ms step_avg:89.50ms +step:820/1680 train_time:73389ms step_avg:89.50ms +step:821/1680 train_time:73479ms step_avg:89.50ms +step:822/1680 train_time:73569ms step_avg:89.50ms +step:823/1680 train_time:73659ms step_avg:89.50ms +step:824/1680 train_time:73749ms step_avg:89.50ms +step:825/1680 train_time:73839ms step_avg:89.50ms +step:826/1680 train_time:73929ms step_avg:89.50ms +step:827/1680 train_time:74018ms step_avg:89.50ms +step:828/1680 train_time:74108ms step_avg:89.50ms +step:829/1680 train_time:74198ms step_avg:89.50ms +step:830/1680 train_time:74288ms step_avg:89.50ms +step:831/1680 train_time:74378ms step_avg:89.50ms +step:832/1680 train_time:74469ms step_avg:89.51ms +step:833/1680 train_time:74560ms step_avg:89.51ms +step:834/1680 train_time:74649ms step_avg:89.51ms +step:835/1680 train_time:74740ms step_avg:89.51ms +step:836/1680 train_time:74830ms step_avg:89.51ms +step:837/1680 train_time:74920ms step_avg:89.51ms +step:838/1680 train_time:75010ms step_avg:89.51ms +step:839/1680 train_time:75100ms step_avg:89.51ms +step:840/1680 train_time:75190ms step_avg:89.51ms +step:841/1680 train_time:75280ms step_avg:89.51ms +step:842/1680 train_time:75370ms step_avg:89.51ms +step:843/1680 train_time:75461ms step_avg:89.51ms +step:844/1680 train_time:75550ms step_avg:89.51ms +step:845/1680 train_time:75641ms step_avg:89.52ms +step:846/1680 train_time:75730ms step_avg:89.52ms +step:847/1680 train_time:75821ms step_avg:89.52ms +step:848/1680 train_time:75911ms step_avg:89.52ms +step:849/1680 train_time:76001ms step_avg:89.52ms +step:850/1680 train_time:76091ms step_avg:89.52ms +step:851/1680 train_time:76181ms step_avg:89.52ms +step:852/1680 train_time:76271ms step_avg:89.52ms +step:853/1680 train_time:76361ms step_avg:89.52ms +step:854/1680 train_time:76451ms step_avg:89.52ms +step:855/1680 train_time:76541ms step_avg:89.52ms +step:856/1680 train_time:76631ms step_avg:89.52ms +step:857/1680 train_time:76721ms step_avg:89.52ms +step:858/1680 train_time:76810ms step_avg:89.52ms +step:859/1680 train_time:76901ms step_avg:89.52ms +step:860/1680 train_time:76992ms step_avg:89.53ms +step:861/1680 train_time:77081ms step_avg:89.53ms +step:862/1680 train_time:77172ms step_avg:89.53ms +step:863/1680 train_time:77261ms step_avg:89.53ms +step:864/1680 train_time:77352ms step_avg:89.53ms +step:865/1680 train_time:77442ms step_avg:89.53ms +step:866/1680 train_time:77532ms step_avg:89.53ms +step:867/1680 train_time:77622ms step_avg:89.53ms +step:868/1680 train_time:77712ms step_avg:89.53ms +step:869/1680 train_time:77802ms step_avg:89.53ms +step:870/1680 train_time:77892ms step_avg:89.53ms +step:871/1680 train_time:77981ms step_avg:89.53ms +step:872/1680 train_time:78071ms step_avg:89.53ms +step:873/1680 train_time:78161ms step_avg:89.53ms +step:874/1680 train_time:78251ms step_avg:89.53ms +step:875/1680 train_time:78340ms step_avg:89.53ms +step:875/1680 val_loss:3.5192 train_time:78431ms step_avg:89.64ms +step:876/1680 train_time:78454ms step_avg:89.56ms +step:877/1680 train_time:78525ms step_avg:89.54ms +step:878/1680 train_time:78623ms step_avg:89.55ms +step:879/1680 train_time:78715ms step_avg:89.55ms +step:880/1680 train_time:78804ms step_avg:89.55ms +step:881/1680 train_time:78894ms step_avg:89.55ms +step:882/1680 train_time:78983ms step_avg:89.55ms +step:883/1680 train_time:79071ms step_avg:89.55ms +step:884/1680 train_time:79160ms step_avg:89.55ms +step:885/1680 train_time:79249ms step_avg:89.55ms +step:886/1680 train_time:79338ms step_avg:89.55ms +step:887/1680 train_time:79429ms step_avg:89.55ms +step:888/1680 train_time:79522ms step_avg:89.55ms +step:889/1680 train_time:79616ms step_avg:89.56ms +step:890/1680 train_time:79706ms step_avg:89.56ms +step:891/1680 train_time:79797ms step_avg:89.56ms +step:892/1680 train_time:79888ms step_avg:89.56ms +step:893/1680 train_time:79978ms step_avg:89.56ms +step:894/1680 train_time:80067ms step_avg:89.56ms +step:895/1680 train_time:80156ms step_avg:89.56ms +step:896/1680 train_time:80246ms step_avg:89.56ms +step:897/1680 train_time:80336ms step_avg:89.56ms +step:898/1680 train_time:80428ms step_avg:89.56ms +step:899/1680 train_time:80520ms step_avg:89.57ms +step:900/1680 train_time:80611ms step_avg:89.57ms +step:901/1680 train_time:80702ms step_avg:89.57ms +step:902/1680 train_time:80793ms step_avg:89.57ms +step:903/1680 train_time:80884ms step_avg:89.57ms +step:904/1680 train_time:80975ms step_avg:89.57ms +step:905/1680 train_time:81065ms step_avg:89.57ms +step:906/1680 train_time:81154ms step_avg:89.57ms +step:907/1680 train_time:81244ms step_avg:89.57ms +step:908/1680 train_time:81334ms step_avg:89.57ms +step:909/1680 train_time:81425ms step_avg:89.58ms +step:910/1680 train_time:81516ms step_avg:89.58ms +step:911/1680 train_time:81607ms step_avg:89.58ms +step:912/1680 train_time:81698ms step_avg:89.58ms +step:913/1680 train_time:81788ms step_avg:89.58ms +step:914/1680 train_time:81879ms step_avg:89.58ms +step:915/1680 train_time:81969ms step_avg:89.58ms +step:916/1680 train_time:82058ms step_avg:89.58ms +step:917/1680 train_time:82147ms step_avg:89.58ms +step:918/1680 train_time:82238ms step_avg:89.58ms +step:919/1680 train_time:82327ms step_avg:89.58ms +step:920/1680 train_time:82418ms step_avg:89.58ms +step:921/1680 train_time:82508ms step_avg:89.59ms +step:922/1680 train_time:82599ms step_avg:89.59ms +step:923/1680 train_time:82690ms step_avg:89.59ms +step:924/1680 train_time:82781ms step_avg:89.59ms +step:925/1680 train_time:82871ms step_avg:89.59ms +step:926/1680 train_time:82961ms step_avg:89.59ms +step:927/1680 train_time:83051ms step_avg:89.59ms +step:928/1680 train_time:83141ms step_avg:89.59ms +step:929/1680 train_time:83230ms step_avg:89.59ms +step:930/1680 train_time:83319ms step_avg:89.59ms +step:931/1680 train_time:83410ms step_avg:89.59ms +step:932/1680 train_time:83500ms step_avg:89.59ms +step:933/1680 train_time:83591ms step_avg:89.59ms +step:934/1680 train_time:83681ms step_avg:89.59ms +step:935/1680 train_time:83771ms step_avg:89.59ms +step:936/1680 train_time:83862ms step_avg:89.60ms +step:937/1680 train_time:83951ms step_avg:89.60ms +step:938/1680 train_time:84041ms step_avg:89.60ms +step:939/1680 train_time:84131ms step_avg:89.60ms +step:940/1680 train_time:84221ms step_avg:89.60ms +step:941/1680 train_time:84311ms step_avg:89.60ms +step:942/1680 train_time:84400ms step_avg:89.60ms +step:943/1680 train_time:84493ms step_avg:89.60ms +step:944/1680 train_time:84580ms step_avg:89.60ms +step:945/1680 train_time:84671ms step_avg:89.60ms +step:946/1680 train_time:84761ms step_avg:89.60ms +step:947/1680 train_time:84852ms step_avg:89.60ms +step:948/1680 train_time:84943ms step_avg:89.60ms +step:949/1680 train_time:85032ms step_avg:89.60ms +step:950/1680 train_time:85122ms step_avg:89.60ms +step:951/1680 train_time:85213ms step_avg:89.60ms +step:952/1680 train_time:85303ms step_avg:89.60ms +step:953/1680 train_time:85393ms step_avg:89.60ms +step:954/1680 train_time:85483ms step_avg:89.61ms +step:955/1680 train_time:85574ms step_avg:89.61ms +step:956/1680 train_time:85663ms step_avg:89.61ms +step:957/1680 train_time:85754ms step_avg:89.61ms +step:958/1680 train_time:85844ms step_avg:89.61ms +step:959/1680 train_time:85934ms step_avg:89.61ms +step:960/1680 train_time:86026ms step_avg:89.61ms +step:961/1680 train_time:86116ms step_avg:89.61ms +step:962/1680 train_time:86206ms step_avg:89.61ms +step:963/1680 train_time:86296ms step_avg:89.61ms +step:964/1680 train_time:86386ms step_avg:89.61ms +step:965/1680 train_time:86477ms step_avg:89.61ms +step:966/1680 train_time:86567ms step_avg:89.61ms +step:967/1680 train_time:86657ms step_avg:89.61ms +step:968/1680 train_time:86747ms step_avg:89.62ms +step:969/1680 train_time:86838ms step_avg:89.62ms +step:970/1680 train_time:86928ms step_avg:89.62ms +step:971/1680 train_time:87018ms step_avg:89.62ms +step:972/1680 train_time:87108ms step_avg:89.62ms +step:973/1680 train_time:87197ms step_avg:89.62ms +step:974/1680 train_time:87287ms step_avg:89.62ms +step:975/1680 train_time:87377ms step_avg:89.62ms +step:976/1680 train_time:87468ms step_avg:89.62ms +step:977/1680 train_time:87558ms step_avg:89.62ms +step:978/1680 train_time:87648ms step_avg:89.62ms +step:979/1680 train_time:87739ms step_avg:89.62ms +step:980/1680 train_time:87829ms step_avg:89.62ms +step:981/1680 train_time:87919ms step_avg:89.62ms +step:982/1680 train_time:88009ms step_avg:89.62ms +step:983/1680 train_time:88100ms step_avg:89.62ms +step:984/1680 train_time:88190ms step_avg:89.62ms +step:985/1680 train_time:88279ms step_avg:89.62ms +step:986/1680 train_time:88369ms step_avg:89.62ms +step:987/1680 train_time:88459ms step_avg:89.62ms +step:988/1680 train_time:88549ms step_avg:89.62ms +step:989/1680 train_time:88639ms step_avg:89.62ms +step:990/1680 train_time:88729ms step_avg:89.63ms +step:991/1680 train_time:88819ms step_avg:89.63ms +step:992/1680 train_time:88909ms step_avg:89.63ms +step:993/1680 train_time:89000ms step_avg:89.63ms +step:994/1680 train_time:89090ms step_avg:89.63ms +step:995/1680 train_time:89181ms step_avg:89.63ms +step:996/1680 train_time:89271ms step_avg:89.63ms +step:997/1680 train_time:89360ms step_avg:89.63ms +step:998/1680 train_time:89451ms step_avg:89.63ms +step:999/1680 train_time:89541ms step_avg:89.63ms +step:1000/1680 train_time:89630ms step_avg:89.63ms +step:1000/1680 val_loss:3.4683 train_time:89721ms step_avg:89.72ms +step:1001/1680 train_time:89745ms step_avg:89.65ms +step:1002/1680 train_time:89814ms step_avg:89.63ms +step:1003/1680 train_time:89906ms step_avg:89.64ms +step:1004/1680 train_time:89996ms step_avg:89.64ms +step:1005/1680 train_time:90085ms step_avg:89.64ms +step:1006/1680 train_time:90174ms step_avg:89.64ms +step:1007/1680 train_time:90263ms step_avg:89.64ms +step:1008/1680 train_time:90353ms step_avg:89.64ms +step:1009/1680 train_time:90442ms step_avg:89.64ms +step:1010/1680 train_time:90532ms step_avg:89.64ms +step:1011/1680 train_time:90621ms step_avg:89.64ms +step:1012/1680 train_time:90712ms step_avg:89.64ms +step:1013/1680 train_time:90803ms step_avg:89.64ms +step:1014/1680 train_time:90895ms step_avg:89.64ms +step:1015/1680 train_time:90985ms step_avg:89.64ms +step:1016/1680 train_time:91075ms step_avg:89.64ms +step:1017/1680 train_time:91166ms step_avg:89.64ms +step:1018/1680 train_time:91256ms step_avg:89.64ms +step:1019/1680 train_time:91345ms step_avg:89.64ms +step:1020/1680 train_time:91434ms step_avg:89.64ms +step:1021/1680 train_time:91524ms step_avg:89.64ms +step:1022/1680 train_time:91614ms step_avg:89.64ms +step:1023/1680 train_time:91705ms step_avg:89.64ms +step:1024/1680 train_time:91797ms step_avg:89.65ms +step:1025/1680 train_time:91888ms step_avg:89.65ms +step:1026/1680 train_time:91978ms step_avg:89.65ms +step:1027/1680 train_time:92069ms step_avg:89.65ms +step:1028/1680 train_time:92159ms step_avg:89.65ms +step:1029/1680 train_time:92249ms step_avg:89.65ms +step:1030/1680 train_time:92343ms step_avg:89.65ms +step:1031/1680 train_time:92428ms step_avg:89.65ms +step:1032/1680 train_time:92518ms step_avg:89.65ms +step:1033/1680 train_time:92608ms step_avg:89.65ms +step:1034/1680 train_time:92698ms step_avg:89.65ms +step:1035/1680 train_time:92789ms step_avg:89.65ms +step:1036/1680 train_time:92879ms step_avg:89.65ms +step:1037/1680 train_time:92970ms step_avg:89.65ms +step:1038/1680 train_time:93060ms step_avg:89.65ms +step:1039/1680 train_time:93150ms step_avg:89.65ms +step:1040/1680 train_time:93240ms step_avg:89.65ms +step:1041/1680 train_time:93330ms step_avg:89.65ms +step:1042/1680 train_time:93419ms step_avg:89.65ms +step:1043/1680 train_time:93509ms step_avg:89.65ms +step:1044/1680 train_time:93599ms step_avg:89.65ms +step:1045/1680 train_time:93689ms step_avg:89.65ms +step:1046/1680 train_time:93780ms step_avg:89.66ms +step:1047/1680 train_time:93870ms step_avg:89.66ms +step:1048/1680 train_time:93961ms step_avg:89.66ms +step:1049/1680 train_time:94051ms step_avg:89.66ms +step:1050/1680 train_time:94143ms step_avg:89.66ms +step:1051/1680 train_time:94233ms step_avg:89.66ms +step:1052/1680 train_time:94323ms step_avg:89.66ms +step:1053/1680 train_time:94412ms step_avg:89.66ms +step:1054/1680 train_time:94501ms step_avg:89.66ms +step:1055/1680 train_time:94590ms step_avg:89.66ms +step:1056/1680 train_time:94680ms step_avg:89.66ms +step:1057/1680 train_time:94770ms step_avg:89.66ms +step:1058/1680 train_time:94860ms step_avg:89.66ms +step:1059/1680 train_time:94950ms step_avg:89.66ms +step:1060/1680 train_time:95041ms step_avg:89.66ms +step:1061/1680 train_time:95132ms step_avg:89.66ms +step:1062/1680 train_time:95222ms step_avg:89.66ms +step:1063/1680 train_time:95311ms step_avg:89.66ms +step:1064/1680 train_time:95401ms step_avg:89.66ms +step:1065/1680 train_time:95491ms step_avg:89.66ms +step:1066/1680 train_time:95581ms step_avg:89.66ms +step:1067/1680 train_time:95672ms step_avg:89.66ms +step:1068/1680 train_time:95761ms step_avg:89.66ms +step:1069/1680 train_time:95851ms step_avg:89.66ms +step:1070/1680 train_time:95942ms step_avg:89.67ms +step:1071/1680 train_time:96033ms step_avg:89.67ms +step:1072/1680 train_time:96123ms step_avg:89.67ms +step:1073/1680 train_time:96213ms step_avg:89.67ms +step:1074/1680 train_time:96303ms step_avg:89.67ms +step:1075/1680 train_time:96393ms step_avg:89.67ms +step:1076/1680 train_time:96482ms step_avg:89.67ms +step:1077/1680 train_time:96572ms step_avg:89.67ms +step:1078/1680 train_time:96663ms step_avg:89.67ms +step:1079/1680 train_time:96752ms step_avg:89.67ms +step:1080/1680 train_time:96842ms step_avg:89.67ms +step:1081/1680 train_time:96932ms step_avg:89.67ms +step:1082/1680 train_time:97024ms step_avg:89.67ms +step:1083/1680 train_time:97114ms step_avg:89.67ms +step:1084/1680 train_time:97204ms step_avg:89.67ms +step:1085/1680 train_time:97294ms step_avg:89.67ms +step:1086/1680 train_time:97384ms step_avg:89.67ms +step:1087/1680 train_time:97474ms step_avg:89.67ms +step:1088/1680 train_time:97565ms step_avg:89.67ms +step:1089/1680 train_time:97655ms step_avg:89.67ms +step:1090/1680 train_time:97745ms step_avg:89.67ms +step:1091/1680 train_time:97835ms step_avg:89.67ms +step:1092/1680 train_time:97925ms step_avg:89.67ms +step:1093/1680 train_time:98015ms step_avg:89.67ms +step:1094/1680 train_time:98105ms step_avg:89.68ms +step:1095/1680 train_time:98195ms step_avg:89.68ms +step:1096/1680 train_time:98286ms step_avg:89.68ms +step:1097/1680 train_time:98377ms step_avg:89.68ms +step:1098/1680 train_time:98469ms step_avg:89.68ms +step:1099/1680 train_time:98560ms step_avg:89.68ms +step:1100/1680 train_time:98651ms step_avg:89.68ms +step:1101/1680 train_time:98742ms step_avg:89.68ms +step:1102/1680 train_time:98833ms step_avg:89.69ms +step:1103/1680 train_time:98923ms step_avg:89.69ms +step:1104/1680 train_time:99014ms step_avg:89.69ms +step:1105/1680 train_time:99105ms step_avg:89.69ms +step:1106/1680 train_time:99197ms step_avg:89.69ms +step:1107/1680 train_time:99287ms step_avg:89.69ms +step:1108/1680 train_time:99379ms step_avg:89.69ms +step:1109/1680 train_time:99470ms step_avg:89.69ms +step:1110/1680 train_time:99561ms step_avg:89.69ms +step:1111/1680 train_time:99652ms step_avg:89.70ms +step:1112/1680 train_time:99742ms step_avg:89.70ms +step:1113/1680 train_time:99833ms step_avg:89.70ms +step:1114/1680 train_time:99924ms step_avg:89.70ms +step:1115/1680 train_time:100015ms step_avg:89.70ms +step:1116/1680 train_time:100105ms step_avg:89.70ms +step:1117/1680 train_time:100196ms step_avg:89.70ms +step:1118/1680 train_time:100287ms step_avg:89.70ms +step:1119/1680 train_time:100379ms step_avg:89.70ms +step:1120/1680 train_time:100470ms step_avg:89.71ms +step:1121/1680 train_time:100562ms step_avg:89.71ms +step:1122/1680 train_time:100652ms step_avg:89.71ms +step:1123/1680 train_time:100743ms step_avg:89.71ms +step:1124/1680 train_time:100834ms step_avg:89.71ms +step:1125/1680 train_time:100925ms step_avg:89.71ms +step:1125/1680 val_loss:3.4146 train_time:101017ms step_avg:89.79ms +step:1126/1680 train_time:101041ms step_avg:89.73ms +step:1127/1680 train_time:101112ms step_avg:89.72ms +step:1128/1680 train_time:101210ms step_avg:89.72ms +step:1129/1680 train_time:101307ms step_avg:89.73ms +step:1130/1680 train_time:101398ms step_avg:89.73ms +step:1131/1680 train_time:101488ms step_avg:89.73ms +step:1132/1680 train_time:101578ms step_avg:89.73ms +step:1133/1680 train_time:101668ms step_avg:89.73ms +step:1134/1680 train_time:101757ms step_avg:89.73ms +step:1135/1680 train_time:101847ms step_avg:89.73ms +step:1136/1680 train_time:101936ms step_avg:89.73ms +step:1137/1680 train_time:102029ms step_avg:89.74ms +step:1138/1680 train_time:102119ms step_avg:89.74ms +step:1139/1680 train_time:102213ms step_avg:89.74ms +step:1140/1680 train_time:102306ms step_avg:89.74ms +step:1141/1680 train_time:102399ms step_avg:89.74ms +step:1142/1680 train_time:102490ms step_avg:89.75ms +step:1143/1680 train_time:102580ms step_avg:89.75ms +step:1144/1680 train_time:102669ms step_avg:89.75ms +step:1145/1680 train_time:102760ms step_avg:89.75ms +step:1146/1680 train_time:102849ms step_avg:89.75ms +step:1147/1680 train_time:102939ms step_avg:89.75ms +step:1148/1680 train_time:103028ms step_avg:89.75ms +step:1149/1680 train_time:103120ms step_avg:89.75ms +step:1150/1680 train_time:103213ms step_avg:89.75ms +step:1151/1680 train_time:103305ms step_avg:89.75ms +step:1152/1680 train_time:103398ms step_avg:89.76ms +step:1153/1680 train_time:103488ms step_avg:89.76ms +step:1154/1680 train_time:103578ms step_avg:89.76ms +step:1155/1680 train_time:103669ms step_avg:89.76ms +step:1156/1680 train_time:103758ms step_avg:89.76ms +step:1157/1680 train_time:103848ms step_avg:89.76ms +step:1158/1680 train_time:103938ms step_avg:89.76ms +step:1159/1680 train_time:104028ms step_avg:89.76ms +step:1160/1680 train_time:104119ms step_avg:89.76ms +step:1161/1680 train_time:104211ms step_avg:89.76ms +step:1162/1680 train_time:104303ms step_avg:89.76ms +step:1163/1680 train_time:104395ms step_avg:89.76ms +step:1164/1680 train_time:104487ms step_avg:89.77ms +step:1165/1680 train_time:104577ms step_avg:89.77ms +step:1166/1680 train_time:104667ms step_avg:89.77ms +step:1167/1680 train_time:104758ms step_avg:89.77ms +step:1168/1680 train_time:104848ms step_avg:89.77ms +step:1169/1680 train_time:104938ms step_avg:89.77ms +step:1170/1680 train_time:105028ms step_avg:89.77ms +step:1171/1680 train_time:105119ms step_avg:89.77ms +step:1172/1680 train_time:105210ms step_avg:89.77ms +step:1173/1680 train_time:105301ms step_avg:89.77ms +step:1174/1680 train_time:105392ms step_avg:89.77ms +step:1175/1680 train_time:105484ms step_avg:89.77ms +step:1176/1680 train_time:105575ms step_avg:89.77ms +step:1177/1680 train_time:105666ms step_avg:89.78ms +step:1178/1680 train_time:105757ms step_avg:89.78ms +step:1179/1680 train_time:105847ms step_avg:89.78ms +step:1180/1680 train_time:105938ms step_avg:89.78ms +step:1181/1680 train_time:106028ms step_avg:89.78ms +step:1182/1680 train_time:106119ms step_avg:89.78ms +step:1183/1680 train_time:106210ms step_avg:89.78ms +step:1184/1680 train_time:106302ms step_avg:89.78ms +step:1185/1680 train_time:106392ms step_avg:89.78ms +step:1186/1680 train_time:106483ms step_avg:89.78ms +step:1187/1680 train_time:106575ms step_avg:89.78ms +step:1188/1680 train_time:106665ms step_avg:89.79ms +step:1189/1680 train_time:106756ms step_avg:89.79ms +step:1190/1680 train_time:106846ms step_avg:89.79ms +step:1191/1680 train_time:106938ms step_avg:89.79ms +step:1192/1680 train_time:107028ms step_avg:89.79ms +step:1193/1680 train_time:107119ms step_avg:89.79ms +step:1194/1680 train_time:107211ms step_avg:89.79ms +step:1195/1680 train_time:107302ms step_avg:89.79ms +step:1196/1680 train_time:107392ms step_avg:89.79ms +step:1197/1680 train_time:107482ms step_avg:89.79ms +step:1198/1680 train_time:107573ms step_avg:89.79ms +step:1199/1680 train_time:107664ms step_avg:89.80ms +step:1200/1680 train_time:107755ms step_avg:89.80ms +step:1201/1680 train_time:107846ms step_avg:89.80ms +step:1202/1680 train_time:107937ms step_avg:89.80ms +step:1203/1680 train_time:108027ms step_avg:89.80ms +step:1204/1680 train_time:108118ms step_avg:89.80ms +step:1205/1680 train_time:108209ms step_avg:89.80ms +step:1206/1680 train_time:108300ms step_avg:89.80ms +step:1207/1680 train_time:108391ms step_avg:89.80ms +step:1208/1680 train_time:108481ms step_avg:89.80ms +step:1209/1680 train_time:108573ms step_avg:89.80ms +step:1210/1680 train_time:108664ms step_avg:89.81ms +step:1211/1680 train_time:108755ms step_avg:89.81ms +step:1212/1680 train_time:108847ms step_avg:89.81ms +step:1213/1680 train_time:108939ms step_avg:89.81ms +step:1214/1680 train_time:109030ms step_avg:89.81ms +step:1215/1680 train_time:109120ms step_avg:89.81ms +step:1216/1680 train_time:109211ms step_avg:89.81ms +step:1217/1680 train_time:109302ms step_avg:89.81ms +step:1218/1680 train_time:109393ms step_avg:89.81ms +step:1219/1680 train_time:109485ms step_avg:89.82ms +step:1220/1680 train_time:109575ms step_avg:89.82ms +step:1221/1680 train_time:109666ms step_avg:89.82ms +step:1222/1680 train_time:109757ms step_avg:89.82ms +step:1223/1680 train_time:109847ms step_avg:89.82ms +step:1224/1680 train_time:109938ms step_avg:89.82ms +step:1225/1680 train_time:110029ms step_avg:89.82ms +step:1226/1680 train_time:110120ms step_avg:89.82ms +step:1227/1680 train_time:110210ms step_avg:89.82ms +step:1228/1680 train_time:110300ms step_avg:89.82ms +step:1229/1680 train_time:110391ms step_avg:89.82ms +step:1230/1680 train_time:110483ms step_avg:89.82ms +step:1231/1680 train_time:110575ms step_avg:89.82ms +step:1232/1680 train_time:110666ms step_avg:89.83ms +step:1233/1680 train_time:110758ms step_avg:89.83ms +step:1234/1680 train_time:110849ms step_avg:89.83ms +step:1235/1680 train_time:110939ms step_avg:89.83ms +step:1236/1680 train_time:111030ms step_avg:89.83ms +step:1237/1680 train_time:111122ms step_avg:89.83ms +step:1238/1680 train_time:111213ms step_avg:89.83ms +step:1239/1680 train_time:111303ms step_avg:89.83ms +step:1240/1680 train_time:111394ms step_avg:89.83ms +step:1241/1680 train_time:111484ms step_avg:89.83ms +step:1242/1680 train_time:111575ms step_avg:89.84ms +step:1243/1680 train_time:111666ms step_avg:89.84ms +step:1244/1680 train_time:111756ms step_avg:89.84ms +step:1245/1680 train_time:111848ms step_avg:89.84ms +step:1246/1680 train_time:111939ms step_avg:89.84ms +step:1247/1680 train_time:112030ms step_avg:89.84ms +step:1248/1680 train_time:112121ms step_avg:89.84ms +step:1249/1680 train_time:112212ms step_avg:89.84ms +step:1250/1680 train_time:112302ms step_avg:89.84ms +step:1250/1680 val_loss:3.3761 train_time:112395ms step_avg:89.92ms +step:1251/1680 train_time:112418ms step_avg:89.86ms +step:1252/1680 train_time:112491ms step_avg:89.85ms +step:1253/1680 train_time:112589ms step_avg:89.86ms +step:1254/1680 train_time:112680ms step_avg:89.86ms +step:1255/1680 train_time:112769ms step_avg:89.86ms +step:1256/1680 train_time:112859ms step_avg:89.86ms +step:1257/1680 train_time:112948ms step_avg:89.85ms +step:1258/1680 train_time:113037ms step_avg:89.85ms +step:1259/1680 train_time:113127ms step_avg:89.85ms +step:1260/1680 train_time:113216ms step_avg:89.85ms +step:1261/1680 train_time:113306ms step_avg:89.85ms +step:1262/1680 train_time:113400ms step_avg:89.86ms +step:1263/1680 train_time:113494ms step_avg:89.86ms +step:1264/1680 train_time:113586ms step_avg:89.86ms +step:1265/1680 train_time:113679ms step_avg:89.87ms +step:1266/1680 train_time:113769ms step_avg:89.87ms +step:1267/1680 train_time:113860ms step_avg:89.87ms +step:1268/1680 train_time:113951ms step_avg:89.87ms +step:1269/1680 train_time:114040ms step_avg:89.87ms +step:1270/1680 train_time:114130ms step_avg:89.87ms +step:1271/1680 train_time:114220ms step_avg:89.87ms +step:1272/1680 train_time:114310ms step_avg:89.87ms +step:1273/1680 train_time:114402ms step_avg:89.87ms +step:1274/1680 train_time:114495ms step_avg:89.87ms +step:1275/1680 train_time:114587ms step_avg:89.87ms +step:1276/1680 train_time:114678ms step_avg:89.87ms +step:1277/1680 train_time:114774ms step_avg:89.88ms +step:1278/1680 train_time:114860ms step_avg:89.87ms +step:1279/1680 train_time:114951ms step_avg:89.88ms +step:1280/1680 train_time:115041ms step_avg:89.88ms +step:1281/1680 train_time:115130ms step_avg:89.88ms +step:1282/1680 train_time:115221ms step_avg:89.88ms +step:1283/1680 train_time:115312ms step_avg:89.88ms +step:1284/1680 train_time:115403ms step_avg:89.88ms +step:1285/1680 train_time:115494ms step_avg:89.88ms +step:1286/1680 train_time:115586ms step_avg:89.88ms +step:1287/1680 train_time:115678ms step_avg:89.88ms +step:1288/1680 train_time:115768ms step_avg:89.88ms +step:1289/1680 train_time:115860ms step_avg:89.88ms +step:1290/1680 train_time:115951ms step_avg:89.88ms +step:1291/1680 train_time:116041ms step_avg:89.88ms +step:1292/1680 train_time:116130ms step_avg:89.88ms +step:1293/1680 train_time:116220ms step_avg:89.88ms +step:1294/1680 train_time:116311ms step_avg:89.88ms +step:1295/1680 train_time:116402ms step_avg:89.89ms +step:1296/1680 train_time:116492ms step_avg:89.89ms +step:1297/1680 train_time:116583ms step_avg:89.89ms +step:1298/1680 train_time:116676ms step_avg:89.89ms +step:1299/1680 train_time:116766ms step_avg:89.89ms +step:1300/1680 train_time:116858ms step_avg:89.89ms +step:1301/1680 train_time:116949ms step_avg:89.89ms +step:1302/1680 train_time:117039ms step_avg:89.89ms +step:1303/1680 train_time:117129ms step_avg:89.89ms +step:1304/1680 train_time:117219ms step_avg:89.89ms +step:1305/1680 train_time:117311ms step_avg:89.89ms +step:1306/1680 train_time:117401ms step_avg:89.89ms +step:1307/1680 train_time:117492ms step_avg:89.89ms +step:1308/1680 train_time:117583ms step_avg:89.90ms +step:1309/1680 train_time:117675ms step_avg:89.90ms +step:1310/1680 train_time:117766ms step_avg:89.90ms +step:1311/1680 train_time:117857ms step_avg:89.90ms +step:1312/1680 train_time:117948ms step_avg:89.90ms +step:1313/1680 train_time:118039ms step_avg:89.90ms +step:1314/1680 train_time:118131ms step_avg:89.90ms +step:1315/1680 train_time:118221ms step_avg:89.90ms +step:1316/1680 train_time:118312ms step_avg:89.90ms +step:1317/1680 train_time:118402ms step_avg:89.90ms +step:1318/1680 train_time:118492ms step_avg:89.90ms +step:1319/1680 train_time:118583ms step_avg:89.90ms +step:1320/1680 train_time:118674ms step_avg:89.90ms +step:1321/1680 train_time:118767ms step_avg:89.91ms +step:1322/1680 train_time:118858ms step_avg:89.91ms +step:1323/1680 train_time:118949ms step_avg:89.91ms +step:1324/1680 train_time:119040ms step_avg:89.91ms +step:1325/1680 train_time:119130ms step_avg:89.91ms +step:1326/1680 train_time:119221ms step_avg:89.91ms +step:1327/1680 train_time:119311ms step_avg:89.91ms +step:1328/1680 train_time:119402ms step_avg:89.91ms +step:1329/1680 train_time:119492ms step_avg:89.91ms +step:1330/1680 train_time:119583ms step_avg:89.91ms +step:1331/1680 train_time:119675ms step_avg:89.91ms +step:1332/1680 train_time:119767ms step_avg:89.91ms +step:1333/1680 train_time:119858ms step_avg:89.92ms +step:1334/1680 train_time:119948ms step_avg:89.92ms +step:1335/1680 train_time:120039ms step_avg:89.92ms +step:1336/1680 train_time:120131ms step_avg:89.92ms +step:1337/1680 train_time:120222ms step_avg:89.92ms +step:1338/1680 train_time:120313ms step_avg:89.92ms +step:1339/1680 train_time:120403ms step_avg:89.92ms +step:1340/1680 train_time:120493ms step_avg:89.92ms +step:1341/1680 train_time:120584ms step_avg:89.92ms +step:1342/1680 train_time:120677ms step_avg:89.92ms +step:1343/1680 train_time:120767ms step_avg:89.92ms +step:1344/1680 train_time:120859ms step_avg:89.93ms +step:1345/1680 train_time:120951ms step_avg:89.93ms +step:1346/1680 train_time:121042ms step_avg:89.93ms +step:1347/1680 train_time:121132ms step_avg:89.93ms +step:1348/1680 train_time:121223ms step_avg:89.93ms +step:1349/1680 train_time:121314ms step_avg:89.93ms +step:1350/1680 train_time:121404ms step_avg:89.93ms +step:1351/1680 train_time:121495ms step_avg:89.93ms +step:1352/1680 train_time:121586ms step_avg:89.93ms +step:1353/1680 train_time:121678ms step_avg:89.93ms +step:1354/1680 train_time:121769ms step_avg:89.93ms +step:1355/1680 train_time:121860ms step_avg:89.93ms +step:1356/1680 train_time:121952ms step_avg:89.94ms +step:1357/1680 train_time:122042ms step_avg:89.94ms +step:1358/1680 train_time:122133ms step_avg:89.94ms +step:1359/1680 train_time:122223ms step_avg:89.94ms +step:1360/1680 train_time:122314ms step_avg:89.94ms +step:1361/1680 train_time:122404ms step_avg:89.94ms +step:1362/1680 train_time:122495ms step_avg:89.94ms +step:1363/1680 train_time:122586ms step_avg:89.94ms +step:1364/1680 train_time:122677ms step_avg:89.94ms +step:1365/1680 train_time:122768ms step_avg:89.94ms +step:1366/1680 train_time:122860ms step_avg:89.94ms +step:1367/1680 train_time:122952ms step_avg:89.94ms +step:1368/1680 train_time:123042ms step_avg:89.94ms +step:1369/1680 train_time:123132ms step_avg:89.94ms +step:1370/1680 train_time:123223ms step_avg:89.94ms +step:1371/1680 train_time:123313ms step_avg:89.94ms +step:1372/1680 train_time:123404ms step_avg:89.94ms +step:1373/1680 train_time:123495ms step_avg:89.95ms +step:1374/1680 train_time:123586ms step_avg:89.95ms +step:1375/1680 train_time:123677ms step_avg:89.95ms +step:1375/1680 val_loss:3.3409 train_time:123769ms step_avg:90.01ms +step:1376/1680 train_time:123792ms step_avg:89.97ms +step:1377/1680 train_time:123862ms step_avg:89.95ms +step:1378/1680 train_time:123959ms step_avg:89.96ms +step:1379/1680 train_time:124050ms step_avg:89.96ms +step:1380/1680 train_time:124139ms step_avg:89.96ms +step:1381/1680 train_time:124229ms step_avg:89.96ms +step:1382/1680 train_time:124318ms step_avg:89.96ms +step:1383/1680 train_time:124408ms step_avg:89.95ms +step:1384/1680 train_time:124497ms step_avg:89.95ms +step:1385/1680 train_time:124587ms step_avg:89.95ms +step:1386/1680 train_time:124678ms step_avg:89.96ms +step:1387/1680 train_time:124770ms step_avg:89.96ms +step:1388/1680 train_time:124863ms step_avg:89.96ms +step:1389/1680 train_time:124956ms step_avg:89.96ms +step:1390/1680 train_time:125047ms step_avg:89.96ms +step:1391/1680 train_time:125138ms step_avg:89.96ms +step:1392/1680 train_time:125228ms step_avg:89.96ms +step:1393/1680 train_time:125318ms step_avg:89.96ms +step:1394/1680 train_time:125407ms step_avg:89.96ms +step:1395/1680 train_time:125497ms step_avg:89.96ms +step:1396/1680 train_time:125587ms step_avg:89.96ms +step:1397/1680 train_time:125678ms step_avg:89.96ms +step:1398/1680 train_time:125769ms step_avg:89.96ms +step:1399/1680 train_time:125861ms step_avg:89.96ms +step:1400/1680 train_time:125953ms step_avg:89.97ms +step:1401/1680 train_time:126045ms step_avg:89.97ms +step:1402/1680 train_time:126135ms step_avg:89.97ms +step:1403/1680 train_time:126226ms step_avg:89.97ms +step:1404/1680 train_time:126317ms step_avg:89.97ms +step:1405/1680 train_time:126407ms step_avg:89.97ms +step:1406/1680 train_time:126496ms step_avg:89.97ms +step:1407/1680 train_time:126586ms step_avg:89.97ms +step:1408/1680 train_time:126677ms step_avg:89.97ms +step:1409/1680 train_time:126769ms step_avg:89.97ms +step:1410/1680 train_time:126860ms step_avg:89.97ms +step:1411/1680 train_time:126951ms step_avg:89.97ms +step:1412/1680 train_time:127044ms step_avg:89.97ms +step:1413/1680 train_time:127136ms step_avg:89.98ms +step:1414/1680 train_time:127227ms step_avg:89.98ms +step:1415/1680 train_time:127318ms step_avg:89.98ms +step:1416/1680 train_time:127408ms step_avg:89.98ms +step:1417/1680 train_time:127498ms step_avg:89.98ms +step:1418/1680 train_time:127588ms step_avg:89.98ms +step:1419/1680 train_time:127678ms step_avg:89.98ms +step:1420/1680 train_time:127770ms step_avg:89.98ms +step:1421/1680 train_time:127861ms step_avg:89.98ms +step:1422/1680 train_time:127954ms step_avg:89.98ms +step:1423/1680 train_time:128044ms step_avg:89.98ms +step:1424/1680 train_time:128135ms step_avg:89.98ms +step:1425/1680 train_time:128227ms step_avg:89.98ms +step:1426/1680 train_time:128318ms step_avg:89.98ms +step:1427/1680 train_time:128408ms step_avg:89.98ms +step:1428/1680 train_time:128498ms step_avg:89.98ms +step:1429/1680 train_time:128589ms step_avg:89.99ms +step:1430/1680 train_time:128680ms step_avg:89.99ms +step:1431/1680 train_time:128770ms step_avg:89.99ms +step:1432/1680 train_time:128861ms step_avg:89.99ms +step:1433/1680 train_time:128952ms step_avg:89.99ms +step:1434/1680 train_time:129043ms step_avg:89.99ms +step:1435/1680 train_time:129134ms step_avg:89.99ms +step:1436/1680 train_time:129225ms step_avg:89.99ms +step:1437/1680 train_time:129316ms step_avg:89.99ms +step:1438/1680 train_time:129406ms step_avg:89.99ms +step:1439/1680 train_time:129497ms step_avg:89.99ms +step:1440/1680 train_time:129588ms step_avg:89.99ms +step:1441/1680 train_time:129678ms step_avg:89.99ms +step:1442/1680 train_time:129768ms step_avg:89.99ms +step:1443/1680 train_time:129859ms step_avg:89.99ms +step:1444/1680 train_time:129950ms step_avg:89.99ms +step:1445/1680 train_time:130042ms step_avg:89.99ms +step:1446/1680 train_time:130133ms step_avg:89.99ms +step:1447/1680 train_time:130224ms step_avg:90.00ms +step:1448/1680 train_time:130315ms step_avg:90.00ms +step:1449/1680 train_time:130406ms step_avg:90.00ms +step:1450/1680 train_time:130496ms step_avg:90.00ms +step:1451/1680 train_time:130587ms step_avg:90.00ms +step:1452/1680 train_time:130676ms step_avg:90.00ms +step:1453/1680 train_time:130767ms step_avg:90.00ms +step:1454/1680 train_time:130858ms step_avg:90.00ms +step:1455/1680 train_time:130949ms step_avg:90.00ms +step:1456/1680 train_time:131040ms step_avg:90.00ms +step:1457/1680 train_time:131132ms step_avg:90.00ms +step:1458/1680 train_time:131223ms step_avg:90.00ms +step:1459/1680 train_time:131314ms step_avg:90.00ms +step:1460/1680 train_time:131406ms step_avg:90.00ms +step:1461/1680 train_time:131496ms step_avg:90.00ms +step:1462/1680 train_time:131587ms step_avg:90.00ms +step:1463/1680 train_time:131677ms step_avg:90.00ms +step:1464/1680 train_time:131768ms step_avg:90.01ms +step:1465/1680 train_time:131858ms step_avg:90.01ms +step:1466/1680 train_time:131948ms step_avg:90.01ms +step:1467/1680 train_time:132039ms step_avg:90.01ms +step:1468/1680 train_time:132130ms step_avg:90.01ms +step:1469/1680 train_time:132221ms step_avg:90.01ms +step:1470/1680 train_time:132311ms step_avg:90.01ms +step:1471/1680 train_time:132404ms step_avg:90.01ms +step:1472/1680 train_time:132495ms step_avg:90.01ms +step:1473/1680 train_time:132586ms step_avg:90.01ms +step:1474/1680 train_time:132677ms step_avg:90.01ms +step:1475/1680 train_time:132767ms step_avg:90.01ms +step:1476/1680 train_time:132858ms step_avg:90.01ms +step:1477/1680 train_time:132948ms step_avg:90.01ms +step:1478/1680 train_time:133039ms step_avg:90.01ms +step:1479/1680 train_time:133130ms step_avg:90.01ms +step:1480/1680 train_time:133222ms step_avg:90.02ms +step:1481/1680 train_time:133313ms step_avg:90.02ms +step:1482/1680 train_time:133406ms step_avg:90.02ms +step:1483/1680 train_time:133497ms step_avg:90.02ms +step:1484/1680 train_time:133588ms step_avg:90.02ms +step:1485/1680 train_time:133679ms step_avg:90.02ms +step:1486/1680 train_time:133768ms step_avg:90.02ms +step:1487/1680 train_time:133859ms step_avg:90.02ms +step:1488/1680 train_time:133950ms step_avg:90.02ms +step:1489/1680 train_time:134041ms step_avg:90.02ms +step:1490/1680 train_time:134136ms step_avg:90.02ms +step:1491/1680 train_time:134222ms step_avg:90.02ms +step:1492/1680 train_time:134313ms step_avg:90.02ms +step:1493/1680 train_time:134405ms step_avg:90.02ms +step:1494/1680 train_time:134495ms step_avg:90.02ms +step:1495/1680 train_time:134587ms step_avg:90.02ms +step:1496/1680 train_time:134678ms step_avg:90.03ms +step:1497/1680 train_time:134768ms step_avg:90.03ms +step:1498/1680 train_time:134858ms step_avg:90.03ms +step:1499/1680 train_time:134949ms step_avg:90.03ms +step:1500/1680 train_time:135039ms step_avg:90.03ms +step:1500/1680 val_loss:3.3110 train_time:135131ms step_avg:90.09ms +step:1501/1680 train_time:135154ms step_avg:90.04ms +step:1502/1680 train_time:135227ms step_avg:90.03ms +step:1503/1680 train_time:135325ms step_avg:90.04ms +step:1504/1680 train_time:135415ms step_avg:90.04ms +step:1505/1680 train_time:135505ms step_avg:90.04ms +step:1506/1680 train_time:135595ms step_avg:90.04ms +step:1507/1680 train_time:135684ms step_avg:90.04ms +step:1508/1680 train_time:135775ms step_avg:90.04ms +step:1509/1680 train_time:135864ms step_avg:90.04ms +step:1510/1680 train_time:135954ms step_avg:90.04ms +step:1511/1680 train_time:136044ms step_avg:90.04ms +step:1512/1680 train_time:136137ms step_avg:90.04ms +step:1513/1680 train_time:136231ms step_avg:90.04ms +step:1514/1680 train_time:136325ms step_avg:90.04ms +step:1515/1680 train_time:136417ms step_avg:90.04ms +step:1516/1680 train_time:136507ms step_avg:90.04ms +step:1517/1680 train_time:136598ms step_avg:90.04ms +step:1518/1680 train_time:136688ms step_avg:90.04ms +step:1519/1680 train_time:136778ms step_avg:90.04ms +step:1520/1680 train_time:136867ms step_avg:90.04ms +step:1521/1680 train_time:136957ms step_avg:90.04ms +step:1522/1680 train_time:137048ms step_avg:90.04ms +step:1523/1680 train_time:137140ms step_avg:90.05ms +step:1524/1680 train_time:137232ms step_avg:90.05ms +step:1525/1680 train_time:137323ms step_avg:90.05ms +step:1526/1680 train_time:137414ms step_avg:90.05ms +step:1527/1680 train_time:137505ms step_avg:90.05ms +step:1528/1680 train_time:137595ms step_avg:90.05ms +step:1529/1680 train_time:137685ms step_avg:90.05ms +step:1530/1680 train_time:137775ms step_avg:90.05ms +step:1531/1680 train_time:137865ms step_avg:90.05ms +step:1532/1680 train_time:137956ms step_avg:90.05ms +step:1533/1680 train_time:138047ms step_avg:90.05ms +step:1534/1680 train_time:138138ms step_avg:90.05ms +step:1535/1680 train_time:138230ms step_avg:90.05ms +step:1536/1680 train_time:138321ms step_avg:90.05ms +step:1537/1680 train_time:138412ms step_avg:90.05ms +step:1538/1680 train_time:138503ms step_avg:90.05ms +step:1539/1680 train_time:138594ms step_avg:90.05ms +step:1540/1680 train_time:138684ms step_avg:90.05ms +step:1541/1680 train_time:138774ms step_avg:90.05ms +step:1542/1680 train_time:138864ms step_avg:90.05ms +step:1543/1680 train_time:138954ms step_avg:90.05ms +step:1544/1680 train_time:139045ms step_avg:90.06ms +step:1545/1680 train_time:139137ms step_avg:90.06ms +step:1546/1680 train_time:139228ms step_avg:90.06ms +step:1547/1680 train_time:139320ms step_avg:90.06ms +step:1548/1680 train_time:139411ms step_avg:90.06ms +step:1549/1680 train_time:139502ms step_avg:90.06ms +step:1550/1680 train_time:139593ms step_avg:90.06ms +step:1551/1680 train_time:139683ms step_avg:90.06ms +step:1552/1680 train_time:139773ms step_avg:90.06ms +step:1553/1680 train_time:139863ms step_avg:90.06ms +step:1554/1680 train_time:139954ms step_avg:90.06ms +step:1555/1680 train_time:140046ms step_avg:90.06ms +step:1556/1680 train_time:140136ms step_avg:90.06ms +step:1557/1680 train_time:140227ms step_avg:90.06ms +step:1558/1680 train_time:140320ms step_avg:90.06ms +step:1559/1680 train_time:140411ms step_avg:90.06ms +step:1560/1680 train_time:140502ms step_avg:90.07ms +step:1561/1680 train_time:140593ms step_avg:90.07ms +step:1562/1680 train_time:140683ms step_avg:90.07ms +step:1563/1680 train_time:140773ms step_avg:90.07ms +step:1564/1680 train_time:140864ms step_avg:90.07ms +step:1565/1680 train_time:140954ms step_avg:90.07ms +step:1566/1680 train_time:141045ms step_avg:90.07ms +step:1567/1680 train_time:141135ms step_avg:90.07ms +step:1568/1680 train_time:141226ms step_avg:90.07ms +step:1569/1680 train_time:141319ms step_avg:90.07ms +step:1570/1680 train_time:141410ms step_avg:90.07ms +step:1571/1680 train_time:141501ms step_avg:90.07ms +step:1572/1680 train_time:141591ms step_avg:90.07ms +step:1573/1680 train_time:141682ms step_avg:90.07ms +step:1574/1680 train_time:141772ms step_avg:90.07ms +step:1575/1680 train_time:141864ms step_avg:90.07ms +step:1576/1680 train_time:141954ms step_avg:90.07ms +step:1577/1680 train_time:142045ms step_avg:90.07ms +step:1578/1680 train_time:142136ms step_avg:90.07ms +step:1579/1680 train_time:142227ms step_avg:90.07ms +step:1580/1680 train_time:142319ms step_avg:90.08ms +step:1581/1680 train_time:142410ms step_avg:90.08ms +step:1582/1680 train_time:142501ms step_avg:90.08ms +step:1583/1680 train_time:142592ms step_avg:90.08ms +step:1584/1680 train_time:142683ms step_avg:90.08ms +step:1585/1680 train_time:142773ms step_avg:90.08ms +step:1586/1680 train_time:142863ms step_avg:90.08ms +step:1587/1680 train_time:142954ms step_avg:90.08ms +step:1588/1680 train_time:143045ms step_avg:90.08ms +step:1589/1680 train_time:143136ms step_avg:90.08ms +step:1590/1680 train_time:143226ms step_avg:90.08ms +step:1591/1680 train_time:143317ms step_avg:90.08ms +step:1592/1680 train_time:143408ms step_avg:90.08ms +step:1593/1680 train_time:143500ms step_avg:90.08ms +step:1594/1680 train_time:143592ms step_avg:90.08ms +step:1595/1680 train_time:143683ms step_avg:90.08ms +step:1596/1680 train_time:143773ms step_avg:90.08ms +step:1597/1680 train_time:143864ms step_avg:90.08ms +step:1598/1680 train_time:143955ms step_avg:90.08ms +step:1599/1680 train_time:144046ms step_avg:90.09ms +step:1600/1680 train_time:144136ms step_avg:90.09ms +step:1601/1680 train_time:144227ms step_avg:90.09ms +step:1602/1680 train_time:144317ms step_avg:90.09ms +step:1603/1680 train_time:144408ms step_avg:90.09ms +step:1604/1680 train_time:144500ms step_avg:90.09ms +step:1605/1680 train_time:144592ms step_avg:90.09ms +step:1606/1680 train_time:144683ms step_avg:90.09ms +step:1607/1680 train_time:144774ms step_avg:90.09ms +step:1608/1680 train_time:144864ms step_avg:90.09ms +step:1609/1680 train_time:144955ms step_avg:90.09ms +step:1610/1680 train_time:145048ms step_avg:90.09ms +step:1611/1680 train_time:145139ms step_avg:90.09ms +step:1612/1680 train_time:145230ms step_avg:90.09ms +step:1613/1680 train_time:145321ms step_avg:90.09ms +step:1614/1680 train_time:145411ms step_avg:90.09ms +step:1615/1680 train_time:145502ms step_avg:90.09ms +step:1616/1680 train_time:145594ms step_avg:90.10ms +step:1617/1680 train_time:145685ms step_avg:90.10ms +step:1618/1680 train_time:145778ms step_avg:90.10ms +step:1619/1680 train_time:145869ms step_avg:90.10ms +step:1620/1680 train_time:145959ms step_avg:90.10ms +step:1621/1680 train_time:146051ms step_avg:90.10ms +step:1622/1680 train_time:146143ms step_avg:90.10ms +step:1623/1680 train_time:146234ms step_avg:90.10ms +step:1624/1680 train_time:146324ms step_avg:90.10ms +step:1625/1680 train_time:146414ms step_avg:90.10ms +step:1625/1680 val_loss:3.2870 train_time:146506ms step_avg:90.16ms +step:1626/1680 train_time:146529ms step_avg:90.12ms +step:1627/1680 train_time:146600ms step_avg:90.10ms +step:1628/1680 train_time:146699ms step_avg:90.11ms +step:1629/1680 train_time:146791ms step_avg:90.11ms +step:1630/1680 train_time:146882ms step_avg:90.11ms +step:1631/1680 train_time:146972ms step_avg:90.11ms +step:1632/1680 train_time:147061ms step_avg:90.11ms +step:1633/1680 train_time:147151ms step_avg:90.11ms +step:1634/1680 train_time:147241ms step_avg:90.11ms +step:1635/1680 train_time:147332ms step_avg:90.11ms +step:1636/1680 train_time:147420ms step_avg:90.11ms +step:1637/1680 train_time:147514ms step_avg:90.11ms +step:1638/1680 train_time:147609ms step_avg:90.12ms +step:1639/1680 train_time:147701ms step_avg:90.12ms +step:1640/1680 train_time:147793ms step_avg:90.12ms +step:1641/1680 train_time:147886ms step_avg:90.12ms +step:1642/1680 train_time:147975ms step_avg:90.12ms +step:1643/1680 train_time:148065ms step_avg:90.12ms +step:1644/1680 train_time:148155ms step_avg:90.12ms +step:1645/1680 train_time:148244ms step_avg:90.12ms +step:1646/1680 train_time:148335ms step_avg:90.12ms +step:1647/1680 train_time:148425ms step_avg:90.12ms +step:1648/1680 train_time:148516ms step_avg:90.12ms +step:1649/1680 train_time:148608ms step_avg:90.12ms +step:1650/1680 train_time:148700ms step_avg:90.12ms +step:1651/1680 train_time:148792ms step_avg:90.12ms +step:1652/1680 train_time:148884ms step_avg:90.12ms +step:1653/1680 train_time:148974ms step_avg:90.12ms +step:1654/1680 train_time:149064ms step_avg:90.12ms +step:1655/1680 train_time:149154ms step_avg:90.12ms +step:1656/1680 train_time:149244ms step_avg:90.12ms +step:1657/1680 train_time:149335ms step_avg:90.12ms +step:1658/1680 train_time:149425ms step_avg:90.12ms +step:1659/1680 train_time:149517ms step_avg:90.12ms +step:1660/1680 train_time:149608ms step_avg:90.13ms +step:1661/1680 train_time:149700ms step_avg:90.13ms +step:1662/1680 train_time:149792ms step_avg:90.13ms +step:1663/1680 train_time:149884ms step_avg:90.13ms +step:1664/1680 train_time:149975ms step_avg:90.13ms +step:1665/1680 train_time:150065ms step_avg:90.13ms +step:1666/1680 train_time:150155ms step_avg:90.13ms +step:1667/1680 train_time:150245ms step_avg:90.13ms +step:1668/1680 train_time:150335ms step_avg:90.13ms +step:1669/1680 train_time:150425ms step_avg:90.13ms +step:1670/1680 train_time:150515ms step_avg:90.13ms +step:1671/1680 train_time:150607ms step_avg:90.13ms +step:1672/1680 train_time:150698ms step_avg:90.13ms +step:1673/1680 train_time:150789ms step_avg:90.13ms +step:1674/1680 train_time:150880ms step_avg:90.13ms +step:1675/1680 train_time:150972ms step_avg:90.13ms +step:1676/1680 train_time:151063ms step_avg:90.13ms +step:1677/1680 train_time:151154ms step_avg:90.13ms +step:1678/1680 train_time:151244ms step_avg:90.13ms +step:1679/1680 train_time:151334ms step_avg:90.13ms +step:1680/1680 train_time:151424ms step_avg:90.13ms +step:1680/1680 val_loss:3.2762 train_time:151516ms step_avg:90.19ms +peak memory allocated: 31255 MiB reserved: 46514 MiB diff --git a/train_gpt.py b/train_gpt.py index a661d63d3..95e7f47ac 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -493,18 +493,6 @@ def step(self): * getattr(p_example, "wd_mul", 1.0) ) - # Determine effective LR and WD once per group, assuming constant for same-shaped params. - eff_lr_val = ( - group["lr"] - * max(1, params[0].size(-2) / params[0].size(-1)) ** 0.5 - * getattr(params[0], "lr_mul", 1.0) - ) - eff_weight_decay_val = ( - group["lr"] - * group["weight_decay"] - * getattr(params[0], "wd_mul", 1.0) - ) - # Prepare a contiguous buffer for the updated parameters for this rank's chunk. # This buffer will serve as the input_tensor for dist.all_gather_into_tensor. updated_param_chunk = torch.empty( @@ -849,7 +837,7 @@ class Block(nn.Module): def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx != 7 else None + self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None # skip MLP blocks for first MLP layer by @EmelyanenkoK self.mlp = MLP(dim) if layer_idx != 0 else None @@ -911,17 +899,17 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: i self.lm_head.weight.lr_mul = 1.0 self.scalars.lr_mul = 5.0 - def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int, ws_final_layer: int): + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): assert input_seq.ndim == 1 ve = [value_embed(input_seq) for value_embed in self.value_embeds] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] assert len(ve) == len(self.blocks) - long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size - final_bm = ws_final_layer * args.block_size - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, final_bm] + short_bm = ws_short * args.block_size + long_bm = ws_long * args.block_size + bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] assert len(bm_sizes) == len(self.blocks) x = self.embed(input_seq) @@ -940,7 +928,8 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in n = len(self.blocks) // 2 - for i in range(len(self.blocks)): + # skip layer zero + for i in range(1,len(self.blocks)): attn_args = AttnArgs( ve=ve[i], sa_lambdas=sa_lambdas[i], @@ -950,7 +939,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: in sin=self.yarn.sin, attn_scale=self.yarn.attn_scale ) - if i >= n: + if i >= n and i<11: gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) x = x + gate * skip_connections.pop() x = self.blocks[i](x, x0, lambdas[i], attn_args) @@ -1153,7 +1142,8 @@ class Hyperparameters: train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1645 # number of iterations to run + num_iterations: int = 1640 # number of iterations to run + iteration_extension = 40 # number of iterations to continue training at final cooldown and window size cooldown_frac: int = 0.5 # fraction of training spent cooling down the learning rate # evaluation and logging run_id: str = f"{uuid.uuid4()}" @@ -1162,8 +1152,8 @@ class Hyperparameters: # attention masking block_size: int = 128 ws_schedule: tuple = (3, 7, 11) - ws_validate: int = 13 # increase final validation ws @classiclarryd - ws_validate_final_layer: int = 20 # final layer shows no degradation with context length + ws_validate: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd + ws_long_validate: int = 20 # extend long windows out even further args = Hyperparameters() @@ -1250,7 +1240,7 @@ def nvidia_smi(): # learning rate schedule: stable then decay def get_lr(step: int): - x = step / args.num_iterations + x = min(0.9999,step / args.num_iterations) assert 0 <= x < 1 lr = 1.0 if x >= 1 - args.cooldown_frac: @@ -1259,12 +1249,12 @@ def get_lr(step: int): return lr def get_ws(step: int): - if step == args.num_iterations: - return args.ws_validate, args.ws_validate_final_layer - x = step / (1 + args.num_iterations) + if step == args.num_iterations+args.iteration_extension: + return args.ws_validate//2, args.ws_validate + x = min(step / (1 + args.num_iterations),0.9999) assert 0 <= x < 1 ws_idx = int(len(args.ws_schedule) * x) - return args.ws_schedule[ws_idx], args.ws_schedule[ws_idx] + return args.ws_schedule[ws_idx]//2, args.ws_schedule[ws_idx] model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) @@ -1277,17 +1267,17 @@ def get_ws(step: int): initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) -ws=args.ws_schedule[0] +ws_long = args.ws_schedule[0] for step in range(warmup_steps): inputs, targets, cum_seqlens = next(train_loader) - new_ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each with YaRN params - if new_ws > ws: - model.yarn.apply(ws, new_ws) - ws = new_ws - elif new_ws ws_long: + model.yarn.apply(ws_long, new_ws_long) + ws_long = new_ws_long + elif new_ws_long 0 and step % args.val_loss_every == 0): + if last_step: + ws_long = args.ws_long_validate # stop the clock torch.cuda.synchronize() training_time_ms += 1000 * (time.perf_counter() - t0) @@ -1329,7 +1321,7 @@ def get_ws(step: int): with torch.no_grad(): for _ in range(val_steps): inputs, targets, cum_seqlens = next(val_loader) - val_loss += model(inputs, targets, cum_seqlens, ws, ws_final_layer) + val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) val_loss /= val_steps del val_loader dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) @@ -1350,7 +1342,7 @@ def get_ws(step: int): # --------------- TRAINING SECTION ----------------- for _ in range(grad_accum_steps): inputs, targets, cum_seqlens = next(train_loader) - model(inputs, targets, cum_seqlens, ws, ws_final_layer).backward() + model(inputs, targets, cum_seqlens, ws_short, ws_long).backward() # set optimization hyperparameters for opt in optimizers: for group in opt.param_groups: